• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "transform/graph_ir/convert.h"
18 
19 #include <algorithm>
20 #include <unordered_set>
21 #include <vector>
22 
23 #include "op_proto/inc/array_ops.h"
24 #include "op_proto/inc/elewise_calculation_ops.h"
25 #include "op_proto/inc/save_ops.h"
26 #include "op_proto/inc/state_ops.h"
27 #include "include/common/debug/anf_ir_dump.h"
28 #include "include/common/utils/anfalgo.h"
29 #include "include/common/utils/config_manager.h"
30 #include "include/common/utils/utils.h"
31 #include "include/transform/graph_ir/utils.h"
32 #include "ir/graph_utils.h"
33 #include "ops/array_ops.h"
34 #include "ops/conv_pool_ops.h"
35 #include "ops/framework_ops.h"
36 #include "ops/image_ops.h"
37 #include "ops/math_op_name.h"
38 #include "ops/nn_ops.h"
39 #include "ops/nn_optimizer_ops.h"
40 #include "ops/other_ops.h"
41 #include "ops/sequence_ops.h"
42 #include "ops/structure_ops.h"
43 #include "ops/lite_ops.h"
44 #include "ops/op_def.h"
45 #include "ops/auto_generate/gen_ops_primitive.h"
46 #include "plugin/device/ascend/hal/hardware/ascend_collective_comm_lib.h"
47 #include "plugin/device/ascend/hal/hardware/dummy_ascend_collective_comm_lib.h"
48 #include "plugin/device/ascend/hal/hardware/ge_utils.h"
49 #include "plugin/device/ascend/hal/hccl_adapter/hccl_adapter.h"
50 #include "transform/graph_ir/op_adapter.h"
51 #include "transform/graph_ir/op_adapter_desc.h"
52 #include "transform/graph_ir/op_adapter_map.h"
53 #include "transform/graph_ir/storage_format_convertor.h"
54 #include "utils/anf_utils.h"
55 #include "utils/check_convert_utils.h"
56 #include "utils/log_adapter.h"
57 #include "utils/ms_context.h"
58 #include "utils/symbolic.h"
59 #include "utils/singleton.h"
60 
61 namespace mindspore::transform {
62 using ::ge::Operator;
63 using mindspore::kValueAny;
64 using std::make_shared;
65 using std::shared_ptr;
66 using std::string;
67 using std::vector;
68 using Variable = ::ge::op::Variable;
69 using Constant = ::ge::op::Constant;
70 using Assign = ::ge::op::Assign;
71 using Data = ::ge::op::Data;
72 using RefData = ::ge::op::RefData;
73 using std::endl;
74 using std::static_pointer_cast;
75 
76 constexpr int64_t kInputOffset = 2;
77 constexpr size_t kSwitchInputSize = 4;
78 constexpr size_t kSwitchBodyIndex = 2;
79 constexpr size_t kSwitchAfterIndex = 3;
80 constexpr size_t kAfterIndexInCache = 2;
81 constexpr size_t kCnodeInputSizeOne = 1;
82 constexpr size_t kDataInputIndex = 1;
83 constexpr size_t kInputSize2 = 2;
84 constexpr size_t kMergeInputSize = 2;
85 constexpr size_t kNoOpOptThreshold = 3;
86 constexpr auto kHcclFusionByFusionID = 2;
87 constexpr auto kHcclFusionDefault = 1;
88 constexpr auto kTypeNoOp = "NoOp";
89 constexpr auto kTypeIdentity = "Identity";
90 constexpr auto kTypeIdentityN = "IdentityN";
91 constexpr auto kTypeMerge = "Merge";
92 constexpr auto kTypeIf = "If";
93 constexpr auto kTypeVariable = "Variable";
94 constexpr auto kParallelGroup = "_parallel_group";
95 constexpr auto kParallelGroupId = "_parallel_group_id";
96 constexpr auto kTypeRefData = "RefData";
97 constexpr auto kBroadcast = "broadcast";
98 constexpr auto kInit = "init";
99 constexpr auto kTypeData = "Data";
100 constexpr auto kTypeIndex = "index";
101 constexpr auto kTypeY = "y";
102 constexpr auto kTypeX = "x";
103 constexpr auto kProcessNodeEngineID = "_process_node_engine_id";
104 constexpr auto kIsFreeVariable = "_is_free_variable";
105 
106 namespace {
107 const std::map<TypeId, TypeId> kReduceRaiseMap = {{kNumberTypeInt64, kNumberTypeInt32}};
108 mindspore::HashMap<std::string, size_t> branches_repeat_times = {};
109 mindspore::HashMap<std::string, size_t> call_subgraphs_repeat_times = {};
110 // {node name | {{input_index, dst_type}...}}
111 const std::map<std::string, std::vector<std::pair<size_t, TypeId>>> kTransInputDTypeMap = {
112   {kResizeNearestNeighborGradOpName, {{2, kNumberTypeInt32}}},
113   {kResizeNearestNeighborOpName, {{2, kNumberTypeInt32}}},
114   {kResizeNearestNeighborV2OpName, {{2, kNumberTypeInt32}}},
115   {kResizeNearestNeighborV2GradOpName, {{2, kNumberTypeInt32}}},
116   {kResizeBicubicOpName, {{2, kNumberTypeInt32}}},
117   {kConv2DBackpropFilterOpName, {{3, kNumberTypeInt32}}},
118   {kConv2DBackpropInputOpName, {{3, kNumberTypeInt32}}},
119   {kOneHotOpName, {{2, kNumberTypeInt32}}},
120   {kLinSpaceOpName, {{3, kNumberTypeInt32}}},
121   {kResizeNearestNeighborV2GradOpName, {{2, kNumberTypeInt32}}},
122   {kResizeBilinearV2OpName, {{2, kNumberTypeInt32}}},
123   {kCol2ImOpName, {{2, kNumberTypeInt32}}}};
124 
125 // {node name | {{attr_name, dst_type}...}}
126 const std::map<std::string, std::vector<std::pair<std::string, TypeId>>> kTransAttrDTypeMap = {
127   {kResizeBilinearOpName, {{"size", kNumberTypeInt32}}},
128   {kSpaceToBatchNDOpName, {{"block_shape", kNumberTypeInt32}}},
129   {kBatchToSpaceNDOpName, {{"block_shape", kNumberTypeInt32}}},
130   {kSplitVOpName, {{"split_dim", kNumberTypeInt32}}},
131   {kSplitVDOpName, {{"split_dim", kNumberTypeInt32}}}};
132 
IsValidConversion(TypeId src_type,TypeId dst_type)133 bool IsValidConversion(TypeId src_type, TypeId dst_type) {
134   if (src_type == dst_type) {
135     MS_LOG(DEBUG) << "No need convert, src type and dst type is same, type:" << TypeIdToString(src_type);
136     return false;
137   }
138   auto iter = kReduceRaiseMap.find(src_type);
139   if (iter != kReduceRaiseMap.end() && iter->second == dst_type) {
140     MS_LOG(INFO) << "Convert data type from " << TypeIdToString(src_type) << " to " << TypeIdToString(dst_type);
141     return true;
142   }
143   MS_LOG(DEBUG) << "Unsupported conversion. src_type:" << TypeIdToString(src_type)
144                 << ", dst_type:" << TypeIdToString(dst_type);
145   return false;
146 }
147 
148 template <typename T>
CreateNewValue(const ValuePtr & value,const std::vector<T> & values,const TypeId & dst_type)149 ValuePtr CreateNewValue(const ValuePtr &value, const std::vector<T> &values, const TypeId &dst_type) {
150   MS_EXCEPTION_IF_NULL(value);
151   if (dst_type == kNumberTypeInt32) {
152     if (value->isa<ValueSequence>()) {
153       std::vector<int32_t> result;
154       std::for_each(values.begin(), values.end(),
155                     [&result](const auto &elem) { result.emplace_back(static_cast<int32_t>(elem)); });
156       return MakeValue(result);
157     }
158     return MakeValue(static_cast<int32_t>(values[0]));
159   } else {
160     MS_LOG(EXCEPTION) << "Invalid dst type:" << TypeIdToString(dst_type);
161   }
162   return value;
163 }
164 
165 template <typename T>
GetAllValues(const ValuePtr & value)166 std::vector<T> GetAllValues(const ValuePtr &value) {
167   MS_EXCEPTION_IF_NULL(value);
168   std::vector<T> result;
169   if (value->isa<ValueSequence>()) {
170     auto value_seq = value->cast<ValueSequencePtr>();
171     MS_EXCEPTION_IF_NULL(value_seq);
172     for (const auto &elem : value_seq->value()) {
173       auto value_list = GetAllValues<T>(elem);
174       std::copy(value_list.begin(), value_list.end(), std::back_inserter(result));
175     }
176   } else {
177     result.emplace_back(GetValue<T>(value));
178   }
179   return result;
180 }
181 
GetElemType(const ValuePtr & value)182 TypeId GetElemType(const ValuePtr &value) {
183   MS_EXCEPTION_IF_NULL(value);
184   if (value->isa<tensor::Tensor>()) {
185     auto tensor_ptr = value->cast<tensor::TensorPtr>();
186     MS_EXCEPTION_IF_NULL(tensor_ptr);
187     return tensor_ptr->data_type();
188   }
189   if (!value->isa<ValueList>() && !value->isa<ValueTuple>()) {
190     return value->type()->type_id();
191   }
192 
193   auto elems = value->isa<ValueTuple>() ? value->cast<ValueTuplePtr>()->value() : value->cast<ValueListPtr>()->value();
194   if (elems.empty()) {
195     MS_LOG(EXCEPTION) << "Value:" << value->ToString() << " is empty, check pls.";
196   }
197   return GetElemType(elems.at(0));
198 }
199 
CastDstValue(const ValuePtr & value,const TypeId & dst_type)200 ValuePtr CastDstValue(const ValuePtr &value, const TypeId &dst_type) {
201   MS_EXCEPTION_IF_NULL(value);
202   auto src_type = GetElemType(value);
203   if (!IsValidConversion(src_type, dst_type)) {
204     return nullptr;
205   }
206   if (src_type == kNumberTypeInt64) {
207     if (value->isa<tensor::Tensor>()) {
208       auto tensor_ptr = value->cast<tensor::TensorPtr>();
209       MS_EXCEPTION_IF_NULL(tensor_ptr);
210       auto tensor_size = tensor_ptr->Size() / sizeof(int64_t);
211       int64_t *data = static_cast<int64_t *>(tensor_ptr->data_c());
212       std::vector<int32_t> v;
213       for (size_t i = 0; i < tensor_size; i++) {
214         (void)v.emplace_back(LongToInt(data[i]));
215       }
216       return MakeValue(v);
217     }
218     auto values = GetAllValues<int64_t>(value);
219     return CreateNewValue<int64_t>(value, values, dst_type);
220   } else {
221     MS_LOG(EXCEPTION) << "Invalid src type:" << value->type()->ToString();
222   }
223   return value;
224 }
225 
226 // If mark_fv is true, set the kIsFreeVariable flag for all free variables and their inputs.
SuccIncludeFv(const FuncGraphPtr & fg,const AnfNodePtr & node,bool mark_fv=false)227 AnfNodeWeakPtrList SuccIncludeFv(const FuncGraphPtr &fg, const AnfNodePtr &node, bool mark_fv = false) {
228   AnfNodeWeakPtrList vecs;
229   if (node == nullptr) {
230     return vecs;
231   }
232 
233   if (node->isa<CNode>()) {
234     auto cnode = node->cast<CNodePtr>();
235     bool is_fv = mark_fv && node->has_user_data(kIsFreeVariable);
236     auto &weak_inputs = cnode->weak_inputs();
237 
238     // Check if free variables used.
239     for (const auto &weak_input : weak_inputs) {
240       auto input = weak_input.lock();
241       MS_EXCEPTION_IF_NULL(input);
242       if (is_fv) {
243         input->set_user_data(kIsFreeVariable, std::make_shared<bool>(true));
244       }
245       auto input_fg = GetValueNode<FuncGraphPtr>(input);
246       if (input_fg) {
247         for (auto &fv : input_fg->free_variables_nodes()) {
248           if (fv->func_graph() == fg && fg->nodes().contains(fv)) {
249             if (mark_fv) {
250               fv->set_user_data(kIsFreeVariable, std::make_shared<bool>(true));
251             }
252             (void)vecs.emplace_back(fv);
253           }
254         }
255       }
256     }
257 
258     (void)vecs.insert(vecs.end(), weak_inputs.begin(), weak_inputs.end());
259   }
260 
261   return vecs;
262 }
263 
GetOrderedCNodes(const FuncGraphPtr fg,const AnfNodePtr node=nullptr)264 std::vector<AnfNodePtr> GetOrderedCNodes(const FuncGraphPtr fg, const AnfNodePtr node = nullptr) {
265   MS_EXCEPTION_IF_NULL(fg);
266   auto succ_include_fv = [&fg](const AnfNodePtr &node) -> AnfNodeWeakPtrList { return SuccIncludeFv(fg, node); };
267 
268   return (node == nullptr) ? TopoSort(fg->get_return(), succ_include_fv) : TopoSort(node, succ_include_fv);
269 }
270 
GetFvNames(const FuncGraphPtr fg)271 std::set<std::string> GetFvNames(const FuncGraphPtr fg) {
272   MS_EXCEPTION_IF_NULL(fg);
273   auto succ_include_fv = [&fg](const AnfNodePtr &node) -> AnfNodeWeakPtrList { return SuccIncludeFv(fg, node, true); };
274 
275   std::set<std::string> fvs;
276   auto nodes = TopoSort(fg->get_return(), succ_include_fv);
277   for (const auto &node : nodes) {
278     if (node->has_user_data(kIsFreeVariable)) {
279       node->set_user_data(kIsFreeVariable, std::shared_ptr<bool>(nullptr));
280       fvs.emplace(node->fullname_with_scope());
281     }
282   }
283 
284   return fvs;
285 }
286 
GetDynInputNum(const OpAdapterPtr & adpt,bool is_call,std::vector<int64_t> dyn_input_sizes,size_t real_input_idx,size_t input_size,const CNodePtr & node)287 int64_t GetDynInputNum(const OpAdapterPtr &adpt, bool is_call, std::vector<int64_t> dyn_input_sizes,
288                        size_t real_input_idx, size_t input_size, const CNodePtr &node) {
289   MS_EXCEPTION_IF_NULL(adpt);
290   MS_EXCEPTION_IF_NULL(node);
291   int64_t dyn_input_num = -1;
292   if (!dyn_input_sizes.empty()) {
293     dyn_input_num = dyn_input_sizes.at(real_input_idx - 1);
294   } else if (adpt->IsDynInputOp(real_input_idx)) {
295     if (is_call) {
296       auto &input = node->inputs().back();
297       // the first input of Call node is Primitive, the second input is kernel_graph,
298       // which should not be members of input args, so the dyn_input_num need to minus 2 in default.
299       if (IsPrimitiveCNode(input, prim::kPrimUpdateState)) {
300         // For PartitionedCall, Monod should not be a member of input args, so here dyn_input_num need to minus 3.
301         dyn_input_num = SizeToLong(input_size) - 3;
302       } else {
303         dyn_input_num = SizeToLong(input_size) - 2;
304       }
305       return dyn_input_num;
306     }
307     dyn_input_num = 1;
308   }
309   return dyn_input_num;
310 }
311 
IsBranchNode(const AnfNodePtr & node)312 bool IsBranchNode(const AnfNodePtr &node) { return IsIfNode(node) || IsCaseNode(node); }
313 
GetAnfCallInputs(bool is_kernel_graph,const CNodePtr & c_node)314 std::vector<AnfNodePtr> GetAnfCallInputs(bool is_kernel_graph, const CNodePtr &c_node) {
315   std::vector<AnfNodePtr> inputs;
316   if (is_kernel_graph) {
317     (void)std::copy(c_node->inputs().begin() + kInputOffset, c_node->inputs().end(), std::back_inserter(inputs));
318   } else {
319     if (c_node->input(0)->isa<CNode>()) {
320       auto in0 = c_node->input(0)->cast<CNodePtr>();
321       (void)std::copy(in0->inputs().begin() + kInputOffset, in0->inputs().end(), std::back_inserter(inputs));
322     }
323     (void)std::copy(c_node->inputs().begin() + 1, c_node->inputs().end(), std::back_inserter(inputs));
324   }
325   return inputs;
326 }
327 
HasSubgraph(const std::shared_ptr<AnfGraph> & func_graph)328 bool HasSubgraph(const std::shared_ptr<AnfGraph> &func_graph) {
329   auto node_list = TopoSort(func_graph->get_return());
330   for (auto &node : node_list) {
331     if (!utils::isa<CNodePtr>(node)) {
332       continue;
333     }
334     auto sub_graph = GetCNodeFuncGraph(node);
335     if (sub_graph != nullptr) {
336       return true;
337     }
338   }
339   return false;
340 }
341 
IsMakeTupleWithNullValue(const AnfNodePtr & node,const AnfNodePtr & input)342 bool IsMakeTupleWithNullValue(const AnfNodePtr &node, const AnfNodePtr &input) {
343   MS_EXCEPTION_IF_NULL(input);
344   if (IsPrimitiveCNode(node, prim::kPrimMakeTuple) && input->isa<ValueNode>()) {
345     auto type = input->Type();
346     MS_EXCEPTION_IF_NULL(type);
347     if (type->isa<Tuple>()) {
348       auto tuple_type = type->cast<std::shared_ptr<Tuple>>();
349       MS_EXCEPTION_IF_NULL(tuple_type);
350       if (tuple_type->elements().empty()) {
351         return true;
352       }
353     }
354   }
355   return false;
356 }
357 
IsMonad(const AnfNodePtr & node)358 bool IsMonad(const AnfNodePtr &node) {
359   return IsValueNode<UMonad>(node) || IsValueNode<IOMonad>(node) || HasAbstractMonad(node);
360 }
361 
IsOverFlowNode(const AnfNodePtr & node,const AnfNodePtr & input)362 bool IsOverFlowNode(const AnfNodePtr &node, const AnfNodePtr &input) {
363   return IsPrimitiveCNode(input, prim::kPrimNPUClearFloatStatusV2) ||
364          IsPrimitiveCNode(node, prim::kPrimNPUClearFloatStatusV2) ||
365          IsPrimitiveCNode(node, prim::kPrimNPUGetFloatStatusV2);
366 }
367 
SelectParamOriFormat(const FuncGraphManagerPtr & manager,const AnfNodePtr & node)368 std::string SelectParamOriFormat(const FuncGraphManagerPtr &manager, const AnfNodePtr &node) {
369   MS_EXCEPTION_IF_NULL(manager);
370   MS_EXCEPTION_IF_NULL(node);
371   std::deque<AnfNodePtr> todo{node};
372   while (!todo.empty()) {
373     auto &curr_node = todo.front();
374     todo.pop_front();
375     const auto &nodes = manager->node_users()[curr_node];
376     for (const auto &node_pair : nodes) {
377       if (IsPrimitiveCNode(node_pair.first, prim::kPrimLoad)) {
378         todo.emplace_back(node_pair.first);
379       } else if (node_pair.first->isa<CNode>()) {
380         auto visited_format = GetOpIOFormat(node_pair.first);
381         if (visited_format != kOpFormat_DEFAULT) {
382           return visited_format;
383         }
384       }
385     }
386   }
387   return kOpFormat_DEFAULT;
388 }
389 
GetGeTensorOrders(const mindspore::HashMap<int,int> & ge_input_to_ms_input,const std::vector<int64_t> & dyn_input_sizes,const int & ge_input_size,std::vector<int64_t> * new_dyn_input_sizes)390 std::vector<int> GetGeTensorOrders(const mindspore::HashMap<int, int> &ge_input_to_ms_input,
391                                    const std::vector<int64_t> &dyn_input_sizes, const int &ge_input_size,
392                                    std::vector<int64_t> *new_dyn_input_sizes) {
393   std::vector<int> ge_tensor_orders(ge_input_size, -1);
394   for (int ge_idx = 0; ge_idx < ge_input_size; ++ge_idx) {
395     int ms_idx = ge_input_to_ms_input.at(ge_idx);
396     new_dyn_input_sizes->at(ge_idx) = dyn_input_sizes[ms_idx];
397     int begin_idx = 0;
398     for (int i = 0; i < ms_idx; ++i) {
399       begin_idx += dyn_input_sizes[i] == -1 ? 1 : dyn_input_sizes[i];
400     }
401     ge_tensor_orders[ge_idx] = begin_idx;
402   }
403   return ge_tensor_orders;
404 }
405 
IsNeedToUpdateTensorDesc(const std::string & op_type,const AnfNodePtr & node)406 bool IsNeedToUpdateTensorDesc(const std::string &op_type, const AnfNodePtr &node) {
407   // When IdentityN's input is Function or IdentityN, it can not
408   // find GEType mapping to MSType. There are ERROR logs that do not affect the result. So it no need to set OutputDesc
409   // of IdentityN, it can be inferred by GE. eg: MakeTuple-->MakeTuple. Output node should set OpDesc.
410   if (op_type == kTypeIdentityN && !IsPrimitiveCNode(node, prim::kPrimReturn)) {
411     MS_LOG(DEBUG) << "No need to set the OpDesc of Identity except return, node: " << node->fullname_with_scope();
412     return false;
413   }
414   // NoOp has not output, so it no need to set OutputDesc.
415   if (op_type == kTypeNoOp) {
416     MS_LOG(DEBUG) << "No need to set the OpDesc of NoOp, node: " << node->fullname_with_scope();
417     return false;
418   }
419   return true;
420 }
421 
422 template <typename T>
SetXDataIndex(const OperatorPtr & op,T idx)423 void SetXDataIndex(const OperatorPtr &op, T idx) {
424   MS_EXCEPTION_IF_NULL(op);
425   op->SetAttr(kTypeIndex, static_cast<int64_t>(idx));
426 }
427 
ParamCompare(const std::string & l,const std::string & r,const mindspore::HashMap<std::string,AnfNodePtr> & params,const NodeUsersMap & node_users)428 bool ParamCompare(const std::string &l, const std::string &r, const mindspore::HashMap<std::string, AnfNodePtr> &params,
429                   const NodeUsersMap &node_users) {
430   auto lpram_iter = params.find(l);
431   auto rpram_iter = params.find(r);
432   if (lpram_iter == params.end() && rpram_iter == params.end()) {
433     return l.compare(r) < 0;
434   } else if (lpram_iter == params.end()) {
435     return true;
436   } else if (rpram_iter == params.end()) {
437     return false;
438   }
439 
440   bool lused_as_accum = (GetMomentumVarByAccum(lpram_iter->second, node_users) != nullptr);
441   bool rused_as_accum = (GetMomentumVarByAccum(rpram_iter->second, node_users) != nullptr);
442   if (lused_as_accum ^ rused_as_accum) {
443     return rused_as_accum;
444   }
445 
446   return l.compare(r) < 0;
447 }
448 
IsESNodeWithNoOutput(const AnfNodePtr & node)449 bool IsESNodeWithNoOutput(const AnfNodePtr &node) {
450   const mindspore::HashSet<PrimitivePtr, PrimitiveHasher, PrimitiveEqual> no_output_prims = {
451     prim::kPrimInitPartitionMap,          prim::kPrimInitEmbeddingHashmap,      prim::kPrimEmbeddingTableImport,
452     prim::kPrimEmbeddingComputeVarExport, prim::kPrimEmbeddingComputeVarImport, prim::kPrimEmbeddingTableExport};
453   if (IsOneOfPrimitiveCNode(node, no_output_prims)) {
454     return true;
455   }
456   return false;
457 }
458 
GetEmbeddingApplyAdamOutput(const CNodePtr & node)459 std::vector<AnfNodePtr> GetEmbeddingApplyAdamOutput(const CNodePtr &node) {
460   MS_EXCEPTION_IF_NULL(node);
461   std::vector<AnfNodePtr> ret_nodes;
462   auto depend = node->input(1);
463   MS_EXCEPTION_IF_NULL(depend);
464   if (!IsPrimitiveCNode(depend, prim::kPrimDepend)) {
465     MS_LOG(EXCEPTION) << "Need Depend ops, but get " << depend->fullname_with_scope();
466   }
467   auto depend_cnode = depend->cast<CNodePtr>();
468   auto tuple = depend_cnode->input(2);
469   MS_EXCEPTION_IF_NULL(tuple);
470   if (!IsPrimitiveCNode(tuple, prim::kPrimMakeTuple)) {
471     MS_LOG(EXCEPTION) << "Need MakeTuple ops, but get " << tuple->fullname_with_scope();
472   }
473   auto tuple_cnode = tuple->cast<CNodePtr>();
474   auto output_nodes = tuple_cnode->inputs();
475   ret_nodes.emplace_back(depend_cnode->input(1));
476   ret_nodes.insert(ret_nodes.end(), output_nodes.begin() + 1, output_nodes.end());
477   return ret_nodes;
478 }
479 }  // namespace
480 
GenExampleGraph(const std::string & name)481 DfGraphPtr GenExampleGraph(const std::string &name) {
482   MS_LOG(INFO) << "Gen example graph name is " << name;
483   auto graph = std::make_shared<DfGraph>(name);
484   MS_EXCEPTION_IF_NULL(graph);
485   auto shape_data = std::vector<int64_t>({1, 1, 1, 1});
486   GeTensorDesc desc_data(ge::Shape(shape_data), ge::FORMAT_ND, ge::DT_FLOAT16);
487   auto data = ge::op::Data("data");
488   data.set_attr_index(0);
489   data.update_input_desc_x(desc_data);
490   data.update_output_desc_y(desc_data);
491   auto abs = ge::op::Abs("abs").set_input_x(data);
492   std::vector<Operator> inputs{data};
493   std::vector<Operator> outputs{abs};
494   graph->SetInputs(inputs);
495   graph->SetOutputs(outputs);
496   return graph;
497 }
498 
499 // ---------------implement of DfGraphConvertor-------------
500 
IsDynamicShapeNode(const AnfNodePtr node)501 bool IsDynamicShapeNode(const AnfNodePtr node) {
502   auto shape = node->Shape();
503   if (shape == nullptr) {
504     return false;
505   }
506   if (!shape->isa<abstract::Shape>()) {  // do not accept tuple shape as call node input
507     return false;
508   }
509   if (shape->IsDynamic()) {
510     return true;
511   }
512   return false;
513 }
514 
InitLoopVar(std::vector<::ge::Operator> * init_input)515 bool DfGraphConvertor::InitLoopVar(std::vector<::ge::Operator> *init_input) {
516   MS_EXCEPTION_IF_NULL(init_input);
517   if (!this->training_) {
518     return false;
519   }
520   bool is_sink_size_repeat = false;
521   auto ms_context = MsContext::GetInstance();
522   MS_EXCEPTION_IF_NULL(ms_context);
523   int64_t value = 0;
524   if (ConfigManager::GetInstance().dataset_mode() == DS_SINK_MODE) {
525     static int64_t sink_size = 0;
526     if (!ms_context->get_param<bool>(MS_CTX_ENABLE_LOOP_SINK)) {
527       return false;
528     }
529     value = ConfigManager::GetInstance().iter_num();
530     if (sink_size == value) {
531       is_sink_size_repeat = true;
532     }
533     sink_size = value;
534   } else {
535     MS_LOG(INFO) << "Run with normal(non-sink) mode, the iterator number will always be 1";
536     ConfigManager::GetInstance().ResetIterNum();
537     return false;
538   }
539   GeTensorDesc desc(GeShape(), ::ge::FORMAT_NCHW, ::ge::DT_INT64);
540   auto var_iter_num = std::make_shared<Variable>("npu_runconfig/iterations_per_loop");
541   auto var_loop_cond = std::make_shared<Variable>("npu_runconfig/loop_cond");
542   auto var_one = std::make_shared<Variable>("npu_runconfig/one");
543   auto var_zero = std::make_shared<Variable>("npu_runconfig/zero");
544   (void)var_iter_num->update_output_desc_y(desc);
545   (void)var_loop_cond->update_output_desc_y(desc);
546   (void)var_one->update_output_desc_y(desc);
547   (void)var_zero->update_output_desc_y(desc);
548   vars_["npu_runconfig/iterations_per_loop"] = var_iter_num;
549   vars_["npu_runconfig/loop_cond"] = var_loop_cond;
550   vars_["npu_runconfig/one"] = var_one;
551   vars_["npu_runconfig/zero"] = var_zero;
552   auto const_iter_num = std::make_shared<Constant>("const/npu_runconfig/iterations_per_loop");
553   value -= 1;  // iteration start from 0, the max iteration number for n loop should be n-1
554   (void)const_iter_num->set_attr_value(GeTensor(desc, reinterpret_cast<uint8_t *>(&value), sizeof(int64_t)));
555 
556   auto const_loop_cond = std::make_shared<Constant>("const/npu_runconfig/loop_cond");
557   value = 0;
558   (void)const_loop_cond->set_attr_value(GeTensor(desc, reinterpret_cast<uint8_t *>(&value), sizeof(int64_t)));
559 
560   auto const_one = std::make_shared<Constant>("const/npu_runconfig/one");
561   value = 1;
562   (void)const_one->set_attr_value(GeTensor(desc, reinterpret_cast<uint8_t *>(&value), sizeof(int64_t)));
563 
564   auto const_zero = std::make_shared<Constant>("const/npu_runconfig/zero");
565   value = 0;
566   (void)const_zero->set_attr_value(GeTensor(desc, reinterpret_cast<uint8_t *>(&value), sizeof(int64_t)));
567 
568   (void)const_iter_num->update_output_desc_y(desc);
569   (void)const_loop_cond->update_output_desc_y(desc);
570   (void)const_one->update_output_desc_y(desc);
571   (void)const_zero->update_output_desc_y(desc);
572 
573   auto assign_iter_num = std::make_shared<Assign>("assign/npu_runconfig/iterations_per_loop");
574   (void)assign_iter_num->set_input_ref(*var_iter_num).set_input_value(*const_iter_num);
575   auto assign_loop_cond = std::make_shared<Assign>("assign/npu_runconfig/loop_cond");
576   (void)assign_loop_cond->set_input_ref(*var_loop_cond).set_input_value(*const_loop_cond);
577   auto assign_one = std::make_shared<Assign>("assign/npu_runconfig/one");
578   (void)assign_one->set_input_ref(*var_one).set_input_value(*const_one);
579   auto assign_zero = std::make_shared<Assign>("assign/npu_runconfig/zero");
580   (void)assign_zero->set_input_ref(*var_zero).set_input_value(*const_zero);
581 
582   init_input->emplace_back(*var_iter_num);
583   init_input->emplace_back(*var_loop_cond);
584   init_input->emplace_back(*var_one);
585   init_input->emplace_back(*var_zero);
586   init_ops_.emplace_back(var_iter_num);
587   init_ops_.emplace_back(var_loop_cond);
588   init_ops_.emplace_back(var_one);
589   init_ops_.emplace_back(var_zero);
590   init_ops_.emplace_back(const_iter_num);
591   init_ops_.emplace_back(const_loop_cond);
592   init_ops_.emplace_back(const_one);
593   init_ops_.emplace_back(const_zero);
594   init_ops_.emplace_back(assign_iter_num);
595   init_ops_.emplace_back(assign_loop_cond);
596   init_ops_.emplace_back(assign_one);
597   init_ops_.emplace_back(assign_zero);
598   return is_sink_size_repeat;
599 }
600 
DrawParamInitSubGraph(const std::string & name,const AnfNodePtr & it)601 void DfGraphConvertor::DrawParamInitSubGraph(const std::string &name, const AnfNodePtr &it) {
602   // draw init subgraph
603   init_sout_ << "op_assign" << it.get() << "[label=<";
604   init_sout_ << "<table border='1' cellborder='1'>" << endl;
605   init_sout_ << "<tr>";
606   init_sout_ << "<td port='1'>resource</td>";
607   init_sout_ << "<td port='2'>value</td>";
608   init_sout_ << "</tr>" << endl;
609   init_sout_ << "<tr><td colspan=\"2\">"
610              << "\"assign_" << name << "\"</td></tr>" << endl;
611   init_sout_ << "</table>> shape=plaintext]" << endl;
612   init_sout_ << "param" << it.get() << "[shape=octagon, label=\"" << name << "\"]" << endl;
613   init_sout_ << "const" << it.get() << "[label= \"" << name << "_const"
614              << "\" shape=ellipse]" << endl;
615   init_sout_ << "param" << it.get() << "->"
616              << "op_assign" << it.get() << ":1" << endl;
617   init_sout_ << "const" << it.get() << "->"
618              << "op_assign" << it.get() << ":2" << endl;
619 }
620 
SetupParamInitSubGraph(const TensorOrderMap & tensors,const std::vector<::ge::Operator> * const init_input,bool is_sink_size_repeat)621 void DfGraphConvertor::SetupParamInitSubGraph(const TensorOrderMap &tensors,
622                                               const std::vector<::ge::Operator> *const init_input,
623                                               bool is_sink_size_repeat) {
624   DfGraphPtr init_graph = std::make_shared<DfGraph>(kInit);
625   MS_EXCEPTION_IF_NULL(init_graph);
626   std::vector<AnfNodePtr> nodes = GetOrderedCNodes(anf_graph_);
627 
628   for (auto &it : nodes) {
629     MS_EXCEPTION_IF_NULL(it);
630     if (it->isa<ValueNode>()) {
631       if (IsValueNode<SymbolicKeyInstance>(it)) {
632         auto symbolic = GetValueNode<SymbolicKeyInstancePtr>(it);
633         auto name = std::static_pointer_cast<Parameter>(symbolic->node())->name();
634         auto iter = vars_.find(name);  // get corresponding variable op
635         if (iter != vars_.end()) {
636           op_cache_[it.get()] = iter->second;
637           // #ifdef DRAW_GE_GRAPH
638           compute_sout_ << op_draw_name_[params_[name].get()] << " -> " << op_draw_name_[it.get()]
639                         << "[style=\"dotted\"]" << endl;
640           // #endif
641         }
642       } else if (IsValueNode<RefKey>(it)) {
643         auto refkey = GetValueNode<StringImmPtr>(it);
644         MS_EXCEPTION_IF_NULL(refkey);
645         auto name = refkey->value();
646         auto iter = vars_.find(name);  // get corresponding variable op
647         if (iter != vars_.end()) {
648           op_cache_[it.get()] = iter->second;
649           compute_sout_ << op_draw_name_[params_[name].get()] << " -> " << op_draw_name_[it.get()]
650                         << "[style=\"dotted\"]" << endl;
651         }
652       }
653     }
654   }
655 
656   for (auto &it : tensors) {
657     if (vars_.find(it.first) == vars_.end()) {
658       MS_LOG(WARNING) << "Init parameter " << it.first << " didn't appear in graph.";
659       vars_[it.first] = nullptr;
660     }
661   }
662 
663   // set up init sub graph
664   MS_EXCEPTION_IF_NULL(init_input);
665   if (!init_input->empty() && !is_sink_size_repeat) {
666     // init sub graph needs no input
667     MS_LOG(INFO) << "Build data init subgraph.";
668     (void)init_graph->SetInputs(*init_input);
669     this->init_graph_ = init_graph;
670   } else {
671     this->init_graph_ = nullptr;
672   }
673 }
674 
SetupParamInitSubGraph()675 void DfGraphConvertor::SetupParamInitSubGraph() {
676   DfGraphPtr init_graph = std::make_shared<DfGraph>("init");
677   std::vector<AnfNodePtr> nodes = GetOrderedCNodes(anf_graph_);
678   MS_EXCEPTION_IF_NULL(init_graph);
679 
680   for (auto &it : nodes) {
681     MS_EXCEPTION_IF_NULL(it);
682     if (it->isa<ValueNode>()) {
683       if (IsValueNode<SymbolicKeyInstance>(it)) {
684         auto symbolic = GetValueNode<SymbolicKeyInstancePtr>(it);
685         MS_EXCEPTION_IF_NULL(symbolic);
686         MS_EXCEPTION_IF_NULL(std::static_pointer_cast<Parameter>(symbolic->node()));
687         auto name = std::static_pointer_cast<Parameter>(symbolic->node())->name();
688         auto iter = vars_.find(name);  // get corresponding variable op
689         if (iter != vars_.end()) {
690           op_cache_[it.get()] = iter->second;
691         }
692       } else if (IsValueNode<RefKey>(it)) {
693         auto refkey = GetValueNode<StringImmPtr>(it);
694         MS_EXCEPTION_IF_NULL(refkey);
695         auto name = refkey->value();
696         auto iter = vars_.find(name);  // get corresponding variable op
697         if (iter != vars_.end()) {
698           op_cache_[it.get()] = iter->second;
699         }
700       }
701     }
702   }
703 
704   // set up init sub graph
705   std::vector<::ge::Operator> init_input;
706   bool is_sink_size_repeat = InitLoopVar(&init_input);
707   if (!init_input.empty() && !is_sink_size_repeat) {
708     // init sub graph needs no input
709     MS_LOG(INFO) << "Build data init subgraph.";
710     (void)init_graph->SetInputs(init_input);
711     this->init_graph_ = init_graph;
712   } else {
713     this->init_graph_ = nullptr;
714   }
715 }
716 
SetupBroadcast(const OperatorPtr & broadcast,const std::vector<GeTensorDesc> & broadcast_desc,const DfGraphPtr & broadcast_graph,std::vector<::ge::Operator> broadcast_input)717 void DfGraphConvertor::SetupBroadcast(const OperatorPtr &broadcast, const std::vector<GeTensorDesc> &broadcast_desc,
718                                       const DfGraphPtr &broadcast_graph, std::vector<::ge::Operator> broadcast_input) {
719   MS_LOG(INFO) << "build broadcast subgraph";
720   if (broadcast_desc.size() != broadcast_input.size()) {
721     MS_LOG(EXCEPTION) << "Desc number of BroadCast is not equal to number of Input";
722   }
723   MS_EXCEPTION_IF_NULL(broadcast);
724   (void)broadcast->DynamicInputRegister(kTypeX, (static_cast<unsigned int>(broadcast_input.size())));
725   (void)broadcast->DynamicOutputRegister(kTypeY, static_cast<unsigned int>(broadcast_desc.size()));
726   for (unsigned int i = 0; i < broadcast_input.size(); i++) {
727     (void)broadcast->SetInput(kTypeX, i, broadcast_input[i]);
728     (void)broadcast->UpdateDynamicOutputDesc(kTypeY, i, broadcast_desc[i]);
729   }
730   MS_EXCEPTION_IF_NULL(broadcast_graph);
731   (void)broadcast_graph->SetInputs(broadcast_input);
732   this->broadcast_graph_ = broadcast_graph;
733 }
734 
NodeInputKeepUpdate(const FuncGraphManagerPtr & manager,const AnfNodePtr & node)735 bool DfGraphConvertor::NodeInputKeepUpdate(const FuncGraphManagerPtr &manager, const AnfNodePtr &node) {
736   if (manager == nullptr || node == nullptr) {
737     MS_LOG(ERROR) << "Input argument manager or node is nullptr";
738     return false;
739   }
740   if (offline_convert_) {
741     return false;
742   }
743   if (std::find(extra_variables_names_.begin(), extra_variables_names_.end(), node->fullname_with_scope()) !=
744       extra_variables_names_.end()) {
745     return true;
746   }
747   const auto &node_users = manager->node_users();
748   std::vector<PrimitivePtr> vec{
749     prim::kPrimAssign,        prim::kPrimKVCacheMgr,     prim::kPrimScatterUpdate,       prim::kPrimScatterNdUpdate,
750     prim::kPrimPromptKVCache, prim::kPrimDecoderKVCache, prim::kPrimKVCacheScatterUpdate};
751   auto user_it = node_users.find(node);
752   if (user_it != node_users.end()) {
753     auto &users = user_it->second;
754     for (auto &user_node : users) {
755       auto &node_use = user_node.first;
756       if (node_use && std::any_of(vec.begin(), vec.end(),
757                                   [&node_use](const PrimitivePtr &prim) { return IsPrimitiveCNode(node_use, prim); })) {
758         return true;
759       }
760       // check if node is ReshapeAndKVCache which is fused by akg.
761       if (IsPrimitiveCNode(node_use, prim::kPrimCustom)) {
762         auto prim_custom = GetCNodePrimitive(node_use);
763         const std::string kAttrNameInfoPath = "info_path";
764 
765         if (!prim_custom->HasAttr(kAttrNameInfoPath)) {
766           continue;
767         }
768         auto info_path_attr_node = prim_custom->GetAttr(kAttrNameInfoPath);
769         if (info_path_attr_node == nullptr) {
770           MS_LOG(EXCEPTION) << "attr node '" << kAttrNameInfoPath << "' is null";
771         }
772         std::string info_path = GetValue<std::string>(info_path_attr_node);
773         const std::string kOpReshapeAndCache = "ReshapeAndCache";
774         if (info_path.find(kOpReshapeAndCache) == std::string::npos) {
775           continue;
776         }
777 
778         MS_LOG(INFO) << "found ReshapeAndCache, make use inpu keep update";
779         return true;
780       }
781     }
782   }
783   return false;
784 }
785 
JudgeParamTransType(const bool & node_will_update,bool * as_ref_data,bool * as_constant) const786 void DfGraphConvertor::JudgeParamTransType(const bool &node_will_update, bool *as_ref_data, bool *as_constant) const {
787   if (ref_mode_) {
788     if ((ref_mode_type_ == RefModeFlag::kRefModeAll || node_will_update) && !export_air_) {
789       *as_ref_data = true;
790     } else {  // When only variable will be treated as RefData, constant Parameter will be treated as Constant
791       *as_constant = true;
792     }
793   } else if (!training_ && !node_will_update) {
794     // parameter will be updated, lite inference mode will treat as variables
795     *as_constant = true;
796   }
797 }
798 
InitParamWithData(const TensorOrderMap & tensors)799 void DfGraphConvertor::InitParamWithData(const TensorOrderMap &tensors) {
800   int index = 0;
801   std::vector<Operator> init_input;
802   MS_EXCEPTION_IF_NULL(graph_manager_);
803   // The format of Momentum's accum is updated according to format of Momentum's var, so here sort tensors to put
804   // Momentum's accum parameter at last
805   auto cmp = std::bind(ParamCompare, std::placeholders::_1, std::placeholders::_2, std::cref(params_),
806                        graph_manager_->node_users());
807   std::map<std::string, std::pair<int, tensor::TensorPtr>, decltype(cmp)> ordered_tensors(cmp);
808   // NOTE: the sequence of parameters of init DfGraph is calculated by TensorOrderMap, see method `GetInputTensors`
809   // defined in `mindspore/ccsrc/plugin/device/ascend/hal/hardware/ge_graph_executor.cc`
810   for (auto &it : tensors) {
811     ordered_tensors.insert({it.first, {index++, it.second}});
812   }
813   for (const auto &itor : ordered_tensors) {
814     std::string name = itor.first;
815     auto &it = itor.second;
816     auto node_itor = params_.find(name);
817     // if name not in params_, create a node in graph
818     if (node_itor == params_.end()) {
819       // In lite, param maybe not exist.
820       MS_LOG(WARNING) << name << " is not in params, and create a new node.";
821       ParameterPtr param = std::make_shared<Parameter>(nullptr);
822       MS_EXCEPTION_IF_NULL(param);
823       if (!ref_mode_) {
824         name += "_temp";
825       }
826       param->set_name(name);
827       (void)ConvertParameter(param);
828       node_itor = params_.find(name);
829     }
830     auto node = node_itor->second;
831     MS_EXCEPTION_IF_NULL(node);
832     auto op_itor = op_cache_.find(node.get());
833     if (op_itor == op_cache_.end()) {
834       MS_LOG(EXCEPTION) << "Can not find op for node " << node->ToString() << ".";
835     }
836 
837     MS_EXCEPTION_IF_NULL(it.second);
838     bool as_ref_data = false;
839     bool as_constant = false;
840     auto node_will_update = NodeInputKeepUpdate(graph_manager_, node);
841     JudgeParamTransType(node_will_update, &as_ref_data, &as_constant);
842 
843     auto shape = it.second->shape_c();
844     if (as_ref_data && dyn_ref_data_func_ != nullptr) {
845       shape = dyn_ref_data_func_(node, shape);
846     }
847     auto desc =
848       TransformUtil::GetGeTensorDesc(shape, it.second->data_type(), SelectParamOriFormat(graph_manager_, node));
849     if (desc == nullptr) {
850       MS_LOG(WARNING) << "Create const " << name << " output descriptor failed!";
851       continue;
852     }
853     if (as_ref_data) {
854       StorageFormatConvertor::SetupStorageFormat(anf_graph_, node, desc);
855       auto ref_data = std::make_shared<RefData>(name);
856       MS_EXCEPTION_IF_NULL(ref_data);
857       (void)ref_data->update_output_desc_y(*desc);
858       (void)ref_data->update_input_desc_x(*desc);
859       (void)ref_data->set_attr_index(SizeToInt(ref_datas_.size()));
860       (void)ref_datas_.emplace_back(ref_data);
861       ref_data_names_.emplace_back(name);
862       // do not use read ref_data while ref_data sink
863       MS_LOG(DEBUG) << "InitParam, op_name = " << name << ", var = " << ref_data->GetName() << ".";
864       op_itor->second = ref_data;  // replace parameter with ref_data
865       vars_[name] = ref_data;      // prevent the ref_data operator from being freed
866     } else if (as_constant) {
867       auto adpt_const = FindAdapter(kNameConst, training_);
868       if (adpt_const == nullptr) {
869         continue;
870       }
871       auto const_op = adpt_const->generate(name + "_const");
872       (void)adpt_const->setAttr(const_op, "value", it.second);
873       const_op->UpdateOutputDesc(kTypeY, *desc);
874       const_op_to_value_[const_op] = it.second;
875       vars_[name] = const_op;
876       op_itor->second = const_op;
877     } else {
878       auto &infer_need_update_parameter_names =
879         Singleton<mindspore::device::ascend::InferNeedUpdateParaNames>::Instance().GetInferParameterNames();
880       // we need three variable ops for each graph with same name
881       // build init subgraph
882       auto adpt = FindAdapter(kNameParam, training_);
883       if (adpt == nullptr) {
884         continue;
885       }
886       auto param_op = adpt->generate(name + "_data");
887       if (it.second->is_init() == 0) {
888         SetXDataIndex(param_op, it.first);
889         ProcessInputData(&init_input, &infer_need_update_parameter_names, param_op, name, desc);
890       }
891 
892       auto variable = std::make_shared<Variable>(name);
893       MS_EXCEPTION_IF_NULL(variable);
894       (void)variable->update_output_desc_y(*desc);
895       // do not use read variable while variable sink
896       MS_LOG(DEBUG) << "InitParam, op_name = " << name << ", var = " << variable->GetName() << ".";
897       op_itor->second = variable;  // replace parameter with variable
898       vars_[name] = variable;      // prevent the variable operator from being freed
899       DrawParamInitSubGraph(name, node);
900     }
901   }
902   ReplaceAllParameterToRefData();
903   if (ref_mode_) {
904     SetupParamInitSubGraph();
905   } else {
906     bool is_sink_size_repeat = InitLoopVar(&init_input);
907     SetupParamInitSubGraph(tensors, &init_input, is_sink_size_repeat);
908   }
909 }
910 
ReplaceAllParameterToRefData()911 void DfGraphConvertor::ReplaceAllParameterToRefData() {
912   if (ref_mode_ && (ref_mode_type_ == RefModeFlag::kRefModeAll) && !export_air_) {
913     MS_LOG(INFO) << "Graph abs ref tenor to ref data, " << anf_graph_->ToString();
914     auto parameters = anf_graph_->parameters();
915     int64_t idx = 0;
916     for (const auto &param : parameters) {
917       auto op_itor = op_cache_.find(param.get());
918       if (op_itor != op_cache_.end() && op_itor->second->GetOpType() == kTypeRefData) {
919         MS_LOG(INFO) << "This process param has default, have been change to RefData: " << param->fullname_with_scope();
920         continue;
921       }
922       auto para = param->cast<ParameterPtr>();
923       MS_EXCEPTION_IF_NULL(para);
924       auto abs = para->abstract();
925       MS_EXCEPTION_IF_NULL(abs);
926       if (!abs->isa<abstract::AbstractRefTensor>()) {
927         continue;
928       }
929       MS_EXCEPTION_IF_NULL(abs->BuildShape());
930       auto shape = abs->BuildShape()->GetShapeVector();
931       auto type = abs->BuildType()->type_id();
932       if (type == kObjectTypeTensorType) {
933         type = dyn_cast<TensorType>(abs->BuildType())->element()->type_id();
934       }
935       auto name = para->name();
936       if (name.empty()) {
937         name = "RefData_NULL_" + std::to_string(idx++);
938       }
939       auto ref_data = std::make_shared<RefData>(name);
940       MS_EXCEPTION_IF_NULL(ref_data);
941       auto desc = TransformUtil::GetGeTensorDesc(shape, type, SelectParamOriFormat(graph_manager_, para));
942       if (!desc) {
943         MS_LOG(ERROR) << "Create ge node desc failed, node name:" << name << ", shape: " << shape << ", type: " << type;
944         continue;
945       }
946       (void)ref_data->update_output_desc_y(*desc);
947       (void)ref_data->update_input_desc_x(*desc);
948       (void)ref_data->set_attr_index(SizeToInt(ref_datas_.size()));
949       (void)ref_datas_.emplace_back(ref_data);
950       ref_data_names_.emplace_back(name);
951       // do not use read ref_data while ref_data sink
952       MS_LOG(INFO) << "Change no default param: " << name << " to ref data. ";
953       op_itor->second = ref_data;  // replace parameter with ref_data
954       vars_[name] = ref_data;      // prevent the ref_data operator from being freed
955     }
956   }
957 }
958 
ProcessInputData(vector<Operator> * init_input,std::unordered_set<std::string> * infer_need_update_parameter_names,const OperatorPtr & param_op,const string & name,const std::shared_ptr<GeTensorDesc> & desc)959 void DfGraphConvertor::ProcessInputData(vector<Operator> *init_input,
960                                         std::unordered_set<std::string> *infer_need_update_parameter_names,
961                                         const OperatorPtr &param_op, const string &name,
962                                         const std::shared_ptr<GeTensorDesc> &desc) {
963   MS_EXCEPTION_IF_NULL(init_input);
964   MS_EXCEPTION_IF_NULL(infer_need_update_parameter_names);
965   auto init_var = std::make_shared<Variable>(name);
966   auto assign_op = std::make_shared<Assign>("assign_" + name);
967   MS_EXCEPTION_IF_NULL(init_var);
968   MS_EXCEPTION_IF_NULL(assign_op);
969   (void)init_var->update_output_desc_y(*desc);
970   (void)assign_op->set_input_ref(*init_var).set_input_value(*param_op);
971   init_input->emplace_back(*init_var);
972   this->init_ops_.emplace_back(param_op);
973   this->init_ops_.emplace_back(assign_op);
974   this->init_ops_.emplace_back(init_var);
975   this->init_data_names_.emplace_back(name);
976   infer_need_update_parameter_names->insert(name);
977 }
978 
979 // convert all parameter need initialize to variable
InitParam(const TensorOrderMap & tensors)980 DfGraphConvertor &DfGraphConvertor::InitParam(const TensorOrderMap &tensors) {
981   if (error_ != SUCCESS) {
982     return *this;
983   }
984   if (anf_graph_ == nullptr || anf_graph_->output() == nullptr) {
985     error_ = INVALID_ARGUMENT;
986     MS_LOG(ERROR) << "Invalid AnfGraph in InitParam.";
987     return *this;
988   }
989 
990   InitParamWithData(tensors);
991   init_sout_ << "}" << endl;
992   return *this;
993 }
994 
995 #if (defined ENABLE_D)
BuildSaveCheckpointGraph()996 void DfGraphConvertor::BuildSaveCheckpointGraph() {
997   std::vector<Operator> graph_inputs;
998   ::ge::op::Save save_op("save_parms");
999   int save_op_is_active = 0;
1000   size_t index = 0;
1001   string name;
1002 
1003   auto count_size = std::count_if(vars_.begin(), vars_.end(), [](const auto &it) {
1004     return LongToUlong(it.second == nullptr || it.first.find("/") != std::string::npos);
1005   });
1006 
1007   (void)save_op.create_dynamic_input_tensors(static_cast<uint32_t>(vars_.size() - static_cast<size_t>(count_size)));
1008 
1009   // for each "parameter" in anf graph excluding "input"
1010   for (const auto &it : vars_) {
1011     name = it.first;
1012     if (it.second == nullptr || name.find("/") != std::string::npos) {
1013       continue;
1014     }
1015     Variable variable(name);
1016     (void)variable.update_output_desc_y(it.second->GetOutputDesc(0));
1017     (void)save_op.set_dynamic_input_tensors(static_cast<uint32_t>(index++), variable);
1018 
1019     graph_inputs.emplace_back(variable);
1020 
1021     if (save_op_is_active == 0) {
1022       checkpoint_sout_ << "op_save" << &save_op << "[label=<";
1023       checkpoint_sout_ << "<table border='1' cellborder='1'>" << endl;
1024       checkpoint_sout_ << "<tr><td port='1'>tensor</td></tr>" << endl;
1025       checkpoint_sout_ << "<tr><td colspan=\"1\">"
1026                        << "\"saveop"
1027                        << "\"</td></tr>" << endl;
1028       checkpoint_sout_ << "</table>> shape=plaintext]" << endl;
1029     }
1030 
1031     checkpoint_sout_ << "param" << it.second << "[shape=octagon, label=\"" << name << "\"]" << endl;
1032 
1033     checkpoint_sout_ << "param" << it.second << "->"
1034                      << "op_save" << &save_op << ":1" << endl;
1035     save_op_is_active = 1;
1036   }
1037   if (save_op_is_active != 0) {
1038     std::vector<Operator> graph_output;
1039     (void)graph_output.emplace_back(save_op);
1040     DfGraphPtr checkpoint_graph = std::make_shared<DfGraph>("checkpoint");
1041     (void)checkpoint_graph->SetInputs(graph_inputs);
1042     (void)checkpoint_graph->SetOutputs(graph_output);
1043     this->save_ckp_graph_ = checkpoint_graph;
1044   } else {
1045     this->save_ckp_graph_ = nullptr;
1046   }
1047 
1048   checkpoint_sout_ << "}" << endl;
1049   return;
1050 }
1051 #endif
1052 
GenerateBroadcastGraph(const TensorOrderMap & tensors)1053 DfGraphConvertor &DfGraphConvertor::GenerateBroadcastGraph(const TensorOrderMap &tensors) {
1054   if (error_ != SUCCESS) {
1055     return *this;
1056   }
1057   if (anf_graph_ == nullptr || anf_graph_->output() == nullptr) {
1058     error_ = INVALID_ARGUMENT;
1059     MS_LOG(ERROR) << "Invalid AnfGraph in generate broadcast graph";
1060     return *this;
1061   }
1062 
1063   DfGraphPtr broadcast_graph = std::make_shared<DfGraph>(kBroadcast);
1064   // collect the operators create for broadcast sub graph, in order to avoid auto release
1065   std::vector<Operator> broadcast_input;
1066   std::vector<GeTensorDesc> broadcast_desc;
1067   auto adpt = FindAdapter(kNameBroadcast);
1068   if (!adpt) {
1069     MS_LOG(EXCEPTION) << "Get adpt failed, node type: HcomBroadcast";
1070   }
1071   auto broadcast = adpt->generate("broadcast_parameter");
1072   const int64_t root_rank_v = 0;
1073   (void)broadcast->SetAttr("root_rank", root_rank_v);
1074   (void)broadcast->SetAttr("group", "hccl_world_group");
1075   broadcast_ops_.emplace_back(broadcast);
1076 
1077   // find every parameter, build broadcast subgraph (or initialize the parameter with constant)
1078   for (auto &it : anf_graph_->parameters()) {
1079     auto op_itor = op_cache_.find(it.get());  // converted node
1080     if (it->isa<Parameter>() && op_itor != op_cache_.end()) {
1081       string name = std::static_pointer_cast<Parameter>(it)->name();
1082       auto tensor_itor = tensors.find(name);  // in init tensor map
1083       if (tensor_itor != tensors.end()) {
1084         auto tensor = tensor_itor->second;
1085         auto shape_ge = tensor->shape_c();
1086 
1087         // create tensor descriptor for output descriptor
1088         auto desc = TransformUtil::GetGeTensorDesc(shape_ge, tensor->data_type(), kOpFormat_DEFAULT);
1089         if (desc == nullptr) {
1090           MS_LOG(ERROR) << "Create variable " << name << " output descriptor failed!";
1091           continue;
1092         }
1093 
1094         // build broadcast subgraph
1095         if (distribute_) {
1096           auto broadcast_var = std::make_shared<Variable>(name);
1097           (void)broadcast_var->update_output_desc_y(*desc);
1098           broadcast_input.emplace_back(*broadcast_var);
1099           broadcast_desc.emplace_back(*desc);
1100           broadcast_ops_.emplace_back(broadcast_var);
1101         }
1102       }
1103     }
1104   }
1105 
1106   // set up broadcast sub graph
1107   if (!broadcast_input.empty()) {
1108     DfGraphConvertor::SetupBroadcast(broadcast, broadcast_desc, broadcast_graph, broadcast_input);
1109   } else {
1110     this->broadcast_graph_ = nullptr;
1111   }
1112   return *this;
1113 }
1114 
GenerateCheckpointGraph()1115 DfGraphConvertor &DfGraphConvertor::GenerateCheckpointGraph() {
1116   if (error_ != SUCCESS) {
1117     MS_LOG(ERROR) << "Generate checkpoint graph failed, found error code " << error_ << ".";
1118     if (!unsupported_ops_names_.empty()) {
1119       MS_LOG(ERROR) << "===========================================";
1120       MS_LOG(ERROR) << unsupported_ops_names_.size() << " Operator(s) cannot be converted:";
1121       std::string unsupported_ops_list;
1122       for (const auto &unsupported_ops : unsupported_ops_names_) {
1123         if (!unsupported_ops_list.empty()) {
1124           unsupported_ops_list += ", ";
1125         }
1126         unsupported_ops_list += unsupported_ops;
1127       }
1128       MS_LOG(ERROR) << "Unsupported op type list: " << unsupported_ops_list;
1129       MS_LOG(ERROR) << "===========================================";
1130     }
1131     return *this;
1132   }
1133   if (anf_graph_ == nullptr || anf_graph_->output() == nullptr) {
1134     error_ = INVALID_ARGUMENT;
1135     MS_LOG(ERROR) << "Invalid AnfGraph in GenerateCheckpointGraph";
1136     return *this;
1137   }
1138 #ifdef ENABLE_D
1139   auto ms_context = MsContext::GetInstance();
1140   MS_EXCEPTION_IF_NULL(ms_context);
1141   if (ms_context->backend_policy() == "ge") {
1142     BuildSaveCheckpointGraph();
1143     // Restoring from checkpoint file is done by pyfront, not in graph now.
1144   }
1145 #endif
1146   return *this;
1147 }
1148 
ConvertAllNode()1149 DfGraphConvertor &DfGraphConvertor::ConvertAllNode() {
1150   if (error_ != SUCCESS) {
1151     return *this;
1152   }
1153   if (anf_graph_ == nullptr || anf_graph_->output() == nullptr) {
1154     MS_LOG(ERROR) << "Invalid AnfGraph";
1155     error_ = FAILED;
1156     return *this;
1157   }
1158 
1159   compute_sout_.clear();
1160   compute_sout_ << "digraph {" << endl;
1161   init_sout_.clear();
1162   init_sout_ << "digraph {" << endl;
1163 #ifdef ENABLE_D
1164   auto ms_context = MsContext::GetInstance();
1165   MS_EXCEPTION_IF_NULL(ms_context);
1166   if (ms_context->backend_policy() == "ge") {
1167     checkpoint_sout_.clear();
1168     checkpoint_sout_ << "digraph {" << endl;
1169   }
1170 #endif
1171   restore_checkpoint_sout_.clear();
1172   restore_checkpoint_sout_ << "digraph {" << endl;
1173   // Trans data_type for some specific nodes' inputs and attr.
1174   TransDataType(anf_graph_);
1175   // Convert all anf node to Operator
1176   MS_LOG(INFO) << "Convert all node, graph: " << anf_graph_->ToString();
1177   std::vector<AnfNodePtr> nodes = GetOrderedCNodes(anf_graph_, while_cond_node_);
1178   if (ref_mode_) {
1179     // Ref mode need build all node(cnode && parameter).
1180     for (auto &p : anf_graph_->parameters()) {
1181       if (std::find(nodes.begin(), nodes.end(), p) == nodes.end()) {
1182         MS_LOG(INFO) << "Parameter " << p->DebugString() << " can not found in topo sort lists.";
1183         nodes.emplace_back(p);
1184       }
1185     }
1186   }
1187   for (auto &it : nodes) {
1188     if (IsSubGraph() && it->isa<Parameter>()) {
1189       continue;
1190     }
1191     if (IsSubGraph() && (IsPartialSuccNode(it) || IsPartialCNode(it))) {
1192       continue;
1193     }
1194     (void)Convert(it);
1195     if (this->error_ != SUCCESS) {
1196       MS_LOG(ERROR) << "Failed to convert node: " << it->DebugString() << ".";
1197     }
1198   }
1199 
1200   // return the data flow graph
1201   return *this;
1202 }
1203 
CacheWhileGraph(const CNodePtr & cnode)1204 void DfGraphConvertor::CacheWhileGraph(const CNodePtr &cnode) {
1205   if (while_graph_cache_.find(cnode) != while_graph_cache_.end()) {
1206     return;
1207   }
1208   ValueNodePtr graph_node = nullptr;
1209   if (is_kernel_graph_) {
1210     graph_node = cnode->input(1)->cast<ValueNodePtr>();
1211   } else {
1212     if (cnode->input(0)->isa<ValueNode>()) {
1213       graph_node = cnode->input(0)->cast<ValueNodePtr>();
1214     } else {
1215       auto partial_node = cnode->input(0);
1216       graph_node = partial_node->cast<CNodePtr>()->input(1)->cast<ValueNodePtr>();
1217     }
1218   }
1219 
1220   MS_EXCEPTION_IF_NULL(graph_node);
1221   FuncGraphPtr cond_graph = graph_node->value()->cast<FuncGraphPtr>();
1222   MS_EXCEPTION_IF_NULL(cond_graph);
1223   const auto &cond_set = cond_graph->nodes();
1224   for (auto beg = cond_set.begin(); beg != cond_set.end(); ++beg) {
1225     if (!((*beg)->isa<CNode>())) {
1226       continue;
1227     }
1228     auto c_beg = (*beg)->cast<CNodePtr>();
1229     if (GetCNodeFuncName(c_beg) == prim::kPrimSwitch->name()) {
1230       while_graph_cache_[cnode] = {c_beg->input(1), c_beg->input(kSwitchBodyIndex), c_beg->input(kSwitchAfterIndex)};
1231     }
1232   }
1233 }
1234 
GetWhileBodyOutputs()1235 std::vector<Operator> DfGraphConvertor::GetWhileBodyOutputs() {
1236   std::vector<Operator> outputs;
1237 
1238   const auto &node = anf_graph_->get_return()->input(1);
1239   AnfNodePtr real_ret = node;
1240   MS_EXCEPTION_IF_NULL(real_ret);
1241   while (real_ret->isa<CNode>() && GetCNodeTargetFuncName(real_ret->cast<CNodePtr>()) == prim::kPrimDepend->name()) {
1242     real_ret = real_ret->cast<CNodePtr>()->input(1);
1243   }
1244 
1245   // skip input of UMonad, IOMonad
1246   if (HasAbstractUMonad(real_ret) || HasAbstractIOMonad(real_ret)) {
1247     return outputs;
1248   }
1249 
1250   // skip input of the None, UpdateState
1251   if (IsValueNode<None>(real_ret) || IsPrimitiveCNode(real_ret, prim::kPrimUpdateState)) {
1252     return outputs;
1253   }
1254 
1255   if (IsPrimitiveCNode(real_ret, prim::kPrimLoad)) {
1256     real_ret = ParseLoadInput(real_ret->cast<CNodePtr>());
1257   }
1258 
1259   if (!real_ret->isa<CNode>()) {
1260     return outputs;
1261   }
1262 
1263   auto c_node = real_ret->cast<CNodePtr>();
1264   std::vector<AnfNodePtr> inputs = GetAnfCallInputs(is_kernel_graph_, c_node);
1265   for (size_t i = 0; i < inputs.size(); i++) {
1266     auto j = inputs[i];
1267     MS_EXCEPTION_IF_NULL(j);
1268     if (!IsDataInput(c_node, j, 0)) {
1269       continue;
1270     }
1271     if (j->isa<Parameter>()) {
1272       int64_t idx = find(inputs_.begin(), inputs_.end(), j) - inputs_.begin();
1273       auto idx_cond = body_cond_map_[idx];
1274       if (while_used_input_index_.find(idx_cond) == while_used_input_index_.end() ||
1275           while_const_input_index_.find(idx_cond) != while_const_input_index_.end()) {
1276         continue;
1277       }
1278       outputs.emplace_back(*(subgraph_input_cache_[idx]));
1279     } else {
1280       outputs.emplace_back(*Convert(j));
1281     }
1282   }
1283   MS_LOG(DEBUG) << "get while body outputs size: " << outputs.size();
1284   return outputs;
1285 }
1286 
GetWhileSubGraphInput()1287 std::shared_ptr<std::vector<Operator>> DfGraphConvertor::GetWhileSubGraphInput() {
1288   std::shared_ptr<std::vector<Operator>> graph_in = std::make_shared<std::vector<Operator>>();
1289   subgraph_input_cache_.clear();
1290   size_t i = 0;
1291   OperatorPtr op = nullptr;
1292   ParamIndexMap cond_body;
1293   std::string name_app = "_in_cond";
1294   if (graph_type_ == GraphType::kBody) {
1295     name_app = "_in_body";
1296     for (auto &p : body_cond_map_) {
1297       cond_body[p.second] = p.first;
1298     }
1299   }
1300   for (auto &idx : while_used_input_index_) {
1301     if (while_const_input_index_.find(idx) == while_const_input_index_.end()) {
1302       op = std::make_shared<Data>();
1303       MS_EXCEPTION_IF_NULL(op);
1304       SetXDataIndex(op, i);
1305       i++;
1306     } else {
1307       // No need to process ge tensor desc
1308       auto temp = while_const_input_index_[idx].op;
1309       auto name = temp->GetName();
1310       auto value = const_op_to_value_[temp];
1311       MS_EXCEPTION_IF_NULL(value);
1312       auto adpt_const = FindAdapter(kNameConst, training_);
1313       if (adpt_const == nullptr) {
1314         continue;
1315       }
1316       name += name_app;
1317       auto const_op = adpt_const->generate(name);
1318       (void)adpt_const->setAttr(const_op, "value", value);
1319       auto const_op_desc = TransformUtil::GetGeTensorDesc(value->shape_c(), value->data_type(), kOpFormat_DEFAULT);
1320       if (const_op_desc == nullptr) {
1321         MS_LOG(WARNING) << "Create variable " << name << " output descriptor failed!";
1322         continue;
1323       }
1324       const_op->UpdateOutputDesc(kTypeY, *const_op_desc);
1325       op = const_op;
1326     }
1327     graph_in->emplace_back(*op);
1328     if (graph_type_ == GraphType::kCond) {
1329       subgraph_input_cache_[idx] = op;
1330     } else if (graph_type_ == GraphType::kBody) {
1331       subgraph_input_cache_[cond_body[idx]] = op;
1332     }
1333   }
1334   MS_LOG(DEBUG) << "created " << subgraph_input_cache_.size() << " data node "
1335                 << " in graph: " << anf_graph_->ToString();
1336   return graph_in;
1337 }
1338 
BuildWhileSubGraph()1339 void DfGraphConvertor::BuildWhileSubGraph() {
1340   // set up dependencies
1341 
1342   std::vector<Operator> graph_in = *GetWhileSubGraphInput();
1343   auto nodes = GetOrderedCNodes(anf_graph_, while_cond_node_);
1344 
1345   AnfNodePtr real_ret = anf_graph_->get_return()->input(1);
1346   while (real_ret->isa<CNode>() && GetCNodeTargetFuncName(real_ret->cast<CNodePtr>()) == prim::kPrimDepend->name()) {
1347     real_ret = real_ret->cast<CNodePtr>()->input(1);
1348   }
1349   for (auto &it : nodes) {
1350     if (IsBranchNode(it)) {
1351       auto node = it->cast<CNodePtr>();
1352       GetBranchNodeInput(node);
1353     }
1354   }
1355 
1356   for (auto &it : nodes) {
1357     if (it == real_ret || HasAbstractMonad(it)) {
1358       continue;
1359     }
1360     SetNodeInput(it);
1361     SetSubgraph(it);
1362     UpdateOpDesc(it);
1363   }
1364   std::vector<Operator> graph_out;
1365   auto graph_name = TransformUtil::NormOpName(cur_while_node_->fullname_with_scope());
1366   if (graph_type_ == GraphType::kCond) {
1367     if (op_cache_.find(while_cond_node_.get()) == op_cache_.end()) {
1368       return;
1369     }
1370     graph_name += "_cond_graph";
1371     graph_out.emplace_back(*(op_cache_[while_cond_node_.get()]));
1372   } else {
1373     graph_name += "_body_graph";
1374     graph_out = GetWhileBodyOutputs();
1375   }
1376   if (error_ == SUCCESS) {
1377     if (df_graph_->GetName() != graph_name) {
1378       MS_LOG(DEBUG) << "convert anf graph name : " << df_graph_->GetName() << " to df graph name: " << graph_name;
1379     }
1380     df_graph_ = make_shared<DfGraph>(graph_name);
1381   } else {
1382     return;
1383   }
1384   MS_LOG(DEBUG) << "Set while sub graph input num: " << graph_in.size();
1385   MS_LOG(DEBUG) << "Set while sub graph output num: " << graph_out.size();
1386 
1387   compute_sout_ << "}" << endl;
1388   (void)df_graph_->SetInputs(graph_in).SetOutputs(graph_out);
1389   IdentityOptimization();
1390 }
1391 
BuildWhileAfterSubGraph()1392 void DfGraphConvertor::BuildWhileAfterSubGraph() {
1393   size_t i = 0;
1394   prev_cond_to_while_out_index_.clear();
1395   for (auto n : prev_while_used_input_index_) {
1396     if (prev_while_const_input_index_.find(n) == prev_while_const_input_index_.end()) {
1397       prev_cond_to_while_out_index_[n] = i;
1398       i++;
1399     }
1400   }
1401   GetCallNodeInputs(cur_while_node_);
1402   auto nodes = GetOrderedCNodes(anf_graph_);
1403   for (auto &it : nodes) {
1404     SetNodeInput(it);
1405     SetSubgraph(it);
1406     UpdateOpDesc(it);
1407   }
1408   if (graph_outputs_.empty()) {
1409     SetGraphOutputs();
1410   }
1411   compute_sout_ << "}" << endl;
1412   return;
1413 }
1414 
ConvertWhileBody(const AnfNodePtr & node)1415 void DfGraphConvertor::ConvertWhileBody(const AnfNodePtr &node) {
1416   if (!node->isa<CNode>() || GetCNodeFuncName(node->cast<CNodePtr>()) != prim::kPrimPartial->name()) {
1417     return;
1418   }
1419   auto graph_node = node->cast<CNodePtr>()->input(1)->cast<ValueNodePtr>();
1420   MS_EXCEPTION_IF_NULL(graph_node);
1421   FuncGraphPtr anf_graph = graph_node->value()->cast<FuncGraphPtr>();
1422   MS_EXCEPTION_IF_NULL(anf_graph);
1423   DfGraphConvertor converter(anf_graph, phase_prefix_);
1424   converter.use_inputs_ = true;
1425 
1426   const auto &params = anf_graph->parameters();
1427   converter.inputs_ = params;
1428 
1429   converter.graph_type_ = GraphType::kBody;
1430   converter.cur_while_node_ = cur_while_node_;
1431   converter.body_cond_map_ = body_cond_map_;
1432   converter.while_const_input_index_ = while_const_input_index_;
1433   converter.while_used_input_index_ = while_used_input_index_;
1434   converter.const_op_to_value_ = const_op_to_value_;
1435   converter.ConvertAllNode().BuildWhileSubGraph();
1436   while_dfgraph_cache_[cur_while_node_]->emplace_back(*(converter.df_graph_));
1437   std::string name = graph_node->ToString() + "_ge_graph.dot";
1438   auto context = MsContext::GetInstance();
1439   MS_EXCEPTION_IF_NULL(context);
1440   if (context->CanDump(kFully)) {
1441     converter.DrawComputeGraph(name);
1442   }
1443   return;
1444 }
1445 
GetWhileUsedInputIndex(const std::vector<AnfNodePtr> & graphs)1446 void DfGraphConvertor::GetWhileUsedInputIndex(const std::vector<AnfNodePtr> &graphs) {
1447   if (!while_used_input_index_.empty()) {
1448     return;
1449   }
1450 
1451   auto cond_graph_node = graphs.at(0);
1452   auto graph = cond_graph_node->func_graph();
1453   MS_EXCEPTION_IF_NULL(graph);
1454   const auto &cond_params = graph->parameters();
1455   auto nodes = GetOrderedCNodes(graph, cond_graph_node);
1456 
1457   std::set<size_t> used_params_index;
1458   for (auto &n : nodes) {
1459     if (!n->isa<CNode>()) {
1460       continue;
1461     }
1462     auto c = n->cast<CNodePtr>();
1463     auto inputs = c->inputs();
1464     for (size_t idx = 1; idx < inputs.size(); idx++) {
1465       auto &i = inputs[idx];
1466       if (!i->isa<Parameter>() || HasAbstractMonad(i) || IsDynamicShapeNode(i)) {
1467         continue;
1468       }
1469       auto idx_cond = std::find(cond_params.begin(), cond_params.end(), i) - cond_params.begin();
1470       (void)used_params_index.insert(idx_cond);
1471     }
1472   }
1473 
1474   auto body_graph_node_in_cond = graphs.at(1)->cast<CNodePtr>();
1475   auto body_graph_node = body_graph_node_in_cond->input(1)->cast<ValueNodePtr>();
1476   MS_EXCEPTION_IF_NULL(body_graph_node);
1477   graph = body_graph_node->value()->cast<FuncGraphPtr>();
1478   const auto &body_params = graph->parameters();
1479 
1480   auto real_ret = graph->get_return()->input(1);
1481   while (real_ret->isa<CNode>() && GetCNodeTargetFuncName(real_ret->cast<CNodePtr>()) == prim::kPrimDepend->name()) {
1482     real_ret = real_ret->cast<CNodePtr>()->input(1);
1483   }
1484 
1485   nodes = GetOrderedCNodes(graph);
1486   for (auto &n : nodes) {
1487     if (!n->isa<CNode>()) {
1488       continue;
1489     }
1490     auto c = n->cast<CNodePtr>();
1491     if (c == real_ret || c == real_ret->cast<CNodePtr>()->input(0)) {
1492       continue;
1493     }
1494     auto inputs = c->inputs();
1495     for (size_t idx = 1; idx < inputs.size(); idx++) {
1496       auto &i = inputs[idx];
1497       if (!i->isa<Parameter>() || HasAbstractMonad(i) || IsDynamicShapeNode(i)) {
1498         continue;
1499       }
1500       auto idx_body = std::find(body_params.begin(), body_params.end(), i) - body_params.begin();
1501       auto p = body_graph_node_in_cond->input(static_cast<size_t>(idx_body + kInputOffset));
1502       auto idx_cond = std::find(cond_params.begin(), cond_params.end(), p) - cond_params.begin();
1503       (void)used_params_index.insert(idx_cond);
1504     }
1505   }
1506   while_used_input_index_ = used_params_index;
1507 }
1508 
SetParamIndexMap(const std::vector<AnfNodePtr> & graphs)1509 void DfGraphConvertor::SetParamIndexMap(const std::vector<AnfNodePtr> &graphs) {
1510   auto cond_graph_node = graphs.at(0);
1511   MS_EXCEPTION_IF_NULL(cond_graph_node);
1512   auto cond_graph = cond_graph_node->func_graph();
1513   MS_EXCEPTION_IF_NULL(cond_graph);
1514   const auto &cond_params = cond_graph->parameters();
1515 
1516   auto body_graph_node = graphs.at(1);
1517   MS_EXCEPTION_IF_NULL(body_graph_node);
1518   if (!body_graph_node->isa<CNode>()) {
1519     return;
1520   }
1521   MS_EXCEPTION_IF_NULL(body_graph_node->cast<CNodePtr>());
1522   auto body_graph_node_inputs = body_graph_node->cast<CNodePtr>()->inputs();
1523   std::vector<AnfNodePtr> body_params;
1524   for (auto it = body_graph_node_inputs.begin() + kInputOffset; it != body_graph_node_inputs.end(); ++it) {
1525     body_params.emplace_back(*it);
1526   }
1527 
1528   for (size_t i = 0; i < body_params.size(); i++) {
1529     auto p = body_params[i];
1530     int64_t idx = find(cond_params.begin(), cond_params.end(), p) - cond_params.begin();
1531     body_cond_map_[i] = static_cast<size_t>(idx);
1532     MS_LOG(DEBUG) << "body_cond_map_'s key: " << i << " value: " << idx;
1533   }
1534 
1535   auto after_graph_node = graphs.at(kSwitchBodyIndex);
1536   MS_EXCEPTION_IF_NULL(after_graph_node);
1537   if (!after_graph_node->isa<CNode>()) {
1538     return;
1539   }
1540   MS_EXCEPTION_IF_NULL(after_graph_node->cast<CNodePtr>());
1541   auto after_graph_node_inputs = after_graph_node->cast<CNodePtr>()->inputs();
1542   std::vector<AnfNodePtr> after_params;
1543   for (auto it = after_graph_node_inputs.begin() + 2; it != after_graph_node_inputs.end(); ++it) {
1544     after_params.emplace_back(*it);
1545   }
1546 
1547   for (size_t i = 0; i < after_params.size(); i++) {
1548     auto p = after_params[i];
1549     int64_t idx = find(cond_params.begin(), cond_params.end(), p) - cond_params.begin();
1550     after_cond_map_[i] = static_cast<size_t>(idx);
1551     MS_LOG(DEBUG) << "after_cond_map_'s key: " << i << " value: " << idx;
1552   }
1553   return;
1554 }
1555 
ConvertWhileCond(const AnfNodePtr & node)1556 void DfGraphConvertor::ConvertWhileCond(const AnfNodePtr &node) {
1557   MS_LOG(DEBUG) << "begin to convert while node cond graph";
1558   auto func_graph = node->func_graph();
1559   MS_EXCEPTION_IF_NULL(func_graph);
1560 
1561   DfGraphConvertor converter(func_graph, phase_prefix_);
1562   converter.use_inputs_ = true;
1563 
1564   converter.inputs_ = func_graph->parameters();
1565 
1566   converter.graph_type_ = GraphType::kCond;
1567   converter.cur_while_node_ = cur_while_node_;
1568   converter.while_cond_node_ = node;
1569   converter.while_const_input_index_ = while_const_input_index_;
1570   converter.while_used_input_index_ = while_used_input_index_;
1571   converter.const_op_to_value_ = const_op_to_value_;
1572   converter.ConvertAllNode().BuildWhileSubGraph();
1573   MS_EXCEPTION_IF_NULL(while_dfgraph_cache_[cur_while_node_]);
1574   while_dfgraph_cache_[cur_while_node_]->emplace_back(*(converter.df_graph_));
1575   std::string name = func_graph->ToString() + "_ge_graph.dot";
1576   auto context = MsContext::GetInstance();
1577   MS_EXCEPTION_IF_NULL(context);
1578   if (context->CanDump(kFully)) {
1579     converter.DrawComputeGraph(name);
1580   }
1581 
1582   MS_LOG(DEBUG) << "convert while node cond graph end";
1583 }
1584 
SetWhileOutputHandle(const OperatorPtr & prev_while_op)1585 void DfGraphConvertor::SetWhileOutputHandle(const OperatorPtr &prev_while_op) {
1586   if (while_output_handle_cache_.find(prev_while_node_) != while_output_handle_cache_.end()) {
1587     return;
1588   }
1589   auto out_handler = std::make_shared<std::vector<OutHandler>>();
1590   MS_EXCEPTION_IF_NULL(out_handler);
1591   string str = "output";
1592   for (size_t i = 0; i < prev_while_node_out_size_; i++) {
1593     (void)out_handler->emplace_back(prev_while_op, str + std::to_string(i), prev_while_node_);
1594   }
1595   while_output_handle_cache_[prev_while_node_] = out_handler;
1596   return;
1597 }
1598 
ConvertWhileAfter(const AnfNodePtr & node)1599 void DfGraphConvertor::ConvertWhileAfter(const AnfNodePtr &node) {
1600   MS_EXCEPTION_IF_NULL(node);
1601   if (!node->isa<CNode>() || GetCNodeFuncName(node->cast<CNodePtr>()) != prim::kPrimPartial->name()) {
1602     return;
1603   }
1604   MS_LOG(DEBUG) << "begin to convert while node after graph";
1605   MS_EXCEPTION_IF_NULL(node->cast<CNodePtr>()->input(1));
1606   auto graph_node = node->cast<CNodePtr>()->input(1)->cast<ValueNodePtr>();
1607   MS_EXCEPTION_IF_NULL(graph_node);
1608   MS_EXCEPTION_IF_NULL(graph_node->value());
1609   FuncGraphPtr anf_graph = graph_node->value()->cast<FuncGraphPtr>();
1610   MS_EXCEPTION_IF_NULL(anf_graph);
1611   DfGraphConvertor converter(anf_graph, phase_prefix_);
1612   converter.use_inputs_ = true;
1613 
1614   const auto &params = anf_graph->parameters();
1615   converter.inputs_ = params;
1616 
1617   converter.graph_type_ = GraphType::kAfter;
1618   converter.prev_after_cond_map_ = after_cond_map_;
1619   converter.prev_while_node_ = cur_while_node_;
1620   converter.prev_while_node_out_size_ = cur_while_node_out_size_;
1621   converter.bypass_node_prev_handle_cache_ = bypass_node_handle_cache_;
1622   converter.prev_while_used_input_index_ = while_used_input_index_;
1623   converter.prev_while_const_input_index_ = while_const_input_index_;
1624   converter.const_op_to_value_ = const_op_to_value_;
1625 
1626   auto while_op = Convert(converter.prev_while_node_);
1627   converter.SetWhileOutputHandle(while_op);
1628   converter.ConvertAllNode().BuildWhileAfterSubGraph();
1629   std::string name = graph_node->ToString() + "_ge_graph.dot";
1630   auto context = MsContext::GetInstance();
1631   MS_EXCEPTION_IF_NULL(context);
1632   if (context->CanDump(kFully)) {
1633     converter.DrawComputeGraph(name);
1634   }
1635   MS_LOG(DEBUG) << "add while after graph " << converter.graph_const_inputs_.size()
1636                 << " const inputs to main graph const inputs";
1637   (void)std::transform(converter.graph_const_inputs_.begin(), converter.graph_const_inputs_.end(),
1638                        std::back_inserter(graph_const_inputs_), [](OperatorPtr x) { return x; });
1639 
1640   graph_outputs_ = converter.graph_outputs_;
1641   MS_LOG(DEBUG) << "convert while node after graph end";
1642   return;
1643 }
1644 
ConvertWhileNode(const CNodePtr & node)1645 void DfGraphConvertor::ConvertWhileNode(const CNodePtr &node) {
1646   if (IsSubGraph()) {
1647     return;
1648   }
1649 
1650   auto while_graph = while_graph_cache_[node];
1651   cur_while_node_ = node;
1652 
1653   auto &while_inputs = *(call_input_handle_cache_[node]);
1654   cur_while_node_out_size_ = while_inputs.size();
1655   while_dfgraph_cache_[node] = std::make_shared<std::vector<DfGraph>>();
1656   // convert cond graph
1657   auto cond_graph_node = while_graph[0];
1658   ConvertWhileCond(cond_graph_node);
1659 
1660   // convert body graph
1661   auto body_graph_node = while_graph[1];
1662   ConvertWhileBody(body_graph_node);
1663 
1664   OpAdapterPtr adpt = FindAdapter(node, training_);
1665   if (adpt == nullptr) {
1666     MS_LOG(DEBUG) << "Not found adapter";
1667     return;
1668   }
1669 
1670   OperatorPtr op = Convert(node);
1671   auto graphs = while_dfgraph_cache_[node];
1672   adpt->setSubgraph(op, graphs);
1673 
1674   // convert after graph
1675   auto after_graph_node = while_graph[kAfterIndexInCache];
1676   ConvertWhileAfter(after_graph_node);
1677   return;
1678 }
1679 
BuildBranchGraphs(const CNodePtr & cnode)1680 std::shared_ptr<std::vector<DfGraph>> DfGraphConvertor::BuildBranchGraphs(const CNodePtr &cnode) {
1681   MS_EXCEPTION_IF_NULL(cnode);
1682   bool is_case = IsCaseNode(cnode);
1683   std::shared_ptr<std::vector<DfGraph>> df_branches = std::make_shared<std::vector<DfGraph>>();
1684   MS_EXCEPTION_IF_NULL(df_branches);
1685   if (IsNormalGraph() || IsBodyGraph() || IsBranchGraph()) {
1686     size_t branch_call_input_size = 0;
1687     size_t node_input_index = 0;
1688     if (!is_kernel_graph_) {
1689       for (size_t i = 1; i < cnode->size(); i++) {
1690         auto pred = cnode->input(i);
1691         if (!IsDataInput(cnode, pred, 0)) {
1692           continue;
1693         }
1694         node_input_index++;
1695         branch_call_input_size++;
1696       }
1697     }
1698     MS_EXCEPTION_IF_NULL(cnode->input(0));
1699     CNodePtr input_node = is_kernel_graph_ ? cnode : cnode->input(0)->cast<CNodePtr>();
1700     MS_EXCEPTION_IF_NULL(input_node);
1701     MS_EXCEPTION_IF_NULL(input_node->input(kInputOffset));
1702     auto bnode = is_case ? input_node->input(kInputOffset)->cast<CNodePtr>() : input_node->cast<CNodePtr>();
1703     MS_EXCEPTION_IF_NULL(bnode);
1704     const size_t init_i = is_case ? 1 : 2;
1705 
1706     for (size_t i = init_i; i < bnode->size(); i++) {
1707       ParamIndexMap branch_to_parent_node_map;
1708       size_t branch_index = 0;  //  branch graph input's index
1709       if (bnode->input(i)->isa<CNode>()) {
1710         auto branch_node = bnode->input(i)->cast<CNodePtr>();
1711         MS_EXCEPTION_IF_NULL(branch_node);
1712         for (size_t j = kInputOffset; j < branch_node->size(); j++) {
1713           auto pred = branch_node->input(j);
1714           if (!IsDataInput(cnode, pred, 0)) {
1715             continue;
1716           }
1717           branch_to_parent_node_map[branch_index] = node_input_index;
1718           node_input_index++;
1719           branch_index++;
1720         }
1721       }
1722       if (!is_kernel_graph_) {
1723         for (size_t k = 0; k < branch_call_input_size; k++) {
1724           branch_to_parent_node_map[branch_index] = k;
1725           branch_index++;
1726         }
1727       }
1728       ProcessSubgraph(cnode, bnode->input(i), branch_to_parent_node_map);
1729       (void)(df_branches->emplace_back(branches_map_[bnode->input(i).get()]));
1730     }
1731   }
1732   return df_branches;
1733 }
1734 
BuildCallSubGraphs(const AnfNodePtr & node)1735 void DfGraphConvertor::BuildCallSubGraphs(const AnfNodePtr &node) {
1736   MS_EXCEPTION_IF_NULL(node);
1737   auto cnode = node->cast<CNodePtr>();
1738   MS_EXCEPTION_IF_NULL(cnode);
1739   MS_EXCEPTION_IF_NULL(cnode->input(1));
1740   ValueNodePtr graph_node = cnode->input(1)->cast<ValueNodePtr>();
1741   MS_EXCEPTION_IF_NULL(graph_node);
1742   MS_EXCEPTION_IF_NULL(graph_node->value());
1743   auto anf_graph = graph_node->value()->cast<AnfGraphPtr>();
1744   MS_EXCEPTION_IF_NULL(anf_graph);
1745   DfGraphConvertor converter(anf_graph, phase_prefix_);
1746   converter.graph_type_ = GraphType::kNormal;
1747   converter.use_inputs_ = true;
1748   converter.inputs_ = anf_graph->parameters();
1749   std::string graph_name = anf_graph->ToString();
1750   auto iter = call_subgraphs_repeat_times.find(graph_name);
1751   if (iter == call_subgraphs_repeat_times.end()) {
1752     call_subgraphs_repeat_times[graph_name] = 1;
1753   } else {
1754     iter->second += 1;
1755     graph_name = graph_name + "_call_" + std::to_string(iter->second);
1756   }
1757   (void)converter.ConvertAllNode().BuildGraph(graph_name);
1758 
1759   call_dfgraph_cache_[node] = std::make_shared<std::vector<DfGraph>>();
1760   MS_EXCEPTION_IF_NULL(call_dfgraph_cache_[node]);
1761   call_dfgraph_cache_[node]->emplace_back(*(converter.df_graph_));
1762   MS_LOG(INFO) << "build call subgraph end.";
1763 }
1764 
SetSubgraph(const AnfNodePtr & node)1765 void DfGraphConvertor::SetSubgraph(const AnfNodePtr &node) {
1766   MS_EXCEPTION_IF_NULL(node);
1767   if (!node->isa<CNode>()) {
1768     return;
1769   }
1770   auto cnode = node->cast<CNodePtr>();
1771   if (IsWhileNode(cnode)) {
1772     MS_LOG(DEBUG) << "Start to set while's sub graph.";
1773     CacheWhileGraph(cnode);
1774     ConvertWhileNode(cnode);
1775     MS_LOG(DEBUG) << "Set while's sub graph end.";
1776     return;
1777   }
1778 
1779   if (IsBranchNode(cnode)) {
1780     MS_LOG(DEBUG) << "Start to set if/case's sub graph.";
1781     std::shared_ptr<std::vector<DfGraph>> df_branches = BuildBranchGraphs(cnode);
1782     if (op_cache_.find(node.get()) == op_cache_.end()) {
1783       return;
1784     }
1785 
1786     OpAdapterPtr adpt = FindAdapter(node, training_);
1787     if (adpt == nullptr) {
1788       MS_LOG(DEBUG) << "Not found adapter";
1789       return;
1790     }
1791 
1792     OperatorPtr op = Convert(node);
1793     bool is_case = IsCaseNode(node);
1794     if (is_case) {
1795       adpt->setSubgraph(op, 0, df_branches);
1796     } else {
1797       adpt->setSubgraph(op, df_branches);
1798     }
1799     MS_LOG(DEBUG) << "Set if/case's sub graph end.";
1800     return;
1801   }
1802 
1803   if (IsCallNode(cnode)) {
1804     MS_LOG(DEBUG) << "Start to set call's sub graph.";
1805     BuildCallSubGraphs(node);
1806     if (op_cache_.find(node.get()) == op_cache_.end()) {
1807       return;
1808     }
1809     OpAdapterPtr adpt = FindAdapter(node, training_);
1810     if (adpt == nullptr) {
1811       MS_LOG(EXCEPTION) << "Not found adapter";
1812       return;
1813     }
1814     OperatorPtr op = Convert(node);
1815     auto df_graphs = call_dfgraph_cache_[node];
1816     adpt->setSubgraph(op, df_graphs);
1817     MS_LOG(DEBUG) << "Set call's sub graph end.";
1818   }
1819   return;
1820 }
1821 
GetBranchNodeInput(const CNodePtr node)1822 void DfGraphConvertor::GetBranchNodeInput(const CNodePtr node) {
1823   if (branch_input_handle_cache_.find(node.get()) != branch_input_handle_cache_.end()) {
1824     return;
1825   }
1826   bool is_case = IsCaseNode(node);
1827   std::vector<AnfNodePtr> branch_inputs;
1828   const size_t branch_index = 1;
1829 
1830   MS_EXCEPTION_IF_NULL(node);
1831   MS_EXCEPTION_IF_NULL(node->input(0));
1832   CNodePtr sw_node = is_kernel_graph_ ? node : node->input(0)->cast<CNodePtr>();
1833   MS_EXCEPTION_IF_NULL(sw_node);
1834   AnfNodePtr branch_index_iter = sw_node->input(branch_index);
1835   AnfNodePtr branch_dyn_input_node = nullptr;
1836   const size_t make_tuple_index = 2;
1837   AnfNodePtr make_tuple_iter = sw_node->input(make_tuple_index);
1838   branch_dyn_input_node = make_tuple_iter;  // switch node's 2nd input as dyn input
1839 
1840   std::shared_ptr<std::vector<OutHandler>> tuple_items = std::make_shared<std::vector<OutHandler>>();
1841   MS_EXCEPTION_IF_NULL(tuple_items);
1842 
1843   CNodePtr input_node = node;
1844   if (!is_kernel_graph_) {
1845     for (size_t i = 1; i < node->size(); i++) {
1846       auto pred = node->input(i);
1847       (void)(branch_inputs.emplace_back(pred));
1848     }
1849     input_node = node->input(0)->cast<CNodePtr>();
1850   }
1851   MS_EXCEPTION_IF_NULL(input_node);
1852   auto bnode = is_case ? input_node->input(make_tuple_index)->cast<CNodePtr>() : input_node;
1853   MS_EXCEPTION_IF_NULL(bnode);
1854   const size_t init_i = is_case ? 1 : 2;
1855   for (size_t i = init_i; i < bnode->size(); ++i) {
1856     const auto &bnode_input = bnode->input(i);
1857     MS_EXCEPTION_IF_NULL(bnode_input);
1858     if (!bnode_input->isa<CNode>()) {
1859       continue;
1860     }
1861     auto branch_node = bnode_input->cast<CNodePtr>();
1862     MS_EXCEPTION_IF_NULL(branch_node);
1863     for (size_t j = 2; j < branch_node->size(); ++j) {
1864       auto pred = branch_node->input(j);
1865       (void)(branch_inputs.emplace_back(pred));
1866     }
1867   }
1868   std::vector<AnfNodePtr> branch_control_input;
1869   for (size_t i = 0; i < branch_inputs.size(); i++) {
1870     auto item = branch_inputs[i];
1871     if (!IsDataInput(node, item, 0)) {
1872       branch_control_input.emplace_back(item);
1873       continue;
1874     }
1875     if (IsBodyGraph() && item->isa<Parameter>()) {
1876       auto idx = std::find(inputs_.begin(), inputs_.end(), item) - inputs_.begin();
1877       (void)(tuple_items->emplace_back(subgraph_input_cache_[idx], "", item));
1878     } else {
1879       auto hd = GetHandler(item);
1880       tuple_items->emplace_back(hd);
1881     }
1882   }
1883   tuple_out_handle_cache_[branch_dyn_input_node.get()] = tuple_items;
1884 
1885   std::shared_ptr<std::vector<AnfNodePtr>> branch_input_items = std::make_shared<std::vector<AnfNodePtr>>();
1886   MS_EXCEPTION_IF_NULL(branch_input_items);
1887   (void)branch_input_items->emplace_back(branch_index_iter);
1888   (void)branch_input_items->emplace_back(branch_dyn_input_node);
1889 
1890   (void)std::copy(branch_control_input.begin(), branch_control_input.end(), std::back_inserter(*branch_input_items));
1891   branch_input_handle_cache_[node.get()] = branch_input_items;
1892   return;
1893 }
1894 
GetCallNodeInputs(const CNodePtr & node)1895 void DfGraphConvertor::GetCallNodeInputs(const CNodePtr &node) {
1896   if (node == nullptr) {
1897     return;
1898   }
1899   if (call_input_handle_cache_.find(node) != call_input_handle_cache_.end()) {
1900     return;
1901   }
1902 
1903   auto call_input_items = std::make_shared<std::vector<OutHandler>>();
1904   MS_EXCEPTION_IF_NULL(call_input_items);
1905   std::vector<AnfNodePtr> inputs = GetAnfCallInputs(is_kernel_graph_, node);
1906 
1907   auto &params = anf_graph_->parameters();
1908   auto while_op = Convert(node);
1909 
1910   while_const_input_index_.clear();
1911   std::set<size_t> while_input_node_index;
1912   for (auto iter = while_used_input_index_.begin(); iter != while_used_input_index_.end(); ++iter) {
1913     auto n = inputs[*iter];
1914     MS_EXCEPTION_IF_NULL(n);
1915     OutHandler out_handler;
1916     if (IsAfterGraph() && n->isa<Parameter>()) {
1917       auto idx = std::find(params.begin(), params.end(), n) - params.begin();
1918       auto idx_cond = prev_after_cond_map_[idx];
1919       if (bypass_node_prev_handle_cache_.find(idx_cond) != bypass_node_prev_handle_cache_.end()) {
1920         out_handler = bypass_node_prev_handle_cache_[idx_cond];
1921       } else {
1922         auto idx_out = prev_cond_to_while_out_index_[idx_cond];
1923         out_handler = while_output_handle_cache_[prev_while_node_]->at(idx_out);
1924       }
1925     } else {
1926       out_handler = GetHandler(inputs[*iter]);
1927     }
1928     MS_EXCEPTION_IF_NULL(out_handler.op);
1929     if ((out_handler.op->GetOpType() == "Const" || out_handler.op->GetOpType() == "Constant") &&
1930         const_op_to_value_.find(out_handler.op) != const_op_to_value_.end()) {
1931       while_const_input_index_[*iter] = out_handler;
1932     } else {
1933       (void)while_input_node_index.insert(*iter);
1934       call_input_items->emplace_back(out_handler);
1935     }
1936   }
1937   cur_while_node_out_size_ = call_input_items->size();
1938   bypass_node_handle_cache_.clear();
1939 
1940   for (size_t i = 0; i < inputs.size(); i++) {
1941     if (while_input_node_index.find(i) == while_input_node_index.end()) {
1942       auto n = inputs[i];
1943       MS_EXCEPTION_IF_NULL(n);
1944       if (HasAbstractMonad(n)) {
1945         continue;
1946       }
1947       if (IsAfterGraph() && n->isa<Parameter>()) {
1948         auto idx = std::find(params.begin(), params.end(), n) - params.begin();
1949         auto idx_cond = prev_after_cond_map_[idx];
1950         if (bypass_node_prev_handle_cache_.find(idx_cond) != bypass_node_prev_handle_cache_.end()) {
1951           bypass_node_handle_cache_[i] = bypass_node_prev_handle_cache_[idx_cond];
1952         } else {
1953           auto idx_out = prev_cond_to_while_out_index_[idx_cond];
1954           bypass_node_handle_cache_[i] = while_output_handle_cache_[prev_while_node_]->at(idx_out);
1955         }
1956       } else {
1957         bypass_node_handle_cache_[i] = GetHandler(n);
1958       }
1959     }
1960   }
1961 
1962   auto op = Convert(node);
1963   auto adpt = FindAdapter(node, training_);
1964   MS_EXCEPTION_IF_NULL(adpt);
1965   adpt->setDynamicOutputNum(op, cur_while_node_out_size_);
1966   call_input_handle_cache_[node] = call_input_items;
1967   return;
1968 }
1969 
SetGraphInputs(std::vector<Operator> * inputs)1970 void DfGraphConvertor::SetGraphInputs(std::vector<Operator> *inputs) {
1971   if (IsNormalGraph() && ConfigManager::GetInstance().dataset_mode() == DS_SINK_MODE) {
1972     auto ms_context = MsContext::GetInstance();
1973     MS_EXCEPTION_IF_NULL(ms_context);
1974     std::vector<PrimitivePtr> input_prims;
1975     if (ms_context->get_param<bool>(MS_CTX_ENABLE_GE_HETEROGENOUS)) {
1976       input_prims = {prim::kPrimQueueData};
1977     } else {
1978       input_prims = {prim::kPrimGetNext, prim::kPrimDynamicGetNextV2};
1979     }
1980 
1981     OperatorPtr input;
1982     auto nodes = GetOrderedCNodes(anf_graph_);
1983     for (auto &it : nodes) {
1984       if (std::any_of(input_prims.begin(), input_prims.end(),
1985                       [&it](const PrimitivePtr &prim) { return IsPrimitiveCNode(it, prim); })) {
1986         auto it_op = op_cache_.find(it.get());
1987         if (it_op != op_cache_.end()) {
1988           input = it_op->second;
1989           break;
1990         } else {
1991           MS_LOG(EXCEPTION) << "Can not find the operator of node: " << it->fullname_with_scope();
1992         }
1993       }
1994     }
1995     if (input == nullptr) {
1996       MS_LOG(EXCEPTION) << "Can not find the GetNext node in graph in sink_mode, please check.";
1997     }
1998     inputs->emplace_back(*input);
1999 
2000     MS_EXCEPTION_IF_NULL(anf_graph_);
2001     anf_graph_->set_flag(kGraphFlagHasGetNext, true);
2002   } else {
2003     auto params = anf_graph_->parameters();
2004     int index = 0;
2005     for (auto &it : params) {
2006       auto param = it->cast<ParameterPtr>();
2007       MS_EXCEPTION_IF_NULL(param);
2008       auto name = param->name();
2009       if (std::find(init_data_names_.begin(), init_data_names_.end(), name) == init_data_names_.end()) {
2010         const auto &param_shape = param->Shape();
2011         MS_EXCEPTION_IF_NULL(param_shape);
2012         const auto &shape = param_shape->cast<abstract::ShapePtr>();
2013         if (shape != nullptr) {
2014           const auto &sv = shape->shape();
2015           if (IsDynamic(sv)) {
2016             dynamic_shape_inputs_ = true;
2017           }
2018           input_shapes_.emplace_back(sv);
2019         }
2020       }
2021       //  the parameters which has not been converted to var
2022       if (vars_.find(name) == vars_.end()) {
2023         if (HasAbstractMonad(it)) {
2024           MS_LOG(INFO) << it->DebugString() << " is a monad parameter, skip.";
2025           continue;
2026         }
2027         auto op = Convert(it);
2028         MS_EXCEPTION_IF_NULL(op);
2029         MS_LOG(INFO) << "add not var input " << it->ToString() << ", index " << index;
2030         if (op == nullptr) {
2031           MS_LOG(ERROR) << "Convert graph failed!";
2032           return;
2033         }
2034         UpdateDataOpDesc(it, op);
2035 
2036         if (IsNormalGraph()) {
2037           MS_LOG(INFO) << "add input " << it->ToString() << ", index " << index;
2038           SetXDataIndex(op, index);
2039           index++;
2040         }
2041         inputs->emplace_back(*op);
2042       } else if (vars_[name] != nullptr) {
2043         MS_LOG(INFO) << "add var input " << it->ToString();
2044         auto op = Convert(it);
2045         MS_EXCEPTION_IF_NULL(op);
2046         UpdateConstOpDesc(it, vars_[name]);
2047         inputs->emplace_back(*op);
2048       }
2049     }
2050   }
2051 }
2052 
IsConstantOp(const OperatorPtr & op) const2053 bool DfGraphConvertor::IsConstantOp(const OperatorPtr &op) const {
2054   if (op == nullptr) {
2055     return false;
2056   }
2057   return (op->GetOpType() == "Constant" || op->GetOpType() == "Const");
2058 }
2059 
SetGraphInputsForNotVar(const AnfNodePtr & it,int64_t * index,std::vector<Operator> * inputs)2060 OperatorPtr DfGraphConvertor::SetGraphInputsForNotVar(const AnfNodePtr &it, int64_t *index,
2061                                                       std::vector<Operator> *inputs) {
2062   MS_EXCEPTION_IF_NULL(index);
2063   MS_EXCEPTION_IF_NULL(inputs);
2064   auto op = Convert(it);
2065   if (op == nullptr) {
2066     MS_LOG(EXCEPTION) << "Convert graph failed!";
2067   }
2068   UpdateDataOpDesc(it, op);
2069   if (IsNormalGraph()) {
2070     MS_LOG(INFO) << "add input " << it->ToString() << ", index " << *index;
2071     auto op_type = op->GetOpType();
2072     if (op_type == kTypeData || op_type == kTypeRefData) {
2073       SetXDataIndex(op, (*index));
2074       (*index)++;
2075     } else {
2076       auto name = std::static_pointer_cast<Parameter>(it)->name();
2077       MS_LOG(EXCEPTION) << "Op " << name << " is invalid type " << op->GetOpType() << " as graph input.";
2078     }
2079   }
2080   inputs->push_back(*op);
2081   return op;
2082 }
2083 
SetGraphInputs(std::vector<Operator> * inputs,AnfNodeWeakPtrList * ge_inputs)2084 void DfGraphConvertor::SetGraphInputs(std::vector<Operator> *inputs, AnfNodeWeakPtrList *ge_inputs) {
2085   MS_EXCEPTION_IF_NULL(inputs);
2086   MS_EXCEPTION_IF_NULL(ge_inputs);
2087   MS_LOG(INFO) << "IsNormalGraph=" << IsNormalGraph() << ", dataset_mode"
2088                << ConfigManager::GetInstance().dataset_mode();
2089   AddInputInDataSink(inputs);
2090   auto params = anf_graph_->parameters();
2091   MS_LOG(INFO) << "Parameters size " << params.size();
2092   int64_t index = 0;
2093   std::set<std::string> name_records = {};
2094   for (auto &it : params) {
2095     auto name = std::static_pointer_cast<Parameter>(it)->name();
2096     OperatorPtr op;
2097     //  the parameters which has not been converted to var
2098     if (vars_.find(name) == vars_.end()) {
2099       auto abs = it->abstract();
2100       MS_EXCEPTION_IF_NULL(abs);
2101       if (HasAbstractMonad(it) || abs->isa<abstract::AbstractSequence>()) {
2102         MS_LOG(INFO) << it->DebugString() << " is a monad or tuple/list parameter, skip.";
2103         continue;
2104       }
2105       op = SetGraphInputsForNotVar(it, &index, inputs);
2106     } else if (vars_[name] != nullptr) {
2107       MS_LOG(INFO) << "add var input " << it->ToString() << ", index " << index;
2108       op = Convert(it);
2109       MS_EXCEPTION_IF_NULL(op);
2110       if (name_records.count(name) != 0) {
2111         // two parameters have same ref_key
2112         MS_LOG(INFO) << "var input " << it->ToString() << " is already added";
2113         continue;
2114       }
2115       (void)name_records.insert(name);
2116       UpdateConstOpDesc(it, vars_[name]);
2117       auto op_type = op->GetOpType();
2118       if (op_type == kTypeRefData) {
2119         SetXDataIndex(op, index);
2120         index++;
2121       } else if (IsConstantOp(op)) {
2122         continue;
2123       } else {
2124         MS_LOG(EXCEPTION) << "Op " << name << " is invalid type " << op->GetOpType() << " as graph input.";
2125       }
2126       inputs->push_back(*op);
2127     }
2128     (void)ge_inputs->emplace_back(AnfNodeWeakPtr(it));
2129   }
2130   MS_LOG(INFO) << "Input size " << inputs->size();
2131 }
2132 
AddInputInDataSink(vector<Operator> * inputs)2133 void DfGraphConvertor::AddInputInDataSink(vector<Operator> *inputs) {
2134   MS_EXCEPTION_IF_NULL(inputs);
2135   auto ms_context = MsContext::GetInstance();
2136   MS_EXCEPTION_IF_NULL(ms_context);
2137   std::vector<PrimitivePtr> input_prims;
2138   if (ms_context->get_param<bool>(MS_CTX_ENABLE_GE_HETEROGENOUS)) {
2139     input_prims = {prim::kPrimQueueData};
2140   } else {
2141     input_prims = {prim::kPrimGetNext, prim::kPrimDynamicGetNextV2};
2142   }
2143   OperatorPtr input = nullptr;
2144   auto nodes = GetOrderedCNodes(anf_graph_);
2145   for (auto &it : nodes) {
2146     if (std::any_of(input_prims.begin(), input_prims.end(),
2147                     [&it](const PrimitivePtr &prim) { return IsPrimitiveCNode(it, prim); })) {
2148       auto it_op = op_cache_.find(it.get());
2149       if (it_op != op_cache_.end()) {
2150         input = it_op->second;
2151         break;
2152       } else {
2153         MS_LOG(EXCEPTION) << "Can not find the operator of node: " << it->fullname_with_scope();
2154       }
2155     }
2156   }
2157   if (IsNormalGraph() && ConfigManager::GetInstance().dataset_mode() == DS_SINK_MODE && input != nullptr) {
2158     (void)inputs->emplace_back(*input);
2159     MS_EXCEPTION_IF_NULL(anf_graph_);
2160     anf_graph_->set_flag(kGraphFlagHasGetNext, true);
2161   }
2162 }
2163 
BuildInitDataGraph(const std::string & name)2164 void DfGraphConvertor::BuildInitDataGraph(const std::string &name) {
2165   MS_LOG(INFO) << "Start BuildInitDataGraph.";
2166 
2167   // If MS_CTX_ENABLE_GE_HETEROGENOUS is true, no need InitData graph
2168   auto ms_context = MsContext::GetInstance();
2169   MS_EXCEPTION_IF_NULL(ms_context);
2170   if (ms_context->get_param<bool>(MS_CTX_ENABLE_GE_HETEROGENOUS)) {
2171     df_graph_ = nullptr;
2172     return;
2173   }
2174 
2175   AnfNodePtr init_dataset_queue_node = nullptr;
2176   auto nodes = GetOrderedCNodes(anf_graph_);
2177   for (auto &it : nodes) {
2178     if (IsInitDataSetQueueNode(it)) {
2179       init_dataset_queue_node = it;
2180       break;
2181     }
2182   }
2183   OperatorPtr init_data_op = Convert(init_dataset_queue_node);
2184   MS_EXCEPTION_IF_NULL(init_data_op);
2185   if (error_ != SUCCESS) {
2186     return;
2187   }
2188   std::vector<::ge::Operator> inputs{*init_data_op};
2189   std::vector<::ge::Operator> outputs{*init_data_op};
2190   df_graph_ = make_shared<DfGraph>(name);
2191   (void)df_graph_->SetInputs(inputs);
2192   (void)df_graph_->SetOutputs(outputs);
2193   MS_LOG(INFO) << "End BuildInitDataGraph.";
2194 }
2195 
FillEmptyInputsWithNoInputOp(std::vector<Operator> * inputs)2196 void DfGraphConvertor::FillEmptyInputsWithNoInputOp(std::vector<Operator> *inputs) {
2197   MS_EXCEPTION_IF_NULL(inputs);
2198   MS_LOG(INFO) << "Fill empty graph inputs with cnode whose inputs are empty.";
2199   auto nodes = GetOrderedCNodes(anf_graph_);
2200   for (auto &it : nodes) {
2201     if (!it->isa<CNode>()) {
2202       continue;
2203     }
2204     std::string name = common::AnfAlgo::GetCNodeName(it);
2205     if (name == prim::kPrimSwitch->name() || name == prim::kPrimSwitchLayer->name() ||
2206         name == prim::kPrimPartial->name()) {
2207       continue;
2208     }
2209     auto adpt = FindAdapter(it, training_);
2210     if (adpt == nullptr) {
2211       continue;
2212     }
2213     if (adpt->getInputMap().empty() && adpt->getAttrInputMap().empty()) {
2214       auto cnode_op = op_cache_.find(it.get());
2215       if (cnode_op != op_cache_.end()) {
2216         (void)inputs->emplace_back(*(cnode_op->second));
2217         break;
2218       } else {
2219         MS_LOG(EXCEPTION) << "Can not find the operator of node: " << it->fullname_with_scope();
2220       }
2221     }
2222   }
2223 }
2224 
SetupInputFormat(const FuncGraphManagerPtr & manager,const AnfNodePtr & node)2225 void DfGraphConvertor::SetupInputFormat(const FuncGraphManagerPtr &manager, const AnfNodePtr &node) {
2226   if (!node->isa<Parameter>()) {
2227     return;
2228   }
2229   auto para = node->cast<ParameterPtr>();
2230   std::vector<int64_t> shape;
2231   TypeId type;
2232   std::string format = kOpFormat_DEFAULT;
2233   if (para->has_default()) {
2234     auto value = para->default_param();
2235     MS_EXCEPTION_IF_NULL(value);
2236     auto tensor = value->cast<std::shared_ptr<tensor::Tensor>>();
2237     MS_EXCEPTION_IF_NULL(tensor);
2238     shape = tensor->shape_c();
2239     type = tensor->data_type();
2240     format = SelectParamOriFormat(manager, para);
2241   } else {
2242     if (auto normal_shape_ptr = dyn_cast<abstract::Shape>(para->Shape()); normal_shape_ptr != nullptr) {
2243       shape = normal_shape_ptr->shape();
2244     } else if (!dyn_cast<abstract::NoShape>(para->Shape())) {
2245       MS_LOG(INFO) << "Invalid shape.";
2246       return;
2247     }
2248     if (para->Type()) {
2249       type = para->Type()->type_id();
2250       if (type == kObjectTypeTensorType) {
2251         type = dyn_cast<TensorType>(para->Type())->element()->type_id();
2252       }
2253     } else {
2254       MS_LOG(INFO) << "Invalid shape.";
2255       return;
2256     }
2257   }
2258   std::string param_debug_info = para->DebugString();
2259   auto param_format = param_format_.find(param_debug_info);
2260   if (param_format != param_format_.end()) {
2261     format = param_format->second;
2262     MS_LOG(DEBUG) << "Parameter debug info: " << param_debug_info << ", format is " << format;
2263   }
2264   auto desc = TransformUtil::GetGeTensorDesc(shape, type, format);
2265   StorageFormatConvertor::SetupStorageFormat(anf_graph_, node, desc);
2266 }
2267 
GenFakeGraphInRefMode()2268 void DfGraphConvertor::GenFakeGraphInRefMode() {
2269   const auto &nodes = GetOrderedCNodes(anf_graph_);
2270   for (const auto &node : nodes) {
2271     if (!node->isa<CNode>()) {
2272       continue;
2273     }
2274     SaveParamFormat(node->cast<CNodePtr>());
2275   }
2276   auto manager = Manage(anf_graph_, true);
2277   MS_EXCEPTION_IF_NULL(manager);
2278   std::vector<AnfNodeWeakPtr> ge_input_nodes = {};
2279   const auto &params = anf_graph_->parameters();
2280   for (auto &node : params) {
2281     MS_EXCEPTION_IF_NULL(node);
2282     auto abs = node->abstract();
2283     MS_EXCEPTION_IF_NULL(abs);
2284     if (HasAbstractMonad(node) || abs->isa<abstract::AbstractSequence>()) {
2285       continue;
2286     }
2287     SetupInputFormat(manager, node);
2288     (void)ge_input_nodes.emplace_back(AnfNodeWeakPtr(node));
2289   }
2290   auto input_name_list = std::make_shared<GEInputList>();
2291   input_name_list->ge_inputs = ge_input_nodes;
2292   anf_graph_->set_user_data(input_name_list);
2293   for (auto &anf_node : params) {
2294     MS_EXCEPTION_IF_NULL(anf_node);
2295     auto para = anf_node->cast<ParameterPtr>();
2296     MS_EXCEPTION_IF_NULL(para);
2297     auto name = para->name();
2298     if (std::find(init_data_names_.begin(), init_data_names_.end(), name) == init_data_names_.end()) {
2299       const auto &param_shape = para->Shape();
2300       MS_EXCEPTION_IF_NULL(param_shape);
2301       const auto &shape = param_shape->cast<abstract::ShapePtr>();
2302       if (shape != nullptr) {
2303         const auto &sv = shape->shape();
2304         if (IsDynamic(sv)) {
2305           dynamic_shape_inputs_ = true;
2306         }
2307         input_shapes_.push_back(sv);
2308       }
2309     }
2310   }
2311 
2312   auto ms_context = MsContext::GetInstance();
2313   MS_EXCEPTION_IF_NULL(ms_context);
2314   // set up init sub graph
2315   static bool is_inited = false;
2316   init_graph_ = nullptr;
2317   bool sink_mode = ConfigManager::GetInstance().dataset_mode() == DS_SINK_MODE;
2318   if (training_ && sink_mode && ms_context->get_param<bool>(MS_CTX_ENABLE_LOOP_SINK) && !is_inited) {
2319     init_graph_ = GenExampleGraph(kInit);
2320     is_inited = true;
2321   }
2322 }
2323 
GenFakeGraph(const std::string & name)2324 void DfGraphConvertor::GenFakeGraph(const std::string &name) {
2325   MS_LOG(INFO) << "Gen fake compute graph " << name;
2326   df_graph_ = GenExampleGraph(name);
2327   MS_EXCEPTION_IF_NULL(df_graph_);
2328   bool sink_mode = ConfigManager::GetInstance().dataset_mode() == DS_SINK_MODE;
2329   if (IsNormalGraph() && sink_mode) {
2330     MS_EXCEPTION_IF_NULL(anf_graph_);
2331     anf_graph_->set_flag(kGraphFlagHasGetNext, true);
2332   }
2333   const auto &params = anf_graph_->parameters();
2334   bool has_weight = std::any_of(params.begin(), params.end(), [](const auto &para) {
2335     auto parameter = para->template cast<ParameterPtr>();
2336     MS_EXCEPTION_IF_NULL(parameter);
2337     return parameter->has_default();
2338   });
2339   if (distribute_ && has_weight) {
2340     this->broadcast_graph_ = GenExampleGraph(kBroadcast);
2341   }
2342   if (!ref_mode_) {
2343     return;
2344   }
2345   GenFakeGraphInRefMode();
2346 }
2347 
BuildGraph(const std::string & name)2348 DfGraphConvertor &DfGraphConvertor::BuildGraph(const std::string &name) {
2349   MS_LOG(INFO) << "Start BuildGraph, graph: " << anf_graph_->ToString();
2350 
2351   if (error_ != SUCCESS) {
2352     return *this;
2353   }
2354 
2355   GetCallNodeInputs(cur_while_node_);
2356   // branch node set input.
2357   bool is_initdata_graph = false;
2358   std::vector<AnfNodePtr> nodes = GetOrderedCNodes(anf_graph_);
2359   for (auto &it : nodes) {
2360     if (IsBranchNode(it)) {
2361       auto node = it->cast<CNodePtr>();
2362       GetBranchNodeInput(node);
2363     }
2364     if (IsInitDataSetQueueNode(it)) {
2365       is_initdata_graph = true;
2366     }
2367   }
2368   auto manager = anf_graph_->manager();
2369   if (manager == nullptr) {
2370     auto new_manager = MakeManager({anf_graph_});
2371     MS_EXCEPTION_IF_NULL(new_manager);
2372     new_manager->AddFuncGraph(anf_graph_);
2373     anf_graph_->set_manager(new_manager);
2374   }
2375 
2376   if (is_initdata_graph) {
2377     BuildInitDataGraph(name);
2378     return *this;
2379   }
2380   nodes = GetOrderedCNodes(anf_graph_);
2381   for (auto &it : nodes) {
2382     SetNodeInput(it);
2383     SetSubgraph(it);
2384     UpdateOpDesc(it);
2385   }
2386 
2387   if (error_ == SUCCESS) {
2388     df_graph_ = make_shared<DfGraph>(name);
2389   } else {
2390     return *this;
2391   }
2392 
2393   // set graph input according to the order from anf graph
2394   std::vector<Operator> inputs;
2395   std::vector<AnfNodeWeakPtr> ge_input_nodes = {};
2396   if (ref_mode_ && !export_air_) {
2397     SetGraphInputs(&inputs, &ge_input_nodes);
2398   } else {
2399     SetGraphInputs(&inputs);
2400   }
2401 
2402   // Add const nodes as graph input for some operator work with constant
2403   MS_LOG(INFO) << "Graph const input size: " << graph_const_inputs_.size();
2404   auto fv_names = GetFvNames(anf_graph_);
2405   for (auto &input : graph_const_inputs_) {
2406     if (fv_names.find(input->GetName()) == fv_names.end()) {
2407       inputs.emplace_back(*input);
2408     }
2409   }
2410 
2411   FillEmptyInputsWithNoInputOp(&inputs);
2412 
2413   MS_LOG(INFO) << "Set graph input num: " << inputs.size();
2414   (void)df_graph_->SetInputs(inputs);
2415 
2416   SetGraphOutputs(true);
2417   (void)df_graph_->SetOutputs(graph_outputs_);
2418 
2419   IdentityOptimization();
2420   NoOpOptimization();
2421   if (has_es_node_) {
2422     ESOptimization();
2423   }
2424 
2425   compute_sout_ << "}" << endl;
2426   // For the graph(e.g. eval_subgraph) whose IterNum is 1, do not set NeedIteration flag.
2427   auto ms_context = MsContext::GetInstance();
2428   MS_EXCEPTION_IF_NULL(ms_context);
2429   if (ConfigManager::GetInstance().iter_num() > 1 && ms_context->get_param<bool>(MS_CTX_ENABLE_LOOP_SINK)) {
2430     df_graph_->SetNeedIteration(true);
2431     anf_graph_->set_flag(kGraphNeedIteration, true);
2432   }
2433   if (ref_mode_) {
2434     std::sort(ref_datas_.begin(), ref_datas_.end(), [](const OperatorPtr &left, const OperatorPtr &right) -> bool {
2435       int64_t left_idx;
2436       int64_t right_idx;
2437       left->GetAttr(kTypeIndex, left_idx);
2438       right->GetAttr(kTypeIndex, right_idx);
2439       return left_idx < right_idx;
2440     });
2441     auto input_name_list = std::make_shared<GEInputList>();
2442     MS_EXCEPTION_IF_NULL(input_name_list);
2443     input_name_list->ge_inputs = ge_input_nodes;
2444     anf_graph_->set_user_data(input_name_list);
2445   }
2446   MS_LOG(INFO) << "End BuildGraph, graph: " << anf_graph_->ToString();
2447   return *this;
2448 }
2449 
SetGraphOutputs(bool is_main_graph)2450 void DfGraphConvertor::SetGraphOutputs(bool is_main_graph) {
2451   if (cur_while_node_ == nullptr) {
2452     graph_outputs_.clear();
2453     std::vector<AnfNodePtr> return_nodes;
2454     auto ret_node = anf_graph_->get_return();
2455     MS_EXCEPTION_IF_NULL(ret_node);
2456     auto output_nodes = ret_node->inputs();
2457     if (has_es_node_) {
2458       return_nodes = GetEmbeddingApplyAdamOutput(ret_node);
2459     } else if (((!HasSubgraph(anf_graph_) && is_main_graph)) ||
2460                (output_nodes.size() > 1 && IsESNodeWithNoOutput(output_nodes[1]))) {
2461       // replace return node with graph output node.
2462       return_nodes.insert(return_nodes.end(), output_nodes.begin() + 1, output_nodes.end());
2463     } else {
2464       return_nodes.emplace_back(ret_node);
2465     }
2466     for (const auto &output_node : return_nodes) {
2467       MS_EXCEPTION_IF_NULL(output_node);
2468       auto adpt = FindAdapter(output_node, training_);
2469       MS_EXCEPTION_IF_NULL(adpt);
2470       auto op_ptr = Convert(output_node);
2471       std::vector<OutHandler> handles;
2472       if (op_ptr != nullptr) {
2473         handles = adpt->getOutputs(op_ptr);
2474       } else if (tuple_out_handle_cache_.count(output_node.get()) > 0) {
2475         handles = *tuple_out_handle_cache_[output_node.get()];
2476       } else {
2477         MS_LOG(EXCEPTION) << "Can not find matched handles for node " << output_node->ToString();
2478       }
2479 
2480       for (const auto &handle : handles) {
2481         (void)graph_outputs_.emplace_back(std::make_pair(*handle.op, handle.out));
2482       }
2483     }
2484   }
2485 
2486   MS_LOG(INFO) << "Set graph " << anf_graph_->ToString() << " output, num: " << graph_outputs_.size();
2487   for (size_t i = 0; i < graph_outputs_.size(); i++) {
2488     MS_LOG(INFO) << "Graph output " << i << ": node: " << graph_outputs_[i].first.GetName()
2489                  << ", out: " << graph_outputs_[i].second;
2490   }
2491 }
2492 
UpdateConstOpDesc(const AnfNodePtr & it,const OperatorPtr & op) const2493 void DfGraphConvertor::UpdateConstOpDesc(const AnfNodePtr &it, const OperatorPtr &op) const {
2494   if (!it->isa<Parameter>()) {
2495     MS_LOG(DEBUG) << "It is not parameter, name: " << it->DebugString();
2496     return;
2497   }
2498   auto para = it->cast<ParameterPtr>();
2499   MS_EXCEPTION_IF_NULL(para);
2500   std::string format = SelectParamOriFormat(graph_manager_, it);
2501   std::string param_debug_info = para->DebugString();
2502   auto param_format = param_format_.find(param_debug_info);
2503   if (param_format != param_format_.end()) {
2504     format = param_format->second;
2505     MS_LOG(DEBUG) << "Parameter debug info: " << param_debug_info << ", format is " << format;
2506   }
2507   if (format == kOpFormat_DEFAULT || format == kOpFormat_NCHW) {
2508     MS_LOG(DEBUG) << "Format is not changed, no need to update op desc, name: " << param_debug_info;
2509     return;
2510   }
2511   if (!para->has_default()) {
2512     MS_LOG(DEBUG) << "Parameter has no default, no need to update op desc, name: " << param_debug_info;
2513     return;
2514   }
2515   auto value = para->default_param();
2516   MS_EXCEPTION_IF_NULL(value);
2517   auto tensor = value->cast<std::shared_ptr<tensor::Tensor>>();
2518   MS_EXCEPTION_IF_NULL(tensor);
2519   auto const_op_desc = TransformUtil::GetGeTensorDesc(tensor->shape_c(), tensor->data_type(), format);
2520   StorageFormatConvertor::SetupStorageFormat(anf_graph_, it, const_op_desc, format);
2521   if (const_op_desc == nullptr) {
2522     MS_LOG(WARNING) << "Create parameter " << para->name() << " output descriptor failed!";
2523     return;
2524   }
2525   (void)op->UpdateOutputDesc(kTypeY, *const_op_desc);
2526 }
2527 
UpdateDataOpDesc(const AnfNodePtr & it,const OperatorPtr & op) const2528 void DfGraphConvertor::UpdateDataOpDesc(const AnfNodePtr &it, const OperatorPtr &op) const {
2529   auto node = std::static_pointer_cast<AnfNode>(it);
2530   MS_EXCEPTION_IF_NULL(node);
2531   if (node == nullptr) {
2532     MS_LOG(ERROR) << "Update data op descriptor failed! Invalid node.";
2533     return;
2534   }
2535   std::vector<int64_t> shape;
2536   if (auto normal_shape_ptr = dyn_cast<abstract::Shape>(node->Shape()); normal_shape_ptr != nullptr) {
2537     shape = normal_shape_ptr->shape();
2538   } else if (auto no_shape_ptr = dyn_cast<abstract::NoShape>(node->Shape()); no_shape_ptr != nullptr) {
2539     shape = {};
2540   } else {
2541     MS_LOG(INFO) << "Invalid shape to update data op descriptor.";
2542     return;
2543   }
2544   if (node->Type() == nullptr) {
2545     MS_LOG(INFO) << "Invalid type to update data op descriptor.";
2546     return;
2547   }
2548   TypeId me_type = node->Type()->type_id();
2549   if (kObjectTypeTensorType == me_type) {
2550     me_type = dyn_cast<TensorType>(node->Type())->element()->type_id();
2551   }
2552   std::ostringstream buf;
2553   buf << "[" << shape << "]";
2554   MS_LOG(INFO) << "input shape is " << buf.str() << ", type is " << me_type;
2555   std::string format = SelectParamOriFormat(graph_manager_, it);
2556   if (it->isa<Parameter>()) {
2557     auto param = it->cast<ParameterPtr>();
2558     MS_EXCEPTION_IF_NULL(param);
2559     std::string param_name = param->DebugString();
2560     auto param_format = param_format_.find(param_name);
2561     if (param_format != param_format_.end()) {
2562       format = param_format->second;
2563       MS_LOG(DEBUG) << "parameter: " << param_name << ", format is " << format;
2564     }
2565   }
2566   auto desc = TransformUtil::GetGeTensorDesc(shape, me_type, format);
2567   StorageFormatConvertor::SetupStorageFormat(anf_graph_, it, desc, format);
2568   if (desc == nullptr) {
2569     MS_LOG(ERROR) << "Update data op descriptor failed! TensorDesc is null.";
2570   } else {
2571     (void)op->UpdateInputDesc(kTypeX, *desc);
2572     (void)op->UpdateOutputDesc(kTypeY, *desc);
2573   }
2574 }
2575 
GetComputeGraph()2576 DfGraphPtr DfGraphConvertor::GetComputeGraph() { return df_graph_; }
2577 
GetInitGraph()2578 DfGraphPtr DfGraphConvertor::GetInitGraph() { return init_graph_; }
2579 
GetSaveCheckpointGraph()2580 DfGraphPtr DfGraphConvertor::GetSaveCheckpointGraph() { return save_ckp_graph_; }
2581 
GetBroadcastGraph()2582 DfGraphPtr DfGraphConvertor::GetBroadcastGraph() { return broadcast_graph_; }
2583 
2584 const std::vector<std::string> trans_var_list = {string(kNameAssign), string(kNameAssignAdd), string(kNameAssignSub)};
2585 
ParseLoadInput(const CNodePtr & cnode) const2586 AnfNodePtr DfGraphConvertor::ParseLoadInput(const CNodePtr &cnode) const {
2587   MS_EXCEPTION_IF_NULL(cnode);
2588   size_t min_inputs_size = 3;
2589   if (cnode->size() < min_inputs_size) {
2590     MS_LOG(EXCEPTION) << "input size error, " << cnode->ToString();
2591   }
2592   const size_t para_index = 1;
2593   return cnode->input(para_index);
2594 }
2595 
TransformConstOp(const CNodePtr & node,const AnfNodePtr & pred)2596 void DfGraphConvertor::TransformConstOp(const CNodePtr &node, const AnfNodePtr &pred) {
2597   // transform "Const" op to "Variable" op when the next node is "Assign" op.
2598   std::string c_name = GetCNodeTargetFuncName(node);
2599   auto pos = std::find(trans_var_list.begin(), trans_var_list.end(), c_name);
2600   if (!training_ && !IsSubGraph() && pos != trans_var_list.end() && pred->isa<Parameter>()) {
2601     std::string name = std::static_pointer_cast<Parameter>(pred)->name();
2602     auto op_itor = op_cache_.find(pred.get());
2603     if (op_itor == op_cache_.end()) {
2604       MS_LOG(EXCEPTION) << "Can not find op for node " << pred->ToString() << ".";
2605     }
2606     if (op_itor->second != nullptr &&
2607         (op_itor->second->GetOpType() == "Constant" || op_itor->second->GetOpType() == "Const") &&
2608         vars_.find(name) != vars_.end()) {
2609       MS_EXCEPTION_IF_NULL(vars_[name]);
2610       if (ref_mode_) {
2611         auto variable = std::make_shared<RefData>(name);
2612         MS_EXCEPTION_IF_NULL(variable);
2613         auto desc = vars_[name]->GetOutputDesc(kTypeY);
2614         (void)variable->update_output_desc_y(desc);
2615         (void)variable->update_input_desc_x(desc);
2616         (void)variable->set_attr_index(ref_datas_.size());
2617         (void)ref_datas_.emplace_back(variable);
2618         MS_LOG(DEBUG) << "Trans to variable, var = " << variable->GetName() << ".";
2619         op_itor->second = variable;  // replace parameter with variable
2620         vars_[name] = variable;
2621       } else {
2622         auto variable = std::make_shared<Variable>(name);
2623         MS_EXCEPTION_IF_NULL(variable);
2624         auto desc = vars_[name]->GetOutputDesc(kTypeY);
2625         (void)variable->update_output_desc_y(desc);
2626         MS_LOG(DEBUG) << "Trans to variable, var = " << variable->GetName() << ".";
2627         op_itor->second = variable;  // replace parameter with variable
2628         vars_[name] = variable;
2629       }
2630     }
2631   }
2632 }
2633 
GetRealInputNode(const CNodePtr & node,const AnfNodePtr & input)2634 AnfNodePtr DfGraphConvertor::GetRealInputNode(const CNodePtr &node, const AnfNodePtr &input) {
2635   if (input == nullptr || node == nullptr) {
2636     return nullptr;
2637   }
2638   AnfNodePtr pred = input;
2639   while (pred->isa<CNode>() && GetCNodeTargetFuncName(pred->cast<CNodePtr>()) == prim::kPrimDepend->name()) {
2640     pred = pred->cast<CNodePtr>()->input(1);
2641   }
2642 
2643   // skip input of UMonad, IOMonad
2644   if (IsValueNode<UMonad>(pred) || IsValueNode<IOMonad>(pred)) {
2645     return nullptr;
2646   }
2647   if (HasAbstractMonad(pred)) {
2648     return nullptr;
2649   }
2650 
2651   // skip input of the None, UpdateState
2652   if (IsValueNode<None>(pred) || IsPrimitiveCNode(pred, prim::kPrimUpdateState)) {
2653     return nullptr;
2654   }
2655 
2656   if (IsPrimitiveCNode(pred, prim::kPrimLoad)) {
2657     pred = ParseLoadInput(pred->cast<CNodePtr>());
2658     // for scenario like: Depend->Load->TensorMove
2659     if (IsPrimitiveCNode(pred, prim::kPrimDepend)) {
2660       return GetRealInputNode(node, pred);
2661     }
2662   }
2663   TransformConstOp(node, pred);
2664   return pred;
2665 }
2666 
IsDataInput(const AnfNodePtr & node,const AnfNodePtr & input,size_t input_index)2667 bool DfGraphConvertor::IsDataInput(const AnfNodePtr &node, const AnfNodePtr &input, size_t input_index) {
2668   if (node == nullptr || input == nullptr) {
2669     MS_LOG(ERROR) << "Node or input is null.";
2670     return false;
2671   }
2672   // Ignore the null ValueTupe in MakeTuple
2673   if (IsMakeTupleWithNullValue(node, input)) {
2674     return false;
2675   }
2676 
2677   // skip NoOp
2678   auto op = Convert(node);
2679   if (op != nullptr && op->GetOpType() == kTypeNoOp) {
2680     return false;
2681   }
2682 
2683   // skip input of UMonad, IOMonad
2684   if (IsMonad(input)) {
2685     return false;
2686   }
2687 
2688   // skip input of the None, UpdateState
2689   if (IsValueNode<None>(input) || IsPrimitiveCNode(input, prim::kPrimUpdateState)) {
2690     return false;
2691   }
2692 
2693   const PrimitiveSet has_control_node = {prim::kPrimLoad, prim::kPrimDepend, prim::kPrimTupleGetItem};
2694   if (input_index != kDataInputIndex && IsOneOfPrimitiveCNode(node, has_control_node)) {
2695     return false;
2696   }
2697 
2698   // Ge Operator of HcomReceive has no input.
2699   if (IsPrimitiveCNode(node, prim::kPrimReceive)) {
2700     return false;
2701   }
2702 
2703   // The NPUClearFloatStatusV2 of GE has no input and output, and the NPUGetFloatStatusV2 has no input.
2704   // The extra data edges of MindSpore need to be converted to control edges of GE.
2705   if (IsOverFlowNode(node, input)) {
2706     return false;
2707   }
2708 
2709   if (IsESNodeWithNoOutput(input)) {
2710     return false;
2711   }
2712 
2713   return true;
2714 }
2715 
GetNormalOpInput(const AnfNodePtr & node,const AnfNodePtr & pred)2716 OutHandler DfGraphConvertor::GetNormalOpInput(const AnfNodePtr &node, const AnfNodePtr &pred) {
2717   MS_EXCEPTION_IF_NULL(node);
2718   MS_EXCEPTION_IF_NULL(pred);
2719   OutHandler out_handler;
2720   if (IsSubGraph() && pred->isa<Parameter>()) {
2721     auto idx = std::find(inputs_.begin(), inputs_.end(), pred) - inputs_.begin();
2722     OperatorPtr op = subgraph_input_cache_[idx];
2723     out_handler.op = op;
2724     return out_handler;
2725   }
2726 
2727   if (IsAfterGraph() && pred->isa<Parameter>()) {
2728     auto idx = std::find(inputs_.begin(), inputs_.end(), pred) - inputs_.begin();
2729     auto idx_cond = prev_after_cond_map_[idx];
2730     if (bypass_node_prev_handle_cache_.find(idx_cond) != bypass_node_prev_handle_cache_.end()) {
2731       out_handler = bypass_node_prev_handle_cache_[idx_cond];
2732     } else {
2733       auto idx_out = prev_cond_to_while_out_index_[idx_cond];
2734       MS_EXCEPTION_IF_NULL(while_output_handle_cache_[prev_while_node_]);
2735       out_handler = while_output_handle_cache_[prev_while_node_]->at(idx_out);
2736     }
2737     return out_handler;
2738   }
2739 
2740   if (out_handle_cache_.find(pred.get()) != out_handle_cache_.end()) {
2741     return out_handle_cache_[pred.get()];
2742   }
2743   auto op = Convert(pred);
2744   if (op == nullptr) {
2745     MS_LOG(WARNING) << "Convert input node failed, input node: " << pred->fullname_with_scope()
2746                     << ", node: " << node->fullname_with_scope() << ", graph: " << anf_graph_->ToString()
2747                     << ". Please check whether the node is Partial node or successor node of Partial in sub-graph.";
2748   }
2749   out_handler.op = op;
2750   out_handler.node = pred;
2751   return out_handler;
2752 }
2753 
DrawOpInput(const AnfNodePtr & node,const AnfNodePtr & pred,size_t i)2754 void DfGraphConvertor::DrawOpInput(const AnfNodePtr &node, const AnfNodePtr &pred, size_t i) {
2755   MS_EXCEPTION_IF_NULL(pred);
2756   if (pred->isa<CNode>() && GetCNodeTargetFuncName(pred->cast<CNodePtr>()) == mindspore::kTupleGetItemOpName) {
2757     MS_EXCEPTION_IF_NULL(pred->cast<CNodePtr>());
2758     MS_EXCEPTION_IF_NULL(pred->cast<CNodePtr>()->input(1));
2759     compute_sout_ << op_draw_name_[pred->cast<CNodePtr>()->input(1).get()] << " -> " << op_draw_name_[node.get()] << ":"
2760                   << i << endl;
2761   } else if (pred->isa<Parameter>()) {
2762     compute_sout_ << op_draw_name_[pred.get()] << " -> " << op_draw_name_[node.get()] << ":" << i << endl;
2763   }
2764   return;
2765 }
2766 
GetInputHandles(const AnfNodePtr & node,const AnfNodePtr & input)2767 std::vector<OutHandler> DfGraphConvertor::GetInputHandles(const AnfNodePtr &node, const AnfNodePtr &input) {
2768   MS_EXCEPTION_IF_NULL(node);
2769   MS_EXCEPTION_IF_NULL(input);
2770   std::vector<OutHandler> handles;
2771   auto cache_ret = tuple_out_handle_cache_.find(input.get());
2772   if (cache_ret != tuple_out_handle_cache_.end()) {
2773     handles = *(cache_ret->second);
2774   } else if (IsWhileNode(input)) {
2775     // While node in subgraph does not convert.
2776     // Output handle of While node is inconsistent with MS.
2777     MS_LOG(WARNING) << "Input node is while node, input node: " << input->fullname_with_scope()
2778                     << ", node: " << node->fullname_with_scope() << ", graph: " << anf_graph_->ToString();
2779     std::transform(graph_outputs_.begin(), graph_outputs_.end(), std::back_inserter(handles), [](const auto output) {
2780       return OutHandler(std::make_shared<::ge::Operator>(output.first), output.second);
2781     });
2782   } else {
2783     auto pred_adpt = FindAdapter(input, training_);
2784     MS_EXCEPTION_IF_NULL(pred_adpt);
2785     // When node's output is dynamic or node has multiple output, it need to get all handles.
2786     // TupleGetItem's input is dynamic output(eg:MakeTuple), but it only need to get one handle.
2787     if ((pred_adpt->IsDyOutputOp(0) || pred_adpt->IsMultipleOutputOp(input))) {
2788       MS_EXCEPTION_IF_NULL(Convert(input));
2789       handles = pred_adpt->getOutputs(Convert(input));
2790     } else {
2791       auto handle = GetNormalOpInput(node, input);
2792       if (handle.op != nullptr) {
2793         handles.emplace_back(handle);
2794       }
2795     }
2796   }
2797 
2798   if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
2799     std::vector<OutHandler> return_handles;
2800     CNodePtr cnode = node->cast<CNodePtr>();
2801     MS_EXCEPTION_IF_NULL(cnode);
2802     size_t tuplegetitem_idx = common::AnfAlgo::GetTupleGetItemOutIndex(cnode);
2803     if (tuplegetitem_idx >= handles.size()) {
2804       MS_LOG(EXCEPTION) << "Node output index " << tuplegetitem_idx << " is out of range [0," << handles.size()
2805                         << "), node: " << node->fullname_with_scope()
2806                         << ", input node: " << input->fullname_with_scope();
2807     } else {
2808       return_handles.emplace_back(handles[tuplegetitem_idx]);
2809       return return_handles;
2810     }
2811   }
2812 
2813   return handles;
2814 }
2815 
SetDynamicInputHandleByMultiInput(const OpAdapterPtr & adpt,const CNodePtr & node,const CNodePtr & from_node_input)2816 void DfGraphConvertor::SetDynamicInputHandleByMultiInput(const OpAdapterPtr &adpt, const CNodePtr &node,
2817                                                          const CNodePtr &from_node_input) {
2818   MS_EXCEPTION_IF_NULL(adpt);
2819   MS_EXCEPTION_IF_NULL(node);
2820   MS_EXCEPTION_IF_NULL(from_node_input);
2821   auto inputs = from_node_input->inputs();
2822   std::vector<OutHandler> handles;
2823   for (size_t i = 1; i < inputs.size(); i++) {
2824     auto input = inputs[i];
2825     if (!IsDataInput(from_node_input, input, i)) {
2826       SetNodeControlInput(node, input);
2827       continue;
2828     }
2829     TransformConstOp(from_node_input, input);
2830     auto input_handles = GetInputHandles(from_node_input, input);
2831     handles.insert(handles.end(), input_handles.begin(), input_handles.end());
2832     if (input_handles.empty()) {
2833       MS_LOG(INFO) << "input handles is empty, node: " << from_node_input->fullname_with_scope()
2834                    << ", input node: " << input->fullname_with_scope();
2835       continue;
2836     }
2837     AddGraphConstInput(input_handles[0].op);
2838     DrawOpInput(node, input, i);
2839   }
2840 
2841   auto ret = adpt->setInput(Convert(node), 1, std::make_shared<std::vector<OutHandler>>(handles));
2842   if (ret != SUCCESS) {
2843     MS_LOG(EXCEPTION) << "Set node input handle failed, node:" << node->fullname_with_scope();
2844   }
2845 }
2846 
IsMergeOrSwitchLayerInput(const CNodePtr & node) const2847 bool DfGraphConvertor::IsMergeOrSwitchLayerInput(const CNodePtr &node) const {
2848   auto manager = anf_graph_->manager();
2849   if (manager == nullptr) {
2850     auto new_manager = MakeManager({anf_graph_});
2851     MS_EXCEPTION_IF_NULL(new_manager);
2852     new_manager->AddFuncGraph(anf_graph_);
2853     anf_graph_->set_manager(new_manager);
2854     manager = new_manager;
2855   }
2856   auto node_users = manager->node_users()[node];
2857 
2858   return (node_users.size() == 1 && std::find_if(node_users.begin(), node_users.end(), [](const auto &node_user) {
2859                                       return IsPrimitiveCNode(node_user.first, prim::kPrimMerge) ||
2860                                              IsPrimitiveCNode(node_user.first, prim::kPrimSwitchLayer);
2861                                     }) != node_users.end());
2862 }
2863 
SetMakeTupleInput(const OpAdapterPtr & adpt,const CNodePtr & make_tuple_node)2864 void DfGraphConvertor::SetMakeTupleInput(const OpAdapterPtr &adpt, const CNodePtr &make_tuple_node) {
2865   MS_EXCEPTION_IF_NULL(adpt);
2866   MS_EXCEPTION_IF_NULL(make_tuple_node);
2867   MS_LOG(DEBUG) << "Set MakeTuple input handle: " << make_tuple_node->fullname_with_scope();
2868   // Skip MakeTuple make_tuple_node before Merge. Two branches(true/false) should not be merged before Merge, which
2869   // will lead to assign stream error in GE. Skip MakeTuple node before switch_layer, switch_layer's inputs will be
2870   // set in control flow process
2871   if (IsMergeOrSwitchLayerInput(make_tuple_node)) {
2872     MS_LOG(INFO) << "Skip make_tuple_node " << make_tuple_node->fullname_with_scope() << ", not set input handle.";
2873     return;
2874   }
2875   SetDynamicInputHandleByMultiInput(adpt, make_tuple_node, make_tuple_node);
2876 }
2877 
SetMergeInput(const OpAdapterPtr & adpt,const CNodePtr & merge_node)2878 void DfGraphConvertor::SetMergeInput(const OpAdapterPtr &adpt, const CNodePtr &merge_node) {
2879   MS_EXCEPTION_IF_NULL(adpt);
2880   MS_EXCEPTION_IF_NULL(merge_node);
2881   auto inputs = merge_node->inputs();
2882   if (inputs.size() != kMergeInputSize) {
2883     MS_LOG(EXCEPTION) << "Merge input size should be " << kMergeInputSize << ", but is " << inputs.size()
2884                       << ", node: " << merge_node->fullname_with_scope();
2885   }
2886   auto make_tuple = inputs[1];
2887   MS_EXCEPTION_IF_NULL(make_tuple);
2888   if (!IsPrimitiveCNode(make_tuple, prim::kPrimMakeTuple)) {
2889     MS_LOG(EXCEPTION) << "Merge input is not MakeTuple, but is " << make_tuple->fullname_with_scope()
2890                       << ", node: " << merge_node->fullname_with_scope();
2891   }
2892   SetDynamicInputHandleByMultiInput(adpt, merge_node, make_tuple->cast<CNodePtr>());
2893 }
2894 
SetNodeControlInput(const AnfNodePtr & node,const AnfNodePtr & input)2895 void DfGraphConvertor::SetNodeControlInput(const AnfNodePtr &node, const AnfNodePtr &input) {
2896   MS_EXCEPTION_IF_NULL(node);
2897   MS_EXCEPTION_IF_NULL(input);
2898   if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem) && input->isa<ValueNode>()) {
2899     return;
2900   }
2901   if (input->isa<Parameter>() && HasAbstractMonad(input)) {
2902     MS_LOG(DEBUG) << "Node input is monad node, do not add control edge. node:" << node->fullname_with_scope()
2903                   << ", input: " << input->ToString();
2904     return;
2905   }
2906   auto dst = Convert(node);
2907   MS_EXCEPTION_IF_NULL(dst);
2908   auto src = Convert(input);
2909   if (src != nullptr) {
2910     dst->AddControlInput(*src);
2911   }
2912 }
2913 
IsDynamicInputBeforeNormalInput(const OpAdapterPtr & adpt,int * ge_input_size,mindspore::HashMap<int,int> * ge_input_to_ms_input)2914 bool DfGraphConvertor::IsDynamicInputBeforeNormalInput(const OpAdapterPtr &adpt, int *ge_input_size,
2915                                                        mindspore::HashMap<int, int> *ge_input_to_ms_input) {
2916   MS_EXCEPTION_IF_NULL(adpt);
2917   const auto &input_map = adpt->getInputMap();
2918   const auto &dyn_input_map = adpt->getDynInputMap();
2919 
2920   // If adpt has no dynamic input, return false.
2921   if (dyn_input_map.empty()) {
2922     return false;
2923   }
2924 
2925   // If dynamic input is behind the normal input, return false
2926   int min_dynamic_idx = std::numeric_limits<int>::max();
2927   int max_normal_idx = -1;
2928   for (const auto &iter : dyn_input_map) {
2929     int ms_order = iter.first - kIndex1;
2930     int ge_order = iter.second.index;
2931     min_dynamic_idx = std::min(min_dynamic_idx, ge_order);
2932     *ge_input_size = std::max(*ge_input_size, ge_order + 1);
2933     (*ge_input_to_ms_input)[ge_order] = ms_order;
2934   }
2935   for (const auto &iter : input_map) {
2936     int ms_order = iter.first - kIndex1;
2937     int ge_order = iter.second.index;
2938     max_normal_idx = std::max(max_normal_idx, ge_order);
2939     *ge_input_size = std::max(*ge_input_size, ge_order + 1);
2940     (*ge_input_to_ms_input)[ge_order] = ms_order;
2941   }
2942   if (min_dynamic_idx == std::numeric_limits<int>::max() || max_normal_idx == -1 || min_dynamic_idx > max_normal_idx) {
2943     return false;
2944   }
2945   return true;
2946 }
2947 
SetDynamicInputBeforeNormalInput(const OpAdapterPtr & adpt,const CNodePtr & node,const std::vector<AnfNodePtr> & inputs,const int & ge_input_size,const mindspore::HashMap<int,int> & ge_input_to_ms_input,std::vector<int64_t> * dyn_input_sizes)2948 void DfGraphConvertor::SetDynamicInputBeforeNormalInput(const OpAdapterPtr &adpt, const CNodePtr &node,
2949                                                         const std::vector<AnfNodePtr> &inputs, const int &ge_input_size,
2950                                                         const mindspore::HashMap<int, int> &ge_input_to_ms_input,
2951                                                         std::vector<int64_t> *dyn_input_sizes) {
2952   //  If dynamic input is ahead of the normal input, use 'create_dynamic_input_by_index_name' to create dynamic input,
2953   //  and this func must be called before set normal input.
2954   OperatorPtr src = Convert(node);
2955   MS_EXCEPTION_IF_NULL(adpt);
2956   const auto &dyn_input_map = adpt->getDynInputMap();
2957   MS_EXCEPTION_IF_NULL(dyn_input_sizes);
2958   if (dyn_input_sizes->empty()) {
2959     *dyn_input_sizes = std::vector<int64_t>(ge_input_size, -1);
2960     for (const auto &iter : dyn_input_map) {
2961       dyn_input_sizes->at(iter.first - kIndex1) = 1;
2962     }
2963   }
2964   std::vector<int64_t> new_dyn_input_sizes(ge_input_size, -1);
2965   std::vector<int> ge_tensor_orders =
2966     GetGeTensorOrders(ge_input_to_ms_input, *dyn_input_sizes, ge_input_size, &new_dyn_input_sizes);
2967 
2968   std::vector<size_t> ms_control_inputs;
2969   for (size_t i = 1; i < inputs.size(); ++i) {
2970     if (HasAbstractMonad(inputs[i])) {
2971       ms_control_inputs.emplace_back(i);
2972     }
2973   }
2974 
2975   MS_LOG(INFO) << "Adjust the dyn input order and use create_dynamic_input_byindex_name for node: "
2976                << node->fullname_with_scope();
2977   // ge_input_idx: the real ge input order
2978   // ge_tensor_orders: the tensor input order
2979   // ge_input_to_ms_input: the relationship between ge input order and ms input order
2980   // new_dyn_input_sizes:  tensor size of dynamic input
2981   for (int ge_input_idx = 0; ge_input_idx < ge_input_size; ++ge_input_idx) {
2982     int ms_input_idx = ge_input_to_ms_input.at(ge_input_idx) + kIndex1;
2983     // ge_tensor_idx: the ge input idx of unfold mindspore inputs
2984     int ge_tensor_idx = ge_tensor_orders[ge_input_idx] + kIndex1;
2985     if (ge_tensor_idx >= static_cast<int>(inputs.size())) {
2986       MS_LOG(INFO) << "ge tensor index is more than ms inputs size, ge_tensor_idx:" << ge_tensor_idx
2987                    << ", input size: " << inputs.size();
2988       continue;
2989     }
2990     AnfNodePtr pred = inputs[ge_tensor_idx];
2991     MS_EXCEPTION_IF_NULL(pred);
2992     if (!IsDataInput(node, pred, ge_input_idx)) {
2993       SetNodeControlInput(node, pred);
2994       continue;
2995     }
2996     auto handles = GetInputHandles(node, pred);
2997     if (handles.empty()) {
2998       MS_LOG(INFO) << "Input handles is empty, input node: " << pred->fullname_with_scope()
2999                    << ", node: " << node->fullname_with_scope() << ", index: " << ms_input_idx;
3000       continue;
3001     }
3002     int ret;
3003     int64_t dyn_input_num = new_dyn_input_sizes[ge_input_idx];
3004     if (dyn_input_num != -1) {
3005       for (size_t dyn_input_idx = 1; dyn_input_idx < LongToSize(dyn_input_num); ++dyn_input_idx) {
3006         auto dyn_input_handle = GetInputHandles(node, inputs[ge_tensor_idx + dyn_input_idx]);
3007         handles.insert(handles.end(), dyn_input_handle.begin(), dyn_input_handle.end());
3008       }
3009       size_t dyn_input_begin_idx = 0;
3010       for (size_t i = 0; i < IntToSize(ge_input_idx); ++i) {
3011         dyn_input_begin_idx += new_dyn_input_sizes[i] == -1 ? 1 : LongToSize(new_dyn_input_sizes[i]);
3012       }
3013       ret = adpt->setInput(src, SizeToInt(ms_input_idx), std::make_shared<std::vector<OutHandler>>(handles), true,
3014                            dyn_input_begin_idx);
3015     } else {
3016       if (handles.size() != 1 && pred->isa<ValueNode>()) {
3017         handles.clear();
3018         auto handle = GetNormalOpInput(node, pred);
3019         handles.emplace_back(handle);
3020       }
3021       if (handles.size() != 1) {
3022         MS_LOG(EXCEPTION) << "Input handles size " << handles.size() << " is not equal to 1, "
3023                           << node->fullname_with_scope() << ", input node: " << pred->fullname_with_scope()
3024                           << ", index: " << ms_input_idx;
3025       }
3026       ret = adpt->setInput(src, SizeToInt(ms_input_idx), handles[0]);
3027     }
3028     if (ret != SUCCESS) {
3029       MS_LOG(DEBUG) << "Set node input handle failed, node:" << node->fullname_with_scope()
3030                     << ", input node: " << pred->fullname_with_scope() << ", index: " << ms_input_idx;
3031     } else {
3032       DrawOpInput(node, pred, ge_input_idx);
3033       AddGraphConstInput(handles[0].op);
3034     }
3035   }
3036 
3037   for (size_t ms_control_input : ms_control_inputs) {
3038     AnfNodePtr pred = inputs[ms_control_input];
3039     SetNodeControlInput(node, pred);
3040   }
3041 
3042   // Set input from attr.
3043   SetOpAttrToInput(adpt, node);
3044   return;
3045 }
3046 
AddInputAttrsForESNode(const CNodePtr & node,const AnfNodePtr & input)3047 void DfGraphConvertor::AddInputAttrsForESNode(const CNodePtr &node, const AnfNodePtr &input) {
3048   const mindspore::HashSet<PrimitivePtr, PrimitiveHasher, PrimitiveEqual> es_need_add_attr = {
3049     prim::kPrimInitPartitionMap,     prim::kPrimInitEmbeddingHashmap,      prim::kPrimEmbeddingTableImport,
3050     prim::kPrimEmbeddingTableExport, prim::kPrimEmbeddingComputeVarImport, prim::kPrimEmbeddingComputeVarExport,
3051     prim::kPrimEmbeddingApplyAdam,   prim::kPrimEmbeddingApplyAdamW,       prim::kPrimEmbeddingApplyAdaGrad,
3052     prim::kPrimEmbeddingApplyFtrl,
3053   };
3054   if (!IsOneOfPrimitiveCNode(node, es_need_add_attr)) {
3055     return;
3056   }
3057   auto real = GetRealInputNode(node, input);
3058   MS_EXCEPTION_IF_NULL(real);
3059   auto op = Convert(real);
3060   MS_EXCEPTION_IF_NULL(real);
3061   if (!real->isa<ValueNode>()) {
3062     return;
3063   }
3064   (void)op->SetAttr(kProcessNodeEngineID, "PS");
3065 }
3066 
SetOpInput(const OpAdapterPtr & adpt,const CNodePtr & node)3067 void DfGraphConvertor::SetOpInput(const OpAdapterPtr &adpt, const CNodePtr &node) {
3068   MS_EXCEPTION_IF_NULL(adpt);
3069   MS_EXCEPTION_IF_NULL(node);
3070   OperatorPtr src = Convert(node);
3071   bool branch_flag = false;
3072   auto &inputs = node->inputs();
3073   size_t input_size = inputs.size();
3074   if (branch_input_handle_cache_.find(node.get()) != branch_input_handle_cache_.end()) {
3075     branch_flag = true;
3076     MS_EXCEPTION_IF_NULL(branch_input_handle_cache_[node.get()]);
3077     input_size = branch_input_handle_cache_[node.get()]->size() + 1;
3078   } else if (!IsSubGraph() && call_input_handle_cache_.find(node) != call_input_handle_cache_.end()) {
3079     auto &handles = call_input_handle_cache_[node];
3080     MS_EXCEPTION_IF_NULL(handles);
3081     MS_LOG(DEBUG) << "call node input size: " << handles->size();
3082     adpt->setInput(src, 1, handles);
3083     return;
3084   }
3085 
3086   MS_LOG(DEBUG) << "Set op input for node: " << node->fullname_with_scope();
3087   if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
3088     SetMakeTupleInput(adpt, node);
3089     return;
3090   }
3091 
3092   if (IsPrimitiveCNode(node, prim::kPrimMerge)) {
3093     SetMergeInput(adpt, node);
3094     return;
3095   }
3096   bool is_call = IsCallNode(node);
3097   std::vector<int64_t> dyn_input_sizes;
3098   if (common::AnfAlgo::HasNodeAttr(kAttrDynInputSizes, node)) {
3099     dyn_input_sizes = common::AnfAlgo::GetNodeAttr<std::vector<int64_t>>(node, kAttrDynInputSizes);
3100   }
3101 
3102   int ge_input_size = 1;
3103   mindspore::HashMap<int, int> ge_input_to_ms_input;
3104   if (IsDynamicInputBeforeNormalInput(adpt, &ge_input_size, &ge_input_to_ms_input)) {
3105     SetDynamicInputBeforeNormalInput(adpt, node, inputs, ge_input_size, ge_input_to_ms_input, &dyn_input_sizes);
3106     return;
3107   }
3108   // For call node, the first input is kernel_graph, which should not be added to input args.
3109   size_t input_idx = is_call ? 2 : 1;
3110   size_t real_input_idx = 1;
3111   while (input_idx < input_size) {
3112     AnfNodePtr pred = branch_flag ? branch_input_handle_cache_[node.get()]->at(input_idx - 1) : inputs[input_idx];
3113     MS_EXCEPTION_IF_NULL(pred);
3114     if (!IsDataInput(node, pred, real_input_idx)) {
3115       SetNodeControlInput(node, pred);
3116       input_idx += 1;
3117       real_input_idx += 1;
3118       continue;
3119     }
3120     TransformConstOp(node, pred);
3121     auto handles = GetInputHandles(node, pred);
3122     if (handles.empty()) {
3123       MS_LOG(INFO) << "Input handles is empty, input node: " << pred->fullname_with_scope()
3124                    << ", node: " << node->fullname_with_scope() << ", index: " << real_input_idx;
3125       input_idx += 1;
3126       real_input_idx += 1;
3127       continue;
3128     }
3129 
3130     int ret;
3131     int64_t dyn_input_num = GetDynInputNum(adpt, is_call, dyn_input_sizes, real_input_idx, input_size, node);
3132     if (dyn_input_num != -1) {
3133       for (size_t dyn_input_idx = 1; dyn_input_idx < LongToSize(dyn_input_num); ++dyn_input_idx) {
3134         auto dyn_input_handle = GetInputHandles(node, inputs[input_idx + dyn_input_idx]);
3135         handles.insert(handles.end(), dyn_input_handle.begin(), dyn_input_handle.end());
3136       }
3137       ret = adpt->setInput(src, SizeToInt(real_input_idx), std::make_shared<std::vector<OutHandler>>(handles));
3138       input_idx += LongToSize(dyn_input_num);
3139     } else {
3140       if (handles.size() != 1 && pred->isa<ValueNode>()) {
3141         handles.clear();
3142         auto handle = GetNormalOpInput(node, pred);
3143         handles.emplace_back(handle);
3144       }
3145       if (handles.size() != 1) {
3146         MS_LOG(EXCEPTION) << "Input handles size " << handles.size() << " is not equal to 1, "
3147                           << node->fullname_with_scope() << ", input node: " << pred->fullname_with_scope()
3148                           << ", index: " << real_input_idx;
3149       }
3150       ret = adpt->setInput(src, SizeToInt(real_input_idx), handles[0]);
3151       input_idx += 1;
3152     }
3153     if (ret != SUCCESS) {
3154       MS_LOG(DEBUG) << "Set node input handle failed, node:" << node->fullname_with_scope()
3155                     << ", input node: " << pred->fullname_with_scope() << ", index: " << real_input_idx;
3156     } else {
3157       DrawOpInput(node, pred, real_input_idx);
3158       AddGraphConstInput(handles[0].op);
3159     }
3160     AddInputAttrsForESNode(node, pred);
3161     real_input_idx += 1;
3162   }
3163   // Set input from attr.
3164   SetOpAttrToInput(adpt, node);
3165 }
3166 
SetOpAttrToInput(const OpAdapterPtr & adpt,const CNodePtr & node)3167 void DfGraphConvertor::SetOpAttrToInput(const OpAdapterPtr &adpt, const CNodePtr &node) {
3168   OperatorPtr src = Convert(node);
3169   auto &inputs = node->inputs();
3170   size_t input_size = inputs.size();
3171   const auto &primitive = GetCNodePrimitive(node);
3172   MS_EXCEPTION_IF_NULL(primitive);
3173   const auto monad_size = std::count_if(inputs.begin() + kIndex1, inputs.end(), [](const AnfNodePtr &input) {
3174     return input->isa<ValueNode>() && HasAbstractMonad(input);
3175   });
3176   const auto &attr_input_map = adpt->getAttrInputMap();
3177   const auto &input_map = adpt->getInputMap();
3178   if (input_map.size() != attr_input_map.size() + input_size - monad_size - kIndex1) {
3179     MS_LOG(DEBUG) << "For node: " << node->DebugString()
3180                   << ", the size of real input:" << input_size - monad_size - kIndex1
3181                   << " + the size of attr_input_map: " << attr_input_map.size()
3182                   << " != the size of input_map:" << input_map.size()
3183                   << ", so do not convert input from attr any more.";
3184     return;
3185   }
3186   MS_EXCEPTION_IF_NULL(anf_graph_);
3187   for (auto &it : attr_input_map) {
3188     // Get attr from node.
3189     auto value = primitive->GetAttr(it.first);
3190     if (value == nullptr) {
3191       MS_LOG(INFO) << "Node: " << node->DebugString() << " has no attr: " << it.first;
3192       continue;
3193     }
3194     // Create input node for attr value.
3195     auto input_node = NewValueNode(value);
3196     input_node->set_abstract(value->ToAbstract());
3197     anf_graph_->manager()->AddEdge(node, input_node);
3198     auto new_input_op = Convert(input_node);
3199     // Get input desc.
3200     auto input_name = it.second;
3201     auto input_desc = std::find_if(input_map.begin(), input_map.end(),
3202                                    [input_name](const auto &item) { return item.second.name == input_name; });
3203     if (input_desc == input_map.end()) {
3204       MS_LOG(WARNING) << "Node: " << node->DebugString() << " has no input :" << input_name;
3205       continue;
3206     }
3207     MS_LOG(INFO) << "Set input from attr:" << it.first << " for node: " << node->DebugString()
3208                  << ", new value node:" << input_node->DebugString();
3209     input_desc->second.set_op(src, new_input_op);
3210     // Input idx may be wrong.
3211     DrawOpInput(node, input_node, static_cast<size_t>(input_desc->first));
3212     AddGraphConstInput(new_input_op);
3213   }
3214 }
3215 
AddGraphConstInput(const OperatorPtr & op)3216 void DfGraphConvertor::AddGraphConstInput(const OperatorPtr &op) {
3217   if (op == nullptr) {
3218     return;
3219   }
3220   if (IsSubGraph()) {
3221     return;
3222   }
3223 
3224   if (op->GetOpType() == "Constant" || op->GetOpType() == "Const") {
3225     graph_const_inputs_.emplace_back(op);
3226   }
3227 }
3228 
SetNodeInput(const AnfNodePtr node)3229 void DfGraphConvertor::SetNodeInput(const AnfNodePtr node) {
3230   if (!node->isa<CNode>()) {
3231     return;
3232   }
3233   if (op_cache_.find(node.get()) == op_cache_.end()) {
3234     return;
3235   }
3236   auto cnode = node->cast<CNodePtr>();
3237   MS_EXCEPTION_IF_NULL(cnode);
3238   OpAdapterPtr adpt = FindAdapter(cnode, training_);
3239   if (adpt == nullptr) {
3240     error_ = NOT_FOUND;
3241     return;
3242   }
3243 
3244   // get Operator from op_cache_, use adapter to set Inputs
3245   DfGraphConvertor::SetOpInput(adpt, cnode);
3246 }
3247 
GetGNodeName(const::ge::GNode & node) const3248 std::string DfGraphConvertor::GetGNodeName(const ::ge::GNode &node) const {
3249   ::ge::AscendString name;
3250   auto ret = node.GetName(name);
3251   if (ret == ::ge::GRAPH_SUCCESS) {
3252     return std::string(name.GetString());
3253   } else {
3254     MS_LOG(WARNING) << "Get GNode name failed, ret: " << ret;
3255     return std::string();
3256   }
3257 }
3258 
GetGNodeType(const::ge::GNode & node) const3259 std::string DfGraphConvertor::GetGNodeType(const ::ge::GNode &node) const {
3260   ::ge::AscendString node_type;
3261   auto ret = node.GetType(node_type);
3262   if (ret == ::ge::GRAPH_SUCCESS) {
3263     return std::string(node_type.GetString());
3264   } else {
3265     MS_LOG(WARNING) << "Get GNode type failed, ret: " << ret;
3266     return std::string();
3267   }
3268 }
3269 
3270 // 1) Identity or IdentityN is the input of Merge, not delete
3271 // 2) Identity or IdentityN is the subgraph(If) input, not delete
3272 // 3) Identity or IdentityN it the output, not delete
3273 // 4) Identity or IdentityN has multiple users, not delete
3274 // 5) Nodes with control edges, temporarily not delete
IsIdentityRedundant(const::ge::GNode & node) const3275 bool DfGraphConvertor::IsIdentityRedundant(const ::ge::GNode &node) const {
3276   auto node_type = GetGNodeType(node);
3277   if (node_type != kTypeIdentityN && node_type != kTypeIdentity) {
3278     MS_LOG(DEBUG) << "Node is not Identity or IdentityN, but is " << node_type << ", node name: " << GetGNodeName(node);
3279     return false;
3280   }
3281 
3282   auto node_name = GetGNodeName(node);
3283   auto ret = std::find_if(graph_outputs_.begin(), graph_outputs_.end(),
3284                           [&node_name](const auto &output) { return output.first.GetName() == node_name; });
3285   if (ret != graph_outputs_.end()) {
3286     return false;
3287   }
3288 
3289   for (size_t output_index = 0; output_index < node.GetOutputsSize(); output_index++) {
3290     auto output_nodes = node.GetOutDataNodesAndPortIndexs(static_cast<int32_t>(output_index));
3291     if (!output_nodes.empty() && has_es_node_) {
3292       return true;
3293     }
3294     if (output_nodes.size() != 1) {
3295       return false;
3296     }
3297 
3298     auto output_node_type = GetGNodeType(*(output_nodes.begin()->first));
3299     if (output_node_type == kTypeMerge || output_node_type == kTypeIf) {
3300       return false;
3301     }
3302   }
3303 
3304   if (!node.GetOutControlNodes().empty()) {
3305     return false;
3306   }
3307 
3308   return true;
3309 }
3310 
RemoveIdentity(::ge::GNode identity_node)3311 void DfGraphConvertor::RemoveIdentity(::ge::GNode identity_node) {
3312   MS_LOG(INFO) << "Start Remove Identity or IdentityN, identity_node: " << GetGNodeName(identity_node);
3313   auto node_type = GetGNodeType(identity_node);
3314   if (node_type != kTypeIdentity && node_type != kTypeIdentityN) {
3315     MS_LOG(EXCEPTION) << "Node is not Identity or IdentityN, but is " << node_type
3316                       << ", identity_node name: " << GetGNodeName(identity_node);
3317     return;
3318   }
3319   if (identity_node.GetInputsSize() != identity_node.GetOutputsSize()) {
3320     MS_LOG(EXCEPTION) << "Node output size " << identity_node.GetOutputsSize() << " is not equal to input size "
3321                       << identity_node.GetInputsSize() << ", identity_node: " << GetGNodeName(identity_node);
3322     return;
3323   }
3324 
3325   ::ge::graphStatus ret;
3326   for (size_t output_index = 0; output_index < identity_node.GetOutputsSize(); output_index++) {
3327     auto output_nodes = identity_node.GetOutDataNodesAndPortIndexs(static_cast<int>(output_index));
3328     if (output_nodes.size() != 1 && !has_es_node_) {
3329       return;
3330     }
3331 
3332     // 1. Set identity_node data edge
3333     for (size_t i = 0; i < output_nodes.size(); i++) {
3334       auto node_output = output_nodes[i];
3335       auto input_index = output_index;
3336       auto node_input = identity_node.GetInDataNodesAndPortIndexs(static_cast<int32_t>(input_index));
3337       ret = df_graph_->RemoveEdge(identity_node, static_cast<int32_t>(output_index), *node_output.first,
3338                                   node_output.second);
3339       if (ret != ::ge::GRAPH_SUCCESS) {
3340         MS_LOG(EXCEPTION) << "Remove edge failed, src identity_node: " << GetGNodeName(identity_node)
3341                           << ", index: " << output_index << ", dst identity_node: " << GetGNodeName(*node_output.first)
3342                           << ", index: " << node_output.second << ", ret: " << ret;
3343         return;
3344       }
3345       ret = df_graph_->AddDataEdge(*node_input.first, node_input.second, *node_output.first, node_output.second);
3346       if (ret != ::ge::GRAPH_SUCCESS) {
3347         MS_LOG(EXCEPTION) << "Add data edge failed, src identity_node: " << GetGNodeName(*node_input.first)
3348                           << ", index: "
3349                           << ", dst identity_node: " << GetGNodeName(*node_output.first)
3350                           << ", index: " << node_output.second << ", ret: " << ret;
3351         return;
3352       }
3353 
3354       // 2. Set identity_node control edge
3355       auto node_control = identity_node.GetInControlNodes();
3356       for (const auto &item : node_control) {
3357         ret = df_graph_->AddControlEdge(*item, *node_output.first);
3358         if (ret != ::ge::GRAPH_SUCCESS) {
3359           MS_LOG(EXCEPTION) << "Add control edge failed, src identity_node: " << GetGNodeName(*item)
3360                             << ", dst identity_node: " << GetGNodeName(*node_output.first) << ", ret: " << ret;
3361           return;
3362         }
3363       }
3364     }
3365   }
3366 
3367   // 3. Remove identity
3368   ret = df_graph_->RemoveNode(identity_node);
3369   if (ret != ::ge::GRAPH_SUCCESS) {
3370     MS_LOG(EXCEPTION) << "Remove identity_node failed, identity_node: " << GetGNodeName(identity_node)
3371                       << ", ret: " << ret;
3372     return;
3373   }
3374 }
3375 
IdentityOptimization()3376 void DfGraphConvertor::IdentityOptimization() {
3377   MS_LOG(INFO) << "Start IdentityOptimization, graph: " << anf_graph_->ToString();
3378   MS_EXCEPTION_IF_NULL(df_graph_);
3379   auto all_nodes = df_graph_->GetDirectNode();
3380   for (const auto &node : all_nodes) {
3381     if (IsIdentityRedundant(node)) {
3382       RemoveIdentity(node);
3383     }
3384   }
3385   MS_LOG(INFO) << "End IdentityOptimization, graph: " << anf_graph_->ToString();
3386 }
3387 
NoOpOptimization()3388 void DfGraphConvertor::NoOpOptimization() {
3389   MS_LOG(INFO) << "Start NoOpOptimization, graph:" << anf_graph_->ToString();
3390   MS_EXCEPTION_IF_NULL(df_graph_);
3391   auto all_nodes = df_graph_->GetDirectNode();
3392   for (const auto &node : all_nodes) {
3393     if (IsNoOpRedundant(node)) {
3394       RemoveNoOp(node);
3395     }
3396   }
3397   MS_LOG(INFO) << "End NoopOptimization, graph:" << anf_graph_->ToString();
3398 }
3399 
ESOptimization()3400 void DfGraphConvertor::ESOptimization() {
3401   MS_LOG(INFO) << "Start ESOptimization, graph:" << anf_graph_->ToString();
3402   MS_EXCEPTION_IF_NULL(df_graph_);
3403   auto all_nodes = df_graph_->GetDirectNode();
3404   ::ge::GNode no_op;
3405   bool not_remove = false;
3406   for (const auto &node : all_nodes) {
3407     node.GetAttr(kAttrNotRemove, not_remove);
3408     if (not_remove) {
3409       no_op = node;
3410       break;
3411     }
3412   }
3413   if (not_remove) {
3414     auto output_control_node = no_op.GetOutControlNodes();
3415     if (output_control_node.empty()) {
3416       return;
3417     }
3418     RemoveIdentityForES(*output_control_node[0]);
3419   }
3420 }
3421 
RemoveIdentityForES(::ge::GNode node)3422 void DfGraphConvertor::RemoveIdentityForES(::ge::GNode node) {
3423   ::ge::graphStatus ret;
3424   auto out_control_node = node.GetOutControlNodes();
3425   for (size_t input_index = 0; input_index < node.GetInputsSize(); input_index++) {
3426     auto node_input = node.GetInDataNodesAndPortIndexs(static_cast<int32_t>(input_index));
3427     ret = df_graph_->RemoveEdge(*node_input.first, node_input.second, node, input_index);
3428     if (ret != ::ge::GRAPH_SUCCESS) {
3429       MS_LOG(EXCEPTION) << "Remove edge failed, src node: " << GetGNodeName(*node_input.first)
3430                         << ", index: " << node_input.second << ", dst identity_node: " << GetGNodeName(node)
3431                         << ", index: " << input_index << ", ret: " << ret;
3432       return;
3433     }
3434   }
3435   ret = df_graph_->RemoveNode(node);
3436   if (ret != ::ge::GRAPH_SUCCESS) {
3437     MS_LOG(EXCEPTION) << "Remove node failed, node: " << GetGNodeName(node);
3438   }
3439   if (out_control_node.empty()) {
3440     return;
3441   }
3442   auto output_node = out_control_node[0];
3443   MS_EXCEPTION_IF_NULL(output_node);
3444   RemoveIdentityForES(*output_node);
3445 }
3446 
IsNoOpRedundant(const::ge::GNode & node) const3447 bool DfGraphConvertor::IsNoOpRedundant(const ::ge::GNode &node) const {
3448   auto node_type = GetGNodeType(node);
3449   if (node_type != kTypeNoOp) {
3450     return false;
3451   }
3452   if (!training_) {
3453     return true;
3454   }
3455 
3456   bool not_remove = false;
3457   node.GetAttr(kAttrNotRemove, not_remove);
3458   if (not_remove) {
3459     return false;
3460   }
3461 
3462   auto out_control_node = node.GetOutControlNodes();
3463   auto in_control_node = node.GetInControlNodes();
3464   if (out_control_node.size() == 1 || in_control_node.size() == 1) {
3465     return true;
3466   }
3467   if (out_control_node.size() > kNoOpOptThreshold || in_control_node.size() > kNoOpOptThreshold) {
3468     return false;
3469   }
3470   return true;
3471 }
RemoveNoOp(::ge::GNode noop)3472 void DfGraphConvertor::RemoveNoOp(::ge::GNode noop) {
3473   MS_LOG(INFO) << "Start Remove NoOp, node:" << GetGNodeName(noop);
3474   auto node_type = GetGNodeType(noop);
3475   if (node_type != kTypeNoOp) {
3476     MS_LOG(EXCEPTION) << "Node is not NoOp, but is: " << GetGNodeName(noop);
3477   }
3478 
3479   auto in_control_nodes = noop.GetInControlNodes();
3480   auto out_control_nodes = noop.GetOutControlNodes();
3481   auto ret = df_graph_->RemoveNode(noop);
3482   if (ret != ::ge::GRAPH_SUCCESS) {
3483     MS_LOG(EXCEPTION) << "Remove node failed, node: " << GetGNodeName(noop);
3484   }
3485   for (auto src_node : in_control_nodes) {
3486     for (auto dst_node : out_control_nodes) {
3487       ret = df_graph_->AddControlEdge(*src_node, *dst_node);
3488       if (ret != ::ge::GRAPH_SUCCESS) {
3489         MS_LOG(EXCEPTION) << "Add control edge failed, src node: " << GetGNodeName(*src_node)
3490                           << ", dst node:" << GetGNodeName(*dst_node);
3491       }
3492     }
3493   }
3494   MS_LOG(INFO) << "End Remove Noop, node: " << GetGNodeName(noop);
3495 }
3496 
ProcessSubgraph(const AnfNodePtr & node,const AnfNodePtr & branch_node,ParamIndexMap & branch_to_parent_node_map)3497 void DfGraphConvertor::ProcessSubgraph(const AnfNodePtr &node, const AnfNodePtr &branch_node,
3498                                        ParamIndexMap &branch_to_parent_node_map) {
3499   MS_LOG(INFO) << "ProcessSubgraph begin.";
3500   ValueNodePtr graph_node = nullptr;
3501   if (branch_node->isa<CNode>()) {
3502     graph_node = branch_node->cast<CNodePtr>()->input(1)->cast<ValueNodePtr>();
3503   } else if (branch_node->isa<ValueNode>()) {
3504     graph_node = branch_node->cast<ValueNodePtr>();
3505   } else {
3506     return;
3507   }
3508 
3509   MS_EXCEPTION_IF_NULL(graph_node);
3510   auto anf_graph = graph_node->value()->cast<AnfGraphPtr>();
3511   MS_EXCEPTION_IF_NULL(anf_graph);
3512   DfGraphConvertor converter(anf_graph, phase_prefix_);
3513   converter.graph_type_ = GraphType::kBranch;
3514 
3515   auto &params = anf_graph->parameters();
3516   if (ref_mode_) {
3517     for (size_t i = 0; i < params.size(); i++) {
3518       auto &param = params[i];
3519       if (branch_to_parent_node_map.find(i) != branch_to_parent_node_map.end()) {
3520         size_t parent_index = branch_to_parent_node_map[i];
3521         OperatorPtr op = nullptr;
3522         op = std::make_shared<Data>();
3523         MS_EXCEPTION_IF_NULL(op);
3524         SetXDataIndex(op, parent_index);
3525         converter.op_cache_[param.get()] = op;
3526       } else if (!HasAbstractMonad(param)) {
3527         MS_LOG(EXCEPTION) << "Branch graph input index to parent node dyn input index error, "
3528                           << "branch graph: " << anf_graph->ToString() << "'s " << i << "(st/nd/rd/st)"
3529                           << " input can not find the corresponding parent node input index.";
3530       }
3531     }
3532   } else {
3533     auto &dyn_input = branch_input_handle_cache_[node.get()];
3534     MS_EXCEPTION_IF_NULL(dyn_input);
3535     auto &inputs = tuple_out_handle_cache_[dyn_input->at(1).get()];
3536     MS_EXCEPTION_IF_NULL(inputs);
3537     for (size_t i = 0; i < params.size(); i++) {
3538       auto &param = params[i];
3539       if (branch_to_parent_node_map.find(i) != branch_to_parent_node_map.end()) {
3540         size_t parent_index = branch_to_parent_node_map[i];
3541         auto &parent_handle = inputs->at(parent_index);
3542         OperatorPtr op = nullptr;
3543         MS_EXCEPTION_IF_NULL(parent_handle.op);
3544         if (parent_handle.op->GetOpType() == kTypeVariable) {
3545           auto name = parent_handle.op->GetName();
3546           op = std::make_shared<Variable>(name);
3547           MS_EXCEPTION_IF_NULL(op);
3548           SetXDataIndex(op, parent_index);
3549         } else {
3550           op = std::make_shared<Data>();
3551           MS_EXCEPTION_IF_NULL(op);
3552           SetXDataIndex(op, parent_index);
3553         }
3554         converter.op_cache_[param.get()] = op;
3555       } else if (!HasAbstractMonad(param)) {
3556         MS_LOG(EXCEPTION) << "Branch graph input index to parent node dyn input index error, "
3557                           << "branch graph: " << anf_graph->ToString() << "'s " << i << "(st/nd/rd/st)"
3558                           << " input can not find the corresponding parent node input index.";
3559       }
3560     }
3561   }
3562 
3563   std::string graph_name = anf_graph->ToString();
3564   auto iter = branches_repeat_times.find(graph_name);
3565   if (iter == branches_repeat_times.end()) {
3566     branches_repeat_times[graph_name] = 1;
3567   } else {
3568     iter->second += 1;
3569     graph_name = graph_name + "_" + std::to_string(iter->second);
3570   }
3571   (void)converter.ConvertAllNode().BuildGraph(graph_name);
3572 #ifdef ENABLE_DUMP_IR
3573   std::string name = graph_node->ToString() + "_ge_graph.dot";
3574   auto context = MsContext::GetInstance();
3575   MS_EXCEPTION_IF_NULL(context);
3576   if (context->CanDump(kFully)) {
3577     converter.DrawComputeGraph(name);
3578   }
3579 #endif
3580   branches_map_[branch_node.get()] = *(converter.df_graph_);
3581   MS_LOG(INFO) << "ProcessSubgraph end.";
3582 }
3583 
3584 // Update GE op's shape and type info
UpdateOpDesc(const AnfNodePtr node)3585 void DfGraphConvertor::UpdateOpDesc(const AnfNodePtr node) {
3586   MS_EXCEPTION_IF_NULL(node);
3587   if (node == nullptr || !node->isa<CNode>()) {
3588     return;
3589   }
3590 
3591   if (op_cache_.find(node.get()) == op_cache_.end()) {
3592     return;
3593   }
3594 
3595   OpAdapterPtr adpt = FindAdapter(node, training_);
3596   if (adpt == nullptr) {
3597     error_ = NOT_FOUND;
3598     return;
3599   }
3600 
3601   // get Operator from op_cache_
3602   OperatorPtr op = Convert(node);
3603   MS_EXCEPTION_IF_NULL(op);
3604   std::string op_type = op->GetOpType();
3605   if (!IsNeedToUpdateTensorDesc(op_type, node)) {
3606     MS_LOG(INFO) << "No need to set the opDesc of node: " << node->fullname_with_scope() << ", op type is " << op_type;
3607     return;
3608   }
3609 
3610   adpt->updateOutputDesc(op, node->Shape(), node->Type(), node);
3611 }
3612 
Convert(const AnfNodePtr node)3613 OperatorPtr DfGraphConvertor::Convert(const AnfNodePtr node) {
3614   if (node == nullptr) {
3615     MS_LOG(ERROR) << "node is nullptr";
3616     error_ = NOT_FOUND;
3617     return nullptr;
3618   }
3619   // find in cache
3620   if (op_cache_.count(node.get()) != 0) {
3621     MS_LOG(DEBUG) << "Get op from cache: " << op_cache_[node.get()]->GetName();
3622     return op_cache_[node.get()];
3623   }
3624 
3625   // do not convert primitive node
3626   if (IsValueNode<Primitive>(node)) {
3627     return nullptr;
3628   }
3629   // convert a new one
3630   if (node->isa<CNode>()) {
3631     auto cnode = node->cast<CNodePtr>();
3632     if (IsSubGraph() && IsWhileNode(cnode)) {
3633       return nullptr;
3634     }
3635     if (!IsSubGraph() && IsWhileNode(cnode)) {
3636       CacheWhileGraph(cnode);
3637       auto &graphs = while_graph_cache_[cnode];
3638       GetWhileUsedInputIndex(graphs);
3639       SetParamIndexMap(graphs);
3640       cur_while_node_ = cnode;
3641     }
3642     return ConvertCNode(cnode);
3643   }
3644 
3645   if (node->isa<Parameter>() && IsSubGraph()) {
3646     return nullptr;
3647   }
3648 
3649   if (node->isa<Parameter>()) {
3650     return ConvertParameter(node);
3651   }
3652   if (node->isa<ValueNode>()) {
3653     if (IsValueNode<Monad>(node)) {
3654       return nullptr;
3655     }
3656     return ConvertValueNode(node->cast<ValueNodePtr>());
3657   }
3658 
3659   MS_LOG(ERROR) << "Invalid AnfNode";
3660   error_ = INVALID_ARGUMENT;
3661   return nullptr;
3662 }
3663 
ConvertTopK(const CNodePtr & node)3664 void DfGraphConvertor::ConvertTopK(const CNodePtr &node) {
3665   MS_EXCEPTION_IF_NULL(node);
3666   auto value_ptr = node->input(kIndex2)->cast<ValueNodePtr>();
3667   if (value_ptr == nullptr) {
3668     // input is not const valuenode, cannot convert to int32, throw exception when input k is int64 since cann
3669     // has precision problem, can be deleted after cann support int64 for input k
3670     if (common::AnfAlgo::GetPrevNodeOutputInferDataType(node, kIndex1) == kNumberTypeInt64) {
3671       MS_LOG(EXCEPTION) << "Op TopK(" << node->fullname_with_scope() << ")'s second input k is an int64 mutable "
3672                         << "tensor/scalar, which is not supported in ascend, please use int32.";
3673     }
3674     return;
3675   }
3676   MS_LOG(INFO) << "Convert TopK second input's type from int64 to int32.";
3677   auto input_value = value_ptr->value();
3678   MS_EXCEPTION_IF_NULL(input_value);
3679   std::ostringstream ss;
3680   ss << "op" << value_ptr.get();
3681   op_draw_name_[value_ptr.get()] = ss.str();
3682   compute_sout_ << ss.str() << "[label= \"" << value_ptr->value()->ToString() << "\" shape=ellipse]" << endl;
3683   int32_t k_value;
3684   if (input_value->isa<tensor::Tensor>()) {
3685     auto input_tensor = input_value->cast<tensor::TensorPtr>();
3686     if (input_tensor->data_type() == kNumberTypeInt32) {
3687       k_value = *static_cast<int32_t *>(input_tensor->data_c());
3688     } else {
3689       k_value = LongToInt(*static_cast<int64_t *>(input_tensor->data_c()));
3690     }
3691   } else {
3692     k_value = LongToInt(GetValue<int64_t>(input_value));
3693   }
3694   OpAdapterPtr adpt = FindAdapter(value_ptr, training_);
3695   MS_EXCEPTION_IF_NULL(adpt);
3696   auto op = adpt->generate(value_ptr);
3697   (void)adpt->setAttr(op, "value", k_value);
3698   op_cache_[value_ptr.get()] = op;
3699 }
3700 
CreateCast(const AnfNodePtr & input,const TypePtr & dst_type) const3701 AnfNodePtr DfGraphConvertor::CreateCast(const AnfNodePtr &input, const TypePtr &dst_type) const {
3702   auto func_graph = input->func_graph();
3703   MS_EXCEPTION_IF_NULL(func_graph);
3704   AnfNodePtrList inputs = {NewValueNode(prim::kPrimCast), input,
3705                            NewValueNode(static_cast<int64_t>(dst_type->type_id()))};
3706   auto cnode = func_graph->NewCNode(inputs);
3707   MS_EXCEPTION_IF_NULL(cnode);
3708   auto abs_tensor = std::make_shared<abstract::AbstractTensor>(dst_type, input->Shape());
3709   cnode->set_abstract(abs_tensor);
3710   return cnode;
3711 }
3712 
CastToInt(const ValuePtr & value) const3713 std::vector<int64_t> DfGraphConvertor::CastToInt(const ValuePtr &value) const {
3714   if (value == nullptr) {
3715     return {};
3716   }
3717   std::vector<int64_t> cur_value = {};
3718   if (utils::isa<ValueSequencePtr>(value)) {
3719     auto val_seq_ptr = value->cast<ValueSequencePtr>();
3720     MS_EXCEPTION_IF_NULL(val_seq_ptr);
3721     if (!val_seq_ptr->value().empty()) {
3722       auto first_val = val_seq_ptr->value().front();
3723       MS_EXCEPTION_IF_NULL(first_val);
3724       MS_EXCEPTION_IF_NULL(first_val->type());
3725       if (first_val->type()->number_type() == kNumberTypeInt64) {
3726         cur_value = GetValue<std::vector<int64_t>>(value);
3727       } else {
3728         auto origin_value = GetValue<std::vector<int>>(value);
3729         (void)std::transform(origin_value.begin(), origin_value.end(), std::back_inserter(cur_value),
3730                              [](int index) { return static_cast<int64_t>(index); });
3731       }
3732     }
3733   } else {
3734     MS_EXCEPTION_IF_NULL(value->type());
3735     if (value->type()->number_type() == kNumberTypeInt64) {
3736       cur_value.emplace_back(GetValue<int64_t>(value));
3737     } else {
3738       cur_value.emplace_back(static_cast<int64_t>(GetValue<int>(value)));
3739     }
3740   }
3741   return cur_value;
3742 }
3743 
TransInputDataType(const CNodePtr & node,const std::string & node_name) const3744 void DfGraphConvertor::TransInputDataType(const CNodePtr &node, const std::string &node_name) const {
3745   auto iter = kTransInputDTypeMap.find(node_name);
3746   if (iter == kTransInputDTypeMap.end()) {
3747     return;
3748   }
3749   MS_EXCEPTION_IF_NULL(node);
3750   MS_LOG(DEBUG) << "Trans input data type of node:" << node->DebugString();
3751   for (auto &item : iter->second) {
3752     auto input_node = node->input(item.first);
3753     TypeId dst_type = item.second;
3754     MS_EXCEPTION_IF_NULL(input_node);
3755     if (input_node->isa<CNode>() || input_node->isa<Parameter>()) {
3756       auto src_type = input_node->Type()->type_id();
3757       if (kObjectTypeTensorType == src_type) {
3758         src_type = dyn_cast<TensorType>(input_node->Type())->element()->type_id();
3759       }
3760       if (!IsValidConversion(src_type, dst_type)) {
3761         continue;
3762       }
3763       auto new_cast = CreateCast(input_node, TypeIdToType(dst_type));
3764       node->set_input(item.first, new_cast);
3765     } else if (input_node->isa<ValueNode>()) {
3766       auto input_value_node = input_node->cast<ValueNodePtr>();
3767       MS_EXCEPTION_IF_NULL(input_value_node);
3768       auto value = input_value_node->value();
3769       ValuePtr new_value = CastDstValue(value, dst_type);
3770       if (new_value == nullptr) {
3771         continue;
3772       }
3773       auto new_value_node = std::make_shared<ValueNode>(new_value);
3774       MS_EXCEPTION_IF_NULL(new_value_node);
3775       new_value_node->set_abstract(new_value->ToAbstract());
3776       node->set_input(item.first, new_value_node);
3777     }
3778   }
3779   MS_LOG(DEBUG) << "Finish to trans input data type of node:" << node->DebugString();
3780 }
3781 
TransAttrDataType(const CNodePtr & node,const std::string & node_name) const3782 void DfGraphConvertor::TransAttrDataType(const CNodePtr &node, const std::string &node_name) const {
3783   auto iter = kTransAttrDTypeMap.find(node_name);
3784   if (iter == kTransAttrDTypeMap.end()) {
3785     return;
3786   }
3787   MS_EXCEPTION_IF_NULL(node);
3788   MS_LOG(DEBUG) << "Trans attr data type of node:" << node->DebugString();
3789   auto prim = common::AnfAlgo::GetCNodePrimitive(node);
3790   MS_EXCEPTION_IF_NULL(prim);
3791   for (auto &item : iter->second) {
3792     std::string attr_name = item.first;
3793     TypeId dst_type = item.second;
3794     if (!prim->HasAttr(attr_name)) {
3795       MS_LOG(EXCEPTION) << "Please check kTransAttrDTypeMap, node:" << node->DebugString()
3796                         << " has no attr:" << attr_name;
3797     }
3798     auto attr_value = prim->GetAttr(attr_name);
3799     auto new_attr_value = CastDstValue(attr_value, dst_type);
3800     if (new_attr_value == nullptr) {
3801       continue;
3802     }
3803     prim->set_attr(attr_name, new_attr_value);
3804   }
3805   MS_LOG(DEBUG) << "Finish to trans attr data type of node:" << node->DebugString();
3806 }
3807 
TransDataType(const FuncGraphPtr & anf_graph) const3808 void DfGraphConvertor::TransDataType(const FuncGraphPtr &anf_graph) const {
3809   MS_EXCEPTION_IF_NULL(anf_graph);
3810   MS_LOG(DEBUG) << "TransDataType begin. graph:" << anf_graph->ToString();
3811   std::vector<AnfNodePtr> nodes = GetOrderedCNodes(anf_graph);
3812   for (auto &it : nodes) {
3813     if (it->isa<CNode>()) {
3814       auto node = it->cast<CNodePtr>();
3815       MS_EXCEPTION_IF_NULL(node);
3816       std::string name = GetCNodeTargetFuncName(node);
3817       TransInputDataType(node, name);
3818       TransAttrDataType(node, name);
3819     }
3820   }
3821   MS_LOG(DEBUG) << "TransDataType end. graph:" << anf_graph->ToString();
3822 }
3823 
ConvertReshape(const CNodePtr & node)3824 void DfGraphConvertor::ConvertReshape(const CNodePtr &node) {
3825   MS_LOG(INFO) << "Convert the second input of reshape to op attr.";
3826   const auto kInputNum = 3;
3827   if (node->size() < kInputNum) {
3828     MS_LOG(WARNING) << "Reshape must have two inputs.";
3829     return;
3830   }
3831   OpAdapterPtr adpt = FindAdapter(node, training_);
3832   if (adpt == nullptr) {
3833     return;
3834   }
3835   auto op = adpt->generate(node);
3836   MS_EXCEPTION_IF_NULL(op);
3837   // get shape form attr
3838   auto primitive = GetCNodePrimitive(node);
3839   MS_EXCEPTION_IF_NULL(primitive);
3840   if (primitive->HasAttr("shape")) {
3841     auto value = primitive->GetAttr("shape");
3842     auto list = CastToInt(value);
3843     (void)op->SetAttr("shape", list);
3844   }
3845   if (primitive->HasAttr("allowzero")) {
3846     auto value = primitive->GetAttr("allowzero");
3847     auto list = CastToInt(value);
3848     if (list.size() == 1) {
3849       (void)op->SetAttr("allowzero", list[0]);
3850     }
3851   }
3852   op_cache_[node.get()] = op;
3853 }
3854 
ConvertDynamicStitch(const CNodePtr & node)3855 void DfGraphConvertor::ConvertDynamicStitch(const CNodePtr &node) {
3856   MS_LOG(INFO) << "Convert and set 'N' attr of DynamicStitch.";
3857   OpAdapterPtr adpt = FindAdapter(node, training_);
3858   if (adpt == nullptr) {
3859     return;
3860   }
3861   auto op = adpt->generate(node);
3862   MS_EXCEPTION_IF_NULL(op);
3863   int64_t input_length = 0;
3864   auto indices = node->input(1);
3865   MS_EXCEPTION_IF_NULL(indices);
3866   if (indices->isa<CNode>()) {
3867     input_length = SizeToLong(indices->cast<CNodePtr>()->size()) - 1;
3868   } else if (IsValueNode<ValueSequence>(indices)) {
3869     const auto tuple = GetValueNode<ValueSequencePtr>(indices);
3870     MS_EXCEPTION_IF_NULL(tuple);
3871     input_length = SizeToLong(tuple->size());
3872   } else {
3873     MS_LOG(EXCEPTION) << "Input 1 of DynamicStitch is neither CNode nor ValueNode contains ValueSequence, but "
3874                       << indices->ToString() << ", can not set 'N' attr.";
3875   }
3876 
3877   (void)op->SetAttr("N", input_length);
3878   MS_LOG(INFO) << "Set 'N' attr of DynamicStitch to " << input_length;
3879   op_cache_[node.get()] = op;
3880 }
3881 
ConvertParallelGroupToHcom(const CNodePtr & node)3882 void DfGraphConvertor::ConvertParallelGroupToHcom(const CNodePtr &node) {
3883   auto group_name = common::AnfAlgo::GetNodeAttr<std::string>(node, kParallelGroup);
3884   OpAdapterPtr adpt = FindAdapter(node, training_);
3885   if (adpt == nullptr) {
3886     return;
3887   }
3888 
3889   // get operator
3890   OperatorPtr op = nullptr;
3891   auto it_op = op_cache_.find(node.get());
3892   if (it_op != op_cache_.end()) {
3893     op = it_op->second;
3894   } else {
3895     op = adpt->generate(node);
3896   }
3897   MS_EXCEPTION_IF_NULL(op);
3898   (void)op->SetAttr(kParallelGroup, group_name);
3899   op_cache_[node.get()] = op;
3900 }
3901 
ConvertParallelGroupIdToHcom(const CNodePtr & node)3902 void DfGraphConvertor::ConvertParallelGroupIdToHcom(const CNodePtr &node) {
3903   auto parallel_group_id_value = node->GetAttr(kParallelGroupId);
3904   auto parallel_group_id = GetValue<uint32_t>(parallel_group_id_value);
3905   OpAdapterPtr adpt = FindAdapter(node, training_);
3906   if (adpt == nullptr) {
3907     return;
3908   }
3909 
3910   // get operator
3911   OperatorPtr op = nullptr;
3912   auto it_op = op_cache_.find(node.get());
3913   if (it_op != op_cache_.end()) {
3914     op = it_op->second;
3915   } else {
3916     op = adpt->generate(node);
3917     op_cache_[node.get()] = op;
3918   }
3919   MS_EXCEPTION_IF_NULL(op);
3920   (void)op->SetAttr(kParallelGroupId, parallel_group_id);
3921   MS_LOG(DEBUG) << "Successfully convert _parallel_group_id: " << parallel_group_id << " to ge op: " << op->GetName();
3922 }
3923 
ConvertHcomFusionId(const CNodePtr & node)3924 void DfGraphConvertor::ConvertHcomFusionId(const CNodePtr &node) {
3925   MS_EXCEPTION_IF_NULL(node);
3926   MS_LOG(INFO) << "Add Hcom fusion_id";
3927   OpAdapterPtr adpt = FindAdapter(node, training_);
3928   if (adpt == nullptr) {
3929     return;
3930   }
3931   auto op = adpt->generate(node);
3932   MS_EXCEPTION_IF_NULL(op);
3933   // get shape form attr
3934   auto primitive = GetCNodePrimitive(node);
3935   MS_EXCEPTION_IF_NULL(primitive);
3936   auto fusion_value = primitive->GetAttr("fusion");
3937   if (fusion_value == nullptr) {
3938     MS_LOG(WARNING) << "Failed to get attr fusion for gather node " << node->fullname_with_scope();
3939     return;
3940   }
3941   int64_t fusion = 0;
3942   if (fusion_value->isa<Int64Imm>()) {
3943     fusion = GetValue<int64_t>(fusion_value);
3944   } else if (fusion_value->isa<Int32Imm>()) {
3945     fusion = GetValue<int32_t>(fusion_value);
3946   } else {
3947     MS_LOG(WARNING) << "Attr fusion is not int64/int32 type, real type " << fusion_value->type_name()
3948                     << ", gather node " << node->fullname_with_scope();
3949     return;
3950   }
3951   int64_t fusion_id = -1;
3952 
3953   // fusion 0: no fusion; 1(default): fusion; 2: fusion the ops by fusion id.
3954   if (fusion >= 1) {
3955     fusion_id = fusion;
3956     fusion = kHcclFusionByFusionID;
3957   } else if (fusion < 0) {
3958     fusion = kHcclFusionDefault;
3959   }
3960 
3961   auto context = MsContext::GetInstance();
3962   MS_EXCEPTION_IF_NULL(context);
3963   if (context->CellReuseLevel() != CellReuseLevel::kNoCellReuse) {
3964     MS_LOG(INFO) << "cell reuse not support all fusion";
3965     fusion = 0;
3966   }
3967   MS_EXCEPTION_IF_NULL(parallel::ParallelContext::GetInstance());
3968   auto parallel_mode = parallel::ParallelContext::GetInstance()->parallel_mode();
3969   if (!MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_TASK_OPT) &&
3970       (parallel_mode == parallel::kSemiAutoParallel || parallel_mode == parallel::kAutoParallel)) {
3971     fusion_id = 0;
3972     fusion = 0;
3973   }
3974   (void)op->SetAttr("fusion_id", fusion_id);
3975   (void)op->SetAttr("fusion", fusion);
3976   AddCommAttrForHcclNode(node, op);
3977   op_cache_[node.get()] = op;
3978 }
3979 
ConvertAllToAllv(const CNodePtr & node)3980 void DfGraphConvertor::ConvertAllToAllv(const CNodePtr &node) {
3981   OpAdapterPtr adpt = FindAdapter(node, training_);
3982   if (adpt == nullptr) {
3983     return;
3984   }
3985   auto op = adpt->generate(node);
3986   MS_EXCEPTION_IF_NULL(op);
3987   op_cache_[node.get()] = op;
3988   AddCommAttrForHcclNode(node, op);
3989   // set _is_inserted_by_ge attr to avoid mistaken delete
3990   auto primitive = GetCNodePrimitive(node);
3991   MS_EXCEPTION_IF_NULL(primitive);
3992   auto is_inserted_value = primitive->GetAttr("is_inserted_by_ge");
3993   if (is_inserted_value == nullptr) {
3994     return;
3995   }
3996   auto is_inserted = GetValue<bool>(is_inserted_value);
3997   (void)op->SetAttr("_is_inserted_by_ge", is_inserted);
3998 }
3999 
ConvertUniformReal(const CNodePtr & node)4000 void DfGraphConvertor::ConvertUniformReal(const CNodePtr &node) {
4001   OpAdapterPtr adpt = FindAdapter(node, training_);
4002   if (adpt == nullptr) {
4003     return;
4004   }
4005   auto op = adpt->generate(node);
4006   MS_EXCEPTION_IF_NULL(op);
4007   op_cache_[node.get()] = op;
4008   (void)op->SetAttr("dtype", ::ge::DataType::DT_FLOAT);
4009 }
4010 
ConvertUpdateState(const CNodePtr & node)4011 void DfGraphConvertor::ConvertUpdateState(const CNodePtr &node) {
4012   OpAdapterPtr adpt = FindAdapter(node, training_);
4013   if (adpt == nullptr) {
4014     return;
4015   }
4016   auto op = adpt->generate(node);
4017   MS_EXCEPTION_IF_NULL(op);
4018   op_cache_[node.get()] = op;
4019   if (common::AnfAlgo::HasNodeAttr(kAttrNotRemove, node)) {
4020     bool not_remove = common::AnfAlgo::GetNodeAttr<bool>(node, kAttrNotRemove);
4021     (void)op->SetAttr(kProcessNodeEngineID, "PS");
4022     (void)op->SetAttr(kAttrNotRemove, not_remove);
4023     has_es_node_ = true;
4024   }
4025 }
4026 
ConvertHcclNode(const CNodePtr & node)4027 void DfGraphConvertor::ConvertHcclNode(const CNodePtr &node) {
4028   OpAdapterPtr adpt = FindAdapter(node, training_);
4029   if (adpt == nullptr) {
4030     return;
4031   }
4032   auto op = adpt->generate(node);
4033   MS_EXCEPTION_IF_NULL(op);
4034   AddCommAttrForHcclNode(node, op);
4035   op_cache_[node.get()] = op;
4036 }
4037 
AddCommAttrForHcclNode(const CNodePtr & node,const OperatorPtr & converted_op) const4038 void DfGraphConvertor::AddCommAttrForHcclNode(const CNodePtr &node, const OperatorPtr &converted_op) const {
4039   MS_EXCEPTION_IF_NULL(node);
4040   MS_EXCEPTION_IF_NULL(converted_op);
4041   if (!common::AnfAlgo::HasNodeAttr(kAttrGroup, node)) {
4042     MS_LOG(WARNING) << "Node " << node->fullname_with_scope() << " does not have attr " << kAttrGroup << " skip.";
4043     return;
4044   }
4045   std::string group = common::AnfAlgo::GetNodeAttr<std::string>(node, kAttrGroup);
4046   (void)converted_op->SetAttr("group", group);
4047 #ifdef ENABLE_D
4048   if (!common::GetEnv(kSimulationLevel).empty()) {
4049     auto hccl_inner_comm_name = device::DummyAscendCollectiveCommLib::GetInstance().HcclInnerCommName(group);
4050     MS_LOG(INFO) << "Set comm handle and comm group name of the hccl node: " << node->fullname_with_scope()
4051                  << "comm name:" << hccl_inner_comm_name;
4052     (void)converted_op->SetAttr("group", hccl_inner_comm_name);
4053     return;
4054   }
4055   if (common::GetEnv(kSimulationLevel).empty() && !common::IsNeedProfileMemory()) {
4056     if (common::UseHostCollective() && !hccl::HcclAdapter::GetInstance().UseHcclCM()) {
4057       // For HcclCommInitRootInfo manner, set 'group' and 'comm' attrs. 'group' attr value should be hccl's inner comm
4058       // name.
4059       auto comm = device::ascend::AscendCollectiveCommLib::GetInstance().HcclCommunicator(group);
4060       auto hccl_inner_comm_name = device::ascend::AscendCollectiveCommLib::GetInstance().HcclInnerCommName(group);
4061       MS_LOG(INFO) << "Set comm handle and comm group name of the hccl node: " << node->fullname_with_scope()
4062                    << ". Comm handle: " << comm << ", comm name:" << hccl_inner_comm_name;
4063       MS_EXCEPTION_IF_NULL(comm);
4064       (void)converted_op->SetAttr("comm", reinterpret_cast<int64_t>(comm));
4065       (void)converted_op->SetAttr("group", hccl_inner_comm_name);
4066     } else {
4067       // For rank_table manner, 'group' attr should be user set group name.
4068       MS_LOG(INFO) << "Set group name for ranktable manner: " << group;
4069       (void)converted_op->SetAttr("group", group);
4070     }
4071   }
4072 #endif
4073 }
4074 
ConvertConv2D(const CNodePtr & node)4075 void DfGraphConvertor::ConvertConv2D(const CNodePtr &node) {
4076   MS_LOG(INFO) << "Convert and set 'padding' attr for Conv2D-like op.";
4077   MS_EXCEPTION_IF_NULL(node);
4078   OpAdapterPtr adpt = FindAdapter(node, training_);
4079   if (adpt == nullptr) {
4080     return;
4081   }
4082   auto op = adpt->generate(node);
4083   MS_EXCEPTION_IF_NULL(op);
4084   op_cache_[node.get()] = op;
4085   auto primitive = GetCNodePrimitive(node);
4086   MS_EXCEPTION_IF_NULL(primitive);
4087   std::string pad_mode;
4088   if (auto pad_value = primitive->GetAttr("padding"); pad_value != nullptr) {
4089     pad_mode = GetValue<std::string>(pad_value);
4090   } else if (auto value = primitive->GetAttr("pad_mode"); value != nullptr) {
4091     // Get 'pad_mode' attr and set it to 'padding' attr for ge
4092     const mindspore::HashMap<int64_t, std::string> pad_mode_map{{1, "SAME"}, {2, "VALID"}};
4093     if (value->isa<StringImm>()) {
4094       pad_mode = GetValue<std::string>(value);
4095       (void)std::transform(pad_mode.cbegin(), pad_mode.cend(), pad_mode.begin(), toupper);
4096       if (pad_mode != "SAME" && pad_mode != "VALID") {
4097         return;
4098       }
4099     } else if (auto it = pad_mode_map.find(GetValue<int64_t>(value)); it != pad_mode_map.cend()) {
4100       // 'pad_mode' attr could be an enumeration
4101       pad_mode = it->second;
4102     } else {
4103       return;
4104     }
4105   } else {
4106     MS_LOG(INFO) << "Node: " << node->fullname_with_scope() << " has no 'padding' or 'pad_mode' attr";
4107     return;
4108   }
4109   MS_LOG(INFO) << "Set 'padding' attr of node: " << node->fullname_with_scope() << " to " << pad_mode;
4110   (void)op->SetAttr("padding", pad_mode);
4111 }
4112 
ConvertOCRRecPreHandle(const CNodePtr & node)4113 void DfGraphConvertor::ConvertOCRRecPreHandle(const CNodePtr &node) {
4114   MS_LOG(INFO) << "Add OCRRecognitionPreHandle _op_max_shape attr";
4115   OpAdapterPtr adpt = FindAdapter(node, training_);
4116   if (adpt == nullptr) {
4117     return;
4118   }
4119   auto op = adpt->generate(node);
4120   MS_EXCEPTION_IF_NULL(op);
4121   // get shape form attr
4122   auto primitive = GetCNodePrimitive(node);
4123   MS_EXCEPTION_IF_NULL(primitive);
4124   auto value = primitive->GetAttr("_op_max_shape");
4125   if (value == nullptr) {
4126     return;
4127   }
4128   auto op_max_shape = GetValue<std::string>(value);
4129   (void)op->SetAttr("_op_max_shape", op_max_shape);
4130   op_cache_[node.get()] = op;
4131 }
4132 
GetHandler(const AnfNodePtr & node)4133 OutHandler DfGraphConvertor::GetHandler(const AnfNodePtr &node) {
4134   if (node == nullptr) {
4135     MS_LOG(ERROR) << "Get nullptr while getting handler from node";
4136     return OutHandler(nullptr, "");
4137   }
4138   if (out_handle_cache_.find(node.get()) != out_handle_cache_.end()) {
4139     return out_handle_cache_[node.get()];
4140   }
4141   auto op = Convert(node);
4142   if (op != nullptr) {
4143     auto name = op->GetName();
4144     if ((vars_.count(name) != 0) && vars_[name] != nullptr) {
4145       op = vars_[name];
4146       MS_LOG(DEBUG) << "update tuple_out_handle_cache_ " << name;
4147     }
4148     return OutHandler(op, "", node);
4149   } else {
4150     MS_LOG(DEBUG) << "Add an empty out handler: " << node->ToString();
4151     return OutHandler();
4152   }
4153 }
4154 
CheckCNode(const std::string & name,const CNodePtr node)4155 bool DfGraphConvertor::CheckCNode(const std::string &name, const CNodePtr node) {
4156   // ignore apply node of return
4157   if (name == "" || name == prim::kPrimSwitch->name() || name == prim::kPrimSwitchLayer->name() ||
4158       name == prim::kPrimPartial->name()) {
4159     return false;
4160   }
4161 
4162   const mindspore::HashMap<std::string, std::function<void(decltype(this), const CNodePtr &)>>
4163     auxiliary_node_converters{
4164       // Convert TopK second input from int64 to int32.
4165       {prim::kPrimTopK->name(), &DfGraphConvertor::ConvertTopK},
4166       // Convert Reshape add const input to attr(shape)
4167       {prim::kPrimReshape->name(), &DfGraphConvertor::ConvertReshape},
4168       {prim::kPrimOCRRecognitionPreHandle->name(), &DfGraphConvertor::ConvertOCRRecPreHandle},
4169       // Add attr 'pad_mode' to Conv2D-like op
4170       {prim::kPrimConv2D->name(), &DfGraphConvertor::ConvertConv2D},
4171       {prim::kPrimDepthwiseConv2dNative->name(), &DfGraphConvertor::ConvertConv2D},
4172       {kNameConv2DBackpropInputV2, &DfGraphConvertor::ConvertConv2D},
4173       {prim::kPrimConv2DBackpropInput->name(), &DfGraphConvertor::ConvertConv2D},
4174       {prim::kPrimConv2DBackpropFilter->name(), &DfGraphConvertor::ConvertConv2D},
4175       // Add attr 'N' to DynamicStitch
4176       {prim::kPrimDynamicStitch->name(), &DfGraphConvertor::ConvertDynamicStitch},
4177       // Convert hccl op for comm handle
4178       {prim::kPrimAllReduce->name(), &DfGraphConvertor::ConvertHcomFusionId},
4179       {prim::kPrimAllGather->name(), &DfGraphConvertor::ConvertHcomFusionId},
4180       {prim::kPrimReduceScatter->name(), &DfGraphConvertor::ConvertHcomFusionId},
4181       {prim::kPrimBroadcast->name(), &DfGraphConvertor::ConvertHcclNode},
4182       {prim::kPrimReduceScatter->name(), &DfGraphConvertor::ConvertHcclNode},
4183       {prim::kPrimSend->name(), &DfGraphConvertor::ConvertHcclNode},
4184       {prim::kPrimReceive->name(), &DfGraphConvertor::ConvertHcclNode},
4185       {prim::kPrimAllToAllv->name(), &DfGraphConvertor::ConvertAllToAllv},
4186       {prim::kPrimUniformReal->name(), &DfGraphConvertor::ConvertUniformReal},
4187       {prim::kPrimMatmulReduceScatter->name(), &DfGraphConvertor::ConvertHcclNode},
4188       {prim::kPrimAllGatherMatmul->name(), &DfGraphConvertor::ConvertHcclNode},
4189       {prim::kPrimUpdateState->name(), &DfGraphConvertor::ConvertUpdateState},
4190     };
4191 
4192   if (const auto it = auxiliary_node_converters.find(name); it != auxiliary_node_converters.cend()) {
4193     it->second(this, node);
4194   }
4195   if (common::AnfAlgo::HasNodeAttr(kParallelGroup, node)) {
4196     ConvertParallelGroupToHcom(node);
4197   }
4198   if (node->HasAttr(kParallelGroupId)) {
4199     ConvertParallelGroupIdToHcom(node);
4200   }
4201 
4202   return true;
4203 }
4204 
CheckAndAddScopeAttrInt(const OperatorPtr op,const PrimitivePtr primitive,const std::string & attr_name)4205 void CheckAndAddScopeAttrInt(const OperatorPtr op, const PrimitivePtr primitive, const std::string &attr_name) {
4206   auto attr_value = primitive->GetAttr(attr_name);
4207   if (attr_value != nullptr) {
4208     auto value = GetValue<int64_t>(attr_value);
4209     (void)op->SetAttr(attr_name, value);
4210   }
4211 }
4212 
CheckAndAddScopeAttrString(const OperatorPtr op,const PrimitivePtr primitive,const std::string & attr_name)4213 void CheckAndAddScopeAttrString(const OperatorPtr op, const PrimitivePtr primitive, const std::string &attr_name) {
4214   auto attr_value = primitive->GetAttr(attr_name);
4215   if (attr_value != nullptr) {
4216     auto value = GetValue<std::string>(attr_value);
4217     (void)op->SetAttr(attr_name, value);
4218   }
4219 }
4220 
4221 // If node does not have abstract, it will fail when the node is generated to operator.
SetNodeAbstract(const CNodePtr & node) const4222 void DfGraphConvertor::SetNodeAbstract(const CNodePtr &node) const {
4223   MS_EXCEPTION_IF_NULL(node);
4224   if (node->abstract() != nullptr) {
4225     return;
4226   }
4227   if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
4228     auto inputs = node->inputs();
4229     AbstractBasePtrList elem;
4230     std::transform(inputs.begin() + 1, inputs.end(), std::back_inserter(elem),
4231                    [](const AnfNodePtr &node) { return node->abstract(); });
4232     node->set_abstract(std::make_shared<abstract::AbstractTuple>(elem));
4233     return;
4234   }
4235   if (IsPrimitiveCNode(node, prim::kPrimReturn) || IsPrimitiveCNode(node, prim::kPrimDepend)) {
4236     auto inputs = node->inputs();
4237     if (inputs.size() < kInputSize2) {
4238       MS_LOG(EXCEPTION) << "node input size " << inputs.size() << " less than 2, node: " << node->fullname_with_scope();
4239     }
4240     auto input = inputs[1];
4241     MS_EXCEPTION_IF_NULL(input);
4242     node->set_abstract(input->abstract());
4243     return;
4244   }
4245   MS_LOG(WARNING) << "Node has not abstract:" << node->fullname_with_scope() << ", DebugString: " << node->ToString();
4246 }
4247 
ConvertCNode(const CNodePtr node)4248 OperatorPtr DfGraphConvertor::ConvertCNode(const CNodePtr node) {
4249   SaveParamFormat(node);
4250   std::string name = GetCNodeTargetFuncName(node);
4251   if (!CheckCNode(name, node)) {
4252     return nullptr;
4253   }
4254 
4255   // get corresponding OpAdapter
4256   OpAdapterPtr adpt = FindAdapter(node, training_);
4257   if (adpt == nullptr) {
4258     MS_LOG(ERROR) << "Cannot get adapter for " << node->fullname_with_scope();
4259     unsupported_ops_names_.insert(name);
4260     error_ = NOT_FOUND;
4261     return nullptr;
4262   }
4263   SetNodeAbstract(node);
4264   // get operator
4265   OperatorPtr op = nullptr;
4266   auto it_op = op_cache_.find(node.get());
4267   if (it_op != op_cache_.end()) {
4268     op = it_op->second;
4269   } else {
4270     if (cur_while_node_ == node) {
4271       op = adpt->generateDynOutputOp(node);
4272     } else {
4273       op = adpt->generate(node);
4274     }
4275   }
4276 
4277   // set attribute for primitive
4278   (void)adpt->setAttr(op, node);
4279   auto value_node = node->input(0)->cast<ValueNodePtr>();
4280   if (value_node != nullptr && value_node->value()->cast<PrimitivePtr>() != nullptr) {
4281     MS_LOG(DEBUG) << "Set attr for subgraph multi dims";
4282     auto primitive = value_node->value()->cast<PrimitivePtr>();
4283     CheckAndAddScopeAttrInt(op, primitive, "_subgraph_multi_dims_index");
4284     CheckAndAddScopeAttrString(op, primitive, "_subgraph_multi_dims_input_dims");
4285     CheckAndAddScopeAttrString(op, primitive, "_subgraph_multi_dims_input_shape");
4286   }
4287 
4288   // add into cache
4289   (void)op_cache_.emplace(node.get(), op);
4290 
4291   DrawCNode(node, adpt);
4292 
4293   return op_cache_[node.get()];
4294 }
4295 
ConvertParameter(const AnfNodePtr node)4296 OperatorPtr DfGraphConvertor::ConvertParameter(const AnfNodePtr node) {
4297   // convert Parameter in ANF to variable in DataFlow
4298   auto adpt = FindAdapter(node, training_);
4299   if (adpt == nullptr) {
4300     MS_LOG(EXCEPTION) << "Can not find adapter for Parameter";
4301   }
4302   auto op = adpt->generate(node);
4303   op_cache_[node.get()] = op;
4304 
4305   // build index for parameter using name
4306   std::string name = std::static_pointer_cast<Parameter>(node)->name();
4307   params_[name] = node;
4308   std::ostringstream ss;
4309   ss << "op" << node.get();
4310   op_draw_name_[node.get()] = ss.str();
4311   compute_sout_ << ss.str() << "[shape=octagon, label=\"" << name << "\"]" << endl;
4312   return op_cache_[node.get()];
4313 }
4314 
SaveParamFormat(const CNodePtr node)4315 void DfGraphConvertor::SaveParamFormat(const CNodePtr node) {
4316   AnfNodePtr op = node->input(0);
4317   if (IsValueNode<Primitive>(op)) {
4318     auto prim = GetValueNode<PrimitivePtr>(op);
4319     std::string format;
4320     auto op_def = ops::GetOpDef(prim->name());
4321     if (op_def) {
4322       for (size_t index = 0; index < op_def->args_.size() && index < node->size() - 1; index++) {
4323         auto arg = op_def->args_[index];
4324         if (arg.as_init_arg_ && (arg.arg_name_ == ops::kFormat || arg.arg_name_ == ops::kDataFormat)) {
4325           auto value_ptr = node->input(index + 1)->cast<ValueNodePtr>();
4326           if (value_ptr == nullptr) {
4327             break;
4328           }
4329           auto input_value = value_ptr->value();
4330           MS_EXCEPTION_IF_NULL(input_value);
4331           auto format_id = GetValue<int64_t>(input_value);
4332           format = FormatEnumToString(static_cast<Format>(format_id));
4333         }
4334       }
4335     }
4336     auto value_ptr = prim->GetAttr(ops::kFormat);
4337     if (value_ptr) {
4338       if (value_ptr->isa<Int64Imm>()) {
4339         bool converted = CheckAndConvertUtils::ConvertAttrValueToString(prim->name(), "format", &value_ptr);
4340         if (converted) {
4341           format = value_ptr->ToString();
4342         } else {
4343           CheckAndConvertUtils::GetFormatStringVal(prim, &format);
4344         }
4345       } else if (value_ptr->isa<StringImm>()) {
4346         format = value_ptr->ToString();
4347       }
4348     }
4349 
4350     if (format == "NCDHW" || format == "NHWC") {
4351       for (size_t i = 1; i < node->size(); i++) {
4352         auto input = node->input(i);
4353         if (input->isa<Parameter>()) {
4354           param_format_[input->DebugString()] = format;
4355           MS_LOG(DEBUG) << "Save Param " << input->DebugString() << " format: " << format;
4356         }
4357       }
4358     }
4359   }
4360 }
4361 
TryConvertValueNodeToMultiConst(const ValueNodePtr node)4362 Status DfGraphConvertor::TryConvertValueNodeToMultiConst(const ValueNodePtr node) {
4363   MS_EXCEPTION_IF_NULL(node);
4364   ValuePtr value = node->value();
4365   MS_EXCEPTION_IF_NULL(value);
4366   if (!value->isa<ValueList>() && !value->isa<ValueTuple>()) {
4367     return FAILED;
4368   }
4369 
4370   auto vec = value->isa<ValueTuple>() ? value->cast<ValueTuplePtr>()->value() : value->cast<ValueListPtr>()->value();
4371   if (vec.empty()) {
4372     return FAILED;
4373   }
4374 
4375   std::shared_ptr<std::vector<OutHandler>> tuple_items = std::make_shared<std::vector<OutHandler>>();
4376   // if the the sequence has only one element which is a scalar, it should be convert to a 1-D Tensor rather than a
4377   // 0-D Scalar.
4378   if (vec.size() == 1 && !vec[0]->isa<MeTensor>()) {
4379     return FAILED;
4380   }
4381   for (size_t i = 0; i < vec.size(); i++) {
4382     MS_EXCEPTION_IF_NULL(vec[i]);
4383     GeTensorPtr ge_tensor = nullptr;
4384     if (vec[i]->isa<MeTensor>()) {
4385       ge_tensor = transform::TransformUtil::ConvertTensor(vec[i]->cast<MeTensorPtr>(), kOpFormat_DEFAULT);
4386       MS_EXCEPTION_IF_NULL(ge_tensor);
4387     } else {
4388       ge_tensor = transform::TransformUtil::ConvertScalar(vec[i]);
4389       if (ge_tensor == nullptr) {
4390         return FAILED;
4391       }
4392     }
4393     auto const_op = std::make_shared<Constant>(node->fullname_with_scope() + "/const/inputs/" + std::to_string(i));
4394     AddGraphConstInput(const_op);
4395     (void)const_op->set_attr_value(*ge_tensor);
4396     (void)const_op->update_output_desc_y(ge_tensor->GetTensorDesc());
4397     (void)tuple_items->emplace_back(OutHandler(const_op, ""));
4398   }
4399   if (tuple_items->empty()) {
4400     return FAILED;
4401   }
4402 
4403   tuple_out_handle_cache_[node.get()] = tuple_items;
4404   if (!vec[0]->isa<MeTensor>()) {
4405     return FAILED;
4406   }
4407   return SUCCESS;
4408 }
4409 
ConvertValueNode(const ValueNodePtr node)4410 OperatorPtr DfGraphConvertor::ConvertValueNode(const ValueNodePtr node) {
4411   // convert valuenode in ANF to Const in DataFlow
4412   // find paramerte referenced by SymbolicKeyInstance of valuenode
4413   std::ostringstream ss;
4414   ss << "op" << node.get();
4415   op_draw_name_[node.get()] = ss.str();
4416   compute_sout_ << ss.str() << "[label= \"" << node->value()->ToString() << "\" shape=ellipse]" << endl;
4417 
4418   if (TryConvertValueNodeToMultiConst(node) == SUCCESS) {
4419     MS_LOG(INFO) << "Convert value node to multi Constant OP success";
4420     return nullptr;
4421   }
4422 
4423   OpAdapterPtr adpt = FindAdapter(node, training_);
4424   if (adpt == nullptr) {
4425     error_ = NOT_FOUND;
4426     return nullptr;
4427   }
4428   auto op = adpt->generate(node);
4429   // set const's attrs
4430   if (adpt->setAttr(op, "value", node->value()) != 0) {
4431     MS_LOG(WARNING) << "set attr value for const failed";
4432   }
4433 
4434   if (op->GetOpType() != "Constant" && op->GetOpType() != "Const") {
4435     MS_LOG(ERROR) << "Get Constant operator failed, ge node type: " << op->GetOpType()
4436                   << ", ms node info: " << node->ToString() << ", is train: " << training_;
4437     return nullptr;
4438   }
4439   ::ge::Tensor ge_tensor;
4440   (void)op->GetAttr("value", ge_tensor);
4441   auto ge_desc = ge_tensor.GetTensorDesc();
4442   (void)op->UpdateOutputDesc(kTypeY, ge_desc);
4443 
4444   op_cache_[node.get()] = op;
4445   return op_cache_[node.get()];
4446 }
4447 
DrawCNode(const CNodePtr node,const OpAdapterPtr adpt)4448 void DfGraphConvertor::DrawCNode(const CNodePtr node, const OpAdapterPtr adpt) {
4449   if (adpt == nullptr || node == nullptr) {
4450     MS_LOG(ERROR) << "Failed to draw apply node as adpt or node is nullptr!";
4451     return;
4452   }
4453   std::ostringstream ss;
4454   ss << "op" << node.get();
4455   op_draw_name_[node.get()] = ss.str();
4456 
4457   compute_sout_ << ss.str() << "[label=<";
4458   compute_sout_ << "<table border='1' cellborder='1'>" << endl;
4459 
4460   auto input_map = adpt->getInputMap();
4461   auto dyn_input_map = adpt->getDynInputMap();
4462   if (input_map.size() + dyn_input_map.size() > 0) {
4463     compute_sout_ << "<tr>";
4464     for (auto &it : input_map) {
4465       compute_sout_ << "<td port='" << it.first << "'>" << it.second.name << "</td>";
4466     }
4467     for (auto &it : dyn_input_map) {
4468       compute_sout_ << "<td port='" << it.first << "'>" << it.second.name << "</td>";
4469     }
4470     compute_sout_ << "</tr>" << endl;
4471   }
4472 
4473   compute_sout_ << "<tr><td colspan=\"" << (input_map.size() + dyn_input_map.size()) << "\">\"" << node->ToString()
4474                 << ":" << GetCNodeTargetFuncName(node) << "\"</td></tr>" << endl;
4475 
4476   // print attrs' values
4477   auto atts = adpt->GetAttrsFromDrawGraph();
4478   for (auto &it : atts) {
4479     compute_sout_ << "<tr><td colspan=\"" << (input_map.size() + dyn_input_map.size()) << "\">\"" << it
4480                   << "\"</td></tr>";
4481   }
4482 
4483   adpt->clearAttrVect();
4484 
4485   compute_sout_ << "</table>> shape=plaintext]" << endl;
4486 }
RegisterAdapter(const std::string & name,OpAdapterPtr adpt)4487 void DfGraphConvertor::RegisterAdapter(const std::string &name, OpAdapterPtr adpt) {
4488   OpAdapterMap::get()[name] = std::make_shared<OpAdapterDesc>(adpt);
4489 }
RegisterAdapter(const std::string & name,OpAdapterPtr train_adpt,OpAdapterPtr infer_adpt)4490 void DfGraphConvertor::RegisterAdapter(const std::string &name, OpAdapterPtr train_adpt, OpAdapterPtr infer_adpt) {
4491   OpAdapterMap::get()[name] = std::make_shared<OpAdapterDesc>(train_adpt, infer_adpt);
4492 }
4493 
GetAttrAndValue(const AnfNodePtr & node,const bool training=true)4494 std::map<std::string, ValuePtr> GeOpConvertor::GetAttrAndValue(const AnfNodePtr &node, const bool training = true) {
4495   MS_EXCEPTION_IF_NULL(node);
4496   std::map<std::string, ValuePtr> attr_list;
4497   if (!node->isa<CNode>()) {
4498     MS_LOG(INFO) << "Current node isn't a cnode! node info:" << node->DebugString();
4499     return attr_list;
4500   }
4501 
4502   OpAdapterPtr adpt = FindAdapter(node, training);
4503   if (adpt == nullptr) {
4504     MS_LOG(INFO) << "Current node can't find adpt! node info:" << node->DebugString();
4505     return attr_list;
4506   }
4507 
4508   attr_list = adpt->GetNormalOpAttrList(node);
4509   return attr_list;
4510 }
4511 
GetOpType(const AnfNodePtr & node,const bool training=true)4512 std::string GeOpConvertor::GetOpType(const AnfNodePtr &node, const bool training = true) {
4513   MS_EXCEPTION_IF_NULL(node);
4514   OpAdapterPtr adpt = FindAdapter(node, training);
4515   if (adpt == nullptr) {
4516     MS_LOG(INFO) << "Current node can't find adpt! node info:" << node->DebugString();
4517     return "";
4518   }
4519   return adpt->getOpType();
4520 }
4521 
GetTensorDesc(const ShapeVector & dev_shape,const TypeId & dev_type,const std::string & dev_format,const ShapeVector & ori_shape,const std::string & ori_format)4522 std::shared_ptr<GeTensorDesc> GeOpConvertor::GetTensorDesc(const ShapeVector &dev_shape, const TypeId &dev_type,
4523                                                            const std::string &dev_format, const ShapeVector &ori_shape,
4524                                                            const std::string &ori_format) {
4525   auto tensor_desc = transform::TransformUtil::GetGeTensorDesc(dev_shape, dev_type, dev_format, ori_shape, ori_format);
4526   MS_EXCEPTION_IF_NULL(tensor_desc);
4527   return tensor_desc;
4528 }
4529 
GetNeedAddInput(const AnfNodePtr & node,const bool training)4530 mindspore::HashMap<std::string, std::string> GeOpConvertor::GetNeedAddInput(const AnfNodePtr &node,
4531                                                                             const bool training) {
4532   MS_EXCEPTION_IF_NULL(node);
4533   OpAdapterPtr adpt = FindAdapter(node, training);
4534   if (adpt == nullptr) {
4535     MS_LOG(INFO) << "Current node can't find adpt! node info:" << node->DebugString();
4536     return {};
4537   }
4538 
4539   return adpt->getAttrInputMap();
4540 }
4541 
IsDynamicInput(const AnfNodePtr & node,const size_t idx)4542 bool GeOpConvertor::IsDynamicInput(const AnfNodePtr &node, const size_t idx) {
4543   MS_EXCEPTION_IF_NULL(node);
4544   OpAdapterPtr adapterPtr = FindAdapter(node, true);
4545   if (adapterPtr == nullptr) {
4546     MS_LOG(INFO) << "Can't find a adapter for op:" << node->DebugString();
4547     return false;
4548   }
4549   return adapterPtr->IsDynInputOp(idx);
4550 }
4551 
GetAclInputNames(const AnfNodePtr & node)4552 std::map<int, std::string> GeOpConvertor::GetAclInputNames(const AnfNodePtr &node) {
4553   MS_EXCEPTION_IF_NULL(node);
4554   OpAdapterPtr adapterPtr = FindAdapter(node, true);
4555   if (adapterPtr == nullptr) {
4556     MS_LOG(EXCEPTION) << "Can't find a adapter for op:" << node->DebugString();
4557   }
4558 
4559   std::map<int, std::string> input_names;
4560   for (const auto &[k, v] : adapterPtr->getInputMap()) {
4561     input_names.emplace(k, v.name);
4562   }
4563   // dynamic input
4564   for (const auto &[k, v] : adapterPtr->getDynInputMap()) {
4565     input_names.emplace(k, v.name);
4566   }
4567   return input_names;
4568 }
4569 
GetAclOutputNames(const AnfNodePtr & node)4570 std::map<int, std::string> GeOpConvertor::GetAclOutputNames(const AnfNodePtr &node) {
4571   MS_EXCEPTION_IF_NULL(node);
4572   OpAdapterPtr adapterPtr = FindAdapter(node, true);
4573   if (adapterPtr == nullptr) {
4574     MS_LOG(EXCEPTION) << "Can't find a adapter for op:" << node->DebugString();
4575   }
4576 
4577   std::map<int, std::string> output_names;
4578   for (const auto &[k, v] : adapterPtr->getOutputMap()) {
4579     output_names.emplace(k, v.name);
4580   }
4581 
4582   // dynamic output
4583   for (const auto &[k, v] : adapterPtr->getDynOutputMap()) {
4584     output_names.emplace(k, v.name);
4585   }
4586   return output_names;
4587 }
4588 
GetAclDynamicInputNames(const AnfNodePtr & node)4589 std::map<int, std::string> GeOpConvertor::GetAclDynamicInputNames(const AnfNodePtr &node) {
4590   MS_EXCEPTION_IF_NULL(node);
4591   OpAdapterPtr adapterPtr = FindAdapter(node, true);
4592   if (adapterPtr == nullptr) {
4593     MS_LOG(EXCEPTION) << "Can't find a adapter for op:" << node->DebugString();
4594   }
4595   std::map<int, std::string> dyn_input_names;
4596   for (const auto &[k, v] : adapterPtr->getDynInputMap()) {
4597     dyn_input_names.emplace(k, v.name);
4598   }
4599   return dyn_input_names;
4600 }
4601 
GetAclDynamicOutputNames(const AnfNodePtr & node)4602 std::map<int, std::string> GeOpConvertor::GetAclDynamicOutputNames(const AnfNodePtr &node) {
4603   MS_EXCEPTION_IF_NULL(node);
4604   OpAdapterPtr adapterPtr = FindAdapter(node, true);
4605   if (adapterPtr == nullptr) {
4606     MS_LOG(EXCEPTION) << "Can't find a adapter for op:" << node->DebugString();
4607   }
4608   std::map<int, std::string> dyn_output_names;
4609   for (const auto &[k, v] : adapterPtr->getDynOutputMap()) {
4610     dyn_output_names.emplace(k, v.name);
4611   }
4612   return dyn_output_names;
4613 }
4614 }  // namespace mindspore::transform
4615