1 /**
2 * Copyright 2019-2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/parallel/ops_info/operator_info.h"
18
19 #include <algorithm>
20 #include <functional>
21 #include <memory>
22 #include <string>
23 #include <utility>
24 #include <vector>
25
26 #include "ir/dtype.h"
27 #include "ir/tensor.h"
28 #include "ir/value.h"
29 #include "frontend/parallel/auto_parallel/edge_costmodel.h"
30 #include "frontend/parallel/auto_parallel/graph_costmodel.h"
31 #include "frontend/parallel/context.h"
32 #include "utils/log_adapter.h"
33
34 namespace mindspore {
35 namespace parallel {
StrategyToString(const Strategys & strategy)36 std::string StrategyToString(const Strategys &strategy) {
37 std::string strategy_str = "";
38 strategy_str += "(";
39 for (size_t i = 0; i < strategy.size(); ++i) {
40 strategy_str += "(";
41 for (size_t j = 0; j < strategy[i].size(); ++j) {
42 strategy_str += std::to_string(strategy[i][j]);
43 if (j != strategy[i].size() - 1) {
44 strategy_str += ", ";
45 }
46 }
47 strategy_str += ")";
48 if (i != strategy.size() - 1) {
49 strategy_str += ", ";
50 }
51 }
52 if (strategy.size() == 1) {
53 strategy_str += ",";
54 }
55 strategy_str += ")";
56 return strategy_str;
57 }
58
CheckStrategyValue(const StrategyPtr & strategy,const Shapes & inputs_shape)59 Status OperatorInfo::CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shape) {
60 if (strategy == nullptr) {
61 MS_LOG(ERROR) << name_ << ": The strategy is null.";
62 return FAILED;
63 }
64
65 size_t strategy_size = strategy->GetInputNumber();
66 size_t inputs_shape_size = inputs_shape.size();
67 Strategys stra = strategy->GetInputDim();
68 if (strategy_size != inputs_shape_size) {
69 if (is_auto_parallel_) {
70 MS_LOG(DEBUG) << name_ << ": The strategy is " << StrategyToString(stra) << ", strategy size: " << strategy_size
71 << " is not equal to inputs size: " << inputs_shape_size;
72 } else {
73 MS_LOG(ERROR) << name_ << ": The strategy is " << StrategyToString(stra) << ", strategy size: " << strategy_size
74 << " is not equal to inputs size: " << inputs_shape_size;
75 }
76 return FAILED;
77 }
78
79 for (size_t i = 0; i < strategy_size; ++i) {
80 Shape sub_strategy = stra.at(i);
81 Shape sub_input_shape = inputs_shape.at(i);
82 size_t strategy_len = sub_strategy.size();
83 size_t inputs_len = sub_input_shape.size();
84 if (strategy_len != inputs_len) {
85 if (is_auto_parallel_) {
86 MS_LOG(DEBUG) << name_ << ": The strategy is " << StrategyToString(stra) << ", strategy len: " << strategy_len
87 << " is not equal to inputs len: " << inputs_len << ", index: " << i;
88 } else {
89 MS_LOG(ERROR) << name_ << ": The strategy is " << StrategyToString(stra) << ", strategy len: " << strategy_len
90 << " is not equal to inputs len: " << inputs_len << ", index: " << i;
91 }
92 return FAILED;
93 }
94
95 for (size_t j = 0; j < strategy_len; ++j) {
96 int64_t strategy_value = sub_strategy.at(j);
97 if (strategy_value < MIN_SLICE_NUM) {
98 if (is_auto_parallel_) {
99 MS_LOG(DEBUG) << name_ << ": The strategy is " << StrategyToString(stra)
100 << ", the value of strategy must be larger than 0, but get " << strategy_value;
101 } else {
102 MS_LOG(ERROR) << name_ << ": The strategy is " << StrategyToString(stra)
103 << ", the value of strategy must be larger than 0, but get " << strategy_value;
104 }
105 return FAILED;
106 }
107
108 if ((LongToUlong(strategy_value) & LongToUlong(strategy_value - 1)) != 0) {
109 if (is_auto_parallel_) {
110 MS_LOG(DEBUG) << name_ << ": The strategy is " << StrategyToString(stra)
111 << ", the value of strategy must be the power of 2, but get " << strategy_value;
112 } else {
113 MS_LOG(ERROR) << name_ << ": The strategy is " << StrategyToString(stra)
114 << ", the value of strategy must be the power of 2, but get " << strategy_value;
115 }
116 return FAILED;
117 }
118
119 int64_t shape_value = sub_input_shape.at(j);
120 if ((shape_value % strategy_value) != 0) {
121 if (is_auto_parallel_) {
122 MS_LOG(DEBUG) << name_ << ": The strategy is " << StrategyToString(stra) << ", shape " << shape_value
123 << " cannot be divisible by strategy value " << strategy_value;
124 } else {
125 MS_LOG(ERROR) << name_ << ": The strategy is " << StrategyToString(stra) << ", shape " << shape_value
126 << " cannot be divisible by strategy value " << strategy_value;
127 }
128 return FAILED;
129 }
130 }
131 }
132
133 return SUCCESS;
134 }
135
ResetQueueMember()136 void OperatorInfo::ResetQueueMember() {
137 inputs_tensor_info_.clear();
138 outputs_tensor_info_.clear();
139 inputs_tensor_map_.clear();
140 outputs_tensor_map_.clear();
141 dev_matrix_shape_.clear();
142 forward_op_.clear();
143 mirror_ops_.clear();
144 sub_ops_.clear();
145 replace_op_.clear();
146 replace_op_info_.clear();
147 virtual_div_op_.clear();
148 }
149
InferAttrs()150 Status OperatorInfo::InferAttrs() {
151 if (infer_attrs_completed_) {
152 return SUCCESS;
153 }
154
155 if (GetAttrs() != SUCCESS) {
156 return FAILED;
157 }
158 infer_attrs_completed_ = true;
159 return SUCCESS;
160 }
161
InferMirrorOps()162 Status OperatorInfo::InferMirrorOps() {
163 mirror_ops_.clear();
164 if (inputs_shape_.empty()) {
165 MS_LOG(INFO) << name_ << ": The inputs size is empty";
166 return SUCCESS;
167 }
168
169 if (inputs_tensor_map_.size() != inputs_shape_.size()) {
170 MS_LOG(ERROR) << name_ << ": The size of inputs tensor map is not equal to the size of inputs shape";
171 return FAILED;
172 }
173
174 bool group_is_empty = true;
175 for (size_t i = 0; i < inputs_tensor_map_.size(); ++i) {
176 std::vector<Group> group;
177 if (CreateGroupByTensorMap(inputs_tensor_map_[i], &group) != SUCCESS) {
178 MS_LOG(ERROR) << name_ << ": Create group failed, the input index is " << i;
179 mirror_ops_.clear();
180 return FAILED;
181 }
182
183 OperatorVector mirror_op;
184 if (group.empty()) {
185 MS_LOG(INFO) << name_ << ": The mirror group is empty, the input index is " << i;
186 mirror_ops_.push_back(mirror_op);
187 continue;
188 }
189
190 group_is_empty = false;
191 mirror_op = CreateMirrorOps(group[0].name(), group[0].GetDevNum());
192 mirror_ops_.push_back(mirror_op);
193 }
194
195 if (group_is_empty) {
196 mirror_ops_.clear();
197 MS_LOG(INFO) << name_ << ": No need to insert mirror ops";
198 }
199 return SUCCESS;
200 }
201
InferTensorInfo()202 Status OperatorInfo::InferTensorInfo() {
203 if (inputs_shape_.empty() || outputs_shape_.empty() || inputs_tensor_map_.empty() || outputs_tensor_map_.empty()) {
204 MS_LOG(ERROR) << name_ << ": Invalid args";
205 return FAILED;
206 }
207
208 for (size_t i = 0; i < inputs_tensor_map_.size(); ++i) {
209 TensorLayout input_layout;
210 if (input_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[i], inputs_shape_[i]) != SUCCESS) {
211 MS_LOG(ERROR) << name_ << ": Infer input tensor layout failed, the index is " << i;
212 return FAILED;
213 }
214 TensorInfo input_tensor_info(input_layout);
215 inputs_tensor_info_.push_back(input_tensor_info);
216 }
217
218 for (size_t i = 0; i < outputs_tensor_map_.size(); ++i) {
219 TensorLayout output_layout;
220 if (output_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[i], outputs_shape_[i]) != SUCCESS) {
221 MS_LOG(ERROR) << name_ << ": Infer output tensor layout failed, the index is " << i;
222 return FAILED;
223 }
224 TensorInfo output_tensor_info(output_layout);
225 outputs_tensor_info_.push_back(output_tensor_info);
226 }
227
228 return SUCCESS;
229 }
230
InferRepeatedCalcInfo()231 Status OperatorInfo::InferRepeatedCalcInfo() {
232 int64_t g_dev_list_size = stage_device_size_;
233 int64_t dev_matrix_size =
234 std::accumulate(dev_matrix_shape_.begin(), dev_matrix_shape_.end(), 1, std::multiplies<int64_t>());
235 if (dev_matrix_size == 0) {
236 MS_LOG(ERROR) << name_ << ": The dev matrix size is 0";
237 return FAILED;
238 }
239
240 if (g_dev_list_size == dev_matrix_size) {
241 repeated_calc_num_ = 1;
242 } else if (g_dev_list_size % dev_matrix_size == 0) {
243 repeated_calc_num_ = ((int64_t)(g_dev_list_size / dev_matrix_size));
244 } else {
245 MS_LOG(ERROR) << name_ << ": The strategy is " << StrategyToString(strategy_->GetInputDim()) << ", it requires "
246 << dev_matrix_size << " devices, "
247 << "but the device number of this stage is " << g_dev_list_size << ", it can not be divisible by "
248 << dev_matrix_size;
249 return FAILED;
250 }
251 return SUCCESS;
252 }
253
254 // If repeated calculation, set the repeated_calc_num as the last dimension of dev-matrix in default,
255 // because if the previous shard is (a, b), and the next shard is (a, 1), adding the repeated_calc_num
256 // to the last dimension of dev-matrix, there is no need to redistribution.
SetRepeatedCalcDevMatrix()257 void OperatorInfo::SetRepeatedCalcDevMatrix() {
258 if (repeated_calc_num_ <= 1) {
259 return;
260 }
261 if (repeated_num_in_dev_matrix_right_) {
262 dev_matrix_shape_.push_back(repeated_calc_num_);
263 } else {
264 (void)dev_matrix_shape_.insert(dev_matrix_shape_.begin(), repeated_calc_num_);
265 }
266 }
267
268 // If repeated calculation, and the repeated_calc_num is inserted to the last dimension of the dev-matrix,
269 // the index value of tensor map needs to be increased by 1.
ResetTensorMapIfRepeatedCalc()270 void OperatorInfo::ResetTensorMapIfRepeatedCalc() {
271 if ((repeated_calc_num_ <= 1) || !repeated_num_in_dev_matrix_right_) {
272 return;
273 }
274
275 MS_LOG(DEBUG) << name_ << ": the repeated calc num is " << repeated_calc_num_ << ", and reset the tensor maps";
276 for (auto &tensor_map : inputs_tensor_map_) {
277 for (auto &element : tensor_map) {
278 if (element == MAP_NONE) {
279 continue;
280 }
281 element += 1;
282 }
283 }
284
285 for (auto &tensor_map : outputs_tensor_map_) {
286 for (auto &element : tensor_map) {
287 if (element == MAP_NONE) {
288 continue;
289 }
290 element += 1;
291 }
292 }
293 }
294
295 // use for loss repeated calculation
CreateVirtualDivOp(int64_t div_num)296 Operator CreateVirtualDivOp(int64_t div_num) {
297 OperatorName operator_name = VIRTUAL_DIV;
298 ValuePtr attr0_value = MakeValue(div_num);
299 Attr attr0 = std::make_pair(DIVISOR, attr0_value);
300 OperatorAttrs operator_attrs;
301 operator_attrs.push_back(attr0);
302
303 OperatorParams operator_param;
304 OperatorArgs operator_arg = std::make_pair(operator_attrs, operator_param);
305
306 Operator op = std::make_pair(operator_name, operator_arg);
307 return op;
308 }
309
CreateReduceCommunicationOpArgs(const std::string & reduce_op,const std::string & group)310 static OperatorArgs CreateReduceCommunicationOpArgs(const std::string &reduce_op, const std::string &group) {
311 ValuePtr attr0_value = MakeValue(reduce_op);
312 ValuePtr attr1_value = MakeValue(group);
313 Attr attr0 = std::make_pair(OP, attr0_value);
314 Attr attr1 = std::make_pair(GROUP, attr1_value);
315 OperatorAttrs operator_attrs;
316 operator_attrs.push_back(attr0);
317 operator_attrs.push_back(attr1);
318
319 OperatorParams operator_param;
320 return std::make_pair(operator_attrs, operator_param);
321 }
322
323 // use for forward all reduce
CreateAllReduceOp(const std::string & reduce_op,const std::string & group)324 Operator CreateAllReduceOp(const std::string &reduce_op, const std::string &group) {
325 OperatorName operator_name = ALL_REDUCE;
326 OperatorArgs operator_arg = CreateReduceCommunicationOpArgs(reduce_op, group);
327
328 Operator op = std::make_pair(operator_name, operator_arg);
329 MS_LOG(INFO) << "Create all reduce op success, the reduce_op is " << reduce_op << ", the group is " << group;
330 return op;
331 }
332
CreateReduceScatterOp(const std::string & reduce_op,const std::string & group)333 Operator CreateReduceScatterOp(const std::string &reduce_op, const std::string &group) {
334 OperatorName operator_name = REDUCE_SCATTER;
335 OperatorArgs operator_arg = CreateReduceCommunicationOpArgs(reduce_op, group);
336
337 Operator op = std::make_pair(operator_name, operator_arg);
338 MS_LOG(INFO) << "Create reduce scatter op success, the reduce_op is " << reduce_op << ", the group is " << group;
339 return op;
340 }
341
AddCommOpFusionType(const CNodePtr & comm_node,const AnfNodePtr & param_node)342 void AddCommOpFusionType(const CNodePtr &comm_node, const AnfNodePtr ¶m_node) {
343 MS_EXCEPTION_IF_NULL(comm_node);
344 MS_EXCEPTION_IF_NULL(param_node);
345 ParameterPtr param;
346 if (IsPrimitiveCNode(param_node, prim::kPrimReceive)) {
347 param = param_node->user_data<AnfNode>(PIPELINE_PARAM)->cast<ParameterPtr>();
348 } else {
349 param = param_node->cast<ParameterPtr>();
350 }
351 MS_EXCEPTION_IF_NULL(param);
352 auto prim = GetValueNode<PrimitivePtr>(comm_node->input(0));
353 MS_EXCEPTION_IF_NULL(prim);
354 auto attrs = prim->attrs();
355 auto param_info = param->param_info();
356 if (!param_info) {
357 MS_LOG(WARNING) << param->ToString() << "does not have parameter info.";
358 return;
359 }
360 int32_t fusion_type = param_info->comm_fusion();
361 attrs[FUSION] = MakeValue<int64_t>(fusion_type);
362 prim->SetAttrs(attrs);
363 bool parallel_optimizer_comm_recompute = param_info->parallel_optimizer_comm_recompute();
364 std::string instance_name = prim->instance_name();
365 if (instance_name == PARALLEL_OPTIMIZER_ALLGATHER_NOT_COMPUTE && parallel_optimizer_comm_recompute &&
366 prim->name() == ALL_GATHER) {
367 prim->set_attr(RECOMPUTE, MakeValue(true));
368 prim->set_instance_name(PARALLEL_OPTIMIZER_ALLGATHER);
369 }
370 MS_LOG(INFO) << "Set comm fusion:" << param->param_info()->name() << "'s fusion type is " << fusion_type;
371 }
372
AddCommOpMeanFlag(const CNodePtr & comm_node)373 void AddCommOpMeanFlag(const CNodePtr &comm_node) {
374 MS_EXCEPTION_IF_NULL(comm_node);
375 auto prim = GetValueNode<PrimitivePtr>(comm_node->input(0));
376 auto attrs = prim->attrs();
377 MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
378 bool mean_flag = ParallelContext::GetInstance()->gradients_mean();
379 attrs[MEAN_FLAG] = MakeValue<bool>(mean_flag);
380 prim->SetAttrs(attrs);
381 }
382
AddCommOpParamFlag(const CNodePtr & comm_node)383 void AddCommOpParamFlag(const CNodePtr &comm_node) {
384 MS_EXCEPTION_IF_NULL(comm_node);
385 auto graph = comm_node->func_graph();
386 MS_EXCEPTION_IF_NULL(graph);
387 auto manager = graph->manager();
388 MS_EXCEPTION_IF_NULL(manager);
389 auto node_users = manager->node_users()[comm_node->input(1)];
390 for (auto &node_user : node_users) {
391 if (IsPrimitiveCNode(node_user.first, prim::kPrimSend)) {
392 auto prim = GetCNodePrimitive(comm_node);
393 (void)prim->AddAttr(PARAMETER_MICRO, MakeValue(0));
394 return;
395 }
396 }
397 }
398
CreateAllGatherOp(const std::string & group)399 Operator CreateAllGatherOp(const std::string &group) {
400 OperatorName operator_name = ALL_GATHER;
401 ValuePtr attr0_value = MakeValue(group); // group
402 Attr attr0 = std::make_pair(GROUP, attr0_value);
403 OperatorAttrs operator_attrs;
404 operator_attrs.push_back(attr0);
405
406 OperatorParams operator_param;
407 OperatorArgs operator_arg = std::make_pair(operator_attrs, operator_param);
408
409 Operator op = std::make_pair(operator_name, operator_arg);
410 MS_LOG(INFO) << "Create allgather op success, the group is " << group;
411 return op;
412 }
413
CreateMiniStepAllGatherOp(const std::string & group)414 Operator CreateMiniStepAllGatherOp(const std::string &group) {
415 int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step();
416 bool mean_flag = ParallelContext::GetInstance()->gradients_mean();
417
418 OperatorName operator_name = MINI_STEP_ALL_GATHER;
419 ValuePtr attr0_value = MakeValue(group); // group
420 Attr attr0 = std::make_pair(GROUP, attr0_value);
421 ValuePtr attr1_value = MakeValue(grad_accumulation_step); // grad_accumulation_step
422 Attr attr1 = std::make_pair(GRAD_ACCUMULATION_STEP, attr1_value);
423 ValuePtr attr2_value = MakeValue(mean_flag); // mean_flag
424 Attr attr2 = std::make_pair(MEAN_FLAG, attr2_value);
425 OperatorAttrs operator_attrs;
426 operator_attrs.push_back(attr0);
427 operator_attrs.push_back(attr1);
428 operator_attrs.push_back(attr2);
429
430 OperatorParams operator_param;
431 OperatorArgs operator_arg = std::make_pair(operator_attrs, operator_param);
432
433 Operator op = std::make_pair(operator_name, operator_arg);
434 MS_LOG(INFO) << "Create MINI_STEP_ALL_GATHER success, the group is " << group;
435 return op;
436 }
437
CreateMicroStepAllGatherOp(const std::string & group)438 Operator CreateMicroStepAllGatherOp(const std::string &group) {
439 bool mean_flag = ParallelContext::GetInstance()->gradients_mean();
440
441 OperatorName operator_name = MICRO_STEP_ALL_GATHER;
442 ValuePtr attr0_value = MakeValue(group); // group
443 Attr attr0 = std::make_pair(GROUP, attr0_value);
444 ValuePtr attr1_value = MakeValue(mean_flag); // mean_flag
445 Attr attr1 = std::make_pair(MEAN_FLAG, attr1_value);
446 OperatorAttrs operator_attrs;
447 operator_attrs.push_back(attr0);
448 operator_attrs.push_back(attr1);
449
450 OperatorParams operator_param;
451 OperatorArgs operator_arg = std::make_pair(operator_attrs, operator_param);
452
453 Operator op = std::make_pair(operator_name, operator_arg);
454 MS_LOG(INFO) << "Create MICRO_STEP_ALL_GATHER success, the group is " << group;
455 return op;
456 }
457
458 // use for get tensor slice
CreateGetTensorSliceOp(const TensorLayout & tensor_layout)459 Operator CreateGetTensorSliceOp(const TensorLayout &tensor_layout) {
460 Shape tensor_map = tensor_layout.tensor_map().array();
461 Shape dev_matrix_shape = tensor_layout.device_arrangement().array();
462 OperatorName operator_name = GET_TENSOR_SLICE;
463
464 OperatorAttrs attrs;
465 ValuePtr dev_mat_value = MakeValue(dev_matrix_shape);
466 Param dev_mat_param = std::make_pair(std::make_pair(DEV_MAT, dev_mat_value), 2);
467 ValuePtr tensor_map_value = MakeValue(tensor_map);
468 Param tensor_map_param = std::make_pair(std::make_pair(TENSOR_MAP, tensor_map_value), 3);
469 OperatorParams params = {dev_mat_param, tensor_map_param};
470 OperatorArgs operator_arg = std::make_pair(attrs, params);
471
472 Operator op = std::make_pair(operator_name, operator_arg);
473 MS_LOG(INFO) << "Create get tensor slice op success, the dev mat and tensor map is "
474 << ShapeToString(dev_matrix_shape) << ", " << ShapeToString(tensor_map);
475 return op;
476 }
477
CreateMirrorOps(const std::string & group_name,size_t dev_num)478 OperatorVector CreateMirrorOps(const std::string &group_name, size_t dev_num) {
479 if ((dev_num == 0) || (dev_num == 1)) {
480 MS_LOG(EXCEPTION) << "Invalid dev num: " << dev_num;
481 }
482 OperatorVector op_for_weight;
483 bool mean_flag = ParallelContext::GetInstance()->gradients_mean();
484 int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step();
485 int64_t split_stage_num = ParallelContext::GetInstance()->pipeline_stage_split_num();
486
487 ValuePtr attr0_value = MakeValue(group_name);
488 ValuePtr attr1_value = MakeValue(SizeToLong(dev_num));
489 ValuePtr attr2_value = MakeValue(mean_flag);
490
491 Attr attr0 = std::make_pair(GROUP, attr0_value);
492 Attr attr1 = std::make_pair(DEV_NUM, attr1_value);
493 Attr attr2 = std::make_pair(MEAN_FLAG, attr2_value);
494
495 OperatorAttrs operator_attrs;
496 operator_attrs.push_back(attr0);
497 operator_attrs.push_back(attr1);
498 operator_attrs.push_back(attr2);
499
500 OperatorName operator_name;
501 if (grad_accumulation_step > 1) {
502 operator_name = MIRROR_MINI_STEP_OPERATOR;
503 ValuePtr attr3_value = MakeValue(grad_accumulation_step);
504 Attr attr3 = std::make_pair(GRAD_ACCUMULATION_STEP, attr3_value);
505 operator_attrs.push_back(attr3);
506 MS_LOG(INFO) << "The grad accumulation step is " << grad_accumulation_step << ", use mini step mirror";
507 } else if (split_stage_num > 1) {
508 operator_name = MIRROR_MICRO_STEP_OPERATOR;
509 } else {
510 operator_name = MIRROR_OPERATOR;
511 }
512
513 OperatorParams operator_param;
514 OperatorArgs operator_args = std::make_pair(operator_attrs, operator_param);
515
516 Operator op = std::make_pair(operator_name, operator_args);
517
518 op_for_weight.push_back(op);
519 MS_LOG(INFO) << "The group name is " << group_name << ", the dev num is " << dev_num << ", the mean flag is "
520 << mean_flag;
521 return op_for_weight;
522 }
523
CreateGroupByTensorMap(const Shape & tensor_map,std::vector<Group> * group)524 Status OperatorInfo::CreateGroupByTensorMap(const Shape &tensor_map, std::vector<Group> *group) {
525 if (group == nullptr) {
526 MS_LOG(ERROR) << "The group is null.";
527 return FAILED;
528 }
529 CheckGlobalDeviceManager();
530 int64_t rank = g_device_manager->global_rank();
531 DeviceMatrix dev_matrix(rank, stage_device_list_, dev_matrix_shape_);
532 RankList group_devices;
533 if (dev_matrix.GetDevicesByTensorMap(tensor_map, &group_devices) != SUCCESS) {
534 return FAILED;
535 }
536
537 if (group_devices.size() == 1) {
538 MS_LOG(INFO) << "The dev size is 1, no need to create group.";
539 return SUCCESS;
540 }
541
542 Group g = g_device_manager->CreateGroup(group_devices);
543 group->push_back(g);
544 return SUCCESS;
545 }
546
CreateGroupForOptShard(TensorLayout * const tensor_layout,std::vector<Group> * groups)547 Status OperatorInfo::CreateGroupForOptShard(TensorLayout *const tensor_layout, std::vector<Group> *groups) {
548 if (groups == nullptr) {
549 MS_LOG(ERROR) << "The group is null. Operator is " << name_;
550 return FAILED;
551 }
552 CheckGlobalDeviceManager();
553 int64_t rank = g_device_manager->global_rank();
554 DeviceMatrix dev_matrix(rank, stage_device_list_, dev_matrix_shape_);
555 RankList group_devices;
556 Shape tensor_map = tensor_layout->origin_tensor_map().array();
557 if (dev_matrix.GetDevicesByTensorMap(tensor_map, &group_devices) != SUCCESS) {
558 return FAILED;
559 }
560
561 if (group_devices.size() == 1) {
562 MS_LOG(INFO) << "The dev size is 1, no need to create group.";
563 return SUCCESS;
564 }
565 int64_t optimizer_weight_shard_size = ParallelContext::GetInstance()->optimizer_weight_shard_size();
566 if (optimizer_weight_shard_size != -1) {
567 // not fully use opt shard
568 int64_t index = std::find(group_devices.begin(), group_devices.end(), rank) - group_devices.begin();
569 int64_t repeated_size = SizeToLong(group_devices.size());
570 if (repeated_size % optimizer_weight_shard_size != 0) {
571 MS_LOG(WARNING) << "Parallel optimizer: optimizer_weight_shard_size " << optimizer_weight_shard_size
572 << " can not be applied. The repeated size of Operator " << name_ << " is " << repeated_size;
573 return FAILED;
574 }
575 repeated_size = repeated_size / optimizer_weight_shard_size;
576 // create allgather group
577 // eg: optimizer_weight_shard_size = 2, [0, 8, 16, 24] -> [0, 8], [16, 24]
578 RankList new_group_devices(
579 group_devices.begin() + index / optimizer_weight_shard_size * optimizer_weight_shard_size,
580 group_devices.begin() + (index / optimizer_weight_shard_size + 1) * optimizer_weight_shard_size);
581 Group allgather_group = g_device_manager->CreateGroup(new_group_devices);
582 groups->push_back(allgather_group);
583 tensor_layout->set_opt_shard_group(allgather_group.name());
584 MS_LOG(INFO) << "Parallel optimizer: create allgather group " << allgather_group.name();
585 // create mirror group
586 // eg: optimizer_weight_shard_size = 2, [0, 8, 16, 24] -> [0, 16], [8, 24]
587 int64_t device_num = g_device_manager->stage_device_num();
588 Shape dev_mat = {repeated_size, device_num / repeated_size};
589 DeviceMatrix temp_dev_matrix(rank, stage_device_list_, dev_mat);
590 RankList mirror_group_devices;
591 if (temp_dev_matrix.GetDevicesAlongDim(0, &mirror_group_devices) != SUCCESS) {
592 return FAILED;
593 }
594 Group mirror_group = g_device_manager->CreateGroup(mirror_group_devices);
595 groups->push_back(mirror_group);
596 tensor_layout->set_opt_shard_mirror_group(mirror_group.name());
597 MS_LOG(INFO) << "Parallel optimizer: create mirror group " << mirror_group.name();
598 } else {
599 // fully use opt shard
600 // create allgather group
601 Group allgather_group = g_device_manager->CreateGroup(group_devices);
602 groups->push_back(allgather_group);
603 tensor_layout->set_opt_shard_group(allgather_group.name());
604 MS_LOG(INFO) << "Parallel optimizer: create allgather group " << allgather_group.name();
605 }
606 // save in tensor_layout for strategy ckpt
607 auto integrated_save = ParallelContext::GetInstance()->optimizer_weight_shard_aggregated_save();
608 if (!integrated_save) {
609 tensor_layout->set_opt_weight_shard_size(LongToInt(optimizer_weight_shard_size));
610 int64_t opt_weight_shard_step =
611 (group_devices.back() - group_devices.front()) / (SizeToLong(group_devices.size()) - 1);
612 tensor_layout->set_opt_weight_shard_step(LongToInt(opt_weight_shard_step));
613 MS_LOG(INFO) << "Parallel optimizer: save opt_weight_shard_step " << opt_weight_shard_step << " in strategy ckpt";
614 }
615 return SUCCESS;
616 }
617
CreateGroupByDim(size_t axis,std::vector<Group> * group)618 Status OperatorInfo::CreateGroupByDim(size_t axis, std::vector<Group> *group) {
619 if (group == nullptr) {
620 MS_LOG(ERROR) << "The group is null.";
621 return FAILED;
622 }
623 CheckGlobalDeviceManager();
624 int64_t rank = g_device_manager->global_rank();
625 DeviceMatrix dev_matrix(rank, stage_device_list_, dev_matrix_shape_);
626 RankList group_devices;
627 if (dev_matrix.GetDevicesAlongDim(SizeToUlong(axis), &group_devices) != SUCCESS) {
628 return FAILED;
629 }
630
631 if (group_devices.size() == 1) {
632 MS_LOG(INFO) << "The dev size is 1, no need to create group.";
633 return SUCCESS;
634 }
635
636 Group g = g_device_manager->CreateGroup(group_devices);
637 group->push_back(g);
638 return SUCCESS;
639 }
640
GetSliceShape(const Shape & tensor_shape,const Dimensions & strategy)641 Shape GetSliceShape(const Shape &tensor_shape, const Dimensions &strategy) {
642 Shape slice_shape;
643 if (std::any_of(strategy.begin(), strategy.end(), [](int64_t value) { return value <= 0; })) {
644 MS_LOG(ERROR) << "Invalid strategy: " << ShapeToString(strategy) << ", the element is less than or equal to 0";
645 return slice_shape;
646 }
647 for (size_t i = 0; i < strategy.size(); ++i) {
648 slice_shape.push_back(tensor_shape.at(i) / strategy.at(i));
649 }
650 return slice_shape;
651 }
652
InferSliceShapeByStrategy(const Strategys & strategys,const Shapes & shapes,Shapes * slice_shapes)653 Status InferSliceShapeByStrategy(const Strategys &strategys, const Shapes &shapes, Shapes *slice_shapes) {
654 if (slice_shapes == nullptr) {
655 MS_LOG(ERROR) << "The slice_shapes is null.";
656 return FAILED;
657 }
658 if (strategys.size() != shapes.size()) {
659 MS_LOG(ERROR) << "Strategy size " << strategys.size() << " not equal to shape size " << shapes.size();
660 return FAILED;
661 }
662
663 for (size_t i = 0; i < strategys.size(); ++i) {
664 if (strategys.at(i).size() != shapes.at(i).size()) {
665 MS_LOG(ERROR) << "Strategy dimension " << strategys.at(i).size() << " not equal to shape dimension "
666 << shapes.at(i).size();
667 slice_shapes->clear();
668 return FAILED;
669 }
670
671 for (size_t j = 0; j < shapes.at(i).size(); ++j) {
672 if (strategys.at(i).at(j) <= 0) {
673 MS_LOG(ERROR) << "Invalid strategy: " << ShapeToString(strategys[i])
674 << " the element is less than or equal to 0.";
675 slice_shapes->clear();
676 return FAILED;
677 }
678 if (shapes.at(i).at(j) % strategys.at(i).at(j) != 0) {
679 MS_LOG(ERROR) << "Shape cannot be divisible by strategy, " << shapes.at(i).at(j) << " : "
680 << strategys.at(i).at(j);
681 slice_shapes->clear();
682 return FAILED;
683 }
684 }
685 Shape slice_shape = GetSliceShape(shapes.at(i), strategys.at(i));
686 slice_shapes->push_back(slice_shape);
687 }
688
689 return SUCCESS;
690 }
691
InferSliceShape(const Strategys & inputs_strategy,const Strategys & outputs_strategy,Shapes * inputs_slice_shape,Shapes * outputs_slice_shape)692 Status OperatorInfo::InferSliceShape(const Strategys &inputs_strategy, const Strategys &outputs_strategy,
693 Shapes *inputs_slice_shape, Shapes *outputs_slice_shape) {
694 if (inputs_slice_shape == nullptr || outputs_slice_shape == nullptr) {
695 MS_LOG(ERROR) << "The slice_shape is null.";
696 return FAILED;
697 }
698
699 if (InferSliceShapeByStrategy(inputs_strategy, inputs_shape_, inputs_slice_shape) != SUCCESS) {
700 MS_LOG(ERROR) << "Infer inputs slice shape error.";
701 return FAILED;
702 }
703
704 if (InferSliceShapeByStrategy(outputs_strategy, outputs_shape_, outputs_slice_shape) != SUCCESS) {
705 MS_LOG(ERROR) << "Infer outputs slice shape error.";
706 inputs_slice_shape->clear();
707 return FAILED;
708 }
709
710 return SUCCESS;
711 }
712
713 // method0: auto insert repeated_calculation_num for dev_matrix_shape when repeated_calculation_num > 1
InitForCostModelWithAutoRepeatCalc(const StrategyPtr & strategy)714 Status OperatorInfo::InitForCostModelWithAutoRepeatCalc(const StrategyPtr &strategy) {
715 if (strategy == nullptr) {
716 MS_LOG(ERROR) << name_ << ": The strategy is null.";
717 return FAILED;
718 }
719
720 if (InferAttrs() != SUCCESS) {
721 MS_LOG(ERROR) << name_ << ": InferAttrs failed.";
722 return FAILED;
723 }
724
725 // must be after InferAttrs()
726 if (CheckStrategy(strategy) != SUCCESS) {
727 if (is_auto_parallel_) {
728 MS_LOG(DEBUG) << name_ << ": CheckStrategy failed.";
729 } else {
730 MS_LOG(ERROR) << name_ << ": CheckStrategy failed.";
731 }
732 return FAILED;
733 }
734
735 // need to clear queues before Init(),
736 // because Init() may be called multiple times by cost model
737 ResetQueueMember();
738
739 strategy_ = strategy;
740
741 if (InferDevMatrixShape() != SUCCESS) {
742 MS_LOG(ERROR) << name_ << ": InferDevMatrixShape failed.";
743 return FAILED;
744 }
745
746 used_devices_ =
747 ((int64_t)(std::accumulate(dev_matrix_shape_.begin(), dev_matrix_shape_.end(), 1, std::multiplies<int64_t>())));
748
749 // must be after InferDevMatrixShape
750 if (InferRepeatedCalcInfo() != SUCCESS) {
751 MS_LOG(ERROR) << ": InferRepeatedCalcInfo failed.";
752 return FAILED;
753 }
754
755 // if repeated calculation, need to set the repeated_calc_num as the last dimension of dev-matrix for layout
756 SetRepeatedCalcDevMatrix();
757
758 if (InferTensorMap() != SUCCESS) {
759 MS_LOG(ERROR) << name_ << ": InferTensorMap failed.";
760 return FAILED;
761 }
762
763 ResetTensorMapIfRepeatedCalc();
764
765 if (InferTensorInfo() != SUCCESS) {
766 MS_LOG(ERROR) << name_ << ": InferTensorInfo failed.";
767 return FAILED;
768 }
769
770 return SUCCESS;
771 }
772
773 // method1: manually insert repeated_calculation_num for dev_matrix_shape in InferDevMatrixShape
InitForCostModelWithManualRepeatCalc(const StrategyPtr & strategy)774 Status OperatorInfo::InitForCostModelWithManualRepeatCalc(const StrategyPtr &strategy) {
775 if (strategy == nullptr) {
776 MS_LOG(ERROR) << name_ << ": The strategy is null.";
777 return FAILED;
778 }
779
780 if (InferAttrs() != SUCCESS) {
781 MS_LOG(ERROR) << name_ << ": InferAttrs failed.";
782 return FAILED;
783 }
784
785 // must be after InferAttrs()
786 if (CheckStrategy(strategy) != SUCCESS) {
787 MS_LOG(ERROR) << name_ << ": CheckStrategy failed.";
788 return FAILED;
789 }
790
791 // need to clear queues before Init(),
792 // because Init() may be called multiple times by cost model
793 ResetQueueMember();
794
795 strategy_ = strategy;
796
797 if (InferDevMatrixShape() != SUCCESS) {
798 MS_LOG(ERROR) << name_ << ": InferDevMatrixShape failed.";
799 return FAILED;
800 }
801
802 // must be after InferDevMatrixShape
803 if (InferRepeatedCalcInfo() != SUCCESS) {
804 MS_LOG(ERROR) << name_ << ": InferRepeatedCalcInfo failed.";
805 return FAILED;
806 }
807
808 if (InferTensorMap() != SUCCESS) {
809 MS_LOG(ERROR) << name_ << ": InferTensorMap failed.";
810 return FAILED;
811 }
812
813 if (InferTensorInfo() != SUCCESS) {
814 MS_LOG(ERROR) << name_ << ": InferTensorInfo failed.";
815 return FAILED;
816 }
817
818 return SUCCESS;
819 }
820
InitWithAutoRepeatCalc(const StrategyPtr & strategy)821 Status OperatorInfo::InitWithAutoRepeatCalc(const StrategyPtr &strategy) {
822 if (strategy == nullptr) {
823 MS_LOG(ERROR) << name_ << ": The strategy is null.";
824 return FAILED;
825 }
826
827 if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
828 return FAILED;
829 }
830
831 if (InferForwardCommunication() != SUCCESS) {
832 MS_LOG(ERROR) << name_ << ": InferForwardCommunication failed.";
833 return FAILED;
834 }
835
836 if (InferMirrorOps() != SUCCESS) {
837 MS_LOG(ERROR) << name_ << ": InferMirrorOps failed.";
838 return FAILED;
839 }
840
841 if (InferVirtualDivOps() != SUCCESS) {
842 MS_LOG(ERROR) << name_ << ": InferVirtualDivOps failed.";
843 return FAILED;
844 }
845
846 return SUCCESS;
847 }
848
InitWithManualRepeatCalc(const StrategyPtr & strategy)849 Status OperatorInfo::InitWithManualRepeatCalc(const StrategyPtr &strategy) {
850 if (strategy == nullptr) {
851 MS_LOG(ERROR) << name_ << ": The strategy is null.";
852 return FAILED;
853 }
854
855 if (InitForCostModelWithManualRepeatCalc(strategy) != SUCCESS) {
856 return FAILED;
857 }
858
859 if (InferForwardCommunication() != SUCCESS) {
860 MS_LOG(ERROR) << name_ << ": InferForwardCommunication failed.";
861 return FAILED;
862 }
863
864 if (InferMirrorOps() != SUCCESS) {
865 MS_LOG(ERROR) << name_ << ": InferMirrorOps failed.";
866 return FAILED;
867 }
868
869 if (InferVirtualDivOps() != SUCCESS) {
870 MS_LOG(ERROR) << name_ << ": InferVirtualDivOps failed.";
871 return FAILED;
872 }
873
874 return SUCCESS;
875 }
876
GetAliveSuccEdges()877 std::vector<std::shared_ptr<Edge>> OperatorInfo::GetAliveSuccEdges() {
878 std::vector<std::shared_ptr<Edge>> ret;
879 for (auto &edge : succ_edges_) {
880 if ((edge->next_operator()->is_alive()) && (edge->next_operator()->name().find(RELU) != std::string::npos)) {
881 ret.push_back(edge);
882 } else if ((edge->next_operator()->is_alive()) && (edge->next_operator()->name().find(CAST) != std::string::npos)) {
883 // CAST is ordered in front of L2NORMALIZE
884 ret.push_back(edge);
885 }
886 }
887 for (auto &edge : succ_edges_) {
888 if ((edge->next_operator()->is_alive()) && (edge->next_operator()->name().find(RELU) == std::string::npos) &&
889 (edge->next_operator()->name().find(CAST) == std::string::npos)) {
890 ret.push_back(edge);
891 }
892 }
893 return ret;
894 }
895
GetAlivePrevEdges()896 std::vector<std::shared_ptr<Edge>> OperatorInfo::GetAlivePrevEdges() {
897 std::vector<std::shared_ptr<Edge>> ret;
898 for (auto &edge : prev_edges_) {
899 if (edge->prev_operator()->is_alive()) {
900 ret.push_back(edge);
901 }
902 }
903 return ret;
904 }
905
ReplacePreEdge(const std::shared_ptr<OperatorInfo> & op,const std::shared_ptr<Edge> & replace_edge)906 void OperatorInfo::ReplacePreEdge(const std::shared_ptr<OperatorInfo> &op, const std::shared_ptr<Edge> &replace_edge) {
907 if (op == nullptr) {
908 MS_LOG(ERROR) << name_ << ": ReplacePreEdge: the op is null.";
909 return;
910 }
911 for (auto &edge : prev_edges_) {
912 if (edge->prev_operator() == op) {
913 edge = replace_edge;
914 return;
915 }
916 }
917 MS_LOG(EXCEPTION) << name_ << ": Replace edge failed: no edge has been replaced";
918 }
919
ReplaceSuccEdge(const std::shared_ptr<OperatorInfo> & op,const std::shared_ptr<Edge> & replace_edge)920 void OperatorInfo::ReplaceSuccEdge(const std::shared_ptr<OperatorInfo> &op, const std::shared_ptr<Edge> &replace_edge) {
921 if (op == nullptr) {
922 MS_LOG(ERROR) << name_ << ": ReplaceSuccEdge: the op is null.";
923 return;
924 }
925 for (auto &edge : succ_edges_) {
926 if (edge->next_operator() == op) {
927 edge = replace_edge;
928 return;
929 }
930 }
931 MS_LOG(EXCEPTION) << name_ << ": Replace edge failed: no edge has been replaced";
932 }
933
ReplacePreEdges(const std::shared_ptr<OperatorInfo> & op,const std::shared_ptr<Edge> & replace_edge)934 void OperatorInfo::ReplacePreEdges(const std::shared_ptr<OperatorInfo> &op, const std::shared_ptr<Edge> &replace_edge) {
935 if (op == nullptr) {
936 MS_LOG(ERROR) << name_ << ": ReplacePreEdges: the op is null.";
937 return;
938 }
939 std::vector<std::shared_ptr<Edge>> update_pre_edges;
940 for (auto &edge : prev_edges_) {
941 if (edge->prev_operator() != op) {
942 update_pre_edges.push_back(edge);
943 }
944 }
945 update_pre_edges.push_back(replace_edge);
946 prev_edges_ = update_pre_edges;
947 }
948
ReplaceSuccEdges(const std::shared_ptr<OperatorInfo> & op,const std::shared_ptr<Edge> & replace_edge)949 void OperatorInfo::ReplaceSuccEdges(const std::shared_ptr<OperatorInfo> &op,
950 const std::shared_ptr<Edge> &replace_edge) {
951 if (op == nullptr) {
952 MS_LOG(ERROR) << name_ << ": ReplaceSuccEdges: the op is null";
953 return;
954 }
955 std::vector<std::shared_ptr<Edge>> update_pre_edges;
956 for (auto &edge : succ_edges_) {
957 if (edge->next_operator() != op) {
958 update_pre_edges.push_back(edge);
959 }
960 }
961 update_pre_edges.push_back(replace_edge);
962 succ_edges_ = update_pre_edges;
963 }
964
GenerateBatchStrategiesBySplitFlag(const Shapes & shapes,const std::vector<bool> & split_flag_list)965 std::shared_ptr<Strategys> GenerateBatchStrategiesBySplitFlag(const Shapes &shapes,
966 const std::vector<bool> &split_flag_list) {
967 if (shapes.size() != split_flag_list.size()) {
968 MS_LOG(ERROR) << "Split_flag_list do not have the same size as inputs shape, " << split_flag_list.size() << " : "
969 << shapes.size();
970 return nullptr;
971 }
972 CheckGlobalDeviceManager();
973 int64_t dev_num = g_device_manager->stage_device_num();
974 Strategys strategy_v;
975 for (size_t i = 0; i != shapes.size(); i++) {
976 if (shapes[i].empty()) {
977 MS_LOG(INFO) << "Elements of shapes is empty.";
978 Dimensions empty_element;
979 strategy_v.push_back(empty_element);
980 } else {
981 Dimensions element(shapes[i].size(), 1);
982 if (split_flag_list[i]) {
983 element[0] = dev_num;
984 }
985 strategy_v.push_back(element);
986 }
987 }
988 return std::make_shared<Strategys>(strategy_v);
989 }
990
ReComputeBatchSplitFlagList()991 void OperatorInfo::ReComputeBatchSplitFlagList() {
992 if (!inputs_shape_.empty()) {
993 split_flag_list_[0] = true;
994 }
995 }
996
ComputeBatchSplitFlagList()997 void OperatorInfo::ComputeBatchSplitFlagList() {
998 split_flag_list_.clear();
999 for (auto iter = inputs_shape_.begin(); iter != inputs_shape_.end(); ++iter) {
1000 split_flag_list_.push_back(false);
1001 }
1002 ReComputeBatchSplitFlagList();
1003 }
1004
1005 // This is a common method for checking whether the generated strategy has the correct number of devuces.
PrepareStrategyBase(int64_t stage_id,size_t dev_num,const Shapes & inputs_partitions,StrategyPtr * const sp)1006 Status PrepareStrategyBase(int64_t stage_id, size_t dev_num, const Shapes &inputs_partitions, StrategyPtr *const sp) {
1007 if (sp == nullptr) {
1008 MS_LOG(ERROR) << "The strategy is null.";
1009 return FAILED;
1010 }
1011 int64_t product = 1;
1012
1013 for (auto &input_partition : inputs_partitions) {
1014 product *= std::accumulate(input_partition.begin(), input_partition.end(), 1, std::multiplies<int64_t>());
1015 }
1016 const auto fully_use_device = CostModelContext::GetInstance()->fully_use_device();
1017 if (!fully_use_device) {
1018 if (LongToSize(product) > dev_num) {
1019 return FAILED;
1020 }
1021 } else {
1022 if ((product != 1) && (LongToSize(product) != dev_num)) {
1023 return FAILED;
1024 }
1025 }
1026 Strategys stras(inputs_partitions);
1027 (*sp) = std::make_shared<Strategy>(stage_id, stras);
1028 return SUCCESS;
1029 }
1030
GenerateBatchStrategies()1031 std::shared_ptr<Strategys> OperatorInfo::GenerateBatchStrategies() {
1032 if (inputs_shape_.empty() && InferAttrs() != SUCCESS) {
1033 MS_LOG(EXCEPTION) << name_ << ": Infer attrs failed";
1034 }
1035 ComputeBatchSplitFlagList();
1036 return GenerateBatchStrategiesBySplitFlag(inputs_shape_, split_flag_list_);
1037 }
1038
PrintStrategy(const StrategyPtr & strategy)1039 void PrintStrategy(const StrategyPtr &strategy) {
1040 if (strategy == nullptr) {
1041 return;
1042 }
1043 std::string all_strategy = "";
1044 for (size_t i = 0; i < strategy->GetInputNumber(); ++i) {
1045 all_strategy += "[";
1046 for (size_t j = 0; j < strategy->GetInputDim()[i].size(); ++j) {
1047 all_strategy += std::to_string(strategy->GetInputDim()[i][j]);
1048 if (j != strategy->GetInputDim()[i].size() - 1) {
1049 all_strategy += ", ";
1050 }
1051 }
1052 all_strategy += "]";
1053 if (i != strategy->GetInputNumber() - 1) {
1054 all_strategy += ", ";
1055 }
1056 }
1057 MS_LOG(INFO) << "The strategy is: " << all_strategy;
1058 }
1059
1060 // generate strategies for that each dimension of input0 and input1 is relevant, such as: ([a, b, c, d], [a, b, c, d])
GenerateStrategiesForTwoEqualInputs(int64_t stage_id,const Shapes & inputs_shape,const Shapes & splittable_inputs,std::vector<StrategyPtr> * const sp_vector)1061 Status GenerateStrategiesForTwoEqualInputs(int64_t stage_id, const Shapes &inputs_shape,
1062 const Shapes &splittable_inputs, std::vector<StrategyPtr> *const sp_vector) {
1063 if (sp_vector == nullptr) {
1064 MS_LOG(ERROR) << "The sp_vector is null.";
1065 return FAILED;
1066 }
1067
1068 if ((inputs_shape.size() != 2) || (splittable_inputs.size() != 2)) {
1069 MS_LOG(ERROR) << "The inputs size is wrong.";
1070 return FAILED;
1071 }
1072
1073 if ((inputs_shape[0].size() != inputs_shape[1].size()) ||
1074 (splittable_inputs[0].size() != splittable_inputs[1].size())) {
1075 MS_LOG(ERROR) << "The size of two inputs are not equal.";
1076 return FAILED;
1077 }
1078
1079 Shapes input0_shape = {inputs_shape[0]};
1080 Shapes input0_splittable = {splittable_inputs[0]};
1081 if (GenerateStrategiesForIndependentInputs(stage_id, input0_shape, input0_splittable, sp_vector) != SUCCESS) {
1082 return FAILED;
1083 }
1084
1085 for (auto &sp : *sp_vector) {
1086 sp->ExpandInputDimFromOneToTwo();
1087 }
1088
1089 return SUCCESS;
1090 }
1091
1092 // generate strategies for that input0 and input1 have relevant dimensions, and input0 needs to broadcast
1093 // such as: ([b, c, d], [a, b, c, d]) or ([1, c, d], [a, b, c, d])
GenerateStrategiesForBroadcastLeft(int64_t stage_id,const Shapes & inputs_shape,const Shapes & splittable_inputs,std::vector<StrategyPtr> * const sp_vector)1094 Status GenerateStrategiesForBroadcastLeft(int64_t stage_id, const Shapes &inputs_shape, const Shapes &splittable_inputs,
1095 std::vector<StrategyPtr> *const sp_vector) {
1096 if (sp_vector == nullptr) {
1097 MS_LOG(ERROR) << "The sp_vector is null.";
1098 return FAILED;
1099 }
1100
1101 if (inputs_shape[0].size() >= inputs_shape[1].size()) {
1102 MS_LOG(ERROR) << "Invalid inputs shape.";
1103 return FAILED;
1104 }
1105
1106 // first, generate strategy for input0 the same as input1
1107 Shapes tmp_inputs_shape = {inputs_shape[1], inputs_shape[1]};
1108 Shapes tmp_splittable_inputs = {splittable_inputs[1], splittable_inputs[1]};
1109 if (GenerateStrategiesForTwoEqualInputs(stage_id, tmp_inputs_shape, tmp_splittable_inputs, sp_vector) != SUCCESS) {
1110 MS_LOG(ERROR) << "GenerateStrategiesForTwoEqualInputs failed.";
1111 return FAILED;
1112 }
1113
1114 // second, get the correct strategy for input0
1115 for (auto &sp : *sp_vector) {
1116 Strategys tmp_strategy;
1117 Dimensions input0_strategy = sp->GetInputDim()[0];
1118 size_t size_diff = inputs_shape[1].size() - inputs_shape[0].size();
1119
1120 // erase the unnecessary part
1121 (void)input0_strategy.erase(input0_strategy.begin(),
1122 input0_strategy.begin() + static_cast<different_type>(size_diff));
1123
1124 // handle the case likes ([1, c, d], [a, b, c, d])
1125 for (size_t i = 0; i < inputs_shape[0].size(); ++i) {
1126 if (inputs_shape[0][i] == 1) {
1127 input0_strategy[i] = 1;
1128 } else {
1129 break;
1130 }
1131 }
1132
1133 // reset the strategy
1134 tmp_strategy.push_back(input0_strategy); // input0
1135 tmp_strategy.push_back(sp->GetInputDim()[1]); // input1
1136 sp->ResetInputs(tmp_strategy);
1137 }
1138 return SUCCESS;
1139 }
1140
1141 // generate strategies for that input0 and input1 have relevant dimensions, and input1 needs to broadcast
1142 // such as: ([a, b, c, d], [b, c, d]) or ([a, b, c, d], [1, c, d])
GenerateStrategiesForBroadcastRight(int64_t stage_id,const Shapes & inputs_shape,const Shapes & splittable_inputs,std::vector<StrategyPtr> * const sp_vector)1143 Status GenerateStrategiesForBroadcastRight(int64_t stage_id, const Shapes &inputs_shape,
1144 const Shapes &splittable_inputs, std::vector<StrategyPtr> *const sp_vector) {
1145 if (sp_vector == nullptr) {
1146 MS_LOG(ERROR) << "The sp_vector is null.";
1147 return FAILED;
1148 }
1149
1150 if (inputs_shape[0].size() <= inputs_shape[1].size()) {
1151 MS_LOG(ERROR) << "Invalid inputs shape.";
1152 return FAILED;
1153 }
1154
1155 // first, generate strategy for input1 the same as input0
1156 Shapes tmp_inputs_shape = {inputs_shape[0], inputs_shape[0]};
1157 Shapes tmp_splittable_inputs = {splittable_inputs[0], splittable_inputs[0]};
1158 if (GenerateStrategiesForTwoEqualInputs(stage_id, tmp_inputs_shape, tmp_splittable_inputs, sp_vector) != SUCCESS) {
1159 MS_LOG(ERROR) << "GenerateStrategiesForTwoEqualInputs failed.";
1160 return FAILED;
1161 }
1162
1163 // second, get the correct strategy for input1
1164 for (auto &sp : *sp_vector) {
1165 Strategys tmp_strategy;
1166 tmp_strategy.push_back(sp->GetInputDim()[0]); // input0
1167
1168 Dimensions input1_strategy = sp->GetInputDim()[1];
1169 size_t size_diff = inputs_shape[0].size() - inputs_shape[1].size();
1170
1171 // erase the unnecessary part
1172 (void)input1_strategy.erase(input1_strategy.begin(),
1173 input1_strategy.begin() + static_cast<different_type>(size_diff));
1174
1175 // handle the case likes ([a, b, c, d], [1, c, d])
1176 for (size_t i = 0; i < inputs_shape[1].size(); ++i) {
1177 if (inputs_shape[1][i] == 1) {
1178 input1_strategy[i] = 1;
1179 } else {
1180 break;
1181 }
1182 }
1183
1184 // reset the strategy
1185 tmp_strategy.push_back(input1_strategy); // input1
1186 sp->ResetInputs(tmp_strategy);
1187 }
1188 return SUCCESS;
1189 }
1190
1191 // generate strategies for that input0 and input1 have same size, and input0 or input1 needs to broadcast
1192 // such as: ([a, 1], [1, b]) or ([a, b, c, d], [1, b, c, d]) or ([a, b, c, 1], [1, b, c, d])
GenerateStrategiesForBroadcastBoth(int64_t stage_id,const Shapes & inputs_shape,const Shapes & splittable_inputs,std::vector<StrategyPtr> * const sp_vector)1193 Status GenerateStrategiesForBroadcastBoth(int64_t stage_id, const Shapes &inputs_shape, const Shapes &splittable_inputs,
1194 std::vector<StrategyPtr> *const sp_vector) {
1195 if (sp_vector == nullptr) {
1196 MS_LOG(ERROR) << "The sp_vector is null.";
1197 return FAILED;
1198 }
1199
1200 if (inputs_shape[0].size() != inputs_shape[1].size()) {
1201 MS_LOG(ERROR) << "Invalid inputs shape.";
1202 return FAILED;
1203 }
1204
1205 // step1: ([a, 1], [1, b]) -> [a, b]
1206 Shape max_shape, splittable_vector;
1207 for (size_t i = 0; i < inputs_shape[0].size(); ++i) {
1208 if (inputs_shape[0][i] >= inputs_shape[1][i]) {
1209 max_shape.push_back(inputs_shape[0][i]);
1210 splittable_vector.push_back(splittable_inputs[0][i]);
1211 } else {
1212 max_shape.push_back(inputs_shape[1][i]);
1213 splittable_vector.push_back(splittable_inputs[1][i]);
1214 }
1215 }
1216
1217 // step2: ([a, 1], [1, b]) -> generate strategy for ([a, b], [a, b])
1218 Shapes tmp_inputs_shape = {max_shape, max_shape};
1219 Shapes tmp_splittable_inputs = {splittable_vector, splittable_vector};
1220 if (GenerateStrategiesForTwoEqualInputs(stage_id, tmp_inputs_shape, tmp_splittable_inputs, sp_vector) != SUCCESS) {
1221 MS_LOG(ERROR) << "GenerateStrategiesForTwoEqualInputs failed.";
1222 return FAILED;
1223 }
1224
1225 // step3: reset the strategy if the dimension is 1
1226 for (auto &sp : *sp_vector) {
1227 Dimensions input0_strategy = sp->GetInputDim()[0];
1228 Dimensions input1_strategy = sp->GetInputDim()[1];
1229 for (size_t i = 0; i < inputs_shape[0].size(); ++i) {
1230 if (inputs_shape[0][i] == 1) {
1231 input0_strategy[i] = 1;
1232 }
1233
1234 if (inputs_shape[1][i] == 1) {
1235 input1_strategy[i] = 1;
1236 }
1237 }
1238 sp->ResetInputs({input0_strategy, input1_strategy});
1239 }
1240
1241 return SUCCESS;
1242 }
1243
1244 // 'splittable_inputs' has the same dimensions as 'inputs_shape_'. '0' in 'splittable_inputs' means that
1245 // the corresponding dimension is unsplittable, '1' in 'splittable_inputs' means that the corresponding
1246 // dimension is splittable. 'inputs_partitions' is the result of partitions.
1247 // NOTE: This implementation would partition all splittable dimensions in all inputs. Some operators requiring
1248 // specific dimensions in inputs have the identical partition should have individual implementation.
GenerateStrategiesForIndependentInputs(int64_t stage_id,const Shapes & inputs_shape,const Shapes & splittable_inputs,std::vector<StrategyPtr> * const sp_vector)1249 Status GenerateStrategiesForIndependentInputs(int64_t stage_id, const Shapes &inputs_shape,
1250 const Shapes &splittable_inputs,
1251 std::vector<StrategyPtr> *const sp_vector) {
1252 if (sp_vector == nullptr) {
1253 MS_LOG(ERROR) << "The sp_vector is null.";
1254 return FAILED;
1255 }
1256 if (splittable_inputs.size() != inputs_shape.size()) {
1257 MS_LOG(ERROR) << "Splittable_inputs do not have the same input number of inputs shape, " << splittable_inputs.size()
1258 << " : " << inputs_shape.size();
1259 return FAILED;
1260 }
1261 CheckGlobalDeviceManager();
1262 size_t dev_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
1263
1264 Shape combined_inputs_shape, combined_splittable_inputs, combined_partitions;
1265 for (size_t j = 0; j < inputs_shape.size(); ++j) {
1266 (void)combined_inputs_shape.insert(combined_inputs_shape.end(), inputs_shape[j].begin(), inputs_shape[j].end());
1267 (void)combined_splittable_inputs.insert(combined_splittable_inputs.end(), splittable_inputs[j].begin(),
1268 splittable_inputs[j].end());
1269 }
1270 std::function<void(uint64_t, size_t)> recursive = [&stage_id, &dev_num, &sp_vector, &combined_inputs_shape,
1271 &combined_splittable_inputs, &combined_partitions, &recursive,
1272 &inputs_shape](uint64_t current_index, size_t n) {
1273 if (current_index == combined_inputs_shape.size()) {
1274 MS_LOG(DEBUG) << "The value of combined_splittable_inputs.size is: " << combined_splittable_inputs.size();
1275 Shapes inputs_partitions;
1276 size_t global_index = 0;
1277 for (auto &shape : inputs_shape) {
1278 Shape tmp_partition;
1279 for (size_t j = 0; j < shape.size(); ++j) {
1280 tmp_partition.push_back(combined_partitions[global_index]);
1281 global_index++;
1282 }
1283 inputs_partitions.push_back(tmp_partition);
1284 }
1285 StrategyPtr sp;
1286 if (PrepareStrategyBase(stage_id, dev_num, inputs_partitions, &sp) == SUCCESS) {
1287 sp_vector->push_back(sp);
1288 }
1289 return;
1290 } else {
1291 MS_LOG(DEBUG) << "The value of sp_vector size is " << sp_vector->size();
1292 if (combined_splittable_inputs[current_index] == 0) {
1293 combined_partitions.push_back(MIN_SLICE_NUM);
1294 recursive(current_index + 1, n / MIN_SLICE_NUM);
1295 combined_partitions.pop_back();
1296 } else if (combined_splittable_inputs[current_index] == 1) {
1297 for (uint64_t i = 1; i <= n; i *= 2) {
1298 if (n % i == 0 && LongToSize(combined_inputs_shape[current_index]) % i == 0) {
1299 combined_partitions.push_back(i);
1300 recursive(current_index + 1, n / i);
1301 combined_partitions.pop_back();
1302 }
1303 }
1304 }
1305 }
1306 };
1307 recursive(0, dev_num);
1308 if (sp_vector->empty()) {
1309 MS_LOG(EXCEPTION) << "No available strategy for current OperatorInfo.";
1310 }
1311 return SUCCESS;
1312 }
1313
1314 // generate strategies for that have two inputs, and input0 or input1 maybe broadcast,
1315 // and the corresponding dimensions that are not broadcast are all relevant dimensions
1316 // such as: ([a, b, c, d], [a, b, c, d]) or ([b, c, d], [a, b, c, d]) or ([1, c, d], [a, b, c, d])
1317 // or ([a, b, c, d], [b, c, d]) or ([a, b, c, d], [1, c, d])
1318 // or ([a, 1], [1, b]) or ([a, b, c, d], [1, b, c, d]) or ([a, b, c, 1], [1, b, c, d])
GenerateStrategiesWithBroadcast(int64_t stage_id,const Shapes & inputs_shape,const Shapes & splittable_inputs,std::vector<StrategyPtr> * const sp_vector)1319 Status GenerateStrategiesWithBroadcast(int64_t stage_id, const Shapes &inputs_shape, const Shapes &splittable_inputs,
1320 std::vector<StrategyPtr> *const sp_vector) {
1321 if (sp_vector == nullptr) {
1322 MS_LOG(ERROR) << "The sp_vector is null.";
1323 return FAILED;
1324 }
1325
1326 if ((inputs_shape.size() != 2) || (splittable_inputs.size() != 2)) {
1327 MS_LOG(ERROR) << "The inputs' size is wrong.";
1328 return FAILED;
1329 }
1330
1331 if (inputs_shape[0] == inputs_shape[1]) {
1332 // element wise operation([a, b, c, d], [a, b, c, d]), so input0's strategy is equal to input1's strategy
1333 if (GenerateStrategiesForTwoEqualInputs(stage_id, inputs_shape, splittable_inputs, sp_vector) != SUCCESS) {
1334 MS_LOG(ERROR) << "GenerateStrategiesForTwoEqualInputs failed.";
1335 return FAILED;
1336 }
1337 MS_LOG(INFO) << "GenerateStrategiesForTwoEqualInputs success.";
1338 } else if (inputs_shape[0].empty() || inputs_shape[1].empty()) {
1339 // ([a, b, c, d], []) or ([], [a, b, c, d])
1340 if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape, splittable_inputs, sp_vector) != SUCCESS) {
1341 MS_LOG(ERROR) << "Generate strategies for scalar case failed.";
1342 return FAILED;
1343 }
1344 MS_LOG(INFO) << "Generate strategies for scalar case success.";
1345 } else if (inputs_shape[0].size() > inputs_shape[1].size()) {
1346 // ([a, b, c, d], [b, c, d]) or ([a, b, c, d], [1, c, d])
1347 if (GenerateStrategiesForBroadcastRight(stage_id, inputs_shape, splittable_inputs, sp_vector) != SUCCESS) {
1348 MS_LOG(ERROR) << "GenerateStrategiesForBroadcastRight failed.";
1349 return FAILED;
1350 }
1351 MS_LOG(INFO) << "GenerateStrategiesForBroadcastRight success.";
1352 } else if (inputs_shape[0].size() < inputs_shape[1].size()) {
1353 // ([b, c, d], [a, b, c, d]) or ([1, c, d], [a, b, c, d])
1354 if (GenerateStrategiesForBroadcastLeft(stage_id, inputs_shape, splittable_inputs, sp_vector) != SUCCESS) {
1355 MS_LOG(ERROR) << "GenerateStrategiesForBroadcastLeft failed.";
1356 return FAILED;
1357 }
1358 MS_LOG(INFO) << "GenerateStrategiesForBroadcastLeft success.";
1359 } else { // same size, but different value
1360 // ([a, 1], [1, b]) or ([a, b, c, d], [1, b, c, d]) or ([a, b, c, 1], [1, b, c, d])
1361 if (GenerateStrategiesForBroadcastBoth(stage_id, inputs_shape, splittable_inputs, sp_vector) != SUCCESS) {
1362 MS_LOG(ERROR) << "GenerateStrategiesForBroadcastBoth failed.";
1363 return FAILED;
1364 }
1365 MS_LOG(INFO) << "GenerateStrategiesForBroadcastBoth success.";
1366 }
1367 return SUCCESS;
1368 }
1369
SetCostUnderStrategyBase(const StrategyPtr & strategy)1370 Status OperatorInfo::SetCostUnderStrategyBase(const StrategyPtr &strategy) {
1371 if (InitForCostModel(strategy) == FAILED) {
1372 if (is_auto_parallel_) {
1373 MS_LOG(DEBUG) << name_ << ": Initialization under the strategy failed.";
1374 } else {
1375 MS_LOG(ERROR) << name_ << ": Initialization under the strategy failed.";
1376 }
1377 return FAILED;
1378 }
1379 int64_t stage_id = strategy->GetInputStage();
1380 double computation_cost =
1381 operator_cost()->GetForwardComputationCost(inputs_tensor_info_, outputs_tensor_info_, stage_id);
1382 double communication_cost = operator_cost()->GetCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id);
1383 const auto gamma = CostModelContext::GetInstance()->costmodel_gamma();
1384 std::shared_ptr<Cost> result = std::make_shared<Cost>(computation_cost, communication_cost);
1385 result->communication_without_parameter_ =
1386 operator_cost()->GetForwardCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id);
1387 result->communication_with_partial_para_ =
1388 result->communication_without_parameter_ + gamma * (communication_cost - result->communication_without_parameter_);
1389
1390 // Breaking ties for preferring data parallelization
1391 BreakingTiesForPerferringDataParallel(strategy, result);
1392 // refine communication cost calculation for practice
1393 RefineForPracticalCost(result, false);
1394 result->communication_forward_ = result->communication_without_parameter_;
1395
1396 std::shared_ptr<StrategyWithCost> swc =
1397 std::make_shared<StrategyWithCost>(strategy, inputs_tensor_info_, outputs_tensor_info_);
1398 swc->cost_list.push_back(result);
1399 strategy_cost_.emplace_back(swc);
1400
1401 return SUCCESS;
1402 }
1403
1404 // Keep at most (1.0 / epsilon) number of available strategies for each operator.
ApproximateStrategies()1405 void OperatorInfo::ApproximateStrategies() {
1406 auto enable_approxi = CostModelContext::GetInstance()->dp_algo_enable_approxi();
1407 if (!enable_approxi) {
1408 return;
1409 }
1410 MS_LOG(INFO) << "Approximating strategy-cost for: " << name_;
1411 auto epsilon = CostModelContext::GetInstance()->dp_algo_approxi_epsilon();
1412 auto target_num = static_cast<size_t>(std::ceil(1.0 / epsilon));
1413 if (strategy_cost_.size() <= target_num) {
1414 MS_LOG(INFO) << name_ << "'s strategy number is: " << strategy_cost_.size()
1415 << ", no greater than target-num: " << target_num;
1416 return;
1417 }
1418 std::vector<std::shared_ptr<StrategyWithCost>> ret;
1419 auto &origin_stra_cost = strategy_cost_;
1420 auto alpha = CostModelContext::GetInstance()->costmodel_alpha();
1421 auto beta = CostModelContext::GetInstance()->costmodel_beta();
1422 // sort
1423 std::sort(
1424 origin_stra_cost.begin(), origin_stra_cost.end(),
1425 [&alpha, &beta](const std::shared_ptr<StrategyWithCost> &s1, const std::shared_ptr<StrategyWithCost> &s2) {
1426 if (alpha * s1->cost_list[0]->computation_cost_ + beta * s1->cost_list[0]->communication_with_partial_para_ <
1427 alpha * s2->cost_list[0]->computation_cost_ + beta * s2->cost_list[0]->communication_with_partial_para_) {
1428 return true;
1429 }
1430 return false;
1431 });
1432 size_t step_length = origin_stra_cost.size() / target_num;
1433 for (size_t i = 0; ret.size() < target_num && static_cast<size_t>(i * step_length) < origin_stra_cost.size(); ++i) {
1434 ret.push_back(origin_stra_cost[static_cast<size_t>(i * step_length)]);
1435 }
1436
1437 strategy_cost_ = ret;
1438 is_strategy_cost_exact_ = false;
1439 }
1440
ExactStrategiesAndRelatedEdges()1441 void OperatorInfo::ExactStrategiesAndRelatedEdges() {
1442 if (is_strategy_cost_exact()) {
1443 return;
1444 }
1445 ClearStrategyCost();
1446 if (GenerateStrategies(0) != SUCCESS) {
1447 MS_LOG(EXCEPTION) << "Strategy search for Operator " << name() << " failed.";
1448 return;
1449 }
1450 SetIsStrategyCostExactTrue();
1451 // re-init the previous edges
1452 for (auto &prev_edge : prev_edges()) {
1453 if (prev_edge->InitEdgeCost() != SUCCESS) {
1454 MS_LOG(EXCEPTION) << "Edge: " << prev_edge->edge_name() << " cost init failed.";
1455 }
1456 }
1457 // re-init the successive edges
1458 for (auto &next_edge : succ_edges()) {
1459 if (next_edge->InitEdgeCost() != SUCCESS) {
1460 MS_LOG(EXCEPTION) << "Edge: " << next_edge->edge_name() << " cost init failed.";
1461 }
1462 }
1463 }
1464
ComputeOpAndPrevEdgeParameterInvolved()1465 int64_t OperatorInfo::ComputeOpAndPrevEdgeParameterInvolved() {
1466 if (is_output_parameter_involve_ != -1) {
1467 return is_output_parameter_involve_;
1468 }
1469 is_parameter_involve_ = is_parameter_;
1470 const auto &prev_edges = this->GetAlivePrevEdges();
1471 for (auto &p_edge : prev_edges) {
1472 auto input_index = p_edge->next_op_input_index();
1473 auto prev_op_para = p_edge->prev_operator()->ComputeOpAndPrevEdgeParameterInvolved();
1474 if (input_index >= is_parameter_involve_.size()) {
1475 MS_LOG(EXCEPTION) << name_ << " has input length: " << is_parameter_involve_.size()
1476 << ", but got wrong input_index: " << input_index;
1477 }
1478 if (prev_op_para == 0) {
1479 is_parameter_involve_[input_index] = false;
1480 } else if (prev_op_para == 1) {
1481 is_parameter_involve_[input_index] = true;
1482 } else {
1483 MS_LOG(EXCEPTION) << name_ << " got wrong value: " << prev_op_para << ", input_index: " << input_index;
1484 }
1485 p_edge->set_parameter_involve(prev_op_para);
1486 }
1487 if (std::any_of(is_parameter_involve_.begin(), is_parameter_involve_.end(), [](bool value) { return value; })) {
1488 // If anyone of the input is a parameter_involved, the output is parameter_involved.
1489 is_output_parameter_involve_ = 1;
1490 } else {
1491 is_output_parameter_involve_ = 0;
1492 }
1493 // Set 'is_parameter_involve_' and 'is_output_parameter_involve_' into operatorCost, which are used in
1494 // calculating 'inputs_in_memory' and 'output_in_memory', respectively.
1495 operator_cost()->set_is_parameter_involve(is_parameter_involve_);
1496 operator_cost()->set_output_parameter_involve(is_output_parameter_involve_);
1497 // Calculating 'output_in_memory'
1498 operator_cost()->CalculateOutputInMemory();
1499 // Calculating 'inputs_in_memory'
1500 std::map<size_t, bool> input_in_memory;
1501 for (auto &p_edge : prev_edges) {
1502 auto input_index = p_edge->next_op_input_index();
1503 auto is_in_mem = p_edge->prev_operator()->operator_cost()->is_output_in_memory();
1504 input_in_memory.emplace(std::make_pair(input_index, is_in_mem));
1505 }
1506 operator_cost()->CalculateInputsInMemory(input_in_memory);
1507
1508 return is_output_parameter_involve_;
1509 }
1510
set_is_parameter(const std::vector<bool> & is_parameter)1511 Status OperatorInfo::set_is_parameter(const std::vector<bool> &is_parameter) {
1512 if (is_parameter.size() != inputs_shape_.size()) {
1513 MS_LOG(ERROR) << "Is_parameter: " << is_parameter.size()
1514 << " do not have the same number of inputs_shape_: " << inputs_shape_.size();
1515 return FAILED;
1516 }
1517 is_parameter_ = is_parameter;
1518 operator_cost()->set_is_parameter(is_parameter);
1519 return SUCCESS;
1520 }
1521
CalculateMemoryCost()1522 Status OperatorInfo::CalculateMemoryCost() {
1523 if (is_parameter_involve_.size() != is_parameter_.size()) {
1524 MS_LOG(ERROR) << "'is_parameter_' does not have the same number of input size of 'is_parameter_involve_'.";
1525 return FAILED;
1526 }
1527 // Set the memory cost in the 'strategy_cost_'
1528 for (auto &swc : strategy_cost_) {
1529 auto mem_cost = operator_cost()->GetMemoryCost(swc->inputs_ptr, swc->outputs_ptr);
1530 swc->cost_list[0]->memory_with_reuse_ = mem_cost;
1531 }
1532 return SUCCESS;
1533 }
1534
CalculateMemoryCostForInference()1535 Status OperatorInfo::CalculateMemoryCostForInference() {
1536 // First, set the 'is_outputs_critical_' flag into OperatorCost.
1537 if (is_output_critical_ == -1) {
1538 MS_LOG(EXCEPTION) << "The critical flag is not set.";
1539 return FAILED;
1540 }
1541 operator_cost()->set_output_critical(is_output_critical_);
1542 // Set the memory cost in the 'strategy_cost_'
1543 for (auto &swc : strategy_cost_) {
1544 auto mem_cost = operator_cost()->GetMemoryCostForInference(swc->inputs_ptr, swc->outputs_ptr);
1545 swc->cost_list[0]->memory_with_reuse_ = mem_cost;
1546 }
1547 return SUCCESS;
1548 }
1549
CorrectMemoryCost(size_t input_index)1550 Status OperatorInfo::CorrectMemoryCost(size_t input_index) {
1551 for (auto &swc : strategy_cost_) {
1552 double parameter_mem_cost = ListProduct(swc->inputs_ptr[input_index].slice_shape()) *
1553 static_cast<double>(operator_cost()->inputs_type_lengths()[input_index]);
1554 swc->cost_list[0]->memory_with_reuse_ -= parameter_mem_cost;
1555 if (swc->cost_list[0]->memory_with_reuse_ < 0) {
1556 MS_LOG(ERROR) << "The memory cost after correction is: " << swc->cost_list[0]->memory_with_reuse_
1557 << ", the parameter memory cost is: " << parameter_mem_cost;
1558 return FAILED;
1559 }
1560 }
1561 return SUCCESS;
1562 }
1563
ComputeRepeatDeviceNumByTensorMap(const Shape & dev_matrix_shape,const Shape & tensor_map)1564 int64_t ComputeRepeatDeviceNumByTensorMap(const Shape &dev_matrix_shape, const Shape &tensor_map) {
1565 int64_t ret = -1;
1566
1567 // The number of repetitions is equal to the number of all devices divided by the number of devices use for
1568 // tensor map.
1569 int64_t device_num = std::accumulate(dev_matrix_shape.begin(), dev_matrix_shape.end(), 1, std::multiplies<int64_t>());
1570 for (auto &element : tensor_map) {
1571 // -1 means the corresponding dimension is not split.
1572 if (element == MAP_NONE) {
1573 continue;
1574 } else if ((element < 0) || (LongToSize(element) >= dev_matrix_shape.size())) {
1575 MS_LOG(ERROR) << "Invalid tensor map: " << ShapeToString(tensor_map) << ", the dev matrix shape is "
1576 << ShapeToString(dev_matrix_shape);
1577 return ret;
1578 } else {
1579 size_t index = dev_matrix_shape.size() - LongToSize(element) - 1;
1580 if (dev_matrix_shape[index] <= 0) {
1581 MS_LOG(ERROR) << "Invalid dev matrix shape: " << ShapeToString(dev_matrix_shape);
1582 return ret;
1583 }
1584 device_num /= dev_matrix_shape[index];
1585 }
1586 }
1587
1588 return device_num;
1589 }
1590
InferAsLossDivisor()1591 Status OperatorInfo::InferAsLossDivisor() {
1592 if (!ParallelContext::GetInstance()->loss_repeated_mean()) {
1593 as_loss_divisor_ = 1;
1594 return SUCCESS;
1595 }
1596
1597 if (outputs_tensor_map_.empty()) {
1598 MS_LOG(ERROR) << name_ << ": The outputs tensor map is empty.";
1599 return FAILED;
1600 }
1601
1602 if (outputs_tensor_map_.size() > 1) {
1603 MS_LOG(ERROR) << name_ << ": The output size is " << outputs_tensor_map_.size()
1604 << ", need to override this function ";
1605 return FAILED;
1606 }
1607
1608 if (outputs_tensor_map_[0].empty()) {
1609 as_loss_divisor_ = stage_device_size_;
1610 MS_LOG(INFO) << name_ << ": The output is a scalar, use the dev size " << as_loss_divisor_ << ", loss divisor.";
1611 return SUCCESS;
1612 }
1613
1614 if (out_dev_matrix_shape_.empty()) {
1615 out_dev_matrix_shape_ = dev_matrix_shape_;
1616 }
1617 as_loss_divisor_ = ComputeRepeatDeviceNumByTensorMap(out_dev_matrix_shape_, outputs_tensor_map_[0]);
1618 MS_LOG(INFO) << name_ << ": the dev matrix shape is " << ShapeToString(out_dev_matrix_shape_)
1619 << ", the output tensor map is " << ShapeToString(outputs_tensor_map_[0]) << ", loss divisor is "
1620 << as_loss_divisor_;
1621 return SUCCESS;
1622 }
1623
1624 // If the operator is used as a loss, a div node is inserted for the grad of all its inputs.
InferVirtualDivOps()1625 Status OperatorInfo::InferVirtualDivOps() {
1626 if (InferAsLossDivisor() != SUCCESS) {
1627 MS_LOG(ERROR) << name_ << ": InferAsLossDivisor failed.";
1628 return FAILED;
1629 }
1630
1631 if (as_loss_divisor_ <= 0) {
1632 MS_LOG(ERROR) << name_ << ": Invalid loss divisor: " << as_loss_divisor_;
1633 return FAILED;
1634 } else if (as_loss_divisor_ == 1) {
1635 MS_LOG(INFO) << name_ << ": The loss divisor is 1, no need to create virtual div op.";
1636 return SUCCESS;
1637 }
1638
1639 virtual_div_op_.clear();
1640 // if loss is repeated calculation, insert div op
1641 Operator op = CreateVirtualDivOp(as_loss_divisor_);
1642 virtual_div_op_.push_back(op);
1643 return SUCCESS;
1644 }
1645
SetInputAndOutputTypeLength(const std::vector<size_t> & input_lengths,const std::vector<size_t> & output_lengths)1646 Status OperatorInfo::SetInputAndOutputTypeLength(const std::vector<size_t> &input_lengths,
1647 const std::vector<size_t> &output_lengths) {
1648 if (input_lengths.size() != inputs_shape_.size()) {
1649 MS_LOG(ERROR) << "Input_lengths: " << input_lengths.size()
1650 << " do not have the same number of inputs shape: " << inputs_shape_.size();
1651 return FAILED;
1652 }
1653 if (output_lengths.size() != outputs_shape_.size()) {
1654 MS_LOG(ERROR) << "Output_lengths: " << output_lengths.size()
1655 << " do not have the same number of outputs shape: " << outputs_shape_.size();
1656 return FAILED;
1657 }
1658 inputs_type_lengths_ = input_lengths;
1659 outputs_type_lengths_ = output_lengths;
1660 operator_cost()->SetInputAndOutputTypeLength(input_lengths, output_lengths);
1661 return SUCCESS;
1662 }
1663
GetOutputsTotalSize()1664 double OperatorInfo::GetOutputsTotalSize() {
1665 if (is_calculated_outputs_size_) {
1666 return outputs_total_size_;
1667 }
1668 if (outputs_type_lengths_.size() != outputs_shape_.size()) {
1669 MS_LOG(EXCEPTION) << "Output_lengths: " << outputs_type_lengths_.size()
1670 << " do not have the same number of outputs shape: " << outputs_shape_.size();
1671 }
1672 double sum = 0.0;
1673 for (size_t i = 0; i < outputs_type_lengths_.size(); ++i) {
1674 auto size = std::accumulate(outputs_shape_[i].begin(), outputs_shape_[i].end(), static_cast<double>(1.0),
1675 std::multiplies<double>());
1676 sum += size * static_cast<double>(outputs_type_lengths_[i]);
1677 }
1678 is_calculated_outputs_size_ = true;
1679 outputs_total_size_ = sum;
1680 return outputs_total_size_;
1681 }
1682
set_outputs_type(const std::vector<TypePtr> & outputs_type)1683 Status OperatorInfo::set_outputs_type(const std::vector<TypePtr> &outputs_type) {
1684 if (outputs_type.size() != outputs_shape_.size()) {
1685 MS_LOG(ERROR) << "Outputs type: " << outputs_type.size()
1686 << " do not have the same number of outputs shape: " << outputs_shape_.size();
1687 return FAILED;
1688 }
1689 outputs_type_ = outputs_type;
1690 return SUCCESS;
1691 }
1692
BreakingTiesForPerferringDataParallel(const StrategyPtr & stra,const CostPtr & cost)1693 void OperatorInfo::BreakingTiesForPerferringDataParallel(const StrategyPtr &stra, const CostPtr &cost) {
1694 if (!stra->GetInputDim().empty() && !stra->GetInputDim()[0].empty()) {
1695 if (stra->GetInputDim()[0][0] == stage_device_size_) {
1696 if (cost->computation_cost_ > 1.0) {
1697 cost->computation_cost_ -= 1.0;
1698 }
1699 if (cost->communication_cost_ > 1.0) {
1700 cost->communication_cost_ -= 1.0;
1701 }
1702 if (cost->communication_with_partial_para_ > 1.0) {
1703 cost->communication_with_partial_para_ -= 1.0;
1704 }
1705 if (cost->communication_without_parameter_ > 1.0) {
1706 cost->communication_without_parameter_ -= 1.0;
1707 }
1708 }
1709 }
1710 }
1711
SetSelectedStrategy(const StrategyPtr & s_strategy,size_t curr_depth)1712 void OperatorInfo::SetSelectedStrategy(const StrategyPtr &s_strategy, size_t curr_depth) {
1713 MS_EXCEPTION_IF_NULL(s_strategy);
1714 if ((selected_strategy_depth_ != -1) && (SizeToLong(curr_depth) > selected_strategy_depth_)) {
1715 MS_LOG(INFO) << name_ << " has already been set strategy.";
1716 return;
1717 }
1718 MS_LOG(INFO) << "Set strategy for: " << name_;
1719 PrintStrategy(s_strategy);
1720 selected_strategy_ = s_strategy;
1721 selected_strategy_depth_ = SizeToLong(curr_depth);
1722 }
1723
cnode()1724 CNodePtr OperatorInfo::cnode() {
1725 MS_EXCEPTION_IF_NULL(cnode_);
1726 return cnode_;
1727 }
1728
GetForwardMemoryCostFromCNode()1729 double OperatorInfo::GetForwardMemoryCostFromCNode() {
1730 return operator_cost()->GetForwardComputationCost(inputs_tensor_info_, outputs_tensor_info_, 0);
1731 }
1732
CheckSelectedStrategy(const StrategyPtr & s_strategy)1733 void OperatorInfo::CheckSelectedStrategy(const StrategyPtr &s_strategy) {
1734 MS_EXCEPTION_IF_NULL(s_strategy);
1735 if (!s_strategy->IsEqual(selected_strategy_)) {
1736 MS_LOG(INFO) << name() << "'s strategy may cause suboptimal, the determined strategy:";
1737 PrintStrategy(selected_strategy_);
1738 MS_LOG(INFO) << "The minimal strategy:";
1739 PrintStrategy(s_strategy);
1740 }
1741 }
1742
SetStrategyCost(const std::vector<std::shared_ptr<StrategyWithCost>> & stra_cost)1743 void OperatorInfo::SetStrategyCost(const std::vector<std::shared_ptr<StrategyWithCost>> &stra_cost) {
1744 strategy_cost_ = stra_cost;
1745 }
1746
GenerateStrategies(int64_t stage_id)1747 Status OperatorInfo::GenerateStrategies(int64_t stage_id) {
1748 if (InferAttrs() != SUCCESS) {
1749 MS_LOG(ERROR) << name_ << ": Infer attrs failed";
1750 return FAILED;
1751 }
1752
1753 std::vector<StrategyPtr> sp_vector = GenerateOpStrategies(stage_id);
1754
1755 size_t success = 0;
1756 for (auto &sp : sp_vector) {
1757 PrintStrategy(sp);
1758 if (SetCostUnderStrategy(sp) == SUCCESS) {
1759 success++;
1760 MS_LOG(INFO) << name_ << ": Successfully generated " << success << " strategy.";
1761 PrintStrategy(sp);
1762 }
1763 }
1764 return SUCCESS;
1765 }
1766
GetIntAttr(const std::string & attr_name)1767 int64_t OperatorInfo::GetIntAttr(const std::string &attr_name) {
1768 auto attr_iter = attrs_.find(attr_name);
1769 if (attr_iter == attrs_.end()) {
1770 MS_LOG(EXCEPTION) << name_ << ": Can not find the attribution of " << attr_name;
1771 }
1772
1773 MS_EXCEPTION_IF_NULL(attr_iter->second);
1774 if (!attr_iter->second->isa<Int64Imm>()) {
1775 MS_LOG(EXCEPTION) << name_ << ": The value of " << attr_name << " is not int";
1776 }
1777
1778 return attr_iter->second->cast<Int64ImmPtr>()->value();
1779 }
1780
GetBoolAttr(const std::string & attr_name)1781 bool OperatorInfo::GetBoolAttr(const std::string &attr_name) {
1782 auto attr_iter = attrs_.find(attr_name);
1783 if (attr_iter == attrs_.end()) {
1784 MS_LOG(EXCEPTION) << name_ << ": Can not find the attribution of " << attr_name;
1785 }
1786
1787 MS_EXCEPTION_IF_NULL(attr_iter->second);
1788 if (!attr_iter->second->isa<BoolImm>()) {
1789 MS_LOG(EXCEPTION) << name_ << ": The value of " << attr_name << " is not int";
1790 }
1791
1792 return attr_iter->second->cast<BoolImmPtr>()->value();
1793 }
1794
GetStringAttr(const std::string & attr_name)1795 std::string OperatorInfo::GetStringAttr(const std::string &attr_name) {
1796 std::string string_attr;
1797 auto attr_iter = attrs_.find(attr_name);
1798 if (attr_iter == attrs_.end()) {
1799 MS_LOG(EXCEPTION) << name_ << ": Can not find the attribution of " << attr_name;
1800 }
1801
1802 MS_EXCEPTION_IF_NULL(attr_iter->second);
1803 if (!attr_iter->second->isa<StringImm>()) {
1804 MS_LOG(EXCEPTION) << name_ << ": The value of " << attr_name << " is not string";
1805 }
1806
1807 string_attr = attr_iter->second->cast<StringImmPtr>()->value();
1808 return string_attr;
1809 }
1810
GetTupleIntAttr(const std::string & attr_name)1811 std::vector<int64_t> OperatorInfo::GetTupleIntAttr(const std::string &attr_name) {
1812 std::vector<int64_t> tuple_attr;
1813 auto tuple_attr_iter = attrs_.find(attr_name);
1814 if (tuple_attr_iter == attrs_.end()) {
1815 MS_LOG(EXCEPTION) << name_ << ": Can not find the attribution of " << attr_name;
1816 }
1817
1818 MS_EXCEPTION_IF_NULL(tuple_attr_iter->second);
1819 tuple_attr = GetValue<std::vector<int64_t>>(tuple_attr_iter->second);
1820
1821 return tuple_attr;
1822 }
1823
GetFloatAttr(const std::string & attr_name)1824 float OperatorInfo::GetFloatAttr(const std::string &attr_name) {
1825 auto attr_iter = attrs_.find(attr_name);
1826 if (attr_iter == attrs_.end()) {
1827 MS_LOG(EXCEPTION) << name_ << ": Can not find the attribution of " << attr_name;
1828 }
1829
1830 MS_EXCEPTION_IF_NULL(attr_iter->second);
1831 if (!attr_iter->second->isa<FP32Imm>()) {
1832 MS_LOG(EXCEPTION) << name_ << ": The value of " << attr_name << " is not float";
1833 }
1834
1835 return attr_iter->second->cast<FP32ImmPtr>()->value();
1836 }
1837
GetValueSequeue(const ValuePtr & sequeue)1838 std::vector<ValuePtr> GetValueSequeue(const ValuePtr &sequeue) {
1839 MS_EXCEPTION_IF_NULL(sequeue);
1840 std::vector<ValuePtr> ret;
1841 if (!sequeue->isa<ValueTuple>() && !sequeue->isa<ValueList>()) {
1842 MS_LOG(ERROR) << "The arg is not value tuple or value list";
1843 return ret;
1844 }
1845
1846 if (sequeue->isa<ValueTuple>()) {
1847 auto val_tuple = sequeue->cast<ValueTuplePtr>();
1848 return val_tuple->value();
1849 }
1850 auto val = sequeue->cast<ValueListPtr>();
1851 return val->value();
1852 }
1853 } // namespace parallel
1854 } // namespace mindspore
1855