• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include <algorithm>
17 #include <cstdlib>
18 #include <numeric>
19 #include <string>
20 #include <utility>
21 #include <vector>
22 #include <queue>
23 
24 #include "frontend/parallel/auto_parallel/graph_costmodel.h"
25 #include "frontend/parallel/ops_info/reshape_info.h"
26 #include "frontend/parallel/step_auto_parallel.h"
27 
28 namespace mindspore {
29 namespace parallel {
30 CostGraphPtr entire_costgraph = nullptr;
31 size_t TOTAL_OPS = 0;
32 
Init()33 void CostGraph::Init() {
34   inputs_tensor_name_list_.clear();
35   tuple_getitem_list_.clear();
36   ops_.clear();
37   edges_.clear();
38   connected_compoents_.clear();
39   out_edges_.clear();
40   in_edges_.clear();
41 }
42 
RemoveOperator(const OperatorInfoPtr & op)43 void CostGraph::RemoveOperator(const OperatorInfoPtr &op) {
44   for (auto it = ops_.begin(); it != ops_.end();) {
45     if ((*it) == op) {
46       it = ops_.erase(it);
47     } else {
48       ++it;
49     }
50   }
51 }
52 
IsOperatorInCostGraph(const OperatorInfoPtr & op_test)53 bool CostGraph::IsOperatorInCostGraph(const OperatorInfoPtr &op_test) {
54   struct IsInGraph {
55     const OperatorInfoPtr test_;
56     explicit IsInGraph(const OperatorInfoPtr &n) : test_(n) {}
57     bool operator()(const OperatorInfoPtr &in) const { return (test_ == in); }
58   };
59   return std::any_of(ops_.begin(), ops_.end(), IsInGraph(op_test));
60 }
61 
AddEdge(OperatorInfoPtr u_node,OperatorInfoPtr v_node,const EdgePtr & edge)62 void CostGraph::AddEdge(OperatorInfoPtr u_node, OperatorInfoPtr v_node, const EdgePtr &edge) {
63   std::vector<EdgePtr> curr_edges(edges_[{u_node, v_node}]);
64   curr_edges.push_back(edge);
65   edges_[{u_node, v_node}] = curr_edges;
66 
67   std::vector<EdgePtr> curr_out_edges(out_edges_[u_node]);
68   curr_out_edges.push_back(edge);
69   out_edges_[u_node] = curr_out_edges;
70 
71   std::vector<EdgePtr> curr_in_edges(in_edges_[v_node]);
72   curr_in_edges.push_back(edge);
73   in_edges_[v_node] = curr_in_edges;
74 }
75 
IsEdgeInCostGraph(const std::string & test_edge_name,size_t output_index,size_t input_index)76 bool CostGraph::IsEdgeInCostGraph(const std::string &test_edge_name, size_t output_index, size_t input_index) {
77   for (auto &edge_pair : edges_) {
78     auto edges = edge_pair.second;
79     for (auto &edge : edges) {
80       MS_EXCEPTION_IF_NULL(edge);
81       bool bool_result = (edge->edge_name() == test_edge_name) && (edge->prev_op_output_index() == output_index) &&
82                          (edge->next_op_input_index() == input_index);
83       if (bool_result) {
84         return true;
85       }
86     }
87   }
88   return false;
89 }
90 
StrategyPropagate(const std::map<OperatorInfoPtr,StrategyPtr> & ops_stras)91 void CostGraph::StrategyPropagate(const std::map<OperatorInfoPtr, StrategyPtr> &ops_stras) {
92   if (ops_stras.empty()) {
93     MS_LOG(EXCEPTION) << "There is no operator that is configured sharding strategy.";
94   }
95   std::map<OperatorInfoPtr, bool> visited;
96   for (auto &op : ops_) {
97     visited[op] = false;
98   }
99   for (auto &op_stra : ops_stras) {
100     BFS(op_stra.first, op_stra.second, ops_stras, &visited);
101   }
102 }
103 
CheckShardingConsisitency(std::map<OperatorInfoPtr,StrategyPtr> configured_ops,const OperatorInfoPtr & curr_op,const OperatorInfoPtr & another_op,const CostPtr & cost,const EdgePtr & edge)104 void CheckShardingConsisitency(std::map<OperatorInfoPtr, StrategyPtr> configured_ops, const OperatorInfoPtr &curr_op,
105                                const OperatorInfoPtr &another_op, const CostPtr &cost, const EdgePtr &edge) {
106   if ((configured_ops.find(another_op) == configured_ops.end()) &&
107       (cost == nullptr || cost->communication_cost_ != 0.0)) {
108     PrintStrategy(another_op->selected_strategy());
109     PrintStrategy(curr_op->selected_strategy());
110     MS_LOG(EXCEPTION) << "There are redistribution cost occurs at edge: " << edge->edge_name()
111                       << ", consider configuring sharding strategies for two operators."
112                       << " The full name of these two operators are: " << curr_op->cnode()->fullname_with_scope()
113                       << " and " << another_op->cnode()->fullname_with_scope();
114   }
115 }
116 
BFS(const OperatorInfoPtr & op,const StrategyPtr & op_stra,std::map<OperatorInfoPtr,StrategyPtr> configured_ops,std::map<OperatorInfoPtr,bool> * visited)117 void CostGraph::BFS(const OperatorInfoPtr &op, const StrategyPtr &op_stra,
118                     std::map<OperatorInfoPtr, StrategyPtr> configured_ops, std::map<OperatorInfoPtr, bool> *visited) {
119   std::queue<std::pair<std::pair<OperatorInfoPtr, StrategyPtr>, size_t>> next_level;
120   (void)next_level.emplace(std::make_pair(op, op_stra), 0);
121   while (!next_level.empty()) {
122     auto curr_op = next_level.front().first.first;
123     auto configured_stra = next_level.front().first.second;
124     auto curr_depth = next_level.front().second;
125     visited->at(curr_op) = true;
126     MS_LOG(INFO) << "curr_depth: " << curr_depth;
127     curr_op->SetSelectedStrategy(configured_stra, curr_depth);
128     for (auto &edge : curr_op->succ_edges()) {
129       const auto &next_op = edge->next_operator();
130       if (visited->at(next_op)) {
131         const auto cost = edge->GetCostByStrategyPair({curr_op->selected_strategy(), next_op->selected_strategy()});
132         CheckShardingConsisitency(configured_ops, curr_op, next_op, cost, edge);
133         continue;
134       }
135       if ((curr_depth > 0) && (configured_ops.find(next_op) != configured_ops.end())) {
136         const auto &next_op_conf_stra = configured_ops[next_op];
137         const auto &next_op_stra = edge->GetNextOpStrategyByPrevOpStrategyWithZeroComm(curr_op->selected_strategy());
138         if ((next_op_conf_stra == nullptr) || (!next_op_conf_stra->IsEqual(next_op_stra))) {
139           MS_LOG(EXCEPTION) << "Sharding strategies should be configured on the boundary operators. "
140                             << "Currently reaching " << curr_op->name() << " and " << next_op->name() << "."
141                             << " The full name of these two operators are: " << curr_op->cnode()->fullname_with_scope()
142                             << " and " << next_op->cnode()->fullname_with_scope();
143         }
144       }
145       if (configured_ops.find(next_op) != configured_ops.end()) {
146         continue;
147       }
148       const auto &next_op_stra = edge->GetNextOpStrategyByPrevOpStrategyWithZeroComm(curr_op->selected_strategy());
149       if (next_op_stra == nullptr) {
150         PrintStrategy(curr_op->selected_strategy());
151         MS_LOG(EXCEPTION) << next_op->name() << "'s strategy is null in the edge: " << edge->edge_name();
152       }
153       (void)next_level.emplace(std::make_pair(next_op, next_op_stra), curr_depth + 1);
154     }
155     for (auto &edge : curr_op->prev_edges()) {
156       const auto &prev_op = edge->prev_operator();
157       if (visited->at(prev_op)) {
158         const auto cost = edge->GetCostByStrategyPair({prev_op->selected_strategy(), curr_op->selected_strategy()});
159         CheckShardingConsisitency(configured_ops, curr_op, prev_op, cost, edge);
160         continue;
161       }
162       if ((curr_depth > 0) && (configured_ops.find(prev_op) != configured_ops.end())) {
163         const auto &prev_op_conf_stra = configured_ops[prev_op];
164         const auto &prev_op_stra = edge->GetPrevOpStrategyByNextOpStrategyWithZeroComm(curr_op->selected_strategy());
165         if ((prev_op_conf_stra == nullptr) || (!prev_op_conf_stra->IsEqual(prev_op_stra))) {
166           MS_LOG(ERROR) << "curr_depth: " << curr_depth;
167           MS_LOG(EXCEPTION) << "Sharding strategies should be configured on the boundary operators. "
168                             << "Currently reaching " << prev_op->name() << " and " << curr_op->name() << "."
169                             << " The full name of these two operators are: " << prev_op->cnode()->fullname_with_scope()
170                             << " and " << curr_op->cnode()->fullname_with_scope();
171         }
172       }
173       if (configured_ops.find(prev_op) != configured_ops.end()) {
174         continue;
175       }
176       const auto &prev_op_stra = edge->GetPrevOpStrategyByNextOpStrategyWithZeroComm(curr_op->selected_strategy());
177       if (prev_op_stra == nullptr) {
178         PrintStrategy(curr_op->selected_strategy());
179         MS_LOG(EXCEPTION) << prev_op->name() << "'s strategy is null in the edge: " << edge->edge_name();
180       }
181       (void)next_level.emplace(std::make_pair(prev_op, prev_op_stra), curr_depth + 1);
182     }
183     next_level.pop();
184   }
185 }
186 
ConstructConnectedComponents(std::vector<OperatorInfoPtr> alive_ops)187 std::vector<std::shared_ptr<CostGraph>> CostGraph::ConstructConnectedComponents(
188   std::vector<OperatorInfoPtr> alive_ops) {
189   std::map<OperatorInfoPtr, bool> visited;
190 
191   for (auto &op : alive_ops) {
192     visited[op] = false;
193   }
194 
195   MS_LOG(INFO) << "visited: " << visited.size() << ".";
196   for (auto &op : alive_ops) {
197     if ((!visited[op]) && op->is_alive()) {
198       std::shared_ptr<CostGraph> new_component = std::make_shared<CostGraph>();
199       MS_EXCEPTION_IF_NULL(new_component);
200       DFS(op, &visited, new_component);
201       connected_compoents_.push_back(new_component);
202     }
203   }
204   return connected_compoents_;
205 }
206 
DFS(const OperatorInfoPtr & current_op,std::map<OperatorInfoPtr,bool> * visited,const std::shared_ptr<CostGraph> & component)207 void CostGraph::DFS(const OperatorInfoPtr &current_op, std::map<OperatorInfoPtr, bool> *visited,
208                     const std::shared_ptr<CostGraph> &component) {
209   MS_EXCEPTION_IF_NULL(visited);
210   MS_EXCEPTION_IF_NULL(component);
211   visited->at(current_op) = true;
212   component->AddOperator(current_op);
213 
214   for (auto &edge : current_op->succ_edges()) {
215     bool bool_test = (visited->find(edge->next_operator()) != visited->end()) &&
216                      (!visited->at(edge->next_operator())) && edge->next_operator()->is_alive();
217     if (bool_test) {
218       component->AddEdge(current_op, edge->next_operator(), edge);
219       DFS(edge->next_operator(), visited, component);
220     }
221   }
222 
223   for (auto &edge : current_op->prev_edges()) {
224     bool bool_test = (visited->find(edge->prev_operator()) != visited->end()) &&
225                      (!visited->at(edge->prev_operator())) && edge->prev_operator()->is_alive();
226     if (bool_test) {
227       component->AddEdge(edge->prev_operator(), current_op, edge);
228       DFS(edge->prev_operator(), visited, component);
229     }
230   }
231 }
232 
233 // Create final cost list for the graph: u --> v
CreateFinalCostList(const OperatorInfoPtr & u,const std::shared_ptr<Edge> & e,const OperatorInfoPtr & v)234 CostPtrList CostGraph::CreateFinalCostList(const OperatorInfoPtr &u, const std::shared_ptr<Edge> &e,
235                                            const OperatorInfoPtr &v) {
236   MS_EXCEPTION_IF_NULL(u);
237   MS_EXCEPTION_IF_NULL(v);
238   MS_EXCEPTION_IF_NULL(e);
239   CostPtrList ret;
240   for (const auto &u_strategy : u->GetStrategyCost()) {
241     for (const auto &v_strategy : v->GetStrategyCost()) {
242       MS_EXCEPTION_IF_NULL(u_strategy);
243       MS_EXCEPTION_IF_NULL(v_strategy);
244       auto u_strategy_ptr = u_strategy->strategy_ptr;
245       auto v_strategy_ptr = v_strategy->strategy_ptr;
246       CostPtrList clist1 = u_strategy->cost_list;
247       CostPtrList clist2 = e->GetCostList(u_strategy_ptr, v_strategy_ptr);
248       CostPtrList clist3 = v_strategy->cost_list;
249       for (const auto &cost1 : clist1) {
250         for (const auto &cost2 : clist2) {
251           for (const auto &cost3 : clist3) {
252             MS_EXCEPTION_IF_NULL(cost1);
253             MS_EXCEPTION_IF_NULL(cost2);
254             MS_EXCEPTION_IF_NULL(cost3);
255             double computation = cost1->computation_cost_ + cost2->computation_cost_ + cost3->computation_cost_;
256             double memory = cost1->memory_with_reuse_ + cost2->memory_with_reuse_ + cost3->memory_with_reuse_;
257             double communication = cost1->communication_cost_ + cost2->communication_cost_ + cost3->communication_cost_;
258             double communication_forward =
259               cost1->communication_forward_ + cost2->communication_forward_ + cost3->communication_forward_;
260             double communication_without_para = cost1->communication_without_parameter_ +
261                                                 cost2->communication_without_parameter_ +
262                                                 cost3->communication_without_parameter_;
263             auto decision =
264               std::make_shared<FinalDecision>(u_strategy->strategy_ptr, v_strategy->strategy_ptr, cost1, cost2, cost3);
265             auto cost = std::make_shared<Cost>(computation, communication, decision);
266             const auto gamma = CostModelContext::GetInstance()->costmodel_gamma();
267             MS_EXCEPTION_IF_NULL(cost);
268             cost->communication_without_parameter_ = communication_without_para;
269             cost->communication_with_partial_para_ =
270               communication_without_para + gamma * (communication - communication_without_para);
271             cost->memory_with_reuse_ = memory;
272             cost->communication_forward_ = communication_forward;
273             ret.push_back(cost);
274           }
275         }
276       }
277     }
278   }
279 
280   Simplify(&ret);
281   return ret;
282 }
283 
284 // Create final cost list for the graph containing a single node: u
CreateFinalSingleCostList(const OperatorInfoPtr & u)285 CostPtrList CostGraph::CreateFinalSingleCostList(const OperatorInfoPtr &u) {
286   MS_EXCEPTION_IF_NULL(u);
287   CostPtrList ret;
288   for (const auto &u_strategy : u->GetStrategyCost()) {
289     MS_EXCEPTION_IF_NULL(u_strategy);
290     auto u_strategy_ptr = u_strategy->strategy_ptr;
291     CostPtrList clist1 = u_strategy->cost_list;
292     for (const auto &cost1 : clist1) {
293       MS_EXCEPTION_IF_NULL(cost1);
294       auto decision = std::make_shared<FinalSingleDecision>(u_strategy_ptr, cost1);
295       auto new_cost = std::make_shared<Cost>(cost1->computation_cost_, cost1->communication_cost_, decision);
296       const auto gamma = CostModelContext::GetInstance()->costmodel_gamma();
297       MS_EXCEPTION_IF_NULL(new_cost);
298       new_cost->communication_without_parameter_ = cost1->communication_without_parameter_;
299       new_cost->communication_with_partial_para_ =
300         cost1->communication_without_parameter_ +
301         gamma * (cost1->communication_cost_ - cost1->communication_without_parameter_);
302       new_cost->memory_with_reuse_ = cost1->memory_with_reuse_;
303       new_cost->communication_forward_ = cost1->communication_forward_;
304       ret.push_back(new_cost);
305     }
306   }
307 
308   Simplify(&ret);
309   return ret;
310 }
311 
SelectCostWithMinInferenceTime(const CostPtrList & cost_list,double memory)312 CostPtr CostGraph::SelectCostWithMinInferenceTime(const CostPtrList &cost_list, double memory) {
313   // Select the cost with minimum inference time. Currently, the inference time is modeled as =
314   // costmodel_alpha_ * computation_cost + costmodel_beta_ * communication_forward_
315   if (cost_list.empty()) {
316     MS_LOG(ERROR) << "Final cost list is null.";
317     return nullptr;
318   }
319   CostPtrList after_mem_filter;
320   double minimum_memory = DBL_MAX;
321   // Filter out the valid costs.
322   for (auto &a_cost : cost_list) {
323     if (a_cost->memory_with_reuse_ <= memory) {
324       after_mem_filter.emplace_back(std::move(a_cost));
325     } else if (a_cost->memory_with_reuse_ < minimum_memory) {
326       minimum_memory = a_cost->memory_with_reuse_;
327     }
328   }
329   if (after_mem_filter.empty()) {
330     MS_LOG(ERROR) << "No available cost. The minimum memory cost is: " << minimum_memory
331                   << ", the memory capacity is: " << memory << ".";
332     return nullptr;
333   }
334   // Init the returned value with first cost.
335   CostPtr ret = after_mem_filter[0];
336   const auto alpha = CostModelContext::GetInstance()->costmodel_alpha();
337   const auto beta = CostModelContext::GetInstance()->costmodel_beta();
338 
339   double minimum = alpha * ret->computation_cost_ + beta * ret->communication_forward_;
340   MS_LOG(INFO) << "Cost 0: "
341                << "memory_cost: " << ret->memory_with_reuse_ << ", computation_cost_: " << ret->computation_cost_
342                << ", communication_forward_: " << ret->communication_forward_
343                << ", communication_with_partial_para_: " << ret->communication_with_partial_para_
344                << ", communication_cost_: " << ret->communication_cost_
345                << ", communication_without_parameter_: " << ret->communication_without_parameter_ << ".";
346   MS_LOG(INFO) << "Cost 0: total_cost: " << minimum;
347   for (size_t i = 1; i < after_mem_filter.size(); ++i) {
348     MS_EXCEPTION_IF_NULL(after_mem_filter[i]);
349     MS_LOG(INFO) << "Cost " << i << ": memory_cost: " << after_mem_filter[i]->memory_with_reuse_
350                  << ", computation_cost_: " << after_mem_filter[i]->computation_cost_
351                  << ", communication_forward_: " << after_mem_filter[i]->communication_forward_
352                  << ", communication_with_partial_para_: " << after_mem_filter[i]->communication_with_partial_para_
353                  << ", communication_cost_: " << after_mem_filter[i]->communication_cost_
354                  << ", communication_without_parameter_: " << after_mem_filter[i]->communication_without_parameter_
355                  << ".";
356     auto tmp = alpha * after_mem_filter[i]->computation_cost_ + beta * after_mem_filter[i]->communication_forward_;
357     MS_LOG(INFO) << "Cost " << i << ": total_cost: " << tmp;
358     if (minimum > tmp) {
359       minimum = tmp;
360       ret = after_mem_filter[i];
361       MS_LOG(INFO) << "Selected: " << i;
362     }
363   }
364   return ret;
365 }
366 
SelectCostWithMinTrainingTime(const CostPtrList & cost_list,double memory)367 CostPtr CostGraph::SelectCostWithMinTrainingTime(const CostPtrList &cost_list, double memory) {
368   // Select the cost with minimum training time. Currently, the training time is modeled as =
369   // costmodel_alpha_ * computation_cost + costmodel_beta_ * communication_with_partial_para_
370   if (cost_list.empty()) {
371     MS_LOG(ERROR) << "Final cost list is null.";
372     return nullptr;
373   }
374   CostPtrList after_mem_filter;
375   double minimum_memory = DBL_MAX;
376   // Filter out the valid costs.
377   for (auto &a_cost : cost_list) {
378     if (a_cost->memory_with_reuse_ <= memory) {
379       after_mem_filter.emplace_back(std::move(a_cost));
380     } else if (a_cost->memory_with_reuse_ < minimum_memory) {
381       minimum_memory = a_cost->memory_with_reuse_;
382     }
383   }
384   if (after_mem_filter.empty()) {
385     MS_LOG(ERROR) << "No available cost. The minimum memory cost is: " << minimum_memory
386                   << ", the memory capacity is: " << memory << ".";
387     return nullptr;
388   }
389   // Init the returned value with first cost.
390   CostPtr ret = after_mem_filter[0];
391   const auto alpha = CostModelContext::GetInstance()->costmodel_alpha();
392   const auto beta = CostModelContext::GetInstance()->costmodel_beta();
393 
394   double minimum = alpha * ret->computation_cost_ + beta * ret->communication_with_partial_para_;
395   MS_LOG(INFO) << "Cost 0: "
396                << "memory_cost: " << ret->memory_with_reuse_ << ", computation_cost_: " << ret->computation_cost_
397                << ", communication_with_partial_para_: " << ret->communication_with_partial_para_
398                << ", communication_cost_: " << ret->communication_cost_
399                << ", communication_without_parameter_: " << ret->communication_without_parameter_ << ".";
400   MS_LOG(INFO) << "Cost 0: total_cost: " << minimum;
401   for (size_t i = 1; i < after_mem_filter.size(); ++i) {
402     MS_EXCEPTION_IF_NULL(after_mem_filter[i]);
403     MS_LOG(INFO) << "Cost " << i << ": memory_cost: " << after_mem_filter[i]->memory_with_reuse_
404                  << ", computation_cost_: " << after_mem_filter[i]->computation_cost_
405                  << ", communication_with_partial_para_: " << after_mem_filter[i]->communication_with_partial_para_
406                  << ", communication_cost_: " << after_mem_filter[i]->communication_cost_
407                  << ", communication_without_parameter_: " << after_mem_filter[i]->communication_without_parameter_
408                  << ".";
409     auto tmp =
410       alpha * after_mem_filter[i]->computation_cost_ + beta * after_mem_filter[i]->communication_with_partial_para_;
411     MS_LOG(INFO) << "Cost " << i << ": total_cost: " << tmp;
412     if (minimum > tmp) {
413       minimum = tmp;
414       ret = after_mem_filter[i];
415       MS_LOG(INFO) << "Selected: " << i;
416     }
417   }
418   return ret;
419 }
420 
SelectCostListWithMinTrainingTimeMultiple(const std::vector<CostPtrList> & all_cost_list,double available_memory) const421 CostPtrList CostGraph::SelectCostListWithMinTrainingTimeMultiple(const std::vector<CostPtrList> &all_cost_list,
422                                                                  double available_memory) const {
423   CostPtrList selected_cost_list(all_cost_list.size(), nullptr);
424   double minimum = DBL_MAX, total_memory = 0.0;
425   CostPtrList ret(all_cost_list.size(), nullptr);
426   // Check whether valid costs exist.
427   for (size_t i = 0; i < all_cost_list.size(); ++i) {
428     if (all_cost_list[i][0] == nullptr) {
429       MS_LOG(ERROR) << "The cost list " << i << " is empty.";
430       return ret;
431     } else {
432       double memory_i_cost = DBL_MAX;
433       for (size_t j = 0; j < all_cost_list[i].size(); ++j) {
434         if (all_cost_list[i][j]->memory_with_reuse_ < memory_i_cost) {
435           memory_i_cost = all_cost_list[i][j]->memory_with_reuse_;
436         }
437       }
438       total_memory += memory_i_cost;
439     }
440   }
441   if (total_memory >= available_memory) {
442     MS_LOG(ERROR) << "No strategy can be found under current memory: " << available_memory
443                   << ", minimum strategy cost: " << total_memory << ".";
444     return selected_cost_list;
445   }
446 
447   std::function<void(size_t)> recursive = [&all_cost_list, &selected_cost_list, &minimum, &ret, &recursive,
448                                            &available_memory](size_t k) {
449     const auto alpha = CostModelContext::GetInstance()->costmodel_alpha();
450     const auto beta = CostModelContext::GetInstance()->costmodel_beta();
451     if (k == all_cost_list.size()) {
452       double tmp_memory = 0.0, tmp_minimum = 0.0;
453       for (size_t i = 0; i < selected_cost_list.size(); ++i) {
454         MS_EXCEPTION_IF_NULL(selected_cost_list[i]);
455         tmp_memory += selected_cost_list[i]->memory_with_reuse_;
456         tmp_minimum += alpha * selected_cost_list[i]->computation_cost_ +
457                        beta * selected_cost_list[i]->communication_with_partial_para_;
458       }
459       MS_LOG(INFO) << "tmp_memory: " << tmp_memory << ", tmp_minimum: " << tmp_minimum << ", minimum: " << minimum
460                    << ".";
461       if (tmp_memory < available_memory && tmp_minimum < minimum) {
462         ret = selected_cost_list;
463         minimum = tmp_minimum;
464         MS_LOG(INFO) << "selected tmp_memory: " << tmp_memory << ", tmp_minimum: " << tmp_minimum << ".";
465       }
466       return;
467     }
468 
469     MS_LOG(DEBUG) << "The value minimum: " << minimum << ", available_memory: " << available_memory << ".";
470     for (auto &c : all_cost_list[k]) {
471       selected_cost_list[k] = c;
472       recursive(k + 1);
473     }
474   };
475   recursive(0);
476   return ret;
477 }
478 
SearchStrategyForMultiNodeFinalGraph(const std::vector<OperatorInfoPtr> & alive_ops)479 Status CostGraph::SearchStrategyForMultiNodeFinalGraph(const std::vector<OperatorInfoPtr> &alive_ops) {
480   MS_LOG(INFO) << "There are " << alive_ops.size() << " nodes in the final graph.";
481   auto connected_components = ConstructConnectedComponents(alive_ops);
482   MS_LOG(INFO) << "There are " << connected_components.size() << " components in the final graph.";
483   std::vector<CostPtrList> all_list;
484   for (size_t j = 0; j < connected_components.size(); ++j) {
485     auto one_component = connected_components[j];
486     MS_EXCEPTION_IF_NULL(one_component);
487     if (one_component->GetOperators().size() == 1) {
488       MS_LOG(INFO) << "There are 1 operator in a component in the final graph.";
489       auto cost_list_1 = one_component->CreateFinalSingleCostList(one_component->GetOperators()[0]);
490       all_list.push_back(cost_list_1);
491     } else if (one_component->GetOperators().size() == 2) {
492       MS_LOG(INFO) << "There are 2 operators in a component in the final graph.";
493       OperatorInfoPtr u, v;
494       auto first_op = one_component->GetOperators()[0];
495       auto second_op = one_component->GetOperators()[1];
496       MS_EXCEPTION_IF_NULL(first_op);
497       MS_EXCEPTION_IF_NULL(second_op);
498       if (!first_op->GetAliveSuccEdges().empty() &&
499           first_op->GetAliveSuccEdges()[0]->next_operator().get() == second_op.get()) {
500         u = first_op;
501         v = second_op;
502       } else if (!second_op->GetAliveSuccEdges().empty() &&
503                  second_op->GetAliveSuccEdges()[0]->next_operator().get() == first_op.get()) {
504         u = second_op;
505         v = first_op;
506       } else {
507         MS_LOG(EXCEPTION) << "The final graph is not the case of u --> v, " << first_op->GetAliveSuccEdges().size()
508                           << ", " << second_op->GetAliveSuccEdges().size() << ".";
509       }
510       MS_EXCEPTION_IF_NULL(u);
511       auto e = u->GetAliveSuccEdges()[0];
512       auto cost_list = one_component->CreateFinalCostList(u, e, v);
513       all_list.push_back(cost_list);
514     } else {
515       MS_LOG(EXCEPTION) << "There are " << one_component->GetOperators().size()
516                         << " operators in a component in the final graph.";
517     }
518   }
519   const auto device_mem_capacity = CostModelContext::GetInstance()->device_memory_capacity();
520   auto selected_cost_list = SelectCostListWithMinTrainingTimeMultiple(all_list, device_mem_capacity);
521   for (size_t k = 0; k < selected_cost_list.size(); ++k) {
522     auto selected_cost = selected_cost_list[k];
523     if (selected_cost == nullptr) {
524       MS_LOG(ERROR) << "No valid strategy can be found under the current device memory: " << device_mem_capacity << ".";
525       return FAILED;
526     }
527     MS_EXCEPTION_IF_NULL(connected_components[k]);
528     if (connected_components[k]->GetOperators().size() == 1) {
529       auto u = connected_components[k]->GetOperators()[0];
530       auto decision_f = selected_cost->decision_ptr_->cast<FinalSingleDecisionPtr>();
531       u->SetSelectedStrategyAndCost(decision_f->u_strategy_, decision_f->u_cost_);
532       MS_LOG(INFO) << "Searching the strategy for the component " << k << " final graph ended.";
533     } else if (connected_components[k]->GetOperators().size() == 2) {
534       OperatorInfoPtr u = nullptr, v = nullptr;
535       auto first_op = connected_components[k]->GetOperators()[0];
536       auto second_op = connected_components[k]->GetOperators()[1];
537       MS_EXCEPTION_IF_NULL(first_op);
538       MS_EXCEPTION_IF_NULL(second_op);
539       if (!first_op->GetAliveSuccEdges().empty() &&
540           first_op->GetAliveSuccEdges()[0]->next_operator().get() == second_op.get()) {
541         u = first_op;
542         v = second_op;
543       } else if (!second_op->GetAliveSuccEdges().empty() &&
544                  second_op->GetAliveSuccEdges()[0]->next_operator().get() == first_op.get()) {
545         u = second_op;
546         v = first_op;
547       }
548       MS_EXCEPTION_IF_NULL(u);
549       auto e = u->GetAliveSuccEdges()[0];
550       MS_EXCEPTION_IF_NULL(v);
551       MS_EXCEPTION_IF_NULL(e);
552       MS_EXCEPTION_IF_NULL(selected_cost->decision_ptr_);
553       auto decision = selected_cost->decision_ptr_->cast<FinalDecisionPtr>();
554       MS_EXCEPTION_IF_NULL(decision);
555       u->SetSelectedStrategyAndCost(decision->u_strategy_, decision->left_cost_);
556       v->SetSelectedStrategyAndCost(decision->v_strategy_, decision->right_cost_);
557       e->set_selected_cost(decision->middle_cost_);
558       MS_LOG(INFO) << "Searching the strategy for the component " << k << " final graph ended.";
559     }
560   }
561   return SUCCESS;
562 }
563 
SearchStrategyForTwoNodeFinalGraph(const std::vector<OperatorInfoPtr> & alive_ops)564 Status CostGraph::SearchStrategyForTwoNodeFinalGraph(const std::vector<OperatorInfoPtr> &alive_ops) {
565   // In this case, the final graph should contains exactly 2 nodes.
566   if (alive_ops.empty()) {
567     MS_LOG(INFO) << "0 Operator in the final graph.";
568     return SUCCESS;
569   }
570   OperatorInfoPtr u, v;
571   MS_EXCEPTION_IF_NULL(alive_ops[0]);
572   MS_EXCEPTION_IF_NULL(alive_ops[1]);
573   const auto phase = CostModelContext::GetInstance()->run_phase();
574   const auto device_mem_capacity = CostModelContext::GetInstance()->device_memory_capacity();
575   if (!alive_ops[0]->GetAliveSuccEdges().empty() &&
576       alive_ops[0]->GetAliveSuccEdges()[0]->next_operator().get() == alive_ops[1].get()) {
577     u = alive_ops[0];
578     v = alive_ops[1];
579   } else if (!alive_ops[1]->GetAliveSuccEdges().empty() &&
580              alive_ops[1]->GetAliveSuccEdges()[0]->next_operator().get() == alive_ops[0].get()) {
581     u = alive_ops[1];
582     v = alive_ops[0];
583   } else {
584     if (!alive_ops[0]->GetAliveSuccEdges().empty() || !alive_ops[1]->GetAliveSuccEdges().empty()) {
585       MS_LOG(EXCEPTION) << "The final graph is not the case of u --> v, " << alive_ops[0]->GetAliveSuccEdges().size()
586                         << ", " << alive_ops[1]->GetAliveSuccEdges().size() << ".";
587     } else {
588       // In this case, the final graph consists of two single nodes
589       MS_LOG(INFO) << "There are 2 single nodes in the final graph.";
590       std::vector<CostPtrList> all_list;
591       auto connected_components = ConstructConnectedComponents(alive_ops);
592       MS_LOG(INFO) << "There are " << connected_components.size() << " components in the final graph.";
593       for (size_t i = 0; i < connected_components.size(); ++i) {
594         MS_LOG(INFO) << "There are 1 operator in a component in the final graph.";
595         auto one_component = connected_components[i];
596         MS_EXCEPTION_IF_NULL(one_component);
597         auto cost_list = one_component->CreateFinalSingleCostList(one_component->GetOperators()[0]);
598         all_list.push_back(cost_list);
599       }
600       CostPtrList selected_cost_list;
601       if (phase == TRAINING_PHASE) {
602         // training phase
603         selected_cost_list = SelectCostListWithMinTrainingTimeMultiple(all_list, device_mem_capacity);
604       } else {
605         // inference phase
606         MS_LOG(EXCEPTION) << "Currently, searching strategy for the two-separated-node final graph in the inference "
607                              "phase is not supported.";
608       }
609       for (size_t k = 0; k < selected_cost_list.size(); ++k) {
610         auto selected_cost = selected_cost_list[k];
611         if (selected_cost == nullptr) {
612           MS_LOG(ERROR) << "No valid strategy can be found under the current device memory: " << device_mem_capacity
613                         << ".";
614           return FAILED;
615         }
616         MS_EXCEPTION_IF_NULL(connected_components[k]);
617         auto one_operator = connected_components[k]->GetOperators()[0];
618         MS_EXCEPTION_IF_NULL(selected_cost->decision_ptr_);
619         auto decision = selected_cost->decision_ptr_->cast<FinalSingleDecisionPtr>();
620         MS_EXCEPTION_IF_NULL(decision);
621         one_operator->SetSelectedStrategyAndCost(decision->u_strategy_, decision->u_cost_);
622         MS_LOG(INFO) << "Searching the strategy for the component " << k << " final graph ended.";
623       }
624 
625       return SUCCESS;
626     }
627   }
628   MS_LOG(INFO) << "There are 2 nodes in the final graph.";
629   // In this case, the finale graph is exactly of the form: u --> v
630   MS_EXCEPTION_IF_NULL(u);
631   MS_EXCEPTION_IF_NULL(v);
632   auto e = u->GetAliveSuccEdges()[0];
633   MS_EXCEPTION_IF_NULL(e);
634   auto f_cost_list = CreateFinalCostList(u, e, v);
635   CostPtr cost = nullptr;
636   if (phase == TRAINING_PHASE) {
637     // training phase
638     cost = SelectCostWithMinTrainingTime(f_cost_list, device_mem_capacity);
639   } else {
640     MS_LOG(EXCEPTION) << "Currently, searching strategy for the two-connected-node final graph in the inference "
641                          "phase is not supported.";
642   }
643   if (cost == nullptr) {
644     MS_LOG(ERROR) << "No valid strategy can be found under the current device memory: " << device_mem_capacity << ".";
645     return FAILED;
646   }
647   MS_EXCEPTION_IF_NULL(cost->decision_ptr_);
648   auto f_decision = cost->decision_ptr_->cast<FinalDecisionPtr>();
649   MS_EXCEPTION_IF_NULL(f_decision);
650   u->SetSelectedStrategyAndCost(f_decision->u_strategy_, f_decision->left_cost_);
651   v->SetSelectedStrategyAndCost(f_decision->v_strategy_, f_decision->right_cost_);
652   e->set_selected_cost(f_decision->middle_cost_);
653   MS_LOG(INFO) << "Searching the strategy for the eliminated final graph ended.";
654   return SUCCESS;
655 }
656 
657 // searching the strategy for the final eliminated graph
SearchStrategy()658 Status CostGraph::SearchStrategy() {
659   MS_LOG(INFO) << "Searching the strategy for the eliminated final graph began.";
660   std::vector<OperatorInfoPtr> alive_ops;
661   (void)std::for_each(ops_.begin(), ops_.end(), [&alive_ops](const OperatorInfoPtr &op) {
662     MS_EXCEPTION_IF_NULL(op);
663     if (op->is_alive()) {
664       alive_ops.push_back(op);
665     }
666   });
667   const auto phase = CostModelContext::GetInstance()->run_phase();
668   const auto device_mem_capacity = CostModelContext::GetInstance()->device_memory_capacity();
669 
670   if (alive_ops.size() > 2) {
671     if (phase == TRAINING_PHASE) {
672       // training phase
673       return SearchStrategyForMultiNodeFinalGraph(alive_ops);
674     } else {
675       // inference phase
676       MS_LOG(EXCEPTION)
677         << "Currently, searching strategy for the multi-node final graph in inference phase is not supported.";
678     }
679   } else if (alive_ops.size() == 1) {
680     MS_LOG(INFO) << "There are 1 single node in the final graph.";
681     OperatorInfoPtr u = alive_ops[0];
682     auto cost_list = CreateFinalSingleCostList(u);
683     CostPtr cost = nullptr;
684     if (phase == TRAINING_PHASE) {
685       // training phase
686       cost = SelectCostWithMinTrainingTime(cost_list, device_mem_capacity);
687     } else {
688       // inference phase
689       cost = SelectCostWithMinInferenceTime(cost_list, device_mem_capacity);
690     }
691     if (cost == nullptr) {
692       MS_LOG(ERROR) << "No valid strategy can be found under the current device memory: " << device_mem_capacity << ".";
693       return FAILED;
694     }
695     MS_EXCEPTION_IF_NULL(u);
696     MS_EXCEPTION_IF_NULL(cost->decision_ptr_);
697     auto decision = cost->decision_ptr_->cast<FinalSingleDecisionPtr>();
698     MS_EXCEPTION_IF_NULL(decision);
699     u->SetSelectedStrategyAndCost(decision->u_strategy_, decision->u_cost_);
700     MS_LOG(INFO) << "Searching the strategy for the eliminated final graph ended.";
701     return SUCCESS;
702   } else {
703     return SearchStrategyForTwoNodeFinalGraph(alive_ops);
704   }
705 }
706 
707 // Given a graph which contains the following subgraph: u --> v --> w, the node v can be eliminated
708 // return the v and the edge u --> v
CheckOpElimination() const709 OperatorInfoPtr CostGraph::CheckOpElimination() const {
710   for (auto &op : ops_) {
711     bool bool_test = op->is_alive() && op->GetAliveSuccEdges().size() == 1 && op->GetAlivePrevEdges().size() == 1;
712     if (bool_test) {
713       if ((op->GetAliveSuccEdges()[0]->next_operator() != op) && (op->GetAlivePrevEdges()[0]->prev_operator() != op)) {
714         return op;
715       }
716     }
717   }
718   return nullptr;
719 }
720 
721 // Check the graph whether an EdgeElimination can be performed
CheckEdgeElimination() const722 std::vector<std::shared_ptr<Edge>> CostGraph::CheckEdgeElimination() const {
723   for (auto &op : ops_) {
724     MS_EXCEPTION_IF_NULL(op);
725     if (!op->is_alive()) continue;
726     std::map<void *, int64_t> count;
727     for (auto &edge_su : op->GetAliveSuccEdges()) {
728       MS_EXCEPTION_IF_NULL(edge_su);
729       auto v = edge_su->next_operator();
730       count[v.get()]++;
731     }
732     for (auto &pair : count) {
733       auto *op_ptr = pair.first;
734       int64_t op_count = pair.second;
735       if (op_count > 1) {
736         std::vector<std::shared_ptr<Edge>> ret;
737         for (auto &edge : op->GetAliveSuccEdges()) {
738           MS_EXCEPTION_IF_NULL(edge);
739           if (edge->next_operator().get() == op_ptr) {
740             ret.push_back(edge);
741           }
742         }
743         return ret;
744       }
745     }
746   }
747   return {};
748 }
749 
750 // Check the graph whether a MergeElimination can be performed
CheckMergeElimination() const751 OperatorInfoPtr CostGraph::CheckMergeElimination() const {
752   for (auto &op : ops_) {
753     MS_EXCEPTION_IF_NULL(op);
754     bool bool_test = op->is_alive() && op->GetAlivePrevEdges().empty() && op->GetAliveSuccEdges().size() == 1;
755     if (bool_test) {
756       auto next_op = op->GetAliveSuccEdges()[0]->next_operator();
757       MS_EXCEPTION_IF_NULL(next_op);
758       if (!next_op->GetAlivePrevEdges().empty()) {
759         return op;
760       }
761     }
762   }
763   return nullptr;
764 }
765 
766 // Check the graph whether a ContractElimination can be performed
CheckContractElimination() const767 OperatorInfoPtr CostGraph::CheckContractElimination() const {
768   for (auto &op : ops_) {
769     MS_EXCEPTION_IF_NULL(op);
770     bool bool_test = op->is_alive() && op->GetAlivePrevEdges().size() == 1 && op->GetAliveSuccEdges().empty();
771     if (bool_test) {
772       auto edge = op->GetAlivePrevEdges()[0];
773       MS_EXCEPTION_IF_NULL(edge);
774       auto prev_op = edge->prev_operator();
775       MS_EXCEPTION_IF_NULL(prev_op);
776       if (!prev_op->GetAliveSuccEdges().empty()) {
777         return op;
778       }
779     }
780   }
781   return nullptr;
782 }
783 
CheckSourceElimination() const784 std::pair<OperatorInfoPtr, OperatorInfoPtr> CostGraph::CheckSourceElimination() const {
785   size_t source_count = 0;
786   std::vector<OperatorInfoPtr> op_vector(2, nullptr);
787   for (auto &op : ops_) {
788     MS_EXCEPTION_IF_NULL(op);
789     bool bool_test = op->is_alive() && op->GetAlivePrevEdges().empty() && op->GetAliveSuccEdges().size() > 0;
790     if (bool_test) {
791       op_vector[source_count++] = op;
792       if (source_count == 2) {
793         return std::make_pair(op_vector[0], op_vector[1]);
794       }
795     }
796   }
797   return std::make_pair(nullptr, nullptr);
798 }
799 
CreateSourceEliminationSubCostList(StrategyPtr op1_old_stra,const CostPtrList & op1_old_clist,StrategyPtr op2_old_stra,const CostPtrList & op2_old_clist,CostPtrList * op1_new_clist)800 void CostGraph::CreateSourceEliminationSubCostList(StrategyPtr op1_old_stra, const CostPtrList &op1_old_clist,
801                                                    StrategyPtr op2_old_stra, const CostPtrList &op2_old_clist,
802                                                    CostPtrList *op1_new_clist) {
803   for (auto &op1_cost : op1_old_clist) {
804     for (auto &op2_cost : op2_old_clist) {
805       double computation = op1_cost->computation_cost_ + op2_cost->computation_cost_;
806       double memory = op1_cost->memory_with_reuse_ + op2_cost->memory_with_reuse_;
807       double communication = op1_cost->communication_cost_ + op2_cost->communication_cost_;
808       double communication_forward = op1_cost->communication_forward_ + op2_cost->communication_forward_;
809       double communication_without_para =
810         op1_cost->communication_without_parameter_ + op2_cost->communication_without_parameter_;
811       auto decision = std::make_shared<SourceEliminationDecision>(op1_old_stra, op1_cost, op2_old_stra, op2_cost);
812       auto new_cost = std::make_shared<Cost>(computation, communication, decision);
813       const auto gamma = CostModelContext::GetInstance()->costmodel_gamma();
814       MS_EXCEPTION_IF_NULL(new_cost);
815       new_cost->communication_without_parameter_ = communication_without_para;
816       new_cost->communication_with_partial_para_ =
817         communication_without_para + gamma * (communication - communication_without_para);
818       new_cost->memory_with_reuse_ = memory;
819       new_cost->communication_forward_ = communication_forward;
820       MS_EXCEPTION_IF_NULL(op1_new_clist);
821       op1_new_clist->emplace_back(std::move(new_cost));
822     }
823   }
824 }
825 
UpdateEdgesIncidentToNodes(OperatorInfoPtr op1,std::vector<EdgePtr> * op1_old_succ_edges,std::vector<std::map<CostPtrKey,CostPtrList>> * op1_new_edges_cost,std::vector<EdgePtr> * op1_new_succ_edges,const OperatorInfoPtr op2,std::vector<EdgePtr> * op2_old_succ_edges,std::vector<std::map<CostPtrKey,CostPtrList>> * op2_new_edges_cost,std::vector<EdgePtr> * op2_new_succ_edges)826 std::pair<std::vector<EdgePtr>, std::vector<EdgePtr>> UpdateEdgesIncidentToNodes(
827   OperatorInfoPtr op1, std::vector<EdgePtr> *op1_old_succ_edges,
828   std::vector<std::map<CostPtrKey, CostPtrList>> *op1_new_edges_cost, std::vector<EdgePtr> *op1_new_succ_edges,
829   const OperatorInfoPtr op2, std::vector<EdgePtr> *op2_old_succ_edges,
830   std::vector<std::map<CostPtrKey, CostPtrList>> *op2_new_edges_cost, std::vector<EdgePtr> *op2_new_succ_edges) {
831   for (size_t i = 0; i < op1_old_succ_edges->size(); ++i) {
832     auto &new_cost_map = op1_new_edges_cost->at(i);
833     auto ith_edge = op1_old_succ_edges->at(i);
834 
835     std::string new_edge_name = op1->name() + OPERATOR_TO_OPERATOR_CONNECTOR + ith_edge->next_operator()->name();
836     std::shared_ptr<Edge> new_edge;
837     if (ith_edge->is_combined()) {
838       std::vector<size_t> output_indexs, input_indexs;
839       output_indexs = ith_edge->prev_op_output_indexs();
840       input_indexs = ith_edge->next_op_input_indexs();
841       new_edge =
842         std::make_shared<Edge>(new_edge_name, op1, ith_edge->next_operator(), output_indexs, input_indexs, true);
843     } else {
844       size_t output_index, input_index;
845       output_index = ith_edge->prev_op_output_index();
846       input_index = ith_edge->next_op_input_index();
847       new_edge =
848         std::make_shared<Edge>(new_edge_name, op1, ith_edge->next_operator(), output_index, input_index, false);
849     }
850     new_edge->SetCostMapAndInputOutput(new_cost_map);
851     // replace the old successive edges with the new ones.
852     op1->ReplaceSuccEdge(ith_edge->next_operator(), new_edge);
853     ith_edge->next_operator()->ReplacePreEdge(op1, new_edge);
854     (void)op1_new_succ_edges->erase(op1_new_succ_edges->begin() + SizeToLong(i));
855     (void)op1_new_succ_edges->emplace(op1_new_succ_edges->begin() + SizeToLong(i), new_edge);
856   }
857   for (size_t i = 0; i < op2_old_succ_edges->size(); ++i) {
858     auto &new_cost_map = op2_new_edges_cost->at(i);
859     auto ith_edge = op2_old_succ_edges->at(i);
860     const auto &destination = ith_edge->next_operator();
861 
862     std::string new_edge_name = op1->name() + OPERATOR_TO_OPERATOR_CONNECTOR + destination->name();
863     std::shared_ptr<Edge> new_edge;
864     if (ith_edge->is_combined()) {
865       std::vector<size_t> output_indexs, input_indexs;
866       output_indexs = ith_edge->prev_op_output_indexs();
867       input_indexs = ith_edge->next_op_input_indexs();
868       new_edge = std::make_shared<Edge>(new_edge_name, op1, destination, output_indexs, input_indexs, true);
869     } else {
870       size_t output_index, input_index;
871       output_index = ith_edge->prev_op_output_index();
872       input_index = ith_edge->next_op_input_index();
873       new_edge = std::make_shared<Edge>(new_edge_name, op1, destination, output_index, input_index, false);
874     }
875     new_edge->SetCostMapAndInputOutput(new_cost_map);
876     // replace the old successive edges with the new ones.
877     destination->ReplacePreEdge(op2, new_edge);
878     op1->AddSuccEdge(new_edge);
879     (void)op2_new_succ_edges->erase(op2_new_succ_edges->begin() + SizeToLong(i));
880     (void)op2_new_succ_edges->emplace(op2_new_succ_edges->begin() + SizeToLong(i), new_edge);
881   }
882   return std::make_pair(*op1_new_succ_edges, *op2_new_succ_edges);
883 }
884 
EliminationSources(const OperatorInfoPtr op1,const OperatorInfoPtr op2)885 std::pair<std::vector<std::shared_ptr<Edge>>, std::vector<std::shared_ptr<Edge>>> CostGraph::EliminationSources(
886   const OperatorInfoPtr op1, const OperatorInfoPtr op2) {
887   MS_EXCEPTION_IF_NULL(op1);
888   MS_EXCEPTION_IF_NULL(op2);
889   MS_LOG(INFO) << "Now source eliminating node: " << op2->name() << " to node: " << op1->name();
890 
891   auto op1_old_succ_edges = op1->GetAliveSuccEdges();
892   std::vector<std::map<StrategyPtr, std::vector<std::pair<StrategyPtr, CostPtrList>>>> op1_edges_reorganised_cost(
893     op1_old_succ_edges.size());
894   std::vector<std::map<CostPtrKey, CostPtrList>> op1_new_edges_cost(op1_old_succ_edges.size());
895   std::vector<std::shared_ptr<Edge>> op1_new_succ_edges(op1_old_succ_edges.size());
896 
897   auto op2_old_succ_edges = op2->GetAliveSuccEdges();
898   std::vector<std::map<StrategyPtr, std::vector<std::pair<StrategyPtr, CostPtrList>>>> op2_edges_reorganised_cost(
899     op2_old_succ_edges.size());
900   std::vector<std::map<CostPtrKey, CostPtrList>> op2_new_edges_cost(op2_old_succ_edges.size());
901   std::vector<std::shared_ptr<Edge>> op2_new_succ_edges(op2_old_succ_edges.size());
902 
903   // Construct cost_map for the data_structure of 'op1_edges_reorganised_cost' and 'op2_edges_reorganised_cost'
904   for (size_t i = 0; i < op1_old_succ_edges.size(); ++i) {
905     const auto &op1_cost_map = op1_old_succ_edges[i]->GetCostMap();
906     std::map<StrategyPtr, std::vector<std::pair<StrategyPtr, CostPtrList>>> from_tocost;
907     for (const auto &key_value : op1_cost_map) {
908       const auto &from_to_strategies = key_value.first;
909       const auto &costlist = key_value.second;
910       from_tocost[from_to_strategies.first].emplace_back(std::make_pair(from_to_strategies.second, costlist));
911     }
912     op1_edges_reorganised_cost[i] = from_tocost;
913   }
914 
915   for (size_t i = 0; i < op2_old_succ_edges.size(); ++i) {
916     const auto &op2_cost_map = op2_old_succ_edges[i]->GetCostMap();
917     std::map<StrategyPtr, std::vector<std::pair<StrategyPtr, CostPtrList>>> from_tocost;
918     for (const auto &key_value : op2_cost_map) {
919       const auto &from_to_strategies = key_value.first;
920       const auto &costlist = key_value.second;
921       from_tocost[from_to_strategies.first].emplace_back(std::make_pair(from_to_strategies.second, costlist));
922     }
923     op2_edges_reorganised_cost[i] = from_tocost;
924   }
925 
926   // Merge op2 into op1
927   const auto &op1_old_stra_cost = op1->GetStrategyCost();
928   const auto &op2_old_stra_cost = op2->GetStrategyCost();
929   std::vector<std::shared_ptr<StrategyWithCost>> op1_new_stra_cost;
930 
931   for (auto &op1_stra_cost : op1_old_stra_cost) {
932     auto op1_old_stra = op1_stra_cost->strategy_ptr;
933     auto op1_old_costlist = op1_stra_cost->cost_list;
934 
935     for (auto &op2_stra_cost : op2_old_stra_cost) {
936       auto op2_stra = op2_stra_cost->strategy_ptr;
937       auto op2_costlist = op2_stra_cost->cost_list;
938 
939       StrategyPtr op1_new_stra = std::make_shared<Strategy>(*op1_old_stra);
940       op1_new_stra->CoverStrategy(op2_stra);
941       CostPtrList op1_new_costlist;
942       // Calculate new cost for 'op1_new_costlist'
943       CreateSourceEliminationSubCostList(op1_old_stra, op1_old_costlist, op2_stra, op2_costlist, &op1_new_costlist);
944       std::shared_ptr<StrategyWithCost> swc = std::make_shared<StrategyWithCost>(op1_new_stra, op1_new_costlist);
945       op1_new_stra_cost.emplace_back(swc);
946 
947       // Set cost for new successive edges of op1 and op2
948       for (size_t i = 0; i < op1_old_succ_edges.size(); ++i) {
949         auto &from_tocost = op1_edges_reorganised_cost[i];
950         auto &to_cost = from_tocost[op1_old_stra];
951         auto &new_cost_map = op1_new_edges_cost[i];
952         for (auto &stra_costlit : to_cost) {
953           auto &to_strategy = stra_costlit.first;
954           auto &edge_costlist = stra_costlit.second;
955           CostPtrKey new_key = {op1_new_stra, to_strategy};
956           new_cost_map[new_key] = edge_costlist;
957         }
958       }
959       for (size_t i = 0; i < op2_old_succ_edges.size(); ++i) {
960         auto &from_tocost = op2_edges_reorganised_cost[i];
961         auto &to_cost = from_tocost[op2_stra];
962         auto &new_cost_map = op2_new_edges_cost[i];
963         for (auto &stra_costlist : to_cost) {
964           auto &to_strategy = stra_costlist.first;
965           auto &edge_costlist = stra_costlist.second;
966           CostPtrKey new_key = {op1_new_stra, to_strategy};
967           new_cost_map[new_key] = edge_costlist;
968         }
969       }
970     }
971   }
972   op1->SetStrategyCost(op1_new_stra_cost);
973   op2->SetNotAlive();
974 
975   // Update the edges incident to op1, and edges incident to op2
976   MS_LOG(INFO) << "Source eliminating node: " << op2->name() << " to node: " << op1->name() + " succeeded.";
977   return UpdateEdgesIncidentToNodes(op1, &op1_old_succ_edges, &op1_new_edges_cost, &op1_new_succ_edges, op2,
978                                     &op2_old_succ_edges, &op2_new_edges_cost, &op2_new_succ_edges);
979 }
980 
981 // Check the graph whether a TriangleElimination can be performed
CheckTriangleElimination() const982 std::pair<OperatorInfoPtr, std::shared_ptr<Edge>> CostGraph::CheckTriangleElimination() const {
983   for (auto &op : ops_) {
984     MS_EXCEPTION_IF_NULL(op);
985     bool bool_test = (op->is_alive()) && (op->GetAlivePrevEdges().empty()) && (op->GetAliveSuccEdges().size() == 2);
986     if (bool_test) {
987       auto edge1 = op->GetAliveSuccEdges()[0];
988       auto edge2 = op->GetAliveSuccEdges()[1];
989       MS_EXCEPTION_IF_NULL(edge1);
990       MS_EXCEPTION_IF_NULL(edge2);
991       auto first_op = edge1->next_operator();
992       auto second_op = edge2->next_operator();
993       MS_EXCEPTION_IF_NULL(first_op);
994       for (auto &first_op_succ_edge : first_op->GetAliveSuccEdges()) {
995         if (first_op_succ_edge->next_operator() == second_op) {
996           return {op, first_op_succ_edge};
997         }
998       }
999       MS_EXCEPTION_IF_NULL(second_op);
1000       for (auto &second_op_succ_edge : second_op->GetAliveSuccEdges()) {
1001         if (second_op_succ_edge->next_operator() == first_op) {
1002           return {op, second_op_succ_edge};
1003         }
1004       }
1005     }
1006   }
1007   return {nullptr, nullptr};
1008 }
1009 
1010 // Check the graph whether a StarElimination can be performed.
1011 // NOTE: this elimination MUST be performed only when the above 5 operation cannot be applied.
CheckStarElimination() const1012 OperatorInfoPtr CostGraph::CheckStarElimination() const {
1013   for (auto &op : ops_) {
1014     MS_EXCEPTION_IF_NULL(op);
1015     bool bool_test = (op->is_alive()) && (op->GetAlivePrevEdges().empty()) && (op->GetAliveSuccEdges().size() > 1);
1016     if (bool_test) {
1017       return op;
1018     }
1019   }
1020   return nullptr;
1021 }
1022 
1023 // This method is for 'eliminating operator' operation in the DP algorithm. It creates a new edge to replace
1024 // 'lefe_edge', 'op' and 'right_edge'. As a consequence, it creates new costlist for the new edge.
EliminationOp(const OperatorInfoPtr & op)1025 std::shared_ptr<Edge> CostGraph::EliminationOp(const OperatorInfoPtr &op) {
1026   // in this case, the operators are organised in the form of u-->op-->v, and the goal
1027   // is to eliminate 'op'.
1028   MS_EXCEPTION_IF_NULL(op);
1029   MS_LOG(INFO) << "Now eliminating node: " << op->name() << ".";
1030   auto edge_u_op = op->GetAlivePrevEdges()[0];
1031   auto edge_op_v = op->GetAliveSuccEdges()[0];
1032   MS_EXCEPTION_IF_NULL(edge_u_op);
1033   MS_EXCEPTION_IF_NULL(edge_op_v);
1034   auto u = edge_u_op->prev_operator();
1035   auto v = edge_op_v->next_operator();
1036   std::vector<size_t> output_indexs, input_indexs;
1037   size_t output_index, input_index;
1038   MS_EXCEPTION_IF_NULL(u);
1039   MS_EXCEPTION_IF_NULL(v);
1040   std::string new_edge_name = u->name() + OPERATOR_TO_OPERATOR_CONNECTOR + v->name();
1041   std::shared_ptr<Edge> new_edge;
1042   if (edge_u_op->is_combined()) {
1043     output_indexs = edge_u_op->prev_op_output_indexs();
1044   } else {
1045     output_index = edge_u_op->prev_op_output_index();
1046     output_indexs.push_back(output_index);
1047   }
1048   if (edge_op_v->is_combined()) {
1049     input_indexs = edge_op_v->next_op_input_indexs();
1050   } else {
1051     input_index = edge_op_v->next_op_input_index();
1052     input_indexs.push_back(input_index);
1053   }
1054 
1055   if (!edge_u_op->is_combined() && !edge_op_v->is_combined()) {
1056     new_edge = std::make_shared<Edge>(new_edge_name, u, v, output_index, input_index, false);
1057   } else {
1058     new_edge = std::make_shared<Edge>(new_edge_name, u, v, output_indexs, input_indexs, true);
1059   }
1060   MS_EXCEPTION_IF_NULL(new_edge);
1061   new_edge->set_pre_op_output(edge_u_op->prev_op_output());
1062   new_edge->set_next_op_input(edge_op_v->next_op_input());
1063   new_edge->OpEliminationSetNewCost(edge_u_op, op, edge_op_v);
1064   u->ReplaceSuccEdge(op, new_edge);
1065   v->ReplacePreEdge(op, new_edge);
1066   op->SetNotAlive();
1067   MS_LOG(INFO) << "Eliminating node: " << op->name() << " succeeded.";
1068   return new_edge;
1069 }
1070 
1071 // This method is for 'eliminating edges' operation in the DP algorithm. It creates a new edge to replace the 'edges',
1072 // and sets new costlist for the new edge.
EliminationEdges(const std::vector<std::shared_ptr<Edge>> & edges)1073 std::shared_ptr<Edge> CostGraph::EliminationEdges(const std::vector<std::shared_ptr<Edge>> &edges) {
1074   MS_LOG(INFO) << "Now eliminating " << edges.size() << " edges.";
1075   MS_EXCEPTION_IF_NULL(edges[0]);
1076   auto u = edges[0]->prev_operator();
1077   auto v = edges[0]->next_operator();
1078   MS_EXCEPTION_IF_NULL(u);
1079   MS_EXCEPTION_IF_NULL(v);
1080   std::string new_edge_name = u->name() + OPERATOR_TO_OPERATOR_CONNECTOR + v->name();
1081   std::vector<size_t> output_indexs, input_indexs;
1082 
1083   for (auto &edge : edges) {
1084     MS_EXCEPTION_IF_NULL(edge);
1085     if (edge->is_combined()) {
1086       auto from_output_indexs = edge->prev_op_output_indexs();
1087       auto from_input_indexs = edge->next_op_input_indexs();
1088       (void)std::copy(from_output_indexs.begin(), from_output_indexs.end(), std::back_inserter(output_indexs));
1089       (void)std::copy(from_input_indexs.begin(), from_input_indexs.end(), std::back_inserter(input_indexs));
1090     } else {
1091       output_indexs.push_back(edge->prev_op_output_index());
1092       input_indexs.push_back(edge->next_op_input_index());
1093     }
1094   }
1095 
1096   std::shared_ptr<Edge> new_edge = std::make_shared<Edge>(new_edge_name, u, v, output_indexs, input_indexs, true);
1097   MS_EXCEPTION_IF_NULL(new_edge);
1098   new_edge->set_pre_op_output(edges[0]->prev_op_output());
1099   new_edge->set_next_op_input(edges[0]->next_op_input());
1100 
1101   new_edge->EdgeEliminationSetNewCost(u, edges, v);
1102 
1103   u->ReplaceSuccEdges(v, new_edge);
1104   v->ReplacePreEdges(u, new_edge);
1105   MS_LOG(INFO) << "Eliminating " << edges.size() << " edges succeeded.";
1106   return new_edge;
1107 }
1108 
1109 // Given 'op_cost_list', 'edge_cost_list', and 'tar_cost_list', this method is to create 'tar_cost_list_new'
1110 // for this contract under the strategy 'op_strategy'
CreateMergeEliminationSubCostList(StrategyPtr op_strategy,const CostPtrList & op_cost_list,const CostPtrList & edge_cost_list,StrategyPtr tar_op_strategy,const CostPtrList & tar_cost_list,CostPtrList * const tar_cost_list_new)1111 void CostGraph::CreateMergeEliminationSubCostList(StrategyPtr op_strategy, const CostPtrList &op_cost_list,
1112                                                   const CostPtrList &edge_cost_list, StrategyPtr tar_op_strategy,
1113                                                   const CostPtrList &tar_cost_list,
1114                                                   CostPtrList *const tar_cost_list_new) {
1115   for (size_t i = 0; i < op_cost_list.size(); ++i) {
1116     auto &op_cost = op_cost_list[i];
1117     MS_EXCEPTION_IF_NULL(op_cost);
1118     for (size_t j = 0; j < edge_cost_list.size(); ++j) {
1119       auto &edge_cost = edge_cost_list[j];
1120       MS_EXCEPTION_IF_NULL(edge_cost);
1121       for (size_t k = 0; k < tar_cost_list.size(); ++k) {
1122         auto &tar_cost = tar_cost_list[k];
1123         MS_EXCEPTION_IF_NULL(tar_cost);
1124         double computation = op_cost->computation_cost_ + edge_cost->computation_cost_ + tar_cost->computation_cost_;
1125         double memory = op_cost->memory_with_reuse_ + edge_cost->memory_with_reuse_ + tar_cost->memory_with_reuse_;
1126         double communication =
1127           op_cost->communication_cost_ + edge_cost->communication_cost_ + tar_cost->communication_cost_;
1128         double communication_forward =
1129           op_cost->communication_forward_ + edge_cost->communication_forward_ + tar_cost->communication_forward_;
1130         double communication_without_para = op_cost->communication_without_parameter_ +
1131                                             edge_cost->communication_without_parameter_ +
1132                                             tar_cost->communication_without_parameter_;
1133 
1134         auto decision =
1135           std::make_shared<MergeEliminationDecision>(op_strategy, op_cost, edge_cost, tar_op_strategy, tar_cost);
1136         auto new_cost = std::make_shared<Cost>(computation, communication, decision);
1137         const auto gamma = CostModelContext::GetInstance()->costmodel_gamma();
1138         MS_EXCEPTION_IF_NULL(new_cost);
1139         new_cost->communication_without_parameter_ = communication_without_para;
1140         new_cost->communication_with_partial_para_ =
1141           communication_without_para + gamma * (communication - communication_without_para);
1142         new_cost->memory_with_reuse_ = memory;
1143         new_cost->communication_forward_ = communication_forward;
1144         MS_EXCEPTION_IF_NULL(tar_cost_list_new);
1145         tar_cost_list_new->emplace_back(std::move(new_cost));
1146       }
1147     }
1148   }
1149 }
1150 
1151 // This method is for the 'Merge' operation in DP algorithm. It creates new costlist for each strategy in the
1152 // target_op
EliminationMerge(const OperatorInfoPtr & op)1153 OperatorInfoPtr CostGraph::EliminationMerge(const OperatorInfoPtr &op) {
1154   MS_EXCEPTION_IF_NULL(op);
1155   auto target_op = op->GetAliveSuccEdges()[0]->next_operator();
1156   auto edge_ptr = op->GetAliveSuccEdges()[0];
1157   MS_EXCEPTION_IF_NULL(target_op);
1158   MS_EXCEPTION_IF_NULL(edge_ptr);
1159   MS_LOG(INFO) << "Now merging " << op->name() << " into " << target_op->name() << ".";
1160   bool valid = false;
1161 
1162   for (auto &tar_stra_cost : target_op->GetStrategyCost()) {
1163     MS_EXCEPTION_IF_NULL(tar_stra_cost);
1164     auto tar_stra = tar_stra_cost->strategy_ptr;
1165     auto tar_clist_origin = tar_stra_cost->cost_list;
1166     CostPtrList tar_clist_new;
1167 
1168     for (auto &op_stra_cost : op->GetStrategyCost()) {
1169       MS_EXCEPTION_IF_NULL(op_stra_cost);
1170       auto op_stra = op_stra_cost->strategy_ptr;
1171       auto op_clist = op_stra_cost->cost_list;
1172       auto edge_clist = edge_ptr->GetCostList(op_stra, tar_stra);
1173 
1174       CreateMergeEliminationSubCostList(op_stra, op_clist, edge_clist, tar_stra, tar_clist_origin, &tar_clist_new);
1175     }
1176     Simplify(&tar_clist_new);
1177     // Set the new costlist w.r.t the strategy
1178     tar_stra_cost->cost_list = tar_clist_new;
1179     if ((!valid) && (!tar_clist_new.empty())) {
1180       valid = true;
1181     }
1182   }
1183 
1184   if (!valid) {
1185     MS_LOG(EXCEPTION) << "Merging " << op->name() << " into " << target_op->name() << " failed.";
1186   }
1187   op->SetNotAlive();
1188   MS_LOG(INFO) << "Merging " << op->name() << " into " << target_op->name() << " succeeded.";
1189   return target_op;
1190 }
1191 
1192 // Given 'contract_op_cost_list', 'edge_cost_list', and 'tar_cost_list', this method is to create 'tar_cost_list_new'
1193 // for this contract under the strategy 'contract_op_stra'
CreateContractEliminationSubCostList(StrategyPtr contract_op_stra,const CostPtrList & contract_op_cost_list,const CostPtrList & edge_cost_list,StrategyPtr target_op_stra,const CostPtrList & tar_cost_list,CostPtrList * tar_cost_list_new)1194 void CostGraph::CreateContractEliminationSubCostList(StrategyPtr contract_op_stra,
1195                                                      const CostPtrList &contract_op_cost_list,
1196                                                      const CostPtrList &edge_cost_list, StrategyPtr target_op_stra,
1197                                                      const CostPtrList &tar_cost_list, CostPtrList *tar_cost_list_new) {
1198   for (size_t i = 0; i < contract_op_cost_list.size(); ++i) {
1199     auto &contract_op_cost = contract_op_cost_list[i];
1200     MS_EXCEPTION_IF_NULL(contract_op_cost);
1201     for (size_t j = 0; j < edge_cost_list.size(); ++j) {
1202       auto &edge_cost = edge_cost_list[j];
1203       MS_EXCEPTION_IF_NULL(edge_cost);
1204       for (size_t k = 0; k < tar_cost_list.size(); ++k) {
1205         auto &tar_cost = tar_cost_list[k];
1206         MS_EXCEPTION_IF_NULL(tar_cost);
1207         double computation =
1208           contract_op_cost->computation_cost_ + edge_cost->computation_cost_ + tar_cost->computation_cost_;
1209         double memory =
1210           contract_op_cost->memory_with_reuse_ + edge_cost->memory_with_reuse_ + tar_cost->memory_with_reuse_;
1211         double communication =
1212           contract_op_cost->communication_cost_ + edge_cost->communication_cost_ + tar_cost->communication_cost_;
1213         double communication_forward = contract_op_cost->communication_forward_ + edge_cost->communication_forward_ +
1214                                        tar_cost->communication_forward_;
1215         double communication_without_para = contract_op_cost->communication_without_parameter_ +
1216                                             edge_cost->communication_without_parameter_ +
1217                                             tar_cost->communication_without_parameter_;
1218 
1219         auto decision = std::make_shared<ContractEliminationDecision>(contract_op_stra, contract_op_cost, edge_cost,
1220                                                                       target_op_stra, tar_cost);
1221         auto new_cost = std::make_shared<Cost>(computation, communication, decision);
1222         auto gamma = CostModelContext::GetInstance()->costmodel_gamma();
1223         new_cost->communication_without_parameter_ = communication_without_para;
1224         new_cost->communication_with_partial_para_ =
1225           communication_without_para + gamma * (communication - communication_without_para);
1226         new_cost->memory_with_reuse_ = memory;
1227         new_cost->communication_forward_ = communication_forward;
1228         tar_cost_list_new->emplace_back(std::move(new_cost));
1229       }
1230     }
1231   }
1232 }
1233 
1234 // This method is for the 'Contract' operation in DP algorithm. It creates new costlist for each strategy in the
1235 // target_op
EliminationContract(const OperatorInfoPtr & op)1236 OperatorInfoPtr CostGraph::EliminationContract(const OperatorInfoPtr &op) {
1237   MS_EXCEPTION_IF_NULL(op);
1238   auto target_op = op->GetAlivePrevEdges()[0]->prev_operator();
1239   auto edge_ptr = op->GetAlivePrevEdges()[0];
1240   MS_LOG(INFO) << "Now contracting " << op->name() << " into " << target_op->name() << ".";
1241   bool valid = false;
1242 
1243   for (auto &tar_stra_cost : target_op->GetStrategyCost()) {
1244     MS_EXCEPTION_IF_NULL(tar_stra_cost);
1245     auto tar_stra = tar_stra_cost->strategy_ptr;
1246     auto tar_clist_origin = tar_stra_cost->cost_list;
1247     CostPtrList tar_clist_new;
1248 
1249     for (auto &op_stra_cost : op->GetStrategyCost()) {
1250       MS_EXCEPTION_IF_NULL(op_stra_cost);
1251       auto op_stra = op_stra_cost->strategy_ptr;
1252       auto op_clist = op_stra_cost->cost_list;
1253       auto edge_clist = edge_ptr->GetCostList(tar_stra, op_stra);
1254 
1255       CreateContractEliminationSubCostList(op_stra, op_clist, edge_clist, tar_stra, tar_clist_origin, &tar_clist_new);
1256     }
1257     Simplify(&tar_clist_new);
1258     // Set the new costlist w.r.t the strategy
1259     tar_stra_cost->cost_list = tar_clist_new;
1260     if ((!valid) && (!tar_clist_new.empty())) {
1261       valid = true;
1262     }
1263   }
1264   if (!valid) {
1265     MS_LOG(EXCEPTION) << "Contracting " << op->name() << " into " << target_op->name() << " failed.";
1266   }
1267   op->SetNotAlive();
1268   MS_LOG(INFO) << "Contracting " << op->name() << " into " << target_op->name() << " succeeded.";
1269   return target_op;
1270 }
1271 
CreateTriangleEliminationSubCostList(StrategyPtr elimi_op_stra,StrategyPtr left_op_stra,StrategyPtr right_op_stra,const CostPtr & right_op_cost,const CostPtrList & elimi_op_clist,const CostPtrList & left_edge_clist,const CostPtr & right_edge_cost,const CostPtrList & left_node_clist_origin,CostPtrList * left_node_clist_new)1272 void CostGraph::CreateTriangleEliminationSubCostList(StrategyPtr elimi_op_stra, StrategyPtr left_op_stra,
1273                                                      StrategyPtr right_op_stra, const CostPtr &right_op_cost,
1274                                                      const CostPtrList &elimi_op_clist,
1275                                                      const CostPtrList &left_edge_clist, const CostPtr &right_edge_cost,
1276                                                      const CostPtrList &left_node_clist_origin,
1277                                                      CostPtrList *left_node_clist_new) {
1278   MS_EXCEPTION_IF_NULL(right_edge_cost);
1279   MS_EXCEPTION_IF_NULL(right_op_cost);
1280   MS_EXCEPTION_IF_NULL(left_node_clist_new);
1281   for (auto &elimi_op_cost : elimi_op_clist) {
1282     MS_EXCEPTION_IF_NULL(elimi_op_cost);
1283     for (auto &left_edge_cost : left_edge_clist) {
1284       MS_EXCEPTION_IF_NULL(left_edge_cost);
1285       for (auto &left_node_cost : left_node_clist_origin) {
1286         MS_EXCEPTION_IF_NULL(left_node_cost);
1287         double new_computation = elimi_op_cost->computation_cost_ + left_edge_cost->computation_cost_ +
1288                                  left_node_cost->computation_cost_ + right_edge_cost->computation_cost_;
1289         double new_memory = elimi_op_cost->memory_with_reuse_ + left_edge_cost->memory_with_reuse_ +
1290                             left_node_cost->memory_with_reuse_ + right_edge_cost->memory_with_reuse_;
1291         double new_commu_cost = elimi_op_cost->communication_cost_ + left_edge_cost->communication_cost_ +
1292                                 left_node_cost->communication_cost_ + right_edge_cost->communication_cost_;
1293         double new_commu_forward = elimi_op_cost->communication_forward_ + left_edge_cost->communication_forward_ +
1294                                    left_node_cost->communication_forward_ + right_edge_cost->communication_forward_;
1295         double new_commu_without =
1296           elimi_op_cost->communication_without_parameter_ + left_edge_cost->communication_without_parameter_ +
1297           left_node_cost->communication_without_parameter_ + right_edge_cost->communication_without_parameter_;
1298         const auto triangle_star_stra_overwrite = CostModelContext::GetInstance()->triangle_star_strategy_overwrite();
1299         const auto gamma = CostModelContext::GetInstance()->costmodel_gamma();
1300 
1301         if (triangle_star_stra_overwrite) {
1302           new_computation += right_op_cost->computation_cost_;
1303           new_memory += right_op_cost->memory_with_reuse_;
1304           new_commu_cost += right_op_cost->communication_cost_;
1305           new_commu_forward += right_op_cost->communication_forward_;
1306           new_commu_without += right_op_cost->communication_without_parameter_;
1307         }
1308 
1309         auto decision =
1310           std::make_shared<TriangleEliminationDecision>(elimi_op_stra, elimi_op_cost, left_edge_cost, right_edge_cost,
1311                                                         left_op_stra, left_node_cost, right_op_stra, right_op_cost);
1312         auto new_cost = std::make_shared<Cost>(new_computation, new_commu_cost, decision);
1313         new_cost->communication_without_parameter_ = new_commu_without;
1314         new_cost->communication_with_partial_para_ = new_commu_without + gamma * (new_commu_cost - new_commu_without);
1315         new_cost->memory_with_reuse_ = new_memory;
1316         new_cost->communication_forward_ = new_commu_forward;
1317         left_node_clist_new->emplace_back(std::move(new_cost));
1318       }
1319     }
1320   }
1321 }
1322 
CreateTriangleEliminationCostList(const OperatorInfoPtr & elimi_op,const CostPtrList & right_node_clist,const CostPtrList & right_edge_clist,const StrategyPtr & elimi_op_stra,const StrategyPtr & left_node_stra,const StrategyPtr & right_node_stra,const CostPtrList & elimi_op_clist,const CostPtrList & left_edge_clist,const CostPtrList & left_node_clist_origin,CostPtrList * left_node_clist_new)1323 void CostGraph::CreateTriangleEliminationCostList(const OperatorInfoPtr &elimi_op, const CostPtrList &right_node_clist,
1324                                                   const CostPtrList &right_edge_clist, const StrategyPtr &elimi_op_stra,
1325                                                   const StrategyPtr &left_node_stra, const StrategyPtr &right_node_stra,
1326                                                   const CostPtrList &elimi_op_clist, const CostPtrList &left_edge_clist,
1327                                                   const CostPtrList &left_node_clist_origin,
1328                                                   CostPtrList *left_node_clist_new) {
1329   MS_EXCEPTION_IF_NULL(elimi_op);
1330   for (auto &right_node_cost : right_node_clist) {
1331     MS_EXCEPTION_IF_NULL(right_node_cost);
1332     for (auto &right_edge_cost : right_edge_clist) {
1333       MS_EXCEPTION_IF_NULL(right_edge_cost);
1334       CreateTriangleEliminationSubCostList(elimi_op_stra, left_node_stra, right_node_stra, right_node_cost,
1335                                            elimi_op_clist, left_edge_clist, right_edge_cost, left_node_clist_origin,
1336                                            left_node_clist_new);
1337     }
1338   }
1339 }
1340 
EliminationTriangle(const OperatorInfoPtr & elimi_op,const std::shared_ptr<Edge> & edge_left_right)1341 OperatorInfoPtr CostGraph::EliminationTriangle(const OperatorInfoPtr &elimi_op,
1342                                                const std::shared_ptr<Edge> &edge_left_right) {
1343   MS_EXCEPTION_IF_NULL(edge_left_right);
1344   MS_EXCEPTION_IF_NULL(elimi_op);
1345   MS_LOG(INFO) << "Now eliminating triangle: " << elimi_op->name() << ".";
1346   auto left_node = edge_left_right->prev_operator();
1347   auto right_node = edge_left_right->next_operator();
1348   auto left_edge = elimi_op->GetAliveSuccEdges()[0];
1349   auto right_edge = elimi_op->GetAliveSuccEdges()[1];
1350   MS_EXCEPTION_IF_NULL(left_node);
1351   MS_EXCEPTION_IF_NULL(right_node);
1352   MS_EXCEPTION_IF_NULL(left_edge);
1353   MS_EXCEPTION_IF_NULL(right_edge);
1354   MS_LOG(INFO) << "The left operator is: " << left_node->name() << ".";
1355   MS_LOG(INFO) << "The right operator is: " << right_node->name() << ".";
1356 
1357   if (left_edge->next_operator() != left_node) {
1358     auto tmp = left_edge;
1359     left_edge = right_edge;
1360     right_edge = tmp;
1361   }
1362   bool valid = false;
1363 
1364   for (auto &left_node_stra_cost : left_node->GetStrategyCost()) {
1365     MS_EXCEPTION_IF_NULL(left_node_stra_cost);
1366     auto left_node_stra = left_node_stra_cost->strategy_ptr;
1367     auto left_node_clist_origin = left_node_stra_cost->cost_list;
1368     CostPtrList left_node_clist_new;
1369 
1370     for (auto &elimi_op_stra_cost : elimi_op->GetStrategyCost()) {
1371       MS_EXCEPTION_IF_NULL(elimi_op_stra_cost);
1372       auto elimi_op_stra = elimi_op_stra_cost->strategy_ptr;
1373       auto elimi_op_clist = elimi_op_stra_cost->cost_list;
1374       auto left_edge_clist = left_edge->GetCostList(elimi_op_stra, left_node_stra);
1375 
1376       for (auto &right_node_stra_cost : right_node->GetStrategyCost()) {
1377         MS_EXCEPTION_IF_NULL(right_node_stra_cost);
1378         auto right_node_stra = right_node_stra_cost->strategy_ptr;
1379         auto right_node_clist = right_node_stra_cost->cost_list;
1380         auto right_edge_clist = right_edge->GetCostList(elimi_op_stra, right_node_stra);
1381 
1382         CreateTriangleEliminationCostList(elimi_op, right_node_clist, right_edge_clist, elimi_op_stra, left_node_stra,
1383                                           right_node_stra, elimi_op_clist, left_edge_clist, left_node_clist_origin,
1384                                           &left_node_clist_new);
1385       }
1386     }
1387     Simplify(&left_node_clist_new);
1388     // Set the new costlist w.r.t the strategy
1389     left_node_stra_cost->cost_list = left_node_clist_new;
1390     if ((!valid) && (!left_node_clist_new.empty())) {
1391       valid = true;
1392     }
1393   }
1394 
1395   if (!valid) {
1396     MS_LOG(EXCEPTION) << "Eliminating triangle: " << elimi_op->name()
1397                       << " failed. It may be caused by "
1398                          "configuring inconsistent strategies for operators.";
1399   }
1400   elimi_op->SetNotAlive();
1401   MS_LOG(INFO) << "Eliminating triangle: " << elimi_op->name() << " succeeded.";
1402   return left_node;
1403 }
1404 
CreateStarEliminationSubCostList(const StrategyPtr & first_succ_node_stra,const CostPtrList & first_succ_node_clist,const CostPtrList & first_succ_edge_clist,const StrategyPtr & merged_op_stra,const CostPtrList & merged_op_clist,std::vector<StrategyPtr> succ_nodes_stras,CostPtrList & succ_edges_costs,CostPtrList & succ_nodes_costs,CostPtrList * first_succ_node_clist_new)1405 void CostGraph::CreateStarEliminationSubCostList(const StrategyPtr &first_succ_node_stra,
1406                                                  const CostPtrList &first_succ_node_clist,
1407                                                  const CostPtrList &first_succ_edge_clist,
1408                                                  const StrategyPtr &merged_op_stra, const CostPtrList &merged_op_clist,
1409                                                  std::vector<StrategyPtr> succ_nodes_stras,
1410                                                  CostPtrList &succ_edges_costs, CostPtrList &succ_nodes_costs,
1411                                                  CostPtrList *first_succ_node_clist_new) {
1412   for (auto &first_succ_node_cost : first_succ_node_clist) {
1413     for (auto &first_succ_edge_cost : first_succ_edge_clist) {
1414       for (auto &merged_node_cost : merged_op_clist) {
1415         MS_EXCEPTION_IF_NULL(merged_node_cost);
1416         succ_nodes_stras[0] = first_succ_node_stra;
1417         succ_edges_costs[0] = first_succ_edge_cost;
1418         succ_nodes_costs[0] = first_succ_node_cost;
1419 
1420         double computation_cost = merged_node_cost->computation_cost_,
1421                memory_cost = merged_node_cost->memory_with_reuse_, commu_cost = merged_node_cost->communication_cost_,
1422                commu_without = merged_node_cost->communication_without_parameter_,
1423                commu_forward = merged_node_cost->communication_forward_;
1424         const auto triangle_star_stra_overwrite = CostModelContext::GetInstance()->triangle_star_strategy_overwrite();
1425         const auto gamma = CostModelContext::GetInstance()->costmodel_gamma();
1426         for (size_t i = 0; i < succ_nodes_stras.size(); ++i) {
1427           MS_EXCEPTION_IF_NULL(succ_edges_costs[i]);
1428           if (i == 0) {
1429             computation_cost += succ_edges_costs[i]->computation_cost_ + succ_nodes_costs[i]->computation_cost_;
1430             memory_cost += succ_edges_costs[i]->memory_with_reuse_ + succ_nodes_costs[i]->memory_with_reuse_;
1431             commu_cost += succ_edges_costs[i]->communication_cost_ + succ_nodes_costs[i]->communication_cost_;
1432             commu_forward += succ_edges_costs[i]->communication_forward_ + succ_nodes_costs[i]->communication_forward_;
1433             commu_without += succ_edges_costs[i]->communication_without_parameter_ +
1434                              succ_nodes_costs[i]->communication_without_parameter_;
1435           } else {
1436             computation_cost += succ_edges_costs[i]->computation_cost_;
1437             memory_cost += succ_edges_costs[i]->memory_with_reuse_;
1438             commu_cost += succ_edges_costs[i]->communication_cost_;
1439             commu_forward += succ_edges_costs[i]->communication_forward_;
1440             commu_without += succ_edges_costs[i]->communication_without_parameter_;
1441             if (triangle_star_stra_overwrite) {
1442               computation_cost += succ_nodes_costs[i]->computation_cost_;
1443               memory_cost += succ_nodes_costs[i]->memory_with_reuse_;
1444               commu_cost += succ_nodes_costs[i]->communication_cost_;
1445               commu_forward += succ_nodes_costs[i]->communication_forward_;
1446               commu_without += succ_nodes_costs[i]->communication_without_parameter_;
1447             }
1448           }
1449         }
1450 
1451         auto decision = std::make_shared<StarEliminationDecision>(merged_op_stra, merged_node_cost, succ_edges_costs,
1452                                                                   succ_nodes_stras, succ_nodes_costs);
1453         auto new_cost = std::make_shared<Cost>(computation_cost, commu_cost, decision);
1454         new_cost->communication_without_parameter_ = commu_without;
1455         new_cost->communication_with_partial_para_ = commu_without + gamma * (commu_cost - commu_without);
1456         new_cost->memory_with_reuse_ = memory_cost;
1457         new_cost->communication_forward_ = commu_forward;
1458         first_succ_node_clist_new->emplace_back(std::move(new_cost));
1459       }
1460     }
1461   }
1462 }
1463 
CreateStarEliminationCostList(std::vector<std::shared_ptr<Edge>> & succ_edges,const StrategyPtr & first_succ_node_stra,const CostPtrList & first_succ_node_clist,const CostPtrList & first_succ_edge_clist,const StrategyPtr & merged_op_stra,const CostPtrList & merged_op_clist,CostPtrList * first_succ_node_clist_new)1464 void CostGraph::CreateStarEliminationCostList(std::vector<std::shared_ptr<Edge>> &succ_edges,
1465                                               const StrategyPtr &first_succ_node_stra,
1466                                               const CostPtrList &first_succ_node_clist,
1467                                               const CostPtrList &first_succ_edge_clist,
1468                                               const StrategyPtr &merged_op_stra, const CostPtrList &merged_op_clist,
1469                                               CostPtrList *first_succ_node_clist_new) {
1470   std::vector<StrategyPtr> succ_nodes_stras(succ_edges.size(), nullptr);
1471   CostPtrList succ_edges_costs(succ_edges.size(), nullptr), succ_nodes_costs(succ_edges.size(), nullptr);
1472   std::function<void(size_t)> recursive = [&first_succ_node_stra, &first_succ_node_clist, &first_succ_edge_clist,
1473                                            &merged_op_stra, &merged_op_clist, &succ_nodes_stras, &succ_edges_costs,
1474                                            &succ_nodes_costs, &first_succ_node_clist_new, &succ_edges, &recursive,
1475                                            this](size_t k) {
1476     if (k == succ_edges.size()) {
1477       CreateStarEliminationSubCostList(first_succ_node_stra, first_succ_node_clist, first_succ_edge_clist,
1478                                        merged_op_stra, merged_op_clist, succ_nodes_stras, succ_edges_costs,
1479                                        succ_nodes_costs, first_succ_node_clist_new);
1480       return;
1481     }
1482     MS_LOG(DEBUG) << "The size of first_succ_node_clist: " << first_succ_node_clist.size()
1483                   << ", first_succ_edge_clist: " << first_succ_edge_clist.size()
1484                   << ", merged_op_clist: " << merged_op_clist.size()
1485                   << ", first_succ_node_clist_new: " << first_succ_node_clist_new->size() << ".";
1486     auto succ_edge = succ_edges[k];
1487     MS_EXCEPTION_IF_NULL(succ_edge);
1488     auto succ_node = succ_edge->next_operator();
1489     MS_EXCEPTION_IF_NULL(succ_node);
1490     for (auto &succ_node_stra_cost : succ_node->GetStrategyCost()) {
1491       MS_EXCEPTION_IF_NULL(succ_node_stra_cost);
1492       auto succ_node_stra = succ_node_stra_cost->strategy_ptr;
1493       auto succ_node_clist = succ_node_stra_cost->cost_list;
1494       auto succ_edge_clist = succ_edge->GetCostList(merged_op_stra, succ_node_stra);
1495 
1496       for (auto &succ_node_cost : succ_node_clist) {
1497         MS_EXCEPTION_IF_NULL(succ_node_cost);
1498         for (auto &succ_edge_cost : succ_edge_clist) {
1499           MS_EXCEPTION_IF_NULL(succ_edge_cost);
1500           succ_nodes_stras[k] = succ_node_stra;
1501           succ_edges_costs[k] = succ_edge_cost;
1502           succ_nodes_costs[k] = succ_node_cost;
1503           recursive(k + 1);
1504         }
1505       }
1506     }
1507   };
1508 
1509   recursive(1);
1510 }
1511 
EliminationStar(const OperatorInfoPtr & merged_op)1512 std::vector<std::shared_ptr<Edge>> CostGraph::EliminationStar(const OperatorInfoPtr &merged_op) {
1513   MS_EXCEPTION_IF_NULL(merged_op);
1514   auto succ_edges = merged_op->GetAliveSuccEdges();
1515   MS_LOG(INFO) << "Now eliminating star centered at: " << merged_op->name() << ".";
1516   for (auto &succ_edge : succ_edges) {
1517     MS_EXCEPTION_IF_NULL(succ_edge->next_operator());
1518     MS_LOG(INFO) << "The successive operator is: " << succ_edge->next_operator()->name() << ".";
1519   }
1520 
1521   MS_EXCEPTION_IF_NULL(succ_edges[0]);
1522   auto first_succ_node = succ_edges[0]->next_operator();
1523   auto first_succ_edge = succ_edges[0];
1524   bool valid = false;
1525 
1526   // 'merged_op' is merged into first_node
1527   MS_EXCEPTION_IF_NULL(first_succ_node);
1528   for (auto &first_succ_node_stra_cost : first_succ_node->GetStrategyCost()) {
1529     MS_EXCEPTION_IF_NULL(first_succ_node_stra_cost);
1530     auto first_succ_node_stra = first_succ_node_stra_cost->strategy_ptr;
1531     auto first_succ_node_clist = first_succ_node_stra_cost->cost_list;
1532     CostPtrList first_succ_node_clist_new;
1533 
1534     for (auto &merged_op_stra_cost : merged_op->GetStrategyCost()) {
1535       MS_EXCEPTION_IF_NULL(merged_op_stra_cost);
1536       auto merged_op_stra = merged_op_stra_cost->strategy_ptr;
1537       auto merged_op_clist = merged_op_stra_cost->cost_list;
1538       auto first_succ_edge_clist = first_succ_edge->GetCostList(merged_op_stra, first_succ_node_stra);
1539 
1540       CreateStarEliminationCostList(succ_edges, first_succ_node_stra, first_succ_node_clist, first_succ_edge_clist,
1541                                     merged_op_stra, merged_op_clist, &first_succ_node_clist_new);
1542     }
1543     Simplify(&first_succ_node_clist_new);
1544     // Set the new costlist w.r.t the strategy
1545     first_succ_node_stra_cost->cost_list = first_succ_node_clist_new;
1546     if ((!valid) && (!first_succ_node_clist_new.empty())) {
1547       valid = true;
1548     }
1549   }
1550 
1551   if (!valid) {
1552     MS_LOG(EXCEPTION) << "Eliminating star centered at: " << merged_op->name()
1553                       << " failed. It may be caused by "
1554                          "configuring inconsistent strategies for operators.";
1555   }
1556 
1557   merged_op->SetNotAlive();
1558   MS_LOG(INFO) << "Eliminating star centered at: " << merged_op->name() << " succeeded.";
1559   return succ_edges;
1560 }
1561 
GetNumEdges() const1562 size_t CostGraph::GetNumEdges() const {
1563   size_t sum = 0;
1564   for (const auto &kv : edges_) {
1565     auto &edges = kv.second;
1566     sum += edges.size();
1567   }
1568   return sum;
1569 }
1570 
InitReshapeStrategy()1571 Status CostGraph::InitReshapeStrategy() {
1572   // reshape init should be apply after the init of it's previous node and next node.
1573   for (size_t i = 0; i < ops_.size(); ++i) {
1574     if (ops_[i]->name().find(RESHAPEINFO) != std::string::npos) {
1575       auto reshape_info = std::dynamic_pointer_cast<ReshapeInfo>(ops_[i]);
1576       auto in_edges = GetOriginalPrevEdges(ops_[i]);
1577       auto pre_iter = std::find_if(in_edges.begin(), in_edges.end(), [&](const std::shared_ptr<Edge> &edge) {
1578         return edge->prev_operator()->name() == reshape_info->pre_operator_name();
1579       });
1580       auto out_edges = GetOriginalNextEdges(ops_[i]);
1581       auto next_iter = std::find_if(out_edges.begin(), out_edges.end(), [&](const std::shared_ptr<Edge> &edge) {
1582         return edge->next_operator()->name() == reshape_info->next_operator_name();
1583       });
1584       bool reshape_is_first_op = reshape_info->pre_operator_name() == reshape_info->name();
1585       if (reshape_is_first_op) {
1586         reshape_info->InitSelectedStrategy(reshape_info->selected_strategy());
1587       }
1588       if (pre_iter != in_edges.end() || reshape_is_first_op) {
1589         MS_LOG(DEBUG) << "Set reshape input layout by " << reshape_info->pre_operator_name();
1590         int64_t pre_index = reshape_info->pre_operator_index();
1591         TensorInfo pre_info;
1592         std::shared_ptr<OperatorInfo> pre_op_info;
1593         if (reshape_is_first_op) {
1594           pre_op_info = reshape_info;
1595           pre_info = pre_op_info->inputs_tensor_info()[LongToSize(pre_index)];
1596         } else {
1597           pre_op_info = (*pre_iter)->prev_operator();
1598           pre_info = pre_op_info->outputs_tensor_info()[LongToSize(pre_index)];
1599         }
1600         reshape_info->SetInputLayout(pre_info.tensor_layout());
1601         if (pre_iter != in_edges.end()) {
1602           Dimensions stra = pre_info.InferStrategy();
1603           if (stra.empty()) {
1604             MS_LOG(EXCEPTION) << "Infer strategy by tensor_info failed";
1605           }
1606           Strategys stra_inputs = {stra};
1607           StrategyPtr reshape_stra =
1608             std::make_shared<Strategy>((*pre_iter)->prev_operator()->strategy()->GetInputStage(), stra_inputs);
1609           reshape_info->set_strategy(reshape_stra);
1610         }
1611       }
1612       if (next_iter != out_edges.end()) {
1613         MS_LOG(DEBUG) << "Set reshape output layout by " << reshape_info->next_operator_name();
1614         int64_t next_index = reshape_info->next_operator_index();
1615         reshape_info->SetOutputLayout(
1616           (*next_iter)->next_operator()->inputs_tensor_info()[LongToSize(next_index)].tensor_layout());
1617       }
1618       if (reshape_info->Init(nullptr) != SUCCESS) {
1619         return FAILED;
1620       }
1621     }
1622   }
1623   return SUCCESS;
1624 }
1625 
InitSelectedStrategy()1626 Status CostGraph::InitSelectedStrategy() {
1627   for (auto &op : ops_) {
1628     MS_EXCEPTION_IF_NULL(op);
1629     if (op->name().find(RESHAPEINFO) != std::string::npos) {
1630       continue;
1631     }
1632     auto result_op = op->InitSelectedStrategy(op->selected_strategy());
1633     if (result_op != SUCCESS) {
1634       return result_op;
1635     }
1636   }
1637   auto result = InitReshapeStrategy();
1638   return result;
1639 }
1640 
ComputeOpsAndEdgesParameterInvolved()1641 Status CostGraph::ComputeOpsAndEdgesParameterInvolved() {
1642   for (auto &op : ops_) {
1643     MS_EXCEPTION_IF_NULL(op);
1644     const auto &output_parameter = op->ComputeOpAndPrevEdgeParameterInvolved();
1645     if ((output_parameter != 0) && (output_parameter != 1)) {
1646       MS_LOG(ERROR) << "Computing parameter_involved for " << op->name() << " failed.";
1647       return FAILED;
1648     }
1649   }
1650   return SUCCESS;
1651 }
1652 
DFSForTopoOrder(const OperatorInfoPtr & current_op,std::map<OperatorInfoPtr,bool> * visited,std::vector<OperatorInfoPtr> * topo_order)1653 void CostGraph::DFSForTopoOrder(const OperatorInfoPtr &current_op, std::map<OperatorInfoPtr, bool> *visited,
1654                                 std::vector<OperatorInfoPtr> *topo_order) {
1655   MS_EXCEPTION_IF_NULL(current_op);
1656   MS_EXCEPTION_IF_NULL(visited);
1657   MS_EXCEPTION_IF_NULL(topo_order);
1658 
1659   visited->at(current_op) = true;
1660   for (const auto &s_edge : current_op->succ_edges()) {
1661     if (!visited->at(s_edge->next_operator())) {
1662       DFSForTopoOrder(s_edge->next_operator(), visited, topo_order);
1663     }
1664   }
1665   topo_order->push_back(current_op);
1666 }
1667 
1668 // Compute a topological order of the costgraph
TopologyOrder(std::vector<OperatorInfoPtr> * topo_order)1669 void CostGraph::TopologyOrder(std::vector<OperatorInfoPtr> *topo_order) {
1670   std::map<OperatorInfoPtr, bool> visited;
1671   for (auto &op : ops_) {
1672     visited[op] = false;
1673   }
1674 
1675   for (auto &op : ops_) {
1676     if (!visited[op]) {
1677       DFSForTopoOrder(op, &visited, topo_order);
1678     }
1679   }
1680 }
MarkCriticalOpsAndEdges(const std::map<OperatorInfoPtr,int64_t> & candidate_ops)1681 void CostGraph::MarkCriticalOpsAndEdges(const std::map<OperatorInfoPtr, int64_t> &candidate_ops) {
1682   for (auto &op : ops_) {
1683     auto search = candidate_ops.find(op);
1684     if (search != candidate_ops.end()) {
1685       // Mark the critical operators
1686       op->mark_output_critical();
1687       // Mark the successive edges
1688       for (auto &s_edge : op->succ_edges()) {
1689         s_edge->mark_output_critical();
1690       }
1691     } else {
1692       op->mark_output_not_critical();
1693     }
1694   }
1695 }
1696 
DetermineCriticalOps(const std::vector<OperatorInfoPtr> & topo_order)1697 Status CostGraph::DetermineCriticalOps(const std::vector<OperatorInfoPtr> &topo_order) {
1698   if (topo_order.size() == 0) {
1699     MS_LOG(ERROR) << "0 operator in costgraph.";
1700     return FAILED;
1701   }
1702   auto &first_op = topo_order[0];
1703   if (first_op->prev_edges().size() > 0) {
1704     MS_LOG(ERROR) << "The first operator in the first of topological order of "
1705                      "costgraph should have 0 incoming edge, but has "
1706                   << first_op->prev_edges() << "edges.";
1707     return FAILED;
1708   }
1709   // The 'curr_memory_state' records <OperatorInfo, remaining_output_cnt>, where remaining_output_cnt is the number
1710   // of the output of OperatorInfo that currently has not been used
1711   std::map<OperatorInfoPtr, int64_t> curr_memory_state;
1712   (void)curr_memory_state.emplace(std::make_pair(first_op, SizeToLong(first_op->succ_edges().size())));
1713   std::map<OperatorInfoPtr, int64_t> max_memory_state = curr_memory_state;
1714   // The 'curr_memory_size' records the current total memory size, which is the sum of outputs of operators that has
1715   // not been used
1716   double curr_memory_size = first_op->GetOutputsTotalSize();
1717   double max_memory_size = curr_memory_size;
1718 
1719   for (size_t finished = 1; finished < topo_order.size(); ++finished) {
1720     // Produce
1721     (void)curr_memory_state.emplace(
1722       std::make_pair(topo_order[finished], SizeToLong(topo_order[finished]->succ_edges().size())));
1723     curr_memory_size += topo_order[finished]->GetOutputsTotalSize();
1724     // Consume
1725     for (const auto &prev_edge : topo_order[finished]->prev_edges()) {
1726       const auto &prev_op = prev_edge->prev_operator();
1727       curr_memory_state[prev_op]--;
1728     }
1729     for (const auto &prev_edge : topo_order[finished]->prev_edges()) {
1730       const auto &prev_op = prev_edge->prev_operator();
1731       if (curr_memory_state[prev_op] < 0) {
1732         MS_LOG(ERROR) << "Failure: " << prev_op->name() << "'s current output count: " << curr_memory_state[prev_op];
1733         return FAILED;
1734       } else if (curr_memory_state[prev_op] == 0) {
1735         curr_memory_state.erase(prev_op);
1736         curr_memory_size -= prev_op->GetOutputsTotalSize();
1737       }
1738     }
1739 
1740     if (curr_memory_size < 0) {
1741       MS_LOG(ERROR) << "Memory size calculation failed: " << curr_memory_size;
1742     }
1743     // Modify the max
1744     if (curr_memory_size > max_memory_size) {
1745       max_memory_size = curr_memory_size;
1746       max_memory_state = curr_memory_state;
1747     }
1748   }
1749   // Mark those critical operators
1750   MarkCriticalOpsAndEdges(max_memory_state);
1751   return SUCCESS;
1752 }
1753 
ComputeOpsAndEdgesOutputCritical()1754 Status CostGraph::ComputeOpsAndEdgesOutputCritical() {
1755   // Two steps to do:
1756   // 1. Compute a topological order of the costgraph
1757   // 2. Determine and mark the operators (and necessary edges) that are critical
1758   std::vector<OperatorInfoPtr> topo_order;
1759   TopologyOrder(&topo_order);
1760   std::reverse(std::begin(topo_order), std::end(topo_order));
1761 
1762   if (DetermineCriticalOps(topo_order) != SUCCESS) {
1763     MS_LOG(ERROR) << "Determining critical operators failed.";
1764     return FAILED;
1765   }
1766   return SUCCESS;
1767 }
1768 
CalculateOpsMemoryCost()1769 Status CostGraph::CalculateOpsMemoryCost() {
1770   for (auto &op : ops_) {
1771     MS_EXCEPTION_IF_NULL(op);
1772     if (op->CalculateMemoryCost() != SUCCESS) {
1773       MS_LOG(ERROR) << "Calculate Operator: " << op->name() << " cost for memory usage failed.";
1774       return FAILED;
1775     }
1776   }
1777   return SUCCESS;
1778 }
1779 
CalculateOpsMemoryCostForInference()1780 Status CostGraph::CalculateOpsMemoryCostForInference() {
1781   for (auto &op : ops_) {
1782     MS_EXCEPTION_IF_NULL(op);
1783     if (op->CalculateMemoryCostForInference() != SUCCESS) {
1784       MS_LOG(ERROR) << "Calculate Operator: " << op->name() << " cost for memory usage failed.";
1785       return FAILED;
1786     }
1787   }
1788   return SUCCESS;
1789 }
1790 
CalculateEdgesMemoryCost()1791 Status CostGraph::CalculateEdgesMemoryCost() {
1792   for (auto &edge_pair : edges_) {
1793     const auto &edges = edge_pair.second;
1794     for (auto &one_edge : edges) {
1795       if (one_edge->CalculateMemoryCost() != SUCCESS) {
1796         MS_LOG(ERROR) << "Calculate Edge: " << one_edge->edge_name() << " cost for memory usage failed.";
1797         return FAILED;
1798       }
1799     }
1800   }
1801   return SUCCESS;
1802 }
1803 
CalculateEdgesMemoryCostForInference()1804 Status CostGraph::CalculateEdgesMemoryCostForInference() {
1805   for (auto &edge_pair : edges_) {
1806     const auto &edges = edge_pair.second;
1807     for (auto &one_edge : edges) {
1808       if (one_edge->CalculateMemoryCostForInference() != SUCCESS) {
1809         MS_LOG(ERROR) << "Calculate Edge: " << one_edge->edge_name() << " cost for memory usage failed.";
1810         return FAILED;
1811       }
1812     }
1813   }
1814   return SUCCESS;
1815 }
1816 
FindTmpIdentityByParameterName(std::string & p_name) const1817 OperatorInfoPtr CostGraph::FindTmpIdentityByParameterName(std::string &p_name) const {
1818   for (auto one_op : ops_) {
1819     if (one_op->name().find(IDENTITY_INFO) != std::string::npos) {
1820       if (one_op->refkey_parameter_name() == p_name) {
1821         return one_op;
1822       }
1823     }
1824   }
1825   return nullptr;
1826 }
CorrectOpsMemoryCost()1827 Status CostGraph::CorrectOpsMemoryCost() {
1828   for (auto &one_op : ops_) {
1829     if ((one_op->name().find(IDENTITY_INFO) != std::string::npos) && (one_op->is_output_parameter_involve() == 1)) {
1830       if (one_op->GetAliveSuccEdges().size() > 1) {
1831         // Filter out the case when the TmpIdentity being used by multiple operators
1832         std::map<size_t, int64_t> output_count;
1833         for (size_t i = 0; i < one_op->GetAliveSuccEdges().size(); ++i) {
1834           auto output_index = one_op->GetAliveSuccEdges()[i]->prev_op_output_index();
1835           output_count[output_index]++;
1836         }
1837         for (size_t i = 0; i < one_op->GetAliveSuccEdges().size(); ++i) {
1838           auto output_index = one_op->GetAliveSuccEdges()[i]->prev_op_output_index();
1839           if (output_count[output_index] <= 1) {
1840             continue;
1841           }
1842           auto next_op = one_op->GetAliveSuccEdges()[i]->next_operator();
1843           MS_EXCEPTION_IF_NULL(next_op);
1844           auto input_index = one_op->GetAliveSuccEdges()[i]->next_op_input_index();
1845           if (next_op->CorrectMemoryCost(input_index) != SUCCESS) {
1846             MS_LOG(ERROR) << "The operator name: " << one_op->name() << ", the next operator name: " << next_op->name()
1847                           << ", the output_index: " << output_index << ", the input_index: " << input_index << ".";
1848             return FAILED;
1849           }
1850           output_count[output_index]--;
1851         }
1852       }
1853     }
1854   }
1855   return SUCCESS;
1856 }
1857 
CalculateMemoryCost()1858 Status CostGraph::CalculateMemoryCost() {
1859   const auto phase = CostModelContext::GetInstance()->run_phase();
1860   if (phase == TRAINING_PHASE) {
1861     // training phase
1862     if (ComputeOpsAndEdgesParameterInvolved() == SUCCESS) {
1863       // Calculate operators' memory usage
1864       if (CalculateOpsMemoryCost() != SUCCESS) {
1865         MS_LOG(ERROR) << "Calculating operators' cost for memory cost failed.";
1866         return FAILED;
1867       }
1868       // Calculate edges' memory usage
1869       if (CalculateEdgesMemoryCost() != SUCCESS) {
1870         MS_LOG(ERROR) << "Calculating edges' cost for memory cost failed.";
1871         return FAILED;
1872       }
1873       // Correct memory usage caused by TmpIdentity
1874       if (CorrectOpsMemoryCost() != SUCCESS) {
1875         MS_LOG(ERROR) << "Correcting operators' cost for memory cost failed.";
1876         return FAILED;
1877       }
1878     } else {
1879       MS_LOG(ERROR) << "Computing operators' parameter_involved failed.";
1880       return FAILED;
1881     }
1882   } else {
1883     // inference phase
1884     if (ComputeOpsAndEdgesOutputCritical() == SUCCESS) {
1885       // Calculate operators' memory usage
1886       if (CalculateOpsMemoryCostForInference() != SUCCESS) {
1887         MS_LOG(ERROR) << "Calculating operators' memory cost for inference failed.";
1888         return FAILED;
1889       }
1890       // Calculate edges's memory usage
1891       if (CalculateEdgesMemoryCostForInference() != SUCCESS) {
1892         MS_LOG(ERROR) << "Calculating operators' memory cost for inference failed.";
1893         return FAILED;
1894       }
1895     } else {
1896       MS_LOG(ERROR) << "Computing operators' critical flag failed.";
1897       return FAILED;
1898     }
1899   }
1900   return SUCCESS;
1901 }
1902 
CheckApproximateCostGraphEdges()1903 void CostGraph::CheckApproximateCostGraphEdges() {
1904   auto approximation = CostModelContext::GetInstance()->dp_algo_enable_approxi();
1905   if (!approximation) {
1906     return;
1907   }
1908   for (auto &s_edge : edges_) {
1909     auto &edges_vector = s_edge.second;
1910     for (auto &edge_ptr : edges_vector) {
1911       MS_EXCEPTION_IF_NULL(edge_ptr);
1912       if (edge_ptr->CheckStrategyCostPossibility()) {
1913         continue;
1914       }
1915       MS_LOG(INFO) << "Checking StrategyCost for edge: " << edge_ptr->edge_name()
1916                    << " impossible, re-initing the operators and edges";
1917       auto prev_op = edge_ptr->prev_operator();
1918       MS_EXCEPTION_IF_NULL(prev_op);
1919       auto next_op = edge_ptr->next_operator();
1920       MS_EXCEPTION_IF_NULL(next_op);
1921       // Check the 'prev_op'
1922       prev_op->ExactStrategiesAndRelatedEdges();
1923       // Check the 'next_op'
1924       next_op->ExactStrategiesAndRelatedEdges();
1925     }
1926   }
1927 }
1928 }  // namespace parallel
1929 }  // namespace mindspore
1930