• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.h"
18 
19 #include <algorithm>
20 #include <functional>
21 #include <map>
22 #include <memory>
23 #include <optional>
24 #include <set>
25 #include <vector>
26 
27 #include "frontend/parallel/auto_parallel/rec_core/rec_parse_graph.h"
28 #include "frontend/parallel/auto_parallel/rec_core/rec_partition.h"
29 #include "frontend/parallel/ops_info/flash_attention_score_info.h"
30 #include "frontend/parallel/ops_info/operator_info.h"
31 #include "frontend/parallel/ops_info/strided_slice_info.h"
32 #include "frontend/parallel/ops_info/gather_info.h"
33 #include "frontend/parallel/parameter_manager.h"
34 #include "frontend/parallel/step_parallel.h"
35 #include "frontend/parallel/step_parallel_utils.h"
36 #include "frontend/parallel/strategy.h"
37 #include "include/common/utils/utils.h"
38 #include "ir/value.h"
39 #include "mindspore/core/ops/ops_func_impl/flash_attention_score.h"
40 #include "ops/op_enum.h"
41 
42 namespace mindspore {
43 namespace parallel {
44 namespace {
45 using PrepareStraFuncPtr = Strategies (*)(const std::shared_ptr<OperatorInfo> &, Dimensions, bool);
46 std::map<std::string, PrepareStraFuncPtr> g_prepare_stra_map;
47 
GetKeepDimsFromAttrs(const std::shared_ptr<OperatorInfo> & op)48 std::optional<bool> GetKeepDimsFromAttrs(const std::shared_ptr<OperatorInfo> &op) {
49   auto keep_dims_iter = op->attrs().find(KEEP_DIMS);
50   if (keep_dims_iter == op->attrs().end()) {
51     return std::nullopt;
52   }
53   auto keep_dims_ptr = keep_dims_iter->second;
54   MS_EXCEPTION_IF_NULL(keep_dims_ptr);
55   if (!keep_dims_ptr->isa<BoolImm>()) {
56     MS_LOG(EXCEPTION) << op->name() << ": Keep_dims is not a bool.";
57   }
58   auto keepdims = keep_dims_ptr->cast<BoolImmPtr>()->value();
59   return keepdims;
60 }
61 
GetKeepDimsFromInputs(const std::shared_ptr<OperatorInfo> & op)62 std::optional<bool> GetKeepDimsFromInputs(const std::shared_ptr<OperatorInfo> &op) {
63   auto keep_dims_opt = GetScalarValueFromInputs<bool>(op->input_value(), op->name(), KEEP_DIMS);
64   return keep_dims_opt;
65 }
66 
GetKeepDims(const std::shared_ptr<OperatorInfo> & op)67 bool GetKeepDims(const std::shared_ptr<OperatorInfo> &op) {
68   auto keep_dims_opt = GetKeepDimsFromAttrs(op);
69   if (!keep_dims_opt.has_value()) {
70     keep_dims_opt = GetKeepDimsFromInputs(op);
71   }
72   if (!keep_dims_opt.has_value()) {
73     MS_LOG(EXCEPTION) << op->name() << ": Don't have attr keep_dims.";
74   }
75   auto keepdims = keep_dims_opt.value();
76   return keepdims;
77 }
78 
GetDimList(const std::shared_ptr<OperatorInfo> & op)79 Dimensions GetDimList(const std::shared_ptr<OperatorInfo> &op) {
80   Dimensions dim_list;
81   bool keep_dims = GetKeepDims(op);
82   if (keep_dims) {
83     return dim_list;
84   }
85 
86   const auto &name = op->name();
87   auto dim_list_opt = GetArrayValueFromInputs<int64_t>(op->input_value(), name, AXIS);
88   if (!dim_list_opt.has_value()) {
89     MS_LOG(EXCEPTION) << "For " << name << ", failed to get value for " << AXIS << ".";
90   }
91 
92   dim_list = dim_list_opt.value();
93   auto x_dim = op->inputs_shape()[0].size();
94   // axis is (), reduce all dim
95   if (dim_list.empty()) {
96     for (size_t i = 0; i < x_dim; ++i) {
97       dim_list.push_back(SizeToLong(i));
98     }
99   } else {
100     auto AxisCorrectFunc = [x_dim](const int64_t axis) {
101       if (axis < 0) {
102         return axis + SizeToLong(x_dim);
103       }
104       return axis;
105     };
106     std::transform(dim_list.begin(), dim_list.end(), dim_list.begin(), AxisCorrectFunc);
107   }
108   return dim_list;
109 }
110 }  // namespace
111 
OpNameToId(const std::vector<std::shared_ptr<OperatorInfo>> & ops,const std::shared_ptr<OperatorInfo> & op)112 size_t OpNameToId(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const std::shared_ptr<OperatorInfo> &op) {
113   for (size_t i = 0; i < ops.size(); ++i) {
114     if (ops[i]->name() == op->name()) {
115       return i;
116     }
117   }
118 
119   return SIZE_MAX;
120 }
121 
IsDimensionsFlat(const Dimensions & dims)122 bool IsDimensionsFlat(const Dimensions &dims) {
123   return !std::any_of(dims.begin(), dims.end(), [](const int64_t &dim) { return dim != 1; });
124 }
125 
IsDimensionsEmpty(const Dimensions & dims)126 bool IsDimensionsEmpty(const Dimensions &dims) { return dims.empty(); }
127 
IsStrategyFlat(const StrategyPtr & str)128 bool IsStrategyFlat(const StrategyPtr &str) {
129   const auto &input_dims = str->GetInputDim();
130   return !std::any_of(input_dims.begin(), input_dims.end(),
131                       [](const Dimensions &dims) { return !IsDimensionsFlat(dims); });
132 }
133 
DevicesForDimensions(const Dimensions & dims)134 size_t DevicesForDimensions(const Dimensions &dims) {
135   return std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<size_t>());
136 }
137 
HasStrategy(std::shared_ptr<OperatorInfo> op)138 bool HasStrategy(std::shared_ptr<OperatorInfo> op) {
139   StrategyPtr s_strategy = op->selected_strategy();
140   if (s_strategy != nullptr && !s_strategy->ToString().empty()) {
141     return true;
142   }
143   return false;
144 }
145 
FindIndexOfOperatorIncoming(const std::vector<std::shared_ptr<OperatorInfo>> & ops,const std::vector<std::vector<std::string>> & input_tensor_names,size_t iter_ops)146 size_t FindIndexOfOperatorIncoming(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
147                                    const std::vector<std::vector<std::string>> &input_tensor_names, size_t iter_ops) {
148   size_t incoming_op_index = SIZE_MAX;
149   for (size_t i = 1; i < input_tensor_names[iter_ops].size(); i++) {
150     for (size_t j = 0; j < input_tensor_names.size(); j++) {
151       if (input_tensor_names[iter_ops][i] == input_tensor_names[j][0]) {
152         incoming_op_index = j;
153         break;
154       }
155     }
156     if (incoming_op_index != SIZE_MAX && HasStrategy(ops.at(incoming_op_index)) &&
157         !IsStrategyFlat(ops.at(incoming_op_index)->selected_strategy())) {
158       break;
159     }
160   }
161   if (incoming_op_index != SIZE_MAX &&
162       ops.at(incoming_op_index)->name().find(VIRTUALDATASETINFO) != std::string::npos) {
163     return SIZE_MAX;
164   }
165   return incoming_op_index;
166 }
167 
FindIndexOfOperatorOutgoing(const std::vector<std::shared_ptr<OperatorInfo>> & ops,const std::vector<std::vector<std::string>> & input_tensor_names,const size_t iter_ops)168 std::pair<size_t, size_t> FindIndexOfOperatorOutgoing(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
169                                                       const std::vector<std::vector<std::string>> &input_tensor_names,
170                                                       const size_t iter_ops) {
171   bool found = false;
172   size_t outgoing_op_index = SIZE_MAX;
173   size_t iter_op_inputs = SIZE_MAX;
174 
175   for (size_t i = 0; i < input_tensor_names.size(); i++) {
176     for (size_t j = 1; j < input_tensor_names[i].size(); j++) {
177       if (input_tensor_names[i][j] == input_tensor_names[iter_ops][0] &&
178           ops[i]->selected_strategy()->GetInputNumber() != 0) {
179         outgoing_op_index = i;
180         iter_op_inputs = std::min(j - 1, ops[outgoing_op_index]->inputs_shape().size() - 1);
181         found = true;
182         break;
183       }
184     }
185     if (found) {
186       break;
187     }
188   }
189 
190   std::pair<size_t, size_t> res = std::make_pair(outgoing_op_index, iter_op_inputs);
191 
192   return res;
193 }
194 
GetGatherAxis(const std::shared_ptr<OperatorInfo> & op)195 int64_t GetGatherAxis(const std::shared_ptr<OperatorInfo> &op) {
196   auto axis_input = GetValue<int64_t>(op->input_value().at(2));
197   if (axis_input < 0) {
198     axis_input += SizeToLong(op->inputs_shape()[0].size());
199   }
200   if (axis_input >= SizeToLong(op->inputs_shape()[0].size())) {
201     MS_LOG(EXCEPTION) << "Failure: Gather's axis out of range.";
202   }
203   return axis_input;
204 }
205 
GetGatherBatchDims(const std::shared_ptr<OperatorInfo> & op)206 int64_t GetGatherBatchDims(const std::shared_ptr<OperatorInfo> &op) {
207   int64_t batch_dims = -1;
208   auto batch_dims_val = GetScalarValueFromInputs<int64_t>(op->input_value(), op->name(), BATCH_DIMS);
209   if (batch_dims_val.has_value()) {
210     batch_dims = batch_dims_val.value();
211   } else {
212     MS_LOG(EXCEPTION) << op->name() << ": Failed to fetch the value of batch dims";
213   }
214   return batch_dims;
215 }
216 
ReverseRemainingList(const std::shared_ptr<std::vector<size_t>> & no_stra_op_list)217 void ReverseRemainingList(const std::shared_ptr<std::vector<size_t>> &no_stra_op_list) {
218   MS_LOG(INFO) << "ReverseRemainingList";
219   std::reverse(no_stra_op_list->begin(), no_stra_op_list->end());
220 }
221 
GenerateStrategy(const std::shared_ptr<Graph> & graph,const std::vector<std::shared_ptr<OperatorInfo>> & ops,const std::shared_ptr<std::vector<std::vector<size_t>>> & eli_list,const std::vector<std::vector<std::string>> & input_tensor_names,const std::shared_ptr<std::vector<size_t>> & index_list,bool is_training,const std::vector<std::vector<size_t>> & param_users_ops_index,const FuncGraphPtr & root)222 void GenerateStrategy(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
223                       const std::shared_ptr<std::vector<std::vector<size_t>>> &eli_list,
224                       const std::vector<std::vector<std::string>> &input_tensor_names,
225                       const std::shared_ptr<std::vector<size_t>> &index_list, bool is_training,
226                       const std::vector<std::vector<size_t>> &param_users_ops_index, const FuncGraphPtr &root) {
227   RecStrategyPropagator propagator(graph, ops, eli_list, input_tensor_names, index_list, is_training,
228                                    param_users_ops_index, root);
229 
230   if (g_device_manager->DeviceNum() > SIZE_THIRTY_TWO) {
231     propagator.ExtraShardMatmulOnBatchDim();
232   }
233   if (is_training) {
234     propagator.GenerateStrategyV3();
235   } else {
236     propagator.GenerateStrategyV1();
237   }
238 }
239 
FillFlashLayoutIndexes(const std::shared_ptr<FlashAttentionScoreInfo> & flashOp,size_t * batch_split_idx,size_t * n_split_idx,size_t * s_split_idx)240 void FillFlashLayoutIndexes(const std::shared_ptr<FlashAttentionScoreInfo> &flashOp, size_t *batch_split_idx,
241                             size_t *n_split_idx, size_t *s_split_idx) {
242   MS_EXCEPTION_IF_NULL(flashOp);
243   MS_EXCEPTION_IF_NULL(batch_split_idx);
244   MS_EXCEPTION_IF_NULL(n_split_idx);
245   MS_EXCEPTION_IF_NULL(s_split_idx);
246 
247   size_t tmp_batch_split_idx;
248   size_t tmp_n_split_idx;
249   size_t tmp_s_split_idx;
250 
251   using mindspore::ops::FASInputLayoutMode;
252   switch (flashOp->input_layout()) {
253     case FASInputLayoutMode::BSH:
254     case FASInputLayoutMode::BSND:
255       tmp_batch_split_idx = kIndex0;
256       tmp_s_split_idx = kIndex1;
257       tmp_n_split_idx = kIndex2;
258       break;
259     case FASInputLayoutMode::BNSD:
260       tmp_batch_split_idx = kIndex0;
261       tmp_n_split_idx = kIndex1;
262       tmp_s_split_idx = kIndex2;
263       break;
264     case FASInputLayoutMode::SBH:
265       tmp_s_split_idx = kIndex0;
266       tmp_batch_split_idx = kIndex1;
267       tmp_n_split_idx = kIndex2;
268       break;
269     default:
270       MS_LOG(EXCEPTION) << flashOp->name() << "unknown input_layout: " << flashOp->input_layout();
271   }
272 
273   *batch_split_idx = tmp_batch_split_idx;
274   *n_split_idx = tmp_n_split_idx;
275   *s_split_idx = tmp_s_split_idx;
276 }
277 
PrepareFlashAttentionScore(const std::shared_ptr<OperatorInfo> & op,Dimensions basic_stra,bool dyn_shape_tmp_fix)278 Strategies PrepareFlashAttentionScore(const std::shared_ptr<OperatorInfo> &op, Dimensions basic_stra,
279                                       bool dyn_shape_tmp_fix) {
280   std::shared_ptr<FlashAttentionScoreInfo> flashOp = std::static_pointer_cast<FlashAttentionScoreInfo>(op);
281 
282   if (flashOp->InitAttrs() != SUCCESS) {
283     MS_LOG(EXCEPTION) << flashOp->name() << " : InitAttrs failed.";
284   }
285 
286   Strategies expect_strategies = Strategies(ops::kFlashAttentionScoreInputsNum);
287   auto is_input_passed = flashOp->is_input_passed();
288 
289   size_t batch_idx;
290   size_t n_split_idx;
291   size_t s_split_idx;
292 
293   FillFlashLayoutIndexes(flashOp, &batch_idx, &n_split_idx, &s_split_idx);
294 
295   int64_t batch_split_num = basic_stra[batch_idx];
296   int64_t s1_split_num = basic_stra[s_split_idx];
297   int64_t n1_split_num = basic_stra[n_split_idx];
298   int64_t n2_split_num = flashOp->kv_split() ? n1_split_num : 1;
299 
300   Dimensions q_stra(op->inputs_shape()[ops::kFlashAttentionScoreInputQueryIndex].size(), 1);
301   q_stra[batch_idx] = batch_split_num;
302   q_stra[s_split_idx] = s1_split_num;
303   q_stra[n_split_idx] = n1_split_num;
304 
305   Dimensions kv_stra(op->inputs_shape()[ops::kFlashAttentionScoreInputKeyIndex].size(), 1);
306   kv_stra[batch_idx] = batch_split_num;
307   kv_stra[n_split_idx] = n2_split_num;
308 
309   expect_strategies[ops::kFlashAttentionScoreInputQueryIndex] = q_stra;
310   expect_strategies[ops::kFlashAttentionScoreInputKeyIndex] = kv_stra;
311   expect_strategies[ops::kFlashAttentionScoreInputValueIndex] = kv_stra;
312 
313   if (is_input_passed[ops::kFlashAttentionScoreInputRealShiftIndex]) {
314     int64_t real_shift_s1_split_num = flashOp->real_shift_have_s1_dim() ? s1_split_num : 1;
315     int64_t real_shift_batch_split_num = flashOp->real_shift_have_batch_dim() ? batch_split_num : 1;
316     expect_strategies[ops::kFlashAttentionScoreInputRealShiftIndex] = {real_shift_batch_split_num, n1_split_num,
317                                                                        real_shift_s1_split_num, 1};
318   }
319 
320   if (is_input_passed[ops::kFlashAttentionScoreInputDropMaskIndex]) {
321     expect_strategies[ops::kFlashAttentionScoreInputDropMaskIndex] = {batch_split_num, n1_split_num, s1_split_num, 1};
322   }
323 
324   if (is_input_passed[ops::kFlashAttentionScoreInputPaddingMaskIndex]) {
325     expect_strategies[ops::kFlashAttentionScoreInputPaddingMaskIndex] = {};
326   }
327 
328   if (is_input_passed[ops::kFlashAttentionScoreInputAttnMaskIndex]) {
329     auto attn_mask_shape =
330       flashOp->inputs_shape().at(flashOp->GetStrategyRealIndex(ops::kFlashAttentionScoreInputAttnMaskIndex));
331     int64_t s1_split_num_attn_mask = flashOp->is_attn_mask_compressed() ? 1 : s1_split_num;
332     if (attn_mask_shape.size() == kSizeTwo) {
333       // attn_mask_shape: (S1, S2)
334       expect_strategies[ops::kFlashAttentionScoreInputAttnMaskIndex] = {s1_split_num_attn_mask, 1};
335     } else if (attn_mask_shape.size() == kSizeFour) {
336       // attn_mask_shape: (B, N1, S1, S2) or (B, 1, S1, S2)
337       auto attn_mask_n1_split_num = flashOp->attn_mask_have_n1_dim() ? n1_split_num : 1;
338       auto attn_batch_split_num = flashOp->attn_mask_have_batch_dim() ? batch_split_num : 1;
339       expect_strategies[ops::kFlashAttentionScoreInputAttnMaskIndex] = {attn_batch_split_num, attn_mask_n1_split_num,
340                                                                         s1_split_num_attn_mask, 1};
341     }
342   }
343 
344   if (is_input_passed[ops::kFlashAttentionScoreInputPrefixIndex]) {
345     expect_strategies[ops::kFlashAttentionScoreInputPrefixIndex] = {batch_split_num};
346   }
347 
348   if (is_input_passed[ops::kFlashAttentionScoreInputActualSeqQlenIndex]) {
349     expect_strategies[ops::kFlashAttentionScoreInputActualSeqQlenIndex] = {NO_SPLIT_STRATEGY};
350   }
351   if (is_input_passed[ops::kFlashAttentionScoreInputActualSeqKVlenIndex]) {
352     expect_strategies[ops::kFlashAttentionScoreInputActualSeqKVlenIndex] = {NO_SPLIT_STRATEGY};
353   }
354 
355   expect_strategies.erase(std::remove(expect_strategies.begin(), expect_strategies.end(), Shape{}),
356                           expect_strategies.end());
357   return expect_strategies;
358 }
359 
PrepareFillV2(const std::shared_ptr<OperatorInfo> & op,Dimensions basic_stra,bool dyn_shape_tmp_fix)360 Strategies PrepareFillV2(const std::shared_ptr<OperatorInfo> &op, Dimensions basic_stra, bool dyn_shape_tmp_fix) {
361   Strategies strategies;
362 
363   if (op->outputs_shape().size() == 0) {
364     MS_LOG(EXCEPTION) << op->name() << " output tensor info is empty.";
365   }
366 
367   for (size_t i = basic_stra.size(); i < op->outputs_shape()[0].size(); i++) {
368     basic_stra.push_back(1);
369   }
370 
371   strategies.push_back(basic_stra);
372   basic_stra.clear();
373   strategies.push_back(basic_stra);
374   return strategies;
375 }
376 
PrepareMatMulStrategy(Graph::NodeType * node,bool transpose_a,bool transpose_b,size_t iter_op_inputs)377 Dimensions PrepareMatMulStrategy(Graph::NodeType *node, bool transpose_a, bool transpose_b, size_t iter_op_inputs) {
378   Dimensions strategy;
379   if (transpose_a && (iter_op_inputs == 0)) {
380     strategy.push_back(static_cast<int64_t>(1.0 / node->apply.arguments[iter_op_inputs].tensor_str.str_w));
381     strategy.push_back(static_cast<int64_t>(1.0 / node->apply.arguments[iter_op_inputs].tensor_str.str_h));
382   } else if (transpose_b && (iter_op_inputs == 1)) {
383     strategy.push_back(static_cast<int64_t>(1.0 / node->apply.arguments[iter_op_inputs].tensor_str.str_w));
384     strategy.push_back(static_cast<int64_t>(1.0 / node->apply.arguments[iter_op_inputs].tensor_str.str_h));
385   } else {
386     strategy.push_back(static_cast<int64_t>(1.0 / node->apply.arguments[iter_op_inputs].tensor_str.str_h));
387     strategy.push_back(static_cast<int64_t>(1.0 / node->apply.arguments[iter_op_inputs].tensor_str.str_w));
388   }
389   return strategy;
390 }
391 
PrepareMatMul(Graph::NodeType * node,const std::shared_ptr<OperatorInfo> & op)392 Strategies PrepareMatMul(Graph::NodeType *node, const std::shared_ptr<OperatorInfo> &op) {
393   Strategies strategies;
394   auto input_value = op->input_value();
395   bool transpose_a = input_value[2]->cast<BoolImmPtr>()->value();
396   bool transpose_b = input_value[3]->cast<BoolImmPtr>()->value();
397 
398   for (size_t iter_op_inputs = 0; iter_op_inputs < op->inputs_shape().size(); iter_op_inputs++) {
399     Dimensions strategy = PrepareMatMulStrategy(node, transpose_a, transpose_b, iter_op_inputs);
400     strategies.push_back(strategy);
401   }
402   return strategies;
403 }
404 
PreparePropagateBatchMatMul(const std::shared_ptr<OperatorInfo> & op,Dimensions basic_stra,bool dyn_shape_tmp_fix)405 Strategies PreparePropagateBatchMatMul(const std::shared_ptr<OperatorInfo> &op, Dimensions basic_stra,
406                                        bool dyn_shape_tmp_fix) {
407   if (dyn_shape_tmp_fix) {
408     return CheckDivisible(op, basic_stra);
409   }
410   // This backward propagation does NOT complete strategy on k. Could be done later
411   Strategies stra;
412   auto input_value = op->input_value();
413   bool transpose_a = input_value[2]->cast<BoolImmPtr>()->value();
414   bool transpose_b = input_value[3]->cast<BoolImmPtr>()->value();
415 
416   size_t first_input_size = op->inputs_shape()[0].size();
417   size_t second_input_size = op->inputs_shape()[1].size();
418 
419   Dimensions first_input_dim(first_input_size);
420   Dimensions second_input_dim(second_input_size);
421 
422   // first input
423   if (!transpose_a) {
424     first_input_dim[first_input_size - 1] = 1;                                  // k axis
425     first_input_dim[first_input_size - 2] = basic_stra[basic_stra.size() - 2];  // i axis
426   } else {
427     first_input_dim[first_input_size - 2] = 1;                                  // k axis
428     first_input_dim[first_input_size - 1] = basic_stra[basic_stra.size() - 2];  // i axis
429   }
430 
431   for (size_t idx = 3; idx <= first_input_size; idx++) {
432     first_input_dim[first_input_size - idx] = basic_stra[basic_stra.size() - idx];
433   }
434 
435   // second input
436   if (!transpose_b) {
437     second_input_dim[second_input_size - 2] = 1;                                  // k axis
438     second_input_dim[second_input_size - 1] = basic_stra[basic_stra.size() - 1];  // j axis
439   } else {
440     second_input_dim[second_input_size - 1] = 1;                                  // k axis
441     second_input_dim[second_input_size - 2] = basic_stra[basic_stra.size() - 1];  // j axis
442   }
443 
444   for (size_t idx = 3; idx <= second_input_size; idx++) {
445     second_input_dim[second_input_size - idx] = basic_stra[basic_stra.size() - idx];
446   }
447 
448   stra.push_back(first_input_dim);
449   stra.push_back(second_input_dim);
450   return stra;
451 }
452 
PrepareBatchMatMulStrategy(Graph::NodeType * node,const bool transpose_a,const bool transpose_b,const size_t iter_op_inputs,const size_t dim_num)453 Dimensions PrepareBatchMatMulStrategy(Graph::NodeType *node, const bool transpose_a, const bool transpose_b,
454                                       const size_t iter_op_inputs, const size_t dim_num) {
455   if (node->apply.arguments[iter_op_inputs].tensor_str.str_n == 0 ||
456       node->apply.arguments[iter_op_inputs].tensor_str.str_c == 0 ||
457       node->apply.arguments[iter_op_inputs].tensor_str.str_h == 0 ||
458       node->apply.arguments[iter_op_inputs].tensor_str.str_w == 0) {
459     MS_LOG(EXCEPTION) << "The strategy is 0";
460   }
461 
462   Dimensions strategy;
463   if (dim_num >= SIZE_FOUR) {
464     strategy.push_back(static_cast<int64_t>(1.0 / node->apply.arguments[iter_op_inputs].tensor_str.str_n));
465   }
466   if (dim_num >= SIZE_THREE) {
467     strategy.push_back(static_cast<int64_t>(1.0 / node->apply.arguments[iter_op_inputs].tensor_str.str_c));
468   }
469   if (transpose_a && (iter_op_inputs == 0)) {
470     strategy.push_back(static_cast<int64_t>(1.0 / node->apply.arguments[iter_op_inputs].tensor_str.str_w));
471     strategy.push_back(static_cast<int64_t>(1.0 / node->apply.arguments[iter_op_inputs].tensor_str.str_h));
472   } else if (transpose_b && (iter_op_inputs == 1)) {
473     strategy.push_back(static_cast<int64_t>(1.0 / node->apply.arguments[iter_op_inputs].tensor_str.str_w));
474     strategy.push_back(static_cast<int64_t>(1.0 / node->apply.arguments[iter_op_inputs].tensor_str.str_h));
475   } else {
476     strategy.push_back(static_cast<int64_t>(1.0 / node->apply.arguments[iter_op_inputs].tensor_str.str_h));
477     strategy.push_back(static_cast<int64_t>(1.0 / node->apply.arguments[iter_op_inputs].tensor_str.str_w));
478   }
479   return strategy;
480 }
481 
PrepareBatchMatMul(Graph::NodeType * node,const std::shared_ptr<OperatorInfo> & op)482 Strategies PrepareBatchMatMul(Graph::NodeType *node, const std::shared_ptr<OperatorInfo> &op) {
483   Strategies strategies;
484   auto input_value = op->input_value();
485   bool transpose_a = input_value[2]->cast<BoolImmPtr>()->value();
486   bool transpose_b = input_value[3]->cast<BoolImmPtr>()->value();
487 
488   for (size_t iter_op_inputs = 0; iter_op_inputs < op->inputs_shape().size(); iter_op_inputs++) {
489     Dimensions strategy = PrepareBatchMatMulStrategy(node, transpose_a, transpose_b, iter_op_inputs,
490                                                      op->inputs_shape()[iter_op_inputs].size());
491     strategies.push_back(strategy);
492   }
493   return strategies;
494 }
495 
PrepareBiasAdd(const std::shared_ptr<OperatorInfo> & op,Dimensions basic_stra,bool dyn_shape_tmp_fix)496 Strategies PrepareBiasAdd(const std::shared_ptr<OperatorInfo> &op, Dimensions basic_stra, bool dyn_shape_tmp_fix) {
497   auto strategy = std::make_shared<Dimensions>(basic_stra);
498   Strategies strategies;
499   strategies.push_back(*strategy);
500   Dimensions s_biasadd;
501   s_biasadd.push_back(strategy->at(1));
502   strategies.push_back(s_biasadd);
503   return strategies;
504 }
505 
PrepareStandAlone(const std::shared_ptr<OperatorInfo> & op)506 Strategies PrepareStandAlone(const std::shared_ptr<OperatorInfo> &op) {
507   Strategies strategies;
508   Dimensions strategy;
509 
510   for (size_t i = 0; i < op->outputs_tensor_info().size(); i++) {
511     strategy.clear();
512     for (size_t j = 0; j < op->inputs_tensor_info()[i].shape().size(); j++) {
513       strategy.push_back(1);
514     }
515     strategies.push_back(strategy);
516   }
517 
518   return strategies;
519 }
520 
PrepareDataParallel(const std::shared_ptr<OperatorInfo> & op,Dimensions basic_stra,bool dyn_shape_tmp_fix)521 Strategies PrepareDataParallel(const std::shared_ptr<OperatorInfo> &op, Dimensions basic_stra, bool dyn_shape_tmp_fix) {
522   size_t numDev = g_device_manager->stage_device_num();
523 
524   Strategies strategies;
525   Dimensions strategy;
526 
527   if (numDev == 0) {
528     MS_LOG(EXCEPTION) << "The number of devices is 0";
529   }
530 
531   for (size_t i = 0; i < op->inputs_shape().size(); i++) {
532     strategy.clear();
533     if (LongToSize(op->inputs_shape()[i][0]) % numDev == 0) {
534       strategy.push_back(numDev);
535     } else {
536       strategy.push_back(1);
537     }
538     for (size_t j = 1; j < op->inputs_shape()[i].size(); j++) {
539       strategy.push_back(1);
540     }
541     strategies.push_back(strategy);
542   }
543 
544   return strategies;
545 }
546 
PrepareOneHotOutputStrategy(const std::shared_ptr<OperatorInfo> & op)547 Dimensions PrepareOneHotOutputStrategy(const std::shared_ptr<OperatorInfo> &op) {
548   auto op_strategy = op->selected_strategy();
549   Dimensions strategy;
550 
551   for (size_t i = 0; i < static_cast<size_t>(op->inputs_shape().size()); i++) {
552     if (op->inputs_shape()[i].size() == 0) {
553       continue;
554     }
555     // copy the full strategy (Assume strategy has the same size as the following operator input shape)
556     for (size_t j = 0; j < op_strategy->GetInputDim().at(i).size(); ++j) {
557       strategy.push_back(op_strategy->GetInputDim().at(i).at(j));
558     }
559     break;
560   }
561   return strategy;
562 }
563 
PrepareStridedSlice(const std::shared_ptr<OperatorInfo> & op,Dimensions basic_stra,bool dyn_shape_tmp_fix)564 Strategies PrepareStridedSlice(const std::shared_ptr<OperatorInfo> &op, Dimensions basic_stra, bool dyn_shape_tmp_fix) {
565   Strategies strategies;
566 
567   if (dyn_shape_tmp_fix) {
568     return strategies;
569   }
570 
571   auto strided_slice = std::static_pointer_cast<StridedSliceInfo>(op);
572   strided_slice->GetAttrs();
573   auto begin = strided_slice->begin();
574   auto strides = strided_slice->strides();
575   auto new_axis_mask_bitmap = strided_slice->new_axis_mask_bitmap();
576   auto fully_fetch_flag = strided_slice->fully_fetch_flag();
577   auto skip_redistribution = strided_slice->skip_redistribution();
578 
579   Shape strategy_in_process = Shape(basic_stra.size(), 0);
580   for (size_t i = 0; i < new_axis_mask_bitmap.size() && i < begin.size() && i < basic_stra.size(); ++i) {
581     if (new_axis_mask_bitmap[i]) {
582       strategy_in_process[i] = 1;
583     }
584   }
585 
586   size_t count = 0;
587   for (auto &ele : strategy_in_process) {
588     if (ele != 0) {
589       continue;
590     }
591     ele = basic_stra[count];
592     count++;
593   }
594 
595   (void)strategy_in_process.insert(strategy_in_process.end(), basic_stra.begin() + count, basic_stra.end());
596   MS_LOG(INFO) << op->name() << ": The strategy in process is " << strategy_in_process;
597 
598   for (size_t j = 0; j < strides.size(); ++j) {
599     if ((strides[j] != 1) && (strategy_in_process[j] > 1)) {
600       strategy_in_process[j] = 1;
601     }
602   }
603 
604   for (size_t k = 0; k < begin.size(); ++k) {
605     if (!fully_fetch_flag[k] && (strategy_in_process[k] != 1) && !skip_redistribution) {
606       strategy_in_process[k] = 1;
607     }
608   }
609 
610   strategies.push_back(strategy_in_process);
611   return strategies;
612 }
613 
FindAxisProperty(const std::shared_ptr<OperatorInfo> & op)614 std::vector<int64_t> FindAxisProperty(const std::shared_ptr<OperatorInfo> &op) {
615   std::vector<int64_t> axis_list;
616   string axis_name = AXIS;
617   auto input_value = op->input_value();
618 
619   auto op_name = op->name();
620 
621   if (input_value[input_value.size() - 1]->isa<ValueSequence>()) {  // Softmax axis is a tuple
622     std::optional<std::vector<int64_t>> axis_opt = GetArrayValueFromInputs<int64_t>(input_value, op_name, axis_name);
623     std::vector<int64_t> axis_val = axis_opt.value();
624     if (axis_opt.has_value()) {
625       axis_list.swap(axis_val);
626     } else {
627       axis_list.push_back(-1);
628     }
629   } else {  // LogSoftmax axis is a scaler
630     std::optional<int64_t> axis_opt = GetScalarValueFromInputs<int64_t>(input_value, op_name, axis_name);
631     int64_t axis_val = axis_opt.value();
632     if (axis_opt.has_value()) {
633       axis_list.push_back(axis_val);
634     } else {
635       axis_list.push_back(-1);
636     }
637   }
638   return axis_list;
639 }
640 
PrepareSoftMax(const std::shared_ptr<OperatorInfo> & op,Dimensions basic_stra,bool dyn_shape_tmp_fix)641 Strategies PrepareSoftMax(const std::shared_ptr<OperatorInfo> &op, Dimensions basic_stra, bool dyn_shape_tmp_fix) {
642   Strategies strategies;
643   strategies.push_back(basic_stra);
644   std::vector<int64_t> axis_list = FindAxisProperty(op);
645 
646   for (auto &axis : axis_list) {
647     if (axis < 0) {
648       int64_t input_dim = SizeToLong(op->inputs_shape()[0].size());
649       axis = input_dim + axis;
650     }
651     if (axis >= SizeToLong(strategies[0].size()) || axis < 0) {
652       MS_LOG(EXCEPTION) << op->name() << ": axis value is out of range.";
653     }
654     if (strategies[0][LongToSize(axis)] != 1) {
655       strategies[0][LongToSize(axis)] = 1;
656       MS_LOG(INFO) << op->name() << ": adjust strategy to 1 on axis " << axis;
657     }
658   }
659 
660   // Strategy protection to avoid that partition number is larger than the shape of related dimension.
661   for (size_t i = 0; i < op->inputs_shape().size(); i++) {
662     for (size_t j = 0; j < op->inputs_shape()[i].size(); j++) {
663       if (strategies[i][j] > op->inputs_shape()[i][j] || op->inputs_shape()[i][j] % strategies[i][j] != 0) {
664         strategies[i][j] = 1;
665       }
666     }
667   }
668 
669   return strategies;
670 }
671 
PrepareLayerNorm(const std::shared_ptr<OperatorInfo> & op,Dimensions basic_stra,bool dyn_shape_tmp_fix)672 Strategies PrepareLayerNorm(const std::shared_ptr<OperatorInfo> &op, Dimensions basic_stra, bool dyn_shape_tmp_fix) {
673   Strategies strategies;
674   strategies.push_back(basic_stra);
675   std::vector<int64_t> axis_list;
676   string axis_name = AXIS;
677 
678   auto iter = op->attrs().find(axis_name);
679   if (iter != op->attrs().end()) {
680     MS_EXCEPTION_IF_NULL(iter->second);
681     if (iter->second->isa<Int64Imm>()) {
682       axis_list.push_back(iter->second->cast<Int64ImmPtr>()->value());
683     } else if (iter->second->isa<ValueTuple>()) {
684       ValueTuplePtr value_tuple = iter->second->cast<ValueTuplePtr>();
685       if (value_tuple == nullptr) {
686         MS_LOG(EXCEPTION) << op->name() << ": The value_tuple is nullptr.";
687       }
688 
689       std::vector<ValuePtr> value_vector = value_tuple->value();
690       (void)std::transform(value_vector.begin(), value_vector.end(), std::back_inserter(axis_list),
691                            [](const ValuePtr &value) { return static_cast<int64_t>(GetValue<int64_t>(value)); });
692     } else {
693       MS_LOG(EXCEPTION) << op->name() << ": The value of axis is not int64_t or tuple int64_t.";
694     }
695   } else {
696     axis_list.push_back(-1);
697   }
698 
699   for (auto &axis : axis_list) {
700     if (axis < 0) {
701       int64_t input_dim = SizeToLong(op->inputs_shape()[0].size());
702       axis = input_dim + axis;
703     }
704     if (axis >= SizeToLong(strategies[0].size()) || axis < 0) {
705       MS_LOG(EXCEPTION) << op->name() << ": axis value is out of range.";
706     }
707     if (strategies[0][LongToSize(axis)] != 1) {
708       strategies[0][LongToSize(axis)] = 1;
709       MS_LOG(INFO) << op->name() << ": adjust strategy to 1 on axis " << axis;
710     }
711   }
712   Dimensions d = {1};
713   strategies.push_back(d);
714   strategies.push_back(d);
715   return strategies;
716 }
717 
PrepareRmsNorm(const std::shared_ptr<OperatorInfo> & op,Dimensions basic_stra,bool dyn_shape_tmp_fix)718 Strategies PrepareRmsNorm(const std::shared_ptr<OperatorInfo> &op, Dimensions basic_stra, bool dyn_shape_tmp_fix) {
719   Strategies strategies;
720   auto inputs = op->inputs_shape();
721   auto input = inputs[0];
722   Shape strategy_in_process = Shape(input.size(), 1);
723   int64_t devices = SizeToLong(g_device_manager->DeviceNum());
724 
725   if (parallel::ParallelContext::GetInstance()->pipeline_stage_split_num() == 0) {
726     MS_LOG(EXCEPTION) << "divisors cannot be 0!";
727   }
728   int64_t max_cut = devices / parallel::ParallelContext::GetInstance()->pipeline_stage_split_num();
729   strategy_in_process[0] = input[0] < max_cut ? input[0] : max_cut;
730 
731   auto gamma = inputs[1];
732   size_t gamma_diff = input.size() - gamma.size();
733   Dimensions gamma_strategy;
734   for (size_t j = 0; j < gamma.size(); ++j) {
735     gamma_strategy.push_back(strategy_in_process[gamma_diff + j]);
736   }
737 
738   strategies.push_back(strategy_in_process);
739   strategies.push_back(gamma_strategy);
740   return strategies;
741 }
742 
PrepareOneHot(const std::shared_ptr<OperatorInfo> & op,Dimensions strategy,bool dyn_shape_tmp_fix)743 Strategies PrepareOneHot(const std::shared_ptr<OperatorInfo> &op, Dimensions strategy, bool dyn_shape_tmp_fix) {
744   Strategies strategies;
745 
746   // OneHot's strategy depends on its output shape.
747   for (size_t i = strategy.size(); i < op->outputs_shape()[0].size(); i++) {
748     strategy.push_back(1);
749   }
750 
751   // Partition number should not exceed the number of devices
752   for (size_t i = 0; i < op->outputs_shape()[0].size(); i++) {
753     if (strategy[i] > op->outputs_shape()[0][i]) {
754       strategy[i] = 1;
755     }
756   }
757 
758   strategies.push_back(strategy);
759 
760   // Push two empty Dimensions for the other two input tensors.
761   Dimensions s_empty = {};
762   strategies.push_back(s_empty);
763   strategies.push_back(s_empty);
764 
765   return strategies;
766 }
767 
GenGatherStra(Shape targeted_shape)768 Dimensions GenGatherStra(Shape targeted_shape) {
769   Dimensions index(targeted_shape.size() - 1, 0);
770   for (size_t i = 0; i < index.size(); i++) {
771     index[i] = SizeToLong(i);
772   }
773 
774   std::sort(index.begin(), index.end(), [&targeted_shape](const size_t &a, const size_t &b) {
775     return (targeted_shape[a + 1] > targeted_shape[b + 1]);
776   });
777   (void)std::transform(std::begin(index), std::end(index), std::begin(index), [](int64_t x) { return x + 1; });
778   (void)index.insert(index.cbegin(), 0);
779 
780   Dimensions strategie(targeted_shape.size(), 1);
781 
782   size_t num_device = LongToSize(g_device_manager->stage_device_num());
783   size_t cut = 1;
784   for (size_t i = 0; i < index.size(); i++) {
785     size_t index_i = LongToSize(index[i]);
786     while (targeted_shape[index_i] % SIZE_TWO == 0 && targeted_shape[index_i] > 0 && cut < num_device) {
787       targeted_shape[index_i] /= SIZE_TWO;
788       cut *= SIZE_TWO;
789       strategie[index_i] *= SIZE_TWO;  // We apply 2-parts partitioning for Gather.
790     }
791     if (cut == num_device) {
792       break;
793     }
794   }
795 
796   return strategie;
797 }
798 
GatherForDynamicShape(const std::shared_ptr<OperatorInfo> & op,const size_t dim)799 Strategies GatherForDynamicShape(const std::shared_ptr<OperatorInfo> &op, const size_t dim) {
800   Strategies strategies;
801   auto gather_input_0_shape = op->inputs_shape()[0];
802   if (dim >= gather_input_0_shape.size()) {
803     MS_LOG(EXCEPTION) << "Failure: Gather's axis out of range.";
804   }
805   Dimensions gather_input_0_strategy(gather_input_0_shape.size(), 1);
806   int64_t num_device = g_device_manager->stage_device_num();
807   if (gather_input_0_shape[dim] % num_device == 0) {
808     size_t cut = 1;
809     while (gather_input_0_shape[dim] > 0 && gather_input_0_shape[dim] % SIZE_TWO == 0 && cut < LongToSize(num_device)) {
810       gather_input_0_shape[dim] /= SIZE_TWO;
811       cut *= SIZE_TWO;
812       gather_input_0_strategy[dim] *= SIZE_TWO;
813     }
814   }
815   strategies.push_back(gather_input_0_strategy);
816   for (size_t i = 1; i < op->inputs_shape().size(); i++) {
817     Dimensions gather_input_i_strategy(op->inputs_shape()[i].size(), 1);
818     strategies.push_back(gather_input_i_strategy);
819   }
820   return strategies;
821 }
822 
PrepareGather(const std::shared_ptr<OperatorInfo> & op,Dimensions strategy,bool dyn_shape_tmp_fix)823 Strategies PrepareGather(const std::shared_ptr<OperatorInfo> &op, Dimensions strategy, bool dyn_shape_tmp_fix) {
824   if (dyn_shape_tmp_fix) {
825     Strategies strategies;
826     strategies.push_back(strategy);
827     for (size_t i = 1; i < op->inputs_shape().size(); i++) {
828       Dimensions gather_input_i_strategy(op->inputs_shape()[i].size(), 1);
829       strategies.push_back(gather_input_i_strategy);
830     }
831     return strategies;
832   }
833 
834   Strategies strategies;
835   Shape targeted_shape = op->outputs_shape()[0];
836   Dimensions strategie = GenGatherStra(targeted_shape);
837 
838   int64_t axis = GetGatherAxis(op);
839   MS_LOG(INFO) << op->name() << ": the axis is " << axis;
840 
841   int64_t batch_dims = GetGatherBatchDims(op);
842   MS_LOG(INFO) << op->name() << ": the batch_dims is " << batch_dims;
843 
844   if (batch_dims > 1) {
845     for (size_t i = 0; i < op->inputs_shape().size(); i++) {
846       strategies.push_back(strategie);
847     }
848     strategies[0][axis] = 1;
849     return strategies;
850   }
851 
852   strategy.clear();
853   if (axis == 0) {
854     Shape param_strategy = Shape(op->inputs_shape()[0].size(), 1);
855     Shape indices_strategy = Shape(op->inputs_shape()[1].size(), 1);
856     strategies.push_back(param_strategy);
857     strategies.push_back(indices_strategy);
858     size_t num_device = LongToSize(g_device_manager->stage_device_num());
859     size_t cut = 1;
860     int gather_inputs_num = SizeToInt(op->inputs_shape().size());
861     for (int i = gather_inputs_num - 1; i >= 0; --i) {
862       auto tensor_shape = op->inputs_shape()[i];
863       while (tensor_shape[0] % SIZE_TWO == 0 && tensor_shape[0] > 0 && cut < num_device) {
864         tensor_shape[0] /= SIZE_TWO;
865         cut *= SIZE_TWO;
866         strategies[i][0] *= SIZE_TWO;  // We apply 2-parts partitioning for Gather.
867       }
868       if (cut == num_device) {
869         break;
870       }
871     }
872   } else if (axis == 1) {
873     strategy.push_back(strategie[0]);
874     strategy.push_back(1);
875     strategies.push_back(strategy);
876     strategy.clear();
877     for (size_t i = 0; i < op->inputs_shape()[1].size(); i++) {
878       strategy.push_back(strategie[op->inputs_shape()[0].size() - 1 + i]);
879     }
880     strategies.push_back(strategy);
881   } else {
882     MS_LOG(EXCEPTION) << "Failure: Normal Gather's axis is neither 0 nor 1.";
883   }
884 
885   auto gather = std::static_pointer_cast<GatherInfo>(op);
886   auto gather_mode = gather->GetGatherMode(strategies[0], strategies[1]);
887   MS_LOG(INFO) << op->name() << ": the gather_mode is " << gather_mode;
888   if (gather_mode == SHARD_AXIS_0_DYNAMIC || gather_mode == SHARD_AXIS_0_STATIC || gather_mode == SHARD_AXIS_1) {
889     if (DevicesForDimensions(strategies[1]) != 1 && strategies[0][axis] != 1) {
890       strategies[0][axis] = 1;
891       MS_LOG(INFO) << op->name() << ": param_strategy[" << axis << "] is changed to 1.";
892     }
893   }
894 
895   return strategies;
896 }
897 
PrepareGatherV2OutputStrategy(const std::shared_ptr<OperatorInfo> & op)898 Dimensions PrepareGatherV2OutputStrategy(const std::shared_ptr<OperatorInfo> &op) {
899   auto targeted_shape = op->outputs_shape()[0];
900   Dimensions strategie = GenGatherStra(targeted_shape);
901   return strategie;
902 }
903 
PrepareL2Normalize(const std::shared_ptr<OperatorInfo> & op,Dimensions strategy,bool dyn_shape_tmp_fix)904 Strategies PrepareL2Normalize(const std::shared_ptr<OperatorInfo> &op, Dimensions strategy, bool dyn_shape_tmp_fix) {
905   int64_t axis = 0;
906   auto iter = op->attrs().find(AXIS);
907   if (iter != op->attrs().end()) {
908     MS_EXCEPTION_IF_NULL(iter->second);
909     if (iter->second->isa<ValueSequence>()) {
910       axis = GetValue<std::vector<int64_t>>(iter->second)[0];
911     } else {
912       MS_LOG(EXCEPTION) << op->name() << " : The value of axis is not int64_t.";
913     }
914   }
915 
916   int64_t axis_index = axis;
917   if (axis < 0) {
918     size_t input_dim = op->inputs_shape()[0].size();
919     axis_index = static_cast<int64_t>(input_dim) + axis;
920   }
921 
922   strategy[LongToSize(axis_index)] = 1;
923 
924   Strategies strategies;
925   strategies.push_back(strategy);
926   return strategies;
927 }
928 
PrepareAxisRelatedStrategy(Graph::NodeType * node,const std::vector<std::shared_ptr<OperatorInfo>> & ops,const size_t iter_ops)929 Strategies PrepareAxisRelatedStrategy(Graph::NodeType *node, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
930                                       const size_t iter_ops) {
931   Strategies strategies = MakeRecSearchStrategy(node, ops, iter_ops);
932   if (strategies.size() < 1) {
933     MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": get empty Strategy.";
934   }
935 
936   std::vector<int64_t> axis_list;
937   string axis_name = AXIS;
938   int64_t default_axis = -1;
939   if (ops[iter_ops]->type() == LAYER_NORM) {
940     axis_name = "begin_norm_axis";
941     default_axis = 1;
942   }
943 
944   auto iter = ops[iter_ops]->attrs().find(axis_name);
945   if (iter != ops[iter_ops]->attrs().end()) {
946     MS_EXCEPTION_IF_NULL(iter->second);
947     if (iter->second->isa<Int64Imm>()) {
948       axis_list.push_back(iter->second->cast<Int64ImmPtr>()->value());
949     } else if (iter->second->isa<ValueTuple>()) {
950       ValueTuplePtr value_tuple = iter->second->cast<ValueTuplePtr>();
951       if (value_tuple == nullptr) {
952         MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": The value_tuple is nullptr.";
953       }
954       std::vector<ValuePtr> value_vector = value_tuple->value();
955       (void)std::transform(value_vector.begin(), value_vector.end(), std::back_inserter(axis_list),
956                            [](const ValuePtr &value) { return static_cast<int64_t>(GetValue<int64_t>(value)); });
957     } else {
958       MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": The value of axis is not int64_t or tuple int64_t.";
959     }
960   } else {
961     axis_list.push_back(default_axis);
962   }
963 
964   for (auto &axis : axis_list) {
965     if (axis < 0) {
966       int64_t input_dim = SizeToLong(ops[iter_ops]->inputs_shape()[0].size());
967       axis = input_dim + axis;
968     }
969     if (axis >= SizeToLong(strategies[0].size()) || axis < 0) {
970       MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": axis value is out of range.";
971     }
972     if (strategies[0][LongToSize(axis)] != 1) {
973       strategies[0][LongToSize(axis)] = 1;
974       MS_LOG(INFO) << ops[iter_ops]->name() << ": adjust strategy to 1 on axis " << axis;
975     }
976   }
977   return strategies;
978 }
979 
MakeRecSearchStrategy(Graph::NodeType * node,const std::vector<std::shared_ptr<OperatorInfo>> & ops,const size_t iter_ops)980 Strategies MakeRecSearchStrategy(Graph::NodeType *node, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
981                                  const size_t iter_ops) {
982   if (ops.empty()) {
983     MS_LOG(EXCEPTION) << "Failure: Operators is empty.";
984   }
985   if (iter_ops >= ops.size()) {
986     MS_LOG(EXCEPTION) << "Failure: Operators' elements out of range.";
987   }
988   if (node->apply.op_type == kRecUnsortedSegmentOp) {
989     return MakeDataParallelStrategy(node, ops, iter_ops);
990   }
991 
992   Strategies strategies;
993   for (size_t iter_op_inputs = 0; iter_op_inputs < ops[iter_ops]->inputs_shape().size(); iter_op_inputs++) {
994     if (iter_op_inputs >= ops[iter_ops]->inputs_shape().size()) {
995       MS_LOG(EXCEPTION) << "Failure: Strategy's InputDim out of range.";
996     }
997 
998     size_t input_size = ops[iter_ops]->inputs_shape()[iter_op_inputs].size();
999     Dimensions strategy;
1000     if (input_size == SIZE_FOUR) {
1001       strategy.push_back(static_cast<int64_t>(1.0 / node->apply.arguments[iter_op_inputs].tensor_str.str_n));
1002       strategy.push_back(static_cast<int64_t>(1.0 / node->apply.arguments[iter_op_inputs].tensor_str.str_c));
1003       strategy.push_back(static_cast<int64_t>(1.0 / node->apply.arguments[iter_op_inputs].tensor_str.str_h));
1004       strategy.push_back(static_cast<int64_t>(1.0 / node->apply.arguments[iter_op_inputs].tensor_str.str_w));
1005     } else if (input_size == SIZE_THREE) {
1006       // Experimental support for 3D data.
1007       strategy.push_back(static_cast<int64_t>(1.0 / node->apply.arguments[iter_op_inputs].tensor_str.str_c));
1008       strategy.push_back(static_cast<int64_t>(1.0 / node->apply.arguments[iter_op_inputs].tensor_str.str_h));
1009       strategy.push_back(static_cast<int64_t>(1.0 / node->apply.arguments[iter_op_inputs].tensor_str.str_w));
1010     } else if (input_size == SIZE_TWO) {
1011       strategy.push_back(static_cast<int64_t>(1.0 / node->apply.arguments[iter_op_inputs].tensor_str.str_h));
1012       strategy.push_back(static_cast<int64_t>(1.0 / node->apply.arguments[iter_op_inputs].tensor_str.str_w));
1013     } else if (input_size == SIZE_ONE) {
1014       strategy.push_back(static_cast<int64_t>(1.0 / node->apply.arguments[iter_op_inputs].tensor_str.str_w));
1015     } else if (input_size == SIZE_ZERO) {
1016       strategy = {};
1017     } else {
1018       MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": Tensor's input size is unexcepted.";
1019     }
1020     strategies.push_back(strategy);
1021   }
1022   return strategies;
1023 }
1024 
MakeDataParallelStrategy(Graph::NodeType * node,const std::vector<std::shared_ptr<OperatorInfo>> & ops,const size_t iter_ops)1025 Strategies MakeDataParallelStrategy(Graph::NodeType *node, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
1026                                     const size_t iter_ops) {
1027   if (ops.empty()) {
1028     MS_LOG(EXCEPTION) << "Failure: Operators is empty.";
1029   }
1030   if (iter_ops >= ops.size()) {
1031     MS_LOG(EXCEPTION) << "Failure: Operators' elements out of range.";
1032   }
1033 
1034   Strategies strategies;
1035   size_t max_device_num = LongToSize(g_device_manager->stage_device_num());
1036   size_t target_tensor_batch = LongToUlong(ops[iter_ops]->inputs_shape()[0][0]);
1037   for (size_t iter_op_inputs = 0; iter_op_inputs < ops[iter_ops]->inputs_shape().size(); iter_op_inputs++) {
1038     if (iter_op_inputs >= ops[iter_ops]->inputs_shape().size()) {
1039       MS_LOG(EXCEPTION) << "Failure: Strategy's InputDim out of range.";
1040     }
1041 
1042     Dimensions strategy;
1043     size_t input_size = ops[iter_ops]->inputs_shape()[iter_op_inputs].size();
1044     for (size_t dim = 0; dim < input_size; dim++) {
1045       // Experimental support for 3D data (input_size == 3).
1046       if (input_size >= SIZE_ONE && input_size <= STR_DIM_NUM) {
1047         if (dim == 0) {
1048           strategy.push_back(std::min(max_device_num, target_tensor_batch));
1049         } else {
1050           strategy.push_back(1);
1051         }
1052       } else if (input_size == 0) {
1053         strategy = {};
1054       } else {
1055         MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": Tensor shape " << input_size << " is unexpected.";
1056       }
1057     }
1058     strategies.push_back(strategy);
1059   }
1060   // Set default strategy.
1061   node->tensor_parm.tensor_str.str_n = 1.0;
1062   node->tensor_parm.tensor_str.str_c = 1.0;
1063   node->tensor_parm.tensor_str.str_h = 1.0;
1064   node->tensor_parm.tensor_str.str_w = 1.0;
1065 
1066   // Update data parallel strategy.
1067   if (ops[iter_ops]->outputs_shape().size() == SIZE_ZERO) {
1068     MS_LOG(EXCEPTION) << ops[iter_ops]->name() << " output tensor info is empty.";
1069   }
1070   if (ops[iter_ops]->outputs_shape()[0].size() == SIZE_ONE) {
1071     node->tensor_parm.tensor_str.str_w = 1.0 / std::min(max_device_num, target_tensor_batch);
1072   } else if (ops[iter_ops]->outputs_shape()[0].size() == SIZE_TWO) {
1073     node->tensor_parm.tensor_str.str_h = 1.0 / std::min(max_device_num, target_tensor_batch);
1074   } else if (ops[iter_ops]->outputs_shape()[0].size() == SIZE_THREE) {
1075     // Experimental support for 3D data.
1076     node->tensor_parm.tensor_str.str_c = 1.0 / std::min(max_device_num, target_tensor_batch);
1077   } else if (ops[iter_ops]->outputs_shape()[0].size() == SIZE_FOUR) {  // Experimental support for 4D data.
1078     node->tensor_parm.tensor_str.str_n = 1.0 / std::min(max_device_num, target_tensor_batch);
1079   } else {
1080     MS_LOG(INFO) << ops[iter_ops]->name() << " output tensor shape is unexpected, using default value instead.";
1081   }
1082 
1083   return strategies;
1084 }
1085 
MakeFullBatchStrategy(Graph::NodeType * node,const std::vector<std::shared_ptr<OperatorInfo>> & ops,const size_t iter_ops)1086 Strategies MakeFullBatchStrategy(Graph::NodeType *node, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
1087                                  const size_t iter_ops) {
1088   if (ops.empty()) {
1089     MS_LOG(EXCEPTION) << "Failure: Operators is empty.";
1090   }
1091   if (iter_ops >= ops.size()) {
1092     MS_LOG(EXCEPTION) << "Failure: Operators' elements out of range.";
1093   }
1094 
1095   Strategies strategies;
1096   for (size_t iter_op_inputs = 0; iter_op_inputs < ops[iter_ops]->inputs_shape().size(); iter_op_inputs++) {
1097     if (iter_op_inputs >= ops[iter_ops]->inputs_shape().size()) {
1098       MS_LOG(EXCEPTION) << "Failure: Strategy's InputDim out of range.";
1099     }
1100     Dimensions strategy;
1101     size_t input_size = ops[iter_ops]->inputs_shape()[iter_op_inputs].size();
1102     for (size_t dim = 0; dim < input_size; dim++) {
1103       if (input_size >= SIZE_ONE && input_size <= SIZE_FOUR) {
1104         strategy.push_back(1);
1105       } else if (input_size == 0) {
1106         strategy = {};
1107       } else {
1108         MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": Tensor shape " << input_size << " is unexpected.";
1109       }
1110     }
1111     strategies.push_back(strategy);
1112   }
1113   // Update the output strategy of Rec Graph
1114   node->tensor_parm.tensor_str.str_n = 1.0;
1115   node->tensor_parm.tensor_str.str_c = 1.0;
1116   node->tensor_parm.tensor_str.str_h = 1.0;
1117   node->tensor_parm.tensor_str.str_w = 1.0;
1118 
1119   return strategies;
1120 }
1121 
SetBackToRawStrategy(const std::shared_ptr<OperatorInfo> & op)1122 void SetBackToRawStrategy(const std::shared_ptr<OperatorInfo> &op) {
1123   Strategies strategies;
1124 
1125   for (size_t iter_strategy = 0; iter_strategy < op->inputs_shape().size(); iter_strategy++) {
1126     Dimensions strategy;
1127     size_t strategy_size = op->inputs_shape()[iter_strategy].size();
1128     for (size_t dim = 0; dim < strategy_size; dim++) {
1129       if (strategy_size >= SIZE_ONE && strategy_size <= SIZE_FOUR) {
1130         strategy.push_back(1);
1131       } else if (strategy_size == 0) {
1132         strategy = {};
1133       } else {
1134         MS_LOG(EXCEPTION) << op->name() << ": Strategy size " << strategy_size << " is unmatched.";
1135       }
1136     }
1137     strategies.push_back(strategy);
1138   }
1139 
1140   StrategyPtr sp = std::make_shared<Strategy>(0, strategies);
1141 
1142   op->SetSelectedStrategyAndCost(sp, op->selected_cost());
1143 }
1144 
PrepareStrategy(Graph::NodeType * node,const std::vector<std::shared_ptr<OperatorInfo>> & ops,const size_t iter_ops,const bool dyn_shape_tmp_fix)1145 Strategies PrepareStrategy(Graph::NodeType *node, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
1146                            const size_t iter_ops, const bool dyn_shape_tmp_fix) {
1147   if (ops.empty()) {
1148     MS_LOG(EXCEPTION) << "Failure: Operators is empty.";
1149   }
1150   if (iter_ops >= ops.size()) {
1151     MS_LOG(EXCEPTION) << "Failure: Operators' elements out of range.";
1152   }
1153   MS_EXCEPTION_IF_NULL(ops[iter_ops]);
1154 
1155   auto type = ops[iter_ops]->type();
1156   MS_LOG(INFO) << "Processing main operator " << ops[iter_ops]->name() << " (type=" << type << ")";
1157   if (type == MATMUL) {
1158     return PrepareMatMul(node, ops[iter_ops]);
1159   } else if (dyn_shape_tmp_fix && type == BATCH_MATMUL) {
1160     return PrepareBatchMatMul(node, ops[iter_ops]);
1161   } else if (type == LAYER_NORM) {
1162     return PrepareAxisRelatedStrategy(node, ops, iter_ops);
1163   } else if (type == SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS) {
1164     return MakeDataParallelStrategy(node, ops, iter_ops);
1165   } else if (type == VIRTUAL_DATA_SET) {
1166     if (ParallelContext::GetInstance()->full_batch()) {
1167       return MakeFullBatchStrategy(node, ops, iter_ops);
1168     } else {
1169       return MakeDataParallelStrategy(node, ops, iter_ops);
1170     }
1171   } else {
1172     return MakeRecSearchStrategy(node, ops, iter_ops);
1173   }
1174 }
1175 
CheckVirtualDatasetStrategy(Graph::NodeType * node)1176 float CheckVirtualDatasetStrategy(Graph::NodeType *node) {
1177   // The values for str can only be 1.0, 0.5, 0.25, 0.125…
1178   // We want to find out the first str that is smaller than 1
1179   if (node->tensor_parm.tensor_str.str_n < 0.9) {
1180     return node->tensor_parm.tensor_str.str_n;
1181   }
1182   if (node->tensor_parm.tensor_str.str_c < 0.9) {
1183     return node->tensor_parm.tensor_str.str_c;
1184   }
1185   if (node->tensor_parm.tensor_str.str_h < 0.9) {
1186     return node->tensor_parm.tensor_str.str_h;
1187   }
1188   if (node->tensor_parm.tensor_str.str_w < 0.9) {
1189     return node->tensor_parm.tensor_str.str_w;
1190   }
1191   return 1.0;
1192 }
1193 
CopyVirtualDataset(Graph::NodeType * node,const std::shared_ptr<OperatorInfo> & op,float epsilon=0.00005f)1194 Dimensions CopyVirtualDataset(Graph::NodeType *node, const std::shared_ptr<OperatorInfo> &op,
1195                               float epsilon = 0.00005f) {
1196   Dimensions strategy;
1197   auto input_stra_dim = op->inputs_shape()[0].size();
1198   auto virtual_dataset_str = CheckVirtualDatasetStrategy(node);
1199   MS_EXCEPTION_IF_ZERO("Virtual_Dataset", virtual_dataset_str);
1200   if (input_stra_dim == 0) {
1201     return strategy;
1202   } else {
1203     if (std::fabs(virtual_dataset_str) < epsilon) {
1204       strategy.push_back(1);
1205     } else {
1206       strategy.push_back(FloatToLong(1 / virtual_dataset_str));
1207     }
1208     for (size_t i = 1; i < input_stra_dim; i++) {
1209       strategy.push_back(1);
1210     }
1211   }
1212   return strategy;
1213 }
1214 
CopyIncomingOperatorOutputStrategy(Graph::NodeType * node,const std::vector<std::shared_ptr<OperatorInfo>> & ops,const size_t iter_ops,const size_t incoming_op_index)1215 Dimensions CopyIncomingOperatorOutputStrategy(Graph::NodeType *node,
1216                                               const std::vector<std::shared_ptr<OperatorInfo>> &ops,
1217                                               const size_t iter_ops, const size_t incoming_op_index) {
1218   Dimensions strategy;
1219 
1220   if (ops[incoming_op_index]->type() == VIRTUAL_DATA_SET) {
1221     strategy = CopyVirtualDataset(node, ops[iter_ops]);
1222     return strategy;
1223   }
1224 
1225   for (auto inputs_shape : ops[iter_ops]->inputs_shape()) {
1226     auto input_stra_dim = inputs_shape.size();
1227     if (input_stra_dim == SIZE_ZERO) {
1228       continue;
1229     }
1230     if (input_stra_dim == SIZE_ONE) {
1231       strategy.push_back(FloatToLong(1 / node->tensor_parm.tensor_str.str_w));
1232     } else if (input_stra_dim == SIZE_TWO) {
1233       strategy.push_back(FloatToLong(1 / node->tensor_parm.tensor_str.str_h));
1234       strategy.push_back(FloatToLong(1 / node->tensor_parm.tensor_str.str_w));
1235     } else if (input_stra_dim == SIZE_THREE) {
1236       // Experimental support for 3D data.
1237       strategy.push_back(FloatToLong(1 / node->tensor_parm.tensor_str.str_c));
1238       strategy.push_back(FloatToLong(1 / node->tensor_parm.tensor_str.str_h));
1239       strategy.push_back(FloatToLong(1 / node->tensor_parm.tensor_str.str_w));
1240     } else if (input_stra_dim == SIZE_FOUR) {
1241       strategy.push_back(FloatToLong(1 / node->tensor_parm.tensor_str.str_n));
1242       strategy.push_back(FloatToLong(1 / node->tensor_parm.tensor_str.str_c));
1243       strategy.push_back(FloatToLong(1 / node->tensor_parm.tensor_str.str_h));
1244       strategy.push_back(FloatToLong(1 / node->tensor_parm.tensor_str.str_w));
1245     } else {
1246       MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": Tensor's shape is unknown.";
1247     }
1248     break;
1249   }
1250   return strategy;
1251 }
1252 
PrepareReshape(std::vector<int64_t> from_shape,std::vector<int64_t> to_shape,std::vector<int64_t> from_strat)1253 Dimensions PrepareReshape(std::vector<int64_t> from_shape, std::vector<int64_t> to_shape,
1254                           std::vector<int64_t> from_strat) {
1255   Dimensions to_strat(to_shape.size(), 1);
1256   std::vector<int64_t> from_shape_cpy(from_shape);
1257   std::vector<int64_t> to_shape_cpy(to_shape);
1258   size_t from_idx = 0;
1259   size_t to_idx = 0;
1260 
1261   // Attempt to assign full strategy to one dimension
1262   while (from_idx < from_shape.size() && to_idx < to_shape.size()) {
1263     if (from_shape[from_idx] > to_shape[to_idx]) {
1264       if (to_shape[to_idx] % from_strat[from_idx] == 0) {
1265         to_strat[to_idx] *= from_strat[from_idx];
1266         from_strat[from_idx] = 1;
1267       }
1268       from_shape[from_idx] /= to_shape[to_idx];
1269       to_idx++;
1270     } else if (from_shape[from_idx] < to_shape[to_idx]) {
1271       to_shape[to_idx] /= from_shape[from_idx];
1272       from_idx++;
1273     } else {
1274       if (to_shape[to_idx] % from_strat[from_idx] == 0) {
1275         to_strat[to_idx] *= from_strat[from_idx];
1276         from_strat[from_idx] = 1;
1277       }
1278       from_idx++;
1279       to_idx++;
1280     }
1281   }
1282 
1283   // Reset shapes & indices
1284   from_idx = 0;
1285   to_idx = 0;
1286   from_shape = from_shape_cpy;
1287   to_shape = to_shape_cpy;
1288 
1289   // Assign remaining strategy
1290   while (from_idx < from_shape.size() && to_idx < to_shape.size()) {
1291     if (from_shape[from_idx] > to_shape[to_idx]) {
1292       int64_t d = std::gcd(from_strat[from_idx], to_shape[to_idx]);
1293       to_strat[to_idx] *= d;
1294       from_strat[from_idx] /= d;
1295       from_shape[from_idx] /= to_shape[to_idx];
1296       to_idx++;
1297     } else if (from_shape[from_idx] < to_shape[to_idx]) {
1298       to_strat[to_idx] *= from_strat[from_idx];
1299       to_shape[to_idx] /= from_shape[from_idx];
1300       from_idx++;
1301     } else {  // equal case
1302       to_strat[to_idx] *= from_strat[from_idx];
1303       from_idx++;
1304       to_idx++;
1305     }
1306   }
1307   return to_strat;
1308 }
1309 
PrepareReshapeOutputStrategy(const std::shared_ptr<OperatorInfo> & op)1310 Dimensions PrepareReshapeOutputStrategy(const std::shared_ptr<OperatorInfo> &op) {
1311   auto output_shape = op->outputs_shape()[0];
1312   auto input_shape = op->inputs_shape()[0];
1313   auto strategy = op->selected_strategy();
1314 
1315   return PrepareReshape(input_shape, output_shape, strategy->GetInputDim()[0]);
1316 }
1317 
PrepareTransposeOutputStrategy(const std::shared_ptr<OperatorInfo> & op)1318 Dimensions PrepareTransposeOutputStrategy(const std::shared_ptr<OperatorInfo> &op) {
1319   Dimensions strategy;
1320   auto permutation = GetValue<std::vector<int64_t>>(op->input_value().at(1));
1321   auto op_strategy = op->selected_strategy();
1322   // The strategies are assigned according to the order in permutation (user defined).
1323   for (size_t i = 0; i < permutation.size(); i++) {
1324     strategy.push_back(op_strategy->GetInputDim()[0][LongToSize(permutation[i])]);
1325   }
1326   return strategy;
1327 }
1328 
PrepareExpandDimsOutputStrategy(const std::shared_ptr<OperatorInfo> & op)1329 Dimensions PrepareExpandDimsOutputStrategy(const std::shared_ptr<OperatorInfo> &op) {
1330   Dimensions strategy;
1331 
1332   auto axis_input = GetValue<int64_t>(op->input_value().at(1));
1333   auto op_strategy = op->selected_strategy();
1334   bool already_expand = false;
1335 
1336   // axis_input can be negative, in which case the index is computed backward from the shape size.
1337   if (axis_input < 0) {
1338     axis_input = SizeToLong(op->inputs_shape()[0].size()) + axis_input + 1;
1339   }
1340 
1341   // The strategy of the expanded dimension will be assigned 1, the others take the strategies of corresponding
1342   // dimensions.
1343   for (size_t i = 0; i < op->inputs_shape()[0].size() + 1; i++) {
1344     if (UlongToLong(i) == axis_input) {
1345       strategy.push_back(1);
1346       already_expand = true;
1347     } else if (UlongToLong(i) != axis_input && !already_expand) {
1348       strategy.push_back(op_strategy->GetInputDim()[0][i]);
1349     } else {
1350       if (i < 1) {
1351         MS_LOG(EXCEPTION) << "The index i -1 is less than 0. Please check the situation.";
1352       }
1353       strategy.push_back(op_strategy->GetInputDim()[0][i - 1]);
1354     }
1355   }
1356 
1357   return strategy;
1358 }
1359 
PrepareCumOutputStrategy(const std::shared_ptr<OperatorInfo> & op)1360 Dimensions PrepareCumOutputStrategy(const std::shared_ptr<OperatorInfo> &op) {
1361   Dimensions strategy;
1362 
1363   int64_t axis_input = 1;
1364 
1365   if (op->input_value().at(1)->isa<Int64Imm>()) {
1366     axis_input = GetValue<int64_t>(op->input_value().at(1));
1367     MS_LOG(INFO) << op->name() << "is a prefix sum on axis " << axis_input;
1368   } else {
1369     MS_LOG(INFO) << op->name() << "that is supposedly a cum op, has an axis that is NOT an int64";
1370   }
1371 
1372   auto op_strategy = op->selected_strategy();
1373 
1374   // axis_input can be negative, in which case the index is computed backward from the shape size.
1375   if (axis_input < 0) {
1376     axis_input = op->inputs_shape()[0].size() + axis_input + 1;
1377   }
1378 
1379   // The strategy of the cumulated axis will be assigned 1, the others take the strategies of corresponding dimensions.
1380   for (size_t i = 0; i < op->inputs_shape()[0].size(); i++) {
1381     if ((int64_t)i == axis_input) {
1382       strategy.push_back(1);
1383     } else {
1384       strategy.push_back(op_strategy->GetInputDim()[0][i]);
1385     }
1386   }
1387 
1388   return strategy;
1389 }
1390 
GetReduceAxisList(const std::shared_ptr<OperatorInfo> & op)1391 ShapeVector GetReduceAxisList(const std::shared_ptr<OperatorInfo> &op) {
1392   ShapeVector axis_list;
1393   auto input_value = op->input_value();
1394   auto input_dim = op->inputs_shape()[0].size();
1395 
1396   if (input_value.back()->isa<ValueTuple>()) {
1397     auto attr_axis = GetValue<std::vector<int64_t>>(input_value.back());
1398     if (attr_axis.empty()) {
1399       for (size_t i = 0; i < input_dim; i++) {
1400         axis_list.push_back(i);
1401       }
1402     } else {
1403       axis_list = attr_axis;
1404     }
1405   } else if (input_value.back()->isa<Int64Imm>()) {
1406     int64_t axis = GetValue<int64_t>(input_value.back());
1407     axis_list.push_back(axis < 0 ? axis + SizeToLong(input_dim) : axis);
1408   } else {
1409     MS_LOG(EXCEPTION) << "Failure: Axis type is invalid." << std::endl;
1410   }
1411 
1412   return axis_list;
1413 }
1414 
PrepareCumInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> & ops,size_t i_ops,size_t outgoing_op_index,size_t i_input)1415 Dimensions PrepareCumInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops, size_t i_ops,
1416                                    size_t outgoing_op_index, size_t i_input) {
1417   Dimensions strategy;
1418   int64_t axis_input = 1;
1419 
1420   if (ops[i_ops]->input_value().at(1)->isa<Int64Imm>()) {
1421     axis_input = GetValue<int64_t>(ops[i_ops]->input_value().at(1));
1422     MS_LOG(INFO) << ops[i_ops]->name() << "is a prefix sum on axis " << axis_input;
1423   } else {
1424     MS_LOG(INFO) << ops[i_ops]->name() << "that is supposedly a cumulative op has an axis that is NOT an int64";
1425   }
1426 
1427   auto op_strategy = ops[outgoing_op_index]->selected_strategy();
1428 
1429   size_t n_dim = op_strategy->GetInputDim()[i_input].size();
1430 
1431   if (axis_input < 0) {
1432     axis_input = n_dim + LongToSize(axis_input);
1433   }
1434 
1435   MS_EXCEPTION_IF_CHECK_FAIL(axis_input >= 0, "Input axis is lower than 0");
1436 
1437   for (size_t i_dim = 0; i_dim < n_dim; ++i_dim) {
1438     if (i_dim == size_t(axis_input)) {
1439       strategy.push_back(1);
1440     } else {
1441       strategy.push_back(op_strategy->GetInputDim()[i_input][i_dim]);
1442     }
1443   }
1444 
1445   return strategy;
1446 }
1447 
PrepareIncomingArithmeticOpeartorInputStrategy(const std::shared_ptr<OperatorInfo> & op)1448 Dimensions PrepareIncomingArithmeticOpeartorInputStrategy(const std::shared_ptr<OperatorInfo> &op) {
1449   Dimensions strategy;
1450   size_t max = 0;
1451   for (size_t i = 1; i < op->inputs_shape().size(); i++) {
1452     if (op->inputs_shape()[i].size() > op->inputs_shape()[max].size()) {
1453       max = i;
1454     }
1455   }
1456 
1457   for (size_t j = 0; j < op->inputs_shape()[max].size(); j++) {
1458     strategy.push_back(op->selected_strategy()->GetInputDim()[max][j]);
1459   }
1460 
1461   return strategy;
1462 }
1463 
PrepareIncomingOperatorInputStrategy(const std::shared_ptr<OperatorInfo> & op)1464 Dimensions PrepareIncomingOperatorInputStrategy(const std::shared_ptr<OperatorInfo> &op) {
1465   Dimensions strategy;
1466 
1467   if (op->type() == GATHERV2) {
1468     auto pos = op->name().find("Info");
1469     if (pos == std::string::npos) {
1470       return strategy;
1471     }
1472     auto name = op->name().substr(0, pos);
1473     if (name == "Gather") {
1474       return PrepareGatherV2OutputStrategy(op);
1475     } else {
1476       MS_LOG(EXCEPTION) << "Failure: Unknown type of GatherV2.";
1477     }
1478   }
1479 
1480   if (!HasStrategy(op)) {
1481     return strategy;
1482   }
1483 
1484   auto op_strategy = op->selected_strategy();
1485   if (op_strategy->GetInputNumber() == 0) {
1486     return strategy;
1487   }
1488 
1489   if (op->type() == MUL || op->type() == SUB || op->type() == ADD || op->type() == BIAS_ADD) {
1490     strategy = PrepareIncomingArithmeticOpeartorInputStrategy(op);
1491     return strategy;
1492   }
1493 
1494   if (op->type() == RESHAPE) {
1495     return PrepareReshapeOutputStrategy(op);
1496   } else if (op->type() == TRANSPOSE) {
1497     return PrepareTransposeOutputStrategy(op);
1498   } else if (op->type() == EXPAND_DIMS) {
1499     return PrepareExpandDimsOutputStrategy(op);
1500   } else if (op->type() == CUM_SUM || op->type() == CUM_PROD) {
1501     return PrepareCumOutputStrategy(op);
1502   } else if (op->type() == ONEHOT) {
1503     return PrepareOneHotOutputStrategy(op);
1504   }
1505 
1506   for (size_t i = 0; i < static_cast<size_t>(op->inputs_shape().size()); i++) {
1507     if (op->inputs_shape()[i].size() == 0) {
1508       continue;
1509     }
1510     for (size_t j = 0; j < op->inputs_shape()[i].size(); ++j) {
1511       strategy.push_back(op_strategy->GetInputDim()[i][j]);
1512     }
1513     break;
1514   }
1515   return strategy;
1516 }
1517 
GetAxisList(const std::vector<std::shared_ptr<OperatorInfo>> & ops,const int64_t iter_ops)1518 Dimensions GetAxisList(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const int64_t iter_ops) {
1519   Dimensions axis_list;
1520   auto axis_param = ops[LongToSize(iter_ops)]->attrs().find(AXIS)->second;
1521   std::vector<ValuePtr> elements;
1522   if (axis_param->isa<ValueTuple>()) {
1523     elements = axis_param->cast<ValueTuplePtr>()->value();
1524   } else if (axis_param->isa<ValueList>()) {
1525     elements = axis_param->cast<ValueListPtr>()->value();
1526   } else {
1527     MS_LOG(EXCEPTION) << "Failure: Axis type is invalid, neither tuple nor list.";
1528   }
1529 
1530   for (auto &element : elements) {
1531     if (!element->isa<Int64Imm>()) {
1532       MS_LOG(EXCEPTION) << "Failure: Dimension indexes is not Int32.";
1533     }
1534     auto axis = element->cast<Int64ImmPtr>()->value();
1535     axis_list.push_back(axis);
1536   }
1537   return axis_list;
1538 }
1539 
ModifyStrategyIfSqueezeIncoming(const std::vector<std::shared_ptr<OperatorInfo>> & ops,const size_t incoming_op_index,Dimensions strategy)1540 Dimensions ModifyStrategyIfSqueezeIncoming(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
1541                                            const size_t incoming_op_index, Dimensions strategy) {
1542   Dimensions s_Squeeze;
1543   Dimensions stra_dim_list;
1544   for (size_t i = 0; i < strategy.size(); i++) {
1545     stra_dim_list.push_back(SizeToLong(i));
1546   }
1547 
1548   auto axis_list = GetAxisList(ops, SizeToLong(incoming_op_index));
1549   for (auto axis : axis_list) {
1550     axis = (axis < 0) ? (strategy.size() + axis) : axis;
1551     auto it = find(stra_dim_list.begin(), stra_dim_list.end(), axis);
1552     if (it == stra_dim_list.end()) {
1553       MS_LOG(EXCEPTION) << "Failure: Can not find dimension indexes in Axis.";
1554     }
1555     if (ops[incoming_op_index]->inputs_shape()[0][LongToSize(axis)] != 1) {
1556       MS_LOG(EXCEPTION) << "Failure: Removed dimension's shape is not 1.";
1557     }
1558     (void)stra_dim_list.erase(it);
1559   }
1560 
1561   for (size_t i = 0; i < stra_dim_list.size(); i++) {
1562     s_Squeeze.push_back(strategy[LongToSize(stra_dim_list[i])]);
1563   }
1564   return s_Squeeze;
1565 }
1566 
ModifyStrategyIfReduceIncoming(const std::shared_ptr<OperatorInfo> & op,Dimensions strategy)1567 Dimensions ModifyStrategyIfReduceIncoming(const std::shared_ptr<OperatorInfo> &op, Dimensions strategy) {
1568   Dimensions s_Reduce;
1569   Dimensions axis_list;
1570   for (size_t i = 0; i < strategy.size(); i++) {
1571     axis_list.push_back(SizeToLong(i));
1572   }
1573 
1574   auto dim_list = GetDimList(op);
1575   for (auto axis : dim_list) {
1576     auto it = find(axis_list.begin(), axis_list.end(), axis);
1577     if (it == axis_list.end()) {
1578       MS_LOG(EXCEPTION) << "Failure: Can not find dimension indexes in Axis.";
1579     }
1580     (void)axis_list.erase(it);
1581   }
1582 
1583   for (size_t i = 0; i < axis_list.size(); i++) {
1584     s_Reduce.push_back(strategy[LongToSize(axis_list[i])]);
1585   }
1586   return s_Reduce;
1587 }
1588 
GetDimListFromAttrs(const std::shared_ptr<OperatorInfo> & op)1589 Dimensions GetDimListFromAttrs(const std::shared_ptr<OperatorInfo> &op) {
1590   Dimensions dim_list;
1591   auto iter = op->attrs().find(AXIS);
1592   if (iter == op->attrs().end()) {
1593     MS_LOG(EXCEPTION) << op->name() << ": Don't have attr axis.";
1594   }
1595   auto input_dim = op->inputs_shape()[0].size();
1596   MS_EXCEPTION_IF_NULL(iter->second);
1597   if (iter->second->isa<ValueTuple>()) {
1598     auto attr_axis = GetValue<std::vector<int64_t>>(iter->second);
1599     if (attr_axis.empty()) {
1600       for (size_t i = 0; i < input_dim; ++i) {
1601         dim_list.push_back(SizeToLong(i));
1602       }
1603     } else {
1604       for (auto &axis : attr_axis) {
1605         axis < 0 ? dim_list.push_back(axis + SizeToLong(input_dim)) : dim_list.push_back(axis);
1606       }
1607     }
1608   } else if (iter->second->isa<Int64Imm>()) {
1609     int64_t axis = GetValue<int64_t>(iter->second);
1610     axis < 0 ? dim_list.push_back(axis + SizeToLong(input_dim)) : dim_list.push_back(axis);
1611   } else {
1612     MS_LOG(EXCEPTION) << "Axis type is invalid.";
1613   }
1614   return dim_list;
1615 }
1616 
ModifyStrategyIfArgIncoming(const std::shared_ptr<OperatorInfo> & op,Dimensions strategy)1617 Dimensions ModifyStrategyIfArgIncoming(const std::shared_ptr<OperatorInfo> &op, Dimensions strategy) {
1618   bool keepdims = GetKeepDims(op);
1619   if (keepdims) {
1620     return strategy;
1621   }
1622 
1623   Dimensions s_Arg;
1624   Dimensions axis_list;
1625   for (size_t i = 0; i < strategy.size(); i++) {
1626     axis_list.push_back(SizeToLong(i));
1627   }
1628 
1629   auto dim_list = GetDimListFromAttrs(op);
1630   for (auto axis : dim_list) {
1631     auto it = find(axis_list.begin(), axis_list.end(), axis);
1632     if (it == axis_list.end()) {
1633       MS_LOG(EXCEPTION) << "Failure: Can not find dimension indexes in Axis.";
1634     }
1635     (void)axis_list.erase(it);
1636   }
1637 
1638   for (size_t i = 0; i < axis_list.size(); i++) {
1639     s_Arg.push_back(strategy[LongToSize(axis_list[i])]);
1640   }
1641   return s_Arg;
1642 }
1643 
ModifyStrategyIfFlattenIncoming(const std::shared_ptr<OperatorInfo> & op,Dimensions strategy)1644 Dimensions ModifyStrategyIfFlattenIncoming(const std::shared_ptr<OperatorInfo> &op, Dimensions strategy) {
1645   Dimensions new_strategy;
1646   int start_dim = 1, end_dim = strategy.size() - 1;
1647   auto start_dim_iter = op->attrs().find("start_dim");
1648   if (start_dim_iter != op->attrs().end()) {
1649     start_dim = GetValue<int64_t>(start_dim_iter->second);
1650   }
1651   auto end_dim_iter = op->attrs().find("end_dim");
1652   if (end_dim_iter != op->attrs().end() && GetValue<int64_t>(end_dim_iter->second) >= 0) {
1653     end_dim = GetValue<int64_t>(end_dim_iter->second);
1654   }
1655 
1656   for (int idx = 0; idx < start_dim; idx++) {
1657     new_strategy.push_back(strategy[idx]);
1658   }
1659 
1660   int flatten_strategy = 1;
1661   for (int idx = start_dim; idx < end_dim + 1; idx++) {
1662     flatten_strategy *= strategy[idx];
1663   }
1664   new_strategy.push_back(flatten_strategy);
1665   if (IntToSize(end_dim + 1) < strategy.size()) {
1666     for (size_t idx = end_dim + 1; idx < strategy.size(); idx++) {
1667       new_strategy.push_back(strategy[idx]);
1668     }
1669   }
1670 
1671   return new_strategy;
1672 }
1673 
CopyIncomingOperatorInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> & ops,const size_t iter_ops,const size_t incoming_op_index)1674 Dimensions CopyIncomingOperatorInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
1675                                              const size_t iter_ops, const size_t incoming_op_index) {
1676   Dimensions strategy;
1677   if (ops[iter_ops]->type() == ONEHOT) {
1678     return strategy;
1679   }
1680   if (ops[iter_ops]->type() == TRANSPOSE) {
1681     return strategy;
1682   }
1683   if (ops[incoming_op_index]->type() == STRIDED_SLICE) {
1684     return strategy;
1685   }
1686   strategy = PrepareIncomingOperatorInputStrategy(ops[incoming_op_index]);
1687   if (strategy.size() != 0) {
1688     if (ops[incoming_op_index]->type() == SQUEEZE) {
1689       strategy = ModifyStrategyIfSqueezeIncoming(ops, incoming_op_index, strategy);
1690     }
1691     if (ops[incoming_op_index]->type() == REDUCE_SUM || ops[incoming_op_index]->type() == REDUCE_MAX ||
1692         ops[incoming_op_index]->type() == REDUCE_MIN || ops[incoming_op_index]->type() == REDUCE_MEAN) {
1693       strategy = ModifyStrategyIfReduceIncoming(ops[incoming_op_index], strategy);
1694     }
1695     if (ops[incoming_op_index]->type() == ARGMAXWITHVALUE || ops[incoming_op_index]->type() == ARGMINWITHVALUE) {
1696       strategy = ModifyStrategyIfArgIncoming(ops[incoming_op_index], strategy);
1697     }
1698     if (ops[incoming_op_index]->type() == FLATTEN) {
1699       strategy = ModifyStrategyIfFlattenIncoming(ops[incoming_op_index], strategy);
1700     }
1701   }
1702   return strategy;
1703 }
1704 
PrepareDropoutDoMask(const std::shared_ptr<OperatorInfo> & op,Dimensions basic_stra,bool dyn_shape_tmp_fix)1705 Strategies PrepareDropoutDoMask(const std::shared_ptr<OperatorInfo> &op, Dimensions basic_stra,
1706                                 bool dyn_shape_tmp_fix) {
1707   // Dropout's strategy shape must be 1.
1708   Strategies strategies;
1709   strategies.clear();
1710   strategies.push_back(basic_stra);
1711   return strategies;
1712 }
1713 
1714 // Function to deal with ops with broadcasting, like TensorAdd/Sub/Mul/Div etc.
CheckBroadcast(const std::shared_ptr<OperatorInfo> & op,Dimensions strategy,bool dyn_shape_tmp_fix)1715 Strategies CheckBroadcast(const std::shared_ptr<OperatorInfo> &op, Dimensions strategy, bool dyn_shape_tmp_fix) {
1716   Strategies strategies;
1717 
1718   size_t first_tensor_dim = op->inputs_shape()[0].size();
1719   size_t second_tensor_dim = op->inputs_shape()[1].size();
1720   size_t s_dim = strategy.size();
1721   // Do Broadcasting in the second tensor.
1722   if (second_tensor_dim < first_tensor_dim) {
1723     if (s_dim == first_tensor_dim) {
1724       bool broadcast_first_tensor = false;
1725       strategies.push_back(strategy);
1726       strategies.push_back(ApplyBroadcast(op, strategy, broadcast_first_tensor));
1727     } else {
1728       // When the strategy is from the smaller tensor, make the strategy all 1.
1729       Dimensions broadcast_revise_s(first_tensor_dim, 1);
1730       strategies.push_back(broadcast_revise_s);
1731       Dimensions broadcast_s(strategy.size(), 1);
1732       strategies.push_back(broadcast_s);
1733     }
1734   } else if (second_tensor_dim > first_tensor_dim) {  // Do Broadcasting in the first tensor.
1735     if (s_dim == second_tensor_dim) {
1736       bool broadcast_first_tensor = true;
1737       strategies.push_back(ApplyBroadcast(op, strategy, broadcast_first_tensor));
1738       strategies.push_back(strategy);
1739     } else {
1740       // When the strategy is from the smaller tensor, make the strategy all 1.
1741       Dimensions broadcast_s(strategy.size(), 1);
1742       strategies.push_back(broadcast_s);
1743       Dimensions broadcast_revise_s(second_tensor_dim, 1);
1744       strategies.push_back(broadcast_revise_s);
1745     }
1746   } else {  // Broadcasting can be ignored or No broadcasting needs to be applied.
1747     strategies = CheckDivisible(op, strategy);
1748   }
1749   // Strategy protection to avoid that partition number is larger than the shape of related dimension.
1750   for (size_t i = 0; i < op->inputs_shape().size(); i++) {
1751     for (size_t j = 0; j < op->inputs_shape()[i].size(); j++) {
1752       if (strategies[i][j] > op->inputs_shape()[i][j] || op->inputs_shape()[i][j] % strategies[i][j] != 0) {
1753         strategies[i][j] = 1;
1754       }
1755     }
1756   }
1757 
1758   return strategies;
1759 }
1760 
InitializeStrategyMap()1761 void InitializeStrategyMap() {
1762   if (g_prepare_stra_map.empty()) {
1763     g_prepare_stra_map =
1764       std::map<std::string, PrepareStraFuncPtr>{{FILLV2, &PrepareFillV2},
1765                                                 {BIAS_ADD, &PrepareBiasAdd},
1766                                                 {STRIDED_SLICE, &PrepareStridedSlice},
1767                                                 {GATHERV2, &PrepareGather},
1768                                                 {ONEHOT, &PrepareOneHot},
1769                                                 {L2_NORMALIZE, &PrepareL2Normalize},
1770                                                 {ADD, &CheckBroadcast},
1771                                                 {SUB, &CheckBroadcast},
1772                                                 {MUL, &CheckBroadcast},
1773                                                 {DIV, &CheckBroadcast},
1774                                                 {SOFTMAX, &PrepareSoftMax},
1775                                                 {LOG_SOFTMAX, &PrepareSoftMax},
1776                                                 {FLATTEN, &PrepareDataParallel},
1777                                                 {GATHERD, &PrepareDataParallel},
1778                                                 {LAYER_NORM, &PrepareLayerNorm},
1779                                                 {RMS_NORM, &PrepareRmsNorm},
1780                                                 {BATCH_MATMUL, &PreparePropagateBatchMatMul},
1781                                                 {DROPOUT_DO_MASK, &PrepareDropoutDoMask},
1782                                                 {FLASH_ATTENTION_SCORE, &PrepareFlashAttentionScore}};
1783   }
1784 }
1785 
GenerateStrategiesFromStrategy(const std::vector<std::shared_ptr<OperatorInfo>> & ops,const size_t iter_ops,Dimensions basic_stra,bool dyn_shape_tmp_fix)1786 Strategies GenerateStrategiesFromStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
1787                                           Dimensions basic_stra, bool dyn_shape_tmp_fix) {
1788   MS_EXCEPTION_IF_NULL(ops[iter_ops]);
1789 
1790   if (iter_ops >= ops.size()) {
1791     MS_LOG(EXCEPTION) << "Failure: Operators' elements out of range.";
1792   }
1793 
1794   Strategies strategies;
1795   if (basic_stra.size() == 0) {
1796     for (size_t iter_op_inputs = 0; iter_op_inputs < static_cast<size_t>(ops[iter_ops]->inputs_shape().size());
1797          iter_op_inputs++) {
1798       strategies.push_back(basic_stra);
1799     }
1800     return strategies;
1801   }
1802   InitializeStrategyMap();
1803   auto type = ops[iter_ops]->type();
1804   auto iter_stra_func = g_prepare_stra_map.find(type);
1805   if (iter_stra_func != g_prepare_stra_map.end()) {
1806     auto stra = iter_stra_func->second(ops[iter_ops], basic_stra, dyn_shape_tmp_fix);
1807     return stra;
1808   }
1809 
1810   return CheckDivisible(ops[iter_ops], basic_stra);
1811 }
1812 
ApplyBroadcast(const std::shared_ptr<OperatorInfo> & op,const Dimensions & strategy,bool broadcast_first_tensor)1813 Dimensions ApplyBroadcast(const std::shared_ptr<OperatorInfo> &op, const Dimensions &strategy,
1814                           bool broadcast_first_tensor) {
1815   Dimensions s_broadcast;
1816   size_t target_tensor_index = 0;
1817   size_t target_tensor_dim = 1;
1818 
1819   // Indexing target and refer tensor.
1820   if (!broadcast_first_tensor) {
1821     target_tensor_index = 1;
1822   }
1823 
1824   target_tensor_dim = op->inputs_shape()[target_tensor_index].size();
1825   for (size_t iter = 0; iter < target_tensor_dim; iter++) {
1826     if (op->inputs_shape()[target_tensor_index][target_tensor_dim - 1 - iter] == 1) {
1827       s_broadcast.insert(s_broadcast.begin(), 1);
1828     } else {
1829       s_broadcast.insert(s_broadcast.begin(), strategy[strategy.size() - 1 - iter]);
1830     }
1831   }
1832 
1833   return s_broadcast;
1834 }
1835 
1836 // Check whether the operator can be divided by the current strategy.
CheckDivisible(const std::shared_ptr<OperatorInfo> & op,const Dimensions & basic_stra)1837 Strategies CheckDivisible(const std::shared_ptr<OperatorInfo> &op, const Dimensions &basic_stra) {
1838   Dimensions s_empty = {};
1839   Strategies strategies;
1840 
1841   // For all the input tensors.
1842   for (size_t iter_op_inputs = 0; iter_op_inputs < op->inputs_shape().size(); iter_op_inputs++) {
1843     // If input tensor is empty, return strategy as void.
1844     if (op->inputs_shape()[iter_op_inputs].size() == 0) {
1845       strategies.push_back(s_empty);
1846       continue;
1847     }
1848 
1849     Dimensions tmp_stra;
1850 
1851     // Make sure each tensor's dim shape is greater than 1. If not, push back strategy as 1 instead.
1852     for (size_t j = 0; j < op->inputs_shape()[iter_op_inputs].size(); j++) {
1853       if (op->inputs_shape()[iter_op_inputs][j] == 1) {
1854         tmp_stra.push_back(1);
1855       } else if (j < basic_stra.size()) {
1856         tmp_stra.push_back(basic_stra[j]);
1857       } else {
1858         tmp_stra.push_back(1);
1859       }
1860     }
1861     strategies.push_back(tmp_stra);
1862   }
1863 
1864   return strategies;
1865 }
1866 
ModifyStrategyIfSqueezeOutgoing(const std::vector<std::shared_ptr<OperatorInfo>> & ops,const size_t iter_ops,Dimensions strategy)1867 Dimensions ModifyStrategyIfSqueezeOutgoing(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
1868                                            Dimensions strategy) {
1869   Dimensions s_Squeeze;
1870   auto axis_list = GetAxisList(ops, SizeToLong(iter_ops));
1871   size_t s_index = 0;
1872   size_t axis_list_index = 0;
1873   for (size_t i = 0; i < strategy.size() + axis_list.size(); i++) {
1874     if (axis_list[axis_list_index] > 0 && i == LongToSize(axis_list[axis_list_index])) {
1875       s_Squeeze.push_back(1);
1876       axis_list_index++;
1877     } else {
1878       s_Squeeze.push_back(strategy[s_index]);
1879       s_index++;
1880     }
1881   }
1882 
1883   size_t cut = 1;
1884   for (size_t i = 0; i < s_Squeeze.size(); i++) {
1885     cut *= LongToSize(s_Squeeze[i]);
1886   }
1887   if (cut != size_t(g_device_manager->stage_device_num())) {
1888     s_Squeeze.clear();
1889   }
1890 
1891   return s_Squeeze;
1892 }
1893 
PrepareExpandDimsInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> & ops,size_t i_ops,size_t outgoing_op_index,size_t i_input)1894 Dimensions PrepareExpandDimsInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops, size_t i_ops,
1895                                           size_t outgoing_op_index, size_t i_input) {
1896   Dimensions strategy;
1897 
1898   int64_t axis_input = GetValue<int64_t>(ops[i_ops]->input_value().at(1));
1899 
1900   auto op_strategy = ops[outgoing_op_index]->selected_strategy();
1901 
1902   size_t n_dim = op_strategy->GetInputDim()[i_input].size();
1903 
1904   if (axis_input < 0) {
1905     axis_input = SizeToLong(n_dim) + axis_input;
1906   }
1907 
1908   MS_EXCEPTION_IF_CHECK_FAIL(axis_input >= 0, "Input axis is lower than 0");
1909 
1910   for (size_t i_dim = 0; i_dim < n_dim; ++i_dim) {
1911     if (i_dim != size_t(axis_input)) {
1912       strategy.push_back(op_strategy->GetInputDim()[i_input][i_dim]);
1913     }
1914   }
1915 
1916   return strategy;
1917 }
1918 
PrepareReshapeInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> & ops,size_t i_ops,size_t outgoing_op_index,size_t iter_op_inputs,bool dyn_shape_tmp_fix)1919 Dimensions PrepareReshapeInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops, size_t i_ops,
1920                                        size_t outgoing_op_index, size_t iter_op_inputs, bool dyn_shape_tmp_fix) {
1921   if (dyn_shape_tmp_fix) {
1922     Dimensions empty_strategy;
1923     return empty_strategy;
1924   }
1925   auto output_shape = ops[i_ops]->outputs_shape()[0];
1926   auto input_shape = ops[i_ops]->inputs_shape()[0];
1927   auto strategy = ops[outgoing_op_index]->selected_strategy();
1928 
1929   return PrepareReshape(output_shape, input_shape, strategy->GetInputDim()[iter_op_inputs]);
1930 }
1931 
PrepareGatherV2InputStrategy(const std::shared_ptr<OperatorInfo> & op,size_t i_input)1932 Dimensions PrepareGatherV2InputStrategy(const std::shared_ptr<OperatorInfo> &op, size_t i_input) {
1933   auto targeted_shape = op->inputs_shape()[i_input];
1934   Dimensions strategie = GenGatherStra(targeted_shape);
1935   return strategie;
1936 }
1937 
PrepareReduceOutputStrategy(const std::shared_ptr<OperatorInfo> & op)1938 Dimensions PrepareReduceOutputStrategy(const std::shared_ptr<OperatorInfo> &op) {
1939   bool keep_dims = GetKeepDims(op);
1940   auto axis_list = GetDimList(op);
1941   auto basic_stra = op->selected_strategy()->GetInputDim().at(0);
1942 
1943   Dimensions strategy;
1944 
1945   for (size_t i = 0; i < basic_stra.size(); ++i) {
1946     if (std::find(axis_list.begin(), axis_list.end(), i) != axis_list.end()) {
1947       if (keep_dims) {
1948         strategy.push_back(1);
1949       }
1950     } else {
1951       strategy.push_back(basic_stra.at(i));
1952     }
1953   }
1954 
1955   return strategy;
1956 }
1957 
PrepareReduceInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> & ops,size_t i_ops,size_t outgoing_op_index,size_t i_input)1958 Dimensions PrepareReduceInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops, size_t i_ops,
1959                                       size_t outgoing_op_index, size_t i_input) {
1960   bool keep_dims = GetKeepDims(ops[i_ops]);
1961 
1962   auto axis_list = GetDimList(ops[i_ops]);
1963 
1964   Dimensions strategy;
1965 
1966   auto basic_stra = ops[outgoing_op_index]->selected_strategy()->GetInputDim().at(i_input);
1967 
1968   for (size_t i = 0, i_stra = 0; i < ops[i_ops]->inputs_shape()[0].size(); ++i) {
1969     if (std::find(axis_list.begin(), axis_list.end(), i) != axis_list.end()) {
1970       strategy.push_back(1);
1971       if (keep_dims) {
1972         ++i_stra;
1973       }
1974     } else {
1975       strategy.push_back(basic_stra.at(i_stra++));
1976     }
1977   }
1978 
1979   return strategy;
1980 }
1981 
PrepareTransposeInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> & ops,size_t i_ops,size_t outgoing_op_index,size_t iter_op_inputs)1982 Dimensions PrepareTransposeInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops, size_t i_ops,
1983                                          size_t outgoing_op_index, size_t iter_op_inputs) {
1984   Dimensions strategy;
1985   auto permutation = GetValue<std::vector<int64_t>>(ops[i_ops]->input_value().at(1));
1986   auto op_strategy = ops[outgoing_op_index]->selected_strategy();
1987   // The strategies are assigned according to the order in permutation (user defined).
1988   for (size_t i = 0; i < permutation.size(); i++) {
1989     strategy.push_back(op_strategy->GetInputDim()[iter_op_inputs][LongToSize(permutation[i])]);
1990   }
1991   return strategy;
1992 }
1993 
CopyOutgoingOperatorInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> & ops,size_t iter_ops,size_t outgoing_op_index,size_t iter_op_inputs,bool dyn_shape_tmp_fix)1994 Dimensions CopyOutgoingOperatorInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops, size_t iter_ops,
1995                                              size_t outgoing_op_index, size_t iter_op_inputs, bool dyn_shape_tmp_fix) {
1996   Dimensions strategy;
1997   // Propagation not implemented for these operators
1998   if (ops[iter_ops]->type() == ARGMAXWITHVALUE || ops[iter_ops]->type() == ARGMINWITHVALUE) {
1999     return strategy;
2000   }
2001 
2002   // Propagation not allowed for these operators
2003   if (ops[iter_ops]->type() == FLATTEN) {
2004     return strategy;
2005   }
2006 
2007   if (outgoing_op_index != SIZE_MAX && iter_op_inputs != SIZE_MAX) {
2008     std::string type = ops[iter_ops]->type();
2009     if (type == EXPAND_DIMS) {
2010       strategy = PrepareExpandDimsInputStrategy(ops, iter_ops, outgoing_op_index, iter_op_inputs);
2011     } else if (type == RESHAPE) {
2012       strategy = PrepareReshapeInputStrategy(ops, iter_ops, outgoing_op_index, iter_op_inputs, dyn_shape_tmp_fix);
2013       return strategy;
2014     } else if (type == GATHERV2) {
2015       strategy = PrepareGatherV2InputStrategy(ops[outgoing_op_index], iter_op_inputs);
2016       return strategy;
2017     } else if (type == REDUCE_MEAN || type == REDUCE_MAX || type == REDUCE_MIN || type == REDUCE_SUM) {
2018       strategy = PrepareReduceInputStrategy(ops, iter_ops, outgoing_op_index, iter_op_inputs);
2019     } else if (type == TRANSPOSE) {
2020       strategy = PrepareTransposeInputStrategy(ops, iter_ops, outgoing_op_index, iter_op_inputs);
2021       return strategy;
2022     } else {
2023       for (size_t k = 0; k < ops[iter_ops]->outputs_shape()[0].size(); ++k) {
2024         strategy.push_back(ops[outgoing_op_index]->selected_strategy()->GetInputDim()[iter_op_inputs][k]);
2025       }
2026     }
2027     if (!IsDimensionsEmpty(strategy) && ops[iter_ops]->type() == SQUEEZE) {
2028       strategy = ModifyStrategyIfSqueezeOutgoing(ops, iter_ops, strategy);
2029     }
2030   }
2031 
2032   return strategy;
2033 }
2034 
ApplyStrategy(size_t i_op,const Strategies & strategies)2035 void RecStrategyPropagator::ApplyStrategy(size_t i_op, const Strategies &strategies) {
2036   StrategyPtr sp = std::make_shared<Strategy>(0, strategies);
2037   ops_[i_op]->SetSelectedStrategyAndCost(sp, ops_[i_op]->selected_cost());
2038 }
2039 
GetMaxDimNum(size_t i_op)2040 size_t RecStrategyPropagator::GetMaxDimNum(size_t i_op) {
2041   size_t max_dim_num = 0;
2042   for (size_t iter_op_inputs = 0; iter_op_inputs < ops_[i_op]->inputs_shape().size(); iter_op_inputs++) {
2043     if (ops_[i_op]->inputs_shape()[iter_op_inputs].size() > max_dim_num) {
2044       max_dim_num = ops_[i_op]->inputs_shape()[iter_op_inputs].size();
2045     }
2046   }
2047 
2048   return max_dim_num;
2049 }
2050 
GetDefaultStrategy(size_t i_op)2051 Dimensions RecStrategyPropagator::GetDefaultStrategy(size_t i_op) {
2052   Dimensions strategy;
2053   size_t max_dim_num = GetMaxDimNum(i_op);
2054   for (size_t i = 0; i < max_dim_num; i++) {
2055     strategy.push_back(1);
2056   }
2057 
2058   return strategy;
2059 }
2060 
StopPropAtOP(std::string op_type)2061 bool StopPropAtOP(std::string op_type) {
2062   const std::set<std::string> stop_at = {GATHERV2, ASSIGN, EXPAND_DIMS};
2063   return stop_at.find(op_type) != stop_at.end();
2064 }
2065 
GenerateEliminatedOperatorStrategyForward(size_t min_devices)2066 size_t RecStrategyPropagator::GenerateEliminatedOperatorStrategyForward(size_t min_devices) {
2067   MS_LOG(INFO) << "There are " << no_stra_op_list_->size() << " operators left that do not have strategy.";
2068   size_t changes = 0;
2069   if (no_stra_op_list_->empty()) {
2070     return changes;
2071   }
2072 
2073   std::vector<size_t> no_stra_op_list_bis;
2074   for (size_t iter_list = no_stra_op_list_->size(); iter_list > 0; iter_list--) {
2075     size_t iter_ops = no_stra_op_list_->at(iter_list - 1);
2076     Strategies strategies;
2077     size_t incoming_op_index = FindIndexOfOperatorIncoming(ops_, input_tensor_names_, iter_ops);
2078     Dimensions strategy = GetInputStrategy(graph_, ops_, index_list_, iter_ops, incoming_op_index);
2079     if (IsDimensionsEmpty(strategy) || DevicesForDimensions(strategy) < min_devices ||
2080         StopPropAtOP(ops_[incoming_op_index]->type())) {
2081       no_stra_op_list_bis.push_back(iter_ops);
2082     } else {
2083       strategies = GenerateStrategiesFromStrategy(ops_, iter_ops, strategy, graph_->dyn_shape_tmp_fix);
2084       ApplyStrategy(iter_ops, strategies);
2085       ++changes;
2086       MS_LOG(INFO) << ops_[iter_ops]->name() << " assigned strategies " << StrategyToString(strategies) << " from "
2087                    << ops_[incoming_op_index]->name() << " with strategy " << strategy;
2088     }
2089   }
2090   *no_stra_op_list_ = no_stra_op_list_bis;
2091 
2092   return changes;
2093 }
2094 
GenerateEliminatedOperatorStrategyBackward(size_t min_devices)2095 size_t RecStrategyPropagator::GenerateEliminatedOperatorStrategyBackward(size_t min_devices) {
2096   MS_LOG(INFO) << "There are " << no_stra_op_list_->size() << " operators left that do not have strategy.";
2097   size_t changes = 0;
2098   if (no_stra_op_list_->empty()) {
2099     return changes;
2100   }
2101 
2102   std::vector<size_t> no_stra_op_list_bis;
2103   for (size_t iter_list = no_stra_op_list_->size(); iter_list > 0; iter_list--) {
2104     auto iter_ops = no_stra_op_list_->at(iter_list - 1);
2105     Strategies strategies;
2106     std::pair<size_t, size_t> idx = FindIndexOfOperatorOutgoing(ops_, input_tensor_names_, iter_ops);
2107     size_t outgoing_op_index = idx.first;
2108     size_t iter_op_inputs = idx.second;
2109     Dimensions strategy =
2110       CopyOutgoingOperatorInputStrategy(ops_, iter_ops, outgoing_op_index, iter_op_inputs, graph_->dyn_shape_tmp_fix);
2111     if (IsDimensionsEmpty(strategy) || DevicesForDimensions(strategy) < min_devices ||
2112         StopPropAtOP(ops_[outgoing_op_index]->type())) {
2113       no_stra_op_list_bis.push_back(iter_ops);
2114     } else {
2115       strategies = GenerateStrategiesFromStrategy(ops_, iter_ops, strategy, graph_->dyn_shape_tmp_fix);
2116       ApplyStrategy(iter_ops, strategies);
2117       ++changes;
2118       MS_LOG(INFO) << ops_[iter_ops]->name() << " assigned strategies " << StrategyToString(strategies) << " from "
2119                    << ops_[outgoing_op_index]->name() << " with strategy " << strategy;
2120     }
2121   }
2122   *no_stra_op_list_ = no_stra_op_list_bis;
2123 
2124   return changes;
2125 }
2126 
GenerateRemainingOperatorStrategy()2127 size_t RecStrategyPropagator::GenerateRemainingOperatorStrategy() {
2128   size_t changes = 0;
2129 
2130   if (no_stra_op_list_->empty()) {
2131     return changes;
2132   }
2133 
2134   size_t no_stra_op_list_size = no_stra_op_list_->size();
2135   do {
2136     no_stra_op_list_size = no_stra_op_list_->size();
2137     changes += GenerateEliminatedOperatorStrategyForward();
2138     changes += GenerateEliminatedOperatorStrategyBackward();
2139   } while (no_stra_op_list_size > no_stra_op_list_->size());
2140 
2141   for (size_t iter_list = 0; iter_list < no_stra_op_list_->size(); iter_list++) {
2142     auto iter_ops = no_stra_op_list_->at(iter_list);
2143     Dimensions strategy = GetDefaultStrategy(iter_ops);
2144     if (graph_->dyn_shape_tmp_fix && strategy.empty()) {
2145       continue;
2146     }
2147     Strategies strategies = GenerateStrategiesFromStrategy(ops_, iter_ops, strategy, graph_->dyn_shape_tmp_fix);
2148     ApplyStrategy(iter_ops, strategies);
2149     ++changes;
2150     MS_LOG(INFO) << ops_[iter_ops]->name() << " assigned default strategies " << StrategyToString(strategies)
2151                  << " with strategy  " << strategy;
2152   }
2153 
2154   return changes;
2155 }
2156 
2157 // param_name equals to (operator index * input index)
GetParamUsers()2158 std::map<std::string, std::vector<std::pair<size_t, size_t>>> RecStrategyPropagator::GetParamUsers() {
2159   std::map<std::string, std::vector<std::pair<size_t, size_t>>> param_users;
2160 
2161   AnfNodePtr ret = root_->get_return();
2162   std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
2163 
2164   for (auto &node : all_nodes) {
2165     if (node->isa<Parameter>()) {
2166       ParameterUsersInfo parameter_users_info = FindParameterUsers(node, IsParallelCareNode, all_nodes);
2167       auto users_set = parameter_users_info.second.second;
2168       if (users_set.size() >= 1) {
2169         MS_LOG(INFO) << "Parameter " << parameter_users_info.first << " has " << users_set.size() << " users.";
2170         for (auto &user : users_set) {
2171           MS_LOG(INFO) << "with ID: " << user.first->UniqueId() << " and name: " << user.first->UniqueName();
2172 
2173           std::pair<size_t, size_t> user_index = std::make_pair(SIZE_MAX, SIZE_MAX);
2174           for (size_t i = 0; i < input_tensor_names_.size(); i++) {
2175             if (input_tensor_names_[i][0] == user.first->UniqueId()) {
2176               size_t input_index = 0;
2177               if ((ops_[i]->type() == MATMUL) || (ops_[i]->type() == BATCH_MATMUL)) {
2178                 input_index = 1;
2179               }
2180               user_index = std::make_pair(i, input_index);
2181             }
2182           }
2183           if (user_index.first != SIZE_MAX) {
2184             param_users[parameter_users_info.first].push_back(user_index);
2185           }
2186         }
2187       }
2188     }
2189   }
2190 
2191   return param_users;
2192 }
2193 
SetParamStrategy()2194 void RecStrategyPropagator::SetParamStrategy() {
2195   std::map<std::string, std::vector<std::pair<size_t, size_t>>> params_users = GetParamUsers();  // perhaps store this ?
2196   for (auto &param : params_users) {
2197     MS_LOG(INFO) << "Treat parameter " << param.first << " with " << param.second.size() << " uers";
2198     if (param_strategy_.find(param.first) == param_strategy_.end() && !param.second.empty()) {
2199       Dimensions strategy;
2200       Dimensions max_strat;
2201       int max_stra_cut_num = 1;
2202       int max_stra_cut_ratio = INT_MAX;
2203 
2204       for (auto &user : param.second) {
2205         MS_LOG(INFO) << "user is " << ops_[user.first]->name() << " param goes to input " << user.second;
2206         if (!HasStrategy(ops_[user.first])) {
2207           continue;
2208         }
2209         strategy = ops_[user.first]->selected_strategy()->GetInputDim()[user.second];
2210         if (strategy.empty()) {
2211           MS_LOG(INFO) << "user has no strategy";
2212           continue;
2213         }
2214         MS_LOG(INFO) << "This user wants strategy " << strategy;
2215 
2216         auto param_shape = ops_[user.first]->inputs_shape()[user.second];
2217         auto ratio = 0;
2218         for (size_t idx = 0; idx < strategy.size(); idx++) {
2219           MS_EXCEPTION_IF_ZERO("strategy", strategy[idx]);
2220           ratio += param_shape[idx] / strategy[idx];
2221         }
2222 
2223         int cut_num = DevicesForDimensions(strategy);
2224         if (cut_num >= max_stra_cut_num && ratio < max_stra_cut_ratio) {
2225           max_stra_cut_num = cut_num;
2226           max_stra_cut_ratio = ratio;
2227           max_strat = strategy;
2228         }
2229       }
2230       if (!max_strat.empty()) {
2231         param_strategy_[param.first] = max_strat;
2232       }
2233     }
2234   }
2235   MS_LOG(INFO) << "Done";
2236 }
2237 
MakeGatherStratFromParam(const std::shared_ptr<OperatorInfo> & op,Dimensions param_strategy)2238 Strategies MakeGatherStratFromParam(const std::shared_ptr<OperatorInfo> &op, Dimensions param_strategy) {
2239   Strategies strategies;
2240   Dimensions index_strategy;
2241   int64_t axis = GetGatherAxis(op);
2242   if (param_strategy.at(LongToSize(axis)) == 1) {
2243     size_t num_device_used = 1;
2244     for (size_t i = 0; i < param_strategy.size(); i++) {
2245       num_device_used *= param_strategy[i];
2246     }
2247     MS_EXCEPTION_IF_ZERO("num_device_used", num_device_used);
2248     index_strategy.push_back(g_device_manager->stage_device_num() / num_device_used);
2249   } else {
2250     index_strategy.push_back(1);
2251   }
2252 
2253   for (size_t i = 1; i < op->inputs_shape()[1].size(); ++i) {
2254     index_strategy.push_back(1);
2255   }
2256 
2257   strategies.push_back(param_strategy);
2258   strategies.push_back(index_strategy);
2259 
2260   MS_LOG(INFO) << "Gather is assigned strategy " << StrategyToString(strategies);
2261 
2262   return strategies;
2263 }
2264 
MakeMatMulStratFromParam(const std::shared_ptr<OperatorInfo> & op,Dimensions param_strategy)2265 Strategies MakeMatMulStratFromParam(const std::shared_ptr<OperatorInfo> &op, Dimensions param_strategy) {
2266   Strategies new_strategy;
2267   Dimensions new_param_strat;
2268   Dimensions input0_strat = op->selected_strategy()->GetInputDim()[0];
2269   int64_t k_cuts = 1;
2270 
2271   auto input_value = op->input_value();
2272   bool transpose_a = input_value[2]->cast<BoolImmPtr>()->value();
2273   bool transpose_b = input_value[3]->cast<BoolImmPtr>()->value();
2274 
2275   k_cuts = param_strategy[0];
2276   if (transpose_b) {
2277     new_param_strat.push_back(param_strategy[1]);
2278     new_param_strat.push_back(param_strategy[0]);
2279   } else {
2280     new_param_strat.push_back(param_strategy[0]);
2281     new_param_strat.push_back(param_strategy[1]);
2282   }
2283 
2284   if (transpose_a) {
2285     input0_strat[0] = k_cuts;
2286     input0_strat[1] = std::min(input0_strat[1], g_device_manager->stage_device_num() / k_cuts);
2287   } else {
2288     input0_strat[1] = k_cuts;
2289     input0_strat[0] = std::min(input0_strat[1], g_device_manager->stage_device_num() / k_cuts);
2290   }
2291 
2292   new_strategy.push_back(input0_strat);
2293   new_strategy.push_back(new_param_strat);
2294 
2295   MS_LOG(INFO) << "Transpose B : " << transpose_b << "; Transpose A : " << transpose_a << "; K cuts : " << k_cuts;
2296 
2297   MS_LOG(INFO) << "MatMul is assigned strategy " << StrategyToString(new_strategy);
2298 
2299   return new_strategy;
2300 }
2301 
ApplyParamStrategy()2302 size_t RecStrategyPropagator::ApplyParamStrategy() {
2303   size_t changes = 0;
2304   std::map<std::string, std::vector<std::pair<size_t, size_t>>> params_users = GetParamUsers();
2305 
2306   for (auto &param : params_users) {
2307     if (param_strategy_.find(param.first) != param_strategy_.end()) {
2308       for (auto &user : param.second) {
2309         if (graph_->dyn_shape_tmp_fix && ops_[user.first]->type() == GATHERV2) {
2310           if (param.first.find(".output.ffn.projection.weight") != std::string::npos) {
2311             ApplyStrategy(user.first, GatherForDynamicShape(ops_[user.first], 1));
2312             continue;
2313           }
2314           if (param.first.find(".output.ffn.mapping.bias") != std::string::npos) {
2315             ApplyStrategy(user.first, GatherForDynamicShape(ops_[user.first], 3));
2316             continue;
2317           }
2318           if (param.first.find(".output.ffn.mapping.weight") != std::string::npos) {
2319             ApplyStrategy(user.first, GatherForDynamicShape(ops_[user.first], 2));
2320             continue;
2321           }
2322           // This Gather uses shared parameter, but it is not treated as using shared parameter.
2323           // Temporary workaround until this issue is fixed.
2324           if (param.first.find(".embedding.word_embedding.embedding_table") != std::string::npos) {
2325             ApplyStrategy(user.first, GatherForDynamicShape(ops_[user.first], 0));
2326             continue;
2327           }
2328         }
2329 
2330         if (!HasStrategy(ops_[user.first]) ||
2331             param_strategy_[param.first] != ops_[user.first]->selected_strategy()->GetInputDim()[user.second]) {
2332           Strategies strategies;
2333           if (ops_[user.first]->type() == GATHERV2) {
2334             strategies = MakeGatherStratFromParam(ops_[user.first], param_strategy_[param.first]);
2335           } else if (ops_[user.first]->type() == MATMUL) {
2336             strategies = MakeMatMulStratFromParam(ops_[user.first], param_strategy_[param.first]);
2337           } else if (ops_[user.first]->type() == STRIDED_SLICE) {
2338             strategies = CheckDivisible(ops_[user.first], param_strategy_[param.first]);
2339           } else {
2340             strategies =
2341               GenerateStrategiesFromStrategy(ops_, user.first, param_strategy_[param.first], graph_->dyn_shape_tmp_fix);
2342           }
2343           ApplyStrategy(user.first, strategies);
2344           MS_LOG(INFO) << ops_[user.first]->name() << " assigned strategy " << StrategyToString(strategies)
2345                        << " from parameter " << param.first;
2346           ++changes;
2347         }
2348       }
2349     }
2350   }
2351   return changes;
2352 }
2353 
ModifyParamSharingOpsStrategy()2354 size_t RecStrategyPropagator::ModifyParamSharingOpsStrategy() {
2355   size_t changes = 0;
2356 
2357   for (auto tensor : shared_tensors_ops_) {
2358     for (auto op_i : tensor) {
2359       for (auto op_j : tensor) {
2360         if (op_i != op_j) {
2361           MS_LOG(INFO) << "Operator " << ops_[op_i]->name() << " sharing parameter with operator "
2362                        << ops_[op_j]->name();
2363         }
2364       }
2365     }
2366   }
2367 
2368   for (auto tensor : shared_tensors_ops_) {
2369     for (auto op_i : tensor) {
2370       if (ops_[op_i]->type() == GATHERV2) {
2371         for (auto op_j : tensor) {
2372           if (op_i != op_j) {
2373             Dimensions str_j;
2374             if (ops_[op_j]->type() == CAST) {
2375               str_j = ops_[op_j]->selected_strategy()->GetInputDim()[0];
2376             } else if (ops_[op_j]->type() == MATMUL) {
2377               str_j = ops_[op_j]->selected_strategy()->GetInputDim()[1];
2378             } else if (ops_[op_j]->type() == MUL) {
2379               str_j = ops_[op_j]->selected_strategy()->GetInputDim()[0];
2380             } else {
2381               continue;
2382             }
2383 
2384             Strategies strategies;
2385             Dimensions param_strategy, index_strategy;
2386 
2387             param_strategy = str_j;
2388 
2389             size_t num_device_used = 1;
2390             for (size_t i = 0; i < str_j.size(); i++) {
2391               num_device_used *= LongToSize(str_j[i]);
2392             }
2393             MS_EXCEPTION_IF_ZERO("num_device_used", num_device_used);
2394             index_strategy.push_back(g_device_manager->stage_device_num() / num_device_used);
2395 
2396             for (size_t i = 1; i < ops_[op_i]->inputs_shape()[1].size(); ++i) {
2397               index_strategy.push_back(1);
2398             }
2399 
2400             strategies.push_back(param_strategy);
2401             strategies.push_back(index_strategy);
2402 
2403             MS_LOG(INFO) << "Changing strategy of " << ops_[op_i]->name() << " with " << ops_[op_j]->name();
2404             MS_LOG(INFO) << ops_[op_i]->name() << " assigned strategy " << StrategyToString(strategies)
2405                          << " from ModifyParamSharingOpsStrategy";
2406 
2407             ApplyStrategy(op_i, strategies);
2408             ++changes;
2409           }
2410         }
2411       }
2412     }
2413   }
2414 
2415   return changes;
2416 }
2417 
RecStrategyPropagator(const std::shared_ptr<Graph> & graph,const std::vector<std::shared_ptr<OperatorInfo>> & ops,const std::shared_ptr<std::vector<std::vector<size_t>>> & eli_list,const std::vector<std::vector<std::string>> & input_tensor_names,const std::shared_ptr<std::vector<size_t>> & index_list,bool is_training,const std::vector<std::vector<size_t>> & shared_tensors_ops,const FuncGraphPtr & root)2418 RecStrategyPropagator::RecStrategyPropagator(const std::shared_ptr<Graph> &graph,
2419                                              const std::vector<std::shared_ptr<OperatorInfo>> &ops,
2420                                              const std::shared_ptr<std::vector<std::vector<size_t>>> &eli_list,
2421                                              const std::vector<std::vector<std::string>> &input_tensor_names,
2422                                              const std::shared_ptr<std::vector<size_t>> &index_list, bool is_training,
2423                                              const std::vector<std::vector<size_t>> &shared_tensors_ops,
2424                                              const FuncGraphPtr &root)
2425     : graph_(graph),
2426       ops_(ops),
2427       eli_list_(eli_list),
2428       input_tensor_names_(input_tensor_names),
2429       index_list_(index_list),
2430       is_training_(is_training),
2431       shared_tensors_ops_(shared_tensors_ops),
2432       root_(root) {}
2433 
CopyMainOperatorsStrategy()2434 size_t RecStrategyPropagator::CopyMainOperatorsStrategy() {
2435   size_t changes = 0;
2436 
2437   for (size_t i_op = 0; i_op < static_cast<size_t>(index_list_->size()); i_op++) {
2438     Strategies strategies;
2439     size_t iter_graph = index_list_->at(i_op);
2440     if (iter_graph != SIZE_MAX && ops_[i_op]->type() != GET_NEXT) {
2441       strategies = PrepareStrategy(&graph_->nodes[iter_graph], ops_, i_op, graph_->dyn_shape_tmp_fix);
2442     }
2443     if (!strategies.empty()) {
2444       source_ops_.push_back(i_op);
2445       ++changes;
2446     }
2447     StrategyPtr sp = std::make_shared<Strategy>(0, strategies);
2448     ops_[i_op]->SetSelectedStrategyAndCost(sp, ops_[i_op]->selected_cost());
2449   }
2450 
2451   return changes;
2452 }
2453 
GetInputStrategy(const std::shared_ptr<Graph> & graph,const std::vector<std::shared_ptr<OperatorInfo>> & ops,const std::shared_ptr<std::vector<size_t>> & index_list,size_t i_op,size_t incoming_op_index)2454 Dimensions GetInputStrategy(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
2455                             const std::shared_ptr<std::vector<size_t>> &index_list, size_t i_op,
2456                             size_t incoming_op_index) {
2457   Dimensions strategy;
2458   if (incoming_op_index != SIZE_MAX) {
2459     auto iter_graph = index_list->at(incoming_op_index);
2460     if (iter_graph != SIZE_MAX) {
2461       strategy = CopyIncomingOperatorOutputStrategy(&graph->nodes[iter_graph], ops, i_op, incoming_op_index);
2462     } else {
2463       strategy = CopyIncomingOperatorInputStrategy(ops, i_op, incoming_op_index);
2464     }
2465   }
2466 
2467   return strategy;
2468 }
2469 
PropagateFromInputs()2470 size_t RecStrategyPropagator::PropagateFromInputs() { return 0; }
2471 
PropagateFromOutputs()2472 size_t RecStrategyPropagator::PropagateFromOutputs() { return 0; }
2473 
GenerateNoStraList()2474 void RecStrategyPropagator::GenerateNoStraList() {
2475   no_stra_op_list_ = std::make_shared<std::vector<size_t>>();
2476   for (size_t i = 0; i < eli_list_->size(); i++) {
2477     no_stra_op_list_->push_back(eli_list_->at(i)[0]);
2478   }
2479 }
2480 
FixInvalidStra()2481 void RecStrategyPropagator::FixInvalidStra() {
2482   for (auto &op : ops_) {
2483     bool modified = false;
2484     if (!HasStrategy(op)) {
2485       continue;
2486     }
2487     if (op->type() == FILLV2) {
2488       continue;
2489     }
2490     if (graph_->dyn_shape_tmp_fix && (op->type() == ASSIGN || op->type() == ONEHOT)) {
2491       continue;
2492     }
2493     StrategyPtr old_strategys = op->selected_strategy();
2494     Strategies new_strategys;
2495     for (size_t iter_op_inputs = 0; iter_op_inputs < old_strategys->GetInputDim().size(); iter_op_inputs++) {
2496       Dimensions strategies;
2497       for (size_t iter_op_input_stra = 0; iter_op_input_stra < op->inputs_shape()[iter_op_inputs].size();
2498            iter_op_input_stra++) {
2499         if (graph_->dyn_shape_tmp_fix && op->inputs_shape()[iter_op_inputs][iter_op_input_stra] == -1) {
2500           strategies.push_back(old_strategys->GetInputDim()[iter_op_inputs][iter_op_input_stra]);
2501           continue;
2502         }
2503         if (op->inputs_shape()[iter_op_inputs][iter_op_input_stra] <
2504               old_strategys->GetInputDim()[iter_op_inputs][iter_op_input_stra] ||
2505             op->inputs_shape()[iter_op_inputs][iter_op_input_stra] %
2506                 old_strategys->GetInputDim()[iter_op_inputs][iter_op_input_stra] !=
2507               0) {
2508           strategies.push_back(1);
2509           modified = true;
2510         } else {
2511           strategies.push_back(old_strategys->GetInputDim()[iter_op_inputs][iter_op_input_stra]);
2512         }
2513       }
2514       new_strategys.push_back(strategies);
2515     }
2516     if (modified) {
2517       StrategyPtr sp = std::make_shared<Strategy>(0, new_strategys);
2518       op->SetSelectedStrategyAndCost(sp, op->selected_cost());
2519       MS_LOG(INFO) << "CHANGE INVALID STRATEGY FOR : " << op->name() << " from " << old_strategys->GetInputDim()
2520                    << " to " << StrategyToString(new_strategys);
2521     }
2522   }
2523 }
2524 
AjustToNoTraining()2525 void RecStrategyPropagator::AjustToNoTraining() {
2526   for (auto &op : ops_) {
2527     // Set back to raw strategy for special node in predict/eval
2528     if (!is_training_) {
2529       if ((op->is_last_node()) || (op->type() == VIRTUAL_DATA_SET)) {
2530         SetBackToRawStrategy(op);
2531       }
2532     }
2533   }
2534 }
2535 
GenerateStrategyV1()2536 void RecStrategyPropagator::GenerateStrategyV1() {
2537   MS_EXCEPTION_IF_NULL(graph_);
2538   MS_EXCEPTION_IF_NULL(eli_list_);
2539   MS_EXCEPTION_IF_NULL(index_list_);
2540 
2541   no_stra_op_list_ = std::make_shared<std::vector<size_t>>();
2542   for (size_t i = eli_list_->size(); i > 0; i--) {
2543     no_stra_op_list_->push_back(eli_list_->at(i - 1)[0]);
2544   }
2545 
2546   size_t changes;
2547   changes = CopyMainOperatorsStrategy();
2548   MS_LOG(INFO) << "The strategies of " << changes << " operators are modified after CopyMainOperatorsStrategy.";
2549 
2550   changes = GenerateEliminatedOperatorStrategyForward();
2551   MS_LOG(INFO) << "The strategies of " << changes
2552                << " operators are modified after GenerateEliminatedOperatorStrategyForward.";
2553 
2554   changes = GenerateEliminatedOperatorStrategyBackward();
2555   MS_LOG(INFO) << "The strategies of " << changes
2556                << " operators are modified after GenerateEliminatedOperatorStrategyBackward.";
2557 
2558   changes = GenerateRemainingOperatorStrategy();
2559   MS_LOG(INFO) << "The strategies of " << changes << " operators are modified after GenerateRemainingOperatorStrategy.";
2560 
2561   if (graph_->dyn_shape_tmp_fix) {
2562     for (auto &op : ops_) {
2563       if (op->type() == ASSIGN) {
2564         Strategies strategies;
2565         auto assign_input_0_shape = op->inputs_shape()[0];
2566         Dimensions assign_input_0_strategy(assign_input_0_shape.size(), 1);
2567         size_t num_device = LongToSize(g_device_manager->stage_device_num());
2568         if (assign_input_0_shape[1] > 0 && assign_input_0_shape[1] % num_device == 0) {
2569           assign_input_0_strategy[1] = num_device;
2570         }
2571         for (size_t i = 0; i < op->inputs_shape().size(); i++) {
2572           strategies.push_back(assign_input_0_strategy);
2573         }
2574         StrategyPtr sp = std::make_shared<Strategy>(0, strategies);
2575         op->SetSelectedStrategyAndCost(sp, op->selected_cost());
2576       }
2577     }
2578   }
2579 
2580   SetParamStrategy();
2581   changes = ApplyParamStrategy();
2582   MS_LOG(INFO) << "The strategies of " << changes << " operators are modified after ApplyParamStrategy.";
2583 
2584   FixInvalidStra();
2585   AjustToNoTraining();
2586 }
2587 
AssignStandaloneAndBatchParallelOpStrategy()2588 size_t RecStrategyPropagator::AssignStandaloneAndBatchParallelOpStrategy() {
2589   size_t changes = 0;
2590   for (size_t iter_ops = 0; iter_ops < ops_.size(); iter_ops++) {
2591     auto pos = ops_[iter_ops]->name().find("Info");
2592     auto name = ops_[iter_ops]->name().substr(0, pos);
2593     if (name == STAND_ALONE) {
2594       Strategies strategies = PrepareStandAlone(ops_[iter_ops]);
2595       ApplyStrategy(iter_ops, strategies);
2596       changes++;
2597       MS_LOG(INFO) << ops_[iter_ops]->name() << " assigned strategy " << StrategyToString(strategies);
2598       auto iter = find(no_stra_op_list_->begin(), no_stra_op_list_->end(), iter_ops);
2599       if (iter != no_stra_op_list_->end()) {
2600         no_stra_op_list_->erase(iter);
2601       }
2602     }
2603     if (name == BATCH_PARALLEL) {
2604       Strategies strategies;
2605       auto split_flag_list = ops_[iter_ops]->split_flag_list();
2606       auto inputs_shape = ops_[iter_ops]->inputs_shape();
2607       for (size_t i = 0; i < inputs_shape.size(); i++) {
2608         Shape temp(inputs_shape[i].size(), 1);
2609         if (split_flag_list[i]) {
2610           temp[0] = g_device_manager->stage_device_num();
2611         }
2612         strategies.push_back(temp);
2613       }
2614       ApplyStrategy(iter_ops, strategies);
2615       changes++;
2616       MS_LOG(INFO) << ops_[iter_ops]->name() << " assigned strategy " << StrategyToString(strategies);
2617       auto iter = find(no_stra_op_list_->begin(), no_stra_op_list_->end(), iter_ops);
2618       if (iter != no_stra_op_list_->end()) {
2619         no_stra_op_list_->erase(iter);
2620       }
2621     }
2622   }
2623   return changes;
2624 }
2625 
CalMatmulBatchDimFactor(size_t num_device,const StrategyRec & str)2626 static size_t CalMatmulBatchDimFactor(size_t num_device, const StrategyRec &str) {
2627   size_t max_shard_num = FloatToSize(1 / str.inputTensor[0].str_h) * FloatToSize(1 / str.inputTensor[0].str_w);
2628   max_shard_num = max_shard_num < num_device ? max_shard_num : num_device;
2629   return max_shard_num / (FloatToSize(1 / str.outputTensor.str_h) * FloatToSize(1 / str.outputTensor.str_w));
2630 }
2631 
ExtraShardMatmulOnBatchDim()2632 void RecStrategyPropagator::ExtraShardMatmulOnBatchDim() {
2633   MS_EXCEPTION_IF_NULL(graph_);
2634   MS_EXCEPTION_IF_NULL(eli_list_);
2635   MS_EXCEPTION_IF_NULL(index_list_);
2636 
2637   for (size_t i_op = 0; i_op < static_cast<size_t>(index_list_->size()); i_op++) {
2638     size_t iter_graph = index_list_->at(i_op);
2639     if (iter_graph == SIZE_MAX || ops_[i_op]->type() != MATMUL) {
2640       continue;
2641     }
2642     Graph::NodeType &node = graph_->nodes[iter_graph];
2643     size_t matmulBatchDimFactor = CalMatmulBatchDimFactor(g_device_manager->stage_device_num(), node.apply.str);
2644     if (matmulBatchDimFactor > 1) {
2645       MS_LOG(INFO) << ops_[i_op]->name() << " matmulBatchDimFactor " << matmulBatchDimFactor;
2646       node.apply.str.outputTensor.str_h /= matmulBatchDimFactor;
2647       node.tensor_parm.tensor_str.str_h = node.apply.str.outputTensor.str_h;
2648 
2649       Strategies strategies;
2650       Dimensions strategy;
2651       strategy.push_back(static_cast<int64_t>(1.0 / node.apply.str.outputTensor.str_h));
2652       strategy.push_back(static_cast<int64_t>(1.0 / node.apply.str.outputTensor.str_w));
2653       strategies.push_back(strategy);
2654 
2655       int64_t stage_id = g_device_manager->stage_id();
2656       StrategyPtr strategyPtr = NewStrategy(stage_id, strategies);
2657       ops_[i_op]->set_out_strategy(strategyPtr);
2658     }
2659   }
2660 }
2661 
GenerateStrategyV3()2662 void RecStrategyPropagator::GenerateStrategyV3() {
2663   MS_EXCEPTION_IF_NULL(graph_);
2664   MS_EXCEPTION_IF_NULL(eli_list_);
2665   MS_EXCEPTION_IF_NULL(index_list_);
2666 
2667   GenerateNoStraList();
2668   size_t changes;
2669   changes = CopyMainOperatorsStrategy();
2670   MS_LOG(INFO) << "CopyMainOperatorsStrategy has " << changes << "changes";
2671   AssignStandaloneAndBatchParallelOpStrategy();
2672 
2673   for (auto min_devices = g_device_manager->stage_device_num(); min_devices > 1; min_devices /= SIZE_TWO) {
2674     size_t pass_changes = 1;
2675     while (pass_changes > 0) {
2676       pass_changes = 0;
2677 
2678       changes = GenerateEliminatedOperatorStrategyForward(min_devices);
2679       MS_LOG(INFO) << "GenerateEliminatedOperatorStrategyForward has " << changes << "changes";
2680 
2681       pass_changes += changes;
2682       if (changes > 0) continue;
2683 
2684       changes = GenerateEliminatedOperatorStrategyBackward(min_devices);
2685       MS_LOG(INFO) << "GenerateEliminatedOperatorStrategyBackward has " << changes << "changes";
2686 
2687       pass_changes += changes;
2688       if (changes > 0) continue;
2689     }
2690   }
2691 
2692   changes = GenerateRemainingOperatorStrategy();
2693   MS_LOG(INFO) << "GenerateRemainingOperatorStrategy has " << changes << "changes";
2694 
2695   changes = ModifyParamSharingOpsStrategy();
2696   MS_LOG(INFO) << "ModifyParamSharingOpsStrategy has " << changes << "changes";
2697 
2698   FixInvalidStra();
2699   AjustToNoTraining();
2700 }
2701 }  // namespace parallel
2702 }  // namespace mindspore
2703