• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define USE_DEPRECATED_API
18 #include "tools/converter/export_model.h"
19 #include <fstream>
20 #include <map>
21 #include <memory>
22 #include <set>
23 #include <string>
24 #include <vector>
25 #include "mindspore/core/ops/sequence_ops.h"
26 #include "mindspore/core/ops/framework_ops.h"
27 #include "include/backend/optimizer/optimizer.h"
28 #include "include/errorcode.h"
29 #include "ir/func_graph.h"
30 #include "tools/lite_exporter/anf_exporter.h"
31 #include "tools/optimizer/common/pass_manager_extends.h"
32 #include "tools/converter/graphdef_transform.h"
33 #include "tools/converter/optimizer_manager.h"
34 #include "tools/converter/parser/parser_utils.h"
35 #include "tools/optimizer/graph/control_flow_pass.h"
36 #include "tools/optimizer/graph/clip_convert_activation_pass.h"
37 #include "nnacl/op_base.h"
38 #include "src/common/log_util.h"
39 
40 namespace mindspore {
41 namespace lite {
42 namespace {
43 using NodesMap = std::map<std::string, std::vector<AnfNodePtr>>;
CloneGraphInputs(const FuncGraphPtr & origin,const FuncGraphPtr & mirror,NodesMap * origin_map,NodesMap * mirror_map)44 void CloneGraphInputs(const FuncGraphPtr &origin, const FuncGraphPtr &mirror, NodesMap *origin_map,
45                       NodesMap *mirror_map) {
46   MS_ASSERT(origin != nullptr && mirror != nullptr);
47   MS_ASSERT(origin_map != nullptr && mirror_map != nullptr);
48   auto origin_inputs = origin->get_inputs();
49   for (auto &input : origin_inputs) {
50     auto mirror_input = mirror->add_parameter();
51     MS_CHECK_TRUE_RET_VOID(mirror_input != nullptr);
52     if (input->abstract() != nullptr) {
53       mirror_input->set_abstract(input->abstract()->Clone());
54     }
55     mirror_input->set_name(input->fullname_with_scope());
56     MS_ASSERT(origin_map->find(input->fullname_with_scope()) != origin_map->end());
57     MS_ASSERT(mirror_map->find(input->fullname_with_scope()) != mirror_map->end());
58     (*origin_map)[input->fullname_with_scope()].push_back(input);
59     (*mirror_map)[input->fullname_with_scope()].push_back(mirror_input);
60   }
61 }
62 
CloneParameterAndValueNode(const CNodePtr & cnode,size_t index,const FuncGraphPtr & mirror_graph,const FuncGraphManagerPtr & manager,const std::shared_ptr<ConverterPara> & param)63 AnfNodePtr CloneParameterAndValueNode(const CNodePtr &cnode, size_t index, const FuncGraphPtr &mirror_graph,
64                                       const FuncGraphManagerPtr &manager, const std::shared_ptr<ConverterPara> &param) {
65   MS_ASSERT(cnode != nullptr && mirror_graph != nullptr);
66   MS_CHECK_TRUE_RET(index < cnode->size(), nullptr);
67   auto node = cnode->input(index);
68   if (node == nullptr || utils::isa<mindspore::CNode>(node)) {
69     MS_LOG(ERROR) << "this func cannot copy cnode.";
70     return nullptr;
71   }
72   if (utils::isa<ValueNode>(node)) {
73     auto value_node = node->cast<ValueNodePtr>();
74     MS_CHECK_TRUE_RET(value_node != nullptr, nullptr);
75     auto value_ptr = value_node->value();
76     MS_CHECK_TRUE_RET(value_ptr != nullptr, nullptr);
77     if (utils::isa<Monad>(value_ptr)) {
78       std::shared_ptr<Monad> mirror_monad;
79       if (utils::isa<UMonad>(value_ptr)) {
80         mirror_monad = std::make_shared<UMonad>();
81       } else {
82         mirror_monad = std::make_shared<IOMonad>();
83       }
84       MS_CHECK_TRUE_RET(mirror_monad != nullptr, nullptr);
85       auto monad_abs = mirror_monad->ToAbstract();
86       MS_CHECK_TRUE_RET(monad_abs != nullptr, nullptr);
87       auto mirror_value_node = NewValueNode(mirror_monad);
88       MS_CHECK_TRUE_RET(mirror_value_node != nullptr, nullptr);
89       mirror_value_node->set_abstract(monad_abs);
90       return mirror_value_node;
91     }
92   }
93   DataInfo data_info;
94   STATUS status = RET_ERROR;
95   if (utils::isa<Parameter>(node)) {
96     status = FetchDataFromParameterNode(cnode, index, param->fmk_type, &data_info, true);
97   } else if (utils::isa<ValueNode>(node)) {
98     status = FetchDataFromValueNode(cnode, index, param->fmk_type, param->train_model, &data_info, true);
99   }
100   if (status != RET_OK && status != RET_NO_CHANGE) {
101     MS_LOG(ERROR) << "fetch data failed.";
102     return nullptr;
103   }
104   if (opt::CheckPrimitiveType(cnode, prim::kPrimTupleGetItem) && data_info.data_.size() >= sizeof(int)) {
105     return NewValueNode(MakeValue<int64_t>(*reinterpret_cast<int *>(data_info.data_.data())));
106   }
107   ShapeVector shape_vec(data_info.shape_.begin(), data_info.shape_.end());
108   if (data_info.data_type_ == kObjectTypeTensorType) {
109     shape_vec = ShapeVector{static_cast<int64_t>(data_info.data_.size() / sizeof(int))};
110   }
111   std::shared_ptr<tensor::Tensor> tensor_info;
112   if (static_cast<TensorCompressionType>(data_info.compress_type_) == TensorCompressionType::kNoCompression) {
113     tensor_info = std::make_shared<tensor::Tensor>(static_cast<TypeId>(data_info.data_type_), shape_vec);
114   } else {
115     tensor_info =
116       std::make_shared<tensor::Tensor>(static_cast<TypeId>(data_info.data_type_), shape_vec, data_info.data_.size(),
117                                        static_cast<TensorCompressionType>(data_info.compress_type_));
118   }
119   MS_CHECK_TRUE_RET(tensor_info != nullptr, nullptr);
120   if (!data_info.data_.empty()) {
121     auto tensor_data = reinterpret_cast<uint8_t *>(tensor_info->data_c());
122     if (tensor_data == nullptr || tensor_info->data().nbytes() < 0) {
123       MS_LOG(ERROR) << "tensor info data is nullptr or the size is smaller than zero.";
124       return nullptr;
125     }
126     if (memcpy_s(tensor_data, tensor_info->data().nbytes(), data_info.data_.data(), data_info.data_.size()) != EOK) {
127       MS_LOG(ERROR) << "memcpy_s failed";
128       return nullptr;
129     }
130   }
131   tensor_info->set_quant_param(data_info.quant_params_);
132   auto mirror_parameter = mirror_graph->add_parameter();
133   MS_CHECK_TRUE_RET(mirror_parameter != nullptr, nullptr);
134 
135   mirror_parameter->set_name(node->fullname_with_scope());
136   mirror_parameter->set_default_param(tensor_info);
137   mirror_parameter->set_abstract(tensor_info->ToAbstract());
138   return mirror_parameter;
139 }
140 
ClonePrimitive(const CNodePtr & cnode)141 PrimitivePtr ClonePrimitive(const CNodePtr &cnode) {
142   MS_ASSERT(cnode != nullptr);
143   auto origin_prim = GetValueNode<PrimitivePtr>(cnode->input(0));
144   if (origin_prim == nullptr) {
145     return nullptr;
146   }
147   PrimitivePtr prim;
148   auto op_primc_fns = ops::OpPrimCRegister::GetInstance().GetPrimCMap();
149   if (op_primc_fns.find(origin_prim->name()) != op_primc_fns.end()) {
150     prim = op_primc_fns[origin_prim->name()]();
151     MS_CHECK_TRUE_RET(prim != nullptr, nullptr);
152   } else {
153     prim = std::make_shared<PrimitiveC>(origin_prim->name());
154     MS_CHECK_TRUE_RET(prim != nullptr, nullptr);
155     prim->set_instance_name(origin_prim->name());
156   }
157   prim->SetAttrs(origin_prim->attrs());
158   if (prim->GetAttr("quant_params") != nullptr) {
159     auto quant_holder = prim->GetAttr("quant_params")->cast<QuantParamHolderPtr>();
160     prim->AddAttr("quant_params", std::make_shared<QuantParamHolder>(*quant_holder));
161   }
162   return prim;
163 }
164 }  // namespace
165 
CloneFuncGraph(const FuncGraphPtr & graph,const std::shared_ptr<ConverterPara> & param,std::map<FuncGraphPtr,FuncGraphPtr> * cloned_func_graph)166 FuncGraphPtr CloneFuncGraph(const FuncGraphPtr &graph, const std::shared_ptr<ConverterPara> &param,
167                             std::map<FuncGraphPtr, FuncGraphPtr> *cloned_func_graph) {
168   MS_ASSERT(graph != nullptr);
169   MS_ASSERT(param != nullptr);
170   MS_ASSERT(cloned_func_graph != nullptr);
171   auto cloned_func_graph_iter = cloned_func_graph->find(graph);
172   if (cloned_func_graph_iter != cloned_func_graph->end()) {
173     return cloned_func_graph_iter->second;
174   }
175   auto mirror_graph = std::make_shared<FuncGraph>();
176   MS_CHECK_TRUE_RET(mirror_graph != nullptr, nullptr);
177   auto ret = cloned_func_graph->emplace(graph, mirror_graph);
178   if (!ret.second) {
179     MS_LOG(ERROR) << "emplace mirror graph into map failed.";
180     return nullptr;
181   }
182   mirror_graph->set_attrs(graph->attrs());
183   NodesMap origin_nodes;
184   NodesMap mirror_nodes;
185   CloneGraphInputs(graph, mirror_graph, &origin_nodes, &mirror_nodes);
186   auto node_list = TopoSort(graph->get_return());
187   auto manager = graph->manager();
188   MS_CHECK_TRUE_RET(manager != nullptr, nullptr);
189   for (auto &node : node_list) {
190     if (!utils::isa<mindspore::CNode>(node)) {
191       continue;
192     }
193     auto cnode = node->cast<CNodePtr>();
194     std::vector<AnfNodePtr> node_inputs;
195     size_t begin_index = 1;
196     auto mirror_prim = ClonePrimitive(cnode);
197     if (mirror_prim == nullptr) {
198       begin_index = 0;
199     }
200     for (size_t i = begin_index; i < cnode->size(); ++i) {
201       auto origin_input = cnode->input(i);
202       MS_CHECK_TRUE_RET(origin_input != nullptr, nullptr);
203       AnfNodePtr mirror_input = nullptr;
204       auto value = origin_nodes[origin_input->fullname_with_scope()];
205       auto iter = std::find(value.begin(), value.end(), origin_input);
206       if (iter != value.end()) {
207         mirror_input = mirror_nodes[origin_input->fullname_with_scope()][iter - value.begin()];
208       }
209       if (mirror_input == nullptr) {
210         if (IsValueNode<FuncGraph>(origin_input)) {
211           auto sub_func_graph = GetValueNode<FuncGraphPtr>(origin_input);
212           MS_CHECK_TRUE_RET(sub_func_graph != nullptr, nullptr);
213           auto mirror_sub_graph = CloneFuncGraph(sub_func_graph, param, cloned_func_graph);
214           mirror_input = NewValueNode(mirror_sub_graph);
215         } else {
216           mirror_input = CloneParameterAndValueNode(cnode, i, mirror_graph, manager, param);
217         }
218         if (mirror_input == nullptr) {
219           MS_LOG(ERROR) << "node input cannot be found.";
220           return nullptr;
221         }
222         origin_nodes[origin_input->fullname_with_scope()].push_back(origin_input);
223         mirror_nodes[origin_input->fullname_with_scope()].push_back(mirror_input);
224       }
225       node_inputs.push_back(mirror_input);
226     }
227     auto mirror_cnode =
228       mirror_prim == nullptr ? mirror_graph->NewCNode(node_inputs) : mirror_graph->NewCNode(mirror_prim, node_inputs);
229     MS_CHECK_TRUE_RET(mirror_cnode != nullptr, nullptr);
230     mirror_cnode->set_fullname_with_scope(cnode->fullname_with_scope());
231     auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
232     MS_CHECK_TRUE_RET(primitive != nullptr, nullptr);
233     auto quant_type_valueptr = primitive->GetAttr(quant::kQuantType);
234     if (quant_type_valueptr != nullptr) {
235       mirror_cnode->AddAttr(quant::kQuantType, quant_type_valueptr);
236     }
237     if (cnode->abstract() != nullptr) {
238       mirror_cnode->set_abstract(cnode->abstract()->Clone());
239     }
240     origin_nodes[cnode->fullname_with_scope()].push_back(cnode);
241     mirror_nodes[cnode->fullname_with_scope()].push_back(mirror_cnode);
242     if (opt::CheckPrimitiveType(cnode, prim::kPrimReturn)) {
243       mirror_graph->set_return(mirror_cnode);
244     }
245   }
246   return mirror_graph;
247 }
248 
ExportModel(const FuncGraphPtr & graph,const std::shared_ptr<ConverterPara> & param)249 STATUS ExportModel(const FuncGraphPtr &graph, const std::shared_ptr<ConverterPara> &param) {
250   CHECK_NULL_RETURN(graph);
251   CHECK_NULL_RETURN(param);
252   std::map<FuncGraphPtr, FuncGraphPtr> cloned_func_graph;
253   auto mirror_graph = CloneFuncGraph(graph, param, &cloned_func_graph);
254   if (mirror_graph == nullptr) {
255     MS_LOG(ERROR) << "Clone funcGraph failed.";
256     return RET_ERROR;
257   }
258   auto manager = Manage(mirror_graph, true);
259   MS_CHECK_TRUE_RET(manager != nullptr, RET_ERROR);
260   std::set<FuncGraphPtr> all_func_graphs;
261   GetAllFuncGraph(mirror_graph, &all_func_graphs);
262   for (auto &func_graph : all_func_graphs) {
263     manager->AddFuncGraph(func_graph);
264   }
265   auto clip_transfer = std::make_shared<opt::ClipConvertActivationPass>();
266   CHECK_NULL_RETURN(clip_transfer);
267   (void)clip_transfer->Run(mirror_graph);
268   if (!RunOptimizerPass(mirror_graph, {"ToNHWCFormat", "InferShapePass", "SpecialNodePostProcess"})) {
269     MS_LOG(ERROR) << "Run transpose opt pass failed.";
270     return RET_ERROR;
271   }
272   auto optimizer = std::make_shared<opt::GraphOptimizer>();
273   CHECK_NULL_RETURN(optimizer);
274   auto graph_pm = std::make_shared<opt::LitePassManager>("anf graph pass manager", true);
275   CHECK_NULL_RETURN(graph_pm);
276   if (param->fmk_type == converter::kFmkTypeTflite || param->fmk_type == converter::kFmkTypeTf ||
277       param->fmk_type == converter::kFmkTypeOnnx) {
278     graph_pm->AddPass(std::make_shared<opt::ControlFlowPass>());
279   }
280   optimizer->AddPassManager(graph_pm);
281   if (optimizer->Optimize(mirror_graph) == nullptr) {
282     MS_LOG(ERROR) << "run  graph pass failed.";
283     return RET_ERROR;
284   }
285   auto meta_graph = Export(mirror_graph);
286   if (meta_graph == nullptr) {
287     MS_LOG(ERROR) << "Export to meta graph return nullptr";
288     return RET_ERROR;
289   }
290   auto metagraph_transform = std::make_unique<GraphDefTransform>();
291   if (metagraph_transform == nullptr) {
292     MS_LOG(ERROR) << "Create metagraph_transform return nullptr";
293     delete meta_graph;
294     return RET_ERROR;
295   }
296   metagraph_transform->SetGraphDef(meta_graph);
297   auto status = metagraph_transform->Transform(param);
298   if (status != RET_OK) {
299     MS_LOG(ERROR) << "Transform meta graph failed " << status;
300     delete meta_graph;
301     return RET_ERROR;
302   }
303   // set output tensor names to the original names, the output_names is null in nnie converter.
304   auto output_names = ConverterInnerContext::GetInstance()->GetGraphOutputTensorNames();
305   if (output_names.size() > meta_graph->outputIndex.size()) {
306     MS_LOG(ERROR) << "the num of setting output_names is greater than actual, " << output_names.size() << " > "
307                   << meta_graph->outputIndex.size() << ".";
308     ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
309     delete meta_graph;
310     return RET_ERROR;
311   }
312   for (size_t idx = 0; idx < output_names.size(); idx++) {
313     auto &tensor = meta_graph->allTensors.at(meta_graph->outputIndex.at(idx));
314     tensor->name = output_names.at(idx);
315   }
316   meta_graph->version = Version();
317   status = MetaGraphSerializer::Save(*meta_graph, "model");
318   delete meta_graph;
319   std::ostringstream oss;
320   if (status != RET_OK) {
321     oss << "SAVE GRAPH FAILED:" << status << " " << lite::GetErrorInfo(status);
322     MS_LOG(ERROR) << oss.str();
323     std::cout << oss.str() << std::endl;
324     return status;
325   }
326   return status;
327 }
328 }  // namespace lite
329 }  // namespace mindspore
330