1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.h"
18
19 #include <algorithm>
20 #include <memory>
21 #include <vector>
22
23 #include "ir/value.h"
24 #include "frontend/parallel/auto_parallel/rec_core/rec_parse_graph.h"
25 #include "frontend/parallel/auto_parallel/rec_core/rec_partition.h"
26 #include "frontend/parallel/ops_info/operator_info.h"
27 #include "frontend/parallel/strategy.h"
28 #include "frontend/parallel/step_parallel.h"
29
30 namespace mindspore {
31 namespace parallel {
GenerateStrategy(const std::shared_ptr<Graph> & graph,const std::vector<std::shared_ptr<OperatorInfo>> & ops,const std::shared_ptr<std::vector<std::vector<size_t>>> & eli_list,const std::vector<std::vector<std::string>> & input_tensor_names,const std::shared_ptr<std::vector<size_t>> & index_list,bool is_training)32 void GenerateStrategy(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
33 const std::shared_ptr<std::vector<std::vector<size_t>>> &eli_list,
34 const std::vector<std::vector<std::string>> &input_tensor_names,
35 const std::shared_ptr<std::vector<size_t>> &index_list, bool is_training) {
36 MS_EXCEPTION_IF_NULL(graph);
37 MS_EXCEPTION_IF_NULL(eli_list);
38 MS_EXCEPTION_IF_NULL(index_list);
39 GeneratePartitionedOperatorStrategy(graph, ops, index_list);
40
41 std::shared_ptr<std::vector<size_t>> no_stra_op_list(new std::vector<size_t>);
42 for (size_t i = 0; i < eli_list->size(); i++) {
43 no_stra_op_list->push_back(eli_list->at(i)[0]);
44 }
45 GenerateEliminatedOperatorStrategyForward(graph, ops, input_tensor_names, index_list, no_stra_op_list);
46 GenerateEliminatedOperatorStrategyBackward(ops, input_tensor_names, no_stra_op_list);
47 GenerateRemainingOperatorStrategy(graph, ops, input_tensor_names, index_list, no_stra_op_list);
48
49 for (auto &op : ops) {
50 // Set user-defined strategy
51 auto attrs = op->attrs();
52 if (StrategyFound(attrs)) {
53 StrategyPtr user_defined_stra = parallel::ExtractStrategy(attrs[STRATEGY]);
54 op->SetSelectedStrategyAndCost(user_defined_stra, op->selected_cost());
55 }
56 // Set back to raw strategy for special node in predict/eval
57 if (!is_training) {
58 if ((op->is_last_node()) || (op->type() == VIRTUAL_DATA_SET)) {
59 SetBackToRawStrategy(op);
60 }
61 }
62 }
63 }
64
PrepareMatMul(const std::shared_ptr<Graph> & graph,const std::vector<std::shared_ptr<OperatorInfo>> & ops,const size_t iter_graph,const size_t iter_ops)65 Strategys PrepareMatMul(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
66 const size_t iter_graph, const size_t iter_ops) {
67 Strategys strategies;
68 auto attrs = ops[iter_ops]->attrs();
69 bool transpose_a = attrs[TRANSPOSE_A]->cast<BoolImmPtr>()->value();
70 bool transpose_b = attrs[TRANSPOSE_B]->cast<BoolImmPtr>()->value();
71
72 // HCCL does not support multi-dimension partition, and the hardware does not support excessive
73 // number of EVENT, so we temporarily disable matmul's multi-dimension partition function.
74 const float max_cut = 1.0 / SizeToFloat(g_device_manager->DeviceNum());
75 // The rule of cut is 0.5, 0.125. To compare the result we have to use ">" so we multiply max_cut to 1.1
76 if (graph->nodes[iter_graph].apply.arguments[0].tensor_str.str_h > max_cut * 1.1 &&
77 graph->nodes[iter_graph].apply.arguments[1].tensor_str.str_w > max_cut * 1.1) {
78 graph->nodes[iter_graph].apply.arguments[0].tensor_str.str_h = 1.0;
79 graph->nodes[iter_graph].apply.arguments[0].tensor_str.str_w = 1.0;
80 graph->nodes[iter_graph].apply.arguments[1].tensor_str.str_h = 1.0;
81 graph->nodes[iter_graph].apply.arguments[1].tensor_str.str_w = 1.0;
82 graph->nodes[iter_graph].tensor_parm.tensor_str.str_h = 1.0;
83 graph->nodes[iter_graph].tensor_parm.tensor_str.str_w = 1.0;
84
85 auto shape_1 = ops[iter_ops]->inputs_tensor_info()[0].shape()[0];
86 if (transpose_a) {
87 shape_1 = ops[iter_ops]->inputs_tensor_info()[0].shape()[1];
88 }
89 auto shape_4 = ops[iter_ops]->inputs_tensor_info()[1].shape()[1];
90 if (transpose_b) {
91 shape_4 = ops[iter_ops]->inputs_tensor_info()[1].shape()[0];
92 }
93
94 bool already_cut = false;
95 if (shape_1 >= shape_4) {
96 if (LongToSize(shape_1) % g_device_manager->DeviceNum() == 0) {
97 graph->nodes[iter_graph].apply.arguments[0].tensor_str.str_h = max_cut;
98 graph->nodes[iter_graph].tensor_parm.tensor_str.str_h = max_cut;
99 already_cut = true;
100 }
101 if (!already_cut && LongToSize(shape_4) % g_device_manager->DeviceNum() == 0) {
102 graph->nodes[iter_graph].apply.arguments[1].tensor_str.str_w = max_cut;
103 graph->nodes[iter_graph].tensor_parm.tensor_str.str_w = max_cut;
104 already_cut = true;
105 }
106 } else {
107 if (LongToSize(shape_4) % g_device_manager->DeviceNum() == 0) {
108 graph->nodes[iter_graph].apply.arguments[1].tensor_str.str_w = max_cut;
109 graph->nodes[iter_graph].tensor_parm.tensor_str.str_w = max_cut;
110 already_cut = true;
111 }
112 if (!already_cut && LongToSize(shape_1) % g_device_manager->DeviceNum() == 0) {
113 graph->nodes[iter_graph].apply.arguments[0].tensor_str.str_h = max_cut;
114 graph->nodes[iter_graph].tensor_parm.tensor_str.str_h = max_cut;
115 already_cut = true;
116 }
117 }
118
119 if (!already_cut) {
120 MS_LOG(EXCEPTION) << "Failure: MatMul's shape is invalid.";
121 }
122 }
123
124 for (size_t iter_op_inputs = 0; iter_op_inputs < ops[iter_ops]->inputs_tensor_info().size(); iter_op_inputs++) {
125 Dimensions s;
126 if (transpose_a && (iter_op_inputs == 0)) {
127 s.push_back(
128 static_cast<int64_t>(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_w));
129 s.push_back(
130 static_cast<int64_t>(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_h));
131 } else if (transpose_b && (iter_op_inputs == 1)) {
132 s.push_back(
133 static_cast<int64_t>(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_w));
134 s.push_back(
135 static_cast<int64_t>(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_h));
136 } else {
137 s.push_back(
138 static_cast<int64_t>(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_h));
139 s.push_back(
140 static_cast<int64_t>(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_w));
141 }
142 strategies.push_back(s);
143 }
144 return strategies;
145 }
146
PrepareBiasAdd(const std::shared_ptr<Dimensions> & s)147 Strategys PrepareBiasAdd(const std::shared_ptr<Dimensions> &s) {
148 Strategys strategies;
149 strategies.push_back(*s);
150 Dimensions s_biasadd;
151 s_biasadd.push_back(s->at(1));
152 strategies.push_back(s_biasadd);
153 return strategies;
154 }
155
PrepareStridedSlice(const std::vector<std::shared_ptr<OperatorInfo>> & ops,const size_t iter_ops,Dimensions basic_stra)156 Strategys PrepareStridedSlice(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
157 Dimensions basic_stra) {
158 Strategys stra;
159
160 auto begin = GetValue<std::vector<int64_t>>(ops[iter_ops]->input_value().at(1));
161 auto end = GetValue<std::vector<int64_t>>(ops[iter_ops]->input_value().at(2));
162 auto strides = GetValue<std::vector<int64_t>>(ops[iter_ops]->input_value().at(3));
163
164 for (size_t i = 0; i < strides.size(); ++i) {
165 if ((strides[i] != 1) && (basic_stra[i] > 1)) {
166 basic_stra[i] = 1;
167 }
168 }
169
170 for (size_t i = 0; i < begin.size(); ++i) {
171 bool no_fully_fetch = ((begin[i] != 0) || (end[i] < ops[iter_ops]->inputs_tensor_info()[0].shape()[i]));
172 if (no_fully_fetch && (basic_stra[i] != 1)) {
173 basic_stra[i] = 1;
174 }
175 }
176
177 stra.push_back(basic_stra);
178 return stra;
179 }
180
PrepareOneHot(const std::shared_ptr<Graph> & graph,const std::vector<std::shared_ptr<OperatorInfo>> & ops,const size_t iter_graph,const size_t iter_ops)181 Strategys PrepareOneHot(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
182 const size_t iter_graph, const size_t iter_ops) {
183 Strategys strategies = MakeRecSearchStrategy(graph, ops, iter_graph, iter_ops);
184
185 int64_t axis = -1;
186 auto iter = ops[iter_ops]->attrs().find(AXIS);
187 if (iter != ops[iter_ops]->attrs().end()) {
188 MS_EXCEPTION_IF_NULL(iter->second);
189 if (iter->second->isa<Int64Imm>()) {
190 axis = iter->second->cast<Int64ImmPtr>()->value();
191 } else {
192 MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": The value of axis is not int64_t.";
193 }
194 }
195 if (axis == -1) {
196 strategies[0][0] = strategies[0][1];
197 strategies[0][1] = 1;
198 graph->nodes[iter_graph].tensor_parm.tensor_str.str_h = graph->nodes[iter_graph].tensor_parm.tensor_str.str_w;
199 graph->nodes[iter_graph].tensor_parm.tensor_str.str_w = 1.0;
200 }
201
202 Dimensions s_empty = {};
203 strategies.push_back(s_empty);
204 strategies.push_back(s_empty);
205 return strategies;
206 }
207
PrepareGatherV2(const std::vector<std::shared_ptr<OperatorInfo>> & ops,const size_t iter_ops,Dimensions s)208 Strategys PrepareGatherV2(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops, Dimensions s) {
209 Strategys strategies;
210
211 auto axis_input = GetValue<int64_t>(ops[iter_ops]->input_value().at(2));
212 if (axis_input < 0) {
213 axis_input += SizeToLong(ops[iter_ops]->inputs_tensor_info()[0].shape().size());
214 }
215 int64_t axis = axis_input;
216 if (axis >= SizeToLong(s.size())) {
217 MS_LOG(EXCEPTION) << "Failure: GatherV2' axis out of range.";
218 }
219 s[LongToSize(axis)] = 1;
220 strategies.push_back(s);
221
222 return strategies;
223 }
224
PrepareGatherV2P(const std::vector<std::shared_ptr<OperatorInfo>> & ops,const size_t iter_ops,Dimensions s)225 Strategys PrepareGatherV2P(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops, Dimensions s) {
226 Strategys strategies;
227
228 auto output_shape = ops[iter_ops]->outputs_tensor_info()[0].shape();
229 Dimensions index(output_shape.size() - 1, 0);
230 for (size_t i = 0; i < index.size(); i++) {
231 index[i] = SizeToLong(i);
232 }
233 std::sort(index.begin(), index.end(), [&output_shape](const int64_t &a, const int64_t &b) {
234 return (output_shape[LongToSize(a + 1)] > output_shape[LongToSize(b + 1)]);
235 });
236 std::transform(std::begin(index), std::end(index), std::begin(index), [](int64_t x) { return x + 1; });
237 index.insert(index.begin(), 0);
238
239 Dimensions strategie(output_shape.size(), 1);
240 size_t num_device = g_device_manager->DeviceNum();
241 size_t cut = 1;
242 for (size_t i = 0; i < index.size(); i++) {
243 size_t index_i = LongToSize(index[i]);
244 while (output_shape[index_i] % 2 == 0 && output_shape[index_i] > 0 && cut < num_device) {
245 output_shape[index_i] /= 2;
246 cut *= 2;
247 strategie[index_i] *= 2;
248 }
249 if (cut == num_device) {
250 break;
251 }
252 }
253
254 auto axis_input = GetValue<int64_t>(ops[iter_ops]->input_value().at(2));
255 if (axis_input < 0) {
256 axis_input += SizeToLong(ops[iter_ops]->inputs_tensor_info()[0].shape().size());
257 }
258 int64_t axis = axis_input;
259 if (axis >= SizeToLong(s.size())) {
260 MS_LOG(EXCEPTION) << "Failure: GatherV2' axis out of range.";
261 }
262 if (axis == 0) {
263 s.clear();
264 s.push_back(1);
265 for (size_t i = 1; i < ops[iter_ops]->inputs_tensor_info()[0].shape().size(); i++) {
266 s.push_back(strategie[ops[iter_ops]->inputs_tensor_info()[1].shape().size() - 1 + i]);
267 }
268 strategies.push_back(s);
269 s.clear();
270 for (size_t i = 0; i < ops[iter_ops]->inputs_tensor_info()[1].shape().size(); i++) {
271 s.push_back(strategie[i]);
272 }
273 strategies.push_back(s);
274 } else if (axis == 1) {
275 s.clear();
276 s.push_back(strategie[0]);
277 s.push_back(1);
278 strategies.push_back(s);
279 s.clear();
280 for (size_t i = 0; i < ops[iter_ops]->inputs_tensor_info()[1].shape().size(); i++) {
281 s.push_back(strategie[ops[iter_ops]->inputs_tensor_info()[0].shape().size() - 1 + i]);
282 }
283 strategies.push_back(s);
284 } else {
285 MS_LOG(EXCEPTION) << "Failure: GatherV2's axis is neither 0 nor 1.";
286 }
287
288 return strategies;
289 }
290
PrepareGatherV2POutputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> & ops,const size_t incoming_op_index)291 Dimensions PrepareGatherV2POutputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
292 const size_t incoming_op_index) {
293 auto output_shape = ops[incoming_op_index]->outputs_tensor_info()[0].shape();
294 Dimensions index(output_shape.size() - 1, 0);
295 for (size_t i = 0; i < index.size(); i++) {
296 index[i] = SizeToLong(i);
297 }
298 std::sort(index.begin(), index.end(),
299 [&output_shape](const size_t &a, const size_t &b) { return (output_shape[a + 1] > output_shape[b + 1]); });
300 std::transform(std::begin(index), std::end(index), std::begin(index), [](int64_t x) { return x + 1; });
301 index.insert(index.begin(), 0);
302
303 Dimensions strategie(output_shape.size(), 1);
304 size_t num_device = g_device_manager->DeviceNum();
305 size_t cut = 1;
306 for (size_t i = 0; i < index.size(); i++) {
307 size_t index_i = LongToSize(index[i]);
308 while (output_shape[index_i] % 2 == 0 && output_shape[index_i] > 0 && cut < num_device) {
309 output_shape[index_i] /= 2;
310 cut *= 2;
311 strategie[index_i] *= 2;
312 }
313 if (cut == num_device) {
314 break;
315 }
316 }
317
318 return strategie;
319 }
320
PrepareL2Normalize(const std::vector<std::shared_ptr<OperatorInfo>> & ops,const size_t iter_ops,Dimensions s)321 Strategys PrepareL2Normalize(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
322 Dimensions s) {
323 int64_t axis = 0;
324 auto iter = ops[iter_ops]->attrs().find(AXIS);
325 if (iter != ops[iter_ops]->attrs().end()) {
326 MS_EXCEPTION_IF_NULL(iter->second);
327 if (iter->second->isa<ValueSequeue>()) {
328 axis = GetValue<std::vector<int64_t>>(iter->second)[0];
329 } else {
330 MS_LOG(EXCEPTION) << ops[iter_ops]->name() << " : The value of axis is not int64_t.";
331 }
332 }
333
334 int64_t axis_index = axis;
335 if (axis < 0) {
336 size_t input_dim = ops[iter_ops]->inputs_tensor_info()[0].shape().size();
337 axis_index = static_cast<int64_t>(input_dim) + axis;
338 }
339
340 s[LongToSize(axis_index)] = 1;
341
342 Strategys strategies;
343 strategies.push_back(s);
344 return strategies;
345 }
346
PrepareAxisRelatedStrategy(const std::shared_ptr<Graph> & graph,const std::vector<std::shared_ptr<OperatorInfo>> & ops,const size_t iter_graph,const size_t iter_ops)347 Strategys PrepareAxisRelatedStrategy(const std::shared_ptr<Graph> &graph,
348 const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_graph,
349 const size_t iter_ops) {
350 Strategys strategies = MakeRecSearchStrategy(graph, ops, iter_graph, iter_ops);
351 if (strategies.size() < 1) {
352 MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": get empty Strategy.";
353 }
354
355 std::vector<int64_t> axis_list;
356 string axis_name = AXIS;
357 int64_t default_axis = -1;
358 if (ops[iter_ops]->type() == LAYER_NORM) {
359 axis_name = "begin_norm_axis";
360 default_axis = 1;
361 }
362
363 auto iter = ops[iter_ops]->attrs().find(axis_name);
364 if (iter != ops[iter_ops]->attrs().end()) {
365 MS_EXCEPTION_IF_NULL(iter->second);
366 if (iter->second->isa<Int64Imm>()) {
367 axis_list.push_back(iter->second->cast<Int64ImmPtr>()->value());
368 } else if (iter->second->isa<ValueTuple>()) {
369 ValueTuplePtr value_tuple = iter->second->cast<ValueTuplePtr>();
370 if (value_tuple == nullptr) {
371 MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": The value_tuple is nullptr.";
372 }
373 std::vector<ValuePtr> value_vector = value_tuple->value();
374 (void)std::transform(value_vector.begin(), value_vector.end(), std::back_inserter(axis_list),
375 [](const ValuePtr &value) { return static_cast<int64_t>(GetValue<int64_t>(value)); });
376 } else {
377 MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": The value of axis is not int64_t or tuple int64_t.";
378 }
379 } else {
380 axis_list.push_back(default_axis);
381 }
382
383 for (auto &axis : axis_list) {
384 if (axis < 0) {
385 int64_t input_dim = SizeToLong(ops[iter_ops]->inputs_tensor_info()[0].shape().size());
386 axis = input_dim + axis;
387 }
388 if (axis >= SizeToLong(strategies[0].size()) || axis < 0) {
389 MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": axis value is out of range.";
390 }
391 if (strategies[0][LongToSize(axis)] != 1) {
392 strategies[0][LongToSize(axis)] = 1;
393 MS_LOG(INFO) << ops[iter_ops]->name() << ": adjust strategy to 1 on axis " << axis;
394 }
395 }
396 return strategies;
397 }
398
MakeRecSearchStrategy(const std::shared_ptr<Graph> & graph,const std::vector<std::shared_ptr<OperatorInfo>> & ops,const size_t iter_graph,const size_t iter_ops)399 Strategys MakeRecSearchStrategy(const std::shared_ptr<Graph> &graph,
400 const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_graph,
401 const size_t iter_ops) {
402 if (ops.empty()) {
403 MS_LOG(EXCEPTION) << "Failure: Operators is empty.";
404 }
405 if (iter_ops >= ops.size()) {
406 MS_LOG(EXCEPTION) << "Failure: Operators' elements out of range.";
407 }
408 if (graph->nodes[iter_graph].apply.op_type == kRecUnsortedSegmentOp) {
409 return MakeDataParallelStrategy(graph, ops, iter_graph, iter_ops);
410 }
411
412 StrategyPtr origin_strategy = ops[iter_ops]->strategy();
413 Strategys strategies;
414 for (size_t iter_op_inputs = 0; iter_op_inputs < ops[iter_ops]->inputs_tensor_info().size(); iter_op_inputs++) {
415 if (iter_op_inputs >= origin_strategy->GetInputDim().size()) {
416 MS_LOG(EXCEPTION) << "Failure: Strategy's InputDim out of range.";
417 }
418
419 size_t output_size = origin_strategy->GetInputDim()[iter_op_inputs].size();
420 Dimensions s;
421 if (output_size == 4) {
422 s.push_back(
423 static_cast<int64_t>(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_n));
424 s.push_back(
425 static_cast<int64_t>(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_c));
426 s.push_back(
427 static_cast<int64_t>(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_h));
428 s.push_back(
429 static_cast<int64_t>(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_w));
430 } else if (output_size == 3) {
431 // Experimental support for 3D data.
432 s.push_back(
433 static_cast<int64_t>(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_c));
434 s.push_back(
435 static_cast<int64_t>(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_h));
436 s.push_back(
437 static_cast<int64_t>(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_w));
438 } else if (output_size == 2) {
439 s.push_back(
440 static_cast<int64_t>(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_h));
441 s.push_back(
442 static_cast<int64_t>(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_w));
443 } else if (output_size == 1) {
444 s.push_back(
445 static_cast<int64_t>(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_w));
446 } else if (output_size == 0) {
447 s = {};
448 } else {
449 MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": Tensor's output size is unexcepted.";
450 }
451 strategies.push_back(s);
452 }
453 return strategies;
454 }
455
MakeDataParallelStrategy(const std::shared_ptr<Graph> & graph,const std::vector<std::shared_ptr<OperatorInfo>> & ops,const size_t iter_graph,const size_t iter_ops)456 Strategys MakeDataParallelStrategy(const std::shared_ptr<Graph> &graph,
457 const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_graph,
458 const size_t iter_ops) {
459 if (ops.empty()) {
460 MS_LOG(EXCEPTION) << "Failure: Operators is empty.";
461 }
462 if (iter_ops >= ops.size()) {
463 MS_LOG(EXCEPTION) << "Failure: Operators' elements out of range.";
464 }
465
466 StrategyPtr origin_strategy = ops[iter_ops]->strategy();
467 Strategys strategies;
468 size_t max_device_num = g_device_manager->DeviceNum();
469 size_t target_tensor_batch = ops[iter_ops]->inputs_tensor_info()[0].shape()[0];
470 for (size_t iter_op_inputs = 0; iter_op_inputs < ops[iter_ops]->inputs_tensor_info().size(); iter_op_inputs++) {
471 if (iter_op_inputs >= origin_strategy->GetInputDim().size()) {
472 MS_LOG(EXCEPTION) << "Failure: Strategy's InputDim out of range.";
473 }
474
475 Dimensions s;
476 size_t input_size = origin_strategy->GetInputDim()[iter_op_inputs].size();
477 for (size_t dim = 0; dim < input_size; dim++) {
478 // Experimental support for 3D data (input_size == 3).
479 if (input_size >= 1 && input_size <= 4) {
480 if (dim == 0) {
481 s.push_back(std::min(max_device_num, target_tensor_batch));
482 } else {
483 s.push_back(1);
484 }
485 } else if (input_size == 0) {
486 s = {};
487 } else {
488 MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": Tensor shape " << input_size << " is unexpected.";
489 }
490 }
491 strategies.push_back(s);
492 }
493 // Set default strategy.
494 graph->nodes[iter_graph].tensor_parm.tensor_str.str_n = 1.0;
495 graph->nodes[iter_graph].tensor_parm.tensor_str.str_c = 1.0;
496 graph->nodes[iter_graph].tensor_parm.tensor_str.str_h = 1.0;
497 graph->nodes[iter_graph].tensor_parm.tensor_str.str_w = 1.0;
498
499 // Update data parallel strategy.
500 if (ops[iter_ops]->outputs_tensor_info().size() == 0) {
501 MS_LOG(EXCEPTION) << ops[iter_ops]->name() << " output tensor info is empty.";
502 }
503 if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 1) {
504 graph->nodes[iter_graph].tensor_parm.tensor_str.str_w = 1.0 / std::min(max_device_num, target_tensor_batch);
505 } else if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 2) {
506 graph->nodes[iter_graph].tensor_parm.tensor_str.str_h = 1.0 / std::min(max_device_num, target_tensor_batch);
507 } else if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 3) {
508 // Experimental support for 3D data.
509 graph->nodes[iter_graph].tensor_parm.tensor_str.str_c = 1.0 / std::min(max_device_num, target_tensor_batch);
510 } else if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 4) {
511 graph->nodes[iter_graph].tensor_parm.tensor_str.str_n = 1.0 / std::min(max_device_num, target_tensor_batch);
512 } else {
513 MS_LOG(INFO) << ops[iter_ops]->name() << " output tensor shape is unexpected, using default value instead.";
514 }
515
516 return strategies;
517 }
518
MakeFullBatchStrategy(const std::shared_ptr<Graph> & graph,const std::vector<std::shared_ptr<OperatorInfo>> & ops,const size_t iter_graph,const size_t iter_ops)519 Strategys MakeFullBatchStrategy(const std::shared_ptr<Graph> &graph,
520 const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_graph,
521 const size_t iter_ops) {
522 if (ops.empty()) {
523 MS_LOG(EXCEPTION) << "Failure: Operators is empty.";
524 }
525 if (iter_ops >= ops.size()) {
526 MS_LOG(EXCEPTION) << "Failure: Operators' elements out of range.";
527 }
528
529 StrategyPtr origin_strategy = ops[iter_ops]->strategy();
530 Strategys strategies;
531 for (size_t iter_op_inputs = 0; iter_op_inputs < ops[iter_ops]->inputs_tensor_info().size(); iter_op_inputs++) {
532 if (iter_op_inputs >= origin_strategy->GetInputDim().size()) {
533 MS_LOG(EXCEPTION) << "Failure: Strategy's InputDim out of range.";
534 }
535 Dimensions s;
536 size_t input_size = origin_strategy->GetInputDim()[iter_op_inputs].size();
537 for (size_t dim = 0; dim < input_size; dim++) {
538 if (input_size >= 1 && input_size <= 4) {
539 s.push_back(1);
540 } else if (input_size == 0) {
541 s = {};
542 } else {
543 MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": Tensor shape " << input_size << " is unexpected.";
544 }
545 }
546 strategies.push_back(s);
547 }
548 // Update the output strategy of Rec Graph
549 graph->nodes[iter_graph].tensor_parm.tensor_str.str_n = 1.0;
550 graph->nodes[iter_graph].tensor_parm.tensor_str.str_c = 1.0;
551 graph->nodes[iter_graph].tensor_parm.tensor_str.str_h = 1.0;
552 graph->nodes[iter_graph].tensor_parm.tensor_str.str_w = 1.0;
553
554 return strategies;
555 }
556
SetBackToRawStrategy(const std::shared_ptr<OperatorInfo> & op)557 void SetBackToRawStrategy(const std::shared_ptr<OperatorInfo> &op) {
558 StrategyPtr origin_strategy = op->strategy();
559 Strategys strategies;
560
561 for (size_t iter_strategy = 0; iter_strategy < origin_strategy->GetInputDim().size(); iter_strategy++) {
562 Dimensions s;
563 size_t strategy_size = origin_strategy->GetInputDim()[iter_strategy].size();
564 for (size_t dim = 0; dim < strategy_size; dim++) {
565 if (strategy_size >= 1 && strategy_size <= 4) {
566 s.push_back(1);
567 } else if (strategy_size == 0) {
568 s = {};
569 } else {
570 MS_LOG(EXCEPTION) << op->name() << ": Strategy size " << strategy_size << " is unmatched.";
571 }
572 }
573 strategies.push_back(s);
574 }
575
576 StrategyPtr sp = std::make_shared<Strategy>(0, strategies);
577 op->SetSelectedStrategyAndCost(sp, op->selected_cost());
578 }
579
PrepareStrategy(const std::shared_ptr<Graph> & graph,const std::vector<std::shared_ptr<OperatorInfo>> & ops,const size_t iter_graph,const size_t iter_ops)580 Strategys PrepareStrategy(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
581 const size_t iter_graph, const size_t iter_ops) {
582 if (ops.empty()) {
583 MS_LOG(EXCEPTION) << "Failure: Operators is empty.";
584 }
585 if (iter_ops >= ops.size()) {
586 MS_LOG(EXCEPTION) << "Failure: Operators' elements out of range.";
587 }
588 MS_EXCEPTION_IF_NULL(ops[iter_ops]);
589
590 auto type = ops[iter_ops]->type();
591 if (type == MATMUL) {
592 return PrepareMatMul(graph, ops, iter_graph, iter_ops);
593 } else if (type == ONEHOT) {
594 return PrepareOneHot(graph, ops, iter_graph, iter_ops);
595 } else if ((type == SOFTMAX) || (type == LAYER_NORM)) {
596 return PrepareAxisRelatedStrategy(graph, ops, iter_graph, iter_ops);
597 } else if ((type == SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS) || (type == "Dropout") || (type == BATCH_MATMUL)) {
598 return MakeDataParallelStrategy(graph, ops, iter_graph, iter_ops);
599 } else if (type == "_VirtualDataset") {
600 if (ParallelContext::GetInstance()->full_batch()) {
601 return MakeFullBatchStrategy(graph, ops, iter_graph, iter_ops);
602 } else {
603 return MakeDataParallelStrategy(graph, ops, iter_graph, iter_ops);
604 }
605 } else {
606 return MakeRecSearchStrategy(graph, ops, iter_graph, iter_ops);
607 }
608 }
609
GeneratePartitionedOperatorStrategy(const std::shared_ptr<Graph> & graph,const std::vector<std::shared_ptr<OperatorInfo>> & ops,const std::shared_ptr<std::vector<size_t>> & index_list)610 void GeneratePartitionedOperatorStrategy(const std::shared_ptr<Graph> &graph,
611 const std::vector<std::shared_ptr<OperatorInfo>> &ops,
612 const std::shared_ptr<std::vector<size_t>> &index_list) {
613 for (size_t iter_ops = 0; iter_ops < (size_t)index_list->size(); iter_ops++) {
614 Strategys strategies;
615 size_t iter_graph = index_list->at(iter_ops);
616 if (iter_graph != SIZE_MAX && ops[iter_ops]->type() != GET_NEXT) {
617 strategies = PrepareStrategy(graph, ops, iter_graph, iter_ops);
618 }
619 StrategyPtr sp = std::make_shared<Strategy>(0, strategies);
620 ops[iter_ops]->SetSelectedStrategyAndCost(sp, ops[iter_ops]->selected_cost());
621 }
622 }
623
FindIndexOfOperatorIncoming(const std::vector<std::vector<std::string>> & input_tensor_names,const size_t iter_ops)624 size_t FindIndexOfOperatorIncoming(const std::vector<std::vector<std::string>> &input_tensor_names,
625 const size_t iter_ops) {
626 size_t incoming_op_index = SIZE_MAX;
627 for (size_t i = 1; i < input_tensor_names[iter_ops].size(); i++) {
628 for (size_t j = 0; j < input_tensor_names.size(); j++) {
629 if (input_tensor_names[iter_ops][i] == input_tensor_names[j][0]) {
630 incoming_op_index = j;
631 break;
632 }
633 }
634 if (incoming_op_index != SIZE_MAX) {
635 break;
636 }
637 }
638 return incoming_op_index;
639 }
640
CheckVirtualDatasetStrategy(const std::shared_ptr<Graph> & graph,const size_t iter_graph)641 float CheckVirtualDatasetStrategy(const std::shared_ptr<Graph> &graph, const size_t iter_graph) {
642 // The values for str can only be 1.0, 0.5, 0.25, 0.125…
643 // We want to find out the first str that is smaller than 1
644 if (graph->nodes[iter_graph].tensor_parm.tensor_str.str_n < 0.9) {
645 return graph->nodes[iter_graph].tensor_parm.tensor_str.str_n;
646 }
647 if (graph->nodes[iter_graph].tensor_parm.tensor_str.str_c < 0.9) {
648 return graph->nodes[iter_graph].tensor_parm.tensor_str.str_c;
649 }
650 if (graph->nodes[iter_graph].tensor_parm.tensor_str.str_h < 0.9) {
651 return graph->nodes[iter_graph].tensor_parm.tensor_str.str_h;
652 }
653 if (graph->nodes[iter_graph].tensor_parm.tensor_str.str_w < 0.9) {
654 return graph->nodes[iter_graph].tensor_parm.tensor_str.str_w;
655 }
656 return 1.0;
657 }
658
CopyVirtualDataset(const std::shared_ptr<Graph> & graph,const std::vector<std::shared_ptr<OperatorInfo>> & ops,const size_t iter_ops,const size_t iter_graph)659 Dimensions CopyVirtualDataset(const std::shared_ptr<Graph> &graph,
660 const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
661 const size_t iter_graph) {
662 Dimensions s;
663 auto input_stra_dim = ops[iter_ops]->inputs_tensor_info()[0].shape().size();
664 auto virtual_dataset_str = CheckVirtualDatasetStrategy(graph, iter_graph);
665 if (input_stra_dim == 0) {
666 return s;
667 } else {
668 if (virtual_dataset_str == 0) {
669 s.push_back(1);
670 } else {
671 s.push_back(FloatToLong(1 / virtual_dataset_str));
672 }
673 for (size_t i = 1; i < input_stra_dim; i++) {
674 s.push_back(1);
675 }
676 }
677 return s;
678 }
679
CopyIncomingOperatorOutputStrategy(const std::shared_ptr<Graph> & graph,const std::vector<std::shared_ptr<OperatorInfo>> & ops,const size_t iter_ops,const size_t iter_graph,const size_t incoming_op_index)680 Dimensions CopyIncomingOperatorOutputStrategy(const std::shared_ptr<Graph> &graph,
681 const std::vector<std::shared_ptr<OperatorInfo>> &ops,
682 const size_t iter_ops, const size_t iter_graph,
683 const size_t incoming_op_index) {
684 Dimensions s;
685
686 if (ops[incoming_op_index]->type() == VIRTUAL_DATA_SET) {
687 s = CopyVirtualDataset(graph, ops, iter_ops, iter_graph);
688 return s;
689 }
690
691 for (auto input : ops[iter_ops]->inputs_tensor_info()) {
692 auto input_stra_dim = input.shape().size();
693 if (input_stra_dim == 0) {
694 continue;
695 }
696 if (input_stra_dim == 1) {
697 s.push_back(FloatToLong(1 / graph->nodes[iter_graph].tensor_parm.tensor_str.str_w));
698 } else if (input_stra_dim == 2) {
699 s.push_back(FloatToLong(1 / graph->nodes[iter_graph].tensor_parm.tensor_str.str_h));
700 s.push_back(FloatToLong(1 / graph->nodes[iter_graph].tensor_parm.tensor_str.str_w));
701 } else if (input_stra_dim == 3) {
702 // Experimental support for 3D data.
703 s.push_back(FloatToLong(1 / graph->nodes[iter_graph].tensor_parm.tensor_str.str_c));
704 s.push_back(FloatToLong(1 / graph->nodes[iter_graph].tensor_parm.tensor_str.str_h));
705 s.push_back(FloatToLong(1 / graph->nodes[iter_graph].tensor_parm.tensor_str.str_w));
706 } else if (input_stra_dim == 4) {
707 s.push_back(FloatToLong(1 / graph->nodes[iter_graph].tensor_parm.tensor_str.str_n));
708 s.push_back(FloatToLong(1 / graph->nodes[iter_graph].tensor_parm.tensor_str.str_c));
709 s.push_back(FloatToLong(1 / graph->nodes[iter_graph].tensor_parm.tensor_str.str_h));
710 s.push_back(FloatToLong(1 / graph->nodes[iter_graph].tensor_parm.tensor_str.str_w));
711 } else {
712 MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": Tensor's shape is unknown.";
713 }
714 break;
715 }
716 return s;
717 }
718
PrepareIncomingOperatorInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> & ops,const size_t incoming_op_index)719 Dimensions PrepareIncomingOperatorInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
720 const size_t incoming_op_index) {
721 Dimensions s;
722 if (ops[incoming_op_index]->type() == RESHAPE || ops[incoming_op_index]->type() == TRANSPOSE) {
723 return s;
724 }
725 if (ops[incoming_op_index]->type() == GATHERV2) {
726 auto pos = ops[incoming_op_index]->name().find("Info");
727 if (pos == std::string::npos) {
728 return s;
729 }
730 auto name = ops[incoming_op_index]->name().substr(0, pos);
731 if (name == "Gather") {
732 return s;
733 } else if (name == "GatherP") {
734 return PrepareGatherV2POutputStrategy(ops, incoming_op_index);
735 } else {
736 MS_LOG(EXCEPTION) << "Failure: Unknown type of GatherV2." << std::endl;
737 }
738 }
739 auto strategy = ops[incoming_op_index]->selected_strategy();
740 if (strategy->GetInputNumber() == 0) {
741 return s;
742 }
743
744 for (size_t i = 0; i < (size_t)ops[incoming_op_index]->inputs_tensor_info().size(); i++) {
745 if (ops[incoming_op_index]->inputs_tensor_info()[i].shape().size() == 0) {
746 continue;
747 }
748 for (size_t j = 0; j < ops[incoming_op_index]->inputs_tensor_info()[i].shape().size(); ++j) {
749 s.push_back(strategy->GetInputDim()[i][j]);
750 }
751 break;
752 }
753 return s;
754 }
755
GetAxisList(const std::vector<std::shared_ptr<OperatorInfo>> & ops,const int64_t iter_ops)756 Dimensions GetAxisList(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const int64_t iter_ops) {
757 Dimensions axis_list;
758 auto axis_param = ops[LongToSize(iter_ops)]->attrs().find(AXIS)->second;
759 std::vector<ValuePtr> elements;
760 if (axis_param->isa<ValueTuple>()) {
761 elements = axis_param->cast<ValueTuplePtr>()->value();
762 } else if (axis_param->isa<ValueList>()) {
763 elements = axis_param->cast<ValueListPtr>()->value();
764 } else {
765 MS_LOG(EXCEPTION) << "Failure: Axis type is invalid, neither tuple nor list." << std::endl;
766 }
767
768 for (auto &element : elements) {
769 if (!element->isa<Int64Imm>()) {
770 MS_LOG(EXCEPTION) << "Failure: Dimension indexes is not Int32." << std::endl;
771 }
772 auto axis = element->cast<Int64ImmPtr>()->value();
773 axis_list.push_back(axis);
774 }
775 return axis_list;
776 }
777
ModifyStrategyIfSqueezeIncoming(const std::vector<std::shared_ptr<OperatorInfo>> & ops,const size_t incoming_op_index,Dimensions s)778 Dimensions ModifyStrategyIfSqueezeIncoming(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
779 const size_t incoming_op_index, Dimensions s) {
780 Dimensions s_Squeeze;
781 Dimensions stra_dim_list;
782 for (size_t i = 0; i < s.size(); i++) {
783 stra_dim_list.push_back(SizeToLong(i));
784 }
785
786 auto axis_list = GetAxisList(ops, SizeToLong(incoming_op_index));
787 for (auto axis : axis_list) {
788 auto it = find(stra_dim_list.begin(), stra_dim_list.end(), axis);
789 if (it == stra_dim_list.end()) {
790 MS_LOG(EXCEPTION) << "Failure: Can not find dimension indexes in Axis." << std::endl;
791 }
792 if (ops[incoming_op_index]->inputs_tensor_info()[0].shape()[LongToSize(axis)] != 1) {
793 MS_LOG(EXCEPTION) << "Failure: Removed dimension's shape is not 1." << std::endl;
794 }
795 stra_dim_list.erase(it);
796 }
797
798 for (size_t i = 0; i < (size_t)stra_dim_list.size(); i++) {
799 s_Squeeze.push_back(s[LongToSize(stra_dim_list[i])]);
800 }
801 return s_Squeeze;
802 }
803
GetKeepDims(const std::vector<std::shared_ptr<OperatorInfo>> & ops,const size_t iter_ops)804 bool GetKeepDims(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops) {
805 bool keepdims = false;
806 auto keep_dims_iter = ops[iter_ops]->attrs().find(KEEP_DIMS);
807 if (keep_dims_iter == ops[iter_ops]->attrs().end()) {
808 MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": Don't have attr keep_dims.";
809 }
810 MS_EXCEPTION_IF_NULL(keep_dims_iter->second);
811 if (!keep_dims_iter->second->isa<BoolImm>()) {
812 MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": Keep_dims is not a bool.";
813 }
814 keepdims = keep_dims_iter->second->cast<BoolImmPtr>()->value();
815 return keepdims;
816 }
817
GetDimList(const std::vector<std::shared_ptr<OperatorInfo>> & ops,const size_t iter_ops)818 Dimensions GetDimList(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops) {
819 Dimensions dim_list;
820 bool keep_dims = GetKeepDims(ops, iter_ops);
821 if (keep_dims != false) {
822 return dim_list;
823 }
824 auto input_value = ops[iter_ops]->input_value();
825 auto input_dim = ops[iter_ops]->inputs_tensor_info()[0].shape().size();
826 if (input_value.back()->isa<ValueTuple>()) {
827 auto attr_axis = GetValue<std::vector<int64_t>>(input_value.back());
828 if (attr_axis.empty()) {
829 for (size_t i = 0; i < input_dim; i++) {
830 dim_list.push_back(SizeToLong(i));
831 }
832 } else {
833 for (auto &axis : attr_axis) {
834 axis < 0 ? dim_list.push_back(axis + SizeToLong(input_dim)) : dim_list.push_back(axis);
835 }
836 }
837 } else if (input_value.back()->isa<Int64Imm>()) {
838 int64_t axis = GetValue<int64_t>(input_value.back());
839 axis < 0 ? dim_list.push_back(axis + SizeToLong(input_dim)) : dim_list.push_back(axis);
840 } else {
841 MS_LOG(EXCEPTION) << "Failure: Axis type is invalid." << std::endl;
842 }
843 return dim_list;
844 }
845
ModifyStrategyIfReduceIncoming(const std::vector<std::shared_ptr<OperatorInfo>> & ops,const size_t incoming_op_index,Dimensions s)846 Dimensions ModifyStrategyIfReduceIncoming(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
847 const size_t incoming_op_index, Dimensions s) {
848 Dimensions s_Reduce;
849 Dimensions axis_list;
850 for (size_t i = 0; i < s.size(); i++) {
851 axis_list.push_back(SizeToLong(i));
852 }
853
854 auto dim_list = GetDimList(ops, incoming_op_index);
855 for (auto axis : dim_list) {
856 auto it = find(axis_list.begin(), axis_list.end(), axis);
857 if (it == axis_list.end()) {
858 MS_LOG(EXCEPTION) << "Failure: Can not find dimension indexes in Axis." << std::endl;
859 }
860 axis_list.erase(it);
861 }
862
863 for (size_t i = 0; i < (size_t)axis_list.size(); i++) {
864 s_Reduce.push_back(s[LongToSize(axis_list[i])]);
865 }
866 return s_Reduce;
867 }
868
GetDimListFromAttrs(const std::vector<std::shared_ptr<OperatorInfo>> & ops,const size_t iter_ops)869 Dimensions GetDimListFromAttrs(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops) {
870 Dimensions dim_list;
871 auto iter = ops[iter_ops]->attrs().find(AXIS);
872 if (iter == ops[iter_ops]->attrs().end()) {
873 MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": Don't have attr axis.";
874 }
875 auto input_dim = ops[iter_ops]->inputs_tensor_info()[0].shape().size();
876 MS_EXCEPTION_IF_NULL(iter->second);
877 if (iter->second->isa<ValueTuple>()) {
878 auto attr_axis = GetValue<std::vector<int64_t>>(iter->second);
879 if (attr_axis.empty()) {
880 for (size_t i = 0; i < input_dim; ++i) {
881 dim_list.push_back(SizeToLong(i));
882 }
883 } else {
884 for (auto &axis : attr_axis) {
885 axis < 0 ? dim_list.push_back(axis + SizeToLong(input_dim)) : dim_list.push_back(axis);
886 }
887 }
888 } else if (iter->second->isa<Int64Imm>()) {
889 int64_t axis = GetValue<int64_t>(iter->second);
890 axis < 0 ? dim_list.push_back(axis + SizeToLong(input_dim)) : dim_list.push_back(axis);
891 } else {
892 MS_LOG(EXCEPTION) << "Axis type is invalid.";
893 }
894 return dim_list;
895 }
896
ModifyStrategyIfArgIncoming(const std::vector<std::shared_ptr<OperatorInfo>> & ops,const size_t incoming_op_index,Dimensions s)897 Dimensions ModifyStrategyIfArgIncoming(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
898 const size_t incoming_op_index, Dimensions s) {
899 bool keepdims = GetKeepDims(ops, incoming_op_index);
900 if (keepdims) {
901 return s;
902 }
903
904 Dimensions s_Arg;
905 Dimensions axis_list;
906 for (size_t i = 0; i < s.size(); i++) {
907 axis_list.push_back(SizeToLong(i));
908 }
909
910 auto dim_list = GetDimListFromAttrs(ops, incoming_op_index);
911 for (auto axis : dim_list) {
912 auto it = find(axis_list.begin(), axis_list.end(), axis);
913 if (it == axis_list.end()) {
914 MS_LOG(EXCEPTION) << "Failure: Can not find dimension indexes in Axis." << std::endl;
915 }
916 axis_list.erase(it);
917 }
918
919 for (size_t i = 0; i < (size_t)axis_list.size(); i++) {
920 s_Arg.push_back(s[LongToSize(axis_list[i])]);
921 }
922 return s_Arg;
923 }
924
CopyIncomingOperatorInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> & ops,const size_t incoming_op_index)925 Dimensions CopyIncomingOperatorInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
926 const size_t incoming_op_index) {
927 Dimensions s;
928 s = PrepareIncomingOperatorInputStrategy(ops, incoming_op_index);
929 if (s.size() != 0) {
930 if (ops[incoming_op_index]->type() == SQUEEZE) {
931 s = ModifyStrategyIfSqueezeIncoming(ops, incoming_op_index, s);
932 }
933 if (ops[incoming_op_index]->type() == REDUCE_SUM || ops[incoming_op_index]->type() == REDUCE_MAX ||
934 ops[incoming_op_index]->type() == REDUCE_MIN || ops[incoming_op_index]->type() == REDUCE_MEAN) {
935 s = ModifyStrategyIfReduceIncoming(ops, incoming_op_index, s);
936 }
937 if (ops[incoming_op_index]->type() == ARGMAXWITHVALUE || ops[incoming_op_index]->type() == ARGMINWITHVALUE) {
938 s = ModifyStrategyIfArgIncoming(ops, incoming_op_index, s);
939 }
940 }
941 return s;
942 }
943
GenerateStrategiesFromStrategy(const std::vector<std::shared_ptr<OperatorInfo>> & ops,const size_t iter_ops,Dimensions basic_stra)944 Strategys GenerateStrategiesFromStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
945 Dimensions basic_stra) {
946 Strategys stra;
947 MS_EXCEPTION_IF_NULL(ops[iter_ops]);
948
949 if (basic_stra.size() == 0) {
950 for (size_t iter_op_inputs = 0; iter_op_inputs < (size_t)ops[iter_ops]->inputs_tensor_info().size();
951 iter_op_inputs++) {
952 stra.push_back(basic_stra);
953 }
954 return stra;
955 }
956
957 auto s_ptr = std::make_shared<Dimensions>(basic_stra);
958 if (ops[iter_ops]->type() == BIAS_ADD) {
959 return PrepareBiasAdd(s_ptr);
960 }
961 if (ops[iter_ops]->type() == STRIDED_SLICE) {
962 return PrepareStridedSlice(ops, iter_ops, basic_stra);
963 }
964 if (ops[iter_ops]->type() == GATHERV2) {
965 auto pos = ops[iter_ops]->name().find("Info");
966 auto name = ops[iter_ops]->name().substr(0, pos);
967 if (name == "Gather") {
968 return PrepareGatherV2(ops, iter_ops, basic_stra);
969 } else if (name == "GatherP") {
970 return PrepareGatherV2P(ops, iter_ops, basic_stra);
971 } else {
972 MS_LOG(EXCEPTION) << "Failure: Unknown type of GatherV2." << std::endl;
973 }
974 }
975 if (ops[iter_ops]->type() == L2_NORMALIZE) {
976 return PrepareL2Normalize(ops, iter_ops, basic_stra);
977 }
978 if (ops[iter_ops]->type() == ADD || ops[iter_ops]->type() == SUB || ops[iter_ops]->type() == MUL ||
979 ops[iter_ops]->type() == DIV) {
980 return CheckBroadcast(ops, iter_ops, basic_stra);
981 }
982
983 return CheckDivisible(ops, iter_ops, basic_stra);
984 }
985
986 // Function to deal with ops with broadcasting, like TensorAdd/Sub/Mul/Div etc.
CheckBroadcast(const std::vector<std::shared_ptr<OperatorInfo>> & ops,const size_t iter_ops,const Dimensions s)987 Strategys CheckBroadcast(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
988 const Dimensions s) {
989 Strategys stra;
990
991 size_t first_tensor_dim = ops[iter_ops]->inputs_tensor_info()[0].shape().size();
992 size_t second_tensor_dim = ops[iter_ops]->inputs_tensor_info()[1].shape().size();
993 size_t s_dim = s.size();
994 // Do Broadcasting in the second tensor.
995 if (second_tensor_dim < first_tensor_dim) {
996 bool broadcast_first_tensor = false;
997 // Push back the first tensor's strategy.
998 if (s_dim == first_tensor_dim) {
999 stra.push_back(s);
1000 } else {
1001 Dimensions broadcast_revise_s(first_tensor_dim, 1);
1002 stra.push_back(broadcast_revise_s);
1003 }
1004 // Push back the second tensor's strategy after applying broadcast.
1005 stra.push_back(ApplyBroadcast(ops, iter_ops, s, first_tensor_dim, second_tensor_dim, broadcast_first_tensor));
1006 } else if (second_tensor_dim > first_tensor_dim) { // Do Broadcasting in the first tensor.
1007 bool broadcast_first_tensor = true;
1008 // Push back the first tensor's strategy after applying broadcast.
1009 stra.push_back(ApplyBroadcast(ops, iter_ops, s, first_tensor_dim, second_tensor_dim, broadcast_first_tensor));
1010 // Push back the second tensor's strategy.
1011 if (s_dim == second_tensor_dim) {
1012 stra.push_back(s);
1013 } else {
1014 Dimensions broadcast_revise_s(second_tensor_dim, 1);
1015 stra.push_back(broadcast_revise_s);
1016 }
1017 } else { // Broadcasting can be ignored or No broadcasting needs to be applied.
1018 stra = CheckDivisible(ops, iter_ops, s);
1019 }
1020
1021 return stra;
1022 }
1023
ApplyBroadcast(const std::vector<std::shared_ptr<OperatorInfo>> & ops,const size_t iter_ops,Dimensions s,size_t first_tensor_dim,size_t second_tensor_dim,bool broadcast_first_tensor)1024 Dimensions ApplyBroadcast(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops, Dimensions s,
1025 size_t first_tensor_dim, size_t second_tensor_dim, bool broadcast_first_tensor) {
1026 Dimensions s_empty = {};
1027 Dimensions s_broadcast;
1028 size_t target_tensor_index = 0;
1029 size_t refer_tensor_index = 0;
1030 size_t target_tensor_dim;
1031 size_t refer_tensor_dim;
1032
1033 // Indexing target and refer tensor.
1034 if (broadcast_first_tensor) {
1035 target_tensor_index = 0;
1036 refer_tensor_index = 1;
1037 target_tensor_dim = first_tensor_dim;
1038 refer_tensor_dim = second_tensor_dim;
1039 } else {
1040 target_tensor_index = 1;
1041 refer_tensor_index = 0;
1042 target_tensor_dim = second_tensor_dim;
1043 refer_tensor_dim = first_tensor_dim;
1044 }
1045
1046 // When target tensor with an empty dim.
1047 if (target_tensor_dim == 0) {
1048 return s_empty;
1049 } else if (target_tensor_dim == 1) { // When target tensor with a single dim.
1050 bool broadcast_dim_found = false;
1051 for (size_t iter = 0; iter < refer_tensor_dim; iter++) {
1052 // Find and copy that dim's strategy from the refer tensor.
1053 if ((ops[iter_ops]->inputs_tensor_info()[refer_tensor_index].shape()[iter] ==
1054 ops[iter_ops]->inputs_tensor_info()[target_tensor_index].shape()[0]) &&
1055 (ops[iter_ops]->inputs_tensor_info()[refer_tensor_index].shape()[iter] > 1) &&
1056 (refer_tensor_dim == s.size())) {
1057 s_broadcast.push_back(s.at(iter));
1058 broadcast_dim_found = true;
1059 break;
1060 }
1061 }
1062 // Cannot decide which dim it is, push back one.
1063 if (broadcast_dim_found == false) {
1064 s_broadcast.push_back(1);
1065 }
1066 } else {
1067 // Cannot decide which dim needs to do broadcast, push back one(s).
1068 for (size_t iter = 0; iter < target_tensor_dim; iter++) {
1069 s_broadcast.push_back(1);
1070 }
1071 }
1072
1073 return s_broadcast;
1074 }
1075
1076 // Check whether the operator can be divided by the current strategy.
CheckDivisible(const std::vector<std::shared_ptr<OperatorInfo>> & ops,const size_t iter_ops,const Dimensions basic_stra)1077 Strategys CheckDivisible(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
1078 const Dimensions basic_stra) {
1079 Dimensions s_empty = {};
1080 Strategys stra;
1081
1082 // For all the input tensors.
1083 for (size_t iter_op_inputs = 0; iter_op_inputs < (size_t)ops[iter_ops]->inputs_tensor_info().size();
1084 iter_op_inputs++) {
1085 // If input tensor is empty, return strategy as void.
1086 if (ops[iter_ops]->inputs_tensor_info()[iter_op_inputs].shape().size() == 0) {
1087 stra.push_back(s_empty);
1088 continue;
1089 }
1090
1091 Dimensions tmp_stra = basic_stra;
1092 bool modified = false;
1093
1094 // Make sure each tensor's dim shape is greater than 1. If not, push back strategy as 1 instead.
1095 for (size_t j = 0; j < (size_t)ops[iter_ops]->inputs_tensor_info()[iter_op_inputs].shape().size(); j++) {
1096 if (ops[iter_ops]->inputs_tensor_info()[iter_op_inputs].shape()[j] == 1) {
1097 tmp_stra[j] = 1;
1098 modified = true;
1099 }
1100 }
1101 if (modified) {
1102 stra.push_back(tmp_stra);
1103 } else {
1104 stra.push_back(basic_stra);
1105 }
1106 }
1107
1108 return stra;
1109 }
1110
GenerateEliminatedOperatorStrategyForward(const std::shared_ptr<Graph> & graph,const std::vector<std::shared_ptr<OperatorInfo>> & ops,const std::vector<std::vector<std::string>> & input_tensor_names,const std::shared_ptr<std::vector<size_t>> & index_list,const std::shared_ptr<std::vector<size_t>> & no_stra_op_list)1111 void GenerateEliminatedOperatorStrategyForward(const std::shared_ptr<Graph> &graph,
1112 const std::vector<std::shared_ptr<OperatorInfo>> &ops,
1113 const std::vector<std::vector<std::string>> &input_tensor_names,
1114 const std::shared_ptr<std::vector<size_t>> &index_list,
1115 const std::shared_ptr<std::vector<size_t>> &no_stra_op_list) {
1116 if (no_stra_op_list->size() == 0) {
1117 return;
1118 }
1119 std::vector<size_t> no_stra_op_list_bis;
1120
1121 for (size_t iter_list = no_stra_op_list->size(); iter_list > 0; iter_list--) {
1122 size_t iter_ops = no_stra_op_list->at(iter_list - 1);
1123 Strategys stra;
1124 Dimensions s;
1125 size_t incoming_op_index = FindIndexOfOperatorIncoming(input_tensor_names, iter_ops);
1126 if (incoming_op_index != SIZE_MAX) {
1127 auto iter_graph = index_list->at(incoming_op_index);
1128 if (iter_graph != SIZE_MAX) {
1129 s = CopyIncomingOperatorOutputStrategy(graph, ops, iter_ops, iter_graph, incoming_op_index);
1130 } else {
1131 s = CopyIncomingOperatorInputStrategy(ops, incoming_op_index);
1132 }
1133 }
1134
1135 if (s.size() == 0) {
1136 no_stra_op_list_bis.push_back(iter_ops);
1137 } else {
1138 stra = GenerateStrategiesFromStrategy(ops, iter_ops, s);
1139 }
1140
1141 StrategyPtr sp = std::make_shared<Strategy>(0, stra);
1142 ops[iter_ops]->SetSelectedStrategyAndCost(sp, ops[iter_ops]->selected_cost());
1143 }
1144
1145 no_stra_op_list->clear();
1146 for (size_t i = 0; i < no_stra_op_list_bis.size(); i++) {
1147 no_stra_op_list->push_back(no_stra_op_list_bis[i]);
1148 }
1149 }
1150
ModifyStrategyIfSqueezeOutgoing(const std::vector<std::shared_ptr<OperatorInfo>> & ops,const size_t iter_ops,Dimensions s)1151 Dimensions ModifyStrategyIfSqueezeOutgoing(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
1152 Dimensions s) {
1153 Dimensions s_Squeeze;
1154 auto axis_list = GetAxisList(ops, SizeToLong(iter_ops));
1155 size_t s_index = 0;
1156 size_t axis_list_index = 0;
1157 for (size_t i = 0; i < (size_t)(s.size() + axis_list.size()); i++) {
1158 if (i == (size_t)axis_list[axis_list_index]) {
1159 s_Squeeze.push_back(1);
1160 axis_list_index++;
1161 } else {
1162 s_Squeeze.push_back(s[s_index]);
1163 s_index++;
1164 }
1165 }
1166
1167 size_t cut = 1;
1168 for (size_t i = 0; i < s_Squeeze.size(); i++) {
1169 cut *= LongToSize(s_Squeeze[i]);
1170 }
1171 if (cut != g_device_manager->DeviceNum()) {
1172 s_Squeeze.clear();
1173 }
1174
1175 return s_Squeeze;
1176 }
1177
CopyOutgoingOperatorInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> & ops,const std::vector<std::vector<std::string>> & input_tensor_names,const size_t iter_ops)1178 Dimensions CopyOutgoingOperatorInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
1179 const std::vector<std::vector<std::string>> &input_tensor_names,
1180 const size_t iter_ops) {
1181 Dimensions s;
1182 if (ops[iter_ops]->type() == REDUCE_MAX || ops[iter_ops]->type() == REDUCE_MIN ||
1183 ops[iter_ops]->type() == REDUCE_SUM || ops[iter_ops]->type() == REDUCE_MEAN || ops[iter_ops]->type() == RESHAPE ||
1184 ops[iter_ops]->type() == GATHERV2 || ops[iter_ops]->type() == TRANSPOSE ||
1185 ops[iter_ops]->type() == ARGMAXWITHVALUE || ops[iter_ops]->type() == ARGMINWITHVALUE) {
1186 return s;
1187 }
1188
1189 bool found = false;
1190 size_t outgoing_op_index = SIZE_MAX;
1191 size_t iter_op_inputs = SIZE_MAX;
1192 for (size_t i = 0; i < input_tensor_names.size(); i++) {
1193 for (size_t j = 1; j < input_tensor_names[i].size(); j++) {
1194 if (input_tensor_names[i][j] == input_tensor_names[iter_ops][0] &&
1195 ops[i]->selected_strategy()->GetInputNumber() != 0) {
1196 outgoing_op_index = i;
1197 iter_op_inputs = j - 1;
1198 found = true;
1199 break;
1200 }
1201 }
1202 if (found) {
1203 break;
1204 }
1205 }
1206
1207 if (outgoing_op_index != SIZE_MAX && iter_op_inputs != SIZE_MAX) {
1208 for (size_t k = 0; k < ops[iter_ops]->outputs_tensor_info()[0].shape().size(); ++k) {
1209 s.push_back(ops[outgoing_op_index]->selected_strategy()->GetInputDim()[iter_op_inputs][k]);
1210 }
1211 }
1212 return s;
1213 }
1214
GenerateEliminatedOperatorStrategyBackward(const std::vector<std::shared_ptr<OperatorInfo>> & ops,const std::vector<std::vector<std::string>> & input_tensor_names,const std::shared_ptr<std::vector<size_t>> & no_stra_op_list)1215 void GenerateEliminatedOperatorStrategyBackward(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
1216 const std::vector<std::vector<std::string>> &input_tensor_names,
1217 const std::shared_ptr<std::vector<size_t>> &no_stra_op_list) {
1218 if (no_stra_op_list->size() == 0) {
1219 return;
1220 }
1221 std::vector<size_t> no_stra_op_list_bis;
1222
1223 for (size_t iter_list = no_stra_op_list->size(); iter_list > 0; iter_list--) {
1224 auto iter_ops = no_stra_op_list->at(iter_list - 1);
1225 Strategys stra;
1226 Dimensions s = CopyOutgoingOperatorInputStrategy(ops, input_tensor_names, iter_ops);
1227 if (s.size() != 0 && ops[iter_ops]->type() == SQUEEZE) {
1228 s = ModifyStrategyIfSqueezeOutgoing(ops, iter_ops, s);
1229 }
1230 if (s.size() != 0) {
1231 stra = GenerateStrategiesFromStrategy(ops, iter_ops, s);
1232 } else {
1233 no_stra_op_list_bis.push_back(iter_ops);
1234 }
1235
1236 StrategyPtr sp = std::make_shared<Strategy>(0, stra);
1237 ops[iter_ops]->SetSelectedStrategyAndCost(sp, ops[iter_ops]->selected_cost());
1238 }
1239
1240 no_stra_op_list->clear();
1241 for (size_t i = 0; i < no_stra_op_list_bis.size(); i++) {
1242 no_stra_op_list->push_back(no_stra_op_list_bis[i]);
1243 }
1244 }
1245
GenerateRemainingOperatorStrategy(const std::shared_ptr<Graph> & graph,const std::vector<std::shared_ptr<OperatorInfo>> & ops,const std::vector<std::vector<std::string>> & input_tensor_names,const std::shared_ptr<std::vector<size_t>> & index_list,const std::shared_ptr<std::vector<size_t>> & no_stra_op_list)1246 void GenerateRemainingOperatorStrategy(const std::shared_ptr<Graph> &graph,
1247 const std::vector<std::shared_ptr<OperatorInfo>> &ops,
1248 const std::vector<std::vector<std::string>> &input_tensor_names,
1249 const std::shared_ptr<std::vector<size_t>> &index_list,
1250 const std::shared_ptr<std::vector<size_t>> &no_stra_op_list) {
1251 if (no_stra_op_list->size() == 0) {
1252 return;
1253 }
1254
1255 size_t no_stra_op_list_size = no_stra_op_list->size();
1256 do {
1257 no_stra_op_list_size = no_stra_op_list->size();
1258 GenerateEliminatedOperatorStrategyForward(graph, ops, input_tensor_names, index_list, no_stra_op_list);
1259 GenerateEliminatedOperatorStrategyBackward(ops, input_tensor_names, no_stra_op_list);
1260 } while (no_stra_op_list_size > no_stra_op_list->size());
1261
1262 for (size_t iter_list = 0; iter_list < no_stra_op_list->size(); iter_list++) {
1263 auto iter_ops = no_stra_op_list->at(iter_list);
1264 Strategys stra;
1265 Dimensions s;
1266
1267 size_t max_dim_num = 0;
1268 for (size_t iter_op_inputs = 0; iter_op_inputs < ops[iter_ops]->inputs_tensor_info().size(); iter_op_inputs++) {
1269 if (ops[iter_ops]->inputs_tensor_info()[iter_op_inputs].shape().size() > max_dim_num) {
1270 max_dim_num = ops[iter_ops]->inputs_tensor_info()[iter_op_inputs].shape().size();
1271 }
1272 }
1273 for (size_t i = 0; i < max_dim_num; i++) {
1274 s.push_back(1);
1275 }
1276
1277 stra = GenerateStrategiesFromStrategy(ops, iter_ops, s);
1278 StrategyPtr sp = std::make_shared<Strategy>(0, stra);
1279 ops[iter_ops]->SetSelectedStrategyAndCost(sp, ops[iter_ops]->selected_cost());
1280 }
1281 }
1282 } // namespace parallel
1283 } // namespace mindspore
1284