• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include <algorithm>
17 #include <cstdlib>
18 #include <numeric>
19 #include <string>
20 #include <utility>
21 #include <vector>
22 #include <queue>
23 #include <set>
24 
25 #include "frontend/parallel/auto_parallel/graph_costmodel.h"
26 #include "frontend/parallel/ops_info/reshape_info.h"
27 #include "frontend/parallel/step_auto_parallel.h"
28 
29 namespace mindspore {
30 namespace parallel {
31 CostGraphPtr entire_costgraph = nullptr;
32 
Init()33 void CostGraph::Init() {
34   inputs_tensor_name_list_.clear();
35   tuple_getitem_list_.clear();
36   ops_.clear();
37   edges_.clear();
38   connected_compoents_.clear();
39   out_edges_.clear();
40   in_edges_.clear();
41 }
42 
RemoveOperator(const OperatorInfoPtr & op)43 void CostGraph::RemoveOperator(const OperatorInfoPtr &op) {
44   for (auto it = ops_.begin(); it != ops_.end();) {
45     if ((*it) == op) {
46       it = ops_.erase(it);
47     } else {
48       ++it;
49     }
50   }
51 }
52 
IsOperatorInCostGraph(const OperatorInfoPtr & op_test)53 bool CostGraph::IsOperatorInCostGraph(const OperatorInfoPtr &op_test) {
54   struct IsInGraph {
55     const OperatorInfoPtr test_;
56     explicit IsInGraph(const OperatorInfoPtr &n) : test_(n) {}
57     bool operator()(const OperatorInfoPtr &in) const { return (test_ == in); }
58   };
59   return std::any_of(ops_.begin(), ops_.end(), IsInGraph(op_test));
60 }
61 
AddEdge(OperatorInfoPtr u_node,OperatorInfoPtr v_node,const EdgePtr & edge)62 void CostGraph::AddEdge(OperatorInfoPtr u_node, OperatorInfoPtr v_node, const EdgePtr &edge) {
63   std::vector<EdgePtr> curr_edges(edges_[{u_node, v_node}]);
64   curr_edges.push_back(edge);
65   edges_[{u_node, v_node}] = curr_edges;
66 
67   std::vector<EdgePtr> curr_out_edges(out_edges_[u_node]);
68   curr_out_edges.push_back(edge);
69   out_edges_[u_node] = curr_out_edges;
70 
71   std::vector<EdgePtr> curr_in_edges(in_edges_[v_node]);
72   curr_in_edges.push_back(edge);
73   in_edges_[v_node] = curr_in_edges;
74 }
75 
IsEdgeInCostGraph(const std::string & test_edge_name,size_t output_index,size_t input_index)76 bool CostGraph::IsEdgeInCostGraph(const std::string &test_edge_name, size_t output_index, size_t input_index) {
77   for (auto &edge_pair : edges_) {
78     auto edges = edge_pair.second;
79     for (auto &edge : edges) {
80       MS_EXCEPTION_IF_NULL(edge);
81       bool bool_result = (edge->edge_name() == test_edge_name) && (edge->prev_op_output_index() == output_index) &&
82                          (edge->next_op_input_index() == input_index);
83       if (bool_result) {
84         return true;
85       }
86     }
87   }
88   return false;
89 }
90 
91 // '_diff_stra_params' contains all TmpIdentity operators to be processed(i.e. params), the users of
92 // these operators have different strategies, which need an additional propagation process.
93 std::set<OperatorInfoPtr> _diff_stra_params;
StrategyPropagate(const std::map<OperatorInfoPtr,StrategyPtr,OpsPtrCompare> & ops_stras)94 void CostGraph::StrategyPropagate(const std::map<OperatorInfoPtr, StrategyPtr, OpsPtrCompare> &ops_stras) {
95   if (ops_stras.empty()) {
96     MS_LOG(EXCEPTION) << "There is no operator that is configured sharding strategy.";
97   }
98   _diff_stra_params.clear();
99   std::map<OperatorInfoPtr, bool> visited;
100   for (auto &op : ops_) {
101     visited[op] = false;
102   }
103   for (auto &op_stra : ops_stras) {
104     BFS(op_stra.first, op_stra.second, ops_stras, &visited);
105   }
106 
107   ProcessDiffStraParams(ops_stras);
108   _diff_stra_params.clear();
109   // GetNext as a isolate op can not be propagated
110   for (auto &op : entire_costgraph->GetOperators()) {
111     if (op->selected_strategy() == nullptr) {
112       MS_LOG(INFO) << "Set default strategy for op: " << op->name();
113       op->SetSelectedStrategy(op->strategy_cost()[0]->strategy_ptr, 0);
114     }
115   }
116 }
117 
CheckVisitedEdgeConsistency(const EdgePtr & edge)118 bool CheckVisitedEdgeConsistency(const EdgePtr &edge) {
119   MS_EXCEPTION_IF_NULL(edge);
120   auto prev_op = edge->prev_operator();
121   auto next_op = edge->next_operator();
122   bool consistency = true;
123   if (prev_op->IsReshape()) {
124     const auto &reshape_output_lyt =
125       next_op->GetInputLayoutFromSWCByStrategy(next_op->selected_strategy(), edge->next_op_input_index());
126     auto reshape_ptr = std::dynamic_pointer_cast<ReshapeInfo>(prev_op);
127     consistency = reshape_ptr->CheckStrategyConsistencyByOutputLayout(reshape_ptr->swc_index(), reshape_output_lyt);
128     if (!consistency) {
129       MS_LOG(WARNING) << "Inconsistency occurred at edge: " << edge->edge_name();
130     }
131   } else if (next_op->IsReshape()) {
132     const auto &reshape_input_lyt =
133       prev_op->GetOutputLayoutFromSWCByStrategy(prev_op->selected_strategy(), edge->prev_op_output_index());
134     auto reshape_ptr = std::dynamic_pointer_cast<ReshapeInfo>(next_op);
135     consistency = reshape_ptr->CheckStrategyConsistencyByInputLayout(reshape_ptr->swc_index(), reshape_input_lyt);
136     if (!consistency) {
137       MS_LOG(WARNING) << "Inconsistency occurred at edge: " << edge->edge_name();
138     }
139   } else {
140     consistency =
141       edge->CheckStrategyConsistency(prev_op->selected_strategy(), next_op->selected_strategy(), &_diff_stra_params);
142     if (!consistency) {
143       MS_LOG(WARNING) << "Inconsistency occurred at edge: " << edge->edge_name();
144     }
145   }
146   return consistency;
147 }
148 
CheckConfiguredSuccEdgeConsistency(const EdgePtr & edge,const std::map<OperatorInfoPtr,StrategyPtr,OpsPtrCompare> & configured_ops)149 bool CheckConfiguredSuccEdgeConsistency(const EdgePtr &edge,
150                                         const std::map<OperatorInfoPtr, StrategyPtr, OpsPtrCompare> &configured_ops) {
151   MS_EXCEPTION_IF_NULL(edge);
152   auto curr_op = edge->prev_operator();
153   auto next_op = edge->next_operator();
154   auto next_op_conf_stra = configured_ops.at(next_op);
155   bool consistency = true;
156   if (curr_op->IsReshape()) {
157     const auto &reshape_output_lyt =
158       next_op->GetInputLayoutFromSWCByStrategy(next_op_conf_stra, edge->next_op_input_index());
159     auto reshape_ptr = std::dynamic_pointer_cast<ReshapeInfo>(curr_op);
160     consistency = reshape_ptr->CheckStrategyConsistencyByOutputLayout(reshape_ptr->swc_index(), reshape_output_lyt);
161     if (!consistency) {
162       MS_LOG(WARNING) << "Inconsistency occurred at edge: " << edge->edge_name();
163     }
164     return consistency;
165   } else {
166     consistency = edge->CheckStrategyConsistency(curr_op->selected_strategy(), next_op_conf_stra, &_diff_stra_params);
167     if (!consistency) {
168       MS_LOG(WARNING) << "Inconsistency occurred at edge: " << edge->edge_name();
169     }
170   }
171   return consistency;
172 }
173 
CheckConfiguredPrevEdgeConsistency(const EdgePtr & edge,const std::map<OperatorInfoPtr,StrategyPtr,OpsPtrCompare> & configured_ops)174 void CheckConfiguredPrevEdgeConsistency(const EdgePtr &edge,
175                                         const std::map<OperatorInfoPtr, StrategyPtr, OpsPtrCompare> &configured_ops) {
176   auto curr_op = edge->next_operator();
177   auto prev_op = edge->prev_operator();
178   auto prev_op_conf_stra = configured_ops.at(prev_op);
179   if (curr_op->IsReshape()) {
180     const auto &reshape_input_lyt =
181       prev_op->GetOutputLayoutFromSWCByStrategy(prev_op_conf_stra, edge->prev_op_output_index());
182     auto reshape_ptr = std::dynamic_pointer_cast<ReshapeInfo>(curr_op);
183     auto consistency = reshape_ptr->CheckStrategyConsistencyByInputLayout(reshape_ptr->swc_index(), reshape_input_lyt);
184     if (!consistency) {
185       MS_LOG(WARNING) << "Inconsistency occurred at edge: " << edge->edge_name();
186     }
187   } else {
188     auto consistency =
189       edge->CheckStrategyConsistency(prev_op_conf_stra, curr_op->selected_strategy(), &_diff_stra_params);
190     if (!consistency) {
191       MS_LOG(WARNING) << "Inconsistency occurred at edge: " << edge->edge_name();
192     }
193   }
194 }
195 
BFSPreNode(const std::shared_ptr<Edge> & edge,std::map<OperatorInfoPtr,bool> * visited,int64_t curr_depth,const std::map<OperatorInfoPtr,StrategyPtr,OpsPtrCompare> & configured_ops,const OperatorInfoPtr & curr_op,std::queue<std::pair<std::pair<OperatorInfoPtr,std::pair<StrategyPtr,int64_t>>,int64_t>> * next_level)196 void BFSPreNode(
197   const std::shared_ptr<Edge> &edge, std::map<OperatorInfoPtr, bool> *visited, int64_t curr_depth,
198   const std::map<OperatorInfoPtr, StrategyPtr, OpsPtrCompare> &configured_ops, const OperatorInfoPtr &curr_op,
199   std::queue<std::pair<std::pair<OperatorInfoPtr, std::pair<StrategyPtr, int64_t>>, int64_t>> *next_level) {
200   MS_EXCEPTION_IF_NULL(edge);
201   MS_EXCEPTION_IF_NULL(visited);
202   MS_EXCEPTION_IF_NULL(curr_op);
203   MS_EXCEPTION_IF_NULL(next_level);
204   const auto &prev_op = edge->prev_operator();
205   bool is_has_stra_op = visited->at(prev_op) || configured_ops.find(prev_op) != configured_ops.end();
206   if (edge->InitEdgeCost() != SUCCESS && !is_has_stra_op) {
207     MS_LOG(EXCEPTION) << "Edge cost initialization failed.";
208   }
209   if (visited->at(prev_op)) {
210     (void)CheckVisitedEdgeConsistency(edge);
211     return;
212   }
213   if ((curr_depth > 0) && (configured_ops.find(prev_op) != configured_ops.end())) {
214     CheckConfiguredPrevEdgeConsistency(edge, configured_ops);
215   }
216   if (configured_ops.find(prev_op) != configured_ops.end()) {
217     return;
218   }
219   visited->at(prev_op) = true;
220   if (prev_op->IsReshape()) {
221     auto swc_index = edge->GetReshapeSWCIndexByNextOpStrategy(curr_op->selected_strategy());
222     (void)next_level->emplace(std::make_pair(prev_op, std::make_pair(nullptr, swc_index)), curr_depth + 1);
223     prev_op->set_swc_index(swc_index, curr_depth + 1);
224   } else if (curr_op->IsReshape()) {
225     auto prev_stra = edge->GetPrevOpStrategyByReshapeSWCIndex(curr_op->swc_index());
226     (void)next_level->emplace(std::make_pair(prev_op, std::make_pair(prev_stra, -1)), curr_depth + 1);
227     prev_op->SetSelectedStrategy(prev_stra, LongToSize(curr_depth + 1));
228     prev_op->ClearStrategyCost();
229     if (prev_op->SetCostUnderStrategy(prev_stra) != SUCCESS) {
230       MS_LOG(EXCEPTION) << "Failure: operator " << prev_op->name() << " SetCostUnderStrategy failed";
231     }
232   } else {
233     const auto &prev_op_stra = edge->GetPrevOpStrategyByNextOpStrategyWithMiniComm(curr_op->selected_strategy());
234     if (prev_op_stra == nullptr) {
235       MS_LOG(EXCEPTION) << "Current op " << curr_op->name() << "'s strategy is "
236                         << curr_op->selected_strategy()->ToString() << ". " << prev_op->name()
237                         << "'s strategy is null in the edge: " << edge->edge_name();
238     }
239     (void)next_level->emplace(std::make_pair(prev_op, std::make_pair(prev_op_stra, -1)), curr_depth + 1);
240     prev_op->SetSelectedStrategy(prev_op_stra, LongToSize(curr_depth + 1));
241     prev_op->ClearStrategyCost();
242     if (prev_op->SetCostUnderStrategy(prev_op_stra) != SUCCESS) {
243       MS_LOG(EXCEPTION) << "Failure: operator " << prev_op->name() << " SetCostUnderStrategy failed";
244     }
245   }
246 }
247 
BFSNextNode(const std::shared_ptr<Edge> & edge,std::map<OperatorInfoPtr,bool> * visited,int64_t curr_depth,const std::map<OperatorInfoPtr,StrategyPtr,OpsPtrCompare> & configured_ops,const OperatorInfoPtr & curr_op,std::queue<std::pair<std::pair<OperatorInfoPtr,std::pair<StrategyPtr,int64_t>>,int64_t>> * next_level)248 void BFSNextNode(
249   const std::shared_ptr<Edge> &edge, std::map<OperatorInfoPtr, bool> *visited, int64_t curr_depth,
250   const std::map<OperatorInfoPtr, StrategyPtr, OpsPtrCompare> &configured_ops, const OperatorInfoPtr &curr_op,
251   std::queue<std::pair<std::pair<OperatorInfoPtr, std::pair<StrategyPtr, int64_t>>, int64_t>> *next_level) {
252   MS_EXCEPTION_IF_NULL(edge);
253   MS_EXCEPTION_IF_NULL(visited);
254   MS_EXCEPTION_IF_NULL(curr_op);
255   MS_EXCEPTION_IF_NULL(next_level);
256   const auto &next_op = edge->next_operator();
257   bool is_has_stra_op = visited->at(next_op) || configured_ops.find(next_op) != configured_ops.end();
258   if (edge->InitEdgeCost() != SUCCESS && !is_has_stra_op) {
259     MS_LOG(EXCEPTION) << "Edge cost initialization failed.";
260   }
261   if (visited->at(next_op)) {
262     (void)CheckVisitedEdgeConsistency(edge);
263     return;
264   }
265   if ((curr_depth > 0) && (configured_ops.find(next_op) != configured_ops.end())) {
266     (void)CheckConfiguredSuccEdgeConsistency(edge, configured_ops);
267   }
268   if (configured_ops.find(next_op) != configured_ops.end()) {
269     return;
270   }
271   visited->at(next_op) = true;
272   if (curr_op->IsReshape()) {
273     auto stra = edge->GetNextOpStrategyByReshapeSWCIndex(curr_op->swc_index());
274     (void)next_level->emplace(std::make_pair(next_op, std::make_pair(stra, -1)), curr_depth + 1);
275     next_op->SetSelectedStrategy(stra, LongToSize(curr_depth + 1));
276     next_op->ClearStrategyCost();
277     if (next_op->SetCostUnderStrategy(stra) != SUCCESS) {
278       MS_LOG(EXCEPTION) << "Failure: operator " << next_op->name() << " SetCostUnderStrategy failed";
279     }
280   } else if (next_op->IsReshape()) {
281     auto swc_index = edge->GetReshapeSWCIndexByPrevOpStrategy(curr_op->selected_strategy());
282     (void)next_level->emplace(std::make_pair(next_op, std::make_pair(nullptr, swc_index)), curr_depth + 1);
283     next_op->set_swc_index(swc_index, curr_depth + 1);
284   } else {
285     const auto &next_op_stra = edge->GetNextOpStrategyByPrevOpStrategyWithMiniComm(curr_op->selected_strategy());
286     if (next_op_stra == nullptr) {
287       MS_LOG(INFO) << "The strategy is: " << curr_op->selected_strategy()->ToString();
288       MS_LOG(EXCEPTION) << next_op->name() << "'s strategy is null in the edge: " << edge->edge_name();
289     }
290     (void)next_level->emplace(std::make_pair(next_op, std::make_pair(next_op_stra, -1)), curr_depth + 1);
291     next_op->SetSelectedStrategy(next_op_stra, LongToSize(curr_depth + 1));
292     next_op->ClearStrategyCost();
293     if (next_op->SetCostUnderStrategy(next_op_stra) != SUCCESS) {
294       MS_LOG(EXCEPTION) << "Failure: operator " << next_op->name() << " SetCostUnderStrategy failed";
295     }
296   }
297 }
298 
BFS(const OperatorInfoPtr & op,const StrategyPtr & op_stra,const std::map<OperatorInfoPtr,StrategyPtr,OpsPtrCompare> & configured_ops,std::map<OperatorInfoPtr,bool> * visited) const299 void CostGraph::BFS(const OperatorInfoPtr &op, const StrategyPtr &op_stra,
300                     const std::map<OperatorInfoPtr, StrategyPtr, OpsPtrCompare> &configured_ops,
301                     std::map<OperatorInfoPtr, bool> *visited) const {
302   std::queue<std::pair<std::pair<OperatorInfoPtr, std::pair<StrategyPtr, int64_t>>, int64_t>> next_level;
303   (void)next_level.emplace(std::make_pair(op, std::make_pair(op_stra, -1)), 0);
304   op->SetSelectedStrategy(op_stra, 0);
305   while (!next_level.empty()) {
306     auto curr_op = next_level.front().first.first;
307     auto configured_stra = next_level.front().first.second.first;
308     auto curr_depth = next_level.front().second;
309     visited->at(curr_op) = true;
310     MS_LOG(INFO) << "curr_depth: " << curr_depth << " curr_op: " << curr_op->name();
311     for (auto &edge : curr_op->succ_edges()) {
312       BFSNextNode(edge, visited, curr_depth, configured_ops, curr_op, &next_level);
313     }
314     for (auto &edge : curr_op->prev_edges()) {
315       BFSPreNode(edge, visited, curr_depth, configured_ops, curr_op, &next_level);
316     }
317     next_level.pop();
318   }
319 }
320 
321 // An additional propagation to deal with incontinuous strategy between the ops that share params.
ProcessDiffStraParams(const std::map<OperatorInfoPtr,StrategyPtr,OpsPtrCompare> & configured_ops)322 void CostGraph::ProcessDiffStraParams(const std::map<OperatorInfoPtr, StrategyPtr, OpsPtrCompare> &configured_ops) {
323   for (auto &curr_identity_op : _diff_stra_params) {
324     auto succ_edges = curr_identity_op->succ_edges();
325     auto is_target = [&](const std::shared_ptr<Edge> &edge) {
326       return configured_ops.find(edge->next_operator()) != configured_ops.end();
327     };
328     auto it = std::find_if(succ_edges.begin(), succ_edges.end(), is_target);
329     // if there are configured ops in the users, update the strategy of identity_op
330     if (it != succ_edges.end()) {
331       auto consistency = CheckVisitedEdgeConsistency(*it);
332       if (!consistency) {
333         curr_identity_op->ClearStrategyCost();
334         (void)curr_identity_op->GenerateStrategies(0);
335         if ((*it)->InitEdgeCost() != SUCCESS) {
336           MS_LOG(EXCEPTION) << "Edge cost initialization failed.";
337         }
338         const auto curr_op = (*it)->next_operator();
339         const auto &prev_op_stra = (*it)->GetPrevOpStrategyByNextOpStrategyWithMiniComm(curr_op->selected_strategy());
340         if (prev_op_stra == nullptr) {
341           MS_LOG(EXCEPTION) << curr_identity_op->name() << "'s strategy is null in the edge: " << (*it)->edge_name();
342         }
343         curr_identity_op->SetSelectedStrategy(prev_op_stra, 1);
344         curr_identity_op->ClearStrategyCost();
345         if (curr_identity_op->SetCostUnderStrategy(prev_op_stra) != SUCCESS) {
346           MS_LOG(EXCEPTION) << "Failure: operator " << curr_identity_op->name() << " SetCostUnderStrategy failed";
347         }
348       }
349     }
350     for (auto &edge : succ_edges) {
351       const auto &next_op = edge->next_operator();
352       ParamPropagation(curr_identity_op, edge, configured_ops);
353       if (next_op->name().find(CAST) != std::string::npos) {
354         for (auto &cast_edge : next_op->succ_edges()) {
355           ParamPropagation(next_op, cast_edge, configured_ops);
356         }
357       }
358     }
359   }
360 }
361 
ParamPropagation(const OperatorInfoPtr & curr_op,const std::shared_ptr<Edge> edge,const std::map<OperatorInfoPtr,StrategyPtr,OpsPtrCompare> & configured_ops) const362 void CostGraph::ParamPropagation(const OperatorInfoPtr &curr_op, const std::shared_ptr<Edge> edge,
363                                  const std::map<OperatorInfoPtr, StrategyPtr, OpsPtrCompare> &configured_ops) const {
364   const auto &next_op = edge->next_operator();
365   MS_LOG(INFO) << "params propagation at " << curr_op->name() << "->" << next_op->name();
366   if ((configured_ops.find(next_op) != configured_ops.end())) {
367     auto consistency = CheckConfiguredSuccEdgeConsistency(edge, configured_ops);
368     if (!consistency) {
369       MS_LOG(EXCEPTION) << "Incorrect strategy configuration.";
370     }
371     return;
372   }
373   auto consistency = CheckVisitedEdgeConsistency(edge);
374   if (consistency) {
375     MS_LOG(INFO) << "Jump the consistency edge.";
376     return;
377   }
378   next_op->ClearStrategyCost();
379   (void)next_op->GenerateStrategies(0);
380   if (edge->InitEdgeCost() != SUCCESS) {
381     MS_LOG(EXCEPTION) << "Edge cost initialization failed.";
382   }
383   const auto &next_op_stra = edge->GetNextOpStrategyByPrevOpStrategyWithMiniComm(curr_op->selected_strategy());
384   if (next_op_stra == nullptr) {
385     MS_LOG(EXCEPTION) << next_op->name() << "'s strategy is null in the edge: " << edge->edge_name();
386   }
387   next_op->SetSelectedStrategy(next_op_stra, 1);
388   next_op->ClearStrategyCost();
389   if (next_op->SetCostUnderStrategy(next_op_stra) != SUCCESS) {
390     MS_LOG(EXCEPTION) << "Failure: operator " << next_op->name() << " SetCostUnderStrategy failed";
391   }
392 }
393 
ConstructConnectedComponents(std::vector<OperatorInfoPtr> alive_ops)394 std::vector<std::shared_ptr<CostGraph>> CostGraph::ConstructConnectedComponents(
395   std::vector<OperatorInfoPtr> alive_ops) {
396   std::map<OperatorInfoPtr, bool> visited;
397 
398   for (auto &op : alive_ops) {
399     visited[op] = false;
400   }
401 
402   MS_LOG(INFO) << "visited: " << visited.size() << ".";
403   for (auto &op : alive_ops) {
404     if ((!visited[op]) && op->is_alive()) {
405       std::shared_ptr<CostGraph> new_component = std::make_shared<CostGraph>();
406       MS_EXCEPTION_IF_NULL(new_component);
407       DFS(op, &visited, new_component);
408       connected_compoents_.push_back(new_component);
409     }
410   }
411   return connected_compoents_;
412 }
413 
DFS(const OperatorInfoPtr & current_op,std::map<OperatorInfoPtr,bool> * visited,const std::shared_ptr<CostGraph> & component)414 void CostGraph::DFS(const OperatorInfoPtr &current_op, std::map<OperatorInfoPtr, bool> *visited,
415                     const std::shared_ptr<CostGraph> &component) {
416   MS_EXCEPTION_IF_NULL(visited);
417   MS_EXCEPTION_IF_NULL(component);
418   visited->at(current_op) = true;
419   component->AddOperator(current_op);
420 
421   for (auto &edge : current_op->succ_edges()) {
422     bool bool_test = (visited->find(edge->next_operator()) != visited->end()) &&
423                      (!visited->at(edge->next_operator())) && edge->next_operator()->is_alive();
424     if (bool_test) {
425       component->AddEdge(current_op, edge->next_operator(), edge);
426       DFS(edge->next_operator(), visited, component);
427     }
428   }
429 
430   for (auto &edge : current_op->prev_edges()) {
431     bool bool_test = (visited->find(edge->prev_operator()) != visited->end()) &&
432                      (!visited->at(edge->prev_operator())) && edge->prev_operator()->is_alive();
433     if (bool_test) {
434       component->AddEdge(edge->prev_operator(), current_op, edge);
435       DFS(edge->prev_operator(), visited, component);
436     }
437   }
438 }
439 
440 // Create final cost list for the graph: u --> v
CreateFinalCostList(const OperatorInfoPtr & u,const std::shared_ptr<Edge> & e,const OperatorInfoPtr & v) const441 CostPtrList CostGraph::CreateFinalCostList(const OperatorInfoPtr &u, const std::shared_ptr<Edge> &e,
442                                            const OperatorInfoPtr &v) const {
443   MS_EXCEPTION_IF_NULL(u);
444   MS_EXCEPTION_IF_NULL(v);
445   MS_EXCEPTION_IF_NULL(e);
446   CostPtrList ret;
447   for (const auto &u_strategy : u->GetStrategyCost()) {
448     for (const auto &v_strategy : v->GetStrategyCost()) {
449       MS_EXCEPTION_IF_NULL(u_strategy);
450       MS_EXCEPTION_IF_NULL(v_strategy);
451       auto u_strategy_ptr = u_strategy->strategy_ptr;
452       auto v_strategy_ptr = v_strategy->strategy_ptr;
453       CostPtrList clist1 = u_strategy->cost_list;
454       CostPtrList clist2 = e->GetCostList(u_strategy_ptr, v_strategy_ptr);
455       CostPtrList clist3 = v_strategy->cost_list;
456       for (const auto &cost1 : clist1) {
457         for (const auto &cost2 : clist2) {
458           for (const auto &cost3 : clist3) {
459             MS_EXCEPTION_IF_NULL(cost1);
460             MS_EXCEPTION_IF_NULL(cost2);
461             MS_EXCEPTION_IF_NULL(cost3);
462             double computation = cost1->computation_cost_ + cost2->computation_cost_ + cost3->computation_cost_;
463             double memory = cost1->memory_with_reuse_ + cost2->memory_with_reuse_ + cost3->memory_with_reuse_;
464             double communication = cost1->communication_cost_ + cost2->communication_cost_ + cost3->communication_cost_;
465             double communication_forward =
466               cost1->communication_forward_ + cost2->communication_forward_ + cost3->communication_forward_;
467             double communication_without_para = cost1->communication_without_parameter_ +
468                                                 cost2->communication_without_parameter_ +
469                                                 cost3->communication_without_parameter_;
470             auto decision =
471               std::make_shared<FinalDecision>(u_strategy->strategy_ptr, v_strategy->strategy_ptr, cost1, cost2, cost3);
472             auto cost = std::make_shared<Cost>(computation, communication, decision);
473             const auto gamma = CostModelContext::GetInstance()->costmodel_gamma();
474             MS_EXCEPTION_IF_NULL(cost);
475             cost->communication_without_parameter_ = communication_without_para;
476             cost->communication_with_partial_para_ =
477               communication_without_para + gamma * (communication - communication_without_para);
478             cost->memory_with_reuse_ = memory;
479             cost->communication_forward_ = communication_forward;
480             ret.push_back(cost);
481           }
482         }
483       }
484     }
485   }
486 
487   Simplify(&ret);
488   return ret;
489 }
490 
491 // Create final cost list for the graph containing a single node: u
CreateFinalSingleCostList(const OperatorInfoPtr & u) const492 CostPtrList CostGraph::CreateFinalSingleCostList(const OperatorInfoPtr &u) const {
493   MS_EXCEPTION_IF_NULL(u);
494   CostPtrList ret;
495   for (const auto &u_strategy : u->GetStrategyCost()) {
496     MS_EXCEPTION_IF_NULL(u_strategy);
497     auto u_strategy_ptr = u_strategy->strategy_ptr;
498     CostPtrList clist1 = u_strategy->cost_list;
499     for (const auto &cost1 : clist1) {
500       MS_EXCEPTION_IF_NULL(cost1);
501       auto decision = std::make_shared<FinalSingleDecision>(u_strategy_ptr, cost1);
502       auto new_cost = std::make_shared<Cost>(cost1->computation_cost_, cost1->communication_cost_, decision);
503       const auto gamma = CostModelContext::GetInstance()->costmodel_gamma();
504       MS_EXCEPTION_IF_NULL(new_cost);
505       new_cost->communication_without_parameter_ = cost1->communication_without_parameter_;
506       new_cost->communication_with_partial_para_ =
507         cost1->communication_without_parameter_ +
508         gamma * (cost1->communication_cost_ - cost1->communication_without_parameter_);
509       new_cost->memory_with_reuse_ = cost1->memory_with_reuse_;
510       new_cost->communication_forward_ = cost1->communication_forward_;
511       ret.push_back(new_cost);
512     }
513   }
514 
515   Simplify(&ret);
516   return ret;
517 }
518 
SelectCostWithMinInferenceTime(const CostPtrList & cost_list,double memory) const519 CostPtr CostGraph::SelectCostWithMinInferenceTime(const CostPtrList &cost_list, double memory) const {
520   // Select the cost with minimum inference time. Currently, the inference time is modeled as =
521   // costmodel_alpha_ * computation_cost + costmodel_beta_ * communication_forward_
522   if (cost_list.empty()) {
523     MS_LOG(ERROR) << "Final cost list is null.";
524     return nullptr;
525   }
526   CostPtrList after_mem_filter;
527   double minimum_memory = DBL_MAX;
528   // Filter out the valid costs.
529   for (auto &a_cost : cost_list) {
530     if (a_cost->memory_with_reuse_ <= memory) {
531       after_mem_filter.push_back(a_cost);
532     } else if (a_cost->memory_with_reuse_ < minimum_memory) {
533       minimum_memory = a_cost->memory_with_reuse_;
534     }
535   }
536   if (after_mem_filter.empty()) {
537     MS_LOG(ERROR) << "No available cost. The minimum memory cost is: " << minimum_memory
538                   << ", the memory capacity is: " << memory << ".";
539     return nullptr;
540   }
541   // Init the returned value with first cost.
542   CostPtr ret = after_mem_filter[0];
543   const auto alpha = CostModelContext::GetInstance()->costmodel_alpha();
544   const auto beta = CostModelContext::GetInstance()->costmodel_beta();
545 
546   double minimum = alpha * ret->computation_cost_ + beta * ret->communication_forward_;
547   MS_LOG(INFO) << "Cost 0: "
548                << "memory_cost: " << ret->memory_with_reuse_ << ", computation_cost_: " << ret->computation_cost_
549                << ", communication_forward_: " << ret->communication_forward_
550                << ", communication_with_partial_para_: " << ret->communication_with_partial_para_
551                << ", communication_cost_: " << ret->communication_cost_
552                << ", communication_without_parameter_: " << ret->communication_without_parameter_ << ".";
553   MS_LOG(INFO) << "Cost 0: total_cost: " << minimum;
554   for (size_t i = 1; i < after_mem_filter.size(); ++i) {
555     MS_EXCEPTION_IF_NULL(after_mem_filter[i]);
556     MS_LOG(INFO) << "Cost " << i << ": memory_cost: " << after_mem_filter[i]->memory_with_reuse_
557                  << ", computation_cost_: " << after_mem_filter[i]->computation_cost_
558                  << ", communication_forward_: " << after_mem_filter[i]->communication_forward_
559                  << ", communication_with_partial_para_: " << after_mem_filter[i]->communication_with_partial_para_
560                  << ", communication_cost_: " << after_mem_filter[i]->communication_cost_
561                  << ", communication_without_parameter_: " << after_mem_filter[i]->communication_without_parameter_
562                  << ".";
563     auto tmp = alpha * after_mem_filter[i]->computation_cost_ + beta * after_mem_filter[i]->communication_forward_;
564     MS_LOG(INFO) << "Cost " << i << ": total_cost: " << tmp;
565     if (minimum > tmp) {
566       minimum = tmp;
567       ret = after_mem_filter[i];
568       MS_LOG(INFO) << "Selected: " << i;
569     }
570   }
571   return ret;
572 }
573 
SelectCostWithMinTrainingTime(const CostPtrList & cost_list,double memory) const574 CostPtr CostGraph::SelectCostWithMinTrainingTime(const CostPtrList &cost_list, double memory) const {
575   // Select the cost with minimum training time. Currently, the training time is modeled as =
576   // costmodel_alpha_ * computation_cost + costmodel_beta_ * communication_with_partial_para_
577   if (cost_list.empty()) {
578     MS_LOG(ERROR) << "Final cost list is null.";
579     return nullptr;
580   }
581   CostPtrList after_mem_filter;
582   double minimum_memory = DBL_MAX;
583   // Filter out the valid costs.
584   for (auto &a_cost : cost_list) {
585     if (a_cost->memory_with_reuse_ <= memory) {
586       after_mem_filter.push_back(a_cost);
587     } else if (a_cost->memory_with_reuse_ < minimum_memory) {
588       minimum_memory = a_cost->memory_with_reuse_;
589     }
590   }
591   if (after_mem_filter.empty()) {
592     MS_LOG(ERROR) << "No available cost. The minimum memory cost is: " << minimum_memory
593                   << ", the memory capacity is: " << memory << ".";
594     return nullptr;
595   }
596   // Init the returned value with first cost.
597   CostPtr ret = after_mem_filter[0];
598   const auto alpha = CostModelContext::GetInstance()->costmodel_alpha();
599   const auto beta = CostModelContext::GetInstance()->costmodel_beta();
600 
601   double minimum = alpha * ret->computation_cost_ + beta * ret->communication_with_partial_para_;
602   MS_LOG(INFO) << "Cost 0: "
603                << "memory_cost: " << ret->memory_with_reuse_ << ", computation_cost_: " << ret->computation_cost_
604                << ", communication_with_partial_para_: " << ret->communication_with_partial_para_
605                << ", communication_cost_: " << ret->communication_cost_
606                << ", communication_without_parameter_: " << ret->communication_without_parameter_ << ".";
607   MS_LOG(INFO) << "Cost 0: total_cost: " << minimum;
608   for (size_t i = 1; i < after_mem_filter.size(); ++i) {
609     MS_EXCEPTION_IF_NULL(after_mem_filter[i]);
610     MS_LOG(INFO) << "Cost " << i << ": memory_cost: " << after_mem_filter[i]->memory_with_reuse_
611                  << ", computation_cost_: " << after_mem_filter[i]->computation_cost_
612                  << ", communication_with_partial_para_: " << after_mem_filter[i]->communication_with_partial_para_
613                  << ", communication_cost_: " << after_mem_filter[i]->communication_cost_
614                  << ", communication_without_parameter_: " << after_mem_filter[i]->communication_without_parameter_
615                  << ".";
616     auto tmp =
617       alpha * after_mem_filter[i]->computation_cost_ + beta * after_mem_filter[i]->communication_with_partial_para_;
618     MS_LOG(INFO) << "Cost " << i << ": total_cost: " << tmp;
619     if (minimum > tmp) {
620       minimum = tmp;
621       ret = after_mem_filter[i];
622       MS_LOG(INFO) << "Selected: " << i;
623     }
624   }
625   return ret;
626 }
627 
SelectCostListWithMinTrainingTimeMultiple(const std::vector<CostPtrList> & all_cost_list,double available_memory) const628 CostPtrList CostGraph::SelectCostListWithMinTrainingTimeMultiple(const std::vector<CostPtrList> &all_cost_list,
629                                                                  double available_memory) const {
630   CostPtrList selected_cost_list(all_cost_list.size(), nullptr);
631   double minimum = DBL_MAX, total_memory = 0.0;
632   CostPtrList ret(all_cost_list.size(), nullptr);
633   // Check whether valid costs exist.
634   for (size_t i = 0; i < all_cost_list.size(); ++i) {
635     if (all_cost_list[i][0] == nullptr) {
636       MS_LOG(ERROR) << "The cost list " << i << " is empty.";
637       return ret;
638     } else {
639       double memory_i_cost = DBL_MAX;
640       for (size_t j = 0; j < all_cost_list[i].size(); ++j) {
641         if (all_cost_list[i][j]->memory_with_reuse_ < memory_i_cost) {
642           memory_i_cost = all_cost_list[i][j]->memory_with_reuse_;
643         }
644       }
645       total_memory += memory_i_cost;
646     }
647   }
648   if (total_memory >= available_memory) {
649     MS_LOG(ERROR) << "No strategy can be found under current memory: " << available_memory
650                   << ", minimum strategy cost: " << total_memory << ".";
651     return selected_cost_list;
652   }
653 
654   std::function<void(size_t)> recursive = [&all_cost_list, &selected_cost_list, &minimum, &ret, &recursive,
655                                            &available_memory](size_t k) {
656     const auto alpha = CostModelContext::GetInstance()->costmodel_alpha();
657     const auto beta = CostModelContext::GetInstance()->costmodel_beta();
658     if (k == all_cost_list.size()) {
659       double tmp_memory = 0.0, tmp_minimum = 0.0;
660       for (size_t i = 0; i < selected_cost_list.size(); ++i) {
661         MS_EXCEPTION_IF_NULL(selected_cost_list[i]);
662         tmp_memory += selected_cost_list[i]->memory_with_reuse_;
663         tmp_minimum += alpha * selected_cost_list[i]->computation_cost_ +
664                        beta * selected_cost_list[i]->communication_with_partial_para_;
665       }
666       MS_LOG(INFO) << "tmp_memory: " << tmp_memory << ", tmp_minimum: " << tmp_minimum << ", minimum: " << minimum
667                    << ".";
668       if (tmp_memory < available_memory && tmp_minimum < minimum) {
669         ret = selected_cost_list;
670         minimum = tmp_minimum;
671         MS_LOG(INFO) << "selected tmp_memory: " << tmp_memory << ", tmp_minimum: " << tmp_minimum << ".";
672       }
673       return;
674     }
675 
676     MS_LOG(DEBUG) << "The value minimum: " << minimum << ", available_memory: " << available_memory << ".";
677     for (auto &c : all_cost_list[k]) {
678       selected_cost_list[k] = c;
679       recursive(k + 1);
680     }
681   };
682   recursive(0);
683   return ret;
684 }
685 
SearchStrategyForMultiNodeFinalGraph(const std::vector<OperatorInfoPtr> & alive_ops)686 Status CostGraph::SearchStrategyForMultiNodeFinalGraph(const std::vector<OperatorInfoPtr> &alive_ops) {
687   MS_LOG(INFO) << "There are " << alive_ops.size() << " nodes in the final graph.";
688   auto connected_components = ConstructConnectedComponents(alive_ops);
689   MS_LOG(INFO) << "There are " << connected_components.size() << " components in the final graph.";
690   std::vector<CostPtrList> all_list;
691   for (size_t j = 0; j < connected_components.size(); ++j) {
692     auto one_component = connected_components[j];
693     MS_EXCEPTION_IF_NULL(one_component);
694     if (one_component->GetOperators().size() == 1) {
695       MS_LOG(INFO) << "There are 1 operator in a component in the final graph.";
696       auto cost_list_1 = one_component->CreateFinalSingleCostList(one_component->GetOperators()[0]);
697       all_list.push_back(cost_list_1);
698     } else if (one_component->GetOperators().size() == 2) {
699       MS_LOG(INFO) << "There are 2 operators in a component in the final graph.";
700       OperatorInfoPtr u, v;
701       auto first_op = one_component->GetOperators()[0];
702       auto second_op = one_component->GetOperators()[1];
703       MS_EXCEPTION_IF_NULL(first_op);
704       MS_EXCEPTION_IF_NULL(second_op);
705       if (!first_op->GetAliveSuccEdges().empty() &&
706           first_op->GetAliveSuccEdges()[0]->next_operator().get() == second_op.get()) {
707         u = first_op;
708         v = second_op;
709       } else if (!second_op->GetAliveSuccEdges().empty() &&
710                  second_op->GetAliveSuccEdges()[0]->next_operator().get() == first_op.get()) {
711         u = second_op;
712         v = first_op;
713       } else {
714         MS_LOG(EXCEPTION) << "The final graph is not the case of u --> v, " << first_op->GetAliveSuccEdges().size()
715                           << ", " << second_op->GetAliveSuccEdges().size() << ".";
716       }
717       MS_EXCEPTION_IF_NULL(u);
718       auto e = u->GetAliveSuccEdges()[0];
719       auto cost_list = one_component->CreateFinalCostList(u, e, v);
720       all_list.push_back(cost_list);
721     } else {
722       MS_LOG(EXCEPTION) << "There are " << one_component->GetOperators().size()
723                         << " operators in a component in the final graph.";
724     }
725   }
726   const auto device_mem_capacity = CostModelContext::GetInstance()->device_memory_capacity();
727   auto selected_cost_list = SelectCostListWithMinTrainingTimeMultiple(all_list, device_mem_capacity);
728   for (size_t k = 0; k < selected_cost_list.size(); ++k) {
729     auto selected_cost = selected_cost_list[k];
730     if (selected_cost == nullptr) {
731       MS_LOG(ERROR) << "No valid strategy can be found under the current device memory: " << device_mem_capacity << ".";
732       return FAILED;
733     }
734     MS_EXCEPTION_IF_NULL(connected_components[k]);
735     if (connected_components[k]->GetOperators().size() == 1) {
736       auto u = connected_components[k]->GetOperators()[0];
737       auto decision_f = selected_cost->decision_ptr_->cast<FinalSingleDecisionPtr>();
738       u->SetSelectedStrategyAndCost(decision_f->u_strategy_, decision_f->u_cost_);
739       MS_LOG(INFO) << "Searching the strategy for the component " << k << " final graph ended.";
740     } else if (connected_components[k]->GetOperators().size() == 2) {
741       OperatorInfoPtr u = nullptr;
742       OperatorInfoPtr v = nullptr;
743       auto first_op = connected_components[k]->GetOperators()[0];
744       auto second_op = connected_components[k]->GetOperators()[1];
745       MS_EXCEPTION_IF_NULL(first_op);
746       MS_EXCEPTION_IF_NULL(second_op);
747       if (!first_op->GetAliveSuccEdges().empty() &&
748           first_op->GetAliveSuccEdges()[0]->next_operator().get() == second_op.get()) {
749         u = first_op;
750         v = second_op;
751       } else if (!second_op->GetAliveSuccEdges().empty() &&
752                  second_op->GetAliveSuccEdges()[0]->next_operator().get() == first_op.get()) {
753         u = second_op;
754         v = first_op;
755       }
756       MS_EXCEPTION_IF_NULL(u);
757       auto e = u->GetAliveSuccEdges()[0];
758       MS_EXCEPTION_IF_NULL(v);
759       MS_EXCEPTION_IF_NULL(e);
760       MS_EXCEPTION_IF_NULL(selected_cost->decision_ptr_);
761       auto decision = selected_cost->decision_ptr_->cast<FinalDecisionPtr>();
762       MS_EXCEPTION_IF_NULL(decision);
763       u->SetSelectedStrategyAndCost(decision->u_strategy_, decision->left_cost_);
764       v->SetSelectedStrategyAndCost(decision->v_strategy_, decision->right_cost_);
765       e->set_selected_cost(decision->middle_cost_);
766       MS_LOG(INFO) << "Searching the strategy for the component " << k << " final graph ended.";
767     }
768   }
769   return SUCCESS;
770 }
771 
SearchStrategyForTwoNodeFinalGraph(const std::vector<OperatorInfoPtr> & alive_ops)772 Status CostGraph::SearchStrategyForTwoNodeFinalGraph(const std::vector<OperatorInfoPtr> &alive_ops) {
773   // In this case, the final graph should contains exactly 2 nodes.
774   if (alive_ops.empty()) {
775     MS_LOG(INFO) << "0 Operator in the final graph.";
776     return SUCCESS;
777   }
778   OperatorInfoPtr u, v;
779   MS_EXCEPTION_IF_NULL(alive_ops[0]);
780   MS_EXCEPTION_IF_NULL(alive_ops[1]);
781   const auto phase = CostModelContext::GetInstance()->run_phase();
782   const auto device_mem_capacity = CostModelContext::GetInstance()->device_memory_capacity();
783   if (!alive_ops[0]->GetAliveSuccEdges().empty() &&
784       alive_ops[0]->GetAliveSuccEdges()[0]->next_operator().get() == alive_ops[1].get()) {
785     u = alive_ops[0];
786     v = alive_ops[1];
787   } else if (!alive_ops[1]->GetAliveSuccEdges().empty() &&
788              alive_ops[1]->GetAliveSuccEdges()[0]->next_operator().get() == alive_ops[0].get()) {
789     u = alive_ops[1];
790     v = alive_ops[0];
791   } else {
792     if (!alive_ops[0]->GetAliveSuccEdges().empty() || !alive_ops[1]->GetAliveSuccEdges().empty()) {
793       MS_LOG(EXCEPTION) << "The final graph is not the case of u --> v, " << alive_ops[0]->GetAliveSuccEdges().size()
794                         << ", " << alive_ops[1]->GetAliveSuccEdges().size() << ".";
795     } else {
796       // In this case, the final graph consists of two single nodes
797       MS_LOG(INFO) << "There are 2 single nodes in the final graph.";
798       std::vector<CostPtrList> all_list;
799       auto connected_components = ConstructConnectedComponents(alive_ops);
800       MS_LOG(INFO) << "There are " << connected_components.size() << " components in the final graph.";
801       for (size_t i = 0; i < connected_components.size(); ++i) {
802         MS_LOG(INFO) << "There are 1 operator in a component in the final graph.";
803         auto one_component = connected_components[i];
804         MS_EXCEPTION_IF_NULL(one_component);
805         auto cost_list = one_component->CreateFinalSingleCostList(one_component->GetOperators()[0]);
806         all_list.push_back(cost_list);
807       }
808       CostPtrList selected_cost_list;
809       if (phase == TRAINING_PHASE) {
810         // training phase
811         selected_cost_list = SelectCostListWithMinTrainingTimeMultiple(all_list, device_mem_capacity);
812       } else {
813         // inference phase
814         MS_LOG(EXCEPTION) << "Currently, searching strategy for the two-separated-node final graph in the inference "
815                              "phase is not supported.";
816       }
817       for (size_t k = 0; k < selected_cost_list.size(); ++k) {
818         auto selected_cost = selected_cost_list[k];
819         if (selected_cost == nullptr) {
820           MS_LOG(ERROR) << "No valid strategy can be found under the current device memory: " << device_mem_capacity
821                         << ".";
822           return FAILED;
823         }
824         MS_EXCEPTION_IF_NULL(connected_components[k]);
825         auto one_operator = connected_components[k]->GetOperators()[0];
826         MS_EXCEPTION_IF_NULL(selected_cost->decision_ptr_);
827         auto decision = selected_cost->decision_ptr_->cast<FinalSingleDecisionPtr>();
828         MS_EXCEPTION_IF_NULL(decision);
829         one_operator->SetSelectedStrategyAndCost(decision->u_strategy_, decision->u_cost_);
830         MS_LOG(INFO) << "Searching the strategy for the component " << k << " final graph ended.";
831       }
832 
833       return SUCCESS;
834     }
835   }
836   MS_LOG(INFO) << "There are 2 nodes in the final graph.";
837   // In this case, the finale graph is exactly of the form: u --> v
838   MS_EXCEPTION_IF_NULL(u);
839   MS_EXCEPTION_IF_NULL(v);
840   auto e = u->GetAliveSuccEdges()[0];
841   MS_EXCEPTION_IF_NULL(e);
842   auto f_cost_list = CreateFinalCostList(u, e, v);
843   CostPtr cost = nullptr;
844   if (phase == TRAINING_PHASE) {
845     // training phase
846     cost = SelectCostWithMinTrainingTime(f_cost_list, device_mem_capacity);
847   } else {
848     MS_LOG(EXCEPTION) << "Currently, searching strategy for the two-connected-node final graph in the inference "
849                          "phase is not supported.";
850   }
851   if (cost == nullptr) {
852     MS_LOG(ERROR) << "No valid strategy can be found under the current device memory: " << device_mem_capacity << ".";
853     return FAILED;
854   }
855   MS_EXCEPTION_IF_NULL(cost->decision_ptr_);
856   auto f_decision = cost->decision_ptr_->cast<FinalDecisionPtr>();
857   MS_EXCEPTION_IF_NULL(f_decision);
858   u->SetSelectedStrategyAndCost(f_decision->u_strategy_, f_decision->left_cost_);
859   v->SetSelectedStrategyAndCost(f_decision->v_strategy_, f_decision->right_cost_);
860   e->set_selected_cost(f_decision->middle_cost_);
861   MS_LOG(INFO) << "Searching the strategy for the eliminated final graph ended.";
862   return SUCCESS;
863 }
864 
865 // searching the strategy for the final eliminated graph
SearchStrategy()866 Status CostGraph::SearchStrategy() {
867   MS_LOG(INFO) << "Searching the strategy for the eliminated final graph began.";
868   std::vector<OperatorInfoPtr> alive_ops;
869   (void)std::for_each(ops_.begin(), ops_.end(), [&alive_ops](const OperatorInfoPtr &op) {
870     MS_EXCEPTION_IF_NULL(op);
871     if (op->is_alive()) {
872       alive_ops.push_back(op);
873     }
874   });
875   const auto phase = CostModelContext::GetInstance()->run_phase();
876   const auto device_mem_capacity = CostModelContext::GetInstance()->device_memory_capacity();
877 
878   if (alive_ops.size() > 2) {
879     if (phase == TRAINING_PHASE) {
880       // training phase
881       return SearchStrategyForMultiNodeFinalGraph(alive_ops);
882     } else {
883       // inference phase
884       MS_LOG(EXCEPTION)
885         << "Currently, searching strategy for the multi-node final graph in inference phase is not supported.";
886     }
887   } else if (alive_ops.size() == 1) {
888     MS_LOG(INFO) << "There are 1 single node in the final graph.";
889     OperatorInfoPtr u = alive_ops[0];
890     auto cost_list = CreateFinalSingleCostList(u);
891     CostPtr cost = nullptr;
892     if (phase == TRAINING_PHASE) {
893       // training phase
894       cost = SelectCostWithMinTrainingTime(cost_list, device_mem_capacity);
895     } else {
896       // inference phase
897       cost = SelectCostWithMinInferenceTime(cost_list, device_mem_capacity);
898     }
899     if (cost == nullptr) {
900       MS_LOG(ERROR) << "No valid strategy can be found under the current device memory: " << device_mem_capacity << ".";
901       return FAILED;
902     }
903     MS_EXCEPTION_IF_NULL(u);
904     MS_EXCEPTION_IF_NULL(cost->decision_ptr_);
905     auto decision = cost->decision_ptr_->cast<FinalSingleDecisionPtr>();
906     MS_EXCEPTION_IF_NULL(decision);
907     u->SetSelectedStrategyAndCost(decision->u_strategy_, decision->u_cost_);
908     MS_LOG(INFO) << "Searching the strategy for the eliminated final graph ended.";
909     return SUCCESS;
910   } else {
911     return SearchStrategyForTwoNodeFinalGraph(alive_ops);
912   }
913 }
914 
915 // Given a graph which contains the following subgraph: u --> v --> w, the node v can be eliminated
916 // return the v and the edge u --> v
CheckOpElimination() const917 OperatorInfoPtr CostGraph::CheckOpElimination() const {
918   for (auto &op : ops_) {
919     bool bool_test = op->is_alive() && op->GetAliveSuccEdges().size() == 1 && op->GetAlivePrevEdges().size() == 1;
920     if (bool_test) {
921       if ((op->GetAliveSuccEdges()[0]->next_operator() != op) && (op->GetAlivePrevEdges()[0]->prev_operator() != op)) {
922         return op;
923       }
924     }
925   }
926   return nullptr;
927 }
928 
929 // Check the graph whether an EdgeElimination can be performed
CheckEdgeElimination() const930 std::vector<std::shared_ptr<Edge>> CostGraph::CheckEdgeElimination() const {
931   for (auto &op : ops_) {
932     MS_EXCEPTION_IF_NULL(op);
933     if (!op->is_alive()) {
934       continue;
935     }
936     std::map<void *, int64_t> count;
937     for (auto &edge_su : op->GetAliveSuccEdges()) {
938       MS_EXCEPTION_IF_NULL(edge_su);
939       auto v = edge_su->next_operator();
940       count[v.get()]++;
941     }
942     for (auto &pair : count) {
943       auto *op_ptr = pair.first;
944       int64_t op_count = pair.second;
945       if (op_count > 1) {
946         std::vector<std::shared_ptr<Edge>> ret;
947         for (auto &edge : op->GetAliveSuccEdges()) {
948           MS_EXCEPTION_IF_NULL(edge);
949           if (edge->next_operator().get() == op_ptr) {
950             ret.push_back(edge);
951           }
952         }
953         return ret;
954       }
955     }
956   }
957   return {};
958 }
959 
960 // Check the graph whether a MergeElimination can be performed
CheckMergeElimination() const961 OperatorInfoPtr CostGraph::CheckMergeElimination() const {
962   for (auto &op : ops_) {
963     MS_EXCEPTION_IF_NULL(op);
964     bool bool_test = op->is_alive() && op->GetAlivePrevEdges().empty() && op->GetAliveSuccEdges().size() == 1;
965     if (bool_test) {
966       auto next_op = op->GetAliveSuccEdges()[0]->next_operator();
967       MS_EXCEPTION_IF_NULL(next_op);
968       if (!next_op->GetAlivePrevEdges().empty()) {
969         return op;
970       }
971     }
972   }
973   return nullptr;
974 }
975 
976 // Check the graph whether a ContractElimination can be performed
CheckContractElimination() const977 OperatorInfoPtr CostGraph::CheckContractElimination() const {
978   for (auto &op : ops_) {
979     MS_EXCEPTION_IF_NULL(op);
980     bool bool_test = op->is_alive() && op->GetAlivePrevEdges().size() == 1 && op->GetAliveSuccEdges().empty();
981     if (bool_test) {
982       auto edge = op->GetAlivePrevEdges()[0];
983       MS_EXCEPTION_IF_NULL(edge);
984       auto prev_op = edge->prev_operator();
985       MS_EXCEPTION_IF_NULL(prev_op);
986       if (!prev_op->GetAliveSuccEdges().empty()) {
987         return op;
988       }
989     }
990   }
991   return nullptr;
992 }
993 
CheckSourceElimination() const994 std::pair<OperatorInfoPtr, OperatorInfoPtr> CostGraph::CheckSourceElimination() const {
995   size_t source_count = 0;
996   std::vector<OperatorInfoPtr> op_vector(2, nullptr);
997   for (auto &op : ops_) {
998     MS_EXCEPTION_IF_NULL(op);
999     bool bool_test = op->is_alive() && op->GetAlivePrevEdges().empty() && op->GetAliveSuccEdges().size() > 0;
1000     if (bool_test) {
1001       op_vector[source_count++] = op;
1002       if (source_count == 2) {
1003         return std::make_pair(op_vector[0], op_vector[1]);
1004       }
1005     }
1006   }
1007   return std::make_pair(nullptr, nullptr);
1008 }
1009 
CreateSourceEliminationSubCostList(StrategyPtr op1_old_stra,const CostPtrList & op1_old_clist,StrategyPtr op2_old_stra,const CostPtrList & op2_old_clist,CostPtrList * op1_new_clist) const1010 void CostGraph::CreateSourceEliminationSubCostList(StrategyPtr op1_old_stra, const CostPtrList &op1_old_clist,
1011                                                    StrategyPtr op2_old_stra, const CostPtrList &op2_old_clist,
1012                                                    CostPtrList *op1_new_clist) const {
1013   for (auto &op1_cost : op1_old_clist) {
1014     for (auto &op2_cost : op2_old_clist) {
1015       double computation = op1_cost->computation_cost_ + op2_cost->computation_cost_;
1016       double memory = op1_cost->memory_with_reuse_ + op2_cost->memory_with_reuse_;
1017       double communication = op1_cost->communication_cost_ + op2_cost->communication_cost_;
1018       double communication_forward = op1_cost->communication_forward_ + op2_cost->communication_forward_;
1019       double communication_without_para =
1020         op1_cost->communication_without_parameter_ + op2_cost->communication_without_parameter_;
1021       auto decision = std::make_shared<SourceEliminationDecision>(op1_old_stra, op1_cost, op2_old_stra, op2_cost);
1022       auto new_cost = std::make_shared<Cost>(computation, communication, decision);
1023       const auto gamma = CostModelContext::GetInstance()->costmodel_gamma();
1024       MS_EXCEPTION_IF_NULL(new_cost);
1025       new_cost->communication_without_parameter_ = communication_without_para;
1026       new_cost->communication_with_partial_para_ =
1027         communication_without_para + gamma * (communication - communication_without_para);
1028       new_cost->memory_with_reuse_ = memory;
1029       new_cost->communication_forward_ = communication_forward;
1030       MS_EXCEPTION_IF_NULL(op1_new_clist);
1031       op1_new_clist->push_back(std::move(new_cost));
1032     }
1033   }
1034 }
1035 
UpdateEdgesIncidentToNodes(OperatorInfoPtr op1,std::vector<EdgePtr> * op1_old_succ_edges,std::vector<std::map<CostPtrKey,CostPtrList>> * op1_new_edges_cost,std::vector<EdgePtr> * op1_new_succ_edges,const OperatorInfoPtr op2,std::vector<EdgePtr> * op2_old_succ_edges,std::vector<std::map<CostPtrKey,CostPtrList>> * op2_new_edges_cost,std::vector<EdgePtr> * op2_new_succ_edges)1036 std::pair<std::vector<EdgePtr>, std::vector<EdgePtr>> UpdateEdgesIncidentToNodes(
1037   OperatorInfoPtr op1, std::vector<EdgePtr> *op1_old_succ_edges,
1038   std::vector<std::map<CostPtrKey, CostPtrList>> *op1_new_edges_cost, std::vector<EdgePtr> *op1_new_succ_edges,
1039   const OperatorInfoPtr op2, std::vector<EdgePtr> *op2_old_succ_edges,
1040   std::vector<std::map<CostPtrKey, CostPtrList>> *op2_new_edges_cost, std::vector<EdgePtr> *op2_new_succ_edges) {
1041   for (size_t i = 0; i < op1_old_succ_edges->size(); ++i) {
1042     auto &new_cost_map = op1_new_edges_cost->at(i);
1043     auto ith_edge = op1_old_succ_edges->at(i);
1044 
1045     std::string new_edge_name = op1->name() + OPERATOR_TO_OPERATOR_CONNECTOR + ith_edge->next_operator()->name();
1046     std::shared_ptr<Edge> new_edge;
1047     if (ith_edge->is_combined()) {
1048       std::vector<size_t> output_indexs, input_indexs;
1049       output_indexs = ith_edge->prev_op_output_indexs();
1050       input_indexs = ith_edge->next_op_input_indexs();
1051       new_edge =
1052         std::make_shared<Edge>(new_edge_name, op1, ith_edge->next_operator(), output_indexs, input_indexs, true);
1053     } else {
1054       size_t output_index;
1055       size_t input_index;
1056       output_index = ith_edge->prev_op_output_index();
1057       input_index = ith_edge->next_op_input_index();
1058       new_edge =
1059         std::make_shared<Edge>(new_edge_name, op1, ith_edge->next_operator(), output_index, input_index, false);
1060     }
1061     new_edge->SetCostMapAndInputOutput(new_cost_map);
1062     // replace the old successive edges with the new ones.
1063     op1->ReplaceSuccEdge(ith_edge->next_operator(), new_edge);
1064     ith_edge->next_operator()->ReplacePreEdge(op1, new_edge);
1065     (void)op1_new_succ_edges->erase(op1_new_succ_edges->cbegin() + SizeToLong(i));
1066     (void)op1_new_succ_edges->emplace(op1_new_succ_edges->begin() + SizeToLong(i), new_edge);
1067   }
1068   for (size_t i = 0; i < op2_old_succ_edges->size(); ++i) {
1069     auto &new_cost_map = op2_new_edges_cost->at(i);
1070     auto ith_edge = op2_old_succ_edges->at(i);
1071     const auto &destination = ith_edge->next_operator();
1072 
1073     std::string new_edge_name = op1->name() + OPERATOR_TO_OPERATOR_CONNECTOR + destination->name();
1074     std::shared_ptr<Edge> new_edge;
1075     if (ith_edge->is_combined()) {
1076       std::vector<size_t> output_indexs, input_indexs;
1077       output_indexs = ith_edge->prev_op_output_indexs();
1078       input_indexs = ith_edge->next_op_input_indexs();
1079       new_edge = std::make_shared<Edge>(new_edge_name, op1, destination, output_indexs, input_indexs, true);
1080     } else {
1081       size_t output_index;
1082       size_t input_index;
1083       output_index = ith_edge->prev_op_output_index();
1084       input_index = ith_edge->next_op_input_index();
1085       new_edge = std::make_shared<Edge>(new_edge_name, op1, destination, output_index, input_index, false);
1086     }
1087     new_edge->SetCostMapAndInputOutput(new_cost_map);
1088     // replace the old successive edges with the new ones.
1089     destination->ReplacePreEdge(op2, new_edge);
1090     op1->AddSuccEdge(new_edge);
1091     (void)op2_new_succ_edges->erase(op2_new_succ_edges->cbegin() + SizeToLong(i));
1092     (void)op2_new_succ_edges->emplace(op2_new_succ_edges->begin() + SizeToLong(i), new_edge);
1093   }
1094   return std::make_pair(*op1_new_succ_edges, *op2_new_succ_edges);
1095 }
1096 
EliminationSources(const OperatorInfoPtr op1,const OperatorInfoPtr op2) const1097 std::pair<std::vector<std::shared_ptr<Edge>>, std::vector<std::shared_ptr<Edge>>> CostGraph::EliminationSources(
1098   const OperatorInfoPtr op1, const OperatorInfoPtr op2) const {
1099   MS_EXCEPTION_IF_NULL(op1);
1100   MS_EXCEPTION_IF_NULL(op2);
1101   MS_LOG(INFO) << "Now source eliminating node: " << op2->name() << " to node: " << op1->name();
1102 
1103   auto op1_old_succ_edges = op1->GetAliveSuccEdges();
1104   std::vector<std::map<StrategyPtr, std::vector<std::pair<StrategyPtr, CostPtrList>>>> op1_edges_reorganised_cost(
1105     op1_old_succ_edges.size());
1106   std::vector<std::map<CostPtrKey, CostPtrList>> op1_new_edges_cost(op1_old_succ_edges.size());
1107   std::vector<std::shared_ptr<Edge>> op1_new_succ_edges(op1_old_succ_edges.size());
1108 
1109   auto op2_old_succ_edges = op2->GetAliveSuccEdges();
1110   std::vector<std::map<StrategyPtr, std::vector<std::pair<StrategyPtr, CostPtrList>>>> op2_edges_reorganised_cost(
1111     op2_old_succ_edges.size());
1112   std::vector<std::map<CostPtrKey, CostPtrList>> op2_new_edges_cost(op2_old_succ_edges.size());
1113   std::vector<std::shared_ptr<Edge>> op2_new_succ_edges(op2_old_succ_edges.size());
1114 
1115   // Construct cost_map for the data_structure of 'op1_edges_reorganised_cost' and 'op2_edges_reorganised_cost'
1116   for (size_t i = 0; i < op1_old_succ_edges.size(); ++i) {
1117     const auto &op1_cost_map = op1_old_succ_edges[i]->GetCostMap();
1118     std::map<StrategyPtr, std::vector<std::pair<StrategyPtr, CostPtrList>>> from_tocost;
1119     for (const auto &key_value : op1_cost_map) {
1120       const auto &from_to_strategies = key_value.first;
1121       const auto &costlist = key_value.second;
1122       from_tocost[from_to_strategies.first].push_back(std::make_pair(from_to_strategies.second, costlist));
1123     }
1124     op1_edges_reorganised_cost[i] = from_tocost;
1125   }
1126 
1127   for (size_t i = 0; i < op2_old_succ_edges.size(); ++i) {
1128     const auto &op2_cost_map = op2_old_succ_edges[i]->GetCostMap();
1129     std::map<StrategyPtr, std::vector<std::pair<StrategyPtr, CostPtrList>>> from_tocost;
1130     for (const auto &key_value : op2_cost_map) {
1131       const auto &from_to_strategies = key_value.first;
1132       const auto &costlist = key_value.second;
1133       from_tocost[from_to_strategies.first].push_back(std::make_pair(from_to_strategies.second, costlist));
1134     }
1135     op2_edges_reorganised_cost[i] = from_tocost;
1136   }
1137 
1138   // Merge op2 into op1
1139   const auto &op1_old_stra_cost = op1->GetStrategyCost();
1140   const auto &op2_old_stra_cost = op2->GetStrategyCost();
1141   std::vector<std::shared_ptr<StrategyWithCost>> op1_new_stra_cost;
1142 
1143   for (auto &op1_stra_cost : op1_old_stra_cost) {
1144     auto op1_old_stra = op1_stra_cost->strategy_ptr;
1145     auto op1_old_costlist = op1_stra_cost->cost_list;
1146 
1147     for (auto &op2_stra_cost : op2_old_stra_cost) {
1148       auto op2_stra = op2_stra_cost->strategy_ptr;
1149       auto op2_costlist = op2_stra_cost->cost_list;
1150 
1151       StrategyPtr op1_new_stra = std::make_shared<Strategy>(*op1_old_stra);
1152       op1_new_stra->CoverStrategy(op2_stra);
1153       CostPtrList op1_new_costlist;
1154       // Calculate new cost for 'op1_new_costlist'
1155       CreateSourceEliminationSubCostList(op1_old_stra, op1_old_costlist, op2_stra, op2_costlist, &op1_new_costlist);
1156       std::shared_ptr<StrategyWithCost> swc = std::make_shared<StrategyWithCost>(op1_new_stra, op1_new_costlist);
1157       op1_new_stra_cost.push_back(swc);
1158 
1159       // Set cost for new successive edges of op1 and op2
1160       for (size_t i = 0; i < op1_old_succ_edges.size(); ++i) {
1161         auto &from_tocost = op1_edges_reorganised_cost[i];
1162         auto &to_cost = from_tocost[op1_old_stra];
1163         auto &new_cost_map = op1_new_edges_cost[i];
1164         for (auto &stra_costlit : to_cost) {
1165           auto &to_strategy = stra_costlit.first;
1166           auto &edge_costlist = stra_costlit.second;
1167           CostPtrKey new_key = {op1_new_stra, to_strategy};
1168           new_cost_map[new_key] = edge_costlist;
1169         }
1170       }
1171       for (size_t i = 0; i < op2_old_succ_edges.size(); ++i) {
1172         auto &from_tocost = op2_edges_reorganised_cost[i];
1173         auto &to_cost = from_tocost[op2_stra];
1174         auto &new_cost_map = op2_new_edges_cost[i];
1175         for (auto &stra_costlist : to_cost) {
1176           auto &to_strategy = stra_costlist.first;
1177           auto &edge_costlist = stra_costlist.second;
1178           CostPtrKey new_key = {op1_new_stra, to_strategy};
1179           new_cost_map[new_key] = edge_costlist;
1180         }
1181       }
1182     }
1183   }
1184   op1->SetStrategyCost(op1_new_stra_cost);
1185   op2->SetNotAlive();
1186 
1187   // Update the edges incident to op1, and edges incident to op2
1188   MS_LOG(INFO) << "Source eliminating node: " << op2->name() << " to node: " << op1->name() + " succeeded.";
1189   return UpdateEdgesIncidentToNodes(op1, &op1_old_succ_edges, &op1_new_edges_cost, &op1_new_succ_edges, op2,
1190                                     &op2_old_succ_edges, &op2_new_edges_cost, &op2_new_succ_edges);
1191 }
1192 
1193 // Check the graph whether a TriangleElimination can be performed
CheckTriangleElimination() const1194 std::pair<OperatorInfoPtr, std::shared_ptr<Edge>> CostGraph::CheckTriangleElimination() const {
1195   for (auto &op : ops_) {
1196     MS_EXCEPTION_IF_NULL(op);
1197     bool bool_test = (op->is_alive()) && (op->GetAlivePrevEdges().empty()) && (op->GetAliveSuccEdges().size() == 2);
1198     if (bool_test) {
1199       auto edge1 = op->GetAliveSuccEdges()[0];
1200       auto edge2 = op->GetAliveSuccEdges()[1];
1201       MS_EXCEPTION_IF_NULL(edge1);
1202       MS_EXCEPTION_IF_NULL(edge2);
1203       auto first_op = edge1->next_operator();
1204       auto second_op = edge2->next_operator();
1205       MS_EXCEPTION_IF_NULL(first_op);
1206       for (auto &first_op_succ_edge : first_op->GetAliveSuccEdges()) {
1207         if (first_op_succ_edge->next_operator() == second_op) {
1208           return {op, first_op_succ_edge};
1209         }
1210       }
1211       MS_EXCEPTION_IF_NULL(second_op);
1212       for (auto &second_op_succ_edge : second_op->GetAliveSuccEdges()) {
1213         if (second_op_succ_edge->next_operator() == first_op) {
1214           return {op, second_op_succ_edge};
1215         }
1216       }
1217     }
1218   }
1219   return {nullptr, nullptr};
1220 }
1221 
1222 // Check the graph whether a StarElimination can be performed.
1223 // NOTE: this elimination MUST be performed only when the above 5 operation cannot be applied.
CheckStarElimination() const1224 OperatorInfoPtr CostGraph::CheckStarElimination() const {
1225   for (auto &op : ops_) {
1226     MS_EXCEPTION_IF_NULL(op);
1227     bool bool_test = (op->is_alive()) && (op->GetAlivePrevEdges().empty()) && (op->GetAliveSuccEdges().size() > 1);
1228     if (bool_test) {
1229       return op;
1230     }
1231   }
1232   return nullptr;
1233 }
1234 
1235 // This method is for 'eliminating operator' operation in the DP algorithm. It creates a new edge to replace
1236 // 'lefe_edge', 'op' and 'right_edge'. As a consequence, it creates new costlist for the new edge.
EliminationOp(const OperatorInfoPtr & op) const1237 std::shared_ptr<Edge> CostGraph::EliminationOp(const OperatorInfoPtr &op) const {
1238   // in this case, the operators are organised in the form of u-->op-->v, and the goal
1239   // is to eliminate 'op'.
1240   MS_EXCEPTION_IF_NULL(op);
1241   MS_LOG(INFO) << "Now eliminating node: " << op->name() << ".";
1242   auto edge_u_op = op->GetAlivePrevEdges()[0];
1243   auto edge_op_v = op->GetAliveSuccEdges()[0];
1244   MS_EXCEPTION_IF_NULL(edge_u_op);
1245   MS_EXCEPTION_IF_NULL(edge_op_v);
1246   auto u = edge_u_op->prev_operator();
1247   auto v = edge_op_v->next_operator();
1248   std::vector<size_t> output_indexs;
1249   std::vector<size_t> input_indexs;
1250   size_t output_index, input_index;
1251   MS_EXCEPTION_IF_NULL(u);
1252   MS_EXCEPTION_IF_NULL(v);
1253   std::string new_edge_name = u->name() + OPERATOR_TO_OPERATOR_CONNECTOR + v->name();
1254   std::shared_ptr<Edge> new_edge;
1255   if (edge_u_op->is_combined()) {
1256     output_indexs = edge_u_op->prev_op_output_indexs();
1257   } else {
1258     output_index = edge_u_op->prev_op_output_index();
1259     output_indexs.push_back(output_index);
1260   }
1261   if (edge_op_v->is_combined()) {
1262     input_indexs = edge_op_v->next_op_input_indexs();
1263   } else {
1264     input_index = edge_op_v->next_op_input_index();
1265     input_indexs.push_back(input_index);
1266   }
1267 
1268   if (!edge_u_op->is_combined() && !edge_op_v->is_combined()) {
1269     new_edge = std::make_shared<Edge>(new_edge_name, u, v, output_index, input_index, false);
1270   } else {
1271     new_edge = std::make_shared<Edge>(new_edge_name, u, v, output_indexs, input_indexs, true);
1272   }
1273   MS_EXCEPTION_IF_NULL(new_edge);
1274   new_edge->set_pre_op_output(edge_u_op->prev_op_output());
1275   new_edge->set_next_op_input(edge_op_v->next_op_input());
1276   new_edge->OpEliminationSetNewCost(edge_u_op, op, edge_op_v);
1277   u->ReplaceSuccEdge(op, new_edge);
1278   v->ReplacePreEdge(op, new_edge);
1279   op->SetNotAlive();
1280   MS_LOG(INFO) << "Eliminating node: " << op->name() << " succeeded.";
1281   return new_edge;
1282 }
1283 
1284 // This method is for 'eliminating edges' operation in the DP algorithm. It creates a new edge to replace the 'edges',
1285 // and sets new costlist for the new edge.
EliminationEdges(const std::vector<std::shared_ptr<Edge>> & edges) const1286 std::shared_ptr<Edge> CostGraph::EliminationEdges(const std::vector<std::shared_ptr<Edge>> &edges) const {
1287   MS_LOG(INFO) << "Now eliminating " << edges.size() << " edges.";
1288   MS_EXCEPTION_IF_NULL(edges[0]);
1289   auto u = edges[0]->prev_operator();
1290   auto v = edges[0]->next_operator();
1291   MS_EXCEPTION_IF_NULL(u);
1292   MS_EXCEPTION_IF_NULL(v);
1293   std::string new_edge_name = u->name() + OPERATOR_TO_OPERATOR_CONNECTOR + v->name();
1294   std::vector<size_t> output_indexs, input_indexs;
1295 
1296   for (auto &edge : edges) {
1297     MS_EXCEPTION_IF_NULL(edge);
1298     if (edge->is_combined()) {
1299       auto from_output_indexs = edge->prev_op_output_indexs();
1300       auto from_input_indexs = edge->next_op_input_indexs();
1301       (void)std::copy(from_output_indexs.begin(), from_output_indexs.end(), std::back_inserter(output_indexs));
1302       (void)std::copy(from_input_indexs.begin(), from_input_indexs.end(), std::back_inserter(input_indexs));
1303     } else {
1304       output_indexs.push_back(edge->prev_op_output_index());
1305       input_indexs.push_back(edge->next_op_input_index());
1306     }
1307   }
1308 
1309   std::shared_ptr<Edge> new_edge = std::make_shared<Edge>(new_edge_name, u, v, output_indexs, input_indexs, true);
1310   MS_EXCEPTION_IF_NULL(new_edge);
1311   new_edge->set_pre_op_output(edges[0]->prev_op_output());
1312   new_edge->set_next_op_input(edges[0]->next_op_input());
1313 
1314   new_edge->EdgeEliminationSetNewCost(u, edges, v);
1315 
1316   u->ReplaceSuccEdges(v, new_edge);
1317   v->ReplacePreEdges(u, new_edge);
1318   MS_LOG(INFO) << "Eliminating " << edges.size() << " edges succeeded.";
1319   return new_edge;
1320 }
1321 
1322 // Given 'op_cost_list', 'edge_cost_list', and 'tar_cost_list', this method is to create 'tar_cost_list_new'
1323 // for this contract under the strategy 'op_strategy'
CreateMergeEliminationSubCostList(StrategyPtr op_strategy,const CostPtrList & op_cost_list,const CostPtrList & edge_cost_list,StrategyPtr tar_op_strategy,const CostPtrList & tar_cost_list,CostPtrList * const tar_cost_list_new) const1324 void CostGraph::CreateMergeEliminationSubCostList(StrategyPtr op_strategy, const CostPtrList &op_cost_list,
1325                                                   const CostPtrList &edge_cost_list, StrategyPtr tar_op_strategy,
1326                                                   const CostPtrList &tar_cost_list,
1327                                                   CostPtrList *const tar_cost_list_new) const {
1328   for (size_t i = 0; i < op_cost_list.size(); ++i) {
1329     auto &op_cost = op_cost_list[i];
1330     MS_EXCEPTION_IF_NULL(op_cost);
1331     for (size_t j = 0; j < edge_cost_list.size(); ++j) {
1332       auto &edge_cost = edge_cost_list[j];
1333       MS_EXCEPTION_IF_NULL(edge_cost);
1334       for (size_t k = 0; k < tar_cost_list.size(); ++k) {
1335         auto &tar_cost = tar_cost_list[k];
1336         MS_EXCEPTION_IF_NULL(tar_cost);
1337         double computation = op_cost->computation_cost_ + edge_cost->computation_cost_ + tar_cost->computation_cost_;
1338         double memory = op_cost->memory_with_reuse_ + edge_cost->memory_with_reuse_ + tar_cost->memory_with_reuse_;
1339         double communication =
1340           op_cost->communication_cost_ + edge_cost->communication_cost_ + tar_cost->communication_cost_;
1341         double communication_forward =
1342           op_cost->communication_forward_ + edge_cost->communication_forward_ + tar_cost->communication_forward_;
1343         double communication_without_para = op_cost->communication_without_parameter_ +
1344                                             edge_cost->communication_without_parameter_ +
1345                                             tar_cost->communication_without_parameter_;
1346 
1347         auto decision =
1348           std::make_shared<MergeEliminationDecision>(op_strategy, op_cost, edge_cost, tar_op_strategy, tar_cost);
1349         auto new_cost = std::make_shared<Cost>(computation, communication, decision);
1350         const auto gamma = CostModelContext::GetInstance()->costmodel_gamma();
1351         MS_EXCEPTION_IF_NULL(new_cost);
1352         new_cost->communication_without_parameter_ = communication_without_para;
1353         new_cost->communication_with_partial_para_ =
1354           communication_without_para + gamma * (communication - communication_without_para);
1355         new_cost->memory_with_reuse_ = memory;
1356         new_cost->communication_forward_ = communication_forward;
1357         MS_EXCEPTION_IF_NULL(tar_cost_list_new);
1358         tar_cost_list_new->push_back(std::move(new_cost));
1359       }
1360     }
1361   }
1362 }
1363 
1364 // This method is for the 'Merge' operation in DP algorithm. It creates new costlist for each strategy in the
1365 // target_op
EliminationMerge(const OperatorInfoPtr & op) const1366 OperatorInfoPtr CostGraph::EliminationMerge(const OperatorInfoPtr &op) const {
1367   MS_EXCEPTION_IF_NULL(op);
1368   auto target_op = op->GetAliveSuccEdges()[0]->next_operator();
1369   auto edge_ptr = op->GetAliveSuccEdges()[0];
1370   MS_EXCEPTION_IF_NULL(target_op);
1371   MS_EXCEPTION_IF_NULL(edge_ptr);
1372   MS_LOG(INFO) << "Now merging " << op->name() << " into " << target_op->name() << ".";
1373   bool valid = false;
1374 
1375   for (auto &tar_stra_cost : target_op->GetStrategyCost()) {
1376     MS_EXCEPTION_IF_NULL(tar_stra_cost);
1377     auto tar_stra = tar_stra_cost->strategy_ptr;
1378     auto tar_clist_origin = tar_stra_cost->cost_list;
1379     CostPtrList tar_clist_new;
1380 
1381     for (auto &op_stra_cost : op->GetStrategyCost()) {
1382       MS_EXCEPTION_IF_NULL(op_stra_cost);
1383       auto op_stra = op_stra_cost->strategy_ptr;
1384       auto op_clist = op_stra_cost->cost_list;
1385       auto edge_clist = edge_ptr->GetCostList(op_stra, tar_stra);
1386 
1387       CreateMergeEliminationSubCostList(op_stra, op_clist, edge_clist, tar_stra, tar_clist_origin, &tar_clist_new);
1388     }
1389     Simplify(&tar_clist_new);
1390     // Set the new costlist w.r.t the strategy
1391     tar_stra_cost->cost_list = tar_clist_new;
1392     if ((!valid) && (!tar_clist_new.empty())) {
1393       valid = true;
1394     }
1395   }
1396 
1397   if (!valid) {
1398     MS_LOG(EXCEPTION) << "Merging " << op->name() << " into " << target_op->name() << " failed.";
1399   }
1400   op->SetNotAlive();
1401   MS_LOG(INFO) << "Merging " << op->name() << " into " << target_op->name() << " succeeded.";
1402   return target_op;
1403 }
1404 
1405 // Given 'contract_op_cost_list', 'edge_cost_list', and 'tar_cost_list', this method is to create 'tar_cost_list_new'
1406 // for this contract under the strategy 'contract_op_stra'
CreateContractEliminationSubCostList(StrategyPtr contract_op_stra,const CostPtrList & contract_op_cost_list,const CostPtrList & edge_cost_list,StrategyPtr target_op_stra,const CostPtrList & tar_cost_list,CostPtrList * tar_cost_list_new) const1407 void CostGraph::CreateContractEliminationSubCostList(StrategyPtr contract_op_stra,
1408                                                      const CostPtrList &contract_op_cost_list,
1409                                                      const CostPtrList &edge_cost_list, StrategyPtr target_op_stra,
1410                                                      const CostPtrList &tar_cost_list,
1411                                                      CostPtrList *tar_cost_list_new) const {
1412   for (size_t i = 0; i < contract_op_cost_list.size(); ++i) {
1413     auto &contract_op_cost = contract_op_cost_list[i];
1414     MS_EXCEPTION_IF_NULL(contract_op_cost);
1415     for (size_t j = 0; j < edge_cost_list.size(); ++j) {
1416       auto &edge_cost = edge_cost_list[j];
1417       MS_EXCEPTION_IF_NULL(edge_cost);
1418       for (size_t k = 0; k < tar_cost_list.size(); ++k) {
1419         auto &tar_cost = tar_cost_list[k];
1420         MS_EXCEPTION_IF_NULL(tar_cost);
1421         double computation =
1422           contract_op_cost->computation_cost_ + edge_cost->computation_cost_ + tar_cost->computation_cost_;
1423         double memory =
1424           contract_op_cost->memory_with_reuse_ + edge_cost->memory_with_reuse_ + tar_cost->memory_with_reuse_;
1425         double communication =
1426           contract_op_cost->communication_cost_ + edge_cost->communication_cost_ + tar_cost->communication_cost_;
1427         double communication_forward = contract_op_cost->communication_forward_ + edge_cost->communication_forward_ +
1428                                        tar_cost->communication_forward_;
1429         double communication_without_para = contract_op_cost->communication_without_parameter_ +
1430                                             edge_cost->communication_without_parameter_ +
1431                                             tar_cost->communication_without_parameter_;
1432 
1433         auto decision = std::make_shared<ContractEliminationDecision>(contract_op_stra, contract_op_cost, edge_cost,
1434                                                                       target_op_stra, tar_cost);
1435         auto new_cost = std::make_shared<Cost>(computation, communication, decision);
1436         auto gamma = CostModelContext::GetInstance()->costmodel_gamma();
1437         new_cost->communication_without_parameter_ = communication_without_para;
1438         new_cost->communication_with_partial_para_ =
1439           communication_without_para + gamma * (communication - communication_without_para);
1440         new_cost->memory_with_reuse_ = memory;
1441         new_cost->communication_forward_ = communication_forward;
1442         tar_cost_list_new->push_back(std::move(new_cost));
1443       }
1444     }
1445   }
1446 }
1447 
1448 // This method is for the 'Contract' operation in DP algorithm. It creates new costlist for each strategy in the
1449 // target_op
EliminationContract(const OperatorInfoPtr & op) const1450 OperatorInfoPtr CostGraph::EliminationContract(const OperatorInfoPtr &op) const {
1451   MS_EXCEPTION_IF_NULL(op);
1452   auto target_op = op->GetAlivePrevEdges()[0]->prev_operator();
1453   auto edge_ptr = op->GetAlivePrevEdges()[0];
1454   MS_LOG(INFO) << "Now contracting " << op->name() << " into " << target_op->name() << ".";
1455   bool valid = false;
1456 
1457   for (auto &tar_stra_cost : target_op->GetStrategyCost()) {
1458     MS_EXCEPTION_IF_NULL(tar_stra_cost);
1459     auto tar_stra = tar_stra_cost->strategy_ptr;
1460     auto tar_clist_origin = tar_stra_cost->cost_list;
1461     CostPtrList tar_clist_new;
1462 
1463     for (auto &op_stra_cost : op->GetStrategyCost()) {
1464       MS_EXCEPTION_IF_NULL(op_stra_cost);
1465       auto op_stra = op_stra_cost->strategy_ptr;
1466       auto op_clist = op_stra_cost->cost_list;
1467       auto edge_clist = edge_ptr->GetCostList(tar_stra, op_stra);
1468 
1469       CreateContractEliminationSubCostList(op_stra, op_clist, edge_clist, tar_stra, tar_clist_origin, &tar_clist_new);
1470     }
1471     Simplify(&tar_clist_new);
1472     // Set the new costlist w.r.t the strategy
1473     tar_stra_cost->cost_list = tar_clist_new;
1474     if ((!valid) && (!tar_clist_new.empty())) {
1475       valid = true;
1476     }
1477   }
1478   if (!valid) {
1479     MS_LOG(EXCEPTION) << "Contracting " << op->name() << " into " << target_op->name() << " failed.";
1480   }
1481   op->SetNotAlive();
1482   MS_LOG(INFO) << "Contracting " << op->name() << " into " << target_op->name() << " succeeded.";
1483   return target_op;
1484 }
1485 
CreateTriangleEliminationSubCostList(StrategyPtr elimi_op_stra,StrategyPtr left_op_stra,StrategyPtr right_op_stra,const CostPtr & right_op_cost,const CostPtrList & elimi_op_clist,const CostPtrList & left_edge_clist,const CostPtr & right_edge_cost,const CostPtrList & left_node_clist_origin,CostPtrList * left_node_clist_new) const1486 void CostGraph::CreateTriangleEliminationSubCostList(StrategyPtr elimi_op_stra, StrategyPtr left_op_stra,
1487                                                      StrategyPtr right_op_stra, const CostPtr &right_op_cost,
1488                                                      const CostPtrList &elimi_op_clist,
1489                                                      const CostPtrList &left_edge_clist, const CostPtr &right_edge_cost,
1490                                                      const CostPtrList &left_node_clist_origin,
1491                                                      CostPtrList *left_node_clist_new) const {
1492   MS_EXCEPTION_IF_NULL(right_edge_cost);
1493   MS_EXCEPTION_IF_NULL(right_op_cost);
1494   MS_EXCEPTION_IF_NULL(left_node_clist_new);
1495   for (auto &elimi_op_cost : elimi_op_clist) {
1496     MS_EXCEPTION_IF_NULL(elimi_op_cost);
1497     for (auto &left_edge_cost : left_edge_clist) {
1498       MS_EXCEPTION_IF_NULL(left_edge_cost);
1499       for (auto &left_node_cost : left_node_clist_origin) {
1500         MS_EXCEPTION_IF_NULL(left_node_cost);
1501         double new_computation = elimi_op_cost->computation_cost_ + left_edge_cost->computation_cost_ +
1502                                  left_node_cost->computation_cost_ + right_edge_cost->computation_cost_;
1503         double new_memory = elimi_op_cost->memory_with_reuse_ + left_edge_cost->memory_with_reuse_ +
1504                             left_node_cost->memory_with_reuse_ + right_edge_cost->memory_with_reuse_;
1505         double new_commu_cost = elimi_op_cost->communication_cost_ + left_edge_cost->communication_cost_ +
1506                                 left_node_cost->communication_cost_ + right_edge_cost->communication_cost_;
1507         double new_commu_forward = elimi_op_cost->communication_forward_ + left_edge_cost->communication_forward_ +
1508                                    left_node_cost->communication_forward_ + right_edge_cost->communication_forward_;
1509         double new_commu_without =
1510           elimi_op_cost->communication_without_parameter_ + left_edge_cost->communication_without_parameter_ +
1511           left_node_cost->communication_without_parameter_ + right_edge_cost->communication_without_parameter_;
1512         const auto triangle_star_stra_overwrite = CostModelContext::GetInstance()->triangle_star_strategy_overwrite();
1513         const auto gamma = CostModelContext::GetInstance()->costmodel_gamma();
1514 
1515         if (triangle_star_stra_overwrite) {
1516           new_computation += right_op_cost->computation_cost_;
1517           new_memory += right_op_cost->memory_with_reuse_;
1518           new_commu_cost += right_op_cost->communication_cost_;
1519           new_commu_forward += right_op_cost->communication_forward_;
1520           new_commu_without += right_op_cost->communication_without_parameter_;
1521         }
1522 
1523         auto decision =
1524           std::make_shared<TriangleEliminationDecision>(elimi_op_stra, elimi_op_cost, left_edge_cost, right_edge_cost,
1525                                                         left_op_stra, left_node_cost, right_op_stra, right_op_cost);
1526         auto new_cost = std::make_shared<Cost>(new_computation, new_commu_cost, decision);
1527         new_cost->communication_without_parameter_ = new_commu_without;
1528         new_cost->communication_with_partial_para_ = new_commu_without + gamma * (new_commu_cost - new_commu_without);
1529         new_cost->memory_with_reuse_ = new_memory;
1530         new_cost->communication_forward_ = new_commu_forward;
1531         left_node_clist_new->push_back(std::move(new_cost));
1532       }
1533     }
1534   }
1535 }
1536 
CreateTriangleEliminationCostList(const OperatorInfoPtr & elimi_op,const CostPtrList & right_node_clist,const CostPtrList & right_edge_clist,const StrategyPtr & elimi_op_stra,const StrategyPtr & left_node_stra,const StrategyPtr & right_node_stra,const CostPtrList & elimi_op_clist,const CostPtrList & left_edge_clist,const CostPtrList & left_node_clist_origin,CostPtrList * left_node_clist_new) const1537 void CostGraph::CreateTriangleEliminationCostList(const OperatorInfoPtr &elimi_op, const CostPtrList &right_node_clist,
1538                                                   const CostPtrList &right_edge_clist, const StrategyPtr &elimi_op_stra,
1539                                                   const StrategyPtr &left_node_stra, const StrategyPtr &right_node_stra,
1540                                                   const CostPtrList &elimi_op_clist, const CostPtrList &left_edge_clist,
1541                                                   const CostPtrList &left_node_clist_origin,
1542                                                   CostPtrList *left_node_clist_new) const {
1543   MS_EXCEPTION_IF_NULL(elimi_op);
1544   for (auto &right_node_cost : right_node_clist) {
1545     MS_EXCEPTION_IF_NULL(right_node_cost);
1546     for (auto &right_edge_cost : right_edge_clist) {
1547       MS_EXCEPTION_IF_NULL(right_edge_cost);
1548       CreateTriangleEliminationSubCostList(elimi_op_stra, left_node_stra, right_node_stra, right_node_cost,
1549                                            elimi_op_clist, left_edge_clist, right_edge_cost, left_node_clist_origin,
1550                                            left_node_clist_new);
1551     }
1552   }
1553 }
1554 
EliminationTriangle(const OperatorInfoPtr & elimi_op,const std::shared_ptr<Edge> & edge_left_right) const1555 OperatorInfoPtr CostGraph::EliminationTriangle(const OperatorInfoPtr &elimi_op,
1556                                                const std::shared_ptr<Edge> &edge_left_right) const {
1557   MS_EXCEPTION_IF_NULL(edge_left_right);
1558   MS_EXCEPTION_IF_NULL(elimi_op);
1559   MS_LOG(INFO) << "Now eliminating triangle: " << elimi_op->name() << ".";
1560   auto left_node = edge_left_right->prev_operator();
1561   auto right_node = edge_left_right->next_operator();
1562   auto left_edge = elimi_op->GetAliveSuccEdges()[0];
1563   auto right_edge = elimi_op->GetAliveSuccEdges()[1];
1564   MS_EXCEPTION_IF_NULL(left_node);
1565   MS_EXCEPTION_IF_NULL(right_node);
1566   MS_EXCEPTION_IF_NULL(left_edge);
1567   MS_EXCEPTION_IF_NULL(right_edge);
1568   MS_LOG(INFO) << "The left operator is: " << left_node->name() << ".";
1569   MS_LOG(INFO) << "The right operator is: " << right_node->name() << ".";
1570 
1571   if (left_edge->next_operator() != left_node) {
1572     auto tmp = left_edge;
1573     left_edge = right_edge;
1574     right_edge = tmp;
1575   }
1576   bool valid = false;
1577 
1578   for (auto &left_node_stra_cost : left_node->GetStrategyCost()) {
1579     MS_EXCEPTION_IF_NULL(left_node_stra_cost);
1580     auto left_node_stra = left_node_stra_cost->strategy_ptr;
1581     auto left_node_clist_origin = left_node_stra_cost->cost_list;
1582     CostPtrList left_node_clist_new;
1583 
1584     for (auto &elimi_op_stra_cost : elimi_op->GetStrategyCost()) {
1585       MS_EXCEPTION_IF_NULL(elimi_op_stra_cost);
1586       auto elimi_op_stra = elimi_op_stra_cost->strategy_ptr;
1587       auto elimi_op_clist = elimi_op_stra_cost->cost_list;
1588       auto left_edge_clist = left_edge->GetCostList(elimi_op_stra, left_node_stra);
1589 
1590       for (auto &right_node_stra_cost : right_node->GetStrategyCost()) {
1591         MS_EXCEPTION_IF_NULL(right_node_stra_cost);
1592         auto right_node_stra = right_node_stra_cost->strategy_ptr;
1593         auto right_node_clist = right_node_stra_cost->cost_list;
1594         auto right_edge_clist = right_edge->GetCostList(elimi_op_stra, right_node_stra);
1595 
1596         CreateTriangleEliminationCostList(elimi_op, right_node_clist, right_edge_clist, elimi_op_stra, left_node_stra,
1597                                           right_node_stra, elimi_op_clist, left_edge_clist, left_node_clist_origin,
1598                                           &left_node_clist_new);
1599       }
1600     }
1601     Simplify(&left_node_clist_new);
1602     // Set the new costlist w.r.t the strategy
1603     left_node_stra_cost->cost_list = left_node_clist_new;
1604     if ((!valid) && (!left_node_clist_new.empty())) {
1605       valid = true;
1606     }
1607   }
1608 
1609   if (!valid) {
1610     MS_LOG(EXCEPTION) << "Eliminating triangle: " << elimi_op->name()
1611                       << " failed. It may be caused by "
1612                          "configuring inconsistent strategies for operators.";
1613   }
1614   elimi_op->SetNotAlive();
1615   MS_LOG(INFO) << "Eliminating triangle: " << elimi_op->name() << " succeeded.";
1616   return left_node;
1617 }
1618 
CreateStarEliminationSubCostList(const StrategyPtr & first_succ_node_stra,const CostPtrList & first_succ_node_clist,const CostPtrList & first_succ_edge_clist,const StrategyPtr & merged_op_stra,const CostPtrList & merged_op_clist,std::vector<StrategyPtr> succ_nodes_stras,CostPtrList & succ_edges_costs,CostPtrList & succ_nodes_costs,CostPtrList * first_succ_node_clist_new) const1619 void CostGraph::CreateStarEliminationSubCostList(const StrategyPtr &first_succ_node_stra,
1620                                                  const CostPtrList &first_succ_node_clist,
1621                                                  const CostPtrList &first_succ_edge_clist,
1622                                                  const StrategyPtr &merged_op_stra, const CostPtrList &merged_op_clist,
1623                                                  std::vector<StrategyPtr> succ_nodes_stras,
1624                                                  CostPtrList &succ_edges_costs, CostPtrList &succ_nodes_costs,
1625                                                  CostPtrList *first_succ_node_clist_new) const {
1626   for (auto &first_succ_node_cost : first_succ_node_clist) {
1627     for (auto &first_succ_edge_cost : first_succ_edge_clist) {
1628       for (auto &merged_node_cost : merged_op_clist) {
1629         MS_EXCEPTION_IF_NULL(merged_node_cost);
1630         succ_nodes_stras[0] = first_succ_node_stra;
1631         succ_edges_costs[0] = first_succ_edge_cost;
1632         succ_nodes_costs[0] = first_succ_node_cost;
1633 
1634         double computation_cost = merged_node_cost->computation_cost_,
1635                memory_cost = merged_node_cost->memory_with_reuse_, commu_cost = merged_node_cost->communication_cost_,
1636                commu_without = merged_node_cost->communication_without_parameter_,
1637                commu_forward = merged_node_cost->communication_forward_;
1638         const auto triangle_star_stra_overwrite = CostModelContext::GetInstance()->triangle_star_strategy_overwrite();
1639         const auto gamma = CostModelContext::GetInstance()->costmodel_gamma();
1640         for (size_t i = 0; i < succ_nodes_stras.size(); ++i) {
1641           MS_EXCEPTION_IF_NULL(succ_edges_costs[i]);
1642           if (i == 0) {
1643             computation_cost += succ_edges_costs[i]->computation_cost_ + succ_nodes_costs[i]->computation_cost_;
1644             memory_cost += succ_edges_costs[i]->memory_with_reuse_ + succ_nodes_costs[i]->memory_with_reuse_;
1645             commu_cost += succ_edges_costs[i]->communication_cost_ + succ_nodes_costs[i]->communication_cost_;
1646             commu_forward += succ_edges_costs[i]->communication_forward_ + succ_nodes_costs[i]->communication_forward_;
1647             commu_without += succ_edges_costs[i]->communication_without_parameter_ +
1648                              succ_nodes_costs[i]->communication_without_parameter_;
1649           } else {
1650             computation_cost += succ_edges_costs[i]->computation_cost_;
1651             memory_cost += succ_edges_costs[i]->memory_with_reuse_;
1652             commu_cost += succ_edges_costs[i]->communication_cost_;
1653             commu_forward += succ_edges_costs[i]->communication_forward_;
1654             commu_without += succ_edges_costs[i]->communication_without_parameter_;
1655             if (triangle_star_stra_overwrite) {
1656               computation_cost += succ_nodes_costs[i]->computation_cost_;
1657               memory_cost += succ_nodes_costs[i]->memory_with_reuse_;
1658               commu_cost += succ_nodes_costs[i]->communication_cost_;
1659               commu_forward += succ_nodes_costs[i]->communication_forward_;
1660               commu_without += succ_nodes_costs[i]->communication_without_parameter_;
1661             }
1662           }
1663         }
1664 
1665         auto decision = std::make_shared<StarEliminationDecision>(merged_op_stra, merged_node_cost, succ_edges_costs,
1666                                                                   succ_nodes_stras, succ_nodes_costs);
1667         auto new_cost = std::make_shared<Cost>(computation_cost, commu_cost, decision);
1668         new_cost->communication_without_parameter_ = commu_without;
1669         new_cost->communication_with_partial_para_ = commu_without + gamma * (commu_cost - commu_without);
1670         new_cost->memory_with_reuse_ = memory_cost;
1671         new_cost->communication_forward_ = commu_forward;
1672         first_succ_node_clist_new->push_back(std::move(new_cost));
1673       }
1674     }
1675   }
1676 }
1677 
CreateStarEliminationCostList(std::vector<std::shared_ptr<Edge>> & succ_edges,const StrategyPtr & first_succ_node_stra,const CostPtrList & first_succ_node_clist,const CostPtrList & first_succ_edge_clist,const StrategyPtr & merged_op_stra,const CostPtrList & merged_op_clist,CostPtrList * first_succ_node_clist_new) const1678 void CostGraph::CreateStarEliminationCostList(std::vector<std::shared_ptr<Edge>> &succ_edges,
1679                                               const StrategyPtr &first_succ_node_stra,
1680                                               const CostPtrList &first_succ_node_clist,
1681                                               const CostPtrList &first_succ_edge_clist,
1682                                               const StrategyPtr &merged_op_stra, const CostPtrList &merged_op_clist,
1683                                               CostPtrList *first_succ_node_clist_new) const {
1684   std::vector<StrategyPtr> succ_nodes_stras(succ_edges.size(), nullptr);
1685   CostPtrList succ_edges_costs(succ_edges.size(), nullptr), succ_nodes_costs(succ_edges.size(), nullptr);
1686   std::function<void(size_t)> recursive = [&first_succ_node_stra, &first_succ_node_clist, &first_succ_edge_clist,
1687                                            &merged_op_stra, &merged_op_clist, &succ_nodes_stras, &succ_edges_costs,
1688                                            &succ_nodes_costs, &first_succ_node_clist_new, &succ_edges, &recursive,
1689                                            this](size_t k) {
1690     if (k == succ_edges.size()) {
1691       CreateStarEliminationSubCostList(first_succ_node_stra, first_succ_node_clist, first_succ_edge_clist,
1692                                        merged_op_stra, merged_op_clist, succ_nodes_stras, succ_edges_costs,
1693                                        succ_nodes_costs, first_succ_node_clist_new);
1694       return;
1695     }
1696     MS_LOG(DEBUG) << "The size of first_succ_node_clist: " << first_succ_node_clist.size()
1697                   << ", first_succ_edge_clist: " << first_succ_edge_clist.size()
1698                   << ", merged_op_clist: " << merged_op_clist.size()
1699                   << ", first_succ_node_clist_new: " << first_succ_node_clist_new->size() << ".";
1700     auto succ_edge = succ_edges[k];
1701     MS_EXCEPTION_IF_NULL(succ_edge);
1702     auto succ_node = succ_edge->next_operator();
1703     MS_EXCEPTION_IF_NULL(succ_node);
1704     for (auto &succ_node_stra_cost : succ_node->GetStrategyCost()) {
1705       MS_EXCEPTION_IF_NULL(succ_node_stra_cost);
1706       auto succ_node_stra = succ_node_stra_cost->strategy_ptr;
1707       auto succ_node_clist = succ_node_stra_cost->cost_list;
1708       auto succ_edge_clist = succ_edge->GetCostList(merged_op_stra, succ_node_stra);
1709 
1710       for (auto &succ_node_cost : succ_node_clist) {
1711         MS_EXCEPTION_IF_NULL(succ_node_cost);
1712         for (auto &succ_edge_cost : succ_edge_clist) {
1713           MS_EXCEPTION_IF_NULL(succ_edge_cost);
1714           succ_nodes_stras[k] = succ_node_stra;
1715           succ_edges_costs[k] = succ_edge_cost;
1716           succ_nodes_costs[k] = succ_node_cost;
1717           recursive(k + 1);
1718         }
1719       }
1720     }
1721   };
1722 
1723   recursive(1);
1724 }
1725 
EliminationStar(const OperatorInfoPtr & merged_op) const1726 std::vector<std::shared_ptr<Edge>> CostGraph::EliminationStar(const OperatorInfoPtr &merged_op) const {
1727   MS_EXCEPTION_IF_NULL(merged_op);
1728   auto succ_edges = merged_op->GetAliveSuccEdges();
1729   MS_LOG(INFO) << "Now eliminating star centered at: " << merged_op->name() << ".";
1730   for (auto &succ_edge : succ_edges) {
1731     MS_EXCEPTION_IF_NULL(succ_edge->next_operator());
1732     MS_LOG(INFO) << "The successive operator is: " << succ_edge->next_operator()->name() << ".";
1733   }
1734 
1735   MS_EXCEPTION_IF_NULL(succ_edges[0]);
1736   auto first_succ_node = succ_edges[0]->next_operator();
1737   auto first_succ_edge = succ_edges[0];
1738   bool valid = false;
1739 
1740   // 'merged_op' is merged into first_node
1741   MS_EXCEPTION_IF_NULL(first_succ_node);
1742   for (auto &first_succ_node_stra_cost : first_succ_node->GetStrategyCost()) {
1743     MS_EXCEPTION_IF_NULL(first_succ_node_stra_cost);
1744     auto first_succ_node_stra = first_succ_node_stra_cost->strategy_ptr;
1745     auto first_succ_node_clist = first_succ_node_stra_cost->cost_list;
1746     CostPtrList first_succ_node_clist_new;
1747 
1748     for (auto &merged_op_stra_cost : merged_op->GetStrategyCost()) {
1749       MS_EXCEPTION_IF_NULL(merged_op_stra_cost);
1750       auto merged_op_stra = merged_op_stra_cost->strategy_ptr;
1751       auto merged_op_clist = merged_op_stra_cost->cost_list;
1752       auto first_succ_edge_clist = first_succ_edge->GetCostList(merged_op_stra, first_succ_node_stra);
1753 
1754       CreateStarEliminationCostList(succ_edges, first_succ_node_stra, first_succ_node_clist, first_succ_edge_clist,
1755                                     merged_op_stra, merged_op_clist, &first_succ_node_clist_new);
1756     }
1757     Simplify(&first_succ_node_clist_new);
1758     // Set the new costlist w.r.t the strategy
1759     first_succ_node_stra_cost->cost_list = first_succ_node_clist_new;
1760     if ((!valid) && (!first_succ_node_clist_new.empty())) {
1761       valid = true;
1762     }
1763   }
1764 
1765   if (!valid) {
1766     MS_LOG(EXCEPTION) << "Eliminating star centered at: " << merged_op->name()
1767                       << " failed. It may be caused by "
1768                          "configuring inconsistent strategies for operators.";
1769   }
1770 
1771   merged_op->SetNotAlive();
1772   MS_LOG(INFO) << "Eliminating star centered at: " << merged_op->name() << " succeeded.";
1773   return succ_edges;
1774 }
1775 
GetNumEdges() const1776 size_t CostGraph::GetNumEdges() const {
1777   size_t sum = 0;
1778   for (const auto &kv : edges_) {
1779     auto &edges = kv.second;
1780     sum += edges.size();
1781   }
1782   return sum;
1783 }
1784 
InitReshapeStrategy()1785 Status CostGraph::InitReshapeStrategy() {
1786   // reshape init should be apply after the init of it's previous node and next node.
1787   for (size_t i = 0; i < ops_.size(); ++i) {
1788     if (ops_[i]->IsReshape()) {
1789       auto reshape_info = std::dynamic_pointer_cast<ReshapeInfo>(ops_[i]);
1790       auto in_edges = GetOriginalPrevEdges(ops_[i]);
1791       auto out_edges = GetOriginalNextEdges(ops_[i]);
1792       // deal with consecutive reshape op
1793       auto pre_reshape_iter = std::find_if(in_edges.begin(), in_edges.end(), [&](const std::shared_ptr<Edge> &edge) {
1794         return edge->prev_operator()->IsReshape();
1795       });
1796       auto next_reshape_iter = std::find_if(out_edges.begin(), out_edges.end(), [&](const std::shared_ptr<Edge> &edge) {
1797         return edge->next_operator()->IsReshape();
1798       });
1799       if (pre_reshape_iter != in_edges.end() || next_reshape_iter != out_edges.end()) {
1800         reshape_info->set_strategy(nullptr);
1801         continue;
1802       }
1803       auto pre_iter = std::find_if(in_edges.begin(), in_edges.end(), [&](const std::shared_ptr<Edge> &edge) {
1804         return edge->prev_operator()->name() == reshape_info->pre_operator_name();
1805       });
1806       auto next_iter = std::find_if(out_edges.begin(), out_edges.end(), [&](const std::shared_ptr<Edge> &edge) {
1807         return edge->next_operator()->name() == reshape_info->next_operator_name();
1808       });
1809       bool reshape_is_first_op = reshape_info->pre_operator_name() == reshape_info->name();
1810       if (reshape_is_first_op) {
1811         (void)reshape_info->InitSelectedStrategy(reshape_info->selected_strategy(), nullptr);
1812       }
1813       if (pre_iter != in_edges.end() || reshape_is_first_op) {
1814         MS_LOG(DEBUG) << "Set reshape input layout by " << reshape_info->pre_operator_name();
1815         int64_t pre_index = reshape_info->pre_operator_index();
1816         TensorInfo pre_info;
1817         std::shared_ptr<OperatorInfo> pre_op_info;
1818         if (reshape_is_first_op) {
1819           pre_op_info = reshape_info;
1820           pre_info = pre_op_info->inputs_tensor_info()[LongToSize(pre_index)];
1821         } else {
1822           pre_op_info = (*pre_iter)->prev_operator();
1823           pre_info = pre_op_info->outputs_tensor_info()[LongToSize(pre_index)];
1824         }
1825         reshape_info->SetInputLayout(pre_info.tensor_layout());
1826         if (pre_iter != in_edges.end()) {
1827           Dimensions stra = pre_info.InferStrategy();
1828           if (stra.empty()) {
1829             MS_LOG(EXCEPTION) << "Infer strategy by tensor_info failed";
1830           }
1831           Strategies stra_inputs = {stra};
1832           StrategyPtr reshape_stra =
1833             std::make_shared<Strategy>((*pre_iter)->prev_operator()->strategy()->GetInputStage(), stra_inputs);
1834           reshape_info->set_strategy(reshape_stra);
1835         }
1836       }
1837       if (next_iter != out_edges.end()) {
1838         MS_LOG(DEBUG) << "Set reshape output layout by " << reshape_info->next_operator_name();
1839         int64_t next_index = reshape_info->next_operator_index();
1840         reshape_info->SetOutputLayout(
1841           (*next_iter)->next_operator()->inputs_tensor_info()[LongToSize(next_index)].tensor_layout());
1842       }
1843       if (reshape_info->Init(nullptr, nullptr) != SUCCESS) {
1844         return FAILED;
1845       }
1846     }
1847   }
1848   return SUCCESS;
1849 }
1850 
InitSelectedStrategy()1851 Status CostGraph::InitSelectedStrategy() {
1852   for (auto &op : ops_) {
1853     MS_EXCEPTION_IF_NULL(op);
1854     if (op->IsReshape()) {
1855       continue;
1856     }
1857     StrategyPtr out_strategy = nullptr;
1858     if (op->type() == MATMUL && op->out_strategy() != nullptr) {
1859       out_strategy = op->out_strategy();
1860     }
1861     auto result_op = op->InitSelectedStrategy(op->selected_strategy(), out_strategy);
1862     if (result_op != SUCCESS) {
1863       return result_op;
1864     }
1865   }
1866   auto result = InitReshapeStrategy();
1867   return result;
1868 }
1869 
ComputeOpsAndEdgesParameterInvolved()1870 Status CostGraph::ComputeOpsAndEdgesParameterInvolved() {
1871   for (auto &op : ops_) {
1872     MS_EXCEPTION_IF_NULL(op);
1873     const auto &output_parameter = op->ComputeOpAndPrevEdgeParameterInvolved();
1874     if ((output_parameter != 0) && (output_parameter != 1)) {
1875       MS_LOG(ERROR) << "Computing parameter_involved for " << op->name() << " failed.";
1876       return FAILED;
1877     }
1878   }
1879   return SUCCESS;
1880 }
1881 
DFSForTopoOrder(const OperatorInfoPtr & current_op,std::map<OperatorInfoPtr,bool> * visited,std::vector<OperatorInfoPtr> * topo_order)1882 void CostGraph::DFSForTopoOrder(const OperatorInfoPtr &current_op, std::map<OperatorInfoPtr, bool> *visited,
1883                                 std::vector<OperatorInfoPtr> *topo_order) {
1884   MS_EXCEPTION_IF_NULL(current_op);
1885   MS_EXCEPTION_IF_NULL(visited);
1886   MS_EXCEPTION_IF_NULL(topo_order);
1887 
1888   visited->at(current_op) = true;
1889   for (const auto &s_edge : current_op->succ_edges()) {
1890     if (!visited->at(s_edge->next_operator())) {
1891       DFSForTopoOrder(s_edge->next_operator(), visited, topo_order);
1892     }
1893   }
1894   topo_order->push_back(current_op);
1895 }
1896 
1897 // Compute a topological order of the costgraph
TopologyOrder(std::vector<OperatorInfoPtr> * topo_order)1898 void CostGraph::TopologyOrder(std::vector<OperatorInfoPtr> *topo_order) {
1899   std::map<OperatorInfoPtr, bool> visited;
1900   for (auto &op : ops_) {
1901     visited[op] = false;
1902   }
1903 
1904   for (auto &op : ops_) {
1905     if (!visited[op]) {
1906       DFSForTopoOrder(op, &visited, topo_order);
1907     }
1908   }
1909 }
MarkCriticalOpsAndEdges(const std::map<OperatorInfoPtr,int64_t> & candidate_ops)1910 void CostGraph::MarkCriticalOpsAndEdges(const std::map<OperatorInfoPtr, int64_t> &candidate_ops) {
1911   for (auto &op : ops_) {
1912     auto search = candidate_ops.find(op);
1913     if (search != candidate_ops.end()) {
1914       // Mark the critical operators
1915       op->mark_output_critical();
1916       // Mark the successive edges
1917       for (auto &s_edge : op->succ_edges()) {
1918         s_edge->mark_output_critical();
1919       }
1920     } else {
1921       op->mark_output_not_critical();
1922     }
1923   }
1924 }
1925 
DetermineCriticalOps(const std::vector<OperatorInfoPtr> & topo_order)1926 Status CostGraph::DetermineCriticalOps(const std::vector<OperatorInfoPtr> &topo_order) {
1927   if (topo_order.size() == 0) {
1928     MS_LOG(ERROR) << "0 operator in costgraph.";
1929     return FAILED;
1930   }
1931   auto &first_op = topo_order[0];
1932   if (first_op->prev_edges().size() > 0) {
1933     MS_LOG(ERROR) << "The first operator in the first of topological order of "
1934                      "costgraph should have 0 incoming edge, but has "
1935                   << first_op->prev_edges() << "edges.";
1936     return FAILED;
1937   }
1938   // The 'curr_memory_state' records <OperatorInfo, remaining_output_cnt>, where remaining_output_cnt is the number
1939   // of the output of OperatorInfo that currently has not been used
1940   std::map<OperatorInfoPtr, int64_t> curr_memory_state;
1941   (void)curr_memory_state.emplace(std::make_pair(first_op, SizeToLong(first_op->succ_edges().size())));
1942   std::map<OperatorInfoPtr, int64_t> max_memory_state = curr_memory_state;
1943   // The 'curr_memory_size' records the current total memory size, which is the sum of outputs of operators that has
1944   // not been used
1945   double curr_memory_size = first_op->GetOutputsTotalSize();
1946   double max_memory_size = curr_memory_size;
1947 
1948   for (size_t finished = 1; finished < topo_order.size(); ++finished) {
1949     // Produce
1950     (void)curr_memory_state.emplace(
1951       std::make_pair(topo_order[finished], SizeToLong(topo_order[finished]->succ_edges().size())));
1952     curr_memory_size += topo_order[finished]->GetOutputsTotalSize();
1953     // Consume
1954     for (const auto &prev_edge : topo_order[finished]->prev_edges()) {
1955       const auto &prev_op = prev_edge->prev_operator();
1956       curr_memory_state[prev_op]--;
1957     }
1958     for (const auto &prev_edge : topo_order[finished]->prev_edges()) {
1959       const auto &prev_op = prev_edge->prev_operator();
1960       if (curr_memory_state[prev_op] < 0) {
1961         MS_LOG(ERROR) << "Failure: " << prev_op->name() << "'s current output count: " << curr_memory_state[prev_op];
1962         return FAILED;
1963       } else if (curr_memory_state[prev_op] == 0) {
1964         (void)curr_memory_state.erase(prev_op);
1965         curr_memory_size -= prev_op->GetOutputsTotalSize();
1966       }
1967     }
1968 
1969     if (curr_memory_size < 0) {
1970       MS_LOG(ERROR) << "Memory size calculation failed: " << curr_memory_size;
1971     }
1972     // Modify the max
1973     if (curr_memory_size > max_memory_size) {
1974       max_memory_size = curr_memory_size;
1975       max_memory_state = curr_memory_state;
1976     }
1977   }
1978   // Mark those critical operators
1979   MarkCriticalOpsAndEdges(max_memory_state);
1980   return SUCCESS;
1981 }
1982 
ComputeOpsAndEdgesOutputCritical()1983 Status CostGraph::ComputeOpsAndEdgesOutputCritical() {
1984   // Two steps to do:
1985   // 1. Compute a topological order of the costgraph
1986   // 2. Determine and mark the operators (and necessary edges) that are critical
1987   std::vector<OperatorInfoPtr> topo_order;
1988   TopologyOrder(&topo_order);
1989   std::reverse(std::begin(topo_order), std::end(topo_order));
1990 
1991   if (DetermineCriticalOps(topo_order) != SUCCESS) {
1992     MS_LOG(ERROR) << "Determining critical operators failed.";
1993     return FAILED;
1994   }
1995   return SUCCESS;
1996 }
1997 
CalculateOpsMemoryCost()1998 Status CostGraph::CalculateOpsMemoryCost() {
1999   for (auto &op : ops_) {
2000     MS_EXCEPTION_IF_NULL(op);
2001     if (op->CalculateMemoryCost() != SUCCESS) {
2002       MS_LOG(ERROR) << "Calculate Operator: " << op->name() << " cost for memory usage failed.";
2003       return FAILED;
2004     }
2005   }
2006   return SUCCESS;
2007 }
2008 
CalculateOpsMemoryCostForInference()2009 Status CostGraph::CalculateOpsMemoryCostForInference() {
2010   for (auto &op : ops_) {
2011     MS_EXCEPTION_IF_NULL(op);
2012     if (op->CalculateMemoryCostForInference() != SUCCESS) {
2013       MS_LOG(ERROR) << "Calculate Operator: " << op->name() << " cost for memory usage failed.";
2014       return FAILED;
2015     }
2016   }
2017   return SUCCESS;
2018 }
2019 
CalculateEdgesMemoryCost()2020 Status CostGraph::CalculateEdgesMemoryCost() {
2021   for (const auto &edge_pair : edges_) {
2022     const auto &edges = edge_pair.second;
2023     for (auto &one_edge : edges) {
2024       if (one_edge->CalculateMemoryCost() != SUCCESS) {
2025         MS_LOG(ERROR) << "Calculate Edge: " << one_edge->edge_name() << " cost for memory usage failed.";
2026         return FAILED;
2027       }
2028     }
2029   }
2030   return SUCCESS;
2031 }
2032 
CalculateEdgesMemoryCostForInference()2033 Status CostGraph::CalculateEdgesMemoryCostForInference() {
2034   for (const auto &edge_pair : edges_) {
2035     const auto &edges = edge_pair.second;
2036     for (auto &one_edge : edges) {
2037       if (one_edge->CalculateMemoryCostForInference() != SUCCESS) {
2038         MS_LOG(ERROR) << "Calculate Edge: " << one_edge->edge_name() << " cost for memory usage failed.";
2039         return FAILED;
2040       }
2041     }
2042   }
2043   return SUCCESS;
2044 }
2045 
FindTmpIdentityByParameterName(const std::string & p_name) const2046 OperatorInfoPtr CostGraph::FindTmpIdentityByParameterName(const std::string &p_name) const {
2047   for (auto one_op : ops_) {
2048     if (one_op->name().find(IDENTITY_INFO) != std::string::npos) {
2049       if (one_op->refkey_parameter_name() == p_name) {
2050         return one_op;
2051       }
2052     }
2053   }
2054   return nullptr;
2055 }
CorrectOpsMemoryCost()2056 Status CostGraph::CorrectOpsMemoryCost() {
2057   for (auto &one_op : ops_) {
2058     if ((one_op->name().find(IDENTITY_INFO) != std::string::npos) && (one_op->is_output_parameter_involve() == 1)) {
2059       if (one_op->GetAliveSuccEdges().size() > 1) {
2060         // Filter out the case when the TmpIdentity being used by multiple operators
2061         std::map<size_t, int64_t> output_count;
2062         for (size_t i = 0; i < one_op->GetAliveSuccEdges().size(); ++i) {
2063           auto output_index = one_op->GetAliveSuccEdges()[i]->prev_op_output_index();
2064           output_count[output_index]++;
2065         }
2066         for (size_t i = 0; i < one_op->GetAliveSuccEdges().size(); ++i) {
2067           auto output_index = one_op->GetAliveSuccEdges()[i]->prev_op_output_index();
2068           if (output_count[output_index] <= 1) {
2069             continue;
2070           }
2071           auto next_op = one_op->GetAliveSuccEdges()[i]->next_operator();
2072           MS_EXCEPTION_IF_NULL(next_op);
2073           auto input_index = one_op->GetAliveSuccEdges()[i]->next_op_input_index();
2074           if (next_op->CorrectMemoryCost(input_index) != SUCCESS) {
2075             MS_LOG(ERROR) << "The operator name: " << one_op->name() << ", the next operator name: " << next_op->name()
2076                           << ", the output_index: " << output_index << ", the input_index: " << input_index << ".";
2077             return FAILED;
2078           }
2079           output_count[output_index]--;
2080         }
2081       }
2082     }
2083   }
2084   return SUCCESS;
2085 }
2086 
CalculateMemoryCost()2087 Status CostGraph::CalculateMemoryCost() {
2088   const auto phase = CostModelContext::GetInstance()->run_phase();
2089   if (phase == TRAINING_PHASE) {
2090     // training phase
2091     if (ComputeOpsAndEdgesParameterInvolved() == SUCCESS) {
2092       // Calculate operators' memory usage
2093       if (CalculateOpsMemoryCost() != SUCCESS) {
2094         MS_LOG(ERROR) << "Calculating operators' cost for memory cost failed.";
2095         return FAILED;
2096       }
2097       // Calculate edges' memory usage
2098       if (CalculateEdgesMemoryCost() != SUCCESS) {
2099         MS_LOG(ERROR) << "Calculating edges' cost for memory cost failed.";
2100         return FAILED;
2101       }
2102       // Correct memory usage caused by TmpIdentity
2103       if (CorrectOpsMemoryCost() != SUCCESS) {
2104         MS_LOG(ERROR) << "Correcting operators' cost for memory cost failed.";
2105         return FAILED;
2106       }
2107     } else {
2108       MS_LOG(ERROR) << "Computing operators' parameter_involved failed.";
2109       return FAILED;
2110     }
2111   } else {
2112     // inference phase
2113     if (ComputeOpsAndEdgesOutputCritical() == SUCCESS) {
2114       // Calculate operators' memory usage
2115       if (CalculateOpsMemoryCostForInference() != SUCCESS) {
2116         MS_LOG(ERROR) << "Calculating operators' memory cost for inference failed.";
2117         return FAILED;
2118       }
2119       // Calculate edges's memory usage
2120       if (CalculateEdgesMemoryCostForInference() != SUCCESS) {
2121         MS_LOG(ERROR) << "Calculating operators' memory cost for inference failed.";
2122         return FAILED;
2123       }
2124     } else {
2125       MS_LOG(ERROR) << "Computing operators' critical flag failed.";
2126       return FAILED;
2127     }
2128   }
2129   return SUCCESS;
2130 }
2131 
CheckApproximateCostGraphEdges()2132 void CostGraph::CheckApproximateCostGraphEdges() {
2133   auto approximation = CostModelContext::GetInstance()->dp_algo_enable_approxi();
2134   if (!approximation) {
2135     return;
2136   }
2137   for (const auto &s_edge : edges_) {
2138     auto &edges_vector = s_edge.second;
2139     for (auto &edge_ptr : edges_vector) {
2140       MS_EXCEPTION_IF_NULL(edge_ptr);
2141       if (edge_ptr->CheckStrategyCostPossibility()) {
2142         continue;
2143       }
2144       MS_LOG(INFO) << "Checking StrategyCost for edge: " << edge_ptr->edge_name()
2145                    << " impossible, re-initing the operators and edges";
2146       auto prev_op = edge_ptr->prev_operator();
2147       MS_EXCEPTION_IF_NULL(prev_op);
2148       auto next_op = edge_ptr->next_operator();
2149       MS_EXCEPTION_IF_NULL(next_op);
2150       // Check the 'prev_op'
2151       prev_op->ExactStrategiesAndRelatedEdges();
2152       // Check the 'next_op'
2153       next_op->ExactStrategiesAndRelatedEdges();
2154     }
2155   }
2156 }
2157 }  // namespace parallel
2158 }  // namespace mindspore
2159