• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/parallel/tensor_layout/tensor_redistribution.h"
18 #include <functional>
19 #include <numeric>
20 #include "utils/ms_utils.h"
21 #include "frontend/parallel/status.h"
22 #include "frontend/parallel/tensor_layout/shape_util.h"
23 
24 namespace mindspore {
25 namespace parallel {
Init(const TensorLayout & from,const TensorLayout & to,const RankList & dev_list)26 Status TensorRedistribution::Init(const TensorLayout &from, const TensorLayout &to, const RankList &dev_list) {
27   from_origin_ = from;
28   to_origin_ = to;
29   if (from_origin_.tensor_shape().size() != to_origin_.tensor_shape().size()) {
30     MS_LOG(ERROR) << "from shape size must be equal to to shape size!";
31     MS_LOG(ERROR) << "reshape from_origin_ " << from_origin_.ToString();
32     MS_LOG(ERROR) << "reshape to_origin_ " << to_origin_.ToString();
33     return Status::FAILED;
34   }
35 
36   dev_list_ = dev_list;
37   from_ = from_origin_.SqueezeShape();
38   to_ = to_origin_.SqueezeShape();
39   return Status::SUCCESS;
40 }
41 
InferTensorRedistributionOperatorListUnExpand(bool is_cost_model)42 RedistributionOpListPtr TensorRedistribution::InferTensorRedistributionOperatorListUnExpand(bool is_cost_model) {
43   TensorLayout from_repeat = from_origin_.TransferRepeatLayout();
44   TensorLayout to_repeat = to_origin_.TransferRepeatLayout();
45   MS_LOG(DEBUG) << "reshape from_repeat " << from_repeat.ToString();
46   MS_LOG(DEBUG) << "reshape to_repeat " << to_repeat.ToString();
47   MS_LOG(DEBUG) << "reshape from_origin_ " << from_origin_.ToString();
48   MS_LOG(DEBUG) << "reshape to_origin_ " << to_origin_.ToString();
49   MS_LOG(DEBUG) << "reshape from_ " << from_.ToString();
50   MS_LOG(DEBUG) << "reshape to_ " << to_.ToString();
51   OperatorVector operator_vector;
52   OutPutInfoVector output_info_vector;
53   if (InferRedistribution(from_origin_, from_repeat, &operator_vector, &output_info_vector, is_cost_model) ==
54       Status::FAILED) {
55     return nullptr;
56   }
57   if (from_repeat.slice_shape().array() != to_repeat.slice_shape().array()) {
58     reshape_flag_ = true;
59     ConstructOperator constructor;
60     constructor.UpdateTensorShape(from_repeat.slice_shape().array());
61     Arrangement shape = to_repeat.slice_shape();
62     MS_LOG(DEBUG) << "reshape " << shape.ToString();
63     if (constructor.ReshapeOP(shape.array()) == Status::FAILED) {
64       return nullptr;
65     } else {
66       operator_vector.push_back(constructor.GetOperator());
67       output_info_vector.push_back(std::make_pair(false, 0));
68     }
69   }
70   if (InferRedistribution(to_repeat, to_origin_, &operator_vector, &output_info_vector, is_cost_model) ==
71       Status::FAILED) {
72     return nullptr;
73   }
74   return std::make_shared<std::pair<OperatorVector, OutPutInfoVector>>(
75     std::make_pair(operator_vector, output_info_vector));
76 }
77 
InferTensorRedistributionOperatorList(bool is_cost_model)78 RedistributionOpListPtr TensorRedistribution::InferTensorRedistributionOperatorList(bool is_cost_model) {
79   // Step 1: Match device arrangement between from_ and to_
80   RedistributionLayoutTransfer layout_transfer;
81   Status status = layout_transfer.Init(from_, to_);
82   if (status != Status::SUCCESS) {
83     return nullptr;
84   }
85   TensorLayout from_layout;
86   TensorLayout to_layout;
87   if (layout_transfer.IsDynamicShape()) {
88     from_layout = layout_transfer.from_in();
89     to_layout = layout_transfer.to_in();
90   } else {
91     std::shared_ptr<ReshapeLayoutTransfer> ptr = layout_transfer.UnifyDeviceArrangementAndTensorShape();
92     if (ptr == nullptr) {
93       MS_LOG(ERROR) << "Infer tensor layout return nullptr!";
94       return nullptr;
95     }
96     if (!ptr->ExpandAble()) {
97       expand_able_ = false;
98       return InferTensorRedistributionOperatorListUnExpand(is_cost_model);
99     }
100     from_layout = ptr->from_in();
101     to_layout = ptr->to_in();
102   }
103   MS_LOG(DEBUG) << "reshape from_layout " << from_layout.ToString();
104   MS_LOG(DEBUG) << "reshape to_layout " << to_layout.ToString();
105   MS_LOG(DEBUG) << "reshape from_origin_ " << from_origin_.ToString();
106   MS_LOG(DEBUG) << "reshape to_origin_ " << to_origin_.ToString();
107   MS_LOG(DEBUG) << "reshape from_ " << from_.ToString();
108   MS_LOG(DEBUG) << "reshape to_ " << to_.ToString();
109   // Step 2: Infer redistribution and insert operators
110   OperatorVector operator_vector;
111   OutPutInfoVector output_info_vector;
112   if (InferRedistribution(from_layout, to_layout, &operator_vector, &output_info_vector, is_cost_model) !=
113       Status::SUCCESS) {
114     return nullptr;
115   }
116   // Step 3: Infer reshape and insert operators
117   if (InferReshape(from_layout, to_layout, &operator_vector, &output_info_vector) != Status::SUCCESS) {
118     MS_LOG(ERROR) << "Construct Reshape operator failed!";
119     return nullptr;
120   }
121   return std::make_shared<std::pair<OperatorVector, OutPutInfoVector>>(
122     std::make_pair(operator_vector, output_info_vector));
123 }
124 
InferReshape(const TensorLayout & from_layout,const TensorLayout & to_layout,OperatorVector * const operator_vector,OutPutInfoVector * const output_info_vector)125 Status TensorRedistribution::InferReshape(const TensorLayout &from_layout, const TensorLayout &to_layout,
126                                           OperatorVector *const operator_vector,
127                                           OutPutInfoVector *const output_info_vector) {
128   MS_EXCEPTION_IF_NULL(operator_vector);
129   MS_EXCEPTION_IF_NULL(output_info_vector);
130   ConstructOperator constructor;
131   if (operator_list_.empty()) {
132     if (from_origin_.slice_shape().array() != to_origin_.slice_shape().array() || keep_reshape_) {
133       reshape_flag_ = true;
134       constructor.UpdateTensorShape(from_origin_.slice_shape().array());
135       Arrangement shape = to_origin_.slice_shape();
136       MS_LOG(DEBUG) << "reshape " << shape.ToString();
137       if (constructor.ReshapeOP(shape.array()) == Status::FAILED) {
138         return Status::FAILED;
139       } else {
140         (void)operator_vector->insert(operator_vector->begin(), constructor.GetOperator());
141         (void)output_info_vector->insert(output_info_vector->begin(), std::make_pair(false, 0));
142       }
143     }
144     return Status::SUCCESS;
145   }
146 
147   if (from_origin_.slice_shape().array() != from_layout.slice_shape().array()) {
148     reshape_flag_ = true;
149     constructor.UpdateTensorShape(from_origin_.slice_shape().array());
150     Arrangement shape = from_layout.slice_shape();
151     MS_LOG(DEBUG) << "reshape " << shape.ToString();
152     if (constructor.ReshapeOP(shape.array()) == Status::FAILED) {
153       return Status::FAILED;
154     } else {
155       (void)operator_vector->insert(operator_vector->begin(), constructor.GetOperator());
156       (void)output_info_vector->insert(output_info_vector->begin(), std::make_pair(false, 0));
157     }
158   }
159 
160   if (to_origin_.slice_shape().array() != to_layout.slice_shape().array()) {
161     reshape_flag_ = true;
162     constructor.UpdateTensorShape(to_layout.slice_shape().array());
163     Arrangement shape = to_origin_.slice_shape();
164     MS_LOG(DEBUG) << "step_parallel to reshape " << shape.ToString();
165     if (constructor.ReshapeOP(shape.array()) == Status::FAILED) {
166       return Status::FAILED;
167     } else {
168       (void)operator_vector->insert(operator_vector->end(), constructor.GetOperator());
169       (void)output_info_vector->insert(output_info_vector->end(), std::make_pair(false, 0));
170     }
171   }
172   return Status::SUCCESS;
173 }
174 
InferRedistribution(const TensorLayout & from_layout,const TensorLayout & to_layout,OperatorVector * const operator_vector,OutPutInfoVector * const output_info_vector,bool is_cost_model)175 Status TensorRedistribution::InferRedistribution(const TensorLayout &from_layout, const TensorLayout &to_layout,
176                                                  OperatorVector *const operator_vector,
177                                                  OutPutInfoVector *const output_info_vector, bool is_cost_model) {
178   MS_EXCEPTION_IF_NULL(operator_vector);
179   MS_EXCEPTION_IF_NULL(output_info_vector);
180   RedistributionOperatorInfer operator_infer(construct_op_flag_);
181   if (operator_infer.Init(from_layout, to_layout.tensor_map(), dev_list_, is_cost_model) == Status::FAILED) {
182     MS_LOG(ERROR) << "Init operatorInfer failed";
183     return Status::FAILED;
184   }
185   if (operator_infer.InferRedistributionOperator() != Status::SUCCESS) {
186     MS_LOG(ERROR) << "Infer redistribution failed";
187     return Status::FAILED;
188   } else {
189     for (auto op : operator_infer.operator_vector()) {
190       operator_vector->insert(operator_vector->end(), op);
191     }
192     for (auto info : operator_infer.output_info_vector()) {
193       output_info_vector->insert(output_info_vector->end(), info);
194     }
195     for (auto opc : operator_infer.operator_list()) {
196       operator_list_.insert(operator_list_.end(), opc);
197     }
198   }
199   return Status::SUCCESS;
200 }
201 
ComputeCost()202 Status TensorRedistribution::ComputeCost() {
203   RedistributionOpListPtr redistribution_oplist_ptr = InferTensorRedistributionOperatorList(true);
204   if (redistribution_oplist_ptr == nullptr) {
205     MS_LOG(ERROR) << "Failure: InferTensorRedistribution failed";
206     return Status::FAILED;
207   }
208   // Compute redistribution communication cost and computation cost
209   for (auto &op_cost : operator_list_) {
210     OperatorR op = op_cost.first;
211     Shape slice_shape = op_cost.second;
212     double prod =
213       std::accumulate(slice_shape.begin(), slice_shape.end(), static_cast<double>(1.0), std::multiplies<double>());
214     std::string str = op.first;
215     if (str == PERMUTE_BY_AXIS && ComputePermuteCost(prod, op.second) != Status::SUCCESS) {
216       return Status::FAILED;
217     } else if (str == CONCAT_BY_AXIS && ComputeConcatCost(prod, op.second) != Status::SUCCESS) {
218       return Status::FAILED;
219     } else {
220       // There is only computation cost in SplitByAxis.
221       // computation cost = before_slice_shape
222       computation_cost_ += prod;
223       // This addition may be erroneous
224       memory_cost_ += prod;
225     }
226   }
227   if (reshape_flag()) {
228     Shape prev_shape;
229     if (expand_able_) {
230       prev_shape = from_.slice_shape().array();
231     } else {
232       prev_shape = from_.tensor_shape().array();
233     }
234     double prev_prod =
235       std::accumulate(prev_shape.begin(), prev_shape.end(), static_cast<double>(1.0), std::multiplies<double>());
236     computation_cost_ += COST_FACTOR * prev_prod;
237     memory_cost_ += COST_FACTOR * prev_prod;
238   }
239   return Status::SUCCESS;
240 }
241 
ComputePermuteCost(double input_size,const Shape & attrs)242 Status TensorRedistribution::ComputePermuteCost(double input_size, const Shape &attrs) {
243   // Since AlltoAll is a virtual operator, the expanded operators are used here to compute cost.
244   // communication cost = all_gather + reduce_scatter = before_slice_shape + after_slice_shape
245   if (attrs.size() < TRANSFER_PERMUTE_ARGS_SIZE) {
246     MS_LOG(ERROR) << "attrs size should not be less than 5!";
247     return Status::FAILED;
248   }
249   forward_comm_cost_ += input_size * ALLTOALL_SCALE_FACTOR;
250   backward_comm_cost_ += input_size * ALLTOALL_SCALE_FACTOR;
251   comm_cost_ += COST_FACTOR * input_size * ALLTOALL_SCALE_FACTOR;
252   int64_t concat_dim = attrs[TRANSFER_PERMUTE_CONCAT_DIM_INDEX];
253   if (concat_dim == 0) {
254     // memory cost = all_gather
255     computation_cost_ += input_size;
256     memory_cost_ += input_size;
257   } else {
258     // memory cost = all_gather + split + concat
259     int64_t dev_num = attrs[TRANSFER_PERMUTE_DEV_NUM_INDEX];
260     computation_cost_ += (input_size + input_size * dev_num + input_size * dev_num);
261     memory_cost_ += (input_size * dev_num + input_size * dev_num + input_size);
262   }
263   return Status::SUCCESS;
264 }
265 
ComputeConcatCost(double input_size,const Shape & attrs)266 Status TensorRedistribution::ComputeConcatCost(double input_size, const Shape &attrs) {
267   // communication cost = all_gather + reduce_scatter = before_slice_shape + after_slice_shape
268   // computation cost = before_slice_shape
269   if (attrs.size() < TRANSFER_CONCAT_ARGS_SIZE) {
270     MS_LOG(ERROR) << "op.second size should not be less than 3!";
271     return Status::FAILED;
272   }
273   double dev_num = attrs[TRANSFER_CONCAT_SPLIT_COUNT_INDEX];
274   // here, communication cost = all_gather + reduce_scatter
275   forward_comm_cost_ += input_size * dev_num * ALLGATHER_REDUCESCATTER_SCALE_FACTOR;
276   backward_comm_cost_ += input_size * ALLGATHER_REDUCESCATTER_SCALE_FACTOR;
277   comm_cost_ += input_size * (dev_num + 1.0) * ALLGATHER_REDUCESCATTER_SCALE_FACTOR;
278   int64_t concat_dim = attrs[TRANSFER_CONCAT_TENSOR_DIM_INDEX];
279   if (concat_dim == 0) {
280     // computation cost = all_gather
281     computation_cost_ += input_size;
282     memory_cost_ += input_size * dev_num;
283   } else {
284     // computation cost = all_gather + split + concat
285     computation_cost_ += (input_size + input_size * dev_num + input_size * dev_num);
286     memory_cost_ += (input_size * dev_num + input_size * dev_num + input_size);
287   }
288   return Status::SUCCESS;
289 }
290 }  // namespace parallel
291 }  // namespace mindspore
292