1 /**
2 * Copyright 2020-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define USE_DEPRECATED_API
18
19 #include "tools/lite_exporter/anf_exporter.h"
20 #include <functional>
21 #include <list>
22 #include <memory>
23 #include <string>
24 #include <utility>
25 #include <vector>
26 #include "abstract/abstract_value.h"
27 #include "mindspore/core/ir/primitive.h"
28 #include "mindspore/core/ops/framework_ops.h"
29 #include "mindspore/core/ops/lite_ops.h"
30 #include "mindspore/core/ops/nn_ops.h"
31 #include "mindspore/core/ops/op_name.h"
32 #include "mindspore/core/ops/op_utils.h"
33 #include "mindspore/core/ops/sequence_ops.h"
34 #include "nnacl/op_base.h"
35 #include "ops/depend.h"
36 #include "ops/fusion/partial_fusion.h"
37 #include "ops/make_tuple.h"
38 #include "ops/return.h"
39 #include "ops/tuple_get_item.h"
40 #include "ops/fusion/make_tuple_v2.h"
41 #include "src/common/log_util.h"
42 #include "src/common/ops/anf_utils.h"
43 #include "src/common/utils.h"
44 #include "src/litert/tensor_category.h"
45 #include "tools/common/graph_util.h"
46 #include "tools/common/meta_graph_utils.h"
47 #include "tools/common/node_util.h"
48 #include "tools/converter/converter_context.h"
49 #include "tools/converter/quantizer/quantize_util.h"
50
51 using mindspore::ops::PrimitiveC;
52
53 namespace {
54 constexpr const int kMainGraphIndex = 0;
55 constexpr const int kFirstDataIndex = 1;
56 constexpr const int kSecondDataIndex = 2;
57 constexpr const int kThirdDataIndex = 3;
58 constexpr const int kPrimIndex = 0;
59 }; // namespace
60
61 namespace mindspore::lite {
62 namespace {
63 constexpr int kIndexOfValueInputOfGetTupleItem = 2;
64 constexpr int kMaxDepth = 2048;
65
GetOrderedCNodes(const FuncGraphPtr fg)66 std::list<CNodePtr> GetOrderedCNodes(const FuncGraphPtr fg) {
67 MS_CHECK_TRUE_MSG(fg != nullptr, {}, "fg is nullptr.");
68 auto BelongSameGraph = std::bind(IncludeBelongGraph, fg, std::placeholders::_1);
69 auto succ_include_fv = [&fg](const AnfNodePtr &node) -> std::vector<AnfNodePtr> {
70 std::vector<AnfNodePtr> vecs{};
71 if (node == nullptr) {
72 return vecs;
73 }
74 if (node->isa<mindspore::CNode>()) {
75 auto cnode = node->cast<CNodePtr>();
76 MS_ASSERT(cnode != nullptr);
77 auto &inputs = cnode->inputs();
78 // Check if free variables used.
79 for (const auto &input : inputs) {
80 auto input_fg = GetValueNode<FuncGraphPtr>(input);
81 if (input_fg) {
82 for (auto &fv : input_fg->free_variables_nodes()) {
83 if (fv->func_graph() == fg && fg->nodes().contains(fv)) {
84 vecs.push_back(fv);
85 }
86 }
87 }
88 }
89 (void)vecs.insert(vecs.end(), inputs.begin(), inputs.end());
90 }
91 return vecs;
92 };
93
94 std::list<CNodePtr> cnodes{};
95 auto nodes = TopoSort(fg->get_return(), succ_include_fv, BelongSameGraph);
96 for (const auto &node : nodes) {
97 auto cnode = dyn_cast<mindspore::CNode>(node);
98 if (cnode) {
99 cnodes.push_back(cnode);
100 }
101 }
102 return cnodes;
103 }
104
CreateTensorFromDataInfo(const lite::DataInfo & data_info,const std::string & name,const bool has_default)105 std::unique_ptr<schema::TensorT> CreateTensorFromDataInfo(const lite::DataInfo &data_info, const std::string &name,
106 const bool has_default) {
107 auto schema_tensor = std::make_unique<schema::TensorT>();
108 MS_CHECK_TRUE_MSG(schema_tensor != nullptr, nullptr, "schema_tensor is nullptr");
109 schema_tensor->format = static_cast<schema::Format>(data_info.format_);
110 schema_tensor->name = name;
111 schema_tensor->dims = data_info.shape_;
112 schema_tensor->dataType = data_info.data_type_;
113 schema_tensor->data = data_info.data_;
114 if (has_default) {
115 schema_tensor->nodeType = NodeType_ValueNode;
116 } else {
117 schema_tensor->nodeType = NodeType_CNode;
118 }
119 schema_tensor->enableHuffmanCode = data_info.enable_huffman_code_;
120 schema_tensor->weightQuantCompressType =
121 static_cast<mindspore::schema::WeightQuantCompressType>(data_info.compress_type_);
122 return schema_tensor;
123 }
124 } // namespace
125
ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> & meta_graph,const std::shared_ptr<mindspore::Primitive> & primitive,const std::unique_ptr<schema::CNodeT> & dst_node)126 int AnfExporter::ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &meta_graph,
127 const std::shared_ptr<mindspore::Primitive> &primitive,
128 const std::unique_ptr<schema::CNodeT> &dst_node) {
129 MS_ASSERT(meta_graph != nullptr);
130 MS_ASSERT(primitive != nullptr);
131 MS_ASSERT(dst_node != nullptr);
132 // add quant param
133 MS_LOG(DEBUG) << "node: " << dst_node->name << " add QuantParam";
134 // activation
135 QuantParamsVector input_quant_params;
136 QuantParamsVector output_quant_params;
137 dst_node->quantType = schema::QuantType_QUANT_NONE;
138 auto quant_tensor_info_ptr = primitive->GetAttr("quant_params");
139 if (quant_tensor_info_ptr == nullptr) {
140 return RET_OK;
141 }
142 auto quant_param_holder = quant_tensor_info_ptr->cast<QuantParamHolderPtr>();
143 CHECK_NULL_RETURN(quant_param_holder);
144 input_quant_params = quant_param_holder->get_input_quant_params();
145 output_quant_params = quant_param_holder->get_output_quant_params();
146 dst_node->quantType = static_cast<schema::QuantType>(static_cast<int>(quant_param_holder->quant_type()));
147
148 // convert input quant param
149 for (size_t i = 0; i < dst_node->inputIndex.size(); i++) {
150 if (i >= input_quant_params.size()) {
151 MS_LOG(INFO) << "node: " << dst_node->name << " has " << dst_node->inputIndex.size() << " input, but only has "
152 << input_quant_params.size() << " quant params";
153 break;
154 }
155 auto activate_index = dst_node->inputIndex[i];
156 MS_CHECK_TRUE_MSG(GetAllTensorSize(meta_graph) > activate_index, RET_ERROR, "allTensors size is wrong.");
157 auto tensor_input = GetTensorFromAllTensor(meta_graph, activate_index);
158 CHECK_NULL_RETURN(tensor_input);
159
160 tensor_input->quantClusters = quant_param_holder->GetQuantClusters(i);
161
162 if (!TensorQuantParamsInited(*tensor_input)) {
163 tensor_input->quantParams.clear();
164 for (auto input_quant_param : input_quant_params[i]) {
165 auto input_quant_param_ptr = std::make_unique<schema::QuantParamT>(input_quant_param);
166 MS_CHECK_TRUE_MSG(input_quant_param_ptr != nullptr, RET_ERROR, "input_quant_param_ptr is nullptr");
167 MS_LOG(DEBUG) << "[input][" << i << "]node: " << dst_node->name << " scale: " << input_quant_param_ptr->scale
168 << " zp: " << input_quant_param_ptr->zeroPoint;
169 tensor_input->quantParams.emplace_back(std::move(input_quant_param_ptr));
170 }
171 }
172 }
173
174 // output_quant_params
175 for (size_t index = 0; index < dst_node->outputIndex.size(); ++index) {
176 if (index >= output_quant_params.size()) {
177 MS_LOG(INFO) << "node: " << dst_node->name << " has " << dst_node->outputIndex.size() << " output, but only has"
178 << output_quant_params.size() << " quant params";
179 break;
180 }
181 auto output_tensor = GetTensorFromAllTensor(meta_graph, dst_node->outputIndex[index]);
182 auto &output_quant_param = output_quant_params[index];
183 for (const auto &channel_quant_param : output_quant_param) {
184 if (output_tensor->quantParams.empty() && dst_node->quantType != schema::QuantType_QUANT_WEIGHT) {
185 std::unique_ptr<schema::QuantParamT> output_quant_param_ptr =
186 std::make_unique<schema::QuantParamT>(channel_quant_param);
187 CHECK_NULL_RETURN(output_quant_param_ptr);
188 MS_LOG(DEBUG) << "[output]node: " << dst_node->name << " scale: " << output_quant_param_ptr->scale
189 << " zp: " << output_quant_param_ptr->zeroPoint;
190 output_tensor->quantParams.emplace_back(std::move(output_quant_param_ptr));
191 }
192 }
193 }
194
195 return RET_OK;
196 }
197
ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> & meta_graph,const CNodePtr & cnode,const std::shared_ptr<mindspore::Primitive> & primitive,const std::unique_ptr<schema::CNodeT> & dst_node)198 int AnfExporter::ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &meta_graph, const CNodePtr &cnode,
199 const std::shared_ptr<mindspore::Primitive> &primitive,
200 const std::unique_ptr<schema::CNodeT> &dst_node) {
201 CHECK_NULL_RETURN(meta_graph);
202 CHECK_NULL_RETURN(dst_node);
203 CHECK_NULL_RETURN(cnode);
204 // quant_type not exist in cnode, return
205 auto quant_type_attr = primitive->GetAttr(quant::kQuantType);
206 if (!opt::CheckPrimitiveType(cnode, prim::kPrimQuantDTypeCast)) {
207 if (quant_type_attr != nullptr) {
208 dst_node->quantType = static_cast<schema::QuantType>(GetValue<int32_t>(quant_type_attr));
209 } else {
210 MS_LOG(DEBUG) << "quant_type not exist in cnode, node name: " << dst_node->name;
211 return RET_OK;
212 }
213 } else {
214 dst_node->quantType = schema::QuantType_QUANT_NONE;
215 }
216
217 // convert input quant param
218 for (size_t i = 0; i < dst_node->inputIndex.size(); i++) {
219 auto activate_index = dst_node->inputIndex[i];
220 MS_CHECK_TRUE_MSG(meta_graph->allTensors.size() > activate_index, RET_ERROR, "allTensors size is wrong.");
221 auto tensor_input = meta_graph->allTensors[activate_index].get();
222 auto input_node = cnode->input(i + quant::kPrimOffset);
223 auto status = SetInputQuantParamToTensorT(primitive, input_node, tensor_input);
224 if (status != RET_NO_CHANGE && status != RET_OK) {
225 MS_LOG(ERROR) << "[input][" << i << "] node: " << dst_node->name << " SetInputQuantParamToTensorT failed.";
226 return status;
227 }
228 }
229
230 // output_quant_params
231 for (size_t i = 0; i < dst_node->outputIndex.size(); ++i) {
232 auto output_tensor = meta_graph->allTensors[dst_node->outputIndex[i]].get();
233 auto quantization_param_value = primitive->GetAttr(quant::kQuantParam);
234 if (quantization_param_value == nullptr) {
235 MS_LOG(INFO) << "[output]node: " << dst_node->name << " output quant param Not exist.";
236 continue;
237 }
238 auto quantization_param_list = GetValue<std::vector<QuantizationParamPtr>>(quantization_param_value);
239 if (quantization_param_list.empty()) {
240 MS_LOG(INFO) << "[output]node: " << dst_node->name << " output quant param Not exist.";
241 continue;
242 }
243 if (output_tensor->quantParams.empty() && dst_node->quantType != schema::QuantType_QUANT_WEIGHT) {
244 // Set QuantParamT into meta_graph tensor
245 // Not support cnode with multi-output
246 auto quant_params = quant::ConvertQuantizationParamToQuantParamT(quantization_param_list.front());
247 for (auto quant_param : quant_params) {
248 auto quant_param_ptr = std::make_unique<schema::QuantParamT>(quant_param);
249 MS_LOG(DEBUG) << "node: " << output_tensor->name << " scale: " << quant_param_ptr->scale
250 << " zp: " << quant_param_ptr->zeroPoint;
251 CHECK_NULL_RETURN(quant_param_ptr);
252 output_tensor->quantParams.emplace_back(std::move(quant_param_ptr));
253 }
254 }
255 }
256 return RET_OK;
257 }
258
SetInputQuantParamToTensorT(const std::shared_ptr<mindspore::Primitive> & primitive,const AnfNodePtr & input_node,mindspore::schema::TensorT * tensor_input)259 int AnfExporter::SetInputQuantParamToTensorT(const std::shared_ptr<mindspore::Primitive> &primitive,
260 const AnfNodePtr &input_node, mindspore::schema::TensorT *tensor_input) {
261 CHECK_NULL_RETURN(primitive);
262 CHECK_NULL_RETURN(input_node);
263 CHECK_NULL_RETURN(tensor_input);
264 if (IsGraphInput(input_node)) {
265 if (!primitive->HasAttr(quant::kGraphInputQuantParam)) {
266 return RET_NO_CHANGE;
267 }
268 if (TensorQuantParamsInited(*tensor_input)) {
269 MS_LOG(DEBUG) << input_node->fullname_with_scope() << " TensorT quant param exist.";
270 return RET_NO_CHANGE;
271 }
272 tensor_input->quantParams.clear();
273 auto quantization_param_value = primitive->GetAttr(quant::kGraphInputQuantParam);
274 auto quantization_param_ptr = quantization_param_value->cast<QuantizationParamPtr>();
275 CHECK_NULL_RETURN(quantization_param_ptr);
276 auto quant_params = quant::ConvertQuantizationParamToQuantParamT(quantization_param_ptr);
277 for (auto quant_param : quant_params) {
278 auto quant_param_ptr = std::make_unique<schema::QuantParamT>(quant_param);
279 MS_LOG(DEBUG) << "node: " << input_node->fullname_with_scope() << " scale: " << quant_param_ptr->scale
280 << " zp: " << quant_param_ptr->zeroPoint;
281 tensor_input->quantParams.emplace_back(std::move(quant_param_ptr));
282 }
283 } else if (input_node->isa<mindspore::CNode>()) {
284 // input node has single output
285 auto input_cnode = input_node->cast<mindspore::CNodePtr>();
286 auto input_primitive = GetValueNode<PrimitivePtr>(input_cnode->input(0));
287 MS_CHECK_TRUE_MSG(input_primitive != nullptr, RET_ERROR, "Input node primitive nullptr.");
288 if (!input_primitive->HasAttr(quant::kQuantParam)) {
289 return RET_NO_CHANGE;
290 }
291 if (TensorQuantParamsInited(*tensor_input)) {
292 MS_LOG(DEBUG) << input_node->fullname_with_scope() << " TensorT quant param exist.";
293 return RET_NO_CHANGE;
294 }
295 tensor_input->quantParams.clear();
296 auto quantization_param_value = input_primitive->GetAttr(quant::kQuantParam);
297 auto quantization_param_list = GetValue<std::vector<QuantizationParamPtr>>(quantization_param_value);
298 if (quantization_param_list.empty()) {
299 MS_LOG(DEBUG) << input_node->fullname_with_scope() << " quantization param is empty.";
300 return RET_NO_CHANGE;
301 }
302 auto quant_params = quant::ConvertQuantizationParamToQuantParamT(quantization_param_list.front());
303 for (auto quant_param : quant_params) {
304 auto quant_param_ptr = std::make_unique<schema::QuantParamT>(quant_param);
305 MS_LOG(DEBUG) << "node: " << input_node->fullname_with_scope() << " scale: " << quant_param_ptr->scale
306 << " zp: " << quant_param_ptr->zeroPoint;
307 tensor_input->quantParams.emplace_back(std::move(quant_param_ptr));
308 }
309 } else if (input_node->isa<mindspore::Parameter>() || input_node->isa<mindspore::ValueNode>()) {
310 tensor::TensorPtr input_tensor = quant::GetNodeTensor(input_node);
311 MS_CHECK_TRUE_RET(input_tensor != nullptr, RET_NO_CHANGE);
312 auto quantization_params = input_tensor->quant_params();
313 if (quantization_params.empty()) {
314 MS_LOG(DEBUG) << input_node->fullname_with_scope() << " quantization param is empty.";
315 return RET_NO_CHANGE;
316 }
317 auto quantization_param = quantization_params.front();
318 auto cluster_centroid_list_attr = quantization_param->GetAttr(quant::kClusterCentroidList);
319 if (cluster_centroid_list_attr != nullptr) {
320 tensor_input->quantClusters = GetValue<std::vector<float>>(cluster_centroid_list_attr);
321 return RET_OK;
322 }
323 if (!TensorQuantParamsInited(*tensor_input)) {
324 tensor_input->quantParams.clear();
325 // Set QuantParamT into meta_graph tensor
326 auto quant_params = quant::ConvertQuantizationParamToQuantParamT(quantization_param);
327 for (auto quant_param : quant_params) {
328 auto quant_param_ptr = std::make_unique<schema::QuantParamT>(quant_param);
329 MS_LOG(DEBUG) << "node: " << tensor_input->name << " scale: " << quant_param_ptr->scale
330 << " zp: " << quant_param_ptr->zeroPoint;
331 CHECK_NULL_RETURN(quant_param_ptr);
332 tensor_input->quantParams.emplace_back(std::move(quant_param_ptr));
333 }
334 }
335 } else {
336 MS_LOG(WARNING) << input_node->fullname_with_scope() << " : " << input_node->type_name() << " not supported.";
337 }
338 return RET_OK;
339 }
340
CreateNewTensorForParameter(const std::unique_ptr<schema::MetaGraphT> & meta_graphT,const AnfNodePtr & input,size_t * tensor_index_ptr)341 int AnfExporter::CreateNewTensorForParameter(const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
342 const AnfNodePtr &input, size_t *tensor_index_ptr) {
343 MS_CHECK_TRUE_MSG(meta_graphT != nullptr, RET_NULL_PTR, "meta_graphT is nullptr");
344 MS_CHECK_TRUE_MSG(input != nullptr, RET_NULL_PTR, "input is nullptr");
345 MS_CHECK_TRUE_MSG(tensor_index_ptr != nullptr, RET_NULL_PTR, "tensor_index_ptr is nullptr");
346 lite::DataInfo data_info;
347 auto param_node = input->cast<ParameterPtr>();
348 MS_CHECK_TRUE_MSG(param_node != nullptr, RET_NULL_PTR, "cast ptr failed");
349 if (FetchFromDefaultParam(param_node, converter::FmkType(meta_graphT->fmkType), &data_info, true) != RET_OK) {
350 MS_LOG(ERROR) << "FetchFromDefaultParam failed.";
351 return RET_ERROR;
352 }
353 auto schema_tensor = CreateTensorFromDataInfo(data_info, param_node->name(), param_node->has_default());
354 auto key = std::make_pair(input, 0);
355 *tensor_index_ptr = NewFbTensor(meta_graphT, schema_tensor.release());
356 SetNodeId(key, *tensor_index_ptr);
357 return RET_OK;
358 }
359
SetSubGraphInputIndex(const std::unique_ptr<schema::MetaGraphT> & meta_graphT,const size_t & subgraph_index)360 int AnfExporter::SetSubGraphInputIndex(const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
361 const size_t &subgraph_index) {
362 MS_CHECK_TRUE_MSG(meta_graphT != nullptr, RET_NULL_PTR, "meta_graphT is nullptr");
363 auto &subgraph = meta_graphT->subGraph.at(subgraph_index);
364 FuncGraphPtr fg = nullptr;
365 std::for_each(fg_subgraph_map_.begin(), fg_subgraph_map_.end(),
366 [&subgraph_index, &fg](const std::pair<const FuncGraphPtr, size_t> &it) {
367 if (it.second == subgraph_index) {
368 fg = it.first;
369 }
370 });
371
372 auto inputs = fg->get_inputs();
373 for (auto &input : inputs) {
374 auto key = std::make_pair(input, 0);
375 size_t tensor_index;
376 if (HasNodeIdKey(key)) {
377 subgraph->inputIndices.emplace_back(GetNodeId(key));
378 } else {
379 if (CreateNewTensorForParameter(meta_graphT, input, &tensor_index) != RET_OK) {
380 MS_LOG(ERROR) << "CreateNewTensorForParameter failed.";
381 return RET_ERROR;
382 }
383 subgraph->inputIndices.emplace_back(tensor_index);
384 }
385 }
386 return RET_OK;
387 }
388
SetSubGraphOutputIndex(const CNodePtr & cnode,const size_t subgraph_index,const std::unique_ptr<schema::MetaGraphT> & meta_graphT,schema::CNodeT * return_node)389 int AnfExporter::SetSubGraphOutputIndex(const CNodePtr &cnode, const size_t subgraph_index,
390 const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
391 schema::CNodeT *return_node) {
392 MS_ASSERT(meta_graphT != nullptr);
393 MS_ASSERT(return_node != nullptr);
394 for (size_t i = kFirstDataIndex; i < cnode->size(); i++) {
395 auto input_node = cnode->input(i);
396 if (input_node == nullptr) {
397 MS_LOG(ERROR) << "output node is nullptr";
398 return RET_NULL_PTR;
399 } else if (input_node->isa<mindspore::CNode>()) {
400 auto ret = ConvertInputCNode(input_node, return_node);
401 if (ret != RET_OK) {
402 MS_LOG(ERROR) << "obtain outputs failed";
403 return ret;
404 }
405 } else if (input_node->isa<Parameter>()) {
406 auto key = std::make_pair(input_node, 0);
407 size_t tensor_index;
408 if (HasNodeIdKey(key)) {
409 return_node->inputIndex.emplace_back(GetNodeId(key));
410 } else {
411 if (CreateNewTensorForParameter(meta_graphT, input_node, &tensor_index) != RET_OK) {
412 MS_LOG(ERROR) << "CreateNewTensorForParameter failed.";
413 return RET_ERROR;
414 }
415 return_node->inputIndex.emplace_back(tensor_index);
416 }
417 if (IsContain(graph_inputs_, input_node->cast<AnfNodePtr>()) &&
418 graph_inputs_map_.find(input_node) == graph_inputs_map_.end()) {
419 graph_inputs_map_[input_node] = tensor_index;
420 }
421 } else {
422 MS_LOG(ERROR) << "the node " << input_node->fullname_with_scope().c_str() << "is not output node";
423 return RET_ERROR;
424 }
425 }
426 for (unsigned int &i : return_node->inputIndex) {
427 meta_graphT->subGraph.at(subgraph_index)->outputIndices.push_back(i);
428 }
429 return RET_OK;
430 }
431
HasExported(const FuncGraphPtr & func_graph)432 bool AnfExporter::HasExported(const FuncGraphPtr &func_graph) {
433 if (fg_subgraph_map_.find(func_graph) != fg_subgraph_map_.end()) {
434 return true;
435 }
436 return false;
437 }
438
ExportPartialNode(const std::unique_ptr<schema::MetaGraphT> & meta_graphT,const bool & keep_graph,const bool & copy_primitive,const CNodePtr & partial_cnode,const std::unique_ptr<schema::CNodeT> & schema_cnode)439 int AnfExporter::ExportPartialNode(const std::unique_ptr<schema::MetaGraphT> &meta_graphT, const bool &keep_graph,
440 const bool ©_primitive, const CNodePtr &partial_cnode,
441 const std::unique_ptr<schema::CNodeT> &schema_cnode) {
442 MS_CHECK_TRUE_MSG(meta_graphT != nullptr, RET_NULL_PTR, "meta_graphT is nullptr");
443 MS_CHECK_TRUE_MSG(partial_cnode != nullptr, RET_NULL_PTR, "partial_cnode is nullptr");
444 MS_CHECK_TRUE_MSG(schema_cnode != nullptr, RET_NULL_PTR, "schema_cnode is nullptr");
445 auto prim = GetValueNode<std::shared_ptr<mindspore::Primitive>>(partial_cnode->input(0));
446 MS_CHECK_TRUE_MSG(prim != nullptr, RET_NULL_PTR, "GetValueNode failed");
447 if (prim->name() != mindspore::ops::kNamePartialFusion) {
448 MS_LOG(INFO) << "not is partial";
449 return RET_OK;
450 }
451
452 auto partial_fusion_primc = schema_cnode->primitive->value.AsPartialFusion();
453 auto vnode = partial_cnode->input(kFirstDataIndex)->cast<ValueNodePtr>();
454 MS_CHECK_TRUE_MSG(partial_fusion_primc != nullptr, RET_NULL_PTR, "partial_fusion_primc is invalid");
455 MS_CHECK_TRUE_MSG(vnode != nullptr, RET_NULL_PTR, "vnode is invalid");
456 auto fg = vnode->value()->cast<FuncGraphPtr>();
457 MS_CHECK_TRUE_MSG(fg != nullptr, RET_NULL_PTR, "func graph is nullptr.");
458 if (fg_subgraph_map_.find(fg) != fg_subgraph_map_.end()) {
459 partial_fusion_primc->sub_graph_index = static_cast<int>(fg_subgraph_map_.at(fg));
460 return RET_OK;
461 }
462
463 partial_fusion_primc->sub_graph_index = static_cast<int>(meta_graphT->subGraph.size());
464 auto ret = ExportSubgraph(fg, meta_graphT, keep_graph, copy_primitive, partial_cnode);
465 if (ret != RET_OK) {
466 MS_LOG(ERROR) << "ExportSubgraph failed";
467 return ret;
468 }
469 return RET_OK;
470 }
471
InsertCallNode(const FuncGraphPtr & func_graph)472 std::list<CNodePtr> AnfExporter::InsertCallNode(const FuncGraphPtr &func_graph) {
473 MS_CHECK_TRUE_MSG(func_graph != nullptr, {}, "func_graph is nullptr");
474 auto cnodes = GetOrderedCNodes(func_graph);
475 for (auto it = cnodes.begin(); it != cnodes.end();) {
476 auto prim = GetValueNode<std::shared_ptr<mindspore::Primitive>>((*it)->input(kPrimIndex));
477 if (prim == nullptr) {
478 auto fg = GetValueNode<FuncGraphPtr>((*it)->input(kPrimIndex));
479 if (fg != nullptr) {
480 auto partial_cnode = CreatePartialCnode(fg, (*it));
481 auto call_cnode = CreateCallCnode(fg, partial_cnode);
482 ++it;
483 it = cnodes.insert(it, call_cnode);
484 continue;
485 } else {
486 auto call_anf_prim_vnode = GetCallAnfPrim();
487 auto cnode_input = (*it)->inputs();
488 cnode_input.insert(cnode_input.begin(), call_anf_prim_vnode);
489 (*it)->set_inputs(cnode_input);
490 }
491 }
492 ++it;
493 }
494 return cnodes;
495 }
496
SetNonTailCall(const CNodePtr & cnode,schema::CNodeT * node)497 void AnfExporter::SetNonTailCall(const CNodePtr &cnode, schema::CNodeT *node) {
498 if (cnode == nullptr || node == nullptr) {
499 MS_LOG(ERROR) << "conde or node is nullptr";
500 return;
501 }
502 if (!opt::CheckPrimitiveType(cnode, prim::kPrimCall)) {
503 return;
504 }
505 node->primitive->value.AsCall()->is_tail_call = false;
506 call_node_map_[cnode] = node;
507 return;
508 }
509
SetTailCallForReturn(const CNodePtr & return_cnode)510 int AnfExporter::SetTailCallForReturn(const CNodePtr &return_cnode) {
511 MS_CHECK_TRUE_MSG(return_cnode != nullptr, RET_NULL_PTR, "return_cnode is nullptr");
512 auto return_cnode_input_size = return_cnode->size();
513 for (size_t i = 1; i < return_cnode_input_size; ++i) {
514 if (!utils::isa<CNodePtr>(return_cnode->input(i))) {
515 continue;
516 }
517 if (!opt::CheckPrimitiveType(return_cnode->input(i), prim::kPrimCall)) {
518 continue;
519 }
520 auto call_cnode = return_cnode->input(i)->cast<CNodePtr>();
521 if (call_node_map_.find(call_cnode) == call_node_map_.end()) {
522 MS_LOG(ERROR) << "Not found call cnode in call_node_map.";
523 return RET_ERROR;
524 }
525 call_node_map_[call_cnode]->primitive->value.AsCall()->is_tail_call = true;
526 }
527 return RET_OK;
528 }
529
SetTailCallForNonOutput()530 int AnfExporter::SetTailCallForNonOutput() {
531 for (auto item : call_node_map_) {
532 auto call_cnode = item.first;
533 auto mg = call_cnode->func_graph()->manager();
534 if (mg == nullptr) {
535 MS_LOG(ERROR) << "manager is nullptr.";
536 return RET_NULL_PTR;
537 }
538 auto node_user = mg->node_users()[call_cnode];
539 if (node_user.empty()) {
540 (item.second)->primitive->value.AsCall()->is_tail_call = true;
541 }
542 }
543 return RET_OK;
544 }
545
GetNodeId(const std::pair<AnfNodePtr,size_t> & key)546 size_t AnfExporter::GetNodeId(const std::pair<AnfNodePtr, size_t> &key) {
547 node_id_map_mutex_.lock();
548 auto node_tensor_index = node_id_map_[key];
549 node_id_map_mutex_.unlock();
550 return node_tensor_index;
551 }
552
SetNodeId(const std::pair<AnfNodePtr,size_t> & key,size_t value)553 void AnfExporter::SetNodeId(const std::pair<AnfNodePtr, size_t> &key, size_t value) {
554 node_id_map_mutex_.lock();
555 node_id_map_[key] = value;
556 node_id_map_mutex_.unlock();
557 }
558
HasNodeIdKey(const std::pair<AnfNodePtr,size_t> & key)559 bool AnfExporter::HasNodeIdKey(const std::pair<AnfNodePtr, size_t> &key) {
560 node_id_map_mutex_.lock();
561 auto has_key = node_id_map_.find(key) != node_id_map_.end();
562 node_id_map_mutex_.unlock();
563 return has_key;
564 }
565
NewFbTensor(const std::unique_ptr<schema::MetaGraphT> & meta_graphT,mindspore::schema::TensorT * tensor)566 size_t AnfExporter::NewFbTensor(const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
567 mindspore::schema::TensorT *tensor) {
568 fb_graph_all_tensors_mutex_.lock();
569 auto insert_index = meta_graphT->allTensors.size();
570 meta_graphT->allTensors.emplace_back(std::move(tensor));
571 fb_graph_all_tensors_mutex_.unlock();
572 return insert_index;
573 }
574
InsertFbTensor(const std::unique_ptr<schema::MetaGraphT> & meta_graphT,mindspore::schema::TensorT * tensor)575 void AnfExporter::InsertFbTensor(const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
576 mindspore::schema::TensorT *tensor) {
577 fb_graph_all_tensors_mutex_.lock();
578 meta_graphT->allTensors.emplace_back(std::move(tensor));
579 fb_graph_all_tensors_mutex_.unlock();
580 }
581
GetAllTensorSize(const std::unique_ptr<schema::MetaGraphT> & meta_graphT)582 size_t AnfExporter::GetAllTensorSize(const std::unique_ptr<schema::MetaGraphT> &meta_graphT) {
583 fb_graph_all_tensors_mutex_.lock();
584 auto size = meta_graphT->allTensors.size();
585 fb_graph_all_tensors_mutex_.unlock();
586 return size;
587 }
588
GetTensorFromAllTensor(const std::unique_ptr<schema::MetaGraphT> & meta_graphT,size_t index)589 mindspore::schema::TensorT *AnfExporter::GetTensorFromAllTensor(const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
590 size_t index) {
591 fb_graph_all_tensors_mutex_.lock();
592 auto *tensor = meta_graphT->allTensors[index].get();
593 fb_graph_all_tensors_mutex_.unlock();
594 return tensor;
595 }
596
CaseToContinue(const string & prim_name)597 bool AnfExporter::CaseToContinue(const string &prim_name) {
598 return prim_name == mindspore::ops::kNameDepend || prim_name == mindspore::ops::kNameTupleGetItem ||
599 prim_name == mindspore::ops::kNameMakeTuple || prim_name == mindspore::ops::kNameMakeTupleV2;
600 }
601
602 struct Anf2FbItem {
603 public:
Anf2FbItemmindspore::lite::Anf2FbItem604 Anf2FbItem(const std::shared_ptr<mindspore::Primitive> &prim, CNodePtr cnode) : prim_(prim), cnode_(cnode) {
605 dst_node_ = nullptr;
606 }
607
608 std::shared_ptr<mindspore::Primitive> prim_;
609 CNodePtr cnode_;
610 schema::CNodeT *dst_node_;
611 };
612
Anf2Fb(const FuncGraphPtr & func_graph,const std::unique_ptr<schema::MetaGraphT> & meta_graphT,const size_t & subgraph_index,const bool & keep_graph,const bool & copy_primitive)613 int AnfExporter::Anf2Fb(const FuncGraphPtr &func_graph, const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
614 const size_t &subgraph_index, const bool &keep_graph, const bool ©_primitive) {
615 MS_CHECK_TRUE_MSG(func_graph != nullptr, RET_NULL_PTR, "func_graph is nullptr");
616 MS_CHECK_TRUE_MSG(meta_graphT != nullptr, RET_NULL_PTR, "meta_graphT is nullptr");
617 int ret = RET_OK;
618 auto cnodes = InsertCallNode(func_graph);
619 std::list<Anf2FbItem> convert_items;
620
621 // Do Modify FuncGraph in here and save convert item for next step
622 for (const auto &cnode : cnodes) {
623 auto prim = GetValueNode<std::shared_ptr<mindspore::Primitive>>(cnode->input(kPrimIndex));
624 if (prim == nullptr) {
625 MS_LOG(ERROR) << "get prim from value node failed.";
626 return RET_ERROR;
627 }
628 ret = RemoveIfDepend(cnode);
629 if (ret != RET_OK) {
630 MS_LOG(ERROR) << "RemoveIfDepend failed";
631 return ret;
632 }
633 if (CaseToContinue(prim->name())) {
634 continue;
635 }
636 ret = RemoveIfMakeTuple(cnode);
637 if (ret != RET_OK) {
638 MS_LOG(ERROR) << "RemoveIfMakeTuple failed";
639 return ret;
640 }
641 auto node = std::make_unique<schema::CNodeT>();
642 if (node == nullptr) {
643 MS_LOG(ERROR) << "object failed to be constructed";
644 return RET_MEMORY_FAILED;
645 }
646
647 Anf2FbItem convert_item(prim, cnode);
648 convert_item.dst_node_ = node.release();
649 convert_items.push_back(convert_item);
650 }
651
652 // convert CNode into NodeT
653 for (const auto &item : convert_items) {
654 auto prim = item.prim_;
655 auto cnode = item.cnode_;
656
657 std::unique_ptr<schema::CNodeT> node(item.dst_node_);
658 std::unique_ptr<schema::PrimitiveT> primT;
659
660 if (opt::CheckPrimitiveType(cnode, prim::kPrimReturn)) {
661 node->name = mindspore::ops::kNameReturn;
662 ret = SetSubGraphOutputIndex(cnode, subgraph_index, meta_graphT, node.get());
663 if (ret != RET_OK) {
664 MS_LOG(ERROR) << "SetOpOutputN failed";
665 break;
666 }
667 ret = SetTailCallForReturn(cnode);
668 if (ret != RET_OK) {
669 MS_LOG(ERROR) << "SetTailCallForReturn failed";
670 break;
671 }
672 continue;
673 }
674 primT = GetPrimitiveT(cnode->input(kPrimIndex));
675 node->name = cnode->fullname_with_scope();
676 node->primitive = std::move(primT);
677 auto device_type_attr = cnode->GetAttr(mindspore::ops::kDeviceType);
678 node->deviceType = (device_type_attr != nullptr) ? GetValue<int32_t>(device_type_attr) : -1;
679
680 ret = SetOpOutputNode(cnode, meta_graphT, node.get());
681 if (ret != RET_OK) {
682 MS_LOG(ERROR) << "SetOpOutputNode failed";
683 break;
684 }
685
686 ret = SetOpInputNode(cnode, meta_graphT, node.get());
687 if (ret != RET_OK) {
688 MS_LOG(ERROR) << "SetOpInputNode failed";
689 break;
690 }
691 // set all call op to non tail call
692 if (opt::CheckPrimitiveType(cnode, prim::kPrimCall)) {
693 node->primitive->value.AsCall()->is_tail_call = false;
694 call_node_map_[cnode] = node.get();
695 }
696
697 ret = ExportPartialNode(meta_graphT, keep_graph, copy_primitive, cnode, node);
698 if (ret != RET_OK) {
699 MS_LOG(ERROR) << "ExportPartialNode failed.";
700 break;
701 }
702
703 ret = ConvertQuantParam(meta_graphT, prim, node);
704 if (ret != RET_OK) {
705 MS_LOG(ERROR) << "ConvertQuantParam failed";
706 break;
707 }
708
709 ret = ConvertQuantParam(meta_graphT, cnode, prim, node);
710 if (ret != RET_OK) {
711 MS_LOG(ERROR) << "New ConvertQuantParam failed";
712 break;
713 }
714
715 fb_graph_node_mutex_.lock();
716 meta_graphT->nodes.push_back(std::move(node));
717 meta_graphT->subGraph.at(subgraph_index)->nodeIndices.push_back(node_idx_++);
718 fb_graph_node_mutex_.unlock();
719 }
720 return ret;
721 }
722
ExportSubgraph(const FuncGraphPtr & func_graph,const std::unique_ptr<schema::MetaGraphT> & meta_graphT,bool keep_graph,bool copy_primitive,const std::shared_ptr<AnfNode> & partial_anode)723 int AnfExporter::ExportSubgraph(const FuncGraphPtr &func_graph, const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
724 bool keep_graph, bool copy_primitive, const std::shared_ptr<AnfNode> &partial_anode) {
725 MS_CHECK_TRUE_MSG(func_graph != nullptr, RET_NULL_PTR, "func_graph is nullptr");
726 MS_CHECK_TRUE_MSG(meta_graphT != nullptr, RET_NULL_PTR, "meta_graphT is nullptr");
727 if (HasExported(func_graph)) {
728 MS_LOG(INFO) << "Has been exported.";
729 return RET_OK;
730 }
731
732 auto subgraph_ptr = std::make_unique<schema::SubGraphT>();
733 CHECK_NULL_RETURN(subgraph_ptr);
734 meta_graphT->subGraph.emplace_back(std::move(subgraph_ptr));
735 auto subgraph_index = meta_graphT->subGraph.size() - 1;
736 fg_subgraph_map_[func_graph] = subgraph_index;
737 auto subgraph_name = func_graph->get_attr("graph_name");
738 MS_CHECK_TRUE_MSG(subgraph_name != nullptr, RET_ERROR, "subgraph_name is nullptr");
739 meta_graphT->subGraph.back()->name =
740 "subgraph_" + std::to_string(meta_graphT->subGraph.size() - 1) + "_" + GetValue<std::string>(subgraph_name);
741
742 auto ret = Anf2Fb(func_graph, meta_graphT, subgraph_index, keep_graph, copy_primitive);
743 if (ret != RET_OK) {
744 MS_LOG(ERROR) << "Anf2Fb failed";
745 ReturnCode::GetSingleReturnCode()->UpdateReturnCode(ret);
746 return ret;
747 }
748
749 ret = SetSubGraphInputIndex(meta_graphT, subgraph_index);
750 if (ret != RET_OK) {
751 MS_LOG(ERROR) << "SetSubGraphInputIndex failed";
752 ReturnCode::GetSingleReturnCode()->UpdateReturnCode(ret);
753 return ret;
754 }
755
756 SetSubgraphTensorIndices(meta_graphT.get());
757
758 return RET_OK;
759 }
760
GetFinalGraph(const FuncGraphPtr & func_graph,int i)761 FuncGraphPtr GetFinalGraph(const FuncGraphPtr &func_graph, int i) {
762 MS_CHECK_TRUE_MSG(func_graph != nullptr, nullptr, "func_graph is nullptr");
763 if (i > kMaxDepth) {
764 MS_LOG(ERROR) << "exceed max depth 2048, i " << i;
765 return nullptr;
766 }
767 i++;
768 // get output
769 CNodePtr call_cnode = nullptr;
770 auto fg_output = func_graph->output();
771 if (opt::CheckPrimitiveType(fg_output, prim::kPrimCall)) {
772 call_cnode = fg_output->cast<CNodePtr>();
773 } else {
774 return func_graph;
775 }
776
777 // if call input is switch, meta output is call switch false partial's fg'output!
778 auto cnode = call_cnode->input(kFirstDataIndex)->cast<CNodePtr>();
779 if (IsSwitch(cnode)) {
780 auto false_cnode = cnode->input(kThirdDataIndex)->cast<CNodePtr>();
781 MS_CHECK_TRUE_MSG(false_cnode != nullptr, nullptr, "cast failed");
782 auto false_fg = GetValueNode<FuncGraphPtr>(false_cnode->input(kFirstDataIndex));
783 MS_CHECK_TRUE_MSG(false_fg != nullptr, nullptr, "GetValueNode failed");
784 return GetFinalGraph(false_fg, i);
785 } else if (IsSwitchLayer(cnode)) {
786 auto first_partial_cnode = cnode->input(kSecondDataIndex)->cast<CNodePtr>();
787 MS_CHECK_TRUE_MSG(first_partial_cnode != nullptr, nullptr, "cast failed");
788 auto next_fg = GetValueNode<FuncGraphPtr>(first_partial_cnode->input(kFirstDataIndex));
789 MS_CHECK_TRUE_MSG(next_fg != nullptr, nullptr, "GetValueNode failed");
790 return GetFinalGraph(next_fg, i);
791 } else {
792 auto fg = GetValueNode<FuncGraphPtr>(cnode->input(kFirstDataIndex));
793 MS_CHECK_TRUE_MSG(fg != nullptr, nullptr, "GetValueNode failed");
794 return GetFinalGraph(fg, i);
795 }
796 }
797
SetMetaGraphInput(const FuncGraphPtr & func_graph,const std::unique_ptr<schema::MetaGraphT> & meta_graphT)798 int AnfExporter::SetMetaGraphInput(const FuncGraphPtr &func_graph,
799 const std::unique_ptr<schema::MetaGraphT> &meta_graphT) {
800 MS_CHECK_TRUE_MSG(func_graph != nullptr, RET_NULL_PTR, "func_graph is nullptr");
801 MS_CHECK_TRUE_MSG(meta_graphT != nullptr, RET_NULL_PTR, "meta_graphT is nullptr");
802 MS_ASSERT(func_graph != nullptr);
803 meta_graphT->inputIndex.clear();
804 for (const auto &input : func_graph->get_inputs()) {
805 auto iter = graph_inputs_map_.find(input);
806 if (iter == graph_inputs_map_.end()) {
807 MS_LOG(ERROR) << "input " << input->ToString() << " not found in graph" << std::endl;
808 return RET_ERROR;
809 }
810 meta_graphT->inputIndex.emplace_back(iter->second);
811 }
812 return RET_OK;
813 }
814
SetMetaGraphOutput(const FuncGraphPtr & func_graph,const std::unique_ptr<schema::MetaGraphT> & meta_graphT)815 int AnfExporter::SetMetaGraphOutput(const FuncGraphPtr &func_graph,
816 const std::unique_ptr<schema::MetaGraphT> &meta_graphT) {
817 MS_CHECK_TRUE_MSG(func_graph != nullptr, RET_NULL_PTR, "func_graph is nullptr");
818 MS_CHECK_TRUE_MSG(meta_graphT != nullptr, RET_NULL_PTR, "meta_graphT is nullptr");
819 FuncGraphPtr final_fg = nullptr;
820 if (meta_graphT->fmkType == static_cast<int32_t>(converter::kFmkTypeMs)) {
821 final_fg = func_graph;
822 } else {
823 int i = 0;
824 final_fg = GetFinalGraph(func_graph, i);
825 }
826 MS_CHECK_TRUE_MSG(final_fg != nullptr, RET_ERROR, "GetFinalGraph failed.");
827 auto final_meta_graph_index = fg_subgraph_map_.at(final_fg);
828 auto &final_meta_graph = meta_graphT->subGraph.at(final_meta_graph_index);
829 meta_graphT->outputIndex.assign(final_meta_graph->outputIndices.begin(), final_meta_graph->outputIndices.end());
830
831 for (auto &output_index : meta_graphT->outputIndex) {
832 auto tensor = GetTensorFromAllTensor(meta_graphT, output_index);
833 if (tensor == nullptr) {
834 MS_LOG(ERROR) << "Set meta graph output failed: output tensor is null.";
835 return RET_ERROR;
836 }
837 ConverterInnerContext::GetInstance()->UpdateGraphOutputDType(meta_graphT->outputIndex.size(), tensor->dataType);
838 }
839
840 return RET_OK;
841 }
842
Export(const FuncGraphPtr & func_graph,bool keep_graph,bool copy_primitive,bool train_flag)843 schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph, bool keep_graph, bool copy_primitive,
844 bool train_flag) {
845 MS_CHECK_TRUE_MSG(func_graph != nullptr, nullptr, "func_graph is nullptr");
846 this->train_flag_ = train_flag;
847 // hardcode for nnie and train
848 this->graph_inputs_map_.clear();
849 auto meta_graphT = std::make_unique<schema::MetaGraphT>();
850 MS_CHECK_TRUE_MSG(meta_graphT != nullptr, nullptr, "meta_graphT is nullptr");
851 auto fmk = func_graph->get_attr("fmk");
852 MS_CHECK_TRUE_MSG(fmk != nullptr, nullptr, "fmk is nullptr");
853 if (fmk->isa<Int64Imm>()) {
854 meta_graphT->fmkType = GetValue<int64_t>(fmk);
855 } else {
856 meta_graphT->fmkType = GetValue<int>(fmk);
857 }
858
859 graph_inputs_ = func_graph->get_inputs();
860
861 int ret = ExportSubgraph(func_graph, meta_graphT, keep_graph, copy_primitive);
862 if (ret != RET_OK) {
863 MS_LOG(ERROR) << "Export subgraph failed.";
864 ReturnCode::GetSingleReturnCode()->UpdateReturnCode(ret);
865 return nullptr;
866 }
867
868 ret = SetTailCallForNonOutput();
869 if (ret != RET_OK) {
870 MS_LOG(ERROR) << "SetTailCallForNonOutput failed.";
871 ReturnCode::GetSingleReturnCode()->UpdateReturnCode(ret);
872 return nullptr;
873 }
874
875 ret = SetMetaGraphInput(func_graph, meta_graphT);
876 if (ret != RET_OK) {
877 MS_LOG(ERROR) << "SetMetaGraphInput failed.";
878 ReturnCode::GetSingleReturnCode()->UpdateReturnCode(ret);
879 return nullptr;
880 }
881 ret = SetMetaGraphOutput(func_graph, meta_graphT);
882 if (ret != RET_OK) {
883 MS_LOG(ERROR) << "SetMetaGraphOutput failed.";
884 ReturnCode::GetSingleReturnCode()->UpdateReturnCode(ret);
885 return nullptr;
886 }
887
888 return meta_graphT.release();
889 }
890
ConvertInputCNodeCommonOp(const AnfNodePtr & input_anode,schema::CNodeT * output_cnode)891 int AnfExporter::ConvertInputCNodeCommonOp(const AnfNodePtr &input_anode, schema::CNodeT *output_cnode) {
892 MS_ASSERT(input_anode != nullptr && output_cnode != nullptr);
893 if (this->train_flag_) {
894 auto key = std::make_pair(input_anode, 0);
895 if (HasNodeIdKey(key)) {
896 output_cnode->inputIndex.emplace_back(GetNodeId(key));
897 }
898 return RET_OK;
899 }
900 if (utils::isa<abstract::AbstractTuple>(input_anode->abstract())) {
901 auto tuple = std::reinterpret_pointer_cast<abstract::AbstractTuple>(input_anode->abstract());
902 MS_CHECK_TRUE_MSG(tuple != nullptr, RET_ERROR, "tuple is nullptr");
903 auto elements = tuple->elements();
904 for (size_t i = 0; i < elements.size(); i++) {
905 auto key = std::make_pair(input_anode, i);
906 if (HasNodeIdKey(key)) {
907 output_cnode->inputIndex.emplace_back(GetNodeId(key));
908 }
909 }
910 } else {
911 auto key = std::make_pair(input_anode, 0);
912 if (HasNodeIdKey(key)) {
913 output_cnode->inputIndex.emplace_back(GetNodeId(key));
914 }
915 }
916 return RET_OK;
917 }
918
ConvertInputCNode(const std::shared_ptr<AnfNode> & input_anode,schema::CNodeT * output_cnode)919 int AnfExporter::ConvertInputCNode(const std::shared_ptr<AnfNode> &input_anode, schema::CNodeT *output_cnode) {
920 auto input_cnode = utils::cast<CNodePtr>(input_anode);
921 MS_CHECK_TRUE_MSG(input_cnode != nullptr, RET_ERROR, "cast ptr failed");
922 auto input_value_node = input_cnode->input(kPrimIndex)->cast<ValueNodePtr>();
923 if (input_value_node == nullptr) {
924 if (!IsCall(input_cnode)) {
925 MS_LOG(ERROR) << "value node is invalid.";
926 return RET_ERROR;
927 } else {
928 auto call_anf_prim_vnode = GetCallAnfPrim();
929 auto cnode_input = input_cnode->inputs();
930 MS_CHECK_TRUE_MSG(call_anf_prim_vnode != nullptr, RET_ERROR, "GetCallAnfPrim failed");
931 cnode_input.insert(cnode_input.begin(), call_anf_prim_vnode);
932 input_cnode->set_inputs(cnode_input);
933 }
934 }
935
936 input_value_node = input_cnode->input(kPrimIndex)->cast<ValueNodePtr>();
937
938 if (input_value_node->value() == nullptr || !opt::CheckPrimitiveType(input_cnode, prim::kPrimTupleGetItem)) {
939 return ConvertInputCNodeCommonOp(input_anode, output_cnode);
940 } else {
941 auto inputs = input_cnode->inputs();
942
943 if (inputs.size() != 3) {
944 MS_LOG(ERROR) << "TupleGetItem should have 3 inputs, got " << inputs.size();
945 return RET_ERROR;
946 }
947 auto get_item_input_cnode = inputs.at(1);
948 auto index_vnode = inputs.at(kIndexOfValueInputOfGetTupleItem);
949 if (!utils::isa<ValueNode>(index_vnode)) {
950 MS_LOG(ERROR) << "TupleGetItem's input 2 is not valuenode";
951 return RET_ERROR;
952 }
953 auto value_node = utils::cast<ValueNodePtr>(index_vnode);
954 MS_CHECK_TRUE_MSG(value_node != nullptr, RET_ERROR, "cast to ValueNode failed");
955 auto idx = value_node->value()->type()->number_type() == kNumberTypeInt64 ? GetValue<int64_t>(value_node->value())
956 : GetValue<int>(value_node->value());
957 auto key = std::make_pair(get_item_input_cnode, idx);
958 if (!HasNodeIdKey(key)) {
959 key = std::make_pair(get_item_input_cnode, 0); // try name with 0
960 if (!HasNodeIdKey(key)) {
961 MS_LOG(ERROR) << "Can not find get_item output tensor "
962 << get_item_input_cnode->fullname_with_scope() + "_o:" + std::to_string(idx);
963 return RET_ERROR;
964 }
965 }
966 output_cnode->inputIndex.emplace_back(GetNodeId(key));
967 }
968 return RET_OK;
969 }
970
ConvertInputParameter(const CNodePtr & cnode,size_t index,const PrimitivePtr & primitive,const std::unique_ptr<schema::MetaGraphT> & meta_graphT,schema::CNodeT * op_node,size_t * tensor_index_ptr)971 int AnfExporter::ConvertInputParameter(const CNodePtr &cnode, size_t index, const PrimitivePtr &primitive,
972 const std::unique_ptr<schema::MetaGraphT> &meta_graphT, schema::CNodeT *op_node,
973 size_t *tensor_index_ptr) {
974 MS_CHECK_TRUE_MSG(cnode != nullptr, RET_NULL_PTR, "cnode is nullptr");
975 MS_CHECK_TRUE_MSG(primitive != nullptr, RET_NULL_PTR, "primitive is nullptr");
976 MS_CHECK_TRUE_MSG(meta_graphT != nullptr, RET_NULL_PTR, "meta_graphT is nullptr");
977 MS_CHECK_TRUE_MSG(op_node != nullptr, RET_NULL_PTR, "op_node is nullptr");
978 MS_CHECK_TRUE_MSG(tensor_index_ptr != nullptr, RET_NULL_PTR, "tensor_index_ptr is nullptr");
979 auto param_node = cnode->input(index)->cast<ParameterPtr>();
980 MS_ASSERT(param_node != nullptr);
981 auto key = std::make_pair(param_node, 0);
982 if (HasNodeIdKey(key)) {
983 op_node->inputIndex.emplace_back(GetNodeId(key));
984 return RET_OK;
985 }
986 DataInfo data_info;
987 if (FetchDataFromParameterNode(cnode, index, converter::FmkType(meta_graphT->fmkType), &data_info, true) != RET_OK) {
988 MS_LOG(ERROR) << "parse const node failed.";
989 return RET_ERROR;
990 }
991 auto schema_tensor = CreateTensorFromDataInfo(data_info, param_node->name(), param_node->has_default());
992 *tensor_index_ptr = NewFbTensor(meta_graphT, schema_tensor.release());
993 SetNodeId(key, *tensor_index_ptr);
994 op_node->inputIndex.emplace_back(*tensor_index_ptr);
995 return RET_OK;
996 }
997
ConvertInputValueNode(const CNodePtr & cnode,size_t index,const PrimitivePtr & primitive,const std::unique_ptr<schema::MetaGraphT> & meta_graphT,schema::CNodeT * op_node)998 int AnfExporter::ConvertInputValueNode(const CNodePtr &cnode, size_t index, const PrimitivePtr &primitive,
999 const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
1000 schema::CNodeT *op_node) {
1001 MS_CHECK_TRUE_MSG(cnode != nullptr, RET_NULL_PTR, "cnode is nullptr");
1002 MS_CHECK_TRUE_MSG(primitive != nullptr, RET_NULL_PTR, "primitive is nullptr");
1003 MS_CHECK_TRUE_MSG(meta_graphT != nullptr, RET_NULL_PTR, "meta_graphT is nullptr");
1004 MS_CHECK_TRUE_MSG(op_node != nullptr, RET_NULL_PTR, "op_node is nullptr");
1005 auto value_node = cnode->input(index)->cast<ValueNodePtr>();
1006 MS_ASSERT(value_node != nullptr);
1007 auto key = std::make_pair(value_node, 0);
1008 if (HasNodeIdKey(key)) {
1009 op_node->inputIndex.emplace_back(GetNodeId(key));
1010 return RET_OK;
1011 }
1012 DataInfo data_info;
1013 auto status =
1014 FetchDataFromValueNode(cnode, index, converter::FmkType(meta_graphT->fmkType), train_flag_, &data_info, true);
1015 if (status == RET_NO_CHANGE) {
1016 return RET_OK;
1017 }
1018 if (status != RET_OK) {
1019 MS_LOG(ERROR) << "parse value node failed.";
1020 return status;
1021 }
1022 auto schema_tensor = std::make_unique<schema::TensorT>();
1023 MS_CHECK_TRUE_MSG(schema_tensor != nullptr, RET_ERROR, "schema is nullptr");
1024 schema_tensor->name = value_node->fullname_with_scope();
1025 schema_tensor->format = static_cast<schema::Format>(data_info.format_);
1026 schema_tensor->dataType = data_info.data_type_;
1027 schema_tensor->dims = data_info.shape_;
1028 schema_tensor->data = data_info.data_;
1029
1030 auto tensor_index = NewFbTensor(meta_graphT, schema_tensor.release());
1031 SetNodeId(key, tensor_index);
1032 op_node->inputIndex.emplace_back(tensor_index);
1033 return RET_OK;
1034 }
1035
SetOpInputNode(const CNodePtr & cnode,const std::unique_ptr<schema::MetaGraphT> & meta_graphT,schema::CNodeT * fb_node)1036 int AnfExporter::SetOpInputNode(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
1037 schema::CNodeT *fb_node) {
1038 MS_ASSERT(meta_graphT != nullptr);
1039 MS_ASSERT(fb_node != nullptr);
1040 if (cnode->size() <= 1) {
1041 return RET_OK;
1042 }
1043 auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
1044 if (primitive_c == nullptr) {
1045 MS_LOG(ERROR) << "primitive_c is nullptr: " << cnode->fullname_with_scope();
1046 return RET_ERROR;
1047 }
1048 for (size_t i = 1; i < cnode->size(); i++) {
1049 auto input_node = cnode->input(i);
1050 if (input_node->isa<mindspore::CNode>()) {
1051 auto ret = ConvertInputCNode(input_node, fb_node);
1052 if (ret != RET_OK) {
1053 MS_LOG(ERROR) << "ConvertInputCNode failed";
1054 return ret;
1055 }
1056 } else if (input_node->isa<Parameter>()) {
1057 size_t tensor_index;
1058 auto ret = ConvertInputParameter(cnode, i, primitive_c, meta_graphT, fb_node, &tensor_index);
1059 if (ret != RET_OK) {
1060 MS_LOG(ERROR) << "ConvertInputParameter failed";
1061 return ret;
1062 }
1063 if (IsContain(graph_inputs_, input_node->cast<AnfNodePtr>()) &&
1064 graph_inputs_map_.find(input_node) == graph_inputs_map_.end()) {
1065 graph_inputs_map_[input_node] = tensor_index;
1066 }
1067 } else if (input_node->isa<ValueNode>()) {
1068 auto ret = ConvertInputValueNode(cnode, i, primitive_c, meta_graphT, fb_node);
1069 if (ret != RET_OK) {
1070 MS_LOG(ERROR) << "ConvertInputValueNode failed";
1071 return RET_ERROR;
1072 }
1073 }
1074 }
1075 fb_node->name = cnode->fullname_with_scope();
1076 return RET_OK;
1077 }
1078
SetOpOutputNode(const CNodePtr & cnode,const std::unique_ptr<schema::MetaGraphT> & meta_graphT,schema::CNodeT * fb_node)1079 int AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
1080 schema::CNodeT *fb_node) {
1081 MS_ASSERT(meta_graphT != nullptr);
1082 MS_ASSERT(fb_node != nullptr);
1083 std::string cnode_name = fb_node->name;
1084
1085 // new anf export and import will add abstract tuple for control flow op, which contains abstract closure,
1086 // abstract tuple and abstract tensor. For inference, we don't need this information. So skip export abstract tuple
1087 // for control flow op. Just use a abstract tensor link the control flow ops.
1088 if (utils::isa<abstract::AbstractTuple>(cnode->abstract()) && !IsControlFlowOp(cnode)) {
1089 auto tuple = std::reinterpret_pointer_cast<abstract::AbstractTuple>(cnode->abstract());
1090 MS_CHECK_TRUE_MSG(tuple != nullptr, RET_ERROR, "tuple is nullptr");
1091 auto elements = tuple->elements();
1092 for (size_t i = 0; i < lite::GetCNodeOutputsSize(cnode, train_flag_); i++) {
1093 auto ms_tensor = new (std::nothrow) schema::TensorT();
1094 if (ms_tensor == nullptr) {
1095 MS_LOG(ERROR) << "new msTensor failed";
1096 return RET_ERROR;
1097 }
1098 ms_tensor->nodeType = NodeType_CNode;
1099 auto key = std::make_pair(cnode, i);
1100 if (!train_flag_) {
1101 auto val_ptr = cnode->GetAttr("outputs_names");
1102 std::string tensor_name = "";
1103 std::string name_surfix = "";
1104 auto val_index = i;
1105 if (elements.size() == 1) {
1106 key = std::make_pair(cnode, 0);
1107 val_index = 0;
1108 } else {
1109 name_surfix = "_o:" + std::to_string(i);
1110 }
1111 if (val_ptr != nullptr) {
1112 auto outputs_names = GetValue<std::vector<std::string>>(val_ptr);
1113 tensor_name = outputs_names[val_index];
1114 } else {
1115 tensor_name = cnode_name + name_surfix;
1116 }
1117
1118 if (!utils::isa<abstract::AbstractTensorPtr>(elements[i])) {
1119 MS_LOG(ERROR) << "abstract is not AbstractTensor";
1120 delete (ms_tensor);
1121 return RET_ERROR;
1122 }
1123 auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(elements[i]);
1124 MS_CHECK_TRUE_MSG(abstract_tensor != nullptr, RET_ERROR, "Cast to abstract tensor failed!");
1125 auto type_ptr = abstract_tensor->element()->GetTypeTrack();
1126 MS_CHECK_TRUE_MSG(type_ptr != nullptr, RET_ERROR, "type_ptr is nullptr");
1127 ms_tensor->dataType = type_ptr->type_id();
1128 ms_tensor->name = tensor_name;
1129
1130 auto tensor_index = NewFbTensor(meta_graphT, ms_tensor);
1131 SetNodeId(key, tensor_index);
1132 fb_node->outputIndex.emplace_back(tensor_index);
1133 if (opt::CheckPrimitiveType(cnode, prim::kPrimConv2DFusion) ||
1134 opt::CheckPrimitiveType(cnode, prim::kPrimFusedBatchNorm) ||
1135 opt::CheckPrimitiveType(cnode, prim::kPrimLayerNormFusion)) {
1136 break;
1137 }
1138 } else {
1139 auto tensor_index = NewFbTensor(meta_graphT, ms_tensor);
1140 SetNodeId(key, tensor_index);
1141 fb_node->outputIndex.emplace_back(tensor_index);
1142 }
1143 }
1144 } else {
1145 auto ms_tensor = new (std::nothrow) schema::TensorT();
1146 if (ms_tensor == nullptr) {
1147 MS_LOG(ERROR) << "new tensor failed";
1148 return RET_ERROR;
1149 }
1150 auto type = kNumberTypeFloat32;
1151 if (utils::isa<abstract::AbstractTensorPtr>(cnode->abstract())) {
1152 auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(cnode->abstract());
1153 MS_CHECK_TRUE_MSG(abstract_tensor != nullptr, RET_ERROR, "Cast to abstract tensor failed!");
1154 auto typePtr = abstract_tensor->element()->GetTypeTrack();
1155 type = typePtr->type_id();
1156 }
1157 ms_tensor->dataType = type;
1158 ms_tensor->nodeType = NodeType_CNode;
1159 auto val_ptr = cnode->GetAttr("outputs_names");
1160 if (val_ptr != nullptr) {
1161 auto outputs_names = GetValue<std::vector<std::string>>(val_ptr);
1162 ms_tensor->name = outputs_names[0];
1163 } else {
1164 ms_tensor->name = cnode_name;
1165 }
1166 auto tensor_index = NewFbTensor(meta_graphT, ms_tensor);
1167 auto key = std::make_pair(cnode, 0);
1168 SetNodeId(key, tensor_index);
1169 fb_node->outputIndex.emplace_back(tensor_index);
1170 }
1171 return RET_OK;
1172 }
1173
CreateCallCnode(const FuncGraphPtr & fg,const AnfNodePtr & node)1174 CNodePtr AnfExporter::CreateCallCnode(const FuncGraphPtr &fg, const AnfNodePtr &node) {
1175 auto call_anf_prim_vnode = GetCallAnfPrim();
1176 MS_CHECK_TRUE_MSG(call_anf_prim_vnode != nullptr, nullptr, "GetCallAnfPrim failed");
1177 std::vector<AnfNodePtr> inputs{call_anf_prim_vnode, node};
1178 auto cnode = fg->NewCNodeInOrder(inputs);
1179 MS_CHECK_TRUE_MSG(cnode != nullptr, nullptr, "NewCNode failed");
1180 cnode->set_func_graph(fg);
1181 return cnode;
1182 }
1183
CreatePartialCnode(const FuncGraphPtr & fg,const AnfNodePtr & node)1184 CNodePtr AnfExporter::CreatePartialCnode(const FuncGraphPtr &fg, const AnfNodePtr &node) {
1185 if (utils::isa<CNodePtr>(node)) {
1186 auto cnode = utils::cast<CNodePtr>(node);
1187 MS_CHECK_TRUE_MSG(cnode != nullptr, nullptr, "cast ptr failed");
1188 auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(kPrimIndex));
1189 if (primitive_c != nullptr) {
1190 return cnode;
1191 }
1192 auto partial_anf_prim_vnode = GetPartialFusionPrim();
1193 auto cnode_input = cnode->inputs();
1194 MS_CHECK_TRUE_MSG(partial_anf_prim_vnode != nullptr, nullptr, "GetPartialFusionPrim failed");
1195 cnode_input.insert(cnode_input.begin(), partial_anf_prim_vnode);
1196 cnode->set_inputs(cnode_input);
1197 return cnode;
1198 } else if (utils::isa<ValueNodePtr>(node)) {
1199 auto partial_anf_prim_vnode = GetPartialFusionPrim();
1200 MS_CHECK_TRUE_MSG(partial_anf_prim_vnode != nullptr, nullptr, "GetPartialFusionPrim failed");
1201 std::vector<AnfNodePtr> inputs{partial_anf_prim_vnode, node};
1202 auto cnode = fg->NewCNode(inputs);
1203 MS_CHECK_TRUE_MSG(cnode != nullptr, nullptr, "New cnode failed");
1204 return cnode;
1205 } else {
1206 MS_LOG(ERROR) << "failed to create partial cnode.";
1207 return nullptr;
1208 }
1209 }
1210
Export(const FuncGraphPtr & func_graph,bool keep_graph,bool copy_primitive,bool train_flag)1211 schema::MetaGraphT *Export(const FuncGraphPtr &func_graph, bool keep_graph, bool copy_primitive, bool train_flag) {
1212 AnfExporter lite_exporter;
1213 return lite_exporter.Export(func_graph, keep_graph, copy_primitive, train_flag);
1214 }
1215 } // namespace mindspore::lite
1216