1 /**
2 * Copyright 2019 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/parallel/ops_info/transpose_info.h"
18
19 #include <memory>
20 #include <vector>
21
22 #include "frontend/parallel/device_manager.h"
23 #include "frontend/parallel/device_matrix.h"
24 #include "frontend/parallel/step_parallel.h"
25 #include "utils/convert_utils.h"
26 #include "utils/log_adapter.h"
27
28 namespace mindspore {
29 namespace parallel {
CheckStrategy(const StrategyPtr & strategy)30 Status TransposeInfo::CheckStrategy(const StrategyPtr &strategy) { return CheckStrategyValue(strategy, inputs_shape_); }
31
InferDevMatrixShape()32 Status TransposeInfo::InferDevMatrixShape() {
33 Strategys stra = strategy_->GetInputDim();
34 input_strategy_ = stra.at(0);
35 for (auto &iter : input_strategy_) {
36 dev_matrix_shape_.push_back(iter);
37 }
38 return SUCCESS;
39 }
40
41 // there is no Parameter for Transpose Primitive, so no need to do all reduce
InferMirrorOps()42 Status TransposeInfo::InferMirrorOps() { return SUCCESS; }
43
44 // there is no reduction dimension for forward computation of Transpose Primitive, so no need to do all reduce
InferForwardCommunication()45 Status TransposeInfo::InferForwardCommunication() { return SUCCESS; }
46
47 /*
48 * get perm input of Transpose Primitive
49 * perm is a permutation of the dimensions of input
50 * the result is saved in axis_v_
51 */
ComputeAxis()52 Status TransposeInfo::ComputeAxis() {
53 if (input_value_[1] == nullptr) {
54 MS_LOG(ERROR) << name_ << ": input_value_[1] is nullptr.";
55 return FAILED;
56 }
57 std::vector<ValuePtr> elements;
58 ValueTuplePtr dim_tuple = input_value_[1]->cast<ValueTuplePtr>();
59 if (dim_tuple == nullptr) {
60 MS_LOG(ERROR) << name_ << ": input_value_[1] must be ValueTuplePtr.";
61 return FAILED;
62 }
63 elements = dim_tuple->value();
64 if (elements.size() != inputs_shape_[0].size()) {
65 MS_LOG(ERROR) << name_ << ": elements size must equal to inputs shape 0 size.";
66 return FAILED;
67 }
68 axis_v_.clear();
69 for (auto &element : elements) {
70 MS_EXCEPTION_IF_NULL(element);
71 if (element->isa<Int64Imm>()) {
72 int64_t axis = element->cast<Int64ImmPtr>()->value();
73 axis_v_.push_back(axis);
74 } else {
75 MS_LOG(ERROR) << name_ << ": The value of axis must be int32.";
76 return FAILED;
77 }
78 }
79
80 for (int64_t i = 0; i < SizeToLong(axis_v_.size()); i++) {
81 auto iter = std::find(axis_v_.begin(), axis_v_.end(), i);
82 if (iter == axis_v_.end()) {
83 MS_LOG(ERROR) << name_ << ": axis_v_ must be a permutation.";
84 }
85 }
86 return SUCCESS;
87 }
88
89 // the output tensor map is the permutation of input tensor map, the permutation is axis_v
InferTensorMap()90 Status TransposeInfo::InferTensorMap() {
91 if ((inputs_shape_.size() != 1) || (outputs_shape_.size() != 1)) {
92 MS_LOG(ERROR) << name_ << ": inputs_shape_ and outputs_shape_ size must be 1, inputs shape and outputs shape is "
93 << inputs_shape_.size() << ", " << outputs_shape_.size();
94 return FAILED;
95 }
96
97 Shape tensor_map_index_input;
98 for (size_t j = 0; j < inputs_shape_[0].size(); ++j) {
99 tensor_map_index_input.push_back(SizeToLong(inputs_shape_[0].size() - j - 1));
100 }
101 inputs_tensor_map_.push_back(tensor_map_index_input);
102
103 Shape tensor_map_index_output = tensor_map_index_input;
104 for (uint64_t i = 0; i < tensor_map_index_output.size(); i++) {
105 tensor_map_index_output[i] = tensor_map_index_input[LongToUlong(axis_v_[i])];
106 }
107 outputs_tensor_map_.push_back(tensor_map_index_output);
108 return SUCCESS;
109 }
110
111 // compute axis_v_ during this method
GetAttrs()112 Status TransposeInfo::GetAttrs() { return ComputeAxis(); }
113
Init(const StrategyPtr & strategy)114 Status TransposeInfo::Init(const StrategyPtr &strategy) {
115 if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
116 MS_LOG(ERROR) << name_ << ": Init failed.";
117 return FAILED;
118 }
119 MS_LOG(INFO) << name_ << ": Init success.";
120 return SUCCESS;
121 }
122
InitForCostModel(const StrategyPtr & strategy)123 Status TransposeInfo::InitForCostModel(const StrategyPtr &strategy) {
124 if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
125 MS_LOG(ERROR) << name_ << ": Init for cost model failed.";
126 return FAILED;
127 }
128
129 MS_LOG(INFO) << name_ << ": Init for cost model success.";
130 return SUCCESS;
131 }
132
SetCostUnderStrategy(const mindspore::parallel::StrategyPtr & strategy)133 Status TransposeInfo::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr &strategy) {
134 return SetCostUnderStrategyBase(strategy);
135 }
136
GenerateOpStrategies(int64_t stage_id)137 std::vector<StrategyPtr> TransposeInfo::GenerateOpStrategies(int64_t stage_id) {
138 Shape input0_split(inputs_shape_[0].size(), 1);
139 Shapes splittable_inputs = {input0_split};
140 std::vector<StrategyPtr> sp_vector;
141 if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) {
142 MS_LOG(EXCEPTION) << name_ << ": GenerateStrategiesForIndependentInputs failed";
143 }
144
145 return sp_vector;
146 }
147 } // namespace parallel
148 } // namespace mindspore
149