1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "tools/converter/export_model.h"
18 #include <fstream>
19 #include <map>
20 #include <memory>
21 #include <string>
22 #include <vector>
23 #include "backend/optimizer/common/optimizer.h"
24 #include "include/errorcode.h"
25 #include "include/version.h"
26 #include "ir/func_graph.h"
27 #include "tools/anf_exporter/anf_exporter.h"
28 #include "tools/converter/graphdef_transform.h"
29 #include "tools/converter/optimizer_manager.h"
30 #include "tools/optimizer/graph/control_flow_pass.h"
31 #include "nnacl/op_base.h"
32 #include "src/common/log_util.h"
33
34 namespace mindspore {
35 namespace lite {
36 namespace {
37 using NodesMap = std::map<std::string, std::vector<AnfNodePtr>>;
CloneGraphInputs(const FuncGraphPtr & origin,const FuncGraphPtr & mirror,NodesMap * origin_map,NodesMap * mirror_map)38 void CloneGraphInputs(const FuncGraphPtr &origin, const FuncGraphPtr &mirror, NodesMap *origin_map,
39 NodesMap *mirror_map) {
40 MS_ASSERT(origin != nullptr && mirror != nullptr);
41 MS_ASSERT(origin_map != nullptr && mirror_map != nullptr);
42 auto origin_inputs = origin->get_inputs();
43 for (auto &input : origin_inputs) {
44 auto mirror_input = mirror->add_parameter();
45 MS_CHECK_TRUE_RET_VOID(mirror_input != nullptr);
46 if (input->abstract() != nullptr) {
47 mirror_input->set_abstract(input->abstract()->Clone());
48 }
49 mirror_input->set_name(input->fullname_with_scope());
50 (*origin_map)[input->fullname_with_scope()].push_back(input);
51 (*mirror_map)[input->fullname_with_scope()].push_back(mirror_input);
52 }
53 }
54
CloneParameterAndValueNode(const CNodePtr & cnode,size_t index,const FuncGraphPtr & mirror_graph,const converter::Flags * flags)55 AnfNodePtr CloneParameterAndValueNode(const CNodePtr &cnode, size_t index, const FuncGraphPtr &mirror_graph,
56 const converter::Flags *flags) {
57 MS_ASSERT(cnode != nullptr && mirror_graph != nullptr);
58 if (index >= cnode->size()) {
59 MS_LOG(ERROR) << "input index out of range.";
60 return nullptr;
61 }
62 auto node = cnode->input(index);
63 if (utils::isa<mindspore::CNode>(node)) {
64 MS_LOG(ERROR) << "this func cannot copy cnode.";
65 return nullptr;
66 }
67 if (utils::isa<ValueNode>(node)) {
68 auto value_node = node->cast<ValueNodePtr>();
69 auto value_ptr = value_node->value();
70 MS_ASSERT(value_ptr != nullptr);
71 if (utils::isa<Monad>(value_ptr)) {
72 std::shared_ptr<Monad> mirror_monad;
73 if (utils::isa<UMonad>(value_ptr)) {
74 mirror_monad = std::make_shared<UMonad>();
75 } else {
76 mirror_monad = std::make_shared<IOMonad>();
77 }
78 MS_CHECK_TRUE_RET(mirror_monad != nullptr, nullptr);
79 auto monad_abs = mirror_monad->ToAbstract();
80 auto mirror_value_node = NewValueNode(mirror_monad);
81 MS_CHECK_TRUE_RET(mirror_value_node != nullptr, nullptr);
82 mirror_value_node->set_abstract(monad_abs);
83 return mirror_value_node;
84 }
85 }
86 DataInfo data_info;
87 STATUS status;
88 if (utils::isa<Parameter>(node)) {
89 status = FetchDataFromParameterNode(cnode, index, flags->fmk, flags->trainModel, &data_info);
90 } else if (utils::isa<ValueNode>(node)) {
91 status = FetchDataFromValueNode(cnode, index, flags->fmk, flags->trainModel, &data_info);
92 } else {
93 status = RET_ERROR;
94 }
95 if (status != RET_OK && status != RET_NO_CHANGE) {
96 MS_LOG(ERROR) << "fetch data failed.";
97 return nullptr;
98 }
99 if (opt::CheckPrimitiveType(cnode, prim::kPrimTupleGetItem) && !data_info.data_.empty()) {
100 return NewValueNode(MakeValue<int>(*reinterpret_cast<int *>(data_info.data_.data())));
101 }
102 ShapeVector shape_vec(data_info.shape_.begin(), data_info.shape_.end());
103 auto tensor_info = std::make_shared<tensor::Tensor>(static_cast<TypeId>(data_info.data_type_), shape_vec);
104 MS_CHECK_TRUE_RET(tensor_info != nullptr, nullptr);
105 if (!data_info.data_.empty()) {
106 auto tensor_data = reinterpret_cast<uint8_t *>(tensor_info->data_c());
107 if (memcpy_s(tensor_data, tensor_info->data().nbytes(), data_info.data_.data(), data_info.data_.size()) != EOK) {
108 MS_LOG(ERROR) << "memcpy_s failed";
109 return nullptr;
110 }
111 }
112 auto mirror_parameter = mirror_graph->add_parameter();
113 MS_CHECK_TRUE_RET(mirror_parameter != nullptr, nullptr);
114 if (node->abstract() != nullptr) {
115 mirror_parameter->set_abstract(node->abstract()->Clone());
116 }
117 mirror_parameter->set_name(node->fullname_with_scope());
118 mirror_parameter->set_default_param(tensor_info);
119 return mirror_parameter;
120 }
121
ClonePrimitive(const CNodePtr & cnode)122 PrimitivePtr ClonePrimitive(const CNodePtr &cnode) {
123 MS_ASSERT(cnode != nullptr);
124 auto origin_prim = GetValueNode<PrimitivePtr>(cnode->input(0));
125 MS_ASSERT(origin_prim != nullptr);
126 PrimitivePtr prim;
127 auto op_primc_fns = ops::OpPrimCRegister::GetInstance().GetPrimCMap();
128 if (op_primc_fns.find(origin_prim->name()) != op_primc_fns.end()) {
129 prim = op_primc_fns[origin_prim->name()]();
130 } else {
131 prim = std::make_shared<PrimitiveC>(origin_prim->name());
132 MS_CHECK_TRUE_RET(prim != nullptr, nullptr);
133 prim->set_instance_name(origin_prim->name());
134 }
135 prim->SetAttrs(origin_prim->attrs());
136 return prim;
137 }
138
CloneFuncGraph(const FuncGraphPtr & graph,const converter::Flags * flags)139 FuncGraphPtr CloneFuncGraph(const FuncGraphPtr &graph, const converter::Flags *flags) {
140 MS_ASSERT(graph != nullptr);
141 auto mirror_graph = std::make_shared<FuncGraph>();
142 MS_CHECK_TRUE_RET(mirror_graph != nullptr, nullptr);
143 mirror_graph->set_attrs(graph->attrs());
144 NodesMap origin_nodes;
145 NodesMap mirror_nodes;
146 CloneGraphInputs(graph, mirror_graph, &origin_nodes, &mirror_nodes);
147 auto node_list = TopoSort(graph->get_return());
148 for (auto &node : node_list) {
149 if (!utils::isa<mindspore::CNode>(node)) {
150 continue;
151 }
152 auto cnode = node->cast<CNodePtr>();
153 auto mirrro_prim = ClonePrimitive(cnode);
154 std::vector<AnfNodePtr> node_inputs;
155 for (size_t i = 1; i < cnode->size(); ++i) {
156 auto origin_input = cnode->input(i);
157 MS_CHECK_TRUE_RET(origin_input != nullptr, nullptr);
158 AnfNodePtr mirror_input = nullptr;
159 auto value = origin_nodes[origin_input->fullname_with_scope()];
160 auto iter = std::find(value.begin(), value.end(), origin_input);
161 if (iter != value.end()) {
162 mirror_input = mirror_nodes[origin_input->fullname_with_scope()][iter - value.begin()];
163 }
164 if (mirror_input == nullptr) {
165 if (IsValueNode<FuncGraph>(origin_input)) {
166 auto sub_func_graph = GetValueNode<FuncGraphPtr>(origin_input);
167 auto mirror_sub_graph = CloneFuncGraph(sub_func_graph, flags);
168 mirror_input = NewValueNode(mirror_sub_graph);
169 } else {
170 mirror_input = CloneParameterAndValueNode(cnode, i, mirror_graph, flags);
171 }
172 if (mirror_input == nullptr) {
173 MS_LOG(ERROR) << "node input cannot be found.";
174 return nullptr;
175 }
176 origin_nodes[origin_input->fullname_with_scope()].push_back(origin_input);
177 mirror_nodes[origin_input->fullname_with_scope()].push_back(mirror_input);
178 }
179 node_inputs.push_back(mirror_input);
180 }
181 auto mirror_cnode = mirror_graph->NewCNode(mirrro_prim, node_inputs);
182 MS_CHECK_TRUE_RET(mirror_cnode != nullptr, nullptr);
183 mirror_cnode->set_fullname_with_scope(cnode->fullname_with_scope());
184 if (cnode->abstract() != nullptr) {
185 mirror_cnode->set_abstract(cnode->abstract()->Clone());
186 }
187 origin_nodes[cnode->fullname_with_scope()].push_back(cnode);
188 mirror_nodes[cnode->fullname_with_scope()].push_back(mirror_cnode);
189 if (opt::CheckPrimitiveType(cnode, prim::kPrimReturn)) {
190 mirror_graph->set_return(mirror_cnode);
191 }
192 }
193 return mirror_graph;
194 }
195 } // namespace
196
ExportModel(const FuncGraphPtr & graph,const converter::Flags * flags)197 STATUS ExportModel(const FuncGraphPtr &graph, const converter::Flags *flags) {
198 MS_ASSERT(graph != nullptr && flags != nullptr);
199 auto mirror_graph = CloneFuncGraph(graph, flags);
200 if (mirror_graph == nullptr) {
201 MS_LOG(ERROR) << "Clone funcGraph failed.";
202 return RET_ERROR;
203 }
204 (void)Manage(mirror_graph, true);
205 if (!RunOptimizerPass(mirror_graph, {"ToNHWCFormat", "InferShapePass", "DecreaseTransposeAlgo"})) {
206 MS_LOG(ERROR) << "Run transpose opt pass failed.";
207 return RET_ERROR;
208 }
209 auto optimizer = std::make_shared<opt::GraphOptimizer>();
210 CHECK_NULL_RETURN(optimizer);
211 auto graph_pm = std::make_shared<opt::PassManager>("anf graph pass manager", true);
212 CHECK_NULL_RETURN(graph_pm);
213 if (flags->fmk == converter::kFmkTypeTflite || flags->fmk == converter::kFmkTypeTf ||
214 flags->fmk == converter::kFmkTypeOnnx) {
215 graph_pm->AddPass(std::make_shared<opt::ControlFlowPass>());
216 }
217 optimizer->AddPassManager(graph_pm);
218 if (optimizer->Optimize(mirror_graph) == nullptr) {
219 MS_LOG(ERROR) << "run graph pass failed.";
220 return RET_ERROR;
221 }
222 auto meta_graph = Export(mirror_graph);
223 if (meta_graph == nullptr) {
224 MS_LOG(ERROR) << "Export to meta graph return nullptr";
225 return RET_ERROR;
226 }
227 auto metagraph_transform = std::make_unique<GraphDefTransform>();
228 CHECK_NULL_RETURN(metagraph_transform);
229 metagraph_transform->SetGraphDef(meta_graph);
230 auto status = metagraph_transform->Transform(*flags);
231 if (status != RET_OK) {
232 MS_LOG(ERROR) << "Transform meta graph failed " << status;
233 return RET_ERROR;
234 }
235 meta_graph->version = Version();
236 status = Storage::Save(*meta_graph, "model");
237 std::ostringstream oss;
238 if (status != RET_OK) {
239 oss << "SAVE GRAPH FAILED:" << status << " " << lite::GetErrorInfo(status);
240 MS_LOG(ERROR) << oss.str();
241 std::cout << oss.str() << std::endl;
242 return status;
243 }
244
245 delete meta_graph;
246 return status;
247 }
248 } // namespace lite
249 } // namespace mindspore
250