1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.h"
18
19 #include <algorithm>
20 #include <functional>
21 #include <map>
22 #include <memory>
23 #include <optional>
24 #include <set>
25 #include <vector>
26
27 #include "frontend/parallel/auto_parallel/rec_core/rec_parse_graph.h"
28 #include "frontend/parallel/auto_parallel/rec_core/rec_partition.h"
29 #include "frontend/parallel/ops_info/flash_attention_score_info.h"
30 #include "frontend/parallel/ops_info/operator_info.h"
31 #include "frontend/parallel/ops_info/strided_slice_info.h"
32 #include "frontend/parallel/ops_info/gather_info.h"
33 #include "frontend/parallel/parameter_manager.h"
34 #include "frontend/parallel/step_parallel.h"
35 #include "frontend/parallel/step_parallel_utils.h"
36 #include "frontend/parallel/strategy.h"
37 #include "include/common/utils/utils.h"
38 #include "ir/value.h"
39 #include "mindspore/core/ops/ops_func_impl/flash_attention_score.h"
40 #include "ops/op_enum.h"
41
42 namespace mindspore {
43 namespace parallel {
44 namespace {
45 using PrepareStraFuncPtr = Strategies (*)(const std::shared_ptr<OperatorInfo> &, Dimensions, bool);
46 std::map<std::string, PrepareStraFuncPtr> g_prepare_stra_map;
47
GetKeepDimsFromAttrs(const std::shared_ptr<OperatorInfo> & op)48 std::optional<bool> GetKeepDimsFromAttrs(const std::shared_ptr<OperatorInfo> &op) {
49 auto keep_dims_iter = op->attrs().find(KEEP_DIMS);
50 if (keep_dims_iter == op->attrs().end()) {
51 return std::nullopt;
52 }
53 auto keep_dims_ptr = keep_dims_iter->second;
54 MS_EXCEPTION_IF_NULL(keep_dims_ptr);
55 if (!keep_dims_ptr->isa<BoolImm>()) {
56 MS_LOG(EXCEPTION) << op->name() << ": Keep_dims is not a bool.";
57 }
58 auto keepdims = keep_dims_ptr->cast<BoolImmPtr>()->value();
59 return keepdims;
60 }
61
GetKeepDimsFromInputs(const std::shared_ptr<OperatorInfo> & op)62 std::optional<bool> GetKeepDimsFromInputs(const std::shared_ptr<OperatorInfo> &op) {
63 auto keep_dims_opt = GetScalarValueFromInputs<bool>(op->input_value(), op->name(), KEEP_DIMS);
64 return keep_dims_opt;
65 }
66
GetKeepDims(const std::shared_ptr<OperatorInfo> & op)67 bool GetKeepDims(const std::shared_ptr<OperatorInfo> &op) {
68 auto keep_dims_opt = GetKeepDimsFromAttrs(op);
69 if (!keep_dims_opt.has_value()) {
70 keep_dims_opt = GetKeepDimsFromInputs(op);
71 }
72 if (!keep_dims_opt.has_value()) {
73 MS_LOG(EXCEPTION) << op->name() << ": Don't have attr keep_dims.";
74 }
75 auto keepdims = keep_dims_opt.value();
76 return keepdims;
77 }
78
GetDimList(const std::shared_ptr<OperatorInfo> & op)79 Dimensions GetDimList(const std::shared_ptr<OperatorInfo> &op) {
80 Dimensions dim_list;
81 bool keep_dims = GetKeepDims(op);
82 if (keep_dims) {
83 return dim_list;
84 }
85
86 const auto &name = op->name();
87 auto dim_list_opt = GetArrayValueFromInputs<int64_t>(op->input_value(), name, AXIS);
88 if (!dim_list_opt.has_value()) {
89 MS_LOG(EXCEPTION) << "For " << name << ", failed to get value for " << AXIS << ".";
90 }
91
92 dim_list = dim_list_opt.value();
93 auto x_dim = op->inputs_shape()[0].size();
94 // axis is (), reduce all dim
95 if (dim_list.empty()) {
96 for (size_t i = 0; i < x_dim; ++i) {
97 dim_list.push_back(SizeToLong(i));
98 }
99 } else {
100 auto AxisCorrectFunc = [x_dim](const int64_t axis) {
101 if (axis < 0) {
102 return axis + SizeToLong(x_dim);
103 }
104 return axis;
105 };
106 std::transform(dim_list.begin(), dim_list.end(), dim_list.begin(), AxisCorrectFunc);
107 }
108 return dim_list;
109 }
110 } // namespace
111
OpNameToId(const std::vector<std::shared_ptr<OperatorInfo>> & ops,const std::shared_ptr<OperatorInfo> & op)112 size_t OpNameToId(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const std::shared_ptr<OperatorInfo> &op) {
113 for (size_t i = 0; i < ops.size(); ++i) {
114 if (ops[i]->name() == op->name()) {
115 return i;
116 }
117 }
118
119 return SIZE_MAX;
120 }
121
IsDimensionsFlat(const Dimensions & dims)122 bool IsDimensionsFlat(const Dimensions &dims) {
123 return !std::any_of(dims.begin(), dims.end(), [](const int64_t &dim) { return dim != 1; });
124 }
125
IsDimensionsEmpty(const Dimensions & dims)126 bool IsDimensionsEmpty(const Dimensions &dims) { return dims.empty(); }
127
IsStrategyFlat(const StrategyPtr & str)128 bool IsStrategyFlat(const StrategyPtr &str) {
129 const auto &input_dims = str->GetInputDim();
130 return !std::any_of(input_dims.begin(), input_dims.end(),
131 [](const Dimensions &dims) { return !IsDimensionsFlat(dims); });
132 }
133
DevicesForDimensions(const Dimensions & dims)134 size_t DevicesForDimensions(const Dimensions &dims) {
135 return std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<size_t>());
136 }
137
HasStrategy(std::shared_ptr<OperatorInfo> op)138 bool HasStrategy(std::shared_ptr<OperatorInfo> op) {
139 StrategyPtr s_strategy = op->selected_strategy();
140 if (s_strategy != nullptr && !s_strategy->ToString().empty()) {
141 return true;
142 }
143 return false;
144 }
145
FindIndexOfOperatorIncoming(const std::vector<std::shared_ptr<OperatorInfo>> & ops,const std::vector<std::vector<std::string>> & input_tensor_names,size_t iter_ops)146 size_t FindIndexOfOperatorIncoming(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
147 const std::vector<std::vector<std::string>> &input_tensor_names, size_t iter_ops) {
148 size_t incoming_op_index = SIZE_MAX;
149 for (size_t i = 1; i < input_tensor_names[iter_ops].size(); i++) {
150 for (size_t j = 0; j < input_tensor_names.size(); j++) {
151 if (input_tensor_names[iter_ops][i] == input_tensor_names[j][0]) {
152 incoming_op_index = j;
153 break;
154 }
155 }
156 if (incoming_op_index != SIZE_MAX && HasStrategy(ops.at(incoming_op_index)) &&
157 !IsStrategyFlat(ops.at(incoming_op_index)->selected_strategy())) {
158 break;
159 }
160 }
161 if (incoming_op_index != SIZE_MAX &&
162 ops.at(incoming_op_index)->name().find(VIRTUALDATASETINFO) != std::string::npos) {
163 return SIZE_MAX;
164 }
165 return incoming_op_index;
166 }
167
FindIndexOfOperatorOutgoing(const std::vector<std::shared_ptr<OperatorInfo>> & ops,const std::vector<std::vector<std::string>> & input_tensor_names,const size_t iter_ops)168 std::pair<size_t, size_t> FindIndexOfOperatorOutgoing(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
169 const std::vector<std::vector<std::string>> &input_tensor_names,
170 const size_t iter_ops) {
171 bool found = false;
172 size_t outgoing_op_index = SIZE_MAX;
173 size_t iter_op_inputs = SIZE_MAX;
174
175 for (size_t i = 0; i < input_tensor_names.size(); i++) {
176 for (size_t j = 1; j < input_tensor_names[i].size(); j++) {
177 if (input_tensor_names[i][j] == input_tensor_names[iter_ops][0] &&
178 ops[i]->selected_strategy()->GetInputNumber() != 0) {
179 outgoing_op_index = i;
180 iter_op_inputs = std::min(j - 1, ops[outgoing_op_index]->inputs_shape().size() - 1);
181 found = true;
182 break;
183 }
184 }
185 if (found) {
186 break;
187 }
188 }
189
190 std::pair<size_t, size_t> res = std::make_pair(outgoing_op_index, iter_op_inputs);
191
192 return res;
193 }
194
GetGatherAxis(const std::shared_ptr<OperatorInfo> & op)195 int64_t GetGatherAxis(const std::shared_ptr<OperatorInfo> &op) {
196 auto axis_input = GetValue<int64_t>(op->input_value().at(2));
197 if (axis_input < 0) {
198 axis_input += SizeToLong(op->inputs_shape()[0].size());
199 }
200 if (axis_input >= SizeToLong(op->inputs_shape()[0].size())) {
201 MS_LOG(EXCEPTION) << "Failure: Gather's axis out of range.";
202 }
203 return axis_input;
204 }
205
GetGatherBatchDims(const std::shared_ptr<OperatorInfo> & op)206 int64_t GetGatherBatchDims(const std::shared_ptr<OperatorInfo> &op) {
207 int64_t batch_dims = -1;
208 auto batch_dims_val = GetScalarValueFromInputs<int64_t>(op->input_value(), op->name(), BATCH_DIMS);
209 if (batch_dims_val.has_value()) {
210 batch_dims = batch_dims_val.value();
211 } else {
212 MS_LOG(EXCEPTION) << op->name() << ": Failed to fetch the value of batch dims";
213 }
214 return batch_dims;
215 }
216
ReverseRemainingList(const std::shared_ptr<std::vector<size_t>> & no_stra_op_list)217 void ReverseRemainingList(const std::shared_ptr<std::vector<size_t>> &no_stra_op_list) {
218 MS_LOG(INFO) << "ReverseRemainingList";
219 std::reverse(no_stra_op_list->begin(), no_stra_op_list->end());
220 }
221
GenerateStrategy(const std::shared_ptr<Graph> & graph,const std::vector<std::shared_ptr<OperatorInfo>> & ops,const std::shared_ptr<std::vector<std::vector<size_t>>> & eli_list,const std::vector<std::vector<std::string>> & input_tensor_names,const std::shared_ptr<std::vector<size_t>> & index_list,bool is_training,const std::vector<std::vector<size_t>> & param_users_ops_index,const FuncGraphPtr & root)222 void GenerateStrategy(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
223 const std::shared_ptr<std::vector<std::vector<size_t>>> &eli_list,
224 const std::vector<std::vector<std::string>> &input_tensor_names,
225 const std::shared_ptr<std::vector<size_t>> &index_list, bool is_training,
226 const std::vector<std::vector<size_t>> ¶m_users_ops_index, const FuncGraphPtr &root) {
227 RecStrategyPropagator propagator(graph, ops, eli_list, input_tensor_names, index_list, is_training,
228 param_users_ops_index, root);
229
230 if (g_device_manager->DeviceNum() > SIZE_THIRTY_TWO) {
231 propagator.ExtraShardMatmulOnBatchDim();
232 }
233 if (is_training) {
234 propagator.GenerateStrategyV3();
235 } else {
236 propagator.GenerateStrategyV1();
237 }
238 }
239
FillFlashLayoutIndexes(const std::shared_ptr<FlashAttentionScoreInfo> & flashOp,size_t * batch_split_idx,size_t * n_split_idx,size_t * s_split_idx)240 void FillFlashLayoutIndexes(const std::shared_ptr<FlashAttentionScoreInfo> &flashOp, size_t *batch_split_idx,
241 size_t *n_split_idx, size_t *s_split_idx) {
242 MS_EXCEPTION_IF_NULL(flashOp);
243 MS_EXCEPTION_IF_NULL(batch_split_idx);
244 MS_EXCEPTION_IF_NULL(n_split_idx);
245 MS_EXCEPTION_IF_NULL(s_split_idx);
246
247 size_t tmp_batch_split_idx;
248 size_t tmp_n_split_idx;
249 size_t tmp_s_split_idx;
250
251 using mindspore::ops::FASInputLayoutMode;
252 switch (flashOp->input_layout()) {
253 case FASInputLayoutMode::BSH:
254 case FASInputLayoutMode::BSND:
255 tmp_batch_split_idx = kIndex0;
256 tmp_s_split_idx = kIndex1;
257 tmp_n_split_idx = kIndex2;
258 break;
259 case FASInputLayoutMode::BNSD:
260 tmp_batch_split_idx = kIndex0;
261 tmp_n_split_idx = kIndex1;
262 tmp_s_split_idx = kIndex2;
263 break;
264 case FASInputLayoutMode::SBH:
265 tmp_s_split_idx = kIndex0;
266 tmp_batch_split_idx = kIndex1;
267 tmp_n_split_idx = kIndex2;
268 break;
269 default:
270 MS_LOG(EXCEPTION) << flashOp->name() << "unknown input_layout: " << flashOp->input_layout();
271 }
272
273 *batch_split_idx = tmp_batch_split_idx;
274 *n_split_idx = tmp_n_split_idx;
275 *s_split_idx = tmp_s_split_idx;
276 }
277
PrepareFlashAttentionScore(const std::shared_ptr<OperatorInfo> & op,Dimensions basic_stra,bool dyn_shape_tmp_fix)278 Strategies PrepareFlashAttentionScore(const std::shared_ptr<OperatorInfo> &op, Dimensions basic_stra,
279 bool dyn_shape_tmp_fix) {
280 std::shared_ptr<FlashAttentionScoreInfo> flashOp = std::static_pointer_cast<FlashAttentionScoreInfo>(op);
281
282 if (flashOp->InitAttrs() != SUCCESS) {
283 MS_LOG(EXCEPTION) << flashOp->name() << " : InitAttrs failed.";
284 }
285
286 Strategies expect_strategies = Strategies(ops::kFlashAttentionScoreInputsNum);
287 auto is_input_passed = flashOp->is_input_passed();
288
289 size_t batch_idx;
290 size_t n_split_idx;
291 size_t s_split_idx;
292
293 FillFlashLayoutIndexes(flashOp, &batch_idx, &n_split_idx, &s_split_idx);
294
295 int64_t batch_split_num = basic_stra[batch_idx];
296 int64_t s1_split_num = basic_stra[s_split_idx];
297 int64_t n1_split_num = basic_stra[n_split_idx];
298 int64_t n2_split_num = flashOp->kv_split() ? n1_split_num : 1;
299
300 Dimensions q_stra(op->inputs_shape()[ops::kFlashAttentionScoreInputQueryIndex].size(), 1);
301 q_stra[batch_idx] = batch_split_num;
302 q_stra[s_split_idx] = s1_split_num;
303 q_stra[n_split_idx] = n1_split_num;
304
305 Dimensions kv_stra(op->inputs_shape()[ops::kFlashAttentionScoreInputKeyIndex].size(), 1);
306 kv_stra[batch_idx] = batch_split_num;
307 kv_stra[n_split_idx] = n2_split_num;
308
309 expect_strategies[ops::kFlashAttentionScoreInputQueryIndex] = q_stra;
310 expect_strategies[ops::kFlashAttentionScoreInputKeyIndex] = kv_stra;
311 expect_strategies[ops::kFlashAttentionScoreInputValueIndex] = kv_stra;
312
313 if (is_input_passed[ops::kFlashAttentionScoreInputRealShiftIndex]) {
314 int64_t real_shift_s1_split_num = flashOp->real_shift_have_s1_dim() ? s1_split_num : 1;
315 int64_t real_shift_batch_split_num = flashOp->real_shift_have_batch_dim() ? batch_split_num : 1;
316 expect_strategies[ops::kFlashAttentionScoreInputRealShiftIndex] = {real_shift_batch_split_num, n1_split_num,
317 real_shift_s1_split_num, 1};
318 }
319
320 if (is_input_passed[ops::kFlashAttentionScoreInputDropMaskIndex]) {
321 expect_strategies[ops::kFlashAttentionScoreInputDropMaskIndex] = {batch_split_num, n1_split_num, s1_split_num, 1};
322 }
323
324 if (is_input_passed[ops::kFlashAttentionScoreInputPaddingMaskIndex]) {
325 expect_strategies[ops::kFlashAttentionScoreInputPaddingMaskIndex] = {};
326 }
327
328 if (is_input_passed[ops::kFlashAttentionScoreInputAttnMaskIndex]) {
329 auto attn_mask_shape =
330 flashOp->inputs_shape().at(flashOp->GetStrategyRealIndex(ops::kFlashAttentionScoreInputAttnMaskIndex));
331 int64_t s1_split_num_attn_mask = flashOp->is_attn_mask_compressed() ? 1 : s1_split_num;
332 if (attn_mask_shape.size() == kSizeTwo) {
333 // attn_mask_shape: (S1, S2)
334 expect_strategies[ops::kFlashAttentionScoreInputAttnMaskIndex] = {s1_split_num_attn_mask, 1};
335 } else if (attn_mask_shape.size() == kSizeFour) {
336 // attn_mask_shape: (B, N1, S1, S2) or (B, 1, S1, S2)
337 auto attn_mask_n1_split_num = flashOp->attn_mask_have_n1_dim() ? n1_split_num : 1;
338 auto attn_batch_split_num = flashOp->attn_mask_have_batch_dim() ? batch_split_num : 1;
339 expect_strategies[ops::kFlashAttentionScoreInputAttnMaskIndex] = {attn_batch_split_num, attn_mask_n1_split_num,
340 s1_split_num_attn_mask, 1};
341 }
342 }
343
344 if (is_input_passed[ops::kFlashAttentionScoreInputPrefixIndex]) {
345 expect_strategies[ops::kFlashAttentionScoreInputPrefixIndex] = {batch_split_num};
346 }
347
348 if (is_input_passed[ops::kFlashAttentionScoreInputActualSeqQlenIndex]) {
349 expect_strategies[ops::kFlashAttentionScoreInputActualSeqQlenIndex] = {NO_SPLIT_STRATEGY};
350 }
351 if (is_input_passed[ops::kFlashAttentionScoreInputActualSeqKVlenIndex]) {
352 expect_strategies[ops::kFlashAttentionScoreInputActualSeqKVlenIndex] = {NO_SPLIT_STRATEGY};
353 }
354
355 expect_strategies.erase(std::remove(expect_strategies.begin(), expect_strategies.end(), Shape{}),
356 expect_strategies.end());
357 return expect_strategies;
358 }
359
PrepareFillV2(const std::shared_ptr<OperatorInfo> & op,Dimensions basic_stra,bool dyn_shape_tmp_fix)360 Strategies PrepareFillV2(const std::shared_ptr<OperatorInfo> &op, Dimensions basic_stra, bool dyn_shape_tmp_fix) {
361 Strategies strategies;
362
363 if (op->outputs_shape().size() == 0) {
364 MS_LOG(EXCEPTION) << op->name() << " output tensor info is empty.";
365 }
366
367 for (size_t i = basic_stra.size(); i < op->outputs_shape()[0].size(); i++) {
368 basic_stra.push_back(1);
369 }
370
371 strategies.push_back(basic_stra);
372 basic_stra.clear();
373 strategies.push_back(basic_stra);
374 return strategies;
375 }
376
PrepareMatMulStrategy(Graph::NodeType * node,bool transpose_a,bool transpose_b,size_t iter_op_inputs)377 Dimensions PrepareMatMulStrategy(Graph::NodeType *node, bool transpose_a, bool transpose_b, size_t iter_op_inputs) {
378 Dimensions strategy;
379 if (transpose_a && (iter_op_inputs == 0)) {
380 strategy.push_back(static_cast<int64_t>(1.0 / node->apply.arguments[iter_op_inputs].tensor_str.str_w));
381 strategy.push_back(static_cast<int64_t>(1.0 / node->apply.arguments[iter_op_inputs].tensor_str.str_h));
382 } else if (transpose_b && (iter_op_inputs == 1)) {
383 strategy.push_back(static_cast<int64_t>(1.0 / node->apply.arguments[iter_op_inputs].tensor_str.str_w));
384 strategy.push_back(static_cast<int64_t>(1.0 / node->apply.arguments[iter_op_inputs].tensor_str.str_h));
385 } else {
386 strategy.push_back(static_cast<int64_t>(1.0 / node->apply.arguments[iter_op_inputs].tensor_str.str_h));
387 strategy.push_back(static_cast<int64_t>(1.0 / node->apply.arguments[iter_op_inputs].tensor_str.str_w));
388 }
389 return strategy;
390 }
391
PrepareMatMul(Graph::NodeType * node,const std::shared_ptr<OperatorInfo> & op)392 Strategies PrepareMatMul(Graph::NodeType *node, const std::shared_ptr<OperatorInfo> &op) {
393 Strategies strategies;
394 auto input_value = op->input_value();
395 bool transpose_a = input_value[2]->cast<BoolImmPtr>()->value();
396 bool transpose_b = input_value[3]->cast<BoolImmPtr>()->value();
397
398 for (size_t iter_op_inputs = 0; iter_op_inputs < op->inputs_shape().size(); iter_op_inputs++) {
399 Dimensions strategy = PrepareMatMulStrategy(node, transpose_a, transpose_b, iter_op_inputs);
400 strategies.push_back(strategy);
401 }
402 return strategies;
403 }
404
PreparePropagateBatchMatMul(const std::shared_ptr<OperatorInfo> & op,Dimensions basic_stra,bool dyn_shape_tmp_fix)405 Strategies PreparePropagateBatchMatMul(const std::shared_ptr<OperatorInfo> &op, Dimensions basic_stra,
406 bool dyn_shape_tmp_fix) {
407 if (dyn_shape_tmp_fix) {
408 return CheckDivisible(op, basic_stra);
409 }
410 // This backward propagation does NOT complete strategy on k. Could be done later
411 Strategies stra;
412 auto input_value = op->input_value();
413 bool transpose_a = input_value[2]->cast<BoolImmPtr>()->value();
414 bool transpose_b = input_value[3]->cast<BoolImmPtr>()->value();
415
416 size_t first_input_size = op->inputs_shape()[0].size();
417 size_t second_input_size = op->inputs_shape()[1].size();
418
419 Dimensions first_input_dim(first_input_size);
420 Dimensions second_input_dim(second_input_size);
421
422 // first input
423 if (!transpose_a) {
424 first_input_dim[first_input_size - 1] = 1; // k axis
425 first_input_dim[first_input_size - 2] = basic_stra[basic_stra.size() - 2]; // i axis
426 } else {
427 first_input_dim[first_input_size - 2] = 1; // k axis
428 first_input_dim[first_input_size - 1] = basic_stra[basic_stra.size() - 2]; // i axis
429 }
430
431 for (size_t idx = 3; idx <= first_input_size; idx++) {
432 first_input_dim[first_input_size - idx] = basic_stra[basic_stra.size() - idx];
433 }
434
435 // second input
436 if (!transpose_b) {
437 second_input_dim[second_input_size - 2] = 1; // k axis
438 second_input_dim[second_input_size - 1] = basic_stra[basic_stra.size() - 1]; // j axis
439 } else {
440 second_input_dim[second_input_size - 1] = 1; // k axis
441 second_input_dim[second_input_size - 2] = basic_stra[basic_stra.size() - 1]; // j axis
442 }
443
444 for (size_t idx = 3; idx <= second_input_size; idx++) {
445 second_input_dim[second_input_size - idx] = basic_stra[basic_stra.size() - idx];
446 }
447
448 stra.push_back(first_input_dim);
449 stra.push_back(second_input_dim);
450 return stra;
451 }
452
PrepareBatchMatMulStrategy(Graph::NodeType * node,const bool transpose_a,const bool transpose_b,const size_t iter_op_inputs,const size_t dim_num)453 Dimensions PrepareBatchMatMulStrategy(Graph::NodeType *node, const bool transpose_a, const bool transpose_b,
454 const size_t iter_op_inputs, const size_t dim_num) {
455 if (node->apply.arguments[iter_op_inputs].tensor_str.str_n == 0 ||
456 node->apply.arguments[iter_op_inputs].tensor_str.str_c == 0 ||
457 node->apply.arguments[iter_op_inputs].tensor_str.str_h == 0 ||
458 node->apply.arguments[iter_op_inputs].tensor_str.str_w == 0) {
459 MS_LOG(EXCEPTION) << "The strategy is 0";
460 }
461
462 Dimensions strategy;
463 if (dim_num >= SIZE_FOUR) {
464 strategy.push_back(static_cast<int64_t>(1.0 / node->apply.arguments[iter_op_inputs].tensor_str.str_n));
465 }
466 if (dim_num >= SIZE_THREE) {
467 strategy.push_back(static_cast<int64_t>(1.0 / node->apply.arguments[iter_op_inputs].tensor_str.str_c));
468 }
469 if (transpose_a && (iter_op_inputs == 0)) {
470 strategy.push_back(static_cast<int64_t>(1.0 / node->apply.arguments[iter_op_inputs].tensor_str.str_w));
471 strategy.push_back(static_cast<int64_t>(1.0 / node->apply.arguments[iter_op_inputs].tensor_str.str_h));
472 } else if (transpose_b && (iter_op_inputs == 1)) {
473 strategy.push_back(static_cast<int64_t>(1.0 / node->apply.arguments[iter_op_inputs].tensor_str.str_w));
474 strategy.push_back(static_cast<int64_t>(1.0 / node->apply.arguments[iter_op_inputs].tensor_str.str_h));
475 } else {
476 strategy.push_back(static_cast<int64_t>(1.0 / node->apply.arguments[iter_op_inputs].tensor_str.str_h));
477 strategy.push_back(static_cast<int64_t>(1.0 / node->apply.arguments[iter_op_inputs].tensor_str.str_w));
478 }
479 return strategy;
480 }
481
PrepareBatchMatMul(Graph::NodeType * node,const std::shared_ptr<OperatorInfo> & op)482 Strategies PrepareBatchMatMul(Graph::NodeType *node, const std::shared_ptr<OperatorInfo> &op) {
483 Strategies strategies;
484 auto input_value = op->input_value();
485 bool transpose_a = input_value[2]->cast<BoolImmPtr>()->value();
486 bool transpose_b = input_value[3]->cast<BoolImmPtr>()->value();
487
488 for (size_t iter_op_inputs = 0; iter_op_inputs < op->inputs_shape().size(); iter_op_inputs++) {
489 Dimensions strategy = PrepareBatchMatMulStrategy(node, transpose_a, transpose_b, iter_op_inputs,
490 op->inputs_shape()[iter_op_inputs].size());
491 strategies.push_back(strategy);
492 }
493 return strategies;
494 }
495
PrepareBiasAdd(const std::shared_ptr<OperatorInfo> & op,Dimensions basic_stra,bool dyn_shape_tmp_fix)496 Strategies PrepareBiasAdd(const std::shared_ptr<OperatorInfo> &op, Dimensions basic_stra, bool dyn_shape_tmp_fix) {
497 auto strategy = std::make_shared<Dimensions>(basic_stra);
498 Strategies strategies;
499 strategies.push_back(*strategy);
500 Dimensions s_biasadd;
501 s_biasadd.push_back(strategy->at(1));
502 strategies.push_back(s_biasadd);
503 return strategies;
504 }
505
PrepareStandAlone(const std::shared_ptr<OperatorInfo> & op)506 Strategies PrepareStandAlone(const std::shared_ptr<OperatorInfo> &op) {
507 Strategies strategies;
508 Dimensions strategy;
509
510 for (size_t i = 0; i < op->outputs_tensor_info().size(); i++) {
511 strategy.clear();
512 for (size_t j = 0; j < op->inputs_tensor_info()[i].shape().size(); j++) {
513 strategy.push_back(1);
514 }
515 strategies.push_back(strategy);
516 }
517
518 return strategies;
519 }
520
PrepareDataParallel(const std::shared_ptr<OperatorInfo> & op,Dimensions basic_stra,bool dyn_shape_tmp_fix)521 Strategies PrepareDataParallel(const std::shared_ptr<OperatorInfo> &op, Dimensions basic_stra, bool dyn_shape_tmp_fix) {
522 size_t numDev = g_device_manager->stage_device_num();
523
524 Strategies strategies;
525 Dimensions strategy;
526
527 if (numDev == 0) {
528 MS_LOG(EXCEPTION) << "The number of devices is 0";
529 }
530
531 for (size_t i = 0; i < op->inputs_shape().size(); i++) {
532 strategy.clear();
533 if (LongToSize(op->inputs_shape()[i][0]) % numDev == 0) {
534 strategy.push_back(numDev);
535 } else {
536 strategy.push_back(1);
537 }
538 for (size_t j = 1; j < op->inputs_shape()[i].size(); j++) {
539 strategy.push_back(1);
540 }
541 strategies.push_back(strategy);
542 }
543
544 return strategies;
545 }
546
PrepareOneHotOutputStrategy(const std::shared_ptr<OperatorInfo> & op)547 Dimensions PrepareOneHotOutputStrategy(const std::shared_ptr<OperatorInfo> &op) {
548 auto op_strategy = op->selected_strategy();
549 Dimensions strategy;
550
551 for (size_t i = 0; i < static_cast<size_t>(op->inputs_shape().size()); i++) {
552 if (op->inputs_shape()[i].size() == 0) {
553 continue;
554 }
555 // copy the full strategy (Assume strategy has the same size as the following operator input shape)
556 for (size_t j = 0; j < op_strategy->GetInputDim().at(i).size(); ++j) {
557 strategy.push_back(op_strategy->GetInputDim().at(i).at(j));
558 }
559 break;
560 }
561 return strategy;
562 }
563
PrepareStridedSlice(const std::shared_ptr<OperatorInfo> & op,Dimensions basic_stra,bool dyn_shape_tmp_fix)564 Strategies PrepareStridedSlice(const std::shared_ptr<OperatorInfo> &op, Dimensions basic_stra, bool dyn_shape_tmp_fix) {
565 Strategies strategies;
566
567 if (dyn_shape_tmp_fix) {
568 return strategies;
569 }
570
571 auto strided_slice = std::static_pointer_cast<StridedSliceInfo>(op);
572 strided_slice->GetAttrs();
573 auto begin = strided_slice->begin();
574 auto strides = strided_slice->strides();
575 auto new_axis_mask_bitmap = strided_slice->new_axis_mask_bitmap();
576 auto fully_fetch_flag = strided_slice->fully_fetch_flag();
577 auto skip_redistribution = strided_slice->skip_redistribution();
578
579 Shape strategy_in_process = Shape(basic_stra.size(), 0);
580 for (size_t i = 0; i < new_axis_mask_bitmap.size() && i < begin.size() && i < basic_stra.size(); ++i) {
581 if (new_axis_mask_bitmap[i]) {
582 strategy_in_process[i] = 1;
583 }
584 }
585
586 size_t count = 0;
587 for (auto &ele : strategy_in_process) {
588 if (ele != 0) {
589 continue;
590 }
591 ele = basic_stra[count];
592 count++;
593 }
594
595 (void)strategy_in_process.insert(strategy_in_process.end(), basic_stra.begin() + count, basic_stra.end());
596 MS_LOG(INFO) << op->name() << ": The strategy in process is " << strategy_in_process;
597
598 for (size_t j = 0; j < strides.size(); ++j) {
599 if ((strides[j] != 1) && (strategy_in_process[j] > 1)) {
600 strategy_in_process[j] = 1;
601 }
602 }
603
604 for (size_t k = 0; k < begin.size(); ++k) {
605 if (!fully_fetch_flag[k] && (strategy_in_process[k] != 1) && !skip_redistribution) {
606 strategy_in_process[k] = 1;
607 }
608 }
609
610 strategies.push_back(strategy_in_process);
611 return strategies;
612 }
613
FindAxisProperty(const std::shared_ptr<OperatorInfo> & op)614 std::vector<int64_t> FindAxisProperty(const std::shared_ptr<OperatorInfo> &op) {
615 std::vector<int64_t> axis_list;
616 string axis_name = AXIS;
617 auto input_value = op->input_value();
618
619 auto op_name = op->name();
620
621 if (input_value[input_value.size() - 1]->isa<ValueSequence>()) { // Softmax axis is a tuple
622 std::optional<std::vector<int64_t>> axis_opt = GetArrayValueFromInputs<int64_t>(input_value, op_name, axis_name);
623 std::vector<int64_t> axis_val = axis_opt.value();
624 if (axis_opt.has_value()) {
625 axis_list.swap(axis_val);
626 } else {
627 axis_list.push_back(-1);
628 }
629 } else { // LogSoftmax axis is a scaler
630 std::optional<int64_t> axis_opt = GetScalarValueFromInputs<int64_t>(input_value, op_name, axis_name);
631 int64_t axis_val = axis_opt.value();
632 if (axis_opt.has_value()) {
633 axis_list.push_back(axis_val);
634 } else {
635 axis_list.push_back(-1);
636 }
637 }
638 return axis_list;
639 }
640
PrepareSoftMax(const std::shared_ptr<OperatorInfo> & op,Dimensions basic_stra,bool dyn_shape_tmp_fix)641 Strategies PrepareSoftMax(const std::shared_ptr<OperatorInfo> &op, Dimensions basic_stra, bool dyn_shape_tmp_fix) {
642 Strategies strategies;
643 strategies.push_back(basic_stra);
644 std::vector<int64_t> axis_list = FindAxisProperty(op);
645
646 for (auto &axis : axis_list) {
647 if (axis < 0) {
648 int64_t input_dim = SizeToLong(op->inputs_shape()[0].size());
649 axis = input_dim + axis;
650 }
651 if (axis >= SizeToLong(strategies[0].size()) || axis < 0) {
652 MS_LOG(EXCEPTION) << op->name() << ": axis value is out of range.";
653 }
654 if (strategies[0][LongToSize(axis)] != 1) {
655 strategies[0][LongToSize(axis)] = 1;
656 MS_LOG(INFO) << op->name() << ": adjust strategy to 1 on axis " << axis;
657 }
658 }
659
660 // Strategy protection to avoid that partition number is larger than the shape of related dimension.
661 for (size_t i = 0; i < op->inputs_shape().size(); i++) {
662 for (size_t j = 0; j < op->inputs_shape()[i].size(); j++) {
663 if (strategies[i][j] > op->inputs_shape()[i][j] || op->inputs_shape()[i][j] % strategies[i][j] != 0) {
664 strategies[i][j] = 1;
665 }
666 }
667 }
668
669 return strategies;
670 }
671
PrepareLayerNorm(const std::shared_ptr<OperatorInfo> & op,Dimensions basic_stra,bool dyn_shape_tmp_fix)672 Strategies PrepareLayerNorm(const std::shared_ptr<OperatorInfo> &op, Dimensions basic_stra, bool dyn_shape_tmp_fix) {
673 Strategies strategies;
674 strategies.push_back(basic_stra);
675 std::vector<int64_t> axis_list;
676 string axis_name = AXIS;
677
678 auto iter = op->attrs().find(axis_name);
679 if (iter != op->attrs().end()) {
680 MS_EXCEPTION_IF_NULL(iter->second);
681 if (iter->second->isa<Int64Imm>()) {
682 axis_list.push_back(iter->second->cast<Int64ImmPtr>()->value());
683 } else if (iter->second->isa<ValueTuple>()) {
684 ValueTuplePtr value_tuple = iter->second->cast<ValueTuplePtr>();
685 if (value_tuple == nullptr) {
686 MS_LOG(EXCEPTION) << op->name() << ": The value_tuple is nullptr.";
687 }
688
689 std::vector<ValuePtr> value_vector = value_tuple->value();
690 (void)std::transform(value_vector.begin(), value_vector.end(), std::back_inserter(axis_list),
691 [](const ValuePtr &value) { return static_cast<int64_t>(GetValue<int64_t>(value)); });
692 } else {
693 MS_LOG(EXCEPTION) << op->name() << ": The value of axis is not int64_t or tuple int64_t.";
694 }
695 } else {
696 axis_list.push_back(-1);
697 }
698
699 for (auto &axis : axis_list) {
700 if (axis < 0) {
701 int64_t input_dim = SizeToLong(op->inputs_shape()[0].size());
702 axis = input_dim + axis;
703 }
704 if (axis >= SizeToLong(strategies[0].size()) || axis < 0) {
705 MS_LOG(EXCEPTION) << op->name() << ": axis value is out of range.";
706 }
707 if (strategies[0][LongToSize(axis)] != 1) {
708 strategies[0][LongToSize(axis)] = 1;
709 MS_LOG(INFO) << op->name() << ": adjust strategy to 1 on axis " << axis;
710 }
711 }
712 Dimensions d = {1};
713 strategies.push_back(d);
714 strategies.push_back(d);
715 return strategies;
716 }
717
PrepareRmsNorm(const std::shared_ptr<OperatorInfo> & op,Dimensions basic_stra,bool dyn_shape_tmp_fix)718 Strategies PrepareRmsNorm(const std::shared_ptr<OperatorInfo> &op, Dimensions basic_stra, bool dyn_shape_tmp_fix) {
719 Strategies strategies;
720 auto inputs = op->inputs_shape();
721 auto input = inputs[0];
722 Shape strategy_in_process = Shape(input.size(), 1);
723 int64_t devices = SizeToLong(g_device_manager->DeviceNum());
724
725 if (parallel::ParallelContext::GetInstance()->pipeline_stage_split_num() == 0) {
726 MS_LOG(EXCEPTION) << "divisors cannot be 0!";
727 }
728 int64_t max_cut = devices / parallel::ParallelContext::GetInstance()->pipeline_stage_split_num();
729 strategy_in_process[0] = input[0] < max_cut ? input[0] : max_cut;
730
731 auto gamma = inputs[1];
732 size_t gamma_diff = input.size() - gamma.size();
733 Dimensions gamma_strategy;
734 for (size_t j = 0; j < gamma.size(); ++j) {
735 gamma_strategy.push_back(strategy_in_process[gamma_diff + j]);
736 }
737
738 strategies.push_back(strategy_in_process);
739 strategies.push_back(gamma_strategy);
740 return strategies;
741 }
742
PrepareOneHot(const std::shared_ptr<OperatorInfo> & op,Dimensions strategy,bool dyn_shape_tmp_fix)743 Strategies PrepareOneHot(const std::shared_ptr<OperatorInfo> &op, Dimensions strategy, bool dyn_shape_tmp_fix) {
744 Strategies strategies;
745
746 // OneHot's strategy depends on its output shape.
747 for (size_t i = strategy.size(); i < op->outputs_shape()[0].size(); i++) {
748 strategy.push_back(1);
749 }
750
751 // Partition number should not exceed the number of devices
752 for (size_t i = 0; i < op->outputs_shape()[0].size(); i++) {
753 if (strategy[i] > op->outputs_shape()[0][i]) {
754 strategy[i] = 1;
755 }
756 }
757
758 strategies.push_back(strategy);
759
760 // Push two empty Dimensions for the other two input tensors.
761 Dimensions s_empty = {};
762 strategies.push_back(s_empty);
763 strategies.push_back(s_empty);
764
765 return strategies;
766 }
767
GenGatherStra(Shape targeted_shape)768 Dimensions GenGatherStra(Shape targeted_shape) {
769 Dimensions index(targeted_shape.size() - 1, 0);
770 for (size_t i = 0; i < index.size(); i++) {
771 index[i] = SizeToLong(i);
772 }
773
774 std::sort(index.begin(), index.end(), [&targeted_shape](const size_t &a, const size_t &b) {
775 return (targeted_shape[a + 1] > targeted_shape[b + 1]);
776 });
777 (void)std::transform(std::begin(index), std::end(index), std::begin(index), [](int64_t x) { return x + 1; });
778 (void)index.insert(index.cbegin(), 0);
779
780 Dimensions strategie(targeted_shape.size(), 1);
781
782 size_t num_device = LongToSize(g_device_manager->stage_device_num());
783 size_t cut = 1;
784 for (size_t i = 0; i < index.size(); i++) {
785 size_t index_i = LongToSize(index[i]);
786 while (targeted_shape[index_i] % SIZE_TWO == 0 && targeted_shape[index_i] > 0 && cut < num_device) {
787 targeted_shape[index_i] /= SIZE_TWO;
788 cut *= SIZE_TWO;
789 strategie[index_i] *= SIZE_TWO; // We apply 2-parts partitioning for Gather.
790 }
791 if (cut == num_device) {
792 break;
793 }
794 }
795
796 return strategie;
797 }
798
GatherForDynamicShape(const std::shared_ptr<OperatorInfo> & op,const size_t dim)799 Strategies GatherForDynamicShape(const std::shared_ptr<OperatorInfo> &op, const size_t dim) {
800 Strategies strategies;
801 auto gather_input_0_shape = op->inputs_shape()[0];
802 if (dim >= gather_input_0_shape.size()) {
803 MS_LOG(EXCEPTION) << "Failure: Gather's axis out of range.";
804 }
805 Dimensions gather_input_0_strategy(gather_input_0_shape.size(), 1);
806 int64_t num_device = g_device_manager->stage_device_num();
807 if (gather_input_0_shape[dim] % num_device == 0) {
808 size_t cut = 1;
809 while (gather_input_0_shape[dim] > 0 && gather_input_0_shape[dim] % SIZE_TWO == 0 && cut < LongToSize(num_device)) {
810 gather_input_0_shape[dim] /= SIZE_TWO;
811 cut *= SIZE_TWO;
812 gather_input_0_strategy[dim] *= SIZE_TWO;
813 }
814 }
815 strategies.push_back(gather_input_0_strategy);
816 for (size_t i = 1; i < op->inputs_shape().size(); i++) {
817 Dimensions gather_input_i_strategy(op->inputs_shape()[i].size(), 1);
818 strategies.push_back(gather_input_i_strategy);
819 }
820 return strategies;
821 }
822
PrepareGather(const std::shared_ptr<OperatorInfo> & op,Dimensions strategy,bool dyn_shape_tmp_fix)823 Strategies PrepareGather(const std::shared_ptr<OperatorInfo> &op, Dimensions strategy, bool dyn_shape_tmp_fix) {
824 if (dyn_shape_tmp_fix) {
825 Strategies strategies;
826 strategies.push_back(strategy);
827 for (size_t i = 1; i < op->inputs_shape().size(); i++) {
828 Dimensions gather_input_i_strategy(op->inputs_shape()[i].size(), 1);
829 strategies.push_back(gather_input_i_strategy);
830 }
831 return strategies;
832 }
833
834 Strategies strategies;
835 Shape targeted_shape = op->outputs_shape()[0];
836 Dimensions strategie = GenGatherStra(targeted_shape);
837
838 int64_t axis = GetGatherAxis(op);
839 MS_LOG(INFO) << op->name() << ": the axis is " << axis;
840
841 int64_t batch_dims = GetGatherBatchDims(op);
842 MS_LOG(INFO) << op->name() << ": the batch_dims is " << batch_dims;
843
844 if (batch_dims > 1) {
845 for (size_t i = 0; i < op->inputs_shape().size(); i++) {
846 strategies.push_back(strategie);
847 }
848 strategies[0][axis] = 1;
849 return strategies;
850 }
851
852 strategy.clear();
853 if (axis == 0) {
854 Shape param_strategy = Shape(op->inputs_shape()[0].size(), 1);
855 Shape indices_strategy = Shape(op->inputs_shape()[1].size(), 1);
856 strategies.push_back(param_strategy);
857 strategies.push_back(indices_strategy);
858 size_t num_device = LongToSize(g_device_manager->stage_device_num());
859 size_t cut = 1;
860 int gather_inputs_num = SizeToInt(op->inputs_shape().size());
861 for (int i = gather_inputs_num - 1; i >= 0; --i) {
862 auto tensor_shape = op->inputs_shape()[i];
863 while (tensor_shape[0] % SIZE_TWO == 0 && tensor_shape[0] > 0 && cut < num_device) {
864 tensor_shape[0] /= SIZE_TWO;
865 cut *= SIZE_TWO;
866 strategies[i][0] *= SIZE_TWO; // We apply 2-parts partitioning for Gather.
867 }
868 if (cut == num_device) {
869 break;
870 }
871 }
872 } else if (axis == 1) {
873 strategy.push_back(strategie[0]);
874 strategy.push_back(1);
875 strategies.push_back(strategy);
876 strategy.clear();
877 for (size_t i = 0; i < op->inputs_shape()[1].size(); i++) {
878 strategy.push_back(strategie[op->inputs_shape()[0].size() - 1 + i]);
879 }
880 strategies.push_back(strategy);
881 } else {
882 MS_LOG(EXCEPTION) << "Failure: Normal Gather's axis is neither 0 nor 1.";
883 }
884
885 auto gather = std::static_pointer_cast<GatherInfo>(op);
886 auto gather_mode = gather->GetGatherMode(strategies[0], strategies[1]);
887 MS_LOG(INFO) << op->name() << ": the gather_mode is " << gather_mode;
888 if (gather_mode == SHARD_AXIS_0_DYNAMIC || gather_mode == SHARD_AXIS_0_STATIC || gather_mode == SHARD_AXIS_1) {
889 if (DevicesForDimensions(strategies[1]) != 1 && strategies[0][axis] != 1) {
890 strategies[0][axis] = 1;
891 MS_LOG(INFO) << op->name() << ": param_strategy[" << axis << "] is changed to 1.";
892 }
893 }
894
895 return strategies;
896 }
897
PrepareGatherV2OutputStrategy(const std::shared_ptr<OperatorInfo> & op)898 Dimensions PrepareGatherV2OutputStrategy(const std::shared_ptr<OperatorInfo> &op) {
899 auto targeted_shape = op->outputs_shape()[0];
900 Dimensions strategie = GenGatherStra(targeted_shape);
901 return strategie;
902 }
903
PrepareL2Normalize(const std::shared_ptr<OperatorInfo> & op,Dimensions strategy,bool dyn_shape_tmp_fix)904 Strategies PrepareL2Normalize(const std::shared_ptr<OperatorInfo> &op, Dimensions strategy, bool dyn_shape_tmp_fix) {
905 int64_t axis = 0;
906 auto iter = op->attrs().find(AXIS);
907 if (iter != op->attrs().end()) {
908 MS_EXCEPTION_IF_NULL(iter->second);
909 if (iter->second->isa<ValueSequence>()) {
910 axis = GetValue<std::vector<int64_t>>(iter->second)[0];
911 } else {
912 MS_LOG(EXCEPTION) << op->name() << " : The value of axis is not int64_t.";
913 }
914 }
915
916 int64_t axis_index = axis;
917 if (axis < 0) {
918 size_t input_dim = op->inputs_shape()[0].size();
919 axis_index = static_cast<int64_t>(input_dim) + axis;
920 }
921
922 strategy[LongToSize(axis_index)] = 1;
923
924 Strategies strategies;
925 strategies.push_back(strategy);
926 return strategies;
927 }
928
PrepareAxisRelatedStrategy(Graph::NodeType * node,const std::vector<std::shared_ptr<OperatorInfo>> & ops,const size_t iter_ops)929 Strategies PrepareAxisRelatedStrategy(Graph::NodeType *node, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
930 const size_t iter_ops) {
931 Strategies strategies = MakeRecSearchStrategy(node, ops, iter_ops);
932 if (strategies.size() < 1) {
933 MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": get empty Strategy.";
934 }
935
936 std::vector<int64_t> axis_list;
937 string axis_name = AXIS;
938 int64_t default_axis = -1;
939 if (ops[iter_ops]->type() == LAYER_NORM) {
940 axis_name = "begin_norm_axis";
941 default_axis = 1;
942 }
943
944 auto iter = ops[iter_ops]->attrs().find(axis_name);
945 if (iter != ops[iter_ops]->attrs().end()) {
946 MS_EXCEPTION_IF_NULL(iter->second);
947 if (iter->second->isa<Int64Imm>()) {
948 axis_list.push_back(iter->second->cast<Int64ImmPtr>()->value());
949 } else if (iter->second->isa<ValueTuple>()) {
950 ValueTuplePtr value_tuple = iter->second->cast<ValueTuplePtr>();
951 if (value_tuple == nullptr) {
952 MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": The value_tuple is nullptr.";
953 }
954 std::vector<ValuePtr> value_vector = value_tuple->value();
955 (void)std::transform(value_vector.begin(), value_vector.end(), std::back_inserter(axis_list),
956 [](const ValuePtr &value) { return static_cast<int64_t>(GetValue<int64_t>(value)); });
957 } else {
958 MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": The value of axis is not int64_t or tuple int64_t.";
959 }
960 } else {
961 axis_list.push_back(default_axis);
962 }
963
964 for (auto &axis : axis_list) {
965 if (axis < 0) {
966 int64_t input_dim = SizeToLong(ops[iter_ops]->inputs_shape()[0].size());
967 axis = input_dim + axis;
968 }
969 if (axis >= SizeToLong(strategies[0].size()) || axis < 0) {
970 MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": axis value is out of range.";
971 }
972 if (strategies[0][LongToSize(axis)] != 1) {
973 strategies[0][LongToSize(axis)] = 1;
974 MS_LOG(INFO) << ops[iter_ops]->name() << ": adjust strategy to 1 on axis " << axis;
975 }
976 }
977 return strategies;
978 }
979
MakeRecSearchStrategy(Graph::NodeType * node,const std::vector<std::shared_ptr<OperatorInfo>> & ops,const size_t iter_ops)980 Strategies MakeRecSearchStrategy(Graph::NodeType *node, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
981 const size_t iter_ops) {
982 if (ops.empty()) {
983 MS_LOG(EXCEPTION) << "Failure: Operators is empty.";
984 }
985 if (iter_ops >= ops.size()) {
986 MS_LOG(EXCEPTION) << "Failure: Operators' elements out of range.";
987 }
988 if (node->apply.op_type == kRecUnsortedSegmentOp) {
989 return MakeDataParallelStrategy(node, ops, iter_ops);
990 }
991
992 Strategies strategies;
993 for (size_t iter_op_inputs = 0; iter_op_inputs < ops[iter_ops]->inputs_shape().size(); iter_op_inputs++) {
994 if (iter_op_inputs >= ops[iter_ops]->inputs_shape().size()) {
995 MS_LOG(EXCEPTION) << "Failure: Strategy's InputDim out of range.";
996 }
997
998 size_t input_size = ops[iter_ops]->inputs_shape()[iter_op_inputs].size();
999 Dimensions strategy;
1000 if (input_size == SIZE_FOUR) {
1001 strategy.push_back(static_cast<int64_t>(1.0 / node->apply.arguments[iter_op_inputs].tensor_str.str_n));
1002 strategy.push_back(static_cast<int64_t>(1.0 / node->apply.arguments[iter_op_inputs].tensor_str.str_c));
1003 strategy.push_back(static_cast<int64_t>(1.0 / node->apply.arguments[iter_op_inputs].tensor_str.str_h));
1004 strategy.push_back(static_cast<int64_t>(1.0 / node->apply.arguments[iter_op_inputs].tensor_str.str_w));
1005 } else if (input_size == SIZE_THREE) {
1006 // Experimental support for 3D data.
1007 strategy.push_back(static_cast<int64_t>(1.0 / node->apply.arguments[iter_op_inputs].tensor_str.str_c));
1008 strategy.push_back(static_cast<int64_t>(1.0 / node->apply.arguments[iter_op_inputs].tensor_str.str_h));
1009 strategy.push_back(static_cast<int64_t>(1.0 / node->apply.arguments[iter_op_inputs].tensor_str.str_w));
1010 } else if (input_size == SIZE_TWO) {
1011 strategy.push_back(static_cast<int64_t>(1.0 / node->apply.arguments[iter_op_inputs].tensor_str.str_h));
1012 strategy.push_back(static_cast<int64_t>(1.0 / node->apply.arguments[iter_op_inputs].tensor_str.str_w));
1013 } else if (input_size == SIZE_ONE) {
1014 strategy.push_back(static_cast<int64_t>(1.0 / node->apply.arguments[iter_op_inputs].tensor_str.str_w));
1015 } else if (input_size == SIZE_ZERO) {
1016 strategy = {};
1017 } else {
1018 MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": Tensor's input size is unexcepted.";
1019 }
1020 strategies.push_back(strategy);
1021 }
1022 return strategies;
1023 }
1024
MakeDataParallelStrategy(Graph::NodeType * node,const std::vector<std::shared_ptr<OperatorInfo>> & ops,const size_t iter_ops)1025 Strategies MakeDataParallelStrategy(Graph::NodeType *node, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
1026 const size_t iter_ops) {
1027 if (ops.empty()) {
1028 MS_LOG(EXCEPTION) << "Failure: Operators is empty.";
1029 }
1030 if (iter_ops >= ops.size()) {
1031 MS_LOG(EXCEPTION) << "Failure: Operators' elements out of range.";
1032 }
1033
1034 Strategies strategies;
1035 size_t max_device_num = LongToSize(g_device_manager->stage_device_num());
1036 size_t target_tensor_batch = LongToUlong(ops[iter_ops]->inputs_shape()[0][0]);
1037 for (size_t iter_op_inputs = 0; iter_op_inputs < ops[iter_ops]->inputs_shape().size(); iter_op_inputs++) {
1038 if (iter_op_inputs >= ops[iter_ops]->inputs_shape().size()) {
1039 MS_LOG(EXCEPTION) << "Failure: Strategy's InputDim out of range.";
1040 }
1041
1042 Dimensions strategy;
1043 size_t input_size = ops[iter_ops]->inputs_shape()[iter_op_inputs].size();
1044 for (size_t dim = 0; dim < input_size; dim++) {
1045 // Experimental support for 3D data (input_size == 3).
1046 if (input_size >= SIZE_ONE && input_size <= STR_DIM_NUM) {
1047 if (dim == 0) {
1048 strategy.push_back(std::min(max_device_num, target_tensor_batch));
1049 } else {
1050 strategy.push_back(1);
1051 }
1052 } else if (input_size == 0) {
1053 strategy = {};
1054 } else {
1055 MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": Tensor shape " << input_size << " is unexpected.";
1056 }
1057 }
1058 strategies.push_back(strategy);
1059 }
1060 // Set default strategy.
1061 node->tensor_parm.tensor_str.str_n = 1.0;
1062 node->tensor_parm.tensor_str.str_c = 1.0;
1063 node->tensor_parm.tensor_str.str_h = 1.0;
1064 node->tensor_parm.tensor_str.str_w = 1.0;
1065
1066 // Update data parallel strategy.
1067 if (ops[iter_ops]->outputs_shape().size() == SIZE_ZERO) {
1068 MS_LOG(EXCEPTION) << ops[iter_ops]->name() << " output tensor info is empty.";
1069 }
1070 if (ops[iter_ops]->outputs_shape()[0].size() == SIZE_ONE) {
1071 node->tensor_parm.tensor_str.str_w = 1.0 / std::min(max_device_num, target_tensor_batch);
1072 } else if (ops[iter_ops]->outputs_shape()[0].size() == SIZE_TWO) {
1073 node->tensor_parm.tensor_str.str_h = 1.0 / std::min(max_device_num, target_tensor_batch);
1074 } else if (ops[iter_ops]->outputs_shape()[0].size() == SIZE_THREE) {
1075 // Experimental support for 3D data.
1076 node->tensor_parm.tensor_str.str_c = 1.0 / std::min(max_device_num, target_tensor_batch);
1077 } else if (ops[iter_ops]->outputs_shape()[0].size() == SIZE_FOUR) { // Experimental support for 4D data.
1078 node->tensor_parm.tensor_str.str_n = 1.0 / std::min(max_device_num, target_tensor_batch);
1079 } else {
1080 MS_LOG(INFO) << ops[iter_ops]->name() << " output tensor shape is unexpected, using default value instead.";
1081 }
1082
1083 return strategies;
1084 }
1085
MakeFullBatchStrategy(Graph::NodeType * node,const std::vector<std::shared_ptr<OperatorInfo>> & ops,const size_t iter_ops)1086 Strategies MakeFullBatchStrategy(Graph::NodeType *node, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
1087 const size_t iter_ops) {
1088 if (ops.empty()) {
1089 MS_LOG(EXCEPTION) << "Failure: Operators is empty.";
1090 }
1091 if (iter_ops >= ops.size()) {
1092 MS_LOG(EXCEPTION) << "Failure: Operators' elements out of range.";
1093 }
1094
1095 Strategies strategies;
1096 for (size_t iter_op_inputs = 0; iter_op_inputs < ops[iter_ops]->inputs_shape().size(); iter_op_inputs++) {
1097 if (iter_op_inputs >= ops[iter_ops]->inputs_shape().size()) {
1098 MS_LOG(EXCEPTION) << "Failure: Strategy's InputDim out of range.";
1099 }
1100 Dimensions strategy;
1101 size_t input_size = ops[iter_ops]->inputs_shape()[iter_op_inputs].size();
1102 for (size_t dim = 0; dim < input_size; dim++) {
1103 if (input_size >= SIZE_ONE && input_size <= SIZE_FOUR) {
1104 strategy.push_back(1);
1105 } else if (input_size == 0) {
1106 strategy = {};
1107 } else {
1108 MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": Tensor shape " << input_size << " is unexpected.";
1109 }
1110 }
1111 strategies.push_back(strategy);
1112 }
1113 // Update the output strategy of Rec Graph
1114 node->tensor_parm.tensor_str.str_n = 1.0;
1115 node->tensor_parm.tensor_str.str_c = 1.0;
1116 node->tensor_parm.tensor_str.str_h = 1.0;
1117 node->tensor_parm.tensor_str.str_w = 1.0;
1118
1119 return strategies;
1120 }
1121
SetBackToRawStrategy(const std::shared_ptr<OperatorInfo> & op)1122 void SetBackToRawStrategy(const std::shared_ptr<OperatorInfo> &op) {
1123 Strategies strategies;
1124
1125 for (size_t iter_strategy = 0; iter_strategy < op->inputs_shape().size(); iter_strategy++) {
1126 Dimensions strategy;
1127 size_t strategy_size = op->inputs_shape()[iter_strategy].size();
1128 for (size_t dim = 0; dim < strategy_size; dim++) {
1129 if (strategy_size >= SIZE_ONE && strategy_size <= SIZE_FOUR) {
1130 strategy.push_back(1);
1131 } else if (strategy_size == 0) {
1132 strategy = {};
1133 } else {
1134 MS_LOG(EXCEPTION) << op->name() << ": Strategy size " << strategy_size << " is unmatched.";
1135 }
1136 }
1137 strategies.push_back(strategy);
1138 }
1139
1140 StrategyPtr sp = std::make_shared<Strategy>(0, strategies);
1141
1142 op->SetSelectedStrategyAndCost(sp, op->selected_cost());
1143 }
1144
PrepareStrategy(Graph::NodeType * node,const std::vector<std::shared_ptr<OperatorInfo>> & ops,const size_t iter_ops,const bool dyn_shape_tmp_fix)1145 Strategies PrepareStrategy(Graph::NodeType *node, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
1146 const size_t iter_ops, const bool dyn_shape_tmp_fix) {
1147 if (ops.empty()) {
1148 MS_LOG(EXCEPTION) << "Failure: Operators is empty.";
1149 }
1150 if (iter_ops >= ops.size()) {
1151 MS_LOG(EXCEPTION) << "Failure: Operators' elements out of range.";
1152 }
1153 MS_EXCEPTION_IF_NULL(ops[iter_ops]);
1154
1155 auto type = ops[iter_ops]->type();
1156 MS_LOG(INFO) << "Processing main operator " << ops[iter_ops]->name() << " (type=" << type << ")";
1157 if (type == MATMUL) {
1158 return PrepareMatMul(node, ops[iter_ops]);
1159 } else if (dyn_shape_tmp_fix && type == BATCH_MATMUL) {
1160 return PrepareBatchMatMul(node, ops[iter_ops]);
1161 } else if (type == LAYER_NORM) {
1162 return PrepareAxisRelatedStrategy(node, ops, iter_ops);
1163 } else if (type == SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS) {
1164 return MakeDataParallelStrategy(node, ops, iter_ops);
1165 } else if (type == VIRTUAL_DATA_SET) {
1166 if (ParallelContext::GetInstance()->full_batch()) {
1167 return MakeFullBatchStrategy(node, ops, iter_ops);
1168 } else {
1169 return MakeDataParallelStrategy(node, ops, iter_ops);
1170 }
1171 } else {
1172 return MakeRecSearchStrategy(node, ops, iter_ops);
1173 }
1174 }
1175
CheckVirtualDatasetStrategy(Graph::NodeType * node)1176 float CheckVirtualDatasetStrategy(Graph::NodeType *node) {
1177 // The values for str can only be 1.0, 0.5, 0.25, 0.125…
1178 // We want to find out the first str that is smaller than 1
1179 if (node->tensor_parm.tensor_str.str_n < 0.9) {
1180 return node->tensor_parm.tensor_str.str_n;
1181 }
1182 if (node->tensor_parm.tensor_str.str_c < 0.9) {
1183 return node->tensor_parm.tensor_str.str_c;
1184 }
1185 if (node->tensor_parm.tensor_str.str_h < 0.9) {
1186 return node->tensor_parm.tensor_str.str_h;
1187 }
1188 if (node->tensor_parm.tensor_str.str_w < 0.9) {
1189 return node->tensor_parm.tensor_str.str_w;
1190 }
1191 return 1.0;
1192 }
1193
CopyVirtualDataset(Graph::NodeType * node,const std::shared_ptr<OperatorInfo> & op,float epsilon=0.00005f)1194 Dimensions CopyVirtualDataset(Graph::NodeType *node, const std::shared_ptr<OperatorInfo> &op,
1195 float epsilon = 0.00005f) {
1196 Dimensions strategy;
1197 auto input_stra_dim = op->inputs_shape()[0].size();
1198 auto virtual_dataset_str = CheckVirtualDatasetStrategy(node);
1199 MS_EXCEPTION_IF_ZERO("Virtual_Dataset", virtual_dataset_str);
1200 if (input_stra_dim == 0) {
1201 return strategy;
1202 } else {
1203 if (std::fabs(virtual_dataset_str) < epsilon) {
1204 strategy.push_back(1);
1205 } else {
1206 strategy.push_back(FloatToLong(1 / virtual_dataset_str));
1207 }
1208 for (size_t i = 1; i < input_stra_dim; i++) {
1209 strategy.push_back(1);
1210 }
1211 }
1212 return strategy;
1213 }
1214
CopyIncomingOperatorOutputStrategy(Graph::NodeType * node,const std::vector<std::shared_ptr<OperatorInfo>> & ops,const size_t iter_ops,const size_t incoming_op_index)1215 Dimensions CopyIncomingOperatorOutputStrategy(Graph::NodeType *node,
1216 const std::vector<std::shared_ptr<OperatorInfo>> &ops,
1217 const size_t iter_ops, const size_t incoming_op_index) {
1218 Dimensions strategy;
1219
1220 if (ops[incoming_op_index]->type() == VIRTUAL_DATA_SET) {
1221 strategy = CopyVirtualDataset(node, ops[iter_ops]);
1222 return strategy;
1223 }
1224
1225 for (auto inputs_shape : ops[iter_ops]->inputs_shape()) {
1226 auto input_stra_dim = inputs_shape.size();
1227 if (input_stra_dim == SIZE_ZERO) {
1228 continue;
1229 }
1230 if (input_stra_dim == SIZE_ONE) {
1231 strategy.push_back(FloatToLong(1 / node->tensor_parm.tensor_str.str_w));
1232 } else if (input_stra_dim == SIZE_TWO) {
1233 strategy.push_back(FloatToLong(1 / node->tensor_parm.tensor_str.str_h));
1234 strategy.push_back(FloatToLong(1 / node->tensor_parm.tensor_str.str_w));
1235 } else if (input_stra_dim == SIZE_THREE) {
1236 // Experimental support for 3D data.
1237 strategy.push_back(FloatToLong(1 / node->tensor_parm.tensor_str.str_c));
1238 strategy.push_back(FloatToLong(1 / node->tensor_parm.tensor_str.str_h));
1239 strategy.push_back(FloatToLong(1 / node->tensor_parm.tensor_str.str_w));
1240 } else if (input_stra_dim == SIZE_FOUR) {
1241 strategy.push_back(FloatToLong(1 / node->tensor_parm.tensor_str.str_n));
1242 strategy.push_back(FloatToLong(1 / node->tensor_parm.tensor_str.str_c));
1243 strategy.push_back(FloatToLong(1 / node->tensor_parm.tensor_str.str_h));
1244 strategy.push_back(FloatToLong(1 / node->tensor_parm.tensor_str.str_w));
1245 } else {
1246 MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": Tensor's shape is unknown.";
1247 }
1248 break;
1249 }
1250 return strategy;
1251 }
1252
PrepareReshape(std::vector<int64_t> from_shape,std::vector<int64_t> to_shape,std::vector<int64_t> from_strat)1253 Dimensions PrepareReshape(std::vector<int64_t> from_shape, std::vector<int64_t> to_shape,
1254 std::vector<int64_t> from_strat) {
1255 Dimensions to_strat(to_shape.size(), 1);
1256 std::vector<int64_t> from_shape_cpy(from_shape);
1257 std::vector<int64_t> to_shape_cpy(to_shape);
1258 size_t from_idx = 0;
1259 size_t to_idx = 0;
1260
1261 // Attempt to assign full strategy to one dimension
1262 while (from_idx < from_shape.size() && to_idx < to_shape.size()) {
1263 if (from_shape[from_idx] > to_shape[to_idx]) {
1264 if (to_shape[to_idx] % from_strat[from_idx] == 0) {
1265 to_strat[to_idx] *= from_strat[from_idx];
1266 from_strat[from_idx] = 1;
1267 }
1268 from_shape[from_idx] /= to_shape[to_idx];
1269 to_idx++;
1270 } else if (from_shape[from_idx] < to_shape[to_idx]) {
1271 to_shape[to_idx] /= from_shape[from_idx];
1272 from_idx++;
1273 } else {
1274 if (to_shape[to_idx] % from_strat[from_idx] == 0) {
1275 to_strat[to_idx] *= from_strat[from_idx];
1276 from_strat[from_idx] = 1;
1277 }
1278 from_idx++;
1279 to_idx++;
1280 }
1281 }
1282
1283 // Reset shapes & indices
1284 from_idx = 0;
1285 to_idx = 0;
1286 from_shape = from_shape_cpy;
1287 to_shape = to_shape_cpy;
1288
1289 // Assign remaining strategy
1290 while (from_idx < from_shape.size() && to_idx < to_shape.size()) {
1291 if (from_shape[from_idx] > to_shape[to_idx]) {
1292 int64_t d = std::gcd(from_strat[from_idx], to_shape[to_idx]);
1293 to_strat[to_idx] *= d;
1294 from_strat[from_idx] /= d;
1295 from_shape[from_idx] /= to_shape[to_idx];
1296 to_idx++;
1297 } else if (from_shape[from_idx] < to_shape[to_idx]) {
1298 to_strat[to_idx] *= from_strat[from_idx];
1299 to_shape[to_idx] /= from_shape[from_idx];
1300 from_idx++;
1301 } else { // equal case
1302 to_strat[to_idx] *= from_strat[from_idx];
1303 from_idx++;
1304 to_idx++;
1305 }
1306 }
1307 return to_strat;
1308 }
1309
PrepareReshapeOutputStrategy(const std::shared_ptr<OperatorInfo> & op)1310 Dimensions PrepareReshapeOutputStrategy(const std::shared_ptr<OperatorInfo> &op) {
1311 auto output_shape = op->outputs_shape()[0];
1312 auto input_shape = op->inputs_shape()[0];
1313 auto strategy = op->selected_strategy();
1314
1315 return PrepareReshape(input_shape, output_shape, strategy->GetInputDim()[0]);
1316 }
1317
PrepareTransposeOutputStrategy(const std::shared_ptr<OperatorInfo> & op)1318 Dimensions PrepareTransposeOutputStrategy(const std::shared_ptr<OperatorInfo> &op) {
1319 Dimensions strategy;
1320 auto permutation = GetValue<std::vector<int64_t>>(op->input_value().at(1));
1321 auto op_strategy = op->selected_strategy();
1322 // The strategies are assigned according to the order in permutation (user defined).
1323 for (size_t i = 0; i < permutation.size(); i++) {
1324 strategy.push_back(op_strategy->GetInputDim()[0][LongToSize(permutation[i])]);
1325 }
1326 return strategy;
1327 }
1328
PrepareExpandDimsOutputStrategy(const std::shared_ptr<OperatorInfo> & op)1329 Dimensions PrepareExpandDimsOutputStrategy(const std::shared_ptr<OperatorInfo> &op) {
1330 Dimensions strategy;
1331
1332 auto axis_input = GetValue<int64_t>(op->input_value().at(1));
1333 auto op_strategy = op->selected_strategy();
1334 bool already_expand = false;
1335
1336 // axis_input can be negative, in which case the index is computed backward from the shape size.
1337 if (axis_input < 0) {
1338 axis_input = SizeToLong(op->inputs_shape()[0].size()) + axis_input + 1;
1339 }
1340
1341 // The strategy of the expanded dimension will be assigned 1, the others take the strategies of corresponding
1342 // dimensions.
1343 for (size_t i = 0; i < op->inputs_shape()[0].size() + 1; i++) {
1344 if (UlongToLong(i) == axis_input) {
1345 strategy.push_back(1);
1346 already_expand = true;
1347 } else if (UlongToLong(i) != axis_input && !already_expand) {
1348 strategy.push_back(op_strategy->GetInputDim()[0][i]);
1349 } else {
1350 if (i < 1) {
1351 MS_LOG(EXCEPTION) << "The index i -1 is less than 0. Please check the situation.";
1352 }
1353 strategy.push_back(op_strategy->GetInputDim()[0][i - 1]);
1354 }
1355 }
1356
1357 return strategy;
1358 }
1359
PrepareCumOutputStrategy(const std::shared_ptr<OperatorInfo> & op)1360 Dimensions PrepareCumOutputStrategy(const std::shared_ptr<OperatorInfo> &op) {
1361 Dimensions strategy;
1362
1363 int64_t axis_input = 1;
1364
1365 if (op->input_value().at(1)->isa<Int64Imm>()) {
1366 axis_input = GetValue<int64_t>(op->input_value().at(1));
1367 MS_LOG(INFO) << op->name() << "is a prefix sum on axis " << axis_input;
1368 } else {
1369 MS_LOG(INFO) << op->name() << "that is supposedly a cum op, has an axis that is NOT an int64";
1370 }
1371
1372 auto op_strategy = op->selected_strategy();
1373
1374 // axis_input can be negative, in which case the index is computed backward from the shape size.
1375 if (axis_input < 0) {
1376 axis_input = op->inputs_shape()[0].size() + axis_input + 1;
1377 }
1378
1379 // The strategy of the cumulated axis will be assigned 1, the others take the strategies of corresponding dimensions.
1380 for (size_t i = 0; i < op->inputs_shape()[0].size(); i++) {
1381 if ((int64_t)i == axis_input) {
1382 strategy.push_back(1);
1383 } else {
1384 strategy.push_back(op_strategy->GetInputDim()[0][i]);
1385 }
1386 }
1387
1388 return strategy;
1389 }
1390
GetReduceAxisList(const std::shared_ptr<OperatorInfo> & op)1391 ShapeVector GetReduceAxisList(const std::shared_ptr<OperatorInfo> &op) {
1392 ShapeVector axis_list;
1393 auto input_value = op->input_value();
1394 auto input_dim = op->inputs_shape()[0].size();
1395
1396 if (input_value.back()->isa<ValueTuple>()) {
1397 auto attr_axis = GetValue<std::vector<int64_t>>(input_value.back());
1398 if (attr_axis.empty()) {
1399 for (size_t i = 0; i < input_dim; i++) {
1400 axis_list.push_back(i);
1401 }
1402 } else {
1403 axis_list = attr_axis;
1404 }
1405 } else if (input_value.back()->isa<Int64Imm>()) {
1406 int64_t axis = GetValue<int64_t>(input_value.back());
1407 axis_list.push_back(axis < 0 ? axis + SizeToLong(input_dim) : axis);
1408 } else {
1409 MS_LOG(EXCEPTION) << "Failure: Axis type is invalid." << std::endl;
1410 }
1411
1412 return axis_list;
1413 }
1414
PrepareCumInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> & ops,size_t i_ops,size_t outgoing_op_index,size_t i_input)1415 Dimensions PrepareCumInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops, size_t i_ops,
1416 size_t outgoing_op_index, size_t i_input) {
1417 Dimensions strategy;
1418 int64_t axis_input = 1;
1419
1420 if (ops[i_ops]->input_value().at(1)->isa<Int64Imm>()) {
1421 axis_input = GetValue<int64_t>(ops[i_ops]->input_value().at(1));
1422 MS_LOG(INFO) << ops[i_ops]->name() << "is a prefix sum on axis " << axis_input;
1423 } else {
1424 MS_LOG(INFO) << ops[i_ops]->name() << "that is supposedly a cumulative op has an axis that is NOT an int64";
1425 }
1426
1427 auto op_strategy = ops[outgoing_op_index]->selected_strategy();
1428
1429 size_t n_dim = op_strategy->GetInputDim()[i_input].size();
1430
1431 if (axis_input < 0) {
1432 axis_input = n_dim + LongToSize(axis_input);
1433 }
1434
1435 MS_EXCEPTION_IF_CHECK_FAIL(axis_input >= 0, "Input axis is lower than 0");
1436
1437 for (size_t i_dim = 0; i_dim < n_dim; ++i_dim) {
1438 if (i_dim == size_t(axis_input)) {
1439 strategy.push_back(1);
1440 } else {
1441 strategy.push_back(op_strategy->GetInputDim()[i_input][i_dim]);
1442 }
1443 }
1444
1445 return strategy;
1446 }
1447
PrepareIncomingArithmeticOpeartorInputStrategy(const std::shared_ptr<OperatorInfo> & op)1448 Dimensions PrepareIncomingArithmeticOpeartorInputStrategy(const std::shared_ptr<OperatorInfo> &op) {
1449 Dimensions strategy;
1450 size_t max = 0;
1451 for (size_t i = 1; i < op->inputs_shape().size(); i++) {
1452 if (op->inputs_shape()[i].size() > op->inputs_shape()[max].size()) {
1453 max = i;
1454 }
1455 }
1456
1457 for (size_t j = 0; j < op->inputs_shape()[max].size(); j++) {
1458 strategy.push_back(op->selected_strategy()->GetInputDim()[max][j]);
1459 }
1460
1461 return strategy;
1462 }
1463
PrepareIncomingOperatorInputStrategy(const std::shared_ptr<OperatorInfo> & op)1464 Dimensions PrepareIncomingOperatorInputStrategy(const std::shared_ptr<OperatorInfo> &op) {
1465 Dimensions strategy;
1466
1467 if (op->type() == GATHERV2) {
1468 auto pos = op->name().find("Info");
1469 if (pos == std::string::npos) {
1470 return strategy;
1471 }
1472 auto name = op->name().substr(0, pos);
1473 if (name == "Gather") {
1474 return PrepareGatherV2OutputStrategy(op);
1475 } else {
1476 MS_LOG(EXCEPTION) << "Failure: Unknown type of GatherV2.";
1477 }
1478 }
1479
1480 if (!HasStrategy(op)) {
1481 return strategy;
1482 }
1483
1484 auto op_strategy = op->selected_strategy();
1485 if (op_strategy->GetInputNumber() == 0) {
1486 return strategy;
1487 }
1488
1489 if (op->type() == MUL || op->type() == SUB || op->type() == ADD || op->type() == BIAS_ADD) {
1490 strategy = PrepareIncomingArithmeticOpeartorInputStrategy(op);
1491 return strategy;
1492 }
1493
1494 if (op->type() == RESHAPE) {
1495 return PrepareReshapeOutputStrategy(op);
1496 } else if (op->type() == TRANSPOSE) {
1497 return PrepareTransposeOutputStrategy(op);
1498 } else if (op->type() == EXPAND_DIMS) {
1499 return PrepareExpandDimsOutputStrategy(op);
1500 } else if (op->type() == CUM_SUM || op->type() == CUM_PROD) {
1501 return PrepareCumOutputStrategy(op);
1502 } else if (op->type() == ONEHOT) {
1503 return PrepareOneHotOutputStrategy(op);
1504 }
1505
1506 for (size_t i = 0; i < static_cast<size_t>(op->inputs_shape().size()); i++) {
1507 if (op->inputs_shape()[i].size() == 0) {
1508 continue;
1509 }
1510 for (size_t j = 0; j < op->inputs_shape()[i].size(); ++j) {
1511 strategy.push_back(op_strategy->GetInputDim()[i][j]);
1512 }
1513 break;
1514 }
1515 return strategy;
1516 }
1517
GetAxisList(const std::vector<std::shared_ptr<OperatorInfo>> & ops,const int64_t iter_ops)1518 Dimensions GetAxisList(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const int64_t iter_ops) {
1519 Dimensions axis_list;
1520 auto axis_param = ops[LongToSize(iter_ops)]->attrs().find(AXIS)->second;
1521 std::vector<ValuePtr> elements;
1522 if (axis_param->isa<ValueTuple>()) {
1523 elements = axis_param->cast<ValueTuplePtr>()->value();
1524 } else if (axis_param->isa<ValueList>()) {
1525 elements = axis_param->cast<ValueListPtr>()->value();
1526 } else {
1527 MS_LOG(EXCEPTION) << "Failure: Axis type is invalid, neither tuple nor list.";
1528 }
1529
1530 for (auto &element : elements) {
1531 if (!element->isa<Int64Imm>()) {
1532 MS_LOG(EXCEPTION) << "Failure: Dimension indexes is not Int32.";
1533 }
1534 auto axis = element->cast<Int64ImmPtr>()->value();
1535 axis_list.push_back(axis);
1536 }
1537 return axis_list;
1538 }
1539
ModifyStrategyIfSqueezeIncoming(const std::vector<std::shared_ptr<OperatorInfo>> & ops,const size_t incoming_op_index,Dimensions strategy)1540 Dimensions ModifyStrategyIfSqueezeIncoming(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
1541 const size_t incoming_op_index, Dimensions strategy) {
1542 Dimensions s_Squeeze;
1543 Dimensions stra_dim_list;
1544 for (size_t i = 0; i < strategy.size(); i++) {
1545 stra_dim_list.push_back(SizeToLong(i));
1546 }
1547
1548 auto axis_list = GetAxisList(ops, SizeToLong(incoming_op_index));
1549 for (auto axis : axis_list) {
1550 axis = (axis < 0) ? (strategy.size() + axis) : axis;
1551 auto it = find(stra_dim_list.begin(), stra_dim_list.end(), axis);
1552 if (it == stra_dim_list.end()) {
1553 MS_LOG(EXCEPTION) << "Failure: Can not find dimension indexes in Axis.";
1554 }
1555 if (ops[incoming_op_index]->inputs_shape()[0][LongToSize(axis)] != 1) {
1556 MS_LOG(EXCEPTION) << "Failure: Removed dimension's shape is not 1.";
1557 }
1558 (void)stra_dim_list.erase(it);
1559 }
1560
1561 for (size_t i = 0; i < stra_dim_list.size(); i++) {
1562 s_Squeeze.push_back(strategy[LongToSize(stra_dim_list[i])]);
1563 }
1564 return s_Squeeze;
1565 }
1566
ModifyStrategyIfReduceIncoming(const std::shared_ptr<OperatorInfo> & op,Dimensions strategy)1567 Dimensions ModifyStrategyIfReduceIncoming(const std::shared_ptr<OperatorInfo> &op, Dimensions strategy) {
1568 Dimensions s_Reduce;
1569 Dimensions axis_list;
1570 for (size_t i = 0; i < strategy.size(); i++) {
1571 axis_list.push_back(SizeToLong(i));
1572 }
1573
1574 auto dim_list = GetDimList(op);
1575 for (auto axis : dim_list) {
1576 auto it = find(axis_list.begin(), axis_list.end(), axis);
1577 if (it == axis_list.end()) {
1578 MS_LOG(EXCEPTION) << "Failure: Can not find dimension indexes in Axis.";
1579 }
1580 (void)axis_list.erase(it);
1581 }
1582
1583 for (size_t i = 0; i < axis_list.size(); i++) {
1584 s_Reduce.push_back(strategy[LongToSize(axis_list[i])]);
1585 }
1586 return s_Reduce;
1587 }
1588
GetDimListFromAttrs(const std::shared_ptr<OperatorInfo> & op)1589 Dimensions GetDimListFromAttrs(const std::shared_ptr<OperatorInfo> &op) {
1590 Dimensions dim_list;
1591 auto iter = op->attrs().find(AXIS);
1592 if (iter == op->attrs().end()) {
1593 MS_LOG(EXCEPTION) << op->name() << ": Don't have attr axis.";
1594 }
1595 auto input_dim = op->inputs_shape()[0].size();
1596 MS_EXCEPTION_IF_NULL(iter->second);
1597 if (iter->second->isa<ValueTuple>()) {
1598 auto attr_axis = GetValue<std::vector<int64_t>>(iter->second);
1599 if (attr_axis.empty()) {
1600 for (size_t i = 0; i < input_dim; ++i) {
1601 dim_list.push_back(SizeToLong(i));
1602 }
1603 } else {
1604 for (auto &axis : attr_axis) {
1605 axis < 0 ? dim_list.push_back(axis + SizeToLong(input_dim)) : dim_list.push_back(axis);
1606 }
1607 }
1608 } else if (iter->second->isa<Int64Imm>()) {
1609 int64_t axis = GetValue<int64_t>(iter->second);
1610 axis < 0 ? dim_list.push_back(axis + SizeToLong(input_dim)) : dim_list.push_back(axis);
1611 } else {
1612 MS_LOG(EXCEPTION) << "Axis type is invalid.";
1613 }
1614 return dim_list;
1615 }
1616
ModifyStrategyIfArgIncoming(const std::shared_ptr<OperatorInfo> & op,Dimensions strategy)1617 Dimensions ModifyStrategyIfArgIncoming(const std::shared_ptr<OperatorInfo> &op, Dimensions strategy) {
1618 bool keepdims = GetKeepDims(op);
1619 if (keepdims) {
1620 return strategy;
1621 }
1622
1623 Dimensions s_Arg;
1624 Dimensions axis_list;
1625 for (size_t i = 0; i < strategy.size(); i++) {
1626 axis_list.push_back(SizeToLong(i));
1627 }
1628
1629 auto dim_list = GetDimListFromAttrs(op);
1630 for (auto axis : dim_list) {
1631 auto it = find(axis_list.begin(), axis_list.end(), axis);
1632 if (it == axis_list.end()) {
1633 MS_LOG(EXCEPTION) << "Failure: Can not find dimension indexes in Axis.";
1634 }
1635 (void)axis_list.erase(it);
1636 }
1637
1638 for (size_t i = 0; i < axis_list.size(); i++) {
1639 s_Arg.push_back(strategy[LongToSize(axis_list[i])]);
1640 }
1641 return s_Arg;
1642 }
1643
ModifyStrategyIfFlattenIncoming(const std::shared_ptr<OperatorInfo> & op,Dimensions strategy)1644 Dimensions ModifyStrategyIfFlattenIncoming(const std::shared_ptr<OperatorInfo> &op, Dimensions strategy) {
1645 Dimensions new_strategy;
1646 int start_dim = 1, end_dim = strategy.size() - 1;
1647 auto start_dim_iter = op->attrs().find("start_dim");
1648 if (start_dim_iter != op->attrs().end()) {
1649 start_dim = GetValue<int64_t>(start_dim_iter->second);
1650 }
1651 auto end_dim_iter = op->attrs().find("end_dim");
1652 if (end_dim_iter != op->attrs().end() && GetValue<int64_t>(end_dim_iter->second) >= 0) {
1653 end_dim = GetValue<int64_t>(end_dim_iter->second);
1654 }
1655
1656 for (int idx = 0; idx < start_dim; idx++) {
1657 new_strategy.push_back(strategy[idx]);
1658 }
1659
1660 int flatten_strategy = 1;
1661 for (int idx = start_dim; idx < end_dim + 1; idx++) {
1662 flatten_strategy *= strategy[idx];
1663 }
1664 new_strategy.push_back(flatten_strategy);
1665 if (IntToSize(end_dim + 1) < strategy.size()) {
1666 for (size_t idx = end_dim + 1; idx < strategy.size(); idx++) {
1667 new_strategy.push_back(strategy[idx]);
1668 }
1669 }
1670
1671 return new_strategy;
1672 }
1673
CopyIncomingOperatorInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> & ops,const size_t iter_ops,const size_t incoming_op_index)1674 Dimensions CopyIncomingOperatorInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
1675 const size_t iter_ops, const size_t incoming_op_index) {
1676 Dimensions strategy;
1677 if (ops[iter_ops]->type() == ONEHOT) {
1678 return strategy;
1679 }
1680 if (ops[iter_ops]->type() == TRANSPOSE) {
1681 return strategy;
1682 }
1683 if (ops[incoming_op_index]->type() == STRIDED_SLICE) {
1684 return strategy;
1685 }
1686 strategy = PrepareIncomingOperatorInputStrategy(ops[incoming_op_index]);
1687 if (strategy.size() != 0) {
1688 if (ops[incoming_op_index]->type() == SQUEEZE) {
1689 strategy = ModifyStrategyIfSqueezeIncoming(ops, incoming_op_index, strategy);
1690 }
1691 if (ops[incoming_op_index]->type() == REDUCE_SUM || ops[incoming_op_index]->type() == REDUCE_MAX ||
1692 ops[incoming_op_index]->type() == REDUCE_MIN || ops[incoming_op_index]->type() == REDUCE_MEAN) {
1693 strategy = ModifyStrategyIfReduceIncoming(ops[incoming_op_index], strategy);
1694 }
1695 if (ops[incoming_op_index]->type() == ARGMAXWITHVALUE || ops[incoming_op_index]->type() == ARGMINWITHVALUE) {
1696 strategy = ModifyStrategyIfArgIncoming(ops[incoming_op_index], strategy);
1697 }
1698 if (ops[incoming_op_index]->type() == FLATTEN) {
1699 strategy = ModifyStrategyIfFlattenIncoming(ops[incoming_op_index], strategy);
1700 }
1701 }
1702 return strategy;
1703 }
1704
PrepareDropoutDoMask(const std::shared_ptr<OperatorInfo> & op,Dimensions basic_stra,bool dyn_shape_tmp_fix)1705 Strategies PrepareDropoutDoMask(const std::shared_ptr<OperatorInfo> &op, Dimensions basic_stra,
1706 bool dyn_shape_tmp_fix) {
1707 // Dropout's strategy shape must be 1.
1708 Strategies strategies;
1709 strategies.clear();
1710 strategies.push_back(basic_stra);
1711 return strategies;
1712 }
1713
1714 // Function to deal with ops with broadcasting, like TensorAdd/Sub/Mul/Div etc.
CheckBroadcast(const std::shared_ptr<OperatorInfo> & op,Dimensions strategy,bool dyn_shape_tmp_fix)1715 Strategies CheckBroadcast(const std::shared_ptr<OperatorInfo> &op, Dimensions strategy, bool dyn_shape_tmp_fix) {
1716 Strategies strategies;
1717
1718 size_t first_tensor_dim = op->inputs_shape()[0].size();
1719 size_t second_tensor_dim = op->inputs_shape()[1].size();
1720 size_t s_dim = strategy.size();
1721 // Do Broadcasting in the second tensor.
1722 if (second_tensor_dim < first_tensor_dim) {
1723 if (s_dim == first_tensor_dim) {
1724 bool broadcast_first_tensor = false;
1725 strategies.push_back(strategy);
1726 strategies.push_back(ApplyBroadcast(op, strategy, broadcast_first_tensor));
1727 } else {
1728 // When the strategy is from the smaller tensor, make the strategy all 1.
1729 Dimensions broadcast_revise_s(first_tensor_dim, 1);
1730 strategies.push_back(broadcast_revise_s);
1731 Dimensions broadcast_s(strategy.size(), 1);
1732 strategies.push_back(broadcast_s);
1733 }
1734 } else if (second_tensor_dim > first_tensor_dim) { // Do Broadcasting in the first tensor.
1735 if (s_dim == second_tensor_dim) {
1736 bool broadcast_first_tensor = true;
1737 strategies.push_back(ApplyBroadcast(op, strategy, broadcast_first_tensor));
1738 strategies.push_back(strategy);
1739 } else {
1740 // When the strategy is from the smaller tensor, make the strategy all 1.
1741 Dimensions broadcast_s(strategy.size(), 1);
1742 strategies.push_back(broadcast_s);
1743 Dimensions broadcast_revise_s(second_tensor_dim, 1);
1744 strategies.push_back(broadcast_revise_s);
1745 }
1746 } else { // Broadcasting can be ignored or No broadcasting needs to be applied.
1747 strategies = CheckDivisible(op, strategy);
1748 }
1749 // Strategy protection to avoid that partition number is larger than the shape of related dimension.
1750 for (size_t i = 0; i < op->inputs_shape().size(); i++) {
1751 for (size_t j = 0; j < op->inputs_shape()[i].size(); j++) {
1752 if (strategies[i][j] > op->inputs_shape()[i][j] || op->inputs_shape()[i][j] % strategies[i][j] != 0) {
1753 strategies[i][j] = 1;
1754 }
1755 }
1756 }
1757
1758 return strategies;
1759 }
1760
InitializeStrategyMap()1761 void InitializeStrategyMap() {
1762 if (g_prepare_stra_map.empty()) {
1763 g_prepare_stra_map =
1764 std::map<std::string, PrepareStraFuncPtr>{{FILLV2, &PrepareFillV2},
1765 {BIAS_ADD, &PrepareBiasAdd},
1766 {STRIDED_SLICE, &PrepareStridedSlice},
1767 {GATHERV2, &PrepareGather},
1768 {ONEHOT, &PrepareOneHot},
1769 {L2_NORMALIZE, &PrepareL2Normalize},
1770 {ADD, &CheckBroadcast},
1771 {SUB, &CheckBroadcast},
1772 {MUL, &CheckBroadcast},
1773 {DIV, &CheckBroadcast},
1774 {SOFTMAX, &PrepareSoftMax},
1775 {LOG_SOFTMAX, &PrepareSoftMax},
1776 {FLATTEN, &PrepareDataParallel},
1777 {GATHERD, &PrepareDataParallel},
1778 {LAYER_NORM, &PrepareLayerNorm},
1779 {RMS_NORM, &PrepareRmsNorm},
1780 {BATCH_MATMUL, &PreparePropagateBatchMatMul},
1781 {DROPOUT_DO_MASK, &PrepareDropoutDoMask},
1782 {FLASH_ATTENTION_SCORE, &PrepareFlashAttentionScore}};
1783 }
1784 }
1785
GenerateStrategiesFromStrategy(const std::vector<std::shared_ptr<OperatorInfo>> & ops,const size_t iter_ops,Dimensions basic_stra,bool dyn_shape_tmp_fix)1786 Strategies GenerateStrategiesFromStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
1787 Dimensions basic_stra, bool dyn_shape_tmp_fix) {
1788 MS_EXCEPTION_IF_NULL(ops[iter_ops]);
1789
1790 if (iter_ops >= ops.size()) {
1791 MS_LOG(EXCEPTION) << "Failure: Operators' elements out of range.";
1792 }
1793
1794 Strategies strategies;
1795 if (basic_stra.size() == 0) {
1796 for (size_t iter_op_inputs = 0; iter_op_inputs < static_cast<size_t>(ops[iter_ops]->inputs_shape().size());
1797 iter_op_inputs++) {
1798 strategies.push_back(basic_stra);
1799 }
1800 return strategies;
1801 }
1802 InitializeStrategyMap();
1803 auto type = ops[iter_ops]->type();
1804 auto iter_stra_func = g_prepare_stra_map.find(type);
1805 if (iter_stra_func != g_prepare_stra_map.end()) {
1806 auto stra = iter_stra_func->second(ops[iter_ops], basic_stra, dyn_shape_tmp_fix);
1807 return stra;
1808 }
1809
1810 return CheckDivisible(ops[iter_ops], basic_stra);
1811 }
1812
ApplyBroadcast(const std::shared_ptr<OperatorInfo> & op,const Dimensions & strategy,bool broadcast_first_tensor)1813 Dimensions ApplyBroadcast(const std::shared_ptr<OperatorInfo> &op, const Dimensions &strategy,
1814 bool broadcast_first_tensor) {
1815 Dimensions s_broadcast;
1816 size_t target_tensor_index = 0;
1817 size_t target_tensor_dim = 1;
1818
1819 // Indexing target and refer tensor.
1820 if (!broadcast_first_tensor) {
1821 target_tensor_index = 1;
1822 }
1823
1824 target_tensor_dim = op->inputs_shape()[target_tensor_index].size();
1825 for (size_t iter = 0; iter < target_tensor_dim; iter++) {
1826 if (op->inputs_shape()[target_tensor_index][target_tensor_dim - 1 - iter] == 1) {
1827 s_broadcast.insert(s_broadcast.begin(), 1);
1828 } else {
1829 s_broadcast.insert(s_broadcast.begin(), strategy[strategy.size() - 1 - iter]);
1830 }
1831 }
1832
1833 return s_broadcast;
1834 }
1835
1836 // Check whether the operator can be divided by the current strategy.
CheckDivisible(const std::shared_ptr<OperatorInfo> & op,const Dimensions & basic_stra)1837 Strategies CheckDivisible(const std::shared_ptr<OperatorInfo> &op, const Dimensions &basic_stra) {
1838 Dimensions s_empty = {};
1839 Strategies strategies;
1840
1841 // For all the input tensors.
1842 for (size_t iter_op_inputs = 0; iter_op_inputs < op->inputs_shape().size(); iter_op_inputs++) {
1843 // If input tensor is empty, return strategy as void.
1844 if (op->inputs_shape()[iter_op_inputs].size() == 0) {
1845 strategies.push_back(s_empty);
1846 continue;
1847 }
1848
1849 Dimensions tmp_stra;
1850
1851 // Make sure each tensor's dim shape is greater than 1. If not, push back strategy as 1 instead.
1852 for (size_t j = 0; j < op->inputs_shape()[iter_op_inputs].size(); j++) {
1853 if (op->inputs_shape()[iter_op_inputs][j] == 1) {
1854 tmp_stra.push_back(1);
1855 } else if (j < basic_stra.size()) {
1856 tmp_stra.push_back(basic_stra[j]);
1857 } else {
1858 tmp_stra.push_back(1);
1859 }
1860 }
1861 strategies.push_back(tmp_stra);
1862 }
1863
1864 return strategies;
1865 }
1866
ModifyStrategyIfSqueezeOutgoing(const std::vector<std::shared_ptr<OperatorInfo>> & ops,const size_t iter_ops,Dimensions strategy)1867 Dimensions ModifyStrategyIfSqueezeOutgoing(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
1868 Dimensions strategy) {
1869 Dimensions s_Squeeze;
1870 auto axis_list = GetAxisList(ops, SizeToLong(iter_ops));
1871 size_t s_index = 0;
1872 size_t axis_list_index = 0;
1873 for (size_t i = 0; i < strategy.size() + axis_list.size(); i++) {
1874 if (axis_list[axis_list_index] > 0 && i == LongToSize(axis_list[axis_list_index])) {
1875 s_Squeeze.push_back(1);
1876 axis_list_index++;
1877 } else {
1878 s_Squeeze.push_back(strategy[s_index]);
1879 s_index++;
1880 }
1881 }
1882
1883 size_t cut = 1;
1884 for (size_t i = 0; i < s_Squeeze.size(); i++) {
1885 cut *= LongToSize(s_Squeeze[i]);
1886 }
1887 if (cut != size_t(g_device_manager->stage_device_num())) {
1888 s_Squeeze.clear();
1889 }
1890
1891 return s_Squeeze;
1892 }
1893
PrepareExpandDimsInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> & ops,size_t i_ops,size_t outgoing_op_index,size_t i_input)1894 Dimensions PrepareExpandDimsInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops, size_t i_ops,
1895 size_t outgoing_op_index, size_t i_input) {
1896 Dimensions strategy;
1897
1898 int64_t axis_input = GetValue<int64_t>(ops[i_ops]->input_value().at(1));
1899
1900 auto op_strategy = ops[outgoing_op_index]->selected_strategy();
1901
1902 size_t n_dim = op_strategy->GetInputDim()[i_input].size();
1903
1904 if (axis_input < 0) {
1905 axis_input = SizeToLong(n_dim) + axis_input;
1906 }
1907
1908 MS_EXCEPTION_IF_CHECK_FAIL(axis_input >= 0, "Input axis is lower than 0");
1909
1910 for (size_t i_dim = 0; i_dim < n_dim; ++i_dim) {
1911 if (i_dim != size_t(axis_input)) {
1912 strategy.push_back(op_strategy->GetInputDim()[i_input][i_dim]);
1913 }
1914 }
1915
1916 return strategy;
1917 }
1918
PrepareReshapeInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> & ops,size_t i_ops,size_t outgoing_op_index,size_t iter_op_inputs,bool dyn_shape_tmp_fix)1919 Dimensions PrepareReshapeInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops, size_t i_ops,
1920 size_t outgoing_op_index, size_t iter_op_inputs, bool dyn_shape_tmp_fix) {
1921 if (dyn_shape_tmp_fix) {
1922 Dimensions empty_strategy;
1923 return empty_strategy;
1924 }
1925 auto output_shape = ops[i_ops]->outputs_shape()[0];
1926 auto input_shape = ops[i_ops]->inputs_shape()[0];
1927 auto strategy = ops[outgoing_op_index]->selected_strategy();
1928
1929 return PrepareReshape(output_shape, input_shape, strategy->GetInputDim()[iter_op_inputs]);
1930 }
1931
PrepareGatherV2InputStrategy(const std::shared_ptr<OperatorInfo> & op,size_t i_input)1932 Dimensions PrepareGatherV2InputStrategy(const std::shared_ptr<OperatorInfo> &op, size_t i_input) {
1933 auto targeted_shape = op->inputs_shape()[i_input];
1934 Dimensions strategie = GenGatherStra(targeted_shape);
1935 return strategie;
1936 }
1937
PrepareReduceOutputStrategy(const std::shared_ptr<OperatorInfo> & op)1938 Dimensions PrepareReduceOutputStrategy(const std::shared_ptr<OperatorInfo> &op) {
1939 bool keep_dims = GetKeepDims(op);
1940 auto axis_list = GetDimList(op);
1941 auto basic_stra = op->selected_strategy()->GetInputDim().at(0);
1942
1943 Dimensions strategy;
1944
1945 for (size_t i = 0; i < basic_stra.size(); ++i) {
1946 if (std::find(axis_list.begin(), axis_list.end(), i) != axis_list.end()) {
1947 if (keep_dims) {
1948 strategy.push_back(1);
1949 }
1950 } else {
1951 strategy.push_back(basic_stra.at(i));
1952 }
1953 }
1954
1955 return strategy;
1956 }
1957
PrepareReduceInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> & ops,size_t i_ops,size_t outgoing_op_index,size_t i_input)1958 Dimensions PrepareReduceInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops, size_t i_ops,
1959 size_t outgoing_op_index, size_t i_input) {
1960 bool keep_dims = GetKeepDims(ops[i_ops]);
1961
1962 auto axis_list = GetDimList(ops[i_ops]);
1963
1964 Dimensions strategy;
1965
1966 auto basic_stra = ops[outgoing_op_index]->selected_strategy()->GetInputDim().at(i_input);
1967
1968 for (size_t i = 0, i_stra = 0; i < ops[i_ops]->inputs_shape()[0].size(); ++i) {
1969 if (std::find(axis_list.begin(), axis_list.end(), i) != axis_list.end()) {
1970 strategy.push_back(1);
1971 if (keep_dims) {
1972 ++i_stra;
1973 }
1974 } else {
1975 strategy.push_back(basic_stra.at(i_stra++));
1976 }
1977 }
1978
1979 return strategy;
1980 }
1981
PrepareTransposeInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> & ops,size_t i_ops,size_t outgoing_op_index,size_t iter_op_inputs)1982 Dimensions PrepareTransposeInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops, size_t i_ops,
1983 size_t outgoing_op_index, size_t iter_op_inputs) {
1984 Dimensions strategy;
1985 auto permutation = GetValue<std::vector<int64_t>>(ops[i_ops]->input_value().at(1));
1986 auto op_strategy = ops[outgoing_op_index]->selected_strategy();
1987 // The strategies are assigned according to the order in permutation (user defined).
1988 for (size_t i = 0; i < permutation.size(); i++) {
1989 strategy.push_back(op_strategy->GetInputDim()[iter_op_inputs][LongToSize(permutation[i])]);
1990 }
1991 return strategy;
1992 }
1993
CopyOutgoingOperatorInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> & ops,size_t iter_ops,size_t outgoing_op_index,size_t iter_op_inputs,bool dyn_shape_tmp_fix)1994 Dimensions CopyOutgoingOperatorInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops, size_t iter_ops,
1995 size_t outgoing_op_index, size_t iter_op_inputs, bool dyn_shape_tmp_fix) {
1996 Dimensions strategy;
1997 // Propagation not implemented for these operators
1998 if (ops[iter_ops]->type() == ARGMAXWITHVALUE || ops[iter_ops]->type() == ARGMINWITHVALUE) {
1999 return strategy;
2000 }
2001
2002 // Propagation not allowed for these operators
2003 if (ops[iter_ops]->type() == FLATTEN) {
2004 return strategy;
2005 }
2006
2007 if (outgoing_op_index != SIZE_MAX && iter_op_inputs != SIZE_MAX) {
2008 std::string type = ops[iter_ops]->type();
2009 if (type == EXPAND_DIMS) {
2010 strategy = PrepareExpandDimsInputStrategy(ops, iter_ops, outgoing_op_index, iter_op_inputs);
2011 } else if (type == RESHAPE) {
2012 strategy = PrepareReshapeInputStrategy(ops, iter_ops, outgoing_op_index, iter_op_inputs, dyn_shape_tmp_fix);
2013 return strategy;
2014 } else if (type == GATHERV2) {
2015 strategy = PrepareGatherV2InputStrategy(ops[outgoing_op_index], iter_op_inputs);
2016 return strategy;
2017 } else if (type == REDUCE_MEAN || type == REDUCE_MAX || type == REDUCE_MIN || type == REDUCE_SUM) {
2018 strategy = PrepareReduceInputStrategy(ops, iter_ops, outgoing_op_index, iter_op_inputs);
2019 } else if (type == TRANSPOSE) {
2020 strategy = PrepareTransposeInputStrategy(ops, iter_ops, outgoing_op_index, iter_op_inputs);
2021 return strategy;
2022 } else {
2023 for (size_t k = 0; k < ops[iter_ops]->outputs_shape()[0].size(); ++k) {
2024 strategy.push_back(ops[outgoing_op_index]->selected_strategy()->GetInputDim()[iter_op_inputs][k]);
2025 }
2026 }
2027 if (!IsDimensionsEmpty(strategy) && ops[iter_ops]->type() == SQUEEZE) {
2028 strategy = ModifyStrategyIfSqueezeOutgoing(ops, iter_ops, strategy);
2029 }
2030 }
2031
2032 return strategy;
2033 }
2034
ApplyStrategy(size_t i_op,const Strategies & strategies)2035 void RecStrategyPropagator::ApplyStrategy(size_t i_op, const Strategies &strategies) {
2036 StrategyPtr sp = std::make_shared<Strategy>(0, strategies);
2037 ops_[i_op]->SetSelectedStrategyAndCost(sp, ops_[i_op]->selected_cost());
2038 }
2039
GetMaxDimNum(size_t i_op)2040 size_t RecStrategyPropagator::GetMaxDimNum(size_t i_op) {
2041 size_t max_dim_num = 0;
2042 for (size_t iter_op_inputs = 0; iter_op_inputs < ops_[i_op]->inputs_shape().size(); iter_op_inputs++) {
2043 if (ops_[i_op]->inputs_shape()[iter_op_inputs].size() > max_dim_num) {
2044 max_dim_num = ops_[i_op]->inputs_shape()[iter_op_inputs].size();
2045 }
2046 }
2047
2048 return max_dim_num;
2049 }
2050
GetDefaultStrategy(size_t i_op)2051 Dimensions RecStrategyPropagator::GetDefaultStrategy(size_t i_op) {
2052 Dimensions strategy;
2053 size_t max_dim_num = GetMaxDimNum(i_op);
2054 for (size_t i = 0; i < max_dim_num; i++) {
2055 strategy.push_back(1);
2056 }
2057
2058 return strategy;
2059 }
2060
StopPropAtOP(std::string op_type)2061 bool StopPropAtOP(std::string op_type) {
2062 const std::set<std::string> stop_at = {GATHERV2, ASSIGN, EXPAND_DIMS};
2063 return stop_at.find(op_type) != stop_at.end();
2064 }
2065
GenerateEliminatedOperatorStrategyForward(size_t min_devices)2066 size_t RecStrategyPropagator::GenerateEliminatedOperatorStrategyForward(size_t min_devices) {
2067 MS_LOG(INFO) << "There are " << no_stra_op_list_->size() << " operators left that do not have strategy.";
2068 size_t changes = 0;
2069 if (no_stra_op_list_->empty()) {
2070 return changes;
2071 }
2072
2073 std::vector<size_t> no_stra_op_list_bis;
2074 for (size_t iter_list = no_stra_op_list_->size(); iter_list > 0; iter_list--) {
2075 size_t iter_ops = no_stra_op_list_->at(iter_list - 1);
2076 Strategies strategies;
2077 size_t incoming_op_index = FindIndexOfOperatorIncoming(ops_, input_tensor_names_, iter_ops);
2078 Dimensions strategy = GetInputStrategy(graph_, ops_, index_list_, iter_ops, incoming_op_index);
2079 if (IsDimensionsEmpty(strategy) || DevicesForDimensions(strategy) < min_devices ||
2080 StopPropAtOP(ops_[incoming_op_index]->type())) {
2081 no_stra_op_list_bis.push_back(iter_ops);
2082 } else {
2083 strategies = GenerateStrategiesFromStrategy(ops_, iter_ops, strategy, graph_->dyn_shape_tmp_fix);
2084 ApplyStrategy(iter_ops, strategies);
2085 ++changes;
2086 MS_LOG(INFO) << ops_[iter_ops]->name() << " assigned strategies " << StrategyToString(strategies) << " from "
2087 << ops_[incoming_op_index]->name() << " with strategy " << strategy;
2088 }
2089 }
2090 *no_stra_op_list_ = no_stra_op_list_bis;
2091
2092 return changes;
2093 }
2094
GenerateEliminatedOperatorStrategyBackward(size_t min_devices)2095 size_t RecStrategyPropagator::GenerateEliminatedOperatorStrategyBackward(size_t min_devices) {
2096 MS_LOG(INFO) << "There are " << no_stra_op_list_->size() << " operators left that do not have strategy.";
2097 size_t changes = 0;
2098 if (no_stra_op_list_->empty()) {
2099 return changes;
2100 }
2101
2102 std::vector<size_t> no_stra_op_list_bis;
2103 for (size_t iter_list = no_stra_op_list_->size(); iter_list > 0; iter_list--) {
2104 auto iter_ops = no_stra_op_list_->at(iter_list - 1);
2105 Strategies strategies;
2106 std::pair<size_t, size_t> idx = FindIndexOfOperatorOutgoing(ops_, input_tensor_names_, iter_ops);
2107 size_t outgoing_op_index = idx.first;
2108 size_t iter_op_inputs = idx.second;
2109 Dimensions strategy =
2110 CopyOutgoingOperatorInputStrategy(ops_, iter_ops, outgoing_op_index, iter_op_inputs, graph_->dyn_shape_tmp_fix);
2111 if (IsDimensionsEmpty(strategy) || DevicesForDimensions(strategy) < min_devices ||
2112 StopPropAtOP(ops_[outgoing_op_index]->type())) {
2113 no_stra_op_list_bis.push_back(iter_ops);
2114 } else {
2115 strategies = GenerateStrategiesFromStrategy(ops_, iter_ops, strategy, graph_->dyn_shape_tmp_fix);
2116 ApplyStrategy(iter_ops, strategies);
2117 ++changes;
2118 MS_LOG(INFO) << ops_[iter_ops]->name() << " assigned strategies " << StrategyToString(strategies) << " from "
2119 << ops_[outgoing_op_index]->name() << " with strategy " << strategy;
2120 }
2121 }
2122 *no_stra_op_list_ = no_stra_op_list_bis;
2123
2124 return changes;
2125 }
2126
GenerateRemainingOperatorStrategy()2127 size_t RecStrategyPropagator::GenerateRemainingOperatorStrategy() {
2128 size_t changes = 0;
2129
2130 if (no_stra_op_list_->empty()) {
2131 return changes;
2132 }
2133
2134 size_t no_stra_op_list_size = no_stra_op_list_->size();
2135 do {
2136 no_stra_op_list_size = no_stra_op_list_->size();
2137 changes += GenerateEliminatedOperatorStrategyForward();
2138 changes += GenerateEliminatedOperatorStrategyBackward();
2139 } while (no_stra_op_list_size > no_stra_op_list_->size());
2140
2141 for (size_t iter_list = 0; iter_list < no_stra_op_list_->size(); iter_list++) {
2142 auto iter_ops = no_stra_op_list_->at(iter_list);
2143 Dimensions strategy = GetDefaultStrategy(iter_ops);
2144 if (graph_->dyn_shape_tmp_fix && strategy.empty()) {
2145 continue;
2146 }
2147 Strategies strategies = GenerateStrategiesFromStrategy(ops_, iter_ops, strategy, graph_->dyn_shape_tmp_fix);
2148 ApplyStrategy(iter_ops, strategies);
2149 ++changes;
2150 MS_LOG(INFO) << ops_[iter_ops]->name() << " assigned default strategies " << StrategyToString(strategies)
2151 << " with strategy " << strategy;
2152 }
2153
2154 return changes;
2155 }
2156
2157 // param_name equals to (operator index * input index)
GetParamUsers()2158 std::map<std::string, std::vector<std::pair<size_t, size_t>>> RecStrategyPropagator::GetParamUsers() {
2159 std::map<std::string, std::vector<std::pair<size_t, size_t>>> param_users;
2160
2161 AnfNodePtr ret = root_->get_return();
2162 std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);
2163
2164 for (auto &node : all_nodes) {
2165 if (node->isa<Parameter>()) {
2166 ParameterUsersInfo parameter_users_info = FindParameterUsers(node, IsParallelCareNode, all_nodes);
2167 auto users_set = parameter_users_info.second.second;
2168 if (users_set.size() >= 1) {
2169 MS_LOG(INFO) << "Parameter " << parameter_users_info.first << " has " << users_set.size() << " users.";
2170 for (auto &user : users_set) {
2171 MS_LOG(INFO) << "with ID: " << user.first->UniqueId() << " and name: " << user.first->UniqueName();
2172
2173 std::pair<size_t, size_t> user_index = std::make_pair(SIZE_MAX, SIZE_MAX);
2174 for (size_t i = 0; i < input_tensor_names_.size(); i++) {
2175 if (input_tensor_names_[i][0] == user.first->UniqueId()) {
2176 size_t input_index = 0;
2177 if ((ops_[i]->type() == MATMUL) || (ops_[i]->type() == BATCH_MATMUL)) {
2178 input_index = 1;
2179 }
2180 user_index = std::make_pair(i, input_index);
2181 }
2182 }
2183 if (user_index.first != SIZE_MAX) {
2184 param_users[parameter_users_info.first].push_back(user_index);
2185 }
2186 }
2187 }
2188 }
2189 }
2190
2191 return param_users;
2192 }
2193
SetParamStrategy()2194 void RecStrategyPropagator::SetParamStrategy() {
2195 std::map<std::string, std::vector<std::pair<size_t, size_t>>> params_users = GetParamUsers(); // perhaps store this ?
2196 for (auto ¶m : params_users) {
2197 MS_LOG(INFO) << "Treat parameter " << param.first << " with " << param.second.size() << " uers";
2198 if (param_strategy_.find(param.first) == param_strategy_.end() && !param.second.empty()) {
2199 Dimensions strategy;
2200 Dimensions max_strat;
2201 int max_stra_cut_num = 1;
2202 int max_stra_cut_ratio = INT_MAX;
2203
2204 for (auto &user : param.second) {
2205 MS_LOG(INFO) << "user is " << ops_[user.first]->name() << " param goes to input " << user.second;
2206 if (!HasStrategy(ops_[user.first])) {
2207 continue;
2208 }
2209 strategy = ops_[user.first]->selected_strategy()->GetInputDim()[user.second];
2210 if (strategy.empty()) {
2211 MS_LOG(INFO) << "user has no strategy";
2212 continue;
2213 }
2214 MS_LOG(INFO) << "This user wants strategy " << strategy;
2215
2216 auto param_shape = ops_[user.first]->inputs_shape()[user.second];
2217 auto ratio = 0;
2218 for (size_t idx = 0; idx < strategy.size(); idx++) {
2219 MS_EXCEPTION_IF_ZERO("strategy", strategy[idx]);
2220 ratio += param_shape[idx] / strategy[idx];
2221 }
2222
2223 int cut_num = DevicesForDimensions(strategy);
2224 if (cut_num >= max_stra_cut_num && ratio < max_stra_cut_ratio) {
2225 max_stra_cut_num = cut_num;
2226 max_stra_cut_ratio = ratio;
2227 max_strat = strategy;
2228 }
2229 }
2230 if (!max_strat.empty()) {
2231 param_strategy_[param.first] = max_strat;
2232 }
2233 }
2234 }
2235 MS_LOG(INFO) << "Done";
2236 }
2237
MakeGatherStratFromParam(const std::shared_ptr<OperatorInfo> & op,Dimensions param_strategy)2238 Strategies MakeGatherStratFromParam(const std::shared_ptr<OperatorInfo> &op, Dimensions param_strategy) {
2239 Strategies strategies;
2240 Dimensions index_strategy;
2241 int64_t axis = GetGatherAxis(op);
2242 if (param_strategy.at(LongToSize(axis)) == 1) {
2243 size_t num_device_used = 1;
2244 for (size_t i = 0; i < param_strategy.size(); i++) {
2245 num_device_used *= param_strategy[i];
2246 }
2247 MS_EXCEPTION_IF_ZERO("num_device_used", num_device_used);
2248 index_strategy.push_back(g_device_manager->stage_device_num() / num_device_used);
2249 } else {
2250 index_strategy.push_back(1);
2251 }
2252
2253 for (size_t i = 1; i < op->inputs_shape()[1].size(); ++i) {
2254 index_strategy.push_back(1);
2255 }
2256
2257 strategies.push_back(param_strategy);
2258 strategies.push_back(index_strategy);
2259
2260 MS_LOG(INFO) << "Gather is assigned strategy " << StrategyToString(strategies);
2261
2262 return strategies;
2263 }
2264
MakeMatMulStratFromParam(const std::shared_ptr<OperatorInfo> & op,Dimensions param_strategy)2265 Strategies MakeMatMulStratFromParam(const std::shared_ptr<OperatorInfo> &op, Dimensions param_strategy) {
2266 Strategies new_strategy;
2267 Dimensions new_param_strat;
2268 Dimensions input0_strat = op->selected_strategy()->GetInputDim()[0];
2269 int64_t k_cuts = 1;
2270
2271 auto input_value = op->input_value();
2272 bool transpose_a = input_value[2]->cast<BoolImmPtr>()->value();
2273 bool transpose_b = input_value[3]->cast<BoolImmPtr>()->value();
2274
2275 k_cuts = param_strategy[0];
2276 if (transpose_b) {
2277 new_param_strat.push_back(param_strategy[1]);
2278 new_param_strat.push_back(param_strategy[0]);
2279 } else {
2280 new_param_strat.push_back(param_strategy[0]);
2281 new_param_strat.push_back(param_strategy[1]);
2282 }
2283
2284 if (transpose_a) {
2285 input0_strat[0] = k_cuts;
2286 input0_strat[1] = std::min(input0_strat[1], g_device_manager->stage_device_num() / k_cuts);
2287 } else {
2288 input0_strat[1] = k_cuts;
2289 input0_strat[0] = std::min(input0_strat[1], g_device_manager->stage_device_num() / k_cuts);
2290 }
2291
2292 new_strategy.push_back(input0_strat);
2293 new_strategy.push_back(new_param_strat);
2294
2295 MS_LOG(INFO) << "Transpose B : " << transpose_b << "; Transpose A : " << transpose_a << "; K cuts : " << k_cuts;
2296
2297 MS_LOG(INFO) << "MatMul is assigned strategy " << StrategyToString(new_strategy);
2298
2299 return new_strategy;
2300 }
2301
ApplyParamStrategy()2302 size_t RecStrategyPropagator::ApplyParamStrategy() {
2303 size_t changes = 0;
2304 std::map<std::string, std::vector<std::pair<size_t, size_t>>> params_users = GetParamUsers();
2305
2306 for (auto ¶m : params_users) {
2307 if (param_strategy_.find(param.first) != param_strategy_.end()) {
2308 for (auto &user : param.second) {
2309 if (graph_->dyn_shape_tmp_fix && ops_[user.first]->type() == GATHERV2) {
2310 if (param.first.find(".output.ffn.projection.weight") != std::string::npos) {
2311 ApplyStrategy(user.first, GatherForDynamicShape(ops_[user.first], 1));
2312 continue;
2313 }
2314 if (param.first.find(".output.ffn.mapping.bias") != std::string::npos) {
2315 ApplyStrategy(user.first, GatherForDynamicShape(ops_[user.first], 3));
2316 continue;
2317 }
2318 if (param.first.find(".output.ffn.mapping.weight") != std::string::npos) {
2319 ApplyStrategy(user.first, GatherForDynamicShape(ops_[user.first], 2));
2320 continue;
2321 }
2322 // This Gather uses shared parameter, but it is not treated as using shared parameter.
2323 // Temporary workaround until this issue is fixed.
2324 if (param.first.find(".embedding.word_embedding.embedding_table") != std::string::npos) {
2325 ApplyStrategy(user.first, GatherForDynamicShape(ops_[user.first], 0));
2326 continue;
2327 }
2328 }
2329
2330 if (!HasStrategy(ops_[user.first]) ||
2331 param_strategy_[param.first] != ops_[user.first]->selected_strategy()->GetInputDim()[user.second]) {
2332 Strategies strategies;
2333 if (ops_[user.first]->type() == GATHERV2) {
2334 strategies = MakeGatherStratFromParam(ops_[user.first], param_strategy_[param.first]);
2335 } else if (ops_[user.first]->type() == MATMUL) {
2336 strategies = MakeMatMulStratFromParam(ops_[user.first], param_strategy_[param.first]);
2337 } else if (ops_[user.first]->type() == STRIDED_SLICE) {
2338 strategies = CheckDivisible(ops_[user.first], param_strategy_[param.first]);
2339 } else {
2340 strategies =
2341 GenerateStrategiesFromStrategy(ops_, user.first, param_strategy_[param.first], graph_->dyn_shape_tmp_fix);
2342 }
2343 ApplyStrategy(user.first, strategies);
2344 MS_LOG(INFO) << ops_[user.first]->name() << " assigned strategy " << StrategyToString(strategies)
2345 << " from parameter " << param.first;
2346 ++changes;
2347 }
2348 }
2349 }
2350 }
2351 return changes;
2352 }
2353
ModifyParamSharingOpsStrategy()2354 size_t RecStrategyPropagator::ModifyParamSharingOpsStrategy() {
2355 size_t changes = 0;
2356
2357 for (auto tensor : shared_tensors_ops_) {
2358 for (auto op_i : tensor) {
2359 for (auto op_j : tensor) {
2360 if (op_i != op_j) {
2361 MS_LOG(INFO) << "Operator " << ops_[op_i]->name() << " sharing parameter with operator "
2362 << ops_[op_j]->name();
2363 }
2364 }
2365 }
2366 }
2367
2368 for (auto tensor : shared_tensors_ops_) {
2369 for (auto op_i : tensor) {
2370 if (ops_[op_i]->type() == GATHERV2) {
2371 for (auto op_j : tensor) {
2372 if (op_i != op_j) {
2373 Dimensions str_j;
2374 if (ops_[op_j]->type() == CAST) {
2375 str_j = ops_[op_j]->selected_strategy()->GetInputDim()[0];
2376 } else if (ops_[op_j]->type() == MATMUL) {
2377 str_j = ops_[op_j]->selected_strategy()->GetInputDim()[1];
2378 } else if (ops_[op_j]->type() == MUL) {
2379 str_j = ops_[op_j]->selected_strategy()->GetInputDim()[0];
2380 } else {
2381 continue;
2382 }
2383
2384 Strategies strategies;
2385 Dimensions param_strategy, index_strategy;
2386
2387 param_strategy = str_j;
2388
2389 size_t num_device_used = 1;
2390 for (size_t i = 0; i < str_j.size(); i++) {
2391 num_device_used *= LongToSize(str_j[i]);
2392 }
2393 MS_EXCEPTION_IF_ZERO("num_device_used", num_device_used);
2394 index_strategy.push_back(g_device_manager->stage_device_num() / num_device_used);
2395
2396 for (size_t i = 1; i < ops_[op_i]->inputs_shape()[1].size(); ++i) {
2397 index_strategy.push_back(1);
2398 }
2399
2400 strategies.push_back(param_strategy);
2401 strategies.push_back(index_strategy);
2402
2403 MS_LOG(INFO) << "Changing strategy of " << ops_[op_i]->name() << " with " << ops_[op_j]->name();
2404 MS_LOG(INFO) << ops_[op_i]->name() << " assigned strategy " << StrategyToString(strategies)
2405 << " from ModifyParamSharingOpsStrategy";
2406
2407 ApplyStrategy(op_i, strategies);
2408 ++changes;
2409 }
2410 }
2411 }
2412 }
2413 }
2414
2415 return changes;
2416 }
2417
RecStrategyPropagator(const std::shared_ptr<Graph> & graph,const std::vector<std::shared_ptr<OperatorInfo>> & ops,const std::shared_ptr<std::vector<std::vector<size_t>>> & eli_list,const std::vector<std::vector<std::string>> & input_tensor_names,const std::shared_ptr<std::vector<size_t>> & index_list,bool is_training,const std::vector<std::vector<size_t>> & shared_tensors_ops,const FuncGraphPtr & root)2418 RecStrategyPropagator::RecStrategyPropagator(const std::shared_ptr<Graph> &graph,
2419 const std::vector<std::shared_ptr<OperatorInfo>> &ops,
2420 const std::shared_ptr<std::vector<std::vector<size_t>>> &eli_list,
2421 const std::vector<std::vector<std::string>> &input_tensor_names,
2422 const std::shared_ptr<std::vector<size_t>> &index_list, bool is_training,
2423 const std::vector<std::vector<size_t>> &shared_tensors_ops,
2424 const FuncGraphPtr &root)
2425 : graph_(graph),
2426 ops_(ops),
2427 eli_list_(eli_list),
2428 input_tensor_names_(input_tensor_names),
2429 index_list_(index_list),
2430 is_training_(is_training),
2431 shared_tensors_ops_(shared_tensors_ops),
2432 root_(root) {}
2433
CopyMainOperatorsStrategy()2434 size_t RecStrategyPropagator::CopyMainOperatorsStrategy() {
2435 size_t changes = 0;
2436
2437 for (size_t i_op = 0; i_op < static_cast<size_t>(index_list_->size()); i_op++) {
2438 Strategies strategies;
2439 size_t iter_graph = index_list_->at(i_op);
2440 if (iter_graph != SIZE_MAX && ops_[i_op]->type() != GET_NEXT) {
2441 strategies = PrepareStrategy(&graph_->nodes[iter_graph], ops_, i_op, graph_->dyn_shape_tmp_fix);
2442 }
2443 if (!strategies.empty()) {
2444 source_ops_.push_back(i_op);
2445 ++changes;
2446 }
2447 StrategyPtr sp = std::make_shared<Strategy>(0, strategies);
2448 ops_[i_op]->SetSelectedStrategyAndCost(sp, ops_[i_op]->selected_cost());
2449 }
2450
2451 return changes;
2452 }
2453
GetInputStrategy(const std::shared_ptr<Graph> & graph,const std::vector<std::shared_ptr<OperatorInfo>> & ops,const std::shared_ptr<std::vector<size_t>> & index_list,size_t i_op,size_t incoming_op_index)2454 Dimensions GetInputStrategy(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
2455 const std::shared_ptr<std::vector<size_t>> &index_list, size_t i_op,
2456 size_t incoming_op_index) {
2457 Dimensions strategy;
2458 if (incoming_op_index != SIZE_MAX) {
2459 auto iter_graph = index_list->at(incoming_op_index);
2460 if (iter_graph != SIZE_MAX) {
2461 strategy = CopyIncomingOperatorOutputStrategy(&graph->nodes[iter_graph], ops, i_op, incoming_op_index);
2462 } else {
2463 strategy = CopyIncomingOperatorInputStrategy(ops, i_op, incoming_op_index);
2464 }
2465 }
2466
2467 return strategy;
2468 }
2469
PropagateFromInputs()2470 size_t RecStrategyPropagator::PropagateFromInputs() { return 0; }
2471
PropagateFromOutputs()2472 size_t RecStrategyPropagator::PropagateFromOutputs() { return 0; }
2473
GenerateNoStraList()2474 void RecStrategyPropagator::GenerateNoStraList() {
2475 no_stra_op_list_ = std::make_shared<std::vector<size_t>>();
2476 for (size_t i = 0; i < eli_list_->size(); i++) {
2477 no_stra_op_list_->push_back(eli_list_->at(i)[0]);
2478 }
2479 }
2480
FixInvalidStra()2481 void RecStrategyPropagator::FixInvalidStra() {
2482 for (auto &op : ops_) {
2483 bool modified = false;
2484 if (!HasStrategy(op)) {
2485 continue;
2486 }
2487 if (op->type() == FILLV2) {
2488 continue;
2489 }
2490 if (graph_->dyn_shape_tmp_fix && (op->type() == ASSIGN || op->type() == ONEHOT)) {
2491 continue;
2492 }
2493 StrategyPtr old_strategys = op->selected_strategy();
2494 Strategies new_strategys;
2495 for (size_t iter_op_inputs = 0; iter_op_inputs < old_strategys->GetInputDim().size(); iter_op_inputs++) {
2496 Dimensions strategies;
2497 for (size_t iter_op_input_stra = 0; iter_op_input_stra < op->inputs_shape()[iter_op_inputs].size();
2498 iter_op_input_stra++) {
2499 if (graph_->dyn_shape_tmp_fix && op->inputs_shape()[iter_op_inputs][iter_op_input_stra] == -1) {
2500 strategies.push_back(old_strategys->GetInputDim()[iter_op_inputs][iter_op_input_stra]);
2501 continue;
2502 }
2503 if (op->inputs_shape()[iter_op_inputs][iter_op_input_stra] <
2504 old_strategys->GetInputDim()[iter_op_inputs][iter_op_input_stra] ||
2505 op->inputs_shape()[iter_op_inputs][iter_op_input_stra] %
2506 old_strategys->GetInputDim()[iter_op_inputs][iter_op_input_stra] !=
2507 0) {
2508 strategies.push_back(1);
2509 modified = true;
2510 } else {
2511 strategies.push_back(old_strategys->GetInputDim()[iter_op_inputs][iter_op_input_stra]);
2512 }
2513 }
2514 new_strategys.push_back(strategies);
2515 }
2516 if (modified) {
2517 StrategyPtr sp = std::make_shared<Strategy>(0, new_strategys);
2518 op->SetSelectedStrategyAndCost(sp, op->selected_cost());
2519 MS_LOG(INFO) << "CHANGE INVALID STRATEGY FOR : " << op->name() << " from " << old_strategys->GetInputDim()
2520 << " to " << StrategyToString(new_strategys);
2521 }
2522 }
2523 }
2524
AjustToNoTraining()2525 void RecStrategyPropagator::AjustToNoTraining() {
2526 for (auto &op : ops_) {
2527 // Set back to raw strategy for special node in predict/eval
2528 if (!is_training_) {
2529 if ((op->is_last_node()) || (op->type() == VIRTUAL_DATA_SET)) {
2530 SetBackToRawStrategy(op);
2531 }
2532 }
2533 }
2534 }
2535
GenerateStrategyV1()2536 void RecStrategyPropagator::GenerateStrategyV1() {
2537 MS_EXCEPTION_IF_NULL(graph_);
2538 MS_EXCEPTION_IF_NULL(eli_list_);
2539 MS_EXCEPTION_IF_NULL(index_list_);
2540
2541 no_stra_op_list_ = std::make_shared<std::vector<size_t>>();
2542 for (size_t i = eli_list_->size(); i > 0; i--) {
2543 no_stra_op_list_->push_back(eli_list_->at(i - 1)[0]);
2544 }
2545
2546 size_t changes;
2547 changes = CopyMainOperatorsStrategy();
2548 MS_LOG(INFO) << "The strategies of " << changes << " operators are modified after CopyMainOperatorsStrategy.";
2549
2550 changes = GenerateEliminatedOperatorStrategyForward();
2551 MS_LOG(INFO) << "The strategies of " << changes
2552 << " operators are modified after GenerateEliminatedOperatorStrategyForward.";
2553
2554 changes = GenerateEliminatedOperatorStrategyBackward();
2555 MS_LOG(INFO) << "The strategies of " << changes
2556 << " operators are modified after GenerateEliminatedOperatorStrategyBackward.";
2557
2558 changes = GenerateRemainingOperatorStrategy();
2559 MS_LOG(INFO) << "The strategies of " << changes << " operators are modified after GenerateRemainingOperatorStrategy.";
2560
2561 if (graph_->dyn_shape_tmp_fix) {
2562 for (auto &op : ops_) {
2563 if (op->type() == ASSIGN) {
2564 Strategies strategies;
2565 auto assign_input_0_shape = op->inputs_shape()[0];
2566 Dimensions assign_input_0_strategy(assign_input_0_shape.size(), 1);
2567 size_t num_device = LongToSize(g_device_manager->stage_device_num());
2568 if (assign_input_0_shape[1] > 0 && assign_input_0_shape[1] % num_device == 0) {
2569 assign_input_0_strategy[1] = num_device;
2570 }
2571 for (size_t i = 0; i < op->inputs_shape().size(); i++) {
2572 strategies.push_back(assign_input_0_strategy);
2573 }
2574 StrategyPtr sp = std::make_shared<Strategy>(0, strategies);
2575 op->SetSelectedStrategyAndCost(sp, op->selected_cost());
2576 }
2577 }
2578 }
2579
2580 SetParamStrategy();
2581 changes = ApplyParamStrategy();
2582 MS_LOG(INFO) << "The strategies of " << changes << " operators are modified after ApplyParamStrategy.";
2583
2584 FixInvalidStra();
2585 AjustToNoTraining();
2586 }
2587
AssignStandaloneAndBatchParallelOpStrategy()2588 size_t RecStrategyPropagator::AssignStandaloneAndBatchParallelOpStrategy() {
2589 size_t changes = 0;
2590 for (size_t iter_ops = 0; iter_ops < ops_.size(); iter_ops++) {
2591 auto pos = ops_[iter_ops]->name().find("Info");
2592 auto name = ops_[iter_ops]->name().substr(0, pos);
2593 if (name == STAND_ALONE) {
2594 Strategies strategies = PrepareStandAlone(ops_[iter_ops]);
2595 ApplyStrategy(iter_ops, strategies);
2596 changes++;
2597 MS_LOG(INFO) << ops_[iter_ops]->name() << " assigned strategy " << StrategyToString(strategies);
2598 auto iter = find(no_stra_op_list_->begin(), no_stra_op_list_->end(), iter_ops);
2599 if (iter != no_stra_op_list_->end()) {
2600 no_stra_op_list_->erase(iter);
2601 }
2602 }
2603 if (name == BATCH_PARALLEL) {
2604 Strategies strategies;
2605 auto split_flag_list = ops_[iter_ops]->split_flag_list();
2606 auto inputs_shape = ops_[iter_ops]->inputs_shape();
2607 for (size_t i = 0; i < inputs_shape.size(); i++) {
2608 Shape temp(inputs_shape[i].size(), 1);
2609 if (split_flag_list[i]) {
2610 temp[0] = g_device_manager->stage_device_num();
2611 }
2612 strategies.push_back(temp);
2613 }
2614 ApplyStrategy(iter_ops, strategies);
2615 changes++;
2616 MS_LOG(INFO) << ops_[iter_ops]->name() << " assigned strategy " << StrategyToString(strategies);
2617 auto iter = find(no_stra_op_list_->begin(), no_stra_op_list_->end(), iter_ops);
2618 if (iter != no_stra_op_list_->end()) {
2619 no_stra_op_list_->erase(iter);
2620 }
2621 }
2622 }
2623 return changes;
2624 }
2625
CalMatmulBatchDimFactor(size_t num_device,const StrategyRec & str)2626 static size_t CalMatmulBatchDimFactor(size_t num_device, const StrategyRec &str) {
2627 size_t max_shard_num = FloatToSize(1 / str.inputTensor[0].str_h) * FloatToSize(1 / str.inputTensor[0].str_w);
2628 max_shard_num = max_shard_num < num_device ? max_shard_num : num_device;
2629 return max_shard_num / (FloatToSize(1 / str.outputTensor.str_h) * FloatToSize(1 / str.outputTensor.str_w));
2630 }
2631
ExtraShardMatmulOnBatchDim()2632 void RecStrategyPropagator::ExtraShardMatmulOnBatchDim() {
2633 MS_EXCEPTION_IF_NULL(graph_);
2634 MS_EXCEPTION_IF_NULL(eli_list_);
2635 MS_EXCEPTION_IF_NULL(index_list_);
2636
2637 for (size_t i_op = 0; i_op < static_cast<size_t>(index_list_->size()); i_op++) {
2638 size_t iter_graph = index_list_->at(i_op);
2639 if (iter_graph == SIZE_MAX || ops_[i_op]->type() != MATMUL) {
2640 continue;
2641 }
2642 Graph::NodeType &node = graph_->nodes[iter_graph];
2643 size_t matmulBatchDimFactor = CalMatmulBatchDimFactor(g_device_manager->stage_device_num(), node.apply.str);
2644 if (matmulBatchDimFactor > 1) {
2645 MS_LOG(INFO) << ops_[i_op]->name() << " matmulBatchDimFactor " << matmulBatchDimFactor;
2646 node.apply.str.outputTensor.str_h /= matmulBatchDimFactor;
2647 node.tensor_parm.tensor_str.str_h = node.apply.str.outputTensor.str_h;
2648
2649 Strategies strategies;
2650 Dimensions strategy;
2651 strategy.push_back(static_cast<int64_t>(1.0 / node.apply.str.outputTensor.str_h));
2652 strategy.push_back(static_cast<int64_t>(1.0 / node.apply.str.outputTensor.str_w));
2653 strategies.push_back(strategy);
2654
2655 int64_t stage_id = g_device_manager->stage_id();
2656 StrategyPtr strategyPtr = NewStrategy(stage_id, strategies);
2657 ops_[i_op]->set_out_strategy(strategyPtr);
2658 }
2659 }
2660 }
2661
GenerateStrategyV3()2662 void RecStrategyPropagator::GenerateStrategyV3() {
2663 MS_EXCEPTION_IF_NULL(graph_);
2664 MS_EXCEPTION_IF_NULL(eli_list_);
2665 MS_EXCEPTION_IF_NULL(index_list_);
2666
2667 GenerateNoStraList();
2668 size_t changes;
2669 changes = CopyMainOperatorsStrategy();
2670 MS_LOG(INFO) << "CopyMainOperatorsStrategy has " << changes << "changes";
2671 AssignStandaloneAndBatchParallelOpStrategy();
2672
2673 for (auto min_devices = g_device_manager->stage_device_num(); min_devices > 1; min_devices /= SIZE_TWO) {
2674 size_t pass_changes = 1;
2675 while (pass_changes > 0) {
2676 pass_changes = 0;
2677
2678 changes = GenerateEliminatedOperatorStrategyForward(min_devices);
2679 MS_LOG(INFO) << "GenerateEliminatedOperatorStrategyForward has " << changes << "changes";
2680
2681 pass_changes += changes;
2682 if (changes > 0) continue;
2683
2684 changes = GenerateEliminatedOperatorStrategyBackward(min_devices);
2685 MS_LOG(INFO) << "GenerateEliminatedOperatorStrategyBackward has " << changes << "changes";
2686
2687 pass_changes += changes;
2688 if (changes > 0) continue;
2689 }
2690 }
2691
2692 changes = GenerateRemainingOperatorStrategy();
2693 MS_LOG(INFO) << "GenerateRemainingOperatorStrategy has " << changes << "changes";
2694
2695 changes = ModifyParamSharingOpsStrategy();
2696 MS_LOG(INFO) << "ModifyParamSharingOpsStrategy has " << changes << "changes";
2697
2698 FixInvalidStra();
2699 AjustToNoTraining();
2700 }
2701 } // namespace parallel
2702 } // namespace mindspore
2703