1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/parallel/ops_info/conv2d_info.h"
18
19 #include <algorithm>
20 #include <functional>
21 #include <cmath>
22 #include <memory>
23 #include <utility>
24 #include <vector>
25
26 #include "frontend/parallel/device_matrix.h"
27 #include "frontend/parallel/strategy.h"
28 #include "frontend/parallel/tensor_layout/tensor_redistribution.h"
29 #include "frontend/parallel/graph_util/generate_graph.h"
30 #include "pipeline/jit/resource.h"
31
32 namespace mindspore {
33 namespace parallel {
34 namespace {
MakeListValue(const std::vector<int64_t> & v)35 ValuePtr MakeListValue(const std::vector<int64_t> &v) {
36 std::vector<ValuePtr> list;
37 (void)std::transform(v.begin(), v.end(), std::back_inserter(list), [](int64_t ele) { return MakeValue(ele); });
38 return std::make_shared<ValueSequeue>(list);
39 }
40
MakeTupleListValue(const Shapes & v)41 ValuePtr MakeTupleListValue(const Shapes &v) {
42 std::vector<ValuePtr> tuple;
43 (void)std::transform(v.begin(), v.end(), std::back_inserter(tuple),
44 [](const std::vector<int64_t> &list) { return MakeListValue(list); });
45 return std::make_shared<ValueTuple>(tuple);
46 }
47 } // namespace
GetAttrsBase()48 Status Conv2DInfo::GetAttrsBase() {
49 // format
50 format_ = GetStringAttr(FORMAT);
51 if (format_ != NCHW) {
52 MS_LOG(ERROR) << name_ << ": The format must be 'NCHW', but got " << format_;
53 return FAILED;
54 }
55
56 // out_channel
57 out_channel_ = GetIntAttr(OUT_CHANNEL);
58 if (out_channel_ <= 0) {
59 MS_LOG(ERROR) << name_ << ": The attr of out_channel is invalid";
60 return FAILED;
61 }
62
63 // kernel_size
64 auto kernel_size_iter = attrs_.find(KERNEL_SIZE);
65 if (kernel_size_iter == attrs_.end()) {
66 MS_LOG(ERROR) << name_ << ": Can not find the attribution of " << KERNEL_SIZE;
67 return FAILED;
68 }
69
70 MS_EXCEPTION_IF_NULL(kernel_size_iter->second);
71 if (kernel_size_iter->second->isa<Int64Imm>()) {
72 int64_t kernel_size = kernel_size_iter->second->cast<Int64ImmPtr>()->value();
73 kernel_size_ = {kernel_size, kernel_size};
74 } else if (kernel_size_iter->second->isa<ValueTuple>() || kernel_size_iter->second->isa<ValueList>()) {
75 kernel_size_ = GetValue<std::vector<int64_t>>(kernel_size_iter->second);
76 if (kernel_size_.size() != 2) {
77 MS_LOG(ERROR) << name_ << ": The size of kernel_size'tuple must be 2, but got " << kernel_size_.size();
78 return FAILED;
79 }
80 } else {
81 MS_LOG(ERROR) << name_ << ": The kernel_size must be int or tuple";
82 return FAILED;
83 }
84
85 // mode
86 mode_ = GetIntAttr(MODE);
87 if (mode_ != 1) {
88 MS_LOG(ERROR) << name_ << ": The mode must be 1, but got " << mode_;
89 return FAILED;
90 }
91
92 // pad_mode
93 pad_mode_ = GetIntAttr(PAD_MODE);
94 if (pad_mode_ < 0 || pad_mode_ > 2) {
95 MS_LOG(ERROR) << name_ << ": The pad_mode must be in the range of [0, 2], but got " << pad_mode_;
96 return FAILED;
97 }
98
99 // pad_list
100 pad_list_ = GetTupleIntAttr(PAD_LIST);
101 if (pad_list_.size() != 4) {
102 MS_LOG(ERROR) << name_ << ": The size of pad_list must be 4, but got " << pad_list_.size();
103 return FAILED;
104 }
105
106 // stride
107 stride_ = GetTupleIntAttr(STRIDE);
108 if (stride_.size() != 4) {
109 MS_LOG(ERROR) << name_ << ": The size of stride must be 4, but got " << stride_.size();
110 return FAILED;
111 }
112
113 if (stride_[0] != 1 || stride_[1] != 1) {
114 MS_LOG(ERROR) << name_ << ": The first two elements of stride must be 1, but got (" << stride_[0] << ", "
115 << stride_[1] << ")";
116 return FAILED;
117 }
118
119 // dilation
120 dilation_ = GetTupleIntAttr(DILATION);
121 if (dilation_.size() != 4) {
122 MS_LOG(ERROR) << name_ << ": The size of dilation must be 4, but got " << dilation_.size();
123 return FAILED;
124 }
125
126 // group
127 group_ = GetIntAttr(GROUP);
128
129 MS_LOG(INFO) << name_ << ": The out channel is " << out_channel_ << ", kernel size is " << kernel_size_
130 << ", mode is " << mode_ << ", pad mode is " << pad_mode_ << ", pad list is " << pad_list_
131 << ", stride is " << stride_ << ", dilation is " << dilation_ << ", group is " << group_
132 << ", format is " << format_;
133
134 return SUCCESS;
135 }
136
GetAttrs()137 Status Conv2DInfo::GetAttrs() { return GetAttrsBase(); }
138
CheckHWStrategyBase(int64_t h_strategy,int64_t w_strategy) const139 Status Conv2DInfo::CheckHWStrategyBase(int64_t h_strategy, int64_t w_strategy) const {
140 if (outputs_shape_[0][2] % h_strategy != 0) {
141 MS_LOG(ERROR) << name_
142 << ": Do not support to split h dimension when out_shape of h dimension is not divisible by strategy "
143 "of h dimension";
144 return FAILED;
145 }
146
147 if (outputs_shape_[0][3] % w_strategy != 0) {
148 MS_LOG(ERROR) << name_
149 << ": Do not support to split w dimension when out_shape of w dimension is not divisible by strategy "
150 "of w dimension";
151 return FAILED;
152 }
153
154 return SUCCESS;
155 }
156
CheckHWStrategySameMode(int64_t h_strategy,int64_t w_strategy)157 Status Conv2DInfo::CheckHWStrategySameMode(int64_t h_strategy, int64_t w_strategy) {
158 int64_t h_slice_shape = inputs_shape_[0][2] / h_strategy;
159 int64_t w_slice_shape = inputs_shape_[0][3] / w_strategy;
160
161 // H dimension
162 if (kernel_size_[0] > stride_[2] && h_strategy > 1) {
163 MS_LOG(ERROR) << name_ << ": The 'same' mode do not support to split H when kernel_size > stride";
164 return FAILED;
165 }
166
167 if (h_strategy > 1 && (kernel_size_[0] <= stride_[2] && h_slice_shape % stride_[2] != 0)) {
168 MS_LOG(ERROR) << name_
169 << ": The 'same' mode do not support to split H when kernel_size <= stride but slice shape "
170 "is not divisible by stride ";
171 return FAILED;
172 }
173
174 // W dimension
175 if (w_strategy > 1 && (kernel_size_[1] <= stride_[3] && w_slice_shape % stride_[3] != 0)) {
176 MS_LOG(ERROR) << name_
177 << ": The 'same' mode do not support to split W when kernel_size <= stride but slice shape "
178 "is not divisible by stride ";
179 return FAILED;
180 }
181
182 if (w_strategy > 1 && (kernel_size_[1] > stride_[3])) {
183 if (inputs_shape_[0][3] % stride_[3] != 0) {
184 MS_LOG(ERROR) << name_
185 << ": The 'same' mode do not support to split W when kernel_size > stride but w shape is not "
186 "divisible by stride";
187 return FAILED;
188 }
189
190 if (w_slice_shape <= ((kernel_size_[1] - stride_[3] + 1) / 2)) {
191 MS_LOG(ERROR) << name_
192 << ": The 'same' mode do not support to split W when kernel_size > stride but w slice shape is "
193 "smaller than or equal to (k - s + 1) / 2";
194 return FAILED;
195 }
196
197 if (kernel_size_[1] - stride_[3] == 1) {
198 MS_LOG(ERROR) << name_ << ": The 'same' mode do not support to split W when kernel_size > stride but k - s == 1";
199 return FAILED;
200 }
201 }
202
203 return SUCCESS;
204 }
205
CheckHWStrategyValidMode(int64_t h_strategy,int64_t w_strategy)206 Status Conv2DInfo::CheckHWStrategyValidMode(int64_t h_strategy, int64_t w_strategy) {
207 int64_t h_slice_shape = inputs_shape_[0][2] / h_strategy;
208 int64_t w_slice_shape = inputs_shape_[0][3] / w_strategy;
209
210 if ((kernel_size_[0] > stride_[2] && h_strategy > 1) || (kernel_size_[1] > stride_[3] && w_strategy > 1)) {
211 MS_LOG(ERROR) << name_ << ": The 'valid' mode do not support to split H or W when kernel_size > stride";
212 return FAILED;
213 }
214
215 if (kernel_size_[0] <= stride_[2] && h_slice_shape % stride_[2] != 0) {
216 MS_LOG(ERROR) << name_
217 << ": The 'valid' mode do not support to split H when kernel_size <= stride but slice shape is "
218 "not divisible by stride ";
219 return FAILED;
220 }
221
222 if (kernel_size_[1] <= stride_[3] && w_slice_shape % stride_[3] != 0) {
223 MS_LOG(ERROR) << name_
224 << ": The 'valid' mode do not support to split W when kernel_size <= stride but slice shape is "
225 "not divisible by stride ";
226 return FAILED;
227 }
228
229 return SUCCESS;
230 }
231
CheckHWStrategy(int64_t h_strategy,int64_t w_strategy)232 Status Conv2DInfo::CheckHWStrategy(int64_t h_strategy, int64_t w_strategy) {
233 if (CheckHWStrategyBase(h_strategy, w_strategy) != SUCCESS) {
234 return FAILED;
235 }
236
237 if (pad_mode_ == 0) { // 'pad' mode
238 MS_LOG(ERROR) << name_ << ": The 'pad' mode do not support to split H or W";
239 return FAILED;
240 }
241
242 if (pad_mode_ == 1) { // 'same' mode
243 return CheckHWStrategySameMode(h_strategy, w_strategy);
244 }
245
246 if (pad_mode_ == 2) { // 'valid' mode
247 return CheckHWStrategyValidMode(h_strategy, w_strategy);
248 }
249
250 return SUCCESS;
251 }
252
CheckStrategyBase(const StrategyPtr & strategy)253 Status Conv2DInfo::CheckStrategyBase(const StrategyPtr &strategy) {
254 MS_EXCEPTION_IF_NULL(strategy);
255 if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
256 MS_LOG(ERROR) << name_ << ": Invalid strategy";
257 return FAILED;
258 }
259
260 std::vector<Dimensions> stra = strategy->GetInputDim();
261 if (stra.size() != 2) {
262 MS_LOG(ERROR) << name_ << ": The size of strategy must be 2, but got " << stra.size();
263 return FAILED;
264 }
265
266 Dimensions input_strategy = stra[0];
267 Dimensions weight_strategy = stra[1];
268 if (input_strategy.size() != 4 || weight_strategy.size() != 4) {
269 MS_LOG(ERROR) << name_
270 << ": The size of input strategy or weight strategy must be 4, but the size of input strategy is "
271 << input_strategy.size() << ", the size of weight strategy is " << weight_strategy.size();
272 return FAILED;
273 }
274
275 if (weight_strategy[2] != 1 || weight_strategy[3] != 1) {
276 MS_LOG(ERROR) << name_ << ": The kernel size can not be split, but the strategy for kernel size is ("
277 << weight_strategy[2] << ", " << weight_strategy[3] << ")";
278 return FAILED;
279 }
280
281 if (weight_strategy[0] > 1) {
282 out_channel_shard_ = true;
283 new_out_channel_ = out_channel_ / weight_strategy[0];
284 } else {
285 out_channel_shard_ = false;
286 new_out_channel_ = out_channel_;
287 }
288
289 int64_t input_except_n_shards =
290 std::accumulate(input_strategy.begin() + 1, input_strategy.end(), 1, std::multiplies<int64_t>());
291 int64_t weight_shards =
292 std::accumulate(weight_strategy.begin() + 1, weight_strategy.end(), 1, std::multiplies<int64_t>());
293
294 bool is_data_parallel = (input_except_n_shards * weight_shards == 1);
295 if (!is_data_parallel) {
296 if (std::any_of(dilation_.begin(), dilation_.end(), [](int64_t value) { return value != 1; })) {
297 MS_LOG(ERROR) << name_ << ": If it is not data parallel, the value of dilation must be 1, but got " << dilation_;
298 return FAILED;
299 }
300
301 if (group_ != 1) {
302 MS_LOG(ERROR) << name_ << ": If it is not data parallel, the group must be 1, but got " << group_;
303 return FAILED;
304 }
305 }
306 return SUCCESS;
307 }
308
CheckStrategy(const StrategyPtr & strategy)309 Status Conv2DInfo::CheckStrategy(const StrategyPtr &strategy) {
310 need_exchange_overlap_ = false;
311 if (CheckStrategyBase(strategy) != SUCCESS) {
312 return FAILED;
313 }
314
315 std::vector<Dimensions> stra = strategy->GetInputDim();
316 Dimensions input_strategy = stra[0];
317 Dimensions weight_strategy = stra[1];
318 if (input_strategy[1] != weight_strategy[1]) {
319 MS_LOG(ERROR) << name_ << ": The shard num of c-in for input strategy is " << input_strategy[1]
320 << ", but the shard num of c-in for weight strategy is " << weight_strategy[1];
321 return FAILED;
322 }
323
324 if (input_strategy[2] != 1 || input_strategy[3] != 1) {
325 if (CheckHWStrategy(input_strategy[2], input_strategy[3]) != SUCCESS) {
326 return FAILED;
327 }
328 }
329
330 // kernel size larger than stride and the w dimension is split, need to exchange overlap
331 if ((kernel_size_[1] > stride_[3]) && (input_strategy[3] > 1)) {
332 need_exchange_overlap_ = true;
333 }
334
335 return SUCCESS;
336 }
337
InferDevMatrixShape()338 Status Conv2DInfo::InferDevMatrixShape() {
339 // the strategy is ((n, i, h, w), (o, i, 1, 1))
340 // the dev matrix is (n, i, h, w, o)
341 MS_EXCEPTION_IF_NULL(strategy_);
342 std::vector<Dimensions> stra = strategy_->GetInputDim();
343 if (stra.size() != 2) {
344 MS_LOG(ERROR) << name_ << ": The size of strategy must be 2, but got " << stra.size();
345 return FAILED;
346 }
347
348 dev_matrix_shape_ = stra[0];
349 dev_matrix_shape_.push_back(stra[1][0]);
350 w_dimension_shard_num_ = stra[0][3];
351 input_slice_shape_ = GetSliceShape(inputs_shape_[0], stra[0]);
352 return SUCCESS;
353 }
354
InferRankBias()355 Status Conv2DInfo::InferRankBias() {
356 // the Conv2D operator:
357 // the origin dev_matrix is [n, i, h, w, o]
358 // if repeated calculation and repeated num in the left of dev matrix, the dev_matrix is [repeated_num, n, i, h, w, o]
359 // if repeated calculation and repeated num in the right of dev matrix, the dev_matrix is [n, i, h, w, o,
360 // repeated_num]
361 //
362 // the Conv2DBackpropInput's origin dev_matrix is [n, o, h, w, i], w dimension's relative position is the same as
363 // Conv2D, the rank_bias_ is the position of the current rank in the w dimension of the dev_matrix(have not split h
364 // dimension)
365 if (!need_exchange_overlap_) {
366 MS_LOG(INFO) << name_ << ": No need to infer rank bias";
367 return SUCCESS;
368 }
369
370 uint64_t w_index_in_dev_matrix = 3;
371 if (repeated_calc_num_ > 1 && !repeated_num_in_dev_matrix_right_) {
372 w_index_in_dev_matrix += 1;
373 }
374
375 CheckGlobalDeviceManager();
376 int64_t rank = g_device_manager->global_rank();
377 DeviceMatrix dev_matrix(rank, stage_device_list_, dev_matrix_shape_);
378 RankList group_devices;
379 if (dev_matrix.GetDevicesAlongDim(w_index_in_dev_matrix, &group_devices) != SUCCESS) {
380 return FAILED;
381 }
382
383 if (group_devices.size() <= 1) {
384 MS_LOG(INFO) << name_ << ": The devices' size of w dimension is " << group_devices.size()
385 << ", no need to infer rank bias";
386 return SUCCESS;
387 }
388
389 if (group_devices.size() != LongToSize(w_dimension_shard_num_)) {
390 MS_LOG(ERROR) << name_ << ": The devices' size of w dimension is " << group_devices.size()
391 << ", but the shard num of w dimension is " << w_dimension_shard_num_;
392 return FAILED;
393 }
394
395 std::vector<int64_t>::iterator it = std::find(group_devices.begin(), group_devices.end(), rank);
396 if (it == group_devices.end()) {
397 MS_LOG(ERROR) << name_ << ": Can not find the current rank in device list of w dimension, the current rank is "
398 << rank << ", the device list is " << group_devices;
399 return FAILED;
400 }
401
402 rank_bias_ = std::distance(group_devices.begin(), it);
403 if (it == group_devices.begin()) {
404 left_rank_bias_ = -1;
405 right_rank_bias_ = rank_bias_ + 1;
406
407 left_rank_id_ = -1;
408 right_rank_id_ = *(it + 1);
409 } else if (it == group_devices.end() - 1) {
410 left_rank_bias_ = rank_bias_ - 1;
411 right_rank_bias_ = -1;
412
413 left_rank_id_ = *(it - 1);
414 right_rank_id_ = -1;
415 } else {
416 left_rank_bias_ = rank_bias_ - 1;
417 right_rank_bias_ = rank_bias_ + 1;
418
419 left_rank_id_ = *(it - 1);
420 right_rank_id_ = *(it + 1);
421 }
422 MS_LOG(INFO) << name_ << ": The current rank is " << rank << ", the device list of w dimension is " << group_devices
423 << ", the rank bias is " << rank_bias_ << ", the left rank bias is " << left_rank_bias_
424 << ", the right rank bias is " << right_rank_bias_ << ", the left rank id is " << left_rank_id_
425 << ", the right rank id is " << right_rank_id_;
426 return SUCCESS;
427 }
428
ComputeOverlapLeftSizeByRankBias(int64_t rank_bias)429 int64_t Conv2DInfo::ComputeOverlapLeftSizeByRankBias(int64_t rank_bias) {
430 int64_t left_pad = pad_list_[2];
431 int64_t w_dimension_input_shape = inputs_shape_[0][3];
432 int64_t w_dimension_output_shape = outputs_shape_[0][3];
433 int64_t w_stride = stride_[3];
434
435 return left_pad +
436 (w_dimension_input_shape - w_dimension_output_shape * w_stride) * rank_bias / w_dimension_shard_num_;
437 }
438
ComputeOverlapRightSizeByRankBias(int64_t rank_bias)439 int64_t Conv2DInfo::ComputeOverlapRightSizeByRankBias(int64_t rank_bias) {
440 int64_t left_pad = pad_list_[2];
441 int64_t w_dimension_input_shape = inputs_shape_[0][3];
442 int64_t w_dimension_output_shape = outputs_shape_[0][3];
443 int64_t w_kernel_size = kernel_size_[1];
444 int64_t w_stride = stride_[3];
445
446 return (rank_bias + 1) * (w_dimension_output_shape * w_stride - w_dimension_input_shape) / w_dimension_shard_num_ +
447 w_kernel_size - w_stride - left_pad;
448 }
449
InferOverlapSize()450 void Conv2DInfo::InferOverlapSize() {
451 if (!need_exchange_overlap_) {
452 MS_LOG(INFO) << name_ << ": No need to infer overlap size";
453 return;
454 }
455
456 overlap_left_size_ = ComputeOverlapLeftSizeByRankBias(rank_bias_);
457 overlap_right_size_ = ComputeOverlapRightSizeByRankBias(rank_bias_);
458
459 if (rank_bias_ == 0) { // it has not left rank
460 left_rank_overlap_left_size_ = 0;
461 left_rank_overlap_right_size_ = 0;
462 right_rank_overlap_left_size_ = ComputeOverlapLeftSizeByRankBias(right_rank_bias_);
463 right_rank_overlap_right_size_ = ComputeOverlapRightSizeByRankBias(right_rank_bias_);
464 } else if (rank_bias_ == w_dimension_shard_num_ - 1) { // it has not right rank
465 left_rank_overlap_left_size_ = ComputeOverlapLeftSizeByRankBias(left_rank_bias_);
466 left_rank_overlap_right_size_ = ComputeOverlapRightSizeByRankBias(left_rank_bias_);
467 right_rank_overlap_left_size_ = 0;
468 right_rank_overlap_right_size_ = 0;
469 } else { // it has left rank and right rank
470 left_rank_overlap_left_size_ = ComputeOverlapLeftSizeByRankBias(left_rank_bias_);
471 left_rank_overlap_right_size_ = ComputeOverlapRightSizeByRankBias(left_rank_bias_);
472 right_rank_overlap_left_size_ = ComputeOverlapLeftSizeByRankBias(right_rank_bias_);
473 right_rank_overlap_right_size_ = ComputeOverlapRightSizeByRankBias(right_rank_bias_);
474 }
475
476 MS_LOG(INFO) << name_ << ": the left overlap size of current rank is " << overlap_left_size_
477 << ", the right overlap size of current rank is " << overlap_right_size_
478 << ", the left overlap size of left rank is " << left_rank_overlap_left_size_
479 << ", the right overlap size of left rank is " << left_rank_overlap_right_size_
480 << ", the left overlap size of right rank is " << right_rank_overlap_left_size_
481 << ", the right overlap size of right rank is " << right_rank_overlap_right_size_;
482 }
483
InferTensorMap()484 Status Conv2DInfo::InferTensorMap() {
485 // input_strategy: ((n, i, h, w), (o, i, 1, 1))
486 // output_strategy: ((n, o, h, w),)
487 // dev_matrix: (n, i, h, w, o)
488 TensorMap input_tensor_map = {4, 3, 2, 1};
489 TensorMap weight_tensor_map = {0, 3, -1, -1};
490 TensorMap output_tensor_map = {4, 0, 2, 1};
491
492 (void)inputs_tensor_map_.emplace_back(std::move(input_tensor_map));
493 (void)inputs_tensor_map_.emplace_back(std::move(weight_tensor_map));
494 (void)outputs_tensor_map_.emplace_back(std::move(output_tensor_map));
495 return SUCCESS;
496 }
497
498 // Conv2d: dev_matrix is (n, i, h, w, o), if in channel is split, it need to insert all reduce
499 // Conv2DBackpropInputInfo: dev_matrix is (n, o, h, w, i), if out channel is split, it need to insert all reduce
InferForwardCommunication()500 Status Conv2DInfo::InferForwardCommunication() {
501 forward_op_.clear();
502 size_t relevant_dim_index = IN_CHANNEL_INDEX;
503 if (repeated_calc_num_ > 1 && !repeated_num_in_dev_matrix_right_) {
504 // if repeated calculation and repeated num in the left of dev matrix, the index of relevant dimension should add 1
505 relevant_dim_index += 1;
506 }
507
508 if (dev_matrix_shape_[relevant_dim_index] == MIN_SLICE_NUM) {
509 MS_LOG(INFO) << name_ << ": Forward all reduce is not required";
510 return SUCCESS;
511 }
512
513 std::vector<Group> group_list;
514 if (CreateGroupByDim(relevant_dim_index, &group_list) != SUCCESS) {
515 MS_LOG(ERROR) << name_ << ": Create group failed";
516 return FAILED;
517 }
518
519 if (group_list.empty()) {
520 MS_LOG(INFO) << name_ << ": Forward all reduce is not required";
521 return SUCCESS;
522 }
523
524 Operator op = CreateAllReduceOp(REDUCE_OP_SUM, group_list[0].name());
525 forward_op_.push_back(op);
526 MS_LOG(INFO) << name_ << ": The group name of forward all reduce is " << group_list[0].name();
527
528 return SUCCESS;
529 }
530
InferNewPadList()531 void Conv2DInfo::InferNewPadList() {
532 new_pad_list_ = pad_list_;
533 if (rank_bias_ == 0) { // the first rank
534 new_pad_list_[3] = 0; // no need the right pad
535 } else if (rank_bias_ == w_dimension_shard_num_ - 1) { // the last rank
536 new_pad_list_[2] = 0; // no need the left pad
537 } else { // the middle rank
538 new_pad_list_[2] = 0; // no need the left pad
539 new_pad_list_[3] = 0; // no need the right pad
540 }
541 MS_LOG(INFO) << name_ << ": the new pad list is " << new_pad_list_;
542 }
543
InferSendRecvFlag()544 void Conv2DInfo::InferSendRecvFlag() {
545 if (rank_bias_ == 0) { // the first rank
546 left_need_send_ = false;
547 left_need_recv_ = false;
548 right_need_send_ = (right_rank_overlap_left_size_ > 0);
549 right_need_recv_ = (overlap_right_size_ > 0); // no need the right pad
550 } else if (rank_bias_ == w_dimension_shard_num_ - 1) { // the last rank
551 left_need_send_ = (left_rank_overlap_right_size_ > 0);
552 left_need_recv_ = (overlap_left_size_ > 0);
553 right_need_send_ = false;
554 right_need_recv_ = false;
555 } else { // the middle rank
556 left_need_send_ = (left_rank_overlap_right_size_ > 0);
557 left_need_recv_ = (overlap_left_size_ > 0);
558 right_need_send_ = (right_rank_overlap_left_size_ > 0);
559 right_need_recv_ = (overlap_right_size_ > 0);
560 }
561 MS_LOG(INFO) << name_ << ": The left need send is " << left_need_send_ << ", the left need recv is "
562 << left_need_recv_ << ", the right need send is " << right_need_send_ << ", the right need recv is "
563 << right_need_recv_;
564
565 if (left_need_send_) {
566 if (left_rank_overlap_right_size_ >= input_slice_shape_[3]) {
567 MS_LOG(EXCEPTION) << name_ << ": Do not support left overlap size(" << left_rank_overlap_right_size_
568 << ") larger than or equal to slice shape in w dimension(" << input_slice_shape_[3] << ")";
569 }
570 send_rank_ids_.push_back(left_rank_id_);
571 }
572
573 if (right_need_send_) {
574 if (right_rank_overlap_left_size_ >= input_slice_shape_[3]) {
575 MS_LOG(EXCEPTION) << name_ << ": Do not support left overlap size(" << right_rank_overlap_left_size_
576 << ") larger than or equal to slice shape in w dimension(" << input_slice_shape_[3] << ")";
577 }
578 send_rank_ids_.push_back(right_rank_id_);
579 }
580
581 if (left_need_recv_) {
582 recv_rank_ids_.push_back(left_rank_id_);
583 }
584
585 if (right_need_recv_) {
586 recv_rank_ids_.push_back(right_rank_id_);
587 }
588
589 MS_LOG(INFO) << name_ << ": The send rank ids is " << send_rank_ids_ << ", the recv rank ids is " << recv_rank_ids_;
590 }
591
InferOverlapShapes()592 void Conv2DInfo::InferOverlapShapes() {
593 if (left_need_recv_) {
594 Shape left_recv_shape = input_slice_shape_;
595 left_recv_shape[3] = overlap_left_size_;
596 recv_shapes_.push_back(left_recv_shape);
597 }
598
599 if (right_need_recv_) {
600 Shape right_recv_shape = input_slice_shape_;
601 right_recv_shape[3] = overlap_right_size_;
602 recv_shapes_.push_back(right_recv_shape);
603 }
604
605 if (left_need_send_) {
606 Shape left_send_shape = input_slice_shape_;
607 left_send_shape[3] = left_rank_overlap_right_size_;
608 send_shapes_.push_back(left_send_shape);
609 }
610
611 if (right_need_send_) {
612 Shape right_send_shape = input_slice_shape_;
613 right_send_shape[3] = right_rank_overlap_left_size_;
614 send_shapes_.push_back(right_send_shape);
615 }
616 MS_LOG(INFO) << name_ << ": the recv shapes is " << recv_shapes_ << ", the send shapes is " << send_shapes_;
617 }
618
InferStridedSliceAttrs()619 void Conv2DInfo::InferStridedSliceAttrs() {
620 if (left_need_send_) {
621 left_strided_slice_begin_ = {0, 0, 0, 0};
622 left_strided_slice_end_ = input_slice_shape_;
623 left_strided_slice_end_[3] = left_rank_overlap_right_size_;
624 left_strided_slice_strides_ = {1, 1, 1, 1};
625 MS_LOG(INFO) << name_ << ": The left strided slice begin is " << left_strided_slice_begin_ << ", end is "
626 << left_strided_slice_end_;
627 }
628
629 if (right_need_send_) {
630 right_strided_slice_begin_ = {0, 0, 0, 0};
631 right_strided_slice_begin_[3] = input_slice_shape_[3] - right_rank_overlap_left_size_;
632 right_strided_slice_end_ = input_slice_shape_;
633 right_strided_slice_strides_ = {1, 1, 1, 1};
634 MS_LOG(INFO) << name_ << ": The right strided slice begin is " << right_strided_slice_begin_ << ", end is "
635 << right_strided_slice_end_;
636 }
637 }
638
InferNewOperatorAttrs()639 void Conv2DInfo::InferNewOperatorAttrs() {
640 InferNewPadList();
641
642 InferSendRecvFlag();
643
644 InferOverlapShapes();
645
646 InferStridedSliceAttrs();
647 }
648
CreateNeighborExchangeAttrs(const CNodePtr & cnode)649 OperatorAttrs Conv2DInfo::CreateNeighborExchangeAttrs(const CNodePtr &cnode) {
650 auto type = cnode->Type();
651 MS_EXCEPTION_IF_NULL(type);
652 auto tensor_type = type->cast<mindspore::TensorTypePtr>();
653 MS_EXCEPTION_IF_NULL(tensor_type);
654 auto dtype = tensor_type->element();
655 MS_EXCEPTION_IF_NULL(dtype);
656 Attr send_ranks = {SEND_RNAK_IDS, MakeListValue(send_rank_ids_)};
657 Attr recv_ranks = {RECV_RNAK_IDS, MakeListValue(recv_rank_ids_)};
658 Attr send_shapes = {SEND_SHAPES, MakeTupleListValue(send_shapes_)};
659 Attr recv_shapes = {RECV_SHAPES, MakeTupleListValue(recv_shapes_)};
660 Attr recv_type = {RECV_TYPE, dtype};
661 OperatorAttrs attrs = {send_ranks, recv_ranks, recv_shapes, send_shapes, recv_type};
662 return attrs;
663 }
664
CreateConv2DAttrs()665 OperatorAttrs Conv2DInfo::CreateConv2DAttrs() {
666 Attr out_channel = {OUT_CHANNEL, MakeValue(new_out_channel_)};
667 Attr kernel_size = {KERNEL_SIZE, MakeValue(kernel_size_)};
668 Attr mode = {MODE, MakeValue(mode_)};
669 Attr pad_mode = {PAD_MODE, MakeValue("pad")};
670 Attr pad = {PAD, MakeValue(new_pad_list_)};
671 Attr stride = {STRIDE, MakeValue(stride_)};
672 Attr dilation = {DILATION, MakeValue(dilation_)};
673 Attr group = {GROUP, MakeValue(group_)};
674 Attr data_format = {DATA_FORMAT, MakeValue(format_)};
675
676 OperatorAttrs attrs;
677 if (name_.find(CONV2D_INFO) != std::string::npos) {
678 attrs = {out_channel, kernel_size, mode, pad_mode, pad, stride, dilation, group, data_format};
679 } else { // Conv2DTranspose
680 attrs = {out_channel, kernel_size, pad_mode, pad, pad, mode, stride, dilation, group, data_format};
681 }
682
683 return attrs;
684 }
685
ReplaceNodeName() const686 std::string Conv2DInfo::ReplaceNodeName() const {
687 if (name_.find(CONV2D_INFO) != std::string::npos) {
688 return CONV2D;
689 }
690
691 if (name_.find(CONV2D_BACK_PROP_INPUT_INFO) != std::string::npos) {
692 return CONV2D_BACK_PROP_INPUT;
693 }
694
695 if (name_.find(CONV2D_TRANSPOSE_INFO) != std::string::npos) {
696 return CONV2D_TRANSPOSE;
697 }
698
699 MS_LOG(EXCEPTION) << "Invalid name: " << name_;
700 }
701
GenerateConv2DNode(const AnfNodePtr & new_input,const CNodePtr & cnode)702 AnfNodePtr Conv2DInfo::GenerateConv2DNode(const AnfNodePtr &new_input, const CNodePtr &cnode) {
703 auto conv2d_attrs = CreateConv2DAttrs();
704 auto node_name = ReplaceNodeName();
705
706 // conv2d
707 if (name_.find(CONV2D_INFO) != std::string::npos) {
708 if (cnode->size() < 3) {
709 MS_LOG(EXCEPTION) << name_ << ": The size of cnode is invalid: " << cnode->size();
710 }
711 return gen_g_.PushBack({gen_g_.NewOpInst(node_name, conv2d_attrs), new_input, cnode->input(2)});
712 }
713
714 // conv2dtranspose
715 if (cnode->size() < 4) {
716 MS_LOG(EXCEPTION) << name_ << ": The size of cnode is invalid: " << cnode->size();
717 }
718 return gen_g_.PushBack({gen_g_.NewOpInst(node_name, conv2d_attrs), new_input, cnode->input(2), cnode->input(3)});
719 }
720
ComputeReplaceGraph(const CNodePtr & cnode)721 void Conv2DInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
722 auto graph = cnode->func_graph();
723 MS_EXCEPTION_IF_NULL(graph);
724
725 if (gen_g_.Init(cnode) != SUCCESS) {
726 MS_LOG(EXCEPTION) << "GenerateGraph Init failed";
727 }
728
729 if (!left_need_send_ && !right_need_send_) {
730 MS_LOG(EXCEPTION) << name_ << ": Now do not support left no need to send and right no need to send";
731 }
732
733 if (!left_need_recv_ && !right_need_recv_) {
734 MS_LOG(EXCEPTION) << name_ << ": Now do not support left no need to recv and right no need to recv";
735 }
736
737 std::vector<std::pair<AnfNodePtr, int64_t>> input_nodes;
738 std::vector<AnfNodePtr> make_tuple_a_inputs = {NewValueNode(prim::kPrimMakeTuple)};
739 if (left_need_send_) {
740 auto slice_left_begin = CreateTuple(left_strided_slice_begin_);
741 auto slice_left_end = CreateTuple(left_strided_slice_end_);
742 auto slice_left_strided = CreateTuple(left_strided_slice_strides_);
743 auto slice_left = gen_g_.PushBack({gen_g_.NewOpInst(STRIDED_SLICE), gen_g_.virtual_input_node(), slice_left_begin,
744 slice_left_end, slice_left_strided});
745 make_tuple_a_inputs.push_back(slice_left);
746 input_nodes.push_back(std::make_pair(slice_left, 1));
747 }
748 if (right_need_send_) {
749 auto slice_right_begin = CreateTuple(right_strided_slice_begin_);
750 auto slice_right_end = CreateTuple(right_strided_slice_end_);
751 auto slice_right_strided = CreateTuple(right_strided_slice_strides_);
752 auto slice_right = gen_g_.PushBack({gen_g_.NewOpInst(STRIDED_SLICE), gen_g_.virtual_input_node(), slice_right_begin,
753 slice_right_end, slice_right_strided});
754 make_tuple_a_inputs.push_back(slice_right);
755 input_nodes.push_back(std::make_pair(slice_right, 1));
756 }
757
758 auto make_tuple_a = graph->NewCNode(make_tuple_a_inputs);
759 auto alltoall_attrs = CreateNeighborExchangeAttrs(cnode);
760 auto alltoall_v = gen_g_.PushBack({gen_g_.NewOpInst(NEIGHBOREXCHANGE, alltoall_attrs), make_tuple_a});
761
762 AnfNodePtr conv2d;
763 Attr concat_axis = {AXIS, MakeValue(-1)};
764 OperatorAttrs concat_attrs = {concat_axis};
765
766 if (left_need_recv_) {
767 std::vector<AnfNodePtr> tuple_getitem_l_inputs = {NewValueNode(prim::kPrimTupleGetItem), alltoall_v,
768 CreatInt64Imm(0)};
769 auto tuple_getitem_l = graph->NewCNode(tuple_getitem_l_inputs);
770 std::vector<AnfNodePtr> make_tuple_l_inputs = {NewValueNode(prim::kPrimMakeTuple), tuple_getitem_l,
771 cnode->input(1)};
772 auto make_tuple_l = graph->NewCNode(make_tuple_l_inputs);
773 auto concat_l = gen_g_.PushBack({gen_g_.NewOpInst(CONCAT, concat_attrs), make_tuple_l});
774
775 if (right_need_recv_) {
776 std::vector<AnfNodePtr> tuple_getitem_r_inputs = {NewValueNode(prim::kPrimTupleGetItem), alltoall_v,
777 CreatInt64Imm(1)};
778 auto tuple_getitem_r = graph->NewCNode(tuple_getitem_r_inputs);
779 std::vector<AnfNodePtr> make_tuple_r_inputs = {NewValueNode(prim::kPrimMakeTuple), concat_l, tuple_getitem_r};
780 auto make_tuple_r = graph->NewCNode(make_tuple_r_inputs);
781 auto concat_r = gen_g_.PushBack({gen_g_.NewOpInst(CONCAT, concat_attrs), make_tuple_r});
782 conv2d = GenerateConv2DNode(concat_r, cnode);
783 } else {
784 conv2d = GenerateConv2DNode(concat_l, cnode);
785 }
786 } else { // left no need recv, and right need recv
787 std::vector<AnfNodePtr> tuple_getitem_r_inputs_1 = {NewValueNode(prim::kPrimTupleGetItem), alltoall_v,
788 CreatInt64Imm(0)};
789 auto tuple_getitem_r_1 = graph->NewCNode(tuple_getitem_r_inputs_1);
790 std::vector<AnfNodePtr> make_tuple_r_inputs_1 = {NewValueNode(prim::kPrimMakeTuple), gen_g_.virtual_input_node(),
791 tuple_getitem_r_1};
792 auto make_tuple_r_1 = graph->NewCNode(make_tuple_r_inputs_1);
793 input_nodes.push_back(std::make_pair(make_tuple_r_1, 1));
794
795 auto concat_r_1 = gen_g_.PushBack({gen_g_.NewOpInst(CONCAT, concat_attrs), make_tuple_r_1});
796 conv2d = GenerateConv2DNode(concat_r_1, cnode);
797 }
798
799 replace_graph_ = std::make_shared<std::pair<std::vector<std::pair<AnfNodePtr, int64_t>>, AnfNodePtr>>(
800 std::make_pair(input_nodes, conv2d));
801 }
802
replace_graph(const CNodePtr & cnode)803 ReplaceGraphPtr Conv2DInfo::replace_graph(const CNodePtr &cnode) {
804 if (!need_exchange_overlap_) {
805 if (!out_channel_shard_) {
806 return nullptr;
807 }
808 auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
809 prim->set_attr(OUT_CHANNEL, MakeValue(new_out_channel_));
810 return nullptr;
811 }
812
813 if (InferRankBias() != SUCCESS) {
814 return nullptr;
815 }
816
817 InferOverlapSize();
818
819 InferNewOperatorAttrs();
820
821 ComputeReplaceGraph(cnode);
822 return replace_graph_;
823 }
824
ReComputeBatchSplitFlagList()825 void Conv2DInfo::ReComputeBatchSplitFlagList() {
826 split_flag_list_[0] = true;
827 split_flag_list_[1] = false;
828 }
829
SetCostUnderStrategy(const StrategyPtr & strategy)830 Status Conv2DInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }
831
GenerateOpStrategies(int64_t stage_id)832 std::vector<StrategyPtr> Conv2DInfo::GenerateOpStrategies(int64_t stage_id) {
833 Strategys strategy = {{stage_device_size_, 1, 1, 1}, {1, 1, 1, 1}};
834 StrategyPtr sp = std::make_shared<Strategy>(stage_id, strategy);
835 std::vector<StrategyPtr> sp_vector;
836 sp_vector.push_back(sp);
837 return sp_vector;
838 }
839
Init(const StrategyPtr & strategy)840 Status Conv2DInfo::Init(const StrategyPtr &strategy) {
841 if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
842 MS_LOG(ERROR) << name_ << ": Init failed.";
843 return FAILED;
844 }
845 MS_LOG(INFO) << name_ << ": Init success.";
846 return SUCCESS;
847 }
848
InitForCostModel(const StrategyPtr & strategy)849 Status Conv2DInfo::InitForCostModel(const StrategyPtr &strategy) {
850 if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
851 MS_LOG(ERROR) << name_ << ": Init for cost model failed.";
852 return FAILED;
853 }
854
855 MS_LOG(INFO) << name_ << ": Init for cost model success.";
856 return SUCCESS;
857 }
858
GetOutShape()859 Status Conv2DBackpropInputInfo::GetOutShape() {
860 if (input_value_.size() != 3) {
861 MS_LOG(ERROR) << name_ << ": The size of input value must be 3, but got " << input_value_.size();
862 return FAILED;
863 }
864
865 if (input_value_[2] == nullptr) {
866 MS_LOG(ERROR) << name_ << ": The input_value_[2] is nullptr";
867 return FAILED;
868 }
869
870 std::vector<ValuePtr> elements;
871 auto value_tuple = input_value_[2]->cast<ValueTuplePtr>();
872 if (value_tuple == nullptr) {
873 MS_LOG(ERROR) << name_ << ": Input_value_[2] must be ValueTuplePtr.";
874 return FAILED;
875 }
876 elements = value_tuple->value();
877 if (elements.size() != 4) {
878 MS_LOG(ERROR) << name_ << ": Elements size must be 4, but got " << elements.size();
879 return FAILED;
880 }
881
882 for (auto &element : elements) {
883 MS_EXCEPTION_IF_NULL(element);
884 if (element->isa<Int64Imm>()) {
885 int64_t ele_value = element->cast<Int64ImmPtr>()->value();
886 out_shape_.push_back(ele_value);
887 } else {
888 MS_LOG(ERROR) << name_ << ": The value of shape must be int";
889 return FAILED;
890 }
891 }
892
893 return SUCCESS;
894 }
895
GetAttrs()896 Status Conv2DBackpropInputInfo::GetAttrs() {
897 if (GetAttrsBase() != SUCCESS) {
898 return FAILED;
899 }
900
901 return GetOutShape();
902 }
903
CheckStrategy(const StrategyPtr & strategy)904 Status Conv2DBackpropInputInfo::CheckStrategy(const StrategyPtr &strategy) {
905 need_exchange_overlap_ = false;
906 if (CheckStrategyBase(strategy) != SUCCESS) {
907 return FAILED;
908 }
909
910 std::vector<Dimensions> stra = strategy->GetInputDim();
911 Dimensions input_strategy = stra[0];
912 Dimensions weight_strategy = stra[1];
913 if (input_strategy[1] != weight_strategy[0]) {
914 MS_LOG(ERROR) << name_ << ": The shard num of c-out for input strategy is " << input_strategy[1]
915 << ", but the shard num of c-out for weight strategy is " << weight_strategy[0];
916 return FAILED;
917 }
918
919 if (input_strategy[2] != 1 || input_strategy[3] != 1) {
920 if (CheckHWStrategy(input_strategy[2], input_strategy[3]) != SUCCESS) {
921 return FAILED;
922 }
923 }
924
925 // kernel size larger than stride and the w dimension is split, need to exchange overlap
926 if ((kernel_size_[1] > stride_[3]) && (input_strategy[3] > 1)) {
927 need_exchange_overlap_ = true;
928 }
929 return SUCCESS;
930 }
931
CheckHWStrategy(int64_t h_strategy,int64_t w_strategy)932 Status Conv2DBackpropInputInfo::CheckHWStrategy(int64_t h_strategy, int64_t w_strategy) {
933 if (CheckHWStrategyBase(h_strategy, w_strategy) != SUCCESS) {
934 return FAILED;
935 }
936
937 if (pad_mode_ != 1) { // only support same mode
938 MS_LOG(ERROR) << name_ << ": Do not support the pad mode " << pad_mode_ << " when split H or W dimension";
939 return FAILED;
940 }
941
942 if (h_strategy > 1) {
943 MS_LOG(ERROR) << name_ << ": Do not support to split h dimension";
944 return FAILED;
945 }
946
947 if (w_strategy > 1 && inputs_shape_[0][3] * stride_[3] != outputs_shape_[0][3]) {
948 MS_LOG(ERROR) << name_ << ": Do not support to split w dimension when in_shape * stride != out_shape";
949 return FAILED;
950 }
951
952 return SUCCESS;
953 }
954
InferDevMatrixShape()955 Status Conv2DBackpropInputInfo::InferDevMatrixShape() {
956 // the strategy is ((n, o, h, w), (o, i, 1, 1))
957 // the dev matrix is (n, o, h, w, i)
958 MS_EXCEPTION_IF_NULL(strategy_);
959 std::vector<Dimensions> stra = strategy_->GetInputDim();
960 if (stra.size() != 2) {
961 MS_LOG(ERROR) << name_ << ": The size of strategy must be 2, but got " << stra.size();
962 return FAILED;
963 }
964
965 dev_matrix_shape_ = stra[0];
966 dev_matrix_shape_.push_back(stra[1][1]);
967
968 Shape out_strategy = stra[0];
969 out_strategy[1] = stra[1][1];
970
971 out_slice_shape_ = out_shape_;
972 if (out_shape_.size() != out_strategy.size()) {
973 MS_LOG(ERROR) << name_ << ": The size of out shape is " << out_shape_.size()
974 << ", but the size of output strategy is " << out_strategy.size();
975 return FAILED;
976 }
977
978 for (size_t i = 0; i < out_slice_shape_.size(); ++i) {
979 if (out_slice_shape_[i] % out_strategy[i] != 0) {
980 MS_LOG(ERROR) << name_ << ": The output can not be split by strategy. The shape of output is " << out_slice_shape_
981 << ", but the strategy of output is " << out_strategy;
982 return FAILED;
983 }
984 out_slice_shape_[i] = out_slice_shape_[i] / out_strategy[i];
985 }
986
987 w_dimension_shard_num_ = stra[0][3];
988 input_slice_shape_ = GetSliceShape(inputs_shape_[0], stra[0]);
989 MS_LOG(INFO) << name_ << ": The output slice shape is " << out_slice_shape_;
990 return SUCCESS;
991 }
992
InferTensorMap()993 Status Conv2DBackpropInputInfo::InferTensorMap() {
994 // input_strategy: ((n, o, h, w), (o, i, 1, 1))
995 // output_strategy: ((n, i, h, w),)
996 // dev_matrix: (n, o, h, w, i)
997 TensorMap input_tensor_map = {4, 3, 2, 1};
998 TensorMap weight_tensor_map = {3, 0, -1, -1};
999 TensorMap output_tensor_map = {4, 0, 2, 1};
1000
1001 (void)inputs_tensor_map_.emplace_back(std::move(input_tensor_map));
1002 (void)inputs_tensor_map_.emplace_back(std::move(weight_tensor_map));
1003 (void)outputs_tensor_map_.emplace_back(std::move(output_tensor_map));
1004 return SUCCESS;
1005 }
1006
InferMirrorOps()1007 Status Conv2DBackpropInputInfo::InferMirrorOps() {
1008 mirror_ops_.clear();
1009 if (inputs_shape_.empty()) {
1010 MS_LOG(INFO) << name_ << ": The inputs size is empty";
1011 return SUCCESS;
1012 }
1013
1014 if (inputs_tensor_map_.size() != inputs_shape_.size()) {
1015 MS_LOG(ERROR) << name_ << ": The size of inputs tensor map is not equal to the size of inputs shape";
1016 return FAILED;
1017 }
1018
1019 bool group_is_empty = true;
1020 for (size_t i = 0; i < inputs_tensor_map_.size(); ++i) {
1021 std::vector<Group> group;
1022 if (CreateGroupByTensorMap(inputs_tensor_map_[i], &group) != SUCCESS) {
1023 MS_LOG(ERROR) << name_ << ": Create group failed, the input index is " << i;
1024 mirror_ops_.clear();
1025 return FAILED;
1026 }
1027
1028 OperatorVector mirror_op;
1029 if (group.empty()) {
1030 MS_LOG(INFO) << name_ << ": The mirror group is empty, the input index is " << i;
1031 mirror_ops_.push_back(mirror_op);
1032 continue;
1033 }
1034
1035 group_is_empty = false;
1036 mirror_op = CreateMirrorOps(group[0].name(), group[0].GetDevNum());
1037 mirror_ops_.push_back(mirror_op);
1038 }
1039
1040 if (group_is_empty) {
1041 mirror_ops_.clear();
1042 MS_LOG(INFO) << name_ << ": No need to insert mirror ops";
1043 return SUCCESS;
1044 }
1045
1046 OperatorVector tmp_mirror_op; // tmp mirror op for 'out_shape'
1047 mirror_ops_.push_back(tmp_mirror_op);
1048 return SUCCESS;
1049 }
1050
UpdateOutShape()1051 void Conv2DBackpropInputInfo::UpdateOutShape() {
1052 auto cnode = cnode_;
1053 MS_EXCEPTION_IF_NULL(cnode);
1054 if (cnode->size() != 4) {
1055 MS_LOG(EXCEPTION) << name_ << ": The size of cnode's inputs must be 4, but got " << cnode->size();
1056 }
1057
1058 if (!IsValueNode<ValueTuple>(cnode->input(3))) {
1059 MS_LOG(EXCEPTION) << name_ << ": The cnode's input[3] is not value node";
1060 }
1061
1062 auto func_graph = cnode->func_graph();
1063 MS_EXCEPTION_IF_NULL(func_graph);
1064 auto manager = func_graph->manager();
1065 MS_EXCEPTION_IF_NULL(manager);
1066
1067 ValuePtr out_shape = MakeValue(out_slice_shape_);
1068 AnfNodePtr val = NewValueNode(out_shape);
1069 (void)manager->Replace(cnode->input(3), val);
1070 MS_LOG(INFO) << name_ << ": Update the output shape " << out_slice_shape_;
1071 }
1072
ComputeOverlapLeftSizeByRankBias(int64_t rank_bias)1073 int64_t Conv2DBackpropInputInfo::ComputeOverlapLeftSizeByRankBias(int64_t rank_bias) {
1074 // 1. the first rank: 0
1075 // 2. the last rank:
1076 // size of origin data required by current rank: a = ceil((o/n + k - o + w*s - s - x)/s)
1077 // data size of the current rank: b = w/n
1078 // return a - b = ceil((o/n + k - o + w*s - s - x)/s) - w/n
1079 // 3. the middle rank:
1080 // r*w/n - ceil((r*o/n - k + x + 1)/s)
1081 if (rank_bias == 0) { // the first rank
1082 return 0;
1083 }
1084
1085 int64_t w_output_shape = outputs_shape_[0][3];
1086 int64_t w_input_shape = inputs_shape_[0][3];
1087 int64_t w_kernel_size = kernel_size_[1];
1088 int64_t w_stride = stride_[3];
1089 int64_t left_pad = pad_list_[2];
1090 if (rank_bias == w_dimension_shard_num_ - 1) { // the last rank
1091 return DoubleToLong(std::ceil(LongToDouble(w_output_shape / w_dimension_shard_num_ + w_kernel_size -
1092 w_output_shape + w_input_shape * w_stride - w_stride - left_pad) /
1093 LongToDouble(w_stride))) -
1094 w_input_shape / w_dimension_shard_num_;
1095 }
1096
1097 // the middle rank
1098 return rank_bias * w_input_shape / w_dimension_shard_num_ -
1099 DoubleToLong(
1100 std::ceil(LongToDouble(rank_bias * w_output_shape / w_dimension_shard_num_ - w_kernel_size + left_pad + 1) /
1101 LongToDouble(w_stride)));
1102 }
1103
ComputeOverlapRightSizeByRankBias(int64_t rank_bias)1104 int64_t Conv2DBackpropInputInfo::ComputeOverlapRightSizeByRankBias(int64_t rank_bias) {
1105 // 1. the first rank: ceil((o/n + x)/s) - w/n
1106 // 2. the last rank: 0
1107 // 3. the middle rank: ceil((r*o/n + o/n + x)/s) - r*w/n - w/n
1108 int64_t w_output_shape = outputs_shape_[0][3];
1109 int64_t w_input_shape = inputs_shape_[0][3];
1110 int64_t w_stride = stride_[3];
1111 int64_t left_pad = pad_list_[2];
1112
1113 if (rank_bias == 0) { // the first rank
1114 return DoubleToLong(
1115 std::ceil(LongToDouble(w_output_shape / w_dimension_shard_num_ + left_pad) / LongToDouble(w_stride))) -
1116 w_input_shape / w_dimension_shard_num_;
1117 }
1118
1119 if (rank_bias == w_dimension_shard_num_ - 1) { // the last rank
1120 return 0;
1121 }
1122
1123 // the middle rank
1124 return DoubleToLong(std::ceil(LongToDouble(rank_bias * w_output_shape / w_dimension_shard_num_ +
1125 w_output_shape / w_dimension_shard_num_ + left_pad) /
1126 LongToDouble(w_stride))) -
1127 (rank_bias + 1) * w_input_shape / w_dimension_shard_num_;
1128 }
1129
InferNewPadList()1130 void Conv2DBackpropInputInfo::InferNewPadList() {
1131 // 1. compute the size of origin data required by current rank:
1132 // 1) the first rank: ceil((o/n + x) / s)
1133 // 2) the last rank: ceil((o/n + k - o + ws - s - x) / s)
1134 // 3) the middle rank: ceil((r*o/n + o/n + x) / s) - ceil((r*o/n - k + x + 1) / s)
1135 //
1136 // 2. compute the real left pad
1137 // 1) the first rank: k - x - 1
1138 // 2) the last rank:
1139 // if (o/n + k - o + ws - s - x) is divisible by s, real_left_pad = s - 1.
1140 // otherwise, real_left_pad = (o/n + k - o + ws - s - x) % s - 1
1141 // 3) the middle rank:
1142 // if (r*on - k + x + 1) is divisible by s, real_left_pad = 0.
1143 // otherwise, real_left_pad = s - (r*on - k + x + 1) % s
1144 int64_t w_output_shape = outputs_shape_[0][3];
1145 int64_t w_input_shape = inputs_shape_[0][3];
1146 int64_t w_kernel_size = kernel_size_[1];
1147 int64_t w_stride = stride_[3];
1148 int64_t left_pad = pad_list_[2];
1149 int64_t current_rank_required_size = 0;
1150 int64_t real_left_pad = 0;
1151
1152 if (rank_bias_ == 0) { // the first rank
1153 current_rank_required_size = DoubleToLong(
1154 std::ceil(LongToDouble(w_output_shape / w_dimension_shard_num_ + left_pad) / LongToDouble(w_stride)));
1155
1156 real_left_pad = w_kernel_size - left_pad - 1;
1157 } else if (rank_bias_ == w_dimension_shard_num_ - 1) { // the last rank
1158 current_rank_required_size =
1159 DoubleToLong(std::ceil(LongToDouble(w_output_shape / w_dimension_shard_num_ + w_kernel_size - w_output_shape +
1160 w_input_shape * w_stride - w_stride - left_pad) /
1161 LongToDouble(w_stride)));
1162
1163 int64_t tmp = w_output_shape / w_dimension_shard_num_ + w_kernel_size - w_output_shape + w_input_shape * w_stride -
1164 w_stride - left_pad;
1165 if (tmp % w_stride == 0) {
1166 real_left_pad = w_stride - 1;
1167 } else {
1168 real_left_pad = tmp % w_stride - 1;
1169 }
1170 } else { // the middle rank
1171 current_rank_required_size =
1172 DoubleToLong(std::ceil(LongToDouble(rank_bias_ * w_output_shape / w_dimension_shard_num_ +
1173 w_output_shape / w_dimension_shard_num_ + left_pad) /
1174 LongToDouble(w_stride))) -
1175 DoubleToLong(
1176 std::ceil(LongToDouble(rank_bias_ * w_output_shape / w_dimension_shard_num_ - w_kernel_size + left_pad + 1) /
1177 LongToDouble(w_stride)));
1178
1179 int64_t tmp = rank_bias_ * w_output_shape / w_dimension_shard_num_ - w_kernel_size + left_pad + 1;
1180 if (tmp % w_stride == 0) {
1181 real_left_pad = 0;
1182 } else {
1183 real_left_pad = w_stride - tmp % w_stride;
1184 }
1185 }
1186
1187 // 3. compute the pad_add: (current_rank_required_size - 1) * s + k - o/n
1188 int64_t pad_all =
1189 (current_rank_required_size - 1) * w_stride + w_kernel_size - w_output_shape / w_dimension_shard_num_;
1190
1191 // 4. compute new left pad: k - real_left_pad - 1
1192 new_pad_list_ = pad_list_;
1193 new_pad_list_[2] = w_kernel_size - real_left_pad - 1;
1194
1195 // 5. compute new right pad: pad_all - new_left_pad
1196 new_pad_list_[3] = pad_all - new_pad_list_[2];
1197
1198 MS_LOG(INFO) << name_ << ": the new pad list is " << new_pad_list_ << ", the required size of current rank is "
1199 << current_rank_required_size << ", new pad all is " << pad_all;
1200 }
1201
ReplaceNodeInputOrAttrs()1202 void Conv2DBackpropInputInfo::ReplaceNodeInputOrAttrs() { UpdateOutShape(); }
1203 } // namespace parallel
1204 } // namespace mindspore
1205