1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/parallel/ops_info/gather_v2_p_info.h"
18
19 #include <vector>
20 #include <numeric>
21 #include <functional>
22 #include <utility>
23 #include <algorithm>
24
25 #include "frontend/parallel/device_matrix.h"
26 #include "frontend/parallel/graph_util/generate_graph.h"
27 #include "frontend/parallel/context.h"
28 #if ((defined ENABLE_CPU) && (!defined _WIN32))
29 #include "ps/ps_cache/ps_cache_manager.h"
30 #include "utils/ms_context.h"
31 #endif
32
33 namespace mindspore {
34 namespace parallel {
GetManualSplitWithoutOffsetAttr()35 Status GatherPInfo::GetManualSplitWithoutOffsetAttr() {
36 auto manual_split_without_offset_iter = attrs_.find("manual_split");
37 if (manual_split_without_offset_iter != attrs_.end()) {
38 manual_split_ = true;
39 MS_EXCEPTION_IF_NULL(manual_split_without_offset_iter->second);
40 if (manual_split_without_offset_iter->second->cast<ValueTuplePtr>() == nullptr) {
41 MS_LOG(ERROR) << name_ << ": Manual split without offset strategy's format is wrong! Need ValueSequeue";
42 return FAILED;
43 }
44 std::vector<ValuePtr> value_vector = manual_split_without_offset_iter->second->cast<ValueTuplePtr>()->value();
45 MS_LOG(INFO) << name_ << ": manual split with offset is " << manual_split_without_offset_iter->second->ToString();
46
47 int64_t offset = 0;
48 for (auto &ele : value_vector) {
49 index_offsets_.push_back(offset);
50 if (!ele->isa<Int64Imm>()) {
51 MS_LOG(ERROR) << name_ << ": The element of manual split must be int64_t";
52 return FAILED;
53 }
54 int64_t param_split_shape = static_cast<int64_t>(GetValue<int64_t>(ele));
55 if (param_split_shape <= 0) {
56 MS_LOG(ERROR) << name_ << ": The value of manual split must be positive, but got " << param_split_shape;
57 return FAILED;
58 }
59 param_split_shapes_.push_back(param_split_shape);
60 offset += param_split_shape;
61 }
62 if (param_split_shapes_.empty()) {
63 MS_LOG(ERROR) << name_ << ": Failed to extract param split's split info";
64 return FAILED;
65 }
66 }
67
68 return SUCCESS;
69 }
70
GetManualSplitAttr()71 Status GatherPInfo::GetManualSplitAttr() {
72 auto manual_split_with_offset_iter = attrs_.find("manual_split_with_offset");
73 if (manual_split_with_offset_iter != attrs_.end()) {
74 manual_split_ = true;
75 auto var = manual_split_with_offset_iter->second->cast<ValueTuplePtr>();
76 if (var == nullptr) {
77 MS_LOG(ERROR) << name_ << ": Manual split with offset strategy's format is wrong! Need ValueSequeue";
78 return FAILED;
79 }
80
81 MS_LOG(INFO) << name_ << ": manual split with offset strategy " << var->ToString();
82 for (auto &ele : var->value()) {
83 if (!ele->isa<ValueSequeue>()) {
84 MS_LOG(ERROR) << name_ << ": Manual split with offset strategy's format is wrong! Need ValueSequeue";
85 return FAILED;
86 }
87 std::vector<ValuePtr> value_vector = ele->cast<ValueTuplePtr>()->value();
88 if (value_vector.size() != 2) {
89 MS_LOG(ERROR) << name_ << ": Size of manual split with offset's element must be 2";
90 return FAILED;
91 }
92 int64_t param_split_row = (GetValue<int64_t>(value_vector[0]));
93 int64_t offset = (GetValue<int64_t>(value_vector[1]));
94 if ((param_split_row <= 0) || (offset < 0)) {
95 MS_LOG(ERROR) << name_
96 << ": The value of param split shape must be positive, and the offset must larger or equal to 0";
97 return FAILED;
98 }
99 param_split_shapes_.push_back(param_split_row);
100 index_offsets_.push_back(offset);
101 }
102
103 if (param_split_shapes_.empty()) {
104 MS_LOG(ERROR) << name_ << ": Failed to extract param split with offset's split info";
105 return FAILED;
106 }
107 if (std::any_of(index_offsets_.begin(), index_offsets_.end(), [](const int64_t &offset) { return offset < 0; })) {
108 MS_LOG(ERROR) << name_ << ": Index offset must not less than 0";
109 return FAILED;
110 }
111 return SUCCESS;
112 }
113
114 if (GetManualSplitWithoutOffsetAttr() != SUCCESS) {
115 return FAILED;
116 }
117
118 return SUCCESS;
119 }
120
GetAttrs()121 Status GatherPInfo::GetAttrs() {
122 // get axis, the third input is the axis, is a ValueNode, embeddinglookup doesn't have axis.
123 if (target_ != CPU) {
124 if (input_value_.at(2) == nullptr) {
125 MS_LOG(ERROR) << name_ << ": the third input value is nullptr, is not a ValueNode!";
126 return FAILED;
127 }
128 auto axis = GetValue<int64_t>(input_value_.at(2));
129 // if axis is negative then convert it to positive
130 auto params_shape = inputs_shape_.at(0);
131 if (params_shape.size() == 0) {
132 MS_LOG(ERROR) << name_ << ": params can not be a scalar!";
133 return FAILED;
134 }
135 if (axis < 0) {
136 axis += SizeToLong(inputs_shape_[0].size());
137 }
138 axis_ = axis;
139 }
140
141 auto target_iter = attrs_.find(TARGET);
142 if (target_iter != attrs_.end()) {
143 MS_EXCEPTION_IF_NULL(target_iter->second);
144 if (target_iter->second->isa<StringImm>()) {
145 target_ = target_iter->second->cast<StringImmPtr>()->value();
146 } else {
147 MS_LOG(ERROR) << name_ << ": The value of target is not a string.";
148 }
149 }
150
151 if (GetManualSplitAttr() != SUCCESS) {
152 return FAILED;
153 }
154
155 if (manual_split_ && (axis_ != 0)) {
156 MS_LOG(ERROR) << name_ << ": The axis or offset must be 0 if manual split, bug got " << axis_;
157 return FAILED;
158 }
159
160 if (std::find(inputs_shape_[1].begin(), inputs_shape_[1].end(), -1) != inputs_shape_[1].end()) {
161 dynamic_shape_indices_ = true;
162 }
163 #if ((defined ENABLE_CPU) && (!defined _WIN32))
164 MS_EXCEPTION_IF_NULL(MsContext::GetInstance());
165 bool enable_sparse = MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_SPARSE);
166 if (ps::PsDataPrefetch::GetInstance().cache_enable() && enable_sparse) {
167 dynamic_shape_indices_ = true;
168 }
169 #endif
170 return SUCCESS;
171 }
172
CheckManualSplit(const Strategys & strategy)173 Status GatherPInfo::CheckManualSplit(const Strategys &strategy) {
174 if (strategy.size() != 2) {
175 MS_LOG(ERROR) << name_ << ": The size of strategy must be 2, but got " << strategy.size();
176 return FAILED;
177 }
178 Dimensions param_strategy = strategy[0];
179 Dimensions indices_strategy = strategy[1];
180 if (param_strategy.size() != 2 || indices_strategy.size() != 2) {
181 MS_LOG(ERROR) << name_ << ": The size of param strategy or indices strategy must be 2";
182 return FAILED;
183 }
184
185 if (indices_strategy[0] != 1) {
186 MS_LOG(ERROR) << name_ << ": The indices_strategy[0] must be 1, bug got " << indices_strategy[0];
187 return FAILED;
188 }
189
190 if (param_strategy[0] != indices_strategy[1]) {
191 MS_LOG(ERROR) << name_ << ": The param_strategy[0] must be equal to indices_strategy[1]";
192 return FAILED;
193 }
194
195 if (indices_strategy[1] != SizeToLong(param_split_shapes_.size())) {
196 MS_LOG(ERROR) << name_ << ": The indices_strategy[1] must be equal to manual split size";
197 return FAILED;
198 }
199
200 int64_t min_param_slice_row = inputs_shape_[1][1] / indices_strategy[1];
201 bool invalid = std::any_of(param_split_shapes_.begin(), param_split_shapes_.end(),
202 [&min_param_slice_row](int64_t v) { return v < min_param_slice_row; });
203 if (invalid) {
204 MS_LOG(ERROR) << name_ << ": The split value must be larger than or equal to indices slice's column num";
205 return FAILED;
206 }
207
208 if (inputs_shape_[0][0] < inputs_shape_[1][1]) {
209 MS_LOG(ERROR) << name_ << ": The param's row smaller than indices' column";
210 return FAILED;
211 }
212
213 // Don't support repeated calc
214 auto product_p = std::accumulate(param_strategy.begin(), param_strategy.end(), 1, std::multiplies<int64_t>());
215 if (product_p < stage_device_size_) {
216 MS_LOG(ERROR) << name_ << ": Manual split doesn't support repeated calc";
217 return FAILED;
218 }
219
220 int64_t split_shape_sum = std::accumulate(param_split_shapes_.begin(), param_split_shapes_.end(), 0,
221 [](int64_t s, int64_t shape) { return s + shape; });
222 if (split_shape_sum != inputs_shape_[0][0]) {
223 MS_LOG(ERROR) << name_ << ": Sum of split shapes must be equal to param_shape[0]";
224 return FAILED;
225 }
226 return SUCCESS;
227 }
228
CheckSplitAxisStrategy(const StrategyPtr & strategy)229 Status GatherPInfo::CheckSplitAxisStrategy(const StrategyPtr &strategy) {
230 auto param_strategy = strategy->GetInputDim().at(0);
231 auto index_strategy = strategy->GetInputDim().at(1);
232 // param_strategy(axis) != 1, index can't be split
233 auto product_i = std::accumulate(index_strategy.begin(), index_strategy.end(), 1, std::multiplies<int64_t>());
234 if ((param_strategy.at(LongToSize(axis_)) != 1) && (product_i != 1)) {
235 MS_LOG(DEBUG) << name_ << ": param is split at dim (axis)" << axis_ << " ,index can't be split.";
236 return FAILED;
237 }
238
239 // param_strategy(axis) != 1, and axis != 0, don't support repeated calc
240 auto product_p = std::accumulate(param_strategy.begin(), param_strategy.end(), 1, std::multiplies<int64_t>());
241 if ((product_p != stage_device_size_) && (param_strategy.at(LongToSize(axis_)) != 1) && (axis_ != 0)) {
242 MS_LOG(DEBUG) << name_ << ": Invalid strategy. Don't support repeated calc.";
243 return FAILED;
244 }
245
246 if ((product_p != stage_device_size_) && (param_strategy.at(LongToSize(axis_)) != 1) && (axis_ == 0)) {
247 if ((param_strategy.size() == 2) && (param_strategy[1] != 1)) {
248 MS_LOG(DEBUG) << name_ << ": axis(0) is split, and param_strategy[1] != 1, don't support repeated calc.";
249 return FAILED;
250 }
251 MS_LOG(INFO) << name_ << ": split axis(0) and repeat calculation";
252 }
253 return SUCCESS;
254 }
255
256 // return true: axis is 0, and split the first dimension of parameter and the first dimension of indices
257 // otherwise return false
ShardBatchAndAxis(const Strategys & strategy) const258 bool GatherPInfo::ShardBatchAndAxis(const Strategys &strategy) const {
259 if (axis_ != 0) {
260 return false;
261 }
262
263 if (strategy.size() != 2) {
264 return false;
265 }
266
267 Dimensions param_strategy = strategy[0];
268 Dimensions indices_strategy = strategy[1];
269 if ((param_strategy.size() != 2) || (indices_strategy.size() != 2)) {
270 return false;
271 }
272
273 if ((param_strategy[1] != 1) || (indices_strategy[1] != 1)) {
274 return false;
275 }
276
277 if (param_strategy[0] * indices_strategy[0] != stage_device_size_) {
278 return false;
279 }
280
281 if ((param_strategy[0] == stage_device_size_) || (indices_strategy[0] == stage_device_size_)) {
282 return false;
283 }
284
285 return true;
286 }
287
SetAttribute(const StrategyPtr & strategy)288 void GatherPInfo::SetAttribute(const StrategyPtr &strategy) {
289 auto param_strategy = strategy->GetInputDim().at(0);
290 // axis=0, index_shape(0)%param_strategy(0) must be 0
291 Shape index_shape = inputs_shape_.at(1);
292 if ((axis_ == 0) && (index_shape.at(0) % param_strategy.at(0) != 0) && !dynamic_shape_indices_) {
293 MS_LOG(INFO) << name_ << ": index_shape(0) can't be divided by param_strategy(0), use allreduce in forward";
294 axis_split_forward_allreduce_ = true;
295 } else if (is_auto_parallel_) {
296 // in auto parallel mode, this function will be called many times, so need to reset the flags
297 axis_split_forward_allreduce_ = false;
298 }
299
300 auto product_param = std::accumulate(param_strategy.begin(), param_strategy.end(), 1, std::multiplies<int>());
301 // Cast 1: If repeated calculation, need to set repeated num to the left of dev-matrix. For example,
302 // parameter strategy is [8, 1], indices strategy is [1, 1], dev num is 16,
303 // and dev_matrix is [2, 1, 8, 1, 1], the communication groups are [0, 8] and [0, 1, 2, 3, 4, 5, 6, 7], they
304 // can communicate normally, and dev0 to dev7 have the all parameters.
305 // Cast 2: If not repeated calculation(such as data parallel), need to set repeated num to the right,
306 // as it's easy to introduce the redistribution after or before gather operation, influencing the performance.
307 if (product_param == stage_device_size_ || product_param == 1) {
308 repeated_num_in_dev_matrix_right_ = true;
309 } else {
310 repeated_num_in_dev_matrix_right_ = false;
311 }
312 MS_LOG(INFO) << "Set repeated_num_in_dev_matrix_right for gather to " << repeated_num_in_dev_matrix_right_;
313 }
314
CheckStrategy(const StrategyPtr & strategy)315 Status GatherPInfo::CheckStrategy(const StrategyPtr &strategy) {
316 if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
317 return FAILED;
318 }
319
320 // param slice shape need 32Byte aligned
321 auto param_shape = inputs_shape_.at(0);
322 auto param_strategy = strategy->GetInputDim().at(0);
323 auto slice_shape = param_shape.at(param_shape.size() - 1) / param_strategy.at(param_strategy.size() - 1);
324 if ((target_ != CPU) && (slice_shape % 8 != 0) && (slice_shape != 1)) {
325 MS_LOG(ERROR) << name_ << ": Last dim of param slice shape need 32Byte aligned.";
326 return FAILED;
327 }
328
329 // only support 1-dim and 2-dim param
330 if (inputs_shape_.at(0).size() != 1 && inputs_shape_.at(0).size() != 2) {
331 MS_LOG(ERROR) << name_ << ": Don't support param dim " << inputs_shape_.at(0).size();
332 return FAILED;
333 }
334
335 // don't support scalar index
336 if (inputs_shape_.at(1).size() == 0) {
337 MS_LOG(DEBUG) << name_ << ": Don't support scalar index.";
338 return FAILED;
339 }
340
341 if (ShardBatchAndAxis(strategy->GetInputDim())) {
342 shard_batch_and_axis_ = true;
343 axis_split_forward_allreduce_ = true;
344 MS_LOG(INFO) << name_ << ": Sharding batch and axis, and the forward use allreduce";
345 return SUCCESS;
346 } else if (is_auto_parallel_) {
347 // in auto parallel mode, this function will be called many times, so need to reset the flags
348 shard_batch_and_axis_ = false;
349 axis_split_forward_allreduce_ = false;
350 }
351
352 if (manual_split_) {
353 if (CheckManualSplit(strategy->GetInputDim()) != SUCCESS) {
354 return FAILED;
355 }
356 // when using manual_split, no need to check belowings.
357 return SUCCESS;
358 }
359
360 // axis != 0, param_shape(0)%(param_strategy(0)*param_strategy(axis)) must be 0
361 if (axis_ != 0 && param_shape.at(0) % (param_strategy.at(0) * param_strategy.at(LongToSize(axis_))) != 0) {
362 MS_LOG(DEBUG) << name_ << ": param_shape(0) can't be divided by (param_strategy(0)*param_strategy(axis)).";
363 return FAILED;
364 }
365
366 if (CheckSplitAxisStrategy(strategy) != SUCCESS) {
367 return FAILED;
368 }
369
370 // According to the strategy, set the private members.
371 SetAttribute(strategy);
372
373 return SUCCESS;
374 }
375
InferMirrorOps()376 Status GatherPInfo::InferMirrorOps() {
377 // There is no mirror operators for manual split
378 if (manual_split_) {
379 return SUCCESS;
380 }
381
382 mirror_ops_.clear();
383 Shape input_a_tensor_map = inputs_tensor_map_.at(0);
384 std::vector<Group> input_a_group;
385 if (CreateGroupByTensorMap(input_a_tensor_map, &input_a_group) != SUCCESS) {
386 MS_LOG(ERROR) << name_ << " : Create group for input a failed.";
387 return FAILED;
388 }
389
390 OperatorVector op_for_input_a, op_for_input_b, op_for_axis;
391 if (input_a_group.empty()) {
392 MS_LOG(INFO) << name_ << " : The mirror group is empty.";
393 return SUCCESS;
394 } else {
395 op_for_input_a = CreateMirrorOps(input_a_group[0].name(), input_a_group[0].GetDevNum());
396 MS_LOG(INFO) << name_ << " : Create the mirror ops for input a success, group is " << input_a_group[0].name();
397 }
398
399 mirror_ops_.push_back(op_for_input_a);
400 mirror_ops_.push_back(op_for_input_b);
401 mirror_ops_.push_back(op_for_axis);
402
403 return SUCCESS;
404 }
405
InferDevMatrixShape()406 Status GatherPInfo::InferDevMatrixShape() {
407 dev_matrix_shape_.clear();
408 out_dev_matrix_shape_.clear();
409 // infer input dev_matrix_shape
410 auto param_strategy = strategy_->GetInputDim().at(0);
411 auto index_strategy = strategy_->GetInputDim().at(1);
412
413 if (manual_split_) {
414 dev_matrix_shape_ = param_strategy;
415 out_dev_matrix_shape_ = dev_matrix_shape_;
416 return SUCCESS;
417 }
418
419 if (shard_batch_and_axis_) {
420 dev_matrix_shape_ = {index_strategy[0], param_strategy[0]};
421 // if forward use reducescatter, the dev matrix is {index_strategy[0] * param_strategy[0]}
422 out_dev_matrix_shape_ = dev_matrix_shape_;
423 MS_LOG(INFO) << name_ << ": Sharding batch and axis, the dev matrix is " << dev_matrix_shape_
424 << ", out dev matrix is " << out_dev_matrix_shape_;
425 return SUCCESS;
426 }
427
428 dev_matrix_shape_ = param_strategy;
429
430 // param_strategy(axis) is 1
431 if (param_strategy.at(LongToSize(axis_)) == 1) {
432 dev_matrix_shape_.insert(dev_matrix_shape_.end(), index_strategy.begin(), index_strategy.end());
433 }
434
435 // infer out dev_matrix_shape
436 // axis is not 0, split axis
437 if (axis_ != 0 && param_strategy.at(LongToSize(axis_)) != 1) {
438 for (size_t i = 1; i < param_strategy.size(); ++i) {
439 if (i == LongToSize(axis_)) {
440 out_dev_matrix_shape_.push_back(1);
441 } else {
442 out_dev_matrix_shape_.push_back(param_strategy.at(i));
443 }
444 }
445 out_dev_matrix_shape_.push_back(param_strategy.at(0) * param_strategy.at(LongToSize(axis_)));
446 } else {
447 out_dev_matrix_shape_ = dev_matrix_shape_;
448 }
449 auto param_product = std::accumulate(param_strategy.begin(), param_strategy.end(), 1, std::multiplies<int64_t>());
450 auto index_product = std::accumulate(index_strategy.begin(), index_strategy.end(), 1, std::multiplies<int64_t>());
451 if (param_product * index_product < stage_device_size_) {
452 auto repeated_calc_num = stage_device_size_ / (param_product * index_product);
453 if (repeated_num_in_dev_matrix_right_) {
454 out_dev_matrix_shape_.push_back(repeated_calc_num);
455 } else {
456 (void)out_dev_matrix_shape_.insert(out_dev_matrix_shape_.begin(), repeated_calc_num);
457 }
458 }
459
460 return SUCCESS;
461 }
462
InferInputsTensorMap()463 void GatherPInfo::InferInputsTensorMap() {
464 // infer input tensor map
465 // param_strategy(axis) is not 1
466 size_t param_size = inputs_shape_.at(0).size();
467 size_t index_size = inputs_shape_.at(1).size();
468 size_t total_size = param_size + index_size;
469 Shape tensor_map_index;
470 Shape tensor_map_params;
471 auto param_strategy = strategy_->GetInputDim().at(0);
472 if (param_strategy.at(LongToSize(axis_)) != 1) {
473 tensor_map_index.insert(tensor_map_index.begin(), index_size, MAP_NONE);
474 for (size_t i = 0; i < param_size; ++i) {
475 tensor_map_params.push_back(SizeToLong(param_size - i - 1));
476 }
477 } else {
478 // param_strategy(axis) is 1
479 for (size_t i = 0; i < param_size; ++i) {
480 tensor_map_params.push_back(SizeToLong(total_size - i - 1));
481 }
482 for (size_t i = 0; i < index_size; ++i) {
483 tensor_map_index.push_back(SizeToLong(index_size - i - 1));
484 }
485 }
486 inputs_tensor_map_.emplace_back(std::move(tensor_map_params));
487 inputs_tensor_map_.emplace_back(std::move(tensor_map_index));
488 }
489
InferOutputsTensorMap()490 void GatherPInfo::InferOutputsTensorMap() {
491 // infer output tensor map
492 size_t param_size = inputs_shape_.at(0).size();
493 size_t index_size = inputs_shape_.at(1).size();
494 size_t total_size = param_size + index_size;
495 Shape tensor_map_out;
496 auto param_strategy = strategy_->GetInputDim().at(0);
497 if (param_strategy.at(LongToSize(axis_)) == 1) {
498 // param_strategy(axis) is 1
499 for (size_t i = 0; i < param_size; ++i) {
500 if (i == LongToSize(axis_)) {
501 for (size_t j = 0; j < index_size; ++j) {
502 tensor_map_out.push_back(SizeToLong(index_size - j - 1));
503 }
504 } else {
505 tensor_map_out.push_back(SizeToLong(total_size - i - 1));
506 }
507 }
508 } else {
509 // param_strategy(axis) is not 1
510 if (axis_ == 0) {
511 if ((dynamic_shape_indices_ && target_ != CPU) || axis_split_forward_allreduce_) {
512 // the output is repeat calculation
513 tensor_map_out.insert(tensor_map_out.end(), MAP_NONE);
514 } else {
515 tensor_map_out.insert(tensor_map_out.end(), param_size - 1);
516 }
517 tensor_map_out.insert(tensor_map_out.end(), index_size - 1, MAP_NONE);
518 for (size_t i = 1; i < param_size; ++i) {
519 tensor_map_out.push_back(param_size - 1 - i);
520 }
521 } else {
522 for (size_t i = 0; i < param_size; ++i) {
523 if (i == LongToSize(axis_)) {
524 tensor_map_out.insert(tensor_map_out.end(), index_size, MAP_NONE);
525 } else {
526 if (i == 0 && dynamic_shape_indices_ && target_ != CPU) {
527 tensor_map_out.push_back(MAP_NONE);
528 }
529 tensor_map_out.push_back(SizeToLong(i));
530 }
531 }
532 }
533 }
534 (void)outputs_tensor_map_.emplace_back(std::move(tensor_map_out));
535 }
536
InferTensorMap()537 Status GatherPInfo::InferTensorMap() {
538 if (manual_split_) {
539 Shape param_map = {1, 0};
540 Shape indices_map = {-1, 1};
541 Shape out_map = {-1, 1, 0};
542 (void)inputs_tensor_map_.emplace_back(std::move(param_map));
543 (void)inputs_tensor_map_.emplace_back(std::move(indices_map));
544 (void)outputs_tensor_map_.emplace_back(std::move(out_map));
545 return SUCCESS;
546 }
547
548 if (shard_batch_and_axis_) {
549 Shape param_tensor_map = {0, -1};
550 Shape indices_tensor_map = {1, -1};
551 Shape out_tensor_map = {1, -1, -1};
552 (void)inputs_tensor_map_.emplace_back(std::move(param_tensor_map)); // param
553 (void)inputs_tensor_map_.emplace_back(std::move(indices_tensor_map)); // indices
554 (void)outputs_tensor_map_.emplace_back(
555 std::move(out_tensor_map)); // output, if forward use reducescatter, tensormap is {0, -1, -1}
556 return SUCCESS;
557 }
558 InferInputsTensorMap();
559 InferOutputsTensorMap();
560 return SUCCESS;
561 }
562
InferTensorInfo()563 Status GatherPInfo::InferTensorInfo() {
564 // infer tensor shape
565 Shape input_shape = inputs_shape_.at(0);
566 Shape input_index_shape = inputs_shape_.at(1);
567 Shape output_shape = outputs_shape_.at(0);
568 int64_t rank = g_device_manager->rank_index_in_stage();
569 // infer tensor layout
570 TensorLayout input_tensor_layout, input_index_layout, output_tensor_layout;
571 if (manual_split_) {
572 input_shape[0] = param_split_shapes_[LongToSize(rank / dev_matrix_shape_[1])];
573 input_shape[0] = input_shape[0] * dev_matrix_shape_[0];
574 }
575 if ((input_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_.at(0), input_shape) != SUCCESS) ||
576 (input_index_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_.at(1), input_index_shape) != SUCCESS) ||
577 (output_tensor_layout.InitFromVector(out_dev_matrix_shape_, outputs_tensor_map_.at(0), output_shape) !=
578 SUCCESS)) {
579 return FAILED;
580 }
581
582 if (manual_split_) {
583 input_tensor_layout.set_uniform_split(false);
584 }
585 // infer tensor info
586 TensorInfo input_tensor_info(input_tensor_layout);
587 TensorInfo input_index_info(input_index_layout);
588 TensorInfo output_tensor_info(output_tensor_layout);
589
590 inputs_tensor_info_.push_back(input_tensor_info);
591 inputs_tensor_info_.push_back(input_index_info);
592 outputs_tensor_info_.push_back(output_tensor_info);
593 return SUCCESS;
594 }
595
InferBias()596 Status GatherPInfo::InferBias() {
597 CheckGlobalDeviceManager();
598 int64_t rank = g_device_manager->rank_index_in_stage();
599 auto input_shape = inputs_shape_.at(0);
600 auto params_strategy = strategy_->GetInputDim().at(0);
601
602 if (shard_batch_and_axis_) {
603 slice_size_ = input_shape[0] / params_strategy[0];
604 bias_ = rank % params_strategy[0] * slice_size_;
605 MS_LOG(INFO) << name_ << ": Sharding batch and axis, the rank is " << rank << ", slice size is " << slice_size_
606 << ", bias is " << bias_;
607 return SUCCESS;
608 }
609
610 // axis don't split
611 if (params_strategy.at(LongToSize(axis_)) == 1) {
612 bias_ = 0;
613 return SUCCESS;
614 }
615 // params_size=1, axis=0
616 if ((input_shape.size() == 1) && (axis_ == 0)) {
617 slice_size_ = input_shape.at(0) / params_strategy.at(0);
618 // if repeated calculation, because the repeated num in the right of dev-matrix, so rank need to div repeated num
619 if (repeated_calc_num_ > 1) {
620 if (repeated_num_in_dev_matrix_right_) {
621 rank = rank / repeated_calc_num_;
622 } else {
623 rank = rank % params_strategy[0];
624 }
625 }
626 bias_ = rank * slice_size_;
627 return SUCCESS;
628 }
629 // params_size=2, axis=0
630 if ((input_shape.size() == 2) && (axis_ == 0)) {
631 slice_size_ = input_shape.at(0) / params_strategy.at(0);
632 // if repeated calculation, because the repeated num in the right of dev-matrix, so rank need to div repeated num
633 if (repeated_calc_num_ > 1) {
634 if (repeated_num_in_dev_matrix_right_) {
635 rank = rank / repeated_calc_num_;
636 } else {
637 rank = rank % (params_strategy[0] * params_strategy[1]);
638 }
639 }
640 #if ((defined ENABLE_CPU) && (!defined _WIN32))
641 if (ps::PsDataPrefetch::GetInstance().cache_enable()) {
642 bias_ = static_cast<int64_t>(ps::PsCacheManager::GetInstance().cache_indices_lower_bound());
643 return SUCCESS;
644 }
645 #endif
646 bias_ = rank / params_strategy.at(1) * slice_size_;
647 return SUCCESS;
648 }
649 // params_size=2, axis=1
650 if ((input_shape.size() == 2) && (axis_ == 1)) {
651 slice_size_ = input_shape.at(1) / params_strategy.at(1);
652 bias_ = rank % params_strategy.at(1) * slice_size_;
653 return SUCCESS;
654 }
655 MS_LOG(ERROR) << name_ << ": Don't support params_size:" << input_shape.size() << " axis:" << axis_;
656 return FAILED;
657 }
658
InferOffset()659 Status GatherPInfo::InferOffset() {
660 CheckGlobalDeviceManager();
661 size_t rank = LongToSize(g_device_manager->rank_index_in_stage());
662
663 MS_EXCEPTION_IF_NULL(strategy_);
664 auto param_strategy = strategy_->GetInputDim()[0];
665 if (param_strategy.size() != 2) {
666 MS_LOG(ERROR) << "The size of param strategy must be 2";
667 return FAILED;
668 }
669 size_t index = rank / LongToSize(param_strategy[1]);
670 if (index < index_offsets_.size()) {
671 index_offset_ = index_offsets_[index];
672 MS_LOG(INFO) << name_ << ": Device rank " << rank << ", Index Offset: " << index_offset_;
673 return SUCCESS;
674 }
675
676 MS_LOG(ERROR) << name_ << ": Get index offset failed, index offset size is" << index_offsets_.size();
677 return FAILED;
678 }
679
InferGroup()680 Status GatherPInfo::InferGroup() {
681 size_t dim = LongToSize(axis_);
682
683 int64_t rank = g_device_manager->global_rank();
684 DeviceMatrix dev_matrix(rank, stage_device_list_, dev_matrix_shape_);
685 RankList group_devices;
686
687 // the dev_matrix[0] is repeated_calc_num, so the dim need to add 1
688 if ((repeated_calc_num_ > 1) && !repeated_num_in_dev_matrix_right_) {
689 dim = dim + 1;
690 }
691
692 if (shard_batch_and_axis_) {
693 dim = 1;
694 MS_LOG(INFO) << name_ << ": Sharding batch and axis, the group dim is " << dim;
695 }
696
697 if (dev_matrix.GetDevicesAlongDim(SizeToUlong(dim), &group_devices) != SUCCESS) {
698 MS_LOG(ERROR) << name_ << ": Create group failed.";
699 return FAILED;
700 }
701 if (group_devices.size() == 1) {
702 MS_LOG(INFO) << name_ << ": The group is empty";
703 return SUCCESS;
704 }
705
706 MS_LOG(INFO) << name_ << ": The group ranks is " << group_devices;
707 group_ = g_device_manager->CreateGroup(group_devices);
708 return SUCCESS;
709 }
710
InferForwardCommunication()711 Status GatherPInfo::InferForwardCommunication() {
712 if (manual_split_) {
713 return SUCCESS;
714 }
715
716 forward_op_.clear();
717 auto param_strategy = strategy_->GetInputDim().at(0);
718 // don't split axis or target is not CPU, no need forward communication
719 if (target_ != CPU || param_strategy.at(LongToSize(axis_)) == 1) {
720 return SUCCESS;
721 }
722 // split axis
723 OperatorName operator_name;
724 if (InferGroup() != SUCCESS) {
725 MS_LOG(ERROR) << name_ << ": Infer Group failed.";
726 return FAILED;
727 }
728 Attr attr_group;
729 operator_name = REDUCE_SCATTER;
730 if (InferGroup() != SUCCESS) {
731 MS_LOG(ERROR) << name_ << ": Infer Group failed.";
732 return FAILED;
733 }
734 if (group_.name().empty()) {
735 return SUCCESS;
736 }
737 attr_group = std::make_pair(GROUP, MakeValue(group_.name()));
738 Attr attr_op = std::make_pair(OP, MakeValue(REDUCE_OP_SUM));
739 OperatorAttrs attrs = {attr_op, attr_group};
740 OperatorParams params;
741 OperatorArgs args = std::make_pair(attrs, params);
742 Operator op = std::make_pair(operator_name, args);
743
744 forward_op_.push_back(op);
745 return SUCCESS;
746 }
747
ComputeReplaceGraph(const CNodePtr & cnode)748 Status GatherPInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
749 GenerateGraph gen_g = GenerateGraph(attrs_);
750 if (gen_g.Init(cnode) != SUCCESS) {
751 MS_LOG(ERROR) << "GenerateGraph Init failed";
752 return FAILED;
753 }
754 if (manual_split_ && target_ != CPU) {
755 if (InferOffset() != SUCCESS) {
756 MS_LOG(ERROR) << name_ << ": Infer Bias failed.";
757 return FAILED;
758 }
759 auto sub_node =
760 gen_g.PushBack({gen_g.NewOpInst(SUB), gen_g.virtual_input_node(), CreateInt32Tensor(index_offset_)});
761 auto gather_v2_node =
762 gen_g.PushBack({gen_g.NewOpInst(replace_op_name_), gen_g.virtual_input_node(), sub_node, CreatInt64Imm(axis_)});
763 std::vector<std::pair<AnfNodePtr, int64_t>> input_nodes = {std::make_pair(sub_node, 2),
764 std::make_pair(gather_v2_node, 1)};
765 replace_graph_ = std::make_shared<std::pair<std::vector<std::pair<AnfNodePtr, int64_t>>, AnfNodePtr>>(
766 std::make_pair(input_nodes, gather_v2_node));
767 return SUCCESS;
768 }
769 if (InferBias() != SUCCESS) {
770 MS_LOG(ERROR) << name_ << ": Infer Bias failed.";
771 return FAILED;
772 }
773 MS_LOG(INFO) << name_ << ": The rank is " << g_device_manager->rank_index_in_stage() << ", the bias is " << bias_;
774 auto sub = gen_g.PushBack({gen_g.NewOpInst(SUB), gen_g.virtual_input_node(), CreateInt32Tensor(bias_)});
775 auto relu = gen_g.PushBack({gen_g.NewOpInst(RELU), sub});
776 auto minimum = gen_g.PushBack({gen_g.NewOpInst(MINIMUM), relu, CreateInt32Tensor(slice_size_ - 1)});
777 auto equal = gen_g.PushBack({gen_g.NewOpInst(EQUAL), sub, minimum});
778 auto gather_v2 =
779 gen_g.PushBack({gen_g.NewOpInst(replace_op_name_), gen_g.virtual_input_node(), minimum, CreatInt64Imm(axis_)});
780 auto dtype = gen_g.PushBack({gen_g.NewOpInst(DTYPE), gather_v2});
781 auto cast = gen_g.PushBack({gen_g.NewOpInst(CAST), equal, dtype});
782 auto expand_dims = gen_g.PushBack({gen_g.NewOpInst(EXPAND_DIMS), cast, CreatInt64Imm(axis_ - 1)});
783 auto mul = gen_g.PushBack({gen_g.NewOpInst(MUL), gather_v2, expand_dims});
784 // don't need expand dim, if param_size = 1
785 if (inputs_shape_.at(0).size() == 1) {
786 mul = gen_g.PushBack({gen_g.NewOpInst(MUL), gather_v2, cast});
787 }
788 if (InferGroup() != SUCCESS) {
789 MS_LOG(ERROR) << name_ << ": Infer Group failed.";
790 return FAILED;
791 }
792 Attr attr_op = std::make_pair(OP, MakeValue(REDUCE_OP_SUM));
793 Attr attr_group = std::make_pair(GROUP, MakeValue(group_.name()));
794 OperatorAttrs attrs = {attr_op, attr_group};
795 AnfNodePtr reduce_op;
796 if (dynamic_shape_indices_ || axis_split_forward_allreduce_) {
797 reduce_op = gen_g.PushBack({gen_g.NewOpInst(ALL_REDUCE, attrs), mul});
798 } else {
799 reduce_op = gen_g.PushBack({gen_g.NewOpInst(REDUCE_SCATTER, attrs), mul});
800 }
801 std::vector<std::pair<AnfNodePtr, int64_t>> input_nodes = {std::make_pair(sub, 2), std::make_pair(gather_v2, 1)};
802 replace_graph_ = std::make_shared<std::pair<std::vector<std::pair<AnfNodePtr, int64_t>>, AnfNodePtr>>(
803 std::make_pair(input_nodes, reduce_op));
804
805 return SUCCESS;
806 }
807
replace_graph(const CNodePtr & cnode)808 ReplaceGraphPtr GatherPInfo::replace_graph(const CNodePtr &cnode) {
809 if (manual_split_ && target_ != CPU) {
810 if (ComputeReplaceGraph(cnode) != SUCCESS) {
811 MS_LOG(EXCEPTION) << name_ << ": ComputeReplaceGraph failed.";
812 }
813 return replace_graph_;
814 }
815
816 auto param_strategy = strategy_->GetInputDim().at(0);
817 // target_ == CPU, no need to replace graph
818 if (target_ == CPU) {
819 return nullptr;
820 }
821 if (param_strategy.at(LongToSize(axis_)) != 1 && ComputeReplaceGraph(cnode) != SUCCESS) {
822 MS_LOG(EXCEPTION) << name_ << ": ComputeReplaceGraph failed.";
823 }
824 return replace_graph_;
825 }
826
ComputeReplaceOp()827 Status GatherPInfo::ComputeReplaceOp() {
828 int64_t bias = 0;
829 if (manual_split_) {
830 if (InferOffset() != SUCCESS) {
831 MS_LOG(ERROR) << name_ << ": Infer offset failed.";
832 return FAILED;
833 }
834 bias = index_offset_;
835 } else {
836 if (InferBias() != SUCCESS) {
837 MS_LOG(ERROR) << name_ << ": Infer offset failed.";
838 return FAILED;
839 }
840 bias = bias_;
841 }
842
843 OperatorName op_name = EMBEDDING_LOOKUP;
844 OperatorAttrs attrs;
845 Attr param_offset = std::make_pair("offset", MakeValue(bias));
846 OperatorParams params = {std::make_pair(param_offset, 3)};
847 OperatorArgs args = std::make_pair(attrs, params);
848 Operator op = std::make_pair(op_name, args);
849 replace_op_.push_back(op);
850
851 return SUCCESS;
852 }
853
Init(const StrategyPtr & strategy)854 Status GatherPInfo::Init(const StrategyPtr &strategy) {
855 if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
856 MS_LOG(ERROR) << name_ << ": Init failed.";
857 return FAILED;
858 }
859 // only target_ == CPU, we need to replace op
860 if (target_ == CPU && ComputeReplaceOp() != SUCCESS) {
861 MS_LOG(ERROR) << name_ << ": ComputeReplaceOp failed.";
862 }
863 MS_LOG(INFO) << name_ << ": Init success.";
864 return SUCCESS;
865 }
866
InitForCostModel(const StrategyPtr & strategy)867 Status GatherPInfo::InitForCostModel(const StrategyPtr &strategy) {
868 if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
869 if (is_auto_parallel_) {
870 MS_LOG(DEBUG) << name_ << ": Init for cost model failed.";
871 } else {
872 MS_LOG(ERROR) << name_ << ": Init for cost model failed.";
873 }
874 return FAILED;
875 }
876 auto param_strategy = strategy_->GetInputDim().at(0);
877 // cost model set axis and strategy
878 auto gatherv2_2cost = std::dynamic_pointer_cast<GatherV2PCost>(operator_cost());
879 gatherv2_2cost->set_axis(axis_);
880 gatherv2_2cost->set_strategy(param_strategy);
881 MS_LOG(INFO) << name_ << ": Init for cost model success.";
882 return SUCCESS;
883 }
884
SetCostUnderStrategy(const StrategyPtr & strategy)885 Status GatherPInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }
886
GenerateOpStrategies(int64_t stage_id)887 std::vector<StrategyPtr> GatherPInfo::GenerateOpStrategies(int64_t stage_id) {
888 if (manual_split_) {
889 MS_LOG(EXCEPTION) << name_ << ": Manual split does not support to search strategy";
890 }
891 is_auto_parallel_ = true;
892 Shape input0_split(inputs_shape_[0].size(), 1);
893 Shape input1_split(inputs_shape_[1].size(), 1);
894 Shapes splittable_inputs = {input0_split, input1_split};
895
896 std::vector<StrategyPtr> sp_vector;
897 if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) {
898 MS_LOG(EXCEPTION) << name_ << ": Generate strategies for independent inputs() failed.";
899 }
900 return sp_vector;
901 }
902
GenerateBatchStrategies()903 std::shared_ptr<Strategys> GatherPInfo::GenerateBatchStrategies() {
904 if (GetAttrs() != SUCCESS) {
905 MS_LOG(EXCEPTION) << name_ << ": Get attr failed";
906 }
907 if (manual_split_) {
908 MS_LOG(EXCEPTION) << name_ << ": Manual split does not support to generate batch strategy";
909 }
910
911 Dimensions param_strategy(inputs_shape_[0].size(), 1);
912 Dimensions index_strategy;
913 index_strategy.push_back(stage_device_size_);
914 for (size_t i = 1; i < inputs_shape_[1].size(); i++) {
915 index_strategy.push_back(1);
916 }
917 Strategys strategy_v = {param_strategy, index_strategy};
918 return std::make_shared<Strategys>(strategy_v);
919 }
920 } // namespace parallel
921 } // namespace mindspore
922