• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/parallel/ops_info/activation_info.h"
18 
19 #include <algorithm>
20 #include <memory>
21 #include <vector>
22 #include <utility>
23 #include <functional>
24 #include <numeric>
25 
26 #include "ir/value.h"
27 #include "frontend/parallel/auto_parallel/costmodel.h"
28 #include "frontend/parallel/device_matrix.h"
29 #include "frontend/parallel/strategy.h"
30 
31 namespace mindspore {
32 namespace parallel {
SetCostUnderStrategy(const StrategyPtr & strategy)33 Status Activation::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }
34 
CheckStrategy(const StrategyPtr & strategy)35 Status Activation::CheckStrategy(const StrategyPtr &strategy) { return CheckStrategyValue(strategy, inputs_shape_); }
36 
GetAttrs()37 Status ActivationInfo::GetAttrs() {
38   if (attrs_.size() < ACTIVATION_ATTR_SIZE) {
39     MS_LOG(ERROR) << name_ << " : The size of attrs small than 1.";
40     return FAILED;
41   }
42 
43   if ((inputs_shape_.size() != ACTIVATION_INPUTS_SIZE) || (outputs_shape_.size() != ACTIVATION_OUTPUTS_SIZE)) {
44     MS_LOG(ERROR) << name_ << " : Inputs shape size(" << inputs_shape_.size() << ") or outputs shape size("
45                   << outputs_shape_.size() << "is wrong.";
46     return FAILED;
47   }
48 
49   auto iter = attrs_.find(ACTIVATION_TYPE);
50   if (iter != attrs_.end()) {
51     MS_EXCEPTION_IF_NULL(iter->second);
52     if (iter->second->isa<StringImm>()) {
53       std::string val = iter->second->cast<StringImmPtr>()->value();
54       if ((val != RELU_TYPE) && (val != RELU6_TYPE) && (val != SIGMOID_TYPE)) {
55         MS_LOG(ERROR) << name_ << " : Activation type is wrong.";
56         return FAILED;
57       }
58     } else {
59       MS_LOG(ERROR) << name_ << " : The value of activation_type is not string.";
60       return FAILED;
61     }
62   }
63 
64   return SUCCESS;
65 }
66 
GetAttrs()67 Status ActivationOther::GetAttrs() {
68   if ((inputs_shape_.size() != ACTIVATION_INPUTS_SIZE) || (outputs_shape_.size() != ACTIVATION_OUTPUTS_SIZE)) {
69     MS_LOG(ERROR) << name_ << " : Inputs shape size(" << inputs_shape_.size() << ") or outputs shape size("
70                   << outputs_shape_.size() << "is wrong.";
71     return FAILED;
72   }
73   return SUCCESS;
74 }
75 
GenerateOpStrategies(int64_t stage_id)76 std::vector<StrategyPtr> Activation::GenerateOpStrategies(int64_t stage_id) {
77   std::vector<StrategyPtr> sp_vector;
78   if ((inputs_shape_.size() != ACTIVATION_INPUTS_SIZE) || (outputs_shape_.size() != ACTIVATION_OUTPUTS_SIZE)) {
79     MS_LOG(EXCEPTION) << name_ << " : Inputs shape size(" << inputs_shape_.size() << ") or outputs shape size("
80                       << outputs_shape_.size() << "is wrong.";
81   }
82 
83   Shape input0_split(inputs_shape_[0].size(), 1);
84   Shapes splittable_inputs = {input0_split};
85 
86   if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) {
87     MS_LOG(EXCEPTION) << name_ << " : Generate strategies for independent inputs() failed.";
88   }
89 
90   return sp_vector;
91 }
92 
GenerateOpStrategies(int64_t stage_id)93 std::vector<StrategyPtr> DropoutInfo::GenerateOpStrategies(int64_t stage_id) {
94   Shape input0_split(inputs_shape_[0].size(), 1);
95   Shapes splittable_inputs = {input0_split};
96 
97   std::vector<StrategyPtr> sp_vector;
98   if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) {
99     MS_LOG(EXCEPTION) << name_ << " : Generate strategies for independent inputs() failed.";
100   }
101   return sp_vector;
102 }
103 
CheckStrategy(const StrategyPtr & strategy)104 Status Softmax::CheckStrategy(const StrategyPtr &strategy) {
105   if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
106     MS_LOG(ERROR) << name_ << " : Invalid strategy.";
107     return FAILED;
108   }
109 
110   Strategys stra = strategy->GetInputDim();
111   Dimensions input_strategy = stra.at(0);
112 
113   for (auto &element : axis_) {
114     int64_t axis_index = element;
115     if (element < 0) {
116       size_t input_dim = inputs_shape_.at(0).size();
117       axis_index = static_cast<int64_t>(input_dim) + element;
118     }
119 
120     int64_t axis_strategy = input_strategy.at(LongToSize(axis_index));
121     // Dimension corresponding to axis is un-splittable
122     if (axis_strategy != MIN_SLICE_NUM) {
123       MS_LOG(ERROR) << name_ << " : The strategy corresponding to axis dimension(" << axis_strategy << ") is not 1";
124       return FAILED;
125     }
126   }
127 
128   return SUCCESS;
129 }
130 
GetAttrs()131 Status Softmax::GetAttrs() {
132   if (attrs_.size() < SOFTMAX_ATTR_SIZE) {
133     MS_LOG(ERROR) << name_ << " : The size of attrs small than 1.";
134     return FAILED;
135   }
136 
137   auto iter = attrs_.find(AXIS);
138   if (iter != attrs_.end()) {
139     MS_EXCEPTION_IF_NULL(iter->second);
140     if (iter->second->isa<Int64Imm>()) {  // the axis is a number
141       int64_t axis_element = iter->second->cast<Int64ImmPtr>()->value();
142       axis_.push_back(axis_element);
143       MS_LOG(INFO) << name_ << " : The axis is int64_t, value is " << axis_element;
144     } else if (iter->second->isa<ValueTuple>()) {  // the axis is a tuple
145       ValueTuplePtr value_tuple = iter->second->cast<ValueTuplePtr>();
146       if (value_tuple == nullptr) {
147         MS_LOG(ERROR) << name_ << " : The value_tuple is nullptr.";
148         return FAILED;
149       }
150       std::vector<ValuePtr> value_vector = value_tuple->value();
151       (void)std::transform(value_vector.begin(), value_vector.end(), std::back_inserter(axis_),
152                            [](const ValuePtr &value) { return static_cast<int64_t>(GetValue<int64_t>(value)); });
153       if (axis_.empty()) {
154         MS_LOG(ERROR) << name_ << " : The axis tuple is empty.";
155         return FAILED;
156       }
157       MS_LOG(INFO) << name_ << " : The axis is tuple, value is " << ListToString(axis_);
158     } else {
159       MS_LOG(ERROR) << name_ << " : The value of axis is not int64_t or tuple int64_t.";
160       return FAILED;
161     }
162   }
163 
164   if ((inputs_shape_.size() != ACTIVATION_INPUTS_SIZE) || (outputs_shape_.size() != ACTIVATION_OUTPUTS_SIZE)) {
165     MS_LOG(ERROR) << name_ << " : Inputs shape size or outputs shape size is wrong.";
166     return FAILED;
167   }
168 
169   // for example: tensor dimension is 4, then axis range [-4, 3]
170   int64_t dim = SizeToLong(inputs_shape_.at(0).size());
171   auto it =
172     std::find_if(axis_.begin(), axis_.end(), [dim](int64_t element) { return ((element >= dim) || (element < -dim)); });
173   if (it != axis_.end()) {
174     MS_LOG(ERROR) << name_ << " : The axis(" << *it << ") is out of range[" << (-dim) << ", " << (dim - 1) << "].";
175     return FAILED;
176   }
177 
178   return SUCCESS;
179 }
180 
SetCostUnderStrategy(const StrategyPtr & strategy)181 Status Softmax::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }
182 
GenerateOpStrategies(int64_t stage_id)183 std::vector<StrategyPtr> Softmax::GenerateOpStrategies(int64_t stage_id) {
184   if ((inputs_shape_.size() != ACTIVATION_INPUTS_SIZE) || (outputs_shape_.size() != ACTIVATION_OUTPUTS_SIZE)) {
185     MS_LOG(EXCEPTION) << name_ << " : Inputs shape size or outputs shape size is wrong.";
186   }
187 
188   Shape input0_split;
189   (void)input0_split.insert(input0_split.begin(), inputs_shape_[0].size(), 1);
190   for (auto &element : axis_) {
191     int64_t axis_index = element;
192     if (element < 0) {
193       size_t input_dim = inputs_shape_.at(0).size();
194       axis_index = static_cast<int64_t>(input_dim) + element;
195     }
196     input0_split[LongToSize(axis_index)] = 0;
197   }
198   Shapes splittable_inputs = {input0_split};
199 
200   std::vector<StrategyPtr> sp_vector;
201   if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) {
202     MS_LOG(EXCEPTION) << name_ << " : Generate strategies for independent inputs failed.";
203   }
204   return sp_vector;
205 }
206 
InferDevMatrixShape()207 Status ActivationBase::InferDevMatrixShape() {
208   Strategys stra = strategy_->GetInputDim();
209   Dimensions input_strategy = stra.at(0);
210 
211   dev_matrix_shape_ = input_strategy;
212 
213   return SUCCESS;
214 }
215 
InferMirrorOps()216 Status ActivationBase::InferMirrorOps() {
217   mirror_ops_.clear();
218 
219   Shape tensor_map = inputs_tensor_map_[0];
220   std::vector<Group> group;
221   if (CreateGroupByTensorMap(tensor_map, &group) != SUCCESS) {
222     MS_LOG(ERROR) << name_ << " : Create group failed.";
223     return FAILED;
224   }
225 
226   OperatorVector mirror_op;
227   if (group.empty()) {
228     MS_LOG(INFO) << name_ << " : The mirror ops is empty.";
229     return SUCCESS;
230   } else {
231     mirror_op = CreateMirrorOps(group[0].name(), group[0].GetDevNum());
232     mirror_ops_.push_back(mirror_op);
233     std::string group_name = group[0].name();
234     MS_LOG(INFO) << name_ << " : Create the mirror ops success, the group name is " << group_name;
235   }
236 
237   return SUCCESS;
238 }
239 
InferForwardCommunication()240 Status ActivationBase::InferForwardCommunication() {
241   // do nothing
242   return SUCCESS;
243 }
244 
InferTensorMap()245 Status ActivationBase::InferTensorMap() {
246   Shape tensor_map_index;
247   size_t size = inputs_shape_.at(0).size();
248   // such as 4: tensor_map_index [3,2,1,0]
249   for (size_t i = 0; i < size; ++i) {
250     tensor_map_index.push_back((int64_t)(size - i - 1));
251   }
252 
253   inputs_tensor_map_.push_back(tensor_map_index);
254   outputs_tensor_map_.push_back(tensor_map_index);
255   return SUCCESS;
256 }
257 
GetAttrs()258 Status DropoutInfo::GetAttrs() {
259   auto iter0 = attrs_.find(SEED0);
260   if (iter0 != attrs_.end()) {
261     MS_EXCEPTION_IF_NULL(iter0->second);
262     if (iter0->second->isa<Int64Imm>()) {
263       seed0_ = iter0->second->cast<Int64ImmPtr>()->value();
264     } else {
265       MS_LOG(ERROR) << name_ << " : The value of seed0 is not int64_t.";
266       return FAILED;
267     }
268   }
269   auto iter1 = attrs_.find(SEED1);
270   if (iter1 != attrs_.end()) {
271     MS_EXCEPTION_IF_NULL(iter1->second);
272     if (iter1->second->isa<Int64Imm>()) {
273       seed1_ = iter1->second->cast<Int64ImmPtr>()->value();
274     } else {
275       MS_LOG(ERROR) << name_ << " : The value of seed1 is not int64_t.";
276       return FAILED;
277     }
278   }
279   return SUCCESS;
280 }
281 
InferTensorMap()282 Status DropoutInfo::InferTensorMap() {
283   Shape tensor_map_in;
284   size_t size = inputs_shape_.at(0).size();
285   // such as 4: tensor_map_index [3,2,1,0]
286   for (size_t i = 0; i < size; ++i) {
287     tensor_map_in.push_back((int64_t)(size - i - 1));
288   }
289 
290   inputs_tensor_map_.push_back(tensor_map_in);
291   outputs_tensor_map_.push_back(tensor_map_in);
292   outputs_tensor_map_.push_back(tensor_map_in);  // the dropout has two outputs
293   return SUCCESS;
294 }
295 
InferAsLossDivisor()296 Status DropoutInfo::InferAsLossDivisor() {
297   if (outputs_tensor_map_.empty()) {
298     MS_LOG(ERROR) << name_ << ": The size of outputs tensor map is empty";
299     return FAILED;
300   }
301   as_loss_divisor_ = ComputeRepeatDeviceNumByTensorMap(dev_matrix_shape_, outputs_tensor_map_[0]);
302   MS_LOG(INFO) << name_ << " : The dev matrix shape is " << ShapeToString(dev_matrix_shape_)
303                << ", the output[0]'s tensor map is " << ShapeToString(outputs_tensor_map_[0])
304                << ", as_loss_divisor_ is " << as_loss_divisor_;
305   return SUCCESS;
306 }
307 
InferReplaceOps()308 Status DropoutInfo::InferReplaceOps() {
309   if ((seed0_ != 0) || (seed1_ != 0) || (repeated_calc_num_ == 1)) {
310     return SUCCESS;
311   }
312   int64_t seed = get_seed();
313   ValuePtr new_seed0 = MakeValue(seed);
314   ValuePtr new_seed1 = MakeValue(seed);
315   Attr attr_seed0 = std::make_pair(SEED0, new_seed0);
316   Attr attr_seed1 = std::make_pair(SEED1, new_seed1);
317   Attr attr_keep_probs = std::make_pair(KEEP_PROB, attrs_[KEEP_PROB]);
318   OperatorAttrs attrs = {attr_keep_probs, attr_seed0, attr_seed1};
319   OperatorParams params;
320   OperatorArgs args = std::make_pair(attrs, params);
321   replace_op_ = {std::make_pair(DROPOUT, args)};
322   return SUCCESS;
323 }
324 
Init(const StrategyPtr & strategy)325 Status DropoutInfo::Init(const StrategyPtr &strategy) {
326   if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
327     MS_LOG(ERROR) << name_ << " : Init failed";
328     return FAILED;
329   }
330   (void)InferReplaceOps();
331 
332   MS_LOG(INFO) << name_ << " : Init success";
333   return SUCCESS;
334 }
335 
Init(const StrategyPtr & strategy)336 Status ActivationBase::Init(const StrategyPtr &strategy) {
337   if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
338     MS_LOG(ERROR) << name_ << " : Init failed.";
339     return FAILED;
340   }
341 
342   MS_LOG(INFO) << name_ << " : Init success.";
343   return SUCCESS;
344 }
345 
InitForCostModel(const StrategyPtr & strategy)346 Status ActivationBase::InitForCostModel(const StrategyPtr &strategy) {
347   if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
348     MS_LOG(ERROR) << name_ << " : Init for cost model failed.";
349     return FAILED;
350   }
351 
352   MS_LOG(INFO) << name_ << " : Init for cost model success.";
353   return SUCCESS;
354 }
355 
InferMirrorOps()356 Status CastInfo::InferMirrorOps() {
357   mirror_ops_.clear();
358 
359   Shape tensor_map = inputs_tensor_map_[0];
360   std::vector<Group> group;
361   if (CreateGroupByTensorMap(tensor_map, &group) != SUCCESS) {
362     MS_LOG(ERROR) << name_ << " : Create group failed.";
363     return FAILED;
364   }
365 
366   OperatorVector mirror_op;
367   OperatorVector op_for_value;
368   if (group.empty()) {
369     MS_LOG(INFO) << name_ << " : The mirror ops is empty.";
370     return SUCCESS;
371   } else {
372     mirror_op = CreateMirrorOps(group[0].name(), group[0].GetDevNum());
373     mirror_ops_.push_back(mirror_op);
374     mirror_ops_.push_back(op_for_value);
375     std::string group_name = group[0].name();
376     MS_LOG(INFO) << name_ << " : Create the mirror ops success, the group name is " << group_name;
377   }
378 
379   return SUCCESS;
380 }
381 
GetAttrs()382 Status ExpandDimsInfo::GetAttrs() {
383   if (input_value_.size() != EXPANDDIMS_INPUT_SIZE) {
384     MS_LOG(ERROR) << name_ << ": Invalid inputs size " << input_value_.size();
385     return FAILED;
386   }
387 
388   if (!input_value_.back()->isa<Int64Imm>()) {
389     MS_LOG(ERROR) << name_ << ": The type of axis is not int64_t";
390     return FAILED;
391   }
392 
393   int64_t axis = GetValue<int64_t>(input_value_.back());
394 
395   if (inputs_shape_.empty()) {
396     MS_LOG(ERROR) << name_ << ": The inputs shape is empty";
397     return FAILED;
398   }
399 
400   int64_t dim = SizeToLong(inputs_shape_[0].size());
401   if ((axis > dim) || (axis < -dim - 1)) {
402     MS_LOG(ERROR) << name_ << ": The axis(" << axis << ") is out of range[" << (-dim - 1) << ", " << dim << "]";
403     return FAILED;
404   }
405 
406   if (axis < 0) {
407     positive_axis_ = dim + axis + 1;
408   } else {
409     positive_axis_ = axis;
410   }
411   MS_LOG(INFO) << name_ << ": The axis is " << axis << ", and the positive axis is " << positive_axis_;
412   return SUCCESS;
413 }
414 
InferTensorMap()415 Status ExpandDimsInfo::InferTensorMap() {
416   if (inputs_shape_.empty()) {
417     MS_LOG(ERROR) << name_ << ": The inputs shape is empty";
418     return FAILED;
419   }
420 
421   // for example: if the dimension of input is 3, and the axis is 2,
422   // then the input_tensor_map is [2, 1, 0], the output_tensor_map is [2, 1, -1, 0]
423   Shape input_tensor_map, output_tensor_map;
424   size_t size = inputs_shape_[0].size();
425   for (size_t i = 0; i < size; ++i) {
426     input_tensor_map.push_back(SizeToLong(size - i - 1));
427   }
428 
429   inputs_tensor_map_.push_back(input_tensor_map);
430 
431   output_tensor_map = input_tensor_map;
432   if ((positive_axis_ < 0) || (positive_axis_ > SizeToLong(size))) {
433     MS_LOG(ERROR) << name_ << ": Invalid positive axis " << positive_axis_;
434     return FAILED;
435   }
436   (void)output_tensor_map.insert(output_tensor_map.begin() + positive_axis_, NO_SPLIT_MAP);
437   outputs_tensor_map_.push_back(output_tensor_map);
438 
439   MS_LOG(INFO) << name_ << ": The tensor map of input is " << ShapeToString(input_tensor_map)
440                << ", and the tensor map of output is " << ShapeToString(output_tensor_map);
441   return SUCCESS;
442 }
443 
InferTensorStrategy()444 Status ExpandDimsInfo::InferTensorStrategy() {
445   if (strategy_ == nullptr) {
446     MS_LOG(ERROR) << name_ << ": The strategy is null";
447     return FAILED;
448   }
449 
450   inputs_strategy_ = strategy_->GetInputDim();
451   if (inputs_strategy_.empty()) {
452     MS_LOG(ERROR) << name_ << ": The strategy is empty";
453     return FAILED;
454   }
455 
456   Shape output_strategy = inputs_strategy_[0];
457   if ((positive_axis_ < 0) || (positive_axis_ > SizeToLong(output_strategy.size()))) {
458     MS_LOG(ERROR) << name_ << ": Invalid positive axis " << positive_axis_;
459     return FAILED;
460   }
461   (void)output_strategy.insert(output_strategy.begin() + positive_axis_, NO_SPLIT_STRATEGY);
462   outputs_strategy_ = {output_strategy};
463   return SUCCESS;
464 }
465 
InferMirrorOps()466 Status ExpandDimsInfo::InferMirrorOps() {
467   mirror_ops_.clear();
468 
469   if (inputs_tensor_map_.empty()) {
470     MS_LOG(ERROR) << name_ << ": The tensor map of inputs is empty";
471     return FAILED;
472   }
473 
474   std::vector<Group> group;
475   if (CreateGroupByTensorMap(inputs_tensor_map_[0], &group) != SUCCESS) {
476     MS_LOG(ERROR) << name_ << ": Create group failed";
477     return FAILED;
478   }
479 
480   if (group.empty()) {
481     MS_LOG(INFO) << name_ << ": No need to create mirror ops";
482     return SUCCESS;
483   }
484 
485   OperatorVector mirror_op, placeholder_op;
486   mirror_op = CreateMirrorOps(group[0].name(), group[0].GetDevNum());
487   mirror_ops_.push_back(mirror_op);
488   mirror_ops_.push_back(placeholder_op);
489   MS_LOG(INFO) << name_ << ": Create mirror ops success, the group name is " << group[0].name();
490   return SUCCESS;
491 }
492 
InferAxis(const ValueTuplePtr & value_tuple)493 Status SqueezeInfo::InferAxis(const ValueTuplePtr &value_tuple) {
494   std::vector<int64_t> axis;
495   auto axis_list = value_tuple->value();
496   if (inputs_shape_.empty()) {
497     MS_LOG(ERROR) << name_ << ": The inputs shape is empty";
498     return FAILED;
499   }
500   Shape input_shape = inputs_shape_.at(0);
501   size_t input_size = input_shape.size();
502   // if axis tuple is empty, we should exclude the axis that the corresponding slice shape is 1.
503   if (axis_list.empty()) {
504     for (size_t i = 0; i < input_size; ++i) {
505       if (input_shape[i] == 1) {
506         axis.push_back(i);
507       }
508     }
509     axis_ = MakeValue(axis)->cast<ValueTuplePtr>();
510     return SUCCESS;
511   }
512 
513   // convert negative axis to positive.
514   for (auto &dim : axis_list) {
515     if (!dim->isa<Int64Imm>()) {
516       MS_LOG(ERROR) << name_ << ": The type of axis is not int64_t";
517       return FAILED;
518     }
519     int64_t dim_value = GetValue<int64_t>(dim);
520     int64_t positive_value = (dim_value < 0) ? (dim_value + SizeToLong(input_size)) : dim_value;
521     axis.push_back(positive_value);
522   }
523   axis_ = MakeValue(axis)->cast<ValueTuplePtr>();
524   return SUCCESS;
525 }
526 
GetAttrs()527 Status SqueezeInfo::GetAttrs() {
528   auto iter = attrs_.find(AXIS);
529   if (iter == attrs_.end()) {
530     MS_LOG(ERROR) << name_ << ": Can't find axis attribute.";
531     return FAILED;
532   }
533   MS_EXCEPTION_IF_NULL(iter->second);
534   auto value_tuple = iter->second->cast<ValueTuplePtr>();
535   MS_EXCEPTION_IF_NULL(value_tuple);
536   InferAxis(value_tuple);
537   attrs_[AXIS] = axis_;
538   return SUCCESS;
539 }
540 
InferReplaceOps()541 Status SqueezeInfo::InferReplaceOps() {
542   Attr attr = std::make_pair(AXIS, axis_);
543   OperatorAttrs attrs = {attr};
544   OperatorParams params;
545   OperatorArgs args = std::make_pair(attrs, params);
546   replace_op_ = {std::make_pair(SQUEEZE, args)};
547   return SUCCESS;
548 }
549 
InferTensorMap()550 Status SqueezeInfo::InferTensorMap() {
551   // for example: if the shape of input is [32, 32, 1], and the axis is (2, ),
552   // then the input_tensor_map is [2, 1, 0], the output_tensor_map is [2, 1]
553   Shape input_tensor_map, output_tensor_map;
554   if (inputs_shape_.empty()) {
555     MS_LOG(ERROR) << name_ << ": The inputs shape is empty";
556     return FAILED;
557   }
558   size_t size = inputs_shape_[0].size();
559   std::vector<int64_t> axis = GetValue<const std::vector<int64_t>>(axis_);
560   for (size_t i = 0; i < size; ++i) {
561     size_t index = size - i - 1;
562     auto iter = std::find(axis.begin(), axis.end(), SizeToLong(i));
563     if (iter == axis.end()) {
564       output_tensor_map.push_back(SizeToLong(index));
565     }
566     input_tensor_map.push_back(SizeToLong(index));
567   }
568   inputs_tensor_map_.push_back(input_tensor_map);
569   outputs_tensor_map_.push_back(output_tensor_map);
570   MS_LOG(INFO) << name_ << ": The tensor map of input is " << ShapeToString(input_tensor_map)
571                << ", and the tensor map of output is " << ShapeToString(output_tensor_map);
572 
573   return SUCCESS;
574 }
575 
Init(const StrategyPtr & strategy)576 Status SqueezeInfo::Init(const StrategyPtr &strategy) {
577   if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
578     MS_LOG(ERROR) << name_ << " : Init failed.";
579   }
580 
581   (void)InferReplaceOps();
582 
583   MS_LOG(INFO) << name_ << " : Init success.";
584   return SUCCESS;
585 }
586 }  // namespace parallel
587 }  // namespace mindspore
588