1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <map>
18 #include <memory>
19 #include <unordered_map>
20 #include <utility>
21 #include <algorithm>
22 #include <functional>
23
24 #include "ir/tensor.h"
25 #include "ir/param_info.h"
26 #include "ir/func_graph.h"
27 #include "base/core_ops.h"
28 #include "proto/mind_ir.pb.h"
29 #include "utils/check_convert_utils.h"
30
31 namespace mindspore {
32 using FloatPtr = std::shared_ptr<Float>;
33 using IntPtr = std::shared_ptr<Int>;
34 using UIntPtr = std::shared_ptr<UInt>;
35 // anf type to mindir type map
36 static std::unordered_map<int, mind_ir::TensorProto_DataType> g_data_type_map = {
37 {kNumberTypeBool, mind_ir::TensorProto_DataType_BOOL},
38 {kNumberTypeInt8, mind_ir::TensorProto_DataType_INT8},
39 {kNumberTypeInt16, mind_ir::TensorProto_DataType_INT16},
40 {kNumberTypeInt32, mind_ir::TensorProto_DataType_INT32},
41 {kNumberTypeInt64, mind_ir::TensorProto_DataType_INT64},
42 {kNumberTypeUInt8, mind_ir::TensorProto_DataType_UINT8},
43 {kNumberTypeUInt16, mind_ir::TensorProto_DataType_UINT16},
44 {kNumberTypeUInt32, mind_ir::TensorProto_DataType_UINT32},
45 {kNumberTypeUInt64, mind_ir::TensorProto_DataType_UINT64},
46 {kNumberTypeFloat16, mind_ir::TensorProto_DataType_FLOAT16},
47 {kNumberTypeFloat32, mind_ir::TensorProto_DataType_FLOAT},
48 {kNumberTypeFloat64, mind_ir::TensorProto_DataType_DOUBLE},
49 {kObjectTypeString, mind_ir::TensorProto_DataType_STRING},
50 };
51
52 static std::unordered_map<int, mind_ir::TensorProto_DataType> g_data_bits_int_map = {
53 {8, mind_ir::TensorProto_DataType_INT8},
54 {16, mind_ir::TensorProto_DataType_INT16},
55 {32, mind_ir::TensorProto_DataType_INT32},
56 {64, mind_ir::TensorProto_DataType_INT64},
57 };
58
59 static std::unordered_map<int, mind_ir::TensorProto_DataType> g_data_bits_uint_map = {
60 {8, mind_ir::TensorProto_DataType_UINT8},
61 {16, mind_ir::TensorProto_DataType_UINT16},
62 {32, mind_ir::TensorProto_DataType_UINT32},
63 {64, mind_ir::TensorProto_DataType_UINT64},
64 };
65
66 static std::unordered_map<int, mind_ir::TensorProto_DataType> g_data_bits_float_map = {
67 {16, mind_ir::TensorProto_DataType_FLOAT16},
68 {32, mind_ir::TensorProto_DataType_FLOAT},
69 {64, mind_ir::TensorProto_DataType_FLOAT64},
70 };
71
72 // Can build different builder according to format
73 class IrExportBuilder;
74 using IrExportBuilderPtr = std::shared_ptr<IrExportBuilder>;
75
76 class IrExporter {
77 public:
IrExporter(IrExportBuilderPtr builder)78 explicit IrExporter(IrExportBuilderPtr builder) : builder_(std::move(builder)) {}
79 virtual ~IrExporter() = default;
80 std::string GetDumpString(const FuncGraphPtr &func_graph);
81 mind_ir::ModelProto GetDumpProto(const FuncGraphPtr &func_graph, bool save_tensor_data = false);
82
83 private:
84 IrExportBuilderPtr builder_;
85 };
86
87 class IrExportBuilder {
88 public:
89 IrExportBuilder() = default;
~IrExportBuilder()90 ~IrExportBuilder() { google::protobuf::ShutdownProtobufLibrary(); }
91 std::string GetProtoString() const;
92 void BuildModelInfo();
93 void BuildModel(const FuncGraphPtr &func_graph, bool save_tensor_data = false);
Model()94 mind_ir::ModelProto Model() { return model_; }
95
96 private:
97 void BuildFuncGraph(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto,
98 bool save_tensor_data = false);
99 void BuildParameters(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto,
100 bool save_tensor_data = false);
101 void BuildNodes(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto);
102 void BuildOutput(const CNodePtr &node, mind_ir::GraphProto *const graph_proto);
103 void BuildCNode(const CNodePtr &node, mind_ir::GraphProto *const graph_proto);
104 std::string BuildInputNode(const AnfNodePtr &node, mind_ir::GraphProto *const graph_proto);
105
106 void SetValueInfoProto(const AnfNodePtr &node, mind_ir::ValueInfoProto *const value_proto);
107 void SetValueInfoProto(const TypePtr &type, const BaseShapePtr &shape, mind_ir::ValueInfoProto *const value_proto);
108 void SetParamToTensorProto(const ParameterPtr ¶m, mind_ir::TensorProto *const tensor_proto);
109 void SetTensorProto(const TypePtr &type, const BaseShapePtr &shape, mind_ir::TensorProto *const tensor_proto);
110 void SetAttributeProto(const AnfNodePtr &node, mind_ir::NodeProto *const node_proto);
111 void SetShapeToNodeProto(const CNodePtr &node, mind_ir::NodeProto *const node_proto);
112 void SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, mind_ir::AttributeProto *const attr_proto,
113 std::string *const seq_string);
114 void SetValueToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto);
115 void SetTypeToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto);
116 void SetScalarToAttributeProto_ir(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto);
117 void SetScalarToAttributeProto_irs(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto);
118 void SetTensorToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto);
119 void SetSequenceToAttributeProto(const ValueSequeuePtr &value, mind_ir::AttributeProto *const attr_proto,
120 std::string *const seq_string);
121 void SetSeqElemToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto,
122 std::string *const seq_string);
123
124 mind_ir::TensorProto_DataType GetMindirDataType(TypeId type_id);
125 mind_ir::TensorProto_DataType GetMindirDataBitsIntType(int bits);
126 mind_ir::TensorProto_DataType GetMindirDataBitsFloatType(int bits);
127 mind_ir::TensorProto_DataType GetMindirDataBitsUIntType(int bits);
128 std::string GetNodeName(const AnfNodePtr &node);
129 std::string GetUniqueNodeName(const AnfNodePtr &node);
130 std::string GetOpTypeName(const AnfNodePtr &node);
GetNodeIndex()131 size_t GetNodeIndex() { return ++node_index_; }
ResetNodeIndex()132 void ResetNodeIndex() { node_index_ = 0; }
GetTupleIndex()133 size_t GetTupleIndex() { return ++shape_index_; }
ResetTupleIndex()134 void ResetTupleIndex() { shape_index_ = 0; }
135
136 mind_ir::ModelProto model_;
137 mind_ir::NodeProto *last_node_{nullptr};
138 std::list<FuncGraphPtr> todo_;
139 std::map<AnfNodePtr, std::string> node_index_map_;
140 std::set<std::string> nodeName_;
141 size_t node_index_{0};
142 size_t shape_index_{0};
143 bool top_graph{true};
144 };
145
146 using IrExporterPtr = std::shared_ptr<IrExporter>;
147
GetDumpString(const FuncGraphPtr & func_graph)148 std::string IrExporter::GetDumpString(const FuncGraphPtr &func_graph) {
149 (void)GetDumpProto(func_graph);
150 return builder_->GetProtoString();
151 }
152
GetDumpProto(const FuncGraphPtr & func_graph,bool save_tensor_data)153 mind_ir::ModelProto IrExporter::GetDumpProto(const FuncGraphPtr &func_graph, bool save_tensor_data) {
154 if ((builder_ == nullptr) || (func_graph == nullptr)) {
155 MS_LOG(EXCEPTION) << "Input params is null.";
156 }
157
158 // Export model info
159 builder_->BuildModelInfo();
160
161 // Export model and return string
162 builder_->BuildModel(func_graph, save_tensor_data);
163 return builder_->Model();
164 }
165
GetProtoString() const166 std::string IrExportBuilder::GetProtoString() const {
167 MS_LOG(DEBUG) << "BuildModel complete!";
168 return model_.SerializeAsString();
169 }
170
BuildModelInfo()171 void IrExportBuilder::BuildModelInfo() {
172 constexpr auto ir_version = "0.1.0";
173 constexpr auto mindspore_name = "MindSpore";
174 model_.set_ir_version(ir_version);
175 model_.set_producer_name(mindspore_name);
176 model_.set_model_version(VERSION);
177 }
178
BuildModel(const FuncGraphPtr & func_graph,bool save_tensor_data)179 void IrExportBuilder::BuildModel(const FuncGraphPtr &func_graph, bool save_tensor_data) {
180 MS_EXCEPTION_IF_NULL(func_graph);
181 mind_ir::GraphProto *graph_proto = model_.mutable_graph();
182 graph_proto->set_name(func_graph->ToString());
183 graph_proto->set_bprop_hash(func_graph->bprop_hash());
184 ResetNodeIndex();
185 todo_.clear();
186 nodeName_.clear();
187 // Build the main funcGraph
188 (void)nodeName_.insert(func_graph->ToString());
189 top_graph = true;
190 BuildFuncGraph(func_graph, graph_proto, save_tensor_data);
191 std::set<FuncGraphPtr> graphVisited;
192 (void)graphVisited.insert(func_graph);
193 top_graph = false;
194 while (!todo_.empty()) {
195 FuncGraphPtr fg = todo_.back();
196 todo_.pop_back();
197 if (graphVisited.count(fg) > 0) {
198 continue;
199 }
200 if (nodeName_.count(fg->ToString()) > 0) {
201 MS_LOG(EXCEPTION) << "There is a duplicate name: " << fg->ToString();
202 }
203 (void)nodeName_.insert(fg->ToString());
204 (void)graphVisited.insert(fg);
205 auto graph = model_.add_functions();
206 BuildFuncGraph(fg, graph, save_tensor_data);
207 }
208 // Release resource
209 nodeName_.clear();
210 node_index_map_.clear();
211 }
212
BuildFuncGraph(const FuncGraphPtr & func_graph,mind_ir::GraphProto * const graph_proto,bool save_tensor_data)213 void IrExportBuilder::BuildFuncGraph(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto,
214 bool save_tensor_data) {
215 // Export funcGraph name.
216 graph_proto->set_name(func_graph->ToString());
217 // Export parameters
218 // 1. parameters should be mapped to ValueInfoProto
219 // 2. parameters with default value should be mapped to Initializer
220 BuildParameters(func_graph, graph_proto, save_tensor_data);
221
222 // Export operator nodes(include output)
223 BuildNodes(func_graph, graph_proto);
224 }
225
BuildParameters(const FuncGraphPtr & func_graph,mind_ir::GraphProto * const graph_proto,bool save_tensor_data)226 void IrExportBuilder::BuildParameters(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto,
227 bool save_tensor_data) {
228 MS_EXCEPTION_IF_NULL(func_graph);
229 for (auto &item : func_graph->parameters()) {
230 MS_EXCEPTION_IF_NULL(item);
231 auto param = item->cast<ParameterPtr>();
232 if (param == nullptr) {
233 MS_LOG(EXCEPTION) << "Parameter: '" << item->ToString() << "' could not cast to parameter.";
234 }
235 std::string param_name = GetUniqueNodeName(param);
236 if (top_graph && param->has_default()) {
237 MS_LOG(DEBUG) << "Parameter: '" << item->DebugString();
238 mind_ir::TensorProto *parameter_proto = graph_proto->add_parameter();
239 parameter_proto->set_name(param_name);
240 SetParamToTensorProto(param, parameter_proto);
241 auto tensor = std::dynamic_pointer_cast<tensor::Tensor>(param->default_param());
242 if (tensor && save_tensor_data) {
243 parameter_proto->set_raw_data(tensor->data_c(), static_cast<size_t>(tensor->data().nbytes()));
244 }
245 } else {
246 mind_ir::ValueInfoProto *input_proto = graph_proto->add_input();
247 input_proto->set_name(param_name);
248 SetValueInfoProto(param, input_proto);
249 }
250 if (nodeName_.count(param_name) > 0) {
251 MS_LOG(EXCEPTION) << "parameter name is duplicate:" << param_name;
252 }
253 (void)nodeName_.insert(param_name);
254 }
255 }
256
GetMindirDataType(TypeId type_id)257 mind_ir::TensorProto_DataType IrExportBuilder::GetMindirDataType(TypeId type_id) {
258 auto iter = g_data_type_map.find(type_id);
259 if (iter == g_data_type_map.end()) {
260 MS_LOG(EXCEPTION) << "Convert type error, unsupported type! " << type_id;
261 }
262 return iter->second;
263 }
264
GetMindirDataBitsIntType(int bits)265 mind_ir::TensorProto_DataType IrExportBuilder::GetMindirDataBitsIntType(int bits) {
266 auto iter = g_data_bits_int_map.find(bits);
267 if (iter == g_data_bits_int_map.end()) {
268 MS_LOG(EXCEPTION) << "Convert bits int error, unsupported bits! " << bits;
269 }
270 return iter->second;
271 }
272
GetMindirDataBitsUIntType(int bits)273 mind_ir::TensorProto_DataType IrExportBuilder::GetMindirDataBitsUIntType(int bits) {
274 auto iter = g_data_bits_uint_map.find(bits);
275 if (iter == g_data_bits_uint_map.end()) {
276 MS_LOG(EXCEPTION) << "Convert bits uint error, unsupported bits! " << bits;
277 }
278 return iter->second;
279 }
280
GetMindirDataBitsFloatType(int bits)281 mind_ir::TensorProto_DataType IrExportBuilder::GetMindirDataBitsFloatType(int bits) {
282 auto iter = g_data_bits_float_map.find(bits);
283 if (iter == g_data_bits_float_map.end()) {
284 MS_LOG(EXCEPTION) << "Convert bits float error, unsupported bits! " << bits;
285 }
286 return iter->second;
287 }
288
SetValueInfoProto(const AnfNodePtr & node,mind_ir::ValueInfoProto * const value_proto)289 void IrExportBuilder::SetValueInfoProto(const AnfNodePtr &node, mind_ir::ValueInfoProto *const value_proto) {
290 if (node == nullptr || value_proto == nullptr) {
291 MS_LOG(EXCEPTION) << "AnfNode or ValueInfo is null!";
292 }
293 MS_LOG(DEBUG) << "SetValueInfoProto: " << node->DebugString();
294 const TypePtr &type = node->Type();
295 const BaseShapePtr &shape = node->Shape();
296 if (type == nullptr || shape == nullptr) {
297 return;
298 }
299 if (type->isa<TensorType>() && shape->isa<abstract::Shape>()) {
300 auto tensor = type->cast<TensorTypePtr>();
301 MS_EXCEPTION_IF_NULL(tensor);
302 auto elem_type = tensor->element();
303 const auto &dims = shape->cast<abstract::ShapePtr>()->shape();
304 mind_ir::TensorProto *tensor_proto = value_proto->add_tensor();
305 tensor_proto->set_data_type(GetMindirDataType(elem_type->type_id()));
306 if (dims.size() == 0) {
307 MS_LOG(DEBUG) << "The dim of ValueInfoProto is 0.";
308 } else {
309 for (const auto &dim : dims) {
310 MS_LOG(DEBUG) << "SetValueInfoProto dim: " << dim;
311 tensor_proto->add_dims(dim);
312 }
313 }
314 } else if (type->isa<Tuple>()) {
315 auto tup_shape = shape->cast<abstract::TupleShapePtr>();
316 value_proto->set_denotation(type->type_name() + ":" + std::to_string(tup_shape->shape().size()));
317 } else {
318 value_proto->set_denotation(type->type_name());
319 }
320 MS_LOG(DEBUG) << "Value type: " << type->type_name();
321 }
322
SetTensorToAttributeProto(const ValuePtr & value,mind_ir::AttributeProto * const attr_proto)323 void IrExportBuilder::SetTensorToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) {
324 if (value == nullptr || attr_proto == nullptr) {
325 MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!";
326 }
327 attr_proto->set_ref_attr_name("tensor:value0");
328 attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSORS);
329 mind_ir::TensorProto *tensor_proto = attr_proto->add_tensors();
330 tensor_proto->set_name("value0");
331 auto data = value->cast<tensor::TensorPtr>();
332 MS_EXCEPTION_IF_NULL(data);
333 tensor_proto->set_raw_data(data->data_c(), static_cast<size_t>(data->data().nbytes()));
334 auto dtype = data->data_type();
335 auto shape = data->shape_c();
336 tensor_proto->set_data_type(GetMindirDataType(dtype));
337 for (const auto &dim : shape) {
338 tensor_proto->add_dims(dim);
339 }
340 }
341
SetTensorProto(const TypePtr & type,const BaseShapePtr & shape,mind_ir::TensorProto * const tensor_proto)342 void IrExportBuilder::SetTensorProto(const TypePtr &type, const BaseShapePtr &shape,
343 mind_ir::TensorProto *const tensor_proto) {
344 if (!type->isa<TensorType>() || !shape->isa<abstract::Shape>()) {
345 MS_LOG(EXCEPTION) << "Type or shape is not supported! " << type->ToString();
346 }
347 auto tensor = type->cast<TensorTypePtr>();
348 const auto &dims = shape->cast<abstract::ShapePtr>()->shape();
349 tensor_proto->set_data_type(GetMindirDataType(tensor->element()->type_id()));
350 for (const auto &dim : dims) {
351 tensor_proto->add_dims(dim);
352 }
353 }
354
SetParamToTensorProto(const ParameterPtr & param,mind_ir::TensorProto * const tensor_proto)355 void IrExportBuilder::SetParamToTensorProto(const ParameterPtr ¶m, mind_ir::TensorProto *const tensor_proto) {
356 if (param == nullptr || tensor_proto == nullptr) {
357 MS_LOG(EXCEPTION) << "Parameter or TensorProto is null!";
358 }
359 MS_LOG(DEBUG) << "SetParamToTensorProto: " << param->DebugString();
360 SetTensorProto(param->Type(), param->Shape(), tensor_proto);
361 }
362
BuildNodes(const FuncGraphPtr & func_graph,mind_ir::GraphProto * const graph_proto)363 void IrExportBuilder::BuildNodes(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto) {
364 std::vector<AnfNodePtr> nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude);
365 for (const AnfNodePtr &node : nodes) {
366 MS_EXCEPTION_IF_NULL(node);
367 if (!node->isa<CNode>()) {
368 MS_LOG(DEBUG) << "Node: '" << node->ToString() << "' is not cnode";
369 continue;
370 }
371 auto cnode = node->cast<CNodePtr>();
372 if (cnode == func_graph->get_return()) {
373 BuildOutput(cnode, graph_proto);
374 } else {
375 BuildCNode(cnode, graph_proto);
376 }
377 }
378 }
379
BuildOutput(const CNodePtr & node,mind_ir::GraphProto * const graph_proto)380 void IrExportBuilder::BuildOutput(const CNodePtr &node, mind_ir::GraphProto *const graph_proto) {
381 MS_EXCEPTION_IF_NULL(node);
382 const int OutputSize = 2;
383 if (node->size() != OutputSize) {
384 MS_LOG(EXCEPTION) << "Number of inputs of return node is not equal to 2.";
385 }
386 AnfNodePtr arg = node->input(1);
387 std::string node_name = BuildInputNode(arg, graph_proto);
388 mind_ir::ValueInfoProto *output_proto = graph_proto->add_output();
389 output_proto->set_name(node_name);
390 SetValueInfoProto(arg, output_proto);
391 }
392
GetOpTypeName(const AnfNodePtr & node)393 std::string IrExportBuilder::GetOpTypeName(const AnfNodePtr &node) {
394 // May be ValueNode/CNode/Parameter
395 std::string type_name = "";
396 if (IsValueNode<Primitive>(node)) {
397 PrimitivePtr prim = GetValueNode<PrimitivePtr>(node);
398 MS_EXCEPTION_IF_NULL(prim);
399 type_name = prim->ToString();
400 } else if (IsValueNode<FuncGraph>(node)) {
401 FuncGraphPtr fg = GetValueNode<FuncGraphPtr>(node);
402 MS_EXCEPTION_IF_NULL(fg);
403 todo_.push_back(fg);
404 type_name = "REF::" + fg->ToString();
405 } else if (node->isa<CNode>() || node->isa<Parameter>()) {
406 auto nodeName = GetUniqueNodeName(node);
407 type_name = "REF::" + nodeName;
408 if (nodeName_.count(nodeName) == 0) {
409 MS_LOG(EXCEPTION) << "There is not the name: " << nodeName;
410 }
411 } else {
412 MS_LOG(EXCEPTION) << "Need to support op type: " << node->type_name();
413 }
414 MS_LOG(DEBUG) << "ExportType: " << type_name;
415 return type_name;
416 }
417
SetShapeToNodeProto(const TypePtr & type,const BaseShapePtr & shape,mind_ir::AttributeProto * const attr_proto,std::string * const seq_string)418 void IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape,
419 mind_ir::AttributeProto *const attr_proto, std::string *const seq_string) {
420 MS_EXCEPTION_IF_NULL(type);
421 MS_EXCEPTION_IF_NULL(shape);
422 MS_EXCEPTION_IF_NULL(seq_string);
423 if (type->isa<Tuple>()) {
424 *seq_string += "Tuple[";
425 auto elements = type->cast<TuplePtr>()->elements();
426 auto tuple_shape = shape->cast<abstract::TupleShapePtr>()->shape();
427 for (size_t i = 0; i < elements.size(); i++) {
428 SetShapeToNodeProto(elements[i], tuple_shape[i], attr_proto, seq_string);
429 }
430 *seq_string += "],";
431 } else if (type->isa<TensorType>() && shape->isa<abstract::Shape>()) {
432 string shape_name = "shape" + std::to_string(GetTupleIndex());
433 *seq_string += shape_name + ",";
434 mind_ir::TensorProto *tensor_proto = attr_proto->add_tensors();
435 tensor_proto->set_name(shape_name);
436 SetTensorProto(type, shape, tensor_proto);
437 } else if (type->isa<Number>()) {
438 if (type->isa<Bool>()) {
439 attr_proto->set_type(mind_ir::AttributeProto_AttributeType_BOOL);
440 } else {
441 string shape_name = "shape" + std::to_string(GetTupleIndex());
442 *seq_string += shape_name + ",";
443 mind_ir::TensorProto *tensor_proto = attr_proto->add_tensors();
444 tensor_proto->set_name(shape_name);
445 tensor_proto->set_data_type(mind_ir::TensorProto_DataType_UINT64);
446 tensor_proto->add_dims(1);
447 }
448 } else if (type->isa<Function>()) {
449 attr_proto->set_type(mind_ir::AttributeProto_AttributeType_GRAPH);
450 *seq_string += type->type_name() + ",";
451 } else if (type->isa<String>() || type->isa<UMonadType>() || type->isa<IOMonadType>()) {
452 *seq_string += type->type_name() + ",";
453 } else {
454 MS_LOG(EXCEPTION) << "Type of cnode need to be supported: " << type->type_name();
455 }
456 }
457
SetShapeToNodeProto(const CNodePtr & node,mind_ir::NodeProto * const node_proto)458 void IrExportBuilder::SetShapeToNodeProto(const CNodePtr &node, mind_ir::NodeProto *const node_proto) {
459 // Get shape of cnode
460 // 1. need to get shape from tuple element
461 // 2. save shape in TensorProto
462 // 3. save tuple string in ref_attr_name
463 MS_EXCEPTION_IF_NULL(node);
464 auto type = node->Type();
465 auto shape = node->Shape();
466 if (type == nullptr || shape == nullptr) {
467 return;
468 }
469 ResetTupleIndex();
470 std::string seq_string = "shape:";
471 mind_ir::AttributeProto *attr_proto = node_proto->add_attribute();
472 SetShapeToNodeProto(type, shape, attr_proto, &seq_string);
473 attr_proto->set_ref_attr_name(seq_string);
474 MS_LOG(DEBUG) << "CNode shape: " << seq_string;
475 }
476
BuildCNode(const CNodePtr & node,mind_ir::GraphProto * const graph_proto)477 void IrExportBuilder::BuildCNode(const CNodePtr &node, mind_ir::GraphProto *const graph_proto) {
478 auto inputs_size = node->size();
479 if (inputs_size < 1) {
480 MS_LOG(EXCEPTION) << "Inputs of apply node is empty";
481 }
482
483 // Need to build input node before dealing with cnode
484 std::vector<AnfNodePtr> op_inputs;
485 std::vector<string> input_names;
486 for (size_t i = 1; i < inputs_size; i++) {
487 auto input = node->input(i);
488 op_inputs.push_back(input);
489 input_names.push_back(BuildInputNode(input, graph_proto));
490 }
491
492 // Build cnode
493 mind_ir::NodeProto *node_proto = graph_proto->add_node();
494 std::string output_name = GetUniqueNodeName(node);
495 if (nodeName_.count(output_name) > 0) {
496 MS_LOG(EXCEPTION) << "There is a duplicate name: " << output_name;
497 }
498 (void)nodeName_.insert(output_name);
499 node_proto->add_output(output_name);
500 node_proto->set_name(output_name);
501 node_proto->set_domain(node->fullname_with_scope());
502 AnfNodePtr op = node->input(0);
503 std::string type_name = GetOpTypeName(op);
504 node_proto->set_op_type(type_name);
505 last_node_ = node_proto;
506 // Maybe Tensor or Function or nullptr
507 SetShapeToNodeProto(node, node_proto);
508
509 (void)std::for_each(input_names.begin(), input_names.end(),
510 [&node_proto](const string &name) { node_proto->add_input(name); });
511
512 // Add primitive attrs
513 if (IsValueNode<Primitive>(op)) {
514 auto prim = GetValueNode<PrimitivePtr>(op);
515 for (auto attr : prim->attrs()) {
516 MS_LOG(DEBUG) << "attr: " << attr.first << " " << attr.second->DumpText() << " " << attr.second->type_name();
517 mind_ir::AttributeProto *attr_proto = node_proto->add_attribute();
518 attr_proto->set_name(attr.first);
519 auto attr_value = attr.second;
520 CheckAndConvertUtils::ConvertAttrValueInExport(type_name, attr.first, &attr_value);
521 SetValueToAttributeProto(attr_value, attr_proto);
522 }
523 }
524 }
525
BuildInputNode(const AnfNodePtr & node,mind_ir::GraphProto * const graph_proto)526 std::string IrExportBuilder::BuildInputNode(const AnfNodePtr &node, mind_ir::GraphProto *const graph_proto) {
527 std::string node_name = GetUniqueNodeName(node);
528 // FuncGraph will be added to functions and the input name is the function name.
529 if (IsValueNode<FuncGraph>(node)) {
530 FuncGraphPtr fg = GetValueNode<FuncGraphPtr>(node);
531 todo_.push_back(fg);
532 return fg->ToString();
533 }
534 if (node->isa<ValueNode>()) {
535 // When node input is a ValueNode, need to create a Constant Node
536 mind_ir::NodeProto *node_proto = graph_proto->add_node();
537 node_proto->set_name(node_name);
538 node_proto->add_output(node_name);
539 SetAttributeProto(node, node_proto);
540 }
541 return node_name;
542 }
543
GetUniqueNodeName(const AnfNodePtr & node)544 std::string IrExportBuilder::GetUniqueNodeName(const AnfNodePtr &node) {
545 // Naming anfnode
546 // 1. parameter is unique in one func_graph
547 // 2. cnode and valuenode may be reduplicative, so add index to identify.
548 auto iter = node_index_map_.find(node);
549 if (iter != node_index_map_.end()) {
550 return iter->second;
551 } else {
552 std::string node_name = GetNodeName(node);
553 // Compatible before. CNode = FuncGraphName:CNodeName:index ,Parameter = FuncGraphName:ParameterName
554 if (node->isa<CNode>()) {
555 node_name = node_name + ":" + std::to_string(GetNodeIndex());
556 }
557 // Avoid duplicate name.
558 while (nodeName_.count(node_name) > 0) {
559 node_name = node_name + "_" + std::to_string(GetNodeIndex());
560 }
561 node_index_map_[node] = node_name;
562 return node_name;
563 }
564 }
565
GetNodeName(const AnfNodePtr & node)566 std::string IrExportBuilder::GetNodeName(const AnfNodePtr &node) {
567 MS_EXCEPTION_IF_NULL(node);
568 std::string node_name = "";
569 if (node->func_graph() != nullptr) {
570 node_name = node->func_graph()->ToString() + ":";
571 }
572 if (node->isa<ValueNode>()) {
573 // Needn't value
574 node_name += node->AnfNode::ToString();
575 } else {
576 node_name += node->ToString();
577 }
578 MS_LOG(DEBUG) << "GetNodeName: " << node_name;
579 return node_name;
580 }
581
SetAttributeProto(const AnfNodePtr & node,mind_ir::NodeProto * const node_proto)582 void IrExportBuilder::SetAttributeProto(const AnfNodePtr &node, mind_ir::NodeProto *const node_proto) {
583 if (node == nullptr || node_proto == nullptr) {
584 MS_LOG(EXCEPTION) << "AnfNode or NodeProto is null!";
585 }
586 auto value_node = node->cast<ValueNodePtr>();
587 MS_EXCEPTION_IF_NULL(value_node);
588 auto value = value_node->value();
589 node_proto->set_op_type("Constant");
590 mind_ir::AttributeProto *attr_proto = node_proto->add_attribute();
591 attr_proto->set_name("value");
592 MS_LOG(DEBUG) << "Set Constant attribute: " << value->ToString();
593 SetValueToAttributeProto(value, attr_proto);
594 }
595
SetTypeToAttributeProto(const ValuePtr & value,mind_ir::AttributeProto * const attr_proto)596 void IrExportBuilder::SetTypeToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) {
597 if (value == nullptr || attr_proto == nullptr) {
598 MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!";
599 }
600 attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSORS);
601 mind_ir::TensorProto *tensor_proto = attr_proto->add_tensors();
602 if (value->isa<Int>()) {
603 attr_proto->set_ref_attr_name("type:value0");
604 tensor_proto->set_name("value0");
605 auto int_value = value->cast<IntPtr>();
606 tensor_proto->set_data_type(GetMindirDataBitsIntType(int_value->nbits()));
607 } else if (value->isa<UInt>()) {
608 attr_proto->set_ref_attr_name("type:value0");
609 tensor_proto->set_name("value0");
610 auto float_value = value->cast<UIntPtr>();
611 tensor_proto->set_data_type(GetMindirDataBitsUIntType(float_value->nbits()));
612 } else if (value->isa<Float>()) {
613 attr_proto->set_ref_attr_name("type:value0");
614 tensor_proto->set_name("value0");
615 auto float_value = value->cast<FloatPtr>();
616 tensor_proto->set_data_type(GetMindirDataBitsFloatType(float_value->nbits()));
617 } else if (value->isa<Bool>()) {
618 attr_proto->set_ref_attr_name("type:value0");
619 tensor_proto->set_name("value0");
620 tensor_proto->set_data_type(mind_ir::TensorProto_DataType_BOOL);
621 } else if (value->isa<TensorType>()) {
622 attr_proto->set_ref_attr_name("type:tensor0");
623 tensor_proto->set_name("tensor0");
624 auto elem_type = value->cast<TensorTypePtr>()->element();
625 if (elem_type->isa<Int>()) {
626 auto int_value = elem_type->cast<IntPtr>();
627 tensor_proto->set_data_type(GetMindirDataBitsIntType(int_value->nbits()));
628 } else if (elem_type->isa<Float>()) {
629 auto float_value = elem_type->cast<FloatPtr>();
630 tensor_proto->set_data_type(GetMindirDataBitsFloatType(float_value->nbits()));
631 } else {
632 MS_LOG(EXCEPTION) << "Unsupported type " << elem_type->type_name();
633 }
634 } else {
635 MS_LOG(EXCEPTION) << "Unsupported type: " << value->type_name();
636 }
637 }
638
SetValueToAttributeProto(const ValuePtr & value,mind_ir::AttributeProto * const attr_proto)639 void IrExportBuilder::SetValueToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) {
640 if (value == nullptr || attr_proto == nullptr) {
641 MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!";
642 }
643 if (value->isa<StringImm>() || value->isa<Scalar>()) {
644 SetScalarToAttributeProto_ir(value, attr_proto);
645 } else if (value->isa<Number>() || value->isa<TensorType>()) {
646 SetTypeToAttributeProto(value, attr_proto);
647 } else if (value->isa<ValueSequeue>()) {
648 ResetTupleIndex();
649 std::string seq_string = "scalar:";
650 attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSORS);
651 SetSequenceToAttributeProto(value->cast<ValueSequeuePtr>(), attr_proto, &seq_string);
652 attr_proto->set_ref_attr_name(seq_string);
653 MS_LOG(DEBUG) << "Attr string: " << seq_string;
654 } else if (value->isa<tensor::Tensor>()) {
655 SetTensorToAttributeProto(value, attr_proto);
656 } else if (value->isa<None>()) {
657 attr_proto->set_ref_attr_name("none");
658 MS_LOG(DEBUG) << "Attr string: " << value->type_name();
659 } else if (value->isa<Monad>()) {
660 if (value->isa<UMonad>()) {
661 attr_proto->set_ref_attr_name("Monad:UMonad");
662 } else if (value->isa<IOMonad>()) {
663 attr_proto->set_ref_attr_name("Monad:IOMonad");
664 } else {
665 MS_LOG(EXCEPTION) << "Unsupported Monad type: " << value->type_name();
666 }
667 } else {
668 MS_LOG(EXCEPTION) << "Unsupported type: " << value->type_name();
669 }
670 }
671
SetScalarToAttributeProto_ir(const ValuePtr & value,mind_ir::AttributeProto * const attr_proto)672 void IrExportBuilder::SetScalarToAttributeProto_ir(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) {
673 if (value == nullptr || attr_proto == nullptr) {
674 MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!";
675 }
676 attr_proto->set_ref_attr_name("scalar:value0");
677 if (value->isa<StringImm>()) {
678 attr_proto->set_type(mind_ir::AttributeProto_AttributeType_STRING);
679 attr_proto->set_s(GetValue<std::string>(value));
680 } else if (value->isa<BoolImm>()) {
681 attr_proto->set_type(mind_ir::AttributeProto_AttributeType_BOOL);
682 int64_t attr_value = GetValue<bool>(value) ? 1 : 0;
683 attr_proto->set_i(attr_value);
684 } else if (value->isa<Int8Imm>()) {
685 attr_proto->set_type(mind_ir::AttributeProto_AttributeType_INT8);
686 attr_proto->set_i(value->cast<Int8ImmPtr>()->value());
687 } else if (value->isa<Int16Imm>()) {
688 attr_proto->set_type(mind_ir::AttributeProto_AttributeType_INT16);
689 attr_proto->set_i(value->cast<Int16ImmPtr>()->value());
690 } else if (value->isa<Int32Imm>()) {
691 attr_proto->set_type(mind_ir::AttributeProto_AttributeType_INT32);
692 attr_proto->set_i(value->cast<Int32ImmPtr>()->value());
693 } else if (value->isa<Int64Imm>()) {
694 attr_proto->set_type(mind_ir::AttributeProto_AttributeType_INT64);
695 attr_proto->set_i(value->cast<Int64ImmPtr>()->value());
696 } else if (value->isa<UInt8Imm>()) {
697 attr_proto->set_type(mind_ir::AttributeProto_AttributeType_UINT8);
698 attr_proto->set_i(value->cast<UInt8ImmPtr>()->value());
699 } else if (value->isa<UInt16Imm>()) {
700 attr_proto->set_type(mind_ir::AttributeProto_AttributeType_UINT16);
701 attr_proto->set_i(value->cast<UInt16ImmPtr>()->value());
702 } else if (value->isa<UInt32Imm>()) {
703 attr_proto->set_type(mind_ir::AttributeProto_AttributeType_UINT32);
704 attr_proto->set_i(value->cast<UInt32ImmPtr>()->value());
705 } else if (value->isa<UInt64Imm>()) {
706 attr_proto->set_type(mind_ir::AttributeProto_AttributeType_UINT64);
707 attr_proto->set_i(UlongToLong(value->cast<UInt64ImmPtr>()->value()));
708 } else if (value->isa<FP32Imm>()) {
709 attr_proto->set_type(mind_ir::AttributeProto_AttributeType_FLOAT);
710 attr_proto->set_f(GetValue<float>(value));
711 } else if (value->isa<FP64Imm>()) {
712 attr_proto->set_type(mind_ir::AttributeProto_AttributeType_DOUBLE);
713 attr_proto->set_d(GetValue<double>(value));
714 } else if (value->isa<tensor::Tensor>()) {
715 attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSOR);
716 SetTensorToAttributeProto(value, attr_proto);
717 } else {
718 MS_LOG(EXCEPTION) << "Unsupported scalar type: " << value->type_name();
719 }
720 }
721
SetScalarToAttributeProto_irs(const ValuePtr & value,mind_ir::AttributeProto * const attr_proto)722 void IrExportBuilder::SetScalarToAttributeProto_irs(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) {
723 if (value == nullptr || attr_proto == nullptr) {
724 MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!";
725 }
726 if (value->isa<Int>()) {
727 attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSORS);
728 mind_ir::TensorProto *tensor_proto = attr_proto->add_tensors();
729 auto int_value = value->cast<IntPtr>();
730 tensor_proto->set_data_type(GetMindirDataBitsIntType(int_value->nbits()));
731 } else if (value->isa<Float>()) {
732 attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSORS);
733 mind_ir::TensorProto *tensor_proto = attr_proto->add_tensors();
734 auto float_value = value->cast<FloatPtr>();
735 tensor_proto->set_data_type(GetMindirDataBitsFloatType(float_value->nbits()));
736 } else if (value->isa<StringImm>()) {
737 attr_proto->set_type(mind_ir::AttributeProto_AttributeType_STRING);
738 attr_proto->add_strings(GetValue<std::string>(value));
739 } else if (value->isa<BoolImm>()) {
740 attr_proto->set_type(mind_ir::AttributeProto_AttributeType_BOOL);
741 int attr_value = GetValue<bool>(value) ? 1 : 0;
742 attr_proto->add_ints(attr_value);
743 } else if (value->isa<Int8Imm>()) {
744 attr_proto->set_type(mind_ir::AttributeProto_AttributeType_INT8);
745 attr_proto->add_ints(value->cast<Int8ImmPtr>()->value());
746 } else if (value->isa<Int16Imm>()) {
747 attr_proto->set_type(mind_ir::AttributeProto_AttributeType_INT16);
748 attr_proto->add_ints(value->cast<Int16ImmPtr>()->value());
749 } else if (value->isa<Int32Imm>()) {
750 attr_proto->set_type(mind_ir::AttributeProto_AttributeType_INT32);
751 attr_proto->add_ints(value->cast<Int32ImmPtr>()->value());
752 } else if (value->isa<Int64Imm>()) {
753 attr_proto->set_type(mind_ir::AttributeProto_AttributeType_INT64);
754 attr_proto->add_ints(value->cast<Int64ImmPtr>()->value());
755 } else if (value->isa<UInt8Imm>()) {
756 attr_proto->set_type(mind_ir::AttributeProto_AttributeType_UINT8);
757 attr_proto->add_ints(value->cast<UInt8ImmPtr>()->value());
758 } else if (value->isa<UInt16Imm>()) {
759 attr_proto->set_type(mind_ir::AttributeProto_AttributeType_UINT16);
760 attr_proto->add_ints(value->cast<UInt16ImmPtr>()->value());
761 } else if (value->isa<UInt32Imm>()) {
762 attr_proto->set_type(mind_ir::AttributeProto_AttributeType_UINT32);
763 attr_proto->add_ints(value->cast<UInt32ImmPtr>()->value());
764 } else if (value->isa<UInt64Imm>()) {
765 attr_proto->set_type(mind_ir::AttributeProto_AttributeType_UINT64);
766 attr_proto->add_ints(SizeToInt(value->cast<UInt64ImmPtr>()->value()));
767 } else if (value->isa<FP32Imm>()) {
768 attr_proto->set_type(mind_ir::AttributeProto_AttributeType_FLOAT);
769 attr_proto->add_floats(GetValue<float>(value));
770 } else if (value->isa<FP64Imm>()) {
771 attr_proto->set_type(mind_ir::AttributeProto_AttributeType_DOUBLE);
772 attr_proto->add_doubles(GetValue<double>(value));
773 } else if (value->isa<tensor::Tensor>()) {
774 attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSOR);
775 SetTensorToAttributeProto(value, attr_proto);
776 } else {
777 MS_LOG(EXCEPTION) << "Unsupported scalar type: " << value->type_name();
778 }
779 }
780
SetSeqElemToAttributeProto(const ValuePtr & value,mind_ir::AttributeProto * const attr_proto,std::string * const seq_string)781 void IrExportBuilder::SetSeqElemToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto,
782 std::string *const seq_string) {
783 string value_name = "value" + std::to_string(GetTupleIndex());
784 if (seq_string != nullptr) {
785 *seq_string += value_name + ",";
786 }
787 SetScalarToAttributeProto_irs(value, attr_proto);
788 }
789
SetSequenceToAttributeProto(const ValueSequeuePtr & value,mind_ir::AttributeProto * const attr_proto,std::string * const seq_string)790 void IrExportBuilder::SetSequenceToAttributeProto(const ValueSequeuePtr &value,
791 mind_ir::AttributeProto *const attr_proto,
792 std::string *const seq_string) {
793 if (value == nullptr || attr_proto == nullptr) {
794 MS_LOG(EXCEPTION) << "ValueSequeuePtr or AttributeProto is null!";
795 }
796 if (value->isa<ValueTuple>() && seq_string != nullptr) {
797 *seq_string += "Tuple[";
798 const ValueTuplePtr &tuple_value = value->cast<ValueTuplePtr>();
799 if (tuple_value->value().size() == 0) {
800 *seq_string += "],";
801 MS_LOG(DEBUG) << "SetSequenceToAttributeProto tuple size is 0";
802 return;
803 }
804 for (const auto &item : tuple_value->value()) {
805 if (item->isa<ValueTuple>()) {
806 SetSequenceToAttributeProto(item->cast<ValueTuplePtr>(), attr_proto, seq_string);
807 } else {
808 SetSeqElemToAttributeProto(item, attr_proto, seq_string);
809 }
810 }
811 *seq_string += "],";
812 } else if (value->isa<ValueList>() && seq_string != nullptr) {
813 *seq_string += "List[";
814 const ValueListPtr &list_value = value->cast<ValueListPtr>();
815 if (list_value->value().size() == 0) {
816 *seq_string += "],";
817 MS_LOG(DEBUG) << "SetSequenceToAttributeProto list size is 0.";
818 return;
819 }
820 for (const auto &item : list_value->value()) {
821 MS_EXCEPTION_IF_NULL(item);
822 if (item->isa<ValueList>()) {
823 SetSequenceToAttributeProto(item->cast<ValueListPtr>(), attr_proto, seq_string);
824 } else {
825 SetSeqElemToAttributeProto(item, attr_proto, seq_string);
826 }
827 }
828 *seq_string += "],";
829 }
830 }
831
GetBinaryProtoString(const FuncGraphPtr & func_graph)832 std::string GetBinaryProtoString(const FuncGraphPtr &func_graph) {
833 auto builder = std::make_shared<IrExportBuilder>();
834 if (builder == nullptr) {
835 MS_LOG(ERROR) << "Create ir exporter failed!";
836 return "";
837 }
838 auto exporter = std::make_shared<IrExporter>(builder);
839 if (exporter == nullptr) {
840 return "";
841 }
842 return exporter->GetDumpString(func_graph);
843 }
844
GetBinaryProto(const FuncGraphPtr & func_graph,bool save_tensor_data)845 mind_ir::ModelProto GetBinaryProto(const FuncGraphPtr &func_graph, bool save_tensor_data) {
846 auto exporter = std::make_shared<IrExporter>(std::make_shared<IrExportBuilder>());
847 auto result = exporter->GetDumpProto(func_graph, save_tensor_data);
848 return result;
849 }
850 } // namespace mindspore
851