1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/parallel/auto_parallel/rec_core/rec_cost.h"
18
19 #include <algorithm>
20 #include <limits>
21 #include <string>
22 #include <utility>
23 #include <vector>
24
25 #include "include/common/utils/parallel_context.h"
26
27 namespace mindspore {
28 namespace parallel {
SameShape(const Shape4D & shape1,const Shape4D & shape2)29 bool SameShape(const Shape4D &shape1, const Shape4D &shape2) {
30 bool equal = (shape1 == shape2);
31
32 return (equal || !ONLY_REDIST_WITH_SAME_SHAPE);
33 }
34
costOfDistributing(const TensorParam & t)35 double costOfDistributing(const TensorParam &t) {
36 return (static_cast<double>(t.tensor_shape.shape_n) * t.tensor_str.str_n *
37 static_cast<double>(t.tensor_shape.shape_c) * t.tensor_str.str_c *
38 static_cast<double>(t.tensor_shape.shape_h) * t.tensor_str.str_h *
39 static_cast<double>(t.tensor_shape.shape_w) * t.tensor_str.str_w / 2.0);
40 }
41
minNodeSize(const Graph::NodeType & node)42 double minNodeSize(const Graph::NodeType &node) {
43 double distributing0 = costOfDistributing(node.apply.arguments[0]);
44 double distributing1 = costOfDistributing(node.apply.arguments[1]);
45 double distributing2 = costOfDistributing(node.tensor_parm);
46 double min_distribution = std::min(distributing0, distributing1);
47 min_distribution = std::min(min_distribution, distributing2);
48 min_distribution *= EXPERT_COEF;
49 return min_distribution;
50 }
51
52 // Compute redistributed cost
CostRedis(const Graph::NodeType & node,const std::vector<std::pair<std::string,StrategyRec>> & node_name_to_strategy,const std::vector<std::vector<float>> & mode,const Graph & graph)53 double CostRedis(const Graph::NodeType &node,
54 const std::vector<std::pair<std::string, StrategyRec>> &node_name_to_strategy,
55 const std::vector<std::vector<float>> &mode, const Graph &graph) {
56 // Store value of cost redist
57 double cost_redis = 0;
58
59 // Number of current strategies.
60 size_t num_strategy = node_name_to_strategy.size();
61
62 // Number of node-in and node-out
63 size_t num_node_in = node.node_in.size();
64 size_t num_node_out = node.node_out.size();
65
66 // Set tensor edge value with original tensor shape and cutting times.
67 double input_tensor = node.apply.arguments[0].tensor_shape.shape_n * node.apply.arguments[0].tensor_str.str_n *
68 node.apply.arguments[0].tensor_shape.shape_c * node.apply.arguments[0].tensor_str.str_c *
69 node.apply.arguments[0].tensor_shape.shape_h * node.apply.arguments[0].tensor_str.str_h *
70 node.apply.arguments[0].tensor_shape.shape_w * node.apply.arguments[0].tensor_str.str_w;
71
72 double output_tensor = node.tensor_parm.tensor_shape.shape_n * node.tensor_parm.tensor_str.str_n *
73 node.tensor_parm.tensor_shape.shape_c * node.tensor_parm.tensor_str.str_c *
74 node.tensor_parm.tensor_shape.shape_h * node.tensor_parm.tensor_str.str_h *
75 node.tensor_parm.tensor_shape.shape_w * node.tensor_parm.tensor_str.str_w;
76
77 // For each strategy candidate.
78 for (size_t i_strategy = 0; i_strategy < num_strategy; i_strategy++) {
79 // Find its forward nodes
80 for (size_t i_node = 0; i_node < num_node_in; i_node++) {
81 if (graph.nodes[node.node_in[i_node]].name == node_name_to_strategy[i_strategy].first &&
82 SameShape(graph.nodes[node.node_in[i_node]].tensor_parm.tensor_shape,
83 node.apply.arguments[i_node].tensor_shape)) {
84 bool is_search_forward = true;
85 cost_redis +=
86 CostRedisWithAdjacentNode(node_name_to_strategy, mode, i_strategy, i_node, input_tensor, is_search_forward);
87 }
88 }
89
90 // Find its backward nodes
91 for (size_t i_node = 0; i_node < num_node_out; i_node++) {
92 bool is_same_shape =
93 SameShape(graph.nodes[node.node_out[i_node]].apply.arguments[0].tensor_shape, node.tensor_parm.tensor_shape) ||
94 SameShape(graph.nodes[node.node_out[i_node]].apply.arguments[1].tensor_shape, node.tensor_parm.tensor_shape);
95 if (graph.nodes[node.node_out[i_node]].name == node_name_to_strategy[i_strategy].first && is_same_shape) {
96 bool is_search_forward = false;
97 cost_redis +=
98 CostRedisWithAdjacentNode(node_name_to_strategy, mode, i_strategy, i_node, output_tensor, is_search_forward);
99 }
100 }
101
102 // Calculate the Redis Cost of node_in_aux
103 for (size_t i_node = 0; i_node < node.node_in_aux.size(); i_node++) {
104 size_t index = node.node_in_aux_idx[i_node];
105 if (graph.nodes[node.node_in_aux[i_node]].name == node_name_to_strategy[i_strategy].first &&
106 SameShape(graph.nodes[node.node_in_aux[i_node]].tensor_parm.tensor_shape,
107 node.apply.arguments[index].tensor_shape)) {
108 bool is_search_forward = true;
109 cost_redis +=
110 CostRedisWithAdjacentNode(node_name_to_strategy, mode, i_strategy, index, input_tensor, is_search_forward);
111 }
112 }
113 }
114
115 return cost_redis;
116 }
117
CostRedisWithAdjacentNode(const std::vector<std::pair<std::string,StrategyRec>> & node_name_to_strategy,const std::vector<std::vector<float>> & mode,size_t i_strategy,size_t i_node,double tensor_size,bool search_forward)118 double CostRedisWithAdjacentNode(const std::vector<std::pair<std::string, StrategyRec>> &node_name_to_strategy,
119 const std::vector<std::vector<float>> &mode, size_t i_strategy, size_t i_node,
120 double tensor_size, bool search_forward) {
121 double new_redis_cost = 0;
122 bool diff = false;
123
124 auto output_tensor = node_name_to_strategy[i_strategy].second.outputTensor;
125 auto input_tensor = node_name_to_strategy[i_strategy].second.inputTensor[0];
126
127 if (search_forward) {
128 float output_dims[NDIMS] = {output_tensor.str_n, output_tensor.str_c, output_tensor.str_h, output_tensor.str_w};
129 for (size_t i = 0; i < NDIMS; ++i) {
130 if (output_dims[i] == 0 || mode[i_node][i] == 0) {
131 MS_LOG(EXCEPTION) << "divisors cannot be 0!";
132 }
133 if (static_cast<int64_t>(1 / output_dims[i]) != static_cast<int64_t>(1 / mode[i_node][i])) {
134 diff = true;
135 break;
136 }
137 }
138 } else {
139 float input_dims[NDIMS] = {input_tensor.str_n, input_tensor.str_c, input_tensor.str_h, input_tensor.str_w};
140 for (size_t i = 0; i < NDIMS; ++i) {
141 if (input_dims[i] == 0 || mode[2][i] == 0) {
142 MS_LOG(EXCEPTION) << "divisors cannot be 0!";
143 }
144 if (static_cast<int64_t>(1 / input_dims[i]) != static_cast<int64_t>(1 / mode[2][i])) {
145 diff = true;
146 break;
147 }
148 }
149 }
150
151 if (diff) {
152 new_redis_cost = tensor_size * REDIS_COEF;
153 }
154
155 return new_redis_cost;
156 }
157
hasBeenSplitted(const Graph::NodeType & node,const bool dyn_shape_tmp_fix)158 bool hasBeenSplitted(const Graph::NodeType &node, const bool dyn_shape_tmp_fix) {
159 if (dyn_shape_tmp_fix) {
160 if (node.apply.arguments[0].tensor_str.str_h < 1 || node.apply.arguments[0].tensor_str.str_w < 1 ||
161 node.apply.arguments[1].tensor_str.str_w < 1 || node.apply.arguments[0].tensor_str.str_n < 1 ||
162 node.apply.arguments[0].tensor_str.str_c < 1) {
163 return true;
164 }
165 }
166 return false;
167 }
168
169 // Get optimal strategy for MatMul
GetOptimalStr(const Graph::NodeType & node,const std::vector<std::pair<std::string,StrategyRec>> & node_name_to_strategy,const Graph & graph,const bool isTraining)170 StrategyRec CostMatMul::GetOptimalStr(const Graph::NodeType &node,
171 const std::vector<std::pair<std::string, StrategyRec>> &node_name_to_strategy,
172 const Graph &graph, const bool isTraining) {
173 int64_t edge_i =
174 static_cast<int64_t>(node.apply.arguments[0].tensor_shape.shape_h * node.apply.arguments[0].tensor_str.str_h);
175 int64_t edge_j =
176 static_cast<int64_t>(node.apply.arguments[1].tensor_shape.shape_w * node.apply.arguments[1].tensor_str.str_w);
177 int64_t edge_k =
178 static_cast<int64_t>(node.apply.arguments[0].tensor_shape.shape_w * node.apply.arguments[0].tensor_str.str_w);
179
180 bool isMicroBatchSizeLargeEnough = true;
181 if (parallel::ParallelContext::GetInstance()->pipeline_stage_split_num() > 1) {
182 if (graph.micro_batch_size * node.apply.arguments[0].tensor_str.str_h <= 1) {
183 isMicroBatchSizeLargeEnough = false;
184 }
185 }
186
187 std::vector<double> cost_op;
188 if (node.apply.arguments[0].tensor_str.str_h == 0) {
189 MS_LOG(EXCEPTION) << "str_h cannot be 0!";
190 }
191 if (edge_i < INT64_TWO || edge_i % INT64_TWO != 0 || !isMicroBatchSizeLargeEnough) {
192 cost_op.push_back(DOUBLE_MAX);
193 } else {
194 std::vector<std::vector<float>> mode = {{1, 1, 0.5, 1}, {1, 1, 1, 1}, {1, 1, 0.5, 1}};
195 double cost_if_cut_i = StrConcatDimI(edge_j, edge_k);
196 double redist_if_cut_i = CostRedis(node, node_name_to_strategy, mode, graph);
197 double total_cost_if_cut_i = cost_if_cut_i + redist_if_cut_i;
198 MS_LOG(INFO) << "If the I-axis is cut, the op-cost is " << cost_if_cut_i << ", the redist-cost is "
199 << redist_if_cut_i << ", and the total cost is " << total_cost_if_cut_i;
200 cost_op.push_back(total_cost_if_cut_i);
201 }
202
203 // Do not partition the J-axis and K-axis for the same MatMul
204 if (edge_j < INT64_TWO || edge_j % INT64_TWO != 0 || node.apply.arguments[0].tensor_str.str_w < 1) {
205 cost_op.push_back(DOUBLE_MAX);
206 } else {
207 std::vector<std::vector<float>> mode = {{1, 1, 1, 1}, {1, 1, 1, 0.5}, {1, 1, 1, 0.5}};
208 double cost_if_cut_j = StrConcatDimJ(edge_i, edge_k);
209 double redist_if_cut_j = CostRedis(node, node_name_to_strategy, mode, graph);
210 double total_cost_if_cut_j = cost_if_cut_j + redist_if_cut_j;
211 MS_LOG(INFO) << "If the J-axis is cut, the op-cost is " << cost_if_cut_j << ", the redist-cost is "
212 << redist_if_cut_j << ", and the total cost is " << total_cost_if_cut_j;
213 cost_op.push_back(total_cost_if_cut_j);
214 }
215
216 if (edge_k < INT64_TWO || edge_k % INT64_TWO != 0 || node.apply.arguments[1].tensor_str.str_w < 1) {
217 cost_op.push_back(DOUBLE_MAX);
218 } else {
219 std::vector<std::vector<float>> mode = {{1, 1, 1, 0.5}, {1, 1, 0.5, 1}, {1, 1, 1, 1}};
220 double cost_if_cut_k = StrReduceDimK(edge_i, edge_j);
221 double redist_if_cut_k = CostRedis(node, node_name_to_strategy, mode, graph);
222 double total_cost_if_cut_k = cost_if_cut_k + redist_if_cut_k;
223 MS_LOG(INFO) << "If the K-axis is cut, the op-cost is " << cost_if_cut_k << ", the redist-cost is "
224 << redist_if_cut_k << ", and the total cost is " << total_cost_if_cut_k;
225 cost_op.push_back(total_cost_if_cut_k);
226 }
227
228 if (hasBeenSplitted(node, graph.dyn_shape_tmp_fix)) {
229 cost_op.push_back(DOUBLE_MAX);
230 } else {
231 std::vector<std::vector<float>> mode = {{1, 1, 1, 1}, {1, 1, 1, 1}, {1, 1, 1, 1}};
232 double cost_if_no_cut =
233 StrRecom(StrConcatDimI(edge_j, edge_k), StrConcatDimJ(edge_i, edge_k), StrReduceDimK(edge_i, edge_j));
234 double redist_if_no_cut = CostRedis(node, node_name_to_strategy, mode, graph);
235 double total_cost_if_no_cut = cost_if_no_cut + redist_if_no_cut;
236 MS_LOG(INFO) << "If do NOT cut the axis, the op-cost is " << cost_if_no_cut << ", the redist-cost is "
237 << redist_if_no_cut << ", and the total cost is " << total_cost_if_no_cut;
238 cost_op.push_back(total_cost_if_no_cut);
239 }
240
241 for (auto &cost : cost_op) {
242 cost = std::abs(cost);
243 }
244
245 return ChoseStr(cost_op, node.apply.str);
246 }
247
248 // Get weight for MatMul
GetMaxCostIn(const OperatorRec & op)249 double CostMatMul::GetMaxCostIn(const OperatorRec &op) {
250 int64_t edge_i = static_cast<int64_t>(op.arguments[0].tensor_shape.shape_h * op.arguments[0].tensor_str.str_h);
251 int64_t edge_j = static_cast<int64_t>(op.arguments[1].tensor_shape.shape_w * op.arguments[1].tensor_str.str_w);
252 int64_t edge_k = static_cast<int64_t>(op.arguments[0].tensor_shape.shape_w * op.arguments[0].tensor_str.str_w);
253
254 double cost_if_cut_i = StrConcatDimI(edge_j, edge_k);
255 double cost_if_cut_j = StrConcatDimJ(edge_i, edge_k);
256 double cost_if_cut_k = StrReduceDimK(edge_i, edge_j);
257 double cost_if_no_cut = StrRecom(cost_if_cut_i, cost_if_cut_j, cost_if_cut_k);
258
259 std::vector<double> cost_in;
260 cost_in.push_back(cost_if_cut_i);
261 cost_in.push_back(cost_if_cut_j);
262 cost_in.push_back(cost_if_cut_k);
263 cost_in.push_back(cost_if_no_cut);
264
265 return *max_element(cost_in.begin(), cost_in.end());
266 }
267
268 // Chose strategy for MatMul
ChoseStr(const std::vector<double> & cost_op,StrategyRec str) const269 StrategyRec CostMatMul::ChoseStr(const std::vector<double> &cost_op, StrategyRec str) const {
270 MS_LOG(INFO) << "The costs of cutting the I-axis/J-axis/K-axis/no_cut are : " << cost_op;
271 uint64_t min_position = LongToUlong(min_element(cost_op.begin(), cost_op.end()) - cost_op.begin());
272 if (cost_op[min_position] > (DOUBLE_MAX - 0.1)) {
273 return str;
274 }
275
276 switch (min_position) {
277 case 0:
278 str.inputTensor[0].str_h /= 2.0;
279 str.outputTensor.str_h /= 2.0;
280 str.cut_counter += 1;
281 str.cost = str.cost + cost_in_i_;
282 MS_LOG(INFO) << "The I-axis is chosen to cut";
283 break;
284
285 case 1:
286 str.inputTensor[1].str_w /= 2.0;
287 str.outputTensor.str_w /= 2.0;
288 str.cut_counter += 1;
289 str.cost = str.cost + cost_in_j_;
290 MS_LOG(INFO) << "The J-axis is chosen to cut";
291 break;
292
293 case 2:
294 str.inputTensor[0].str_w /= 2.0;
295 str.inputTensor[1].str_h /= 2.0;
296 str.cut_counter += 1;
297 str.cost = str.cost + cost_in_k_;
298 MS_LOG(INFO) << "The K-axis is chosen to cut";
299 break;
300
301 case 3:
302 MS_LOG(INFO) << "Choose NOT to cut";
303 break;
304
305 default:
306 MS_LOG(EXCEPTION) << "Failure:CostMatMul failed.";
307 }
308
309 return str;
310 }
311
getBatchDimsSize(const OperatorRec & op)312 size_t CostBatchMatMul::getBatchDimsSize(const OperatorRec &op) {
313 return static_cast<double>(std::max(op.arguments[0].tensor_shape.shape_n, op.arguments[1].tensor_shape.shape_n)) *
314 std::max(op.arguments[0].tensor_str.str_n, op.arguments[1].tensor_str.str_n) *
315 static_cast<double>(std::max(op.arguments[0].tensor_shape.shape_c, op.arguments[1].tensor_shape.shape_c)) *
316 std::max(op.arguments[0].tensor_str.str_c, op.arguments[1].tensor_str.str_c);
317 }
318
cost(Axis a,const Graph::NodeType & node)319 double CostBatchMatMul::cost(Axis a, const Graph::NodeType &node) {
320 double mc_ratio;
321 size_t batch_dims_size = getBatchDimsSize(node.apply);
322 if (batch_dims_size == 1) {
323 mc_ratio = static_cast<double>(NUMBER_ASCEND_CORES);
324 } else {
325 mc_ratio = std::max(NUMBER_ASCEND_CORES / static_cast<double>(batch_dims_size) - 1, 0.0);
326 }
327 double min_size = minNodeSize(node);
328
329 switch (a) {
330 // Calculate the cost if the Batch-axis of BatchMatMul is cut
331 case B:
332 return (mc_ratio * min_size);
333
334 // Calculate the cost if the Expert-axis of BatchMatMul is cut
335 case X:
336 return (mc_ratio * min_size) - 1;
337
338 // Calculate the cost if the I-axis of BatchMatMul is cut
339 case I:
340 return costOfDistributing(node.apply.arguments[1]);
341
342 // Calculate the cost if the J-axis of BatchMatMul is cut
343 case J:
344 return costOfDistributing(node.apply.arguments[0]);
345
346 // Calculate the cost if the K-axis of BatchMatMul is cut
347 case K:
348 return costOfDistributing(node.tensor_parm);
349
350 // Calculate the cost if BatchMatMul is not cut
351 case R:
352 return min_size * min_size / REPLICATE_BELOW;
353
354 default:
355 MS_LOG(EXCEPTION) << "Axis " << a << " is not taken into account";
356 }
357
358 return 1;
359 }
360
SplitOnlyOneDimension(const Graph & graph,float str)361 bool SplitOnlyOneDimension(const Graph &graph, float str) {
362 if (graph.dyn_shape_tmp_fix && str < 1) {
363 return true;
364 }
365 return false;
366 }
367
IsEdgeSplittable(const int64_t edge)368 bool IsEdgeSplittable(const int64_t edge) {
369 if (edge < INT64_TWO || edge % INT64_TWO != 0) {
370 return false;
371 }
372 return true;
373 }
374
375 // Get optimal strategy for BatchMatMul
GetOptimalStr(const Graph::NodeType & node,const std::vector<std::pair<std::string,StrategyRec>> & node_name_to_strategy,const Graph & graph,const bool isTraining)376 StrategyRec CostBatchMatMul::GetOptimalStr(
377 const Graph::NodeType &node, const std::vector<std::pair<std::string, StrategyRec>> &node_name_to_strategy,
378 const Graph &graph, const bool isTraining) {
379 int64_t edge_b =
380 static_cast<int64_t>(node.apply.arguments[0].tensor_shape.shape_n * node.apply.arguments[0].tensor_str.str_n);
381 int64_t edge_x =
382 static_cast<int64_t>(node.apply.arguments[0].tensor_shape.shape_c * node.apply.arguments[0].tensor_str.str_c);
383 int64_t edge_i =
384 static_cast<int64_t>(node.apply.arguments[0].tensor_shape.shape_h * node.apply.arguments[0].tensor_str.str_h);
385 int64_t edge_j =
386 static_cast<int64_t>(node.apply.arguments[1].tensor_shape.shape_w * node.apply.arguments[1].tensor_str.str_w);
387 int64_t edge_k =
388 static_cast<int64_t>(node.apply.arguments[0].tensor_shape.shape_w * node.apply.arguments[0].tensor_str.str_w);
389
390 bool isMicroBatchSizeLargeEnough = true;
391 if (parallel::ParallelContext::GetInstance()->pipeline_stage_split_num() > 1) {
392 if (graph.micro_batch_size * node.apply.arguments[0].tensor_str.str_n <= 1) {
393 isMicroBatchSizeLargeEnough = false;
394 }
395 }
396
397 std::vector<double> cost_op;
398 if (node.apply.arguments[0].tensor_str.str_n == 0) {
399 MS_LOG(EXCEPTION) << "str_n cannot be 0!";
400 }
401 if (!IsEdgeSplittable(edge_b) || !isMicroBatchSizeLargeEnough) {
402 cost_op.push_back(DOUBLE_MAX);
403 } else {
404 std::vector<std::vector<float>> mode = {{0.5, 1, 1, 1}, {0.5, 1, 1, 1}, {0.5, 1, 1, 1}};
405 double cost_if_cut_b = cost(B, node);
406 double redist_if_cut_b = CostRedis(node, node_name_to_strategy, mode, graph);
407 double total_cost_if_cut_b = cost_if_cut_b + redist_if_cut_b;
408 MS_LOG(INFO) << "If the Batch-axis is cut, the op-cost is " << cost_if_cut_b << ", the redist-cost is "
409 << redist_if_cut_b << ", and the total cost is " << total_cost_if_cut_b;
410 cost_op.push_back(total_cost_if_cut_b);
411 }
412
413 if (!IsEdgeSplittable(edge_x)) {
414 cost_op.push_back(DOUBLE_MAX);
415 } else {
416 std::vector<std::vector<float>> mode = {{1, 0.5, 1, 1}, {1, 0.5, 1, 1}, {1, 0.5, 1, 1}};
417 double cost_if_cut_x = cost(X, node);
418 double redist_if_cut_x = CostRedis(node, node_name_to_strategy, mode, graph);
419 double total_cost_if_cut_x = cost_if_cut_x + redist_if_cut_x;
420 MS_LOG(INFO) << "If the Expert-axis is cut, the op-cost is " << cost_if_cut_x << ", the redist-cost is "
421 << redist_if_cut_x << ", and the total cost is " << total_cost_if_cut_x;
422 cost_op.push_back(total_cost_if_cut_x);
423 }
424
425 if (!IsEdgeSplittable(edge_i) || SplitOnlyOneDimension(graph, node.apply.arguments[0].tensor_str.str_c)) {
426 cost_op.push_back(DOUBLE_MAX);
427 } else {
428 std::vector<std::vector<float>> mode = {{1, 1, 0.5, 1}, {1, 1, 1, 1}, {1, 1, 0.5, 1}};
429 double cost_if_cut_i = cost(I, node);
430 double redist_if_cut_i = CostRedis(node, node_name_to_strategy, mode, graph);
431 double total_cost_if_cut_i = cost_if_cut_i + redist_if_cut_i;
432 MS_LOG(INFO) << "If the I-axis is cut, the op-cost is " << cost_if_cut_i << ", the redist-cost is "
433 << redist_if_cut_i << ", and the total cost is " << total_cost_if_cut_i;
434 cost_op.push_back(total_cost_if_cut_i);
435 }
436
437 if (!IsEdgeSplittable(edge_j) || node.apply.arguments[0].tensor_str.str_w < 1 ||
438 SplitOnlyOneDimension(graph, node.apply.arguments[0].tensor_str.str_c)) {
439 cost_op.push_back(DOUBLE_MAX);
440 } else {
441 std::vector<std::vector<float>> mode = {{1, 1, 1, 1}, {1, 1, 1, 0.5}, {1, 1, 1, 0.5}};
442 double cost_if_cut_j = cost(J, node);
443 double redist_if_cut_j = CostRedis(node, node_name_to_strategy, mode, graph);
444 double total_cost_if_cut_j = cost_if_cut_j + redist_if_cut_j;
445 MS_LOG(INFO) << "If the J-axis is cut, the op-cost is " << cost_if_cut_j << ", the redist-cost is "
446 << redist_if_cut_j << ", and the total cost is " << total_cost_if_cut_j;
447 cost_op.push_back(total_cost_if_cut_j / BMM_COEF);
448 }
449
450 if (!IsEdgeSplittable(edge_k) || node.apply.arguments[1].tensor_str.str_w < 1 ||
451 SplitOnlyOneDimension(graph, node.apply.arguments[0].tensor_str.str_c)) {
452 cost_op.push_back(DOUBLE_MAX);
453 } else {
454 std::vector<std::vector<float>> mode = {{1, 1, 1, 0.5}, {1, 1, 0.5, 1}, {1, 1, 1, 1}};
455 double cost_if_cut_k = cost(K, node);
456 double redist_if_cut_k = CostRedis(node, node_name_to_strategy, mode, graph);
457 double total_cost_if_cut_k = cost_if_cut_k + redist_if_cut_k;
458 MS_LOG(INFO) << "If the K-axis is cut, the op-cost is " << cost_if_cut_k << ", the redist-cost is "
459 << redist_if_cut_k << ", and the total cost is " << total_cost_if_cut_k;
460 cost_op.push_back(total_cost_if_cut_k / BMM_COEF);
461 }
462
463 if (hasBeenSplitted(node, graph.dyn_shape_tmp_fix)) {
464 cost_op.push_back(DOUBLE_MAX);
465 } else {
466 std::vector<std::vector<float>> mode = {{1, 1, 1, 1}, {1, 1, 1, 1}, {1, 1, 1, 1}};
467 double cost_if_no_cut = cost(R, node);
468 double redist_if_no_cut = CostRedis(node, node_name_to_strategy, mode, graph);
469 double total_cost_if_no_cut = cost_if_no_cut + redist_if_no_cut;
470 MS_LOG(INFO) << "If do NOT cut the axis, the op-cost is " << cost_if_no_cut << ", the redist-cost is "
471 << redist_if_no_cut << ", and the total cost is " << total_cost_if_no_cut;
472 cost_op.push_back(total_cost_if_no_cut);
473 }
474
475 for (auto &cost : cost_op) {
476 cost = std::abs(cost);
477 }
478
479 return ChoseStr(cost_op, node.apply.str);
480 }
481
482 // Get weight for BatchMatMul
GetMaxCostIn(const Graph::NodeType & node)483 double CostBatchMatMul::GetMaxCostIn(const Graph::NodeType &node) {
484 std::vector<double> cost_in;
485 cost_in.push_back(cost(B, node));
486 cost_in.push_back(cost(X, node));
487 cost_in.push_back(cost(I, node));
488 cost_in.push_back(cost(J, node));
489 cost_in.push_back(cost(K, node));
490 cost_in.push_back(cost(R, node));
491
492 return *max_element(cost_in.begin(), cost_in.end());
493 }
494
495 // Chose strategy for BatchMatMul
ChoseStr(const std::vector<double> & cost_op,StrategyRec str) const496 StrategyRec CostBatchMatMul::ChoseStr(const std::vector<double> &cost_op, StrategyRec str) const {
497 MS_LOG(INFO) << "The costs of cutting the Batch-axis/Expert-axis/I-axis/J-axis/K-axis/no_cut are : " << cost_op;
498 uint64_t min_position = min_element(cost_op.begin(), cost_op.end()) - cost_op.begin();
499 if (cost_op[min_position] > (DOUBLE_MAX - 0.1)) {
500 return str;
501 }
502
503 str.cut_counter += 1;
504 str.cost = str.cost + cost_op[min_position];
505
506 switch (min_position) {
507 case 0:
508 str.inputTensor[0].str_n /= 2.0;
509 str.inputTensor[1].str_n /= 2.0;
510 str.outputTensor.str_n /= 2.0;
511 MS_LOG(INFO) << "The Batch-axis is chosen to cut";
512 break;
513
514 case 1:
515 str.inputTensor[0].str_c /= 2.0;
516 str.inputTensor[1].str_c /= 2.0;
517 str.outputTensor.str_c /= 2.0;
518 MS_LOG(INFO) << "The Expert-axis is chosen to cut";
519 break;
520
521 case 2:
522 str.inputTensor[0].str_h /= 2.0;
523 str.outputTensor.str_h /= 2.0;
524 MS_LOG(INFO) << "The I-axis is chosen to cut";
525 break;
526
527 case 3:
528 str.inputTensor[1].str_w /= 2.0;
529 str.outputTensor.str_w /= 2.0;
530 MS_LOG(INFO) << "The J-axis is chosen to cut";
531 break;
532
533 case 4:
534 str.inputTensor[0].str_w /= 2.0;
535 str.inputTensor[1].str_h /= 2.0;
536 MS_LOG(INFO) << "The K-axis is chosen to cut";
537 break;
538
539 case 5:
540 MS_LOG(INFO) << "Choose NOT to cut";
541 break;
542
543 default:
544 MS_LOG(EXCEPTION) << "Failure:CostBatchMatMul failed.";
545 }
546
547 return str;
548 }
549
550 // Get optimal strategy for Conv
GetOptimalStr(const Graph::NodeType & node,const std::vector<std::pair<std::string,StrategyRec>> & node_name_to_strategy,const Graph & graph,bool channel_partition)551 StrategyRec CostConvolution::GetOptimalStr(
552 const Graph::NodeType &node, const std::vector<std::pair<std::string, StrategyRec>> &node_name_to_strategy,
553 const Graph &graph, bool channel_partition) {
554 const OperatorRec &op = node.apply;
555
556 int64_t input_tensor_h =
557 static_cast<int64_t>(op.arguments[0].tensor_shape.shape_h * op.arguments[0].tensor_str.str_h);
558 int64_t input_tensor_w =
559 static_cast<int64_t>(op.arguments[0].tensor_shape.shape_w * op.arguments[0].tensor_str.str_w);
560 int64_t input_tensor_n =
561 static_cast<int64_t>(op.arguments[0].tensor_shape.shape_n * op.arguments[0].tensor_str.str_n);
562 int64_t input_tensor_c =
563 static_cast<int64_t>(op.arguments[0].tensor_shape.shape_c * op.arguments[0].tensor_str.str_c);
564
565 int64_t tensor_in = input_tensor_h * input_tensor_w * input_tensor_n * input_tensor_c;
566
567 int64_t tensor_filter_h =
568 static_cast<int64_t>(op.arguments[1].tensor_shape.shape_h * op.arguments[1].tensor_str.str_h);
569 int64_t tensor_filter_w =
570 static_cast<int64_t>(op.arguments[1].tensor_shape.shape_w * op.arguments[1].tensor_str.str_w);
571 int64_t tensor_filter_n =
572 static_cast<int64_t>(op.arguments[1].tensor_shape.shape_n * op.arguments[1].tensor_str.str_n);
573 int64_t tensor_filter_c =
574 static_cast<int64_t>(op.arguments[1].tensor_shape.shape_c * op.arguments[1].tensor_str.str_c);
575
576 int64_t tensor_filter = tensor_filter_h * tensor_filter_w * tensor_filter_n * tensor_filter_c;
577
578 int64_t output_tensor_h =
579 static_cast<int64_t>(node.tensor_parm.tensor_shape.shape_h * node.tensor_parm.tensor_str.str_h);
580 int64_t output_tensor_w =
581 static_cast<int64_t>(node.tensor_parm.tensor_shape.shape_w * node.tensor_parm.tensor_str.str_w);
582 int64_t output_tensor_n =
583 static_cast<int64_t>(node.tensor_parm.tensor_shape.shape_n * node.tensor_parm.tensor_str.str_n);
584 int64_t output_tensor_c =
585 static_cast<int64_t>(node.tensor_parm.tensor_shape.shape_c * node.tensor_parm.tensor_str.str_c);
586
587 int64_t tensor_out = output_tensor_h * output_tensor_w * output_tensor_n * output_tensor_c;
588
589 std::vector<double> cost_op;
590 cost_op.reserve(7);
591
592 if (input_tensor_n < 2 || input_tensor_n % 2 != 0) {
593 cost_op.push_back(DOUBLE_MAX);
594 } else {
595 std::vector<std::vector<float>> mode = {{0.5, 1, 1, 1}, {1, 1, 1, 1}, {0.5, 1, 1, 1}};
596 cost_op.push_back(StrDimB(tensor_filter) + CostRedis(node, node_name_to_strategy, mode, graph));
597 }
598
599 cost_op.push_back(DOUBLE_MAX);
600 cost_op.push_back(DOUBLE_MAX);
601
602 if (channel_partition == false || tensor_filter < 2 || tensor_filter % 2 != 0) {
603 cost_op.push_back(DOUBLE_MAX);
604 } else {
605 std::vector<std::vector<float>> mode = {{1, 1, 1, 1}, {0.5, 1, 1, 1}, {1, 0.5, 1, 1}};
606 cost_op.push_back(StrDimK(tensor_in) + CostRedis(node, node_name_to_strategy, mode, graph));
607 }
608
609 cost_op.push_back(DOUBLE_MAX);
610 cost_op.push_back(DOUBLE_MAX);
611
612 if (channel_partition == false || tensor_filter_c < 2 || tensor_filter_c % 2 != 0) {
613 cost_op.push_back(DOUBLE_MAX);
614 } else {
615 std::vector<std::vector<float>> mode = {{1, 0.5, 1, 1}, {1, 0.5, 1, 1}, {1, 1, 1, 1}};
616 cost_op.push_back(StrDimQ(tensor_out) + CostRedis(node, node_name_to_strategy, mode, graph));
617 }
618
619 return ChoseStr(cost_op, node.apply.str);
620 }
621
622 // Get weight for Conv
GetMinCostIn(const Graph::NodeType & node)623 double CostConvolution::GetMinCostIn(const Graph::NodeType &node) {
624 const OperatorRec &op = node.apply;
625
626 int64_t tensor_in = static_cast<int64_t>(op.arguments[0].tensor_shape.shape_h * op.arguments[0].tensor_str.str_h) *
627 static_cast<int64_t>(op.arguments[0].tensor_shape.shape_n * op.arguments[0].tensor_str.str_n) *
628 static_cast<int64_t>(op.arguments[0].tensor_shape.shape_w * op.arguments[0].tensor_str.str_w) *
629 static_cast<int64_t>(op.arguments[0].tensor_shape.shape_c * op.arguments[0].tensor_str.str_c);
630 int64_t tensor_filter =
631 static_cast<int64_t>(op.arguments[1].tensor_shape.shape_h * op.arguments[1].tensor_str.str_h) *
632 static_cast<int64_t>(op.arguments[1].tensor_shape.shape_n * op.arguments[1].tensor_str.str_n) *
633 static_cast<int64_t>(op.arguments[1].tensor_shape.shape_w * op.arguments[1].tensor_str.str_w) *
634 static_cast<int64_t>(op.arguments[1].tensor_shape.shape_c * op.arguments[1].tensor_str.str_c);
635 int64_t tensor_out = static_cast<int64_t>(node.tensor_parm.tensor_shape.shape_h * node.tensor_parm.tensor_str.str_h) *
636 static_cast<int64_t>(node.tensor_parm.tensor_shape.shape_n * node.tensor_parm.tensor_str.str_n) *
637 static_cast<int64_t>(node.tensor_parm.tensor_shape.shape_w * node.tensor_parm.tensor_str.str_w) *
638 static_cast<int64_t>(node.tensor_parm.tensor_shape.shape_c * node.tensor_parm.tensor_str.str_c);
639
640 std::vector<double> cost_in;
641 cost_in.push_back(StrDimB(tensor_filter));
642 cost_in.push_back(StrDimI(tensor_in, tensor_filter));
643 cost_in.push_back(StrDimJ(tensor_in, tensor_filter));
644 cost_in.push_back(StrDimK(tensor_in));
645 cost_in.push_back(StrDimDI(tensor_in, tensor_out));
646 cost_in.push_back(StrDimDJ(tensor_in, tensor_out));
647 cost_in.push_back(StrDimQ(tensor_out));
648
649 return *min_element(cost_in.begin(), cost_in.end());
650 }
651
652 // Chose strategy for Conv
ChoseStr(const std::vector<double> & cost_op,StrategyRec str) const653 StrategyRec CostConvolution::ChoseStr(const std::vector<double> &cost_op, StrategyRec str) const {
654 uint64_t min_position = LongToUlong(min_element(cost_op.begin(), cost_op.end()) - cost_op.begin());
655 if (cost_op[min_position] > (DOUBLE_MAX - 0.1)) {
656 return str;
657 }
658
659 switch (min_position) {
660 case 0:
661 str.inputTensor[0].str_n /= 2.0;
662 str.outputTensor.str_n /= 2.0;
663 str.cut_counter += 1;
664 str.cost = str.cost + cost_in_b_;
665 break;
666
667 case 1:
668 str.inputTensor[0].str_h /= 2.0;
669 str.outputTensor.str_h /= 2.0;
670 str.cut_counter += 1;
671 str.cost = str.cost + cost_in_i_;
672 break;
673
674 case 2:
675 str.inputTensor[0].str_w /= 2.0;
676 str.outputTensor.str_w /= 2.0;
677 str.cut_counter += 1;
678 str.cost = str.cost + cost_in_j_;
679 break;
680
681 case 3:
682 str.inputTensor[1].str_n /= 2.0;
683 str.outputTensor.str_c /= 2.0;
684 str.cut_counter += 1;
685 str.cost = str.cost + cost_in_k_;
686 break;
687
688 case 4:
689 str.inputTensor[1].str_h /= 2.0;
690 str.cut_counter += 1;
691 str.cost = str.cost + cost_in_di_;
692 break;
693
694 case 5:
695 str.inputTensor[1].str_w /= 2.0;
696 str.cut_counter += 1;
697 str.cost = str.cost + cost_in_dj_;
698 break;
699
700 case 6:
701 str.inputTensor[0].str_c /= 2.0;
702 str.inputTensor[1].str_c /= 2.0;
703 str.cut_counter += 1;
704 str.cost = str.cost + cost_in_q_;
705 break;
706
707 default:
708 MS_LOG(EXCEPTION) << "Failure: CostConvolution failed.";
709 }
710 return str;
711 }
712
713 // Get optimal strategy for Pooling
GetOptimalStr(const Graph::NodeType & node,const std::vector<std::pair<std::string,StrategyRec>> & node_name_to_strategy,const Graph & graph) const714 StrategyRec CostPooling::GetOptimalStr(const Graph::NodeType &node,
715 const std::vector<std::pair<std::string, StrategyRec>> &node_name_to_strategy,
716 const Graph &graph) const {
717 int64_t tensor_n = static_cast<int64_t>(node.tensor_parm.tensor_shape.shape_n * node.tensor_parm.tensor_str.str_n);
718 int64_t tensor_c = static_cast<int64_t>(node.tensor_parm.tensor_shape.shape_c * node.tensor_parm.tensor_str.str_c);
719
720 std::vector<double> cost_op;
721
722 if (tensor_n < 2 || tensor_n % 2 != 0) {
723 cost_op.push_back(DOUBLE_MAX);
724 } else {
725 std::vector<std::vector<float>> mode = {{0.5, 1, 1, 1}, {0.5, 1, 1, 1}, {0.5, 1, 1, 1}};
726 cost_op.push_back(cost_in_ + CostRedis(node, node_name_to_strategy, mode, graph));
727 }
728
729 if (tensor_c < 2 || tensor_c % 2 != 0) {
730 cost_op.push_back(DOUBLE_MAX);
731 } else {
732 std::vector<std::vector<float>> mode = {{1, 0.5, 1, 1}, {1, 0.5, 1, 1}, {1, 0.5, 1, 1}};
733 cost_op.push_back(cost_in_ + CostRedis(node, node_name_to_strategy, mode, graph));
734 }
735
736 cost_op.push_back(DOUBLE_MAX);
737 cost_op.push_back(DOUBLE_MAX);
738
739 return ChoseStr(cost_op, node.apply.str);
740 }
741
742 // Chose strategy for Pooling
ChoseStr(const std::vector<double> & cost_op,StrategyRec str) const743 StrategyRec CostPooling::ChoseStr(const std::vector<double> &cost_op, StrategyRec str) const {
744 uint64_t min_position = LongToUlong(min_element(cost_op.begin(), cost_op.end()) - cost_op.begin());
745 if (cost_op[min_position] > (DOUBLE_MAX - 0.1)) {
746 return str;
747 }
748
749 switch (min_position) {
750 case 0:
751 str.inputTensor[0].str_n /= 2.0;
752 str.outputTensor.str_n /= 2.0;
753 str.cut_counter += 1;
754 str.cost = str.cost + cost_in_;
755 break;
756
757 case 1:
758 str.inputTensor[0].str_c /= 2.0;
759 str.outputTensor.str_c /= 2.0;
760 str.cut_counter += 1;
761 str.cost = str.cost + cost_in_;
762 break;
763
764 case 2:
765 str.inputTensor[0].str_h /= 2.0;
766 str.outputTensor.str_h /= 2.0;
767 str.cut_counter += 1;
768 str.cost = str.cost + cost_in_;
769 break;
770
771 case 3:
772 str.inputTensor[0].str_w /= 2.0;
773 str.outputTensor.str_w /= 2.0;
774 str.cut_counter += 1;
775 str.cost = str.cost + cost_in_;
776 break;
777
778 default:
779 MS_LOG(EXCEPTION) << "Failure: CostPooling failed.";
780 }
781 return str;
782 }
783
784 // Chose strategy for Add
ChoseStr(const std::vector<double> & cost_op,StrategyRec str)785 StrategyRec CostTensorAdd::ChoseStr(const std::vector<double> &cost_op, StrategyRec str) {
786 uint64_t min_position = LongToUlong(min_element(cost_op.begin(), cost_op.end()) - cost_op.begin());
787 if (cost_op[min_position] > (DOUBLE_MAX - 0.1)) {
788 return str;
789 }
790
791 switch (min_position) {
792 case 0:
793 str.inputTensor[0].str_n /= 2.0;
794 str.inputTensor[1].str_n /= 2.0;
795 str.outputTensor.str_n /= 2.0;
796 str.cut_counter += 1;
797 str.cost = str.cost + cost_in_;
798 break;
799
800 case 1:
801 str.inputTensor[0].str_c /= 2.0;
802 str.inputTensor[1].str_c /= 2.0;
803 str.outputTensor.str_c /= 2.0;
804 str.cut_counter += 1;
805 str.cost = str.cost + cost_in_;
806 break;
807
808 case 2:
809 str.inputTensor[0].str_h /= 2.0;
810 str.inputTensor[1].str_h /= 2.0;
811 str.outputTensor.str_h /= 2.0;
812 str.cut_counter += 1;
813 str.cost = str.cost + cost_in_;
814 break;
815
816 case 3:
817 str.inputTensor[0].str_w /= 2.0;
818 str.inputTensor[1].str_w /= 2.0;
819 str.outputTensor.str_w /= 2.0;
820 str.cut_counter += 1;
821 str.cost = str.cost + cost_in_;
822 break;
823
824 default:
825 MS_LOG(EXCEPTION) << "Failure: CostAdd failed.";
826 }
827 return str;
828 }
829
830 // Get optimal strategy for Reshape
GetOptimalStr(const Graph::NodeType & node) const831 StrategyRec CostReshape::GetOptimalStr(const Graph::NodeType &node) const { return ChoseStr(node.apply.str); }
832
ChoseStr(StrategyRec str) const833 StrategyRec CostReshape::ChoseStr(StrategyRec str) const { return str; }
834
835 // Chose strategy for BiasAdd
ChoseStr(const std::vector<double> & cost_op,StrategyRec str)836 StrategyRec CostBiasAdd::ChoseStr(const std::vector<double> &cost_op, StrategyRec str) {
837 uint64_t min_position = LongToUlong(min_element(cost_op.begin(), cost_op.end()) - cost_op.begin());
838 if (cost_op[min_position] > (DOUBLE_MAX - 0.1)) {
839 return str;
840 }
841
842 switch (min_position) {
843 case 0:
844 str.inputTensor[0].str_n /= 2.0;
845 str.outputTensor.str_n /= 2.0;
846 str.cut_counter += 1;
847 str.cost = str.cost + cost_in_;
848 break;
849
850 case 1:
851 str.inputTensor[0].str_c /= 2.0;
852 str.outputTensor.str_c /= 2.0;
853 str.cut_counter += 1;
854 str.cost = str.cost + cost_in_;
855 break;
856
857 case 2:
858 str.inputTensor[0].str_h /= 2.0;
859 str.outputTensor.str_h /= 2.0;
860 str.cut_counter += 1;
861 str.cost = str.cost + cost_in_;
862 break;
863
864 case 3:
865 str.inputTensor[0].str_w /= 2.0;
866 str.inputTensor[1].str_w /= 2.0;
867 str.outputTensor.str_w /= 2.0;
868 str.cut_counter += 1;
869 str.cost = str.cost + cost_in_;
870 break;
871
872 default:
873 MS_LOG(EXCEPTION) << "Failure: CostBiasAdd failed.";
874 }
875 return str;
876 }
877
878 // Get optimal strategy for Common OPs
GetOptimalStr(const Graph::NodeType & node,const std::vector<std::pair<std::string,StrategyRec>> & node_name_to_strategy,const Graph & graph)879 StrategyRec CostCommon::GetOptimalStr(const Graph::NodeType &node,
880 const std::vector<std::pair<std::string, StrategyRec>> &node_name_to_strategy,
881 const Graph &graph) {
882 const OperatorRec &op = node.apply;
883 int64_t tensor_n = static_cast<int64_t>(op.arguments[0].tensor_shape.shape_n * op.arguments[0].tensor_str.str_n);
884 int64_t tensor_c = static_cast<int64_t>(op.arguments[0].tensor_shape.shape_c * op.arguments[0].tensor_str.str_c);
885 int64_t tensor_h = static_cast<int64_t>(op.arguments[0].tensor_shape.shape_h * op.arguments[0].tensor_str.str_h);
886 int64_t tensor_w = static_cast<int64_t>(op.arguments[0].tensor_shape.shape_w * op.arguments[0].tensor_str.str_w);
887
888 std::vector<double> cost_op;
889
890 if (tensor_n < 2 || tensor_n % 2 != 0) {
891 cost_op.push_back(DOUBLE_MAX);
892 } else {
893 std::vector<std::vector<float>> mode = {{0.5, 1, 1, 1}, {0.5, 1, 1, 1}, {0.5, 1, 1, 1}};
894 cost_op.push_back(cost_in_ + CostRedis(node, node_name_to_strategy, mode, graph));
895 }
896
897 if (tensor_c < 2 || tensor_c % 2 != 0) {
898 cost_op.push_back(DOUBLE_MAX);
899 } else {
900 std::vector<std::vector<float>> mode = {{1, 0.5, 1, 1}, {1, 0.5, 1, 1}, {1, 0.5, 1, 1}};
901 cost_op.push_back(cost_in_ + CostRedis(node, node_name_to_strategy, mode, graph));
902 }
903
904 if (tensor_h < 2 || tensor_h % 2 != 0) {
905 cost_op.push_back(DOUBLE_MAX);
906 } else {
907 std::vector<std::vector<float>> mode = {{1, 1, 0.5, 1}, {1, 1, 0.5, 1}, {1, 1, 0.5, 1}};
908 cost_op.push_back(cost_in_ + CostRedis(node, node_name_to_strategy, mode, graph));
909 }
910
911 if (tensor_w < 2 || tensor_w % 2 != 0) {
912 cost_op.push_back(DOUBLE_MAX);
913 } else {
914 std::vector<std::vector<float>> mode = {{1, 1, 1, 0.5}, {1, 1, 1, 0.5}, {1, 1, 1, 0.5}};
915 cost_op.push_back(cost_in_ + CostRedis(node, node_name_to_strategy, mode, graph));
916 }
917
918 return ChoseStr(cost_op, node.apply.str);
919 }
920
921 // Chose strategy for Common op
ChoseStr(const std::vector<double> & cost_op,StrategyRec str)922 StrategyRec CostCommon::ChoseStr(const std::vector<double> &cost_op, StrategyRec str) {
923 uint64_t min_position = LongToUlong(min_element(cost_op.begin(), cost_op.end()) - cost_op.begin());
924 if (cost_op[min_position] > (DOUBLE_MAX - 0.1)) {
925 return str;
926 }
927
928 switch (min_position) {
929 case 0:
930 str.inputTensor[0].str_n /= 2.0;
931 str.outputTensor.str_n /= 2.0;
932 str.cut_counter += 1;
933 str.cost = str.cost + cost_in_;
934 break;
935
936 case 1:
937 str.inputTensor[0].str_c /= 2.0;
938 str.outputTensor.str_c /= 2.0;
939 str.cut_counter += 1;
940 str.cost = str.cost + cost_in_;
941 break;
942
943 case 2:
944 str.inputTensor[0].str_h /= 2.0;
945 str.outputTensor.str_h /= 2.0;
946 str.cut_counter += 1;
947 str.cost = str.cost + cost_in_;
948 break;
949
950 case 3:
951 str.inputTensor[0].str_w /= 2.0;
952 str.outputTensor.str_w /= 2.0;
953 str.cut_counter += 1;
954 str.cost = str.cost + cost_in_;
955 break;
956
957 default:
958 MS_LOG(EXCEPTION) << "Failure: Common failed.";
959 }
960 return str;
961 }
962
963 // Get optimal strategy for BatchParallel OPs
GetOptimalStr(const Graph::NodeType & node)964 StrategyRec CostBatchParallel::GetOptimalStr(const Graph::NodeType &node) {
965 const OperatorRec &op = node.apply;
966 int64_t tensor_n = static_cast<int64_t>(op.arguments[0].tensor_shape.shape_n * op.arguments[0].tensor_str.str_n);
967 int64_t tensor_c = static_cast<int64_t>(op.arguments[0].tensor_shape.shape_c * op.arguments[0].tensor_str.str_c);
968 int64_t tensor_h = static_cast<int64_t>(op.arguments[0].tensor_shape.shape_h * op.arguments[0].tensor_str.str_h);
969 int64_t tensor_w = static_cast<int64_t>(op.arguments[0].tensor_shape.shape_w * op.arguments[0].tensor_str.str_w);
970
971 std::vector<double> cost_op;
972
973 if (tensor_n < 2 || tensor_n % 2 != 0) {
974 cost_op.push_back(DOUBLE_MAX);
975 } else {
976 cost_op.push_back(cost_in_);
977 }
978
979 if (tensor_c < 2 || tensor_c % 2 != 0) {
980 cost_op.push_back(DOUBLE_MAX);
981 } else {
982 cost_op.push_back(cost_in_);
983 }
984
985 if (tensor_h < 2 || tensor_h % 2 != 0) {
986 cost_op.push_back(DOUBLE_MAX);
987 } else {
988 cost_op.push_back(cost_in_);
989 }
990
991 if (tensor_w < 2 || tensor_w % 2 != 0) {
992 cost_op.push_back(DOUBLE_MAX);
993 } else {
994 cost_op.push_back(cost_in_);
995 }
996
997 return ChoseStr(cost_op, node.apply.str);
998 }
999
1000 // Chose strategy for BatchParallel op
ChoseStr(const std::vector<double> & cost_op,StrategyRec str)1001 StrategyRec CostBatchParallel::ChoseStr(const std::vector<double> &cost_op, StrategyRec str) {
1002 uint64_t min_position = LongToUlong(min_element(cost_op.begin(), cost_op.end()) - cost_op.begin());
1003 if (cost_op[min_position] > (DOUBLE_MAX - 0.1)) {
1004 return str;
1005 }
1006
1007 switch (min_position) {
1008 case 0:
1009 str.inputTensor[0].str_n /= 2.0;
1010 str.outputTensor.str_n /= 2.0;
1011 str.cut_counter += 1;
1012 str.cost = str.cost + cost_in_;
1013 break;
1014
1015 case 1:
1016 str.inputTensor[0].str_c /= 2.0;
1017 str.outputTensor.str_c /= 2.0;
1018 str.cut_counter += 1;
1019 str.cost = str.cost + cost_in_;
1020 break;
1021
1022 case 2:
1023 str.inputTensor[0].str_h /= 2.0;
1024 str.outputTensor.str_h /= 2.0;
1025 str.cut_counter += 1;
1026 str.cost = str.cost + cost_in_;
1027 break;
1028
1029 case 3:
1030 str.inputTensor[0].str_w /= 2.0;
1031 str.outputTensor.str_w /= 2.0;
1032 str.cut_counter += 1;
1033 str.cost = str.cost + cost_in_;
1034 break;
1035
1036 default:
1037 MS_LOG(EXCEPTION) << "Failure: CostBatchParallel failed.";
1038 }
1039 return str;
1040 }
1041
1042 // Chose strategy for CostSoftmaxCrossEntropyWithLogits
ChoseStr(const std::vector<double> & cost_op,StrategyRec str)1043 StrategyRec CostSoftmaxCrossEntropyWithLogits::ChoseStr(const std::vector<double> &cost_op, StrategyRec str) {
1044 uint64_t min_position = LongToUlong(min_element(cost_op.begin(), cost_op.end()) - cost_op.begin());
1045 if (cost_op[min_position] > (DOUBLE_MAX - 0.1)) {
1046 return str;
1047 }
1048
1049 switch (min_position) {
1050 case 0:
1051 str.inputTensor[0].str_n /= 2.0;
1052 str.inputTensor[1].str_n /= 2.0;
1053 str.cut_counter += 1;
1054 str.cost = str.cost + cost_in_;
1055 break;
1056
1057 case 1:
1058 str.inputTensor[0].str_c /= 2.0;
1059 str.inputTensor[1].str_c /= 2.0;
1060 str.cut_counter += 1;
1061 str.cost = str.cost + cost_in_;
1062 break;
1063
1064 case 2:
1065 str.inputTensor[0].str_h /= 2.0;
1066 str.inputTensor[1].str_h /= 2.0;
1067 str.outputTensor.str_w /= 2.0;
1068 str.cut_counter += 1;
1069 str.cost = str.cost + cost_in_;
1070 break;
1071
1072 case 3:
1073 str.inputTensor[0].str_w /= 2.0;
1074 str.inputTensor[1].str_w /= 2.0;
1075 str.cut_counter += 1;
1076 str.cost = str.cost + cost_in_;
1077 break;
1078
1079 default:
1080 MS_LOG(EXCEPTION) << "Failure: CostSoftmax failed.";
1081 }
1082 return str;
1083 }
1084 } // namespace parallel
1085 } // namespace mindspore
1086