1 /**
2 * Copyright 2019-2024 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "transform/graph_ir/convert.h"
18
19 #include <algorithm>
20 #include <unordered_set>
21 #include <vector>
22
23 #include "op_proto/inc/array_ops.h"
24 #include "op_proto/inc/elewise_calculation_ops.h"
25 #include "op_proto/inc/save_ops.h"
26 #include "op_proto/inc/state_ops.h"
27 #include "include/common/debug/anf_ir_dump.h"
28 #include "include/common/utils/anfalgo.h"
29 #include "include/common/utils/config_manager.h"
30 #include "include/common/utils/utils.h"
31 #include "include/transform/graph_ir/utils.h"
32 #include "ir/graph_utils.h"
33 #include "ops/array_ops.h"
34 #include "ops/conv_pool_ops.h"
35 #include "ops/framework_ops.h"
36 #include "ops/image_ops.h"
37 #include "ops/math_op_name.h"
38 #include "ops/nn_ops.h"
39 #include "ops/nn_optimizer_ops.h"
40 #include "ops/other_ops.h"
41 #include "ops/sequence_ops.h"
42 #include "ops/structure_ops.h"
43 #include "ops/lite_ops.h"
44 #include "ops/op_def.h"
45 #include "ops/auto_generate/gen_ops_primitive.h"
46 #include "plugin/device/ascend/hal/hardware/ascend_collective_comm_lib.h"
47 #include "plugin/device/ascend/hal/hardware/dummy_ascend_collective_comm_lib.h"
48 #include "plugin/device/ascend/hal/hardware/ge_utils.h"
49 #include "plugin/device/ascend/hal/hccl_adapter/hccl_adapter.h"
50 #include "transform/graph_ir/op_adapter.h"
51 #include "transform/graph_ir/op_adapter_desc.h"
52 #include "transform/graph_ir/op_adapter_map.h"
53 #include "transform/graph_ir/storage_format_convertor.h"
54 #include "utils/anf_utils.h"
55 #include "utils/check_convert_utils.h"
56 #include "utils/log_adapter.h"
57 #include "utils/ms_context.h"
58 #include "utils/symbolic.h"
59 #include "utils/singleton.h"
60
61 namespace mindspore::transform {
62 using ::ge::Operator;
63 using mindspore::kValueAny;
64 using std::make_shared;
65 using std::shared_ptr;
66 using std::string;
67 using std::vector;
68 using Variable = ::ge::op::Variable;
69 using Constant = ::ge::op::Constant;
70 using Assign = ::ge::op::Assign;
71 using Data = ::ge::op::Data;
72 using RefData = ::ge::op::RefData;
73 using std::endl;
74 using std::static_pointer_cast;
75
76 constexpr int64_t kInputOffset = 2;
77 constexpr size_t kSwitchInputSize = 4;
78 constexpr size_t kSwitchBodyIndex = 2;
79 constexpr size_t kSwitchAfterIndex = 3;
80 constexpr size_t kAfterIndexInCache = 2;
81 constexpr size_t kCnodeInputSizeOne = 1;
82 constexpr size_t kDataInputIndex = 1;
83 constexpr size_t kInputSize2 = 2;
84 constexpr size_t kMergeInputSize = 2;
85 constexpr size_t kNoOpOptThreshold = 3;
86 constexpr auto kHcclFusionByFusionID = 2;
87 constexpr auto kHcclFusionDefault = 1;
88 constexpr auto kTypeNoOp = "NoOp";
89 constexpr auto kTypeIdentity = "Identity";
90 constexpr auto kTypeIdentityN = "IdentityN";
91 constexpr auto kTypeMerge = "Merge";
92 constexpr auto kTypeIf = "If";
93 constexpr auto kTypeVariable = "Variable";
94 constexpr auto kParallelGroup = "_parallel_group";
95 constexpr auto kParallelGroupId = "_parallel_group_id";
96 constexpr auto kTypeRefData = "RefData";
97 constexpr auto kBroadcast = "broadcast";
98 constexpr auto kInit = "init";
99 constexpr auto kTypeData = "Data";
100 constexpr auto kTypeIndex = "index";
101 constexpr auto kTypeY = "y";
102 constexpr auto kTypeX = "x";
103 constexpr auto kProcessNodeEngineID = "_process_node_engine_id";
104 constexpr auto kIsFreeVariable = "_is_free_variable";
105
106 namespace {
107 const std::map<TypeId, TypeId> kReduceRaiseMap = {{kNumberTypeInt64, kNumberTypeInt32}};
108 mindspore::HashMap<std::string, size_t> branches_repeat_times = {};
109 mindspore::HashMap<std::string, size_t> call_subgraphs_repeat_times = {};
110 // {node name | {{input_index, dst_type}...}}
111 const std::map<std::string, std::vector<std::pair<size_t, TypeId>>> kTransInputDTypeMap = {
112 {kResizeNearestNeighborGradOpName, {{2, kNumberTypeInt32}}},
113 {kResizeNearestNeighborOpName, {{2, kNumberTypeInt32}}},
114 {kResizeNearestNeighborV2OpName, {{2, kNumberTypeInt32}}},
115 {kResizeNearestNeighborV2GradOpName, {{2, kNumberTypeInt32}}},
116 {kResizeBicubicOpName, {{2, kNumberTypeInt32}}},
117 {kConv2DBackpropFilterOpName, {{3, kNumberTypeInt32}}},
118 {kConv2DBackpropInputOpName, {{3, kNumberTypeInt32}}},
119 {kOneHotOpName, {{2, kNumberTypeInt32}}},
120 {kLinSpaceOpName, {{3, kNumberTypeInt32}}},
121 {kResizeNearestNeighborV2GradOpName, {{2, kNumberTypeInt32}}},
122 {kResizeBilinearV2OpName, {{2, kNumberTypeInt32}}},
123 {kCol2ImOpName, {{2, kNumberTypeInt32}}}};
124
125 // {node name | {{attr_name, dst_type}...}}
126 const std::map<std::string, std::vector<std::pair<std::string, TypeId>>> kTransAttrDTypeMap = {
127 {kResizeBilinearOpName, {{"size", kNumberTypeInt32}}},
128 {kSpaceToBatchNDOpName, {{"block_shape", kNumberTypeInt32}}},
129 {kBatchToSpaceNDOpName, {{"block_shape", kNumberTypeInt32}}},
130 {kSplitVOpName, {{"split_dim", kNumberTypeInt32}}},
131 {kSplitVDOpName, {{"split_dim", kNumberTypeInt32}}}};
132
IsValidConversion(TypeId src_type,TypeId dst_type)133 bool IsValidConversion(TypeId src_type, TypeId dst_type) {
134 if (src_type == dst_type) {
135 MS_LOG(DEBUG) << "No need convert, src type and dst type is same, type:" << TypeIdToString(src_type);
136 return false;
137 }
138 auto iter = kReduceRaiseMap.find(src_type);
139 if (iter != kReduceRaiseMap.end() && iter->second == dst_type) {
140 MS_LOG(INFO) << "Convert data type from " << TypeIdToString(src_type) << " to " << TypeIdToString(dst_type);
141 return true;
142 }
143 MS_LOG(DEBUG) << "Unsupported conversion. src_type:" << TypeIdToString(src_type)
144 << ", dst_type:" << TypeIdToString(dst_type);
145 return false;
146 }
147
148 template <typename T>
CreateNewValue(const ValuePtr & value,const std::vector<T> & values,const TypeId & dst_type)149 ValuePtr CreateNewValue(const ValuePtr &value, const std::vector<T> &values, const TypeId &dst_type) {
150 MS_EXCEPTION_IF_NULL(value);
151 if (dst_type == kNumberTypeInt32) {
152 if (value->isa<ValueSequence>()) {
153 std::vector<int32_t> result;
154 std::for_each(values.begin(), values.end(),
155 [&result](const auto &elem) { result.emplace_back(static_cast<int32_t>(elem)); });
156 return MakeValue(result);
157 }
158 return MakeValue(static_cast<int32_t>(values[0]));
159 } else {
160 MS_LOG(EXCEPTION) << "Invalid dst type:" << TypeIdToString(dst_type);
161 }
162 return value;
163 }
164
165 template <typename T>
GetAllValues(const ValuePtr & value)166 std::vector<T> GetAllValues(const ValuePtr &value) {
167 MS_EXCEPTION_IF_NULL(value);
168 std::vector<T> result;
169 if (value->isa<ValueSequence>()) {
170 auto value_seq = value->cast<ValueSequencePtr>();
171 MS_EXCEPTION_IF_NULL(value_seq);
172 for (const auto &elem : value_seq->value()) {
173 auto value_list = GetAllValues<T>(elem);
174 std::copy(value_list.begin(), value_list.end(), std::back_inserter(result));
175 }
176 } else {
177 result.emplace_back(GetValue<T>(value));
178 }
179 return result;
180 }
181
GetElemType(const ValuePtr & value)182 TypeId GetElemType(const ValuePtr &value) {
183 MS_EXCEPTION_IF_NULL(value);
184 if (value->isa<tensor::Tensor>()) {
185 auto tensor_ptr = value->cast<tensor::TensorPtr>();
186 MS_EXCEPTION_IF_NULL(tensor_ptr);
187 return tensor_ptr->data_type();
188 }
189 if (!value->isa<ValueList>() && !value->isa<ValueTuple>()) {
190 return value->type()->type_id();
191 }
192
193 auto elems = value->isa<ValueTuple>() ? value->cast<ValueTuplePtr>()->value() : value->cast<ValueListPtr>()->value();
194 if (elems.empty()) {
195 MS_LOG(EXCEPTION) << "Value:" << value->ToString() << " is empty, check pls.";
196 }
197 return GetElemType(elems.at(0));
198 }
199
CastDstValue(const ValuePtr & value,const TypeId & dst_type)200 ValuePtr CastDstValue(const ValuePtr &value, const TypeId &dst_type) {
201 MS_EXCEPTION_IF_NULL(value);
202 auto src_type = GetElemType(value);
203 if (!IsValidConversion(src_type, dst_type)) {
204 return nullptr;
205 }
206 if (src_type == kNumberTypeInt64) {
207 if (value->isa<tensor::Tensor>()) {
208 auto tensor_ptr = value->cast<tensor::TensorPtr>();
209 MS_EXCEPTION_IF_NULL(tensor_ptr);
210 auto tensor_size = tensor_ptr->Size() / sizeof(int64_t);
211 int64_t *data = static_cast<int64_t *>(tensor_ptr->data_c());
212 std::vector<int32_t> v;
213 for (size_t i = 0; i < tensor_size; i++) {
214 (void)v.emplace_back(LongToInt(data[i]));
215 }
216 return MakeValue(v);
217 }
218 auto values = GetAllValues<int64_t>(value);
219 return CreateNewValue<int64_t>(value, values, dst_type);
220 } else {
221 MS_LOG(EXCEPTION) << "Invalid src type:" << value->type()->ToString();
222 }
223 return value;
224 }
225
226 // If mark_fv is true, set the kIsFreeVariable flag for all free variables and their inputs.
SuccIncludeFv(const FuncGraphPtr & fg,const AnfNodePtr & node,bool mark_fv=false)227 AnfNodeWeakPtrList SuccIncludeFv(const FuncGraphPtr &fg, const AnfNodePtr &node, bool mark_fv = false) {
228 AnfNodeWeakPtrList vecs;
229 if (node == nullptr) {
230 return vecs;
231 }
232
233 if (node->isa<CNode>()) {
234 auto cnode = node->cast<CNodePtr>();
235 bool is_fv = mark_fv && node->has_user_data(kIsFreeVariable);
236 auto &weak_inputs = cnode->weak_inputs();
237
238 // Check if free variables used.
239 for (const auto &weak_input : weak_inputs) {
240 auto input = weak_input.lock();
241 MS_EXCEPTION_IF_NULL(input);
242 if (is_fv) {
243 input->set_user_data(kIsFreeVariable, std::make_shared<bool>(true));
244 }
245 auto input_fg = GetValueNode<FuncGraphPtr>(input);
246 if (input_fg) {
247 for (auto &fv : input_fg->free_variables_nodes()) {
248 if (fv->func_graph() == fg && fg->nodes().contains(fv)) {
249 if (mark_fv) {
250 fv->set_user_data(kIsFreeVariable, std::make_shared<bool>(true));
251 }
252 (void)vecs.emplace_back(fv);
253 }
254 }
255 }
256 }
257
258 (void)vecs.insert(vecs.end(), weak_inputs.begin(), weak_inputs.end());
259 }
260
261 return vecs;
262 }
263
GetOrderedCNodes(const FuncGraphPtr fg,const AnfNodePtr node=nullptr)264 std::vector<AnfNodePtr> GetOrderedCNodes(const FuncGraphPtr fg, const AnfNodePtr node = nullptr) {
265 MS_EXCEPTION_IF_NULL(fg);
266 auto succ_include_fv = [&fg](const AnfNodePtr &node) -> AnfNodeWeakPtrList { return SuccIncludeFv(fg, node); };
267
268 return (node == nullptr) ? TopoSort(fg->get_return(), succ_include_fv) : TopoSort(node, succ_include_fv);
269 }
270
GetFvNames(const FuncGraphPtr fg)271 std::set<std::string> GetFvNames(const FuncGraphPtr fg) {
272 MS_EXCEPTION_IF_NULL(fg);
273 auto succ_include_fv = [&fg](const AnfNodePtr &node) -> AnfNodeWeakPtrList { return SuccIncludeFv(fg, node, true); };
274
275 std::set<std::string> fvs;
276 auto nodes = TopoSort(fg->get_return(), succ_include_fv);
277 for (const auto &node : nodes) {
278 if (node->has_user_data(kIsFreeVariable)) {
279 node->set_user_data(kIsFreeVariable, std::shared_ptr<bool>(nullptr));
280 fvs.emplace(node->fullname_with_scope());
281 }
282 }
283
284 return fvs;
285 }
286
GetDynInputNum(const OpAdapterPtr & adpt,bool is_call,std::vector<int64_t> dyn_input_sizes,size_t real_input_idx,size_t input_size,const CNodePtr & node)287 int64_t GetDynInputNum(const OpAdapterPtr &adpt, bool is_call, std::vector<int64_t> dyn_input_sizes,
288 size_t real_input_idx, size_t input_size, const CNodePtr &node) {
289 MS_EXCEPTION_IF_NULL(adpt);
290 MS_EXCEPTION_IF_NULL(node);
291 int64_t dyn_input_num = -1;
292 if (!dyn_input_sizes.empty()) {
293 dyn_input_num = dyn_input_sizes.at(real_input_idx - 1);
294 } else if (adpt->IsDynInputOp(real_input_idx)) {
295 if (is_call) {
296 auto &input = node->inputs().back();
297 // the first input of Call node is Primitive, the second input is kernel_graph,
298 // which should not be members of input args, so the dyn_input_num need to minus 2 in default.
299 if (IsPrimitiveCNode(input, prim::kPrimUpdateState)) {
300 // For PartitionedCall, Monod should not be a member of input args, so here dyn_input_num need to minus 3.
301 dyn_input_num = SizeToLong(input_size) - 3;
302 } else {
303 dyn_input_num = SizeToLong(input_size) - 2;
304 }
305 return dyn_input_num;
306 }
307 dyn_input_num = 1;
308 }
309 return dyn_input_num;
310 }
311
IsBranchNode(const AnfNodePtr & node)312 bool IsBranchNode(const AnfNodePtr &node) { return IsIfNode(node) || IsCaseNode(node); }
313
GetAnfCallInputs(bool is_kernel_graph,const CNodePtr & c_node)314 std::vector<AnfNodePtr> GetAnfCallInputs(bool is_kernel_graph, const CNodePtr &c_node) {
315 std::vector<AnfNodePtr> inputs;
316 if (is_kernel_graph) {
317 (void)std::copy(c_node->inputs().begin() + kInputOffset, c_node->inputs().end(), std::back_inserter(inputs));
318 } else {
319 if (c_node->input(0)->isa<CNode>()) {
320 auto in0 = c_node->input(0)->cast<CNodePtr>();
321 (void)std::copy(in0->inputs().begin() + kInputOffset, in0->inputs().end(), std::back_inserter(inputs));
322 }
323 (void)std::copy(c_node->inputs().begin() + 1, c_node->inputs().end(), std::back_inserter(inputs));
324 }
325 return inputs;
326 }
327
HasSubgraph(const std::shared_ptr<AnfGraph> & func_graph)328 bool HasSubgraph(const std::shared_ptr<AnfGraph> &func_graph) {
329 auto node_list = TopoSort(func_graph->get_return());
330 for (auto &node : node_list) {
331 if (!utils::isa<CNodePtr>(node)) {
332 continue;
333 }
334 auto sub_graph = GetCNodeFuncGraph(node);
335 if (sub_graph != nullptr) {
336 return true;
337 }
338 }
339 return false;
340 }
341
IsMakeTupleWithNullValue(const AnfNodePtr & node,const AnfNodePtr & input)342 bool IsMakeTupleWithNullValue(const AnfNodePtr &node, const AnfNodePtr &input) {
343 MS_EXCEPTION_IF_NULL(input);
344 if (IsPrimitiveCNode(node, prim::kPrimMakeTuple) && input->isa<ValueNode>()) {
345 auto type = input->Type();
346 MS_EXCEPTION_IF_NULL(type);
347 if (type->isa<Tuple>()) {
348 auto tuple_type = type->cast<std::shared_ptr<Tuple>>();
349 MS_EXCEPTION_IF_NULL(tuple_type);
350 if (tuple_type->elements().empty()) {
351 return true;
352 }
353 }
354 }
355 return false;
356 }
357
IsMonad(const AnfNodePtr & node)358 bool IsMonad(const AnfNodePtr &node) {
359 return IsValueNode<UMonad>(node) || IsValueNode<IOMonad>(node) || HasAbstractMonad(node);
360 }
361
IsOverFlowNode(const AnfNodePtr & node,const AnfNodePtr & input)362 bool IsOverFlowNode(const AnfNodePtr &node, const AnfNodePtr &input) {
363 return IsPrimitiveCNode(input, prim::kPrimNPUClearFloatStatusV2) ||
364 IsPrimitiveCNode(node, prim::kPrimNPUClearFloatStatusV2) ||
365 IsPrimitiveCNode(node, prim::kPrimNPUGetFloatStatusV2);
366 }
367
SelectParamOriFormat(const FuncGraphManagerPtr & manager,const AnfNodePtr & node)368 std::string SelectParamOriFormat(const FuncGraphManagerPtr &manager, const AnfNodePtr &node) {
369 MS_EXCEPTION_IF_NULL(manager);
370 MS_EXCEPTION_IF_NULL(node);
371 std::deque<AnfNodePtr> todo{node};
372 while (!todo.empty()) {
373 auto &curr_node = todo.front();
374 todo.pop_front();
375 const auto &nodes = manager->node_users()[curr_node];
376 for (const auto &node_pair : nodes) {
377 if (IsPrimitiveCNode(node_pair.first, prim::kPrimLoad)) {
378 todo.emplace_back(node_pair.first);
379 } else if (node_pair.first->isa<CNode>()) {
380 auto visited_format = GetOpIOFormat(node_pair.first);
381 if (visited_format != kOpFormat_DEFAULT) {
382 return visited_format;
383 }
384 }
385 }
386 }
387 return kOpFormat_DEFAULT;
388 }
389
GetGeTensorOrders(const mindspore::HashMap<int,int> & ge_input_to_ms_input,const std::vector<int64_t> & dyn_input_sizes,const int & ge_input_size,std::vector<int64_t> * new_dyn_input_sizes)390 std::vector<int> GetGeTensorOrders(const mindspore::HashMap<int, int> &ge_input_to_ms_input,
391 const std::vector<int64_t> &dyn_input_sizes, const int &ge_input_size,
392 std::vector<int64_t> *new_dyn_input_sizes) {
393 std::vector<int> ge_tensor_orders(ge_input_size, -1);
394 for (int ge_idx = 0; ge_idx < ge_input_size; ++ge_idx) {
395 int ms_idx = ge_input_to_ms_input.at(ge_idx);
396 new_dyn_input_sizes->at(ge_idx) = dyn_input_sizes[ms_idx];
397 int begin_idx = 0;
398 for (int i = 0; i < ms_idx; ++i) {
399 begin_idx += dyn_input_sizes[i] == -1 ? 1 : dyn_input_sizes[i];
400 }
401 ge_tensor_orders[ge_idx] = begin_idx;
402 }
403 return ge_tensor_orders;
404 }
405
IsNeedToUpdateTensorDesc(const std::string & op_type,const AnfNodePtr & node)406 bool IsNeedToUpdateTensorDesc(const std::string &op_type, const AnfNodePtr &node) {
407 // When IdentityN's input is Function or IdentityN, it can not
408 // find GEType mapping to MSType. There are ERROR logs that do not affect the result. So it no need to set OutputDesc
409 // of IdentityN, it can be inferred by GE. eg: MakeTuple-->MakeTuple. Output node should set OpDesc.
410 if (op_type == kTypeIdentityN && !IsPrimitiveCNode(node, prim::kPrimReturn)) {
411 MS_LOG(DEBUG) << "No need to set the OpDesc of Identity except return, node: " << node->fullname_with_scope();
412 return false;
413 }
414 // NoOp has not output, so it no need to set OutputDesc.
415 if (op_type == kTypeNoOp) {
416 MS_LOG(DEBUG) << "No need to set the OpDesc of NoOp, node: " << node->fullname_with_scope();
417 return false;
418 }
419 return true;
420 }
421
422 template <typename T>
SetXDataIndex(const OperatorPtr & op,T idx)423 void SetXDataIndex(const OperatorPtr &op, T idx) {
424 MS_EXCEPTION_IF_NULL(op);
425 op->SetAttr(kTypeIndex, static_cast<int64_t>(idx));
426 }
427
ParamCompare(const std::string & l,const std::string & r,const mindspore::HashMap<std::string,AnfNodePtr> & params,const NodeUsersMap & node_users)428 bool ParamCompare(const std::string &l, const std::string &r, const mindspore::HashMap<std::string, AnfNodePtr> ¶ms,
429 const NodeUsersMap &node_users) {
430 auto lpram_iter = params.find(l);
431 auto rpram_iter = params.find(r);
432 if (lpram_iter == params.end() && rpram_iter == params.end()) {
433 return l.compare(r) < 0;
434 } else if (lpram_iter == params.end()) {
435 return true;
436 } else if (rpram_iter == params.end()) {
437 return false;
438 }
439
440 bool lused_as_accum = (GetMomentumVarByAccum(lpram_iter->second, node_users) != nullptr);
441 bool rused_as_accum = (GetMomentumVarByAccum(rpram_iter->second, node_users) != nullptr);
442 if (lused_as_accum ^ rused_as_accum) {
443 return rused_as_accum;
444 }
445
446 return l.compare(r) < 0;
447 }
448
IsESNodeWithNoOutput(const AnfNodePtr & node)449 bool IsESNodeWithNoOutput(const AnfNodePtr &node) {
450 const mindspore::HashSet<PrimitivePtr, PrimitiveHasher, PrimitiveEqual> no_output_prims = {
451 prim::kPrimInitPartitionMap, prim::kPrimInitEmbeddingHashmap, prim::kPrimEmbeddingTableImport,
452 prim::kPrimEmbeddingComputeVarExport, prim::kPrimEmbeddingComputeVarImport, prim::kPrimEmbeddingTableExport};
453 if (IsOneOfPrimitiveCNode(node, no_output_prims)) {
454 return true;
455 }
456 return false;
457 }
458
GetEmbeddingApplyAdamOutput(const CNodePtr & node)459 std::vector<AnfNodePtr> GetEmbeddingApplyAdamOutput(const CNodePtr &node) {
460 MS_EXCEPTION_IF_NULL(node);
461 std::vector<AnfNodePtr> ret_nodes;
462 auto depend = node->input(1);
463 MS_EXCEPTION_IF_NULL(depend);
464 if (!IsPrimitiveCNode(depend, prim::kPrimDepend)) {
465 MS_LOG(EXCEPTION) << "Need Depend ops, but get " << depend->fullname_with_scope();
466 }
467 auto depend_cnode = depend->cast<CNodePtr>();
468 auto tuple = depend_cnode->input(2);
469 MS_EXCEPTION_IF_NULL(tuple);
470 if (!IsPrimitiveCNode(tuple, prim::kPrimMakeTuple)) {
471 MS_LOG(EXCEPTION) << "Need MakeTuple ops, but get " << tuple->fullname_with_scope();
472 }
473 auto tuple_cnode = tuple->cast<CNodePtr>();
474 auto output_nodes = tuple_cnode->inputs();
475 ret_nodes.emplace_back(depend_cnode->input(1));
476 ret_nodes.insert(ret_nodes.end(), output_nodes.begin() + 1, output_nodes.end());
477 return ret_nodes;
478 }
479 } // namespace
480
GenExampleGraph(const std::string & name)481 DfGraphPtr GenExampleGraph(const std::string &name) {
482 MS_LOG(INFO) << "Gen example graph name is " << name;
483 auto graph = std::make_shared<DfGraph>(name);
484 MS_EXCEPTION_IF_NULL(graph);
485 auto shape_data = std::vector<int64_t>({1, 1, 1, 1});
486 GeTensorDesc desc_data(ge::Shape(shape_data), ge::FORMAT_ND, ge::DT_FLOAT16);
487 auto data = ge::op::Data("data");
488 data.set_attr_index(0);
489 data.update_input_desc_x(desc_data);
490 data.update_output_desc_y(desc_data);
491 auto abs = ge::op::Abs("abs").set_input_x(data);
492 std::vector<Operator> inputs{data};
493 std::vector<Operator> outputs{abs};
494 graph->SetInputs(inputs);
495 graph->SetOutputs(outputs);
496 return graph;
497 }
498
499 // ---------------implement of DfGraphConvertor-------------
500
IsDynamicShapeNode(const AnfNodePtr node)501 bool IsDynamicShapeNode(const AnfNodePtr node) {
502 auto shape = node->Shape();
503 if (shape == nullptr) {
504 return false;
505 }
506 if (!shape->isa<abstract::Shape>()) { // do not accept tuple shape as call node input
507 return false;
508 }
509 if (shape->IsDynamic()) {
510 return true;
511 }
512 return false;
513 }
514
InitLoopVar(std::vector<::ge::Operator> * init_input)515 bool DfGraphConvertor::InitLoopVar(std::vector<::ge::Operator> *init_input) {
516 MS_EXCEPTION_IF_NULL(init_input);
517 if (!this->training_) {
518 return false;
519 }
520 bool is_sink_size_repeat = false;
521 auto ms_context = MsContext::GetInstance();
522 MS_EXCEPTION_IF_NULL(ms_context);
523 int64_t value = 0;
524 if (ConfigManager::GetInstance().dataset_mode() == DS_SINK_MODE) {
525 static int64_t sink_size = 0;
526 if (!ms_context->get_param<bool>(MS_CTX_ENABLE_LOOP_SINK)) {
527 return false;
528 }
529 value = ConfigManager::GetInstance().iter_num();
530 if (sink_size == value) {
531 is_sink_size_repeat = true;
532 }
533 sink_size = value;
534 } else {
535 MS_LOG(INFO) << "Run with normal(non-sink) mode, the iterator number will always be 1";
536 ConfigManager::GetInstance().ResetIterNum();
537 return false;
538 }
539 GeTensorDesc desc(GeShape(), ::ge::FORMAT_NCHW, ::ge::DT_INT64);
540 auto var_iter_num = std::make_shared<Variable>("npu_runconfig/iterations_per_loop");
541 auto var_loop_cond = std::make_shared<Variable>("npu_runconfig/loop_cond");
542 auto var_one = std::make_shared<Variable>("npu_runconfig/one");
543 auto var_zero = std::make_shared<Variable>("npu_runconfig/zero");
544 (void)var_iter_num->update_output_desc_y(desc);
545 (void)var_loop_cond->update_output_desc_y(desc);
546 (void)var_one->update_output_desc_y(desc);
547 (void)var_zero->update_output_desc_y(desc);
548 vars_["npu_runconfig/iterations_per_loop"] = var_iter_num;
549 vars_["npu_runconfig/loop_cond"] = var_loop_cond;
550 vars_["npu_runconfig/one"] = var_one;
551 vars_["npu_runconfig/zero"] = var_zero;
552 auto const_iter_num = std::make_shared<Constant>("const/npu_runconfig/iterations_per_loop");
553 value -= 1; // iteration start from 0, the max iteration number for n loop should be n-1
554 (void)const_iter_num->set_attr_value(GeTensor(desc, reinterpret_cast<uint8_t *>(&value), sizeof(int64_t)));
555
556 auto const_loop_cond = std::make_shared<Constant>("const/npu_runconfig/loop_cond");
557 value = 0;
558 (void)const_loop_cond->set_attr_value(GeTensor(desc, reinterpret_cast<uint8_t *>(&value), sizeof(int64_t)));
559
560 auto const_one = std::make_shared<Constant>("const/npu_runconfig/one");
561 value = 1;
562 (void)const_one->set_attr_value(GeTensor(desc, reinterpret_cast<uint8_t *>(&value), sizeof(int64_t)));
563
564 auto const_zero = std::make_shared<Constant>("const/npu_runconfig/zero");
565 value = 0;
566 (void)const_zero->set_attr_value(GeTensor(desc, reinterpret_cast<uint8_t *>(&value), sizeof(int64_t)));
567
568 (void)const_iter_num->update_output_desc_y(desc);
569 (void)const_loop_cond->update_output_desc_y(desc);
570 (void)const_one->update_output_desc_y(desc);
571 (void)const_zero->update_output_desc_y(desc);
572
573 auto assign_iter_num = std::make_shared<Assign>("assign/npu_runconfig/iterations_per_loop");
574 (void)assign_iter_num->set_input_ref(*var_iter_num).set_input_value(*const_iter_num);
575 auto assign_loop_cond = std::make_shared<Assign>("assign/npu_runconfig/loop_cond");
576 (void)assign_loop_cond->set_input_ref(*var_loop_cond).set_input_value(*const_loop_cond);
577 auto assign_one = std::make_shared<Assign>("assign/npu_runconfig/one");
578 (void)assign_one->set_input_ref(*var_one).set_input_value(*const_one);
579 auto assign_zero = std::make_shared<Assign>("assign/npu_runconfig/zero");
580 (void)assign_zero->set_input_ref(*var_zero).set_input_value(*const_zero);
581
582 init_input->emplace_back(*var_iter_num);
583 init_input->emplace_back(*var_loop_cond);
584 init_input->emplace_back(*var_one);
585 init_input->emplace_back(*var_zero);
586 init_ops_.emplace_back(var_iter_num);
587 init_ops_.emplace_back(var_loop_cond);
588 init_ops_.emplace_back(var_one);
589 init_ops_.emplace_back(var_zero);
590 init_ops_.emplace_back(const_iter_num);
591 init_ops_.emplace_back(const_loop_cond);
592 init_ops_.emplace_back(const_one);
593 init_ops_.emplace_back(const_zero);
594 init_ops_.emplace_back(assign_iter_num);
595 init_ops_.emplace_back(assign_loop_cond);
596 init_ops_.emplace_back(assign_one);
597 init_ops_.emplace_back(assign_zero);
598 return is_sink_size_repeat;
599 }
600
DrawParamInitSubGraph(const std::string & name,const AnfNodePtr & it)601 void DfGraphConvertor::DrawParamInitSubGraph(const std::string &name, const AnfNodePtr &it) {
602 // draw init subgraph
603 init_sout_ << "op_assign" << it.get() << "[label=<";
604 init_sout_ << "<table border='1' cellborder='1'>" << endl;
605 init_sout_ << "<tr>";
606 init_sout_ << "<td port='1'>resource</td>";
607 init_sout_ << "<td port='2'>value</td>";
608 init_sout_ << "</tr>" << endl;
609 init_sout_ << "<tr><td colspan=\"2\">"
610 << "\"assign_" << name << "\"</td></tr>" << endl;
611 init_sout_ << "</table>> shape=plaintext]" << endl;
612 init_sout_ << "param" << it.get() << "[shape=octagon, label=\"" << name << "\"]" << endl;
613 init_sout_ << "const" << it.get() << "[label= \"" << name << "_const"
614 << "\" shape=ellipse]" << endl;
615 init_sout_ << "param" << it.get() << "->"
616 << "op_assign" << it.get() << ":1" << endl;
617 init_sout_ << "const" << it.get() << "->"
618 << "op_assign" << it.get() << ":2" << endl;
619 }
620
SetupParamInitSubGraph(const TensorOrderMap & tensors,const std::vector<::ge::Operator> * const init_input,bool is_sink_size_repeat)621 void DfGraphConvertor::SetupParamInitSubGraph(const TensorOrderMap &tensors,
622 const std::vector<::ge::Operator> *const init_input,
623 bool is_sink_size_repeat) {
624 DfGraphPtr init_graph = std::make_shared<DfGraph>(kInit);
625 MS_EXCEPTION_IF_NULL(init_graph);
626 std::vector<AnfNodePtr> nodes = GetOrderedCNodes(anf_graph_);
627
628 for (auto &it : nodes) {
629 MS_EXCEPTION_IF_NULL(it);
630 if (it->isa<ValueNode>()) {
631 if (IsValueNode<SymbolicKeyInstance>(it)) {
632 auto symbolic = GetValueNode<SymbolicKeyInstancePtr>(it);
633 auto name = std::static_pointer_cast<Parameter>(symbolic->node())->name();
634 auto iter = vars_.find(name); // get corresponding variable op
635 if (iter != vars_.end()) {
636 op_cache_[it.get()] = iter->second;
637 // #ifdef DRAW_GE_GRAPH
638 compute_sout_ << op_draw_name_[params_[name].get()] << " -> " << op_draw_name_[it.get()]
639 << "[style=\"dotted\"]" << endl;
640 // #endif
641 }
642 } else if (IsValueNode<RefKey>(it)) {
643 auto refkey = GetValueNode<StringImmPtr>(it);
644 MS_EXCEPTION_IF_NULL(refkey);
645 auto name = refkey->value();
646 auto iter = vars_.find(name); // get corresponding variable op
647 if (iter != vars_.end()) {
648 op_cache_[it.get()] = iter->second;
649 compute_sout_ << op_draw_name_[params_[name].get()] << " -> " << op_draw_name_[it.get()]
650 << "[style=\"dotted\"]" << endl;
651 }
652 }
653 }
654 }
655
656 for (auto &it : tensors) {
657 if (vars_.find(it.first) == vars_.end()) {
658 MS_LOG(WARNING) << "Init parameter " << it.first << " didn't appear in graph.";
659 vars_[it.first] = nullptr;
660 }
661 }
662
663 // set up init sub graph
664 MS_EXCEPTION_IF_NULL(init_input);
665 if (!init_input->empty() && !is_sink_size_repeat) {
666 // init sub graph needs no input
667 MS_LOG(INFO) << "Build data init subgraph.";
668 (void)init_graph->SetInputs(*init_input);
669 this->init_graph_ = init_graph;
670 } else {
671 this->init_graph_ = nullptr;
672 }
673 }
674
SetupParamInitSubGraph()675 void DfGraphConvertor::SetupParamInitSubGraph() {
676 DfGraphPtr init_graph = std::make_shared<DfGraph>("init");
677 std::vector<AnfNodePtr> nodes = GetOrderedCNodes(anf_graph_);
678 MS_EXCEPTION_IF_NULL(init_graph);
679
680 for (auto &it : nodes) {
681 MS_EXCEPTION_IF_NULL(it);
682 if (it->isa<ValueNode>()) {
683 if (IsValueNode<SymbolicKeyInstance>(it)) {
684 auto symbolic = GetValueNode<SymbolicKeyInstancePtr>(it);
685 MS_EXCEPTION_IF_NULL(symbolic);
686 MS_EXCEPTION_IF_NULL(std::static_pointer_cast<Parameter>(symbolic->node()));
687 auto name = std::static_pointer_cast<Parameter>(symbolic->node())->name();
688 auto iter = vars_.find(name); // get corresponding variable op
689 if (iter != vars_.end()) {
690 op_cache_[it.get()] = iter->second;
691 }
692 } else if (IsValueNode<RefKey>(it)) {
693 auto refkey = GetValueNode<StringImmPtr>(it);
694 MS_EXCEPTION_IF_NULL(refkey);
695 auto name = refkey->value();
696 auto iter = vars_.find(name); // get corresponding variable op
697 if (iter != vars_.end()) {
698 op_cache_[it.get()] = iter->second;
699 }
700 }
701 }
702 }
703
704 // set up init sub graph
705 std::vector<::ge::Operator> init_input;
706 bool is_sink_size_repeat = InitLoopVar(&init_input);
707 if (!init_input.empty() && !is_sink_size_repeat) {
708 // init sub graph needs no input
709 MS_LOG(INFO) << "Build data init subgraph.";
710 (void)init_graph->SetInputs(init_input);
711 this->init_graph_ = init_graph;
712 } else {
713 this->init_graph_ = nullptr;
714 }
715 }
716
SetupBroadcast(const OperatorPtr & broadcast,const std::vector<GeTensorDesc> & broadcast_desc,const DfGraphPtr & broadcast_graph,std::vector<::ge::Operator> broadcast_input)717 void DfGraphConvertor::SetupBroadcast(const OperatorPtr &broadcast, const std::vector<GeTensorDesc> &broadcast_desc,
718 const DfGraphPtr &broadcast_graph, std::vector<::ge::Operator> broadcast_input) {
719 MS_LOG(INFO) << "build broadcast subgraph";
720 if (broadcast_desc.size() != broadcast_input.size()) {
721 MS_LOG(EXCEPTION) << "Desc number of BroadCast is not equal to number of Input";
722 }
723 MS_EXCEPTION_IF_NULL(broadcast);
724 (void)broadcast->DynamicInputRegister(kTypeX, (static_cast<unsigned int>(broadcast_input.size())));
725 (void)broadcast->DynamicOutputRegister(kTypeY, static_cast<unsigned int>(broadcast_desc.size()));
726 for (unsigned int i = 0; i < broadcast_input.size(); i++) {
727 (void)broadcast->SetInput(kTypeX, i, broadcast_input[i]);
728 (void)broadcast->UpdateDynamicOutputDesc(kTypeY, i, broadcast_desc[i]);
729 }
730 MS_EXCEPTION_IF_NULL(broadcast_graph);
731 (void)broadcast_graph->SetInputs(broadcast_input);
732 this->broadcast_graph_ = broadcast_graph;
733 }
734
NodeInputKeepUpdate(const FuncGraphManagerPtr & manager,const AnfNodePtr & node)735 bool DfGraphConvertor::NodeInputKeepUpdate(const FuncGraphManagerPtr &manager, const AnfNodePtr &node) {
736 if (manager == nullptr || node == nullptr) {
737 MS_LOG(ERROR) << "Input argument manager or node is nullptr";
738 return false;
739 }
740 if (offline_convert_) {
741 return false;
742 }
743 if (std::find(extra_variables_names_.begin(), extra_variables_names_.end(), node->fullname_with_scope()) !=
744 extra_variables_names_.end()) {
745 return true;
746 }
747 const auto &node_users = manager->node_users();
748 std::vector<PrimitivePtr> vec{
749 prim::kPrimAssign, prim::kPrimKVCacheMgr, prim::kPrimScatterUpdate, prim::kPrimScatterNdUpdate,
750 prim::kPrimPromptKVCache, prim::kPrimDecoderKVCache, prim::kPrimKVCacheScatterUpdate};
751 auto user_it = node_users.find(node);
752 if (user_it != node_users.end()) {
753 auto &users = user_it->second;
754 for (auto &user_node : users) {
755 auto &node_use = user_node.first;
756 if (node_use && std::any_of(vec.begin(), vec.end(),
757 [&node_use](const PrimitivePtr &prim) { return IsPrimitiveCNode(node_use, prim); })) {
758 return true;
759 }
760 // check if node is ReshapeAndKVCache which is fused by akg.
761 if (IsPrimitiveCNode(node_use, prim::kPrimCustom)) {
762 auto prim_custom = GetCNodePrimitive(node_use);
763 const std::string kAttrNameInfoPath = "info_path";
764
765 if (!prim_custom->HasAttr(kAttrNameInfoPath)) {
766 continue;
767 }
768 auto info_path_attr_node = prim_custom->GetAttr(kAttrNameInfoPath);
769 if (info_path_attr_node == nullptr) {
770 MS_LOG(EXCEPTION) << "attr node '" << kAttrNameInfoPath << "' is null";
771 }
772 std::string info_path = GetValue<std::string>(info_path_attr_node);
773 const std::string kOpReshapeAndCache = "ReshapeAndCache";
774 if (info_path.find(kOpReshapeAndCache) == std::string::npos) {
775 continue;
776 }
777
778 MS_LOG(INFO) << "found ReshapeAndCache, make use inpu keep update";
779 return true;
780 }
781 }
782 }
783 return false;
784 }
785
JudgeParamTransType(const bool & node_will_update,bool * as_ref_data,bool * as_constant) const786 void DfGraphConvertor::JudgeParamTransType(const bool &node_will_update, bool *as_ref_data, bool *as_constant) const {
787 if (ref_mode_) {
788 if ((ref_mode_type_ == RefModeFlag::kRefModeAll || node_will_update) && !export_air_) {
789 *as_ref_data = true;
790 } else { // When only variable will be treated as RefData, constant Parameter will be treated as Constant
791 *as_constant = true;
792 }
793 } else if (!training_ && !node_will_update) {
794 // parameter will be updated, lite inference mode will treat as variables
795 *as_constant = true;
796 }
797 }
798
InitParamWithData(const TensorOrderMap & tensors)799 void DfGraphConvertor::InitParamWithData(const TensorOrderMap &tensors) {
800 int index = 0;
801 std::vector<Operator> init_input;
802 MS_EXCEPTION_IF_NULL(graph_manager_);
803 // The format of Momentum's accum is updated according to format of Momentum's var, so here sort tensors to put
804 // Momentum's accum parameter at last
805 auto cmp = std::bind(ParamCompare, std::placeholders::_1, std::placeholders::_2, std::cref(params_),
806 graph_manager_->node_users());
807 std::map<std::string, std::pair<int, tensor::TensorPtr>, decltype(cmp)> ordered_tensors(cmp);
808 // NOTE: the sequence of parameters of init DfGraph is calculated by TensorOrderMap, see method `GetInputTensors`
809 // defined in `mindspore/ccsrc/plugin/device/ascend/hal/hardware/ge_graph_executor.cc`
810 for (auto &it : tensors) {
811 ordered_tensors.insert({it.first, {index++, it.second}});
812 }
813 for (const auto &itor : ordered_tensors) {
814 std::string name = itor.first;
815 auto &it = itor.second;
816 auto node_itor = params_.find(name);
817 // if name not in params_, create a node in graph
818 if (node_itor == params_.end()) {
819 // In lite, param maybe not exist.
820 MS_LOG(WARNING) << name << " is not in params, and create a new node.";
821 ParameterPtr param = std::make_shared<Parameter>(nullptr);
822 MS_EXCEPTION_IF_NULL(param);
823 if (!ref_mode_) {
824 name += "_temp";
825 }
826 param->set_name(name);
827 (void)ConvertParameter(param);
828 node_itor = params_.find(name);
829 }
830 auto node = node_itor->second;
831 MS_EXCEPTION_IF_NULL(node);
832 auto op_itor = op_cache_.find(node.get());
833 if (op_itor == op_cache_.end()) {
834 MS_LOG(EXCEPTION) << "Can not find op for node " << node->ToString() << ".";
835 }
836
837 MS_EXCEPTION_IF_NULL(it.second);
838 bool as_ref_data = false;
839 bool as_constant = false;
840 auto node_will_update = NodeInputKeepUpdate(graph_manager_, node);
841 JudgeParamTransType(node_will_update, &as_ref_data, &as_constant);
842
843 auto shape = it.second->shape_c();
844 if (as_ref_data && dyn_ref_data_func_ != nullptr) {
845 shape = dyn_ref_data_func_(node, shape);
846 }
847 auto desc =
848 TransformUtil::GetGeTensorDesc(shape, it.second->data_type(), SelectParamOriFormat(graph_manager_, node));
849 if (desc == nullptr) {
850 MS_LOG(WARNING) << "Create const " << name << " output descriptor failed!";
851 continue;
852 }
853 if (as_ref_data) {
854 StorageFormatConvertor::SetupStorageFormat(anf_graph_, node, desc);
855 auto ref_data = std::make_shared<RefData>(name);
856 MS_EXCEPTION_IF_NULL(ref_data);
857 (void)ref_data->update_output_desc_y(*desc);
858 (void)ref_data->update_input_desc_x(*desc);
859 (void)ref_data->set_attr_index(SizeToInt(ref_datas_.size()));
860 (void)ref_datas_.emplace_back(ref_data);
861 ref_data_names_.emplace_back(name);
862 // do not use read ref_data while ref_data sink
863 MS_LOG(DEBUG) << "InitParam, op_name = " << name << ", var = " << ref_data->GetName() << ".";
864 op_itor->second = ref_data; // replace parameter with ref_data
865 vars_[name] = ref_data; // prevent the ref_data operator from being freed
866 } else if (as_constant) {
867 auto adpt_const = FindAdapter(kNameConst, training_);
868 if (adpt_const == nullptr) {
869 continue;
870 }
871 auto const_op = adpt_const->generate(name + "_const");
872 (void)adpt_const->setAttr(const_op, "value", it.second);
873 const_op->UpdateOutputDesc(kTypeY, *desc);
874 const_op_to_value_[const_op] = it.second;
875 vars_[name] = const_op;
876 op_itor->second = const_op;
877 } else {
878 auto &infer_need_update_parameter_names =
879 Singleton<mindspore::device::ascend::InferNeedUpdateParaNames>::Instance().GetInferParameterNames();
880 // we need three variable ops for each graph with same name
881 // build init subgraph
882 auto adpt = FindAdapter(kNameParam, training_);
883 if (adpt == nullptr) {
884 continue;
885 }
886 auto param_op = adpt->generate(name + "_data");
887 if (it.second->is_init() == 0) {
888 SetXDataIndex(param_op, it.first);
889 ProcessInputData(&init_input, &infer_need_update_parameter_names, param_op, name, desc);
890 }
891
892 auto variable = std::make_shared<Variable>(name);
893 MS_EXCEPTION_IF_NULL(variable);
894 (void)variable->update_output_desc_y(*desc);
895 // do not use read variable while variable sink
896 MS_LOG(DEBUG) << "InitParam, op_name = " << name << ", var = " << variable->GetName() << ".";
897 op_itor->second = variable; // replace parameter with variable
898 vars_[name] = variable; // prevent the variable operator from being freed
899 DrawParamInitSubGraph(name, node);
900 }
901 }
902 ReplaceAllParameterToRefData();
903 if (ref_mode_) {
904 SetupParamInitSubGraph();
905 } else {
906 bool is_sink_size_repeat = InitLoopVar(&init_input);
907 SetupParamInitSubGraph(tensors, &init_input, is_sink_size_repeat);
908 }
909 }
910
ReplaceAllParameterToRefData()911 void DfGraphConvertor::ReplaceAllParameterToRefData() {
912 if (ref_mode_ && (ref_mode_type_ == RefModeFlag::kRefModeAll) && !export_air_) {
913 MS_LOG(INFO) << "Graph abs ref tenor to ref data, " << anf_graph_->ToString();
914 auto parameters = anf_graph_->parameters();
915 int64_t idx = 0;
916 for (const auto ¶m : parameters) {
917 auto op_itor = op_cache_.find(param.get());
918 if (op_itor != op_cache_.end() && op_itor->second->GetOpType() == kTypeRefData) {
919 MS_LOG(INFO) << "This process param has default, have been change to RefData: " << param->fullname_with_scope();
920 continue;
921 }
922 auto para = param->cast<ParameterPtr>();
923 MS_EXCEPTION_IF_NULL(para);
924 auto abs = para->abstract();
925 MS_EXCEPTION_IF_NULL(abs);
926 if (!abs->isa<abstract::AbstractRefTensor>()) {
927 continue;
928 }
929 MS_EXCEPTION_IF_NULL(abs->BuildShape());
930 auto shape = abs->BuildShape()->GetShapeVector();
931 auto type = abs->BuildType()->type_id();
932 if (type == kObjectTypeTensorType) {
933 type = dyn_cast<TensorType>(abs->BuildType())->element()->type_id();
934 }
935 auto name = para->name();
936 if (name.empty()) {
937 name = "RefData_NULL_" + std::to_string(idx++);
938 }
939 auto ref_data = std::make_shared<RefData>(name);
940 MS_EXCEPTION_IF_NULL(ref_data);
941 auto desc = TransformUtil::GetGeTensorDesc(shape, type, SelectParamOriFormat(graph_manager_, para));
942 if (!desc) {
943 MS_LOG(ERROR) << "Create ge node desc failed, node name:" << name << ", shape: " << shape << ", type: " << type;
944 continue;
945 }
946 (void)ref_data->update_output_desc_y(*desc);
947 (void)ref_data->update_input_desc_x(*desc);
948 (void)ref_data->set_attr_index(SizeToInt(ref_datas_.size()));
949 (void)ref_datas_.emplace_back(ref_data);
950 ref_data_names_.emplace_back(name);
951 // do not use read ref_data while ref_data sink
952 MS_LOG(INFO) << "Change no default param: " << name << " to ref data. ";
953 op_itor->second = ref_data; // replace parameter with ref_data
954 vars_[name] = ref_data; // prevent the ref_data operator from being freed
955 }
956 }
957 }
958
ProcessInputData(vector<Operator> * init_input,std::unordered_set<std::string> * infer_need_update_parameter_names,const OperatorPtr & param_op,const string & name,const std::shared_ptr<GeTensorDesc> & desc)959 void DfGraphConvertor::ProcessInputData(vector<Operator> *init_input,
960 std::unordered_set<std::string> *infer_need_update_parameter_names,
961 const OperatorPtr ¶m_op, const string &name,
962 const std::shared_ptr<GeTensorDesc> &desc) {
963 MS_EXCEPTION_IF_NULL(init_input);
964 MS_EXCEPTION_IF_NULL(infer_need_update_parameter_names);
965 auto init_var = std::make_shared<Variable>(name);
966 auto assign_op = std::make_shared<Assign>("assign_" + name);
967 MS_EXCEPTION_IF_NULL(init_var);
968 MS_EXCEPTION_IF_NULL(assign_op);
969 (void)init_var->update_output_desc_y(*desc);
970 (void)assign_op->set_input_ref(*init_var).set_input_value(*param_op);
971 init_input->emplace_back(*init_var);
972 this->init_ops_.emplace_back(param_op);
973 this->init_ops_.emplace_back(assign_op);
974 this->init_ops_.emplace_back(init_var);
975 this->init_data_names_.emplace_back(name);
976 infer_need_update_parameter_names->insert(name);
977 }
978
979 // convert all parameter need initialize to variable
InitParam(const TensorOrderMap & tensors)980 DfGraphConvertor &DfGraphConvertor::InitParam(const TensorOrderMap &tensors) {
981 if (error_ != SUCCESS) {
982 return *this;
983 }
984 if (anf_graph_ == nullptr || anf_graph_->output() == nullptr) {
985 error_ = INVALID_ARGUMENT;
986 MS_LOG(ERROR) << "Invalid AnfGraph in InitParam.";
987 return *this;
988 }
989
990 InitParamWithData(tensors);
991 init_sout_ << "}" << endl;
992 return *this;
993 }
994
995 #if (defined ENABLE_D)
BuildSaveCheckpointGraph()996 void DfGraphConvertor::BuildSaveCheckpointGraph() {
997 std::vector<Operator> graph_inputs;
998 ::ge::op::Save save_op("save_parms");
999 int save_op_is_active = 0;
1000 size_t index = 0;
1001 string name;
1002
1003 auto count_size = std::count_if(vars_.begin(), vars_.end(), [](const auto &it) {
1004 return LongToUlong(it.second == nullptr || it.first.find("/") != std::string::npos);
1005 });
1006
1007 (void)save_op.create_dynamic_input_tensors(static_cast<uint32_t>(vars_.size() - static_cast<size_t>(count_size)));
1008
1009 // for each "parameter" in anf graph excluding "input"
1010 for (const auto &it : vars_) {
1011 name = it.first;
1012 if (it.second == nullptr || name.find("/") != std::string::npos) {
1013 continue;
1014 }
1015 Variable variable(name);
1016 (void)variable.update_output_desc_y(it.second->GetOutputDesc(0));
1017 (void)save_op.set_dynamic_input_tensors(static_cast<uint32_t>(index++), variable);
1018
1019 graph_inputs.emplace_back(variable);
1020
1021 if (save_op_is_active == 0) {
1022 checkpoint_sout_ << "op_save" << &save_op << "[label=<";
1023 checkpoint_sout_ << "<table border='1' cellborder='1'>" << endl;
1024 checkpoint_sout_ << "<tr><td port='1'>tensor</td></tr>" << endl;
1025 checkpoint_sout_ << "<tr><td colspan=\"1\">"
1026 << "\"saveop"
1027 << "\"</td></tr>" << endl;
1028 checkpoint_sout_ << "</table>> shape=plaintext]" << endl;
1029 }
1030
1031 checkpoint_sout_ << "param" << it.second << "[shape=octagon, label=\"" << name << "\"]" << endl;
1032
1033 checkpoint_sout_ << "param" << it.second << "->"
1034 << "op_save" << &save_op << ":1" << endl;
1035 save_op_is_active = 1;
1036 }
1037 if (save_op_is_active != 0) {
1038 std::vector<Operator> graph_output;
1039 (void)graph_output.emplace_back(save_op);
1040 DfGraphPtr checkpoint_graph = std::make_shared<DfGraph>("checkpoint");
1041 (void)checkpoint_graph->SetInputs(graph_inputs);
1042 (void)checkpoint_graph->SetOutputs(graph_output);
1043 this->save_ckp_graph_ = checkpoint_graph;
1044 } else {
1045 this->save_ckp_graph_ = nullptr;
1046 }
1047
1048 checkpoint_sout_ << "}" << endl;
1049 return;
1050 }
1051 #endif
1052
GenerateBroadcastGraph(const TensorOrderMap & tensors)1053 DfGraphConvertor &DfGraphConvertor::GenerateBroadcastGraph(const TensorOrderMap &tensors) {
1054 if (error_ != SUCCESS) {
1055 return *this;
1056 }
1057 if (anf_graph_ == nullptr || anf_graph_->output() == nullptr) {
1058 error_ = INVALID_ARGUMENT;
1059 MS_LOG(ERROR) << "Invalid AnfGraph in generate broadcast graph";
1060 return *this;
1061 }
1062
1063 DfGraphPtr broadcast_graph = std::make_shared<DfGraph>(kBroadcast);
1064 // collect the operators create for broadcast sub graph, in order to avoid auto release
1065 std::vector<Operator> broadcast_input;
1066 std::vector<GeTensorDesc> broadcast_desc;
1067 auto adpt = FindAdapter(kNameBroadcast);
1068 if (!adpt) {
1069 MS_LOG(EXCEPTION) << "Get adpt failed, node type: HcomBroadcast";
1070 }
1071 auto broadcast = adpt->generate("broadcast_parameter");
1072 const int64_t root_rank_v = 0;
1073 (void)broadcast->SetAttr("root_rank", root_rank_v);
1074 (void)broadcast->SetAttr("group", "hccl_world_group");
1075 broadcast_ops_.emplace_back(broadcast);
1076
1077 // find every parameter, build broadcast subgraph (or initialize the parameter with constant)
1078 for (auto &it : anf_graph_->parameters()) {
1079 auto op_itor = op_cache_.find(it.get()); // converted node
1080 if (it->isa<Parameter>() && op_itor != op_cache_.end()) {
1081 string name = std::static_pointer_cast<Parameter>(it)->name();
1082 auto tensor_itor = tensors.find(name); // in init tensor map
1083 if (tensor_itor != tensors.end()) {
1084 auto tensor = tensor_itor->second;
1085 auto shape_ge = tensor->shape_c();
1086
1087 // create tensor descriptor for output descriptor
1088 auto desc = TransformUtil::GetGeTensorDesc(shape_ge, tensor->data_type(), kOpFormat_DEFAULT);
1089 if (desc == nullptr) {
1090 MS_LOG(ERROR) << "Create variable " << name << " output descriptor failed!";
1091 continue;
1092 }
1093
1094 // build broadcast subgraph
1095 if (distribute_) {
1096 auto broadcast_var = std::make_shared<Variable>(name);
1097 (void)broadcast_var->update_output_desc_y(*desc);
1098 broadcast_input.emplace_back(*broadcast_var);
1099 broadcast_desc.emplace_back(*desc);
1100 broadcast_ops_.emplace_back(broadcast_var);
1101 }
1102 }
1103 }
1104 }
1105
1106 // set up broadcast sub graph
1107 if (!broadcast_input.empty()) {
1108 DfGraphConvertor::SetupBroadcast(broadcast, broadcast_desc, broadcast_graph, broadcast_input);
1109 } else {
1110 this->broadcast_graph_ = nullptr;
1111 }
1112 return *this;
1113 }
1114
GenerateCheckpointGraph()1115 DfGraphConvertor &DfGraphConvertor::GenerateCheckpointGraph() {
1116 if (error_ != SUCCESS) {
1117 MS_LOG(ERROR) << "Generate checkpoint graph failed, found error code " << error_ << ".";
1118 if (!unsupported_ops_names_.empty()) {
1119 MS_LOG(ERROR) << "===========================================";
1120 MS_LOG(ERROR) << unsupported_ops_names_.size() << " Operator(s) cannot be converted:";
1121 std::string unsupported_ops_list;
1122 for (const auto &unsupported_ops : unsupported_ops_names_) {
1123 if (!unsupported_ops_list.empty()) {
1124 unsupported_ops_list += ", ";
1125 }
1126 unsupported_ops_list += unsupported_ops;
1127 }
1128 MS_LOG(ERROR) << "Unsupported op type list: " << unsupported_ops_list;
1129 MS_LOG(ERROR) << "===========================================";
1130 }
1131 return *this;
1132 }
1133 if (anf_graph_ == nullptr || anf_graph_->output() == nullptr) {
1134 error_ = INVALID_ARGUMENT;
1135 MS_LOG(ERROR) << "Invalid AnfGraph in GenerateCheckpointGraph";
1136 return *this;
1137 }
1138 #ifdef ENABLE_D
1139 auto ms_context = MsContext::GetInstance();
1140 MS_EXCEPTION_IF_NULL(ms_context);
1141 if (ms_context->backend_policy() == "ge") {
1142 BuildSaveCheckpointGraph();
1143 // Restoring from checkpoint file is done by pyfront, not in graph now.
1144 }
1145 #endif
1146 return *this;
1147 }
1148
ConvertAllNode()1149 DfGraphConvertor &DfGraphConvertor::ConvertAllNode() {
1150 if (error_ != SUCCESS) {
1151 return *this;
1152 }
1153 if (anf_graph_ == nullptr || anf_graph_->output() == nullptr) {
1154 MS_LOG(ERROR) << "Invalid AnfGraph";
1155 error_ = FAILED;
1156 return *this;
1157 }
1158
1159 compute_sout_.clear();
1160 compute_sout_ << "digraph {" << endl;
1161 init_sout_.clear();
1162 init_sout_ << "digraph {" << endl;
1163 #ifdef ENABLE_D
1164 auto ms_context = MsContext::GetInstance();
1165 MS_EXCEPTION_IF_NULL(ms_context);
1166 if (ms_context->backend_policy() == "ge") {
1167 checkpoint_sout_.clear();
1168 checkpoint_sout_ << "digraph {" << endl;
1169 }
1170 #endif
1171 restore_checkpoint_sout_.clear();
1172 restore_checkpoint_sout_ << "digraph {" << endl;
1173 // Trans data_type for some specific nodes' inputs and attr.
1174 TransDataType(anf_graph_);
1175 // Convert all anf node to Operator
1176 MS_LOG(INFO) << "Convert all node, graph: " << anf_graph_->ToString();
1177 std::vector<AnfNodePtr> nodes = GetOrderedCNodes(anf_graph_, while_cond_node_);
1178 if (ref_mode_) {
1179 // Ref mode need build all node(cnode && parameter).
1180 for (auto &p : anf_graph_->parameters()) {
1181 if (std::find(nodes.begin(), nodes.end(), p) == nodes.end()) {
1182 MS_LOG(INFO) << "Parameter " << p->DebugString() << " can not found in topo sort lists.";
1183 nodes.emplace_back(p);
1184 }
1185 }
1186 }
1187 for (auto &it : nodes) {
1188 if (IsSubGraph() && it->isa<Parameter>()) {
1189 continue;
1190 }
1191 if (IsSubGraph() && (IsPartialSuccNode(it) || IsPartialCNode(it))) {
1192 continue;
1193 }
1194 (void)Convert(it);
1195 if (this->error_ != SUCCESS) {
1196 MS_LOG(ERROR) << "Failed to convert node: " << it->DebugString() << ".";
1197 }
1198 }
1199
1200 // return the data flow graph
1201 return *this;
1202 }
1203
CacheWhileGraph(const CNodePtr & cnode)1204 void DfGraphConvertor::CacheWhileGraph(const CNodePtr &cnode) {
1205 if (while_graph_cache_.find(cnode) != while_graph_cache_.end()) {
1206 return;
1207 }
1208 ValueNodePtr graph_node = nullptr;
1209 if (is_kernel_graph_) {
1210 graph_node = cnode->input(1)->cast<ValueNodePtr>();
1211 } else {
1212 if (cnode->input(0)->isa<ValueNode>()) {
1213 graph_node = cnode->input(0)->cast<ValueNodePtr>();
1214 } else {
1215 auto partial_node = cnode->input(0);
1216 graph_node = partial_node->cast<CNodePtr>()->input(1)->cast<ValueNodePtr>();
1217 }
1218 }
1219
1220 MS_EXCEPTION_IF_NULL(graph_node);
1221 FuncGraphPtr cond_graph = graph_node->value()->cast<FuncGraphPtr>();
1222 MS_EXCEPTION_IF_NULL(cond_graph);
1223 const auto &cond_set = cond_graph->nodes();
1224 for (auto beg = cond_set.begin(); beg != cond_set.end(); ++beg) {
1225 if (!((*beg)->isa<CNode>())) {
1226 continue;
1227 }
1228 auto c_beg = (*beg)->cast<CNodePtr>();
1229 if (GetCNodeFuncName(c_beg) == prim::kPrimSwitch->name()) {
1230 while_graph_cache_[cnode] = {c_beg->input(1), c_beg->input(kSwitchBodyIndex), c_beg->input(kSwitchAfterIndex)};
1231 }
1232 }
1233 }
1234
GetWhileBodyOutputs()1235 std::vector<Operator> DfGraphConvertor::GetWhileBodyOutputs() {
1236 std::vector<Operator> outputs;
1237
1238 const auto &node = anf_graph_->get_return()->input(1);
1239 AnfNodePtr real_ret = node;
1240 MS_EXCEPTION_IF_NULL(real_ret);
1241 while (real_ret->isa<CNode>() && GetCNodeTargetFuncName(real_ret->cast<CNodePtr>()) == prim::kPrimDepend->name()) {
1242 real_ret = real_ret->cast<CNodePtr>()->input(1);
1243 }
1244
1245 // skip input of UMonad, IOMonad
1246 if (HasAbstractUMonad(real_ret) || HasAbstractIOMonad(real_ret)) {
1247 return outputs;
1248 }
1249
1250 // skip input of the None, UpdateState
1251 if (IsValueNode<None>(real_ret) || IsPrimitiveCNode(real_ret, prim::kPrimUpdateState)) {
1252 return outputs;
1253 }
1254
1255 if (IsPrimitiveCNode(real_ret, prim::kPrimLoad)) {
1256 real_ret = ParseLoadInput(real_ret->cast<CNodePtr>());
1257 }
1258
1259 if (!real_ret->isa<CNode>()) {
1260 return outputs;
1261 }
1262
1263 auto c_node = real_ret->cast<CNodePtr>();
1264 std::vector<AnfNodePtr> inputs = GetAnfCallInputs(is_kernel_graph_, c_node);
1265 for (size_t i = 0; i < inputs.size(); i++) {
1266 auto j = inputs[i];
1267 MS_EXCEPTION_IF_NULL(j);
1268 if (!IsDataInput(c_node, j, 0)) {
1269 continue;
1270 }
1271 if (j->isa<Parameter>()) {
1272 int64_t idx = find(inputs_.begin(), inputs_.end(), j) - inputs_.begin();
1273 auto idx_cond = body_cond_map_[idx];
1274 if (while_used_input_index_.find(idx_cond) == while_used_input_index_.end() ||
1275 while_const_input_index_.find(idx_cond) != while_const_input_index_.end()) {
1276 continue;
1277 }
1278 outputs.emplace_back(*(subgraph_input_cache_[idx]));
1279 } else {
1280 outputs.emplace_back(*Convert(j));
1281 }
1282 }
1283 MS_LOG(DEBUG) << "get while body outputs size: " << outputs.size();
1284 return outputs;
1285 }
1286
GetWhileSubGraphInput()1287 std::shared_ptr<std::vector<Operator>> DfGraphConvertor::GetWhileSubGraphInput() {
1288 std::shared_ptr<std::vector<Operator>> graph_in = std::make_shared<std::vector<Operator>>();
1289 subgraph_input_cache_.clear();
1290 size_t i = 0;
1291 OperatorPtr op = nullptr;
1292 ParamIndexMap cond_body;
1293 std::string name_app = "_in_cond";
1294 if (graph_type_ == GraphType::kBody) {
1295 name_app = "_in_body";
1296 for (auto &p : body_cond_map_) {
1297 cond_body[p.second] = p.first;
1298 }
1299 }
1300 for (auto &idx : while_used_input_index_) {
1301 if (while_const_input_index_.find(idx) == while_const_input_index_.end()) {
1302 op = std::make_shared<Data>();
1303 MS_EXCEPTION_IF_NULL(op);
1304 SetXDataIndex(op, i);
1305 i++;
1306 } else {
1307 // No need to process ge tensor desc
1308 auto temp = while_const_input_index_[idx].op;
1309 auto name = temp->GetName();
1310 auto value = const_op_to_value_[temp];
1311 MS_EXCEPTION_IF_NULL(value);
1312 auto adpt_const = FindAdapter(kNameConst, training_);
1313 if (adpt_const == nullptr) {
1314 continue;
1315 }
1316 name += name_app;
1317 auto const_op = adpt_const->generate(name);
1318 (void)adpt_const->setAttr(const_op, "value", value);
1319 auto const_op_desc = TransformUtil::GetGeTensorDesc(value->shape_c(), value->data_type(), kOpFormat_DEFAULT);
1320 if (const_op_desc == nullptr) {
1321 MS_LOG(WARNING) << "Create variable " << name << " output descriptor failed!";
1322 continue;
1323 }
1324 const_op->UpdateOutputDesc(kTypeY, *const_op_desc);
1325 op = const_op;
1326 }
1327 graph_in->emplace_back(*op);
1328 if (graph_type_ == GraphType::kCond) {
1329 subgraph_input_cache_[idx] = op;
1330 } else if (graph_type_ == GraphType::kBody) {
1331 subgraph_input_cache_[cond_body[idx]] = op;
1332 }
1333 }
1334 MS_LOG(DEBUG) << "created " << subgraph_input_cache_.size() << " data node "
1335 << " in graph: " << anf_graph_->ToString();
1336 return graph_in;
1337 }
1338
BuildWhileSubGraph()1339 void DfGraphConvertor::BuildWhileSubGraph() {
1340 // set up dependencies
1341
1342 std::vector<Operator> graph_in = *GetWhileSubGraphInput();
1343 auto nodes = GetOrderedCNodes(anf_graph_, while_cond_node_);
1344
1345 AnfNodePtr real_ret = anf_graph_->get_return()->input(1);
1346 while (real_ret->isa<CNode>() && GetCNodeTargetFuncName(real_ret->cast<CNodePtr>()) == prim::kPrimDepend->name()) {
1347 real_ret = real_ret->cast<CNodePtr>()->input(1);
1348 }
1349 for (auto &it : nodes) {
1350 if (IsBranchNode(it)) {
1351 auto node = it->cast<CNodePtr>();
1352 GetBranchNodeInput(node);
1353 }
1354 }
1355
1356 for (auto &it : nodes) {
1357 if (it == real_ret || HasAbstractMonad(it)) {
1358 continue;
1359 }
1360 SetNodeInput(it);
1361 SetSubgraph(it);
1362 UpdateOpDesc(it);
1363 }
1364 std::vector<Operator> graph_out;
1365 auto graph_name = TransformUtil::NormOpName(cur_while_node_->fullname_with_scope());
1366 if (graph_type_ == GraphType::kCond) {
1367 if (op_cache_.find(while_cond_node_.get()) == op_cache_.end()) {
1368 return;
1369 }
1370 graph_name += "_cond_graph";
1371 graph_out.emplace_back(*(op_cache_[while_cond_node_.get()]));
1372 } else {
1373 graph_name += "_body_graph";
1374 graph_out = GetWhileBodyOutputs();
1375 }
1376 if (error_ == SUCCESS) {
1377 if (df_graph_->GetName() != graph_name) {
1378 MS_LOG(DEBUG) << "convert anf graph name : " << df_graph_->GetName() << " to df graph name: " << graph_name;
1379 }
1380 df_graph_ = make_shared<DfGraph>(graph_name);
1381 } else {
1382 return;
1383 }
1384 MS_LOG(DEBUG) << "Set while sub graph input num: " << graph_in.size();
1385 MS_LOG(DEBUG) << "Set while sub graph output num: " << graph_out.size();
1386
1387 compute_sout_ << "}" << endl;
1388 (void)df_graph_->SetInputs(graph_in).SetOutputs(graph_out);
1389 IdentityOptimization();
1390 }
1391
BuildWhileAfterSubGraph()1392 void DfGraphConvertor::BuildWhileAfterSubGraph() {
1393 size_t i = 0;
1394 prev_cond_to_while_out_index_.clear();
1395 for (auto n : prev_while_used_input_index_) {
1396 if (prev_while_const_input_index_.find(n) == prev_while_const_input_index_.end()) {
1397 prev_cond_to_while_out_index_[n] = i;
1398 i++;
1399 }
1400 }
1401 GetCallNodeInputs(cur_while_node_);
1402 auto nodes = GetOrderedCNodes(anf_graph_);
1403 for (auto &it : nodes) {
1404 SetNodeInput(it);
1405 SetSubgraph(it);
1406 UpdateOpDesc(it);
1407 }
1408 if (graph_outputs_.empty()) {
1409 SetGraphOutputs();
1410 }
1411 compute_sout_ << "}" << endl;
1412 return;
1413 }
1414
ConvertWhileBody(const AnfNodePtr & node)1415 void DfGraphConvertor::ConvertWhileBody(const AnfNodePtr &node) {
1416 if (!node->isa<CNode>() || GetCNodeFuncName(node->cast<CNodePtr>()) != prim::kPrimPartial->name()) {
1417 return;
1418 }
1419 auto graph_node = node->cast<CNodePtr>()->input(1)->cast<ValueNodePtr>();
1420 MS_EXCEPTION_IF_NULL(graph_node);
1421 FuncGraphPtr anf_graph = graph_node->value()->cast<FuncGraphPtr>();
1422 MS_EXCEPTION_IF_NULL(anf_graph);
1423 DfGraphConvertor converter(anf_graph, phase_prefix_);
1424 converter.use_inputs_ = true;
1425
1426 const auto ¶ms = anf_graph->parameters();
1427 converter.inputs_ = params;
1428
1429 converter.graph_type_ = GraphType::kBody;
1430 converter.cur_while_node_ = cur_while_node_;
1431 converter.body_cond_map_ = body_cond_map_;
1432 converter.while_const_input_index_ = while_const_input_index_;
1433 converter.while_used_input_index_ = while_used_input_index_;
1434 converter.const_op_to_value_ = const_op_to_value_;
1435 converter.ConvertAllNode().BuildWhileSubGraph();
1436 while_dfgraph_cache_[cur_while_node_]->emplace_back(*(converter.df_graph_));
1437 std::string name = graph_node->ToString() + "_ge_graph.dot";
1438 auto context = MsContext::GetInstance();
1439 MS_EXCEPTION_IF_NULL(context);
1440 if (context->CanDump(kFully)) {
1441 converter.DrawComputeGraph(name);
1442 }
1443 return;
1444 }
1445
GetWhileUsedInputIndex(const std::vector<AnfNodePtr> & graphs)1446 void DfGraphConvertor::GetWhileUsedInputIndex(const std::vector<AnfNodePtr> &graphs) {
1447 if (!while_used_input_index_.empty()) {
1448 return;
1449 }
1450
1451 auto cond_graph_node = graphs.at(0);
1452 auto graph = cond_graph_node->func_graph();
1453 MS_EXCEPTION_IF_NULL(graph);
1454 const auto &cond_params = graph->parameters();
1455 auto nodes = GetOrderedCNodes(graph, cond_graph_node);
1456
1457 std::set<size_t> used_params_index;
1458 for (auto &n : nodes) {
1459 if (!n->isa<CNode>()) {
1460 continue;
1461 }
1462 auto c = n->cast<CNodePtr>();
1463 auto inputs = c->inputs();
1464 for (size_t idx = 1; idx < inputs.size(); idx++) {
1465 auto &i = inputs[idx];
1466 if (!i->isa<Parameter>() || HasAbstractMonad(i) || IsDynamicShapeNode(i)) {
1467 continue;
1468 }
1469 auto idx_cond = std::find(cond_params.begin(), cond_params.end(), i) - cond_params.begin();
1470 (void)used_params_index.insert(idx_cond);
1471 }
1472 }
1473
1474 auto body_graph_node_in_cond = graphs.at(1)->cast<CNodePtr>();
1475 auto body_graph_node = body_graph_node_in_cond->input(1)->cast<ValueNodePtr>();
1476 MS_EXCEPTION_IF_NULL(body_graph_node);
1477 graph = body_graph_node->value()->cast<FuncGraphPtr>();
1478 const auto &body_params = graph->parameters();
1479
1480 auto real_ret = graph->get_return()->input(1);
1481 while (real_ret->isa<CNode>() && GetCNodeTargetFuncName(real_ret->cast<CNodePtr>()) == prim::kPrimDepend->name()) {
1482 real_ret = real_ret->cast<CNodePtr>()->input(1);
1483 }
1484
1485 nodes = GetOrderedCNodes(graph);
1486 for (auto &n : nodes) {
1487 if (!n->isa<CNode>()) {
1488 continue;
1489 }
1490 auto c = n->cast<CNodePtr>();
1491 if (c == real_ret || c == real_ret->cast<CNodePtr>()->input(0)) {
1492 continue;
1493 }
1494 auto inputs = c->inputs();
1495 for (size_t idx = 1; idx < inputs.size(); idx++) {
1496 auto &i = inputs[idx];
1497 if (!i->isa<Parameter>() || HasAbstractMonad(i) || IsDynamicShapeNode(i)) {
1498 continue;
1499 }
1500 auto idx_body = std::find(body_params.begin(), body_params.end(), i) - body_params.begin();
1501 auto p = body_graph_node_in_cond->input(static_cast<size_t>(idx_body + kInputOffset));
1502 auto idx_cond = std::find(cond_params.begin(), cond_params.end(), p) - cond_params.begin();
1503 (void)used_params_index.insert(idx_cond);
1504 }
1505 }
1506 while_used_input_index_ = used_params_index;
1507 }
1508
SetParamIndexMap(const std::vector<AnfNodePtr> & graphs)1509 void DfGraphConvertor::SetParamIndexMap(const std::vector<AnfNodePtr> &graphs) {
1510 auto cond_graph_node = graphs.at(0);
1511 MS_EXCEPTION_IF_NULL(cond_graph_node);
1512 auto cond_graph = cond_graph_node->func_graph();
1513 MS_EXCEPTION_IF_NULL(cond_graph);
1514 const auto &cond_params = cond_graph->parameters();
1515
1516 auto body_graph_node = graphs.at(1);
1517 MS_EXCEPTION_IF_NULL(body_graph_node);
1518 if (!body_graph_node->isa<CNode>()) {
1519 return;
1520 }
1521 MS_EXCEPTION_IF_NULL(body_graph_node->cast<CNodePtr>());
1522 auto body_graph_node_inputs = body_graph_node->cast<CNodePtr>()->inputs();
1523 std::vector<AnfNodePtr> body_params;
1524 for (auto it = body_graph_node_inputs.begin() + kInputOffset; it != body_graph_node_inputs.end(); ++it) {
1525 body_params.emplace_back(*it);
1526 }
1527
1528 for (size_t i = 0; i < body_params.size(); i++) {
1529 auto p = body_params[i];
1530 int64_t idx = find(cond_params.begin(), cond_params.end(), p) - cond_params.begin();
1531 body_cond_map_[i] = static_cast<size_t>(idx);
1532 MS_LOG(DEBUG) << "body_cond_map_'s key: " << i << " value: " << idx;
1533 }
1534
1535 auto after_graph_node = graphs.at(kSwitchBodyIndex);
1536 MS_EXCEPTION_IF_NULL(after_graph_node);
1537 if (!after_graph_node->isa<CNode>()) {
1538 return;
1539 }
1540 MS_EXCEPTION_IF_NULL(after_graph_node->cast<CNodePtr>());
1541 auto after_graph_node_inputs = after_graph_node->cast<CNodePtr>()->inputs();
1542 std::vector<AnfNodePtr> after_params;
1543 for (auto it = after_graph_node_inputs.begin() + 2; it != after_graph_node_inputs.end(); ++it) {
1544 after_params.emplace_back(*it);
1545 }
1546
1547 for (size_t i = 0; i < after_params.size(); i++) {
1548 auto p = after_params[i];
1549 int64_t idx = find(cond_params.begin(), cond_params.end(), p) - cond_params.begin();
1550 after_cond_map_[i] = static_cast<size_t>(idx);
1551 MS_LOG(DEBUG) << "after_cond_map_'s key: " << i << " value: " << idx;
1552 }
1553 return;
1554 }
1555
ConvertWhileCond(const AnfNodePtr & node)1556 void DfGraphConvertor::ConvertWhileCond(const AnfNodePtr &node) {
1557 MS_LOG(DEBUG) << "begin to convert while node cond graph";
1558 auto func_graph = node->func_graph();
1559 MS_EXCEPTION_IF_NULL(func_graph);
1560
1561 DfGraphConvertor converter(func_graph, phase_prefix_);
1562 converter.use_inputs_ = true;
1563
1564 converter.inputs_ = func_graph->parameters();
1565
1566 converter.graph_type_ = GraphType::kCond;
1567 converter.cur_while_node_ = cur_while_node_;
1568 converter.while_cond_node_ = node;
1569 converter.while_const_input_index_ = while_const_input_index_;
1570 converter.while_used_input_index_ = while_used_input_index_;
1571 converter.const_op_to_value_ = const_op_to_value_;
1572 converter.ConvertAllNode().BuildWhileSubGraph();
1573 MS_EXCEPTION_IF_NULL(while_dfgraph_cache_[cur_while_node_]);
1574 while_dfgraph_cache_[cur_while_node_]->emplace_back(*(converter.df_graph_));
1575 std::string name = func_graph->ToString() + "_ge_graph.dot";
1576 auto context = MsContext::GetInstance();
1577 MS_EXCEPTION_IF_NULL(context);
1578 if (context->CanDump(kFully)) {
1579 converter.DrawComputeGraph(name);
1580 }
1581
1582 MS_LOG(DEBUG) << "convert while node cond graph end";
1583 }
1584
SetWhileOutputHandle(const OperatorPtr & prev_while_op)1585 void DfGraphConvertor::SetWhileOutputHandle(const OperatorPtr &prev_while_op) {
1586 if (while_output_handle_cache_.find(prev_while_node_) != while_output_handle_cache_.end()) {
1587 return;
1588 }
1589 auto out_handler = std::make_shared<std::vector<OutHandler>>();
1590 MS_EXCEPTION_IF_NULL(out_handler);
1591 string str = "output";
1592 for (size_t i = 0; i < prev_while_node_out_size_; i++) {
1593 (void)out_handler->emplace_back(prev_while_op, str + std::to_string(i), prev_while_node_);
1594 }
1595 while_output_handle_cache_[prev_while_node_] = out_handler;
1596 return;
1597 }
1598
ConvertWhileAfter(const AnfNodePtr & node)1599 void DfGraphConvertor::ConvertWhileAfter(const AnfNodePtr &node) {
1600 MS_EXCEPTION_IF_NULL(node);
1601 if (!node->isa<CNode>() || GetCNodeFuncName(node->cast<CNodePtr>()) != prim::kPrimPartial->name()) {
1602 return;
1603 }
1604 MS_LOG(DEBUG) << "begin to convert while node after graph";
1605 MS_EXCEPTION_IF_NULL(node->cast<CNodePtr>()->input(1));
1606 auto graph_node = node->cast<CNodePtr>()->input(1)->cast<ValueNodePtr>();
1607 MS_EXCEPTION_IF_NULL(graph_node);
1608 MS_EXCEPTION_IF_NULL(graph_node->value());
1609 FuncGraphPtr anf_graph = graph_node->value()->cast<FuncGraphPtr>();
1610 MS_EXCEPTION_IF_NULL(anf_graph);
1611 DfGraphConvertor converter(anf_graph, phase_prefix_);
1612 converter.use_inputs_ = true;
1613
1614 const auto ¶ms = anf_graph->parameters();
1615 converter.inputs_ = params;
1616
1617 converter.graph_type_ = GraphType::kAfter;
1618 converter.prev_after_cond_map_ = after_cond_map_;
1619 converter.prev_while_node_ = cur_while_node_;
1620 converter.prev_while_node_out_size_ = cur_while_node_out_size_;
1621 converter.bypass_node_prev_handle_cache_ = bypass_node_handle_cache_;
1622 converter.prev_while_used_input_index_ = while_used_input_index_;
1623 converter.prev_while_const_input_index_ = while_const_input_index_;
1624 converter.const_op_to_value_ = const_op_to_value_;
1625
1626 auto while_op = Convert(converter.prev_while_node_);
1627 converter.SetWhileOutputHandle(while_op);
1628 converter.ConvertAllNode().BuildWhileAfterSubGraph();
1629 std::string name = graph_node->ToString() + "_ge_graph.dot";
1630 auto context = MsContext::GetInstance();
1631 MS_EXCEPTION_IF_NULL(context);
1632 if (context->CanDump(kFully)) {
1633 converter.DrawComputeGraph(name);
1634 }
1635 MS_LOG(DEBUG) << "add while after graph " << converter.graph_const_inputs_.size()
1636 << " const inputs to main graph const inputs";
1637 (void)std::transform(converter.graph_const_inputs_.begin(), converter.graph_const_inputs_.end(),
1638 std::back_inserter(graph_const_inputs_), [](OperatorPtr x) { return x; });
1639
1640 graph_outputs_ = converter.graph_outputs_;
1641 MS_LOG(DEBUG) << "convert while node after graph end";
1642 return;
1643 }
1644
ConvertWhileNode(const CNodePtr & node)1645 void DfGraphConvertor::ConvertWhileNode(const CNodePtr &node) {
1646 if (IsSubGraph()) {
1647 return;
1648 }
1649
1650 auto while_graph = while_graph_cache_[node];
1651 cur_while_node_ = node;
1652
1653 auto &while_inputs = *(call_input_handle_cache_[node]);
1654 cur_while_node_out_size_ = while_inputs.size();
1655 while_dfgraph_cache_[node] = std::make_shared<std::vector<DfGraph>>();
1656 // convert cond graph
1657 auto cond_graph_node = while_graph[0];
1658 ConvertWhileCond(cond_graph_node);
1659
1660 // convert body graph
1661 auto body_graph_node = while_graph[1];
1662 ConvertWhileBody(body_graph_node);
1663
1664 OpAdapterPtr adpt = FindAdapter(node, training_);
1665 if (adpt == nullptr) {
1666 MS_LOG(DEBUG) << "Not found adapter";
1667 return;
1668 }
1669
1670 OperatorPtr op = Convert(node);
1671 auto graphs = while_dfgraph_cache_[node];
1672 adpt->setSubgraph(op, graphs);
1673
1674 // convert after graph
1675 auto after_graph_node = while_graph[kAfterIndexInCache];
1676 ConvertWhileAfter(after_graph_node);
1677 return;
1678 }
1679
BuildBranchGraphs(const CNodePtr & cnode)1680 std::shared_ptr<std::vector<DfGraph>> DfGraphConvertor::BuildBranchGraphs(const CNodePtr &cnode) {
1681 MS_EXCEPTION_IF_NULL(cnode);
1682 bool is_case = IsCaseNode(cnode);
1683 std::shared_ptr<std::vector<DfGraph>> df_branches = std::make_shared<std::vector<DfGraph>>();
1684 MS_EXCEPTION_IF_NULL(df_branches);
1685 if (IsNormalGraph() || IsBodyGraph() || IsBranchGraph()) {
1686 size_t branch_call_input_size = 0;
1687 size_t node_input_index = 0;
1688 if (!is_kernel_graph_) {
1689 for (size_t i = 1; i < cnode->size(); i++) {
1690 auto pred = cnode->input(i);
1691 if (!IsDataInput(cnode, pred, 0)) {
1692 continue;
1693 }
1694 node_input_index++;
1695 branch_call_input_size++;
1696 }
1697 }
1698 MS_EXCEPTION_IF_NULL(cnode->input(0));
1699 CNodePtr input_node = is_kernel_graph_ ? cnode : cnode->input(0)->cast<CNodePtr>();
1700 MS_EXCEPTION_IF_NULL(input_node);
1701 MS_EXCEPTION_IF_NULL(input_node->input(kInputOffset));
1702 auto bnode = is_case ? input_node->input(kInputOffset)->cast<CNodePtr>() : input_node->cast<CNodePtr>();
1703 MS_EXCEPTION_IF_NULL(bnode);
1704 const size_t init_i = is_case ? 1 : 2;
1705
1706 for (size_t i = init_i; i < bnode->size(); i++) {
1707 ParamIndexMap branch_to_parent_node_map;
1708 size_t branch_index = 0; // branch graph input's index
1709 if (bnode->input(i)->isa<CNode>()) {
1710 auto branch_node = bnode->input(i)->cast<CNodePtr>();
1711 MS_EXCEPTION_IF_NULL(branch_node);
1712 for (size_t j = kInputOffset; j < branch_node->size(); j++) {
1713 auto pred = branch_node->input(j);
1714 if (!IsDataInput(cnode, pred, 0)) {
1715 continue;
1716 }
1717 branch_to_parent_node_map[branch_index] = node_input_index;
1718 node_input_index++;
1719 branch_index++;
1720 }
1721 }
1722 if (!is_kernel_graph_) {
1723 for (size_t k = 0; k < branch_call_input_size; k++) {
1724 branch_to_parent_node_map[branch_index] = k;
1725 branch_index++;
1726 }
1727 }
1728 ProcessSubgraph(cnode, bnode->input(i), branch_to_parent_node_map);
1729 (void)(df_branches->emplace_back(branches_map_[bnode->input(i).get()]));
1730 }
1731 }
1732 return df_branches;
1733 }
1734
BuildCallSubGraphs(const AnfNodePtr & node)1735 void DfGraphConvertor::BuildCallSubGraphs(const AnfNodePtr &node) {
1736 MS_EXCEPTION_IF_NULL(node);
1737 auto cnode = node->cast<CNodePtr>();
1738 MS_EXCEPTION_IF_NULL(cnode);
1739 MS_EXCEPTION_IF_NULL(cnode->input(1));
1740 ValueNodePtr graph_node = cnode->input(1)->cast<ValueNodePtr>();
1741 MS_EXCEPTION_IF_NULL(graph_node);
1742 MS_EXCEPTION_IF_NULL(graph_node->value());
1743 auto anf_graph = graph_node->value()->cast<AnfGraphPtr>();
1744 MS_EXCEPTION_IF_NULL(anf_graph);
1745 DfGraphConvertor converter(anf_graph, phase_prefix_);
1746 converter.graph_type_ = GraphType::kNormal;
1747 converter.use_inputs_ = true;
1748 converter.inputs_ = anf_graph->parameters();
1749 std::string graph_name = anf_graph->ToString();
1750 auto iter = call_subgraphs_repeat_times.find(graph_name);
1751 if (iter == call_subgraphs_repeat_times.end()) {
1752 call_subgraphs_repeat_times[graph_name] = 1;
1753 } else {
1754 iter->second += 1;
1755 graph_name = graph_name + "_call_" + std::to_string(iter->second);
1756 }
1757 (void)converter.ConvertAllNode().BuildGraph(graph_name);
1758
1759 call_dfgraph_cache_[node] = std::make_shared<std::vector<DfGraph>>();
1760 MS_EXCEPTION_IF_NULL(call_dfgraph_cache_[node]);
1761 call_dfgraph_cache_[node]->emplace_back(*(converter.df_graph_));
1762 MS_LOG(INFO) << "build call subgraph end.";
1763 }
1764
SetSubgraph(const AnfNodePtr & node)1765 void DfGraphConvertor::SetSubgraph(const AnfNodePtr &node) {
1766 MS_EXCEPTION_IF_NULL(node);
1767 if (!node->isa<CNode>()) {
1768 return;
1769 }
1770 auto cnode = node->cast<CNodePtr>();
1771 if (IsWhileNode(cnode)) {
1772 MS_LOG(DEBUG) << "Start to set while's sub graph.";
1773 CacheWhileGraph(cnode);
1774 ConvertWhileNode(cnode);
1775 MS_LOG(DEBUG) << "Set while's sub graph end.";
1776 return;
1777 }
1778
1779 if (IsBranchNode(cnode)) {
1780 MS_LOG(DEBUG) << "Start to set if/case's sub graph.";
1781 std::shared_ptr<std::vector<DfGraph>> df_branches = BuildBranchGraphs(cnode);
1782 if (op_cache_.find(node.get()) == op_cache_.end()) {
1783 return;
1784 }
1785
1786 OpAdapterPtr adpt = FindAdapter(node, training_);
1787 if (adpt == nullptr) {
1788 MS_LOG(DEBUG) << "Not found adapter";
1789 return;
1790 }
1791
1792 OperatorPtr op = Convert(node);
1793 bool is_case = IsCaseNode(node);
1794 if (is_case) {
1795 adpt->setSubgraph(op, 0, df_branches);
1796 } else {
1797 adpt->setSubgraph(op, df_branches);
1798 }
1799 MS_LOG(DEBUG) << "Set if/case's sub graph end.";
1800 return;
1801 }
1802
1803 if (IsCallNode(cnode)) {
1804 MS_LOG(DEBUG) << "Start to set call's sub graph.";
1805 BuildCallSubGraphs(node);
1806 if (op_cache_.find(node.get()) == op_cache_.end()) {
1807 return;
1808 }
1809 OpAdapterPtr adpt = FindAdapter(node, training_);
1810 if (adpt == nullptr) {
1811 MS_LOG(EXCEPTION) << "Not found adapter";
1812 return;
1813 }
1814 OperatorPtr op = Convert(node);
1815 auto df_graphs = call_dfgraph_cache_[node];
1816 adpt->setSubgraph(op, df_graphs);
1817 MS_LOG(DEBUG) << "Set call's sub graph end.";
1818 }
1819 return;
1820 }
1821
GetBranchNodeInput(const CNodePtr node)1822 void DfGraphConvertor::GetBranchNodeInput(const CNodePtr node) {
1823 if (branch_input_handle_cache_.find(node.get()) != branch_input_handle_cache_.end()) {
1824 return;
1825 }
1826 bool is_case = IsCaseNode(node);
1827 std::vector<AnfNodePtr> branch_inputs;
1828 const size_t branch_index = 1;
1829
1830 MS_EXCEPTION_IF_NULL(node);
1831 MS_EXCEPTION_IF_NULL(node->input(0));
1832 CNodePtr sw_node = is_kernel_graph_ ? node : node->input(0)->cast<CNodePtr>();
1833 MS_EXCEPTION_IF_NULL(sw_node);
1834 AnfNodePtr branch_index_iter = sw_node->input(branch_index);
1835 AnfNodePtr branch_dyn_input_node = nullptr;
1836 const size_t make_tuple_index = 2;
1837 AnfNodePtr make_tuple_iter = sw_node->input(make_tuple_index);
1838 branch_dyn_input_node = make_tuple_iter; // switch node's 2nd input as dyn input
1839
1840 std::shared_ptr<std::vector<OutHandler>> tuple_items = std::make_shared<std::vector<OutHandler>>();
1841 MS_EXCEPTION_IF_NULL(tuple_items);
1842
1843 CNodePtr input_node = node;
1844 if (!is_kernel_graph_) {
1845 for (size_t i = 1; i < node->size(); i++) {
1846 auto pred = node->input(i);
1847 (void)(branch_inputs.emplace_back(pred));
1848 }
1849 input_node = node->input(0)->cast<CNodePtr>();
1850 }
1851 MS_EXCEPTION_IF_NULL(input_node);
1852 auto bnode = is_case ? input_node->input(make_tuple_index)->cast<CNodePtr>() : input_node;
1853 MS_EXCEPTION_IF_NULL(bnode);
1854 const size_t init_i = is_case ? 1 : 2;
1855 for (size_t i = init_i; i < bnode->size(); ++i) {
1856 const auto &bnode_input = bnode->input(i);
1857 MS_EXCEPTION_IF_NULL(bnode_input);
1858 if (!bnode_input->isa<CNode>()) {
1859 continue;
1860 }
1861 auto branch_node = bnode_input->cast<CNodePtr>();
1862 MS_EXCEPTION_IF_NULL(branch_node);
1863 for (size_t j = 2; j < branch_node->size(); ++j) {
1864 auto pred = branch_node->input(j);
1865 (void)(branch_inputs.emplace_back(pred));
1866 }
1867 }
1868 std::vector<AnfNodePtr> branch_control_input;
1869 for (size_t i = 0; i < branch_inputs.size(); i++) {
1870 auto item = branch_inputs[i];
1871 if (!IsDataInput(node, item, 0)) {
1872 branch_control_input.emplace_back(item);
1873 continue;
1874 }
1875 if (IsBodyGraph() && item->isa<Parameter>()) {
1876 auto idx = std::find(inputs_.begin(), inputs_.end(), item) - inputs_.begin();
1877 (void)(tuple_items->emplace_back(subgraph_input_cache_[idx], "", item));
1878 } else {
1879 auto hd = GetHandler(item);
1880 tuple_items->emplace_back(hd);
1881 }
1882 }
1883 tuple_out_handle_cache_[branch_dyn_input_node.get()] = tuple_items;
1884
1885 std::shared_ptr<std::vector<AnfNodePtr>> branch_input_items = std::make_shared<std::vector<AnfNodePtr>>();
1886 MS_EXCEPTION_IF_NULL(branch_input_items);
1887 (void)branch_input_items->emplace_back(branch_index_iter);
1888 (void)branch_input_items->emplace_back(branch_dyn_input_node);
1889
1890 (void)std::copy(branch_control_input.begin(), branch_control_input.end(), std::back_inserter(*branch_input_items));
1891 branch_input_handle_cache_[node.get()] = branch_input_items;
1892 return;
1893 }
1894
GetCallNodeInputs(const CNodePtr & node)1895 void DfGraphConvertor::GetCallNodeInputs(const CNodePtr &node) {
1896 if (node == nullptr) {
1897 return;
1898 }
1899 if (call_input_handle_cache_.find(node) != call_input_handle_cache_.end()) {
1900 return;
1901 }
1902
1903 auto call_input_items = std::make_shared<std::vector<OutHandler>>();
1904 MS_EXCEPTION_IF_NULL(call_input_items);
1905 std::vector<AnfNodePtr> inputs = GetAnfCallInputs(is_kernel_graph_, node);
1906
1907 auto ¶ms = anf_graph_->parameters();
1908 auto while_op = Convert(node);
1909
1910 while_const_input_index_.clear();
1911 std::set<size_t> while_input_node_index;
1912 for (auto iter = while_used_input_index_.begin(); iter != while_used_input_index_.end(); ++iter) {
1913 auto n = inputs[*iter];
1914 MS_EXCEPTION_IF_NULL(n);
1915 OutHandler out_handler;
1916 if (IsAfterGraph() && n->isa<Parameter>()) {
1917 auto idx = std::find(params.begin(), params.end(), n) - params.begin();
1918 auto idx_cond = prev_after_cond_map_[idx];
1919 if (bypass_node_prev_handle_cache_.find(idx_cond) != bypass_node_prev_handle_cache_.end()) {
1920 out_handler = bypass_node_prev_handle_cache_[idx_cond];
1921 } else {
1922 auto idx_out = prev_cond_to_while_out_index_[idx_cond];
1923 out_handler = while_output_handle_cache_[prev_while_node_]->at(idx_out);
1924 }
1925 } else {
1926 out_handler = GetHandler(inputs[*iter]);
1927 }
1928 MS_EXCEPTION_IF_NULL(out_handler.op);
1929 if ((out_handler.op->GetOpType() == "Const" || out_handler.op->GetOpType() == "Constant") &&
1930 const_op_to_value_.find(out_handler.op) != const_op_to_value_.end()) {
1931 while_const_input_index_[*iter] = out_handler;
1932 } else {
1933 (void)while_input_node_index.insert(*iter);
1934 call_input_items->emplace_back(out_handler);
1935 }
1936 }
1937 cur_while_node_out_size_ = call_input_items->size();
1938 bypass_node_handle_cache_.clear();
1939
1940 for (size_t i = 0; i < inputs.size(); i++) {
1941 if (while_input_node_index.find(i) == while_input_node_index.end()) {
1942 auto n = inputs[i];
1943 MS_EXCEPTION_IF_NULL(n);
1944 if (HasAbstractMonad(n)) {
1945 continue;
1946 }
1947 if (IsAfterGraph() && n->isa<Parameter>()) {
1948 auto idx = std::find(params.begin(), params.end(), n) - params.begin();
1949 auto idx_cond = prev_after_cond_map_[idx];
1950 if (bypass_node_prev_handle_cache_.find(idx_cond) != bypass_node_prev_handle_cache_.end()) {
1951 bypass_node_handle_cache_[i] = bypass_node_prev_handle_cache_[idx_cond];
1952 } else {
1953 auto idx_out = prev_cond_to_while_out_index_[idx_cond];
1954 bypass_node_handle_cache_[i] = while_output_handle_cache_[prev_while_node_]->at(idx_out);
1955 }
1956 } else {
1957 bypass_node_handle_cache_[i] = GetHandler(n);
1958 }
1959 }
1960 }
1961
1962 auto op = Convert(node);
1963 auto adpt = FindAdapter(node, training_);
1964 MS_EXCEPTION_IF_NULL(adpt);
1965 adpt->setDynamicOutputNum(op, cur_while_node_out_size_);
1966 call_input_handle_cache_[node] = call_input_items;
1967 return;
1968 }
1969
SetGraphInputs(std::vector<Operator> * inputs)1970 void DfGraphConvertor::SetGraphInputs(std::vector<Operator> *inputs) {
1971 if (IsNormalGraph() && ConfigManager::GetInstance().dataset_mode() == DS_SINK_MODE) {
1972 auto ms_context = MsContext::GetInstance();
1973 MS_EXCEPTION_IF_NULL(ms_context);
1974 std::vector<PrimitivePtr> input_prims;
1975 if (ms_context->get_param<bool>(MS_CTX_ENABLE_GE_HETEROGENOUS)) {
1976 input_prims = {prim::kPrimQueueData};
1977 } else {
1978 input_prims = {prim::kPrimGetNext, prim::kPrimDynamicGetNextV2};
1979 }
1980
1981 OperatorPtr input;
1982 auto nodes = GetOrderedCNodes(anf_graph_);
1983 for (auto &it : nodes) {
1984 if (std::any_of(input_prims.begin(), input_prims.end(),
1985 [&it](const PrimitivePtr &prim) { return IsPrimitiveCNode(it, prim); })) {
1986 auto it_op = op_cache_.find(it.get());
1987 if (it_op != op_cache_.end()) {
1988 input = it_op->second;
1989 break;
1990 } else {
1991 MS_LOG(EXCEPTION) << "Can not find the operator of node: " << it->fullname_with_scope();
1992 }
1993 }
1994 }
1995 if (input == nullptr) {
1996 MS_LOG(EXCEPTION) << "Can not find the GetNext node in graph in sink_mode, please check.";
1997 }
1998 inputs->emplace_back(*input);
1999
2000 MS_EXCEPTION_IF_NULL(anf_graph_);
2001 anf_graph_->set_flag(kGraphFlagHasGetNext, true);
2002 } else {
2003 auto params = anf_graph_->parameters();
2004 int index = 0;
2005 for (auto &it : params) {
2006 auto param = it->cast<ParameterPtr>();
2007 MS_EXCEPTION_IF_NULL(param);
2008 auto name = param->name();
2009 if (std::find(init_data_names_.begin(), init_data_names_.end(), name) == init_data_names_.end()) {
2010 const auto ¶m_shape = param->Shape();
2011 MS_EXCEPTION_IF_NULL(param_shape);
2012 const auto &shape = param_shape->cast<abstract::ShapePtr>();
2013 if (shape != nullptr) {
2014 const auto &sv = shape->shape();
2015 if (IsDynamic(sv)) {
2016 dynamic_shape_inputs_ = true;
2017 }
2018 input_shapes_.emplace_back(sv);
2019 }
2020 }
2021 // the parameters which has not been converted to var
2022 if (vars_.find(name) == vars_.end()) {
2023 if (HasAbstractMonad(it)) {
2024 MS_LOG(INFO) << it->DebugString() << " is a monad parameter, skip.";
2025 continue;
2026 }
2027 auto op = Convert(it);
2028 MS_EXCEPTION_IF_NULL(op);
2029 MS_LOG(INFO) << "add not var input " << it->ToString() << ", index " << index;
2030 if (op == nullptr) {
2031 MS_LOG(ERROR) << "Convert graph failed!";
2032 return;
2033 }
2034 UpdateDataOpDesc(it, op);
2035
2036 if (IsNormalGraph()) {
2037 MS_LOG(INFO) << "add input " << it->ToString() << ", index " << index;
2038 SetXDataIndex(op, index);
2039 index++;
2040 }
2041 inputs->emplace_back(*op);
2042 } else if (vars_[name] != nullptr) {
2043 MS_LOG(INFO) << "add var input " << it->ToString();
2044 auto op = Convert(it);
2045 MS_EXCEPTION_IF_NULL(op);
2046 UpdateConstOpDesc(it, vars_[name]);
2047 inputs->emplace_back(*op);
2048 }
2049 }
2050 }
2051 }
2052
IsConstantOp(const OperatorPtr & op) const2053 bool DfGraphConvertor::IsConstantOp(const OperatorPtr &op) const {
2054 if (op == nullptr) {
2055 return false;
2056 }
2057 return (op->GetOpType() == "Constant" || op->GetOpType() == "Const");
2058 }
2059
SetGraphInputsForNotVar(const AnfNodePtr & it,int64_t * index,std::vector<Operator> * inputs)2060 OperatorPtr DfGraphConvertor::SetGraphInputsForNotVar(const AnfNodePtr &it, int64_t *index,
2061 std::vector<Operator> *inputs) {
2062 MS_EXCEPTION_IF_NULL(index);
2063 MS_EXCEPTION_IF_NULL(inputs);
2064 auto op = Convert(it);
2065 if (op == nullptr) {
2066 MS_LOG(EXCEPTION) << "Convert graph failed!";
2067 }
2068 UpdateDataOpDesc(it, op);
2069 if (IsNormalGraph()) {
2070 MS_LOG(INFO) << "add input " << it->ToString() << ", index " << *index;
2071 auto op_type = op->GetOpType();
2072 if (op_type == kTypeData || op_type == kTypeRefData) {
2073 SetXDataIndex(op, (*index));
2074 (*index)++;
2075 } else {
2076 auto name = std::static_pointer_cast<Parameter>(it)->name();
2077 MS_LOG(EXCEPTION) << "Op " << name << " is invalid type " << op->GetOpType() << " as graph input.";
2078 }
2079 }
2080 inputs->push_back(*op);
2081 return op;
2082 }
2083
SetGraphInputs(std::vector<Operator> * inputs,AnfNodeWeakPtrList * ge_inputs)2084 void DfGraphConvertor::SetGraphInputs(std::vector<Operator> *inputs, AnfNodeWeakPtrList *ge_inputs) {
2085 MS_EXCEPTION_IF_NULL(inputs);
2086 MS_EXCEPTION_IF_NULL(ge_inputs);
2087 MS_LOG(INFO) << "IsNormalGraph=" << IsNormalGraph() << ", dataset_mode"
2088 << ConfigManager::GetInstance().dataset_mode();
2089 AddInputInDataSink(inputs);
2090 auto params = anf_graph_->parameters();
2091 MS_LOG(INFO) << "Parameters size " << params.size();
2092 int64_t index = 0;
2093 std::set<std::string> name_records = {};
2094 for (auto &it : params) {
2095 auto name = std::static_pointer_cast<Parameter>(it)->name();
2096 OperatorPtr op;
2097 // the parameters which has not been converted to var
2098 if (vars_.find(name) == vars_.end()) {
2099 auto abs = it->abstract();
2100 MS_EXCEPTION_IF_NULL(abs);
2101 if (HasAbstractMonad(it) || abs->isa<abstract::AbstractSequence>()) {
2102 MS_LOG(INFO) << it->DebugString() << " is a monad or tuple/list parameter, skip.";
2103 continue;
2104 }
2105 op = SetGraphInputsForNotVar(it, &index, inputs);
2106 } else if (vars_[name] != nullptr) {
2107 MS_LOG(INFO) << "add var input " << it->ToString() << ", index " << index;
2108 op = Convert(it);
2109 MS_EXCEPTION_IF_NULL(op);
2110 if (name_records.count(name) != 0) {
2111 // two parameters have same ref_key
2112 MS_LOG(INFO) << "var input " << it->ToString() << " is already added";
2113 continue;
2114 }
2115 (void)name_records.insert(name);
2116 UpdateConstOpDesc(it, vars_[name]);
2117 auto op_type = op->GetOpType();
2118 if (op_type == kTypeRefData) {
2119 SetXDataIndex(op, index);
2120 index++;
2121 } else if (IsConstantOp(op)) {
2122 continue;
2123 } else {
2124 MS_LOG(EXCEPTION) << "Op " << name << " is invalid type " << op->GetOpType() << " as graph input.";
2125 }
2126 inputs->push_back(*op);
2127 }
2128 (void)ge_inputs->emplace_back(AnfNodeWeakPtr(it));
2129 }
2130 MS_LOG(INFO) << "Input size " << inputs->size();
2131 }
2132
AddInputInDataSink(vector<Operator> * inputs)2133 void DfGraphConvertor::AddInputInDataSink(vector<Operator> *inputs) {
2134 MS_EXCEPTION_IF_NULL(inputs);
2135 auto ms_context = MsContext::GetInstance();
2136 MS_EXCEPTION_IF_NULL(ms_context);
2137 std::vector<PrimitivePtr> input_prims;
2138 if (ms_context->get_param<bool>(MS_CTX_ENABLE_GE_HETEROGENOUS)) {
2139 input_prims = {prim::kPrimQueueData};
2140 } else {
2141 input_prims = {prim::kPrimGetNext, prim::kPrimDynamicGetNextV2};
2142 }
2143 OperatorPtr input = nullptr;
2144 auto nodes = GetOrderedCNodes(anf_graph_);
2145 for (auto &it : nodes) {
2146 if (std::any_of(input_prims.begin(), input_prims.end(),
2147 [&it](const PrimitivePtr &prim) { return IsPrimitiveCNode(it, prim); })) {
2148 auto it_op = op_cache_.find(it.get());
2149 if (it_op != op_cache_.end()) {
2150 input = it_op->second;
2151 break;
2152 } else {
2153 MS_LOG(EXCEPTION) << "Can not find the operator of node: " << it->fullname_with_scope();
2154 }
2155 }
2156 }
2157 if (IsNormalGraph() && ConfigManager::GetInstance().dataset_mode() == DS_SINK_MODE && input != nullptr) {
2158 (void)inputs->emplace_back(*input);
2159 MS_EXCEPTION_IF_NULL(anf_graph_);
2160 anf_graph_->set_flag(kGraphFlagHasGetNext, true);
2161 }
2162 }
2163
BuildInitDataGraph(const std::string & name)2164 void DfGraphConvertor::BuildInitDataGraph(const std::string &name) {
2165 MS_LOG(INFO) << "Start BuildInitDataGraph.";
2166
2167 // If MS_CTX_ENABLE_GE_HETEROGENOUS is true, no need InitData graph
2168 auto ms_context = MsContext::GetInstance();
2169 MS_EXCEPTION_IF_NULL(ms_context);
2170 if (ms_context->get_param<bool>(MS_CTX_ENABLE_GE_HETEROGENOUS)) {
2171 df_graph_ = nullptr;
2172 return;
2173 }
2174
2175 AnfNodePtr init_dataset_queue_node = nullptr;
2176 auto nodes = GetOrderedCNodes(anf_graph_);
2177 for (auto &it : nodes) {
2178 if (IsInitDataSetQueueNode(it)) {
2179 init_dataset_queue_node = it;
2180 break;
2181 }
2182 }
2183 OperatorPtr init_data_op = Convert(init_dataset_queue_node);
2184 MS_EXCEPTION_IF_NULL(init_data_op);
2185 if (error_ != SUCCESS) {
2186 return;
2187 }
2188 std::vector<::ge::Operator> inputs{*init_data_op};
2189 std::vector<::ge::Operator> outputs{*init_data_op};
2190 df_graph_ = make_shared<DfGraph>(name);
2191 (void)df_graph_->SetInputs(inputs);
2192 (void)df_graph_->SetOutputs(outputs);
2193 MS_LOG(INFO) << "End BuildInitDataGraph.";
2194 }
2195
FillEmptyInputsWithNoInputOp(std::vector<Operator> * inputs)2196 void DfGraphConvertor::FillEmptyInputsWithNoInputOp(std::vector<Operator> *inputs) {
2197 MS_EXCEPTION_IF_NULL(inputs);
2198 MS_LOG(INFO) << "Fill empty graph inputs with cnode whose inputs are empty.";
2199 auto nodes = GetOrderedCNodes(anf_graph_);
2200 for (auto &it : nodes) {
2201 if (!it->isa<CNode>()) {
2202 continue;
2203 }
2204 std::string name = common::AnfAlgo::GetCNodeName(it);
2205 if (name == prim::kPrimSwitch->name() || name == prim::kPrimSwitchLayer->name() ||
2206 name == prim::kPrimPartial->name()) {
2207 continue;
2208 }
2209 auto adpt = FindAdapter(it, training_);
2210 if (adpt == nullptr) {
2211 continue;
2212 }
2213 if (adpt->getInputMap().empty() && adpt->getAttrInputMap().empty()) {
2214 auto cnode_op = op_cache_.find(it.get());
2215 if (cnode_op != op_cache_.end()) {
2216 (void)inputs->emplace_back(*(cnode_op->second));
2217 break;
2218 } else {
2219 MS_LOG(EXCEPTION) << "Can not find the operator of node: " << it->fullname_with_scope();
2220 }
2221 }
2222 }
2223 }
2224
SetupInputFormat(const FuncGraphManagerPtr & manager,const AnfNodePtr & node)2225 void DfGraphConvertor::SetupInputFormat(const FuncGraphManagerPtr &manager, const AnfNodePtr &node) {
2226 if (!node->isa<Parameter>()) {
2227 return;
2228 }
2229 auto para = node->cast<ParameterPtr>();
2230 std::vector<int64_t> shape;
2231 TypeId type;
2232 std::string format = kOpFormat_DEFAULT;
2233 if (para->has_default()) {
2234 auto value = para->default_param();
2235 MS_EXCEPTION_IF_NULL(value);
2236 auto tensor = value->cast<std::shared_ptr<tensor::Tensor>>();
2237 MS_EXCEPTION_IF_NULL(tensor);
2238 shape = tensor->shape_c();
2239 type = tensor->data_type();
2240 format = SelectParamOriFormat(manager, para);
2241 } else {
2242 if (auto normal_shape_ptr = dyn_cast<abstract::Shape>(para->Shape()); normal_shape_ptr != nullptr) {
2243 shape = normal_shape_ptr->shape();
2244 } else if (!dyn_cast<abstract::NoShape>(para->Shape())) {
2245 MS_LOG(INFO) << "Invalid shape.";
2246 return;
2247 }
2248 if (para->Type()) {
2249 type = para->Type()->type_id();
2250 if (type == kObjectTypeTensorType) {
2251 type = dyn_cast<TensorType>(para->Type())->element()->type_id();
2252 }
2253 } else {
2254 MS_LOG(INFO) << "Invalid shape.";
2255 return;
2256 }
2257 }
2258 std::string param_debug_info = para->DebugString();
2259 auto param_format = param_format_.find(param_debug_info);
2260 if (param_format != param_format_.end()) {
2261 format = param_format->second;
2262 MS_LOG(DEBUG) << "Parameter debug info: " << param_debug_info << ", format is " << format;
2263 }
2264 auto desc = TransformUtil::GetGeTensorDesc(shape, type, format);
2265 StorageFormatConvertor::SetupStorageFormat(anf_graph_, node, desc);
2266 }
2267
GenFakeGraphInRefMode()2268 void DfGraphConvertor::GenFakeGraphInRefMode() {
2269 const auto &nodes = GetOrderedCNodes(anf_graph_);
2270 for (const auto &node : nodes) {
2271 if (!node->isa<CNode>()) {
2272 continue;
2273 }
2274 SaveParamFormat(node->cast<CNodePtr>());
2275 }
2276 auto manager = Manage(anf_graph_, true);
2277 MS_EXCEPTION_IF_NULL(manager);
2278 std::vector<AnfNodeWeakPtr> ge_input_nodes = {};
2279 const auto ¶ms = anf_graph_->parameters();
2280 for (auto &node : params) {
2281 MS_EXCEPTION_IF_NULL(node);
2282 auto abs = node->abstract();
2283 MS_EXCEPTION_IF_NULL(abs);
2284 if (HasAbstractMonad(node) || abs->isa<abstract::AbstractSequence>()) {
2285 continue;
2286 }
2287 SetupInputFormat(manager, node);
2288 (void)ge_input_nodes.emplace_back(AnfNodeWeakPtr(node));
2289 }
2290 auto input_name_list = std::make_shared<GEInputList>();
2291 input_name_list->ge_inputs = ge_input_nodes;
2292 anf_graph_->set_user_data(input_name_list);
2293 for (auto &anf_node : params) {
2294 MS_EXCEPTION_IF_NULL(anf_node);
2295 auto para = anf_node->cast<ParameterPtr>();
2296 MS_EXCEPTION_IF_NULL(para);
2297 auto name = para->name();
2298 if (std::find(init_data_names_.begin(), init_data_names_.end(), name) == init_data_names_.end()) {
2299 const auto ¶m_shape = para->Shape();
2300 MS_EXCEPTION_IF_NULL(param_shape);
2301 const auto &shape = param_shape->cast<abstract::ShapePtr>();
2302 if (shape != nullptr) {
2303 const auto &sv = shape->shape();
2304 if (IsDynamic(sv)) {
2305 dynamic_shape_inputs_ = true;
2306 }
2307 input_shapes_.push_back(sv);
2308 }
2309 }
2310 }
2311
2312 auto ms_context = MsContext::GetInstance();
2313 MS_EXCEPTION_IF_NULL(ms_context);
2314 // set up init sub graph
2315 static bool is_inited = false;
2316 init_graph_ = nullptr;
2317 bool sink_mode = ConfigManager::GetInstance().dataset_mode() == DS_SINK_MODE;
2318 if (training_ && sink_mode && ms_context->get_param<bool>(MS_CTX_ENABLE_LOOP_SINK) && !is_inited) {
2319 init_graph_ = GenExampleGraph(kInit);
2320 is_inited = true;
2321 }
2322 }
2323
GenFakeGraph(const std::string & name)2324 void DfGraphConvertor::GenFakeGraph(const std::string &name) {
2325 MS_LOG(INFO) << "Gen fake compute graph " << name;
2326 df_graph_ = GenExampleGraph(name);
2327 MS_EXCEPTION_IF_NULL(df_graph_);
2328 bool sink_mode = ConfigManager::GetInstance().dataset_mode() == DS_SINK_MODE;
2329 if (IsNormalGraph() && sink_mode) {
2330 MS_EXCEPTION_IF_NULL(anf_graph_);
2331 anf_graph_->set_flag(kGraphFlagHasGetNext, true);
2332 }
2333 const auto ¶ms = anf_graph_->parameters();
2334 bool has_weight = std::any_of(params.begin(), params.end(), [](const auto ¶) {
2335 auto parameter = para->template cast<ParameterPtr>();
2336 MS_EXCEPTION_IF_NULL(parameter);
2337 return parameter->has_default();
2338 });
2339 if (distribute_ && has_weight) {
2340 this->broadcast_graph_ = GenExampleGraph(kBroadcast);
2341 }
2342 if (!ref_mode_) {
2343 return;
2344 }
2345 GenFakeGraphInRefMode();
2346 }
2347
BuildGraph(const std::string & name)2348 DfGraphConvertor &DfGraphConvertor::BuildGraph(const std::string &name) {
2349 MS_LOG(INFO) << "Start BuildGraph, graph: " << anf_graph_->ToString();
2350
2351 if (error_ != SUCCESS) {
2352 return *this;
2353 }
2354
2355 GetCallNodeInputs(cur_while_node_);
2356 // branch node set input.
2357 bool is_initdata_graph = false;
2358 std::vector<AnfNodePtr> nodes = GetOrderedCNodes(anf_graph_);
2359 for (auto &it : nodes) {
2360 if (IsBranchNode(it)) {
2361 auto node = it->cast<CNodePtr>();
2362 GetBranchNodeInput(node);
2363 }
2364 if (IsInitDataSetQueueNode(it)) {
2365 is_initdata_graph = true;
2366 }
2367 }
2368 auto manager = anf_graph_->manager();
2369 if (manager == nullptr) {
2370 auto new_manager = MakeManager({anf_graph_});
2371 MS_EXCEPTION_IF_NULL(new_manager);
2372 new_manager->AddFuncGraph(anf_graph_);
2373 anf_graph_->set_manager(new_manager);
2374 }
2375
2376 if (is_initdata_graph) {
2377 BuildInitDataGraph(name);
2378 return *this;
2379 }
2380 nodes = GetOrderedCNodes(anf_graph_);
2381 for (auto &it : nodes) {
2382 SetNodeInput(it);
2383 SetSubgraph(it);
2384 UpdateOpDesc(it);
2385 }
2386
2387 if (error_ == SUCCESS) {
2388 df_graph_ = make_shared<DfGraph>(name);
2389 } else {
2390 return *this;
2391 }
2392
2393 // set graph input according to the order from anf graph
2394 std::vector<Operator> inputs;
2395 std::vector<AnfNodeWeakPtr> ge_input_nodes = {};
2396 if (ref_mode_ && !export_air_) {
2397 SetGraphInputs(&inputs, &ge_input_nodes);
2398 } else {
2399 SetGraphInputs(&inputs);
2400 }
2401
2402 // Add const nodes as graph input for some operator work with constant
2403 MS_LOG(INFO) << "Graph const input size: " << graph_const_inputs_.size();
2404 auto fv_names = GetFvNames(anf_graph_);
2405 for (auto &input : graph_const_inputs_) {
2406 if (fv_names.find(input->GetName()) == fv_names.end()) {
2407 inputs.emplace_back(*input);
2408 }
2409 }
2410
2411 FillEmptyInputsWithNoInputOp(&inputs);
2412
2413 MS_LOG(INFO) << "Set graph input num: " << inputs.size();
2414 (void)df_graph_->SetInputs(inputs);
2415
2416 SetGraphOutputs(true);
2417 (void)df_graph_->SetOutputs(graph_outputs_);
2418
2419 IdentityOptimization();
2420 NoOpOptimization();
2421 if (has_es_node_) {
2422 ESOptimization();
2423 }
2424
2425 compute_sout_ << "}" << endl;
2426 // For the graph(e.g. eval_subgraph) whose IterNum is 1, do not set NeedIteration flag.
2427 auto ms_context = MsContext::GetInstance();
2428 MS_EXCEPTION_IF_NULL(ms_context);
2429 if (ConfigManager::GetInstance().iter_num() > 1 && ms_context->get_param<bool>(MS_CTX_ENABLE_LOOP_SINK)) {
2430 df_graph_->SetNeedIteration(true);
2431 anf_graph_->set_flag(kGraphNeedIteration, true);
2432 }
2433 if (ref_mode_) {
2434 std::sort(ref_datas_.begin(), ref_datas_.end(), [](const OperatorPtr &left, const OperatorPtr &right) -> bool {
2435 int64_t left_idx;
2436 int64_t right_idx;
2437 left->GetAttr(kTypeIndex, left_idx);
2438 right->GetAttr(kTypeIndex, right_idx);
2439 return left_idx < right_idx;
2440 });
2441 auto input_name_list = std::make_shared<GEInputList>();
2442 MS_EXCEPTION_IF_NULL(input_name_list);
2443 input_name_list->ge_inputs = ge_input_nodes;
2444 anf_graph_->set_user_data(input_name_list);
2445 }
2446 MS_LOG(INFO) << "End BuildGraph, graph: " << anf_graph_->ToString();
2447 return *this;
2448 }
2449
SetGraphOutputs(bool is_main_graph)2450 void DfGraphConvertor::SetGraphOutputs(bool is_main_graph) {
2451 if (cur_while_node_ == nullptr) {
2452 graph_outputs_.clear();
2453 std::vector<AnfNodePtr> return_nodes;
2454 auto ret_node = anf_graph_->get_return();
2455 MS_EXCEPTION_IF_NULL(ret_node);
2456 auto output_nodes = ret_node->inputs();
2457 if (has_es_node_) {
2458 return_nodes = GetEmbeddingApplyAdamOutput(ret_node);
2459 } else if (((!HasSubgraph(anf_graph_) && is_main_graph)) ||
2460 (output_nodes.size() > 1 && IsESNodeWithNoOutput(output_nodes[1]))) {
2461 // replace return node with graph output node.
2462 return_nodes.insert(return_nodes.end(), output_nodes.begin() + 1, output_nodes.end());
2463 } else {
2464 return_nodes.emplace_back(ret_node);
2465 }
2466 for (const auto &output_node : return_nodes) {
2467 MS_EXCEPTION_IF_NULL(output_node);
2468 auto adpt = FindAdapter(output_node, training_);
2469 MS_EXCEPTION_IF_NULL(adpt);
2470 auto op_ptr = Convert(output_node);
2471 std::vector<OutHandler> handles;
2472 if (op_ptr != nullptr) {
2473 handles = adpt->getOutputs(op_ptr);
2474 } else if (tuple_out_handle_cache_.count(output_node.get()) > 0) {
2475 handles = *tuple_out_handle_cache_[output_node.get()];
2476 } else {
2477 MS_LOG(EXCEPTION) << "Can not find matched handles for node " << output_node->ToString();
2478 }
2479
2480 for (const auto &handle : handles) {
2481 (void)graph_outputs_.emplace_back(std::make_pair(*handle.op, handle.out));
2482 }
2483 }
2484 }
2485
2486 MS_LOG(INFO) << "Set graph " << anf_graph_->ToString() << " output, num: " << graph_outputs_.size();
2487 for (size_t i = 0; i < graph_outputs_.size(); i++) {
2488 MS_LOG(INFO) << "Graph output " << i << ": node: " << graph_outputs_[i].first.GetName()
2489 << ", out: " << graph_outputs_[i].second;
2490 }
2491 }
2492
UpdateConstOpDesc(const AnfNodePtr & it,const OperatorPtr & op) const2493 void DfGraphConvertor::UpdateConstOpDesc(const AnfNodePtr &it, const OperatorPtr &op) const {
2494 if (!it->isa<Parameter>()) {
2495 MS_LOG(DEBUG) << "It is not parameter, name: " << it->DebugString();
2496 return;
2497 }
2498 auto para = it->cast<ParameterPtr>();
2499 MS_EXCEPTION_IF_NULL(para);
2500 std::string format = SelectParamOriFormat(graph_manager_, it);
2501 std::string param_debug_info = para->DebugString();
2502 auto param_format = param_format_.find(param_debug_info);
2503 if (param_format != param_format_.end()) {
2504 format = param_format->second;
2505 MS_LOG(DEBUG) << "Parameter debug info: " << param_debug_info << ", format is " << format;
2506 }
2507 if (format == kOpFormat_DEFAULT || format == kOpFormat_NCHW) {
2508 MS_LOG(DEBUG) << "Format is not changed, no need to update op desc, name: " << param_debug_info;
2509 return;
2510 }
2511 if (!para->has_default()) {
2512 MS_LOG(DEBUG) << "Parameter has no default, no need to update op desc, name: " << param_debug_info;
2513 return;
2514 }
2515 auto value = para->default_param();
2516 MS_EXCEPTION_IF_NULL(value);
2517 auto tensor = value->cast<std::shared_ptr<tensor::Tensor>>();
2518 MS_EXCEPTION_IF_NULL(tensor);
2519 auto const_op_desc = TransformUtil::GetGeTensorDesc(tensor->shape_c(), tensor->data_type(), format);
2520 StorageFormatConvertor::SetupStorageFormat(anf_graph_, it, const_op_desc, format);
2521 if (const_op_desc == nullptr) {
2522 MS_LOG(WARNING) << "Create parameter " << para->name() << " output descriptor failed!";
2523 return;
2524 }
2525 (void)op->UpdateOutputDesc(kTypeY, *const_op_desc);
2526 }
2527
UpdateDataOpDesc(const AnfNodePtr & it,const OperatorPtr & op) const2528 void DfGraphConvertor::UpdateDataOpDesc(const AnfNodePtr &it, const OperatorPtr &op) const {
2529 auto node = std::static_pointer_cast<AnfNode>(it);
2530 MS_EXCEPTION_IF_NULL(node);
2531 if (node == nullptr) {
2532 MS_LOG(ERROR) << "Update data op descriptor failed! Invalid node.";
2533 return;
2534 }
2535 std::vector<int64_t> shape;
2536 if (auto normal_shape_ptr = dyn_cast<abstract::Shape>(node->Shape()); normal_shape_ptr != nullptr) {
2537 shape = normal_shape_ptr->shape();
2538 } else if (auto no_shape_ptr = dyn_cast<abstract::NoShape>(node->Shape()); no_shape_ptr != nullptr) {
2539 shape = {};
2540 } else {
2541 MS_LOG(INFO) << "Invalid shape to update data op descriptor.";
2542 return;
2543 }
2544 if (node->Type() == nullptr) {
2545 MS_LOG(INFO) << "Invalid type to update data op descriptor.";
2546 return;
2547 }
2548 TypeId me_type = node->Type()->type_id();
2549 if (kObjectTypeTensorType == me_type) {
2550 me_type = dyn_cast<TensorType>(node->Type())->element()->type_id();
2551 }
2552 std::ostringstream buf;
2553 buf << "[" << shape << "]";
2554 MS_LOG(INFO) << "input shape is " << buf.str() << ", type is " << me_type;
2555 std::string format = SelectParamOriFormat(graph_manager_, it);
2556 if (it->isa<Parameter>()) {
2557 auto param = it->cast<ParameterPtr>();
2558 MS_EXCEPTION_IF_NULL(param);
2559 std::string param_name = param->DebugString();
2560 auto param_format = param_format_.find(param_name);
2561 if (param_format != param_format_.end()) {
2562 format = param_format->second;
2563 MS_LOG(DEBUG) << "parameter: " << param_name << ", format is " << format;
2564 }
2565 }
2566 auto desc = TransformUtil::GetGeTensorDesc(shape, me_type, format);
2567 StorageFormatConvertor::SetupStorageFormat(anf_graph_, it, desc, format);
2568 if (desc == nullptr) {
2569 MS_LOG(ERROR) << "Update data op descriptor failed! TensorDesc is null.";
2570 } else {
2571 (void)op->UpdateInputDesc(kTypeX, *desc);
2572 (void)op->UpdateOutputDesc(kTypeY, *desc);
2573 }
2574 }
2575
GetComputeGraph()2576 DfGraphPtr DfGraphConvertor::GetComputeGraph() { return df_graph_; }
2577
GetInitGraph()2578 DfGraphPtr DfGraphConvertor::GetInitGraph() { return init_graph_; }
2579
GetSaveCheckpointGraph()2580 DfGraphPtr DfGraphConvertor::GetSaveCheckpointGraph() { return save_ckp_graph_; }
2581
GetBroadcastGraph()2582 DfGraphPtr DfGraphConvertor::GetBroadcastGraph() { return broadcast_graph_; }
2583
2584 const std::vector<std::string> trans_var_list = {string(kNameAssign), string(kNameAssignAdd), string(kNameAssignSub)};
2585
ParseLoadInput(const CNodePtr & cnode) const2586 AnfNodePtr DfGraphConvertor::ParseLoadInput(const CNodePtr &cnode) const {
2587 MS_EXCEPTION_IF_NULL(cnode);
2588 size_t min_inputs_size = 3;
2589 if (cnode->size() < min_inputs_size) {
2590 MS_LOG(EXCEPTION) << "input size error, " << cnode->ToString();
2591 }
2592 const size_t para_index = 1;
2593 return cnode->input(para_index);
2594 }
2595
TransformConstOp(const CNodePtr & node,const AnfNodePtr & pred)2596 void DfGraphConvertor::TransformConstOp(const CNodePtr &node, const AnfNodePtr &pred) {
2597 // transform "Const" op to "Variable" op when the next node is "Assign" op.
2598 std::string c_name = GetCNodeTargetFuncName(node);
2599 auto pos = std::find(trans_var_list.begin(), trans_var_list.end(), c_name);
2600 if (!training_ && !IsSubGraph() && pos != trans_var_list.end() && pred->isa<Parameter>()) {
2601 std::string name = std::static_pointer_cast<Parameter>(pred)->name();
2602 auto op_itor = op_cache_.find(pred.get());
2603 if (op_itor == op_cache_.end()) {
2604 MS_LOG(EXCEPTION) << "Can not find op for node " << pred->ToString() << ".";
2605 }
2606 if (op_itor->second != nullptr &&
2607 (op_itor->second->GetOpType() == "Constant" || op_itor->second->GetOpType() == "Const") &&
2608 vars_.find(name) != vars_.end()) {
2609 MS_EXCEPTION_IF_NULL(vars_[name]);
2610 if (ref_mode_) {
2611 auto variable = std::make_shared<RefData>(name);
2612 MS_EXCEPTION_IF_NULL(variable);
2613 auto desc = vars_[name]->GetOutputDesc(kTypeY);
2614 (void)variable->update_output_desc_y(desc);
2615 (void)variable->update_input_desc_x(desc);
2616 (void)variable->set_attr_index(ref_datas_.size());
2617 (void)ref_datas_.emplace_back(variable);
2618 MS_LOG(DEBUG) << "Trans to variable, var = " << variable->GetName() << ".";
2619 op_itor->second = variable; // replace parameter with variable
2620 vars_[name] = variable;
2621 } else {
2622 auto variable = std::make_shared<Variable>(name);
2623 MS_EXCEPTION_IF_NULL(variable);
2624 auto desc = vars_[name]->GetOutputDesc(kTypeY);
2625 (void)variable->update_output_desc_y(desc);
2626 MS_LOG(DEBUG) << "Trans to variable, var = " << variable->GetName() << ".";
2627 op_itor->second = variable; // replace parameter with variable
2628 vars_[name] = variable;
2629 }
2630 }
2631 }
2632 }
2633
GetRealInputNode(const CNodePtr & node,const AnfNodePtr & input)2634 AnfNodePtr DfGraphConvertor::GetRealInputNode(const CNodePtr &node, const AnfNodePtr &input) {
2635 if (input == nullptr || node == nullptr) {
2636 return nullptr;
2637 }
2638 AnfNodePtr pred = input;
2639 while (pred->isa<CNode>() && GetCNodeTargetFuncName(pred->cast<CNodePtr>()) == prim::kPrimDepend->name()) {
2640 pred = pred->cast<CNodePtr>()->input(1);
2641 }
2642
2643 // skip input of UMonad, IOMonad
2644 if (IsValueNode<UMonad>(pred) || IsValueNode<IOMonad>(pred)) {
2645 return nullptr;
2646 }
2647 if (HasAbstractMonad(pred)) {
2648 return nullptr;
2649 }
2650
2651 // skip input of the None, UpdateState
2652 if (IsValueNode<None>(pred) || IsPrimitiveCNode(pred, prim::kPrimUpdateState)) {
2653 return nullptr;
2654 }
2655
2656 if (IsPrimitiveCNode(pred, prim::kPrimLoad)) {
2657 pred = ParseLoadInput(pred->cast<CNodePtr>());
2658 // for scenario like: Depend->Load->TensorMove
2659 if (IsPrimitiveCNode(pred, prim::kPrimDepend)) {
2660 return GetRealInputNode(node, pred);
2661 }
2662 }
2663 TransformConstOp(node, pred);
2664 return pred;
2665 }
2666
IsDataInput(const AnfNodePtr & node,const AnfNodePtr & input,size_t input_index)2667 bool DfGraphConvertor::IsDataInput(const AnfNodePtr &node, const AnfNodePtr &input, size_t input_index) {
2668 if (node == nullptr || input == nullptr) {
2669 MS_LOG(ERROR) << "Node or input is null.";
2670 return false;
2671 }
2672 // Ignore the null ValueTupe in MakeTuple
2673 if (IsMakeTupleWithNullValue(node, input)) {
2674 return false;
2675 }
2676
2677 // skip NoOp
2678 auto op = Convert(node);
2679 if (op != nullptr && op->GetOpType() == kTypeNoOp) {
2680 return false;
2681 }
2682
2683 // skip input of UMonad, IOMonad
2684 if (IsMonad(input)) {
2685 return false;
2686 }
2687
2688 // skip input of the None, UpdateState
2689 if (IsValueNode<None>(input) || IsPrimitiveCNode(input, prim::kPrimUpdateState)) {
2690 return false;
2691 }
2692
2693 const PrimitiveSet has_control_node = {prim::kPrimLoad, prim::kPrimDepend, prim::kPrimTupleGetItem};
2694 if (input_index != kDataInputIndex && IsOneOfPrimitiveCNode(node, has_control_node)) {
2695 return false;
2696 }
2697
2698 // Ge Operator of HcomReceive has no input.
2699 if (IsPrimitiveCNode(node, prim::kPrimReceive)) {
2700 return false;
2701 }
2702
2703 // The NPUClearFloatStatusV2 of GE has no input and output, and the NPUGetFloatStatusV2 has no input.
2704 // The extra data edges of MindSpore need to be converted to control edges of GE.
2705 if (IsOverFlowNode(node, input)) {
2706 return false;
2707 }
2708
2709 if (IsESNodeWithNoOutput(input)) {
2710 return false;
2711 }
2712
2713 return true;
2714 }
2715
GetNormalOpInput(const AnfNodePtr & node,const AnfNodePtr & pred)2716 OutHandler DfGraphConvertor::GetNormalOpInput(const AnfNodePtr &node, const AnfNodePtr &pred) {
2717 MS_EXCEPTION_IF_NULL(node);
2718 MS_EXCEPTION_IF_NULL(pred);
2719 OutHandler out_handler;
2720 if (IsSubGraph() && pred->isa<Parameter>()) {
2721 auto idx = std::find(inputs_.begin(), inputs_.end(), pred) - inputs_.begin();
2722 OperatorPtr op = subgraph_input_cache_[idx];
2723 out_handler.op = op;
2724 return out_handler;
2725 }
2726
2727 if (IsAfterGraph() && pred->isa<Parameter>()) {
2728 auto idx = std::find(inputs_.begin(), inputs_.end(), pred) - inputs_.begin();
2729 auto idx_cond = prev_after_cond_map_[idx];
2730 if (bypass_node_prev_handle_cache_.find(idx_cond) != bypass_node_prev_handle_cache_.end()) {
2731 out_handler = bypass_node_prev_handle_cache_[idx_cond];
2732 } else {
2733 auto idx_out = prev_cond_to_while_out_index_[idx_cond];
2734 MS_EXCEPTION_IF_NULL(while_output_handle_cache_[prev_while_node_]);
2735 out_handler = while_output_handle_cache_[prev_while_node_]->at(idx_out);
2736 }
2737 return out_handler;
2738 }
2739
2740 if (out_handle_cache_.find(pred.get()) != out_handle_cache_.end()) {
2741 return out_handle_cache_[pred.get()];
2742 }
2743 auto op = Convert(pred);
2744 if (op == nullptr) {
2745 MS_LOG(WARNING) << "Convert input node failed, input node: " << pred->fullname_with_scope()
2746 << ", node: " << node->fullname_with_scope() << ", graph: " << anf_graph_->ToString()
2747 << ". Please check whether the node is Partial node or successor node of Partial in sub-graph.";
2748 }
2749 out_handler.op = op;
2750 out_handler.node = pred;
2751 return out_handler;
2752 }
2753
DrawOpInput(const AnfNodePtr & node,const AnfNodePtr & pred,size_t i)2754 void DfGraphConvertor::DrawOpInput(const AnfNodePtr &node, const AnfNodePtr &pred, size_t i) {
2755 MS_EXCEPTION_IF_NULL(pred);
2756 if (pred->isa<CNode>() && GetCNodeTargetFuncName(pred->cast<CNodePtr>()) == mindspore::kTupleGetItemOpName) {
2757 MS_EXCEPTION_IF_NULL(pred->cast<CNodePtr>());
2758 MS_EXCEPTION_IF_NULL(pred->cast<CNodePtr>()->input(1));
2759 compute_sout_ << op_draw_name_[pred->cast<CNodePtr>()->input(1).get()] << " -> " << op_draw_name_[node.get()] << ":"
2760 << i << endl;
2761 } else if (pred->isa<Parameter>()) {
2762 compute_sout_ << op_draw_name_[pred.get()] << " -> " << op_draw_name_[node.get()] << ":" << i << endl;
2763 }
2764 return;
2765 }
2766
GetInputHandles(const AnfNodePtr & node,const AnfNodePtr & input)2767 std::vector<OutHandler> DfGraphConvertor::GetInputHandles(const AnfNodePtr &node, const AnfNodePtr &input) {
2768 MS_EXCEPTION_IF_NULL(node);
2769 MS_EXCEPTION_IF_NULL(input);
2770 std::vector<OutHandler> handles;
2771 auto cache_ret = tuple_out_handle_cache_.find(input.get());
2772 if (cache_ret != tuple_out_handle_cache_.end()) {
2773 handles = *(cache_ret->second);
2774 } else if (IsWhileNode(input)) {
2775 // While node in subgraph does not convert.
2776 // Output handle of While node is inconsistent with MS.
2777 MS_LOG(WARNING) << "Input node is while node, input node: " << input->fullname_with_scope()
2778 << ", node: " << node->fullname_with_scope() << ", graph: " << anf_graph_->ToString();
2779 std::transform(graph_outputs_.begin(), graph_outputs_.end(), std::back_inserter(handles), [](const auto output) {
2780 return OutHandler(std::make_shared<::ge::Operator>(output.first), output.second);
2781 });
2782 } else {
2783 auto pred_adpt = FindAdapter(input, training_);
2784 MS_EXCEPTION_IF_NULL(pred_adpt);
2785 // When node's output is dynamic or node has multiple output, it need to get all handles.
2786 // TupleGetItem's input is dynamic output(eg:MakeTuple), but it only need to get one handle.
2787 if ((pred_adpt->IsDyOutputOp(0) || pred_adpt->IsMultipleOutputOp(input))) {
2788 MS_EXCEPTION_IF_NULL(Convert(input));
2789 handles = pred_adpt->getOutputs(Convert(input));
2790 } else {
2791 auto handle = GetNormalOpInput(node, input);
2792 if (handle.op != nullptr) {
2793 handles.emplace_back(handle);
2794 }
2795 }
2796 }
2797
2798 if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
2799 std::vector<OutHandler> return_handles;
2800 CNodePtr cnode = node->cast<CNodePtr>();
2801 MS_EXCEPTION_IF_NULL(cnode);
2802 size_t tuplegetitem_idx = common::AnfAlgo::GetTupleGetItemOutIndex(cnode);
2803 if (tuplegetitem_idx >= handles.size()) {
2804 MS_LOG(EXCEPTION) << "Node output index " << tuplegetitem_idx << " is out of range [0," << handles.size()
2805 << "), node: " << node->fullname_with_scope()
2806 << ", input node: " << input->fullname_with_scope();
2807 } else {
2808 return_handles.emplace_back(handles[tuplegetitem_idx]);
2809 return return_handles;
2810 }
2811 }
2812
2813 return handles;
2814 }
2815
SetDynamicInputHandleByMultiInput(const OpAdapterPtr & adpt,const CNodePtr & node,const CNodePtr & from_node_input)2816 void DfGraphConvertor::SetDynamicInputHandleByMultiInput(const OpAdapterPtr &adpt, const CNodePtr &node,
2817 const CNodePtr &from_node_input) {
2818 MS_EXCEPTION_IF_NULL(adpt);
2819 MS_EXCEPTION_IF_NULL(node);
2820 MS_EXCEPTION_IF_NULL(from_node_input);
2821 auto inputs = from_node_input->inputs();
2822 std::vector<OutHandler> handles;
2823 for (size_t i = 1; i < inputs.size(); i++) {
2824 auto input = inputs[i];
2825 if (!IsDataInput(from_node_input, input, i)) {
2826 SetNodeControlInput(node, input);
2827 continue;
2828 }
2829 TransformConstOp(from_node_input, input);
2830 auto input_handles = GetInputHandles(from_node_input, input);
2831 handles.insert(handles.end(), input_handles.begin(), input_handles.end());
2832 if (input_handles.empty()) {
2833 MS_LOG(INFO) << "input handles is empty, node: " << from_node_input->fullname_with_scope()
2834 << ", input node: " << input->fullname_with_scope();
2835 continue;
2836 }
2837 AddGraphConstInput(input_handles[0].op);
2838 DrawOpInput(node, input, i);
2839 }
2840
2841 auto ret = adpt->setInput(Convert(node), 1, std::make_shared<std::vector<OutHandler>>(handles));
2842 if (ret != SUCCESS) {
2843 MS_LOG(EXCEPTION) << "Set node input handle failed, node:" << node->fullname_with_scope();
2844 }
2845 }
2846
IsMergeOrSwitchLayerInput(const CNodePtr & node) const2847 bool DfGraphConvertor::IsMergeOrSwitchLayerInput(const CNodePtr &node) const {
2848 auto manager = anf_graph_->manager();
2849 if (manager == nullptr) {
2850 auto new_manager = MakeManager({anf_graph_});
2851 MS_EXCEPTION_IF_NULL(new_manager);
2852 new_manager->AddFuncGraph(anf_graph_);
2853 anf_graph_->set_manager(new_manager);
2854 manager = new_manager;
2855 }
2856 auto node_users = manager->node_users()[node];
2857
2858 return (node_users.size() == 1 && std::find_if(node_users.begin(), node_users.end(), [](const auto &node_user) {
2859 return IsPrimitiveCNode(node_user.first, prim::kPrimMerge) ||
2860 IsPrimitiveCNode(node_user.first, prim::kPrimSwitchLayer);
2861 }) != node_users.end());
2862 }
2863
SetMakeTupleInput(const OpAdapterPtr & adpt,const CNodePtr & make_tuple_node)2864 void DfGraphConvertor::SetMakeTupleInput(const OpAdapterPtr &adpt, const CNodePtr &make_tuple_node) {
2865 MS_EXCEPTION_IF_NULL(adpt);
2866 MS_EXCEPTION_IF_NULL(make_tuple_node);
2867 MS_LOG(DEBUG) << "Set MakeTuple input handle: " << make_tuple_node->fullname_with_scope();
2868 // Skip MakeTuple make_tuple_node before Merge. Two branches(true/false) should not be merged before Merge, which
2869 // will lead to assign stream error in GE. Skip MakeTuple node before switch_layer, switch_layer's inputs will be
2870 // set in control flow process
2871 if (IsMergeOrSwitchLayerInput(make_tuple_node)) {
2872 MS_LOG(INFO) << "Skip make_tuple_node " << make_tuple_node->fullname_with_scope() << ", not set input handle.";
2873 return;
2874 }
2875 SetDynamicInputHandleByMultiInput(adpt, make_tuple_node, make_tuple_node);
2876 }
2877
SetMergeInput(const OpAdapterPtr & adpt,const CNodePtr & merge_node)2878 void DfGraphConvertor::SetMergeInput(const OpAdapterPtr &adpt, const CNodePtr &merge_node) {
2879 MS_EXCEPTION_IF_NULL(adpt);
2880 MS_EXCEPTION_IF_NULL(merge_node);
2881 auto inputs = merge_node->inputs();
2882 if (inputs.size() != kMergeInputSize) {
2883 MS_LOG(EXCEPTION) << "Merge input size should be " << kMergeInputSize << ", but is " << inputs.size()
2884 << ", node: " << merge_node->fullname_with_scope();
2885 }
2886 auto make_tuple = inputs[1];
2887 MS_EXCEPTION_IF_NULL(make_tuple);
2888 if (!IsPrimitiveCNode(make_tuple, prim::kPrimMakeTuple)) {
2889 MS_LOG(EXCEPTION) << "Merge input is not MakeTuple, but is " << make_tuple->fullname_with_scope()
2890 << ", node: " << merge_node->fullname_with_scope();
2891 }
2892 SetDynamicInputHandleByMultiInput(adpt, merge_node, make_tuple->cast<CNodePtr>());
2893 }
2894
SetNodeControlInput(const AnfNodePtr & node,const AnfNodePtr & input)2895 void DfGraphConvertor::SetNodeControlInput(const AnfNodePtr &node, const AnfNodePtr &input) {
2896 MS_EXCEPTION_IF_NULL(node);
2897 MS_EXCEPTION_IF_NULL(input);
2898 if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem) && input->isa<ValueNode>()) {
2899 return;
2900 }
2901 if (input->isa<Parameter>() && HasAbstractMonad(input)) {
2902 MS_LOG(DEBUG) << "Node input is monad node, do not add control edge. node:" << node->fullname_with_scope()
2903 << ", input: " << input->ToString();
2904 return;
2905 }
2906 auto dst = Convert(node);
2907 MS_EXCEPTION_IF_NULL(dst);
2908 auto src = Convert(input);
2909 if (src != nullptr) {
2910 dst->AddControlInput(*src);
2911 }
2912 }
2913
IsDynamicInputBeforeNormalInput(const OpAdapterPtr & adpt,int * ge_input_size,mindspore::HashMap<int,int> * ge_input_to_ms_input)2914 bool DfGraphConvertor::IsDynamicInputBeforeNormalInput(const OpAdapterPtr &adpt, int *ge_input_size,
2915 mindspore::HashMap<int, int> *ge_input_to_ms_input) {
2916 MS_EXCEPTION_IF_NULL(adpt);
2917 const auto &input_map = adpt->getInputMap();
2918 const auto &dyn_input_map = adpt->getDynInputMap();
2919
2920 // If adpt has no dynamic input, return false.
2921 if (dyn_input_map.empty()) {
2922 return false;
2923 }
2924
2925 // If dynamic input is behind the normal input, return false
2926 int min_dynamic_idx = std::numeric_limits<int>::max();
2927 int max_normal_idx = -1;
2928 for (const auto &iter : dyn_input_map) {
2929 int ms_order = iter.first - kIndex1;
2930 int ge_order = iter.second.index;
2931 min_dynamic_idx = std::min(min_dynamic_idx, ge_order);
2932 *ge_input_size = std::max(*ge_input_size, ge_order + 1);
2933 (*ge_input_to_ms_input)[ge_order] = ms_order;
2934 }
2935 for (const auto &iter : input_map) {
2936 int ms_order = iter.first - kIndex1;
2937 int ge_order = iter.second.index;
2938 max_normal_idx = std::max(max_normal_idx, ge_order);
2939 *ge_input_size = std::max(*ge_input_size, ge_order + 1);
2940 (*ge_input_to_ms_input)[ge_order] = ms_order;
2941 }
2942 if (min_dynamic_idx == std::numeric_limits<int>::max() || max_normal_idx == -1 || min_dynamic_idx > max_normal_idx) {
2943 return false;
2944 }
2945 return true;
2946 }
2947
SetDynamicInputBeforeNormalInput(const OpAdapterPtr & adpt,const CNodePtr & node,const std::vector<AnfNodePtr> & inputs,const int & ge_input_size,const mindspore::HashMap<int,int> & ge_input_to_ms_input,std::vector<int64_t> * dyn_input_sizes)2948 void DfGraphConvertor::SetDynamicInputBeforeNormalInput(const OpAdapterPtr &adpt, const CNodePtr &node,
2949 const std::vector<AnfNodePtr> &inputs, const int &ge_input_size,
2950 const mindspore::HashMap<int, int> &ge_input_to_ms_input,
2951 std::vector<int64_t> *dyn_input_sizes) {
2952 // If dynamic input is ahead of the normal input, use 'create_dynamic_input_by_index_name' to create dynamic input,
2953 // and this func must be called before set normal input.
2954 OperatorPtr src = Convert(node);
2955 MS_EXCEPTION_IF_NULL(adpt);
2956 const auto &dyn_input_map = adpt->getDynInputMap();
2957 MS_EXCEPTION_IF_NULL(dyn_input_sizes);
2958 if (dyn_input_sizes->empty()) {
2959 *dyn_input_sizes = std::vector<int64_t>(ge_input_size, -1);
2960 for (const auto &iter : dyn_input_map) {
2961 dyn_input_sizes->at(iter.first - kIndex1) = 1;
2962 }
2963 }
2964 std::vector<int64_t> new_dyn_input_sizes(ge_input_size, -1);
2965 std::vector<int> ge_tensor_orders =
2966 GetGeTensorOrders(ge_input_to_ms_input, *dyn_input_sizes, ge_input_size, &new_dyn_input_sizes);
2967
2968 std::vector<size_t> ms_control_inputs;
2969 for (size_t i = 1; i < inputs.size(); ++i) {
2970 if (HasAbstractMonad(inputs[i])) {
2971 ms_control_inputs.emplace_back(i);
2972 }
2973 }
2974
2975 MS_LOG(INFO) << "Adjust the dyn input order and use create_dynamic_input_byindex_name for node: "
2976 << node->fullname_with_scope();
2977 // ge_input_idx: the real ge input order
2978 // ge_tensor_orders: the tensor input order
2979 // ge_input_to_ms_input: the relationship between ge input order and ms input order
2980 // new_dyn_input_sizes: tensor size of dynamic input
2981 for (int ge_input_idx = 0; ge_input_idx < ge_input_size; ++ge_input_idx) {
2982 int ms_input_idx = ge_input_to_ms_input.at(ge_input_idx) + kIndex1;
2983 // ge_tensor_idx: the ge input idx of unfold mindspore inputs
2984 int ge_tensor_idx = ge_tensor_orders[ge_input_idx] + kIndex1;
2985 if (ge_tensor_idx >= static_cast<int>(inputs.size())) {
2986 MS_LOG(INFO) << "ge tensor index is more than ms inputs size, ge_tensor_idx:" << ge_tensor_idx
2987 << ", input size: " << inputs.size();
2988 continue;
2989 }
2990 AnfNodePtr pred = inputs[ge_tensor_idx];
2991 MS_EXCEPTION_IF_NULL(pred);
2992 if (!IsDataInput(node, pred, ge_input_idx)) {
2993 SetNodeControlInput(node, pred);
2994 continue;
2995 }
2996 auto handles = GetInputHandles(node, pred);
2997 if (handles.empty()) {
2998 MS_LOG(INFO) << "Input handles is empty, input node: " << pred->fullname_with_scope()
2999 << ", node: " << node->fullname_with_scope() << ", index: " << ms_input_idx;
3000 continue;
3001 }
3002 int ret;
3003 int64_t dyn_input_num = new_dyn_input_sizes[ge_input_idx];
3004 if (dyn_input_num != -1) {
3005 for (size_t dyn_input_idx = 1; dyn_input_idx < LongToSize(dyn_input_num); ++dyn_input_idx) {
3006 auto dyn_input_handle = GetInputHandles(node, inputs[ge_tensor_idx + dyn_input_idx]);
3007 handles.insert(handles.end(), dyn_input_handle.begin(), dyn_input_handle.end());
3008 }
3009 size_t dyn_input_begin_idx = 0;
3010 for (size_t i = 0; i < IntToSize(ge_input_idx); ++i) {
3011 dyn_input_begin_idx += new_dyn_input_sizes[i] == -1 ? 1 : LongToSize(new_dyn_input_sizes[i]);
3012 }
3013 ret = adpt->setInput(src, SizeToInt(ms_input_idx), std::make_shared<std::vector<OutHandler>>(handles), true,
3014 dyn_input_begin_idx);
3015 } else {
3016 if (handles.size() != 1 && pred->isa<ValueNode>()) {
3017 handles.clear();
3018 auto handle = GetNormalOpInput(node, pred);
3019 handles.emplace_back(handle);
3020 }
3021 if (handles.size() != 1) {
3022 MS_LOG(EXCEPTION) << "Input handles size " << handles.size() << " is not equal to 1, "
3023 << node->fullname_with_scope() << ", input node: " << pred->fullname_with_scope()
3024 << ", index: " << ms_input_idx;
3025 }
3026 ret = adpt->setInput(src, SizeToInt(ms_input_idx), handles[0]);
3027 }
3028 if (ret != SUCCESS) {
3029 MS_LOG(DEBUG) << "Set node input handle failed, node:" << node->fullname_with_scope()
3030 << ", input node: " << pred->fullname_with_scope() << ", index: " << ms_input_idx;
3031 } else {
3032 DrawOpInput(node, pred, ge_input_idx);
3033 AddGraphConstInput(handles[0].op);
3034 }
3035 }
3036
3037 for (size_t ms_control_input : ms_control_inputs) {
3038 AnfNodePtr pred = inputs[ms_control_input];
3039 SetNodeControlInput(node, pred);
3040 }
3041
3042 // Set input from attr.
3043 SetOpAttrToInput(adpt, node);
3044 return;
3045 }
3046
AddInputAttrsForESNode(const CNodePtr & node,const AnfNodePtr & input)3047 void DfGraphConvertor::AddInputAttrsForESNode(const CNodePtr &node, const AnfNodePtr &input) {
3048 const mindspore::HashSet<PrimitivePtr, PrimitiveHasher, PrimitiveEqual> es_need_add_attr = {
3049 prim::kPrimInitPartitionMap, prim::kPrimInitEmbeddingHashmap, prim::kPrimEmbeddingTableImport,
3050 prim::kPrimEmbeddingTableExport, prim::kPrimEmbeddingComputeVarImport, prim::kPrimEmbeddingComputeVarExport,
3051 prim::kPrimEmbeddingApplyAdam, prim::kPrimEmbeddingApplyAdamW, prim::kPrimEmbeddingApplyAdaGrad,
3052 prim::kPrimEmbeddingApplyFtrl,
3053 };
3054 if (!IsOneOfPrimitiveCNode(node, es_need_add_attr)) {
3055 return;
3056 }
3057 auto real = GetRealInputNode(node, input);
3058 MS_EXCEPTION_IF_NULL(real);
3059 auto op = Convert(real);
3060 MS_EXCEPTION_IF_NULL(real);
3061 if (!real->isa<ValueNode>()) {
3062 return;
3063 }
3064 (void)op->SetAttr(kProcessNodeEngineID, "PS");
3065 }
3066
SetOpInput(const OpAdapterPtr & adpt,const CNodePtr & node)3067 void DfGraphConvertor::SetOpInput(const OpAdapterPtr &adpt, const CNodePtr &node) {
3068 MS_EXCEPTION_IF_NULL(adpt);
3069 MS_EXCEPTION_IF_NULL(node);
3070 OperatorPtr src = Convert(node);
3071 bool branch_flag = false;
3072 auto &inputs = node->inputs();
3073 size_t input_size = inputs.size();
3074 if (branch_input_handle_cache_.find(node.get()) != branch_input_handle_cache_.end()) {
3075 branch_flag = true;
3076 MS_EXCEPTION_IF_NULL(branch_input_handle_cache_[node.get()]);
3077 input_size = branch_input_handle_cache_[node.get()]->size() + 1;
3078 } else if (!IsSubGraph() && call_input_handle_cache_.find(node) != call_input_handle_cache_.end()) {
3079 auto &handles = call_input_handle_cache_[node];
3080 MS_EXCEPTION_IF_NULL(handles);
3081 MS_LOG(DEBUG) << "call node input size: " << handles->size();
3082 adpt->setInput(src, 1, handles);
3083 return;
3084 }
3085
3086 MS_LOG(DEBUG) << "Set op input for node: " << node->fullname_with_scope();
3087 if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
3088 SetMakeTupleInput(adpt, node);
3089 return;
3090 }
3091
3092 if (IsPrimitiveCNode(node, prim::kPrimMerge)) {
3093 SetMergeInput(adpt, node);
3094 return;
3095 }
3096 bool is_call = IsCallNode(node);
3097 std::vector<int64_t> dyn_input_sizes;
3098 if (common::AnfAlgo::HasNodeAttr(kAttrDynInputSizes, node)) {
3099 dyn_input_sizes = common::AnfAlgo::GetNodeAttr<std::vector<int64_t>>(node, kAttrDynInputSizes);
3100 }
3101
3102 int ge_input_size = 1;
3103 mindspore::HashMap<int, int> ge_input_to_ms_input;
3104 if (IsDynamicInputBeforeNormalInput(adpt, &ge_input_size, &ge_input_to_ms_input)) {
3105 SetDynamicInputBeforeNormalInput(adpt, node, inputs, ge_input_size, ge_input_to_ms_input, &dyn_input_sizes);
3106 return;
3107 }
3108 // For call node, the first input is kernel_graph, which should not be added to input args.
3109 size_t input_idx = is_call ? 2 : 1;
3110 size_t real_input_idx = 1;
3111 while (input_idx < input_size) {
3112 AnfNodePtr pred = branch_flag ? branch_input_handle_cache_[node.get()]->at(input_idx - 1) : inputs[input_idx];
3113 MS_EXCEPTION_IF_NULL(pred);
3114 if (!IsDataInput(node, pred, real_input_idx)) {
3115 SetNodeControlInput(node, pred);
3116 input_idx += 1;
3117 real_input_idx += 1;
3118 continue;
3119 }
3120 TransformConstOp(node, pred);
3121 auto handles = GetInputHandles(node, pred);
3122 if (handles.empty()) {
3123 MS_LOG(INFO) << "Input handles is empty, input node: " << pred->fullname_with_scope()
3124 << ", node: " << node->fullname_with_scope() << ", index: " << real_input_idx;
3125 input_idx += 1;
3126 real_input_idx += 1;
3127 continue;
3128 }
3129
3130 int ret;
3131 int64_t dyn_input_num = GetDynInputNum(adpt, is_call, dyn_input_sizes, real_input_idx, input_size, node);
3132 if (dyn_input_num != -1) {
3133 for (size_t dyn_input_idx = 1; dyn_input_idx < LongToSize(dyn_input_num); ++dyn_input_idx) {
3134 auto dyn_input_handle = GetInputHandles(node, inputs[input_idx + dyn_input_idx]);
3135 handles.insert(handles.end(), dyn_input_handle.begin(), dyn_input_handle.end());
3136 }
3137 ret = adpt->setInput(src, SizeToInt(real_input_idx), std::make_shared<std::vector<OutHandler>>(handles));
3138 input_idx += LongToSize(dyn_input_num);
3139 } else {
3140 if (handles.size() != 1 && pred->isa<ValueNode>()) {
3141 handles.clear();
3142 auto handle = GetNormalOpInput(node, pred);
3143 handles.emplace_back(handle);
3144 }
3145 if (handles.size() != 1) {
3146 MS_LOG(EXCEPTION) << "Input handles size " << handles.size() << " is not equal to 1, "
3147 << node->fullname_with_scope() << ", input node: " << pred->fullname_with_scope()
3148 << ", index: " << real_input_idx;
3149 }
3150 ret = adpt->setInput(src, SizeToInt(real_input_idx), handles[0]);
3151 input_idx += 1;
3152 }
3153 if (ret != SUCCESS) {
3154 MS_LOG(DEBUG) << "Set node input handle failed, node:" << node->fullname_with_scope()
3155 << ", input node: " << pred->fullname_with_scope() << ", index: " << real_input_idx;
3156 } else {
3157 DrawOpInput(node, pred, real_input_idx);
3158 AddGraphConstInput(handles[0].op);
3159 }
3160 AddInputAttrsForESNode(node, pred);
3161 real_input_idx += 1;
3162 }
3163 // Set input from attr.
3164 SetOpAttrToInput(adpt, node);
3165 }
3166
SetOpAttrToInput(const OpAdapterPtr & adpt,const CNodePtr & node)3167 void DfGraphConvertor::SetOpAttrToInput(const OpAdapterPtr &adpt, const CNodePtr &node) {
3168 OperatorPtr src = Convert(node);
3169 auto &inputs = node->inputs();
3170 size_t input_size = inputs.size();
3171 const auto &primitive = GetCNodePrimitive(node);
3172 MS_EXCEPTION_IF_NULL(primitive);
3173 const auto monad_size = std::count_if(inputs.begin() + kIndex1, inputs.end(), [](const AnfNodePtr &input) {
3174 return input->isa<ValueNode>() && HasAbstractMonad(input);
3175 });
3176 const auto &attr_input_map = adpt->getAttrInputMap();
3177 const auto &input_map = adpt->getInputMap();
3178 if (input_map.size() != attr_input_map.size() + input_size - monad_size - kIndex1) {
3179 MS_LOG(DEBUG) << "For node: " << node->DebugString()
3180 << ", the size of real input:" << input_size - monad_size - kIndex1
3181 << " + the size of attr_input_map: " << attr_input_map.size()
3182 << " != the size of input_map:" << input_map.size()
3183 << ", so do not convert input from attr any more.";
3184 return;
3185 }
3186 MS_EXCEPTION_IF_NULL(anf_graph_);
3187 for (auto &it : attr_input_map) {
3188 // Get attr from node.
3189 auto value = primitive->GetAttr(it.first);
3190 if (value == nullptr) {
3191 MS_LOG(INFO) << "Node: " << node->DebugString() << " has no attr: " << it.first;
3192 continue;
3193 }
3194 // Create input node for attr value.
3195 auto input_node = NewValueNode(value);
3196 input_node->set_abstract(value->ToAbstract());
3197 anf_graph_->manager()->AddEdge(node, input_node);
3198 auto new_input_op = Convert(input_node);
3199 // Get input desc.
3200 auto input_name = it.second;
3201 auto input_desc = std::find_if(input_map.begin(), input_map.end(),
3202 [input_name](const auto &item) { return item.second.name == input_name; });
3203 if (input_desc == input_map.end()) {
3204 MS_LOG(WARNING) << "Node: " << node->DebugString() << " has no input :" << input_name;
3205 continue;
3206 }
3207 MS_LOG(INFO) << "Set input from attr:" << it.first << " for node: " << node->DebugString()
3208 << ", new value node:" << input_node->DebugString();
3209 input_desc->second.set_op(src, new_input_op);
3210 // Input idx may be wrong.
3211 DrawOpInput(node, input_node, static_cast<size_t>(input_desc->first));
3212 AddGraphConstInput(new_input_op);
3213 }
3214 }
3215
AddGraphConstInput(const OperatorPtr & op)3216 void DfGraphConvertor::AddGraphConstInput(const OperatorPtr &op) {
3217 if (op == nullptr) {
3218 return;
3219 }
3220 if (IsSubGraph()) {
3221 return;
3222 }
3223
3224 if (op->GetOpType() == "Constant" || op->GetOpType() == "Const") {
3225 graph_const_inputs_.emplace_back(op);
3226 }
3227 }
3228
SetNodeInput(const AnfNodePtr node)3229 void DfGraphConvertor::SetNodeInput(const AnfNodePtr node) {
3230 if (!node->isa<CNode>()) {
3231 return;
3232 }
3233 if (op_cache_.find(node.get()) == op_cache_.end()) {
3234 return;
3235 }
3236 auto cnode = node->cast<CNodePtr>();
3237 MS_EXCEPTION_IF_NULL(cnode);
3238 OpAdapterPtr adpt = FindAdapter(cnode, training_);
3239 if (adpt == nullptr) {
3240 error_ = NOT_FOUND;
3241 return;
3242 }
3243
3244 // get Operator from op_cache_, use adapter to set Inputs
3245 DfGraphConvertor::SetOpInput(adpt, cnode);
3246 }
3247
GetGNodeName(const::ge::GNode & node) const3248 std::string DfGraphConvertor::GetGNodeName(const ::ge::GNode &node) const {
3249 ::ge::AscendString name;
3250 auto ret = node.GetName(name);
3251 if (ret == ::ge::GRAPH_SUCCESS) {
3252 return std::string(name.GetString());
3253 } else {
3254 MS_LOG(WARNING) << "Get GNode name failed, ret: " << ret;
3255 return std::string();
3256 }
3257 }
3258
GetGNodeType(const::ge::GNode & node) const3259 std::string DfGraphConvertor::GetGNodeType(const ::ge::GNode &node) const {
3260 ::ge::AscendString node_type;
3261 auto ret = node.GetType(node_type);
3262 if (ret == ::ge::GRAPH_SUCCESS) {
3263 return std::string(node_type.GetString());
3264 } else {
3265 MS_LOG(WARNING) << "Get GNode type failed, ret: " << ret;
3266 return std::string();
3267 }
3268 }
3269
3270 // 1) Identity or IdentityN is the input of Merge, not delete
3271 // 2) Identity or IdentityN is the subgraph(If) input, not delete
3272 // 3) Identity or IdentityN it the output, not delete
3273 // 4) Identity or IdentityN has multiple users, not delete
3274 // 5) Nodes with control edges, temporarily not delete
IsIdentityRedundant(const::ge::GNode & node) const3275 bool DfGraphConvertor::IsIdentityRedundant(const ::ge::GNode &node) const {
3276 auto node_type = GetGNodeType(node);
3277 if (node_type != kTypeIdentityN && node_type != kTypeIdentity) {
3278 MS_LOG(DEBUG) << "Node is not Identity or IdentityN, but is " << node_type << ", node name: " << GetGNodeName(node);
3279 return false;
3280 }
3281
3282 auto node_name = GetGNodeName(node);
3283 auto ret = std::find_if(graph_outputs_.begin(), graph_outputs_.end(),
3284 [&node_name](const auto &output) { return output.first.GetName() == node_name; });
3285 if (ret != graph_outputs_.end()) {
3286 return false;
3287 }
3288
3289 for (size_t output_index = 0; output_index < node.GetOutputsSize(); output_index++) {
3290 auto output_nodes = node.GetOutDataNodesAndPortIndexs(static_cast<int32_t>(output_index));
3291 if (!output_nodes.empty() && has_es_node_) {
3292 return true;
3293 }
3294 if (output_nodes.size() != 1) {
3295 return false;
3296 }
3297
3298 auto output_node_type = GetGNodeType(*(output_nodes.begin()->first));
3299 if (output_node_type == kTypeMerge || output_node_type == kTypeIf) {
3300 return false;
3301 }
3302 }
3303
3304 if (!node.GetOutControlNodes().empty()) {
3305 return false;
3306 }
3307
3308 return true;
3309 }
3310
RemoveIdentity(::ge::GNode identity_node)3311 void DfGraphConvertor::RemoveIdentity(::ge::GNode identity_node) {
3312 MS_LOG(INFO) << "Start Remove Identity or IdentityN, identity_node: " << GetGNodeName(identity_node);
3313 auto node_type = GetGNodeType(identity_node);
3314 if (node_type != kTypeIdentity && node_type != kTypeIdentityN) {
3315 MS_LOG(EXCEPTION) << "Node is not Identity or IdentityN, but is " << node_type
3316 << ", identity_node name: " << GetGNodeName(identity_node);
3317 return;
3318 }
3319 if (identity_node.GetInputsSize() != identity_node.GetOutputsSize()) {
3320 MS_LOG(EXCEPTION) << "Node output size " << identity_node.GetOutputsSize() << " is not equal to input size "
3321 << identity_node.GetInputsSize() << ", identity_node: " << GetGNodeName(identity_node);
3322 return;
3323 }
3324
3325 ::ge::graphStatus ret;
3326 for (size_t output_index = 0; output_index < identity_node.GetOutputsSize(); output_index++) {
3327 auto output_nodes = identity_node.GetOutDataNodesAndPortIndexs(static_cast<int>(output_index));
3328 if (output_nodes.size() != 1 && !has_es_node_) {
3329 return;
3330 }
3331
3332 // 1. Set identity_node data edge
3333 for (size_t i = 0; i < output_nodes.size(); i++) {
3334 auto node_output = output_nodes[i];
3335 auto input_index = output_index;
3336 auto node_input = identity_node.GetInDataNodesAndPortIndexs(static_cast<int32_t>(input_index));
3337 ret = df_graph_->RemoveEdge(identity_node, static_cast<int32_t>(output_index), *node_output.first,
3338 node_output.second);
3339 if (ret != ::ge::GRAPH_SUCCESS) {
3340 MS_LOG(EXCEPTION) << "Remove edge failed, src identity_node: " << GetGNodeName(identity_node)
3341 << ", index: " << output_index << ", dst identity_node: " << GetGNodeName(*node_output.first)
3342 << ", index: " << node_output.second << ", ret: " << ret;
3343 return;
3344 }
3345 ret = df_graph_->AddDataEdge(*node_input.first, node_input.second, *node_output.first, node_output.second);
3346 if (ret != ::ge::GRAPH_SUCCESS) {
3347 MS_LOG(EXCEPTION) << "Add data edge failed, src identity_node: " << GetGNodeName(*node_input.first)
3348 << ", index: "
3349 << ", dst identity_node: " << GetGNodeName(*node_output.first)
3350 << ", index: " << node_output.second << ", ret: " << ret;
3351 return;
3352 }
3353
3354 // 2. Set identity_node control edge
3355 auto node_control = identity_node.GetInControlNodes();
3356 for (const auto &item : node_control) {
3357 ret = df_graph_->AddControlEdge(*item, *node_output.first);
3358 if (ret != ::ge::GRAPH_SUCCESS) {
3359 MS_LOG(EXCEPTION) << "Add control edge failed, src identity_node: " << GetGNodeName(*item)
3360 << ", dst identity_node: " << GetGNodeName(*node_output.first) << ", ret: " << ret;
3361 return;
3362 }
3363 }
3364 }
3365 }
3366
3367 // 3. Remove identity
3368 ret = df_graph_->RemoveNode(identity_node);
3369 if (ret != ::ge::GRAPH_SUCCESS) {
3370 MS_LOG(EXCEPTION) << "Remove identity_node failed, identity_node: " << GetGNodeName(identity_node)
3371 << ", ret: " << ret;
3372 return;
3373 }
3374 }
3375
IdentityOptimization()3376 void DfGraphConvertor::IdentityOptimization() {
3377 MS_LOG(INFO) << "Start IdentityOptimization, graph: " << anf_graph_->ToString();
3378 MS_EXCEPTION_IF_NULL(df_graph_);
3379 auto all_nodes = df_graph_->GetDirectNode();
3380 for (const auto &node : all_nodes) {
3381 if (IsIdentityRedundant(node)) {
3382 RemoveIdentity(node);
3383 }
3384 }
3385 MS_LOG(INFO) << "End IdentityOptimization, graph: " << anf_graph_->ToString();
3386 }
3387
NoOpOptimization()3388 void DfGraphConvertor::NoOpOptimization() {
3389 MS_LOG(INFO) << "Start NoOpOptimization, graph:" << anf_graph_->ToString();
3390 MS_EXCEPTION_IF_NULL(df_graph_);
3391 auto all_nodes = df_graph_->GetDirectNode();
3392 for (const auto &node : all_nodes) {
3393 if (IsNoOpRedundant(node)) {
3394 RemoveNoOp(node);
3395 }
3396 }
3397 MS_LOG(INFO) << "End NoopOptimization, graph:" << anf_graph_->ToString();
3398 }
3399
ESOptimization()3400 void DfGraphConvertor::ESOptimization() {
3401 MS_LOG(INFO) << "Start ESOptimization, graph:" << anf_graph_->ToString();
3402 MS_EXCEPTION_IF_NULL(df_graph_);
3403 auto all_nodes = df_graph_->GetDirectNode();
3404 ::ge::GNode no_op;
3405 bool not_remove = false;
3406 for (const auto &node : all_nodes) {
3407 node.GetAttr(kAttrNotRemove, not_remove);
3408 if (not_remove) {
3409 no_op = node;
3410 break;
3411 }
3412 }
3413 if (not_remove) {
3414 auto output_control_node = no_op.GetOutControlNodes();
3415 if (output_control_node.empty()) {
3416 return;
3417 }
3418 RemoveIdentityForES(*output_control_node[0]);
3419 }
3420 }
3421
RemoveIdentityForES(::ge::GNode node)3422 void DfGraphConvertor::RemoveIdentityForES(::ge::GNode node) {
3423 ::ge::graphStatus ret;
3424 auto out_control_node = node.GetOutControlNodes();
3425 for (size_t input_index = 0; input_index < node.GetInputsSize(); input_index++) {
3426 auto node_input = node.GetInDataNodesAndPortIndexs(static_cast<int32_t>(input_index));
3427 ret = df_graph_->RemoveEdge(*node_input.first, node_input.second, node, input_index);
3428 if (ret != ::ge::GRAPH_SUCCESS) {
3429 MS_LOG(EXCEPTION) << "Remove edge failed, src node: " << GetGNodeName(*node_input.first)
3430 << ", index: " << node_input.second << ", dst identity_node: " << GetGNodeName(node)
3431 << ", index: " << input_index << ", ret: " << ret;
3432 return;
3433 }
3434 }
3435 ret = df_graph_->RemoveNode(node);
3436 if (ret != ::ge::GRAPH_SUCCESS) {
3437 MS_LOG(EXCEPTION) << "Remove node failed, node: " << GetGNodeName(node);
3438 }
3439 if (out_control_node.empty()) {
3440 return;
3441 }
3442 auto output_node = out_control_node[0];
3443 MS_EXCEPTION_IF_NULL(output_node);
3444 RemoveIdentityForES(*output_node);
3445 }
3446
IsNoOpRedundant(const::ge::GNode & node) const3447 bool DfGraphConvertor::IsNoOpRedundant(const ::ge::GNode &node) const {
3448 auto node_type = GetGNodeType(node);
3449 if (node_type != kTypeNoOp) {
3450 return false;
3451 }
3452 if (!training_) {
3453 return true;
3454 }
3455
3456 bool not_remove = false;
3457 node.GetAttr(kAttrNotRemove, not_remove);
3458 if (not_remove) {
3459 return false;
3460 }
3461
3462 auto out_control_node = node.GetOutControlNodes();
3463 auto in_control_node = node.GetInControlNodes();
3464 if (out_control_node.size() == 1 || in_control_node.size() == 1) {
3465 return true;
3466 }
3467 if (out_control_node.size() > kNoOpOptThreshold || in_control_node.size() > kNoOpOptThreshold) {
3468 return false;
3469 }
3470 return true;
3471 }
RemoveNoOp(::ge::GNode noop)3472 void DfGraphConvertor::RemoveNoOp(::ge::GNode noop) {
3473 MS_LOG(INFO) << "Start Remove NoOp, node:" << GetGNodeName(noop);
3474 auto node_type = GetGNodeType(noop);
3475 if (node_type != kTypeNoOp) {
3476 MS_LOG(EXCEPTION) << "Node is not NoOp, but is: " << GetGNodeName(noop);
3477 }
3478
3479 auto in_control_nodes = noop.GetInControlNodes();
3480 auto out_control_nodes = noop.GetOutControlNodes();
3481 auto ret = df_graph_->RemoveNode(noop);
3482 if (ret != ::ge::GRAPH_SUCCESS) {
3483 MS_LOG(EXCEPTION) << "Remove node failed, node: " << GetGNodeName(noop);
3484 }
3485 for (auto src_node : in_control_nodes) {
3486 for (auto dst_node : out_control_nodes) {
3487 ret = df_graph_->AddControlEdge(*src_node, *dst_node);
3488 if (ret != ::ge::GRAPH_SUCCESS) {
3489 MS_LOG(EXCEPTION) << "Add control edge failed, src node: " << GetGNodeName(*src_node)
3490 << ", dst node:" << GetGNodeName(*dst_node);
3491 }
3492 }
3493 }
3494 MS_LOG(INFO) << "End Remove Noop, node: " << GetGNodeName(noop);
3495 }
3496
ProcessSubgraph(const AnfNodePtr & node,const AnfNodePtr & branch_node,ParamIndexMap & branch_to_parent_node_map)3497 void DfGraphConvertor::ProcessSubgraph(const AnfNodePtr &node, const AnfNodePtr &branch_node,
3498 ParamIndexMap &branch_to_parent_node_map) {
3499 MS_LOG(INFO) << "ProcessSubgraph begin.";
3500 ValueNodePtr graph_node = nullptr;
3501 if (branch_node->isa<CNode>()) {
3502 graph_node = branch_node->cast<CNodePtr>()->input(1)->cast<ValueNodePtr>();
3503 } else if (branch_node->isa<ValueNode>()) {
3504 graph_node = branch_node->cast<ValueNodePtr>();
3505 } else {
3506 return;
3507 }
3508
3509 MS_EXCEPTION_IF_NULL(graph_node);
3510 auto anf_graph = graph_node->value()->cast<AnfGraphPtr>();
3511 MS_EXCEPTION_IF_NULL(anf_graph);
3512 DfGraphConvertor converter(anf_graph, phase_prefix_);
3513 converter.graph_type_ = GraphType::kBranch;
3514
3515 auto ¶ms = anf_graph->parameters();
3516 if (ref_mode_) {
3517 for (size_t i = 0; i < params.size(); i++) {
3518 auto ¶m = params[i];
3519 if (branch_to_parent_node_map.find(i) != branch_to_parent_node_map.end()) {
3520 size_t parent_index = branch_to_parent_node_map[i];
3521 OperatorPtr op = nullptr;
3522 op = std::make_shared<Data>();
3523 MS_EXCEPTION_IF_NULL(op);
3524 SetXDataIndex(op, parent_index);
3525 converter.op_cache_[param.get()] = op;
3526 } else if (!HasAbstractMonad(param)) {
3527 MS_LOG(EXCEPTION) << "Branch graph input index to parent node dyn input index error, "
3528 << "branch graph: " << anf_graph->ToString() << "'s " << i << "(st/nd/rd/st)"
3529 << " input can not find the corresponding parent node input index.";
3530 }
3531 }
3532 } else {
3533 auto &dyn_input = branch_input_handle_cache_[node.get()];
3534 MS_EXCEPTION_IF_NULL(dyn_input);
3535 auto &inputs = tuple_out_handle_cache_[dyn_input->at(1).get()];
3536 MS_EXCEPTION_IF_NULL(inputs);
3537 for (size_t i = 0; i < params.size(); i++) {
3538 auto ¶m = params[i];
3539 if (branch_to_parent_node_map.find(i) != branch_to_parent_node_map.end()) {
3540 size_t parent_index = branch_to_parent_node_map[i];
3541 auto &parent_handle = inputs->at(parent_index);
3542 OperatorPtr op = nullptr;
3543 MS_EXCEPTION_IF_NULL(parent_handle.op);
3544 if (parent_handle.op->GetOpType() == kTypeVariable) {
3545 auto name = parent_handle.op->GetName();
3546 op = std::make_shared<Variable>(name);
3547 MS_EXCEPTION_IF_NULL(op);
3548 SetXDataIndex(op, parent_index);
3549 } else {
3550 op = std::make_shared<Data>();
3551 MS_EXCEPTION_IF_NULL(op);
3552 SetXDataIndex(op, parent_index);
3553 }
3554 converter.op_cache_[param.get()] = op;
3555 } else if (!HasAbstractMonad(param)) {
3556 MS_LOG(EXCEPTION) << "Branch graph input index to parent node dyn input index error, "
3557 << "branch graph: " << anf_graph->ToString() << "'s " << i << "(st/nd/rd/st)"
3558 << " input can not find the corresponding parent node input index.";
3559 }
3560 }
3561 }
3562
3563 std::string graph_name = anf_graph->ToString();
3564 auto iter = branches_repeat_times.find(graph_name);
3565 if (iter == branches_repeat_times.end()) {
3566 branches_repeat_times[graph_name] = 1;
3567 } else {
3568 iter->second += 1;
3569 graph_name = graph_name + "_" + std::to_string(iter->second);
3570 }
3571 (void)converter.ConvertAllNode().BuildGraph(graph_name);
3572 #ifdef ENABLE_DUMP_IR
3573 std::string name = graph_node->ToString() + "_ge_graph.dot";
3574 auto context = MsContext::GetInstance();
3575 MS_EXCEPTION_IF_NULL(context);
3576 if (context->CanDump(kFully)) {
3577 converter.DrawComputeGraph(name);
3578 }
3579 #endif
3580 branches_map_[branch_node.get()] = *(converter.df_graph_);
3581 MS_LOG(INFO) << "ProcessSubgraph end.";
3582 }
3583
3584 // Update GE op's shape and type info
UpdateOpDesc(const AnfNodePtr node)3585 void DfGraphConvertor::UpdateOpDesc(const AnfNodePtr node) {
3586 MS_EXCEPTION_IF_NULL(node);
3587 if (node == nullptr || !node->isa<CNode>()) {
3588 return;
3589 }
3590
3591 if (op_cache_.find(node.get()) == op_cache_.end()) {
3592 return;
3593 }
3594
3595 OpAdapterPtr adpt = FindAdapter(node, training_);
3596 if (adpt == nullptr) {
3597 error_ = NOT_FOUND;
3598 return;
3599 }
3600
3601 // get Operator from op_cache_
3602 OperatorPtr op = Convert(node);
3603 MS_EXCEPTION_IF_NULL(op);
3604 std::string op_type = op->GetOpType();
3605 if (!IsNeedToUpdateTensorDesc(op_type, node)) {
3606 MS_LOG(INFO) << "No need to set the opDesc of node: " << node->fullname_with_scope() << ", op type is " << op_type;
3607 return;
3608 }
3609
3610 adpt->updateOutputDesc(op, node->Shape(), node->Type(), node);
3611 }
3612
Convert(const AnfNodePtr node)3613 OperatorPtr DfGraphConvertor::Convert(const AnfNodePtr node) {
3614 if (node == nullptr) {
3615 MS_LOG(ERROR) << "node is nullptr";
3616 error_ = NOT_FOUND;
3617 return nullptr;
3618 }
3619 // find in cache
3620 if (op_cache_.count(node.get()) != 0) {
3621 MS_LOG(DEBUG) << "Get op from cache: " << op_cache_[node.get()]->GetName();
3622 return op_cache_[node.get()];
3623 }
3624
3625 // do not convert primitive node
3626 if (IsValueNode<Primitive>(node)) {
3627 return nullptr;
3628 }
3629 // convert a new one
3630 if (node->isa<CNode>()) {
3631 auto cnode = node->cast<CNodePtr>();
3632 if (IsSubGraph() && IsWhileNode(cnode)) {
3633 return nullptr;
3634 }
3635 if (!IsSubGraph() && IsWhileNode(cnode)) {
3636 CacheWhileGraph(cnode);
3637 auto &graphs = while_graph_cache_[cnode];
3638 GetWhileUsedInputIndex(graphs);
3639 SetParamIndexMap(graphs);
3640 cur_while_node_ = cnode;
3641 }
3642 return ConvertCNode(cnode);
3643 }
3644
3645 if (node->isa<Parameter>() && IsSubGraph()) {
3646 return nullptr;
3647 }
3648
3649 if (node->isa<Parameter>()) {
3650 return ConvertParameter(node);
3651 }
3652 if (node->isa<ValueNode>()) {
3653 if (IsValueNode<Monad>(node)) {
3654 return nullptr;
3655 }
3656 return ConvertValueNode(node->cast<ValueNodePtr>());
3657 }
3658
3659 MS_LOG(ERROR) << "Invalid AnfNode";
3660 error_ = INVALID_ARGUMENT;
3661 return nullptr;
3662 }
3663
ConvertTopK(const CNodePtr & node)3664 void DfGraphConvertor::ConvertTopK(const CNodePtr &node) {
3665 MS_EXCEPTION_IF_NULL(node);
3666 auto value_ptr = node->input(kIndex2)->cast<ValueNodePtr>();
3667 if (value_ptr == nullptr) {
3668 // input is not const valuenode, cannot convert to int32, throw exception when input k is int64 since cann
3669 // has precision problem, can be deleted after cann support int64 for input k
3670 if (common::AnfAlgo::GetPrevNodeOutputInferDataType(node, kIndex1) == kNumberTypeInt64) {
3671 MS_LOG(EXCEPTION) << "Op TopK(" << node->fullname_with_scope() << ")'s second input k is an int64 mutable "
3672 << "tensor/scalar, which is not supported in ascend, please use int32.";
3673 }
3674 return;
3675 }
3676 MS_LOG(INFO) << "Convert TopK second input's type from int64 to int32.";
3677 auto input_value = value_ptr->value();
3678 MS_EXCEPTION_IF_NULL(input_value);
3679 std::ostringstream ss;
3680 ss << "op" << value_ptr.get();
3681 op_draw_name_[value_ptr.get()] = ss.str();
3682 compute_sout_ << ss.str() << "[label= \"" << value_ptr->value()->ToString() << "\" shape=ellipse]" << endl;
3683 int32_t k_value;
3684 if (input_value->isa<tensor::Tensor>()) {
3685 auto input_tensor = input_value->cast<tensor::TensorPtr>();
3686 if (input_tensor->data_type() == kNumberTypeInt32) {
3687 k_value = *static_cast<int32_t *>(input_tensor->data_c());
3688 } else {
3689 k_value = LongToInt(*static_cast<int64_t *>(input_tensor->data_c()));
3690 }
3691 } else {
3692 k_value = LongToInt(GetValue<int64_t>(input_value));
3693 }
3694 OpAdapterPtr adpt = FindAdapter(value_ptr, training_);
3695 MS_EXCEPTION_IF_NULL(adpt);
3696 auto op = adpt->generate(value_ptr);
3697 (void)adpt->setAttr(op, "value", k_value);
3698 op_cache_[value_ptr.get()] = op;
3699 }
3700
CreateCast(const AnfNodePtr & input,const TypePtr & dst_type) const3701 AnfNodePtr DfGraphConvertor::CreateCast(const AnfNodePtr &input, const TypePtr &dst_type) const {
3702 auto func_graph = input->func_graph();
3703 MS_EXCEPTION_IF_NULL(func_graph);
3704 AnfNodePtrList inputs = {NewValueNode(prim::kPrimCast), input,
3705 NewValueNode(static_cast<int64_t>(dst_type->type_id()))};
3706 auto cnode = func_graph->NewCNode(inputs);
3707 MS_EXCEPTION_IF_NULL(cnode);
3708 auto abs_tensor = std::make_shared<abstract::AbstractTensor>(dst_type, input->Shape());
3709 cnode->set_abstract(abs_tensor);
3710 return cnode;
3711 }
3712
CastToInt(const ValuePtr & value) const3713 std::vector<int64_t> DfGraphConvertor::CastToInt(const ValuePtr &value) const {
3714 if (value == nullptr) {
3715 return {};
3716 }
3717 std::vector<int64_t> cur_value = {};
3718 if (utils::isa<ValueSequencePtr>(value)) {
3719 auto val_seq_ptr = value->cast<ValueSequencePtr>();
3720 MS_EXCEPTION_IF_NULL(val_seq_ptr);
3721 if (!val_seq_ptr->value().empty()) {
3722 auto first_val = val_seq_ptr->value().front();
3723 MS_EXCEPTION_IF_NULL(first_val);
3724 MS_EXCEPTION_IF_NULL(first_val->type());
3725 if (first_val->type()->number_type() == kNumberTypeInt64) {
3726 cur_value = GetValue<std::vector<int64_t>>(value);
3727 } else {
3728 auto origin_value = GetValue<std::vector<int>>(value);
3729 (void)std::transform(origin_value.begin(), origin_value.end(), std::back_inserter(cur_value),
3730 [](int index) { return static_cast<int64_t>(index); });
3731 }
3732 }
3733 } else {
3734 MS_EXCEPTION_IF_NULL(value->type());
3735 if (value->type()->number_type() == kNumberTypeInt64) {
3736 cur_value.emplace_back(GetValue<int64_t>(value));
3737 } else {
3738 cur_value.emplace_back(static_cast<int64_t>(GetValue<int>(value)));
3739 }
3740 }
3741 return cur_value;
3742 }
3743
TransInputDataType(const CNodePtr & node,const std::string & node_name) const3744 void DfGraphConvertor::TransInputDataType(const CNodePtr &node, const std::string &node_name) const {
3745 auto iter = kTransInputDTypeMap.find(node_name);
3746 if (iter == kTransInputDTypeMap.end()) {
3747 return;
3748 }
3749 MS_EXCEPTION_IF_NULL(node);
3750 MS_LOG(DEBUG) << "Trans input data type of node:" << node->DebugString();
3751 for (auto &item : iter->second) {
3752 auto input_node = node->input(item.first);
3753 TypeId dst_type = item.second;
3754 MS_EXCEPTION_IF_NULL(input_node);
3755 if (input_node->isa<CNode>() || input_node->isa<Parameter>()) {
3756 auto src_type = input_node->Type()->type_id();
3757 if (kObjectTypeTensorType == src_type) {
3758 src_type = dyn_cast<TensorType>(input_node->Type())->element()->type_id();
3759 }
3760 if (!IsValidConversion(src_type, dst_type)) {
3761 continue;
3762 }
3763 auto new_cast = CreateCast(input_node, TypeIdToType(dst_type));
3764 node->set_input(item.first, new_cast);
3765 } else if (input_node->isa<ValueNode>()) {
3766 auto input_value_node = input_node->cast<ValueNodePtr>();
3767 MS_EXCEPTION_IF_NULL(input_value_node);
3768 auto value = input_value_node->value();
3769 ValuePtr new_value = CastDstValue(value, dst_type);
3770 if (new_value == nullptr) {
3771 continue;
3772 }
3773 auto new_value_node = std::make_shared<ValueNode>(new_value);
3774 MS_EXCEPTION_IF_NULL(new_value_node);
3775 new_value_node->set_abstract(new_value->ToAbstract());
3776 node->set_input(item.first, new_value_node);
3777 }
3778 }
3779 MS_LOG(DEBUG) << "Finish to trans input data type of node:" << node->DebugString();
3780 }
3781
TransAttrDataType(const CNodePtr & node,const std::string & node_name) const3782 void DfGraphConvertor::TransAttrDataType(const CNodePtr &node, const std::string &node_name) const {
3783 auto iter = kTransAttrDTypeMap.find(node_name);
3784 if (iter == kTransAttrDTypeMap.end()) {
3785 return;
3786 }
3787 MS_EXCEPTION_IF_NULL(node);
3788 MS_LOG(DEBUG) << "Trans attr data type of node:" << node->DebugString();
3789 auto prim = common::AnfAlgo::GetCNodePrimitive(node);
3790 MS_EXCEPTION_IF_NULL(prim);
3791 for (auto &item : iter->second) {
3792 std::string attr_name = item.first;
3793 TypeId dst_type = item.second;
3794 if (!prim->HasAttr(attr_name)) {
3795 MS_LOG(EXCEPTION) << "Please check kTransAttrDTypeMap, node:" << node->DebugString()
3796 << " has no attr:" << attr_name;
3797 }
3798 auto attr_value = prim->GetAttr(attr_name);
3799 auto new_attr_value = CastDstValue(attr_value, dst_type);
3800 if (new_attr_value == nullptr) {
3801 continue;
3802 }
3803 prim->set_attr(attr_name, new_attr_value);
3804 }
3805 MS_LOG(DEBUG) << "Finish to trans attr data type of node:" << node->DebugString();
3806 }
3807
TransDataType(const FuncGraphPtr & anf_graph) const3808 void DfGraphConvertor::TransDataType(const FuncGraphPtr &anf_graph) const {
3809 MS_EXCEPTION_IF_NULL(anf_graph);
3810 MS_LOG(DEBUG) << "TransDataType begin. graph:" << anf_graph->ToString();
3811 std::vector<AnfNodePtr> nodes = GetOrderedCNodes(anf_graph);
3812 for (auto &it : nodes) {
3813 if (it->isa<CNode>()) {
3814 auto node = it->cast<CNodePtr>();
3815 MS_EXCEPTION_IF_NULL(node);
3816 std::string name = GetCNodeTargetFuncName(node);
3817 TransInputDataType(node, name);
3818 TransAttrDataType(node, name);
3819 }
3820 }
3821 MS_LOG(DEBUG) << "TransDataType end. graph:" << anf_graph->ToString();
3822 }
3823
ConvertReshape(const CNodePtr & node)3824 void DfGraphConvertor::ConvertReshape(const CNodePtr &node) {
3825 MS_LOG(INFO) << "Convert the second input of reshape to op attr.";
3826 const auto kInputNum = 3;
3827 if (node->size() < kInputNum) {
3828 MS_LOG(WARNING) << "Reshape must have two inputs.";
3829 return;
3830 }
3831 OpAdapterPtr adpt = FindAdapter(node, training_);
3832 if (adpt == nullptr) {
3833 return;
3834 }
3835 auto op = adpt->generate(node);
3836 MS_EXCEPTION_IF_NULL(op);
3837 // get shape form attr
3838 auto primitive = GetCNodePrimitive(node);
3839 MS_EXCEPTION_IF_NULL(primitive);
3840 if (primitive->HasAttr("shape")) {
3841 auto value = primitive->GetAttr("shape");
3842 auto list = CastToInt(value);
3843 (void)op->SetAttr("shape", list);
3844 }
3845 if (primitive->HasAttr("allowzero")) {
3846 auto value = primitive->GetAttr("allowzero");
3847 auto list = CastToInt(value);
3848 if (list.size() == 1) {
3849 (void)op->SetAttr("allowzero", list[0]);
3850 }
3851 }
3852 op_cache_[node.get()] = op;
3853 }
3854
ConvertDynamicStitch(const CNodePtr & node)3855 void DfGraphConvertor::ConvertDynamicStitch(const CNodePtr &node) {
3856 MS_LOG(INFO) << "Convert and set 'N' attr of DynamicStitch.";
3857 OpAdapterPtr adpt = FindAdapter(node, training_);
3858 if (adpt == nullptr) {
3859 return;
3860 }
3861 auto op = adpt->generate(node);
3862 MS_EXCEPTION_IF_NULL(op);
3863 int64_t input_length = 0;
3864 auto indices = node->input(1);
3865 MS_EXCEPTION_IF_NULL(indices);
3866 if (indices->isa<CNode>()) {
3867 input_length = SizeToLong(indices->cast<CNodePtr>()->size()) - 1;
3868 } else if (IsValueNode<ValueSequence>(indices)) {
3869 const auto tuple = GetValueNode<ValueSequencePtr>(indices);
3870 MS_EXCEPTION_IF_NULL(tuple);
3871 input_length = SizeToLong(tuple->size());
3872 } else {
3873 MS_LOG(EXCEPTION) << "Input 1 of DynamicStitch is neither CNode nor ValueNode contains ValueSequence, but "
3874 << indices->ToString() << ", can not set 'N' attr.";
3875 }
3876
3877 (void)op->SetAttr("N", input_length);
3878 MS_LOG(INFO) << "Set 'N' attr of DynamicStitch to " << input_length;
3879 op_cache_[node.get()] = op;
3880 }
3881
ConvertParallelGroupToHcom(const CNodePtr & node)3882 void DfGraphConvertor::ConvertParallelGroupToHcom(const CNodePtr &node) {
3883 auto group_name = common::AnfAlgo::GetNodeAttr<std::string>(node, kParallelGroup);
3884 OpAdapterPtr adpt = FindAdapter(node, training_);
3885 if (adpt == nullptr) {
3886 return;
3887 }
3888
3889 // get operator
3890 OperatorPtr op = nullptr;
3891 auto it_op = op_cache_.find(node.get());
3892 if (it_op != op_cache_.end()) {
3893 op = it_op->second;
3894 } else {
3895 op = adpt->generate(node);
3896 }
3897 MS_EXCEPTION_IF_NULL(op);
3898 (void)op->SetAttr(kParallelGroup, group_name);
3899 op_cache_[node.get()] = op;
3900 }
3901
ConvertParallelGroupIdToHcom(const CNodePtr & node)3902 void DfGraphConvertor::ConvertParallelGroupIdToHcom(const CNodePtr &node) {
3903 auto parallel_group_id_value = node->GetAttr(kParallelGroupId);
3904 auto parallel_group_id = GetValue<uint32_t>(parallel_group_id_value);
3905 OpAdapterPtr adpt = FindAdapter(node, training_);
3906 if (adpt == nullptr) {
3907 return;
3908 }
3909
3910 // get operator
3911 OperatorPtr op = nullptr;
3912 auto it_op = op_cache_.find(node.get());
3913 if (it_op != op_cache_.end()) {
3914 op = it_op->second;
3915 } else {
3916 op = adpt->generate(node);
3917 op_cache_[node.get()] = op;
3918 }
3919 MS_EXCEPTION_IF_NULL(op);
3920 (void)op->SetAttr(kParallelGroupId, parallel_group_id);
3921 MS_LOG(DEBUG) << "Successfully convert _parallel_group_id: " << parallel_group_id << " to ge op: " << op->GetName();
3922 }
3923
ConvertHcomFusionId(const CNodePtr & node)3924 void DfGraphConvertor::ConvertHcomFusionId(const CNodePtr &node) {
3925 MS_EXCEPTION_IF_NULL(node);
3926 MS_LOG(INFO) << "Add Hcom fusion_id";
3927 OpAdapterPtr adpt = FindAdapter(node, training_);
3928 if (adpt == nullptr) {
3929 return;
3930 }
3931 auto op = adpt->generate(node);
3932 MS_EXCEPTION_IF_NULL(op);
3933 // get shape form attr
3934 auto primitive = GetCNodePrimitive(node);
3935 MS_EXCEPTION_IF_NULL(primitive);
3936 auto fusion_value = primitive->GetAttr("fusion");
3937 if (fusion_value == nullptr) {
3938 MS_LOG(WARNING) << "Failed to get attr fusion for gather node " << node->fullname_with_scope();
3939 return;
3940 }
3941 int64_t fusion = 0;
3942 if (fusion_value->isa<Int64Imm>()) {
3943 fusion = GetValue<int64_t>(fusion_value);
3944 } else if (fusion_value->isa<Int32Imm>()) {
3945 fusion = GetValue<int32_t>(fusion_value);
3946 } else {
3947 MS_LOG(WARNING) << "Attr fusion is not int64/int32 type, real type " << fusion_value->type_name()
3948 << ", gather node " << node->fullname_with_scope();
3949 return;
3950 }
3951 int64_t fusion_id = -1;
3952
3953 // fusion 0: no fusion; 1(default): fusion; 2: fusion the ops by fusion id.
3954 if (fusion >= 1) {
3955 fusion_id = fusion;
3956 fusion = kHcclFusionByFusionID;
3957 } else if (fusion < 0) {
3958 fusion = kHcclFusionDefault;
3959 }
3960
3961 auto context = MsContext::GetInstance();
3962 MS_EXCEPTION_IF_NULL(context);
3963 if (context->CellReuseLevel() != CellReuseLevel::kNoCellReuse) {
3964 MS_LOG(INFO) << "cell reuse not support all fusion";
3965 fusion = 0;
3966 }
3967 MS_EXCEPTION_IF_NULL(parallel::ParallelContext::GetInstance());
3968 auto parallel_mode = parallel::ParallelContext::GetInstance()->parallel_mode();
3969 if (!MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_TASK_OPT) &&
3970 (parallel_mode == parallel::kSemiAutoParallel || parallel_mode == parallel::kAutoParallel)) {
3971 fusion_id = 0;
3972 fusion = 0;
3973 }
3974 (void)op->SetAttr("fusion_id", fusion_id);
3975 (void)op->SetAttr("fusion", fusion);
3976 AddCommAttrForHcclNode(node, op);
3977 op_cache_[node.get()] = op;
3978 }
3979
ConvertAllToAllv(const CNodePtr & node)3980 void DfGraphConvertor::ConvertAllToAllv(const CNodePtr &node) {
3981 OpAdapterPtr adpt = FindAdapter(node, training_);
3982 if (adpt == nullptr) {
3983 return;
3984 }
3985 auto op = adpt->generate(node);
3986 MS_EXCEPTION_IF_NULL(op);
3987 op_cache_[node.get()] = op;
3988 AddCommAttrForHcclNode(node, op);
3989 // set _is_inserted_by_ge attr to avoid mistaken delete
3990 auto primitive = GetCNodePrimitive(node);
3991 MS_EXCEPTION_IF_NULL(primitive);
3992 auto is_inserted_value = primitive->GetAttr("is_inserted_by_ge");
3993 if (is_inserted_value == nullptr) {
3994 return;
3995 }
3996 auto is_inserted = GetValue<bool>(is_inserted_value);
3997 (void)op->SetAttr("_is_inserted_by_ge", is_inserted);
3998 }
3999
ConvertUniformReal(const CNodePtr & node)4000 void DfGraphConvertor::ConvertUniformReal(const CNodePtr &node) {
4001 OpAdapterPtr adpt = FindAdapter(node, training_);
4002 if (adpt == nullptr) {
4003 return;
4004 }
4005 auto op = adpt->generate(node);
4006 MS_EXCEPTION_IF_NULL(op);
4007 op_cache_[node.get()] = op;
4008 (void)op->SetAttr("dtype", ::ge::DataType::DT_FLOAT);
4009 }
4010
ConvertUpdateState(const CNodePtr & node)4011 void DfGraphConvertor::ConvertUpdateState(const CNodePtr &node) {
4012 OpAdapterPtr adpt = FindAdapter(node, training_);
4013 if (adpt == nullptr) {
4014 return;
4015 }
4016 auto op = adpt->generate(node);
4017 MS_EXCEPTION_IF_NULL(op);
4018 op_cache_[node.get()] = op;
4019 if (common::AnfAlgo::HasNodeAttr(kAttrNotRemove, node)) {
4020 bool not_remove = common::AnfAlgo::GetNodeAttr<bool>(node, kAttrNotRemove);
4021 (void)op->SetAttr(kProcessNodeEngineID, "PS");
4022 (void)op->SetAttr(kAttrNotRemove, not_remove);
4023 has_es_node_ = true;
4024 }
4025 }
4026
ConvertHcclNode(const CNodePtr & node)4027 void DfGraphConvertor::ConvertHcclNode(const CNodePtr &node) {
4028 OpAdapterPtr adpt = FindAdapter(node, training_);
4029 if (adpt == nullptr) {
4030 return;
4031 }
4032 auto op = adpt->generate(node);
4033 MS_EXCEPTION_IF_NULL(op);
4034 AddCommAttrForHcclNode(node, op);
4035 op_cache_[node.get()] = op;
4036 }
4037
AddCommAttrForHcclNode(const CNodePtr & node,const OperatorPtr & converted_op) const4038 void DfGraphConvertor::AddCommAttrForHcclNode(const CNodePtr &node, const OperatorPtr &converted_op) const {
4039 MS_EXCEPTION_IF_NULL(node);
4040 MS_EXCEPTION_IF_NULL(converted_op);
4041 if (!common::AnfAlgo::HasNodeAttr(kAttrGroup, node)) {
4042 MS_LOG(WARNING) << "Node " << node->fullname_with_scope() << " does not have attr " << kAttrGroup << " skip.";
4043 return;
4044 }
4045 std::string group = common::AnfAlgo::GetNodeAttr<std::string>(node, kAttrGroup);
4046 (void)converted_op->SetAttr("group", group);
4047 #ifdef ENABLE_D
4048 if (!common::GetEnv(kSimulationLevel).empty()) {
4049 auto hccl_inner_comm_name = device::DummyAscendCollectiveCommLib::GetInstance().HcclInnerCommName(group);
4050 MS_LOG(INFO) << "Set comm handle and comm group name of the hccl node: " << node->fullname_with_scope()
4051 << "comm name:" << hccl_inner_comm_name;
4052 (void)converted_op->SetAttr("group", hccl_inner_comm_name);
4053 return;
4054 }
4055 if (common::GetEnv(kSimulationLevel).empty() && !common::IsNeedProfileMemory()) {
4056 if (common::UseHostCollective() && !hccl::HcclAdapter::GetInstance().UseHcclCM()) {
4057 // For HcclCommInitRootInfo manner, set 'group' and 'comm' attrs. 'group' attr value should be hccl's inner comm
4058 // name.
4059 auto comm = device::ascend::AscendCollectiveCommLib::GetInstance().HcclCommunicator(group);
4060 auto hccl_inner_comm_name = device::ascend::AscendCollectiveCommLib::GetInstance().HcclInnerCommName(group);
4061 MS_LOG(INFO) << "Set comm handle and comm group name of the hccl node: " << node->fullname_with_scope()
4062 << ". Comm handle: " << comm << ", comm name:" << hccl_inner_comm_name;
4063 MS_EXCEPTION_IF_NULL(comm);
4064 (void)converted_op->SetAttr("comm", reinterpret_cast<int64_t>(comm));
4065 (void)converted_op->SetAttr("group", hccl_inner_comm_name);
4066 } else {
4067 // For rank_table manner, 'group' attr should be user set group name.
4068 MS_LOG(INFO) << "Set group name for ranktable manner: " << group;
4069 (void)converted_op->SetAttr("group", group);
4070 }
4071 }
4072 #endif
4073 }
4074
ConvertConv2D(const CNodePtr & node)4075 void DfGraphConvertor::ConvertConv2D(const CNodePtr &node) {
4076 MS_LOG(INFO) << "Convert and set 'padding' attr for Conv2D-like op.";
4077 MS_EXCEPTION_IF_NULL(node);
4078 OpAdapterPtr adpt = FindAdapter(node, training_);
4079 if (adpt == nullptr) {
4080 return;
4081 }
4082 auto op = adpt->generate(node);
4083 MS_EXCEPTION_IF_NULL(op);
4084 op_cache_[node.get()] = op;
4085 auto primitive = GetCNodePrimitive(node);
4086 MS_EXCEPTION_IF_NULL(primitive);
4087 std::string pad_mode;
4088 if (auto pad_value = primitive->GetAttr("padding"); pad_value != nullptr) {
4089 pad_mode = GetValue<std::string>(pad_value);
4090 } else if (auto value = primitive->GetAttr("pad_mode"); value != nullptr) {
4091 // Get 'pad_mode' attr and set it to 'padding' attr for ge
4092 const mindspore::HashMap<int64_t, std::string> pad_mode_map{{1, "SAME"}, {2, "VALID"}};
4093 if (value->isa<StringImm>()) {
4094 pad_mode = GetValue<std::string>(value);
4095 (void)std::transform(pad_mode.cbegin(), pad_mode.cend(), pad_mode.begin(), toupper);
4096 if (pad_mode != "SAME" && pad_mode != "VALID") {
4097 return;
4098 }
4099 } else if (auto it = pad_mode_map.find(GetValue<int64_t>(value)); it != pad_mode_map.cend()) {
4100 // 'pad_mode' attr could be an enumeration
4101 pad_mode = it->second;
4102 } else {
4103 return;
4104 }
4105 } else {
4106 MS_LOG(INFO) << "Node: " << node->fullname_with_scope() << " has no 'padding' or 'pad_mode' attr";
4107 return;
4108 }
4109 MS_LOG(INFO) << "Set 'padding' attr of node: " << node->fullname_with_scope() << " to " << pad_mode;
4110 (void)op->SetAttr("padding", pad_mode);
4111 }
4112
ConvertOCRRecPreHandle(const CNodePtr & node)4113 void DfGraphConvertor::ConvertOCRRecPreHandle(const CNodePtr &node) {
4114 MS_LOG(INFO) << "Add OCRRecognitionPreHandle _op_max_shape attr";
4115 OpAdapterPtr adpt = FindAdapter(node, training_);
4116 if (adpt == nullptr) {
4117 return;
4118 }
4119 auto op = adpt->generate(node);
4120 MS_EXCEPTION_IF_NULL(op);
4121 // get shape form attr
4122 auto primitive = GetCNodePrimitive(node);
4123 MS_EXCEPTION_IF_NULL(primitive);
4124 auto value = primitive->GetAttr("_op_max_shape");
4125 if (value == nullptr) {
4126 return;
4127 }
4128 auto op_max_shape = GetValue<std::string>(value);
4129 (void)op->SetAttr("_op_max_shape", op_max_shape);
4130 op_cache_[node.get()] = op;
4131 }
4132
GetHandler(const AnfNodePtr & node)4133 OutHandler DfGraphConvertor::GetHandler(const AnfNodePtr &node) {
4134 if (node == nullptr) {
4135 MS_LOG(ERROR) << "Get nullptr while getting handler from node";
4136 return OutHandler(nullptr, "");
4137 }
4138 if (out_handle_cache_.find(node.get()) != out_handle_cache_.end()) {
4139 return out_handle_cache_[node.get()];
4140 }
4141 auto op = Convert(node);
4142 if (op != nullptr) {
4143 auto name = op->GetName();
4144 if ((vars_.count(name) != 0) && vars_[name] != nullptr) {
4145 op = vars_[name];
4146 MS_LOG(DEBUG) << "update tuple_out_handle_cache_ " << name;
4147 }
4148 return OutHandler(op, "", node);
4149 } else {
4150 MS_LOG(DEBUG) << "Add an empty out handler: " << node->ToString();
4151 return OutHandler();
4152 }
4153 }
4154
CheckCNode(const std::string & name,const CNodePtr node)4155 bool DfGraphConvertor::CheckCNode(const std::string &name, const CNodePtr node) {
4156 // ignore apply node of return
4157 if (name == "" || name == prim::kPrimSwitch->name() || name == prim::kPrimSwitchLayer->name() ||
4158 name == prim::kPrimPartial->name()) {
4159 return false;
4160 }
4161
4162 const mindspore::HashMap<std::string, std::function<void(decltype(this), const CNodePtr &)>>
4163 auxiliary_node_converters{
4164 // Convert TopK second input from int64 to int32.
4165 {prim::kPrimTopK->name(), &DfGraphConvertor::ConvertTopK},
4166 // Convert Reshape add const input to attr(shape)
4167 {prim::kPrimReshape->name(), &DfGraphConvertor::ConvertReshape},
4168 {prim::kPrimOCRRecognitionPreHandle->name(), &DfGraphConvertor::ConvertOCRRecPreHandle},
4169 // Add attr 'pad_mode' to Conv2D-like op
4170 {prim::kPrimConv2D->name(), &DfGraphConvertor::ConvertConv2D},
4171 {prim::kPrimDepthwiseConv2dNative->name(), &DfGraphConvertor::ConvertConv2D},
4172 {kNameConv2DBackpropInputV2, &DfGraphConvertor::ConvertConv2D},
4173 {prim::kPrimConv2DBackpropInput->name(), &DfGraphConvertor::ConvertConv2D},
4174 {prim::kPrimConv2DBackpropFilter->name(), &DfGraphConvertor::ConvertConv2D},
4175 // Add attr 'N' to DynamicStitch
4176 {prim::kPrimDynamicStitch->name(), &DfGraphConvertor::ConvertDynamicStitch},
4177 // Convert hccl op for comm handle
4178 {prim::kPrimAllReduce->name(), &DfGraphConvertor::ConvertHcomFusionId},
4179 {prim::kPrimAllGather->name(), &DfGraphConvertor::ConvertHcomFusionId},
4180 {prim::kPrimReduceScatter->name(), &DfGraphConvertor::ConvertHcomFusionId},
4181 {prim::kPrimBroadcast->name(), &DfGraphConvertor::ConvertHcclNode},
4182 {prim::kPrimReduceScatter->name(), &DfGraphConvertor::ConvertHcclNode},
4183 {prim::kPrimSend->name(), &DfGraphConvertor::ConvertHcclNode},
4184 {prim::kPrimReceive->name(), &DfGraphConvertor::ConvertHcclNode},
4185 {prim::kPrimAllToAllv->name(), &DfGraphConvertor::ConvertAllToAllv},
4186 {prim::kPrimUniformReal->name(), &DfGraphConvertor::ConvertUniformReal},
4187 {prim::kPrimMatmulReduceScatter->name(), &DfGraphConvertor::ConvertHcclNode},
4188 {prim::kPrimAllGatherMatmul->name(), &DfGraphConvertor::ConvertHcclNode},
4189 {prim::kPrimUpdateState->name(), &DfGraphConvertor::ConvertUpdateState},
4190 };
4191
4192 if (const auto it = auxiliary_node_converters.find(name); it != auxiliary_node_converters.cend()) {
4193 it->second(this, node);
4194 }
4195 if (common::AnfAlgo::HasNodeAttr(kParallelGroup, node)) {
4196 ConvertParallelGroupToHcom(node);
4197 }
4198 if (node->HasAttr(kParallelGroupId)) {
4199 ConvertParallelGroupIdToHcom(node);
4200 }
4201
4202 return true;
4203 }
4204
CheckAndAddScopeAttrInt(const OperatorPtr op,const PrimitivePtr primitive,const std::string & attr_name)4205 void CheckAndAddScopeAttrInt(const OperatorPtr op, const PrimitivePtr primitive, const std::string &attr_name) {
4206 auto attr_value = primitive->GetAttr(attr_name);
4207 if (attr_value != nullptr) {
4208 auto value = GetValue<int64_t>(attr_value);
4209 (void)op->SetAttr(attr_name, value);
4210 }
4211 }
4212
CheckAndAddScopeAttrString(const OperatorPtr op,const PrimitivePtr primitive,const std::string & attr_name)4213 void CheckAndAddScopeAttrString(const OperatorPtr op, const PrimitivePtr primitive, const std::string &attr_name) {
4214 auto attr_value = primitive->GetAttr(attr_name);
4215 if (attr_value != nullptr) {
4216 auto value = GetValue<std::string>(attr_value);
4217 (void)op->SetAttr(attr_name, value);
4218 }
4219 }
4220
4221 // If node does not have abstract, it will fail when the node is generated to operator.
SetNodeAbstract(const CNodePtr & node) const4222 void DfGraphConvertor::SetNodeAbstract(const CNodePtr &node) const {
4223 MS_EXCEPTION_IF_NULL(node);
4224 if (node->abstract() != nullptr) {
4225 return;
4226 }
4227 if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
4228 auto inputs = node->inputs();
4229 AbstractBasePtrList elem;
4230 std::transform(inputs.begin() + 1, inputs.end(), std::back_inserter(elem),
4231 [](const AnfNodePtr &node) { return node->abstract(); });
4232 node->set_abstract(std::make_shared<abstract::AbstractTuple>(elem));
4233 return;
4234 }
4235 if (IsPrimitiveCNode(node, prim::kPrimReturn) || IsPrimitiveCNode(node, prim::kPrimDepend)) {
4236 auto inputs = node->inputs();
4237 if (inputs.size() < kInputSize2) {
4238 MS_LOG(EXCEPTION) << "node input size " << inputs.size() << " less than 2, node: " << node->fullname_with_scope();
4239 }
4240 auto input = inputs[1];
4241 MS_EXCEPTION_IF_NULL(input);
4242 node->set_abstract(input->abstract());
4243 return;
4244 }
4245 MS_LOG(WARNING) << "Node has not abstract:" << node->fullname_with_scope() << ", DebugString: " << node->ToString();
4246 }
4247
ConvertCNode(const CNodePtr node)4248 OperatorPtr DfGraphConvertor::ConvertCNode(const CNodePtr node) {
4249 SaveParamFormat(node);
4250 std::string name = GetCNodeTargetFuncName(node);
4251 if (!CheckCNode(name, node)) {
4252 return nullptr;
4253 }
4254
4255 // get corresponding OpAdapter
4256 OpAdapterPtr adpt = FindAdapter(node, training_);
4257 if (adpt == nullptr) {
4258 MS_LOG(ERROR) << "Cannot get adapter for " << node->fullname_with_scope();
4259 unsupported_ops_names_.insert(name);
4260 error_ = NOT_FOUND;
4261 return nullptr;
4262 }
4263 SetNodeAbstract(node);
4264 // get operator
4265 OperatorPtr op = nullptr;
4266 auto it_op = op_cache_.find(node.get());
4267 if (it_op != op_cache_.end()) {
4268 op = it_op->second;
4269 } else {
4270 if (cur_while_node_ == node) {
4271 op = adpt->generateDynOutputOp(node);
4272 } else {
4273 op = adpt->generate(node);
4274 }
4275 }
4276
4277 // set attribute for primitive
4278 (void)adpt->setAttr(op, node);
4279 auto value_node = node->input(0)->cast<ValueNodePtr>();
4280 if (value_node != nullptr && value_node->value()->cast<PrimitivePtr>() != nullptr) {
4281 MS_LOG(DEBUG) << "Set attr for subgraph multi dims";
4282 auto primitive = value_node->value()->cast<PrimitivePtr>();
4283 CheckAndAddScopeAttrInt(op, primitive, "_subgraph_multi_dims_index");
4284 CheckAndAddScopeAttrString(op, primitive, "_subgraph_multi_dims_input_dims");
4285 CheckAndAddScopeAttrString(op, primitive, "_subgraph_multi_dims_input_shape");
4286 }
4287
4288 // add into cache
4289 (void)op_cache_.emplace(node.get(), op);
4290
4291 DrawCNode(node, adpt);
4292
4293 return op_cache_[node.get()];
4294 }
4295
ConvertParameter(const AnfNodePtr node)4296 OperatorPtr DfGraphConvertor::ConvertParameter(const AnfNodePtr node) {
4297 // convert Parameter in ANF to variable in DataFlow
4298 auto adpt = FindAdapter(node, training_);
4299 if (adpt == nullptr) {
4300 MS_LOG(EXCEPTION) << "Can not find adapter for Parameter";
4301 }
4302 auto op = adpt->generate(node);
4303 op_cache_[node.get()] = op;
4304
4305 // build index for parameter using name
4306 std::string name = std::static_pointer_cast<Parameter>(node)->name();
4307 params_[name] = node;
4308 std::ostringstream ss;
4309 ss << "op" << node.get();
4310 op_draw_name_[node.get()] = ss.str();
4311 compute_sout_ << ss.str() << "[shape=octagon, label=\"" << name << "\"]" << endl;
4312 return op_cache_[node.get()];
4313 }
4314
SaveParamFormat(const CNodePtr node)4315 void DfGraphConvertor::SaveParamFormat(const CNodePtr node) {
4316 AnfNodePtr op = node->input(0);
4317 if (IsValueNode<Primitive>(op)) {
4318 auto prim = GetValueNode<PrimitivePtr>(op);
4319 std::string format;
4320 auto op_def = ops::GetOpDef(prim->name());
4321 if (op_def) {
4322 for (size_t index = 0; index < op_def->args_.size() && index < node->size() - 1; index++) {
4323 auto arg = op_def->args_[index];
4324 if (arg.as_init_arg_ && (arg.arg_name_ == ops::kFormat || arg.arg_name_ == ops::kDataFormat)) {
4325 auto value_ptr = node->input(index + 1)->cast<ValueNodePtr>();
4326 if (value_ptr == nullptr) {
4327 break;
4328 }
4329 auto input_value = value_ptr->value();
4330 MS_EXCEPTION_IF_NULL(input_value);
4331 auto format_id = GetValue<int64_t>(input_value);
4332 format = FormatEnumToString(static_cast<Format>(format_id));
4333 }
4334 }
4335 }
4336 auto value_ptr = prim->GetAttr(ops::kFormat);
4337 if (value_ptr) {
4338 if (value_ptr->isa<Int64Imm>()) {
4339 bool converted = CheckAndConvertUtils::ConvertAttrValueToString(prim->name(), "format", &value_ptr);
4340 if (converted) {
4341 format = value_ptr->ToString();
4342 } else {
4343 CheckAndConvertUtils::GetFormatStringVal(prim, &format);
4344 }
4345 } else if (value_ptr->isa<StringImm>()) {
4346 format = value_ptr->ToString();
4347 }
4348 }
4349
4350 if (format == "NCDHW" || format == "NHWC") {
4351 for (size_t i = 1; i < node->size(); i++) {
4352 auto input = node->input(i);
4353 if (input->isa<Parameter>()) {
4354 param_format_[input->DebugString()] = format;
4355 MS_LOG(DEBUG) << "Save Param " << input->DebugString() << " format: " << format;
4356 }
4357 }
4358 }
4359 }
4360 }
4361
TryConvertValueNodeToMultiConst(const ValueNodePtr node)4362 Status DfGraphConvertor::TryConvertValueNodeToMultiConst(const ValueNodePtr node) {
4363 MS_EXCEPTION_IF_NULL(node);
4364 ValuePtr value = node->value();
4365 MS_EXCEPTION_IF_NULL(value);
4366 if (!value->isa<ValueList>() && !value->isa<ValueTuple>()) {
4367 return FAILED;
4368 }
4369
4370 auto vec = value->isa<ValueTuple>() ? value->cast<ValueTuplePtr>()->value() : value->cast<ValueListPtr>()->value();
4371 if (vec.empty()) {
4372 return FAILED;
4373 }
4374
4375 std::shared_ptr<std::vector<OutHandler>> tuple_items = std::make_shared<std::vector<OutHandler>>();
4376 // if the the sequence has only one element which is a scalar, it should be convert to a 1-D Tensor rather than a
4377 // 0-D Scalar.
4378 if (vec.size() == 1 && !vec[0]->isa<MeTensor>()) {
4379 return FAILED;
4380 }
4381 for (size_t i = 0; i < vec.size(); i++) {
4382 MS_EXCEPTION_IF_NULL(vec[i]);
4383 GeTensorPtr ge_tensor = nullptr;
4384 if (vec[i]->isa<MeTensor>()) {
4385 ge_tensor = transform::TransformUtil::ConvertTensor(vec[i]->cast<MeTensorPtr>(), kOpFormat_DEFAULT);
4386 MS_EXCEPTION_IF_NULL(ge_tensor);
4387 } else {
4388 ge_tensor = transform::TransformUtil::ConvertScalar(vec[i]);
4389 if (ge_tensor == nullptr) {
4390 return FAILED;
4391 }
4392 }
4393 auto const_op = std::make_shared<Constant>(node->fullname_with_scope() + "/const/inputs/" + std::to_string(i));
4394 AddGraphConstInput(const_op);
4395 (void)const_op->set_attr_value(*ge_tensor);
4396 (void)const_op->update_output_desc_y(ge_tensor->GetTensorDesc());
4397 (void)tuple_items->emplace_back(OutHandler(const_op, ""));
4398 }
4399 if (tuple_items->empty()) {
4400 return FAILED;
4401 }
4402
4403 tuple_out_handle_cache_[node.get()] = tuple_items;
4404 if (!vec[0]->isa<MeTensor>()) {
4405 return FAILED;
4406 }
4407 return SUCCESS;
4408 }
4409
ConvertValueNode(const ValueNodePtr node)4410 OperatorPtr DfGraphConvertor::ConvertValueNode(const ValueNodePtr node) {
4411 // convert valuenode in ANF to Const in DataFlow
4412 // find paramerte referenced by SymbolicKeyInstance of valuenode
4413 std::ostringstream ss;
4414 ss << "op" << node.get();
4415 op_draw_name_[node.get()] = ss.str();
4416 compute_sout_ << ss.str() << "[label= \"" << node->value()->ToString() << "\" shape=ellipse]" << endl;
4417
4418 if (TryConvertValueNodeToMultiConst(node) == SUCCESS) {
4419 MS_LOG(INFO) << "Convert value node to multi Constant OP success";
4420 return nullptr;
4421 }
4422
4423 OpAdapterPtr adpt = FindAdapter(node, training_);
4424 if (adpt == nullptr) {
4425 error_ = NOT_FOUND;
4426 return nullptr;
4427 }
4428 auto op = adpt->generate(node);
4429 // set const's attrs
4430 if (adpt->setAttr(op, "value", node->value()) != 0) {
4431 MS_LOG(WARNING) << "set attr value for const failed";
4432 }
4433
4434 if (op->GetOpType() != "Constant" && op->GetOpType() != "Const") {
4435 MS_LOG(ERROR) << "Get Constant operator failed, ge node type: " << op->GetOpType()
4436 << ", ms node info: " << node->ToString() << ", is train: " << training_;
4437 return nullptr;
4438 }
4439 ::ge::Tensor ge_tensor;
4440 (void)op->GetAttr("value", ge_tensor);
4441 auto ge_desc = ge_tensor.GetTensorDesc();
4442 (void)op->UpdateOutputDesc(kTypeY, ge_desc);
4443
4444 op_cache_[node.get()] = op;
4445 return op_cache_[node.get()];
4446 }
4447
DrawCNode(const CNodePtr node,const OpAdapterPtr adpt)4448 void DfGraphConvertor::DrawCNode(const CNodePtr node, const OpAdapterPtr adpt) {
4449 if (adpt == nullptr || node == nullptr) {
4450 MS_LOG(ERROR) << "Failed to draw apply node as adpt or node is nullptr!";
4451 return;
4452 }
4453 std::ostringstream ss;
4454 ss << "op" << node.get();
4455 op_draw_name_[node.get()] = ss.str();
4456
4457 compute_sout_ << ss.str() << "[label=<";
4458 compute_sout_ << "<table border='1' cellborder='1'>" << endl;
4459
4460 auto input_map = adpt->getInputMap();
4461 auto dyn_input_map = adpt->getDynInputMap();
4462 if (input_map.size() + dyn_input_map.size() > 0) {
4463 compute_sout_ << "<tr>";
4464 for (auto &it : input_map) {
4465 compute_sout_ << "<td port='" << it.first << "'>" << it.second.name << "</td>";
4466 }
4467 for (auto &it : dyn_input_map) {
4468 compute_sout_ << "<td port='" << it.first << "'>" << it.second.name << "</td>";
4469 }
4470 compute_sout_ << "</tr>" << endl;
4471 }
4472
4473 compute_sout_ << "<tr><td colspan=\"" << (input_map.size() + dyn_input_map.size()) << "\">\"" << node->ToString()
4474 << ":" << GetCNodeTargetFuncName(node) << "\"</td></tr>" << endl;
4475
4476 // print attrs' values
4477 auto atts = adpt->GetAttrsFromDrawGraph();
4478 for (auto &it : atts) {
4479 compute_sout_ << "<tr><td colspan=\"" << (input_map.size() + dyn_input_map.size()) << "\">\"" << it
4480 << "\"</td></tr>";
4481 }
4482
4483 adpt->clearAttrVect();
4484
4485 compute_sout_ << "</table>> shape=plaintext]" << endl;
4486 }
RegisterAdapter(const std::string & name,OpAdapterPtr adpt)4487 void DfGraphConvertor::RegisterAdapter(const std::string &name, OpAdapterPtr adpt) {
4488 OpAdapterMap::get()[name] = std::make_shared<OpAdapterDesc>(adpt);
4489 }
RegisterAdapter(const std::string & name,OpAdapterPtr train_adpt,OpAdapterPtr infer_adpt)4490 void DfGraphConvertor::RegisterAdapter(const std::string &name, OpAdapterPtr train_adpt, OpAdapterPtr infer_adpt) {
4491 OpAdapterMap::get()[name] = std::make_shared<OpAdapterDesc>(train_adpt, infer_adpt);
4492 }
4493
GetAttrAndValue(const AnfNodePtr & node,const bool training=true)4494 std::map<std::string, ValuePtr> GeOpConvertor::GetAttrAndValue(const AnfNodePtr &node, const bool training = true) {
4495 MS_EXCEPTION_IF_NULL(node);
4496 std::map<std::string, ValuePtr> attr_list;
4497 if (!node->isa<CNode>()) {
4498 MS_LOG(INFO) << "Current node isn't a cnode! node info:" << node->DebugString();
4499 return attr_list;
4500 }
4501
4502 OpAdapterPtr adpt = FindAdapter(node, training);
4503 if (adpt == nullptr) {
4504 MS_LOG(INFO) << "Current node can't find adpt! node info:" << node->DebugString();
4505 return attr_list;
4506 }
4507
4508 attr_list = adpt->GetNormalOpAttrList(node);
4509 return attr_list;
4510 }
4511
GetOpType(const AnfNodePtr & node,const bool training=true)4512 std::string GeOpConvertor::GetOpType(const AnfNodePtr &node, const bool training = true) {
4513 MS_EXCEPTION_IF_NULL(node);
4514 OpAdapterPtr adpt = FindAdapter(node, training);
4515 if (adpt == nullptr) {
4516 MS_LOG(INFO) << "Current node can't find adpt! node info:" << node->DebugString();
4517 return "";
4518 }
4519 return adpt->getOpType();
4520 }
4521
GetTensorDesc(const ShapeVector & dev_shape,const TypeId & dev_type,const std::string & dev_format,const ShapeVector & ori_shape,const std::string & ori_format)4522 std::shared_ptr<GeTensorDesc> GeOpConvertor::GetTensorDesc(const ShapeVector &dev_shape, const TypeId &dev_type,
4523 const std::string &dev_format, const ShapeVector &ori_shape,
4524 const std::string &ori_format) {
4525 auto tensor_desc = transform::TransformUtil::GetGeTensorDesc(dev_shape, dev_type, dev_format, ori_shape, ori_format);
4526 MS_EXCEPTION_IF_NULL(tensor_desc);
4527 return tensor_desc;
4528 }
4529
GetNeedAddInput(const AnfNodePtr & node,const bool training)4530 mindspore::HashMap<std::string, std::string> GeOpConvertor::GetNeedAddInput(const AnfNodePtr &node,
4531 const bool training) {
4532 MS_EXCEPTION_IF_NULL(node);
4533 OpAdapterPtr adpt = FindAdapter(node, training);
4534 if (adpt == nullptr) {
4535 MS_LOG(INFO) << "Current node can't find adpt! node info:" << node->DebugString();
4536 return {};
4537 }
4538
4539 return adpt->getAttrInputMap();
4540 }
4541
IsDynamicInput(const AnfNodePtr & node,const size_t idx)4542 bool GeOpConvertor::IsDynamicInput(const AnfNodePtr &node, const size_t idx) {
4543 MS_EXCEPTION_IF_NULL(node);
4544 OpAdapterPtr adapterPtr = FindAdapter(node, true);
4545 if (adapterPtr == nullptr) {
4546 MS_LOG(INFO) << "Can't find a adapter for op:" << node->DebugString();
4547 return false;
4548 }
4549 return adapterPtr->IsDynInputOp(idx);
4550 }
4551
GetAclInputNames(const AnfNodePtr & node)4552 std::map<int, std::string> GeOpConvertor::GetAclInputNames(const AnfNodePtr &node) {
4553 MS_EXCEPTION_IF_NULL(node);
4554 OpAdapterPtr adapterPtr = FindAdapter(node, true);
4555 if (adapterPtr == nullptr) {
4556 MS_LOG(EXCEPTION) << "Can't find a adapter for op:" << node->DebugString();
4557 }
4558
4559 std::map<int, std::string> input_names;
4560 for (const auto &[k, v] : adapterPtr->getInputMap()) {
4561 input_names.emplace(k, v.name);
4562 }
4563 // dynamic input
4564 for (const auto &[k, v] : adapterPtr->getDynInputMap()) {
4565 input_names.emplace(k, v.name);
4566 }
4567 return input_names;
4568 }
4569
GetAclOutputNames(const AnfNodePtr & node)4570 std::map<int, std::string> GeOpConvertor::GetAclOutputNames(const AnfNodePtr &node) {
4571 MS_EXCEPTION_IF_NULL(node);
4572 OpAdapterPtr adapterPtr = FindAdapter(node, true);
4573 if (adapterPtr == nullptr) {
4574 MS_LOG(EXCEPTION) << "Can't find a adapter for op:" << node->DebugString();
4575 }
4576
4577 std::map<int, std::string> output_names;
4578 for (const auto &[k, v] : adapterPtr->getOutputMap()) {
4579 output_names.emplace(k, v.name);
4580 }
4581
4582 // dynamic output
4583 for (const auto &[k, v] : adapterPtr->getDynOutputMap()) {
4584 output_names.emplace(k, v.name);
4585 }
4586 return output_names;
4587 }
4588
GetAclDynamicInputNames(const AnfNodePtr & node)4589 std::map<int, std::string> GeOpConvertor::GetAclDynamicInputNames(const AnfNodePtr &node) {
4590 MS_EXCEPTION_IF_NULL(node);
4591 OpAdapterPtr adapterPtr = FindAdapter(node, true);
4592 if (adapterPtr == nullptr) {
4593 MS_LOG(EXCEPTION) << "Can't find a adapter for op:" << node->DebugString();
4594 }
4595 std::map<int, std::string> dyn_input_names;
4596 for (const auto &[k, v] : adapterPtr->getDynInputMap()) {
4597 dyn_input_names.emplace(k, v.name);
4598 }
4599 return dyn_input_names;
4600 }
4601
GetAclDynamicOutputNames(const AnfNodePtr & node)4602 std::map<int, std::string> GeOpConvertor::GetAclDynamicOutputNames(const AnfNodePtr &node) {
4603 MS_EXCEPTION_IF_NULL(node);
4604 OpAdapterPtr adapterPtr = FindAdapter(node, true);
4605 if (adapterPtr == nullptr) {
4606 MS_LOG(EXCEPTION) << "Can't find a adapter for op:" << node->DebugString();
4607 }
4608 std::map<int, std::string> dyn_output_names;
4609 for (const auto &[k, v] : adapterPtr->getDynOutputMap()) {
4610 dyn_output_names.emplace(k, v.name);
4611 }
4612 return dyn_output_names;
4613 }
4614 } // namespace mindspore::transform
4615