1 /**
2 * Copyright 2020-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/parallel/graph_util/node_info.h"
18
19 #include <string>
20 #include <utility>
21
22 #include "ops/sequence_ops.h"
23 #include "ops/array_ops.h"
24 #include "ops/framework_ops.h"
25 #include "ir/param_info.h"
26 #include "ir/meta_tensor.h"
27 #include "include/common/utils/python_adapter.h"
28 #include "frontend/parallel/ops_info/ops_utils.h"
29 #include "frontend/parallel/step_parallel.h"
30 #include "frontend/parallel/step_parallel_utils.h"
31
32 namespace mindspore {
33 namespace parallel {
34 const std::vector<std::string> filter_attrs = {RECOMPUTE, TARGET};
35 const uint32_t kMinInputSize = 2;
36 constexpr size_t kSize2 = 2;
ParameterName(const AnfNodePtr & node_ptr)37 std::string ParameterName(const AnfNodePtr &node_ptr) {
38 auto para_ptr = node_ptr->cast<ParameterPtr>();
39 MS_EXCEPTION_IF_NULL(para_ptr);
40 return para_ptr->name();
41 }
42
ParameterRequireGrad(const AnfNodePtr & node_ptr)43 bool ParameterRequireGrad(const AnfNodePtr &node_ptr) {
44 auto para_ptr = node_ptr->cast<ParameterPtr>();
45 if (para_ptr == nullptr) {
46 return false;
47 }
48 if (!para_ptr->has_default()) {
49 return false;
50 }
51 auto param_value = para_ptr->param_info();
52 if (param_value == nullptr) {
53 return false;
54 }
55 return param_value->requires_grad();
56 }
57
GetRealInput(const AnfNodePtr & input)58 AnfNodePtr GetRealInput(const AnfNodePtr &input) {
59 auto res = input;
60 while (IsPrimitiveCNode(res, prim::kPrimLoad) || IsPrimitiveCNode(res, prim::kPrimDepend)) {
61 res = res->cast<CNodePtr>()->input(1);
62 if (!res->isa<CNode>()) {
63 return res;
64 }
65 }
66 return res;
67 }
68
69 // Given the node, return whether each input is a parameter or a output of a operator.
70 // The returned boolean vector should be the same order of the inputs, thus its implementation
71 // is closely consistent with ExtractShape() in step_parallel.cc
ExtractInputParameterByNode(const CNodePtr & node)72 std::vector<bool> ExtractInputParameterByNode(const CNodePtr &node) {
73 std::vector<bool> is_parameter;
74 std::vector<AnfNodePtr> node_inputs{node->inputs()};
75 // input is a ValueList or ValueTuple, then all inputs are not parameter.
76 if ((node_inputs.size() == kMinInputSize || IsSomePrimitiveList(node, INPUT_IS_TUPLE_OR_LIST_OPS)) &&
77 IsValueSequence(node_inputs[1])) {
78 std::vector<ValuePtr> inputs_seq;
79 if (IsValueNode<ValueList>(node_inputs[1])) {
80 inputs_seq = node_inputs[1]->cast<ValueNodePtr>()->value()->cast<ValueListPtr>()->value();
81 } else {
82 inputs_seq = node_inputs[1]->cast<ValueNodePtr>()->value()->cast<ValueTuplePtr>()->value();
83 }
84 size_t inputs_seq_tensor_size = inputs_seq.size();
85 for (const auto &inputs_seq_value : inputs_seq) {
86 auto tensor = inputs_seq_value->cast<tensor::TensorPtr>();
87 if (tensor == nullptr) {
88 MS_LOG(DEBUG) << "The value not is not a tensor.";
89 inputs_seq_tensor_size = 0;
90 break;
91 }
92 }
93 return std::vector<bool>(inputs_seq_tensor_size, false);
94 }
95 if ((node_inputs.size() == kMinInputSize || IsSomePrimitiveList(node, INPUT_IS_TUPLE_OR_LIST_OPS)) &&
96 IsMakeSequence(node_inputs[1])) {
97 node_inputs = node_inputs[1]->cast<CNodePtr>()->inputs();
98 }
99 for (size_t i = 1; i < node_inputs.size(); ++i) {
100 auto input = GetRealInput(node_inputs[i]);
101 if (HasAbstractMonad(input)) {
102 continue;
103 }
104 if (input->isa<Parameter>()) {
105 auto input_parameter = input->cast<ParameterPtr>();
106 is_parameter.push_back(ParameterRequireGrad(input_parameter));
107 } else if (input->isa<CNode>() || IsValueNode<tensor::Tensor>(input) || IsValueNode<RefKey>(input)) {
108 if (IsDynamicShapeInput(node, input)) {
109 MS_LOG(INFO) << "may be dynamic shape, no need to get input's shape, the node is " << node->ToString();
110 continue;
111 }
112 is_parameter.push_back(false);
113 }
114 }
115 return is_parameter;
116 }
117
ExtractInputParameterNameByNode(const CNodePtr & node)118 std::string ExtractInputParameterNameByNode(const CNodePtr &node) {
119 std::string param_name = "";
120 std::vector<AnfNodePtr> node_inputs{node->inputs()};
121 // input is a ValueList or ValueTuple, then all inputs are not parameter.
122 if ((node_inputs.size() == kMinInputSize || IsSomePrimitiveList(node, INPUT_IS_TUPLE_OR_LIST_OPS)) &&
123 IsValueSequence(node_inputs[1])) {
124 node_inputs = node_inputs[1]->cast<CNodePtr>()->inputs();
125 }
126 for (size_t i = 1; i < node_inputs.size(); ++i) {
127 auto input = GetRealInput(node_inputs[i]);
128 if (HasAbstractMonad(input)) {
129 continue;
130 }
131 if (input->isa<Parameter>()) {
132 param_name = input->fullname_with_scope();
133 auto input_parameter = input->cast<ParameterPtr>();
134 MS_LOG(INFO) << "node name: " << node->fullname_with_scope() << "involved parameter: " << input_parameter->name();
135 }
136 }
137 return param_name;
138 }
139
140 // Given the type, return the number of bytes to represent this type
GetLengthOfDataType(const TypePtr & type)141 size_t GetLengthOfDataType(const TypePtr &type) {
142 switch (type->type_id()) {
143 case kNumberTypeBool:
144 return sizeof(bool);
145 case kNumberTypeInt8:
146 return sizeof(int8_t);
147 case kNumberTypeInt16:
148 return sizeof(int16_t);
149 case kNumberTypeInt32:
150 return sizeof(int32_t);
151 case kNumberTypeInt64:
152 return sizeof(int64_t);
153 case kNumberTypeUInt8:
154 return sizeof(uint8_t);
155 case kNumberTypeUInt16:
156 return sizeof(uint16_t);
157 case kNumberTypeUInt32:
158 return sizeof(uint32_t);
159 case kNumberTypeUInt64:
160 return sizeof(uint64_t);
161 case kNumberTypeFloat16:
162 return sizeof(float) / kSize2;
163 case kNumberTypeFloat32:
164 return sizeof(float);
165 case kNumberTypeFloat64:
166 return sizeof(double);
167 case kNumberTypeInt:
168 return sizeof(int64_t);
169 case kNumberTypeUInt:
170 return sizeof(unsigned);
171 case kNumberTypeFloat:
172 return sizeof(float);
173 case kNumberTypeBFloat16:
174 return sizeof(float) / kSize2;
175 case kNumberTypeComplex64:
176 return sizeof(float) * kSize2;
177 default:
178 MS_LOG(EXCEPTION) << "Unexpected type " << type->type_name();
179 }
180 }
181
GetInputsTypeLen(const AnfNodePtr & input)182 size_t GetInputsTypeLen(const AnfNodePtr &input) {
183 MS_EXCEPTION_IF_NULL(input);
184 if (!input->isa<CNode>() && !input->isa<Parameter>() && !IsValueNode<tensor::Tensor>(input)) {
185 MS_LOG(EXCEPTION) << "The input node is not a cnode or parameter or tensor";
186 }
187
188 size_t input_type_len = 0;
189 auto type = input->Type();
190 MS_EXCEPTION_IF_NULL(type);
191 if (type->isa<mindspore::TensorType>()) {
192 auto input_element_type = type->cast<mindspore::TensorTypePtr>()->element();
193 input_type_len = GetLengthOfDataType(input_element_type);
194 } else {
195 MS_LOG(EXCEPTION) << "Unknown type: " << type->type_name();
196 }
197 return input_type_len;
198 }
199
ExtractInputElementLength(const CNodePtr & node,std::vector<AnfNodePtr> node_inputs)200 std::vector<size_t> ExtractInputElementLength(const CNodePtr &node, std::vector<AnfNodePtr> node_inputs) {
201 std::vector<size_t> inputs_type_len;
202 // extract input element length
203 for (auto &input : node_inputs) {
204 if (HasAbstractMonad(input)) {
205 continue;
206 }
207 if (IsValueNode<RefKey>(input)) {
208 auto func_graph = node->func_graph();
209 MS_EXCEPTION_IF_NULL(func_graph);
210 std::vector<AnfNodePtr> parameters = FindParameterByRefKeyNode(input, func_graph);
211 if (parameters.size() != 1) {
212 MS_LOG(EXCEPTION) << "Find parameter by ref key node failed";
213 }
214 inputs_type_len.push_back(GetInputsTypeLen(parameters[0]));
215 } else if (input->isa<CNode>() || input->isa<Parameter>() || IsValueNode<tensor::Tensor>(input)) {
216 if (IsDynamicShapeInput(node, input)) {
217 MS_LOG(INFO) << "may be dynamic shape, no need to get input's shape, the node is " << node->ToString();
218 continue;
219 }
220 // extract input shape from parameter and apply node
221 inputs_type_len.push_back(GetInputsTypeLen(input));
222 }
223 }
224 return inputs_type_len;
225 }
226
extra_input_for_ifa(CNodePtr node,std::vector<AnfNodePtr> node_input)227 std::vector<AnfNodePtr> extra_input_for_ifa(CNodePtr node, std::vector<AnfNodePtr> node_input) {
228 ValueNodePtr anf_node = node->input(0)->cast<ValueNodePtr>();
229 if (!anf_node) {
230 return node_input;
231 }
232 PrimitivePtr prim = anf_node->value()->cast<PrimitivePtr>();
233 if (!prim) {
234 return node_input;
235 }
236 if (prim->name() != INCRE_FLASH_ATTENTION) {
237 return node_input;
238 }
239 for (size_t input_index = 1; input_index < node_input.size(); input_index++) {
240 if (node_input[input_index] != nullptr && IsMakeSequence(node_input[input_index])) {
241 node_input[input_index] = node_input[input_index]->cast<CNodePtr>()->inputs()[1];
242 }
243 }
244 return node_input;
245 }
246
ExtractInputTypeLengthByNode(const CNodePtr & node)247 std::vector<size_t> ExtractInputTypeLengthByNode(const CNodePtr &node) {
248 MS_EXCEPTION_IF_NULL(node);
249 std::vector<size_t> inputs_type_len;
250 std::vector<AnfNodePtr> node_inputs{node->inputs()};
251
252 if ((node_inputs.size() == kMinInputSize || IsSomePrimitiveList(node, INPUT_IS_TUPLE_OR_LIST_OPS)) &&
253 IsValueSequence(node_inputs[1])) {
254 std::vector<ValuePtr> inputs_seq;
255 if (IsValueNode<ValueList>(node_inputs[1])) {
256 inputs_seq = node_inputs[1]->cast<ValueNodePtr>()->value()->cast<ValueListPtr>()->value();
257 } else {
258 inputs_seq = node_inputs[1]->cast<ValueNodePtr>()->value()->cast<ValueTuplePtr>()->value();
259 }
260 for (auto &ele : inputs_seq) {
261 auto tensor = ele->cast<tensor::TensorPtr>();
262 if (tensor == nullptr) {
263 inputs_type_len.clear();
264 return inputs_type_len;
265 }
266 inputs_type_len.push_back(GetLengthOfDataType(tensor->Dtype()));
267 }
268 return inputs_type_len;
269 }
270
271 if ((node_inputs.size() == kMinInputSize || IsSomePrimitiveList(node, INPUT_IS_TUPLE_OR_LIST_OPS)) &&
272 IsMakeSequence(node_inputs[1])) {
273 node_inputs = node_inputs[1]->cast<CNodePtr>()->inputs();
274 }
275
276 node_inputs = extra_input_for_ifa(node, node_inputs);
277 return ExtractInputElementLength(node, node_inputs);
278 }
279
ExtractOutputTypeByNode(const CNodePtr & node)280 std::vector<TypePtr> ExtractOutputTypeByNode(const CNodePtr &node) {
281 MS_EXCEPTION_IF_NULL(node);
282 std::vector<TypePtr> outputs_type;
283 // extract output element type
284 auto primary_output_type = node->Type();
285 MS_EXCEPTION_IF_NULL(primary_output_type);
286 if (primary_output_type->isa<mindspore::Tuple>()) {
287 // in this case, the output is a tuple
288 auto tuple_output_type = primary_output_type->cast<mindspore::TuplePtr>();
289 auto elements = tuple_output_type->elements();
290 for (auto &ele : elements) {
291 if (ele->isa<mindspore::TensorType>()) {
292 auto ele_element_type = ele->cast<mindspore::TensorTypePtr>()->element();
293 outputs_type.push_back(ele_element_type);
294 } else {
295 MS_LOG(EXCEPTION) << "Unknown type: " << primary_output_type->type_name();
296 }
297 }
298 } else {
299 // in this case, the output is a single tensor
300 if (primary_output_type->isa<mindspore::TensorType>()) {
301 auto element_type = primary_output_type->cast<mindspore::TensorTypePtr>()->element();
302 outputs_type.push_back(element_type);
303 } else {
304 MS_LOG(EXCEPTION) << "Unknown type: " << primary_output_type->type_name();
305 }
306 }
307 return outputs_type;
308 }
309
FindParameterByRefKeyNode(const AnfNodePtr & node,const FuncGraphPtr & func_graph)310 std::vector<AnfNodePtr> FindParameterByRefKeyNode(const AnfNodePtr &node, const FuncGraphPtr &func_graph) {
311 MS_EXCEPTION_IF_NULL(node);
312 MS_EXCEPTION_IF_NULL(func_graph);
313 std::vector<AnfNodePtr> parameters;
314 if (!IsValueNode<RefKey>(node)) {
315 MS_LOG(ERROR) << "The node is not a ref key";
316 return parameters;
317 }
318
319 auto ref_key = GetValueNode<StringImmPtr>(node);
320 MS_EXCEPTION_IF_NULL(ref_key);
321 auto name = ref_key->value();
322
323 auto manager = func_graph->manager();
324 MS_EXCEPTION_IF_NULL(manager);
325 auto roots = manager->roots();
326 if (roots.size() != 1) {
327 MS_LOG(ERROR) << "The size of roots ( " << roots.size() << " ) is not 1";
328 return parameters;
329 }
330
331 FuncGraphPtr root_g = roots.back();
332 MS_EXCEPTION_IF_NULL(root_g);
333 for (auto ¶m_node : root_g->parameters()) {
334 auto param = param_node->cast<ParameterPtr>();
335 if (param && (name == param->name())) {
336 parameters.push_back(param_node);
337 MS_LOG(INFO) << "The name of ref key is: " << name;
338 return parameters;
339 }
340 }
341
342 MS_LOG(ERROR) << "The name of ref key is: " << name << ", but have not found the parameter";
343 return parameters;
344 }
345
AnfNodeIsPrimitive(const AnfNodePtr & anf_node,const std::string & prim_name)346 bool AnfNodeIsPrimitive(const AnfNodePtr &anf_node, const std::string &prim_name) {
347 MS_EXCEPTION_IF_NULL(anf_node);
348 auto cnode = anf_node->cast<CNodePtr>();
349 if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) {
350 return false;
351 }
352
353 auto value_node = cnode->input(0)->cast<ValueNodePtr>();
354 auto prim = GetValueNode<PrimitivePtr>(value_node);
355 MS_EXCEPTION_IF_NULL(prim);
356 if (prim->name() == prim_name) {
357 return true;
358 }
359 return false;
360 }
361
FindReshape(const CNodePtr & cnode,mindspore::HashSet<std::string> * op_cache)362 bool FindReshape(const CNodePtr &cnode, mindspore::HashSet<std::string> *op_cache) {
363 if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) {
364 return false;
365 }
366 if (!IsParallelCareNode(cnode) || !cnode->has_user_data<OperatorInfo>()) {
367 return false;
368 }
369 ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
370 PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
371 MS_EXCEPTION_IF_NULL(prim);
372 if (prim->name() == RESHAPE) {
373 auto operator_info = cnode->user_data<OperatorInfo>();
374 std::string op_info_name = operator_info->name();
375 if (op_cache->find(op_info_name) != op_cache->end()) {
376 return false;
377 }
378 (void)op_cache->insert(op_info_name);
379 return true;
380 }
381 return false;
382 }
383
FindReshapePreNodeCrossParam(const AnfNodePtr & node,OperatorInfoPtr * pre_operator_info,bool * is_prev_param,int64_t * out_index,size_t curr_depth)384 bool FindReshapePreNodeCrossParam(const AnfNodePtr &node, OperatorInfoPtr *pre_operator_info, bool *is_prev_param,
385 int64_t *out_index, size_t curr_depth) {
386 auto fg_map = node->func_graph()->func_graph_cnodes_index();
387 auto parameters = node->func_graph()->parameters();
388 int64_t param_index = -1;
389 for (size_t j = 0; j < parameters.size(); ++j) {
390 if (parameters[j] == node) {
391 param_index = SizeToLong(j);
392 }
393 }
394 if (fg_map.size() == 0 || param_index == -1) {
395 *is_prev_param = true;
396 return true;
397 }
398 auto temp_node = fg_map.begin()->first->first->cast<CNodePtr>();
399 auto prev_node = temp_node->input(param_index + 1);
400 return FindReshapePreNodeStraCosts(prev_node, pre_operator_info, is_prev_param, out_index, ++curr_depth);
401 }
402
403 // Find previous node of Reshape, then obtain its strategy_cost_ vector to get its layout vector.
FindReshapePreNodeStraCosts(const AnfNodePtr & node,OperatorInfoPtr * pre_operator_info,bool * is_prev_param,int64_t * out_index,size_t curr_depth)404 bool FindReshapePreNodeStraCosts(const AnfNodePtr &node, OperatorInfoPtr *pre_operator_info, bool *is_prev_param,
405 int64_t *out_index, size_t curr_depth) {
406 if (curr_depth > MAX_RECURSIVE_DEPTH) {
407 MS_LOG(WARNING) << "When finding Reshape's previous node, exceeded the max recursive depth: "
408 << MAX_RECURSIVE_DEPTH;
409 return false;
410 }
411 // if previous node is a parameter, handle it in the outsize.
412 if (node->isa<Parameter>()) {
413 return FindReshapePreNodeCrossParam(node, pre_operator_info, is_prev_param, out_index, curr_depth);
414 }
415 if (!node->isa<CNode>()) {
416 return false;
417 }
418 CNodePtr cnode = node->cast<CNodePtr>();
419 FindPreNodeCrossFuncGraph(&cnode, *out_index);
420 if (!IsValueNode<Primitive>(cnode->input(0))) {
421 return false;
422 }
423 auto node_op_info = cnode->user_data<OperatorInfo>();
424 if (IsParallelCareNode(cnode) && (node_op_info != nullptr) && !IsPrimitiveCNode(cnode, prim::kPrimReshape)) {
425 *pre_operator_info = node_op_info;
426 *out_index = 0;
427 return true;
428 }
429 ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
430 PrimitivePtr prim = prim_anf_node->value()->cast<PrimitivePtr>();
431 if (prim->name() == prim::kPrimTupleGetItem->name()) {
432 *out_index = GetTupleGetItemIndex(cnode);
433 // find tuple_get_item's previous node
434 auto pre_node = cnode->input(1);
435 if (!pre_node->isa<CNode>()) {
436 MS_LOG(EXCEPTION) << "tuple get item's second input is not a cnode";
437 }
438 CNodePtr pre_cnode = pre_node->cast<CNodePtr>();
439 FindPreNodeCrossFuncGraph(&pre_cnode, *out_index);
440 auto pre_op_info = pre_cnode->user_data<OperatorInfo>();
441 if (IsParallelCareNode(pre_cnode) && (pre_op_info != nullptr)) {
442 *pre_operator_info = pre_op_info;
443 return true;
444 }
445 return false;
446 }
447 for (size_t index = 0; index < cnode->size(); ++index) {
448 if (prim->name() == DEPEND && index != 1) {
449 continue;
450 }
451 if (!FindReshapePreNodeStraCosts(cnode->inputs()[index], pre_operator_info, is_prev_param, out_index,
452 ++curr_depth)) {
453 continue;
454 }
455 return true;
456 }
457 MS_LOG(WARNING)
458 << "FindReshapePreNodeStraCosts failed, if reshape is not the first primitive, there must be some error";
459 return false;
460 }
461
462 // Find next node of Reshape, then obtain its strategy_cost_ vector to get its layout vector.
463 // if reshape's output connect to several primitive, return the first layout found
FindReshapeNextNodeStraCosts(const CNodePtr & cnode,std::vector<std::pair<OperatorInfoPtr,int64_t>> * next_ops_index,bool * is_next_reshape,size_t curr_depth)464 void FindReshapeNextNodeStraCosts(const CNodePtr &cnode,
465 std::vector<std::pair<OperatorInfoPtr, int64_t>> *next_ops_index,
466 bool *is_next_reshape, size_t curr_depth) {
467 if (curr_depth > MAX_RECURSIVE_DEPTH) {
468 MS_LOG(WARNING) << "When finding Reshape's next node, exceeded the max recursive depth: " << MAX_RECURSIVE_DEPTH;
469 return;
470 }
471 MS_EXCEPTION_IF_NULL(cnode);
472 MS_EXCEPTION_IF_NULL(cnode->func_graph());
473 FuncGraphManagerPtr manager = cnode->func_graph()->manager();
474 MS_EXCEPTION_IF_NULL(manager);
475 AnfNodeIndexSet node_set = manager->node_users()[cnode];
476 for (auto &node_pair : node_set) {
477 CNodePtr use_apply = node_pair.first->cast<CNodePtr>();
478 if (use_apply == nullptr ||
479 !(IsValueNode<Primitive>(use_apply->input(0)) || IsValueNode<FuncGraph>(use_apply->input(0)))) {
480 continue;
481 }
482 auto pair = node_pair;
483 if (IsValueNode<FuncGraph>(use_apply->input(0))) {
484 auto sub_graph = GetValueNode<FuncGraphPtr>(use_apply->input(0));
485 auto params = sub_graph->parameters();
486 auto sub_manager = sub_graph->manager();
487 auto sub_node_set = sub_manager->node_users()[params[node_pair.second - 1]];
488 for (auto &sub_node_pair : sub_node_set) {
489 use_apply = sub_node_pair.first->cast<CNodePtr>();
490 pair = sub_node_pair;
491 break;
492 }
493 }
494 if (IsPrimitiveCNode(use_apply, prim::kPrimReshape)) {
495 *is_next_reshape = true;
496 continue;
497 }
498 ValueNodePtr prim_anf_node = use_apply->input(0)->cast<ValueNodePtr>();
499 MS_EXCEPTION_IF_NULL(prim_anf_node);
500 PrimitivePtr node_prim = prim_anf_node->value()->cast<PrimitivePtr>();
501 MS_EXCEPTION_IF_NULL(node_prim);
502 MS_LOG(INFO) << "FindNextLayout prim " << node_prim->name();
503 if (node_prim->name() == DEPEND && pair.second != 1) {
504 continue;
505 }
506 auto op_info = use_apply->user_data<OperatorInfo>();
507 if (IsParallelCareNode(use_apply) && (op_info != nullptr)) {
508 MS_LOG(INFO) << "FindReshapeNextNodeStraCosts success prim " << node_prim->name();
509 *is_next_reshape = false;
510 next_ops_index->push_back(std::make_pair(op_info, pair.second - 1));
511 continue;
512 }
513 MS_LOG(DEBUG) << "FindReshapeNextNodeStraCosts failed prim " << node_prim->name() << " "
514 << IsParallelCareNode(use_apply) << " " << (op_info != nullptr);
515
516 FindReshapeNextNodeStraCosts(use_apply, next_ops_index, is_next_reshape, ++curr_depth);
517 }
518 }
519
SetUserAttrs(const mindspore::HashMap<std::string,ValuePtr> & origin_prim_attrs,const PrimitivePtr & self_prim)520 void SetUserAttrs(const mindspore::HashMap<std::string, ValuePtr> &origin_prim_attrs, const PrimitivePtr &self_prim) {
521 MS_EXCEPTION_IF_NULL(self_prim);
522 for (auto attr_name : filter_attrs) {
523 auto iter = origin_prim_attrs.find(attr_name);
524 if (iter != origin_prim_attrs.cend()) {
525 self_prim->set_attr(attr_name, iter->second);
526 MS_LOG(INFO) << "The new prim " << self_prim << " add attr " << attr_name;
527 }
528 }
529 }
530
531 // Convert ValueTuple/ValueList to vector
TransValueSequeueToVector(const ValuePtr & input_value,std::vector<int64_t> * input)532 Status TransValueSequeueToVector(const ValuePtr &input_value, std::vector<int64_t> *input) {
533 MS_EXCEPTION_IF_NULL(input_value);
534 input->clear();
535 if (!input_value->isa<ValueSequeue>()) {
536 MS_LOG(ERROR) << "Input value must be ValueTuplePtr.";
537 return FAILED;
538 }
539 ValueSequeuePtr value_seq = input_value->cast<ValueSequeuePtr>();
540 for (auto &element : value_seq->value()) {
541 MS_EXCEPTION_IF_NULL(element);
542 if (element->isa<Int64Imm>()) {
543 int64_t value = element->cast<Int64ImmPtr>()->value();
544 input->push_back(value);
545 } else {
546 MS_LOG(ERROR) << "The value must be int64";
547 return FAILED;
548 }
549 }
550 return SUCCESS;
551 }
552
553 // Get the input of cnode, skipping DEPEND/LOAD/UPDATESTATE
RealInputNode(const CNodePtr cnode,size_t index)554 const AnfNodePtr RealInputNode(const CNodePtr cnode, size_t index) {
555 MS_EXCEPTION_IF_NULL(cnode);
556 if (cnode->size() <= index) {
557 MS_LOG(EXCEPTION) << "cnode inputs size: " << cnode->size() << " is less equal index: " << index;
558 }
559 auto input0 = cnode->input(index);
560 if (!IsPrimitiveCNode(input0)) {
561 return input0;
562 }
563 auto prim = GetCNodePrimitive(input0);
564 MS_EXCEPTION_IF_NULL(prim);
565 while (prim->name() == LOAD || prim->name() == DEPEND || prim->name() == UPDATESTATE) {
566 if (prim->name() == LOAD || prim->name() == DEPEND) {
567 input0 = input0->cast<CNodePtr>()->input(1);
568 } else {
569 input0 = input0->cast<CNodePtr>()->input(2);
570 }
571 if (!input0->isa<CNode>()) {
572 return input0;
573 }
574 prim = GetCNodePrimitive(input0);
575 MS_EXCEPTION_IF_NULL(prim);
576 }
577 return input0;
578 }
579 } // namespace parallel
580 } // namespace mindspore
581