• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.h"
18 
19 #include <algorithm>
20 #include <memory>
21 #include <vector>
22 
23 #include "ir/value.h"
24 #include "frontend/parallel/auto_parallel/rec_core/rec_parse_graph.h"
25 #include "frontend/parallel/auto_parallel/rec_core/rec_partition.h"
26 #include "frontend/parallel/ops_info/operator_info.h"
27 #include "frontend/parallel/strategy.h"
28 #include "frontend/parallel/step_parallel.h"
29 
30 namespace mindspore {
31 namespace parallel {
GenerateStrategy(const std::shared_ptr<Graph> & graph,const std::vector<std::shared_ptr<OperatorInfo>> & ops,const std::shared_ptr<std::vector<std::vector<size_t>>> & eli_list,const std::vector<std::vector<std::string>> & input_tensor_names,const std::shared_ptr<std::vector<size_t>> & index_list,bool is_training)32 void GenerateStrategy(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
33                       const std::shared_ptr<std::vector<std::vector<size_t>>> &eli_list,
34                       const std::vector<std::vector<std::string>> &input_tensor_names,
35                       const std::shared_ptr<std::vector<size_t>> &index_list, bool is_training) {
36   MS_EXCEPTION_IF_NULL(graph);
37   MS_EXCEPTION_IF_NULL(eli_list);
38   MS_EXCEPTION_IF_NULL(index_list);
39   GeneratePartitionedOperatorStrategy(graph, ops, index_list);
40 
41   std::shared_ptr<std::vector<size_t>> no_stra_op_list(new std::vector<size_t>);
42   for (size_t i = 0; i < eli_list->size(); i++) {
43     no_stra_op_list->push_back(eli_list->at(i)[0]);
44   }
45   GenerateEliminatedOperatorStrategyForward(graph, ops, input_tensor_names, index_list, no_stra_op_list);
46   GenerateEliminatedOperatorStrategyBackward(ops, input_tensor_names, no_stra_op_list);
47   GenerateRemainingOperatorStrategy(graph, ops, input_tensor_names, index_list, no_stra_op_list);
48 
49   for (auto &op : ops) {
50     // Set user-defined strategy
51     auto attrs = op->attrs();
52     if (StrategyFound(attrs)) {
53       StrategyPtr user_defined_stra = parallel::ExtractStrategy(attrs[STRATEGY]);
54       op->SetSelectedStrategyAndCost(user_defined_stra, op->selected_cost());
55     }
56     // Set back to raw strategy for special node in predict/eval
57     if (!is_training) {
58       if ((op->is_last_node()) || (op->type() == VIRTUAL_DATA_SET)) {
59         SetBackToRawStrategy(op);
60       }
61     }
62   }
63 }
64 
PrepareMatMul(const std::shared_ptr<Graph> & graph,const std::vector<std::shared_ptr<OperatorInfo>> & ops,const size_t iter_graph,const size_t iter_ops)65 Strategys PrepareMatMul(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
66                         const size_t iter_graph, const size_t iter_ops) {
67   Strategys strategies;
68   auto attrs = ops[iter_ops]->attrs();
69   bool transpose_a = attrs[TRANSPOSE_A]->cast<BoolImmPtr>()->value();
70   bool transpose_b = attrs[TRANSPOSE_B]->cast<BoolImmPtr>()->value();
71 
72   // HCCL does not support multi-dimension partition, and the hardware does not support excessive
73   // number of EVENT, so we temporarily disable matmul's multi-dimension partition function.
74   const float max_cut = 1.0 / SizeToFloat(g_device_manager->DeviceNum());
75   // The rule of cut is 0.5, 0.125. To compare the result we have to use ">" so we multiply max_cut to 1.1
76   if (graph->nodes[iter_graph].apply.arguments[0].tensor_str.str_h > max_cut * 1.1 &&
77       graph->nodes[iter_graph].apply.arguments[1].tensor_str.str_w > max_cut * 1.1) {
78     graph->nodes[iter_graph].apply.arguments[0].tensor_str.str_h = 1.0;
79     graph->nodes[iter_graph].apply.arguments[0].tensor_str.str_w = 1.0;
80     graph->nodes[iter_graph].apply.arguments[1].tensor_str.str_h = 1.0;
81     graph->nodes[iter_graph].apply.arguments[1].tensor_str.str_w = 1.0;
82     graph->nodes[iter_graph].tensor_parm.tensor_str.str_h = 1.0;
83     graph->nodes[iter_graph].tensor_parm.tensor_str.str_w = 1.0;
84 
85     auto shape_1 = ops[iter_ops]->inputs_tensor_info()[0].shape()[0];
86     if (transpose_a) {
87       shape_1 = ops[iter_ops]->inputs_tensor_info()[0].shape()[1];
88     }
89     auto shape_4 = ops[iter_ops]->inputs_tensor_info()[1].shape()[1];
90     if (transpose_b) {
91       shape_4 = ops[iter_ops]->inputs_tensor_info()[1].shape()[0];
92     }
93 
94     bool already_cut = false;
95     if (shape_1 >= shape_4) {
96       if (LongToSize(shape_1) % g_device_manager->DeviceNum() == 0) {
97         graph->nodes[iter_graph].apply.arguments[0].tensor_str.str_h = max_cut;
98         graph->nodes[iter_graph].tensor_parm.tensor_str.str_h = max_cut;
99         already_cut = true;
100       }
101       if (!already_cut && LongToSize(shape_4) % g_device_manager->DeviceNum() == 0) {
102         graph->nodes[iter_graph].apply.arguments[1].tensor_str.str_w = max_cut;
103         graph->nodes[iter_graph].tensor_parm.tensor_str.str_w = max_cut;
104         already_cut = true;
105       }
106     } else {
107       if (LongToSize(shape_4) % g_device_manager->DeviceNum() == 0) {
108         graph->nodes[iter_graph].apply.arguments[1].tensor_str.str_w = max_cut;
109         graph->nodes[iter_graph].tensor_parm.tensor_str.str_w = max_cut;
110         already_cut = true;
111       }
112       if (!already_cut && LongToSize(shape_1) % g_device_manager->DeviceNum() == 0) {
113         graph->nodes[iter_graph].apply.arguments[0].tensor_str.str_h = max_cut;
114         graph->nodes[iter_graph].tensor_parm.tensor_str.str_h = max_cut;
115         already_cut = true;
116       }
117     }
118 
119     if (!already_cut) {
120       MS_LOG(EXCEPTION) << "Failure: MatMul's shape is invalid.";
121     }
122   }
123 
124   for (size_t iter_op_inputs = 0; iter_op_inputs < ops[iter_ops]->inputs_tensor_info().size(); iter_op_inputs++) {
125     Dimensions s;
126     if (transpose_a && (iter_op_inputs == 0)) {
127       s.push_back(
128         static_cast<int64_t>(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_w));
129       s.push_back(
130         static_cast<int64_t>(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_h));
131     } else if (transpose_b && (iter_op_inputs == 1)) {
132       s.push_back(
133         static_cast<int64_t>(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_w));
134       s.push_back(
135         static_cast<int64_t>(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_h));
136     } else {
137       s.push_back(
138         static_cast<int64_t>(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_h));
139       s.push_back(
140         static_cast<int64_t>(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_w));
141     }
142     strategies.push_back(s);
143   }
144   return strategies;
145 }
146 
PrepareBiasAdd(const std::shared_ptr<Dimensions> & s)147 Strategys PrepareBiasAdd(const std::shared_ptr<Dimensions> &s) {
148   Strategys strategies;
149   strategies.push_back(*s);
150   Dimensions s_biasadd;
151   s_biasadd.push_back(s->at(1));
152   strategies.push_back(s_biasadd);
153   return strategies;
154 }
155 
PrepareStridedSlice(const std::vector<std::shared_ptr<OperatorInfo>> & ops,const size_t iter_ops,Dimensions basic_stra)156 Strategys PrepareStridedSlice(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
157                               Dimensions basic_stra) {
158   Strategys stra;
159 
160   auto begin = GetValue<std::vector<int64_t>>(ops[iter_ops]->input_value().at(1));
161   auto end = GetValue<std::vector<int64_t>>(ops[iter_ops]->input_value().at(2));
162   auto strides = GetValue<std::vector<int64_t>>(ops[iter_ops]->input_value().at(3));
163 
164   for (size_t i = 0; i < strides.size(); ++i) {
165     if ((strides[i] != 1) && (basic_stra[i] > 1)) {
166       basic_stra[i] = 1;
167     }
168   }
169 
170   for (size_t i = 0; i < begin.size(); ++i) {
171     bool no_fully_fetch = ((begin[i] != 0) || (end[i] < ops[iter_ops]->inputs_tensor_info()[0].shape()[i]));
172     if (no_fully_fetch && (basic_stra[i] != 1)) {
173       basic_stra[i] = 1;
174     }
175   }
176 
177   stra.push_back(basic_stra);
178   return stra;
179 }
180 
PrepareOneHot(const std::shared_ptr<Graph> & graph,const std::vector<std::shared_ptr<OperatorInfo>> & ops,const size_t iter_graph,const size_t iter_ops)181 Strategys PrepareOneHot(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
182                         const size_t iter_graph, const size_t iter_ops) {
183   Strategys strategies = MakeRecSearchStrategy(graph, ops, iter_graph, iter_ops);
184 
185   int64_t axis = -1;
186   auto iter = ops[iter_ops]->attrs().find(AXIS);
187   if (iter != ops[iter_ops]->attrs().end()) {
188     MS_EXCEPTION_IF_NULL(iter->second);
189     if (iter->second->isa<Int64Imm>()) {
190       axis = iter->second->cast<Int64ImmPtr>()->value();
191     } else {
192       MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": The value of axis is not int64_t.";
193     }
194   }
195   if (axis == -1) {
196     strategies[0][0] = strategies[0][1];
197     strategies[0][1] = 1;
198     graph->nodes[iter_graph].tensor_parm.tensor_str.str_h = graph->nodes[iter_graph].tensor_parm.tensor_str.str_w;
199     graph->nodes[iter_graph].tensor_parm.tensor_str.str_w = 1.0;
200   }
201 
202   Dimensions s_empty = {};
203   strategies.push_back(s_empty);
204   strategies.push_back(s_empty);
205   return strategies;
206 }
207 
PrepareGatherV2(const std::vector<std::shared_ptr<OperatorInfo>> & ops,const size_t iter_ops,Dimensions s)208 Strategys PrepareGatherV2(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops, Dimensions s) {
209   Strategys strategies;
210 
211   auto axis_input = GetValue<int64_t>(ops[iter_ops]->input_value().at(2));
212   if (axis_input < 0) {
213     axis_input += SizeToLong(ops[iter_ops]->inputs_tensor_info()[0].shape().size());
214   }
215   int64_t axis = axis_input;
216   if (axis >= SizeToLong(s.size())) {
217     MS_LOG(EXCEPTION) << "Failure: GatherV2' axis out of range.";
218   }
219   s[LongToSize(axis)] = 1;
220   strategies.push_back(s);
221 
222   return strategies;
223 }
224 
PrepareGatherV2P(const std::vector<std::shared_ptr<OperatorInfo>> & ops,const size_t iter_ops,Dimensions s)225 Strategys PrepareGatherV2P(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops, Dimensions s) {
226   Strategys strategies;
227 
228   auto output_shape = ops[iter_ops]->outputs_tensor_info()[0].shape();
229   Dimensions index(output_shape.size() - 1, 0);
230   for (size_t i = 0; i < index.size(); i++) {
231     index[i] = SizeToLong(i);
232   }
233   std::sort(index.begin(), index.end(), [&output_shape](const int64_t &a, const int64_t &b) {
234     return (output_shape[LongToSize(a + 1)] > output_shape[LongToSize(b + 1)]);
235   });
236   std::transform(std::begin(index), std::end(index), std::begin(index), [](int64_t x) { return x + 1; });
237   index.insert(index.begin(), 0);
238 
239   Dimensions strategie(output_shape.size(), 1);
240   size_t num_device = g_device_manager->DeviceNum();
241   size_t cut = 1;
242   for (size_t i = 0; i < index.size(); i++) {
243     size_t index_i = LongToSize(index[i]);
244     while (output_shape[index_i] % 2 == 0 && output_shape[index_i] > 0 && cut < num_device) {
245       output_shape[index_i] /= 2;
246       cut *= 2;
247       strategie[index_i] *= 2;
248     }
249     if (cut == num_device) {
250       break;
251     }
252   }
253 
254   auto axis_input = GetValue<int64_t>(ops[iter_ops]->input_value().at(2));
255   if (axis_input < 0) {
256     axis_input += SizeToLong(ops[iter_ops]->inputs_tensor_info()[0].shape().size());
257   }
258   int64_t axis = axis_input;
259   if (axis >= SizeToLong(s.size())) {
260     MS_LOG(EXCEPTION) << "Failure: GatherV2' axis out of range.";
261   }
262   if (axis == 0) {
263     s.clear();
264     s.push_back(1);
265     for (size_t i = 1; i < ops[iter_ops]->inputs_tensor_info()[0].shape().size(); i++) {
266       s.push_back(strategie[ops[iter_ops]->inputs_tensor_info()[1].shape().size() - 1 + i]);
267     }
268     strategies.push_back(s);
269     s.clear();
270     for (size_t i = 0; i < ops[iter_ops]->inputs_tensor_info()[1].shape().size(); i++) {
271       s.push_back(strategie[i]);
272     }
273     strategies.push_back(s);
274   } else if (axis == 1) {
275     s.clear();
276     s.push_back(strategie[0]);
277     s.push_back(1);
278     strategies.push_back(s);
279     s.clear();
280     for (size_t i = 0; i < ops[iter_ops]->inputs_tensor_info()[1].shape().size(); i++) {
281       s.push_back(strategie[ops[iter_ops]->inputs_tensor_info()[0].shape().size() - 1 + i]);
282     }
283     strategies.push_back(s);
284   } else {
285     MS_LOG(EXCEPTION) << "Failure: GatherV2's axis is neither 0 nor 1.";
286   }
287 
288   return strategies;
289 }
290 
PrepareGatherV2POutputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> & ops,const size_t incoming_op_index)291 Dimensions PrepareGatherV2POutputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
292                                           const size_t incoming_op_index) {
293   auto output_shape = ops[incoming_op_index]->outputs_tensor_info()[0].shape();
294   Dimensions index(output_shape.size() - 1, 0);
295   for (size_t i = 0; i < index.size(); i++) {
296     index[i] = SizeToLong(i);
297   }
298   std::sort(index.begin(), index.end(),
299             [&output_shape](const size_t &a, const size_t &b) { return (output_shape[a + 1] > output_shape[b + 1]); });
300   std::transform(std::begin(index), std::end(index), std::begin(index), [](int64_t x) { return x + 1; });
301   index.insert(index.begin(), 0);
302 
303   Dimensions strategie(output_shape.size(), 1);
304   size_t num_device = g_device_manager->DeviceNum();
305   size_t cut = 1;
306   for (size_t i = 0; i < index.size(); i++) {
307     size_t index_i = LongToSize(index[i]);
308     while (output_shape[index_i] % 2 == 0 && output_shape[index_i] > 0 && cut < num_device) {
309       output_shape[index_i] /= 2;
310       cut *= 2;
311       strategie[index_i] *= 2;
312     }
313     if (cut == num_device) {
314       break;
315     }
316   }
317 
318   return strategie;
319 }
320 
PrepareL2Normalize(const std::vector<std::shared_ptr<OperatorInfo>> & ops,const size_t iter_ops,Dimensions s)321 Strategys PrepareL2Normalize(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
322                              Dimensions s) {
323   int64_t axis = 0;
324   auto iter = ops[iter_ops]->attrs().find(AXIS);
325   if (iter != ops[iter_ops]->attrs().end()) {
326     MS_EXCEPTION_IF_NULL(iter->second);
327     if (iter->second->isa<ValueSequeue>()) {
328       axis = GetValue<std::vector<int64_t>>(iter->second)[0];
329     } else {
330       MS_LOG(EXCEPTION) << ops[iter_ops]->name() << " : The value of axis is not int64_t.";
331     }
332   }
333 
334   int64_t axis_index = axis;
335   if (axis < 0) {
336     size_t input_dim = ops[iter_ops]->inputs_tensor_info()[0].shape().size();
337     axis_index = static_cast<int64_t>(input_dim) + axis;
338   }
339 
340   s[LongToSize(axis_index)] = 1;
341 
342   Strategys strategies;
343   strategies.push_back(s);
344   return strategies;
345 }
346 
PrepareAxisRelatedStrategy(const std::shared_ptr<Graph> & graph,const std::vector<std::shared_ptr<OperatorInfo>> & ops,const size_t iter_graph,const size_t iter_ops)347 Strategys PrepareAxisRelatedStrategy(const std::shared_ptr<Graph> &graph,
348                                      const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_graph,
349                                      const size_t iter_ops) {
350   Strategys strategies = MakeRecSearchStrategy(graph, ops, iter_graph, iter_ops);
351   if (strategies.size() < 1) {
352     MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": get empty Strategy.";
353   }
354 
355   std::vector<int64_t> axis_list;
356   string axis_name = AXIS;
357   int64_t default_axis = -1;
358   if (ops[iter_ops]->type() == LAYER_NORM) {
359     axis_name = "begin_norm_axis";
360     default_axis = 1;
361   }
362 
363   auto iter = ops[iter_ops]->attrs().find(axis_name);
364   if (iter != ops[iter_ops]->attrs().end()) {
365     MS_EXCEPTION_IF_NULL(iter->second);
366     if (iter->second->isa<Int64Imm>()) {
367       axis_list.push_back(iter->second->cast<Int64ImmPtr>()->value());
368     } else if (iter->second->isa<ValueTuple>()) {
369       ValueTuplePtr value_tuple = iter->second->cast<ValueTuplePtr>();
370       if (value_tuple == nullptr) {
371         MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": The value_tuple is nullptr.";
372       }
373       std::vector<ValuePtr> value_vector = value_tuple->value();
374       (void)std::transform(value_vector.begin(), value_vector.end(), std::back_inserter(axis_list),
375                            [](const ValuePtr &value) { return static_cast<int64_t>(GetValue<int64_t>(value)); });
376     } else {
377       MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": The value of axis is not int64_t or tuple int64_t.";
378     }
379   } else {
380     axis_list.push_back(default_axis);
381   }
382 
383   for (auto &axis : axis_list) {
384     if (axis < 0) {
385       int64_t input_dim = SizeToLong(ops[iter_ops]->inputs_tensor_info()[0].shape().size());
386       axis = input_dim + axis;
387     }
388     if (axis >= SizeToLong(strategies[0].size()) || axis < 0) {
389       MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": axis value is out of range.";
390     }
391     if (strategies[0][LongToSize(axis)] != 1) {
392       strategies[0][LongToSize(axis)] = 1;
393       MS_LOG(INFO) << ops[iter_ops]->name() << ": adjust strategy to 1 on axis " << axis;
394     }
395   }
396   return strategies;
397 }
398 
MakeRecSearchStrategy(const std::shared_ptr<Graph> & graph,const std::vector<std::shared_ptr<OperatorInfo>> & ops,const size_t iter_graph,const size_t iter_ops)399 Strategys MakeRecSearchStrategy(const std::shared_ptr<Graph> &graph,
400                                 const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_graph,
401                                 const size_t iter_ops) {
402   if (ops.empty()) {
403     MS_LOG(EXCEPTION) << "Failure: Operators is empty.";
404   }
405   if (iter_ops >= ops.size()) {
406     MS_LOG(EXCEPTION) << "Failure: Operators' elements out of range.";
407   }
408   if (graph->nodes[iter_graph].apply.op_type == kRecUnsortedSegmentOp) {
409     return MakeDataParallelStrategy(graph, ops, iter_graph, iter_ops);
410   }
411 
412   StrategyPtr origin_strategy = ops[iter_ops]->strategy();
413   Strategys strategies;
414   for (size_t iter_op_inputs = 0; iter_op_inputs < ops[iter_ops]->inputs_tensor_info().size(); iter_op_inputs++) {
415     if (iter_op_inputs >= origin_strategy->GetInputDim().size()) {
416       MS_LOG(EXCEPTION) << "Failure: Strategy's InputDim out of range.";
417     }
418 
419     size_t output_size = origin_strategy->GetInputDim()[iter_op_inputs].size();
420     Dimensions s;
421     if (output_size == 4) {
422       s.push_back(
423         static_cast<int64_t>(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_n));
424       s.push_back(
425         static_cast<int64_t>(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_c));
426       s.push_back(
427         static_cast<int64_t>(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_h));
428       s.push_back(
429         static_cast<int64_t>(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_w));
430     } else if (output_size == 3) {
431       // Experimental support for 3D data.
432       s.push_back(
433         static_cast<int64_t>(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_c));
434       s.push_back(
435         static_cast<int64_t>(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_h));
436       s.push_back(
437         static_cast<int64_t>(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_w));
438     } else if (output_size == 2) {
439       s.push_back(
440         static_cast<int64_t>(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_h));
441       s.push_back(
442         static_cast<int64_t>(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_w));
443     } else if (output_size == 1) {
444       s.push_back(
445         static_cast<int64_t>(1.0 / graph->nodes[iter_graph].apply.arguments[iter_op_inputs].tensor_str.str_w));
446     } else if (output_size == 0) {
447       s = {};
448     } else {
449       MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": Tensor's output size is unexcepted.";
450     }
451     strategies.push_back(s);
452   }
453   return strategies;
454 }
455 
MakeDataParallelStrategy(const std::shared_ptr<Graph> & graph,const std::vector<std::shared_ptr<OperatorInfo>> & ops,const size_t iter_graph,const size_t iter_ops)456 Strategys MakeDataParallelStrategy(const std::shared_ptr<Graph> &graph,
457                                    const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_graph,
458                                    const size_t iter_ops) {
459   if (ops.empty()) {
460     MS_LOG(EXCEPTION) << "Failure: Operators is empty.";
461   }
462   if (iter_ops >= ops.size()) {
463     MS_LOG(EXCEPTION) << "Failure: Operators' elements out of range.";
464   }
465 
466   StrategyPtr origin_strategy = ops[iter_ops]->strategy();
467   Strategys strategies;
468   size_t max_device_num = g_device_manager->DeviceNum();
469   size_t target_tensor_batch = ops[iter_ops]->inputs_tensor_info()[0].shape()[0];
470   for (size_t iter_op_inputs = 0; iter_op_inputs < ops[iter_ops]->inputs_tensor_info().size(); iter_op_inputs++) {
471     if (iter_op_inputs >= origin_strategy->GetInputDim().size()) {
472       MS_LOG(EXCEPTION) << "Failure: Strategy's InputDim out of range.";
473     }
474 
475     Dimensions s;
476     size_t input_size = origin_strategy->GetInputDim()[iter_op_inputs].size();
477     for (size_t dim = 0; dim < input_size; dim++) {
478       // Experimental support for 3D data (input_size == 3).
479       if (input_size >= 1 && input_size <= 4) {
480         if (dim == 0) {
481           s.push_back(std::min(max_device_num, target_tensor_batch));
482         } else {
483           s.push_back(1);
484         }
485       } else if (input_size == 0) {
486         s = {};
487       } else {
488         MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": Tensor shape " << input_size << " is unexpected.";
489       }
490     }
491     strategies.push_back(s);
492   }
493   // Set default strategy.
494   graph->nodes[iter_graph].tensor_parm.tensor_str.str_n = 1.0;
495   graph->nodes[iter_graph].tensor_parm.tensor_str.str_c = 1.0;
496   graph->nodes[iter_graph].tensor_parm.tensor_str.str_h = 1.0;
497   graph->nodes[iter_graph].tensor_parm.tensor_str.str_w = 1.0;
498 
499   // Update data parallel strategy.
500   if (ops[iter_ops]->outputs_tensor_info().size() == 0) {
501     MS_LOG(EXCEPTION) << ops[iter_ops]->name() << " output tensor info is empty.";
502   }
503   if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 1) {
504     graph->nodes[iter_graph].tensor_parm.tensor_str.str_w = 1.0 / std::min(max_device_num, target_tensor_batch);
505   } else if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 2) {
506     graph->nodes[iter_graph].tensor_parm.tensor_str.str_h = 1.0 / std::min(max_device_num, target_tensor_batch);
507   } else if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 3) {
508     // Experimental support for 3D data.
509     graph->nodes[iter_graph].tensor_parm.tensor_str.str_c = 1.0 / std::min(max_device_num, target_tensor_batch);
510   } else if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 4) {
511     graph->nodes[iter_graph].tensor_parm.tensor_str.str_n = 1.0 / std::min(max_device_num, target_tensor_batch);
512   } else {
513     MS_LOG(INFO) << ops[iter_ops]->name() << " output tensor shape is unexpected, using default value instead.";
514   }
515 
516   return strategies;
517 }
518 
MakeFullBatchStrategy(const std::shared_ptr<Graph> & graph,const std::vector<std::shared_ptr<OperatorInfo>> & ops,const size_t iter_graph,const size_t iter_ops)519 Strategys MakeFullBatchStrategy(const std::shared_ptr<Graph> &graph,
520                                 const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_graph,
521                                 const size_t iter_ops) {
522   if (ops.empty()) {
523     MS_LOG(EXCEPTION) << "Failure: Operators is empty.";
524   }
525   if (iter_ops >= ops.size()) {
526     MS_LOG(EXCEPTION) << "Failure: Operators' elements out of range.";
527   }
528 
529   StrategyPtr origin_strategy = ops[iter_ops]->strategy();
530   Strategys strategies;
531   for (size_t iter_op_inputs = 0; iter_op_inputs < ops[iter_ops]->inputs_tensor_info().size(); iter_op_inputs++) {
532     if (iter_op_inputs >= origin_strategy->GetInputDim().size()) {
533       MS_LOG(EXCEPTION) << "Failure: Strategy's InputDim out of range.";
534     }
535     Dimensions s;
536     size_t input_size = origin_strategy->GetInputDim()[iter_op_inputs].size();
537     for (size_t dim = 0; dim < input_size; dim++) {
538       if (input_size >= 1 && input_size <= 4) {
539         s.push_back(1);
540       } else if (input_size == 0) {
541         s = {};
542       } else {
543         MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": Tensor shape " << input_size << " is unexpected.";
544       }
545     }
546     strategies.push_back(s);
547   }
548   // Update the output strategy of Rec Graph
549   graph->nodes[iter_graph].tensor_parm.tensor_str.str_n = 1.0;
550   graph->nodes[iter_graph].tensor_parm.tensor_str.str_c = 1.0;
551   graph->nodes[iter_graph].tensor_parm.tensor_str.str_h = 1.0;
552   graph->nodes[iter_graph].tensor_parm.tensor_str.str_w = 1.0;
553 
554   return strategies;
555 }
556 
SetBackToRawStrategy(const std::shared_ptr<OperatorInfo> & op)557 void SetBackToRawStrategy(const std::shared_ptr<OperatorInfo> &op) {
558   StrategyPtr origin_strategy = op->strategy();
559   Strategys strategies;
560 
561   for (size_t iter_strategy = 0; iter_strategy < origin_strategy->GetInputDim().size(); iter_strategy++) {
562     Dimensions s;
563     size_t strategy_size = origin_strategy->GetInputDim()[iter_strategy].size();
564     for (size_t dim = 0; dim < strategy_size; dim++) {
565       if (strategy_size >= 1 && strategy_size <= 4) {
566         s.push_back(1);
567       } else if (strategy_size == 0) {
568         s = {};
569       } else {
570         MS_LOG(EXCEPTION) << op->name() << ": Strategy size " << strategy_size << " is unmatched.";
571       }
572     }
573     strategies.push_back(s);
574   }
575 
576   StrategyPtr sp = std::make_shared<Strategy>(0, strategies);
577   op->SetSelectedStrategyAndCost(sp, op->selected_cost());
578 }
579 
PrepareStrategy(const std::shared_ptr<Graph> & graph,const std::vector<std::shared_ptr<OperatorInfo>> & ops,const size_t iter_graph,const size_t iter_ops)580 Strategys PrepareStrategy(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
581                           const size_t iter_graph, const size_t iter_ops) {
582   if (ops.empty()) {
583     MS_LOG(EXCEPTION) << "Failure: Operators is empty.";
584   }
585   if (iter_ops >= ops.size()) {
586     MS_LOG(EXCEPTION) << "Failure: Operators' elements out of range.";
587   }
588   MS_EXCEPTION_IF_NULL(ops[iter_ops]);
589 
590   auto type = ops[iter_ops]->type();
591   if (type == MATMUL) {
592     return PrepareMatMul(graph, ops, iter_graph, iter_ops);
593   } else if (type == ONEHOT) {
594     return PrepareOneHot(graph, ops, iter_graph, iter_ops);
595   } else if ((type == SOFTMAX) || (type == LAYER_NORM)) {
596     return PrepareAxisRelatedStrategy(graph, ops, iter_graph, iter_ops);
597   } else if ((type == SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS) || (type == "Dropout") || (type == BATCH_MATMUL)) {
598     return MakeDataParallelStrategy(graph, ops, iter_graph, iter_ops);
599   } else if (type == "_VirtualDataset") {
600     if (ParallelContext::GetInstance()->full_batch()) {
601       return MakeFullBatchStrategy(graph, ops, iter_graph, iter_ops);
602     } else {
603       return MakeDataParallelStrategy(graph, ops, iter_graph, iter_ops);
604     }
605   } else {
606     return MakeRecSearchStrategy(graph, ops, iter_graph, iter_ops);
607   }
608 }
609 
GeneratePartitionedOperatorStrategy(const std::shared_ptr<Graph> & graph,const std::vector<std::shared_ptr<OperatorInfo>> & ops,const std::shared_ptr<std::vector<size_t>> & index_list)610 void GeneratePartitionedOperatorStrategy(const std::shared_ptr<Graph> &graph,
611                                          const std::vector<std::shared_ptr<OperatorInfo>> &ops,
612                                          const std::shared_ptr<std::vector<size_t>> &index_list) {
613   for (size_t iter_ops = 0; iter_ops < (size_t)index_list->size(); iter_ops++) {
614     Strategys strategies;
615     size_t iter_graph = index_list->at(iter_ops);
616     if (iter_graph != SIZE_MAX && ops[iter_ops]->type() != GET_NEXT) {
617       strategies = PrepareStrategy(graph, ops, iter_graph, iter_ops);
618     }
619     StrategyPtr sp = std::make_shared<Strategy>(0, strategies);
620     ops[iter_ops]->SetSelectedStrategyAndCost(sp, ops[iter_ops]->selected_cost());
621   }
622 }
623 
FindIndexOfOperatorIncoming(const std::vector<std::vector<std::string>> & input_tensor_names,const size_t iter_ops)624 size_t FindIndexOfOperatorIncoming(const std::vector<std::vector<std::string>> &input_tensor_names,
625                                    const size_t iter_ops) {
626   size_t incoming_op_index = SIZE_MAX;
627   for (size_t i = 1; i < input_tensor_names[iter_ops].size(); i++) {
628     for (size_t j = 0; j < input_tensor_names.size(); j++) {
629       if (input_tensor_names[iter_ops][i] == input_tensor_names[j][0]) {
630         incoming_op_index = j;
631         break;
632       }
633     }
634     if (incoming_op_index != SIZE_MAX) {
635       break;
636     }
637   }
638   return incoming_op_index;
639 }
640 
CheckVirtualDatasetStrategy(const std::shared_ptr<Graph> & graph,const size_t iter_graph)641 float CheckVirtualDatasetStrategy(const std::shared_ptr<Graph> &graph, const size_t iter_graph) {
642   // The values for str can only be 1.0, 0.5, 0.25, 0.125…
643   // We want to find out the first str that is smaller than 1
644   if (graph->nodes[iter_graph].tensor_parm.tensor_str.str_n < 0.9) {
645     return graph->nodes[iter_graph].tensor_parm.tensor_str.str_n;
646   }
647   if (graph->nodes[iter_graph].tensor_parm.tensor_str.str_c < 0.9) {
648     return graph->nodes[iter_graph].tensor_parm.tensor_str.str_c;
649   }
650   if (graph->nodes[iter_graph].tensor_parm.tensor_str.str_h < 0.9) {
651     return graph->nodes[iter_graph].tensor_parm.tensor_str.str_h;
652   }
653   if (graph->nodes[iter_graph].tensor_parm.tensor_str.str_w < 0.9) {
654     return graph->nodes[iter_graph].tensor_parm.tensor_str.str_w;
655   }
656   return 1.0;
657 }
658 
CopyVirtualDataset(const std::shared_ptr<Graph> & graph,const std::vector<std::shared_ptr<OperatorInfo>> & ops,const size_t iter_ops,const size_t iter_graph)659 Dimensions CopyVirtualDataset(const std::shared_ptr<Graph> &graph,
660                               const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
661                               const size_t iter_graph) {
662   Dimensions s;
663   auto input_stra_dim = ops[iter_ops]->inputs_tensor_info()[0].shape().size();
664   auto virtual_dataset_str = CheckVirtualDatasetStrategy(graph, iter_graph);
665   if (input_stra_dim == 0) {
666     return s;
667   } else {
668     if (virtual_dataset_str == 0) {
669       s.push_back(1);
670     } else {
671       s.push_back(FloatToLong(1 / virtual_dataset_str));
672     }
673     for (size_t i = 1; i < input_stra_dim; i++) {
674       s.push_back(1);
675     }
676   }
677   return s;
678 }
679 
CopyIncomingOperatorOutputStrategy(const std::shared_ptr<Graph> & graph,const std::vector<std::shared_ptr<OperatorInfo>> & ops,const size_t iter_ops,const size_t iter_graph,const size_t incoming_op_index)680 Dimensions CopyIncomingOperatorOutputStrategy(const std::shared_ptr<Graph> &graph,
681                                               const std::vector<std::shared_ptr<OperatorInfo>> &ops,
682                                               const size_t iter_ops, const size_t iter_graph,
683                                               const size_t incoming_op_index) {
684   Dimensions s;
685 
686   if (ops[incoming_op_index]->type() == VIRTUAL_DATA_SET) {
687     s = CopyVirtualDataset(graph, ops, iter_ops, iter_graph);
688     return s;
689   }
690 
691   for (auto input : ops[iter_ops]->inputs_tensor_info()) {
692     auto input_stra_dim = input.shape().size();
693     if (input_stra_dim == 0) {
694       continue;
695     }
696     if (input_stra_dim == 1) {
697       s.push_back(FloatToLong(1 / graph->nodes[iter_graph].tensor_parm.tensor_str.str_w));
698     } else if (input_stra_dim == 2) {
699       s.push_back(FloatToLong(1 / graph->nodes[iter_graph].tensor_parm.tensor_str.str_h));
700       s.push_back(FloatToLong(1 / graph->nodes[iter_graph].tensor_parm.tensor_str.str_w));
701     } else if (input_stra_dim == 3) {
702       // Experimental support for 3D data.
703       s.push_back(FloatToLong(1 / graph->nodes[iter_graph].tensor_parm.tensor_str.str_c));
704       s.push_back(FloatToLong(1 / graph->nodes[iter_graph].tensor_parm.tensor_str.str_h));
705       s.push_back(FloatToLong(1 / graph->nodes[iter_graph].tensor_parm.tensor_str.str_w));
706     } else if (input_stra_dim == 4) {
707       s.push_back(FloatToLong(1 / graph->nodes[iter_graph].tensor_parm.tensor_str.str_n));
708       s.push_back(FloatToLong(1 / graph->nodes[iter_graph].tensor_parm.tensor_str.str_c));
709       s.push_back(FloatToLong(1 / graph->nodes[iter_graph].tensor_parm.tensor_str.str_h));
710       s.push_back(FloatToLong(1 / graph->nodes[iter_graph].tensor_parm.tensor_str.str_w));
711     } else {
712       MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": Tensor's shape is unknown.";
713     }
714     break;
715   }
716   return s;
717 }
718 
PrepareIncomingOperatorInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> & ops,const size_t incoming_op_index)719 Dimensions PrepareIncomingOperatorInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
720                                                 const size_t incoming_op_index) {
721   Dimensions s;
722   if (ops[incoming_op_index]->type() == RESHAPE || ops[incoming_op_index]->type() == TRANSPOSE) {
723     return s;
724   }
725   if (ops[incoming_op_index]->type() == GATHERV2) {
726     auto pos = ops[incoming_op_index]->name().find("Info");
727     if (pos == std::string::npos) {
728       return s;
729     }
730     auto name = ops[incoming_op_index]->name().substr(0, pos);
731     if (name == "Gather") {
732       return s;
733     } else if (name == "GatherP") {
734       return PrepareGatherV2POutputStrategy(ops, incoming_op_index);
735     } else {
736       MS_LOG(EXCEPTION) << "Failure: Unknown type of GatherV2." << std::endl;
737     }
738   }
739   auto strategy = ops[incoming_op_index]->selected_strategy();
740   if (strategy->GetInputNumber() == 0) {
741     return s;
742   }
743 
744   for (size_t i = 0; i < (size_t)ops[incoming_op_index]->inputs_tensor_info().size(); i++) {
745     if (ops[incoming_op_index]->inputs_tensor_info()[i].shape().size() == 0) {
746       continue;
747     }
748     for (size_t j = 0; j < ops[incoming_op_index]->inputs_tensor_info()[i].shape().size(); ++j) {
749       s.push_back(strategy->GetInputDim()[i][j]);
750     }
751     break;
752   }
753   return s;
754 }
755 
GetAxisList(const std::vector<std::shared_ptr<OperatorInfo>> & ops,const int64_t iter_ops)756 Dimensions GetAxisList(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const int64_t iter_ops) {
757   Dimensions axis_list;
758   auto axis_param = ops[LongToSize(iter_ops)]->attrs().find(AXIS)->second;
759   std::vector<ValuePtr> elements;
760   if (axis_param->isa<ValueTuple>()) {
761     elements = axis_param->cast<ValueTuplePtr>()->value();
762   } else if (axis_param->isa<ValueList>()) {
763     elements = axis_param->cast<ValueListPtr>()->value();
764   } else {
765     MS_LOG(EXCEPTION) << "Failure: Axis type is invalid, neither tuple nor list." << std::endl;
766   }
767 
768   for (auto &element : elements) {
769     if (!element->isa<Int64Imm>()) {
770       MS_LOG(EXCEPTION) << "Failure: Dimension indexes is not Int32." << std::endl;
771     }
772     auto axis = element->cast<Int64ImmPtr>()->value();
773     axis_list.push_back(axis);
774   }
775   return axis_list;
776 }
777 
ModifyStrategyIfSqueezeIncoming(const std::vector<std::shared_ptr<OperatorInfo>> & ops,const size_t incoming_op_index,Dimensions s)778 Dimensions ModifyStrategyIfSqueezeIncoming(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
779                                            const size_t incoming_op_index, Dimensions s) {
780   Dimensions s_Squeeze;
781   Dimensions stra_dim_list;
782   for (size_t i = 0; i < s.size(); i++) {
783     stra_dim_list.push_back(SizeToLong(i));
784   }
785 
786   auto axis_list = GetAxisList(ops, SizeToLong(incoming_op_index));
787   for (auto axis : axis_list) {
788     auto it = find(stra_dim_list.begin(), stra_dim_list.end(), axis);
789     if (it == stra_dim_list.end()) {
790       MS_LOG(EXCEPTION) << "Failure: Can not find dimension indexes in Axis." << std::endl;
791     }
792     if (ops[incoming_op_index]->inputs_tensor_info()[0].shape()[LongToSize(axis)] != 1) {
793       MS_LOG(EXCEPTION) << "Failure: Removed dimension's shape is not 1." << std::endl;
794     }
795     stra_dim_list.erase(it);
796   }
797 
798   for (size_t i = 0; i < (size_t)stra_dim_list.size(); i++) {
799     s_Squeeze.push_back(s[LongToSize(stra_dim_list[i])]);
800   }
801   return s_Squeeze;
802 }
803 
GetKeepDims(const std::vector<std::shared_ptr<OperatorInfo>> & ops,const size_t iter_ops)804 bool GetKeepDims(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops) {
805   bool keepdims = false;
806   auto keep_dims_iter = ops[iter_ops]->attrs().find(KEEP_DIMS);
807   if (keep_dims_iter == ops[iter_ops]->attrs().end()) {
808     MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": Don't have attr keep_dims.";
809   }
810   MS_EXCEPTION_IF_NULL(keep_dims_iter->second);
811   if (!keep_dims_iter->second->isa<BoolImm>()) {
812     MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": Keep_dims is not a bool.";
813   }
814   keepdims = keep_dims_iter->second->cast<BoolImmPtr>()->value();
815   return keepdims;
816 }
817 
GetDimList(const std::vector<std::shared_ptr<OperatorInfo>> & ops,const size_t iter_ops)818 Dimensions GetDimList(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops) {
819   Dimensions dim_list;
820   bool keep_dims = GetKeepDims(ops, iter_ops);
821   if (keep_dims != false) {
822     return dim_list;
823   }
824   auto input_value = ops[iter_ops]->input_value();
825   auto input_dim = ops[iter_ops]->inputs_tensor_info()[0].shape().size();
826   if (input_value.back()->isa<ValueTuple>()) {
827     auto attr_axis = GetValue<std::vector<int64_t>>(input_value.back());
828     if (attr_axis.empty()) {
829       for (size_t i = 0; i < input_dim; i++) {
830         dim_list.push_back(SizeToLong(i));
831       }
832     } else {
833       for (auto &axis : attr_axis) {
834         axis < 0 ? dim_list.push_back(axis + SizeToLong(input_dim)) : dim_list.push_back(axis);
835       }
836     }
837   } else if (input_value.back()->isa<Int64Imm>()) {
838     int64_t axis = GetValue<int64_t>(input_value.back());
839     axis < 0 ? dim_list.push_back(axis + SizeToLong(input_dim)) : dim_list.push_back(axis);
840   } else {
841     MS_LOG(EXCEPTION) << "Failure: Axis type is invalid." << std::endl;
842   }
843   return dim_list;
844 }
845 
ModifyStrategyIfReduceIncoming(const std::vector<std::shared_ptr<OperatorInfo>> & ops,const size_t incoming_op_index,Dimensions s)846 Dimensions ModifyStrategyIfReduceIncoming(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
847                                           const size_t incoming_op_index, Dimensions s) {
848   Dimensions s_Reduce;
849   Dimensions axis_list;
850   for (size_t i = 0; i < s.size(); i++) {
851     axis_list.push_back(SizeToLong(i));
852   }
853 
854   auto dim_list = GetDimList(ops, incoming_op_index);
855   for (auto axis : dim_list) {
856     auto it = find(axis_list.begin(), axis_list.end(), axis);
857     if (it == axis_list.end()) {
858       MS_LOG(EXCEPTION) << "Failure: Can not find dimension indexes in Axis." << std::endl;
859     }
860     axis_list.erase(it);
861   }
862 
863   for (size_t i = 0; i < (size_t)axis_list.size(); i++) {
864     s_Reduce.push_back(s[LongToSize(axis_list[i])]);
865   }
866   return s_Reduce;
867 }
868 
GetDimListFromAttrs(const std::vector<std::shared_ptr<OperatorInfo>> & ops,const size_t iter_ops)869 Dimensions GetDimListFromAttrs(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops) {
870   Dimensions dim_list;
871   auto iter = ops[iter_ops]->attrs().find(AXIS);
872   if (iter == ops[iter_ops]->attrs().end()) {
873     MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": Don't have attr axis.";
874   }
875   auto input_dim = ops[iter_ops]->inputs_tensor_info()[0].shape().size();
876   MS_EXCEPTION_IF_NULL(iter->second);
877   if (iter->second->isa<ValueTuple>()) {
878     auto attr_axis = GetValue<std::vector<int64_t>>(iter->second);
879     if (attr_axis.empty()) {
880       for (size_t i = 0; i < input_dim; ++i) {
881         dim_list.push_back(SizeToLong(i));
882       }
883     } else {
884       for (auto &axis : attr_axis) {
885         axis < 0 ? dim_list.push_back(axis + SizeToLong(input_dim)) : dim_list.push_back(axis);
886       }
887     }
888   } else if (iter->second->isa<Int64Imm>()) {
889     int64_t axis = GetValue<int64_t>(iter->second);
890     axis < 0 ? dim_list.push_back(axis + SizeToLong(input_dim)) : dim_list.push_back(axis);
891   } else {
892     MS_LOG(EXCEPTION) << "Axis type is invalid.";
893   }
894   return dim_list;
895 }
896 
ModifyStrategyIfArgIncoming(const std::vector<std::shared_ptr<OperatorInfo>> & ops,const size_t incoming_op_index,Dimensions s)897 Dimensions ModifyStrategyIfArgIncoming(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
898                                        const size_t incoming_op_index, Dimensions s) {
899   bool keepdims = GetKeepDims(ops, incoming_op_index);
900   if (keepdims) {
901     return s;
902   }
903 
904   Dimensions s_Arg;
905   Dimensions axis_list;
906   for (size_t i = 0; i < s.size(); i++) {
907     axis_list.push_back(SizeToLong(i));
908   }
909 
910   auto dim_list = GetDimListFromAttrs(ops, incoming_op_index);
911   for (auto axis : dim_list) {
912     auto it = find(axis_list.begin(), axis_list.end(), axis);
913     if (it == axis_list.end()) {
914       MS_LOG(EXCEPTION) << "Failure: Can not find dimension indexes in Axis." << std::endl;
915     }
916     axis_list.erase(it);
917   }
918 
919   for (size_t i = 0; i < (size_t)axis_list.size(); i++) {
920     s_Arg.push_back(s[LongToSize(axis_list[i])]);
921   }
922   return s_Arg;
923 }
924 
CopyIncomingOperatorInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> & ops,const size_t incoming_op_index)925 Dimensions CopyIncomingOperatorInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
926                                              const size_t incoming_op_index) {
927   Dimensions s;
928   s = PrepareIncomingOperatorInputStrategy(ops, incoming_op_index);
929   if (s.size() != 0) {
930     if (ops[incoming_op_index]->type() == SQUEEZE) {
931       s = ModifyStrategyIfSqueezeIncoming(ops, incoming_op_index, s);
932     }
933     if (ops[incoming_op_index]->type() == REDUCE_SUM || ops[incoming_op_index]->type() == REDUCE_MAX ||
934         ops[incoming_op_index]->type() == REDUCE_MIN || ops[incoming_op_index]->type() == REDUCE_MEAN) {
935       s = ModifyStrategyIfReduceIncoming(ops, incoming_op_index, s);
936     }
937     if (ops[incoming_op_index]->type() == ARGMAXWITHVALUE || ops[incoming_op_index]->type() == ARGMINWITHVALUE) {
938       s = ModifyStrategyIfArgIncoming(ops, incoming_op_index, s);
939     }
940   }
941   return s;
942 }
943 
GenerateStrategiesFromStrategy(const std::vector<std::shared_ptr<OperatorInfo>> & ops,const size_t iter_ops,Dimensions basic_stra)944 Strategys GenerateStrategiesFromStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
945                                          Dimensions basic_stra) {
946   Strategys stra;
947   MS_EXCEPTION_IF_NULL(ops[iter_ops]);
948 
949   if (basic_stra.size() == 0) {
950     for (size_t iter_op_inputs = 0; iter_op_inputs < (size_t)ops[iter_ops]->inputs_tensor_info().size();
951          iter_op_inputs++) {
952       stra.push_back(basic_stra);
953     }
954     return stra;
955   }
956 
957   auto s_ptr = std::make_shared<Dimensions>(basic_stra);
958   if (ops[iter_ops]->type() == BIAS_ADD) {
959     return PrepareBiasAdd(s_ptr);
960   }
961   if (ops[iter_ops]->type() == STRIDED_SLICE) {
962     return PrepareStridedSlice(ops, iter_ops, basic_stra);
963   }
964   if (ops[iter_ops]->type() == GATHERV2) {
965     auto pos = ops[iter_ops]->name().find("Info");
966     auto name = ops[iter_ops]->name().substr(0, pos);
967     if (name == "Gather") {
968       return PrepareGatherV2(ops, iter_ops, basic_stra);
969     } else if (name == "GatherP") {
970       return PrepareGatherV2P(ops, iter_ops, basic_stra);
971     } else {
972       MS_LOG(EXCEPTION) << "Failure: Unknown type of GatherV2." << std::endl;
973     }
974   }
975   if (ops[iter_ops]->type() == L2_NORMALIZE) {
976     return PrepareL2Normalize(ops, iter_ops, basic_stra);
977   }
978   if (ops[iter_ops]->type() == ADD || ops[iter_ops]->type() == SUB || ops[iter_ops]->type() == MUL ||
979       ops[iter_ops]->type() == DIV) {
980     return CheckBroadcast(ops, iter_ops, basic_stra);
981   }
982 
983   return CheckDivisible(ops, iter_ops, basic_stra);
984 }
985 
986 // Function to deal with ops with broadcasting, like TensorAdd/Sub/Mul/Div etc.
CheckBroadcast(const std::vector<std::shared_ptr<OperatorInfo>> & ops,const size_t iter_ops,const Dimensions s)987 Strategys CheckBroadcast(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
988                          const Dimensions s) {
989   Strategys stra;
990 
991   size_t first_tensor_dim = ops[iter_ops]->inputs_tensor_info()[0].shape().size();
992   size_t second_tensor_dim = ops[iter_ops]->inputs_tensor_info()[1].shape().size();
993   size_t s_dim = s.size();
994   // Do Broadcasting in the second tensor.
995   if (second_tensor_dim < first_tensor_dim) {
996     bool broadcast_first_tensor = false;
997     // Push back the first tensor's strategy.
998     if (s_dim == first_tensor_dim) {
999       stra.push_back(s);
1000     } else {
1001       Dimensions broadcast_revise_s(first_tensor_dim, 1);
1002       stra.push_back(broadcast_revise_s);
1003     }
1004     // Push back the second tensor's strategy after applying broadcast.
1005     stra.push_back(ApplyBroadcast(ops, iter_ops, s, first_tensor_dim, second_tensor_dim, broadcast_first_tensor));
1006   } else if (second_tensor_dim > first_tensor_dim) {  // Do Broadcasting in the first tensor.
1007     bool broadcast_first_tensor = true;
1008     // Push back the first tensor's strategy after applying broadcast.
1009     stra.push_back(ApplyBroadcast(ops, iter_ops, s, first_tensor_dim, second_tensor_dim, broadcast_first_tensor));
1010     // Push back the second tensor's strategy.
1011     if (s_dim == second_tensor_dim) {
1012       stra.push_back(s);
1013     } else {
1014       Dimensions broadcast_revise_s(second_tensor_dim, 1);
1015       stra.push_back(broadcast_revise_s);
1016     }
1017   } else {  // Broadcasting can be ignored or No broadcasting needs to be applied.
1018     stra = CheckDivisible(ops, iter_ops, s);
1019   }
1020 
1021   return stra;
1022 }
1023 
ApplyBroadcast(const std::vector<std::shared_ptr<OperatorInfo>> & ops,const size_t iter_ops,Dimensions s,size_t first_tensor_dim,size_t second_tensor_dim,bool broadcast_first_tensor)1024 Dimensions ApplyBroadcast(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops, Dimensions s,
1025                           size_t first_tensor_dim, size_t second_tensor_dim, bool broadcast_first_tensor) {
1026   Dimensions s_empty = {};
1027   Dimensions s_broadcast;
1028   size_t target_tensor_index = 0;
1029   size_t refer_tensor_index = 0;
1030   size_t target_tensor_dim;
1031   size_t refer_tensor_dim;
1032 
1033   // Indexing target and refer tensor.
1034   if (broadcast_first_tensor) {
1035     target_tensor_index = 0;
1036     refer_tensor_index = 1;
1037     target_tensor_dim = first_tensor_dim;
1038     refer_tensor_dim = second_tensor_dim;
1039   } else {
1040     target_tensor_index = 1;
1041     refer_tensor_index = 0;
1042     target_tensor_dim = second_tensor_dim;
1043     refer_tensor_dim = first_tensor_dim;
1044   }
1045 
1046   // When target tensor with an empty dim.
1047   if (target_tensor_dim == 0) {
1048     return s_empty;
1049   } else if (target_tensor_dim == 1) {  // When target tensor with a single dim.
1050     bool broadcast_dim_found = false;
1051     for (size_t iter = 0; iter < refer_tensor_dim; iter++) {
1052       // Find and copy that dim's strategy from the refer tensor.
1053       if ((ops[iter_ops]->inputs_tensor_info()[refer_tensor_index].shape()[iter] ==
1054            ops[iter_ops]->inputs_tensor_info()[target_tensor_index].shape()[0]) &&
1055           (ops[iter_ops]->inputs_tensor_info()[refer_tensor_index].shape()[iter] > 1) &&
1056           (refer_tensor_dim == s.size())) {
1057         s_broadcast.push_back(s.at(iter));
1058         broadcast_dim_found = true;
1059         break;
1060       }
1061     }
1062     // Cannot decide which dim it is, push back one.
1063     if (broadcast_dim_found == false) {
1064       s_broadcast.push_back(1);
1065     }
1066   } else {
1067     // Cannot decide which dim needs to do broadcast, push back one(s).
1068     for (size_t iter = 0; iter < target_tensor_dim; iter++) {
1069       s_broadcast.push_back(1);
1070     }
1071   }
1072 
1073   return s_broadcast;
1074 }
1075 
1076 // Check whether the operator can be divided by the current strategy.
CheckDivisible(const std::vector<std::shared_ptr<OperatorInfo>> & ops,const size_t iter_ops,const Dimensions basic_stra)1077 Strategys CheckDivisible(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
1078                          const Dimensions basic_stra) {
1079   Dimensions s_empty = {};
1080   Strategys stra;
1081 
1082   // For all the input tensors.
1083   for (size_t iter_op_inputs = 0; iter_op_inputs < (size_t)ops[iter_ops]->inputs_tensor_info().size();
1084        iter_op_inputs++) {
1085     // If input tensor is empty, return strategy as void.
1086     if (ops[iter_ops]->inputs_tensor_info()[iter_op_inputs].shape().size() == 0) {
1087       stra.push_back(s_empty);
1088       continue;
1089     }
1090 
1091     Dimensions tmp_stra = basic_stra;
1092     bool modified = false;
1093 
1094     // Make sure each tensor's dim shape is greater than 1. If not, push back strategy as 1 instead.
1095     for (size_t j = 0; j < (size_t)ops[iter_ops]->inputs_tensor_info()[iter_op_inputs].shape().size(); j++) {
1096       if (ops[iter_ops]->inputs_tensor_info()[iter_op_inputs].shape()[j] == 1) {
1097         tmp_stra[j] = 1;
1098         modified = true;
1099       }
1100     }
1101     if (modified) {
1102       stra.push_back(tmp_stra);
1103     } else {
1104       stra.push_back(basic_stra);
1105     }
1106   }
1107 
1108   return stra;
1109 }
1110 
GenerateEliminatedOperatorStrategyForward(const std::shared_ptr<Graph> & graph,const std::vector<std::shared_ptr<OperatorInfo>> & ops,const std::vector<std::vector<std::string>> & input_tensor_names,const std::shared_ptr<std::vector<size_t>> & index_list,const std::shared_ptr<std::vector<size_t>> & no_stra_op_list)1111 void GenerateEliminatedOperatorStrategyForward(const std::shared_ptr<Graph> &graph,
1112                                                const std::vector<std::shared_ptr<OperatorInfo>> &ops,
1113                                                const std::vector<std::vector<std::string>> &input_tensor_names,
1114                                                const std::shared_ptr<std::vector<size_t>> &index_list,
1115                                                const std::shared_ptr<std::vector<size_t>> &no_stra_op_list) {
1116   if (no_stra_op_list->size() == 0) {
1117     return;
1118   }
1119   std::vector<size_t> no_stra_op_list_bis;
1120 
1121   for (size_t iter_list = no_stra_op_list->size(); iter_list > 0; iter_list--) {
1122     size_t iter_ops = no_stra_op_list->at(iter_list - 1);
1123     Strategys stra;
1124     Dimensions s;
1125     size_t incoming_op_index = FindIndexOfOperatorIncoming(input_tensor_names, iter_ops);
1126     if (incoming_op_index != SIZE_MAX) {
1127       auto iter_graph = index_list->at(incoming_op_index);
1128       if (iter_graph != SIZE_MAX) {
1129         s = CopyIncomingOperatorOutputStrategy(graph, ops, iter_ops, iter_graph, incoming_op_index);
1130       } else {
1131         s = CopyIncomingOperatorInputStrategy(ops, incoming_op_index);
1132       }
1133     }
1134 
1135     if (s.size() == 0) {
1136       no_stra_op_list_bis.push_back(iter_ops);
1137     } else {
1138       stra = GenerateStrategiesFromStrategy(ops, iter_ops, s);
1139     }
1140 
1141     StrategyPtr sp = std::make_shared<Strategy>(0, stra);
1142     ops[iter_ops]->SetSelectedStrategyAndCost(sp, ops[iter_ops]->selected_cost());
1143   }
1144 
1145   no_stra_op_list->clear();
1146   for (size_t i = 0; i < no_stra_op_list_bis.size(); i++) {
1147     no_stra_op_list->push_back(no_stra_op_list_bis[i]);
1148   }
1149 }
1150 
ModifyStrategyIfSqueezeOutgoing(const std::vector<std::shared_ptr<OperatorInfo>> & ops,const size_t iter_ops,Dimensions s)1151 Dimensions ModifyStrategyIfSqueezeOutgoing(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
1152                                            Dimensions s) {
1153   Dimensions s_Squeeze;
1154   auto axis_list = GetAxisList(ops, SizeToLong(iter_ops));
1155   size_t s_index = 0;
1156   size_t axis_list_index = 0;
1157   for (size_t i = 0; i < (size_t)(s.size() + axis_list.size()); i++) {
1158     if (i == (size_t)axis_list[axis_list_index]) {
1159       s_Squeeze.push_back(1);
1160       axis_list_index++;
1161     } else {
1162       s_Squeeze.push_back(s[s_index]);
1163       s_index++;
1164     }
1165   }
1166 
1167   size_t cut = 1;
1168   for (size_t i = 0; i < s_Squeeze.size(); i++) {
1169     cut *= LongToSize(s_Squeeze[i]);
1170   }
1171   if (cut != g_device_manager->DeviceNum()) {
1172     s_Squeeze.clear();
1173   }
1174 
1175   return s_Squeeze;
1176 }
1177 
CopyOutgoingOperatorInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> & ops,const std::vector<std::vector<std::string>> & input_tensor_names,const size_t iter_ops)1178 Dimensions CopyOutgoingOperatorInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
1179                                              const std::vector<std::vector<std::string>> &input_tensor_names,
1180                                              const size_t iter_ops) {
1181   Dimensions s;
1182   if (ops[iter_ops]->type() == REDUCE_MAX || ops[iter_ops]->type() == REDUCE_MIN ||
1183       ops[iter_ops]->type() == REDUCE_SUM || ops[iter_ops]->type() == REDUCE_MEAN || ops[iter_ops]->type() == RESHAPE ||
1184       ops[iter_ops]->type() == GATHERV2 || ops[iter_ops]->type() == TRANSPOSE ||
1185       ops[iter_ops]->type() == ARGMAXWITHVALUE || ops[iter_ops]->type() == ARGMINWITHVALUE) {
1186     return s;
1187   }
1188 
1189   bool found = false;
1190   size_t outgoing_op_index = SIZE_MAX;
1191   size_t iter_op_inputs = SIZE_MAX;
1192   for (size_t i = 0; i < input_tensor_names.size(); i++) {
1193     for (size_t j = 1; j < input_tensor_names[i].size(); j++) {
1194       if (input_tensor_names[i][j] == input_tensor_names[iter_ops][0] &&
1195           ops[i]->selected_strategy()->GetInputNumber() != 0) {
1196         outgoing_op_index = i;
1197         iter_op_inputs = j - 1;
1198         found = true;
1199         break;
1200       }
1201     }
1202     if (found) {
1203       break;
1204     }
1205   }
1206 
1207   if (outgoing_op_index != SIZE_MAX && iter_op_inputs != SIZE_MAX) {
1208     for (size_t k = 0; k < ops[iter_ops]->outputs_tensor_info()[0].shape().size(); ++k) {
1209       s.push_back(ops[outgoing_op_index]->selected_strategy()->GetInputDim()[iter_op_inputs][k]);
1210     }
1211   }
1212   return s;
1213 }
1214 
GenerateEliminatedOperatorStrategyBackward(const std::vector<std::shared_ptr<OperatorInfo>> & ops,const std::vector<std::vector<std::string>> & input_tensor_names,const std::shared_ptr<std::vector<size_t>> & no_stra_op_list)1215 void GenerateEliminatedOperatorStrategyBackward(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
1216                                                 const std::vector<std::vector<std::string>> &input_tensor_names,
1217                                                 const std::shared_ptr<std::vector<size_t>> &no_stra_op_list) {
1218   if (no_stra_op_list->size() == 0) {
1219     return;
1220   }
1221   std::vector<size_t> no_stra_op_list_bis;
1222 
1223   for (size_t iter_list = no_stra_op_list->size(); iter_list > 0; iter_list--) {
1224     auto iter_ops = no_stra_op_list->at(iter_list - 1);
1225     Strategys stra;
1226     Dimensions s = CopyOutgoingOperatorInputStrategy(ops, input_tensor_names, iter_ops);
1227     if (s.size() != 0 && ops[iter_ops]->type() == SQUEEZE) {
1228       s = ModifyStrategyIfSqueezeOutgoing(ops, iter_ops, s);
1229     }
1230     if (s.size() != 0) {
1231       stra = GenerateStrategiesFromStrategy(ops, iter_ops, s);
1232     } else {
1233       no_stra_op_list_bis.push_back(iter_ops);
1234     }
1235 
1236     StrategyPtr sp = std::make_shared<Strategy>(0, stra);
1237     ops[iter_ops]->SetSelectedStrategyAndCost(sp, ops[iter_ops]->selected_cost());
1238   }
1239 
1240   no_stra_op_list->clear();
1241   for (size_t i = 0; i < no_stra_op_list_bis.size(); i++) {
1242     no_stra_op_list->push_back(no_stra_op_list_bis[i]);
1243   }
1244 }
1245 
GenerateRemainingOperatorStrategy(const std::shared_ptr<Graph> & graph,const std::vector<std::shared_ptr<OperatorInfo>> & ops,const std::vector<std::vector<std::string>> & input_tensor_names,const std::shared_ptr<std::vector<size_t>> & index_list,const std::shared_ptr<std::vector<size_t>> & no_stra_op_list)1246 void GenerateRemainingOperatorStrategy(const std::shared_ptr<Graph> &graph,
1247                                        const std::vector<std::shared_ptr<OperatorInfo>> &ops,
1248                                        const std::vector<std::vector<std::string>> &input_tensor_names,
1249                                        const std::shared_ptr<std::vector<size_t>> &index_list,
1250                                        const std::shared_ptr<std::vector<size_t>> &no_stra_op_list) {
1251   if (no_stra_op_list->size() == 0) {
1252     return;
1253   }
1254 
1255   size_t no_stra_op_list_size = no_stra_op_list->size();
1256   do {
1257     no_stra_op_list_size = no_stra_op_list->size();
1258     GenerateEliminatedOperatorStrategyForward(graph, ops, input_tensor_names, index_list, no_stra_op_list);
1259     GenerateEliminatedOperatorStrategyBackward(ops, input_tensor_names, no_stra_op_list);
1260   } while (no_stra_op_list_size > no_stra_op_list->size());
1261 
1262   for (size_t iter_list = 0; iter_list < no_stra_op_list->size(); iter_list++) {
1263     auto iter_ops = no_stra_op_list->at(iter_list);
1264     Strategys stra;
1265     Dimensions s;
1266 
1267     size_t max_dim_num = 0;
1268     for (size_t iter_op_inputs = 0; iter_op_inputs < ops[iter_ops]->inputs_tensor_info().size(); iter_op_inputs++) {
1269       if (ops[iter_ops]->inputs_tensor_info()[iter_op_inputs].shape().size() > max_dim_num) {
1270         max_dim_num = ops[iter_ops]->inputs_tensor_info()[iter_op_inputs].shape().size();
1271       }
1272     }
1273     for (size_t i = 0; i < max_dim_num; i++) {
1274       s.push_back(1);
1275     }
1276 
1277     stra = GenerateStrategiesFromStrategy(ops, iter_ops, s);
1278     StrategyPtr sp = std::make_shared<Strategy>(0, stra);
1279     ops[iter_ops]->SetSelectedStrategyAndCost(sp, ops[iter_ops]->selected_cost());
1280   }
1281 }
1282 }  // namespace parallel
1283 }  // namespace mindspore
1284