• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/parallel/ops_info/reduce_method_info.h"
18 
19 #include <algorithm>
20 #include <memory>
21 #include <utility>
22 #include <vector>
23 
24 #include "ir/value.h"
25 #include "frontend/parallel/device_manager.h"
26 #include "frontend/parallel/device_matrix.h"
27 #include "frontend/parallel/tensor_layout/tensor_redistribution.h"
28 #include "utils/log_adapter.h"
29 
30 namespace mindspore {
31 namespace parallel {
CheckStrategy(const StrategyPtr & strategy)32 Status ReduceMethod::CheckStrategy(const StrategyPtr &strategy) { return CheckStrategyValue(strategy, inputs_shape_); }
33 
InferDevMatrixShape()34 Status ReduceMethod::InferDevMatrixShape() {
35   Strategys stra = strategy_->GetInputDim();
36   Dimensions input_strategy = stra.at(0);
37 
38   dev_matrix_shape_ = input_strategy;
39 
40   return SUCCESS;
41 }
42 
reduce_dim()43 std::vector<int64_t> ReduceMethod::reduce_dim() {
44   std::vector<int64_t> dim_list;
45   if (input_value_.size() < 2) {
46     MS_LOG(EXCEPTION) << name_ << ": Input value size is smaller than 2.";
47   }
48   if (input_value_.back() == nullptr) {
49     MS_LOG(EXCEPTION) << name_ << ": Input value is nullptr.";
50   }
51   MS_ASSERT(inputs_shape_.size() == 1);
52   auto input_dim = inputs_shape_.at(0).size();
53   if (input_value_.back()->isa<ValueTuple>()) {
54     auto attr_axis = GetValue<std::vector<int64_t>>(input_value_.back());
55     // axis is (), reduce all dim
56     if (attr_axis.empty()) {
57       for (size_t i = 0; i < input_dim; ++i) {
58         dim_list.push_back(SizeToLong(i));
59       }
60     } else {
61       for (auto &axis : attr_axis) {
62         axis < 0 ? dim_list.push_back(axis + SizeToLong(input_dim)) : dim_list.push_back(axis);
63       }
64     }
65   } else if (input_value_.back()->isa<Int64Imm>()) {
66     int64_t axis = GetValue<int64_t>(input_value_.back());
67     axis < 0 ? dim_list.push_back(axis + SizeToLong(input_dim)) : dim_list.push_back(axis);
68   } else {
69     MS_LOG(EXCEPTION) << "Axis type is invalid.";
70   }
71 
72   return dim_list;
73 }
74 
GetAttrs()75 Status ReduceMethod::GetAttrs() {
76   // get attr cross_batch and keep_dims
77   auto keep_dims_iter = attrs_.find(KEEP_DIMS);
78   if (keep_dims_iter == attrs_.end()) {
79     MS_LOG(ERROR) << name_ << ": Don't have attr keep_dims.";
80     return FAILED;
81   }
82 
83   if (keep_dims_iter != attrs_.end()) {
84     MS_EXCEPTION_IF_NULL(keep_dims_iter->second);
85     if (!keep_dims_iter->second->isa<BoolImm>()) {
86       MS_LOG(ERROR) << name_ << ": Keep_dims is not a bool.";
87       return FAILED;
88     }
89     keepdims_ = keep_dims_iter->second->cast<BoolImmPtr>()->value();
90   }
91 
92   auto cross_batch_iter = attrs_.find(CROSS_BATCH);
93   if (cross_batch_iter != attrs_.end()) {
94     MS_EXCEPTION_IF_NULL(cross_batch_iter->second);
95     if (!cross_batch_iter->second->isa<BoolImm>()) {
96       MS_LOG(ERROR) << name_ << ": cross_batch is not a bool.";
97       return FAILED;
98     }
99     cross_batch_ = cross_batch_iter->second->cast<BoolImmPtr>()->value();
100   }
101   auto reducemethodcost = std::dynamic_pointer_cast<ReduceMethodCost>(operator_cost());
102   if (reducemethodcost == nullptr) {
103     MS_LOG(ERROR) << "Cost cast to ReduceMethodCostPtr failed!";
104     return FAILED;
105   }
106   reducemethodcost->set_cross_batch(cross_batch_);
107   return SUCCESS;
108 }
109 
InferTensorMap()110 Status ReduceMethod::InferTensorMap() {
111   Shape tensor_map_index, output_tensor_map;
112   std::vector<int64_t> dim_list;
113   size_t size = inputs_shape_.at(0).size();
114   // such as 4: tensor_map_index [3,2,1,0]
115   for (size_t i = 0; i < size; ++i) {
116     tensor_map_index.push_back((int64_t)(size - 1 - i));
117   }
118   dim_list = reduce_dim();
119   for (size_t i = 0; i < size; ++i) {
120     if (find(dim_list.begin(), dim_list.end(), SizeToLong(i)) != dim_list.end()) {
121       if (keepdims_) {
122         output_tensor_map.push_back(-1);
123       } else {
124         continue;
125       }
126     } else {
127       output_tensor_map.push_back(tensor_map_index[i]);
128     }
129   }
130   inputs_tensor_map_.push_back(tensor_map_index);
131   outputs_tensor_map_.push_back(output_tensor_map);
132 
133   return SUCCESS;
134 }
135 
IsDataParallelStrategy(const Dimensions & strategy,int32_t stage_id)136 bool IsDataParallelStrategy(const Dimensions &strategy, int32_t stage_id) {
137   CheckGlobalDeviceManager();
138   size_t total_dev_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
139   if (strategy.empty()) {
140     MS_LOG(EXCEPTION) << "IsDataParallelStrategy: strategy is empty";
141   }
142 
143   return (LongToSize(strategy[0]) == total_dev_num);
144 }
145 
InferForwardCommunication()146 Status ReduceMethod::InferForwardCommunication() {
147   Dimensions stra = strategy_->GetInputDim().at(0);
148   if (cross_batch_ && IsDataParallelStrategy(stra, stage_id_)) {
149     MS_LOG(INFO) << name_ << ": cross_batch is True, don't need to InferForwardCommunication";
150     return SUCCESS;
151   }
152   if (cross_batch_) {
153     MS_LOG(INFO) << name_ << ": cross_batch is True, don't need to InferForwardCommunication";
154     return SUCCESS;
155   }
156   forward_op_.clear();
157   std::vector<int64_t> dim_list = reduce_dim();
158   size_t size = stra.size();
159   // judge if the reduce dim is partitioned.
160   Shape group_creat_map;
161 
162   // if repeated calculation and the repeated_calc_num_ insert to the first dimension of dev matrix,
163   // it need to handle the first dimension of map.
164   if ((dev_matrix_shape_.size() > size) && !repeated_num_in_dev_matrix_right_) {
165     group_creat_map.push_back(SizeToInt(dev_matrix_shape_.size() - size_t(1)));
166   }
167   for (size_t index = 0; index < size; ++index) {
168     auto pos =
169       std::find_if(dim_list.begin(), dim_list.end(), [index](const int64_t &dim) { return SizeToLong(index) == dim; });
170     if (pos != dim_list.end() && stra[index] != 1) {
171       continue;
172     }
173     group_creat_map.push_back(SizeToLong(size) - SizeToLong(index) - 1);
174   }
175 
176   // if repeated calculation and the repeated_calc_num_ insert to the last dimension of dev matrix,
177   // it need to handle the group_creat_map and insert the 0 to the last dimension of the group_creat_map.
178   if (repeated_num_in_dev_matrix_right_ && (repeated_calc_num_ > 1)) {
179     for (auto &ele : group_creat_map) {
180       if (ele == MAP_NONE) {
181         continue;
182       }
183       ele += 1;
184     }
185     group_creat_map.push_back(0);
186   }
187   std::vector<Group> forward_group;
188   if (CreateGroupByTensorMap(group_creat_map, &forward_group) != SUCCESS) {
189     MS_LOG(ERROR) << name_ << ": InferForwardCommunication group failed.";
190     return FAILED;
191   }
192   if (!forward_group.empty()) {
193     Operator op = CreateAllReduceOp(reduce_method_, forward_group[0].name());
194     forward_op_.push_back(op);
195     std::string group_name = forward_group[0].name();
196     MS_LOG(INFO) << name_ << ": Forward communication group is " << group_name;
197   }
198 
199   return SUCCESS;
200 }
201 
CreateReduceMeanForwardOp(const std::vector<Group> & forward_group,const TypePtr & dtype)202 ForwardOp CreateReduceMeanForwardOp(const std::vector<Group> &forward_group, const TypePtr &dtype) {
203   // Create AllReduceSum op
204   Operator op0 = CreateAllReduceOp(REDUCE_OP_SUM, forward_group[0].name());
205   std::string group_name = forward_group[0].name();
206   MS_LOG(INFO) << "The group of forward all reduce is " << group_name;
207 
208   // Create RealDiv op
209   OperatorName operator1_name = REAL_DIV;
210   std::vector<Device> device_list = forward_group[0].GetDevicesList();
211   auto divisor = static_cast<float>(device_list.size());
212   mindspore::tensor::TensorPtr tensor_ptr = std::make_shared<mindspore::tensor::Tensor>(divisor, dtype);
213   ValuePtr op1_param_value = MakeValue(tensor_ptr);
214   Attr op1_param = std::make_pair("divisor", op1_param_value);
215   OperatorParams operator1_params = {std::make_pair(op1_param, 2)};
216   OperatorAttrs operator1_attrs;
217   OperatorArgs operator1_args = std::make_pair(operator1_attrs, operator1_params);
218   Operator op1 = std::make_pair(operator1_name, operator1_args);
219   ForwardOp forward_op = {op0, op1};
220 
221   std::string dtype_name = dtype->ToString();
222   MS_LOG(INFO) << "The divisor of Div op is " << device_list.size() << ", the dtype is " << dtype_name;
223   return forward_op;
224 }
225 
InferForwardCommunication()226 Status ReduceMeanInfo::InferForwardCommunication() {
227   Dimensions stra = strategy_->GetInputDim().at(0);
228   if (cross_batch_ && IsDataParallelStrategy(stra, stage_id_)) {
229     MS_LOG(INFO) << name_ << ": cross_batch is True, don't need to InferForwardCommunication";
230     return SUCCESS;
231   }
232   forward_op_.clear();
233   std::vector<int64_t> dim_list = reduce_dim();
234   size_t size = stra.size();
235   // judge if the reduce dim is partitioned.
236   Shape group_creat_map;
237 
238   // if repeated calculation and the repeated_calc_num_ insert to the first dimension of dev matrix,
239   // it need to handle the first dimension of map.
240   if ((dev_matrix_shape_.size() > size) && !repeated_num_in_dev_matrix_right_) {
241     group_creat_map.push_back(SizeToInt(dev_matrix_shape_.size() - size_t(1)));
242   }
243 
244   for (size_t index = 0; index < size; ++index) {
245     auto pos =
246       std::find_if(dim_list.begin(), dim_list.end(), [index](const int64_t &dim) { return SizeToLong(index) == dim; });
247     if (pos != dim_list.end() && stra[index] != 1) {
248       continue;
249     }
250     group_creat_map.push_back(SizeToLong(size) - SizeToLong(index) - 1);
251   }
252 
253   // if repeated calculation and the repeated_calc_num_ insert to the last dimension of dev matrix,
254   // it need to handle the group_creat_map and insert the 0 to the last dimension of the group_creat_map.
255   if (repeated_num_in_dev_matrix_right_ && (repeated_calc_num_ > 1)) {
256     for (auto &ele : group_creat_map) {
257       if (ele == MAP_NONE) {
258         continue;
259       }
260       ele += 1;
261     }
262     group_creat_map.push_back(0);
263   }
264 
265   std::vector<Group> forward_group;
266   if (CreateGroupByTensorMap(group_creat_map, &forward_group) != SUCCESS) {
267     MS_LOG(ERROR) << name_ << ": InferForwardCommunication group failed.";
268     return FAILED;
269   }
270   if (!forward_group.empty()) {
271     if ((outputs_dtype_ == nullptr) || !outputs_dtype_->isa<mindspore::TensorType>()) {
272       MS_LOG(ERROR) << name_ << ": The dtype of output is not Array";
273       return FAILED;
274     }
275 
276     auto element_type = outputs_dtype_->cast<mindspore::TensorTypePtr>()->element();
277     forward_op_ = CreateReduceMeanForwardOp(forward_group, element_type);
278   }
279 
280   return SUCCESS;
281 }
282 
InferMirrorOps()283 Status ReduceMethod::InferMirrorOps() {
284   mirror_ops_.clear();
285   Shape input_tensor_map = inputs_tensor_map_.at(0);
286   std::vector<Group> input_group;
287   if (CreateGroupByTensorMap(input_tensor_map, &input_group) != SUCCESS) {
288     MS_LOG(ERROR) << name_ << " Infer MirrorOps failed.";
289     return FAILED;
290   }
291 
292   OperatorVector op_for_weight;
293   OperatorVector op_for_reduce_axis;  // helper node
294   if (input_group.empty()) {
295     MS_LOG(INFO) << name_ << ": The mirror ops is empty.";
296     return SUCCESS;
297   } else {
298     op_for_weight = CreateMirrorOps(input_group[0].name(), input_group[0].GetDevNum());
299     mirror_ops_.push_back(op_for_weight);
300     mirror_ops_.push_back(op_for_reduce_axis);
301     std::string group_name = input_group[0].name();
302     MS_LOG(INFO) << name_ << ": Create the mirror ops for weight success, the group is " << group_name;
303   }
304 
305   return SUCCESS;
306 }
307 
InferMirrorOps()308 Status ArgMaxWithValueInfo::InferMirrorOps() {
309   mirror_ops_.clear();
310   Shape input_tensor_map = inputs_tensor_map_.at(0);
311   std::vector<Group> input_group;
312   if (CreateGroupByTensorMap(input_tensor_map, &input_group) != SUCCESS) {
313     MS_LOG(ERROR) << name_ << ": Infer MirrorOps failed.";
314     return FAILED;
315   }
316 
317   OperatorVector op_for_weight;
318   if (input_group.empty()) {
319     MS_LOG(INFO) << name_ << ": The mirror ops is empty.";
320     return SUCCESS;
321   } else {
322     op_for_weight = CreateMirrorOps(input_group[0].name(), input_group[0].GetDevNum());
323     mirror_ops_.push_back(op_for_weight);
324     MS_LOG(INFO) << name_ << ": Create the mirror ops for weight success.";
325   }
326 
327   return SUCCESS;
328 }
329 
InferOutputStrategy()330 Dimensions ReduceMethod::InferOutputStrategy() {
331   std::vector<int64_t> dim_list = reduce_dim();
332   Dimensions output_strategy;
333   Dimensions stra = strategy_->GetInputDim().at(0);
334   // if keepdims_ is true,then output strategy is same with input.
335   for (size_t i = 0; i < stra.size(); ++i) {
336     if (find(dim_list.begin(), dim_list.end(), SizeToLong(i)) != dim_list.end()) {
337       if (keepdims_) {
338         output_strategy.push_back(1);
339       }
340     } else {
341       output_strategy.push_back(stra[i]);
342     }
343   }
344   return output_strategy;
345 }
346 
InferTensorInfo()347 Status ReduceMethod::InferTensorInfo() {
348   // infer tensor shape
349   Shape input_shape = inputs_shape_.at(0);
350   Shape output_shape = outputs_shape_.at(0);
351 
352   // infer slice shape
353   Shapes inputs_slice_shape, outputs_slice_shape;
354   Strategys inputs_strategy = strategy_->GetInputDim();
355   Dimensions output_strategy = InferOutputStrategy();
356 
357   Strategys outputs_strategy = {output_strategy};
358   if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) {
359     return FAILED;
360   }
361   Shape input_slice_shape = inputs_slice_shape.at(0);
362   Shape output_slice_shape = outputs_slice_shape.at(0);
363 
364   TensorLayout input_tensor_layout, output_tensor_layout;
365   if ((input_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], input_shape) != SUCCESS) ||
366       (output_tensor_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[0], output_shape) != SUCCESS)) {
367     return FAILED;
368   }
369 
370   std::vector<int64_t> dim_list = reduce_dim();
371   TensorInfo input_tensor_info(input_tensor_layout, input_shape, input_slice_shape);
372   TensorInfo output_tensor_info(output_tensor_layout, output_shape, output_slice_shape);
373   input_tensor_info.set_reduce_dim(dim_list);
374 
375   inputs_tensor_info_.push_back(input_tensor_info);
376   outputs_tensor_info_.push_back(output_tensor_info);
377 
378   return SUCCESS;
379 }
380 
SetCostUnderStrategy(const StrategyPtr & strategy)381 Status ReduceMethod::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }
382 
GenerateOpStrategies(int64_t stage_id)383 std::vector<StrategyPtr> ReduceMethod::GenerateOpStrategies(int64_t stage_id) {
384   Shape input0_split(inputs_shape_[0].size(), 1);
385   Shapes splittable_inputs = {input0_split};
386   std::vector<StrategyPtr> sp_vector;
387   if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) {
388     MS_LOG(EXCEPTION) << name_ << ": GenerateStrategiesForIndependentInputs failed.";
389   }
390   return sp_vector;
391 }
392 
Init(const StrategyPtr & strategy)393 Status ReduceMethod::Init(const StrategyPtr &strategy) {
394   if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
395     MS_LOG(ERROR) << name_ << ": Init failed.";
396     return FAILED;
397   }
398 
399   return SUCCESS;
400 }
401 
InitForCostModel(const StrategyPtr & strategy)402 Status ReduceMethod::InitForCostModel(const StrategyPtr &strategy) {
403   if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
404     MS_LOG(ERROR) << name_ << ": Init for cost model failed";
405     return FAILED;
406   }
407 
408   MS_LOG(INFO) << name_ << ": Init for cost model success";
409   return SUCCESS;
410 }
411 
reduce_dim()412 std::vector<int64_t> ArgMaxWithValueInfo::reduce_dim() {
413   std::vector<int64_t> dim_list;
414   auto iter = attrs_.find(AXIS);
415   if (iter == attrs_.end()) {
416     MS_LOG(EXCEPTION) << name_ << ": Don't have attr axis.";
417   }
418 
419   MS_ASSERT(inputs_shape_.size() == 1);
420   auto input_dim = inputs_shape_.at(0).size();
421   MS_EXCEPTION_IF_NULL(iter->second);
422   if (iter->second->isa<ValueTuple>()) {
423     auto attr_axis = GetValue<std::vector<int64_t>>(iter->second);
424     if (attr_axis.empty()) {
425       for (size_t i = 0; i < input_dim; ++i) {
426         dim_list.push_back(SizeToLong(i));
427       }
428     } else {
429       for (auto &axis : attr_axis) {
430         axis < 0 ? dim_list.push_back(axis + SizeToLong(input_dim)) : dim_list.push_back(axis);
431       }
432     }
433   } else if (iter->second->isa<Int64Imm>()) {
434     int64_t axis = GetValue<int64_t>(iter->second);
435     axis < 0 ? dim_list.push_back(axis + SizeToLong(input_dim)) : dim_list.push_back(axis);
436   } else {
437     MS_LOG(EXCEPTION) << "Axis type is invalid.";
438   }
439 
440   return dim_list;
441 }
442 
CheckStrategy(const StrategyPtr & strategy)443 Status ArgMaxWithValueInfo::CheckStrategy(const StrategyPtr &strategy) {
444   if (ReduceMethod::CheckStrategy(strategy) != SUCCESS) {
445     MS_LOG(ERROR) << name_ << ": CheckStrategy for parent class ReduceMethod failed";
446     return FAILED;
447   }
448   std::vector<int64_t> dim_list = reduce_dim();
449   MS_ASSERT(dim_list.size() == 1);
450 
451   Strategys stra = strategy->GetInputDim();
452   MS_ASSERT(stra.size() == 1);
453   Shape input_strategy = stra.at(0);
454   MS_ASSERT(dim_list.at(0) < input_strategy.size());
455   if (input_strategy.at(LongToSize(dim_list.at(0))) != 1) {
456     MS_LOG(WARNING)
457       << name_
458       << " CheckStrategy for ArgMaxWithValueInfo, the strategy corresponding to axis is not one, real strategy "
459          "is  "
460       << input_strategy.at(LongToSize(dim_list.at(0)))
461       << ", the output index may be not compatible with the stand alone Primitive";
462   }
463   return SUCCESS;
464 }
465 
InferTensorMap()466 Status ArgMaxWithValueInfo::InferTensorMap() {
467   if (ReduceMethod::InferTensorMap() != SUCCESS) {
468     MS_LOG(ERROR) << name_ << ": InferTensorMap for parent class ReduceMethod failed";
469     return FAILED;
470   }
471   MS_ASSERT(outputs_tensor_map_.size() == 1);
472   outputs_tensor_map_.push_back(outputs_tensor_map_[0]);
473   return SUCCESS;
474 }
475 
InferTensorInfo()476 Status ArgMaxWithValueInfo::InferTensorInfo() {
477   // infer tensor shape
478   Shape input_shape = inputs_shape_.at(0);
479   Shape output_shape = outputs_shape_.at(0);
480 
481   // infer slice shape
482   Shapes inputs_slice_shape, outputs_slice_shape;
483   Strategys inputs_strategy = strategy_->GetInputDim();
484   Dimensions output_strategy = InferOutputStrategy();
485 
486   Strategys outputs_strategy = {output_strategy, output_strategy};
487   if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) {
488     return FAILED;
489   }
490   Shape input_slice_shape = inputs_slice_shape.at(0);
491   Shape output_slice_shape = outputs_slice_shape.at(0);
492 
493   TensorLayout input_tensor_layout, output_tensor_layout;
494   if ((input_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], input_shape) != SUCCESS) ||
495       (output_tensor_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[0], output_shape) != SUCCESS)) {
496     return FAILED;
497   }
498 
499   std::vector<int64_t> dim_list = reduce_dim();
500   TensorInfo input_tensor_info(input_tensor_layout, input_shape, input_slice_shape);
501   TensorInfo output_tensor_info(output_tensor_layout, output_shape, output_slice_shape);
502   input_tensor_info.set_reduce_dim(dim_list);
503 
504   inputs_tensor_info_.push_back(input_tensor_info);
505   outputs_tensor_info_.push_back(output_tensor_info);
506   outputs_tensor_info_.push_back(output_tensor_info);
507   return SUCCESS;
508 }
509 
InferAsLossDivisor()510 Status ArgMaxWithValueInfo::InferAsLossDivisor() {
511   if (outputs_tensor_map_.empty()) {
512     MS_LOG(ERROR) << name_ << ": The outputs tensor map is empty.";
513     return FAILED;
514   }
515 
516   MS_LOG(INFO) << name_ << " has two outputs, use output[0] to infer";
517   if (outputs_tensor_map_[0].empty()) {
518     as_loss_divisor_ = stage_device_size_;
519     MS_LOG(INFO) << name_ << ": The output is a scalar, use the dev size" << as_loss_divisor_ << " as loss divisor.";
520     return SUCCESS;
521   }
522 
523   as_loss_divisor_ = ComputeRepeatDeviceNumByTensorMap(dev_matrix_shape_, outputs_tensor_map_[0]);
524 
525   std::string dev_matrix_shape_str = ShapeToString(dev_matrix_shape_);
526   std::string output_tensor_map_str = ShapeToString(outputs_tensor_map_[0]);
527   MS_LOG(INFO) << name_ << ": the dev matrix shape, the output tensor map, and loss divisor is " << dev_matrix_shape_str
528                << ", " << output_tensor_map_str << ", " << as_loss_divisor_;
529   return SUCCESS;
530 }
531 
GenerateOpStrategies(int64_t stage_id)532 std::vector<StrategyPtr> ArgMaxWithValueInfo::GenerateOpStrategies(int64_t stage_id) {
533   Shape input0_split(inputs_shape_[0].size(), 1);
534   Shapes splittable_inputs = {input0_split};
535   std::vector<StrategyPtr> sp_vector;
536   if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) {
537     MS_LOG(EXCEPTION) << name_ << ": GenerateStrategiesForIndependentInputs failed.";
538   }
539 
540   return sp_vector;
541 }
542 
CheckStrategy(const StrategyPtr & strategy)543 Status ReduceAnyInfo::CheckStrategy(const StrategyPtr &strategy) {
544   if (ReduceMethod::CheckStrategy(strategy) != SUCCESS) {
545     MS_LOG(ERROR) << name_ << ": checking strategy failed.";
546     return FAILED;
547   }
548   auto dim_list = ReduceMethod::reduce_dim();
549   Dimensions stra = strategy->GetInputDim().at(0);
550   for (size_t index = 0; index < stra.size(); ++index) {
551     auto pos =
552       std::find_if(dim_list.begin(), dim_list.end(), [index](const int64_t &dim) { return SizeToLong(index) == dim; });
553     if (pos != dim_list.end() && stra[index] != 1) {
554       MS_LOG(ERROR) << name_
555                     << ": checking strategy failed. ReduceAny operator does not support reduced dimension split.";
556       return FAILED;
557     }
558   }
559   return SUCCESS;
560 }
561 }  // namespace parallel
562 }  // namespace mindspore
563