• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/parallel/auto_parallel/rec_core/rec_cost.h"
18 
19 #include <algorithm>
20 #include <limits>
21 #include <string>
22 #include <utility>
23 #include <vector>
24 
25 #include "include/common/utils/parallel_context.h"
26 
27 namespace mindspore {
28 namespace parallel {
SameShape(const Shape4D & shape1,const Shape4D & shape2)29 bool SameShape(const Shape4D &shape1, const Shape4D &shape2) {
30   bool equal = (shape1 == shape2);
31 
32   return (equal || !ONLY_REDIST_WITH_SAME_SHAPE);
33 }
34 
costOfDistributing(const TensorParam & t)35 double costOfDistributing(const TensorParam &t) {
36   return (static_cast<double>(t.tensor_shape.shape_n) * t.tensor_str.str_n *
37           static_cast<double>(t.tensor_shape.shape_c) * t.tensor_str.str_c *
38           static_cast<double>(t.tensor_shape.shape_h) * t.tensor_str.str_h *
39           static_cast<double>(t.tensor_shape.shape_w) * t.tensor_str.str_w / 2.0);
40 }
41 
minNodeSize(const Graph::NodeType & node)42 double minNodeSize(const Graph::NodeType &node) {
43   double distributing0 = costOfDistributing(node.apply.arguments[0]);
44   double distributing1 = costOfDistributing(node.apply.arguments[1]);
45   double distributing2 = costOfDistributing(node.tensor_parm);
46   double min_distribution = std::min(distributing0, distributing1);
47   min_distribution = std::min(min_distribution, distributing2);
48   min_distribution *= EXPERT_COEF;
49   return min_distribution;
50 }
51 
52 // Compute redistributed cost
CostRedis(const Graph::NodeType & node,const std::vector<std::pair<std::string,StrategyRec>> & node_name_to_strategy,const std::vector<std::vector<float>> & mode,const Graph & graph)53 double CostRedis(const Graph::NodeType &node,
54                  const std::vector<std::pair<std::string, StrategyRec>> &node_name_to_strategy,
55                  const std::vector<std::vector<float>> &mode, const Graph &graph) {
56   // Store value of cost redist
57   double cost_redis = 0;
58 
59   // Number of current strategies.
60   size_t num_strategy = node_name_to_strategy.size();
61 
62   // Number of node-in and node-out
63   size_t num_node_in = node.node_in.size();
64   size_t num_node_out = node.node_out.size();
65 
66   // Set tensor edge value with original tensor shape and cutting times.
67   double input_tensor = node.apply.arguments[0].tensor_shape.shape_n * node.apply.arguments[0].tensor_str.str_n *
68                         node.apply.arguments[0].tensor_shape.shape_c * node.apply.arguments[0].tensor_str.str_c *
69                         node.apply.arguments[0].tensor_shape.shape_h * node.apply.arguments[0].tensor_str.str_h *
70                         node.apply.arguments[0].tensor_shape.shape_w * node.apply.arguments[0].tensor_str.str_w;
71 
72   double output_tensor = node.tensor_parm.tensor_shape.shape_n * node.tensor_parm.tensor_str.str_n *
73                          node.tensor_parm.tensor_shape.shape_c * node.tensor_parm.tensor_str.str_c *
74                          node.tensor_parm.tensor_shape.shape_h * node.tensor_parm.tensor_str.str_h *
75                          node.tensor_parm.tensor_shape.shape_w * node.tensor_parm.tensor_str.str_w;
76 
77   // For each strategy candidate.
78   for (size_t i_strategy = 0; i_strategy < num_strategy; i_strategy++) {
79     // Find its forward nodes
80     for (size_t i_node = 0; i_node < num_node_in; i_node++) {
81       if (graph.nodes[node.node_in[i_node]].name == node_name_to_strategy[i_strategy].first &&
82           SameShape(graph.nodes[node.node_in[i_node]].tensor_parm.tensor_shape,
83                     node.apply.arguments[i_node].tensor_shape)) {
84         bool is_search_forward = true;
85         cost_redis +=
86           CostRedisWithAdjacentNode(node_name_to_strategy, mode, i_strategy, i_node, input_tensor, is_search_forward);
87       }
88     }
89 
90     // Find its backward nodes
91     for (size_t i_node = 0; i_node < num_node_out; i_node++) {
92       bool is_same_shape =
93         SameShape(graph.nodes[node.node_out[i_node]].apply.arguments[0].tensor_shape, node.tensor_parm.tensor_shape) ||
94         SameShape(graph.nodes[node.node_out[i_node]].apply.arguments[1].tensor_shape, node.tensor_parm.tensor_shape);
95       if (graph.nodes[node.node_out[i_node]].name == node_name_to_strategy[i_strategy].first && is_same_shape) {
96         bool is_search_forward = false;
97         cost_redis +=
98           CostRedisWithAdjacentNode(node_name_to_strategy, mode, i_strategy, i_node, output_tensor, is_search_forward);
99       }
100     }
101 
102     // Calculate the Redis Cost of node_in_aux
103     for (size_t i_node = 0; i_node < node.node_in_aux.size(); i_node++) {
104       size_t index = node.node_in_aux_idx[i_node];
105       if (graph.nodes[node.node_in_aux[i_node]].name == node_name_to_strategy[i_strategy].first &&
106           SameShape(graph.nodes[node.node_in_aux[i_node]].tensor_parm.tensor_shape,
107                     node.apply.arguments[index].tensor_shape)) {
108         bool is_search_forward = true;
109         cost_redis +=
110           CostRedisWithAdjacentNode(node_name_to_strategy, mode, i_strategy, index, input_tensor, is_search_forward);
111       }
112     }
113   }
114 
115   return cost_redis;
116 }
117 
CostRedisWithAdjacentNode(const std::vector<std::pair<std::string,StrategyRec>> & node_name_to_strategy,const std::vector<std::vector<float>> & mode,size_t i_strategy,size_t i_node,double tensor_size,bool search_forward)118 double CostRedisWithAdjacentNode(const std::vector<std::pair<std::string, StrategyRec>> &node_name_to_strategy,
119                                  const std::vector<std::vector<float>> &mode, size_t i_strategy, size_t i_node,
120                                  double tensor_size, bool search_forward) {
121   double new_redis_cost = 0;
122   bool diff = false;
123 
124   auto output_tensor = node_name_to_strategy[i_strategy].second.outputTensor;
125   auto input_tensor = node_name_to_strategy[i_strategy].second.inputTensor[0];
126 
127   if (search_forward) {
128     float output_dims[NDIMS] = {output_tensor.str_n, output_tensor.str_c, output_tensor.str_h, output_tensor.str_w};
129     for (size_t i = 0; i < NDIMS; ++i) {
130       if (output_dims[i] == 0 || mode[i_node][i] == 0) {
131         MS_LOG(EXCEPTION) << "divisors cannot be 0!";
132       }
133       if (static_cast<int64_t>(1 / output_dims[i]) != static_cast<int64_t>(1 / mode[i_node][i])) {
134         diff = true;
135         break;
136       }
137     }
138   } else {
139     float input_dims[NDIMS] = {input_tensor.str_n, input_tensor.str_c, input_tensor.str_h, input_tensor.str_w};
140     for (size_t i = 0; i < NDIMS; ++i) {
141       if (input_dims[i] == 0 || mode[2][i] == 0) {
142         MS_LOG(EXCEPTION) << "divisors cannot be 0!";
143       }
144       if (static_cast<int64_t>(1 / input_dims[i]) != static_cast<int64_t>(1 / mode[2][i])) {
145         diff = true;
146         break;
147       }
148     }
149   }
150 
151   if (diff) {
152     new_redis_cost = tensor_size * REDIS_COEF;
153   }
154 
155   return new_redis_cost;
156 }
157 
hasBeenSplitted(const Graph::NodeType & node,const bool dyn_shape_tmp_fix)158 bool hasBeenSplitted(const Graph::NodeType &node, const bool dyn_shape_tmp_fix) {
159   if (dyn_shape_tmp_fix) {
160     if (node.apply.arguments[0].tensor_str.str_h < 1 || node.apply.arguments[0].tensor_str.str_w < 1 ||
161         node.apply.arguments[1].tensor_str.str_w < 1 || node.apply.arguments[0].tensor_str.str_n < 1 ||
162         node.apply.arguments[0].tensor_str.str_c < 1) {
163       return true;
164     }
165   }
166   return false;
167 }
168 
169 // Get optimal strategy for MatMul
GetOptimalStr(const Graph::NodeType & node,const std::vector<std::pair<std::string,StrategyRec>> & node_name_to_strategy,const Graph & graph,const bool isTraining)170 StrategyRec CostMatMul::GetOptimalStr(const Graph::NodeType &node,
171                                       const std::vector<std::pair<std::string, StrategyRec>> &node_name_to_strategy,
172                                       const Graph &graph, const bool isTraining) {
173   int64_t edge_i =
174     static_cast<int64_t>(node.apply.arguments[0].tensor_shape.shape_h * node.apply.arguments[0].tensor_str.str_h);
175   int64_t edge_j =
176     static_cast<int64_t>(node.apply.arguments[1].tensor_shape.shape_w * node.apply.arguments[1].tensor_str.str_w);
177   int64_t edge_k =
178     static_cast<int64_t>(node.apply.arguments[0].tensor_shape.shape_w * node.apply.arguments[0].tensor_str.str_w);
179 
180   bool isMicroBatchSizeLargeEnough = true;
181   if (parallel::ParallelContext::GetInstance()->pipeline_stage_split_num() > 1) {
182     if (graph.micro_batch_size * node.apply.arguments[0].tensor_str.str_h <= 1) {
183       isMicroBatchSizeLargeEnough = false;
184     }
185   }
186 
187   std::vector<double> cost_op;
188   if (node.apply.arguments[0].tensor_str.str_h == 0) {
189     MS_LOG(EXCEPTION) << "str_h cannot be 0!";
190   }
191   if (edge_i < INT64_TWO || edge_i % INT64_TWO != 0 || !isMicroBatchSizeLargeEnough) {
192     cost_op.push_back(DOUBLE_MAX);
193   } else {
194     std::vector<std::vector<float>> mode = {{1, 1, 0.5, 1}, {1, 1, 1, 1}, {1, 1, 0.5, 1}};
195     double cost_if_cut_i = StrConcatDimI(edge_j, edge_k);
196     double redist_if_cut_i = CostRedis(node, node_name_to_strategy, mode, graph);
197     double total_cost_if_cut_i = cost_if_cut_i + redist_if_cut_i;
198     MS_LOG(INFO) << "If the I-axis is cut, the op-cost is " << cost_if_cut_i << ", the redist-cost is "
199                  << redist_if_cut_i << ", and the total cost is " << total_cost_if_cut_i;
200     cost_op.push_back(total_cost_if_cut_i);
201   }
202 
203   // Do not partition the J-axis and K-axis for the same MatMul
204   if (edge_j < INT64_TWO || edge_j % INT64_TWO != 0 || node.apply.arguments[0].tensor_str.str_w < 1) {
205     cost_op.push_back(DOUBLE_MAX);
206   } else {
207     std::vector<std::vector<float>> mode = {{1, 1, 1, 1}, {1, 1, 1, 0.5}, {1, 1, 1, 0.5}};
208     double cost_if_cut_j = StrConcatDimJ(edge_i, edge_k);
209     double redist_if_cut_j = CostRedis(node, node_name_to_strategy, mode, graph);
210     double total_cost_if_cut_j = cost_if_cut_j + redist_if_cut_j;
211     MS_LOG(INFO) << "If the J-axis is cut, the op-cost is " << cost_if_cut_j << ", the redist-cost is "
212                  << redist_if_cut_j << ", and the total cost is " << total_cost_if_cut_j;
213     cost_op.push_back(total_cost_if_cut_j);
214   }
215 
216   if (edge_k < INT64_TWO || edge_k % INT64_TWO != 0 || node.apply.arguments[1].tensor_str.str_w < 1) {
217     cost_op.push_back(DOUBLE_MAX);
218   } else {
219     std::vector<std::vector<float>> mode = {{1, 1, 1, 0.5}, {1, 1, 0.5, 1}, {1, 1, 1, 1}};
220     double cost_if_cut_k = StrReduceDimK(edge_i, edge_j);
221     double redist_if_cut_k = CostRedis(node, node_name_to_strategy, mode, graph);
222     double total_cost_if_cut_k = cost_if_cut_k + redist_if_cut_k;
223     MS_LOG(INFO) << "If the K-axis is cut, the op-cost is " << cost_if_cut_k << ", the redist-cost is "
224                  << redist_if_cut_k << ", and the total cost is " << total_cost_if_cut_k;
225     cost_op.push_back(total_cost_if_cut_k);
226   }
227 
228   if (hasBeenSplitted(node, graph.dyn_shape_tmp_fix)) {
229     cost_op.push_back(DOUBLE_MAX);
230   } else {
231     std::vector<std::vector<float>> mode = {{1, 1, 1, 1}, {1, 1, 1, 1}, {1, 1, 1, 1}};
232     double cost_if_no_cut =
233       StrRecom(StrConcatDimI(edge_j, edge_k), StrConcatDimJ(edge_i, edge_k), StrReduceDimK(edge_i, edge_j));
234     double redist_if_no_cut = CostRedis(node, node_name_to_strategy, mode, graph);
235     double total_cost_if_no_cut = cost_if_no_cut + redist_if_no_cut;
236     MS_LOG(INFO) << "If do NOT cut the axis, the op-cost is " << cost_if_no_cut << ", the redist-cost is "
237                  << redist_if_no_cut << ", and the total cost is " << total_cost_if_no_cut;
238     cost_op.push_back(total_cost_if_no_cut);
239   }
240 
241   for (auto &cost : cost_op) {
242     cost = std::abs(cost);
243   }
244 
245   return ChoseStr(cost_op, node.apply.str);
246 }
247 
248 // Get weight for MatMul
GetMaxCostIn(const OperatorRec & op)249 double CostMatMul::GetMaxCostIn(const OperatorRec &op) {
250   int64_t edge_i = static_cast<int64_t>(op.arguments[0].tensor_shape.shape_h * op.arguments[0].tensor_str.str_h);
251   int64_t edge_j = static_cast<int64_t>(op.arguments[1].tensor_shape.shape_w * op.arguments[1].tensor_str.str_w);
252   int64_t edge_k = static_cast<int64_t>(op.arguments[0].tensor_shape.shape_w * op.arguments[0].tensor_str.str_w);
253 
254   double cost_if_cut_i = StrConcatDimI(edge_j, edge_k);
255   double cost_if_cut_j = StrConcatDimJ(edge_i, edge_k);
256   double cost_if_cut_k = StrReduceDimK(edge_i, edge_j);
257   double cost_if_no_cut = StrRecom(cost_if_cut_i, cost_if_cut_j, cost_if_cut_k);
258 
259   std::vector<double> cost_in;
260   cost_in.push_back(cost_if_cut_i);
261   cost_in.push_back(cost_if_cut_j);
262   cost_in.push_back(cost_if_cut_k);
263   cost_in.push_back(cost_if_no_cut);
264 
265   return *max_element(cost_in.begin(), cost_in.end());
266 }
267 
268 // Chose strategy for MatMul
ChoseStr(const std::vector<double> & cost_op,StrategyRec str) const269 StrategyRec CostMatMul::ChoseStr(const std::vector<double> &cost_op, StrategyRec str) const {
270   MS_LOG(INFO) << "The costs of cutting the I-axis/J-axis/K-axis/no_cut are : " << cost_op;
271   uint64_t min_position = LongToUlong(min_element(cost_op.begin(), cost_op.end()) - cost_op.begin());
272   if (cost_op[min_position] > (DOUBLE_MAX - 0.1)) {
273     return str;
274   }
275 
276   switch (min_position) {
277     case 0:
278       str.inputTensor[0].str_h /= 2.0;
279       str.outputTensor.str_h /= 2.0;
280       str.cut_counter += 1;
281       str.cost = str.cost + cost_in_i_;
282       MS_LOG(INFO) << "The I-axis is chosen to cut";
283       break;
284 
285     case 1:
286       str.inputTensor[1].str_w /= 2.0;
287       str.outputTensor.str_w /= 2.0;
288       str.cut_counter += 1;
289       str.cost = str.cost + cost_in_j_;
290       MS_LOG(INFO) << "The J-axis is chosen to cut";
291       break;
292 
293     case 2:
294       str.inputTensor[0].str_w /= 2.0;
295       str.inputTensor[1].str_h /= 2.0;
296       str.cut_counter += 1;
297       str.cost = str.cost + cost_in_k_;
298       MS_LOG(INFO) << "The K-axis is chosen to cut";
299       break;
300 
301     case 3:
302       MS_LOG(INFO) << "Choose NOT to cut";
303       break;
304 
305     default:
306       MS_LOG(EXCEPTION) << "Failure:CostMatMul failed.";
307   }
308 
309   return str;
310 }
311 
getBatchDimsSize(const OperatorRec & op)312 size_t CostBatchMatMul::getBatchDimsSize(const OperatorRec &op) {
313   return static_cast<double>(std::max(op.arguments[0].tensor_shape.shape_n, op.arguments[1].tensor_shape.shape_n)) *
314          std::max(op.arguments[0].tensor_str.str_n, op.arguments[1].tensor_str.str_n) *
315          static_cast<double>(std::max(op.arguments[0].tensor_shape.shape_c, op.arguments[1].tensor_shape.shape_c)) *
316          std::max(op.arguments[0].tensor_str.str_c, op.arguments[1].tensor_str.str_c);
317 }
318 
cost(Axis a,const Graph::NodeType & node)319 double CostBatchMatMul::cost(Axis a, const Graph::NodeType &node) {
320   double mc_ratio;
321   size_t batch_dims_size = getBatchDimsSize(node.apply);
322   if (batch_dims_size == 1) {
323     mc_ratio = static_cast<double>(NUMBER_ASCEND_CORES);
324   } else {
325     mc_ratio = std::max(NUMBER_ASCEND_CORES / static_cast<double>(batch_dims_size) - 1, 0.0);
326   }
327   double min_size = minNodeSize(node);
328 
329   switch (a) {
330     // Calculate the cost if the Batch-axis of BatchMatMul is cut
331     case B:
332       return (mc_ratio * min_size);
333 
334     // Calculate the cost if the Expert-axis of BatchMatMul is cut
335     case X:
336       return (mc_ratio * min_size) - 1;
337 
338     // Calculate the cost if the I-axis of BatchMatMul is cut
339     case I:
340       return costOfDistributing(node.apply.arguments[1]);
341 
342     // Calculate the cost if the J-axis of BatchMatMul is cut
343     case J:
344       return costOfDistributing(node.apply.arguments[0]);
345 
346     // Calculate the cost if the K-axis of BatchMatMul is cut
347     case K:
348       return costOfDistributing(node.tensor_parm);
349 
350     // Calculate the cost if BatchMatMul is not cut
351     case R:
352       return min_size * min_size / REPLICATE_BELOW;
353 
354     default:
355       MS_LOG(EXCEPTION) << "Axis " << a << " is not taken into account";
356   }
357 
358   return 1;
359 }
360 
SplitOnlyOneDimension(const Graph & graph,float str)361 bool SplitOnlyOneDimension(const Graph &graph, float str) {
362   if (graph.dyn_shape_tmp_fix && str < 1) {
363     return true;
364   }
365   return false;
366 }
367 
IsEdgeSplittable(const int64_t edge)368 bool IsEdgeSplittable(const int64_t edge) {
369   if (edge < INT64_TWO || edge % INT64_TWO != 0) {
370     return false;
371   }
372   return true;
373 }
374 
375 // Get optimal strategy for BatchMatMul
GetOptimalStr(const Graph::NodeType & node,const std::vector<std::pair<std::string,StrategyRec>> & node_name_to_strategy,const Graph & graph,const bool isTraining)376 StrategyRec CostBatchMatMul::GetOptimalStr(
377   const Graph::NodeType &node, const std::vector<std::pair<std::string, StrategyRec>> &node_name_to_strategy,
378   const Graph &graph, const bool isTraining) {
379   int64_t edge_b =
380     static_cast<int64_t>(node.apply.arguments[0].tensor_shape.shape_n * node.apply.arguments[0].tensor_str.str_n);
381   int64_t edge_x =
382     static_cast<int64_t>(node.apply.arguments[0].tensor_shape.shape_c * node.apply.arguments[0].tensor_str.str_c);
383   int64_t edge_i =
384     static_cast<int64_t>(node.apply.arguments[0].tensor_shape.shape_h * node.apply.arguments[0].tensor_str.str_h);
385   int64_t edge_j =
386     static_cast<int64_t>(node.apply.arguments[1].tensor_shape.shape_w * node.apply.arguments[1].tensor_str.str_w);
387   int64_t edge_k =
388     static_cast<int64_t>(node.apply.arguments[0].tensor_shape.shape_w * node.apply.arguments[0].tensor_str.str_w);
389 
390   bool isMicroBatchSizeLargeEnough = true;
391   if (parallel::ParallelContext::GetInstance()->pipeline_stage_split_num() > 1) {
392     if (graph.micro_batch_size * node.apply.arguments[0].tensor_str.str_n <= 1) {
393       isMicroBatchSizeLargeEnough = false;
394     }
395   }
396 
397   std::vector<double> cost_op;
398   if (node.apply.arguments[0].tensor_str.str_n == 0) {
399     MS_LOG(EXCEPTION) << "str_n cannot be 0!";
400   }
401   if (!IsEdgeSplittable(edge_b) || !isMicroBatchSizeLargeEnough) {
402     cost_op.push_back(DOUBLE_MAX);
403   } else {
404     std::vector<std::vector<float>> mode = {{0.5, 1, 1, 1}, {0.5, 1, 1, 1}, {0.5, 1, 1, 1}};
405     double cost_if_cut_b = cost(B, node);
406     double redist_if_cut_b = CostRedis(node, node_name_to_strategy, mode, graph);
407     double total_cost_if_cut_b = cost_if_cut_b + redist_if_cut_b;
408     MS_LOG(INFO) << "If the Batch-axis is cut, the op-cost is " << cost_if_cut_b << ", the redist-cost is "
409                  << redist_if_cut_b << ", and the total cost is " << total_cost_if_cut_b;
410     cost_op.push_back(total_cost_if_cut_b);
411   }
412 
413   if (!IsEdgeSplittable(edge_x)) {
414     cost_op.push_back(DOUBLE_MAX);
415   } else {
416     std::vector<std::vector<float>> mode = {{1, 0.5, 1, 1}, {1, 0.5, 1, 1}, {1, 0.5, 1, 1}};
417     double cost_if_cut_x = cost(X, node);
418     double redist_if_cut_x = CostRedis(node, node_name_to_strategy, mode, graph);
419     double total_cost_if_cut_x = cost_if_cut_x + redist_if_cut_x;
420     MS_LOG(INFO) << "If the Expert-axis is cut, the op-cost is " << cost_if_cut_x << ", the redist-cost is "
421                  << redist_if_cut_x << ", and the total cost is " << total_cost_if_cut_x;
422     cost_op.push_back(total_cost_if_cut_x);
423   }
424 
425   if (!IsEdgeSplittable(edge_i) || SplitOnlyOneDimension(graph, node.apply.arguments[0].tensor_str.str_c)) {
426     cost_op.push_back(DOUBLE_MAX);
427   } else {
428     std::vector<std::vector<float>> mode = {{1, 1, 0.5, 1}, {1, 1, 1, 1}, {1, 1, 0.5, 1}};
429     double cost_if_cut_i = cost(I, node);
430     double redist_if_cut_i = CostRedis(node, node_name_to_strategy, mode, graph);
431     double total_cost_if_cut_i = cost_if_cut_i + redist_if_cut_i;
432     MS_LOG(INFO) << "If the I-axis is cut, the op-cost is " << cost_if_cut_i << ", the redist-cost is "
433                  << redist_if_cut_i << ", and the total cost is " << total_cost_if_cut_i;
434     cost_op.push_back(total_cost_if_cut_i);
435   }
436 
437   if (!IsEdgeSplittable(edge_j) || node.apply.arguments[0].tensor_str.str_w < 1 ||
438       SplitOnlyOneDimension(graph, node.apply.arguments[0].tensor_str.str_c)) {
439     cost_op.push_back(DOUBLE_MAX);
440   } else {
441     std::vector<std::vector<float>> mode = {{1, 1, 1, 1}, {1, 1, 1, 0.5}, {1, 1, 1, 0.5}};
442     double cost_if_cut_j = cost(J, node);
443     double redist_if_cut_j = CostRedis(node, node_name_to_strategy, mode, graph);
444     double total_cost_if_cut_j = cost_if_cut_j + redist_if_cut_j;
445     MS_LOG(INFO) << "If the J-axis is cut, the op-cost is " << cost_if_cut_j << ", the redist-cost is "
446                  << redist_if_cut_j << ", and the total cost is " << total_cost_if_cut_j;
447     cost_op.push_back(total_cost_if_cut_j / BMM_COEF);
448   }
449 
450   if (!IsEdgeSplittable(edge_k) || node.apply.arguments[1].tensor_str.str_w < 1 ||
451       SplitOnlyOneDimension(graph, node.apply.arguments[0].tensor_str.str_c)) {
452     cost_op.push_back(DOUBLE_MAX);
453   } else {
454     std::vector<std::vector<float>> mode = {{1, 1, 1, 0.5}, {1, 1, 0.5, 1}, {1, 1, 1, 1}};
455     double cost_if_cut_k = cost(K, node);
456     double redist_if_cut_k = CostRedis(node, node_name_to_strategy, mode, graph);
457     double total_cost_if_cut_k = cost_if_cut_k + redist_if_cut_k;
458     MS_LOG(INFO) << "If the K-axis is cut, the op-cost is " << cost_if_cut_k << ", the redist-cost is "
459                  << redist_if_cut_k << ", and the total cost is " << total_cost_if_cut_k;
460     cost_op.push_back(total_cost_if_cut_k / BMM_COEF);
461   }
462 
463   if (hasBeenSplitted(node, graph.dyn_shape_tmp_fix)) {
464     cost_op.push_back(DOUBLE_MAX);
465   } else {
466     std::vector<std::vector<float>> mode = {{1, 1, 1, 1}, {1, 1, 1, 1}, {1, 1, 1, 1}};
467     double cost_if_no_cut = cost(R, node);
468     double redist_if_no_cut = CostRedis(node, node_name_to_strategy, mode, graph);
469     double total_cost_if_no_cut = cost_if_no_cut + redist_if_no_cut;
470     MS_LOG(INFO) << "If do NOT cut the axis, the op-cost is " << cost_if_no_cut << ", the redist-cost is "
471                  << redist_if_no_cut << ", and the total cost is " << total_cost_if_no_cut;
472     cost_op.push_back(total_cost_if_no_cut);
473   }
474 
475   for (auto &cost : cost_op) {
476     cost = std::abs(cost);
477   }
478 
479   return ChoseStr(cost_op, node.apply.str);
480 }
481 
482 // Get weight for BatchMatMul
GetMaxCostIn(const Graph::NodeType & node)483 double CostBatchMatMul::GetMaxCostIn(const Graph::NodeType &node) {
484   std::vector<double> cost_in;
485   cost_in.push_back(cost(B, node));
486   cost_in.push_back(cost(X, node));
487   cost_in.push_back(cost(I, node));
488   cost_in.push_back(cost(J, node));
489   cost_in.push_back(cost(K, node));
490   cost_in.push_back(cost(R, node));
491 
492   return *max_element(cost_in.begin(), cost_in.end());
493 }
494 
495 // Chose strategy for BatchMatMul
ChoseStr(const std::vector<double> & cost_op,StrategyRec str) const496 StrategyRec CostBatchMatMul::ChoseStr(const std::vector<double> &cost_op, StrategyRec str) const {
497   MS_LOG(INFO) << "The costs of cutting the Batch-axis/Expert-axis/I-axis/J-axis/K-axis/no_cut are : " << cost_op;
498   uint64_t min_position = min_element(cost_op.begin(), cost_op.end()) - cost_op.begin();
499   if (cost_op[min_position] > (DOUBLE_MAX - 0.1)) {
500     return str;
501   }
502 
503   str.cut_counter += 1;
504   str.cost = str.cost + cost_op[min_position];
505 
506   switch (min_position) {
507     case 0:
508       str.inputTensor[0].str_n /= 2.0;
509       str.inputTensor[1].str_n /= 2.0;
510       str.outputTensor.str_n /= 2.0;
511       MS_LOG(INFO) << "The Batch-axis is chosen to cut";
512       break;
513 
514     case 1:
515       str.inputTensor[0].str_c /= 2.0;
516       str.inputTensor[1].str_c /= 2.0;
517       str.outputTensor.str_c /= 2.0;
518       MS_LOG(INFO) << "The Expert-axis is chosen to cut";
519       break;
520 
521     case 2:
522       str.inputTensor[0].str_h /= 2.0;
523       str.outputTensor.str_h /= 2.0;
524       MS_LOG(INFO) << "The I-axis is chosen to cut";
525       break;
526 
527     case 3:
528       str.inputTensor[1].str_w /= 2.0;
529       str.outputTensor.str_w /= 2.0;
530       MS_LOG(INFO) << "The J-axis is chosen to cut";
531       break;
532 
533     case 4:
534       str.inputTensor[0].str_w /= 2.0;
535       str.inputTensor[1].str_h /= 2.0;
536       MS_LOG(INFO) << "The K-axis is chosen to cut";
537       break;
538 
539     case 5:
540       MS_LOG(INFO) << "Choose NOT to cut";
541       break;
542 
543     default:
544       MS_LOG(EXCEPTION) << "Failure:CostBatchMatMul failed.";
545   }
546 
547   return str;
548 }
549 
550 // Get optimal strategy for Conv
GetOptimalStr(const Graph::NodeType & node,const std::vector<std::pair<std::string,StrategyRec>> & node_name_to_strategy,const Graph & graph,bool channel_partition)551 StrategyRec CostConvolution::GetOptimalStr(
552   const Graph::NodeType &node, const std::vector<std::pair<std::string, StrategyRec>> &node_name_to_strategy,
553   const Graph &graph, bool channel_partition) {
554   const OperatorRec &op = node.apply;
555 
556   int64_t input_tensor_h =
557     static_cast<int64_t>(op.arguments[0].tensor_shape.shape_h * op.arguments[0].tensor_str.str_h);
558   int64_t input_tensor_w =
559     static_cast<int64_t>(op.arguments[0].tensor_shape.shape_w * op.arguments[0].tensor_str.str_w);
560   int64_t input_tensor_n =
561     static_cast<int64_t>(op.arguments[0].tensor_shape.shape_n * op.arguments[0].tensor_str.str_n);
562   int64_t input_tensor_c =
563     static_cast<int64_t>(op.arguments[0].tensor_shape.shape_c * op.arguments[0].tensor_str.str_c);
564 
565   int64_t tensor_in = input_tensor_h * input_tensor_w * input_tensor_n * input_tensor_c;
566 
567   int64_t tensor_filter_h =
568     static_cast<int64_t>(op.arguments[1].tensor_shape.shape_h * op.arguments[1].tensor_str.str_h);
569   int64_t tensor_filter_w =
570     static_cast<int64_t>(op.arguments[1].tensor_shape.shape_w * op.arguments[1].tensor_str.str_w);
571   int64_t tensor_filter_n =
572     static_cast<int64_t>(op.arguments[1].tensor_shape.shape_n * op.arguments[1].tensor_str.str_n);
573   int64_t tensor_filter_c =
574     static_cast<int64_t>(op.arguments[1].tensor_shape.shape_c * op.arguments[1].tensor_str.str_c);
575 
576   int64_t tensor_filter = tensor_filter_h * tensor_filter_w * tensor_filter_n * tensor_filter_c;
577 
578   int64_t output_tensor_h =
579     static_cast<int64_t>(node.tensor_parm.tensor_shape.shape_h * node.tensor_parm.tensor_str.str_h);
580   int64_t output_tensor_w =
581     static_cast<int64_t>(node.tensor_parm.tensor_shape.shape_w * node.tensor_parm.tensor_str.str_w);
582   int64_t output_tensor_n =
583     static_cast<int64_t>(node.tensor_parm.tensor_shape.shape_n * node.tensor_parm.tensor_str.str_n);
584   int64_t output_tensor_c =
585     static_cast<int64_t>(node.tensor_parm.tensor_shape.shape_c * node.tensor_parm.tensor_str.str_c);
586 
587   int64_t tensor_out = output_tensor_h * output_tensor_w * output_tensor_n * output_tensor_c;
588 
589   std::vector<double> cost_op;
590   cost_op.reserve(7);
591 
592   if (input_tensor_n < 2 || input_tensor_n % 2 != 0) {
593     cost_op.push_back(DOUBLE_MAX);
594   } else {
595     std::vector<std::vector<float>> mode = {{0.5, 1, 1, 1}, {1, 1, 1, 1}, {0.5, 1, 1, 1}};
596     cost_op.push_back(StrDimB(tensor_filter) + CostRedis(node, node_name_to_strategy, mode, graph));
597   }
598 
599   cost_op.push_back(DOUBLE_MAX);
600   cost_op.push_back(DOUBLE_MAX);
601 
602   if (channel_partition == false || tensor_filter < 2 || tensor_filter % 2 != 0) {
603     cost_op.push_back(DOUBLE_MAX);
604   } else {
605     std::vector<std::vector<float>> mode = {{1, 1, 1, 1}, {0.5, 1, 1, 1}, {1, 0.5, 1, 1}};
606     cost_op.push_back(StrDimK(tensor_in) + CostRedis(node, node_name_to_strategy, mode, graph));
607   }
608 
609   cost_op.push_back(DOUBLE_MAX);
610   cost_op.push_back(DOUBLE_MAX);
611 
612   if (channel_partition == false || tensor_filter_c < 2 || tensor_filter_c % 2 != 0) {
613     cost_op.push_back(DOUBLE_MAX);
614   } else {
615     std::vector<std::vector<float>> mode = {{1, 0.5, 1, 1}, {1, 0.5, 1, 1}, {1, 1, 1, 1}};
616     cost_op.push_back(StrDimQ(tensor_out) + CostRedis(node, node_name_to_strategy, mode, graph));
617   }
618 
619   return ChoseStr(cost_op, node.apply.str);
620 }
621 
622 // Get weight for Conv
GetMinCostIn(const Graph::NodeType & node)623 double CostConvolution::GetMinCostIn(const Graph::NodeType &node) {
624   const OperatorRec &op = node.apply;
625 
626   int64_t tensor_in = static_cast<int64_t>(op.arguments[0].tensor_shape.shape_h * op.arguments[0].tensor_str.str_h) *
627                       static_cast<int64_t>(op.arguments[0].tensor_shape.shape_n * op.arguments[0].tensor_str.str_n) *
628                       static_cast<int64_t>(op.arguments[0].tensor_shape.shape_w * op.arguments[0].tensor_str.str_w) *
629                       static_cast<int64_t>(op.arguments[0].tensor_shape.shape_c * op.arguments[0].tensor_str.str_c);
630   int64_t tensor_filter =
631     static_cast<int64_t>(op.arguments[1].tensor_shape.shape_h * op.arguments[1].tensor_str.str_h) *
632     static_cast<int64_t>(op.arguments[1].tensor_shape.shape_n * op.arguments[1].tensor_str.str_n) *
633     static_cast<int64_t>(op.arguments[1].tensor_shape.shape_w * op.arguments[1].tensor_str.str_w) *
634     static_cast<int64_t>(op.arguments[1].tensor_shape.shape_c * op.arguments[1].tensor_str.str_c);
635   int64_t tensor_out = static_cast<int64_t>(node.tensor_parm.tensor_shape.shape_h * node.tensor_parm.tensor_str.str_h) *
636                        static_cast<int64_t>(node.tensor_parm.tensor_shape.shape_n * node.tensor_parm.tensor_str.str_n) *
637                        static_cast<int64_t>(node.tensor_parm.tensor_shape.shape_w * node.tensor_parm.tensor_str.str_w) *
638                        static_cast<int64_t>(node.tensor_parm.tensor_shape.shape_c * node.tensor_parm.tensor_str.str_c);
639 
640   std::vector<double> cost_in;
641   cost_in.push_back(StrDimB(tensor_filter));
642   cost_in.push_back(StrDimI(tensor_in, tensor_filter));
643   cost_in.push_back(StrDimJ(tensor_in, tensor_filter));
644   cost_in.push_back(StrDimK(tensor_in));
645   cost_in.push_back(StrDimDI(tensor_in, tensor_out));
646   cost_in.push_back(StrDimDJ(tensor_in, tensor_out));
647   cost_in.push_back(StrDimQ(tensor_out));
648 
649   return *min_element(cost_in.begin(), cost_in.end());
650 }
651 
652 // Chose strategy for Conv
ChoseStr(const std::vector<double> & cost_op,StrategyRec str) const653 StrategyRec CostConvolution::ChoseStr(const std::vector<double> &cost_op, StrategyRec str) const {
654   uint64_t min_position = LongToUlong(min_element(cost_op.begin(), cost_op.end()) - cost_op.begin());
655   if (cost_op[min_position] > (DOUBLE_MAX - 0.1)) {
656     return str;
657   }
658 
659   switch (min_position) {
660     case 0:
661       str.inputTensor[0].str_n /= 2.0;
662       str.outputTensor.str_n /= 2.0;
663       str.cut_counter += 1;
664       str.cost = str.cost + cost_in_b_;
665       break;
666 
667     case 1:
668       str.inputTensor[0].str_h /= 2.0;
669       str.outputTensor.str_h /= 2.0;
670       str.cut_counter += 1;
671       str.cost = str.cost + cost_in_i_;
672       break;
673 
674     case 2:
675       str.inputTensor[0].str_w /= 2.0;
676       str.outputTensor.str_w /= 2.0;
677       str.cut_counter += 1;
678       str.cost = str.cost + cost_in_j_;
679       break;
680 
681     case 3:
682       str.inputTensor[1].str_n /= 2.0;
683       str.outputTensor.str_c /= 2.0;
684       str.cut_counter += 1;
685       str.cost = str.cost + cost_in_k_;
686       break;
687 
688     case 4:
689       str.inputTensor[1].str_h /= 2.0;
690       str.cut_counter += 1;
691       str.cost = str.cost + cost_in_di_;
692       break;
693 
694     case 5:
695       str.inputTensor[1].str_w /= 2.0;
696       str.cut_counter += 1;
697       str.cost = str.cost + cost_in_dj_;
698       break;
699 
700     case 6:
701       str.inputTensor[0].str_c /= 2.0;
702       str.inputTensor[1].str_c /= 2.0;
703       str.cut_counter += 1;
704       str.cost = str.cost + cost_in_q_;
705       break;
706 
707     default:
708       MS_LOG(EXCEPTION) << "Failure: CostConvolution failed.";
709   }
710   return str;
711 }
712 
713 // Get optimal strategy for Pooling
GetOptimalStr(const Graph::NodeType & node,const std::vector<std::pair<std::string,StrategyRec>> & node_name_to_strategy,const Graph & graph) const714 StrategyRec CostPooling::GetOptimalStr(const Graph::NodeType &node,
715                                        const std::vector<std::pair<std::string, StrategyRec>> &node_name_to_strategy,
716                                        const Graph &graph) const {
717   int64_t tensor_n = static_cast<int64_t>(node.tensor_parm.tensor_shape.shape_n * node.tensor_parm.tensor_str.str_n);
718   int64_t tensor_c = static_cast<int64_t>(node.tensor_parm.tensor_shape.shape_c * node.tensor_parm.tensor_str.str_c);
719 
720   std::vector<double> cost_op;
721 
722   if (tensor_n < 2 || tensor_n % 2 != 0) {
723     cost_op.push_back(DOUBLE_MAX);
724   } else {
725     std::vector<std::vector<float>> mode = {{0.5, 1, 1, 1}, {0.5, 1, 1, 1}, {0.5, 1, 1, 1}};
726     cost_op.push_back(cost_in_ + CostRedis(node, node_name_to_strategy, mode, graph));
727   }
728 
729   if (tensor_c < 2 || tensor_c % 2 != 0) {
730     cost_op.push_back(DOUBLE_MAX);
731   } else {
732     std::vector<std::vector<float>> mode = {{1, 0.5, 1, 1}, {1, 0.5, 1, 1}, {1, 0.5, 1, 1}};
733     cost_op.push_back(cost_in_ + CostRedis(node, node_name_to_strategy, mode, graph));
734   }
735 
736   cost_op.push_back(DOUBLE_MAX);
737   cost_op.push_back(DOUBLE_MAX);
738 
739   return ChoseStr(cost_op, node.apply.str);
740 }
741 
742 // Chose strategy for Pooling
ChoseStr(const std::vector<double> & cost_op,StrategyRec str) const743 StrategyRec CostPooling::ChoseStr(const std::vector<double> &cost_op, StrategyRec str) const {
744   uint64_t min_position = LongToUlong(min_element(cost_op.begin(), cost_op.end()) - cost_op.begin());
745   if (cost_op[min_position] > (DOUBLE_MAX - 0.1)) {
746     return str;
747   }
748 
749   switch (min_position) {
750     case 0:
751       str.inputTensor[0].str_n /= 2.0;
752       str.outputTensor.str_n /= 2.0;
753       str.cut_counter += 1;
754       str.cost = str.cost + cost_in_;
755       break;
756 
757     case 1:
758       str.inputTensor[0].str_c /= 2.0;
759       str.outputTensor.str_c /= 2.0;
760       str.cut_counter += 1;
761       str.cost = str.cost + cost_in_;
762       break;
763 
764     case 2:
765       str.inputTensor[0].str_h /= 2.0;
766       str.outputTensor.str_h /= 2.0;
767       str.cut_counter += 1;
768       str.cost = str.cost + cost_in_;
769       break;
770 
771     case 3:
772       str.inputTensor[0].str_w /= 2.0;
773       str.outputTensor.str_w /= 2.0;
774       str.cut_counter += 1;
775       str.cost = str.cost + cost_in_;
776       break;
777 
778     default:
779       MS_LOG(EXCEPTION) << "Failure: CostPooling failed.";
780   }
781   return str;
782 }
783 
784 // Chose strategy for Add
ChoseStr(const std::vector<double> & cost_op,StrategyRec str)785 StrategyRec CostTensorAdd::ChoseStr(const std::vector<double> &cost_op, StrategyRec str) {
786   uint64_t min_position = LongToUlong(min_element(cost_op.begin(), cost_op.end()) - cost_op.begin());
787   if (cost_op[min_position] > (DOUBLE_MAX - 0.1)) {
788     return str;
789   }
790 
791   switch (min_position) {
792     case 0:
793       str.inputTensor[0].str_n /= 2.0;
794       str.inputTensor[1].str_n /= 2.0;
795       str.outputTensor.str_n /= 2.0;
796       str.cut_counter += 1;
797       str.cost = str.cost + cost_in_;
798       break;
799 
800     case 1:
801       str.inputTensor[0].str_c /= 2.0;
802       str.inputTensor[1].str_c /= 2.0;
803       str.outputTensor.str_c /= 2.0;
804       str.cut_counter += 1;
805       str.cost = str.cost + cost_in_;
806       break;
807 
808     case 2:
809       str.inputTensor[0].str_h /= 2.0;
810       str.inputTensor[1].str_h /= 2.0;
811       str.outputTensor.str_h /= 2.0;
812       str.cut_counter += 1;
813       str.cost = str.cost + cost_in_;
814       break;
815 
816     case 3:
817       str.inputTensor[0].str_w /= 2.0;
818       str.inputTensor[1].str_w /= 2.0;
819       str.outputTensor.str_w /= 2.0;
820       str.cut_counter += 1;
821       str.cost = str.cost + cost_in_;
822       break;
823 
824     default:
825       MS_LOG(EXCEPTION) << "Failure: CostAdd failed.";
826   }
827   return str;
828 }
829 
830 // Get optimal strategy for Reshape
GetOptimalStr(const Graph::NodeType & node) const831 StrategyRec CostReshape::GetOptimalStr(const Graph::NodeType &node) const { return ChoseStr(node.apply.str); }
832 
ChoseStr(StrategyRec str) const833 StrategyRec CostReshape::ChoseStr(StrategyRec str) const { return str; }
834 
835 // Chose strategy for BiasAdd
ChoseStr(const std::vector<double> & cost_op,StrategyRec str)836 StrategyRec CostBiasAdd::ChoseStr(const std::vector<double> &cost_op, StrategyRec str) {
837   uint64_t min_position = LongToUlong(min_element(cost_op.begin(), cost_op.end()) - cost_op.begin());
838   if (cost_op[min_position] > (DOUBLE_MAX - 0.1)) {
839     return str;
840   }
841 
842   switch (min_position) {
843     case 0:
844       str.inputTensor[0].str_n /= 2.0;
845       str.outputTensor.str_n /= 2.0;
846       str.cut_counter += 1;
847       str.cost = str.cost + cost_in_;
848       break;
849 
850     case 1:
851       str.inputTensor[0].str_c /= 2.0;
852       str.outputTensor.str_c /= 2.0;
853       str.cut_counter += 1;
854       str.cost = str.cost + cost_in_;
855       break;
856 
857     case 2:
858       str.inputTensor[0].str_h /= 2.0;
859       str.outputTensor.str_h /= 2.0;
860       str.cut_counter += 1;
861       str.cost = str.cost + cost_in_;
862       break;
863 
864     case 3:
865       str.inputTensor[0].str_w /= 2.0;
866       str.inputTensor[1].str_w /= 2.0;
867       str.outputTensor.str_w /= 2.0;
868       str.cut_counter += 1;
869       str.cost = str.cost + cost_in_;
870       break;
871 
872     default:
873       MS_LOG(EXCEPTION) << "Failure: CostBiasAdd failed.";
874   }
875   return str;
876 }
877 
878 // Get optimal strategy for Common OPs
GetOptimalStr(const Graph::NodeType & node,const std::vector<std::pair<std::string,StrategyRec>> & node_name_to_strategy,const Graph & graph)879 StrategyRec CostCommon::GetOptimalStr(const Graph::NodeType &node,
880                                       const std::vector<std::pair<std::string, StrategyRec>> &node_name_to_strategy,
881                                       const Graph &graph) {
882   const OperatorRec &op = node.apply;
883   int64_t tensor_n = static_cast<int64_t>(op.arguments[0].tensor_shape.shape_n * op.arguments[0].tensor_str.str_n);
884   int64_t tensor_c = static_cast<int64_t>(op.arguments[0].tensor_shape.shape_c * op.arguments[0].tensor_str.str_c);
885   int64_t tensor_h = static_cast<int64_t>(op.arguments[0].tensor_shape.shape_h * op.arguments[0].tensor_str.str_h);
886   int64_t tensor_w = static_cast<int64_t>(op.arguments[0].tensor_shape.shape_w * op.arguments[0].tensor_str.str_w);
887 
888   std::vector<double> cost_op;
889 
890   if (tensor_n < 2 || tensor_n % 2 != 0) {
891     cost_op.push_back(DOUBLE_MAX);
892   } else {
893     std::vector<std::vector<float>> mode = {{0.5, 1, 1, 1}, {0.5, 1, 1, 1}, {0.5, 1, 1, 1}};
894     cost_op.push_back(cost_in_ + CostRedis(node, node_name_to_strategy, mode, graph));
895   }
896 
897   if (tensor_c < 2 || tensor_c % 2 != 0) {
898     cost_op.push_back(DOUBLE_MAX);
899   } else {
900     std::vector<std::vector<float>> mode = {{1, 0.5, 1, 1}, {1, 0.5, 1, 1}, {1, 0.5, 1, 1}};
901     cost_op.push_back(cost_in_ + CostRedis(node, node_name_to_strategy, mode, graph));
902   }
903 
904   if (tensor_h < 2 || tensor_h % 2 != 0) {
905     cost_op.push_back(DOUBLE_MAX);
906   } else {
907     std::vector<std::vector<float>> mode = {{1, 1, 0.5, 1}, {1, 1, 0.5, 1}, {1, 1, 0.5, 1}};
908     cost_op.push_back(cost_in_ + CostRedis(node, node_name_to_strategy, mode, graph));
909   }
910 
911   if (tensor_w < 2 || tensor_w % 2 != 0) {
912     cost_op.push_back(DOUBLE_MAX);
913   } else {
914     std::vector<std::vector<float>> mode = {{1, 1, 1, 0.5}, {1, 1, 1, 0.5}, {1, 1, 1, 0.5}};
915     cost_op.push_back(cost_in_ + CostRedis(node, node_name_to_strategy, mode, graph));
916   }
917 
918   return ChoseStr(cost_op, node.apply.str);
919 }
920 
921 // Chose strategy for Common op
ChoseStr(const std::vector<double> & cost_op,StrategyRec str)922 StrategyRec CostCommon::ChoseStr(const std::vector<double> &cost_op, StrategyRec str) {
923   uint64_t min_position = LongToUlong(min_element(cost_op.begin(), cost_op.end()) - cost_op.begin());
924   if (cost_op[min_position] > (DOUBLE_MAX - 0.1)) {
925     return str;
926   }
927 
928   switch (min_position) {
929     case 0:
930       str.inputTensor[0].str_n /= 2.0;
931       str.outputTensor.str_n /= 2.0;
932       str.cut_counter += 1;
933       str.cost = str.cost + cost_in_;
934       break;
935 
936     case 1:
937       str.inputTensor[0].str_c /= 2.0;
938       str.outputTensor.str_c /= 2.0;
939       str.cut_counter += 1;
940       str.cost = str.cost + cost_in_;
941       break;
942 
943     case 2:
944       str.inputTensor[0].str_h /= 2.0;
945       str.outputTensor.str_h /= 2.0;
946       str.cut_counter += 1;
947       str.cost = str.cost + cost_in_;
948       break;
949 
950     case 3:
951       str.inputTensor[0].str_w /= 2.0;
952       str.outputTensor.str_w /= 2.0;
953       str.cut_counter += 1;
954       str.cost = str.cost + cost_in_;
955       break;
956 
957     default:
958       MS_LOG(EXCEPTION) << "Failure: Common failed.";
959   }
960   return str;
961 }
962 
963 // Get optimal strategy for BatchParallel OPs
GetOptimalStr(const Graph::NodeType & node)964 StrategyRec CostBatchParallel::GetOptimalStr(const Graph::NodeType &node) {
965   const OperatorRec &op = node.apply;
966   int64_t tensor_n = static_cast<int64_t>(op.arguments[0].tensor_shape.shape_n * op.arguments[0].tensor_str.str_n);
967   int64_t tensor_c = static_cast<int64_t>(op.arguments[0].tensor_shape.shape_c * op.arguments[0].tensor_str.str_c);
968   int64_t tensor_h = static_cast<int64_t>(op.arguments[0].tensor_shape.shape_h * op.arguments[0].tensor_str.str_h);
969   int64_t tensor_w = static_cast<int64_t>(op.arguments[0].tensor_shape.shape_w * op.arguments[0].tensor_str.str_w);
970 
971   std::vector<double> cost_op;
972 
973   if (tensor_n < 2 || tensor_n % 2 != 0) {
974     cost_op.push_back(DOUBLE_MAX);
975   } else {
976     cost_op.push_back(cost_in_);
977   }
978 
979   if (tensor_c < 2 || tensor_c % 2 != 0) {
980     cost_op.push_back(DOUBLE_MAX);
981   } else {
982     cost_op.push_back(cost_in_);
983   }
984 
985   if (tensor_h < 2 || tensor_h % 2 != 0) {
986     cost_op.push_back(DOUBLE_MAX);
987   } else {
988     cost_op.push_back(cost_in_);
989   }
990 
991   if (tensor_w < 2 || tensor_w % 2 != 0) {
992     cost_op.push_back(DOUBLE_MAX);
993   } else {
994     cost_op.push_back(cost_in_);
995   }
996 
997   return ChoseStr(cost_op, node.apply.str);
998 }
999 
1000 // Chose strategy for BatchParallel op
ChoseStr(const std::vector<double> & cost_op,StrategyRec str)1001 StrategyRec CostBatchParallel::ChoseStr(const std::vector<double> &cost_op, StrategyRec str) {
1002   uint64_t min_position = LongToUlong(min_element(cost_op.begin(), cost_op.end()) - cost_op.begin());
1003   if (cost_op[min_position] > (DOUBLE_MAX - 0.1)) {
1004     return str;
1005   }
1006 
1007   switch (min_position) {
1008     case 0:
1009       str.inputTensor[0].str_n /= 2.0;
1010       str.outputTensor.str_n /= 2.0;
1011       str.cut_counter += 1;
1012       str.cost = str.cost + cost_in_;
1013       break;
1014 
1015     case 1:
1016       str.inputTensor[0].str_c /= 2.0;
1017       str.outputTensor.str_c /= 2.0;
1018       str.cut_counter += 1;
1019       str.cost = str.cost + cost_in_;
1020       break;
1021 
1022     case 2:
1023       str.inputTensor[0].str_h /= 2.0;
1024       str.outputTensor.str_h /= 2.0;
1025       str.cut_counter += 1;
1026       str.cost = str.cost + cost_in_;
1027       break;
1028 
1029     case 3:
1030       str.inputTensor[0].str_w /= 2.0;
1031       str.outputTensor.str_w /= 2.0;
1032       str.cut_counter += 1;
1033       str.cost = str.cost + cost_in_;
1034       break;
1035 
1036     default:
1037       MS_LOG(EXCEPTION) << "Failure: CostBatchParallel failed.";
1038   }
1039   return str;
1040 }
1041 
1042 // Chose strategy for CostSoftmaxCrossEntropyWithLogits
ChoseStr(const std::vector<double> & cost_op,StrategyRec str)1043 StrategyRec CostSoftmaxCrossEntropyWithLogits::ChoseStr(const std::vector<double> &cost_op, StrategyRec str) {
1044   uint64_t min_position = LongToUlong(min_element(cost_op.begin(), cost_op.end()) - cost_op.begin());
1045   if (cost_op[min_position] > (DOUBLE_MAX - 0.1)) {
1046     return str;
1047   }
1048 
1049   switch (min_position) {
1050     case 0:
1051       str.inputTensor[0].str_n /= 2.0;
1052       str.inputTensor[1].str_n /= 2.0;
1053       str.cut_counter += 1;
1054       str.cost = str.cost + cost_in_;
1055       break;
1056 
1057     case 1:
1058       str.inputTensor[0].str_c /= 2.0;
1059       str.inputTensor[1].str_c /= 2.0;
1060       str.cut_counter += 1;
1061       str.cost = str.cost + cost_in_;
1062       break;
1063 
1064     case 2:
1065       str.inputTensor[0].str_h /= 2.0;
1066       str.inputTensor[1].str_h /= 2.0;
1067       str.outputTensor.str_w /= 2.0;
1068       str.cut_counter += 1;
1069       str.cost = str.cost + cost_in_;
1070       break;
1071 
1072     case 3:
1073       str.inputTensor[0].str_w /= 2.0;
1074       str.inputTensor[1].str_w /= 2.0;
1075       str.cut_counter += 1;
1076       str.cost = str.cost + cost_in_;
1077       break;
1078 
1079     default:
1080       MS_LOG(EXCEPTION) << "Failure: CostSoftmax failed.";
1081   }
1082   return str;
1083 }
1084 }  // namespace parallel
1085 }  // namespace mindspore
1086