1 /**
2 * Copyright 2019 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/parallel/ops_info/matmul_info.h"
18
19 #include <algorithm>
20 #include <functional>
21 #include <memory>
22 #include <numeric>
23 #include <string>
24 #include <utility>
25 #include <vector>
26
27 #include "ir/value.h"
28 #include "frontend/parallel/auto_parallel/graph_costmodel.h"
29 #include "frontend/parallel/device_manager.h"
30 #include "frontend/parallel/device_matrix.h"
31 #include "frontend/parallel/tensor_layout/tensor_redistribution.h"
32
33 namespace mindspore {
34 namespace parallel {
SetDevMatrixShape(const Dimensions & mat_a_strategy,const Dimensions & mat_b_strategy,bool transpose_b,Shape * dev_matrix_shape)35 void SetDevMatrixShape(const Dimensions &mat_a_strategy, const Dimensions &mat_b_strategy, bool transpose_b,
36 Shape *dev_matrix_shape) {
37 MS_EXCEPTION_IF_NULL(dev_matrix_shape);
38 size_t mat_a_size = mat_a_strategy.size();
39 size_t mat_b_size = mat_b_strategy.size();
40 if (mat_a_size >= mat_b_size) {
41 // for example: mat_a_strategy:[2,4,8,16], mat_b_strategy:[4,16,32]
42 // dev_matrix_shape:[2,4,8,16,32] (transpose_b is false)
43
44 // [2],[4] in the example above
45 for (size_t i = 0; i < SECOND_FROM_END(mat_a_size); ++i) {
46 dev_matrix_shape->push_back(mat_a_strategy.at(i));
47 }
48 } else {
49 // for example: mat_a_strategy:[8,16], mat_b_strategy:[2,4,16,32]
50 // dev_matrix_shape:[2,4,8,16,32] (transpose_b is false)
51
52 // [2],[4] in the example above
53 for (size_t i = 0; i < SECOND_FROM_END(mat_b_size); ++i) {
54 dev_matrix_shape->push_back(mat_b_strategy.at(i));
55 }
56 }
57
58 // [8],[16] in the example above
59 dev_matrix_shape->push_back(mat_a_strategy.at(SECOND_FROM_END(mat_a_size)));
60 dev_matrix_shape->push_back(mat_a_strategy.back());
61
62 // [32] in the example above
63 if (!transpose_b) {
64 dev_matrix_shape->push_back(mat_b_strategy.back());
65 } else {
66 dev_matrix_shape->push_back(mat_b_strategy.at(SECOND_FROM_END(mat_b_size)));
67 }
68 }
69
GetAttrs()70 Status MatMulBase::GetAttrs() {
71 if (attrs_.size() < MATMUL_ATTRS_SIZE) {
72 MS_LOG(ERROR) << name_ << " : The size of attrs small than 2.";
73 return FAILED;
74 }
75
76 auto transpose_a_iter = attrs_.find(TRANSPOSE_A);
77 if (transpose_a_iter != attrs_.end()) {
78 MS_EXCEPTION_IF_NULL(transpose_a_iter->second);
79 if (transpose_a_iter->second->isa<BoolImm>()) {
80 transpose_a_ = transpose_a_iter->second->cast<BoolImmPtr>()->value();
81 } else {
82 MS_LOG(ERROR) << name_ << " : The value of transpose_a is not bool.";
83 return FAILED;
84 }
85 }
86
87 auto transpose_b_iter = attrs_.find(TRANSPOSE_B);
88 if (transpose_b_iter != attrs_.end()) {
89 MS_EXCEPTION_IF_NULL(transpose_b_iter->second);
90 if (transpose_b_iter->second->isa<BoolImm>()) {
91 transpose_b_ = transpose_b_iter->second->cast<BoolImmPtr>()->value();
92 } else {
93 MS_LOG(ERROR) << name_ << " : The value of transpose_a is not bool.";
94 return FAILED;
95 }
96 }
97
98 auto forward_reduce_scatter_iter = attrs_.find(FORWARD_REDUCE_SCATTER);
99 if (forward_reduce_scatter_iter != attrs_.end()) {
100 MS_EXCEPTION_IF_NULL(forward_reduce_scatter_iter->second);
101 if (forward_reduce_scatter_iter->second->isa<BoolImm>()) {
102 forward_reduce_scatter_ = forward_reduce_scatter_iter->second->cast<BoolImmPtr>()->value();
103 } else {
104 MS_LOG(ERROR) << name_ << " : The value of forward reduce scatter is not bool.";
105 return FAILED;
106 }
107 }
108
109 auto field_size_iter = attrs_.find(FIELD_SIZE);
110 if (field_size_iter != attrs_.end()) {
111 MS_EXCEPTION_IF_NULL(field_size_iter->second);
112 if (field_size_iter->second->isa<Int64Imm>()) {
113 field_size_ = field_size_iter->second->cast<Int64ImmPtr>()->value();
114 } else {
115 MS_LOG(ERROR) << name_ << " : The value of field_size is not int64_t.";
116 return FAILED;
117 }
118 }
119
120 // infer inputs dimension size
121 if ((inputs_shape_.size() != MATMUL_INPUTS_SIZE) || (outputs_shape_.size() != MATMUL_OUTPUTS_SIZE)) {
122 MS_LOG(ERROR) << name_ << " : Inputs shape size or outputs shape size is wrong.";
123 return FAILED;
124 }
125 mat_a_dimension_ = inputs_shape_.at(0).size();
126 mat_b_dimension_ = inputs_shape_.at(1).size();
127
128 return SUCCESS;
129 }
130
CheckRelevantDimension(const Dimensions & long_strategy,const Dimensions & short_strategy)131 Status CheckRelevantDimension(const Dimensions &long_strategy, const Dimensions &short_strategy) {
132 size_t long_size = long_strategy.size();
133 size_t short_size = short_strategy.size();
134 if (long_size < short_size) {
135 MS_LOG(ERROR) << "Size error, the size of long strategy is " << long_size << ", the size of short strategy is "
136 << short_size;
137 return FAILED;
138 }
139
140 size_t len_diff = long_size - short_size;
141 for (size_t j = 0; j < SECOND_FROM_END(short_size); ++j) {
142 if (long_strategy.at(len_diff + j) != short_strategy.at(j)) {
143 MS_LOG(ERROR) << "Strategies of relevant dimensions are not equal, long strategy is "
144 << ShapeToString(long_strategy) << ", short strategy is " << ShapeToString(short_strategy);
145 return FAILED;
146 }
147 }
148
149 return SUCCESS;
150 }
151
CheckStrategy(const StrategyPtr & strategy)152 Status MatMul::CheckStrategy(const StrategyPtr &strategy) {
153 if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
154 MS_LOG(ERROR) << name_ << " : Invalid strategy.";
155 return FAILED;
156 }
157
158 Strategys stra = strategy->GetInputDim();
159 Dimensions mat_a_strategy = stra.at(0);
160 Dimensions mat_b_strategy = stra.at(1);
161
162 size_t mat_a_size = mat_a_strategy.size();
163 size_t mat_b_size = mat_b_strategy.size();
164 if ((mat_a_size != mat_a_dimension_) || (mat_b_size != mat_b_dimension_)) {
165 MS_LOG(ERROR) << name_ << " : The dimensions of mat_a or mat_b's strategy is wrong.";
166 return FAILED;
167 }
168
169 // for example: mat_a_strategy:[2,4,8,16], mat_b_strategy:[4,16,32]
170 // dev_matrix_shape:[2,4,8,16,32] (transpose_b is false)
171 // [16] in the example above
172 if (!transpose_b_ && (mat_a_strategy.back() != mat_b_strategy.at(SECOND_FROM_END(mat_b_size)))) {
173 MS_LOG(ERROR) << name_ << " : Can not do this operator in the strategy: " << StrategyToString(stra)
174 << ", the transpose_b is false, the shard num of first input's column is " << mat_a_strategy.back()
175 << ", but the shard num of second input's row is " << mat_b_strategy.at(SECOND_FROM_END(mat_b_size));
176 return FAILED;
177 } else if (transpose_b_ && (mat_a_strategy.back() != mat_b_strategy.back())) {
178 MS_LOG(ERROR) << name_ << " : Can not do this operator in the strategy: " << StrategyToString(stra)
179 << ", the transpose_b is true, the shard num of first input's column is " << mat_a_strategy.back()
180 << ", but the shard num of second input's column is " << mat_b_strategy.back();
181 return FAILED;
182 }
183
184 if (mat_a_size >= mat_b_size) {
185 if (CheckRelevantDimension(mat_a_strategy, mat_b_strategy) != SUCCESS) {
186 MS_LOG(ERROR) << name_ << " : Strategies of relevant dimensions are not equal.";
187 return FAILED;
188 }
189 } else {
190 if (CheckRelevantDimension(mat_b_strategy, mat_a_strategy) != SUCCESS) {
191 MS_LOG(ERROR) << name_ << " : Strategies of relevant dimensions are not equal.";
192 return FAILED;
193 }
194 }
195
196 if ((mat_a_dimension_ != 2 || mat_b_dimension_ != 2) && forward_reduce_scatter_) {
197 MS_LOG(WARNING) << name_
198 << ": The dimension of mat a and mat b must be 2 in forward reduce scatter mode, "
199 "setting the forward reduce scatter mode to false here";
200 forward_reduce_scatter_ = false;
201 }
202
203 return SUCCESS;
204 }
205
InferDevMatrixShape()206 Status MatMulBase::InferDevMatrixShape() {
207 Strategys stra = strategy_->GetInputDim();
208 Dimensions mat_a_strategy = stra.at(0);
209 Dimensions mat_b_strategy = stra.at(1);
210
211 SetDevMatrixShape(mat_a_strategy, mat_b_strategy, transpose_b_, &dev_matrix_shape_);
212 origin_dev_matrix_shape_ = dev_matrix_shape_;
213 return SUCCESS;
214 }
215
InferForwardCommunication()216 Status MatMulBase::InferForwardCommunication() {
217 forward_op_.clear();
218 size_t dimension = origin_dev_matrix_shape_.size();
219 size_t relevant_dimension_index = SECOND_FROM_END(dimension);
220 // Relevant dimension is not split and all reduce is not required,
221 // need to use origin_dev_matrix_shape_ here, since the dev_matrix_shape_ will be changed if repeated calculation.
222 if (origin_dev_matrix_shape_.at(relevant_dimension_index) == MIN_SLICE_NUM) {
223 MS_LOG(INFO) << name_ << " : Forward all reduce is not required.";
224 return SUCCESS;
225 }
226
227 std::vector<Group> group_list;
228 if (CreateGroupByDim(relevant_dimension_index, &group_list) != SUCCESS) {
229 MS_LOG(ERROR) << name_ << " : Infer forward communication, create group failed.";
230 return FAILED;
231 } else if (group_list.empty()) {
232 MS_LOG(INFO) << name_ << " : Forward all reduce is not required.";
233 return SUCCESS;
234 }
235
236 Operator op;
237 if (forward_reduce_scatter_) {
238 op = CreateReduceScatterOp(REDUCE_OP_SUM, group_list[0].name());
239 } else {
240 op = CreateAllReduceOp(REDUCE_OP_SUM, group_list[0].name());
241 }
242
243 forward_op_.push_back(op);
244 MS_LOG(INFO) << name_ << " : The group name of forward communication is " << group_list[0].name();
245 return SUCCESS;
246 }
247
InferTensorMap()248 Status MatMulBase::InferTensorMap() {
249 size_t size = dev_matrix_shape_.size();
250 if (repeated_calc_num_ > 1) {
251 // move the first dimension(repeated_calc_num_), just for the convenience of tensor-map's calculation
252 size = dev_matrix_shape_.size() - 1;
253 }
254
255 Shape tensor_map_index;
256 // such as 5: tensor_map_index [4,3,2,1,0]
257 for (size_t i = 0; i < size; ++i) {
258 tensor_map_index.push_back((int64_t)(LAST_INDEX(size) - i));
259 }
260
261 // infer output tensor map: [4,3,2,0], delete the second-from-end element
262 TensorMap output_tensor_map = tensor_map_index;
263 (void)output_tensor_map.erase(output_tensor_map.begin() + static_cast<different_type>(SECOND_FROM_END(size)));
264
265 // infer mat_a tensor map
266 // for example: mat_a_dimension is 4, mat_a tensor map:[4,3,2,1]
267 TensorMap mat_a_tensor_map = tensor_map_index;
268 // delete last one element
269 mat_a_tensor_map.pop_back();
270 // delete the first (dev_matrix_size - 1 - mat_a_dimension) elements
271 (void)mat_a_tensor_map.erase(
272 mat_a_tensor_map.begin(),
273 mat_a_tensor_map.begin() + static_cast<different_type>(LAST_INDEX(size) - mat_a_dimension_));
274
275 // infer mat_b tensor map
276 TensorMap mat_b_tensor_map = tensor_map_index;
277 // delete the third-to-last element
278 (void)mat_b_tensor_map.erase(mat_b_tensor_map.begin() + static_cast<different_type>(THIRD_FROM_END(size)));
279 // delete the first (dev_matrix_size - 1 - mat_b_dimension) elements
280 (void)mat_b_tensor_map.erase(
281 mat_b_tensor_map.begin(),
282 mat_b_tensor_map.begin() + static_cast<different_type>(LAST_INDEX(size) - mat_b_dimension_));
283 if (transpose_b_) {
284 // swap the last two elements
285 int64_t last_value = mat_b_tensor_map.back();
286 mat_b_tensor_map.pop_back();
287 (void)mat_b_tensor_map.insert(
288 mat_b_tensor_map.begin() + static_cast<different_type>(LAST_INDEX(mat_b_tensor_map.size())), last_value);
289 }
290
291 if (forward_reduce_scatter_) {
292 if (dev_matrix_shape_.size() != 3) {
293 MS_LOG(WARNING) << name_
294 << ": The dimension of dev matrix shape must be 3 in forward reduce scatter mode, "
295 "setting the forward reduce scatter mode to false here";
296 forward_reduce_scatter_ = false;
297 } else if (outputs_shape_[0][0] % (dev_matrix_shape_[0] * dev_matrix_shape_[1]) != 0) {
298 MS_LOG(WARNING) << name_
299 << ": The first dimension of output should be split by dev_matrix[0]*dev_matrix[1] in "
300 "forward reduce scatter mode, setting the forward reduce scatter mode to false here";
301 forward_reduce_scatter_ = false;
302 } else {
303 // the forward reduce scatter only support that the dimension of output is 2
304 output_tensor_map = {1, 0};
305 }
306 }
307
308 inputs_tensor_map_.push_back(mat_a_tensor_map);
309 inputs_tensor_map_.push_back(mat_b_tensor_map);
310 outputs_tensor_map_.push_back(output_tensor_map);
311 return SUCCESS;
312 }
313
InferTensorLayout(TensorLayouts * inputs_layout,TensorLayouts * outputs_layout)314 Status MatMulBase::InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout) {
315 Shape output_dev_matrix_shape;
316 if (forward_reduce_scatter_) {
317 if (dev_matrix_shape_.size() != 3) {
318 MS_LOG(ERROR) << "The size of origin dev matrix shape must be 3 in forward reduce scatter mode";
319 return FAILED;
320 }
321 output_dev_matrix_shape = {dev_matrix_shape_[0] * dev_matrix_shape_[1], dev_matrix_shape_[2]};
322 } else {
323 output_dev_matrix_shape = dev_matrix_shape_;
324 }
325
326 TensorLayout mat_a_layout, mat_b_layout, output_layout;
327 if ((mat_a_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], inputs_shape_[0]) != SUCCESS) ||
328 (mat_b_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[1], inputs_shape_[1]) != SUCCESS) ||
329 (output_layout.InitFromVector(output_dev_matrix_shape, outputs_tensor_map_[0], outputs_shape_[0]) != SUCCESS)) {
330 return FAILED;
331 }
332
333 if (field_size_ != 0) {
334 mat_b_layout.set_field_size(field_size_);
335 }
336
337 inputs_layout->push_back(mat_a_layout);
338 inputs_layout->push_back(mat_b_layout);
339 outputs_layout->push_back(output_layout);
340 return SUCCESS;
341 }
342
InferTensorInfo()343 Status MatMulBase::InferTensorInfo() {
344 // infer tensor layout
345 TensorLayouts inputs_layout, outputs_layout;
346 if (InferTensorLayout(&inputs_layout, &outputs_layout) != SUCCESS) {
347 return FAILED;
348 }
349
350 TensorLayout mat_a_layout = inputs_layout.at(0);
351 TensorLayout mat_b_layout = inputs_layout.at(1);
352 TensorLayout output_layout = outputs_layout.at(0);
353 TensorInfo mat_a_tensor_info(mat_a_layout);
354 TensorInfo mat_b_tensor_info(mat_b_layout);
355 TensorInfo output_tensor_info(output_layout);
356
357 inputs_tensor_info_.push_back(mat_a_tensor_info);
358 inputs_tensor_info_.push_back(mat_b_tensor_info);
359 outputs_tensor_info_.push_back(output_tensor_info);
360 return SUCCESS;
361 }
362
Init(const StrategyPtr & strategy)363 Status MatMulBase::Init(const StrategyPtr &strategy) {
364 if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
365 MS_LOG(ERROR) << name_ << " : Init failed.";
366 return FAILED;
367 }
368
369 if (forward_reduce_scatter_) {
370 virtual_div_op_.clear();
371 MS_LOG(INFO) << "The forward reduce scatter mode does not involve repeated calculation, clear the virtual div op";
372 }
373
374 MS_LOG(INFO) << name_ << " : Init success.";
375 return SUCCESS;
376 }
377
InitForCostModel(const StrategyPtr & strategy)378 Status MatMulBase::InitForCostModel(const StrategyPtr &strategy) {
379 if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
380 MS_LOG(ERROR) << name_ << " : Init for cost model failed.";
381 return FAILED;
382 }
383
384 MS_LOG(INFO) << name_ << " : Init for cost model success.";
385 return SUCCESS;
386 }
387
SwapLastTwoElements(mindspore::parallel::Shape * const input)388 Status MatMulBase::SwapLastTwoElements(mindspore::parallel::Shape *const input) {
389 if (input->size() < 2) {
390 MS_LOG(ERROR) << name_ << " : The size of inputs small than 2.";
391 return FAILED;
392 }
393 auto last_1st_value = input->at(input->size() - 1);
394 auto last_2nd_value = input->at(input->size() - 2);
395 input->pop_back();
396 input->pop_back();
397 input->push_back(last_1st_value);
398 input->push_back(last_2nd_value);
399 return SUCCESS;
400 }
401
GenerateStrategies(int64_t stage_id)402 Status MatMulBase::GenerateStrategies(int64_t stage_id) {
403 if (GetAttrs() != SUCCESS) {
404 MS_LOG(ERROR) << name_ << " : GetAttrs failed.";
405 return FAILED;
406 }
407 CheckGlobalDeviceManager();
408 RankList dev_list = g_device_manager->GetDeviceListByStageId(stage_id);
409 size_t dev_num = dev_list.size();
410 Shape input0_shape = inputs_shape_[0], input1_shape = inputs_shape_[1];
411 if (transpose_a_) {
412 if (SwapLastTwoElements(&input0_shape) == FAILED) {
413 MS_LOG(ERROR) << name_ << " : Swap last two elements failed.";
414 }
415 }
416 if (transpose_b_) {
417 if (SwapLastTwoElements(&input1_shape) == FAILED) {
418 MS_LOG(ERROR) << name_ << " : Swap last two elements failed.";
419 }
420 }
421 // The shape of input0 (input1)
422 // E.g., input0 = [100, 200, 300], input1 = [300, 400]
423
424 // Combining the input0_shape and input1_shape
425 // E.g., combined_shape = [100, 200, 300, 400]
426 size_t input1_shape_size = input1_shape.size(), input0_shape_size = input0_shape.size();
427 Dimensions combined_partitions;
428 Shape combined_shape;
429 // In SwapLastTwoElements(), it is guaranteed that input0_shape.size() and input1_shape.size() are both larger than 2
430 if (input0_shape.size() >= input1_shape.size()) {
431 combined_shape = input0_shape;
432 combined_shape.push_back(input1_shape[input1_shape.size() - 1]);
433 } else {
434 combined_shape = input1_shape;
435 combined_shape.push_back(input0_shape[input0_shape.size() - 2]);
436 }
437 std::function<void(uint64_t, size_t)> recursive = [&stage_id, &dev_num, &combined_partitions, &combined_shape,
438 &input1_shape_size, &recursive, &input0_shape_size,
439 this](uint64_t current_index, size_t n) {
440 // Finishing the recursive steps, if the strategy is valid, then calculate the cost
441 // for this operator under the strategy.
442 if (current_index == combined_shape.size()) {
443 StrategyPtr sp;
444 if (this->PrepareStrategy(stage_id, dev_num, combined_partitions, input0_shape_size, input1_shape_size, &sp) ==
445 FAILED) {
446 return;
447 }
448 if (this->SetCostUnderStrategy(sp) == FAILED) {
449 MS_LOG(WARNING) << name_ << " : Calculating cost for strategy failed.";
450 return;
451 }
452 } else {
453 MS_LOG(DEBUG) << name_ << " : The value input0_shape_size: " << input0_shape_size
454 << ", input1_shape_size: " << input1_shape_size;
455 for (uint64_t i = 1; i <= n; i *= 2) {
456 if (n % i == 0 && LongToSize(combined_shape[current_index]) % i == 0) {
457 combined_partitions.push_back(i);
458 recursive(current_index + 1, n / i);
459 combined_partitions.pop_back();
460 }
461 }
462 }
463 };
464 recursive(0, dev_num);
465 if (strategy_cost_.empty()) {
466 MS_LOG(EXCEPTION) << name_ << " : No available strategy.";
467 }
468 return Status::SUCCESS;
469 }
470
GenerateOpStrategies(int64_t)471 std::vector<StrategyPtr> MatMulBase::GenerateOpStrategies(int64_t) {
472 std::vector<StrategyPtr> sp_vector;
473 return sp_vector;
474 }
475
PrepareStrategy(int64_t stage_id,size_t dev_num,mindspore::parallel::Dimensions combined_partitions,size_t input0_shape_size,size_t input1_shape_size,mindspore::parallel::StrategyPtr * const sp)476 Status MatMulBase::PrepareStrategy(int64_t stage_id, size_t dev_num,
477 mindspore::parallel::Dimensions combined_partitions, size_t input0_shape_size,
478 size_t input1_shape_size, mindspore::parallel::StrategyPtr *const sp) {
479 int64_t product =
480 std::accumulate(combined_partitions.begin(), combined_partitions.end(), 1, std::multiplies<int64_t>());
481 const auto fully_use_device = CostModelContext::GetInstance()->fully_use_device();
482 if (!fully_use_device) {
483 if (LongToSize(product) > dev_num) {
484 return FAILED;
485 }
486 } else {
487 if (LongToSize(product) != dev_num) {
488 return FAILED;
489 }
490 }
491 Dimensions input0_partitions, input1_partitions;
492 if (input0_shape_size >= input1_shape_size) {
493 for (size_t i = 0; i < input0_shape_size; ++i) {
494 input0_partitions.push_back(combined_partitions[i]);
495 }
496 if (input1_shape_size == 2) {
497 input1_partitions.push_back(combined_partitions[combined_partitions.size() - 2]);
498 input1_partitions.push_back(combined_partitions[combined_partitions.size() - 1]);
499 } else {
500 // input1_shape.size() > 2
501 for (size_t j = combined_partitions.size() - input1_shape_size - 1; j < combined_partitions.size(); ++j) {
502 if (j == combined_partitions.size() - 3) {
503 continue;
504 }
505 input1_partitions.push_back(combined_partitions[j]);
506 }
507 }
508 } else {
509 for (size_t i = 0; i < input1_shape_size; ++i) {
510 input1_partitions.push_back(combined_partitions[i]);
511 }
512 for (size_t j = combined_partitions.size() - input0_shape_size - 1; j < combined_partitions.size() - 3; ++j) {
513 input0_partitions.push_back(combined_partitions[j]);
514 }
515 input0_partitions.push_back(combined_partitions[combined_partitions.size() - 1]);
516 input0_partitions.push_back(combined_partitions[combined_partitions.size() - 3]);
517 }
518 if (transpose_a_) {
519 if (SwapLastTwoElements(&input0_partitions) == FAILED) {
520 MS_LOG(ERROR) << name_ << " : Swap last two elements failed.";
521 }
522 }
523 if (transpose_b_) {
524 if (SwapLastTwoElements(&input1_partitions) == FAILED) {
525 MS_LOG(ERROR) << name_ << " : Swap last two elements failed.";
526 }
527 }
528 Strategys stras;
529 stras.push_back(input0_partitions);
530 stras.push_back(input1_partitions);
531 (*sp) = std::make_shared<Strategy>(stage_id, stras);
532
533 return SUCCESS;
534 }
535
InitTensorInfoForCost(std::vector<TensorInfo> * relica_inputs_tensor_vector)536 void MatMulBase::InitTensorInfoForCost(std::vector<TensorInfo> *relica_inputs_tensor_vector) {
537 TensorLayout tly;
538 if (transpose_a_) {
539 Shape replica_input0_shape(inputs_tensor_info_[0].shape());
540 Shape replica_input0_slice_shape(inputs_tensor_info_[0].slice_shape());
541 if (SwapLastTwoElements(&replica_input0_shape) == FAILED) {
542 MS_LOG(ERROR) << name_ << " : Swap last two elements failed.";
543 }
544 if (SwapLastTwoElements(&replica_input0_slice_shape) == FAILED) {
545 MS_LOG(ERROR) << name_ << " : Swap last two elements failed.";
546 }
547
548 TensorInfo replica_input0_info(tly, replica_input0_shape, replica_input0_slice_shape);
549 relica_inputs_tensor_vector->push_back(replica_input0_info);
550 } else {
551 relica_inputs_tensor_vector->push_back(inputs_tensor_info_[0]);
552 }
553 if (transpose_b_) {
554 Shape replica_input1_shape(inputs_tensor_info_[1].shape());
555 Shape replica_input1_slice_shape(inputs_tensor_info_[1].slice_shape());
556 if (SwapLastTwoElements(&replica_input1_shape) == FAILED) {
557 MS_LOG(ERROR) << name_ << " : Swap last two elements failed.";
558 }
559 if (SwapLastTwoElements(&replica_input1_slice_shape) == FAILED) {
560 MS_LOG(ERROR) << name_ << " : Swap last two elements failed.";
561 }
562
563 TensorInfo replica_input1_info(tly, replica_input1_shape, replica_input1_slice_shape);
564 relica_inputs_tensor_vector->push_back(replica_input1_info);
565 } else {
566 relica_inputs_tensor_vector->push_back(inputs_tensor_info_[1]);
567 }
568 }
569
CheckForTensorSliceValid() const570 Status MatMulBase::CheckForTensorSliceValid() const {
571 const auto align_enable = CostModelContext::GetInstance()->tensor_slice_alignment_enable();
572 const auto align_size = CostModelContext::GetInstance()->tensor_slice_alignment_size();
573 if (!align_enable) {
574 return SUCCESS;
575 }
576 if (inputs_tensor_info_.empty()) {
577 return FAILED;
578 }
579 for (auto &one_input_tensor : inputs_tensor_info_) {
580 auto slice_shape = one_input_tensor.slice_shape();
581 if ((LongToSize(slice_shape[LAST_INDEX(slice_shape.size())]) % align_size != 0) ||
582 (LongToSize(slice_shape[SECOND_FROM_END(slice_shape.size())]) % align_size != 0)) {
583 return FAILED;
584 }
585 }
586 return SUCCESS;
587 }
588
GenerateBatchStrategies()589 std::shared_ptr<Strategys> BatchMatMulInfo::GenerateBatchStrategies() {
590 Dimensions batch_strategy(inputs_shape_[1].size() - 1, 1);
591 batch_strategy.insert(batch_strategy.begin(), stage_device_size_);
592 Strategys strategy_v = {batch_strategy, batch_strategy};
593 return std::make_shared<Strategys>(strategy_v);
594 }
595
SetCostUnderStrategy(const mindspore::parallel::StrategyPtr & strategy)596 Status MatMulBase::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr &strategy) {
597 if (InitForCostModel(strategy) == FAILED) {
598 MS_LOG(ERROR) << name_ << " : Initialization under the strategy failed.";
599 return FAILED;
600 }
601 PrintStrategy(strategy);
602 // Check whether the tensor slice of input_tensor_info is valid or not
603 if (CheckForTensorSliceValid() != SUCCESS) {
604 MS_LOG(INFO) << name_ << " : The tensor slice is not valid under this strategy.";
605 return FAILED;
606 }
607 // Here, a replicated inputs_ is constructed for the transposed TensorInfo.
608 std::vector<TensorInfo> relica_inputs_tensor_vector;
609 InitTensorInfoForCost(&relica_inputs_tensor_vector);
610
611 int64_t stage_id = strategy->GetInputStage();
612 // Here, we use the origin outputs_, because we only use the slice size of the output tensor.
613 // It does not matter whether the output tensor is transposed or not.
614 double computation_cost =
615 operator_cost()->GetForwardComputationCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id);
616 double communication_cost = operator_cost()->GetCommCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id);
617 const auto gamma = CostModelContext::GetInstance()->costmodel_gamma();
618 std::shared_ptr<Cost> result = std::make_shared<Cost>(computation_cost, communication_cost);
619 result->communication_without_parameter_ =
620 operator_cost()->GetForwardCommCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id);
621 result->communication_with_partial_para_ =
622 result->communication_without_parameter_ + gamma * (communication_cost - result->communication_without_parameter_);
623
624 // Breaking ties for preferring data parallelization
625 BreakingTiesForPerferringDataParallel(strategy, result);
626 MS_LOG(DEBUG) << name_ << " : computation_cost: " << result->computation_cost_
627 << ", communication_cost: " << result->communication_cost_
628 << ", communication_without_parameter_: " << result->communication_without_parameter_
629 << ", communication_with_partial_para_: " << result->communication_with_partial_para_;
630 // refine communication cost calculation for practice
631 RefineForPracticalCost(result, false);
632 result->communication_forward_ = result->communication_without_parameter_;
633
634 std::shared_ptr<StrategyWithCost> swc =
635 std::make_shared<StrategyWithCost>(strategy, inputs_tensor_info_, outputs_tensor_info_);
636 swc->cost_list.push_back(result);
637 strategy_cost_.emplace_back(swc);
638
639 return SUCCESS;
640 }
641 } // namespace parallel
642 } // namespace mindspore
643