• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/parallel/auto_parallel/rec_core/rec_cost.h"
18 
19 #include <algorithm>
20 #include <limits>
21 #include <string>
22 #include <utility>
23 #include <vector>
24 
25 #include "ir/anf.h"
26 
27 namespace mindspore {
28 namespace parallel {
29 // Compute redistributed cost
CostRedis(const Graph::NodeType & node,const std::vector<std::pair<std::string,StrategyRec>> & node_name_to_strategy,const std::vector<std::vector<float>> & mode,const Graph & graph)30 double CostRedis(const Graph::NodeType &node,
31                  const std::vector<std::pair<std::string, StrategyRec>> &node_name_to_strategy,
32                  const std::vector<std::vector<float>> &mode, const Graph &graph) {
33   // Store value of cost redist
34   double cost_redis = 0;
35 
36   // Number of current strategies.
37   size_t num_strategy = node_name_to_strategy.size();
38 
39   // Number of node-in and node-out
40   size_t num_node_in = node.node_in.size();
41   size_t num_node_out = node.node_out.size();
42 
43   // Set tensor edge value with original tensor shape and cutting times.
44   double input_tensor = node.apply.arguments[0].tensor_shape.shape_n * node.apply.arguments[0].tensor_str.str_n *
45                         node.apply.arguments[0].tensor_shape.shape_c * node.apply.arguments[0].tensor_str.str_c *
46                         node.apply.arguments[0].tensor_shape.shape_h * node.apply.arguments[0].tensor_str.str_h *
47                         node.apply.arguments[0].tensor_shape.shape_w * node.apply.arguments[0].tensor_str.str_w;
48 
49   double output_tensor = node.tensor_parm.tensor_shape.shape_n * node.tensor_parm.tensor_str.str_n *
50                          node.tensor_parm.tensor_shape.shape_c * node.tensor_parm.tensor_str.str_c *
51                          node.tensor_parm.tensor_shape.shape_h * node.tensor_parm.tensor_str.str_h *
52                          node.tensor_parm.tensor_shape.shape_w * node.tensor_parm.tensor_str.str_w;
53 
54   // For each strategy candidate.
55   for (size_t i_strategy = 0; i_strategy < num_strategy; i_strategy++) {
56     // Find its forward nodes
57     for (size_t i_node = 0; i_node < num_node_in; i_node++) {
58       if (graph.nodes[node.node_in[i_node]].name == node_name_to_strategy[i_strategy].first) {
59         bool is_search_forward = true;
60         cost_redis +=
61           CostRedisWithAdjacentNode(node_name_to_strategy, mode, i_strategy, i_node, input_tensor, is_search_forward);
62       }
63     }
64 
65     // Find its backward nodes
66     for (size_t i_node = 0; i_node < num_node_out; i_node++) {
67       if (graph.nodes[node.node_out[i_node]].name == node_name_to_strategy[i_strategy].first) {
68         bool is_search_forward = false;
69         cost_redis +=
70           CostRedisWithAdjacentNode(node_name_to_strategy, mode, i_strategy, i_node, output_tensor, is_search_forward);
71       }
72     }
73   }
74 
75   return cost_redis;
76 }
77 
CostRedisWithAdjacentNode(const std::vector<std::pair<std::string,StrategyRec>> & node_name_to_strategy,const std::vector<std::vector<float>> & mode,size_t i_strategy,size_t i_node,double tensor_size,bool search_forward)78 double CostRedisWithAdjacentNode(const std::vector<std::pair<std::string, StrategyRec>> &node_name_to_strategy,
79                                  const std::vector<std::vector<float>> &mode, size_t i_strategy, size_t i_node,
80                                  double tensor_size, bool search_forward) {
81   double new_redis_cost = 0;
82   int64_t counter = 0;
83 
84   if (search_forward) {
85     if (static_cast<int64_t>(1 / node_name_to_strategy[i_strategy].second.outputTensor.str_n) !=
86         static_cast<int64_t>(1 / mode[i_node][0])) {
87       counter += 1;
88     }
89     if (static_cast<int64_t>(1 / node_name_to_strategy[i_strategy].second.outputTensor.str_c) !=
90         static_cast<int64_t>(1 / mode[i_node][1])) {
91       counter += 1;
92     }
93     if (static_cast<int64_t>(1 / node_name_to_strategy[i_strategy].second.outputTensor.str_h) !=
94         static_cast<int64_t>(1 / mode[i_node][2])) {
95       counter += 1;
96     }
97     if (static_cast<int64_t>(1 / node_name_to_strategy[i_strategy].second.outputTensor.str_w) !=
98         static_cast<int64_t>(1 / mode[i_node][3])) {
99       counter += 1;
100     }
101   } else {
102     if (static_cast<int64_t>(1 / node_name_to_strategy[i_strategy].second.inputTensor[0].str_n) !=
103         static_cast<int64_t>(1 / mode[2][0])) {
104       counter += 1;
105     }
106     if (static_cast<int64_t>(1 / node_name_to_strategy[i_strategy].second.inputTensor[0].str_c) !=
107         static_cast<int64_t>(1 / mode[2][1])) {
108       counter += 1;
109     }
110     if (static_cast<int64_t>(1 / node_name_to_strategy[i_strategy].second.inputTensor[0].str_h) !=
111         static_cast<int64_t>(1 / mode[2][2])) {
112       counter += 1;
113     }
114     if (static_cast<int64_t>(1 / node_name_to_strategy[i_strategy].second.inputTensor[0].str_w) !=
115         static_cast<int64_t>(1 / mode[2][3])) {
116       counter += 1;
117     }
118   }
119 
120   if (counter >= 2) {
121     new_redis_cost = tensor_size / 4.0;
122   } else if (counter == 0 || counter == 1) {
123     new_redis_cost = 0;
124   } else {
125     MS_LOG(EXCEPTION) << "Failure: CostRedis failed.";
126   }
127 
128   return new_redis_cost;
129 }
130 
131 // Get optimal strategy for MatMul
GetOptimalStr(const Graph::NodeType & node,const std::vector<std::pair<std::string,StrategyRec>> & node_name_to_strategy,const Graph & graph)132 StrategyRec CostMatMul::GetOptimalStr(const Graph::NodeType &node,
133                                       const std::vector<std::pair<std::string, StrategyRec>> &node_name_to_strategy,
134                                       const Graph &graph) {
135   int64_t edge_i =
136     static_cast<int64_t>(node.apply.arguments[0].tensor_shape.shape_h * node.apply.arguments[0].tensor_str.str_h);
137   int64_t edge_j =
138     static_cast<int64_t>(node.apply.arguments[1].tensor_shape.shape_w * node.apply.arguments[1].tensor_str.str_w);
139   int64_t edge_k =
140     static_cast<int64_t>(node.apply.arguments[0].tensor_shape.shape_w * node.apply.arguments[0].tensor_str.str_w);
141 
142   std::vector<double> cost_op;
143   std::vector<std::vector<float>> mode;
144 
145   if (edge_i < 2 || edge_i % 2 != 0) {
146     cost_op.push_back(DOUBLE_MAX);
147   } else {
148     cost_op.push_back(StrConcatDimI(edge_j, edge_k) + CostRedis(node, node_name_to_strategy,
149                                                                 mode = {{1, 1, 0.5, 1}, {1, 1, 1, 1}, {1, 1, 0.5, 1}},
150                                                                 graph));
151   }
152 
153   if (edge_j < 2 || edge_j % 2 != 0) {
154     cost_op.push_back(DOUBLE_MAX);
155   } else {
156     cost_op.push_back(StrConcatDimJ(edge_i, edge_k) + CostRedis(node, node_name_to_strategy,
157                                                                 mode = {{1, 1, 1, 1}, {1, 1, 1, 0.5}, {1, 1, 1, 0.5}},
158                                                                 graph));
159   }
160 
161   if (edge_k < 2 || edge_k % 2 != 0) {
162     cost_op.push_back(DOUBLE_MAX);
163   } else {
164     cost_op.push_back(StrReduceDimK(edge_i, edge_j) + CostRedis(node, node_name_to_strategy,
165                                                                 mode = {{1, 1, 1, 0.5}, {1, 1, 0.5, 1}, {1, 1, 1, 1}},
166                                                                 graph));
167   }
168 
169   return ChoseStr(cost_op, node.apply.str);
170 }
171 
172 // Get weight for MatMul
GetMinCostIn(const OperatorRec & op)173 double CostMatMul::GetMinCostIn(const OperatorRec &op) {
174   int64_t edge_i = static_cast<int64_t>(op.arguments[0].tensor_shape.shape_h * op.arguments[0].tensor_str.str_h);
175   int64_t edge_j = static_cast<int64_t>(op.arguments[1].tensor_shape.shape_w * op.arguments[1].tensor_str.str_w);
176   int64_t edge_k = static_cast<int64_t>(op.arguments[0].tensor_shape.shape_w * op.arguments[0].tensor_str.str_w);
177 
178   std::vector<double> cost_in;
179   cost_in.push_back(StrConcatDimI(edge_j, edge_k));
180   cost_in.push_back(StrConcatDimJ(edge_i, edge_k));
181   cost_in.push_back(StrReduceDimK(edge_i, edge_j));
182 
183   return *min_element(cost_in.begin(), cost_in.end());
184 }
185 
186 // Chose strategy for MatMul
ChoseStr(const std::vector<double> & cost_op,StrategyRec str)187 StrategyRec CostMatMul::ChoseStr(const std::vector<double> &cost_op, StrategyRec str) {
188   uint64_t min_position = min_element(cost_op.begin(), cost_op.end()) - cost_op.begin();
189   if (cost_op[min_position] > (DOUBLE_MAX - 0.1)) {
190     return str;
191   }
192 
193   switch (min_position) {
194     case 0:
195       str.inputTensor[0].str_h /= 2.0;
196       str.outputTensor.str_h /= 2.0;
197       str.cut_counter += 1;
198       str.cost = str.cost + cost_in_i_;
199       break;
200 
201     case 1:
202       str.inputTensor[1].str_w /= 2.0;
203       str.outputTensor.str_w /= 2.0;
204       str.cut_counter += 1;
205       str.cost = str.cost + cost_in_j_;
206       break;
207 
208     case 2:
209       str.inputTensor[0].str_w /= 2.0;
210       str.inputTensor[1].str_h /= 2.0;
211       str.cut_counter += 1;
212       str.cost = str.cost + cost_in_k_;
213       break;
214 
215     default:
216       MS_LOG(EXCEPTION) << "Failure:CostMatMul failed.";
217   }
218 
219   return str;
220 }
221 
222 // Get optimal strategy for Conv
GetOptimalStr(const Graph::NodeType & node,const std::vector<std::pair<std::string,StrategyRec>> & node_name_to_strategy,const Graph & graph,bool channel_partition)223 StrategyRec CostConvolution::GetOptimalStr(
224   const Graph::NodeType &node, const std::vector<std::pair<std::string, StrategyRec>> &node_name_to_strategy,
225   const Graph &graph, bool channel_partition) {
226   const OperatorRec &op = node.apply;
227 
228   int64_t input_tensor_h =
229     static_cast<int64_t>(op.arguments[0].tensor_shape.shape_h * op.arguments[0].tensor_str.str_h);
230   int64_t input_tensor_w =
231     static_cast<int64_t>(op.arguments[0].tensor_shape.shape_w * op.arguments[0].tensor_str.str_w);
232   int64_t input_tensor_n =
233     static_cast<int64_t>(op.arguments[0].tensor_shape.shape_n * op.arguments[0].tensor_str.str_n);
234   int64_t input_tensor_c =
235     static_cast<int64_t>(op.arguments[0].tensor_shape.shape_c * op.arguments[0].tensor_str.str_c);
236 
237   int64_t tensor_in = input_tensor_h * input_tensor_w * input_tensor_n * input_tensor_c;
238 
239   int64_t tensor_filter_h =
240     static_cast<int64_t>(op.arguments[1].tensor_shape.shape_h * op.arguments[1].tensor_str.str_h);
241   int64_t tensor_filter_w =
242     static_cast<int64_t>(op.arguments[1].tensor_shape.shape_w * op.arguments[1].tensor_str.str_w);
243   int64_t tensor_filter_n =
244     static_cast<int64_t>(op.arguments[1].tensor_shape.shape_n * op.arguments[1].tensor_str.str_n);
245   int64_t tensor_filter_c =
246     static_cast<int64_t>(op.arguments[1].tensor_shape.shape_c * op.arguments[1].tensor_str.str_c);
247 
248   int64_t tensor_filter = tensor_filter_h * tensor_filter_w * tensor_filter_n * tensor_filter_c;
249 
250   int64_t output_tensor_h =
251     static_cast<int64_t>(node.tensor_parm.tensor_shape.shape_h * node.tensor_parm.tensor_str.str_h);
252   int64_t output_tensor_w =
253     static_cast<int64_t>(node.tensor_parm.tensor_shape.shape_w * node.tensor_parm.tensor_str.str_w);
254   int64_t output_tensor_n =
255     static_cast<int64_t>(node.tensor_parm.tensor_shape.shape_n * node.tensor_parm.tensor_str.str_n);
256   int64_t output_tensor_c =
257     static_cast<int64_t>(node.tensor_parm.tensor_shape.shape_c * node.tensor_parm.tensor_str.str_c);
258 
259   int64_t tensor_out = output_tensor_h * output_tensor_w * output_tensor_n * output_tensor_c;
260 
261   std::vector<double> cost_op;
262   cost_op.reserve(7);
263   std::vector<std::vector<float>> mode;
264 
265   if (input_tensor_n < 2 || input_tensor_n % 2 != 0) {
266     cost_op.push_back(DOUBLE_MAX);
267   } else {
268     cost_op.push_back(StrDimB(tensor_filter) + CostRedis(node, node_name_to_strategy,
269                                                          mode = {{0.5, 1, 1, 1}, {1, 1, 1, 1}, {0.5, 1, 1, 1}}, graph));
270   }
271 
272   cost_op.push_back(DOUBLE_MAX);
273   cost_op.push_back(DOUBLE_MAX);
274 
275   if (channel_partition == false || tensor_filter < 2 || tensor_filter % 2 != 0) {
276     cost_op.push_back(DOUBLE_MAX);
277   } else {
278     cost_op.push_back(StrDimK(tensor_in) + CostRedis(node, node_name_to_strategy,
279                                                      mode = {{1, 1, 1, 1}, {0.5, 1, 1, 1}, {1, 0.5, 1, 1}}, graph));
280   }
281 
282   cost_op.push_back(DOUBLE_MAX);
283   cost_op.push_back(DOUBLE_MAX);
284 
285   if (channel_partition == false || tensor_filter_c < 2 || tensor_filter_c % 2 != 0) {
286     cost_op.push_back(DOUBLE_MAX);
287   } else {
288     cost_op.push_back(StrDimQ(tensor_out) + CostRedis(node, node_name_to_strategy,
289                                                       mode = {{1, 0.5, 1, 1}, {1, 0.5, 1, 1}, {1, 1, 1, 1}}, graph));
290   }
291 
292   return ChoseStr(cost_op, node.apply.str);
293 }
294 
295 // Get weight for Conv
GetMinCostIn(const Graph::NodeType & node)296 double CostConvolution::GetMinCostIn(const Graph::NodeType &node) {
297   const OperatorRec &op = node.apply;
298 
299   int64_t tensor_in = static_cast<int64_t>(op.arguments[0].tensor_shape.shape_h * op.arguments[0].tensor_str.str_h) *
300                       static_cast<int64_t>(op.arguments[0].tensor_shape.shape_n * op.arguments[0].tensor_str.str_n) *
301                       static_cast<int64_t>(op.arguments[0].tensor_shape.shape_w * op.arguments[0].tensor_str.str_w) *
302                       static_cast<int64_t>(op.arguments[0].tensor_shape.shape_c * op.arguments[0].tensor_str.str_c);
303   int64_t tensor_filter =
304     static_cast<int64_t>(op.arguments[1].tensor_shape.shape_h * op.arguments[1].tensor_str.str_h) *
305     static_cast<int64_t>(op.arguments[1].tensor_shape.shape_n * op.arguments[1].tensor_str.str_n) *
306     static_cast<int64_t>(op.arguments[1].tensor_shape.shape_w * op.arguments[1].tensor_str.str_w) *
307     static_cast<int64_t>(op.arguments[1].tensor_shape.shape_c * op.arguments[1].tensor_str.str_c);
308   int64_t tensor_out = static_cast<int64_t>(node.tensor_parm.tensor_shape.shape_h * node.tensor_parm.tensor_str.str_h) *
309                        static_cast<int64_t>(node.tensor_parm.tensor_shape.shape_n * node.tensor_parm.tensor_str.str_n) *
310                        static_cast<int64_t>(node.tensor_parm.tensor_shape.shape_w * node.tensor_parm.tensor_str.str_w) *
311                        static_cast<int64_t>(node.tensor_parm.tensor_shape.shape_c * node.tensor_parm.tensor_str.str_c);
312 
313   std::vector<double> cost_in;
314   cost_in.push_back(StrDimB(tensor_filter));
315   cost_in.push_back(StrDimI(tensor_in, tensor_filter));
316   cost_in.push_back(StrDimJ(tensor_in, tensor_filter));
317   cost_in.push_back(StrDimK(tensor_in));
318   cost_in.push_back(StrDimDI(tensor_in, tensor_out));
319   cost_in.push_back(StrDimDJ(tensor_in, tensor_out));
320   cost_in.push_back(StrDimQ(tensor_out));
321 
322   return *min_element(cost_in.begin(), cost_in.end());
323 }
324 
325 // Chose strategy for Conv
ChoseStr(const std::vector<double> & cost_op,StrategyRec str)326 StrategyRec CostConvolution::ChoseStr(const std::vector<double> &cost_op, StrategyRec str) {
327   uint64_t min_position = min_element(cost_op.begin(), cost_op.end()) - cost_op.begin();
328   if (cost_op[min_position] > (DOUBLE_MAX - 0.1)) {
329     return str;
330   }
331 
332   switch (min_position) {
333     case 0:
334       str.inputTensor[0].str_n /= 2.0;
335       str.outputTensor.str_n /= 2.0;
336       str.cut_counter += 1;
337       str.cost = str.cost + cost_in_b_;
338       break;
339 
340     case 1:
341       str.inputTensor[0].str_h /= 2.0;
342       str.outputTensor.str_h /= 2.0;
343       str.cut_counter += 1;
344       str.cost = str.cost + cost_in_i_;
345       break;
346 
347     case 2:
348       str.inputTensor[0].str_w /= 2.0;
349       str.outputTensor.str_w /= 2.0;
350       str.cut_counter += 1;
351       str.cost = str.cost + cost_in_j_;
352       break;
353 
354     case 3:
355       str.inputTensor[1].str_n /= 2.0;
356       str.outputTensor.str_c /= 2.0;
357       str.cut_counter += 1;
358       str.cost = str.cost + cost_in_k_;
359       break;
360 
361     case 4:
362       str.inputTensor[1].str_h /= 2.0;
363       str.cut_counter += 1;
364       str.cost = str.cost + cost_in_di_;
365       break;
366 
367     case 5:
368       str.inputTensor[1].str_w /= 2.0;
369       str.cut_counter += 1;
370       str.cost = str.cost + cost_in_dj_;
371       break;
372 
373     case 6:
374       str.inputTensor[0].str_c /= 2.0;
375       str.inputTensor[1].str_c /= 2.0;
376       str.cut_counter += 1;
377       str.cost = str.cost + cost_in_q_;
378       break;
379 
380     default:
381       MS_LOG(EXCEPTION) << "Failure: CostConvolution failed.";
382   }
383   return str;
384 }
385 
386 // Get optimal strategy for Pooling
GetOptimalStr(const Graph::NodeType & node,const std::vector<std::pair<std::string,StrategyRec>> & node_name_to_strategy,const Graph & graph)387 StrategyRec CostPooling::GetOptimalStr(const Graph::NodeType &node,
388                                        const std::vector<std::pair<std::string, StrategyRec>> &node_name_to_strategy,
389                                        const Graph &graph) {
390   int64_t tensor_n = static_cast<int64_t>(node.tensor_parm.tensor_shape.shape_n * node.tensor_parm.tensor_str.str_n);
391   int64_t tensor_c = static_cast<int64_t>(node.tensor_parm.tensor_shape.shape_c * node.tensor_parm.tensor_str.str_c);
392 
393   std::vector<double> cost_op;
394   std::vector<std::vector<float>> mode;
395 
396   if (tensor_n < 2 || tensor_n % 2 != 0) {
397     cost_op.push_back(DOUBLE_MAX);
398   } else {
399     cost_op.push_back(cost_in_ + CostRedis(node, node_name_to_strategy,
400                                            mode = {{0.5, 1, 1, 1}, {0.5, 1, 1, 1}, {0.5, 1, 1, 1}}, graph));
401   }
402 
403   if (tensor_c < 2 || tensor_c % 2 != 0) {
404     cost_op.push_back(DOUBLE_MAX);
405   } else {
406     cost_op.push_back(cost_in_ + CostRedis(node, node_name_to_strategy,
407                                            mode = {{1, 0.5, 1, 1}, {1, 0.5, 1, 1}, {1, 0.5, 1, 1}}, graph));
408   }
409 
410   cost_op.push_back(DOUBLE_MAX);
411   cost_op.push_back(DOUBLE_MAX);
412 
413   return ChoseStr(cost_op, node.apply.str);
414 }
415 
416 // Chose strategy for Pooling
ChoseStr(const std::vector<double> & cost_op,StrategyRec str)417 StrategyRec CostPooling::ChoseStr(const std::vector<double> &cost_op, StrategyRec str) {
418   uint64_t min_position = min_element(cost_op.begin(), cost_op.end()) - cost_op.begin();
419   if (cost_op[min_position] > (DOUBLE_MAX - 0.1)) {
420     return str;
421   }
422 
423   switch (min_position) {
424     case 0:
425       str.inputTensor[0].str_n /= 2.0;
426       str.outputTensor.str_n /= 2.0;
427       str.cut_counter += 1;
428       str.cost = str.cost + cost_in_;
429       break;
430 
431     case 1:
432       str.inputTensor[0].str_c /= 2.0;
433       str.outputTensor.str_c /= 2.0;
434       str.cut_counter += 1;
435       str.cost = str.cost + cost_in_;
436       break;
437 
438     case 2:
439       str.inputTensor[0].str_h /= 2.0;
440       str.outputTensor.str_h /= 2.0;
441       str.cut_counter += 1;
442       str.cost = str.cost + cost_in_;
443       break;
444 
445     case 3:
446       str.inputTensor[0].str_w /= 2.0;
447       str.outputTensor.str_w /= 2.0;
448       str.cut_counter += 1;
449       str.cost = str.cost + cost_in_;
450       break;
451 
452     default:
453       MS_LOG(EXCEPTION) << "Failure: CostPooling failed.";
454   }
455   return str;
456 }
457 
458 // Chose strategy for Add
ChoseStr(const std::vector<double> & cost_op,StrategyRec str)459 StrategyRec CostTensorAdd::ChoseStr(const std::vector<double> &cost_op, StrategyRec str) {
460   uint64_t min_position = min_element(cost_op.begin(), cost_op.end()) - cost_op.begin();
461   if (cost_op[min_position] > (DOUBLE_MAX - 0.1)) {
462     return str;
463   }
464 
465   switch (min_position) {
466     case 0:
467       str.inputTensor[0].str_n /= 2.0;
468       str.inputTensor[1].str_n /= 2.0;
469       str.outputTensor.str_n /= 2.0;
470       str.cut_counter += 1;
471       str.cost = str.cost + cost_in_;
472       break;
473 
474     case 1:
475       str.inputTensor[0].str_c /= 2.0;
476       str.inputTensor[1].str_c /= 2.0;
477       str.outputTensor.str_c /= 2.0;
478       str.cut_counter += 1;
479       str.cost = str.cost + cost_in_;
480       break;
481 
482     case 2:
483       str.inputTensor[0].str_h /= 2.0;
484       str.inputTensor[1].str_h /= 2.0;
485       str.outputTensor.str_h /= 2.0;
486       str.cut_counter += 1;
487       str.cost = str.cost + cost_in_;
488       break;
489 
490     case 3:
491       str.inputTensor[0].str_w /= 2.0;
492       str.inputTensor[1].str_w /= 2.0;
493       str.outputTensor.str_w /= 2.0;
494       str.cut_counter += 1;
495       str.cost = str.cost + cost_in_;
496       break;
497 
498     default:
499       MS_LOG(EXCEPTION) << "Failure: CostAdd failed.";
500   }
501   return str;
502 }
503 
504 // Get optimal strategy for Reshape
GetOptimalStr(const Graph::NodeType & node) const505 StrategyRec CostReshape::GetOptimalStr(const Graph::NodeType &node) const { return ChoseStr(node.apply.str); }
506 
ChoseStr(StrategyRec str) const507 StrategyRec CostReshape::ChoseStr(StrategyRec str) const { return str; }
508 
509 // Chose strategy for BiasAdd
ChoseStr(const std::vector<double> & cost_op,StrategyRec str)510 StrategyRec CostBiasAdd::ChoseStr(const std::vector<double> &cost_op, StrategyRec str) {
511   uint64_t min_position = min_element(cost_op.begin(), cost_op.end()) - cost_op.begin();
512   if (cost_op[min_position] > (DOUBLE_MAX - 0.1)) {
513     return str;
514   }
515 
516   switch (min_position) {
517     case 0:
518       str.inputTensor[0].str_n /= 2.0;
519       str.outputTensor.str_n /= 2.0;
520       str.cut_counter += 1;
521       str.cost = str.cost + cost_in_;
522       break;
523 
524     case 1:
525       str.inputTensor[0].str_c /= 2.0;
526       str.outputTensor.str_c /= 2.0;
527       str.cut_counter += 1;
528       str.cost = str.cost + cost_in_;
529       break;
530 
531     case 2:
532       str.inputTensor[0].str_h /= 2.0;
533       str.outputTensor.str_h /= 2.0;
534       str.cut_counter += 1;
535       str.cost = str.cost + cost_in_;
536       break;
537 
538     case 3:
539       str.inputTensor[0].str_w /= 2.0;
540       str.inputTensor[1].str_w /= 2.0;
541       str.outputTensor.str_w /= 2.0;
542       str.cut_counter += 1;
543       str.cost = str.cost + cost_in_;
544       break;
545 
546     default:
547       MS_LOG(EXCEPTION) << "Failure: CostBiasAdd failed.";
548   }
549   return str;
550 }
551 
552 // Get optimal strategy for Common OPs
GetOptimalStr(const Graph::NodeType & node,const std::vector<std::pair<std::string,StrategyRec>> & node_name_to_strategy,const Graph & graph)553 StrategyRec CostCommon::GetOptimalStr(const Graph::NodeType &node,
554                                       const std::vector<std::pair<std::string, StrategyRec>> &node_name_to_strategy,
555                                       const Graph &graph) {
556   const OperatorRec &op = node.apply;
557   int64_t tensor_n = static_cast<int64_t>(op.arguments[0].tensor_shape.shape_n * op.arguments[0].tensor_str.str_n);
558   int64_t tensor_c = static_cast<int64_t>(op.arguments[0].tensor_shape.shape_c * op.arguments[0].tensor_str.str_c);
559   int64_t tensor_h = static_cast<int64_t>(op.arguments[0].tensor_shape.shape_h * op.arguments[0].tensor_str.str_h);
560   int64_t tensor_w = static_cast<int64_t>(op.arguments[0].tensor_shape.shape_w * op.arguments[0].tensor_str.str_w);
561 
562   std::vector<double> cost_op;
563   std::vector<std::vector<float>> mode;
564 
565   if (tensor_n < 2 || tensor_n % 2 != 0) {
566     cost_op.push_back(DOUBLE_MAX);
567   } else {
568     cost_op.push_back(cost_in_ + CostRedis(node, node_name_to_strategy,
569                                            mode = {{0.5, 1, 1, 1}, {0.5, 1, 1, 1}, {0.5, 1, 1, 1}}, graph));
570   }
571 
572   if (tensor_c < 2 || tensor_c % 2 != 0) {
573     cost_op.push_back(DOUBLE_MAX);
574   } else {
575     cost_op.push_back(cost_in_ + CostRedis(node, node_name_to_strategy,
576                                            mode = {{1, 0.5, 1, 1}, {1, 0.5, 1, 1}, {1, 0.5, 1, 1}}, graph));
577   }
578 
579   if (tensor_h < 2 || tensor_h % 2 != 0) {
580     cost_op.push_back(DOUBLE_MAX);
581   } else {
582     cost_op.push_back(cost_in_ + CostRedis(node, node_name_to_strategy,
583                                            mode = {{1, 1, 0.5, 1}, {1, 1, 0.5, 1}, {1, 1, 0.5, 1}}, graph));
584   }
585 
586   if (tensor_w < 2 || tensor_w % 2 != 0) {
587     cost_op.push_back(DOUBLE_MAX);
588   } else {
589     cost_op.push_back(cost_in_ + CostRedis(node, node_name_to_strategy,
590                                            mode = {{1, 1, 1, 0.5}, {1, 1, 1, 0.5}, {1, 1, 1, 0.5}}, graph));
591   }
592 
593   return ChoseStr(cost_op, node.apply.str);
594 }
595 
596 // Chose strategy for Common op
ChoseStr(const std::vector<double> & cost_op,StrategyRec str)597 StrategyRec CostCommon::ChoseStr(const std::vector<double> &cost_op, StrategyRec str) {
598   uint64_t min_position = min_element(cost_op.begin(), cost_op.end()) - cost_op.begin();
599   if (cost_op[min_position] > (DOUBLE_MAX - 0.1)) {
600     return str;
601   }
602 
603   switch (min_position) {
604     case 0:
605       str.inputTensor[0].str_n /= 2.0;
606       str.outputTensor.str_n /= 2.0;
607       str.cut_counter += 1;
608       str.cost = str.cost + cost_in_;
609       break;
610 
611     case 1:
612       str.inputTensor[0].str_c /= 2.0;
613       str.outputTensor.str_c /= 2.0;
614       str.cut_counter += 1;
615       str.cost = str.cost + cost_in_;
616       break;
617 
618     case 2:
619       str.inputTensor[0].str_h /= 2.0;
620       str.outputTensor.str_h /= 2.0;
621       str.cut_counter += 1;
622       str.cost = str.cost + cost_in_;
623       break;
624 
625     case 3:
626       str.inputTensor[0].str_w /= 2.0;
627       str.outputTensor.str_w /= 2.0;
628       str.cut_counter += 1;
629       str.cost = str.cost + cost_in_;
630       break;
631 
632     default:
633       MS_LOG(EXCEPTION) << "Failure: Common failed.";
634   }
635   return str;
636 }
637 
638 // Get optimal strategy for BatchParallel OPs
GetOptimalStr(const Graph::NodeType & node)639 StrategyRec CostBatchParallel::GetOptimalStr(const Graph::NodeType &node) {
640   const OperatorRec &op = node.apply;
641   int64_t tensor_n = static_cast<int64_t>(op.arguments[0].tensor_shape.shape_n * op.arguments[0].tensor_str.str_n);
642   int64_t tensor_c = static_cast<int64_t>(op.arguments[0].tensor_shape.shape_c * op.arguments[0].tensor_str.str_c);
643   int64_t tensor_h = static_cast<int64_t>(op.arguments[0].tensor_shape.shape_h * op.arguments[0].tensor_str.str_h);
644   int64_t tensor_w = static_cast<int64_t>(op.arguments[0].tensor_shape.shape_w * op.arguments[0].tensor_str.str_w);
645 
646   std::vector<double> cost_op;
647 
648   if (tensor_n < 2 || tensor_n % 2 != 0) {
649     cost_op.push_back(DOUBLE_MAX);
650   } else {
651     cost_op.push_back(cost_in_);
652   }
653 
654   if (tensor_c < 2 || tensor_c % 2 != 0) {
655     cost_op.push_back(DOUBLE_MAX);
656   } else {
657     cost_op.push_back(cost_in_);
658   }
659 
660   if (tensor_h < 2 || tensor_h % 2 != 0) {
661     cost_op.push_back(DOUBLE_MAX);
662   } else {
663     cost_op.push_back(cost_in_);
664   }
665 
666   if (tensor_w < 2 || tensor_w % 2 != 0) {
667     cost_op.push_back(DOUBLE_MAX);
668   } else {
669     cost_op.push_back(cost_in_);
670   }
671 
672   return ChoseStr(cost_op, node.apply.str);
673 }
674 
675 // Chose strategy for BatchParallel op
ChoseStr(const std::vector<double> & cost_op,StrategyRec str)676 StrategyRec CostBatchParallel::ChoseStr(const std::vector<double> &cost_op, StrategyRec str) {
677   uint64_t min_position = min_element(cost_op.begin(), cost_op.end()) - cost_op.begin();
678   if (cost_op[min_position] > (DOUBLE_MAX - 0.1)) {
679     return str;
680   }
681 
682   switch (min_position) {
683     case 0:
684       str.inputTensor[0].str_n /= 2.0;
685       str.outputTensor.str_n /= 2.0;
686       str.cut_counter += 1;
687       str.cost = str.cost + cost_in_;
688       break;
689 
690     case 1:
691       str.inputTensor[0].str_c /= 2.0;
692       str.outputTensor.str_c /= 2.0;
693       str.cut_counter += 1;
694       str.cost = str.cost + cost_in_;
695       break;
696 
697     case 2:
698       str.inputTensor[0].str_h /= 2.0;
699       str.outputTensor.str_h /= 2.0;
700       str.cut_counter += 1;
701       str.cost = str.cost + cost_in_;
702       break;
703 
704     case 3:
705       str.inputTensor[0].str_w /= 2.0;
706       str.outputTensor.str_w /= 2.0;
707       str.cut_counter += 1;
708       str.cost = str.cost + cost_in_;
709       break;
710 
711     default:
712       MS_LOG(EXCEPTION) << "Failure: CostBatchParallel failed.";
713   }
714   return str;
715 }
716 
717 // Chose strategy for CostSoftmaxCrossEntropyWithLogits
ChoseStr(const std::vector<double> & cost_op,StrategyRec str)718 StrategyRec CostSoftmaxCrossEntropyWithLogits::ChoseStr(const std::vector<double> &cost_op, StrategyRec str) {
719   uint64_t min_position = min_element(cost_op.begin(), cost_op.end()) - cost_op.begin();
720   if (cost_op[min_position] > (DOUBLE_MAX - 0.1)) {
721     return str;
722   }
723 
724   switch (min_position) {
725     case 0:
726       str.inputTensor[0].str_n /= 2.0;
727       str.inputTensor[1].str_n /= 2.0;
728       str.cut_counter += 1;
729       str.cost = str.cost + cost_in_;
730       break;
731 
732     case 1:
733       str.inputTensor[0].str_c /= 2.0;
734       str.inputTensor[1].str_c /= 2.0;
735       str.cut_counter += 1;
736       str.cost = str.cost + cost_in_;
737       break;
738 
739     case 2:
740       str.inputTensor[0].str_h /= 2.0;
741       str.inputTensor[1].str_h /= 2.0;
742       str.outputTensor.str_w /= 2.0;
743       str.cut_counter += 1;
744       str.cost = str.cost + cost_in_;
745       break;
746 
747     case 3:
748       str.inputTensor[0].str_w /= 2.0;
749       str.inputTensor[1].str_w /= 2.0;
750       str.cut_counter += 1;
751       str.cost = str.cost + cost_in_;
752       break;
753 
754     default:
755       MS_LOG(EXCEPTION) << "Failure: CostSoftmax failed.";
756   }
757   return str;
758 }
759 }  // namespace parallel
760 }  // namespace mindspore
761