• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "tools/converter/export_model.h"
18 #include <fstream>
19 #include <map>
20 #include <memory>
21 #include <string>
22 #include <vector>
23 #include "backend/optimizer/common/optimizer.h"
24 #include "include/errorcode.h"
25 #include "include/version.h"
26 #include "ir/func_graph.h"
27 #include "tools/anf_exporter/anf_exporter.h"
28 #include "tools/converter/graphdef_transform.h"
29 #include "tools/converter/optimizer_manager.h"
30 #include "tools/optimizer/graph/control_flow_pass.h"
31 #include "nnacl/op_base.h"
32 #include "src/common/log_util.h"
33 
34 namespace mindspore {
35 namespace lite {
36 namespace {
37 using NodesMap = std::map<std::string, std::vector<AnfNodePtr>>;
CloneGraphInputs(const FuncGraphPtr & origin,const FuncGraphPtr & mirror,NodesMap * origin_map,NodesMap * mirror_map)38 void CloneGraphInputs(const FuncGraphPtr &origin, const FuncGraphPtr &mirror, NodesMap *origin_map,
39                       NodesMap *mirror_map) {
40   MS_ASSERT(origin != nullptr && mirror != nullptr);
41   MS_ASSERT(origin_map != nullptr && mirror_map != nullptr);
42   auto origin_inputs = origin->get_inputs();
43   for (auto &input : origin_inputs) {
44     auto mirror_input = mirror->add_parameter();
45     MS_CHECK_TRUE_RET_VOID(mirror_input != nullptr);
46     if (input->abstract() != nullptr) {
47       mirror_input->set_abstract(input->abstract()->Clone());
48     }
49     mirror_input->set_name(input->fullname_with_scope());
50     (*origin_map)[input->fullname_with_scope()].push_back(input);
51     (*mirror_map)[input->fullname_with_scope()].push_back(mirror_input);
52   }
53 }
54 
CloneParameterAndValueNode(const CNodePtr & cnode,size_t index,const FuncGraphPtr & mirror_graph,const converter::Flags * flags)55 AnfNodePtr CloneParameterAndValueNode(const CNodePtr &cnode, size_t index, const FuncGraphPtr &mirror_graph,
56                                       const converter::Flags *flags) {
57   MS_ASSERT(cnode != nullptr && mirror_graph != nullptr);
58   if (index >= cnode->size()) {
59     MS_LOG(ERROR) << "input index out of range.";
60     return nullptr;
61   }
62   auto node = cnode->input(index);
63   if (utils::isa<mindspore::CNode>(node)) {
64     MS_LOG(ERROR) << "this func cannot copy cnode.";
65     return nullptr;
66   }
67   if (utils::isa<ValueNode>(node)) {
68     auto value_node = node->cast<ValueNodePtr>();
69     auto value_ptr = value_node->value();
70     MS_ASSERT(value_ptr != nullptr);
71     if (utils::isa<Monad>(value_ptr)) {
72       std::shared_ptr<Monad> mirror_monad;
73       if (utils::isa<UMonad>(value_ptr)) {
74         mirror_monad = std::make_shared<UMonad>();
75       } else {
76         mirror_monad = std::make_shared<IOMonad>();
77       }
78       MS_CHECK_TRUE_RET(mirror_monad != nullptr, nullptr);
79       auto monad_abs = mirror_monad->ToAbstract();
80       auto mirror_value_node = NewValueNode(mirror_monad);
81       MS_CHECK_TRUE_RET(mirror_value_node != nullptr, nullptr);
82       mirror_value_node->set_abstract(monad_abs);
83       return mirror_value_node;
84     }
85   }
86   DataInfo data_info;
87   STATUS status;
88   if (utils::isa<Parameter>(node)) {
89     status = FetchDataFromParameterNode(cnode, index, flags->fmk, flags->trainModel, &data_info);
90   } else if (utils::isa<ValueNode>(node)) {
91     status = FetchDataFromValueNode(cnode, index, flags->fmk, flags->trainModel, &data_info);
92   } else {
93     status = RET_ERROR;
94   }
95   if (status != RET_OK && status != RET_NO_CHANGE) {
96     MS_LOG(ERROR) << "fetch data failed.";
97     return nullptr;
98   }
99   if (opt::CheckPrimitiveType(cnode, prim::kPrimTupleGetItem) && !data_info.data_.empty()) {
100     return NewValueNode(MakeValue<int>(*reinterpret_cast<int *>(data_info.data_.data())));
101   }
102   ShapeVector shape_vec(data_info.shape_.begin(), data_info.shape_.end());
103   auto tensor_info = std::make_shared<tensor::Tensor>(static_cast<TypeId>(data_info.data_type_), shape_vec);
104   MS_CHECK_TRUE_RET(tensor_info != nullptr, nullptr);
105   if (!data_info.data_.empty()) {
106     auto tensor_data = reinterpret_cast<uint8_t *>(tensor_info->data_c());
107     if (memcpy_s(tensor_data, tensor_info->data().nbytes(), data_info.data_.data(), data_info.data_.size()) != EOK) {
108       MS_LOG(ERROR) << "memcpy_s failed";
109       return nullptr;
110     }
111   }
112   auto mirror_parameter = mirror_graph->add_parameter();
113   MS_CHECK_TRUE_RET(mirror_parameter != nullptr, nullptr);
114   if (node->abstract() != nullptr) {
115     mirror_parameter->set_abstract(node->abstract()->Clone());
116   }
117   mirror_parameter->set_name(node->fullname_with_scope());
118   mirror_parameter->set_default_param(tensor_info);
119   return mirror_parameter;
120 }
121 
ClonePrimitive(const CNodePtr & cnode)122 PrimitivePtr ClonePrimitive(const CNodePtr &cnode) {
123   MS_ASSERT(cnode != nullptr);
124   auto origin_prim = GetValueNode<PrimitivePtr>(cnode->input(0));
125   MS_ASSERT(origin_prim != nullptr);
126   PrimitivePtr prim;
127   auto op_primc_fns = ops::OpPrimCRegister::GetInstance().GetPrimCMap();
128   if (op_primc_fns.find(origin_prim->name()) != op_primc_fns.end()) {
129     prim = op_primc_fns[origin_prim->name()]();
130   } else {
131     prim = std::make_shared<PrimitiveC>(origin_prim->name());
132     MS_CHECK_TRUE_RET(prim != nullptr, nullptr);
133     prim->set_instance_name(origin_prim->name());
134   }
135   prim->SetAttrs(origin_prim->attrs());
136   return prim;
137 }
138 
CloneFuncGraph(const FuncGraphPtr & graph,const converter::Flags * flags)139 FuncGraphPtr CloneFuncGraph(const FuncGraphPtr &graph, const converter::Flags *flags) {
140   MS_ASSERT(graph != nullptr);
141   auto mirror_graph = std::make_shared<FuncGraph>();
142   MS_CHECK_TRUE_RET(mirror_graph != nullptr, nullptr);
143   mirror_graph->set_attrs(graph->attrs());
144   NodesMap origin_nodes;
145   NodesMap mirror_nodes;
146   CloneGraphInputs(graph, mirror_graph, &origin_nodes, &mirror_nodes);
147   auto node_list = TopoSort(graph->get_return());
148   for (auto &node : node_list) {
149     if (!utils::isa<mindspore::CNode>(node)) {
150       continue;
151     }
152     auto cnode = node->cast<CNodePtr>();
153     auto mirrro_prim = ClonePrimitive(cnode);
154     std::vector<AnfNodePtr> node_inputs;
155     for (size_t i = 1; i < cnode->size(); ++i) {
156       auto origin_input = cnode->input(i);
157       MS_CHECK_TRUE_RET(origin_input != nullptr, nullptr);
158       AnfNodePtr mirror_input = nullptr;
159       auto value = origin_nodes[origin_input->fullname_with_scope()];
160       auto iter = std::find(value.begin(), value.end(), origin_input);
161       if (iter != value.end()) {
162         mirror_input = mirror_nodes[origin_input->fullname_with_scope()][iter - value.begin()];
163       }
164       if (mirror_input == nullptr) {
165         if (IsValueNode<FuncGraph>(origin_input)) {
166           auto sub_func_graph = GetValueNode<FuncGraphPtr>(origin_input);
167           auto mirror_sub_graph = CloneFuncGraph(sub_func_graph, flags);
168           mirror_input = NewValueNode(mirror_sub_graph);
169         } else {
170           mirror_input = CloneParameterAndValueNode(cnode, i, mirror_graph, flags);
171         }
172         if (mirror_input == nullptr) {
173           MS_LOG(ERROR) << "node input cannot be found.";
174           return nullptr;
175         }
176         origin_nodes[origin_input->fullname_with_scope()].push_back(origin_input);
177         mirror_nodes[origin_input->fullname_with_scope()].push_back(mirror_input);
178       }
179       node_inputs.push_back(mirror_input);
180     }
181     auto mirror_cnode = mirror_graph->NewCNode(mirrro_prim, node_inputs);
182     MS_CHECK_TRUE_RET(mirror_cnode != nullptr, nullptr);
183     mirror_cnode->set_fullname_with_scope(cnode->fullname_with_scope());
184     if (cnode->abstract() != nullptr) {
185       mirror_cnode->set_abstract(cnode->abstract()->Clone());
186     }
187     origin_nodes[cnode->fullname_with_scope()].push_back(cnode);
188     mirror_nodes[cnode->fullname_with_scope()].push_back(mirror_cnode);
189     if (opt::CheckPrimitiveType(cnode, prim::kPrimReturn)) {
190       mirror_graph->set_return(mirror_cnode);
191     }
192   }
193   return mirror_graph;
194 }
195 }  // namespace
196 
ExportModel(const FuncGraphPtr & graph,const converter::Flags * flags)197 STATUS ExportModel(const FuncGraphPtr &graph, const converter::Flags *flags) {
198   MS_ASSERT(graph != nullptr && flags != nullptr);
199   auto mirror_graph = CloneFuncGraph(graph, flags);
200   if (mirror_graph == nullptr) {
201     MS_LOG(ERROR) << "Clone funcGraph failed.";
202     return RET_ERROR;
203   }
204   (void)Manage(mirror_graph, true);
205   if (!RunOptimizerPass(mirror_graph, {"ToNHWCFormat", "InferShapePass", "DecreaseTransposeAlgo"})) {
206     MS_LOG(ERROR) << "Run transpose opt pass failed.";
207     return RET_ERROR;
208   }
209   auto optimizer = std::make_shared<opt::GraphOptimizer>();
210   CHECK_NULL_RETURN(optimizer);
211   auto graph_pm = std::make_shared<opt::PassManager>("anf graph pass manager", true);
212   CHECK_NULL_RETURN(graph_pm);
213   if (flags->fmk == converter::kFmkTypeTflite || flags->fmk == converter::kFmkTypeTf ||
214       flags->fmk == converter::kFmkTypeOnnx) {
215     graph_pm->AddPass(std::make_shared<opt::ControlFlowPass>());
216   }
217   optimizer->AddPassManager(graph_pm);
218   if (optimizer->Optimize(mirror_graph) == nullptr) {
219     MS_LOG(ERROR) << "run  graph pass failed.";
220     return RET_ERROR;
221   }
222   auto meta_graph = Export(mirror_graph);
223   if (meta_graph == nullptr) {
224     MS_LOG(ERROR) << "Export to meta graph return nullptr";
225     return RET_ERROR;
226   }
227   auto metagraph_transform = std::make_unique<GraphDefTransform>();
228   CHECK_NULL_RETURN(metagraph_transform);
229   metagraph_transform->SetGraphDef(meta_graph);
230   auto status = metagraph_transform->Transform(*flags);
231   if (status != RET_OK) {
232     MS_LOG(ERROR) << "Transform meta graph failed " << status;
233     return RET_ERROR;
234   }
235   meta_graph->version = Version();
236   status = Storage::Save(*meta_graph, "model");
237   std::ostringstream oss;
238   if (status != RET_OK) {
239     oss << "SAVE GRAPH FAILED:" << status << " " << lite::GetErrorInfo(status);
240     MS_LOG(ERROR) << oss.str();
241     std::cout << oss.str() << std::endl;
242     return status;
243   }
244 
245   delete meta_graph;
246   return status;
247 }
248 }  // namespace lite
249 }  // namespace mindspore
250