1 /**
2 * Copyright 2019-2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/parallel/ops_info/operator_info.h"
18
19 #include <algorithm>
20 #include <functional>
21 #include <memory>
22 #include <numeric>
23 #include <string>
24 #include <unordered_map>
25 #include <utility>
26 #include <vector>
27
28 #include "frontend/parallel/auto_parallel/edge_costmodel.h"
29 #include "frontend/parallel/auto_parallel/graph_costmodel.h"
30 #include "frontend/parallel/step_parallel_utils.h"
31 #include "frontend/parallel/graph_util/graph_utils.h"
32 #include "mindspore/core/ops/sequence_ops.h"
33 #include "include/common/debug/anf_dump_utils.h"
34 #include "include/common/utils/parallel_context.h"
35 #include "ir/tensor.h"
36 #include "ir/value.h"
37 #include "mindspore/core/ops/framework_ops.h"
38 #include "utils/log_adapter.h"
39
40 namespace mindspore {
41 namespace parallel {
42 namespace {
43 struct InStrategyValueRegister {
InStrategyValueRegistermindspore::parallel::__anonc3f973980111::InStrategyValueRegister44 InStrategyValueRegister() noexcept {
45 AnfDumpHandler::SetInStrategyValueHandler([](const std::shared_ptr<AnfNode> &node) -> ValuePtr {
46 auto operator_info = node->user_data<parallel::OperatorInfo>();
47 if (operator_info == nullptr) {
48 return nullptr;
49 }
50
51 auto in_strategy = operator_info->strategy();
52 if (in_strategy == nullptr) {
53 return nullptr;
54 }
55
56 return MakeValue(in_strategy->GetInputDim());
57 });
58 }
59 } in_regist;
60
61 struct InStrategyStageValueRegister {
InStrategyStageValueRegistermindspore::parallel::__anonc3f973980111::InStrategyStageValueRegister62 InStrategyStageValueRegister() noexcept {
63 AnfDumpHandler::SetInStrategyStageValueHandler([](const std::shared_ptr<AnfNode> &node) -> ValuePtr {
64 auto operator_info = node->user_data<parallel::OperatorInfo>();
65 if (operator_info == nullptr) {
66 return nullptr;
67 }
68
69 auto in_strategy = operator_info->strategy();
70 if (in_strategy == nullptr) {
71 return nullptr;
72 }
73
74 return MakeValue(in_strategy->GetInputStage());
75 });
76 }
77 } in_stage_regist;
78
79 struct OutStrategyValueRegister {
OutStrategyValueRegistermindspore::parallel::__anonc3f973980111::OutStrategyValueRegister80 OutStrategyValueRegister() noexcept {
81 AnfDumpHandler::SetOutStrategyValueHandler([](const std::shared_ptr<AnfNode> &node) -> ValuePtr {
82 auto operator_info = node->user_data<parallel::OperatorInfo>();
83 if (operator_info == nullptr) {
84 return nullptr;
85 }
86
87 auto in_strategy = operator_info->out_strategy();
88 if (in_strategy == nullptr) {
89 return nullptr;
90 }
91
92 return MakeValue(in_strategy->GetInputDim());
93 });
94 }
95 } out_regist;
96 } // namespace
97
StrategyToString(const Strategies & strategy)98 std::string StrategyToString(const Strategies &strategy) {
99 std::string strategy_str = "";
100 strategy_str += "(";
101 for (size_t i = 0; i < strategy.size(); ++i) {
102 strategy_str += "(";
103 for (size_t j = 0; j < strategy[i].size(); ++j) {
104 strategy_str += std::to_string(strategy[i][j]);
105 if (j != strategy[i].size() - 1) {
106 strategy_str += ", ";
107 }
108 }
109 strategy_str += ")";
110 if (i != strategy.size() - 1) {
111 strategy_str += ", ";
112 }
113 }
114 if (strategy.size() == 1) {
115 strategy_str += ",";
116 }
117 strategy_str += ")";
118 return strategy_str;
119 }
120
CheckOutputStrategy(const StrategyPtr & out_strategy)121 Status OperatorInfo::CheckOutputStrategy(const StrategyPtr &out_strategy) {
122 if (out_strategy) {
123 MS_LOG(ERROR) << name_ << ": It does not support to set output strategy now, please modify the shard set";
124 return FAILED;
125 }
126 return SUCCESS;
127 }
128
CheckStrategyByVector(const Shapes & stra,const Shapes & inputs_shape)129 Status OperatorInfo::CheckStrategyByVector(const Shapes &stra, const Shapes &inputs_shape) {
130 size_t strategy_size = stra.size();
131 size_t inputs_shape_size = inputs_shape.size();
132 if (strategy_size != inputs_shape_size) {
133 MS_LOG(ERROR) << name_ << ": The strategy is " << StrategyToString(stra) << ", strategy size: " << strategy_size
134 << " is not equal to inputs size: " << inputs_shape_size;
135 return FAILED;
136 }
137
138 for (size_t i = 0; i < strategy_size; ++i) {
139 Shape sub_strategy = stra.at(i);
140 Shape sub_input_shape = inputs_shape.at(i);
141 size_t strategy_len = sub_strategy.size();
142 size_t inputs_len = sub_input_shape.size();
143 MS_LOG(DEBUG) << "Compare: sub_input_shape:" << sub_input_shape << " sub_strategy: " << sub_strategy;
144 if (strategy_len != inputs_len) {
145 MS_LOG(ERROR) << name_ << ": The strategy is " << StrategyToString(stra) << ", strategy len: " << strategy_len
146 << " is not equal to inputs len: " << inputs_len << ", index: " << i;
147 return FAILED;
148 }
149
150 for (size_t j = 0; j < strategy_len; ++j) {
151 int64_t strategy_value = sub_strategy.at(j);
152 if (strategy_value < MIN_SLICE_NUM) {
153 MS_LOG(ERROR) << name_ << ": The strategy is " << StrategyToString(stra)
154 << ", the value of strategy must be larger than 0, but get " << strategy_value;
155 return FAILED;
156 }
157
158 int64_t shape_value = sub_input_shape.at(j);
159 if (shape_value != -1 && (shape_value % strategy_value) != 0) {
160 if (dynamic_shape_flag_) {
161 Shapes origin_shapes = inputs_shape_clone_;
162 if (strategy_ != nullptr) { // if strategy_ is not null, means that check output strategy
163 origin_shapes = outputs_shape_clone_;
164 }
165 MS_LOG(ERROR) << name_ << ": The strategy is " << StrategyToString(stra) << ", shape or divisor "
166 << shape_value << " at " << j << " cannot be divisible by strategy value " << strategy_value
167 << ", shape is " << ShapeToString(origin_shapes[i]) << ", divisor is "
168 << ShapeToString(sub_input_shape);
169 } else {
170 MS_LOG(ERROR) << name_ << ": The strategy is " << StrategyToString(stra) << ", shape " << shape_value
171 << " at " << j << " cannot be divisible by strategy value " << strategy_value << ", shape is "
172 << ShapeToString(sub_input_shape);
173 }
174 return FAILED;
175 }
176
177 if ((LongToUlong(strategy_value) & LongToUlong(strategy_value - 1)) != 0) {
178 if ((g_device_manager->DeviceNum() & (g_device_manager->DeviceNum() - 1)) != 0) {
179 MS_LOG(INFO) << name_
180 << ": The device num is not the power of 2, thus do not check the strategy as power of 2";
181 continue;
182 }
183 MS_LOG(ERROR) << name_ << ": The strategy is " << StrategyToString(stra)
184 << ", the value of strategy must be the power of 2, but get " << strategy_value;
185 return FAILED;
186 }
187 }
188 }
189
190 return SUCCESS;
191 }
192
CheckStrategyValue(const StrategyPtr & strategy,const Shapes & inputs_shape)193 Status OperatorInfo::CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shape) {
194 if (strategy == nullptr) {
195 MS_LOG(ERROR) << name_ << ": The strategy is null.";
196 return FAILED;
197 }
198
199 Strategies stra = strategy->GetInputDim();
200 return CheckStrategyByVector(stra, inputs_shape);
201 }
202
ResetQueueMember()203 void OperatorInfo::ResetQueueMember() {
204 inputs_tensor_info_.clear();
205 outputs_tensor_info_.clear();
206 outputs_tensor_map_.clear();
207 out_dev_matrix_shape_.clear();
208 forward_op_.clear();
209 mirror_ops_.clear();
210 sub_ops_.clear();
211 replace_op_.clear();
212 replace_op_info_.clear();
213 virtual_div_op_.clear();
214 if (!is_layout_config_) {
215 inputs_tensor_map_.clear();
216 dev_matrix_shape_.clear();
217 }
218 strategy_ = nullptr;
219 out_strategy_ = nullptr;
220 }
221
CheckLayoutConfigBase()222 Status OperatorInfo::CheckLayoutConfigBase() {
223 // size
224 if (inputs_tensor_map_.size() != inputs_shape_.size()) {
225 MS_LOG(ERROR) << name_
226 << ": the size of inputs tensor map and inputs shape must be equal, but the inputs tensor map is "
227 << inputs_tensor_map_ << ", and the inputs shape is " << inputs_shape_;
228 return FAILED;
229 }
230
231 for (size_t i = 0; i < inputs_shape_.size(); ++i) {
232 if (inputs_shape_[i].size() != inputs_tensor_map_[i].size()) {
233 MS_LOG(ERROR) << name_
234 << ": the size of input tensor map and input shape must be equal, but the inputs tensor map is "
235 << inputs_tensor_map_ << ", and the inputs shape is " << inputs_shape_ << ", the " << i
236 << "th is not equal";
237 return FAILED;
238 }
239 }
240
241 size_t dev_matrix_size = dev_matrix_shape_.size();
242 strategy_from_layout_.clear();
243
244 for (size_t j = 0; j < inputs_tensor_map_.size(); ++j) {
245 Shape tmp_strategy;
246 for (size_t k = 0; k < inputs_tensor_map_[j].size(); ++k) {
247 auto map = inputs_tensor_map_[j][k];
248
249 // range
250 if (map == MAP_NONE) {
251 tmp_strategy.push_back(NO_SPLIT_STRATEGY);
252 continue;
253 }
254
255 if (map < 0 || map >= SizeToLong(dev_matrix_size)) {
256 MS_LOG(ERROR) << name_ << ": the range of tensor_map's value is [-1, " << (dev_matrix_size - 1)
257 << "], but the inputs tensor map is " << inputs_tensor_map_;
258 return FAILED;
259 }
260
261 // divisible
262 auto shard_num = dev_matrix_shape_[dev_matrix_size - LongToSize(map) - 1];
263 MS_EXCEPTION_IF_ZERO("shard_num", shard_num);
264 if (inputs_shape_[j][k] % shard_num != 0) {
265 MS_LOG(ERROR) << name_ << ": the shape is not divisible by layout, the input shape is " << inputs_shape_
266 << ", the dev matrix is " << dev_matrix_shape_ << ", and the tensor map is "
267 << inputs_tensor_map_;
268 return FAILED;
269 }
270
271 // if shard_num is 1, reset the map to -1
272 if (shard_num == NO_SPLIT_STRATEGY) {
273 inputs_tensor_map_[j][k] = MAP_NONE;
274 }
275 tmp_strategy.push_back(shard_num);
276 }
277 strategy_from_layout_.push_back(tmp_strategy);
278 }
279
280 MS_LOG(INFO) << name_ << ": the strategy from layout is " << strategy_from_layout_;
281 return SUCCESS;
282 }
283
GetLayoutConfig()284 Status OperatorInfo::GetLayoutConfig() {
285 auto layout_iter = attrs_.find(LAYOUT);
286 if (layout_iter == attrs_.end()) {
287 return SUCCESS;
288 }
289
290 MS_EXCEPTION_IF_NULL(layout_iter->second);
291 auto layout = layout_iter->second;
292 if (!layout->isa<ValueDictionary>()) {
293 MS_LOG(ERROR) << name_ << ": the layout is not a dict";
294 return FAILED;
295 }
296
297 auto dict = layout->cast<ValueDictionaryPtr>();
298 for (const auto &kv : dict->value()) {
299 ValuePtr key_ptr = kv.first;
300 ValuePtr value_ptr = kv.second;
301 MS_EXCEPTION_IF_NULL(key_ptr);
302 MS_EXCEPTION_IF_NULL(value_ptr);
303 if (!key_ptr->isa<StringImm>()) {
304 MS_LOG(EXCEPTION) << name_ << ": the value of key is not string";
305 }
306
307 std::string key = key_ptr->cast<StringImmPtr>()->value();
308
309 if (key == DEV_MATRIX) {
310 if (!value_ptr->isa<ValueTuple>()) {
311 MS_LOG(ERROR) << name_ << ": the type of dev matrix is not tuple";
312 return FAILED;
313 }
314
315 dev_matrix_shape_ = GetValue<std::vector<int64_t>>(value_ptr);
316 auto used_devices =
317 std::accumulate(dev_matrix_shape_.begin(), dev_matrix_shape_.end(), 1, std::multiplies<int64_t>());
318 if (used_devices != stage_device_size_) {
319 MS_LOG(ERROR) << name_
320 << ": the product of dev matrix must be equal to the stage divece size, but the dev matrix is "
321 << dev_matrix_shape_ << ", but the stage device size is " << stage_device_size_;
322 return FAILED;
323 }
324 continue;
325 }
326
327 if (key == INPUT_TENSOR_MAP) {
328 auto var = value_ptr->cast<ValueTuplePtr>();
329 if (!value_ptr->isa<ValueTuple>()) {
330 MS_LOG(ERROR) << name_ << ": the type of input_tensor_map is not tuple";
331 return FAILED;
332 }
333
334 std::vector<ValuePtr> elements = var->value();
335 for (const auto &ele : elements) {
336 Shape tensor_map;
337 if (ele->isa<ValueSequence>()) {
338 auto value_tuple = ele->cast<ValueTuplePtr>();
339 std::vector<ValuePtr> value_vector = value_tuple->value();
340 (void)std::transform(value_vector.begin(), value_vector.end(), std::back_inserter(tensor_map),
341 [](const ValuePtr &value) { return static_cast<int64_t>(GetValue<int64_t>(value)); });
342 inputs_tensor_map_.push_back(tensor_map);
343 } else {
344 MS_LOG(ERROR) << name_ << ": the format of input tensor map is wrong! Need ValueSequence";
345 return FAILED;
346 }
347 }
348 continue;
349 }
350
351 MS_LOG(ERROR) << name_ << ": the invalid key for layout: " << key;
352 return FAILED;
353 }
354
355 MS_LOG(INFO) << name_ << ": the dev matrix is " << dev_matrix_shape_ << ", the tensor map is " << inputs_tensor_map_;
356
357 is_layout_config_ = true;
358
359 return CheckLayoutConfigBase();
360 }
361
IsDynamicShape()362 bool OperatorInfo::IsDynamicShape() {
363 for (auto &input_shape : inputs_shape_) {
364 auto in_it = std::find_if(input_shape.cbegin(), input_shape.cend(), [&](const int64_t ele) { return ele == -1; });
365 if (in_it != input_shape.end()) {
366 return True;
367 }
368 }
369
370 for (auto &output_shape : outputs_shape_) {
371 auto out_it =
372 std::find_if(output_shape.cbegin(), output_shape.cend(), [&](const int64_t ele) { return ele == -1; });
373 if (out_it != output_shape.end()) {
374 return True;
375 }
376 }
377 return False;
378 }
379
IsDynamicRank()380 bool OperatorInfo::IsDynamicRank() {
381 for (auto &input_shape : inputs_shape_) {
382 auto in_it = std::find_if(input_shape.cbegin(), input_shape.cend(), [&](const int64_t ele) { return ele == -2; });
383 if (in_it != input_shape.end()) {
384 return True;
385 }
386 }
387
388 for (auto &output_shape : outputs_shape_) {
389 auto out_it =
390 std::find_if(output_shape.cbegin(), output_shape.cend(), [&](const int64_t ele) { return ele == -2; });
391 if (out_it != output_shape.end()) {
392 return True;
393 }
394 }
395 return False;
396 }
397
IsSelfDefineShard()398 bool OperatorInfo::IsSelfDefineShard() {
399 bool self_define_shard = false;
400 auto attr_iter = attrs_.find(parallel::SELF_DEFINE_SHARD);
401 if (attr_iter != attrs_.end()) {
402 self_define_shard = attr_iter->second->cast<BoolImmPtr>()->value();
403 }
404 return self_define_shard;
405 }
406
GetRepeatedNumInDevMatrixRight()407 Status OperatorInfo::GetRepeatedNumInDevMatrixRight() {
408 bool repeated_num_right = true;
409 auto iter = attrs_.find(REPEATED_NUM_IN_DEV_MATRIX_RIGHT);
410 if (iter != attrs_.end()) {
411 MS_EXCEPTION_IF_NULL(iter->second);
412 if (iter->second->isa<BoolImm>()) {
413 repeated_num_right = iter->second->cast<BoolImmPtr>()->value();
414 MS_LOG(INFO) << name_ << ": attr " << REPEATED_NUM_IN_DEV_MATRIX_RIGHT << " will be set to "
415 << repeated_num_right;
416 } else {
417 MS_LOG(ERROR) << name_ << ": The value of " << REPEATED_NUM_IN_DEV_MATRIX_RIGHT << " is not bool.";
418 return FAILED;
419 }
420 }
421 repeated_num_in_dev_matrix_right_ = repeated_num_right;
422 return SUCCESS;
423 }
424
InferAttrs()425 Status OperatorInfo::InferAttrs() {
426 if (infer_attrs_completed_) {
427 return SUCCESS;
428 }
429
430 if (GetAttrs() != SUCCESS) {
431 return FAILED;
432 }
433
434 if (GetRepeatedNumInDevMatrixRight() != SUCCESS) {
435 return FAILED;
436 }
437
438 if (GetLayoutConfig() != SUCCESS) {
439 return FAILED;
440 }
441
442 if (is_layout_config_ && CheckLayoutConfig() != SUCCESS) {
443 return FAILED;
444 }
445
446 self_define_shard_ = IsSelfDefineShard();
447 is_dynamic_shape_ = IsDynamicShape();
448 is_dynamic_rank_ = IsDynamicRank();
449 if (is_dynamic_rank_) {
450 MS_LOG(ERROR) << name_
451 << ": it does not support dynamic rank now, the inupts' shape: " << ShapesToString(inputs_shape_)
452 << ", the outputs' shape: " << ShapesToString(outputs_shape_);
453 return FAILED;
454 }
455
456 inputs_shape_clone_ = inputs_shape_;
457 outputs_shape_clone_ = outputs_shape_;
458
459 infer_attrs_completed_ = true;
460 return SUCCESS;
461 }
462
InferMirrorOps()463 Status OperatorInfo::InferMirrorOps() {
464 mirror_ops_.clear();
465 if (inputs_shape_.empty()) {
466 MS_LOG(INFO) << name_ << ": The inputs size is empty";
467 return SUCCESS;
468 }
469
470 if (inputs_tensor_map_.size() != inputs_shape_.size()) {
471 MS_LOG(ERROR) << name_ << ": The size of inputs tensor map is not equal to the size of inputs shape";
472 return FAILED;
473 }
474
475 bool group_is_empty = true;
476 for (size_t i = 0; i < inputs_tensor_map_.size(); ++i) {
477 std::vector<Group> group;
478 if (CreateGroupByTensorMap(inputs_tensor_map_[i], &group) != SUCCESS) {
479 ReportError(name_ + ": Create group failed, the input index is " + std::to_string(i));
480 mirror_ops_.clear();
481 return FAILED;
482 }
483
484 OperatorVector mirror_op;
485 if (group.empty()) {
486 MS_LOG(INFO) << name_ << ": The mirror group is empty, the input index is " << i;
487 mirror_ops_.push_back(mirror_op);
488 continue;
489 }
490
491 group_is_empty = false;
492 mirror_op = CreateMirrorOps(group[0].name(), group[0].GetDevNum());
493 mirror_ops_.push_back(mirror_op);
494 }
495
496 if (group_is_empty) {
497 mirror_ops_.clear();
498 MS_LOG(INFO) << name_ << ": No need to insert mirror ops";
499 }
500 return SUCCESS;
501 }
502
InferMirrorOpsByLayout()503 Status OperatorInfo::InferMirrorOpsByLayout() {
504 mirror_ops_.clear();
505 if (inputs_shape_.empty()) {
506 MS_LOG(INFO) << name_ << ": The inputs size is empty";
507 return SUCCESS;
508 }
509
510 bool group_is_empty = true;
511 for (size_t i = 0; i < inputs_tensor_info_.size(); ++i) {
512 auto input_tensor_layout = inputs_tensor_info_[i].tensor_layout();
513 auto repeated_rank_list = input_tensor_layout.InferRepeatedGroup();
514
515 OperatorVector mirror_op;
516 if (repeated_rank_list.size() == 1) {
517 MS_LOG(INFO) << name_ << ": The mirror group is empty, the input index is " << i;
518 mirror_ops_.push_back(mirror_op);
519 continue;
520 }
521 if (is_auto_parallel_) {
522 if (g_device_manager->CheckDeviceList(repeated_rank_list) != SUCCESS) {
523 MS_LOG(INFO) << name_ << ": Try to create communication group : " << repeated_rank_list
524 << " failed in auto parallel mode, "
525 "this error can be ignored in parallel strategies searching step";
526 return FAILED;
527 }
528 return SUCCESS;
529 }
530
531 Group mirror_group;
532 if (g_device_manager->CreateGroup(repeated_rank_list, &mirror_group) != SUCCESS) {
533 MS_LOG(ERROR) << name_
534 << ": Create communication group by tensor_map failed, the rank_list is: " << repeated_rank_list
535 << ", the full_name of node is: " << cnode_->fullname_with_scope();
536 return FAILED;
537 }
538 group_is_empty = false;
539 mirror_op = CreateMirrorOps(mirror_group.name(), mirror_group.GetDevNum());
540 mirror_ops_.push_back(mirror_op);
541 }
542
543 if (group_is_empty) {
544 mirror_ops_.clear();
545 MS_LOG(INFO) << name_ << ": No need to insert mirror ops";
546 }
547 return SUCCESS;
548 }
549
CreateTensorInfo(const Shape & device_matrix,const ShapeBasePtr & inputs_shape,const ShapeBasePtr & inputs_tensor_map)550 TensorInfoBasePtr CreateTensorInfo(const Shape &device_matrix, const ShapeBasePtr &inputs_shape,
551 const ShapeBasePtr &inputs_tensor_map) {
552 TensorInfoBasePtr out_tensor_info;
553 if (inputs_shape->is_list()) {
554 std::vector<TensorInfoBasePtr> tensor_info_list;
555 for (int64_t i = 0; i < SizeToLong(inputs_shape->size()); ++i) {
556 auto tensor_map = inputs_tensor_map->GetElement(i);
557 auto shape = inputs_shape->GetElement(i);
558 auto input_tensor_info = CreateTensorInfo(device_matrix, shape, tensor_map);
559 tensor_info_list.emplace_back(input_tensor_info);
560 }
561 out_tensor_info = std::make_shared<TensorInfoList>(tensor_info_list);
562 } else {
563 TensorLayout input_layout;
564 input_layout.InitFromVector(device_matrix, inputs_tensor_map->GetValue(), inputs_shape->GetValue());
565 TensorInfo input_tensor_info(input_layout);
566 out_tensor_info = std::make_shared<TensorInfoValue>(input_tensor_info);
567 }
568 return out_tensor_info;
569 }
570
InferTensorInfoNew()571 Status OperatorInfo::InferTensorInfoNew() {
572 size_t real_input_index = 0;
573 for (size_t i = 0; i < inputs_tensor_map_new_.size(); ++i) {
574 // Insert placeholder TensorInfo for optional input
575 while (real_input_index < input_value_.size() && input_value_[real_input_index] != nullptr &&
576 input_value_[real_input_index]->isa<None>()) {
577 (void)inputs_tensor_info_new_.emplace_back(std::make_shared<TensorInfoValue>(TensorInfo()));
578 ++real_input_index;
579 }
580 auto input_tensor_info = CreateTensorInfo(dev_matrix_shape_, inputs_shape_new_[i], inputs_tensor_map_new_[i]);
581 inputs_tensor_info_new_.emplace_back(input_tensor_info);
582 ++real_input_index;
583 }
584
585 for (size_t i = 0; i < outputs_tensor_map_new_.size(); ++i) {
586 auto output_tensor_info = CreateTensorInfo(dev_matrix_shape_, outputs_shape_new_[i], outputs_tensor_map_new_[i]);
587 outputs_tensor_info_new_.emplace_back(output_tensor_info);
588 }
589 return SUCCESS;
590 }
591
UpdateOutputTensorInfoForInterleaved()592 void OperatorInfo::UpdateOutputTensorInfoForInterleaved() {
593 if (!std::any_of(inputs_tensor_info_.begin(), inputs_tensor_info_.end(), [](const TensorInfo &input_tensor_info) {
594 return input_tensor_info.tensor_layout().IsInterleavedParallel();
595 })) {
596 return;
597 }
598 if (std::any_of(outputs_tensor_info_.begin(), outputs_tensor_info_.end(), [](const TensorInfo &output_tensor_info) {
599 return output_tensor_info.tensor_layout().IsInterleavedParallel();
600 })) {
601 return;
602 }
603 auto interleaved_num = ParallelContext::GetInstance()->fine_grained_micro_interleaved_size();
604 auto output_dev_matrix = outputs_tensor_info_[kIndex0].tensor_layout().device_arrangement_origin().array();
605 output_dev_matrix[output_dev_matrix.size() - 1] = interleaved_num;
606 Arrangement out_device_arrangement_interleaved;
607 out_device_arrangement_interleaved.Init(output_dev_matrix);
608 auto new_tensor_layout = outputs_tensor_info_[kIndex0].tensor_layout();
609 new_tensor_layout.set_device_arrangement_interleaved(out_device_arrangement_interleaved);
610 TensorInfo new_output_tensor_info(new_tensor_layout);
611 outputs_tensor_info_[kIndex0] = new_output_tensor_info;
612 }
613
InferTensorInfo()614 Status OperatorInfo::InferTensorInfo() {
615 if (!inputs_shape_new_.empty()) {
616 return InferTensorInfoNew();
617 }
618 if (inputs_shape_.empty() || outputs_shape_.empty() || inputs_tensor_map_.empty() || outputs_tensor_map_.empty()) {
619 MS_LOG(ERROR) << name_ << ": Invalid args";
620 return FAILED;
621 }
622
623 size_t real_input_index = 0;
624 for (size_t i = 0; i < inputs_tensor_map_.size(); ++i) {
625 // Insert placeholder TensorInfo for optional input
626 while (real_input_index < input_value_.size() && input_value_[real_input_index] != nullptr &&
627 input_value_[real_input_index]->isa<None>()) {
628 (void)inputs_tensor_info_.emplace_back(TensorInfo());
629 ++real_input_index;
630 }
631 TensorLayout input_layout;
632 if (input_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[i], inputs_shape_[i]) != SUCCESS) {
633 MS_LOG(ERROR) << name_ << ": Infer input tensor layout failed, the index is " << i;
634 return FAILED;
635 }
636 TensorInfo input_tensor_info(input_layout);
637 inputs_tensor_info_.push_back(input_tensor_info);
638 ++real_input_index;
639 }
640
641 for (size_t i = 0; i < outputs_tensor_map_.size(); ++i) {
642 TensorLayout output_layout;
643 if (output_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[i], outputs_shape_[i]) != SUCCESS) {
644 MS_LOG(ERROR) << name_ << ": Infer output tensor layout failed, the index is " << i;
645 return FAILED;
646 }
647 TensorInfo output_tensor_info(output_layout);
648 outputs_tensor_info_.push_back(output_tensor_info);
649 }
650
651 return SUCCESS;
652 }
653
InferRepeatedCalcInfo()654 Status OperatorInfo::InferRepeatedCalcInfo() {
655 int64_t g_dev_list_size = stage_device_size_;
656 int64_t dev_matrix_size =
657 std::accumulate(dev_matrix_shape_.begin(), dev_matrix_shape_.end(), 1, std::multiplies<int64_t>());
658 if (dev_matrix_size == 0) {
659 MS_LOG(ERROR) << name_ << ": The dev matrix size is 0";
660 return FAILED;
661 }
662
663 if (g_dev_list_size == dev_matrix_size) {
664 repeated_calc_num_ = 1;
665 } else if (g_dev_list_size % dev_matrix_size == 0) {
666 repeated_calc_num_ = g_dev_list_size / dev_matrix_size;
667 } else {
668 MS_LOG(ERROR) << name_ << ": The strategy is " << StrategyToString(strategy_->GetInputDim()) << ", it requires "
669 << dev_matrix_size << " devices, "
670 << "but the device number of this stage is " << g_dev_list_size << ", it can not be divisible by "
671 << dev_matrix_size;
672 return FAILED;
673 }
674 return SUCCESS;
675 }
676
677 // If repeated calculation, set the repeated_calc_num as the last dimension of dev-matrix in default,
678 // because if the previous shard is (a, b), and the next shard is (a, 1), adding the repeated_calc_num
679 // to the last dimension of dev-matrix, there is no need to redistribution.
SetRepeatedCalcDevMatrix()680 void OperatorInfo::SetRepeatedCalcDevMatrix() {
681 if (repeated_calc_num_ <= 1) {
682 return;
683 }
684 if (repeated_num_in_dev_matrix_right_) {
685 dev_matrix_shape_.push_back(repeated_calc_num_);
686 } else {
687 (void)dev_matrix_shape_.insert(dev_matrix_shape_.cbegin(), repeated_calc_num_);
688 }
689 }
690
ResetTupleTensorMapIfRepeatedCalc(NewTensorMaps * tensor_map_new)691 void OperatorInfo::ResetTupleTensorMapIfRepeatedCalc(NewTensorMaps *tensor_map_new) {
692 MS_EXCEPTION_IF_NULL(tensor_map_new);
693 for (auto &tensor_map : *tensor_map_new) {
694 if (tensor_map->is_list()) {
695 std::vector<ShapeBasePtr> new_list;
696 for (auto &elements : tensor_map->GetAllElements()) {
697 std::vector<int64_t> new_shape;
698 for (auto &element : elements) {
699 if (element != MAP_NONE) {
700 element += 1;
701 }
702 new_shape.emplace_back(element);
703 }
704 new_list.emplace_back(std::make_shared<ShapeValue>(new_shape));
705 }
706 tensor_map->set_shape(std::make_shared<ShapeList>(new_list));
707 } else {
708 std::vector<int64_t> new_shape;
709 for (auto &element : tensor_map->GetValue()) {
710 if (element != MAP_NONE) {
711 element += 1;
712 }
713 new_shape.emplace_back(element);
714 }
715 tensor_map->set_shape(std::make_shared<ShapeValue>(new_shape));
716 }
717 }
718 }
719
720 // If repeated calculation, and the repeated_calc_num is inserted to the last dimension of the dev-matrix,
721 // the index value of tensor map needs to be increased by 1.
ResetTensorMapIfRepeatedCalc()722 void OperatorInfo::ResetTensorMapIfRepeatedCalc() {
723 if ((repeated_calc_num_ <= 1) || !repeated_num_in_dev_matrix_right_) {
724 return;
725 }
726
727 MS_LOG(DEBUG) << name_ << ": the repeated calc num is " << repeated_calc_num_ << ", and reset the tensor maps";
728 for (auto &tensor_map : inputs_tensor_map_) {
729 for (auto &element : tensor_map) {
730 if (element == MAP_NONE) {
731 continue;
732 }
733 element += 1;
734 }
735 }
736
737 for (auto &tensor_map : outputs_tensor_map_) {
738 for (auto &element : tensor_map) {
739 if (element == MAP_NONE) {
740 continue;
741 }
742 element += 1;
743 }
744 }
745
746 ResetTupleTensorMapIfRepeatedCalc(&inputs_tensor_map_new_);
747 ResetTupleTensorMapIfRepeatedCalc(&outputs_tensor_map_new_);
748 }
749
750 // use for loss repeated calculation
CreateVirtualDivOp(int64_t div_num)751 Operator CreateVirtualDivOp(int64_t div_num) {
752 OperatorName operator_name = VIRTUAL_DIV;
753 ValuePtr attr0_value = MakeValue(div_num);
754 Attr attr0 = std::make_pair(DIVISOR, attr0_value);
755 OperatorAttrs operator_attrs;
756 operator_attrs.push_back(attr0);
757
758 OperatorParams operator_param;
759 OperatorArgs operator_arg = std::make_pair(operator_attrs, operator_param);
760
761 Operator op = std::make_pair(operator_name, operator_arg);
762 return op;
763 }
764
CreateDivOp(float scale)765 Operator CreateDivOp(float scale) {
766 OperatorName operator_name = REAL_DIV;
767 OperatorAttrs operator_attrs;
768 OperatorParams operator_param;
769 constexpr size_t parameter_pos = 2;
770 mindspore::tensor::TensorPtr tensor_ptr = std::make_shared<mindspore::tensor::Tensor>(scale);
771 ValuePtr scale_value = MakeValue(tensor_ptr);
772 (void)operator_param.emplace_back(std::make_pair(std::make_pair(Y, scale_value), parameter_pos));
773 OperatorArgs operator_arg = std::make_pair(operator_attrs, operator_param);
774
775 Operator op = std::make_pair(operator_name, operator_arg);
776 return op;
777 }
778
CreateScalarFloorDivOp(int64_t div_num)779 Operator CreateScalarFloorDivOp(int64_t div_num) {
780 OperatorName operator_name = SCALAR_FLOOR_DIV;
781 OperatorAttrs operator_attrs;
782 OperatorParams operator_param;
783 constexpr size_t parameter_pos = 2;
784 ValuePtr scale_value = MakeValue(div_num);
785 (void)operator_param.emplace_back(std::make_pair(std::make_pair(Y, scale_value), parameter_pos));
786 OperatorArgs operator_arg = std::make_pair(operator_attrs, operator_param);
787
788 Operator op = std::make_pair(operator_name, operator_arg);
789 return op;
790 }
791
CreateScalarDivOp(int64_t div_num)792 Operator CreateScalarDivOp(int64_t div_num) {
793 OperatorName operator_name = SCALAR_DIV;
794 OperatorAttrs operator_attrs;
795 OperatorParams operator_param;
796 constexpr size_t parameter_pos = 2;
797 ValuePtr scale_value = MakeValue(std::make_shared<Int64Imm>(div_num));
798 (void)operator_param.emplace_back(std::make_pair(std::make_pair(Y, scale_value), parameter_pos));
799 OperatorArgs operator_arg = std::make_pair(operator_attrs, operator_param);
800
801 Operator op = std::make_pair(operator_name, operator_arg);
802 return op;
803 }
804
CreateScalarMulOp(int64_t scalar)805 Operator CreateScalarMulOp(int64_t scalar) {
806 OperatorName operator_name = SCALAR_MUL;
807 OperatorAttrs operator_attrs;
808 OperatorParams operator_param;
809 constexpr size_t parameter_pos = 2;
810 ValuePtr scale_value = MakeValue(std::make_shared<Int64Imm>(scalar));
811 (void)operator_param.emplace_back(std::make_pair(std::make_pair(Y, scale_value), parameter_pos));
812 OperatorArgs operator_arg = std::make_pair(operator_attrs, operator_param);
813
814 Operator op = std::make_pair(operator_name, operator_arg);
815 return op;
816 }
817
CreateReduceCommunicationOpArgs(const std::string & reduce_op,const std::string & group)818 static OperatorArgs CreateReduceCommunicationOpArgs(const std::string &reduce_op, const std::string &group) {
819 ValuePtr attr0_value = MakeValue(reduce_op);
820 ValuePtr attr1_value = MakeValue(group);
821 Attr attr0 = std::make_pair(OP, attr0_value);
822 Attr attr1 = std::make_pair(GROUP, attr1_value);
823 OperatorAttrs operator_attrs;
824 operator_attrs.push_back(attr0);
825 operator_attrs.push_back(attr1);
826
827 OperatorParams operator_param;
828 return std::make_pair(operator_attrs, operator_param);
829 }
830
831 // use for forward all reduce
CreateAllReduceOp(const std::string & reduce_op,const std::string & group)832 Operator CreateAllReduceOp(const std::string &reduce_op, const std::string &group) {
833 OperatorName operator_name = ALL_REDUCE;
834 OperatorArgs operator_arg = CreateReduceCommunicationOpArgs(reduce_op, group);
835
836 Operator op = std::make_pair(operator_name, operator_arg);
837 MS_LOG(INFO) << "Create all reduce op success, the reduce_op is " << reduce_op << ", the group is " << group;
838 return op;
839 }
840
CreateReduceScatterOp(const std::string & reduce_op,const std::string & group)841 Operator CreateReduceScatterOp(const std::string &reduce_op, const std::string &group) {
842 OperatorName operator_name = REDUCE_SCATTER;
843 OperatorArgs operator_arg = CreateReduceCommunicationOpArgs(reduce_op, group);
844
845 Operator op = std::make_pair(operator_name, operator_arg);
846 MS_LOG(INFO) << "Create reduce scatter op success, the reduce_op is " << reduce_op << ", the group is " << group;
847 return op;
848 }
849
CreateCastOp(TypePtr type)850 Operator CreateCastOp(TypePtr type) {
851 auto type_id = MakeValue(static_cast<int64_t>(type->type_id()));
852 Param param_type = std::make_pair(std::make_pair(DTYPE, type_id), 2);
853 OperatorAttrs attrs;
854 OperatorParams params = {param_type};
855 OperatorArgs args = std::make_pair(attrs, params);
856 Operator op_cast = std::make_pair(CAST, args);
857 return op_cast;
858 }
859
CreateScalarCastOp(TypePtr type)860 Operator CreateScalarCastOp(TypePtr type) {
861 auto type_id = MakeValue(static_cast<int64_t>(type->type_id()));
862 Param param_type = std::make_pair(std::make_pair(DTYPE, type_id), 2);
863 OperatorAttrs attrs;
864 OperatorParams params = {param_type};
865 OperatorArgs args = std::make_pair(attrs, params);
866 Operator op_cast = std::make_pair(SCALAR_CAST, args);
867 return op_cast;
868 }
869
AddCommOpFusionType(const CNodePtr & comm_node,const AnfNodePtr & param_node)870 int32_t AddCommOpFusionType(const CNodePtr &comm_node, const AnfNodePtr ¶m_node) {
871 MS_EXCEPTION_IF_NULL(comm_node);
872 MS_EXCEPTION_IF_NULL(param_node);
873 ParameterPtr param;
874 if (IsPrimitiveCNode(param_node, prim::kPrimReceive)) {
875 param = param_node->user_data<AnfNode>(PIPELINE_PARAM)->cast<ParameterPtr>();
876 } else {
877 param = param_node->cast<ParameterPtr>();
878 }
879 MS_EXCEPTION_IF_NULL(param);
880 auto prim = GetValueNode<PrimitivePtr>(comm_node->input(0));
881 MS_EXCEPTION_IF_NULL(prim);
882 auto attrs = prim->attrs();
883 auto param_info = param->param_info();
884 if (!param_info) {
885 MS_LOG(WARNING) << param->ToString() << "does not have parameter info.";
886 return 0;
887 }
888 int32_t fusion_type = param_info->comm_fusion();
889 attrs[FUSION] = MakeValue<int64_t>(fusion_type);
890 (void)prim->SetAttrs(attrs);
891 bool parallel_optimizer_comm_recompute = param_info->parallel_optimizer_comm_recompute();
892 std::string instance_name = prim->instance_name();
893 if (instance_name == PARALLEL_OPTIMIZER_ALLGATHER_NOT_COMPUTE && parallel_optimizer_comm_recompute &&
894 prim->name() == ALL_GATHER) {
895 prim->set_attr(RECOMPUTE, MakeValue(true));
896 prim->set_instance_name(PARALLEL_OPTIMIZER_ALLGATHER);
897 }
898 MS_LOG(INFO) << "Set comm fusion:" << param->param_info()->name() << "'s fusion type is " << fusion_type;
899 return fusion_type;
900 }
901
AddCommOpMeanFlag(const CNodePtr & comm_node)902 void AddCommOpMeanFlag(const CNodePtr &comm_node) {
903 MS_EXCEPTION_IF_NULL(comm_node);
904 auto prim = GetValueNode<PrimitivePtr>(comm_node->input(0));
905 auto attrs = prim->attrs();
906 MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
907 bool mean_flag = ParallelContext::GetInstance()->gradients_mean();
908 attrs[MEAN_FLAG] = MakeValue<bool>(mean_flag);
909 (void)prim->SetAttrs(attrs);
910 }
911
AddCNodePrimAttr(const CNodePtr & comm_node,const std::string & attr_name,const ValuePtr & attr_val)912 void AddCNodePrimAttr(const CNodePtr &comm_node, const std::string &attr_name, const ValuePtr &attr_val) {
913 MS_EXCEPTION_IF_NULL(comm_node);
914 auto prim = GetValueNode<PrimitivePtr>(comm_node->input(0));
915 auto attrs = prim->attrs();
916 attrs[attr_name] = attr_val;
917 (void)prim->SetAttrs(attrs);
918 }
919
AddCommOpParamFlag(const CNodePtr & comm_node)920 void AddCommOpParamFlag(const CNodePtr &comm_node) {
921 MS_EXCEPTION_IF_NULL(comm_node);
922 auto graph = comm_node->func_graph();
923 MS_EXCEPTION_IF_NULL(graph);
924 auto manager = graph->manager();
925 MS_EXCEPTION_IF_NULL(manager);
926 auto node_users = manager->node_users()[comm_node->input(1)];
927 for (auto &node_user : node_users) {
928 if (IsPrimitiveCNode(node_user.first, prim::kPrimSend)) {
929 auto prim = GetCNodePrimitive(comm_node);
930 (void)prim->AddAttr(PARAMETER_MICRO, MakeValue(0));
931 return;
932 }
933 }
934 }
935
CreateAllGatherOp(const std::string & group)936 Operator CreateAllGatherOp(const std::string &group) {
937 OperatorName operator_name = ALL_GATHER;
938 // group
939 ValuePtr attr0_value = MakeValue(group);
940 Attr attr0 = std::make_pair(GROUP, attr0_value);
941 OperatorAttrs operator_attrs;
942 operator_attrs.push_back(attr0);
943
944 OperatorParams operator_param;
945 OperatorArgs operator_arg = std::make_pair(operator_attrs, operator_param);
946
947 Operator op = std::make_pair(operator_name, operator_arg);
948 MS_LOG(INFO) << "Create allgather op success, the group is " << group;
949 return op;
950 }
951
CreateMicroStepAllGatherOp(const std::string & group)952 Operator CreateMicroStepAllGatherOp(const std::string &group) {
953 bool mean_flag = ParallelContext::GetInstance()->gradients_mean();
954 OperatorName operator_name = MICRO_STEP_ALL_GATHER;
955 // group
956 ValuePtr attr0_value = MakeValue(group);
957 Attr attr0 = std::make_pair(GROUP, attr0_value);
958 // mean_flag
959 ValuePtr attr1_value = MakeValue(mean_flag);
960 Attr attr1 = std::make_pair(MEAN_FLAG, attr1_value);
961 OperatorAttrs operator_attrs;
962 operator_attrs.push_back(attr0);
963 operator_attrs.push_back(attr1);
964
965 OperatorParams operator_param;
966 OperatorArgs operator_arg = std::make_pair(operator_attrs, operator_param);
967
968 Operator op = std::make_pair(operator_name, operator_arg);
969 MS_LOG(INFO) << "Create MICRO_STEP_ALL_GATHER success, the group is " << group;
970 return op;
971 }
972
973 // use for get tensor slice
CreateGetTensorSliceOp(const TensorLayout & tensor_layout)974 Operator CreateGetTensorSliceOp(const TensorLayout &tensor_layout) {
975 Shape tensor_map = tensor_layout.tensor_map().array();
976 Shape dev_matrix_shape = tensor_layout.device_arrangement().array();
977 Shape slice_shape = tensor_layout.base_slice_shape().array();
978 Shape full_shape = tensor_layout.tensor_shape().array();
979 OperatorName operator_name = GET_TENSOR_SLICE;
980
981 OperatorAttrs attrs;
982 ValuePtr dev_mat_value = MakeValue(dev_matrix_shape);
983 Param dev_mat_param = std::make_pair(std::make_pair(DEV_MAT, dev_mat_value), 2);
984 ValuePtr tensor_map_value = MakeValue(tensor_map);
985 Param tensor_map_param = std::make_pair(std::make_pair(TENSOR_MAP, tensor_map_value), 3);
986 ValuePtr slice_shape_value = MakeValue(slice_shape);
987 Param slice_shape_param = std::make_pair(std::make_pair(SLICE_SHAPE, slice_shape_value), 4);
988 ValuePtr full_shape_value = MakeValue(full_shape);
989 Param full_shape_param = std::make_pair(std::make_pair(FULL_SHAPE, full_shape_value), 5);
990 OperatorParams params = {dev_mat_param, tensor_map_param, slice_shape_param, full_shape_param};
991 OperatorArgs operator_arg = std::make_pair(attrs, params);
992
993 Operator op = std::make_pair(operator_name, operator_arg);
994 MS_LOG(INFO) << "Create get tensor slice op success, the dev mat and tensor map is "
995 << ShapeToString(dev_matrix_shape) << ", " << ShapeToString(tensor_map);
996 return op;
997 }
998
CreateMirrorOps(const std::string & group_name,size_t dev_num)999 OperatorVector CreateMirrorOps(const std::string &group_name, size_t dev_num) {
1000 if (dev_num == 0) {
1001 MS_LOG(EXCEPTION) << "Invalid dev num: " << dev_num;
1002 }
1003 OperatorVector op_for_weight;
1004 bool mean_flag = ParallelContext::GetInstance()->gradients_mean();
1005 int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step();
1006 int64_t split_stage_num = ParallelContext::GetInstance()->pipeline_stage_split_num();
1007
1008 ValuePtr attr0_value = MakeValue(group_name);
1009 ValuePtr attr1_value = MakeValue(SizeToLong(dev_num));
1010 ValuePtr attr2_value = MakeValue(mean_flag);
1011
1012 Attr attr0 = std::make_pair(GROUP, attr0_value);
1013 Attr attr1 = std::make_pair(DEV_NUM, attr1_value);
1014 Attr attr2 = std::make_pair(MEAN_FLAG, attr2_value);
1015
1016 OperatorAttrs operator_attrs;
1017 operator_attrs.push_back(attr0);
1018 operator_attrs.push_back(attr1);
1019 operator_attrs.push_back(attr2);
1020
1021 OperatorName operator_name;
1022 if (grad_accumulation_step > 1 || split_stage_num > 1) {
1023 operator_name = MIRROR_MICRO_STEP_OPERATOR;
1024 } else {
1025 operator_name = MIRROR_OPERATOR;
1026 }
1027
1028 OperatorParams operator_param;
1029 OperatorArgs operator_args = std::make_pair(operator_attrs, operator_param);
1030
1031 Operator op = std::make_pair(operator_name, operator_args);
1032
1033 op_for_weight.push_back(op);
1034 MS_LOG(INFO) << "The group name is " << group_name << ", the dev num is " << dev_num << ", the mean flag is "
1035 << mean_flag;
1036 return op_for_weight;
1037 }
1038
CreateGroupByTensorMap(const Shape & tensor_map,std::vector<Group> * group)1039 Status OperatorInfo::CreateGroupByTensorMap(const Shape &tensor_map, std::vector<Group> *group) {
1040 if (group == nullptr) {
1041 MS_LOG(ERROR) << name_ << ": The group is null.";
1042 return FAILED;
1043 }
1044 CheckGlobalDeviceManager();
1045 int64_t rank = g_device_manager->global_rank();
1046 DeviceMatrix dev_matrix(rank, stage_device_list_, dev_matrix_shape_);
1047 RankList group_devices;
1048 if (dev_matrix.GetDevicesByTensorMap(tensor_map, &group_devices) != SUCCESS) {
1049 return FAILED;
1050 }
1051
1052 if (group_devices.size() == 1 && !((ParallelContext::GetInstance()->grad_accumulation_step() > 1 ||
1053 ParallelContext::GetInstance()->pipeline_stage_split_num() > 1) &&
1054 ParallelContext::GetInstance()->enable_parallel_optimizer())) {
1055 MS_LOG(INFO) << name_ << ": The dev size is 1, no need to create group.";
1056 return SUCCESS;
1057 }
1058 if (is_auto_parallel_) {
1059 if (g_device_manager->CheckDeviceList(group_devices) != SUCCESS) {
1060 MS_LOG(INFO) << name_ << ": Try to create communication group : " << group_devices
1061 << " failed in auto parallel mode, "
1062 "this error can be ignored in parallel strategies searching step";
1063 return FAILED;
1064 }
1065 return SUCCESS;
1066 }
1067
1068 Group g;
1069 if (g_device_manager->CreateGroup(group_devices, &g) != SUCCESS) {
1070 MS_LOG(ERROR) << name_ << ": Create communication group by tensor_map failed, the rank_list is: " << group_devices
1071 << ", the input strategy is " << strategy_->GetInputDim()
1072 << ", the full_name of node is: " << cnode_->fullname_with_scope();
1073 return FAILED;
1074 }
1075 group->push_back(g);
1076 return SUCCESS;
1077 }
1078
CreateGroupForOptShard(TensorLayout * tensor_layout,std::vector<Group> * groups)1079 Status OperatorInfo::CreateGroupForOptShard(TensorLayout *tensor_layout, std::vector<Group> *groups) {
1080 if (groups == nullptr) {
1081 MS_LOG(ERROR) << name_ << ": The group is null.";
1082 return FAILED;
1083 }
1084 CheckGlobalDeviceManager();
1085 int64_t rank = g_device_manager->global_rank();
1086 DeviceMatrix dev_matrix(rank, stage_device_list_, tensor_layout->device_arrangement_origin().array());
1087 RankList group_devices;
1088 Shape tensor_map = tensor_layout->origin_tensor_map().array();
1089 if (dev_matrix.GetDevicesByTensorMap(tensor_map, &group_devices) != SUCCESS) {
1090 return FAILED;
1091 }
1092
1093 if (group_devices.size() == 1) {
1094 MS_LOG(INFO) << name_ << ": The dev size is 1, no need to create group.";
1095 return SUCCESS;
1096 }
1097 int64_t repeated_size = SizeToLong(group_devices.size());
1098 int64_t optimizer_weight_shard_size = ParallelContext::GetInstance()->optimizer_weight_shard_size();
1099 MS_EXCEPTION_IF_ZERO("optimizer_weight_shard_size", optimizer_weight_shard_size);
1100 if (optimizer_weight_shard_size != -1 && repeated_size > optimizer_weight_shard_size) {
1101 // not fully use opt shard
1102 int64_t index = std::find(group_devices.begin(), group_devices.end(), rank) - group_devices.begin();
1103 if (repeated_size % optimizer_weight_shard_size != 0 || repeated_size < optimizer_weight_shard_size) {
1104 MS_LOG(WARNING) << "Parallel optimizer:"
1105 << " optimizer_weight_shard_size " << optimizer_weight_shard_size
1106 << " can not be applied for the parameter used by" << cnode_->fullname_with_scope()
1107 << " The data parallel group size is " << repeated_size;
1108 return FAILED;
1109 }
1110 repeated_size = repeated_size / optimizer_weight_shard_size;
1111 // create allgather group
1112 // eg: optimizer_weight_shard_size = 2, [0, 8, 16, 24] -> [0, 8], [16, 24]
1113 RankList new_group_devices(
1114 group_devices.begin() + index / optimizer_weight_shard_size * optimizer_weight_shard_size,
1115 group_devices.begin() + (index / optimizer_weight_shard_size + 1) * optimizer_weight_shard_size);
1116 Group allgather_group;
1117 if (g_device_manager->CreateGroup(new_group_devices, &allgather_group) != SUCCESS) {
1118 MS_LOG(ERROR) << name_
1119 << ": Create communication group for allgather in optimizer parallel failed,"
1120 " the rank_list is: "
1121 << group_devices << ", the input strategy is " << strategy_->GetInputDim()
1122 << ", the full_name of node is: " << cnode_->fullname_with_scope();
1123 return FAILED;
1124 }
1125 groups->push_back(allgather_group);
1126 tensor_layout->set_opt_shard_group(allgather_group.name());
1127 MS_LOG(INFO) << name_ << ": Parallel optimizer, create allgather group " << allgather_group.name();
1128 // create mirror group
1129 // eg: optimizer_weight_shard_size = 2, [0, 8, 16, 24] -> [0, 16], [8, 24]
1130 int64_t device_num = g_device_manager->stage_device_num();
1131 MS_EXCEPTION_IF_ZERO("repeated_size", repeated_size);
1132 Shape dev_mat = {repeated_size, device_num / repeated_size};
1133 DeviceMatrix temp_dev_matrix(rank, stage_device_list_, dev_mat);
1134 RankList mirror_group_devices;
1135 if (temp_dev_matrix.GetDevicesAlongDim(0, &mirror_group_devices) != SUCCESS) {
1136 return FAILED;
1137 }
1138 Group mirror_group;
1139 if (g_device_manager->CreateGroup(mirror_group_devices, &mirror_group) != SUCCESS) {
1140 MS_LOG(ERROR) << name_
1141 << ": Create communication group for mirror in optimizer parallel failed,"
1142 " the rank_list is: "
1143 << group_devices << ", the input strategy is " << strategy_->GetInputDim()
1144 << ", the full_name of node is: " << cnode_->fullname_with_scope();
1145 return FAILED;
1146 }
1147 groups->push_back(mirror_group);
1148 tensor_layout->set_opt_shard_mirror_group(mirror_group.name());
1149 MS_LOG(INFO) << name_ << ": Parallel optimizer, create mirror group " << mirror_group.name();
1150 } else {
1151 // fully use opt shard
1152 // create allgather group
1153 Group allgather_group;
1154 if (g_device_manager->CreateGroup(group_devices, &allgather_group) != SUCCESS) {
1155 MS_LOG(ERROR) << name_
1156 << ": Create communication group for allgather in optimizer parallel failed,"
1157 " the rank_list is: "
1158 << group_devices << ", the input strategy is " << strategy_->GetInputDim()
1159 << ", the full_name of node is: " << cnode_->fullname_with_scope();
1160 return FAILED;
1161 }
1162 groups->push_back(allgather_group);
1163 tensor_layout->set_opt_shard_group(allgather_group.name());
1164 MS_LOG(INFO) << name_ << ": Parallel optimizer, create allgather group " << allgather_group.name();
1165 }
1166 // save in tensor_layout for strategy ckpt
1167 auto integrated_save = ParallelContext::GetInstance()->optimizer_weight_shard_aggregated_save();
1168 if (!integrated_save) {
1169 tensor_layout->set_opt_weight_shard_size(LongToInt(optimizer_weight_shard_size));
1170 if (optimizer_weight_shard_size > 0 && group_devices.size() < LongToSize(optimizer_weight_shard_size)) {
1171 tensor_layout->set_opt_weight_shard_size(SizeToInt(group_devices.size()));
1172 }
1173 MS_EXCEPTION_IF_ZERO("SizeToLong(group_devices.size()) - 1", SizeToLong(group_devices.size()) - 1);
1174 int64_t opt_weight_shard_step =
1175 (group_devices.back() - group_devices.front()) / (SizeToLong(group_devices.size()) - 1);
1176 tensor_layout->set_opt_weight_shard_step(LongToInt(opt_weight_shard_step));
1177 MS_LOG(INFO) << name_ << "Parallel optimizer, save opt_weight_shard_step " << opt_weight_shard_step
1178 << " in strategy ckpt";
1179 }
1180 return SUCCESS;
1181 }
1182
InsertDivOpToNodeInput(const CNodePtr & node,int64_t div_num,size_t index,const string & instance_name)1183 static void InsertDivOpToNodeInput(const CNodePtr &node, int64_t div_num, size_t index, const string &instance_name) {
1184 MS_EXCEPTION_IF_NULL(node);
1185 FuncGraphPtr func_graph = node->func_graph();
1186 MS_EXCEPTION_IF_NULL(func_graph);
1187 // instance the div operator
1188 Operator div_op = CreateScalarFloorDivOp(div_num);
1189
1190 // Insert it as the input of the node
1191 AnfNodePtr input = node->input(index);
1192 MS_EXCEPTION_IF_NULL(input);
1193 InsertNode(div_op, node, index, node->input(index), func_graph, instance_name);
1194 }
1195
ChangeMakeTupleConstant(const CNodePtr & cnode,size_t make_tuple_index)1196 void OperatorInfo::ChangeMakeTupleConstant(const CNodePtr &cnode, size_t make_tuple_index) {
1197 if (!IsPrimitiveCNode(cnode->input(make_tuple_index), prim::kPrimMakeTuple)) {
1198 MS_LOG(EXCEPTION) << name_ << ": the dst shape is not make tuple";
1199 }
1200 size_t input_dim = inputs_shape_[0].size();
1201 auto shard_size = strategy_->GetInputDim()[0];
1202 if (input_dim != shard_size.size()) {
1203 MS_LOG(EXCEPTION) << name_ << ": the input dim is " << input_dim << ", but the size of strategy is "
1204 << shard_size.size();
1205 }
1206
1207 auto make_tuple = cnode->input(make_tuple_index);
1208 auto make_tuple_cnode = make_tuple->cast<CNodePtr>();
1209 for (size_t i = 0; i < input_dim; ++i) {
1210 if (shard_size[i] <= 1) {
1211 continue;
1212 }
1213 auto value_node = GetValueNode(make_tuple_cnode->input(i + 1));
1214 if (value_node == nullptr) {
1215 std::string instance_name = name_ + "div";
1216 InsertDivOpToNodeInput(make_tuple_cnode, shard_size[i], i + 1, instance_name);
1217 } else if (value_node->isa<Int64Imm>()) {
1218 MS_EXCEPTION_IF_ZERO("shard_size", shard_size[i]);
1219 auto origin_value = GetValue<int64_t>(value_node);
1220 if (origin_value < 0 || origin_value % shard_size[i] != 0) {
1221 MS_LOG(EXCEPTION) << name_ << ": the origin value is " << origin_value << ", can not be div by shard size "
1222 << shard_size[i] << ", the make tuple index of this op is " << make_tuple_index
1223 << ", the input index of make_tuple is " << i + 1;
1224 }
1225 int64_t replace_value = GetValue<int64_t>(value_node) / shard_size[i];
1226 auto replace_value_ptr = MakeValue(replace_value);
1227 auto replace_value_node = std::make_shared<ValueNode>(replace_value_ptr);
1228 auto manager = make_tuple->func_graph()->manager();
1229 manager->SetEdge(make_tuple, i + 1, replace_value_node);
1230 } else {
1231 MS_LOG(EXCEPTION) << name_ << ": the input of make_tuple is value node but not int64, the index is " << (i + 1);
1232 }
1233 }
1234 }
1235
CreateGroupByDim(size_t axis,std::vector<Group> * group)1236 Status OperatorInfo::CreateGroupByDim(size_t axis, std::vector<Group> *group) {
1237 if (group == nullptr) {
1238 MS_LOG(ERROR) << name_ << ": The group is null.";
1239 return FAILED;
1240 }
1241 CheckGlobalDeviceManager();
1242 int64_t rank = g_device_manager->global_rank();
1243 DeviceMatrix dev_matrix(rank, stage_device_list_, dev_matrix_shape_);
1244
1245 return CreateGroupByDimWithDevMatrix(&dev_matrix, axis, group);
1246 }
1247
CreateGroupByDimWithDevMatrix(DeviceMatrix * dev_matrix,size_t axis,std::vector<Group> * group)1248 Status OperatorInfo::CreateGroupByDimWithDevMatrix(DeviceMatrix *dev_matrix, size_t axis, std::vector<Group> *group) {
1249 if (group == nullptr) {
1250 MS_LOG(ERROR) << name_ << ": The group is null.";
1251 return FAILED;
1252 }
1253 CheckGlobalDeviceManager();
1254 RankList group_devices;
1255 if (dev_matrix->GetDevicesAlongDim(SizeToUlong(axis), &group_devices) != SUCCESS) {
1256 return FAILED;
1257 }
1258
1259 if (group_devices.size() == 1) {
1260 MS_LOG(INFO) << name_ << ": The dev size is 1, no need to create group.";
1261 return SUCCESS;
1262 }
1263 if (is_auto_parallel_) {
1264 if (g_device_manager->CheckDeviceList(group_devices) != SUCCESS) {
1265 MS_LOG(INFO) << name_ << ": Try to create communication group : " << group_devices
1266 << " failed in auto parallel mode, "
1267 << "this error can be ignored in parallel strategies searching step";
1268 return FAILED;
1269 }
1270 return SUCCESS;
1271 }
1272 Group g;
1273 if (g_device_manager->CreateGroup(group_devices, &g) != SUCCESS) {
1274 MS_LOG(ERROR) << name_ << ": Create communication group by dim failed, the rank_list is: " << group_devices
1275 << ", the input strategy is " << strategy_->GetInputDim()
1276 << ", the full_name of node is: " << cnode_->fullname_with_scope();
1277 return FAILED;
1278 }
1279 MS_LOG(INFO) << name_ << ": Create communication group by dim " << axis
1280 << " success, the rank_list is: " << group_devices;
1281 group->push_back(g);
1282 return SUCCESS;
1283 }
1284
GetSliceShape(const Shape & tensor_shape,const Dimensions & strategy)1285 Shape GetSliceShape(const Shape &tensor_shape, const Dimensions &strategy) {
1286 Shape slice_shape;
1287 if (std::any_of(strategy.begin(), strategy.end(), [](int64_t value) { return value <= 0; })) {
1288 MS_LOG(ERROR) << "Invalid strategy: " << ShapeToString(strategy) << ", the element is less than or equal to 0";
1289 return slice_shape;
1290 }
1291 for (size_t i = 0; i < strategy.size(); ++i) {
1292 slice_shape.push_back(tensor_shape.at(i) / strategy.at(i));
1293 }
1294 return slice_shape;
1295 }
1296
Init(const StrategyPtr & in_strategy,const StrategyPtr & out_strategy,const std::vector<std::shared_ptr<TensorLayout>> & in_tensor_layouts,const std::vector<std::shared_ptr<TensorLayout>> & out_tensor_layouts)1297 Status OperatorInfo::Init(const StrategyPtr &in_strategy, const StrategyPtr &out_strategy,
1298 const std::vector<std::shared_ptr<TensorLayout>> &in_tensor_layouts,
1299 const std::vector<std::shared_ptr<TensorLayout>> &out_tensor_layouts) {
1300 if (!in_tensor_layouts.empty()) {
1301 return InitWithTensorLayout(in_tensor_layouts, out_tensor_layouts);
1302 }
1303
1304 if (InitWithAutoRepeatCalc(in_strategy, out_strategy) != SUCCESS) {
1305 MS_LOG(ERROR) << name_ << " : Init failed.";
1306 return FAILED;
1307 }
1308
1309 MS_LOG(INFO) << name_ << " : Init success.";
1310 return SUCCESS;
1311 }
1312
InitForCostModel(const StrategyPtr & in_strategy,const StrategyPtr & out_strategy)1313 Status OperatorInfo::InitForCostModel(const StrategyPtr &in_strategy, const StrategyPtr &out_strategy) {
1314 std::vector<std::shared_ptr<TensorLayout>> in_tensor_layouts;
1315 std::vector<std::shared_ptr<TensorLayout>> out_tensor_layouts;
1316 Status status =
1317 ExtractUserConfigLayout(attrs_, inputs_shape_, outputs_shape_, &in_tensor_layouts, &out_tensor_layouts);
1318 if (status != SUCCESS) {
1319 MS_LOG(EXCEPTION) << "Failure:operator " << name_ << " extract configured layout failed.";
1320 }
1321 if (!in_tensor_layouts.empty()) {
1322 out_tensor_layouts = {};
1323 return InitWithTensorLayout(in_tensor_layouts, out_tensor_layouts);
1324 }
1325 if (InitForCostModelWithAutoRepeatCalc(in_strategy, out_strategy) != SUCCESS) {
1326 ReportError(name_ + " : Init for cost model failed.");
1327 return FAILED;
1328 }
1329
1330 MS_LOG(INFO) << name_ << " : Init for cost model success.";
1331 return SUCCESS;
1332 }
1333
DivisorsReplaceShapes()1334 void OperatorInfo::DivisorsReplaceShapes() {
1335 if (!dynamic_shape_flag_) {
1336 return;
1337 }
1338
1339 inputs_shape_ = inputs_divisor_;
1340 outputs_shape_ = outputs_divisor_;
1341 }
1342
ResumeShapes()1343 void OperatorInfo::ResumeShapes() {
1344 if (!dynamic_shape_flag_) {
1345 return;
1346 }
1347
1348 inputs_shape_ = inputs_shape_clone_;
1349 outputs_shape_ = outputs_shape_clone_;
1350 }
1351
DynamicShapeCheckStrategyLog()1352 void OperatorInfo::DynamicShapeCheckStrategyLog() {
1353 if (!dynamic_shape_flag_) {
1354 return;
1355 }
1356 MS_LOG(ERROR) << name_ << ": the origin shape of inputs is " << ShapesToString(inputs_shape_clone_)
1357 << ", but the divisor info of inputs is " << ShapesToString(inputs_divisor_);
1358 }
1359
1360 // auto insert repeated_calculation_num for dev_matrix_shape when repeated_calculation_num > 1
InitForCostModelWithAutoRepeatCalc(const StrategyPtr & in_strategy,const StrategyPtr & out_strategy)1361 Status OperatorInfo::InitForCostModelWithAutoRepeatCalc(const StrategyPtr &in_strategy,
1362 const StrategyPtr &out_strategy) {
1363 if (!is_layout_config_ && in_strategy == nullptr) {
1364 MS_LOG(ERROR) << name_ << ": The strategy is null, the inputs shape is " << inputs_shape_;
1365 return FAILED;
1366 }
1367
1368 // need to clear queues before Init(),
1369 // because Init() may be called multiple times by cost model
1370 ResetQueueMember();
1371
1372 if (InferAttrs() != SUCCESS) {
1373 MS_LOG(ERROR) << name_ << ": InferAttrs failed.";
1374 return FAILED;
1375 }
1376
1377 // if layout is configured, no need to check strategy and infer dev matrix
1378 if (!is_layout_config_) {
1379 DivisorsReplaceShapes(); // in dynamic shape, using divisors replace to shapes before CheckStrategy
1380 // must be after InferAttrs()
1381 if (CheckStrategy(in_strategy) != SUCCESS) {
1382 DynamicShapeCheckStrategyLog();
1383 FILTER_LOG(is_auto_parallel_) << name_ << ": CheckStrategy failed.";
1384 return FAILED;
1385 }
1386 ResumeShapes(); // in dynamic shape, resume shapes after CheckStrategy
1387
1388 if (is_dynamic_shape_ && CheckStrategyForDynamicShape(in_strategy) != SUCCESS) {
1389 MS_LOG(ERROR) << name_ << ": Check strategy for dynamic shape failed";
1390 return FAILED;
1391 }
1392 strategy_ = in_strategy;
1393
1394 if (out_strategy && CheckOutputStrategy(out_strategy) != SUCCESS) {
1395 MS_LOG(ERROR) << name_ << ": The output strategy is invalid";
1396 return FAILED;
1397 }
1398 set_out_strategy(out_strategy);
1399
1400 if (InferDevMatrixShape() != SUCCESS) {
1401 MS_LOG(ERROR) << name_ << ": InferDevMatrixShape failed.";
1402 return FAILED;
1403 }
1404
1405 used_devices_ = std::accumulate(dev_matrix_shape_.begin(), dev_matrix_shape_.end(), 1, std::multiplies<int64_t>());
1406
1407 // must be after InferDevMatrixShape
1408 if (InferRepeatedCalcInfo() != SUCCESS) {
1409 MS_LOG(ERROR) << name_ << ": InferRepeatedCalcInfo failed.";
1410 return FAILED;
1411 }
1412
1413 // if repeated calculation, need to set the repeated_calc_num as the last dimension of dev-matrix for layout
1414 SetRepeatedCalcDevMatrix();
1415
1416 if (InferTensorMap() != SUCCESS) {
1417 MS_LOG(ERROR) << name_ << ": InferTensorMap failed.";
1418 return FAILED;
1419 }
1420
1421 ResetTensorMapIfRepeatedCalc();
1422 } else {
1423 if (InferOutputTensorMap() != SUCCESS) {
1424 MS_LOG(ERROR) << name_ << ": InferOutputTensorMap failed.";
1425 return FAILED;
1426 }
1427 }
1428
1429 if (InferTensorInfo() != SUCCESS) {
1430 MS_LOG(ERROR) << name_ << ": InferTensorInfo failed.";
1431 return FAILED;
1432 }
1433 auto stage_dev_num = LongToSize(g_device_manager->stage_device_num());
1434 if ((stage_dev_num & (stage_dev_num - 1)) == 0) {
1435 return SUCCESS;
1436 }
1437 if (InferForwardCommunication() != SUCCESS) {
1438 MS_LOG(WARNING) << name_ << ": InferForwardCommunication failed in auto parallel searching strategies step.";
1439 return FAILED;
1440 }
1441
1442 if (InferMirrorOps() != SUCCESS) {
1443 MS_LOG(WARNING) << name_ << ": InferMirrorOps failed in auto parallel searching strategies step.";
1444 return FAILED;
1445 }
1446 return SUCCESS;
1447 }
1448
InitWithAutoRepeatCalc(const StrategyPtr & in_strategy,const StrategyPtr & out_strategy)1449 Status OperatorInfo::InitWithAutoRepeatCalc(const StrategyPtr &in_strategy, const StrategyPtr &out_strategy) {
1450 if (in_strategy == nullptr) {
1451 MS_LOG(ERROR) << name_ << ": The input strategy is null.";
1452 return FAILED;
1453 }
1454
1455 if (InitForCostModelWithAutoRepeatCalc(in_strategy, out_strategy) != SUCCESS) {
1456 return FAILED;
1457 }
1458
1459 if (InferForwardCommunication() != SUCCESS) {
1460 MS_LOG(ERROR) << name_ << ": InferForwardCommunication failed.";
1461 return FAILED;
1462 }
1463
1464 if (InferMirrorOps() != SUCCESS) {
1465 MS_LOG(ERROR) << name_ << ": InferMirrorOps failed.";
1466 return FAILED;
1467 }
1468
1469 if (InferVirtualDivOps() != SUCCESS) {
1470 MS_LOG(ERROR) << name_ << ": InferVirtualDivOps failed.";
1471 return FAILED;
1472 }
1473
1474 InferReplaceOps();
1475 return SUCCESS;
1476 }
1477
CheckInputLayout()1478 Status OperatorInfo::CheckInputLayout() {
1479 MS_LOG(ERROR) << "Current op " << name_
1480 << " does not support config layout. Please check "
1481 "https://www.mindspore.cn/docs/zh-CN/r2.3.0rc2/note/operator_list_parallel.html to get limitation "
1482 "and more details";
1483 // Check self_define_shard attribute
1484 if (!self_define_shard_) {
1485 MS_LOG(ERROR) << "Please set add_prim_attr('self_define_shard', True) to " << name_
1486 << " to config layout for this ops";
1487 return FAILED;
1488 }
1489 return FAILED;
1490 }
1491
InitWithTensorLayout(const std::vector<std::shared_ptr<TensorLayout>> & in_tensor_layouts,const std::vector<std::shared_ptr<TensorLayout>> & out_tensor_layouts)1492 Status OperatorInfo::InitWithTensorLayout(const std::vector<std::shared_ptr<TensorLayout>> &in_tensor_layouts,
1493 const std::vector<std::shared_ptr<TensorLayout>> &out_tensor_layouts) {
1494 ResetQueueMember();
1495
1496 if (InferAttrs() != SUCCESS) {
1497 MS_LOG(ERROR) << name_ << ": InferAttrs failed.";
1498 return FAILED;
1499 }
1500
1501 size_t real_input_index = 0;
1502 for (const auto &input_layout : in_tensor_layouts) {
1503 // Insert placeholder TensorInfo for optional input
1504 while (real_input_index < input_value_.size() && input_value_[real_input_index] != nullptr &&
1505 input_value_[real_input_index]->isa<None>()) {
1506 (void)inputs_tensor_info_.emplace_back(TensorInfo());
1507 ++real_input_index;
1508 }
1509 TensorInfo input_tensor_info(*input_layout);
1510 inputs_tensor_info_.push_back(input_tensor_info);
1511 ++real_input_index;
1512 }
1513 if (CheckInputLayout() != SUCCESS) {
1514 MS_LOG(ERROR) << name_ << ": CheckInputLayout failed.";
1515 return FAILED;
1516 }
1517 for (const auto &output_layout : out_tensor_layouts) {
1518 TensorInfo output_tensor_info(*output_layout);
1519 outputs_tensor_info_.push_back(output_tensor_info);
1520 }
1521
1522 if (outputs_tensor_info_.size() != outputs_shape_.size()) {
1523 outputs_tensor_info_.clear();
1524 // Need be override
1525 if (InferOutputTensorInfo() != SUCCESS) {
1526 MS_LOG(ERROR) << name_ << ": InferOutputTensorLayout failed.";
1527 return FAILED;
1528 }
1529 }
1530
1531 if (outputs_tensor_info_.size() != outputs_shape_.size()) {
1532 MS_LOG(ERROR) << name_ << ": the output tensor layout num " << outputs_tensor_info_.size()
1533 << " dose not match the output num " << outputs_shape_.size();
1534 return FAILED;
1535 }
1536
1537 if (CheckOutputLayout() != SUCCESS) {
1538 MS_LOG(ERROR) << name_ << ": CheckLayout failed.";
1539 return FAILED;
1540 }
1541
1542 // Need be override
1543 if (InferForwardCommunicationByLayout() != SUCCESS) {
1544 MS_LOG(ERROR) << name_ << ": InferForwardCommunication failed.";
1545 return FAILED;
1546 }
1547
1548 if (InferMirrorOpsByLayout() != SUCCESS) {
1549 MS_LOG(ERROR) << name_ << ": InferMirrorOps failed.";
1550 return FAILED;
1551 }
1552 if (InferVirtualDivOpsByLayout() != SUCCESS) {
1553 MS_LOG(ERROR) << name_ << ": InferVirtualDivOps failed.";
1554 return FAILED;
1555 }
1556 return SUCCESS;
1557 }
1558
GetAliveSuccEdges()1559 std::vector<std::shared_ptr<Edge>> OperatorInfo::GetAliveSuccEdges() {
1560 std::vector<std::shared_ptr<Edge>> ret;
1561 for (auto &edge : succ_edges_) {
1562 if ((edge->next_operator()->is_alive()) && (edge->next_operator()->name().find(RELU) != std::string::npos)) {
1563 ret.push_back(edge);
1564 } else if ((edge->next_operator()->is_alive()) && (edge->next_operator()->name().find(CAST) != std::string::npos)) {
1565 // CAST is ordered in front of L2NORMALIZE
1566 ret.push_back(edge);
1567 }
1568 }
1569 for (auto &edge : succ_edges_) {
1570 if ((edge->next_operator()->is_alive()) && (edge->next_operator()->name().find(RELU) == std::string::npos) &&
1571 (edge->next_operator()->name().find(CAST) == std::string::npos)) {
1572 ret.push_back(edge);
1573 }
1574 }
1575 return ret;
1576 }
1577
GetAlivePrevEdges()1578 std::vector<std::shared_ptr<Edge>> OperatorInfo::GetAlivePrevEdges() {
1579 std::vector<std::shared_ptr<Edge>> ret;
1580 for (auto &edge : prev_edges_) {
1581 if (edge->prev_operator()->is_alive()) {
1582 ret.push_back(edge);
1583 }
1584 }
1585 return ret;
1586 }
1587
ReplacePreEdge(const std::shared_ptr<OperatorInfo> & op,const std::shared_ptr<Edge> & new_edge)1588 void OperatorInfo::ReplacePreEdge(const std::shared_ptr<OperatorInfo> &op, const std::shared_ptr<Edge> &new_edge) {
1589 if (op == nullptr) {
1590 MS_LOG(ERROR) << name_ << ": ReplacePreEdge: the op is null.";
1591 return;
1592 }
1593 for (auto &edge : prev_edges_) {
1594 if (edge->prev_operator() == op) {
1595 edge = new_edge;
1596 return;
1597 }
1598 }
1599 MS_LOG(EXCEPTION) << name_ << ": Replace edge failed: no edge has been replaced";
1600 }
1601
ReplaceSuccEdge(const std::shared_ptr<OperatorInfo> & op,const std::shared_ptr<Edge> & new_edge)1602 void OperatorInfo::ReplaceSuccEdge(const std::shared_ptr<OperatorInfo> &op, const std::shared_ptr<Edge> &new_edge) {
1603 if (op == nullptr) {
1604 MS_LOG(ERROR) << name_ << ": ReplaceSuccEdge: the op is null.";
1605 return;
1606 }
1607 for (auto &edge : succ_edges_) {
1608 if (edge->next_operator() == op) {
1609 edge = new_edge;
1610 return;
1611 }
1612 }
1613 MS_LOG(EXCEPTION) << name_ << ": Replace edge failed: no edge has been replaced";
1614 }
1615
ReplacePreEdges(const std::shared_ptr<OperatorInfo> & op,const std::shared_ptr<Edge> & new_edge)1616 void OperatorInfo::ReplacePreEdges(const std::shared_ptr<OperatorInfo> &op, const std::shared_ptr<Edge> &new_edge) {
1617 if (op == nullptr) {
1618 MS_LOG(ERROR) << name_ << ": ReplacePreEdges: the op is null.";
1619 return;
1620 }
1621 std::vector<std::shared_ptr<Edge>> update_pre_edges;
1622 for (auto &edge : prev_edges_) {
1623 if (edge->prev_operator() != op) {
1624 update_pre_edges.push_back(edge);
1625 }
1626 }
1627 update_pre_edges.push_back(new_edge);
1628 prev_edges_ = update_pre_edges;
1629 }
1630
ReplaceSuccEdges(const std::shared_ptr<OperatorInfo> & op,const std::shared_ptr<Edge> & new_edge)1631 void OperatorInfo::ReplaceSuccEdges(const std::shared_ptr<OperatorInfo> &op, const std::shared_ptr<Edge> &new_edge) {
1632 if (op == nullptr) {
1633 MS_LOG(ERROR) << name_ << ": ReplaceSuccEdges: the op is null";
1634 return;
1635 }
1636 std::vector<std::shared_ptr<Edge>> update_pre_edges;
1637 for (auto &edge : succ_edges_) {
1638 if (edge->next_operator() != op) {
1639 update_pre_edges.push_back(edge);
1640 }
1641 }
1642 update_pre_edges.push_back(new_edge);
1643 succ_edges_ = update_pre_edges;
1644 }
1645
GenerateBatchStrategiesWithCheck()1646 std::shared_ptr<Strategies> OperatorInfo::GenerateBatchStrategiesWithCheck() {
1647 if (InferAttrs() != SUCCESS) {
1648 MS_LOG(EXCEPTION) << name_ << ": Infer attrs failed";
1649 }
1650 DivisorsReplaceShapes(); // in dynamic shape, using divisors replace to shapes before GenerateBatchStrategies
1651
1652 std::shared_ptr<Strategies> batch_strategy = GenerateBatchStrategies();
1653 if (batch_strategy->size() != inputs_shape_.size()) {
1654 MS_LOG(WARNING) << "The inputs size:" << inputs_shape_.size()
1655 << " is not equal to the generated batch parallel strategies size:" << batch_strategy->size();
1656 return batch_strategy;
1657 }
1658 int64_t shard_size = g_device_manager->stage_device_num();
1659 std::vector<std::pair<size_t, size_t>> changed_pos;
1660 for (size_t i = 0; i < inputs_shape_.size(); ++i) {
1661 auto stra = batch_strategy->at(i);
1662 auto input_shape = inputs_shape_.at(i);
1663 if (stra.size() != input_shape.size()) {
1664 MS_LOG(WARNING) << "The " << i << " input size:" << input_shape.size() << " is not equal to the " << i
1665 << " generated batch parallel strategy size:" << stra.size();
1666 return batch_strategy;
1667 }
1668 for (size_t j = 0; j < input_shape.size(); ++j) {
1669 if (stra[j] == 1) {
1670 continue;
1671 }
1672 if (stra[j] != g_device_manager->stage_device_num()) {
1673 MS_LOG(WARNING) << "The batch parallel value is not equal to device num, skip adjust it.";
1674 return batch_strategy;
1675 }
1676 shard_size = std::gcd(input_shape[j], shard_size);
1677 changed_pos.push_back({i, j});
1678 }
1679 }
1680 for (auto &pair : changed_pos) {
1681 batch_strategy->at(pair.first).at(pair.second) = shard_size;
1682 }
1683
1684 ResumeShapes();
1685 return batch_strategy;
1686 }
1687
GenerateBatchStrategiesBySplitFlag(const Shapes & shapes,const std::vector<bool> & split_flag_list)1688 std::shared_ptr<Strategies> GenerateBatchStrategiesBySplitFlag(const Shapes &shapes,
1689 const std::vector<bool> &split_flag_list) {
1690 if (shapes.size() != split_flag_list.size()) {
1691 MS_LOG(ERROR) << "Split_flag_list do not have the same size as inputs shape, " << split_flag_list.size() << " : "
1692 << shapes.size();
1693 return nullptr;
1694 }
1695 CheckGlobalDeviceManager();
1696 int64_t dev_num = g_device_manager->stage_device_num();
1697 Strategies strategy_v;
1698 for (size_t i = 0; i != shapes.size(); i++) {
1699 if (shapes[i].empty()) {
1700 MS_LOG(INFO) << "Elements of shapes is empty.";
1701 Dimensions empty_element;
1702 strategy_v.push_back(empty_element);
1703 } else {
1704 Dimensions element(shapes[i].size(), 1);
1705 if (split_flag_list[i]) {
1706 element[0] = dev_num;
1707 }
1708 strategy_v.push_back(element);
1709 }
1710 }
1711 return std::make_shared<Strategies>(strategy_v);
1712 }
1713
ReComputeBatchSplitFlagList()1714 void OperatorInfo::ReComputeBatchSplitFlagList() {
1715 if (!inputs_shape_.empty()) {
1716 split_flag_list_[0] = true;
1717 }
1718 }
1719
ComputeBatchSplitFlagList()1720 void OperatorInfo::ComputeBatchSplitFlagList() {
1721 split_flag_list_.clear();
1722 for (auto iter = inputs_shape_.begin(); iter != inputs_shape_.end(); ++iter) {
1723 split_flag_list_.push_back(false);
1724 }
1725 ReComputeBatchSplitFlagList();
1726 }
1727
1728 // This is a common method for checking whether the generated strategy has the correct number of devuces.
PrepareStrategyBase(int64_t stage_id,size_t dev_num,const Shapes & inputs_partitions,StrategyPtr * const sp)1729 Status PrepareStrategyBase(int64_t stage_id, size_t dev_num, const Shapes &inputs_partitions, StrategyPtr *const sp) {
1730 if (sp == nullptr) {
1731 MS_LOG(ERROR) << "The strategy is null.";
1732 return FAILED;
1733 }
1734 int64_t product = 1;
1735
1736 for (auto &input_partition : inputs_partitions) {
1737 product *= std::accumulate(input_partition.begin(), input_partition.end(), 1, std::multiplies<int64_t>());
1738 }
1739 const auto fully_use_device = CostModelContext::GetInstance()->fully_use_device();
1740 if (!fully_use_device) {
1741 if (LongToSize(product) > dev_num) {
1742 return FAILED;
1743 }
1744 } else {
1745 if ((product != 1) && (LongToSize(product) != dev_num)) {
1746 return FAILED;
1747 }
1748 }
1749 Strategies stras(inputs_partitions);
1750 (*sp) = std::make_shared<Strategy>(stage_id, stras);
1751 return SUCCESS;
1752 }
1753
GenerateBatchStrategies()1754 std::shared_ptr<Strategies> OperatorInfo::GenerateBatchStrategies() {
1755 if (InferAttrs() != SUCCESS) {
1756 MS_LOG(EXCEPTION) << name_ << ": Infer attrs failed";
1757 }
1758 ComputeBatchSplitFlagList();
1759 return GenerateBatchStrategiesBySplitFlag(inputs_shape_, split_flag_list_);
1760 }
1761
1762 // generate strategies for that each dimension of input0 and input1 is relevant, such as: ([a, b, c, d], [a, b, c, d])
GenerateStrategiesForTwoEqualInputs(int64_t stage_id,const Shapes & inputs_shape,const Shapes & splittable_inputs,std::vector<StrategyPtr> * const sp_vector)1763 Status GenerateStrategiesForTwoEqualInputs(int64_t stage_id, const Shapes &inputs_shape,
1764 const Shapes &splittable_inputs, std::vector<StrategyPtr> *const sp_vector) {
1765 if (sp_vector == nullptr) {
1766 MS_LOG(ERROR) << "The sp_vector is null.";
1767 return FAILED;
1768 }
1769
1770 if ((inputs_shape.size() != 2) || (splittable_inputs.size() != 2)) {
1771 MS_LOG(ERROR) << "The inputs size is wrong.";
1772 return FAILED;
1773 }
1774
1775 if ((inputs_shape[0].size() != inputs_shape[1].size()) ||
1776 (splittable_inputs[0].size() != splittable_inputs[1].size())) {
1777 MS_LOG(ERROR) << "The size of two inputs are not equal.";
1778 return FAILED;
1779 }
1780
1781 Shapes input0_shape = {inputs_shape[0]};
1782 Shapes input0_splittable = {splittable_inputs[0]};
1783 if (GenerateStrategiesForIndependentInputs(stage_id, input0_shape, input0_splittable, sp_vector) != SUCCESS) {
1784 return FAILED;
1785 }
1786
1787 for (auto &sp : *sp_vector) {
1788 sp->ExpandInputDimFromOneToTwo();
1789 }
1790
1791 return SUCCESS;
1792 }
1793
1794 // generate strategies for that input0 and input1 have relevant dimensions, and input0 needs to broadcast
1795 // such as: ([b, c, d], [a, b, c, d]) or ([1, c, d], [a, b, c, d])
GenerateStrategiesForBroadcastLeft(int64_t stage_id,const Shapes & inputs_shape,const Shapes & splittable_inputs,std::vector<StrategyPtr> * const sp_vector)1796 Status GenerateStrategiesForBroadcastLeft(int64_t stage_id, const Shapes &inputs_shape, const Shapes &splittable_inputs,
1797 std::vector<StrategyPtr> *const sp_vector) {
1798 if (sp_vector == nullptr) {
1799 MS_LOG(ERROR) << "The sp_vector is null.";
1800 return FAILED;
1801 }
1802
1803 if (inputs_shape[0].size() >= inputs_shape[1].size()) {
1804 MS_LOG(ERROR) << "Invalid inputs shape.";
1805 return FAILED;
1806 }
1807
1808 // first, generate strategy for input0 the same as input1
1809 Shapes tmp_inputs_shape = {inputs_shape[1], inputs_shape[1]};
1810 Shapes tmp_splittable_inputs = {splittable_inputs[1], splittable_inputs[1]};
1811 if (GenerateStrategiesForTwoEqualInputs(stage_id, tmp_inputs_shape, tmp_splittable_inputs, sp_vector) != SUCCESS) {
1812 MS_LOG(ERROR) << "GenerateStrategiesForTwoEqualInputs failed.";
1813 return FAILED;
1814 }
1815
1816 // second, get the correct strategy for input0
1817 for (auto &sp : *sp_vector) {
1818 Strategies tmp_strategy;
1819 Dimensions input0_strategy = sp->GetInputDim()[0];
1820 size_t size_diff = inputs_shape[1].size() - inputs_shape[0].size();
1821
1822 // erase the unnecessary part
1823 (void)input0_strategy.erase(input0_strategy.cbegin(),
1824 input0_strategy.cbegin() + static_cast<different_type>(size_diff));
1825
1826 // handle the case likes ([1, c, d], [a, b, c, d])
1827 for (size_t i = 0; i < inputs_shape[0].size(); ++i) {
1828 if (inputs_shape[0][i] == 1) {
1829 input0_strategy[i] = 1;
1830 } else {
1831 break;
1832 }
1833 }
1834
1835 // reset the strategy
1836 tmp_strategy.push_back(input0_strategy); // input0
1837 tmp_strategy.push_back(sp->GetInputDim()[1]); // input1
1838 sp->ResetInputs(tmp_strategy);
1839 }
1840 return SUCCESS;
1841 }
1842
1843 // generate strategies for that input0 and input1 have relevant dimensions, and input1 needs to broadcast
1844 // such as: ([a, b, c, d], [b, c, d]) or ([a, b, c, d], [1, c, d])
GenerateStrategiesForBroadcastRight(int64_t stage_id,const Shapes & inputs_shape,const Shapes & splittable_inputs,std::vector<StrategyPtr> * const sp_vector)1845 Status GenerateStrategiesForBroadcastRight(int64_t stage_id, const Shapes &inputs_shape,
1846 const Shapes &splittable_inputs, std::vector<StrategyPtr> *const sp_vector) {
1847 if (sp_vector == nullptr) {
1848 MS_LOG(ERROR) << "The sp_vector is null.";
1849 return FAILED;
1850 }
1851
1852 if (inputs_shape[0].size() <= inputs_shape[1].size()) {
1853 MS_LOG(ERROR) << "Invalid inputs shape.";
1854 return FAILED;
1855 }
1856
1857 // first, generate strategy for input1 the same as input0
1858 Shapes tmp_inputs_shape = {inputs_shape[0], inputs_shape[0]};
1859 Shapes tmp_splittable_inputs = {splittable_inputs[0], splittable_inputs[0]};
1860 if (GenerateStrategiesForTwoEqualInputs(stage_id, tmp_inputs_shape, tmp_splittable_inputs, sp_vector) != SUCCESS) {
1861 MS_LOG(ERROR) << "GenerateStrategiesForTwoEqualInputs failed.";
1862 return FAILED;
1863 }
1864
1865 // second, get the correct strategy for input1
1866 for (auto &sp : *sp_vector) {
1867 Strategies tmp_strategy;
1868 tmp_strategy.push_back(sp->GetInputDim()[0]); // input0
1869
1870 Dimensions input1_strategy = sp->GetInputDim()[1];
1871 size_t size_diff = inputs_shape[0].size() - inputs_shape[1].size();
1872
1873 // erase the unnecessary part
1874 (void)input1_strategy.erase(input1_strategy.cbegin(),
1875 input1_strategy.cbegin() + static_cast<different_type>(size_diff));
1876
1877 // handle the case likes ([a, b, c, d], [1, c, d])
1878 for (size_t i = 0; i < inputs_shape[1].size(); ++i) {
1879 if (inputs_shape[1][i] == 1) {
1880 input1_strategy[i] = 1;
1881 } else {
1882 break;
1883 }
1884 }
1885
1886 // reset the strategy
1887 tmp_strategy.push_back(input1_strategy); // input1
1888 sp->ResetInputs(tmp_strategy);
1889 }
1890 return SUCCESS;
1891 }
1892
1893 // generate strategies for that input0 and input1 have same size, and input0 or input1 needs to broadcast
1894 // such as: ([a, 1], [1, b]) or ([a, b, c, d], [1, b, c, d]) or ([a, b, c, 1], [1, b, c, d])
GenerateStrategiesForBroadcastBoth(int64_t stage_id,const Shapes & inputs_shape,const Shapes & splittable_inputs,std::vector<StrategyPtr> * const sp_vector)1895 Status GenerateStrategiesForBroadcastBoth(int64_t stage_id, const Shapes &inputs_shape, const Shapes &splittable_inputs,
1896 std::vector<StrategyPtr> *const sp_vector) {
1897 if (sp_vector == nullptr) {
1898 MS_LOG(ERROR) << "The sp_vector is null.";
1899 return FAILED;
1900 }
1901
1902 if (inputs_shape[0].size() != inputs_shape[1].size()) {
1903 MS_LOG(ERROR) << "Invalid inputs shape.";
1904 return FAILED;
1905 }
1906
1907 // step1: ([a, 1], [1, b]) -> [a, b]
1908 Shape max_shape, splittable_vector;
1909 for (size_t i = 0; i < inputs_shape[0].size(); ++i) {
1910 if (inputs_shape[0][i] >= inputs_shape[1][i]) {
1911 max_shape.push_back(inputs_shape[0][i]);
1912 splittable_vector.push_back(splittable_inputs[0][i]);
1913 } else {
1914 max_shape.push_back(inputs_shape[1][i]);
1915 splittable_vector.push_back(splittable_inputs[1][i]);
1916 }
1917 }
1918
1919 // step2: ([a, 1], [1, b]) -> generate strategy for ([a, b], [a, b])
1920 Shapes tmp_inputs_shape = {max_shape, max_shape};
1921 Shapes tmp_splittable_inputs = {splittable_vector, splittable_vector};
1922 if (GenerateStrategiesForTwoEqualInputs(stage_id, tmp_inputs_shape, tmp_splittable_inputs, sp_vector) != SUCCESS) {
1923 MS_LOG(ERROR) << "GenerateStrategiesForTwoEqualInputs failed.";
1924 return FAILED;
1925 }
1926
1927 // step3: reset the strategy if the dimension is 1
1928 for (auto &sp : *sp_vector) {
1929 Dimensions input0_strategy = sp->GetInputDim()[0];
1930 Dimensions input1_strategy = sp->GetInputDim()[1];
1931 for (size_t i = 0; i < inputs_shape[0].size(); ++i) {
1932 if (inputs_shape[0][i] == 1) {
1933 input0_strategy[i] = 1;
1934 }
1935
1936 if (inputs_shape[1][i] == 1) {
1937 input1_strategy[i] = 1;
1938 }
1939 }
1940 sp->ResetInputs({input0_strategy, input1_strategy});
1941 }
1942
1943 return SUCCESS;
1944 }
1945
GenerateStrategiesForIndependentInputsBase(int64_t stage_id,size_t dev_num,const Shapes & inputs_shape,const Shapes & splittable_inputs,std::vector<StrategyPtr> * sp_vector)1946 Status GenerateStrategiesForIndependentInputsBase(int64_t stage_id, size_t dev_num, const Shapes &inputs_shape,
1947 const Shapes &splittable_inputs,
1948 std::vector<StrategyPtr> *sp_vector) {
1949 Shape combined_inputs_shape, combined_splittable_inputs, combined_partitions;
1950 for (size_t j = 0; j < inputs_shape.size(); ++j) {
1951 (void)combined_inputs_shape.insert(combined_inputs_shape.cend(), inputs_shape[j].cbegin(), inputs_shape[j].cend());
1952 (void)combined_splittable_inputs.insert(combined_splittable_inputs.cend(), splittable_inputs[j].cbegin(),
1953 splittable_inputs[j].cend());
1954 }
1955 std::function<void(uint64_t, size_t)> recursive = [&stage_id, &dev_num, &sp_vector, &combined_inputs_shape,
1956 &combined_splittable_inputs, &combined_partitions, &recursive,
1957 &inputs_shape](uint64_t current_index, size_t n) {
1958 if (current_index == combined_inputs_shape.size()) {
1959 MS_LOG(DEBUG) << "The value of combined_splittable_inputs.size is: " << combined_splittable_inputs.size();
1960 Shapes inputs_partitions;
1961 size_t global_index = 0;
1962 for (auto &shape : inputs_shape) {
1963 Shape tmp_partition;
1964 for (size_t j = 0; j < shape.size(); ++j) {
1965 tmp_partition.push_back(combined_partitions[global_index]);
1966 global_index++;
1967 }
1968 inputs_partitions.push_back(tmp_partition);
1969 }
1970 StrategyPtr sp;
1971 if (PrepareStrategyBase(stage_id, dev_num, inputs_partitions, &sp) == SUCCESS) {
1972 sp_vector->push_back(sp);
1973 }
1974 return;
1975 } else {
1976 MS_LOG(DEBUG) << "The value of sp_vector size is " << sp_vector->size();
1977 if (combined_splittable_inputs[current_index] == 0) {
1978 combined_partitions.push_back(MIN_SLICE_NUM);
1979 recursive(current_index + 1, n / MIN_SLICE_NUM);
1980 combined_partitions.pop_back();
1981 } else if (combined_splittable_inputs[current_index] == 1) {
1982 for (uint64_t i = 1; i <= n; i *= 2) {
1983 if (n % i == 0 && LongToSize(combined_inputs_shape[current_index]) % i == 0) {
1984 combined_partitions.push_back(i);
1985 recursive(current_index + 1, n / i);
1986 combined_partitions.pop_back();
1987 }
1988 }
1989 }
1990 }
1991 };
1992 recursive(0, dev_num);
1993 if (sp_vector->empty()) {
1994 MS_LOG(EXCEPTION) << "No available strategy for current OperatorInfo.";
1995 }
1996 return SUCCESS;
1997 }
1998
1999 // 'splittable_inputs' has the same dimensions as 'inputs_shape_'. '0' in 'splittable_inputs' means that
2000 // the corresponding dimension is unsplittable, '1' in 'splittable_inputs' means that the corresponding
2001 // dimension is splittable. 'inputs_partitions' is the result of partitions.
2002 // NOTE: This implementation would partition all splittable dimensions in all inputs. Some operators requiring
2003 // specific dimensions in inputs have the identical partition should have individual implementation.
GenerateStrategiesForIndependentInputs(int64_t stage_id,const Shapes & inputs_shape,const Shapes & splittable_inputs,std::vector<StrategyPtr> * sp_vector)2004 Status GenerateStrategiesForIndependentInputs(int64_t stage_id, const Shapes &inputs_shape,
2005 const Shapes &splittable_inputs, std::vector<StrategyPtr> *sp_vector) {
2006 if (sp_vector == nullptr) {
2007 MS_LOG(ERROR) << "The sp_vector is null.";
2008 return FAILED;
2009 }
2010 if (splittable_inputs.size() != inputs_shape.size()) {
2011 MS_LOG(ERROR) << "Splittable_inputs do not have the same input number of inputs shape, " << splittable_inputs.size()
2012 << " : " << inputs_shape.size();
2013 return FAILED;
2014 }
2015 CheckGlobalDeviceManager();
2016 size_t dev_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
2017 auto dev_num_2_power = (dev_num & (dev_num - 1));
2018 if (dev_num_2_power == 0) {
2019 return GenerateStrategiesForIndependentInputsBase(stage_id, dev_num, inputs_shape, splittable_inputs, sp_vector);
2020 }
2021 MS_EXCEPTION_IF_ZERO("dev_num - dev_num_2_power", dev_num - dev_num_2_power);
2022 auto dev_num_not_2_power = dev_num / (dev_num - dev_num_2_power);
2023 std::vector<StrategyPtr> sp_vector_2_power_part;
2024 if (GenerateStrategiesForIndependentInputsBase(stage_id, dev_num - dev_num_2_power, inputs_shape, splittable_inputs,
2025 &sp_vector_2_power_part) != SUCCESS) {
2026 MS_LOG(ERROR) << "Generate strategy in the power of 2 devices part failed.";
2027 return FAILED;
2028 }
2029 // Handle the not power of 2 part.
2030 for (auto &stra : sp_vector_2_power_part) {
2031 auto stra_arrays = stra->GetInputDim();
2032 size_t stras_size = stra_arrays.size();
2033 for (size_t i = 0; i < stras_size; ++i) {
2034 auto split_input = splittable_inputs[i];
2035 size_t stra_size = stra_arrays[i].size();
2036 for (size_t j = 0; j < stra_size; ++j) {
2037 if (split_input[j] == 0) {
2038 continue;
2039 }
2040 auto new_stra_arrays{stra_arrays};
2041 new_stra_arrays[i][j] = new_stra_arrays[i][j] * UlongToLong(dev_num_not_2_power);
2042 // discard invalid strategy
2043 MS_EXCEPTION_IF_ZERO("new_stra_arrays[i][j]", new_stra_arrays[i][j]);
2044 if (inputs_shape[i][j] % new_stra_arrays[i][j] != 0) {
2045 continue;
2046 }
2047 StrategyPtr new_stra = std::make_shared<Strategy>(stage_id, new_stra_arrays);
2048 sp_vector->push_back(new_stra);
2049 }
2050 }
2051 }
2052 // add the repeated strategy
2053 auto repeated_stra_arrays{splittable_inputs};
2054 for (auto &stra_array : repeated_stra_arrays) {
2055 std::fill(stra_array.begin(), stra_array.end(), 1);
2056 }
2057 StrategyPtr repeated_stra = std::make_shared<Strategy>(stage_id, repeated_stra_arrays);
2058 sp_vector->push_back(repeated_stra);
2059 return SUCCESS;
2060 }
2061
2062 // 'splittable_inputs' has the same dimensions as 'inputs_shape_'. '0' in 'splittable_inputs' means that
2063 // the corresponding dimension is unsplittable, otherwise means that the corresponding dimension is splittable.
2064 // In particular, if the same dimensions exist in 'splittable_inputs',
2065 // the corresponding dimensions in the strategy are the same.
2066 // 'sp' is the result of partitions.
GenerateStrategiesForDependentInputs(int64_t stage_id,const Shapes & inputs_shape,const Shapes & splittable_inputs,std::vector<StrategyPtr> * sp)2067 Status GenerateStrategiesForDependentInputs(int64_t stage_id, const Shapes &inputs_shape,
2068 const Shapes &splittable_inputs, std::vector<StrategyPtr> *sp) {
2069 if (inputs_shape.size() != splittable_inputs.size()) {
2070 MS_LOG(EXCEPTION) << "Size of inputs_shape and splittable_inputs are not equal.";
2071 }
2072
2073 std::unordered_map<int64_t, int64_t> mp;
2074 for (size_t i = 0; i < inputs_shape.size(); ++i) {
2075 auto input_shape = inputs_shape[i];
2076 auto splittable_input = splittable_inputs[i];
2077 for (size_t j = 0; j < input_shape.size(); ++j) {
2078 int64_t indice = splittable_input[j];
2079 int64_t shape = input_shape[j];
2080 if (splittable_input[j] == 0) {
2081 continue;
2082 }
2083 if (mp.find(indice) == mp.end()) {
2084 mp[indice] = shape;
2085 } else {
2086 mp[indice] = std::gcd(mp[indice], shape);
2087 }
2088 }
2089 }
2090
2091 std::unordered_map<int64_t, size_t> indices_mp;
2092 Shape tmp_input_shape;
2093 Shapes tmp_splittable_inputs = {Shape(mp.size(), 1)};
2094
2095 for (const auto &item : mp) {
2096 indices_mp[item.first] = tmp_input_shape.size();
2097 tmp_input_shape.push_back(item.second);
2098 }
2099 Shapes tmp_inputs_shape = {tmp_input_shape};
2100 std::vector<StrategyPtr> tmp_sp_vector;
2101 if (GenerateStrategiesForIndependentInputs(stage_id, tmp_inputs_shape, tmp_splittable_inputs, &tmp_sp_vector) !=
2102 SUCCESS) {
2103 return FAILED;
2104 }
2105
2106 (void)std::transform(tmp_sp_vector.begin(), tmp_sp_vector.end(), std::back_inserter(*sp),
2107 [stage_id, &indices_mp, &splittable_inputs](const StrategyPtr &sp) {
2108 auto sp_strategies = sp->GetInputDim();
2109 auto sp_sub_strategy = sp_strategies.at(0);
2110 Strategies strategies(splittable_inputs);
2111 for (size_t i = 0; i < strategies.size(); ++i) {
2112 for (size_t j = 0; j < strategies[i].size(); ++j) {
2113 if (splittable_inputs[i][j] == 0) {
2114 strategies[i][j] = 1;
2115 } else {
2116 strategies[i][j] = sp_sub_strategy[indices_mp[splittable_inputs[i][j]]];
2117 }
2118 }
2119 }
2120 return std::make_shared<Strategy>(stage_id, strategies);
2121 });
2122 return SUCCESS;
2123 }
2124
2125 // generate strategies for that have two inputs, and input0 or input1 maybe broadcast,
2126 // and the corresponding dimensions that are not broadcast are all relevant dimensions
2127 // such as: ([a, b, c, d], [a, b, c, d]) or ([b, c, d], [a, b, c, d]) or ([1, c, d], [a, b, c, d])
2128 // or ([a, b, c, d], [b, c, d]) or ([a, b, c, d], [1, c, d])
2129 // or ([a, 1], [1, b]) or ([a, b, c, d], [1, b, c, d]) or ([a, b, c, 1], [1, b, c, d])
GenerateStrategiesWithBroadcast(int64_t stage_id,const Shapes & inputs_shape,const Shapes & splittable_inputs,std::vector<StrategyPtr> * sp_vector)2130 Status GenerateStrategiesWithBroadcast(int64_t stage_id, const Shapes &inputs_shape, const Shapes &splittable_inputs,
2131 std::vector<StrategyPtr> *sp_vector) {
2132 if (sp_vector == nullptr) {
2133 MS_LOG(ERROR) << "The sp_vector is null.";
2134 return FAILED;
2135 }
2136
2137 if ((inputs_shape.size() != 2) || (splittable_inputs.size() != 2)) {
2138 MS_LOG(ERROR) << "The inputs' size is wrong.";
2139 return FAILED;
2140 }
2141
2142 if (inputs_shape[0] == inputs_shape[1]) {
2143 // element wise operation([a, b, c, d], [a, b, c, d]), so input0's strategy is equal to input1's strategy
2144 if (GenerateStrategiesForTwoEqualInputs(stage_id, inputs_shape, splittable_inputs, sp_vector) != SUCCESS) {
2145 MS_LOG(ERROR) << "GenerateStrategiesForTwoEqualInputs failed.";
2146 return FAILED;
2147 }
2148 MS_LOG(INFO) << "GenerateStrategiesForTwoEqualInputs success.";
2149 } else if (inputs_shape[0].empty() || inputs_shape[1].empty()) {
2150 // ([a, b, c, d], []) or ([], [a, b, c, d])
2151 if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape, splittable_inputs, sp_vector) != SUCCESS) {
2152 MS_LOG(ERROR) << "Generate strategies for scalar case failed.";
2153 return FAILED;
2154 }
2155 MS_LOG(INFO) << "Generate strategies for scalar case success.";
2156 } else if (inputs_shape[0].size() > inputs_shape[1].size()) {
2157 // ([a, b, c, d], [b, c, d]) or ([a, b, c, d], [1, c, d])
2158 if (GenerateStrategiesForBroadcastRight(stage_id, inputs_shape, splittable_inputs, sp_vector) != SUCCESS) {
2159 MS_LOG(ERROR) << "GenerateStrategiesForBroadcastRight failed.";
2160 return FAILED;
2161 }
2162 MS_LOG(INFO) << "GenerateStrategiesForBroadcastRight success.";
2163 } else if (inputs_shape[0].size() < inputs_shape[1].size()) {
2164 // ([b, c, d], [a, b, c, d]) or ([1, c, d], [a, b, c, d])
2165 if (GenerateStrategiesForBroadcastLeft(stage_id, inputs_shape, splittable_inputs, sp_vector) != SUCCESS) {
2166 MS_LOG(ERROR) << "GenerateStrategiesForBroadcastLeft failed.";
2167 return FAILED;
2168 }
2169 MS_LOG(INFO) << "GenerateStrategiesForBroadcastLeft success.";
2170 } else { // same size, but different value
2171 // ([a, 1], [1, b]) or ([a, b, c, d], [1, b, c, d]) or ([a, b, c, 1], [1, b, c, d])
2172 if (GenerateStrategiesForBroadcastBoth(stage_id, inputs_shape, splittable_inputs, sp_vector) != SUCCESS) {
2173 MS_LOG(ERROR) << "GenerateStrategiesForBroadcastBoth failed.";
2174 return FAILED;
2175 }
2176 MS_LOG(INFO) << "GenerateStrategiesForBroadcastBoth success.";
2177 }
2178 return SUCCESS;
2179 }
2180
SetCostUnderStrategyBase(const StrategyPtr & strategy)2181 Status OperatorInfo::SetCostUnderStrategyBase(const StrategyPtr &strategy) {
2182 if (InitForCostModel(strategy, nullptr) == FAILED) {
2183 MS_LOG(DEBUG) << name_ << ": Initialization under the strategy failed.";
2184 return FAILED;
2185 }
2186 int64_t stage_id = strategy->GetInputStage();
2187 double computation_cost =
2188 operator_cost()->GetForwardComputationCost(inputs_tensor_info_, outputs_tensor_info_, stage_id);
2189 double communication_cost = operator_cost()->GetCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id);
2190 const auto gamma = CostModelContext::GetInstance()->costmodel_gamma();
2191 std::shared_ptr<Cost> result = std::make_shared<Cost>(computation_cost, communication_cost);
2192 result->communication_without_parameter_ =
2193 operator_cost()->GetForwardCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id);
2194 result->communication_with_partial_para_ =
2195 result->communication_without_parameter_ + gamma * (communication_cost - result->communication_without_parameter_);
2196
2197 // Breaking ties for preferring data parallelization
2198 BreakingTiesForPreferringDataParallel(strategy, result);
2199 // refine communication cost calculation for practice
2200 RefineForPracticalCost(result, false);
2201 result->communication_forward_ = result->communication_without_parameter_;
2202
2203 std::shared_ptr<StrategyWithCost> swc =
2204 std::make_shared<StrategyWithCost>(strategy, inputs_tensor_info_, outputs_tensor_info_);
2205 swc->cost_list.push_back(result);
2206 (void)strategy_cost_.emplace_back(swc);
2207
2208 return SUCCESS;
2209 }
2210
GetCostByStrategyPtr(const StrategyPtr & strategy)2211 CostPtrList OperatorInfo::GetCostByStrategyPtr(const StrategyPtr &strategy) {
2212 auto target = std::find_if(
2213 strategy_cost_.begin(), strategy_cost_.end(),
2214 [&](const std::shared_ptr<StrategyWithCost> &stra_cost) { return stra_cost->strategy_ptr == strategy; });
2215 if (target == strategy_cost_.end()) {
2216 MS_LOG(EXCEPTION) << "There is no StrategyWithCost with a strategy";
2217 }
2218 return (*target)->cost_list;
2219 }
2220
GetInputLayoutFromSWCByStrategy(const StrategyPtr & stra,size_t input_index)2221 TensorLayout OperatorInfo::GetInputLayoutFromSWCByStrategy(const StrategyPtr &stra, size_t input_index) {
2222 auto is_target = [&](const std::shared_ptr<StrategyWithCost> &swc) { return swc->strategy_ptr->IsEqual(stra); };
2223 auto it = std::find_if(strategy_cost_.begin(), strategy_cost_.end(), is_target);
2224 if (it != strategy_cost_.end()) {
2225 const auto &input_info = (*it)->inputs_ptr[input_index];
2226 return std::move(input_info.tensor_layout());
2227 }
2228 TensorLayout empty;
2229 return empty;
2230 }
2231
GetOutputLayoutFromSWCByStrategy(const StrategyPtr & stra,size_t output_index)2232 TensorLayout OperatorInfo::GetOutputLayoutFromSWCByStrategy(const StrategyPtr &stra, size_t output_index) {
2233 auto is_target = [&](const std::shared_ptr<StrategyWithCost> &swc) { return swc->strategy_ptr->IsEqual(stra); };
2234 auto it = std::find_if(strategy_cost_.begin(), strategy_cost_.end(), is_target);
2235 if (it != strategy_cost_.end()) {
2236 const auto &output_info = (*it)->outputs_ptr[output_index];
2237 return std::move(output_info.tensor_layout());
2238 }
2239 TensorLayout empty;
2240 return empty;
2241 }
2242
GetStrategyFromSWCByInputLayout(const TensorLayout & input_layout,size_t input_index)2243 StrategyPtr OperatorInfo::GetStrategyFromSWCByInputLayout(const TensorLayout &input_layout, size_t input_index) {
2244 auto is_target = [&](const std::shared_ptr<StrategyWithCost> &swc) {
2245 return swc->inputs_ptr[input_index].tensor_layout() == input_layout;
2246 };
2247 auto it = std::find_if(strategy_cost_.begin(), strategy_cost_.end(), is_target);
2248 if (it != strategy_cost_.end()) {
2249 return (*it)->strategy_ptr;
2250 }
2251 return nullptr;
2252 }
2253
GetStrategyFromSWCByOutputLayout(const TensorLayout & output_layout,size_t output_index)2254 StrategyPtr OperatorInfo::GetStrategyFromSWCByOutputLayout(const TensorLayout &output_layout, size_t output_index) {
2255 auto is_target = [&](const std::shared_ptr<StrategyWithCost> &swc) {
2256 return swc->outputs_ptr[output_index].tensor_layout() == output_layout;
2257 };
2258 auto it = std::find_if(strategy_cost_.begin(), strategy_cost_.end(), is_target);
2259 if (it != strategy_cost_.end()) {
2260 return (*it)->strategy_ptr;
2261 }
2262 return nullptr;
2263 }
2264
IsReshape() const2265 bool OperatorInfo::IsReshape() const {
2266 if (name_.find(RESHAPEINFO) != std::string::npos) {
2267 return true;
2268 }
2269 return false;
2270 }
2271
IsTmpIdentity() const2272 bool OperatorInfo::IsTmpIdentity() const {
2273 if (name_.find(IDENTITY_INFO) != std::string::npos) {
2274 return true;
2275 }
2276 return false;
2277 }
2278
2279 // Keep at most (1.0 / epsilon) number of available strategies for each operator.
ApproximateStrategies()2280 void OperatorInfo::ApproximateStrategies() {
2281 auto enable_approxi = CostModelContext::GetInstance()->dp_algo_enable_approxi();
2282 if (!enable_approxi) {
2283 return;
2284 }
2285 MS_LOG(INFO) << name_ << ": Approximating strategy-cost";
2286 auto epsilon = CostModelContext::GetInstance()->dp_algo_approxi_epsilon();
2287 MS_EXCEPTION_IF_ZERO("epsilon", epsilon);
2288 auto target_num = static_cast<size_t>(std::ceil(1.0 / epsilon));
2289 if (strategy_cost_.size() <= target_num) {
2290 MS_LOG(INFO) << name_ << "'s strategy number is: " << strategy_cost_.size()
2291 << ", no greater than target-num: " << target_num;
2292 return;
2293 }
2294 std::vector<std::shared_ptr<StrategyWithCost>> ret;
2295 auto &origin_stra_cost = strategy_cost_;
2296 auto alpha = CostModelContext::GetInstance()->costmodel_alpha();
2297 auto beta = CostModelContext::GetInstance()->costmodel_beta();
2298 // sort
2299 std::sort(
2300 origin_stra_cost.begin(), origin_stra_cost.end(),
2301 [&alpha, &beta](const std::shared_ptr<StrategyWithCost> &s1, const std::shared_ptr<StrategyWithCost> &s2) {
2302 if (alpha * s1->cost_list[0]->computation_cost_ + beta * s1->cost_list[0]->communication_with_partial_para_ <
2303 alpha * s2->cost_list[0]->computation_cost_ + beta * s2->cost_list[0]->communication_with_partial_para_) {
2304 return true;
2305 }
2306 return false;
2307 });
2308 MS_EXCEPTION_IF_ZERO("target_num", target_num);
2309 size_t step_length = origin_stra_cost.size() / target_num;
2310 for (size_t i = 0; ret.size() < target_num && static_cast<size_t>(i * step_length) < origin_stra_cost.size(); ++i) {
2311 ret.push_back(origin_stra_cost[static_cast<size_t>(i * step_length)]);
2312 }
2313
2314 strategy_cost_ = ret;
2315 is_strategy_cost_exact_ = false;
2316 }
2317
ExactStrategiesAndRelatedEdges()2318 void OperatorInfo::ExactStrategiesAndRelatedEdges() {
2319 if (is_strategy_cost_exact()) {
2320 return;
2321 }
2322 ClearStrategyCost();
2323 if (GenerateStrategies(0) != SUCCESS) {
2324 MS_LOG(EXCEPTION) << name_ << ": Strategy search failed.";
2325 }
2326 SetIsStrategyCostExactTrue();
2327 // re-init the previous edges
2328 for (auto &prev_edge : prev_edges()) {
2329 if (prev_edge->InitEdgeCost() != SUCCESS) {
2330 MS_LOG(EXCEPTION) << name_ << ": Edge: " << prev_edge->edge_name() << " cost init failed.";
2331 }
2332 }
2333 // re-init the successive edges
2334 for (auto &next_edge : succ_edges()) {
2335 if (next_edge->InitEdgeCost() != SUCCESS) {
2336 MS_LOG(EXCEPTION) << name_ << ": Edge: " << next_edge->edge_name() << " cost init failed.";
2337 }
2338 }
2339 }
2340
ComputeOpAndPrevEdgeParameterInvolved()2341 int64_t OperatorInfo::ComputeOpAndPrevEdgeParameterInvolved() {
2342 if (is_output_parameter_involve_ != -1) {
2343 return is_output_parameter_involve_;
2344 }
2345 is_parameter_involve_ = is_parameter_;
2346 const auto &prev_edges = this->GetAlivePrevEdges();
2347 for (auto &p_edge : prev_edges) {
2348 auto input_index = p_edge->next_op_input_index();
2349 auto prev_op_para = p_edge->prev_operator()->ComputeOpAndPrevEdgeParameterInvolved();
2350 if (input_index >= is_parameter_involve_.size()) {
2351 MS_LOG(EXCEPTION) << name_ << " has input length: " << is_parameter_involve_.size()
2352 << ", but got wrong input_index: " << input_index;
2353 }
2354 if (prev_op_para == 0) {
2355 is_parameter_involve_[input_index] = false;
2356 } else if (prev_op_para == 1) {
2357 is_parameter_involve_[input_index] = true;
2358 } else {
2359 MS_LOG(EXCEPTION) << name_ << " got wrong value: " << prev_op_para << ", input_index: " << input_index;
2360 }
2361 p_edge->set_parameter_involve(prev_op_para);
2362 }
2363 if (std::any_of(is_parameter_involve_.begin(), is_parameter_involve_.end(), [](bool value) { return value; })) {
2364 // If anyone of the input is a parameter_involved, the output is parameter_involved.
2365 is_output_parameter_involve_ = 1;
2366 } else {
2367 is_output_parameter_involve_ = 0;
2368 }
2369 // Set 'is_parameter_involve_' and 'is_output_parameter_involve_' into operatorCost, which are used in
2370 // calculating 'inputs_in_memory' and 'output_in_memory', respectively.
2371 operator_cost()->set_is_parameter_involve(is_parameter_involve_);
2372 operator_cost()->set_output_parameter_involve(is_output_parameter_involve_);
2373 // Calculating 'output_in_memory'
2374 operator_cost()->CalculateOutputInMemory();
2375 // Calculating 'inputs_in_memory'
2376 std::map<size_t, bool> input_in_memory;
2377 for (auto &p_edge : prev_edges) {
2378 auto input_index = p_edge->next_op_input_index();
2379 auto is_in_mem = p_edge->prev_operator()->operator_cost()->is_output_in_memory();
2380 (void)input_in_memory.emplace(std::make_pair(input_index, is_in_mem));
2381 }
2382 operator_cost()->CalculateInputsInMemory(input_in_memory);
2383
2384 return is_output_parameter_involve_;
2385 }
2386
set_is_parameter(const std::vector<bool> & is_parameter)2387 Status OperatorInfo::set_is_parameter(const std::vector<bool> &is_parameter) {
2388 if (is_parameter.size() != inputs_shape_.size()) {
2389 MS_LOG(ERROR) << name_ << ": Is_parameter: " << is_parameter.size()
2390 << " do not have the same number of inputs_shape_: " << inputs_shape_.size();
2391 return FAILED;
2392 }
2393 is_parameter_ = is_parameter;
2394 operator_cost()->set_is_parameter(is_parameter);
2395 return SUCCESS;
2396 }
2397
CalculateMemoryCost()2398 Status OperatorInfo::CalculateMemoryCost() {
2399 if (is_parameter_involve_.size() != is_parameter_.size()) {
2400 MS_LOG(ERROR) << name_ << ": the size of 'is_parameter_' is " << is_parameter_.size()
2401 << " does not have the same number of the size of 'is_parameter_involve_'."
2402 << is_parameter_involve_.size();
2403 return FAILED;
2404 }
2405 // Set the memory cost in the 'strategy_cost_'
2406 for (auto &swc : strategy_cost_) {
2407 auto mem_cost = operator_cost()->GetMemoryCost(swc->inputs_ptr, swc->outputs_ptr);
2408 swc->cost_list[0]->memory_with_reuse_ = mem_cost;
2409 }
2410 return SUCCESS;
2411 }
2412
CalculateMemoryCostForInference()2413 Status OperatorInfo::CalculateMemoryCostForInference() {
2414 // First, set the 'is_outputs_critical_' flag into OperatorCost.
2415 if (is_output_critical_ == -1) {
2416 MS_LOG(EXCEPTION) << name_ << ": The critical flag is not set.";
2417 }
2418 operator_cost()->set_output_critical(is_output_critical_);
2419 // Set the memory cost in the 'strategy_cost_'
2420 for (auto &swc : strategy_cost_) {
2421 auto mem_cost = operator_cost()->GetMemoryCostForInference(swc->inputs_ptr, swc->outputs_ptr);
2422 swc->cost_list[0]->memory_with_reuse_ = mem_cost;
2423 }
2424 return SUCCESS;
2425 }
2426
CorrectMemoryCost(size_t input_index)2427 Status OperatorInfo::CorrectMemoryCost(size_t input_index) {
2428 for (auto &swc : strategy_cost_) {
2429 double parameter_mem_cost = ListProduct(swc->inputs_ptr[input_index].slice_shape()) *
2430 static_cast<double>(operator_cost()->inputs_type_lengths()[input_index]);
2431 swc->cost_list[0]->memory_with_reuse_ -= parameter_mem_cost;
2432 if (swc->cost_list[0]->memory_with_reuse_ < 0) {
2433 MS_LOG(WARNING) << name_ << ": The memory cost after correction is: " << swc->cost_list[0]->memory_with_reuse_
2434 << ", the parameter memory cost is: " << parameter_mem_cost;
2435 swc->cost_list[0]->memory_with_reuse_ = 0;
2436 }
2437 }
2438 return SUCCESS;
2439 }
2440
ComputeRepeatDeviceNumByTensorMap(const Shape & dev_matrix_shape,const Shape & tensor_map)2441 int64_t ComputeRepeatDeviceNumByTensorMap(const Shape &dev_matrix_shape, const Shape &tensor_map) {
2442 int64_t ret = -1;
2443
2444 // The number of repetitions is equal to the number of all devices divided by the number of devices use for
2445 // tensor map.
2446 int64_t device_num = std::accumulate(dev_matrix_shape.begin(), dev_matrix_shape.end(), 1, std::multiplies<int64_t>());
2447 for (auto &element : tensor_map) {
2448 // -1 means the corresponding dimension is not split.
2449 if (element == MAP_NONE) {
2450 continue;
2451 } else if ((element < 0) || (LongToSize(element) >= dev_matrix_shape.size())) {
2452 MS_LOG(ERROR) << "Invalid tensor map: " << ShapeToString(tensor_map) << ", the dev matrix shape is "
2453 << ShapeToString(dev_matrix_shape);
2454 return ret;
2455 } else {
2456 size_t index = dev_matrix_shape.size() - LongToSize(element) - 1;
2457 if (dev_matrix_shape[index] <= 0) {
2458 MS_LOG(ERROR) << "Invalid dev matrix shape: " << ShapeToString(dev_matrix_shape);
2459 return ret;
2460 }
2461 device_num /= dev_matrix_shape[index];
2462 }
2463 }
2464
2465 return device_num;
2466 }
2467
InferAsLossDivisor()2468 Status OperatorInfo::InferAsLossDivisor() {
2469 if (!ParallelContext::GetInstance()->loss_repeated_mean()) {
2470 as_loss_divisor_ = 1;
2471 return SUCCESS;
2472 }
2473 if (!inputs_shape_new_.empty()) {
2474 MS_LOG(ERROR) << name_ << ": For Tuple input ops, please override this function";
2475 return FAILED;
2476 }
2477 if (outputs_tensor_map_.empty()) {
2478 MS_LOG(ERROR) << name_ << ": The outputs tensor map is empty.";
2479 return FAILED;
2480 }
2481
2482 if (outputs_tensor_map_.size() > 1) {
2483 MS_LOG(ERROR) << name_ << ": The output size is " << outputs_tensor_map_.size()
2484 << ", need to override this function ";
2485 return FAILED;
2486 }
2487
2488 if (outputs_tensor_map_[0].empty()) {
2489 as_loss_divisor_ = stage_device_size_;
2490 MS_LOG(INFO) << name_ << ": The output is a scalar, use the dev size " << as_loss_divisor_ << ", loss divisor.";
2491 return SUCCESS;
2492 }
2493
2494 if (out_dev_matrix_shape_.empty()) {
2495 out_dev_matrix_shape_ = dev_matrix_shape_;
2496 }
2497 as_loss_divisor_ = ComputeRepeatDeviceNumByTensorMap(out_dev_matrix_shape_, outputs_tensor_map_[0]);
2498 MS_LOG(INFO) << name_ << ": the dev matrix shape is " << ShapeToString(out_dev_matrix_shape_)
2499 << ", the output tensor map is " << ShapeToString(outputs_tensor_map_[0]) << ", loss divisor is "
2500 << as_loss_divisor_;
2501 return SUCCESS;
2502 }
2503
InferAsLossDivisorByLayout()2504 Status OperatorInfo::InferAsLossDivisorByLayout() {
2505 if (!ParallelContext::GetInstance()->loss_repeated_mean()) {
2506 as_loss_divisor_ = 1;
2507 return SUCCESS;
2508 }
2509
2510 if (outputs_tensor_info_.empty()) {
2511 MS_LOG(ERROR) << name_ << ": The outputs tensor info is empty.";
2512 return FAILED;
2513 }
2514
2515 if (outputs_tensor_info_.size() > 1) {
2516 MS_LOG(ERROR) << name_ << ": The output size is " << outputs_tensor_info_.size()
2517 << ", need to override this function ";
2518 return FAILED;
2519 }
2520
2521 TensorMaps outputs_tensor_map = outputs_tensor_info_[0].tensor_layout().tensor_map_before();
2522 if (outputs_tensor_map.empty()) {
2523 MS_LOG(INFO) << name_ << ": out_dev_matrix_shape is empty";
2524 as_loss_divisor_ = stage_device_size_;
2525 MS_LOG(INFO) << name_ << ": The output is a scalar, use the dev size " << as_loss_divisor_ << ", loss divisor.";
2526 return SUCCESS;
2527 }
2528
2529 auto out_dev_matrix_shape = outputs_tensor_info_[0].tensor_layout().device_arrangement_origin().array();
2530 if (out_dev_matrix_shape.empty()) {
2531 out_dev_matrix_shape = dev_matrix_shape_;
2532 }
2533 Shape squashed_tensor_map;
2534 for (const auto &tensor_map : outputs_tensor_map) {
2535 std::copy(tensor_map.begin(), tensor_map.end(), std::back_inserter(squashed_tensor_map));
2536 }
2537 as_loss_divisor_ = ComputeRepeatDeviceNumByTensorMap(out_dev_matrix_shape, squashed_tensor_map);
2538 MS_LOG(INFO) << name_ << ": the dev matrix shape is " << ShapeToString(out_dev_matrix_shape)
2539 << ", the output tensor map is " << ShapeToString(squashed_tensor_map) << ", loss divisor is "
2540 << as_loss_divisor_;
2541 return SUCCESS;
2542 }
2543
2544 // If the operator is used as a loss, a div node is inserted for the grad of all its inputs.
InferVirtualDivOps()2545 Status OperatorInfo::InferVirtualDivOps() {
2546 if (InferAsLossDivisor() != SUCCESS) {
2547 MS_LOG(ERROR) << name_ << ": InferAsLossDivisor failed.";
2548 return FAILED;
2549 }
2550
2551 if (as_loss_divisor_ <= 0) {
2552 MS_LOG(ERROR) << name_ << ": Invalid loss divisor: " << as_loss_divisor_;
2553 return FAILED;
2554 } else if (as_loss_divisor_ == 1) {
2555 MS_LOG(INFO) << name_ << ": The loss divisor is 1, no need to create virtual div op.";
2556 return SUCCESS;
2557 }
2558
2559 virtual_div_op_.clear();
2560 // if loss is repeated calculation, insert div op
2561 Operator op = CreateVirtualDivOp(as_loss_divisor_);
2562 virtual_div_op_.push_back(op);
2563 return SUCCESS;
2564 }
2565
InferVirtualDivOpsByLayout()2566 Status OperatorInfo::InferVirtualDivOpsByLayout() {
2567 if (InferAsLossDivisorByLayout() != SUCCESS) {
2568 MS_LOG(ERROR) << name_ << ": InferAsLossDivisor failed.";
2569 return FAILED;
2570 }
2571
2572 if (as_loss_divisor_ <= 0) {
2573 MS_LOG(ERROR) << name_ << ": Invalid loss divisor: " << as_loss_divisor_;
2574 return FAILED;
2575 } else if (as_loss_divisor_ == 1) {
2576 MS_LOG(INFO) << name_ << ": The loss divisor is 1, no need to create virtual div op.";
2577 return SUCCESS;
2578 }
2579
2580 virtual_div_op_.clear();
2581 // if loss is repeated calculation, insert div op
2582 Operator op = CreateVirtualDivOp(as_loss_divisor_);
2583 virtual_div_op_.push_back(op);
2584 return SUCCESS;
2585 }
2586
SetInputAndOutputTypeLength(const std::vector<size_t> & input_lengths,const std::vector<size_t> & output_lengths)2587 Status OperatorInfo::SetInputAndOutputTypeLength(const std::vector<size_t> &input_lengths,
2588 const std::vector<size_t> &output_lengths) {
2589 if (input_lengths.size() != inputs_shape_.size()) {
2590 MS_LOG(ERROR) << name_ << ": Input_lengths: " << input_lengths.size()
2591 << " do not have the same number of inputs shape: " << inputs_shape_.size();
2592 return FAILED;
2593 }
2594 if (output_lengths.size() != outputs_shape_.size()) {
2595 MS_LOG(ERROR) << name_ << ": Output_lengths: " << output_lengths.size()
2596 << " do not have the same number of outputs shape: " << outputs_shape_.size();
2597 return FAILED;
2598 }
2599 inputs_type_lengths_ = input_lengths;
2600 outputs_type_lengths_ = output_lengths;
2601 operator_cost()->SetInputAndOutputTypeLength(input_lengths, output_lengths);
2602 return SUCCESS;
2603 }
2604
GetOutputsTotalSize()2605 double OperatorInfo::GetOutputsTotalSize() {
2606 if (is_calculated_outputs_size_) {
2607 return outputs_total_size_;
2608 }
2609 if (outputs_type_lengths_.size() != outputs_shape_.size()) {
2610 MS_LOG(EXCEPTION) << name_ << ": Output_lengths: " << outputs_type_lengths_.size()
2611 << " do not have the same number of outputs shape: " << outputs_shape_.size();
2612 }
2613 double sum = 0.0;
2614 for (size_t i = 0; i < outputs_type_lengths_.size(); ++i) {
2615 auto size = std::accumulate(outputs_shape_[i].begin(), outputs_shape_[i].end(), static_cast<double>(1.0),
2616 std::multiplies<double>());
2617 sum += size * static_cast<double>(outputs_type_lengths_[i]);
2618 }
2619 is_calculated_outputs_size_ = true;
2620 outputs_total_size_ = sum;
2621 return outputs_total_size_;
2622 }
2623
set_outputs_type(const std::vector<TypePtr> & outputs_type)2624 Status OperatorInfo::set_outputs_type(const std::vector<TypePtr> &outputs_type) {
2625 if (outputs_type.size() != outputs_shape_.size()) {
2626 MS_LOG(ERROR) << name_ << ": Outputs type: " << outputs_type.size()
2627 << " do not have the same number of outputs shape: " << outputs_shape_.size();
2628 return FAILED;
2629 }
2630 outputs_type_ = outputs_type;
2631 return SUCCESS;
2632 }
2633
BreakingTiesForPreferringDataParallel(const StrategyPtr & stra,const CostPtr & cost) const2634 void OperatorInfo::BreakingTiesForPreferringDataParallel(const StrategyPtr &stra, const CostPtr &cost) const {
2635 if (!stra->GetInputDim().empty() && !stra->GetInputDim()[0].empty()) {
2636 if (stra->GetInputDim()[0][0] == stage_device_size_) {
2637 if (cost->computation_cost_ > 1.0) {
2638 cost->computation_cost_ -= 1.0;
2639 }
2640 if (cost->communication_cost_ > 1.0) {
2641 cost->communication_cost_ -= 1.0;
2642 }
2643 if (cost->communication_with_partial_para_ > 1.0) {
2644 cost->communication_with_partial_para_ -= 1.0;
2645 }
2646 if (cost->communication_without_parameter_ > 1.0) {
2647 cost->communication_without_parameter_ -= 1.0;
2648 }
2649 }
2650 }
2651 }
2652
SetSelectedStrategy(const StrategyPtr & s_strategy,size_t curr_depth)2653 void OperatorInfo::SetSelectedStrategy(const StrategyPtr &s_strategy, size_t curr_depth) {
2654 MS_EXCEPTION_IF_NULL(s_strategy);
2655 if ((selected_strategy_depth_ != -1) && (SizeToLong(curr_depth) > selected_strategy_depth_)) {
2656 MS_LOG(INFO) << name_ << " has already been set strategy.";
2657 return;
2658 }
2659 MS_LOG(INFO) << name_ << ": Set strategy " << s_strategy->ToString();
2660 selected_strategy_ = s_strategy;
2661 selected_strategy_depth_ = SizeToLong(curr_depth);
2662 }
2663
set_swc_index(int64_t swc,int64_t depth)2664 void OperatorInfo::set_swc_index(int64_t swc, int64_t depth) {
2665 MS_LOG(INFO) << name_ << ": Set SWC index: " << swc;
2666 selected_strategy_depth_ = depth;
2667 swc_index_ = swc;
2668 }
2669
cnodes()2670 std::vector<CNodePtr> OperatorInfo::cnodes() { return cnodes_; }
2671
GetForwardMemoryCostFromCNode()2672 double OperatorInfo::GetForwardMemoryCostFromCNode() {
2673 return operator_cost()->GetForwardComputationCost(inputs_tensor_info_, outputs_tensor_info_, 0);
2674 }
2675
CheckSelectedStrategy(const StrategyPtr & s_strategy)2676 void OperatorInfo::CheckSelectedStrategy(const StrategyPtr &s_strategy) {
2677 MS_EXCEPTION_IF_NULL(s_strategy);
2678 if (!s_strategy->IsEqual(selected_strategy_)) {
2679 MS_LOG(INFO) << name_
2680 << "'s strategy may cause suboptimal, the determined strategy: " << selected_strategy_->ToString()
2681 << "The minimal strategy: " << s_strategy->ToString();
2682 }
2683 }
2684
SetStrategyCost(const std::vector<std::shared_ptr<StrategyWithCost>> & stra_cost)2685 void OperatorInfo::SetStrategyCost(const std::vector<std::shared_ptr<StrategyWithCost>> &stra_cost) {
2686 strategy_cost_ = stra_cost;
2687 }
2688
GenerateStrategies(int64_t stage_id)2689 Status OperatorInfo::GenerateStrategies(int64_t stage_id) {
2690 if (InferAttrs() != SUCCESS) {
2691 MS_LOG(ERROR) << name_ << ": Infer attrs failed";
2692 return FAILED;
2693 }
2694
2695 DivisorsReplaceShapes(); // in dynamic shape, using divisors replace to shapes before CheckStrategy and so on
2696 std::vector<StrategyPtr> sp_vector = GenerateOpStrategies(stage_id);
2697 ResumeShapes(); // resume shapes
2698
2699 size_t success = 0;
2700 for (auto &sp : sp_vector) {
2701 if (SetCostUnderStrategy(sp) == SUCCESS) {
2702 success++;
2703 MS_LOG(INFO) << name_ << ": Successfully generated the " << GetSerialNumberString(success)
2704 << " strategy: " << sp->ToString();
2705 } else {
2706 MS_LOG(INFO) << name_ << ": SetCostUnderStrategy failed, the strategy is " << sp->ToString();
2707 }
2708 }
2709 return SUCCESS;
2710 }
2711
GetIntAttr(const std::string & attr_name)2712 int64_t OperatorInfo::GetIntAttr(const std::string &attr_name) {
2713 auto attr_iter = attrs_.find(attr_name);
2714 if (attr_iter == attrs_.end()) {
2715 MS_LOG(EXCEPTION) << name_ << ": Can not find the attribution of " << attr_name;
2716 }
2717
2718 MS_EXCEPTION_IF_NULL(attr_iter->second);
2719 if (!attr_iter->second->isa<Int64Imm>()) {
2720 MS_LOG(EXCEPTION) << name_ << ": The value of " << attr_name << " is not int";
2721 }
2722
2723 return attr_iter->second->cast<Int64ImmPtr>()->value();
2724 }
2725
GetBoolAttr(const std::string & attr_name)2726 bool OperatorInfo::GetBoolAttr(const std::string &attr_name) {
2727 auto attr_iter = attrs_.find(attr_name);
2728 if (attr_iter == attrs_.end()) {
2729 MS_LOG(EXCEPTION) << name_ << ": Can not find the attribution of " << attr_name;
2730 }
2731
2732 MS_EXCEPTION_IF_NULL(attr_iter->second);
2733 if (!attr_iter->second->isa<BoolImm>()) {
2734 MS_LOG(EXCEPTION) << name_ << ": The value of " << attr_name << " is not int";
2735 }
2736
2737 return attr_iter->second->cast<BoolImmPtr>()->value();
2738 }
2739
GetStringAttr(const std::string & attr_name)2740 std::string OperatorInfo::GetStringAttr(const std::string &attr_name) {
2741 std::string string_attr;
2742 auto attr_iter = attrs_.find(attr_name);
2743 if (attr_iter == attrs_.end()) {
2744 MS_LOG(EXCEPTION) << name_ << ": Can not find the attribution of " << attr_name;
2745 }
2746
2747 MS_EXCEPTION_IF_NULL(attr_iter->second);
2748 if (!attr_iter->second->isa<StringImm>()) {
2749 MS_LOG(EXCEPTION) << name_ << ": The value of " << attr_name << " is not string";
2750 }
2751
2752 string_attr = attr_iter->second->cast<StringImmPtr>()->value();
2753 return string_attr;
2754 }
2755
GetTupleIntAttr(const std::string & attr_name)2756 std::vector<int64_t> OperatorInfo::GetTupleIntAttr(const std::string &attr_name) {
2757 std::vector<int64_t> tuple_attr;
2758 auto tuple_attr_iter = attrs_.find(attr_name);
2759 if (tuple_attr_iter == attrs_.end()) {
2760 MS_LOG(EXCEPTION) << name_ << ": Can not find the attribution of " << attr_name;
2761 }
2762
2763 MS_EXCEPTION_IF_NULL(tuple_attr_iter->second);
2764 tuple_attr = GetValue<std::vector<int64_t>>(tuple_attr_iter->second);
2765
2766 return tuple_attr;
2767 }
2768
GetFloatAttr(const std::string & attr_name)2769 float OperatorInfo::GetFloatAttr(const std::string &attr_name) {
2770 auto attr_iter = attrs_.find(attr_name);
2771 if (attr_iter == attrs_.end()) {
2772 MS_LOG(EXCEPTION) << name_ << ": Can not find the attribution of " << attr_name;
2773 }
2774
2775 MS_EXCEPTION_IF_NULL(attr_iter->second);
2776 if (!attr_iter->second->isa<FP32Imm>()) {
2777 MS_LOG(EXCEPTION) << name_ << ": The value of " << attr_name << " is not float";
2778 }
2779
2780 return attr_iter->second->cast<FP32ImmPtr>()->value();
2781 }
2782
GetValueSequence(const ValuePtr & sequence)2783 std::vector<ValuePtr> GetValueSequence(const ValuePtr &sequence) {
2784 MS_EXCEPTION_IF_NULL(sequence);
2785 std::vector<ValuePtr> ret;
2786 if (!sequence->isa<ValueTuple>() && !sequence->isa<ValueList>()) {
2787 MS_LOG(ERROR) << "The arg is not value tuple or value list";
2788 return ret;
2789 }
2790
2791 if (sequence->isa<ValueTuple>()) {
2792 auto val_tuple = sequence->cast<ValueTuplePtr>();
2793 return val_tuple->value();
2794 }
2795 auto val = sequence->cast<ValueListPtr>();
2796 return val->value();
2797 }
2798
MakeListValue(const std::vector<int64_t> & v)2799 ValuePtr MakeListValue(const std::vector<int64_t> &v) {
2800 std::vector<ValuePtr> list;
2801 (void)std::transform(v.begin(), v.end(), std::back_inserter(list), [](int64_t ele) { return MakeValue(ele); });
2802 return std::make_shared<ValueSequence>(list);
2803 }
2804
MakeTupleListValue(const Shapes & v)2805 ValuePtr MakeTupleListValue(const Shapes &v) {
2806 std::vector<ValuePtr> tuple;
2807 (void)std::transform(v.begin(), v.end(), std::back_inserter(tuple),
2808 [](const std::vector<int64_t> &list) { return MakeListValue(list); });
2809 return std::make_shared<ValueTuple>(tuple);
2810 }
2811
CreateValueTupleAnfNodePtr(const std::vector<int64_t> & value_tuple)2812 AnfNodePtr CreateValueTupleAnfNodePtr(const std::vector<int64_t> &value_tuple) {
2813 auto value_ptr = MakeValue(value_tuple)->cast<ValueTuplePtr>();
2814 auto value_node = NewValueNode(value_ptr);
2815 return value_node->cast<AnfNodePtr>();
2816 }
2817
CreateTensorTupleAnfNodePtr(const tensor::TensorPtrList & tensor_tuple)2818 AnfNodePtr CreateTensorTupleAnfNodePtr(const tensor::TensorPtrList &tensor_tuple) {
2819 auto tensor_ptr = MakeValue(tensor_tuple)->cast<ValueTuplePtr>();
2820 auto tensor_node = NewValueNode(tensor_ptr);
2821 return tensor_node->cast<AnfNodePtr>();
2822 }
2823
CreateDivOpWithType(float divisor,const TypePtr & dtype)2824 Operator CreateDivOpWithType(float divisor, const TypePtr &dtype) {
2825 OperatorName operator1_name = REAL_DIV;
2826 mindspore::tensor::TensorPtr tensor_ptr = std::make_shared<mindspore::tensor::Tensor>(divisor, dtype);
2827 ValuePtr op1_param_value = MakeValue(tensor_ptr);
2828 Attr op1_param = std::make_pair("divisor", op1_param_value);
2829 OperatorParams operator1_params = {std::make_pair(op1_param, 2)};
2830 OperatorAttrs operator1_attrs;
2831 OperatorArgs operator1_args = std::make_pair(operator1_attrs, operator1_params);
2832 Operator div_op = std::make_pair(operator1_name, operator1_args);
2833 return div_op;
2834 }
2835
CreateReduceMeanForwardOp(const std::vector<Group> & forward_group,const TypePtr & dtype)2836 ForwardOp CreateReduceMeanForwardOp(const std::vector<Group> &forward_group, const TypePtr &dtype) {
2837 // Create AllReduceSum op
2838 Operator op0 = CreateAllReduceOp(REDUCE_OP_SUM, forward_group[0].name());
2839 std::string group_name = forward_group[0].name();
2840 MS_LOG(INFO) << "The group of forward all reduce is " << group_name;
2841
2842 // Create RealDiv op
2843 std::vector<Device> device_list = forward_group[0].GetDevicesList();
2844 auto divisor = SizeToFloat(device_list.size());
2845 Operator op1 = CreateDivOpWithType(divisor, dtype);
2846 std::string dtype_name = dtype->ToString();
2847 MS_LOG(INFO) << "The divisor of Div op is " << device_list.size() << ", the dtype is " << dtype_name;
2848
2849 return {op0, op1};
2850 }
2851
GetTensorValue(const ValuePtr & ori_value)2852 std::vector<int64_t> GetTensorValue(const ValuePtr &ori_value) {
2853 MS_EXCEPTION_IF_NULL(ori_value);
2854 if (!ori_value->isa<tensor::Tensor>()) {
2855 MS_LOG(INTERNAL_EXCEPTION) << "Value is not tensor";
2856 }
2857 auto tensor_ptr = ori_value->cast<tensor::TensorPtr>();
2858 std::vector<int64_t> value;
2859 auto element_size = tensor_ptr->data().size();
2860 auto *data = static_cast<int64_t *>(tensor_ptr->data_c());
2861 for (auto i = 0; i < element_size; i++) {
2862 value.push_back(data[i]);
2863 }
2864 return value;
2865 }
2866 } // namespace parallel
2867 } // namespace mindspore
2868