• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <algorithm>
18 #include <fstream>
19 #include <functional>
20 #include <map>
21 #include <memory>
22 #include <utility>
23 #include "include/common/debug/anf_ir_dump.h"
24 #include "include/common/debug/dump_proto.h"
25 #include "mindspore/core/ops/op_def.h"
26 #include "mindspore/core/ops/array_ops.h"
27 #include "mindspore/core/ops/framework_ops.h"
28 #include "mindspore/core/ops/structure_ops.h"
29 #include "include/common/utils/compile_cache_context.h"
30 
31 namespace {
32 using mindspore::CNodePtr;
33 using mindspore::FileUtils;
34 using mindspore::FuncGraph;
35 using mindspore::FuncGraphPtr;
36 using mindspore::ValueNode;
37 using mindspore::ValueNodePtr;
38 
GetAllFuncGraphs(const FuncGraphPtr & func_graph,std::set<FuncGraphPtr> * all_func_graphs)39 void GetAllFuncGraphs(const FuncGraphPtr &func_graph, std::set<FuncGraphPtr> *all_func_graphs) {
40   MS_ASSERT(all_func_graphs != nullptr);
41   MS_ASSERT(func_graph != nullptr);
42   if (all_func_graphs->find(func_graph) == all_func_graphs->end()) {
43     (void)(all_func_graphs->insert(func_graph));
44   } else {
45     return;
46   }
47   auto nodes = mindspore::TopoSort(func_graph->get_return());
48   for (auto &node : nodes) {
49     if (mindspore::IsValueNode<FuncGraph>(node)) {
50       MS_ASSERT(node->cast<ValueNodePtr>() != nullptr);
51       MS_ASSERT(node->cast<ValueNodePtr>()->value() != nullptr);
52       MS_ASSERT((node->cast<ValueNodePtr>()->value())->cast<FuncGraphPtr>() != nullptr);
53       auto new_fg = (node->cast<ValueNodePtr>()->value())->cast<FuncGraphPtr>();
54       GetAllFuncGraphs(new_fg, all_func_graphs);
55     }
56     if (mindspore::utils::isa<CNodePtr>(node)) {
57       auto cnode = node->cast<CNodePtr>();
58       MS_ASSERT(cnode != nullptr);
59       for (auto &weak_input : cnode->weak_inputs()) {
60         auto input = weak_input.lock();
61         MS_EXCEPTION_IF_NULL(input);
62         if (input->isa<ValueNode>()) {
63           if (mindspore::IsValueNode<FuncGraph>(input)) {
64             MS_ASSERT(input->cast<ValueNodePtr>() != nullptr);
65             MS_ASSERT(input->cast<ValueNodePtr>()->value() != nullptr);
66             MS_ASSERT((input->cast<ValueNodePtr>()->value())->cast<FuncGraphPtr>() != nullptr);
67             auto new_fg = (input->cast<ValueNodePtr>()->value())->cast<FuncGraphPtr>();
68             GetAllFuncGraphs(new_fg, all_func_graphs);
69           }
70         }
71       }
72     }
73   }
74 }
75 
DeleteDirRecursively(const std::string & dir_name)76 bool DeleteDirRecursively(const std::string &dir_name) {
77   DIR *dir = opendir(dir_name.c_str());
78   dirent *dirent = nullptr;
79   std::vector<std::string> file_names{};
80   while ((dirent = readdir(dir)) != nullptr) {
81     if (strcmp(dirent->d_name, ".") != 0 && strcmp(dirent->d_name, "..") != 0) {
82       (void)(file_names.emplace_back(dirent->d_name));
83     }
84   }
85   for (auto &file_name : file_names) {
86     auto file_path = dir_name + "/" + file_name;
87     auto real_file_path = FileUtils::GetRealPath(file_path.c_str());
88     if (!real_file_path.has_value()) {
89       (void)(closedir(dir));
90       MS_LOG(ERROR) << "Cannot get pwd path";
91       return false;
92     }
93     auto result = unlink(real_file_path.value().c_str());
94     if (result != 0) {
95       (void)(closedir(dir));
96       MS_LOG(ERROR) << "Delete the file(" << real_file_path.value() << ") failed." << mindspore::ErrnoToString(errno);
97       return false;
98     }
99   }
100   (void)(closedir(dir));
101   return true;
102 }
103 };  // namespace
104 
105 namespace mindspore {
SetAbstractFuncToAttributeProto(const abstract::AbstractBasePtr & abstract,mind_ir::AttributeProto * const attr_proto)106 bool IrExportBuilder::SetAbstractFuncToAttributeProto(const abstract::AbstractBasePtr &abstract,
107                                                       mind_ir::AttributeProto *const attr_proto) {
108   MS_EXCEPTION_IF_NULL(abstract);
109   MS_EXCEPTION_IF_NULL(attr_proto);
110   if (abstract->isa<abstract::FuncGraphAbstractClosure>()) {
111     attr_proto->set_type(mind_ir::AttributeProto_AttributeType_FUNCGRAPHCLOSURE);
112     auto func_name = abstract->cast<abstract::FuncGraphAbstractClosurePtr>()->func_graph()->ToString();
113     attr_proto->set_s(func_name);
114   } else if (abstract->isa<abstract::PrimitiveAbstractClosure>()) {
115     attr_proto->set_type(mind_ir::AttributeProto_AttributeType_PRIMITIVECLOSURE);
116     auto prim = abstract->cast<abstract::PrimitiveAbstractClosurePtr>()->prim();
117     attr_proto->set_s(GetPrimitiveUniqueName(prim));
118   } else if (abstract->isa<abstract::PartialAbstractClosure>()) {
119     attr_proto->set_type(mind_ir::AttributeProto_AttributeType_PARTIALCLOSURE);
120     auto node_ptr = abstract->cast<abstract::PartialAbstractClosurePtr>()->node();
121     MS_EXCEPTION_IF_NULL(node_ptr);
122     attr_proto->set_s(GetUniqueNodeName(node_ptr));
123   } else if (abstract->isa<abstract::AbstractFuncUnion>()) {
124     attr_proto->set_type(mind_ir::AttributeProto_AttributeType_UNIONFUNCCLOSURE);
125     auto visit_func = [this, &attr_proto](const abstract::AbstractFuncAtomPtr &poss) {
126       auto element_attr_proto = attr_proto->add_values();
127       if (!this->SetAbstractFuncToAttributeProto(poss, element_attr_proto)) {
128         MS_LOG(EXCEPTION) << "Set union function abstract to proto error." << poss->ToString();
129       }
130     };
131     abstract->cast<abstract::AbstractFunctionPtr>()->Visit(visit_func);
132   } else {
133     MS_LOG(ERROR) << "The parameter abstract is not an abstractFunction: " << abstract->ToString();
134     return false;
135   }
136   return true;
137 }
138 
GetPrimitiveUniqueName(const PrimitivePtr & primitive_ptr)139 std::string IrExportBuilder::GetPrimitiveUniqueName(const PrimitivePtr &primitive_ptr) {
140   auto it = primitive_name_map_.find(primitive_ptr);
141   if (it != primitive_name_map_.end()) {
142     return it->second;
143   }
144   // Remove this check if we find a way to handle save/load training model with flattened parameters.
145   if (IsPrimitiveEquals(primitive_ptr, prim::kPrimFlattenConcat)) {
146     MS_LOG(EXCEPTION) << "Export model with operator '" << primitive_ptr->name() << "' is not supported yet.\n"
147                       << "Please remove 'net.flatten_weights()' in your script and try again.";
148   }
149   auto answer = primitive_ptr->name() + ":" + std::to_string(GetUniqueID());
150   primitive_name_map_[primitive_ptr] = answer;
151   return answer;
152 }
153 
BuildPrimitives()154 bool IrExportBuilder::BuildPrimitives() {
155   for (auto it = primitive_name_map_.begin(); it != primitive_name_map_.end(); ++it) {
156     auto prim = it->first;
157     if (prim->name() == prim::kPrimPyExecute->name()) {
158       MS_LOG(EXCEPTION) << "Cannot export a PyExecute CNode in MindIR.";
159     }
160     auto prim_proto = model_->add_primitives();
161 
162     prim_proto->set_name(it->second);
163     prim_proto->set_op_type(prim->name());
164     // function IsPrimitiveFunction: dynamic shape new primitive
165     // attr is_primitive_function: default true, Lite MindIr false
166     bool is_primitive_function =
167       prim->GetAttr("primitive_function") == nullptr || GetValue<bool>(prim->GetAttr("primitive_function"));
168     if (mindspore::ops::IsPrimitiveFunction(prim->name()) && is_primitive_function) {
169       prim_proto->set_prim_type(mind_ir::PrimitiveProto_PrimType_PRIMITIVE_FUNCTION);
170     } else {
171       prim_proto->set_prim_type(mind_ir::PrimitiveProto_PrimType_PRIMITIVE);
172     }
173 
174     auto real_prim = GetValueWithoutDoSignature(prim)->cast<PrimitivePtr>();
175     if (real_prim != nullptr) {
176       prim = real_prim;
177     }
178 
179     prim_proto->set_instance_name(prim->instance_name());
180 
181     // Set primitive attributes
182     for (const auto &attr : prim->attrs()) {
183       MS_LOG(DEBUG) << "attr: " << attr.first << " " << attr.second->DumpText() << " " << attr.second->type_name();
184       auto iter = g_export_attr_blacklist.find(attr.first);
185       if (iter != g_export_attr_blacklist.end()) {
186         continue;
187       }
188       if (attr.second == nullptr) {
189         MS_LOG(ERROR) << "attr: " << attr.first << " has no value.";
190         continue;
191       }
192       mind_ir::AttributeProto *attr_proto = prim_proto->add_attribute();
193       attr_proto->set_name(attr.first);
194       auto attr_value = attr.second;
195       if (!is_kernel_graph_) {
196         CheckAndConvertUtils::ConvertAttrValueInExport(prim->name(), attr.first, &attr_value);
197       }
198       if (!SetValueToAttributeProto(attr_value, attr_proto)) {
199         MS_LOG(ERROR) << "Set value to AttributeProto failed.";
200         return false;
201       }
202     }  // Loop of attrs
203   }    // Loop of primitives
204   return true;
205 }
206 
GetDumpString(const FuncGraphPtr & func_graph)207 std::string IrExporter::GetDumpString(const FuncGraphPtr &func_graph) {
208   auto dump_proto = GetDumpProto(func_graph);
209   if (dump_proto == nullptr) {
210     MS_LOG(EXCEPTION) << "Get dump proto for graph " << func_graph->ToString() << " failed.";
211   }
212   return builder_->GetProtoString();
213 }
214 
GetDumpProto(const FuncGraphPtr & func_graph)215 ModelProtoPtr IrExporter::GetDumpProto(const FuncGraphPtr &func_graph) {
216   if ((builder_ == nullptr) || (func_graph == nullptr)) {
217     MS_LOG(EXCEPTION) << "Input params is null.";
218   }
219 
220   // Export model info
221   builder_->BuildModelInfo();
222 
223   // Export model and return string
224   if (!builder_->BuildModel(func_graph)) {
225     return nullptr;
226   }
227   return builder_->Model();
228 }
229 
GetDumpProto(const FuncGraphPtr & root_graph,const std::vector<FuncGraphPtr> & child_graphs,const std::vector<AnfNodePtr> & isolated_nodes)230 ModelProtoPtr IrExporter::GetDumpProto(const FuncGraphPtr &root_graph, const std::vector<FuncGraphPtr> &child_graphs,
231                                        const std::vector<AnfNodePtr> &isolated_nodes) {
232   // Export model info
233   builder_->BuildModelInfo();
234   // Export model and return string
235   if (!builder_->BuildModel(root_graph, child_graphs, isolated_nodes)) {
236     return nullptr;
237   }
238   return builder_->Model();
239 }
240 
GetProtoString() const241 std::string IrExportBuilder::GetProtoString() const {
242   MS_LOG(DEBUG) << "BuildModel complete!";
243   return model_->SerializeAsString();
244 }
245 
BuildModelInfo()246 void IrExportBuilder::BuildModelInfo() {
247   constexpr auto ir_version = "0.1.1";
248   constexpr auto mindspore_name = "MindSpore";
249   model_->set_ir_version(ir_version);
250   model_->set_producer_name(mindspore_name);
251   model_->set_model_version(VERSION);
252   model_->set_little_endian(common::IsLittleByteOrder());
253   model_->set_mind_ir_version(mind_ir::Version_MAX);
254 }
255 
256 // build model for kernel graph
BuildModel(const FuncGraphPtr & root_graph,const std::vector<FuncGraphPtr> & child_graphs,const std::vector<AnfNodePtr> & isolated_nodes)257 bool IrExportBuilder::BuildModel(const FuncGraphPtr &root_graph, const std::vector<FuncGraphPtr> &child_graphs,
258                                  const std::vector<AnfNodePtr> &isolated_nodes) {
259   MS_EXCEPTION_IF_NULL(root_graph);
260   is_kernel_graph_ = root_graph->type_name() == kKernelGraphTypeName;
261   nodeName_.clear();
262   node_name_map_.clear();
263   primitive_name_map_.clear();
264 
265   // Because param may be called across graphs, build params of all graphs first.
266   auto build_params_attrs = [this](const FuncGraphPtr &graph, mind_ir::GraphProto *const proto) {
267     if (!BuildParameters(graph, proto)) {
268       MS_LOG(ERROR) << "Build graph parameters failed.";
269       return false;
270     }
271     if (!BuildFuncGraphAttrs(graph, proto)) {
272       MS_LOG(ERROR) << "Build graph parameters attrs failed.";
273       return false;
274     }
275     return true;
276   };
277 
278   (void)nodeName_.insert(root_graph->ToString());
279   auto root_graph_proto = model_->mutable_graph();
280   // build root graph params
281   top_graph = true;
282   if (!(build_params_attrs(root_graph, root_graph_proto))) {
283     return false;
284   }
285   root_graph_proto->set_name(root_graph->ToString());
286   graph_protos_[root_graph] = root_graph_proto;
287   // build child graph params
288   top_graph = false;
289   for (const auto &graph : child_graphs) {
290     auto func_proto = model_->add_functions();
291     func_proto->set_name(graph->ToString());
292     (void)nodeName_.insert(graph->ToString());
293     if (!(build_params_attrs(graph, func_proto))) {
294       return false;
295     }
296     graph_protos_[graph] = func_proto;
297   }
298   // build nodes for root_graph, then child_graph
299   if (!BuildNodes(root_graph, root_graph_proto)) {
300     return false;
301   }
302   std::map<std::string, FuncGraphPtr> sorted_graphs;
303   std::for_each(child_graphs.begin(), child_graphs.end(),
304                 [&sorted_graphs](const auto &iter) { sorted_graphs[iter->ToString()] = iter; });
305   for (const auto &iter : sorted_graphs) {
306     const auto &graph = iter.second;
307     if (!BuildNodes(graph, graph_protos_[graph])) {
308       return false;
309     }
310   }
311   if (!BuildIsolatedNodes(isolated_nodes)) {
312     return false;
313   }
314   // build primitives
315   if (!BuildPrimitives()) {
316     return false;
317   }
318   // Release resource
319   nodeName_.clear();
320   node_name_map_.clear();
321   primitive_name_map_.clear();
322   graph_protos_.clear();
323   return true;
324 }
325 
BuildModel(const FuncGraphPtr & func_graph)326 bool IrExportBuilder::BuildModel(const FuncGraphPtr &func_graph) {
327   MS_EXCEPTION_IF_NULL(func_graph);
328   mind_ir::GraphProto *graph_proto = model_->mutable_graph();
329   graph_proto->set_name(func_graph->ToString());
330   graph_proto->set_bprop_hash(func_graph->bprop_hash());
331   graph_proto->set_bprop_filepath(func_graph->bprop_filepath());
332   todo_.clear();
333   nodeName_.clear();
334   primitive_name_map_.clear();
335   // Build the main funcGraph
336   (void)nodeName_.insert(func_graph->ToString());
337   top_graph = true;
338 
339   if (!BuildFuncGraph(func_graph, graph_proto)) {
340     MS_LOG(ERROR) << "Build func_graph " << func_graph->ToString() << " failed.";
341     return false;
342   }
343 
344   // Build child funcGraphs
345   std::set<FuncGraphPtr> graphVisited;
346   (void)graphVisited.insert(func_graph);
347   top_graph = false;
348 
349   auto &context = CompileCacheContext::GetInstance();
350   const auto &child_graphs = context.GetChileGraphs();
351   (void)(std::transform(child_graphs.begin(), child_graphs.end(), std::back_inserter(todo_),
352                         [](const FuncGraphPtr &g) { return g; }));
353   while (!todo_.empty()) {
354     FuncGraphPtr fg = todo_.back();
355     todo_.pop_back();
356     if (graphVisited.count(fg) > 0) {
357       continue;
358     }
359     if (nodeName_.count(fg->ToString()) > 0) {
360       MS_LOG(ERROR) << "There is a duplicate name: " << fg->ToString();
361       return false;
362     }
363     (void)nodeName_.insert(fg->ToString());
364     (void)graphVisited.insert(fg);
365     auto graph = model_->add_functions();
366     if (!BuildFuncGraph(fg, graph)) {
367       MS_LOG(ERROR) << "Build func_graph " << fg->ToString() << " failed.";
368       return false;
369     }
370   }
371 
372   if (!BuildPrimitives()) {
373     return false;
374   }
375   // Release resource
376   nodeName_.clear();
377   node_name_map_.clear();
378   primitive_name_map_.clear();
379   graph_protos_.clear();
380   MS_LOG(INFO) << "BuildModel end.";
381   return true;
382 }
383 
BuildFuncGraph(const FuncGraphPtr & func_graph,mind_ir::GraphProto * const graph_proto)384 bool IrExportBuilder::BuildFuncGraph(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto) {
385   graph_protos_[func_graph] = graph_proto;
386   // Export funcGraph name.
387   graph_proto->set_name(func_graph->ToString());
388   // Export parameters
389   // 1. parameters should be mapped to ValueInfoProto
390   // 2. parameters with default value should be mapped to Initializer
391   if (!BuildParameters(func_graph, graph_proto)) {
392     MS_LOG(ERROR) << "Build parameters failed.";
393     return false;
394   }
395 
396   // Export graph attributes
397   if (!BuildFuncGraphAttrs(func_graph, graph_proto)) {
398     MS_LOG(ERROR) << "Build attributes for graph failed.";
399     return false;
400   }
401 
402   // Export operator nodes(include output)
403   return BuildNodes(func_graph, graph_proto);
404 }
405 
BuildFuncGraphAttrs(const FuncGraphPtr & func_graph,mind_ir::GraphProto * const graph_proto)406 bool IrExportBuilder::BuildFuncGraphAttrs(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto) {
407   MS_EXCEPTION_IF_NULL(func_graph);
408   MS_EXCEPTION_IF_NULL(graph_proto);
409   for (const auto &attr : func_graph->attrs()) {
410     MS_LOG(DEBUG) << "attr: " << attr.first << " " << attr.second->DumpText() << " " << attr.second->type_name();
411     auto iter = g_export_attr_blacklist.find(attr.first);
412     if (iter != g_export_attr_blacklist.end()) {
413       continue;
414     }
415     if (attr.second == nullptr) {
416       MS_LOG(ERROR) << "attr: " << attr.first << " has no value.";
417       continue;
418     }
419     mind_ir::AttributeProto *attr_proto = graph_proto->add_attribute();
420     attr_proto->set_name(attr.first);
421     if (!SetValueToAttributeProto(attr.second, attr_proto)) {
422       MS_LOG(ERROR) << "Set value to AttributeProto for GraphProto failed.";
423       return false;
424     }
425   }
426   return true;
427 }
428 
ExportWeight(const ParameterPtr & param,const std::string & param_name,mind_ir::GraphProto * const graph_proto)429 bool IrExportBuilder::ExportWeight(const ParameterPtr &param, const std::string &param_name,
430                                    mind_ir::GraphProto *const graph_proto) {
431   MS_LOG(DEBUG) << "Parameter: '" << param->DebugString();
432   auto param_abs = param->abstract();
433   MS_EXCEPTION_IF_NULL(param_abs);
434   if (param_abs->isa<abstract::AbstractMapTensor>()) {
435     auto *map_parameter_proto = graph_proto->add_map_parameter();
436     if (!ConvertMapParameterToMapTensorProto(param, map_parameter_proto)) {
437       MS_LOG(ERROR) << "Convert MapParameter " << param->ToString() << " to MapTensorProto failed.";
438       return false;
439     }
440     return true;
441   }
442   if (param_abs->isa<abstract::AbstractTensor>()) {
443     mind_ir::TensorProto *parameter_proto = graph_proto->add_parameter();
444     parameter_proto->set_name(param_name);
445     if (!SetParamToTensorProto(param, parameter_proto)) {
446       MS_LOG(ERROR) << "Set parameter " << param->DebugString() << " to TensorProto failed.";
447       return false;
448     }
449     return true;
450   }
451   MS_LOG(ERROR) << "Only support MapTensor or Tensor as default param of Parameter, got: "
452                 << param->default_param()->ToString();
453   return false;
454 }
455 
BuildParameters(const FuncGraphPtr & func_graph,mind_ir::GraphProto * const graph_proto)456 bool IrExportBuilder::BuildParameters(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto) {
457   MS_EXCEPTION_IF_NULL(func_graph);
458   MS_EXCEPTION_IF_NULL(graph_proto);
459   auto &context = CompileCacheContext::GetInstance();
460   auto param_size = func_graph->parameters().size();
461   MS_LOG(DEBUG) << "func graph: " << func_graph->ToString() << " parameter num:" << param_size
462                 << ", fv param num:" << func_graph->fv_param_count();
463   for (size_t param_counter = 0; param_counter < param_size; ++param_counter) {
464     auto &item = func_graph->parameters()[param_counter];
465     MS_EXCEPTION_IF_NULL(item);
466     auto param = item->cast<ParameterPtr>();
467     if (param == nullptr) {
468       MS_LOG(ERROR) << "Parameter: '" << item->ToString() << "' could not cast to parameter.";
469       return false;
470     }
471     if (is_kernel_graph_ && (node_name_map_.find(param) != node_name_map_.end() || param->func_graph() != func_graph)) {
472       continue;
473     }
474     std::string param_name = GetUniqueNodeName(param);
475     param->set_user_data<std::string>(kUniqueCacheName, std::make_shared<std::string>(param_name));
476     if (is_kernel_graph_ && context.IsBackendParamGenFromFrontendParam(param)) {
477       (void)nodeName_.insert(param_name);
478       continue;
479     }
480     if (top_graph &&
481         (param_counter >= param_size - func_graph->fv_param_count() || (is_kernel_graph_ && param->has_default()))) {
482       if (!ExportWeight(param, param_name, graph_proto)) {
483         MS_LOG(ERROR) << "Failed to export parameter weight:" << param->DebugString();
484         return false;
485       }
486     } else {
487       // export graph input
488       mind_ir::ValueInfoProto *input_proto = graph_proto->add_input();
489       input_proto->set_name(param_name);
490       if (!SetValueInfoProto(param, input_proto)) {
491         MS_LOG(ERROR) << "Set parameter " << param->DebugString() << " to TensorProto failed.";
492         return false;
493       }
494     }
495     if (nodeName_.count(param_name) > 0) {
496       MS_LOG(ERROR) << "parameter name is duplicate:" << param_name;
497       return false;
498     }
499     (void)nodeName_.insert(param_name);
500   }
501   return true;
502 }
503 
SetQuantizationParamToAttrProto(const std::shared_ptr<QuantizationParam> & quantization_param,mind_ir::TensorProto_QuantParamProto * const quant_param_proto)504 bool IrExportBuilder::SetQuantizationParamToAttrProto(const std::shared_ptr<QuantizationParam> &quantization_param,
505                                                       mind_ir::TensorProto_QuantParamProto *const quant_param_proto) {
506   quant_param_proto->set_quant_algo_name(quantization_param->quant_algo_name());
507   auto quant_param_attrs = quantization_param->attrs();
508   for (auto &quant_param_attr : quant_param_attrs) {
509     if (quant_param_attr.second == nullptr) {
510       MS_LOG(ERROR) << "attr: " << quant_param_attr.first << " has no value.";
511       continue;
512     }
513     auto attr_proto = quant_param_proto->add_attribute();
514     attr_proto->set_name(quant_param_attr.first);
515     auto value_ptr = quant_param_attr.second;
516     auto ret = SetValueToAttributeProto(value_ptr, attr_proto);
517     if (!ret) {
518       MS_LOG(ERROR) << "QuantizationParam Set Value to AttributeProto Error";
519       return false;
520     }
521   }
522   return true;
523 }
524 
SetFunctorToAttrProto(const FunctorPtr & func,mind_ir::AttributeProto * const attr_proto)525 bool IrExportBuilder::SetFunctorToAttrProto(const FunctorPtr &func, mind_ir::AttributeProto *const attr_proto) {
526   auto *functor_proto = attr_proto->mutable_functor();
527   attr_proto->set_type(mind_ir::AttributeProto_AttributeType_FUNCTOR);
528   if (func->isa<ShapeCalcBaseFunctor>()) {
529     functor_proto->set_type(mind_ir::FunctorProto_FunctorType_SHAPE_CALC_FUNCTOR);
530   } else {
531     MS_LOG(ERROR) << "Unknown functor: " << func->ToString();
532     return false;
533   }
534   functor_proto->set_name(func->name());
535   auto values = func->ToValue();
536   if (values == nullptr) {
537     values = kNone;
538   }
539   if (!SetValueToAttributeProto(values, functor_proto->add_values())) {
540     return false;
541   }
542   return true;
543 }
544 
GetMindirDataType(TypeId type_id) const545 mind_ir::TensorProto_DataType IrExportBuilder::GetMindirDataType(TypeId type_id) const {
546   auto iter = g_data_type_map.find(type_id);
547   if (iter == g_data_type_map.end()) {
548     MS_LOG(ERROR) << "Convert type error, unsupported type! " << type_id;
549     return mind_ir::TensorProto_DataType_UNDEFINED;
550   }
551   return iter->second;
552 }
553 
GetMindirDataBitsIntType(int bits) const554 mind_ir::TensorProto_DataType IrExportBuilder::GetMindirDataBitsIntType(int bits) const {
555   auto iter = g_data_bits_int_map.find(bits);
556   if (iter == g_data_bits_int_map.end()) {
557     MS_LOG(ERROR) << "Convert bits int error, unsupported bits! " << bits;
558     return mind_ir::TensorProto_DataType_UNDEFINED;
559   }
560   return iter->second;
561 }
562 
GetMindirDataBitsUIntType(int bits) const563 mind_ir::TensorProto_DataType IrExportBuilder::GetMindirDataBitsUIntType(int bits) const {
564   auto iter = g_data_bits_uint_map.find(bits);
565   if (iter == g_data_bits_uint_map.end()) {
566     MS_LOG(ERROR) << "Convert bits uint error, unsupported bits! " << bits;
567     return mind_ir::TensorProto_DataType_UNDEFINED;
568   }
569   return iter->second;
570 }
571 
GetMindirDataBitsFloatType(int bits) const572 mind_ir::TensorProto_DataType IrExportBuilder::GetMindirDataBitsFloatType(int bits) const {
573   auto iter = g_data_bits_float_map.find(bits);
574   if (iter == g_data_bits_float_map.end()) {
575     MS_LOG(ERROR) << "Convert bits float error, unsupported bits! " << bits;
576     return mind_ir::TensorProto_DataType_UNDEFINED;
577   }
578   return iter->second;
579 }
580 
GetMindirDataBitsBFloatType(int bits) const581 mind_ir::TensorProto_DataType IrExportBuilder::GetMindirDataBitsBFloatType(int bits) const {
582   auto iter = g_data_bits_bfloat_map.find(bits);
583   if (iter == g_data_bits_bfloat_map.end()) {
584     MS_LOG(ERROR) << "Convert bits bfloat error, unsupported bits! " << bits;
585     return mind_ir::TensorProto_DataType_UNDEFINED;
586   }
587   return iter->second;
588 }
589 
GetMindirDataBitsComplexType(int bits) const590 mind_ir::TensorProto_DataType IrExportBuilder::GetMindirDataBitsComplexType(int bits) const {
591   auto iter = g_data_bits_complex_map.find(bits);
592   if (iter == g_data_bits_complex_map.end()) {
593     MS_LOG(ERROR) << "Convert bits float error, unsupported bits! " << bits;
594     return mind_ir::TensorProto_DataType_UNDEFINED;
595   }
596   return iter->second;
597 }
598 
SetValueInfoProto(const AnfNodePtr & node,mind_ir::ValueInfoProto * const value_proto)599 bool IrExportBuilder::SetValueInfoProto(const AnfNodePtr &node, mind_ir::ValueInfoProto *const value_proto) {
600   if (node == nullptr || value_proto == nullptr) {
601     MS_LOG(EXCEPTION) << "AnfNode or ValueInfo is null!";
602   }
603   MS_LOG(DEBUG) << "SetValueInfoProto: " << node->DebugString();
604   const TypePtr &type = node->Type();
605   const BaseShapePtr &shape = node->Shape();
606   // For the bprop fg which has not been renormalized.
607   if (type == nullptr || shape == nullptr) {
608     return true;
609   }
610   if (type->isa<TensorType>() && shape->isa<abstract::Shape>()) {
611     mind_ir::TensorProto *tensor_proto = value_proto->add_tensor();
612     if (!SetTensorProto(node->abstract(), tensor_proto)) {
613       return false;
614     }
615   } else {
616     mind_ir::AttributeProto *attribute = value_proto->mutable_attr_info();
617     if (!SetAbstractToNodeProto(node->abstract(), attribute)) {
618       MS_LOG(ERROR) << "Set shape to Proto for " << node->DebugString() << " failed.";
619       return false;
620     }
621     value_proto->set_denotation(type->type_name());
622   }
623   MS_LOG(DEBUG) << "Value type: " << type->type_name();
624   return true;
625 }
626 
SetTensorToAttributeProto(const ValuePtr & value,mind_ir::AttributeProto * const attr_proto)627 bool IrExportBuilder::SetTensorToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) {
628   if (value == nullptr || attr_proto == nullptr) {
629     MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!";
630   }
631   attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSORS);
632   mind_ir::TensorProto *tensor_proto = attr_proto->add_tensors();
633   tensor_proto->set_name("value0");
634   auto data = value->cast<tensor::TensorPtr>();
635   MS_EXCEPTION_IF_NULL(data);
636   tensor_proto->set_raw_data(data->data_c(), static_cast<size_t>(data->data().nbytes()));
637   auto dtype = data->data_type();
638   auto shape = data->shape_c();
639   auto data_type = GetMindirDataType(dtype);
640   if (data_type == mind_ir::TensorProto_DataType_UNDEFINED) {
641     return false;
642   }
643   tensor_proto->set_data_type(data_type);
644   for (const auto &dim : shape) {
645     tensor_proto->add_dims(dim);
646   }
647   return true;
648 }
649 
SetCSRTensorToProto(const AbstractBasePtr & abstract,mind_ir::AttributeProto * const attr_proto)650 bool IrExportBuilder::SetCSRTensorToProto(const AbstractBasePtr &abstract, mind_ir::AttributeProto *const attr_proto) {
651   abstract::AbstractCSRTensorPtr csr_tensor_abs = abstract->cast<abstract::AbstractCSRTensorPtr>();
652   MS_EXCEPTION_IF_NULL(csr_tensor_abs);
653   attr_proto->set_type(mind_ir::AttributeProto_AttributeType_CSR_TENSOR);
654   mind_ir::AttributeProto *indptr = attr_proto->add_values();
655   bool res = SetAbstractToNodeProto(csr_tensor_abs->indptr(), indptr);
656   mind_ir::AttributeProto *indices = attr_proto->add_values();
657   res = res && SetAbstractToNodeProto(csr_tensor_abs->indices(), indices);
658   mind_ir::AttributeProto *values = attr_proto->add_values();
659   res = res && SetAbstractToNodeProto(csr_tensor_abs->values(), values);
660   mind_ir::AttributeProto *shape = attr_proto->add_values();
661   res = res && SetAbstractToNodeProto(csr_tensor_abs->shape(), shape);
662   return res;
663 }
664 
SetCOOTensorToProto(const AbstractBasePtr & abstract,mind_ir::AttributeProto * const attr_proto)665 bool IrExportBuilder::SetCOOTensorToProto(const AbstractBasePtr &abstract, mind_ir::AttributeProto *const attr_proto) {
666   abstract::AbstractCOOTensorPtr coo_tensor_abs = abstract->cast<abstract::AbstractCOOTensorPtr>();
667   MS_EXCEPTION_IF_NULL(coo_tensor_abs);
668   attr_proto->set_type(mind_ir::AttributeProto_AttributeType_COO_TENSOR);
669   mind_ir::AttributeProto *indices = attr_proto->add_values();
670   bool res = SetAbstractToNodeProto(coo_tensor_abs->indices(), indices);
671   mind_ir::AttributeProto *values = attr_proto->add_values();
672   res = res && SetAbstractToNodeProto(coo_tensor_abs->values(), values);
673   mind_ir::AttributeProto *shape = attr_proto->add_values();
674   res = res && SetAbstractToNodeProto(coo_tensor_abs->shape(), shape);
675   return res;
676 }
677 
SetTensorProto(const AbstractBasePtr & abstract,mind_ir::TensorProto * const tensor_proto)678 bool IrExportBuilder::SetTensorProto(const AbstractBasePtr &abstract, mind_ir::TensorProto *const tensor_proto) {
679   auto type = abstract->BuildType();
680   auto shape = abstract->BuildShape();
681   if (!type->isa<TensorType>() || !shape->isa<abstract::Shape>()) {
682     MS_LOG(ERROR) << "Type or shape is not supported! " << type->ToString();
683     return false;
684   }
685   auto tensor = type->cast<TensorTypePtr>();
686   auto tensor_shape = shape->cast<abstract::ShapePtr>();
687   const auto &dims = tensor_shape->shape();
688   auto data_type = GetMindirDataType(tensor->element()->type_id());
689   if (data_type == mind_ir::TensorProto_DataType_UNDEFINED) {
690     return false;
691   }
692   tensor_proto->set_data_type(data_type);
693   for (const auto &dim : dims) {
694     tensor_proto->add_dims(dim);
695   }
696 
697   if (!abstract->name().empty()) {
698     tensor_proto->set_name(abstract->name());
699   }
700   // Deal Ref
701   if (!type->isa<RefType>()) {
702     return true;
703   }
704 
705   auto abs_ref = abstract->cast<abstract::AbstractRefPtr>();
706   if (abs_ref == nullptr) {
707     MS_LOG(ERROR) << "The abstract " << abstract->ToString() << " should be AbstractRefTensor.";
708     return false;
709   }
710   auto ref_key_value = abs_ref->ref_key_value()->cast<StringImmPtr>();
711   if (ref_key_value == nullptr) {
712     MS_LOG(INFO) << "The ref_key_value of abstract ref " << abstract->ToString() << " is nullptr";
713     return true;
714   }
715   tensor_proto->set_ref_key(ref_key_value->value());
716   return true;
717 }
718 
SetParamToTensorProto(const ParameterPtr & param,mind_ir::TensorProto * const tensor_proto)719 bool IrExportBuilder::SetParamToTensorProto(const ParameterPtr &param, mind_ir::TensorProto *const tensor_proto) {
720   if (param == nullptr || tensor_proto == nullptr) {
721     MS_LOG(EXCEPTION) << "Parameter or TensorProto is null!";
722   }
723   MS_LOG(DEBUG) << "SetParamToTensorProto: " << param->DebugString();
724   if (!SetTensorProto(param->abstract(), tensor_proto)) {
725     MS_LOG(ERROR) << "Export Parameter to tensor proto failed.";
726     return false;
727   }
728   // export quant parameter info
729   auto tensor = param->default_param()->cast<tensor::TensorPtr>();
730   if (tensor != nullptr) {
731     tensor_proto->set_compression_type(static_cast<mind_ir::TensorProto_CompressionType>(tensor->compression_type()));
732     auto quant_params = tensor->quant_params();
733     for (const auto &quant_param : quant_params) {
734       auto quant_param_proto = tensor_proto->add_quant_params();
735       auto ret = SetQuantizationParamToAttrProto(quant_param, quant_param_proto);
736       if (ret != true) {
737         MS_LOG(ERROR) << "QuantizationParam Set Value to AttributeProto Error";
738         return false;
739       }
740     }
741   }
742   return true;
743 }
744 
ConvertMapParameterToMapTensorProto(const ParameterPtr & map_parameter,mind_ir::MapTensorProto * const map_tensor_proto)745 bool IrExportBuilder::ConvertMapParameterToMapTensorProto(const ParameterPtr &map_parameter,
746                                                           mind_ir::MapTensorProto *const map_tensor_proto) {
747   if (map_parameter == nullptr || map_tensor_proto == nullptr) {
748     MS_LOG(EXCEPTION) << "MapParameter or MapTensorProto is null!";
749   }
750   MS_LOG(DEBUG) << "ConvertMapParameterToMapTensorProto: " << map_parameter->ToString();
751 
752   // parameter name
753   map_tensor_proto->set_name(GetUniqueNodeName(map_parameter));
754 
755   auto param_default = map_parameter->default_param();
756   MS_EXCEPTION_IF_NULL(param_default);
757   auto map_tensor = param_default->cast<tensor::MapTensorPtr>();
758   MS_EXCEPTION_IF_NULL(map_tensor);
759   // default value
760   auto default_value = map_tensor->default_value();
761   MS_EXCEPTION_IF_NULL(default_value);
762   auto *default_value_proto = map_tensor_proto->mutable_default_value();
763   MS_EXCEPTION_IF_NULL(default_value_proto);
764   if (!SetValueToAttributeProto(default_value, default_value_proto)) {
765     MS_LOG(ERROR) << "Export default value of MapTensor failed, default_value: " << default_value->ToString();
766     return false;
767   }
768   tensor::MapTensor::ExportData export_data = map_tensor->Export(this->incremental_);
769   // key_tensor
770   auto *key_tensor_proto = map_tensor_proto->mutable_key_tensor();
771   MS_EXCEPTION_IF_NULL(key_tensor_proto);
772   auto &key_tensor = export_data.key_tensor;
773   MS_EXCEPTION_IF_NULL(key_tensor);
774   if (!SetTensorProto(key_tensor->ToAbstract(), key_tensor_proto)) {
775     MS_LOG(ERROR) << "Export key tensor of MapTensor failed, key_tensor: " << key_tensor->ToString();
776     return false;
777   }
778   // value_tensor
779   auto *value_tensor_proto = map_tensor_proto->mutable_value_tensor();
780   MS_EXCEPTION_IF_NULL(value_tensor_proto);
781   auto &value_tensor = export_data.value_tensor;
782   MS_EXCEPTION_IF_NULL(value_tensor);
783   if (!SetTensorProto(value_tensor->ToAbstract(), value_tensor_proto)) {
784     MS_LOG(ERROR) << "Export value tensor of MapTensor failed, value_tensor: " << value_tensor->ToString();
785     return false;
786   }
787   // status_tensor
788   auto *status_tensor_proto = map_tensor_proto->mutable_status_tensor();
789   MS_EXCEPTION_IF_NULL(status_tensor_proto);
790   auto &status_tensor = export_data.status_tensor;
791   MS_EXCEPTION_IF_NULL(status_tensor);
792   if (!SetTensorProto(status_tensor->ToAbstract(), status_tensor_proto)) {
793     MS_LOG(ERROR) << "Export status tensor of MapTensor failed, status_tensor: " << status_tensor->ToString();
794     return false;
795   }
796   return true;
797 }
798 
ConvertAbstractMapTensorToAttrProto(const AbstractBasePtr & abstract,mind_ir::AttributeProto * const attr_proto)799 bool IrExportBuilder::ConvertAbstractMapTensorToAttrProto(const AbstractBasePtr &abstract,
800                                                           mind_ir::AttributeProto *const attr_proto) {
801   auto map_tensor_abs = abstract->cast<abstract::AbstractMapTensorPtr>();
802   MS_EXCEPTION_IF_NULL(map_tensor_abs);
803 
804   auto map_tensor_type = map_tensor_abs->map_tensor_type();
805   MS_EXCEPTION_IF_NULL(map_tensor_type);
806   attr_proto->set_type(mind_ir::AttributeProto_AttributeType_MAP_TENSOR);
807   // key_tensor
808   auto key_dtype = map_tensor_type->key_dtype();
809   auto key_shape = {abstract::Shape::kShapeDimAny};
810   auto key_tensor_abs = std::make_shared<abstract::AbstractTensor>(key_dtype, key_shape);
811   auto *key_tensor_proto = attr_proto->add_tensors();
812   MS_EXCEPTION_IF_NULL(key_tensor_proto);
813   MS_EXCEPTION_IF_NULL(key_tensor_abs);
814   if (!SetTensorProto(key_tensor_abs, key_tensor_proto)) {
815     MS_LOG(ERROR) << "Export key tensor abstract of AbstractMapTensor failed, abstract_map_tensor: "
816                   << abstract->ToString();
817     return false;
818   }
819   // value_dtype value_shape
820   auto value_dtype = map_tensor_type->key_dtype();
821   auto value_shape = map_tensor_abs->value_shape()->shape();
822   auto value_tensor_abs = std::make_shared<abstract::AbstractTensor>(value_dtype, value_shape);
823   auto *value_tensor_proto = attr_proto->add_tensors();
824   MS_EXCEPTION_IF_NULL(value_tensor_proto);
825   MS_EXCEPTION_IF_NULL(value_tensor_abs);
826   if (!SetTensorProto(value_tensor_abs, value_tensor_proto)) {
827     MS_LOG(ERROR) << "Export value tensor abstract of AbstractMapTensor failed, abstract_map_tensor: "
828                   << abstract->ToString();
829     return false;
830   }
831   // default_value
832   auto default_value = map_tensor_abs->default_value();
833   if (default_value != nullptr) {
834     auto *default_value_proto = attr_proto->add_values();
835     MS_EXCEPTION_IF_NULL(default_value_proto);
836     if (!SetValueToAttributeProto(default_value, default_value_proto)) {
837       MS_LOG(ERROR) << "Export default value of AbstractMapTensor failed, abstract_map_tensor: "
838                     << abstract->ToString();
839       return false;
840     }
841   }
842   return true;
843 }
844 
BuildNodes(const FuncGraphPtr & func_graph,mind_ir::GraphProto * const graph_proto)845 bool IrExportBuilder::BuildNodes(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto) {
846   std::vector<AnfNodePtr> nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude);
847   for (const AnfNodePtr &node : nodes) {
848     MS_EXCEPTION_IF_NULL(node);
849     if (!node->isa<CNode>()) {
850       MS_LOG(DEBUG) << "Node: '" << node->ToString() << "' is not cnode";
851       continue;
852     }
853     if (is_kernel_graph_ && (node_name_map_.find(node) != node_name_map_.end() || node->func_graph() != func_graph)) {
854       continue;
855     }
856     auto cnode = node->cast<CNodePtr>();
857     if (cnode == func_graph->get_return()) {
858       if (!BuildOutput(cnode, graph_proto)) {
859         MS_LOG(ERROR) << "Build output for graph " << func_graph->ToString() << " failed.";
860         return false;
861       }
862     } else {
863       auto iter = graph_protos_.find(node->func_graph());
864       if (iter == graph_protos_.end()) {
865         MS_LOG(ERROR) << "Can not find the graph proto of func_graph " << node->func_graph()->ToString();
866         return false;
867       }
868       auto owner_graph_proto = iter->second;
869       if (!BuildCNode(cnode, owner_graph_proto)) {
870         MS_LOG(ERROR) << "Build proto for cnode " << cnode->DebugString() << " failed.";
871         return false;
872       }
873     }
874   }
875   return true;
876 }
877 
BuildIsolatedCNode(const AnfNodePtr & node,std::set<AnfNodePtr> * visited)878 bool IrExportBuilder::BuildIsolatedCNode(const AnfNodePtr &node, std::set<AnfNodePtr> *visited) {
879   MS_EXCEPTION_IF_NULL(node);
880   auto iter = node_name_map_.find(node);
881   if (iter != node_name_map_.end()) {
882     return true;
883   }
884   MS_EXCEPTION_IF_NULL(visited);
885   if (visited->find(node) != visited->end()) {
886     MS_LOG(ERROR) << "There is a cycle when build node " << node->DebugString();
887     return false;
888   }
889   if (!node->isa<CNode>()) {
890     return false;
891   }
892   const auto &cnode = node->cast<CNodePtr>();
893   MS_EXCEPTION_IF_NULL(cnode);
894   const auto &graph = cnode->func_graph();
895   if (!graph) {
896     MS_LOG(ERROR) << "The isolated node " << node->DebugString() << " is not belongs to any graph.";
897     return false;
898   }
899   auto graph_proto = graph_protos_[graph];
900   auto input_size = cnode->size();
901   std::vector<string> input_names;
902   // build input nodes
903   for (size_t i = 1; i < input_size; i++) {
904     auto input = cnode->input(i);
905     MS_EXCEPTION_IF_NULL(input);
906     if (input->isa<Parameter>()) {
907       MS_LOG(ERROR) << "Only support that the isolated node's input is cnode or value_node, but the input is "
908                     << input->DebugString();
909       return false;
910     }
911     std::string node_name;
912     if (input->isa<ValueNode>()) {
913       auto input_graph = input->func_graph();
914       auto input_proto = input_graph ? graph_protos_[input_graph] : graph_proto;
915       MS_EXCEPTION_IF_NULL(input_proto);
916       node_name = BuildInputNode(input, input_proto);
917     } else {
918       if (!BuildIsolatedCNode(input, visited)) {
919         MS_LOG(ERROR) << "Build input node for " << input->DebugString() << " failed.";
920         return false;
921       }
922       node_name = GetUniqueNodeName(input);
923     }
924     if (node_name.empty()) {
925       MS_LOG(ERROR) << "Build input node for " << input->DebugString() << " failed.";
926       return false;
927     }
928     input->set_user_data<std::string>(kUniqueCacheName, std::make_shared<std::string>(node_name));
929     input_names.push_back(node_name);
930   }
931   // build cnode
932   auto output_name = GetUniqueNodeName(cnode);
933   cnode->set_user_data<std::string>(kUniqueCacheName, std::make_shared<std::string>(output_name));
934   if (nodeName_.count(output_name) > 0) {
935     MS_LOG(INFO) << "There is a duplicate name: " << output_name;
936     return true;
937   }
938   mind_ir::NodeProto *node_proto = graph_proto->add_node();
939   (void)nodeName_.insert(output_name);
940   node_proto->add_output(output_name);
941   node_proto->set_name(output_name);
942   node_proto->set_domain(cnode->fullname_with_scope());
943   AnfNodePtr op = cnode->input(0);
944   std::string type_name = GetOpTypeName(op);
945   if (type_name.empty()) {
946     MS_LOG(ERROR) << "Get op type name for " << op->DebugString() << " failed.";
947     return false;
948   }
949   node_proto->set_op_type(type_name);
950   if (!SetAbstractToNodeProto(cnode, node_proto)) {
951     MS_LOG(DEBUG) << "Fail to export abstract of the node: " << node->DebugString();
952   }
953   (void)std::for_each(input_names.begin(), input_names.end(),
954                       [&node_proto](const string &name) { node_proto->add_input(name); });
955   if (!BuildCNodeAttr(cnode, node_proto)) {
956     MS_LOG(ERROR) << "Set value to node attr to node proto failed.";
957     return false;
958   }
959   (void)(visited->insert(node));
960   return true;
961 }
962 
BuildIsolatedNodes(const std::vector<AnfNodePtr> & isolated_nodes)963 bool IrExportBuilder::BuildIsolatedNodes(const std::vector<AnfNodePtr> &isolated_nodes) {
964   for (const auto &node : isolated_nodes) {
965     if (!node->isa<CNode>()) {
966       MS_LOG(ERROR) << "Only support that the isolated node is cnode, but the node is " << node->DebugString();
967       return false;
968     }
969     if (mindspore::IsPrimitiveCNode(node, mindspore::prim::kPrimReturn)) {
970       MS_LOG(ERROR) << "Only support that the isolated node is not return node, but the node is "
971                     << node->DebugString();
972       return false;
973     }
974     std::set<AnfNodePtr> visited;
975     if (!BuildIsolatedCNode(node, &visited)) {
976       MS_LOG(ERROR) << "Build isolated node " << node->DebugString() << " failed.";
977       return false;
978     }
979   }
980 
981   return true;
982 }
983 
BuildOutput(const CNodePtr & node,mind_ir::GraphProto * const graph_proto)984 bool IrExportBuilder::BuildOutput(const CNodePtr &node, mind_ir::GraphProto *const graph_proto) {
985   MS_EXCEPTION_IF_NULL(node);
986   MS_EXCEPTION_IF_NULL(graph_proto);
987   const int OutputSize = 2;
988   if (node->size() != OutputSize) {
989     MS_LOG(ERROR) << "Number of inputs of return node is not equal to 2.";
990     return false;
991   }
992   auto graph_name = graph_proto->name();
993   node->set_user_data<std::string>(kUniqueCacheName, std::make_shared<std::string>(graph_name + kReturnNode));
994   AnfNodePtr arg = node->input(1);
995   auto node_name = BuildInputNode(arg, graph_proto);
996   arg->set_user_data<std::string>(kUniqueCacheName, std::make_shared<std::string>(node_name));
997   if (node_name.empty()) {
998     MS_LOG(ERROR) << "Build input node failed for arg " << arg->DebugString();
999     return false;
1000   }
1001   mind_ir::ValueInfoProto *output_proto = graph_proto->add_output();
1002   output_proto->set_name(node_name);
1003   // for return node primitive export
1004   AnfNodePtr op = node->input(0);
1005   op->set_user_data<std::string>(kUniqueCacheName, std::make_shared<std::string>(graph_name + kReturnPrimNode));
1006   return SetValueInfoProto(arg, output_proto);
1007 }
1008 
GetOpTypeName(const AnfNodePtr & node)1009 std::string IrExportBuilder::GetOpTypeName(const AnfNodePtr &node) {
1010   // May be ValueNode/CNode/Parameter
1011   std::string type_name = "";
1012   if (IsValueNode<Primitive>(node)) {
1013     PrimitivePtr prim = GetValueNode<PrimitivePtr>(node);
1014     MS_EXCEPTION_IF_NULL(prim);
1015     auto do_sign_prim = prim->cast_ptr<prim::DoSignaturePrimitive>();
1016     if (do_sign_prim != nullptr && do_sign_prim->function() != nullptr &&
1017         do_sign_prim->function()->isa<MetaFuncGraph>()) {
1018       type_name = "REF::MetaFuncGraph::" + do_sign_prim->function()->cast_ptr<MetaFuncGraph>()->name();
1019     } else {
1020       const auto &unique_name = GetPrimitiveUniqueName(prim);
1021       type_name = "REF::" + unique_name;
1022       // for valuenode export
1023       node->set_user_data<std::string>(kUniqueCacheName, std::make_shared<std::string>(unique_name));
1024     }
1025   } else if (IsValueNode<FuncGraph>(node)) {
1026     FuncGraphPtr fg = GetValueNode<FuncGraphPtr>(node);
1027     MS_EXCEPTION_IF_NULL(fg);
1028     todo_.push_back(fg);
1029     type_name = "REF::" + fg->ToString();
1030   } else if (node->isa<CNode>() || node->isa<Parameter>()) {
1031     auto nodeName = GetUniqueNodeName(node);
1032     type_name = "REF::" + nodeName;
1033     if (nodeName_.count(nodeName) == 0) {
1034       MS_LOG(ERROR) << "There is not the name: " << nodeName;
1035       return "";
1036     }
1037   } else if (IsValueNode<MindIRClassType>(node)) {
1038     auto class_type = GetValueNode<MindIRClassTypePtr>(node)->name();
1039     // class 'XXX' -> XXX
1040     constexpr int64_t path_begin_index = 7;
1041     auto str = std::string(class_type.begin() + path_begin_index, class_type.end() - 1);
1042     type_name = "REF::ClassType::" + str;
1043   } else if (IsValueNode<MetaFuncGraph>(node)) {
1044     auto meta_fg = GetValueNode<MetaFuncGraphPtr>(node);
1045     MS_EXCEPTION_IF_NULL(meta_fg);
1046     type_name = "REF::MetaFuncGraph::" + meta_fg->name();
1047   } else {
1048     MS_LOG(ERROR) << "Need to support op type: " << node->DebugString();
1049     return "";
1050   }
1051   MS_LOG(DEBUG) << "ExportType: " << type_name;
1052   return type_name;
1053 }
1054 
ExportSequence(const abstract::AbstractSequencePtr & seq_abs,mind_ir::AttributeProto * const attr_proto)1055 bool IrExportBuilder::ExportSequence(const abstract::AbstractSequencePtr &seq_abs,
1056                                      mind_ir::AttributeProto *const attr_proto) {
1057   if (seq_abs->isa<abstract::AbstractTuple>()) {
1058     attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TUPLE);
1059   } else {
1060     attr_proto->set_type(mind_ir::AttributeProto_AttributeType_LIST);
1061   }
1062   auto seq_info_proto = attr_proto->mutable_seq_info();
1063   seq_info_proto->set_is_dyn_len(seq_abs->dynamic_len());
1064 
1065   auto elem_abs = seq_abs->dynamic_len_element_abs();
1066   if (elem_abs != nullptr) {
1067     mind_ir::AttributeProto *tuple_elem_proto = seq_info_proto->mutable_tuple_elem_item();
1068     if (!SetAbstractToNodeProto(elem_abs, tuple_elem_proto)) {
1069       return false;
1070     }
1071   }
1072 
1073   const auto &elems = seq_abs->elements();
1074   for (const auto &item : elems) {
1075     mind_ir::AttributeProto *attr_values = attr_proto->add_values();
1076     if (!SetAbstractToNodeProto(item, attr_values)) {
1077       return false;
1078     }
1079   }
1080   return true;
1081 }
1082 
SetAbstractToNodeProto(const AbstractBasePtr & abs,mind_ir::AttributeProto * const attr_proto)1083 bool IrExportBuilder::SetAbstractToNodeProto(const AbstractBasePtr &abs, mind_ir::AttributeProto *const attr_proto) {
1084   auto type = abs->BuildType();
1085   auto shape = abs->BuildShape();
1086   // Not use abstract because the abstract of csr tensor is a subclass of AbstractTuple
1087   if (type->isa<Tuple>() || type->isa<List>()) {
1088     return ExportSequence(abs->cast<abstract::AbstractSequencePtr>(), attr_proto);
1089   } else if (type->isa<TensorType>() && shape->isa<abstract::Shape>()) {
1090     attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSORS);
1091     mind_ir::TensorProto *tensor_proto = attr_proto->add_tensors();
1092     return SetTensorProto(abs, tensor_proto);
1093   } else if (type->isa<Number>()) {
1094     attr_proto->set_type(mind_ir::AttributeProto_AttributeType_SCALAR);
1095     mind_ir::TensorProto *tensor_proto = attr_proto->add_tensors();
1096     auto data_type = GetMindirDataType(type->type_id());
1097     tensor_proto->set_data_type(data_type);
1098     tensor_proto->add_dims(0);
1099   } else if (type->isa<Function>()) {
1100     if (!SetAbstractFuncToAttributeProto(abs, attr_proto)) {
1101       return false;
1102     }
1103   } else if (type->isa<String>()) {
1104     attr_proto->set_type(mind_ir::AttributeProto_AttributeType_STRING);
1105   } else if (type->isa<UMonadType>()) {
1106     attr_proto->set_type(mind_ir::AttributeProto_AttributeType_UMONAD);
1107   } else if (type->isa<IOMonadType>()) {
1108     attr_proto->set_type(mind_ir::AttributeProto_AttributeType_IOMONAD);
1109   } else if (type->isa<CSRTensorType>()) {
1110     auto csr_tensor_abs = abs->cast<abstract::AbstractCSRTensorPtr>();
1111     if (!SetCSRTensorToProto(csr_tensor_abs, attr_proto)) {
1112       return false;
1113     }
1114   } else if (type->isa<COOTensorType>()) {
1115     auto coo_tensor_abs = abs->cast<abstract::AbstractCOOTensorPtr>();
1116     if (!SetCOOTensorToProto(coo_tensor_abs, attr_proto)) {
1117       return false;
1118     }
1119   } else if (type->isa<TypeNone>()) {
1120     attr_proto->set_type(mind_ir::AttributeProto_AttributeType_NONE);
1121   } else if (type->isa<MapTensorType>()) {
1122     return ConvertAbstractMapTensorToAttrProto(abs, attr_proto);
1123   } else {
1124     MS_LOG(ERROR) << "Type of cnode need to be supported: " << type->type_name();
1125     return false;
1126   }
1127 
1128   return true;
1129 }
1130 
SetAbstractToNodeProto(const CNodePtr & node,mind_ir::NodeProto * const node_proto)1131 bool IrExportBuilder::SetAbstractToNodeProto(const CNodePtr &node, mind_ir::NodeProto *const node_proto) {
1132   // Get shape of cnode
1133   // 1. need to get shape from tuple element
1134   // 2. save shape in TensorProto
1135   MS_EXCEPTION_IF_NULL(node);
1136   auto type = node->Type();
1137   auto shape = node->Shape();
1138   auto abs = node->abstract();
1139   // For the bprop fg which has not been renormalized.
1140   if (type == nullptr || shape == nullptr) {
1141     return true;
1142   }
1143   mind_ir::AttributeProto *attr_proto = node_proto->add_attribute();
1144   if (!SetAbstractToNodeProto(abs, attr_proto)) {
1145     MS_LOG(WARNING) << "Set shape to NodeProto for " << node->DebugString() << " failed. abs: " << abs->ToString();
1146     return false;
1147   }
1148   attr_proto->set_name("shape");
1149   return true;
1150 }
1151 
BuildCNode(const CNodePtr & node,mind_ir::GraphProto * const graph_proto)1152 bool IrExportBuilder::BuildCNode(const CNodePtr &node, mind_ir::GraphProto *const graph_proto) {
1153   auto inputs_size = node->size();
1154   if (inputs_size < 1) {
1155     MS_LOG(ERROR) << "Inputs of node " << node->DebugString() << " is empty";
1156     return false;
1157   }
1158 
1159   // Need to build input node before dealing with cnode
1160   std::vector<string> input_names;
1161   for (size_t i = 1; i < inputs_size; i++) {
1162     auto input = node->input(i);
1163     std::string node_name = BuildInputNode(input, graph_proto);
1164     input->set_user_data<std::string>(kUniqueCacheName, std::make_shared<std::string>(node_name));
1165     if (node_name.empty()) {
1166       MS_LOG(ERROR) << "Build input node for " << input->DebugString() << " failed.";
1167       return false;
1168     }
1169     input_names.push_back(node_name);
1170   }
1171 
1172   // Build cnode
1173   std::string output_name = GetUniqueNodeName(node);
1174   node->set_user_data<std::string>(kUniqueCacheName, std::make_shared<std::string>(output_name));
1175   if (nodeName_.count(output_name) > 0) {
1176     MS_LOG(INFO) << "There is a duplicate name: " << output_name;
1177     return true;
1178   }
1179 
1180   mind_ir::NodeProto *node_proto = graph_proto->add_node();
1181   (void)nodeName_.insert(output_name);
1182   node_proto->add_output(output_name);
1183   node_proto->set_name(output_name);
1184   node_proto->set_domain(node->fullname_with_scope());
1185   AnfNodePtr op = node->input(0);
1186   std::string type_name = GetOpTypeName(op);
1187   if (type_name.empty()) {
1188     MS_LOG(ERROR) << "Get op type name for " << op->DebugString() << " failed.";
1189     return false;
1190   }
1191   node_proto->set_op_type(type_name);
1192   last_node_ = node_proto;
1193   if (!SetAbstractToNodeProto(node, node_proto)) {
1194     MS_LOG(DEBUG) << "Fail to export abstract of the node: " << node->DebugString();
1195   }
1196 
1197   (void)std::for_each(input_names.begin(), input_names.end(),
1198                       [&node_proto](const string &name) { node_proto->add_input(name); });
1199 
1200   if (!BuildCNodeAttr(node, node_proto)) {
1201     MS_LOG(ERROR) << "Set value to node attr to node proto failed.";
1202     return false;
1203   }
1204   return true;
1205 }
1206 
BuildValueNode(const ValueNodePtr & node,const string & node_name,mind_ir::GraphProto * const graph_proto)1207 bool IrExportBuilder::BuildValueNode(const ValueNodePtr &node, const string &node_name,
1208                                      mind_ir::GraphProto *const graph_proto) {
1209   // FuncGraphNode don't need to be exported to the proto in this step
1210   // check the node has been exported before
1211   if (IsValueNode<FuncGraph>(node) || nodeName_.count(node_name) > 0) {
1212     return true;
1213   }
1214   (void)nodeName_.insert(node_name);
1215   // When node input is a ValueNode, need to create a Constant Node
1216   mind_ir::NodeProto *node_proto = graph_proto->add_node();
1217   node_proto->set_name(node_name);
1218   node_proto->add_output(node_name);
1219   if (!SetAttributeProto(node, node_proto)) {
1220     return false;
1221   }
1222   return true;
1223 }
1224 
BuildInputNode(const AnfNodePtr & node,mind_ir::GraphProto * const graph_proto)1225 std::string IrExportBuilder::BuildInputNode(const AnfNodePtr &node, mind_ir::GraphProto *const graph_proto) {
1226   std::string node_name = GetUniqueNodeName(node);
1227   if (node->isa<ValueNode>()) {
1228     if (!BuildValueNode(node->cast<ValueNodePtr>(), node_name, graph_proto)) {
1229       MS_LOG(ERROR) << "Export ValueNode Failed";
1230       return "";
1231     }
1232     MS_LOG(DEBUG) << "Export ValueNode " << node->DebugString() << " success";
1233   }
1234   return node_name;
1235 }
1236 
GetUniqueNodeName(const AnfNodePtr & node)1237 std::string IrExportBuilder::GetUniqueNodeName(const AnfNodePtr &node) {
1238   // Naming anfnode
1239   // 1. parameter is unique in one func_graph
1240   // 2. cnode and valuenode may be reduplicative, so add index to identify.
1241   auto iter = node_name_map_.find(node);
1242   if (iter != node_name_map_.end()) {
1243     return iter->second;
1244   }
1245   // FuncGraph will be added to functions and the input name is the function name.
1246   if (IsValueNode<FuncGraph>(node)) {
1247     FuncGraphPtr fg = GetValueNode<FuncGraphPtr>(node);
1248     todo_.push_back(fg);
1249     auto name = fg->ToString();
1250     node->set_user_data<std::string>(kUniqueCacheName, std::make_shared<std::string>(name));
1251     return name;
1252   }
1253 
1254   std::string node_name = GetNodeName(node);
1255   // Compatible before. CNode = FuncGraphName:CNodeName:index ,Parameter = FuncGraphName:ParameterName
1256   if (node->isa<CNode>()) {
1257     node_name = node_name + ":" + std::to_string(GetUniqueID());
1258   }
1259   // Avoid duplicate name.
1260   while (nodeName_.count(node_name) > 0) {
1261     node_name = node_name + "_" + std::to_string(GetUniqueID());
1262   }
1263   node_name_map_[node] = node_name;
1264   return node_name;
1265 }
1266 
GetNodeName(const AnfNodePtr & node) const1267 std::string IrExportBuilder::GetNodeName(const AnfNodePtr &node) const {
1268   MS_EXCEPTION_IF_NULL(node);
1269   std::string node_name = "";
1270   if (node->func_graph() != nullptr) {
1271     node_name = node->func_graph()->ToString() + ":";
1272   }
1273   if (node->isa<ValueNode>()) {
1274     // Needn't value
1275     node_name += node->AnfNode::ToString();
1276   } else {
1277     node_name += node->ToString();
1278   }
1279   MS_LOG(DEBUG) << "GetNodeName: " << node_name;
1280   return node_name;
1281 }
1282 
SetAttributeProto(const AnfNodePtr & node,mind_ir::NodeProto * const node_proto)1283 bool IrExportBuilder::SetAttributeProto(const AnfNodePtr &node, mind_ir::NodeProto *const node_proto) {
1284   if (node == nullptr || node_proto == nullptr) {
1285     MS_LOG(EXCEPTION) << "AnfNode or NodeProto is null!";
1286   }
1287   auto value_node = node->cast<ValueNodePtr>();
1288   MS_EXCEPTION_IF_NULL(value_node);
1289   auto value = value_node->value();
1290   MS_EXCEPTION_IF_NULL(value);
1291   node_proto->set_op_type("Constant");
1292   mind_ir::AttributeProto *attr_proto = node_proto->add_attribute();
1293   attr_proto->set_name("value");
1294   MS_LOG(DEBUG) << "Set Constant attribute: " << value->ToString();
1295   return SetValueToAttributeProto(value, attr_proto);
1296 }
1297 
SetTensorTypeToAttributeProto(const ValuePtr & value,mind_ir::TensorProto * tensor_proto)1298 bool IrExportBuilder::SetTensorTypeToAttributeProto(const ValuePtr &value, mind_ir::TensorProto *tensor_proto) {
1299   tensor_proto->set_name("tensor0");
1300   auto elem_type = value->cast<TensorTypePtr>()->element();
1301   if (elem_type->isa<Int>()) {
1302     auto int_value = elem_type->cast<IntPtr>();
1303     auto data_type = GetMindirDataBitsIntType(int_value->nbits());
1304     if (data_type == mind_ir::TensorProto_DataType_UNDEFINED) {
1305       return false;
1306     }
1307     tensor_proto->set_data_type(data_type);
1308   } else if (elem_type->isa<Float>()) {
1309     auto float_value = elem_type->cast<FloatPtr>();
1310     auto data_type = GetMindirDataBitsFloatType(float_value->nbits());
1311     if (data_type == mind_ir::TensorProto_DataType_UNDEFINED) {
1312       return false;
1313     }
1314     tensor_proto->set_data_type(data_type);
1315   } else {
1316     MS_LOG(ERROR) << "Unsupported type " << elem_type->type_name();
1317     return false;
1318   }
1319   return true;
1320 }
1321 
SetTypeToAttributeProto(const ValuePtr & value,mind_ir::AttributeProto * const attr_proto)1322 bool IrExportBuilder::SetTypeToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) {
1323   if (value == nullptr || attr_proto == nullptr) {
1324     MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!";
1325   }
1326   attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSORS);
1327   mind_ir::TensorProto *tensor_proto = attr_proto->add_tensors();
1328   if (value->isa<Int>()) {
1329     tensor_proto->set_name("value0");
1330     auto int_value = value->cast<IntPtr>();
1331     auto data_type = GetMindirDataBitsIntType(int_value->nbits());
1332     if (data_type == mind_ir::TensorProto_DataType_UNDEFINED) {
1333       return false;
1334     }
1335     tensor_proto->set_data_type(data_type);
1336   } else if (value->isa<UInt>()) {
1337     tensor_proto->set_name("value0");
1338     auto float_value = value->cast<UIntPtr>();
1339     auto data_type = GetMindirDataBitsUIntType(float_value->nbits());
1340     if (data_type == mind_ir::TensorProto_DataType_UNDEFINED) {
1341       return false;
1342     }
1343     tensor_proto->set_data_type(data_type);
1344   } else if (value->isa<Float>()) {
1345     tensor_proto->set_name("value0");
1346     auto float_value = value->cast<FloatPtr>();
1347     auto data_type = GetMindirDataBitsFloatType(float_value->nbits());
1348     if (data_type == mind_ir::TensorProto_DataType_UNDEFINED) {
1349       return false;
1350     }
1351     tensor_proto->set_data_type(data_type);
1352   } else if (value->isa<BFloat>()) {
1353     tensor_proto->set_name("value0");
1354     auto bfloat_value = value->cast<BFloatPtr>();
1355     MS_EXCEPTION_IF_NULL(bfloat_value);
1356     auto data_type = GetMindirDataBitsBFloatType(bfloat_value->nbits());
1357     if (data_type == mind_ir::TensorProto_DataType_UNDEFINED) {
1358       return false;
1359     }
1360     tensor_proto->set_data_type(data_type);
1361   } else if (value->isa<Complex>()) {
1362     tensor_proto->set_name("value0");
1363     auto complex_value = value->cast<ComplexPtr>();
1364     auto data_type = GetMindirDataBitsComplexType(complex_value->nbits());
1365     if (data_type == mind_ir::TensorProto_DataType_UNDEFINED) {
1366       return false;
1367     }
1368     tensor_proto->set_data_type(data_type);
1369   } else if (value->isa<Bool>()) {
1370     tensor_proto->set_name("value0");
1371     tensor_proto->set_data_type(mind_ir::TensorProto_DataType_BOOL);
1372   } else if (value->isa<TensorType>()) {
1373     return SetTensorTypeToAttributeProto(value, tensor_proto);
1374   } else {
1375     MS_LOG(EXCEPTION) << "Unsupported type: " << value->type_name();
1376   }
1377   return true;
1378 }
1379 
SetNamedValueToAttributeProto(const ValuePtr & value,mind_ir::AttributeProto * const attr_proto) const1380 bool IrExportBuilder::SetNamedValueToAttributeProto(const ValuePtr &value,
1381                                                     mind_ir::AttributeProto *const attr_proto) const {
1382   if (value->isa<None>()) {
1383     attr_proto->set_type(mind_ir::AttributeProto_AttributeType_NONE);
1384     MS_LOG(DEBUG) << "Attr string: " << value->type_name();
1385   } else if (value->isa<MindIRClassType>()) {
1386     attr_proto->set_type(mind_ir::AttributeProto_AttributeType_CLASS_TYPE);
1387     auto class_type = GetValue<MindIRClassTypePtr>(value)->name();
1388     // class 'XXX' -> XXX
1389     constexpr int64_t path_begin_index = 7;
1390     auto str = std::string(class_type.begin() + path_begin_index, class_type.end() - 1);
1391     attr_proto->set_s(str);
1392   } else if (value->isa<MindIRNameSpace>()) {
1393     attr_proto->set_type(mind_ir::AttributeProto_AttributeType_NAME_SPACE);
1394     attr_proto->set_s(GetValue<MindIRNameSpacePtr>(value)->name_space());
1395   } else if (value->isa<MindIRSymbol>()) {
1396     attr_proto->set_type(mind_ir::AttributeProto_AttributeType_SYMBOL);
1397     attr_proto->set_s(GetValue<MindIRSymbolPtr>(value)->symbol());
1398   } else {
1399     MS_LOG(ERROR) << "Unsupported named type: " << value->type_name();
1400     return false;
1401   }
1402   return true;
1403 }
1404 
SetValueToAttributeProto(const ValuePtr & value,mind_ir::AttributeProto * const attr_proto)1405 bool IrExportBuilder::SetValueToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) {
1406   if (value == nullptr) {
1407     MS_LOG(ERROR) << "Value is null.";
1408     return false;
1409   }
1410   MS_EXCEPTION_IF_NULL(attr_proto);
1411   if (value->isa<StringImm>() || value->isa<Scalar>()) {
1412     return SetScalarToAttributeProto_ir(value, attr_proto);
1413   } else if (value->isa<Number>() || value->isa<TensorType>()) {
1414     return SetTypeToAttributeProto(value, attr_proto);
1415   } else if (value->isa<ValueSequence>()) {
1416     if (!SetSequenceToAttributeProto(value->cast<ValueSequencePtr>(), attr_proto)) {
1417       MS_LOG(ERROR) << "Set sequence to AttributeProto failed.";
1418       return false;
1419     }
1420     MS_LOG(DEBUG) << "Attr string: " << value->type_name();
1421   } else if (value->isa<ValueDictionary>()) {
1422     if (!SetDictToAttributeProto(value->cast<ValueDictionaryPtr>(), attr_proto)) {
1423       MS_LOG(ERROR) << "Set dictionary to AttributeProto failed.";
1424       return false;
1425     }
1426     MS_LOG(DEBUG) << "Attr string: " << value->type_name();
1427   } else if (value->isa<tensor::Tensor>()) {
1428     return SetTensorToAttributeProto(value, attr_proto);
1429   } else if (value->isa<Named>()) {
1430     return SetNamedValueToAttributeProto(value, attr_proto);
1431   } else if (value->isa<TypeNull>()) {
1432     attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TYPE_NULL);
1433     MS_LOG(DEBUG) << "Attr string: " << value->type_name();
1434   } else if (value->isa<Monad>()) {
1435     if (value->isa<UMonad>()) {
1436       attr_proto->set_type(mind_ir::AttributeProto_AttributeType_UMONAD);
1437     } else if (value->isa<IOMonad>()) {
1438       attr_proto->set_type(mind_ir::AttributeProto_AttributeType_IOMONAD);
1439     } else {
1440       MS_LOG(ERROR) << "Unsupported Monad type: " << value->type_name();
1441       return false;
1442     }
1443   } else if (value->isa<QuantizationParam>()) {
1444     auto quantization_param = value->cast<std::shared_ptr<QuantizationParam>>();
1445     attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSORS);
1446     auto tensor_proto = attr_proto->add_tensors();
1447     tensor_proto->set_name(attr_proto->name());
1448     auto quant_param_proto = tensor_proto->add_quant_params();
1449     auto ret = SetQuantizationParamToAttrProto(quantization_param, quant_param_proto);
1450     if (ret != true) {
1451       MS_LOG(ERROR) << "QuantizationParam Set Value to AttributeProto Error";
1452       return false;
1453     }
1454   } else if (value->isa<Functor>()) {
1455     return SetFunctorToAttrProto(value->cast<FunctorPtr>(), attr_proto);
1456   } else {
1457     MS_LOG(ERROR) << "Unsupported type: " << value->type_name();
1458     return false;
1459   }
1460   return true;
1461 }
1462 
SetScalarToAttributeProto_ir(const ValuePtr & value,mind_ir::AttributeProto * const attr_proto) const1463 bool IrExportBuilder::SetScalarToAttributeProto_ir(const ValuePtr &value,
1464                                                    mind_ir::AttributeProto *const attr_proto) const {
1465   if (value == nullptr || attr_proto == nullptr) {
1466     MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!";
1467   }
1468   if (value->isa<StringImm>()) {
1469     attr_proto->set_type(mind_ir::AttributeProto_AttributeType_STRING);
1470     attr_proto->set_s(GetValue<std::string>(value));
1471   } else if (value->isa<BoolImm>()) {
1472     attr_proto->set_type(mind_ir::AttributeProto_AttributeType_BOOL);
1473     int64_t attr_value = GetValue<bool>(value) ? 1 : 0;
1474     attr_proto->set_i(attr_value);
1475   } else if (SetScalarToAttributeProtoForInt_ir(value, attr_proto)) {
1476     return true;
1477   } else if (value->isa<FP32Imm>()) {
1478     attr_proto->set_type(mind_ir::AttributeProto_AttributeType_FLOAT);
1479     attr_proto->set_f(GetValue<float>(value));
1480   } else if (value->isa<FP64Imm>()) {
1481     attr_proto->set_type(mind_ir::AttributeProto_AttributeType_DOUBLE);
1482     attr_proto->set_d(GetValue<double>(value));
1483   } else {
1484     MS_LOG(ERROR) << "Unsupported scalar type: " << value->type_name();
1485     return false;
1486   }
1487   return true;
1488 }
1489 
SetScalarToAttributeProtoForInt_ir(const ValuePtr & value,mind_ir::AttributeProto * const attr_proto) const1490 bool IrExportBuilder::SetScalarToAttributeProtoForInt_ir(const ValuePtr &value,
1491                                                          mind_ir::AttributeProto *const attr_proto) const {
1492   if (value->isa<Int8Imm>()) {
1493     attr_proto->set_type(mind_ir::AttributeProto_AttributeType_INT8);
1494     attr_proto->set_i(value->cast<Int8ImmPtr>()->value());
1495   } else if (value->isa<Int16Imm>()) {
1496     attr_proto->set_type(mind_ir::AttributeProto_AttributeType_INT16);
1497     attr_proto->set_i(value->cast<Int16ImmPtr>()->value());
1498   } else if (value->isa<Int32Imm>()) {
1499     attr_proto->set_type(mind_ir::AttributeProto_AttributeType_INT32);
1500     attr_proto->set_i(value->cast<Int32ImmPtr>()->value());
1501   } else if (value->isa<Int64Imm>()) {
1502     attr_proto->set_type(mind_ir::AttributeProto_AttributeType_INT64);
1503     attr_proto->set_i(value->cast<Int64ImmPtr>()->value());
1504   } else if (value->isa<UInt8Imm>()) {
1505     attr_proto->set_type(mind_ir::AttributeProto_AttributeType_UINT8);
1506     attr_proto->set_i(value->cast<UInt8ImmPtr>()->value());
1507   } else if (value->isa<UInt16Imm>()) {
1508     attr_proto->set_type(mind_ir::AttributeProto_AttributeType_UINT16);
1509     attr_proto->set_i(value->cast<UInt16ImmPtr>()->value());
1510   } else if (value->isa<UInt32Imm>()) {
1511     attr_proto->set_type(mind_ir::AttributeProto_AttributeType_UINT32);
1512     attr_proto->set_i(value->cast<UInt32ImmPtr>()->value());
1513   } else if (value->isa<UInt64Imm>()) {
1514     attr_proto->set_type(mind_ir::AttributeProto_AttributeType_UINT64);
1515     attr_proto->set_i(UlongToLong(value->cast<UInt64ImmPtr>()->value()));
1516   } else {
1517     return false;
1518   }
1519   return true;
1520 }
1521 
SetTypeToAttributeProto_irs(const ValuePtr & value,mind_ir::AttributeProto * const attr_proto)1522 bool IrExportBuilder::SetTypeToAttributeProto_irs(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) {
1523   if (attr_proto == nullptr) {
1524     MS_LOG(EXCEPTION) << "AttributeProto is null!";
1525   }
1526   if (value->isa<Int>()) {
1527     attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSORS);
1528     mind_ir::TensorProto *tensor_proto = attr_proto->add_tensors();
1529     auto int_value = value->cast<IntPtr>();
1530     auto data_type = GetMindirDataBitsIntType(int_value->nbits());
1531     if (data_type == mind_ir::TensorProto_DataType_UNDEFINED) {
1532       return false;
1533     }
1534     tensor_proto->set_data_type(data_type);
1535   } else if (value->isa<Float>()) {
1536     attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSORS);
1537     mind_ir::TensorProto *tensor_proto = attr_proto->add_tensors();
1538     auto float_value = value->cast<FloatPtr>();
1539     auto data_type = GetMindirDataBitsFloatType(float_value->nbits());
1540     if (data_type == mind_ir::TensorProto_DataType_UNDEFINED) {
1541       return false;
1542     }
1543     tensor_proto->set_data_type(data_type);
1544   } else if (value->isa<UInt>()) {
1545     attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSORS);
1546     mind_ir::TensorProto *tensor_proto = attr_proto->add_tensors();
1547     auto uint_value = value->cast<UIntPtr>();
1548     auto data_type = GetMindirDataBitsUIntType(uint_value->nbits());
1549     if (data_type == mind_ir::TensorProto_DataType_UNDEFINED) {
1550       return false;
1551     }
1552     tensor_proto->set_data_type(data_type);
1553   } else if (value->isa<Bool>()) {
1554     attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSORS);
1555     mind_ir::TensorProto *tensor_proto = attr_proto->add_tensors();
1556     tensor_proto->set_data_type(mind_ir::TensorProto_DataType_BOOL);
1557   } else if (value->isa<tensor::Tensor>()) {
1558     attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSORS);
1559     return SetTensorToAttributeProto(value, attr_proto);
1560   } else if (value->isa<QuantizationParam>()) {
1561     auto quantization_param = value->cast<std::shared_ptr<QuantizationParam>>();
1562     attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSORS);
1563     auto tensor_proto = attr_proto->add_tensors();
1564     tensor_proto->set_name("quant_param");
1565     auto quant_param_proto = tensor_proto->add_quant_params();
1566     auto ret = SetQuantizationParamToAttrProto(quantization_param, quant_param_proto);
1567     if (ret != true) {
1568       MS_LOG(ERROR) << "QuantizationParam Set Value to AttributeProto Error";
1569       return false;
1570     }
1571   } else {
1572     MS_LOG(EXCEPTION) << "Unsupported type: " << value->type_name();
1573   }
1574   return true;
1575 }
1576 
SetScalarToAttributeProto_irs(const ValuePtr & value,mind_ir::AttributeProto * const attr_proto) const1577 bool IrExportBuilder::SetScalarToAttributeProto_irs(const ValuePtr &value,
1578                                                     mind_ir::AttributeProto *const attr_proto) const {
1579   if (attr_proto == nullptr) {
1580     MS_LOG(EXCEPTION) << "AttributeProto is null!";
1581   }
1582   if (value->isa<StringImm>()) {
1583     attr_proto->set_type(mind_ir::AttributeProto_AttributeType_STRING);
1584     attr_proto->add_strings(GetValue<std::string>(value));
1585   } else if (value->isa<BoolImm>()) {
1586     attr_proto->set_type(mind_ir::AttributeProto_AttributeType_BOOL);
1587     attr_proto->add_ints(GetValue<bool>(value));
1588   } else if (SetScalarToAttributeProtoForInt_irs(value, attr_proto)) {
1589     return true;
1590   } else if (value->isa<FP32Imm>()) {
1591     attr_proto->set_type(mind_ir::AttributeProto_AttributeType_FLOAT);
1592     attr_proto->add_floats(GetValue<float>(value));
1593   } else if (value->isa<FP64Imm>()) {
1594     attr_proto->set_type(mind_ir::AttributeProto_AttributeType_DOUBLE);
1595     attr_proto->add_doubles(GetValue<double>(value));
1596   } else {
1597     MS_LOG(ERROR) << "Unsupported scalar type: " << value->type_name();
1598     return false;
1599   }
1600   return true;
1601 }
1602 
SetScalarToAttributeProtoForInt_irs(const ValuePtr & value,mind_ir::AttributeProto * const attr_proto) const1603 bool IrExportBuilder::SetScalarToAttributeProtoForInt_irs(const ValuePtr &value,
1604                                                           mind_ir::AttributeProto *const attr_proto) const {
1605   if (value->isa<Int8Imm>()) {
1606     attr_proto->set_type(mind_ir::AttributeProto_AttributeType_INT8);
1607     attr_proto->add_ints(value->cast<Int8ImmPtr>()->value());
1608   } else if (value->isa<Int16Imm>()) {
1609     attr_proto->set_type(mind_ir::AttributeProto_AttributeType_INT16);
1610     attr_proto->add_ints(value->cast<Int16ImmPtr>()->value());
1611   } else if (value->isa<Int32Imm>()) {
1612     attr_proto->set_type(mind_ir::AttributeProto_AttributeType_INT32);
1613     attr_proto->add_ints(value->cast<Int32ImmPtr>()->value());
1614   } else if (value->isa<Int64Imm>()) {
1615     attr_proto->set_type(mind_ir::AttributeProto_AttributeType_INT64);
1616     attr_proto->add_ints(value->cast<Int64ImmPtr>()->value());
1617   } else if (value->isa<UInt8Imm>()) {
1618     attr_proto->set_type(mind_ir::AttributeProto_AttributeType_UINT8);
1619     attr_proto->add_ints(value->cast<UInt8ImmPtr>()->value());
1620   } else if (value->isa<UInt16Imm>()) {
1621     attr_proto->set_type(mind_ir::AttributeProto_AttributeType_UINT16);
1622     attr_proto->add_ints(value->cast<UInt16ImmPtr>()->value());
1623   } else if (value->isa<UInt32Imm>()) {
1624     attr_proto->set_type(mind_ir::AttributeProto_AttributeType_UINT32);
1625     attr_proto->add_ints(value->cast<UInt32ImmPtr>()->value());
1626   } else if (value->isa<UInt64Imm>()) {
1627     attr_proto->set_type(mind_ir::AttributeProto_AttributeType_UINT64);
1628     attr_proto->add_ints(SizeToInt(value->cast<UInt64ImmPtr>()->value()));
1629   } else {
1630     return false;
1631   }
1632   return true;
1633 }
1634 
SetSeqElemToAttributeProto(const ValuePtr & value,mind_ir::AttributeProto * const attr_proto)1635 bool IrExportBuilder::SetSeqElemToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) {
1636   if (value == nullptr) {
1637     MS_LOG(ERROR) << "Value is nullptr";
1638     return false;
1639   }
1640   if (value->isa<StringImm>() || value->isa<Scalar>()) {
1641     return SetScalarToAttributeProto_irs(value, attr_proto);
1642   }
1643   return SetTypeToAttributeProto_irs(value, attr_proto);
1644 }
1645 
SetSequenceToAttributeProto(const ValueSequencePtr & value,mind_ir::AttributeProto * const attr_proto)1646 bool IrExportBuilder::SetSequenceToAttributeProto(const ValueSequencePtr &value,
1647                                                   mind_ir::AttributeProto *const attr_proto) {
1648   if (value == nullptr || attr_proto == nullptr) {
1649     MS_LOG(EXCEPTION) << "ValueSequencePtr or AttributeProto is null!";
1650   }
1651   if (value->isa<ValueTuple>()) {
1652     attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TUPLE);
1653   } else if (value->isa<ValueList>()) {
1654     attr_proto->set_type(mind_ir::AttributeProto_AttributeType_LIST);
1655   } else {
1656     MS_LOG(EXCEPTION) << "The sequance value should be ValueTuple or ValueList, but it is " << value->ToString();
1657   }
1658   auto value_sequence = value->cast<ValueSequencePtr>();
1659   MS_EXCEPTION_IF_NULL(value_sequence);
1660   const auto &values = value_sequence->value();
1661   if (values.empty()) {
1662     MS_LOG(DEBUG) << "SetSequenceToAttributeProto sequence size is 0";
1663     return true;
1664   }
1665   for (const auto &item : values) {
1666     mind_ir::AttributeProto *attr_values = attr_proto->add_values();
1667     MS_EXCEPTION_IF_NULL(item);
1668     if (item->isa<ValueSequence>()) {
1669       if (!SetSequenceToAttributeProto(item->cast<ValueSequencePtr>(), attr_values)) {
1670         MS_LOG(ERROR) << "Set sequence to AttributeProto failed.";
1671         return false;
1672       }
1673     } else {
1674       if (!SetSeqElemToAttributeProto(item, attr_values)) {
1675         MS_LOG(ERROR) << "Set seq elem to AttributeProto failed.";
1676         return false;
1677       }
1678     }
1679   }
1680   return true;
1681 }
1682 
SetDictToAttributeProto(const ValueDictionaryPtr & value_dict,mind_ir::AttributeProto * const attr_proto)1683 bool IrExportBuilder::SetDictToAttributeProto(const ValueDictionaryPtr &value_dict,
1684                                               mind_ir::AttributeProto *const attr_proto) {
1685   if (value_dict == nullptr || attr_proto == nullptr) {
1686     MS_LOG(EXCEPTION) << "ValueDictionaryPtr or AttributeProto is null!";
1687   }
1688   attr_proto->set_type(mind_ir::AttributeProto_AttributeType_DICT);
1689   const auto &values = value_dict->value();
1690   if (values.empty()) {
1691     MS_LOG(DEBUG) << "SetDictToAttributeProto dictionary size is 0";
1692     return true;
1693   }
1694   for (const auto &item : values) {
1695     mind_ir::AttributeProto *dict_item_proto = attr_proto->add_values();
1696     const auto &key = item.first;
1697     dict_item_proto->set_name(GetValue<std::string>(key));
1698     const auto &value = item.second;
1699     MS_EXCEPTION_IF_NULL(value);
1700     mind_ir::AttributeProto *dict_item_value = dict_item_proto->add_values();
1701     if (value->isa<ValueSequence>()) {
1702       if (!SetSequenceToAttributeProto(value->cast<ValueSequencePtr>(), dict_item_value)) {
1703         MS_LOG(ERROR) << "Set sequence to AttributeProto failed.";
1704         return false;
1705       }
1706     } else if (value->isa<ValueDictionary>()) {
1707       if (!SetDictToAttributeProto(value->cast<ValueDictionaryPtr>(), dict_item_value)) {
1708         MS_LOG(ERROR) << "Set dictionary to AttributeProto failed.";
1709         return false;
1710       }
1711     } else if (value->isa<StringImm>() || value->isa<Scalar>()) {
1712       if (!SetScalarToAttributeProto_irs(value, dict_item_value)) {
1713         MS_LOG(ERROR) << "Set StringImm or Scalar to AttributeProto failed.";
1714         return false;
1715       }
1716     } else if (value->isa<Number>() || value->isa<tensor::Tensor>()) {
1717       if (!SetTypeToAttributeProto_irs(value, dict_item_value)) {
1718         MS_LOG(ERROR) << "Set Number or Tensor to AttributeProto failed.";
1719         return false;
1720       }
1721     } else {
1722       MS_LOG(EXCEPTION) << "Unsupported type while converting ValueDictionary to AttributeProto: "
1723                         << value->type_name();
1724     }
1725   }
1726   return true;
1727 }
1728 
BuildCNodeAttr(const CNodePtr & node,mind_ir::NodeProto * const node_proto)1729 bool IrExportBuilder::BuildCNodeAttr(const CNodePtr &node, mind_ir::NodeProto *const node_proto) {
1730   for (const auto &attr : node->attrs()) {
1731     if (attr.second == nullptr) {
1732       MS_LOG(ERROR) << "attr: " << attr.first << " has no value.";
1733       continue;
1734     }
1735     mind_ir::AttributeProto *attr_proto = node_proto->add_node_attr();
1736     attr_proto->set_name(attr.first);
1737     if (!SetValueToAttributeProto(attr.second, attr_proto)) {
1738       MS_LOG(ERROR) << "Set value to node attr to node proto failed.";
1739       MS_LOG(ERROR) << "node :" << node->DebugString() << "attr:{" << attr.first << "," << attr.second << "}";
1740       return false;
1741     }
1742   }
1743 
1744   for (const auto &attr : node->primal_attrs()) {
1745     if (attr.second == nullptr) {
1746       MS_LOG(ERROR) << "attr: " << attr.first << " has no value.";
1747       continue;
1748     }
1749     mind_ir::AttributeProto *attr_proto = node_proto->add_primal_attr();
1750     attr_proto->set_name(attr.first);
1751     if (!SetValueToAttributeProto(attr.second, attr_proto)) {
1752       MS_LOG(ERROR) << "Set value to node primal attr to node proto failed.";
1753       MS_LOG(ERROR) << "node :" << node->DebugString() << "attr:{" << attr.first << "," << attr.second << "}";
1754       return false;
1755     }
1756   }
1757   return true;
1758 }
1759 
GetBinaryProtoString(const FuncGraphPtr & func_graph,const bool & incremental)1760 std::string GetBinaryProtoString(const FuncGraphPtr &func_graph, const bool &incremental) {
1761   auto builder = std::make_shared<IrExportBuilder>(incremental);
1762   if (builder == nullptr) {
1763     MS_LOG(ERROR) << "Create ir exporter failed!";
1764     return "";
1765   }
1766   auto exporter = std::make_shared<IrExporter>(builder);
1767   if (exporter == nullptr) {
1768     return "";
1769   }
1770   auto ret = exporter->GetDumpString(func_graph);
1771   return ret;
1772 }
1773 
GenBinaryProto(const FuncGraphPtr & func_graph)1774 ModelProtoPtr GenBinaryProto(const FuncGraphPtr &func_graph) {
1775   auto exporter = std::make_shared<IrExporter>(std::make_shared<IrExportBuilder>());
1776   return exporter->GetDumpProto(func_graph);
1777 }
1778 
DumpBinaryProto(const FuncGraphPtr & func_graph,const std::string & file_path)1779 bool DumpBinaryProto(const FuncGraphPtr &func_graph, const std::string &file_path) {
1780   auto exporter = std::make_shared<IrExporter>(std::make_shared<IrExportBuilder>());
1781   auto proto = exporter->GetDumpProto(func_graph);
1782   MindIRExporter mindir_exporter;
1783   if (proto == nullptr) {
1784     MS_LOG(ERROR) << "Get binary proto for graph " << func_graph->ToString() << " failed.";
1785     return false;
1786   }
1787   return mindir_exporter.SaveProtoToFile(proto.get(), file_path);
1788 }
1789 
DumpBinaryProto(const FuncGraphPtr & root_graph,const std::vector<FuncGraphPtr> & child_graphs,const std::vector<AnfNodePtr> & isolated_nodes,const std::string & file_path)1790 bool DumpBinaryProto(const FuncGraphPtr &root_graph, const std::vector<FuncGraphPtr> &child_graphs,
1791                      const std::vector<AnfNodePtr> &isolated_nodes, const std::string &file_path) {
1792   auto exporter = std::make_shared<IrExporter>(std::make_shared<IrExportBuilder>());
1793   auto proto = exporter->GetDumpProto(root_graph, child_graphs, isolated_nodes);
1794   if (proto == nullptr) {
1795     MS_LOG(ERROR) << "Get binary proto for graph " << root_graph->ToString() << " failed.";
1796     return false;
1797   }
1798   auto realpath = Common::CreatePrefixPath(file_path, true);
1799   if (!realpath.has_value()) {
1800     MS_LOG(ERROR) << "Get real path of file " << file_path << " failed.";
1801     return false;
1802   }
1803   ChangeFileMode(realpath.value(), S_IWUSR);
1804   std::ofstream fout(realpath.value());
1805   if (!fout.is_open()) {
1806     MS_LOG(ERROR) << "Open the file '" << realpath.value() << "' failed!" << ErrnoToString(errno);
1807     return false;
1808   }
1809   if (!proto->SerializeToOstream(&fout)) {
1810     MS_LOG(ERROR) << "Failed to write the mindir proto to file " << realpath.value();
1811     fout.close();
1812     return false;
1813   }
1814   fout.close();
1815   ChangeFileMode(realpath.value(), S_IRUSR);
1816   return true;
1817 }
1818 
ParserPath(const std::string & output_path)1819 bool MindIRExporter::ParserPath(const std::string &output_path) {
1820   if (!FileUtils::ParserPathAndModelName(output_path, &save_path_, &model_name_)) {
1821     MS_LOG(ERROR) << "parser save path and model name from output_path failed.";
1822     return false;
1823   }
1824 #ifdef _WIN32
1825   save_model_path_ = save_path_ + "\\" + model_name_ + ".mindir";
1826 #else
1827   save_model_path_ = save_path_ + "/" + model_name_ + ".mindir";
1828 #endif
1829   return true;
1830 }
1831 
ExportProto(const FuncGraphPtr & func_graph,const std::string & file_path,const FuncGraphPtr & param_layout_fg)1832 bool MindIRExporter::ExportProto(const FuncGraphPtr &func_graph, const std::string &file_path,
1833                                  const FuncGraphPtr &param_layout_fg) {
1834   if (func_graph == nullptr) {
1835     MS_LOG(ERROR) << "func_graph is nullptr.";
1836     return false;
1837   }
1838 
1839   if (!ParserPath(file_path)) {
1840     MS_LOG(ERROR) << "parse path failed.";
1841     return false;
1842   }
1843 
1844   // Serialize to protobuf using unique parameter name label.
1845   // Do preprocess on func_graph and check conditions for saving together.
1846   bool ret = PreProcSaveTogether(func_graph);
1847   if (!ret) {
1848     MS_LOG(ERROR) << "PreProcSaveTogether failed";
1849     return ret;
1850   }
1851 #ifdef ENABLE_DUMP_IR
1852   auto context = MsContext::GetInstance();
1853   MS_EXCEPTION_IF_NULL(context);
1854   if (context->CanDump(kIntroductory)) {
1855     DumpIR("PreProcSaveTogether.ir", func_graph);
1856   }
1857 #endif
1858 
1859   if (save_together_) {
1860     MS_LOG(INFO) << "SaveMindIRTogether";
1861     ret = SaveMindIRTogether();
1862   } else {
1863     MS_LOG(INFO) << "SplitSave";
1864     ret = SplitSave();
1865   }
1866   if (!ret) {
1867     MS_LOG(ERROR) << "save mindir weight failed.";
1868     return ret;
1869   }
1870   return true;
1871 }
1872 
SaveMindIRTogether()1873 bool MindIRExporter::SaveMindIRTogether() {
1874   for (auto &param_proto : *(model_proto_.mutable_graph()->mutable_parameter())) {
1875     std::string proto_name = param_proto.name();
1876     auto para = GetFgParaAccordingToProtoName(proto_name);
1877     if (para == nullptr) {
1878       return false;
1879     }
1880     if (!para->has_default()) {
1881       continue;
1882     }
1883     auto data = para->default_param()->cast<tensor::TensorPtr>();
1884     param_proto.clear_raw_data();
1885     param_proto.set_raw_data(data->data_c(), static_cast<size_t>(data->data().nbytes()));
1886   }
1887   return SaveProtoToFile(&model_proto_, save_model_path_);
1888 }
1889 
CreateParameterDir()1890 bool MindIRExporter::CreateParameterDir() {
1891 #ifdef _WIN32
1892   dir_name_ = save_path_ + "\\" + model_name_ + "_variables";
1893 #else
1894   dir_name_ = save_path_ + "/" + model_name_ + "_variables";
1895 #endif
1896   fs_ = system::Env::GetFileSystem();
1897   if (fs_ == nullptr) {
1898     MS_LOG(ERROR) << "create file system failed.";
1899     return false;
1900   }
1901 
1902   if (fs_->FileExist(dir_name_)) {
1903     if (!DeleteDirRecursively(dir_name_)) {
1904       return false;
1905     }
1906   }
1907 
1908   if (!fs_->CreateDir(dir_name_)) {
1909     MS_LOG(ERROR) << "create dir failed.";
1910     return false;
1911   }
1912 
1913   ChangeFileMode(dir_name_, S_IWUSR | S_IRUSR | S_IXUSR);
1914   return true;
1915 }
1916 
CreateExternalPath(const std::string & external_file)1917 std::string MindIRExporter::CreateExternalPath(const std::string &external_file) {
1918   dir_path_ = FileUtils::GetRealPath(dir_name_.c_str()).value();
1919   std::string external_local_path{};
1920 #ifdef _WIN32
1921   external_local_path = dir_path_ + "\\" + external_file;
1922 #else
1923   external_local_path = dir_path_ + "/" + external_file;
1924 #endif
1925   return external_local_path;
1926 }
1927 
SplitSave()1928 bool MindIRExporter::SplitSave() {
1929   MS_LOG(DEBUG) << "Parameters in the net capacity exceeds 1G, save MindIR model and parameters separately.";
1930   if (!CreateParameterDir()) {
1931     MS_LOG(ERROR) << "create parameter dir failed.";
1932     return false;
1933   }
1934 
1935   int index = 0;
1936   std::string external_local = "data_" + std::to_string(index);
1937   auto external_local_path = CreateExternalPath(external_local);
1938   if (fs_->FileExist(external_local_path)) {
1939     if (!fs_->DeleteFile(external_local_path)) {
1940       MS_LOG(ERROR) << "delete file failed.";
1941       return false;
1942     }
1943   }
1944   int64_t parameter_size = 0;
1945   int64_t offset = OFFSET;
1946 
1947   data_fs_ = FileUtils::OpenFile(external_local_path, std::ios::out | std::ios::binary | std::ios::trunc);
1948   if (data_fs_ == nullptr) {
1949     MS_LOG(ERROR) << "Open " << external_local_path << " failed";
1950     return false;
1951   }
1952   if (!ChangeParaDataFile(external_local)) {
1953     MS_LOG(ERROR) << "change parameter data file failed.";
1954     return false;
1955   }
1956 
1957   for (auto &param_proto : *(model_proto_.mutable_graph()->mutable_parameter())) {
1958     std::string proto_name = param_proto.name();
1959     auto para = GetFgParaAccordingToProtoName(proto_name);
1960     if (para == nullptr) {
1961       return false;
1962     }
1963     if (!para->has_default()) {
1964       continue;
1965     }
1966     auto data = para->default_param()->cast<tensor::TensorPtr>();
1967     int64_t data_length = static_cast<int64_t>(data->data().nbytes());
1968     int64_t append_size = 0;
1969     if (data_length % OFFSET != 0) {
1970       append_size = OFFSET - (data_length % OFFSET);
1971     }
1972     parameter_size += ((append_size + data_length) / PARA_ROUND);
1973     if (parameter_size > static_cast<int64_t>(TOTAL_SAVE)) {
1974       index++;
1975       external_local = "data_" + std::to_string(index);
1976       data_fs_->close();
1977       delete data_fs_;
1978       data_fs_ = nullptr;
1979 
1980       if (!ChangeParaDataFile(external_local)) {
1981         MS_LOG(ERROR) << "change parameter data file failed.";
1982         return false;
1983       }
1984       parameter_size = OFFSET / PARA_ROUND;
1985     }
1986     std::string external_local_data = model_name_ + "_variables/" + external_local;
1987     param_proto.mutable_external_data()->set_location(external_local_data);
1988     param_proto.mutable_external_data()->set_length(data_length);
1989     param_proto.mutable_external_data()->set_offset(offset);
1990 
1991     data_fs_->write(static_cast<const char *>(data->data_c()), data_length);
1992     auto append_data = new char[append_size];
1993     if (append_data == nullptr) {
1994       return false;
1995     }
1996     data_fs_->write(append_data, append_size);
1997     offset += (data_length + append_size);
1998     delete[] append_data;
1999   }
2000   std::string split_model_file_name = "";
2001 #ifdef _WIN32
2002   split_model_file_name = save_path_ + "\\" + model_name_ + "_graph.mindir";
2003 #else
2004   split_model_file_name = save_path_ + "/" + model_name_ + "_graph.mindir";
2005 #endif
2006   return SaveProtoToFile(&model_proto_, split_model_file_name);
2007 }
2008 
SaveProtoToFile(mind_ir::ModelProto * model_proto,const std::string & output_file)2009 bool MindIRExporter::SaveProtoToFile(mind_ir::ModelProto *model_proto, const std::string &output_file) {
2010   auto realpath = Common::CreatePrefixPath(output_file, true);
2011   if (!realpath.has_value()) {
2012     MS_LOG(ERROR) << "Get real path of file " << output_file << " failed.";
2013     return false;
2014   }
2015 
2016   ChangeFileMode(realpath.value(), S_IWUSR);
2017   std::ofstream fout(realpath.value());
2018   if (!fout.is_open()) {
2019     MS_LOG(ERROR) << "Open the file '" << realpath.value() << "' failed!" << ErrnoToString(errno);
2020     return false;
2021   }
2022 
2023   if (!model_proto->SerializeToOstream(&fout)) {
2024     MS_LOG(ERROR) << "Failed to write the mindir proto to file " << realpath.value();
2025     fout.close();
2026     return false;
2027   }
2028 
2029   fout.close();
2030   ChangeFileMode(realpath.value(), S_IRUSR);
2031   return true;
2032 }
2033 
ChangeParaDataFile(const std::string & file)2034 bool MindIRExporter::ChangeParaDataFile(const std::string &file) {
2035   auto real_path = CreateExternalPath(file);
2036   if (fs_->FileExist(real_path)) {
2037     if (!fs_->DeleteFile(real_path)) {
2038       MS_LOG(ERROR) << "delete file failed.";
2039       return false;
2040     }
2041   }
2042   ChangeFileMode(real_path, S_IWUSR);
2043   data_fs_ = FileUtils::OpenFile(real_path, std::ios::app);
2044   if (data_fs_ == nullptr) {
2045     MS_LOG(ERROR) << "data_fs_ is nullptr.";
2046     return false;
2047   }
2048   char front_info[OFFSET]{0};
2049   front_info[0] = IsSystemLittleEndidan();
2050   (void)data_fs_->write(front_info, OFFSET);
2051   return true;
2052 }
2053 
IsSystemLittleEndidan() const2054 bool MindIRExporter::IsSystemLittleEndidan() const {
2055   int check = 0x01;
2056   auto address = reinterpret_cast<char *>(&check);
2057   return *address == 0x01;
2058 }
2059 
PreProcSaveTogether(const FuncGraphPtr & func_graph)2060 bool MindIRExporter::PreProcSaveTogether(const FuncGraphPtr &func_graph) {
2061   if (func_graph == nullptr) {
2062     MS_LOG(ERROR) << "func_graph is nullptr.";
2063     return false;
2064   }
2065 
2066   if (!UpdateParamCount(func_graph)) {
2067     MS_LOG(ERROR) << "Update parameter count failed.";
2068     return false;
2069   }
2070 
2071   // Parse func_graph as model proto
2072   std::string proto_string = GetBinaryProtoString(func_graph);
2073   if (proto_string.empty()) {
2074     MS_LOG(ERROR) << "parse proto string failed.";
2075     return false;
2076   }
2077 
2078   if (!model_proto_.ParseFromString(proto_string)) {
2079     MS_LOG(ERROR) << "parse model proto from string failed.";
2080     return false;
2081   }
2082 
2083   if (!ParamDict(func_graph)) {
2084     MS_LOG(ERROR) << "parse param form funcgraph failed.";
2085     return false;
2086   }
2087 
2088   if (!IfSaveTogether(&save_together_)) {
2089     MS_LOG(ERROR) << "error occur when check condition of saving together.";
2090     return false;
2091   }
2092 
2093   return true;
2094 }
2095 
IfSaveTogether(bool * save_together)2096 bool MindIRExporter::IfSaveTogether(bool *save_together) {
2097   size_t data_total = model_proto_.ByteSizeLong();
2098   for (auto &param_proto : model_proto_.graph().parameter()) {
2099     std::string proto_name = param_proto.name();
2100     auto para = GetFgParaAccordingToProtoName(proto_name);
2101     if (para == nullptr) {
2102       return false;
2103     }
2104     if (!para->has_default()) {
2105       continue;
2106     }
2107     auto tensor = std::dynamic_pointer_cast<tensor::Tensor>(para->default_param());
2108     if (tensor == nullptr) {
2109       MS_LOG(ERROR) << "param node default_param is not tensor.";
2110       return false;
2111     }
2112     data_total += tensor->Size();
2113   }
2114   if (data_total > TOTAL_SAVE) {
2115     *save_together = false;
2116   } else {
2117     *save_together = false;
2118   }
2119   return true;
2120 }
2121 
GetFgParaAccordingToProtoName(const std::string & proto_name)2122 std::shared_ptr<Parameter> MindIRExporter::GetFgParaAccordingToProtoName(const std::string &proto_name) {
2123   auto beg_pos = proto_name.find_first_of(':') + 1;
2124   if (beg_pos >= proto_name.size()) {
2125     MS_LOG(ERROR) << "begin pos exceed proto name length.";
2126     return nullptr;
2127   }
2128   auto name = proto_name.substr(beg_pos);
2129   if (param_dict_.find(name) == param_dict_.end()) {
2130     MS_LOG(ERROR) << "param proto name: " << name << " is not in param dict.";
2131     return nullptr;
2132   }
2133   return param_dict_.at(name);
2134 }
2135 
UpdateParamCount(const FuncGraphPtr & func_graph)2136 bool MindIRExporter::UpdateParamCount(const FuncGraphPtr &func_graph) {
2137   auto fv_count = 0;
2138   std::vector<AnfNodePtr> params;
2139   std::vector<AnfNodePtr> reorder_param;
2140   reorder_param.reserve(func_graph->parameters().size());
2141   for (const auto &node : func_graph->parameters()) {
2142     auto param_node = node->cast<ParameterPtr>();
2143     if (param_node == nullptr) {
2144       MS_LOG(ERROR) << "The parameters() in func graph should be all Parameter Node. but got " << node->DebugString();
2145       return false;
2146     }
2147     if (param_node->has_default()) {
2148       (void)params.emplace_back(param_node);
2149       ++fv_count;
2150       continue;
2151     }
2152     (void)reorder_param.emplace_back(param_node);
2153   }
2154 
2155   std::copy(params.begin(), params.end(), std::back_inserter(reorder_param));
2156   func_graph->set_parameters(reorder_param);
2157   func_graph->set_fv_param_count(fv_count);
2158   return true;
2159 }
2160 
ParamDict(const FuncGraphPtr & func_graph)2161 bool MindIRExporter::ParamDict(const FuncGraphPtr &func_graph) {
2162   std::set<FuncGraphPtr> all_func_graphs = {};
2163   GetAllFuncGraphs(func_graph, &all_func_graphs);
2164   for (auto &fg : all_func_graphs) {
2165     for (auto &para : fg->parameters()) {
2166       if (!para->isa<Parameter>()) {
2167         MS_LOG(ERROR) << "fg parameters contains non-parameter type node.";
2168         return false;
2169       }
2170       auto para_node = para->cast<ParameterPtr>();
2171       param_dict_[para->ToString()] = para_node;
2172     }
2173   }
2174   return true;
2175 }
2176 }  // namespace mindspore
2177