1 /**
2 * Copyright 2019-2024 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/parallel/ops_info/reshape_info.h"
18
19 #include <memory>
20 #include <vector>
21 #include <utility>
22 #include <functional>
23
24 #include "frontend/parallel/device_manager.h"
25 #include "frontend/parallel/device_matrix.h"
26 #include "frontend/parallel/dynamic_creator.h"
27 #include "frontend/parallel/step_parallel.h"
28 #include "frontend/parallel/step_parallel_utils.h"
29 #include "frontend/parallel/graph_util/graph_utils.h"
30 #include "frontend/parallel/tensor_layout/tensor_transform.h"
31 #include "frontend/parallel/auto_parallel/graph_costmodel.h"
32 #include "utils/log_adapter.h"
33 #include "mindspore/core/ops/auto_generate/gen_ops_primitive.h"
34
35 namespace mindspore {
36 namespace parallel {
CheckStrategy(const StrategyPtr & strategy)37 Status ReshapeInfo::CheckStrategy(const StrategyPtr &strategy) { return CheckStrategyValue(strategy, inputs_shape_); }
38
39 /*
40 * support parallel degree smaller than device number, set the duplicate device dimension to the first dimension of
41 * device matrix
42 * only support batch parallel reshape operator in ReID (batch parallel degree can be smaller than device number)
43 */
InferDevMatrixShape()44 Status ReshapeInfo::InferDevMatrixShape() {
45 Strategies stra = strategy_->GetInputDim();
46 input_strategy_ = stra.at(0);
47 dev_matrix_shape_ = stra.at(0);
48 return SUCCESS;
49 }
50
51 /*
52 * there is no Parameter for Reshape Primitive, so no need to do allreduce
53 */
InferMirrorOps()54 Status ReshapeInfo::InferMirrorOps() {
55 mirror_ops_.clear();
56 Shape input_tensor_map = input_layout_.tensor_map().array();
57 std::vector<Group> input_group;
58 if (CreateGroupByTensorMap(input_tensor_map, &input_group) != SUCCESS) {
59 ReportError(name_ + ": Create group failed.");
60 return FAILED;
61 }
62
63 OperatorVector op_for_input;
64 if (input_group.empty()) {
65 MS_LOG(INFO) << name_ << ": The mirror ops is empty.";
66 return SUCCESS;
67 }
68 if (!input_group.empty()) {
69 op_for_input = CreateMirrorOps(input_group[0].name(), input_group[0].GetDevNum());
70 std::string group_name = input_group[0].name();
71 MS_LOG(INFO) << name_ << ": Create the mirror ops for input_a success, group is " << group_name;
72 }
73 mirror_ops_.push_back(op_for_input);
74 OperatorVector op_for_input_empty;
75 mirror_ops_.push_back(op_for_input_empty);
76
77 return SUCCESS;
78 }
79
80 /*
81 * there is no reduction dimension for forward computation of Reshape Primitive, so no need to do allreduce
82 */
InferForwardCommunication()83 Status ReshapeInfo::InferForwardCommunication() { return SUCCESS; }
84
GetInputShape(const AnfNodePtr & shape_input_node)85 std::vector<int64_t> ReshapeInfo::GetInputShape(const AnfNodePtr &shape_input_node) {
86 MS_EXCEPTION_IF_NULL(shape_input_node);
87 Shape origin_dst_shape;
88 if (shape_input_node->isa<ValueNode>()) {
89 auto shape_input_value_node = shape_input_node->cast<ValueNodePtr>();
90 MS_EXCEPTION_IF_NULL(shape_input_value_node);
91 auto shape_input_value = shape_input_value_node->value();
92 MS_EXCEPTION_IF_NULL(shape_input_value);
93 origin_dst_shape = GetValue<std::vector<int64_t>>(shape_input_value);
94 } else if (IsPrimitiveCNode(shape_input_node, prim::kPrimMakeTuple)) {
95 auto shape_input_cnode = shape_input_node->cast<CNodePtr>();
96 MS_EXCEPTION_IF_NULL(shape_input_cnode);
97 for (size_t i = 1; i < shape_input_cnode->size(); ++i) {
98 auto input_node = shape_input_cnode->input(i);
99 MS_EXCEPTION_IF_NULL(input_node);
100 if (input_node->isa<ValueNode>()) {
101 auto input_value_node = input_node->cast<ValueNodePtr>();
102 MS_EXCEPTION_IF_NULL(input_value_node);
103 origin_dst_shape.push_back(GetValue<int64_t>(input_value_node->value()));
104 } else {
105 // the dst shape is dynamic, if two or more dimensions are dynamic, it requires additional processing
106 origin_dst_shape.push_back(abstract::Shape::kShapeDimAny);
107 }
108 }
109 } else if (IsPrimitiveCNode(shape_input_node, prim::kPrimShape)) {
110 // dynamic shape: the dst shape is Shape op
111 MS_EXCEPTION_IF_NULL(input_value_[1]);
112 origin_dst_shape = GetValue<std::vector<int64_t>>(input_value_[1]);
113 MS_LOG(INFO) << name_ << ": the input value is Shape op, dst shape is " << origin_dst_shape;
114 } else {
115 MS_LOG(EXCEPTION) << name_ << ": input shape must be either Tuple or MakeTuple cnode or Shape op, but got "
116 << shape_input_node->fullname_with_scope();
117 }
118 return origin_dst_shape;
119 }
120
OnlyOneDimDynamicShape(const Shape & shape)121 bool OnlyOneDimDynamicShape(const Shape &shape) {
122 return (std::count(shape.cbegin(), shape.cend(), DYNAMIC_DIM_VAL) == 1);
123 }
124
AccumulateShape(const Shape & shape)125 int64_t AccumulateShape(const Shape &shape) {
126 return std::accumulate(shape.cbegin(), shape.cend(), 1, std::multiplies<int64_t>());
127 }
128
DynamicShapeIndex(const Shape & shape)129 size_t DynamicShapeIndex(const Shape &shape) {
130 if (!OnlyOneDimDynamicShape(shape)) {
131 MS_LOG(EXCEPTION) << "The shape is not one dim dynamic: " << shape;
132 }
133
134 for (size_t i = 0; i < shape.size(); ++i) {
135 if (shape[i] == LongToInt(DYNAMIC_DIM_VAL)) {
136 return i;
137 }
138 }
139 return 0;
140 }
141
ComputeReplaceOpForDynamicShape()142 Status ReshapeInfo::ComputeReplaceOpForDynamicShape() {
143 RankList dev_list = stage_device_list();
144 TensorRedistribution tensor_redistribution(!is_generating_costs_, true);
145
146 TensorLayout fake_in;
147 TensorLayout fake_out;
148
149 // replace -1 shape to 1
150 Shape replace_shape_in = input_layout_.tensor_shape_origin().array();
151 Shape replace_shape_out = output_layout_.tensor_shape_origin().array();
152 int64_t replace_value = 1;
153 auto accumulate_shape_in = AccumulateShape(replace_shape_in);
154 auto accumulate_shape_out = AccumulateShape(replace_shape_out);
155 if (accumulate_shape_in < accumulate_shape_out) {
156 replace_value = accumulate_shape_in / accumulate_shape_out;
157 (void)std::replace(replace_shape_in.begin(), replace_shape_in.end(), LongToInt(DYNAMIC_DIM_VAL), 1);
158 (void)std::replace(replace_shape_out.begin(), replace_shape_out.end(), LongToInt(DYNAMIC_DIM_VAL),
159 LongToInt(replace_value));
160 } else {
161 replace_value = accumulate_shape_out / accumulate_shape_in;
162 (void)std::replace(replace_shape_in.begin(), replace_shape_in.end(), LongToInt(DYNAMIC_DIM_VAL),
163 LongToInt(replace_value));
164 (void)std::replace(replace_shape_out.begin(), replace_shape_out.end(), LongToInt(DYNAMIC_DIM_VAL), 1);
165 }
166
167 if (fake_in.InitFromVector(input_layout_.device_arrangement_origin().array(),
168 input_layout_.origin_tensor_map().array(), replace_shape_in) != SUCCESS) {
169 MS_LOG(ERROR) << name_ << ": Fake tensor layout for input init failed, the old input layout is "
170 << input_layout_.ToString() << ", the replace input shape is " << replace_shape_in;
171 return FAILED;
172 }
173
174 if (fake_out.InitFromVector(output_layout_.device_arrangement_origin().array(),
175 output_layout_.origin_tensor_map().array(), replace_shape_out) != SUCCESS) {
176 MS_LOG(ERROR) << name_ << ": Fake tensor layout for output init failed, the old output layout is "
177 << output_layout_.ToString() << ", the replace output shape is " << replace_shape_out;
178 return FAILED;
179 }
180
181 if (tensor_redistribution.Init(fake_in, fake_out, dev_list) != SUCCESS) {
182 MS_LOG(ERROR) << name_ << ": Redistribution init failed";
183 return FAILED;
184 }
185
186 // use static shape to infer redistribution operator list
187 RedistributionOpListPtr redistribution_oplist_ptr = tensor_redistribution.InferTensorRedistributionOperatorList();
188 if (redistribution_oplist_ptr == nullptr) {
189 MS_LOG(ERROR) << name_ << ": InferTensorRedistribution failed.";
190 return FAILED;
191 }
192 replace_op_ = redistribution_oplist_ptr->first;
193 replace_op_info_ = redistribution_oplist_ptr->second;
194
195 if (replace_op_.size() == 1 && replace_op_.front().first == RESHAPE) {
196 auto dst_shape = GetValue<std::vector<int64_t>>(replace_op_.front().second.second.front().first.second);
197 if (dst_shape.size() != output_layout_.tensor_shape_origin().array().size()) {
198 MS_LOG(ERROR) << name_ << ": The size of dst shape must equal to output origin shape, but the dst shape is"
199 << dst_shape << ", output origin shape is " << output_layout_.tensor_shape_origin().array();
200 return FAILED;
201 }
202 size_t index = DynamicShapeIndex(output_layout_.tensor_shape_origin().array()); // find the dynamic dimension
203 dst_shape[index] = DYNAMIC_DIM_VAL; // reset the dynamic dimension
204 replace_op_.front().second.second.front().first.second = MakeValue(dst_shape);
205 return SUCCESS;
206 }
207
208 MS_LOG(ERROR) << name_ << ": This sense is not supported, the input layout is " << input_layout_.ToString()
209 << ", the output layout is " << output_layout_.ToString();
210 return FAILED;
211 }
212
DstShapeIsConstant(const AnfNodePtr & shape_input_node)213 bool DstShapeIsConstant(const AnfNodePtr &shape_input_node) {
214 MS_EXCEPTION_IF_NULL(shape_input_node);
215 Shape origin_dst_shape;
216 if (shape_input_node->isa<ValueNode>()) {
217 return true;
218 }
219
220 if (IsPrimitiveCNode(shape_input_node, prim::kPrimMakeTuple)) {
221 auto shape_input_cnode = shape_input_node->cast<CNodePtr>();
222 MS_EXCEPTION_IF_NULL(shape_input_cnode);
223 for (size_t i = 1; i < shape_input_cnode->size(); ++i) {
224 auto input_node = shape_input_cnode->input(i);
225 MS_EXCEPTION_IF_NULL(input_node);
226 if (input_node->isa<ValueNode>()) {
227 continue;
228 } else {
229 // the dst shape is dynamic
230 return false;
231 }
232 }
233 return true;
234 }
235
236 if (IsPrimitiveCNode(shape_input_node, prim::kPrimShape)) {
237 // the dst shape is dynamic
238 return false;
239 }
240
241 MS_LOG(EXCEPTION) << "The dst shape must be either Tuple or MakeTuple cnode or Shape op, but got "
242 << shape_input_node->fullname_with_scope();
243 }
244
ChangeDynamicDstShapeForSkipRedistribution(const AnfNodePtr & shape_input_node)245 void ReshapeInfo::ChangeDynamicDstShapeForSkipRedistribution(const AnfNodePtr &shape_input_node) {
246 MS_EXCEPTION_IF_NULL(shape_input_node);
247 if (!IsPrimitiveCNode(shape_input_node, prim::kPrimMakeTuple)) {
248 return;
249 }
250
251 auto make_tuple_cnode = shape_input_node->cast<CNodePtr>();
252 MS_EXCEPTION_IF_NULL(make_tuple_cnode);
253 Shape out_strategy = output_layout_.shard_strategy();
254
255 // two consecutive reshape: op1---reshape1---reshape2---op2,
256 // the in_layout and out_layout of reshape1 are the layout of op1's output,
257 // but the size of out_strategy's size may be not equal to the size of dst shape for reshape1,
258 // here find the total shard num in the constant part of the original shape,
259 // and find a constant shape from the dst shape that can be divided by it and perform the division
260 if (out_strategy.size() != (make_tuple_cnode->size() - 1)) {
261 MS_LOG(WARNING) << name_ << ": It may be a scene of two consecutive reshapes, the out_strategy size is "
262 << out_strategy.size() << ", but the size of make_tuple's input is " << make_tuple_cnode->size() - 1
263 << ", the input shape is " << inputs_shape_[0] << ", the output shape is " << outputs_shape_[0];
264 Shape input_shape = inputs_shape_[0];
265 if (input_shape.size() != out_strategy.size()) {
266 MS_LOG(EXCEPTION) << name_ << ": the size of input shape is not equal to the size of out_strategy";
267 }
268 int64_t constant_shard_num = 1;
269 for (size_t i = 0; i < input_shape.size(); ++i) {
270 if (input_shape[i] > 0) {
271 constant_shard_num *= out_strategy[i];
272 }
273 }
274 MS_LOG(INFO) << name_ << ": the shard num of constant shape is " << constant_shard_num;
275 if (constant_shard_num <= 1) {
276 return;
277 }
278 for (size_t i = 1; i < make_tuple_cnode->size(); ++i) {
279 auto input_node = make_tuple_cnode->input(i);
280 MS_EXCEPTION_IF_NULL(input_node);
281 auto value_node = GetValueNode(input_node);
282 if (value_node != nullptr && value_node->isa<Int64Imm>()) {
283 int64_t origin_shape_ele = GetValue<int64_t>(value_node);
284 if (origin_shape_ele > 0 && origin_shape_ele % constant_shard_num == 0) {
285 int64_t replace_shape = origin_shape_ele / constant_shard_num;
286 auto replace_value_ptr = MakeValue(replace_shape);
287 auto replace_value_node = std::make_shared<ValueNode>(replace_value_ptr);
288 make_tuple_cnode->set_input(i, replace_value_node);
289 return;
290 }
291 }
292 }
293 MS_LOG(EXCEPTION) << name_ << ": do not support this scenes, the output shape is " << outputs_shape_[0]
294 << ", the out strategy is " << out_strategy;
295 }
296
297 // common reshape, handle the constant part of the dst shape, div by the corresponding out_strategy
298 for (size_t i = 1; i < make_tuple_cnode->size(); ++i) {
299 if (out_strategy[i - 1] <= 1) {
300 continue;
301 }
302 auto input_node = make_tuple_cnode->input(i);
303 MS_EXCEPTION_IF_NULL(input_node);
304 auto value_node = GetValueNode(input_node);
305 if (value_node != nullptr && value_node->isa<Int64Imm>()) {
306 auto origin_shape_ele = GetValue<int64_t>(value_node);
307 if (origin_shape_ele > 0) {
308 if (origin_shape_ele % out_strategy[i - 1] != 0) {
309 MS_LOG(EXCEPTION) << name_ << ": the origin shape is " << origin_shape_ele
310 << ", can not be div by shard size " << out_strategy[i - 1];
311 }
312 int64_t replace_shape = origin_shape_ele / out_strategy[i - 1];
313 auto replace_value_ptr = MakeValue(replace_shape);
314 auto replace_value_node = std::make_shared<ValueNode>(replace_value_ptr);
315 make_tuple_cnode->set_input(i, replace_value_node);
316 }
317 }
318 }
319 }
320
ReshapeRedistribution()321 TensorRedistributionPtr ReshapeInfo::ReshapeRedistribution() {
322 TensorRedistributionPtr tensor_redistribution = this->CreateReshapeTensorRedistribution(!is_generating_costs_, true);
323 RankList dev_list = stage_device_list();
324 if (tensor_redistribution->Init(input_layout_, output_layout_, dev_list) == FAILED) {
325 MS_LOG(EXCEPTION) << name_ << ": tensor_redistribution init failed.";
326 }
327 return tensor_redistribution;
328 }
329
SpecialPatternInTransformer(const TensorLayout & from_layout,const TensorLayout & to_layout)330 bool SpecialPatternInTransformer(const TensorLayout &from_layout, const TensorLayout &to_layout) {
331 auto from_tensor_shape = from_layout.tensor_shape();
332 auto to_tensor_shape = to_layout.tensor_shape();
333 auto from_dev_mat = from_layout.device_arrangement();
334 auto to_dev_mat = to_layout.device_arrangement();
335 auto from_tensor_map = from_layout.tensor_map();
336 auto to_tensor_map = to_layout.tensor_map();
337 if (from_dev_mat.array() != to_dev_mat.array()) {
338 return false;
339 }
340 // Reshape (bs, d1*d2, d3) -> (bs, d1, d2, d3) and only shard on 1,2 axis.
341 if (from_tensor_shape.array().size() != SIZE_THREE && to_tensor_shape.array().size() != SIZE_FOUR) {
342 return false;
343 }
344 bool is_same_batch_dim = from_tensor_shape.GetDimByIdx(0) == to_tensor_shape.GetDimByIdx(0);
345 bool is_from_dyn_on_last_two_dim =
346 from_tensor_shape.GetDimByIdx(INDEX_ONE) == -1 && from_tensor_shape.GetDimByIdx(INDEX_TWO) == -1;
347 bool is_to_dyn_on_last_two_dim =
348 to_tensor_shape.GetDimByIdx(INDEX_TWO) == -1 && to_tensor_shape.GetDimByIdx(INDEX_THREE) == -1;
349 if (is_same_batch_dim && is_from_dyn_on_last_two_dim && is_to_dyn_on_last_two_dim) {
350 bool same_shard_on_front_two_dim = from_tensor_map.GetDimByIdx(0) == to_tensor_map.GetDimByIdx(0) &&
351 from_tensor_map.GetDimByIdx(1) == to_tensor_map.GetDimByIdx(1);
352 return same_shard_on_front_two_dim;
353 }
354 return false;
355 }
356
SkipTensorRedistribution(const TensorLayout & from_layout,const TensorLayout & to_layout)357 bool SkipTensorRedistribution(const TensorLayout &from_layout, const TensorLayout &to_layout) {
358 // If only one axis is sharded, and it's const axis, use past solution.
359 size_t from_shard_axis_cnt = 0;
360 size_t to_shard_axis_cnt = 0;
361 size_t from_index = 0;
362 size_t to_index = 0;
363 for (size_t i = 0; i < from_layout.tensor_map().array().size(); ++i) {
364 if (from_layout.tensor_map().GetDimByIdx(i) != -1) {
365 from_shard_axis_cnt += 1;
366 from_index = i;
367 }
368 }
369 for (size_t i = 0; i < to_layout.tensor_map().array().size(); ++i) {
370 if (to_layout.tensor_map().GetDimByIdx(i) != -1) {
371 to_shard_axis_cnt += 1;
372 to_index = i;
373 }
374 }
375 bool only_shard_on_one_axis = from_shard_axis_cnt == 1 && to_shard_axis_cnt == 1;
376 bool only_shard_on_const_axis =
377 from_layout.tensor_shape().GetDimByIdx(from_index) != -1 && to_layout.tensor_shape().GetDimByIdx(to_index) != -1;
378 if (only_shard_on_one_axis && only_shard_on_const_axis) {
379 return true;
380 }
381 return false;
382 }
383
ComputeReplaceOp()384 Status ReshapeInfo::ComputeReplaceOp() {
385 MS_LOG(INFO) << "Infer reshape redistribution for " << this->cnode_->fullname_with_scope() << "." << std::endl
386 << "input_layout_: " << this->input_layout_.ToString() << std::endl
387 << "output_layout_: " << this->output_layout_.ToString();
388 if (is_skip_) {
389 MS_LOG(DEBUG) << "Skip reshape redistribution for " << cnode_->fullname_with_scope() << std::endl;
390 if (DstShapeIsConstant(cnode_->input(2))) {
391 ConstructOperator constructor;
392 replace_op_ = constructor.SkipRedisReshapeOP(output_layout_.slice_shape().array());
393 replace_op_info_.clear();
394 MS_LOG(INFO) << "skip reshape redistribution and reshape slice_shape is "
395 << ShapeToString(output_layout_.slice_shape().array());
396 } else {
397 replace_op_.clear();
398 replace_op_info_.clear();
399 MS_LOG(WARNING) << name_ << ": dst shape is dynamic, and skip redistribution";
400 // need to modify the dst shape
401 ChangeDynamicDstShapeForSkipRedistribution(cnode_->input(2));
402 }
403 } else {
404 if (AccumulateShape(input_layout_.shard_strategy()) == 1 && AccumulateShape(output_layout_.shard_strategy()) == 1) {
405 // input and output have not shard
406 replace_op_.clear();
407 replace_op_info_.clear();
408 return SUCCESS;
409 }
410 auto reshape_input = this->cnode_->input(1);
411 MS_EXCEPTION_IF_CHECK_FAIL(reshape_input != nullptr,
412 "input of Reshape " + this->cnode_->fullname_with_scope() + " is nullptr.");
413
414 RankList dev_list = stage_device_list();
415 TensorRedistributionPtr tensor_redistribution =
416 this->CreateReshapeTensorRedistribution(!is_generating_costs_, true);
417 tensor_redistribution->SetPreAndNextCNode(reshape_input, this->cnode_);
418 if (tensor_redistribution->Init(input_layout_, output_layout_, dev_list) == FAILED) {
419 if (is_generating_costs_) {
420 MS_LOG(DEBUG) << name_ << ": tensor_redistribution init failed.";
421 } else {
422 MS_LOG(ERROR) << name_ << ": tensor_redistribution init failed.";
423 }
424 return FAILED;
425 }
426 MS_LOG(DEBUG) << name_ << ": input " << input_layout_.ToString();
427 MS_LOG(DEBUG) << name_ << ": output " << output_layout_.ToString();
428 MS_LOG(DEBUG) << name_ << ": dev_list " << dev_list.size();
429 RedistributionOpListPtr redistribution_oplist_ptr;
430 Shape output_tensor_shape = this->output_layout_.tensor_shape().array();
431 if (this->input_layout_.tensor_shape().array().size() != this->output_layout_.tensor_shape().array().size() &&
432 std::count(output_tensor_shape.begin(), output_tensor_shape.end(), -1) > 1) {
433 // If only one axis is sharded, and it's const axis, use past solution.
434 if (SpecialPatternInTransformer(this->input_layout_, this->output_layout_)) {
435 MS_LOG(INFO) << "Match special pattern in transformer.";
436 replace_op_.clear();
437 replace_op_info_.clear();
438 return SUCCESS;
439 }
440 if (SkipTensorRedistribution(this->input_layout_, this->output_layout_)) {
441 MS_LOG(WARNING) << "Skip tensor redistribution for " << this->cnode_->fullname_with_scope();
442 replace_op_.clear();
443 replace_op_info_.clear();
444 ChangeDynamicDstShapeForSkipRedistribution(cnode_->input(2));
445 return SUCCESS;
446 }
447 // use naive method. Do AllGather on each dim.
448 tensor_redistribution->set_original_reshape_shape(this->cnode_->input(INDEX_TWO));
449 MS_LOG(INFO) << this->name_
450 << " has more than 1 dynamic axis. shape: " << this->cnode_->input(INDEX_TWO)->fullname_with_scope();
451 redistribution_oplist_ptr = tensor_redistribution->InferTensorRedistributionOperatorListForMultiDynamicReshape();
452 } else {
453 tensor_redistribution->set_original_reshape_shape(nullptr);
454 redistribution_oplist_ptr = tensor_redistribution->InferTensorRedistributionOperatorList();
455 }
456 if (!is_generating_costs_ && !tensor_redistribution->IsAssembledStaticShape()) {
457 redistribution_oplist_ptr = TensorTransform::GetInstance()->OptimizeTensorRedistributionOperatorList(
458 redistribution_oplist_ptr, tensor_redistribution->input_shape());
459 }
460 if (redistribution_oplist_ptr == nullptr) {
461 if (is_generating_costs_) {
462 MS_LOG(DEBUG) << name_ << "InferTensorRedistribution failed.";
463 } else {
464 MS_LOG(ERROR) << name_ << "InferTensorRedistribution failed.";
465 }
466 return FAILED;
467 }
468 if (!redistribution_oplist_ptr->first.empty() && tensor_redistribution->original_reshape_shape() == nullptr &&
469 tensor_redistribution->IsAssembledStaticShape()) {
470 auto func_graph = this->cnode_->func_graph();
471 tensor_redistribution->CreateAssembledDynamicMapping(this->cnode_, reshape_input, func_graph, INDEX_ONE);
472 }
473 replace_op_ = redistribution_oplist_ptr->first;
474 replace_op_info_ = redistribution_oplist_ptr->second;
475 }
476 MS_LOG(DEBUG) << name_ << ": replace op size = " << replace_op_.size();
477 if (replace_op_.size() == 1 && replace_op_.front().first == RESHAPE) {
478 ChangeDstShape();
479 }
480 return SUCCESS;
481 }
482
ChangeDstShape()483 void ReshapeInfo::ChangeDstShape() {
484 int64_t shape_dim = 2;
485 auto value = replace_op_.front().second.second.front().first.second;
486 Shape dst_shape = GetValue<std::vector<int64_t>>(value);
487 Shape origin_dst_shape = GetInputShape(cnode_->input(LongToSize(shape_dim)));
488 if (dst_shape.size() == origin_dst_shape.size()) {
489 for (size_t i = 0; i < dst_shape.size(); ++i) {
490 if (origin_dst_shape[i] != dst_shape[i] && origin_dst_shape[i] != DYNAMIC_DIM_VAL) {
491 return;
492 }
493 }
494 int64_t dyn_dim_cnt = std::count(origin_dst_shape.cbegin(), origin_dst_shape.cend(), DYNAMIC_DIM_VAL);
495 if (dyn_dim_cnt > 1) {
496 MS_LOG(DEBUG) << name_ << ": Don't need to replace reshape's target shape.";
497 return;
498 }
499 MS_LOG(INFO) << name_ << ": The reshape would not change the target shape.";
500 replace_op_.front().second.second.front().first.second = MakeValue(origin_dst_shape);
501 }
502 }
503
504 /*
505 * the first dimension of input tensor map and output tensor map is set to the last dimension of device arrangement,
506 * all other dimension is set to None
507 * only support batch parallel reshape operator in ReID (batch parallel degree can be smaller than device number)
508 */
InferTensorMap()509 Status ReshapeInfo::InferTensorMap() {
510 if ((inputs_shape_.size() != 1) || (outputs_shape_.size() != 1)) {
511 MS_LOG(ERROR) << name_ << ": inputs shape and outputs shape size must be 1. inputs shape and outputs shape are "
512 << inputs_shape_.size() << " and " << outputs_shape_.size();
513 return FAILED;
514 }
515
516 Shape tensor_map_index_input;
517 for (size_t j = 0; j < inputs_shape_[0].size(); ++j) {
518 tensor_map_index_input.push_back(SizeToLong(inputs_shape_[0].size() - j - 1));
519 }
520 inputs_tensor_map_.push_back(tensor_map_index_input);
521
522 Shape tensor_map_index_output;
523 for (size_t j = 0; j < outputs_shape_[0].size(); ++j) {
524 tensor_map_index_output.push_back(MAP_NONE);
525 }
526 outputs_tensor_map_.push_back(tensor_map_index_output);
527 return SUCCESS;
528 }
529
InferTensorLayout(TensorLayouts * inputs_layout,TensorLayouts * outputs_layout)530 Status ReshapeInfo::InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout) {
531 if (inputs_layout == nullptr || outputs_layout == nullptr) {
532 MS_LOG(ERROR) << name_ << ": InferTensorLayout: the layout is null.";
533 return FAILED;
534 }
535 Arrangement dev_matrix;
536 Status status = dev_matrix.Init(dev_matrix_shape_);
537 if (status != Status::SUCCESS) {
538 return status;
539 }
540 // infer input tensor info
541 Shape shape_array_in = inputs_shape_.at(0);
542 TensorMap tensor_map_array_in = inputs_tensor_map_.at(0);
543 TensorLayout tensor_layout_in;
544 Map tensor_map_in;
545 status = tensor_map_in.Init(tensor_map_array_in);
546 if (status != Status::SUCCESS) {
547 return status;
548 }
549 Arrangement shape_in;
550 status = shape_in.Init(shape_array_in);
551 if (status != Status::SUCCESS) {
552 return status;
553 }
554 (void)tensor_layout_in.Init(dev_matrix, tensor_map_in, shape_in);
555 inputs_layout->push_back(tensor_layout_in);
556 // infer output tensor info
557 Shape shape_array_out = outputs_shape_.at(0);
558
559 TensorMap tensor_map_array_out = outputs_tensor_map_.at(0);
560 TensorLayout tensor_layout_out;
561 Map tensor_map_out;
562 status = tensor_map_out.Init(tensor_map_array_out);
563 if (status != Status::SUCCESS) {
564 return status;
565 }
566 Arrangement shape_out;
567 status = shape_out.Init(shape_array_out);
568 if (status != Status::SUCCESS) {
569 return status;
570 }
571 (void)tensor_layout_out.Init(dev_matrix, tensor_map_out, shape_out);
572 outputs_layout->push_back(tensor_layout_out);
573
574 input_layout_ = tensor_layout_in;
575 output_layout_ = tensor_layout_out;
576 return SUCCESS;
577 }
578
InferTensorInfo()579 Status ReshapeInfo::InferTensorInfo() {
580 // skip reshape infer if skip_redistribution is true
581 if (is_skip_) {
582 TensorLayout layout;
583 Shape shape;
584 Shape slice_shape;
585 layout.set_skip_redistribution(true);
586 TensorInfo tensor_info_in(layout, shape, slice_shape);
587 inputs_tensor_info_.push_back(tensor_info_in);
588 outputs_tensor_info_.push_back(tensor_info_in);
589 MS_LOG(DEBUG) << name() << "skip redistribution reshape InferTensorInfo";
590 return SUCCESS;
591 }
592
593 TensorLayouts inputs_layout, outputs_layout;
594 if (InferTensorLayout(&inputs_layout, &outputs_layout) != SUCCESS) {
595 return FAILED;
596 }
597 TensorLayout tensor_layout_in = inputs_layout.at(0);
598 TensorLayout tensor_layout_out = outputs_layout.at(0);
599 TensorInfo tensor_info_in(tensor_layout_in);
600 TensorInfo tensor_info_out(tensor_layout_out);
601 inputs_tensor_info_.push_back(tensor_info_in);
602 outputs_tensor_info_.push_back(tensor_info_out);
603 return SUCCESS;
604 }
605
InferTensorInfoByLayout()606 void ReshapeInfo::InferTensorInfoByLayout() {
607 TensorInfo tensor_info_in(input_layout_);
608 TensorInfo tensor_info_out(output_layout_);
609 inputs_tensor_info_.push_back(tensor_info_in);
610 outputs_tensor_info_.push_back(tensor_info_out);
611 }
612
device_number()613 void ReshapeInfo::device_number() {
614 dev_num_ = stage_device_size_;
615 MS_ASSERT(dev_num_ > 0);
616 }
617
InferDefaultLayout(const Shape & shape,TensorLayout * const layout)618 Status ReshapeInfo::InferDefaultLayout(const Shape &shape, TensorLayout *const layout) {
619 Shape tensor_map_index;
620 for (size_t i = 0; i < shape.size(); i++) {
621 tensor_map_index.push_back(MAP_NONE);
622 }
623 Status status = layout->InitFromVector({dev_num_}, tensor_map_index, shape);
624 if (status != Status::SUCCESS) {
625 MS_LOG(ERROR) << name_ << ": InferDefaultLayout failed.";
626 return status;
627 }
628 return Status::SUCCESS;
629 }
630
Init(const StrategyPtr & in_strategy,const StrategyPtr & out_strategy,const std::vector<std::shared_ptr<TensorLayout>> & in_tensor_layouts,const std::vector<std::shared_ptr<TensorLayout>> & out_tensor_layouts)631 Status ReshapeInfo::Init(const StrategyPtr &in_strategy, const StrategyPtr &out_strategy,
632 const std::vector<std::shared_ptr<TensorLayout>> &in_tensor_layouts,
633 const std::vector<std::shared_ptr<TensorLayout>> &out_tensor_layouts) {
634 MS_LOG(DEBUG) << "Init for " << this->cnode()->fullname_with_scope();
635 auto reshape_skip_redis_iter = attrs_.find(SKIP_REDISTRIBUTION);
636 if (reshape_skip_redis_iter != attrs_.end()) {
637 MS_EXCEPTION_IF_NULL(reshape_skip_redis_iter->second);
638 if (!reshape_skip_redis_iter->second->isa<BoolImm>()) {
639 MS_LOG(ERROR) << name_ << ": skip_redistribution is not a bool.";
640 return FAILED;
641 }
642 is_skip_ = reshape_skip_redis_iter->second->cast<BoolImmPtr>()->value();
643 }
644
645 ResetQueueMember();
646 device_number();
647 if (in_strategy) {
648 if (InitWithAutoRepeatCalc(in_strategy, out_strategy) != SUCCESS) {
649 MS_LOG(ERROR) << name_ << ": Init failed.";
650 return FAILED;
651 }
652 } else {
653 if (!input_layout_set_flag_) {
654 MS_ASSERT(inputs_shape_.size() == 1);
655 Status status = InferDefaultLayout(inputs_shape_.at(0), &input_layout_);
656 if (status != SUCCESS) {
657 MS_LOG(ERROR) << name_ << ": infer input default layout failed.";
658 return status;
659 }
660 }
661 if (!output_layout_set_flag_) {
662 MS_ASSERT(output_layout_.size() == 1);
663 Status status = InferDefaultLayout(outputs_shape_.at(0), &output_layout_);
664 if (status != SUCCESS) {
665 MS_LOG(ERROR) << name_ << ": infer output default layout failed.";
666 return status;
667 }
668 }
669 inputs_tensor_map_.push_back(input_layout_.tensor_map().array());
670 outputs_tensor_map_.push_back(output_layout_.tensor_map().array());
671 InferTensorInfoByLayout();
672 // change dev_matrix_shape_ to input_layout_ device_arrangement before InferMirrorOps
673 dev_matrix_shape_ = input_layout_.device_arrangement().array();
674 if (InferMirrorOps() != SUCCESS) {
675 MS_LOG(ERROR) << name_ << ": InferMirrorOps failed.";
676 return FAILED;
677 }
678 // change dev_matrix_shape_ to output_layout_ device_arrangement before InferVirtualDivOps
679 dev_matrix_shape_ = output_layout_.device_arrangement().array();
680 if (InferVirtualDivOps() != SUCCESS) {
681 MS_LOG(ERROR) << name_ << ": InferVirtualDivOps failed.";
682 return FAILED;
683 }
684 }
685 if (input_layout_.GetVirtualRank().size() > 1 || output_layout_.GetVirtualRank().size() > 1) {
686 interleaved_parallel_ = true;
687 return SUCCESS;
688 }
689
690 Status status = ComputeReplaceOp();
691 if (status != SUCCESS) {
692 MS_LOG(ERROR) << name_ << ": ComputeReplaceOp failed.";
693 return status;
694 }
695 return SUCCESS;
696 }
697
SetCostUnderStrategy(const mindspore::parallel::StrategyPtr & strategy)698 Status ReshapeInfo::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr &strategy) {
699 return SetCostUnderStrategyBase(strategy);
700 }
701
SetCostForReshapeWithParameter()702 void ReshapeInfo::SetCostForReshapeWithParameter() {
703 size_t success = 0;
704 for (auto &sp : sp_vector_) {
705 if (SetCostUnderStrategy(sp) == SUCCESS) {
706 success++;
707 MS_LOG(INFO) << name_ << ": Successfully generated the " << GetSerialNumberString(success)
708 << " strategy: " << sp->ToString();
709 }
710 }
711 }
712
SetCostForReshape(const mindspore::parallel::StrategyPtr & strategy)713 void ReshapeInfo::SetCostForReshape(const mindspore::parallel::StrategyPtr &strategy) {
714 MS_EXCEPTION_IF_NULL(strategy);
715 int64_t stage_id = strategy->GetInputStage();
716 double computation_cost =
717 operator_cost()->GetForwardComputationCost(inputs_tensor_info_, outputs_tensor_info_, stage_id);
718 double communication_cost = operator_cost()->GetCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id);
719 const auto gamma = CostModelContext::GetInstance()->costmodel_gamma();
720 std::shared_ptr<Cost> result = std::make_shared<Cost>(computation_cost, communication_cost);
721 result->communication_without_parameter_ =
722 operator_cost()->GetForwardCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id);
723 result->communication_with_partial_para_ =
724 result->communication_without_parameter_ + gamma * (communication_cost - result->communication_without_parameter_);
725
726 // Breaking ties for preferring data parallelization
727 BreakingTiesForPreferringDataParallel(strategy, result);
728 // refine communication cost calculation for practice
729 RefineForPracticalCost(result, false);
730
731 std::shared_ptr<StrategyWithCost> swc =
732 std::make_shared<StrategyWithCost>(strategy, inputs_tensor_info_, outputs_tensor_info_);
733 swc->cost_list.push_back(result);
734 strategy_cost_.emplace_back(swc);
735 }
736
GenerateOpStrategies(int64_t stage_id)737 std::vector<StrategyPtr> ReshapeInfo::GenerateOpStrategies(int64_t stage_id) {
738 if (inputs_shape_.empty()) {
739 MS_LOG(EXCEPTION) << name_ << ": Inputs shape size or is empty";
740 }
741 Shape input0_split;
742 (void)input0_split.insert(input0_split.cend(), inputs_shape_[0].size(), 1);
743 Shapes splittable_inputs = {input0_split};
744 // strategy used only in the input node is parameter,
745 // in other case, use the input node's output_layout as input_layout.
746 if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector_) != SUCCESS) {
747 MS_LOG(EXCEPTION) << name_ << ": GenerateStrategiesForIndependentInputs failed.";
748 }
749
750 return sp_vector_;
751 }
752
GenerateStrategyCosts(const std::vector<std::shared_ptr<StrategyWithCost>> & pre_stra_costs,std::vector<std::pair<std::vector<std::shared_ptr<StrategyWithCost>>,int64_t>> next_costs_index,int64_t out_index,bool is_prev_param,bool is_next_reshape)753 Status ReshapeInfo::GenerateStrategyCosts(
754 const std::vector<std::shared_ptr<StrategyWithCost>> &pre_stra_costs,
755 std::vector<std::pair<std::vector<std::shared_ptr<StrategyWithCost>>, int64_t>> next_costs_index, int64_t out_index,
756 bool is_prev_param, bool is_next_reshape) {
757 is_generating_costs_ = true;
758 for (auto pre_stra_cost : pre_stra_costs) {
759 std::vector<TensorInfo> pre_out_tensor_infos;
760 if (is_prev_param) {
761 pre_out_tensor_infos = pre_stra_cost->inputs_ptr;
762 } else {
763 pre_out_tensor_infos = pre_stra_cost->outputs_ptr;
764 }
765 if (pre_out_tensor_infos.size() <= LongToSize(out_index)) {
766 MS_LOG(ERROR) << "out_index is out of range of the tensor_infos in setting reshape's input_layout";
767 return FAILED;
768 }
769 TensorInfo pre_out_tensor_info = pre_out_tensor_infos[LongToSize(out_index)];
770 SetInputLayout(pre_out_tensor_info.tensor_layout());
771 // infer pre_node output strategy from output_layout.
772 Dimensions stra = pre_out_tensor_info.InferStrategy();
773 if (stra.empty()) {
774 MS_LOG(ERROR) << "Infer strategy by tensor_info failed";
775 return FAILED;
776 }
777 Strategies stra_inputs = {stra};
778 StrategyPtr reshape_stra = std::make_shared<Strategy>(pre_stra_cost->strategy_ptr->GetInputStage(), stra_inputs);
779 if (is_next_reshape) {
780 SetOutputLayout(pre_out_tensor_info.tensor_layout());
781 ResetQueueMember();
782 InferTensorInfoByLayout();
783 SetCostForReshape(reshape_stra);
784 } else if (next_costs_index.empty()) {
785 if (Init(nullptr, nullptr) == FAILED) {
786 MS_LOG(ERROR) << "Failure:operator reshape init failed";
787 return FAILED;
788 }
789 SetCostForReshape(reshape_stra);
790 continue;
791 }
792 for (auto next_cost_index_pair : next_costs_index) {
793 auto in_index = next_cost_index_pair.second;
794 auto next_stra_costs = next_cost_index_pair.first;
795 for (auto next_stra_cost : next_stra_costs) {
796 std::vector<TensorInfo> next_in_tensor_infos = next_stra_cost->inputs_ptr;
797 if (next_in_tensor_infos.size() <= LongToSize(in_index)) {
798 MS_LOG(ERROR) << "in_index is out of range of the tensor_infos in setting reshape's output_layout";
799 return FAILED;
800 }
801 TensorInfo next_in_tensor_info = next_in_tensor_infos[LongToSize(in_index)];
802
803 SetOutputLayout(next_in_tensor_info.tensor_layout());
804 ResetQueueMember();
805 InferTensorInfoByLayout();
806 SetCostForReshape(reshape_stra);
807 }
808 }
809 }
810 is_generating_costs_ = false;
811 if (strategy_cost_.empty()) {
812 return FAILED;
813 }
814 MS_LOG(INFO) << "Print " << name() << "'s 'strategy_cost':";
815 for (auto &swc : strategy_cost_) {
816 MS_LOG(INFO) << name() << "'s strategy: " << swc->strategy_ptr->ToString();
817 MS_LOG(INFO) << "The corresponding cost: " << swc->cost_list[0]->computation_cost_ << ", "
818 << swc->cost_list[0]->communication_cost_ << ", "
819 << swc->cost_list[0]->communication_without_parameter_;
820 MS_LOG(INFO) << "Input layout: " << swc->inputs_ptr[0].tensor_layout().ToString();
821 MS_LOG(INFO) << "Output layout: " << swc->outputs_ptr[0].tensor_layout().ToString();
822 }
823 return SUCCESS;
824 }
825
GetSWCIndexByOutputLayoutWithZeroComm(const TensorLayout & output_layout)826 int64_t ReshapeInfo::GetSWCIndexByOutputLayoutWithZeroComm(const TensorLayout &output_layout) {
827 std::vector<std::pair<int64_t, double>> index_computation;
828 for (size_t i = 0; i < strategy_cost_.size(); ++i) {
829 const auto &swc = strategy_cost_[i];
830 if (swc->outputs_ptr[0].tensor_layout() == output_layout &&
831 fabs(swc->cost_list[0]->communication_without_parameter_ - 0.0) < DBL_EPSILON) {
832 (void)index_computation.emplace_back(SizeToLong(i), swc->cost_list[0]->computation_cost_);
833 }
834 }
835 if (index_computation.empty()) {
836 MS_LOG(WARNING) << "There in no available strategy for zero communication cost for reshape: " << name();
837 return -1;
838 }
839 if (index_computation.size() > 1) {
840 MS_LOG(INFO) << "There are multiple strategies available for reshape: " << name();
841 }
842 std::sort(
843 index_computation.begin(), index_computation.end(),
844 [](const std::pair<int64_t, double> &a, const std::pair<int64_t, double> &b) { return a.second < b.second; });
845 return index_computation[0].first;
846 }
847
GetSWCIndexByOutputLayoutWithMiniComm(const TensorLayout & output_layout)848 int64_t ReshapeInfo::GetSWCIndexByOutputLayoutWithMiniComm(const TensorLayout &output_layout) {
849 std::vector<std::pair<int64_t, double>> index_comm;
850 for (size_t i = 0; i < strategy_cost_.size(); ++i) {
851 const auto &swc = strategy_cost_[i];
852 if (swc->outputs_ptr[0].tensor_layout() == output_layout) {
853 (void)index_comm.emplace_back(SizeToLong(i), swc->cost_list[0]->communication_without_parameter_);
854 }
855 }
856 if (index_comm.empty()) {
857 MS_LOG(ERROR) << "There in no available strategy for zero communication cost for reshape: " << name();
858 return -1;
859 }
860 if (index_comm.size() > 1) {
861 MS_LOG(INFO) << "There are multiple strategies available for reshape: " << name();
862 }
863 std::sort(
864 index_comm.begin(), index_comm.end(),
865 [](const std::pair<int64_t, double> &a, const std::pair<int64_t, double> &b) { return a.second < b.second; });
866 return index_comm[0].first;
867 }
868
GetSWCIndexByInputLayoutWithZeroComm(const TensorLayout & input_layout)869 int64_t ReshapeInfo::GetSWCIndexByInputLayoutWithZeroComm(const TensorLayout &input_layout) {
870 std::vector<std::pair<int64_t, double>> index_computation;
871 for (size_t i = 0; i < strategy_cost_.size(); ++i) {
872 const auto &swc = strategy_cost_[i];
873 if (swc->inputs_ptr[0].tensor_layout() == input_layout &&
874 fabs(swc->cost_list[0]->communication_without_parameter_ - 0.0) < DBL_EPSILON) {
875 (void)index_computation.emplace_back(SizeToLong(i), swc->cost_list[0]->computation_cost_);
876 }
877 }
878 if (index_computation.empty()) {
879 MS_LOG(WARNING) << "There in no available strategy for zero communication cost for reshape: " << name();
880 return -1;
881 }
882 if (index_computation.size() > 1) {
883 MS_LOG(INFO) << "There are multiple strategies available for reshape: " << name();
884 }
885 std::sort(
886 index_computation.begin(), index_computation.end(),
887 [](const std::pair<int64_t, double> &a, const std::pair<int64_t, double> &b) { return a.second < b.second; });
888 return index_computation[0].first;
889 }
890
GetSWCIndexByInputLayoutWithMiniComm(const TensorLayout & input_layout)891 int64_t ReshapeInfo::GetSWCIndexByInputLayoutWithMiniComm(const TensorLayout &input_layout) {
892 std::vector<std::pair<int64_t, double>> index_comm;
893 for (size_t i = 0; i < strategy_cost_.size(); ++i) {
894 const auto &swc = strategy_cost_[i];
895 if (swc->inputs_ptr[0].tensor_layout() == input_layout) {
896 (void)index_comm.emplace_back(SizeToLong(i), swc->cost_list[0]->communication_without_parameter_);
897 }
898 }
899 if (index_comm.empty()) {
900 MS_LOG(ERROR) << "There in no available strategy for zero communication cost for reshape: " << name();
901 return -1;
902 }
903 if (index_comm.size() > 1) {
904 MS_LOG(INFO) << "There are multiple strategies available for reshape: " << name();
905 }
906 std::sort(
907 index_comm.begin(), index_comm.end(),
908 [](const std::pair<int64_t, double> &a, const std::pair<int64_t, double> &b) { return a.second < b.second; });
909 return index_comm[0].first;
910 }
911
CheckStrategyConsistencyByOutputLayout(int64_t swc_index,const TensorLayout & output_layout) const912 bool ReshapeInfo::CheckStrategyConsistencyByOutputLayout(int64_t swc_index, const TensorLayout &output_layout) const {
913 if (swc_index == -1 || swc_index >= SizeToLong(strategy_cost_.size())) {
914 MS_LOG(ERROR) << "The strategy_index: " << swc_index << " is out of range.";
915 return false;
916 }
917 const auto &swc = strategy_cost_[LongToSize(swc_index)];
918 if (swc->outputs_ptr[0].tensor_layout() == output_layout) {
919 return true;
920 }
921 MS_LOG(WARNING) << name_ << "'s desired output layout is: " << output_layout.ToString() << ", while the selected "
922 << "output layout is: " << swc->outputs_ptr[0].tensor_layout().ToString()
923 << " and the input layout is: " << swc->inputs_ptr[0].tensor_layout().ToString();
924 return false;
925 }
926
CheckStrategyConsistencyByInputLayout(int64_t swc_index,const TensorLayout & input_layout) const927 bool ReshapeInfo::CheckStrategyConsistencyByInputLayout(int64_t swc_index, const TensorLayout &input_layout) const {
928 if (swc_index == -1 || swc_index >= SizeToLong(strategy_cost_.size())) {
929 MS_LOG(ERROR) << "The strategy_index: " << swc_index << " is out of range.";
930 return false;
931 }
932 const auto &swc = strategy_cost_[LongToSize(swc_index)];
933 if (swc->inputs_ptr[0].tensor_layout() == input_layout) {
934 return true;
935 }
936 MS_LOG(WARNING) << name_ << "'s desired input layout is:" << input_layout.ToString() << ", while the selected "
937 << "input layout is: " << swc->inputs_ptr[0].tensor_layout().ToString()
938 << " and the output layout is: " << swc->outputs_ptr[0].tensor_layout().ToString();
939 return false;
940 }
941
GetInputLayoutBySWCIndex(int64_t swc_index) const942 TensorLayout ReshapeInfo::GetInputLayoutBySWCIndex(int64_t swc_index) const {
943 if (swc_index == -1 || swc_index >= SizeToLong(strategy_cost_.size())) {
944 MS_LOG(EXCEPTION) << "The strategy_index: " << swc_index << " is out of range.";
945 }
946 const auto &swc = strategy_cost_[LongToSize(swc_index)];
947 return std::move(swc->inputs_ptr[0].tensor_layout());
948 }
949
GetOutputLayoutBySWCIndex(int64_t swc_index) const950 TensorLayout ReshapeInfo::GetOutputLayoutBySWCIndex(int64_t swc_index) const {
951 if (swc_index == -1 || swc_index >= SizeToLong(strategy_cost_.size())) {
952 MS_LOG(EXCEPTION) << "The strategy_index: " << swc_index << " is out of range.";
953 }
954 const auto &swc = strategy_cost_[LongToSize(swc_index)];
955 return std::move(swc->outputs_ptr[0].tensor_layout());
956 }
957
get_input_shard_strategy()958 StrategyPtr ReshapeInfo::get_input_shard_strategy() {
959 StrategyPtr ret = nullptr;
960 if (input_layout_set_flag_ && g_device_manager != nullptr) {
961 Strategies strategy;
962 Dimensions dim;
963 int64_t stage_id = g_device_manager->stage_id();
964 dim = input_layout_.shard_strategy();
965 strategy.push_back(dim);
966 ret = NewStrategy(stage_id, strategy);
967 }
968 return ret;
969 }
970
971 REGISTER(ReshapeInfo);
972 } // namespace parallel
973 } // namespace mindspore
974