1 /**
2 * Copyright 2024 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <map>
18 #include <vector>
19 #include <string>
20 #include <memory>
21 #include <utility>
22 #include "frontend/parallel/graph_util/graph_utils.h"
23 #include "frontend/parallel/ops_info/ops_utils.h"
24 #include "frontend/parallel/step_parallel_utils.h"
25 #include "frontend/parallel/parameter_manager.h"
26 #include "frontend/parallel/graph_util/generate_graph.h"
27 #include "frontend/parallel/graph_util/graph_info.h"
28 #include "frontend/parallel/tensor_layout/prime_generator.h"
29 #include "mindspore/core/ir/primitive.h"
30 #include "mindspore/core/ir/func_graph.h"
31 #include "include/common/utils/anfalgo.h"
32
33 namespace mindspore::parallel {
GetPrimeFactor(int64_t value)34 int64_t GetPrimeFactor(int64_t value) {
35 static const std::vector<int64_t> prime_table = PrimeGenerator::GetInstance()->GetPrimeTable();
36 for (const auto &prime : prime_table) {
37 if (prime > value) {
38 return -1;
39 }
40 if (value % prime == 0) {
41 return prime;
42 }
43 }
44 return -1;
45 }
46
CreateShape(const AnfNodePtr & pre_cnode,const FuncGraphPtr & func_graph,const std::string & inst_name)47 CNodePtr CreateShape(const AnfNodePtr &pre_cnode, const FuncGraphPtr &func_graph, const std::string &inst_name) {
48 auto prim = std::make_shared<Primitive>(SHAPE_OP);
49 prim->set_instance_name(inst_name);
50 AnfNodePtrList shape_node_inputs(SIZE_TWO);
51 shape_node_inputs[0] = NewValueNode(prim);
52 shape_node_inputs[1] = pre_cnode;
53 auto shape_cnode = func_graph->NewCNode(shape_node_inputs);
54 return shape_cnode;
55 }
56
IsTargetOp(const CNodePtr & cnode,const std::string & target)57 inline bool IsTargetOp(const CNodePtr &cnode, const std::string &target) { return GetPrimName(cnode) == target; }
58
IsTupleGetItem(const CNodePtr & cnode)59 bool IsTupleGetItem(const CNodePtr &cnode) { return IsTargetOp(cnode, TUPLE_GETITEM_OP); }
60
IsReshapeOp(const CNodePtr & cnode)61 bool IsReshapeOp(const CNodePtr &cnode) { return IsTargetOp(cnode, RESHAPE); }
62
IsShapeOp(const CNodePtr & cnode)63 bool IsShapeOp(const CNodePtr &cnode) { return IsTargetOp(cnode, SHAPE_OP); }
64
GetTensorRedistributionFromCNode(const CNodePtr & node)65 TensorRedistributionPtr GetTensorRedistributionFromCNode(const CNodePtr &node) {
66 OperatorInfoPtr distribute_operator = GetDistributeOperator(node);
67 if (distribute_operator == nullptr) {
68 MS_LOG(WARNING) << node->fullname_with_scope() << " has no OperatorInfo.";
69 return nullptr;
70 }
71 if (IsReshapeOp(node)) {
72 return distribute_operator->reshape_tensor_redistribution();
73 }
74 return distribute_operator->tensor_redistribution();
75 }
76
IsDynamicOp(const CNodePtr & node)77 bool IsDynamicOp(const CNodePtr &node) {
78 TensorRedistributionPtr tensor_redistribution = GetTensorRedistributionFromCNode(node);
79 if (tensor_redistribution == nullptr) {
80 return false;
81 }
82 return tensor_redistribution->IsAssembledStaticShape();
83 }
84
FindForwardGraphByRootNodes(const std::vector<AnfNodePtr> & root_all_nodes)85 std::set<FuncGraphPtr> FindForwardGraphByRootNodes(const std::vector<AnfNodePtr> &root_all_nodes) {
86 // J->CNode->Graph
87 std::set<FuncGraphPtr> graph_set;
88 for (auto &node : root_all_nodes) {
89 MS_EXCEPTION_IF_NULL(node);
90 if (!node->isa<CNode>()) {
91 continue;
92 }
93
94 auto cnode = node->cast<CNodePtr>();
95 if ((cnode->size() < SIZE_TWO) || !IsValueNode<Primitive>(cnode->input(0))) {
96 continue;
97 }
98 auto expect_prim = GetValueNode<PrimitivePtr>(cnode->input(0));
99 if (expect_prim->name() != J && expect_prim->name() != SHARD) {
100 continue;
101 }
102 if (IsValueNode<FuncGraph>(cnode->input(1))) {
103 auto graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
104 MS_LOG(DEBUG) << "Find the forward graph success";
105 (void)graph_set.insert(graph);
106 auto manager = graph->manager();
107 MS_EXCEPTION_IF_NULL(manager);
108 auto graph_used = manager->func_graphs_used_total(graph);
109 for (auto iter = graph_used.cbegin(); iter != graph_used.cend(); ++iter) {
110 (void)graph_set.insert(*iter);
111 }
112 }
113 }
114 return graph_set;
115 }
116
GetAccuGrad(const std::vector<AnfNodePtr> & parameters,const std::string & weight_name)117 AnfNodePtr GetAccuGrad(const std::vector<AnfNodePtr> ¶meters, const std::string &weight_name) {
118 for (auto ¶m : parameters) {
119 if (!ParameterIsCloned(param)) {
120 continue;
121 }
122
123 auto param_ptr = param->cast<ParameterPtr>();
124 MS_EXCEPTION_IF_NULL(param_ptr);
125 auto accu_grads_name = std::string(ACCU_GRADS) + "." + weight_name;
126 if (param_ptr->name() == accu_grads_name) {
127 MS_LOG(INFO) << "Find the accumulation grad node: " << param_ptr->name();
128 return param;
129 }
130 }
131 return nullptr;
132 }
133
CreateMirrorInput(const FuncGraphPtr & root,const Operator & op,const AnfNodePtr & node,const std::string & instance_name,const std::string & weight_name)134 std::vector<AnfNodePtr> CreateMirrorInput(const FuncGraphPtr &root, const Operator &op, const AnfNodePtr &node,
135 const std::string &instance_name, const std::string &weight_name) {
136 MS_EXCEPTION_IF_NULL(root);
137 MS_EXCEPTION_IF_NULL(node);
138 MS_EXCEPTION_IF_NULL(root->manager());
139
140 std::string op_name = op.first;
141 OperatorArgs arg_forward = op.second;
142 AnfNodePtr grad_accu = nullptr;
143
144 int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step();
145 int64_t split_stage_num = ParallelContext::GetInstance()->pipeline_stage_split_num();
146 if (grad_accumulation_step > 1 || split_stage_num > 1) {
147 auto parameters = root->parameters();
148 grad_accu = GetAccuGrad(parameters, weight_name);
149 if (!grad_accu && op_name == MICRO_STEP_ALL_GATHER) {
150 MS_LOG(EXCEPTION) << "You should define `accu_grads` when use " << op_name << " parameter:" << weight_name;
151 }
152 }
153
154 OperatorParams params = arg_forward.second;
155
156 std::vector<AnfNodePtr> new_node_input;
157 if (op_name == MIRROR_MINI_STEP_OPERATOR || op_name == MINI_STEP_ALL_GATHER ||
158 op_name == MIRROR_MICRO_STEP_OPERATOR || op_name == MICRO_STEP_ALL_GATHER) {
159 MS_EXCEPTION_IF_NULL(grad_accu);
160 new_node_input = {node, grad_accu};
161 MS_LOG(INFO) << "Insert the grad accumulation node as the mirror op's input";
162 } else {
163 new_node_input = {node};
164 }
165
166 if (!params.empty()) {
167 for (auto ¶m : params) {
168 AnfNodePtr val = NewValueNode(param.first.second);
169 MS_EXCEPTION_IF_NULL(val);
170 int64_t position = param.second;
171 (void)new_node_input.insert(new_node_input.cbegin() + position - 1, val);
172 }
173 }
174
175 new_node_input = ConvertToRealInputs(op_name, instance_name, new_node_input, arg_forward.first);
176 // if the op have 'group' attr, set the rank list name for the op
177 SetCommunicationOpGroupLabel(new_node_input);
178 return new_node_input;
179 }
180
CreateMakeTuple(const std::vector<AnfNodePtr> & tuple_inputs,const FuncGraphPtr & func_graph,const std::string & instance_name="")181 CNodePtr CreateMakeTuple(const std::vector<AnfNodePtr> &tuple_inputs, const FuncGraphPtr &func_graph,
182 const std::string &instance_name = "") {
183 MS_EXCEPTION_IF_NULL(func_graph);
184 std::vector<AnfNodePtr> make_tuple_inputs(tuple_inputs.size() + 1);
185 auto prim = std::make_shared<Primitive>(MAKE_TUPLE);
186 if (!instance_name.empty()) {
187 prim->set_instance_name(instance_name);
188 }
189 make_tuple_inputs[0] = NewValueNode(prim);
190 for (size_t i = 0; i < tuple_inputs.size(); ++i) {
191 make_tuple_inputs[i + 1] = tuple_inputs[i];
192 }
193 auto make_tuple = func_graph->NewCNode(make_tuple_inputs);
194 return make_tuple;
195 }
196
CreateSplit(const std::vector<AnfNodePtr> & inputs,const FuncGraphPtr & func_graph,const std::string & inst_name)197 CNodePtr CreateSplit(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &func_graph,
198 const std::string &inst_name) {
199 MS_EXCEPTION_IF_NULL(func_graph);
200 MS_EXCEPTION_IF_CHECK_FAIL(inputs.size() == SIZE_THREE, "inputs is empty.");
201 auto prim = std::make_shared<Primitive>(SPLIT);
202 if (!inst_name.empty()) {
203 prim->set_instance_name(inst_name);
204 }
205 std::vector<AnfNodePtr> split_inputs(SIZE_FOUR);
206 split_inputs[INDEX_ZERO] = NewValueNode(prim);
207 split_inputs[INDEX_ONE] = inputs[INDEX_ZERO]; // split_input
208 split_inputs[INDEX_TWO] = inputs[INDEX_ONE]; // split_axis
209 split_inputs[INDEX_THREE] = inputs[INDEX_TWO]; // split_size
210 auto split = func_graph->NewCNode(split_inputs);
211 return split;
212 }
213
CreateCast(const AnfNodePtr & cast_input,const ValueNodePtr & dest_type,const FuncGraphPtr & func_graph)214 CNodePtr CreateCast(const AnfNodePtr &cast_input, const ValueNodePtr &dest_type, const FuncGraphPtr &func_graph) {
215 auto cast_prim = NewValueNode(prim::kPrimScalarCast);
216 auto cast = func_graph->NewCNode({cast_prim, cast_input, dest_type});
217 return cast;
218 }
219
CreateDiv(const AnfNodePtr & input_node,int64_t divisor,const FuncGraphPtr & func_graph,bool to_long,const std::string & inst_name)220 AnfNodePtr CreateDiv(const AnfNodePtr &input_node, int64_t divisor, const FuncGraphPtr &func_graph, bool to_long,
221 const std::string &inst_name) {
222 MS_EXCEPTION_IF_NULL(input_node);
223 MS_EXCEPTION_IF_NULL(func_graph);
224 MS_EXCEPTION_IF_ZERO("div_divisor", divisor);
225 if (divisor == 1) {
226 return input_node;
227 }
228 auto prim = std::make_shared<Primitive>(SCALAR_FLOOR_DIV);
229 if (!inst_name.empty()) {
230 prim->set_instance_name(inst_name);
231 }
232 std::vector<AnfNodePtr> inputs(SIZE_THREE);
233 inputs[INDEX_ZERO] = NewValueNode(prim);
234 inputs[INDEX_ONE] = input_node;
235 inputs[INDEX_TWO] = CreatInt64Imm(divisor);
236 auto div = func_graph->NewCNode(inputs);
237 if (to_long) {
238 auto type_id = NewValueNode(MakeValue(static_cast<int64_t>(kInt64->type_id())));
239 return CreateCast(div, type_id, func_graph);
240 }
241 return div;
242 }
243
CreateMul(const AnfNodePtr & input_node,const int64_t factor,const FuncGraphPtr & func_graph,bool to_long=false,const std::string & inst_name="")244 CNodePtr CreateMul(const AnfNodePtr &input_node, const int64_t factor, const FuncGraphPtr &func_graph,
245 bool to_long = false, const std::string &inst_name = "") {
246 MS_EXCEPTION_IF_NULL(input_node);
247 MS_EXCEPTION_IF_NULL(func_graph);
248 MS_EXCEPTION_IF_ZERO("mul_factor", factor);
249 auto prim = std::make_shared<Primitive>(SCALAR_MUL);
250 if (!inst_name.empty()) {
251 prim->set_instance_name(inst_name);
252 }
253 std::vector<AnfNodePtr> inputs(SIZE_THREE);
254 inputs[INDEX_ZERO] = NewValueNode(prim);
255 inputs[INDEX_ONE] = input_node;
256 inputs[INDEX_TWO] = CreatInt64Imm(factor);
257 auto mul = func_graph->NewCNode(inputs);
258 if (to_long) {
259 auto type_id = NewValueNode(MakeValue(static_cast<int64_t>(kInt64->type_id())));
260 return CreateCast(mul, type_id, func_graph);
261 }
262 return mul;
263 }
264
MatchWithPrime(const AssembledDynamicDimsMapping & dyn_dims_mapping,int64_t prime)265 bool MatchWithPrime(const AssembledDynamicDimsMapping &dyn_dims_mapping, int64_t prime) {
266 for (const auto &iter : dyn_dims_mapping) {
267 int64_t prime_base = GetPrimeFactor(iter.first);
268 if (prime_base == prime) {
269 return true;
270 }
271 }
272 return false;
273 }
274
IsSameRank(const Shape & shape_vec,const Shape & targe_shape_vec)275 inline bool IsSameRank(const Shape &shape_vec, const Shape &targe_shape_vec) {
276 return shape_vec.size() == targe_shape_vec.size();
277 }
278
HasAssebledDynamicDim(const Shape & shape_vec,const AssembledDynamicDimsMapping & dyn_dims_mapping,const TensorRedistributionPtr & tensor_redistribution,bool is_same_rank)279 bool HasAssebledDynamicDim(const Shape &shape_vec, const AssembledDynamicDimsMapping &dyn_dims_mapping,
280 const TensorRedistributionPtr &tensor_redistribution, bool is_same_rank) {
281 for (int64_t dim : shape_vec) {
282 auto iter = dyn_dims_mapping.find(dim);
283 if (iter != dyn_dims_mapping.end()) {
284 return true;
285 }
286 int64_t prime_base = dim;
287 while (prime_base > 1) {
288 int64_t prime_of_dim = GetPrimeFactor(prime_base);
289 if (prime_of_dim == -1) {
290 break;
291 }
292 if (MatchWithPrime(dyn_dims_mapping, prime_of_dim)) {
293 return true;
294 }
295 prime_base /= prime_of_dim;
296 }
297 }
298 return false;
299 }
300
MatchingAccordingToPrime(const Shape & shape_vec,const AssembledDynamicDimsMapping & dyn_dims_mapping,const TensorRedistributionPtr & tensor_redistribution,const FuncGraphPtr & func_graph,std::vector<AnfNodePtr> * shape_input,enum ReshapeMode reshape_mode=ReshapeMode::NO_RESHAPE)301 void MatchingAccordingToPrime(const Shape &shape_vec, const AssembledDynamicDimsMapping &dyn_dims_mapping,
302 const TensorRedistributionPtr &tensor_redistribution, const FuncGraphPtr &func_graph,
303 std::vector<AnfNodePtr> *shape_input,
304 enum ReshapeMode reshape_mode = ReshapeMode::NO_RESHAPE) {
305 MS_LOG(INFO) << "Match with prime, shape_vec=" << shape_vec << ", reshape_mode=" << reshape_mode;
306 MS_EXCEPTION_IF_NULL(shape_input);
307 // If the shape not changed, it means not reshape.
308 // So the dynamic dim can be matched according to index.
309 std::string instance_name = std::string(REDISTRIBUTION_OP) + "_" + "assemble_shape";
310 for (size_t i = 0; i < shape_vec.size(); ++i) {
311 int64_t dim = shape_vec[i];
312 // TODO(liuchongming): dim could has more than one prime, have to get all prime in dim.
313 int64_t dim_prime = GetPrimeFactor(dim);
314 bool found = false;
315 if (dim != -1 && dim_prime != -1) {
316 for (const auto &iter : dyn_dims_mapping) {
317 int64_t dim_value_in_graph = iter.first;
318 AnfNodePtr tuple_getitem = iter.second.second;
319 int64_t dyn_prime = GetPrimeFactor(dim_value_in_graph);
320 if (dyn_prime != dim_prime) {
321 continue;
322 }
323 MS_LOG(INFO) << "i=" << i << ", dim_value_in_graph=" << dim_value_in_graph << ", dim_prime=" << dim_prime
324 << ", dim=" << dim;
325 if (dim_value_in_graph > dim) {
326 int64_t divisor = dim_value_in_graph / dim;
327 AnfNodePtr div_op = CreateDiv(tuple_getitem, divisor, func_graph, false, instance_name);
328 (void)shape_input->emplace_back(div_op);
329 found = true;
330 break;
331 } else if (dim_value_in_graph < dim) {
332 int64_t divisor = dim / dim_value_in_graph;
333 AnfNodePtr mul_op = CreateMul(tuple_getitem, divisor, func_graph, false, instance_name);
334 (void)shape_input->emplace_back(mul_op);
335 found = true;
336 break;
337 } else {
338 (void)shape_input->emplace_back(tuple_getitem);
339 found = true;
340 break;
341 }
342 }
343 }
344 if (!found) {
345 MS_LOG(INFO) << "Cannot find " << dim << " in shape param.";
346 AnfNodePtr val = CreatInt64Imm(dim);
347 (void)shape_input->emplace_back(val);
348 }
349 }
350 }
351
MatchingAccordingToIndex(const Shape & shape_vec,const AssembledDynamicDimsMapping & dyn_dims_mapping,const TensorRedistributionPtr & tensor_redistribution,const FuncGraphPtr & func_graph,std::vector<AnfNodePtr> * shape_input,enum ReshapeMode reshape_mode=ReshapeMode::NO_RESHAPE)352 void MatchingAccordingToIndex(const Shape &shape_vec, const AssembledDynamicDimsMapping &dyn_dims_mapping,
353 const TensorRedistributionPtr &tensor_redistribution, const FuncGraphPtr &func_graph,
354 std::vector<AnfNodePtr> *shape_input,
355 enum ReshapeMode reshape_mode = ReshapeMode::NO_RESHAPE) {
356 MS_LOG(INFO) << "Match with index, shape_vec=" << shape_vec;
357 MS_EXCEPTION_IF_NULL(shape_input);
358 TensorLayout to_layout = tensor_redistribution->layout_transfer().to_in();
359 TensorLayout from_layout = tensor_redistribution->layout_transfer().from_in();
360 // If the shape not changed, it means not reshape.
361 // So the dynamic dim can be matched according to index.
362 // {index, {prime_dim, AnfNode}}
363 std::map<size_t, std::pair<int64_t, AnfNodePtr>> mapping_table;
364 for (const auto &iter : dyn_dims_mapping) {
365 mapping_table.insert({iter.second.first, {iter.first, iter.second.second}});
366 }
367 for (size_t i = 0; i < shape_vec.size(); ++i) {
368 int64_t dim = shape_vec[i];
369 if (dim != -1 && mapping_table.find(i) != mapping_table.end()) {
370 std::pair<int64_t, AnfNodePtr> tuple_getitem_input_pair = mapping_table[i];
371 int64_t dim_value_in_graph = tuple_getitem_input_pair.first;
372 int64_t dim_prime = GetPrimeFactor(dim);
373 int64_t tuple_getitem_prime = GetPrimeFactor(tuple_getitem_input_pair.first);
374 if (dim_prime != tuple_getitem_prime) {
375 MS_LOG(EXCEPTION) << "Prime in dim and dynamic input are not matched, " << dim_prime << " for " << dim
376 << " and " << tuple_getitem_prime << " for " << tuple_getitem_input_pair.first;
377 }
378 // After matching with prime, fetch the real dim value in graph and
379 // calculate whether it needs mul/div.
380 if (dim_value_in_graph > dim) {
381 int64_t divisor = dim_value_in_graph / dim;
382 AnfNodePtr div_op =
383 CreateDiv(tuple_getitem_input_pair.second, divisor, func_graph, true, "assemble_dynamic_shape_op");
384 (void)shape_input->emplace_back(div_op);
385 continue;
386 }
387 if (dim_value_in_graph < dim) {
388 int64_t divisor = dim / dim_value_in_graph;
389 AnfNodePtr mul_op =
390 CreateMul(tuple_getitem_input_pair.second, divisor, func_graph, true, "assemble_dynamic_shape_op");
391 (void)shape_input->emplace_back(mul_op);
392 continue;
393 }
394 (void)shape_input->emplace_back(tuple_getitem_input_pair.second);
395 continue;
396 }
397 MS_LOG(INFO) << "Cannot find " << dim << " in shape param.";
398 AnfNodePtr val = CreatInt64Imm(dim);
399 (void)shape_input->emplace_back(val);
400 }
401 }
402
CountDynamicAxis(const AnfNodePtrList & shape_input)403 int64_t CountDynamicAxis(const AnfNodePtrList &shape_input) {
404 int64_t dyn_axis_cnt = 0;
405 for (size_t i = 0; i < shape_input.size(); ++i) {
406 if (shape_input[i]->isa<ValueNode>()) {
407 auto val_node = shape_input[i]->cast<ValueNodePtr>();
408 MS_EXCEPTION_IF_NULL(val_node->value());
409 int64_t index = GetValue<int64_t>(val_node->value());
410 if (index == -1) {
411 dyn_axis_cnt += 1;
412 }
413 } else {
414 dyn_axis_cnt += 1;
415 }
416 }
417 return dyn_axis_cnt;
418 }
419
WhetherIsValueNode(const AnfNodePtr & node)420 inline bool WhetherIsValueNode(const AnfNodePtr &node) { return node->isa<ValueNode>(); }
421
ConvertConstParamToDynamic(const TensorRedistributionPtr & tensor_redistribution,const Param & param,const FuncGraphPtr & func_graph,bool is_reshape,enum ReshapeMode reshape_mode=ReshapeMode::NO_RESHAPE)422 AnfNodePtr ConvertConstParamToDynamic(const TensorRedistributionPtr &tensor_redistribution, const Param ¶m,
423 const FuncGraphPtr &func_graph, bool is_reshape,
424 enum ReshapeMode reshape_mode = ReshapeMode::NO_RESHAPE) {
425 // Only ConvertReshapeInputs will use this function.
426 MS_EXCEPTION_IF_NULL(tensor_redistribution);
427 AssembledDynamicDimsMapping dyn_dims_mapping = tensor_redistribution->GetDynamicDimsMapping();
428 if (dyn_dims_mapping.empty()) {
429 MS_LOG(ERROR) << "Doesn't have dynamic dims mapping.";
430 return nullptr;
431 }
432 std::vector<int64_t> shape_vec = GetValue<std::vector<int64_t>>(param.first.second);
433 if (shape_vec.empty()) {
434 MS_LOG(ERROR) << "Cannot get shape from param.";
435 return nullptr;
436 }
437
438 // After refactor, dyn_dims_mapping is generated according to origin_from_shape.
439 // Reshape has 3 scenes:
440 // 1. from_origin_->from_layout.from: when shape is squeezed, 1 in front or in back are removed from from_origin.
441 // 2. to_layout.to->to_origin_: when shape is unified, it could be expanded.
442 // 3. User's reshape: written in user's scripts.
443 Shape origin_from_shape = tensor_redistribution->from_origin_layout().tensor_shape().array();
444 Shape origin_slice_from_shape = tensor_redistribution->from_origin_layout().slice_shape().array();
445 Shape from_shape = tensor_redistribution->from_layout().tensor_shape().array();
446 Shape unified_from_shape = tensor_redistribution->layout_transfer().from_in().tensor_shape().array();
447 Shape unified_slice_from_shape = tensor_redistribution->layout_transfer().from_in().slice_shape().array();
448 MS_LOG(INFO) << "reshape_mode=" << reshape_mode << ", shape_vec: " << shape_vec
449 << ", origin_from_shape: " << origin_from_shape
450 << ", \norigin_slice_from_shape: " << origin_slice_from_shape << ", \nfrom_shape: " << from_shape
451 << ", \nunified_from_shape: " << unified_from_shape
452 << ", \nunified_slice_from_shape:" << unified_slice_from_shape;
453 // The rank should be compared between shape_vec and origin_from_shape, because
454 // the mapping is generated according to origin_from_shape.
455 bool is_same_rank = IsSameRank(shape_vec, origin_from_shape);
456 if (!HasAssebledDynamicDim(shape_vec, dyn_dims_mapping, tensor_redistribution, is_same_rank)) {
457 // If the shape_vec is (-1, dim_1) and dim_1 is not a generated fake value by tensor redistribution,
458 // so it doesn't have to match.
459 AnfNodePtr val = NewValueNode(param.first.second);
460 MS_EXCEPTION_IF_NULL(val);
461 val->set_abstract(param.first.second->ToAbstract());
462 return val;
463 }
464 if (shape_vec.size() == 1) {
465 std::vector<int64_t> const_shape{-1};
466 AnfNodePtr val = NewValueNode(const_shape);
467 val->set_abstract(param.first.second->ToAbstract());
468 return val;
469 }
470 std::vector<AnfNodePtr> shape_input;
471 if (reshape_mode == ReshapeMode::FROM_ORIGIN_SLICE_TO_FROM_LAYOUT_SLICE ||
472 reshape_mode == ReshapeMode::TO_ORIGIN_SLICE_TO_TO_LAYOUT_SLICE ||
473 reshape_mode == ReshapeMode::FROM_ORIGIN_BASE_SLICE_TO_TO_ORIGIN_BASE_SLICE) {
474 MatchingAccordingToPrime(shape_vec, dyn_dims_mapping, tensor_redistribution, func_graph, &shape_input,
475 reshape_mode);
476 } else {
477 if (is_same_rank) {
478 MatchingAccordingToIndex(shape_vec, dyn_dims_mapping, tensor_redistribution, func_graph, &shape_input,
479 reshape_mode);
480 } else {
481 MatchingAccordingToPrime(shape_vec, dyn_dims_mapping, tensor_redistribution, func_graph, &shape_input,
482 reshape_mode);
483 }
484 }
485 if (shape_input.size() != shape_vec.size()) {
486 MS_LOG(ERROR) << "shape size is not equal.";
487 return nullptr;
488 }
489
490 if (is_reshape) {
491 // If only has one dynamic axis, then set it to -1.
492 size_t dyn_axis_cnt = LongToSize(CountDynamicAxis(shape_input));
493 MS_LOG(INFO) << "For shape_vec=" << shape_vec << ", has " << dyn_axis_cnt << " dynamic axis.";
494 if (dyn_axis_cnt == 1) {
495 constexpr int64_t unknown = -1;
496 for (size_t i = 0; i < shape_input.size(); ++i) {
497 if (shape_input[i]->isa<CNode>()) {
498 shape_input[i] = NewValueNode(MakeValue(unknown));
499 MS_LOG(INFO) << "change index " << i << " to -1.";
500 break;
501 }
502 }
503 }
504 }
505 if (std::all_of(shape_input.begin(), shape_input.end(), &WhetherIsValueNode)) {
506 std::vector<int64_t> const_shape(shape_input.size());
507 for (size_t i = 0; i < shape_input.size(); ++i) {
508 auto val_node = shape_input[i]->cast<ValueNodePtr>();
509 MS_EXCEPTION_IF_NULL(val_node->value());
510 int64_t value = GetValue<int64_t>(val_node->value());
511 const_shape[i] = value;
512 }
513 return NewValueNode(const_shape);
514 }
515 auto make_tuple = CreateMakeTuple(shape_input, func_graph, REDISTRIBUTION_OP);
516 return make_tuple;
517 }
518
ConvertStridedSliceInputs(const OperatorParams & params,const TensorRedistributionPtr & tensor_redistribution_from_cnode,const FuncGraphPtr & func_graph,std::vector<AnfNodePtr> * new_node_input)519 Status ConvertStridedSliceInputs(const OperatorParams ¶ms,
520 const TensorRedistributionPtr &tensor_redistribution_from_cnode,
521 const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *new_node_input) {
522 for (auto ¶m : params) {
523 if (param.first.first == BEGIN_MASK || param.first.first == END_MASK || param.first.first == ELLIPSIS_MASK ||
524 param.first.first == NEW_AXIS_MASK || param.first.first == SHRINK_AXIS_MASK) {
525 int64_t value = GetValue<int64_t>(param.first.second);
526 MS_LOG(INFO) << "STRIDEDSLICE: param=" << param.first.first << ", param.second=" << value;
527 AnfNodePtr val = NewValueNode(value);
528 val->set_abstract(param.first.second->ToAbstract());
529 (void)new_node_input->emplace_back(val);
530 continue;
531 }
532 Shape shape_vec = GetValue<Shape>(param.first.second);
533 MS_LOG(INFO) << "STRIDEDSLICE: param=" << param.first.first << ", " << shape_vec;
534 if (param.first.first == END) {
535 auto dynamic_input = ConvertConstParamToDynamic(tensor_redistribution_from_cnode, param, func_graph, false);
536 MS_ERROR_IF_NULL_W_RET_VAL(dynamic_input, FAILED);
537 new_node_input->emplace_back(dynamic_input);
538 continue;
539 }
540 AnfNodePtr val = NewValueNode(shape_vec);
541 MS_ERROR_IF_NULL_W_RET_VAL(val, FAILED);
542 val->set_abstract(param.first.second->ToAbstract());
543 (void)new_node_input->emplace_back(val);
544 }
545 return SUCCESS;
546 }
547
WhetherMatchingIsNeededForReshape(const Shape & shape_vec,const TensorRedistributionPtr & tensor_redistribution)548 bool WhetherMatchingIsNeededForReshape(const Shape &shape_vec, const TensorRedistributionPtr &tensor_redistribution) {
549 size_t user_specific_dynamic_dim_cnt = std::count(shape_vec.begin(), shape_vec.end(), -1);
550 TensorLayout to_layout = tensor_redistribution->layout_transfer().to_in();
551 Shape to_shape_in_layout = to_layout.slice_shape().array();
552 MS_LOG(INFO) << "shape_vec=" << shape_vec << ", to_shape_in_layout=" << to_shape_in_layout;
553 if (user_specific_dynamic_dim_cnt == 1 && shape_vec.size() == to_shape_in_layout.size()) {
554 size_t dyn_index = static_cast<size_t>(std::find(shape_vec.begin(), shape_vec.end(), -1) - shape_vec.begin());
555 for (size_t i = 0; i < shape_vec.size(); ++i) {
556 if (i != dyn_index && shape_vec[i] != to_shape_in_layout[i]) {
557 return true;
558 }
559 }
560 MS_LOG(INFO) << "No need to matching for shape: " << shape_vec << ", to_shape_in_layout: " << to_shape_in_layout;
561 return false;
562 }
563 return true;
564 }
565
HasOnlyOneDynamicAxis(const Shape & shape_vec,const TensorRedistributionPtr & tensor_redistribution_from_cnode)566 inline bool HasOnlyOneDynamicAxis(const Shape &shape_vec,
567 const TensorRedistributionPtr &tensor_redistribution_from_cnode) {
568 Shape origin_to_no_assembled = tensor_redistribution_from_cnode->to_origin_no_assembled().tensor_shape().array();
569 Shape origin_to_no_assembled_slice = tensor_redistribution_from_cnode->to_origin_no_assembled().slice_shape().array();
570 bool has_only_one_dynamic_axis = std::count(origin_to_no_assembled.begin(), origin_to_no_assembled.end(), -1) == 1;
571 MS_LOG(INFO) << "shape_vec: " << shape_vec << ", origin_to_no_assembled: " << origin_to_no_assembled
572 << ", origin_to_no_assembled_slice: " << origin_to_no_assembled_slice;
573 return (origin_to_no_assembled.size() == shape_vec.size()) && has_only_one_dynamic_axis;
574 }
575
ReplaceDynamicAxisToNegOne(const TensorRedistributionPtr & tensor_redistribution_from_cnode,Shape * shape_vec)576 void ReplaceDynamicAxisToNegOne(const TensorRedistributionPtr &tensor_redistribution_from_cnode, Shape *shape_vec) {
577 Shape origin_to_no_assembled = tensor_redistribution_from_cnode->to_origin_no_assembled().tensor_shape().array();
578 for (size_t i = 0; i < origin_to_no_assembled.size(); ++i) {
579 if (origin_to_no_assembled[i] == -1) {
580 (*shape_vec)[i] = -1;
581 }
582 }
583 }
584
ConvertReshapeInputs(const OperatorParams & params,const TensorRedistributionPtr & tensor_redistribution_from_cnode,const FuncGraphPtr & func_graph,std::vector<AnfNodePtr> * new_node_input)585 Status ConvertReshapeInputs(const OperatorParams ¶ms,
586 const TensorRedistributionPtr &tensor_redistribution_from_cnode,
587 const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *new_node_input) {
588 Param shape_param;
589 bool use_origin_shape = false;
590 ReshapeMode reshape_mode = ReshapeMode::NO_RESHAPE;
591 for (auto ¶m : params) {
592 if (param.first.first == SHAPE) {
593 shape_param = param;
594 continue;
595 }
596 if (param.first.first == USE_ORIGIN_SHAPE) {
597 use_origin_shape = GetValue<bool>(param.first.second);
598 MS_LOG(INFO) << "Has USE_ORIGIN_SHAPE = " << use_origin_shape;
599 continue;
600 }
601 if (param.first.first == REDISTRIBUTION_RESHAPE_MODE) {
602 reshape_mode = static_cast<ReshapeMode>(GetValue<int64_t>(param.first.second));
603 MS_LOG(INFO) << "Has REDISTRIBUTION_RESHAPE_MODE = " << reshape_mode;
604 continue;
605 }
606 }
607 Shape shape_vec = GetValue<Shape>(shape_param.first.second);
608 if (shape_vec.size() == 1) {
609 std::vector<int64_t> const_shape{-1};
610 AnfNodePtr val = NewValueNode(const_shape);
611 (void)new_node_input->emplace_back(val);
612 return SUCCESS;
613 }
614 if (use_origin_shape && tensor_redistribution_from_cnode->original_reshape_shape() != nullptr) {
615 // Only reshape in user's code should be in this branch.
616 // original_reshape_shape could be ValueNode, MakeTuple, Shape.
617 (void)new_node_input->emplace_back(tensor_redistribution_from_cnode->original_reshape_shape());
618 return SUCCESS;
619 }
620 size_t dynamic_axis_cnt = std::count(shape_vec.begin(), shape_vec.end(), -1);
621 if (shape_vec.size() > 1 && dynamic_axis_cnt >= SIZE_TWO) {
622 MS_LOG(WARNING) << "The shape of Reshape op has more than one -1, cannot be supported for now.";
623 }
624 Shape origin_to_no_assembled = tensor_redistribution_from_cnode->to_origin_no_assembled().tensor_shape().array();
625 Shape origin_to_no_assembled_slice = tensor_redistribution_from_cnode->to_origin_no_assembled().slice_shape().array();
626 MS_LOG(INFO) << "shape_vec: " << shape_vec << ", reshape_mode: " << reshape_mode
627 << ", origin_to_no_assembled: " << origin_to_no_assembled
628 << ", origin_to_no_assembled_slice: " << origin_to_no_assembled_slice;
629 // if only has one dynamic axis, then replace it with -1 simply.
630 if (reshape_mode == ReshapeMode::NO_RESHAPE && HasOnlyOneDynamicAxis(shape_vec, tensor_redistribution_from_cnode)) {
631 // After HasOnlyOneDynamicAxis checks, shape_vec must have one dynamic axis and it must be prime axis.
632 Shape new_shape_vec(shape_vec);
633 ReplaceDynamicAxisToNegOne(tensor_redistribution_from_cnode, &new_shape_vec);
634 MS_LOG(INFO) << "Replace shape: " << shape_vec << " to new_shape_vec: " << new_shape_vec;
635 AnfNodePtr val = NewValueNode(new_shape_vec);
636 (void)new_node_input->emplace_back(val);
637 return SUCCESS;
638 }
639 if (!WhetherMatchingIsNeededForReshape(shape_vec, tensor_redistribution_from_cnode)) {
640 MS_LOG(INFO) << "No need to matching for " << shape_vec;
641 AnfNodePtr val = NewValueNode(shape_param.first.second);
642 val->set_abstract(shape_param.first.second->ToAbstract());
643 (void)new_node_input->emplace_back(val);
644 return SUCCESS;
645 }
646 auto dynamic_input =
647 ConvertConstParamToDynamic(tensor_redistribution_from_cnode, shape_param, func_graph, true, reshape_mode);
648 MS_ERROR_IF_NULL_W_RET_VAL(dynamic_input, FAILED);
649 (void)new_node_input->emplace_back(dynamic_input);
650 return SUCCESS;
651 }
652
ConvertSplitInputs(const OperatorParams & params,const FuncGraphPtr & func_graph,std::vector<AnfNodePtr> * new_node_input)653 Status ConvertSplitInputs(const OperatorParams ¶ms, const FuncGraphPtr &func_graph,
654 std::vector<AnfNodePtr> *new_node_input) {
655 MS_EXCEPTION_IF_CHECK_FAIL(new_node_input->size() == 1,
656 "new_node_input must and only contain the input of split for split.");
657 auto split_target = new_node_input[0];
658 std::vector<AnfNodePtr> split_inputs = {split_target};
659 ValuePtr output_index;
660 for (auto ¶m : params) {
661 if (param.first.first == SPLIT_OUTPUT_INDEX) {
662 output_index = param.first.second;
663 continue;
664 }
665 AnfNodePtr val = NewValueNode(param.first.second);
666 MS_EXCEPTION_IF_NULL(val);
667 val->set_abstract(param.first.second->ToAbstract());
668 (void)split_inputs.emplace_back(val);
669 }
670 constexpr char tag[] = "redistribution_allsplit";
671 auto split_op = CreateSplit(split_inputs, func_graph, tag);
672 auto split_output_index = NewValueNode(output_index);
673 auto tuple_get_item_prim = std::make_shared<Primitive>(TUPLE_GETITEM_OP);
674 auto prim_value_node = NewValueNode(tuple_get_item_prim);
675 tuple_get_item_prim->set_instance_name(tag);
676 new_node_input->resize(SIZE_THREE);
677 (*new_node_input)[INDEX_ZERO] = prim_value_node;
678 (*new_node_input)[INDEX_ONE] = split_op;
679 (*new_node_input)[INDEX_TWO] = split_output_index;
680 return SUCCESS;
681 }
682
IsToBeInsertedSplitOp(const Operator & op)683 bool IsToBeInsertedSplitOp(const Operator &op) {
684 // if split op has attr SPLIT_INSERT_LATER, then skip it in OptimizeTensorRedistributionOperatorList stage,
685 // and insert it in CreateInputs
686 if (op.first != SPLIT) {
687 return false;
688 }
689 OperatorAttrs op_attrs = op.second.first;
690 auto is_skip_func = [](const Attr &attr) -> bool {
691 return attr.first == SPLIT_INSERT_LATER && GetValue<bool>(attr.second);
692 };
693 return std::any_of(op_attrs.begin(), op_attrs.end(), is_skip_func);
694 }
695
ConvertParamsToInputs(const Operator & op,const TensorRedistributionPtr & tensor_redistribution_from_cnode,const FuncGraphPtr & func_graph,std::vector<AnfNodePtr> * new_node_input)696 Status ConvertParamsToInputs(const Operator &op, const TensorRedistributionPtr &tensor_redistribution_from_cnode,
697 const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *new_node_input) {
698 MS_ERROR_IF_NULL_W_RET_VAL(tensor_redistribution_from_cnode, FAILED);
699 MS_EXCEPTION_IF_NULL(func_graph);
700 OperatorArgs arg_forward = op.second;
701 OperatorParams params = arg_forward.second;
702
703 if (op.first == RESHAPE) {
704 if (ConvertReshapeInputs(params, tensor_redistribution_from_cnode, func_graph, new_node_input) != SUCCESS) {
705 return FAILED;
706 }
707 } else if (op.first == STRIDEDSLICE) {
708 if (ConvertStridedSliceInputs(params, tensor_redistribution_from_cnode, func_graph, new_node_input) != SUCCESS) {
709 return FAILED;
710 }
711 } else if (IsToBeInsertedSplitOp(op)) {
712 if (ConvertSplitInputs(params, func_graph, new_node_input) != SUCCESS) {
713 return FAILED;
714 }
715 } else {
716 MS_LOG(DEBUG) << op.first << " is not supported.";
717 return FAILED;
718 }
719 return SUCCESS;
720 }
721
CreateInput(const Operator & op,const AnfNodePtr & pre_node,const std::string & instance_name,const CNodePtr & cur_cnode)722 std::vector<AnfNodePtr> CreateInput(const Operator &op, const AnfNodePtr &pre_node, const std::string &instance_name,
723 const CNodePtr &cur_cnode) {
724 MS_EXCEPTION_IF_NULL(pre_node);
725 OperatorArgs arg_forward = op.second;
726 OperatorParams params = arg_forward.second;
727
728 std::vector<AnfNodePtr> new_node_input = {pre_node};
729 MS_LOG(INFO) << "CreateInput param.empty=" << params.empty() << ", pre_node=" << pre_node->fullname_with_scope()
730 << ", op=" << op.first;
731 bool is_done = false;
732 if (cur_cnode != nullptr) {
733 TensorRedistributionPtr tensor_redistribution = GetTensorRedistributionFromCNode(cur_cnode);
734 // 1. Only deal with Reshape in user scripts.
735 // 2. Deal with non-user Reshape. If only have StrideSliceD, Concat and Split cannot reach.
736 if (tensor_redistribution != nullptr && tensor_redistribution->IsAssembledStaticShape()) {
737 MS_LOG(DEBUG) << cur_cnode->fullname_with_scope() << " distribute_operator is not nullptr";
738 if (ConvertParamsToInputs(op, tensor_redistribution, cur_cnode->func_graph(), &new_node_input) == SUCCESS) {
739 is_done = true;
740 } else {
741 MS_LOG(DEBUG) << "Convert params to inputs failed.";
742 }
743 } else {
744 MS_LOG(INFO) << "cur_cnode=" << cur_cnode->fullname_with_scope() << " is not dynamic node.";
745 }
746 }
747
748 if (IsToBeInsertedSplitOp(op) && !is_done && cur_cnode != nullptr) {
749 // it means Split on static shape scene.
750 auto ret = ConvertSplitInputs(params, cur_cnode->func_graph(), &new_node_input);
751 MS_EXCEPTION_IF_CHECK_FAIL(ret == SUCCESS, "Insert split op failed.");
752 is_done = true;
753 }
754
755 if (!is_done && !params.empty()) {
756 for (const auto ¶m : params) {
757 AnfNodePtr val = NewValueNode(param.first.second);
758 MS_EXCEPTION_IF_NULL(val);
759 val->set_abstract(param.first.second->ToAbstract());
760 int64_t position = param.second;
761 (void)new_node_input.insert(new_node_input.cbegin() + position - 1, val);
762 }
763 }
764
765 if (!IsToBeInsertedSplitOp(op)) {
766 new_node_input = ConvertToRealInputs(op.first, instance_name, new_node_input, arg_forward.first);
767 }
768 // if the op have 'group' attr, set the rank list name for the op
769 SetCommunicationOpGroupLabel(new_node_input);
770 return new_node_input;
771 }
772
ReplaceOpInput(const Operator & replace_op,const std::string & instance_name,const CNodePtr & node)773 std::vector<AnfNodePtr> ReplaceOpInput(const Operator &replace_op, const std::string &instance_name,
774 const CNodePtr &node) {
775 MS_EXCEPTION_IF_NULL(node);
776 MS_EXCEPTION_IF_NULL(node->func_graph());
777 OperatorArgs arg_replace_op = replace_op.second;
778 OperatorParams params = arg_replace_op.second;
779 if (node->size() < SIZE_TWO) {
780 // GetNext operator dose not has input
781 if (node->size() == 1) {
782 return ConvertToRealInputs(replace_op.first, instance_name, AnfNodePtrList{}, arg_replace_op.first);
783 }
784 MS_LOG(EXCEPTION) << "Failure: " << node->ToString() << " size is smaller than 2";
785 }
786 std::vector<AnfNodePtr> replace_input = {node->input(1)};
787
788 if (replace_op.first == EMBEDDING_LOOKUP) {
789 replace_input = {node->input(1), node->input(2)};
790 }
791 if (!params.empty() && replace_op.first != SYNC_BATCH_NORM) {
792 Param param_first = *(params.begin());
793 int64_t first_position = param_first.second;
794 if (first_position == 1) {
795 replace_input.pop_back();
796 }
797 }
798 bool is_done = false;
799 bool to_be_converted = replace_op.first == SPLIT || replace_op.first == STRIDEDSLICE || replace_op.first == RESHAPE;
800 if (!params.empty() && to_be_converted && IsDynamicOp(node)) {
801 TensorRedistributionPtr tensor_redistribution = GetTensorRedistributionFromCNode(node);
802 auto ret = ConvertParamsToInputs(replace_op, tensor_redistribution, node->func_graph(), &replace_input);
803 MS_EXCEPTION_IF_CHECK_FAIL(ret == SUCCESS, "ConvertStridedSliceInputs failed.");
804 is_done = true;
805 } else if (!params.empty() && !IsToBeInsertedSplitOp(replace_op)) {
806 for (auto ¶m : params) {
807 AnfNodePtr val = NewValueNode(param.first.second);
808 if (val == nullptr) {
809 MS_LOG(EXCEPTION) << "Failure:val is nullptr";
810 }
811 int64_t position = param.second;
812 (void)replace_input.insert(replace_input.cbegin() + position - 1, val);
813 }
814 } else if (replace_op.first == SYNC_BATCH_NORM) {
815 for (size_t i = 2; i < node->size(); ++i) {
816 replace_input.push_back(node->input(i));
817 }
818 }
819
820 if (!IsToBeInsertedSplitOp(replace_op)) {
821 replace_input = ConvertToRealInputs(replace_op.first, instance_name, replace_input, arg_replace_op.first);
822 } else if (IsToBeInsertedSplitOp(replace_op) && !is_done) {
823 // it means Split on static shape scene.
824 auto ret = ConvertSplitInputs(params, node->func_graph(), &replace_input);
825 MS_EXCEPTION_IF_CHECK_FAIL(ret == SUCCESS, "Insert split op failed.");
826 }
827 SetCommunicationOpGroupLabel(replace_input);
828 return replace_input;
829 }
830
InsertNode(const Operator & op,const CNodePtr & node,size_t index,const AnfNodePtr & pre_node,const FuncGraphPtr & func_graph,const std::string & instance_name,const std::string & param_name,const FuncGraphPtr & root,const TensorRedistributionPtr & tensor_redistribution)831 CNodePtr InsertNode(const Operator &op, const CNodePtr &node, size_t index, const AnfNodePtr &pre_node,
832 const FuncGraphPtr &func_graph, const std::string &instance_name, const std::string ¶m_name,
833 const FuncGraphPtr &root, const TensorRedistributionPtr &tensor_redistribution) {
834 // insert new node before the node
835 MS_EXCEPTION_IF_NULL(node);
836 MS_EXCEPTION_IF_NULL(func_graph);
837 FuncGraphManagerPtr manager = func_graph->manager();
838 MS_EXCEPTION_IF_NULL(manager);
839 ScopePtr scope = node->scope();
840 MS_EXCEPTION_IF_NULL(scope);
841 std::vector<AnfNodePtr> node_input;
842
843 if (root && !param_name.empty()) {
844 node_input = CreateMirrorInput(root, op, pre_node, instance_name, param_name);
845 } else {
846 node_input = CreateInput(op, pre_node, instance_name, node);
847 }
848
849 CNodePtr new_node = func_graph->NewCNode(node_input);
850 MS_EXCEPTION_IF_NULL(new_node);
851 if (instance_name.find(SPLIT_SENS) == std::string::npos) {
852 new_node->set_in_forward_flag(true); // mark forward flag
853 }
854 auto new_node_value = node_input[0]->cast<ValueNodePtr>();
855 MS_EXCEPTION_IF_NULL(new_node_value);
856 auto new_node_prim = new_node_value->value()->cast<PrimitivePtr>();
857 new_node_prim->set_instance_name(instance_name);
858 new_node_prim->set_attr("keep_value_node_input", MakeValue(true));
859 if (instance_name.find(NOT_RECOMPUTE) != std::string::npos) {
860 new_node_prim->set_attr("recompute", MakeValue(false));
861 } else if (instance_name.find(RECOMPUTE) != std::string::npos) {
862 new_node_prim->set_attr("recompute", MakeValue(true));
863 }
864
865 auto primitive = common::AnfAlgo::GetCNodePrimitive(new_node);
866 MS_EXCEPTION_IF_NULL(primitive);
867 if (node->HasPrimalAttr(SEGMENT)) {
868 primitive->AddAttr(SEGMENT, node->GetPrimalAttr(SEGMENT));
869 new_node->AddPrimalAttr(SEGMENT, node->GetPrimalAttr(SEGMENT));
870 }
871 if (node->HasPrimalAttr(MICRO)) {
872 new_node->AddPrimalAttr(MICRO, node->GetPrimalAttr(MICRO));
873 }
874 new_node->set_scope(scope);
875 node_input[0]->set_scope(scope);
876 if (instance_name.find(REDISTRIBUTION_OP) != std::string::npos) {
877 new_node->AddPrimalAttr(kPrimalAttrForwardCommNodeUniqueId, MakeValue<std::string>(new_node->UniqueId()));
878 if (node->HasPrimalAttr(MICRO)) {
879 new_node->AddPrimalAttr(MICRO, node->GetPrimalAttr(MICRO));
880 }
881 }
882 manager->SetEdge(node, SizeToInt(index), new_node);
883 MS_LOG(INFO) << "Insert " << instance_name << " success";
884 return new_node;
885 }
886
IsRootNode(const CNodePtr & cnode,const AnfNodePtr & root_node)887 bool IsRootNode(const CNodePtr &cnode, const AnfNodePtr &root_node) {
888 // cnode is TupleGetItem.
889 // if first input of op is shape, and the shape first input is the same with reshape.
890 // sometimes the reshape first input maybe is not same with shape first input.
891 auto first_input_of_tuple_getitem = cnode->input(1)->cast<CNodePtr>();
892 if (!IsTargetOp(first_input_of_tuple_getitem, SHAPE_OP)) {
893 return false;
894 }
895 auto first_input_of_shape = first_input_of_tuple_getitem->input(1);
896 if (first_input_of_shape == root_node) {
897 return True;
898 } else {
899 MS_LOG(WARNING) << "Shape's first input is not same with root node.";
900 }
901 return True;
902 }
903
FindPreviousNodeAndSkipTupleGetItem(const CNodePtr & current,int32_t depth=0)904 std::pair<CNodePtr, int64_t> FindPreviousNodeAndSkipTupleGetItem(const CNodePtr ¤t, int32_t depth = 0) {
905 // current is TupleGetItem
906 if (depth == MAX_RECURSIVE_DEPTH) {
907 return {nullptr, -1};
908 }
909 auto prev = current->input(1);
910 auto cnode = prev->cast<CNodePtr>();
911 if (IsTupleGetItem(cnode)) {
912 return FindPreviousNodeAndSkipTupleGetItem(cnode, depth + 1);
913 }
914 int64_t index = GetTupleGetItemIndex(current);
915 return {cnode, index};
916 }
917
ModifyGraph(const CNodePtr & current_cnode,const CNodePtr & previous_tuple_getitem_cnode,size_t input_index)918 bool ModifyGraph(const CNodePtr ¤t_cnode, const CNodePtr &previous_tuple_getitem_cnode, size_t input_index) {
919 /**
920 * This function must be called after IsRootNode() called and IsRootNode() return True.
921 *
922 * TupleGetItem(tensor, index)
923 * ->
924 * ScalarMul(scalar)
925 * ->
926 * current_cnode
927 */
928 int64_t index = GetTupleGetItemIndex(previous_tuple_getitem_cnode);
929 auto root_node = previous_tuple_getitem_cnode->input(1)->cast<CNodePtr>()->input(1)->cast<CNodePtr>();
930 if (IsTupleGetItem(root_node)) {
931 // keep search the previous node.
932 auto output = FindPreviousNodeAndSkipTupleGetItem(root_node);
933 root_node = output.first;
934 }
935 // Get tensor layout from root_node.
936 if (!root_node->has_user_data<OperatorInfo>()) {
937 // Default/TupleGetItem-op0 has no operator info.
938 MS_LOG(INFO) << root_node->fullname_with_scope() << " has no operator info.";
939 return True;
940 }
941 OperatorInfoPtr distribute_operator = GetDistributeOperator(root_node);
942 MS_EXCEPTION_IF_NULL(distribute_operator);
943 std::vector<TensorInfo> root_tensor_info = distribute_operator->outputs_tensor_info();
944 if (root_tensor_info.size() != 1) {
945 MS_LOG(ERROR) << "Outputs number cannot be larger than 1.";
946 return False;
947 }
948 TensorInfo tensor_info = root_tensor_info[0];
949 Map tensor_map = tensor_info.tensor_layout().tensor_map();
950 Arrangement dev_arr = tensor_info.tensor_layout().device_arrangement();
951 if (LongToSize(index) >= tensor_map.GetDimSize()) {
952 MS_LOG(ERROR) << "Index cannot be larger than tensor_map size.";
953 return False;
954 }
955 int64_t scalar = dev_arr.GetDimByReverseIdx(tensor_map.GetDimByIdx(index));
956 // Create ValueNode for scalar->Create Mul Cnode->Modify inputs and edges
957 Operator scalar_mul_op = CreateScalarMulOp(scalar);
958 InsertNode(scalar_mul_op, // to be inserted op
959 current_cnode, // current node
960 input_index, // input index of current_node
961 previous_tuple_getitem_cnode, // insert scalar_mul_op between previous and current
962 current_cnode->func_graph(), // current func_graph
963 "instance_name", "", nullptr);
964 MS_LOG(DEBUG) << tensor_info.tensor_layout().ToString() << ", " << previous_tuple_getitem_cnode->fullname_with_scope()
965 << " index: " << index << ", scalar: " << scalar;
966 return True;
967 }
968
UpdateShapeToRootPath(const CNodePtr & cnode,const AnfNodePtr & root_node,int32_t depth=0)969 Status UpdateShapeToRootPath(const CNodePtr &cnode, const AnfNodePtr &root_node, int32_t depth = 0) {
970 if (depth == MAX_RECURSIVE_DEPTH) {
971 return REACH_MAX_RECURSIVE_DEPTH;
972 }
973 auto value_node = cnode->input(0)->cast<ValueNodePtr>();
974 auto prim = value_node->value()->cast<PrimitivePtr>();
975 for (size_t i = 1; i < cnode->inputs().size(); ++i) {
976 auto input = cnode->input(i)->cast<CNodePtr>();
977 if (input == nullptr) {
978 continue;
979 }
980 if (IsTupleGetItem(input) && IsRootNode(input, root_node)) {
981 // Modify this graph path.
982 if (!ModifyGraph(cnode, input, i)) {
983 MS_LOG(ERROR) << "Failed to modify graph.";
984 return Status::FAILED;
985 }
986 return Status::SUCCESS;
987 }
988 // Keep traceback.
989 Status ret = UpdateShapeToRootPath(input, root_node, depth + 1);
990 if (ret != Status::SUCCESS) {
991 return Status::FAILED;
992 }
993 }
994 return Status::SUCCESS;
995 }
996
UpdatePartialShape(const CNodePtr & cnode)997 Status UpdatePartialShape(const CNodePtr &cnode) {
998 // Traceback shape_of_reshape input of Reshape Op.
999 MS_EXCEPTION_IF_NULL(cnode);
1000 MS_EXCEPTION_IF_CHECK_FAIL(cnode->inputs().size() == RESHAPE_INPUT_SIZE,
1001 "Reshape op must have " + std::to_string(RESHAPE_INPUT_SIZE) + " inputs.");
1002 // Step1. Get second input of Reshape op which represent shape_of_reshape.
1003 // Step2. Visit shape_of_reshape and trace back to dynamic axis.
1004 auto input_of_reshape = cnode->input(RESHAPE_INPUT_SIZE - 2);
1005 auto shape_of_reshape = cnode->input(RESHAPE_INPUT_SIZE - 1);
1006 auto shape_cnode = shape_of_reshape->cast<CNodePtr>(); // MakeTuple
1007 if (shape_cnode == nullptr) {
1008 return Status::SUCCESS;
1009 }
1010 for (const auto &input : shape_cnode->inputs()) {
1011 auto cnode_input = input->cast<CNodePtr>();
1012 if (cnode_input == nullptr) {
1013 continue;
1014 }
1015 if (UpdateShapeToRootPath(cnode_input, input_of_reshape) != Status::SUCCESS) {
1016 MS_LOG(ERROR) << "Update " << cnode->fullname_with_scope() << " previous shape failed.";
1017 return Status::FAILED;
1018 }
1019 }
1020 return Status::SUCCESS;
1021 }
1022
FindPreviousCareNode(const CNodePtr & current,int32_t depth=0)1023 CNodePtr FindPreviousCareNode(const CNodePtr ¤t, int32_t depth = 0) {
1024 if (depth == MAX_RECURSIVE_DEPTH) {
1025 return nullptr;
1026 }
1027 auto prev = current->input(1);
1028 // If prev is parameter maybe problem here.
1029 auto cnode = prev->cast<CNodePtr>();
1030 if (cnode == nullptr) {
1031 MS_LOG(INFO) << "Input of node is not a cnode: " << prev->fullname_with_scope();
1032 return nullptr;
1033 }
1034 if (!IsParallelCareNode(cnode) && (IsTargetOp(cnode, "Cast") || IsTupleGetItem(cnode))) {
1035 return FindPreviousCareNode(cnode, depth + 1);
1036 }
1037 return cnode;
1038 }
1039
GetDistributeOperatorFromCNode(const CNodePtr & cnode,TensorInfo * tensor_info)1040 Status GetDistributeOperatorFromCNode(const CNodePtr &cnode, TensorInfo *tensor_info) {
1041 MS_EXCEPTION_IF_NULL(cnode);
1042 CNodePtr target_cnode = cnode;
1043 if (!IsParallelCareNode(cnode)) {
1044 // keep search the previous node.
1045 target_cnode = FindPreviousCareNode(cnode);
1046 }
1047 if (target_cnode == nullptr) {
1048 return Status::FAILED;
1049 }
1050 if (!target_cnode->has_user_data<OperatorInfo>()) {
1051 MS_LOG(EXCEPTION) << "Found " << cnode->fullname_with_scope() << " previous node is "
1052 << target_cnode->fullname_with_scope() << " and it has no operator info.";
1053 }
1054
1055 OperatorInfoPtr distribute_operator = GetDistributeOperator(target_cnode);
1056 MS_EXCEPTION_IF_NULL(distribute_operator);
1057 std::vector<TensorInfo> root_tensor_info = distribute_operator->outputs_tensor_info();
1058 if (root_tensor_info.size() != 1) {
1059 if (IsTupleGetItem(cnode)) {
1060 int64_t output_index = GetTupleGetItemIndex(cnode);
1061 MS_EXCEPTION_IF_CHECK_FAIL(
1062 (output_index >= 0 && output_index < SizeToLong(root_tensor_info.size())),
1063 "TupleGetItem index is not matched with its input length, TupleGetItem is " + cnode->fullname_with_scope());
1064 MS_LOG(INFO) << "Replace tensor info use " << target_cnode->fullname_with_scope() << " with index "
1065 << output_index;
1066 (*tensor_info) = root_tensor_info[output_index];
1067 return Status::SUCCESS;
1068 }
1069 MS_LOG(WARNING) << "Outputs number cannot be larger than 1, but " << target_cnode->fullname_with_scope() << " has "
1070 << root_tensor_info.size() << " outputs.";
1071 }
1072 (*tensor_info) = root_tensor_info[0];
1073 return Status::SUCCESS;
1074 }
1075
UpdateTupleGetItemShapeValue(const CNodePtr & tuple_getitem,const TensorInfo & tensor_info,const FuncGraphPtr & func_graph)1076 Status UpdateTupleGetItemShapeValue(const CNodePtr &tuple_getitem, const TensorInfo &tensor_info,
1077 const FuncGraphPtr &func_graph) {
1078 MS_LOG(INFO) << "into UpdateTupleGetItemShapeValue";
1079 Map tensor_map = tensor_info.tensor_layout().tensor_map();
1080 Arrangement dev_arr = tensor_info.tensor_layout().device_arrangement();
1081 auto manager = func_graph->manager();
1082 MS_EXCEPTION_IF_NULL(manager);
1083 auto node_users_map = manager->node_users();
1084
1085 int64_t index = GetTupleGetItemIndex(tuple_getitem);
1086 if (LongToSize(index) >= tensor_map.GetDimSize()) {
1087 MS_LOG(ERROR) << "Index cannot be larger than tensor_map size.";
1088 return Status::FAILED;
1089 }
1090 if (tensor_map.GetDimByIdx(index) < 0) {
1091 MS_LOG(DEBUG) << "Skip index " << index << ", because it's " << tensor_map.GetDimByIdx(index);
1092 return Status::SUCCESS;
1093 }
1094 int64_t scalar = dev_arr.GetDimByReverseIdx(tensor_map.GetDimByIdx(index));
1095 for (const auto &next_node : node_users_map[tuple_getitem]) {
1096 auto tuple_getitem_user = next_node.first->cast<CNodePtr>();
1097 if (tuple_getitem_user == nullptr) {
1098 MS_LOG(DEBUG) << "tuple_getitem_user is nullptr";
1099 continue;
1100 }
1101 MS_LOG(INFO) << tuple_getitem->input(1)->fullname_with_scope() << "->" << tuple_getitem->fullname_with_scope()
1102 << "->ScalarMul(" << scalar << ")->" << next_node.first->fullname_with_scope() << "["
1103 << next_node.second << "]" << std::endl;
1104 Operator scalar_mul_op = CreateScalarMulOp(scalar);
1105 (void)InsertNode(scalar_mul_op, // to be inserted op
1106 tuple_getitem_user, // current node
1107 next_node.second, // tuple_getitem_user[input_index] = scalar_mul_op
1108 tuple_getitem, // insert scalar_mul_op between previous and current
1109 tuple_getitem_user->func_graph(), // current func_graph
1110 "update_partial_shape", "", nullptr);
1111 }
1112 return Status::SUCCESS;
1113 }
1114
UpdateReshapeShapeValue(const CNodePtr & reshape_cnode,const CNodePtr & shape_cnode,const Shape & shape,const TensorInfo & tensor_info,const FuncGraphPtr & func_graph)1115 Status UpdateReshapeShapeValue(const CNodePtr &reshape_cnode, const CNodePtr &shape_cnode, const Shape &shape,
1116 const TensorInfo &tensor_info, const FuncGraphPtr &func_graph) {
1117 // Replace shape to MakeTuple(shape[0]*factor0, shape[1]*factor1,...)
1118 MS_LOG(INFO) << "into UpdateReshapeShapeValue: " << shape;
1119 MS_EXCEPTION_IF_NULL(reshape_cnode);
1120 MS_EXCEPTION_IF_NULL(shape_cnode);
1121 MS_EXCEPTION_IF_NULL(func_graph);
1122 FuncGraphManagerPtr manager = func_graph->manager();
1123 MS_EXCEPTION_IF_NULL(manager);
1124 Map tensor_map = tensor_info.tensor_layout().tensor_map();
1125 Arrangement dev_arr = tensor_info.tensor_layout().device_arrangement();
1126 TensorRedistributionPtr tensor_redistribution = GetTensorRedistributionFromCNode(reshape_cnode);
1127
1128 std::vector<AnfNodePtr> make_tuple_inputs;
1129 std::string instance_name = std::string(REDISTRIBUTION_OP) + "_replace_reshape";
1130 for (size_t i = 0; i < shape.size(); ++i) {
1131 if (shape[i] > 0) {
1132 // Get const value and set to make_tuple_inputs.
1133 auto const_val_node = NewValueNode(MakeValue(shape[i]));
1134 make_tuple_inputs.emplace_back(const_val_node);
1135 MS_LOG(INFO) << "Create ValueNode " << shape[i];
1136 continue;
1137 }
1138 // Get shape from shape node.
1139 auto prim_tuple_get_item = std::make_shared<Primitive>(TUPLE_GETITEM_OP);
1140 AnfNodePtrList inputs{NewValueNode(prim_tuple_get_item), shape_cnode, NewValueNode(MakeValue(SizeToLong(i)))};
1141 auto tuple_get_item_cnode = func_graph->NewCNode(inputs);
1142 tuple_get_item_cnode->set_fullname_with_scope("tuple_getitem_replace_reshape");
1143 prim_tuple_get_item->set_instance_name(instance_name);
1144 make_tuple_inputs.emplace_back(tuple_get_item_cnode);
1145 MS_LOG(INFO) << "Create TupleGetItem for " << i;
1146 }
1147 auto make_tuple = CreateMakeTuple(make_tuple_inputs, func_graph, instance_name);
1148 make_tuple->set_in_forward_flag(true);
1149 std::string fullname = shape_cnode->fullname_with_scope() + "_replace";
1150 make_tuple->set_fullname_with_scope(fullname);
1151 manager->SetEdge(reshape_cnode, INDEX_TWO, make_tuple);
1152 MS_LOG(INFO) << shape_cnode->fullname_with_scope() << "->" << make_tuple->fullname_with_scope() << "->"
1153 << reshape_cnode->fullname_with_scope();
1154 MS_LOG(INFO) << "reshape shape is : " << shape;
1155 MS_LOG(INFO) << "reshape tensor_map is : " << tensor_map.array();
1156 MS_LOG(INFO) << "reshape dev_arr is : " << dev_arr.array();
1157 for (size_t i = 0; i < tensor_map.array().size(); ++i) {
1158 if (tensor_map.GetDimByIdx(i) == -1) {
1159 continue;
1160 }
1161 if (make_tuple_inputs[i]->isa<ValueNode>()) {
1162 continue;
1163 }
1164 int64_t scalar = dev_arr.GetDimByReverseIdx(tensor_map.GetDimByIdx(i));
1165 Operator scalar_mul_op = CreateScalarMulOp(scalar);
1166 (void)InsertNode(scalar_mul_op, // to be inserted op
1167 make_tuple, // current node
1168 i + 1, // make_tuple[input_index] = scalar_mul_op
1169 make_tuple->input(i + 1), // insert scalar_mul_op between previous and current
1170 func_graph, // current func_graph
1171 "update_partial_shape", "", nullptr);
1172 }
1173 if (tensor_redistribution != nullptr && tensor_redistribution->original_reshape_shape() != nullptr) {
1174 tensor_redistribution->set_original_reshape_shape(make_tuple);
1175 MS_LOG(INFO) << "Change original_reshape_shape";
1176 }
1177 return Status::SUCCESS;
1178 }
1179
SkipSupplyForReshape(const CNodePtr & cnode)1180 bool SkipSupplyForReshape(const CNodePtr &cnode) {
1181 if (!IsReshapeOp(cnode)) {
1182 return false;
1183 }
1184 auto prim = GetCNodePrimitive(cnode);
1185 if (prim->HasAttr(SKIP_REDISTRIBUTION)) {
1186 bool skip_redistribution = GetValue<bool>(prim->GetAttr(SKIP_REDISTRIBUTION));
1187 return skip_redistribution;
1188 }
1189 return false;
1190 }
1191
UpdateShapeNode(const CNodePtr & cnode,const FuncGraphPtr & func_graph)1192 Status UpdateShapeNode(const CNodePtr &cnode, const FuncGraphPtr &func_graph) {
1193 MS_EXCEPTION_IF_NULL(cnode);
1194 // Step1. Get shape input tensor layout. cnode is Shape op.
1195 auto input_of_shape = cnode->input(1);
1196 auto input_cnode = input_of_shape->cast<CNodePtr>();
1197 if (input_cnode == nullptr) {
1198 return Status::SUCCESS;
1199 }
1200 if (SkipSupplyForReshape(input_cnode)) {
1201 MS_LOG(INFO) << "Skip " << cnode->fullname_with_scope() << ", because its input is reshape.";
1202 return Status::SUCCESS;
1203 }
1204 if (IsValueNode<FuncGraph>(input_cnode->input(0))) {
1205 // It means it's a sub-graph call node.
1206 MS_LOG(WARNING) << "If the input of shape is subgraph, and it's outputs sharding strategy "
1207 "is not all 1, it could be problem.";
1208 return Status::SUCCESS;
1209 }
1210 TensorInfo tensor_info;
1211 if (GetDistributeOperatorFromCNode(input_cnode, &tensor_info) != Status::SUCCESS) {
1212 return Status::SUCCESS;
1213 }
1214 Map tensor_map = tensor_info.tensor_layout().tensor_map();
1215 Arrangement dev_arr = tensor_info.tensor_layout().device_arrangement();
1216
1217 // Step2. Get shape node users.
1218 auto node_users_map = func_graph->manager()->node_users();
1219 auto shape_node_users = node_users_map[cnode];
1220 for (const auto &node_user : shape_node_users) {
1221 MS_EXCEPTION_IF_NULL(node_user.first);
1222 auto shape_user = node_user.first->cast<CNodePtr>();
1223 if (IsReshapeOp(shape_user)) {
1224 std::vector<Shape> input_shapes = GetNodeShape(input_of_shape);
1225 if (input_shapes.size() != 1) {
1226 MS_LOG(EXCEPTION) << "Shape's input size is illegal.";
1227 }
1228 if (UpdateReshapeShapeValue(shape_user, cnode, input_shapes[0], tensor_info, func_graph) != Status::SUCCESS) {
1229 MS_LOG(EXCEPTION) << "Update reshape shape value failed.";
1230 }
1231 continue;
1232 }
1233 if (shape_user == nullptr || IsTargetOp(shape_user, ZEROS)) {
1234 MS_LOG(ERROR) << "won't supply shape for " << shape_user->fullname_with_scope();
1235 continue;
1236 }
1237 MS_EXCEPTION_IF_CHECK_FAIL(IsTupleGetItem(shape_user),
1238 "Only support TupleGetItem here, but got " + GetPrimName(shape_user));
1239 if (IsTupleGetItem(shape_user) &&
1240 UpdateTupleGetItemShapeValue(shape_user, tensor_info, func_graph) != Status::SUCCESS) {
1241 MS_LOG(EXCEPTION) << "Update tuple get item shape value failed.";
1242 }
1243 }
1244 return Status::SUCCESS;
1245 }
1246
UpdateMakeTupleShapeValue(const CNodePtr & make_tuple,const std::map<size_t,int64_t> & factor_mapping,const FuncGraphPtr & func_graph)1247 Status UpdateMakeTupleShapeValue(const CNodePtr &make_tuple, const std::map<size_t, int64_t> &factor_mapping,
1248 const FuncGraphPtr &func_graph) {
1249 for (size_t i = 1; i < make_tuple->inputs().size(); ++i) {
1250 if (factor_mapping.find(i - 1) == factor_mapping.end()) {
1251 continue;
1252 }
1253 auto make_tuple_input = make_tuple->input(i);
1254 if (make_tuple_input->isa<ValueNode>()) {
1255 auto val_node = make_tuple_input->cast<ValueNodePtr>();
1256 MS_EXCEPTION_IF_NULL(val_node->value());
1257 auto dim_value = GetValue<int64_t>(val_node->value());
1258 if (dim_value == -1) {
1259 continue;
1260 }
1261 }
1262 Operator scalar_div_op = CreateScalarDivOp(factor_mapping.at(i - 1));
1263 // TODO(liuchongming): If make_tuple_input is mul op, then consider merge the two op.
1264 auto div_cnode = InsertNode(scalar_div_op, // to be inserted op
1265 make_tuple, // current node
1266 i, // tuple_getitem_user[i] = scalar_div_op
1267 make_tuple_input, // insert scalar_div_op between previous and current
1268 func_graph, // current func_graph
1269 "segment_partial_shape", "", nullptr);
1270 Operator cast_op = CreateScalarCastOp(kInt64);
1271 (void)InsertNode(cast_op, // to be inserted op
1272 make_tuple, // current node
1273 i, // tuple_getitem_user[i] = cast_op
1274 div_cnode, // div_cnode->scalar_div_op->make_tuple
1275 func_graph, // current func_graph
1276 "segment_partial_shape", "", nullptr);
1277 }
1278 return Status::SUCCESS;
1279 }
1280
SegmentEntireShapeToPartialForDynamic(const CNodePtr & reshape_node,const FuncGraphPtr & func_graph)1281 Status SegmentEntireShapeToPartialForDynamic(const CNodePtr &reshape_node, const FuncGraphPtr &func_graph) {
1282 MS_EXCEPTION_IF_NULL(reshape_node);
1283 // reshape_node is Reshape node.
1284 // Step1. Get reshape_node's user tensor layout.
1285 // Step2. Shard reshape_node's second input (only for TupleGetItem).
1286 auto tensor_redistribution = GetTensorRedistributionFromCNode(reshape_node);
1287 if (tensor_redistribution == nullptr) {
1288 MS_LOG(WARNING) << "Cannot find layout in " << reshape_node->fullname_with_scope();
1289 return Status::FAILED;
1290 }
1291 if (!tensor_redistribution->is_dynamic_shape()) {
1292 MS_LOG(INFO) << reshape_node->fullname_with_scope() << " is static shape.";
1293 return Status::SUCCESS;
1294 }
1295 TensorLayout out_layout = tensor_redistribution->to_origin_no_assembled();
1296 auto tensor_map = out_layout.tensor_map();
1297 auto dev_mat = out_layout.device_arrangement();
1298 std::map<size_t, int64_t> factor_mapping;
1299 for (size_t i = 0; i < tensor_map.array().size(); ++i) {
1300 if (tensor_map.GetDimByIdx(i) != -1) {
1301 factor_mapping.insert({i, dev_mat.GetDimByReverseIdx(tensor_map.GetDimByIdx(i))});
1302 }
1303 }
1304 auto shape_input = reshape_node->input(INDEX_TWO);
1305 if (!shape_input->isa<CNode>()) {
1306 MS_LOG(DEBUG) << "Reshape's second input is not a CNode.";
1307 return Status::SUCCESS;
1308 }
1309 auto shape_input_cnode = shape_input->cast<CNodePtr>();
1310 if (IsTargetOp(shape_input_cnode, MAKE_TUPLE)) {
1311 UpdateMakeTupleShapeValue(shape_input_cnode, factor_mapping, func_graph);
1312 }
1313 return Status::SUCCESS;
1314 }
1315
MergeEntireShapeForDynamic(const FuncGraphPtr & root)1316 Status MergeEntireShapeForDynamic(const FuncGraphPtr &root) {
1317 MS_LOG(INFO) << "Into MergeEntireShapeForDynamic";
1318 MS_EXCEPTION_IF_NULL(root);
1319 // Step1. Judge whether is dynamic shape.
1320 // Step2. Find all Shape node, get its factor arr.
1321 // Step3. Mul factor in Step2 to its child nodes(TupleGetItem).
1322 // Step4. Modify next nodes of TupleGetItem.
1323 auto ret_node = root->get_return();
1324 MS_EXCEPTION_IF_NULL(ret_node);
1325 auto all_nodes = DeepScopedGraphSearch(ret_node);
1326 std::reverse(all_nodes.begin(), all_nodes.end());
1327 std::set<FuncGraphPtr> graph_set = FindForwardGraphByRootNodes(all_nodes);
1328
1329 if (graph_set.empty()) {
1330 MS_LOG(INFO) << "Can not find the forward graph, so mark the ops in root graph";
1331 auto fgs = root->manager()->func_graphs();
1332 for (auto fg = fgs.cbegin(); fg != fgs.cend(); ++fg) {
1333 // Travers all node and find shape.
1334 auto fg_nodes_set = (*fg)->nodes();
1335 for (auto const &node : fg_nodes_set) {
1336 if (!node->isa<CNode>()) {
1337 continue;
1338 }
1339 auto cnode = node->cast<CNodePtr>();
1340 if (IsShapeOp(cnode)) {
1341 UpdateShapeNode(cnode, *fg);
1342 continue;
1343 }
1344 }
1345 }
1346 } else {
1347 MS_LOG(INFO) << "The sub graph size of root is " << root->func_graphs_used().size();
1348 for (auto func_graph = graph_set.cbegin(); func_graph != graph_set.cend(); ++func_graph) {
1349 auto return_node = (*func_graph)->get_return();
1350 MS_EXCEPTION_IF_NULL(return_node);
1351 std::vector<AnfNodePtr> all_dfs_nodes = DeepLinkedGraphSearch(return_node);
1352 for (const auto &node : all_dfs_nodes) {
1353 if (!node->isa<CNode>()) {
1354 continue;
1355 }
1356 auto cnode = node->cast<CNodePtr>();
1357 if (IsShapeOp(cnode)) {
1358 UpdateShapeNode(cnode, *func_graph);
1359 continue;
1360 }
1361 }
1362 }
1363 }
1364 return Status::SUCCESS;
1365 }
1366 } // namespace mindspore::parallel
1367