• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/parallel/ops_info/conv2d_info.h"
18 
19 #include <algorithm>
20 #include <functional>
21 #include <cmath>
22 #include <memory>
23 #include <utility>
24 #include <vector>
25 
26 #include "frontend/parallel/device_matrix.h"
27 #include "frontend/parallel/strategy.h"
28 #include "frontend/parallel/tensor_layout/tensor_redistribution.h"
29 #include "frontend/parallel/graph_util/generate_graph.h"
30 #include "pipeline/jit/resource.h"
31 
32 namespace mindspore {
33 namespace parallel {
34 namespace {
MakeListValue(const std::vector<int64_t> & v)35 ValuePtr MakeListValue(const std::vector<int64_t> &v) {
36   std::vector<ValuePtr> list;
37   (void)std::transform(v.begin(), v.end(), std::back_inserter(list), [](int64_t ele) { return MakeValue(ele); });
38   return std::make_shared<ValueSequeue>(list);
39 }
40 
MakeTupleListValue(const Shapes & v)41 ValuePtr MakeTupleListValue(const Shapes &v) {
42   std::vector<ValuePtr> tuple;
43   (void)std::transform(v.begin(), v.end(), std::back_inserter(tuple),
44                        [](const std::vector<int64_t> &list) { return MakeListValue(list); });
45   return std::make_shared<ValueTuple>(tuple);
46 }
47 }  // namespace
GetAttrsBase()48 Status Conv2DInfo::GetAttrsBase() {
49   // format
50   format_ = GetStringAttr(FORMAT);
51   if (format_ != NCHW) {
52     MS_LOG(ERROR) << name_ << ": The format must be 'NCHW', but got " << format_;
53     return FAILED;
54   }
55 
56   // out_channel
57   out_channel_ = GetIntAttr(OUT_CHANNEL);
58   if (out_channel_ <= 0) {
59     MS_LOG(ERROR) << name_ << ": The attr of out_channel is invalid";
60     return FAILED;
61   }
62 
63   // kernel_size
64   auto kernel_size_iter = attrs_.find(KERNEL_SIZE);
65   if (kernel_size_iter == attrs_.end()) {
66     MS_LOG(ERROR) << name_ << ": Can not find the attribution of " << KERNEL_SIZE;
67     return FAILED;
68   }
69 
70   MS_EXCEPTION_IF_NULL(kernel_size_iter->second);
71   if (kernel_size_iter->second->isa<Int64Imm>()) {
72     int64_t kernel_size = kernel_size_iter->second->cast<Int64ImmPtr>()->value();
73     kernel_size_ = {kernel_size, kernel_size};
74   } else if (kernel_size_iter->second->isa<ValueTuple>() || kernel_size_iter->second->isa<ValueList>()) {
75     kernel_size_ = GetValue<std::vector<int64_t>>(kernel_size_iter->second);
76     if (kernel_size_.size() != 2) {
77       MS_LOG(ERROR) << name_ << ": The size of kernel_size'tuple must be 2, but got " << kernel_size_.size();
78       return FAILED;
79     }
80   } else {
81     MS_LOG(ERROR) << name_ << ": The kernel_size must be int or tuple";
82     return FAILED;
83   }
84 
85   // mode
86   mode_ = GetIntAttr(MODE);
87   if (mode_ != 1) {
88     MS_LOG(ERROR) << name_ << ": The mode must be 1, but got " << mode_;
89     return FAILED;
90   }
91 
92   // pad_mode
93   pad_mode_ = GetIntAttr(PAD_MODE);
94   if (pad_mode_ < 0 || pad_mode_ > 2) {
95     MS_LOG(ERROR) << name_ << ": The pad_mode must be in the range of [0, 2], but got " << pad_mode_;
96     return FAILED;
97   }
98 
99   // pad_list
100   pad_list_ = GetTupleIntAttr(PAD_LIST);
101   if (pad_list_.size() != 4) {
102     MS_LOG(ERROR) << name_ << ": The size of pad_list must be 4, but got " << pad_list_.size();
103     return FAILED;
104   }
105 
106   // stride
107   stride_ = GetTupleIntAttr(STRIDE);
108   if (stride_.size() != 4) {
109     MS_LOG(ERROR) << name_ << ": The size of stride must be 4, but got " << stride_.size();
110     return FAILED;
111   }
112 
113   if (stride_[0] != 1 || stride_[1] != 1) {
114     MS_LOG(ERROR) << name_ << ": The first two elements of stride must be 1, but got (" << stride_[0] << ", "
115                   << stride_[1] << ")";
116     return FAILED;
117   }
118 
119   // dilation
120   dilation_ = GetTupleIntAttr(DILATION);
121   if (dilation_.size() != 4) {
122     MS_LOG(ERROR) << name_ << ": The size of dilation must be 4, but got " << dilation_.size();
123     return FAILED;
124   }
125 
126   // group
127   group_ = GetIntAttr(GROUP);
128 
129   MS_LOG(INFO) << name_ << ": The out channel is " << out_channel_ << ", kernel size is " << kernel_size_
130                << ", mode is " << mode_ << ", pad mode is " << pad_mode_ << ", pad list is " << pad_list_
131                << ", stride is " << stride_ << ", dilation is " << dilation_ << ", group is " << group_
132                << ", format is " << format_;
133 
134   return SUCCESS;
135 }
136 
GetAttrs()137 Status Conv2DInfo::GetAttrs() { return GetAttrsBase(); }
138 
CheckHWStrategyBase(int64_t h_strategy,int64_t w_strategy) const139 Status Conv2DInfo::CheckHWStrategyBase(int64_t h_strategy, int64_t w_strategy) const {
140   if (outputs_shape_[0][2] % h_strategy != 0) {
141     MS_LOG(ERROR) << name_
142                   << ": Do not support to split h dimension when out_shape of h dimension is not divisible by strategy "
143                      "of h dimension";
144     return FAILED;
145   }
146 
147   if (outputs_shape_[0][3] % w_strategy != 0) {
148     MS_LOG(ERROR) << name_
149                   << ": Do not support to split w dimension when out_shape of w dimension is not divisible by strategy "
150                      "of w dimension";
151     return FAILED;
152   }
153 
154   return SUCCESS;
155 }
156 
CheckHWStrategySameMode(int64_t h_strategy,int64_t w_strategy)157 Status Conv2DInfo::CheckHWStrategySameMode(int64_t h_strategy, int64_t w_strategy) {
158   int64_t h_slice_shape = inputs_shape_[0][2] / h_strategy;
159   int64_t w_slice_shape = inputs_shape_[0][3] / w_strategy;
160 
161   // H dimension
162   if (kernel_size_[0] > stride_[2] && h_strategy > 1) {
163     MS_LOG(ERROR) << name_ << ": The 'same' mode do not support to split H when kernel_size > stride";
164     return FAILED;
165   }
166 
167   if (h_strategy > 1 && (kernel_size_[0] <= stride_[2] && h_slice_shape % stride_[2] != 0)) {
168     MS_LOG(ERROR) << name_
169                   << ": The 'same' mode do not support to split H when kernel_size <= stride but slice shape "
170                      "is not divisible by stride ";
171     return FAILED;
172   }
173 
174   // W dimension
175   if (w_strategy > 1 && (kernel_size_[1] <= stride_[3] && w_slice_shape % stride_[3] != 0)) {
176     MS_LOG(ERROR) << name_
177                   << ": The 'same' mode do not support to split W when kernel_size <= stride but slice shape "
178                      "is not divisible by stride ";
179     return FAILED;
180   }
181 
182   if (w_strategy > 1 && (kernel_size_[1] > stride_[3])) {
183     if (inputs_shape_[0][3] % stride_[3] != 0) {
184       MS_LOG(ERROR) << name_
185                     << ": The 'same' mode do not support to split W when kernel_size > stride but w shape is not "
186                        "divisible by stride";
187       return FAILED;
188     }
189 
190     if (w_slice_shape <= ((kernel_size_[1] - stride_[3] + 1) / 2)) {
191       MS_LOG(ERROR) << name_
192                     << ": The 'same' mode do not support to split W when kernel_size > stride but w slice shape is "
193                        "smaller than or equal to (k - s + 1) / 2";
194       return FAILED;
195     }
196 
197     if (kernel_size_[1] - stride_[3] == 1) {
198       MS_LOG(ERROR) << name_ << ": The 'same' mode do not support to split W when kernel_size > stride but k - s == 1";
199       return FAILED;
200     }
201   }
202 
203   return SUCCESS;
204 }
205 
CheckHWStrategyValidMode(int64_t h_strategy,int64_t w_strategy)206 Status Conv2DInfo::CheckHWStrategyValidMode(int64_t h_strategy, int64_t w_strategy) {
207   int64_t h_slice_shape = inputs_shape_[0][2] / h_strategy;
208   int64_t w_slice_shape = inputs_shape_[0][3] / w_strategy;
209 
210   if ((kernel_size_[0] > stride_[2] && h_strategy > 1) || (kernel_size_[1] > stride_[3] && w_strategy > 1)) {
211     MS_LOG(ERROR) << name_ << ": The 'valid' mode do not support to split H or W when kernel_size > stride";
212     return FAILED;
213   }
214 
215   if (kernel_size_[0] <= stride_[2] && h_slice_shape % stride_[2] != 0) {
216     MS_LOG(ERROR) << name_
217                   << ": The 'valid' mode do not support to split H when kernel_size <= stride but slice shape is "
218                      "not divisible by stride ";
219     return FAILED;
220   }
221 
222   if (kernel_size_[1] <= stride_[3] && w_slice_shape % stride_[3] != 0) {
223     MS_LOG(ERROR) << name_
224                   << ": The 'valid' mode do not support to split W when kernel_size <= stride but slice shape is "
225                      "not divisible by stride ";
226     return FAILED;
227   }
228 
229   return SUCCESS;
230 }
231 
CheckHWStrategy(int64_t h_strategy,int64_t w_strategy)232 Status Conv2DInfo::CheckHWStrategy(int64_t h_strategy, int64_t w_strategy) {
233   if (CheckHWStrategyBase(h_strategy, w_strategy) != SUCCESS) {
234     return FAILED;
235   }
236 
237   if (pad_mode_ == 0) {  // 'pad' mode
238     MS_LOG(ERROR) << name_ << ": The 'pad' mode do not support to split H or W";
239     return FAILED;
240   }
241 
242   if (pad_mode_ == 1) {  // 'same' mode
243     return CheckHWStrategySameMode(h_strategy, w_strategy);
244   }
245 
246   if (pad_mode_ == 2) {  // 'valid' mode
247     return CheckHWStrategyValidMode(h_strategy, w_strategy);
248   }
249 
250   return SUCCESS;
251 }
252 
CheckStrategyBase(const StrategyPtr & strategy)253 Status Conv2DInfo::CheckStrategyBase(const StrategyPtr &strategy) {
254   MS_EXCEPTION_IF_NULL(strategy);
255   if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
256     MS_LOG(ERROR) << name_ << ": Invalid strategy";
257     return FAILED;
258   }
259 
260   std::vector<Dimensions> stra = strategy->GetInputDim();
261   if (stra.size() != 2) {
262     MS_LOG(ERROR) << name_ << ": The size of strategy must be 2, but got " << stra.size();
263     return FAILED;
264   }
265 
266   Dimensions input_strategy = stra[0];
267   Dimensions weight_strategy = stra[1];
268   if (input_strategy.size() != 4 || weight_strategy.size() != 4) {
269     MS_LOG(ERROR) << name_
270                   << ": The size of input strategy or weight strategy must be 4, but the size of input strategy is "
271                   << input_strategy.size() << ", the size of weight strategy is " << weight_strategy.size();
272     return FAILED;
273   }
274 
275   if (weight_strategy[2] != 1 || weight_strategy[3] != 1) {
276     MS_LOG(ERROR) << name_ << ": The kernel size can not be split, but the strategy for kernel size is ("
277                   << weight_strategy[2] << ", " << weight_strategy[3] << ")";
278     return FAILED;
279   }
280 
281   if (weight_strategy[0] > 1) {
282     out_channel_shard_ = true;
283     new_out_channel_ = out_channel_ / weight_strategy[0];
284   } else {
285     out_channel_shard_ = false;
286     new_out_channel_ = out_channel_;
287   }
288 
289   int64_t input_except_n_shards =
290     std::accumulate(input_strategy.begin() + 1, input_strategy.end(), 1, std::multiplies<int64_t>());
291   int64_t weight_shards =
292     std::accumulate(weight_strategy.begin() + 1, weight_strategy.end(), 1, std::multiplies<int64_t>());
293 
294   bool is_data_parallel = (input_except_n_shards * weight_shards == 1);
295   if (!is_data_parallel) {
296     if (std::any_of(dilation_.begin(), dilation_.end(), [](int64_t value) { return value != 1; })) {
297       MS_LOG(ERROR) << name_ << ": If it is not data parallel, the value of dilation must be 1, but got " << dilation_;
298       return FAILED;
299     }
300 
301     if (group_ != 1) {
302       MS_LOG(ERROR) << name_ << ": If it is not data parallel, the group must be 1, but got " << group_;
303       return FAILED;
304     }
305   }
306   return SUCCESS;
307 }
308 
CheckStrategy(const StrategyPtr & strategy)309 Status Conv2DInfo::CheckStrategy(const StrategyPtr &strategy) {
310   need_exchange_overlap_ = false;
311   if (CheckStrategyBase(strategy) != SUCCESS) {
312     return FAILED;
313   }
314 
315   std::vector<Dimensions> stra = strategy->GetInputDim();
316   Dimensions input_strategy = stra[0];
317   Dimensions weight_strategy = stra[1];
318   if (input_strategy[1] != weight_strategy[1]) {
319     MS_LOG(ERROR) << name_ << ": The shard num of c-in for input strategy is " << input_strategy[1]
320                   << ", but the shard num of c-in for weight strategy is " << weight_strategy[1];
321     return FAILED;
322   }
323 
324   if (input_strategy[2] != 1 || input_strategy[3] != 1) {
325     if (CheckHWStrategy(input_strategy[2], input_strategy[3]) != SUCCESS) {
326       return FAILED;
327     }
328   }
329 
330   // kernel size larger than stride and the w dimension is split, need to exchange overlap
331   if ((kernel_size_[1] > stride_[3]) && (input_strategy[3] > 1)) {
332     need_exchange_overlap_ = true;
333   }
334 
335   return SUCCESS;
336 }
337 
InferDevMatrixShape()338 Status Conv2DInfo::InferDevMatrixShape() {
339   // the strategy is ((n, i, h, w), (o, i, 1, 1))
340   // the dev matrix is (n, i, h, w, o)
341   MS_EXCEPTION_IF_NULL(strategy_);
342   std::vector<Dimensions> stra = strategy_->GetInputDim();
343   if (stra.size() != 2) {
344     MS_LOG(ERROR) << name_ << ": The size of strategy must be 2, but got " << stra.size();
345     return FAILED;
346   }
347 
348   dev_matrix_shape_ = stra[0];
349   dev_matrix_shape_.push_back(stra[1][0]);
350   w_dimension_shard_num_ = stra[0][3];
351   input_slice_shape_ = GetSliceShape(inputs_shape_[0], stra[0]);
352   return SUCCESS;
353 }
354 
InferRankBias()355 Status Conv2DInfo::InferRankBias() {
356   // the Conv2D operator:
357   // the origin dev_matrix is [n, i, h, w, o]
358   // if repeated calculation and repeated num in the left of dev matrix, the dev_matrix is [repeated_num, n, i, h, w, o]
359   // if repeated calculation and repeated num in the right of dev matrix, the dev_matrix is [n, i, h, w, o,
360   // repeated_num]
361   //
362   // the Conv2DBackpropInput's origin dev_matrix is [n, o, h, w, i], w dimension's relative position is the same as
363   // Conv2D, the rank_bias_ is the position of the current rank in the w dimension of the dev_matrix(have not split h
364   // dimension)
365   if (!need_exchange_overlap_) {
366     MS_LOG(INFO) << name_ << ": No need to infer rank bias";
367     return SUCCESS;
368   }
369 
370   uint64_t w_index_in_dev_matrix = 3;
371   if (repeated_calc_num_ > 1 && !repeated_num_in_dev_matrix_right_) {
372     w_index_in_dev_matrix += 1;
373   }
374 
375   CheckGlobalDeviceManager();
376   int64_t rank = g_device_manager->global_rank();
377   DeviceMatrix dev_matrix(rank, stage_device_list_, dev_matrix_shape_);
378   RankList group_devices;
379   if (dev_matrix.GetDevicesAlongDim(w_index_in_dev_matrix, &group_devices) != SUCCESS) {
380     return FAILED;
381   }
382 
383   if (group_devices.size() <= 1) {
384     MS_LOG(INFO) << name_ << ": The devices' size of w dimension is " << group_devices.size()
385                  << ", no need to infer rank bias";
386     return SUCCESS;
387   }
388 
389   if (group_devices.size() != LongToSize(w_dimension_shard_num_)) {
390     MS_LOG(ERROR) << name_ << ": The devices' size of w dimension is " << group_devices.size()
391                   << ", but the shard num of w dimension is " << w_dimension_shard_num_;
392     return FAILED;
393   }
394 
395   std::vector<int64_t>::iterator it = std::find(group_devices.begin(), group_devices.end(), rank);
396   if (it == group_devices.end()) {
397     MS_LOG(ERROR) << name_ << ": Can not find the current rank in device list of w dimension, the current rank is "
398                   << rank << ", the device list is " << group_devices;
399     return FAILED;
400   }
401 
402   rank_bias_ = std::distance(group_devices.begin(), it);
403   if (it == group_devices.begin()) {
404     left_rank_bias_ = -1;
405     right_rank_bias_ = rank_bias_ + 1;
406 
407     left_rank_id_ = -1;
408     right_rank_id_ = *(it + 1);
409   } else if (it == group_devices.end() - 1) {
410     left_rank_bias_ = rank_bias_ - 1;
411     right_rank_bias_ = -1;
412 
413     left_rank_id_ = *(it - 1);
414     right_rank_id_ = -1;
415   } else {
416     left_rank_bias_ = rank_bias_ - 1;
417     right_rank_bias_ = rank_bias_ + 1;
418 
419     left_rank_id_ = *(it - 1);
420     right_rank_id_ = *(it + 1);
421   }
422   MS_LOG(INFO) << name_ << ": The current rank is " << rank << ", the device list of w dimension is " << group_devices
423                << ", the rank bias is " << rank_bias_ << ", the left rank bias is " << left_rank_bias_
424                << ", the right rank bias is " << right_rank_bias_ << ", the left rank id is " << left_rank_id_
425                << ", the right rank id is " << right_rank_id_;
426   return SUCCESS;
427 }
428 
ComputeOverlapLeftSizeByRankBias(int64_t rank_bias)429 int64_t Conv2DInfo::ComputeOverlapLeftSizeByRankBias(int64_t rank_bias) {
430   int64_t left_pad = pad_list_[2];
431   int64_t w_dimension_input_shape = inputs_shape_[0][3];
432   int64_t w_dimension_output_shape = outputs_shape_[0][3];
433   int64_t w_stride = stride_[3];
434 
435   return left_pad +
436          (w_dimension_input_shape - w_dimension_output_shape * w_stride) * rank_bias / w_dimension_shard_num_;
437 }
438 
ComputeOverlapRightSizeByRankBias(int64_t rank_bias)439 int64_t Conv2DInfo::ComputeOverlapRightSizeByRankBias(int64_t rank_bias) {
440   int64_t left_pad = pad_list_[2];
441   int64_t w_dimension_input_shape = inputs_shape_[0][3];
442   int64_t w_dimension_output_shape = outputs_shape_[0][3];
443   int64_t w_kernel_size = kernel_size_[1];
444   int64_t w_stride = stride_[3];
445 
446   return (rank_bias + 1) * (w_dimension_output_shape * w_stride - w_dimension_input_shape) / w_dimension_shard_num_ +
447          w_kernel_size - w_stride - left_pad;
448 }
449 
InferOverlapSize()450 void Conv2DInfo::InferOverlapSize() {
451   if (!need_exchange_overlap_) {
452     MS_LOG(INFO) << name_ << ": No need to infer overlap size";
453     return;
454   }
455 
456   overlap_left_size_ = ComputeOverlapLeftSizeByRankBias(rank_bias_);
457   overlap_right_size_ = ComputeOverlapRightSizeByRankBias(rank_bias_);
458 
459   if (rank_bias_ == 0) {  // it has not left rank
460     left_rank_overlap_left_size_ = 0;
461     left_rank_overlap_right_size_ = 0;
462     right_rank_overlap_left_size_ = ComputeOverlapLeftSizeByRankBias(right_rank_bias_);
463     right_rank_overlap_right_size_ = ComputeOverlapRightSizeByRankBias(right_rank_bias_);
464   } else if (rank_bias_ == w_dimension_shard_num_ - 1) {  // it has not right rank
465     left_rank_overlap_left_size_ = ComputeOverlapLeftSizeByRankBias(left_rank_bias_);
466     left_rank_overlap_right_size_ = ComputeOverlapRightSizeByRankBias(left_rank_bias_);
467     right_rank_overlap_left_size_ = 0;
468     right_rank_overlap_right_size_ = 0;
469   } else {  // it has left rank and right rank
470     left_rank_overlap_left_size_ = ComputeOverlapLeftSizeByRankBias(left_rank_bias_);
471     left_rank_overlap_right_size_ = ComputeOverlapRightSizeByRankBias(left_rank_bias_);
472     right_rank_overlap_left_size_ = ComputeOverlapLeftSizeByRankBias(right_rank_bias_);
473     right_rank_overlap_right_size_ = ComputeOverlapRightSizeByRankBias(right_rank_bias_);
474   }
475 
476   MS_LOG(INFO) << name_ << ": the left overlap size of current rank is " << overlap_left_size_
477                << ", the right overlap size of current rank is " << overlap_right_size_
478                << ", the left overlap size of left rank is " << left_rank_overlap_left_size_
479                << ", the right overlap size of left rank is " << left_rank_overlap_right_size_
480                << ", the left overlap size of right rank is " << right_rank_overlap_left_size_
481                << ", the right overlap size of right rank is " << right_rank_overlap_right_size_;
482 }
483 
InferTensorMap()484 Status Conv2DInfo::InferTensorMap() {
485   // input_strategy: ((n, i, h, w), (o, i, 1, 1))
486   // output_strategy: ((n, o, h, w),)
487   // dev_matrix: (n, i, h, w, o)
488   TensorMap input_tensor_map = {4, 3, 2, 1};
489   TensorMap weight_tensor_map = {0, 3, -1, -1};
490   TensorMap output_tensor_map = {4, 0, 2, 1};
491 
492   (void)inputs_tensor_map_.emplace_back(std::move(input_tensor_map));
493   (void)inputs_tensor_map_.emplace_back(std::move(weight_tensor_map));
494   (void)outputs_tensor_map_.emplace_back(std::move(output_tensor_map));
495   return SUCCESS;
496 }
497 
498 // Conv2d: dev_matrix is (n, i, h, w, o), if in channel is split, it need to insert all reduce
499 // Conv2DBackpropInputInfo: dev_matrix is (n, o, h, w, i), if out channel is split, it need to insert all reduce
InferForwardCommunication()500 Status Conv2DInfo::InferForwardCommunication() {
501   forward_op_.clear();
502   size_t relevant_dim_index = IN_CHANNEL_INDEX;
503   if (repeated_calc_num_ > 1 && !repeated_num_in_dev_matrix_right_) {
504     // if repeated calculation and repeated num in the left of dev matrix, the index of relevant dimension should add 1
505     relevant_dim_index += 1;
506   }
507 
508   if (dev_matrix_shape_[relevant_dim_index] == MIN_SLICE_NUM) {
509     MS_LOG(INFO) << name_ << ": Forward all reduce is not required";
510     return SUCCESS;
511   }
512 
513   std::vector<Group> group_list;
514   if (CreateGroupByDim(relevant_dim_index, &group_list) != SUCCESS) {
515     MS_LOG(ERROR) << name_ << ": Create group failed";
516     return FAILED;
517   }
518 
519   if (group_list.empty()) {
520     MS_LOG(INFO) << name_ << ": Forward all reduce is not required";
521     return SUCCESS;
522   }
523 
524   Operator op = CreateAllReduceOp(REDUCE_OP_SUM, group_list[0].name());
525   forward_op_.push_back(op);
526   MS_LOG(INFO) << name_ << ": The group name of forward all reduce is " << group_list[0].name();
527 
528   return SUCCESS;
529 }
530 
InferNewPadList()531 void Conv2DInfo::InferNewPadList() {
532   new_pad_list_ = pad_list_;
533   if (rank_bias_ == 0) {                                  // the first rank
534     new_pad_list_[3] = 0;                                 // no need the right pad
535   } else if (rank_bias_ == w_dimension_shard_num_ - 1) {  // the last rank
536     new_pad_list_[2] = 0;                                 // no need the left pad
537   } else {                                                // the middle rank
538     new_pad_list_[2] = 0;                                 // no need the left pad
539     new_pad_list_[3] = 0;                                 // no need the right pad
540   }
541   MS_LOG(INFO) << name_ << ": the new pad list is " << new_pad_list_;
542 }
543 
InferSendRecvFlag()544 void Conv2DInfo::InferSendRecvFlag() {
545   if (rank_bias_ == 0) {  // the first rank
546     left_need_send_ = false;
547     left_need_recv_ = false;
548     right_need_send_ = (right_rank_overlap_left_size_ > 0);
549     right_need_recv_ = (overlap_right_size_ > 0);         // no need the right pad
550   } else if (rank_bias_ == w_dimension_shard_num_ - 1) {  // the last rank
551     left_need_send_ = (left_rank_overlap_right_size_ > 0);
552     left_need_recv_ = (overlap_left_size_ > 0);
553     right_need_send_ = false;
554     right_need_recv_ = false;
555   } else {  // the middle rank
556     left_need_send_ = (left_rank_overlap_right_size_ > 0);
557     left_need_recv_ = (overlap_left_size_ > 0);
558     right_need_send_ = (right_rank_overlap_left_size_ > 0);
559     right_need_recv_ = (overlap_right_size_ > 0);
560   }
561   MS_LOG(INFO) << name_ << ": The left need send is " << left_need_send_ << ", the left need recv is "
562                << left_need_recv_ << ", the right need send is " << right_need_send_ << ", the right need recv is "
563                << right_need_recv_;
564 
565   if (left_need_send_) {
566     if (left_rank_overlap_right_size_ >= input_slice_shape_[3]) {
567       MS_LOG(EXCEPTION) << name_ << ": Do not support left overlap size(" << left_rank_overlap_right_size_
568                         << ") larger than or equal to slice shape in w dimension(" << input_slice_shape_[3] << ")";
569     }
570     send_rank_ids_.push_back(left_rank_id_);
571   }
572 
573   if (right_need_send_) {
574     if (right_rank_overlap_left_size_ >= input_slice_shape_[3]) {
575       MS_LOG(EXCEPTION) << name_ << ": Do not support left overlap size(" << right_rank_overlap_left_size_
576                         << ") larger than or equal to slice shape in w dimension(" << input_slice_shape_[3] << ")";
577     }
578     send_rank_ids_.push_back(right_rank_id_);
579   }
580 
581   if (left_need_recv_) {
582     recv_rank_ids_.push_back(left_rank_id_);
583   }
584 
585   if (right_need_recv_) {
586     recv_rank_ids_.push_back(right_rank_id_);
587   }
588 
589   MS_LOG(INFO) << name_ << ": The send rank ids is " << send_rank_ids_ << ", the recv rank ids is " << recv_rank_ids_;
590 }
591 
InferOverlapShapes()592 void Conv2DInfo::InferOverlapShapes() {
593   if (left_need_recv_) {
594     Shape left_recv_shape = input_slice_shape_;
595     left_recv_shape[3] = overlap_left_size_;
596     recv_shapes_.push_back(left_recv_shape);
597   }
598 
599   if (right_need_recv_) {
600     Shape right_recv_shape = input_slice_shape_;
601     right_recv_shape[3] = overlap_right_size_;
602     recv_shapes_.push_back(right_recv_shape);
603   }
604 
605   if (left_need_send_) {
606     Shape left_send_shape = input_slice_shape_;
607     left_send_shape[3] = left_rank_overlap_right_size_;
608     send_shapes_.push_back(left_send_shape);
609   }
610 
611   if (right_need_send_) {
612     Shape right_send_shape = input_slice_shape_;
613     right_send_shape[3] = right_rank_overlap_left_size_;
614     send_shapes_.push_back(right_send_shape);
615   }
616   MS_LOG(INFO) << name_ << ": the recv shapes is " << recv_shapes_ << ", the send shapes is " << send_shapes_;
617 }
618 
InferStridedSliceAttrs()619 void Conv2DInfo::InferStridedSliceAttrs() {
620   if (left_need_send_) {
621     left_strided_slice_begin_ = {0, 0, 0, 0};
622     left_strided_slice_end_ = input_slice_shape_;
623     left_strided_slice_end_[3] = left_rank_overlap_right_size_;
624     left_strided_slice_strides_ = {1, 1, 1, 1};
625     MS_LOG(INFO) << name_ << ": The left strided slice begin is " << left_strided_slice_begin_ << ", end is "
626                  << left_strided_slice_end_;
627   }
628 
629   if (right_need_send_) {
630     right_strided_slice_begin_ = {0, 0, 0, 0};
631     right_strided_slice_begin_[3] = input_slice_shape_[3] - right_rank_overlap_left_size_;
632     right_strided_slice_end_ = input_slice_shape_;
633     right_strided_slice_strides_ = {1, 1, 1, 1};
634     MS_LOG(INFO) << name_ << ": The right strided slice begin is " << right_strided_slice_begin_ << ", end is "
635                  << right_strided_slice_end_;
636   }
637 }
638 
InferNewOperatorAttrs()639 void Conv2DInfo::InferNewOperatorAttrs() {
640   InferNewPadList();
641 
642   InferSendRecvFlag();
643 
644   InferOverlapShapes();
645 
646   InferStridedSliceAttrs();
647 }
648 
CreateNeighborExchangeAttrs(const CNodePtr & cnode)649 OperatorAttrs Conv2DInfo::CreateNeighborExchangeAttrs(const CNodePtr &cnode) {
650   auto type = cnode->Type();
651   MS_EXCEPTION_IF_NULL(type);
652   auto tensor_type = type->cast<mindspore::TensorTypePtr>();
653   MS_EXCEPTION_IF_NULL(tensor_type);
654   auto dtype = tensor_type->element();
655   MS_EXCEPTION_IF_NULL(dtype);
656   Attr send_ranks = {SEND_RNAK_IDS, MakeListValue(send_rank_ids_)};
657   Attr recv_ranks = {RECV_RNAK_IDS, MakeListValue(recv_rank_ids_)};
658   Attr send_shapes = {SEND_SHAPES, MakeTupleListValue(send_shapes_)};
659   Attr recv_shapes = {RECV_SHAPES, MakeTupleListValue(recv_shapes_)};
660   Attr recv_type = {RECV_TYPE, dtype};
661   OperatorAttrs attrs = {send_ranks, recv_ranks, recv_shapes, send_shapes, recv_type};
662   return attrs;
663 }
664 
CreateConv2DAttrs()665 OperatorAttrs Conv2DInfo::CreateConv2DAttrs() {
666   Attr out_channel = {OUT_CHANNEL, MakeValue(new_out_channel_)};
667   Attr kernel_size = {KERNEL_SIZE, MakeValue(kernel_size_)};
668   Attr mode = {MODE, MakeValue(mode_)};
669   Attr pad_mode = {PAD_MODE, MakeValue("pad")};
670   Attr pad = {PAD, MakeValue(new_pad_list_)};
671   Attr stride = {STRIDE, MakeValue(stride_)};
672   Attr dilation = {DILATION, MakeValue(dilation_)};
673   Attr group = {GROUP, MakeValue(group_)};
674   Attr data_format = {DATA_FORMAT, MakeValue(format_)};
675 
676   OperatorAttrs attrs;
677   if (name_.find(CONV2D_INFO) != std::string::npos) {
678     attrs = {out_channel, kernel_size, mode, pad_mode, pad, stride, dilation, group, data_format};
679   } else {  // Conv2DTranspose
680     attrs = {out_channel, kernel_size, pad_mode, pad, pad, mode, stride, dilation, group, data_format};
681   }
682 
683   return attrs;
684 }
685 
ReplaceNodeName() const686 std::string Conv2DInfo::ReplaceNodeName() const {
687   if (name_.find(CONV2D_INFO) != std::string::npos) {
688     return CONV2D;
689   }
690 
691   if (name_.find(CONV2D_BACK_PROP_INPUT_INFO) != std::string::npos) {
692     return CONV2D_BACK_PROP_INPUT;
693   }
694 
695   if (name_.find(CONV2D_TRANSPOSE_INFO) != std::string::npos) {
696     return CONV2D_TRANSPOSE;
697   }
698 
699   MS_LOG(EXCEPTION) << "Invalid name: " << name_;
700 }
701 
GenerateConv2DNode(const AnfNodePtr & new_input,const CNodePtr & cnode)702 AnfNodePtr Conv2DInfo::GenerateConv2DNode(const AnfNodePtr &new_input, const CNodePtr &cnode) {
703   auto conv2d_attrs = CreateConv2DAttrs();
704   auto node_name = ReplaceNodeName();
705 
706   // conv2d
707   if (name_.find(CONV2D_INFO) != std::string::npos) {
708     if (cnode->size() < 3) {
709       MS_LOG(EXCEPTION) << name_ << ": The size of cnode is invalid: " << cnode->size();
710     }
711     return gen_g_.PushBack({gen_g_.NewOpInst(node_name, conv2d_attrs), new_input, cnode->input(2)});
712   }
713 
714   // conv2dtranspose
715   if (cnode->size() < 4) {
716     MS_LOG(EXCEPTION) << name_ << ": The size of cnode is invalid: " << cnode->size();
717   }
718   return gen_g_.PushBack({gen_g_.NewOpInst(node_name, conv2d_attrs), new_input, cnode->input(2), cnode->input(3)});
719 }
720 
ComputeReplaceGraph(const CNodePtr & cnode)721 void Conv2DInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
722   auto graph = cnode->func_graph();
723   MS_EXCEPTION_IF_NULL(graph);
724 
725   if (gen_g_.Init(cnode) != SUCCESS) {
726     MS_LOG(EXCEPTION) << "GenerateGraph Init failed";
727   }
728 
729   if (!left_need_send_ && !right_need_send_) {
730     MS_LOG(EXCEPTION) << name_ << ": Now do not support left no need to send and right no need to send";
731   }
732 
733   if (!left_need_recv_ && !right_need_recv_) {
734     MS_LOG(EXCEPTION) << name_ << ": Now do not support left no need to recv and right no need to recv";
735   }
736 
737   std::vector<std::pair<AnfNodePtr, int64_t>> input_nodes;
738   std::vector<AnfNodePtr> make_tuple_a_inputs = {NewValueNode(prim::kPrimMakeTuple)};
739   if (left_need_send_) {
740     auto slice_left_begin = CreateTuple(left_strided_slice_begin_);
741     auto slice_left_end = CreateTuple(left_strided_slice_end_);
742     auto slice_left_strided = CreateTuple(left_strided_slice_strides_);
743     auto slice_left = gen_g_.PushBack({gen_g_.NewOpInst(STRIDED_SLICE), gen_g_.virtual_input_node(), slice_left_begin,
744                                        slice_left_end, slice_left_strided});
745     make_tuple_a_inputs.push_back(slice_left);
746     input_nodes.push_back(std::make_pair(slice_left, 1));
747   }
748   if (right_need_send_) {
749     auto slice_right_begin = CreateTuple(right_strided_slice_begin_);
750     auto slice_right_end = CreateTuple(right_strided_slice_end_);
751     auto slice_right_strided = CreateTuple(right_strided_slice_strides_);
752     auto slice_right = gen_g_.PushBack({gen_g_.NewOpInst(STRIDED_SLICE), gen_g_.virtual_input_node(), slice_right_begin,
753                                         slice_right_end, slice_right_strided});
754     make_tuple_a_inputs.push_back(slice_right);
755     input_nodes.push_back(std::make_pair(slice_right, 1));
756   }
757 
758   auto make_tuple_a = graph->NewCNode(make_tuple_a_inputs);
759   auto alltoall_attrs = CreateNeighborExchangeAttrs(cnode);
760   auto alltoall_v = gen_g_.PushBack({gen_g_.NewOpInst(NEIGHBOREXCHANGE, alltoall_attrs), make_tuple_a});
761 
762   AnfNodePtr conv2d;
763   Attr concat_axis = {AXIS, MakeValue(-1)};
764   OperatorAttrs concat_attrs = {concat_axis};
765 
766   if (left_need_recv_) {
767     std::vector<AnfNodePtr> tuple_getitem_l_inputs = {NewValueNode(prim::kPrimTupleGetItem), alltoall_v,
768                                                       CreatInt64Imm(0)};
769     auto tuple_getitem_l = graph->NewCNode(tuple_getitem_l_inputs);
770     std::vector<AnfNodePtr> make_tuple_l_inputs = {NewValueNode(prim::kPrimMakeTuple), tuple_getitem_l,
771                                                    cnode->input(1)};
772     auto make_tuple_l = graph->NewCNode(make_tuple_l_inputs);
773     auto concat_l = gen_g_.PushBack({gen_g_.NewOpInst(CONCAT, concat_attrs), make_tuple_l});
774 
775     if (right_need_recv_) {
776       std::vector<AnfNodePtr> tuple_getitem_r_inputs = {NewValueNode(prim::kPrimTupleGetItem), alltoall_v,
777                                                         CreatInt64Imm(1)};
778       auto tuple_getitem_r = graph->NewCNode(tuple_getitem_r_inputs);
779       std::vector<AnfNodePtr> make_tuple_r_inputs = {NewValueNode(prim::kPrimMakeTuple), concat_l, tuple_getitem_r};
780       auto make_tuple_r = graph->NewCNode(make_tuple_r_inputs);
781       auto concat_r = gen_g_.PushBack({gen_g_.NewOpInst(CONCAT, concat_attrs), make_tuple_r});
782       conv2d = GenerateConv2DNode(concat_r, cnode);
783     } else {
784       conv2d = GenerateConv2DNode(concat_l, cnode);
785     }
786   } else {  // left no need recv, and right need recv
787     std::vector<AnfNodePtr> tuple_getitem_r_inputs_1 = {NewValueNode(prim::kPrimTupleGetItem), alltoall_v,
788                                                         CreatInt64Imm(0)};
789     auto tuple_getitem_r_1 = graph->NewCNode(tuple_getitem_r_inputs_1);
790     std::vector<AnfNodePtr> make_tuple_r_inputs_1 = {NewValueNode(prim::kPrimMakeTuple), gen_g_.virtual_input_node(),
791                                                      tuple_getitem_r_1};
792     auto make_tuple_r_1 = graph->NewCNode(make_tuple_r_inputs_1);
793     input_nodes.push_back(std::make_pair(make_tuple_r_1, 1));
794 
795     auto concat_r_1 = gen_g_.PushBack({gen_g_.NewOpInst(CONCAT, concat_attrs), make_tuple_r_1});
796     conv2d = GenerateConv2DNode(concat_r_1, cnode);
797   }
798 
799   replace_graph_ = std::make_shared<std::pair<std::vector<std::pair<AnfNodePtr, int64_t>>, AnfNodePtr>>(
800     std::make_pair(input_nodes, conv2d));
801 }
802 
replace_graph(const CNodePtr & cnode)803 ReplaceGraphPtr Conv2DInfo::replace_graph(const CNodePtr &cnode) {
804   if (!need_exchange_overlap_) {
805     if (!out_channel_shard_) {
806       return nullptr;
807     }
808     auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
809     prim->set_attr(OUT_CHANNEL, MakeValue(new_out_channel_));
810     return nullptr;
811   }
812 
813   if (InferRankBias() != SUCCESS) {
814     return nullptr;
815   }
816 
817   InferOverlapSize();
818 
819   InferNewOperatorAttrs();
820 
821   ComputeReplaceGraph(cnode);
822   return replace_graph_;
823 }
824 
ReComputeBatchSplitFlagList()825 void Conv2DInfo::ReComputeBatchSplitFlagList() {
826   split_flag_list_[0] = true;
827   split_flag_list_[1] = false;
828 }
829 
SetCostUnderStrategy(const StrategyPtr & strategy)830 Status Conv2DInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }
831 
GenerateOpStrategies(int64_t stage_id)832 std::vector<StrategyPtr> Conv2DInfo::GenerateOpStrategies(int64_t stage_id) {
833   Strategys strategy = {{stage_device_size_, 1, 1, 1}, {1, 1, 1, 1}};
834   StrategyPtr sp = std::make_shared<Strategy>(stage_id, strategy);
835   std::vector<StrategyPtr> sp_vector;
836   sp_vector.push_back(sp);
837   return sp_vector;
838 }
839 
Init(const StrategyPtr & strategy)840 Status Conv2DInfo::Init(const StrategyPtr &strategy) {
841   if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
842     MS_LOG(ERROR) << name_ << ": Init failed.";
843     return FAILED;
844   }
845   MS_LOG(INFO) << name_ << ": Init success.";
846   return SUCCESS;
847 }
848 
InitForCostModel(const StrategyPtr & strategy)849 Status Conv2DInfo::InitForCostModel(const StrategyPtr &strategy) {
850   if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
851     MS_LOG(ERROR) << name_ << ": Init for cost model failed.";
852     return FAILED;
853   }
854 
855   MS_LOG(INFO) << name_ << ": Init for cost model success.";
856   return SUCCESS;
857 }
858 
GetOutShape()859 Status Conv2DBackpropInputInfo::GetOutShape() {
860   if (input_value_.size() != 3) {
861     MS_LOG(ERROR) << name_ << ": The size of input value must be 3, but got " << input_value_.size();
862     return FAILED;
863   }
864 
865   if (input_value_[2] == nullptr) {
866     MS_LOG(ERROR) << name_ << ": The input_value_[2] is nullptr";
867     return FAILED;
868   }
869 
870   std::vector<ValuePtr> elements;
871   auto value_tuple = input_value_[2]->cast<ValueTuplePtr>();
872   if (value_tuple == nullptr) {
873     MS_LOG(ERROR) << name_ << ": Input_value_[2] must be ValueTuplePtr.";
874     return FAILED;
875   }
876   elements = value_tuple->value();
877   if (elements.size() != 4) {
878     MS_LOG(ERROR) << name_ << ": Elements size must be 4, but got " << elements.size();
879     return FAILED;
880   }
881 
882   for (auto &element : elements) {
883     MS_EXCEPTION_IF_NULL(element);
884     if (element->isa<Int64Imm>()) {
885       int64_t ele_value = element->cast<Int64ImmPtr>()->value();
886       out_shape_.push_back(ele_value);
887     } else {
888       MS_LOG(ERROR) << name_ << ": The value of shape must be int";
889       return FAILED;
890     }
891   }
892 
893   return SUCCESS;
894 }
895 
GetAttrs()896 Status Conv2DBackpropInputInfo::GetAttrs() {
897   if (GetAttrsBase() != SUCCESS) {
898     return FAILED;
899   }
900 
901   return GetOutShape();
902 }
903 
CheckStrategy(const StrategyPtr & strategy)904 Status Conv2DBackpropInputInfo::CheckStrategy(const StrategyPtr &strategy) {
905   need_exchange_overlap_ = false;
906   if (CheckStrategyBase(strategy) != SUCCESS) {
907     return FAILED;
908   }
909 
910   std::vector<Dimensions> stra = strategy->GetInputDim();
911   Dimensions input_strategy = stra[0];
912   Dimensions weight_strategy = stra[1];
913   if (input_strategy[1] != weight_strategy[0]) {
914     MS_LOG(ERROR) << name_ << ": The shard num of c-out for input strategy is " << input_strategy[1]
915                   << ", but the shard num of c-out for weight strategy is " << weight_strategy[0];
916     return FAILED;
917   }
918 
919   if (input_strategy[2] != 1 || input_strategy[3] != 1) {
920     if (CheckHWStrategy(input_strategy[2], input_strategy[3]) != SUCCESS) {
921       return FAILED;
922     }
923   }
924 
925   // kernel size larger than stride and the w dimension is split, need to exchange overlap
926   if ((kernel_size_[1] > stride_[3]) && (input_strategy[3] > 1)) {
927     need_exchange_overlap_ = true;
928   }
929   return SUCCESS;
930 }
931 
CheckHWStrategy(int64_t h_strategy,int64_t w_strategy)932 Status Conv2DBackpropInputInfo::CheckHWStrategy(int64_t h_strategy, int64_t w_strategy) {
933   if (CheckHWStrategyBase(h_strategy, w_strategy) != SUCCESS) {
934     return FAILED;
935   }
936 
937   if (pad_mode_ != 1) {  // only support same mode
938     MS_LOG(ERROR) << name_ << ": Do not support the pad mode " << pad_mode_ << " when split H or W dimension";
939     return FAILED;
940   }
941 
942   if (h_strategy > 1) {
943     MS_LOG(ERROR) << name_ << ": Do not support to split h dimension";
944     return FAILED;
945   }
946 
947   if (w_strategy > 1 && inputs_shape_[0][3] * stride_[3] != outputs_shape_[0][3]) {
948     MS_LOG(ERROR) << name_ << ": Do not support to split w dimension when in_shape * stride != out_shape";
949     return FAILED;
950   }
951 
952   return SUCCESS;
953 }
954 
InferDevMatrixShape()955 Status Conv2DBackpropInputInfo::InferDevMatrixShape() {
956   // the strategy is ((n, o, h, w), (o, i, 1, 1))
957   // the dev matrix is (n, o, h, w, i)
958   MS_EXCEPTION_IF_NULL(strategy_);
959   std::vector<Dimensions> stra = strategy_->GetInputDim();
960   if (stra.size() != 2) {
961     MS_LOG(ERROR) << name_ << ": The size of strategy must be 2, but got " << stra.size();
962     return FAILED;
963   }
964 
965   dev_matrix_shape_ = stra[0];
966   dev_matrix_shape_.push_back(stra[1][1]);
967 
968   Shape out_strategy = stra[0];
969   out_strategy[1] = stra[1][1];
970 
971   out_slice_shape_ = out_shape_;
972   if (out_shape_.size() != out_strategy.size()) {
973     MS_LOG(ERROR) << name_ << ": The size of out shape is " << out_shape_.size()
974                   << ", but the size of output strategy is " << out_strategy.size();
975     return FAILED;
976   }
977 
978   for (size_t i = 0; i < out_slice_shape_.size(); ++i) {
979     if (out_slice_shape_[i] % out_strategy[i] != 0) {
980       MS_LOG(ERROR) << name_ << ": The output can not be split by strategy. The shape of output is " << out_slice_shape_
981                     << ", but the strategy of output is " << out_strategy;
982       return FAILED;
983     }
984     out_slice_shape_[i] = out_slice_shape_[i] / out_strategy[i];
985   }
986 
987   w_dimension_shard_num_ = stra[0][3];
988   input_slice_shape_ = GetSliceShape(inputs_shape_[0], stra[0]);
989   MS_LOG(INFO) << name_ << ": The output slice shape is " << out_slice_shape_;
990   return SUCCESS;
991 }
992 
InferTensorMap()993 Status Conv2DBackpropInputInfo::InferTensorMap() {
994   // input_strategy: ((n, o, h, w), (o, i, 1, 1))
995   // output_strategy: ((n, i, h, w),)
996   // dev_matrix: (n, o, h, w, i)
997   TensorMap input_tensor_map = {4, 3, 2, 1};
998   TensorMap weight_tensor_map = {3, 0, -1, -1};
999   TensorMap output_tensor_map = {4, 0, 2, 1};
1000 
1001   (void)inputs_tensor_map_.emplace_back(std::move(input_tensor_map));
1002   (void)inputs_tensor_map_.emplace_back(std::move(weight_tensor_map));
1003   (void)outputs_tensor_map_.emplace_back(std::move(output_tensor_map));
1004   return SUCCESS;
1005 }
1006 
InferMirrorOps()1007 Status Conv2DBackpropInputInfo::InferMirrorOps() {
1008   mirror_ops_.clear();
1009   if (inputs_shape_.empty()) {
1010     MS_LOG(INFO) << name_ << ": The inputs size is empty";
1011     return SUCCESS;
1012   }
1013 
1014   if (inputs_tensor_map_.size() != inputs_shape_.size()) {
1015     MS_LOG(ERROR) << name_ << ": The size of inputs tensor map is not equal to the size of inputs shape";
1016     return FAILED;
1017   }
1018 
1019   bool group_is_empty = true;
1020   for (size_t i = 0; i < inputs_tensor_map_.size(); ++i) {
1021     std::vector<Group> group;
1022     if (CreateGroupByTensorMap(inputs_tensor_map_[i], &group) != SUCCESS) {
1023       MS_LOG(ERROR) << name_ << ": Create group failed, the input index is " << i;
1024       mirror_ops_.clear();
1025       return FAILED;
1026     }
1027 
1028     OperatorVector mirror_op;
1029     if (group.empty()) {
1030       MS_LOG(INFO) << name_ << ": The mirror group is empty, the input index is " << i;
1031       mirror_ops_.push_back(mirror_op);
1032       continue;
1033     }
1034 
1035     group_is_empty = false;
1036     mirror_op = CreateMirrorOps(group[0].name(), group[0].GetDevNum());
1037     mirror_ops_.push_back(mirror_op);
1038   }
1039 
1040   if (group_is_empty) {
1041     mirror_ops_.clear();
1042     MS_LOG(INFO) << name_ << ": No need to insert mirror ops";
1043     return SUCCESS;
1044   }
1045 
1046   OperatorVector tmp_mirror_op;  // tmp mirror op for 'out_shape'
1047   mirror_ops_.push_back(tmp_mirror_op);
1048   return SUCCESS;
1049 }
1050 
UpdateOutShape()1051 void Conv2DBackpropInputInfo::UpdateOutShape() {
1052   auto cnode = cnode_;
1053   MS_EXCEPTION_IF_NULL(cnode);
1054   if (cnode->size() != 4) {
1055     MS_LOG(EXCEPTION) << name_ << ": The size of cnode's inputs must be 4, but got " << cnode->size();
1056   }
1057 
1058   if (!IsValueNode<ValueTuple>(cnode->input(3))) {
1059     MS_LOG(EXCEPTION) << name_ << ": The cnode's input[3] is not value node";
1060   }
1061 
1062   auto func_graph = cnode->func_graph();
1063   MS_EXCEPTION_IF_NULL(func_graph);
1064   auto manager = func_graph->manager();
1065   MS_EXCEPTION_IF_NULL(manager);
1066 
1067   ValuePtr out_shape = MakeValue(out_slice_shape_);
1068   AnfNodePtr val = NewValueNode(out_shape);
1069   (void)manager->Replace(cnode->input(3), val);
1070   MS_LOG(INFO) << name_ << ": Update the output shape " << out_slice_shape_;
1071 }
1072 
ComputeOverlapLeftSizeByRankBias(int64_t rank_bias)1073 int64_t Conv2DBackpropInputInfo::ComputeOverlapLeftSizeByRankBias(int64_t rank_bias) {
1074   // 1. the first rank: 0
1075   // 2. the last rank:
1076   //    size of origin data required by current rank: a = ceil((o/n + k - o + w*s - s - x)/s)
1077   //    data size of the current rank: b = w/n
1078   //    return a - b = ceil((o/n + k - o + w*s - s - x)/s) - w/n
1079   // 3. the middle rank:
1080   //    r*w/n - ceil((r*o/n - k + x + 1)/s)
1081   if (rank_bias == 0) {  // the first rank
1082     return 0;
1083   }
1084 
1085   int64_t w_output_shape = outputs_shape_[0][3];
1086   int64_t w_input_shape = inputs_shape_[0][3];
1087   int64_t w_kernel_size = kernel_size_[1];
1088   int64_t w_stride = stride_[3];
1089   int64_t left_pad = pad_list_[2];
1090   if (rank_bias == w_dimension_shard_num_ - 1) {  // the last rank
1091     return DoubleToLong(std::ceil(LongToDouble(w_output_shape / w_dimension_shard_num_ + w_kernel_size -
1092                                                w_output_shape + w_input_shape * w_stride - w_stride - left_pad) /
1093                                   LongToDouble(w_stride))) -
1094            w_input_shape / w_dimension_shard_num_;
1095   }
1096 
1097   // the middle rank
1098   return rank_bias * w_input_shape / w_dimension_shard_num_ -
1099          DoubleToLong(
1100            std::ceil(LongToDouble(rank_bias * w_output_shape / w_dimension_shard_num_ - w_kernel_size + left_pad + 1) /
1101                      LongToDouble(w_stride)));
1102 }
1103 
ComputeOverlapRightSizeByRankBias(int64_t rank_bias)1104 int64_t Conv2DBackpropInputInfo::ComputeOverlapRightSizeByRankBias(int64_t rank_bias) {
1105   // 1. the first rank: ceil((o/n + x)/s) - w/n
1106   // 2. the last rank: 0
1107   // 3. the middle rank: ceil((r*o/n + o/n + x)/s) - r*w/n - w/n
1108   int64_t w_output_shape = outputs_shape_[0][3];
1109   int64_t w_input_shape = inputs_shape_[0][3];
1110   int64_t w_stride = stride_[3];
1111   int64_t left_pad = pad_list_[2];
1112 
1113   if (rank_bias == 0) {  // the first rank
1114     return DoubleToLong(
1115              std::ceil(LongToDouble(w_output_shape / w_dimension_shard_num_ + left_pad) / LongToDouble(w_stride))) -
1116            w_input_shape / w_dimension_shard_num_;
1117   }
1118 
1119   if (rank_bias == w_dimension_shard_num_ - 1) {  // the last rank
1120     return 0;
1121   }
1122 
1123   // the middle rank
1124   return DoubleToLong(std::ceil(LongToDouble(rank_bias * w_output_shape / w_dimension_shard_num_ +
1125                                              w_output_shape / w_dimension_shard_num_ + left_pad) /
1126                                 LongToDouble(w_stride))) -
1127          (rank_bias + 1) * w_input_shape / w_dimension_shard_num_;
1128 }
1129 
InferNewPadList()1130 void Conv2DBackpropInputInfo::InferNewPadList() {
1131   // 1. compute the size of origin data required by current rank:
1132   //    1) the first rank: ceil((o/n + x) / s)
1133   //    2) the last rank: ceil((o/n + k - o + ws - s - x) / s)
1134   //    3) the middle rank: ceil((r*o/n + o/n + x) / s) - ceil((r*o/n - k + x + 1) / s)
1135   //
1136   // 2. compute the real left pad
1137   //    1) the first rank: k - x - 1
1138   //    2) the last rank:
1139   //       if (o/n + k - o + ws - s - x) is divisible by s, real_left_pad = s - 1.
1140   //       otherwise, real_left_pad = (o/n + k - o + ws - s - x) % s - 1
1141   //    3) the middle rank:
1142   //       if (r*on - k + x + 1) is divisible by s, real_left_pad = 0.
1143   //       otherwise, real_left_pad = s - (r*on - k + x + 1) % s
1144   int64_t w_output_shape = outputs_shape_[0][3];
1145   int64_t w_input_shape = inputs_shape_[0][3];
1146   int64_t w_kernel_size = kernel_size_[1];
1147   int64_t w_stride = stride_[3];
1148   int64_t left_pad = pad_list_[2];
1149   int64_t current_rank_required_size = 0;
1150   int64_t real_left_pad = 0;
1151 
1152   if (rank_bias_ == 0) {  // the first rank
1153     current_rank_required_size = DoubleToLong(
1154       std::ceil(LongToDouble(w_output_shape / w_dimension_shard_num_ + left_pad) / LongToDouble(w_stride)));
1155 
1156     real_left_pad = w_kernel_size - left_pad - 1;
1157   } else if (rank_bias_ == w_dimension_shard_num_ - 1) {  // the last rank
1158     current_rank_required_size =
1159       DoubleToLong(std::ceil(LongToDouble(w_output_shape / w_dimension_shard_num_ + w_kernel_size - w_output_shape +
1160                                           w_input_shape * w_stride - w_stride - left_pad) /
1161                              LongToDouble(w_stride)));
1162 
1163     int64_t tmp = w_output_shape / w_dimension_shard_num_ + w_kernel_size - w_output_shape + w_input_shape * w_stride -
1164                   w_stride - left_pad;
1165     if (tmp % w_stride == 0) {
1166       real_left_pad = w_stride - 1;
1167     } else {
1168       real_left_pad = tmp % w_stride - 1;
1169     }
1170   } else {  // the middle rank
1171     current_rank_required_size =
1172       DoubleToLong(std::ceil(LongToDouble(rank_bias_ * w_output_shape / w_dimension_shard_num_ +
1173                                           w_output_shape / w_dimension_shard_num_ + left_pad) /
1174                              LongToDouble(w_stride))) -
1175       DoubleToLong(
1176         std::ceil(LongToDouble(rank_bias_ * w_output_shape / w_dimension_shard_num_ - w_kernel_size + left_pad + 1) /
1177                   LongToDouble(w_stride)));
1178 
1179     int64_t tmp = rank_bias_ * w_output_shape / w_dimension_shard_num_ - w_kernel_size + left_pad + 1;
1180     if (tmp % w_stride == 0) {
1181       real_left_pad = 0;
1182     } else {
1183       real_left_pad = w_stride - tmp % w_stride;
1184     }
1185   }
1186 
1187   // 3. compute the pad_add: (current_rank_required_size - 1) * s + k - o/n
1188   int64_t pad_all =
1189     (current_rank_required_size - 1) * w_stride + w_kernel_size - w_output_shape / w_dimension_shard_num_;
1190 
1191   // 4. compute new left pad: k - real_left_pad - 1
1192   new_pad_list_ = pad_list_;
1193   new_pad_list_[2] = w_kernel_size - real_left_pad - 1;
1194 
1195   // 5. compute new right pad: pad_all - new_left_pad
1196   new_pad_list_[3] = pad_all - new_pad_list_[2];
1197 
1198   MS_LOG(INFO) << name_ << ": the new pad list is " << new_pad_list_ << ", the required size of current rank is "
1199                << current_rank_required_size << ", new pad all is " << pad_all;
1200 }
1201 
ReplaceNodeInputOrAttrs()1202 void Conv2DBackpropInputInfo::ReplaceNodeInputOrAttrs() { UpdateOutShape(); }
1203 }  // namespace parallel
1204 }  // namespace mindspore
1205