• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/parallel/ops_info/reshape_info.h"
18 
19 #include <memory>
20 #include <vector>
21 #include <utility>
22 #include <functional>
23 
24 #include "frontend/parallel/device_manager.h"
25 #include "frontend/parallel/device_matrix.h"
26 #include "frontend/parallel/dynamic_creator.h"
27 #include "frontend/parallel/step_parallel.h"
28 #include "frontend/parallel/step_parallel_utils.h"
29 #include "frontend/parallel/graph_util/graph_utils.h"
30 #include "frontend/parallel/tensor_layout/tensor_transform.h"
31 #include "frontend/parallel/auto_parallel/graph_costmodel.h"
32 #include "utils/log_adapter.h"
33 #include "mindspore/core/ops/auto_generate/gen_ops_primitive.h"
34 
35 namespace mindspore {
36 namespace parallel {
CheckStrategy(const StrategyPtr & strategy)37 Status ReshapeInfo::CheckStrategy(const StrategyPtr &strategy) { return CheckStrategyValue(strategy, inputs_shape_); }
38 
39 /*
40  * support parallel degree smaller than device number, set the duplicate device dimension to the first dimension of
41  * device matrix
42  * only support batch parallel reshape operator in ReID (batch parallel degree can be smaller than device number)
43  */
InferDevMatrixShape()44 Status ReshapeInfo::InferDevMatrixShape() {
45   Strategies stra = strategy_->GetInputDim();
46   input_strategy_ = stra.at(0);
47   dev_matrix_shape_ = stra.at(0);
48   return SUCCESS;
49 }
50 
51 /*
52  * there is no Parameter for Reshape Primitive, so no need to do allreduce
53  */
InferMirrorOps()54 Status ReshapeInfo::InferMirrorOps() {
55   mirror_ops_.clear();
56   Shape input_tensor_map = input_layout_.tensor_map().array();
57   std::vector<Group> input_group;
58   if (CreateGroupByTensorMap(input_tensor_map, &input_group) != SUCCESS) {
59     ReportError(name_ + ": Create group failed.");
60     return FAILED;
61   }
62 
63   OperatorVector op_for_input;
64   if (input_group.empty()) {
65     MS_LOG(INFO) << name_ << ": The mirror ops is empty.";
66     return SUCCESS;
67   }
68   if (!input_group.empty()) {
69     op_for_input = CreateMirrorOps(input_group[0].name(), input_group[0].GetDevNum());
70     std::string group_name = input_group[0].name();
71     MS_LOG(INFO) << name_ << ": Create the mirror ops for input_a success, group is " << group_name;
72   }
73   mirror_ops_.push_back(op_for_input);
74   OperatorVector op_for_input_empty;
75   mirror_ops_.push_back(op_for_input_empty);
76 
77   return SUCCESS;
78 }
79 
80 /*
81  * there is no reduction dimension for forward computation of Reshape Primitive, so no need to do allreduce
82  */
InferForwardCommunication()83 Status ReshapeInfo::InferForwardCommunication() { return SUCCESS; }
84 
GetInputShape(const AnfNodePtr & shape_input_node)85 std::vector<int64_t> ReshapeInfo::GetInputShape(const AnfNodePtr &shape_input_node) {
86   MS_EXCEPTION_IF_NULL(shape_input_node);
87   Shape origin_dst_shape;
88   if (shape_input_node->isa<ValueNode>()) {
89     auto shape_input_value_node = shape_input_node->cast<ValueNodePtr>();
90     MS_EXCEPTION_IF_NULL(shape_input_value_node);
91     auto shape_input_value = shape_input_value_node->value();
92     MS_EXCEPTION_IF_NULL(shape_input_value);
93     origin_dst_shape = GetValue<std::vector<int64_t>>(shape_input_value);
94   } else if (IsPrimitiveCNode(shape_input_node, prim::kPrimMakeTuple)) {
95     auto shape_input_cnode = shape_input_node->cast<CNodePtr>();
96     MS_EXCEPTION_IF_NULL(shape_input_cnode);
97     for (size_t i = 1; i < shape_input_cnode->size(); ++i) {
98       auto input_node = shape_input_cnode->input(i);
99       MS_EXCEPTION_IF_NULL(input_node);
100       if (input_node->isa<ValueNode>()) {
101         auto input_value_node = input_node->cast<ValueNodePtr>();
102         MS_EXCEPTION_IF_NULL(input_value_node);
103         origin_dst_shape.push_back(GetValue<int64_t>(input_value_node->value()));
104       } else {
105         // the dst shape is dynamic, if two or more dimensions are dynamic, it requires additional processing
106         origin_dst_shape.push_back(abstract::Shape::kShapeDimAny);
107       }
108     }
109   } else if (IsPrimitiveCNode(shape_input_node, prim::kPrimShape)) {
110     // dynamic shape: the dst shape is Shape op
111     MS_EXCEPTION_IF_NULL(input_value_[1]);
112     origin_dst_shape = GetValue<std::vector<int64_t>>(input_value_[1]);
113     MS_LOG(INFO) << name_ << ": the input value is Shape op, dst shape is " << origin_dst_shape;
114   } else {
115     MS_LOG(EXCEPTION) << name_ << ": input shape must be either Tuple or MakeTuple cnode or Shape op, but got "
116                       << shape_input_node->fullname_with_scope();
117   }
118   return origin_dst_shape;
119 }
120 
OnlyOneDimDynamicShape(const Shape & shape)121 bool OnlyOneDimDynamicShape(const Shape &shape) {
122   return (std::count(shape.cbegin(), shape.cend(), DYNAMIC_DIM_VAL) == 1);
123 }
124 
AccumulateShape(const Shape & shape)125 int64_t AccumulateShape(const Shape &shape) {
126   return std::accumulate(shape.cbegin(), shape.cend(), 1, std::multiplies<int64_t>());
127 }
128 
DynamicShapeIndex(const Shape & shape)129 size_t DynamicShapeIndex(const Shape &shape) {
130   if (!OnlyOneDimDynamicShape(shape)) {
131     MS_LOG(EXCEPTION) << "The shape is not one dim dynamic: " << shape;
132   }
133 
134   for (size_t i = 0; i < shape.size(); ++i) {
135     if (shape[i] == LongToInt(DYNAMIC_DIM_VAL)) {
136       return i;
137     }
138   }
139   return 0;
140 }
141 
ComputeReplaceOpForDynamicShape()142 Status ReshapeInfo::ComputeReplaceOpForDynamicShape() {
143   RankList dev_list = stage_device_list();
144   TensorRedistribution tensor_redistribution(!is_generating_costs_, true);
145 
146   TensorLayout fake_in;
147   TensorLayout fake_out;
148 
149   // replace -1 shape to 1
150   Shape replace_shape_in = input_layout_.tensor_shape_origin().array();
151   Shape replace_shape_out = output_layout_.tensor_shape_origin().array();
152   int64_t replace_value = 1;
153   auto accumulate_shape_in = AccumulateShape(replace_shape_in);
154   auto accumulate_shape_out = AccumulateShape(replace_shape_out);
155   if (accumulate_shape_in < accumulate_shape_out) {
156     replace_value = accumulate_shape_in / accumulate_shape_out;
157     (void)std::replace(replace_shape_in.begin(), replace_shape_in.end(), LongToInt(DYNAMIC_DIM_VAL), 1);
158     (void)std::replace(replace_shape_out.begin(), replace_shape_out.end(), LongToInt(DYNAMIC_DIM_VAL),
159                        LongToInt(replace_value));
160   } else {
161     replace_value = accumulate_shape_out / accumulate_shape_in;
162     (void)std::replace(replace_shape_in.begin(), replace_shape_in.end(), LongToInt(DYNAMIC_DIM_VAL),
163                        LongToInt(replace_value));
164     (void)std::replace(replace_shape_out.begin(), replace_shape_out.end(), LongToInt(DYNAMIC_DIM_VAL), 1);
165   }
166 
167   if (fake_in.InitFromVector(input_layout_.device_arrangement_origin().array(),
168                              input_layout_.origin_tensor_map().array(), replace_shape_in) != SUCCESS) {
169     MS_LOG(ERROR) << name_ << ": Fake tensor layout for input init failed, the old input layout is "
170                   << input_layout_.ToString() << ", the replace input shape is " << replace_shape_in;
171     return FAILED;
172   }
173 
174   if (fake_out.InitFromVector(output_layout_.device_arrangement_origin().array(),
175                               output_layout_.origin_tensor_map().array(), replace_shape_out) != SUCCESS) {
176     MS_LOG(ERROR) << name_ << ": Fake tensor layout for output init failed, the old output layout is "
177                   << output_layout_.ToString() << ", the replace output shape is " << replace_shape_out;
178     return FAILED;
179   }
180 
181   if (tensor_redistribution.Init(fake_in, fake_out, dev_list) != SUCCESS) {
182     MS_LOG(ERROR) << name_ << ": Redistribution init failed";
183     return FAILED;
184   }
185 
186   // use static shape to infer redistribution operator list
187   RedistributionOpListPtr redistribution_oplist_ptr = tensor_redistribution.InferTensorRedistributionOperatorList();
188   if (redistribution_oplist_ptr == nullptr) {
189     MS_LOG(ERROR) << name_ << ": InferTensorRedistribution failed.";
190     return FAILED;
191   }
192   replace_op_ = redistribution_oplist_ptr->first;
193   replace_op_info_ = redistribution_oplist_ptr->second;
194 
195   if (replace_op_.size() == 1 && replace_op_.front().first == RESHAPE) {
196     auto dst_shape = GetValue<std::vector<int64_t>>(replace_op_.front().second.second.front().first.second);
197     if (dst_shape.size() != output_layout_.tensor_shape_origin().array().size()) {
198       MS_LOG(ERROR) << name_ << ": The size of dst shape must equal to output origin shape, but the dst shape is"
199                     << dst_shape << ", output origin shape is " << output_layout_.tensor_shape_origin().array();
200       return FAILED;
201     }
202     size_t index = DynamicShapeIndex(output_layout_.tensor_shape_origin().array());  // find the dynamic dimension
203     dst_shape[index] = DYNAMIC_DIM_VAL;                                              // reset the dynamic dimension
204     replace_op_.front().second.second.front().first.second = MakeValue(dst_shape);
205     return SUCCESS;
206   }
207 
208   MS_LOG(ERROR) << name_ << ": This sense is not supported, the input layout is " << input_layout_.ToString()
209                 << ", the output layout is " << output_layout_.ToString();
210   return FAILED;
211 }
212 
DstShapeIsConstant(const AnfNodePtr & shape_input_node)213 bool DstShapeIsConstant(const AnfNodePtr &shape_input_node) {
214   MS_EXCEPTION_IF_NULL(shape_input_node);
215   Shape origin_dst_shape;
216   if (shape_input_node->isa<ValueNode>()) {
217     return true;
218   }
219 
220   if (IsPrimitiveCNode(shape_input_node, prim::kPrimMakeTuple)) {
221     auto shape_input_cnode = shape_input_node->cast<CNodePtr>();
222     MS_EXCEPTION_IF_NULL(shape_input_cnode);
223     for (size_t i = 1; i < shape_input_cnode->size(); ++i) {
224       auto input_node = shape_input_cnode->input(i);
225       MS_EXCEPTION_IF_NULL(input_node);
226       if (input_node->isa<ValueNode>()) {
227         continue;
228       } else {
229         // the dst shape is dynamic
230         return false;
231       }
232     }
233     return true;
234   }
235 
236   if (IsPrimitiveCNode(shape_input_node, prim::kPrimShape)) {
237     // the dst shape is dynamic
238     return false;
239   }
240 
241   MS_LOG(EXCEPTION) << "The dst shape must be either Tuple or MakeTuple cnode or Shape op, but got "
242                     << shape_input_node->fullname_with_scope();
243 }
244 
ChangeDynamicDstShapeForSkipRedistribution(const AnfNodePtr & shape_input_node)245 void ReshapeInfo::ChangeDynamicDstShapeForSkipRedistribution(const AnfNodePtr &shape_input_node) {
246   MS_EXCEPTION_IF_NULL(shape_input_node);
247   if (!IsPrimitiveCNode(shape_input_node, prim::kPrimMakeTuple)) {
248     return;
249   }
250 
251   auto make_tuple_cnode = shape_input_node->cast<CNodePtr>();
252   MS_EXCEPTION_IF_NULL(make_tuple_cnode);
253   Shape out_strategy = output_layout_.shard_strategy();
254 
255   // two consecutive reshape: op1---reshape1---reshape2---op2,
256   // the in_layout and out_layout of reshape1 are the layout of op1's output,
257   // but the size of out_strategy's size may be not equal to the size of dst shape for reshape1,
258   // here find the total shard num in the constant part of the original shape,
259   // and find a constant shape from the dst shape that can be divided by it and perform the division
260   if (out_strategy.size() != (make_tuple_cnode->size() - 1)) {
261     MS_LOG(WARNING) << name_ << ": It may be a scene of two consecutive reshapes, the out_strategy size is "
262                     << out_strategy.size() << ", but the size of make_tuple's input is " << make_tuple_cnode->size() - 1
263                     << ", the input shape is " << inputs_shape_[0] << ", the output shape is " << outputs_shape_[0];
264     Shape input_shape = inputs_shape_[0];
265     if (input_shape.size() != out_strategy.size()) {
266       MS_LOG(EXCEPTION) << name_ << ": the size of input shape is not equal to the size of out_strategy";
267     }
268     int64_t constant_shard_num = 1;
269     for (size_t i = 0; i < input_shape.size(); ++i) {
270       if (input_shape[i] > 0) {
271         constant_shard_num *= out_strategy[i];
272       }
273     }
274     MS_LOG(INFO) << name_ << ": the shard num of constant shape is " << constant_shard_num;
275     if (constant_shard_num <= 1) {
276       return;
277     }
278     for (size_t i = 1; i < make_tuple_cnode->size(); ++i) {
279       auto input_node = make_tuple_cnode->input(i);
280       MS_EXCEPTION_IF_NULL(input_node);
281       auto value_node = GetValueNode(input_node);
282       if (value_node != nullptr && value_node->isa<Int64Imm>()) {
283         int64_t origin_shape_ele = GetValue<int64_t>(value_node);
284         if (origin_shape_ele > 0 && origin_shape_ele % constant_shard_num == 0) {
285           int64_t replace_shape = origin_shape_ele / constant_shard_num;
286           auto replace_value_ptr = MakeValue(replace_shape);
287           auto replace_value_node = std::make_shared<ValueNode>(replace_value_ptr);
288           make_tuple_cnode->set_input(i, replace_value_node);
289           return;
290         }
291       }
292     }
293     MS_LOG(EXCEPTION) << name_ << ": do not support this scenes, the output shape is  " << outputs_shape_[0]
294                       << ", the out strategy is " << out_strategy;
295   }
296 
297   // common reshape, handle the constant part of the dst shape, div by the corresponding out_strategy
298   for (size_t i = 1; i < make_tuple_cnode->size(); ++i) {
299     if (out_strategy[i - 1] <= 1) {
300       continue;
301     }
302     auto input_node = make_tuple_cnode->input(i);
303     MS_EXCEPTION_IF_NULL(input_node);
304     auto value_node = GetValueNode(input_node);
305     if (value_node != nullptr && value_node->isa<Int64Imm>()) {
306       auto origin_shape_ele = GetValue<int64_t>(value_node);
307       if (origin_shape_ele > 0) {
308         if (origin_shape_ele % out_strategy[i - 1] != 0) {
309           MS_LOG(EXCEPTION) << name_ << ": the origin shape is " << origin_shape_ele
310                             << ", can not be div by shard size " << out_strategy[i - 1];
311         }
312         int64_t replace_shape = origin_shape_ele / out_strategy[i - 1];
313         auto replace_value_ptr = MakeValue(replace_shape);
314         auto replace_value_node = std::make_shared<ValueNode>(replace_value_ptr);
315         make_tuple_cnode->set_input(i, replace_value_node);
316       }
317     }
318   }
319 }
320 
ReshapeRedistribution()321 TensorRedistributionPtr ReshapeInfo::ReshapeRedistribution() {
322   TensorRedistributionPtr tensor_redistribution = this->CreateReshapeTensorRedistribution(!is_generating_costs_, true);
323   RankList dev_list = stage_device_list();
324   if (tensor_redistribution->Init(input_layout_, output_layout_, dev_list) == FAILED) {
325     MS_LOG(EXCEPTION) << name_ << ": tensor_redistribution init failed.";
326   }
327   return tensor_redistribution;
328 }
329 
SpecialPatternInTransformer(const TensorLayout & from_layout,const TensorLayout & to_layout)330 bool SpecialPatternInTransformer(const TensorLayout &from_layout, const TensorLayout &to_layout) {
331   auto from_tensor_shape = from_layout.tensor_shape();
332   auto to_tensor_shape = to_layout.tensor_shape();
333   auto from_dev_mat = from_layout.device_arrangement();
334   auto to_dev_mat = to_layout.device_arrangement();
335   auto from_tensor_map = from_layout.tensor_map();
336   auto to_tensor_map = to_layout.tensor_map();
337   if (from_dev_mat.array() != to_dev_mat.array()) {
338     return false;
339   }
340   // Reshape (bs, d1*d2, d3) -> (bs, d1, d2, d3) and only shard on 1,2 axis.
341   if (from_tensor_shape.array().size() != SIZE_THREE && to_tensor_shape.array().size() != SIZE_FOUR) {
342     return false;
343   }
344   bool is_same_batch_dim = from_tensor_shape.GetDimByIdx(0) == to_tensor_shape.GetDimByIdx(0);
345   bool is_from_dyn_on_last_two_dim =
346     from_tensor_shape.GetDimByIdx(INDEX_ONE) == -1 && from_tensor_shape.GetDimByIdx(INDEX_TWO) == -1;
347   bool is_to_dyn_on_last_two_dim =
348     to_tensor_shape.GetDimByIdx(INDEX_TWO) == -1 && to_tensor_shape.GetDimByIdx(INDEX_THREE) == -1;
349   if (is_same_batch_dim && is_from_dyn_on_last_two_dim && is_to_dyn_on_last_two_dim) {
350     bool same_shard_on_front_two_dim = from_tensor_map.GetDimByIdx(0) == to_tensor_map.GetDimByIdx(0) &&
351                                        from_tensor_map.GetDimByIdx(1) == to_tensor_map.GetDimByIdx(1);
352     return same_shard_on_front_two_dim;
353   }
354   return false;
355 }
356 
SkipTensorRedistribution(const TensorLayout & from_layout,const TensorLayout & to_layout)357 bool SkipTensorRedistribution(const TensorLayout &from_layout, const TensorLayout &to_layout) {
358   // If only one axis is sharded, and it's const axis, use past solution.
359   size_t from_shard_axis_cnt = 0;
360   size_t to_shard_axis_cnt = 0;
361   size_t from_index = 0;
362   size_t to_index = 0;
363   for (size_t i = 0; i < from_layout.tensor_map().array().size(); ++i) {
364     if (from_layout.tensor_map().GetDimByIdx(i) != -1) {
365       from_shard_axis_cnt += 1;
366       from_index = i;
367     }
368   }
369   for (size_t i = 0; i < to_layout.tensor_map().array().size(); ++i) {
370     if (to_layout.tensor_map().GetDimByIdx(i) != -1) {
371       to_shard_axis_cnt += 1;
372       to_index = i;
373     }
374   }
375   bool only_shard_on_one_axis = from_shard_axis_cnt == 1 && to_shard_axis_cnt == 1;
376   bool only_shard_on_const_axis =
377     from_layout.tensor_shape().GetDimByIdx(from_index) != -1 && to_layout.tensor_shape().GetDimByIdx(to_index) != -1;
378   if (only_shard_on_one_axis && only_shard_on_const_axis) {
379     return true;
380   }
381   return false;
382 }
383 
ComputeReplaceOp()384 Status ReshapeInfo::ComputeReplaceOp() {
385   MS_LOG(INFO) << "Infer reshape redistribution for " << this->cnode_->fullname_with_scope() << "." << std::endl
386                << "input_layout_: " << this->input_layout_.ToString() << std::endl
387                << "output_layout_: " << this->output_layout_.ToString();
388   if (is_skip_) {
389     MS_LOG(DEBUG) << "Skip reshape redistribution for " << cnode_->fullname_with_scope() << std::endl;
390     if (DstShapeIsConstant(cnode_->input(2))) {
391       ConstructOperator constructor;
392       replace_op_ = constructor.SkipRedisReshapeOP(output_layout_.slice_shape().array());
393       replace_op_info_.clear();
394       MS_LOG(INFO) << "skip reshape redistribution and reshape slice_shape is "
395                    << ShapeToString(output_layout_.slice_shape().array());
396     } else {
397       replace_op_.clear();
398       replace_op_info_.clear();
399       MS_LOG(WARNING) << name_ << ": dst shape is dynamic, and skip redistribution";
400       // need to modify the dst shape
401       ChangeDynamicDstShapeForSkipRedistribution(cnode_->input(2));
402     }
403   } else {
404     if (AccumulateShape(input_layout_.shard_strategy()) == 1 && AccumulateShape(output_layout_.shard_strategy()) == 1) {
405       // input and output have not shard
406       replace_op_.clear();
407       replace_op_info_.clear();
408       return SUCCESS;
409     }
410     auto reshape_input = this->cnode_->input(1);
411     MS_EXCEPTION_IF_CHECK_FAIL(reshape_input != nullptr,
412                                "input of Reshape " + this->cnode_->fullname_with_scope() + " is nullptr.");
413 
414     RankList dev_list = stage_device_list();
415     TensorRedistributionPtr tensor_redistribution =
416       this->CreateReshapeTensorRedistribution(!is_generating_costs_, true);
417     tensor_redistribution->SetPreAndNextCNode(reshape_input, this->cnode_);
418     if (tensor_redistribution->Init(input_layout_, output_layout_, dev_list) == FAILED) {
419       if (is_generating_costs_) {
420         MS_LOG(DEBUG) << name_ << ": tensor_redistribution init failed.";
421       } else {
422         MS_LOG(ERROR) << name_ << ": tensor_redistribution init failed.";
423       }
424       return FAILED;
425     }
426     MS_LOG(DEBUG) << name_ << ": input " << input_layout_.ToString();
427     MS_LOG(DEBUG) << name_ << ": output " << output_layout_.ToString();
428     MS_LOG(DEBUG) << name_ << ": dev_list " << dev_list.size();
429     RedistributionOpListPtr redistribution_oplist_ptr;
430     Shape output_tensor_shape = this->output_layout_.tensor_shape().array();
431     if (this->input_layout_.tensor_shape().array().size() != this->output_layout_.tensor_shape().array().size() &&
432         std::count(output_tensor_shape.begin(), output_tensor_shape.end(), -1) > 1) {
433       // If only one axis is sharded, and it's const axis, use past solution.
434       if (SpecialPatternInTransformer(this->input_layout_, this->output_layout_)) {
435         MS_LOG(INFO) << "Match special pattern in transformer.";
436         replace_op_.clear();
437         replace_op_info_.clear();
438         return SUCCESS;
439       }
440       if (SkipTensorRedistribution(this->input_layout_, this->output_layout_)) {
441         MS_LOG(WARNING) << "Skip tensor redistribution for " << this->cnode_->fullname_with_scope();
442         replace_op_.clear();
443         replace_op_info_.clear();
444         ChangeDynamicDstShapeForSkipRedistribution(cnode_->input(2));
445         return SUCCESS;
446       }
447       // use naive method. Do AllGather on each dim.
448       tensor_redistribution->set_original_reshape_shape(this->cnode_->input(INDEX_TWO));
449       MS_LOG(INFO) << this->name_
450                    << " has more than 1 dynamic axis. shape: " << this->cnode_->input(INDEX_TWO)->fullname_with_scope();
451       redistribution_oplist_ptr = tensor_redistribution->InferTensorRedistributionOperatorListForMultiDynamicReshape();
452     } else {
453       tensor_redistribution->set_original_reshape_shape(nullptr);
454       redistribution_oplist_ptr = tensor_redistribution->InferTensorRedistributionOperatorList();
455     }
456     if (!is_generating_costs_ && !tensor_redistribution->IsAssembledStaticShape()) {
457       redistribution_oplist_ptr = TensorTransform::GetInstance()->OptimizeTensorRedistributionOperatorList(
458         redistribution_oplist_ptr, tensor_redistribution->input_shape());
459     }
460     if (redistribution_oplist_ptr == nullptr) {
461       if (is_generating_costs_) {
462         MS_LOG(DEBUG) << name_ << "InferTensorRedistribution failed.";
463       } else {
464         MS_LOG(ERROR) << name_ << "InferTensorRedistribution failed.";
465       }
466       return FAILED;
467     }
468     if (!redistribution_oplist_ptr->first.empty() && tensor_redistribution->original_reshape_shape() == nullptr &&
469         tensor_redistribution->IsAssembledStaticShape()) {
470       auto func_graph = this->cnode_->func_graph();
471       tensor_redistribution->CreateAssembledDynamicMapping(this->cnode_, reshape_input, func_graph, INDEX_ONE);
472     }
473     replace_op_ = redistribution_oplist_ptr->first;
474     replace_op_info_ = redistribution_oplist_ptr->second;
475   }
476   MS_LOG(DEBUG) << name_ << ": replace op size = " << replace_op_.size();
477   if (replace_op_.size() == 1 && replace_op_.front().first == RESHAPE) {
478     ChangeDstShape();
479   }
480   return SUCCESS;
481 }
482 
ChangeDstShape()483 void ReshapeInfo::ChangeDstShape() {
484   int64_t shape_dim = 2;
485   auto value = replace_op_.front().second.second.front().first.second;
486   Shape dst_shape = GetValue<std::vector<int64_t>>(value);
487   Shape origin_dst_shape = GetInputShape(cnode_->input(LongToSize(shape_dim)));
488   if (dst_shape.size() == origin_dst_shape.size()) {
489     for (size_t i = 0; i < dst_shape.size(); ++i) {
490       if (origin_dst_shape[i] != dst_shape[i] && origin_dst_shape[i] != DYNAMIC_DIM_VAL) {
491         return;
492       }
493     }
494     int64_t dyn_dim_cnt = std::count(origin_dst_shape.cbegin(), origin_dst_shape.cend(), DYNAMIC_DIM_VAL);
495     if (dyn_dim_cnt > 1) {
496       MS_LOG(DEBUG) << name_ << ": Don't need to replace reshape's target shape.";
497       return;
498     }
499     MS_LOG(INFO) << name_ << ": The reshape would not change the target shape.";
500     replace_op_.front().second.second.front().first.second = MakeValue(origin_dst_shape);
501   }
502 }
503 
504 /*
505  * the first dimension of input tensor map and output tensor map is set to the last dimension of device arrangement,
506  * all other dimension is set to None
507  * only support batch parallel reshape operator in ReID (batch parallel degree can be smaller than device number)
508  */
InferTensorMap()509 Status ReshapeInfo::InferTensorMap() {
510   if ((inputs_shape_.size() != 1) || (outputs_shape_.size() != 1)) {
511     MS_LOG(ERROR) << name_ << ": inputs shape and outputs shape size must be 1. inputs shape and outputs shape are "
512                   << inputs_shape_.size() << " and " << outputs_shape_.size();
513     return FAILED;
514   }
515 
516   Shape tensor_map_index_input;
517   for (size_t j = 0; j < inputs_shape_[0].size(); ++j) {
518     tensor_map_index_input.push_back(SizeToLong(inputs_shape_[0].size() - j - 1));
519   }
520   inputs_tensor_map_.push_back(tensor_map_index_input);
521 
522   Shape tensor_map_index_output;
523   for (size_t j = 0; j < outputs_shape_[0].size(); ++j) {
524     tensor_map_index_output.push_back(MAP_NONE);
525   }
526   outputs_tensor_map_.push_back(tensor_map_index_output);
527   return SUCCESS;
528 }
529 
InferTensorLayout(TensorLayouts * inputs_layout,TensorLayouts * outputs_layout)530 Status ReshapeInfo::InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout) {
531   if (inputs_layout == nullptr || outputs_layout == nullptr) {
532     MS_LOG(ERROR) << name_ << ": InferTensorLayout: the layout is null.";
533     return FAILED;
534   }
535   Arrangement dev_matrix;
536   Status status = dev_matrix.Init(dev_matrix_shape_);
537   if (status != Status::SUCCESS) {
538     return status;
539   }
540   // infer input tensor info
541   Shape shape_array_in = inputs_shape_.at(0);
542   TensorMap tensor_map_array_in = inputs_tensor_map_.at(0);
543   TensorLayout tensor_layout_in;
544   Map tensor_map_in;
545   status = tensor_map_in.Init(tensor_map_array_in);
546   if (status != Status::SUCCESS) {
547     return status;
548   }
549   Arrangement shape_in;
550   status = shape_in.Init(shape_array_in);
551   if (status != Status::SUCCESS) {
552     return status;
553   }
554   (void)tensor_layout_in.Init(dev_matrix, tensor_map_in, shape_in);
555   inputs_layout->push_back(tensor_layout_in);
556   // infer output tensor info
557   Shape shape_array_out = outputs_shape_.at(0);
558 
559   TensorMap tensor_map_array_out = outputs_tensor_map_.at(0);
560   TensorLayout tensor_layout_out;
561   Map tensor_map_out;
562   status = tensor_map_out.Init(tensor_map_array_out);
563   if (status != Status::SUCCESS) {
564     return status;
565   }
566   Arrangement shape_out;
567   status = shape_out.Init(shape_array_out);
568   if (status != Status::SUCCESS) {
569     return status;
570   }
571   (void)tensor_layout_out.Init(dev_matrix, tensor_map_out, shape_out);
572   outputs_layout->push_back(tensor_layout_out);
573 
574   input_layout_ = tensor_layout_in;
575   output_layout_ = tensor_layout_out;
576   return SUCCESS;
577 }
578 
InferTensorInfo()579 Status ReshapeInfo::InferTensorInfo() {
580   // skip reshape infer if skip_redistribution is true
581   if (is_skip_) {
582     TensorLayout layout;
583     Shape shape;
584     Shape slice_shape;
585     layout.set_skip_redistribution(true);
586     TensorInfo tensor_info_in(layout, shape, slice_shape);
587     inputs_tensor_info_.push_back(tensor_info_in);
588     outputs_tensor_info_.push_back(tensor_info_in);
589     MS_LOG(DEBUG) << name() << "skip redistribution reshape InferTensorInfo";
590     return SUCCESS;
591   }
592 
593   TensorLayouts inputs_layout, outputs_layout;
594   if (InferTensorLayout(&inputs_layout, &outputs_layout) != SUCCESS) {
595     return FAILED;
596   }
597   TensorLayout tensor_layout_in = inputs_layout.at(0);
598   TensorLayout tensor_layout_out = outputs_layout.at(0);
599   TensorInfo tensor_info_in(tensor_layout_in);
600   TensorInfo tensor_info_out(tensor_layout_out);
601   inputs_tensor_info_.push_back(tensor_info_in);
602   outputs_tensor_info_.push_back(tensor_info_out);
603   return SUCCESS;
604 }
605 
InferTensorInfoByLayout()606 void ReshapeInfo::InferTensorInfoByLayout() {
607   TensorInfo tensor_info_in(input_layout_);
608   TensorInfo tensor_info_out(output_layout_);
609   inputs_tensor_info_.push_back(tensor_info_in);
610   outputs_tensor_info_.push_back(tensor_info_out);
611 }
612 
device_number()613 void ReshapeInfo::device_number() {
614   dev_num_ = stage_device_size_;
615   MS_ASSERT(dev_num_ > 0);
616 }
617 
InferDefaultLayout(const Shape & shape,TensorLayout * const layout)618 Status ReshapeInfo::InferDefaultLayout(const Shape &shape, TensorLayout *const layout) {
619   Shape tensor_map_index;
620   for (size_t i = 0; i < shape.size(); i++) {
621     tensor_map_index.push_back(MAP_NONE);
622   }
623   Status status = layout->InitFromVector({dev_num_}, tensor_map_index, shape);
624   if (status != Status::SUCCESS) {
625     MS_LOG(ERROR) << name_ << ": InferDefaultLayout failed.";
626     return status;
627   }
628   return Status::SUCCESS;
629 }
630 
Init(const StrategyPtr & in_strategy,const StrategyPtr & out_strategy,const std::vector<std::shared_ptr<TensorLayout>> & in_tensor_layouts,const std::vector<std::shared_ptr<TensorLayout>> & out_tensor_layouts)631 Status ReshapeInfo::Init(const StrategyPtr &in_strategy, const StrategyPtr &out_strategy,
632                          const std::vector<std::shared_ptr<TensorLayout>> &in_tensor_layouts,
633                          const std::vector<std::shared_ptr<TensorLayout>> &out_tensor_layouts) {
634   MS_LOG(DEBUG) << "Init for " << this->cnode()->fullname_with_scope();
635   auto reshape_skip_redis_iter = attrs_.find(SKIP_REDISTRIBUTION);
636   if (reshape_skip_redis_iter != attrs_.end()) {
637     MS_EXCEPTION_IF_NULL(reshape_skip_redis_iter->second);
638     if (!reshape_skip_redis_iter->second->isa<BoolImm>()) {
639       MS_LOG(ERROR) << name_ << ": skip_redistribution is not a bool.";
640       return FAILED;
641     }
642     is_skip_ = reshape_skip_redis_iter->second->cast<BoolImmPtr>()->value();
643   }
644 
645   ResetQueueMember();
646   device_number();
647   if (in_strategy) {
648     if (InitWithAutoRepeatCalc(in_strategy, out_strategy) != SUCCESS) {
649       MS_LOG(ERROR) << name_ << ": Init failed.";
650       return FAILED;
651     }
652   } else {
653     if (!input_layout_set_flag_) {
654       MS_ASSERT(inputs_shape_.size() == 1);
655       Status status = InferDefaultLayout(inputs_shape_.at(0), &input_layout_);
656       if (status != SUCCESS) {
657         MS_LOG(ERROR) << name_ << ": infer input default layout failed.";
658         return status;
659       }
660     }
661     if (!output_layout_set_flag_) {
662       MS_ASSERT(output_layout_.size() == 1);
663       Status status = InferDefaultLayout(outputs_shape_.at(0), &output_layout_);
664       if (status != SUCCESS) {
665         MS_LOG(ERROR) << name_ << ": infer output default layout failed.";
666         return status;
667       }
668     }
669     inputs_tensor_map_.push_back(input_layout_.tensor_map().array());
670     outputs_tensor_map_.push_back(output_layout_.tensor_map().array());
671     InferTensorInfoByLayout();
672     // change dev_matrix_shape_ to input_layout_ device_arrangement before InferMirrorOps
673     dev_matrix_shape_ = input_layout_.device_arrangement().array();
674     if (InferMirrorOps() != SUCCESS) {
675       MS_LOG(ERROR) << name_ << ": InferMirrorOps failed.";
676       return FAILED;
677     }
678     // change dev_matrix_shape_ to output_layout_ device_arrangement before InferVirtualDivOps
679     dev_matrix_shape_ = output_layout_.device_arrangement().array();
680     if (InferVirtualDivOps() != SUCCESS) {
681       MS_LOG(ERROR) << name_ << ": InferVirtualDivOps failed.";
682       return FAILED;
683     }
684   }
685   if (input_layout_.GetVirtualRank().size() > 1 || output_layout_.GetVirtualRank().size() > 1) {
686     interleaved_parallel_ = true;
687     return SUCCESS;
688   }
689 
690   Status status = ComputeReplaceOp();
691   if (status != SUCCESS) {
692     MS_LOG(ERROR) << name_ << ": ComputeReplaceOp failed.";
693     return status;
694   }
695   return SUCCESS;
696 }
697 
SetCostUnderStrategy(const mindspore::parallel::StrategyPtr & strategy)698 Status ReshapeInfo::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr &strategy) {
699   return SetCostUnderStrategyBase(strategy);
700 }
701 
SetCostForReshapeWithParameter()702 void ReshapeInfo::SetCostForReshapeWithParameter() {
703   size_t success = 0;
704   for (auto &sp : sp_vector_) {
705     if (SetCostUnderStrategy(sp) == SUCCESS) {
706       success++;
707       MS_LOG(INFO) << name_ << ": Successfully generated the " << GetSerialNumberString(success)
708                    << " strategy: " << sp->ToString();
709     }
710   }
711 }
712 
SetCostForReshape(const mindspore::parallel::StrategyPtr & strategy)713 void ReshapeInfo::SetCostForReshape(const mindspore::parallel::StrategyPtr &strategy) {
714   MS_EXCEPTION_IF_NULL(strategy);
715   int64_t stage_id = strategy->GetInputStage();
716   double computation_cost =
717     operator_cost()->GetForwardComputationCost(inputs_tensor_info_, outputs_tensor_info_, stage_id);
718   double communication_cost = operator_cost()->GetCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id);
719   const auto gamma = CostModelContext::GetInstance()->costmodel_gamma();
720   std::shared_ptr<Cost> result = std::make_shared<Cost>(computation_cost, communication_cost);
721   result->communication_without_parameter_ =
722     operator_cost()->GetForwardCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id);
723   result->communication_with_partial_para_ =
724     result->communication_without_parameter_ + gamma * (communication_cost - result->communication_without_parameter_);
725 
726   // Breaking ties for preferring data parallelization
727   BreakingTiesForPreferringDataParallel(strategy, result);
728   // refine communication cost calculation for practice
729   RefineForPracticalCost(result, false);
730 
731   std::shared_ptr<StrategyWithCost> swc =
732     std::make_shared<StrategyWithCost>(strategy, inputs_tensor_info_, outputs_tensor_info_);
733   swc->cost_list.push_back(result);
734   strategy_cost_.emplace_back(swc);
735 }
736 
GenerateOpStrategies(int64_t stage_id)737 std::vector<StrategyPtr> ReshapeInfo::GenerateOpStrategies(int64_t stage_id) {
738   if (inputs_shape_.empty()) {
739     MS_LOG(EXCEPTION) << name_ << ": Inputs shape size or is empty";
740   }
741   Shape input0_split;
742   (void)input0_split.insert(input0_split.cend(), inputs_shape_[0].size(), 1);
743   Shapes splittable_inputs = {input0_split};
744   // strategy used only in the input node is parameter,
745   // in other case, use the input node's output_layout as input_layout.
746   if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector_) != SUCCESS) {
747     MS_LOG(EXCEPTION) << name_ << ": GenerateStrategiesForIndependentInputs failed.";
748   }
749 
750   return sp_vector_;
751 }
752 
GenerateStrategyCosts(const std::vector<std::shared_ptr<StrategyWithCost>> & pre_stra_costs,std::vector<std::pair<std::vector<std::shared_ptr<StrategyWithCost>>,int64_t>> next_costs_index,int64_t out_index,bool is_prev_param,bool is_next_reshape)753 Status ReshapeInfo::GenerateStrategyCosts(
754   const std::vector<std::shared_ptr<StrategyWithCost>> &pre_stra_costs,
755   std::vector<std::pair<std::vector<std::shared_ptr<StrategyWithCost>>, int64_t>> next_costs_index, int64_t out_index,
756   bool is_prev_param, bool is_next_reshape) {
757   is_generating_costs_ = true;
758   for (auto pre_stra_cost : pre_stra_costs) {
759     std::vector<TensorInfo> pre_out_tensor_infos;
760     if (is_prev_param) {
761       pre_out_tensor_infos = pre_stra_cost->inputs_ptr;
762     } else {
763       pre_out_tensor_infos = pre_stra_cost->outputs_ptr;
764     }
765     if (pre_out_tensor_infos.size() <= LongToSize(out_index)) {
766       MS_LOG(ERROR) << "out_index is out of range of the tensor_infos in setting reshape's input_layout";
767       return FAILED;
768     }
769     TensorInfo pre_out_tensor_info = pre_out_tensor_infos[LongToSize(out_index)];
770     SetInputLayout(pre_out_tensor_info.tensor_layout());
771     // infer pre_node output strategy from output_layout.
772     Dimensions stra = pre_out_tensor_info.InferStrategy();
773     if (stra.empty()) {
774       MS_LOG(ERROR) << "Infer strategy by tensor_info failed";
775       return FAILED;
776     }
777     Strategies stra_inputs = {stra};
778     StrategyPtr reshape_stra = std::make_shared<Strategy>(pre_stra_cost->strategy_ptr->GetInputStage(), stra_inputs);
779     if (is_next_reshape) {
780       SetOutputLayout(pre_out_tensor_info.tensor_layout());
781       ResetQueueMember();
782       InferTensorInfoByLayout();
783       SetCostForReshape(reshape_stra);
784     } else if (next_costs_index.empty()) {
785       if (Init(nullptr, nullptr) == FAILED) {
786         MS_LOG(ERROR) << "Failure:operator reshape init failed";
787         return FAILED;
788       }
789       SetCostForReshape(reshape_stra);
790       continue;
791     }
792     for (auto next_cost_index_pair : next_costs_index) {
793       auto in_index = next_cost_index_pair.second;
794       auto next_stra_costs = next_cost_index_pair.first;
795       for (auto next_stra_cost : next_stra_costs) {
796         std::vector<TensorInfo> next_in_tensor_infos = next_stra_cost->inputs_ptr;
797         if (next_in_tensor_infos.size() <= LongToSize(in_index)) {
798           MS_LOG(ERROR) << "in_index is out of range of the tensor_infos in setting reshape's output_layout";
799           return FAILED;
800         }
801         TensorInfo next_in_tensor_info = next_in_tensor_infos[LongToSize(in_index)];
802 
803         SetOutputLayout(next_in_tensor_info.tensor_layout());
804         ResetQueueMember();
805         InferTensorInfoByLayout();
806         SetCostForReshape(reshape_stra);
807       }
808     }
809   }
810   is_generating_costs_ = false;
811   if (strategy_cost_.empty()) {
812     return FAILED;
813   }
814   MS_LOG(INFO) << "Print " << name() << "'s 'strategy_cost':";
815   for (auto &swc : strategy_cost_) {
816     MS_LOG(INFO) << name() << "'s strategy: " << swc->strategy_ptr->ToString();
817     MS_LOG(INFO) << "The corresponding cost: " << swc->cost_list[0]->computation_cost_ << ", "
818                  << swc->cost_list[0]->communication_cost_ << ", "
819                  << swc->cost_list[0]->communication_without_parameter_;
820     MS_LOG(INFO) << "Input layout: " << swc->inputs_ptr[0].tensor_layout().ToString();
821     MS_LOG(INFO) << "Output layout: " << swc->outputs_ptr[0].tensor_layout().ToString();
822   }
823   return SUCCESS;
824 }
825 
GetSWCIndexByOutputLayoutWithZeroComm(const TensorLayout & output_layout)826 int64_t ReshapeInfo::GetSWCIndexByOutputLayoutWithZeroComm(const TensorLayout &output_layout) {
827   std::vector<std::pair<int64_t, double>> index_computation;
828   for (size_t i = 0; i < strategy_cost_.size(); ++i) {
829     const auto &swc = strategy_cost_[i];
830     if (swc->outputs_ptr[0].tensor_layout() == output_layout &&
831         fabs(swc->cost_list[0]->communication_without_parameter_ - 0.0) < DBL_EPSILON) {
832       (void)index_computation.emplace_back(SizeToLong(i), swc->cost_list[0]->computation_cost_);
833     }
834   }
835   if (index_computation.empty()) {
836     MS_LOG(WARNING) << "There in no available strategy for zero communication cost for reshape: " << name();
837     return -1;
838   }
839   if (index_computation.size() > 1) {
840     MS_LOG(INFO) << "There are multiple strategies available for reshape: " << name();
841   }
842   std::sort(
843     index_computation.begin(), index_computation.end(),
844     [](const std::pair<int64_t, double> &a, const std::pair<int64_t, double> &b) { return a.second < b.second; });
845   return index_computation[0].first;
846 }
847 
GetSWCIndexByOutputLayoutWithMiniComm(const TensorLayout & output_layout)848 int64_t ReshapeInfo::GetSWCIndexByOutputLayoutWithMiniComm(const TensorLayout &output_layout) {
849   std::vector<std::pair<int64_t, double>> index_comm;
850   for (size_t i = 0; i < strategy_cost_.size(); ++i) {
851     const auto &swc = strategy_cost_[i];
852     if (swc->outputs_ptr[0].tensor_layout() == output_layout) {
853       (void)index_comm.emplace_back(SizeToLong(i), swc->cost_list[0]->communication_without_parameter_);
854     }
855   }
856   if (index_comm.empty()) {
857     MS_LOG(ERROR) << "There in no available strategy for zero communication cost for reshape: " << name();
858     return -1;
859   }
860   if (index_comm.size() > 1) {
861     MS_LOG(INFO) << "There are multiple strategies available for reshape: " << name();
862   }
863   std::sort(
864     index_comm.begin(), index_comm.end(),
865     [](const std::pair<int64_t, double> &a, const std::pair<int64_t, double> &b) { return a.second < b.second; });
866   return index_comm[0].first;
867 }
868 
GetSWCIndexByInputLayoutWithZeroComm(const TensorLayout & input_layout)869 int64_t ReshapeInfo::GetSWCIndexByInputLayoutWithZeroComm(const TensorLayout &input_layout) {
870   std::vector<std::pair<int64_t, double>> index_computation;
871   for (size_t i = 0; i < strategy_cost_.size(); ++i) {
872     const auto &swc = strategy_cost_[i];
873     if (swc->inputs_ptr[0].tensor_layout() == input_layout &&
874         fabs(swc->cost_list[0]->communication_without_parameter_ - 0.0) < DBL_EPSILON) {
875       (void)index_computation.emplace_back(SizeToLong(i), swc->cost_list[0]->computation_cost_);
876     }
877   }
878   if (index_computation.empty()) {
879     MS_LOG(WARNING) << "There in no available strategy for zero communication cost for reshape: " << name();
880     return -1;
881   }
882   if (index_computation.size() > 1) {
883     MS_LOG(INFO) << "There are multiple strategies available for reshape: " << name();
884   }
885   std::sort(
886     index_computation.begin(), index_computation.end(),
887     [](const std::pair<int64_t, double> &a, const std::pair<int64_t, double> &b) { return a.second < b.second; });
888   return index_computation[0].first;
889 }
890 
GetSWCIndexByInputLayoutWithMiniComm(const TensorLayout & input_layout)891 int64_t ReshapeInfo::GetSWCIndexByInputLayoutWithMiniComm(const TensorLayout &input_layout) {
892   std::vector<std::pair<int64_t, double>> index_comm;
893   for (size_t i = 0; i < strategy_cost_.size(); ++i) {
894     const auto &swc = strategy_cost_[i];
895     if (swc->inputs_ptr[0].tensor_layout() == input_layout) {
896       (void)index_comm.emplace_back(SizeToLong(i), swc->cost_list[0]->communication_without_parameter_);
897     }
898   }
899   if (index_comm.empty()) {
900     MS_LOG(ERROR) << "There in no available strategy for zero communication cost for reshape: " << name();
901     return -1;
902   }
903   if (index_comm.size() > 1) {
904     MS_LOG(INFO) << "There are multiple strategies available for reshape: " << name();
905   }
906   std::sort(
907     index_comm.begin(), index_comm.end(),
908     [](const std::pair<int64_t, double> &a, const std::pair<int64_t, double> &b) { return a.second < b.second; });
909   return index_comm[0].first;
910 }
911 
CheckStrategyConsistencyByOutputLayout(int64_t swc_index,const TensorLayout & output_layout) const912 bool ReshapeInfo::CheckStrategyConsistencyByOutputLayout(int64_t swc_index, const TensorLayout &output_layout) const {
913   if (swc_index == -1 || swc_index >= SizeToLong(strategy_cost_.size())) {
914     MS_LOG(ERROR) << "The strategy_index: " << swc_index << " is out of range.";
915     return false;
916   }
917   const auto &swc = strategy_cost_[LongToSize(swc_index)];
918   if (swc->outputs_ptr[0].tensor_layout() == output_layout) {
919     return true;
920   }
921   MS_LOG(WARNING) << name_ << "'s desired output layout is: " << output_layout.ToString() << ", while the selected "
922                   << "output layout is: " << swc->outputs_ptr[0].tensor_layout().ToString()
923                   << " and the input layout is: " << swc->inputs_ptr[0].tensor_layout().ToString();
924   return false;
925 }
926 
CheckStrategyConsistencyByInputLayout(int64_t swc_index,const TensorLayout & input_layout) const927 bool ReshapeInfo::CheckStrategyConsistencyByInputLayout(int64_t swc_index, const TensorLayout &input_layout) const {
928   if (swc_index == -1 || swc_index >= SizeToLong(strategy_cost_.size())) {
929     MS_LOG(ERROR) << "The strategy_index: " << swc_index << " is out of range.";
930     return false;
931   }
932   const auto &swc = strategy_cost_[LongToSize(swc_index)];
933   if (swc->inputs_ptr[0].tensor_layout() == input_layout) {
934     return true;
935   }
936   MS_LOG(WARNING) << name_ << "'s desired input layout is:" << input_layout.ToString() << ", while the selected "
937                   << "input layout is: " << swc->inputs_ptr[0].tensor_layout().ToString()
938                   << " and the output layout is: " << swc->outputs_ptr[0].tensor_layout().ToString();
939   return false;
940 }
941 
GetInputLayoutBySWCIndex(int64_t swc_index) const942 TensorLayout ReshapeInfo::GetInputLayoutBySWCIndex(int64_t swc_index) const {
943   if (swc_index == -1 || swc_index >= SizeToLong(strategy_cost_.size())) {
944     MS_LOG(EXCEPTION) << "The strategy_index: " << swc_index << " is out of range.";
945   }
946   const auto &swc = strategy_cost_[LongToSize(swc_index)];
947   return std::move(swc->inputs_ptr[0].tensor_layout());
948 }
949 
GetOutputLayoutBySWCIndex(int64_t swc_index) const950 TensorLayout ReshapeInfo::GetOutputLayoutBySWCIndex(int64_t swc_index) const {
951   if (swc_index == -1 || swc_index >= SizeToLong(strategy_cost_.size())) {
952     MS_LOG(EXCEPTION) << "The strategy_index: " << swc_index << " is out of range.";
953   }
954   const auto &swc = strategy_cost_[LongToSize(swc_index)];
955   return std::move(swc->outputs_ptr[0].tensor_layout());
956 }
957 
get_input_shard_strategy()958 StrategyPtr ReshapeInfo::get_input_shard_strategy() {
959   StrategyPtr ret = nullptr;
960   if (input_layout_set_flag_ && g_device_manager != nullptr) {
961     Strategies strategy;
962     Dimensions dim;
963     int64_t stage_id = g_device_manager->stage_id();
964     dim = input_layout_.shard_strategy();
965     strategy.push_back(dim);
966     ret = NewStrategy(stage_id, strategy);
967   }
968   return ret;
969 }
970 
971 REGISTER(ReshapeInfo);
972 }  // namespace parallel
973 }  // namespace mindspore
974