• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/parallel/ops_info/matmul_info.h"
18 
19 #include <algorithm>
20 #include <functional>
21 #include <memory>
22 #include <numeric>
23 #include <string>
24 #include <utility>
25 #include <vector>
26 
27 #include "ir/value.h"
28 #include "frontend/parallel/auto_parallel/graph_costmodel.h"
29 #include "frontend/parallel/device_manager.h"
30 #include "frontend/parallel/device_matrix.h"
31 #include "frontend/parallel/tensor_layout/tensor_redistribution.h"
32 
33 namespace mindspore {
34 namespace parallel {
SetDevMatrixShape(const Dimensions & mat_a_strategy,const Dimensions & mat_b_strategy,bool transpose_b,Shape * dev_matrix_shape)35 void SetDevMatrixShape(const Dimensions &mat_a_strategy, const Dimensions &mat_b_strategy, bool transpose_b,
36                        Shape *dev_matrix_shape) {
37   MS_EXCEPTION_IF_NULL(dev_matrix_shape);
38   size_t mat_a_size = mat_a_strategy.size();
39   size_t mat_b_size = mat_b_strategy.size();
40   if (mat_a_size >= mat_b_size) {
41     // for example: mat_a_strategy:[2,4,8,16], mat_b_strategy:[4,16,32]
42     // dev_matrix_shape:[2,4,8,16,32] (transpose_b is false)
43 
44     // [2],[4] in the example above
45     for (size_t i = 0; i < SECOND_FROM_END(mat_a_size); ++i) {
46       dev_matrix_shape->push_back(mat_a_strategy.at(i));
47     }
48   } else {
49     // for example: mat_a_strategy:[8,16], mat_b_strategy:[2,4,16,32]
50     // dev_matrix_shape:[2,4,8,16,32] (transpose_b is false)
51 
52     // [2],[4] in the example above
53     for (size_t i = 0; i < SECOND_FROM_END(mat_b_size); ++i) {
54       dev_matrix_shape->push_back(mat_b_strategy.at(i));
55     }
56   }
57 
58   // [8],[16] in the example above
59   dev_matrix_shape->push_back(mat_a_strategy.at(SECOND_FROM_END(mat_a_size)));
60   dev_matrix_shape->push_back(mat_a_strategy.back());
61 
62   // [32] in the example above
63   if (!transpose_b) {
64     dev_matrix_shape->push_back(mat_b_strategy.back());
65   } else {
66     dev_matrix_shape->push_back(mat_b_strategy.at(SECOND_FROM_END(mat_b_size)));
67   }
68 }
69 
GetAttrs()70 Status MatMulBase::GetAttrs() {
71   if (attrs_.size() < MATMUL_ATTRS_SIZE) {
72     MS_LOG(ERROR) << name_ << " : The size of attrs small than 2.";
73     return FAILED;
74   }
75 
76   auto transpose_a_iter = attrs_.find(TRANSPOSE_A);
77   if (transpose_a_iter != attrs_.end()) {
78     MS_EXCEPTION_IF_NULL(transpose_a_iter->second);
79     if (transpose_a_iter->second->isa<BoolImm>()) {
80       transpose_a_ = transpose_a_iter->second->cast<BoolImmPtr>()->value();
81     } else {
82       MS_LOG(ERROR) << name_ << " : The value of transpose_a is not bool.";
83       return FAILED;
84     }
85   }
86 
87   auto transpose_b_iter = attrs_.find(TRANSPOSE_B);
88   if (transpose_b_iter != attrs_.end()) {
89     MS_EXCEPTION_IF_NULL(transpose_b_iter->second);
90     if (transpose_b_iter->second->isa<BoolImm>()) {
91       transpose_b_ = transpose_b_iter->second->cast<BoolImmPtr>()->value();
92     } else {
93       MS_LOG(ERROR) << name_ << " : The value of transpose_a is not bool.";
94       return FAILED;
95     }
96   }
97 
98   auto forward_reduce_scatter_iter = attrs_.find(FORWARD_REDUCE_SCATTER);
99   if (forward_reduce_scatter_iter != attrs_.end()) {
100     MS_EXCEPTION_IF_NULL(forward_reduce_scatter_iter->second);
101     if (forward_reduce_scatter_iter->second->isa<BoolImm>()) {
102       forward_reduce_scatter_ = forward_reduce_scatter_iter->second->cast<BoolImmPtr>()->value();
103     } else {
104       MS_LOG(ERROR) << name_ << " : The value of forward reduce scatter is not bool.";
105       return FAILED;
106     }
107   }
108 
109   auto field_size_iter = attrs_.find(FIELD_SIZE);
110   if (field_size_iter != attrs_.end()) {
111     MS_EXCEPTION_IF_NULL(field_size_iter->second);
112     if (field_size_iter->second->isa<Int64Imm>()) {
113       field_size_ = field_size_iter->second->cast<Int64ImmPtr>()->value();
114     } else {
115       MS_LOG(ERROR) << name_ << " : The value of field_size is not int64_t.";
116       return FAILED;
117     }
118   }
119 
120   // infer inputs dimension size
121   if ((inputs_shape_.size() != MATMUL_INPUTS_SIZE) || (outputs_shape_.size() != MATMUL_OUTPUTS_SIZE)) {
122     MS_LOG(ERROR) << name_ << " : Inputs shape size or outputs shape size is wrong.";
123     return FAILED;
124   }
125   mat_a_dimension_ = inputs_shape_.at(0).size();
126   mat_b_dimension_ = inputs_shape_.at(1).size();
127 
128   return SUCCESS;
129 }
130 
CheckRelevantDimension(const Dimensions & long_strategy,const Dimensions & short_strategy)131 Status CheckRelevantDimension(const Dimensions &long_strategy, const Dimensions &short_strategy) {
132   size_t long_size = long_strategy.size();
133   size_t short_size = short_strategy.size();
134   if (long_size < short_size) {
135     MS_LOG(ERROR) << "Size error, the size of long strategy is " << long_size << ", the size of short strategy is "
136                   << short_size;
137     return FAILED;
138   }
139 
140   size_t len_diff = long_size - short_size;
141   for (size_t j = 0; j < SECOND_FROM_END(short_size); ++j) {
142     if (long_strategy.at(len_diff + j) != short_strategy.at(j)) {
143       MS_LOG(ERROR) << "Strategies of relevant dimensions are not equal, long strategy is "
144                     << ShapeToString(long_strategy) << ", short strategy is " << ShapeToString(short_strategy);
145       return FAILED;
146     }
147   }
148 
149   return SUCCESS;
150 }
151 
CheckStrategy(const StrategyPtr & strategy)152 Status MatMul::CheckStrategy(const StrategyPtr &strategy) {
153   if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
154     MS_LOG(ERROR) << name_ << " : Invalid strategy.";
155     return FAILED;
156   }
157 
158   Strategys stra = strategy->GetInputDim();
159   Dimensions mat_a_strategy = stra.at(0);
160   Dimensions mat_b_strategy = stra.at(1);
161 
162   size_t mat_a_size = mat_a_strategy.size();
163   size_t mat_b_size = mat_b_strategy.size();
164   if ((mat_a_size != mat_a_dimension_) || (mat_b_size != mat_b_dimension_)) {
165     MS_LOG(ERROR) << name_ << " : The dimensions of mat_a or mat_b's strategy is wrong.";
166     return FAILED;
167   }
168 
169   // for example: mat_a_strategy:[2,4,8,16], mat_b_strategy:[4,16,32]
170   // dev_matrix_shape:[2,4,8,16,32] (transpose_b is false)
171   // [16] in the example above
172   if (!transpose_b_ && (mat_a_strategy.back() != mat_b_strategy.at(SECOND_FROM_END(mat_b_size)))) {
173     MS_LOG(ERROR) << name_ << " : Can not do this operator in the strategy: " << StrategyToString(stra)
174                   << ", the transpose_b is false, the shard num of first input's column is " << mat_a_strategy.back()
175                   << ", but the shard num of second input's row is " << mat_b_strategy.at(SECOND_FROM_END(mat_b_size));
176     return FAILED;
177   } else if (transpose_b_ && (mat_a_strategy.back() != mat_b_strategy.back())) {
178     MS_LOG(ERROR) << name_ << " : Can not do this operator in the strategy: " << StrategyToString(stra)
179                   << ", the transpose_b is true, the shard num of first input's column is " << mat_a_strategy.back()
180                   << ", but the shard num of second input's column is " << mat_b_strategy.back();
181     return FAILED;
182   }
183 
184   if (mat_a_size >= mat_b_size) {
185     if (CheckRelevantDimension(mat_a_strategy, mat_b_strategy) != SUCCESS) {
186       MS_LOG(ERROR) << name_ << " : Strategies of relevant dimensions are not equal.";
187       return FAILED;
188     }
189   } else {
190     if (CheckRelevantDimension(mat_b_strategy, mat_a_strategy) != SUCCESS) {
191       MS_LOG(ERROR) << name_ << " : Strategies of relevant dimensions are not equal.";
192       return FAILED;
193     }
194   }
195 
196   if ((mat_a_dimension_ != 2 || mat_b_dimension_ != 2) && forward_reduce_scatter_) {
197     MS_LOG(WARNING) << name_
198                     << ": The dimension of mat a and mat b must be 2 in forward reduce scatter mode, "
199                        "setting the forward reduce scatter mode to false here";
200     forward_reduce_scatter_ = false;
201   }
202 
203   return SUCCESS;
204 }
205 
InferDevMatrixShape()206 Status MatMulBase::InferDevMatrixShape() {
207   Strategys stra = strategy_->GetInputDim();
208   Dimensions mat_a_strategy = stra.at(0);
209   Dimensions mat_b_strategy = stra.at(1);
210 
211   SetDevMatrixShape(mat_a_strategy, mat_b_strategy, transpose_b_, &dev_matrix_shape_);
212   origin_dev_matrix_shape_ = dev_matrix_shape_;
213   return SUCCESS;
214 }
215 
InferForwardCommunication()216 Status MatMulBase::InferForwardCommunication() {
217   forward_op_.clear();
218   size_t dimension = origin_dev_matrix_shape_.size();
219   size_t relevant_dimension_index = SECOND_FROM_END(dimension);
220   // Relevant dimension is not split and all reduce is not required,
221   // need to use origin_dev_matrix_shape_ here, since the dev_matrix_shape_ will be changed if repeated calculation.
222   if (origin_dev_matrix_shape_.at(relevant_dimension_index) == MIN_SLICE_NUM) {
223     MS_LOG(INFO) << name_ << " : Forward all reduce is not required.";
224     return SUCCESS;
225   }
226 
227   std::vector<Group> group_list;
228   if (CreateGroupByDim(relevant_dimension_index, &group_list) != SUCCESS) {
229     MS_LOG(ERROR) << name_ << " : Infer forward communication, create group failed.";
230     return FAILED;
231   } else if (group_list.empty()) {
232     MS_LOG(INFO) << name_ << " : Forward all reduce is not required.";
233     return SUCCESS;
234   }
235 
236   Operator op;
237   if (forward_reduce_scatter_) {
238     op = CreateReduceScatterOp(REDUCE_OP_SUM, group_list[0].name());
239   } else {
240     op = CreateAllReduceOp(REDUCE_OP_SUM, group_list[0].name());
241   }
242 
243   forward_op_.push_back(op);
244   MS_LOG(INFO) << name_ << " : The group name of forward communication is " << group_list[0].name();
245   return SUCCESS;
246 }
247 
InferTensorMap()248 Status MatMulBase::InferTensorMap() {
249   size_t size = dev_matrix_shape_.size();
250   if (repeated_calc_num_ > 1) {
251     // move the first dimension(repeated_calc_num_), just for the convenience of tensor-map's calculation
252     size = dev_matrix_shape_.size() - 1;
253   }
254 
255   Shape tensor_map_index;
256   // such as 5: tensor_map_index [4,3,2,1,0]
257   for (size_t i = 0; i < size; ++i) {
258     tensor_map_index.push_back((int64_t)(LAST_INDEX(size) - i));
259   }
260 
261   // infer output tensor map: [4,3,2,0], delete the second-from-end element
262   TensorMap output_tensor_map = tensor_map_index;
263   (void)output_tensor_map.erase(output_tensor_map.begin() + static_cast<different_type>(SECOND_FROM_END(size)));
264 
265   // infer mat_a tensor map
266   // for example: mat_a_dimension is 4, mat_a tensor map:[4,3,2,1]
267   TensorMap mat_a_tensor_map = tensor_map_index;
268   // delete last one element
269   mat_a_tensor_map.pop_back();
270   // delete the first (dev_matrix_size - 1 - mat_a_dimension) elements
271   (void)mat_a_tensor_map.erase(
272     mat_a_tensor_map.begin(),
273     mat_a_tensor_map.begin() + static_cast<different_type>(LAST_INDEX(size) - mat_a_dimension_));
274 
275   // infer mat_b tensor map
276   TensorMap mat_b_tensor_map = tensor_map_index;
277   // delete the third-to-last element
278   (void)mat_b_tensor_map.erase(mat_b_tensor_map.begin() + static_cast<different_type>(THIRD_FROM_END(size)));
279   // delete the first (dev_matrix_size - 1 - mat_b_dimension) elements
280   (void)mat_b_tensor_map.erase(
281     mat_b_tensor_map.begin(),
282     mat_b_tensor_map.begin() + static_cast<different_type>(LAST_INDEX(size) - mat_b_dimension_));
283   if (transpose_b_) {
284     // swap the last two elements
285     int64_t last_value = mat_b_tensor_map.back();
286     mat_b_tensor_map.pop_back();
287     (void)mat_b_tensor_map.insert(
288       mat_b_tensor_map.begin() + static_cast<different_type>(LAST_INDEX(mat_b_tensor_map.size())), last_value);
289   }
290 
291   if (forward_reduce_scatter_) {
292     if (dev_matrix_shape_.size() != 3) {
293       MS_LOG(WARNING) << name_
294                       << ": The dimension of dev matrix shape must be 3 in forward reduce scatter mode, "
295                          "setting the forward reduce scatter mode to false here";
296       forward_reduce_scatter_ = false;
297     } else if (outputs_shape_[0][0] % (dev_matrix_shape_[0] * dev_matrix_shape_[1]) != 0) {
298       MS_LOG(WARNING) << name_
299                       << ": The first dimension of output should be split by dev_matrix[0]*dev_matrix[1] in "
300                          "forward reduce scatter mode, setting the forward reduce scatter mode to false here";
301       forward_reduce_scatter_ = false;
302     } else {
303       // the forward reduce scatter only support that the dimension of output is 2
304       output_tensor_map = {1, 0};
305     }
306   }
307 
308   inputs_tensor_map_.push_back(mat_a_tensor_map);
309   inputs_tensor_map_.push_back(mat_b_tensor_map);
310   outputs_tensor_map_.push_back(output_tensor_map);
311   return SUCCESS;
312 }
313 
InferTensorLayout(TensorLayouts * inputs_layout,TensorLayouts * outputs_layout)314 Status MatMulBase::InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout) {
315   Shape output_dev_matrix_shape;
316   if (forward_reduce_scatter_) {
317     if (dev_matrix_shape_.size() != 3) {
318       MS_LOG(ERROR) << "The size of origin dev matrix shape must be 3 in forward reduce scatter mode";
319       return FAILED;
320     }
321     output_dev_matrix_shape = {dev_matrix_shape_[0] * dev_matrix_shape_[1], dev_matrix_shape_[2]};
322   } else {
323     output_dev_matrix_shape = dev_matrix_shape_;
324   }
325 
326   TensorLayout mat_a_layout, mat_b_layout, output_layout;
327   if ((mat_a_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], inputs_shape_[0]) != SUCCESS) ||
328       (mat_b_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[1], inputs_shape_[1]) != SUCCESS) ||
329       (output_layout.InitFromVector(output_dev_matrix_shape, outputs_tensor_map_[0], outputs_shape_[0]) != SUCCESS)) {
330     return FAILED;
331   }
332 
333   if (field_size_ != 0) {
334     mat_b_layout.set_field_size(field_size_);
335   }
336 
337   inputs_layout->push_back(mat_a_layout);
338   inputs_layout->push_back(mat_b_layout);
339   outputs_layout->push_back(output_layout);
340   return SUCCESS;
341 }
342 
InferTensorInfo()343 Status MatMulBase::InferTensorInfo() {
344   // infer tensor layout
345   TensorLayouts inputs_layout, outputs_layout;
346   if (InferTensorLayout(&inputs_layout, &outputs_layout) != SUCCESS) {
347     return FAILED;
348   }
349 
350   TensorLayout mat_a_layout = inputs_layout.at(0);
351   TensorLayout mat_b_layout = inputs_layout.at(1);
352   TensorLayout output_layout = outputs_layout.at(0);
353   TensorInfo mat_a_tensor_info(mat_a_layout);
354   TensorInfo mat_b_tensor_info(mat_b_layout);
355   TensorInfo output_tensor_info(output_layout);
356 
357   inputs_tensor_info_.push_back(mat_a_tensor_info);
358   inputs_tensor_info_.push_back(mat_b_tensor_info);
359   outputs_tensor_info_.push_back(output_tensor_info);
360   return SUCCESS;
361 }
362 
Init(const StrategyPtr & strategy)363 Status MatMulBase::Init(const StrategyPtr &strategy) {
364   if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
365     MS_LOG(ERROR) << name_ << " : Init failed.";
366     return FAILED;
367   }
368 
369   if (forward_reduce_scatter_) {
370     virtual_div_op_.clear();
371     MS_LOG(INFO) << "The forward reduce scatter mode does not involve repeated calculation, clear the virtual div op";
372   }
373 
374   MS_LOG(INFO) << name_ << " : Init success.";
375   return SUCCESS;
376 }
377 
InitForCostModel(const StrategyPtr & strategy)378 Status MatMulBase::InitForCostModel(const StrategyPtr &strategy) {
379   if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
380     MS_LOG(ERROR) << name_ << " : Init for cost model failed.";
381     return FAILED;
382   }
383 
384   MS_LOG(INFO) << name_ << " : Init for cost model success.";
385   return SUCCESS;
386 }
387 
SwapLastTwoElements(mindspore::parallel::Shape * const input)388 Status MatMulBase::SwapLastTwoElements(mindspore::parallel::Shape *const input) {
389   if (input->size() < 2) {
390     MS_LOG(ERROR) << name_ << " : The size of inputs small than 2.";
391     return FAILED;
392   }
393   auto last_1st_value = input->at(input->size() - 1);
394   auto last_2nd_value = input->at(input->size() - 2);
395   input->pop_back();
396   input->pop_back();
397   input->push_back(last_1st_value);
398   input->push_back(last_2nd_value);
399   return SUCCESS;
400 }
401 
GenerateStrategies(int64_t stage_id)402 Status MatMulBase::GenerateStrategies(int64_t stage_id) {
403   if (GetAttrs() != SUCCESS) {
404     MS_LOG(ERROR) << name_ << " : GetAttrs failed.";
405     return FAILED;
406   }
407   CheckGlobalDeviceManager();
408   RankList dev_list = g_device_manager->GetDeviceListByStageId(stage_id);
409   size_t dev_num = dev_list.size();
410   Shape input0_shape = inputs_shape_[0], input1_shape = inputs_shape_[1];
411   if (transpose_a_) {
412     if (SwapLastTwoElements(&input0_shape) == FAILED) {
413       MS_LOG(ERROR) << name_ << " : Swap last two elements failed.";
414     }
415   }
416   if (transpose_b_) {
417     if (SwapLastTwoElements(&input1_shape) == FAILED) {
418       MS_LOG(ERROR) << name_ << " : Swap last two elements failed.";
419     }
420   }
421   // The shape of input0 (input1)
422   // E.g., input0 = [100, 200, 300], input1 = [300, 400]
423 
424   // Combining the input0_shape and input1_shape
425   // E.g., combined_shape = [100, 200, 300, 400]
426   size_t input1_shape_size = input1_shape.size(), input0_shape_size = input0_shape.size();
427   Dimensions combined_partitions;
428   Shape combined_shape;
429   // In SwapLastTwoElements(), it is guaranteed that input0_shape.size() and input1_shape.size() are both larger than 2
430   if (input0_shape.size() >= input1_shape.size()) {
431     combined_shape = input0_shape;
432     combined_shape.push_back(input1_shape[input1_shape.size() - 1]);
433   } else {
434     combined_shape = input1_shape;
435     combined_shape.push_back(input0_shape[input0_shape.size() - 2]);
436   }
437   std::function<void(uint64_t, size_t)> recursive = [&stage_id, &dev_num, &combined_partitions, &combined_shape,
438                                                      &input1_shape_size, &recursive, &input0_shape_size,
439                                                      this](uint64_t current_index, size_t n) {
440     // Finishing the recursive steps, if the strategy is valid, then calculate the cost
441     // for this operator under the strategy.
442     if (current_index == combined_shape.size()) {
443       StrategyPtr sp;
444       if (this->PrepareStrategy(stage_id, dev_num, combined_partitions, input0_shape_size, input1_shape_size, &sp) ==
445           FAILED) {
446         return;
447       }
448       if (this->SetCostUnderStrategy(sp) == FAILED) {
449         MS_LOG(WARNING) << name_ << " : Calculating cost for strategy failed.";
450         return;
451       }
452     } else {
453       MS_LOG(DEBUG) << name_ << " : The value input0_shape_size: " << input0_shape_size
454                     << ", input1_shape_size: " << input1_shape_size;
455       for (uint64_t i = 1; i <= n; i *= 2) {
456         if (n % i == 0 && LongToSize(combined_shape[current_index]) % i == 0) {
457           combined_partitions.push_back(i);
458           recursive(current_index + 1, n / i);
459           combined_partitions.pop_back();
460         }
461       }
462     }
463   };
464   recursive(0, dev_num);
465   if (strategy_cost_.empty()) {
466     MS_LOG(EXCEPTION) << name_ << " : No available strategy.";
467   }
468   return Status::SUCCESS;
469 }
470 
GenerateOpStrategies(int64_t)471 std::vector<StrategyPtr> MatMulBase::GenerateOpStrategies(int64_t) {
472   std::vector<StrategyPtr> sp_vector;
473   return sp_vector;
474 }
475 
PrepareStrategy(int64_t stage_id,size_t dev_num,mindspore::parallel::Dimensions combined_partitions,size_t input0_shape_size,size_t input1_shape_size,mindspore::parallel::StrategyPtr * const sp)476 Status MatMulBase::PrepareStrategy(int64_t stage_id, size_t dev_num,
477                                    mindspore::parallel::Dimensions combined_partitions, size_t input0_shape_size,
478                                    size_t input1_shape_size, mindspore::parallel::StrategyPtr *const sp) {
479   int64_t product =
480     std::accumulate(combined_partitions.begin(), combined_partitions.end(), 1, std::multiplies<int64_t>());
481   const auto fully_use_device = CostModelContext::GetInstance()->fully_use_device();
482   if (!fully_use_device) {
483     if (LongToSize(product) > dev_num) {
484       return FAILED;
485     }
486   } else {
487     if (LongToSize(product) != dev_num) {
488       return FAILED;
489     }
490   }
491   Dimensions input0_partitions, input1_partitions;
492   if (input0_shape_size >= input1_shape_size) {
493     for (size_t i = 0; i < input0_shape_size; ++i) {
494       input0_partitions.push_back(combined_partitions[i]);
495     }
496     if (input1_shape_size == 2) {
497       input1_partitions.push_back(combined_partitions[combined_partitions.size() - 2]);
498       input1_partitions.push_back(combined_partitions[combined_partitions.size() - 1]);
499     } else {
500       // input1_shape.size() > 2
501       for (size_t j = combined_partitions.size() - input1_shape_size - 1; j < combined_partitions.size(); ++j) {
502         if (j == combined_partitions.size() - 3) {
503           continue;
504         }
505         input1_partitions.push_back(combined_partitions[j]);
506       }
507     }
508   } else {
509     for (size_t i = 0; i < input1_shape_size; ++i) {
510       input1_partitions.push_back(combined_partitions[i]);
511     }
512     for (size_t j = combined_partitions.size() - input0_shape_size - 1; j < combined_partitions.size() - 3; ++j) {
513       input0_partitions.push_back(combined_partitions[j]);
514     }
515     input0_partitions.push_back(combined_partitions[combined_partitions.size() - 1]);
516     input0_partitions.push_back(combined_partitions[combined_partitions.size() - 3]);
517   }
518   if (transpose_a_) {
519     if (SwapLastTwoElements(&input0_partitions) == FAILED) {
520       MS_LOG(ERROR) << name_ << " : Swap last two elements failed.";
521     }
522   }
523   if (transpose_b_) {
524     if (SwapLastTwoElements(&input1_partitions) == FAILED) {
525       MS_LOG(ERROR) << name_ << " : Swap last two elements failed.";
526     }
527   }
528   Strategys stras;
529   stras.push_back(input0_partitions);
530   stras.push_back(input1_partitions);
531   (*sp) = std::make_shared<Strategy>(stage_id, stras);
532 
533   return SUCCESS;
534 }
535 
InitTensorInfoForCost(std::vector<TensorInfo> * relica_inputs_tensor_vector)536 void MatMulBase::InitTensorInfoForCost(std::vector<TensorInfo> *relica_inputs_tensor_vector) {
537   TensorLayout tly;
538   if (transpose_a_) {
539     Shape replica_input0_shape(inputs_tensor_info_[0].shape());
540     Shape replica_input0_slice_shape(inputs_tensor_info_[0].slice_shape());
541     if (SwapLastTwoElements(&replica_input0_shape) == FAILED) {
542       MS_LOG(ERROR) << name_ << " : Swap last two elements failed.";
543     }
544     if (SwapLastTwoElements(&replica_input0_slice_shape) == FAILED) {
545       MS_LOG(ERROR) << name_ << " : Swap last two elements failed.";
546     }
547 
548     TensorInfo replica_input0_info(tly, replica_input0_shape, replica_input0_slice_shape);
549     relica_inputs_tensor_vector->push_back(replica_input0_info);
550   } else {
551     relica_inputs_tensor_vector->push_back(inputs_tensor_info_[0]);
552   }
553   if (transpose_b_) {
554     Shape replica_input1_shape(inputs_tensor_info_[1].shape());
555     Shape replica_input1_slice_shape(inputs_tensor_info_[1].slice_shape());
556     if (SwapLastTwoElements(&replica_input1_shape) == FAILED) {
557       MS_LOG(ERROR) << name_ << " : Swap last two elements failed.";
558     }
559     if (SwapLastTwoElements(&replica_input1_slice_shape) == FAILED) {
560       MS_LOG(ERROR) << name_ << " : Swap last two elements failed.";
561     }
562 
563     TensorInfo replica_input1_info(tly, replica_input1_shape, replica_input1_slice_shape);
564     relica_inputs_tensor_vector->push_back(replica_input1_info);
565   } else {
566     relica_inputs_tensor_vector->push_back(inputs_tensor_info_[1]);
567   }
568 }
569 
CheckForTensorSliceValid() const570 Status MatMulBase::CheckForTensorSliceValid() const {
571   const auto align_enable = CostModelContext::GetInstance()->tensor_slice_alignment_enable();
572   const auto align_size = CostModelContext::GetInstance()->tensor_slice_alignment_size();
573   if (!align_enable) {
574     return SUCCESS;
575   }
576   if (inputs_tensor_info_.empty()) {
577     return FAILED;
578   }
579   for (auto &one_input_tensor : inputs_tensor_info_) {
580     auto slice_shape = one_input_tensor.slice_shape();
581     if ((LongToSize(slice_shape[LAST_INDEX(slice_shape.size())]) % align_size != 0) ||
582         (LongToSize(slice_shape[SECOND_FROM_END(slice_shape.size())]) % align_size != 0)) {
583       return FAILED;
584     }
585   }
586   return SUCCESS;
587 }
588 
GenerateBatchStrategies()589 std::shared_ptr<Strategys> BatchMatMulInfo::GenerateBatchStrategies() {
590   Dimensions batch_strategy(inputs_shape_[1].size() - 1, 1);
591   batch_strategy.insert(batch_strategy.begin(), stage_device_size_);
592   Strategys strategy_v = {batch_strategy, batch_strategy};
593   return std::make_shared<Strategys>(strategy_v);
594 }
595 
SetCostUnderStrategy(const mindspore::parallel::StrategyPtr & strategy)596 Status MatMulBase::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr &strategy) {
597   if (InitForCostModel(strategy) == FAILED) {
598     MS_LOG(ERROR) << name_ << " : Initialization under the strategy failed.";
599     return FAILED;
600   }
601   PrintStrategy(strategy);
602   // Check whether the tensor slice of input_tensor_info is valid or not
603   if (CheckForTensorSliceValid() != SUCCESS) {
604     MS_LOG(INFO) << name_ << " : The tensor slice is not valid under this strategy.";
605     return FAILED;
606   }
607   // Here, a replicated inputs_ is constructed for the transposed TensorInfo.
608   std::vector<TensorInfo> relica_inputs_tensor_vector;
609   InitTensorInfoForCost(&relica_inputs_tensor_vector);
610 
611   int64_t stage_id = strategy->GetInputStage();
612   // Here, we use the origin outputs_, because we only use the slice size of the output tensor.
613   // It does not matter whether the output tensor is transposed or not.
614   double computation_cost =
615     operator_cost()->GetForwardComputationCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id);
616   double communication_cost = operator_cost()->GetCommCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id);
617   const auto gamma = CostModelContext::GetInstance()->costmodel_gamma();
618   std::shared_ptr<Cost> result = std::make_shared<Cost>(computation_cost, communication_cost);
619   result->communication_without_parameter_ =
620     operator_cost()->GetForwardCommCost(relica_inputs_tensor_vector, outputs_tensor_info_, stage_id);
621   result->communication_with_partial_para_ =
622     result->communication_without_parameter_ + gamma * (communication_cost - result->communication_without_parameter_);
623 
624   // Breaking ties for preferring data parallelization
625   BreakingTiesForPerferringDataParallel(strategy, result);
626   MS_LOG(DEBUG) << name_ << " : computation_cost: " << result->computation_cost_
627                 << ", communication_cost: " << result->communication_cost_
628                 << ", communication_without_parameter_: " << result->communication_without_parameter_
629                 << ", communication_with_partial_para_: " << result->communication_with_partial_para_;
630   // refine communication cost calculation for practice
631   RefineForPracticalCost(result, false);
632   result->communication_forward_ = result->communication_without_parameter_;
633 
634   std::shared_ptr<StrategyWithCost> swc =
635     std::make_shared<StrategyWithCost>(strategy, inputs_tensor_info_, outputs_tensor_info_);
636   swc->cost_list.push_back(result);
637   strategy_cost_.emplace_back(swc);
638 
639   return SUCCESS;
640 }
641 }  // namespace parallel
642 }  // namespace mindspore
643