• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <map>
18 #include <memory>
19 #include <unordered_map>
20 #include <utility>
21 #include <algorithm>
22 #include <functional>
23 
24 #include "ir/tensor.h"
25 #include "ir/param_info.h"
26 #include "ir/func_graph.h"
27 #include "base/core_ops.h"
28 #include "proto/mind_ir.pb.h"
29 #include "utils/check_convert_utils.h"
30 
31 namespace mindspore {
32 using FloatPtr = std::shared_ptr<Float>;
33 using IntPtr = std::shared_ptr<Int>;
34 using UIntPtr = std::shared_ptr<UInt>;
35 // anf type to mindir type map
36 static std::unordered_map<int, mind_ir::TensorProto_DataType> g_data_type_map = {
37   {kNumberTypeBool, mind_ir::TensorProto_DataType_BOOL},
38   {kNumberTypeInt8, mind_ir::TensorProto_DataType_INT8},
39   {kNumberTypeInt16, mind_ir::TensorProto_DataType_INT16},
40   {kNumberTypeInt32, mind_ir::TensorProto_DataType_INT32},
41   {kNumberTypeInt64, mind_ir::TensorProto_DataType_INT64},
42   {kNumberTypeUInt8, mind_ir::TensorProto_DataType_UINT8},
43   {kNumberTypeUInt16, mind_ir::TensorProto_DataType_UINT16},
44   {kNumberTypeUInt32, mind_ir::TensorProto_DataType_UINT32},
45   {kNumberTypeUInt64, mind_ir::TensorProto_DataType_UINT64},
46   {kNumberTypeFloat16, mind_ir::TensorProto_DataType_FLOAT16},
47   {kNumberTypeFloat32, mind_ir::TensorProto_DataType_FLOAT},
48   {kNumberTypeFloat64, mind_ir::TensorProto_DataType_DOUBLE},
49   {kObjectTypeString, mind_ir::TensorProto_DataType_STRING},
50 };
51 
52 static std::unordered_map<int, mind_ir::TensorProto_DataType> g_data_bits_int_map = {
53   {8, mind_ir::TensorProto_DataType_INT8},
54   {16, mind_ir::TensorProto_DataType_INT16},
55   {32, mind_ir::TensorProto_DataType_INT32},
56   {64, mind_ir::TensorProto_DataType_INT64},
57 };
58 
59 static std::unordered_map<int, mind_ir::TensorProto_DataType> g_data_bits_uint_map = {
60   {8, mind_ir::TensorProto_DataType_UINT8},
61   {16, mind_ir::TensorProto_DataType_UINT16},
62   {32, mind_ir::TensorProto_DataType_UINT32},
63   {64, mind_ir::TensorProto_DataType_UINT64},
64 };
65 
66 static std::unordered_map<int, mind_ir::TensorProto_DataType> g_data_bits_float_map = {
67   {16, mind_ir::TensorProto_DataType_FLOAT16},
68   {32, mind_ir::TensorProto_DataType_FLOAT},
69   {64, mind_ir::TensorProto_DataType_FLOAT64},
70 };
71 
72 // Can build different builder according to format
73 class IrExportBuilder;
74 using IrExportBuilderPtr = std::shared_ptr<IrExportBuilder>;
75 
76 class IrExporter {
77  public:
IrExporter(IrExportBuilderPtr builder)78   explicit IrExporter(IrExportBuilderPtr builder) : builder_(std::move(builder)) {}
79   virtual ~IrExporter() = default;
80   std::string GetDumpString(const FuncGraphPtr &func_graph);
81   mind_ir::ModelProto GetDumpProto(const FuncGraphPtr &func_graph, bool save_tensor_data = false);
82 
83  private:
84   IrExportBuilderPtr builder_;
85 };
86 
87 class IrExportBuilder {
88  public:
89   IrExportBuilder() = default;
~IrExportBuilder()90   ~IrExportBuilder() { google::protobuf::ShutdownProtobufLibrary(); }
91   std::string GetProtoString() const;
92   void BuildModelInfo();
93   void BuildModel(const FuncGraphPtr &func_graph, bool save_tensor_data = false);
Model()94   mind_ir::ModelProto Model() { return model_; }
95 
96  private:
97   void BuildFuncGraph(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto,
98                       bool save_tensor_data = false);
99   void BuildParameters(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto,
100                        bool save_tensor_data = false);
101   void BuildNodes(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto);
102   void BuildOutput(const CNodePtr &node, mind_ir::GraphProto *const graph_proto);
103   void BuildCNode(const CNodePtr &node, mind_ir::GraphProto *const graph_proto);
104   std::string BuildInputNode(const AnfNodePtr &node, mind_ir::GraphProto *const graph_proto);
105 
106   void SetValueInfoProto(const AnfNodePtr &node, mind_ir::ValueInfoProto *const value_proto);
107   void SetValueInfoProto(const TypePtr &type, const BaseShapePtr &shape, mind_ir::ValueInfoProto *const value_proto);
108   void SetParamToTensorProto(const ParameterPtr &param, mind_ir::TensorProto *const tensor_proto);
109   void SetTensorProto(const TypePtr &type, const BaseShapePtr &shape, mind_ir::TensorProto *const tensor_proto);
110   void SetAttributeProto(const AnfNodePtr &node, mind_ir::NodeProto *const node_proto);
111   void SetShapeToNodeProto(const CNodePtr &node, mind_ir::NodeProto *const node_proto);
112   void SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, mind_ir::AttributeProto *const attr_proto,
113                            std::string *const seq_string);
114   void SetValueToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto);
115   void SetTypeToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto);
116   void SetScalarToAttributeProto_ir(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto);
117   void SetScalarToAttributeProto_irs(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto);
118   void SetTensorToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto);
119   void SetSequenceToAttributeProto(const ValueSequeuePtr &value, mind_ir::AttributeProto *const attr_proto,
120                                    std::string *const seq_string);
121   void SetSeqElemToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto,
122                                   std::string *const seq_string);
123 
124   mind_ir::TensorProto_DataType GetMindirDataType(TypeId type_id);
125   mind_ir::TensorProto_DataType GetMindirDataBitsIntType(int bits);
126   mind_ir::TensorProto_DataType GetMindirDataBitsFloatType(int bits);
127   mind_ir::TensorProto_DataType GetMindirDataBitsUIntType(int bits);
128   std::string GetNodeName(const AnfNodePtr &node);
129   std::string GetUniqueNodeName(const AnfNodePtr &node);
130   std::string GetOpTypeName(const AnfNodePtr &node);
GetNodeIndex()131   size_t GetNodeIndex() { return ++node_index_; }
ResetNodeIndex()132   void ResetNodeIndex() { node_index_ = 0; }
GetTupleIndex()133   size_t GetTupleIndex() { return ++shape_index_; }
ResetTupleIndex()134   void ResetTupleIndex() { shape_index_ = 0; }
135 
136   mind_ir::ModelProto model_;
137   mind_ir::NodeProto *last_node_{nullptr};
138   std::list<FuncGraphPtr> todo_;
139   std::map<AnfNodePtr, std::string> node_index_map_;
140   std::set<std::string> nodeName_;
141   size_t node_index_{0};
142   size_t shape_index_{0};
143   bool top_graph{true};
144 };
145 
146 using IrExporterPtr = std::shared_ptr<IrExporter>;
147 
GetDumpString(const FuncGraphPtr & func_graph)148 std::string IrExporter::GetDumpString(const FuncGraphPtr &func_graph) {
149   (void)GetDumpProto(func_graph);
150   return builder_->GetProtoString();
151 }
152 
GetDumpProto(const FuncGraphPtr & func_graph,bool save_tensor_data)153 mind_ir::ModelProto IrExporter::GetDumpProto(const FuncGraphPtr &func_graph, bool save_tensor_data) {
154   if ((builder_ == nullptr) || (func_graph == nullptr)) {
155     MS_LOG(EXCEPTION) << "Input params is null.";
156   }
157 
158   // Export model info
159   builder_->BuildModelInfo();
160 
161   // Export model and return string
162   builder_->BuildModel(func_graph, save_tensor_data);
163   return builder_->Model();
164 }
165 
GetProtoString() const166 std::string IrExportBuilder::GetProtoString() const {
167   MS_LOG(DEBUG) << "BuildModel complete!";
168   return model_.SerializeAsString();
169 }
170 
BuildModelInfo()171 void IrExportBuilder::BuildModelInfo() {
172   constexpr auto ir_version = "0.1.0";
173   constexpr auto mindspore_name = "MindSpore";
174   model_.set_ir_version(ir_version);
175   model_.set_producer_name(mindspore_name);
176   model_.set_model_version(VERSION);
177 }
178 
BuildModel(const FuncGraphPtr & func_graph,bool save_tensor_data)179 void IrExportBuilder::BuildModel(const FuncGraphPtr &func_graph, bool save_tensor_data) {
180   MS_EXCEPTION_IF_NULL(func_graph);
181   mind_ir::GraphProto *graph_proto = model_.mutable_graph();
182   graph_proto->set_name(func_graph->ToString());
183   graph_proto->set_bprop_hash(func_graph->bprop_hash());
184   ResetNodeIndex();
185   todo_.clear();
186   nodeName_.clear();
187   // Build the main funcGraph
188   (void)nodeName_.insert(func_graph->ToString());
189   top_graph = true;
190   BuildFuncGraph(func_graph, graph_proto, save_tensor_data);
191   std::set<FuncGraphPtr> graphVisited;
192   (void)graphVisited.insert(func_graph);
193   top_graph = false;
194   while (!todo_.empty()) {
195     FuncGraphPtr fg = todo_.back();
196     todo_.pop_back();
197     if (graphVisited.count(fg) > 0) {
198       continue;
199     }
200     if (nodeName_.count(fg->ToString()) > 0) {
201       MS_LOG(EXCEPTION) << "There is a duplicate name: " << fg->ToString();
202     }
203     (void)nodeName_.insert(fg->ToString());
204     (void)graphVisited.insert(fg);
205     auto graph = model_.add_functions();
206     BuildFuncGraph(fg, graph, save_tensor_data);
207   }
208   // Release resource
209   nodeName_.clear();
210   node_index_map_.clear();
211 }
212 
BuildFuncGraph(const FuncGraphPtr & func_graph,mind_ir::GraphProto * const graph_proto,bool save_tensor_data)213 void IrExportBuilder::BuildFuncGraph(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto,
214                                      bool save_tensor_data) {
215   // Export funcGraph name.
216   graph_proto->set_name(func_graph->ToString());
217   // Export parameters
218   // 1. parameters should be mapped to ValueInfoProto
219   // 2. parameters with default value should be mapped to Initializer
220   BuildParameters(func_graph, graph_proto, save_tensor_data);
221 
222   // Export operator nodes(include output)
223   BuildNodes(func_graph, graph_proto);
224 }
225 
BuildParameters(const FuncGraphPtr & func_graph,mind_ir::GraphProto * const graph_proto,bool save_tensor_data)226 void IrExportBuilder::BuildParameters(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto,
227                                       bool save_tensor_data) {
228   MS_EXCEPTION_IF_NULL(func_graph);
229   for (auto &item : func_graph->parameters()) {
230     MS_EXCEPTION_IF_NULL(item);
231     auto param = item->cast<ParameterPtr>();
232     if (param == nullptr) {
233       MS_LOG(EXCEPTION) << "Parameter: '" << item->ToString() << "' could not cast to parameter.";
234     }
235     std::string param_name = GetUniqueNodeName(param);
236     if (top_graph && param->has_default()) {
237       MS_LOG(DEBUG) << "Parameter: '" << item->DebugString();
238       mind_ir::TensorProto *parameter_proto = graph_proto->add_parameter();
239       parameter_proto->set_name(param_name);
240       SetParamToTensorProto(param, parameter_proto);
241       auto tensor = std::dynamic_pointer_cast<tensor::Tensor>(param->default_param());
242       if (tensor && save_tensor_data) {
243         parameter_proto->set_raw_data(tensor->data_c(), static_cast<size_t>(tensor->data().nbytes()));
244       }
245     } else {
246       mind_ir::ValueInfoProto *input_proto = graph_proto->add_input();
247       input_proto->set_name(param_name);
248       SetValueInfoProto(param, input_proto);
249     }
250     if (nodeName_.count(param_name) > 0) {
251       MS_LOG(EXCEPTION) << "parameter name is duplicate:" << param_name;
252     }
253     (void)nodeName_.insert(param_name);
254   }
255 }
256 
GetMindirDataType(TypeId type_id)257 mind_ir::TensorProto_DataType IrExportBuilder::GetMindirDataType(TypeId type_id) {
258   auto iter = g_data_type_map.find(type_id);
259   if (iter == g_data_type_map.end()) {
260     MS_LOG(EXCEPTION) << "Convert type error, unsupported type! " << type_id;
261   }
262   return iter->second;
263 }
264 
GetMindirDataBitsIntType(int bits)265 mind_ir::TensorProto_DataType IrExportBuilder::GetMindirDataBitsIntType(int bits) {
266   auto iter = g_data_bits_int_map.find(bits);
267   if (iter == g_data_bits_int_map.end()) {
268     MS_LOG(EXCEPTION) << "Convert bits int error, unsupported bits! " << bits;
269   }
270   return iter->second;
271 }
272 
GetMindirDataBitsUIntType(int bits)273 mind_ir::TensorProto_DataType IrExportBuilder::GetMindirDataBitsUIntType(int bits) {
274   auto iter = g_data_bits_uint_map.find(bits);
275   if (iter == g_data_bits_uint_map.end()) {
276     MS_LOG(EXCEPTION) << "Convert bits uint error, unsupported bits! " << bits;
277   }
278   return iter->second;
279 }
280 
GetMindirDataBitsFloatType(int bits)281 mind_ir::TensorProto_DataType IrExportBuilder::GetMindirDataBitsFloatType(int bits) {
282   auto iter = g_data_bits_float_map.find(bits);
283   if (iter == g_data_bits_float_map.end()) {
284     MS_LOG(EXCEPTION) << "Convert bits float error, unsupported bits! " << bits;
285   }
286   return iter->second;
287 }
288 
SetValueInfoProto(const AnfNodePtr & node,mind_ir::ValueInfoProto * const value_proto)289 void IrExportBuilder::SetValueInfoProto(const AnfNodePtr &node, mind_ir::ValueInfoProto *const value_proto) {
290   if (node == nullptr || value_proto == nullptr) {
291     MS_LOG(EXCEPTION) << "AnfNode or ValueInfo is null!";
292   }
293   MS_LOG(DEBUG) << "SetValueInfoProto: " << node->DebugString();
294   const TypePtr &type = node->Type();
295   const BaseShapePtr &shape = node->Shape();
296   if (type == nullptr || shape == nullptr) {
297     return;
298   }
299   if (type->isa<TensorType>() && shape->isa<abstract::Shape>()) {
300     auto tensor = type->cast<TensorTypePtr>();
301     MS_EXCEPTION_IF_NULL(tensor);
302     auto elem_type = tensor->element();
303     const auto &dims = shape->cast<abstract::ShapePtr>()->shape();
304     mind_ir::TensorProto *tensor_proto = value_proto->add_tensor();
305     tensor_proto->set_data_type(GetMindirDataType(elem_type->type_id()));
306     if (dims.size() == 0) {
307       MS_LOG(DEBUG) << "The dim of ValueInfoProto is 0.";
308     } else {
309       for (const auto &dim : dims) {
310         MS_LOG(DEBUG) << "SetValueInfoProto dim: " << dim;
311         tensor_proto->add_dims(dim);
312       }
313     }
314   } else if (type->isa<Tuple>()) {
315     auto tup_shape = shape->cast<abstract::TupleShapePtr>();
316     value_proto->set_denotation(type->type_name() + ":" + std::to_string(tup_shape->shape().size()));
317   } else {
318     value_proto->set_denotation(type->type_name());
319   }
320   MS_LOG(DEBUG) << "Value type: " << type->type_name();
321 }
322 
SetTensorToAttributeProto(const ValuePtr & value,mind_ir::AttributeProto * const attr_proto)323 void IrExportBuilder::SetTensorToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) {
324   if (value == nullptr || attr_proto == nullptr) {
325     MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!";
326   }
327   attr_proto->set_ref_attr_name("tensor:value0");
328   attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSORS);
329   mind_ir::TensorProto *tensor_proto = attr_proto->add_tensors();
330   tensor_proto->set_name("value0");
331   auto data = value->cast<tensor::TensorPtr>();
332   MS_EXCEPTION_IF_NULL(data);
333   tensor_proto->set_raw_data(data->data_c(), static_cast<size_t>(data->data().nbytes()));
334   auto dtype = data->data_type();
335   auto shape = data->shape_c();
336   tensor_proto->set_data_type(GetMindirDataType(dtype));
337   for (const auto &dim : shape) {
338     tensor_proto->add_dims(dim);
339   }
340 }
341 
SetTensorProto(const TypePtr & type,const BaseShapePtr & shape,mind_ir::TensorProto * const tensor_proto)342 void IrExportBuilder::SetTensorProto(const TypePtr &type, const BaseShapePtr &shape,
343                                      mind_ir::TensorProto *const tensor_proto) {
344   if (!type->isa<TensorType>() || !shape->isa<abstract::Shape>()) {
345     MS_LOG(EXCEPTION) << "Type or shape is not supported! " << type->ToString();
346   }
347   auto tensor = type->cast<TensorTypePtr>();
348   const auto &dims = shape->cast<abstract::ShapePtr>()->shape();
349   tensor_proto->set_data_type(GetMindirDataType(tensor->element()->type_id()));
350   for (const auto &dim : dims) {
351     tensor_proto->add_dims(dim);
352   }
353 }
354 
SetParamToTensorProto(const ParameterPtr & param,mind_ir::TensorProto * const tensor_proto)355 void IrExportBuilder::SetParamToTensorProto(const ParameterPtr &param, mind_ir::TensorProto *const tensor_proto) {
356   if (param == nullptr || tensor_proto == nullptr) {
357     MS_LOG(EXCEPTION) << "Parameter or TensorProto is null!";
358   }
359   MS_LOG(DEBUG) << "SetParamToTensorProto: " << param->DebugString();
360   SetTensorProto(param->Type(), param->Shape(), tensor_proto);
361 }
362 
BuildNodes(const FuncGraphPtr & func_graph,mind_ir::GraphProto * const graph_proto)363 void IrExportBuilder::BuildNodes(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto) {
364   std::vector<AnfNodePtr> nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude);
365   for (const AnfNodePtr &node : nodes) {
366     MS_EXCEPTION_IF_NULL(node);
367     if (!node->isa<CNode>()) {
368       MS_LOG(DEBUG) << "Node: '" << node->ToString() << "' is not cnode";
369       continue;
370     }
371     auto cnode = node->cast<CNodePtr>();
372     if (cnode == func_graph->get_return()) {
373       BuildOutput(cnode, graph_proto);
374     } else {
375       BuildCNode(cnode, graph_proto);
376     }
377   }
378 }
379 
BuildOutput(const CNodePtr & node,mind_ir::GraphProto * const graph_proto)380 void IrExportBuilder::BuildOutput(const CNodePtr &node, mind_ir::GraphProto *const graph_proto) {
381   MS_EXCEPTION_IF_NULL(node);
382   const int OutputSize = 2;
383   if (node->size() != OutputSize) {
384     MS_LOG(EXCEPTION) << "Number of inputs of return node is not equal to 2.";
385   }
386   AnfNodePtr arg = node->input(1);
387   std::string node_name = BuildInputNode(arg, graph_proto);
388   mind_ir::ValueInfoProto *output_proto = graph_proto->add_output();
389   output_proto->set_name(node_name);
390   SetValueInfoProto(arg, output_proto);
391 }
392 
GetOpTypeName(const AnfNodePtr & node)393 std::string IrExportBuilder::GetOpTypeName(const AnfNodePtr &node) {
394   // May be ValueNode/CNode/Parameter
395   std::string type_name = "";
396   if (IsValueNode<Primitive>(node)) {
397     PrimitivePtr prim = GetValueNode<PrimitivePtr>(node);
398     MS_EXCEPTION_IF_NULL(prim);
399     type_name = prim->ToString();
400   } else if (IsValueNode<FuncGraph>(node)) {
401     FuncGraphPtr fg = GetValueNode<FuncGraphPtr>(node);
402     MS_EXCEPTION_IF_NULL(fg);
403     todo_.push_back(fg);
404     type_name = "REF::" + fg->ToString();
405   } else if (node->isa<CNode>() || node->isa<Parameter>()) {
406     auto nodeName = GetUniqueNodeName(node);
407     type_name = "REF::" + nodeName;
408     if (nodeName_.count(nodeName) == 0) {
409       MS_LOG(EXCEPTION) << "There is not the name: " << nodeName;
410     }
411   } else {
412     MS_LOG(EXCEPTION) << "Need to support op type: " << node->type_name();
413   }
414   MS_LOG(DEBUG) << "ExportType: " << type_name;
415   return type_name;
416 }
417 
SetShapeToNodeProto(const TypePtr & type,const BaseShapePtr & shape,mind_ir::AttributeProto * const attr_proto,std::string * const seq_string)418 void IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape,
419                                           mind_ir::AttributeProto *const attr_proto, std::string *const seq_string) {
420   MS_EXCEPTION_IF_NULL(type);
421   MS_EXCEPTION_IF_NULL(shape);
422   MS_EXCEPTION_IF_NULL(seq_string);
423   if (type->isa<Tuple>()) {
424     *seq_string += "Tuple[";
425     auto elements = type->cast<TuplePtr>()->elements();
426     auto tuple_shape = shape->cast<abstract::TupleShapePtr>()->shape();
427     for (size_t i = 0; i < elements.size(); i++) {
428       SetShapeToNodeProto(elements[i], tuple_shape[i], attr_proto, seq_string);
429     }
430     *seq_string += "],";
431   } else if (type->isa<TensorType>() && shape->isa<abstract::Shape>()) {
432     string shape_name = "shape" + std::to_string(GetTupleIndex());
433     *seq_string += shape_name + ",";
434     mind_ir::TensorProto *tensor_proto = attr_proto->add_tensors();
435     tensor_proto->set_name(shape_name);
436     SetTensorProto(type, shape, tensor_proto);
437   } else if (type->isa<Number>()) {
438     if (type->isa<Bool>()) {
439       attr_proto->set_type(mind_ir::AttributeProto_AttributeType_BOOL);
440     } else {
441       string shape_name = "shape" + std::to_string(GetTupleIndex());
442       *seq_string += shape_name + ",";
443       mind_ir::TensorProto *tensor_proto = attr_proto->add_tensors();
444       tensor_proto->set_name(shape_name);
445       tensor_proto->set_data_type(mind_ir::TensorProto_DataType_UINT64);
446       tensor_proto->add_dims(1);
447     }
448   } else if (type->isa<Function>()) {
449     attr_proto->set_type(mind_ir::AttributeProto_AttributeType_GRAPH);
450     *seq_string += type->type_name() + ",";
451   } else if (type->isa<String>() || type->isa<UMonadType>() || type->isa<IOMonadType>()) {
452     *seq_string += type->type_name() + ",";
453   } else {
454     MS_LOG(EXCEPTION) << "Type of cnode need to be supported: " << type->type_name();
455   }
456 }
457 
SetShapeToNodeProto(const CNodePtr & node,mind_ir::NodeProto * const node_proto)458 void IrExportBuilder::SetShapeToNodeProto(const CNodePtr &node, mind_ir::NodeProto *const node_proto) {
459   // Get shape of cnode
460   // 1. need to get shape from tuple element
461   // 2. save shape in TensorProto
462   // 3. save tuple string in ref_attr_name
463   MS_EXCEPTION_IF_NULL(node);
464   auto type = node->Type();
465   auto shape = node->Shape();
466   if (type == nullptr || shape == nullptr) {
467     return;
468   }
469   ResetTupleIndex();
470   std::string seq_string = "shape:";
471   mind_ir::AttributeProto *attr_proto = node_proto->add_attribute();
472   SetShapeToNodeProto(type, shape, attr_proto, &seq_string);
473   attr_proto->set_ref_attr_name(seq_string);
474   MS_LOG(DEBUG) << "CNode shape: " << seq_string;
475 }
476 
BuildCNode(const CNodePtr & node,mind_ir::GraphProto * const graph_proto)477 void IrExportBuilder::BuildCNode(const CNodePtr &node, mind_ir::GraphProto *const graph_proto) {
478   auto inputs_size = node->size();
479   if (inputs_size < 1) {
480     MS_LOG(EXCEPTION) << "Inputs of apply node is empty";
481   }
482 
483   // Need to build input node before dealing with cnode
484   std::vector<AnfNodePtr> op_inputs;
485   std::vector<string> input_names;
486   for (size_t i = 1; i < inputs_size; i++) {
487     auto input = node->input(i);
488     op_inputs.push_back(input);
489     input_names.push_back(BuildInputNode(input, graph_proto));
490   }
491 
492   // Build cnode
493   mind_ir::NodeProto *node_proto = graph_proto->add_node();
494   std::string output_name = GetUniqueNodeName(node);
495   if (nodeName_.count(output_name) > 0) {
496     MS_LOG(EXCEPTION) << "There is a duplicate name: " << output_name;
497   }
498   (void)nodeName_.insert(output_name);
499   node_proto->add_output(output_name);
500   node_proto->set_name(output_name);
501   node_proto->set_domain(node->fullname_with_scope());
502   AnfNodePtr op = node->input(0);
503   std::string type_name = GetOpTypeName(op);
504   node_proto->set_op_type(type_name);
505   last_node_ = node_proto;
506   // Maybe Tensor or Function or nullptr
507   SetShapeToNodeProto(node, node_proto);
508 
509   (void)std::for_each(input_names.begin(), input_names.end(),
510                       [&node_proto](const string &name) { node_proto->add_input(name); });
511 
512   // Add primitive attrs
513   if (IsValueNode<Primitive>(op)) {
514     auto prim = GetValueNode<PrimitivePtr>(op);
515     for (auto attr : prim->attrs()) {
516       MS_LOG(DEBUG) << "attr: " << attr.first << " " << attr.second->DumpText() << " " << attr.second->type_name();
517       mind_ir::AttributeProto *attr_proto = node_proto->add_attribute();
518       attr_proto->set_name(attr.first);
519       auto attr_value = attr.second;
520       CheckAndConvertUtils::ConvertAttrValueInExport(type_name, attr.first, &attr_value);
521       SetValueToAttributeProto(attr_value, attr_proto);
522     }
523   }
524 }
525 
BuildInputNode(const AnfNodePtr & node,mind_ir::GraphProto * const graph_proto)526 std::string IrExportBuilder::BuildInputNode(const AnfNodePtr &node, mind_ir::GraphProto *const graph_proto) {
527   std::string node_name = GetUniqueNodeName(node);
528   // FuncGraph will be added to functions and the input name is the function name.
529   if (IsValueNode<FuncGraph>(node)) {
530     FuncGraphPtr fg = GetValueNode<FuncGraphPtr>(node);
531     todo_.push_back(fg);
532     return fg->ToString();
533   }
534   if (node->isa<ValueNode>()) {
535     // When node input is a ValueNode, need to create a Constant Node
536     mind_ir::NodeProto *node_proto = graph_proto->add_node();
537     node_proto->set_name(node_name);
538     node_proto->add_output(node_name);
539     SetAttributeProto(node, node_proto);
540   }
541   return node_name;
542 }
543 
GetUniqueNodeName(const AnfNodePtr & node)544 std::string IrExportBuilder::GetUniqueNodeName(const AnfNodePtr &node) {
545   // Naming anfnode
546   // 1. parameter is unique in one func_graph
547   // 2. cnode and valuenode may be reduplicative, so add index to identify.
548   auto iter = node_index_map_.find(node);
549   if (iter != node_index_map_.end()) {
550     return iter->second;
551   } else {
552     std::string node_name = GetNodeName(node);
553     // Compatible before. CNode = FuncGraphName:CNodeName:index ,Parameter = FuncGraphName:ParameterName
554     if (node->isa<CNode>()) {
555       node_name = node_name + ":" + std::to_string(GetNodeIndex());
556     }
557     // Avoid duplicate name.
558     while (nodeName_.count(node_name) > 0) {
559       node_name = node_name + "_" + std::to_string(GetNodeIndex());
560     }
561     node_index_map_[node] = node_name;
562     return node_name;
563   }
564 }
565 
GetNodeName(const AnfNodePtr & node)566 std::string IrExportBuilder::GetNodeName(const AnfNodePtr &node) {
567   MS_EXCEPTION_IF_NULL(node);
568   std::string node_name = "";
569   if (node->func_graph() != nullptr) {
570     node_name = node->func_graph()->ToString() + ":";
571   }
572   if (node->isa<ValueNode>()) {
573     // Needn't value
574     node_name += node->AnfNode::ToString();
575   } else {
576     node_name += node->ToString();
577   }
578   MS_LOG(DEBUG) << "GetNodeName: " << node_name;
579   return node_name;
580 }
581 
SetAttributeProto(const AnfNodePtr & node,mind_ir::NodeProto * const node_proto)582 void IrExportBuilder::SetAttributeProto(const AnfNodePtr &node, mind_ir::NodeProto *const node_proto) {
583   if (node == nullptr || node_proto == nullptr) {
584     MS_LOG(EXCEPTION) << "AnfNode or NodeProto is null!";
585   }
586   auto value_node = node->cast<ValueNodePtr>();
587   MS_EXCEPTION_IF_NULL(value_node);
588   auto value = value_node->value();
589   node_proto->set_op_type("Constant");
590   mind_ir::AttributeProto *attr_proto = node_proto->add_attribute();
591   attr_proto->set_name("value");
592   MS_LOG(DEBUG) << "Set Constant attribute: " << value->ToString();
593   SetValueToAttributeProto(value, attr_proto);
594 }
595 
SetTypeToAttributeProto(const ValuePtr & value,mind_ir::AttributeProto * const attr_proto)596 void IrExportBuilder::SetTypeToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) {
597   if (value == nullptr || attr_proto == nullptr) {
598     MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!";
599   }
600   attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSORS);
601   mind_ir::TensorProto *tensor_proto = attr_proto->add_tensors();
602   if (value->isa<Int>()) {
603     attr_proto->set_ref_attr_name("type:value0");
604     tensor_proto->set_name("value0");
605     auto int_value = value->cast<IntPtr>();
606     tensor_proto->set_data_type(GetMindirDataBitsIntType(int_value->nbits()));
607   } else if (value->isa<UInt>()) {
608     attr_proto->set_ref_attr_name("type:value0");
609     tensor_proto->set_name("value0");
610     auto float_value = value->cast<UIntPtr>();
611     tensor_proto->set_data_type(GetMindirDataBitsUIntType(float_value->nbits()));
612   } else if (value->isa<Float>()) {
613     attr_proto->set_ref_attr_name("type:value0");
614     tensor_proto->set_name("value0");
615     auto float_value = value->cast<FloatPtr>();
616     tensor_proto->set_data_type(GetMindirDataBitsFloatType(float_value->nbits()));
617   } else if (value->isa<Bool>()) {
618     attr_proto->set_ref_attr_name("type:value0");
619     tensor_proto->set_name("value0");
620     tensor_proto->set_data_type(mind_ir::TensorProto_DataType_BOOL);
621   } else if (value->isa<TensorType>()) {
622     attr_proto->set_ref_attr_name("type:tensor0");
623     tensor_proto->set_name("tensor0");
624     auto elem_type = value->cast<TensorTypePtr>()->element();
625     if (elem_type->isa<Int>()) {
626       auto int_value = elem_type->cast<IntPtr>();
627       tensor_proto->set_data_type(GetMindirDataBitsIntType(int_value->nbits()));
628     } else if (elem_type->isa<Float>()) {
629       auto float_value = elem_type->cast<FloatPtr>();
630       tensor_proto->set_data_type(GetMindirDataBitsFloatType(float_value->nbits()));
631     } else {
632       MS_LOG(EXCEPTION) << "Unsupported type " << elem_type->type_name();
633     }
634   } else {
635     MS_LOG(EXCEPTION) << "Unsupported type: " << value->type_name();
636   }
637 }
638 
SetValueToAttributeProto(const ValuePtr & value,mind_ir::AttributeProto * const attr_proto)639 void IrExportBuilder::SetValueToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) {
640   if (value == nullptr || attr_proto == nullptr) {
641     MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!";
642   }
643   if (value->isa<StringImm>() || value->isa<Scalar>()) {
644     SetScalarToAttributeProto_ir(value, attr_proto);
645   } else if (value->isa<Number>() || value->isa<TensorType>()) {
646     SetTypeToAttributeProto(value, attr_proto);
647   } else if (value->isa<ValueSequeue>()) {
648     ResetTupleIndex();
649     std::string seq_string = "scalar:";
650     attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSORS);
651     SetSequenceToAttributeProto(value->cast<ValueSequeuePtr>(), attr_proto, &seq_string);
652     attr_proto->set_ref_attr_name(seq_string);
653     MS_LOG(DEBUG) << "Attr string: " << seq_string;
654   } else if (value->isa<tensor::Tensor>()) {
655     SetTensorToAttributeProto(value, attr_proto);
656   } else if (value->isa<None>()) {
657     attr_proto->set_ref_attr_name("none");
658     MS_LOG(DEBUG) << "Attr string: " << value->type_name();
659   } else if (value->isa<Monad>()) {
660     if (value->isa<UMonad>()) {
661       attr_proto->set_ref_attr_name("Monad:UMonad");
662     } else if (value->isa<IOMonad>()) {
663       attr_proto->set_ref_attr_name("Monad:IOMonad");
664     } else {
665       MS_LOG(EXCEPTION) << "Unsupported Monad type: " << value->type_name();
666     }
667   } else {
668     MS_LOG(EXCEPTION) << "Unsupported type: " << value->type_name();
669   }
670 }
671 
SetScalarToAttributeProto_ir(const ValuePtr & value,mind_ir::AttributeProto * const attr_proto)672 void IrExportBuilder::SetScalarToAttributeProto_ir(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) {
673   if (value == nullptr || attr_proto == nullptr) {
674     MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!";
675   }
676   attr_proto->set_ref_attr_name("scalar:value0");
677   if (value->isa<StringImm>()) {
678     attr_proto->set_type(mind_ir::AttributeProto_AttributeType_STRING);
679     attr_proto->set_s(GetValue<std::string>(value));
680   } else if (value->isa<BoolImm>()) {
681     attr_proto->set_type(mind_ir::AttributeProto_AttributeType_BOOL);
682     int64_t attr_value = GetValue<bool>(value) ? 1 : 0;
683     attr_proto->set_i(attr_value);
684   } else if (value->isa<Int8Imm>()) {
685     attr_proto->set_type(mind_ir::AttributeProto_AttributeType_INT8);
686     attr_proto->set_i(value->cast<Int8ImmPtr>()->value());
687   } else if (value->isa<Int16Imm>()) {
688     attr_proto->set_type(mind_ir::AttributeProto_AttributeType_INT16);
689     attr_proto->set_i(value->cast<Int16ImmPtr>()->value());
690   } else if (value->isa<Int32Imm>()) {
691     attr_proto->set_type(mind_ir::AttributeProto_AttributeType_INT32);
692     attr_proto->set_i(value->cast<Int32ImmPtr>()->value());
693   } else if (value->isa<Int64Imm>()) {
694     attr_proto->set_type(mind_ir::AttributeProto_AttributeType_INT64);
695     attr_proto->set_i(value->cast<Int64ImmPtr>()->value());
696   } else if (value->isa<UInt8Imm>()) {
697     attr_proto->set_type(mind_ir::AttributeProto_AttributeType_UINT8);
698     attr_proto->set_i(value->cast<UInt8ImmPtr>()->value());
699   } else if (value->isa<UInt16Imm>()) {
700     attr_proto->set_type(mind_ir::AttributeProto_AttributeType_UINT16);
701     attr_proto->set_i(value->cast<UInt16ImmPtr>()->value());
702   } else if (value->isa<UInt32Imm>()) {
703     attr_proto->set_type(mind_ir::AttributeProto_AttributeType_UINT32);
704     attr_proto->set_i(value->cast<UInt32ImmPtr>()->value());
705   } else if (value->isa<UInt64Imm>()) {
706     attr_proto->set_type(mind_ir::AttributeProto_AttributeType_UINT64);
707     attr_proto->set_i(UlongToLong(value->cast<UInt64ImmPtr>()->value()));
708   } else if (value->isa<FP32Imm>()) {
709     attr_proto->set_type(mind_ir::AttributeProto_AttributeType_FLOAT);
710     attr_proto->set_f(GetValue<float>(value));
711   } else if (value->isa<FP64Imm>()) {
712     attr_proto->set_type(mind_ir::AttributeProto_AttributeType_DOUBLE);
713     attr_proto->set_d(GetValue<double>(value));
714   } else if (value->isa<tensor::Tensor>()) {
715     attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSOR);
716     SetTensorToAttributeProto(value, attr_proto);
717   } else {
718     MS_LOG(EXCEPTION) << "Unsupported scalar type: " << value->type_name();
719   }
720 }
721 
SetScalarToAttributeProto_irs(const ValuePtr & value,mind_ir::AttributeProto * const attr_proto)722 void IrExportBuilder::SetScalarToAttributeProto_irs(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) {
723   if (value == nullptr || attr_proto == nullptr) {
724     MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!";
725   }
726   if (value->isa<Int>()) {
727     attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSORS);
728     mind_ir::TensorProto *tensor_proto = attr_proto->add_tensors();
729     auto int_value = value->cast<IntPtr>();
730     tensor_proto->set_data_type(GetMindirDataBitsIntType(int_value->nbits()));
731   } else if (value->isa<Float>()) {
732     attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSORS);
733     mind_ir::TensorProto *tensor_proto = attr_proto->add_tensors();
734     auto float_value = value->cast<FloatPtr>();
735     tensor_proto->set_data_type(GetMindirDataBitsFloatType(float_value->nbits()));
736   } else if (value->isa<StringImm>()) {
737     attr_proto->set_type(mind_ir::AttributeProto_AttributeType_STRING);
738     attr_proto->add_strings(GetValue<std::string>(value));
739   } else if (value->isa<BoolImm>()) {
740     attr_proto->set_type(mind_ir::AttributeProto_AttributeType_BOOL);
741     int attr_value = GetValue<bool>(value) ? 1 : 0;
742     attr_proto->add_ints(attr_value);
743   } else if (value->isa<Int8Imm>()) {
744     attr_proto->set_type(mind_ir::AttributeProto_AttributeType_INT8);
745     attr_proto->add_ints(value->cast<Int8ImmPtr>()->value());
746   } else if (value->isa<Int16Imm>()) {
747     attr_proto->set_type(mind_ir::AttributeProto_AttributeType_INT16);
748     attr_proto->add_ints(value->cast<Int16ImmPtr>()->value());
749   } else if (value->isa<Int32Imm>()) {
750     attr_proto->set_type(mind_ir::AttributeProto_AttributeType_INT32);
751     attr_proto->add_ints(value->cast<Int32ImmPtr>()->value());
752   } else if (value->isa<Int64Imm>()) {
753     attr_proto->set_type(mind_ir::AttributeProto_AttributeType_INT64);
754     attr_proto->add_ints(value->cast<Int64ImmPtr>()->value());
755   } else if (value->isa<UInt8Imm>()) {
756     attr_proto->set_type(mind_ir::AttributeProto_AttributeType_UINT8);
757     attr_proto->add_ints(value->cast<UInt8ImmPtr>()->value());
758   } else if (value->isa<UInt16Imm>()) {
759     attr_proto->set_type(mind_ir::AttributeProto_AttributeType_UINT16);
760     attr_proto->add_ints(value->cast<UInt16ImmPtr>()->value());
761   } else if (value->isa<UInt32Imm>()) {
762     attr_proto->set_type(mind_ir::AttributeProto_AttributeType_UINT32);
763     attr_proto->add_ints(value->cast<UInt32ImmPtr>()->value());
764   } else if (value->isa<UInt64Imm>()) {
765     attr_proto->set_type(mind_ir::AttributeProto_AttributeType_UINT64);
766     attr_proto->add_ints(SizeToInt(value->cast<UInt64ImmPtr>()->value()));
767   } else if (value->isa<FP32Imm>()) {
768     attr_proto->set_type(mind_ir::AttributeProto_AttributeType_FLOAT);
769     attr_proto->add_floats(GetValue<float>(value));
770   } else if (value->isa<FP64Imm>()) {
771     attr_proto->set_type(mind_ir::AttributeProto_AttributeType_DOUBLE);
772     attr_proto->add_doubles(GetValue<double>(value));
773   } else if (value->isa<tensor::Tensor>()) {
774     attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSOR);
775     SetTensorToAttributeProto(value, attr_proto);
776   } else {
777     MS_LOG(EXCEPTION) << "Unsupported scalar type: " << value->type_name();
778   }
779 }
780 
SetSeqElemToAttributeProto(const ValuePtr & value,mind_ir::AttributeProto * const attr_proto,std::string * const seq_string)781 void IrExportBuilder::SetSeqElemToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto,
782                                                  std::string *const seq_string) {
783   string value_name = "value" + std::to_string(GetTupleIndex());
784   if (seq_string != nullptr) {
785     *seq_string += value_name + ",";
786   }
787   SetScalarToAttributeProto_irs(value, attr_proto);
788 }
789 
SetSequenceToAttributeProto(const ValueSequeuePtr & value,mind_ir::AttributeProto * const attr_proto,std::string * const seq_string)790 void IrExportBuilder::SetSequenceToAttributeProto(const ValueSequeuePtr &value,
791                                                   mind_ir::AttributeProto *const attr_proto,
792                                                   std::string *const seq_string) {
793   if (value == nullptr || attr_proto == nullptr) {
794     MS_LOG(EXCEPTION) << "ValueSequeuePtr or AttributeProto is null!";
795   }
796   if (value->isa<ValueTuple>() && seq_string != nullptr) {
797     *seq_string += "Tuple[";
798     const ValueTuplePtr &tuple_value = value->cast<ValueTuplePtr>();
799     if (tuple_value->value().size() == 0) {
800       *seq_string += "],";
801       MS_LOG(DEBUG) << "SetSequenceToAttributeProto tuple size is 0";
802       return;
803     }
804     for (const auto &item : tuple_value->value()) {
805       if (item->isa<ValueTuple>()) {
806         SetSequenceToAttributeProto(item->cast<ValueTuplePtr>(), attr_proto, seq_string);
807       } else {
808         SetSeqElemToAttributeProto(item, attr_proto, seq_string);
809       }
810     }
811     *seq_string += "],";
812   } else if (value->isa<ValueList>() && seq_string != nullptr) {
813     *seq_string += "List[";
814     const ValueListPtr &list_value = value->cast<ValueListPtr>();
815     if (list_value->value().size() == 0) {
816       *seq_string += "],";
817       MS_LOG(DEBUG) << "SetSequenceToAttributeProto list size is 0.";
818       return;
819     }
820     for (const auto &item : list_value->value()) {
821       MS_EXCEPTION_IF_NULL(item);
822       if (item->isa<ValueList>()) {
823         SetSequenceToAttributeProto(item->cast<ValueListPtr>(), attr_proto, seq_string);
824       } else {
825         SetSeqElemToAttributeProto(item, attr_proto, seq_string);
826       }
827     }
828     *seq_string += "],";
829   }
830 }
831 
GetBinaryProtoString(const FuncGraphPtr & func_graph)832 std::string GetBinaryProtoString(const FuncGraphPtr &func_graph) {
833   auto builder = std::make_shared<IrExportBuilder>();
834   if (builder == nullptr) {
835     MS_LOG(ERROR) << "Create ir exporter failed!";
836     return "";
837   }
838   auto exporter = std::make_shared<IrExporter>(builder);
839   if (exporter == nullptr) {
840     return "";
841   }
842   return exporter->GetDumpString(func_graph);
843 }
844 
GetBinaryProto(const FuncGraphPtr & func_graph,bool save_tensor_data)845 mind_ir::ModelProto GetBinaryProto(const FuncGraphPtr &func_graph, bool save_tensor_data) {
846   auto exporter = std::make_shared<IrExporter>(std::make_shared<IrExportBuilder>());
847   auto result = exporter->GetDumpProto(func_graph, save_tensor_data);
848   return result;
849 }
850 }  // namespace mindspore
851