• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/parallel/ops_info/gather_v2_p_info.h"
18 
19 #include <vector>
20 #include <numeric>
21 #include <functional>
22 #include <utility>
23 #include <algorithm>
24 
25 #include "frontend/parallel/device_matrix.h"
26 #include "frontend/parallel/graph_util/generate_graph.h"
27 #include "frontend/parallel/context.h"
28 #if ((defined ENABLE_CPU) && (!defined _WIN32))
29 #include "ps/ps_cache/ps_cache_manager.h"
30 #include "utils/ms_context.h"
31 #endif
32 
33 namespace mindspore {
34 namespace parallel {
GetManualSplitWithoutOffsetAttr()35 Status GatherPInfo::GetManualSplitWithoutOffsetAttr() {
36   auto manual_split_without_offset_iter = attrs_.find("manual_split");
37   if (manual_split_without_offset_iter != attrs_.end()) {
38     manual_split_ = true;
39     MS_EXCEPTION_IF_NULL(manual_split_without_offset_iter->second);
40     if (manual_split_without_offset_iter->second->cast<ValueTuplePtr>() == nullptr) {
41       MS_LOG(ERROR) << name_ << ": Manual split without offset strategy's format is wrong! Need ValueSequeue";
42       return FAILED;
43     }
44     std::vector<ValuePtr> value_vector = manual_split_without_offset_iter->second->cast<ValueTuplePtr>()->value();
45     MS_LOG(INFO) << name_ << ": manual split with offset is " << manual_split_without_offset_iter->second->ToString();
46 
47     int64_t offset = 0;
48     for (auto &ele : value_vector) {
49       index_offsets_.push_back(offset);
50       if (!ele->isa<Int64Imm>()) {
51         MS_LOG(ERROR) << name_ << ": The element of manual split must be int64_t";
52         return FAILED;
53       }
54       int64_t param_split_shape = static_cast<int64_t>(GetValue<int64_t>(ele));
55       if (param_split_shape <= 0) {
56         MS_LOG(ERROR) << name_ << ": The value of manual split must be positive, but got " << param_split_shape;
57         return FAILED;
58       }
59       param_split_shapes_.push_back(param_split_shape);
60       offset += param_split_shape;
61     }
62     if (param_split_shapes_.empty()) {
63       MS_LOG(ERROR) << name_ << ": Failed to extract param split's split info";
64       return FAILED;
65     }
66   }
67 
68   return SUCCESS;
69 }
70 
GetManualSplitAttr()71 Status GatherPInfo::GetManualSplitAttr() {
72   auto manual_split_with_offset_iter = attrs_.find("manual_split_with_offset");
73   if (manual_split_with_offset_iter != attrs_.end()) {
74     manual_split_ = true;
75     auto var = manual_split_with_offset_iter->second->cast<ValueTuplePtr>();
76     if (var == nullptr) {
77       MS_LOG(ERROR) << name_ << ": Manual split with offset strategy's format is wrong! Need ValueSequeue";
78       return FAILED;
79     }
80 
81     MS_LOG(INFO) << name_ << ": manual split with offset strategy " << var->ToString();
82     for (auto &ele : var->value()) {
83       if (!ele->isa<ValueSequeue>()) {
84         MS_LOG(ERROR) << name_ << ": Manual split with offset strategy's format is wrong! Need ValueSequeue";
85         return FAILED;
86       }
87       std::vector<ValuePtr> value_vector = ele->cast<ValueTuplePtr>()->value();
88       if (value_vector.size() != 2) {
89         MS_LOG(ERROR) << name_ << ": Size of manual split with offset's element must be 2";
90         return FAILED;
91       }
92       int64_t param_split_row = (GetValue<int64_t>(value_vector[0]));
93       int64_t offset = (GetValue<int64_t>(value_vector[1]));
94       if ((param_split_row <= 0) || (offset < 0)) {
95         MS_LOG(ERROR) << name_
96                       << ": The value of param split shape must be positive, and the offset must larger or equal to 0";
97         return FAILED;
98       }
99       param_split_shapes_.push_back(param_split_row);
100       index_offsets_.push_back(offset);
101     }
102 
103     if (param_split_shapes_.empty()) {
104       MS_LOG(ERROR) << name_ << ": Failed to extract param split with offset's split info";
105       return FAILED;
106     }
107     if (std::any_of(index_offsets_.begin(), index_offsets_.end(), [](const int64_t &offset) { return offset < 0; })) {
108       MS_LOG(ERROR) << name_ << ": Index offset must not less than 0";
109       return FAILED;
110     }
111     return SUCCESS;
112   }
113 
114   if (GetManualSplitWithoutOffsetAttr() != SUCCESS) {
115     return FAILED;
116   }
117 
118   return SUCCESS;
119 }
120 
GetAttrs()121 Status GatherPInfo::GetAttrs() {
122   // get axis, the third input is the axis, is a ValueNode, embeddinglookup doesn't have axis.
123   if (target_ != CPU) {
124     if (input_value_.at(2) == nullptr) {
125       MS_LOG(ERROR) << name_ << ": the third input value is nullptr, is not a ValueNode!";
126       return FAILED;
127     }
128     auto axis = GetValue<int64_t>(input_value_.at(2));
129     // if axis is negative then convert it to positive
130     auto params_shape = inputs_shape_.at(0);
131     if (params_shape.size() == 0) {
132       MS_LOG(ERROR) << name_ << ": params can not be a scalar!";
133       return FAILED;
134     }
135     if (axis < 0) {
136       axis += SizeToLong(inputs_shape_[0].size());
137     }
138     axis_ = axis;
139   }
140 
141   auto target_iter = attrs_.find(TARGET);
142   if (target_iter != attrs_.end()) {
143     MS_EXCEPTION_IF_NULL(target_iter->second);
144     if (target_iter->second->isa<StringImm>()) {
145       target_ = target_iter->second->cast<StringImmPtr>()->value();
146     } else {
147       MS_LOG(ERROR) << name_ << ": The value of target is not a string.";
148     }
149   }
150 
151   if (GetManualSplitAttr() != SUCCESS) {
152     return FAILED;
153   }
154 
155   if (manual_split_ && (axis_ != 0)) {
156     MS_LOG(ERROR) << name_ << ": The axis or offset must be 0 if manual split, bug got " << axis_;
157     return FAILED;
158   }
159 
160   if (std::find(inputs_shape_[1].begin(), inputs_shape_[1].end(), -1) != inputs_shape_[1].end()) {
161     dynamic_shape_indices_ = true;
162   }
163 #if ((defined ENABLE_CPU) && (!defined _WIN32))
164   MS_EXCEPTION_IF_NULL(MsContext::GetInstance());
165   bool enable_sparse = MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_SPARSE);
166   if (ps::PsDataPrefetch::GetInstance().cache_enable() && enable_sparse) {
167     dynamic_shape_indices_ = true;
168   }
169 #endif
170   return SUCCESS;
171 }
172 
CheckManualSplit(const Strategys & strategy)173 Status GatherPInfo::CheckManualSplit(const Strategys &strategy) {
174   if (strategy.size() != 2) {
175     MS_LOG(ERROR) << name_ << ": The size of strategy must be 2, but got " << strategy.size();
176     return FAILED;
177   }
178   Dimensions param_strategy = strategy[0];
179   Dimensions indices_strategy = strategy[1];
180   if (param_strategy.size() != 2 || indices_strategy.size() != 2) {
181     MS_LOG(ERROR) << name_ << ": The size of param strategy or indices strategy must be 2";
182     return FAILED;
183   }
184 
185   if (indices_strategy[0] != 1) {
186     MS_LOG(ERROR) << name_ << ": The indices_strategy[0] must be 1, bug got " << indices_strategy[0];
187     return FAILED;
188   }
189 
190   if (param_strategy[0] != indices_strategy[1]) {
191     MS_LOG(ERROR) << name_ << ": The param_strategy[0] must be equal to indices_strategy[1]";
192     return FAILED;
193   }
194 
195   if (indices_strategy[1] != SizeToLong(param_split_shapes_.size())) {
196     MS_LOG(ERROR) << name_ << ": The indices_strategy[1] must be equal to manual split size";
197     return FAILED;
198   }
199 
200   int64_t min_param_slice_row = inputs_shape_[1][1] / indices_strategy[1];
201   bool invalid = std::any_of(param_split_shapes_.begin(), param_split_shapes_.end(),
202                              [&min_param_slice_row](int64_t v) { return v < min_param_slice_row; });
203   if (invalid) {
204     MS_LOG(ERROR) << name_ << ": The split value must be larger than or equal to indices slice's column num";
205     return FAILED;
206   }
207 
208   if (inputs_shape_[0][0] < inputs_shape_[1][1]) {
209     MS_LOG(ERROR) << name_ << ": The param's row smaller than indices' column";
210     return FAILED;
211   }
212 
213   // Don't support repeated calc
214   auto product_p = std::accumulate(param_strategy.begin(), param_strategy.end(), 1, std::multiplies<int64_t>());
215   if (product_p < stage_device_size_) {
216     MS_LOG(ERROR) << name_ << ": Manual split doesn't support repeated calc";
217     return FAILED;
218   }
219 
220   int64_t split_shape_sum = std::accumulate(param_split_shapes_.begin(), param_split_shapes_.end(), 0,
221                                             [](int64_t s, int64_t shape) { return s + shape; });
222   if (split_shape_sum != inputs_shape_[0][0]) {
223     MS_LOG(ERROR) << name_ << ": Sum of split shapes must be equal to param_shape[0]";
224     return FAILED;
225   }
226   return SUCCESS;
227 }
228 
CheckSplitAxisStrategy(const StrategyPtr & strategy)229 Status GatherPInfo::CheckSplitAxisStrategy(const StrategyPtr &strategy) {
230   auto param_strategy = strategy->GetInputDim().at(0);
231   auto index_strategy = strategy->GetInputDim().at(1);
232   // param_strategy(axis) != 1, index can't be split
233   auto product_i = std::accumulate(index_strategy.begin(), index_strategy.end(), 1, std::multiplies<int64_t>());
234   if ((param_strategy.at(LongToSize(axis_)) != 1) && (product_i != 1)) {
235     MS_LOG(DEBUG) << name_ << ": param is split at dim (axis)" << axis_ << " ,index can't be split.";
236     return FAILED;
237   }
238 
239   // param_strategy(axis) != 1, and axis != 0, don't support repeated calc
240   auto product_p = std::accumulate(param_strategy.begin(), param_strategy.end(), 1, std::multiplies<int64_t>());
241   if ((product_p != stage_device_size_) && (param_strategy.at(LongToSize(axis_)) != 1) && (axis_ != 0)) {
242     MS_LOG(DEBUG) << name_ << ": Invalid strategy. Don't support repeated calc.";
243     return FAILED;
244   }
245 
246   if ((product_p != stage_device_size_) && (param_strategy.at(LongToSize(axis_)) != 1) && (axis_ == 0)) {
247     if ((param_strategy.size() == 2) && (param_strategy[1] != 1)) {
248       MS_LOG(DEBUG) << name_ << ": axis(0) is split, and param_strategy[1] != 1, don't support repeated calc.";
249       return FAILED;
250     }
251     MS_LOG(INFO) << name_ << ": split axis(0) and repeat calculation";
252   }
253   return SUCCESS;
254 }
255 
256 // return true: axis is 0, and split the first dimension of parameter and the first dimension of indices
257 // otherwise return false
ShardBatchAndAxis(const Strategys & strategy) const258 bool GatherPInfo::ShardBatchAndAxis(const Strategys &strategy) const {
259   if (axis_ != 0) {
260     return false;
261   }
262 
263   if (strategy.size() != 2) {
264     return false;
265   }
266 
267   Dimensions param_strategy = strategy[0];
268   Dimensions indices_strategy = strategy[1];
269   if ((param_strategy.size() != 2) || (indices_strategy.size() != 2)) {
270     return false;
271   }
272 
273   if ((param_strategy[1] != 1) || (indices_strategy[1] != 1)) {
274     return false;
275   }
276 
277   if (param_strategy[0] * indices_strategy[0] != stage_device_size_) {
278     return false;
279   }
280 
281   if ((param_strategy[0] == stage_device_size_) || (indices_strategy[0] == stage_device_size_)) {
282     return false;
283   }
284 
285   return true;
286 }
287 
SetAttribute(const StrategyPtr & strategy)288 void GatherPInfo::SetAttribute(const StrategyPtr &strategy) {
289   auto param_strategy = strategy->GetInputDim().at(0);
290   // axis=0, index_shape(0)%param_strategy(0) must be 0
291   Shape index_shape = inputs_shape_.at(1);
292   if ((axis_ == 0) && (index_shape.at(0) % param_strategy.at(0) != 0) && !dynamic_shape_indices_) {
293     MS_LOG(INFO) << name_ << ": index_shape(0) can't be divided by param_strategy(0), use allreduce in forward";
294     axis_split_forward_allreduce_ = true;
295   } else if (is_auto_parallel_) {
296     // in auto parallel mode, this function will be called many times, so need to reset the flags
297     axis_split_forward_allreduce_ = false;
298   }
299 
300   auto product_param = std::accumulate(param_strategy.begin(), param_strategy.end(), 1, std::multiplies<int>());
301   // Cast 1: If repeated calculation, need to set repeated num to the left of dev-matrix. For example,
302   // parameter strategy is [8, 1], indices strategy is [1, 1], dev num is 16,
303   // and dev_matrix is [2, 1, 8, 1, 1], the communication groups are [0, 8] and [0, 1, 2, 3, 4, 5, 6, 7], they
304   // can communicate normally, and dev0 to dev7 have the all parameters.
305   // Cast 2: If not repeated calculation(such as data parallel), need to set repeated num to the right,
306   // as it's easy to introduce the redistribution after or before gather operation, influencing the performance.
307   if (product_param == stage_device_size_ || product_param == 1) {
308     repeated_num_in_dev_matrix_right_ = true;
309   } else {
310     repeated_num_in_dev_matrix_right_ = false;
311   }
312   MS_LOG(INFO) << "Set repeated_num_in_dev_matrix_right for gather to " << repeated_num_in_dev_matrix_right_;
313 }
314 
CheckStrategy(const StrategyPtr & strategy)315 Status GatherPInfo::CheckStrategy(const StrategyPtr &strategy) {
316   if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
317     return FAILED;
318   }
319 
320   // param slice shape need 32Byte aligned
321   auto param_shape = inputs_shape_.at(0);
322   auto param_strategy = strategy->GetInputDim().at(0);
323   auto slice_shape = param_shape.at(param_shape.size() - 1) / param_strategy.at(param_strategy.size() - 1);
324   if ((target_ != CPU) && (slice_shape % 8 != 0) && (slice_shape != 1)) {
325     MS_LOG(ERROR) << name_ << ": Last dim of param slice shape need 32Byte aligned.";
326     return FAILED;
327   }
328 
329   // only support 1-dim and 2-dim param
330   if (inputs_shape_.at(0).size() != 1 && inputs_shape_.at(0).size() != 2) {
331     MS_LOG(ERROR) << name_ << ": Don't support param dim " << inputs_shape_.at(0).size();
332     return FAILED;
333   }
334 
335   // don't support scalar index
336   if (inputs_shape_.at(1).size() == 0) {
337     MS_LOG(DEBUG) << name_ << ": Don't support scalar index.";
338     return FAILED;
339   }
340 
341   if (ShardBatchAndAxis(strategy->GetInputDim())) {
342     shard_batch_and_axis_ = true;
343     axis_split_forward_allreduce_ = true;
344     MS_LOG(INFO) << name_ << ": Sharding batch and axis, and the forward use allreduce";
345     return SUCCESS;
346   } else if (is_auto_parallel_) {
347     // in auto parallel mode, this function will be called many times, so need to reset the flags
348     shard_batch_and_axis_ = false;
349     axis_split_forward_allreduce_ = false;
350   }
351 
352   if (manual_split_) {
353     if (CheckManualSplit(strategy->GetInputDim()) != SUCCESS) {
354       return FAILED;
355     }
356     // when using manual_split, no need to check belowings.
357     return SUCCESS;
358   }
359 
360   // axis != 0, param_shape(0)%(param_strategy(0)*param_strategy(axis)) must be 0
361   if (axis_ != 0 && param_shape.at(0) % (param_strategy.at(0) * param_strategy.at(LongToSize(axis_))) != 0) {
362     MS_LOG(DEBUG) << name_ << ": param_shape(0) can't be divided by (param_strategy(0)*param_strategy(axis)).";
363     return FAILED;
364   }
365 
366   if (CheckSplitAxisStrategy(strategy) != SUCCESS) {
367     return FAILED;
368   }
369 
370   // According to the strategy, set the private members.
371   SetAttribute(strategy);
372 
373   return SUCCESS;
374 }
375 
InferMirrorOps()376 Status GatherPInfo::InferMirrorOps() {
377   // There is no mirror operators for manual split
378   if (manual_split_) {
379     return SUCCESS;
380   }
381 
382   mirror_ops_.clear();
383   Shape input_a_tensor_map = inputs_tensor_map_.at(0);
384   std::vector<Group> input_a_group;
385   if (CreateGroupByTensorMap(input_a_tensor_map, &input_a_group) != SUCCESS) {
386     MS_LOG(ERROR) << name_ << " : Create group for input a failed.";
387     return FAILED;
388   }
389 
390   OperatorVector op_for_input_a, op_for_input_b, op_for_axis;
391   if (input_a_group.empty()) {
392     MS_LOG(INFO) << name_ << " : The mirror group is empty.";
393     return SUCCESS;
394   } else {
395     op_for_input_a = CreateMirrorOps(input_a_group[0].name(), input_a_group[0].GetDevNum());
396     MS_LOG(INFO) << name_ << " : Create the mirror ops for input a success, group is " << input_a_group[0].name();
397   }
398 
399   mirror_ops_.push_back(op_for_input_a);
400   mirror_ops_.push_back(op_for_input_b);
401   mirror_ops_.push_back(op_for_axis);
402 
403   return SUCCESS;
404 }
405 
InferDevMatrixShape()406 Status GatherPInfo::InferDevMatrixShape() {
407   dev_matrix_shape_.clear();
408   out_dev_matrix_shape_.clear();
409   // infer input dev_matrix_shape
410   auto param_strategy = strategy_->GetInputDim().at(0);
411   auto index_strategy = strategy_->GetInputDim().at(1);
412 
413   if (manual_split_) {
414     dev_matrix_shape_ = param_strategy;
415     out_dev_matrix_shape_ = dev_matrix_shape_;
416     return SUCCESS;
417   }
418 
419   if (shard_batch_and_axis_) {
420     dev_matrix_shape_ = {index_strategy[0], param_strategy[0]};
421     // if forward use reducescatter, the dev matrix is {index_strategy[0] * param_strategy[0]}
422     out_dev_matrix_shape_ = dev_matrix_shape_;
423     MS_LOG(INFO) << name_ << ": Sharding batch and axis, the dev matrix is " << dev_matrix_shape_
424                  << ", out dev matrix is " << out_dev_matrix_shape_;
425     return SUCCESS;
426   }
427 
428   dev_matrix_shape_ = param_strategy;
429 
430   // param_strategy(axis) is 1
431   if (param_strategy.at(LongToSize(axis_)) == 1) {
432     dev_matrix_shape_.insert(dev_matrix_shape_.end(), index_strategy.begin(), index_strategy.end());
433   }
434 
435   // infer out dev_matrix_shape
436   // axis is not 0, split axis
437   if (axis_ != 0 && param_strategy.at(LongToSize(axis_)) != 1) {
438     for (size_t i = 1; i < param_strategy.size(); ++i) {
439       if (i == LongToSize(axis_)) {
440         out_dev_matrix_shape_.push_back(1);
441       } else {
442         out_dev_matrix_shape_.push_back(param_strategy.at(i));
443       }
444     }
445     out_dev_matrix_shape_.push_back(param_strategy.at(0) * param_strategy.at(LongToSize(axis_)));
446   } else {
447     out_dev_matrix_shape_ = dev_matrix_shape_;
448   }
449   auto param_product = std::accumulate(param_strategy.begin(), param_strategy.end(), 1, std::multiplies<int64_t>());
450   auto index_product = std::accumulate(index_strategy.begin(), index_strategy.end(), 1, std::multiplies<int64_t>());
451   if (param_product * index_product < stage_device_size_) {
452     auto repeated_calc_num = stage_device_size_ / (param_product * index_product);
453     if (repeated_num_in_dev_matrix_right_) {
454       out_dev_matrix_shape_.push_back(repeated_calc_num);
455     } else {
456       (void)out_dev_matrix_shape_.insert(out_dev_matrix_shape_.begin(), repeated_calc_num);
457     }
458   }
459 
460   return SUCCESS;
461 }
462 
InferInputsTensorMap()463 void GatherPInfo::InferInputsTensorMap() {
464   // infer input tensor map
465   // param_strategy(axis) is not 1
466   size_t param_size = inputs_shape_.at(0).size();
467   size_t index_size = inputs_shape_.at(1).size();
468   size_t total_size = param_size + index_size;
469   Shape tensor_map_index;
470   Shape tensor_map_params;
471   auto param_strategy = strategy_->GetInputDim().at(0);
472   if (param_strategy.at(LongToSize(axis_)) != 1) {
473     tensor_map_index.insert(tensor_map_index.begin(), index_size, MAP_NONE);
474     for (size_t i = 0; i < param_size; ++i) {
475       tensor_map_params.push_back(SizeToLong(param_size - i - 1));
476     }
477   } else {
478     // param_strategy(axis) is 1
479     for (size_t i = 0; i < param_size; ++i) {
480       tensor_map_params.push_back(SizeToLong(total_size - i - 1));
481     }
482     for (size_t i = 0; i < index_size; ++i) {
483       tensor_map_index.push_back(SizeToLong(index_size - i - 1));
484     }
485   }
486   inputs_tensor_map_.emplace_back(std::move(tensor_map_params));
487   inputs_tensor_map_.emplace_back(std::move(tensor_map_index));
488 }
489 
InferOutputsTensorMap()490 void GatherPInfo::InferOutputsTensorMap() {
491   // infer output tensor map
492   size_t param_size = inputs_shape_.at(0).size();
493   size_t index_size = inputs_shape_.at(1).size();
494   size_t total_size = param_size + index_size;
495   Shape tensor_map_out;
496   auto param_strategy = strategy_->GetInputDim().at(0);
497   if (param_strategy.at(LongToSize(axis_)) == 1) {
498     // param_strategy(axis) is 1
499     for (size_t i = 0; i < param_size; ++i) {
500       if (i == LongToSize(axis_)) {
501         for (size_t j = 0; j < index_size; ++j) {
502           tensor_map_out.push_back(SizeToLong(index_size - j - 1));
503         }
504       } else {
505         tensor_map_out.push_back(SizeToLong(total_size - i - 1));
506       }
507     }
508   } else {
509     // param_strategy(axis) is not 1
510     if (axis_ == 0) {
511       if ((dynamic_shape_indices_ && target_ != CPU) || axis_split_forward_allreduce_) {
512         // the output is repeat calculation
513         tensor_map_out.insert(tensor_map_out.end(), MAP_NONE);
514       } else {
515         tensor_map_out.insert(tensor_map_out.end(), param_size - 1);
516       }
517       tensor_map_out.insert(tensor_map_out.end(), index_size - 1, MAP_NONE);
518       for (size_t i = 1; i < param_size; ++i) {
519         tensor_map_out.push_back(param_size - 1 - i);
520       }
521     } else {
522       for (size_t i = 0; i < param_size; ++i) {
523         if (i == LongToSize(axis_)) {
524           tensor_map_out.insert(tensor_map_out.end(), index_size, MAP_NONE);
525         } else {
526           if (i == 0 && dynamic_shape_indices_ && target_ != CPU) {
527             tensor_map_out.push_back(MAP_NONE);
528           }
529           tensor_map_out.push_back(SizeToLong(i));
530         }
531       }
532     }
533   }
534   (void)outputs_tensor_map_.emplace_back(std::move(tensor_map_out));
535 }
536 
InferTensorMap()537 Status GatherPInfo::InferTensorMap() {
538   if (manual_split_) {
539     Shape param_map = {1, 0};
540     Shape indices_map = {-1, 1};
541     Shape out_map = {-1, 1, 0};
542     (void)inputs_tensor_map_.emplace_back(std::move(param_map));
543     (void)inputs_tensor_map_.emplace_back(std::move(indices_map));
544     (void)outputs_tensor_map_.emplace_back(std::move(out_map));
545     return SUCCESS;
546   }
547 
548   if (shard_batch_and_axis_) {
549     Shape param_tensor_map = {0, -1};
550     Shape indices_tensor_map = {1, -1};
551     Shape out_tensor_map = {1, -1, -1};
552     (void)inputs_tensor_map_.emplace_back(std::move(param_tensor_map));    // param
553     (void)inputs_tensor_map_.emplace_back(std::move(indices_tensor_map));  // indices
554     (void)outputs_tensor_map_.emplace_back(
555       std::move(out_tensor_map));  // output, if forward use reducescatter, tensormap is {0, -1, -1}
556     return SUCCESS;
557   }
558   InferInputsTensorMap();
559   InferOutputsTensorMap();
560   return SUCCESS;
561 }
562 
InferTensorInfo()563 Status GatherPInfo::InferTensorInfo() {
564   // infer tensor shape
565   Shape input_shape = inputs_shape_.at(0);
566   Shape input_index_shape = inputs_shape_.at(1);
567   Shape output_shape = outputs_shape_.at(0);
568   int64_t rank = g_device_manager->rank_index_in_stage();
569   // infer tensor layout
570   TensorLayout input_tensor_layout, input_index_layout, output_tensor_layout;
571   if (manual_split_) {
572     input_shape[0] = param_split_shapes_[LongToSize(rank / dev_matrix_shape_[1])];
573     input_shape[0] = input_shape[0] * dev_matrix_shape_[0];
574   }
575   if ((input_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_.at(0), input_shape) != SUCCESS) ||
576       (input_index_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_.at(1), input_index_shape) != SUCCESS) ||
577       (output_tensor_layout.InitFromVector(out_dev_matrix_shape_, outputs_tensor_map_.at(0), output_shape) !=
578        SUCCESS)) {
579     return FAILED;
580   }
581 
582   if (manual_split_) {
583     input_tensor_layout.set_uniform_split(false);
584   }
585   // infer tensor info
586   TensorInfo input_tensor_info(input_tensor_layout);
587   TensorInfo input_index_info(input_index_layout);
588   TensorInfo output_tensor_info(output_tensor_layout);
589 
590   inputs_tensor_info_.push_back(input_tensor_info);
591   inputs_tensor_info_.push_back(input_index_info);
592   outputs_tensor_info_.push_back(output_tensor_info);
593   return SUCCESS;
594 }
595 
InferBias()596 Status GatherPInfo::InferBias() {
597   CheckGlobalDeviceManager();
598   int64_t rank = g_device_manager->rank_index_in_stage();
599   auto input_shape = inputs_shape_.at(0);
600   auto params_strategy = strategy_->GetInputDim().at(0);
601 
602   if (shard_batch_and_axis_) {
603     slice_size_ = input_shape[0] / params_strategy[0];
604     bias_ = rank % params_strategy[0] * slice_size_;
605     MS_LOG(INFO) << name_ << ": Sharding batch and axis, the rank is " << rank << ", slice size is " << slice_size_
606                  << ", bias is " << bias_;
607     return SUCCESS;
608   }
609 
610   // axis don't split
611   if (params_strategy.at(LongToSize(axis_)) == 1) {
612     bias_ = 0;
613     return SUCCESS;
614   }
615   // params_size=1, axis=0
616   if ((input_shape.size() == 1) && (axis_ == 0)) {
617     slice_size_ = input_shape.at(0) / params_strategy.at(0);
618     // if repeated calculation, because the repeated num in the right of dev-matrix, so rank need to div repeated num
619     if (repeated_calc_num_ > 1) {
620       if (repeated_num_in_dev_matrix_right_) {
621         rank = rank / repeated_calc_num_;
622       } else {
623         rank = rank % params_strategy[0];
624       }
625     }
626     bias_ = rank * slice_size_;
627     return SUCCESS;
628   }
629   // params_size=2, axis=0
630   if ((input_shape.size() == 2) && (axis_ == 0)) {
631     slice_size_ = input_shape.at(0) / params_strategy.at(0);
632     // if repeated calculation, because the repeated num in the right of dev-matrix, so rank need to div repeated num
633     if (repeated_calc_num_ > 1) {
634       if (repeated_num_in_dev_matrix_right_) {
635         rank = rank / repeated_calc_num_;
636       } else {
637         rank = rank % (params_strategy[0] * params_strategy[1]);
638       }
639     }
640 #if ((defined ENABLE_CPU) && (!defined _WIN32))
641     if (ps::PsDataPrefetch::GetInstance().cache_enable()) {
642       bias_ = static_cast<int64_t>(ps::PsCacheManager::GetInstance().cache_indices_lower_bound());
643       return SUCCESS;
644     }
645 #endif
646     bias_ = rank / params_strategy.at(1) * slice_size_;
647     return SUCCESS;
648   }
649   // params_size=2, axis=1
650   if ((input_shape.size() == 2) && (axis_ == 1)) {
651     slice_size_ = input_shape.at(1) / params_strategy.at(1);
652     bias_ = rank % params_strategy.at(1) * slice_size_;
653     return SUCCESS;
654   }
655   MS_LOG(ERROR) << name_ << ": Don't support params_size:" << input_shape.size() << " axis:" << axis_;
656   return FAILED;
657 }
658 
InferOffset()659 Status GatherPInfo::InferOffset() {
660   CheckGlobalDeviceManager();
661   size_t rank = LongToSize(g_device_manager->rank_index_in_stage());
662 
663   MS_EXCEPTION_IF_NULL(strategy_);
664   auto param_strategy = strategy_->GetInputDim()[0];
665   if (param_strategy.size() != 2) {
666     MS_LOG(ERROR) << "The size of param strategy must be 2";
667     return FAILED;
668   }
669   size_t index = rank / LongToSize(param_strategy[1]);
670   if (index < index_offsets_.size()) {
671     index_offset_ = index_offsets_[index];
672     MS_LOG(INFO) << name_ << ": Device rank " << rank << ", Index Offset: " << index_offset_;
673     return SUCCESS;
674   }
675 
676   MS_LOG(ERROR) << name_ << ": Get index offset failed, index offset size is" << index_offsets_.size();
677   return FAILED;
678 }
679 
InferGroup()680 Status GatherPInfo::InferGroup() {
681   size_t dim = LongToSize(axis_);
682 
683   int64_t rank = g_device_manager->global_rank();
684   DeviceMatrix dev_matrix(rank, stage_device_list_, dev_matrix_shape_);
685   RankList group_devices;
686 
687   // the dev_matrix[0] is repeated_calc_num, so the dim need to add 1
688   if ((repeated_calc_num_ > 1) && !repeated_num_in_dev_matrix_right_) {
689     dim = dim + 1;
690   }
691 
692   if (shard_batch_and_axis_) {
693     dim = 1;
694     MS_LOG(INFO) << name_ << ": Sharding batch and axis, the group dim is " << dim;
695   }
696 
697   if (dev_matrix.GetDevicesAlongDim(SizeToUlong(dim), &group_devices) != SUCCESS) {
698     MS_LOG(ERROR) << name_ << ": Create group failed.";
699     return FAILED;
700   }
701   if (group_devices.size() == 1) {
702     MS_LOG(INFO) << name_ << ": The group is empty";
703     return SUCCESS;
704   }
705 
706   MS_LOG(INFO) << name_ << ": The group ranks is " << group_devices;
707   group_ = g_device_manager->CreateGroup(group_devices);
708   return SUCCESS;
709 }
710 
InferForwardCommunication()711 Status GatherPInfo::InferForwardCommunication() {
712   if (manual_split_) {
713     return SUCCESS;
714   }
715 
716   forward_op_.clear();
717   auto param_strategy = strategy_->GetInputDim().at(0);
718   // don't split axis or target is not CPU, no need forward communication
719   if (target_ != CPU || param_strategy.at(LongToSize(axis_)) == 1) {
720     return SUCCESS;
721   }
722   // split axis
723   OperatorName operator_name;
724   if (InferGroup() != SUCCESS) {
725     MS_LOG(ERROR) << name_ << ": Infer Group failed.";
726     return FAILED;
727   }
728   Attr attr_group;
729   operator_name = REDUCE_SCATTER;
730   if (InferGroup() != SUCCESS) {
731     MS_LOG(ERROR) << name_ << ": Infer Group failed.";
732     return FAILED;
733   }
734   if (group_.name().empty()) {
735     return SUCCESS;
736   }
737   attr_group = std::make_pair(GROUP, MakeValue(group_.name()));
738   Attr attr_op = std::make_pair(OP, MakeValue(REDUCE_OP_SUM));
739   OperatorAttrs attrs = {attr_op, attr_group};
740   OperatorParams params;
741   OperatorArgs args = std::make_pair(attrs, params);
742   Operator op = std::make_pair(operator_name, args);
743 
744   forward_op_.push_back(op);
745   return SUCCESS;
746 }
747 
ComputeReplaceGraph(const CNodePtr & cnode)748 Status GatherPInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
749   GenerateGraph gen_g = GenerateGraph(attrs_);
750   if (gen_g.Init(cnode) != SUCCESS) {
751     MS_LOG(ERROR) << "GenerateGraph Init failed";
752     return FAILED;
753   }
754   if (manual_split_ && target_ != CPU) {
755     if (InferOffset() != SUCCESS) {
756       MS_LOG(ERROR) << name_ << ": Infer Bias failed.";
757       return FAILED;
758     }
759     auto sub_node =
760       gen_g.PushBack({gen_g.NewOpInst(SUB), gen_g.virtual_input_node(), CreateInt32Tensor(index_offset_)});
761     auto gather_v2_node =
762       gen_g.PushBack({gen_g.NewOpInst(replace_op_name_), gen_g.virtual_input_node(), sub_node, CreatInt64Imm(axis_)});
763     std::vector<std::pair<AnfNodePtr, int64_t>> input_nodes = {std::make_pair(sub_node, 2),
764                                                                std::make_pair(gather_v2_node, 1)};
765     replace_graph_ = std::make_shared<std::pair<std::vector<std::pair<AnfNodePtr, int64_t>>, AnfNodePtr>>(
766       std::make_pair(input_nodes, gather_v2_node));
767     return SUCCESS;
768   }
769   if (InferBias() != SUCCESS) {
770     MS_LOG(ERROR) << name_ << ": Infer Bias failed.";
771     return FAILED;
772   }
773   MS_LOG(INFO) << name_ << ": The rank is " << g_device_manager->rank_index_in_stage() << ", the bias is " << bias_;
774   auto sub = gen_g.PushBack({gen_g.NewOpInst(SUB), gen_g.virtual_input_node(), CreateInt32Tensor(bias_)});
775   auto relu = gen_g.PushBack({gen_g.NewOpInst(RELU), sub});
776   auto minimum = gen_g.PushBack({gen_g.NewOpInst(MINIMUM), relu, CreateInt32Tensor(slice_size_ - 1)});
777   auto equal = gen_g.PushBack({gen_g.NewOpInst(EQUAL), sub, minimum});
778   auto gather_v2 =
779     gen_g.PushBack({gen_g.NewOpInst(replace_op_name_), gen_g.virtual_input_node(), minimum, CreatInt64Imm(axis_)});
780   auto dtype = gen_g.PushBack({gen_g.NewOpInst(DTYPE), gather_v2});
781   auto cast = gen_g.PushBack({gen_g.NewOpInst(CAST), equal, dtype});
782   auto expand_dims = gen_g.PushBack({gen_g.NewOpInst(EXPAND_DIMS), cast, CreatInt64Imm(axis_ - 1)});
783   auto mul = gen_g.PushBack({gen_g.NewOpInst(MUL), gather_v2, expand_dims});
784   // don't need expand dim, if param_size = 1
785   if (inputs_shape_.at(0).size() == 1) {
786     mul = gen_g.PushBack({gen_g.NewOpInst(MUL), gather_v2, cast});
787   }
788   if (InferGroup() != SUCCESS) {
789     MS_LOG(ERROR) << name_ << ": Infer Group failed.";
790     return FAILED;
791   }
792   Attr attr_op = std::make_pair(OP, MakeValue(REDUCE_OP_SUM));
793   Attr attr_group = std::make_pair(GROUP, MakeValue(group_.name()));
794   OperatorAttrs attrs = {attr_op, attr_group};
795   AnfNodePtr reduce_op;
796   if (dynamic_shape_indices_ || axis_split_forward_allreduce_) {
797     reduce_op = gen_g.PushBack({gen_g.NewOpInst(ALL_REDUCE, attrs), mul});
798   } else {
799     reduce_op = gen_g.PushBack({gen_g.NewOpInst(REDUCE_SCATTER, attrs), mul});
800   }
801   std::vector<std::pair<AnfNodePtr, int64_t>> input_nodes = {std::make_pair(sub, 2), std::make_pair(gather_v2, 1)};
802   replace_graph_ = std::make_shared<std::pair<std::vector<std::pair<AnfNodePtr, int64_t>>, AnfNodePtr>>(
803     std::make_pair(input_nodes, reduce_op));
804 
805   return SUCCESS;
806 }
807 
replace_graph(const CNodePtr & cnode)808 ReplaceGraphPtr GatherPInfo::replace_graph(const CNodePtr &cnode) {
809   if (manual_split_ && target_ != CPU) {
810     if (ComputeReplaceGraph(cnode) != SUCCESS) {
811       MS_LOG(EXCEPTION) << name_ << ": ComputeReplaceGraph failed.";
812     }
813     return replace_graph_;
814   }
815 
816   auto param_strategy = strategy_->GetInputDim().at(0);
817   // target_ == CPU, no need to replace graph
818   if (target_ == CPU) {
819     return nullptr;
820   }
821   if (param_strategy.at(LongToSize(axis_)) != 1 && ComputeReplaceGraph(cnode) != SUCCESS) {
822     MS_LOG(EXCEPTION) << name_ << ": ComputeReplaceGraph failed.";
823   }
824   return replace_graph_;
825 }
826 
ComputeReplaceOp()827 Status GatherPInfo::ComputeReplaceOp() {
828   int64_t bias = 0;
829   if (manual_split_) {
830     if (InferOffset() != SUCCESS) {
831       MS_LOG(ERROR) << name_ << ": Infer offset failed.";
832       return FAILED;
833     }
834     bias = index_offset_;
835   } else {
836     if (InferBias() != SUCCESS) {
837       MS_LOG(ERROR) << name_ << ": Infer offset failed.";
838       return FAILED;
839     }
840     bias = bias_;
841   }
842 
843   OperatorName op_name = EMBEDDING_LOOKUP;
844   OperatorAttrs attrs;
845   Attr param_offset = std::make_pair("offset", MakeValue(bias));
846   OperatorParams params = {std::make_pair(param_offset, 3)};
847   OperatorArgs args = std::make_pair(attrs, params);
848   Operator op = std::make_pair(op_name, args);
849   replace_op_.push_back(op);
850 
851   return SUCCESS;
852 }
853 
Init(const StrategyPtr & strategy)854 Status GatherPInfo::Init(const StrategyPtr &strategy) {
855   if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
856     MS_LOG(ERROR) << name_ << ": Init failed.";
857     return FAILED;
858   }
859   // only target_ == CPU, we need to replace op
860   if (target_ == CPU && ComputeReplaceOp() != SUCCESS) {
861     MS_LOG(ERROR) << name_ << ": ComputeReplaceOp failed.";
862   }
863   MS_LOG(INFO) << name_ << ": Init success.";
864   return SUCCESS;
865 }
866 
InitForCostModel(const StrategyPtr & strategy)867 Status GatherPInfo::InitForCostModel(const StrategyPtr &strategy) {
868   if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
869     if (is_auto_parallel_) {
870       MS_LOG(DEBUG) << name_ << ": Init for cost model failed.";
871     } else {
872       MS_LOG(ERROR) << name_ << ": Init for cost model failed.";
873     }
874     return FAILED;
875   }
876   auto param_strategy = strategy_->GetInputDim().at(0);
877   // cost model set axis and strategy
878   auto gatherv2_2cost = std::dynamic_pointer_cast<GatherV2PCost>(operator_cost());
879   gatherv2_2cost->set_axis(axis_);
880   gatherv2_2cost->set_strategy(param_strategy);
881   MS_LOG(INFO) << name_ << ": Init for cost model success.";
882   return SUCCESS;
883 }
884 
SetCostUnderStrategy(const StrategyPtr & strategy)885 Status GatherPInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }
886 
GenerateOpStrategies(int64_t stage_id)887 std::vector<StrategyPtr> GatherPInfo::GenerateOpStrategies(int64_t stage_id) {
888   if (manual_split_) {
889     MS_LOG(EXCEPTION) << name_ << ": Manual split does not support to search strategy";
890   }
891   is_auto_parallel_ = true;
892   Shape input0_split(inputs_shape_[0].size(), 1);
893   Shape input1_split(inputs_shape_[1].size(), 1);
894   Shapes splittable_inputs = {input0_split, input1_split};
895 
896   std::vector<StrategyPtr> sp_vector;
897   if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) {
898     MS_LOG(EXCEPTION) << name_ << ": Generate strategies for independent inputs() failed.";
899   }
900   return sp_vector;
901 }
902 
GenerateBatchStrategies()903 std::shared_ptr<Strategys> GatherPInfo::GenerateBatchStrategies() {
904   if (GetAttrs() != SUCCESS) {
905     MS_LOG(EXCEPTION) << name_ << ": Get attr failed";
906   }
907   if (manual_split_) {
908     MS_LOG(EXCEPTION) << name_ << ": Manual split does not support to generate batch strategy";
909   }
910 
911   Dimensions param_strategy(inputs_shape_[0].size(), 1);
912   Dimensions index_strategy;
913   index_strategy.push_back(stage_device_size_);
914   for (size_t i = 1; i < inputs_shape_[1].size(); i++) {
915     index_strategy.push_back(1);
916   }
917   Strategys strategy_v = {param_strategy, index_strategy};
918   return std::make_shared<Strategys>(strategy_v);
919 }
920 }  // namespace parallel
921 }  // namespace mindspore
922