1 /**
2 * Copyright 2019 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/parallel/ops_info/reduce_method_info.h"
18
19 #include <algorithm>
20 #include <memory>
21 #include <utility>
22 #include <vector>
23
24 #include "ir/value.h"
25 #include "frontend/parallel/device_manager.h"
26 #include "frontend/parallel/device_matrix.h"
27 #include "frontend/parallel/tensor_layout/tensor_redistribution.h"
28 #include "utils/log_adapter.h"
29
30 namespace mindspore {
31 namespace parallel {
CheckStrategy(const StrategyPtr & strategy)32 Status ReduceMethod::CheckStrategy(const StrategyPtr &strategy) { return CheckStrategyValue(strategy, inputs_shape_); }
33
InferDevMatrixShape()34 Status ReduceMethod::InferDevMatrixShape() {
35 Strategys stra = strategy_->GetInputDim();
36 Dimensions input_strategy = stra.at(0);
37
38 dev_matrix_shape_ = input_strategy;
39
40 return SUCCESS;
41 }
42
reduce_dim()43 std::vector<int64_t> ReduceMethod::reduce_dim() {
44 std::vector<int64_t> dim_list;
45 if (input_value_.size() < 2) {
46 MS_LOG(EXCEPTION) << name_ << ": Input value size is smaller than 2.";
47 }
48 if (input_value_.back() == nullptr) {
49 MS_LOG(EXCEPTION) << name_ << ": Input value is nullptr.";
50 }
51 MS_ASSERT(inputs_shape_.size() == 1);
52 auto input_dim = inputs_shape_.at(0).size();
53 if (input_value_.back()->isa<ValueTuple>()) {
54 auto attr_axis = GetValue<std::vector<int64_t>>(input_value_.back());
55 // axis is (), reduce all dim
56 if (attr_axis.empty()) {
57 for (size_t i = 0; i < input_dim; ++i) {
58 dim_list.push_back(SizeToLong(i));
59 }
60 } else {
61 for (auto &axis : attr_axis) {
62 axis < 0 ? dim_list.push_back(axis + SizeToLong(input_dim)) : dim_list.push_back(axis);
63 }
64 }
65 } else if (input_value_.back()->isa<Int64Imm>()) {
66 int64_t axis = GetValue<int64_t>(input_value_.back());
67 axis < 0 ? dim_list.push_back(axis + SizeToLong(input_dim)) : dim_list.push_back(axis);
68 } else {
69 MS_LOG(EXCEPTION) << "Axis type is invalid.";
70 }
71
72 return dim_list;
73 }
74
GetAttrs()75 Status ReduceMethod::GetAttrs() {
76 // get attr cross_batch and keep_dims
77 auto keep_dims_iter = attrs_.find(KEEP_DIMS);
78 if (keep_dims_iter == attrs_.end()) {
79 MS_LOG(ERROR) << name_ << ": Don't have attr keep_dims.";
80 return FAILED;
81 }
82
83 if (keep_dims_iter != attrs_.end()) {
84 MS_EXCEPTION_IF_NULL(keep_dims_iter->second);
85 if (!keep_dims_iter->second->isa<BoolImm>()) {
86 MS_LOG(ERROR) << name_ << ": Keep_dims is not a bool.";
87 return FAILED;
88 }
89 keepdims_ = keep_dims_iter->second->cast<BoolImmPtr>()->value();
90 }
91
92 auto cross_batch_iter = attrs_.find(CROSS_BATCH);
93 if (cross_batch_iter != attrs_.end()) {
94 MS_EXCEPTION_IF_NULL(cross_batch_iter->second);
95 if (!cross_batch_iter->second->isa<BoolImm>()) {
96 MS_LOG(ERROR) << name_ << ": cross_batch is not a bool.";
97 return FAILED;
98 }
99 cross_batch_ = cross_batch_iter->second->cast<BoolImmPtr>()->value();
100 }
101 auto reducemethodcost = std::dynamic_pointer_cast<ReduceMethodCost>(operator_cost());
102 if (reducemethodcost == nullptr) {
103 MS_LOG(ERROR) << "Cost cast to ReduceMethodCostPtr failed!";
104 return FAILED;
105 }
106 reducemethodcost->set_cross_batch(cross_batch_);
107 return SUCCESS;
108 }
109
InferTensorMap()110 Status ReduceMethod::InferTensorMap() {
111 Shape tensor_map_index, output_tensor_map;
112 std::vector<int64_t> dim_list;
113 size_t size = inputs_shape_.at(0).size();
114 // such as 4: tensor_map_index [3,2,1,0]
115 for (size_t i = 0; i < size; ++i) {
116 tensor_map_index.push_back((int64_t)(size - 1 - i));
117 }
118 dim_list = reduce_dim();
119 for (size_t i = 0; i < size; ++i) {
120 if (find(dim_list.begin(), dim_list.end(), SizeToLong(i)) != dim_list.end()) {
121 if (keepdims_) {
122 output_tensor_map.push_back(-1);
123 } else {
124 continue;
125 }
126 } else {
127 output_tensor_map.push_back(tensor_map_index[i]);
128 }
129 }
130 inputs_tensor_map_.push_back(tensor_map_index);
131 outputs_tensor_map_.push_back(output_tensor_map);
132
133 return SUCCESS;
134 }
135
IsDataParallelStrategy(const Dimensions & strategy,int32_t stage_id)136 bool IsDataParallelStrategy(const Dimensions &strategy, int32_t stage_id) {
137 CheckGlobalDeviceManager();
138 size_t total_dev_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
139 if (strategy.empty()) {
140 MS_LOG(EXCEPTION) << "IsDataParallelStrategy: strategy is empty";
141 }
142
143 return (LongToSize(strategy[0]) == total_dev_num);
144 }
145
InferForwardCommunication()146 Status ReduceMethod::InferForwardCommunication() {
147 Dimensions stra = strategy_->GetInputDim().at(0);
148 if (cross_batch_ && IsDataParallelStrategy(stra, stage_id_)) {
149 MS_LOG(INFO) << name_ << ": cross_batch is True, don't need to InferForwardCommunication";
150 return SUCCESS;
151 }
152 if (cross_batch_) {
153 MS_LOG(INFO) << name_ << ": cross_batch is True, don't need to InferForwardCommunication";
154 return SUCCESS;
155 }
156 forward_op_.clear();
157 std::vector<int64_t> dim_list = reduce_dim();
158 size_t size = stra.size();
159 // judge if the reduce dim is partitioned.
160 Shape group_creat_map;
161
162 // if repeated calculation and the repeated_calc_num_ insert to the first dimension of dev matrix,
163 // it need to handle the first dimension of map.
164 if ((dev_matrix_shape_.size() > size) && !repeated_num_in_dev_matrix_right_) {
165 group_creat_map.push_back(SizeToInt(dev_matrix_shape_.size() - size_t(1)));
166 }
167 for (size_t index = 0; index < size; ++index) {
168 auto pos =
169 std::find_if(dim_list.begin(), dim_list.end(), [index](const int64_t &dim) { return SizeToLong(index) == dim; });
170 if (pos != dim_list.end() && stra[index] != 1) {
171 continue;
172 }
173 group_creat_map.push_back(SizeToLong(size) - SizeToLong(index) - 1);
174 }
175
176 // if repeated calculation and the repeated_calc_num_ insert to the last dimension of dev matrix,
177 // it need to handle the group_creat_map and insert the 0 to the last dimension of the group_creat_map.
178 if (repeated_num_in_dev_matrix_right_ && (repeated_calc_num_ > 1)) {
179 for (auto &ele : group_creat_map) {
180 if (ele == MAP_NONE) {
181 continue;
182 }
183 ele += 1;
184 }
185 group_creat_map.push_back(0);
186 }
187 std::vector<Group> forward_group;
188 if (CreateGroupByTensorMap(group_creat_map, &forward_group) != SUCCESS) {
189 MS_LOG(ERROR) << name_ << ": InferForwardCommunication group failed.";
190 return FAILED;
191 }
192 if (!forward_group.empty()) {
193 Operator op = CreateAllReduceOp(reduce_method_, forward_group[0].name());
194 forward_op_.push_back(op);
195 std::string group_name = forward_group[0].name();
196 MS_LOG(INFO) << name_ << ": Forward communication group is " << group_name;
197 }
198
199 return SUCCESS;
200 }
201
CreateReduceMeanForwardOp(const std::vector<Group> & forward_group,const TypePtr & dtype)202 ForwardOp CreateReduceMeanForwardOp(const std::vector<Group> &forward_group, const TypePtr &dtype) {
203 // Create AllReduceSum op
204 Operator op0 = CreateAllReduceOp(REDUCE_OP_SUM, forward_group[0].name());
205 std::string group_name = forward_group[0].name();
206 MS_LOG(INFO) << "The group of forward all reduce is " << group_name;
207
208 // Create RealDiv op
209 OperatorName operator1_name = REAL_DIV;
210 std::vector<Device> device_list = forward_group[0].GetDevicesList();
211 auto divisor = static_cast<float>(device_list.size());
212 mindspore::tensor::TensorPtr tensor_ptr = std::make_shared<mindspore::tensor::Tensor>(divisor, dtype);
213 ValuePtr op1_param_value = MakeValue(tensor_ptr);
214 Attr op1_param = std::make_pair("divisor", op1_param_value);
215 OperatorParams operator1_params = {std::make_pair(op1_param, 2)};
216 OperatorAttrs operator1_attrs;
217 OperatorArgs operator1_args = std::make_pair(operator1_attrs, operator1_params);
218 Operator op1 = std::make_pair(operator1_name, operator1_args);
219 ForwardOp forward_op = {op0, op1};
220
221 std::string dtype_name = dtype->ToString();
222 MS_LOG(INFO) << "The divisor of Div op is " << device_list.size() << ", the dtype is " << dtype_name;
223 return forward_op;
224 }
225
InferForwardCommunication()226 Status ReduceMeanInfo::InferForwardCommunication() {
227 Dimensions stra = strategy_->GetInputDim().at(0);
228 if (cross_batch_ && IsDataParallelStrategy(stra, stage_id_)) {
229 MS_LOG(INFO) << name_ << ": cross_batch is True, don't need to InferForwardCommunication";
230 return SUCCESS;
231 }
232 forward_op_.clear();
233 std::vector<int64_t> dim_list = reduce_dim();
234 size_t size = stra.size();
235 // judge if the reduce dim is partitioned.
236 Shape group_creat_map;
237
238 // if repeated calculation and the repeated_calc_num_ insert to the first dimension of dev matrix,
239 // it need to handle the first dimension of map.
240 if ((dev_matrix_shape_.size() > size) && !repeated_num_in_dev_matrix_right_) {
241 group_creat_map.push_back(SizeToInt(dev_matrix_shape_.size() - size_t(1)));
242 }
243
244 for (size_t index = 0; index < size; ++index) {
245 auto pos =
246 std::find_if(dim_list.begin(), dim_list.end(), [index](const int64_t &dim) { return SizeToLong(index) == dim; });
247 if (pos != dim_list.end() && stra[index] != 1) {
248 continue;
249 }
250 group_creat_map.push_back(SizeToLong(size) - SizeToLong(index) - 1);
251 }
252
253 // if repeated calculation and the repeated_calc_num_ insert to the last dimension of dev matrix,
254 // it need to handle the group_creat_map and insert the 0 to the last dimension of the group_creat_map.
255 if (repeated_num_in_dev_matrix_right_ && (repeated_calc_num_ > 1)) {
256 for (auto &ele : group_creat_map) {
257 if (ele == MAP_NONE) {
258 continue;
259 }
260 ele += 1;
261 }
262 group_creat_map.push_back(0);
263 }
264
265 std::vector<Group> forward_group;
266 if (CreateGroupByTensorMap(group_creat_map, &forward_group) != SUCCESS) {
267 MS_LOG(ERROR) << name_ << ": InferForwardCommunication group failed.";
268 return FAILED;
269 }
270 if (!forward_group.empty()) {
271 if ((outputs_dtype_ == nullptr) || !outputs_dtype_->isa<mindspore::TensorType>()) {
272 MS_LOG(ERROR) << name_ << ": The dtype of output is not Array";
273 return FAILED;
274 }
275
276 auto element_type = outputs_dtype_->cast<mindspore::TensorTypePtr>()->element();
277 forward_op_ = CreateReduceMeanForwardOp(forward_group, element_type);
278 }
279
280 return SUCCESS;
281 }
282
InferMirrorOps()283 Status ReduceMethod::InferMirrorOps() {
284 mirror_ops_.clear();
285 Shape input_tensor_map = inputs_tensor_map_.at(0);
286 std::vector<Group> input_group;
287 if (CreateGroupByTensorMap(input_tensor_map, &input_group) != SUCCESS) {
288 MS_LOG(ERROR) << name_ << " Infer MirrorOps failed.";
289 return FAILED;
290 }
291
292 OperatorVector op_for_weight;
293 OperatorVector op_for_reduce_axis; // helper node
294 if (input_group.empty()) {
295 MS_LOG(INFO) << name_ << ": The mirror ops is empty.";
296 return SUCCESS;
297 } else {
298 op_for_weight = CreateMirrorOps(input_group[0].name(), input_group[0].GetDevNum());
299 mirror_ops_.push_back(op_for_weight);
300 mirror_ops_.push_back(op_for_reduce_axis);
301 std::string group_name = input_group[0].name();
302 MS_LOG(INFO) << name_ << ": Create the mirror ops for weight success, the group is " << group_name;
303 }
304
305 return SUCCESS;
306 }
307
InferMirrorOps()308 Status ArgMaxWithValueInfo::InferMirrorOps() {
309 mirror_ops_.clear();
310 Shape input_tensor_map = inputs_tensor_map_.at(0);
311 std::vector<Group> input_group;
312 if (CreateGroupByTensorMap(input_tensor_map, &input_group) != SUCCESS) {
313 MS_LOG(ERROR) << name_ << ": Infer MirrorOps failed.";
314 return FAILED;
315 }
316
317 OperatorVector op_for_weight;
318 if (input_group.empty()) {
319 MS_LOG(INFO) << name_ << ": The mirror ops is empty.";
320 return SUCCESS;
321 } else {
322 op_for_weight = CreateMirrorOps(input_group[0].name(), input_group[0].GetDevNum());
323 mirror_ops_.push_back(op_for_weight);
324 MS_LOG(INFO) << name_ << ": Create the mirror ops for weight success.";
325 }
326
327 return SUCCESS;
328 }
329
InferOutputStrategy()330 Dimensions ReduceMethod::InferOutputStrategy() {
331 std::vector<int64_t> dim_list = reduce_dim();
332 Dimensions output_strategy;
333 Dimensions stra = strategy_->GetInputDim().at(0);
334 // if keepdims_ is true,then output strategy is same with input.
335 for (size_t i = 0; i < stra.size(); ++i) {
336 if (find(dim_list.begin(), dim_list.end(), SizeToLong(i)) != dim_list.end()) {
337 if (keepdims_) {
338 output_strategy.push_back(1);
339 }
340 } else {
341 output_strategy.push_back(stra[i]);
342 }
343 }
344 return output_strategy;
345 }
346
InferTensorInfo()347 Status ReduceMethod::InferTensorInfo() {
348 // infer tensor shape
349 Shape input_shape = inputs_shape_.at(0);
350 Shape output_shape = outputs_shape_.at(0);
351
352 // infer slice shape
353 Shapes inputs_slice_shape, outputs_slice_shape;
354 Strategys inputs_strategy = strategy_->GetInputDim();
355 Dimensions output_strategy = InferOutputStrategy();
356
357 Strategys outputs_strategy = {output_strategy};
358 if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) {
359 return FAILED;
360 }
361 Shape input_slice_shape = inputs_slice_shape.at(0);
362 Shape output_slice_shape = outputs_slice_shape.at(0);
363
364 TensorLayout input_tensor_layout, output_tensor_layout;
365 if ((input_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], input_shape) != SUCCESS) ||
366 (output_tensor_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[0], output_shape) != SUCCESS)) {
367 return FAILED;
368 }
369
370 std::vector<int64_t> dim_list = reduce_dim();
371 TensorInfo input_tensor_info(input_tensor_layout, input_shape, input_slice_shape);
372 TensorInfo output_tensor_info(output_tensor_layout, output_shape, output_slice_shape);
373 input_tensor_info.set_reduce_dim(dim_list);
374
375 inputs_tensor_info_.push_back(input_tensor_info);
376 outputs_tensor_info_.push_back(output_tensor_info);
377
378 return SUCCESS;
379 }
380
SetCostUnderStrategy(const StrategyPtr & strategy)381 Status ReduceMethod::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }
382
GenerateOpStrategies(int64_t stage_id)383 std::vector<StrategyPtr> ReduceMethod::GenerateOpStrategies(int64_t stage_id) {
384 Shape input0_split(inputs_shape_[0].size(), 1);
385 Shapes splittable_inputs = {input0_split};
386 std::vector<StrategyPtr> sp_vector;
387 if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) {
388 MS_LOG(EXCEPTION) << name_ << ": GenerateStrategiesForIndependentInputs failed.";
389 }
390 return sp_vector;
391 }
392
Init(const StrategyPtr & strategy)393 Status ReduceMethod::Init(const StrategyPtr &strategy) {
394 if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
395 MS_LOG(ERROR) << name_ << ": Init failed.";
396 return FAILED;
397 }
398
399 return SUCCESS;
400 }
401
InitForCostModel(const StrategyPtr & strategy)402 Status ReduceMethod::InitForCostModel(const StrategyPtr &strategy) {
403 if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
404 MS_LOG(ERROR) << name_ << ": Init for cost model failed";
405 return FAILED;
406 }
407
408 MS_LOG(INFO) << name_ << ": Init for cost model success";
409 return SUCCESS;
410 }
411
reduce_dim()412 std::vector<int64_t> ArgMaxWithValueInfo::reduce_dim() {
413 std::vector<int64_t> dim_list;
414 auto iter = attrs_.find(AXIS);
415 if (iter == attrs_.end()) {
416 MS_LOG(EXCEPTION) << name_ << ": Don't have attr axis.";
417 }
418
419 MS_ASSERT(inputs_shape_.size() == 1);
420 auto input_dim = inputs_shape_.at(0).size();
421 MS_EXCEPTION_IF_NULL(iter->second);
422 if (iter->second->isa<ValueTuple>()) {
423 auto attr_axis = GetValue<std::vector<int64_t>>(iter->second);
424 if (attr_axis.empty()) {
425 for (size_t i = 0; i < input_dim; ++i) {
426 dim_list.push_back(SizeToLong(i));
427 }
428 } else {
429 for (auto &axis : attr_axis) {
430 axis < 0 ? dim_list.push_back(axis + SizeToLong(input_dim)) : dim_list.push_back(axis);
431 }
432 }
433 } else if (iter->second->isa<Int64Imm>()) {
434 int64_t axis = GetValue<int64_t>(iter->second);
435 axis < 0 ? dim_list.push_back(axis + SizeToLong(input_dim)) : dim_list.push_back(axis);
436 } else {
437 MS_LOG(EXCEPTION) << "Axis type is invalid.";
438 }
439
440 return dim_list;
441 }
442
CheckStrategy(const StrategyPtr & strategy)443 Status ArgMaxWithValueInfo::CheckStrategy(const StrategyPtr &strategy) {
444 if (ReduceMethod::CheckStrategy(strategy) != SUCCESS) {
445 MS_LOG(ERROR) << name_ << ": CheckStrategy for parent class ReduceMethod failed";
446 return FAILED;
447 }
448 std::vector<int64_t> dim_list = reduce_dim();
449 MS_ASSERT(dim_list.size() == 1);
450
451 Strategys stra = strategy->GetInputDim();
452 MS_ASSERT(stra.size() == 1);
453 Shape input_strategy = stra.at(0);
454 MS_ASSERT(dim_list.at(0) < input_strategy.size());
455 if (input_strategy.at(LongToSize(dim_list.at(0))) != 1) {
456 MS_LOG(WARNING)
457 << name_
458 << " CheckStrategy for ArgMaxWithValueInfo, the strategy corresponding to axis is not one, real strategy "
459 "is "
460 << input_strategy.at(LongToSize(dim_list.at(0)))
461 << ", the output index may be not compatible with the stand alone Primitive";
462 }
463 return SUCCESS;
464 }
465
InferTensorMap()466 Status ArgMaxWithValueInfo::InferTensorMap() {
467 if (ReduceMethod::InferTensorMap() != SUCCESS) {
468 MS_LOG(ERROR) << name_ << ": InferTensorMap for parent class ReduceMethod failed";
469 return FAILED;
470 }
471 MS_ASSERT(outputs_tensor_map_.size() == 1);
472 outputs_tensor_map_.push_back(outputs_tensor_map_[0]);
473 return SUCCESS;
474 }
475
InferTensorInfo()476 Status ArgMaxWithValueInfo::InferTensorInfo() {
477 // infer tensor shape
478 Shape input_shape = inputs_shape_.at(0);
479 Shape output_shape = outputs_shape_.at(0);
480
481 // infer slice shape
482 Shapes inputs_slice_shape, outputs_slice_shape;
483 Strategys inputs_strategy = strategy_->GetInputDim();
484 Dimensions output_strategy = InferOutputStrategy();
485
486 Strategys outputs_strategy = {output_strategy, output_strategy};
487 if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) {
488 return FAILED;
489 }
490 Shape input_slice_shape = inputs_slice_shape.at(0);
491 Shape output_slice_shape = outputs_slice_shape.at(0);
492
493 TensorLayout input_tensor_layout, output_tensor_layout;
494 if ((input_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], input_shape) != SUCCESS) ||
495 (output_tensor_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[0], output_shape) != SUCCESS)) {
496 return FAILED;
497 }
498
499 std::vector<int64_t> dim_list = reduce_dim();
500 TensorInfo input_tensor_info(input_tensor_layout, input_shape, input_slice_shape);
501 TensorInfo output_tensor_info(output_tensor_layout, output_shape, output_slice_shape);
502 input_tensor_info.set_reduce_dim(dim_list);
503
504 inputs_tensor_info_.push_back(input_tensor_info);
505 outputs_tensor_info_.push_back(output_tensor_info);
506 outputs_tensor_info_.push_back(output_tensor_info);
507 return SUCCESS;
508 }
509
InferAsLossDivisor()510 Status ArgMaxWithValueInfo::InferAsLossDivisor() {
511 if (outputs_tensor_map_.empty()) {
512 MS_LOG(ERROR) << name_ << ": The outputs tensor map is empty.";
513 return FAILED;
514 }
515
516 MS_LOG(INFO) << name_ << " has two outputs, use output[0] to infer";
517 if (outputs_tensor_map_[0].empty()) {
518 as_loss_divisor_ = stage_device_size_;
519 MS_LOG(INFO) << name_ << ": The output is a scalar, use the dev size" << as_loss_divisor_ << " as loss divisor.";
520 return SUCCESS;
521 }
522
523 as_loss_divisor_ = ComputeRepeatDeviceNumByTensorMap(dev_matrix_shape_, outputs_tensor_map_[0]);
524
525 std::string dev_matrix_shape_str = ShapeToString(dev_matrix_shape_);
526 std::string output_tensor_map_str = ShapeToString(outputs_tensor_map_[0]);
527 MS_LOG(INFO) << name_ << ": the dev matrix shape, the output tensor map, and loss divisor is " << dev_matrix_shape_str
528 << ", " << output_tensor_map_str << ", " << as_loss_divisor_;
529 return SUCCESS;
530 }
531
GenerateOpStrategies(int64_t stage_id)532 std::vector<StrategyPtr> ArgMaxWithValueInfo::GenerateOpStrategies(int64_t stage_id) {
533 Shape input0_split(inputs_shape_[0].size(), 1);
534 Shapes splittable_inputs = {input0_split};
535 std::vector<StrategyPtr> sp_vector;
536 if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) {
537 MS_LOG(EXCEPTION) << name_ << ": GenerateStrategiesForIndependentInputs failed.";
538 }
539
540 return sp_vector;
541 }
542
CheckStrategy(const StrategyPtr & strategy)543 Status ReduceAnyInfo::CheckStrategy(const StrategyPtr &strategy) {
544 if (ReduceMethod::CheckStrategy(strategy) != SUCCESS) {
545 MS_LOG(ERROR) << name_ << ": checking strategy failed.";
546 return FAILED;
547 }
548 auto dim_list = ReduceMethod::reduce_dim();
549 Dimensions stra = strategy->GetInputDim().at(0);
550 for (size_t index = 0; index < stra.size(); ++index) {
551 auto pos =
552 std::find_if(dim_list.begin(), dim_list.end(), [index](const int64_t &dim) { return SizeToLong(index) == dim; });
553 if (pos != dim_list.end() && stra[index] != 1) {
554 MS_LOG(ERROR) << name_
555 << ": checking strategy failed. ReduceAny operator does not support reduced dimension split.";
556 return FAILED;
557 }
558 }
559 return SUCCESS;
560 }
561 } // namespace parallel
562 } // namespace mindspore
563