1 /**
2 * Copyright 2019-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/parallel/auto_parallel/edge_costmodel.h"
18
19 #include <algorithm>
20 #include <functional>
21 #include <iterator>
22 #include <utility>
23 #include "frontend/parallel/auto_parallel/costmodel.h"
24 #include "frontend/parallel/auto_parallel/graph_costmodel.h"
25 #include "frontend/parallel/tensor_layout/tensor_redistribution.h"
26 #include "frontend/parallel/ops_info/reshape_info.h"
27
28 namespace mindspore {
29 namespace parallel {
InitEdgeCost()30 Status Edge::InitEdgeCost() {
31 bool has_available_cost = false;
32 pre_op_output_.clear();
33 next_op_input_.clear();
34 cost_map_.clear();
35
36 for (auto &swc : prev_op_->GetStrategyCost()) {
37 MS_EXCEPTION_IF_NULL(swc);
38 (void)pre_op_output_.emplace_back(std::make_pair(swc->strategy_ptr, swc->outputs_ptr));
39 }
40 for (auto &swc : next_op_->GetStrategyCost()) {
41 MS_EXCEPTION_IF_NULL(swc);
42 (void)next_op_input_.emplace_back(std::make_pair(swc->strategy_ptr, swc->inputs_ptr));
43 }
44 if (is_identity_edge) {
45 for (auto &target_output : pre_op_output_) {
46 auto target_output_lyt = target_output.second[prev_op_output_index_].tensor_layout();
47 auto target_output_str = target_output.first;
48 for (auto &target_input : next_op_input_) {
49 auto target_input_lyt = target_input.second[next_op_input_index_].tensor_layout();
50 auto target_input_str = target_input.first;
51 // for identity_info ops, no need to compare device_matrix
52 if ((target_output_lyt == target_input_lyt) || (target_output_lyt.IsSameWithoutSplit(target_input_lyt) &&
53 edge_name().find(IDENTITY_INFO) != std::string::npos)) {
54 CostPtrKey ck = {target_output_str, target_input_str};
55 CostPtr cost = std::make_shared<Cost>(0.0, 0.0);
56 MS_EXCEPTION_IF_NULL(cost);
57 cost->communication_without_parameter_ = 0.0;
58 cost->communication_with_partial_para_ = 0.0;
59 CostPtrList cl;
60 cl.push_back(cost);
61 (void)cost_map_.emplace(std::make_pair(ck, cl));
62 has_available_cost = true;
63 }
64 }
65 }
66 } else {
67 for (auto &target_output : pre_op_output_) {
68 auto target_output_lyt = target_output.second[prev_op_output_index_].tensor_layout();
69 auto target_output_str = target_output.first;
70 auto type_length = prev_op_->GetOutputTypeLengths()[prev_op_output_index_];
71 auto type = prev_op_->outputs_type()[prev_op_output_index_];
72 for (auto &target_input : next_op_input_) {
73 auto target_input_lyt = target_input.second[next_op_input_index_].tensor_layout();
74 auto target_input_str = target_input.first;
75 CostPtr cost;
76 if (GetRedistributionCost(target_output_lyt, target_input_lyt, type_length, type, &cost) != SUCCESS) {
77 MS_LOG(EXCEPTION) << "Failure: redistribution cost calculation failed";
78 }
79 MS_EXCEPTION_IF_NULL(cost);
80 MS_LOG(DEBUG) << "The redistribution cost: computation_cost: " << cost->computation_cost_
81 << ", communication_cost: " << cost->communication_cost_
82 << ", communication_without_parameter_: " << cost->communication_without_parameter_
83 << ", communication_with_partial_para_: " << cost->communication_with_partial_para_ << ".";
84 // refine communication cost calculation for practice
85 RefineForPracticalCost(cost, true);
86 cost->communication_forward_ = cost->communication_redis_forward_;
87 CostPtrKey ck = {target_output_str, target_input_str};
88 CostPtrList cl;
89 cl.push_back(cost);
90 (void)cost_map_.emplace(std::make_pair(ck, cl));
91 has_available_cost = true;
92 }
93 }
94 }
95 if (!has_available_cost) {
96 const auto fully_use = CostModelContext::GetInstance()->fully_use_device();
97 const auto stra_follow = CostModelContext::GetInstance()->elementwise_stra_follow();
98 if (fully_use) {
99 MS_LOG(ERROR) << "Generating cost for edge: " << edge_name_
100 << " failed, it may be caused by setting 'fully_use_devices' true. Try to set "
101 "'fully_use_devices' false.";
102 } else if (stra_follow) {
103 MS_LOG(ERROR) << "Generating cost for edge: " << edge_name_
104 << " failed, it may be caused by setting 'elementwise_op_strategy_follow' true. "
105 "Try to set 'elementwise_op_strategy_follow' false.";
106 }
107 if (edge_name_.find(RESHAPE) != std::string::npos) {
108 MS_LOG(ERROR) << "Generating cost for edge: " << edge_name_
109 << " failed, it may be caused by setting different strategies for operators following Reshape. "
110 "Try to fix that.";
111 }
112 MS_LOG(INFO) << "Generating cost for edge: " << edge_name_ << " failed.";
113 return Status::FAILED;
114 }
115 return Status::SUCCESS;
116 }
117
GetRedistributionCost(const TensorLayout & prev_op_output_layout,const TensorLayout & next_op_input_layout,size_t type_length,const TypePtr & type,CostPtr * cost)118 Status Edge::GetRedistributionCost(const TensorLayout &prev_op_output_layout, const TensorLayout &next_op_input_layout,
119 size_t type_length, const TypePtr &type, CostPtr *cost) {
120 MS_EXCEPTION_IF_NULL(prev_op_);
121 MS_EXCEPTION_IF_NULL(cost);
122 RankList dev_list = prev_op_->stage_device_list();
123 TensorRedistribution tensor_redistribution(false);
124
125 // Init TensorRedistribution
126 if (tensor_redistribution.Init(prev_op_output_layout, next_op_input_layout, dev_list) == FAILED) {
127 MS_LOG(EXCEPTION) << "Failure: tensor_redistribution init failed.";
128 }
129
130 if (tensor_redistribution.ComputeCost() == FAILED) {
131 MS_LOG(EXCEPTION) << "Failure: tensor_redistribution ComputeCost failed.";
132 }
133
134 double comm_cost = tensor_redistribution.comm_cost();
135 double forward_comm_cost = tensor_redistribution.forward_comm_cost();
136 double backward_comm_cost = tensor_redistribution.backward_comm_cost();
137 double computation_cost = tensor_redistribution.computation_cost();
138 double mem_cost = tensor_redistribution.memory_cost();
139 const auto gamma = CostModelContext::GetInstance()->costmodel_gamma();
140
141 // Now AllGather, ReduceScatter, AlltoAll don't support bool type
142 MS_EXCEPTION_IF_NULL(type);
143 if ((type->type_id() == kNumberTypeBool) && (comm_cost > 0)) {
144 computation_cost = INF;
145 comm_cost = INF;
146 MS_LOG(WARNING) << "Communication Operators don't support bool dtype!";
147 }
148 *cost = std::make_shared<Cost>(type_length * computation_cost, type_length * comm_cost);
149 (*cost)->communication_without_parameter_ = type_length * comm_cost;
150 (*cost)->communication_with_partial_para_ =
151 (*cost)->communication_without_parameter_ +
152 gamma * ((*cost)->communication_cost_ - (*cost)->communication_without_parameter_);
153 (*cost)->communication_redis_forward_ = type_length * forward_comm_cost;
154 (*cost)->communication_redis_backward_ = type_length * backward_comm_cost;
155 (*cost)->memory_with_reuse_ = mem_cost;
156 return Status::SUCCESS;
157 }
158
GetCostList(StrategyPtr output_str,StrategyPtr input_str)159 CostPtrList Edge::GetCostList(StrategyPtr output_str, StrategyPtr input_str) {
160 CostPtrKey ck = {output_str, input_str};
161 CostPtrList result;
162 if (cost_map_.find(ck) != cost_map_.end()) {
163 return cost_map_.at(ck);
164 }
165 return result;
166 }
167
CreateEdgeEliminationCostList(const StrategyPtr & output_st_ptr,const std::vector<EdgePtr> & edges,const StrategyPtr & input_st_ptr) const168 CostPtrList Edge::CreateEdgeEliminationCostList(const StrategyPtr &output_st_ptr, const std::vector<EdgePtr> &edges,
169 const StrategyPtr &input_st_ptr) const {
170 std::function<CostPtrList(EdgePtr)> LocalGetCostList = [&](const EdgePtr &edge) {
171 MS_EXCEPTION_IF_NULL(edge);
172 return edge->GetCostList(output_st_ptr, input_st_ptr);
173 };
174 CostPtrList result;
175 std::vector<CostPtrList> all_cost_list;
176 all_cost_list.resize(edges.size());
177 (void)std::transform(edges.begin(), edges.end(), all_cost_list.begin(), LocalGetCostList);
178
179 CostPtrList selected_cost_list(all_cost_list.size(), nullptr);
180 std::function<void(size_t, double, double, double, double, double)> recursive =
181 [&](size_t k, double computation, double memory, double communication, double communication_without_para,
182 double communication_forward) {
183 const auto gamma = CostModelContext::GetInstance()->costmodel_gamma();
184 if (k == edges.size()) {
185 auto decision = std::make_shared<EdgeEliminationDecision>(selected_cost_list);
186 CostPtr new_cost = std::make_shared<Cost>(computation, communication);
187 MS_EXCEPTION_IF_NULL(new_cost);
188 new_cost->communication_without_parameter_ = communication_without_para;
189 new_cost->communication_with_partial_para_ =
190 communication_without_para + gamma * (communication - communication_without_para);
191 new_cost->memory_with_reuse_ = memory;
192 new_cost->communication_forward_ = communication_forward;
193 new_cost->decision_ptr_ = decision;
194 result.push_back(new_cost);
195 return;
196 }
197 for (auto &c : all_cost_list[k]) {
198 MS_EXCEPTION_IF_NULL(c);
199 selected_cost_list[k] = c;
200 recursive(k + 1, computation + c->computation_cost_, memory + c->memory_with_reuse_,
201 communication + c->communication_cost_,
202 communication_without_para + c->communication_without_parameter_,
203 communication_forward + c->communication_forward_);
204 }
205 };
206 recursive(0, 0.0, 0.0, 0.0, 0.0, 0.0);
207 Simplify(&result);
208 return result;
209 }
210
EdgeEliminationSetNewCost(OperatorInfoPtr,const std::vector<EdgePtr> & edges,OperatorInfoPtr)211 void Edge::EdgeEliminationSetNewCost(OperatorInfoPtr, const std::vector<EdgePtr> &edges, OperatorInfoPtr) {
212 bool valid = false;
213 for (const auto &output_pair : pre_op_output_) {
214 StrategyPtr output_st_ptr = output_pair.first;
215 for (const auto &input_pair : next_op_input_) {
216 StrategyPtr input_st_ptr = input_pair.first;
217 CostPtrList clist = CreateEdgeEliminationCostList(output_st_ptr, edges, input_st_ptr);
218 CostPtrKey key = {output_st_ptr, input_st_ptr};
219 cost_map_[key] = clist;
220 if ((!valid) && (!clist.empty())) {
221 valid = true;
222 }
223 }
224 }
225 if (!valid) {
226 MS_LOG(EXCEPTION) << "Creating edge: " << edge_name_ << " failed.";
227 }
228 }
229
CreateOpEliminationSubCostList(StrategyPtr op_strategy,const CostPtrList & left_cost_list,const CostPtrList & middle_cost_list,const CostPtrList & right_cost_list,CostPtrList * ret_cost_list) const230 void Edge::CreateOpEliminationSubCostList(StrategyPtr op_strategy, const CostPtrList &left_cost_list,
231 const CostPtrList &middle_cost_list, const CostPtrList &right_cost_list,
232 CostPtrList *ret_cost_list) const {
233 for (auto &left_cost : left_cost_list) {
234 MS_EXCEPTION_IF_NULL(left_cost);
235 for (auto &middle_cost : middle_cost_list) {
236 MS_EXCEPTION_IF_NULL(middle_cost);
237 for (auto &right_cost : right_cost_list) {
238 MS_EXCEPTION_IF_NULL(right_cost);
239 double computation =
240 left_cost->computation_cost_ + middle_cost->computation_cost_ + right_cost->computation_cost_;
241 double communication =
242 left_cost->communication_cost_ + middle_cost->communication_cost_ + right_cost->communication_cost_;
243 double communication_forward =
244 left_cost->communication_forward_ + middle_cost->communication_forward_ + right_cost->communication_forward_;
245 double communication_without_para = left_cost->communication_without_parameter_ +
246 middle_cost->communication_without_parameter_ +
247 right_cost->communication_without_parameter_;
248 double memory_cost =
249 left_cost->memory_with_reuse_ + middle_cost->memory_with_reuse_ + right_cost->memory_with_reuse_;
250
251 auto decision = std::make_shared<OpEliminationDecision>(op_strategy, left_cost, middle_cost, right_cost);
252 auto cost = std::make_shared<Cost>(computation, communication, decision);
253 const auto gamma = CostModelContext::GetInstance()->costmodel_gamma();
254 MS_EXCEPTION_IF_NULL(cost);
255 cost->communication_without_parameter_ = communication_without_para;
256 cost->communication_with_partial_para_ =
257 communication_without_para + gamma * (communication - communication_without_para);
258 cost->memory_with_reuse_ = memory_cost;
259 cost->communication_forward_ = communication_forward;
260 (void)ret_cost_list->emplace_back(std::move(cost));
261 }
262 }
263 }
264 }
265
CreateOpEliminationCostList(const EdgePtr & e1,const StrategyPtr & output_st_ptr,const OperatorInfoPtr & op,const EdgePtr & e2,const StrategyPtr & input_st_ptr) const266 CostPtrList Edge::CreateOpEliminationCostList(const EdgePtr &e1, const StrategyPtr &output_st_ptr,
267 const OperatorInfoPtr &op, const EdgePtr &e2,
268 const StrategyPtr &input_st_ptr) const {
269 MS_EXCEPTION_IF_NULL(op);
270 MS_EXCEPTION_IF_NULL(e1);
271 MS_EXCEPTION_IF_NULL(e2);
272 CostPtrList result;
273 for (const auto &op_strategy : op->GetStrategyCost()) {
274 MS_EXCEPTION_IF_NULL(op_strategy);
275 auto middle_strategy = op_strategy->strategy_ptr;
276 CreateOpEliminationSubCostList(middle_strategy, e1->GetCostList(output_st_ptr, middle_strategy),
277 op_strategy->cost_list, e2->GetCostList(middle_strategy, input_st_ptr), &result);
278 }
279 Simplify(&result);
280 return result;
281 }
282
OpEliminationSetNewCost(const EdgePtr & e1,const OperatorInfoPtr & op,const EdgePtr & e2)283 void Edge::OpEliminationSetNewCost(const EdgePtr &e1, const OperatorInfoPtr &op, const EdgePtr &e2) {
284 bool valid = false;
285 for (const auto &output_pair : pre_op_output_) {
286 StrategyPtr output_st_ptr = output_pair.first;
287 for (const auto &input_pair : next_op_input_) {
288 StrategyPtr input_st_ptr = input_pair.first;
289
290 CostPtrList clist = CreateOpEliminationCostList(e1, output_st_ptr, op, e2, input_st_ptr);
291 CostPtrKey key = {output_st_ptr, input_st_ptr};
292 cost_map_[key] = clist;
293 if ((!valid) && (!clist.empty())) {
294 valid = true;
295 }
296 }
297 }
298 if (!valid) {
299 MS_LOG(EXCEPTION) << "Creating edge: " << edge_name_ << " failed.";
300 }
301 }
302
CalculateMemoryCost()303 Status Edge::CalculateMemoryCost() {
304 if (is_output_parameter_involve_ == -1) {
305 MS_LOG(ERROR) << "is_output_parameter_involve_ is unset.";
306 return FAILED;
307 }
308 if (is_output_parameter_involve_ == 0) {
309 // In this case, it is sure that the tensor redistribution along this edge is NOT parameter-involved, thus it is
310 // unnecessary to keep them in memory.
311 for (auto &cost_kv : cost_map_) {
312 auto &cost_v = cost_kv.second;
313 if (!cost_v.empty()) {
314 cost_v[0]->memory_with_reuse_ = 0;
315 }
316 }
317 }
318
319 return SUCCESS;
320 }
321
CalculateMemoryCostForInference()322 Status Edge::CalculateMemoryCostForInference() {
323 // Currently, memory cost is NOT calculated for redistribution
324 if ((is_output_critical_ != 0) && (is_output_critical_ != 1)) {
325 MS_LOG(ERROR) << "Failure: unexpected output critical flag value: " << is_output_critical_;
326 return FAILED;
327 }
328 for (const auto &cost_kv : cost_map_) {
329 auto &cost_v = cost_kv.second;
330 if (!cost_v.empty()) {
331 cost_v[0]->memory_with_reuse_ = 0;
332 }
333 }
334 return SUCCESS;
335 }
336
GetCostByStrategyPair(const CostPtrKey & stra_pair)337 CostPtr Edge::GetCostByStrategyPair(const CostPtrKey &stra_pair) {
338 if (cost_map_.find(stra_pair) == cost_map_.end()) {
339 return nullptr;
340 }
341 auto cost_vec = cost_map_[stra_pair];
342 if (cost_vec.empty()) {
343 MS_LOG(EXCEPTION) << "stra_pair.first: " << stra_pair.first->ToString() << ", "
344 << "stra_pair.second: " << stra_pair.second->ToString() << ". "
345 << "No available cost under current strategy pair of the edge: " << edge_name_;
346 }
347 if (cost_vec.size() > 1) {
348 MS_LOG(INFO) << "stra_pair.first: " << stra_pair.first->ToString() << ", "
349 << "stra_pair.second: " << stra_pair.second->ToString() << ". "
350 << "Multiple costs available under the stratey pair of the edge: " << edge_name_;
351 }
352 return cost_vec[0];
353 }
354
GetNextOpStrategyByPrevOpStrategyWithMiniComm(const StrategyPtr & prev_op_stra)355 StrategyPtr Edge::GetNextOpStrategyByPrevOpStrategyWithMiniComm(const StrategyPtr &prev_op_stra) {
356 std::vector<std::pair<StrategyPtr, double>> next_op_stras;
357 // First, try to find the strategy with zero communication cost.
358 for (const auto &key_value : cost_map_) {
359 const auto &candidate_prev_op_stra = key_value.first.first;
360 if (prev_op_stra->IsEqual(candidate_prev_op_stra) && (key_value.second[0]->communication_cost_ < EPS)) {
361 (void)next_op_stras.emplace_back(key_value.first.second, key_value.second[0]->computation_cost_);
362 }
363 }
364 if (next_op_stras.empty()) {
365 // Second, if there is not strategy with zero communication cost, find the one with minimum communication cost.
366 std::vector<std::pair<StrategyPtr, double>> next_stras;
367 for (auto &key_value : cost_map_) {
368 const auto &candidate_prev_op_stra = key_value.first.first;
369 if (prev_op_stra->IsEqual(candidate_prev_op_stra)) {
370 (void)next_stras.emplace_back(key_value.first.second, key_value.second[0]->communication_cost_);
371 }
372 }
373 if (next_stras.empty()) {
374 MS_LOG(ERROR) << "There are no available strategy for zero communication cost for edge: " << edge_name_;
375 return nullptr;
376 }
377 MS_LOG(WARNING) << "Inconsistency occurred at edge: " << edge_name();
378 auto min_stra =
379 std::min_element(next_stras.begin(), next_stras.end(),
380 [this](const std::pair<StrategyPtr, double> &a, const std::pair<StrategyPtr, double> &b) {
381 return !IsDoubleEqual(a.second, b.second) ? a.second < b.second : a.first->Compare(b.first);
382 });
383 return min_stra->first;
384 }
385 if (next_op_stras.size() > 1) {
386 MS_LOG(INFO) << "There are multiple strategies for edge: " << edge_name_
387 << " with zero communication cost, choose the one with minimum computation costs.";
388 }
389 auto next_op = next_op_;
390 auto min_next_op_stra = std::min_element(
391 next_op_stras.begin(), next_op_stras.end(),
392 [this, &next_op](const std::pair<StrategyPtr, double> &a, const std::pair<StrategyPtr, double> &b) {
393 if (!IsDoubleEqual(a.second, b.second)) {
394 return a.second < b.second;
395 }
396 auto cost_a = next_op->GetCostByStrategyPtr(a.first)[0]->communication_without_parameter_;
397 auto cost_b = next_op->GetCostByStrategyPtr(b.first)[0]->communication_without_parameter_;
398 if (!IsDoubleEqual(cost_a, cost_b)) {
399 return cost_a < cost_b;
400 }
401 return a.first->Compare(b.first);
402 });
403 return min_next_op_stra->first;
404 }
405
GetPrevOpStrategyByNextOpStrategyWithMiniComm(const StrategyPtr & next_op_stra)406 StrategyPtr Edge::GetPrevOpStrategyByNextOpStrategyWithMiniComm(const StrategyPtr &next_op_stra) {
407 std::vector<std::pair<StrategyPtr, double>> prev_op_stras;
408 // First, try to find the strategy with zero communication cost.
409 for (const auto &key_value : cost_map_) {
410 const auto &candidate_next_op_stra = key_value.first.second;
411 if (next_op_stra->IsEqual(candidate_next_op_stra) && (key_value.second[0]->communication_cost_ < EPS)) {
412 (void)prev_op_stras.emplace_back(key_value.first.first, key_value.second[0]->computation_cost_);
413 }
414 }
415 if (prev_op_stras.empty()) {
416 // Second, if there is no strategy with zero communication cost, find the one with minimum communication cost.
417 std::vector<std::pair<StrategyPtr, double>> prev_stras;
418 for (auto &key_value : cost_map_) {
419 const auto &candidate_next_op_stra = key_value.first.second;
420 if (next_op_stra->IsEqual(candidate_next_op_stra)) {
421 (void)prev_stras.emplace_back(key_value.first.first, key_value.second[0]->communication_cost_);
422 }
423 }
424 if (prev_stras.empty()) {
425 MS_LOG(ERROR) << "There are no available strategy for zero communication cost for edge: " << edge_name_;
426 return nullptr;
427 }
428 MS_LOG(WARNING) << "Inconsistency occurred at edge: " << edge_name();
429 auto min_prev_stra =
430 std::min_element(prev_stras.begin(), prev_stras.end(),
431 [this](const std::pair<StrategyPtr, double> &a, const std::pair<StrategyPtr, double> &b) {
432 return !IsDoubleEqual(a.second, b.second) ? a.second < b.second : a.first->Compare(b.first);
433 });
434 return min_prev_stra->first;
435 }
436 if (prev_op_stras.size() > 1) {
437 MS_LOG(INFO) << "There are multiple strategies for edge: " << edge_name_
438 << " with zero communication costs, choose the one with minimum computation costs.";
439 }
440 auto prev_op = prev_op_;
441 auto min_prev_op_stra = std::min_element(
442 prev_op_stras.begin(), prev_op_stras.end(),
443 [this, &prev_op](const std::pair<StrategyPtr, double> &a, const std::pair<StrategyPtr, double> &b) {
444 if (!IsDoubleEqual(a.second, b.second)) {
445 return a.second < b.second;
446 }
447 auto cost_a = prev_op->GetCostByStrategyPtr(a.first)[0]->communication_without_parameter_;
448 auto cost_b = prev_op->GetCostByStrategyPtr(b.first)[0]->communication_without_parameter_;
449 if (!IsDoubleEqual(cost_a, cost_b)) {
450 return cost_a < cost_b;
451 }
452 return a.first->Compare(b.first);
453 });
454 return min_prev_op_stra->first;
455 }
456
GetReshapeSWCIndexByNextOpStrategy(const StrategyPtr & next_op_stra)457 int64_t Edge::GetReshapeSWCIndexByNextOpStrategy(const StrategyPtr &next_op_stra) {
458 if (!prev_op_->IsReshape()) {
459 MS_LOG(EXCEPTION) << "The edge: " << edge_name_ << "'s prev_op is not a Reshape.";
460 }
461 if (next_op_->IsReshape()) {
462 MS_LOG(EXCEPTION) << "The edge: " << edge_name_ << " has two Reshapes, which is not supported currently.";
463 }
464 const auto &reshape_output_layout = next_op_->GetInputLayoutFromSWCByStrategy(next_op_stra, next_op_input_index_);
465 MS_LOG(INFO) << prev_op_->name() << "'s output layout: " << reshape_output_layout.ToString();
466 auto reshape_ptr = std::dynamic_pointer_cast<ReshapeInfo>(prev_op_);
467 // First, try to find the zero communication strategy.
468 auto swc_index = reshape_ptr->GetSWCIndexByOutputLayoutWithZeroComm(reshape_output_layout);
469 if (swc_index == -1) {
470 // Second, if there is no strategy with zero communication cost, find the strategy with minimum cost.
471 swc_index = reshape_ptr->GetSWCIndexByOutputLayoutWithMiniComm(reshape_output_layout);
472 if (swc_index != -1) {
473 MS_LOG(WARNING) << "Inconsistency occurred at edge: " << edge_name();
474 }
475 }
476 if (swc_index == -1) {
477 MS_LOG(EXCEPTION) << "No available strategy found at edge: " << edge_name_ << " for: " << prev_op_->name();
478 }
479 return swc_index;
480 }
481
GetReshapeSWCIndexByPrevOpStrategy(const StrategyPtr & prev_op_stra)482 int64_t Edge::GetReshapeSWCIndexByPrevOpStrategy(const StrategyPtr &prev_op_stra) {
483 if (!next_op_->IsReshape()) {
484 MS_LOG(EXCEPTION) << "The edge: " << edge_name_ << "'s next_op is not a Reshape.";
485 }
486 if (prev_op_->IsReshape()) {
487 MS_LOG(EXCEPTION) << "The edge: " << edge_name_ << " has two Reshapes, which is not supported currently.";
488 }
489 const auto &reshape_input_lyt = prev_op_->GetOutputLayoutFromSWCByStrategy(prev_op_stra, prev_op_output_index_);
490 MS_LOG(INFO) << next_op_->name() << "'s input layout: " << reshape_input_lyt.ToString();
491 auto reshape_ptr = std::dynamic_pointer_cast<ReshapeInfo>(next_op_);
492 // First, try to find the zero communication strategy.
493 auto swc_index = reshape_ptr->GetSWCIndexByInputLayoutWithZeroComm(reshape_input_lyt);
494 if (swc_index == -1) {
495 // Second, if there is no zero communication strategy, find the strategy with minimum cost.
496 swc_index = reshape_ptr->GetSWCIndexByInputLayoutWithMiniComm(reshape_input_lyt);
497 if (swc_index != -1) {
498 MS_LOG(WARNING) << "Inconsistency occurred at edge: " << edge_name();
499 }
500 }
501 if (swc_index == -1) {
502 MS_LOG(EXCEPTION) << "No available strategy found at edge: " << edge_name_ << " for: " << next_op_->name();
503 }
504 return swc_index;
505 }
506
GetPrevOpStrategyByReshapeSWCIndex(int64_t swc_index)507 StrategyPtr Edge::GetPrevOpStrategyByReshapeSWCIndex(int64_t swc_index) {
508 if (!next_op_->IsReshape()) {
509 MS_LOG(EXCEPTION) << "The edge: " << edge_name_ << "'s next_op is not a Reshape.";
510 }
511 if (prev_op_->IsReshape()) {
512 MS_LOG(EXCEPTION) << "The edge: " << edge_name_ << " has two Reshapes, which is not supported currently.";
513 }
514 auto reshape_ptr = std::dynamic_pointer_cast<ReshapeInfo>(next_op_);
515 const auto &reshape_input_lyt = reshape_ptr->GetInputLayoutBySWCIndex(swc_index);
516 auto stra = prev_op_->GetStrategyFromSWCByOutputLayout(reshape_input_lyt, prev_op_output_index_);
517 if (stra == nullptr) {
518 MS_LOG(EXCEPTION) << "No available strategy found at edge: " << edge_name_ << " for: " << prev_op_->name();
519 }
520 return stra;
521 }
522
GetNextOpStrategyByReshapeSWCIndex(int64_t swc_index)523 StrategyPtr Edge::GetNextOpStrategyByReshapeSWCIndex(int64_t swc_index) {
524 if (!prev_op_->IsReshape()) {
525 MS_LOG(EXCEPTION) << "The edge: " << edge_name_ << "'s next_op is not a Reshape.";
526 }
527 if (next_op_->IsReshape()) {
528 MS_LOG(EXCEPTION) << "The edge: " << edge_name_ << " has two Reshapes, which is not supported currently.";
529 }
530 auto reshape_ptr = std::dynamic_pointer_cast<ReshapeInfo>(prev_op_);
531 const auto &reshape_output_lyt = reshape_ptr->GetOutputLayoutBySWCIndex(swc_index);
532 auto stra = next_op_->GetStrategyFromSWCByInputLayout(reshape_output_lyt, next_op_input_index_);
533 if (stra == nullptr) {
534 MS_LOG(EXCEPTION) << "No available strategy found at edge: " << edge_name_ << " for: " << prev_op_->name();
535 }
536 return stra;
537 }
538
CheckStrategyConsistency(StrategyPtr prev_stra,StrategyPtr next_stra,std::set<OperatorInfoPtr> * _diff_stra_params)539 bool Edge::CheckStrategyConsistency(StrategyPtr prev_stra, StrategyPtr next_stra,
540 std::set<OperatorInfoPtr> *_diff_stra_params) {
541 if (prev_stra == nullptr) {
542 MS_LOG(EXCEPTION) << prev_op_->name() << "'s selected strategy is null!";
543 }
544 if (next_stra == nullptr) {
545 MS_LOG(EXCEPTION) << next_op_->name() << "'s selected strategy is null!";
546 }
547 auto cost = GetCostByStrategyPair({prev_stra, next_stra});
548 if (cost == nullptr || cost->communication_cost_ > 0.0) {
549 MS_LOG(INFO) << "The edge " << edge_name_ << "'s strategy: prev_stra is " << prev_stra->ToString()
550 << ", next_stra is " << next_stra->ToString();
551 if (prev_op_->IsTmpIdentity()) {
552 if (_diff_stra_params->count(prev_op_) == 0) {
553 _diff_stra_params->insert(prev_op_);
554 }
555 MS_LOG(INFO) << "The parameter: " << prev_op_->refkey_parameter_name()
556 << " has been used by operators with "
557 "different sharding strategies. These operators are: ";
558 auto const &succ_edges = prev_op_->succ_edges();
559 for (auto const &succ_edge : succ_edges) {
560 if (succ_edge->next_operator()->cnodes().empty()) {
561 MS_LOG(INFO) << "No CNODE info has been set in operator: " << succ_edge->next_operator()->name();
562 }
563 MS_LOG(INFO) << succ_edge->next_operator()->name() << ", the corresponding fullname is: "
564 << succ_edge->next_operator()->cnodes()[0]->fullname_with_scope();
565 }
566 MS_LOG(INFO) << "Configure these operators with consistent sharding strategies.";
567 }
568 MS_LOG(WARNING) << "There are redistribution cost occurs at edge: " << edge_name() << ".";
569 return false;
570 }
571 return true;
572 }
573
SetCostMapAndInputOutput(const std::map<CostPtrKey,CostPtrList> & cost_map)574 void Edge::SetCostMapAndInputOutput(const std::map<CostPtrKey, CostPtrList> &cost_map) {
575 cost_map_ = cost_map;
576 pre_op_output_.clear();
577 next_op_input_.clear();
578
579 for (const auto &key_value : cost_map_) {
580 auto &key_pair = key_value.first;
581 (void)pre_op_output_.emplace_back(std::pair<StrategyPtr, std::vector<TensorInfo>>(key_pair.first, {}));
582 (void)next_op_input_.emplace_back(std::pair<StrategyPtr, std::vector<TensorInfo>>(key_pair.second, {}));
583 }
584 }
585
586 // Return true if there are available strategies in this edge.
CheckStrategyCostPossibility() const587 bool Edge::CheckStrategyCostPossibility() const { return !cost_map_.empty(); }
588 } // namespace parallel
589 } // namespace mindspore
590