1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/parallel/ops_info/gather_v2_info.h"
18
19 #include <memory>
20 #include <utility>
21 #include <vector>
22
23 #include "ir/tensor.h"
24 #include "ir/value.h"
25 #include "frontend/parallel/auto_parallel/costmodel.h"
26 #include "frontend/parallel/device_matrix.h"
27 #include "frontend/parallel/graph_util/generate_graph.h"
28 #include "frontend/parallel/strategy.h"
29 #include "utils/log_adapter.h"
30
31 namespace mindspore {
32 namespace parallel {
GetAttrs()33 Status GatherInfo::GetAttrs() {
34 if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) {
35 MS_LOG(ERROR) << name_ << ": inputs shape size must be 2, but is " << inputs_shape_.size();
36 return FAILED;
37 }
38 if (outputs_shape_.size() != GATHER_V2_OUTPUTS_SIZE) {
39 MS_LOG(ERROR) << name_ << ": outputs shape size must be 1, but is " << outputs_shape_.size();
40 return FAILED;
41 }
42 if (input_value_.size() != GATHER_V2_INPUTS_VALUE_SIZE) {
43 MS_LOG(ERROR) << name_ << ": input value size must be 3, but is " << input_value_.size();
44 return FAILED;
45 }
46 // the second input is the index tensor
47
48 // the third input is the axis, is a ValueNode
49 if (input_value_.at(2) == nullptr) {
50 MS_LOG(ERROR) << name_ << ": the third input value is nullptr, is not a ValueNode!";
51 return FAILED;
52 }
53
54 if (inputs_shape_.at(0).size() == 0) {
55 MS_LOG(ERROR) << name_ << ": input can not be a scalar!";
56 return FAILED;
57 }
58 int64_t axis = GetValue<int64_t>(input_value_.at(2));
59 if (axis >= SizeToLong(inputs_shape_.at(0).size()) || axis < -SizeToLong(inputs_shape_.at(0).size())) {
60 MS_LOG(ERROR) << "Axis is " << axis << ", not in [-" << inputs_shape_.at(0).size() << ", "
61 << inputs_shape_.at(0).size() << ").";
62 }
63 if (axis < 0) {
64 axis += SizeToLong(inputs_shape_[0].size());
65 }
66 axis_ = axis;
67
68 index_size_ = inputs_shape_.at(1).size();
69
70 return SUCCESS;
71 }
72
CheckStrategy(const StrategyPtr & strategy)73 Status GatherInfo::CheckStrategy(const StrategyPtr &strategy) {
74 if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) {
75 MS_LOG(ERROR) << name_ << ": inputs shape size must be " << GATHER_V2_INPUTS_SIZE << ", but is "
76 << inputs_shape_.size();
77 return FAILED;
78 }
79 if (outputs_shape_.size() != GATHER_V2_OUTPUTS_SIZE) {
80 MS_LOG(ERROR) << name_ << ": outputs shape size must be " << GATHER_V2_OUTPUTS_SIZE << ", but is "
81 << outputs_shape_.size();
82 return FAILED;
83 }
84 // Only strategy of the first input should be set.
85 if (CheckStrategyValue(strategy, {inputs_shape_.at(0)}) != SUCCESS) {
86 MS_LOG(ERROR) << name_ << ": Invalid strategy.";
87 return FAILED;
88 }
89 axis_strategy_ = strategy->GetInputDim().at(0).at(LongToSize(axis_));
90 if (index_size_ != 1 && axis_strategy_ != 1) {
91 MS_LOG(ERROR) << name_
92 << ": Invalid strategy. If the index is a scalar or a more than 1 dimension vector, the strategy "
93 "corresponding to axis must be 1, but is "
94 << axis_strategy_;
95 return FAILED;
96 }
97 if (index_size_ == 1 && axis_strategy_ != 1 && inputs_shape_.at(1).at(0) % axis_strategy_ != 0) {
98 MS_LOG(ERROR) << name_
99 << ": Invalid strategy. The first dimension of index can not be divided by strategy corresponding to "
100 "axis. The first dimension of index is "
101 << inputs_shape_.at(1).at(0) << " strategy corresponding to axis is " << axis_strategy_;
102 return FAILED;
103 }
104 return SUCCESS;
105 }
106
InferDevMatrixShape()107 Status GatherInfo::InferDevMatrixShape() {
108 Strategys stra = strategy_->GetInputDim();
109 dev_matrix_shape_ = stra.at(0);
110 return SUCCESS;
111 }
112
113 // If index is a scalar, output dimension is input dimension minus 1;
114 // If index is a n dimension tensor, output dimension is input dimension plus (n - 1).
115 // Tensor map dimension is equal to the corresponding input and output dimension.
116 // If index's dimension is more than 1, we insert -1 for the output tensor map.
InferTensorMap()117 Status GatherInfo::InferTensorMap() {
118 if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) {
119 MS_LOG(ERROR) << name_ << ": inputs shape size must be " << GATHER_V2_INPUTS_SIZE << ", but is "
120 << inputs_shape_.size();
121 return FAILED;
122 }
123 if (outputs_shape_.size() != GATHER_V2_OUTPUTS_SIZE) {
124 MS_LOG(ERROR) << name_ << ": outputs shape size must be " << GATHER_V2_OUTPUTS_SIZE << ", but is "
125 << outputs_shape_.size();
126 return FAILED;
127 }
128 Shape tensor_map_in;
129 Shape tensor_map_out;
130 size_t size = inputs_shape_.at(0).size();
131 // such as 4: tensor_map_index [3,2,1,0]
132 for (size_t i = 0; i < size; ++i) {
133 tensor_map_in.push_back(SizeToLong(size - i - 1));
134 tensor_map_out.push_back(SizeToLong(size - i - 1));
135 }
136
137 if (index_size_ == 0) {
138 (void)tensor_map_out.erase(tensor_map_out.begin() + axis_);
139 } else if (index_size_ > 1) {
140 (void)tensor_map_out.insert(tensor_map_out.begin() + axis_, index_size_ - 1, -1);
141 }
142 if (tensor_map_out.size() != outputs_shape_.at(0).size()) {
143 MS_LOG(ERROR) << "Out tensor map size is not equal to output size! Out tensor map size is " << tensor_map_out.size()
144 << " output size is " << outputs_shape_.at(0).size();
145 return FAILED;
146 }
147
148 Shape tensor_map_in_index;
149 if (index_size_ >= 1) {
150 tensor_map_in_index.push_back(SizeToLong(size) - axis_ - 1);
151 }
152 for (size_t i = 1; i < index_size_; ++i) {
153 tensor_map_in_index.push_back(-1);
154 }
155 inputs_tensor_map_.emplace_back(std::move(tensor_map_in));
156 inputs_tensor_map_.emplace_back(std::move(tensor_map_in_index));
157 outputs_tensor_map_.emplace_back(std::move(tensor_map_out));
158 return SUCCESS;
159 }
160
InferTensorInfo()161 Status GatherInfo::InferTensorInfo() {
162 if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) {
163 MS_LOG(ERROR) << name_ << ": inputs shape size must be " << GATHER_V2_INPUTS_SIZE << ", but is "
164 << inputs_shape_.size();
165 return FAILED;
166 }
167 if (outputs_shape_.size() != GATHER_V2_OUTPUTS_SIZE) {
168 MS_LOG(ERROR) << name_ << ": outputs shape size must be " << GATHER_V2_OUTPUTS_SIZE << ", but is "
169 << outputs_shape_.size();
170 return FAILED;
171 }
172 if (inputs_tensor_map_.size() != GATHER_V2_INPUTS_SIZE) {
173 MS_LOG(ERROR) << name_ << ": inputs tensor map size must be " << GATHER_V2_INPUTS_SIZE << ", but is "
174 << inputs_tensor_map_.size();
175 return FAILED;
176 }
177 if (outputs_tensor_map_.size() != GATHER_V2_OUTPUTS_SIZE) {
178 MS_LOG(ERROR) << name_ << ": outputs tensor map size must be " << GATHER_V2_OUTPUTS_SIZE << ", but is "
179 << outputs_tensor_map_.size();
180 return FAILED;
181 }
182 // infer tensor shape
183 Shape input_shape = inputs_shape_.at(0);
184 Shape input_index_shape = inputs_shape_.at(1);
185 Shape output_shape = outputs_shape_.at(0);
186
187 TensorLayout input_tensor_layout, input_index_layout, output_tensor_layout;
188 if ((input_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_.at(0), input_shape) != SUCCESS) ||
189 (input_index_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_.at(1), input_index_shape) != SUCCESS) ||
190 (output_tensor_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_.at(0), output_shape) != SUCCESS)) {
191 return FAILED;
192 }
193
194 TensorInfo input_tensor_info(input_tensor_layout);
195 TensorInfo input_index_info(input_index_layout);
196 TensorInfo output_tensor_info(output_tensor_layout);
197
198 inputs_tensor_info_.push_back(input_tensor_info);
199 inputs_tensor_info_.push_back(input_index_info);
200 outputs_tensor_info_.push_back(output_tensor_info);
201 return SUCCESS;
202 }
203
CreateSubOp(int64_t sub_value)204 OperatorVector CreateSubOp(int64_t sub_value) {
205 OperatorVector ops;
206 OperatorName operator_name = SUB;
207 OperatorAttrs operator_attrs;
208
209 std::vector<int64_t> tensor_data = {sub_value};
210 mindspore::tensor::TensorPtr tensor_ptr = std::make_shared<mindspore::tensor::Tensor>(tensor_data, kInt32);
211 ValuePtr op_param_value = MakeValue(tensor_ptr);
212
213 Attr op1_param = std::make_pair("", op_param_value);
214 OperatorParams operator_param = {std::make_pair(op1_param, 2)};
215
216 OperatorArgs operator_args = std::make_pair(operator_attrs, operator_param);
217 Operator op = std::make_pair(operator_name, operator_args);
218 ops.push_back(op);
219 return ops;
220 }
221
InferTensorSubOps()222 Status GatherInfo::InferTensorSubOps() {
223 sub_ops_.clear();
224 if ((index_size_ == 0) || (axis_strategy_ == 1)) {
225 return SUCCESS;
226 }
227 int64_t mod_n = 1;
228 for (size_t i = LongToSize(axis_) + 1; i < dev_matrix_shape_.size(); i++) {
229 mod_n *= dev_matrix_shape_.at(i);
230 }
231 if ((axis_ >= SizeToLong(dev_matrix_shape_.size())) || axis_ < 0) {
232 MS_LOG(ERROR) << "Axis is " << axis_ << ", not in [0, " << dev_matrix_shape_.size() << ").";
233 }
234 int64_t mod_p = mod_n * dev_matrix_shape_.at(LongToSize(axis_));
235 int64_t rank = g_device_manager->rank_index_in_stage();
236 int64_t mod_rank = rank % mod_p;
237 mod_rank = static_cast<int64_t>(mod_rank / mod_n);
238 if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) {
239 MS_LOG(ERROR) << name_ << ": inputs shape size must be " << GATHER_V2_INPUTS_SIZE << ", but is "
240 << inputs_shape_.size();
241 return FAILED;
242 }
243 if ((axis_ >= SizeToLong(inputs_shape_.at(0).size())) || axis_ < 0) {
244 MS_LOG(ERROR) << "Axis is " << axis_ << ", not in [0, " << inputs_shape_.at(0).size() << ").";
245 }
246 int64_t sub_value = inputs_shape_[0][LongToSize(axis_)] / dev_matrix_shape_[LongToSize(axis_)] * mod_rank;
247
248 OperatorVector sub_op;
249 sub_ops_.emplace_back(std::move(sub_op));
250 sub_op = CreateSubOp(sub_value);
251 sub_ops_.emplace_back(std::move(sub_op));
252 return SUCCESS;
253 }
254
Init(const StrategyPtr & strategy)255 Status GatherInfo::Init(const StrategyPtr &strategy) {
256 if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
257 MS_LOG(ERROR) << name_ << ": Init failed.";
258 return FAILED;
259 }
260 Status status = InferTensorSubOps();
261 if (status != SUCCESS) {
262 MS_LOG(ERROR) << name_ << ": InferTensorSubOps failed.";
263 return status;
264 }
265 MS_LOG(INFO) << name_ << ": Init success.";
266 return SUCCESS;
267 }
268
InitForCostModel(const StrategyPtr & strategy)269 Status GatherInfo::InitForCostModel(const StrategyPtr &strategy) {
270 if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
271 MS_LOG(ERROR) << name_ << ": Init for cost model failed.";
272 return FAILED;
273 }
274 MS_LOG(INFO) << name_ << ": Init for cost model success.";
275 return SUCCESS;
276 }
277
GenerateOpStrategies(int64_t stage_id)278 std::vector<StrategyPtr> GatherInfo::GenerateOpStrategies(int64_t stage_id) {
279 if ((inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) || (outputs_shape_.size() != GATHER_V2_OUTPUTS_SIZE)) {
280 MS_LOG(EXCEPTION) << name_ << " : Inputs shape size(" << inputs_shape_.size() << ") or outputs shape size("
281 << outputs_shape_.size() << "is wrong.";
282 }
283 Shape input0_split(inputs_shape_[0].size(), 1);
284 Shapes splittable_inputs = {input0_split};
285
286 std::vector<StrategyPtr> sp_vector;
287 if (GenerateStrategiesForIndependentInputs(stage_id, {inputs_shape_.at(0)}, splittable_inputs, &sp_vector) !=
288 SUCCESS) {
289 MS_LOG(EXCEPTION) << name_ << " : Generate strategies for independent inputs() failed.";
290 }
291 return sp_vector;
292 }
293
SetCostUnderStrategy(const StrategyPtr & strategy)294 Status GatherInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }
295
GenerateBatchStrategies()296 std::shared_ptr<Strategys> GatherInfo::GenerateBatchStrategies() {
297 if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) {
298 MS_LOG(EXCEPTION) << name_ << ": inputs shape size must be " << GATHER_V2_INPUTS_SIZE << ", but is "
299 << inputs_shape_.size();
300 }
301 if (GetAttrs() != SUCCESS) {
302 MS_LOG(EXCEPTION) << "GetAttrs failed!";
303 }
304
305 Dimensions strategy;
306 if (index_size_ != 1) {
307 strategy.push_back(1);
308 } else {
309 strategy.push_back(stage_device_size_);
310 }
311 for (size_t i = 1; i < inputs_shape_[0].size(); i++) {
312 strategy.push_back(1);
313 }
314 Strategys strategy_v = {strategy};
315 return std::make_shared<Strategys>(strategy_v);
316 }
317 } // namespace parallel
318 } // namespace mindspore
319