1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/parallel/ops_info/gather_info.h"
18
19 #include <algorithm>
20 #include <functional>
21 #include <memory>
22 #include <numeric>
23 #include <utility>
24 #include <vector>
25
26 #include "frontend/parallel/device_matrix.h"
27 #include "frontend/parallel/dynamic_creator.h"
28 #include "frontend/parallel/graph_util/generate_graph.h"
29 #include "include/common/utils/parallel_context.h"
30 #if defined(__linux__) && defined(WITH_BACKEND)
31 #include "include/backend/distributed/embedding_cache/embedding_cache_utils.h"
32 #include "include/backend/distributed/ps/ps_cache/ps_data_prefetch.h"
33 #include "include/backend/distributed/ps/ps_context.h"
34 #endif
35
36 namespace mindspore {
37 namespace parallel {
GetManualSplitWithoutOffsetAttr()38 Status GatherInfo::GetManualSplitWithoutOffsetAttr() {
39 auto manual_split_without_offset_iter = attrs_.find("manual_split");
40 if (manual_split_without_offset_iter != attrs_.end()) {
41 manual_split_ = true;
42 MS_EXCEPTION_IF_NULL(manual_split_without_offset_iter->second);
43 if (manual_split_without_offset_iter->second->cast<ValueTuplePtr>() == nullptr) {
44 MS_LOG(ERROR) << name_ << ": Manual split without offset strategy's format is wrong! Need ValueSequence";
45 return FAILED;
46 }
47 std::vector<ValuePtr> value_vector = manual_split_without_offset_iter->second->cast<ValueTuplePtr>()->value();
48 MS_LOG(INFO) << name_ << ": manual split with offset is " << manual_split_without_offset_iter->second->ToString();
49
50 int64_t offset = 0;
51 for (auto &ele : value_vector) {
52 index_offsets_.push_back(offset);
53 if (!ele->isa<Int64Imm>()) {
54 MS_LOG(ERROR) << name_ << ": The element of manual split must be int64_t";
55 return FAILED;
56 }
57 auto param_split_shape = static_cast<int64_t>(GetValue<int64_t>(ele));
58 if (param_split_shape <= 0) {
59 MS_LOG(ERROR) << name_ << ": The value of manual split must be positive, but got " << param_split_shape;
60 return FAILED;
61 }
62 param_split_shapes_.push_back(param_split_shape);
63 offset += param_split_shape;
64 }
65 if (param_split_shapes_.empty()) {
66 MS_LOG(ERROR) << name_ << ": Failed to extract param split's split info";
67 return FAILED;
68 }
69 }
70
71 return SUCCESS;
72 }
73
GetManualSplitAttr()74 Status GatherInfo::GetManualSplitAttr() {
75 auto manual_split_with_offset_iter = attrs_.find("manual_split_with_offset");
76 if (manual_split_with_offset_iter != attrs_.end()) {
77 manual_split_ = true;
78 auto var = manual_split_with_offset_iter->second->cast<ValueTuplePtr>();
79 if (var == nullptr) {
80 MS_LOG(ERROR) << name_ << ": Manual split with offset strategy's format is wrong! Need ValueSequence";
81 return FAILED;
82 }
83
84 MS_LOG(INFO) << name_ << ": manual split with offset strategy " << var->ToString();
85 for (auto &ele : var->value()) {
86 if (!ele->isa<ValueSequence>()) {
87 MS_LOG(ERROR) << name_ << ": Manual split with offset strategy's format is wrong! Need ValueSequence";
88 return FAILED;
89 }
90 std::vector<ValuePtr> value_vector = ele->cast<ValueTuplePtr>()->value();
91 if (value_vector.size() != 2) {
92 MS_LOG(ERROR) << name_ << ": Size of manual split with offset's element must be 2";
93 return FAILED;
94 }
95 int64_t param_split_row = (GetValue<int64_t>(value_vector[0]));
96 int64_t offset = (GetValue<int64_t>(value_vector[1]));
97 if ((param_split_row <= 0) || (offset < 0)) {
98 MS_LOG(ERROR) << name_ << ": The value of param split shape must be positive, "
99 << "and the offset must be greater than or equal to 0";
100 return FAILED;
101 }
102 param_split_shapes_.push_back(param_split_row);
103 index_offsets_.push_back(offset);
104 }
105
106 if (param_split_shapes_.empty()) {
107 MS_LOG(ERROR) << name_ << ": Failed to extract param split with offset's split info";
108 return FAILED;
109 }
110 if (std::any_of(index_offsets_.begin(), index_offsets_.end(), [](const int64_t &offset) { return offset < 0; })) {
111 MS_LOG(ERROR) << name_ << ": Index offset must not be less than 0";
112 return FAILED;
113 }
114 return SUCCESS;
115 }
116
117 if (GetManualSplitWithoutOffsetAttr() != SUCCESS) {
118 return FAILED;
119 }
120
121 return SUCCESS;
122 }
123
GetBatchDims()124 void GatherInfo::GetBatchDims() noexcept {
125 if (name_.find(INDEX_SELECT) != std::string::npos) {
126 batch_dims_ = 0;
127 return;
128 }
129 auto batch_dims_opt = GetScalarValueFromInputs<int64_t>(input_value_, name_, BATCH_DIMS);
130 if (batch_dims_opt.has_value()) {
131 batch_dims_ = batch_dims_opt.value();
132 } else {
133 MS_LOG(EXCEPTION) << name_ << ": Failed to fetch the value of batch dims.";
134 }
135 }
136
MakeManualUtil()137 GatherUtilPtr GatherInfo::MakeManualUtil() {
138 return std::make_shared<GatherManualImpl>(name_, inputs_shape_clone_, outputs_shape_clone_, axis_);
139 }
140
MakeManualUtil()141 GatherUtilPtr SparseGatherV2Info::MakeManualUtil() {
142 return std::make_shared<ManualImpl>(name_, inputs_shape_clone_, outputs_shape_clone_, axis_);
143 }
144
MakeManualUtil()145 GatherUtilPtr EmbeddingLookupInfo::MakeManualUtil() {
146 return std::make_shared<ManualImpl>(name_, inputs_shape_clone_, outputs_shape_clone_, axis_);
147 }
148
GetAttrs()149 Status GatherInfo::GetAttrs() {
150 if (attrs_.find(TARGET) != attrs_.end()) {
151 target_ = GetStringAttr(TARGET);
152 }
153
154 if (name_.find(EMBEDDING_LOOKUP) != std::string::npos && target_ != CPU) {
155 MS_LOG(ERROR) << name_ << ": must be set the cpu target";
156 return FAILED;
157 }
158
159 size_t axis_index = 2;
160 if (name_.find(INDEX_SELECT) != std::string::npos) {
161 axis_index = 1;
162 }
163
164 MS_EXCEPTION_IF_NULL(input_value_[axis_index]);
165 auto value = GetValue<int64_t>(input_value_[axis_index]);
166
167 // get axis, the third input is the axis, is a ValueNode, embeddinglookup doesn't have axis, and its offset.
168 if (target_ != CPU) {
169 auto params_shape = inputs_shape_.at(0);
170 if (params_shape.empty()) {
171 MS_LOG(ERROR) << name_ << ": params can not be a scalar!";
172 return FAILED;
173 }
174 if (value < 0) { // if axis is negative then convert it to positive
175 value += SizeToLong(params_shape.size());
176 }
177 axis_ = value;
178 } else {
179 if (value != 0) {
180 if (name_.find(EMBEDDING_LOOKUP) != std::string::npos) {
181 MS_LOG(ERROR) << name_ << ": the target is cpu, and the offset must be 0, but got " << value;
182 } else {
183 MS_LOG(ERROR) << name_ << ": the target is cpu, and the axis must be 0, but got " << value;
184 }
185 return FAILED;
186 }
187 }
188
189 if (GetManualSplitAttr() != SUCCESS) {
190 return FAILED;
191 }
192
193 GetBatchDims();
194
195 if (manual_split_ && (axis_ != 0)) {
196 MS_LOG(ERROR) << name_ << ": The axis must be 0 if manual split, bug got " << axis_;
197 return FAILED;
198 }
199
200 if (std::find(inputs_shape_[1].begin(), inputs_shape_[1].end(), -1) != inputs_shape_[1].end()) {
201 dynamic_shape_indices_ = true;
202 }
203 #if defined(__linux__) && defined(WITH_BACKEND)
204 if (ps::PsDataPrefetch::GetInstance().cache_enable()) {
205 dynamic_shape_indices_ = true;
206 }
207 #endif
208 return SUCCESS;
209 }
210
211 // return true: axis is 0, and split the first dimension of parameter and the first dimension of indices
212 // otherwise return false
ShardBatchAndAxis(const Shape & param_strategy,const Shape & indices_strategy) const213 bool GatherInfo::ShardBatchAndAxis(const Shape ¶m_strategy, const Shape &indices_strategy) const {
214 if (axis_ != 0) {
215 return false;
216 }
217
218 if ((param_strategy.size() != 2) || (indices_strategy.size() != 2)) {
219 return false;
220 }
221
222 if ((param_strategy[1] != 1) || (indices_strategy[1] != 1)) {
223 return false;
224 }
225
226 if (param_strategy[0] * indices_strategy[0] != stage_device_size_) {
227 return false;
228 }
229
230 if ((param_strategy[0] == stage_device_size_) || (indices_strategy[0] == stage_device_size_)) {
231 return false;
232 }
233
234 return true;
235 }
236
GetGatherMode(const Shape & param_strategy,const Shape & indices_strategy) const237 GatherMode GatherInfo::GetGatherMode(const Shape ¶m_strategy, const Shape &indices_strategy) const {
238 if (batch_dims_ > 0) {
239 return BATCH;
240 }
241
242 if (param_strategy[LongToSize(axis_)] == NO_SPLIT_STRATEGY) {
243 return NORMAL;
244 }
245
246 if (manual_split_) {
247 return MANUAL;
248 }
249
250 if (ShardBatchAndAxis(param_strategy, indices_strategy)) {
251 return SHARD_BATCH_AND_AXIS;
252 }
253
254 if (axis_ == 0 && param_strategy[0] != NO_SPLIT_STRATEGY) {
255 if (std::find(inputs_shape_[1].begin(), inputs_shape_[1].end(), -1) != inputs_shape_[1].end()) {
256 return SHARD_AXIS_0_DYNAMIC;
257 } else {
258 return SHARD_AXIS_0_STATIC;
259 }
260 }
261
262 if (axis_ == 1 && param_strategy[1] != NO_SPLIT_STRATEGY) {
263 return SHARD_AXIS_1;
264 }
265
266 return INVALID;
267 }
268
269 // axis can not be split, and the strategies of batch dims must be equal
270 // support repeat calculation
CheckStrategy(const Shape & param_strategy,const Shape & indices_strategy)271 Status BatchImpl::CheckStrategy(const Shape ¶m_strategy, const Shape &indices_strategy) {
272 if (param_strategy[LongToSize(axis_)] != NO_SPLIT_STRATEGY) {
273 MS_LOG(ERROR) << name_ << ": batch mode, the axis can not be split, but the param strategy is " << param_strategy
274 << ", and the axis is " << axis_;
275 return FAILED;
276 }
277
278 for (size_t i = 0; i < LongToSize(batch_dims_); ++i) {
279 if (param_strategy[i] != indices_strategy[i]) {
280 MS_LOG(ERROR)
281 << name_
282 << ": batch mode, the strategy of the batch dims of param and indices must be equal, but the param strategy is "
283 << param_strategy << ", and the indices strategy is " << indices_strategy << ", batch dims is " << batch_dims_;
284 return FAILED;
285 }
286 }
287 return SUCCESS;
288 }
289
290 // batch mode: axis can not be split
291 // param shape: [A, B, C, D, E]
292 // indices shape: [A, B, F, G]
293 // batch_dims = 2
294 // axis = 3
295 // out = gather(param, indices, axis)
296 // out shape: [A, B, C, F, G, E]
297 // parameter's strategy: [a, b, c, 1, e], indices' strategy: [a, b, f, g]
298 // output's strategy: [a, b, c, f, g, e]
299 // dev_matrix: [a, b, f, g, c, 1, e]
InferDevMatrixShape()300 Status BatchImpl::InferDevMatrixShape() {
301 auto indices_tmp = indices_strategy_; // [a, b, f, g]
302 auto param_tmp = param_strategy_; // [a, b, c, d, e] = [a, b, c, 1, e]
303 (void)param_tmp.erase(param_tmp.cbegin(), param_tmp.cbegin() + LongToSize(batch_dims_)); // [C, 1, E]
304
305 Shape tmp = indices_tmp;
306 (void)tmp.insert(tmp.cend(), param_tmp.cbegin(), param_tmp.cend()); // [a, b, f, g, c, 1, e]
307
308 dev_matrix_shape_ = tmp;
309 MS_LOG(INFO) << name_ << ": batch mode, the dev matrix shape is " << dev_matrix_shape_;
310 return SUCCESS;
311 }
312
InferTensorMap()313 Status BatchImpl::InferTensorMap() {
314 TensorMap tmp_map;
315 int64_t size = SizeToInt(outputs_shape_[0].size()) + 1;
316 for (int i = 0; i < size; ++i) {
317 tmp_map.push_back(size - i - 1); // tmp_map: [a, b, f, g, c, 1, e]
318 }
319
320 TensorMap param_map = tmp_map; // [a, b, f, g, c, 1, e]
321 (void)param_map.erase(param_map.cbegin() + LongToSize(batch_dims_),
322 param_map.cbegin() + inputs_shape_[1].size()); // [a, b, c, 1, e]
323
324 TensorMap indices_map = tmp_map; // [a, b, f, g, c, 1, e]
325 (void)indices_map.erase(indices_map.cbegin() + inputs_shape_[1].size(), indices_map.cend()); // [a, b, f, g]
326
327 TensorMap out_map = param_map; // [a, b, c, 1, e]
328 (void)out_map.erase(out_map.cbegin() + LongToSize(axis_)); // [a, b, c, e]
329
330 TensorMap indices_rm_batch = indices_map; // [a, b, f, g]
331 (void)indices_rm_batch.erase(indices_rm_batch.cbegin(),
332 indices_rm_batch.cbegin() + LongToSize(batch_dims_)); // [f, g]
333
334 (void)out_map.insert(out_map.cbegin() + LongToSize(axis_), indices_rm_batch.cbegin(),
335 indices_rm_batch.cend()); // [a, b, c, f, g, e]
336
337 inputs_tensor_map_.push_back(param_map); // param
338 inputs_tensor_map_.push_back(indices_map); // indices
339 outputs_tensor_map_.push_back(out_map); // out
340 return SUCCESS;
341 }
342
343 // axis can not be split
344 // support repeat calculation
CheckStrategy(const Shape & param_strategy,const Shape & indices_strategy)345 Status NormalImpl::CheckStrategy(const Shape ¶m_strategy, const Shape &indices_strategy) {
346 if (param_strategy[LongToSize(axis_)] != NO_SPLIT_STRATEGY) {
347 MS_LOG(ERROR) << name_ << ": normal mode, the axis can not be split, but the param strategy is " << param_strategy
348 << ", and the axis is " << axis_;
349 return FAILED;
350 }
351
352 return SUCCESS;
353 }
354
355 // normal mode: axis can not be split
356 // param shape: [C, D, E]
357 // indices shape: [F, G]
358 // axis = 1
359 // out = gather(param, indices, axis)
360 // out shape: [C, F, G, E]
361 // parameter's strategy: [c, 1, e], indices' strategy: [f, g]
362 // output's strategy: [c, f, g, e]
363 // dev_matrix: [f, g, c, 1, e]
InferDevMatrixShape()364 Status NormalImpl::InferDevMatrixShape() {
365 auto indices_tmp = indices_strategy_; // [f, g]
366 auto param_tmp = param_strategy_; // [c, d, e] = [c, 1, e]
367
368 Shape tmp = indices_tmp;
369 (void)tmp.insert(tmp.cend(), param_tmp.cbegin(), param_tmp.cend()); // [f, g, c, 1, e]
370
371 dev_matrix_shape_ = tmp;
372 MS_LOG(INFO) << name_ << ": normal mode, the dev matrix shape is " << dev_matrix_shape_;
373 return SUCCESS;
374 }
375
InferTensorMap()376 Status NormalImpl::InferTensorMap() {
377 TensorMap tmp_map;
378 int64_t size = SizeToInt(outputs_shape_[0].size()) + 1;
379 for (int i = 0; i < size; ++i) {
380 tmp_map.push_back(size - i - 1); // tmp_map: [f, g, c, 1, e]
381 }
382
383 TensorMap param_map = tmp_map; // [f, g, c, 1, e]
384 (void)param_map.erase(param_map.cbegin(), param_map.cbegin() + inputs_shape_[1].size()); // [c, 1, e]
385
386 TensorMap indices_map = tmp_map; // [f, g, c, 1, e]
387 (void)indices_map.erase(indices_map.cbegin() + inputs_shape_[1].size(), indices_map.cend()); // [f, g]
388
389 TensorMap out_map = param_map; // [c, 1, e]
390 (void)out_map.erase(out_map.cbegin() + LongToSize(axis_)); // [c, e]
391 (void)out_map.insert(out_map.cbegin() + LongToSize(axis_), indices_map.cbegin(), indices_map.cend()); // [c, f, g, e]
392
393 inputs_tensor_map_.push_back(param_map); // param
394 inputs_tensor_map_.push_back(indices_map); // indices
395 outputs_tensor_map_.push_back(out_map); // out
396 return SUCCESS;
397 }
398
399 // constraint: the field dimension of indices is the last dimension
400 // parameter's dim >= 1, indices' dim >= 1, axis == 0
401 // parameter's strategy: [a, b, ..., c], indices' strategy: [1, ..., 1, a]
402 // output's strategy: [1, ..., 1, a, b, ..., c]
403 // dev_matrix: [a, b, ..., c]
404 // can not support repeated calculation
CheckStrategy(const Shape & param_strategy,const Shape & indices_strategy)405 Status ManualImpl::CheckStrategy(const Shape ¶m_strategy, const Shape &indices_strategy) {
406 if (indices_strategy.size() < 1) {
407 MS_LOG(ERROR) << name_ << ": The size of indices strategy must be positive, but got " << indices_strategy.size();
408 return FAILED;
409 }
410
411 auto product_i = std::accumulate(indices_strategy.begin(), indices_strategy.end(), 1, std::multiplies<int64_t>());
412 size_t indices_split_dim = indices_strategy.size() - 1; // only the last dim of indices can be split
413 if (product_i != indices_strategy[indices_split_dim]) {
414 MS_LOG(ERROR) << name_ << ": Only the last dim of indices can be split, but got " << indices_strategy;
415 return FAILED;
416 }
417
418 if (param_strategy[0] != indices_strategy[indices_split_dim]) {
419 MS_LOG(ERROR) << name_ << ": The param_strategy[0] " << param_strategy[0]
420 << " must be equal to indices_strategy[-1] " << indices_strategy[indices_split_dim];
421 return FAILED;
422 }
423
424 if (indices_strategy[indices_split_dim] != SizeToLong(param_split_shapes_.size())) {
425 MS_LOG(ERROR) << name_ << ": The indices_strategy[-1] " << indices_strategy[indices_split_dim]
426 << " must be equal to manual split size " << param_split_shapes_.size();
427 return FAILED;
428 }
429 MS_EXCEPTION_IF_ZERO("indices_strategy[indices_split_dim]", indices_strategy[indices_split_dim]);
430 int64_t min_param_slice_row = inputs_shape_[1][indices_split_dim] / indices_strategy[indices_split_dim];
431 bool invalid = std::any_of(param_split_shapes_.begin(), param_split_shapes_.end(),
432 [&min_param_slice_row](int64_t v) { return v < min_param_slice_row; });
433 if (invalid) {
434 MS_LOG(ERROR) << name_ << ": The split value " << param_split_shapes_
435 << " must be larger than or equal to indices field slice size " << min_param_slice_row;
436 return FAILED;
437 }
438
439 if (inputs_shape_[0][0] < inputs_shape_[1][indices_split_dim]) {
440 MS_LOG(ERROR) << name_ << ": The param's row size " << inputs_shape_[0][0]
441 << " is smaller than indices' field size " << inputs_shape_[1][indices_split_dim];
442 return FAILED;
443 }
444
445 // Don't support repeated calc
446 auto product_p = std::accumulate(param_strategy.begin(), param_strategy.end(), 1, std::multiplies<int64_t>());
447 MS_EXCEPTION_IF_NULL(g_device_manager);
448 if (product_p < SizeToLong(g_device_manager->GetDeviceListInThisStage().size())) {
449 MS_LOG(ERROR) << name_ << ": Manual split doesn't support repeated calc";
450 return FAILED;
451 }
452
453 int64_t split_shape_sum = std::accumulate(param_split_shapes_.begin(), param_split_shapes_.end(), 0,
454 [](int64_t s, int64_t shape) { return s + shape; });
455 if (split_shape_sum != inputs_shape_[0][0]) {
456 MS_LOG(ERROR) << name_ << ": Sum of split shapes " << split_shape_sum << " must be equal to param_shape[0] "
457 << inputs_shape_[0][0];
458 return FAILED;
459 }
460 return SUCCESS;
461 }
462
InferDevMatrixShape()463 Status ManualImpl::InferDevMatrixShape() {
464 dev_matrix_shape_ = param_strategy_;
465 MS_LOG(INFO) << name_ << ": manual mode, the dev matrix shape is " << dev_matrix_shape_;
466 return SUCCESS;
467 }
468
InferTensorMap()469 Status ManualImpl::InferTensorMap() {
470 Shape param_map;
471 size_t size = inputs_shape_[0].size();
472 for (size_t i = 0; i < size; ++i) {
473 param_map.push_back(static_cast<int64_t>(size - i - 1));
474 }
475
476 size_t indices_size = inputs_shape_[1].size();
477 Shape indices_map(indices_size, MAP_NONE);
478 indices_map[indices_size - 1] = param_map[0];
479
480 Shape out_map = param_map;
481 (void)out_map.insert(out_map.begin(), indices_size - 1, MAP_NONE);
482
483 (void)inputs_tensor_map_.emplace_back(std::move(param_map));
484 (void)inputs_tensor_map_.emplace_back(std::move(indices_map));
485 (void)outputs_tensor_map_.emplace_back(std::move(out_map));
486 return SUCCESS;
487 }
488
InferTensorInfo()489 Status ManualImpl::InferTensorInfo() {
490 // infer tensor shape
491 Shape input_shape = inputs_shape_.at(0);
492 Shape input_index_shape = inputs_shape_.at(1);
493 Shape output_shape = outputs_shape_.at(0);
494 int64_t rank = g_device_manager->rank_index_in_stage();
495 // infer tensor layout
496 TensorLayout input_tensor_layout;
497 TensorLayout input_index_layout;
498 TensorLayout output_tensor_layout;
499
500 int64_t bias_size = 1;
501 if (dev_matrix_shape_.size() > 1) {
502 bias_size = std::accumulate(dev_matrix_shape_.begin() + 1, dev_matrix_shape_.end(), 1, std::multiplies<int64_t>());
503 }
504 if (bias_size == 0) {
505 MS_LOG(ERROR) << name_ << ": Invalid device matrix " << dev_matrix_shape_;
506 return FAILED;
507 }
508 input_shape[0] = param_split_shapes_[LongToSize(rank / bias_size)];
509 input_shape[0] = input_shape[0] * dev_matrix_shape_[0];
510
511 if ((input_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_.at(0), input_shape) != SUCCESS) ||
512 (input_index_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_.at(1), input_index_shape) != SUCCESS) ||
513 (output_tensor_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_.at(0), output_shape) != SUCCESS)) {
514 return FAILED;
515 }
516
517 input_tensor_layout.set_uniform_split(false);
518
519 // infer tensor info
520 TensorInfo input_tensor_info(input_tensor_layout);
521 TensorInfo input_index_info(input_index_layout);
522 TensorInfo output_tensor_info(output_tensor_layout);
523
524 inputs_tensor_info_.push_back(input_tensor_info);
525 inputs_tensor_info_.push_back(input_index_info);
526 outputs_tensor_info_.push_back(output_tensor_info);
527 return SUCCESS;
528 }
529
InferOffset()530 Status ManualImpl::InferOffset() {
531 CheckGlobalDeviceManager();
532 size_t rank = LongToSize(g_device_manager->rank_index_in_stage());
533
534 int64_t bias_size = 1;
535 if (param_strategy_.size() > 1) {
536 bias_size = std::accumulate(param_strategy_.begin() + 1, param_strategy_.end(), 1, std::multiplies<int64_t>());
537 }
538 MS_EXCEPTION_IF_ZERO("bias_size", LongToSize(bias_size));
539 size_t index = rank / LongToSize(bias_size);
540 if (index < index_offsets_.size()) {
541 index_offset_ = index_offsets_[index];
542 MS_LOG(INFO) << name_ << ": Device rank " << rank << ", Index Offset: " << index_offset_;
543 return SUCCESS;
544 }
545
546 MS_LOG(ERROR) << name_ << ": Get index offset failed, index offset size is" << index_offsets_.size();
547 return FAILED;
548 }
549
InferReplaceGraph(const CNodePtr & cnode)550 Status ManualImpl::InferReplaceGraph(const CNodePtr &cnode) {
551 if (target_ == CPU) { // if target is CPU, no need to replace graph
552 return SUCCESS;
553 }
554
555 GenerateGraph gen_g = GenerateGraph(attrs_);
556 if (gen_g.Init(cnode) != SUCCESS) {
557 MS_LOG(ERROR) << name_ << "GenerateGraph Init failed";
558 return FAILED;
559 }
560
561 if (InferOffset() != SUCCESS) {
562 MS_LOG(ERROR) << name_ << ": Infer Bias failed.";
563 return FAILED;
564 }
565
566 auto sub_node = gen_g.PushBack({gen_g.NewOpInst(SUB), gen_g.virtual_input_node(), CreateInt32Tensor(index_offset_)});
567 AnfNodePtr gather_v2_node = nullptr;
568 std::vector<std::pair<AnfNodePtr, int64_t>> input_nodes;
569 if (name_.find(INDEX_SELECT) != std::string::npos) {
570 gather_v2_node =
571 gen_g.PushBack({gen_g.NewOpInst(replace_op_name_), gen_g.virtual_input_node(), CreatInt64Imm(axis_), sub_node});
572 input_nodes = {std::make_pair(sub_node, 3), std::make_pair(gather_v2_node, 1)};
573 } else {
574 gather_v2_node =
575 gen_g.PushBack({gen_g.NewOpInst(replace_op_name_), gen_g.virtual_input_node(), sub_node, CreatInt64Imm(axis_)});
576 input_nodes = {std::make_pair(sub_node, 2), std::make_pair(gather_v2_node, 1)};
577 }
578
579 replace_graph_ = std::make_shared<std::pair<std::vector<std::pair<AnfNodePtr, int64_t>>, AnfNodePtr>>(
580 std::make_pair(input_nodes, gather_v2_node));
581 return SUCCESS;
582 }
583
InferReplaceGraph(const CNodePtr & cnode)584 Status GatherManualImpl::InferReplaceGraph(const CNodePtr &cnode) {
585 if (target_ == CPU) { // if target is CPU, no need to replace graph
586 return SUCCESS;
587 }
588
589 GenerateGraph gen_g = GenerateGraph(attrs_);
590 if (gen_g.Init(cnode) != SUCCESS) {
591 MS_LOG(ERROR) << name_ << "GenerateGraph Init failed";
592 return FAILED;
593 }
594
595 if (InferOffset() != SUCCESS) {
596 MS_LOG(ERROR) << name_ << ": Infer Bias failed.";
597 return FAILED;
598 }
599
600 auto sub_node = gen_g.PushBack({gen_g.NewOpInst(SUB), gen_g.virtual_input_node(), CreateInt32Tensor(index_offset_)});
601 AnfNodePtr gather_v2_node = nullptr;
602 std::vector<std::pair<AnfNodePtr, int64_t>> input_nodes;
603 // Gather processing.
604 if (name_.find(INDEX_SELECT) != std::string::npos) {
605 gather_v2_node =
606 gen_g.PushBack({gen_g.NewOpInst(replace_op_name_), gen_g.virtual_input_node(), CreatInt64Imm(axis_), sub_node});
607 input_nodes = {std::make_pair(sub_node, 3), std::make_pair(gather_v2_node, 1)};
608 } else {
609 gather_v2_node = gen_g.PushBack({gen_g.NewOpInst(replace_op_name_), gen_g.virtual_input_node(), sub_node,
610 CreatInt64Imm(axis_), CreatInt64Imm(0)});
611 input_nodes = {std::make_pair(sub_node, 2), std::make_pair(gather_v2_node, 1)};
612 }
613
614 replace_graph_ = std::make_shared<std::pair<std::vector<std::pair<AnfNodePtr, int64_t>>, AnfNodePtr>>(
615 std::make_pair(input_nodes, gather_v2_node));
616 return SUCCESS;
617 }
618
InferReplaceOps()619 Status ManualImpl::InferReplaceOps() {
620 if (target_ != CPU) { // if target is not CPU, no need to replace ops
621 return SUCCESS;
622 }
623
624 int64_t bias = 0;
625
626 if (InferOffset() != SUCCESS) {
627 MS_LOG(ERROR) << name_ << ": Infer offset failed.";
628 return FAILED;
629 }
630
631 bias = index_offset_;
632
633 OperatorName op_name = EMBEDDING_LOOKUP;
634 OperatorAttrs attrs;
635 Attr param_offset = std::make_pair("offset", MakeValue(bias));
636 OperatorParams params = {std::make_pair(param_offset, 3)};
637 OperatorArgs args = std::make_pair(attrs, params);
638 Operator op = std::make_pair(op_name, args);
639 replace_op_.push_back(op);
640
641 return SUCCESS;
642 }
643
InferDevMatrixShape()644 Status ShardBatchAndAxisImpl::InferDevMatrixShape() {
645 dev_matrix_shape_ = {indices_strategy_[0], param_strategy_[0]};
646 MS_LOG(INFO) << name_ << ": Sharding batch and axis, the dev matrix is " << dev_matrix_shape_;
647 // if forward use reducescatter, the output's dev matrix is {index_strategy[0] * param_strategy[0]}
648 if (axis_split_forward_allreduce_) {
649 out_dev_matrix_shape_ = dev_matrix_shape_;
650 } else {
651 out_dev_matrix_shape_ = {indices_strategy_[0] * param_strategy_[0]};
652 }
653 auto shard_product =
654 std::accumulate(dev_matrix_shape_.begin(), dev_matrix_shape_.end(), 1, std::multiplies<int64_t>());
655 auto stage_device_size = SizeToLong(g_device_manager->GetDeviceListInThisStage().size());
656 if (shard_product < stage_device_size) {
657 MS_EXCEPTION_IF_ZERO("shard_product", shard_product);
658 repeated_calculation_num_ = stage_device_size / shard_product; // set repeated calculation num
659 }
660 return SUCCESS;
661 }
662
InferTensorMap()663 Status ShardBatchAndAxisImpl::InferTensorMap() {
664 Shape param_tensor_map = {0, MAP_NONE};
665 Shape indices_tensor_map = {1, MAP_NONE};
666 Shape out_tensor_map;
667 if (axis_split_forward_allreduce_) {
668 out_tensor_map = {1, MAP_NONE, MAP_NONE}; // the dev matrix is (index_strategy[0], param_strategy[0])
669 } else {
670 out_tensor_map = {0, MAP_NONE, MAP_NONE}; // the dev matrix is (index_strategy[0] * param_strategy[0])
671 }
672
673 (void)inputs_tensor_map_.emplace_back(std::move(param_tensor_map)); // param
674 (void)inputs_tensor_map_.emplace_back(std::move(indices_tensor_map)); // indices
675 (void)outputs_tensor_map_.emplace_back(std::move(out_tensor_map)); // output
676 return SUCCESS;
677 }
678
InferBias()679 Status ShardBatchAndAxisImpl::InferBias() {
680 CheckGlobalDeviceManager();
681 int64_t rank = g_device_manager->rank_index_in_stage();
682 auto input_shape = inputs_shape_.at(0);
683 MS_EXCEPTION_IF_ZERO("param_strategy_[0]", param_strategy_[0]);
684 slice_size_ = input_shape[0] / param_strategy_[0];
685 bias_ = rank % param_strategy_[0] * slice_size_;
686 MS_LOG(INFO) << name_ << ": Sharding batch and axis, the rank is " << rank << ", slice size is " << slice_size_
687 << ", bias is " << bias_;
688 return SUCCESS;
689 }
690
SetAttribute(const Shape & param_strategy)691 void ShardAxisImpl::SetAttribute(const Shape ¶m_strategy) {
692 // axis=0, index_shape(0)%param_strategy(0) must be 0
693 Shape index_shape = inputs_shape_.at(1);
694 MS_EXCEPTION_IF_ZERO("param_strategy.at(0)", param_strategy.at(0));
695 if ((axis_ == 0) && (index_shape.at(0) % param_strategy.at(0) != 0) && !dynamic_shape_indices_) {
696 MS_LOG(INFO) << name_ << ": index_shape(0) can't be divided by param_strategy(0), use allreduce in forward";
697 axis_split_forward_allreduce_ = true;
698 }
699
700 auto product_param = std::accumulate(param_strategy.begin(), param_strategy.end(), 1, std::multiplies<int>());
701 // Cast 1: If repeated calculation, need to set repeated num to the left of dev-matrix. For example,
702 // parameter strategy is [8, 1], indices strategy is [1, 1], dev num is 16,
703 // and dev_matrix is [2, 1, 8, 1, 1], the communication groups are [0, 8] and [0, 1, 2, 3, 4, 5, 6, 7], they
704 // can communicate normally, and dev0 to dev7 have the all parameters.
705 // Cast 2: If not repeated calculation(such as data parallel), need to set repeated num to the right,
706 // as it's easy to introduce the redistribution after or before gather operation, influencing the performance.
707 auto stage_device_size = g_device_manager->GetDeviceListInThisStage().size();
708 if (product_param == SizeToLong(stage_device_size) || product_param == 1) {
709 repeated_num_in_dev_matrix_right_ = true;
710 } else {
711 repeated_num_in_dev_matrix_right_ = false;
712 }
713 MS_LOG(INFO) << "Set repeated_num_in_dev_matrix_right for gather to " << repeated_num_in_dev_matrix_right_;
714 }
715
CheckSplitAxisStrategy(const Shape & param_strategy,const Shape & indices_strategy)716 Status ShardAxisImpl::CheckSplitAxisStrategy(const Shape ¶m_strategy, const Shape &indices_strategy) {
717 // param_strategy(axis) != 1, index can't be split
718 auto stage_device_size = g_device_manager->GetDeviceListInThisStage().size();
719 std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode();
720 bool is_auto_parallel = (parallel_mode == kAutoParallel);
721
722 auto product_i = std::accumulate(indices_strategy.begin(), indices_strategy.end(), 1, std::multiplies<int64_t>());
723 if ((param_strategy.at(LongToSize(axis_)) != 1) && (product_i != 1)) {
724 FILTER_LOG(is_auto_parallel) << name_ << ": param is split at dim (axis)" << axis_ << " ,index can't be split.";
725 return FAILED;
726 }
727
728 // param_strategy(axis) != 1, and axis != 0, don't support repeated calc
729 auto product_p = std::accumulate(param_strategy.begin(), param_strategy.end(), 1, std::multiplies<int64_t>());
730 if ((product_p != SizeToLong(stage_device_size)) && (param_strategy.at(LongToSize(axis_)) != 1) && (axis_ != 0)) {
731 FILTER_LOG(is_auto_parallel) << name_ << ": Invalid strategy. Don't support repeated calc.";
732 return FAILED;
733 }
734
735 if ((product_p != SizeToLong(stage_device_size)) && (param_strategy.at(LongToSize(axis_)) != 1) && (axis_ == 0)) {
736 if ((param_strategy.size() == 2) && (param_strategy[1] != 1)) {
737 FILTER_LOG(is_auto_parallel) << name_
738 << ": axis(0) is split, and param_strategy[1] != 1, don't support"
739 " repeated calc.";
740 return FAILED;
741 }
742 MS_LOG(INFO) << name_ << ": split axis(0) and repeat calculation";
743 }
744 return SUCCESS;
745 }
746
CheckStrategy(const Shape & param_strategy,const Shape & indices_strategy)747 Status ShardAxisImpl::CheckStrategy(const Shape ¶m_strategy, const Shape &indices_strategy) {
748 // only support 1-dim and 2-dim param
749 if (inputs_shape_.at(0).size() != 1 && inputs_shape_.at(0).size() != 2) {
750 MS_LOG(ERROR) << name_ << ": Don't support param dim " << inputs_shape_.at(0).size();
751 return FAILED;
752 }
753
754 // don't support scalar index
755 if (inputs_shape_[1].empty()) {
756 MS_LOG(ERROR) << name_ << ": Don't support scalar index.";
757 return FAILED;
758 }
759
760 // axis != 0, param_shape(0)%(param_strategy(0)*param_strategy(axis)) must be 0
761 MS_EXCEPTION_IF_ZERO("param_strategy", param_strategy.at(0) * param_strategy.at(LongToSize(axis_)));
762 if (axis_ != 0 && inputs_shape_[0][0] % (param_strategy.at(0) * param_strategy.at(LongToSize(axis_))) != 0) {
763 MS_LOG(ERROR) << name_ << ": param_shape(0) can't be divided by (param_strategy(0)*param_strategy(axis)).";
764 return FAILED;
765 }
766
767 if (CheckSplitAxisStrategy(param_strategy, indices_strategy) != SUCCESS) {
768 return FAILED;
769 }
770
771 // According to the strategy, set the private members.
772 SetAttribute(param_strategy);
773
774 return SUCCESS;
775 }
776
InferDevMatrixShape()777 Status ShardAxisImpl::InferDevMatrixShape() {
778 dev_matrix_shape_ = param_strategy_;
779
780 // infer out dev_matrix_shape
781 // axis is not 0, split axis
782 if (axis_ != 0 && param_strategy_.at(LongToSize(axis_)) != 1) {
783 for (size_t i = 1; i < param_strategy_.size(); ++i) {
784 if (i == LongToSize(axis_)) {
785 out_dev_matrix_shape_.push_back(1);
786 } else {
787 out_dev_matrix_shape_.push_back(param_strategy_.at(i));
788 }
789 }
790 out_dev_matrix_shape_.push_back(param_strategy_.at(0) * param_strategy_.at(LongToSize(axis_)));
791 } else {
792 out_dev_matrix_shape_ = dev_matrix_shape_;
793 }
794 auto param_product = std::accumulate(param_strategy_.begin(), param_strategy_.end(), 1, std::multiplies<int64_t>());
795 auto index_product =
796 std::accumulate(indices_strategy_.begin(), indices_strategy_.end(), 1, std::multiplies<int64_t>());
797 auto stage_device_size = SizeToLong(g_device_manager->GetDeviceListInThisStage().size());
798 if (param_product * index_product < stage_device_size) {
799 MS_EXCEPTION_IF_ZERO("param_product * index_product", param_product * index_product);
800 repeated_calculation_num_ = stage_device_size / (param_product * index_product); // set the repeat calc num
801 if (repeated_num_in_dev_matrix_right_) {
802 out_dev_matrix_shape_.push_back(repeated_calculation_num_);
803 } else {
804 (void)out_dev_matrix_shape_.insert(out_dev_matrix_shape_.begin(), repeated_calculation_num_);
805 }
806 }
807
808 return SUCCESS;
809 }
810
InferTensorMap()811 Status ShardAxisImpl::InferTensorMap() {
812 // param_strategy(axis) is not 1
813 // infer input tensor map
814 size_t param_size = inputs_shape_.at(0).size();
815 size_t index_size = inputs_shape_.at(1).size();
816 Shape tensor_map_index;
817 Shape tensor_map_params;
818
819 (void)tensor_map_index.insert(tensor_map_index.begin(), index_size, MAP_NONE);
820 for (size_t i = 0; i < param_size; ++i) {
821 tensor_map_params.push_back(SizeToLong(param_size - i - 1));
822 }
823
824 (void)inputs_tensor_map_.emplace_back(std::move(tensor_map_params));
825 (void)inputs_tensor_map_.emplace_back(std::move(tensor_map_index));
826
827 // infer output tensor map
828 Shape tensor_map_out;
829 if (axis_ == 0) {
830 if ((dynamic_shape_indices_ && target_ != CPU) || axis_split_forward_allreduce_) {
831 // the output is repeat calculation
832 (void)tensor_map_out.insert(tensor_map_out.end(), MAP_NONE);
833 } else {
834 (void)tensor_map_out.insert(tensor_map_out.end(), SizeToLong(param_size) - 1);
835 }
836 (void)tensor_map_out.insert(tensor_map_out.end(), index_size - 1, MAP_NONE);
837 for (size_t i = 1; i < param_size; ++i) {
838 tensor_map_out.push_back(param_size - 1 - i);
839 }
840 } else {
841 for (size_t i = 0; i < param_size; ++i) {
842 if (i == LongToSize(axis_)) {
843 (void)tensor_map_out.insert(tensor_map_out.end(), index_size, MAP_NONE);
844 } else {
845 if (i == 0 && dynamic_shape_indices_ && target_ != CPU) {
846 tensor_map_out.push_back(MAP_NONE);
847 }
848 tensor_map_out.push_back(SizeToLong(i));
849 }
850 }
851 }
852 (void)outputs_tensor_map_.emplace_back(std::move(tensor_map_out));
853
854 return SUCCESS;
855 }
856
InferTensorInfo()857 Status ShardAxisImpl::InferTensorInfo() {
858 // infer tensor layout
859 TensorLayout input_tensor_layout, input_index_layout, output_tensor_layout;
860
861 if ((input_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_.at(0), inputs_shape_[0]) != SUCCESS) ||
862 (input_index_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_.at(1), inputs_shape_[1]) != SUCCESS) ||
863 (output_tensor_layout.InitFromVector(out_dev_matrix_shape_, outputs_tensor_map_.at(0), outputs_shape_[0]) !=
864 SUCCESS)) {
865 return FAILED;
866 }
867
868 // infer tensor info
869 TensorInfo input_tensor_info(input_tensor_layout);
870 TensorInfo input_index_info(input_index_layout);
871 TensorInfo output_tensor_info(output_tensor_layout);
872
873 inputs_tensor_info_.push_back(input_tensor_info);
874 inputs_tensor_info_.push_back(input_index_info);
875 outputs_tensor_info_.push_back(output_tensor_info);
876 return SUCCESS;
877 }
878
InferGroup()879 Status ShardAxisImpl::InferGroup() {
880 size_t dim = LongToSize(axis_);
881
882 int64_t rank = g_device_manager->global_rank();
883 DeviceMatrix dev_matrix(rank, g_device_manager->GetDeviceListInThisStage(), dev_matrix_shape_);
884 RankList group_devices;
885
886 // the dev_matrix[0] is repeated_calc_num, so the dim need to add 1
887 if (repeated_calculation_num_ > 1 && !repeated_num_in_dev_matrix_right_) {
888 dim = dim + 1;
889 }
890
891 if (gather_mode_ == SHARD_BATCH_AND_AXIS) {
892 dim = 1;
893 MS_LOG(INFO) << name_ << ": Sharding batch and axis, the group dim is " << dim;
894 }
895
896 if (dev_matrix.GetDevicesAlongDim(SizeToUlong(dim), &group_devices) != SUCCESS) {
897 MS_LOG(ERROR) << name_ << ": Create group failed.";
898 return FAILED;
899 }
900 if (group_devices.size() == 1) {
901 MS_LOG(INFO) << name_ << ": The group is empty";
902 return SUCCESS;
903 }
904
905 MS_LOG(INFO) << name_ << ": The group ranks is " << group_devices;
906 if (g_device_manager->CreateGroup(group_devices, &group_) != SUCCESS) {
907 MS_LOG(ERROR) << name_ << ": create reduce group failed in table row split.";
908 return FAILED;
909 }
910 return SUCCESS;
911 }
912
InferForwardCommunication()913 Status ShardAxisImpl::InferForwardCommunication() {
914 forward_op_.clear();
915 // don't split axis or target is not CPU, no need forward communication
916 if (target_ != CPU) {
917 return SUCCESS;
918 }
919 // split axis
920 Attr attr_group;
921 OperatorName operator_name;
922 if (axis_split_forward_allreduce_) {
923 operator_name = ALL_REDUCE;
924 } else {
925 operator_name = REDUCE_SCATTER;
926 }
927
928 if (InferGroup() != SUCCESS) {
929 MS_LOG(ERROR) << name_ << ": Infer Group failed.";
930 return FAILED;
931 }
932 if (group_.name().empty()) {
933 return SUCCESS;
934 }
935 attr_group = std::make_pair(GROUP, MakeValue(group_.name()));
936 Attr attr_op = std::make_pair(OP, MakeValue(REDUCE_OP_SUM));
937 OperatorAttrs attrs = {attr_op, attr_group};
938 OperatorParams params;
939 OperatorArgs args = std::make_pair(attrs, params);
940 Operator op = std::make_pair(operator_name, args);
941
942 forward_op_.push_back(op);
943 return SUCCESS;
944 }
945
InferBias()946 Status ShardAxisImpl::InferBias() {
947 CheckGlobalDeviceManager();
948 int64_t rank = g_device_manager->rank_index_in_stage();
949 auto input_shape = inputs_shape_.at(0);
950 // params_size=1, axis=0
951 if ((input_shape.size() == 1) && (axis_ == 0)) {
952 MS_EXCEPTION_IF_ZERO("param_strategy_.at(0)", param_strategy_.at(0));
953 slice_size_ = input_shape.at(0) / param_strategy_.at(0);
954 // if repeated calculation, because the repeated num in the right of dev-matrix, so rank need to div repeated num
955 if (repeated_calculation_num_ > 1) {
956 if (repeated_num_in_dev_matrix_right_) {
957 rank = rank / repeated_calculation_num_;
958 } else {
959 rank = rank % param_strategy_[0];
960 }
961 }
962 bias_ = rank * slice_size_;
963 return SUCCESS;
964 }
965 // params_size=2, axis=0
966 if ((input_shape.size() == 2) && (axis_ == 0)) {
967 MS_EXCEPTION_IF_ZERO("param_strategy_.at(0)", param_strategy_.at(0));
968 MS_EXCEPTION_IF_ZERO("param_strategy_.at(1)", param_strategy_.at(1));
969 slice_size_ = input_shape.at(0) / param_strategy_.at(0);
970 // if repeated calculation, because the repeated num in the right of dev-matrix, so rank need to div repeated num
971 if (repeated_calculation_num_ > 1) {
972 if (repeated_num_in_dev_matrix_right_) {
973 rank = rank / repeated_calculation_num_;
974 } else {
975 rank = rank % (param_strategy_[0] * param_strategy_[1]);
976 }
977 }
978 #if defined(__linux__) && defined(WITH_BACKEND)
979 if (ps::PsDataPrefetch::GetInstance().cache_enable()) {
980 if (ps::PSContext::instance()->enable_distributed_mindrt()) {
981 bias_ = static_cast<int64_t>(embedding_cache_table_manager.cache_indices_lower_bound());
982 }
983 return SUCCESS;
984 }
985 #endif
986 bias_ = rank / param_strategy_.at(1) * slice_size_;
987 return SUCCESS;
988 }
989 // params_size=2, axis=1
990 if ((input_shape.size() == 2) && (axis_ == 1)) {
991 MS_EXCEPTION_IF_ZERO("param_strategy_.at(1)", param_strategy_.at(1));
992 slice_size_ = input_shape.at(1) / param_strategy_.at(1);
993 bias_ = rank % param_strategy_.at(1) * slice_size_;
994 return SUCCESS;
995 }
996 MS_LOG(ERROR) << name_ << ": Don't support params_size:" << input_shape.size() << " axis:" << axis_;
997 return FAILED;
998 }
999
InferReplaceGraph(const CNodePtr & cnode)1000 Status ShardAxisImpl::InferReplaceGraph(const CNodePtr &cnode) {
1001 if (target_ == CPU) {
1002 return SUCCESS;
1003 }
1004
1005 GenerateGraph gen_g = GenerateGraph(attrs_);
1006 if (gen_g.Init(cnode) != SUCCESS) {
1007 MS_LOG(ERROR) << name_ << "GenerateGraph Init failed";
1008 return FAILED;
1009 }
1010
1011 if (InferBias() != SUCCESS) {
1012 MS_LOG(ERROR) << name_ << ": Infer Bias failed.";
1013 return FAILED;
1014 }
1015 MS_LOG(INFO) << name_ << ": The rank is " << g_device_manager->rank_index_in_stage() << ", the bias is " << bias_;
1016 auto sub = gen_g.PushBack({gen_g.NewOpInst(SUB), gen_g.virtual_input_node(), CreateInt32Tensor(bias_)});
1017 auto relu = gen_g.PushBack({gen_g.NewOpInst(RELU), sub});
1018 auto minimum = gen_g.PushBack({gen_g.NewOpInst(MINIMUM), relu, CreateInt32Tensor(slice_size_ - 1)});
1019 auto equal = gen_g.PushBack({gen_g.NewOpInst(EQUAL), sub, minimum});
1020
1021 AnfNodePtr gather_v2{nullptr};
1022 auto replace_op_name = GetPrimNameFromInfoName(replace_op_name_);
1023 if (replace_op_name == INDEX_SELECT) {
1024 gather_v2 =
1025 gen_g.PushBack({gen_g.NewOpInst(replace_op_name_), gen_g.virtual_input_node(), CreatInt64Imm(axis_), minimum});
1026 } else if (replace_op_name == GATHERV2) {
1027 gather_v2 = gen_g.PushBack(
1028 {gen_g.NewOpInst(replace_op_name_), gen_g.virtual_input_node(), minimum, CreatInt64Imm(axis_), CreatInt64Imm(0)});
1029 } else {
1030 gather_v2 =
1031 gen_g.PushBack({gen_g.NewOpInst(replace_op_name_), gen_g.virtual_input_node(), minimum, CreatInt64Imm(axis_)});
1032 }
1033
1034 auto dtype = gen_g.PushBack({gen_g.NewOpInst(DTYPE), gather_v2});
1035 auto dtype_id =
1036 gen_g.PushBack({gen_g.NewOpInst(DTYPETOENUM), CreateStringImm("DtypeToEnum"), CreateStringImm("dtype"), dtype});
1037 auto cast = gen_g.PushBack({gen_g.NewOpInst(CAST), equal, dtype_id});
1038 auto expand_dims = gen_g.PushBack({gen_g.NewOpInst(EXPAND_DIMS), cast, CreatInt64Imm(axis_ - 1)});
1039 auto mul = gen_g.PushBack({gen_g.NewOpInst(MUL), gather_v2, expand_dims});
1040 // don't need expand dim, if param_size = 1
1041 if (inputs_shape_.at(0).size() == 1) {
1042 mul = gen_g.PushBack({gen_g.NewOpInst(MUL), gather_v2, cast});
1043 }
1044 if (InferGroup() != SUCCESS) {
1045 MS_LOG(ERROR) << name_ << ": Infer Group failed.";
1046 return FAILED;
1047 }
1048 Attr attr_op = std::make_pair(OP, MakeValue(REDUCE_OP_SUM));
1049 Attr attr_group = std::make_pair(GROUP, MakeValue(group_.name()));
1050 OperatorAttrs attrs = {attr_op, attr_group};
1051 AnfNodePtr reduce_op;
1052 if (dynamic_shape_indices_ || axis_split_forward_allreduce_ || is_assigned_parallel_) {
1053 reduce_op = gen_g.PushBack({gen_g.NewOpInst(ALL_REDUCE, attrs), mul});
1054 } else {
1055 reduce_op = gen_g.PushBack({gen_g.NewOpInst(REDUCE_SCATTER, attrs), mul});
1056 }
1057 std::vector<std::pair<AnfNodePtr, int64_t>> input_nodes = {std::make_pair(sub, 2), std::make_pair(gather_v2, 1)};
1058 if (replace_op_name == INDEX_SELECT) {
1059 input_nodes = {std::make_pair(sub, 3), std::make_pair(gather_v2, 1)};
1060 }
1061 replace_graph_ = std::make_shared<std::pair<std::vector<std::pair<AnfNodePtr, int64_t>>, AnfNodePtr>>(
1062 std::make_pair(input_nodes, reduce_op));
1063
1064 return SUCCESS;
1065 }
1066
InferReplaceOps()1067 Status ShardAxisImpl::InferReplaceOps() {
1068 if (target_ != CPU) { // if target is not CPU, no need to replace ops
1069 return SUCCESS;
1070 }
1071
1072 int64_t bias = 0;
1073
1074 if (InferBias() != SUCCESS) {
1075 MS_LOG(ERROR) << name_ << ": Infer offset failed.";
1076 return FAILED;
1077 }
1078 bias = bias_;
1079
1080 OperatorName op_name = EMBEDDING_LOOKUP;
1081 OperatorAttrs attrs;
1082 Attr param_offset = std::make_pair("offset", MakeValue(bias));
1083 OperatorParams params = {std::make_pair(param_offset, 3)};
1084 OperatorArgs args = std::make_pair(attrs, params);
1085 Operator op = std::make_pair(op_name, args);
1086 replace_op_.push_back(op);
1087
1088 return SUCCESS;
1089 }
1090
CheckStrategy(const StrategyPtr & strategy)1091 Status GatherInfo::CheckStrategy(const StrategyPtr &strategy) {
1092 if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
1093 return FAILED;
1094 }
1095 gather_util_ = nullptr;
1096 gather_mode_ = INVALID;
1097
1098 // param slice shape preferably 32Byte aligned
1099 auto param_shape = inputs_shape_[0];
1100 auto input_dim = strategy->GetInputDim();
1101 auto param_strategy = input_dim[0];
1102 auto indices_strategy = input_dim[1];
1103 MS_LOG(INFO) << name_ << ": the indices shape is " << inputs_shape_[1] << ", the strategy is " << input_dim[1];
1104 MS_EXCEPTION_IF_ZERO("param_strategy.at(param_strategy.size() - 1)", param_strategy.at(param_strategy.size() - 1));
1105 auto slice_shape = param_shape.at(param_shape.size() - 1) / param_strategy.at(param_strategy.size() - 1);
1106 if ((target_ != CPU) && (slice_shape % 8 != 0) && (slice_shape != 1)) {
1107 MS_LOG(WARNING) << "Gather: Last dim of param slice shape is not 32Byte aligned.";
1108 }
1109
1110 // get the gather mode, and choose the the corresponding implementation
1111 gather_mode_ = GetGatherMode(param_strategy, indices_strategy);
1112 switch (gather_mode_) {
1113 case BATCH: {
1114 gather_util_ = std::make_shared<BatchImpl>(name_, inputs_shape_clone_, outputs_shape_clone_, axis_);
1115 auto batch_util = std::dynamic_pointer_cast<BatchImpl>(gather_util_);
1116 batch_util->set_batch_dims(batch_dims_);
1117 break;
1118 }
1119 case NORMAL:
1120 gather_util_ = std::make_shared<NormalImpl>(name_, inputs_shape_clone_, outputs_shape_clone_, axis_);
1121 break;
1122 case MANUAL: {
1123 gather_util_ = MakeManualUtil();
1124 auto manual_util = std::dynamic_pointer_cast<ManualImpl>(gather_util_);
1125 manual_util->set_param_split_shapes(param_split_shapes_);
1126 manual_util->set_index_offsets(index_offsets_);
1127 manual_util->set_attrs(attrs_);
1128 manual_util->set_target(target_);
1129 manual_util->set_replace_op_name(replace_op_name_);
1130 break;
1131 }
1132 case SHARD_BATCH_AND_AXIS: {
1133 gather_util_ = std::make_shared<ShardBatchAndAxisImpl>(name_, inputs_shape_clone_, outputs_shape_clone_, axis_);
1134 auto shard_batch_and_axis_util = std::dynamic_pointer_cast<ShardBatchAndAxisImpl>(gather_util_);
1135 shard_batch_and_axis_util->set_target(target_);
1136 shard_batch_and_axis_util->set_dynamic_shape_indices(dynamic_shape_indices_);
1137 shard_batch_and_axis_util->set_attrs(attrs_);
1138 shard_batch_and_axis_util->set_replace_op_name(replace_op_name_);
1139 shard_batch_and_axis_util->set_axis_split_forward_allreduce(
1140 true); // Sharding batch and axis, and the forward use allreduce
1141 break;
1142 }
1143 case SHARD_AXIS_0_DYNAMIC:
1144 case SHARD_AXIS_0_STATIC:
1145 case SHARD_AXIS_1: {
1146 gather_util_ = std::make_shared<ShardAxisImpl>(name_, inputs_shape_clone_, outputs_shape_clone_, axis_);
1147 auto shard_axis_util = std::dynamic_pointer_cast<ShardAxisImpl>(gather_util_);
1148 shard_axis_util->set_target(target_);
1149 shard_axis_util->set_dynamic_shape_indices(dynamic_shape_indices_);
1150 shard_axis_util->set_attrs(attrs_);
1151 shard_axis_util->set_replace_op_name(replace_op_name_);
1152 shard_axis_util->set_assigned_parallel(is_assigned_parallel_);
1153 break;
1154 }
1155 default:
1156 MS_LOG(ERROR) << name_ << ": invalid gather mode: " << gather_mode_;
1157 return FAILED;
1158 }
1159
1160 gather_util_->set_dynamic_shape_flag(dynamic_shape_flag_);
1161 gather_util_->set_inputs_divisor(inputs_divisor_);
1162 gather_util_->set_outputs_divisor(outputs_divisor_);
1163
1164 gather_util_->DivisorsReplaceShapes();
1165 if (gather_util_->CheckStrategy(param_strategy, indices_strategy) != SUCCESS) {
1166 return FAILED;
1167 }
1168 gather_util_->ResumeShapes();
1169
1170 gather_util_->set_param_strategy(param_strategy);
1171 gather_util_->set_indices_strategy(indices_strategy);
1172 gather_util_->set_gather_mode(gather_mode_);
1173 MS_LOG(INFO) << name_ << ": the gather mode is " << gather_util_->GatherModeToString();
1174
1175 repeated_num_in_dev_matrix_right_ = gather_util_->repeated_num_in_dev_matrix_right(); // set the base class member
1176
1177 return SUCCESS;
1178 }
1179
CheckStrategyForDynamicShape(const StrategyPtr & strategy)1180 Status GatherInfo::CheckStrategyForDynamicShape(const StrategyPtr &strategy) {
1181 Strategies strategies = strategy->GetInputDim();
1182 auto param_strategy = strategies[0];
1183 if (param_strategy[axis_] != 1 && inputs_shape_[0][axis_] == -1) {
1184 MS_LOG(ERROR) << name_ << ": the axis dim of first input can not be split if it's dynamic shape, the strategy is "
1185 << ShapesToString(strategies) << ", the inputs' shape: " << ShapesToString(inputs_shape_)
1186 << ", the axis " << axis_;
1187 return FAILED;
1188 }
1189 return SUCCESS;
1190 }
1191
CheckOutputStrategy(const StrategyPtr & out_strategy)1192 Status GatherInfo::CheckOutputStrategy(const StrategyPtr &out_strategy) {
1193 if (out_strategy == nullptr) {
1194 MS_LOG(INFO) << name_ << ": The output strategy is null";
1195 return SUCCESS;
1196 }
1197
1198 if (CheckStrategyValue(out_strategy, outputs_shape_) != SUCCESS) {
1199 MS_LOG(ERROR) << name_ << ": Invalid output strategy";
1200 return FAILED;
1201 }
1202
1203 if (axis_ != 0 && batch_dims_ != 0) {
1204 MS_LOG(ERROR) << name_ << ": Set output strategy only for axis = 0 and batch_dims = 0, but the axis is " << axis_
1205 << ", the batch_dims is " << batch_dims_;
1206 return FAILED;
1207 }
1208
1209 MS_EXCEPTION_IF_NULL(gather_util_);
1210 auto shard_axis_util = std::dynamic_pointer_cast<ShardAxisImpl>(gather_util_);
1211
1212 auto in_stra = strategy_->GetInputDim();
1213 auto param_strategy = in_stra[0];
1214 auto index_strategy = in_stra[1];
1215
1216 // only for axis == 0
1217 auto allreduce_strategy = index_strategy;
1218 (void)allreduce_strategy.insert(allreduce_strategy.end(), param_strategy.begin() + 1, param_strategy.end());
1219 auto reduce_scatter_strategy = allreduce_strategy;
1220 reduce_scatter_strategy[0] *= param_strategy[0];
1221
1222 auto out_stra = out_strategy->GetInputDim()[0];
1223 if (out_stra == allreduce_strategy) {
1224 if (shard_axis_util != nullptr) {
1225 shard_axis_util->set_axis_split_forward_allreduce(true);
1226 }
1227
1228 MS_LOG(INFO) << name_ << ": The output strategy is " << out_stra << ", forward use allreduce";
1229 return SUCCESS;
1230 } else if (out_stra == reduce_scatter_strategy) {
1231 if (gather_util_->gather_mode() != SHARD_AXIS_0_STATIC && gather_util_->gather_mode() != SHARD_BATCH_AND_AXIS) {
1232 MS_LOG(ERROR) << name_ << ": The output strategy " << out_stra << " for gather mode "
1233 << gather_util_->GatherModeToString() << " is invalid, it must be " << allreduce_strategy;
1234 return FAILED;
1235 }
1236
1237 if (shard_axis_util) {
1238 shard_axis_util->set_axis_split_forward_allreduce(false);
1239 }
1240 MS_LOG(INFO) << name_ << ": The output strategy is " << out_stra << ", forward use reduce scatter";
1241 return SUCCESS;
1242 }
1243
1244 MS_LOG(ERROR) << name_ << ": The output strategy " << out_stra << " is invalid, it must be " << allreduce_strategy
1245 << " or " << reduce_scatter_strategy;
1246 return FAILED;
1247 }
1248
DealWithBatchDimsMirrorOp()1249 void GatherInfo::DealWithBatchDimsMirrorOp() noexcept {
1250 OperatorVector op_for_batch_dims;
1251 mirror_ops_.push_back(op_for_batch_dims);
1252 }
1253
InferMirrorOps()1254 Status GatherInfo::InferMirrorOps() {
1255 mirror_ops_.clear();
1256 Shape input_a_tensor_map = inputs_tensor_map_.at(0);
1257 std::vector<Group> input_a_group;
1258 if (CreateGroupByTensorMap(input_a_tensor_map, &input_a_group) != SUCCESS) {
1259 ReportError(name_ + " : Create group for input a failed.");
1260 return FAILED;
1261 }
1262
1263 OperatorVector op_for_input_a, op_for_input_b, op_for_axis;
1264 if (input_a_group.empty()) {
1265 MS_LOG(INFO) << name_ << " : The mirror group is empty.";
1266 return SUCCESS;
1267 } else {
1268 op_for_input_a = CreateMirrorOps(input_a_group[0].name(), input_a_group[0].GetDevNum());
1269 MS_LOG(INFO) << name_ << " : Create the mirror ops for input a success, group is " << input_a_group[0].name();
1270 }
1271
1272 mirror_ops_.push_back(op_for_input_a);
1273 mirror_ops_.push_back(op_for_input_b);
1274 mirror_ops_.push_back(op_for_axis);
1275 DealWithBatchDimsMirrorOp();
1276
1277 return SUCCESS;
1278 }
1279
InferDevMatrixShape()1280 Status GatherInfo::InferDevMatrixShape() {
1281 if (gather_util_->InferDevMatrixShape() != SUCCESS) {
1282 return FAILED;
1283 }
1284 dev_matrix_shape_ = gather_util_->dev_matrix_shape();
1285 out_dev_matrix_shape_ = gather_util_->out_dev_matrix_shape(); // set base class member
1286 return SUCCESS;
1287 }
1288
InferTensorMap()1289 Status GatherInfo::InferTensorMap() {
1290 // the dev matrix shape may be changed if repeat calculation, so need to reset the dev matrix shape for gather_util
1291 gather_util_->set_dev_matrix_shape(dev_matrix_shape_);
1292
1293 if (gather_util_->InferTensorMap() != SUCCESS) {
1294 return FAILED;
1295 }
1296
1297 inputs_tensor_map_ = gather_util_->inputs_tensor_map();
1298 outputs_tensor_map_ = gather_util_->outputs_tensor_map();
1299 return SUCCESS;
1300 }
1301
InferTensorInfoNoSplitAxis()1302 Status GatherUtil::InferTensorInfoNoSplitAxis() {
1303 TensorLayout input_tensor_layout;
1304 TensorLayout input_index_layout;
1305 TensorLayout output_tensor_layout;
1306
1307 if ((input_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_.at(0), inputs_shape_[0]) != SUCCESS) ||
1308 (input_index_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_.at(1), inputs_shape_[1]) != SUCCESS) ||
1309 (output_tensor_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_.at(0), outputs_shape_[0]) !=
1310 SUCCESS)) {
1311 return FAILED;
1312 }
1313
1314 // infer tensor info
1315 TensorInfo input_tensor_info(input_tensor_layout);
1316 TensorInfo input_index_info(input_index_layout);
1317 TensorInfo output_tensor_info(output_tensor_layout);
1318
1319 inputs_tensor_info_.push_back(input_tensor_info);
1320 inputs_tensor_info_.push_back(input_index_info);
1321 outputs_tensor_info_.push_back(output_tensor_info);
1322 return SUCCESS;
1323 }
1324
DivisorsReplaceShapes()1325 void GatherUtil::DivisorsReplaceShapes() {
1326 if (!dynamic_shape_flag_) {
1327 return;
1328 }
1329
1330 inputs_shape_ = inputs_divisor_;
1331 outputs_shape_ = outputs_divisor_;
1332 }
1333
ResumeShapes()1334 void GatherUtil::ResumeShapes() {
1335 if (!dynamic_shape_flag_) {
1336 return;
1337 }
1338
1339 inputs_shape_ = inputs_shape_clone_;
1340 outputs_shape_ = outputs_shape_clone_;
1341 }
1342
InferTensorInfo()1343 Status GatherInfo::InferTensorInfo() {
1344 // the tensor map of gather_util may be changed if repeat calculation, so need to reset
1345 gather_util_->set_inputs_tensor_map(inputs_tensor_map_);
1346 gather_util_->set_outputs_tensor_map(outputs_tensor_map_);
1347
1348 if (gather_util_->InferTensorInfo() != SUCCESS) {
1349 return FAILED;
1350 }
1351
1352 inputs_tensor_info_ = gather_util_->inputs_tensor_info();
1353 outputs_tensor_info_ = gather_util_->outputs_tensor_info();
1354 if (name_.find(INDEX_SELECT) != std::string::npos) {
1355 TensorInfo axis_place_holder;
1356 if (inputs_tensor_info_.empty()) {
1357 MS_LOG(ERROR) << name_ << ": the tensor info of inputs is empty";
1358 return FAILED;
1359 }
1360 (void)inputs_tensor_info_.insert(inputs_tensor_info_.cbegin() + 1, axis_place_holder);
1361 }
1362 return SUCCESS;
1363 }
1364
InferForwardCommunication()1365 Status GatherInfo::InferForwardCommunication() {
1366 if (gather_util_->InferForwardCommunication() != SUCCESS) {
1367 return FAILED;
1368 }
1369 forward_op_ = gather_util_->forward_op();
1370 return SUCCESS;
1371 }
1372
replace_graph(const CNodePtr & cnode)1373 ReplaceGraphPtr GatherInfo::replace_graph(const CNodePtr &cnode) {
1374 // target_ == CPU, no need to replace graph
1375 if (target_ == CPU) {
1376 return nullptr;
1377 }
1378 if (gather_util_->InferReplaceGraph(cnode) != SUCCESS) {
1379 MS_LOG(EXCEPTION) << name_ << ": infer replace graph failed.";
1380 }
1381 replace_graph_ = gather_util_->replace_graph();
1382 return replace_graph_;
1383 }
1384
ComputeReplaceOp()1385 Status GatherInfo::ComputeReplaceOp() {
1386 if (gather_util_->InferReplaceOps() != SUCCESS) {
1387 return FAILED;
1388 }
1389 replace_op_ = gather_util_->replace_op();
1390 return SUCCESS;
1391 }
1392
Init(const StrategyPtr & in_strategy,const StrategyPtr & out_strategy,const std::vector<std::shared_ptr<TensorLayout>> & in_tensor_layouts,const std::vector<std::shared_ptr<TensorLayout>> & out_tensor_layouts)1393 Status GatherInfo::Init(const StrategyPtr &in_strategy, const StrategyPtr &out_strategy,
1394 const std::vector<std::shared_ptr<TensorLayout>> &in_tensor_layouts,
1395 const std::vector<std::shared_ptr<TensorLayout>> &out_tensor_layouts) {
1396 if (InitWithAutoRepeatCalc(in_strategy, out_strategy) != SUCCESS) {
1397 MS_LOG(ERROR) << name_ << ": Init failed.";
1398 return FAILED;
1399 }
1400 // only target_ == CPU, we need to replace op
1401 if (target_ == CPU && ComputeReplaceOp() != SUCCESS) {
1402 MS_LOG(ERROR) << name_ << ": ComputeReplaceOp failed.";
1403 return FAILED;
1404 }
1405 MS_LOG(INFO) << name_ << ": Init success.";
1406 return SUCCESS;
1407 }
1408
InitForCostModel(const StrategyPtr & in_strategy,const StrategyPtr & out_strategy)1409 Status GatherInfo::InitForCostModel(const StrategyPtr &in_strategy, const StrategyPtr &out_strategy) {
1410 if (InitForCostModelWithAutoRepeatCalc(in_strategy, out_strategy) != SUCCESS) {
1411 FILTER_LOG(is_auto_parallel_) << name_ << ": Init for cost model failed.";
1412 return FAILED;
1413 }
1414 auto param_strategy = strategy_->GetInputDim().at(0);
1415 // cost model set axis and strategy
1416 auto gather_cost = std::dynamic_pointer_cast<GatherCost>(operator_cost());
1417 gather_cost->set_axis(axis_);
1418 gather_cost->set_strategy(param_strategy);
1419 MS_LOG(INFO) << name_ << ": Init for cost model success.";
1420 return SUCCESS;
1421 }
1422
SetCostUnderStrategy(const StrategyPtr & strategy)1423 Status GatherInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }
1424
GenerateOpStrategies(int64_t stage_id)1425 std::vector<StrategyPtr> GatherInfo::GenerateOpStrategies(int64_t stage_id) {
1426 if (manual_split_) {
1427 MS_LOG(EXCEPTION) << name_ << ": Manual split does not support to search strategy";
1428 }
1429 Shape input0_split(inputs_shape_[0].size(), 1);
1430 Shape input1_split(inputs_shape_[1].size(), 1);
1431 Shapes splittable_inputs = {input0_split, input1_split};
1432
1433 std::vector<StrategyPtr> sp_vector;
1434 if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) {
1435 MS_LOG(EXCEPTION) << name_ << ": Generate strategies for independent inputs() failed.";
1436 }
1437 return sp_vector;
1438 }
1439
GenerateBatchStrategies()1440 std::shared_ptr<Strategies> GatherInfo::GenerateBatchStrategies() {
1441 if (GetAttrs() != SUCCESS) {
1442 MS_LOG(EXCEPTION) << name_ << ": Get attr failed";
1443 }
1444 if (manual_split_) {
1445 MS_LOG(EXCEPTION) << name_ << ": Manual split does not support to generate batch strategy";
1446 }
1447
1448 Dimensions param_strategy(inputs_shape_[0].size(), 1);
1449 Dimensions index_strategy;
1450 index_strategy.push_back(stage_device_size_);
1451 for (size_t i = 1; i < inputs_shape_[1].size(); i++) {
1452 index_strategy.push_back(1);
1453 }
1454
1455 if (batch_dims_ > 0 && !param_strategy.empty()) {
1456 param_strategy[0] = stage_device_size_;
1457 }
1458 Strategies strategy_v = {param_strategy, index_strategy};
1459 return std::make_shared<Strategies>(strategy_v);
1460 }
1461
1462 REGISTER(GatherInfo);
1463 REGISTER(IndexSelectInfo);
1464 REGISTER(SparseGatherV2Info);
1465 REGISTER(EmbeddingLookupInfo);
1466 } // namespace parallel
1467 } // namespace mindspore
1468