• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/parallel/graph_util/node_info.h"
18 
19 #include <string>
20 #include <utility>
21 
22 #include "ops/sequence_ops.h"
23 #include "ops/array_ops.h"
24 #include "ops/framework_ops.h"
25 #include "ir/param_info.h"
26 #include "ir/meta_tensor.h"
27 #include "include/common/utils/python_adapter.h"
28 #include "frontend/parallel/ops_info/ops_utils.h"
29 #include "frontend/parallel/step_parallel.h"
30 #include "frontend/parallel/step_parallel_utils.h"
31 
32 namespace mindspore {
33 namespace parallel {
34 const std::vector<std::string> filter_attrs = {RECOMPUTE, TARGET};
35 const uint32_t kMinInputSize = 2;
36 constexpr size_t kSize2 = 2;
ParameterName(const AnfNodePtr & node_ptr)37 std::string ParameterName(const AnfNodePtr &node_ptr) {
38   auto para_ptr = node_ptr->cast<ParameterPtr>();
39   MS_EXCEPTION_IF_NULL(para_ptr);
40   return para_ptr->name();
41 }
42 
ParameterRequireGrad(const AnfNodePtr & node_ptr)43 bool ParameterRequireGrad(const AnfNodePtr &node_ptr) {
44   auto para_ptr = node_ptr->cast<ParameterPtr>();
45   if (para_ptr == nullptr) {
46     return false;
47   }
48   if (!para_ptr->has_default()) {
49     return false;
50   }
51   auto param_value = para_ptr->param_info();
52   if (param_value == nullptr) {
53     return false;
54   }
55   return param_value->requires_grad();
56 }
57 
GetRealInput(const AnfNodePtr & input)58 AnfNodePtr GetRealInput(const AnfNodePtr &input) {
59   auto res = input;
60   while (IsPrimitiveCNode(res, prim::kPrimLoad) || IsPrimitiveCNode(res, prim::kPrimDepend)) {
61     res = res->cast<CNodePtr>()->input(1);
62     if (!res->isa<CNode>()) {
63       return res;
64     }
65   }
66   return res;
67 }
68 
69 // Given the node, return whether each input is a parameter or a output of a operator.
70 // The returned boolean vector should be the same order of the inputs, thus its implementation
71 // is closely consistent with ExtractShape() in step_parallel.cc
ExtractInputParameterByNode(const CNodePtr & node)72 std::vector<bool> ExtractInputParameterByNode(const CNodePtr &node) {
73   std::vector<bool> is_parameter;
74   std::vector<AnfNodePtr> node_inputs{node->inputs()};
75   // input is a ValueList or ValueTuple, then all inputs are not parameter.
76   if ((node_inputs.size() == kMinInputSize || IsSomePrimitiveList(node, INPUT_IS_TUPLE_OR_LIST_OPS)) &&
77       IsValueSequence(node_inputs[1])) {
78     std::vector<ValuePtr> inputs_seq;
79     if (IsValueNode<ValueList>(node_inputs[1])) {
80       inputs_seq = node_inputs[1]->cast<ValueNodePtr>()->value()->cast<ValueListPtr>()->value();
81     } else {
82       inputs_seq = node_inputs[1]->cast<ValueNodePtr>()->value()->cast<ValueTuplePtr>()->value();
83     }
84     size_t inputs_seq_tensor_size = inputs_seq.size();
85     for (const auto &inputs_seq_value : inputs_seq) {
86       auto tensor = inputs_seq_value->cast<tensor::TensorPtr>();
87       if (tensor == nullptr) {
88         MS_LOG(DEBUG) << "The value not is not a tensor.";
89         inputs_seq_tensor_size = 0;
90         break;
91       }
92     }
93     return std::vector<bool>(inputs_seq_tensor_size, false);
94   }
95   if ((node_inputs.size() == kMinInputSize || IsSomePrimitiveList(node, INPUT_IS_TUPLE_OR_LIST_OPS)) &&
96       IsMakeSequence(node_inputs[1])) {
97     node_inputs = node_inputs[1]->cast<CNodePtr>()->inputs();
98   }
99   for (size_t i = 1; i < node_inputs.size(); ++i) {
100     auto input = GetRealInput(node_inputs[i]);
101     if (HasAbstractMonad(input)) {
102       continue;
103     }
104     if (input->isa<Parameter>()) {
105       auto input_parameter = input->cast<ParameterPtr>();
106       is_parameter.push_back(ParameterRequireGrad(input_parameter));
107     } else if (input->isa<CNode>() || IsValueNode<tensor::Tensor>(input) || IsValueNode<RefKey>(input)) {
108       if (IsDynamicShapeInput(node, input)) {
109         MS_LOG(INFO) << "may be dynamic shape, no need to get input's shape, the node is " << node->ToString();
110         continue;
111       }
112       is_parameter.push_back(false);
113     }
114   }
115   return is_parameter;
116 }
117 
ExtractInputParameterNameByNode(const CNodePtr & node)118 std::string ExtractInputParameterNameByNode(const CNodePtr &node) {
119   std::string param_name = "";
120   std::vector<AnfNodePtr> node_inputs{node->inputs()};
121   // input is a ValueList or ValueTuple, then all inputs are not parameter.
122   if ((node_inputs.size() == kMinInputSize || IsSomePrimitiveList(node, INPUT_IS_TUPLE_OR_LIST_OPS)) &&
123       IsValueSequence(node_inputs[1])) {
124     node_inputs = node_inputs[1]->cast<CNodePtr>()->inputs();
125   }
126   for (size_t i = 1; i < node_inputs.size(); ++i) {
127     auto input = GetRealInput(node_inputs[i]);
128     if (HasAbstractMonad(input)) {
129       continue;
130     }
131     if (input->isa<Parameter>()) {
132       param_name = input->fullname_with_scope();
133       auto input_parameter = input->cast<ParameterPtr>();
134       MS_LOG(INFO) << "node name: " << node->fullname_with_scope() << "involved parameter: " << input_parameter->name();
135     }
136   }
137   return param_name;
138 }
139 
140 // Given the type, return the number of bytes to represent this type
GetLengthOfDataType(const TypePtr & type)141 size_t GetLengthOfDataType(const TypePtr &type) {
142   switch (type->type_id()) {
143     case kNumberTypeBool:
144       return sizeof(bool);
145     case kNumberTypeInt8:
146       return sizeof(int8_t);
147     case kNumberTypeInt16:
148       return sizeof(int16_t);
149     case kNumberTypeInt32:
150       return sizeof(int32_t);
151     case kNumberTypeInt64:
152       return sizeof(int64_t);
153     case kNumberTypeUInt8:
154       return sizeof(uint8_t);
155     case kNumberTypeUInt16:
156       return sizeof(uint16_t);
157     case kNumberTypeUInt32:
158       return sizeof(uint32_t);
159     case kNumberTypeUInt64:
160       return sizeof(uint64_t);
161     case kNumberTypeFloat16:
162       return sizeof(float) / kSize2;
163     case kNumberTypeFloat32:
164       return sizeof(float);
165     case kNumberTypeFloat64:
166       return sizeof(double);
167     case kNumberTypeInt:
168       return sizeof(int64_t);
169     case kNumberTypeUInt:
170       return sizeof(unsigned);
171     case kNumberTypeFloat:
172       return sizeof(float);
173     case kNumberTypeBFloat16:
174       return sizeof(float) / kSize2;
175     case kNumberTypeComplex64:
176       return sizeof(float) * kSize2;
177     default:
178       MS_LOG(EXCEPTION) << "Unexpected type " << type->type_name();
179   }
180 }
181 
GetInputsTypeLen(const AnfNodePtr & input)182 size_t GetInputsTypeLen(const AnfNodePtr &input) {
183   MS_EXCEPTION_IF_NULL(input);
184   if (!input->isa<CNode>() && !input->isa<Parameter>() && !IsValueNode<tensor::Tensor>(input)) {
185     MS_LOG(EXCEPTION) << "The input node is not a cnode or parameter or tensor";
186   }
187 
188   size_t input_type_len = 0;
189   auto type = input->Type();
190   MS_EXCEPTION_IF_NULL(type);
191   if (type->isa<mindspore::TensorType>()) {
192     auto input_element_type = type->cast<mindspore::TensorTypePtr>()->element();
193     input_type_len = GetLengthOfDataType(input_element_type);
194   } else {
195     MS_LOG(EXCEPTION) << "Unknown type: " << type->type_name();
196   }
197   return input_type_len;
198 }
199 
ExtractInputElementLength(const CNodePtr & node,std::vector<AnfNodePtr> node_inputs)200 std::vector<size_t> ExtractInputElementLength(const CNodePtr &node, std::vector<AnfNodePtr> node_inputs) {
201   std::vector<size_t> inputs_type_len;
202   // extract input element length
203   for (auto &input : node_inputs) {
204     if (HasAbstractMonad(input)) {
205       continue;
206     }
207     if (IsValueNode<RefKey>(input)) {
208       auto func_graph = node->func_graph();
209       MS_EXCEPTION_IF_NULL(func_graph);
210       std::vector<AnfNodePtr> parameters = FindParameterByRefKeyNode(input, func_graph);
211       if (parameters.size() != 1) {
212         MS_LOG(EXCEPTION) << "Find parameter by ref key node failed";
213       }
214       inputs_type_len.push_back(GetInputsTypeLen(parameters[0]));
215     } else if (input->isa<CNode>() || input->isa<Parameter>() || IsValueNode<tensor::Tensor>(input)) {
216       if (IsDynamicShapeInput(node, input)) {
217         MS_LOG(INFO) << "may be dynamic shape, no need to get input's shape, the node is " << node->ToString();
218         continue;
219       }
220       // extract input shape from parameter and apply node
221       inputs_type_len.push_back(GetInputsTypeLen(input));
222     }
223   }
224   return inputs_type_len;
225 }
226 
extra_input_for_ifa(CNodePtr node,std::vector<AnfNodePtr> node_input)227 std::vector<AnfNodePtr> extra_input_for_ifa(CNodePtr node, std::vector<AnfNodePtr> node_input) {
228   ValueNodePtr anf_node = node->input(0)->cast<ValueNodePtr>();
229   if (!anf_node) {
230     return node_input;
231   }
232   PrimitivePtr prim = anf_node->value()->cast<PrimitivePtr>();
233   if (!prim) {
234     return node_input;
235   }
236   if (prim->name() != INCRE_FLASH_ATTENTION) {
237     return node_input;
238   }
239   for (size_t input_index = 1; input_index < node_input.size(); input_index++) {
240     if (node_input[input_index] != nullptr && IsMakeSequence(node_input[input_index])) {
241       node_input[input_index] = node_input[input_index]->cast<CNodePtr>()->inputs()[1];
242     }
243   }
244   return node_input;
245 }
246 
ExtractInputTypeLengthByNode(const CNodePtr & node)247 std::vector<size_t> ExtractInputTypeLengthByNode(const CNodePtr &node) {
248   MS_EXCEPTION_IF_NULL(node);
249   std::vector<size_t> inputs_type_len;
250   std::vector<AnfNodePtr> node_inputs{node->inputs()};
251 
252   if ((node_inputs.size() == kMinInputSize || IsSomePrimitiveList(node, INPUT_IS_TUPLE_OR_LIST_OPS)) &&
253       IsValueSequence(node_inputs[1])) {
254     std::vector<ValuePtr> inputs_seq;
255     if (IsValueNode<ValueList>(node_inputs[1])) {
256       inputs_seq = node_inputs[1]->cast<ValueNodePtr>()->value()->cast<ValueListPtr>()->value();
257     } else {
258       inputs_seq = node_inputs[1]->cast<ValueNodePtr>()->value()->cast<ValueTuplePtr>()->value();
259     }
260     for (auto &ele : inputs_seq) {
261       auto tensor = ele->cast<tensor::TensorPtr>();
262       if (tensor == nullptr) {
263         inputs_type_len.clear();
264         return inputs_type_len;
265       }
266       inputs_type_len.push_back(GetLengthOfDataType(tensor->Dtype()));
267     }
268     return inputs_type_len;
269   }
270 
271   if ((node_inputs.size() == kMinInputSize || IsSomePrimitiveList(node, INPUT_IS_TUPLE_OR_LIST_OPS)) &&
272       IsMakeSequence(node_inputs[1])) {
273     node_inputs = node_inputs[1]->cast<CNodePtr>()->inputs();
274   }
275 
276   node_inputs = extra_input_for_ifa(node, node_inputs);
277   return ExtractInputElementLength(node, node_inputs);
278 }
279 
ExtractOutputTypeByNode(const CNodePtr & node)280 std::vector<TypePtr> ExtractOutputTypeByNode(const CNodePtr &node) {
281   MS_EXCEPTION_IF_NULL(node);
282   std::vector<TypePtr> outputs_type;
283   // extract output element type
284   auto primary_output_type = node->Type();
285   MS_EXCEPTION_IF_NULL(primary_output_type);
286   if (primary_output_type->isa<mindspore::Tuple>()) {
287     // in this case, the output is a tuple
288     auto tuple_output_type = primary_output_type->cast<mindspore::TuplePtr>();
289     auto elements = tuple_output_type->elements();
290     for (auto &ele : elements) {
291       if (ele->isa<mindspore::TensorType>()) {
292         auto ele_element_type = ele->cast<mindspore::TensorTypePtr>()->element();
293         outputs_type.push_back(ele_element_type);
294       } else {
295         MS_LOG(EXCEPTION) << "Unknown type: " << primary_output_type->type_name();
296       }
297     }
298   } else {
299     // in this case, the output is a single tensor
300     if (primary_output_type->isa<mindspore::TensorType>()) {
301       auto element_type = primary_output_type->cast<mindspore::TensorTypePtr>()->element();
302       outputs_type.push_back(element_type);
303     } else {
304       MS_LOG(EXCEPTION) << "Unknown type: " << primary_output_type->type_name();
305     }
306   }
307   return outputs_type;
308 }
309 
FindParameterByRefKeyNode(const AnfNodePtr & node,const FuncGraphPtr & func_graph)310 std::vector<AnfNodePtr> FindParameterByRefKeyNode(const AnfNodePtr &node, const FuncGraphPtr &func_graph) {
311   MS_EXCEPTION_IF_NULL(node);
312   MS_EXCEPTION_IF_NULL(func_graph);
313   std::vector<AnfNodePtr> parameters;
314   if (!IsValueNode<RefKey>(node)) {
315     MS_LOG(ERROR) << "The node is not a ref key";
316     return parameters;
317   }
318 
319   auto ref_key = GetValueNode<StringImmPtr>(node);
320   MS_EXCEPTION_IF_NULL(ref_key);
321   auto name = ref_key->value();
322 
323   auto manager = func_graph->manager();
324   MS_EXCEPTION_IF_NULL(manager);
325   auto roots = manager->roots();
326   if (roots.size() != 1) {
327     MS_LOG(ERROR) << "The size of roots ( " << roots.size() << " ) is not 1";
328     return parameters;
329   }
330 
331   FuncGraphPtr root_g = roots.back();
332   MS_EXCEPTION_IF_NULL(root_g);
333   for (auto &param_node : root_g->parameters()) {
334     auto param = param_node->cast<ParameterPtr>();
335     if (param && (name == param->name())) {
336       parameters.push_back(param_node);
337       MS_LOG(INFO) << "The name of ref key is: " << name;
338       return parameters;
339     }
340   }
341 
342   MS_LOG(ERROR) << "The name of ref key is: " << name << ", but have not found the parameter";
343   return parameters;
344 }
345 
AnfNodeIsPrimitive(const AnfNodePtr & anf_node,const std::string & prim_name)346 bool AnfNodeIsPrimitive(const AnfNodePtr &anf_node, const std::string &prim_name) {
347   MS_EXCEPTION_IF_NULL(anf_node);
348   auto cnode = anf_node->cast<CNodePtr>();
349   if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) {
350     return false;
351   }
352 
353   auto value_node = cnode->input(0)->cast<ValueNodePtr>();
354   auto prim = GetValueNode<PrimitivePtr>(value_node);
355   MS_EXCEPTION_IF_NULL(prim);
356   if (prim->name() == prim_name) {
357     return true;
358   }
359   return false;
360 }
361 
FindReshape(const CNodePtr & cnode,mindspore::HashSet<std::string> * op_cache)362 bool FindReshape(const CNodePtr &cnode, mindspore::HashSet<std::string> *op_cache) {
363   if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) {
364     return false;
365   }
366   if (!IsParallelCareNode(cnode) || !cnode->has_user_data<OperatorInfo>()) {
367     return false;
368   }
369   ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
370   PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
371   MS_EXCEPTION_IF_NULL(prim);
372   if (prim->name() == RESHAPE) {
373     auto operator_info = cnode->user_data<OperatorInfo>();
374     std::string op_info_name = operator_info->name();
375     if (op_cache->find(op_info_name) != op_cache->end()) {
376       return false;
377     }
378     (void)op_cache->insert(op_info_name);
379     return true;
380   }
381   return false;
382 }
383 
FindReshapePreNodeCrossParam(const AnfNodePtr & node,OperatorInfoPtr * pre_operator_info,bool * is_prev_param,int64_t * out_index,size_t curr_depth)384 bool FindReshapePreNodeCrossParam(const AnfNodePtr &node, OperatorInfoPtr *pre_operator_info, bool *is_prev_param,
385                                   int64_t *out_index, size_t curr_depth) {
386   auto fg_map = node->func_graph()->func_graph_cnodes_index();
387   auto parameters = node->func_graph()->parameters();
388   int64_t param_index = -1;
389   for (size_t j = 0; j < parameters.size(); ++j) {
390     if (parameters[j] == node) {
391       param_index = SizeToLong(j);
392     }
393   }
394   if (fg_map.size() == 0 || param_index == -1) {
395     *is_prev_param = true;
396     return true;
397   }
398   auto temp_node = fg_map.begin()->first->first->cast<CNodePtr>();
399   auto prev_node = temp_node->input(param_index + 1);
400   return FindReshapePreNodeStraCosts(prev_node, pre_operator_info, is_prev_param, out_index, ++curr_depth);
401 }
402 
403 // Find previous node of Reshape, then obtain its strategy_cost_ vector to get its layout vector.
FindReshapePreNodeStraCosts(const AnfNodePtr & node,OperatorInfoPtr * pre_operator_info,bool * is_prev_param,int64_t * out_index,size_t curr_depth)404 bool FindReshapePreNodeStraCosts(const AnfNodePtr &node, OperatorInfoPtr *pre_operator_info, bool *is_prev_param,
405                                  int64_t *out_index, size_t curr_depth) {
406   if (curr_depth > MAX_RECURSIVE_DEPTH) {
407     MS_LOG(WARNING) << "When finding Reshape's previous node, exceeded the max recursive depth: "
408                     << MAX_RECURSIVE_DEPTH;
409     return false;
410   }
411   // if previous node is a parameter, handle it in the outsize.
412   if (node->isa<Parameter>()) {
413     return FindReshapePreNodeCrossParam(node, pre_operator_info, is_prev_param, out_index, curr_depth);
414   }
415   if (!node->isa<CNode>()) {
416     return false;
417   }
418   CNodePtr cnode = node->cast<CNodePtr>();
419   FindPreNodeCrossFuncGraph(&cnode, *out_index);
420   if (!IsValueNode<Primitive>(cnode->input(0))) {
421     return false;
422   }
423   auto node_op_info = cnode->user_data<OperatorInfo>();
424   if (IsParallelCareNode(cnode) && (node_op_info != nullptr) && !IsPrimitiveCNode(cnode, prim::kPrimReshape)) {
425     *pre_operator_info = node_op_info;
426     *out_index = 0;
427     return true;
428   }
429   ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
430   PrimitivePtr prim = prim_anf_node->value()->cast<PrimitivePtr>();
431   if (prim->name() == prim::kPrimTupleGetItem->name()) {
432     *out_index = GetTupleGetItemIndex(cnode);
433     // find tuple_get_item's previous node
434     auto pre_node = cnode->input(1);
435     if (!pre_node->isa<CNode>()) {
436       MS_LOG(EXCEPTION) << "tuple get item's second input is not a cnode";
437     }
438     CNodePtr pre_cnode = pre_node->cast<CNodePtr>();
439     FindPreNodeCrossFuncGraph(&pre_cnode, *out_index);
440     auto pre_op_info = pre_cnode->user_data<OperatorInfo>();
441     if (IsParallelCareNode(pre_cnode) && (pre_op_info != nullptr)) {
442       *pre_operator_info = pre_op_info;
443       return true;
444     }
445     return false;
446   }
447   for (size_t index = 0; index < cnode->size(); ++index) {
448     if (prim->name() == DEPEND && index != 1) {
449       continue;
450     }
451     if (!FindReshapePreNodeStraCosts(cnode->inputs()[index], pre_operator_info, is_prev_param, out_index,
452                                      ++curr_depth)) {
453       continue;
454     }
455     return true;
456   }
457   MS_LOG(WARNING)
458     << "FindReshapePreNodeStraCosts failed, if reshape is not the first primitive, there must be some error";
459   return false;
460 }
461 
462 // Find next node of Reshape, then obtain its strategy_cost_ vector to get its layout vector.
463 // if reshape's output connect to several primitive, return the first layout found
FindReshapeNextNodeStraCosts(const CNodePtr & cnode,std::vector<std::pair<OperatorInfoPtr,int64_t>> * next_ops_index,bool * is_next_reshape,size_t curr_depth)464 void FindReshapeNextNodeStraCosts(const CNodePtr &cnode,
465                                   std::vector<std::pair<OperatorInfoPtr, int64_t>> *next_ops_index,
466                                   bool *is_next_reshape, size_t curr_depth) {
467   if (curr_depth > MAX_RECURSIVE_DEPTH) {
468     MS_LOG(WARNING) << "When finding Reshape's next node, exceeded the max recursive depth: " << MAX_RECURSIVE_DEPTH;
469     return;
470   }
471   MS_EXCEPTION_IF_NULL(cnode);
472   MS_EXCEPTION_IF_NULL(cnode->func_graph());
473   FuncGraphManagerPtr manager = cnode->func_graph()->manager();
474   MS_EXCEPTION_IF_NULL(manager);
475   AnfNodeIndexSet node_set = manager->node_users()[cnode];
476   for (auto &node_pair : node_set) {
477     CNodePtr use_apply = node_pair.first->cast<CNodePtr>();
478     if (use_apply == nullptr ||
479         !(IsValueNode<Primitive>(use_apply->input(0)) || IsValueNode<FuncGraph>(use_apply->input(0)))) {
480       continue;
481     }
482     auto pair = node_pair;
483     if (IsValueNode<FuncGraph>(use_apply->input(0))) {
484       auto sub_graph = GetValueNode<FuncGraphPtr>(use_apply->input(0));
485       auto params = sub_graph->parameters();
486       auto sub_manager = sub_graph->manager();
487       auto sub_node_set = sub_manager->node_users()[params[node_pair.second - 1]];
488       for (auto &sub_node_pair : sub_node_set) {
489         use_apply = sub_node_pair.first->cast<CNodePtr>();
490         pair = sub_node_pair;
491         break;
492       }
493     }
494     if (IsPrimitiveCNode(use_apply, prim::kPrimReshape)) {
495       *is_next_reshape = true;
496       continue;
497     }
498     ValueNodePtr prim_anf_node = use_apply->input(0)->cast<ValueNodePtr>();
499     MS_EXCEPTION_IF_NULL(prim_anf_node);
500     PrimitivePtr node_prim = prim_anf_node->value()->cast<PrimitivePtr>();
501     MS_EXCEPTION_IF_NULL(node_prim);
502     MS_LOG(INFO) << "FindNextLayout prim " << node_prim->name();
503     if (node_prim->name() == DEPEND && pair.second != 1) {
504       continue;
505     }
506     auto op_info = use_apply->user_data<OperatorInfo>();
507     if (IsParallelCareNode(use_apply) && (op_info != nullptr)) {
508       MS_LOG(INFO) << "FindReshapeNextNodeStraCosts success prim " << node_prim->name();
509       *is_next_reshape = false;
510       next_ops_index->push_back(std::make_pair(op_info, pair.second - 1));
511       continue;
512     }
513     MS_LOG(DEBUG) << "FindReshapeNextNodeStraCosts failed prim " << node_prim->name() << "  "
514                   << IsParallelCareNode(use_apply) << "   " << (op_info != nullptr);
515 
516     FindReshapeNextNodeStraCosts(use_apply, next_ops_index, is_next_reshape, ++curr_depth);
517   }
518 }
519 
SetUserAttrs(const mindspore::HashMap<std::string,ValuePtr> & origin_prim_attrs,const PrimitivePtr & self_prim)520 void SetUserAttrs(const mindspore::HashMap<std::string, ValuePtr> &origin_prim_attrs, const PrimitivePtr &self_prim) {
521   MS_EXCEPTION_IF_NULL(self_prim);
522   for (auto attr_name : filter_attrs) {
523     auto iter = origin_prim_attrs.find(attr_name);
524     if (iter != origin_prim_attrs.cend()) {
525       self_prim->set_attr(attr_name, iter->second);
526       MS_LOG(INFO) << "The new prim " << self_prim << " add attr " << attr_name;
527     }
528   }
529 }
530 
531 // Convert ValueTuple/ValueList to vector
TransValueSequeueToVector(const ValuePtr & input_value,std::vector<int64_t> * input)532 Status TransValueSequeueToVector(const ValuePtr &input_value, std::vector<int64_t> *input) {
533   MS_EXCEPTION_IF_NULL(input_value);
534   input->clear();
535   if (!input_value->isa<ValueSequeue>()) {
536     MS_LOG(ERROR) << "Input value must be ValueTuplePtr.";
537     return FAILED;
538   }
539   ValueSequeuePtr value_seq = input_value->cast<ValueSequeuePtr>();
540   for (auto &element : value_seq->value()) {
541     MS_EXCEPTION_IF_NULL(element);
542     if (element->isa<Int64Imm>()) {
543       int64_t value = element->cast<Int64ImmPtr>()->value();
544       input->push_back(value);
545     } else {
546       MS_LOG(ERROR) << "The value must be int64";
547       return FAILED;
548     }
549   }
550   return SUCCESS;
551 }
552 
553 // Get the input of cnode, skipping DEPEND/LOAD/UPDATESTATE
RealInputNode(const CNodePtr cnode,size_t index)554 const AnfNodePtr RealInputNode(const CNodePtr cnode, size_t index) {
555   MS_EXCEPTION_IF_NULL(cnode);
556   if (cnode->size() <= index) {
557     MS_LOG(EXCEPTION) << "cnode inputs size: " << cnode->size() << " is less equal index: " << index;
558   }
559   auto input0 = cnode->input(index);
560   if (!IsPrimitiveCNode(input0)) {
561     return input0;
562   }
563   auto prim = GetCNodePrimitive(input0);
564   MS_EXCEPTION_IF_NULL(prim);
565   while (prim->name() == LOAD || prim->name() == DEPEND || prim->name() == UPDATESTATE) {
566     if (prim->name() == LOAD || prim->name() == DEPEND) {
567       input0 = input0->cast<CNodePtr>()->input(1);
568     } else {
569       input0 = input0->cast<CNodePtr>()->input(2);
570     }
571     if (!input0->isa<CNode>()) {
572       return input0;
573     }
574     prim = GetCNodePrimitive(input0);
575     MS_EXCEPTION_IF_NULL(prim);
576   }
577   return input0;
578 }
579 }  // namespace parallel
580 }  // namespace mindspore
581