1 /**
2 * Copyright 2019-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "transform/graph_ir/convert.h"
18
19 #include <cinttypes>
20 #include <algorithm>
21 #include <stack>
22 #include "utils/utils.h"
23
24 #include "base/core_ops.h"
25 #include "frontend/operator/ops.h"
26 #include "utils/log_adapter.h"
27 #include "ir/graph_utils.h"
28 #include "utils/symbolic.h"
29 #include "utils/config_manager.h"
30 #include "utils/convert_utils.h"
31 #include "utils/ms_context.h"
32 #include "utils/check_convert_utils.h"
33 #include "transform/graph_ir/op_adapter_map.h"
34 #include "ops/state_ops.h"
35 #include "ops/array_ops.h"
36 #include "ops/elewise_calculation_ops.h"
37 #include "ops/math_ops.h"
38 #ifdef ENABLE_GE
39 #include "ops/save_ops.h"
40 #endif
41
42 namespace mindspore {
43 namespace transform {
44 using std::endl;
45
46 using ge::Operator;
47 using mindspore::kAnyValue;
48 using std::make_shared;
49 using std::shared_ptr;
50 using std::string;
51 using std::vector;
52 using Variable = ge::op::Variable;
53 using Constant = ge::op::Constant;
54 using Assign = ge::op::Assign;
55 using Data = ge::op::Data;
56
57 namespace {
GetOrderedCNodes(const FuncGraphPtr fg)58 std::vector<AnfNodePtr> GetOrderedCNodes(const FuncGraphPtr fg) {
59 MS_EXCEPTION_IF_NULL(fg);
60 auto BelongSameGraph = std::bind(IncludeBelongGraph, fg, std::placeholders::_1);
61 auto succ_include_fv = [&fg](const AnfNodePtr &node) -> std::vector<AnfNodePtr> {
62 std::vector<AnfNodePtr> vecs;
63 if (node == nullptr) {
64 return vecs;
65 }
66 if (node->isa<CNode>()) {
67 auto cnode = node->cast<CNodePtr>();
68 auto &inputs = cnode->inputs();
69 // Check if free variables used.
70 for (const auto &input : inputs) {
71 auto input_fg = GetValueNode<FuncGraphPtr>(input);
72 if (input_fg) {
73 for (auto &fv : input_fg->free_variables_nodes()) {
74 if (fv->func_graph() == fg && fg->nodes().contains(fv)) {
75 vecs.push_back(fv);
76 }
77 }
78 }
79 }
80 (void)vecs.insert(vecs.end(), inputs.begin(), inputs.end());
81 }
82 return vecs;
83 };
84
85 return TopoSort(fg->get_return(), succ_include_fv, BelongSameGraph);
86 }
87 } // namespace
88
89 // ---------------implement of DfGraphConvertor-------------
IsCaseNode(const CNodePtr node)90 bool IsCaseNode(const CNodePtr node) {
91 MS_EXCEPTION_IF_NULL(node);
92 if (!node->inputs().empty() && node->input(0)->isa<CNode>() &&
93 GetCNodeFuncName(node->input(0)->cast<CNodePtr>()) == "switch_layer") {
94 return true;
95 }
96 return false;
97 }
98
GetCNodeTargetFuncName(const CNodePtr cnode)99 std::string GetCNodeTargetFuncName(const CNodePtr cnode) {
100 if (IsCaseNode(cnode)) {
101 return string(kNameCase);
102 }
103 auto name = GetCNodeFuncName(cnode);
104 if (name == "switch_layer") {
105 name = "";
106 }
107 return name;
108 }
109
FindAdapter(const AnfNodePtr node,bool train)110 OpAdapterPtr DfGraphConvertor::FindAdapter(const AnfNodePtr node, bool train) {
111 MS_EXCEPTION_IF_NULL(node);
112 if (node->isa<CNode>()) {
113 auto cnode = node->cast<CNodePtr>();
114
115 std::string name = kNameCustomOp;
116 if (!IsCustomCNode(cnode)) {
117 name = GetCNodeTargetFuncName(cnode);
118 }
119
120 auto it_adpt = OpAdapterMap::get().find(name);
121 if (it_adpt != OpAdapterMap::get().end()) {
122 return it_adpt->second->Get(train);
123 }
124 MS_LOG(EXCEPTION) << "Can't find OpAdapter for " << name;
125 }
126
127 if (node->isa<ValueNode>()) {
128 return OpAdapterMap::get()[kNameConst]->Get(train);
129 }
130 if (node->isa<Parameter>()) {
131 return OpAdapterMap::get()[kNameParam]->Get(train);
132 }
133 return OpAdapterPtr(nullptr);
134 }
135
InitLoopVar(std::vector<ge::Operator> * init_input)136 void DfGraphConvertor::InitLoopVar(std::vector<ge::Operator> *init_input) {
137 MS_EXCEPTION_IF_NULL(init_input);
138 if (this->training_) {
139 GeTensorDesc desc(GeShape(), ge::FORMAT_NCHW, ge::DT_INT64);
140 auto var_iter_num = std::make_shared<Variable>("npu_runconfig/iterations_per_loop");
141 auto var_loop_cond = std::make_shared<Variable>("npu_runconfig/loop_cond");
142 auto var_one = std::make_shared<Variable>("npu_runconfig/one");
143 auto var_zero = std::make_shared<Variable>("npu_runconfig/zero");
144 (void)var_iter_num->update_output_desc_y(desc);
145 (void)var_loop_cond->update_output_desc_y(desc);
146 (void)var_one->update_output_desc_y(desc);
147 (void)var_zero->update_output_desc_y(desc);
148 vars_["npu_runconfig/iterations_per_loop"] = var_iter_num;
149 vars_["npu_runconfig/loop_cond"] = var_loop_cond;
150 vars_["npu_runconfig/one"] = var_one;
151 vars_["npu_runconfig/zero"] = var_zero;
152
153 int64_t value = 0;
154 auto const_iter_num = std::make_shared<Constant>("const/npu_runconfig/iterations_per_loop");
155 if (ConfigManager::GetInstance().dataset_mode() == DS_SINK_MODE) {
156 value = ConfigManager::GetInstance().iter_num();
157 } else {
158 MS_LOG(INFO) << "Run with normal(non-sink) mode, the iterator number will always be 1";
159 value = 1;
160 ConfigManager::GetInstance().set_iter_num(value);
161 }
162 value -= 1; // iteration start from 0, the max iteration number for n loop should be n-1
163 (void)const_iter_num->set_attr_value(GeTensor(desc, reinterpret_cast<uint8_t *>(&value), sizeof(int64_t)));
164
165 auto const_loop_cond = std::make_shared<Constant>("const/npu_runconfig/loop_cond");
166 value = 0;
167 (void)const_loop_cond->set_attr_value(GeTensor(desc, reinterpret_cast<uint8_t *>(&value), sizeof(int64_t)));
168
169 auto const_one = std::make_shared<Constant>("const/npu_runconfig/one");
170 value = 1;
171 (void)const_one->set_attr_value(GeTensor(desc, reinterpret_cast<uint8_t *>(&value), sizeof(int64_t)));
172
173 auto const_zero = std::make_shared<Constant>("const/npu_runconfig/zero");
174 value = 0;
175 (void)const_zero->set_attr_value(GeTensor(desc, reinterpret_cast<uint8_t *>(&value), sizeof(int64_t)));
176
177 (void)const_iter_num->update_output_desc_y(desc);
178 (void)const_loop_cond->update_output_desc_y(desc);
179 (void)const_one->update_output_desc_y(desc);
180 (void)const_zero->update_output_desc_y(desc);
181
182 auto assign_iter_num = std::make_shared<Assign>("assign/npu_runconfig/iterations_per_loop");
183 (void)assign_iter_num->set_input_ref(*var_iter_num).set_input_value(*const_iter_num);
184 auto assign_loop_cond = std::make_shared<Assign>("assign/npu_runconfig/loop_cond");
185 (void)assign_loop_cond->set_input_ref(*var_loop_cond).set_input_value(*const_loop_cond);
186 auto assign_one = std::make_shared<Assign>("assign/npu_runconfig/one");
187 (void)assign_one->set_input_ref(*var_one).set_input_value(*const_one);
188 auto assign_zero = std::make_shared<Assign>("assign/npu_runconfig/zero");
189 (void)assign_zero->set_input_ref(*var_zero).set_input_value(*const_zero);
190
191 init_input->push_back(*var_iter_num);
192 init_input->push_back(*var_loop_cond);
193 init_input->push_back(*var_one);
194 init_input->push_back(*var_zero);
195 init_ops_.push_back(var_iter_num);
196 init_ops_.push_back(var_loop_cond);
197 init_ops_.push_back(var_one);
198 init_ops_.push_back(var_zero);
199 init_ops_.push_back(const_iter_num);
200 init_ops_.push_back(const_loop_cond);
201 init_ops_.push_back(const_one);
202 init_ops_.push_back(const_zero);
203 init_ops_.push_back(assign_iter_num);
204 init_ops_.push_back(assign_loop_cond);
205 init_ops_.push_back(assign_one);
206 init_ops_.push_back(assign_zero);
207 }
208 }
209
FindAdapter(const std::string & name,bool train)210 OpAdapterPtr DfGraphConvertor::FindAdapter(const std::string &name, bool train) {
211 auto it = OpAdapterMap::get().find(name);
212 if (it != OpAdapterMap::get().end()) {
213 return it->second->Get(train);
214 }
215 MS_LOG(EXCEPTION) << "Can't find OpAdapter for " << name;
216 }
217
DrawParamInitSubGraph(const std::string & name,const AnfNodePtr & it)218 void DfGraphConvertor::DrawParamInitSubGraph(const std::string &name, const AnfNodePtr &it) {
219 // draw init subgraph
220 init_sout_ << "op_assign" << it.get() << "[label=<";
221 init_sout_ << "<table border='1' cellborder='1'>" << endl;
222 init_sout_ << "<tr>";
223 init_sout_ << "<td port='1'>resource</td>";
224 init_sout_ << "<td port='2'>value</td>";
225 init_sout_ << "</tr>" << endl;
226 init_sout_ << "<tr><td colspan=\"2\">"
227 << "\"assign_" << name << "\"</td></tr>" << endl;
228 init_sout_ << "</table>> shape=plaintext]" << endl;
229 init_sout_ << "param" << it.get() << "[shape=octagon, label=\"" << name << "\"]" << endl;
230 init_sout_ << "const" << it.get() << "[label= \"" << name << "_const"
231 << "\" shape=ellipse]" << endl;
232 init_sout_ << "param" << it.get() << "->"
233 << "op_assign" << it.get() << ":1" << endl;
234 init_sout_ << "const" << it.get() << "->"
235 << "op_assign" << it.get() << ":2" << endl;
236 }
237
SetupParamInitSubGraph(const TensorOrderMap & tensors,std::vector<ge::Operator> * init_input)238 void DfGraphConvertor::SetupParamInitSubGraph(const TensorOrderMap &tensors, std::vector<ge::Operator> *init_input) {
239 DfGraphPtr init_graph = std::make_shared<DfGraph>("init");
240 std::vector<AnfNodePtr> nodes = GetOrderedCNodes(anf_graph_);
241
242 for (auto &it : nodes) {
243 MS_EXCEPTION_IF_NULL(it);
244 if (it->isa<ValueNode>()) {
245 if (IsValueNode<SymbolicKeyInstance>(it)) {
246 auto symbolic = GetValueNode<SymbolicKeyInstancePtr>(it);
247 auto name = std::static_pointer_cast<Parameter>(symbolic->node())->name();
248 auto iter = vars_.find(name); // get corresponding variable op
249 if (iter != vars_.end()) {
250 op_cache_[it.get()] = iter->second;
251 // #ifdef DRAW_GE_GRAPH
252 compute_sout_ << op_draw_name_[params_[name].get()] << " -> " << op_draw_name_[it.get()]
253 << "[style=\"dotted\"]" << endl;
254 // #endif
255 }
256 } else if (IsValueNode<RefKey>(it)) {
257 auto refkey = GetValueNode<RefKeyPtr>(it);
258 MS_EXCEPTION_IF_NULL(refkey);
259 auto name = refkey->tag();
260 auto iter = vars_.find(name); // get corresponding variable op
261 if (iter != vars_.end()) {
262 op_cache_[it.get()] = iter->second;
263 compute_sout_ << op_draw_name_[params_[name].get()] << " -> " << op_draw_name_[it.get()]
264 << "[style=\"dotted\"]" << endl;
265 }
266 }
267 }
268 }
269
270 for (auto &it : tensors) {
271 if (vars_.find(it.first) == vars_.end()) {
272 MS_LOG(WARNING) << "Init parameter " << it.first << " didn't appear in graph.";
273 vars_[it.first] = nullptr;
274 }
275 }
276
277 // set up init sub graph
278 if (init_input->size()) {
279 // init sub graph needs no input
280 MS_LOG(INFO) << "Build data init subgraph.";
281 (void)init_graph->SetInputs(*init_input);
282 this->init_graph_ = init_graph;
283 } else {
284 this->init_graph_ = nullptr;
285 }
286 }
287
MakeDatasetHandler(const std::string & name,const size_t & input_idx,const AnfNodePtr & it)288 void DfGraphConvertor::MakeDatasetHandler(const std::string &name, const size_t &input_idx, const AnfNodePtr &it) {
289 MS_LOG(INFO) << "The " << name << " is the " << input_idx << "(st/nd/th) input";
290 if (ConfigManager::GetInstance().dataset_mode() == DS_SINK_MODE) {
291 auto getnext_idx = static_cast<int64_t>(input_idx);
292 DatasetGraphParam param = ConfigManager::GetInstance().dataset_param();
293 if (!param.input_indexes().empty() && input_idx <= param.input_indexes().size()) {
294 getnext_idx = param.input_indexes()[input_idx] - 1; // input_idx start from 0.
295 MS_LOG(INFO) << "remap input_index:" << input_idx << " to getnext_index:" << getnext_idx << ".";
296 }
297 // use iterator_getnext op with output_name instead of data op in BuildGraph.
298 if (dataset_iter_getnext_ != nullptr) {
299 out_handle_cache_[it.get()] = OutHandler(dataset_iter_getnext_, "y" + std::to_string(getnext_idx));
300 }
301 }
302 }
303
SetupBroadcast(const std::shared_ptr<HcomBroadcast> & broadcast,const std::vector<GeTensorDesc> & broadcast_desc,const DfGraphPtr & broadcast_graph,std::vector<ge::Operator> broadcast_input)304 void DfGraphConvertor::SetupBroadcast(const std::shared_ptr<HcomBroadcast> &broadcast,
305 const std::vector<GeTensorDesc> &broadcast_desc,
306 const DfGraphPtr &broadcast_graph, std::vector<ge::Operator> broadcast_input) {
307 MS_LOG(INFO) << "build broadcast subgraph";
308 if (broadcast_desc.size() != broadcast_input.size()) {
309 MS_LOG(EXCEPTION) << "Desc number of BroadCast is not equal to number of Input";
310 }
311 (void)broadcast->create_dynamic_input_x(static_cast<unsigned int>(broadcast_input.size()));
312 (void)broadcast->create_dynamic_output_y(static_cast<unsigned int>(broadcast_desc.size()));
313 for (unsigned int i = 0; i < broadcast_input.size(); i++) {
314 (void)broadcast->set_dynamic_input_x(i, broadcast_input[i]);
315 (void)broadcast->update_dynamic_output_desc_y(i, broadcast_desc[i]);
316 }
317 (void)broadcast_graph->SetInputs(broadcast_input);
318 this->broadcast_graph_ = broadcast_graph;
319 }
320
InitParamWithData(const TensorOrderMap & tensors)321 void DfGraphConvertor::InitParamWithData(const TensorOrderMap &tensors) {
322 int index = 0;
323 std::vector<Operator> init_input;
324 for (auto it : tensors) {
325 std::string name = it.first;
326 auto node_itor = params_.find(name);
327 // if name not in params_, create a node in graph
328 if (node_itor == params_.end()) {
329 MS_LOG(WARNING) << name << " is not in params, and create a new node.";
330 ParameterPtr param = std::make_shared<Parameter>(nullptr);
331 name = name + "_temp";
332 param->set_name(name);
333 (void)ConvertParameter(param);
334 node_itor = params_.find(name);
335 }
336 auto node = node_itor->second;
337 auto op_itor = op_cache_.find(node.get());
338 if (op_itor == op_cache_.end()) {
339 MS_LOG(EXCEPTION) << "Can not find op for node " << node->ToString() << ".";
340 }
341 auto adpt = FindAdapter(kNameParam, training_);
342 if (adpt == nullptr) continue;
343 auto param_op = adpt->generate(name + "_data");
344 MS_LOG(INFO) << "Add parameter " << name << " as input, index " << index << ".";
345
346 if (!training_) {
347 auto adpt_const = FindAdapter(kNameConst, training_);
348 if (adpt_const == nullptr) continue;
349 auto const_op = adpt_const->generate(name + "_const");
350 (void)adpt_const->setAttr(const_op, "value", it.second);
351
352 auto const_op_desc = TransformUtil::GetGeTensorDesc(it.second->shape_c(), it.second->data_type(), kOpFormat_NCHW);
353 if (const_op_desc == nullptr) {
354 MS_LOG(WARNING) << "Create variable " << name << " output descriptor failed!";
355 continue;
356 }
357 (void)std::static_pointer_cast<Constant>(const_op)->update_output_desc_y(*const_op_desc);
358
359 vars_[name] = const_op;
360 op_itor->second = const_op;
361 continue;
362 }
363
364 // create tensor descriptor for output descriptor
365 auto desc = TransformUtil::GetGeTensorDesc(it.second->shape_c(), it.second->data_type(), kOpFormat_NCHW);
366 if (desc == nullptr) {
367 MS_LOG(ERROR) << "Create variable " << name << " output descriptor failed!";
368 continue;
369 }
370
371 // we need three variable ops for each graph with same name
372 // build init subgraph
373 if (it.second->is_init() == 0) {
374 (void)std::static_pointer_cast<Data>(param_op)->set_attr_index(index++);
375 auto init_var = std::make_shared<Variable>(name);
376 auto assign_op = std::make_shared<Assign>("assign_" + name);
377 (void)init_var->update_output_desc_y(*desc);
378 (void)assign_op->set_input_ref(*init_var).set_input_value(*param_op);
379 init_input.push_back(*init_var);
380 init_ops_.push_back(param_op);
381 init_ops_.push_back(assign_op);
382 init_ops_.push_back(init_var);
383 }
384
385 auto variable = std::make_shared<Variable>(name);
386 (void)variable->update_output_desc_y(*desc);
387 // do not use read variable while variable sink
388 MS_LOG(DEBUG) << "InitParam, op_name = " << name << ", var = " << variable->GetName() << ".";
389 op_itor->second = variable; // replace parameter with variable
390 vars_[name] = variable; // prevent the variable operator from being freed
391 DrawParamInitSubGraph(name, node);
392 }
393 InitLoopVar(&init_input);
394 SetupParamInitSubGraph(tensors, &init_input);
395 }
396
397 // convert all parameter need initialize to variable
InitParam(const TensorOrderMap & tensors)398 DfGraphConvertor &DfGraphConvertor::InitParam(const TensorOrderMap &tensors) {
399 size_t input_idx = 0;
400 if (error_ != 0) {
401 return *this;
402 }
403 if (anf_graph_ == nullptr || anf_graph_->output() == nullptr) {
404 error_ = INVALID_ARGUMENT;
405 MS_LOG(ERROR) << "Invalid AnfGraph in InitParam.";
406 return *this;
407 }
408
409 // Processing input with MakeDatasetHandler
410 for (auto &it : anf_graph_->parameters()) {
411 auto op_itor = op_cache_.find(it.get()); // converted node
412 if (it->isa<Parameter>() && op_itor != op_cache_.end()) {
413 string name = std::static_pointer_cast<Parameter>(it)->name();
414 auto tensor_itor = tensors.find(name); // in init value map
415 if (tensor_itor == tensors.end()) {
416 DfGraphConvertor::MakeDatasetHandler(name, input_idx, it);
417 input_idx++;
418 }
419 }
420 }
421 InitParamWithData(tensors);
422 init_sout_ << "}" << endl;
423 return *this;
424 }
425
426 #if (defined ENABLE_GE)
BuildSaveCheckpointGraph()427 void DfGraphConvertor::BuildSaveCheckpointGraph() {
428 std::vector<Operator> graph_inputs;
429 ge::op::Save save_op("save_parms");
430 int save_op_is_active = 0;
431 size_t index = 0;
432 string name;
433
434 int32_t count_size = std::count_if(vars_.begin(), vars_.end(), [](const std::pair<std::string, OperatorPtr> &it) {
435 return (it.second == nullptr || it.first.find("/") != std::string::npos);
436 });
437
438 (void)save_op.create_dynamic_input_tensors(vars_.size() - static_cast<size_t>(count_size));
439
440 // for each "parameter" in anf graph excluding "input"
441 for (const auto &it : vars_) {
442 name = it.first;
443 if (it.second == nullptr || name.find("/") != std::string::npos) continue;
444 Variable variable(name);
445 (void)variable.update_output_desc_y(it.second->GetOutputDesc(0));
446 (void)save_op.set_dynamic_input_tensors(index++, variable);
447
448 graph_inputs.push_back(variable);
449
450 if (save_op_is_active == 0) {
451 checkpoint_sout_ << "op_save" << &save_op << "[label=<";
452 checkpoint_sout_ << "<table border='1' cellborder='1'>" << endl;
453 checkpoint_sout_ << "<tr><td port='1'>tensor</td></tr>" << endl;
454 checkpoint_sout_ << "<tr><td colspan=\"1\">"
455 << "\"saveop"
456 << "\"</td></tr>" << endl;
457 checkpoint_sout_ << "</table>> shape=plaintext]" << endl;
458 }
459
460 checkpoint_sout_ << "param" << it.second << "[shape=octagon, label=\"" << name << "\"]" << endl;
461
462 checkpoint_sout_ << "param" << it.second << "->"
463 << "op_save" << &save_op << ":1" << endl;
464 save_op_is_active = 1;
465 }
466 if (save_op_is_active) {
467 std::vector<Operator> graph_output;
468 graph_output.emplace_back(save_op);
469 DfGraphPtr checkpoint_graph = std::make_shared<DfGraph>("checkpoint");
470 (void)checkpoint_graph->SetInputs(graph_inputs);
471 (void)checkpoint_graph->SetOutputs(graph_output);
472 this->save_ckp_graph_ = checkpoint_graph;
473 } else {
474 this->save_ckp_graph_ = nullptr;
475 }
476
477 checkpoint_sout_ << "}" << endl;
478 return;
479 }
480 #endif
481
GenerateBroadcastGraph(const TensorOrderMap & tensors)482 DfGraphConvertor &DfGraphConvertor::GenerateBroadcastGraph(const TensorOrderMap &tensors) {
483 if (error_ != 0) {
484 return *this;
485 }
486 if (anf_graph_ == nullptr || anf_graph_->output() == nullptr) {
487 error_ = INVALID_ARGUMENT;
488 MS_LOG(ERROR) << "Invalid AnfGraph in generate broadcast graph";
489 return *this;
490 }
491
492 DfGraphPtr broadcast_graph = std::make_shared<DfGraph>("broadcast");
493 // collect the operators create for broadcast sub graph, in order to avoid auto release
494 std::vector<Operator> broadcast_input;
495 std::vector<GeTensorDesc> broadcast_desc;
496 auto broadcast = std::make_shared<HcomBroadcast>("broadcast_parameter");
497 (void)broadcast->set_attr_root_rank(0);
498 (void)broadcast->set_attr_group("hccl_world_group");
499 broadcast_ops_.push_back(broadcast);
500
501 // find every parameter, build broadcast subgraph (or initialize the parameter with constant)
502 for (auto &it : anf_graph_->parameters()) {
503 auto op_itor = op_cache_.find(it.get()); // converted node
504 if (it->isa<Parameter>() && op_itor != op_cache_.end()) {
505 string name = std::static_pointer_cast<Parameter>(it)->name();
506 auto tensor_itor = tensors.find(name); // in init tensor map
507 if (tensor_itor != tensors.end()) {
508 auto tensor = tensor_itor->second;
509 auto shape_ge = tensor->shape_c();
510
511 // create tensor descriptor for output descriptor
512 auto desc = TransformUtil::GetGeTensorDesc(shape_ge, tensor->data_type(), kOpFormat_NCHW);
513 if (desc == nullptr) {
514 MS_LOG(ERROR) << "Create variable " << name << " output descriptor failed!";
515 continue;
516 }
517
518 // build broadcast subgraph
519 if (distribute_) {
520 auto broadcast_var = std::make_shared<Variable>(name);
521 (void)broadcast_var->update_output_desc_y(*desc);
522 broadcast_input.push_back(*broadcast_var);
523 broadcast_desc.push_back(*desc);
524 broadcast_ops_.push_back(broadcast_var);
525 }
526 }
527 }
528 }
529
530 // set up broadcast sub graph
531 if (!broadcast_input.empty()) {
532 DfGraphConvertor::SetupBroadcast(broadcast, broadcast_desc, broadcast_graph, broadcast_input);
533 } else {
534 this->broadcast_graph_ = nullptr;
535 }
536 return *this;
537 }
538
GenerateCheckpointGraph()539 DfGraphConvertor &DfGraphConvertor::GenerateCheckpointGraph() {
540 if (error_ != 0) {
541 MS_LOG(ERROR) << "Generate checkpoint graph failed, found error code " << error_ << ".";
542 return *this;
543 }
544 if (anf_graph_ == nullptr || anf_graph_->output() == nullptr) {
545 error_ = INVALID_ARGUMENT;
546 MS_LOG(ERROR) << "Invalid AnfGraph in GenerateCheckpointGraph";
547 return *this;
548 }
549 #if (defined ENABLE_GE)
550 BuildSaveCheckpointGraph();
551 // Restoring from checkpoint file is done by pyfront, not in graph now.
552 #endif
553 return *this;
554 }
555
ConvertAllNode()556 DfGraphConvertor &DfGraphConvertor::ConvertAllNode() {
557 if (error_ != 0) {
558 return *this;
559 }
560 if (anf_graph_ == nullptr || anf_graph_->output() == nullptr) {
561 MS_LOG(ERROR) << "Invalid AnfGraph";
562 error_ = FAILED;
563 return *this;
564 }
565
566 compute_sout_.clear();
567 compute_sout_ << "digraph {" << endl;
568 init_sout_.clear();
569 init_sout_ << "digraph {" << endl;
570 #if (defined ENABLE_GE)
571 checkpoint_sout_.clear();
572 checkpoint_sout_ << "digraph {" << endl;
573 #endif
574 restore_checkpoint_sout_.clear();
575 restore_checkpoint_sout_ << "digraph {" << endl;
576
577 // Convert all anf node to Operator
578 MS_LOG(DEBUG) << "convert all node";
579 std::vector<AnfNodePtr> nodes = GetOrderedCNodes(anf_graph_);
580 for (auto &it : nodes) {
581 (void)Convert(it);
582 if (this->error_ != 0) {
583 MS_LOG(ERROR) << "failed to convert node: " << it->DebugString() << ".";
584 }
585 }
586
587 // Create dataset iterator and iterator_getnext node
588 if (ConfigManager::GetInstance().dataset_mode() == DS_SINK_MODE) {
589 DatasetGraphParam param = ConfigManager::GetInstance().dataset_param();
590 MS_LOG(INFO) << "Dataset param is " << param.ToString() << ".";
591 // GetNext
592 auto iter_getnext_op = make_shared<ge::op::GetNext>("get_next_tmp");
593 (void)iter_getnext_op->set_attr_output_types(param.ge_types());
594 (void)iter_getnext_op->set_attr_output_shapes(param.shapes());
595 (void)iter_getnext_op->set_attr_channel_name(param.queue_name());
596
597 // save iter_getnext_op for later use
598 dataset_iter_getnext_ = iter_getnext_op;
599 }
600
601 // return the data flow graph
602 return *this;
603 }
604
TraceOutputFromTupleGetItem(const AnfNodePtr & anf_out)605 void DfGraphConvertor::TraceOutputFromTupleGetItem(const AnfNodePtr &anf_out) {
606 auto it = out_handle_cache_.find(anf_out.get());
607 if (it != out_handle_cache_.end()) {
608 OutHandler handle = it->second;
609 auto op = handle.op;
610 if (op != nullptr) {
611 MS_LOG(INFO) << "op name: " << op->GetName() << ", op type: " << op->GetOpType() << ", out_name: " << handle.out;
612 graph_outputs_.emplace_back(std::make_pair(*op, handle.out));
613 } else {
614 MS_LOG(EXCEPTION) << "tuple_getitem: " << anf_out->fullname_with_scope() << " is not converted";
615 }
616 } else {
617 // invalid tuple_getitem e.g. tuple_getitem(tuple_getitem())/tuple_getitem(depend())/tuple_getitem(make_tuple())
618 MS_LOG(WARNING) << "Invalid tuple_getitem: " << anf_out->fullname_with_scope();
619 }
620 }
621
TraceOutput(const AnfNodePtr node)622 void DfGraphConvertor::TraceOutput(const AnfNodePtr node) {
623 MS_EXCEPTION_IF_NULL(node);
624 AnfNodePtr anf_out = node;
625 AnfNodePtr pre_node = nullptr;
626
627 // Trace value node
628 if (node->isa<ValueNode>()) {
629 auto op = Convert(anf_out);
630 if (op != nullptr) {
631 graph_outputs_.emplace_back(std::make_pair(*op, ""));
632 AddGraphConstInput(op);
633 }
634 return;
635 }
636
637 // Trace Parameter node
638 TraceOutputFromParameter(anf_out);
639 // Then trace cnode
640 if (!node->isa<CNode>()) {
641 return;
642 }
643
644 // trace tuple_getitem
645 while (anf_out->isa<CNode>() && IsPrimitiveCNode(anf_out, prim::kPrimTupleGetItem)) {
646 pre_node = anf_out;
647 anf_out = anf_out->cast<CNodePtr>()->input(1);
648 }
649 // trace every element of make_tuple
650 auto c = anf_out->cast<CNodePtr>();
651 std::string name = "";
652 if (anf_out->isa<CNode>()) {
653 name = GetCNodeTargetFuncName(c);
654 }
655
656 if (name == "MakeTuple") {
657 for (unsigned int i = 1; i < c->inputs().size(); i++) {
658 TraceOutput(c->input(i));
659 }
660 } else if (name == prim::kPrimDepend->name()) {
661 if (c->inputs().size() < 3) { // "Depend" primitive have 3 inputs
662 MS_LOG(EXCEPTION) << "length of inputs is " << c->inputs().size() << ", which is less than 3";
663 }
664 TraceOutput(c->input(1));
665 } else if (name == prim::kTupleGetItem) {
666 TraceOutputFromTupleGetItem(anf_out);
667 } else {
668 // add outputs
669 auto op = Convert(anf_out);
670 std::string index;
671 if (op != nullptr) {
672 if ((pre_node != nullptr) && IsPrimitiveCNode(pre_node, prim::kPrimTupleGetItem)) {
673 auto item = out_handle_cache_.find(pre_node.get());
674 if (item != out_handle_cache_.end()) {
675 index = item->second.out;
676 } else {
677 MS_LOG(WARNING) << "Can't get operator: " << anf_out->fullname_with_scope() << " 's output item";
678 }
679 }
680 MS_LOG(INFO) << "Add graph output: " << anf_out->fullname_with_scope() << ":" << index;
681 graph_outputs_.emplace_back(make_pair(*op, index));
682 }
683 }
684 }
685
TraceOutputFromParameter(const AnfNodePtr & anf_out)686 void DfGraphConvertor::TraceOutputFromParameter(const AnfNodePtr &anf_out) {
687 MS_EXCEPTION_IF_NULL(anf_out);
688 if (anf_out->isa<Parameter>()) {
689 MS_LOG(INFO) << "Add graph output: " << anf_out->fullname_with_scope();
690 auto it = out_handle_cache_.find(anf_out.get());
691 if (it != out_handle_cache_.end()) {
692 // For dataset graph mode, input parameter is converted to a "iterator_get_next:yn" OutHandler.
693 OutHandler handle = it->second;
694 auto op = handle.op;
695 MS_LOG(INFO) << "op name: " << op->GetName() << ", op type: " << op->GetOpType() << ", out_name: " << handle.out;
696 graph_outputs_.emplace_back(make_pair(*op, handle.out));
697 } else {
698 // common parameter case
699 auto op = Convert(anf_out);
700 if (op != nullptr) {
701 MS_LOG(INFO) << "op name: " << op->GetName() << ", op type: " << op->GetOpType();
702 graph_outputs_.emplace_back(std::make_pair(*op, ""));
703 }
704 }
705 }
706 }
707
SetupDatasetIterGetNextNode(const OperatorPtr & op)708 void SetupDatasetIterGetNextNode(const OperatorPtr &op) {
709 if (ConfigManager::GetInstance().dataset_mode() == DS_SINK_MODE) {
710 DatasetGraphParam param = ConfigManager::GetInstance().dataset_param();
711 size_t output_num = param.ge_types().size();
712 MS_LOG(INFO) << "Set iterator_getnext op's output num = " << output_num << ".";
713 // set iterator_getnext op's output num
714 shared_ptr<ge::op::GetNext> iter_getnext = std::static_pointer_cast<ge::op::GetNext>(op);
715 (void)iter_getnext->create_dynamic_output_y(static_cast<unsigned int>(output_num));
716
717 for (uint32_t i = 0; i < output_num; i++) {
718 ge::TensorDesc desc(GeShape(param.shapes()[i]), ge::FORMAT_NCHW, (ge::DataType)param.ge_types()[i]);
719 // we don't SetRealDimCnt here since GE do not use this output's real-dim
720 (void)iter_getnext->update_dynamic_output_desc_y((i), desc);
721 }
722 }
723 return;
724 }
725
SetSubgraph(AnfNodePtr node)726 void DfGraphConvertor::SetSubgraph(AnfNodePtr node) {
727 if (!node->isa<CNode>()) {
728 return;
729 }
730 auto cnode = node->cast<CNodePtr>();
731 if (!IsCaseNode(cnode)) {
732 return;
733 }
734 std::vector<AnfNodePtr> case_inputs;
735 for (size_t i = 1; i < cnode->inputs().size(); i++) {
736 case_inputs.emplace_back(cnode->input(i));
737 }
738 std::shared_ptr<std::vector<DfGraph>> branches = std::make_shared<std::vector<DfGraph>>();
739 auto bnode = cnode->input(0)->cast<CNodePtr>()->input(2)->cast<CNodePtr>();
740
741 for (size_t i = 1; i < bnode->inputs().size(); i++) {
742 auto branch_node = bnode->input(i)->cast<CNodePtr>();
743 for (size_t j = 2; j < branch_node->inputs().size(); j++) {
744 if (std::find(case_inputs.begin(), case_inputs.end(), branch_node->input(j)) == case_inputs.end()) {
745 case_inputs.emplace_back(branch_node->input(j));
746 }
747 }
748 }
749
750 for (size_t i = 1; i < bnode->inputs().size(); i++) {
751 ProcessSubgraph(bnode->input(i), case_inputs);
752 }
753
754 for (size_t i = 1; i < bnode->inputs().size(); i++) {
755 branches->emplace_back(branches_map_[bnode->input(i).get()]);
756 }
757
758 if (op_cache_.find(node.get()) == op_cache_.end()) {
759 return;
760 }
761
762 OpAdapterPtr adpt = FindAdapter(node, training_);
763 if (adpt == nullptr) {
764 MS_LOG(DEBUG) << "Not found adapter";
765 return;
766 }
767
768 OperatorPtr op = Convert(node);
769 adpt->setSubgraph(op, 0, branches);
770 return;
771 }
772
GetCaseNodeInput(const CNodePtr node,const CNodePtr input_node)773 void DfGraphConvertor::GetCaseNodeInput(const CNodePtr node, const CNodePtr input_node) {
774 std::vector<AnfNodePtr> case_inputs;
775 for (size_t i = 1; i < node->inputs().size(); i++) {
776 case_inputs.emplace_back(node->input(i));
777 }
778 auto bnode = input_node->input(2)->cast<CNodePtr>();
779 MS_EXCEPTION_IF_NULL(bnode);
780 for (size_t i = 1; i < bnode->inputs().size(); i++) {
781 auto branch_node = bnode->input(i)->cast<CNodePtr>();
782 MS_EXCEPTION_IF_NULL(branch_node);
783 for (size_t j = 2; j < branch_node->inputs().size(); j++) {
784 if (std::find(case_inputs.begin(), case_inputs.end(), branch_node->input(j)) == case_inputs.end()) {
785 case_inputs.emplace_back(branch_node->input(j));
786 }
787 }
788 }
789
790 const size_t case_index = 1;
791 const size_t make_tuple_index = 2;
792
793 AnfNodePtr case_index_iter = input_node->input(case_index);
794 AnfNodePtr make_tuple_iter = input_node->input(make_tuple_index);
795 auto make_tuple_node = make_tuple_iter->cast<CNodePtr>();
796 std::shared_ptr<std::vector<OutHandler>> tuple_items = std::make_shared<std::vector<OutHandler>>();
797
798 for (size_t i = 0; i < case_inputs.size(); i++) {
799 auto item = case_inputs[i];
800 auto op = Convert(item);
801 if (op != nullptr) {
802 tuple_items->emplace_back(OutHandler(op, "", item));
803 } else if (out_handle_cache_.find(item.get()) != out_handle_cache_.end()) {
804 tuple_items->push_back(out_handle_cache_[item.get()]);
805 } else {
806 MS_LOG(DEBUG) << "Add an empty out handler: " << item->ToString();
807 tuple_items->push_back(OutHandler());
808 }
809 }
810
811 tuple_out_handle_cache_[make_tuple_node.get()] = tuple_items;
812
813 std::shared_ptr<std::vector<AnfNodePtr>> case_input_items = std::make_shared<std::vector<AnfNodePtr>>();
814 case_input_items->emplace_back(case_index_iter);
815 case_input_items->emplace_back(make_tuple_iter);
816 case_input_handle_cache_[node.get()] = case_input_items;
817 }
818
UpdateTupleOutCache()819 void DfGraphConvertor::UpdateTupleOutCache() {
820 for (auto &it : tuple_out_handle_cache_) {
821 std::size_t len = it.second->size();
822 for (std::size_t i = 0; i < len; i++) {
823 OutHandler handle = (*it.second)[i];
824 if (handle.op == nullptr) {
825 continue;
826 }
827 string name = handle.op->GetName();
828 if (vars_.count(name) && (vars_[name] != nullptr)) {
829 (*it.second)[i] = OutHandler(vars_[name], handle.out, handle.node);
830 MS_LOG(INFO) << "update tuple_out_handle_cache_ " << name;
831 }
832 }
833 }
834 }
835
BuildGraph()836 DfGraphConvertor &DfGraphConvertor::BuildGraph() {
837 SetupDatasetIterGetNextNode(dataset_iter_getnext_);
838
839 if (error_ != 0) {
840 return *this;
841 }
842
843 // Case node set input.
844 std::vector<AnfNodePtr> nodes = GetOrderedCNodes(anf_graph_);
845 for (auto &it : nodes) {
846 if (it->isa<CNode>() && IsCaseNode(it->cast<CNodePtr>())) {
847 auto node = it->cast<CNodePtr>();
848 auto input_node = node->input(0)->cast<CNodePtr>();
849 GetCaseNodeInput(node, input_node);
850 }
851 }
852
853 // update tuple_out_handle_cache_
854 UpdateTupleOutCache();
855
856 // set up dependencies
857 MS_LOG(DEBUG) << "set up dependencies";
858 nodes = GetOrderedCNodes(anf_graph_);
859 for (auto &it : nodes) {
860 SetNodeInput(it);
861 SetOpControlInput(it);
862 SetSubgraph(it);
863 UpdateOpDesc(it);
864 }
865
866 if (error_ == 0) {
867 df_graph_ = make_shared<DfGraph>(anf_graph_->ToString());
868 } else {
869 return *this;
870 }
871
872 // set graph input according to the order from anf graph
873 std::vector<Operator> inputs;
874 if (ConfigManager::GetInstance().dataset_mode() == DS_SINK_MODE) {
875 inputs.push_back(*dataset_iter_getnext_);
876 } else {
877 auto params = anf_graph_->parameters();
878 if (use_inputs_) {
879 params = inputs_;
880 auto anf_params = anf_graph_->parameters();
881 for (size_t i = 0; i < params.size(); i++) {
882 for (size_t j = 0; j < anf_params.size(); j++) {
883 if (params[i]->ToString() == anf_params[j]->ToString()) {
884 params[i] = anf_params[j];
885 }
886 }
887 }
888 }
889
890 int index = 0;
891 for (auto &it : params) {
892 auto name = std::static_pointer_cast<Parameter>(it)->name();
893 // the parameters which has not been converted to var
894 if (vars_.find(name) == vars_.end()) {
895 auto op = Convert(it);
896 MS_EXCEPTION_IF_NULL(op);
897 MS_LOG(INFO) << "add not var input " << it->ToString() << ", index " << index;
898 if (op == nullptr) {
899 MS_LOG(ERROR) << "Convert graph failed!";
900 return *this;
901 }
902 UpdateDataOpDesc(it, op);
903 if (HasAbstractMonad(it)) {
904 MS_LOG(INFO) << it->DebugString() << " is a monad parameter, skip.";
905 continue;
906 }
907 MS_LOG(INFO) << "add input " << it->ToString() << ", index " << index;
908 (void)std::static_pointer_cast<Data>(op)->set_attr_index(index++);
909 inputs.push_back(*op);
910 } else if (vars_[name] != nullptr) {
911 MS_LOG(INFO) << "add var input " << it->ToString();
912 auto op = Convert(it);
913 MS_EXCEPTION_IF_NULL(op);
914 inputs.push_back(*op);
915 }
916 }
917 }
918
919 MS_LOG(DEBUG) << "trace output";
920 graph_outputs_.clear();
921 TraceOutput(anf_graph_->get_return()->input(1));
922
923 // Add const nodes as graph input for some operator work with constant
924 MS_LOG(INFO) << "graph const input size: " << graph_const_inputs_.size();
925 std::transform(graph_const_inputs_.begin(), graph_const_inputs_.end(), std::back_inserter(inputs),
926 [](OperatorPtr x) { return *x; });
927
928 MS_LOG(INFO) << "set graph input num: " << inputs.size();
929 (void)df_graph_->SetInputs(inputs);
930
931 // set graph output
932 // set the value of finale return apply node as the output of dataflow graph
933 MS_LOG(DEBUG) << "set output";
934 MS_LOG(INFO) << "set graph output num: " << graph_outputs_.size();
935 (void)df_graph_->SetOutputs(graph_outputs_);
936
937 compute_sout_ << "}" << endl;
938 // For the graph(e.g. eval_subgraph) whose IterNum is 1, donot set NeedIteration flag.
939 if (ConfigManager::GetInstance().iter_num() > 1) {
940 df_graph_->SetNeedIteration(true);
941 }
942 return *this;
943 }
944
UpdateDataOpDesc(const AnfNodePtr & it,const OperatorPtr & op) const945 void DfGraphConvertor::UpdateDataOpDesc(const AnfNodePtr &it, const OperatorPtr &op) const {
946 auto node = std::static_pointer_cast<AnfNode>(it);
947 if (node == nullptr) {
948 MS_LOG(ERROR) << "Update data op descriptor failed! Invalid node.";
949 return;
950 }
951
952 std::vector<int64_t> shape;
953 if (auto normal_shape_ptr = dyn_cast<abstract::Shape>(node->Shape()); normal_shape_ptr != nullptr) {
954 shape = normal_shape_ptr->shape();
955 } else if (auto no_shape_ptr = dyn_cast<abstract::NoShape>(node->Shape()); no_shape_ptr != nullptr) {
956 shape = {};
957 } else {
958 MS_LOG(INFO) << "Invalid shape to update data op descriptor.";
959 return;
960 }
961
962 if (node->Type() == nullptr) {
963 MS_LOG(INFO) << "Invalid type to update data op descriptor.";
964 return;
965 }
966 TypeId me_type = node->Type()->type_id();
967 if (kObjectTypeTensorType == me_type) {
968 me_type = dyn_cast<TensorType>(node->Type())->element()->type_id();
969 }
970 std::ostringstream buf;
971 buf << "[" << shape << "]";
972 MS_LOG(INFO) << "input shape is " << buf.str() << ", type is " << me_type;
973 std::string format = "NCHW";
974 if (it->isa<Parameter>()) {
975 auto param = it->cast<ParameterPtr>();
976 std::string param_name = param->DebugString();
977 auto param_format = param_format_.find(param_name);
978 if (param_format != param_format_.end()) {
979 format = param_format->second;
980 MS_LOG(DEBUG) << "parameter: " << param_name << ", format is " << format;
981 }
982 }
983 auto desc = TransformUtil::GetGeTensorDesc(shape, me_type, format);
984 if (desc == nullptr) {
985 MS_LOG(ERROR) << "Update data op descriptor failed! TensorDesc is null.";
986 } else {
987 (void)std::static_pointer_cast<Data>(op)->update_input_desc_x(*desc);
988 (void)std::static_pointer_cast<Data>(op)->update_output_desc_y(*desc);
989 }
990 }
991
GetComputeGraph()992 DfGraphPtr DfGraphConvertor::GetComputeGraph() { return df_graph_; }
993
GetInitGraph()994 DfGraphPtr DfGraphConvertor::GetInitGraph() { return init_graph_; }
995
GetSaveCheckpointGraph()996 DfGraphPtr DfGraphConvertor::GetSaveCheckpointGraph() { return save_ckp_graph_; }
997
GetBroadcastGraph()998 DfGraphPtr DfGraphConvertor::GetBroadcastGraph() { return broadcast_graph_; }
999
IsSourceEdgeNode(const AnfNodePtr & node)1000 bool DfGraphConvertor::IsSourceEdgeNode(const AnfNodePtr &node) {
1001 if (!node->isa<CNode>()) {
1002 return false;
1003 }
1004 auto cnode = node->cast<CNodePtr>();
1005 if (!IsCustomCNode(cnode)) {
1006 std::string name = GetCNodeTargetFuncName(cnode);
1007 if (name.empty()) {
1008 return false;
1009 }
1010
1011 // Ignore apply node Depend, UpdateState, make_tuple. make_tuple in ge pipeline.
1012 if ((name == prim::kPrimDepend->name()) || (name == prim::kPrimUpdateState->name()) ||
1013 (name == prim::kPrimReturn->name()) || (name == prim::kPrimMakeTuple->name())) {
1014 return false;
1015 }
1016 }
1017 // Load and other normal primitives which contain monad node.
1018 auto has_monad = std::any_of(cnode->inputs().begin(), cnode->inputs().end(),
1019 [](const AnfNodePtr &node) -> bool { return HasAbstractMonad(node); });
1020 if (has_monad) {
1021 return true;
1022 }
1023
1024 // primitive with make_tuple as input
1025 for (auto &input : cnode->inputs()) {
1026 if (IsPrimitiveCNode(input, prim::kPrimMakeTuple)) {
1027 auto tuple = input->cast<CNodePtr>();
1028 auto ret = std::any_of(tuple->inputs().begin(), tuple->inputs().end(),
1029 [](const AnfNodePtr &node) -> bool { return HasAbstractMonad(node); });
1030 if (ret) {
1031 return true;
1032 }
1033 }
1034 }
1035
1036 return false;
1037 }
1038
IsControlEdgeNode(const AnfNodePtr & node)1039 bool DfGraphConvertor::IsControlEdgeNode(const AnfNodePtr &node) {
1040 if (!node->isa<CNode>()) {
1041 return false;
1042 }
1043 auto cnode = node->cast<CNodePtr>();
1044 if (!IsCustomCNode(cnode)) {
1045 std::string name = GetCNodeTargetFuncName(cnode);
1046 if (name.empty()) {
1047 return false;
1048 }
1049
1050 // Ignore apply node of Load, Depend, UpdateState, make_tuple, return
1051 if ((name == prim::kPrimLoad->name()) || (name == prim::kPrimDepend->name()) ||
1052 (name == prim::kPrimUpdateState->name()) || (name == prim::kPrimMakeTuple->name()) ||
1053 (name == prim::kPrimReturn->name())) {
1054 return false;
1055 }
1056 }
1057 return true;
1058 }
1059
ToOperatorPtr(const AnfNodePtr & node)1060 OperatorPtr DfGraphConvertor::ToOperatorPtr(const AnfNodePtr &node) {
1061 auto op = Convert(GetRealOpNode(node));
1062 if (op == nullptr) {
1063 MS_LOG(ERROR) << "Convert real op node to operator failed, " << node->ToString();
1064 error_ = FAILED;
1065 return nullptr;
1066 }
1067 return op;
1068 }
1069
AddEdgeToCache(const AnfNodePtr & src,const AnfNodePtr & dest)1070 void DfGraphConvertor::AddEdgeToCache(const AnfNodePtr &src, const AnfNodePtr &dest) {
1071 auto item = monad_control_edge_cache_.find(src);
1072 if (item == monad_control_edge_cache_.end()) {
1073 monad_control_edge_cache_[src] = std::set<AnfNodePtr>{dest};
1074 } else {
1075 item->second.insert(dest);
1076 }
1077 }
1078
AddEdgeForLoad(const AnfNodePtr & node)1079 void DfGraphConvertor::AddEdgeForLoad(const AnfNodePtr &node) {
1080 auto func_graph = node->func_graph();
1081 MS_EXCEPTION_IF_NULL(func_graph);
1082 auto mng = func_graph->manager();
1083 if (mng == nullptr) {
1084 mng = Manage(func_graph, true);
1085 func_graph->set_manager(mng);
1086 }
1087 auto manager = func_graph->manager();
1088 MS_EXCEPTION_IF_NULL(manager);
1089 if (manager->node_users().find(node) == manager->node_users().end()) {
1090 MS_LOG(EXCEPTION) << "Can't find node in nodes_users.";
1091 }
1092 auto &users = manager->node_users()[node];
1093 std::shared_ptr<std::vector<AnfNodePtr>> src_node_list = std::make_shared<std::vector<AnfNodePtr>>();
1094 std::shared_ptr<std::vector<AnfNodePtr>> dst_node_list = std::make_shared<std::vector<AnfNodePtr>>();
1095 for (const auto &iter : users) {
1096 auto user_node = iter.first;
1097 auto name = GetCNodeTargetFuncName(user_node->cast<CNodePtr>());
1098 if (name == prim::kPrimUpdateState->name()) {
1099 FindDestOps(user_node, dst_node_list, false);
1100 continue;
1101 }
1102 if (IsControlEdgeNode(user_node)) {
1103 src_node_list->push_back(user_node);
1104 continue;
1105 }
1106 FindDestOps(user_node, src_node_list, false);
1107 }
1108
1109 // add to cache
1110 for (auto &dest : *dst_node_list) {
1111 for (auto &src : *src_node_list) {
1112 AddEdgeToCache(src, dest);
1113 }
1114 }
1115 }
1116
FindDestOps(const AnfNodePtr & node,const std::shared_ptr<std::vector<AnfNodePtr>> & node_list,bool top)1117 void DfGraphConvertor::FindDestOps(const AnfNodePtr &node, const std::shared_ptr<std::vector<AnfNodePtr>> &node_list,
1118 bool top) {
1119 MS_EXCEPTION_IF_NULL(node);
1120 auto func_graph = node->func_graph();
1121 MS_EXCEPTION_IF_NULL(func_graph);
1122 auto mng = func_graph->manager();
1123 if (mng == nullptr) {
1124 mng = Manage(func_graph, true);
1125 func_graph->set_manager(mng);
1126 }
1127 auto manager = func_graph->manager();
1128 MS_EXCEPTION_IF_NULL(manager);
1129
1130 auto users = manager->node_users()[node];
1131 for (const auto &iter : users) {
1132 auto user_node = iter.first;
1133 if (IsControlEdgeNode(user_node)) {
1134 if (!top) {
1135 node_list->push_back(user_node);
1136 }
1137 } else {
1138 FindDestOps(user_node, node_list, false);
1139 }
1140 }
1141 }
1142
AutoMonadCollectInput(const AnfNodePtr & node)1143 void DfGraphConvertor::AutoMonadCollectInput(const AnfNodePtr &node) {
1144 if (!IsSourceEdgeNode(node)) {
1145 return;
1146 }
1147
1148 // Add control edge if contain monad input.
1149 std::string name = GetCNodeTargetFuncName(node->cast<CNodePtr>());
1150 if (name == prim::kPrimLoad->name()) {
1151 AddEdgeForLoad(node);
1152 } else {
1153 auto src_ops = ToOperatorPtr(node);
1154 if (src_ops != nullptr) {
1155 // Find dest ops list
1156 std::shared_ptr<std::vector<AnfNodePtr>> dst_node_list = std::make_shared<std::vector<AnfNodePtr>>();
1157 FindDestOps(node, dst_node_list, true);
1158 for (auto &dest : *dst_node_list) {
1159 AddEdgeToCache(node, dest);
1160 }
1161 }
1162 }
1163 }
1164
AutoMonadSetInput(const AnfNodePtr & node)1165 void DfGraphConvertor::AutoMonadSetInput(const AnfNodePtr &node) {
1166 if (monad_control_edge_cache_.find(node) == monad_control_edge_cache_.end()) {
1167 return;
1168 }
1169
1170 auto src_ops = ToOperatorPtr(node);
1171 if (src_ops != nullptr) {
1172 for (auto &dest : monad_control_edge_cache_[node]) {
1173 auto dest_ops = ToOperatorPtr(dest);
1174 if (dest_ops == nullptr) {
1175 continue;
1176 }
1177 (void)dest_ops->AddControlInput(*src_ops);
1178 #ifdef DRAW_GE_GRAPH
1179 compute_sout_ << op_draw_name_[node.get()] << " -> " << op_draw_name_[dest.get()] << "[style=\"dotted\"]" << endl;
1180 #endif
1181 }
1182 }
1183 }
1184
AutoMonadSetControlInput(const AnfNodePtr & node)1185 void DfGraphConvertor::AutoMonadSetControlInput(const AnfNodePtr &node) {
1186 AutoMonadCollectInput(node);
1187 AutoMonadSetInput(node);
1188 }
1189
SetOpControlInput(const AnfNodePtr & node)1190 void DfGraphConvertor::SetOpControlInput(const AnfNodePtr &node) {
1191 MS_EXCEPTION_IF_NULL(node);
1192 AutoMonadSetControlInput(node);
1193 if (control_edge_cache_.find(node.get()) == control_edge_cache_.end()) {
1194 return;
1195 }
1196
1197 std::vector<ControlEdge> control_edges = control_edge_cache_[node.get()];
1198 if ((control_edges.empty())) {
1199 MS_LOG(ERROR) << "Get control edge node's src or dest operator failed";
1200 return;
1201 }
1202
1203 for (auto &item : control_edges) {
1204 (void)item.dest_op->AddControlInput(*item.src_op);
1205 }
1206 }
1207
1208 const std::vector<std::string> trans_var_list = {string(kNameAssign), string(kNameAssignAdd), string(kNameAssignSub)};
1209
ParseLoadInput(const CNodePtr & cnode)1210 AnfNodePtr DfGraphConvertor::ParseLoadInput(const CNodePtr &cnode) {
1211 if (cnode->inputs().size() < 3) {
1212 MS_LOG(EXCEPTION) << "input size error, " << cnode->ToString();
1213 }
1214 const size_t para_index = 1;
1215 return cnode->input(para_index);
1216 }
1217
SetTupleOpInput(const OpAdapterPtr & adpt,const CNodePtr & node,const AnfNodePtr & pred,const OperatorPtr & src,int index)1218 void DfGraphConvertor::SetTupleOpInput(const OpAdapterPtr &adpt, const CNodePtr &node, const AnfNodePtr &pred,
1219 const OperatorPtr &src, int index) {
1220 std::shared_ptr<std::vector<OutHandler>> handler_vec = tuple_out_handle_cache_[pred.get()];
1221 std::shared_ptr<std::vector<OutHandler>> handler_vec_without_monad = std::make_shared<std::vector<OutHandler>>();
1222 bool with_monad = false;
1223 for (auto &handler : *handler_vec) {
1224 // when tuple with monad type element, the handler operator is nullptr, should be ignored.
1225 if (handler.op == nullptr) {
1226 if ((handler.node != nullptr) && !HasAbstractMonad(handler.node)) {
1227 MS_LOG(WARNING) << "Unsupported node in tuple : " << node->ToString();
1228 }
1229 continue;
1230 }
1231 with_monad = true;
1232 handler_vec_without_monad->push_back(handler);
1233 }
1234 int ret = adpt->setInput(src, index, handler_vec_without_monad);
1235 if ((ret == 0) && pred->isa<CNode>() && (pred->cast<CNodePtr>()->inputs().size() == handler_vec->size() + 1)) {
1236 for (unsigned int j = 0; j < handler_vec_without_monad->size(); j++) {
1237 AnfNodePtr input_node = pred->cast<CNodePtr>()->input(j + 1);
1238 if (with_monad) {
1239 input_node = handler_vec_without_monad->at(j).node;
1240 }
1241 compute_sout_ << op_draw_name_[input_node.get()] << " -> " << op_draw_name_[node.get()] << ":" << index << endl;
1242 AddGraphConstInput(handler_vec_without_monad->at(j).op);
1243 }
1244 return;
1245 }
1246 MS_LOG(WARNING) << "This anf node is not supported as a tuple item : " << node->ToString();
1247 }
GetRealInputNode(const CNodePtr & node,const AnfNodePtr & input)1248 AnfNodePtr DfGraphConvertor::GetRealInputNode(const CNodePtr &node, const AnfNodePtr &input) {
1249 if (input == nullptr || node == nullptr) {
1250 return nullptr;
1251 }
1252 AnfNodePtr pred = input;
1253 while (pred->isa<CNode>() && GetCNodeTargetFuncName(pred->cast<CNodePtr>()) == prim::kPrimDepend->name()) {
1254 pred = pred->cast<CNodePtr>()->input(1);
1255 }
1256
1257 // skip input of UMonad, IOMonad
1258 if (IsValueNode<UMonad>(pred) || IsValueNode<IOMonad>(pred)) {
1259 return nullptr;
1260 }
1261
1262 // skip input of the None, UpdateState
1263 if (IsValueNode<None>(pred) || IsPrimitiveCNode(pred, prim::kPrimUpdateState)) {
1264 return nullptr;
1265 }
1266
1267 if (IsPrimitiveCNode(pred, prim::kPrimLoad)) {
1268 pred = ParseLoadInput(pred->cast<CNodePtr>());
1269 }
1270
1271 // transform "Const" op to "Variable" op when the next node is "Assign" op.
1272 std::string c_name = GetCNodeTargetFuncName(node);
1273 auto pos = std::find(trans_var_list.begin(), trans_var_list.end(), c_name);
1274 if (!training_ && pos != trans_var_list.end() && pred->isa<Parameter>()) {
1275 std::string name = std::static_pointer_cast<Parameter>(pred)->name();
1276 auto op_itor = op_cache_.find(pred.get());
1277 if (op_itor == op_cache_.end()) {
1278 MS_LOG(EXCEPTION) << "Can not find op for node " << pred->ToString() << ".";
1279 }
1280 if (op_itor->second != nullptr &&
1281 (op_itor->second->GetOpType() == "Constant" || op_itor->second->GetOpType() == "Const") &&
1282 vars_.find(name) != vars_.end()) {
1283 auto variable = std::make_shared<Variable>(name);
1284 auto desc = vars_[name]->GetOutputDesc("y");
1285 (void)variable->update_output_desc_y(desc);
1286 MS_LOG(DEBUG) << "Trans to variable, var = " << variable->GetName() << ".";
1287 op_itor->second = variable; // replace parameter with variable
1288 vars_[name] = variable;
1289 }
1290 }
1291 return pred;
1292 }
1293
SetOpInput(const OpAdapterPtr & adpt,const CNodePtr & node)1294 void DfGraphConvertor::SetOpInput(const OpAdapterPtr &adpt, const CNodePtr &node) {
1295 OperatorPtr src = Convert(node);
1296 int case_flag = 0;
1297 auto &inputs = node->inputs();
1298 size_t input_size = inputs.size();
1299 if (case_input_handle_cache_.find(node.get()) != case_input_handle_cache_.end()) {
1300 case_flag = 1;
1301 input_size = case_input_handle_cache_[node.get()]->size() + 1;
1302 }
1303
1304 for (size_t i = 1; i < input_size; i++) {
1305 AnfNodePtr pred = nullptr;
1306 if (case_flag != 0) {
1307 pred = case_input_handle_cache_[node.get()]->at(i - 1);
1308 } else {
1309 pred = inputs[i];
1310 }
1311 pred = GetRealInputNode(node, pred);
1312 if (pred == nullptr) {
1313 continue;
1314 }
1315
1316 int index = SizeToInt(i);
1317 // find in out_hadnle_cache_ first
1318 auto it = out_handle_cache_.find(pred.get());
1319 if (it != out_handle_cache_.end()) {
1320 int ret = adpt->setInput(src, index, it->second);
1321 if (ret == 0) {
1322 if (pred->isa<CNode>() && GetCNodeTargetFuncName(pred->cast<CNodePtr>()) == prim::kTupleGetItem) {
1323 compute_sout_ << op_draw_name_[pred->cast<CNodePtr>()->input(1).get()] << " -> " << op_draw_name_[node.get()]
1324 << ":" << i << endl;
1325 } else if (pred->isa<Parameter>()) {
1326 compute_sout_ << op_draw_name_[pred.get()] << " -> " << op_draw_name_[node.get()] << ":" << i << endl;
1327 } else {
1328 // don't draw anything.
1329 MS_LOG(INFO) << "DRAW_GE_GRAPH: Shouldn't have this case.";
1330 }
1331 AddGraphConstInput(it->second.op);
1332 }
1333 } else if (tuple_out_handle_cache_.find(pred.get()) != tuple_out_handle_cache_.end()) {
1334 SetTupleOpInput(adpt, node, pred, src, index);
1335 } else {
1336 auto op = Convert(pred);
1337 int ret = adpt->setInput(src, index, op);
1338 if (ret == 0) {
1339 compute_sout_ << op_draw_name_[pred.get()] << " -> " << op_draw_name_[node.get()] << ":" << i << endl;
1340 AddGraphConstInput(op);
1341 }
1342 }
1343 }
1344 }
1345
AddGraphConstInput(const OperatorPtr & op)1346 void DfGraphConvertor::AddGraphConstInput(const OperatorPtr &op) {
1347 if (op->GetOpType() == "Constant" || op->GetOpType() == "Const") {
1348 graph_const_inputs_.push_back(op);
1349 }
1350 }
1351
SetNodeInput(const AnfNodePtr node)1352 void DfGraphConvertor::SetNodeInput(const AnfNodePtr node) {
1353 if (!node->isa<CNode>()) {
1354 return;
1355 }
1356 if (op_cache_.find(node.get()) == op_cache_.end()) {
1357 return;
1358 }
1359 auto cnode = node->cast<CNodePtr>();
1360 OpAdapterPtr adpt = FindAdapter(cnode, training_);
1361 if (adpt == nullptr) {
1362 error_ = NOT_FOUND;
1363 return;
1364 }
1365
1366 // get Operator from op_cache_, use adapter to set Inputs
1367 DfGraphConvertor::SetOpInput(adpt, cnode);
1368 }
1369
ProcessSubgraph(AnfNodePtr node,const std::vector<AnfNodePtr> & inputs)1370 void DfGraphConvertor::ProcessSubgraph(AnfNodePtr node, const std::vector<AnfNodePtr> &inputs) {
1371 if (!node->isa<CNode>() || GetCNodeFuncName(node->cast<CNodePtr>()) != "Partial") {
1372 return;
1373 }
1374 auto graph_node = node->cast<CNodePtr>()->input(1)->cast<ValueNodePtr>();
1375 MS_EXCEPTION_IF_NULL(graph_node);
1376 FuncGraphPtr anf_graph = graph_node->value()->cast<FuncGraphPtr>();
1377 DfGraphConvertor converter(anf_graph);
1378 converter.use_inputs_ = true;
1379 converter.inputs_ = inputs;
1380 (void)converter.ConvertAllNode().BuildGraph();
1381 #ifdef ENABLE_DUMP_IR
1382 std::string name = graph_node->ToString() + "_ge_graph.dot";
1383 if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
1384 converter.DrawComputeGraph(name);
1385 }
1386 #endif
1387 branches_map_[node.get()] = *(converter.df_graph_);
1388 }
1389
1390 // Update GE op's shape and type info
UpdateOpDesc(const AnfNodePtr node)1391 void DfGraphConvertor::UpdateOpDesc(const AnfNodePtr node) {
1392 if (node == nullptr || !node->isa<CNode>()) {
1393 return;
1394 }
1395
1396 if (op_cache_.find(node.get()) == op_cache_.end()) {
1397 return;
1398 }
1399
1400 OpAdapterPtr adpt = FindAdapter(node, training_);
1401 if (adpt == nullptr) {
1402 error_ = NOT_FOUND;
1403 return;
1404 }
1405
1406 // get Operator from op_cache_
1407 OperatorPtr op = Convert(node);
1408
1409 adpt->updateOutputDesc(op, node->Shape(), node->Type(), node);
1410 }
1411
Convert(const AnfNodePtr node)1412 OperatorPtr DfGraphConvertor::Convert(const AnfNodePtr node) {
1413 if (node == nullptr) {
1414 MS_LOG(ERROR) << "node is nullptr";
1415 error_ = NOT_FOUND;
1416 return nullptr;
1417 }
1418 // find in cache
1419 if (op_cache_.count(node.get())) {
1420 return op_cache_[node.get()];
1421 }
1422
1423 // do not convert primitive node, Load, UpdateState
1424 if (IsValueNode<Primitive>(node) || IsPrimitiveCNode(node, prim::kPrimLoad) ||
1425 IsPrimitiveCNode(node, prim::kPrimUpdateState)) {
1426 return nullptr;
1427 }
1428
1429 // convert a new one
1430 if (node->isa<CNode>()) {
1431 return ConvertCNode(node->cast<CNodePtr>());
1432 }
1433 if (node->isa<Parameter>()) {
1434 return ConvertParameter(node);
1435 }
1436 if (node->isa<ValueNode>()) {
1437 if (IsValueNode<Monad>(node)) {
1438 return nullptr;
1439 }
1440 return ConvertValueNode(node->cast<ValueNodePtr>());
1441 }
1442
1443 MS_LOG(ERROR) << "Invalid AnfNode";
1444 error_ = INVALID_ARGUMENT;
1445 return nullptr;
1446 }
1447
ConvertMakeTuple(const CNodePtr node)1448 void DfGraphConvertor::ConvertMakeTuple(const CNodePtr node) {
1449 std::shared_ptr<std::vector<OutHandler>> tuple_items = std::make_shared<std::vector<OutHandler>>();
1450 // convert each tuple item to a OutHandler
1451 for (size_t i = 1; i < node->inputs().size(); i++) {
1452 AnfNodePtr item = node->input(i);
1453 if (IsPrimitiveCNode(item, prim::kPrimLoad)) {
1454 item = ParseLoadInput(item->cast<CNodePtr>());
1455 }
1456 OperatorPtr op = Convert(item);
1457 if (op != nullptr) {
1458 tuple_items->emplace_back(OutHandler(op, "", item));
1459 } else if (out_handle_cache_.find(item.get()) != out_handle_cache_.end()) {
1460 tuple_items->push_back(out_handle_cache_[item.get()]);
1461 } else {
1462 tuple_items->push_back(OutHandler(nullptr, "", item));
1463 }
1464 }
1465
1466 MS_LOG(DEBUG) << "ConvertMakeTuple: " << node.get() << " " << tuple_items->size();
1467 tuple_out_handle_cache_[node.get()] = tuple_items;
1468 }
1469
ConvertTopK(const CNodePtr node)1470 void DfGraphConvertor::ConvertTopK(const CNodePtr node) {
1471 MS_EXCEPTION_IF_NULL(node);
1472 MS_LOG(INFO) << "Convert TopK second input's type from int64 to int32.";
1473 auto value_ptr = node->input(2)->cast<ValueNodePtr>();
1474 std::ostringstream ss;
1475 ss << "op" << value_ptr.get();
1476 op_draw_name_[value_ptr.get()] = ss.str();
1477 compute_sout_ << ss.str() << "[label= \"" << value_ptr->value()->ToString() << "\" shape=ellipse]" << endl;
1478 MS_EXCEPTION_IF_NULL(value_ptr);
1479 auto input_value = value_ptr->value();
1480 auto int64_value = GetValue<int64_t>(input_value);
1481 OpAdapterPtr adpt = FindAdapter(value_ptr, training_);
1482 auto op = adpt->generate(value_ptr);
1483 adpt->setAttr(op, "value", static_cast<int32_t>(int64_value));
1484 op_cache_[value_ptr.get()] = op;
1485 }
1486
CastToInt(const ValuePtr & value)1487 std::vector<int64_t> DfGraphConvertor::CastToInt(const ValuePtr &value) {
1488 if (value == nullptr) {
1489 MS_LOG(WARNING) << "Value ptr is nullptr.";
1490 return {};
1491 }
1492 std::vector<int64_t> cur_value = {};
1493 if (utils::isa<ValueSequeuePtr>(value)) {
1494 auto val_seq_ptr = value->cast<ValueSequeuePtr>();
1495 MS_EXCEPTION_IF_NULL(val_seq_ptr);
1496 if (!val_seq_ptr->value().empty()) {
1497 auto first_val = val_seq_ptr->value().front();
1498 MS_EXCEPTION_IF_NULL(first_val);
1499 MS_EXCEPTION_IF_NULL(first_val->type());
1500 if (first_val->type()->number_type() == kNumberTypeInt64) {
1501 cur_value = GetValue<std::vector<int64_t>>(value);
1502 } else {
1503 auto origin_value = GetValue<std::vector<int>>(value);
1504 (void)std::transform(origin_value.begin(), origin_value.end(), std::back_inserter(cur_value),
1505 [](int index) { return static_cast<int64_t>(index); });
1506 }
1507 }
1508 } else {
1509 MS_EXCEPTION_IF_NULL(value->type());
1510 if (value->type()->number_type() == kNumberTypeInt64) {
1511 cur_value.push_back(GetValue<int64_t>(value));
1512 } else {
1513 cur_value.push_back(static_cast<int64_t>(GetValue<int>(value)));
1514 }
1515 }
1516 return cur_value;
1517 }
1518
ConvertReshape(const CNodePtr node)1519 void DfGraphConvertor::ConvertReshape(const CNodePtr node) {
1520 MS_LOG(INFO) << "Convert the second input of reshape to op attr.";
1521 const auto kInputNum = 3;
1522 if (node->size() < kInputNum) {
1523 MS_LOG(WARNING) << "Reshape must have two inputs.";
1524 return;
1525 }
1526 OpAdapterPtr adpt = FindAdapter(node, training_);
1527 if (adpt == nullptr) {
1528 return;
1529 }
1530 auto op = adpt->generate(node);
1531 MS_EXCEPTION_IF_NULL(op);
1532 // get shape form attr
1533 auto value_node = node->input(0)->cast<ValueNodePtr>();
1534 MS_EXCEPTION_IF_NULL(value_node);
1535 MS_EXCEPTION_IF_NULL(value_node->value());
1536 auto primitive = value_node->value()->cast<PrimitivePtr>();
1537 MS_EXCEPTION_IF_NULL(primitive);
1538 auto value = primitive->GetAttr("shape");
1539 std::vector<int64_t> list;
1540 list = CastToInt(value);
1541
1542 (void)op->SetAttr("shape", list);
1543 op_cache_[node.get()] = op;
1544 }
1545
TraceTupleGetItem(const CNodePtr & node,uint64_t * index)1546 AnfNodePtr DfGraphConvertor::TraceTupleGetItem(const CNodePtr &node, uint64_t *index) {
1547 const int TUPLE_GET_ITEM_INDEX = 2;
1548 if (node->inputs().size() < 3) { // "tuple_getitem" primitive must have 3 inputs
1549 MS_LOG(EXCEPTION) << "length of inputs of TupleGetItem is less than 3";
1550 }
1551 auto index_node = node->inputs()[TUPLE_GET_ITEM_INDEX];
1552 if (!index_node->isa<ValueNode>()) {
1553 error_ = INVALID_ARGUMENT;
1554 MS_LOG(EXCEPTION) << "can't convert get item with non-constant index";
1555 }
1556 *index = LongToUlong(GetValue<int64_t>(GetValueNode(index_node)));
1557 return node->inputs()[1];
1558 }
1559
TraceDepend(const CNodePtr & node)1560 AnfNodePtr DfGraphConvertor::TraceDepend(const CNodePtr &node) {
1561 auto cnode = node->cast<CNodePtr>();
1562 if (cnode->inputs().size() < 3) { // "Depend" primitive have 3 inputs
1563 MS_LOG(EXCEPTION) << "length of inputs of depend is less than 3";
1564 }
1565 return cnode->inputs()[1];
1566 }
1567
TraceMakeTuple(const CNodePtr & node,uint64_t index)1568 AnfNodePtr DfGraphConvertor::TraceMakeTuple(const CNodePtr &node, uint64_t index) {
1569 if (index + 1 >= node->inputs().size()) {
1570 MS_LOG(EXCEPTION) << "length of make_tuple is less than index: " << index;
1571 }
1572 return node->inputs()[index + 1];
1573 }
1574
GetHandler(const AnfNodePtr & node,const std::stack<uint64_t> & index_stack,AnfNode * const draw_index)1575 OutHandler DfGraphConvertor::GetHandler(const AnfNodePtr &node, const std::stack<uint64_t> &index_stack,
1576 AnfNode *const draw_index) {
1577 if (node == nullptr) {
1578 MS_LOG(ERROR) << "Get nullptr while trace real op";
1579 return OutHandler(nullptr, "");
1580 }
1581 std::ostringstream ss;
1582 ss << "op" << node.get();
1583 if (index_stack.empty()) {
1584 op_draw_name_[draw_index] = ss.str();
1585 return OutHandler(Convert(node), "");
1586 } else {
1587 OpAdapterPtr adpt = FindAdapter(node, training_);
1588 if (adpt == nullptr) {
1589 MS_LOG(ERROR) << "Can not get node output as adpt is nullptr!";
1590 error_ = NOT_FOUND;
1591 return OutHandler(nullptr, "");
1592 }
1593 OperatorPtr op = Convert(node);
1594 if (op == nullptr) {
1595 error_ = NOT_FOUND;
1596 MS_LOG(ERROR) << "Can not convert node for trace real op";
1597 return OutHandler(nullptr, "");
1598 }
1599 op_draw_name_[draw_index] = ss.str();
1600 return adpt->getOutput(Convert(node), UintToInt(index_stack.top()));
1601 }
1602 }
1603
1604 // get the real operator through maketuple tuple_getitem depend
TraceRealOp(AnfNodePtr node)1605 OutHandler DfGraphConvertor::TraceRealOp(AnfNodePtr node) {
1606 bool flag = IsPrimitiveCNode(node, prim::kPrimTupleGetItem) || IsPrimitiveCNode(node, prim::kPrimMakeTuple) ||
1607 IsPrimitiveCNode(node, prim::kPrimDepend);
1608 std::stack<uint64_t> index_stack;
1609 auto draw_index = node.get();
1610 while (flag) {
1611 flag = false;
1612 if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
1613 uint64_t index;
1614 node = TraceTupleGetItem(node->cast<CNodePtr>(), &index);
1615 index_stack.push(index);
1616 flag = true;
1617 } else if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
1618 if (index_stack.empty()) {
1619 MS_LOG(ERROR) << "TraceRealOp find a make_tuple node";
1620 return OutHandler(nullptr, "");
1621 } else {
1622 node = TraceMakeTuple(node->cast<CNodePtr>(), index_stack.top());
1623 index_stack.pop();
1624 flag = true;
1625 }
1626 } else if (IsPrimitiveCNode(node, prim::kPrimDepend)) {
1627 node = TraceDepend(node->cast<CNodePtr>());
1628 flag = true;
1629 }
1630 }
1631 return GetHandler(node, index_stack, draw_index);
1632 }
1633
ConvertTupleGetItem(const CNodePtr node)1634 void DfGraphConvertor::ConvertTupleGetItem(const CNodePtr node) {
1635 auto handle = TraceRealOp(node);
1636 if (handle.op == nullptr) {
1637 MS_LOG(ERROR) << "Failed to trace tuple get item";
1638 return;
1639 }
1640 out_handle_cache_[node.get()] = handle;
1641 }
1642
1643 // Get the real op for tuple_getitem through make tuple, or depend
GetRealOpNode(AnfNodePtr node)1644 AnfNodePtr DfGraphConvertor::GetRealOpNode(AnfNodePtr node) {
1645 const int TUPLE_GET_ITEM_INDEX = 2;
1646 if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
1647 auto node_inputs = node->cast<CNodePtr>()->inputs();
1648 if (node_inputs.size() != 3) { // "tuple_getitem" primitive must have 3 inputs
1649 MS_LOG(ERROR) << "tuple get item node not correct!";
1650 error_ = FAILED;
1651 return node;
1652 }
1653 MS_EXCEPTION_IF_NULL(node_inputs[TUPLE_GET_ITEM_INDEX]);
1654 if (!node_inputs[TUPLE_GET_ITEM_INDEX]->isa<ValueNode>()) {
1655 error_ = INVALID_ARGUMENT;
1656 MS_LOG(EXCEPTION) << "can't convert get item with non-constant index";
1657 }
1658 auto value_ptr = GetValueNode(node_inputs[TUPLE_GET_ITEM_INDEX])->cast<Int32ImmPtr>();
1659 if (value_ptr == nullptr) {
1660 MS_LOG(ERROR) << "Can not convert get item as value is nullptr!";
1661 error_ = FAILED;
1662 return node;
1663 }
1664 int64_t index = value_ptr->value();
1665
1666 // make_tuple apply inputs:make_tuple, [tuple_items,]
1667 if (IsPrimitiveCNode(node_inputs[1], prim::kPrimMakeTuple)) {
1668 auto tuple_inputs = node->cast<CNodePtr>()->inputs();
1669 if (tuple_inputs.size() < LongToSize(index + 1L)) {
1670 MS_LOG(ERROR) << "make tuple input items node not correct! size:" << tuple_inputs.size()
1671 << ", item index:" << index;
1672 error_ = FAILED;
1673 return node;
1674 }
1675 return GetRealOpNode(tuple_inputs[LongToSize(index + 1L)]);
1676 }
1677 return GetRealOpNode(node_inputs[1]);
1678 }
1679
1680 // depend apply inputs: depend,output,depended_node
1681 if (IsPrimitiveCNode(node, prim::kPrimDepend)) {
1682 auto depend_inputs = node->cast<CNodePtr>()->inputs();
1683 if (depend_inputs.size() != 3) { // "Depend" primitive have 3 inputs
1684 MS_LOG(ERROR) << "depend input items not correct";
1685 error_ = FAILED;
1686 return node;
1687 }
1688 return GetRealOpNode(depend_inputs[1]);
1689 }
1690 return node;
1691 }
1692
1693 // convert the anf node to corresponding operator list
ConvertDependNode(const AnfNodePtr node)1694 std::vector<OperatorPtr> DfGraphConvertor::ConvertDependNode(const AnfNodePtr node) {
1695 if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
1696 std::vector<OperatorPtr> op_lists;
1697 auto node_inputs = node->cast<CNodePtr>()->inputs();
1698 for (size_t index = 1; index < node_inputs.size(); index++) {
1699 auto op = Convert(GetRealOpNode(node_inputs[index]));
1700 if (op == nullptr) {
1701 MS_LOG(ERROR) << "Convert real op node to operator failed";
1702 error_ = FAILED;
1703 return std::vector<OperatorPtr>({});
1704 }
1705 op_lists.push_back(op);
1706 }
1707 return op_lists;
1708 }
1709
1710 auto op = Convert(GetRealOpNode(node));
1711 if (op == nullptr) {
1712 MS_LOG(ERROR) << "Convert real op node to operator failed";
1713 error_ = FAILED;
1714 return std::vector<OperatorPtr>({});
1715 }
1716 return std::vector<OperatorPtr>({op});
1717 }
1718
CheckCNode(const std::string & name,const CNodePtr node)1719 bool DfGraphConvertor::CheckCNode(const std::string &name, const CNodePtr node) {
1720 // ignore apply node of return
1721 if (name == "" || name == prim::kPrimReturn->name() || name == prim::kPrimDepend->name() ||
1722 name == prim::kPrimSwitchLayer->name() || name == prim::kPrimPartial->name()) {
1723 return false;
1724 }
1725
1726 // Convert TopK second input from int64 to int32.
1727 if (name == prim::kPrimTopK->name()) {
1728 ConvertTopK(node);
1729 return true;
1730 }
1731
1732 // Convert Reshape add const input to attr(shape)
1733 if (name == prim::kPrimReshape->name()) {
1734 ConvertReshape(node);
1735 return true;
1736 }
1737
1738 // make_tuple is used for a dynamic_input, convert it to a vector of OutHandlers
1739 if (name == prim::kPrimMakeTuple->name()) {
1740 ConvertMakeTuple(node);
1741 return false;
1742 }
1743
1744 // As for nodes with multi outputs, convert tuple_getitem to OutHandle
1745 if (name == prim::kPrimTupleGetItem->name()) {
1746 ConvertTupleGetItem(node);
1747 return false;
1748 }
1749
1750 return true;
1751 }
1752
ConvertCNode(const CNodePtr node)1753 OperatorPtr DfGraphConvertor::ConvertCNode(const CNodePtr node) {
1754 SaveParamFormat(node);
1755 std::string name = GetCNodeTargetFuncName(node);
1756 if (!CheckCNode(name, node)) {
1757 return nullptr;
1758 }
1759
1760 // get corresponding OpAdapter
1761 OpAdapterPtr adpt = FindAdapter(node, training_);
1762 if (adpt == nullptr) {
1763 error_ = NOT_FOUND;
1764 return nullptr;
1765 }
1766
1767 // get operator
1768 OperatorPtr op = nullptr;
1769 auto it_op = op_cache_.find(node.get());
1770 if (it_op != op_cache_.end()) {
1771 op = it_op->second;
1772 } else {
1773 op = adpt->generate(node);
1774 }
1775
1776 // set attribute for primitive
1777 (void)adpt->setAttr(op, node);
1778
1779 // add into cache
1780 (void)op_cache_.insert(std::make_pair(node.get(), op));
1781
1782 DrawCNode(node, adpt);
1783
1784 return op_cache_[node.get()];
1785 }
1786
ConvertParameter(const AnfNodePtr node)1787 OperatorPtr DfGraphConvertor::ConvertParameter(const AnfNodePtr node) {
1788 // convert Parameter in ANF to variable in DataFlow
1789 auto adpt = FindAdapter(node, training_);
1790 if (adpt == nullptr) {
1791 MS_LOG(EXCEPTION) << "Can not find adapter for Parameter";
1792 }
1793 auto op = adpt->generate(node);
1794 op_cache_[node.get()] = op;
1795
1796 // build index for parameter using name
1797 std::string name = std::static_pointer_cast<Parameter>(node)->name();
1798 params_[name] = node;
1799 std::ostringstream ss;
1800 ss << "op" << node.get();
1801 op_draw_name_[node.get()] = ss.str();
1802 compute_sout_ << ss.str() << "[shape=octagon, label=\"" << name << "\"]" << endl;
1803 return op_cache_[node.get()];
1804 }
1805
SaveParamFormat(const CNodePtr node)1806 void DfGraphConvertor::SaveParamFormat(const CNodePtr node) {
1807 AnfNodePtr op = node->input(0);
1808 if (IsValueNode<Primitive>(op)) {
1809 auto prim = GetValueNode<PrimitivePtr>(op);
1810 for (auto attr : prim->attrs()) {
1811 if (attr.first == "format") {
1812 std::string format;
1813 if (attr.second->isa<Int64Imm>()) {
1814 bool converted = CheckAndConvertUtils::ConvertAttrValueToString(prim->name(), "format", &attr.second);
1815 if (converted) {
1816 format = attr.second->ToString();
1817 }
1818 }
1819 if (format != "NCDHW") {
1820 break;
1821 }
1822 for (size_t i = 1; i < node->size(); i++) {
1823 auto input = node->input(i);
1824 if (input->isa<Parameter>()) {
1825 param_format_[input->DebugString()] = format;
1826 MS_LOG(DEBUG) << "Save Param " << input->DebugString() << " format: " << format;
1827 }
1828 }
1829 }
1830 }
1831 }
1832 }
1833
TryConvertValueNodeToMultiConst(const ValueNodePtr node)1834 Status DfGraphConvertor::TryConvertValueNodeToMultiConst(const ValueNodePtr node) {
1835 MS_EXCEPTION_IF_NULL(node);
1836 ValuePtr value = node->value();
1837 MS_EXCEPTION_IF_NULL(value);
1838 if (!value->isa<ValueList>() && !value->isa<ValueTuple>()) {
1839 return FAILED;
1840 }
1841
1842 auto vec = value->isa<ValueTuple>() ? value->cast<ValueTuplePtr>()->value() : value->cast<ValueListPtr>()->value();
1843 if (vec.empty()) {
1844 return FAILED;
1845 }
1846
1847 std::shared_ptr<std::vector<OutHandler>> tuple_items = std::make_shared<std::vector<OutHandler>>();
1848 for (size_t i = 0; i < vec.size(); i++) {
1849 MS_EXCEPTION_IF_NULL(vec[i]);
1850 if (vec[i]->isa<MeTensor>()) {
1851 GeTensorPtr ge_tensor = transform::TransformUtil::ConvertTensor(vec[i]->cast<MeTensorPtr>(), kOpFormat_NCHW);
1852 auto const_op = std::make_shared<Constant>(node->fullname_with_scope() + "/const/inputs/" + std::to_string(i));
1853 (void)const_op->set_attr_value(*ge_tensor);
1854 (void)const_op->update_output_desc_y(ge_tensor->GetTensorDesc());
1855 tuple_items->emplace_back(OutHandler(const_op, ""));
1856 } else {
1857 return FAILED;
1858 }
1859 }
1860 if (tuple_items->empty()) {
1861 return FAILED;
1862 }
1863
1864 tuple_out_handle_cache_[node.get()] = tuple_items;
1865 return SUCCESS;
1866 }
1867
ConvertValueNode(const ValueNodePtr node)1868 OperatorPtr DfGraphConvertor::ConvertValueNode(const ValueNodePtr node) {
1869 // convert valuenode in ANF to Const in DataFlow
1870 // find paramerte referenced by SymbolicKeyInstance of valuenode
1871 std::ostringstream ss;
1872 ss << "op" << node.get();
1873 op_draw_name_[node.get()] = ss.str();
1874 compute_sout_ << ss.str() << "[label= \"" << node->value()->ToString() << "\" shape=ellipse]" << endl;
1875
1876 if (TryConvertValueNodeToMultiConst(node) == SUCCESS) {
1877 MS_LOG(INFO) << "Convert value node to multi Constant OP success";
1878 return nullptr;
1879 }
1880
1881 OpAdapterPtr adpt = FindAdapter(node, training_);
1882 if (adpt == nullptr) {
1883 error_ = NOT_FOUND;
1884 return nullptr;
1885 }
1886 auto op = adpt->generate(node);
1887 // set const's attrs
1888 if (adpt->setAttr(op, "value", node->value()) != 0) {
1889 MS_LOG(WARNING) << "set attr value for const failed";
1890 }
1891
1892 auto const_op = std::static_pointer_cast<Constant>(op);
1893 if (const_op == nullptr) {
1894 MS_LOG(ERROR) << "Get Constant operator failed";
1895 return nullptr;
1896 }
1897 auto ge_tensor = const_op->get_attr_value();
1898 auto ge_desc = ge_tensor.GetTensorDesc();
1899 (void)const_op->update_output_desc_y(ge_desc);
1900 op_cache_[node.get()] = op;
1901 return op_cache_[node.get()];
1902 }
1903
DrawCNode(const CNodePtr node,const OpAdapterPtr adpt)1904 void DfGraphConvertor::DrawCNode(const CNodePtr node, const OpAdapterPtr adpt) {
1905 if (adpt == nullptr || node == nullptr) {
1906 MS_LOG(ERROR) << "Failed to draw apply node as adpt or node is nullptr!";
1907 return;
1908 }
1909 std::ostringstream ss;
1910 ss << "op" << node.get();
1911 op_draw_name_[node.get()] = ss.str();
1912
1913 compute_sout_ << ss.str() << "[label=<";
1914 compute_sout_ << "<table border='1' cellborder='1'>" << endl;
1915
1916 auto input_map = adpt->getInputMap();
1917 auto dyn_input_map = adpt->getDynInputMap();
1918 if (input_map.size() + dyn_input_map.size() > 0) {
1919 compute_sout_ << "<tr>";
1920 for (auto &it : input_map) {
1921 compute_sout_ << "<td port='" << it.first << "'>" << it.second.name << "</td>";
1922 }
1923 for (auto &it : dyn_input_map) {
1924 compute_sout_ << "<td port='" << it.first << "'>" << it.second.name << "</td>";
1925 }
1926 compute_sout_ << "</tr>" << endl;
1927 }
1928
1929 compute_sout_ << "<tr><td colspan=\"" << (input_map.size() + dyn_input_map.size()) << "\">\"" << node->ToString()
1930 << ":" << GetCNodeTargetFuncName(node) << "\"</td></tr>" << endl;
1931
1932 // print attrs' values
1933 auto atts = adpt->GetAttrsFromDrawGraph();
1934 for (auto &it : atts) {
1935 compute_sout_ << "<tr><td colspan=\"" << (input_map.size() + dyn_input_map.size()) << "\">\"" << it
1936 << "\"</td></tr>";
1937 }
1938
1939 adpt->clearAttrVect();
1940
1941 compute_sout_ << "</table>> shape=plaintext]" << endl;
1942 }
RegisterAdapter(const std::string & name,OpAdapterPtr adpt)1943 void DfGraphConvertor::RegisterAdapter(const std::string &name, OpAdapterPtr adpt) {
1944 OpAdapterMap::get()[name] = std::make_shared<OpAdapterDesc>(adpt);
1945 }
RegisterAdapter(const std::string & name,OpAdapterPtr train_adpt,OpAdapterPtr infer_adpt)1946 void DfGraphConvertor::RegisterAdapter(const std::string &name, OpAdapterPtr train_adpt, OpAdapterPtr infer_adpt) {
1947 OpAdapterMap::get()[name] = std::make_shared<OpAdapterDesc>(train_adpt, infer_adpt);
1948 }
1949 } // namespace transform
1950 } // namespace mindspore
1951