• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/parallel/ops_info/gather_info.h"
18 
19 #include <algorithm>
20 #include <functional>
21 #include <memory>
22 #include <numeric>
23 #include <utility>
24 #include <vector>
25 
26 #include "frontend/parallel/device_matrix.h"
27 #include "frontend/parallel/dynamic_creator.h"
28 #include "frontend/parallel/graph_util/generate_graph.h"
29 #include "include/common/utils/parallel_context.h"
30 #if defined(__linux__) && defined(WITH_BACKEND)
31 #include "include/backend/distributed/embedding_cache/embedding_cache_utils.h"
32 #include "include/backend/distributed/ps/ps_cache/ps_data_prefetch.h"
33 #include "include/backend/distributed/ps/ps_context.h"
34 #endif
35 
36 namespace mindspore {
37 namespace parallel {
GetManualSplitWithoutOffsetAttr()38 Status GatherInfo::GetManualSplitWithoutOffsetAttr() {
39   auto manual_split_without_offset_iter = attrs_.find("manual_split");
40   if (manual_split_without_offset_iter != attrs_.end()) {
41     manual_split_ = true;
42     MS_EXCEPTION_IF_NULL(manual_split_without_offset_iter->second);
43     if (manual_split_without_offset_iter->second->cast<ValueTuplePtr>() == nullptr) {
44       MS_LOG(ERROR) << name_ << ": Manual split without offset strategy's format is wrong! Need ValueSequence";
45       return FAILED;
46     }
47     std::vector<ValuePtr> value_vector = manual_split_without_offset_iter->second->cast<ValueTuplePtr>()->value();
48     MS_LOG(INFO) << name_ << ": manual split with offset is " << manual_split_without_offset_iter->second->ToString();
49 
50     int64_t offset = 0;
51     for (auto &ele : value_vector) {
52       index_offsets_.push_back(offset);
53       if (!ele->isa<Int64Imm>()) {
54         MS_LOG(ERROR) << name_ << ": The element of manual split must be int64_t";
55         return FAILED;
56       }
57       auto param_split_shape = static_cast<int64_t>(GetValue<int64_t>(ele));
58       if (param_split_shape <= 0) {
59         MS_LOG(ERROR) << name_ << ": The value of manual split must be positive, but got " << param_split_shape;
60         return FAILED;
61       }
62       param_split_shapes_.push_back(param_split_shape);
63       offset += param_split_shape;
64     }
65     if (param_split_shapes_.empty()) {
66       MS_LOG(ERROR) << name_ << ": Failed to extract param split's split info";
67       return FAILED;
68     }
69   }
70 
71   return SUCCESS;
72 }
73 
GetManualSplitAttr()74 Status GatherInfo::GetManualSplitAttr() {
75   auto manual_split_with_offset_iter = attrs_.find("manual_split_with_offset");
76   if (manual_split_with_offset_iter != attrs_.end()) {
77     manual_split_ = true;
78     auto var = manual_split_with_offset_iter->second->cast<ValueTuplePtr>();
79     if (var == nullptr) {
80       MS_LOG(ERROR) << name_ << ": Manual split with offset strategy's format is wrong! Need ValueSequence";
81       return FAILED;
82     }
83 
84     MS_LOG(INFO) << name_ << ": manual split with offset strategy " << var->ToString();
85     for (auto &ele : var->value()) {
86       if (!ele->isa<ValueSequence>()) {
87         MS_LOG(ERROR) << name_ << ": Manual split with offset strategy's format is wrong! Need ValueSequence";
88         return FAILED;
89       }
90       std::vector<ValuePtr> value_vector = ele->cast<ValueTuplePtr>()->value();
91       if (value_vector.size() != 2) {
92         MS_LOG(ERROR) << name_ << ": Size of manual split with offset's element must be 2";
93         return FAILED;
94       }
95       int64_t param_split_row = (GetValue<int64_t>(value_vector[0]));
96       int64_t offset = (GetValue<int64_t>(value_vector[1]));
97       if ((param_split_row <= 0) || (offset < 0)) {
98         MS_LOG(ERROR) << name_ << ": The value of param split shape must be positive, "
99                       << "and the offset must be greater than or equal to 0";
100         return FAILED;
101       }
102       param_split_shapes_.push_back(param_split_row);
103       index_offsets_.push_back(offset);
104     }
105 
106     if (param_split_shapes_.empty()) {
107       MS_LOG(ERROR) << name_ << ": Failed to extract param split with offset's split info";
108       return FAILED;
109     }
110     if (std::any_of(index_offsets_.begin(), index_offsets_.end(), [](const int64_t &offset) { return offset < 0; })) {
111       MS_LOG(ERROR) << name_ << ": Index offset must not be less than 0";
112       return FAILED;
113     }
114     return SUCCESS;
115   }
116 
117   if (GetManualSplitWithoutOffsetAttr() != SUCCESS) {
118     return FAILED;
119   }
120 
121   return SUCCESS;
122 }
123 
GetBatchDims()124 void GatherInfo::GetBatchDims() noexcept {
125   if (name_.find(INDEX_SELECT) != std::string::npos) {
126     batch_dims_ = 0;
127     return;
128   }
129   auto batch_dims_opt = GetScalarValueFromInputs<int64_t>(input_value_, name_, BATCH_DIMS);
130   if (batch_dims_opt.has_value()) {
131     batch_dims_ = batch_dims_opt.value();
132   } else {
133     MS_LOG(EXCEPTION) << name_ << ": Failed to fetch the value of batch dims.";
134   }
135 }
136 
MakeManualUtil()137 GatherUtilPtr GatherInfo::MakeManualUtil() {
138   return std::make_shared<GatherManualImpl>(name_, inputs_shape_clone_, outputs_shape_clone_, axis_);
139 }
140 
MakeManualUtil()141 GatherUtilPtr SparseGatherV2Info::MakeManualUtil() {
142   return std::make_shared<ManualImpl>(name_, inputs_shape_clone_, outputs_shape_clone_, axis_);
143 }
144 
MakeManualUtil()145 GatherUtilPtr EmbeddingLookupInfo::MakeManualUtil() {
146   return std::make_shared<ManualImpl>(name_, inputs_shape_clone_, outputs_shape_clone_, axis_);
147 }
148 
GetAttrs()149 Status GatherInfo::GetAttrs() {
150   if (attrs_.find(TARGET) != attrs_.end()) {
151     target_ = GetStringAttr(TARGET);
152   }
153 
154   if (name_.find(EMBEDDING_LOOKUP) != std::string::npos && target_ != CPU) {
155     MS_LOG(ERROR) << name_ << ": must be set the cpu target";
156     return FAILED;
157   }
158 
159   size_t axis_index = 2;
160   if (name_.find(INDEX_SELECT) != std::string::npos) {
161     axis_index = 1;
162   }
163 
164   MS_EXCEPTION_IF_NULL(input_value_[axis_index]);
165   auto value = GetValue<int64_t>(input_value_[axis_index]);
166 
167   // get axis, the third input is the axis, is a ValueNode, embeddinglookup doesn't have axis, and its offset.
168   if (target_ != CPU) {
169     auto params_shape = inputs_shape_.at(0);
170     if (params_shape.empty()) {
171       MS_LOG(ERROR) << name_ << ": params can not be a scalar!";
172       return FAILED;
173     }
174     if (value < 0) {  // if axis is negative then convert it to positive
175       value += SizeToLong(params_shape.size());
176     }
177     axis_ = value;
178   } else {
179     if (value != 0) {
180       if (name_.find(EMBEDDING_LOOKUP) != std::string::npos) {
181         MS_LOG(ERROR) << name_ << ": the target is cpu, and the offset must be 0, but got " << value;
182       } else {
183         MS_LOG(ERROR) << name_ << ": the target is cpu, and the axis must be 0, but got " << value;
184       }
185       return FAILED;
186     }
187   }
188 
189   if (GetManualSplitAttr() != SUCCESS) {
190     return FAILED;
191   }
192 
193   GetBatchDims();
194 
195   if (manual_split_ && (axis_ != 0)) {
196     MS_LOG(ERROR) << name_ << ": The axis must be 0 if manual split, bug got " << axis_;
197     return FAILED;
198   }
199 
200   if (std::find(inputs_shape_[1].begin(), inputs_shape_[1].end(), -1) != inputs_shape_[1].end()) {
201     dynamic_shape_indices_ = true;
202   }
203 #if defined(__linux__) && defined(WITH_BACKEND)
204   if (ps::PsDataPrefetch::GetInstance().cache_enable()) {
205     dynamic_shape_indices_ = true;
206   }
207 #endif
208   return SUCCESS;
209 }
210 
211 // return true: axis is 0, and split the first dimension of parameter and the first dimension of indices
212 // otherwise return false
ShardBatchAndAxis(const Shape & param_strategy,const Shape & indices_strategy) const213 bool GatherInfo::ShardBatchAndAxis(const Shape &param_strategy, const Shape &indices_strategy) const {
214   if (axis_ != 0) {
215     return false;
216   }
217 
218   if ((param_strategy.size() != 2) || (indices_strategy.size() != 2)) {
219     return false;
220   }
221 
222   if ((param_strategy[1] != 1) || (indices_strategy[1] != 1)) {
223     return false;
224   }
225 
226   if (param_strategy[0] * indices_strategy[0] != stage_device_size_) {
227     return false;
228   }
229 
230   if ((param_strategy[0] == stage_device_size_) || (indices_strategy[0] == stage_device_size_)) {
231     return false;
232   }
233 
234   return true;
235 }
236 
GetGatherMode(const Shape & param_strategy,const Shape & indices_strategy) const237 GatherMode GatherInfo::GetGatherMode(const Shape &param_strategy, const Shape &indices_strategy) const {
238   if (batch_dims_ > 0) {
239     return BATCH;
240   }
241 
242   if (param_strategy[LongToSize(axis_)] == NO_SPLIT_STRATEGY) {
243     return NORMAL;
244   }
245 
246   if (manual_split_) {
247     return MANUAL;
248   }
249 
250   if (ShardBatchAndAxis(param_strategy, indices_strategy)) {
251     return SHARD_BATCH_AND_AXIS;
252   }
253 
254   if (axis_ == 0 && param_strategy[0] != NO_SPLIT_STRATEGY) {
255     if (std::find(inputs_shape_[1].begin(), inputs_shape_[1].end(), -1) != inputs_shape_[1].end()) {
256       return SHARD_AXIS_0_DYNAMIC;
257     } else {
258       return SHARD_AXIS_0_STATIC;
259     }
260   }
261 
262   if (axis_ == 1 && param_strategy[1] != NO_SPLIT_STRATEGY) {
263     return SHARD_AXIS_1;
264   }
265 
266   return INVALID;
267 }
268 
269 // axis can not be split, and the strategies of batch dims must be equal
270 // support repeat calculation
CheckStrategy(const Shape & param_strategy,const Shape & indices_strategy)271 Status BatchImpl::CheckStrategy(const Shape &param_strategy, const Shape &indices_strategy) {
272   if (param_strategy[LongToSize(axis_)] != NO_SPLIT_STRATEGY) {
273     MS_LOG(ERROR) << name_ << ": batch mode, the axis can not be split, but the param strategy is " << param_strategy
274                   << ", and the axis is " << axis_;
275     return FAILED;
276   }
277 
278   for (size_t i = 0; i < LongToSize(batch_dims_); ++i) {
279     if (param_strategy[i] != indices_strategy[i]) {
280       MS_LOG(ERROR)
281         << name_
282         << ": batch mode, the strategy of the batch dims of param and indices must be equal, but the param strategy is "
283         << param_strategy << ", and the indices strategy is " << indices_strategy << ", batch dims is " << batch_dims_;
284       return FAILED;
285     }
286   }
287   return SUCCESS;
288 }
289 
290 // batch mode: axis can not be split
291 // param  shape: [A, B, C, D, E]
292 // indices shape: [A, B, F, G]
293 // batch_dims = 2
294 // axis = 3
295 // out = gather(param,  indices,  axis)
296 // out shape: [A, B, C, F, G, E]
297 // parameter's strategy: [a, b, c, 1, e], indices' strategy: [a, b, f, g]
298 // output's strategy: [a, b, c, f, g, e]
299 // dev_matrix: [a, b, f, g, c, 1, e]
InferDevMatrixShape()300 Status BatchImpl::InferDevMatrixShape() {
301   auto indices_tmp = indices_strategy_;  // [a, b, f, g]
302   auto param_tmp = param_strategy_;      // [a, b, c, d, e]  = [a, b, c, 1, e]
303   (void)param_tmp.erase(param_tmp.cbegin(), param_tmp.cbegin() + LongToSize(batch_dims_));  // [C, 1, E]
304 
305   Shape tmp = indices_tmp;
306   (void)tmp.insert(tmp.cend(), param_tmp.cbegin(), param_tmp.cend());  // [a, b, f, g, c, 1, e]
307 
308   dev_matrix_shape_ = tmp;
309   MS_LOG(INFO) << name_ << ": batch mode, the dev matrix shape is " << dev_matrix_shape_;
310   return SUCCESS;
311 }
312 
InferTensorMap()313 Status BatchImpl::InferTensorMap() {
314   TensorMap tmp_map;
315   int64_t size = SizeToInt(outputs_shape_[0].size()) + 1;
316   for (int i = 0; i < size; ++i) {
317     tmp_map.push_back(size - i - 1);  // tmp_map: [a, b, f, g, c, 1, e]
318   }
319 
320   TensorMap param_map = tmp_map;  // [a, b, f, g, c, 1, e]
321   (void)param_map.erase(param_map.cbegin() + LongToSize(batch_dims_),
322                         param_map.cbegin() + inputs_shape_[1].size());  // [a, b, c, 1, e]
323 
324   TensorMap indices_map = tmp_map;                                                              // [a, b, f, g, c, 1, e]
325   (void)indices_map.erase(indices_map.cbegin() + inputs_shape_[1].size(), indices_map.cend());  // [a, b, f, g]
326 
327   TensorMap out_map = param_map;                              // [a, b, c, 1, e]
328   (void)out_map.erase(out_map.cbegin() + LongToSize(axis_));  // [a, b, c, e]
329 
330   TensorMap indices_rm_batch = indices_map;  // [a, b, f, g]
331   (void)indices_rm_batch.erase(indices_rm_batch.cbegin(),
332                                indices_rm_batch.cbegin() + LongToSize(batch_dims_));  // [f, g]
333 
334   (void)out_map.insert(out_map.cbegin() + LongToSize(axis_), indices_rm_batch.cbegin(),
335                        indices_rm_batch.cend());  // [a, b, c, f, g, e]
336 
337   inputs_tensor_map_.push_back(param_map);    // param
338   inputs_tensor_map_.push_back(indices_map);  // indices
339   outputs_tensor_map_.push_back(out_map);     // out
340   return SUCCESS;
341 }
342 
343 // axis can not be split
344 // support repeat calculation
CheckStrategy(const Shape & param_strategy,const Shape & indices_strategy)345 Status NormalImpl::CheckStrategy(const Shape &param_strategy, const Shape &indices_strategy) {
346   if (param_strategy[LongToSize(axis_)] != NO_SPLIT_STRATEGY) {
347     MS_LOG(ERROR) << name_ << ": normal mode, the axis can not be split, but the param strategy is " << param_strategy
348                   << ", and the axis is " << axis_;
349     return FAILED;
350   }
351 
352   return SUCCESS;
353 }
354 
355 // normal mode: axis can not be split
356 // param  shape: [C, D, E]
357 // indices shape: [F, G]
358 // axis = 1
359 // out = gather(param,  indices,  axis)
360 // out shape: [C, F, G, E]
361 // parameter's strategy: [c, 1, e], indices' strategy: [f, g]
362 // output's strategy: [c, f, g, e]
363 // dev_matrix: [f, g, c, 1, e]
InferDevMatrixShape()364 Status NormalImpl::InferDevMatrixShape() {
365   auto indices_tmp = indices_strategy_;  // [f, g]
366   auto param_tmp = param_strategy_;      // [c, d, e]  = [c, 1, e]
367 
368   Shape tmp = indices_tmp;
369   (void)tmp.insert(tmp.cend(), param_tmp.cbegin(), param_tmp.cend());  // [f, g, c, 1, e]
370 
371   dev_matrix_shape_ = tmp;
372   MS_LOG(INFO) << name_ << ": normal mode, the dev matrix shape is " << dev_matrix_shape_;
373   return SUCCESS;
374 }
375 
InferTensorMap()376 Status NormalImpl::InferTensorMap() {
377   TensorMap tmp_map;
378   int64_t size = SizeToInt(outputs_shape_[0].size()) + 1;
379   for (int i = 0; i < size; ++i) {
380     tmp_map.push_back(size - i - 1);  // tmp_map: [f, g, c, 1, e]
381   }
382 
383   TensorMap param_map = tmp_map;                                                            // [f, g, c, 1, e]
384   (void)param_map.erase(param_map.cbegin(), param_map.cbegin() + inputs_shape_[1].size());  // [c, 1, e]
385 
386   TensorMap indices_map = tmp_map;                                                              // [f, g, c, 1, e]
387   (void)indices_map.erase(indices_map.cbegin() + inputs_shape_[1].size(), indices_map.cend());  // [f, g]
388 
389   TensorMap out_map = param_map;                                                                         // [c, 1, e]
390   (void)out_map.erase(out_map.cbegin() + LongToSize(axis_));                                             // [c, e]
391   (void)out_map.insert(out_map.cbegin() + LongToSize(axis_), indices_map.cbegin(), indices_map.cend());  // [c, f, g, e]
392 
393   inputs_tensor_map_.push_back(param_map);    // param
394   inputs_tensor_map_.push_back(indices_map);  // indices
395   outputs_tensor_map_.push_back(out_map);     // out
396   return SUCCESS;
397 }
398 
399 // constraint: the field dimension of indices is the last dimension
400 // parameter's dim >= 1, indices' dim >= 1, axis == 0
401 // parameter's strategy: [a, b, ..., c], indices' strategy: [1, ..., 1, a]
402 // output's strategy: [1, ..., 1, a, b, ..., c]
403 // dev_matrix: [a, b, ..., c]
404 // can not support repeated calculation
CheckStrategy(const Shape & param_strategy,const Shape & indices_strategy)405 Status ManualImpl::CheckStrategy(const Shape &param_strategy, const Shape &indices_strategy) {
406   if (indices_strategy.size() < 1) {
407     MS_LOG(ERROR) << name_ << ": The size of indices strategy must be positive, but got " << indices_strategy.size();
408     return FAILED;
409   }
410 
411   auto product_i = std::accumulate(indices_strategy.begin(), indices_strategy.end(), 1, std::multiplies<int64_t>());
412   size_t indices_split_dim = indices_strategy.size() - 1;  // only the last dim of indices can be split
413   if (product_i != indices_strategy[indices_split_dim]) {
414     MS_LOG(ERROR) << name_ << ": Only the last dim of indices can be split, but got " << indices_strategy;
415     return FAILED;
416   }
417 
418   if (param_strategy[0] != indices_strategy[indices_split_dim]) {
419     MS_LOG(ERROR) << name_ << ": The param_strategy[0] " << param_strategy[0]
420                   << " must be equal to indices_strategy[-1] " << indices_strategy[indices_split_dim];
421     return FAILED;
422   }
423 
424   if (indices_strategy[indices_split_dim] != SizeToLong(param_split_shapes_.size())) {
425     MS_LOG(ERROR) << name_ << ": The indices_strategy[-1] " << indices_strategy[indices_split_dim]
426                   << " must be equal to manual split size " << param_split_shapes_.size();
427     return FAILED;
428   }
429   MS_EXCEPTION_IF_ZERO("indices_strategy[indices_split_dim]", indices_strategy[indices_split_dim]);
430   int64_t min_param_slice_row = inputs_shape_[1][indices_split_dim] / indices_strategy[indices_split_dim];
431   bool invalid = std::any_of(param_split_shapes_.begin(), param_split_shapes_.end(),
432                              [&min_param_slice_row](int64_t v) { return v < min_param_slice_row; });
433   if (invalid) {
434     MS_LOG(ERROR) << name_ << ": The split value " << param_split_shapes_
435                   << " must be larger than or equal to indices field slice size " << min_param_slice_row;
436     return FAILED;
437   }
438 
439   if (inputs_shape_[0][0] < inputs_shape_[1][indices_split_dim]) {
440     MS_LOG(ERROR) << name_ << ": The param's row size " << inputs_shape_[0][0]
441                   << " is smaller than indices' field size " << inputs_shape_[1][indices_split_dim];
442     return FAILED;
443   }
444 
445   // Don't support repeated calc
446   auto product_p = std::accumulate(param_strategy.begin(), param_strategy.end(), 1, std::multiplies<int64_t>());
447   MS_EXCEPTION_IF_NULL(g_device_manager);
448   if (product_p < SizeToLong(g_device_manager->GetDeviceListInThisStage().size())) {
449     MS_LOG(ERROR) << name_ << ": Manual split doesn't support repeated calc";
450     return FAILED;
451   }
452 
453   int64_t split_shape_sum = std::accumulate(param_split_shapes_.begin(), param_split_shapes_.end(), 0,
454                                             [](int64_t s, int64_t shape) { return s + shape; });
455   if (split_shape_sum != inputs_shape_[0][0]) {
456     MS_LOG(ERROR) << name_ << ": Sum of split shapes " << split_shape_sum << " must be equal to param_shape[0] "
457                   << inputs_shape_[0][0];
458     return FAILED;
459   }
460   return SUCCESS;
461 }
462 
InferDevMatrixShape()463 Status ManualImpl::InferDevMatrixShape() {
464   dev_matrix_shape_ = param_strategy_;
465   MS_LOG(INFO) << name_ << ": manual mode, the dev matrix shape is " << dev_matrix_shape_;
466   return SUCCESS;
467 }
468 
InferTensorMap()469 Status ManualImpl::InferTensorMap() {
470   Shape param_map;
471   size_t size = inputs_shape_[0].size();
472   for (size_t i = 0; i < size; ++i) {
473     param_map.push_back(static_cast<int64_t>(size - i - 1));
474   }
475 
476   size_t indices_size = inputs_shape_[1].size();
477   Shape indices_map(indices_size, MAP_NONE);
478   indices_map[indices_size - 1] = param_map[0];
479 
480   Shape out_map = param_map;
481   (void)out_map.insert(out_map.begin(), indices_size - 1, MAP_NONE);
482 
483   (void)inputs_tensor_map_.emplace_back(std::move(param_map));
484   (void)inputs_tensor_map_.emplace_back(std::move(indices_map));
485   (void)outputs_tensor_map_.emplace_back(std::move(out_map));
486   return SUCCESS;
487 }
488 
InferTensorInfo()489 Status ManualImpl::InferTensorInfo() {
490   // infer tensor shape
491   Shape input_shape = inputs_shape_.at(0);
492   Shape input_index_shape = inputs_shape_.at(1);
493   Shape output_shape = outputs_shape_.at(0);
494   int64_t rank = g_device_manager->rank_index_in_stage();
495   // infer tensor layout
496   TensorLayout input_tensor_layout;
497   TensorLayout input_index_layout;
498   TensorLayout output_tensor_layout;
499 
500   int64_t bias_size = 1;
501   if (dev_matrix_shape_.size() > 1) {
502     bias_size = std::accumulate(dev_matrix_shape_.begin() + 1, dev_matrix_shape_.end(), 1, std::multiplies<int64_t>());
503   }
504   if (bias_size == 0) {
505     MS_LOG(ERROR) << name_ << ": Invalid device matrix " << dev_matrix_shape_;
506     return FAILED;
507   }
508   input_shape[0] = param_split_shapes_[LongToSize(rank / bias_size)];
509   input_shape[0] = input_shape[0] * dev_matrix_shape_[0];
510 
511   if ((input_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_.at(0), input_shape) != SUCCESS) ||
512       (input_index_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_.at(1), input_index_shape) != SUCCESS) ||
513       (output_tensor_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_.at(0), output_shape) != SUCCESS)) {
514     return FAILED;
515   }
516 
517   input_tensor_layout.set_uniform_split(false);
518 
519   // infer tensor info
520   TensorInfo input_tensor_info(input_tensor_layout);
521   TensorInfo input_index_info(input_index_layout);
522   TensorInfo output_tensor_info(output_tensor_layout);
523 
524   inputs_tensor_info_.push_back(input_tensor_info);
525   inputs_tensor_info_.push_back(input_index_info);
526   outputs_tensor_info_.push_back(output_tensor_info);
527   return SUCCESS;
528 }
529 
InferOffset()530 Status ManualImpl::InferOffset() {
531   CheckGlobalDeviceManager();
532   size_t rank = LongToSize(g_device_manager->rank_index_in_stage());
533 
534   int64_t bias_size = 1;
535   if (param_strategy_.size() > 1) {
536     bias_size = std::accumulate(param_strategy_.begin() + 1, param_strategy_.end(), 1, std::multiplies<int64_t>());
537   }
538   MS_EXCEPTION_IF_ZERO("bias_size", LongToSize(bias_size));
539   size_t index = rank / LongToSize(bias_size);
540   if (index < index_offsets_.size()) {
541     index_offset_ = index_offsets_[index];
542     MS_LOG(INFO) << name_ << ": Device rank " << rank << ", Index Offset: " << index_offset_;
543     return SUCCESS;
544   }
545 
546   MS_LOG(ERROR) << name_ << ": Get index offset failed, index offset size is" << index_offsets_.size();
547   return FAILED;
548 }
549 
InferReplaceGraph(const CNodePtr & cnode)550 Status ManualImpl::InferReplaceGraph(const CNodePtr &cnode) {
551   if (target_ == CPU) {  // if target is CPU, no need to replace graph
552     return SUCCESS;
553   }
554 
555   GenerateGraph gen_g = GenerateGraph(attrs_);
556   if (gen_g.Init(cnode) != SUCCESS) {
557     MS_LOG(ERROR) << name_ << "GenerateGraph Init failed";
558     return FAILED;
559   }
560 
561   if (InferOffset() != SUCCESS) {
562     MS_LOG(ERROR) << name_ << ": Infer Bias failed.";
563     return FAILED;
564   }
565 
566   auto sub_node = gen_g.PushBack({gen_g.NewOpInst(SUB), gen_g.virtual_input_node(), CreateInt32Tensor(index_offset_)});
567   AnfNodePtr gather_v2_node = nullptr;
568   std::vector<std::pair<AnfNodePtr, int64_t>> input_nodes;
569   if (name_.find(INDEX_SELECT) != std::string::npos) {
570     gather_v2_node =
571       gen_g.PushBack({gen_g.NewOpInst(replace_op_name_), gen_g.virtual_input_node(), CreatInt64Imm(axis_), sub_node});
572     input_nodes = {std::make_pair(sub_node, 3), std::make_pair(gather_v2_node, 1)};
573   } else {
574     gather_v2_node =
575       gen_g.PushBack({gen_g.NewOpInst(replace_op_name_), gen_g.virtual_input_node(), sub_node, CreatInt64Imm(axis_)});
576     input_nodes = {std::make_pair(sub_node, 2), std::make_pair(gather_v2_node, 1)};
577   }
578 
579   replace_graph_ = std::make_shared<std::pair<std::vector<std::pair<AnfNodePtr, int64_t>>, AnfNodePtr>>(
580     std::make_pair(input_nodes, gather_v2_node));
581   return SUCCESS;
582 }
583 
InferReplaceGraph(const CNodePtr & cnode)584 Status GatherManualImpl::InferReplaceGraph(const CNodePtr &cnode) {
585   if (target_ == CPU) {  // if target is CPU, no need to replace graph
586     return SUCCESS;
587   }
588 
589   GenerateGraph gen_g = GenerateGraph(attrs_);
590   if (gen_g.Init(cnode) != SUCCESS) {
591     MS_LOG(ERROR) << name_ << "GenerateGraph Init failed";
592     return FAILED;
593   }
594 
595   if (InferOffset() != SUCCESS) {
596     MS_LOG(ERROR) << name_ << ": Infer Bias failed.";
597     return FAILED;
598   }
599 
600   auto sub_node = gen_g.PushBack({gen_g.NewOpInst(SUB), gen_g.virtual_input_node(), CreateInt32Tensor(index_offset_)});
601   AnfNodePtr gather_v2_node = nullptr;
602   std::vector<std::pair<AnfNodePtr, int64_t>> input_nodes;
603   // Gather processing.
604   if (name_.find(INDEX_SELECT) != std::string::npos) {
605     gather_v2_node =
606       gen_g.PushBack({gen_g.NewOpInst(replace_op_name_), gen_g.virtual_input_node(), CreatInt64Imm(axis_), sub_node});
607     input_nodes = {std::make_pair(sub_node, 3), std::make_pair(gather_v2_node, 1)};
608   } else {
609     gather_v2_node = gen_g.PushBack({gen_g.NewOpInst(replace_op_name_), gen_g.virtual_input_node(), sub_node,
610                                      CreatInt64Imm(axis_), CreatInt64Imm(0)});
611     input_nodes = {std::make_pair(sub_node, 2), std::make_pair(gather_v2_node, 1)};
612   }
613 
614   replace_graph_ = std::make_shared<std::pair<std::vector<std::pair<AnfNodePtr, int64_t>>, AnfNodePtr>>(
615     std::make_pair(input_nodes, gather_v2_node));
616   return SUCCESS;
617 }
618 
InferReplaceOps()619 Status ManualImpl::InferReplaceOps() {
620   if (target_ != CPU) {  // if target is not CPU, no need to replace ops
621     return SUCCESS;
622   }
623 
624   int64_t bias = 0;
625 
626   if (InferOffset() != SUCCESS) {
627     MS_LOG(ERROR) << name_ << ": Infer offset failed.";
628     return FAILED;
629   }
630 
631   bias = index_offset_;
632 
633   OperatorName op_name = EMBEDDING_LOOKUP;
634   OperatorAttrs attrs;
635   Attr param_offset = std::make_pair("offset", MakeValue(bias));
636   OperatorParams params = {std::make_pair(param_offset, 3)};
637   OperatorArgs args = std::make_pair(attrs, params);
638   Operator op = std::make_pair(op_name, args);
639   replace_op_.push_back(op);
640 
641   return SUCCESS;
642 }
643 
InferDevMatrixShape()644 Status ShardBatchAndAxisImpl::InferDevMatrixShape() {
645   dev_matrix_shape_ = {indices_strategy_[0], param_strategy_[0]};
646   MS_LOG(INFO) << name_ << ": Sharding batch and axis, the dev matrix is " << dev_matrix_shape_;
647   // if forward use reducescatter, the output's dev matrix is {index_strategy[0] * param_strategy[0]}
648   if (axis_split_forward_allreduce_) {
649     out_dev_matrix_shape_ = dev_matrix_shape_;
650   } else {
651     out_dev_matrix_shape_ = {indices_strategy_[0] * param_strategy_[0]};
652   }
653   auto shard_product =
654     std::accumulate(dev_matrix_shape_.begin(), dev_matrix_shape_.end(), 1, std::multiplies<int64_t>());
655   auto stage_device_size = SizeToLong(g_device_manager->GetDeviceListInThisStage().size());
656   if (shard_product < stage_device_size) {
657     MS_EXCEPTION_IF_ZERO("shard_product", shard_product);
658     repeated_calculation_num_ = stage_device_size / shard_product;  // set repeated calculation num
659   }
660   return SUCCESS;
661 }
662 
InferTensorMap()663 Status ShardBatchAndAxisImpl::InferTensorMap() {
664   Shape param_tensor_map = {0, MAP_NONE};
665   Shape indices_tensor_map = {1, MAP_NONE};
666   Shape out_tensor_map;
667   if (axis_split_forward_allreduce_) {
668     out_tensor_map = {1, MAP_NONE, MAP_NONE};  // the dev matrix is (index_strategy[0], param_strategy[0])
669   } else {
670     out_tensor_map = {0, MAP_NONE, MAP_NONE};  // the dev matrix is (index_strategy[0] * param_strategy[0])
671   }
672 
673   (void)inputs_tensor_map_.emplace_back(std::move(param_tensor_map));    // param
674   (void)inputs_tensor_map_.emplace_back(std::move(indices_tensor_map));  // indices
675   (void)outputs_tensor_map_.emplace_back(std::move(out_tensor_map));     // output
676   return SUCCESS;
677 }
678 
InferBias()679 Status ShardBatchAndAxisImpl::InferBias() {
680   CheckGlobalDeviceManager();
681   int64_t rank = g_device_manager->rank_index_in_stage();
682   auto input_shape = inputs_shape_.at(0);
683   MS_EXCEPTION_IF_ZERO("param_strategy_[0]", param_strategy_[0]);
684   slice_size_ = input_shape[0] / param_strategy_[0];
685   bias_ = rank % param_strategy_[0] * slice_size_;
686   MS_LOG(INFO) << name_ << ": Sharding batch and axis, the rank is " << rank << ", slice size is " << slice_size_
687                << ", bias is " << bias_;
688   return SUCCESS;
689 }
690 
SetAttribute(const Shape & param_strategy)691 void ShardAxisImpl::SetAttribute(const Shape &param_strategy) {
692   // axis=0, index_shape(0)%param_strategy(0) must be 0
693   Shape index_shape = inputs_shape_.at(1);
694   MS_EXCEPTION_IF_ZERO("param_strategy.at(0)", param_strategy.at(0));
695   if ((axis_ == 0) && (index_shape.at(0) % param_strategy.at(0) != 0) && !dynamic_shape_indices_) {
696     MS_LOG(INFO) << name_ << ": index_shape(0) can't be divided by param_strategy(0), use allreduce in forward";
697     axis_split_forward_allreduce_ = true;
698   }
699 
700   auto product_param = std::accumulate(param_strategy.begin(), param_strategy.end(), 1, std::multiplies<int>());
701   // Cast 1: If repeated calculation, need to set repeated num to the left of dev-matrix. For example,
702   // parameter strategy is [8, 1], indices strategy is [1, 1], dev num is 16,
703   // and dev_matrix is [2, 1, 8, 1, 1], the communication groups are [0, 8] and [0, 1, 2, 3, 4, 5, 6, 7], they
704   // can communicate normally, and dev0 to dev7 have the all parameters.
705   // Cast 2: If not repeated calculation(such as data parallel), need to set repeated num to the right,
706   // as it's easy to introduce the redistribution after or before gather operation, influencing the performance.
707   auto stage_device_size = g_device_manager->GetDeviceListInThisStage().size();
708   if (product_param == SizeToLong(stage_device_size) || product_param == 1) {
709     repeated_num_in_dev_matrix_right_ = true;
710   } else {
711     repeated_num_in_dev_matrix_right_ = false;
712   }
713   MS_LOG(INFO) << "Set repeated_num_in_dev_matrix_right for gather to " << repeated_num_in_dev_matrix_right_;
714 }
715 
CheckSplitAxisStrategy(const Shape & param_strategy,const Shape & indices_strategy)716 Status ShardAxisImpl::CheckSplitAxisStrategy(const Shape &param_strategy, const Shape &indices_strategy) {
717   // param_strategy(axis) != 1, index can't be split
718   auto stage_device_size = g_device_manager->GetDeviceListInThisStage().size();
719   std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode();
720   bool is_auto_parallel = (parallel_mode == kAutoParallel);
721 
722   auto product_i = std::accumulate(indices_strategy.begin(), indices_strategy.end(), 1, std::multiplies<int64_t>());
723   if ((param_strategy.at(LongToSize(axis_)) != 1) && (product_i != 1)) {
724     FILTER_LOG(is_auto_parallel) << name_ << ": param is split at dim (axis)" << axis_ << " ,index can't be split.";
725     return FAILED;
726   }
727 
728   // param_strategy(axis) != 1, and axis != 0, don't support repeated calc
729   auto product_p = std::accumulate(param_strategy.begin(), param_strategy.end(), 1, std::multiplies<int64_t>());
730   if ((product_p != SizeToLong(stage_device_size)) && (param_strategy.at(LongToSize(axis_)) != 1) && (axis_ != 0)) {
731     FILTER_LOG(is_auto_parallel) << name_ << ": Invalid strategy. Don't support repeated calc.";
732     return FAILED;
733   }
734 
735   if ((product_p != SizeToLong(stage_device_size)) && (param_strategy.at(LongToSize(axis_)) != 1) && (axis_ == 0)) {
736     if ((param_strategy.size() == 2) && (param_strategy[1] != 1)) {
737       FILTER_LOG(is_auto_parallel) << name_
738                                    << ": axis(0) is split, and param_strategy[1] != 1, don't support"
739                                       " repeated calc.";
740       return FAILED;
741     }
742     MS_LOG(INFO) << name_ << ": split axis(0) and repeat calculation";
743   }
744   return SUCCESS;
745 }
746 
CheckStrategy(const Shape & param_strategy,const Shape & indices_strategy)747 Status ShardAxisImpl::CheckStrategy(const Shape &param_strategy, const Shape &indices_strategy) {
748   // only support 1-dim and 2-dim param
749   if (inputs_shape_.at(0).size() != 1 && inputs_shape_.at(0).size() != 2) {
750     MS_LOG(ERROR) << name_ << ": Don't support param dim " << inputs_shape_.at(0).size();
751     return FAILED;
752   }
753 
754   // don't support scalar index
755   if (inputs_shape_[1].empty()) {
756     MS_LOG(ERROR) << name_ << ": Don't support scalar index.";
757     return FAILED;
758   }
759 
760   // axis != 0, param_shape(0)%(param_strategy(0)*param_strategy(axis)) must be 0
761   MS_EXCEPTION_IF_ZERO("param_strategy", param_strategy.at(0) * param_strategy.at(LongToSize(axis_)));
762   if (axis_ != 0 && inputs_shape_[0][0] % (param_strategy.at(0) * param_strategy.at(LongToSize(axis_))) != 0) {
763     MS_LOG(ERROR) << name_ << ": param_shape(0) can't be divided by (param_strategy(0)*param_strategy(axis)).";
764     return FAILED;
765   }
766 
767   if (CheckSplitAxisStrategy(param_strategy, indices_strategy) != SUCCESS) {
768     return FAILED;
769   }
770 
771   // According to the strategy, set the private members.
772   SetAttribute(param_strategy);
773 
774   return SUCCESS;
775 }
776 
InferDevMatrixShape()777 Status ShardAxisImpl::InferDevMatrixShape() {
778   dev_matrix_shape_ = param_strategy_;
779 
780   // infer out dev_matrix_shape
781   // axis is not 0, split axis
782   if (axis_ != 0 && param_strategy_.at(LongToSize(axis_)) != 1) {
783     for (size_t i = 1; i < param_strategy_.size(); ++i) {
784       if (i == LongToSize(axis_)) {
785         out_dev_matrix_shape_.push_back(1);
786       } else {
787         out_dev_matrix_shape_.push_back(param_strategy_.at(i));
788       }
789     }
790     out_dev_matrix_shape_.push_back(param_strategy_.at(0) * param_strategy_.at(LongToSize(axis_)));
791   } else {
792     out_dev_matrix_shape_ = dev_matrix_shape_;
793   }
794   auto param_product = std::accumulate(param_strategy_.begin(), param_strategy_.end(), 1, std::multiplies<int64_t>());
795   auto index_product =
796     std::accumulate(indices_strategy_.begin(), indices_strategy_.end(), 1, std::multiplies<int64_t>());
797   auto stage_device_size = SizeToLong(g_device_manager->GetDeviceListInThisStage().size());
798   if (param_product * index_product < stage_device_size) {
799     MS_EXCEPTION_IF_ZERO("param_product * index_product", param_product * index_product);
800     repeated_calculation_num_ = stage_device_size / (param_product * index_product);  // set the repeat calc num
801     if (repeated_num_in_dev_matrix_right_) {
802       out_dev_matrix_shape_.push_back(repeated_calculation_num_);
803     } else {
804       (void)out_dev_matrix_shape_.insert(out_dev_matrix_shape_.begin(), repeated_calculation_num_);
805     }
806   }
807 
808   return SUCCESS;
809 }
810 
InferTensorMap()811 Status ShardAxisImpl::InferTensorMap() {
812   // param_strategy(axis) is not 1
813   // infer input tensor map
814   size_t param_size = inputs_shape_.at(0).size();
815   size_t index_size = inputs_shape_.at(1).size();
816   Shape tensor_map_index;
817   Shape tensor_map_params;
818 
819   (void)tensor_map_index.insert(tensor_map_index.begin(), index_size, MAP_NONE);
820   for (size_t i = 0; i < param_size; ++i) {
821     tensor_map_params.push_back(SizeToLong(param_size - i - 1));
822   }
823 
824   (void)inputs_tensor_map_.emplace_back(std::move(tensor_map_params));
825   (void)inputs_tensor_map_.emplace_back(std::move(tensor_map_index));
826 
827   // infer output tensor map
828   Shape tensor_map_out;
829   if (axis_ == 0) {
830     if ((dynamic_shape_indices_ && target_ != CPU) || axis_split_forward_allreduce_) {
831       // the output is repeat calculation
832       (void)tensor_map_out.insert(tensor_map_out.end(), MAP_NONE);
833     } else {
834       (void)tensor_map_out.insert(tensor_map_out.end(), SizeToLong(param_size) - 1);
835     }
836     (void)tensor_map_out.insert(tensor_map_out.end(), index_size - 1, MAP_NONE);
837     for (size_t i = 1; i < param_size; ++i) {
838       tensor_map_out.push_back(param_size - 1 - i);
839     }
840   } else {
841     for (size_t i = 0; i < param_size; ++i) {
842       if (i == LongToSize(axis_)) {
843         (void)tensor_map_out.insert(tensor_map_out.end(), index_size, MAP_NONE);
844       } else {
845         if (i == 0 && dynamic_shape_indices_ && target_ != CPU) {
846           tensor_map_out.push_back(MAP_NONE);
847         }
848         tensor_map_out.push_back(SizeToLong(i));
849       }
850     }
851   }
852   (void)outputs_tensor_map_.emplace_back(std::move(tensor_map_out));
853 
854   return SUCCESS;
855 }
856 
InferTensorInfo()857 Status ShardAxisImpl::InferTensorInfo() {
858   // infer tensor layout
859   TensorLayout input_tensor_layout, input_index_layout, output_tensor_layout;
860 
861   if ((input_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_.at(0), inputs_shape_[0]) != SUCCESS) ||
862       (input_index_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_.at(1), inputs_shape_[1]) != SUCCESS) ||
863       (output_tensor_layout.InitFromVector(out_dev_matrix_shape_, outputs_tensor_map_.at(0), outputs_shape_[0]) !=
864        SUCCESS)) {
865     return FAILED;
866   }
867 
868   // infer tensor info
869   TensorInfo input_tensor_info(input_tensor_layout);
870   TensorInfo input_index_info(input_index_layout);
871   TensorInfo output_tensor_info(output_tensor_layout);
872 
873   inputs_tensor_info_.push_back(input_tensor_info);
874   inputs_tensor_info_.push_back(input_index_info);
875   outputs_tensor_info_.push_back(output_tensor_info);
876   return SUCCESS;
877 }
878 
InferGroup()879 Status ShardAxisImpl::InferGroup() {
880   size_t dim = LongToSize(axis_);
881 
882   int64_t rank = g_device_manager->global_rank();
883   DeviceMatrix dev_matrix(rank, g_device_manager->GetDeviceListInThisStage(), dev_matrix_shape_);
884   RankList group_devices;
885 
886   // the dev_matrix[0] is repeated_calc_num, so the dim need to add 1
887   if (repeated_calculation_num_ > 1 && !repeated_num_in_dev_matrix_right_) {
888     dim = dim + 1;
889   }
890 
891   if (gather_mode_ == SHARD_BATCH_AND_AXIS) {
892     dim = 1;
893     MS_LOG(INFO) << name_ << ": Sharding batch and axis, the group dim is " << dim;
894   }
895 
896   if (dev_matrix.GetDevicesAlongDim(SizeToUlong(dim), &group_devices) != SUCCESS) {
897     MS_LOG(ERROR) << name_ << ": Create group failed.";
898     return FAILED;
899   }
900   if (group_devices.size() == 1) {
901     MS_LOG(INFO) << name_ << ": The group is empty";
902     return SUCCESS;
903   }
904 
905   MS_LOG(INFO) << name_ << ": The group ranks is " << group_devices;
906   if (g_device_manager->CreateGroup(group_devices, &group_) != SUCCESS) {
907     MS_LOG(ERROR) << name_ << ": create reduce group failed in table row split.";
908     return FAILED;
909   }
910   return SUCCESS;
911 }
912 
InferForwardCommunication()913 Status ShardAxisImpl::InferForwardCommunication() {
914   forward_op_.clear();
915   // don't split axis or target is not CPU, no need forward communication
916   if (target_ != CPU) {
917     return SUCCESS;
918   }
919   // split axis
920   Attr attr_group;
921   OperatorName operator_name;
922   if (axis_split_forward_allreduce_) {
923     operator_name = ALL_REDUCE;
924   } else {
925     operator_name = REDUCE_SCATTER;
926   }
927 
928   if (InferGroup() != SUCCESS) {
929     MS_LOG(ERROR) << name_ << ": Infer Group failed.";
930     return FAILED;
931   }
932   if (group_.name().empty()) {
933     return SUCCESS;
934   }
935   attr_group = std::make_pair(GROUP, MakeValue(group_.name()));
936   Attr attr_op = std::make_pair(OP, MakeValue(REDUCE_OP_SUM));
937   OperatorAttrs attrs = {attr_op, attr_group};
938   OperatorParams params;
939   OperatorArgs args = std::make_pair(attrs, params);
940   Operator op = std::make_pair(operator_name, args);
941 
942   forward_op_.push_back(op);
943   return SUCCESS;
944 }
945 
InferBias()946 Status ShardAxisImpl::InferBias() {
947   CheckGlobalDeviceManager();
948   int64_t rank = g_device_manager->rank_index_in_stage();
949   auto input_shape = inputs_shape_.at(0);
950   // params_size=1, axis=0
951   if ((input_shape.size() == 1) && (axis_ == 0)) {
952     MS_EXCEPTION_IF_ZERO("param_strategy_.at(0)", param_strategy_.at(0));
953     slice_size_ = input_shape.at(0) / param_strategy_.at(0);
954     // if repeated calculation, because the repeated num in the right of dev-matrix, so rank need to div repeated num
955     if (repeated_calculation_num_ > 1) {
956       if (repeated_num_in_dev_matrix_right_) {
957         rank = rank / repeated_calculation_num_;
958       } else {
959         rank = rank % param_strategy_[0];
960       }
961     }
962     bias_ = rank * slice_size_;
963     return SUCCESS;
964   }
965   // params_size=2, axis=0
966   if ((input_shape.size() == 2) && (axis_ == 0)) {
967     MS_EXCEPTION_IF_ZERO("param_strategy_.at(0)", param_strategy_.at(0));
968     MS_EXCEPTION_IF_ZERO("param_strategy_.at(1)", param_strategy_.at(1));
969     slice_size_ = input_shape.at(0) / param_strategy_.at(0);
970     // if repeated calculation, because the repeated num in the right of dev-matrix, so rank need to div repeated num
971     if (repeated_calculation_num_ > 1) {
972       if (repeated_num_in_dev_matrix_right_) {
973         rank = rank / repeated_calculation_num_;
974       } else {
975         rank = rank % (param_strategy_[0] * param_strategy_[1]);
976       }
977     }
978 #if defined(__linux__) && defined(WITH_BACKEND)
979     if (ps::PsDataPrefetch::GetInstance().cache_enable()) {
980       if (ps::PSContext::instance()->enable_distributed_mindrt()) {
981         bias_ = static_cast<int64_t>(embedding_cache_table_manager.cache_indices_lower_bound());
982       }
983       return SUCCESS;
984     }
985 #endif
986     bias_ = rank / param_strategy_.at(1) * slice_size_;
987     return SUCCESS;
988   }
989   // params_size=2, axis=1
990   if ((input_shape.size() == 2) && (axis_ == 1)) {
991     MS_EXCEPTION_IF_ZERO("param_strategy_.at(1)", param_strategy_.at(1));
992     slice_size_ = input_shape.at(1) / param_strategy_.at(1);
993     bias_ = rank % param_strategy_.at(1) * slice_size_;
994     return SUCCESS;
995   }
996   MS_LOG(ERROR) << name_ << ": Don't support params_size:" << input_shape.size() << " axis:" << axis_;
997   return FAILED;
998 }
999 
InferReplaceGraph(const CNodePtr & cnode)1000 Status ShardAxisImpl::InferReplaceGraph(const CNodePtr &cnode) {
1001   if (target_ == CPU) {
1002     return SUCCESS;
1003   }
1004 
1005   GenerateGraph gen_g = GenerateGraph(attrs_);
1006   if (gen_g.Init(cnode) != SUCCESS) {
1007     MS_LOG(ERROR) << name_ << "GenerateGraph Init failed";
1008     return FAILED;
1009   }
1010 
1011   if (InferBias() != SUCCESS) {
1012     MS_LOG(ERROR) << name_ << ": Infer Bias failed.";
1013     return FAILED;
1014   }
1015   MS_LOG(INFO) << name_ << ": The rank is " << g_device_manager->rank_index_in_stage() << ", the bias is " << bias_;
1016   auto sub = gen_g.PushBack({gen_g.NewOpInst(SUB), gen_g.virtual_input_node(), CreateInt32Tensor(bias_)});
1017   auto relu = gen_g.PushBack({gen_g.NewOpInst(RELU), sub});
1018   auto minimum = gen_g.PushBack({gen_g.NewOpInst(MINIMUM), relu, CreateInt32Tensor(slice_size_ - 1)});
1019   auto equal = gen_g.PushBack({gen_g.NewOpInst(EQUAL), sub, minimum});
1020 
1021   AnfNodePtr gather_v2{nullptr};
1022   auto replace_op_name = GetPrimNameFromInfoName(replace_op_name_);
1023   if (replace_op_name == INDEX_SELECT) {
1024     gather_v2 =
1025       gen_g.PushBack({gen_g.NewOpInst(replace_op_name_), gen_g.virtual_input_node(), CreatInt64Imm(axis_), minimum});
1026   } else if (replace_op_name == GATHERV2) {
1027     gather_v2 = gen_g.PushBack(
1028       {gen_g.NewOpInst(replace_op_name_), gen_g.virtual_input_node(), minimum, CreatInt64Imm(axis_), CreatInt64Imm(0)});
1029   } else {
1030     gather_v2 =
1031       gen_g.PushBack({gen_g.NewOpInst(replace_op_name_), gen_g.virtual_input_node(), minimum, CreatInt64Imm(axis_)});
1032   }
1033 
1034   auto dtype = gen_g.PushBack({gen_g.NewOpInst(DTYPE), gather_v2});
1035   auto dtype_id =
1036     gen_g.PushBack({gen_g.NewOpInst(DTYPETOENUM), CreateStringImm("DtypeToEnum"), CreateStringImm("dtype"), dtype});
1037   auto cast = gen_g.PushBack({gen_g.NewOpInst(CAST), equal, dtype_id});
1038   auto expand_dims = gen_g.PushBack({gen_g.NewOpInst(EXPAND_DIMS), cast, CreatInt64Imm(axis_ - 1)});
1039   auto mul = gen_g.PushBack({gen_g.NewOpInst(MUL), gather_v2, expand_dims});
1040   // don't need expand dim, if param_size = 1
1041   if (inputs_shape_.at(0).size() == 1) {
1042     mul = gen_g.PushBack({gen_g.NewOpInst(MUL), gather_v2, cast});
1043   }
1044   if (InferGroup() != SUCCESS) {
1045     MS_LOG(ERROR) << name_ << ": Infer Group failed.";
1046     return FAILED;
1047   }
1048   Attr attr_op = std::make_pair(OP, MakeValue(REDUCE_OP_SUM));
1049   Attr attr_group = std::make_pair(GROUP, MakeValue(group_.name()));
1050   OperatorAttrs attrs = {attr_op, attr_group};
1051   AnfNodePtr reduce_op;
1052   if (dynamic_shape_indices_ || axis_split_forward_allreduce_ || is_assigned_parallel_) {
1053     reduce_op = gen_g.PushBack({gen_g.NewOpInst(ALL_REDUCE, attrs), mul});
1054   } else {
1055     reduce_op = gen_g.PushBack({gen_g.NewOpInst(REDUCE_SCATTER, attrs), mul});
1056   }
1057   std::vector<std::pair<AnfNodePtr, int64_t>> input_nodes = {std::make_pair(sub, 2), std::make_pair(gather_v2, 1)};
1058   if (replace_op_name == INDEX_SELECT) {
1059     input_nodes = {std::make_pair(sub, 3), std::make_pair(gather_v2, 1)};
1060   }
1061   replace_graph_ = std::make_shared<std::pair<std::vector<std::pair<AnfNodePtr, int64_t>>, AnfNodePtr>>(
1062     std::make_pair(input_nodes, reduce_op));
1063 
1064   return SUCCESS;
1065 }
1066 
InferReplaceOps()1067 Status ShardAxisImpl::InferReplaceOps() {
1068   if (target_ != CPU) {  // if target is not CPU, no need to replace ops
1069     return SUCCESS;
1070   }
1071 
1072   int64_t bias = 0;
1073 
1074   if (InferBias() != SUCCESS) {
1075     MS_LOG(ERROR) << name_ << ": Infer offset failed.";
1076     return FAILED;
1077   }
1078   bias = bias_;
1079 
1080   OperatorName op_name = EMBEDDING_LOOKUP;
1081   OperatorAttrs attrs;
1082   Attr param_offset = std::make_pair("offset", MakeValue(bias));
1083   OperatorParams params = {std::make_pair(param_offset, 3)};
1084   OperatorArgs args = std::make_pair(attrs, params);
1085   Operator op = std::make_pair(op_name, args);
1086   replace_op_.push_back(op);
1087 
1088   return SUCCESS;
1089 }
1090 
CheckStrategy(const StrategyPtr & strategy)1091 Status GatherInfo::CheckStrategy(const StrategyPtr &strategy) {
1092   if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
1093     return FAILED;
1094   }
1095   gather_util_ = nullptr;
1096   gather_mode_ = INVALID;
1097 
1098   // param slice shape preferably 32Byte aligned
1099   auto param_shape = inputs_shape_[0];
1100   auto input_dim = strategy->GetInputDim();
1101   auto param_strategy = input_dim[0];
1102   auto indices_strategy = input_dim[1];
1103   MS_LOG(INFO) << name_ << ": the indices shape is " << inputs_shape_[1] << ", the strategy is " << input_dim[1];
1104   MS_EXCEPTION_IF_ZERO("param_strategy.at(param_strategy.size() - 1)", param_strategy.at(param_strategy.size() - 1));
1105   auto slice_shape = param_shape.at(param_shape.size() - 1) / param_strategy.at(param_strategy.size() - 1);
1106   if ((target_ != CPU) && (slice_shape % 8 != 0) && (slice_shape != 1)) {
1107     MS_LOG(WARNING) << "Gather: Last dim of param slice shape is not 32Byte aligned.";
1108   }
1109 
1110   // get the gather mode, and choose the the corresponding implementation
1111   gather_mode_ = GetGatherMode(param_strategy, indices_strategy);
1112   switch (gather_mode_) {
1113     case BATCH: {
1114       gather_util_ = std::make_shared<BatchImpl>(name_, inputs_shape_clone_, outputs_shape_clone_, axis_);
1115       auto batch_util = std::dynamic_pointer_cast<BatchImpl>(gather_util_);
1116       batch_util->set_batch_dims(batch_dims_);
1117       break;
1118     }
1119     case NORMAL:
1120       gather_util_ = std::make_shared<NormalImpl>(name_, inputs_shape_clone_, outputs_shape_clone_, axis_);
1121       break;
1122     case MANUAL: {
1123       gather_util_ = MakeManualUtil();
1124       auto manual_util = std::dynamic_pointer_cast<ManualImpl>(gather_util_);
1125       manual_util->set_param_split_shapes(param_split_shapes_);
1126       manual_util->set_index_offsets(index_offsets_);
1127       manual_util->set_attrs(attrs_);
1128       manual_util->set_target(target_);
1129       manual_util->set_replace_op_name(replace_op_name_);
1130       break;
1131     }
1132     case SHARD_BATCH_AND_AXIS: {
1133       gather_util_ = std::make_shared<ShardBatchAndAxisImpl>(name_, inputs_shape_clone_, outputs_shape_clone_, axis_);
1134       auto shard_batch_and_axis_util = std::dynamic_pointer_cast<ShardBatchAndAxisImpl>(gather_util_);
1135       shard_batch_and_axis_util->set_target(target_);
1136       shard_batch_and_axis_util->set_dynamic_shape_indices(dynamic_shape_indices_);
1137       shard_batch_and_axis_util->set_attrs(attrs_);
1138       shard_batch_and_axis_util->set_replace_op_name(replace_op_name_);
1139       shard_batch_and_axis_util->set_axis_split_forward_allreduce(
1140         true);  // Sharding batch and axis, and the forward use allreduce
1141       break;
1142     }
1143     case SHARD_AXIS_0_DYNAMIC:
1144     case SHARD_AXIS_0_STATIC:
1145     case SHARD_AXIS_1: {
1146       gather_util_ = std::make_shared<ShardAxisImpl>(name_, inputs_shape_clone_, outputs_shape_clone_, axis_);
1147       auto shard_axis_util = std::dynamic_pointer_cast<ShardAxisImpl>(gather_util_);
1148       shard_axis_util->set_target(target_);
1149       shard_axis_util->set_dynamic_shape_indices(dynamic_shape_indices_);
1150       shard_axis_util->set_attrs(attrs_);
1151       shard_axis_util->set_replace_op_name(replace_op_name_);
1152       shard_axis_util->set_assigned_parallel(is_assigned_parallel_);
1153       break;
1154     }
1155     default:
1156       MS_LOG(ERROR) << name_ << ": invalid gather mode: " << gather_mode_;
1157       return FAILED;
1158   }
1159 
1160   gather_util_->set_dynamic_shape_flag(dynamic_shape_flag_);
1161   gather_util_->set_inputs_divisor(inputs_divisor_);
1162   gather_util_->set_outputs_divisor(outputs_divisor_);
1163 
1164   gather_util_->DivisorsReplaceShapes();
1165   if (gather_util_->CheckStrategy(param_strategy, indices_strategy) != SUCCESS) {
1166     return FAILED;
1167   }
1168   gather_util_->ResumeShapes();
1169 
1170   gather_util_->set_param_strategy(param_strategy);
1171   gather_util_->set_indices_strategy(indices_strategy);
1172   gather_util_->set_gather_mode(gather_mode_);
1173   MS_LOG(INFO) << name_ << ": the gather mode is " << gather_util_->GatherModeToString();
1174 
1175   repeated_num_in_dev_matrix_right_ = gather_util_->repeated_num_in_dev_matrix_right();  // set the base class member
1176 
1177   return SUCCESS;
1178 }
1179 
CheckStrategyForDynamicShape(const StrategyPtr & strategy)1180 Status GatherInfo::CheckStrategyForDynamicShape(const StrategyPtr &strategy) {
1181   Strategies strategies = strategy->GetInputDim();
1182   auto param_strategy = strategies[0];
1183   if (param_strategy[axis_] != 1 && inputs_shape_[0][axis_] == -1) {
1184     MS_LOG(ERROR) << name_ << ": the axis dim of first input can not be split if it's dynamic shape, the strategy is "
1185                   << ShapesToString(strategies) << ", the inputs' shape: " << ShapesToString(inputs_shape_)
1186                   << ", the axis " << axis_;
1187     return FAILED;
1188   }
1189   return SUCCESS;
1190 }
1191 
CheckOutputStrategy(const StrategyPtr & out_strategy)1192 Status GatherInfo::CheckOutputStrategy(const StrategyPtr &out_strategy) {
1193   if (out_strategy == nullptr) {
1194     MS_LOG(INFO) << name_ << ": The output strategy is null";
1195     return SUCCESS;
1196   }
1197 
1198   if (CheckStrategyValue(out_strategy, outputs_shape_) != SUCCESS) {
1199     MS_LOG(ERROR) << name_ << ": Invalid output strategy";
1200     return FAILED;
1201   }
1202 
1203   if (axis_ != 0 && batch_dims_ != 0) {
1204     MS_LOG(ERROR) << name_ << ": Set output strategy only for axis = 0 and batch_dims = 0, but the axis is " << axis_
1205                   << ", the batch_dims is " << batch_dims_;
1206     return FAILED;
1207   }
1208 
1209   MS_EXCEPTION_IF_NULL(gather_util_);
1210   auto shard_axis_util = std::dynamic_pointer_cast<ShardAxisImpl>(gather_util_);
1211 
1212   auto in_stra = strategy_->GetInputDim();
1213   auto param_strategy = in_stra[0];
1214   auto index_strategy = in_stra[1];
1215 
1216   // only for axis == 0
1217   auto allreduce_strategy = index_strategy;
1218   (void)allreduce_strategy.insert(allreduce_strategy.end(), param_strategy.begin() + 1, param_strategy.end());
1219   auto reduce_scatter_strategy = allreduce_strategy;
1220   reduce_scatter_strategy[0] *= param_strategy[0];
1221 
1222   auto out_stra = out_strategy->GetInputDim()[0];
1223   if (out_stra == allreduce_strategy) {
1224     if (shard_axis_util != nullptr) {
1225       shard_axis_util->set_axis_split_forward_allreduce(true);
1226     }
1227 
1228     MS_LOG(INFO) << name_ << ": The output strategy is " << out_stra << ", forward use allreduce";
1229     return SUCCESS;
1230   } else if (out_stra == reduce_scatter_strategy) {
1231     if (gather_util_->gather_mode() != SHARD_AXIS_0_STATIC && gather_util_->gather_mode() != SHARD_BATCH_AND_AXIS) {
1232       MS_LOG(ERROR) << name_ << ": The output strategy " << out_stra << " for gather mode "
1233                     << gather_util_->GatherModeToString() << " is invalid, it must be " << allreduce_strategy;
1234       return FAILED;
1235     }
1236 
1237     if (shard_axis_util) {
1238       shard_axis_util->set_axis_split_forward_allreduce(false);
1239     }
1240     MS_LOG(INFO) << name_ << ": The output strategy is " << out_stra << ", forward use reduce scatter";
1241     return SUCCESS;
1242   }
1243 
1244   MS_LOG(ERROR) << name_ << ": The output strategy " << out_stra << " is invalid, it must be " << allreduce_strategy
1245                 << " or " << reduce_scatter_strategy;
1246   return FAILED;
1247 }
1248 
DealWithBatchDimsMirrorOp()1249 void GatherInfo::DealWithBatchDimsMirrorOp() noexcept {
1250   OperatorVector op_for_batch_dims;
1251   mirror_ops_.push_back(op_for_batch_dims);
1252 }
1253 
InferMirrorOps()1254 Status GatherInfo::InferMirrorOps() {
1255   mirror_ops_.clear();
1256   Shape input_a_tensor_map = inputs_tensor_map_.at(0);
1257   std::vector<Group> input_a_group;
1258   if (CreateGroupByTensorMap(input_a_tensor_map, &input_a_group) != SUCCESS) {
1259     ReportError(name_ + " : Create group for input a failed.");
1260     return FAILED;
1261   }
1262 
1263   OperatorVector op_for_input_a, op_for_input_b, op_for_axis;
1264   if (input_a_group.empty()) {
1265     MS_LOG(INFO) << name_ << " : The mirror group is empty.";
1266     return SUCCESS;
1267   } else {
1268     op_for_input_a = CreateMirrorOps(input_a_group[0].name(), input_a_group[0].GetDevNum());
1269     MS_LOG(INFO) << name_ << " : Create the mirror ops for input a success, group is " << input_a_group[0].name();
1270   }
1271 
1272   mirror_ops_.push_back(op_for_input_a);
1273   mirror_ops_.push_back(op_for_input_b);
1274   mirror_ops_.push_back(op_for_axis);
1275   DealWithBatchDimsMirrorOp();
1276 
1277   return SUCCESS;
1278 }
1279 
InferDevMatrixShape()1280 Status GatherInfo::InferDevMatrixShape() {
1281   if (gather_util_->InferDevMatrixShape() != SUCCESS) {
1282     return FAILED;
1283   }
1284   dev_matrix_shape_ = gather_util_->dev_matrix_shape();
1285   out_dev_matrix_shape_ = gather_util_->out_dev_matrix_shape();  // set base class member
1286   return SUCCESS;
1287 }
1288 
InferTensorMap()1289 Status GatherInfo::InferTensorMap() {
1290   // the dev matrix shape may be changed if repeat calculation, so need to reset the dev matrix shape for gather_util
1291   gather_util_->set_dev_matrix_shape(dev_matrix_shape_);
1292 
1293   if (gather_util_->InferTensorMap() != SUCCESS) {
1294     return FAILED;
1295   }
1296 
1297   inputs_tensor_map_ = gather_util_->inputs_tensor_map();
1298   outputs_tensor_map_ = gather_util_->outputs_tensor_map();
1299   return SUCCESS;
1300 }
1301 
InferTensorInfoNoSplitAxis()1302 Status GatherUtil::InferTensorInfoNoSplitAxis() {
1303   TensorLayout input_tensor_layout;
1304   TensorLayout input_index_layout;
1305   TensorLayout output_tensor_layout;
1306 
1307   if ((input_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_.at(0), inputs_shape_[0]) != SUCCESS) ||
1308       (input_index_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_.at(1), inputs_shape_[1]) != SUCCESS) ||
1309       (output_tensor_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_.at(0), outputs_shape_[0]) !=
1310        SUCCESS)) {
1311     return FAILED;
1312   }
1313 
1314   // infer tensor info
1315   TensorInfo input_tensor_info(input_tensor_layout);
1316   TensorInfo input_index_info(input_index_layout);
1317   TensorInfo output_tensor_info(output_tensor_layout);
1318 
1319   inputs_tensor_info_.push_back(input_tensor_info);
1320   inputs_tensor_info_.push_back(input_index_info);
1321   outputs_tensor_info_.push_back(output_tensor_info);
1322   return SUCCESS;
1323 }
1324 
DivisorsReplaceShapes()1325 void GatherUtil::DivisorsReplaceShapes() {
1326   if (!dynamic_shape_flag_) {
1327     return;
1328   }
1329 
1330   inputs_shape_ = inputs_divisor_;
1331   outputs_shape_ = outputs_divisor_;
1332 }
1333 
ResumeShapes()1334 void GatherUtil::ResumeShapes() {
1335   if (!dynamic_shape_flag_) {
1336     return;
1337   }
1338 
1339   inputs_shape_ = inputs_shape_clone_;
1340   outputs_shape_ = outputs_shape_clone_;
1341 }
1342 
InferTensorInfo()1343 Status GatherInfo::InferTensorInfo() {
1344   // the tensor map of gather_util may be changed if repeat calculation, so need to reset
1345   gather_util_->set_inputs_tensor_map(inputs_tensor_map_);
1346   gather_util_->set_outputs_tensor_map(outputs_tensor_map_);
1347 
1348   if (gather_util_->InferTensorInfo() != SUCCESS) {
1349     return FAILED;
1350   }
1351 
1352   inputs_tensor_info_ = gather_util_->inputs_tensor_info();
1353   outputs_tensor_info_ = gather_util_->outputs_tensor_info();
1354   if (name_.find(INDEX_SELECT) != std::string::npos) {
1355     TensorInfo axis_place_holder;
1356     if (inputs_tensor_info_.empty()) {
1357       MS_LOG(ERROR) << name_ << ": the tensor info of inputs is empty";
1358       return FAILED;
1359     }
1360     (void)inputs_tensor_info_.insert(inputs_tensor_info_.cbegin() + 1, axis_place_holder);
1361   }
1362   return SUCCESS;
1363 }
1364 
InferForwardCommunication()1365 Status GatherInfo::InferForwardCommunication() {
1366   if (gather_util_->InferForwardCommunication() != SUCCESS) {
1367     return FAILED;
1368   }
1369   forward_op_ = gather_util_->forward_op();
1370   return SUCCESS;
1371 }
1372 
replace_graph(const CNodePtr & cnode)1373 ReplaceGraphPtr GatherInfo::replace_graph(const CNodePtr &cnode) {
1374   // target_ == CPU, no need to replace graph
1375   if (target_ == CPU) {
1376     return nullptr;
1377   }
1378   if (gather_util_->InferReplaceGraph(cnode) != SUCCESS) {
1379     MS_LOG(EXCEPTION) << name_ << ": infer replace graph failed.";
1380   }
1381   replace_graph_ = gather_util_->replace_graph();
1382   return replace_graph_;
1383 }
1384 
ComputeReplaceOp()1385 Status GatherInfo::ComputeReplaceOp() {
1386   if (gather_util_->InferReplaceOps() != SUCCESS) {
1387     return FAILED;
1388   }
1389   replace_op_ = gather_util_->replace_op();
1390   return SUCCESS;
1391 }
1392 
Init(const StrategyPtr & in_strategy,const StrategyPtr & out_strategy,const std::vector<std::shared_ptr<TensorLayout>> & in_tensor_layouts,const std::vector<std::shared_ptr<TensorLayout>> & out_tensor_layouts)1393 Status GatherInfo::Init(const StrategyPtr &in_strategy, const StrategyPtr &out_strategy,
1394                         const std::vector<std::shared_ptr<TensorLayout>> &in_tensor_layouts,
1395                         const std::vector<std::shared_ptr<TensorLayout>> &out_tensor_layouts) {
1396   if (InitWithAutoRepeatCalc(in_strategy, out_strategy) != SUCCESS) {
1397     MS_LOG(ERROR) << name_ << ": Init failed.";
1398     return FAILED;
1399   }
1400   // only target_ == CPU, we need to replace op
1401   if (target_ == CPU && ComputeReplaceOp() != SUCCESS) {
1402     MS_LOG(ERROR) << name_ << ": ComputeReplaceOp failed.";
1403     return FAILED;
1404   }
1405   MS_LOG(INFO) << name_ << ": Init success.";
1406   return SUCCESS;
1407 }
1408 
InitForCostModel(const StrategyPtr & in_strategy,const StrategyPtr & out_strategy)1409 Status GatherInfo::InitForCostModel(const StrategyPtr &in_strategy, const StrategyPtr &out_strategy) {
1410   if (InitForCostModelWithAutoRepeatCalc(in_strategy, out_strategy) != SUCCESS) {
1411     FILTER_LOG(is_auto_parallel_) << name_ << ": Init for cost model failed.";
1412     return FAILED;
1413   }
1414   auto param_strategy = strategy_->GetInputDim().at(0);
1415   // cost model set axis and strategy
1416   auto gather_cost = std::dynamic_pointer_cast<GatherCost>(operator_cost());
1417   gather_cost->set_axis(axis_);
1418   gather_cost->set_strategy(param_strategy);
1419   MS_LOG(INFO) << name_ << ": Init for cost model success.";
1420   return SUCCESS;
1421 }
1422 
SetCostUnderStrategy(const StrategyPtr & strategy)1423 Status GatherInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }
1424 
GenerateOpStrategies(int64_t stage_id)1425 std::vector<StrategyPtr> GatherInfo::GenerateOpStrategies(int64_t stage_id) {
1426   if (manual_split_) {
1427     MS_LOG(EXCEPTION) << name_ << ": Manual split does not support to search strategy";
1428   }
1429   Shape input0_split(inputs_shape_[0].size(), 1);
1430   Shape input1_split(inputs_shape_[1].size(), 1);
1431   Shapes splittable_inputs = {input0_split, input1_split};
1432 
1433   std::vector<StrategyPtr> sp_vector;
1434   if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) {
1435     MS_LOG(EXCEPTION) << name_ << ": Generate strategies for independent inputs() failed.";
1436   }
1437   return sp_vector;
1438 }
1439 
GenerateBatchStrategies()1440 std::shared_ptr<Strategies> GatherInfo::GenerateBatchStrategies() {
1441   if (GetAttrs() != SUCCESS) {
1442     MS_LOG(EXCEPTION) << name_ << ": Get attr failed";
1443   }
1444   if (manual_split_) {
1445     MS_LOG(EXCEPTION) << name_ << ": Manual split does not support to generate batch strategy";
1446   }
1447 
1448   Dimensions param_strategy(inputs_shape_[0].size(), 1);
1449   Dimensions index_strategy;
1450   index_strategy.push_back(stage_device_size_);
1451   for (size_t i = 1; i < inputs_shape_[1].size(); i++) {
1452     index_strategy.push_back(1);
1453   }
1454 
1455   if (batch_dims_ > 0 && !param_strategy.empty()) {
1456     param_strategy[0] = stage_device_size_;
1457   }
1458   Strategies strategy_v = {param_strategy, index_strategy};
1459   return std::make_shared<Strategies>(strategy_v);
1460 }
1461 
1462 REGISTER(GatherInfo);
1463 REGISTER(IndexSelectInfo);
1464 REGISTER(SparseGatherV2Info);
1465 REGISTER(EmbeddingLookupInfo);
1466 }  // namespace parallel
1467 }  // namespace mindspore
1468