• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "transform/graph_ir/convert.h"
18 
19 #include <cinttypes>
20 #include <algorithm>
21 #include <stack>
22 #include "utils/utils.h"
23 
24 #include "base/core_ops.h"
25 #include "frontend/operator/ops.h"
26 #include "utils/log_adapter.h"
27 #include "ir/graph_utils.h"
28 #include "utils/symbolic.h"
29 #include "utils/config_manager.h"
30 #include "utils/convert_utils.h"
31 #include "utils/ms_context.h"
32 #include "utils/check_convert_utils.h"
33 #include "transform/graph_ir/op_adapter_map.h"
34 #include "ops/state_ops.h"
35 #include "ops/array_ops.h"
36 #include "ops/elewise_calculation_ops.h"
37 #include "ops/math_ops.h"
38 #ifdef ENABLE_GE
39 #include "ops/save_ops.h"
40 #endif
41 
42 namespace mindspore {
43 namespace transform {
44 using std::endl;
45 
46 using ge::Operator;
47 using mindspore::kAnyValue;
48 using std::make_shared;
49 using std::shared_ptr;
50 using std::string;
51 using std::vector;
52 using Variable = ge::op::Variable;
53 using Constant = ge::op::Constant;
54 using Assign = ge::op::Assign;
55 using Data = ge::op::Data;
56 
57 namespace {
GetOrderedCNodes(const FuncGraphPtr fg)58 std::vector<AnfNodePtr> GetOrderedCNodes(const FuncGraphPtr fg) {
59   MS_EXCEPTION_IF_NULL(fg);
60   auto BelongSameGraph = std::bind(IncludeBelongGraph, fg, std::placeholders::_1);
61   auto succ_include_fv = [&fg](const AnfNodePtr &node) -> std::vector<AnfNodePtr> {
62     std::vector<AnfNodePtr> vecs;
63     if (node == nullptr) {
64       return vecs;
65     }
66     if (node->isa<CNode>()) {
67       auto cnode = node->cast<CNodePtr>();
68       auto &inputs = cnode->inputs();
69       // Check if free variables used.
70       for (const auto &input : inputs) {
71         auto input_fg = GetValueNode<FuncGraphPtr>(input);
72         if (input_fg) {
73           for (auto &fv : input_fg->free_variables_nodes()) {
74             if (fv->func_graph() == fg && fg->nodes().contains(fv)) {
75               vecs.push_back(fv);
76             }
77           }
78         }
79       }
80       (void)vecs.insert(vecs.end(), inputs.begin(), inputs.end());
81     }
82     return vecs;
83   };
84 
85   return TopoSort(fg->get_return(), succ_include_fv, BelongSameGraph);
86 }
87 }  // namespace
88 
89 // ---------------implement of DfGraphConvertor-------------
IsCaseNode(const CNodePtr node)90 bool IsCaseNode(const CNodePtr node) {
91   MS_EXCEPTION_IF_NULL(node);
92   if (!node->inputs().empty() && node->input(0)->isa<CNode>() &&
93       GetCNodeFuncName(node->input(0)->cast<CNodePtr>()) == "switch_layer") {
94     return true;
95   }
96   return false;
97 }
98 
GetCNodeTargetFuncName(const CNodePtr cnode)99 std::string GetCNodeTargetFuncName(const CNodePtr cnode) {
100   if (IsCaseNode(cnode)) {
101     return string(kNameCase);
102   }
103   auto name = GetCNodeFuncName(cnode);
104   if (name == "switch_layer") {
105     name = "";
106   }
107   return name;
108 }
109 
FindAdapter(const AnfNodePtr node,bool train)110 OpAdapterPtr DfGraphConvertor::FindAdapter(const AnfNodePtr node, bool train) {
111   MS_EXCEPTION_IF_NULL(node);
112   if (node->isa<CNode>()) {
113     auto cnode = node->cast<CNodePtr>();
114 
115     std::string name = kNameCustomOp;
116     if (!IsCustomCNode(cnode)) {
117       name = GetCNodeTargetFuncName(cnode);
118     }
119 
120     auto it_adpt = OpAdapterMap::get().find(name);
121     if (it_adpt != OpAdapterMap::get().end()) {
122       return it_adpt->second->Get(train);
123     }
124     MS_LOG(EXCEPTION) << "Can't find OpAdapter for " << name;
125   }
126 
127   if (node->isa<ValueNode>()) {
128     return OpAdapterMap::get()[kNameConst]->Get(train);
129   }
130   if (node->isa<Parameter>()) {
131     return OpAdapterMap::get()[kNameParam]->Get(train);
132   }
133   return OpAdapterPtr(nullptr);
134 }
135 
InitLoopVar(std::vector<ge::Operator> * init_input)136 void DfGraphConvertor::InitLoopVar(std::vector<ge::Operator> *init_input) {
137   MS_EXCEPTION_IF_NULL(init_input);
138   if (this->training_) {
139     GeTensorDesc desc(GeShape(), ge::FORMAT_NCHW, ge::DT_INT64);
140     auto var_iter_num = std::make_shared<Variable>("npu_runconfig/iterations_per_loop");
141     auto var_loop_cond = std::make_shared<Variable>("npu_runconfig/loop_cond");
142     auto var_one = std::make_shared<Variable>("npu_runconfig/one");
143     auto var_zero = std::make_shared<Variable>("npu_runconfig/zero");
144     (void)var_iter_num->update_output_desc_y(desc);
145     (void)var_loop_cond->update_output_desc_y(desc);
146     (void)var_one->update_output_desc_y(desc);
147     (void)var_zero->update_output_desc_y(desc);
148     vars_["npu_runconfig/iterations_per_loop"] = var_iter_num;
149     vars_["npu_runconfig/loop_cond"] = var_loop_cond;
150     vars_["npu_runconfig/one"] = var_one;
151     vars_["npu_runconfig/zero"] = var_zero;
152 
153     int64_t value = 0;
154     auto const_iter_num = std::make_shared<Constant>("const/npu_runconfig/iterations_per_loop");
155     if (ConfigManager::GetInstance().dataset_mode() == DS_SINK_MODE) {
156       value = ConfigManager::GetInstance().iter_num();
157     } else {
158       MS_LOG(INFO) << "Run with normal(non-sink) mode, the iterator number will always be 1";
159       value = 1;
160       ConfigManager::GetInstance().set_iter_num(value);
161     }
162     value -= 1;  // iteration start from 0, the max iteration number for n loop should be n-1
163     (void)const_iter_num->set_attr_value(GeTensor(desc, reinterpret_cast<uint8_t *>(&value), sizeof(int64_t)));
164 
165     auto const_loop_cond = std::make_shared<Constant>("const/npu_runconfig/loop_cond");
166     value = 0;
167     (void)const_loop_cond->set_attr_value(GeTensor(desc, reinterpret_cast<uint8_t *>(&value), sizeof(int64_t)));
168 
169     auto const_one = std::make_shared<Constant>("const/npu_runconfig/one");
170     value = 1;
171     (void)const_one->set_attr_value(GeTensor(desc, reinterpret_cast<uint8_t *>(&value), sizeof(int64_t)));
172 
173     auto const_zero = std::make_shared<Constant>("const/npu_runconfig/zero");
174     value = 0;
175     (void)const_zero->set_attr_value(GeTensor(desc, reinterpret_cast<uint8_t *>(&value), sizeof(int64_t)));
176 
177     (void)const_iter_num->update_output_desc_y(desc);
178     (void)const_loop_cond->update_output_desc_y(desc);
179     (void)const_one->update_output_desc_y(desc);
180     (void)const_zero->update_output_desc_y(desc);
181 
182     auto assign_iter_num = std::make_shared<Assign>("assign/npu_runconfig/iterations_per_loop");
183     (void)assign_iter_num->set_input_ref(*var_iter_num).set_input_value(*const_iter_num);
184     auto assign_loop_cond = std::make_shared<Assign>("assign/npu_runconfig/loop_cond");
185     (void)assign_loop_cond->set_input_ref(*var_loop_cond).set_input_value(*const_loop_cond);
186     auto assign_one = std::make_shared<Assign>("assign/npu_runconfig/one");
187     (void)assign_one->set_input_ref(*var_one).set_input_value(*const_one);
188     auto assign_zero = std::make_shared<Assign>("assign/npu_runconfig/zero");
189     (void)assign_zero->set_input_ref(*var_zero).set_input_value(*const_zero);
190 
191     init_input->push_back(*var_iter_num);
192     init_input->push_back(*var_loop_cond);
193     init_input->push_back(*var_one);
194     init_input->push_back(*var_zero);
195     init_ops_.push_back(var_iter_num);
196     init_ops_.push_back(var_loop_cond);
197     init_ops_.push_back(var_one);
198     init_ops_.push_back(var_zero);
199     init_ops_.push_back(const_iter_num);
200     init_ops_.push_back(const_loop_cond);
201     init_ops_.push_back(const_one);
202     init_ops_.push_back(const_zero);
203     init_ops_.push_back(assign_iter_num);
204     init_ops_.push_back(assign_loop_cond);
205     init_ops_.push_back(assign_one);
206     init_ops_.push_back(assign_zero);
207   }
208 }
209 
FindAdapter(const std::string & name,bool train)210 OpAdapterPtr DfGraphConvertor::FindAdapter(const std::string &name, bool train) {
211   auto it = OpAdapterMap::get().find(name);
212   if (it != OpAdapterMap::get().end()) {
213     return it->second->Get(train);
214   }
215   MS_LOG(EXCEPTION) << "Can't find OpAdapter for " << name;
216 }
217 
DrawParamInitSubGraph(const std::string & name,const AnfNodePtr & it)218 void DfGraphConvertor::DrawParamInitSubGraph(const std::string &name, const AnfNodePtr &it) {
219   // draw init subgraph
220   init_sout_ << "op_assign" << it.get() << "[label=<";
221   init_sout_ << "<table border='1' cellborder='1'>" << endl;
222   init_sout_ << "<tr>";
223   init_sout_ << "<td port='1'>resource</td>";
224   init_sout_ << "<td port='2'>value</td>";
225   init_sout_ << "</tr>" << endl;
226   init_sout_ << "<tr><td colspan=\"2\">"
227              << "\"assign_" << name << "\"</td></tr>" << endl;
228   init_sout_ << "</table>> shape=plaintext]" << endl;
229   init_sout_ << "param" << it.get() << "[shape=octagon, label=\"" << name << "\"]" << endl;
230   init_sout_ << "const" << it.get() << "[label= \"" << name << "_const"
231              << "\" shape=ellipse]" << endl;
232   init_sout_ << "param" << it.get() << "->"
233              << "op_assign" << it.get() << ":1" << endl;
234   init_sout_ << "const" << it.get() << "->"
235              << "op_assign" << it.get() << ":2" << endl;
236 }
237 
SetupParamInitSubGraph(const TensorOrderMap & tensors,std::vector<ge::Operator> * init_input)238 void DfGraphConvertor::SetupParamInitSubGraph(const TensorOrderMap &tensors, std::vector<ge::Operator> *init_input) {
239   DfGraphPtr init_graph = std::make_shared<DfGraph>("init");
240   std::vector<AnfNodePtr> nodes = GetOrderedCNodes(anf_graph_);
241 
242   for (auto &it : nodes) {
243     MS_EXCEPTION_IF_NULL(it);
244     if (it->isa<ValueNode>()) {
245       if (IsValueNode<SymbolicKeyInstance>(it)) {
246         auto symbolic = GetValueNode<SymbolicKeyInstancePtr>(it);
247         auto name = std::static_pointer_cast<Parameter>(symbolic->node())->name();
248         auto iter = vars_.find(name);  // get corresponding variable op
249         if (iter != vars_.end()) {
250           op_cache_[it.get()] = iter->second;
251           // #ifdef DRAW_GE_GRAPH
252           compute_sout_ << op_draw_name_[params_[name].get()] << " -> " << op_draw_name_[it.get()]
253                         << "[style=\"dotted\"]" << endl;
254           // #endif
255         }
256       } else if (IsValueNode<RefKey>(it)) {
257         auto refkey = GetValueNode<RefKeyPtr>(it);
258         MS_EXCEPTION_IF_NULL(refkey);
259         auto name = refkey->tag();
260         auto iter = vars_.find(name);  // get corresponding variable op
261         if (iter != vars_.end()) {
262           op_cache_[it.get()] = iter->second;
263           compute_sout_ << op_draw_name_[params_[name].get()] << " -> " << op_draw_name_[it.get()]
264                         << "[style=\"dotted\"]" << endl;
265         }
266       }
267     }
268   }
269 
270   for (auto &it : tensors) {
271     if (vars_.find(it.first) == vars_.end()) {
272       MS_LOG(WARNING) << "Init parameter " << it.first << " didn't appear in graph.";
273       vars_[it.first] = nullptr;
274     }
275   }
276 
277   // set up init sub graph
278   if (init_input->size()) {
279     // init sub graph needs no input
280     MS_LOG(INFO) << "Build data init subgraph.";
281     (void)init_graph->SetInputs(*init_input);
282     this->init_graph_ = init_graph;
283   } else {
284     this->init_graph_ = nullptr;
285   }
286 }
287 
MakeDatasetHandler(const std::string & name,const size_t & input_idx,const AnfNodePtr & it)288 void DfGraphConvertor::MakeDatasetHandler(const std::string &name, const size_t &input_idx, const AnfNodePtr &it) {
289   MS_LOG(INFO) << "The " << name << " is the " << input_idx << "(st/nd/th) input";
290   if (ConfigManager::GetInstance().dataset_mode() == DS_SINK_MODE) {
291     auto getnext_idx = static_cast<int64_t>(input_idx);
292     DatasetGraphParam param = ConfigManager::GetInstance().dataset_param();
293     if (!param.input_indexes().empty() && input_idx <= param.input_indexes().size()) {
294       getnext_idx = param.input_indexes()[input_idx] - 1;  // input_idx start from 0.
295       MS_LOG(INFO) << "remap input_index:" << input_idx << " to getnext_index:" << getnext_idx << ".";
296     }
297     // use iterator_getnext op with output_name instead of data op in BuildGraph.
298     if (dataset_iter_getnext_ != nullptr) {
299       out_handle_cache_[it.get()] = OutHandler(dataset_iter_getnext_, "y" + std::to_string(getnext_idx));
300     }
301   }
302 }
303 
SetupBroadcast(const std::shared_ptr<HcomBroadcast> & broadcast,const std::vector<GeTensorDesc> & broadcast_desc,const DfGraphPtr & broadcast_graph,std::vector<ge::Operator> broadcast_input)304 void DfGraphConvertor::SetupBroadcast(const std::shared_ptr<HcomBroadcast> &broadcast,
305                                       const std::vector<GeTensorDesc> &broadcast_desc,
306                                       const DfGraphPtr &broadcast_graph, std::vector<ge::Operator> broadcast_input) {
307   MS_LOG(INFO) << "build broadcast subgraph";
308   if (broadcast_desc.size() != broadcast_input.size()) {
309     MS_LOG(EXCEPTION) << "Desc number of BroadCast is not equal to number of Input";
310   }
311   (void)broadcast->create_dynamic_input_x(static_cast<unsigned int>(broadcast_input.size()));
312   (void)broadcast->create_dynamic_output_y(static_cast<unsigned int>(broadcast_desc.size()));
313   for (unsigned int i = 0; i < broadcast_input.size(); i++) {
314     (void)broadcast->set_dynamic_input_x(i, broadcast_input[i]);
315     (void)broadcast->update_dynamic_output_desc_y(i, broadcast_desc[i]);
316   }
317   (void)broadcast_graph->SetInputs(broadcast_input);
318   this->broadcast_graph_ = broadcast_graph;
319 }
320 
InitParamWithData(const TensorOrderMap & tensors)321 void DfGraphConvertor::InitParamWithData(const TensorOrderMap &tensors) {
322   int index = 0;
323   std::vector<Operator> init_input;
324   for (auto it : tensors) {
325     std::string name = it.first;
326     auto node_itor = params_.find(name);
327     // if name not in params_, create a node in graph
328     if (node_itor == params_.end()) {
329       MS_LOG(WARNING) << name << " is not in params, and create a new node.";
330       ParameterPtr param = std::make_shared<Parameter>(nullptr);
331       name = name + "_temp";
332       param->set_name(name);
333       (void)ConvertParameter(param);
334       node_itor = params_.find(name);
335     }
336     auto node = node_itor->second;
337     auto op_itor = op_cache_.find(node.get());
338     if (op_itor == op_cache_.end()) {
339       MS_LOG(EXCEPTION) << "Can not find op for node " << node->ToString() << ".";
340     }
341     auto adpt = FindAdapter(kNameParam, training_);
342     if (adpt == nullptr) continue;
343     auto param_op = adpt->generate(name + "_data");
344     MS_LOG(INFO) << "Add parameter " << name << " as input, index " << index << ".";
345 
346     if (!training_) {
347       auto adpt_const = FindAdapter(kNameConst, training_);
348       if (adpt_const == nullptr) continue;
349       auto const_op = adpt_const->generate(name + "_const");
350       (void)adpt_const->setAttr(const_op, "value", it.second);
351 
352       auto const_op_desc = TransformUtil::GetGeTensorDesc(it.second->shape_c(), it.second->data_type(), kOpFormat_NCHW);
353       if (const_op_desc == nullptr) {
354         MS_LOG(WARNING) << "Create variable " << name << " output descriptor failed!";
355         continue;
356       }
357       (void)std::static_pointer_cast<Constant>(const_op)->update_output_desc_y(*const_op_desc);
358 
359       vars_[name] = const_op;
360       op_itor->second = const_op;
361       continue;
362     }
363 
364     // create tensor descriptor for output descriptor
365     auto desc = TransformUtil::GetGeTensorDesc(it.second->shape_c(), it.second->data_type(), kOpFormat_NCHW);
366     if (desc == nullptr) {
367       MS_LOG(ERROR) << "Create variable " << name << " output descriptor failed!";
368       continue;
369     }
370 
371     // we need three variable ops for each graph with same name
372     // build init subgraph
373     if (it.second->is_init() == 0) {
374       (void)std::static_pointer_cast<Data>(param_op)->set_attr_index(index++);
375       auto init_var = std::make_shared<Variable>(name);
376       auto assign_op = std::make_shared<Assign>("assign_" + name);
377       (void)init_var->update_output_desc_y(*desc);
378       (void)assign_op->set_input_ref(*init_var).set_input_value(*param_op);
379       init_input.push_back(*init_var);
380       init_ops_.push_back(param_op);
381       init_ops_.push_back(assign_op);
382       init_ops_.push_back(init_var);
383     }
384 
385     auto variable = std::make_shared<Variable>(name);
386     (void)variable->update_output_desc_y(*desc);
387     // do not use read variable while variable sink
388     MS_LOG(DEBUG) << "InitParam, op_name = " << name << ", var = " << variable->GetName() << ".";
389     op_itor->second = variable;  // replace parameter with variable
390     vars_[name] = variable;      // prevent the variable operator from being freed
391     DrawParamInitSubGraph(name, node);
392   }
393   InitLoopVar(&init_input);
394   SetupParamInitSubGraph(tensors, &init_input);
395 }
396 
397 // convert all parameter need initialize to variable
InitParam(const TensorOrderMap & tensors)398 DfGraphConvertor &DfGraphConvertor::InitParam(const TensorOrderMap &tensors) {
399   size_t input_idx = 0;
400   if (error_ != 0) {
401     return *this;
402   }
403   if (anf_graph_ == nullptr || anf_graph_->output() == nullptr) {
404     error_ = INVALID_ARGUMENT;
405     MS_LOG(ERROR) << "Invalid AnfGraph in InitParam.";
406     return *this;
407   }
408 
409   // Processing input with MakeDatasetHandler
410   for (auto &it : anf_graph_->parameters()) {
411     auto op_itor = op_cache_.find(it.get());  // converted node
412     if (it->isa<Parameter>() && op_itor != op_cache_.end()) {
413       string name = std::static_pointer_cast<Parameter>(it)->name();
414       auto tensor_itor = tensors.find(name);  // in init value map
415       if (tensor_itor == tensors.end()) {
416         DfGraphConvertor::MakeDatasetHandler(name, input_idx, it);
417         input_idx++;
418       }
419     }
420   }
421   InitParamWithData(tensors);
422   init_sout_ << "}" << endl;
423   return *this;
424 }
425 
426 #if (defined ENABLE_GE)
BuildSaveCheckpointGraph()427 void DfGraphConvertor::BuildSaveCheckpointGraph() {
428   std::vector<Operator> graph_inputs;
429   ge::op::Save save_op("save_parms");
430   int save_op_is_active = 0;
431   size_t index = 0;
432   string name;
433 
434   int32_t count_size = std::count_if(vars_.begin(), vars_.end(), [](const std::pair<std::string, OperatorPtr> &it) {
435     return (it.second == nullptr || it.first.find("/") != std::string::npos);
436   });
437 
438   (void)save_op.create_dynamic_input_tensors(vars_.size() - static_cast<size_t>(count_size));
439 
440   // for each "parameter" in anf graph excluding "input"
441   for (const auto &it : vars_) {
442     name = it.first;
443     if (it.second == nullptr || name.find("/") != std::string::npos) continue;
444     Variable variable(name);
445     (void)variable.update_output_desc_y(it.second->GetOutputDesc(0));
446     (void)save_op.set_dynamic_input_tensors(index++, variable);
447 
448     graph_inputs.push_back(variable);
449 
450     if (save_op_is_active == 0) {
451       checkpoint_sout_ << "op_save" << &save_op << "[label=<";
452       checkpoint_sout_ << "<table border='1' cellborder='1'>" << endl;
453       checkpoint_sout_ << "<tr><td port='1'>tensor</td></tr>" << endl;
454       checkpoint_sout_ << "<tr><td colspan=\"1\">"
455                        << "\"saveop"
456                        << "\"</td></tr>" << endl;
457       checkpoint_sout_ << "</table>> shape=plaintext]" << endl;
458     }
459 
460     checkpoint_sout_ << "param" << it.second << "[shape=octagon, label=\"" << name << "\"]" << endl;
461 
462     checkpoint_sout_ << "param" << it.second << "->"
463                      << "op_save" << &save_op << ":1" << endl;
464     save_op_is_active = 1;
465   }
466   if (save_op_is_active) {
467     std::vector<Operator> graph_output;
468     graph_output.emplace_back(save_op);
469     DfGraphPtr checkpoint_graph = std::make_shared<DfGraph>("checkpoint");
470     (void)checkpoint_graph->SetInputs(graph_inputs);
471     (void)checkpoint_graph->SetOutputs(graph_output);
472     this->save_ckp_graph_ = checkpoint_graph;
473   } else {
474     this->save_ckp_graph_ = nullptr;
475   }
476 
477   checkpoint_sout_ << "}" << endl;
478   return;
479 }
480 #endif
481 
GenerateBroadcastGraph(const TensorOrderMap & tensors)482 DfGraphConvertor &DfGraphConvertor::GenerateBroadcastGraph(const TensorOrderMap &tensors) {
483   if (error_ != 0) {
484     return *this;
485   }
486   if (anf_graph_ == nullptr || anf_graph_->output() == nullptr) {
487     error_ = INVALID_ARGUMENT;
488     MS_LOG(ERROR) << "Invalid AnfGraph in generate broadcast graph";
489     return *this;
490   }
491 
492   DfGraphPtr broadcast_graph = std::make_shared<DfGraph>("broadcast");
493   // collect the operators create for broadcast sub graph, in order to avoid auto release
494   std::vector<Operator> broadcast_input;
495   std::vector<GeTensorDesc> broadcast_desc;
496   auto broadcast = std::make_shared<HcomBroadcast>("broadcast_parameter");
497   (void)broadcast->set_attr_root_rank(0);
498   (void)broadcast->set_attr_group("hccl_world_group");
499   broadcast_ops_.push_back(broadcast);
500 
501   // find every parameter, build broadcast subgraph (or initialize the parameter with constant)
502   for (auto &it : anf_graph_->parameters()) {
503     auto op_itor = op_cache_.find(it.get());  // converted node
504     if (it->isa<Parameter>() && op_itor != op_cache_.end()) {
505       string name = std::static_pointer_cast<Parameter>(it)->name();
506       auto tensor_itor = tensors.find(name);  // in init tensor map
507       if (tensor_itor != tensors.end()) {
508         auto tensor = tensor_itor->second;
509         auto shape_ge = tensor->shape_c();
510 
511         // create tensor descriptor for output descriptor
512         auto desc = TransformUtil::GetGeTensorDesc(shape_ge, tensor->data_type(), kOpFormat_NCHW);
513         if (desc == nullptr) {
514           MS_LOG(ERROR) << "Create variable " << name << " output descriptor failed!";
515           continue;
516         }
517 
518         // build broadcast subgraph
519         if (distribute_) {
520           auto broadcast_var = std::make_shared<Variable>(name);
521           (void)broadcast_var->update_output_desc_y(*desc);
522           broadcast_input.push_back(*broadcast_var);
523           broadcast_desc.push_back(*desc);
524           broadcast_ops_.push_back(broadcast_var);
525         }
526       }
527     }
528   }
529 
530   // set up broadcast sub graph
531   if (!broadcast_input.empty()) {
532     DfGraphConvertor::SetupBroadcast(broadcast, broadcast_desc, broadcast_graph, broadcast_input);
533   } else {
534     this->broadcast_graph_ = nullptr;
535   }
536   return *this;
537 }
538 
GenerateCheckpointGraph()539 DfGraphConvertor &DfGraphConvertor::GenerateCheckpointGraph() {
540   if (error_ != 0) {
541     MS_LOG(ERROR) << "Generate checkpoint graph failed, found error code " << error_ << ".";
542     return *this;
543   }
544   if (anf_graph_ == nullptr || anf_graph_->output() == nullptr) {
545     error_ = INVALID_ARGUMENT;
546     MS_LOG(ERROR) << "Invalid AnfGraph in GenerateCheckpointGraph";
547     return *this;
548   }
549 #if (defined ENABLE_GE)
550   BuildSaveCheckpointGraph();
551   // Restoring from checkpoint file is done by pyfront, not in graph now.
552 #endif
553   return *this;
554 }
555 
ConvertAllNode()556 DfGraphConvertor &DfGraphConvertor::ConvertAllNode() {
557   if (error_ != 0) {
558     return *this;
559   }
560   if (anf_graph_ == nullptr || anf_graph_->output() == nullptr) {
561     MS_LOG(ERROR) << "Invalid AnfGraph";
562     error_ = FAILED;
563     return *this;
564   }
565 
566   compute_sout_.clear();
567   compute_sout_ << "digraph {" << endl;
568   init_sout_.clear();
569   init_sout_ << "digraph {" << endl;
570 #if (defined ENABLE_GE)
571   checkpoint_sout_.clear();
572   checkpoint_sout_ << "digraph {" << endl;
573 #endif
574   restore_checkpoint_sout_.clear();
575   restore_checkpoint_sout_ << "digraph {" << endl;
576 
577   // Convert all anf node to Operator
578   MS_LOG(DEBUG) << "convert all node";
579   std::vector<AnfNodePtr> nodes = GetOrderedCNodes(anf_graph_);
580   for (auto &it : nodes) {
581     (void)Convert(it);
582     if (this->error_ != 0) {
583       MS_LOG(ERROR) << "failed to convert node: " << it->DebugString() << ".";
584     }
585   }
586 
587   // Create dataset iterator and iterator_getnext node
588   if (ConfigManager::GetInstance().dataset_mode() == DS_SINK_MODE) {
589     DatasetGraphParam param = ConfigManager::GetInstance().dataset_param();
590     MS_LOG(INFO) << "Dataset param is " << param.ToString() << ".";
591     // GetNext
592     auto iter_getnext_op = make_shared<ge::op::GetNext>("get_next_tmp");
593     (void)iter_getnext_op->set_attr_output_types(param.ge_types());
594     (void)iter_getnext_op->set_attr_output_shapes(param.shapes());
595     (void)iter_getnext_op->set_attr_channel_name(param.queue_name());
596 
597     // save iter_getnext_op for later use
598     dataset_iter_getnext_ = iter_getnext_op;
599   }
600 
601   // return the data flow graph
602   return *this;
603 }
604 
TraceOutputFromTupleGetItem(const AnfNodePtr & anf_out)605 void DfGraphConvertor::TraceOutputFromTupleGetItem(const AnfNodePtr &anf_out) {
606   auto it = out_handle_cache_.find(anf_out.get());
607   if (it != out_handle_cache_.end()) {
608     OutHandler handle = it->second;
609     auto op = handle.op;
610     if (op != nullptr) {
611       MS_LOG(INFO) << "op name: " << op->GetName() << ", op type: " << op->GetOpType() << ", out_name: " << handle.out;
612       graph_outputs_.emplace_back(std::make_pair(*op, handle.out));
613     } else {
614       MS_LOG(EXCEPTION) << "tuple_getitem: " << anf_out->fullname_with_scope() << " is not converted";
615     }
616   } else {
617     // invalid tuple_getitem e.g. tuple_getitem(tuple_getitem())/tuple_getitem(depend())/tuple_getitem(make_tuple())
618     MS_LOG(WARNING) << "Invalid tuple_getitem: " << anf_out->fullname_with_scope();
619   }
620 }
621 
TraceOutput(const AnfNodePtr node)622 void DfGraphConvertor::TraceOutput(const AnfNodePtr node) {
623   MS_EXCEPTION_IF_NULL(node);
624   AnfNodePtr anf_out = node;
625   AnfNodePtr pre_node = nullptr;
626 
627   // Trace value node
628   if (node->isa<ValueNode>()) {
629     auto op = Convert(anf_out);
630     if (op != nullptr) {
631       graph_outputs_.emplace_back(std::make_pair(*op, ""));
632       AddGraphConstInput(op);
633     }
634     return;
635   }
636 
637   // Trace Parameter node
638   TraceOutputFromParameter(anf_out);
639   // Then trace cnode
640   if (!node->isa<CNode>()) {
641     return;
642   }
643 
644   // trace tuple_getitem
645   while (anf_out->isa<CNode>() && IsPrimitiveCNode(anf_out, prim::kPrimTupleGetItem)) {
646     pre_node = anf_out;
647     anf_out = anf_out->cast<CNodePtr>()->input(1);
648   }
649   // trace every element of make_tuple
650   auto c = anf_out->cast<CNodePtr>();
651   std::string name = "";
652   if (anf_out->isa<CNode>()) {
653     name = GetCNodeTargetFuncName(c);
654   }
655 
656   if (name == "MakeTuple") {
657     for (unsigned int i = 1; i < c->inputs().size(); i++) {
658       TraceOutput(c->input(i));
659     }
660   } else if (name == prim::kPrimDepend->name()) {
661     if (c->inputs().size() < 3) {  // "Depend" primitive have 3 inputs
662       MS_LOG(EXCEPTION) << "length of inputs is " << c->inputs().size() << ", which is less than 3";
663     }
664     TraceOutput(c->input(1));
665   } else if (name == prim::kTupleGetItem) {
666     TraceOutputFromTupleGetItem(anf_out);
667   } else {
668     // add outputs
669     auto op = Convert(anf_out);
670     std::string index;
671     if (op != nullptr) {
672       if ((pre_node != nullptr) && IsPrimitiveCNode(pre_node, prim::kPrimTupleGetItem)) {
673         auto item = out_handle_cache_.find(pre_node.get());
674         if (item != out_handle_cache_.end()) {
675           index = item->second.out;
676         } else {
677           MS_LOG(WARNING) << "Can't get operator: " << anf_out->fullname_with_scope() << " 's output item";
678         }
679       }
680       MS_LOG(INFO) << "Add graph output: " << anf_out->fullname_with_scope() << ":" << index;
681       graph_outputs_.emplace_back(make_pair(*op, index));
682     }
683   }
684 }
685 
TraceOutputFromParameter(const AnfNodePtr & anf_out)686 void DfGraphConvertor::TraceOutputFromParameter(const AnfNodePtr &anf_out) {
687   MS_EXCEPTION_IF_NULL(anf_out);
688   if (anf_out->isa<Parameter>()) {
689     MS_LOG(INFO) << "Add graph output: " << anf_out->fullname_with_scope();
690     auto it = out_handle_cache_.find(anf_out.get());
691     if (it != out_handle_cache_.end()) {
692       // For dataset graph mode, input parameter is converted to a "iterator_get_next:yn" OutHandler.
693       OutHandler handle = it->second;
694       auto op = handle.op;
695       MS_LOG(INFO) << "op name: " << op->GetName() << ", op type: " << op->GetOpType() << ", out_name: " << handle.out;
696       graph_outputs_.emplace_back(make_pair(*op, handle.out));
697     } else {
698       // common parameter case
699       auto op = Convert(anf_out);
700       if (op != nullptr) {
701         MS_LOG(INFO) << "op name: " << op->GetName() << ", op type: " << op->GetOpType();
702         graph_outputs_.emplace_back(std::make_pair(*op, ""));
703       }
704     }
705   }
706 }
707 
SetupDatasetIterGetNextNode(const OperatorPtr & op)708 void SetupDatasetIterGetNextNode(const OperatorPtr &op) {
709   if (ConfigManager::GetInstance().dataset_mode() == DS_SINK_MODE) {
710     DatasetGraphParam param = ConfigManager::GetInstance().dataset_param();
711     size_t output_num = param.ge_types().size();
712     MS_LOG(INFO) << "Set iterator_getnext op's output num = " << output_num << ".";
713     // set iterator_getnext op's output num
714     shared_ptr<ge::op::GetNext> iter_getnext = std::static_pointer_cast<ge::op::GetNext>(op);
715     (void)iter_getnext->create_dynamic_output_y(static_cast<unsigned int>(output_num));
716 
717     for (uint32_t i = 0; i < output_num; i++) {
718       ge::TensorDesc desc(GeShape(param.shapes()[i]), ge::FORMAT_NCHW, (ge::DataType)param.ge_types()[i]);
719       // we don't SetRealDimCnt here since GE do not use this output's real-dim
720       (void)iter_getnext->update_dynamic_output_desc_y((i), desc);
721     }
722   }
723   return;
724 }
725 
SetSubgraph(AnfNodePtr node)726 void DfGraphConvertor::SetSubgraph(AnfNodePtr node) {
727   if (!node->isa<CNode>()) {
728     return;
729   }
730   auto cnode = node->cast<CNodePtr>();
731   if (!IsCaseNode(cnode)) {
732     return;
733   }
734   std::vector<AnfNodePtr> case_inputs;
735   for (size_t i = 1; i < cnode->inputs().size(); i++) {
736     case_inputs.emplace_back(cnode->input(i));
737   }
738   std::shared_ptr<std::vector<DfGraph>> branches = std::make_shared<std::vector<DfGraph>>();
739   auto bnode = cnode->input(0)->cast<CNodePtr>()->input(2)->cast<CNodePtr>();
740 
741   for (size_t i = 1; i < bnode->inputs().size(); i++) {
742     auto branch_node = bnode->input(i)->cast<CNodePtr>();
743     for (size_t j = 2; j < branch_node->inputs().size(); j++) {
744       if (std::find(case_inputs.begin(), case_inputs.end(), branch_node->input(j)) == case_inputs.end()) {
745         case_inputs.emplace_back(branch_node->input(j));
746       }
747     }
748   }
749 
750   for (size_t i = 1; i < bnode->inputs().size(); i++) {
751     ProcessSubgraph(bnode->input(i), case_inputs);
752   }
753 
754   for (size_t i = 1; i < bnode->inputs().size(); i++) {
755     branches->emplace_back(branches_map_[bnode->input(i).get()]);
756   }
757 
758   if (op_cache_.find(node.get()) == op_cache_.end()) {
759     return;
760   }
761 
762   OpAdapterPtr adpt = FindAdapter(node, training_);
763   if (adpt == nullptr) {
764     MS_LOG(DEBUG) << "Not found adapter";
765     return;
766   }
767 
768   OperatorPtr op = Convert(node);
769   adpt->setSubgraph(op, 0, branches);
770   return;
771 }
772 
GetCaseNodeInput(const CNodePtr node,const CNodePtr input_node)773 void DfGraphConvertor::GetCaseNodeInput(const CNodePtr node, const CNodePtr input_node) {
774   std::vector<AnfNodePtr> case_inputs;
775   for (size_t i = 1; i < node->inputs().size(); i++) {
776     case_inputs.emplace_back(node->input(i));
777   }
778   auto bnode = input_node->input(2)->cast<CNodePtr>();
779   MS_EXCEPTION_IF_NULL(bnode);
780   for (size_t i = 1; i < bnode->inputs().size(); i++) {
781     auto branch_node = bnode->input(i)->cast<CNodePtr>();
782     MS_EXCEPTION_IF_NULL(branch_node);
783     for (size_t j = 2; j < branch_node->inputs().size(); j++) {
784       if (std::find(case_inputs.begin(), case_inputs.end(), branch_node->input(j)) == case_inputs.end()) {
785         case_inputs.emplace_back(branch_node->input(j));
786       }
787     }
788   }
789 
790   const size_t case_index = 1;
791   const size_t make_tuple_index = 2;
792 
793   AnfNodePtr case_index_iter = input_node->input(case_index);
794   AnfNodePtr make_tuple_iter = input_node->input(make_tuple_index);
795   auto make_tuple_node = make_tuple_iter->cast<CNodePtr>();
796   std::shared_ptr<std::vector<OutHandler>> tuple_items = std::make_shared<std::vector<OutHandler>>();
797 
798   for (size_t i = 0; i < case_inputs.size(); i++) {
799     auto item = case_inputs[i];
800     auto op = Convert(item);
801     if (op != nullptr) {
802       tuple_items->emplace_back(OutHandler(op, "", item));
803     } else if (out_handle_cache_.find(item.get()) != out_handle_cache_.end()) {
804       tuple_items->push_back(out_handle_cache_[item.get()]);
805     } else {
806       MS_LOG(DEBUG) << "Add an empty out handler: " << item->ToString();
807       tuple_items->push_back(OutHandler());
808     }
809   }
810 
811   tuple_out_handle_cache_[make_tuple_node.get()] = tuple_items;
812 
813   std::shared_ptr<std::vector<AnfNodePtr>> case_input_items = std::make_shared<std::vector<AnfNodePtr>>();
814   case_input_items->emplace_back(case_index_iter);
815   case_input_items->emplace_back(make_tuple_iter);
816   case_input_handle_cache_[node.get()] = case_input_items;
817 }
818 
UpdateTupleOutCache()819 void DfGraphConvertor::UpdateTupleOutCache() {
820   for (auto &it : tuple_out_handle_cache_) {
821     std::size_t len = it.second->size();
822     for (std::size_t i = 0; i < len; i++) {
823       OutHandler handle = (*it.second)[i];
824       if (handle.op == nullptr) {
825         continue;
826       }
827       string name = handle.op->GetName();
828       if (vars_.count(name) && (vars_[name] != nullptr)) {
829         (*it.second)[i] = OutHandler(vars_[name], handle.out, handle.node);
830         MS_LOG(INFO) << "update tuple_out_handle_cache_ " << name;
831       }
832     }
833   }
834 }
835 
BuildGraph()836 DfGraphConvertor &DfGraphConvertor::BuildGraph() {
837   SetupDatasetIterGetNextNode(dataset_iter_getnext_);
838 
839   if (error_ != 0) {
840     return *this;
841   }
842 
843   // Case node set input.
844   std::vector<AnfNodePtr> nodes = GetOrderedCNodes(anf_graph_);
845   for (auto &it : nodes) {
846     if (it->isa<CNode>() && IsCaseNode(it->cast<CNodePtr>())) {
847       auto node = it->cast<CNodePtr>();
848       auto input_node = node->input(0)->cast<CNodePtr>();
849       GetCaseNodeInput(node, input_node);
850     }
851   }
852 
853   // update tuple_out_handle_cache_
854   UpdateTupleOutCache();
855 
856   // set up dependencies
857   MS_LOG(DEBUG) << "set up dependencies";
858   nodes = GetOrderedCNodes(anf_graph_);
859   for (auto &it : nodes) {
860     SetNodeInput(it);
861     SetOpControlInput(it);
862     SetSubgraph(it);
863     UpdateOpDesc(it);
864   }
865 
866   if (error_ == 0) {
867     df_graph_ = make_shared<DfGraph>(anf_graph_->ToString());
868   } else {
869     return *this;
870   }
871 
872   // set graph input according to the order from anf graph
873   std::vector<Operator> inputs;
874   if (ConfigManager::GetInstance().dataset_mode() == DS_SINK_MODE) {
875     inputs.push_back(*dataset_iter_getnext_);
876   } else {
877     auto params = anf_graph_->parameters();
878     if (use_inputs_) {
879       params = inputs_;
880       auto anf_params = anf_graph_->parameters();
881       for (size_t i = 0; i < params.size(); i++) {
882         for (size_t j = 0; j < anf_params.size(); j++) {
883           if (params[i]->ToString() == anf_params[j]->ToString()) {
884             params[i] = anf_params[j];
885           }
886         }
887       }
888     }
889 
890     int index = 0;
891     for (auto &it : params) {
892       auto name = std::static_pointer_cast<Parameter>(it)->name();
893       //  the parameters which has not been converted to var
894       if (vars_.find(name) == vars_.end()) {
895         auto op = Convert(it);
896         MS_EXCEPTION_IF_NULL(op);
897         MS_LOG(INFO) << "add not var input " << it->ToString() << ", index " << index;
898         if (op == nullptr) {
899           MS_LOG(ERROR) << "Convert graph failed!";
900           return *this;
901         }
902         UpdateDataOpDesc(it, op);
903         if (HasAbstractMonad(it)) {
904           MS_LOG(INFO) << it->DebugString() << " is a monad parameter, skip.";
905           continue;
906         }
907         MS_LOG(INFO) << "add input " << it->ToString() << ", index " << index;
908         (void)std::static_pointer_cast<Data>(op)->set_attr_index(index++);
909         inputs.push_back(*op);
910       } else if (vars_[name] != nullptr) {
911         MS_LOG(INFO) << "add var input " << it->ToString();
912         auto op = Convert(it);
913         MS_EXCEPTION_IF_NULL(op);
914         inputs.push_back(*op);
915       }
916     }
917   }
918 
919   MS_LOG(DEBUG) << "trace output";
920   graph_outputs_.clear();
921   TraceOutput(anf_graph_->get_return()->input(1));
922 
923   // Add const nodes as graph input for some operator work with constant
924   MS_LOG(INFO) << "graph const input size: " << graph_const_inputs_.size();
925   std::transform(graph_const_inputs_.begin(), graph_const_inputs_.end(), std::back_inserter(inputs),
926                  [](OperatorPtr x) { return *x; });
927 
928   MS_LOG(INFO) << "set graph input num: " << inputs.size();
929   (void)df_graph_->SetInputs(inputs);
930 
931   // set graph output
932   // set the value of finale return apply node as the output of dataflow graph
933   MS_LOG(DEBUG) << "set output";
934   MS_LOG(INFO) << "set graph output num: " << graph_outputs_.size();
935   (void)df_graph_->SetOutputs(graph_outputs_);
936 
937   compute_sout_ << "}" << endl;
938   // For the graph(e.g. eval_subgraph) whose IterNum is 1, donot set NeedIteration flag.
939   if (ConfigManager::GetInstance().iter_num() > 1) {
940     df_graph_->SetNeedIteration(true);
941   }
942   return *this;
943 }
944 
UpdateDataOpDesc(const AnfNodePtr & it,const OperatorPtr & op) const945 void DfGraphConvertor::UpdateDataOpDesc(const AnfNodePtr &it, const OperatorPtr &op) const {
946   auto node = std::static_pointer_cast<AnfNode>(it);
947   if (node == nullptr) {
948     MS_LOG(ERROR) << "Update data op descriptor failed! Invalid node.";
949     return;
950   }
951 
952   std::vector<int64_t> shape;
953   if (auto normal_shape_ptr = dyn_cast<abstract::Shape>(node->Shape()); normal_shape_ptr != nullptr) {
954     shape = normal_shape_ptr->shape();
955   } else if (auto no_shape_ptr = dyn_cast<abstract::NoShape>(node->Shape()); no_shape_ptr != nullptr) {
956     shape = {};
957   } else {
958     MS_LOG(INFO) << "Invalid shape to update data op descriptor.";
959     return;
960   }
961 
962   if (node->Type() == nullptr) {
963     MS_LOG(INFO) << "Invalid type to update data op descriptor.";
964     return;
965   }
966   TypeId me_type = node->Type()->type_id();
967   if (kObjectTypeTensorType == me_type) {
968     me_type = dyn_cast<TensorType>(node->Type())->element()->type_id();
969   }
970   std::ostringstream buf;
971   buf << "[" << shape << "]";
972   MS_LOG(INFO) << "input shape is " << buf.str() << ", type is " << me_type;
973   std::string format = "NCHW";
974   if (it->isa<Parameter>()) {
975     auto param = it->cast<ParameterPtr>();
976     std::string param_name = param->DebugString();
977     auto param_format = param_format_.find(param_name);
978     if (param_format != param_format_.end()) {
979       format = param_format->second;
980       MS_LOG(DEBUG) << "parameter: " << param_name << ", format is " << format;
981     }
982   }
983   auto desc = TransformUtil::GetGeTensorDesc(shape, me_type, format);
984   if (desc == nullptr) {
985     MS_LOG(ERROR) << "Update data op descriptor failed! TensorDesc is null.";
986   } else {
987     (void)std::static_pointer_cast<Data>(op)->update_input_desc_x(*desc);
988     (void)std::static_pointer_cast<Data>(op)->update_output_desc_y(*desc);
989   }
990 }
991 
GetComputeGraph()992 DfGraphPtr DfGraphConvertor::GetComputeGraph() { return df_graph_; }
993 
GetInitGraph()994 DfGraphPtr DfGraphConvertor::GetInitGraph() { return init_graph_; }
995 
GetSaveCheckpointGraph()996 DfGraphPtr DfGraphConvertor::GetSaveCheckpointGraph() { return save_ckp_graph_; }
997 
GetBroadcastGraph()998 DfGraphPtr DfGraphConvertor::GetBroadcastGraph() { return broadcast_graph_; }
999 
IsSourceEdgeNode(const AnfNodePtr & node)1000 bool DfGraphConvertor::IsSourceEdgeNode(const AnfNodePtr &node) {
1001   if (!node->isa<CNode>()) {
1002     return false;
1003   }
1004   auto cnode = node->cast<CNodePtr>();
1005   if (!IsCustomCNode(cnode)) {
1006     std::string name = GetCNodeTargetFuncName(cnode);
1007     if (name.empty()) {
1008       return false;
1009     }
1010 
1011     // Ignore apply node Depend, UpdateState, make_tuple. make_tuple in ge pipeline.
1012     if ((name == prim::kPrimDepend->name()) || (name == prim::kPrimUpdateState->name()) ||
1013         (name == prim::kPrimReturn->name()) || (name == prim::kPrimMakeTuple->name())) {
1014       return false;
1015     }
1016   }
1017   // Load and other normal primitives which contain monad node.
1018   auto has_monad = std::any_of(cnode->inputs().begin(), cnode->inputs().end(),
1019                                [](const AnfNodePtr &node) -> bool { return HasAbstractMonad(node); });
1020   if (has_monad) {
1021     return true;
1022   }
1023 
1024   // primitive with make_tuple as input
1025   for (auto &input : cnode->inputs()) {
1026     if (IsPrimitiveCNode(input, prim::kPrimMakeTuple)) {
1027       auto tuple = input->cast<CNodePtr>();
1028       auto ret = std::any_of(tuple->inputs().begin(), tuple->inputs().end(),
1029                              [](const AnfNodePtr &node) -> bool { return HasAbstractMonad(node); });
1030       if (ret) {
1031         return true;
1032       }
1033     }
1034   }
1035 
1036   return false;
1037 }
1038 
IsControlEdgeNode(const AnfNodePtr & node)1039 bool DfGraphConvertor::IsControlEdgeNode(const AnfNodePtr &node) {
1040   if (!node->isa<CNode>()) {
1041     return false;
1042   }
1043   auto cnode = node->cast<CNodePtr>();
1044   if (!IsCustomCNode(cnode)) {
1045     std::string name = GetCNodeTargetFuncName(cnode);
1046     if (name.empty()) {
1047       return false;
1048     }
1049 
1050     // Ignore apply node of Load, Depend, UpdateState, make_tuple, return
1051     if ((name == prim::kPrimLoad->name()) || (name == prim::kPrimDepend->name()) ||
1052         (name == prim::kPrimUpdateState->name()) || (name == prim::kPrimMakeTuple->name()) ||
1053         (name == prim::kPrimReturn->name())) {
1054       return false;
1055     }
1056   }
1057   return true;
1058 }
1059 
ToOperatorPtr(const AnfNodePtr & node)1060 OperatorPtr DfGraphConvertor::ToOperatorPtr(const AnfNodePtr &node) {
1061   auto op = Convert(GetRealOpNode(node));
1062   if (op == nullptr) {
1063     MS_LOG(ERROR) << "Convert real op node to operator failed, " << node->ToString();
1064     error_ = FAILED;
1065     return nullptr;
1066   }
1067   return op;
1068 }
1069 
AddEdgeToCache(const AnfNodePtr & src,const AnfNodePtr & dest)1070 void DfGraphConvertor::AddEdgeToCache(const AnfNodePtr &src, const AnfNodePtr &dest) {
1071   auto item = monad_control_edge_cache_.find(src);
1072   if (item == monad_control_edge_cache_.end()) {
1073     monad_control_edge_cache_[src] = std::set<AnfNodePtr>{dest};
1074   } else {
1075     item->second.insert(dest);
1076   }
1077 }
1078 
AddEdgeForLoad(const AnfNodePtr & node)1079 void DfGraphConvertor::AddEdgeForLoad(const AnfNodePtr &node) {
1080   auto func_graph = node->func_graph();
1081   MS_EXCEPTION_IF_NULL(func_graph);
1082   auto mng = func_graph->manager();
1083   if (mng == nullptr) {
1084     mng = Manage(func_graph, true);
1085     func_graph->set_manager(mng);
1086   }
1087   auto manager = func_graph->manager();
1088   MS_EXCEPTION_IF_NULL(manager);
1089   if (manager->node_users().find(node) == manager->node_users().end()) {
1090     MS_LOG(EXCEPTION) << "Can't find node in nodes_users.";
1091   }
1092   auto &users = manager->node_users()[node];
1093   std::shared_ptr<std::vector<AnfNodePtr>> src_node_list = std::make_shared<std::vector<AnfNodePtr>>();
1094   std::shared_ptr<std::vector<AnfNodePtr>> dst_node_list = std::make_shared<std::vector<AnfNodePtr>>();
1095   for (const auto &iter : users) {
1096     auto user_node = iter.first;
1097     auto name = GetCNodeTargetFuncName(user_node->cast<CNodePtr>());
1098     if (name == prim::kPrimUpdateState->name()) {
1099       FindDestOps(user_node, dst_node_list, false);
1100       continue;
1101     }
1102     if (IsControlEdgeNode(user_node)) {
1103       src_node_list->push_back(user_node);
1104       continue;
1105     }
1106     FindDestOps(user_node, src_node_list, false);
1107   }
1108 
1109   // add to cache
1110   for (auto &dest : *dst_node_list) {
1111     for (auto &src : *src_node_list) {
1112       AddEdgeToCache(src, dest);
1113     }
1114   }
1115 }
1116 
FindDestOps(const AnfNodePtr & node,const std::shared_ptr<std::vector<AnfNodePtr>> & node_list,bool top)1117 void DfGraphConvertor::FindDestOps(const AnfNodePtr &node, const std::shared_ptr<std::vector<AnfNodePtr>> &node_list,
1118                                    bool top) {
1119   MS_EXCEPTION_IF_NULL(node);
1120   auto func_graph = node->func_graph();
1121   MS_EXCEPTION_IF_NULL(func_graph);
1122   auto mng = func_graph->manager();
1123   if (mng == nullptr) {
1124     mng = Manage(func_graph, true);
1125     func_graph->set_manager(mng);
1126   }
1127   auto manager = func_graph->manager();
1128   MS_EXCEPTION_IF_NULL(manager);
1129 
1130   auto users = manager->node_users()[node];
1131   for (const auto &iter : users) {
1132     auto user_node = iter.first;
1133     if (IsControlEdgeNode(user_node)) {
1134       if (!top) {
1135         node_list->push_back(user_node);
1136       }
1137     } else {
1138       FindDestOps(user_node, node_list, false);
1139     }
1140   }
1141 }
1142 
AutoMonadCollectInput(const AnfNodePtr & node)1143 void DfGraphConvertor::AutoMonadCollectInput(const AnfNodePtr &node) {
1144   if (!IsSourceEdgeNode(node)) {
1145     return;
1146   }
1147 
1148   // Add control edge if contain monad input.
1149   std::string name = GetCNodeTargetFuncName(node->cast<CNodePtr>());
1150   if (name == prim::kPrimLoad->name()) {
1151     AddEdgeForLoad(node);
1152   } else {
1153     auto src_ops = ToOperatorPtr(node);
1154     if (src_ops != nullptr) {
1155       // Find dest ops list
1156       std::shared_ptr<std::vector<AnfNodePtr>> dst_node_list = std::make_shared<std::vector<AnfNodePtr>>();
1157       FindDestOps(node, dst_node_list, true);
1158       for (auto &dest : *dst_node_list) {
1159         AddEdgeToCache(node, dest);
1160       }
1161     }
1162   }
1163 }
1164 
AutoMonadSetInput(const AnfNodePtr & node)1165 void DfGraphConvertor::AutoMonadSetInput(const AnfNodePtr &node) {
1166   if (monad_control_edge_cache_.find(node) == monad_control_edge_cache_.end()) {
1167     return;
1168   }
1169 
1170   auto src_ops = ToOperatorPtr(node);
1171   if (src_ops != nullptr) {
1172     for (auto &dest : monad_control_edge_cache_[node]) {
1173       auto dest_ops = ToOperatorPtr(dest);
1174       if (dest_ops == nullptr) {
1175         continue;
1176       }
1177       (void)dest_ops->AddControlInput(*src_ops);
1178 #ifdef DRAW_GE_GRAPH
1179       compute_sout_ << op_draw_name_[node.get()] << " -> " << op_draw_name_[dest.get()] << "[style=\"dotted\"]" << endl;
1180 #endif
1181     }
1182   }
1183 }
1184 
AutoMonadSetControlInput(const AnfNodePtr & node)1185 void DfGraphConvertor::AutoMonadSetControlInput(const AnfNodePtr &node) {
1186   AutoMonadCollectInput(node);
1187   AutoMonadSetInput(node);
1188 }
1189 
SetOpControlInput(const AnfNodePtr & node)1190 void DfGraphConvertor::SetOpControlInput(const AnfNodePtr &node) {
1191   MS_EXCEPTION_IF_NULL(node);
1192   AutoMonadSetControlInput(node);
1193   if (control_edge_cache_.find(node.get()) == control_edge_cache_.end()) {
1194     return;
1195   }
1196 
1197   std::vector<ControlEdge> control_edges = control_edge_cache_[node.get()];
1198   if ((control_edges.empty())) {
1199     MS_LOG(ERROR) << "Get control edge node's src or dest operator failed";
1200     return;
1201   }
1202 
1203   for (auto &item : control_edges) {
1204     (void)item.dest_op->AddControlInput(*item.src_op);
1205   }
1206 }
1207 
1208 const std::vector<std::string> trans_var_list = {string(kNameAssign), string(kNameAssignAdd), string(kNameAssignSub)};
1209 
ParseLoadInput(const CNodePtr & cnode)1210 AnfNodePtr DfGraphConvertor::ParseLoadInput(const CNodePtr &cnode) {
1211   if (cnode->inputs().size() < 3) {
1212     MS_LOG(EXCEPTION) << "input size error, " << cnode->ToString();
1213   }
1214   const size_t para_index = 1;
1215   return cnode->input(para_index);
1216 }
1217 
SetTupleOpInput(const OpAdapterPtr & adpt,const CNodePtr & node,const AnfNodePtr & pred,const OperatorPtr & src,int index)1218 void DfGraphConvertor::SetTupleOpInput(const OpAdapterPtr &adpt, const CNodePtr &node, const AnfNodePtr &pred,
1219                                        const OperatorPtr &src, int index) {
1220   std::shared_ptr<std::vector<OutHandler>> handler_vec = tuple_out_handle_cache_[pred.get()];
1221   std::shared_ptr<std::vector<OutHandler>> handler_vec_without_monad = std::make_shared<std::vector<OutHandler>>();
1222   bool with_monad = false;
1223   for (auto &handler : *handler_vec) {
1224     // when tuple with monad type element, the handler operator is nullptr, should be ignored.
1225     if (handler.op == nullptr) {
1226       if ((handler.node != nullptr) && !HasAbstractMonad(handler.node)) {
1227         MS_LOG(WARNING) << "Unsupported node in tuple : " << node->ToString();
1228       }
1229       continue;
1230     }
1231     with_monad = true;
1232     handler_vec_without_monad->push_back(handler);
1233   }
1234   int ret = adpt->setInput(src, index, handler_vec_without_monad);
1235   if ((ret == 0) && pred->isa<CNode>() && (pred->cast<CNodePtr>()->inputs().size() == handler_vec->size() + 1)) {
1236     for (unsigned int j = 0; j < handler_vec_without_monad->size(); j++) {
1237       AnfNodePtr input_node = pred->cast<CNodePtr>()->input(j + 1);
1238       if (with_monad) {
1239         input_node = handler_vec_without_monad->at(j).node;
1240       }
1241       compute_sout_ << op_draw_name_[input_node.get()] << " -> " << op_draw_name_[node.get()] << ":" << index << endl;
1242       AddGraphConstInput(handler_vec_without_monad->at(j).op);
1243     }
1244     return;
1245   }
1246   MS_LOG(WARNING) << "This anf node is not supported as a tuple item : " << node->ToString();
1247 }
GetRealInputNode(const CNodePtr & node,const AnfNodePtr & input)1248 AnfNodePtr DfGraphConvertor::GetRealInputNode(const CNodePtr &node, const AnfNodePtr &input) {
1249   if (input == nullptr || node == nullptr) {
1250     return nullptr;
1251   }
1252   AnfNodePtr pred = input;
1253   while (pred->isa<CNode>() && GetCNodeTargetFuncName(pred->cast<CNodePtr>()) == prim::kPrimDepend->name()) {
1254     pred = pred->cast<CNodePtr>()->input(1);
1255   }
1256 
1257   // skip input of UMonad, IOMonad
1258   if (IsValueNode<UMonad>(pred) || IsValueNode<IOMonad>(pred)) {
1259     return nullptr;
1260   }
1261 
1262   // skip input of the None, UpdateState
1263   if (IsValueNode<None>(pred) || IsPrimitiveCNode(pred, prim::kPrimUpdateState)) {
1264     return nullptr;
1265   }
1266 
1267   if (IsPrimitiveCNode(pred, prim::kPrimLoad)) {
1268     pred = ParseLoadInput(pred->cast<CNodePtr>());
1269   }
1270 
1271   // transform "Const" op to "Variable" op when the next node is "Assign" op.
1272   std::string c_name = GetCNodeTargetFuncName(node);
1273   auto pos = std::find(trans_var_list.begin(), trans_var_list.end(), c_name);
1274   if (!training_ && pos != trans_var_list.end() && pred->isa<Parameter>()) {
1275     std::string name = std::static_pointer_cast<Parameter>(pred)->name();
1276     auto op_itor = op_cache_.find(pred.get());
1277     if (op_itor == op_cache_.end()) {
1278       MS_LOG(EXCEPTION) << "Can not find op for node " << pred->ToString() << ".";
1279     }
1280     if (op_itor->second != nullptr &&
1281         (op_itor->second->GetOpType() == "Constant" || op_itor->second->GetOpType() == "Const") &&
1282         vars_.find(name) != vars_.end()) {
1283       auto variable = std::make_shared<Variable>(name);
1284       auto desc = vars_[name]->GetOutputDesc("y");
1285       (void)variable->update_output_desc_y(desc);
1286       MS_LOG(DEBUG) << "Trans to variable, var = " << variable->GetName() << ".";
1287       op_itor->second = variable;  // replace parameter with variable
1288       vars_[name] = variable;
1289     }
1290   }
1291   return pred;
1292 }
1293 
SetOpInput(const OpAdapterPtr & adpt,const CNodePtr & node)1294 void DfGraphConvertor::SetOpInput(const OpAdapterPtr &adpt, const CNodePtr &node) {
1295   OperatorPtr src = Convert(node);
1296   int case_flag = 0;
1297   auto &inputs = node->inputs();
1298   size_t input_size = inputs.size();
1299   if (case_input_handle_cache_.find(node.get()) != case_input_handle_cache_.end()) {
1300     case_flag = 1;
1301     input_size = case_input_handle_cache_[node.get()]->size() + 1;
1302   }
1303 
1304   for (size_t i = 1; i < input_size; i++) {
1305     AnfNodePtr pred = nullptr;
1306     if (case_flag != 0) {
1307       pred = case_input_handle_cache_[node.get()]->at(i - 1);
1308     } else {
1309       pred = inputs[i];
1310     }
1311     pred = GetRealInputNode(node, pred);
1312     if (pred == nullptr) {
1313       continue;
1314     }
1315 
1316     int index = SizeToInt(i);
1317     // find in out_hadnle_cache_ first
1318     auto it = out_handle_cache_.find(pred.get());
1319     if (it != out_handle_cache_.end()) {
1320       int ret = adpt->setInput(src, index, it->second);
1321       if (ret == 0) {
1322         if (pred->isa<CNode>() && GetCNodeTargetFuncName(pred->cast<CNodePtr>()) == prim::kTupleGetItem) {
1323           compute_sout_ << op_draw_name_[pred->cast<CNodePtr>()->input(1).get()] << " -> " << op_draw_name_[node.get()]
1324                         << ":" << i << endl;
1325         } else if (pred->isa<Parameter>()) {
1326           compute_sout_ << op_draw_name_[pred.get()] << " -> " << op_draw_name_[node.get()] << ":" << i << endl;
1327         } else {
1328           // don't draw anything.
1329           MS_LOG(INFO) << "DRAW_GE_GRAPH: Shouldn't have this case.";
1330         }
1331         AddGraphConstInput(it->second.op);
1332       }
1333     } else if (tuple_out_handle_cache_.find(pred.get()) != tuple_out_handle_cache_.end()) {
1334       SetTupleOpInput(adpt, node, pred, src, index);
1335     } else {
1336       auto op = Convert(pred);
1337       int ret = adpt->setInput(src, index, op);
1338       if (ret == 0) {
1339         compute_sout_ << op_draw_name_[pred.get()] << " -> " << op_draw_name_[node.get()] << ":" << i << endl;
1340         AddGraphConstInput(op);
1341       }
1342     }
1343   }
1344 }
1345 
AddGraphConstInput(const OperatorPtr & op)1346 void DfGraphConvertor::AddGraphConstInput(const OperatorPtr &op) {
1347   if (op->GetOpType() == "Constant" || op->GetOpType() == "Const") {
1348     graph_const_inputs_.push_back(op);
1349   }
1350 }
1351 
SetNodeInput(const AnfNodePtr node)1352 void DfGraphConvertor::SetNodeInput(const AnfNodePtr node) {
1353   if (!node->isa<CNode>()) {
1354     return;
1355   }
1356   if (op_cache_.find(node.get()) == op_cache_.end()) {
1357     return;
1358   }
1359   auto cnode = node->cast<CNodePtr>();
1360   OpAdapterPtr adpt = FindAdapter(cnode, training_);
1361   if (adpt == nullptr) {
1362     error_ = NOT_FOUND;
1363     return;
1364   }
1365 
1366   // get Operator from op_cache_, use adapter to set Inputs
1367   DfGraphConvertor::SetOpInput(adpt, cnode);
1368 }
1369 
ProcessSubgraph(AnfNodePtr node,const std::vector<AnfNodePtr> & inputs)1370 void DfGraphConvertor::ProcessSubgraph(AnfNodePtr node, const std::vector<AnfNodePtr> &inputs) {
1371   if (!node->isa<CNode>() || GetCNodeFuncName(node->cast<CNodePtr>()) != "Partial") {
1372     return;
1373   }
1374   auto graph_node = node->cast<CNodePtr>()->input(1)->cast<ValueNodePtr>();
1375   MS_EXCEPTION_IF_NULL(graph_node);
1376   FuncGraphPtr anf_graph = graph_node->value()->cast<FuncGraphPtr>();
1377   DfGraphConvertor converter(anf_graph);
1378   converter.use_inputs_ = true;
1379   converter.inputs_ = inputs;
1380   (void)converter.ConvertAllNode().BuildGraph();
1381 #ifdef ENABLE_DUMP_IR
1382   std::string name = graph_node->ToString() + "_ge_graph.dot";
1383   if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
1384     converter.DrawComputeGraph(name);
1385   }
1386 #endif
1387   branches_map_[node.get()] = *(converter.df_graph_);
1388 }
1389 
1390 // Update GE op's shape and type info
UpdateOpDesc(const AnfNodePtr node)1391 void DfGraphConvertor::UpdateOpDesc(const AnfNodePtr node) {
1392   if (node == nullptr || !node->isa<CNode>()) {
1393     return;
1394   }
1395 
1396   if (op_cache_.find(node.get()) == op_cache_.end()) {
1397     return;
1398   }
1399 
1400   OpAdapterPtr adpt = FindAdapter(node, training_);
1401   if (adpt == nullptr) {
1402     error_ = NOT_FOUND;
1403     return;
1404   }
1405 
1406   // get Operator from op_cache_
1407   OperatorPtr op = Convert(node);
1408 
1409   adpt->updateOutputDesc(op, node->Shape(), node->Type(), node);
1410 }
1411 
Convert(const AnfNodePtr node)1412 OperatorPtr DfGraphConvertor::Convert(const AnfNodePtr node) {
1413   if (node == nullptr) {
1414     MS_LOG(ERROR) << "node is nullptr";
1415     error_ = NOT_FOUND;
1416     return nullptr;
1417   }
1418   // find in cache
1419   if (op_cache_.count(node.get())) {
1420     return op_cache_[node.get()];
1421   }
1422 
1423   // do not convert primitive node, Load, UpdateState
1424   if (IsValueNode<Primitive>(node) || IsPrimitiveCNode(node, prim::kPrimLoad) ||
1425       IsPrimitiveCNode(node, prim::kPrimUpdateState)) {
1426     return nullptr;
1427   }
1428 
1429   // convert a new one
1430   if (node->isa<CNode>()) {
1431     return ConvertCNode(node->cast<CNodePtr>());
1432   }
1433   if (node->isa<Parameter>()) {
1434     return ConvertParameter(node);
1435   }
1436   if (node->isa<ValueNode>()) {
1437     if (IsValueNode<Monad>(node)) {
1438       return nullptr;
1439     }
1440     return ConvertValueNode(node->cast<ValueNodePtr>());
1441   }
1442 
1443   MS_LOG(ERROR) << "Invalid AnfNode";
1444   error_ = INVALID_ARGUMENT;
1445   return nullptr;
1446 }
1447 
ConvertMakeTuple(const CNodePtr node)1448 void DfGraphConvertor::ConvertMakeTuple(const CNodePtr node) {
1449   std::shared_ptr<std::vector<OutHandler>> tuple_items = std::make_shared<std::vector<OutHandler>>();
1450   // convert each tuple item to a OutHandler
1451   for (size_t i = 1; i < node->inputs().size(); i++) {
1452     AnfNodePtr item = node->input(i);
1453     if (IsPrimitiveCNode(item, prim::kPrimLoad)) {
1454       item = ParseLoadInput(item->cast<CNodePtr>());
1455     }
1456     OperatorPtr op = Convert(item);
1457     if (op != nullptr) {
1458       tuple_items->emplace_back(OutHandler(op, "", item));
1459     } else if (out_handle_cache_.find(item.get()) != out_handle_cache_.end()) {
1460       tuple_items->push_back(out_handle_cache_[item.get()]);
1461     } else {
1462       tuple_items->push_back(OutHandler(nullptr, "", item));
1463     }
1464   }
1465 
1466   MS_LOG(DEBUG) << "ConvertMakeTuple: " << node.get() << " " << tuple_items->size();
1467   tuple_out_handle_cache_[node.get()] = tuple_items;
1468 }
1469 
ConvertTopK(const CNodePtr node)1470 void DfGraphConvertor::ConvertTopK(const CNodePtr node) {
1471   MS_EXCEPTION_IF_NULL(node);
1472   MS_LOG(INFO) << "Convert TopK second input's type from int64 to int32.";
1473   auto value_ptr = node->input(2)->cast<ValueNodePtr>();
1474   std::ostringstream ss;
1475   ss << "op" << value_ptr.get();
1476   op_draw_name_[value_ptr.get()] = ss.str();
1477   compute_sout_ << ss.str() << "[label= \"" << value_ptr->value()->ToString() << "\" shape=ellipse]" << endl;
1478   MS_EXCEPTION_IF_NULL(value_ptr);
1479   auto input_value = value_ptr->value();
1480   auto int64_value = GetValue<int64_t>(input_value);
1481   OpAdapterPtr adpt = FindAdapter(value_ptr, training_);
1482   auto op = adpt->generate(value_ptr);
1483   adpt->setAttr(op, "value", static_cast<int32_t>(int64_value));
1484   op_cache_[value_ptr.get()] = op;
1485 }
1486 
CastToInt(const ValuePtr & value)1487 std::vector<int64_t> DfGraphConvertor::CastToInt(const ValuePtr &value) {
1488   if (value == nullptr) {
1489     MS_LOG(WARNING) << "Value ptr is nullptr.";
1490     return {};
1491   }
1492   std::vector<int64_t> cur_value = {};
1493   if (utils::isa<ValueSequeuePtr>(value)) {
1494     auto val_seq_ptr = value->cast<ValueSequeuePtr>();
1495     MS_EXCEPTION_IF_NULL(val_seq_ptr);
1496     if (!val_seq_ptr->value().empty()) {
1497       auto first_val = val_seq_ptr->value().front();
1498       MS_EXCEPTION_IF_NULL(first_val);
1499       MS_EXCEPTION_IF_NULL(first_val->type());
1500       if (first_val->type()->number_type() == kNumberTypeInt64) {
1501         cur_value = GetValue<std::vector<int64_t>>(value);
1502       } else {
1503         auto origin_value = GetValue<std::vector<int>>(value);
1504         (void)std::transform(origin_value.begin(), origin_value.end(), std::back_inserter(cur_value),
1505                              [](int index) { return static_cast<int64_t>(index); });
1506       }
1507     }
1508   } else {
1509     MS_EXCEPTION_IF_NULL(value->type());
1510     if (value->type()->number_type() == kNumberTypeInt64) {
1511       cur_value.push_back(GetValue<int64_t>(value));
1512     } else {
1513       cur_value.push_back(static_cast<int64_t>(GetValue<int>(value)));
1514     }
1515   }
1516   return cur_value;
1517 }
1518 
ConvertReshape(const CNodePtr node)1519 void DfGraphConvertor::ConvertReshape(const CNodePtr node) {
1520   MS_LOG(INFO) << "Convert the second input of reshape to op attr.";
1521   const auto kInputNum = 3;
1522   if (node->size() < kInputNum) {
1523     MS_LOG(WARNING) << "Reshape must have two inputs.";
1524     return;
1525   }
1526   OpAdapterPtr adpt = FindAdapter(node, training_);
1527   if (adpt == nullptr) {
1528     return;
1529   }
1530   auto op = adpt->generate(node);
1531   MS_EXCEPTION_IF_NULL(op);
1532   // get shape form attr
1533   auto value_node = node->input(0)->cast<ValueNodePtr>();
1534   MS_EXCEPTION_IF_NULL(value_node);
1535   MS_EXCEPTION_IF_NULL(value_node->value());
1536   auto primitive = value_node->value()->cast<PrimitivePtr>();
1537   MS_EXCEPTION_IF_NULL(primitive);
1538   auto value = primitive->GetAttr("shape");
1539   std::vector<int64_t> list;
1540   list = CastToInt(value);
1541 
1542   (void)op->SetAttr("shape", list);
1543   op_cache_[node.get()] = op;
1544 }
1545 
TraceTupleGetItem(const CNodePtr & node,uint64_t * index)1546 AnfNodePtr DfGraphConvertor::TraceTupleGetItem(const CNodePtr &node, uint64_t *index) {
1547   const int TUPLE_GET_ITEM_INDEX = 2;
1548   if (node->inputs().size() < 3) {  // "tuple_getitem" primitive must have 3 inputs
1549     MS_LOG(EXCEPTION) << "length of inputs of TupleGetItem is less than 3";
1550   }
1551   auto index_node = node->inputs()[TUPLE_GET_ITEM_INDEX];
1552   if (!index_node->isa<ValueNode>()) {
1553     error_ = INVALID_ARGUMENT;
1554     MS_LOG(EXCEPTION) << "can't convert get item with non-constant index";
1555   }
1556   *index = LongToUlong(GetValue<int64_t>(GetValueNode(index_node)));
1557   return node->inputs()[1];
1558 }
1559 
TraceDepend(const CNodePtr & node)1560 AnfNodePtr DfGraphConvertor::TraceDepend(const CNodePtr &node) {
1561   auto cnode = node->cast<CNodePtr>();
1562   if (cnode->inputs().size() < 3) {  // "Depend" primitive have 3 inputs
1563     MS_LOG(EXCEPTION) << "length of inputs of depend is less than 3";
1564   }
1565   return cnode->inputs()[1];
1566 }
1567 
TraceMakeTuple(const CNodePtr & node,uint64_t index)1568 AnfNodePtr DfGraphConvertor::TraceMakeTuple(const CNodePtr &node, uint64_t index) {
1569   if (index + 1 >= node->inputs().size()) {
1570     MS_LOG(EXCEPTION) << "length of make_tuple is less than index: " << index;
1571   }
1572   return node->inputs()[index + 1];
1573 }
1574 
GetHandler(const AnfNodePtr & node,const std::stack<uint64_t> & index_stack,AnfNode * const draw_index)1575 OutHandler DfGraphConvertor::GetHandler(const AnfNodePtr &node, const std::stack<uint64_t> &index_stack,
1576                                         AnfNode *const draw_index) {
1577   if (node == nullptr) {
1578     MS_LOG(ERROR) << "Get nullptr while trace real op";
1579     return OutHandler(nullptr, "");
1580   }
1581   std::ostringstream ss;
1582   ss << "op" << node.get();
1583   if (index_stack.empty()) {
1584     op_draw_name_[draw_index] = ss.str();
1585     return OutHandler(Convert(node), "");
1586   } else {
1587     OpAdapterPtr adpt = FindAdapter(node, training_);
1588     if (adpt == nullptr) {
1589       MS_LOG(ERROR) << "Can not get node output as adpt is nullptr!";
1590       error_ = NOT_FOUND;
1591       return OutHandler(nullptr, "");
1592     }
1593     OperatorPtr op = Convert(node);
1594     if (op == nullptr) {
1595       error_ = NOT_FOUND;
1596       MS_LOG(ERROR) << "Can not convert node for trace real op";
1597       return OutHandler(nullptr, "");
1598     }
1599     op_draw_name_[draw_index] = ss.str();
1600     return adpt->getOutput(Convert(node), UintToInt(index_stack.top()));
1601   }
1602 }
1603 
1604 // get the real operator through maketuple tuple_getitem depend
TraceRealOp(AnfNodePtr node)1605 OutHandler DfGraphConvertor::TraceRealOp(AnfNodePtr node) {
1606   bool flag = IsPrimitiveCNode(node, prim::kPrimTupleGetItem) || IsPrimitiveCNode(node, prim::kPrimMakeTuple) ||
1607               IsPrimitiveCNode(node, prim::kPrimDepend);
1608   std::stack<uint64_t> index_stack;
1609   auto draw_index = node.get();
1610   while (flag) {
1611     flag = false;
1612     if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
1613       uint64_t index;
1614       node = TraceTupleGetItem(node->cast<CNodePtr>(), &index);
1615       index_stack.push(index);
1616       flag = true;
1617     } else if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
1618       if (index_stack.empty()) {
1619         MS_LOG(ERROR) << "TraceRealOp find a make_tuple node";
1620         return OutHandler(nullptr, "");
1621       } else {
1622         node = TraceMakeTuple(node->cast<CNodePtr>(), index_stack.top());
1623         index_stack.pop();
1624         flag = true;
1625       }
1626     } else if (IsPrimitiveCNode(node, prim::kPrimDepend)) {
1627       node = TraceDepend(node->cast<CNodePtr>());
1628       flag = true;
1629     }
1630   }
1631   return GetHandler(node, index_stack, draw_index);
1632 }
1633 
ConvertTupleGetItem(const CNodePtr node)1634 void DfGraphConvertor::ConvertTupleGetItem(const CNodePtr node) {
1635   auto handle = TraceRealOp(node);
1636   if (handle.op == nullptr) {
1637     MS_LOG(ERROR) << "Failed to trace tuple get item";
1638     return;
1639   }
1640   out_handle_cache_[node.get()] = handle;
1641 }
1642 
1643 // Get the real op for tuple_getitem through make tuple, or depend
GetRealOpNode(AnfNodePtr node)1644 AnfNodePtr DfGraphConvertor::GetRealOpNode(AnfNodePtr node) {
1645   const int TUPLE_GET_ITEM_INDEX = 2;
1646   if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
1647     auto node_inputs = node->cast<CNodePtr>()->inputs();
1648     if (node_inputs.size() != 3) {  // "tuple_getitem" primitive must have 3 inputs
1649       MS_LOG(ERROR) << "tuple get item node not correct!";
1650       error_ = FAILED;
1651       return node;
1652     }
1653     MS_EXCEPTION_IF_NULL(node_inputs[TUPLE_GET_ITEM_INDEX]);
1654     if (!node_inputs[TUPLE_GET_ITEM_INDEX]->isa<ValueNode>()) {
1655       error_ = INVALID_ARGUMENT;
1656       MS_LOG(EXCEPTION) << "can't convert get item with non-constant index";
1657     }
1658     auto value_ptr = GetValueNode(node_inputs[TUPLE_GET_ITEM_INDEX])->cast<Int32ImmPtr>();
1659     if (value_ptr == nullptr) {
1660       MS_LOG(ERROR) << "Can not convert get item as value is nullptr!";
1661       error_ = FAILED;
1662       return node;
1663     }
1664     int64_t index = value_ptr->value();
1665 
1666     // make_tuple apply inputs:make_tuple, [tuple_items,]
1667     if (IsPrimitiveCNode(node_inputs[1], prim::kPrimMakeTuple)) {
1668       auto tuple_inputs = node->cast<CNodePtr>()->inputs();
1669       if (tuple_inputs.size() < LongToSize(index + 1L)) {
1670         MS_LOG(ERROR) << "make tuple input items node not correct! size:" << tuple_inputs.size()
1671                       << ", item index:" << index;
1672         error_ = FAILED;
1673         return node;
1674       }
1675       return GetRealOpNode(tuple_inputs[LongToSize(index + 1L)]);
1676     }
1677     return GetRealOpNode(node_inputs[1]);
1678   }
1679 
1680   // depend apply inputs: depend,output,depended_node
1681   if (IsPrimitiveCNode(node, prim::kPrimDepend)) {
1682     auto depend_inputs = node->cast<CNodePtr>()->inputs();
1683     if (depend_inputs.size() != 3) {  // "Depend" primitive have 3 inputs
1684       MS_LOG(ERROR) << "depend input items not correct";
1685       error_ = FAILED;
1686       return node;
1687     }
1688     return GetRealOpNode(depend_inputs[1]);
1689   }
1690   return node;
1691 }
1692 
1693 // convert the anf node to corresponding operator list
ConvertDependNode(const AnfNodePtr node)1694 std::vector<OperatorPtr> DfGraphConvertor::ConvertDependNode(const AnfNodePtr node) {
1695   if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
1696     std::vector<OperatorPtr> op_lists;
1697     auto node_inputs = node->cast<CNodePtr>()->inputs();
1698     for (size_t index = 1; index < node_inputs.size(); index++) {
1699       auto op = Convert(GetRealOpNode(node_inputs[index]));
1700       if (op == nullptr) {
1701         MS_LOG(ERROR) << "Convert real op node to operator failed";
1702         error_ = FAILED;
1703         return std::vector<OperatorPtr>({});
1704       }
1705       op_lists.push_back(op);
1706     }
1707     return op_lists;
1708   }
1709 
1710   auto op = Convert(GetRealOpNode(node));
1711   if (op == nullptr) {
1712     MS_LOG(ERROR) << "Convert real op node to operator failed";
1713     error_ = FAILED;
1714     return std::vector<OperatorPtr>({});
1715   }
1716   return std::vector<OperatorPtr>({op});
1717 }
1718 
CheckCNode(const std::string & name,const CNodePtr node)1719 bool DfGraphConvertor::CheckCNode(const std::string &name, const CNodePtr node) {
1720   // ignore apply node of return
1721   if (name == "" || name == prim::kPrimReturn->name() || name == prim::kPrimDepend->name() ||
1722       name == prim::kPrimSwitchLayer->name() || name == prim::kPrimPartial->name()) {
1723     return false;
1724   }
1725 
1726   // Convert TopK second input from int64 to int32.
1727   if (name == prim::kPrimTopK->name()) {
1728     ConvertTopK(node);
1729     return true;
1730   }
1731 
1732   // Convert Reshape add const input to attr(shape)
1733   if (name == prim::kPrimReshape->name()) {
1734     ConvertReshape(node);
1735     return true;
1736   }
1737 
1738   // make_tuple is used for a dynamic_input, convert it to a vector of OutHandlers
1739   if (name == prim::kPrimMakeTuple->name()) {
1740     ConvertMakeTuple(node);
1741     return false;
1742   }
1743 
1744   // As for nodes with multi outputs, convert tuple_getitem to OutHandle
1745   if (name == prim::kPrimTupleGetItem->name()) {
1746     ConvertTupleGetItem(node);
1747     return false;
1748   }
1749 
1750   return true;
1751 }
1752 
ConvertCNode(const CNodePtr node)1753 OperatorPtr DfGraphConvertor::ConvertCNode(const CNodePtr node) {
1754   SaveParamFormat(node);
1755   std::string name = GetCNodeTargetFuncName(node);
1756   if (!CheckCNode(name, node)) {
1757     return nullptr;
1758   }
1759 
1760   // get corresponding OpAdapter
1761   OpAdapterPtr adpt = FindAdapter(node, training_);
1762   if (adpt == nullptr) {
1763     error_ = NOT_FOUND;
1764     return nullptr;
1765   }
1766 
1767   // get operator
1768   OperatorPtr op = nullptr;
1769   auto it_op = op_cache_.find(node.get());
1770   if (it_op != op_cache_.end()) {
1771     op = it_op->second;
1772   } else {
1773     op = adpt->generate(node);
1774   }
1775 
1776   // set attribute for primitive
1777   (void)adpt->setAttr(op, node);
1778 
1779   // add into cache
1780   (void)op_cache_.insert(std::make_pair(node.get(), op));
1781 
1782   DrawCNode(node, adpt);
1783 
1784   return op_cache_[node.get()];
1785 }
1786 
ConvertParameter(const AnfNodePtr node)1787 OperatorPtr DfGraphConvertor::ConvertParameter(const AnfNodePtr node) {
1788   // convert Parameter in ANF to variable in DataFlow
1789   auto adpt = FindAdapter(node, training_);
1790   if (adpt == nullptr) {
1791     MS_LOG(EXCEPTION) << "Can not find adapter for Parameter";
1792   }
1793   auto op = adpt->generate(node);
1794   op_cache_[node.get()] = op;
1795 
1796   // build index for parameter using name
1797   std::string name = std::static_pointer_cast<Parameter>(node)->name();
1798   params_[name] = node;
1799   std::ostringstream ss;
1800   ss << "op" << node.get();
1801   op_draw_name_[node.get()] = ss.str();
1802   compute_sout_ << ss.str() << "[shape=octagon, label=\"" << name << "\"]" << endl;
1803   return op_cache_[node.get()];
1804 }
1805 
SaveParamFormat(const CNodePtr node)1806 void DfGraphConvertor::SaveParamFormat(const CNodePtr node) {
1807   AnfNodePtr op = node->input(0);
1808   if (IsValueNode<Primitive>(op)) {
1809     auto prim = GetValueNode<PrimitivePtr>(op);
1810     for (auto attr : prim->attrs()) {
1811       if (attr.first == "format") {
1812         std::string format;
1813         if (attr.second->isa<Int64Imm>()) {
1814           bool converted = CheckAndConvertUtils::ConvertAttrValueToString(prim->name(), "format", &attr.second);
1815           if (converted) {
1816             format = attr.second->ToString();
1817           }
1818         }
1819         if (format != "NCDHW") {
1820           break;
1821         }
1822         for (size_t i = 1; i < node->size(); i++) {
1823           auto input = node->input(i);
1824           if (input->isa<Parameter>()) {
1825             param_format_[input->DebugString()] = format;
1826             MS_LOG(DEBUG) << "Save Param " << input->DebugString() << " format: " << format;
1827           }
1828         }
1829       }
1830     }
1831   }
1832 }
1833 
TryConvertValueNodeToMultiConst(const ValueNodePtr node)1834 Status DfGraphConvertor::TryConvertValueNodeToMultiConst(const ValueNodePtr node) {
1835   MS_EXCEPTION_IF_NULL(node);
1836   ValuePtr value = node->value();
1837   MS_EXCEPTION_IF_NULL(value);
1838   if (!value->isa<ValueList>() && !value->isa<ValueTuple>()) {
1839     return FAILED;
1840   }
1841 
1842   auto vec = value->isa<ValueTuple>() ? value->cast<ValueTuplePtr>()->value() : value->cast<ValueListPtr>()->value();
1843   if (vec.empty()) {
1844     return FAILED;
1845   }
1846 
1847   std::shared_ptr<std::vector<OutHandler>> tuple_items = std::make_shared<std::vector<OutHandler>>();
1848   for (size_t i = 0; i < vec.size(); i++) {
1849     MS_EXCEPTION_IF_NULL(vec[i]);
1850     if (vec[i]->isa<MeTensor>()) {
1851       GeTensorPtr ge_tensor = transform::TransformUtil::ConvertTensor(vec[i]->cast<MeTensorPtr>(), kOpFormat_NCHW);
1852       auto const_op = std::make_shared<Constant>(node->fullname_with_scope() + "/const/inputs/" + std::to_string(i));
1853       (void)const_op->set_attr_value(*ge_tensor);
1854       (void)const_op->update_output_desc_y(ge_tensor->GetTensorDesc());
1855       tuple_items->emplace_back(OutHandler(const_op, ""));
1856     } else {
1857       return FAILED;
1858     }
1859   }
1860   if (tuple_items->empty()) {
1861     return FAILED;
1862   }
1863 
1864   tuple_out_handle_cache_[node.get()] = tuple_items;
1865   return SUCCESS;
1866 }
1867 
ConvertValueNode(const ValueNodePtr node)1868 OperatorPtr DfGraphConvertor::ConvertValueNode(const ValueNodePtr node) {
1869   // convert valuenode in ANF to Const in DataFlow
1870   // find paramerte referenced by SymbolicKeyInstance of valuenode
1871   std::ostringstream ss;
1872   ss << "op" << node.get();
1873   op_draw_name_[node.get()] = ss.str();
1874   compute_sout_ << ss.str() << "[label= \"" << node->value()->ToString() << "\" shape=ellipse]" << endl;
1875 
1876   if (TryConvertValueNodeToMultiConst(node) == SUCCESS) {
1877     MS_LOG(INFO) << "Convert value node to multi Constant OP success";
1878     return nullptr;
1879   }
1880 
1881   OpAdapterPtr adpt = FindAdapter(node, training_);
1882   if (adpt == nullptr) {
1883     error_ = NOT_FOUND;
1884     return nullptr;
1885   }
1886   auto op = adpt->generate(node);
1887   // set const's attrs
1888   if (adpt->setAttr(op, "value", node->value()) != 0) {
1889     MS_LOG(WARNING) << "set attr value for const failed";
1890   }
1891 
1892   auto const_op = std::static_pointer_cast<Constant>(op);
1893   if (const_op == nullptr) {
1894     MS_LOG(ERROR) << "Get Constant operator failed";
1895     return nullptr;
1896   }
1897   auto ge_tensor = const_op->get_attr_value();
1898   auto ge_desc = ge_tensor.GetTensorDesc();
1899   (void)const_op->update_output_desc_y(ge_desc);
1900   op_cache_[node.get()] = op;
1901   return op_cache_[node.get()];
1902 }
1903 
DrawCNode(const CNodePtr node,const OpAdapterPtr adpt)1904 void DfGraphConvertor::DrawCNode(const CNodePtr node, const OpAdapterPtr adpt) {
1905   if (adpt == nullptr || node == nullptr) {
1906     MS_LOG(ERROR) << "Failed to draw apply node as adpt or node is nullptr!";
1907     return;
1908   }
1909   std::ostringstream ss;
1910   ss << "op" << node.get();
1911   op_draw_name_[node.get()] = ss.str();
1912 
1913   compute_sout_ << ss.str() << "[label=<";
1914   compute_sout_ << "<table border='1' cellborder='1'>" << endl;
1915 
1916   auto input_map = adpt->getInputMap();
1917   auto dyn_input_map = adpt->getDynInputMap();
1918   if (input_map.size() + dyn_input_map.size() > 0) {
1919     compute_sout_ << "<tr>";
1920     for (auto &it : input_map) {
1921       compute_sout_ << "<td port='" << it.first << "'>" << it.second.name << "</td>";
1922     }
1923     for (auto &it : dyn_input_map) {
1924       compute_sout_ << "<td port='" << it.first << "'>" << it.second.name << "</td>";
1925     }
1926     compute_sout_ << "</tr>" << endl;
1927   }
1928 
1929   compute_sout_ << "<tr><td colspan=\"" << (input_map.size() + dyn_input_map.size()) << "\">\"" << node->ToString()
1930                 << ":" << GetCNodeTargetFuncName(node) << "\"</td></tr>" << endl;
1931 
1932   // print attrs' values
1933   auto atts = adpt->GetAttrsFromDrawGraph();
1934   for (auto &it : atts) {
1935     compute_sout_ << "<tr><td colspan=\"" << (input_map.size() + dyn_input_map.size()) << "\">\"" << it
1936                   << "\"</td></tr>";
1937   }
1938 
1939   adpt->clearAttrVect();
1940 
1941   compute_sout_ << "</table>> shape=plaintext]" << endl;
1942 }
RegisterAdapter(const std::string & name,OpAdapterPtr adpt)1943 void DfGraphConvertor::RegisterAdapter(const std::string &name, OpAdapterPtr adpt) {
1944   OpAdapterMap::get()[name] = std::make_shared<OpAdapterDesc>(adpt);
1945 }
RegisterAdapter(const std::string & name,OpAdapterPtr train_adpt,OpAdapterPtr infer_adpt)1946 void DfGraphConvertor::RegisterAdapter(const std::string &name, OpAdapterPtr train_adpt, OpAdapterPtr infer_adpt) {
1947   OpAdapterMap::get()[name] = std::make_shared<OpAdapterDesc>(train_adpt, infer_adpt);
1948 }
1949 }  // namespace transform
1950 }  // namespace mindspore
1951