1 /**
2 * Copyright 2019 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/parallel/ops_info/activation_info.h"
18
19 #include <algorithm>
20 #include <memory>
21 #include <vector>
22 #include <utility>
23 #include <functional>
24 #include <numeric>
25
26 #include "ir/value.h"
27 #include "frontend/parallel/auto_parallel/costmodel.h"
28 #include "frontend/parallel/device_matrix.h"
29 #include "frontend/parallel/strategy.h"
30
31 namespace mindspore {
32 namespace parallel {
SetCostUnderStrategy(const StrategyPtr & strategy)33 Status Activation::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }
34
CheckStrategy(const StrategyPtr & strategy)35 Status Activation::CheckStrategy(const StrategyPtr &strategy) { return CheckStrategyValue(strategy, inputs_shape_); }
36
GetAttrs()37 Status ActivationInfo::GetAttrs() {
38 if (attrs_.size() < ACTIVATION_ATTR_SIZE) {
39 MS_LOG(ERROR) << name_ << " : The size of attrs small than 1.";
40 return FAILED;
41 }
42
43 if ((inputs_shape_.size() != ACTIVATION_INPUTS_SIZE) || (outputs_shape_.size() != ACTIVATION_OUTPUTS_SIZE)) {
44 MS_LOG(ERROR) << name_ << " : Inputs shape size(" << inputs_shape_.size() << ") or outputs shape size("
45 << outputs_shape_.size() << "is wrong.";
46 return FAILED;
47 }
48
49 auto iter = attrs_.find(ACTIVATION_TYPE);
50 if (iter != attrs_.end()) {
51 MS_EXCEPTION_IF_NULL(iter->second);
52 if (iter->second->isa<StringImm>()) {
53 std::string val = iter->second->cast<StringImmPtr>()->value();
54 if ((val != RELU_TYPE) && (val != RELU6_TYPE) && (val != SIGMOID_TYPE)) {
55 MS_LOG(ERROR) << name_ << " : Activation type is wrong.";
56 return FAILED;
57 }
58 } else {
59 MS_LOG(ERROR) << name_ << " : The value of activation_type is not string.";
60 return FAILED;
61 }
62 }
63
64 return SUCCESS;
65 }
66
GetAttrs()67 Status ActivationOther::GetAttrs() {
68 if ((inputs_shape_.size() != ACTIVATION_INPUTS_SIZE) || (outputs_shape_.size() != ACTIVATION_OUTPUTS_SIZE)) {
69 MS_LOG(ERROR) << name_ << " : Inputs shape size(" << inputs_shape_.size() << ") or outputs shape size("
70 << outputs_shape_.size() << "is wrong.";
71 return FAILED;
72 }
73 return SUCCESS;
74 }
75
GenerateOpStrategies(int64_t stage_id)76 std::vector<StrategyPtr> Activation::GenerateOpStrategies(int64_t stage_id) {
77 std::vector<StrategyPtr> sp_vector;
78 if ((inputs_shape_.size() != ACTIVATION_INPUTS_SIZE) || (outputs_shape_.size() != ACTIVATION_OUTPUTS_SIZE)) {
79 MS_LOG(EXCEPTION) << name_ << " : Inputs shape size(" << inputs_shape_.size() << ") or outputs shape size("
80 << outputs_shape_.size() << "is wrong.";
81 }
82
83 Shape input0_split(inputs_shape_[0].size(), 1);
84 Shapes splittable_inputs = {input0_split};
85
86 if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) {
87 MS_LOG(EXCEPTION) << name_ << " : Generate strategies for independent inputs() failed.";
88 }
89
90 return sp_vector;
91 }
92
GenerateOpStrategies(int64_t stage_id)93 std::vector<StrategyPtr> DropoutInfo::GenerateOpStrategies(int64_t stage_id) {
94 Shape input0_split(inputs_shape_[0].size(), 1);
95 Shapes splittable_inputs = {input0_split};
96
97 std::vector<StrategyPtr> sp_vector;
98 if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) {
99 MS_LOG(EXCEPTION) << name_ << " : Generate strategies for independent inputs() failed.";
100 }
101 return sp_vector;
102 }
103
CheckStrategy(const StrategyPtr & strategy)104 Status Softmax::CheckStrategy(const StrategyPtr &strategy) {
105 if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
106 MS_LOG(ERROR) << name_ << " : Invalid strategy.";
107 return FAILED;
108 }
109
110 Strategys stra = strategy->GetInputDim();
111 Dimensions input_strategy = stra.at(0);
112
113 for (auto &element : axis_) {
114 int64_t axis_index = element;
115 if (element < 0) {
116 size_t input_dim = inputs_shape_.at(0).size();
117 axis_index = static_cast<int64_t>(input_dim) + element;
118 }
119
120 int64_t axis_strategy = input_strategy.at(LongToSize(axis_index));
121 // Dimension corresponding to axis is un-splittable
122 if (axis_strategy != MIN_SLICE_NUM) {
123 MS_LOG(ERROR) << name_ << " : The strategy corresponding to axis dimension(" << axis_strategy << ") is not 1";
124 return FAILED;
125 }
126 }
127
128 return SUCCESS;
129 }
130
GetAttrs()131 Status Softmax::GetAttrs() {
132 if (attrs_.size() < SOFTMAX_ATTR_SIZE) {
133 MS_LOG(ERROR) << name_ << " : The size of attrs small than 1.";
134 return FAILED;
135 }
136
137 auto iter = attrs_.find(AXIS);
138 if (iter != attrs_.end()) {
139 MS_EXCEPTION_IF_NULL(iter->second);
140 if (iter->second->isa<Int64Imm>()) { // the axis is a number
141 int64_t axis_element = iter->second->cast<Int64ImmPtr>()->value();
142 axis_.push_back(axis_element);
143 MS_LOG(INFO) << name_ << " : The axis is int64_t, value is " << axis_element;
144 } else if (iter->second->isa<ValueTuple>()) { // the axis is a tuple
145 ValueTuplePtr value_tuple = iter->second->cast<ValueTuplePtr>();
146 if (value_tuple == nullptr) {
147 MS_LOG(ERROR) << name_ << " : The value_tuple is nullptr.";
148 return FAILED;
149 }
150 std::vector<ValuePtr> value_vector = value_tuple->value();
151 (void)std::transform(value_vector.begin(), value_vector.end(), std::back_inserter(axis_),
152 [](const ValuePtr &value) { return static_cast<int64_t>(GetValue<int64_t>(value)); });
153 if (axis_.empty()) {
154 MS_LOG(ERROR) << name_ << " : The axis tuple is empty.";
155 return FAILED;
156 }
157 MS_LOG(INFO) << name_ << " : The axis is tuple, value is " << ListToString(axis_);
158 } else {
159 MS_LOG(ERROR) << name_ << " : The value of axis is not int64_t or tuple int64_t.";
160 return FAILED;
161 }
162 }
163
164 if ((inputs_shape_.size() != ACTIVATION_INPUTS_SIZE) || (outputs_shape_.size() != ACTIVATION_OUTPUTS_SIZE)) {
165 MS_LOG(ERROR) << name_ << " : Inputs shape size or outputs shape size is wrong.";
166 return FAILED;
167 }
168
169 // for example: tensor dimension is 4, then axis range [-4, 3]
170 int64_t dim = SizeToLong(inputs_shape_.at(0).size());
171 auto it =
172 std::find_if(axis_.begin(), axis_.end(), [dim](int64_t element) { return ((element >= dim) || (element < -dim)); });
173 if (it != axis_.end()) {
174 MS_LOG(ERROR) << name_ << " : The axis(" << *it << ") is out of range[" << (-dim) << ", " << (dim - 1) << "].";
175 return FAILED;
176 }
177
178 return SUCCESS;
179 }
180
SetCostUnderStrategy(const StrategyPtr & strategy)181 Status Softmax::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }
182
GenerateOpStrategies(int64_t stage_id)183 std::vector<StrategyPtr> Softmax::GenerateOpStrategies(int64_t stage_id) {
184 if ((inputs_shape_.size() != ACTIVATION_INPUTS_SIZE) || (outputs_shape_.size() != ACTIVATION_OUTPUTS_SIZE)) {
185 MS_LOG(EXCEPTION) << name_ << " : Inputs shape size or outputs shape size is wrong.";
186 }
187
188 Shape input0_split;
189 (void)input0_split.insert(input0_split.begin(), inputs_shape_[0].size(), 1);
190 for (auto &element : axis_) {
191 int64_t axis_index = element;
192 if (element < 0) {
193 size_t input_dim = inputs_shape_.at(0).size();
194 axis_index = static_cast<int64_t>(input_dim) + element;
195 }
196 input0_split[LongToSize(axis_index)] = 0;
197 }
198 Shapes splittable_inputs = {input0_split};
199
200 std::vector<StrategyPtr> sp_vector;
201 if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) {
202 MS_LOG(EXCEPTION) << name_ << " : Generate strategies for independent inputs failed.";
203 }
204 return sp_vector;
205 }
206
InferDevMatrixShape()207 Status ActivationBase::InferDevMatrixShape() {
208 Strategys stra = strategy_->GetInputDim();
209 Dimensions input_strategy = stra.at(0);
210
211 dev_matrix_shape_ = input_strategy;
212
213 return SUCCESS;
214 }
215
InferMirrorOps()216 Status ActivationBase::InferMirrorOps() {
217 mirror_ops_.clear();
218
219 Shape tensor_map = inputs_tensor_map_[0];
220 std::vector<Group> group;
221 if (CreateGroupByTensorMap(tensor_map, &group) != SUCCESS) {
222 MS_LOG(ERROR) << name_ << " : Create group failed.";
223 return FAILED;
224 }
225
226 OperatorVector mirror_op;
227 if (group.empty()) {
228 MS_LOG(INFO) << name_ << " : The mirror ops is empty.";
229 return SUCCESS;
230 } else {
231 mirror_op = CreateMirrorOps(group[0].name(), group[0].GetDevNum());
232 mirror_ops_.push_back(mirror_op);
233 std::string group_name = group[0].name();
234 MS_LOG(INFO) << name_ << " : Create the mirror ops success, the group name is " << group_name;
235 }
236
237 return SUCCESS;
238 }
239
InferForwardCommunication()240 Status ActivationBase::InferForwardCommunication() {
241 // do nothing
242 return SUCCESS;
243 }
244
InferTensorMap()245 Status ActivationBase::InferTensorMap() {
246 Shape tensor_map_index;
247 size_t size = inputs_shape_.at(0).size();
248 // such as 4: tensor_map_index [3,2,1,0]
249 for (size_t i = 0; i < size; ++i) {
250 tensor_map_index.push_back((int64_t)(size - i - 1));
251 }
252
253 inputs_tensor_map_.push_back(tensor_map_index);
254 outputs_tensor_map_.push_back(tensor_map_index);
255 return SUCCESS;
256 }
257
GetAttrs()258 Status DropoutInfo::GetAttrs() {
259 auto iter0 = attrs_.find(SEED0);
260 if (iter0 != attrs_.end()) {
261 MS_EXCEPTION_IF_NULL(iter0->second);
262 if (iter0->second->isa<Int64Imm>()) {
263 seed0_ = iter0->second->cast<Int64ImmPtr>()->value();
264 } else {
265 MS_LOG(ERROR) << name_ << " : The value of seed0 is not int64_t.";
266 return FAILED;
267 }
268 }
269 auto iter1 = attrs_.find(SEED1);
270 if (iter1 != attrs_.end()) {
271 MS_EXCEPTION_IF_NULL(iter1->second);
272 if (iter1->second->isa<Int64Imm>()) {
273 seed1_ = iter1->second->cast<Int64ImmPtr>()->value();
274 } else {
275 MS_LOG(ERROR) << name_ << " : The value of seed1 is not int64_t.";
276 return FAILED;
277 }
278 }
279 return SUCCESS;
280 }
281
InferTensorMap()282 Status DropoutInfo::InferTensorMap() {
283 Shape tensor_map_in;
284 size_t size = inputs_shape_.at(0).size();
285 // such as 4: tensor_map_index [3,2,1,0]
286 for (size_t i = 0; i < size; ++i) {
287 tensor_map_in.push_back((int64_t)(size - i - 1));
288 }
289
290 inputs_tensor_map_.push_back(tensor_map_in);
291 outputs_tensor_map_.push_back(tensor_map_in);
292 outputs_tensor_map_.push_back(tensor_map_in); // the dropout has two outputs
293 return SUCCESS;
294 }
295
InferAsLossDivisor()296 Status DropoutInfo::InferAsLossDivisor() {
297 if (outputs_tensor_map_.empty()) {
298 MS_LOG(ERROR) << name_ << ": The size of outputs tensor map is empty";
299 return FAILED;
300 }
301 as_loss_divisor_ = ComputeRepeatDeviceNumByTensorMap(dev_matrix_shape_, outputs_tensor_map_[0]);
302 MS_LOG(INFO) << name_ << " : The dev matrix shape is " << ShapeToString(dev_matrix_shape_)
303 << ", the output[0]'s tensor map is " << ShapeToString(outputs_tensor_map_[0])
304 << ", as_loss_divisor_ is " << as_loss_divisor_;
305 return SUCCESS;
306 }
307
InferReplaceOps()308 Status DropoutInfo::InferReplaceOps() {
309 if ((seed0_ != 0) || (seed1_ != 0) || (repeated_calc_num_ == 1)) {
310 return SUCCESS;
311 }
312 int64_t seed = get_seed();
313 ValuePtr new_seed0 = MakeValue(seed);
314 ValuePtr new_seed1 = MakeValue(seed);
315 Attr attr_seed0 = std::make_pair(SEED0, new_seed0);
316 Attr attr_seed1 = std::make_pair(SEED1, new_seed1);
317 Attr attr_keep_probs = std::make_pair(KEEP_PROB, attrs_[KEEP_PROB]);
318 OperatorAttrs attrs = {attr_keep_probs, attr_seed0, attr_seed1};
319 OperatorParams params;
320 OperatorArgs args = std::make_pair(attrs, params);
321 replace_op_ = {std::make_pair(DROPOUT, args)};
322 return SUCCESS;
323 }
324
Init(const StrategyPtr & strategy)325 Status DropoutInfo::Init(const StrategyPtr &strategy) {
326 if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
327 MS_LOG(ERROR) << name_ << " : Init failed";
328 return FAILED;
329 }
330 (void)InferReplaceOps();
331
332 MS_LOG(INFO) << name_ << " : Init success";
333 return SUCCESS;
334 }
335
Init(const StrategyPtr & strategy)336 Status ActivationBase::Init(const StrategyPtr &strategy) {
337 if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
338 MS_LOG(ERROR) << name_ << " : Init failed.";
339 return FAILED;
340 }
341
342 MS_LOG(INFO) << name_ << " : Init success.";
343 return SUCCESS;
344 }
345
InitForCostModel(const StrategyPtr & strategy)346 Status ActivationBase::InitForCostModel(const StrategyPtr &strategy) {
347 if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
348 MS_LOG(ERROR) << name_ << " : Init for cost model failed.";
349 return FAILED;
350 }
351
352 MS_LOG(INFO) << name_ << " : Init for cost model success.";
353 return SUCCESS;
354 }
355
InferMirrorOps()356 Status CastInfo::InferMirrorOps() {
357 mirror_ops_.clear();
358
359 Shape tensor_map = inputs_tensor_map_[0];
360 std::vector<Group> group;
361 if (CreateGroupByTensorMap(tensor_map, &group) != SUCCESS) {
362 MS_LOG(ERROR) << name_ << " : Create group failed.";
363 return FAILED;
364 }
365
366 OperatorVector mirror_op;
367 OperatorVector op_for_value;
368 if (group.empty()) {
369 MS_LOG(INFO) << name_ << " : The mirror ops is empty.";
370 return SUCCESS;
371 } else {
372 mirror_op = CreateMirrorOps(group[0].name(), group[0].GetDevNum());
373 mirror_ops_.push_back(mirror_op);
374 mirror_ops_.push_back(op_for_value);
375 std::string group_name = group[0].name();
376 MS_LOG(INFO) << name_ << " : Create the mirror ops success, the group name is " << group_name;
377 }
378
379 return SUCCESS;
380 }
381
GetAttrs()382 Status ExpandDimsInfo::GetAttrs() {
383 if (input_value_.size() != EXPANDDIMS_INPUT_SIZE) {
384 MS_LOG(ERROR) << name_ << ": Invalid inputs size " << input_value_.size();
385 return FAILED;
386 }
387
388 if (!input_value_.back()->isa<Int64Imm>()) {
389 MS_LOG(ERROR) << name_ << ": The type of axis is not int64_t";
390 return FAILED;
391 }
392
393 int64_t axis = GetValue<int64_t>(input_value_.back());
394
395 if (inputs_shape_.empty()) {
396 MS_LOG(ERROR) << name_ << ": The inputs shape is empty";
397 return FAILED;
398 }
399
400 int64_t dim = SizeToLong(inputs_shape_[0].size());
401 if ((axis > dim) || (axis < -dim - 1)) {
402 MS_LOG(ERROR) << name_ << ": The axis(" << axis << ") is out of range[" << (-dim - 1) << ", " << dim << "]";
403 return FAILED;
404 }
405
406 if (axis < 0) {
407 positive_axis_ = dim + axis + 1;
408 } else {
409 positive_axis_ = axis;
410 }
411 MS_LOG(INFO) << name_ << ": The axis is " << axis << ", and the positive axis is " << positive_axis_;
412 return SUCCESS;
413 }
414
InferTensorMap()415 Status ExpandDimsInfo::InferTensorMap() {
416 if (inputs_shape_.empty()) {
417 MS_LOG(ERROR) << name_ << ": The inputs shape is empty";
418 return FAILED;
419 }
420
421 // for example: if the dimension of input is 3, and the axis is 2,
422 // then the input_tensor_map is [2, 1, 0], the output_tensor_map is [2, 1, -1, 0]
423 Shape input_tensor_map, output_tensor_map;
424 size_t size = inputs_shape_[0].size();
425 for (size_t i = 0; i < size; ++i) {
426 input_tensor_map.push_back(SizeToLong(size - i - 1));
427 }
428
429 inputs_tensor_map_.push_back(input_tensor_map);
430
431 output_tensor_map = input_tensor_map;
432 if ((positive_axis_ < 0) || (positive_axis_ > SizeToLong(size))) {
433 MS_LOG(ERROR) << name_ << ": Invalid positive axis " << positive_axis_;
434 return FAILED;
435 }
436 (void)output_tensor_map.insert(output_tensor_map.begin() + positive_axis_, NO_SPLIT_MAP);
437 outputs_tensor_map_.push_back(output_tensor_map);
438
439 MS_LOG(INFO) << name_ << ": The tensor map of input is " << ShapeToString(input_tensor_map)
440 << ", and the tensor map of output is " << ShapeToString(output_tensor_map);
441 return SUCCESS;
442 }
443
InferTensorStrategy()444 Status ExpandDimsInfo::InferTensorStrategy() {
445 if (strategy_ == nullptr) {
446 MS_LOG(ERROR) << name_ << ": The strategy is null";
447 return FAILED;
448 }
449
450 inputs_strategy_ = strategy_->GetInputDim();
451 if (inputs_strategy_.empty()) {
452 MS_LOG(ERROR) << name_ << ": The strategy is empty";
453 return FAILED;
454 }
455
456 Shape output_strategy = inputs_strategy_[0];
457 if ((positive_axis_ < 0) || (positive_axis_ > SizeToLong(output_strategy.size()))) {
458 MS_LOG(ERROR) << name_ << ": Invalid positive axis " << positive_axis_;
459 return FAILED;
460 }
461 (void)output_strategy.insert(output_strategy.begin() + positive_axis_, NO_SPLIT_STRATEGY);
462 outputs_strategy_ = {output_strategy};
463 return SUCCESS;
464 }
465
InferMirrorOps()466 Status ExpandDimsInfo::InferMirrorOps() {
467 mirror_ops_.clear();
468
469 if (inputs_tensor_map_.empty()) {
470 MS_LOG(ERROR) << name_ << ": The tensor map of inputs is empty";
471 return FAILED;
472 }
473
474 std::vector<Group> group;
475 if (CreateGroupByTensorMap(inputs_tensor_map_[0], &group) != SUCCESS) {
476 MS_LOG(ERROR) << name_ << ": Create group failed";
477 return FAILED;
478 }
479
480 if (group.empty()) {
481 MS_LOG(INFO) << name_ << ": No need to create mirror ops";
482 return SUCCESS;
483 }
484
485 OperatorVector mirror_op, placeholder_op;
486 mirror_op = CreateMirrorOps(group[0].name(), group[0].GetDevNum());
487 mirror_ops_.push_back(mirror_op);
488 mirror_ops_.push_back(placeholder_op);
489 MS_LOG(INFO) << name_ << ": Create mirror ops success, the group name is " << group[0].name();
490 return SUCCESS;
491 }
492
InferAxis(const ValueTuplePtr & value_tuple)493 Status SqueezeInfo::InferAxis(const ValueTuplePtr &value_tuple) {
494 std::vector<int64_t> axis;
495 auto axis_list = value_tuple->value();
496 if (inputs_shape_.empty()) {
497 MS_LOG(ERROR) << name_ << ": The inputs shape is empty";
498 return FAILED;
499 }
500 Shape input_shape = inputs_shape_.at(0);
501 size_t input_size = input_shape.size();
502 // if axis tuple is empty, we should exclude the axis that the corresponding slice shape is 1.
503 if (axis_list.empty()) {
504 for (size_t i = 0; i < input_size; ++i) {
505 if (input_shape[i] == 1) {
506 axis.push_back(i);
507 }
508 }
509 axis_ = MakeValue(axis)->cast<ValueTuplePtr>();
510 return SUCCESS;
511 }
512
513 // convert negative axis to positive.
514 for (auto &dim : axis_list) {
515 if (!dim->isa<Int64Imm>()) {
516 MS_LOG(ERROR) << name_ << ": The type of axis is not int64_t";
517 return FAILED;
518 }
519 int64_t dim_value = GetValue<int64_t>(dim);
520 int64_t positive_value = (dim_value < 0) ? (dim_value + SizeToLong(input_size)) : dim_value;
521 axis.push_back(positive_value);
522 }
523 axis_ = MakeValue(axis)->cast<ValueTuplePtr>();
524 return SUCCESS;
525 }
526
GetAttrs()527 Status SqueezeInfo::GetAttrs() {
528 auto iter = attrs_.find(AXIS);
529 if (iter == attrs_.end()) {
530 MS_LOG(ERROR) << name_ << ": Can't find axis attribute.";
531 return FAILED;
532 }
533 MS_EXCEPTION_IF_NULL(iter->second);
534 auto value_tuple = iter->second->cast<ValueTuplePtr>();
535 MS_EXCEPTION_IF_NULL(value_tuple);
536 InferAxis(value_tuple);
537 attrs_[AXIS] = axis_;
538 return SUCCESS;
539 }
540
InferReplaceOps()541 Status SqueezeInfo::InferReplaceOps() {
542 Attr attr = std::make_pair(AXIS, axis_);
543 OperatorAttrs attrs = {attr};
544 OperatorParams params;
545 OperatorArgs args = std::make_pair(attrs, params);
546 replace_op_ = {std::make_pair(SQUEEZE, args)};
547 return SUCCESS;
548 }
549
InferTensorMap()550 Status SqueezeInfo::InferTensorMap() {
551 // for example: if the shape of input is [32, 32, 1], and the axis is (2, ),
552 // then the input_tensor_map is [2, 1, 0], the output_tensor_map is [2, 1]
553 Shape input_tensor_map, output_tensor_map;
554 if (inputs_shape_.empty()) {
555 MS_LOG(ERROR) << name_ << ": The inputs shape is empty";
556 return FAILED;
557 }
558 size_t size = inputs_shape_[0].size();
559 std::vector<int64_t> axis = GetValue<const std::vector<int64_t>>(axis_);
560 for (size_t i = 0; i < size; ++i) {
561 size_t index = size - i - 1;
562 auto iter = std::find(axis.begin(), axis.end(), SizeToLong(i));
563 if (iter == axis.end()) {
564 output_tensor_map.push_back(SizeToLong(index));
565 }
566 input_tensor_map.push_back(SizeToLong(index));
567 }
568 inputs_tensor_map_.push_back(input_tensor_map);
569 outputs_tensor_map_.push_back(output_tensor_map);
570 MS_LOG(INFO) << name_ << ": The tensor map of input is " << ShapeToString(input_tensor_map)
571 << ", and the tensor map of output is " << ShapeToString(output_tensor_map);
572
573 return SUCCESS;
574 }
575
Init(const StrategyPtr & strategy)576 Status SqueezeInfo::Init(const StrategyPtr &strategy) {
577 if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
578 MS_LOG(ERROR) << name_ << " : Init failed.";
579 }
580
581 (void)InferReplaceOps();
582
583 MS_LOG(INFO) << name_ << " : Init success.";
584 return SUCCESS;
585 }
586 } // namespace parallel
587 } // namespace mindspore
588