1 /**
2 * Copyright 2019 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/parallel/tensor_layout/tensor_redistribution.h"
18 #include <functional>
19 #include <numeric>
20 #include "utils/ms_utils.h"
21 #include "frontend/parallel/status.h"
22 #include "frontend/parallel/tensor_layout/shape_util.h"
23
24 namespace mindspore {
25 namespace parallel {
Init(const TensorLayout & from,const TensorLayout & to,const RankList & dev_list)26 Status TensorRedistribution::Init(const TensorLayout &from, const TensorLayout &to, const RankList &dev_list) {
27 from_origin_ = from;
28 to_origin_ = to;
29 if (from_origin_.tensor_shape().size() != to_origin_.tensor_shape().size()) {
30 MS_LOG(ERROR) << "from shape size must be equal to to shape size!";
31 MS_LOG(ERROR) << "reshape from_origin_ " << from_origin_.ToString();
32 MS_LOG(ERROR) << "reshape to_origin_ " << to_origin_.ToString();
33 return Status::FAILED;
34 }
35
36 dev_list_ = dev_list;
37 from_ = from_origin_.SqueezeShape();
38 to_ = to_origin_.SqueezeShape();
39 return Status::SUCCESS;
40 }
41
InferTensorRedistributionOperatorListUnExpand(bool is_cost_model)42 RedistributionOpListPtr TensorRedistribution::InferTensorRedistributionOperatorListUnExpand(bool is_cost_model) {
43 TensorLayout from_repeat = from_origin_.TransferRepeatLayout();
44 TensorLayout to_repeat = to_origin_.TransferRepeatLayout();
45 MS_LOG(DEBUG) << "reshape from_repeat " << from_repeat.ToString();
46 MS_LOG(DEBUG) << "reshape to_repeat " << to_repeat.ToString();
47 MS_LOG(DEBUG) << "reshape from_origin_ " << from_origin_.ToString();
48 MS_LOG(DEBUG) << "reshape to_origin_ " << to_origin_.ToString();
49 MS_LOG(DEBUG) << "reshape from_ " << from_.ToString();
50 MS_LOG(DEBUG) << "reshape to_ " << to_.ToString();
51 OperatorVector operator_vector;
52 OutPutInfoVector output_info_vector;
53 if (InferRedistribution(from_origin_, from_repeat, &operator_vector, &output_info_vector, is_cost_model) ==
54 Status::FAILED) {
55 return nullptr;
56 }
57 if (from_repeat.slice_shape().array() != to_repeat.slice_shape().array()) {
58 reshape_flag_ = true;
59 ConstructOperator constructor;
60 constructor.UpdateTensorShape(from_repeat.slice_shape().array());
61 Arrangement shape = to_repeat.slice_shape();
62 MS_LOG(DEBUG) << "reshape " << shape.ToString();
63 if (constructor.ReshapeOP(shape.array()) == Status::FAILED) {
64 return nullptr;
65 } else {
66 operator_vector.push_back(constructor.GetOperator());
67 output_info_vector.push_back(std::make_pair(false, 0));
68 }
69 }
70 if (InferRedistribution(to_repeat, to_origin_, &operator_vector, &output_info_vector, is_cost_model) ==
71 Status::FAILED) {
72 return nullptr;
73 }
74 return std::make_shared<std::pair<OperatorVector, OutPutInfoVector>>(
75 std::make_pair(operator_vector, output_info_vector));
76 }
77
InferTensorRedistributionOperatorList(bool is_cost_model)78 RedistributionOpListPtr TensorRedistribution::InferTensorRedistributionOperatorList(bool is_cost_model) {
79 // Step 1: Match device arrangement between from_ and to_
80 RedistributionLayoutTransfer layout_transfer;
81 Status status = layout_transfer.Init(from_, to_);
82 if (status != Status::SUCCESS) {
83 return nullptr;
84 }
85 TensorLayout from_layout;
86 TensorLayout to_layout;
87 if (layout_transfer.IsDynamicShape()) {
88 from_layout = layout_transfer.from_in();
89 to_layout = layout_transfer.to_in();
90 } else {
91 std::shared_ptr<ReshapeLayoutTransfer> ptr = layout_transfer.UnifyDeviceArrangementAndTensorShape();
92 if (ptr == nullptr) {
93 MS_LOG(ERROR) << "Infer tensor layout return nullptr!";
94 return nullptr;
95 }
96 if (!ptr->ExpandAble()) {
97 expand_able_ = false;
98 return InferTensorRedistributionOperatorListUnExpand(is_cost_model);
99 }
100 from_layout = ptr->from_in();
101 to_layout = ptr->to_in();
102 }
103 MS_LOG(DEBUG) << "reshape from_layout " << from_layout.ToString();
104 MS_LOG(DEBUG) << "reshape to_layout " << to_layout.ToString();
105 MS_LOG(DEBUG) << "reshape from_origin_ " << from_origin_.ToString();
106 MS_LOG(DEBUG) << "reshape to_origin_ " << to_origin_.ToString();
107 MS_LOG(DEBUG) << "reshape from_ " << from_.ToString();
108 MS_LOG(DEBUG) << "reshape to_ " << to_.ToString();
109 // Step 2: Infer redistribution and insert operators
110 OperatorVector operator_vector;
111 OutPutInfoVector output_info_vector;
112 if (InferRedistribution(from_layout, to_layout, &operator_vector, &output_info_vector, is_cost_model) !=
113 Status::SUCCESS) {
114 return nullptr;
115 }
116 // Step 3: Infer reshape and insert operators
117 if (InferReshape(from_layout, to_layout, &operator_vector, &output_info_vector) != Status::SUCCESS) {
118 MS_LOG(ERROR) << "Construct Reshape operator failed!";
119 return nullptr;
120 }
121 return std::make_shared<std::pair<OperatorVector, OutPutInfoVector>>(
122 std::make_pair(operator_vector, output_info_vector));
123 }
124
InferReshape(const TensorLayout & from_layout,const TensorLayout & to_layout,OperatorVector * const operator_vector,OutPutInfoVector * const output_info_vector)125 Status TensorRedistribution::InferReshape(const TensorLayout &from_layout, const TensorLayout &to_layout,
126 OperatorVector *const operator_vector,
127 OutPutInfoVector *const output_info_vector) {
128 MS_EXCEPTION_IF_NULL(operator_vector);
129 MS_EXCEPTION_IF_NULL(output_info_vector);
130 ConstructOperator constructor;
131 if (operator_list_.empty()) {
132 if (from_origin_.slice_shape().array() != to_origin_.slice_shape().array() || keep_reshape_) {
133 reshape_flag_ = true;
134 constructor.UpdateTensorShape(from_origin_.slice_shape().array());
135 Arrangement shape = to_origin_.slice_shape();
136 MS_LOG(DEBUG) << "reshape " << shape.ToString();
137 if (constructor.ReshapeOP(shape.array()) == Status::FAILED) {
138 return Status::FAILED;
139 } else {
140 (void)operator_vector->insert(operator_vector->begin(), constructor.GetOperator());
141 (void)output_info_vector->insert(output_info_vector->begin(), std::make_pair(false, 0));
142 }
143 }
144 return Status::SUCCESS;
145 }
146
147 if (from_origin_.slice_shape().array() != from_layout.slice_shape().array()) {
148 reshape_flag_ = true;
149 constructor.UpdateTensorShape(from_origin_.slice_shape().array());
150 Arrangement shape = from_layout.slice_shape();
151 MS_LOG(DEBUG) << "reshape " << shape.ToString();
152 if (constructor.ReshapeOP(shape.array()) == Status::FAILED) {
153 return Status::FAILED;
154 } else {
155 (void)operator_vector->insert(operator_vector->begin(), constructor.GetOperator());
156 (void)output_info_vector->insert(output_info_vector->begin(), std::make_pair(false, 0));
157 }
158 }
159
160 if (to_origin_.slice_shape().array() != to_layout.slice_shape().array()) {
161 reshape_flag_ = true;
162 constructor.UpdateTensorShape(to_layout.slice_shape().array());
163 Arrangement shape = to_origin_.slice_shape();
164 MS_LOG(DEBUG) << "step_parallel to reshape " << shape.ToString();
165 if (constructor.ReshapeOP(shape.array()) == Status::FAILED) {
166 return Status::FAILED;
167 } else {
168 (void)operator_vector->insert(operator_vector->end(), constructor.GetOperator());
169 (void)output_info_vector->insert(output_info_vector->end(), std::make_pair(false, 0));
170 }
171 }
172 return Status::SUCCESS;
173 }
174
InferRedistribution(const TensorLayout & from_layout,const TensorLayout & to_layout,OperatorVector * const operator_vector,OutPutInfoVector * const output_info_vector,bool is_cost_model)175 Status TensorRedistribution::InferRedistribution(const TensorLayout &from_layout, const TensorLayout &to_layout,
176 OperatorVector *const operator_vector,
177 OutPutInfoVector *const output_info_vector, bool is_cost_model) {
178 MS_EXCEPTION_IF_NULL(operator_vector);
179 MS_EXCEPTION_IF_NULL(output_info_vector);
180 RedistributionOperatorInfer operator_infer(construct_op_flag_);
181 if (operator_infer.Init(from_layout, to_layout.tensor_map(), dev_list_, is_cost_model) == Status::FAILED) {
182 MS_LOG(ERROR) << "Init operatorInfer failed";
183 return Status::FAILED;
184 }
185 if (operator_infer.InferRedistributionOperator() != Status::SUCCESS) {
186 MS_LOG(ERROR) << "Infer redistribution failed";
187 return Status::FAILED;
188 } else {
189 for (auto op : operator_infer.operator_vector()) {
190 operator_vector->insert(operator_vector->end(), op);
191 }
192 for (auto info : operator_infer.output_info_vector()) {
193 output_info_vector->insert(output_info_vector->end(), info);
194 }
195 for (auto opc : operator_infer.operator_list()) {
196 operator_list_.insert(operator_list_.end(), opc);
197 }
198 }
199 return Status::SUCCESS;
200 }
201
ComputeCost()202 Status TensorRedistribution::ComputeCost() {
203 RedistributionOpListPtr redistribution_oplist_ptr = InferTensorRedistributionOperatorList(true);
204 if (redistribution_oplist_ptr == nullptr) {
205 MS_LOG(ERROR) << "Failure: InferTensorRedistribution failed";
206 return Status::FAILED;
207 }
208 // Compute redistribution communication cost and computation cost
209 for (auto &op_cost : operator_list_) {
210 OperatorR op = op_cost.first;
211 Shape slice_shape = op_cost.second;
212 double prod =
213 std::accumulate(slice_shape.begin(), slice_shape.end(), static_cast<double>(1.0), std::multiplies<double>());
214 std::string str = op.first;
215 if (str == PERMUTE_BY_AXIS && ComputePermuteCost(prod, op.second) != Status::SUCCESS) {
216 return Status::FAILED;
217 } else if (str == CONCAT_BY_AXIS && ComputeConcatCost(prod, op.second) != Status::SUCCESS) {
218 return Status::FAILED;
219 } else {
220 // There is only computation cost in SplitByAxis.
221 // computation cost = before_slice_shape
222 computation_cost_ += prod;
223 // This addition may be erroneous
224 memory_cost_ += prod;
225 }
226 }
227 if (reshape_flag()) {
228 Shape prev_shape;
229 if (expand_able_) {
230 prev_shape = from_.slice_shape().array();
231 } else {
232 prev_shape = from_.tensor_shape().array();
233 }
234 double prev_prod =
235 std::accumulate(prev_shape.begin(), prev_shape.end(), static_cast<double>(1.0), std::multiplies<double>());
236 computation_cost_ += COST_FACTOR * prev_prod;
237 memory_cost_ += COST_FACTOR * prev_prod;
238 }
239 return Status::SUCCESS;
240 }
241
ComputePermuteCost(double input_size,const Shape & attrs)242 Status TensorRedistribution::ComputePermuteCost(double input_size, const Shape &attrs) {
243 // Since AlltoAll is a virtual operator, the expanded operators are used here to compute cost.
244 // communication cost = all_gather + reduce_scatter = before_slice_shape + after_slice_shape
245 if (attrs.size() < TRANSFER_PERMUTE_ARGS_SIZE) {
246 MS_LOG(ERROR) << "attrs size should not be less than 5!";
247 return Status::FAILED;
248 }
249 forward_comm_cost_ += input_size * ALLTOALL_SCALE_FACTOR;
250 backward_comm_cost_ += input_size * ALLTOALL_SCALE_FACTOR;
251 comm_cost_ += COST_FACTOR * input_size * ALLTOALL_SCALE_FACTOR;
252 int64_t concat_dim = attrs[TRANSFER_PERMUTE_CONCAT_DIM_INDEX];
253 if (concat_dim == 0) {
254 // memory cost = all_gather
255 computation_cost_ += input_size;
256 memory_cost_ += input_size;
257 } else {
258 // memory cost = all_gather + split + concat
259 int64_t dev_num = attrs[TRANSFER_PERMUTE_DEV_NUM_INDEX];
260 computation_cost_ += (input_size + input_size * dev_num + input_size * dev_num);
261 memory_cost_ += (input_size * dev_num + input_size * dev_num + input_size);
262 }
263 return Status::SUCCESS;
264 }
265
ComputeConcatCost(double input_size,const Shape & attrs)266 Status TensorRedistribution::ComputeConcatCost(double input_size, const Shape &attrs) {
267 // communication cost = all_gather + reduce_scatter = before_slice_shape + after_slice_shape
268 // computation cost = before_slice_shape
269 if (attrs.size() < TRANSFER_CONCAT_ARGS_SIZE) {
270 MS_LOG(ERROR) << "op.second size should not be less than 3!";
271 return Status::FAILED;
272 }
273 double dev_num = attrs[TRANSFER_CONCAT_SPLIT_COUNT_INDEX];
274 // here, communication cost = all_gather + reduce_scatter
275 forward_comm_cost_ += input_size * dev_num * ALLGATHER_REDUCESCATTER_SCALE_FACTOR;
276 backward_comm_cost_ += input_size * ALLGATHER_REDUCESCATTER_SCALE_FACTOR;
277 comm_cost_ += input_size * (dev_num + 1.0) * ALLGATHER_REDUCESCATTER_SCALE_FACTOR;
278 int64_t concat_dim = attrs[TRANSFER_CONCAT_TENSOR_DIM_INDEX];
279 if (concat_dim == 0) {
280 // computation cost = all_gather
281 computation_cost_ += input_size;
282 memory_cost_ += input_size * dev_num;
283 } else {
284 // computation cost = all_gather + split + concat
285 computation_cost_ += (input_size + input_size * dev_num + input_size * dev_num);
286 memory_cost_ += (input_size * dev_num + input_size * dev_num + input_size);
287 }
288 return Status::SUCCESS;
289 }
290 } // namespace parallel
291 } // namespace mindspore
292