• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/parallel/ops_info/operator_info.h"
18 
19 #include <algorithm>
20 #include <functional>
21 #include <memory>
22 #include <numeric>
23 #include <string>
24 #include <unordered_map>
25 #include <utility>
26 #include <vector>
27 
28 #include "frontend/parallel/auto_parallel/edge_costmodel.h"
29 #include "frontend/parallel/auto_parallel/graph_costmodel.h"
30 #include "frontend/parallel/step_parallel_utils.h"
31 #include "frontend/parallel/graph_util/graph_utils.h"
32 #include "mindspore/core/ops/sequence_ops.h"
33 #include "include/common/debug/anf_dump_utils.h"
34 #include "include/common/utils/parallel_context.h"
35 #include "ir/tensor.h"
36 #include "ir/value.h"
37 #include "mindspore/core/ops/framework_ops.h"
38 #include "utils/log_adapter.h"
39 
40 namespace mindspore {
41 namespace parallel {
42 namespace {
43 struct InStrategyValueRegister {
InStrategyValueRegistermindspore::parallel::__anonbca9d9360111::InStrategyValueRegister44   InStrategyValueRegister() noexcept {
45     AnfDumpHandler::SetInStrategyValueHandler([](const std::shared_ptr<AnfNode> &node) -> ValuePtr {
46       auto operator_info = node->user_data<parallel::OperatorInfo>();
47       if (operator_info == nullptr) {
48         return nullptr;
49       }
50 
51       auto in_strategy = operator_info->strategy();
52       if (in_strategy == nullptr) {
53         return nullptr;
54       }
55 
56       return MakeValue(in_strategy->GetInputDim());
57     });
58   }
59 } in_regist;
60 
61 struct InStrategyStageValueRegister {
InStrategyStageValueRegistermindspore::parallel::__anonbca9d9360111::InStrategyStageValueRegister62   InStrategyStageValueRegister() noexcept {
63     AnfDumpHandler::SetInStrategyStageValueHandler([](const std::shared_ptr<AnfNode> &node) -> ValuePtr {
64       auto operator_info = node->user_data<parallel::OperatorInfo>();
65       if (operator_info == nullptr) {
66         return nullptr;
67       }
68 
69       auto in_strategy = operator_info->strategy();
70       if (in_strategy == nullptr) {
71         return nullptr;
72       }
73 
74       return MakeValue(in_strategy->GetInputStage());
75     });
76   }
77 } in_stage_regist;
78 
79 struct OutStrategyValueRegister {
OutStrategyValueRegistermindspore::parallel::__anonbca9d9360111::OutStrategyValueRegister80   OutStrategyValueRegister() noexcept {
81     AnfDumpHandler::SetOutStrategyValueHandler([](const std::shared_ptr<AnfNode> &node) -> ValuePtr {
82       auto operator_info = node->user_data<parallel::OperatorInfo>();
83       if (operator_info == nullptr) {
84         return nullptr;
85       }
86 
87       auto in_strategy = operator_info->out_strategy();
88       if (in_strategy == nullptr) {
89         return nullptr;
90       }
91 
92       return MakeValue(in_strategy->GetInputDim());
93     });
94   }
95 } out_regist;
96 }  // namespace
97 
StrategyToString(const Strategies & strategy)98 std::string StrategyToString(const Strategies &strategy) {
99   std::string strategy_str = "";
100   strategy_str += "(";
101   for (size_t i = 0; i < strategy.size(); ++i) {
102     strategy_str += "(";
103     for (size_t j = 0; j < strategy[i].size(); ++j) {
104       strategy_str += std::to_string(strategy[i][j]);
105       if (j != strategy[i].size() - 1) {
106         strategy_str += ", ";
107       }
108     }
109     strategy_str += ")";
110     if (i != strategy.size() - 1) {
111       strategy_str += ", ";
112     }
113   }
114   if (strategy.size() == 1) {
115     strategy_str += ",";
116   }
117   strategy_str += ")";
118   return strategy_str;
119 }
120 
CheckOutputStrategy(const StrategyPtr & out_strategy)121 Status OperatorInfo::CheckOutputStrategy(const StrategyPtr &out_strategy) {
122   if (out_strategy) {
123     MS_LOG(ERROR) << name_ << ": It does not support to set output strategy now, please modify the shard set";
124     return FAILED;
125   }
126   return SUCCESS;
127 }
128 
CheckStrategyByVector(const Shapes & stra,const Shapes & inputs_shape)129 Status OperatorInfo::CheckStrategyByVector(const Shapes &stra, const Shapes &inputs_shape) {
130   size_t strategy_size = stra.size();
131   size_t inputs_shape_size = inputs_shape.size();
132   if (strategy_size != inputs_shape_size) {
133     MS_LOG(ERROR) << name_ << ": The strategy is " << StrategyToString(stra) << ", strategy size: " << strategy_size
134                   << " is not equal to inputs size: " << inputs_shape_size;
135     return FAILED;
136   }
137 
138   for (size_t i = 0; i < strategy_size; ++i) {
139     Shape sub_strategy = stra.at(i);
140     Shape sub_input_shape = inputs_shape.at(i);
141     size_t strategy_len = sub_strategy.size();
142     size_t inputs_len = sub_input_shape.size();
143     MS_LOG(DEBUG) << "Compare: sub_input_shape:" << sub_input_shape << " sub_strategy: " << sub_strategy;
144     if (strategy_len != inputs_len) {
145       MS_LOG(ERROR) << name_ << ": The strategy is " << StrategyToString(stra) << ", strategy len: " << strategy_len
146                     << " is not equal to inputs len: " << inputs_len << ", index: " << i;
147       return FAILED;
148     }
149 
150     for (size_t j = 0; j < strategy_len; ++j) {
151       int64_t strategy_value = sub_strategy.at(j);
152       if (strategy_value < MIN_SLICE_NUM) {
153         MS_LOG(ERROR) << name_ << ": The strategy is " << StrategyToString(stra)
154                       << ", the value of strategy must be larger than 0, but get " << strategy_value;
155         return FAILED;
156       }
157 
158       int64_t shape_value = sub_input_shape.at(j);
159       if (shape_value != -1 && (shape_value % strategy_value) != 0) {
160         if (dynamic_shape_flag_) {
161           Shapes origin_shapes = inputs_shape_clone_;
162           if (strategy_ != nullptr) {  // if strategy_ is not null, means that check output strategy
163             origin_shapes = outputs_shape_clone_;
164           }
165           MS_LOG(ERROR) << name_ << ": The strategy is " << StrategyToString(stra) << ", shape or divisor "
166                         << shape_value << " at " << j << " cannot be divisible by strategy value " << strategy_value
167                         << ", shape is " << ShapeToString(origin_shapes[i]) << ", divisor is "
168                         << ShapeToString(sub_input_shape);
169         } else {
170           MS_LOG(ERROR) << name_ << ": The strategy is " << StrategyToString(stra) << ", shape " << shape_value
171                         << " at " << j << " cannot be divisible by strategy value " << strategy_value << ", shape is "
172                         << ShapeToString(sub_input_shape);
173         }
174         return FAILED;
175       }
176 
177       if ((LongToUlong(strategy_value) & LongToUlong(strategy_value - 1)) != 0) {
178         if ((g_device_manager->DeviceNum() & (g_device_manager->DeviceNum() - 1)) != 0) {
179           MS_LOG(INFO) << name_
180                        << ": The device num is not the power of 2, thus do not check the strategy as power of 2";
181           continue;
182         }
183         MS_LOG(ERROR) << name_ << ": The strategy is " << StrategyToString(stra)
184                       << ", the value of strategy must be the power of 2, but get " << strategy_value;
185         return FAILED;
186       }
187     }
188   }
189 
190   return SUCCESS;
191 }
192 
CheckStrategyValue(const StrategyPtr & strategy,const Shapes & inputs_shape)193 Status OperatorInfo::CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shape) {
194   if (strategy == nullptr) {
195     MS_LOG(ERROR) << name_ << ": The strategy is null.";
196     return FAILED;
197   }
198 
199   Strategies stra = strategy->GetInputDim();
200   return CheckStrategyByVector(stra, inputs_shape);
201 }
202 
ResetQueueMember()203 void OperatorInfo::ResetQueueMember() {
204   inputs_tensor_info_.clear();
205   outputs_tensor_info_.clear();
206   outputs_tensor_map_.clear();
207   out_dev_matrix_shape_.clear();
208   forward_op_.clear();
209   mirror_ops_.clear();
210   sub_ops_.clear();
211   replace_op_.clear();
212   replace_op_info_.clear();
213   virtual_div_op_.clear();
214   if (!is_layout_config_) {
215     inputs_tensor_map_.clear();
216     dev_matrix_shape_.clear();
217   }
218   strategy_ = nullptr;
219   out_strategy_ = nullptr;
220 }
221 
CheckLayoutConfigBase()222 Status OperatorInfo::CheckLayoutConfigBase() {
223   // size
224   if (inputs_tensor_map_.size() != inputs_shape_.size()) {
225     MS_LOG(ERROR) << name_
226                   << ": the size of inputs tensor map and inputs shape must be equal, but the inputs tensor map is "
227                   << inputs_tensor_map_ << ", and the inputs shape is " << inputs_shape_;
228     return FAILED;
229   }
230 
231   for (size_t i = 0; i < inputs_shape_.size(); ++i) {
232     if (inputs_shape_[i].size() != inputs_tensor_map_[i].size()) {
233       MS_LOG(ERROR) << name_
234                     << ": the size of input tensor map and input shape must be equal, but the inputs tensor map is "
235                     << inputs_tensor_map_ << ", and the inputs shape is " << inputs_shape_ << ", the " << i
236                     << "th is not equal";
237       return FAILED;
238     }
239   }
240 
241   size_t dev_matrix_size = dev_matrix_shape_.size();
242   strategy_from_layout_.clear();
243 
244   for (size_t j = 0; j < inputs_tensor_map_.size(); ++j) {
245     Shape tmp_strategy;
246     for (size_t k = 0; k < inputs_tensor_map_[j].size(); ++k) {
247       auto map = inputs_tensor_map_[j][k];
248 
249       // range
250       if (map == MAP_NONE) {
251         tmp_strategy.push_back(NO_SPLIT_STRATEGY);
252         continue;
253       }
254 
255       if (map < 0 || map >= SizeToLong(dev_matrix_size)) {
256         MS_LOG(ERROR) << name_ << ": the range of tensor_map's value is [-1, " << (dev_matrix_size - 1)
257                       << "], but the inputs tensor map is " << inputs_tensor_map_;
258         return FAILED;
259       }
260 
261       // divisible
262       auto shard_num = dev_matrix_shape_[dev_matrix_size - LongToSize(map) - 1];
263       MS_EXCEPTION_IF_ZERO("shard_num", shard_num);
264       if (inputs_shape_[j][k] % shard_num != 0) {
265         MS_LOG(ERROR) << name_ << ": the shape is not divisible by layout, the input shape is " << inputs_shape_
266                       << ", the dev matrix is " << dev_matrix_shape_ << ", and the tensor map is "
267                       << inputs_tensor_map_;
268         return FAILED;
269       }
270 
271       // if shard_num is 1, reset the map to -1
272       if (shard_num == NO_SPLIT_STRATEGY) {
273         inputs_tensor_map_[j][k] = MAP_NONE;
274       }
275       tmp_strategy.push_back(shard_num);
276     }
277     strategy_from_layout_.push_back(tmp_strategy);
278   }
279 
280   MS_LOG(INFO) << name_ << ": the strategy from layout is " << strategy_from_layout_;
281   return SUCCESS;
282 }
283 
GetLayoutConfig()284 Status OperatorInfo::GetLayoutConfig() {
285   auto layout_iter = attrs_.find(LAYOUT);
286   if (layout_iter == attrs_.end()) {
287     return SUCCESS;
288   }
289 
290   MS_EXCEPTION_IF_NULL(layout_iter->second);
291   auto layout = layout_iter->second;
292   if (!layout->isa<ValueDictionary>()) {
293     MS_LOG(ERROR) << name_ << ": the layout is not a dict";
294     return FAILED;
295   }
296 
297   auto dict = layout->cast<ValueDictionaryPtr>();
298   for (const auto &kv : dict->value()) {
299     ValuePtr key_ptr = kv.first;
300     ValuePtr value_ptr = kv.second;
301     MS_EXCEPTION_IF_NULL(key_ptr);
302     MS_EXCEPTION_IF_NULL(value_ptr);
303     if (!key_ptr->isa<StringImm>()) {
304       MS_LOG(EXCEPTION) << name_ << ": the value of key is not string";
305     }
306 
307     std::string key = key_ptr->cast<StringImmPtr>()->value();
308 
309     if (key == DEV_MATRIX) {
310       if (!value_ptr->isa<ValueTuple>()) {
311         MS_LOG(ERROR) << name_ << ": the type of dev matrix is not tuple";
312         return FAILED;
313       }
314 
315       dev_matrix_shape_ = GetValue<std::vector<int64_t>>(value_ptr);
316       auto used_devices =
317         std::accumulate(dev_matrix_shape_.begin(), dev_matrix_shape_.end(), 1, std::multiplies<int64_t>());
318       if (used_devices != stage_device_size_) {
319         MS_LOG(ERROR) << name_
320                       << ": the product of dev matrix must be equal to the stage divece size, but the dev matrix is "
321                       << dev_matrix_shape_ << ", but the stage device size is " << stage_device_size_;
322         return FAILED;
323       }
324       continue;
325     }
326 
327     if (key == INPUT_TENSOR_MAP) {
328       auto var = value_ptr->cast<ValueTuplePtr>();
329       if (!value_ptr->isa<ValueTuple>()) {
330         MS_LOG(ERROR) << name_ << ": the type of input_tensor_map is not tuple";
331         return FAILED;
332       }
333 
334       std::vector<ValuePtr> elements = var->value();
335       for (const auto &ele : elements) {
336         Shape tensor_map;
337         if (ele->isa<ValueSequence>()) {
338           auto value_tuple = ele->cast<ValueTuplePtr>();
339           std::vector<ValuePtr> value_vector = value_tuple->value();
340           (void)std::transform(value_vector.begin(), value_vector.end(), std::back_inserter(tensor_map),
341                                [](const ValuePtr &value) { return static_cast<int64_t>(GetValue<int64_t>(value)); });
342           inputs_tensor_map_.push_back(tensor_map);
343         } else {
344           MS_LOG(ERROR) << name_ << ": the format of input tensor map is wrong! Need ValueSequence";
345           return FAILED;
346         }
347       }
348       continue;
349     }
350 
351     MS_LOG(ERROR) << name_ << ": the invalid key for layout: " << key;
352     return FAILED;
353   }
354 
355   MS_LOG(INFO) << name_ << ": the dev matrix is " << dev_matrix_shape_ << ", the tensor map is " << inputs_tensor_map_;
356 
357   is_layout_config_ = true;
358 
359   return CheckLayoutConfigBase();
360 }
361 
IsDynamicShape()362 bool OperatorInfo::IsDynamicShape() {
363   for (auto &input_shape : inputs_shape_) {
364     auto in_it = std::find_if(input_shape.cbegin(), input_shape.cend(), [&](const int64_t ele) { return ele == -1; });
365     if (in_it != input_shape.end()) {
366       return True;
367     }
368   }
369 
370   for (auto &output_shape : outputs_shape_) {
371     auto out_it =
372       std::find_if(output_shape.cbegin(), output_shape.cend(), [&](const int64_t ele) { return ele == -1; });
373     if (out_it != output_shape.end()) {
374       return True;
375     }
376   }
377   return False;
378 }
379 
IsDynamicRank()380 bool OperatorInfo::IsDynamicRank() {
381   for (auto &input_shape : inputs_shape_) {
382     auto in_it = std::find_if(input_shape.cbegin(), input_shape.cend(), [&](const int64_t ele) { return ele == -2; });
383     if (in_it != input_shape.end()) {
384       return True;
385     }
386   }
387 
388   for (auto &output_shape : outputs_shape_) {
389     auto out_it =
390       std::find_if(output_shape.cbegin(), output_shape.cend(), [&](const int64_t ele) { return ele == -2; });
391     if (out_it != output_shape.end()) {
392       return True;
393     }
394   }
395   return False;
396 }
397 
IsSelfDefineShard()398 bool OperatorInfo::IsSelfDefineShard() {
399   bool self_define_shard = false;
400   auto attr_iter = attrs_.find(parallel::SELF_DEFINE_SHARD);
401   if (attr_iter != attrs_.end()) {
402     self_define_shard = attr_iter->second->cast<BoolImmPtr>()->value();
403   }
404   return self_define_shard;
405 }
406 
GetRepeatedNumInDevMatrixRight()407 Status OperatorInfo::GetRepeatedNumInDevMatrixRight() {
408   bool repeated_num_right = true;
409   auto iter = attrs_.find(REPEATED_NUM_IN_DEV_MATRIX_RIGHT);
410   if (iter != attrs_.end()) {
411     MS_EXCEPTION_IF_NULL(iter->second);
412     if (iter->second->isa<BoolImm>()) {
413       repeated_num_right = iter->second->cast<BoolImmPtr>()->value();
414       MS_LOG(INFO) << name_ << ": attr " << REPEATED_NUM_IN_DEV_MATRIX_RIGHT << " will be set to "
415                    << repeated_num_right;
416     } else {
417       MS_LOG(ERROR) << name_ << ": The value of " << REPEATED_NUM_IN_DEV_MATRIX_RIGHT << " is not bool.";
418       return FAILED;
419     }
420   }
421   repeated_num_in_dev_matrix_right_ = repeated_num_right;
422   return SUCCESS;
423 }
424 
InferAttrs()425 Status OperatorInfo::InferAttrs() {
426   if (infer_attrs_completed_) {
427     return SUCCESS;
428   }
429 
430   if (GetAttrs() != SUCCESS) {
431     return FAILED;
432   }
433 
434   if (GetRepeatedNumInDevMatrixRight() != SUCCESS) {
435     return FAILED;
436   }
437 
438   if (GetLayoutConfig() != SUCCESS) {
439     return FAILED;
440   }
441 
442   if (is_layout_config_ && CheckLayoutConfig() != SUCCESS) {
443     return FAILED;
444   }
445 
446   self_define_shard_ = IsSelfDefineShard();
447   is_dynamic_shape_ = IsDynamicShape();
448   is_dynamic_rank_ = IsDynamicRank();
449   if (is_dynamic_rank_) {
450     MS_LOG(ERROR) << name_
451                   << ": it does not support dynamic rank now, the inupts' shape: " << ShapesToString(inputs_shape_)
452                   << ", the outputs' shape: " << ShapesToString(outputs_shape_);
453     return FAILED;
454   }
455 
456   inputs_shape_clone_ = inputs_shape_;
457   outputs_shape_clone_ = outputs_shape_;
458 
459   infer_attrs_completed_ = true;
460   return SUCCESS;
461 }
462 
InferMirrorOps()463 Status OperatorInfo::InferMirrorOps() {
464   mirror_ops_.clear();
465   if (inputs_shape_.empty()) {
466     MS_LOG(INFO) << name_ << ": The inputs size is empty";
467     return SUCCESS;
468   }
469 
470   if (inputs_tensor_map_.size() != inputs_shape_.size()) {
471     MS_LOG(ERROR) << name_ << ": The size of inputs tensor map is not equal to the size of inputs shape";
472     return FAILED;
473   }
474 
475   bool group_is_empty = true;
476   for (size_t i = 0; i < inputs_tensor_map_.size(); ++i) {
477     std::vector<Group> group;
478     if (CreateGroupByTensorMap(inputs_tensor_map_[i], &group) != SUCCESS) {
479       ReportError(name_ + ": Create group failed, the input index is " + std::to_string(i));
480       mirror_ops_.clear();
481       return FAILED;
482     }
483 
484     OperatorVector mirror_op;
485     if (group.empty()) {
486       MS_LOG(INFO) << name_ << ": The mirror group is empty, the input index is " << i;
487       mirror_ops_.push_back(mirror_op);
488       continue;
489     }
490 
491     group_is_empty = false;
492     mirror_op = CreateMirrorOps(group[0].name(), group[0].GetDevNum());
493     mirror_ops_.push_back(mirror_op);
494   }
495 
496   if (group_is_empty) {
497     mirror_ops_.clear();
498     MS_LOG(INFO) << name_ << ": No need to insert mirror ops";
499   }
500   return SUCCESS;
501 }
502 
InferMirrorOpsByLayout()503 Status OperatorInfo::InferMirrorOpsByLayout() {
504   mirror_ops_.clear();
505   if (inputs_shape_.empty()) {
506     MS_LOG(INFO) << name_ << ": The inputs size is empty";
507     return SUCCESS;
508   }
509 
510   bool group_is_empty = true;
511   for (size_t i = 0; i < inputs_tensor_info_.size(); ++i) {
512     auto input_tensor_layout = inputs_tensor_info_[i].tensor_layout();
513     auto repeated_rank_list = input_tensor_layout.InferRepeatedGroup();
514 
515     OperatorVector mirror_op;
516     if (repeated_rank_list.size() == 1) {
517       MS_LOG(INFO) << name_ << ": The mirror group is empty, the input index is " << i;
518       mirror_ops_.push_back(mirror_op);
519       continue;
520     }
521     if (is_auto_parallel_) {
522       if (g_device_manager->CheckDeviceList(repeated_rank_list) != SUCCESS) {
523         MS_LOG(INFO) << name_ << ": Try to create communication group : " << repeated_rank_list
524                      << " failed in auto parallel mode, "
525                         "this error can be ignored in parallel strategies searching step";
526         return FAILED;
527       }
528       return SUCCESS;
529     }
530 
531     Group mirror_group;
532     if (g_device_manager->CreateGroup(repeated_rank_list, &mirror_group) != SUCCESS) {
533       MS_LOG(ERROR) << name_
534                     << ": Create communication group by tensor_map failed, the rank_list is: " << repeated_rank_list
535                     << ", the full_name of node is: " << cnode_->fullname_with_scope();
536       return FAILED;
537     }
538     group_is_empty = false;
539     mirror_op = CreateMirrorOps(mirror_group.name(), mirror_group.GetDevNum());
540     mirror_ops_.push_back(mirror_op);
541   }
542 
543   if (group_is_empty) {
544     mirror_ops_.clear();
545     MS_LOG(INFO) << name_ << ": No need to insert mirror ops";
546   }
547   return SUCCESS;
548 }
549 
CreateTensorInfo(const Shape & device_matrix,const ShapeBasePtr & inputs_shape,const ShapeBasePtr & inputs_tensor_map)550 TensorInfoBasePtr CreateTensorInfo(const Shape &device_matrix, const ShapeBasePtr &inputs_shape,
551                                    const ShapeBasePtr &inputs_tensor_map) {
552   TensorInfoBasePtr out_tensor_info;
553   if (inputs_shape->is_list()) {
554     std::vector<TensorInfoBasePtr> tensor_info_list;
555     for (int64_t i = 0; i < SizeToLong(inputs_shape->size()); ++i) {
556       auto tensor_map = inputs_tensor_map->GetElement(i);
557       auto shape = inputs_shape->GetElement(i);
558       auto input_tensor_info = CreateTensorInfo(device_matrix, shape, tensor_map);
559       tensor_info_list.emplace_back(input_tensor_info);
560     }
561     out_tensor_info = std::make_shared<TensorInfoList>(tensor_info_list);
562   } else {
563     TensorLayout input_layout;
564     input_layout.InitFromVector(device_matrix, inputs_tensor_map->GetValue(), inputs_shape->GetValue());
565     TensorInfo input_tensor_info(input_layout);
566     out_tensor_info = std::make_shared<TensorInfoValue>(input_tensor_info);
567   }
568   return out_tensor_info;
569 }
570 
InferTensorInfoNew()571 Status OperatorInfo::InferTensorInfoNew() {
572   size_t real_input_index = 0;
573   for (size_t i = 0; i < inputs_tensor_map_new_.size(); ++i) {
574     // Insert placeholder TensorInfo for optional input
575     while (real_input_index < input_value_.size() && input_value_[real_input_index] != nullptr &&
576            input_value_[real_input_index]->isa<None>()) {
577       (void)inputs_tensor_info_new_.emplace_back(std::make_shared<TensorInfoValue>(TensorInfo()));
578       ++real_input_index;
579     }
580     auto input_tensor_info = CreateTensorInfo(dev_matrix_shape_, inputs_shape_new_[i], inputs_tensor_map_new_[i]);
581     inputs_tensor_info_new_.emplace_back(input_tensor_info);
582     ++real_input_index;
583   }
584 
585   for (size_t i = 0; i < outputs_tensor_map_new_.size(); ++i) {
586     auto output_tensor_info = CreateTensorInfo(dev_matrix_shape_, outputs_shape_new_[i], outputs_tensor_map_new_[i]);
587     outputs_tensor_info_new_.emplace_back(output_tensor_info);
588   }
589   return SUCCESS;
590 }
591 
UpdateOutputTensorInfoForInterleaved()592 void OperatorInfo::UpdateOutputTensorInfoForInterleaved() {
593   if (!std::any_of(inputs_tensor_info_.begin(), inputs_tensor_info_.end(), [](const TensorInfo &input_tensor_info) {
594         return input_tensor_info.tensor_layout().IsInterleavedParallel();
595       })) {
596     return;
597   }
598   if (std::any_of(outputs_tensor_info_.begin(), outputs_tensor_info_.end(), [](const TensorInfo &output_tensor_info) {
599         return output_tensor_info.tensor_layout().IsInterleavedParallel();
600       })) {
601     return;
602   }
603   auto interleaved_num = ParallelContext::GetInstance()->fine_grained_micro_interleaved_size();
604   auto output_dev_matrix = outputs_tensor_info_[kIndex0].tensor_layout().device_arrangement_origin().array();
605   output_dev_matrix[output_dev_matrix.size() - 1] = interleaved_num;
606   Arrangement out_device_arrangement_interleaved;
607   out_device_arrangement_interleaved.Init(output_dev_matrix);
608   auto new_tensor_layout = outputs_tensor_info_[kIndex0].tensor_layout();
609   new_tensor_layout.set_device_arrangement_interleaved(out_device_arrangement_interleaved);
610   TensorInfo new_output_tensor_info(new_tensor_layout);
611   outputs_tensor_info_[kIndex0] = new_output_tensor_info;
612 }
613 
InferTensorInfo()614 Status OperatorInfo::InferTensorInfo() {
615   if (!inputs_shape_new_.empty()) {
616     return InferTensorInfoNew();
617   }
618   if (inputs_shape_.empty() || outputs_shape_.empty() || inputs_tensor_map_.empty() || outputs_tensor_map_.empty()) {
619     MS_LOG(ERROR) << name_ << ": Invalid args";
620     return FAILED;
621   }
622 
623   size_t real_input_index = 0;
624   for (size_t i = 0; i < inputs_tensor_map_.size(); ++i) {
625     // Insert placeholder TensorInfo for optional input
626     while (real_input_index < input_value_.size() && input_value_[real_input_index] != nullptr &&
627            input_value_[real_input_index]->isa<None>()) {
628       (void)inputs_tensor_info_.emplace_back(TensorInfo());
629       ++real_input_index;
630     }
631     TensorLayout input_layout;
632     if (input_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[i], inputs_shape_[i]) != SUCCESS) {
633       MS_LOG(ERROR) << name_ << ": Infer input tensor layout failed, the index is " << i;
634       return FAILED;
635     }
636     TensorInfo input_tensor_info(input_layout);
637     inputs_tensor_info_.push_back(input_tensor_info);
638     ++real_input_index;
639   }
640 
641   for (size_t i = 0; i < outputs_tensor_map_.size(); ++i) {
642     TensorLayout output_layout;
643     if (output_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[i], outputs_shape_[i]) != SUCCESS) {
644       MS_LOG(ERROR) << name_ << ": Infer output tensor layout failed, the index is " << i;
645       return FAILED;
646     }
647     TensorInfo output_tensor_info(output_layout);
648     outputs_tensor_info_.push_back(output_tensor_info);
649   }
650 
651   return SUCCESS;
652 }
653 
InferRepeatedCalcInfo()654 Status OperatorInfo::InferRepeatedCalcInfo() {
655   int64_t g_dev_list_size = stage_device_size_;
656   int64_t dev_matrix_size =
657     std::accumulate(dev_matrix_shape_.begin(), dev_matrix_shape_.end(), 1, std::multiplies<int64_t>());
658   if (dev_matrix_size == 0) {
659     MS_LOG(ERROR) << name_ << ": The dev matrix size is 0";
660     return FAILED;
661   }
662 
663   if (g_dev_list_size == dev_matrix_size) {
664     repeated_calc_num_ = 1;
665   } else if (g_dev_list_size % dev_matrix_size == 0) {
666     repeated_calc_num_ = g_dev_list_size / dev_matrix_size;
667   } else {
668     MS_LOG(ERROR) << name_ << ": The strategy is " << StrategyToString(strategy_->GetInputDim()) << ", it requires "
669                   << dev_matrix_size << " devices, "
670                   << "but the device number of this stage is " << g_dev_list_size << ", it can not be divisible by "
671                   << dev_matrix_size;
672     return FAILED;
673   }
674   return SUCCESS;
675 }
676 
677 // If repeated calculation, set the repeated_calc_num as the last dimension of dev-matrix in default,
678 // because if the previous shard is (a, b), and the next shard is (a, 1), adding the repeated_calc_num
679 // to the last dimension of dev-matrix, there is no need to redistribution.
SetRepeatedCalcDevMatrix()680 void OperatorInfo::SetRepeatedCalcDevMatrix() {
681   if (repeated_calc_num_ <= 1) {
682     return;
683   }
684   if (repeated_num_in_dev_matrix_right_) {
685     dev_matrix_shape_.push_back(repeated_calc_num_);
686   } else {
687     (void)dev_matrix_shape_.insert(dev_matrix_shape_.cbegin(), repeated_calc_num_);
688   }
689 }
690 
ResetTupleTensorMapIfRepeatedCalc(NewTensorMaps * tensor_map_new)691 void OperatorInfo::ResetTupleTensorMapIfRepeatedCalc(NewTensorMaps *tensor_map_new) {
692   MS_EXCEPTION_IF_NULL(tensor_map_new);
693   for (auto &tensor_map : *tensor_map_new) {
694     if (tensor_map->is_list()) {
695       std::vector<ShapeBasePtr> new_list;
696       for (auto &elements : tensor_map->GetAllElements()) {
697         std::vector<int64_t> new_shape;
698         for (auto &element : elements) {
699           if (element != MAP_NONE) {
700             element += 1;
701           }
702           new_shape.emplace_back(element);
703         }
704         new_list.emplace_back(std::make_shared<ShapeValue>(new_shape));
705       }
706       tensor_map->set_shape(std::make_shared<ShapeList>(new_list));
707     } else {
708       std::vector<int64_t> new_shape;
709       for (auto &element : tensor_map->GetValue()) {
710         if (element != MAP_NONE) {
711           element += 1;
712         }
713         new_shape.emplace_back(element);
714       }
715       tensor_map->set_shape(std::make_shared<ShapeValue>(new_shape));
716     }
717   }
718 }
719 
720 // If repeated calculation, and the repeated_calc_num is inserted to the last dimension of the dev-matrix,
721 // the index value of tensor map needs to be increased by 1.
ResetTensorMapIfRepeatedCalc()722 void OperatorInfo::ResetTensorMapIfRepeatedCalc() {
723   if ((repeated_calc_num_ <= 1) || !repeated_num_in_dev_matrix_right_) {
724     return;
725   }
726 
727   MS_LOG(DEBUG) << name_ << ": the repeated calc num is " << repeated_calc_num_ << ", and reset the tensor maps";
728   for (auto &tensor_map : inputs_tensor_map_) {
729     for (auto &element : tensor_map) {
730       if (element == MAP_NONE) {
731         continue;
732       }
733       element += 1;
734     }
735   }
736 
737   for (auto &tensor_map : outputs_tensor_map_) {
738     for (auto &element : tensor_map) {
739       if (element == MAP_NONE) {
740         continue;
741       }
742       element += 1;
743     }
744   }
745 
746   ResetTupleTensorMapIfRepeatedCalc(&inputs_tensor_map_new_);
747   ResetTupleTensorMapIfRepeatedCalc(&outputs_tensor_map_new_);
748 }
749 
750 // use for loss repeated calculation
CreateVirtualDivOp(int64_t div_num)751 Operator CreateVirtualDivOp(int64_t div_num) {
752   OperatorName operator_name = VIRTUAL_DIV;
753   ValuePtr attr0_value = MakeValue(div_num);
754   Attr attr0 = std::make_pair(DIVISOR, attr0_value);
755   OperatorAttrs operator_attrs;
756   operator_attrs.push_back(attr0);
757 
758   OperatorParams operator_param;
759   OperatorArgs operator_arg = std::make_pair(operator_attrs, operator_param);
760 
761   Operator op = std::make_pair(operator_name, operator_arg);
762   return op;
763 }
764 
CreateDivOp(float scale)765 Operator CreateDivOp(float scale) {
766   OperatorName operator_name = REAL_DIV;
767   OperatorAttrs operator_attrs;
768   OperatorParams operator_param;
769   constexpr size_t parameter_pos = 2;
770   mindspore::tensor::TensorPtr tensor_ptr = std::make_shared<mindspore::tensor::Tensor>(scale);
771   ValuePtr scale_value = MakeValue(tensor_ptr);
772   (void)operator_param.emplace_back(std::make_pair(std::make_pair(Y, scale_value), parameter_pos));
773   OperatorArgs operator_arg = std::make_pair(operator_attrs, operator_param);
774 
775   Operator op = std::make_pair(operator_name, operator_arg);
776   return op;
777 }
778 
CreateScalarFloorDivOp(int64_t div_num)779 Operator CreateScalarFloorDivOp(int64_t div_num) {
780   OperatorName operator_name = SCALAR_FLOOR_DIV;
781   OperatorAttrs operator_attrs;
782   OperatorParams operator_param;
783   constexpr size_t parameter_pos = 2;
784   ValuePtr scale_value = MakeValue(div_num);
785   (void)operator_param.emplace_back(std::make_pair(std::make_pair(Y, scale_value), parameter_pos));
786   OperatorArgs operator_arg = std::make_pair(operator_attrs, operator_param);
787 
788   Operator op = std::make_pair(operator_name, operator_arg);
789   return op;
790 }
791 
CreateScalarDivOp(int64_t div_num)792 Operator CreateScalarDivOp(int64_t div_num) {
793   OperatorName operator_name = SCALAR_DIV;
794   OperatorAttrs operator_attrs;
795   OperatorParams operator_param;
796   constexpr size_t parameter_pos = 2;
797   ValuePtr scale_value = MakeValue(std::make_shared<Int64Imm>(div_num));
798   (void)operator_param.emplace_back(std::make_pair(std::make_pair(Y, scale_value), parameter_pos));
799   OperatorArgs operator_arg = std::make_pair(operator_attrs, operator_param);
800 
801   Operator op = std::make_pair(operator_name, operator_arg);
802   return op;
803 }
804 
CreateScalarMulOp(int64_t scalar)805 Operator CreateScalarMulOp(int64_t scalar) {
806   OperatorName operator_name = SCALAR_MUL;
807   OperatorAttrs operator_attrs;
808   OperatorParams operator_param;
809   constexpr size_t parameter_pos = 2;
810   ValuePtr scale_value = MakeValue(std::make_shared<Int64Imm>(scalar));
811   (void)operator_param.emplace_back(std::make_pair(std::make_pair(Y, scale_value), parameter_pos));
812   OperatorArgs operator_arg = std::make_pair(operator_attrs, operator_param);
813 
814   Operator op = std::make_pair(operator_name, operator_arg);
815   return op;
816 }
817 
CreateReduceCommunicationOpArgs(const std::string & reduce_op,const std::string & group)818 static OperatorArgs CreateReduceCommunicationOpArgs(const std::string &reduce_op, const std::string &group) {
819   ValuePtr attr0_value = MakeValue(reduce_op);
820   ValuePtr attr1_value = MakeValue(group);
821   Attr attr0 = std::make_pair(OP, attr0_value);
822   Attr attr1 = std::make_pair(GROUP, attr1_value);
823   OperatorAttrs operator_attrs;
824   operator_attrs.push_back(attr0);
825   operator_attrs.push_back(attr1);
826 
827   OperatorParams operator_param;
828   return std::make_pair(operator_attrs, operator_param);
829 }
830 
831 // use for forward all reduce
CreateAllReduceOp(const std::string & reduce_op,const std::string & group)832 Operator CreateAllReduceOp(const std::string &reduce_op, const std::string &group) {
833   OperatorName operator_name = ALL_REDUCE;
834   OperatorArgs operator_arg = CreateReduceCommunicationOpArgs(reduce_op, group);
835 
836   Operator op = std::make_pair(operator_name, operator_arg);
837   MS_LOG(INFO) << "Create all reduce op success, the reduce_op is  " << reduce_op << ", the group is " << group;
838   return op;
839 }
840 
CreateReduceScatterOp(const std::string & reduce_op,const std::string & group)841 Operator CreateReduceScatterOp(const std::string &reduce_op, const std::string &group) {
842   OperatorName operator_name = REDUCE_SCATTER;
843   OperatorArgs operator_arg = CreateReduceCommunicationOpArgs(reduce_op, group);
844 
845   Operator op = std::make_pair(operator_name, operator_arg);
846   MS_LOG(INFO) << "Create reduce scatter op success, the reduce_op is  " << reduce_op << ", the group is " << group;
847   return op;
848 }
849 
CreateCastOp(TypePtr type)850 Operator CreateCastOp(TypePtr type) {
851   auto type_id = MakeValue(static_cast<int64_t>(type->type_id()));
852   Param param_type = std::make_pair(std::make_pair(DTYPE, type_id), 2);
853   OperatorAttrs attrs;
854   OperatorParams params = {param_type};
855   OperatorArgs args = std::make_pair(attrs, params);
856   Operator op_cast = std::make_pair(CAST, args);
857   return op_cast;
858 }
859 
CreateScalarCastOp(TypePtr type)860 Operator CreateScalarCastOp(TypePtr type) {
861   auto type_id = MakeValue(static_cast<int64_t>(type->type_id()));
862   Param param_type = std::make_pair(std::make_pair(DTYPE, type_id), 2);
863   OperatorAttrs attrs;
864   OperatorParams params = {param_type};
865   OperatorArgs args = std::make_pair(attrs, params);
866   Operator op_cast = std::make_pair(SCALAR_CAST, args);
867   return op_cast;
868 }
869 
AddCommOpFusionType(const CNodePtr & comm_node,const AnfNodePtr & param_node)870 int32_t AddCommOpFusionType(const CNodePtr &comm_node, const AnfNodePtr &param_node) {
871   MS_EXCEPTION_IF_NULL(comm_node);
872   MS_EXCEPTION_IF_NULL(param_node);
873   ParameterPtr param;
874   if (IsPrimitiveCNode(param_node, prim::kPrimReceive)) {
875     param = param_node->user_data<AnfNode>(PIPELINE_PARAM)->cast<ParameterPtr>();
876   } else {
877     param = param_node->cast<ParameterPtr>();
878   }
879   MS_EXCEPTION_IF_NULL(param);
880   auto prim = GetValueNode<PrimitivePtr>(comm_node->input(0));
881   MS_EXCEPTION_IF_NULL(prim);
882   auto attrs = prim->attrs();
883   auto param_info = param->param_info();
884   if (!param_info) {
885     MS_LOG(WARNING) << param->ToString() << "does not have parameter info.";
886     return 0;
887   }
888   int32_t fusion_type = param_info->comm_fusion();
889   attrs[FUSION] = MakeValue<int64_t>(fusion_type);
890   (void)prim->SetAttrs(attrs);
891   bool parallel_optimizer_comm_recompute = param_info->parallel_optimizer_comm_recompute();
892   std::string instance_name = prim->instance_name();
893   if (instance_name == PARALLEL_OPTIMIZER_ALLGATHER_NOT_COMPUTE && parallel_optimizer_comm_recompute &&
894       prim->name() == ALL_GATHER) {
895     prim->set_attr(RECOMPUTE, MakeValue(true));
896     prim->set_instance_name(PARALLEL_OPTIMIZER_ALLGATHER);
897   }
898   MS_LOG(INFO) << "Set comm fusion:" << param->param_info()->name() << "'s fusion type is " << fusion_type;
899   return fusion_type;
900 }
901 
AddCommOpMeanFlag(const CNodePtr & comm_node)902 void AddCommOpMeanFlag(const CNodePtr &comm_node) {
903   MS_EXCEPTION_IF_NULL(comm_node);
904   auto prim = GetValueNode<PrimitivePtr>(comm_node->input(0));
905   auto attrs = prim->attrs();
906   MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
907   bool mean_flag = ParallelContext::GetInstance()->gradients_mean();
908   attrs[MEAN_FLAG] = MakeValue<bool>(mean_flag);
909   (void)prim->SetAttrs(attrs);
910 }
911 
AddCNodePrimAttr(const CNodePtr & comm_node,const std::string & attr_name,const ValuePtr & attr_val)912 void AddCNodePrimAttr(const CNodePtr &comm_node, const std::string &attr_name, const ValuePtr &attr_val) {
913   MS_EXCEPTION_IF_NULL(comm_node);
914   auto prim = GetValueNode<PrimitivePtr>(comm_node->input(0));
915   auto attrs = prim->attrs();
916   attrs[attr_name] = attr_val;
917   (void)prim->SetAttrs(attrs);
918 }
919 
AddCommOpParamFlag(const CNodePtr & comm_node)920 void AddCommOpParamFlag(const CNodePtr &comm_node) {
921   MS_EXCEPTION_IF_NULL(comm_node);
922   auto graph = comm_node->func_graph();
923   MS_EXCEPTION_IF_NULL(graph);
924   auto manager = graph->manager();
925   MS_EXCEPTION_IF_NULL(manager);
926   auto node_users = manager->node_users()[comm_node->input(1)];
927   for (auto &node_user : node_users) {
928     if (IsPrimitiveCNode(node_user.first, prim::kPrimSend)) {
929       auto prim = GetCNodePrimitive(comm_node);
930       (void)prim->AddAttr(PARAMETER_MICRO, MakeValue(0));
931       return;
932     }
933   }
934 }
935 
CreateAllGatherOp(const std::string & group)936 Operator CreateAllGatherOp(const std::string &group) {
937   OperatorName operator_name = ALL_GATHER;
938   // group
939   ValuePtr attr0_value = MakeValue(group);
940   Attr attr0 = std::make_pair(GROUP, attr0_value);
941   OperatorAttrs operator_attrs;
942   operator_attrs.push_back(attr0);
943 
944   OperatorParams operator_param;
945   OperatorArgs operator_arg = std::make_pair(operator_attrs, operator_param);
946 
947   Operator op = std::make_pair(operator_name, operator_arg);
948   MS_LOG(INFO) << "Create allgather op success, the group is " << group;
949   return op;
950 }
951 
CreateMicroStepAllGatherOp(const std::string & group)952 Operator CreateMicroStepAllGatherOp(const std::string &group) {
953   bool mean_flag = ParallelContext::GetInstance()->gradients_mean();
954   OperatorName operator_name = MICRO_STEP_ALL_GATHER;
955   // group
956   ValuePtr attr0_value = MakeValue(group);
957   Attr attr0 = std::make_pair(GROUP, attr0_value);
958   // mean_flag
959   ValuePtr attr1_value = MakeValue(mean_flag);
960   Attr attr1 = std::make_pair(MEAN_FLAG, attr1_value);
961   OperatorAttrs operator_attrs;
962   operator_attrs.push_back(attr0);
963   operator_attrs.push_back(attr1);
964 
965   OperatorParams operator_param;
966   OperatorArgs operator_arg = std::make_pair(operator_attrs, operator_param);
967 
968   Operator op = std::make_pair(operator_name, operator_arg);
969   MS_LOG(INFO) << "Create MICRO_STEP_ALL_GATHER success, the group is " << group;
970   return op;
971 }
972 
973 // use for get tensor slice
CreateGetTensorSliceOp(const TensorLayout & tensor_layout)974 Operator CreateGetTensorSliceOp(const TensorLayout &tensor_layout) {
975   Shape tensor_map = tensor_layout.tensor_map().array();
976   Shape dev_matrix_shape = tensor_layout.device_arrangement().array();
977   Shape slice_shape = tensor_layout.base_slice_shape().array();
978   Shape full_shape = tensor_layout.tensor_shape().array();
979   OperatorName operator_name = GET_TENSOR_SLICE;
980 
981   OperatorAttrs attrs;
982   ValuePtr dev_mat_value = MakeValue(dev_matrix_shape);
983   Param dev_mat_param = std::make_pair(std::make_pair(DEV_MAT, dev_mat_value), 2);
984   ValuePtr tensor_map_value = MakeValue(tensor_map);
985   Param tensor_map_param = std::make_pair(std::make_pair(TENSOR_MAP, tensor_map_value), 3);
986   ValuePtr slice_shape_value = MakeValue(slice_shape);
987   Param slice_shape_param = std::make_pair(std::make_pair(SLICE_SHAPE, slice_shape_value), 4);
988   ValuePtr full_shape_value = MakeValue(full_shape);
989   Param full_shape_param = std::make_pair(std::make_pair(FULL_SHAPE, full_shape_value), 5);
990   OperatorParams params = {dev_mat_param, tensor_map_param, slice_shape_param, full_shape_param};
991   OperatorArgs operator_arg = std::make_pair(attrs, params);
992 
993   Operator op = std::make_pair(operator_name, operator_arg);
994   MS_LOG(INFO) << "Create get tensor slice op success, the dev mat and tensor map is "
995                << ShapeToString(dev_matrix_shape) << ", " << ShapeToString(tensor_map);
996   return op;
997 }
998 
CreateMirrorOps(const std::string & group_name,size_t dev_num)999 OperatorVector CreateMirrorOps(const std::string &group_name, size_t dev_num) {
1000   if (dev_num == 0) {
1001     MS_LOG(EXCEPTION) << "Invalid dev num: " << dev_num;
1002   }
1003   OperatorVector op_for_weight;
1004   bool mean_flag = ParallelContext::GetInstance()->gradients_mean();
1005   int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step();
1006   int64_t split_stage_num = ParallelContext::GetInstance()->pipeline_stage_split_num();
1007 
1008   ValuePtr attr0_value = MakeValue(group_name);
1009   ValuePtr attr1_value = MakeValue(SizeToLong(dev_num));
1010   ValuePtr attr2_value = MakeValue(mean_flag);
1011 
1012   Attr attr0 = std::make_pair(GROUP, attr0_value);
1013   Attr attr1 = std::make_pair(DEV_NUM, attr1_value);
1014   Attr attr2 = std::make_pair(MEAN_FLAG, attr2_value);
1015 
1016   OperatorAttrs operator_attrs;
1017   operator_attrs.push_back(attr0);
1018   operator_attrs.push_back(attr1);
1019   operator_attrs.push_back(attr2);
1020 
1021   OperatorName operator_name;
1022   if (grad_accumulation_step > 1 || split_stage_num > 1) {
1023     operator_name = MIRROR_MICRO_STEP_OPERATOR;
1024   } else {
1025     operator_name = MIRROR_OPERATOR;
1026   }
1027 
1028   OperatorParams operator_param;
1029   OperatorArgs operator_args = std::make_pair(operator_attrs, operator_param);
1030 
1031   Operator op = std::make_pair(operator_name, operator_args);
1032 
1033   op_for_weight.push_back(op);
1034   MS_LOG(INFO) << "The group name is " << group_name << ", the dev num is " << dev_num << ", the mean flag is "
1035                << mean_flag;
1036   return op_for_weight;
1037 }
1038 
CreateGroupByTensorMap(const Shape & tensor_map,std::vector<Group> * group)1039 Status OperatorInfo::CreateGroupByTensorMap(const Shape &tensor_map, std::vector<Group> *group) {
1040   if (group == nullptr) {
1041     MS_LOG(ERROR) << name_ << ": The group is null.";
1042     return FAILED;
1043   }
1044   CheckGlobalDeviceManager();
1045   int64_t rank = g_device_manager->global_rank();
1046   DeviceMatrix dev_matrix(rank, stage_device_list_, dev_matrix_shape_);
1047   RankList group_devices;
1048   if (dev_matrix.GetDevicesByTensorMap(tensor_map, &group_devices) != SUCCESS) {
1049     return FAILED;
1050   }
1051 
1052   if (group_devices.size() == 1 && !((ParallelContext::GetInstance()->grad_accumulation_step() > 1 ||
1053                                       ParallelContext::GetInstance()->pipeline_stage_split_num() > 1) &&
1054                                      ParallelContext::GetInstance()->enable_parallel_optimizer())) {
1055     MS_LOG(INFO) << name_ << ": The dev size is 1, no need to create group.";
1056     return SUCCESS;
1057   }
1058   if (is_auto_parallel_) {
1059     if (g_device_manager->CheckDeviceList(group_devices) != SUCCESS) {
1060       MS_LOG(INFO) << name_ << ": Try to create communication group : " << group_devices
1061                    << " failed in auto parallel mode, "
1062                       "this error can be ignored in parallel strategies searching step";
1063       return FAILED;
1064     }
1065     return SUCCESS;
1066   }
1067 
1068   Group g;
1069   if (g_device_manager->CreateGroup(group_devices, &g) != SUCCESS) {
1070     MS_LOG(ERROR) << name_ << ": Create communication group by tensor_map failed, the rank_list is: " << group_devices
1071                   << ", the input strategy is " << strategy_->GetInputDim()
1072                   << ", the full_name of node is: " << cnode_->fullname_with_scope();
1073     return FAILED;
1074   }
1075   group->push_back(g);
1076   return SUCCESS;
1077 }
1078 
CreateGroupForOptShard(TensorLayout * tensor_layout,std::vector<Group> * groups)1079 Status OperatorInfo::CreateGroupForOptShard(TensorLayout *tensor_layout, std::vector<Group> *groups) {
1080   if (groups == nullptr) {
1081     MS_LOG(ERROR) << name_ << ": The group is null.";
1082     return FAILED;
1083   }
1084   CheckGlobalDeviceManager();
1085   int64_t rank = g_device_manager->global_rank();
1086   DeviceMatrix dev_matrix(rank, stage_device_list_, tensor_layout->device_arrangement_origin().array());
1087   RankList group_devices;
1088   Shape tensor_map = tensor_layout->origin_tensor_map().array();
1089   if (dev_matrix.GetDevicesByTensorMap(tensor_map, &group_devices) != SUCCESS) {
1090     return FAILED;
1091   }
1092 
1093   if (group_devices.size() == 1) {
1094     MS_LOG(INFO) << name_ << ": The dev size is 1, no need to create group.";
1095     return SUCCESS;
1096   }
1097   int64_t repeated_size = SizeToLong(group_devices.size());
1098   int64_t optimizer_weight_shard_size = ParallelContext::GetInstance()->optimizer_weight_shard_size();
1099   MS_EXCEPTION_IF_ZERO("optimizer_weight_shard_size", optimizer_weight_shard_size);
1100   if (optimizer_weight_shard_size != -1 && repeated_size > optimizer_weight_shard_size) {
1101     // not fully use opt shard
1102     int64_t index = std::find(group_devices.begin(), group_devices.end(), rank) - group_devices.begin();
1103     if (repeated_size % optimizer_weight_shard_size != 0 || repeated_size < optimizer_weight_shard_size) {
1104       MS_LOG(WARNING) << "Parallel optimizer:"
1105                       << " optimizer_weight_shard_size " << optimizer_weight_shard_size
1106                       << " can not be applied for the parameter used by" << cnode_->fullname_with_scope()
1107                       << " The data parallel group size is " << repeated_size;
1108       return FAILED;
1109     }
1110     repeated_size = repeated_size / optimizer_weight_shard_size;
1111     // create allgather group
1112     // eg: optimizer_weight_shard_size = 2, [0, 8, 16, 24] -> [0, 8], [16, 24]
1113     RankList new_group_devices(
1114       group_devices.begin() + index / optimizer_weight_shard_size * optimizer_weight_shard_size,
1115       group_devices.begin() + (index / optimizer_weight_shard_size + 1) * optimizer_weight_shard_size);
1116     Group allgather_group;
1117     if (g_device_manager->CreateGroup(new_group_devices, &allgather_group) != SUCCESS) {
1118       MS_LOG(ERROR) << name_
1119                     << ": Create communication group for allgather in optimizer parallel failed,"
1120                        " the rank_list is: "
1121                     << group_devices << ", the input strategy is " << strategy_->GetInputDim()
1122                     << ", the full_name of node is: " << cnode_->fullname_with_scope();
1123       return FAILED;
1124     }
1125     groups->push_back(allgather_group);
1126     tensor_layout->set_opt_shard_group(allgather_group.name());
1127     MS_LOG(INFO) << name_ << ": Parallel optimizer, create allgather group " << allgather_group.name();
1128     // create mirror group
1129     // eg: optimizer_weight_shard_size = 2, [0, 8, 16, 24] -> [0, 16], [8, 24]
1130     int64_t device_num = g_device_manager->stage_device_num();
1131     MS_EXCEPTION_IF_ZERO("repeated_size", repeated_size);
1132     Shape dev_mat = {repeated_size, device_num / repeated_size};
1133     DeviceMatrix temp_dev_matrix(rank, stage_device_list_, dev_mat);
1134     RankList mirror_group_devices;
1135     if (temp_dev_matrix.GetDevicesAlongDim(0, &mirror_group_devices) != SUCCESS) {
1136       return FAILED;
1137     }
1138     Group mirror_group;
1139     if (g_device_manager->CreateGroup(mirror_group_devices, &mirror_group) != SUCCESS) {
1140       MS_LOG(ERROR) << name_
1141                     << ": Create communication group for mirror in optimizer parallel failed,"
1142                        " the rank_list is: "
1143                     << group_devices << ", the input strategy is " << strategy_->GetInputDim()
1144                     << ", the full_name of node is: " << cnode_->fullname_with_scope();
1145       return FAILED;
1146     }
1147     groups->push_back(mirror_group);
1148     tensor_layout->set_opt_shard_mirror_group(mirror_group.name());
1149     MS_LOG(INFO) << name_ << ": Parallel optimizer, create mirror group " << mirror_group.name();
1150   } else {
1151     // fully use opt shard
1152     // create allgather group
1153     Group allgather_group;
1154     if (g_device_manager->CreateGroup(group_devices, &allgather_group) != SUCCESS) {
1155       MS_LOG(ERROR) << name_
1156                     << ": Create communication group for allgather in optimizer parallel failed,"
1157                        " the rank_list is: "
1158                     << group_devices << ", the input strategy is " << strategy_->GetInputDim()
1159                     << ", the full_name of node is: " << cnode_->fullname_with_scope();
1160       return FAILED;
1161     }
1162     groups->push_back(allgather_group);
1163     tensor_layout->set_opt_shard_group(allgather_group.name());
1164     MS_LOG(INFO) << name_ << ": Parallel optimizer, create allgather group " << allgather_group.name();
1165   }
1166   // save in tensor_layout for strategy ckpt
1167   auto integrated_save = ParallelContext::GetInstance()->optimizer_weight_shard_aggregated_save();
1168   if (!integrated_save) {
1169     tensor_layout->set_opt_weight_shard_size(LongToInt(optimizer_weight_shard_size));
1170     if (optimizer_weight_shard_size > 0 && group_devices.size() < LongToSize(optimizer_weight_shard_size)) {
1171       tensor_layout->set_opt_weight_shard_size(SizeToInt(group_devices.size()));
1172     }
1173     MS_EXCEPTION_IF_ZERO("SizeToLong(group_devices.size()) - 1", SizeToLong(group_devices.size()) - 1);
1174     int64_t opt_weight_shard_step =
1175       (group_devices.back() - group_devices.front()) / (SizeToLong(group_devices.size()) - 1);
1176     tensor_layout->set_opt_weight_shard_step(LongToInt(opt_weight_shard_step));
1177     MS_LOG(INFO) << name_ << "Parallel optimizer, save opt_weight_shard_step " << opt_weight_shard_step
1178                  << " in strategy ckpt";
1179   }
1180   return SUCCESS;
1181 }
1182 
InsertDivOpToNodeInput(const CNodePtr & node,int64_t div_num,size_t index,const string & instance_name)1183 static void InsertDivOpToNodeInput(const CNodePtr &node, int64_t div_num, size_t index, const string &instance_name) {
1184   MS_EXCEPTION_IF_NULL(node);
1185   FuncGraphPtr func_graph = node->func_graph();
1186   MS_EXCEPTION_IF_NULL(func_graph);
1187   // instance the div operator
1188   Operator div_op = CreateScalarFloorDivOp(div_num);
1189 
1190   // Insert it as the input of the node
1191   AnfNodePtr input = node->input(index);
1192   MS_EXCEPTION_IF_NULL(input);
1193   InsertNode(div_op, node, index, node->input(index), func_graph, instance_name);
1194 }
1195 
ChangeMakeTupleConstant(const CNodePtr & cnode,size_t make_tuple_index)1196 void OperatorInfo::ChangeMakeTupleConstant(const CNodePtr &cnode, size_t make_tuple_index) {
1197   if (!IsPrimitiveCNode(cnode->input(make_tuple_index), prim::kPrimMakeTuple)) {
1198     MS_LOG(EXCEPTION) << name_ << ": the dst shape is not make tuple";
1199   }
1200   size_t input_dim = inputs_shape_[0].size();
1201   auto shard_size = strategy_->GetInputDim()[0];
1202   if (input_dim != shard_size.size()) {
1203     MS_LOG(EXCEPTION) << name_ << ": the input dim is " << input_dim << ", but the size of strategy is "
1204                       << shard_size.size();
1205   }
1206 
1207   auto make_tuple = cnode->input(make_tuple_index);
1208   auto make_tuple_cnode = make_tuple->cast<CNodePtr>();
1209   for (size_t i = 0; i < input_dim; ++i) {
1210     if (shard_size[i] <= 1) {
1211       continue;
1212     }
1213     auto value_node = GetValueNode(make_tuple_cnode->input(i + 1));
1214     if (value_node == nullptr) {
1215       std::string instance_name = name_ + "div";
1216       InsertDivOpToNodeInput(make_tuple_cnode, shard_size[i], i + 1, instance_name);
1217     } else if (value_node->isa<Int64Imm>()) {
1218       MS_EXCEPTION_IF_ZERO("shard_size", shard_size[i]);
1219       auto origin_value = GetValue<int64_t>(value_node);
1220       if (origin_value < 0 || origin_value % shard_size[i] != 0) {
1221         MS_LOG(EXCEPTION) << name_ << ": the origin value is " << origin_value << ", can not be div by shard size "
1222                           << shard_size[i] << ", the make tuple index of this op is " << make_tuple_index
1223                           << ", the input index of make_tuple is " << i + 1;
1224       }
1225       int64_t replace_value = GetValue<int64_t>(value_node) / shard_size[i];
1226       auto replace_value_ptr = MakeValue(replace_value);
1227       auto replace_value_node = std::make_shared<ValueNode>(replace_value_ptr);
1228       auto manager = make_tuple->func_graph()->manager();
1229       manager->SetEdge(make_tuple, i + 1, replace_value_node);
1230     } else {
1231       MS_LOG(EXCEPTION) << name_ << ": the input of make_tuple is value node but not int64, the index is " << (i + 1);
1232     }
1233   }
1234 }
1235 
CreateGroupByDim(size_t axis,std::vector<Group> * group)1236 Status OperatorInfo::CreateGroupByDim(size_t axis, std::vector<Group> *group) {
1237   if (group == nullptr) {
1238     MS_LOG(ERROR) << name_ << ": The group is null.";
1239     return FAILED;
1240   }
1241   CheckGlobalDeviceManager();
1242   int64_t rank = g_device_manager->global_rank();
1243   DeviceMatrix dev_matrix(rank, stage_device_list_, dev_matrix_shape_);
1244 
1245   return CreateGroupByDimWithDevMatrix(&dev_matrix, axis, group);
1246 }
1247 
CreateGroupByDimWithDevMatrix(DeviceMatrix * dev_matrix,size_t axis,std::vector<Group> * group)1248 Status OperatorInfo::CreateGroupByDimWithDevMatrix(DeviceMatrix *dev_matrix, size_t axis, std::vector<Group> *group) {
1249   if (group == nullptr) {
1250     MS_LOG(ERROR) << name_ << ": The group is null.";
1251     return FAILED;
1252   }
1253   CheckGlobalDeviceManager();
1254   RankList group_devices;
1255   if (dev_matrix->GetDevicesAlongDim(SizeToUlong(axis), &group_devices) != SUCCESS) {
1256     return FAILED;
1257   }
1258 
1259   if (group_devices.size() == 1) {
1260     MS_LOG(INFO) << name_ << ": The dev size is 1, no need to create group.";
1261     return SUCCESS;
1262   }
1263   if (is_auto_parallel_) {
1264     if (g_device_manager->CheckDeviceList(group_devices) != SUCCESS) {
1265       MS_LOG(INFO) << name_ << ": Try to create communication group : " << group_devices
1266                    << " failed in auto parallel mode, "
1267                    << "this error can be ignored in parallel strategies searching step";
1268       return FAILED;
1269     }
1270     return SUCCESS;
1271   }
1272   Group g;
1273   if (g_device_manager->CreateGroup(group_devices, &g) != SUCCESS) {
1274     MS_LOG(ERROR) << name_ << ": Create communication group by dim failed, the rank_list is: " << group_devices
1275                   << ", the input strategy is " << strategy_->GetInputDim()
1276                   << ", the full_name of node is: " << cnode_->fullname_with_scope();
1277     return FAILED;
1278   }
1279   MS_LOG(INFO) << name_ << ": Create communication group by dim " << axis
1280                << " success, the rank_list is: " << group_devices;
1281   group->push_back(g);
1282   return SUCCESS;
1283 }
1284 
GetSliceShape(const Shape & tensor_shape,const Dimensions & strategy)1285 Shape GetSliceShape(const Shape &tensor_shape, const Dimensions &strategy) {
1286   Shape slice_shape;
1287   if (std::any_of(strategy.begin(), strategy.end(), [](int64_t value) { return value <= 0; })) {
1288     MS_LOG(ERROR) << "Invalid strategy: " << ShapeToString(strategy) << ", the element is less than or equal to 0";
1289     return slice_shape;
1290   }
1291   for (size_t i = 0; i < strategy.size(); ++i) {
1292     slice_shape.push_back(tensor_shape.at(i) / strategy.at(i));
1293   }
1294   return slice_shape;
1295 }
1296 
Init(const StrategyPtr & in_strategy,const StrategyPtr & out_strategy,const std::vector<std::shared_ptr<TensorLayout>> & in_tensor_layouts,const std::vector<std::shared_ptr<TensorLayout>> & out_tensor_layouts)1297 Status OperatorInfo::Init(const StrategyPtr &in_strategy, const StrategyPtr &out_strategy,
1298                           const std::vector<std::shared_ptr<TensorLayout>> &in_tensor_layouts,
1299                           const std::vector<std::shared_ptr<TensorLayout>> &out_tensor_layouts) {
1300   if (!in_tensor_layouts.empty()) {
1301     return InitWithTensorLayout(in_tensor_layouts, out_tensor_layouts);
1302   }
1303 
1304   if (InitWithAutoRepeatCalc(in_strategy, out_strategy) != SUCCESS) {
1305     MS_LOG(ERROR) << name_ << " : Init failed.";
1306     return FAILED;
1307   }
1308 
1309   MS_LOG(INFO) << name_ << " : Init success.";
1310   return SUCCESS;
1311 }
1312 
InitForCostModel(const StrategyPtr & in_strategy,const StrategyPtr & out_strategy)1313 Status OperatorInfo::InitForCostModel(const StrategyPtr &in_strategy, const StrategyPtr &out_strategy) {
1314   std::vector<std::shared_ptr<TensorLayout>> in_tensor_layouts;
1315   std::vector<std::shared_ptr<TensorLayout>> out_tensor_layouts;
1316   Status status =
1317     ExtractUserConfigLayout(attrs_, inputs_shape_, outputs_shape_, &in_tensor_layouts, &out_tensor_layouts);
1318   if (status != SUCCESS) {
1319     MS_LOG(EXCEPTION) << "Failure:operator " << name_ << " extract configured layout failed.";
1320   }
1321   if (!in_tensor_layouts.empty()) {
1322     out_tensor_layouts = {};
1323     return InitWithTensorLayout(in_tensor_layouts, out_tensor_layouts);
1324   }
1325   if (InitForCostModelWithAutoRepeatCalc(in_strategy, out_strategy) != SUCCESS) {
1326     ReportError(name_ + " : Init for cost model failed.");
1327     return FAILED;
1328   }
1329 
1330   MS_LOG(INFO) << name_ << " : Init for cost model success.";
1331   return SUCCESS;
1332 }
1333 
DivisorsReplaceShapes()1334 void OperatorInfo::DivisorsReplaceShapes() {
1335   if (!dynamic_shape_flag_) {
1336     return;
1337   }
1338 
1339   inputs_shape_ = inputs_divisor_;
1340   outputs_shape_ = outputs_divisor_;
1341 }
1342 
ResumeShapes()1343 void OperatorInfo::ResumeShapes() {
1344   if (!dynamic_shape_flag_) {
1345     return;
1346   }
1347 
1348   inputs_shape_ = inputs_shape_clone_;
1349   outputs_shape_ = outputs_shape_clone_;
1350 }
1351 
DynamicShapeCheckStrategyLog()1352 void OperatorInfo::DynamicShapeCheckStrategyLog() {
1353   if (!dynamic_shape_flag_) {
1354     return;
1355   }
1356   MS_LOG(ERROR) << name_ << ": the origin shape of inputs is " << ShapesToString(inputs_shape_clone_)
1357                 << ", but the divisor info of inputs is " << ShapesToString(inputs_divisor_);
1358 }
1359 
1360 // auto insert repeated_calculation_num for dev_matrix_shape when repeated_calculation_num > 1
InitForCostModelWithAutoRepeatCalc(const StrategyPtr & in_strategy,const StrategyPtr & out_strategy)1361 Status OperatorInfo::InitForCostModelWithAutoRepeatCalc(const StrategyPtr &in_strategy,
1362                                                         const StrategyPtr &out_strategy) {
1363   if (!is_layout_config_ && in_strategy == nullptr) {
1364     MS_LOG(ERROR) << name_ << ": The strategy is null, the inputs shape is " << inputs_shape_;
1365     return FAILED;
1366   }
1367 
1368   // need to clear queues before Init(),
1369   // because Init() may be called multiple times by cost model
1370   ResetQueueMember();
1371 
1372   if (InferAttrs() != SUCCESS) {
1373     MS_LOG(ERROR) << name_ << ": InferAttrs failed.";
1374     return FAILED;
1375   }
1376 
1377   // if layout is configured, no need to check strategy and infer dev matrix
1378   if (!is_layout_config_) {
1379     DivisorsReplaceShapes();  // in dynamic shape, using divisors replace to shapes before CheckStrategy
1380     // must be after InferAttrs()
1381     if (CheckStrategy(in_strategy) != SUCCESS) {
1382       DynamicShapeCheckStrategyLog();
1383       FILTER_LOG(is_auto_parallel_) << name_ << ": CheckStrategy failed.";
1384       return FAILED;
1385     }
1386     ResumeShapes();  // in dynamic shape, resume shapes after CheckStrategy
1387 
1388     if (is_dynamic_shape_ && CheckStrategyForDynamicShape(in_strategy) != SUCCESS) {
1389       MS_LOG(ERROR) << name_ << ": Check strategy for dynamic shape failed";
1390       return FAILED;
1391     }
1392     strategy_ = in_strategy;
1393 
1394     if (out_strategy && CheckOutputStrategy(out_strategy) != SUCCESS) {
1395       MS_LOG(ERROR) << name_ << ": The output strategy is invalid";
1396       return FAILED;
1397     }
1398     set_out_strategy(out_strategy);
1399 
1400     if (InferDevMatrixShape() != SUCCESS) {
1401       MS_LOG(ERROR) << name_ << ": InferDevMatrixShape failed.";
1402       return FAILED;
1403     }
1404 
1405     used_devices_ = std::accumulate(dev_matrix_shape_.begin(), dev_matrix_shape_.end(), 1, std::multiplies<int64_t>());
1406 
1407     // must be after InferDevMatrixShape
1408     if (InferRepeatedCalcInfo() != SUCCESS) {
1409       MS_LOG(ERROR) << name_ << ": InferRepeatedCalcInfo failed.";
1410       return FAILED;
1411     }
1412 
1413     // if repeated calculation, need to set the repeated_calc_num as the last dimension of dev-matrix for layout
1414     SetRepeatedCalcDevMatrix();
1415 
1416     if (InferTensorMap() != SUCCESS) {
1417       MS_LOG(ERROR) << name_ << ": InferTensorMap failed.";
1418       return FAILED;
1419     }
1420 
1421     ResetTensorMapIfRepeatedCalc();
1422   } else {
1423     if (InferOutputTensorMap() != SUCCESS) {
1424       MS_LOG(ERROR) << name_ << ": InferOutputTensorMap failed.";
1425       return FAILED;
1426     }
1427   }
1428 
1429   if (InferTensorInfo() != SUCCESS) {
1430     MS_LOG(ERROR) << name_ << ": InferTensorInfo failed.";
1431     return FAILED;
1432   }
1433   auto stage_dev_num = LongToSize(g_device_manager->stage_device_num());
1434   if ((stage_dev_num & (stage_dev_num - 1)) == 0) {
1435     return SUCCESS;
1436   }
1437   if (InferForwardCommunication() != SUCCESS) {
1438     MS_LOG(WARNING) << name_ << ": InferForwardCommunication failed in auto parallel searching strategies step.";
1439     return FAILED;
1440   }
1441 
1442   if (InferMirrorOps() != SUCCESS) {
1443     MS_LOG(WARNING) << name_ << ": InferMirrorOps failed in auto parallel searching strategies step.";
1444     return FAILED;
1445   }
1446   return SUCCESS;
1447 }
1448 
InitWithAutoRepeatCalc(const StrategyPtr & in_strategy,const StrategyPtr & out_strategy)1449 Status OperatorInfo::InitWithAutoRepeatCalc(const StrategyPtr &in_strategy, const StrategyPtr &out_strategy) {
1450   if (in_strategy == nullptr) {
1451     MS_LOG(ERROR) << name_ << ": The input strategy is null.";
1452     return FAILED;
1453   }
1454 
1455   if (InitForCostModelWithAutoRepeatCalc(in_strategy, out_strategy) != SUCCESS) {
1456     return FAILED;
1457   }
1458 
1459   if (InferForwardCommunication() != SUCCESS) {
1460     MS_LOG(ERROR) << name_ << ": InferForwardCommunication failed.";
1461     return FAILED;
1462   }
1463 
1464   if (InferMirrorOps() != SUCCESS) {
1465     MS_LOG(ERROR) << name_ << ": InferMirrorOps failed.";
1466     return FAILED;
1467   }
1468 
1469   if (InferVirtualDivOps() != SUCCESS) {
1470     MS_LOG(ERROR) << name_ << ": InferVirtualDivOps failed.";
1471     return FAILED;
1472   }
1473 
1474   InferReplaceOps();
1475   return SUCCESS;
1476 }
1477 
CheckInputLayout()1478 Status OperatorInfo::CheckInputLayout() {
1479   MS_LOG(ERROR) << "Current op " << name_
1480                 << " does not support config layout. Please check "
1481                    "https://www.mindspore.cn/docs/zh-CN/r2.3.0rc2/note/operator_list_parallel.html to get limitation "
1482                    "and more details";
1483   // Check self_define_shard attribute
1484   if (!self_define_shard_) {
1485     MS_LOG(ERROR) << "Please set add_prim_attr('self_define_shard', True) to " << name_
1486                   << " to config layout for this ops";
1487     return FAILED;
1488   }
1489   return FAILED;
1490 }
1491 
InitWithTensorLayout(const std::vector<std::shared_ptr<TensorLayout>> & in_tensor_layouts,const std::vector<std::shared_ptr<TensorLayout>> & out_tensor_layouts)1492 Status OperatorInfo::InitWithTensorLayout(const std::vector<std::shared_ptr<TensorLayout>> &in_tensor_layouts,
1493                                           const std::vector<std::shared_ptr<TensorLayout>> &out_tensor_layouts) {
1494   ResetQueueMember();
1495 
1496   if (InferAttrs() != SUCCESS) {
1497     MS_LOG(ERROR) << name_ << ": InferAttrs failed.";
1498     return FAILED;
1499   }
1500 
1501   size_t real_input_index = 0;
1502   for (const auto &input_layout : in_tensor_layouts) {
1503     // Insert placeholder TensorInfo for optional input
1504     while (real_input_index < input_value_.size() && input_value_[real_input_index] != nullptr &&
1505            input_value_[real_input_index]->isa<None>()) {
1506       (void)inputs_tensor_info_.emplace_back(TensorInfo());
1507       ++real_input_index;
1508     }
1509     TensorInfo input_tensor_info(*input_layout);
1510     inputs_tensor_info_.push_back(input_tensor_info);
1511     ++real_input_index;
1512   }
1513   if (CheckInputLayout() != SUCCESS) {
1514     MS_LOG(ERROR) << name_ << ": CheckInputLayout failed.";
1515     return FAILED;
1516   }
1517   for (const auto &output_layout : out_tensor_layouts) {
1518     TensorInfo output_tensor_info(*output_layout);
1519     outputs_tensor_info_.push_back(output_tensor_info);
1520   }
1521 
1522   if (outputs_tensor_info_.size() != outputs_shape_.size()) {
1523     outputs_tensor_info_.clear();
1524     // Need be override
1525     if (InferOutputTensorInfo() != SUCCESS) {
1526       MS_LOG(ERROR) << name_ << ": InferOutputTensorLayout failed.";
1527       return FAILED;
1528     }
1529   }
1530 
1531   if (outputs_tensor_info_.size() != outputs_shape_.size()) {
1532     MS_LOG(ERROR) << name_ << ": the output tensor layout num " << outputs_tensor_info_.size()
1533                   << " dose not match the output num " << outputs_shape_.size();
1534     return FAILED;
1535   }
1536 
1537   if (CheckOutputLayout() != SUCCESS) {
1538     MS_LOG(ERROR) << name_ << ": CheckLayout failed.";
1539     return FAILED;
1540   }
1541 
1542   // Need be override
1543   if (InferForwardCommunicationByLayout() != SUCCESS) {
1544     MS_LOG(ERROR) << name_ << ": InferForwardCommunication failed.";
1545     return FAILED;
1546   }
1547 
1548   if (InferMirrorOpsByLayout() != SUCCESS) {
1549     MS_LOG(ERROR) << name_ << ": InferMirrorOps failed.";
1550     return FAILED;
1551   }
1552   if (InferVirtualDivOpsByLayout() != SUCCESS) {
1553     MS_LOG(ERROR) << name_ << ": InferVirtualDivOps failed.";
1554     return FAILED;
1555   }
1556   return SUCCESS;
1557 }
1558 
GetAliveSuccEdges()1559 std::vector<std::shared_ptr<Edge>> OperatorInfo::GetAliveSuccEdges() {
1560   std::vector<std::shared_ptr<Edge>> ret;
1561   for (auto &edge : succ_edges_) {
1562     if ((edge->next_operator()->is_alive()) && (edge->next_operator()->name().find(RELU) != std::string::npos)) {
1563       ret.push_back(edge);
1564     } else if ((edge->next_operator()->is_alive()) && (edge->next_operator()->name().find(CAST) != std::string::npos)) {
1565       // CAST is ordered in front of L2NORMALIZE
1566       ret.push_back(edge);
1567     }
1568   }
1569   for (auto &edge : succ_edges_) {
1570     if ((edge->next_operator()->is_alive()) && (edge->next_operator()->name().find(RELU) == std::string::npos) &&
1571         (edge->next_operator()->name().find(CAST) == std::string::npos)) {
1572       ret.push_back(edge);
1573     }
1574   }
1575   return ret;
1576 }
1577 
GetAlivePrevEdges()1578 std::vector<std::shared_ptr<Edge>> OperatorInfo::GetAlivePrevEdges() {
1579   std::vector<std::shared_ptr<Edge>> ret;
1580   for (auto &edge : prev_edges_) {
1581     if (edge->prev_operator()->is_alive()) {
1582       ret.push_back(edge);
1583     }
1584   }
1585   return ret;
1586 }
1587 
ReplacePreEdge(const std::shared_ptr<OperatorInfo> & op,const std::shared_ptr<Edge> & new_edge)1588 void OperatorInfo::ReplacePreEdge(const std::shared_ptr<OperatorInfo> &op, const std::shared_ptr<Edge> &new_edge) {
1589   if (op == nullptr) {
1590     MS_LOG(ERROR) << name_ << ": ReplacePreEdge: the op is null.";
1591     return;
1592   }
1593   for (auto &edge : prev_edges_) {
1594     if (edge->prev_operator() == op) {
1595       edge = new_edge;
1596       return;
1597     }
1598   }
1599   MS_LOG(EXCEPTION) << name_ << ": Replace edge failed: no edge has been replaced";
1600 }
1601 
ReplaceSuccEdge(const std::shared_ptr<OperatorInfo> & op,const std::shared_ptr<Edge> & new_edge)1602 void OperatorInfo::ReplaceSuccEdge(const std::shared_ptr<OperatorInfo> &op, const std::shared_ptr<Edge> &new_edge) {
1603   if (op == nullptr) {
1604     MS_LOG(ERROR) << name_ << ": ReplaceSuccEdge: the op is null.";
1605     return;
1606   }
1607   for (auto &edge : succ_edges_) {
1608     if (edge->next_operator() == op) {
1609       edge = new_edge;
1610       return;
1611     }
1612   }
1613   MS_LOG(EXCEPTION) << name_ << ": Replace edge failed: no edge has been replaced";
1614 }
1615 
ReplacePreEdges(const std::shared_ptr<OperatorInfo> & op,const std::shared_ptr<Edge> & new_edge)1616 void OperatorInfo::ReplacePreEdges(const std::shared_ptr<OperatorInfo> &op, const std::shared_ptr<Edge> &new_edge) {
1617   if (op == nullptr) {
1618     MS_LOG(ERROR) << name_ << ": ReplacePreEdges: the op is null.";
1619     return;
1620   }
1621   std::vector<std::shared_ptr<Edge>> update_pre_edges;
1622   for (auto &edge : prev_edges_) {
1623     if (edge->prev_operator() != op) {
1624       update_pre_edges.push_back(edge);
1625     }
1626   }
1627   update_pre_edges.push_back(new_edge);
1628   prev_edges_ = update_pre_edges;
1629 }
1630 
ReplaceSuccEdges(const std::shared_ptr<OperatorInfo> & op,const std::shared_ptr<Edge> & new_edge)1631 void OperatorInfo::ReplaceSuccEdges(const std::shared_ptr<OperatorInfo> &op, const std::shared_ptr<Edge> &new_edge) {
1632   if (op == nullptr) {
1633     MS_LOG(ERROR) << name_ << ": ReplaceSuccEdges: the op is null";
1634     return;
1635   }
1636   std::vector<std::shared_ptr<Edge>> update_pre_edges;
1637   for (auto &edge : succ_edges_) {
1638     if (edge->next_operator() != op) {
1639       update_pre_edges.push_back(edge);
1640     }
1641   }
1642   update_pre_edges.push_back(new_edge);
1643   succ_edges_ = update_pre_edges;
1644 }
1645 
GenerateBatchStrategiesWithCheck()1646 std::shared_ptr<Strategies> OperatorInfo::GenerateBatchStrategiesWithCheck() {
1647   if (InferAttrs() != SUCCESS) {
1648     MS_LOG(EXCEPTION) << name_ << ": Infer attrs failed";
1649   }
1650   DivisorsReplaceShapes();  // in dynamic shape, using divisors replace to shapes before GenerateBatchStrategies
1651 
1652   std::shared_ptr<Strategies> batch_strategy = GenerateBatchStrategies();
1653   if (batch_strategy->size() != inputs_shape_.size()) {
1654     MS_LOG(WARNING) << "The inputs size:" << inputs_shape_.size()
1655                     << " is not equal to the generated batch parallel strategies size:" << batch_strategy->size();
1656     return batch_strategy;
1657   }
1658   int64_t shard_size = g_device_manager->stage_device_num();
1659   std::vector<std::pair<size_t, size_t>> changed_pos;
1660   for (size_t i = 0; i < inputs_shape_.size(); ++i) {
1661     auto stra = batch_strategy->at(i);
1662     auto input_shape = inputs_shape_.at(i);
1663     if (stra.size() != input_shape.size()) {
1664       MS_LOG(WARNING) << "The " << i << " input size:" << input_shape.size() << " is not equal to the " << i
1665                       << " generated batch parallel strategy size:" << stra.size();
1666       return batch_strategy;
1667     }
1668     for (size_t j = 0; j < input_shape.size(); ++j) {
1669       if (stra[j] == 1) {
1670         continue;
1671       }
1672       if (stra[j] != g_device_manager->stage_device_num()) {
1673         MS_LOG(WARNING) << "The batch parallel value is not equal to device num, skip adjust it.";
1674         return batch_strategy;
1675       }
1676       shard_size = std::gcd(input_shape[j], shard_size);
1677       changed_pos.push_back({i, j});
1678     }
1679   }
1680   for (auto &pair : changed_pos) {
1681     batch_strategy->at(pair.first).at(pair.second) = shard_size;
1682   }
1683 
1684   ResumeShapes();
1685   return batch_strategy;
1686 }
1687 
GenerateBatchStrategiesBySplitFlag(const Shapes & shapes,const std::vector<bool> & split_flag_list)1688 std::shared_ptr<Strategies> GenerateBatchStrategiesBySplitFlag(const Shapes &shapes,
1689                                                                const std::vector<bool> &split_flag_list) {
1690   if (shapes.size() != split_flag_list.size()) {
1691     MS_LOG(ERROR) << "Split_flag_list do not have the same size as inputs shape, " << split_flag_list.size() << " : "
1692                   << shapes.size();
1693     return nullptr;
1694   }
1695   CheckGlobalDeviceManager();
1696   int64_t dev_num = g_device_manager->stage_device_num();
1697   Strategies strategy_v;
1698   for (size_t i = 0; i != shapes.size(); i++) {
1699     if (shapes[i].empty()) {
1700       MS_LOG(INFO) << "Elements of shapes is empty.";
1701       Dimensions empty_element;
1702       strategy_v.push_back(empty_element);
1703     } else {
1704       Dimensions element(shapes[i].size(), 1);
1705       if (split_flag_list[i]) {
1706         element[0] = dev_num;
1707       }
1708       strategy_v.push_back(element);
1709     }
1710   }
1711   return std::make_shared<Strategies>(strategy_v);
1712 }
1713 
ReComputeBatchSplitFlagList()1714 void OperatorInfo::ReComputeBatchSplitFlagList() {
1715   if (!inputs_shape_.empty()) {
1716     split_flag_list_[0] = true;
1717   }
1718 }
1719 
ComputeBatchSplitFlagList()1720 void OperatorInfo::ComputeBatchSplitFlagList() {
1721   split_flag_list_.clear();
1722   for (auto iter = inputs_shape_.begin(); iter != inputs_shape_.end(); ++iter) {
1723     split_flag_list_.push_back(false);
1724   }
1725   ReComputeBatchSplitFlagList();
1726 }
1727 
1728 // This is a common method for checking whether the generated strategy has the correct number of devuces.
PrepareStrategyBase(int64_t stage_id,size_t dev_num,const Shapes & inputs_partitions,StrategyPtr * const sp)1729 Status PrepareStrategyBase(int64_t stage_id, size_t dev_num, const Shapes &inputs_partitions, StrategyPtr *const sp) {
1730   if (sp == nullptr) {
1731     MS_LOG(ERROR) << "The strategy is null.";
1732     return FAILED;
1733   }
1734   int64_t product = 1;
1735 
1736   for (auto &input_partition : inputs_partitions) {
1737     product *= std::accumulate(input_partition.begin(), input_partition.end(), 1, std::multiplies<int64_t>());
1738   }
1739   const auto fully_use_device = CostModelContext::GetInstance()->fully_use_device();
1740   if (!fully_use_device) {
1741     if (LongToSize(product) > dev_num) {
1742       return FAILED;
1743     }
1744   } else {
1745     if ((product != 1) && (LongToSize(product) != dev_num)) {
1746       return FAILED;
1747     }
1748   }
1749   Strategies stras(inputs_partitions);
1750   (*sp) = std::make_shared<Strategy>(stage_id, stras);
1751   return SUCCESS;
1752 }
1753 
GenerateBatchStrategies()1754 std::shared_ptr<Strategies> OperatorInfo::GenerateBatchStrategies() {
1755   if (InferAttrs() != SUCCESS) {
1756     MS_LOG(EXCEPTION) << name_ << ": Infer attrs failed";
1757   }
1758   ComputeBatchSplitFlagList();
1759   return GenerateBatchStrategiesBySplitFlag(inputs_shape_, split_flag_list_);
1760 }
1761 
1762 // generate strategies for that each dimension of input0 and input1 is relevant, such as: ([a, b, c, d], [a, b, c, d])
GenerateStrategiesForTwoEqualInputs(int64_t stage_id,const Shapes & inputs_shape,const Shapes & splittable_inputs,std::vector<StrategyPtr> * const sp_vector)1763 Status GenerateStrategiesForTwoEqualInputs(int64_t stage_id, const Shapes &inputs_shape,
1764                                            const Shapes &splittable_inputs, std::vector<StrategyPtr> *const sp_vector) {
1765   if (sp_vector == nullptr) {
1766     MS_LOG(ERROR) << "The sp_vector is null.";
1767     return FAILED;
1768   }
1769 
1770   if ((inputs_shape.size() != 2) || (splittable_inputs.size() != 2)) {
1771     MS_LOG(ERROR) << "The inputs size is wrong.";
1772     return FAILED;
1773   }
1774 
1775   if ((inputs_shape[0].size() != inputs_shape[1].size()) ||
1776       (splittable_inputs[0].size() != splittable_inputs[1].size())) {
1777     MS_LOG(ERROR) << "The size of two inputs are not equal.";
1778     return FAILED;
1779   }
1780 
1781   Shapes input0_shape = {inputs_shape[0]};
1782   Shapes input0_splittable = {splittable_inputs[0]};
1783   if (GenerateStrategiesForIndependentInputs(stage_id, input0_shape, input0_splittable, sp_vector) != SUCCESS) {
1784     return FAILED;
1785   }
1786 
1787   for (auto &sp : *sp_vector) {
1788     sp->ExpandInputDimFromOneToTwo();
1789   }
1790 
1791   return SUCCESS;
1792 }
1793 
1794 // generate strategies for that input0 and input1 have relevant dimensions, and input0 needs to broadcast
1795 // such as: ([b, c, d], [a, b, c, d]) or ([1, c, d], [a, b, c, d])
GenerateStrategiesForBroadcastLeft(int64_t stage_id,const Shapes & inputs_shape,const Shapes & splittable_inputs,std::vector<StrategyPtr> * const sp_vector)1796 Status GenerateStrategiesForBroadcastLeft(int64_t stage_id, const Shapes &inputs_shape, const Shapes &splittable_inputs,
1797                                           std::vector<StrategyPtr> *const sp_vector) {
1798   if (sp_vector == nullptr) {
1799     MS_LOG(ERROR) << "The sp_vector is null.";
1800     return FAILED;
1801   }
1802 
1803   if (inputs_shape[0].size() >= inputs_shape[1].size()) {
1804     MS_LOG(ERROR) << "Invalid inputs shape.";
1805     return FAILED;
1806   }
1807 
1808   // first, generate strategy for input0 the same as input1
1809   Shapes tmp_inputs_shape = {inputs_shape[1], inputs_shape[1]};
1810   Shapes tmp_splittable_inputs = {splittable_inputs[1], splittable_inputs[1]};
1811   if (GenerateStrategiesForTwoEqualInputs(stage_id, tmp_inputs_shape, tmp_splittable_inputs, sp_vector) != SUCCESS) {
1812     MS_LOG(ERROR) << "GenerateStrategiesForTwoEqualInputs failed.";
1813     return FAILED;
1814   }
1815 
1816   // second, get the correct strategy for input0
1817   for (auto &sp : *sp_vector) {
1818     Strategies tmp_strategy;
1819     Dimensions input0_strategy = sp->GetInputDim()[0];
1820     size_t size_diff = inputs_shape[1].size() - inputs_shape[0].size();
1821 
1822     // erase the unnecessary part
1823     (void)input0_strategy.erase(input0_strategy.cbegin(),
1824                                 input0_strategy.cbegin() + static_cast<different_type>(size_diff));
1825 
1826     // handle the case likes ([1, c, d], [a, b, c, d])
1827     for (size_t i = 0; i < inputs_shape[0].size(); ++i) {
1828       if (inputs_shape[0][i] == 1) {
1829         input0_strategy[i] = 1;
1830       } else {
1831         break;
1832       }
1833     }
1834 
1835     // reset the strategy
1836     tmp_strategy.push_back(input0_strategy);       // input0
1837     tmp_strategy.push_back(sp->GetInputDim()[1]);  // input1
1838     sp->ResetInputs(tmp_strategy);
1839   }
1840   return SUCCESS;
1841 }
1842 
1843 // generate strategies for that input0 and input1 have relevant dimensions, and input1 needs to broadcast
1844 // such as: ([a, b, c, d], [b, c, d]) or ([a, b, c, d], [1, c, d])
GenerateStrategiesForBroadcastRight(int64_t stage_id,const Shapes & inputs_shape,const Shapes & splittable_inputs,std::vector<StrategyPtr> * const sp_vector)1845 Status GenerateStrategiesForBroadcastRight(int64_t stage_id, const Shapes &inputs_shape,
1846                                            const Shapes &splittable_inputs, std::vector<StrategyPtr> *const sp_vector) {
1847   if (sp_vector == nullptr) {
1848     MS_LOG(ERROR) << "The sp_vector is null.";
1849     return FAILED;
1850   }
1851 
1852   if (inputs_shape[0].size() <= inputs_shape[1].size()) {
1853     MS_LOG(ERROR) << "Invalid inputs shape.";
1854     return FAILED;
1855   }
1856 
1857   // first, generate strategy for input1 the same as input0
1858   Shapes tmp_inputs_shape = {inputs_shape[0], inputs_shape[0]};
1859   Shapes tmp_splittable_inputs = {splittable_inputs[0], splittable_inputs[0]};
1860   if (GenerateStrategiesForTwoEqualInputs(stage_id, tmp_inputs_shape, tmp_splittable_inputs, sp_vector) != SUCCESS) {
1861     MS_LOG(ERROR) << "GenerateStrategiesForTwoEqualInputs failed.";
1862     return FAILED;
1863   }
1864 
1865   // second, get the correct strategy for input1
1866   for (auto &sp : *sp_vector) {
1867     Strategies tmp_strategy;
1868     tmp_strategy.push_back(sp->GetInputDim()[0]);  // input0
1869 
1870     Dimensions input1_strategy = sp->GetInputDim()[1];
1871     size_t size_diff = inputs_shape[0].size() - inputs_shape[1].size();
1872 
1873     // erase the unnecessary part
1874     (void)input1_strategy.erase(input1_strategy.cbegin(),
1875                                 input1_strategy.cbegin() + static_cast<different_type>(size_diff));
1876 
1877     // handle the case likes ([a, b, c, d], [1, c, d])
1878     for (size_t i = 0; i < inputs_shape[1].size(); ++i) {
1879       if (inputs_shape[1][i] == 1) {
1880         input1_strategy[i] = 1;
1881       } else {
1882         break;
1883       }
1884     }
1885 
1886     // reset the strategy
1887     tmp_strategy.push_back(input1_strategy);  // input1
1888     sp->ResetInputs(tmp_strategy);
1889   }
1890   return SUCCESS;
1891 }
1892 
1893 // generate strategies for that input0 and input1 have same size, and input0 or input1 needs to broadcast
1894 // such as: ([a, 1], [1, b]) or ([a, b, c, d], [1, b, c, d]) or ([a, b, c, 1], [1, b, c, d])
GenerateStrategiesForBroadcastBoth(int64_t stage_id,const Shapes & inputs_shape,const Shapes & splittable_inputs,std::vector<StrategyPtr> * const sp_vector)1895 Status GenerateStrategiesForBroadcastBoth(int64_t stage_id, const Shapes &inputs_shape, const Shapes &splittable_inputs,
1896                                           std::vector<StrategyPtr> *const sp_vector) {
1897   if (sp_vector == nullptr) {
1898     MS_LOG(ERROR) << "The sp_vector is null.";
1899     return FAILED;
1900   }
1901 
1902   if (inputs_shape[0].size() != inputs_shape[1].size()) {
1903     MS_LOG(ERROR) << "Invalid inputs shape.";
1904     return FAILED;
1905   }
1906 
1907   // step1: ([a, 1], [1, b]) -> [a, b]
1908   Shape max_shape, splittable_vector;
1909   for (size_t i = 0; i < inputs_shape[0].size(); ++i) {
1910     if (inputs_shape[0][i] >= inputs_shape[1][i]) {
1911       max_shape.push_back(inputs_shape[0][i]);
1912       splittable_vector.push_back(splittable_inputs[0][i]);
1913     } else {
1914       max_shape.push_back(inputs_shape[1][i]);
1915       splittable_vector.push_back(splittable_inputs[1][i]);
1916     }
1917   }
1918 
1919   // step2: ([a, 1], [1, b]) -> generate strategy for ([a, b], [a, b])
1920   Shapes tmp_inputs_shape = {max_shape, max_shape};
1921   Shapes tmp_splittable_inputs = {splittable_vector, splittable_vector};
1922   if (GenerateStrategiesForTwoEqualInputs(stage_id, tmp_inputs_shape, tmp_splittable_inputs, sp_vector) != SUCCESS) {
1923     MS_LOG(ERROR) << "GenerateStrategiesForTwoEqualInputs failed.";
1924     return FAILED;
1925   }
1926 
1927   // step3: reset the strategy if the dimension is 1
1928   for (auto &sp : *sp_vector) {
1929     Dimensions input0_strategy = sp->GetInputDim()[0];
1930     Dimensions input1_strategy = sp->GetInputDim()[1];
1931     for (size_t i = 0; i < inputs_shape[0].size(); ++i) {
1932       if (inputs_shape[0][i] == 1) {
1933         input0_strategy[i] = 1;
1934       }
1935 
1936       if (inputs_shape[1][i] == 1) {
1937         input1_strategy[i] = 1;
1938       }
1939     }
1940     sp->ResetInputs({input0_strategy, input1_strategy});
1941   }
1942 
1943   return SUCCESS;
1944 }
1945 
GenerateStrategiesForIndependentInputsBase(int64_t stage_id,size_t dev_num,const Shapes & inputs_shape,const Shapes & splittable_inputs,std::vector<StrategyPtr> * sp_vector)1946 Status GenerateStrategiesForIndependentInputsBase(int64_t stage_id, size_t dev_num, const Shapes &inputs_shape,
1947                                                   const Shapes &splittable_inputs,
1948                                                   std::vector<StrategyPtr> *sp_vector) {
1949   Shape combined_inputs_shape, combined_splittable_inputs, combined_partitions;
1950   for (size_t j = 0; j < inputs_shape.size(); ++j) {
1951     (void)combined_inputs_shape.insert(combined_inputs_shape.cend(), inputs_shape[j].cbegin(), inputs_shape[j].cend());
1952     (void)combined_splittable_inputs.insert(combined_splittable_inputs.cend(), splittable_inputs[j].cbegin(),
1953                                             splittable_inputs[j].cend());
1954   }
1955   std::function<void(uint64_t, size_t)> recursive = [&stage_id, &dev_num, &sp_vector, &combined_inputs_shape,
1956                                                      &combined_splittable_inputs, &combined_partitions, &recursive,
1957                                                      &inputs_shape](uint64_t current_index, size_t n) {
1958     if (current_index == combined_inputs_shape.size()) {
1959       MS_LOG(DEBUG) << "The value of combined_splittable_inputs.size is: " << combined_splittable_inputs.size();
1960       Shapes inputs_partitions;
1961       size_t global_index = 0;
1962       for (auto &shape : inputs_shape) {
1963         Shape tmp_partition;
1964         for (size_t j = 0; j < shape.size(); ++j) {
1965           tmp_partition.push_back(combined_partitions[global_index]);
1966           global_index++;
1967         }
1968         inputs_partitions.push_back(tmp_partition);
1969       }
1970       StrategyPtr sp;
1971       if (PrepareStrategyBase(stage_id, dev_num, inputs_partitions, &sp) == SUCCESS) {
1972         sp_vector->push_back(sp);
1973       }
1974       return;
1975     } else {
1976       MS_LOG(DEBUG) << "The value of sp_vector size is " << sp_vector->size();
1977       if (combined_splittable_inputs[current_index] == 0) {
1978         combined_partitions.push_back(MIN_SLICE_NUM);
1979         recursive(current_index + 1, n / MIN_SLICE_NUM);
1980         combined_partitions.pop_back();
1981       } else if (combined_splittable_inputs[current_index] == 1) {
1982         for (uint64_t i = 1; i <= n; i *= 2) {
1983           if (n % i == 0 && LongToSize(combined_inputs_shape[current_index]) % i == 0) {
1984             combined_partitions.push_back(i);
1985             recursive(current_index + 1, n / i);
1986             combined_partitions.pop_back();
1987           }
1988         }
1989       }
1990     }
1991   };
1992   recursive(0, dev_num);
1993   if (sp_vector->empty()) {
1994     MS_LOG(EXCEPTION) << "No available strategy for current OperatorInfo.";
1995   }
1996   return SUCCESS;
1997 }
1998 
1999 // 'splittable_inputs' has the same dimensions as 'inputs_shape_'. '0' in 'splittable_inputs' means that
2000 // the corresponding dimension is unsplittable, '1' in 'splittable_inputs' means that the corresponding
2001 // dimension is splittable. 'inputs_partitions' is the result of partitions.
2002 // NOTE: This implementation would partition all splittable dimensions in all inputs. Some operators requiring
2003 // specific dimensions in inputs have the identical partition should have individual implementation.
GenerateStrategiesForIndependentInputs(int64_t stage_id,const Shapes & inputs_shape,const Shapes & splittable_inputs,std::vector<StrategyPtr> * sp_vector)2004 Status GenerateStrategiesForIndependentInputs(int64_t stage_id, const Shapes &inputs_shape,
2005                                               const Shapes &splittable_inputs, std::vector<StrategyPtr> *sp_vector) {
2006   if (sp_vector == nullptr) {
2007     MS_LOG(ERROR) << "The sp_vector is null.";
2008     return FAILED;
2009   }
2010   if (splittable_inputs.size() != inputs_shape.size()) {
2011     MS_LOG(ERROR) << "Splittable_inputs do not have the same input number of inputs shape, " << splittable_inputs.size()
2012                   << " : " << inputs_shape.size();
2013     return FAILED;
2014   }
2015   CheckGlobalDeviceManager();
2016   size_t dev_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
2017   auto dev_num_2_power = (dev_num & (dev_num - 1));
2018   if (dev_num_2_power == 0) {
2019     return GenerateStrategiesForIndependentInputsBase(stage_id, dev_num, inputs_shape, splittable_inputs, sp_vector);
2020   }
2021   MS_EXCEPTION_IF_ZERO("dev_num - dev_num_2_power", dev_num - dev_num_2_power);
2022   auto dev_num_not_2_power = dev_num / (dev_num - dev_num_2_power);
2023   std::vector<StrategyPtr> sp_vector_2_power_part;
2024   if (GenerateStrategiesForIndependentInputsBase(stage_id, dev_num - dev_num_2_power, inputs_shape, splittable_inputs,
2025                                                  &sp_vector_2_power_part) != SUCCESS) {
2026     MS_LOG(ERROR) << "Generate strategy in the power of 2 devices part failed.";
2027     return FAILED;
2028   }
2029   // Handle the not power of 2 part.
2030   for (auto &stra : sp_vector_2_power_part) {
2031     auto stra_arrays = stra->GetInputDim();
2032     size_t stras_size = stra_arrays.size();
2033     for (size_t i = 0; i < stras_size; ++i) {
2034       auto split_input = splittable_inputs[i];
2035       size_t stra_size = stra_arrays[i].size();
2036       for (size_t j = 0; j < stra_size; ++j) {
2037         if (split_input[j] == 0) {
2038           continue;
2039         }
2040         auto new_stra_arrays{stra_arrays};
2041         new_stra_arrays[i][j] = new_stra_arrays[i][j] * UlongToLong(dev_num_not_2_power);
2042         // discard invalid strategy
2043         MS_EXCEPTION_IF_ZERO("new_stra_arrays[i][j]", new_stra_arrays[i][j]);
2044         if (inputs_shape[i][j] % new_stra_arrays[i][j] != 0) {
2045           continue;
2046         }
2047         StrategyPtr new_stra = std::make_shared<Strategy>(stage_id, new_stra_arrays);
2048         sp_vector->push_back(new_stra);
2049       }
2050     }
2051   }
2052   // add the repeated strategy
2053   auto repeated_stra_arrays{splittable_inputs};
2054   for (auto &stra_array : repeated_stra_arrays) {
2055     std::fill(stra_array.begin(), stra_array.end(), 1);
2056   }
2057   StrategyPtr repeated_stra = std::make_shared<Strategy>(stage_id, repeated_stra_arrays);
2058   sp_vector->push_back(repeated_stra);
2059   return SUCCESS;
2060 }
2061 
2062 // 'splittable_inputs' has the same dimensions as 'inputs_shape_'. '0' in 'splittable_inputs' means that
2063 // the corresponding dimension is unsplittable, otherwise means that the corresponding dimension is splittable.
2064 // In particular, if the same dimensions exist in 'splittable_inputs',
2065 // the corresponding dimensions in the strategy are the same.
2066 // 'sp' is the result of partitions.
GenerateStrategiesForDependentInputs(int64_t stage_id,const Shapes & inputs_shape,const Shapes & splittable_inputs,std::vector<StrategyPtr> * sp)2067 Status GenerateStrategiesForDependentInputs(int64_t stage_id, const Shapes &inputs_shape,
2068                                             const Shapes &splittable_inputs, std::vector<StrategyPtr> *sp) {
2069   if (inputs_shape.size() != splittable_inputs.size()) {
2070     MS_LOG(EXCEPTION) << "Size of inputs_shape and splittable_inputs are not equal.";
2071   }
2072 
2073   std::unordered_map<int64_t, int64_t> mp;
2074   for (size_t i = 0; i < inputs_shape.size(); ++i) {
2075     auto input_shape = inputs_shape[i];
2076     auto splittable_input = splittable_inputs[i];
2077     for (size_t j = 0; j < input_shape.size(); ++j) {
2078       int64_t indice = splittable_input[j];
2079       int64_t shape = input_shape[j];
2080       if (splittable_input[j] == 0) {
2081         continue;
2082       }
2083       if (mp.find(indice) == mp.end()) {
2084         mp[indice] = shape;
2085       } else {
2086         mp[indice] = std::gcd(mp[indice], shape);
2087       }
2088     }
2089   }
2090 
2091   std::unordered_map<int64_t, size_t> indices_mp;
2092   Shape tmp_input_shape;
2093   Shapes tmp_splittable_inputs = {Shape(mp.size(), 1)};
2094 
2095   for (const auto &item : mp) {
2096     indices_mp[item.first] = tmp_input_shape.size();
2097     tmp_input_shape.push_back(item.second);
2098   }
2099   Shapes tmp_inputs_shape = {tmp_input_shape};
2100   std::vector<StrategyPtr> tmp_sp_vector;
2101   if (GenerateStrategiesForIndependentInputs(stage_id, tmp_inputs_shape, tmp_splittable_inputs, &tmp_sp_vector) !=
2102       SUCCESS) {
2103     return FAILED;
2104   }
2105 
2106   (void)std::transform(tmp_sp_vector.begin(), tmp_sp_vector.end(), std::back_inserter(*sp),
2107                        [stage_id, &indices_mp, &splittable_inputs](const StrategyPtr &sp) {
2108                          auto sp_strategies = sp->GetInputDim();
2109                          auto sp_sub_strategy = sp_strategies.at(0);
2110                          Strategies strategies(splittable_inputs);
2111                          for (size_t i = 0; i < strategies.size(); ++i) {
2112                            for (size_t j = 0; j < strategies[i].size(); ++j) {
2113                              if (splittable_inputs[i][j] == 0) {
2114                                strategies[i][j] = 1;
2115                              } else {
2116                                strategies[i][j] = sp_sub_strategy[indices_mp[splittable_inputs[i][j]]];
2117                              }
2118                            }
2119                          }
2120                          return std::make_shared<Strategy>(stage_id, strategies);
2121                        });
2122   return SUCCESS;
2123 }
2124 
2125 // generate strategies for that have two inputs, and input0 or input1 maybe broadcast,
2126 // and the corresponding dimensions that are not broadcast are all relevant dimensions
2127 // such as: ([a, b, c, d], [a, b, c, d]) or ([b, c, d], [a, b, c, d]) or ([1, c, d], [a, b, c, d])
2128 // or ([a, b, c, d], [b, c, d]) or ([a, b, c, d], [1, c, d])
2129 // or ([a, 1], [1, b]) or ([a, b, c, d], [1, b, c, d]) or ([a, b, c, 1], [1, b, c, d])
GenerateStrategiesWithBroadcast(int64_t stage_id,const Shapes & inputs_shape,const Shapes & splittable_inputs,std::vector<StrategyPtr> * sp_vector)2130 Status GenerateStrategiesWithBroadcast(int64_t stage_id, const Shapes &inputs_shape, const Shapes &splittable_inputs,
2131                                        std::vector<StrategyPtr> *sp_vector) {
2132   if (sp_vector == nullptr) {
2133     MS_LOG(ERROR) << "The sp_vector is null.";
2134     return FAILED;
2135   }
2136 
2137   if ((inputs_shape.size() != 2) || (splittable_inputs.size() != 2)) {
2138     MS_LOG(ERROR) << "The inputs' size is wrong.";
2139     return FAILED;
2140   }
2141 
2142   if (inputs_shape[0] == inputs_shape[1]) {
2143     // element wise operation([a, b, c, d], [a, b, c, d]), so input0's strategy is equal to input1's strategy
2144     if (GenerateStrategiesForTwoEqualInputs(stage_id, inputs_shape, splittable_inputs, sp_vector) != SUCCESS) {
2145       MS_LOG(ERROR) << "GenerateStrategiesForTwoEqualInputs failed.";
2146       return FAILED;
2147     }
2148     MS_LOG(INFO) << "GenerateStrategiesForTwoEqualInputs success.";
2149   } else if (inputs_shape[0].empty() || inputs_shape[1].empty()) {
2150     // ([a, b, c, d], []) or ([], [a, b, c, d])
2151     if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape, splittable_inputs, sp_vector) != SUCCESS) {
2152       MS_LOG(ERROR) << "Generate strategies for scalar case failed.";
2153       return FAILED;
2154     }
2155     MS_LOG(INFO) << "Generate strategies for scalar case success.";
2156   } else if (inputs_shape[0].size() > inputs_shape[1].size()) {
2157     // ([a, b, c, d], [b, c, d]) or ([a, b, c, d], [1, c, d])
2158     if (GenerateStrategiesForBroadcastRight(stage_id, inputs_shape, splittable_inputs, sp_vector) != SUCCESS) {
2159       MS_LOG(ERROR) << "GenerateStrategiesForBroadcastRight failed.";
2160       return FAILED;
2161     }
2162     MS_LOG(INFO) << "GenerateStrategiesForBroadcastRight success.";
2163   } else if (inputs_shape[0].size() < inputs_shape[1].size()) {
2164     // ([b, c, d], [a, b, c, d]) or ([1, c, d], [a, b, c, d])
2165     if (GenerateStrategiesForBroadcastLeft(stage_id, inputs_shape, splittable_inputs, sp_vector) != SUCCESS) {
2166       MS_LOG(ERROR) << "GenerateStrategiesForBroadcastLeft failed.";
2167       return FAILED;
2168     }
2169     MS_LOG(INFO) << "GenerateStrategiesForBroadcastLeft success.";
2170   } else {  // same size, but different value
2171     // ([a, 1], [1, b]) or ([a, b, c, d], [1, b, c, d]) or ([a, b, c, 1], [1, b, c, d])
2172     if (GenerateStrategiesForBroadcastBoth(stage_id, inputs_shape, splittable_inputs, sp_vector) != SUCCESS) {
2173       MS_LOG(ERROR) << "GenerateStrategiesForBroadcastBoth failed.";
2174       return FAILED;
2175     }
2176     MS_LOG(INFO) << "GenerateStrategiesForBroadcastBoth success.";
2177   }
2178   return SUCCESS;
2179 }
2180 
SetCostUnderStrategyBase(const StrategyPtr & strategy)2181 Status OperatorInfo::SetCostUnderStrategyBase(const StrategyPtr &strategy) {
2182   if (InitForCostModel(strategy, nullptr) == FAILED) {
2183     MS_LOG(DEBUG) << name_ << ": Initialization under the strategy failed.";
2184     return FAILED;
2185   }
2186   int64_t stage_id = strategy->GetInputStage();
2187   double computation_cost =
2188     operator_cost()->GetForwardComputationCost(inputs_tensor_info_, outputs_tensor_info_, stage_id);
2189   double communication_cost = operator_cost()->GetCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id);
2190   const auto gamma = CostModelContext::GetInstance()->costmodel_gamma();
2191   std::shared_ptr<Cost> result = std::make_shared<Cost>(computation_cost, communication_cost);
2192   result->communication_without_parameter_ =
2193     operator_cost()->GetForwardCommCost(inputs_tensor_info_, outputs_tensor_info_, stage_id);
2194   result->communication_with_partial_para_ =
2195     result->communication_without_parameter_ + gamma * (communication_cost - result->communication_without_parameter_);
2196 
2197   // Breaking ties for preferring data parallelization
2198   BreakingTiesForPreferringDataParallel(strategy, result);
2199   // refine communication cost calculation for practice
2200   RefineForPracticalCost(result, false);
2201   result->communication_forward_ = result->communication_without_parameter_;
2202 
2203   std::shared_ptr<StrategyWithCost> swc =
2204     std::make_shared<StrategyWithCost>(strategy, inputs_tensor_info_, outputs_tensor_info_);
2205   swc->cost_list.push_back(result);
2206   (void)strategy_cost_.emplace_back(swc);
2207 
2208   return SUCCESS;
2209 }
2210 
GetCostByStrategyPtr(const StrategyPtr & strategy)2211 CostPtrList OperatorInfo::GetCostByStrategyPtr(const StrategyPtr &strategy) {
2212   auto target = std::find_if(
2213     strategy_cost_.begin(), strategy_cost_.end(),
2214     [&](const std::shared_ptr<StrategyWithCost> &stra_cost) { return stra_cost->strategy_ptr == strategy; });
2215   if (target == strategy_cost_.end()) {
2216     MS_LOG(EXCEPTION) << "There is no StrategyWithCost with a strategy";
2217   }
2218   return (*target)->cost_list;
2219 }
2220 
GetInputLayoutFromSWCByStrategy(const StrategyPtr & stra,size_t input_index)2221 TensorLayout OperatorInfo::GetInputLayoutFromSWCByStrategy(const StrategyPtr &stra, size_t input_index) {
2222   auto is_target = [&](const std::shared_ptr<StrategyWithCost> &swc) { return swc->strategy_ptr->IsEqual(stra); };
2223   auto it = std::find_if(strategy_cost_.begin(), strategy_cost_.end(), is_target);
2224   if (it != strategy_cost_.end()) {
2225     const auto &input_info = (*it)->inputs_ptr[input_index];
2226     return std::move(input_info.tensor_layout());
2227   }
2228   TensorLayout empty;
2229   return empty;
2230 }
2231 
GetOutputLayoutFromSWCByStrategy(const StrategyPtr & stra,size_t output_index)2232 TensorLayout OperatorInfo::GetOutputLayoutFromSWCByStrategy(const StrategyPtr &stra, size_t output_index) {
2233   auto is_target = [&](const std::shared_ptr<StrategyWithCost> &swc) { return swc->strategy_ptr->IsEqual(stra); };
2234   auto it = std::find_if(strategy_cost_.begin(), strategy_cost_.end(), is_target);
2235   if (it != strategy_cost_.end()) {
2236     const auto &output_info = (*it)->outputs_ptr[output_index];
2237     return std::move(output_info.tensor_layout());
2238   }
2239   TensorLayout empty;
2240   return empty;
2241 }
2242 
GetStrategyFromSWCByInputLayout(const TensorLayout & input_layout,size_t input_index)2243 StrategyPtr OperatorInfo::GetStrategyFromSWCByInputLayout(const TensorLayout &input_layout, size_t input_index) {
2244   auto is_target = [&](const std::shared_ptr<StrategyWithCost> &swc) {
2245     return swc->inputs_ptr[input_index].tensor_layout() == input_layout;
2246   };
2247   auto it = std::find_if(strategy_cost_.begin(), strategy_cost_.end(), is_target);
2248   if (it != strategy_cost_.end()) {
2249     return (*it)->strategy_ptr;
2250   }
2251   return nullptr;
2252 }
2253 
GetStrategyFromSWCByOutputLayout(const TensorLayout & output_layout,size_t output_index)2254 StrategyPtr OperatorInfo::GetStrategyFromSWCByOutputLayout(const TensorLayout &output_layout, size_t output_index) {
2255   auto is_target = [&](const std::shared_ptr<StrategyWithCost> &swc) {
2256     return swc->outputs_ptr[output_index].tensor_layout() == output_layout;
2257   };
2258   auto it = std::find_if(strategy_cost_.begin(), strategy_cost_.end(), is_target);
2259   if (it != strategy_cost_.end()) {
2260     return (*it)->strategy_ptr;
2261   }
2262   return nullptr;
2263 }
2264 
IsReshape() const2265 bool OperatorInfo::IsReshape() const {
2266   if (name_.find(RESHAPEINFO) != std::string::npos) {
2267     return true;
2268   }
2269   return false;
2270 }
2271 
IsTmpIdentity() const2272 bool OperatorInfo::IsTmpIdentity() const {
2273   if (name_.find(IDENTITY_INFO) != std::string::npos) {
2274     return true;
2275   }
2276   return false;
2277 }
2278 
2279 // Keep at most (1.0 / epsilon) number of available strategies for each operator.
ApproximateStrategies()2280 void OperatorInfo::ApproximateStrategies() {
2281   auto enable_approxi = CostModelContext::GetInstance()->dp_algo_enable_approxi();
2282   if (!enable_approxi) {
2283     return;
2284   }
2285   MS_LOG(INFO) << name_ << ": Approximating strategy-cost";
2286   auto epsilon = CostModelContext::GetInstance()->dp_algo_approxi_epsilon();
2287   MS_EXCEPTION_IF_ZERO("epsilon", epsilon);
2288   auto target_num = static_cast<size_t>(std::ceil(1.0 / epsilon));
2289   if (strategy_cost_.size() <= target_num) {
2290     MS_LOG(INFO) << name_ << "'s strategy number is: " << strategy_cost_.size()
2291                  << ", no greater than target-num: " << target_num;
2292     return;
2293   }
2294   std::vector<std::shared_ptr<StrategyWithCost>> ret;
2295   auto &origin_stra_cost = strategy_cost_;
2296   auto alpha = CostModelContext::GetInstance()->costmodel_alpha();
2297   auto beta = CostModelContext::GetInstance()->costmodel_beta();
2298   // sort
2299   std::sort(
2300     origin_stra_cost.begin(), origin_stra_cost.end(),
2301     [&alpha, &beta](const std::shared_ptr<StrategyWithCost> &s1, const std::shared_ptr<StrategyWithCost> &s2) {
2302       if (alpha * s1->cost_list[0]->computation_cost_ + beta * s1->cost_list[0]->communication_with_partial_para_ <
2303           alpha * s2->cost_list[0]->computation_cost_ + beta * s2->cost_list[0]->communication_with_partial_para_) {
2304         return true;
2305       }
2306       return false;
2307     });
2308   MS_EXCEPTION_IF_ZERO("target_num", target_num);
2309   size_t step_length = origin_stra_cost.size() / target_num;
2310   for (size_t i = 0; ret.size() < target_num && static_cast<size_t>(i * step_length) < origin_stra_cost.size(); ++i) {
2311     ret.push_back(origin_stra_cost[static_cast<size_t>(i * step_length)]);
2312   }
2313 
2314   strategy_cost_ = ret;
2315   is_strategy_cost_exact_ = false;
2316 }
2317 
ExactStrategiesAndRelatedEdges()2318 void OperatorInfo::ExactStrategiesAndRelatedEdges() {
2319   if (is_strategy_cost_exact()) {
2320     return;
2321   }
2322   ClearStrategyCost();
2323   if (GenerateStrategies(0) != SUCCESS) {
2324     MS_LOG(EXCEPTION) << name_ << ": Strategy search failed.";
2325   }
2326   SetIsStrategyCostExactTrue();
2327   // re-init the previous edges
2328   for (auto &prev_edge : prev_edges()) {
2329     if (prev_edge->InitEdgeCost() != SUCCESS) {
2330       MS_LOG(EXCEPTION) << name_ << ": Edge: " << prev_edge->edge_name() << " cost init failed.";
2331     }
2332   }
2333   // re-init the successive edges
2334   for (auto &next_edge : succ_edges()) {
2335     if (next_edge->InitEdgeCost() != SUCCESS) {
2336       MS_LOG(EXCEPTION) << name_ << ": Edge: " << next_edge->edge_name() << " cost init failed.";
2337     }
2338   }
2339 }
2340 
ComputeOpAndPrevEdgeParameterInvolved()2341 int64_t OperatorInfo::ComputeOpAndPrevEdgeParameterInvolved() {
2342   if (is_output_parameter_involve_ != -1) {
2343     return is_output_parameter_involve_;
2344   }
2345   is_parameter_involve_ = is_parameter_;
2346   const auto &prev_edges = this->GetAlivePrevEdges();
2347   for (auto &p_edge : prev_edges) {
2348     auto input_index = p_edge->next_op_input_index();
2349     auto prev_op_para = p_edge->prev_operator()->ComputeOpAndPrevEdgeParameterInvolved();
2350     if (input_index >= is_parameter_involve_.size()) {
2351       MS_LOG(EXCEPTION) << name_ << " has input length: " << is_parameter_involve_.size()
2352                         << ", but got wrong input_index: " << input_index;
2353     }
2354     if (prev_op_para == 0) {
2355       is_parameter_involve_[input_index] = false;
2356     } else if (prev_op_para == 1) {
2357       is_parameter_involve_[input_index] = true;
2358     } else {
2359       MS_LOG(EXCEPTION) << name_ << " got wrong value: " << prev_op_para << ", input_index: " << input_index;
2360     }
2361     p_edge->set_parameter_involve(prev_op_para);
2362   }
2363   if (std::any_of(is_parameter_involve_.begin(), is_parameter_involve_.end(), [](bool value) { return value; })) {
2364     // If anyone of the input is a parameter_involved, the output is parameter_involved.
2365     is_output_parameter_involve_ = 1;
2366   } else {
2367     is_output_parameter_involve_ = 0;
2368   }
2369   // Set 'is_parameter_involve_' and 'is_output_parameter_involve_' into operatorCost, which are used in
2370   // calculating 'inputs_in_memory' and 'output_in_memory', respectively.
2371   operator_cost()->set_is_parameter_involve(is_parameter_involve_);
2372   operator_cost()->set_output_parameter_involve(is_output_parameter_involve_);
2373   // Calculating 'output_in_memory'
2374   operator_cost()->CalculateOutputInMemory();
2375   // Calculating 'inputs_in_memory'
2376   std::map<size_t, bool> input_in_memory;
2377   for (auto &p_edge : prev_edges) {
2378     auto input_index = p_edge->next_op_input_index();
2379     auto is_in_mem = p_edge->prev_operator()->operator_cost()->is_output_in_memory();
2380     (void)input_in_memory.emplace(std::make_pair(input_index, is_in_mem));
2381   }
2382   operator_cost()->CalculateInputsInMemory(input_in_memory);
2383 
2384   return is_output_parameter_involve_;
2385 }
2386 
set_is_parameter(const std::vector<bool> & is_parameter)2387 Status OperatorInfo::set_is_parameter(const std::vector<bool> &is_parameter) {
2388   if (is_parameter.size() != inputs_shape_.size()) {
2389     MS_LOG(ERROR) << name_ << ": Is_parameter: " << is_parameter.size()
2390                   << " do not have the same number of inputs_shape_: " << inputs_shape_.size();
2391     return FAILED;
2392   }
2393   is_parameter_ = is_parameter;
2394   operator_cost()->set_is_parameter(is_parameter);
2395   return SUCCESS;
2396 }
2397 
CalculateMemoryCost()2398 Status OperatorInfo::CalculateMemoryCost() {
2399   if (is_parameter_involve_.size() != is_parameter_.size()) {
2400     MS_LOG(ERROR) << name_ << ": the size of 'is_parameter_' is " << is_parameter_.size()
2401                   << " does not have the same number of the size of 'is_parameter_involve_'."
2402                   << is_parameter_involve_.size();
2403     return FAILED;
2404   }
2405   // Set the memory cost in the 'strategy_cost_'
2406   for (auto &swc : strategy_cost_) {
2407     auto mem_cost = operator_cost()->GetMemoryCost(swc->inputs_ptr, swc->outputs_ptr);
2408     swc->cost_list[0]->memory_with_reuse_ = mem_cost;
2409   }
2410   return SUCCESS;
2411 }
2412 
CalculateMemoryCostForInference()2413 Status OperatorInfo::CalculateMemoryCostForInference() {
2414   // First, set the 'is_outputs_critical_' flag into OperatorCost.
2415   if (is_output_critical_ == -1) {
2416     MS_LOG(EXCEPTION) << name_ << ": The critical flag is not set.";
2417   }
2418   operator_cost()->set_output_critical(is_output_critical_);
2419   // Set the memory cost in the 'strategy_cost_'
2420   for (auto &swc : strategy_cost_) {
2421     auto mem_cost = operator_cost()->GetMemoryCostForInference(swc->inputs_ptr, swc->outputs_ptr);
2422     swc->cost_list[0]->memory_with_reuse_ = mem_cost;
2423   }
2424   return SUCCESS;
2425 }
2426 
CorrectMemoryCost(size_t input_index)2427 Status OperatorInfo::CorrectMemoryCost(size_t input_index) {
2428   for (auto &swc : strategy_cost_) {
2429     double parameter_mem_cost = ListProduct(swc->inputs_ptr[input_index].slice_shape()) *
2430                                 static_cast<double>(operator_cost()->inputs_type_lengths()[input_index]);
2431     swc->cost_list[0]->memory_with_reuse_ -= parameter_mem_cost;
2432     if (swc->cost_list[0]->memory_with_reuse_ < 0) {
2433       MS_LOG(WARNING) << name_ << ": The memory cost after correction is: " << swc->cost_list[0]->memory_with_reuse_
2434                       << ", the parameter memory cost is: " << parameter_mem_cost;
2435       swc->cost_list[0]->memory_with_reuse_ = 0;
2436     }
2437   }
2438   return SUCCESS;
2439 }
2440 
ComputeRepeatDeviceNumByTensorMap(const Shape & dev_matrix_shape,const Shape & tensor_map)2441 int64_t ComputeRepeatDeviceNumByTensorMap(const Shape &dev_matrix_shape, const Shape &tensor_map) {
2442   int64_t ret = -1;
2443 
2444   // The number of repetitions is equal to the number of all devices divided by the number of devices use for
2445   // tensor map.
2446   int64_t device_num = std::accumulate(dev_matrix_shape.begin(), dev_matrix_shape.end(), 1, std::multiplies<int64_t>());
2447   for (auto &element : tensor_map) {
2448     // -1 means the corresponding dimension is not split.
2449     if (element == MAP_NONE) {
2450       continue;
2451     } else if ((element < 0) || (LongToSize(element) >= dev_matrix_shape.size())) {
2452       MS_LOG(ERROR) << "Invalid tensor map: " << ShapeToString(tensor_map) << ", the dev matrix shape is "
2453                     << ShapeToString(dev_matrix_shape);
2454       return ret;
2455     } else {
2456       size_t index = dev_matrix_shape.size() - LongToSize(element) - 1;
2457       if (dev_matrix_shape[index] <= 0) {
2458         MS_LOG(ERROR) << "Invalid dev matrix shape: " << ShapeToString(dev_matrix_shape);
2459         return ret;
2460       }
2461       device_num /= dev_matrix_shape[index];
2462     }
2463   }
2464 
2465   return device_num;
2466 }
2467 
InferAsLossDivisor()2468 Status OperatorInfo::InferAsLossDivisor() {
2469   if (!ParallelContext::GetInstance()->loss_repeated_mean()) {
2470     as_loss_divisor_ = 1;
2471     return SUCCESS;
2472   }
2473   if (!inputs_shape_new_.empty()) {
2474     MS_LOG(ERROR) << name_ << ": For Tuple input ops, please override this function";
2475     return FAILED;
2476   }
2477   if (outputs_tensor_map_.empty()) {
2478     MS_LOG(ERROR) << name_ << ": The outputs tensor map is empty.";
2479     return FAILED;
2480   }
2481 
2482   if (outputs_tensor_map_.size() > 1) {
2483     MS_LOG(ERROR) << name_ << ": The output size is " << outputs_tensor_map_.size()
2484                   << ", need to override this function ";
2485     return FAILED;
2486   }
2487 
2488   if (outputs_tensor_map_[0].empty()) {
2489     as_loss_divisor_ = stage_device_size_;
2490     MS_LOG(INFO) << name_ << ": The output is a scalar, use the dev size " << as_loss_divisor_ << ", loss divisor.";
2491     return SUCCESS;
2492   }
2493 
2494   if (out_dev_matrix_shape_.empty()) {
2495     out_dev_matrix_shape_ = dev_matrix_shape_;
2496   }
2497   as_loss_divisor_ = ComputeRepeatDeviceNumByTensorMap(out_dev_matrix_shape_, outputs_tensor_map_[0]);
2498   MS_LOG(INFO) << name_ << ": the dev matrix shape is " << ShapeToString(out_dev_matrix_shape_)
2499                << ", the output tensor map is " << ShapeToString(outputs_tensor_map_[0]) << ", loss divisor is "
2500                << as_loss_divisor_;
2501   return SUCCESS;
2502 }
2503 
InferAsLossDivisorByLayout()2504 Status OperatorInfo::InferAsLossDivisorByLayout() {
2505   if (!ParallelContext::GetInstance()->loss_repeated_mean()) {
2506     as_loss_divisor_ = 1;
2507     return SUCCESS;
2508   }
2509 
2510   if (outputs_tensor_info_.empty()) {
2511     MS_LOG(ERROR) << name_ << ": The outputs tensor info is empty.";
2512     return FAILED;
2513   }
2514 
2515   if (outputs_tensor_info_.size() > 1) {
2516     MS_LOG(ERROR) << name_ << ": The output size is " << outputs_tensor_info_.size()
2517                   << ", need to override this function ";
2518     return FAILED;
2519   }
2520 
2521   TensorMaps outputs_tensor_map = outputs_tensor_info_[0].tensor_layout().tensor_map_before();
2522   if (outputs_tensor_map.empty()) {
2523     MS_LOG(INFO) << name_ << ": out_dev_matrix_shape is empty";
2524     as_loss_divisor_ = stage_device_size_;
2525     MS_LOG(INFO) << name_ << ": The output is a scalar, use the dev size " << as_loss_divisor_ << ", loss divisor.";
2526     return SUCCESS;
2527   }
2528 
2529   auto out_dev_matrix_shape = outputs_tensor_info_[0].tensor_layout().device_arrangement_origin().array();
2530   if (out_dev_matrix_shape.empty()) {
2531     out_dev_matrix_shape = dev_matrix_shape_;
2532   }
2533   Shape squashed_tensor_map;
2534   for (const auto &tensor_map : outputs_tensor_map) {
2535     std::copy(tensor_map.begin(), tensor_map.end(), std::back_inserter(squashed_tensor_map));
2536   }
2537   as_loss_divisor_ = ComputeRepeatDeviceNumByTensorMap(out_dev_matrix_shape, squashed_tensor_map);
2538   MS_LOG(INFO) << name_ << ": the dev matrix shape is " << ShapeToString(out_dev_matrix_shape)
2539                << ", the output tensor map is " << ShapeToString(squashed_tensor_map) << ", loss divisor is "
2540                << as_loss_divisor_;
2541   return SUCCESS;
2542 }
2543 
2544 // If the operator is used as a loss, a div node is inserted for the grad of all its inputs.
InferVirtualDivOps()2545 Status OperatorInfo::InferVirtualDivOps() {
2546   if (InferAsLossDivisor() != SUCCESS) {
2547     MS_LOG(ERROR) << name_ << ": InferAsLossDivisor failed.";
2548     return FAILED;
2549   }
2550 
2551   if (as_loss_divisor_ <= 0) {
2552     MS_LOG(ERROR) << name_ << ": Invalid loss divisor: " << as_loss_divisor_;
2553     return FAILED;
2554   } else if (as_loss_divisor_ == 1) {
2555     MS_LOG(INFO) << name_ << ": The loss divisor is 1, no need to create virtual div op.";
2556     return SUCCESS;
2557   }
2558 
2559   virtual_div_op_.clear();
2560   // if loss is repeated calculation, insert div op
2561   Operator op = CreateVirtualDivOp(as_loss_divisor_);
2562   virtual_div_op_.push_back(op);
2563   return SUCCESS;
2564 }
2565 
InferVirtualDivOpsByLayout()2566 Status OperatorInfo::InferVirtualDivOpsByLayout() {
2567   if (InferAsLossDivisorByLayout() != SUCCESS) {
2568     MS_LOG(ERROR) << name_ << ": InferAsLossDivisor failed.";
2569     return FAILED;
2570   }
2571 
2572   if (as_loss_divisor_ <= 0) {
2573     MS_LOG(ERROR) << name_ << ": Invalid loss divisor: " << as_loss_divisor_;
2574     return FAILED;
2575   } else if (as_loss_divisor_ == 1) {
2576     MS_LOG(INFO) << name_ << ": The loss divisor is 1, no need to create virtual div op.";
2577     return SUCCESS;
2578   }
2579 
2580   virtual_div_op_.clear();
2581   // if loss is repeated calculation, insert div op
2582   Operator op = CreateVirtualDivOp(as_loss_divisor_);
2583   virtual_div_op_.push_back(op);
2584   return SUCCESS;
2585 }
2586 
SetInputAndOutputTypeLength(const std::vector<size_t> & input_lengths,const std::vector<size_t> & output_lengths)2587 Status OperatorInfo::SetInputAndOutputTypeLength(const std::vector<size_t> &input_lengths,
2588                                                  const std::vector<size_t> &output_lengths) {
2589   if (input_lengths.size() != inputs_shape_.size()) {
2590     MS_LOG(ERROR) << name_ << ": Input_lengths: " << input_lengths.size()
2591                   << " do not have the same number of inputs shape: " << inputs_shape_.size();
2592     return FAILED;
2593   }
2594   if (output_lengths.size() != outputs_shape_.size()) {
2595     MS_LOG(ERROR) << name_ << ": Output_lengths: " << output_lengths.size()
2596                   << " do not have the same number of outputs shape: " << outputs_shape_.size();
2597     return FAILED;
2598   }
2599   inputs_type_lengths_ = input_lengths;
2600   outputs_type_lengths_ = output_lengths;
2601   operator_cost()->SetInputAndOutputTypeLength(input_lengths, output_lengths);
2602   return SUCCESS;
2603 }
2604 
GetOutputsTotalSize()2605 double OperatorInfo::GetOutputsTotalSize() {
2606   if (is_calculated_outputs_size_) {
2607     return outputs_total_size_;
2608   }
2609   if (outputs_type_lengths_.size() != outputs_shape_.size()) {
2610     MS_LOG(EXCEPTION) << name_ << ": Output_lengths: " << outputs_type_lengths_.size()
2611                       << " do not have the same number of outputs shape: " << outputs_shape_.size();
2612   }
2613   double sum = 0.0;
2614   for (size_t i = 0; i < outputs_type_lengths_.size(); ++i) {
2615     auto size = std::accumulate(outputs_shape_[i].begin(), outputs_shape_[i].end(), static_cast<double>(1.0),
2616                                 std::multiplies<double>());
2617     sum += size * static_cast<double>(outputs_type_lengths_[i]);
2618   }
2619   is_calculated_outputs_size_ = true;
2620   outputs_total_size_ = sum;
2621   return outputs_total_size_;
2622 }
2623 
set_outputs_type(const std::vector<TypePtr> & outputs_type)2624 Status OperatorInfo::set_outputs_type(const std::vector<TypePtr> &outputs_type) {
2625   if (outputs_type.size() != outputs_shape_.size()) {
2626     MS_LOG(ERROR) << name_ << ": Outputs type: " << outputs_type.size()
2627                   << " do not have the same number of outputs shape: " << outputs_shape_.size();
2628     return FAILED;
2629   }
2630   outputs_type_ = outputs_type;
2631   return SUCCESS;
2632 }
2633 
BreakingTiesForPreferringDataParallel(const StrategyPtr & stra,const CostPtr & cost) const2634 void OperatorInfo::BreakingTiesForPreferringDataParallel(const StrategyPtr &stra, const CostPtr &cost) const {
2635   if (!stra->GetInputDim().empty() && !stra->GetInputDim()[0].empty()) {
2636     if (stra->GetInputDim()[0][0] == stage_device_size_) {
2637       if (cost->computation_cost_ > 1.0) {
2638         cost->computation_cost_ -= 1.0;
2639       }
2640       if (cost->communication_cost_ > 1.0) {
2641         cost->communication_cost_ -= 1.0;
2642       }
2643       if (cost->communication_with_partial_para_ > 1.0) {
2644         cost->communication_with_partial_para_ -= 1.0;
2645       }
2646       if (cost->communication_without_parameter_ > 1.0) {
2647         cost->communication_without_parameter_ -= 1.0;
2648       }
2649     }
2650   }
2651 }
2652 
SetSelectedStrategy(const StrategyPtr & s_strategy,size_t curr_depth)2653 void OperatorInfo::SetSelectedStrategy(const StrategyPtr &s_strategy, size_t curr_depth) {
2654   MS_EXCEPTION_IF_NULL(s_strategy);
2655   if ((selected_strategy_depth_ != -1) && (SizeToLong(curr_depth) > selected_strategy_depth_)) {
2656     MS_LOG(INFO) << name_ << " has already been set strategy.";
2657     return;
2658   }
2659   MS_LOG(INFO) << name_ << ": Set strategy " << s_strategy->ToString();
2660   selected_strategy_ = s_strategy;
2661   selected_strategy_depth_ = SizeToLong(curr_depth);
2662 }
2663 
set_swc_index(int64_t swc,int64_t depth)2664 void OperatorInfo::set_swc_index(int64_t swc, int64_t depth) {
2665   MS_LOG(INFO) << name_ << ": Set SWC index: " << swc;
2666   selected_strategy_depth_ = depth;
2667   swc_index_ = swc;
2668 }
2669 
cnodes()2670 std::vector<CNodePtr> OperatorInfo::cnodes() { return cnodes_; }
2671 
GetForwardMemoryCostFromCNode()2672 double OperatorInfo::GetForwardMemoryCostFromCNode() {
2673   return operator_cost()->GetForwardComputationCost(inputs_tensor_info_, outputs_tensor_info_, 0);
2674 }
2675 
CheckSelectedStrategy(const StrategyPtr & s_strategy)2676 void OperatorInfo::CheckSelectedStrategy(const StrategyPtr &s_strategy) {
2677   MS_EXCEPTION_IF_NULL(s_strategy);
2678   if (!s_strategy->IsEqual(selected_strategy_)) {
2679     MS_LOG(INFO) << name_
2680                  << "'s strategy may cause suboptimal, the determined strategy: " << selected_strategy_->ToString()
2681                  << "The minimal strategy: " << s_strategy->ToString();
2682   }
2683 }
2684 
SetStrategyCost(const std::vector<std::shared_ptr<StrategyWithCost>> & stra_cost)2685 void OperatorInfo::SetStrategyCost(const std::vector<std::shared_ptr<StrategyWithCost>> &stra_cost) {
2686   strategy_cost_ = stra_cost;
2687 }
2688 
GenerateStrategies(int64_t stage_id)2689 Status OperatorInfo::GenerateStrategies(int64_t stage_id) {
2690   if (InferAttrs() != SUCCESS) {
2691     MS_LOG(ERROR) << name_ << ": Infer attrs failed";
2692     return FAILED;
2693   }
2694 
2695   DivisorsReplaceShapes();  // in dynamic shape, using divisors replace to shapes before CheckStrategy and so on
2696   std::vector<StrategyPtr> sp_vector = GenerateOpStrategies(stage_id);
2697   ResumeShapes();  // resume shapes
2698 
2699   size_t success = 0;
2700   for (auto &sp : sp_vector) {
2701     if (SetCostUnderStrategy(sp) == SUCCESS) {
2702       success++;
2703       MS_LOG(INFO) << name_ << ": Successfully generated the " << GetSerialNumberString(success)
2704                    << " strategy: " << sp->ToString();
2705     } else {
2706       MS_LOG(INFO) << name_ << ": SetCostUnderStrategy failed, the strategy is " << sp->ToString();
2707     }
2708   }
2709   return SUCCESS;
2710 }
2711 
GetIntAttr(const std::string & attr_name)2712 int64_t OperatorInfo::GetIntAttr(const std::string &attr_name) {
2713   auto attr_iter = attrs_.find(attr_name);
2714   if (attr_iter == attrs_.end()) {
2715     MS_LOG(EXCEPTION) << name_ << ": Can not find the attribution of " << attr_name;
2716   }
2717 
2718   MS_EXCEPTION_IF_NULL(attr_iter->second);
2719   if (!attr_iter->second->isa<Int64Imm>()) {
2720     MS_LOG(EXCEPTION) << name_ << ": The value of " << attr_name << " is not int";
2721   }
2722 
2723   return attr_iter->second->cast<Int64ImmPtr>()->value();
2724 }
2725 
GetBoolAttr(const std::string & attr_name)2726 bool OperatorInfo::GetBoolAttr(const std::string &attr_name) {
2727   auto attr_iter = attrs_.find(attr_name);
2728   if (attr_iter == attrs_.end()) {
2729     MS_LOG(EXCEPTION) << name_ << ": Can not find the attribution of " << attr_name;
2730   }
2731 
2732   MS_EXCEPTION_IF_NULL(attr_iter->second);
2733   if (!attr_iter->second->isa<BoolImm>()) {
2734     MS_LOG(EXCEPTION) << name_ << ": The value of " << attr_name << " is not int";
2735   }
2736 
2737   return attr_iter->second->cast<BoolImmPtr>()->value();
2738 }
2739 
GetStringAttr(const std::string & attr_name)2740 std::string OperatorInfo::GetStringAttr(const std::string &attr_name) {
2741   std::string string_attr;
2742   auto attr_iter = attrs_.find(attr_name);
2743   if (attr_iter == attrs_.end()) {
2744     MS_LOG(EXCEPTION) << name_ << ": Can not find the attribution of " << attr_name;
2745   }
2746 
2747   MS_EXCEPTION_IF_NULL(attr_iter->second);
2748   if (!attr_iter->second->isa<StringImm>()) {
2749     MS_LOG(EXCEPTION) << name_ << ": The value of " << attr_name << " is not string";
2750   }
2751 
2752   string_attr = attr_iter->second->cast<StringImmPtr>()->value();
2753   return string_attr;
2754 }
2755 
GetTupleIntAttr(const std::string & attr_name)2756 std::vector<int64_t> OperatorInfo::GetTupleIntAttr(const std::string &attr_name) {
2757   std::vector<int64_t> tuple_attr;
2758   auto tuple_attr_iter = attrs_.find(attr_name);
2759   if (tuple_attr_iter == attrs_.end()) {
2760     MS_LOG(EXCEPTION) << name_ << ": Can not find the attribution of " << attr_name;
2761   }
2762 
2763   MS_EXCEPTION_IF_NULL(tuple_attr_iter->second);
2764   tuple_attr = GetValue<std::vector<int64_t>>(tuple_attr_iter->second);
2765 
2766   return tuple_attr;
2767 }
2768 
GetFloatAttr(const std::string & attr_name)2769 float OperatorInfo::GetFloatAttr(const std::string &attr_name) {
2770   auto attr_iter = attrs_.find(attr_name);
2771   if (attr_iter == attrs_.end()) {
2772     MS_LOG(EXCEPTION) << name_ << ": Can not find the attribution of " << attr_name;
2773   }
2774 
2775   MS_EXCEPTION_IF_NULL(attr_iter->second);
2776   if (!attr_iter->second->isa<FP32Imm>()) {
2777     MS_LOG(EXCEPTION) << name_ << ": The value of " << attr_name << " is not float";
2778   }
2779 
2780   return attr_iter->second->cast<FP32ImmPtr>()->value();
2781 }
2782 
GetValueSequence(const ValuePtr & sequence)2783 std::vector<ValuePtr> GetValueSequence(const ValuePtr &sequence) {
2784   MS_EXCEPTION_IF_NULL(sequence);
2785   std::vector<ValuePtr> ret;
2786   if (!sequence->isa<ValueTuple>() && !sequence->isa<ValueList>()) {
2787     MS_LOG(ERROR) << "The arg is not value tuple or value list";
2788     return ret;
2789   }
2790 
2791   if (sequence->isa<ValueTuple>()) {
2792     auto val_tuple = sequence->cast<ValueTuplePtr>();
2793     return val_tuple->value();
2794   }
2795   auto val = sequence->cast<ValueListPtr>();
2796   return val->value();
2797 }
2798 
MakeListValue(const std::vector<int64_t> & v)2799 ValuePtr MakeListValue(const std::vector<int64_t> &v) {
2800   std::vector<ValuePtr> list;
2801   (void)std::transform(v.begin(), v.end(), std::back_inserter(list), [](int64_t ele) { return MakeValue(ele); });
2802   return std::make_shared<ValueSequence>(list);
2803 }
2804 
MakeTupleListValue(const Shapes & v)2805 ValuePtr MakeTupleListValue(const Shapes &v) {
2806   std::vector<ValuePtr> tuple;
2807   (void)std::transform(v.begin(), v.end(), std::back_inserter(tuple),
2808                        [](const std::vector<int64_t> &list) { return MakeListValue(list); });
2809   return std::make_shared<ValueTuple>(tuple);
2810 }
2811 
CreateValueTupleAnfNodePtr(const std::vector<int64_t> & value_tuple)2812 AnfNodePtr CreateValueTupleAnfNodePtr(const std::vector<int64_t> &value_tuple) {
2813   auto value_ptr = MakeValue(value_tuple)->cast<ValueTuplePtr>();
2814   auto value_node = NewValueNode(value_ptr);
2815   return value_node->cast<AnfNodePtr>();
2816 }
2817 
CreateTensorTupleAnfNodePtr(const tensor::TensorPtrList & tensor_tuple)2818 AnfNodePtr CreateTensorTupleAnfNodePtr(const tensor::TensorPtrList &tensor_tuple) {
2819   auto tensor_ptr = MakeValue(tensor_tuple)->cast<ValueTuplePtr>();
2820   auto tensor_node = NewValueNode(tensor_ptr);
2821   return tensor_node->cast<AnfNodePtr>();
2822 }
2823 
CreateDivOpWithType(float divisor,const TypePtr & dtype)2824 Operator CreateDivOpWithType(float divisor, const TypePtr &dtype) {
2825   OperatorName operator1_name = REAL_DIV;
2826   mindspore::tensor::TensorPtr tensor_ptr = std::make_shared<mindspore::tensor::Tensor>(divisor, dtype);
2827   ValuePtr op1_param_value = MakeValue(tensor_ptr);
2828   Attr op1_param = std::make_pair("divisor", op1_param_value);
2829   OperatorParams operator1_params = {std::make_pair(op1_param, 2)};
2830   OperatorAttrs operator1_attrs;
2831   OperatorArgs operator1_args = std::make_pair(operator1_attrs, operator1_params);
2832   Operator div_op = std::make_pair(operator1_name, operator1_args);
2833   return div_op;
2834 }
2835 
CreateReduceMeanForwardOp(const std::vector<Group> & forward_group,const TypePtr & dtype)2836 ForwardOp CreateReduceMeanForwardOp(const std::vector<Group> &forward_group, const TypePtr &dtype) {
2837   // Create AllReduceSum op
2838   Operator op0 = CreateAllReduceOp(REDUCE_OP_SUM, forward_group[0].name());
2839   std::string group_name = forward_group[0].name();
2840   MS_LOG(INFO) << "The group of forward all reduce is " << group_name;
2841 
2842   // Create RealDiv op
2843   std::vector<Device> device_list = forward_group[0].GetDevicesList();
2844   auto divisor = SizeToFloat(device_list.size());
2845   Operator op1 = CreateDivOpWithType(divisor, dtype);
2846   std::string dtype_name = dtype->ToString();
2847   MS_LOG(INFO) << "The divisor of Div op is " << device_list.size() << ", the dtype is " << dtype_name;
2848 
2849   return {op0, op1};
2850 }
2851 
GetTensorValue(const ValuePtr & ori_value)2852 std::vector<int64_t> GetTensorValue(const ValuePtr &ori_value) {
2853   MS_EXCEPTION_IF_NULL(ori_value);
2854   if (!ori_value->isa<tensor::Tensor>()) {
2855     MS_LOG(INTERNAL_EXCEPTION) << "Value is not tensor";
2856   }
2857   auto tensor_ptr = ori_value->cast<tensor::TensorPtr>();
2858   std::vector<int64_t> value;
2859   auto element_size = tensor_ptr->data().size();
2860   auto *data = static_cast<int64_t *>(tensor_ptr->data_c());
2861   for (auto i = 0; i < element_size; i++) {
2862     value.push_back(data[i]);
2863   }
2864   return value;
2865 }
2866 }  // namespace parallel
2867 }  // namespace mindspore
2868