• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/parallel/auto_parallel/edge_costmodel.h"
18 
19 #include <algorithm>
20 #include <functional>
21 #include <iterator>
22 #include <utility>
23 #include "frontend/parallel/auto_parallel/costmodel.h"
24 #include "frontend/parallel/auto_parallel/graph_costmodel.h"
25 #include "frontend/parallel/tensor_layout/tensor_redistribution.h"
26 #include "frontend/parallel/ops_info/reshape_info.h"
27 
28 namespace mindspore {
29 namespace parallel {
InitEdgeCost()30 Status Edge::InitEdgeCost() {
31   bool has_available_cost = false;
32   pre_op_output_.clear();
33   next_op_input_.clear();
34   cost_map_.clear();
35 
36   for (auto &swc : prev_op_->GetStrategyCost()) {
37     MS_EXCEPTION_IF_NULL(swc);
38     (void)pre_op_output_.emplace_back(std::make_pair(swc->strategy_ptr, swc->outputs_ptr));
39   }
40   for (auto &swc : next_op_->GetStrategyCost()) {
41     MS_EXCEPTION_IF_NULL(swc);
42     (void)next_op_input_.emplace_back(std::make_pair(swc->strategy_ptr, swc->inputs_ptr));
43   }
44   if (is_identity_edge) {
45     for (auto &target_output : pre_op_output_) {
46       auto target_output_lyt = target_output.second[prev_op_output_index_].tensor_layout();
47       auto target_output_str = target_output.first;
48       for (auto &target_input : next_op_input_) {
49         auto target_input_lyt = target_input.second[next_op_input_index_].tensor_layout();
50         auto target_input_str = target_input.first;
51         // for identity_info ops, no need to compare device_matrix
52         if ((target_output_lyt == target_input_lyt) || (target_output_lyt.IsSameWithoutSplit(target_input_lyt) &&
53                                                         edge_name().find(IDENTITY_INFO) != std::string::npos)) {
54           CostPtrKey ck = {target_output_str, target_input_str};
55           CostPtr cost = std::make_shared<Cost>(0.0, 0.0);
56           MS_EXCEPTION_IF_NULL(cost);
57           cost->communication_without_parameter_ = 0.0;
58           cost->communication_with_partial_para_ = 0.0;
59           CostPtrList cl;
60           cl.push_back(cost);
61           (void)cost_map_.emplace(std::make_pair(ck, cl));
62           has_available_cost = true;
63         }
64       }
65     }
66   } else {
67     for (auto &target_output : pre_op_output_) {
68       auto target_output_lyt = target_output.second[prev_op_output_index_].tensor_layout();
69       auto target_output_str = target_output.first;
70       auto type_length = prev_op_->GetOutputTypeLengths()[prev_op_output_index_];
71       auto type = prev_op_->outputs_type()[prev_op_output_index_];
72       for (auto &target_input : next_op_input_) {
73         auto target_input_lyt = target_input.second[next_op_input_index_].tensor_layout();
74         auto target_input_str = target_input.first;
75         CostPtr cost;
76         if (GetRedistributionCost(target_output_lyt, target_input_lyt, type_length, type, &cost) != SUCCESS) {
77           MS_LOG(EXCEPTION) << "Failure: redistribution cost calculation failed";
78         }
79         MS_EXCEPTION_IF_NULL(cost);
80         MS_LOG(DEBUG) << "The redistribution cost: computation_cost: " << cost->computation_cost_
81                       << ", communication_cost: " << cost->communication_cost_
82                       << ", communication_without_parameter_: " << cost->communication_without_parameter_
83                       << ", communication_with_partial_para_: " << cost->communication_with_partial_para_ << ".";
84         // refine communication cost calculation for practice
85         RefineForPracticalCost(cost, true);
86         cost->communication_forward_ = cost->communication_redis_forward_;
87         CostPtrKey ck = {target_output_str, target_input_str};
88         CostPtrList cl;
89         cl.push_back(cost);
90         (void)cost_map_.emplace(std::make_pair(ck, cl));
91         has_available_cost = true;
92       }
93     }
94   }
95   if (!has_available_cost) {
96     const auto fully_use = CostModelContext::GetInstance()->fully_use_device();
97     const auto stra_follow = CostModelContext::GetInstance()->elementwise_stra_follow();
98     if (fully_use) {
99       MS_LOG(ERROR) << "Generating cost for edge: " << edge_name_
100                     << " failed, it may be caused by setting 'fully_use_devices' true. Try to set "
101                        "'fully_use_devices' false.";
102     } else if (stra_follow) {
103       MS_LOG(ERROR) << "Generating cost for edge: " << edge_name_
104                     << " failed, it may be caused by setting 'elementwise_op_strategy_follow' true. "
105                        "Try to set 'elementwise_op_strategy_follow' false.";
106     }
107     if (edge_name_.find(RESHAPE) != std::string::npos) {
108       MS_LOG(ERROR) << "Generating cost for edge: " << edge_name_
109                     << " failed, it may be caused by setting different strategies for operators following Reshape. "
110                        "Try to fix that.";
111     }
112     MS_LOG(INFO) << "Generating cost for edge: " << edge_name_ << " failed.";
113     return Status::FAILED;
114   }
115   return Status::SUCCESS;
116 }
117 
GetRedistributionCost(const TensorLayout & prev_op_output_layout,const TensorLayout & next_op_input_layout,size_t type_length,const TypePtr & type,CostPtr * cost)118 Status Edge::GetRedistributionCost(const TensorLayout &prev_op_output_layout, const TensorLayout &next_op_input_layout,
119                                    size_t type_length, const TypePtr &type, CostPtr *cost) {
120   MS_EXCEPTION_IF_NULL(prev_op_);
121   MS_EXCEPTION_IF_NULL(cost);
122   RankList dev_list = prev_op_->stage_device_list();
123   TensorRedistribution tensor_redistribution(false);
124 
125   // Init TensorRedistribution
126   if (tensor_redistribution.Init(prev_op_output_layout, next_op_input_layout, dev_list) == FAILED) {
127     MS_LOG(EXCEPTION) << "Failure: tensor_redistribution init failed.";
128   }
129 
130   if (tensor_redistribution.ComputeCost() == FAILED) {
131     MS_LOG(EXCEPTION) << "Failure: tensor_redistribution ComputeCost failed.";
132   }
133 
134   double comm_cost = tensor_redistribution.comm_cost();
135   double forward_comm_cost = tensor_redistribution.forward_comm_cost();
136   double backward_comm_cost = tensor_redistribution.backward_comm_cost();
137   double computation_cost = tensor_redistribution.computation_cost();
138   double mem_cost = tensor_redistribution.memory_cost();
139   const auto gamma = CostModelContext::GetInstance()->costmodel_gamma();
140 
141   // Now AllGather, ReduceScatter, AlltoAll don't support bool type
142   MS_EXCEPTION_IF_NULL(type);
143   if ((type->type_id() == kNumberTypeBool) && (comm_cost > 0)) {
144     computation_cost = INF;
145     comm_cost = INF;
146     MS_LOG(WARNING) << "Communication Operators don't support bool dtype!";
147   }
148   *cost = std::make_shared<Cost>(type_length * computation_cost, type_length * comm_cost);
149   (*cost)->communication_without_parameter_ = type_length * comm_cost;
150   (*cost)->communication_with_partial_para_ =
151     (*cost)->communication_without_parameter_ +
152     gamma * ((*cost)->communication_cost_ - (*cost)->communication_without_parameter_);
153   (*cost)->communication_redis_forward_ = type_length * forward_comm_cost;
154   (*cost)->communication_redis_backward_ = type_length * backward_comm_cost;
155   (*cost)->memory_with_reuse_ = mem_cost;
156   return Status::SUCCESS;
157 }
158 
GetCostList(StrategyPtr output_str,StrategyPtr input_str)159 CostPtrList Edge::GetCostList(StrategyPtr output_str, StrategyPtr input_str) {
160   CostPtrKey ck = {output_str, input_str};
161   CostPtrList result;
162   if (cost_map_.find(ck) != cost_map_.end()) {
163     return cost_map_.at(ck);
164   }
165   return result;
166 }
167 
CreateEdgeEliminationCostList(const StrategyPtr & output_st_ptr,const std::vector<EdgePtr> & edges,const StrategyPtr & input_st_ptr) const168 CostPtrList Edge::CreateEdgeEliminationCostList(const StrategyPtr &output_st_ptr, const std::vector<EdgePtr> &edges,
169                                                 const StrategyPtr &input_st_ptr) const {
170   std::function<CostPtrList(EdgePtr)> LocalGetCostList = [&](const EdgePtr &edge) {
171     MS_EXCEPTION_IF_NULL(edge);
172     return edge->GetCostList(output_st_ptr, input_st_ptr);
173   };
174   CostPtrList result;
175   std::vector<CostPtrList> all_cost_list;
176   all_cost_list.resize(edges.size());
177   (void)std::transform(edges.begin(), edges.end(), all_cost_list.begin(), LocalGetCostList);
178 
179   CostPtrList selected_cost_list(all_cost_list.size(), nullptr);
180   std::function<void(size_t, double, double, double, double, double)> recursive =
181     [&](size_t k, double computation, double memory, double communication, double communication_without_para,
182         double communication_forward) {
183       const auto gamma = CostModelContext::GetInstance()->costmodel_gamma();
184       if (k == edges.size()) {
185         auto decision = std::make_shared<EdgeEliminationDecision>(selected_cost_list);
186         CostPtr new_cost = std::make_shared<Cost>(computation, communication);
187         MS_EXCEPTION_IF_NULL(new_cost);
188         new_cost->communication_without_parameter_ = communication_without_para;
189         new_cost->communication_with_partial_para_ =
190           communication_without_para + gamma * (communication - communication_without_para);
191         new_cost->memory_with_reuse_ = memory;
192         new_cost->communication_forward_ = communication_forward;
193         new_cost->decision_ptr_ = decision;
194         result.push_back(new_cost);
195         return;
196       }
197       for (auto &c : all_cost_list[k]) {
198         MS_EXCEPTION_IF_NULL(c);
199         selected_cost_list[k] = c;
200         recursive(k + 1, computation + c->computation_cost_, memory + c->memory_with_reuse_,
201                   communication + c->communication_cost_,
202                   communication_without_para + c->communication_without_parameter_,
203                   communication_forward + c->communication_forward_);
204       }
205     };
206   recursive(0, 0.0, 0.0, 0.0, 0.0, 0.0);
207   Simplify(&result);
208   return result;
209 }
210 
EdgeEliminationSetNewCost(OperatorInfoPtr,const std::vector<EdgePtr> & edges,OperatorInfoPtr)211 void Edge::EdgeEliminationSetNewCost(OperatorInfoPtr, const std::vector<EdgePtr> &edges, OperatorInfoPtr) {
212   bool valid = false;
213   for (const auto &output_pair : pre_op_output_) {
214     StrategyPtr output_st_ptr = output_pair.first;
215     for (const auto &input_pair : next_op_input_) {
216       StrategyPtr input_st_ptr = input_pair.first;
217       CostPtrList clist = CreateEdgeEliminationCostList(output_st_ptr, edges, input_st_ptr);
218       CostPtrKey key = {output_st_ptr, input_st_ptr};
219       cost_map_[key] = clist;
220       if ((!valid) && (!clist.empty())) {
221         valid = true;
222       }
223     }
224   }
225   if (!valid) {
226     MS_LOG(EXCEPTION) << "Creating edge: " << edge_name_ << " failed.";
227   }
228 }
229 
CreateOpEliminationSubCostList(StrategyPtr op_strategy,const CostPtrList & left_cost_list,const CostPtrList & middle_cost_list,const CostPtrList & right_cost_list,CostPtrList * ret_cost_list) const230 void Edge::CreateOpEliminationSubCostList(StrategyPtr op_strategy, const CostPtrList &left_cost_list,
231                                           const CostPtrList &middle_cost_list, const CostPtrList &right_cost_list,
232                                           CostPtrList *ret_cost_list) const {
233   for (auto &left_cost : left_cost_list) {
234     MS_EXCEPTION_IF_NULL(left_cost);
235     for (auto &middle_cost : middle_cost_list) {
236       MS_EXCEPTION_IF_NULL(middle_cost);
237       for (auto &right_cost : right_cost_list) {
238         MS_EXCEPTION_IF_NULL(right_cost);
239         double computation =
240           left_cost->computation_cost_ + middle_cost->computation_cost_ + right_cost->computation_cost_;
241         double communication =
242           left_cost->communication_cost_ + middle_cost->communication_cost_ + right_cost->communication_cost_;
243         double communication_forward =
244           left_cost->communication_forward_ + middle_cost->communication_forward_ + right_cost->communication_forward_;
245         double communication_without_para = left_cost->communication_without_parameter_ +
246                                             middle_cost->communication_without_parameter_ +
247                                             right_cost->communication_without_parameter_;
248         double memory_cost =
249           left_cost->memory_with_reuse_ + middle_cost->memory_with_reuse_ + right_cost->memory_with_reuse_;
250 
251         auto decision = std::make_shared<OpEliminationDecision>(op_strategy, left_cost, middle_cost, right_cost);
252         auto cost = std::make_shared<Cost>(computation, communication, decision);
253         const auto gamma = CostModelContext::GetInstance()->costmodel_gamma();
254         MS_EXCEPTION_IF_NULL(cost);
255         cost->communication_without_parameter_ = communication_without_para;
256         cost->communication_with_partial_para_ =
257           communication_without_para + gamma * (communication - communication_without_para);
258         cost->memory_with_reuse_ = memory_cost;
259         cost->communication_forward_ = communication_forward;
260         (void)ret_cost_list->emplace_back(std::move(cost));
261       }
262     }
263   }
264 }
265 
CreateOpEliminationCostList(const EdgePtr & e1,const StrategyPtr & output_st_ptr,const OperatorInfoPtr & op,const EdgePtr & e2,const StrategyPtr & input_st_ptr) const266 CostPtrList Edge::CreateOpEliminationCostList(const EdgePtr &e1, const StrategyPtr &output_st_ptr,
267                                               const OperatorInfoPtr &op, const EdgePtr &e2,
268                                               const StrategyPtr &input_st_ptr) const {
269   MS_EXCEPTION_IF_NULL(op);
270   MS_EXCEPTION_IF_NULL(e1);
271   MS_EXCEPTION_IF_NULL(e2);
272   CostPtrList result;
273   for (const auto &op_strategy : op->GetStrategyCost()) {
274     MS_EXCEPTION_IF_NULL(op_strategy);
275     auto middle_strategy = op_strategy->strategy_ptr;
276     CreateOpEliminationSubCostList(middle_strategy, e1->GetCostList(output_st_ptr, middle_strategy),
277                                    op_strategy->cost_list, e2->GetCostList(middle_strategy, input_st_ptr), &result);
278   }
279   Simplify(&result);
280   return result;
281 }
282 
OpEliminationSetNewCost(const EdgePtr & e1,const OperatorInfoPtr & op,const EdgePtr & e2)283 void Edge::OpEliminationSetNewCost(const EdgePtr &e1, const OperatorInfoPtr &op, const EdgePtr &e2) {
284   bool valid = false;
285   for (const auto &output_pair : pre_op_output_) {
286     StrategyPtr output_st_ptr = output_pair.first;
287     for (const auto &input_pair : next_op_input_) {
288       StrategyPtr input_st_ptr = input_pair.first;
289 
290       CostPtrList clist = CreateOpEliminationCostList(e1, output_st_ptr, op, e2, input_st_ptr);
291       CostPtrKey key = {output_st_ptr, input_st_ptr};
292       cost_map_[key] = clist;
293       if ((!valid) && (!clist.empty())) {
294         valid = true;
295       }
296     }
297   }
298   if (!valid) {
299     MS_LOG(EXCEPTION) << "Creating edge: " << edge_name_ << " failed.";
300   }
301 }
302 
CalculateMemoryCost()303 Status Edge::CalculateMemoryCost() {
304   if (is_output_parameter_involve_ == -1) {
305     MS_LOG(ERROR) << "is_output_parameter_involve_ is unset.";
306     return FAILED;
307   }
308   if (is_output_parameter_involve_ == 0) {
309     // In this case, it is sure that the tensor redistribution along this edge is NOT parameter-involved, thus it is
310     // unnecessary to keep them in memory.
311     for (auto &cost_kv : cost_map_) {
312       auto &cost_v = cost_kv.second;
313       if (!cost_v.empty()) {
314         cost_v[0]->memory_with_reuse_ = 0;
315       }
316     }
317   }
318 
319   return SUCCESS;
320 }
321 
CalculateMemoryCostForInference()322 Status Edge::CalculateMemoryCostForInference() {
323   // Currently, memory cost is NOT calculated for redistribution
324   if ((is_output_critical_ != 0) && (is_output_critical_ != 1)) {
325     MS_LOG(ERROR) << "Failure: unexpected output critical flag value: " << is_output_critical_;
326     return FAILED;
327   }
328   for (const auto &cost_kv : cost_map_) {
329     auto &cost_v = cost_kv.second;
330     if (!cost_v.empty()) {
331       cost_v[0]->memory_with_reuse_ = 0;
332     }
333   }
334   return SUCCESS;
335 }
336 
GetCostByStrategyPair(const CostPtrKey & stra_pair)337 CostPtr Edge::GetCostByStrategyPair(const CostPtrKey &stra_pair) {
338   if (cost_map_.find(stra_pair) == cost_map_.end()) {
339     return nullptr;
340   }
341   auto cost_vec = cost_map_[stra_pair];
342   if (cost_vec.empty()) {
343     MS_LOG(EXCEPTION) << "stra_pair.first: " << stra_pair.first->ToString() << ", "
344                       << "stra_pair.second: " << stra_pair.second->ToString() << ". "
345                       << "No available cost under current strategy pair of the edge: " << edge_name_;
346   }
347   if (cost_vec.size() > 1) {
348     MS_LOG(INFO) << "stra_pair.first: " << stra_pair.first->ToString() << ", "
349                  << "stra_pair.second: " << stra_pair.second->ToString() << ". "
350                  << "Multiple costs available under the stratey pair of the edge: " << edge_name_;
351   }
352   return cost_vec[0];
353 }
354 
GetNextOpStrategyByPrevOpStrategyWithMiniComm(const StrategyPtr & prev_op_stra)355 StrategyPtr Edge::GetNextOpStrategyByPrevOpStrategyWithMiniComm(const StrategyPtr &prev_op_stra) {
356   std::vector<std::pair<StrategyPtr, double>> next_op_stras;
357   // First, try to find the strategy with zero communication cost.
358   for (const auto &key_value : cost_map_) {
359     const auto &candidate_prev_op_stra = key_value.first.first;
360     if (prev_op_stra->IsEqual(candidate_prev_op_stra) && (key_value.second[0]->communication_cost_ < EPS)) {
361       (void)next_op_stras.emplace_back(key_value.first.second, key_value.second[0]->computation_cost_);
362     }
363   }
364   if (next_op_stras.empty()) {
365     // Second, if there is not strategy with zero communication cost, find the one with minimum communication cost.
366     std::vector<std::pair<StrategyPtr, double>> next_stras;
367     for (auto &key_value : cost_map_) {
368       const auto &candidate_prev_op_stra = key_value.first.first;
369       if (prev_op_stra->IsEqual(candidate_prev_op_stra)) {
370         (void)next_stras.emplace_back(key_value.first.second, key_value.second[0]->communication_cost_);
371       }
372     }
373     if (next_stras.empty()) {
374       MS_LOG(ERROR) << "There are no available strategy for zero communication cost for edge: " << edge_name_;
375       return nullptr;
376     }
377     MS_LOG(WARNING) << "Inconsistency occurred at edge: " << edge_name();
378     auto min_stra =
379       std::min_element(next_stras.begin(), next_stras.end(),
380                        [this](const std::pair<StrategyPtr, double> &a, const std::pair<StrategyPtr, double> &b) {
381                          return !IsDoubleEqual(a.second, b.second) ? a.second < b.second : a.first->Compare(b.first);
382                        });
383     return min_stra->first;
384   }
385   if (next_op_stras.size() > 1) {
386     MS_LOG(INFO) << "There are multiple strategies for edge: " << edge_name_
387                  << " with zero communication cost, choose the one with minimum computation costs.";
388   }
389   auto next_op = next_op_;
390   auto min_next_op_stra = std::min_element(
391     next_op_stras.begin(), next_op_stras.end(),
392     [this, &next_op](const std::pair<StrategyPtr, double> &a, const std::pair<StrategyPtr, double> &b) {
393       if (!IsDoubleEqual(a.second, b.second)) {
394         return a.second < b.second;
395       }
396       auto cost_a = next_op->GetCostByStrategyPtr(a.first)[0]->communication_without_parameter_;
397       auto cost_b = next_op->GetCostByStrategyPtr(b.first)[0]->communication_without_parameter_;
398       if (!IsDoubleEqual(cost_a, cost_b)) {
399         return cost_a < cost_b;
400       }
401       return a.first->Compare(b.first);
402     });
403   return min_next_op_stra->first;
404 }
405 
GetPrevOpStrategyByNextOpStrategyWithMiniComm(const StrategyPtr & next_op_stra)406 StrategyPtr Edge::GetPrevOpStrategyByNextOpStrategyWithMiniComm(const StrategyPtr &next_op_stra) {
407   std::vector<std::pair<StrategyPtr, double>> prev_op_stras;
408   // First, try to find the strategy with zero communication cost.
409   for (const auto &key_value : cost_map_) {
410     const auto &candidate_next_op_stra = key_value.first.second;
411     if (next_op_stra->IsEqual(candidate_next_op_stra) && (key_value.second[0]->communication_cost_ < EPS)) {
412       (void)prev_op_stras.emplace_back(key_value.first.first, key_value.second[0]->computation_cost_);
413     }
414   }
415   if (prev_op_stras.empty()) {
416     // Second, if there is no strategy with zero communication cost, find the one with minimum communication cost.
417     std::vector<std::pair<StrategyPtr, double>> prev_stras;
418     for (auto &key_value : cost_map_) {
419       const auto &candidate_next_op_stra = key_value.first.second;
420       if (next_op_stra->IsEqual(candidate_next_op_stra)) {
421         (void)prev_stras.emplace_back(key_value.first.first, key_value.second[0]->communication_cost_);
422       }
423     }
424     if (prev_stras.empty()) {
425       MS_LOG(ERROR) << "There are no available strategy for zero communication cost for edge: " << edge_name_;
426       return nullptr;
427     }
428     MS_LOG(WARNING) << "Inconsistency occurred at edge: " << edge_name();
429     auto min_prev_stra =
430       std::min_element(prev_stras.begin(), prev_stras.end(),
431                        [this](const std::pair<StrategyPtr, double> &a, const std::pair<StrategyPtr, double> &b) {
432                          return !IsDoubleEqual(a.second, b.second) ? a.second < b.second : a.first->Compare(b.first);
433                        });
434     return min_prev_stra->first;
435   }
436   if (prev_op_stras.size() > 1) {
437     MS_LOG(INFO) << "There are multiple strategies for edge: " << edge_name_
438                  << " with zero communication costs, choose the one with minimum computation costs.";
439   }
440   auto prev_op = prev_op_;
441   auto min_prev_op_stra = std::min_element(
442     prev_op_stras.begin(), prev_op_stras.end(),
443     [this, &prev_op](const std::pair<StrategyPtr, double> &a, const std::pair<StrategyPtr, double> &b) {
444       if (!IsDoubleEqual(a.second, b.second)) {
445         return a.second < b.second;
446       }
447       auto cost_a = prev_op->GetCostByStrategyPtr(a.first)[0]->communication_without_parameter_;
448       auto cost_b = prev_op->GetCostByStrategyPtr(b.first)[0]->communication_without_parameter_;
449       if (!IsDoubleEqual(cost_a, cost_b)) {
450         return cost_a < cost_b;
451       }
452       return a.first->Compare(b.first);
453     });
454   return min_prev_op_stra->first;
455 }
456 
GetReshapeSWCIndexByNextOpStrategy(const StrategyPtr & next_op_stra)457 int64_t Edge::GetReshapeSWCIndexByNextOpStrategy(const StrategyPtr &next_op_stra) {
458   if (!prev_op_->IsReshape()) {
459     MS_LOG(EXCEPTION) << "The edge: " << edge_name_ << "'s prev_op is not a Reshape.";
460   }
461   if (next_op_->IsReshape()) {
462     MS_LOG(EXCEPTION) << "The edge: " << edge_name_ << " has two Reshapes, which is not supported currently.";
463   }
464   const auto &reshape_output_layout = next_op_->GetInputLayoutFromSWCByStrategy(next_op_stra, next_op_input_index_);
465   MS_LOG(INFO) << prev_op_->name() << "'s output layout: " << reshape_output_layout.ToString();
466   auto reshape_ptr = std::dynamic_pointer_cast<ReshapeInfo>(prev_op_);
467   // First, try to find the zero communication strategy.
468   auto swc_index = reshape_ptr->GetSWCIndexByOutputLayoutWithZeroComm(reshape_output_layout);
469   if (swc_index == -1) {
470     // Second, if there is no strategy with zero communication cost, find the strategy with minimum cost.
471     swc_index = reshape_ptr->GetSWCIndexByOutputLayoutWithMiniComm(reshape_output_layout);
472     if (swc_index != -1) {
473       MS_LOG(WARNING) << "Inconsistency occurred at edge: " << edge_name();
474     }
475   }
476   if (swc_index == -1) {
477     MS_LOG(EXCEPTION) << "No available strategy found at edge: " << edge_name_ << " for: " << prev_op_->name();
478   }
479   return swc_index;
480 }
481 
GetReshapeSWCIndexByPrevOpStrategy(const StrategyPtr & prev_op_stra)482 int64_t Edge::GetReshapeSWCIndexByPrevOpStrategy(const StrategyPtr &prev_op_stra) {
483   if (!next_op_->IsReshape()) {
484     MS_LOG(EXCEPTION) << "The edge: " << edge_name_ << "'s next_op is not a Reshape.";
485   }
486   if (prev_op_->IsReshape()) {
487     MS_LOG(EXCEPTION) << "The edge: " << edge_name_ << " has two Reshapes, which is not supported currently.";
488   }
489   const auto &reshape_input_lyt = prev_op_->GetOutputLayoutFromSWCByStrategy(prev_op_stra, prev_op_output_index_);
490   MS_LOG(INFO) << next_op_->name() << "'s input layout: " << reshape_input_lyt.ToString();
491   auto reshape_ptr = std::dynamic_pointer_cast<ReshapeInfo>(next_op_);
492   // First, try to find the zero communication strategy.
493   auto swc_index = reshape_ptr->GetSWCIndexByInputLayoutWithZeroComm(reshape_input_lyt);
494   if (swc_index == -1) {
495     // Second, if there is no zero communication strategy, find the strategy with minimum cost.
496     swc_index = reshape_ptr->GetSWCIndexByInputLayoutWithMiniComm(reshape_input_lyt);
497     if (swc_index != -1) {
498       MS_LOG(WARNING) << "Inconsistency occurred at edge: " << edge_name();
499     }
500   }
501   if (swc_index == -1) {
502     MS_LOG(EXCEPTION) << "No available strategy found at edge: " << edge_name_ << " for: " << next_op_->name();
503   }
504   return swc_index;
505 }
506 
GetPrevOpStrategyByReshapeSWCIndex(int64_t swc_index)507 StrategyPtr Edge::GetPrevOpStrategyByReshapeSWCIndex(int64_t swc_index) {
508   if (!next_op_->IsReshape()) {
509     MS_LOG(EXCEPTION) << "The edge: " << edge_name_ << "'s next_op is not a Reshape.";
510   }
511   if (prev_op_->IsReshape()) {
512     MS_LOG(EXCEPTION) << "The edge: " << edge_name_ << " has two Reshapes, which is not supported currently.";
513   }
514   auto reshape_ptr = std::dynamic_pointer_cast<ReshapeInfo>(next_op_);
515   const auto &reshape_input_lyt = reshape_ptr->GetInputLayoutBySWCIndex(swc_index);
516   auto stra = prev_op_->GetStrategyFromSWCByOutputLayout(reshape_input_lyt, prev_op_output_index_);
517   if (stra == nullptr) {
518     MS_LOG(EXCEPTION) << "No available strategy found at edge: " << edge_name_ << " for: " << prev_op_->name();
519   }
520   return stra;
521 }
522 
GetNextOpStrategyByReshapeSWCIndex(int64_t swc_index)523 StrategyPtr Edge::GetNextOpStrategyByReshapeSWCIndex(int64_t swc_index) {
524   if (!prev_op_->IsReshape()) {
525     MS_LOG(EXCEPTION) << "The edge: " << edge_name_ << "'s next_op is not a Reshape.";
526   }
527   if (next_op_->IsReshape()) {
528     MS_LOG(EXCEPTION) << "The edge: " << edge_name_ << " has two Reshapes, which is not supported currently.";
529   }
530   auto reshape_ptr = std::dynamic_pointer_cast<ReshapeInfo>(prev_op_);
531   const auto &reshape_output_lyt = reshape_ptr->GetOutputLayoutBySWCIndex(swc_index);
532   auto stra = next_op_->GetStrategyFromSWCByInputLayout(reshape_output_lyt, next_op_input_index_);
533   if (stra == nullptr) {
534     MS_LOG(EXCEPTION) << "No available strategy found at edge: " << edge_name_ << " for: " << prev_op_->name();
535   }
536   return stra;
537 }
538 
CheckStrategyConsistency(StrategyPtr prev_stra,StrategyPtr next_stra,std::set<OperatorInfoPtr> * _diff_stra_params)539 bool Edge::CheckStrategyConsistency(StrategyPtr prev_stra, StrategyPtr next_stra,
540                                     std::set<OperatorInfoPtr> *_diff_stra_params) {
541   if (prev_stra == nullptr) {
542     MS_LOG(EXCEPTION) << prev_op_->name() << "'s selected strategy is null!";
543   }
544   if (next_stra == nullptr) {
545     MS_LOG(EXCEPTION) << next_op_->name() << "'s selected strategy is null!";
546   }
547   auto cost = GetCostByStrategyPair({prev_stra, next_stra});
548   if (cost == nullptr || cost->communication_cost_ > 0.0) {
549     MS_LOG(INFO) << "The edge " << edge_name_ << "'s strategy: prev_stra is " << prev_stra->ToString()
550                  << ", next_stra is " << next_stra->ToString();
551     if (prev_op_->IsTmpIdentity()) {
552       if (_diff_stra_params->count(prev_op_) == 0) {
553         _diff_stra_params->insert(prev_op_);
554       }
555       MS_LOG(INFO) << "The parameter: " << prev_op_->refkey_parameter_name()
556                    << " has been used by operators with "
557                       "different sharding strategies. These operators are: ";
558       auto const &succ_edges = prev_op_->succ_edges();
559       for (auto const &succ_edge : succ_edges) {
560         if (succ_edge->next_operator()->cnodes().empty()) {
561           MS_LOG(INFO) << "No CNODE info has been set in operator: " << succ_edge->next_operator()->name();
562         }
563         MS_LOG(INFO) << succ_edge->next_operator()->name() << ", the corresponding fullname is: "
564                      << succ_edge->next_operator()->cnodes()[0]->fullname_with_scope();
565       }
566       MS_LOG(INFO) << "Configure these operators with consistent sharding strategies.";
567     }
568     MS_LOG(WARNING) << "There are redistribution cost occurs at edge: " << edge_name() << ".";
569     return false;
570   }
571   return true;
572 }
573 
SetCostMapAndInputOutput(const std::map<CostPtrKey,CostPtrList> & cost_map)574 void Edge::SetCostMapAndInputOutput(const std::map<CostPtrKey, CostPtrList> &cost_map) {
575   cost_map_ = cost_map;
576   pre_op_output_.clear();
577   next_op_input_.clear();
578 
579   for (const auto &key_value : cost_map_) {
580     auto &key_pair = key_value.first;
581     (void)pre_op_output_.emplace_back(std::pair<StrategyPtr, std::vector<TensorInfo>>(key_pair.first, {}));
582     (void)next_op_input_.emplace_back(std::pair<StrategyPtr, std::vector<TensorInfo>>(key_pair.second, {}));
583   }
584 }
585 
586 // Return true if there are available strategies in this edge.
CheckStrategyCostPossibility() const587 bool Edge::CheckStrategyCostPossibility() const { return !cost_map_.empty(); }
588 }  // namespace parallel
589 }  // namespace mindspore
590