1 /**
2 * Copyright 2021-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define USE_DEPRECATED_API
18 #include "tools/converter/export_model.h"
19 #include <fstream>
20 #include <map>
21 #include <memory>
22 #include <set>
23 #include <string>
24 #include <vector>
25 #include "mindspore/core/ops/sequence_ops.h"
26 #include "mindspore/core/ops/framework_ops.h"
27 #include "include/backend/optimizer/optimizer.h"
28 #include "include/errorcode.h"
29 #include "ir/func_graph.h"
30 #include "tools/lite_exporter/anf_exporter.h"
31 #include "tools/optimizer/common/pass_manager_extends.h"
32 #include "tools/converter/graphdef_transform.h"
33 #include "tools/converter/optimizer_manager.h"
34 #include "tools/converter/parser/parser_utils.h"
35 #include "tools/optimizer/graph/control_flow_pass.h"
36 #include "tools/optimizer/graph/clip_convert_activation_pass.h"
37 #include "nnacl/op_base.h"
38 #include "src/common/log_util.h"
39
40 namespace mindspore {
41 namespace lite {
42 namespace {
43 using NodesMap = std::map<std::string, std::vector<AnfNodePtr>>;
CloneGraphInputs(const FuncGraphPtr & origin,const FuncGraphPtr & mirror,NodesMap * origin_map,NodesMap * mirror_map)44 void CloneGraphInputs(const FuncGraphPtr &origin, const FuncGraphPtr &mirror, NodesMap *origin_map,
45 NodesMap *mirror_map) {
46 MS_ASSERT(origin != nullptr && mirror != nullptr);
47 MS_ASSERT(origin_map != nullptr && mirror_map != nullptr);
48 auto origin_inputs = origin->get_inputs();
49 for (auto &input : origin_inputs) {
50 auto mirror_input = mirror->add_parameter();
51 MS_CHECK_TRUE_RET_VOID(mirror_input != nullptr);
52 if (input->abstract() != nullptr) {
53 mirror_input->set_abstract(input->abstract()->Clone());
54 }
55 mirror_input->set_name(input->fullname_with_scope());
56 MS_ASSERT(origin_map->find(input->fullname_with_scope()) != origin_map->end());
57 MS_ASSERT(mirror_map->find(input->fullname_with_scope()) != mirror_map->end());
58 (*origin_map)[input->fullname_with_scope()].push_back(input);
59 (*mirror_map)[input->fullname_with_scope()].push_back(mirror_input);
60 }
61 }
62
CloneParameterAndValueNode(const CNodePtr & cnode,size_t index,const FuncGraphPtr & mirror_graph,const FuncGraphManagerPtr & manager,const std::shared_ptr<ConverterPara> & param)63 AnfNodePtr CloneParameterAndValueNode(const CNodePtr &cnode, size_t index, const FuncGraphPtr &mirror_graph,
64 const FuncGraphManagerPtr &manager, const std::shared_ptr<ConverterPara> ¶m) {
65 MS_ASSERT(cnode != nullptr && mirror_graph != nullptr);
66 MS_CHECK_TRUE_RET(index < cnode->size(), nullptr);
67 auto node = cnode->input(index);
68 if (node == nullptr || utils::isa<mindspore::CNode>(node)) {
69 MS_LOG(ERROR) << "this func cannot copy cnode.";
70 return nullptr;
71 }
72 if (utils::isa<ValueNode>(node)) {
73 auto value_node = node->cast<ValueNodePtr>();
74 MS_CHECK_TRUE_RET(value_node != nullptr, nullptr);
75 auto value_ptr = value_node->value();
76 MS_CHECK_TRUE_RET(value_ptr != nullptr, nullptr);
77 if (utils::isa<Monad>(value_ptr)) {
78 std::shared_ptr<Monad> mirror_monad;
79 if (utils::isa<UMonad>(value_ptr)) {
80 mirror_monad = std::make_shared<UMonad>();
81 } else {
82 mirror_monad = std::make_shared<IOMonad>();
83 }
84 MS_CHECK_TRUE_RET(mirror_monad != nullptr, nullptr);
85 auto monad_abs = mirror_monad->ToAbstract();
86 MS_CHECK_TRUE_RET(monad_abs != nullptr, nullptr);
87 auto mirror_value_node = NewValueNode(mirror_monad);
88 MS_CHECK_TRUE_RET(mirror_value_node != nullptr, nullptr);
89 mirror_value_node->set_abstract(monad_abs);
90 return mirror_value_node;
91 }
92 }
93 DataInfo data_info;
94 STATUS status = RET_ERROR;
95 if (utils::isa<Parameter>(node)) {
96 status = FetchDataFromParameterNode(cnode, index, param->fmk_type, &data_info, true);
97 } else if (utils::isa<ValueNode>(node)) {
98 status = FetchDataFromValueNode(cnode, index, param->fmk_type, param->train_model, &data_info, true);
99 }
100 if (status != RET_OK && status != RET_NO_CHANGE) {
101 MS_LOG(ERROR) << "fetch data failed.";
102 return nullptr;
103 }
104 if (opt::CheckPrimitiveType(cnode, prim::kPrimTupleGetItem) && data_info.data_.size() >= sizeof(int)) {
105 return NewValueNode(MakeValue<int64_t>(*reinterpret_cast<int *>(data_info.data_.data())));
106 }
107 ShapeVector shape_vec(data_info.shape_.begin(), data_info.shape_.end());
108 if (data_info.data_type_ == kObjectTypeTensorType) {
109 shape_vec = ShapeVector{static_cast<int64_t>(data_info.data_.size() / sizeof(int))};
110 }
111 std::shared_ptr<tensor::Tensor> tensor_info;
112 if (static_cast<TensorCompressionType>(data_info.compress_type_) == TensorCompressionType::kNoCompression) {
113 tensor_info = std::make_shared<tensor::Tensor>(static_cast<TypeId>(data_info.data_type_), shape_vec);
114 } else {
115 tensor_info =
116 std::make_shared<tensor::Tensor>(static_cast<TypeId>(data_info.data_type_), shape_vec, data_info.data_.size(),
117 static_cast<TensorCompressionType>(data_info.compress_type_));
118 }
119 MS_CHECK_TRUE_RET(tensor_info != nullptr, nullptr);
120 if (!data_info.data_.empty()) {
121 auto tensor_data = reinterpret_cast<uint8_t *>(tensor_info->data_c());
122 if (tensor_data == nullptr || tensor_info->data().nbytes() < 0) {
123 MS_LOG(ERROR) << "tensor info data is nullptr or the size is smaller than zero.";
124 return nullptr;
125 }
126 if (memcpy_s(tensor_data, tensor_info->data().nbytes(), data_info.data_.data(), data_info.data_.size()) != EOK) {
127 MS_LOG(ERROR) << "memcpy_s failed";
128 return nullptr;
129 }
130 }
131 tensor_info->set_quant_param(data_info.quant_params_);
132 auto mirror_parameter = mirror_graph->add_parameter();
133 MS_CHECK_TRUE_RET(mirror_parameter != nullptr, nullptr);
134
135 mirror_parameter->set_name(node->fullname_with_scope());
136 mirror_parameter->set_default_param(tensor_info);
137 mirror_parameter->set_abstract(tensor_info->ToAbstract());
138 return mirror_parameter;
139 }
140
ClonePrimitive(const CNodePtr & cnode)141 PrimitivePtr ClonePrimitive(const CNodePtr &cnode) {
142 MS_ASSERT(cnode != nullptr);
143 auto origin_prim = GetValueNode<PrimitivePtr>(cnode->input(0));
144 if (origin_prim == nullptr) {
145 return nullptr;
146 }
147 PrimitivePtr prim;
148 auto op_primc_fns = ops::OpPrimCRegister::GetInstance().GetPrimCMap();
149 if (op_primc_fns.find(origin_prim->name()) != op_primc_fns.end()) {
150 prim = op_primc_fns[origin_prim->name()]();
151 MS_CHECK_TRUE_RET(prim != nullptr, nullptr);
152 } else {
153 prim = std::make_shared<PrimitiveC>(origin_prim->name());
154 MS_CHECK_TRUE_RET(prim != nullptr, nullptr);
155 prim->set_instance_name(origin_prim->name());
156 }
157 prim->SetAttrs(origin_prim->attrs());
158 if (prim->GetAttr("quant_params") != nullptr) {
159 auto quant_holder = prim->GetAttr("quant_params")->cast<QuantParamHolderPtr>();
160 prim->AddAttr("quant_params", std::make_shared<QuantParamHolder>(*quant_holder));
161 }
162 return prim;
163 }
164 } // namespace
165
CloneFuncGraph(const FuncGraphPtr & graph,const std::shared_ptr<ConverterPara> & param,std::map<FuncGraphPtr,FuncGraphPtr> * cloned_func_graph)166 FuncGraphPtr CloneFuncGraph(const FuncGraphPtr &graph, const std::shared_ptr<ConverterPara> ¶m,
167 std::map<FuncGraphPtr, FuncGraphPtr> *cloned_func_graph) {
168 MS_ASSERT(graph != nullptr);
169 MS_ASSERT(param != nullptr);
170 MS_ASSERT(cloned_func_graph != nullptr);
171 auto cloned_func_graph_iter = cloned_func_graph->find(graph);
172 if (cloned_func_graph_iter != cloned_func_graph->end()) {
173 return cloned_func_graph_iter->second;
174 }
175 auto mirror_graph = std::make_shared<FuncGraph>();
176 MS_CHECK_TRUE_RET(mirror_graph != nullptr, nullptr);
177 auto ret = cloned_func_graph->emplace(graph, mirror_graph);
178 if (!ret.second) {
179 MS_LOG(ERROR) << "emplace mirror graph into map failed.";
180 return nullptr;
181 }
182 mirror_graph->set_attrs(graph->attrs());
183 NodesMap origin_nodes;
184 NodesMap mirror_nodes;
185 CloneGraphInputs(graph, mirror_graph, &origin_nodes, &mirror_nodes);
186 auto node_list = TopoSort(graph->get_return());
187 auto manager = graph->manager();
188 MS_CHECK_TRUE_RET(manager != nullptr, nullptr);
189 for (auto &node : node_list) {
190 if (!utils::isa<mindspore::CNode>(node)) {
191 continue;
192 }
193 auto cnode = node->cast<CNodePtr>();
194 std::vector<AnfNodePtr> node_inputs;
195 size_t begin_index = 1;
196 auto mirror_prim = ClonePrimitive(cnode);
197 if (mirror_prim == nullptr) {
198 begin_index = 0;
199 }
200 for (size_t i = begin_index; i < cnode->size(); ++i) {
201 auto origin_input = cnode->input(i);
202 MS_CHECK_TRUE_RET(origin_input != nullptr, nullptr);
203 AnfNodePtr mirror_input = nullptr;
204 auto value = origin_nodes[origin_input->fullname_with_scope()];
205 auto iter = std::find(value.begin(), value.end(), origin_input);
206 if (iter != value.end()) {
207 mirror_input = mirror_nodes[origin_input->fullname_with_scope()][iter - value.begin()];
208 }
209 if (mirror_input == nullptr) {
210 if (IsValueNode<FuncGraph>(origin_input)) {
211 auto sub_func_graph = GetValueNode<FuncGraphPtr>(origin_input);
212 MS_CHECK_TRUE_RET(sub_func_graph != nullptr, nullptr);
213 auto mirror_sub_graph = CloneFuncGraph(sub_func_graph, param, cloned_func_graph);
214 mirror_input = NewValueNode(mirror_sub_graph);
215 } else {
216 mirror_input = CloneParameterAndValueNode(cnode, i, mirror_graph, manager, param);
217 }
218 if (mirror_input == nullptr) {
219 MS_LOG(ERROR) << "node input cannot be found.";
220 return nullptr;
221 }
222 origin_nodes[origin_input->fullname_with_scope()].push_back(origin_input);
223 mirror_nodes[origin_input->fullname_with_scope()].push_back(mirror_input);
224 }
225 node_inputs.push_back(mirror_input);
226 }
227 auto mirror_cnode =
228 mirror_prim == nullptr ? mirror_graph->NewCNode(node_inputs) : mirror_graph->NewCNode(mirror_prim, node_inputs);
229 MS_CHECK_TRUE_RET(mirror_cnode != nullptr, nullptr);
230 mirror_cnode->set_fullname_with_scope(cnode->fullname_with_scope());
231 auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
232 MS_CHECK_TRUE_RET(primitive != nullptr, nullptr);
233 auto quant_type_valueptr = primitive->GetAttr(quant::kQuantType);
234 if (quant_type_valueptr != nullptr) {
235 mirror_cnode->AddAttr(quant::kQuantType, quant_type_valueptr);
236 }
237 if (cnode->abstract() != nullptr) {
238 mirror_cnode->set_abstract(cnode->abstract()->Clone());
239 }
240 origin_nodes[cnode->fullname_with_scope()].push_back(cnode);
241 mirror_nodes[cnode->fullname_with_scope()].push_back(mirror_cnode);
242 if (opt::CheckPrimitiveType(cnode, prim::kPrimReturn)) {
243 mirror_graph->set_return(mirror_cnode);
244 }
245 }
246 return mirror_graph;
247 }
248
ExportModel(const FuncGraphPtr & graph,const std::shared_ptr<ConverterPara> & param)249 STATUS ExportModel(const FuncGraphPtr &graph, const std::shared_ptr<ConverterPara> ¶m) {
250 CHECK_NULL_RETURN(graph);
251 CHECK_NULL_RETURN(param);
252 std::map<FuncGraphPtr, FuncGraphPtr> cloned_func_graph;
253 auto mirror_graph = CloneFuncGraph(graph, param, &cloned_func_graph);
254 if (mirror_graph == nullptr) {
255 MS_LOG(ERROR) << "Clone funcGraph failed.";
256 return RET_ERROR;
257 }
258 auto manager = Manage(mirror_graph, true);
259 MS_CHECK_TRUE_RET(manager != nullptr, RET_ERROR);
260 std::set<FuncGraphPtr> all_func_graphs;
261 GetAllFuncGraph(mirror_graph, &all_func_graphs);
262 for (auto &func_graph : all_func_graphs) {
263 manager->AddFuncGraph(func_graph);
264 }
265 auto clip_transfer = std::make_shared<opt::ClipConvertActivationPass>();
266 CHECK_NULL_RETURN(clip_transfer);
267 (void)clip_transfer->Run(mirror_graph);
268 if (!RunOptimizerPass(mirror_graph, {"ToNHWCFormat", "InferShapePass", "SpecialNodePostProcess"})) {
269 MS_LOG(ERROR) << "Run transpose opt pass failed.";
270 return RET_ERROR;
271 }
272 auto optimizer = std::make_shared<opt::GraphOptimizer>();
273 CHECK_NULL_RETURN(optimizer);
274 auto graph_pm = std::make_shared<opt::LitePassManager>("anf graph pass manager", true);
275 CHECK_NULL_RETURN(graph_pm);
276 if (param->fmk_type == converter::kFmkTypeTflite || param->fmk_type == converter::kFmkTypeTf ||
277 param->fmk_type == converter::kFmkTypeOnnx) {
278 graph_pm->AddPass(std::make_shared<opt::ControlFlowPass>());
279 }
280 optimizer->AddPassManager(graph_pm);
281 if (optimizer->Optimize(mirror_graph) == nullptr) {
282 MS_LOG(ERROR) << "run graph pass failed.";
283 return RET_ERROR;
284 }
285 auto meta_graph = Export(mirror_graph);
286 if (meta_graph == nullptr) {
287 MS_LOG(ERROR) << "Export to meta graph return nullptr";
288 return RET_ERROR;
289 }
290 auto metagraph_transform = std::make_unique<GraphDefTransform>();
291 if (metagraph_transform == nullptr) {
292 MS_LOG(ERROR) << "Create metagraph_transform return nullptr";
293 delete meta_graph;
294 return RET_ERROR;
295 }
296 metagraph_transform->SetGraphDef(meta_graph);
297 auto status = metagraph_transform->Transform(param);
298 if (status != RET_OK) {
299 MS_LOG(ERROR) << "Transform meta graph failed " << status;
300 delete meta_graph;
301 return RET_ERROR;
302 }
303 // set output tensor names to the original names, the output_names is null in nnie converter.
304 auto output_names = ConverterInnerContext::GetInstance()->GetGraphOutputTensorNames();
305 if (output_names.size() > meta_graph->outputIndex.size()) {
306 MS_LOG(ERROR) << "the num of setting output_names is greater than actual, " << output_names.size() << " > "
307 << meta_graph->outputIndex.size() << ".";
308 ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
309 delete meta_graph;
310 return RET_ERROR;
311 }
312 for (size_t idx = 0; idx < output_names.size(); idx++) {
313 auto &tensor = meta_graph->allTensors.at(meta_graph->outputIndex.at(idx));
314 tensor->name = output_names.at(idx);
315 }
316 meta_graph->version = Version();
317 status = MetaGraphSerializer::Save(*meta_graph, "model");
318 delete meta_graph;
319 std::ostringstream oss;
320 if (status != RET_OK) {
321 oss << "SAVE GRAPH FAILED:" << status << " " << lite::GetErrorInfo(status);
322 MS_LOG(ERROR) << oss.str();
323 std::cout << oss.str() << std::endl;
324 return status;
325 }
326 return status;
327 }
328 } // namespace lite
329 } // namespace mindspore
330