• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/parallel/ops_info/operator_info.h"
18 
19 #include <algorithm>
20 #include <functional>
21 #include <memory>
22 #include <string>
23 #include <utility>
24 #include <vector>
25 
26 #include "ir/dtype.h"
27 #include "ir/tensor.h"
28 #include "ir/value.h"
29 #include "frontend/parallel/auto_parallel/edge_costmodel.h"
30 #include "frontend/parallel/auto_parallel/graph_costmodel.h"
31 #include "frontend/parallel/context.h"
32 #include "utils/log_adapter.h"
33 
34 namespace mindspore {
35 namespace parallel {
StrategyToString(const Strategys & strategy)36 std::string StrategyToString(const Strategys &strategy) {
37   std::string strategy_str = "";
38   strategy_str += "(";
39   for (size_t i = 0; i < strategy.size(); ++i) {
40     strategy_str += "(";
41     for (size_t j = 0; j < strategy[i].size(); ++j) {
42       strategy_str += std::to_string(strategy[i][j]);
43       if (j != strategy[i].size() - 1) {
44         strategy_str += ", ";
45       }
46     }
47     strategy_str += ")";
48     if (i != strategy.size() - 1) {
49       strategy_str += ", ";
50     }
51   }
52   if (strategy.size() == 1) {
53     strategy_str += ",";
54   }
55   strategy_str += ")";
56   return strategy_str;
57 }
58 
CheckStrategyValue(const StrategyPtr & strategy,const Shapes & inputs_shape)59 Status OperatorInfo::CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shape) {
60   if (strategy == nullptr) {
61     MS_LOG(ERROR) << name_ << ": The strategy is null.";
62     return FAILED;
63   }
64 
65   size_t strategy_size = strategy->GetInputNumber();
66   size_t inputs_shape_size = inputs_shape.size();
67   Strategys stra = strategy->GetInputDim();
68   if (strategy_size != inputs_shape_size) {
69     if (is_auto_parallel_) {
70       MS_LOG(DEBUG) << name_ << ": The strategy is " << StrategyToString(stra) << ", strategy size: " << strategy_size
71                     << " is not equal to inputs size: " << inputs_shape_size;
72     } else {
73       MS_LOG(ERROR) << name_ << ": The strategy is " << StrategyToString(stra) << ", strategy size: " << strategy_size
74                     << " is not equal to inputs size: " << inputs_shape_size;
75     }
76     return FAILED;
77   }
78 
79   for (size_t i = 0; i < strategy_size; ++i) {
80     Shape sub_strategy = stra.at(i);
81     Shape sub_input_shape = inputs_shape.at(i);
82     size_t strategy_len = sub_strategy.size();
83     size_t inputs_len = sub_input_shape.size();
84     if (strategy_len != inputs_len) {
85       if (is_auto_parallel_) {
86         MS_LOG(DEBUG) << name_ << ": The strategy is " << StrategyToString(stra) << ", strategy len: " << strategy_len
87                       << " is not equal to inputs len: " << inputs_len << ", index: " << i;
88       } else {
89         MS_LOG(ERROR) << name_ << ": The strategy is " << StrategyToString(stra) << ", strategy len: " << strategy_len
90                       << " is not equal to inputs len: " << inputs_len << ", index: " << i;
91       }
92       return FAILED;
93     }
94 
95     for (size_t j = 0; j < strategy_len; ++j) {
96       int64_t strategy_value = sub_strategy.at(j);
97       if (strategy_value < MIN_SLICE_NUM) {
98         if (is_auto_parallel_) {
99           MS_LOG(DEBUG) << name_ << ": The strategy is " << StrategyToString(stra)
100                         << ", the value of strategy must be larger than 0, but get " << strategy_value;
101         } else {
102           MS_LOG(ERROR) << name_ << ": The strategy is " << StrategyToString(stra)
103                         << ", the value of strategy must be larger than 0, but get " << strategy_value;
104         }
105         return FAILED;
106       }
107 
108       if ((LongToUlong(strategy_value) & LongToUlong(strategy_value - 1)) != 0) {
109         if (is_auto_parallel_) {
110           MS_LOG(DEBUG) << name_ << ": The strategy is " << StrategyToString(stra)
111                         << ", the value of strategy must be the power of 2, but get " << strategy_value;
112         } else {
113           MS_LOG(ERROR) << name_ << ": The strategy is " << StrategyToString(stra)
114                         << ", the value of strategy must be the power of 2, but get " << strategy_value;
115         }
116         return FAILED;
117       }
118 
119       int64_t shape_value = sub_input_shape.at(j);
120       if ((shape_value % strategy_value) != 0) {
121         if (is_auto_parallel_) {
122           MS_LOG(DEBUG) << name_ << ": The strategy is " << StrategyToString(stra) << ", shape " << shape_value
123                         << " cannot be divisible by strategy value " << strategy_value;
124         } else {
125           MS_LOG(ERROR) << name_ << ": The strategy is " << StrategyToString(stra) << ", shape " << shape_value
126                         << " cannot be divisible by strategy value " << strategy_value;
127         }
128         return FAILED;
129       }
130     }
131   }
132 
133   return SUCCESS;
134 }
135 
ResetQueueMember()136 void OperatorInfo::ResetQueueMember() {
137   inputs_tensor_info_.clear();
138   outputs_tensor_info_.clear();
139   inputs_tensor_map_.clear();
140   outputs_tensor_map_.clear();
141   dev_matrix_shape_.clear();
142   forward_op_.clear();
143   mirror_ops_.clear();
144   sub_ops_.clear();
145   replace_op_.clear();
146   replace_op_info_.clear();
147   virtual_div_op_.clear();
148 }
149 
InferAttrs()150 Status OperatorInfo::InferAttrs() {
151   if (infer_attrs_completed_) {
152     return SUCCESS;
153   }
154 
155   if (GetAttrs() != SUCCESS) {
156     return FAILED;
157   }
158   infer_attrs_completed_ = true;
159   return SUCCESS;
160 }
161 
InferMirrorOps()162 Status OperatorInfo::InferMirrorOps() {
163   mirror_ops_.clear();
164   if (inputs_shape_.empty()) {
165     MS_LOG(INFO) << name_ << ": The inputs size is empty";
166     return SUCCESS;
167   }
168 
169   if (inputs_tensor_map_.size() != inputs_shape_.size()) {
170     MS_LOG(ERROR) << name_ << ": The size of inputs tensor map is not equal to the size of inputs shape";
171     return FAILED;
172   }
173 
174   bool group_is_empty = true;
175   for (size_t i = 0; i < inputs_tensor_map_.size(); ++i) {
176     std::vector<Group> group;
177     if (CreateGroupByTensorMap(inputs_tensor_map_[i], &group) != SUCCESS) {
178       MS_LOG(ERROR) << name_ << ": Create group failed, the input index is " << i;
179       mirror_ops_.clear();
180       return FAILED;
181     }
182 
183     OperatorVector mirror_op;
184     if (group.empty()) {
185       MS_LOG(INFO) << name_ << ": The mirror group is empty, the input index is " << i;
186       mirror_ops_.push_back(mirror_op);
187       continue;
188     }
189 
190     group_is_empty = false;
191     mirror_op = CreateMirrorOps(group[0].name(), group[0].GetDevNum());
192     mirror_ops_.push_back(mirror_op);
193   }
194 
195   if (group_is_empty) {
196     mirror_ops_.clear();
197     MS_LOG(INFO) << name_ << ": No need to insert mirror ops";
198   }
199   return SUCCESS;
200 }
201 
InferTensorInfo()202 Status OperatorInfo::InferTensorInfo() {
203   if (inputs_shape_.empty() || outputs_shape_.empty() || inputs_tensor_map_.empty() || outputs_tensor_map_.empty()) {
204     MS_LOG(ERROR) << name_ << ": Invalid args";
205     return FAILED;
206   }
207 
208   for (size_t i = 0; i < inputs_tensor_map_.size(); ++i) {
209     TensorLayout input_layout;
210     if (input_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[i], inputs_shape_[i]) != SUCCESS) {
211       MS_LOG(ERROR) << name_ << ": Infer input tensor layout failed, the index is " << i;
212       return FAILED;
213     }
214     TensorInfo input_tensor_info(input_layout);
215     inputs_tensor_info_.push_back(input_tensor_info);
216   }
217 
218   for (size_t i = 0; i < outputs_tensor_map_.size(); ++i) {
219     TensorLayout output_layout;
220     if (output_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[i], outputs_shape_[i]) != SUCCESS) {
221       MS_LOG(ERROR) << name_ << ": Infer output tensor layout failed, the index is " << i;
222       return FAILED;
223     }
224     TensorInfo output_tensor_info(output_layout);
225     outputs_tensor_info_.push_back(output_tensor_info);
226   }
227 
228   return SUCCESS;
229 }
230 
InferRepeatedCalcInfo()231 Status OperatorInfo::InferRepeatedCalcInfo() {
232   int64_t g_dev_list_size = stage_device_size_;
233   int64_t dev_matrix_size =
234     std::accumulate(dev_matrix_shape_.begin(), dev_matrix_shape_.end(), 1, std::multiplies<int64_t>());
235   if (dev_matrix_size == 0) {
236     MS_LOG(ERROR) << name_ << ": The dev matrix size is 0";
237     return FAILED;
238   }
239 
240   if (g_dev_list_size == dev_matrix_size) {
241     repeated_calc_num_ = 1;
242   } else if (g_dev_list_size % dev_matrix_size == 0) {
243     repeated_calc_num_ = ((int64_t)(g_dev_list_size / dev_matrix_size));
244   } else {
245     MS_LOG(ERROR) << name_ << ": The strategy is " << StrategyToString(strategy_->GetInputDim()) << ", it requires "
246                   << dev_matrix_size << " devices, "
247                   << "but the device number of this stage is " << g_dev_list_size << ", it can not be divisible by "
248                   << dev_matrix_size;
249     return FAILED;
250   }
251   return SUCCESS;
252 }
253 
254 // If repeated calculation, set the repeated_calc_num as the last dimension of dev-matrix in default,
255 // because if the previous shard is (a, b), and the next shard is (a, 1), adding the repeated_calc_num
256 // to the last dimension of dev-matrix, there is no need to redistribution.
SetRepeatedCalcDevMatrix()257 void OperatorInfo::SetRepeatedCalcDevMatrix() {
258   if (repeated_calc_num_ <= 1) {
259     return;
260   }
261   if (repeated_num_in_dev_matrix_right_) {
262     dev_matrix_shape_.push_back(repeated_calc_num_);
263   } else {
264     (void)dev_matrix_shape_.insert(dev_matrix_shape_.begin(), repeated_calc_num_);
265   }
266 }
267 
268 // If repeated calculation, and the repeated_calc_num is inserted to the last dimension of the dev-matrix,
269 // the index value of tensor map needs to be increased by 1.
ResetTensorMapIfRepeatedCalc()270 void OperatorInfo::ResetTensorMapIfRepeatedCalc() {
271   if ((repeated_calc_num_ <= 1) || !repeated_num_in_dev_matrix_right_) {
272     return;
273   }
274 
275   MS_LOG(DEBUG) << name_ << ": the repeated calc num is " << repeated_calc_num_ << ", and reset the tensor maps";
276   for (auto &tensor_map : inputs_tensor_map_) {
277     for (auto &element : tensor_map) {
278       if (element == MAP_NONE) {
279         continue;
280       }
281       element += 1;
282     }
283   }
284 
285   for (auto &tensor_map : outputs_tensor_map_) {
286     for (auto &element : tensor_map) {
287       if (element == MAP_NONE) {
288         continue;
289       }
290       element += 1;
291     }
292   }
293 }
294 
295 // use for loss repeated calculation
CreateVirtualDivOp(int64_t div_num)296 Operator CreateVirtualDivOp(int64_t div_num) {
297   OperatorName operator_name = VIRTUAL_DIV;
298   ValuePtr attr0_value = MakeValue(div_num);
299   Attr attr0 = std::make_pair(DIVISOR, attr0_value);
300   OperatorAttrs operator_attrs;
301   operator_attrs.push_back(attr0);
302 
303   OperatorParams operator_param;
304   OperatorArgs operator_arg = std::make_pair(operator_attrs, operator_param);
305 
306   Operator op = std::make_pair(operator_name, operator_arg);
307   return op;
308 }
309 
CreateReduceCommunicationOpArgs(const std::string & reduce_op,const std::string & group)310 static OperatorArgs CreateReduceCommunicationOpArgs(const std::string &reduce_op, const std::string &group) {
311   ValuePtr attr0_value = MakeValue(reduce_op);
312   ValuePtr attr1_value = MakeValue(group);
313   Attr attr0 = std::make_pair(OP, attr0_value);
314   Attr attr1 = std::make_pair(GROUP, attr1_value);
315   OperatorAttrs operator_attrs;
316   operator_attrs.push_back(attr0);
317   operator_attrs.push_back(attr1);
318 
319   OperatorParams operator_param;
320   return std::make_pair(operator_attrs, operator_param);
321 }
322 
323 // use for forward all reduce
CreateAllReduceOp(const std::string & reduce_op,const std::string & group)324 Operator CreateAllReduceOp(const std::string &reduce_op, const std::string &group) {
325   OperatorName operator_name = ALL_REDUCE;
326   OperatorArgs operator_arg = CreateReduceCommunicationOpArgs(reduce_op, group);
327 
328   Operator op = std::make_pair(operator_name, operator_arg);
329   MS_LOG(INFO) << "Create all reduce op success, the reduce_op is  " << reduce_op << ", the group is " << group;
330   return op;
331 }
332 
CreateReduceScatterOp(const std::string & reduce_op,const std::string & group)333 Operator CreateReduceScatterOp(const std::string &reduce_op, const std::string &group) {
334   OperatorName operator_name = REDUCE_SCATTER;
335   OperatorArgs operator_arg = CreateReduceCommunicationOpArgs(reduce_op, group);
336 
337   Operator op = std::make_pair(operator_name, operator_arg);
338   MS_LOG(INFO) << "Create reduce scatter op success, the reduce_op is  " << reduce_op << ", the group is " << group;
339   return op;
340 }
341 
AddCommOpFusionType(const CNodePtr & comm_node,const AnfNodePtr & param_node)342 void AddCommOpFusionType(const CNodePtr &comm_node, const AnfNodePtr &param_node) {
343   MS_EXCEPTION_IF_NULL(comm_node);
344   MS_EXCEPTION_IF_NULL(param_node);
345   ParameterPtr param;
346   if (IsPrimitiveCNode(param_node, prim::kPrimReceive)) {
347     param = param_node->user_data<AnfNode>(PIPELINE_PARAM)->cast<ParameterPtr>();
348   } else {
349     param = param_node->cast<ParameterPtr>();
350   }
351   MS_EXCEPTION_IF_NULL(param);
352   auto prim = GetValueNode<PrimitivePtr>(comm_node->input(0));
353   MS_EXCEPTION_IF_NULL(prim);
354   auto attrs = prim->attrs();
355   auto param_info = param->param_info();
356   if (!param_info) {
357     MS_LOG(WARNING) << param->ToString() << "does not have parameter info.";
358     return;
359   }
360   int32_t fusion_type = param_info->comm_fusion();
361   attrs[FUSION] = MakeValue<int64_t>(fusion_type);
362   prim->SetAttrs(attrs);
363   bool parallel_optimizer_comm_recompute = param_info->parallel_optimizer_comm_recompute();
364   std::string instance_name = prim->instance_name();
365   if (instance_name == PARALLEL_OPTIMIZER_ALLGATHER_NOT_COMPUTE && parallel_optimizer_comm_recompute &&
366       prim->name() == ALL_GATHER) {
367     prim->set_attr(RECOMPUTE, MakeValue(true));
368     prim->set_instance_name(PARALLEL_OPTIMIZER_ALLGATHER);
369   }
370   MS_LOG(INFO) << "Set comm fusion:" << param->param_info()->name() << "'s fusion type is " << fusion_type;
371 }
372 
AddCommOpMeanFlag(const CNodePtr & comm_node)373 void AddCommOpMeanFlag(const CNodePtr &comm_node) {
374   MS_EXCEPTION_IF_NULL(comm_node);
375   auto prim = GetValueNode<PrimitivePtr>(comm_node->input(0));
376   auto attrs = prim->attrs();
377   MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
378   bool mean_flag = ParallelContext::GetInstance()->gradients_mean();
379   attrs[MEAN_FLAG] = MakeValue<bool>(mean_flag);
380   prim->SetAttrs(attrs);
381 }
382 
AddCommOpParamFlag(const CNodePtr & comm_node)383 void AddCommOpParamFlag(const CNodePtr &comm_node) {
384   MS_EXCEPTION_IF_NULL(comm_node);
385   auto graph = comm_node->func_graph();
386   MS_EXCEPTION_IF_NULL(graph);
387   auto manager = graph->manager();
388   MS_EXCEPTION_IF_NULL(manager);
389   auto node_users = manager->node_users()[comm_node->input(1)];
390   for (auto &node_user : node_users) {
391     if (IsPrimitiveCNode(node_user.first, prim::kPrimSend)) {
392       auto prim = GetCNodePrimitive(comm_node);
393       (void)prim->AddAttr(PARAMETER_MICRO, MakeValue(0));
394       return;
395     }
396   }
397 }
398 
CreateAllGatherOp(const std::string & group)399 Operator CreateAllGatherOp(const std::string &group) {
400   OperatorName operator_name = ALL_GATHER;
401   ValuePtr attr0_value = MakeValue(group);  // group
402   Attr attr0 = std::make_pair(GROUP, attr0_value);
403   OperatorAttrs operator_attrs;
404   operator_attrs.push_back(attr0);
405 
406   OperatorParams operator_param;
407   OperatorArgs operator_arg = std::make_pair(operator_attrs, operator_param);
408 
409   Operator op = std::make_pair(operator_name, operator_arg);
410   MS_LOG(INFO) << "Create allgather op success, the group is " << group;
411   return op;
412 }
413 
CreateMiniStepAllGatherOp(const std::string & group)414 Operator CreateMiniStepAllGatherOp(const std::string &group) {
415   int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step();
416   bool mean_flag = ParallelContext::GetInstance()->gradients_mean();
417 
418   OperatorName operator_name = MINI_STEP_ALL_GATHER;
419   ValuePtr attr0_value = MakeValue(group);  // group
420   Attr attr0 = std::make_pair(GROUP, attr0_value);
421   ValuePtr attr1_value = MakeValue(grad_accumulation_step);  // grad_accumulation_step
422   Attr attr1 = std::make_pair(GRAD_ACCUMULATION_STEP, attr1_value);
423   ValuePtr attr2_value = MakeValue(mean_flag);  // mean_flag
424   Attr attr2 = std::make_pair(MEAN_FLAG, attr2_value);
425   OperatorAttrs operator_attrs;
426   operator_attrs.push_back(attr0);
427   operator_attrs.push_back(attr1);
428   operator_attrs.push_back(attr2);
429 
430   OperatorParams operator_param;
431   OperatorArgs operator_arg = std::make_pair(operator_attrs, operator_param);
432 
433   Operator op = std::make_pair(operator_name, operator_arg);
434   MS_LOG(INFO) << "Create MINI_STEP_ALL_GATHER success, the group is " << group;
435   return op;
436 }
437 
CreateMicroStepAllGatherOp(const std::string & group)438 Operator CreateMicroStepAllGatherOp(const std::string &group) {
439   bool mean_flag = ParallelContext::GetInstance()->gradients_mean();
440 
441   OperatorName operator_name = MICRO_STEP_ALL_GATHER;
442   ValuePtr attr0_value = MakeValue(group);  // group
443   Attr attr0 = std::make_pair(GROUP, attr0_value);
444   ValuePtr attr1_value = MakeValue(mean_flag);  // mean_flag
445   Attr attr1 = std::make_pair(MEAN_FLAG, attr1_value);
446   OperatorAttrs operator_attrs;
447   operator_attrs.push_back(attr0);
448   operator_attrs.push_back(attr1);
449 
450   OperatorParams operator_param;
451   OperatorArgs operator_arg = std::make_pair(operator_attrs, operator_param);
452 
453   Operator op = std::make_pair(operator_name, operator_arg);
454   MS_LOG(INFO) << "Create MICRO_STEP_ALL_GATHER success, the group is " << group;
455   return op;
456 }
457 
458 // use for get tensor slice
CreateGetTensorSliceOp(const TensorLayout & tensor_layout)459 Operator CreateGetTensorSliceOp(const TensorLayout &tensor_layout) {
460   Shape tensor_map = tensor_layout.tensor_map().array();
461   Shape dev_matrix_shape = tensor_layout.device_arrangement().array();
462   OperatorName operator_name = GET_TENSOR_SLICE;
463 
464   OperatorAttrs attrs;
465   ValuePtr dev_mat_value = MakeValue(dev_matrix_shape);
466   Param dev_mat_param = std::make_pair(std::make_pair(DEV_MAT, dev_mat_value), 2);
467   ValuePtr tensor_map_value = MakeValue(tensor_map);
468   Param tensor_map_param = std::make_pair(std::make_pair(TENSOR_MAP, tensor_map_value), 3);
469   OperatorParams params = {dev_mat_param, tensor_map_param};
470   OperatorArgs operator_arg = std::make_pair(attrs, params);
471 
472   Operator op = std::make_pair(operator_name, operator_arg);
473   MS_LOG(INFO) << "Create get tensor slice op success, the dev mat and tensor map is "
474                << ShapeToString(dev_matrix_shape) << ", " << ShapeToString(tensor_map);
475   return op;
476 }
477 
CreateMirrorOps(const std::string & group_name,size_t dev_num)478 OperatorVector CreateMirrorOps(const std::string &group_name, size_t dev_num) {
479   if ((dev_num == 0) || (dev_num == 1)) {
480     MS_LOG(EXCEPTION) << "Invalid dev num: " << dev_num;
481   }
482   OperatorVector op_for_weight;
483   bool mean_flag = ParallelContext::GetInstance()->gradients_mean();
484   int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step();
485   int64_t split_stage_num = ParallelContext::GetInstance()->pipeline_stage_split_num();
486 
487   ValuePtr attr0_value = MakeValue(group_name);
488   ValuePtr attr1_value = MakeValue(SizeToLong(dev_num));
489   ValuePtr attr2_value = MakeValue(mean_flag);
490 
491   Attr attr0 = std::make_pair(GROUP, attr0_value);
492   Attr attr1 = std::make_pair(DEV_NUM, attr1_value);
493   Attr attr2 = std::make_pair(MEAN_FLAG, attr2_value);
494 
495   OperatorAttrs operator_attrs;
496   operator_attrs.push_back(attr0);
497   operator_attrs.push_back(attr1);
498   operator_attrs.push_back(attr2);
499 
500   OperatorName operator_name;
501   if (grad_accumulation_step > 1) {
502     operator_name = MIRROR_MINI_STEP_OPERATOR;
503     ValuePtr attr3_value = MakeValue(grad_accumulation_step);
504     Attr attr3 = std::make_pair(GRAD_ACCUMULATION_STEP, attr3_value);
505     operator_attrs.push_back(attr3);
506     MS_LOG(INFO) << "The grad accumulation step is " << grad_accumulation_step << ", use mini step mirror";
507   } else if (split_stage_num > 1) {
508     operator_name = MIRROR_MICRO_STEP_OPERATOR;
509   } else {
510     operator_name = MIRROR_OPERATOR;
511   }
512 
513   OperatorParams operator_param;
514   OperatorArgs operator_args = std::make_pair(operator_attrs, operator_param);
515 
516   Operator op = std::make_pair(operator_name, operator_args);
517 
518   op_for_weight.push_back(op);
519   MS_LOG(INFO) << "The group name is " << group_name << ", the dev num is " << dev_num << ", the mean flag is "
520                << mean_flag;
521   return op_for_weight;
522 }
523 
CreateGroupByTensorMap(const Shape & tensor_map,std::vector<Group> * group)524 Status OperatorInfo::CreateGroupByTensorMap(const Shape &tensor_map, std::vector<Group> *group) {
525   if (group == nullptr) {
526     MS_LOG(ERROR) << "The group is null.";
527     return FAILED;
528   }
529   CheckGlobalDeviceManager();
530   int64_t rank = g_device_manager->global_rank();
531   DeviceMatrix dev_matrix(rank, stage_device_list_, dev_matrix_shape_);
532   RankList group_devices;
533   if (dev_matrix.GetDevicesByTensorMap(tensor_map, &group_devices) != SUCCESS) {
534     return FAILED;
535   }
536 
537   if (group_devices.size() == 1) {
538     MS_LOG(INFO) << "The dev size is 1, no need to create group.";
539     return SUCCESS;
540   }
541 
542   Group g = g_device_manager->CreateGroup(group_devices);
543   group->push_back(g);
544   return SUCCESS;
545 }
546 
CreateGroupForOptShard(TensorLayout * const tensor_layout,std::vector<Group> * groups)547 Status OperatorInfo::CreateGroupForOptShard(TensorLayout *const tensor_layout, std::vector<Group> *groups) {
548   if (groups == nullptr) {
549     MS_LOG(ERROR) << "The group is null. Operator is " << name_;
550     return FAILED;
551   }
552   CheckGlobalDeviceManager();
553   int64_t rank = g_device_manager->global_rank();
554   DeviceMatrix dev_matrix(rank, stage_device_list_, dev_matrix_shape_);
555   RankList group_devices;
556   Shape tensor_map = tensor_layout->origin_tensor_map().array();
557   if (dev_matrix.GetDevicesByTensorMap(tensor_map, &group_devices) != SUCCESS) {
558     return FAILED;
559   }
560 
561   if (group_devices.size() == 1) {
562     MS_LOG(INFO) << "The dev size is 1, no need to create group.";
563     return SUCCESS;
564   }
565   int64_t optimizer_weight_shard_size = ParallelContext::GetInstance()->optimizer_weight_shard_size();
566   if (optimizer_weight_shard_size != -1) {
567     // not fully use opt shard
568     int64_t index = std::find(group_devices.begin(), group_devices.end(), rank) - group_devices.begin();
569     int64_t repeated_size = SizeToLong(group_devices.size());
570     if (repeated_size % optimizer_weight_shard_size != 0) {
571       MS_LOG(WARNING) << "Parallel optimizer: optimizer_weight_shard_size " << optimizer_weight_shard_size
572                       << " can not be applied. The repeated size of Operator " << name_ << " is " << repeated_size;
573       return FAILED;
574     }
575     repeated_size = repeated_size / optimizer_weight_shard_size;
576     // create allgather group
577     // eg: optimizer_weight_shard_size = 2, [0, 8, 16, 24] -> [0, 8], [16, 24]
578     RankList new_group_devices(
579       group_devices.begin() + index / optimizer_weight_shard_size * optimizer_weight_shard_size,
580       group_devices.begin() + (index / optimizer_weight_shard_size + 1) * optimizer_weight_shard_size);
581     Group allgather_group = g_device_manager->CreateGroup(new_group_devices);
582     groups->push_back(allgather_group);
583     tensor_layout->set_opt_shard_group(allgather_group.name());
584     MS_LOG(INFO) << "Parallel optimizer: create allgather group " << allgather_group.name();
585     // create mirror group
586     // eg: optimizer_weight_shard_size = 2, [0, 8, 16, 24] -> [0, 16], [8, 24]
587     int64_t device_num = g_device_manager->stage_device_num();
588     Shape dev_mat = {repeated_size, device_num / repeated_size};
589     DeviceMatrix temp_dev_matrix(rank, stage_device_list_, dev_mat);
590     RankList mirror_group_devices;
591     if (temp_dev_matrix.GetDevicesAlongDim(0, &mirror_group_devices) != SUCCESS) {
592       return FAILED;
593     }
594     Group mirror_group = g_device_manager->CreateGroup(mirror_group_devices);
595     groups->push_back(mirror_group);
596     tensor_layout->set_opt_shard_mirror_group(mirror_group.name());
597     MS_LOG(INFO) << "Parallel optimizer: create mirror group " << mirror_group.name();
598   } else {
599     // fully use opt shard
600     // create allgather group
601     Group allgather_group = g_device_manager->CreateGroup(group_devices);
602     groups->push_back(allgather_group);
603     tensor_layout->set_opt_shard_group(allgather_group.name());
604     MS_LOG(INFO) << "Parallel optimizer: create allgather group " << allgather_group.name();
605   }
606   // save in tensor_layout for strategy ckpt
607   auto integrated_save = ParallelContext::GetInstance()->optimizer_weight_shard_aggregated_save();
608   if (!integrated_save) {
609     tensor_layout->set_opt_weight_shard_size(LongToInt(optimizer_weight_shard_size));
610     int64_t opt_weight_shard_step =
611       (group_devices.back() - group_devices.front()) / (SizeToLong(group_devices.size()) - 1);
612     tensor_layout->set_opt_weight_shard_step(LongToInt(opt_weight_shard_step));
613     MS_LOG(INFO) << "Parallel optimizer: save opt_weight_shard_step " << opt_weight_shard_step << " in strategy ckpt";
614   }
615   return SUCCESS;
616 }
617 
CreateGroupByDim(size_t axis,std::vector<Group> * group)618 Status OperatorInfo::CreateGroupByDim(size_t axis, std::vector<Group> *group) {
619   if (group == nullptr) {
620     MS_LOG(ERROR) << "The group is null.";
621     return FAILED;
622   }
623   CheckGlobalDeviceManager();
624   int64_t rank = g_device_manager->global_rank();
625   DeviceMatrix dev_matrix(rank, stage_device_list_, dev_matrix_shape_);
626   RankList group_devices;
627   if (dev_matrix.GetDevicesAlongDim(SizeToUlong(axis), &group_devices) != SUCCESS) {
628     return FAILED;
629   }
630 
631   if (group_devices.size() == 1) {
632     MS_LOG(INFO) << "The dev size is 1, no need to create group.";
633     return SUCCESS;
634   }
635 
636   Group g = g_device_manager->CreateGroup(group_devices);
637   group->push_back(g);
638   return SUCCESS;
639 }
640 
GetSliceShape(const Shape & tensor_shape,const Dimensions & strategy)641 Shape GetSliceShape(const Shape &tensor_shape, const Dimensions &strategy) {
642   Shape slice_shape;
643   if (std::any_of(strategy.begin(), strategy.end(), [](int64_t value) { return value <= 0; })) {
644     MS_LOG(ERROR) << "Invalid strategy: " << ShapeToString(strategy) << ", the element is less than or equal to 0";
645     return slice_shape;
646   }
647   for (size_t i = 0; i < strategy.size(); ++i) {
648     slice_shape.push_back(tensor_shape.at(i) / strategy.at(i));
649   }
650   return slice_shape;
651 }
652 
InferSliceShapeByStrategy(const Strategys & strategys,const Shapes & shapes,Shapes * slice_shapes)653 Status InferSliceShapeByStrategy(const Strategys &strategys, const Shapes &shapes, Shapes *slice_shapes) {
654   if (slice_shapes == nullptr) {
655     MS_LOG(ERROR) << "The slice_shapes is null.";
656     return FAILED;
657   }
658   if (strategys.size() != shapes.size()) {
659     MS_LOG(ERROR) << "Strategy size " << strategys.size() << " not equal to shape size " << shapes.size();
660     return FAILED;
661   }
662 
663   for (size_t i = 0; i < strategys.size(); ++i) {
664     if (strategys.at(i).size() != shapes.at(i).size()) {
665       MS_LOG(ERROR) << "Strategy dimension " << strategys.at(i).size() << " not equal to shape dimension "
666                     << shapes.at(i).size();
667       slice_shapes->clear();
668       return FAILED;
669     }
670 
671     for (size_t j = 0; j < shapes.at(i).size(); ++j) {
672       if (strategys.at(i).at(j) <= 0) {
673         MS_LOG(ERROR) << "Invalid strategy: " << ShapeToString(strategys[i])
674                       << " the element is less than or equal to 0.";
675         slice_shapes->clear();
676         return FAILED;
677       }
678       if (shapes.at(i).at(j) % strategys.at(i).at(j) != 0) {
679         MS_LOG(ERROR) << "Shape cannot be divisible by strategy, " << shapes.at(i).at(j) << " : "
680                       << strategys.at(i).at(j);
681         slice_shapes->clear();
682         return FAILED;
683       }
684     }
685     Shape slice_shape = GetSliceShape(shapes.at(i), strategys.at(i));
686     slice_shapes->push_back(slice_shape);
687   }
688 
689   return SUCCESS;
690 }
691 
InferSliceShape(const Strategys & inputs_strategy,const Strategys & outputs_strategy,Shapes * inputs_slice_shape,Shapes * outputs_slice_shape)692 Status OperatorInfo::InferSliceShape(const Strategys &inputs_strategy, const Strategys &outputs_strategy,
693                                      Shapes *inputs_slice_shape, Shapes *outputs_slice_shape) {
694   if (inputs_slice_shape == nullptr || outputs_slice_shape == nullptr) {
695     MS_LOG(ERROR) << "The slice_shape is null.";
696     return FAILED;
697   }
698 
699   if (InferSliceShapeByStrategy(inputs_strategy, inputs_shape_, inputs_slice_shape) != SUCCESS) {
700     MS_LOG(ERROR) << "Infer inputs slice shape error.";
701     return FAILED;
702   }
703 
704   if (InferSliceShapeByStrategy(outputs_strategy, outputs_shape_, outputs_slice_shape) != SUCCESS) {
705     MS_LOG(ERROR) << "Infer outputs slice shape error.";
706     inputs_slice_shape->clear();
707     return FAILED;
708   }
709 
710   return SUCCESS;
711 }
712 
713 // method0: auto insert repeated_calculation_num for dev_matrix_shape when repeated_calculation_num > 1
InitForCostModelWithAutoRepeatCalc(const StrategyPtr & strategy)714 Status OperatorInfo::InitForCostModelWithAutoRepeatCalc(const StrategyPtr &strategy) {
715   if (strategy == nullptr) {
716     MS_LOG(ERROR) << name_ << ": The strategy is null.";
717     return FAILED;
718   }
719 
720   if (InferAttrs() != SUCCESS) {
721     MS_LOG(ERROR) << name_ << ": InferAttrs failed.";
722     return FAILED;
723   }
724 
725   // must be after InferAttrs()
726   if (CheckStrategy(strategy) != SUCCESS) {
727     if (is_auto_parallel_) {
728       MS_LOG(DEBUG) << name_ << ": CheckStrategy failed.";
729     } else {
730       MS_LOG(ERROR) << name_ << ": CheckStrategy failed.";
731     }
732     return FAILED;
733   }
734 
735   // need to clear queues before Init(),
736   // because Init() may be called multiple times by cost model
737   ResetQueueMember();
738 
739   strategy_ = strategy;
740 
741   if (InferDevMatrixShape() != SUCCESS) {
742     MS_LOG(ERROR) << name_ << ": InferDevMatrixShape failed.";
743     return FAILED;
744   }
745 
746   used_devices_ =
747     ((int64_t)(std::accumulate(dev_matrix_shape_.begin(), dev_matrix_shape_.end(), 1, std::multiplies<int64_t>())));
748 
749   // must be after InferDevMatrixShape
750   if (InferRepeatedCalcInfo() != SUCCESS) {
751     MS_LOG(ERROR) << ": InferRepeatedCalcInfo failed.";
752     return FAILED;
753   }
754 
755   // if repeated calculation, need to set the repeated_calc_num as the last dimension of dev-matrix for layout
756   SetRepeatedCalcDevMatrix();
757 
758   if (InferTensorMap() != SUCCESS) {
759     MS_LOG(ERROR) << name_ << ": InferTensorMap failed.";
760     return FAILED;
761   }
762 
763   ResetTensorMapIfRepeatedCalc();
764 
765   if (InferTensorInfo() != SUCCESS) {
766     MS_LOG(ERROR) << name_ << ": InferTensorInfo failed.";
767     return FAILED;
768   }
769 
770   return SUCCESS;
771 }
772 
773 // method1: manually insert repeated_calculation_num for dev_matrix_shape in InferDevMatrixShape
InitForCostModelWithManualRepeatCalc(const StrategyPtr & strategy)774 Status OperatorInfo::InitForCostModelWithManualRepeatCalc(const StrategyPtr &strategy) {
775   if (strategy == nullptr) {
776     MS_LOG(ERROR) << name_ << ": The strategy is null.";
777     return FAILED;
778   }
779 
780   if (InferAttrs() != SUCCESS) {
781     MS_LOG(ERROR) << name_ << ": InferAttrs failed.";
782     return FAILED;
783   }
784 
785   // must be after InferAttrs()
786   if (CheckStrategy(strategy) != SUCCESS) {
787     MS_LOG(ERROR) << name_ << ": CheckStrategy failed.";
788     return FAILED;
789   }
790 
791   // need to clear queues before Init(),
792   // because Init() may be called multiple times by cost model
793   ResetQueueMember();
794 
795   strategy_ = strategy;
796 
797   if (InferDevMatrixShape() != SUCCESS) {
798     MS_LOG(ERROR) << name_ << ": InferDevMatrixShape failed.";
799     return FAILED;
800   }
801 
802   // must be after InferDevMatrixShape
803   if (InferRepeatedCalcInfo() != SUCCESS) {
804     MS_LOG(ERROR) << name_ << ": InferRepeatedCalcInfo failed.";
805     return FAILED;
806   }
807 
808   if (InferTensorMap() != SUCCESS) {
809     MS_LOG(ERROR) << name_ << ": InferTensorMap failed.";
810     return FAILED;
811   }
812 
813   if (InferTensorInfo() != SUCCESS) {
814     MS_LOG(ERROR) << name_ << ": InferTensorInfo failed.";
815     return FAILED;
816   }
817 
818   return SUCCESS;
819 }
820 
InitWithAutoRepeatCalc(const StrategyPtr & strategy)821 Status OperatorInfo::InitWithAutoRepeatCalc(const StrategyPtr &strategy) {
822   if (strategy == nullptr) {
823     MS_LOG(ERROR) << name_ << ": The strategy is null.";
824     return FAILED;
825   }
826 
827   if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
828     return FAILED;
829   }
830 
831   if (InferForwardCommunication() != SUCCESS) {
832     MS_LOG(ERROR) << name_ << ": InferForwardCommunication failed.";
833     return FAILED;
834   }
835 
836   if (InferMirrorOps() != SUCCESS) {
837     MS_LOG(ERROR) << name_ << ": InferMirrorOps failed.";
838     return FAILED;
839   }
840 
841   if (InferVirtualDivOps() != SUCCESS) {
842     MS_LOG(ERROR) << name_ << ": InferVirtualDivOps failed.";
843     return FAILED;
844   }
845 
846   return SUCCESS;
847 }
848 
InitWithManualRepeatCalc(const StrategyPtr & strategy)849 Status OperatorInfo::InitWithManualRepeatCalc(const StrategyPtr &strategy) {
850   if (strategy == nullptr) {
851     MS_LOG(ERROR) << name_ << ": The strategy is null.";
852     return FAILED;
853   }
854 
855   if (InitForCostModelWithManualRepeatCalc(strategy) != SUCCESS) {
856     return FAILED;
857   }
858 
859   if (InferForwardCommunication() != SUCCESS) {
860     MS_LOG(ERROR) << name_ << ": InferForwardCommunication failed.";
861     return FAILED;
862   }
863 
864   if (InferMirrorOps() != SUCCESS) {
865     MS_LOG(ERROR) << name_ << ": InferMirrorOps failed.";
866     return FAILED;
867   }
868 
869   if (InferVirtualDivOps() != SUCCESS) {
870     MS_LOG(ERROR) << name_ << ": InferVirtualDivOps failed.";
871     return FAILED;
872   }
873 
874   return SUCCESS;
875 }
876 
GetAliveSuccEdges()877 std::vector<std::shared_ptr<Edge>> OperatorInfo::GetAliveSuccEdges() {
878   std::vector<std::shared_ptr<Edge>> ret;
879   for (auto &edge : succ_edges_) {
880     if ((edge->next_operator()->is_alive()) && (edge->next_operator()->name().find(RELU) != std::string::npos)) {
881       ret.push_back(edge);
882     } else if ((edge->next_operator()->is_alive()) && (edge->next_operator()->name().find(CAST) != std::string::npos)) {
883       // CAST is ordered in front of L2NORMALIZE
884       ret.push_back(edge);
885     }
886   }
887   for (auto &edge : succ_edges_) {
888     if ((edge->next_operator()->is_alive()) && (edge->next_operator()->name().find(RELU) == std::string::npos) &&
889         (edge->next_operator()->name().find(CAST) == std::string::npos)) {
890       ret.push_back(edge);
891     }
892   }
893   return ret;
894 }
895 
GetAlivePrevEdges()896 std::vector<std::shared_ptr<Edge>> OperatorInfo::GetAlivePrevEdges() {
897   std::vector<std::shared_ptr<Edge>> ret;
898   for (auto &edge : prev_edges_) {
899     if (edge->prev_operator()->is_alive()) {
900       ret.push_back(edge);
901     }
902   }
903   return ret;
904 }
905 
ReplacePreEdge(const std::shared_ptr<OperatorInfo> & op,const std::shared_ptr<Edge> & replace_edge)906 void OperatorInfo::ReplacePreEdge(const std::shared_ptr<OperatorInfo> &op, const std::shared_ptr<Edge> &replace_edge) {
907   if (op == nullptr) {
908     MS_LOG(ERROR) << name_ << ": ReplacePreEdge: the op is null.";
909     return;
910   }
911   for (auto &edge : prev_edges_) {
912     if (edge->prev_operator() == op) {
913       edge = replace_edge;
914       return;
915     }
916   }
917   MS_LOG(EXCEPTION) << name_ << ": Replace edge failed: no edge has been replaced";
918 }
919 
ReplaceSuccEdge(const std::shared_ptr<OperatorInfo> & op,const std::shared_ptr<Edge> & replace_edge)920 void OperatorInfo::ReplaceSuccEdge(const std::shared_ptr<OperatorInfo> &op, const std::shared_ptr<Edge> &replace_edge) {
921   if (op == nullptr) {
922     MS_LOG(ERROR) << name_ << ": ReplaceSuccEdge: the op is null.";
923     return;
924   }
925   for (auto &edge : succ_edges_) {
926     if (edge->next_operator() == op) {
927       edge = replace_edge;
928       return;
929     }
930   }
931   MS_LOG(EXCEPTION) << name_ << ": Replace edge failed: no edge has been replaced";
932 }
933 
ReplacePreEdges(const std::shared_ptr<OperatorInfo> & op,const std::shared_ptr<Edge> & replace_edge)934 void OperatorInfo::ReplacePreEdges(const std::shared_ptr<OperatorInfo> &op, const std::shared_ptr<Edge> &replace_edge) {
935   if (op == nullptr) {
936     MS_LOG(ERROR) << name_ << ": ReplacePreEdges: the op is null.";
937     return;
938   }
939   std::vector<std::shared_ptr<Edge>> update_pre_edges;
940   for (auto &edge : prev_edges_) {
941     if (edge->prev_operator() != op) {
942       update_pre_edges.push_back(edge);
943     }
944   }
945   update_pre_edges.push_back(replace_edge);
946   prev_edges_ = update_pre_edges;
947 }
948 
ReplaceSuccEdges(const std::shared_ptr<OperatorInfo> & op,const std::shared_ptr<Edge> & replace_edge)949 void OperatorInfo::ReplaceSuccEdges(const std::shared_ptr<OperatorInfo> &op,
950                                     const std::shared_ptr<Edge> &replace_edge) {
951   if (op == nullptr) {
952     MS_LOG(ERROR) << name_ << ": ReplaceSuccEdges: the op is null";
953     return;
954   }
955   std::vector<std::shared_ptr<Edge>> update_pre_edges;
956   for (auto &edge : succ_edges_) {
957     if (edge->next_operator() != op) {
958       update_pre_edges.push_back(edge);
959     }
960   }
961   update_pre_edges.push_back(replace_edge);
962   succ_edges_ = update_pre_edges;
963 }
964 
GenerateBatchStrategiesBySplitFlag(const Shapes & shapes,const std::vector<bool> & split_flag_list)965 std::shared_ptr<Strategys> GenerateBatchStrategiesBySplitFlag(const Shapes &shapes,
966                                                               const std::vector<bool> &split_flag_list) {
967   if (shapes.size() != split_flag_list.size()) {
968     MS_LOG(ERROR) << "Split_flag_list do not have the same size as inputs shape, " << split_flag_list.size() << " : "
969                   << shapes.size();
970     return nullptr;
971   }
972   CheckGlobalDeviceManager();
973   int64_t dev_num = g_device_manager->stage_device_num();
974   Strategys strategy_v;
975   for (size_t i = 0; i != shapes.size(); i++) {
976     if (shapes[i].empty()) {
977       MS_LOG(INFO) << "Elements of shapes is empty.";
978       Dimensions empty_element;
979       strategy_v.push_back(empty_element);
980     } else {
981       Dimensions element(shapes[i].size(), 1);
982       if (split_flag_list[i]) {
983         element[0] = dev_num;
984       }
985       strategy_v.push_back(element);
986     }
987   }
988   return std::make_shared<Strategys>(strategy_v);
989 }
990 
ReComputeBatchSplitFlagList()991 void OperatorInfo::ReComputeBatchSplitFlagList() {
992   if (!inputs_shape_.empty()) {
993     split_flag_list_[0] = true;
994   }
995 }
996 
ComputeBatchSplitFlagList()997 void OperatorInfo::ComputeBatchSplitFlagList() {
998   split_flag_list_.clear();
999   for (auto iter = inputs_shape_.begin(); iter != inputs_shape_.end(); ++iter) {
1000     split_flag_list_.push_back(false);
1001   }
1002   ReComputeBatchSplitFlagList();
1003 }
1004 
1005 // This is a common method for checking whether the generated strategy has the correct number of devuces.
PrepareStrategyBase(int64_t stage_id,size_t dev_num,const Shapes & inputs_partitions,StrategyPtr * const sp)1006 Status PrepareStrategyBase(int64_t stage_id, size_t dev_num, const Shapes &inputs_partitions, StrategyPtr *const sp) {
1007   if (sp == nullptr) {
1008     MS_LOG(ERROR) << "The strategy is null.";
1009     return FAILED;
1010   }
1011   int64_t product = 1;
1012 
1013   for (auto &input_partition : inputs_partitions) {
1014     product *= std::accumulate(input_partition.begin(), input_partition.end(), 1, std::multiplies<int64_t>());
1015   }
1016   const auto fully_use_device = CostModelContext::GetInstance()->fully_use_device();
1017   if (!fully_use_device) {
1018     if (LongToSize(product) > dev_num) {
1019       return FAILED;
1020     }
1021   } else {
1022     if ((product != 1) && (LongToSize(product) != dev_num)) {
1023       return FAILED;
1024     }
1025   }
1026   Strategys stras(inputs_partitions);
1027   (*sp) = std::make_shared<Strategy>(stage_id, stras);
1028   return SUCCESS;
1029 }
1030 
GenerateBatchStrategies()1031 std::shared_ptr<Strategys> OperatorInfo::GenerateBatchStrategies() {
1032   if (inputs_shape_.empty() && InferAttrs() != SUCCESS) {
1033     MS_LOG(EXCEPTION) << name_ << ": Infer attrs failed";
1034   }
1035   ComputeBatchSplitFlagList();
1036   return GenerateBatchStrategiesBySplitFlag(inputs_shape_, split_flag_list_);
1037 }
1038 
PrintStrategy(const StrategyPtr & strategy)1039 void PrintStrategy(const StrategyPtr &strategy) {
1040   if (strategy == nullptr) {
1041     return;
1042   }
1043   std::string all_strategy = "";
1044   for (size_t i = 0; i < strategy->GetInputNumber(); ++i) {
1045     all_strategy += "[";
1046     for (size_t j = 0; j < strategy->GetInputDim()[i].size(); ++j) {
1047       all_strategy += std::to_string(strategy->GetInputDim()[i][j]);
1048       if (j != strategy->GetInputDim()[i].size() - 1) {
1049         all_strategy += ", ";
1050       }
1051     }
1052     all_strategy += "]";
1053     if (i != strategy->GetInputNumber() - 1) {
1054       all_strategy += ", ";
1055     }
1056   }
1057   MS_LOG(INFO) << "The strategy is: " << all_strategy;
1058 }
1059 
1060 // generate strategies for that each dimension of input0 and input1 is relevant, such as: ([a, b, c, d], [a, b, c, d])
GenerateStrategiesForTwoEqualInputs(int64_t stage_id,const Shapes & inputs_shape,const Shapes & splittable_inputs,std::vector<StrategyPtr> * const sp_vector)1061 Status GenerateStrategiesForTwoEqualInputs(int64_t stage_id, const Shapes &inputs_shape,
1062                                            const Shapes &splittable_inputs, std::vector<StrategyPtr> *const sp_vector) {
1063   if (sp_vector == nullptr) {
1064     MS_LOG(ERROR) << "The sp_vector is null.";
1065     return FAILED;
1066   }
1067 
1068   if ((inputs_shape.size() != 2) || (splittable_inputs.size() != 2)) {
1069     MS_LOG(ERROR) << "The inputs size is wrong.";
1070     return FAILED;
1071   }
1072 
1073   if ((inputs_shape[0].size() != inputs_shape[1].size()) ||
1074       (splittable_inputs[0].size() != splittable_inputs[1].size())) {
1075     MS_LOG(ERROR) << "The size of two inputs are not equal.";
1076     return FAILED;
1077   }
1078 
1079   Shapes input0_shape = {inputs_shape[0]};
1080   Shapes input0_splittable = {splittable_inputs[0]};
1081   if (GenerateStrategiesForIndependentInputs(stage_id, input0_shape, input0_splittable, sp_vector) != SUCCESS) {
1082     return FAILED;
1083   }
1084 
1085   for (auto &sp : *sp_vector) {
1086     sp->ExpandInputDimFromOneToTwo();
1087   }
1088 
1089   return SUCCESS;
1090 }
1091 
1092 // generate strategies for that input0 and input1 have relevant dimensions, and input0 needs to broadcast
1093 // such as: ([b, c, d], [a, b, c, d]) or ([1, c, d], [a, b, c, d])
GenerateStrategiesForBroadcastLeft(int64_t stage_id,const Shapes & inputs_shape,const Shapes & splittable_inputs,std::vector<StrategyPtr> * const sp_vector)1094 Status GenerateStrategiesForBroadcastLeft(int64_t stage_id, const Shapes &inputs_shape, const Shapes &splittable_inputs,
1095                                           std::vector<StrategyPtr> *const sp_vector) {
1096   if (sp_vector == nullptr) {
1097     MS_LOG(ERROR) << "The sp_vector is null.";
1098     return FAILED;
1099   }
1100 
1101   if (inputs_shape[0].size() >= inputs_shape[1].size()) {
1102     MS_LOG(ERROR) << "Invalid inputs shape.";
1103     return FAILED;
1104   }
1105 
1106   // first, generate strategy for input0 the same as input1
1107   Shapes tmp_inputs_shape = {inputs_shape[1], inputs_shape[1]};
1108   Shapes tmp_splittable_inputs = {splittable_inputs[1], splittable_inputs[1]};
1109   if (GenerateStrategiesForTwoEqualInputs(stage_id, tmp_inputs_shape, tmp_splittable_inputs, sp_vector) != SUCCESS) {
1110     MS_LOG(ERROR) << "GenerateStrategiesForTwoEqualInputs failed.";
1111     return FAILED;
1112   }
1113 
1114   // second, get the correct strategy for input0
1115   for (auto &sp : *sp_vector) {
1116     Strategys tmp_strategy;
1117     Dimensions input0_strategy = sp->GetInputDim()[0];
1118     size_t size_diff = inputs_shape[1].size() - inputs_shape[0].size();
1119 
1120     // erase the unnecessary part
1121     (void)input0_strategy.erase(input0_strategy.begin(),
1122                                 input0_strategy.begin() + static_cast<different_type>(size_diff));
1123 
1124     // handle the case likes ([1, c, d], [a, b, c, d])
1125     for (size_t i = 0; i < inputs_shape[0].size(); ++i) {
1126       if (inputs_shape[0][i] == 1) {
1127         input0_strategy[i] = 1;
1128       } else {
1129         break;
1130       }
1131     }
1132 
1133     // reset the strategy
1134     tmp_strategy.push_back(input0_strategy);       // input0
1135     tmp_strategy.push_back(sp->GetInputDim()[1]);  // input1
1136     sp->ResetInputs(tmp_strategy);
1137   }
1138   return SUCCESS;
1139 }
1140 
1141 // generate strategies for that input0 and input1 have relevant dimensions, and input1 needs to broadcast
1142 // such as: ([a, b, c, d], [b, c, d]) or ([a, b, c, d], [1, c, d])
GenerateStrategiesForBroadcastRight(int64_t stage_id,const Shapes & inputs_shape,const Shapes & splittable_inputs,std::vector<StrategyPtr> * const sp_vector)1143 Status GenerateStrategiesForBroadcastRight(int64_t stage_id, const Shapes &inputs_shape,
1144                                            const Shapes &splittable_inputs, std::vector<StrategyPtr> *const sp_vector) {
1145   if (sp_vector == nullptr) {
1146     MS_LOG(ERROR) << "The sp_vector is null.";
1147     return FAILED;
1148   }
1149 
1150   if (inputs_shape[0].size() <= inputs_shape[1].size()) {
1151     MS_LOG(ERROR) << "Invalid inputs shape.";
1152     return FAILED;
1153   }
1154 
1155   // first, generate strategy for input1 the same as input0
1156   Shapes tmp_inputs_shape = {inputs_shape[0], inputs_shape[0]};
1157   Shapes tmp_splittable_inputs = {splittable_inputs[0], splittable_inputs[0]};
1158   if (GenerateStrategiesForTwoEqualInputs(stage_id, tmp_inputs_shape, tmp_splittable_inputs, sp_vector) != SUCCESS) {
1159     MS_LOG(ERROR) << "GenerateStrategiesForTwoEqualInputs failed.";
1160     return FAILED;
1161   }
1162 
1163   // second, get the correct strategy for input1
1164   for (auto &sp : *sp_vector) {
1165     Strategys tmp_strategy;
1166     tmp_strategy.push_back(sp->GetInputDim()[0]);  // input0
1167 
1168     Dimensions input1_strategy = sp->GetInputDim()[1];
1169     size_t size_diff = inputs_shape[0].size() - inputs_shape[1].size();
1170 
1171     // erase the unnecessary part
1172     (void)input1_strategy.erase(input1_strategy.begin(),
1173                                 input1_strategy.begin() + static_cast<different_type>(size_diff));
1174 
1175     // handle the case likes ([a, b, c, d], [1, c, d])
1176     for (size_t i = 0; i < inputs_shape[1].size(); ++i) {
1177       if (inputs_shape[1][i] == 1) {
1178         input1_strategy[i] = 1;
1179       } else {
1180         break;
1181       }
1182     }
1183 
1184     // reset the strategy
1185     tmp_strategy.push_back(input1_strategy);  // input1
1186     sp->ResetInputs(tmp_strategy);
1187   }
1188   return SUCCESS;
1189 }
1190 
1191 // generate strategies for that input0 and input1 have same size, and input0 or input1 needs to broadcast
1192 // such as: ([a, 1], [1, b]) or ([a, b, c, d], [1, b, c, d]) or ([a, b, c, 1], [1, b, c, d])
GenerateStrategiesForBroadcastBoth(int64_t stage_id,const Shapes & inputs_shape,const Shapes & splittable_inputs,std::vector<StrategyPtr> * const sp_vector)1193 Status GenerateStrategiesForBroadcastBoth(int64_t stage_id, const Shapes &inputs_shape, const Shapes &splittable_inputs,
1194                                           std::vector<StrategyPtr> *const sp_vector) {
1195   if (sp_vector == nullptr) {
1196     MS_LOG(ERROR) << "The sp_vector is null.";
1197     return FAILED;
1198   }
1199 
1200   if (inputs_shape[0].size() != inputs_shape[1].size()) {
1201     MS_LOG(ERROR) << "Invalid inputs shape.";
1202     return FAILED;
1203   }
1204 
1205   // step1: ([a, 1], [1, b]) -> [a, b]
1206   Shape max_shape, splittable_vector;
1207   for (size_t i = 0; i < inputs_shape[0].size(); ++i) {
1208     if (inputs_shape[0][i] >= inputs_shape[1][i]) {
1209       max_shape.push_back(inputs_shape[0][i]);
1210       splittable_vector.push_back(splittable_inputs[0][i]);
1211     } else {
1212       max_shape.push_back(inputs_shape[1][i]);
1213       splittable_vector.push_back(splittable_inputs[1][i]);
1214     }
1215   }
1216 
1217   // step2: ([a, 1], [1, b]) -> generate strategy for ([a, b], [a, b])
1218   Shapes tmp_inputs_shape = {max_shape, max_shape};
1219   Shapes tmp_splittable_inputs = {splittable_vector, splittable_vector};
1220   if (GenerateStrategiesForTwoEqualInputs(stage_id, tmp_inputs_shape, tmp_splittable_inputs, sp_vector) != SUCCESS) {
1221     MS_LOG(ERROR) << "GenerateStrategiesForTwoEqualInputs failed.";
1222     return FAILED;
1223   }
1224 
1225   // step3: reset the strategy if the dimension is 1
1226   for (auto &sp : *sp_vector) {
1227     Dimensions input0_strategy = sp->GetInputDim()[0];
1228     Dimensions input1_strategy = sp->GetInputDim()[1];
1229     for (size_t i = 0; i < inputs_shape[0].size(); ++i) {
1230       if (inputs_shape[0][i] == 1) {
1231         input0_strategy[i] = 1;
1232       }
1233 
1234       if (inputs_shape[1][i] == 1) {
1235         input1_strategy[i] = 1;
1236       }
1237     }
1238     sp->ResetInputs({input0_strategy, input1_strategy});
1239   }
1240 
1241   return SUCCESS;
1242 }
1243 
1244 // 'splittable_inputs' has the same dimensions as 'inputs_shape_'. '0' in 'splittable_inputs' means that
1245 // the corresponding dimension is unsplittable, '1' in 'splittable_inputs' means that the corresponding
1246 // dimension is splittable. 'inputs_partitions' is the result of partitions.
1247 // NOTE: This implementation would partition all splittable dimensions in all inputs. Some operators requiring
1248 // specific dimensions in inputs have the identical partition should have individual implementation.
GenerateStrategiesForIndependentInputs(int64_t stage_id,const Shapes & inputs_shape,const Shapes & splittable_inputs,std::vector<StrategyPtr> * const sp_vector)1249 Status GenerateStrategiesForIndependentInputs(int64_t stage_id, const Shapes &inputs_shape,
1250                                               const Shapes &splittable_inputs,
1251                                               std::vector<StrategyPtr> *const sp_vector) {
1252   if (sp_vector == nullptr) {
1253     MS_LOG(ERROR) << "The sp_vector is null.";
1254     return FAILED;
1255   }
1256   if (splittable_inputs.size() != inputs_shape.size()) {
1257     MS_LOG(ERROR) << "Splittable_inputs do not have the same input number of inputs shape, " << splittable_inputs.size()
1258                   << " : " << inputs_shape.size();
1259     return FAILED;
1260   }
1261   CheckGlobalDeviceManager();
1262   size_t dev_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
1263 
1264   Shape combined_inputs_shape, combined_splittable_inputs, combined_partitions;
1265   for (size_t j = 0; j < inputs_shape.size(); ++j) {
1266     (void)combined_inputs_shape.insert(combined_inputs_shape.end(), inputs_shape[j].begin(), inputs_shape[j].end());
1267     (void)combined_splittable_inputs.insert(combined_splittable_inputs.end(), splittable_inputs[j].begin(),
1268                                             splittable_inputs[j].end());
1269   }
1270   std::function<void(uint64_t, size_t)> recursive = [&stage_id, &dev_num, &sp_vector, &combined_inputs_shape,
1271                                                      &combined_splittable_inputs, &combined_partitions, &recursive,
1272                                                      &inputs_shape](uint64_t current_index, size_t n) {
1273     if (current_index == combined_inputs_shape.size()) {
1274       MS_LOG(DEBUG) << "The value of combined_splittable_inputs.size is: " << combined_splittable_inputs.size();
1275       Shapes inputs_partitions;
1276       size_t global_index = 0;
1277       for (auto &shape : inputs_shape) {
1278         Shape tmp_partition;
1279         for (size_t j = 0; j < shape.size(); ++j) {
1280           tmp_partition.push_back(combined_partitions[global_index]);
1281           global_index++;
1282         }
1283         inputs_partitions.push_back(tmp_partition);
1284       }
1285       StrategyPtr sp;
1286       if (PrepareStrategyBase(stage_id, dev_num, inputs_partitions, &sp) == SUCCESS) {
1287         sp_vector->push_back(sp);
1288       }
1289       return;
1290     } else {
1291       MS_LOG(DEBUG) << "The value of sp_vector size is " << sp_vector->size();
1292       if (combined_splittable_inputs[current_index] == 0) {
1293         combined_partitions.push_back(MIN_SLICE_NUM);
1294         recursive(current_index + 1, n / MIN_SLICE_NUM);
1295         combined_partitions.pop_back();
1296       } else if (combined_splittable_inputs[current_index] == 1) {
1297         for (uint64_t i = 1; i <= n; i *= 2) {
1298           if (n % i == 0 && LongToSize(combined_inputs_shape[current_index]) % i == 0) {
1299             combined_partitions.push_back(i);
1300             recursive(current_index + 1, n / i);
1301             combined_partitions.pop_back();
1302           }
1303         }
1304       }
1305     }
1306   };
1307   recursive(0, dev_num);
1308   if (sp_vector->empty()) {
1309     MS_LOG(EXCEPTION) << "No available strategy for current OperatorInfo.";
1310   }
1311   return SUCCESS;
1312 }
1313 
1314 // generate strategies for that have two inputs, and input0 or input1 maybe broadcast,
1315 // and the corresponding dimensions that are not broadcast are all relevant dimensions
1316 // such as: ([a, b, c, d], [a, b, c, d]) or ([b, c, d], [a, b, c, d]) or ([1, c, d], [a, b, c, d])
1317 // or ([a, b, c, d], [b, c, d]) or ([a, b, c, d], [1, c, d])
1318 // or ([a, 1], [1, b]) or ([a, b, c, d], [1, b, c, d]) or ([a, b, c, 1], [1, b, c, d])
GenerateStrategiesWithBroadcast(int64_t stage_id,const Shapes & inputs_shape,const Shapes & splittable_inputs,std::vector<StrategyPtr> * const sp_vector)1319 Status GenerateStrategiesWithBroadcast(int64_t stage_id, const Shapes &inputs_shape, const Shapes &splittable_inputs,
1320                                        std::vector<StrategyPtr> *const sp_vector) {
1321   if (sp_vector == nullptr) {
1322     MS_LOG(ERROR) << "The sp_vector is null.";
1323     return FAILED;
1324   }
1325 
1326   if ((inputs_shape.size() != 2) || (splittable_inputs.size() != 2)) {
1327     MS_LOG(ERROR) << "The inputs' size is wrong.";
1328     return FAILED;
1329   }
1330 
1331   if (inputs_shape[0] == inputs_shape[1]) {
1332     // element wise operation([a, b, c, d], [a, b, c, d]), so input0's strategy is equal to input1's strategy
1333     if (GenerateStrategiesForTwoEqualInputs(stage_id, inputs_shape, splittable_inputs, sp_vector) != SUCCESS) {
1334       MS_LOG(ERROR) << "GenerateStrategiesForTwoEqualInputs failed.";
1335       return FAILED;
1336     }
1337     MS_LOG(INFO) << "GenerateStrategiesForTwoEqualInputs success.";
1338   } else if (inputs_shape[0].empty() || inputs_shape[1].empty()) {
1339     // ([a, b, c, d], []) or ([], [a, b, c, d])
1340     if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape, splittable_inputs, sp_vector) != SUCCESS) {
1341       MS_LOG(ERROR) << "Generate strategies for scalar case failed.";
1342       return FAILED;
1343     }
1344     MS_LOG(INFO) << "Generate strategies for scalar case success.";
1345   } else if (inputs_shape[0].size() > inputs_shape[1].size()) {
1346     // ([a, b, c, d], [b, c, d]) or ([a, b, c, d], [1, c, d])
1347     if (GenerateStrategiesForBroadcastRight(stage_id, inputs_shape, splittable_inputs, sp_vector) != SUCCESS) {
1348       MS_LOG(ERROR) << "GenerateStrategiesForBroadcastRight failed.";
1349       return FAILED;
1350     }
1351     MS_LOG(INFO) << "GenerateStrategiesForBroadcastRight success.";
1352   } else if (inputs_shape[0].size() < inputs_shape[1].size()) {
1353     // ([b, c, d], [a, b, c, d]) or ([1, c, d], [a, b, c, d])
1354     if (GenerateStrategiesForBroadcastLeft(stage_id, inputs_shape, splittable_inputs, sp_vector) != SUCCESS) {
1355       MS_LOG(ERROR) << "GenerateStrategiesForBroadcastLeft failed.";
1356       return FAILED;
1357     }
1358     MS_LOG(INFO) << "GenerateStrategiesForBroadcastLeft success.";
1359   } else {  // same size, but different value
1360     // ([a, 1], [1, b]) or ([a, b, c, d], [1, b, c, d]) or ([a, b, c, 1], [1, b, c, d])
1361     if (GenerateStrategiesForBroadcastBoth(stage_id, inputs_shape, splittable_inputs, sp_vector) != SUCCESS) {
1362       MS_LOG(ERROR) << "GenerateStrategiesForBroadcastBoth failed.";
1363       return FAILED;
1364     }
1365     MS_LOG(INFO) << "GenerateStrategiesForBroadcastBoth success.";
1366   }
1367   return SUCCESS;
1368 }
1369 
SetCostUnderStrategyBase(const StrategyPtr & strategy)1370 Status OperatorInfo::SetCostUnderStrategyBase(const StrategyPtr &strategy) {
1371   if (InitForCostModel(strategy) == FAILED) {
1372     if (is_auto_parallel_) {
1373       MS_LOG(DEBUG) << name_ << ": Initialization under the strategy failed.";
1374     } else {
1375       MS_LOG(ERROR) << name_ << ": Initialization under the strategy failed.";
1376     }
1377     return FAILED;
1378   }
1379   int64_t stage_id = strategy->GetInputStage();
1380   double computation_cost =
1381     operator_cost()->GetForwardComputationCost(inputs_tensor_info_, outputs_tensor_info_, stage_id);
1382   double communication_cost = operator_cost()->GetCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id);
1383   const auto gamma = CostModelContext::GetInstance()->costmodel_gamma();
1384   std::shared_ptr<Cost> result = std::make_shared<Cost>(computation_cost, communication_cost);
1385   result->communication_without_parameter_ =
1386     operator_cost()->GetForwardCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id);
1387   result->communication_with_partial_para_ =
1388     result->communication_without_parameter_ + gamma * (communication_cost - result->communication_without_parameter_);
1389 
1390   // Breaking ties for preferring data parallelization
1391   BreakingTiesForPerferringDataParallel(strategy, result);
1392   // refine communication cost calculation for practice
1393   RefineForPracticalCost(result, false);
1394   result->communication_forward_ = result->communication_without_parameter_;
1395 
1396   std::shared_ptr<StrategyWithCost> swc =
1397     std::make_shared<StrategyWithCost>(strategy, inputs_tensor_info_, outputs_tensor_info_);
1398   swc->cost_list.push_back(result);
1399   strategy_cost_.emplace_back(swc);
1400 
1401   return SUCCESS;
1402 }
1403 
1404 // Keep at most (1.0 / epsilon) number of available strategies for each operator.
ApproximateStrategies()1405 void OperatorInfo::ApproximateStrategies() {
1406   auto enable_approxi = CostModelContext::GetInstance()->dp_algo_enable_approxi();
1407   if (!enable_approxi) {
1408     return;
1409   }
1410   MS_LOG(INFO) << "Approximating strategy-cost for: " << name_;
1411   auto epsilon = CostModelContext::GetInstance()->dp_algo_approxi_epsilon();
1412   auto target_num = static_cast<size_t>(std::ceil(1.0 / epsilon));
1413   if (strategy_cost_.size() <= target_num) {
1414     MS_LOG(INFO) << name_ << "'s strategy number is: " << strategy_cost_.size()
1415                  << ", no greater than target-num: " << target_num;
1416     return;
1417   }
1418   std::vector<std::shared_ptr<StrategyWithCost>> ret;
1419   auto &origin_stra_cost = strategy_cost_;
1420   auto alpha = CostModelContext::GetInstance()->costmodel_alpha();
1421   auto beta = CostModelContext::GetInstance()->costmodel_beta();
1422   // sort
1423   std::sort(
1424     origin_stra_cost.begin(), origin_stra_cost.end(),
1425     [&alpha, &beta](const std::shared_ptr<StrategyWithCost> &s1, const std::shared_ptr<StrategyWithCost> &s2) {
1426       if (alpha * s1->cost_list[0]->computation_cost_ + beta * s1->cost_list[0]->communication_with_partial_para_ <
1427           alpha * s2->cost_list[0]->computation_cost_ + beta * s2->cost_list[0]->communication_with_partial_para_) {
1428         return true;
1429       }
1430       return false;
1431     });
1432   size_t step_length = origin_stra_cost.size() / target_num;
1433   for (size_t i = 0; ret.size() < target_num && static_cast<size_t>(i * step_length) < origin_stra_cost.size(); ++i) {
1434     ret.push_back(origin_stra_cost[static_cast<size_t>(i * step_length)]);
1435   }
1436 
1437   strategy_cost_ = ret;
1438   is_strategy_cost_exact_ = false;
1439 }
1440 
ExactStrategiesAndRelatedEdges()1441 void OperatorInfo::ExactStrategiesAndRelatedEdges() {
1442   if (is_strategy_cost_exact()) {
1443     return;
1444   }
1445   ClearStrategyCost();
1446   if (GenerateStrategies(0) != SUCCESS) {
1447     MS_LOG(EXCEPTION) << "Strategy search for Operator " << name() << " failed.";
1448     return;
1449   }
1450   SetIsStrategyCostExactTrue();
1451   // re-init the previous edges
1452   for (auto &prev_edge : prev_edges()) {
1453     if (prev_edge->InitEdgeCost() != SUCCESS) {
1454       MS_LOG(EXCEPTION) << "Edge: " << prev_edge->edge_name() << " cost init failed.";
1455     }
1456   }
1457   // re-init the successive edges
1458   for (auto &next_edge : succ_edges()) {
1459     if (next_edge->InitEdgeCost() != SUCCESS) {
1460       MS_LOG(EXCEPTION) << "Edge: " << next_edge->edge_name() << " cost init failed.";
1461     }
1462   }
1463 }
1464 
ComputeOpAndPrevEdgeParameterInvolved()1465 int64_t OperatorInfo::ComputeOpAndPrevEdgeParameterInvolved() {
1466   if (is_output_parameter_involve_ != -1) {
1467     return is_output_parameter_involve_;
1468   }
1469   is_parameter_involve_ = is_parameter_;
1470   const auto &prev_edges = this->GetAlivePrevEdges();
1471   for (auto &p_edge : prev_edges) {
1472     auto input_index = p_edge->next_op_input_index();
1473     auto prev_op_para = p_edge->prev_operator()->ComputeOpAndPrevEdgeParameterInvolved();
1474     if (input_index >= is_parameter_involve_.size()) {
1475       MS_LOG(EXCEPTION) << name_ << " has input length: " << is_parameter_involve_.size()
1476                         << ", but got wrong input_index: " << input_index;
1477     }
1478     if (prev_op_para == 0) {
1479       is_parameter_involve_[input_index] = false;
1480     } else if (prev_op_para == 1) {
1481       is_parameter_involve_[input_index] = true;
1482     } else {
1483       MS_LOG(EXCEPTION) << name_ << " got wrong value: " << prev_op_para << ", input_index: " << input_index;
1484     }
1485     p_edge->set_parameter_involve(prev_op_para);
1486   }
1487   if (std::any_of(is_parameter_involve_.begin(), is_parameter_involve_.end(), [](bool value) { return value; })) {
1488     // If anyone of the input is a parameter_involved, the output is parameter_involved.
1489     is_output_parameter_involve_ = 1;
1490   } else {
1491     is_output_parameter_involve_ = 0;
1492   }
1493   // Set 'is_parameter_involve_' and 'is_output_parameter_involve_' into operatorCost, which are used in
1494   // calculating 'inputs_in_memory' and 'output_in_memory', respectively.
1495   operator_cost()->set_is_parameter_involve(is_parameter_involve_);
1496   operator_cost()->set_output_parameter_involve(is_output_parameter_involve_);
1497   // Calculating 'output_in_memory'
1498   operator_cost()->CalculateOutputInMemory();
1499   // Calculating 'inputs_in_memory'
1500   std::map<size_t, bool> input_in_memory;
1501   for (auto &p_edge : prev_edges) {
1502     auto input_index = p_edge->next_op_input_index();
1503     auto is_in_mem = p_edge->prev_operator()->operator_cost()->is_output_in_memory();
1504     input_in_memory.emplace(std::make_pair(input_index, is_in_mem));
1505   }
1506   operator_cost()->CalculateInputsInMemory(input_in_memory);
1507 
1508   return is_output_parameter_involve_;
1509 }
1510 
set_is_parameter(const std::vector<bool> & is_parameter)1511 Status OperatorInfo::set_is_parameter(const std::vector<bool> &is_parameter) {
1512   if (is_parameter.size() != inputs_shape_.size()) {
1513     MS_LOG(ERROR) << "Is_parameter: " << is_parameter.size()
1514                   << " do not have the same number of inputs_shape_: " << inputs_shape_.size();
1515     return FAILED;
1516   }
1517   is_parameter_ = is_parameter;
1518   operator_cost()->set_is_parameter(is_parameter);
1519   return SUCCESS;
1520 }
1521 
CalculateMemoryCost()1522 Status OperatorInfo::CalculateMemoryCost() {
1523   if (is_parameter_involve_.size() != is_parameter_.size()) {
1524     MS_LOG(ERROR) << "'is_parameter_' does not have the same number of input size of 'is_parameter_involve_'.";
1525     return FAILED;
1526   }
1527   // Set the memory cost in the 'strategy_cost_'
1528   for (auto &swc : strategy_cost_) {
1529     auto mem_cost = operator_cost()->GetMemoryCost(swc->inputs_ptr, swc->outputs_ptr);
1530     swc->cost_list[0]->memory_with_reuse_ = mem_cost;
1531   }
1532   return SUCCESS;
1533 }
1534 
CalculateMemoryCostForInference()1535 Status OperatorInfo::CalculateMemoryCostForInference() {
1536   // First, set the 'is_outputs_critical_' flag into OperatorCost.
1537   if (is_output_critical_ == -1) {
1538     MS_LOG(EXCEPTION) << "The critical flag is not set.";
1539     return FAILED;
1540   }
1541   operator_cost()->set_output_critical(is_output_critical_);
1542   // Set the memory cost in the 'strategy_cost_'
1543   for (auto &swc : strategy_cost_) {
1544     auto mem_cost = operator_cost()->GetMemoryCostForInference(swc->inputs_ptr, swc->outputs_ptr);
1545     swc->cost_list[0]->memory_with_reuse_ = mem_cost;
1546   }
1547   return SUCCESS;
1548 }
1549 
CorrectMemoryCost(size_t input_index)1550 Status OperatorInfo::CorrectMemoryCost(size_t input_index) {
1551   for (auto &swc : strategy_cost_) {
1552     double parameter_mem_cost = ListProduct(swc->inputs_ptr[input_index].slice_shape()) *
1553                                 static_cast<double>(operator_cost()->inputs_type_lengths()[input_index]);
1554     swc->cost_list[0]->memory_with_reuse_ -= parameter_mem_cost;
1555     if (swc->cost_list[0]->memory_with_reuse_ < 0) {
1556       MS_LOG(ERROR) << "The memory cost after correction is: " << swc->cost_list[0]->memory_with_reuse_
1557                     << ", the parameter memory cost is: " << parameter_mem_cost;
1558       return FAILED;
1559     }
1560   }
1561   return SUCCESS;
1562 }
1563 
ComputeRepeatDeviceNumByTensorMap(const Shape & dev_matrix_shape,const Shape & tensor_map)1564 int64_t ComputeRepeatDeviceNumByTensorMap(const Shape &dev_matrix_shape, const Shape &tensor_map) {
1565   int64_t ret = -1;
1566 
1567   // The number of repetitions is equal to the number of all devices divided by the number of devices use for
1568   // tensor map.
1569   int64_t device_num = std::accumulate(dev_matrix_shape.begin(), dev_matrix_shape.end(), 1, std::multiplies<int64_t>());
1570   for (auto &element : tensor_map) {
1571     // -1 means the corresponding dimension is not split.
1572     if (element == MAP_NONE) {
1573       continue;
1574     } else if ((element < 0) || (LongToSize(element) >= dev_matrix_shape.size())) {
1575       MS_LOG(ERROR) << "Invalid tensor map: " << ShapeToString(tensor_map) << ", the dev matrix shape is "
1576                     << ShapeToString(dev_matrix_shape);
1577       return ret;
1578     } else {
1579       size_t index = dev_matrix_shape.size() - LongToSize(element) - 1;
1580       if (dev_matrix_shape[index] <= 0) {
1581         MS_LOG(ERROR) << "Invalid dev matrix shape: " << ShapeToString(dev_matrix_shape);
1582         return ret;
1583       }
1584       device_num /= dev_matrix_shape[index];
1585     }
1586   }
1587 
1588   return device_num;
1589 }
1590 
InferAsLossDivisor()1591 Status OperatorInfo::InferAsLossDivisor() {
1592   if (!ParallelContext::GetInstance()->loss_repeated_mean()) {
1593     as_loss_divisor_ = 1;
1594     return SUCCESS;
1595   }
1596 
1597   if (outputs_tensor_map_.empty()) {
1598     MS_LOG(ERROR) << name_ << ": The outputs tensor map is empty.";
1599     return FAILED;
1600   }
1601 
1602   if (outputs_tensor_map_.size() > 1) {
1603     MS_LOG(ERROR) << name_ << ": The output size is " << outputs_tensor_map_.size()
1604                   << ", need to override this function ";
1605     return FAILED;
1606   }
1607 
1608   if (outputs_tensor_map_[0].empty()) {
1609     as_loss_divisor_ = stage_device_size_;
1610     MS_LOG(INFO) << name_ << ": The output is a scalar, use the dev size " << as_loss_divisor_ << ", loss divisor.";
1611     return SUCCESS;
1612   }
1613 
1614   if (out_dev_matrix_shape_.empty()) {
1615     out_dev_matrix_shape_ = dev_matrix_shape_;
1616   }
1617   as_loss_divisor_ = ComputeRepeatDeviceNumByTensorMap(out_dev_matrix_shape_, outputs_tensor_map_[0]);
1618   MS_LOG(INFO) << name_ << ": the dev matrix shape is " << ShapeToString(out_dev_matrix_shape_)
1619                << ", the output tensor map is " << ShapeToString(outputs_tensor_map_[0]) << ", loss divisor is "
1620                << as_loss_divisor_;
1621   return SUCCESS;
1622 }
1623 
1624 // If the operator is used as a loss, a div node is inserted for the grad of all its inputs.
InferVirtualDivOps()1625 Status OperatorInfo::InferVirtualDivOps() {
1626   if (InferAsLossDivisor() != SUCCESS) {
1627     MS_LOG(ERROR) << name_ << ": InferAsLossDivisor failed.";
1628     return FAILED;
1629   }
1630 
1631   if (as_loss_divisor_ <= 0) {
1632     MS_LOG(ERROR) << name_ << ": Invalid loss divisor: " << as_loss_divisor_;
1633     return FAILED;
1634   } else if (as_loss_divisor_ == 1) {
1635     MS_LOG(INFO) << name_ << ": The loss divisor is 1, no need to create virtual div op.";
1636     return SUCCESS;
1637   }
1638 
1639   virtual_div_op_.clear();
1640   // if loss is repeated calculation, insert div op
1641   Operator op = CreateVirtualDivOp(as_loss_divisor_);
1642   virtual_div_op_.push_back(op);
1643   return SUCCESS;
1644 }
1645 
SetInputAndOutputTypeLength(const std::vector<size_t> & input_lengths,const std::vector<size_t> & output_lengths)1646 Status OperatorInfo::SetInputAndOutputTypeLength(const std::vector<size_t> &input_lengths,
1647                                                  const std::vector<size_t> &output_lengths) {
1648   if (input_lengths.size() != inputs_shape_.size()) {
1649     MS_LOG(ERROR) << "Input_lengths: " << input_lengths.size()
1650                   << " do not have the same number of inputs shape: " << inputs_shape_.size();
1651     return FAILED;
1652   }
1653   if (output_lengths.size() != outputs_shape_.size()) {
1654     MS_LOG(ERROR) << "Output_lengths: " << output_lengths.size()
1655                   << " do not have the same number of outputs shape: " << outputs_shape_.size();
1656     return FAILED;
1657   }
1658   inputs_type_lengths_ = input_lengths;
1659   outputs_type_lengths_ = output_lengths;
1660   operator_cost()->SetInputAndOutputTypeLength(input_lengths, output_lengths);
1661   return SUCCESS;
1662 }
1663 
GetOutputsTotalSize()1664 double OperatorInfo::GetOutputsTotalSize() {
1665   if (is_calculated_outputs_size_) {
1666     return outputs_total_size_;
1667   }
1668   if (outputs_type_lengths_.size() != outputs_shape_.size()) {
1669     MS_LOG(EXCEPTION) << "Output_lengths: " << outputs_type_lengths_.size()
1670                       << " do not have the same number of outputs shape: " << outputs_shape_.size();
1671   }
1672   double sum = 0.0;
1673   for (size_t i = 0; i < outputs_type_lengths_.size(); ++i) {
1674     auto size = std::accumulate(outputs_shape_[i].begin(), outputs_shape_[i].end(), static_cast<double>(1.0),
1675                                 std::multiplies<double>());
1676     sum += size * static_cast<double>(outputs_type_lengths_[i]);
1677   }
1678   is_calculated_outputs_size_ = true;
1679   outputs_total_size_ = sum;
1680   return outputs_total_size_;
1681 }
1682 
set_outputs_type(const std::vector<TypePtr> & outputs_type)1683 Status OperatorInfo::set_outputs_type(const std::vector<TypePtr> &outputs_type) {
1684   if (outputs_type.size() != outputs_shape_.size()) {
1685     MS_LOG(ERROR) << "Outputs type: " << outputs_type.size()
1686                   << " do not have the same number of outputs shape: " << outputs_shape_.size();
1687     return FAILED;
1688   }
1689   outputs_type_ = outputs_type;
1690   return SUCCESS;
1691 }
1692 
BreakingTiesForPerferringDataParallel(const StrategyPtr & stra,const CostPtr & cost)1693 void OperatorInfo::BreakingTiesForPerferringDataParallel(const StrategyPtr &stra, const CostPtr &cost) {
1694   if (!stra->GetInputDim().empty() && !stra->GetInputDim()[0].empty()) {
1695     if (stra->GetInputDim()[0][0] == stage_device_size_) {
1696       if (cost->computation_cost_ > 1.0) {
1697         cost->computation_cost_ -= 1.0;
1698       }
1699       if (cost->communication_cost_ > 1.0) {
1700         cost->communication_cost_ -= 1.0;
1701       }
1702       if (cost->communication_with_partial_para_ > 1.0) {
1703         cost->communication_with_partial_para_ -= 1.0;
1704       }
1705       if (cost->communication_without_parameter_ > 1.0) {
1706         cost->communication_without_parameter_ -= 1.0;
1707       }
1708     }
1709   }
1710 }
1711 
SetSelectedStrategy(const StrategyPtr & s_strategy,size_t curr_depth)1712 void OperatorInfo::SetSelectedStrategy(const StrategyPtr &s_strategy, size_t curr_depth) {
1713   MS_EXCEPTION_IF_NULL(s_strategy);
1714   if ((selected_strategy_depth_ != -1) && (SizeToLong(curr_depth) > selected_strategy_depth_)) {
1715     MS_LOG(INFO) << name_ << " has already been set strategy.";
1716     return;
1717   }
1718   MS_LOG(INFO) << "Set strategy for: " << name_;
1719   PrintStrategy(s_strategy);
1720   selected_strategy_ = s_strategy;
1721   selected_strategy_depth_ = SizeToLong(curr_depth);
1722 }
1723 
cnode()1724 CNodePtr OperatorInfo::cnode() {
1725   MS_EXCEPTION_IF_NULL(cnode_);
1726   return cnode_;
1727 }
1728 
GetForwardMemoryCostFromCNode()1729 double OperatorInfo::GetForwardMemoryCostFromCNode() {
1730   return operator_cost()->GetForwardComputationCost(inputs_tensor_info_, outputs_tensor_info_, 0);
1731 }
1732 
CheckSelectedStrategy(const StrategyPtr & s_strategy)1733 void OperatorInfo::CheckSelectedStrategy(const StrategyPtr &s_strategy) {
1734   MS_EXCEPTION_IF_NULL(s_strategy);
1735   if (!s_strategy->IsEqual(selected_strategy_)) {
1736     MS_LOG(INFO) << name() << "'s strategy may cause suboptimal, the determined strategy:";
1737     PrintStrategy(selected_strategy_);
1738     MS_LOG(INFO) << "The minimal strategy:";
1739     PrintStrategy(s_strategy);
1740   }
1741 }
1742 
SetStrategyCost(const std::vector<std::shared_ptr<StrategyWithCost>> & stra_cost)1743 void OperatorInfo::SetStrategyCost(const std::vector<std::shared_ptr<StrategyWithCost>> &stra_cost) {
1744   strategy_cost_ = stra_cost;
1745 }
1746 
GenerateStrategies(int64_t stage_id)1747 Status OperatorInfo::GenerateStrategies(int64_t stage_id) {
1748   if (InferAttrs() != SUCCESS) {
1749     MS_LOG(ERROR) << name_ << ": Infer attrs failed";
1750     return FAILED;
1751   }
1752 
1753   std::vector<StrategyPtr> sp_vector = GenerateOpStrategies(stage_id);
1754 
1755   size_t success = 0;
1756   for (auto &sp : sp_vector) {
1757     PrintStrategy(sp);
1758     if (SetCostUnderStrategy(sp) == SUCCESS) {
1759       success++;
1760       MS_LOG(INFO) << name_ << ": Successfully generated " << success << " strategy.";
1761       PrintStrategy(sp);
1762     }
1763   }
1764   return SUCCESS;
1765 }
1766 
GetIntAttr(const std::string & attr_name)1767 int64_t OperatorInfo::GetIntAttr(const std::string &attr_name) {
1768   auto attr_iter = attrs_.find(attr_name);
1769   if (attr_iter == attrs_.end()) {
1770     MS_LOG(EXCEPTION) << name_ << ": Can not find the attribution of " << attr_name;
1771   }
1772 
1773   MS_EXCEPTION_IF_NULL(attr_iter->second);
1774   if (!attr_iter->second->isa<Int64Imm>()) {
1775     MS_LOG(EXCEPTION) << name_ << ": The value of " << attr_name << " is not int";
1776   }
1777 
1778   return attr_iter->second->cast<Int64ImmPtr>()->value();
1779 }
1780 
GetBoolAttr(const std::string & attr_name)1781 bool OperatorInfo::GetBoolAttr(const std::string &attr_name) {
1782   auto attr_iter = attrs_.find(attr_name);
1783   if (attr_iter == attrs_.end()) {
1784     MS_LOG(EXCEPTION) << name_ << ": Can not find the attribution of " << attr_name;
1785   }
1786 
1787   MS_EXCEPTION_IF_NULL(attr_iter->second);
1788   if (!attr_iter->second->isa<BoolImm>()) {
1789     MS_LOG(EXCEPTION) << name_ << ": The value of " << attr_name << " is not int";
1790   }
1791 
1792   return attr_iter->second->cast<BoolImmPtr>()->value();
1793 }
1794 
GetStringAttr(const std::string & attr_name)1795 std::string OperatorInfo::GetStringAttr(const std::string &attr_name) {
1796   std::string string_attr;
1797   auto attr_iter = attrs_.find(attr_name);
1798   if (attr_iter == attrs_.end()) {
1799     MS_LOG(EXCEPTION) << name_ << ": Can not find the attribution of " << attr_name;
1800   }
1801 
1802   MS_EXCEPTION_IF_NULL(attr_iter->second);
1803   if (!attr_iter->second->isa<StringImm>()) {
1804     MS_LOG(EXCEPTION) << name_ << ": The value of " << attr_name << " is not string";
1805   }
1806 
1807   string_attr = attr_iter->second->cast<StringImmPtr>()->value();
1808   return string_attr;
1809 }
1810 
GetTupleIntAttr(const std::string & attr_name)1811 std::vector<int64_t> OperatorInfo::GetTupleIntAttr(const std::string &attr_name) {
1812   std::vector<int64_t> tuple_attr;
1813   auto tuple_attr_iter = attrs_.find(attr_name);
1814   if (tuple_attr_iter == attrs_.end()) {
1815     MS_LOG(EXCEPTION) << name_ << ": Can not find the attribution of " << attr_name;
1816   }
1817 
1818   MS_EXCEPTION_IF_NULL(tuple_attr_iter->second);
1819   tuple_attr = GetValue<std::vector<int64_t>>(tuple_attr_iter->second);
1820 
1821   return tuple_attr;
1822 }
1823 
GetFloatAttr(const std::string & attr_name)1824 float OperatorInfo::GetFloatAttr(const std::string &attr_name) {
1825   auto attr_iter = attrs_.find(attr_name);
1826   if (attr_iter == attrs_.end()) {
1827     MS_LOG(EXCEPTION) << name_ << ": Can not find the attribution of " << attr_name;
1828   }
1829 
1830   MS_EXCEPTION_IF_NULL(attr_iter->second);
1831   if (!attr_iter->second->isa<FP32Imm>()) {
1832     MS_LOG(EXCEPTION) << name_ << ": The value of " << attr_name << " is not float";
1833   }
1834 
1835   return attr_iter->second->cast<FP32ImmPtr>()->value();
1836 }
1837 
GetValueSequeue(const ValuePtr & sequeue)1838 std::vector<ValuePtr> GetValueSequeue(const ValuePtr &sequeue) {
1839   MS_EXCEPTION_IF_NULL(sequeue);
1840   std::vector<ValuePtr> ret;
1841   if (!sequeue->isa<ValueTuple>() && !sequeue->isa<ValueList>()) {
1842     MS_LOG(ERROR) << "The arg is not value tuple or value list";
1843     return ret;
1844   }
1845 
1846   if (sequeue->isa<ValueTuple>()) {
1847     auto val_tuple = sequeue->cast<ValueTuplePtr>();
1848     return val_tuple->value();
1849   }
1850   auto val = sequeue->cast<ValueListPtr>();
1851   return val->value();
1852 }
1853 }  // namespace parallel
1854 }  // namespace mindspore
1855