1 /**
2 * Copyright 2020-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <algorithm>
18 #include <fstream>
19 #include <functional>
20 #include <map>
21 #include <memory>
22 #include <utility>
23 #include "include/common/debug/anf_ir_dump.h"
24 #include "include/common/debug/dump_proto.h"
25 #include "mindspore/core/ops/op_def.h"
26 #include "mindspore/core/ops/array_ops.h"
27 #include "mindspore/core/ops/framework_ops.h"
28 #include "mindspore/core/ops/structure_ops.h"
29 #include "include/common/utils/compile_cache_context.h"
30
31 namespace {
32 using mindspore::CNodePtr;
33 using mindspore::FileUtils;
34 using mindspore::FuncGraph;
35 using mindspore::FuncGraphPtr;
36 using mindspore::ValueNode;
37 using mindspore::ValueNodePtr;
38
GetAllFuncGraphs(const FuncGraphPtr & func_graph,std::set<FuncGraphPtr> * all_func_graphs)39 void GetAllFuncGraphs(const FuncGraphPtr &func_graph, std::set<FuncGraphPtr> *all_func_graphs) {
40 MS_ASSERT(all_func_graphs != nullptr);
41 MS_ASSERT(func_graph != nullptr);
42 if (all_func_graphs->find(func_graph) == all_func_graphs->end()) {
43 (void)(all_func_graphs->insert(func_graph));
44 } else {
45 return;
46 }
47 auto nodes = mindspore::TopoSort(func_graph->get_return());
48 for (auto &node : nodes) {
49 if (mindspore::IsValueNode<FuncGraph>(node)) {
50 MS_ASSERT(node->cast<ValueNodePtr>() != nullptr);
51 MS_ASSERT(node->cast<ValueNodePtr>()->value() != nullptr);
52 MS_ASSERT((node->cast<ValueNodePtr>()->value())->cast<FuncGraphPtr>() != nullptr);
53 auto new_fg = (node->cast<ValueNodePtr>()->value())->cast<FuncGraphPtr>();
54 GetAllFuncGraphs(new_fg, all_func_graphs);
55 }
56 if (mindspore::utils::isa<CNodePtr>(node)) {
57 auto cnode = node->cast<CNodePtr>();
58 MS_ASSERT(cnode != nullptr);
59 for (auto &weak_input : cnode->weak_inputs()) {
60 auto input = weak_input.lock();
61 MS_EXCEPTION_IF_NULL(input);
62 if (input->isa<ValueNode>()) {
63 if (mindspore::IsValueNode<FuncGraph>(input)) {
64 MS_ASSERT(input->cast<ValueNodePtr>() != nullptr);
65 MS_ASSERT(input->cast<ValueNodePtr>()->value() != nullptr);
66 MS_ASSERT((input->cast<ValueNodePtr>()->value())->cast<FuncGraphPtr>() != nullptr);
67 auto new_fg = (input->cast<ValueNodePtr>()->value())->cast<FuncGraphPtr>();
68 GetAllFuncGraphs(new_fg, all_func_graphs);
69 }
70 }
71 }
72 }
73 }
74 }
75
DeleteDirRecursively(const std::string & dir_name)76 bool DeleteDirRecursively(const std::string &dir_name) {
77 DIR *dir = opendir(dir_name.c_str());
78 dirent *dirent = nullptr;
79 std::vector<std::string> file_names{};
80 while ((dirent = readdir(dir)) != nullptr) {
81 if (strcmp(dirent->d_name, ".") != 0 && strcmp(dirent->d_name, "..") != 0) {
82 (void)(file_names.emplace_back(dirent->d_name));
83 }
84 }
85 for (auto &file_name : file_names) {
86 auto file_path = dir_name + "/" + file_name;
87 auto real_file_path = FileUtils::GetRealPath(file_path.c_str());
88 if (!real_file_path.has_value()) {
89 (void)(closedir(dir));
90 MS_LOG(ERROR) << "Cannot get pwd path";
91 return false;
92 }
93 auto result = unlink(real_file_path.value().c_str());
94 if (result != 0) {
95 (void)(closedir(dir));
96 MS_LOG(ERROR) << "Delete the file(" << real_file_path.value() << ") failed." << mindspore::ErrnoToString(errno);
97 return false;
98 }
99 }
100 (void)(closedir(dir));
101 return true;
102 }
103 }; // namespace
104
105 namespace mindspore {
SetAbstractFuncToAttributeProto(const abstract::AbstractBasePtr & abstract,mind_ir::AttributeProto * const attr_proto)106 bool IrExportBuilder::SetAbstractFuncToAttributeProto(const abstract::AbstractBasePtr &abstract,
107 mind_ir::AttributeProto *const attr_proto) {
108 MS_EXCEPTION_IF_NULL(abstract);
109 MS_EXCEPTION_IF_NULL(attr_proto);
110 if (abstract->isa<abstract::FuncGraphAbstractClosure>()) {
111 attr_proto->set_type(mind_ir::AttributeProto_AttributeType_FUNCGRAPHCLOSURE);
112 auto func_name = abstract->cast<abstract::FuncGraphAbstractClosurePtr>()->func_graph()->ToString();
113 attr_proto->set_s(func_name);
114 } else if (abstract->isa<abstract::PrimitiveAbstractClosure>()) {
115 attr_proto->set_type(mind_ir::AttributeProto_AttributeType_PRIMITIVECLOSURE);
116 auto prim = abstract->cast<abstract::PrimitiveAbstractClosurePtr>()->prim();
117 attr_proto->set_s(GetPrimitiveUniqueName(prim));
118 } else if (abstract->isa<abstract::PartialAbstractClosure>()) {
119 attr_proto->set_type(mind_ir::AttributeProto_AttributeType_PARTIALCLOSURE);
120 auto node_ptr = abstract->cast<abstract::PartialAbstractClosurePtr>()->node();
121 MS_EXCEPTION_IF_NULL(node_ptr);
122 attr_proto->set_s(GetUniqueNodeName(node_ptr));
123 } else if (abstract->isa<abstract::AbstractFuncUnion>()) {
124 attr_proto->set_type(mind_ir::AttributeProto_AttributeType_UNIONFUNCCLOSURE);
125 auto visit_func = [this, &attr_proto](const abstract::AbstractFuncAtomPtr &poss) {
126 auto element_attr_proto = attr_proto->add_values();
127 if (!this->SetAbstractFuncToAttributeProto(poss, element_attr_proto)) {
128 MS_LOG(EXCEPTION) << "Set union function abstract to proto error." << poss->ToString();
129 }
130 };
131 abstract->cast<abstract::AbstractFunctionPtr>()->Visit(visit_func);
132 } else {
133 MS_LOG(ERROR) << "The parameter abstract is not an abstractFunction: " << abstract->ToString();
134 return false;
135 }
136 return true;
137 }
138
GetPrimitiveUniqueName(const PrimitivePtr & primitive_ptr)139 std::string IrExportBuilder::GetPrimitiveUniqueName(const PrimitivePtr &primitive_ptr) {
140 auto it = primitive_name_map_.find(primitive_ptr);
141 if (it != primitive_name_map_.end()) {
142 return it->second;
143 }
144 // Remove this check if we find a way to handle save/load training model with flattened parameters.
145 if (IsPrimitiveEquals(primitive_ptr, prim::kPrimFlattenConcat)) {
146 MS_LOG(EXCEPTION) << "Export model with operator '" << primitive_ptr->name() << "' is not supported yet.\n"
147 << "Please remove 'net.flatten_weights()' in your script and try again.";
148 }
149 auto answer = primitive_ptr->name() + ":" + std::to_string(GetUniqueID());
150 primitive_name_map_[primitive_ptr] = answer;
151 return answer;
152 }
153
BuildPrimitives()154 bool IrExportBuilder::BuildPrimitives() {
155 for (auto it = primitive_name_map_.begin(); it != primitive_name_map_.end(); ++it) {
156 auto prim = it->first;
157 if (prim->name() == prim::kPrimPyExecute->name()) {
158 MS_LOG(EXCEPTION) << "Cannot export a PyExecute CNode in MindIR.";
159 }
160 auto prim_proto = model_->add_primitives();
161
162 prim_proto->set_name(it->second);
163 prim_proto->set_op_type(prim->name());
164 // function IsPrimitiveFunction: dynamic shape new primitive
165 // attr is_primitive_function: default true, Lite MindIr false
166 bool is_primitive_function =
167 prim->GetAttr("primitive_function") == nullptr || GetValue<bool>(prim->GetAttr("primitive_function"));
168 if (mindspore::ops::IsPrimitiveFunction(prim->name()) && is_primitive_function) {
169 prim_proto->set_prim_type(mind_ir::PrimitiveProto_PrimType_PRIMITIVE_FUNCTION);
170 } else {
171 prim_proto->set_prim_type(mind_ir::PrimitiveProto_PrimType_PRIMITIVE);
172 }
173
174 auto real_prim = GetValueWithoutDoSignature(prim)->cast<PrimitivePtr>();
175 if (real_prim != nullptr) {
176 prim = real_prim;
177 }
178
179 prim_proto->set_instance_name(prim->instance_name());
180
181 // Set primitive attributes
182 for (const auto &attr : prim->attrs()) {
183 MS_LOG(DEBUG) << "attr: " << attr.first << " " << attr.second->DumpText() << " " << attr.second->type_name();
184 auto iter = g_export_attr_blacklist.find(attr.first);
185 if (iter != g_export_attr_blacklist.end()) {
186 continue;
187 }
188 if (attr.second == nullptr) {
189 MS_LOG(ERROR) << "attr: " << attr.first << " has no value.";
190 continue;
191 }
192 mind_ir::AttributeProto *attr_proto = prim_proto->add_attribute();
193 attr_proto->set_name(attr.first);
194 auto attr_value = attr.second;
195 if (!is_kernel_graph_) {
196 CheckAndConvertUtils::ConvertAttrValueInExport(prim->name(), attr.first, &attr_value);
197 }
198 if (!SetValueToAttributeProto(attr_value, attr_proto)) {
199 MS_LOG(ERROR) << "Set value to AttributeProto failed.";
200 return false;
201 }
202 } // Loop of attrs
203 } // Loop of primitives
204 return true;
205 }
206
GetDumpString(const FuncGraphPtr & func_graph)207 std::string IrExporter::GetDumpString(const FuncGraphPtr &func_graph) {
208 auto dump_proto = GetDumpProto(func_graph);
209 if (dump_proto == nullptr) {
210 MS_LOG(EXCEPTION) << "Get dump proto for graph " << func_graph->ToString() << " failed.";
211 }
212 return builder_->GetProtoString();
213 }
214
GetDumpProto(const FuncGraphPtr & func_graph)215 ModelProtoPtr IrExporter::GetDumpProto(const FuncGraphPtr &func_graph) {
216 if ((builder_ == nullptr) || (func_graph == nullptr)) {
217 MS_LOG(EXCEPTION) << "Input params is null.";
218 }
219
220 // Export model info
221 builder_->BuildModelInfo();
222
223 // Export model and return string
224 if (!builder_->BuildModel(func_graph)) {
225 return nullptr;
226 }
227 return builder_->Model();
228 }
229
GetDumpProto(const FuncGraphPtr & root_graph,const std::vector<FuncGraphPtr> & child_graphs,const std::vector<AnfNodePtr> & isolated_nodes)230 ModelProtoPtr IrExporter::GetDumpProto(const FuncGraphPtr &root_graph, const std::vector<FuncGraphPtr> &child_graphs,
231 const std::vector<AnfNodePtr> &isolated_nodes) {
232 // Export model info
233 builder_->BuildModelInfo();
234 // Export model and return string
235 if (!builder_->BuildModel(root_graph, child_graphs, isolated_nodes)) {
236 return nullptr;
237 }
238 return builder_->Model();
239 }
240
GetProtoString() const241 std::string IrExportBuilder::GetProtoString() const {
242 MS_LOG(DEBUG) << "BuildModel complete!";
243 return model_->SerializeAsString();
244 }
245
BuildModelInfo()246 void IrExportBuilder::BuildModelInfo() {
247 constexpr auto ir_version = "0.1.1";
248 constexpr auto mindspore_name = "MindSpore";
249 model_->set_ir_version(ir_version);
250 model_->set_producer_name(mindspore_name);
251 model_->set_model_version(VERSION);
252 model_->set_little_endian(common::IsLittleByteOrder());
253 model_->set_mind_ir_version(mind_ir::Version_MAX);
254 }
255
256 // build model for kernel graph
BuildModel(const FuncGraphPtr & root_graph,const std::vector<FuncGraphPtr> & child_graphs,const std::vector<AnfNodePtr> & isolated_nodes)257 bool IrExportBuilder::BuildModel(const FuncGraphPtr &root_graph, const std::vector<FuncGraphPtr> &child_graphs,
258 const std::vector<AnfNodePtr> &isolated_nodes) {
259 MS_EXCEPTION_IF_NULL(root_graph);
260 is_kernel_graph_ = root_graph->type_name() == kKernelGraphTypeName;
261 nodeName_.clear();
262 node_name_map_.clear();
263 primitive_name_map_.clear();
264
265 // Because param may be called across graphs, build params of all graphs first.
266 auto build_params_attrs = [this](const FuncGraphPtr &graph, mind_ir::GraphProto *const proto) {
267 if (!BuildParameters(graph, proto)) {
268 MS_LOG(ERROR) << "Build graph parameters failed.";
269 return false;
270 }
271 if (!BuildFuncGraphAttrs(graph, proto)) {
272 MS_LOG(ERROR) << "Build graph parameters attrs failed.";
273 return false;
274 }
275 return true;
276 };
277
278 (void)nodeName_.insert(root_graph->ToString());
279 auto root_graph_proto = model_->mutable_graph();
280 // build root graph params
281 top_graph = true;
282 if (!(build_params_attrs(root_graph, root_graph_proto))) {
283 return false;
284 }
285 root_graph_proto->set_name(root_graph->ToString());
286 graph_protos_[root_graph] = root_graph_proto;
287 // build child graph params
288 top_graph = false;
289 for (const auto &graph : child_graphs) {
290 auto func_proto = model_->add_functions();
291 func_proto->set_name(graph->ToString());
292 (void)nodeName_.insert(graph->ToString());
293 if (!(build_params_attrs(graph, func_proto))) {
294 return false;
295 }
296 graph_protos_[graph] = func_proto;
297 }
298 // build nodes for root_graph, then child_graph
299 if (!BuildNodes(root_graph, root_graph_proto)) {
300 return false;
301 }
302 std::map<std::string, FuncGraphPtr> sorted_graphs;
303 std::for_each(child_graphs.begin(), child_graphs.end(),
304 [&sorted_graphs](const auto &iter) { sorted_graphs[iter->ToString()] = iter; });
305 for (const auto &iter : sorted_graphs) {
306 const auto &graph = iter.second;
307 if (!BuildNodes(graph, graph_protos_[graph])) {
308 return false;
309 }
310 }
311 if (!BuildIsolatedNodes(isolated_nodes)) {
312 return false;
313 }
314 // build primitives
315 if (!BuildPrimitives()) {
316 return false;
317 }
318 // Release resource
319 nodeName_.clear();
320 node_name_map_.clear();
321 primitive_name_map_.clear();
322 graph_protos_.clear();
323 return true;
324 }
325
BuildModel(const FuncGraphPtr & func_graph)326 bool IrExportBuilder::BuildModel(const FuncGraphPtr &func_graph) {
327 MS_EXCEPTION_IF_NULL(func_graph);
328 mind_ir::GraphProto *graph_proto = model_->mutable_graph();
329 graph_proto->set_name(func_graph->ToString());
330 graph_proto->set_bprop_hash(func_graph->bprop_hash());
331 graph_proto->set_bprop_filepath(func_graph->bprop_filepath());
332 todo_.clear();
333 nodeName_.clear();
334 primitive_name_map_.clear();
335 // Build the main funcGraph
336 (void)nodeName_.insert(func_graph->ToString());
337 top_graph = true;
338
339 if (!BuildFuncGraph(func_graph, graph_proto)) {
340 MS_LOG(ERROR) << "Build func_graph " << func_graph->ToString() << " failed.";
341 return false;
342 }
343
344 // Build child funcGraphs
345 std::set<FuncGraphPtr> graphVisited;
346 (void)graphVisited.insert(func_graph);
347 top_graph = false;
348
349 auto &context = CompileCacheContext::GetInstance();
350 const auto &child_graphs = context.GetChileGraphs();
351 (void)(std::transform(child_graphs.begin(), child_graphs.end(), std::back_inserter(todo_),
352 [](const FuncGraphPtr &g) { return g; }));
353 while (!todo_.empty()) {
354 FuncGraphPtr fg = todo_.back();
355 todo_.pop_back();
356 if (graphVisited.count(fg) > 0) {
357 continue;
358 }
359 if (nodeName_.count(fg->ToString()) > 0) {
360 MS_LOG(ERROR) << "There is a duplicate name: " << fg->ToString();
361 return false;
362 }
363 (void)nodeName_.insert(fg->ToString());
364 (void)graphVisited.insert(fg);
365 auto graph = model_->add_functions();
366 if (!BuildFuncGraph(fg, graph)) {
367 MS_LOG(ERROR) << "Build func_graph " << fg->ToString() << " failed.";
368 return false;
369 }
370 }
371
372 if (!BuildPrimitives()) {
373 return false;
374 }
375 // Release resource
376 nodeName_.clear();
377 node_name_map_.clear();
378 primitive_name_map_.clear();
379 graph_protos_.clear();
380 MS_LOG(INFO) << "BuildModel end.";
381 return true;
382 }
383
BuildFuncGraph(const FuncGraphPtr & func_graph,mind_ir::GraphProto * const graph_proto)384 bool IrExportBuilder::BuildFuncGraph(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto) {
385 graph_protos_[func_graph] = graph_proto;
386 // Export funcGraph name.
387 graph_proto->set_name(func_graph->ToString());
388 // Export parameters
389 // 1. parameters should be mapped to ValueInfoProto
390 // 2. parameters with default value should be mapped to Initializer
391 if (!BuildParameters(func_graph, graph_proto)) {
392 MS_LOG(ERROR) << "Build parameters failed.";
393 return false;
394 }
395
396 // Export graph attributes
397 if (!BuildFuncGraphAttrs(func_graph, graph_proto)) {
398 MS_LOG(ERROR) << "Build attributes for graph failed.";
399 return false;
400 }
401
402 // Export operator nodes(include output)
403 return BuildNodes(func_graph, graph_proto);
404 }
405
BuildFuncGraphAttrs(const FuncGraphPtr & func_graph,mind_ir::GraphProto * const graph_proto)406 bool IrExportBuilder::BuildFuncGraphAttrs(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto) {
407 MS_EXCEPTION_IF_NULL(func_graph);
408 MS_EXCEPTION_IF_NULL(graph_proto);
409 for (const auto &attr : func_graph->attrs()) {
410 MS_LOG(DEBUG) << "attr: " << attr.first << " " << attr.second->DumpText() << " " << attr.second->type_name();
411 auto iter = g_export_attr_blacklist.find(attr.first);
412 if (iter != g_export_attr_blacklist.end()) {
413 continue;
414 }
415 if (attr.second == nullptr) {
416 MS_LOG(ERROR) << "attr: " << attr.first << " has no value.";
417 continue;
418 }
419 mind_ir::AttributeProto *attr_proto = graph_proto->add_attribute();
420 attr_proto->set_name(attr.first);
421 if (!SetValueToAttributeProto(attr.second, attr_proto)) {
422 MS_LOG(ERROR) << "Set value to AttributeProto for GraphProto failed.";
423 return false;
424 }
425 }
426 return true;
427 }
428
ExportWeight(const ParameterPtr & param,const std::string & param_name,mind_ir::GraphProto * const graph_proto)429 bool IrExportBuilder::ExportWeight(const ParameterPtr ¶m, const std::string ¶m_name,
430 mind_ir::GraphProto *const graph_proto) {
431 MS_LOG(DEBUG) << "Parameter: '" << param->DebugString();
432 auto param_abs = param->abstract();
433 MS_EXCEPTION_IF_NULL(param_abs);
434 if (param_abs->isa<abstract::AbstractMapTensor>()) {
435 auto *map_parameter_proto = graph_proto->add_map_parameter();
436 if (!ConvertMapParameterToMapTensorProto(param, map_parameter_proto)) {
437 MS_LOG(ERROR) << "Convert MapParameter " << param->ToString() << " to MapTensorProto failed.";
438 return false;
439 }
440 return true;
441 }
442 if (param_abs->isa<abstract::AbstractTensor>()) {
443 mind_ir::TensorProto *parameter_proto = graph_proto->add_parameter();
444 parameter_proto->set_name(param_name);
445 if (!SetParamToTensorProto(param, parameter_proto)) {
446 MS_LOG(ERROR) << "Set parameter " << param->DebugString() << " to TensorProto failed.";
447 return false;
448 }
449 return true;
450 }
451 MS_LOG(ERROR) << "Only support MapTensor or Tensor as default param of Parameter, got: "
452 << param->default_param()->ToString();
453 return false;
454 }
455
BuildParameters(const FuncGraphPtr & func_graph,mind_ir::GraphProto * const graph_proto)456 bool IrExportBuilder::BuildParameters(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto) {
457 MS_EXCEPTION_IF_NULL(func_graph);
458 MS_EXCEPTION_IF_NULL(graph_proto);
459 auto &context = CompileCacheContext::GetInstance();
460 auto param_size = func_graph->parameters().size();
461 MS_LOG(DEBUG) << "func graph: " << func_graph->ToString() << " parameter num:" << param_size
462 << ", fv param num:" << func_graph->fv_param_count();
463 for (size_t param_counter = 0; param_counter < param_size; ++param_counter) {
464 auto &item = func_graph->parameters()[param_counter];
465 MS_EXCEPTION_IF_NULL(item);
466 auto param = item->cast<ParameterPtr>();
467 if (param == nullptr) {
468 MS_LOG(ERROR) << "Parameter: '" << item->ToString() << "' could not cast to parameter.";
469 return false;
470 }
471 if (is_kernel_graph_ && (node_name_map_.find(param) != node_name_map_.end() || param->func_graph() != func_graph)) {
472 continue;
473 }
474 std::string param_name = GetUniqueNodeName(param);
475 param->set_user_data<std::string>(kUniqueCacheName, std::make_shared<std::string>(param_name));
476 if (is_kernel_graph_ && context.IsBackendParamGenFromFrontendParam(param)) {
477 (void)nodeName_.insert(param_name);
478 continue;
479 }
480 if (top_graph &&
481 (param_counter >= param_size - func_graph->fv_param_count() || (is_kernel_graph_ && param->has_default()))) {
482 if (!ExportWeight(param, param_name, graph_proto)) {
483 MS_LOG(ERROR) << "Failed to export parameter weight:" << param->DebugString();
484 return false;
485 }
486 } else {
487 // export graph input
488 mind_ir::ValueInfoProto *input_proto = graph_proto->add_input();
489 input_proto->set_name(param_name);
490 if (!SetValueInfoProto(param, input_proto)) {
491 MS_LOG(ERROR) << "Set parameter " << param->DebugString() << " to TensorProto failed.";
492 return false;
493 }
494 }
495 if (nodeName_.count(param_name) > 0) {
496 MS_LOG(ERROR) << "parameter name is duplicate:" << param_name;
497 return false;
498 }
499 (void)nodeName_.insert(param_name);
500 }
501 return true;
502 }
503
SetQuantizationParamToAttrProto(const std::shared_ptr<QuantizationParam> & quantization_param,mind_ir::TensorProto_QuantParamProto * const quant_param_proto)504 bool IrExportBuilder::SetQuantizationParamToAttrProto(const std::shared_ptr<QuantizationParam> &quantization_param,
505 mind_ir::TensorProto_QuantParamProto *const quant_param_proto) {
506 quant_param_proto->set_quant_algo_name(quantization_param->quant_algo_name());
507 auto quant_param_attrs = quantization_param->attrs();
508 for (auto &quant_param_attr : quant_param_attrs) {
509 if (quant_param_attr.second == nullptr) {
510 MS_LOG(ERROR) << "attr: " << quant_param_attr.first << " has no value.";
511 continue;
512 }
513 auto attr_proto = quant_param_proto->add_attribute();
514 attr_proto->set_name(quant_param_attr.first);
515 auto value_ptr = quant_param_attr.second;
516 auto ret = SetValueToAttributeProto(value_ptr, attr_proto);
517 if (!ret) {
518 MS_LOG(ERROR) << "QuantizationParam Set Value to AttributeProto Error";
519 return false;
520 }
521 }
522 return true;
523 }
524
SetFunctorToAttrProto(const FunctorPtr & func,mind_ir::AttributeProto * const attr_proto)525 bool IrExportBuilder::SetFunctorToAttrProto(const FunctorPtr &func, mind_ir::AttributeProto *const attr_proto) {
526 auto *functor_proto = attr_proto->mutable_functor();
527 attr_proto->set_type(mind_ir::AttributeProto_AttributeType_FUNCTOR);
528 if (func->isa<ShapeCalcBaseFunctor>()) {
529 functor_proto->set_type(mind_ir::FunctorProto_FunctorType_SHAPE_CALC_FUNCTOR);
530 } else {
531 MS_LOG(ERROR) << "Unknown functor: " << func->ToString();
532 return false;
533 }
534 functor_proto->set_name(func->name());
535 auto values = func->ToValue();
536 if (values == nullptr) {
537 values = kNone;
538 }
539 if (!SetValueToAttributeProto(values, functor_proto->add_values())) {
540 return false;
541 }
542 return true;
543 }
544
GetMindirDataType(TypeId type_id) const545 mind_ir::TensorProto_DataType IrExportBuilder::GetMindirDataType(TypeId type_id) const {
546 auto iter = g_data_type_map.find(type_id);
547 if (iter == g_data_type_map.end()) {
548 MS_LOG(ERROR) << "Convert type error, unsupported type! " << type_id;
549 return mind_ir::TensorProto_DataType_UNDEFINED;
550 }
551 return iter->second;
552 }
553
GetMindirDataBitsIntType(int bits) const554 mind_ir::TensorProto_DataType IrExportBuilder::GetMindirDataBitsIntType(int bits) const {
555 auto iter = g_data_bits_int_map.find(bits);
556 if (iter == g_data_bits_int_map.end()) {
557 MS_LOG(ERROR) << "Convert bits int error, unsupported bits! " << bits;
558 return mind_ir::TensorProto_DataType_UNDEFINED;
559 }
560 return iter->second;
561 }
562
GetMindirDataBitsUIntType(int bits) const563 mind_ir::TensorProto_DataType IrExportBuilder::GetMindirDataBitsUIntType(int bits) const {
564 auto iter = g_data_bits_uint_map.find(bits);
565 if (iter == g_data_bits_uint_map.end()) {
566 MS_LOG(ERROR) << "Convert bits uint error, unsupported bits! " << bits;
567 return mind_ir::TensorProto_DataType_UNDEFINED;
568 }
569 return iter->second;
570 }
571
GetMindirDataBitsFloatType(int bits) const572 mind_ir::TensorProto_DataType IrExportBuilder::GetMindirDataBitsFloatType(int bits) const {
573 auto iter = g_data_bits_float_map.find(bits);
574 if (iter == g_data_bits_float_map.end()) {
575 MS_LOG(ERROR) << "Convert bits float error, unsupported bits! " << bits;
576 return mind_ir::TensorProto_DataType_UNDEFINED;
577 }
578 return iter->second;
579 }
580
GetMindirDataBitsBFloatType(int bits) const581 mind_ir::TensorProto_DataType IrExportBuilder::GetMindirDataBitsBFloatType(int bits) const {
582 auto iter = g_data_bits_bfloat_map.find(bits);
583 if (iter == g_data_bits_bfloat_map.end()) {
584 MS_LOG(ERROR) << "Convert bits bfloat error, unsupported bits! " << bits;
585 return mind_ir::TensorProto_DataType_UNDEFINED;
586 }
587 return iter->second;
588 }
589
GetMindirDataBitsComplexType(int bits) const590 mind_ir::TensorProto_DataType IrExportBuilder::GetMindirDataBitsComplexType(int bits) const {
591 auto iter = g_data_bits_complex_map.find(bits);
592 if (iter == g_data_bits_complex_map.end()) {
593 MS_LOG(ERROR) << "Convert bits float error, unsupported bits! " << bits;
594 return mind_ir::TensorProto_DataType_UNDEFINED;
595 }
596 return iter->second;
597 }
598
SetValueInfoProto(const AnfNodePtr & node,mind_ir::ValueInfoProto * const value_proto)599 bool IrExportBuilder::SetValueInfoProto(const AnfNodePtr &node, mind_ir::ValueInfoProto *const value_proto) {
600 if (node == nullptr || value_proto == nullptr) {
601 MS_LOG(EXCEPTION) << "AnfNode or ValueInfo is null!";
602 }
603 MS_LOG(DEBUG) << "SetValueInfoProto: " << node->DebugString();
604 const TypePtr &type = node->Type();
605 const BaseShapePtr &shape = node->Shape();
606 // For the bprop fg which has not been renormalized.
607 if (type == nullptr || shape == nullptr) {
608 return true;
609 }
610 if (type->isa<TensorType>() && shape->isa<abstract::Shape>()) {
611 mind_ir::TensorProto *tensor_proto = value_proto->add_tensor();
612 if (!SetTensorProto(node->abstract(), tensor_proto)) {
613 return false;
614 }
615 } else {
616 mind_ir::AttributeProto *attribute = value_proto->mutable_attr_info();
617 if (!SetAbstractToNodeProto(node->abstract(), attribute)) {
618 MS_LOG(ERROR) << "Set shape to Proto for " << node->DebugString() << " failed.";
619 return false;
620 }
621 value_proto->set_denotation(type->type_name());
622 }
623 MS_LOG(DEBUG) << "Value type: " << type->type_name();
624 return true;
625 }
626
SetTensorToAttributeProto(const ValuePtr & value,mind_ir::AttributeProto * const attr_proto)627 bool IrExportBuilder::SetTensorToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) {
628 if (value == nullptr || attr_proto == nullptr) {
629 MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!";
630 }
631 attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSORS);
632 mind_ir::TensorProto *tensor_proto = attr_proto->add_tensors();
633 tensor_proto->set_name("value0");
634 auto data = value->cast<tensor::TensorPtr>();
635 MS_EXCEPTION_IF_NULL(data);
636 tensor_proto->set_raw_data(data->data_c(), static_cast<size_t>(data->data().nbytes()));
637 auto dtype = data->data_type();
638 auto shape = data->shape_c();
639 auto data_type = GetMindirDataType(dtype);
640 if (data_type == mind_ir::TensorProto_DataType_UNDEFINED) {
641 return false;
642 }
643 tensor_proto->set_data_type(data_type);
644 for (const auto &dim : shape) {
645 tensor_proto->add_dims(dim);
646 }
647 return true;
648 }
649
SetCSRTensorToProto(const AbstractBasePtr & abstract,mind_ir::AttributeProto * const attr_proto)650 bool IrExportBuilder::SetCSRTensorToProto(const AbstractBasePtr &abstract, mind_ir::AttributeProto *const attr_proto) {
651 abstract::AbstractCSRTensorPtr csr_tensor_abs = abstract->cast<abstract::AbstractCSRTensorPtr>();
652 MS_EXCEPTION_IF_NULL(csr_tensor_abs);
653 attr_proto->set_type(mind_ir::AttributeProto_AttributeType_CSR_TENSOR);
654 mind_ir::AttributeProto *indptr = attr_proto->add_values();
655 bool res = SetAbstractToNodeProto(csr_tensor_abs->indptr(), indptr);
656 mind_ir::AttributeProto *indices = attr_proto->add_values();
657 res = res && SetAbstractToNodeProto(csr_tensor_abs->indices(), indices);
658 mind_ir::AttributeProto *values = attr_proto->add_values();
659 res = res && SetAbstractToNodeProto(csr_tensor_abs->values(), values);
660 mind_ir::AttributeProto *shape = attr_proto->add_values();
661 res = res && SetAbstractToNodeProto(csr_tensor_abs->shape(), shape);
662 return res;
663 }
664
SetCOOTensorToProto(const AbstractBasePtr & abstract,mind_ir::AttributeProto * const attr_proto)665 bool IrExportBuilder::SetCOOTensorToProto(const AbstractBasePtr &abstract, mind_ir::AttributeProto *const attr_proto) {
666 abstract::AbstractCOOTensorPtr coo_tensor_abs = abstract->cast<abstract::AbstractCOOTensorPtr>();
667 MS_EXCEPTION_IF_NULL(coo_tensor_abs);
668 attr_proto->set_type(mind_ir::AttributeProto_AttributeType_COO_TENSOR);
669 mind_ir::AttributeProto *indices = attr_proto->add_values();
670 bool res = SetAbstractToNodeProto(coo_tensor_abs->indices(), indices);
671 mind_ir::AttributeProto *values = attr_proto->add_values();
672 res = res && SetAbstractToNodeProto(coo_tensor_abs->values(), values);
673 mind_ir::AttributeProto *shape = attr_proto->add_values();
674 res = res && SetAbstractToNodeProto(coo_tensor_abs->shape(), shape);
675 return res;
676 }
677
SetTensorProto(const AbstractBasePtr & abstract,mind_ir::TensorProto * const tensor_proto)678 bool IrExportBuilder::SetTensorProto(const AbstractBasePtr &abstract, mind_ir::TensorProto *const tensor_proto) {
679 auto type = abstract->BuildType();
680 auto shape = abstract->BuildShape();
681 if (!type->isa<TensorType>() || !shape->isa<abstract::Shape>()) {
682 MS_LOG(ERROR) << "Type or shape is not supported! " << type->ToString();
683 return false;
684 }
685 auto tensor = type->cast<TensorTypePtr>();
686 auto tensor_shape = shape->cast<abstract::ShapePtr>();
687 const auto &dims = tensor_shape->shape();
688 auto data_type = GetMindirDataType(tensor->element()->type_id());
689 if (data_type == mind_ir::TensorProto_DataType_UNDEFINED) {
690 return false;
691 }
692 tensor_proto->set_data_type(data_type);
693 for (const auto &dim : dims) {
694 tensor_proto->add_dims(dim);
695 }
696
697 if (!abstract->name().empty()) {
698 tensor_proto->set_name(abstract->name());
699 }
700 // Deal Ref
701 if (!type->isa<RefType>()) {
702 return true;
703 }
704
705 auto abs_ref = abstract->cast<abstract::AbstractRefPtr>();
706 if (abs_ref == nullptr) {
707 MS_LOG(ERROR) << "The abstract " << abstract->ToString() << " should be AbstractRefTensor.";
708 return false;
709 }
710 auto ref_key_value = abs_ref->ref_key_value()->cast<StringImmPtr>();
711 if (ref_key_value == nullptr) {
712 MS_LOG(INFO) << "The ref_key_value of abstract ref " << abstract->ToString() << " is nullptr";
713 return true;
714 }
715 tensor_proto->set_ref_key(ref_key_value->value());
716 return true;
717 }
718
SetParamToTensorProto(const ParameterPtr & param,mind_ir::TensorProto * const tensor_proto)719 bool IrExportBuilder::SetParamToTensorProto(const ParameterPtr ¶m, mind_ir::TensorProto *const tensor_proto) {
720 if (param == nullptr || tensor_proto == nullptr) {
721 MS_LOG(EXCEPTION) << "Parameter or TensorProto is null!";
722 }
723 MS_LOG(DEBUG) << "SetParamToTensorProto: " << param->DebugString();
724 if (!SetTensorProto(param->abstract(), tensor_proto)) {
725 MS_LOG(ERROR) << "Export Parameter to tensor proto failed.";
726 return false;
727 }
728 // export quant parameter info
729 auto tensor = param->default_param()->cast<tensor::TensorPtr>();
730 if (tensor != nullptr) {
731 tensor_proto->set_compression_type(static_cast<mind_ir::TensorProto_CompressionType>(tensor->compression_type()));
732 auto quant_params = tensor->quant_params();
733 for (const auto &quant_param : quant_params) {
734 auto quant_param_proto = tensor_proto->add_quant_params();
735 auto ret = SetQuantizationParamToAttrProto(quant_param, quant_param_proto);
736 if (ret != true) {
737 MS_LOG(ERROR) << "QuantizationParam Set Value to AttributeProto Error";
738 return false;
739 }
740 }
741 }
742 return true;
743 }
744
ConvertMapParameterToMapTensorProto(const ParameterPtr & map_parameter,mind_ir::MapTensorProto * const map_tensor_proto)745 bool IrExportBuilder::ConvertMapParameterToMapTensorProto(const ParameterPtr &map_parameter,
746 mind_ir::MapTensorProto *const map_tensor_proto) {
747 if (map_parameter == nullptr || map_tensor_proto == nullptr) {
748 MS_LOG(EXCEPTION) << "MapParameter or MapTensorProto is null!";
749 }
750 MS_LOG(DEBUG) << "ConvertMapParameterToMapTensorProto: " << map_parameter->ToString();
751
752 // parameter name
753 map_tensor_proto->set_name(GetUniqueNodeName(map_parameter));
754
755 auto param_default = map_parameter->default_param();
756 MS_EXCEPTION_IF_NULL(param_default);
757 auto map_tensor = param_default->cast<tensor::MapTensorPtr>();
758 MS_EXCEPTION_IF_NULL(map_tensor);
759 // default value
760 auto default_value = map_tensor->default_value();
761 MS_EXCEPTION_IF_NULL(default_value);
762 auto *default_value_proto = map_tensor_proto->mutable_default_value();
763 MS_EXCEPTION_IF_NULL(default_value_proto);
764 if (!SetValueToAttributeProto(default_value, default_value_proto)) {
765 MS_LOG(ERROR) << "Export default value of MapTensor failed, default_value: " << default_value->ToString();
766 return false;
767 }
768 tensor::MapTensor::ExportData export_data = map_tensor->Export(this->incremental_);
769 // key_tensor
770 auto *key_tensor_proto = map_tensor_proto->mutable_key_tensor();
771 MS_EXCEPTION_IF_NULL(key_tensor_proto);
772 auto &key_tensor = export_data.key_tensor;
773 MS_EXCEPTION_IF_NULL(key_tensor);
774 if (!SetTensorProto(key_tensor->ToAbstract(), key_tensor_proto)) {
775 MS_LOG(ERROR) << "Export key tensor of MapTensor failed, key_tensor: " << key_tensor->ToString();
776 return false;
777 }
778 // value_tensor
779 auto *value_tensor_proto = map_tensor_proto->mutable_value_tensor();
780 MS_EXCEPTION_IF_NULL(value_tensor_proto);
781 auto &value_tensor = export_data.value_tensor;
782 MS_EXCEPTION_IF_NULL(value_tensor);
783 if (!SetTensorProto(value_tensor->ToAbstract(), value_tensor_proto)) {
784 MS_LOG(ERROR) << "Export value tensor of MapTensor failed, value_tensor: " << value_tensor->ToString();
785 return false;
786 }
787 // status_tensor
788 auto *status_tensor_proto = map_tensor_proto->mutable_status_tensor();
789 MS_EXCEPTION_IF_NULL(status_tensor_proto);
790 auto &status_tensor = export_data.status_tensor;
791 MS_EXCEPTION_IF_NULL(status_tensor);
792 if (!SetTensorProto(status_tensor->ToAbstract(), status_tensor_proto)) {
793 MS_LOG(ERROR) << "Export status tensor of MapTensor failed, status_tensor: " << status_tensor->ToString();
794 return false;
795 }
796 return true;
797 }
798
ConvertAbstractMapTensorToAttrProto(const AbstractBasePtr & abstract,mind_ir::AttributeProto * const attr_proto)799 bool IrExportBuilder::ConvertAbstractMapTensorToAttrProto(const AbstractBasePtr &abstract,
800 mind_ir::AttributeProto *const attr_proto) {
801 auto map_tensor_abs = abstract->cast<abstract::AbstractMapTensorPtr>();
802 MS_EXCEPTION_IF_NULL(map_tensor_abs);
803
804 auto map_tensor_type = map_tensor_abs->map_tensor_type();
805 MS_EXCEPTION_IF_NULL(map_tensor_type);
806 attr_proto->set_type(mind_ir::AttributeProto_AttributeType_MAP_TENSOR);
807 // key_tensor
808 auto key_dtype = map_tensor_type->key_dtype();
809 auto key_shape = {abstract::Shape::kShapeDimAny};
810 auto key_tensor_abs = std::make_shared<abstract::AbstractTensor>(key_dtype, key_shape);
811 auto *key_tensor_proto = attr_proto->add_tensors();
812 MS_EXCEPTION_IF_NULL(key_tensor_proto);
813 MS_EXCEPTION_IF_NULL(key_tensor_abs);
814 if (!SetTensorProto(key_tensor_abs, key_tensor_proto)) {
815 MS_LOG(ERROR) << "Export key tensor abstract of AbstractMapTensor failed, abstract_map_tensor: "
816 << abstract->ToString();
817 return false;
818 }
819 // value_dtype value_shape
820 auto value_dtype = map_tensor_type->key_dtype();
821 auto value_shape = map_tensor_abs->value_shape()->shape();
822 auto value_tensor_abs = std::make_shared<abstract::AbstractTensor>(value_dtype, value_shape);
823 auto *value_tensor_proto = attr_proto->add_tensors();
824 MS_EXCEPTION_IF_NULL(value_tensor_proto);
825 MS_EXCEPTION_IF_NULL(value_tensor_abs);
826 if (!SetTensorProto(value_tensor_abs, value_tensor_proto)) {
827 MS_LOG(ERROR) << "Export value tensor abstract of AbstractMapTensor failed, abstract_map_tensor: "
828 << abstract->ToString();
829 return false;
830 }
831 // default_value
832 auto default_value = map_tensor_abs->default_value();
833 if (default_value != nullptr) {
834 auto *default_value_proto = attr_proto->add_values();
835 MS_EXCEPTION_IF_NULL(default_value_proto);
836 if (!SetValueToAttributeProto(default_value, default_value_proto)) {
837 MS_LOG(ERROR) << "Export default value of AbstractMapTensor failed, abstract_map_tensor: "
838 << abstract->ToString();
839 return false;
840 }
841 }
842 return true;
843 }
844
BuildNodes(const FuncGraphPtr & func_graph,mind_ir::GraphProto * const graph_proto)845 bool IrExportBuilder::BuildNodes(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto) {
846 std::vector<AnfNodePtr> nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude);
847 for (const AnfNodePtr &node : nodes) {
848 MS_EXCEPTION_IF_NULL(node);
849 if (!node->isa<CNode>()) {
850 MS_LOG(DEBUG) << "Node: '" << node->ToString() << "' is not cnode";
851 continue;
852 }
853 if (is_kernel_graph_ && (node_name_map_.find(node) != node_name_map_.end() || node->func_graph() != func_graph)) {
854 continue;
855 }
856 auto cnode = node->cast<CNodePtr>();
857 if (cnode == func_graph->get_return()) {
858 if (!BuildOutput(cnode, graph_proto)) {
859 MS_LOG(ERROR) << "Build output for graph " << func_graph->ToString() << " failed.";
860 return false;
861 }
862 } else {
863 auto iter = graph_protos_.find(node->func_graph());
864 if (iter == graph_protos_.end()) {
865 MS_LOG(ERROR) << "Can not find the graph proto of func_graph " << node->func_graph()->ToString();
866 return false;
867 }
868 auto owner_graph_proto = iter->second;
869 if (!BuildCNode(cnode, owner_graph_proto)) {
870 MS_LOG(ERROR) << "Build proto for cnode " << cnode->DebugString() << " failed.";
871 return false;
872 }
873 }
874 }
875 return true;
876 }
877
BuildIsolatedCNode(const AnfNodePtr & node,std::set<AnfNodePtr> * visited)878 bool IrExportBuilder::BuildIsolatedCNode(const AnfNodePtr &node, std::set<AnfNodePtr> *visited) {
879 MS_EXCEPTION_IF_NULL(node);
880 auto iter = node_name_map_.find(node);
881 if (iter != node_name_map_.end()) {
882 return true;
883 }
884 MS_EXCEPTION_IF_NULL(visited);
885 if (visited->find(node) != visited->end()) {
886 MS_LOG(ERROR) << "There is a cycle when build node " << node->DebugString();
887 return false;
888 }
889 if (!node->isa<CNode>()) {
890 return false;
891 }
892 const auto &cnode = node->cast<CNodePtr>();
893 MS_EXCEPTION_IF_NULL(cnode);
894 const auto &graph = cnode->func_graph();
895 if (!graph) {
896 MS_LOG(ERROR) << "The isolated node " << node->DebugString() << " is not belongs to any graph.";
897 return false;
898 }
899 auto graph_proto = graph_protos_[graph];
900 auto input_size = cnode->size();
901 std::vector<string> input_names;
902 // build input nodes
903 for (size_t i = 1; i < input_size; i++) {
904 auto input = cnode->input(i);
905 MS_EXCEPTION_IF_NULL(input);
906 if (input->isa<Parameter>()) {
907 MS_LOG(ERROR) << "Only support that the isolated node's input is cnode or value_node, but the input is "
908 << input->DebugString();
909 return false;
910 }
911 std::string node_name;
912 if (input->isa<ValueNode>()) {
913 auto input_graph = input->func_graph();
914 auto input_proto = input_graph ? graph_protos_[input_graph] : graph_proto;
915 MS_EXCEPTION_IF_NULL(input_proto);
916 node_name = BuildInputNode(input, input_proto);
917 } else {
918 if (!BuildIsolatedCNode(input, visited)) {
919 MS_LOG(ERROR) << "Build input node for " << input->DebugString() << " failed.";
920 return false;
921 }
922 node_name = GetUniqueNodeName(input);
923 }
924 if (node_name.empty()) {
925 MS_LOG(ERROR) << "Build input node for " << input->DebugString() << " failed.";
926 return false;
927 }
928 input->set_user_data<std::string>(kUniqueCacheName, std::make_shared<std::string>(node_name));
929 input_names.push_back(node_name);
930 }
931 // build cnode
932 auto output_name = GetUniqueNodeName(cnode);
933 cnode->set_user_data<std::string>(kUniqueCacheName, std::make_shared<std::string>(output_name));
934 if (nodeName_.count(output_name) > 0) {
935 MS_LOG(INFO) << "There is a duplicate name: " << output_name;
936 return true;
937 }
938 mind_ir::NodeProto *node_proto = graph_proto->add_node();
939 (void)nodeName_.insert(output_name);
940 node_proto->add_output(output_name);
941 node_proto->set_name(output_name);
942 node_proto->set_domain(cnode->fullname_with_scope());
943 AnfNodePtr op = cnode->input(0);
944 std::string type_name = GetOpTypeName(op);
945 if (type_name.empty()) {
946 MS_LOG(ERROR) << "Get op type name for " << op->DebugString() << " failed.";
947 return false;
948 }
949 node_proto->set_op_type(type_name);
950 if (!SetAbstractToNodeProto(cnode, node_proto)) {
951 MS_LOG(DEBUG) << "Fail to export abstract of the node: " << node->DebugString();
952 }
953 (void)std::for_each(input_names.begin(), input_names.end(),
954 [&node_proto](const string &name) { node_proto->add_input(name); });
955 if (!BuildCNodeAttr(cnode, node_proto)) {
956 MS_LOG(ERROR) << "Set value to node attr to node proto failed.";
957 return false;
958 }
959 (void)(visited->insert(node));
960 return true;
961 }
962
BuildIsolatedNodes(const std::vector<AnfNodePtr> & isolated_nodes)963 bool IrExportBuilder::BuildIsolatedNodes(const std::vector<AnfNodePtr> &isolated_nodes) {
964 for (const auto &node : isolated_nodes) {
965 if (!node->isa<CNode>()) {
966 MS_LOG(ERROR) << "Only support that the isolated node is cnode, but the node is " << node->DebugString();
967 return false;
968 }
969 if (mindspore::IsPrimitiveCNode(node, mindspore::prim::kPrimReturn)) {
970 MS_LOG(ERROR) << "Only support that the isolated node is not return node, but the node is "
971 << node->DebugString();
972 return false;
973 }
974 std::set<AnfNodePtr> visited;
975 if (!BuildIsolatedCNode(node, &visited)) {
976 MS_LOG(ERROR) << "Build isolated node " << node->DebugString() << " failed.";
977 return false;
978 }
979 }
980
981 return true;
982 }
983
BuildOutput(const CNodePtr & node,mind_ir::GraphProto * const graph_proto)984 bool IrExportBuilder::BuildOutput(const CNodePtr &node, mind_ir::GraphProto *const graph_proto) {
985 MS_EXCEPTION_IF_NULL(node);
986 MS_EXCEPTION_IF_NULL(graph_proto);
987 const int OutputSize = 2;
988 if (node->size() != OutputSize) {
989 MS_LOG(ERROR) << "Number of inputs of return node is not equal to 2.";
990 return false;
991 }
992 auto graph_name = graph_proto->name();
993 node->set_user_data<std::string>(kUniqueCacheName, std::make_shared<std::string>(graph_name + kReturnNode));
994 AnfNodePtr arg = node->input(1);
995 auto node_name = BuildInputNode(arg, graph_proto);
996 arg->set_user_data<std::string>(kUniqueCacheName, std::make_shared<std::string>(node_name));
997 if (node_name.empty()) {
998 MS_LOG(ERROR) << "Build input node failed for arg " << arg->DebugString();
999 return false;
1000 }
1001 mind_ir::ValueInfoProto *output_proto = graph_proto->add_output();
1002 output_proto->set_name(node_name);
1003 // for return node primitive export
1004 AnfNodePtr op = node->input(0);
1005 op->set_user_data<std::string>(kUniqueCacheName, std::make_shared<std::string>(graph_name + kReturnPrimNode));
1006 return SetValueInfoProto(arg, output_proto);
1007 }
1008
GetOpTypeName(const AnfNodePtr & node)1009 std::string IrExportBuilder::GetOpTypeName(const AnfNodePtr &node) {
1010 // May be ValueNode/CNode/Parameter
1011 std::string type_name = "";
1012 if (IsValueNode<Primitive>(node)) {
1013 PrimitivePtr prim = GetValueNode<PrimitivePtr>(node);
1014 MS_EXCEPTION_IF_NULL(prim);
1015 auto do_sign_prim = prim->cast_ptr<prim::DoSignaturePrimitive>();
1016 if (do_sign_prim != nullptr && do_sign_prim->function() != nullptr &&
1017 do_sign_prim->function()->isa<MetaFuncGraph>()) {
1018 type_name = "REF::MetaFuncGraph::" + do_sign_prim->function()->cast_ptr<MetaFuncGraph>()->name();
1019 } else {
1020 const auto &unique_name = GetPrimitiveUniqueName(prim);
1021 type_name = "REF::" + unique_name;
1022 // for valuenode export
1023 node->set_user_data<std::string>(kUniqueCacheName, std::make_shared<std::string>(unique_name));
1024 }
1025 } else if (IsValueNode<FuncGraph>(node)) {
1026 FuncGraphPtr fg = GetValueNode<FuncGraphPtr>(node);
1027 MS_EXCEPTION_IF_NULL(fg);
1028 todo_.push_back(fg);
1029 type_name = "REF::" + fg->ToString();
1030 } else if (node->isa<CNode>() || node->isa<Parameter>()) {
1031 auto nodeName = GetUniqueNodeName(node);
1032 type_name = "REF::" + nodeName;
1033 if (nodeName_.count(nodeName) == 0) {
1034 MS_LOG(ERROR) << "There is not the name: " << nodeName;
1035 return "";
1036 }
1037 } else if (IsValueNode<MindIRClassType>(node)) {
1038 auto class_type = GetValueNode<MindIRClassTypePtr>(node)->name();
1039 // class 'XXX' -> XXX
1040 constexpr int64_t path_begin_index = 7;
1041 auto str = std::string(class_type.begin() + path_begin_index, class_type.end() - 1);
1042 type_name = "REF::ClassType::" + str;
1043 } else if (IsValueNode<MetaFuncGraph>(node)) {
1044 auto meta_fg = GetValueNode<MetaFuncGraphPtr>(node);
1045 MS_EXCEPTION_IF_NULL(meta_fg);
1046 type_name = "REF::MetaFuncGraph::" + meta_fg->name();
1047 } else {
1048 MS_LOG(ERROR) << "Need to support op type: " << node->DebugString();
1049 return "";
1050 }
1051 MS_LOG(DEBUG) << "ExportType: " << type_name;
1052 return type_name;
1053 }
1054
ExportSequence(const abstract::AbstractSequencePtr & seq_abs,mind_ir::AttributeProto * const attr_proto)1055 bool IrExportBuilder::ExportSequence(const abstract::AbstractSequencePtr &seq_abs,
1056 mind_ir::AttributeProto *const attr_proto) {
1057 if (seq_abs->isa<abstract::AbstractTuple>()) {
1058 attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TUPLE);
1059 } else {
1060 attr_proto->set_type(mind_ir::AttributeProto_AttributeType_LIST);
1061 }
1062 auto seq_info_proto = attr_proto->mutable_seq_info();
1063 seq_info_proto->set_is_dyn_len(seq_abs->dynamic_len());
1064
1065 auto elem_abs = seq_abs->dynamic_len_element_abs();
1066 if (elem_abs != nullptr) {
1067 mind_ir::AttributeProto *tuple_elem_proto = seq_info_proto->mutable_tuple_elem_item();
1068 if (!SetAbstractToNodeProto(elem_abs, tuple_elem_proto)) {
1069 return false;
1070 }
1071 }
1072
1073 const auto &elems = seq_abs->elements();
1074 for (const auto &item : elems) {
1075 mind_ir::AttributeProto *attr_values = attr_proto->add_values();
1076 if (!SetAbstractToNodeProto(item, attr_values)) {
1077 return false;
1078 }
1079 }
1080 return true;
1081 }
1082
SetAbstractToNodeProto(const AbstractBasePtr & abs,mind_ir::AttributeProto * const attr_proto)1083 bool IrExportBuilder::SetAbstractToNodeProto(const AbstractBasePtr &abs, mind_ir::AttributeProto *const attr_proto) {
1084 auto type = abs->BuildType();
1085 auto shape = abs->BuildShape();
1086 // Not use abstract because the abstract of csr tensor is a subclass of AbstractTuple
1087 if (type->isa<Tuple>() || type->isa<List>()) {
1088 return ExportSequence(abs->cast<abstract::AbstractSequencePtr>(), attr_proto);
1089 } else if (type->isa<TensorType>() && shape->isa<abstract::Shape>()) {
1090 attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSORS);
1091 mind_ir::TensorProto *tensor_proto = attr_proto->add_tensors();
1092 return SetTensorProto(abs, tensor_proto);
1093 } else if (type->isa<Number>()) {
1094 attr_proto->set_type(mind_ir::AttributeProto_AttributeType_SCALAR);
1095 mind_ir::TensorProto *tensor_proto = attr_proto->add_tensors();
1096 auto data_type = GetMindirDataType(type->type_id());
1097 tensor_proto->set_data_type(data_type);
1098 tensor_proto->add_dims(0);
1099 } else if (type->isa<Function>()) {
1100 if (!SetAbstractFuncToAttributeProto(abs, attr_proto)) {
1101 return false;
1102 }
1103 } else if (type->isa<String>()) {
1104 attr_proto->set_type(mind_ir::AttributeProto_AttributeType_STRING);
1105 } else if (type->isa<UMonadType>()) {
1106 attr_proto->set_type(mind_ir::AttributeProto_AttributeType_UMONAD);
1107 } else if (type->isa<IOMonadType>()) {
1108 attr_proto->set_type(mind_ir::AttributeProto_AttributeType_IOMONAD);
1109 } else if (type->isa<CSRTensorType>()) {
1110 auto csr_tensor_abs = abs->cast<abstract::AbstractCSRTensorPtr>();
1111 if (!SetCSRTensorToProto(csr_tensor_abs, attr_proto)) {
1112 return false;
1113 }
1114 } else if (type->isa<COOTensorType>()) {
1115 auto coo_tensor_abs = abs->cast<abstract::AbstractCOOTensorPtr>();
1116 if (!SetCOOTensorToProto(coo_tensor_abs, attr_proto)) {
1117 return false;
1118 }
1119 } else if (type->isa<TypeNone>()) {
1120 attr_proto->set_type(mind_ir::AttributeProto_AttributeType_NONE);
1121 } else if (type->isa<MapTensorType>()) {
1122 return ConvertAbstractMapTensorToAttrProto(abs, attr_proto);
1123 } else {
1124 MS_LOG(ERROR) << "Type of cnode need to be supported: " << type->type_name();
1125 return false;
1126 }
1127
1128 return true;
1129 }
1130
SetAbstractToNodeProto(const CNodePtr & node,mind_ir::NodeProto * const node_proto)1131 bool IrExportBuilder::SetAbstractToNodeProto(const CNodePtr &node, mind_ir::NodeProto *const node_proto) {
1132 // Get shape of cnode
1133 // 1. need to get shape from tuple element
1134 // 2. save shape in TensorProto
1135 MS_EXCEPTION_IF_NULL(node);
1136 auto type = node->Type();
1137 auto shape = node->Shape();
1138 auto abs = node->abstract();
1139 // For the bprop fg which has not been renormalized.
1140 if (type == nullptr || shape == nullptr) {
1141 return true;
1142 }
1143 mind_ir::AttributeProto *attr_proto = node_proto->add_attribute();
1144 if (!SetAbstractToNodeProto(abs, attr_proto)) {
1145 MS_LOG(WARNING) << "Set shape to NodeProto for " << node->DebugString() << " failed. abs: " << abs->ToString();
1146 return false;
1147 }
1148 attr_proto->set_name("shape");
1149 return true;
1150 }
1151
BuildCNode(const CNodePtr & node,mind_ir::GraphProto * const graph_proto)1152 bool IrExportBuilder::BuildCNode(const CNodePtr &node, mind_ir::GraphProto *const graph_proto) {
1153 auto inputs_size = node->size();
1154 if (inputs_size < 1) {
1155 MS_LOG(ERROR) << "Inputs of node " << node->DebugString() << " is empty";
1156 return false;
1157 }
1158
1159 // Need to build input node before dealing with cnode
1160 std::vector<string> input_names;
1161 for (size_t i = 1; i < inputs_size; i++) {
1162 auto input = node->input(i);
1163 std::string node_name = BuildInputNode(input, graph_proto);
1164 input->set_user_data<std::string>(kUniqueCacheName, std::make_shared<std::string>(node_name));
1165 if (node_name.empty()) {
1166 MS_LOG(ERROR) << "Build input node for " << input->DebugString() << " failed.";
1167 return false;
1168 }
1169 input_names.push_back(node_name);
1170 }
1171
1172 // Build cnode
1173 std::string output_name = GetUniqueNodeName(node);
1174 node->set_user_data<std::string>(kUniqueCacheName, std::make_shared<std::string>(output_name));
1175 if (nodeName_.count(output_name) > 0) {
1176 MS_LOG(INFO) << "There is a duplicate name: " << output_name;
1177 return true;
1178 }
1179
1180 mind_ir::NodeProto *node_proto = graph_proto->add_node();
1181 (void)nodeName_.insert(output_name);
1182 node_proto->add_output(output_name);
1183 node_proto->set_name(output_name);
1184 node_proto->set_domain(node->fullname_with_scope());
1185 AnfNodePtr op = node->input(0);
1186 std::string type_name = GetOpTypeName(op);
1187 if (type_name.empty()) {
1188 MS_LOG(ERROR) << "Get op type name for " << op->DebugString() << " failed.";
1189 return false;
1190 }
1191 node_proto->set_op_type(type_name);
1192 last_node_ = node_proto;
1193 if (!SetAbstractToNodeProto(node, node_proto)) {
1194 MS_LOG(DEBUG) << "Fail to export abstract of the node: " << node->DebugString();
1195 }
1196
1197 (void)std::for_each(input_names.begin(), input_names.end(),
1198 [&node_proto](const string &name) { node_proto->add_input(name); });
1199
1200 if (!BuildCNodeAttr(node, node_proto)) {
1201 MS_LOG(ERROR) << "Set value to node attr to node proto failed.";
1202 return false;
1203 }
1204 return true;
1205 }
1206
BuildValueNode(const ValueNodePtr & node,const string & node_name,mind_ir::GraphProto * const graph_proto)1207 bool IrExportBuilder::BuildValueNode(const ValueNodePtr &node, const string &node_name,
1208 mind_ir::GraphProto *const graph_proto) {
1209 // FuncGraphNode don't need to be exported to the proto in this step
1210 // check the node has been exported before
1211 if (IsValueNode<FuncGraph>(node) || nodeName_.count(node_name) > 0) {
1212 return true;
1213 }
1214 (void)nodeName_.insert(node_name);
1215 // When node input is a ValueNode, need to create a Constant Node
1216 mind_ir::NodeProto *node_proto = graph_proto->add_node();
1217 node_proto->set_name(node_name);
1218 node_proto->add_output(node_name);
1219 if (!SetAttributeProto(node, node_proto)) {
1220 return false;
1221 }
1222 return true;
1223 }
1224
BuildInputNode(const AnfNodePtr & node,mind_ir::GraphProto * const graph_proto)1225 std::string IrExportBuilder::BuildInputNode(const AnfNodePtr &node, mind_ir::GraphProto *const graph_proto) {
1226 std::string node_name = GetUniqueNodeName(node);
1227 if (node->isa<ValueNode>()) {
1228 if (!BuildValueNode(node->cast<ValueNodePtr>(), node_name, graph_proto)) {
1229 MS_LOG(ERROR) << "Export ValueNode Failed";
1230 return "";
1231 }
1232 MS_LOG(DEBUG) << "Export ValueNode " << node->DebugString() << " success";
1233 }
1234 return node_name;
1235 }
1236
GetUniqueNodeName(const AnfNodePtr & node)1237 std::string IrExportBuilder::GetUniqueNodeName(const AnfNodePtr &node) {
1238 // Naming anfnode
1239 // 1. parameter is unique in one func_graph
1240 // 2. cnode and valuenode may be reduplicative, so add index to identify.
1241 auto iter = node_name_map_.find(node);
1242 if (iter != node_name_map_.end()) {
1243 return iter->second;
1244 }
1245 // FuncGraph will be added to functions and the input name is the function name.
1246 if (IsValueNode<FuncGraph>(node)) {
1247 FuncGraphPtr fg = GetValueNode<FuncGraphPtr>(node);
1248 todo_.push_back(fg);
1249 auto name = fg->ToString();
1250 node->set_user_data<std::string>(kUniqueCacheName, std::make_shared<std::string>(name));
1251 return name;
1252 }
1253
1254 std::string node_name = GetNodeName(node);
1255 // Compatible before. CNode = FuncGraphName:CNodeName:index ,Parameter = FuncGraphName:ParameterName
1256 if (node->isa<CNode>()) {
1257 node_name = node_name + ":" + std::to_string(GetUniqueID());
1258 }
1259 // Avoid duplicate name.
1260 while (nodeName_.count(node_name) > 0) {
1261 node_name = node_name + "_" + std::to_string(GetUniqueID());
1262 }
1263 node_name_map_[node] = node_name;
1264 return node_name;
1265 }
1266
GetNodeName(const AnfNodePtr & node) const1267 std::string IrExportBuilder::GetNodeName(const AnfNodePtr &node) const {
1268 MS_EXCEPTION_IF_NULL(node);
1269 std::string node_name = "";
1270 if (node->func_graph() != nullptr) {
1271 node_name = node->func_graph()->ToString() + ":";
1272 }
1273 if (node->isa<ValueNode>()) {
1274 // Needn't value
1275 node_name += node->AnfNode::ToString();
1276 } else {
1277 node_name += node->ToString();
1278 }
1279 MS_LOG(DEBUG) << "GetNodeName: " << node_name;
1280 return node_name;
1281 }
1282
SetAttributeProto(const AnfNodePtr & node,mind_ir::NodeProto * const node_proto)1283 bool IrExportBuilder::SetAttributeProto(const AnfNodePtr &node, mind_ir::NodeProto *const node_proto) {
1284 if (node == nullptr || node_proto == nullptr) {
1285 MS_LOG(EXCEPTION) << "AnfNode or NodeProto is null!";
1286 }
1287 auto value_node = node->cast<ValueNodePtr>();
1288 MS_EXCEPTION_IF_NULL(value_node);
1289 auto value = value_node->value();
1290 MS_EXCEPTION_IF_NULL(value);
1291 node_proto->set_op_type("Constant");
1292 mind_ir::AttributeProto *attr_proto = node_proto->add_attribute();
1293 attr_proto->set_name("value");
1294 MS_LOG(DEBUG) << "Set Constant attribute: " << value->ToString();
1295 return SetValueToAttributeProto(value, attr_proto);
1296 }
1297
SetTensorTypeToAttributeProto(const ValuePtr & value,mind_ir::TensorProto * tensor_proto)1298 bool IrExportBuilder::SetTensorTypeToAttributeProto(const ValuePtr &value, mind_ir::TensorProto *tensor_proto) {
1299 tensor_proto->set_name("tensor0");
1300 auto elem_type = value->cast<TensorTypePtr>()->element();
1301 if (elem_type->isa<Int>()) {
1302 auto int_value = elem_type->cast<IntPtr>();
1303 auto data_type = GetMindirDataBitsIntType(int_value->nbits());
1304 if (data_type == mind_ir::TensorProto_DataType_UNDEFINED) {
1305 return false;
1306 }
1307 tensor_proto->set_data_type(data_type);
1308 } else if (elem_type->isa<Float>()) {
1309 auto float_value = elem_type->cast<FloatPtr>();
1310 auto data_type = GetMindirDataBitsFloatType(float_value->nbits());
1311 if (data_type == mind_ir::TensorProto_DataType_UNDEFINED) {
1312 return false;
1313 }
1314 tensor_proto->set_data_type(data_type);
1315 } else {
1316 MS_LOG(ERROR) << "Unsupported type " << elem_type->type_name();
1317 return false;
1318 }
1319 return true;
1320 }
1321
SetTypeToAttributeProto(const ValuePtr & value,mind_ir::AttributeProto * const attr_proto)1322 bool IrExportBuilder::SetTypeToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) {
1323 if (value == nullptr || attr_proto == nullptr) {
1324 MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!";
1325 }
1326 attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSORS);
1327 mind_ir::TensorProto *tensor_proto = attr_proto->add_tensors();
1328 if (value->isa<Int>()) {
1329 tensor_proto->set_name("value0");
1330 auto int_value = value->cast<IntPtr>();
1331 auto data_type = GetMindirDataBitsIntType(int_value->nbits());
1332 if (data_type == mind_ir::TensorProto_DataType_UNDEFINED) {
1333 return false;
1334 }
1335 tensor_proto->set_data_type(data_type);
1336 } else if (value->isa<UInt>()) {
1337 tensor_proto->set_name("value0");
1338 auto float_value = value->cast<UIntPtr>();
1339 auto data_type = GetMindirDataBitsUIntType(float_value->nbits());
1340 if (data_type == mind_ir::TensorProto_DataType_UNDEFINED) {
1341 return false;
1342 }
1343 tensor_proto->set_data_type(data_type);
1344 } else if (value->isa<Float>()) {
1345 tensor_proto->set_name("value0");
1346 auto float_value = value->cast<FloatPtr>();
1347 auto data_type = GetMindirDataBitsFloatType(float_value->nbits());
1348 if (data_type == mind_ir::TensorProto_DataType_UNDEFINED) {
1349 return false;
1350 }
1351 tensor_proto->set_data_type(data_type);
1352 } else if (value->isa<BFloat>()) {
1353 tensor_proto->set_name("value0");
1354 auto bfloat_value = value->cast<BFloatPtr>();
1355 MS_EXCEPTION_IF_NULL(bfloat_value);
1356 auto data_type = GetMindirDataBitsBFloatType(bfloat_value->nbits());
1357 if (data_type == mind_ir::TensorProto_DataType_UNDEFINED) {
1358 return false;
1359 }
1360 tensor_proto->set_data_type(data_type);
1361 } else if (value->isa<Complex>()) {
1362 tensor_proto->set_name("value0");
1363 auto complex_value = value->cast<ComplexPtr>();
1364 auto data_type = GetMindirDataBitsComplexType(complex_value->nbits());
1365 if (data_type == mind_ir::TensorProto_DataType_UNDEFINED) {
1366 return false;
1367 }
1368 tensor_proto->set_data_type(data_type);
1369 } else if (value->isa<Bool>()) {
1370 tensor_proto->set_name("value0");
1371 tensor_proto->set_data_type(mind_ir::TensorProto_DataType_BOOL);
1372 } else if (value->isa<TensorType>()) {
1373 return SetTensorTypeToAttributeProto(value, tensor_proto);
1374 } else {
1375 MS_LOG(EXCEPTION) << "Unsupported type: " << value->type_name();
1376 }
1377 return true;
1378 }
1379
SetNamedValueToAttributeProto(const ValuePtr & value,mind_ir::AttributeProto * const attr_proto) const1380 bool IrExportBuilder::SetNamedValueToAttributeProto(const ValuePtr &value,
1381 mind_ir::AttributeProto *const attr_proto) const {
1382 if (value->isa<None>()) {
1383 attr_proto->set_type(mind_ir::AttributeProto_AttributeType_NONE);
1384 MS_LOG(DEBUG) << "Attr string: " << value->type_name();
1385 } else if (value->isa<MindIRClassType>()) {
1386 attr_proto->set_type(mind_ir::AttributeProto_AttributeType_CLASS_TYPE);
1387 auto class_type = GetValue<MindIRClassTypePtr>(value)->name();
1388 // class 'XXX' -> XXX
1389 constexpr int64_t path_begin_index = 7;
1390 auto str = std::string(class_type.begin() + path_begin_index, class_type.end() - 1);
1391 attr_proto->set_s(str);
1392 } else if (value->isa<MindIRNameSpace>()) {
1393 attr_proto->set_type(mind_ir::AttributeProto_AttributeType_NAME_SPACE);
1394 attr_proto->set_s(GetValue<MindIRNameSpacePtr>(value)->name_space());
1395 } else if (value->isa<MindIRSymbol>()) {
1396 attr_proto->set_type(mind_ir::AttributeProto_AttributeType_SYMBOL);
1397 attr_proto->set_s(GetValue<MindIRSymbolPtr>(value)->symbol());
1398 } else {
1399 MS_LOG(ERROR) << "Unsupported named type: " << value->type_name();
1400 return false;
1401 }
1402 return true;
1403 }
1404
SetValueToAttributeProto(const ValuePtr & value,mind_ir::AttributeProto * const attr_proto)1405 bool IrExportBuilder::SetValueToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) {
1406 if (value == nullptr) {
1407 MS_LOG(ERROR) << "Value is null.";
1408 return false;
1409 }
1410 MS_EXCEPTION_IF_NULL(attr_proto);
1411 if (value->isa<StringImm>() || value->isa<Scalar>()) {
1412 return SetScalarToAttributeProto_ir(value, attr_proto);
1413 } else if (value->isa<Number>() || value->isa<TensorType>()) {
1414 return SetTypeToAttributeProto(value, attr_proto);
1415 } else if (value->isa<ValueSequence>()) {
1416 if (!SetSequenceToAttributeProto(value->cast<ValueSequencePtr>(), attr_proto)) {
1417 MS_LOG(ERROR) << "Set sequence to AttributeProto failed.";
1418 return false;
1419 }
1420 MS_LOG(DEBUG) << "Attr string: " << value->type_name();
1421 } else if (value->isa<ValueDictionary>()) {
1422 if (!SetDictToAttributeProto(value->cast<ValueDictionaryPtr>(), attr_proto)) {
1423 MS_LOG(ERROR) << "Set dictionary to AttributeProto failed.";
1424 return false;
1425 }
1426 MS_LOG(DEBUG) << "Attr string: " << value->type_name();
1427 } else if (value->isa<tensor::Tensor>()) {
1428 return SetTensorToAttributeProto(value, attr_proto);
1429 } else if (value->isa<Named>()) {
1430 return SetNamedValueToAttributeProto(value, attr_proto);
1431 } else if (value->isa<TypeNull>()) {
1432 attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TYPE_NULL);
1433 MS_LOG(DEBUG) << "Attr string: " << value->type_name();
1434 } else if (value->isa<Monad>()) {
1435 if (value->isa<UMonad>()) {
1436 attr_proto->set_type(mind_ir::AttributeProto_AttributeType_UMONAD);
1437 } else if (value->isa<IOMonad>()) {
1438 attr_proto->set_type(mind_ir::AttributeProto_AttributeType_IOMONAD);
1439 } else {
1440 MS_LOG(ERROR) << "Unsupported Monad type: " << value->type_name();
1441 return false;
1442 }
1443 } else if (value->isa<QuantizationParam>()) {
1444 auto quantization_param = value->cast<std::shared_ptr<QuantizationParam>>();
1445 attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSORS);
1446 auto tensor_proto = attr_proto->add_tensors();
1447 tensor_proto->set_name(attr_proto->name());
1448 auto quant_param_proto = tensor_proto->add_quant_params();
1449 auto ret = SetQuantizationParamToAttrProto(quantization_param, quant_param_proto);
1450 if (ret != true) {
1451 MS_LOG(ERROR) << "QuantizationParam Set Value to AttributeProto Error";
1452 return false;
1453 }
1454 } else if (value->isa<Functor>()) {
1455 return SetFunctorToAttrProto(value->cast<FunctorPtr>(), attr_proto);
1456 } else {
1457 MS_LOG(ERROR) << "Unsupported type: " << value->type_name();
1458 return false;
1459 }
1460 return true;
1461 }
1462
SetScalarToAttributeProto_ir(const ValuePtr & value,mind_ir::AttributeProto * const attr_proto) const1463 bool IrExportBuilder::SetScalarToAttributeProto_ir(const ValuePtr &value,
1464 mind_ir::AttributeProto *const attr_proto) const {
1465 if (value == nullptr || attr_proto == nullptr) {
1466 MS_LOG(EXCEPTION) << "ValuePtr or AttributeProto is null!";
1467 }
1468 if (value->isa<StringImm>()) {
1469 attr_proto->set_type(mind_ir::AttributeProto_AttributeType_STRING);
1470 attr_proto->set_s(GetValue<std::string>(value));
1471 } else if (value->isa<BoolImm>()) {
1472 attr_proto->set_type(mind_ir::AttributeProto_AttributeType_BOOL);
1473 int64_t attr_value = GetValue<bool>(value) ? 1 : 0;
1474 attr_proto->set_i(attr_value);
1475 } else if (SetScalarToAttributeProtoForInt_ir(value, attr_proto)) {
1476 return true;
1477 } else if (value->isa<FP32Imm>()) {
1478 attr_proto->set_type(mind_ir::AttributeProto_AttributeType_FLOAT);
1479 attr_proto->set_f(GetValue<float>(value));
1480 } else if (value->isa<FP64Imm>()) {
1481 attr_proto->set_type(mind_ir::AttributeProto_AttributeType_DOUBLE);
1482 attr_proto->set_d(GetValue<double>(value));
1483 } else {
1484 MS_LOG(ERROR) << "Unsupported scalar type: " << value->type_name();
1485 return false;
1486 }
1487 return true;
1488 }
1489
SetScalarToAttributeProtoForInt_ir(const ValuePtr & value,mind_ir::AttributeProto * const attr_proto) const1490 bool IrExportBuilder::SetScalarToAttributeProtoForInt_ir(const ValuePtr &value,
1491 mind_ir::AttributeProto *const attr_proto) const {
1492 if (value->isa<Int8Imm>()) {
1493 attr_proto->set_type(mind_ir::AttributeProto_AttributeType_INT8);
1494 attr_proto->set_i(value->cast<Int8ImmPtr>()->value());
1495 } else if (value->isa<Int16Imm>()) {
1496 attr_proto->set_type(mind_ir::AttributeProto_AttributeType_INT16);
1497 attr_proto->set_i(value->cast<Int16ImmPtr>()->value());
1498 } else if (value->isa<Int32Imm>()) {
1499 attr_proto->set_type(mind_ir::AttributeProto_AttributeType_INT32);
1500 attr_proto->set_i(value->cast<Int32ImmPtr>()->value());
1501 } else if (value->isa<Int64Imm>()) {
1502 attr_proto->set_type(mind_ir::AttributeProto_AttributeType_INT64);
1503 attr_proto->set_i(value->cast<Int64ImmPtr>()->value());
1504 } else if (value->isa<UInt8Imm>()) {
1505 attr_proto->set_type(mind_ir::AttributeProto_AttributeType_UINT8);
1506 attr_proto->set_i(value->cast<UInt8ImmPtr>()->value());
1507 } else if (value->isa<UInt16Imm>()) {
1508 attr_proto->set_type(mind_ir::AttributeProto_AttributeType_UINT16);
1509 attr_proto->set_i(value->cast<UInt16ImmPtr>()->value());
1510 } else if (value->isa<UInt32Imm>()) {
1511 attr_proto->set_type(mind_ir::AttributeProto_AttributeType_UINT32);
1512 attr_proto->set_i(value->cast<UInt32ImmPtr>()->value());
1513 } else if (value->isa<UInt64Imm>()) {
1514 attr_proto->set_type(mind_ir::AttributeProto_AttributeType_UINT64);
1515 attr_proto->set_i(UlongToLong(value->cast<UInt64ImmPtr>()->value()));
1516 } else {
1517 return false;
1518 }
1519 return true;
1520 }
1521
SetTypeToAttributeProto_irs(const ValuePtr & value,mind_ir::AttributeProto * const attr_proto)1522 bool IrExportBuilder::SetTypeToAttributeProto_irs(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) {
1523 if (attr_proto == nullptr) {
1524 MS_LOG(EXCEPTION) << "AttributeProto is null!";
1525 }
1526 if (value->isa<Int>()) {
1527 attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSORS);
1528 mind_ir::TensorProto *tensor_proto = attr_proto->add_tensors();
1529 auto int_value = value->cast<IntPtr>();
1530 auto data_type = GetMindirDataBitsIntType(int_value->nbits());
1531 if (data_type == mind_ir::TensorProto_DataType_UNDEFINED) {
1532 return false;
1533 }
1534 tensor_proto->set_data_type(data_type);
1535 } else if (value->isa<Float>()) {
1536 attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSORS);
1537 mind_ir::TensorProto *tensor_proto = attr_proto->add_tensors();
1538 auto float_value = value->cast<FloatPtr>();
1539 auto data_type = GetMindirDataBitsFloatType(float_value->nbits());
1540 if (data_type == mind_ir::TensorProto_DataType_UNDEFINED) {
1541 return false;
1542 }
1543 tensor_proto->set_data_type(data_type);
1544 } else if (value->isa<UInt>()) {
1545 attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSORS);
1546 mind_ir::TensorProto *tensor_proto = attr_proto->add_tensors();
1547 auto uint_value = value->cast<UIntPtr>();
1548 auto data_type = GetMindirDataBitsUIntType(uint_value->nbits());
1549 if (data_type == mind_ir::TensorProto_DataType_UNDEFINED) {
1550 return false;
1551 }
1552 tensor_proto->set_data_type(data_type);
1553 } else if (value->isa<Bool>()) {
1554 attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSORS);
1555 mind_ir::TensorProto *tensor_proto = attr_proto->add_tensors();
1556 tensor_proto->set_data_type(mind_ir::TensorProto_DataType_BOOL);
1557 } else if (value->isa<tensor::Tensor>()) {
1558 attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSORS);
1559 return SetTensorToAttributeProto(value, attr_proto);
1560 } else if (value->isa<QuantizationParam>()) {
1561 auto quantization_param = value->cast<std::shared_ptr<QuantizationParam>>();
1562 attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSORS);
1563 auto tensor_proto = attr_proto->add_tensors();
1564 tensor_proto->set_name("quant_param");
1565 auto quant_param_proto = tensor_proto->add_quant_params();
1566 auto ret = SetQuantizationParamToAttrProto(quantization_param, quant_param_proto);
1567 if (ret != true) {
1568 MS_LOG(ERROR) << "QuantizationParam Set Value to AttributeProto Error";
1569 return false;
1570 }
1571 } else {
1572 MS_LOG(EXCEPTION) << "Unsupported type: " << value->type_name();
1573 }
1574 return true;
1575 }
1576
SetScalarToAttributeProto_irs(const ValuePtr & value,mind_ir::AttributeProto * const attr_proto) const1577 bool IrExportBuilder::SetScalarToAttributeProto_irs(const ValuePtr &value,
1578 mind_ir::AttributeProto *const attr_proto) const {
1579 if (attr_proto == nullptr) {
1580 MS_LOG(EXCEPTION) << "AttributeProto is null!";
1581 }
1582 if (value->isa<StringImm>()) {
1583 attr_proto->set_type(mind_ir::AttributeProto_AttributeType_STRING);
1584 attr_proto->add_strings(GetValue<std::string>(value));
1585 } else if (value->isa<BoolImm>()) {
1586 attr_proto->set_type(mind_ir::AttributeProto_AttributeType_BOOL);
1587 attr_proto->add_ints(GetValue<bool>(value));
1588 } else if (SetScalarToAttributeProtoForInt_irs(value, attr_proto)) {
1589 return true;
1590 } else if (value->isa<FP32Imm>()) {
1591 attr_proto->set_type(mind_ir::AttributeProto_AttributeType_FLOAT);
1592 attr_proto->add_floats(GetValue<float>(value));
1593 } else if (value->isa<FP64Imm>()) {
1594 attr_proto->set_type(mind_ir::AttributeProto_AttributeType_DOUBLE);
1595 attr_proto->add_doubles(GetValue<double>(value));
1596 } else {
1597 MS_LOG(ERROR) << "Unsupported scalar type: " << value->type_name();
1598 return false;
1599 }
1600 return true;
1601 }
1602
SetScalarToAttributeProtoForInt_irs(const ValuePtr & value,mind_ir::AttributeProto * const attr_proto) const1603 bool IrExportBuilder::SetScalarToAttributeProtoForInt_irs(const ValuePtr &value,
1604 mind_ir::AttributeProto *const attr_proto) const {
1605 if (value->isa<Int8Imm>()) {
1606 attr_proto->set_type(mind_ir::AttributeProto_AttributeType_INT8);
1607 attr_proto->add_ints(value->cast<Int8ImmPtr>()->value());
1608 } else if (value->isa<Int16Imm>()) {
1609 attr_proto->set_type(mind_ir::AttributeProto_AttributeType_INT16);
1610 attr_proto->add_ints(value->cast<Int16ImmPtr>()->value());
1611 } else if (value->isa<Int32Imm>()) {
1612 attr_proto->set_type(mind_ir::AttributeProto_AttributeType_INT32);
1613 attr_proto->add_ints(value->cast<Int32ImmPtr>()->value());
1614 } else if (value->isa<Int64Imm>()) {
1615 attr_proto->set_type(mind_ir::AttributeProto_AttributeType_INT64);
1616 attr_proto->add_ints(value->cast<Int64ImmPtr>()->value());
1617 } else if (value->isa<UInt8Imm>()) {
1618 attr_proto->set_type(mind_ir::AttributeProto_AttributeType_UINT8);
1619 attr_proto->add_ints(value->cast<UInt8ImmPtr>()->value());
1620 } else if (value->isa<UInt16Imm>()) {
1621 attr_proto->set_type(mind_ir::AttributeProto_AttributeType_UINT16);
1622 attr_proto->add_ints(value->cast<UInt16ImmPtr>()->value());
1623 } else if (value->isa<UInt32Imm>()) {
1624 attr_proto->set_type(mind_ir::AttributeProto_AttributeType_UINT32);
1625 attr_proto->add_ints(value->cast<UInt32ImmPtr>()->value());
1626 } else if (value->isa<UInt64Imm>()) {
1627 attr_proto->set_type(mind_ir::AttributeProto_AttributeType_UINT64);
1628 attr_proto->add_ints(SizeToInt(value->cast<UInt64ImmPtr>()->value()));
1629 } else {
1630 return false;
1631 }
1632 return true;
1633 }
1634
SetSeqElemToAttributeProto(const ValuePtr & value,mind_ir::AttributeProto * const attr_proto)1635 bool IrExportBuilder::SetSeqElemToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) {
1636 if (value == nullptr) {
1637 MS_LOG(ERROR) << "Value is nullptr";
1638 return false;
1639 }
1640 if (value->isa<StringImm>() || value->isa<Scalar>()) {
1641 return SetScalarToAttributeProto_irs(value, attr_proto);
1642 }
1643 return SetTypeToAttributeProto_irs(value, attr_proto);
1644 }
1645
SetSequenceToAttributeProto(const ValueSequencePtr & value,mind_ir::AttributeProto * const attr_proto)1646 bool IrExportBuilder::SetSequenceToAttributeProto(const ValueSequencePtr &value,
1647 mind_ir::AttributeProto *const attr_proto) {
1648 if (value == nullptr || attr_proto == nullptr) {
1649 MS_LOG(EXCEPTION) << "ValueSequencePtr or AttributeProto is null!";
1650 }
1651 if (value->isa<ValueTuple>()) {
1652 attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TUPLE);
1653 } else if (value->isa<ValueList>()) {
1654 attr_proto->set_type(mind_ir::AttributeProto_AttributeType_LIST);
1655 } else {
1656 MS_LOG(EXCEPTION) << "The sequance value should be ValueTuple or ValueList, but it is " << value->ToString();
1657 }
1658 auto value_sequence = value->cast<ValueSequencePtr>();
1659 MS_EXCEPTION_IF_NULL(value_sequence);
1660 const auto &values = value_sequence->value();
1661 if (values.empty()) {
1662 MS_LOG(DEBUG) << "SetSequenceToAttributeProto sequence size is 0";
1663 return true;
1664 }
1665 for (const auto &item : values) {
1666 mind_ir::AttributeProto *attr_values = attr_proto->add_values();
1667 MS_EXCEPTION_IF_NULL(item);
1668 if (item->isa<ValueSequence>()) {
1669 if (!SetSequenceToAttributeProto(item->cast<ValueSequencePtr>(), attr_values)) {
1670 MS_LOG(ERROR) << "Set sequence to AttributeProto failed.";
1671 return false;
1672 }
1673 } else {
1674 if (!SetSeqElemToAttributeProto(item, attr_values)) {
1675 MS_LOG(ERROR) << "Set seq elem to AttributeProto failed.";
1676 return false;
1677 }
1678 }
1679 }
1680 return true;
1681 }
1682
SetDictToAttributeProto(const ValueDictionaryPtr & value_dict,mind_ir::AttributeProto * const attr_proto)1683 bool IrExportBuilder::SetDictToAttributeProto(const ValueDictionaryPtr &value_dict,
1684 mind_ir::AttributeProto *const attr_proto) {
1685 if (value_dict == nullptr || attr_proto == nullptr) {
1686 MS_LOG(EXCEPTION) << "ValueDictionaryPtr or AttributeProto is null!";
1687 }
1688 attr_proto->set_type(mind_ir::AttributeProto_AttributeType_DICT);
1689 const auto &values = value_dict->value();
1690 if (values.empty()) {
1691 MS_LOG(DEBUG) << "SetDictToAttributeProto dictionary size is 0";
1692 return true;
1693 }
1694 for (const auto &item : values) {
1695 mind_ir::AttributeProto *dict_item_proto = attr_proto->add_values();
1696 const auto &key = item.first;
1697 dict_item_proto->set_name(GetValue<std::string>(key));
1698 const auto &value = item.second;
1699 MS_EXCEPTION_IF_NULL(value);
1700 mind_ir::AttributeProto *dict_item_value = dict_item_proto->add_values();
1701 if (value->isa<ValueSequence>()) {
1702 if (!SetSequenceToAttributeProto(value->cast<ValueSequencePtr>(), dict_item_value)) {
1703 MS_LOG(ERROR) << "Set sequence to AttributeProto failed.";
1704 return false;
1705 }
1706 } else if (value->isa<ValueDictionary>()) {
1707 if (!SetDictToAttributeProto(value->cast<ValueDictionaryPtr>(), dict_item_value)) {
1708 MS_LOG(ERROR) << "Set dictionary to AttributeProto failed.";
1709 return false;
1710 }
1711 } else if (value->isa<StringImm>() || value->isa<Scalar>()) {
1712 if (!SetScalarToAttributeProto_irs(value, dict_item_value)) {
1713 MS_LOG(ERROR) << "Set StringImm or Scalar to AttributeProto failed.";
1714 return false;
1715 }
1716 } else if (value->isa<Number>() || value->isa<tensor::Tensor>()) {
1717 if (!SetTypeToAttributeProto_irs(value, dict_item_value)) {
1718 MS_LOG(ERROR) << "Set Number or Tensor to AttributeProto failed.";
1719 return false;
1720 }
1721 } else {
1722 MS_LOG(EXCEPTION) << "Unsupported type while converting ValueDictionary to AttributeProto: "
1723 << value->type_name();
1724 }
1725 }
1726 return true;
1727 }
1728
BuildCNodeAttr(const CNodePtr & node,mind_ir::NodeProto * const node_proto)1729 bool IrExportBuilder::BuildCNodeAttr(const CNodePtr &node, mind_ir::NodeProto *const node_proto) {
1730 for (const auto &attr : node->attrs()) {
1731 if (attr.second == nullptr) {
1732 MS_LOG(ERROR) << "attr: " << attr.first << " has no value.";
1733 continue;
1734 }
1735 mind_ir::AttributeProto *attr_proto = node_proto->add_node_attr();
1736 attr_proto->set_name(attr.first);
1737 if (!SetValueToAttributeProto(attr.second, attr_proto)) {
1738 MS_LOG(ERROR) << "Set value to node attr to node proto failed.";
1739 MS_LOG(ERROR) << "node :" << node->DebugString() << "attr:{" << attr.first << "," << attr.second << "}";
1740 return false;
1741 }
1742 }
1743
1744 for (const auto &attr : node->primal_attrs()) {
1745 if (attr.second == nullptr) {
1746 MS_LOG(ERROR) << "attr: " << attr.first << " has no value.";
1747 continue;
1748 }
1749 mind_ir::AttributeProto *attr_proto = node_proto->add_primal_attr();
1750 attr_proto->set_name(attr.first);
1751 if (!SetValueToAttributeProto(attr.second, attr_proto)) {
1752 MS_LOG(ERROR) << "Set value to node primal attr to node proto failed.";
1753 MS_LOG(ERROR) << "node :" << node->DebugString() << "attr:{" << attr.first << "," << attr.second << "}";
1754 return false;
1755 }
1756 }
1757 return true;
1758 }
1759
GetBinaryProtoString(const FuncGraphPtr & func_graph,const bool & incremental)1760 std::string GetBinaryProtoString(const FuncGraphPtr &func_graph, const bool &incremental) {
1761 auto builder = std::make_shared<IrExportBuilder>(incremental);
1762 if (builder == nullptr) {
1763 MS_LOG(ERROR) << "Create ir exporter failed!";
1764 return "";
1765 }
1766 auto exporter = std::make_shared<IrExporter>(builder);
1767 if (exporter == nullptr) {
1768 return "";
1769 }
1770 auto ret = exporter->GetDumpString(func_graph);
1771 return ret;
1772 }
1773
GenBinaryProto(const FuncGraphPtr & func_graph)1774 ModelProtoPtr GenBinaryProto(const FuncGraphPtr &func_graph) {
1775 auto exporter = std::make_shared<IrExporter>(std::make_shared<IrExportBuilder>());
1776 return exporter->GetDumpProto(func_graph);
1777 }
1778
DumpBinaryProto(const FuncGraphPtr & func_graph,const std::string & file_path)1779 bool DumpBinaryProto(const FuncGraphPtr &func_graph, const std::string &file_path) {
1780 auto exporter = std::make_shared<IrExporter>(std::make_shared<IrExportBuilder>());
1781 auto proto = exporter->GetDumpProto(func_graph);
1782 MindIRExporter mindir_exporter;
1783 if (proto == nullptr) {
1784 MS_LOG(ERROR) << "Get binary proto for graph " << func_graph->ToString() << " failed.";
1785 return false;
1786 }
1787 return mindir_exporter.SaveProtoToFile(proto.get(), file_path);
1788 }
1789
DumpBinaryProto(const FuncGraphPtr & root_graph,const std::vector<FuncGraphPtr> & child_graphs,const std::vector<AnfNodePtr> & isolated_nodes,const std::string & file_path)1790 bool DumpBinaryProto(const FuncGraphPtr &root_graph, const std::vector<FuncGraphPtr> &child_graphs,
1791 const std::vector<AnfNodePtr> &isolated_nodes, const std::string &file_path) {
1792 auto exporter = std::make_shared<IrExporter>(std::make_shared<IrExportBuilder>());
1793 auto proto = exporter->GetDumpProto(root_graph, child_graphs, isolated_nodes);
1794 if (proto == nullptr) {
1795 MS_LOG(ERROR) << "Get binary proto for graph " << root_graph->ToString() << " failed.";
1796 return false;
1797 }
1798 auto realpath = Common::CreatePrefixPath(file_path, true);
1799 if (!realpath.has_value()) {
1800 MS_LOG(ERROR) << "Get real path of file " << file_path << " failed.";
1801 return false;
1802 }
1803 ChangeFileMode(realpath.value(), S_IWUSR);
1804 std::ofstream fout(realpath.value());
1805 if (!fout.is_open()) {
1806 MS_LOG(ERROR) << "Open the file '" << realpath.value() << "' failed!" << ErrnoToString(errno);
1807 return false;
1808 }
1809 if (!proto->SerializeToOstream(&fout)) {
1810 MS_LOG(ERROR) << "Failed to write the mindir proto to file " << realpath.value();
1811 fout.close();
1812 return false;
1813 }
1814 fout.close();
1815 ChangeFileMode(realpath.value(), S_IRUSR);
1816 return true;
1817 }
1818
ParserPath(const std::string & output_path)1819 bool MindIRExporter::ParserPath(const std::string &output_path) {
1820 if (!FileUtils::ParserPathAndModelName(output_path, &save_path_, &model_name_)) {
1821 MS_LOG(ERROR) << "parser save path and model name from output_path failed.";
1822 return false;
1823 }
1824 #ifdef _WIN32
1825 save_model_path_ = save_path_ + "\\" + model_name_ + ".mindir";
1826 #else
1827 save_model_path_ = save_path_ + "/" + model_name_ + ".mindir";
1828 #endif
1829 return true;
1830 }
1831
ExportProto(const FuncGraphPtr & func_graph,const std::string & file_path,const FuncGraphPtr & param_layout_fg)1832 bool MindIRExporter::ExportProto(const FuncGraphPtr &func_graph, const std::string &file_path,
1833 const FuncGraphPtr ¶m_layout_fg) {
1834 if (func_graph == nullptr) {
1835 MS_LOG(ERROR) << "func_graph is nullptr.";
1836 return false;
1837 }
1838
1839 if (!ParserPath(file_path)) {
1840 MS_LOG(ERROR) << "parse path failed.";
1841 return false;
1842 }
1843
1844 // Serialize to protobuf using unique parameter name label.
1845 // Do preprocess on func_graph and check conditions for saving together.
1846 bool ret = PreProcSaveTogether(func_graph);
1847 if (!ret) {
1848 MS_LOG(ERROR) << "PreProcSaveTogether failed";
1849 return ret;
1850 }
1851 #ifdef ENABLE_DUMP_IR
1852 auto context = MsContext::GetInstance();
1853 MS_EXCEPTION_IF_NULL(context);
1854 if (context->CanDump(kIntroductory)) {
1855 DumpIR("PreProcSaveTogether.ir", func_graph);
1856 }
1857 #endif
1858
1859 if (save_together_) {
1860 MS_LOG(INFO) << "SaveMindIRTogether";
1861 ret = SaveMindIRTogether();
1862 } else {
1863 MS_LOG(INFO) << "SplitSave";
1864 ret = SplitSave();
1865 }
1866 if (!ret) {
1867 MS_LOG(ERROR) << "save mindir weight failed.";
1868 return ret;
1869 }
1870 return true;
1871 }
1872
SaveMindIRTogether()1873 bool MindIRExporter::SaveMindIRTogether() {
1874 for (auto ¶m_proto : *(model_proto_.mutable_graph()->mutable_parameter())) {
1875 std::string proto_name = param_proto.name();
1876 auto para = GetFgParaAccordingToProtoName(proto_name);
1877 if (para == nullptr) {
1878 return false;
1879 }
1880 if (!para->has_default()) {
1881 continue;
1882 }
1883 auto data = para->default_param()->cast<tensor::TensorPtr>();
1884 param_proto.clear_raw_data();
1885 param_proto.set_raw_data(data->data_c(), static_cast<size_t>(data->data().nbytes()));
1886 }
1887 return SaveProtoToFile(&model_proto_, save_model_path_);
1888 }
1889
CreateParameterDir()1890 bool MindIRExporter::CreateParameterDir() {
1891 #ifdef _WIN32
1892 dir_name_ = save_path_ + "\\" + model_name_ + "_variables";
1893 #else
1894 dir_name_ = save_path_ + "/" + model_name_ + "_variables";
1895 #endif
1896 fs_ = system::Env::GetFileSystem();
1897 if (fs_ == nullptr) {
1898 MS_LOG(ERROR) << "create file system failed.";
1899 return false;
1900 }
1901
1902 if (fs_->FileExist(dir_name_)) {
1903 if (!DeleteDirRecursively(dir_name_)) {
1904 return false;
1905 }
1906 }
1907
1908 if (!fs_->CreateDir(dir_name_)) {
1909 MS_LOG(ERROR) << "create dir failed.";
1910 return false;
1911 }
1912
1913 ChangeFileMode(dir_name_, S_IWUSR | S_IRUSR | S_IXUSR);
1914 return true;
1915 }
1916
CreateExternalPath(const std::string & external_file)1917 std::string MindIRExporter::CreateExternalPath(const std::string &external_file) {
1918 dir_path_ = FileUtils::GetRealPath(dir_name_.c_str()).value();
1919 std::string external_local_path{};
1920 #ifdef _WIN32
1921 external_local_path = dir_path_ + "\\" + external_file;
1922 #else
1923 external_local_path = dir_path_ + "/" + external_file;
1924 #endif
1925 return external_local_path;
1926 }
1927
SplitSave()1928 bool MindIRExporter::SplitSave() {
1929 MS_LOG(DEBUG) << "Parameters in the net capacity exceeds 1G, save MindIR model and parameters separately.";
1930 if (!CreateParameterDir()) {
1931 MS_LOG(ERROR) << "create parameter dir failed.";
1932 return false;
1933 }
1934
1935 int index = 0;
1936 std::string external_local = "data_" + std::to_string(index);
1937 auto external_local_path = CreateExternalPath(external_local);
1938 if (fs_->FileExist(external_local_path)) {
1939 if (!fs_->DeleteFile(external_local_path)) {
1940 MS_LOG(ERROR) << "delete file failed.";
1941 return false;
1942 }
1943 }
1944 int64_t parameter_size = 0;
1945 int64_t offset = OFFSET;
1946
1947 data_fs_ = FileUtils::OpenFile(external_local_path, std::ios::out | std::ios::binary | std::ios::trunc);
1948 if (data_fs_ == nullptr) {
1949 MS_LOG(ERROR) << "Open " << external_local_path << " failed";
1950 return false;
1951 }
1952 if (!ChangeParaDataFile(external_local)) {
1953 MS_LOG(ERROR) << "change parameter data file failed.";
1954 return false;
1955 }
1956
1957 for (auto ¶m_proto : *(model_proto_.mutable_graph()->mutable_parameter())) {
1958 std::string proto_name = param_proto.name();
1959 auto para = GetFgParaAccordingToProtoName(proto_name);
1960 if (para == nullptr) {
1961 return false;
1962 }
1963 if (!para->has_default()) {
1964 continue;
1965 }
1966 auto data = para->default_param()->cast<tensor::TensorPtr>();
1967 int64_t data_length = static_cast<int64_t>(data->data().nbytes());
1968 int64_t append_size = 0;
1969 if (data_length % OFFSET != 0) {
1970 append_size = OFFSET - (data_length % OFFSET);
1971 }
1972 parameter_size += ((append_size + data_length) / PARA_ROUND);
1973 if (parameter_size > static_cast<int64_t>(TOTAL_SAVE)) {
1974 index++;
1975 external_local = "data_" + std::to_string(index);
1976 data_fs_->close();
1977 delete data_fs_;
1978 data_fs_ = nullptr;
1979
1980 if (!ChangeParaDataFile(external_local)) {
1981 MS_LOG(ERROR) << "change parameter data file failed.";
1982 return false;
1983 }
1984 parameter_size = OFFSET / PARA_ROUND;
1985 }
1986 std::string external_local_data = model_name_ + "_variables/" + external_local;
1987 param_proto.mutable_external_data()->set_location(external_local_data);
1988 param_proto.mutable_external_data()->set_length(data_length);
1989 param_proto.mutable_external_data()->set_offset(offset);
1990
1991 data_fs_->write(static_cast<const char *>(data->data_c()), data_length);
1992 auto append_data = new char[append_size];
1993 if (append_data == nullptr) {
1994 return false;
1995 }
1996 data_fs_->write(append_data, append_size);
1997 offset += (data_length + append_size);
1998 delete[] append_data;
1999 }
2000 std::string split_model_file_name = "";
2001 #ifdef _WIN32
2002 split_model_file_name = save_path_ + "\\" + model_name_ + "_graph.mindir";
2003 #else
2004 split_model_file_name = save_path_ + "/" + model_name_ + "_graph.mindir";
2005 #endif
2006 return SaveProtoToFile(&model_proto_, split_model_file_name);
2007 }
2008
SaveProtoToFile(mind_ir::ModelProto * model_proto,const std::string & output_file)2009 bool MindIRExporter::SaveProtoToFile(mind_ir::ModelProto *model_proto, const std::string &output_file) {
2010 auto realpath = Common::CreatePrefixPath(output_file, true);
2011 if (!realpath.has_value()) {
2012 MS_LOG(ERROR) << "Get real path of file " << output_file << " failed.";
2013 return false;
2014 }
2015
2016 ChangeFileMode(realpath.value(), S_IWUSR);
2017 std::ofstream fout(realpath.value());
2018 if (!fout.is_open()) {
2019 MS_LOG(ERROR) << "Open the file '" << realpath.value() << "' failed!" << ErrnoToString(errno);
2020 return false;
2021 }
2022
2023 if (!model_proto->SerializeToOstream(&fout)) {
2024 MS_LOG(ERROR) << "Failed to write the mindir proto to file " << realpath.value();
2025 fout.close();
2026 return false;
2027 }
2028
2029 fout.close();
2030 ChangeFileMode(realpath.value(), S_IRUSR);
2031 return true;
2032 }
2033
ChangeParaDataFile(const std::string & file)2034 bool MindIRExporter::ChangeParaDataFile(const std::string &file) {
2035 auto real_path = CreateExternalPath(file);
2036 if (fs_->FileExist(real_path)) {
2037 if (!fs_->DeleteFile(real_path)) {
2038 MS_LOG(ERROR) << "delete file failed.";
2039 return false;
2040 }
2041 }
2042 ChangeFileMode(real_path, S_IWUSR);
2043 data_fs_ = FileUtils::OpenFile(real_path, std::ios::app);
2044 if (data_fs_ == nullptr) {
2045 MS_LOG(ERROR) << "data_fs_ is nullptr.";
2046 return false;
2047 }
2048 char front_info[OFFSET]{0};
2049 front_info[0] = IsSystemLittleEndidan();
2050 (void)data_fs_->write(front_info, OFFSET);
2051 return true;
2052 }
2053
IsSystemLittleEndidan() const2054 bool MindIRExporter::IsSystemLittleEndidan() const {
2055 int check = 0x01;
2056 auto address = reinterpret_cast<char *>(&check);
2057 return *address == 0x01;
2058 }
2059
PreProcSaveTogether(const FuncGraphPtr & func_graph)2060 bool MindIRExporter::PreProcSaveTogether(const FuncGraphPtr &func_graph) {
2061 if (func_graph == nullptr) {
2062 MS_LOG(ERROR) << "func_graph is nullptr.";
2063 return false;
2064 }
2065
2066 if (!UpdateParamCount(func_graph)) {
2067 MS_LOG(ERROR) << "Update parameter count failed.";
2068 return false;
2069 }
2070
2071 // Parse func_graph as model proto
2072 std::string proto_string = GetBinaryProtoString(func_graph);
2073 if (proto_string.empty()) {
2074 MS_LOG(ERROR) << "parse proto string failed.";
2075 return false;
2076 }
2077
2078 if (!model_proto_.ParseFromString(proto_string)) {
2079 MS_LOG(ERROR) << "parse model proto from string failed.";
2080 return false;
2081 }
2082
2083 if (!ParamDict(func_graph)) {
2084 MS_LOG(ERROR) << "parse param form funcgraph failed.";
2085 return false;
2086 }
2087
2088 if (!IfSaveTogether(&save_together_)) {
2089 MS_LOG(ERROR) << "error occur when check condition of saving together.";
2090 return false;
2091 }
2092
2093 return true;
2094 }
2095
IfSaveTogether(bool * save_together)2096 bool MindIRExporter::IfSaveTogether(bool *save_together) {
2097 size_t data_total = model_proto_.ByteSizeLong();
2098 for (auto ¶m_proto : model_proto_.graph().parameter()) {
2099 std::string proto_name = param_proto.name();
2100 auto para = GetFgParaAccordingToProtoName(proto_name);
2101 if (para == nullptr) {
2102 return false;
2103 }
2104 if (!para->has_default()) {
2105 continue;
2106 }
2107 auto tensor = std::dynamic_pointer_cast<tensor::Tensor>(para->default_param());
2108 if (tensor == nullptr) {
2109 MS_LOG(ERROR) << "param node default_param is not tensor.";
2110 return false;
2111 }
2112 data_total += tensor->Size();
2113 }
2114 if (data_total > TOTAL_SAVE) {
2115 *save_together = false;
2116 } else {
2117 *save_together = false;
2118 }
2119 return true;
2120 }
2121
GetFgParaAccordingToProtoName(const std::string & proto_name)2122 std::shared_ptr<Parameter> MindIRExporter::GetFgParaAccordingToProtoName(const std::string &proto_name) {
2123 auto beg_pos = proto_name.find_first_of(':') + 1;
2124 if (beg_pos >= proto_name.size()) {
2125 MS_LOG(ERROR) << "begin pos exceed proto name length.";
2126 return nullptr;
2127 }
2128 auto name = proto_name.substr(beg_pos);
2129 if (param_dict_.find(name) == param_dict_.end()) {
2130 MS_LOG(ERROR) << "param proto name: " << name << " is not in param dict.";
2131 return nullptr;
2132 }
2133 return param_dict_.at(name);
2134 }
2135
UpdateParamCount(const FuncGraphPtr & func_graph)2136 bool MindIRExporter::UpdateParamCount(const FuncGraphPtr &func_graph) {
2137 auto fv_count = 0;
2138 std::vector<AnfNodePtr> params;
2139 std::vector<AnfNodePtr> reorder_param;
2140 reorder_param.reserve(func_graph->parameters().size());
2141 for (const auto &node : func_graph->parameters()) {
2142 auto param_node = node->cast<ParameterPtr>();
2143 if (param_node == nullptr) {
2144 MS_LOG(ERROR) << "The parameters() in func graph should be all Parameter Node. but got " << node->DebugString();
2145 return false;
2146 }
2147 if (param_node->has_default()) {
2148 (void)params.emplace_back(param_node);
2149 ++fv_count;
2150 continue;
2151 }
2152 (void)reorder_param.emplace_back(param_node);
2153 }
2154
2155 std::copy(params.begin(), params.end(), std::back_inserter(reorder_param));
2156 func_graph->set_parameters(reorder_param);
2157 func_graph->set_fv_param_count(fv_count);
2158 return true;
2159 }
2160
ParamDict(const FuncGraphPtr & func_graph)2161 bool MindIRExporter::ParamDict(const FuncGraphPtr &func_graph) {
2162 std::set<FuncGraphPtr> all_func_graphs = {};
2163 GetAllFuncGraphs(func_graph, &all_func_graphs);
2164 for (auto &fg : all_func_graphs) {
2165 for (auto ¶ : fg->parameters()) {
2166 if (!para->isa<Parameter>()) {
2167 MS_LOG(ERROR) << "fg parameters contains non-parameter type node.";
2168 return false;
2169 }
2170 auto para_node = para->cast<ParameterPtr>();
2171 param_dict_[para->ToString()] = para_node;
2172 }
2173 }
2174 return true;
2175 }
2176 } // namespace mindspore
2177