• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define USE_DEPRECATED_API
18 
19 #include "tools/lite_exporter/anf_exporter.h"
20 #include <functional>
21 #include <list>
22 #include <memory>
23 #include <string>
24 #include <utility>
25 #include <vector>
26 #include "abstract/abstract_value.h"
27 #include "mindspore/core/ir/primitive.h"
28 #include "mindspore/core/ops/framework_ops.h"
29 #include "mindspore/core/ops/lite_ops.h"
30 #include "mindspore/core/ops/nn_ops.h"
31 #include "mindspore/core/ops/op_name.h"
32 #include "mindspore/core/ops/op_utils.h"
33 #include "mindspore/core/ops/sequence_ops.h"
34 #include "nnacl/op_base.h"
35 #include "ops/depend.h"
36 #include "ops/fusion/partial_fusion.h"
37 #include "ops/make_tuple.h"
38 #include "ops/return.h"
39 #include "ops/tuple_get_item.h"
40 #include "ops/fusion/make_tuple_v2.h"
41 #include "src/common/log_util.h"
42 #include "src/common/ops/anf_utils.h"
43 #include "src/common/utils.h"
44 #include "src/litert/tensor_category.h"
45 #include "tools/common/graph_util.h"
46 #include "tools/common/meta_graph_utils.h"
47 #include "tools/common/node_util.h"
48 #include "tools/converter/converter_context.h"
49 #include "tools/converter/quantizer/quantize_util.h"
50 
51 using mindspore::ops::PrimitiveC;
52 
53 namespace {
54 constexpr const int kMainGraphIndex = 0;
55 constexpr const int kFirstDataIndex = 1;
56 constexpr const int kSecondDataIndex = 2;
57 constexpr const int kThirdDataIndex = 3;
58 constexpr const int kPrimIndex = 0;
59 };  // namespace
60 
61 namespace mindspore::lite {
62 namespace {
63 constexpr int kIndexOfValueInputOfGetTupleItem = 2;
64 constexpr int kMaxDepth = 2048;
65 
GetOrderedCNodes(const FuncGraphPtr fg)66 std::list<CNodePtr> GetOrderedCNodes(const FuncGraphPtr fg) {
67   MS_CHECK_TRUE_MSG(fg != nullptr, {}, "fg is nullptr.");
68   auto BelongSameGraph = std::bind(IncludeBelongGraph, fg, std::placeholders::_1);
69   auto succ_include_fv = [&fg](const AnfNodePtr &node) -> std::vector<AnfNodePtr> {
70     std::vector<AnfNodePtr> vecs{};
71     if (node == nullptr) {
72       return vecs;
73     }
74     if (node->isa<mindspore::CNode>()) {
75       auto cnode = node->cast<CNodePtr>();
76       MS_ASSERT(cnode != nullptr);
77       auto &inputs = cnode->inputs();
78       // Check if free variables used.
79       for (const auto &input : inputs) {
80         auto input_fg = GetValueNode<FuncGraphPtr>(input);
81         if (input_fg) {
82           for (auto &fv : input_fg->free_variables_nodes()) {
83             if (fv->func_graph() == fg && fg->nodes().contains(fv)) {
84               vecs.push_back(fv);
85             }
86           }
87         }
88       }
89       (void)vecs.insert(vecs.end(), inputs.begin(), inputs.end());
90     }
91     return vecs;
92   };
93 
94   std::list<CNodePtr> cnodes{};
95   auto nodes = TopoSort(fg->get_return(), succ_include_fv, BelongSameGraph);
96   for (const auto &node : nodes) {
97     auto cnode = dyn_cast<mindspore::CNode>(node);
98     if (cnode) {
99       cnodes.push_back(cnode);
100     }
101   }
102   return cnodes;
103 }
104 
CreateTensorFromDataInfo(const lite::DataInfo & data_info,const std::string & name,const bool has_default)105 std::unique_ptr<schema::TensorT> CreateTensorFromDataInfo(const lite::DataInfo &data_info, const std::string &name,
106                                                           const bool has_default) {
107   auto schema_tensor = std::make_unique<schema::TensorT>();
108   MS_CHECK_TRUE_MSG(schema_tensor != nullptr, nullptr, "schema_tensor is nullptr");
109   schema_tensor->format = static_cast<schema::Format>(data_info.format_);
110   schema_tensor->name = name;
111   schema_tensor->dims = data_info.shape_;
112   schema_tensor->dataType = data_info.data_type_;
113   schema_tensor->data = data_info.data_;
114   if (has_default) {
115     schema_tensor->nodeType = NodeType_ValueNode;
116   } else {
117     schema_tensor->nodeType = NodeType_CNode;
118   }
119   schema_tensor->enableHuffmanCode = data_info.enable_huffman_code_;
120   schema_tensor->weightQuantCompressType =
121     static_cast<mindspore::schema::WeightQuantCompressType>(data_info.compress_type_);
122   return schema_tensor;
123 }
124 }  // namespace
125 
ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> & meta_graph,const std::shared_ptr<mindspore::Primitive> & primitive,const std::unique_ptr<schema::CNodeT> & dst_node)126 int AnfExporter::ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &meta_graph,
127                                    const std::shared_ptr<mindspore::Primitive> &primitive,
128                                    const std::unique_ptr<schema::CNodeT> &dst_node) {
129   MS_ASSERT(meta_graph != nullptr);
130   MS_ASSERT(primitive != nullptr);
131   MS_ASSERT(dst_node != nullptr);
132   // add quant param
133   MS_LOG(DEBUG) << "node: " << dst_node->name << " add QuantParam";
134   // activation
135   QuantParamsVector input_quant_params;
136   QuantParamsVector output_quant_params;
137   dst_node->quantType = schema::QuantType_QUANT_NONE;
138   auto quant_tensor_info_ptr = primitive->GetAttr("quant_params");
139   if (quant_tensor_info_ptr == nullptr) {
140     return RET_OK;
141   }
142   auto quant_param_holder = quant_tensor_info_ptr->cast<QuantParamHolderPtr>();
143   CHECK_NULL_RETURN(quant_param_holder);
144   input_quant_params = quant_param_holder->get_input_quant_params();
145   output_quant_params = quant_param_holder->get_output_quant_params();
146   dst_node->quantType = static_cast<schema::QuantType>(static_cast<int>(quant_param_holder->quant_type()));
147 
148   // convert input quant param
149   for (size_t i = 0; i < dst_node->inputIndex.size(); i++) {
150     if (i >= input_quant_params.size()) {
151       MS_LOG(INFO) << "node: " << dst_node->name << " has " << dst_node->inputIndex.size() << " input, but only has "
152                    << input_quant_params.size() << " quant params";
153       break;
154     }
155     auto activate_index = dst_node->inputIndex[i];
156     MS_CHECK_TRUE_MSG(GetAllTensorSize(meta_graph) > activate_index, RET_ERROR, "allTensors size is wrong.");
157     auto tensor_input = GetTensorFromAllTensor(meta_graph, activate_index);
158     CHECK_NULL_RETURN(tensor_input);
159 
160     tensor_input->quantClusters = quant_param_holder->GetQuantClusters(i);
161 
162     if (!TensorQuantParamsInited(*tensor_input)) {
163       tensor_input->quantParams.clear();
164       for (auto input_quant_param : input_quant_params[i]) {
165         auto input_quant_param_ptr = std::make_unique<schema::QuantParamT>(input_quant_param);
166         MS_CHECK_TRUE_MSG(input_quant_param_ptr != nullptr, RET_ERROR, "input_quant_param_ptr is nullptr");
167         MS_LOG(DEBUG) << "[input][" << i << "]node: " << dst_node->name << " scale: " << input_quant_param_ptr->scale
168                       << " zp: " << input_quant_param_ptr->zeroPoint;
169         tensor_input->quantParams.emplace_back(std::move(input_quant_param_ptr));
170       }
171     }
172   }
173 
174   // output_quant_params
175   for (size_t index = 0; index < dst_node->outputIndex.size(); ++index) {
176     if (index >= output_quant_params.size()) {
177       MS_LOG(INFO) << "node: " << dst_node->name << " has " << dst_node->outputIndex.size() << " output, but only has"
178                    << output_quant_params.size() << " quant params";
179       break;
180     }
181     auto output_tensor = GetTensorFromAllTensor(meta_graph, dst_node->outputIndex[index]);
182     auto &output_quant_param = output_quant_params[index];
183     for (const auto &channel_quant_param : output_quant_param) {
184       if (output_tensor->quantParams.empty() && dst_node->quantType != schema::QuantType_QUANT_WEIGHT) {
185         std::unique_ptr<schema::QuantParamT> output_quant_param_ptr =
186           std::make_unique<schema::QuantParamT>(channel_quant_param);
187         CHECK_NULL_RETURN(output_quant_param_ptr);
188         MS_LOG(DEBUG) << "[output]node: " << dst_node->name << " scale: " << output_quant_param_ptr->scale
189                       << " zp: " << output_quant_param_ptr->zeroPoint;
190         output_tensor->quantParams.emplace_back(std::move(output_quant_param_ptr));
191       }
192     }
193   }
194 
195   return RET_OK;
196 }
197 
ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> & meta_graph,const CNodePtr & cnode,const std::shared_ptr<mindspore::Primitive> & primitive,const std::unique_ptr<schema::CNodeT> & dst_node)198 int AnfExporter::ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &meta_graph, const CNodePtr &cnode,
199                                    const std::shared_ptr<mindspore::Primitive> &primitive,
200                                    const std::unique_ptr<schema::CNodeT> &dst_node) {
201   CHECK_NULL_RETURN(meta_graph);
202   CHECK_NULL_RETURN(dst_node);
203   CHECK_NULL_RETURN(cnode);
204   // quant_type not exist in cnode, return
205   auto quant_type_attr = primitive->GetAttr(quant::kQuantType);
206   if (!opt::CheckPrimitiveType(cnode, prim::kPrimQuantDTypeCast)) {
207     if (quant_type_attr != nullptr) {
208       dst_node->quantType = static_cast<schema::QuantType>(GetValue<int32_t>(quant_type_attr));
209     } else {
210       MS_LOG(DEBUG) << "quant_type not exist in cnode, node name: " << dst_node->name;
211       return RET_OK;
212     }
213   } else {
214     dst_node->quantType = schema::QuantType_QUANT_NONE;
215   }
216 
217   // convert input quant param
218   for (size_t i = 0; i < dst_node->inputIndex.size(); i++) {
219     auto activate_index = dst_node->inputIndex[i];
220     MS_CHECK_TRUE_MSG(meta_graph->allTensors.size() > activate_index, RET_ERROR, "allTensors size is wrong.");
221     auto tensor_input = meta_graph->allTensors[activate_index].get();
222     auto input_node = cnode->input(i + quant::kPrimOffset);
223     auto status = SetInputQuantParamToTensorT(primitive, input_node, tensor_input);
224     if (status != RET_NO_CHANGE && status != RET_OK) {
225       MS_LOG(ERROR) << "[input][" << i << "] node: " << dst_node->name << " SetInputQuantParamToTensorT failed.";
226       return status;
227     }
228   }
229 
230   // output_quant_params
231   for (size_t i = 0; i < dst_node->outputIndex.size(); ++i) {
232     auto output_tensor = meta_graph->allTensors[dst_node->outputIndex[i]].get();
233     auto quantization_param_value = primitive->GetAttr(quant::kQuantParam);
234     if (quantization_param_value == nullptr) {
235       MS_LOG(INFO) << "[output]node: " << dst_node->name << " output quant param Not exist.";
236       continue;
237     }
238     auto quantization_param_list = GetValue<std::vector<QuantizationParamPtr>>(quantization_param_value);
239     if (quantization_param_list.empty()) {
240       MS_LOG(INFO) << "[output]node: " << dst_node->name << " output quant param Not exist.";
241       continue;
242     }
243     if (output_tensor->quantParams.empty() && dst_node->quantType != schema::QuantType_QUANT_WEIGHT) {
244       // Set QuantParamT into meta_graph tensor
245       // Not support cnode with multi-output
246       auto quant_params = quant::ConvertQuantizationParamToQuantParamT(quantization_param_list.front());
247       for (auto quant_param : quant_params) {
248         auto quant_param_ptr = std::make_unique<schema::QuantParamT>(quant_param);
249         MS_LOG(DEBUG) << "node: " << output_tensor->name << " scale: " << quant_param_ptr->scale
250                       << " zp: " << quant_param_ptr->zeroPoint;
251         CHECK_NULL_RETURN(quant_param_ptr);
252         output_tensor->quantParams.emplace_back(std::move(quant_param_ptr));
253       }
254     }
255   }
256   return RET_OK;
257 }
258 
SetInputQuantParamToTensorT(const std::shared_ptr<mindspore::Primitive> & primitive,const AnfNodePtr & input_node,mindspore::schema::TensorT * tensor_input)259 int AnfExporter::SetInputQuantParamToTensorT(const std::shared_ptr<mindspore::Primitive> &primitive,
260                                              const AnfNodePtr &input_node, mindspore::schema::TensorT *tensor_input) {
261   CHECK_NULL_RETURN(primitive);
262   CHECK_NULL_RETURN(input_node);
263   CHECK_NULL_RETURN(tensor_input);
264   if (IsGraphInput(input_node)) {
265     if (!primitive->HasAttr(quant::kGraphInputQuantParam)) {
266       return RET_NO_CHANGE;
267     }
268     if (TensorQuantParamsInited(*tensor_input)) {
269       MS_LOG(DEBUG) << input_node->fullname_with_scope() << " TensorT quant param exist.";
270       return RET_NO_CHANGE;
271     }
272     tensor_input->quantParams.clear();
273     auto quantization_param_value = primitive->GetAttr(quant::kGraphInputQuantParam);
274     auto quantization_param_ptr = quantization_param_value->cast<QuantizationParamPtr>();
275     CHECK_NULL_RETURN(quantization_param_ptr);
276     auto quant_params = quant::ConvertQuantizationParamToQuantParamT(quantization_param_ptr);
277     for (auto quant_param : quant_params) {
278       auto quant_param_ptr = std::make_unique<schema::QuantParamT>(quant_param);
279       MS_LOG(DEBUG) << "node: " << input_node->fullname_with_scope() << " scale: " << quant_param_ptr->scale
280                     << " zp: " << quant_param_ptr->zeroPoint;
281       tensor_input->quantParams.emplace_back(std::move(quant_param_ptr));
282     }
283   } else if (input_node->isa<mindspore::CNode>()) {
284     // input node has single output
285     auto input_cnode = input_node->cast<mindspore::CNodePtr>();
286     auto input_primitive = GetValueNode<PrimitivePtr>(input_cnode->input(0));
287     MS_CHECK_TRUE_MSG(input_primitive != nullptr, RET_ERROR, "Input node primitive nullptr.");
288     if (!input_primitive->HasAttr(quant::kQuantParam)) {
289       return RET_NO_CHANGE;
290     }
291     if (TensorQuantParamsInited(*tensor_input)) {
292       MS_LOG(DEBUG) << input_node->fullname_with_scope() << " TensorT quant param exist.";
293       return RET_NO_CHANGE;
294     }
295     tensor_input->quantParams.clear();
296     auto quantization_param_value = input_primitive->GetAttr(quant::kQuantParam);
297     auto quantization_param_list = GetValue<std::vector<QuantizationParamPtr>>(quantization_param_value);
298     if (quantization_param_list.empty()) {
299       MS_LOG(DEBUG) << input_node->fullname_with_scope() << " quantization param is empty.";
300       return RET_NO_CHANGE;
301     }
302     auto quant_params = quant::ConvertQuantizationParamToQuantParamT(quantization_param_list.front());
303     for (auto quant_param : quant_params) {
304       auto quant_param_ptr = std::make_unique<schema::QuantParamT>(quant_param);
305       MS_LOG(DEBUG) << "node: " << input_node->fullname_with_scope() << " scale: " << quant_param_ptr->scale
306                     << " zp: " << quant_param_ptr->zeroPoint;
307       tensor_input->quantParams.emplace_back(std::move(quant_param_ptr));
308     }
309   } else if (input_node->isa<mindspore::Parameter>() || input_node->isa<mindspore::ValueNode>()) {
310     tensor::TensorPtr input_tensor = quant::GetNodeTensor(input_node);
311     MS_CHECK_TRUE_RET(input_tensor != nullptr, RET_NO_CHANGE);
312     auto quantization_params = input_tensor->quant_params();
313     if (quantization_params.empty()) {
314       MS_LOG(DEBUG) << input_node->fullname_with_scope() << " quantization param is empty.";
315       return RET_NO_CHANGE;
316     }
317     auto quantization_param = quantization_params.front();
318     auto cluster_centroid_list_attr = quantization_param->GetAttr(quant::kClusterCentroidList);
319     if (cluster_centroid_list_attr != nullptr) {
320       tensor_input->quantClusters = GetValue<std::vector<float>>(cluster_centroid_list_attr);
321       return RET_OK;
322     }
323     if (!TensorQuantParamsInited(*tensor_input)) {
324       tensor_input->quantParams.clear();
325       // Set QuantParamT into meta_graph tensor
326       auto quant_params = quant::ConvertQuantizationParamToQuantParamT(quantization_param);
327       for (auto quant_param : quant_params) {
328         auto quant_param_ptr = std::make_unique<schema::QuantParamT>(quant_param);
329         MS_LOG(DEBUG) << "node: " << tensor_input->name << " scale: " << quant_param_ptr->scale
330                       << " zp: " << quant_param_ptr->zeroPoint;
331         CHECK_NULL_RETURN(quant_param_ptr);
332         tensor_input->quantParams.emplace_back(std::move(quant_param_ptr));
333       }
334     }
335   } else {
336     MS_LOG(WARNING) << input_node->fullname_with_scope() << " : " << input_node->type_name() << " not supported.";
337   }
338   return RET_OK;
339 }
340 
CreateNewTensorForParameter(const std::unique_ptr<schema::MetaGraphT> & meta_graphT,const AnfNodePtr & input,size_t * tensor_index_ptr)341 int AnfExporter::CreateNewTensorForParameter(const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
342                                              const AnfNodePtr &input, size_t *tensor_index_ptr) {
343   MS_CHECK_TRUE_MSG(meta_graphT != nullptr, RET_NULL_PTR, "meta_graphT is nullptr");
344   MS_CHECK_TRUE_MSG(input != nullptr, RET_NULL_PTR, "input is nullptr");
345   MS_CHECK_TRUE_MSG(tensor_index_ptr != nullptr, RET_NULL_PTR, "tensor_index_ptr is nullptr");
346   lite::DataInfo data_info;
347   auto param_node = input->cast<ParameterPtr>();
348   MS_CHECK_TRUE_MSG(param_node != nullptr, RET_NULL_PTR, "cast ptr failed");
349   if (FetchFromDefaultParam(param_node, converter::FmkType(meta_graphT->fmkType), &data_info, true) != RET_OK) {
350     MS_LOG(ERROR) << "FetchFromDefaultParam failed.";
351     return RET_ERROR;
352   }
353   auto schema_tensor = CreateTensorFromDataInfo(data_info, param_node->name(), param_node->has_default());
354   auto key = std::make_pair(input, 0);
355   *tensor_index_ptr = NewFbTensor(meta_graphT, schema_tensor.release());
356   SetNodeId(key, *tensor_index_ptr);
357   return RET_OK;
358 }
359 
SetSubGraphInputIndex(const std::unique_ptr<schema::MetaGraphT> & meta_graphT,const size_t & subgraph_index)360 int AnfExporter::SetSubGraphInputIndex(const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
361                                        const size_t &subgraph_index) {
362   MS_CHECK_TRUE_MSG(meta_graphT != nullptr, RET_NULL_PTR, "meta_graphT is nullptr");
363   auto &subgraph = meta_graphT->subGraph.at(subgraph_index);
364   FuncGraphPtr fg = nullptr;
365   std::for_each(fg_subgraph_map_.begin(), fg_subgraph_map_.end(),
366                 [&subgraph_index, &fg](const std::pair<const FuncGraphPtr, size_t> &it) {
367                   if (it.second == subgraph_index) {
368                     fg = it.first;
369                   }
370                 });
371 
372   auto inputs = fg->get_inputs();
373   for (auto &input : inputs) {
374     auto key = std::make_pair(input, 0);
375     size_t tensor_index;
376     if (HasNodeIdKey(key)) {
377       subgraph->inputIndices.emplace_back(GetNodeId(key));
378     } else {
379       if (CreateNewTensorForParameter(meta_graphT, input, &tensor_index) != RET_OK) {
380         MS_LOG(ERROR) << "CreateNewTensorForParameter failed.";
381         return RET_ERROR;
382       }
383       subgraph->inputIndices.emplace_back(tensor_index);
384     }
385   }
386   return RET_OK;
387 }
388 
SetSubGraphOutputIndex(const CNodePtr & cnode,const size_t subgraph_index,const std::unique_ptr<schema::MetaGraphT> & meta_graphT,schema::CNodeT * return_node)389 int AnfExporter::SetSubGraphOutputIndex(const CNodePtr &cnode, const size_t subgraph_index,
390                                         const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
391                                         schema::CNodeT *return_node) {
392   MS_ASSERT(meta_graphT != nullptr);
393   MS_ASSERT(return_node != nullptr);
394   for (size_t i = kFirstDataIndex; i < cnode->size(); i++) {
395     auto input_node = cnode->input(i);
396     if (input_node == nullptr) {
397       MS_LOG(ERROR) << "output node is nullptr";
398       return RET_NULL_PTR;
399     } else if (input_node->isa<mindspore::CNode>()) {
400       auto ret = ConvertInputCNode(input_node, return_node);
401       if (ret != RET_OK) {
402         MS_LOG(ERROR) << "obtain outputs failed";
403         return ret;
404       }
405     } else if (input_node->isa<Parameter>()) {
406       auto key = std::make_pair(input_node, 0);
407       size_t tensor_index;
408       if (HasNodeIdKey(key)) {
409         return_node->inputIndex.emplace_back(GetNodeId(key));
410       } else {
411         if (CreateNewTensorForParameter(meta_graphT, input_node, &tensor_index) != RET_OK) {
412           MS_LOG(ERROR) << "CreateNewTensorForParameter failed.";
413           return RET_ERROR;
414         }
415         return_node->inputIndex.emplace_back(tensor_index);
416       }
417       if (IsContain(graph_inputs_, input_node->cast<AnfNodePtr>()) &&
418           graph_inputs_map_.find(input_node) == graph_inputs_map_.end()) {
419         graph_inputs_map_[input_node] = tensor_index;
420       }
421     } else {
422       MS_LOG(ERROR) << "the node " << input_node->fullname_with_scope().c_str() << "is not output node";
423       return RET_ERROR;
424     }
425   }
426   for (unsigned int &i : return_node->inputIndex) {
427     meta_graphT->subGraph.at(subgraph_index)->outputIndices.push_back(i);
428   }
429   return RET_OK;
430 }
431 
HasExported(const FuncGraphPtr & func_graph)432 bool AnfExporter::HasExported(const FuncGraphPtr &func_graph) {
433   if (fg_subgraph_map_.find(func_graph) != fg_subgraph_map_.end()) {
434     return true;
435   }
436   return false;
437 }
438 
ExportPartialNode(const std::unique_ptr<schema::MetaGraphT> & meta_graphT,const bool & keep_graph,const bool & copy_primitive,const CNodePtr & partial_cnode,const std::unique_ptr<schema::CNodeT> & schema_cnode)439 int AnfExporter::ExportPartialNode(const std::unique_ptr<schema::MetaGraphT> &meta_graphT, const bool &keep_graph,
440                                    const bool &copy_primitive, const CNodePtr &partial_cnode,
441                                    const std::unique_ptr<schema::CNodeT> &schema_cnode) {
442   MS_CHECK_TRUE_MSG(meta_graphT != nullptr, RET_NULL_PTR, "meta_graphT is nullptr");
443   MS_CHECK_TRUE_MSG(partial_cnode != nullptr, RET_NULL_PTR, "partial_cnode is nullptr");
444   MS_CHECK_TRUE_MSG(schema_cnode != nullptr, RET_NULL_PTR, "schema_cnode is nullptr");
445   auto prim = GetValueNode<std::shared_ptr<mindspore::Primitive>>(partial_cnode->input(0));
446   MS_CHECK_TRUE_MSG(prim != nullptr, RET_NULL_PTR, "GetValueNode failed");
447   if (prim->name() != mindspore::ops::kNamePartialFusion) {
448     MS_LOG(INFO) << "not is partial";
449     return RET_OK;
450   }
451 
452   auto partial_fusion_primc = schema_cnode->primitive->value.AsPartialFusion();
453   auto vnode = partial_cnode->input(kFirstDataIndex)->cast<ValueNodePtr>();
454   MS_CHECK_TRUE_MSG(partial_fusion_primc != nullptr, RET_NULL_PTR, "partial_fusion_primc is invalid");
455   MS_CHECK_TRUE_MSG(vnode != nullptr, RET_NULL_PTR, "vnode is invalid");
456   auto fg = vnode->value()->cast<FuncGraphPtr>();
457   MS_CHECK_TRUE_MSG(fg != nullptr, RET_NULL_PTR, "func graph is nullptr.");
458   if (fg_subgraph_map_.find(fg) != fg_subgraph_map_.end()) {
459     partial_fusion_primc->sub_graph_index = static_cast<int>(fg_subgraph_map_.at(fg));
460     return RET_OK;
461   }
462 
463   partial_fusion_primc->sub_graph_index = static_cast<int>(meta_graphT->subGraph.size());
464   auto ret = ExportSubgraph(fg, meta_graphT, keep_graph, copy_primitive, partial_cnode);
465   if (ret != RET_OK) {
466     MS_LOG(ERROR) << "ExportSubgraph failed";
467     return ret;
468   }
469   return RET_OK;
470 }
471 
InsertCallNode(const FuncGraphPtr & func_graph)472 std::list<CNodePtr> AnfExporter::InsertCallNode(const FuncGraphPtr &func_graph) {
473   MS_CHECK_TRUE_MSG(func_graph != nullptr, {}, "func_graph is nullptr");
474   auto cnodes = GetOrderedCNodes(func_graph);
475   for (auto it = cnodes.begin(); it != cnodes.end();) {
476     auto prim = GetValueNode<std::shared_ptr<mindspore::Primitive>>((*it)->input(kPrimIndex));
477     if (prim == nullptr) {
478       auto fg = GetValueNode<FuncGraphPtr>((*it)->input(kPrimIndex));
479       if (fg != nullptr) {
480         auto partial_cnode = CreatePartialCnode(fg, (*it));
481         auto call_cnode = CreateCallCnode(fg, partial_cnode);
482         ++it;
483         it = cnodes.insert(it, call_cnode);
484         continue;
485       } else {
486         auto call_anf_prim_vnode = GetCallAnfPrim();
487         auto cnode_input = (*it)->inputs();
488         cnode_input.insert(cnode_input.begin(), call_anf_prim_vnode);
489         (*it)->set_inputs(cnode_input);
490       }
491     }
492     ++it;
493   }
494   return cnodes;
495 }
496 
SetNonTailCall(const CNodePtr & cnode,schema::CNodeT * node)497 void AnfExporter::SetNonTailCall(const CNodePtr &cnode, schema::CNodeT *node) {
498   if (cnode == nullptr || node == nullptr) {
499     MS_LOG(ERROR) << "conde or node is nullptr";
500     return;
501   }
502   if (!opt::CheckPrimitiveType(cnode, prim::kPrimCall)) {
503     return;
504   }
505   node->primitive->value.AsCall()->is_tail_call = false;
506   call_node_map_[cnode] = node;
507   return;
508 }
509 
SetTailCallForReturn(const CNodePtr & return_cnode)510 int AnfExporter::SetTailCallForReturn(const CNodePtr &return_cnode) {
511   MS_CHECK_TRUE_MSG(return_cnode != nullptr, RET_NULL_PTR, "return_cnode is nullptr");
512   auto return_cnode_input_size = return_cnode->size();
513   for (size_t i = 1; i < return_cnode_input_size; ++i) {
514     if (!utils::isa<CNodePtr>(return_cnode->input(i))) {
515       continue;
516     }
517     if (!opt::CheckPrimitiveType(return_cnode->input(i), prim::kPrimCall)) {
518       continue;
519     }
520     auto call_cnode = return_cnode->input(i)->cast<CNodePtr>();
521     if (call_node_map_.find(call_cnode) == call_node_map_.end()) {
522       MS_LOG(ERROR) << "Not found call cnode in call_node_map.";
523       return RET_ERROR;
524     }
525     call_node_map_[call_cnode]->primitive->value.AsCall()->is_tail_call = true;
526   }
527   return RET_OK;
528 }
529 
SetTailCallForNonOutput()530 int AnfExporter::SetTailCallForNonOutput() {
531   for (auto item : call_node_map_) {
532     auto call_cnode = item.first;
533     auto mg = call_cnode->func_graph()->manager();
534     if (mg == nullptr) {
535       MS_LOG(ERROR) << "manager is nullptr.";
536       return RET_NULL_PTR;
537     }
538     auto node_user = mg->node_users()[call_cnode];
539     if (node_user.empty()) {
540       (item.second)->primitive->value.AsCall()->is_tail_call = true;
541     }
542   }
543   return RET_OK;
544 }
545 
GetNodeId(const std::pair<AnfNodePtr,size_t> & key)546 size_t AnfExporter::GetNodeId(const std::pair<AnfNodePtr, size_t> &key) {
547   node_id_map_mutex_.lock();
548   auto node_tensor_index = node_id_map_[key];
549   node_id_map_mutex_.unlock();
550   return node_tensor_index;
551 }
552 
SetNodeId(const std::pair<AnfNodePtr,size_t> & key,size_t value)553 void AnfExporter::SetNodeId(const std::pair<AnfNodePtr, size_t> &key, size_t value) {
554   node_id_map_mutex_.lock();
555   node_id_map_[key] = value;
556   node_id_map_mutex_.unlock();
557 }
558 
HasNodeIdKey(const std::pair<AnfNodePtr,size_t> & key)559 bool AnfExporter::HasNodeIdKey(const std::pair<AnfNodePtr, size_t> &key) {
560   node_id_map_mutex_.lock();
561   auto has_key = node_id_map_.find(key) != node_id_map_.end();
562   node_id_map_mutex_.unlock();
563   return has_key;
564 }
565 
NewFbTensor(const std::unique_ptr<schema::MetaGraphT> & meta_graphT,mindspore::schema::TensorT * tensor)566 size_t AnfExporter::NewFbTensor(const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
567                                 mindspore::schema::TensorT *tensor) {
568   fb_graph_all_tensors_mutex_.lock();
569   auto insert_index = meta_graphT->allTensors.size();
570   meta_graphT->allTensors.emplace_back(std::move(tensor));
571   fb_graph_all_tensors_mutex_.unlock();
572   return insert_index;
573 }
574 
InsertFbTensor(const std::unique_ptr<schema::MetaGraphT> & meta_graphT,mindspore::schema::TensorT * tensor)575 void AnfExporter::InsertFbTensor(const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
576                                  mindspore::schema::TensorT *tensor) {
577   fb_graph_all_tensors_mutex_.lock();
578   meta_graphT->allTensors.emplace_back(std::move(tensor));
579   fb_graph_all_tensors_mutex_.unlock();
580 }
581 
GetAllTensorSize(const std::unique_ptr<schema::MetaGraphT> & meta_graphT)582 size_t AnfExporter::GetAllTensorSize(const std::unique_ptr<schema::MetaGraphT> &meta_graphT) {
583   fb_graph_all_tensors_mutex_.lock();
584   auto size = meta_graphT->allTensors.size();
585   fb_graph_all_tensors_mutex_.unlock();
586   return size;
587 }
588 
GetTensorFromAllTensor(const std::unique_ptr<schema::MetaGraphT> & meta_graphT,size_t index)589 mindspore::schema::TensorT *AnfExporter::GetTensorFromAllTensor(const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
590                                                                 size_t index) {
591   fb_graph_all_tensors_mutex_.lock();
592   auto *tensor = meta_graphT->allTensors[index].get();
593   fb_graph_all_tensors_mutex_.unlock();
594   return tensor;
595 }
596 
CaseToContinue(const string & prim_name)597 bool AnfExporter::CaseToContinue(const string &prim_name) {
598   return prim_name == mindspore::ops::kNameDepend || prim_name == mindspore::ops::kNameTupleGetItem ||
599          prim_name == mindspore::ops::kNameMakeTuple || prim_name == mindspore::ops::kNameMakeTupleV2;
600 }
601 
602 struct Anf2FbItem {
603  public:
Anf2FbItemmindspore::lite::Anf2FbItem604   Anf2FbItem(const std::shared_ptr<mindspore::Primitive> &prim, CNodePtr cnode) : prim_(prim), cnode_(cnode) {
605     dst_node_ = nullptr;
606   }
607 
608   std::shared_ptr<mindspore::Primitive> prim_;
609   CNodePtr cnode_;
610   schema::CNodeT *dst_node_;
611 };
612 
Anf2Fb(const FuncGraphPtr & func_graph,const std::unique_ptr<schema::MetaGraphT> & meta_graphT,const size_t & subgraph_index,const bool & keep_graph,const bool & copy_primitive)613 int AnfExporter::Anf2Fb(const FuncGraphPtr &func_graph, const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
614                         const size_t &subgraph_index, const bool &keep_graph, const bool &copy_primitive) {
615   MS_CHECK_TRUE_MSG(func_graph != nullptr, RET_NULL_PTR, "func_graph is nullptr");
616   MS_CHECK_TRUE_MSG(meta_graphT != nullptr, RET_NULL_PTR, "meta_graphT is nullptr");
617   int ret = RET_OK;
618   auto cnodes = InsertCallNode(func_graph);
619   std::list<Anf2FbItem> convert_items;
620 
621   // Do Modify FuncGraph in here and save convert item for next step
622   for (const auto &cnode : cnodes) {
623     auto prim = GetValueNode<std::shared_ptr<mindspore::Primitive>>(cnode->input(kPrimIndex));
624     if (prim == nullptr) {
625       MS_LOG(ERROR) << "get prim from value node failed.";
626       return RET_ERROR;
627     }
628     ret = RemoveIfDepend(cnode);
629     if (ret != RET_OK) {
630       MS_LOG(ERROR) << "RemoveIfDepend failed";
631       return ret;
632     }
633     if (CaseToContinue(prim->name())) {
634       continue;
635     }
636     ret = RemoveIfMakeTuple(cnode);
637     if (ret != RET_OK) {
638       MS_LOG(ERROR) << "RemoveIfMakeTuple failed";
639       return ret;
640     }
641     auto node = std::make_unique<schema::CNodeT>();
642     if (node == nullptr) {
643       MS_LOG(ERROR) << "object failed to be constructed";
644       return RET_MEMORY_FAILED;
645     }
646 
647     Anf2FbItem convert_item(prim, cnode);
648     convert_item.dst_node_ = node.release();
649     convert_items.push_back(convert_item);
650   }
651 
652   // convert CNode into NodeT
653   for (const auto &item : convert_items) {
654     auto prim = item.prim_;
655     auto cnode = item.cnode_;
656 
657     std::unique_ptr<schema::CNodeT> node(item.dst_node_);
658     std::unique_ptr<schema::PrimitiveT> primT;
659 
660     if (opt::CheckPrimitiveType(cnode, prim::kPrimReturn)) {
661       node->name = mindspore::ops::kNameReturn;
662       ret = SetSubGraphOutputIndex(cnode, subgraph_index, meta_graphT, node.get());
663       if (ret != RET_OK) {
664         MS_LOG(ERROR) << "SetOpOutputN failed";
665         break;
666       }
667       ret = SetTailCallForReturn(cnode);
668       if (ret != RET_OK) {
669         MS_LOG(ERROR) << "SetTailCallForReturn failed";
670         break;
671       }
672       continue;
673     }
674     primT = GetPrimitiveT(cnode->input(kPrimIndex));
675     node->name = cnode->fullname_with_scope();
676     node->primitive = std::move(primT);
677     auto device_type_attr = cnode->GetAttr(mindspore::ops::kDeviceType);
678     node->deviceType = (device_type_attr != nullptr) ? GetValue<int32_t>(device_type_attr) : -1;
679 
680     ret = SetOpOutputNode(cnode, meta_graphT, node.get());
681     if (ret != RET_OK) {
682       MS_LOG(ERROR) << "SetOpOutputNode failed";
683       break;
684     }
685 
686     ret = SetOpInputNode(cnode, meta_graphT, node.get());
687     if (ret != RET_OK) {
688       MS_LOG(ERROR) << "SetOpInputNode failed";
689       break;
690     }
691     // set all call op to non tail call
692     if (opt::CheckPrimitiveType(cnode, prim::kPrimCall)) {
693       node->primitive->value.AsCall()->is_tail_call = false;
694       call_node_map_[cnode] = node.get();
695     }
696 
697     ret = ExportPartialNode(meta_graphT, keep_graph, copy_primitive, cnode, node);
698     if (ret != RET_OK) {
699       MS_LOG(ERROR) << "ExportPartialNode failed.";
700       break;
701     }
702 
703     ret = ConvertQuantParam(meta_graphT, prim, node);
704     if (ret != RET_OK) {
705       MS_LOG(ERROR) << "ConvertQuantParam failed";
706       break;
707     }
708 
709     ret = ConvertQuantParam(meta_graphT, cnode, prim, node);
710     if (ret != RET_OK) {
711       MS_LOG(ERROR) << "New ConvertQuantParam failed";
712       break;
713     }
714 
715     fb_graph_node_mutex_.lock();
716     meta_graphT->nodes.push_back(std::move(node));
717     meta_graphT->subGraph.at(subgraph_index)->nodeIndices.push_back(node_idx_++);
718     fb_graph_node_mutex_.unlock();
719   }
720   return ret;
721 }
722 
ExportSubgraph(const FuncGraphPtr & func_graph,const std::unique_ptr<schema::MetaGraphT> & meta_graphT,bool keep_graph,bool copy_primitive,const std::shared_ptr<AnfNode> & partial_anode)723 int AnfExporter::ExportSubgraph(const FuncGraphPtr &func_graph, const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
724                                 bool keep_graph, bool copy_primitive, const std::shared_ptr<AnfNode> &partial_anode) {
725   MS_CHECK_TRUE_MSG(func_graph != nullptr, RET_NULL_PTR, "func_graph is nullptr");
726   MS_CHECK_TRUE_MSG(meta_graphT != nullptr, RET_NULL_PTR, "meta_graphT is nullptr");
727   if (HasExported(func_graph)) {
728     MS_LOG(INFO) << "Has been exported.";
729     return RET_OK;
730   }
731 
732   auto subgraph_ptr = std::make_unique<schema::SubGraphT>();
733   CHECK_NULL_RETURN(subgraph_ptr);
734   meta_graphT->subGraph.emplace_back(std::move(subgraph_ptr));
735   auto subgraph_index = meta_graphT->subGraph.size() - 1;
736   fg_subgraph_map_[func_graph] = subgraph_index;
737   auto subgraph_name = func_graph->get_attr("graph_name");
738   MS_CHECK_TRUE_MSG(subgraph_name != nullptr, RET_ERROR, "subgraph_name is nullptr");
739   meta_graphT->subGraph.back()->name =
740     "subgraph_" + std::to_string(meta_graphT->subGraph.size() - 1) + "_" + GetValue<std::string>(subgraph_name);
741 
742   auto ret = Anf2Fb(func_graph, meta_graphT, subgraph_index, keep_graph, copy_primitive);
743   if (ret != RET_OK) {
744     MS_LOG(ERROR) << "Anf2Fb failed";
745     ReturnCode::GetSingleReturnCode()->UpdateReturnCode(ret);
746     return ret;
747   }
748 
749   ret = SetSubGraphInputIndex(meta_graphT, subgraph_index);
750   if (ret != RET_OK) {
751     MS_LOG(ERROR) << "SetSubGraphInputIndex failed";
752     ReturnCode::GetSingleReturnCode()->UpdateReturnCode(ret);
753     return ret;
754   }
755 
756   SetSubgraphTensorIndices(meta_graphT.get());
757 
758   return RET_OK;
759 }
760 
GetFinalGraph(const FuncGraphPtr & func_graph,int i)761 FuncGraphPtr GetFinalGraph(const FuncGraphPtr &func_graph, int i) {
762   MS_CHECK_TRUE_MSG(func_graph != nullptr, nullptr, "func_graph is nullptr");
763   if (i > kMaxDepth) {
764     MS_LOG(ERROR) << "exceed max depth 2048, i " << i;
765     return nullptr;
766   }
767   i++;
768   // get output
769   CNodePtr call_cnode = nullptr;
770   auto fg_output = func_graph->output();
771   if (opt::CheckPrimitiveType(fg_output, prim::kPrimCall)) {
772     call_cnode = fg_output->cast<CNodePtr>();
773   } else {
774     return func_graph;
775   }
776 
777   // if call input is switch, meta output is call switch false partial's fg'output!
778   auto cnode = call_cnode->input(kFirstDataIndex)->cast<CNodePtr>();
779   if (IsSwitch(cnode)) {
780     auto false_cnode = cnode->input(kThirdDataIndex)->cast<CNodePtr>();
781     MS_CHECK_TRUE_MSG(false_cnode != nullptr, nullptr, "cast failed");
782     auto false_fg = GetValueNode<FuncGraphPtr>(false_cnode->input(kFirstDataIndex));
783     MS_CHECK_TRUE_MSG(false_fg != nullptr, nullptr, "GetValueNode failed");
784     return GetFinalGraph(false_fg, i);
785   } else if (IsSwitchLayer(cnode)) {
786     auto first_partial_cnode = cnode->input(kSecondDataIndex)->cast<CNodePtr>();
787     MS_CHECK_TRUE_MSG(first_partial_cnode != nullptr, nullptr, "cast failed");
788     auto next_fg = GetValueNode<FuncGraphPtr>(first_partial_cnode->input(kFirstDataIndex));
789     MS_CHECK_TRUE_MSG(next_fg != nullptr, nullptr, "GetValueNode failed");
790     return GetFinalGraph(next_fg, i);
791   } else {
792     auto fg = GetValueNode<FuncGraphPtr>(cnode->input(kFirstDataIndex));
793     MS_CHECK_TRUE_MSG(fg != nullptr, nullptr, "GetValueNode failed");
794     return GetFinalGraph(fg, i);
795   }
796 }
797 
SetMetaGraphInput(const FuncGraphPtr & func_graph,const std::unique_ptr<schema::MetaGraphT> & meta_graphT)798 int AnfExporter::SetMetaGraphInput(const FuncGraphPtr &func_graph,
799                                    const std::unique_ptr<schema::MetaGraphT> &meta_graphT) {
800   MS_CHECK_TRUE_MSG(func_graph != nullptr, RET_NULL_PTR, "func_graph is nullptr");
801   MS_CHECK_TRUE_MSG(meta_graphT != nullptr, RET_NULL_PTR, "meta_graphT is nullptr");
802   MS_ASSERT(func_graph != nullptr);
803   meta_graphT->inputIndex.clear();
804   for (const auto &input : func_graph->get_inputs()) {
805     auto iter = graph_inputs_map_.find(input);
806     if (iter == graph_inputs_map_.end()) {
807       MS_LOG(ERROR) << "input " << input->ToString() << " not found in graph" << std::endl;
808       return RET_ERROR;
809     }
810     meta_graphT->inputIndex.emplace_back(iter->second);
811   }
812   return RET_OK;
813 }
814 
SetMetaGraphOutput(const FuncGraphPtr & func_graph,const std::unique_ptr<schema::MetaGraphT> & meta_graphT)815 int AnfExporter::SetMetaGraphOutput(const FuncGraphPtr &func_graph,
816                                     const std::unique_ptr<schema::MetaGraphT> &meta_graphT) {
817   MS_CHECK_TRUE_MSG(func_graph != nullptr, RET_NULL_PTR, "func_graph is nullptr");
818   MS_CHECK_TRUE_MSG(meta_graphT != nullptr, RET_NULL_PTR, "meta_graphT is nullptr");
819   FuncGraphPtr final_fg = nullptr;
820   if (meta_graphT->fmkType == static_cast<int32_t>(converter::kFmkTypeMs)) {
821     final_fg = func_graph;
822   } else {
823     int i = 0;
824     final_fg = GetFinalGraph(func_graph, i);
825   }
826   MS_CHECK_TRUE_MSG(final_fg != nullptr, RET_ERROR, "GetFinalGraph failed.");
827   auto final_meta_graph_index = fg_subgraph_map_.at(final_fg);
828   auto &final_meta_graph = meta_graphT->subGraph.at(final_meta_graph_index);
829   meta_graphT->outputIndex.assign(final_meta_graph->outputIndices.begin(), final_meta_graph->outputIndices.end());
830 
831   for (auto &output_index : meta_graphT->outputIndex) {
832     auto tensor = GetTensorFromAllTensor(meta_graphT, output_index);
833     if (tensor == nullptr) {
834       MS_LOG(ERROR) << "Set meta graph output failed: output tensor is null.";
835       return RET_ERROR;
836     }
837     ConverterInnerContext::GetInstance()->UpdateGraphOutputDType(meta_graphT->outputIndex.size(), tensor->dataType);
838   }
839 
840   return RET_OK;
841 }
842 
Export(const FuncGraphPtr & func_graph,bool keep_graph,bool copy_primitive,bool train_flag)843 schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph, bool keep_graph, bool copy_primitive,
844                                         bool train_flag) {
845   MS_CHECK_TRUE_MSG(func_graph != nullptr, nullptr, "func_graph is nullptr");
846   this->train_flag_ = train_flag;
847   // hardcode for nnie and train
848   this->graph_inputs_map_.clear();
849   auto meta_graphT = std::make_unique<schema::MetaGraphT>();
850   MS_CHECK_TRUE_MSG(meta_graphT != nullptr, nullptr, "meta_graphT is nullptr");
851   auto fmk = func_graph->get_attr("fmk");
852   MS_CHECK_TRUE_MSG(fmk != nullptr, nullptr, "fmk is nullptr");
853   if (fmk->isa<Int64Imm>()) {
854     meta_graphT->fmkType = GetValue<int64_t>(fmk);
855   } else {
856     meta_graphT->fmkType = GetValue<int>(fmk);
857   }
858 
859   graph_inputs_ = func_graph->get_inputs();
860 
861   int ret = ExportSubgraph(func_graph, meta_graphT, keep_graph, copy_primitive);
862   if (ret != RET_OK) {
863     MS_LOG(ERROR) << "Export subgraph failed.";
864     ReturnCode::GetSingleReturnCode()->UpdateReturnCode(ret);
865     return nullptr;
866   }
867 
868   ret = SetTailCallForNonOutput();
869   if (ret != RET_OK) {
870     MS_LOG(ERROR) << "SetTailCallForNonOutput failed.";
871     ReturnCode::GetSingleReturnCode()->UpdateReturnCode(ret);
872     return nullptr;
873   }
874 
875   ret = SetMetaGraphInput(func_graph, meta_graphT);
876   if (ret != RET_OK) {
877     MS_LOG(ERROR) << "SetMetaGraphInput failed.";
878     ReturnCode::GetSingleReturnCode()->UpdateReturnCode(ret);
879     return nullptr;
880   }
881   ret = SetMetaGraphOutput(func_graph, meta_graphT);
882   if (ret != RET_OK) {
883     MS_LOG(ERROR) << "SetMetaGraphOutput failed.";
884     ReturnCode::GetSingleReturnCode()->UpdateReturnCode(ret);
885     return nullptr;
886   }
887 
888   return meta_graphT.release();
889 }
890 
ConvertInputCNodeCommonOp(const AnfNodePtr & input_anode,schema::CNodeT * output_cnode)891 int AnfExporter::ConvertInputCNodeCommonOp(const AnfNodePtr &input_anode, schema::CNodeT *output_cnode) {
892   MS_ASSERT(input_anode != nullptr && output_cnode != nullptr);
893   if (this->train_flag_) {
894     auto key = std::make_pair(input_anode, 0);
895     if (HasNodeIdKey(key)) {
896       output_cnode->inputIndex.emplace_back(GetNodeId(key));
897     }
898     return RET_OK;
899   }
900   if (utils::isa<abstract::AbstractTuple>(input_anode->abstract())) {
901     auto tuple = std::reinterpret_pointer_cast<abstract::AbstractTuple>(input_anode->abstract());
902     MS_CHECK_TRUE_MSG(tuple != nullptr, RET_ERROR, "tuple is nullptr");
903     auto elements = tuple->elements();
904     for (size_t i = 0; i < elements.size(); i++) {
905       auto key = std::make_pair(input_anode, i);
906       if (HasNodeIdKey(key)) {
907         output_cnode->inputIndex.emplace_back(GetNodeId(key));
908       }
909     }
910   } else {
911     auto key = std::make_pair(input_anode, 0);
912     if (HasNodeIdKey(key)) {
913       output_cnode->inputIndex.emplace_back(GetNodeId(key));
914     }
915   }
916   return RET_OK;
917 }
918 
ConvertInputCNode(const std::shared_ptr<AnfNode> & input_anode,schema::CNodeT * output_cnode)919 int AnfExporter::ConvertInputCNode(const std::shared_ptr<AnfNode> &input_anode, schema::CNodeT *output_cnode) {
920   auto input_cnode = utils::cast<CNodePtr>(input_anode);
921   MS_CHECK_TRUE_MSG(input_cnode != nullptr, RET_ERROR, "cast ptr failed");
922   auto input_value_node = input_cnode->input(kPrimIndex)->cast<ValueNodePtr>();
923   if (input_value_node == nullptr) {
924     if (!IsCall(input_cnode)) {
925       MS_LOG(ERROR) << "value node is invalid.";
926       return RET_ERROR;
927     } else {
928       auto call_anf_prim_vnode = GetCallAnfPrim();
929       auto cnode_input = input_cnode->inputs();
930       MS_CHECK_TRUE_MSG(call_anf_prim_vnode != nullptr, RET_ERROR, "GetCallAnfPrim failed");
931       cnode_input.insert(cnode_input.begin(), call_anf_prim_vnode);
932       input_cnode->set_inputs(cnode_input);
933     }
934   }
935 
936   input_value_node = input_cnode->input(kPrimIndex)->cast<ValueNodePtr>();
937 
938   if (input_value_node->value() == nullptr || !opt::CheckPrimitiveType(input_cnode, prim::kPrimTupleGetItem)) {
939     return ConvertInputCNodeCommonOp(input_anode, output_cnode);
940   } else {
941     auto inputs = input_cnode->inputs();
942 
943     if (inputs.size() != 3) {
944       MS_LOG(ERROR) << "TupleGetItem should have 3 inputs, got " << inputs.size();
945       return RET_ERROR;
946     }
947     auto get_item_input_cnode = inputs.at(1);
948     auto index_vnode = inputs.at(kIndexOfValueInputOfGetTupleItem);
949     if (!utils::isa<ValueNode>(index_vnode)) {
950       MS_LOG(ERROR) << "TupleGetItem's input 2 is not valuenode";
951       return RET_ERROR;
952     }
953     auto value_node = utils::cast<ValueNodePtr>(index_vnode);
954     MS_CHECK_TRUE_MSG(value_node != nullptr, RET_ERROR, "cast to ValueNode failed");
955     auto idx = value_node->value()->type()->number_type() == kNumberTypeInt64 ? GetValue<int64_t>(value_node->value())
956                                                                               : GetValue<int>(value_node->value());
957     auto key = std::make_pair(get_item_input_cnode, idx);
958     if (!HasNodeIdKey(key)) {
959       key = std::make_pair(get_item_input_cnode, 0);  // try name with 0
960       if (!HasNodeIdKey(key)) {
961         MS_LOG(ERROR) << "Can not find get_item output tensor "
962                       << get_item_input_cnode->fullname_with_scope() + "_o:" + std::to_string(idx);
963         return RET_ERROR;
964       }
965     }
966     output_cnode->inputIndex.emplace_back(GetNodeId(key));
967   }
968   return RET_OK;
969 }
970 
ConvertInputParameter(const CNodePtr & cnode,size_t index,const PrimitivePtr & primitive,const std::unique_ptr<schema::MetaGraphT> & meta_graphT,schema::CNodeT * op_node,size_t * tensor_index_ptr)971 int AnfExporter::ConvertInputParameter(const CNodePtr &cnode, size_t index, const PrimitivePtr &primitive,
972                                        const std::unique_ptr<schema::MetaGraphT> &meta_graphT, schema::CNodeT *op_node,
973                                        size_t *tensor_index_ptr) {
974   MS_CHECK_TRUE_MSG(cnode != nullptr, RET_NULL_PTR, "cnode is nullptr");
975   MS_CHECK_TRUE_MSG(primitive != nullptr, RET_NULL_PTR, "primitive is nullptr");
976   MS_CHECK_TRUE_MSG(meta_graphT != nullptr, RET_NULL_PTR, "meta_graphT is nullptr");
977   MS_CHECK_TRUE_MSG(op_node != nullptr, RET_NULL_PTR, "op_node is nullptr");
978   MS_CHECK_TRUE_MSG(tensor_index_ptr != nullptr, RET_NULL_PTR, "tensor_index_ptr is nullptr");
979   auto param_node = cnode->input(index)->cast<ParameterPtr>();
980   MS_ASSERT(param_node != nullptr);
981   auto key = std::make_pair(param_node, 0);
982   if (HasNodeIdKey(key)) {
983     op_node->inputIndex.emplace_back(GetNodeId(key));
984     return RET_OK;
985   }
986   DataInfo data_info;
987   if (FetchDataFromParameterNode(cnode, index, converter::FmkType(meta_graphT->fmkType), &data_info, true) != RET_OK) {
988     MS_LOG(ERROR) << "parse const node failed.";
989     return RET_ERROR;
990   }
991   auto schema_tensor = CreateTensorFromDataInfo(data_info, param_node->name(), param_node->has_default());
992   *tensor_index_ptr = NewFbTensor(meta_graphT, schema_tensor.release());
993   SetNodeId(key, *tensor_index_ptr);
994   op_node->inputIndex.emplace_back(*tensor_index_ptr);
995   return RET_OK;
996 }
997 
ConvertInputValueNode(const CNodePtr & cnode,size_t index,const PrimitivePtr & primitive,const std::unique_ptr<schema::MetaGraphT> & meta_graphT,schema::CNodeT * op_node)998 int AnfExporter::ConvertInputValueNode(const CNodePtr &cnode, size_t index, const PrimitivePtr &primitive,
999                                        const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
1000                                        schema::CNodeT *op_node) {
1001   MS_CHECK_TRUE_MSG(cnode != nullptr, RET_NULL_PTR, "cnode is nullptr");
1002   MS_CHECK_TRUE_MSG(primitive != nullptr, RET_NULL_PTR, "primitive is nullptr");
1003   MS_CHECK_TRUE_MSG(meta_graphT != nullptr, RET_NULL_PTR, "meta_graphT is nullptr");
1004   MS_CHECK_TRUE_MSG(op_node != nullptr, RET_NULL_PTR, "op_node is nullptr");
1005   auto value_node = cnode->input(index)->cast<ValueNodePtr>();
1006   MS_ASSERT(value_node != nullptr);
1007   auto key = std::make_pair(value_node, 0);
1008   if (HasNodeIdKey(key)) {
1009     op_node->inputIndex.emplace_back(GetNodeId(key));
1010     return RET_OK;
1011   }
1012   DataInfo data_info;
1013   auto status =
1014     FetchDataFromValueNode(cnode, index, converter::FmkType(meta_graphT->fmkType), train_flag_, &data_info, true);
1015   if (status == RET_NO_CHANGE) {
1016     return RET_OK;
1017   }
1018   if (status != RET_OK) {
1019     MS_LOG(ERROR) << "parse value node failed.";
1020     return status;
1021   }
1022   auto schema_tensor = std::make_unique<schema::TensorT>();
1023   MS_CHECK_TRUE_MSG(schema_tensor != nullptr, RET_ERROR, "schema is nullptr");
1024   schema_tensor->name = value_node->fullname_with_scope();
1025   schema_tensor->format = static_cast<schema::Format>(data_info.format_);
1026   schema_tensor->dataType = data_info.data_type_;
1027   schema_tensor->dims = data_info.shape_;
1028   schema_tensor->data = data_info.data_;
1029 
1030   auto tensor_index = NewFbTensor(meta_graphT, schema_tensor.release());
1031   SetNodeId(key, tensor_index);
1032   op_node->inputIndex.emplace_back(tensor_index);
1033   return RET_OK;
1034 }
1035 
SetOpInputNode(const CNodePtr & cnode,const std::unique_ptr<schema::MetaGraphT> & meta_graphT,schema::CNodeT * fb_node)1036 int AnfExporter::SetOpInputNode(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
1037                                 schema::CNodeT *fb_node) {
1038   MS_ASSERT(meta_graphT != nullptr);
1039   MS_ASSERT(fb_node != nullptr);
1040   if (cnode->size() <= 1) {
1041     return RET_OK;
1042   }
1043   auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
1044   if (primitive_c == nullptr) {
1045     MS_LOG(ERROR) << "primitive_c is nullptr: " << cnode->fullname_with_scope();
1046     return RET_ERROR;
1047   }
1048   for (size_t i = 1; i < cnode->size(); i++) {
1049     auto input_node = cnode->input(i);
1050     if (input_node->isa<mindspore::CNode>()) {
1051       auto ret = ConvertInputCNode(input_node, fb_node);
1052       if (ret != RET_OK) {
1053         MS_LOG(ERROR) << "ConvertInputCNode failed";
1054         return ret;
1055       }
1056     } else if (input_node->isa<Parameter>()) {
1057       size_t tensor_index;
1058       auto ret = ConvertInputParameter(cnode, i, primitive_c, meta_graphT, fb_node, &tensor_index);
1059       if (ret != RET_OK) {
1060         MS_LOG(ERROR) << "ConvertInputParameter failed";
1061         return ret;
1062       }
1063       if (IsContain(graph_inputs_, input_node->cast<AnfNodePtr>()) &&
1064           graph_inputs_map_.find(input_node) == graph_inputs_map_.end()) {
1065         graph_inputs_map_[input_node] = tensor_index;
1066       }
1067     } else if (input_node->isa<ValueNode>()) {
1068       auto ret = ConvertInputValueNode(cnode, i, primitive_c, meta_graphT, fb_node);
1069       if (ret != RET_OK) {
1070         MS_LOG(ERROR) << "ConvertInputValueNode failed";
1071         return RET_ERROR;
1072       }
1073     }
1074   }
1075   fb_node->name = cnode->fullname_with_scope();
1076   return RET_OK;
1077 }
1078 
SetOpOutputNode(const CNodePtr & cnode,const std::unique_ptr<schema::MetaGraphT> & meta_graphT,schema::CNodeT * fb_node)1079 int AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
1080                                  schema::CNodeT *fb_node) {
1081   MS_ASSERT(meta_graphT != nullptr);
1082   MS_ASSERT(fb_node != nullptr);
1083   std::string cnode_name = fb_node->name;
1084 
1085   // new anf export and import will add abstract tuple for control flow op, which contains abstract closure,
1086   // abstract tuple and abstract tensor. For inference, we don't need this information. So skip export abstract tuple
1087   // for control flow op. Just use a abstract tensor link the control flow ops.
1088   if (utils::isa<abstract::AbstractTuple>(cnode->abstract()) && !IsControlFlowOp(cnode)) {
1089     auto tuple = std::reinterpret_pointer_cast<abstract::AbstractTuple>(cnode->abstract());
1090     MS_CHECK_TRUE_MSG(tuple != nullptr, RET_ERROR, "tuple is nullptr");
1091     auto elements = tuple->elements();
1092     for (size_t i = 0; i < lite::GetCNodeOutputsSize(cnode, train_flag_); i++) {
1093       auto ms_tensor = new (std::nothrow) schema::TensorT();
1094       if (ms_tensor == nullptr) {
1095         MS_LOG(ERROR) << "new msTensor failed";
1096         return RET_ERROR;
1097       }
1098       ms_tensor->nodeType = NodeType_CNode;
1099       auto key = std::make_pair(cnode, i);
1100       if (!train_flag_) {
1101         auto val_ptr = cnode->GetAttr("outputs_names");
1102         std::string tensor_name = "";
1103         std::string name_surfix = "";
1104         auto val_index = i;
1105         if (elements.size() == 1) {
1106           key = std::make_pair(cnode, 0);
1107           val_index = 0;
1108         } else {
1109           name_surfix = "_o:" + std::to_string(i);
1110         }
1111         if (val_ptr != nullptr) {
1112           auto outputs_names = GetValue<std::vector<std::string>>(val_ptr);
1113           tensor_name = outputs_names[val_index];
1114         } else {
1115           tensor_name = cnode_name + name_surfix;
1116         }
1117 
1118         if (!utils::isa<abstract::AbstractTensorPtr>(elements[i])) {
1119           MS_LOG(ERROR) << "abstract is not AbstractTensor";
1120           delete (ms_tensor);
1121           return RET_ERROR;
1122         }
1123         auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(elements[i]);
1124         MS_CHECK_TRUE_MSG(abstract_tensor != nullptr, RET_ERROR, "Cast to abstract tensor failed!");
1125         auto type_ptr = abstract_tensor->element()->GetTypeTrack();
1126         MS_CHECK_TRUE_MSG(type_ptr != nullptr, RET_ERROR, "type_ptr is nullptr");
1127         ms_tensor->dataType = type_ptr->type_id();
1128         ms_tensor->name = tensor_name;
1129 
1130         auto tensor_index = NewFbTensor(meta_graphT, ms_tensor);
1131         SetNodeId(key, tensor_index);
1132         fb_node->outputIndex.emplace_back(tensor_index);
1133         if (opt::CheckPrimitiveType(cnode, prim::kPrimConv2DFusion) ||
1134             opt::CheckPrimitiveType(cnode, prim::kPrimFusedBatchNorm) ||
1135             opt::CheckPrimitiveType(cnode, prim::kPrimLayerNormFusion)) {
1136           break;
1137         }
1138       } else {
1139         auto tensor_index = NewFbTensor(meta_graphT, ms_tensor);
1140         SetNodeId(key, tensor_index);
1141         fb_node->outputIndex.emplace_back(tensor_index);
1142       }
1143     }
1144   } else {
1145     auto ms_tensor = new (std::nothrow) schema::TensorT();
1146     if (ms_tensor == nullptr) {
1147       MS_LOG(ERROR) << "new tensor failed";
1148       return RET_ERROR;
1149     }
1150     auto type = kNumberTypeFloat32;
1151     if (utils::isa<abstract::AbstractTensorPtr>(cnode->abstract())) {
1152       auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(cnode->abstract());
1153       MS_CHECK_TRUE_MSG(abstract_tensor != nullptr, RET_ERROR, "Cast to abstract tensor failed!");
1154       auto typePtr = abstract_tensor->element()->GetTypeTrack();
1155       type = typePtr->type_id();
1156     }
1157     ms_tensor->dataType = type;
1158     ms_tensor->nodeType = NodeType_CNode;
1159     auto val_ptr = cnode->GetAttr("outputs_names");
1160     if (val_ptr != nullptr) {
1161       auto outputs_names = GetValue<std::vector<std::string>>(val_ptr);
1162       ms_tensor->name = outputs_names[0];
1163     } else {
1164       ms_tensor->name = cnode_name;
1165     }
1166     auto tensor_index = NewFbTensor(meta_graphT, ms_tensor);
1167     auto key = std::make_pair(cnode, 0);
1168     SetNodeId(key, tensor_index);
1169     fb_node->outputIndex.emplace_back(tensor_index);
1170   }
1171   return RET_OK;
1172 }
1173 
CreateCallCnode(const FuncGraphPtr & fg,const AnfNodePtr & node)1174 CNodePtr AnfExporter::CreateCallCnode(const FuncGraphPtr &fg, const AnfNodePtr &node) {
1175   auto call_anf_prim_vnode = GetCallAnfPrim();
1176   MS_CHECK_TRUE_MSG(call_anf_prim_vnode != nullptr, nullptr, "GetCallAnfPrim failed");
1177   std::vector<AnfNodePtr> inputs{call_anf_prim_vnode, node};
1178   auto cnode = fg->NewCNodeInOrder(inputs);
1179   MS_CHECK_TRUE_MSG(cnode != nullptr, nullptr, "NewCNode failed");
1180   cnode->set_func_graph(fg);
1181   return cnode;
1182 }
1183 
CreatePartialCnode(const FuncGraphPtr & fg,const AnfNodePtr & node)1184 CNodePtr AnfExporter::CreatePartialCnode(const FuncGraphPtr &fg, const AnfNodePtr &node) {
1185   if (utils::isa<CNodePtr>(node)) {
1186     auto cnode = utils::cast<CNodePtr>(node);
1187     MS_CHECK_TRUE_MSG(cnode != nullptr, nullptr, "cast ptr failed");
1188     auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(kPrimIndex));
1189     if (primitive_c != nullptr) {
1190       return cnode;
1191     }
1192     auto partial_anf_prim_vnode = GetPartialFusionPrim();
1193     auto cnode_input = cnode->inputs();
1194     MS_CHECK_TRUE_MSG(partial_anf_prim_vnode != nullptr, nullptr, "GetPartialFusionPrim failed");
1195     cnode_input.insert(cnode_input.begin(), partial_anf_prim_vnode);
1196     cnode->set_inputs(cnode_input);
1197     return cnode;
1198   } else if (utils::isa<ValueNodePtr>(node)) {
1199     auto partial_anf_prim_vnode = GetPartialFusionPrim();
1200     MS_CHECK_TRUE_MSG(partial_anf_prim_vnode != nullptr, nullptr, "GetPartialFusionPrim failed");
1201     std::vector<AnfNodePtr> inputs{partial_anf_prim_vnode, node};
1202     auto cnode = fg->NewCNode(inputs);
1203     MS_CHECK_TRUE_MSG(cnode != nullptr, nullptr, "New cnode failed");
1204     return cnode;
1205   } else {
1206     MS_LOG(ERROR) << "failed to create partial cnode.";
1207     return nullptr;
1208   }
1209 }
1210 
Export(const FuncGraphPtr & func_graph,bool keep_graph,bool copy_primitive,bool train_flag)1211 schema::MetaGraphT *Export(const FuncGraphPtr &func_graph, bool keep_graph, bool copy_primitive, bool train_flag) {
1212   AnfExporter lite_exporter;
1213   return lite_exporter.Export(func_graph, keep_graph, copy_primitive, train_flag);
1214 }
1215 }  // namespace mindspore::lite
1216