• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/parallel/auto_parallel/edge_costmodel.h"
18 
19 #include <algorithm>
20 #include <functional>
21 #include <iterator>
22 #include <utility>
23 #include "frontend/parallel/auto_parallel/costmodel.h"
24 #include "frontend/parallel/auto_parallel/graph_costmodel.h"
25 #include "frontend/parallel/tensor_layout/tensor_redistribution.h"
26 
27 namespace mindspore {
28 namespace parallel {
InitEdgeCost()29 Status Edge::InitEdgeCost() {
30   bool has_available_cost = false;
31   pre_op_output_.clear();
32   next_op_input_.clear();
33   cost_map_.clear();
34 
35   for (auto &swc : prev_op_->GetStrategyCost()) {
36     MS_EXCEPTION_IF_NULL(swc);
37     pre_op_output_.emplace_back(std::make_pair(swc->strategy_ptr, swc->outputs_ptr));
38   }
39   for (auto &swc : next_op_->GetStrategyCost()) {
40     MS_EXCEPTION_IF_NULL(swc);
41     next_op_input_.emplace_back(std::make_pair(swc->strategy_ptr, swc->inputs_ptr));
42   }
43   if (is_identity_edge) {
44     for (auto &target_output : pre_op_output_) {
45       auto target_output_lyt = target_output.second[prev_op_output_index_].tensor_layout();
46       auto target_output_str = target_output.first;
47       for (auto &target_input : next_op_input_) {
48         auto target_input_lyt = target_input.second[next_op_input_index_].tensor_layout();
49         auto target_input_str = target_input.first;
50         if (target_output_lyt == target_input_lyt) {
51           CostPtrKey ck = {target_output_str, target_input_str};
52           CostPtr cost = std::make_shared<Cost>(0.0, 0.0);
53           MS_EXCEPTION_IF_NULL(cost);
54           cost->communication_without_parameter_ = 0.0;
55           cost->communication_with_partial_para_ = 0.0;
56           CostPtrList cl;
57           cl.push_back(cost);
58           (void)cost_map_.emplace(std::make_pair(ck, cl));
59           has_available_cost = true;
60         }
61       }
62     }
63   } else {
64     for (auto &target_output : pre_op_output_) {
65       auto target_output_lyt = target_output.second[prev_op_output_index_].tensor_layout();
66       auto target_output_str = target_output.first;
67       auto type_length = prev_op_->GetOutputTypeLengths()[prev_op_output_index_];
68       auto type = prev_op_->outputs_type()[prev_op_output_index_];
69       for (auto &target_input : next_op_input_) {
70         auto target_input_lyt = target_input.second[next_op_input_index_].tensor_layout();
71         auto target_input_str = target_input.first;
72         CostPtr cost;
73         if (GetRedistributionCost(target_output_lyt, target_input_lyt, type_length, type, &cost) != SUCCESS) {
74           MS_LOG(EXCEPTION) << "Failure: redistribution cost calculation failed";
75         }
76         MS_EXCEPTION_IF_NULL(cost);
77         MS_LOG(DEBUG) << "The redistribution cost: computation_cost: " << cost->computation_cost_
78                       << ", communication_cost: " << cost->communication_cost_
79                       << ", communication_without_parameter_: " << cost->communication_without_parameter_
80                       << ", communication_with_partial_para_: " << cost->communication_with_partial_para_ << ".";
81         // refine communication cost calculation for practice
82         RefineForPracticalCost(cost, true);
83         cost->communication_forward_ = cost->communication_redis_forward_;
84         CostPtrKey ck = {target_output_str, target_input_str};
85         CostPtrList cl;
86         cl.push_back(cost);
87         (void)cost_map_.emplace(std::make_pair(ck, cl));
88         has_available_cost = true;
89       }
90     }
91   }
92   if (!has_available_cost) {
93     const auto fully_use = CostModelContext::GetInstance()->fully_use_device();
94     const auto stra_follow = CostModelContext::GetInstance()->elementwise_stra_follow();
95     if (fully_use) {
96       MS_LOG(EXCEPTION) << "Generating cost for edge: " << edge_name_
97                         << " failed, it may be caused by setting 'fully_use_devices' true. Try to set "
98                            "'fully_use_devices' false.";
99     } else if (stra_follow) {
100       MS_LOG(EXCEPTION) << "Generating cost for edge: " << edge_name_
101                         << " failed, it may be caused by setting 'elementwise_op_strategy_follow' true. "
102                            "Try to set 'elementwise_op_strategy_follow' false.";
103     }
104     if (edge_name_.find(RESHAPE) != std::string::npos) {
105       MS_LOG(EXCEPTION) << "Generating cost for edge: " << edge_name_
106                         << " failed, it may be caused by setting different strategies for operators following Reshape. "
107                            "Try to fix that.";
108     }
109     MS_LOG(EXCEPTION) << "Generating cost for edge: " << edge_name_ << " failed.";
110   }
111   return Status::SUCCESS;
112 }
113 
GetRedistributionCost(const TensorLayout & prev_op_output_layout,const TensorLayout & next_op_input_layout,size_t type_length,const TypePtr & type,CostPtr * cost)114 Status Edge::GetRedistributionCost(const TensorLayout &prev_op_output_layout, const TensorLayout &next_op_input_layout,
115                                    size_t type_length, const TypePtr &type, CostPtr *cost) {
116   MS_EXCEPTION_IF_NULL(prev_op_);
117   MS_EXCEPTION_IF_NULL(cost);
118   RankList dev_list = prev_op_->stage_device_list();
119   TensorRedistribution tensor_redistribution(false);
120 
121   // Init TensorRedistribution
122   if (tensor_redistribution.Init(prev_op_output_layout, next_op_input_layout, dev_list) == FAILED) {
123     MS_LOG(EXCEPTION) << "Failure: tensor_redistribution init failed.";
124   }
125 
126   if (tensor_redistribution.ComputeCost() == FAILED) {
127     MS_LOG(EXCEPTION) << "Failure: tensor_redistribution ComputeCost failed.";
128   }
129 
130   double comm_cost = tensor_redistribution.comm_cost();
131   double forward_comm_cost = tensor_redistribution.forward_comm_cost();
132   double backward_comm_cost = tensor_redistribution.backward_comm_cost();
133   double computation_cost = tensor_redistribution.computation_cost();
134   double mem_cost = tensor_redistribution.memory_cost();
135   const auto gamma = CostModelContext::GetInstance()->costmodel_gamma();
136 
137   // Now AllGather, ReduceScatter, AlltoAll don't support bool type
138   MS_EXCEPTION_IF_NULL(type);
139   if ((type->type_id() == kNumberTypeBool) && (comm_cost > 0)) {
140     computation_cost = INF;
141     comm_cost = INF;
142     MS_LOG(WARNING) << "Communication Operators don't support bool dtype!";
143   }
144   *cost = std::make_shared<Cost>(type_length * computation_cost, type_length * comm_cost);
145   (*cost)->communication_without_parameter_ = type_length * comm_cost;
146   (*cost)->communication_with_partial_para_ =
147     (*cost)->communication_without_parameter_ +
148     gamma * ((*cost)->communication_cost_ - (*cost)->communication_without_parameter_);
149   (*cost)->communication_redis_forward_ = type_length * forward_comm_cost;
150   (*cost)->communication_redis_backward_ = type_length * backward_comm_cost;
151   (*cost)->memory_with_reuse_ = mem_cost;
152   return Status::SUCCESS;
153 }
154 
GetCostList(StrategyPtr output_str,StrategyPtr input_str)155 CostPtrList Edge::GetCostList(StrategyPtr output_str, StrategyPtr input_str) {
156   CostPtrKey ck = {output_str, input_str};
157   CostPtrList result;
158   if (cost_map_.find(ck) != cost_map_.end()) {
159     return cost_map_.at(ck);
160   }
161   return result;
162 }
163 
CreateEdgeEliminationCostList(const StrategyPtr & output_st_ptr,const std::vector<EdgePtr> & edges,const StrategyPtr & input_st_ptr)164 CostPtrList Edge::CreateEdgeEliminationCostList(const StrategyPtr &output_st_ptr, const std::vector<EdgePtr> &edges,
165                                                 const StrategyPtr &input_st_ptr) {
166   std::function<CostPtrList(EdgePtr)> LocalGetCostList = [&](const EdgePtr &edge) {
167     MS_EXCEPTION_IF_NULL(edge);
168     return edge->GetCostList(output_st_ptr, input_st_ptr);
169   };
170   CostPtrList result;
171   std::vector<CostPtrList> all_cost_list;
172   all_cost_list.resize(edges.size());
173   (void)std::transform(edges.begin(), edges.end(), all_cost_list.begin(), LocalGetCostList);
174 
175   CostPtrList selected_cost_list(all_cost_list.size(), nullptr);
176   std::function<void(size_t, double, double, double, double, double)> recursive =
177     [&](size_t k, double computation, double memory, double communication, double communication_without_para,
178         double communication_forward) {
179       const auto gamma = CostModelContext::GetInstance()->costmodel_gamma();
180       if (k == edges.size()) {
181         auto decision = std::make_shared<EdgeEliminationDecision>(selected_cost_list);
182         CostPtr new_cost = std::make_shared<Cost>(computation, communication);
183         MS_EXCEPTION_IF_NULL(new_cost);
184         new_cost->communication_without_parameter_ = communication_without_para;
185         new_cost->communication_with_partial_para_ =
186           communication_without_para + gamma * (communication - communication_without_para);
187         new_cost->memory_with_reuse_ = memory;
188         new_cost->communication_forward_ = communication_forward;
189         new_cost->decision_ptr_ = decision;
190         result.push_back(new_cost);
191         return;
192       }
193       for (auto &c : all_cost_list[k]) {
194         MS_EXCEPTION_IF_NULL(c);
195         selected_cost_list[k] = c;
196         recursive(k + 1, computation + c->computation_cost_, memory + c->memory_with_reuse_,
197                   communication + c->communication_cost_,
198                   communication_without_para + c->communication_without_parameter_,
199                   communication_forward + c->communication_forward_);
200       }
201     };
202   recursive(0, 0.0, 0.0, 0.0, 0.0, 0.0);
203   Simplify(&result);
204   return result;
205 }
206 
EdgeEliminationSetNewCost(OperatorInfoPtr,const std::vector<EdgePtr> & edges,OperatorInfoPtr)207 void Edge::EdgeEliminationSetNewCost(OperatorInfoPtr, const std::vector<EdgePtr> &edges, OperatorInfoPtr) {
208   bool valid = false;
209   for (const auto &output_pair : pre_op_output_) {
210     StrategyPtr output_st_ptr = output_pair.first;
211     for (const auto &input_pair : next_op_input_) {
212       StrategyPtr input_st_ptr = input_pair.first;
213       CostPtrList clist = CreateEdgeEliminationCostList(output_st_ptr, edges, input_st_ptr);
214       CostPtrKey key = {output_st_ptr, input_st_ptr};
215       cost_map_[key] = clist;
216       if ((!valid) && (!clist.empty())) {
217         valid = true;
218       }
219     }
220   }
221   if (!valid) {
222     MS_LOG(EXCEPTION) << "Creating edge: " << edge_name_ << " failed.";
223   }
224 }
225 
CreateOpEliminationSubCostList(StrategyPtr op_strategy,const CostPtrList & left_cost_list,const CostPtrList & middle_cost_list,const CostPtrList & right_cost_list,CostPtrList * ret_cost_list)226 void Edge::CreateOpEliminationSubCostList(StrategyPtr op_strategy, const CostPtrList &left_cost_list,
227                                           const CostPtrList &middle_cost_list, const CostPtrList &right_cost_list,
228                                           CostPtrList *ret_cost_list) {
229   for (auto &left_cost : left_cost_list) {
230     MS_EXCEPTION_IF_NULL(left_cost);
231     for (auto &middle_cost : middle_cost_list) {
232       MS_EXCEPTION_IF_NULL(middle_cost);
233       for (auto &right_cost : right_cost_list) {
234         MS_EXCEPTION_IF_NULL(right_cost);
235         double computation =
236           left_cost->computation_cost_ + middle_cost->computation_cost_ + right_cost->computation_cost_;
237         double communication =
238           left_cost->communication_cost_ + middle_cost->communication_cost_ + right_cost->communication_cost_;
239         double communication_forward =
240           left_cost->communication_forward_ + middle_cost->communication_forward_ + right_cost->communication_forward_;
241         double communication_without_para = left_cost->communication_without_parameter_ +
242                                             middle_cost->communication_without_parameter_ +
243                                             right_cost->communication_without_parameter_;
244         double memory_cost =
245           left_cost->memory_with_reuse_ + middle_cost->memory_with_reuse_ + right_cost->memory_with_reuse_;
246 
247         auto decision = std::make_shared<OpEliminationDecision>(op_strategy, left_cost, middle_cost, right_cost);
248         auto cost = std::make_shared<Cost>(computation, communication, decision);
249         const auto gamma = CostModelContext::GetInstance()->costmodel_gamma();
250         MS_EXCEPTION_IF_NULL(cost);
251         cost->communication_without_parameter_ = communication_without_para;
252         cost->communication_with_partial_para_ =
253           communication_without_para + gamma * (communication - communication_without_para);
254         cost->memory_with_reuse_ = memory_cost;
255         cost->communication_forward_ = communication_forward;
256         ret_cost_list->emplace_back(std::move(cost));
257       }
258     }
259   }
260 }
261 
CreateOpEliminationCostList(const EdgePtr & e1,const StrategyPtr & output_st_ptr,const OperatorInfoPtr & op,const EdgePtr & e2,const StrategyPtr & input_st_ptr)262 CostPtrList Edge::CreateOpEliminationCostList(const EdgePtr &e1, const StrategyPtr &output_st_ptr,
263                                               const OperatorInfoPtr &op, const EdgePtr &e2,
264                                               const StrategyPtr &input_st_ptr) {
265   MS_EXCEPTION_IF_NULL(op);
266   MS_EXCEPTION_IF_NULL(e1);
267   MS_EXCEPTION_IF_NULL(e2);
268   CostPtrList result;
269   for (const auto &op_strategy : op->GetStrategyCost()) {
270     MS_EXCEPTION_IF_NULL(op_strategy);
271     auto middle_strategy = op_strategy->strategy_ptr;
272     CreateOpEliminationSubCostList(middle_strategy, e1->GetCostList(output_st_ptr, middle_strategy),
273                                    op_strategy->cost_list, e2->GetCostList(middle_strategy, input_st_ptr), &result);
274   }
275   Simplify(&result);
276   return result;
277 }
278 
OpEliminationSetNewCost(const EdgePtr & e1,const OperatorInfoPtr & op,const EdgePtr & e2)279 void Edge::OpEliminationSetNewCost(const EdgePtr &e1, const OperatorInfoPtr &op, const EdgePtr &e2) {
280   bool valid = false;
281   for (const auto &output_pair : pre_op_output_) {
282     StrategyPtr output_st_ptr = output_pair.first;
283     for (const auto &input_pair : next_op_input_) {
284       StrategyPtr input_st_ptr = input_pair.first;
285 
286       CostPtrList clist = CreateOpEliminationCostList(e1, output_st_ptr, op, e2, input_st_ptr);
287       CostPtrKey key = {output_st_ptr, input_st_ptr};
288       cost_map_[key] = clist;
289       if ((!valid) && (!clist.empty())) {
290         valid = true;
291       }
292     }
293   }
294   if (!valid) {
295     MS_LOG(EXCEPTION) << "Creating edge: " << edge_name_ << " failed.";
296   }
297 }
298 
CalculateMemoryCost()299 Status Edge::CalculateMemoryCost() {
300   if (is_output_parameter_involve_ == -1) {
301     MS_LOG(ERROR) << "is_output_parameter_involve_ is unset.";
302     return FAILED;
303   }
304   if (is_output_parameter_involve_ == 0) {
305     // In this case, it is sure that the tensor redistribution along this edge is NOT parameter-involved, thus it is
306     // unnecessary to keep them in memory.
307     for (auto &cost_kv : cost_map_) {
308       auto &cost_v = cost_kv.second;
309       if (!cost_v.empty()) {
310         cost_v[0]->memory_with_reuse_ = 0;
311       }
312     }
313   }
314 
315   return SUCCESS;
316 }
317 
CalculateMemoryCostForInference()318 Status Edge::CalculateMemoryCostForInference() {
319   // Currently, memory cost is NOT calculated for redistribution
320   if ((is_output_critical_ != 0) && (is_output_critical_ != 1)) {
321     MS_LOG(ERROR) << "Failure: unexpected output critical flag value: " << is_output_critical_;
322     return FAILED;
323   }
324   for (auto &cost_kv : cost_map_) {
325     auto &cost_v = cost_kv.second;
326     if (!cost_v.empty()) {
327       cost_v[0]->memory_with_reuse_ = 0;
328     }
329   }
330   return SUCCESS;
331 }
332 
GetCostByStrategyPair(const CostPtrKey & stra_pair)333 CostPtr Edge::GetCostByStrategyPair(const CostPtrKey &stra_pair) {
334   if (cost_map_.find(stra_pair) == cost_map_.end()) {
335     return nullptr;
336   }
337   auto cost_vec = cost_map_[stra_pair];
338   if (cost_vec.empty()) {
339     PrintStrategy(stra_pair.first);
340     PrintStrategy(stra_pair.second);
341     MS_LOG(EXCEPTION) << "No available cost under current strategy pair of the edge: " << edge_name_;
342   }
343   if (cost_vec.size() > 1) {
344     PrintStrategy(stra_pair.first);
345     PrintStrategy(stra_pair.second);
346     MS_LOG(INFO) << "Multiple costs available under the stratey pair of the edge: " << edge_name_;
347   }
348   return cost_vec[0];
349 }
350 
GetNextOpStrategyByPrevOpStrategyWithZeroComm(const StrategyPtr & prev_op_stra)351 StrategyPtr Edge::GetNextOpStrategyByPrevOpStrategyWithZeroComm(const StrategyPtr &prev_op_stra) {
352   std::vector<std::pair<StrategyPtr, double>> next_op_stras;
353   for (auto &key_value : cost_map_) {
354     const auto &candidate_prev_op_stra = key_value.first.first;
355     if (prev_op_stra->IsEqual(candidate_prev_op_stra) && (key_value.second[0]->communication_cost_ == 0.0)) {
356       (void)next_op_stras.emplace_back(key_value.first.second, key_value.second[0]->computation_cost_);
357     }
358   }
359   if (next_op_stras.empty()) {
360     MS_LOG(ERROR) << "There are no available strategy for zero communication cost for edge: " << edge_name_;
361     return nullptr;
362   } else if (next_op_stras.size() > 1) {
363     MS_LOG(INFO) << "There are multiple strategies for edge: " << edge_name_
364                  << ", choose the one with"
365                     " minimum computation costs.";
366   }
367   std::sort(next_op_stras.begin(), next_op_stras.end(),
368             [](const std::pair<StrategyPtr, double> &a, const std::pair<StrategyPtr, double> &b) {
369               return a.second <= b.second;
370             });
371   return next_op_stras[0].first;
372 }
373 
GetPrevOpStrategyByNextOpStrategyWithZeroComm(const StrategyPtr & next_op_stra)374 StrategyPtr Edge::GetPrevOpStrategyByNextOpStrategyWithZeroComm(const StrategyPtr &next_op_stra) {
375   std::vector<std::pair<StrategyPtr, double>> prev_op_stras;
376   for (auto &key_value : cost_map_) {
377     const auto &candidate_next_op_stra = key_value.first.second;
378     if (next_op_stra->IsEqual(candidate_next_op_stra) && (key_value.second[0]->communication_cost_ == 0.0)) {
379       (void)prev_op_stras.emplace_back(key_value.first.first, key_value.second[0]->computation_cost_);
380     }
381   }
382   if (prev_op_stras.empty()) {
383     MS_LOG(ERROR) << "There are no available strategy for zero communication cost for edge: " << edge_name_;
384     return nullptr;
385   } else if (prev_op_stras.size() > 1) {
386     MS_LOG(INFO) << "There are multiple strategies for edge: " << edge_name_
387                  << ", choose the one with minimum "
388                     "computation costs.";
389   }
390   std::sort(prev_op_stras.begin(), prev_op_stras.end(),
391             [](const std::pair<StrategyPtr, double> &a, const std::pair<StrategyPtr, double> &b) {
392               return a.second <= b.second;
393             });
394   return prev_op_stras[0].first;
395 }
396 
SetCostMapAndInputOutput(std::map<CostPtrKey,CostPtrList> & cost_map)397 void Edge::SetCostMapAndInputOutput(std::map<CostPtrKey, CostPtrList> &cost_map) {
398   cost_map_ = cost_map;
399   pre_op_output_.clear();
400   next_op_input_.clear();
401 
402   for (auto &key_value : cost_map_) {
403     auto &key_pair = key_value.first;
404     pre_op_output_.emplace_back(std::pair<StrategyPtr, std::vector<TensorInfo>>(key_pair.first, {}));
405     next_op_input_.emplace_back(std::pair<StrategyPtr, std::vector<TensorInfo>>(key_pair.second, {}));
406   }
407 }
408 
409 // Return true if there are available strategies in this edge.
CheckStrategyCostPossibility() const410 bool Edge::CheckStrategyCostPossibility() const { return !cost_map_.empty(); }
411 }  // namespace parallel
412 }  // namespace mindspore
413