1 /**
2 * Copyright 2019 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/parallel/auto_parallel/edge_costmodel.h"
18
19 #include <algorithm>
20 #include <functional>
21 #include <iterator>
22 #include <utility>
23 #include "frontend/parallel/auto_parallel/costmodel.h"
24 #include "frontend/parallel/auto_parallel/graph_costmodel.h"
25 #include "frontend/parallel/tensor_layout/tensor_redistribution.h"
26
27 namespace mindspore {
28 namespace parallel {
InitEdgeCost()29 Status Edge::InitEdgeCost() {
30 bool has_available_cost = false;
31 pre_op_output_.clear();
32 next_op_input_.clear();
33 cost_map_.clear();
34
35 for (auto &swc : prev_op_->GetStrategyCost()) {
36 MS_EXCEPTION_IF_NULL(swc);
37 pre_op_output_.emplace_back(std::make_pair(swc->strategy_ptr, swc->outputs_ptr));
38 }
39 for (auto &swc : next_op_->GetStrategyCost()) {
40 MS_EXCEPTION_IF_NULL(swc);
41 next_op_input_.emplace_back(std::make_pair(swc->strategy_ptr, swc->inputs_ptr));
42 }
43 if (is_identity_edge) {
44 for (auto &target_output : pre_op_output_) {
45 auto target_output_lyt = target_output.second[prev_op_output_index_].tensor_layout();
46 auto target_output_str = target_output.first;
47 for (auto &target_input : next_op_input_) {
48 auto target_input_lyt = target_input.second[next_op_input_index_].tensor_layout();
49 auto target_input_str = target_input.first;
50 if (target_output_lyt == target_input_lyt) {
51 CostPtrKey ck = {target_output_str, target_input_str};
52 CostPtr cost = std::make_shared<Cost>(0.0, 0.0);
53 MS_EXCEPTION_IF_NULL(cost);
54 cost->communication_without_parameter_ = 0.0;
55 cost->communication_with_partial_para_ = 0.0;
56 CostPtrList cl;
57 cl.push_back(cost);
58 (void)cost_map_.emplace(std::make_pair(ck, cl));
59 has_available_cost = true;
60 }
61 }
62 }
63 } else {
64 for (auto &target_output : pre_op_output_) {
65 auto target_output_lyt = target_output.second[prev_op_output_index_].tensor_layout();
66 auto target_output_str = target_output.first;
67 auto type_length = prev_op_->GetOutputTypeLengths()[prev_op_output_index_];
68 auto type = prev_op_->outputs_type()[prev_op_output_index_];
69 for (auto &target_input : next_op_input_) {
70 auto target_input_lyt = target_input.second[next_op_input_index_].tensor_layout();
71 auto target_input_str = target_input.first;
72 CostPtr cost;
73 if (GetRedistributionCost(target_output_lyt, target_input_lyt, type_length, type, &cost) != SUCCESS) {
74 MS_LOG(EXCEPTION) << "Failure: redistribution cost calculation failed";
75 }
76 MS_EXCEPTION_IF_NULL(cost);
77 MS_LOG(DEBUG) << "The redistribution cost: computation_cost: " << cost->computation_cost_
78 << ", communication_cost: " << cost->communication_cost_
79 << ", communication_without_parameter_: " << cost->communication_without_parameter_
80 << ", communication_with_partial_para_: " << cost->communication_with_partial_para_ << ".";
81 // refine communication cost calculation for practice
82 RefineForPracticalCost(cost, true);
83 cost->communication_forward_ = cost->communication_redis_forward_;
84 CostPtrKey ck = {target_output_str, target_input_str};
85 CostPtrList cl;
86 cl.push_back(cost);
87 (void)cost_map_.emplace(std::make_pair(ck, cl));
88 has_available_cost = true;
89 }
90 }
91 }
92 if (!has_available_cost) {
93 const auto fully_use = CostModelContext::GetInstance()->fully_use_device();
94 const auto stra_follow = CostModelContext::GetInstance()->elementwise_stra_follow();
95 if (fully_use) {
96 MS_LOG(EXCEPTION) << "Generating cost for edge: " << edge_name_
97 << " failed, it may be caused by setting 'fully_use_devices' true. Try to set "
98 "'fully_use_devices' false.";
99 } else if (stra_follow) {
100 MS_LOG(EXCEPTION) << "Generating cost for edge: " << edge_name_
101 << " failed, it may be caused by setting 'elementwise_op_strategy_follow' true. "
102 "Try to set 'elementwise_op_strategy_follow' false.";
103 }
104 if (edge_name_.find(RESHAPE) != std::string::npos) {
105 MS_LOG(EXCEPTION) << "Generating cost for edge: " << edge_name_
106 << " failed, it may be caused by setting different strategies for operators following Reshape. "
107 "Try to fix that.";
108 }
109 MS_LOG(EXCEPTION) << "Generating cost for edge: " << edge_name_ << " failed.";
110 }
111 return Status::SUCCESS;
112 }
113
GetRedistributionCost(const TensorLayout & prev_op_output_layout,const TensorLayout & next_op_input_layout,size_t type_length,const TypePtr & type,CostPtr * cost)114 Status Edge::GetRedistributionCost(const TensorLayout &prev_op_output_layout, const TensorLayout &next_op_input_layout,
115 size_t type_length, const TypePtr &type, CostPtr *cost) {
116 MS_EXCEPTION_IF_NULL(prev_op_);
117 MS_EXCEPTION_IF_NULL(cost);
118 RankList dev_list = prev_op_->stage_device_list();
119 TensorRedistribution tensor_redistribution(false);
120
121 // Init TensorRedistribution
122 if (tensor_redistribution.Init(prev_op_output_layout, next_op_input_layout, dev_list) == FAILED) {
123 MS_LOG(EXCEPTION) << "Failure: tensor_redistribution init failed.";
124 }
125
126 if (tensor_redistribution.ComputeCost() == FAILED) {
127 MS_LOG(EXCEPTION) << "Failure: tensor_redistribution ComputeCost failed.";
128 }
129
130 double comm_cost = tensor_redistribution.comm_cost();
131 double forward_comm_cost = tensor_redistribution.forward_comm_cost();
132 double backward_comm_cost = tensor_redistribution.backward_comm_cost();
133 double computation_cost = tensor_redistribution.computation_cost();
134 double mem_cost = tensor_redistribution.memory_cost();
135 const auto gamma = CostModelContext::GetInstance()->costmodel_gamma();
136
137 // Now AllGather, ReduceScatter, AlltoAll don't support bool type
138 MS_EXCEPTION_IF_NULL(type);
139 if ((type->type_id() == kNumberTypeBool) && (comm_cost > 0)) {
140 computation_cost = INF;
141 comm_cost = INF;
142 MS_LOG(WARNING) << "Communication Operators don't support bool dtype!";
143 }
144 *cost = std::make_shared<Cost>(type_length * computation_cost, type_length * comm_cost);
145 (*cost)->communication_without_parameter_ = type_length * comm_cost;
146 (*cost)->communication_with_partial_para_ =
147 (*cost)->communication_without_parameter_ +
148 gamma * ((*cost)->communication_cost_ - (*cost)->communication_without_parameter_);
149 (*cost)->communication_redis_forward_ = type_length * forward_comm_cost;
150 (*cost)->communication_redis_backward_ = type_length * backward_comm_cost;
151 (*cost)->memory_with_reuse_ = mem_cost;
152 return Status::SUCCESS;
153 }
154
GetCostList(StrategyPtr output_str,StrategyPtr input_str)155 CostPtrList Edge::GetCostList(StrategyPtr output_str, StrategyPtr input_str) {
156 CostPtrKey ck = {output_str, input_str};
157 CostPtrList result;
158 if (cost_map_.find(ck) != cost_map_.end()) {
159 return cost_map_.at(ck);
160 }
161 return result;
162 }
163
CreateEdgeEliminationCostList(const StrategyPtr & output_st_ptr,const std::vector<EdgePtr> & edges,const StrategyPtr & input_st_ptr)164 CostPtrList Edge::CreateEdgeEliminationCostList(const StrategyPtr &output_st_ptr, const std::vector<EdgePtr> &edges,
165 const StrategyPtr &input_st_ptr) {
166 std::function<CostPtrList(EdgePtr)> LocalGetCostList = [&](const EdgePtr &edge) {
167 MS_EXCEPTION_IF_NULL(edge);
168 return edge->GetCostList(output_st_ptr, input_st_ptr);
169 };
170 CostPtrList result;
171 std::vector<CostPtrList> all_cost_list;
172 all_cost_list.resize(edges.size());
173 (void)std::transform(edges.begin(), edges.end(), all_cost_list.begin(), LocalGetCostList);
174
175 CostPtrList selected_cost_list(all_cost_list.size(), nullptr);
176 std::function<void(size_t, double, double, double, double, double)> recursive =
177 [&](size_t k, double computation, double memory, double communication, double communication_without_para,
178 double communication_forward) {
179 const auto gamma = CostModelContext::GetInstance()->costmodel_gamma();
180 if (k == edges.size()) {
181 auto decision = std::make_shared<EdgeEliminationDecision>(selected_cost_list);
182 CostPtr new_cost = std::make_shared<Cost>(computation, communication);
183 MS_EXCEPTION_IF_NULL(new_cost);
184 new_cost->communication_without_parameter_ = communication_without_para;
185 new_cost->communication_with_partial_para_ =
186 communication_without_para + gamma * (communication - communication_without_para);
187 new_cost->memory_with_reuse_ = memory;
188 new_cost->communication_forward_ = communication_forward;
189 new_cost->decision_ptr_ = decision;
190 result.push_back(new_cost);
191 return;
192 }
193 for (auto &c : all_cost_list[k]) {
194 MS_EXCEPTION_IF_NULL(c);
195 selected_cost_list[k] = c;
196 recursive(k + 1, computation + c->computation_cost_, memory + c->memory_with_reuse_,
197 communication + c->communication_cost_,
198 communication_without_para + c->communication_without_parameter_,
199 communication_forward + c->communication_forward_);
200 }
201 };
202 recursive(0, 0.0, 0.0, 0.0, 0.0, 0.0);
203 Simplify(&result);
204 return result;
205 }
206
EdgeEliminationSetNewCost(OperatorInfoPtr,const std::vector<EdgePtr> & edges,OperatorInfoPtr)207 void Edge::EdgeEliminationSetNewCost(OperatorInfoPtr, const std::vector<EdgePtr> &edges, OperatorInfoPtr) {
208 bool valid = false;
209 for (const auto &output_pair : pre_op_output_) {
210 StrategyPtr output_st_ptr = output_pair.first;
211 for (const auto &input_pair : next_op_input_) {
212 StrategyPtr input_st_ptr = input_pair.first;
213 CostPtrList clist = CreateEdgeEliminationCostList(output_st_ptr, edges, input_st_ptr);
214 CostPtrKey key = {output_st_ptr, input_st_ptr};
215 cost_map_[key] = clist;
216 if ((!valid) && (!clist.empty())) {
217 valid = true;
218 }
219 }
220 }
221 if (!valid) {
222 MS_LOG(EXCEPTION) << "Creating edge: " << edge_name_ << " failed.";
223 }
224 }
225
CreateOpEliminationSubCostList(StrategyPtr op_strategy,const CostPtrList & left_cost_list,const CostPtrList & middle_cost_list,const CostPtrList & right_cost_list,CostPtrList * ret_cost_list)226 void Edge::CreateOpEliminationSubCostList(StrategyPtr op_strategy, const CostPtrList &left_cost_list,
227 const CostPtrList &middle_cost_list, const CostPtrList &right_cost_list,
228 CostPtrList *ret_cost_list) {
229 for (auto &left_cost : left_cost_list) {
230 MS_EXCEPTION_IF_NULL(left_cost);
231 for (auto &middle_cost : middle_cost_list) {
232 MS_EXCEPTION_IF_NULL(middle_cost);
233 for (auto &right_cost : right_cost_list) {
234 MS_EXCEPTION_IF_NULL(right_cost);
235 double computation =
236 left_cost->computation_cost_ + middle_cost->computation_cost_ + right_cost->computation_cost_;
237 double communication =
238 left_cost->communication_cost_ + middle_cost->communication_cost_ + right_cost->communication_cost_;
239 double communication_forward =
240 left_cost->communication_forward_ + middle_cost->communication_forward_ + right_cost->communication_forward_;
241 double communication_without_para = left_cost->communication_without_parameter_ +
242 middle_cost->communication_without_parameter_ +
243 right_cost->communication_without_parameter_;
244 double memory_cost =
245 left_cost->memory_with_reuse_ + middle_cost->memory_with_reuse_ + right_cost->memory_with_reuse_;
246
247 auto decision = std::make_shared<OpEliminationDecision>(op_strategy, left_cost, middle_cost, right_cost);
248 auto cost = std::make_shared<Cost>(computation, communication, decision);
249 const auto gamma = CostModelContext::GetInstance()->costmodel_gamma();
250 MS_EXCEPTION_IF_NULL(cost);
251 cost->communication_without_parameter_ = communication_without_para;
252 cost->communication_with_partial_para_ =
253 communication_without_para + gamma * (communication - communication_without_para);
254 cost->memory_with_reuse_ = memory_cost;
255 cost->communication_forward_ = communication_forward;
256 ret_cost_list->emplace_back(std::move(cost));
257 }
258 }
259 }
260 }
261
CreateOpEliminationCostList(const EdgePtr & e1,const StrategyPtr & output_st_ptr,const OperatorInfoPtr & op,const EdgePtr & e2,const StrategyPtr & input_st_ptr)262 CostPtrList Edge::CreateOpEliminationCostList(const EdgePtr &e1, const StrategyPtr &output_st_ptr,
263 const OperatorInfoPtr &op, const EdgePtr &e2,
264 const StrategyPtr &input_st_ptr) {
265 MS_EXCEPTION_IF_NULL(op);
266 MS_EXCEPTION_IF_NULL(e1);
267 MS_EXCEPTION_IF_NULL(e2);
268 CostPtrList result;
269 for (const auto &op_strategy : op->GetStrategyCost()) {
270 MS_EXCEPTION_IF_NULL(op_strategy);
271 auto middle_strategy = op_strategy->strategy_ptr;
272 CreateOpEliminationSubCostList(middle_strategy, e1->GetCostList(output_st_ptr, middle_strategy),
273 op_strategy->cost_list, e2->GetCostList(middle_strategy, input_st_ptr), &result);
274 }
275 Simplify(&result);
276 return result;
277 }
278
OpEliminationSetNewCost(const EdgePtr & e1,const OperatorInfoPtr & op,const EdgePtr & e2)279 void Edge::OpEliminationSetNewCost(const EdgePtr &e1, const OperatorInfoPtr &op, const EdgePtr &e2) {
280 bool valid = false;
281 for (const auto &output_pair : pre_op_output_) {
282 StrategyPtr output_st_ptr = output_pair.first;
283 for (const auto &input_pair : next_op_input_) {
284 StrategyPtr input_st_ptr = input_pair.first;
285
286 CostPtrList clist = CreateOpEliminationCostList(e1, output_st_ptr, op, e2, input_st_ptr);
287 CostPtrKey key = {output_st_ptr, input_st_ptr};
288 cost_map_[key] = clist;
289 if ((!valid) && (!clist.empty())) {
290 valid = true;
291 }
292 }
293 }
294 if (!valid) {
295 MS_LOG(EXCEPTION) << "Creating edge: " << edge_name_ << " failed.";
296 }
297 }
298
CalculateMemoryCost()299 Status Edge::CalculateMemoryCost() {
300 if (is_output_parameter_involve_ == -1) {
301 MS_LOG(ERROR) << "is_output_parameter_involve_ is unset.";
302 return FAILED;
303 }
304 if (is_output_parameter_involve_ == 0) {
305 // In this case, it is sure that the tensor redistribution along this edge is NOT parameter-involved, thus it is
306 // unnecessary to keep them in memory.
307 for (auto &cost_kv : cost_map_) {
308 auto &cost_v = cost_kv.second;
309 if (!cost_v.empty()) {
310 cost_v[0]->memory_with_reuse_ = 0;
311 }
312 }
313 }
314
315 return SUCCESS;
316 }
317
CalculateMemoryCostForInference()318 Status Edge::CalculateMemoryCostForInference() {
319 // Currently, memory cost is NOT calculated for redistribution
320 if ((is_output_critical_ != 0) && (is_output_critical_ != 1)) {
321 MS_LOG(ERROR) << "Failure: unexpected output critical flag value: " << is_output_critical_;
322 return FAILED;
323 }
324 for (auto &cost_kv : cost_map_) {
325 auto &cost_v = cost_kv.second;
326 if (!cost_v.empty()) {
327 cost_v[0]->memory_with_reuse_ = 0;
328 }
329 }
330 return SUCCESS;
331 }
332
GetCostByStrategyPair(const CostPtrKey & stra_pair)333 CostPtr Edge::GetCostByStrategyPair(const CostPtrKey &stra_pair) {
334 if (cost_map_.find(stra_pair) == cost_map_.end()) {
335 return nullptr;
336 }
337 auto cost_vec = cost_map_[stra_pair];
338 if (cost_vec.empty()) {
339 PrintStrategy(stra_pair.first);
340 PrintStrategy(stra_pair.second);
341 MS_LOG(EXCEPTION) << "No available cost under current strategy pair of the edge: " << edge_name_;
342 }
343 if (cost_vec.size() > 1) {
344 PrintStrategy(stra_pair.first);
345 PrintStrategy(stra_pair.second);
346 MS_LOG(INFO) << "Multiple costs available under the stratey pair of the edge: " << edge_name_;
347 }
348 return cost_vec[0];
349 }
350
GetNextOpStrategyByPrevOpStrategyWithZeroComm(const StrategyPtr & prev_op_stra)351 StrategyPtr Edge::GetNextOpStrategyByPrevOpStrategyWithZeroComm(const StrategyPtr &prev_op_stra) {
352 std::vector<std::pair<StrategyPtr, double>> next_op_stras;
353 for (auto &key_value : cost_map_) {
354 const auto &candidate_prev_op_stra = key_value.first.first;
355 if (prev_op_stra->IsEqual(candidate_prev_op_stra) && (key_value.second[0]->communication_cost_ == 0.0)) {
356 (void)next_op_stras.emplace_back(key_value.first.second, key_value.second[0]->computation_cost_);
357 }
358 }
359 if (next_op_stras.empty()) {
360 MS_LOG(ERROR) << "There are no available strategy for zero communication cost for edge: " << edge_name_;
361 return nullptr;
362 } else if (next_op_stras.size() > 1) {
363 MS_LOG(INFO) << "There are multiple strategies for edge: " << edge_name_
364 << ", choose the one with"
365 " minimum computation costs.";
366 }
367 std::sort(next_op_stras.begin(), next_op_stras.end(),
368 [](const std::pair<StrategyPtr, double> &a, const std::pair<StrategyPtr, double> &b) {
369 return a.second <= b.second;
370 });
371 return next_op_stras[0].first;
372 }
373
GetPrevOpStrategyByNextOpStrategyWithZeroComm(const StrategyPtr & next_op_stra)374 StrategyPtr Edge::GetPrevOpStrategyByNextOpStrategyWithZeroComm(const StrategyPtr &next_op_stra) {
375 std::vector<std::pair<StrategyPtr, double>> prev_op_stras;
376 for (auto &key_value : cost_map_) {
377 const auto &candidate_next_op_stra = key_value.first.second;
378 if (next_op_stra->IsEqual(candidate_next_op_stra) && (key_value.second[0]->communication_cost_ == 0.0)) {
379 (void)prev_op_stras.emplace_back(key_value.first.first, key_value.second[0]->computation_cost_);
380 }
381 }
382 if (prev_op_stras.empty()) {
383 MS_LOG(ERROR) << "There are no available strategy for zero communication cost for edge: " << edge_name_;
384 return nullptr;
385 } else if (prev_op_stras.size() > 1) {
386 MS_LOG(INFO) << "There are multiple strategies for edge: " << edge_name_
387 << ", choose the one with minimum "
388 "computation costs.";
389 }
390 std::sort(prev_op_stras.begin(), prev_op_stras.end(),
391 [](const std::pair<StrategyPtr, double> &a, const std::pair<StrategyPtr, double> &b) {
392 return a.second <= b.second;
393 });
394 return prev_op_stras[0].first;
395 }
396
SetCostMapAndInputOutput(std::map<CostPtrKey,CostPtrList> & cost_map)397 void Edge::SetCostMapAndInputOutput(std::map<CostPtrKey, CostPtrList> &cost_map) {
398 cost_map_ = cost_map;
399 pre_op_output_.clear();
400 next_op_input_.clear();
401
402 for (auto &key_value : cost_map_) {
403 auto &key_pair = key_value.first;
404 pre_op_output_.emplace_back(std::pair<StrategyPtr, std::vector<TensorInfo>>(key_pair.first, {}));
405 next_op_input_.emplace_back(std::pair<StrategyPtr, std::vector<TensorInfo>>(key_pair.second, {}));
406 }
407 }
408
409 // Return true if there are available strategies in this edge.
CheckStrategyCostPossibility() const410 bool Edge::CheckStrategyCostPossibility() const { return !cost_map_.empty(); }
411 } // namespace parallel
412 } // namespace mindspore
413