1 /**
2 * Copyright 2019-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include <algorithm>
17 #include <cstdlib>
18 #include <numeric>
19 #include <string>
20 #include <utility>
21 #include <vector>
22 #include <queue>
23 #include <set>
24
25 #include "frontend/parallel/auto_parallel/graph_costmodel.h"
26 #include "frontend/parallel/ops_info/reshape_info.h"
27 #include "frontend/parallel/step_auto_parallel.h"
28
29 namespace mindspore {
30 namespace parallel {
31 CostGraphPtr entire_costgraph = nullptr;
32
Init()33 void CostGraph::Init() {
34 inputs_tensor_name_list_.clear();
35 tuple_getitem_list_.clear();
36 ops_.clear();
37 edges_.clear();
38 connected_compoents_.clear();
39 out_edges_.clear();
40 in_edges_.clear();
41 }
42
RemoveOperator(const OperatorInfoPtr & op)43 void CostGraph::RemoveOperator(const OperatorInfoPtr &op) {
44 for (auto it = ops_.begin(); it != ops_.end();) {
45 if ((*it) == op) {
46 it = ops_.erase(it);
47 } else {
48 ++it;
49 }
50 }
51 }
52
IsOperatorInCostGraph(const OperatorInfoPtr & op_test)53 bool CostGraph::IsOperatorInCostGraph(const OperatorInfoPtr &op_test) {
54 struct IsInGraph {
55 const OperatorInfoPtr test_;
56 explicit IsInGraph(const OperatorInfoPtr &n) : test_(n) {}
57 bool operator()(const OperatorInfoPtr &in) const { return (test_ == in); }
58 };
59 return std::any_of(ops_.begin(), ops_.end(), IsInGraph(op_test));
60 }
61
AddEdge(OperatorInfoPtr u_node,OperatorInfoPtr v_node,const EdgePtr & edge)62 void CostGraph::AddEdge(OperatorInfoPtr u_node, OperatorInfoPtr v_node, const EdgePtr &edge) {
63 std::vector<EdgePtr> curr_edges(edges_[{u_node, v_node}]);
64 curr_edges.push_back(edge);
65 edges_[{u_node, v_node}] = curr_edges;
66
67 std::vector<EdgePtr> curr_out_edges(out_edges_[u_node]);
68 curr_out_edges.push_back(edge);
69 out_edges_[u_node] = curr_out_edges;
70
71 std::vector<EdgePtr> curr_in_edges(in_edges_[v_node]);
72 curr_in_edges.push_back(edge);
73 in_edges_[v_node] = curr_in_edges;
74 }
75
IsEdgeInCostGraph(const std::string & test_edge_name,size_t output_index,size_t input_index)76 bool CostGraph::IsEdgeInCostGraph(const std::string &test_edge_name, size_t output_index, size_t input_index) {
77 for (auto &edge_pair : edges_) {
78 auto edges = edge_pair.second;
79 for (auto &edge : edges) {
80 MS_EXCEPTION_IF_NULL(edge);
81 bool bool_result = (edge->edge_name() == test_edge_name) && (edge->prev_op_output_index() == output_index) &&
82 (edge->next_op_input_index() == input_index);
83 if (bool_result) {
84 return true;
85 }
86 }
87 }
88 return false;
89 }
90
91 // '_diff_stra_params' contains all TmpIdentity operators to be processed(i.e. params), the users of
92 // these operators have different strategies, which need an additional propagation process.
93 std::set<OperatorInfoPtr> _diff_stra_params;
StrategyPropagate(const std::map<OperatorInfoPtr,StrategyPtr,OpsPtrCompare> & ops_stras)94 void CostGraph::StrategyPropagate(const std::map<OperatorInfoPtr, StrategyPtr, OpsPtrCompare> &ops_stras) {
95 if (ops_stras.empty()) {
96 MS_LOG(EXCEPTION) << "There is no operator that is configured sharding strategy.";
97 }
98 _diff_stra_params.clear();
99 std::map<OperatorInfoPtr, bool> visited;
100 for (auto &op : ops_) {
101 visited[op] = false;
102 }
103 for (auto &op_stra : ops_stras) {
104 BFS(op_stra.first, op_stra.second, ops_stras, &visited);
105 }
106
107 ProcessDiffStraParams(ops_stras);
108 _diff_stra_params.clear();
109 // GetNext as a isolate op can not be propagated
110 for (auto &op : entire_costgraph->GetOperators()) {
111 if (op->selected_strategy() == nullptr) {
112 MS_LOG(INFO) << "Set default strategy for op: " << op->name();
113 op->SetSelectedStrategy(op->strategy_cost()[0]->strategy_ptr, 0);
114 }
115 }
116 }
117
CheckVisitedEdgeConsistency(const EdgePtr & edge)118 bool CheckVisitedEdgeConsistency(const EdgePtr &edge) {
119 MS_EXCEPTION_IF_NULL(edge);
120 auto prev_op = edge->prev_operator();
121 auto next_op = edge->next_operator();
122 bool consistency = true;
123 if (prev_op->IsReshape()) {
124 const auto &reshape_output_lyt =
125 next_op->GetInputLayoutFromSWCByStrategy(next_op->selected_strategy(), edge->next_op_input_index());
126 auto reshape_ptr = std::dynamic_pointer_cast<ReshapeInfo>(prev_op);
127 consistency = reshape_ptr->CheckStrategyConsistencyByOutputLayout(reshape_ptr->swc_index(), reshape_output_lyt);
128 if (!consistency) {
129 MS_LOG(WARNING) << "Inconsistency occurred at edge: " << edge->edge_name();
130 }
131 } else if (next_op->IsReshape()) {
132 const auto &reshape_input_lyt =
133 prev_op->GetOutputLayoutFromSWCByStrategy(prev_op->selected_strategy(), edge->prev_op_output_index());
134 auto reshape_ptr = std::dynamic_pointer_cast<ReshapeInfo>(next_op);
135 consistency = reshape_ptr->CheckStrategyConsistencyByInputLayout(reshape_ptr->swc_index(), reshape_input_lyt);
136 if (!consistency) {
137 MS_LOG(WARNING) << "Inconsistency occurred at edge: " << edge->edge_name();
138 }
139 } else {
140 consistency =
141 edge->CheckStrategyConsistency(prev_op->selected_strategy(), next_op->selected_strategy(), &_diff_stra_params);
142 if (!consistency) {
143 MS_LOG(WARNING) << "Inconsistency occurred at edge: " << edge->edge_name();
144 }
145 }
146 return consistency;
147 }
148
CheckConfiguredSuccEdgeConsistency(const EdgePtr & edge,const std::map<OperatorInfoPtr,StrategyPtr,OpsPtrCompare> & configured_ops)149 bool CheckConfiguredSuccEdgeConsistency(const EdgePtr &edge,
150 const std::map<OperatorInfoPtr, StrategyPtr, OpsPtrCompare> &configured_ops) {
151 MS_EXCEPTION_IF_NULL(edge);
152 auto curr_op = edge->prev_operator();
153 auto next_op = edge->next_operator();
154 auto next_op_conf_stra = configured_ops.at(next_op);
155 bool consistency = true;
156 if (curr_op->IsReshape()) {
157 const auto &reshape_output_lyt =
158 next_op->GetInputLayoutFromSWCByStrategy(next_op_conf_stra, edge->next_op_input_index());
159 auto reshape_ptr = std::dynamic_pointer_cast<ReshapeInfo>(curr_op);
160 consistency = reshape_ptr->CheckStrategyConsistencyByOutputLayout(reshape_ptr->swc_index(), reshape_output_lyt);
161 if (!consistency) {
162 MS_LOG(WARNING) << "Inconsistency occurred at edge: " << edge->edge_name();
163 }
164 return consistency;
165 } else {
166 consistency = edge->CheckStrategyConsistency(curr_op->selected_strategy(), next_op_conf_stra, &_diff_stra_params);
167 if (!consistency) {
168 MS_LOG(WARNING) << "Inconsistency occurred at edge: " << edge->edge_name();
169 }
170 }
171 return consistency;
172 }
173
CheckConfiguredPrevEdgeConsistency(const EdgePtr & edge,const std::map<OperatorInfoPtr,StrategyPtr,OpsPtrCompare> & configured_ops)174 void CheckConfiguredPrevEdgeConsistency(const EdgePtr &edge,
175 const std::map<OperatorInfoPtr, StrategyPtr, OpsPtrCompare> &configured_ops) {
176 auto curr_op = edge->next_operator();
177 auto prev_op = edge->prev_operator();
178 auto prev_op_conf_stra = configured_ops.at(prev_op);
179 if (curr_op->IsReshape()) {
180 const auto &reshape_input_lyt =
181 prev_op->GetOutputLayoutFromSWCByStrategy(prev_op_conf_stra, edge->prev_op_output_index());
182 auto reshape_ptr = std::dynamic_pointer_cast<ReshapeInfo>(curr_op);
183 auto consistency = reshape_ptr->CheckStrategyConsistencyByInputLayout(reshape_ptr->swc_index(), reshape_input_lyt);
184 if (!consistency) {
185 MS_LOG(WARNING) << "Inconsistency occurred at edge: " << edge->edge_name();
186 }
187 } else {
188 auto consistency =
189 edge->CheckStrategyConsistency(prev_op_conf_stra, curr_op->selected_strategy(), &_diff_stra_params);
190 if (!consistency) {
191 MS_LOG(WARNING) << "Inconsistency occurred at edge: " << edge->edge_name();
192 }
193 }
194 }
195
BFSPreNode(const std::shared_ptr<Edge> & edge,std::map<OperatorInfoPtr,bool> * visited,int64_t curr_depth,const std::map<OperatorInfoPtr,StrategyPtr,OpsPtrCompare> & configured_ops,const OperatorInfoPtr & curr_op,std::queue<std::pair<std::pair<OperatorInfoPtr,std::pair<StrategyPtr,int64_t>>,int64_t>> * next_level)196 void BFSPreNode(
197 const std::shared_ptr<Edge> &edge, std::map<OperatorInfoPtr, bool> *visited, int64_t curr_depth,
198 const std::map<OperatorInfoPtr, StrategyPtr, OpsPtrCompare> &configured_ops, const OperatorInfoPtr &curr_op,
199 std::queue<std::pair<std::pair<OperatorInfoPtr, std::pair<StrategyPtr, int64_t>>, int64_t>> *next_level) {
200 MS_EXCEPTION_IF_NULL(edge);
201 MS_EXCEPTION_IF_NULL(visited);
202 MS_EXCEPTION_IF_NULL(curr_op);
203 MS_EXCEPTION_IF_NULL(next_level);
204 const auto &prev_op = edge->prev_operator();
205 bool is_has_stra_op = visited->at(prev_op) || configured_ops.find(prev_op) != configured_ops.end();
206 if (edge->InitEdgeCost() != SUCCESS && !is_has_stra_op) {
207 MS_LOG(EXCEPTION) << "Edge cost initialization failed.";
208 }
209 if (visited->at(prev_op)) {
210 (void)CheckVisitedEdgeConsistency(edge);
211 return;
212 }
213 if ((curr_depth > 0) && (configured_ops.find(prev_op) != configured_ops.end())) {
214 CheckConfiguredPrevEdgeConsistency(edge, configured_ops);
215 }
216 if (configured_ops.find(prev_op) != configured_ops.end()) {
217 return;
218 }
219 visited->at(prev_op) = true;
220 if (prev_op->IsReshape()) {
221 auto swc_index = edge->GetReshapeSWCIndexByNextOpStrategy(curr_op->selected_strategy());
222 (void)next_level->emplace(std::make_pair(prev_op, std::make_pair(nullptr, swc_index)), curr_depth + 1);
223 prev_op->set_swc_index(swc_index, curr_depth + 1);
224 } else if (curr_op->IsReshape()) {
225 auto prev_stra = edge->GetPrevOpStrategyByReshapeSWCIndex(curr_op->swc_index());
226 (void)next_level->emplace(std::make_pair(prev_op, std::make_pair(prev_stra, -1)), curr_depth + 1);
227 prev_op->SetSelectedStrategy(prev_stra, LongToSize(curr_depth + 1));
228 prev_op->ClearStrategyCost();
229 if (prev_op->SetCostUnderStrategy(prev_stra) != SUCCESS) {
230 MS_LOG(EXCEPTION) << "Failure: operator " << prev_op->name() << " SetCostUnderStrategy failed";
231 }
232 } else {
233 const auto &prev_op_stra = edge->GetPrevOpStrategyByNextOpStrategyWithMiniComm(curr_op->selected_strategy());
234 if (prev_op_stra == nullptr) {
235 MS_LOG(EXCEPTION) << "Current op " << curr_op->name() << "'s strategy is "
236 << curr_op->selected_strategy()->ToString() << ". " << prev_op->name()
237 << "'s strategy is null in the edge: " << edge->edge_name();
238 }
239 (void)next_level->emplace(std::make_pair(prev_op, std::make_pair(prev_op_stra, -1)), curr_depth + 1);
240 prev_op->SetSelectedStrategy(prev_op_stra, LongToSize(curr_depth + 1));
241 prev_op->ClearStrategyCost();
242 if (prev_op->SetCostUnderStrategy(prev_op_stra) != SUCCESS) {
243 MS_LOG(EXCEPTION) << "Failure: operator " << prev_op->name() << " SetCostUnderStrategy failed";
244 }
245 }
246 }
247
BFSNextNode(const std::shared_ptr<Edge> & edge,std::map<OperatorInfoPtr,bool> * visited,int64_t curr_depth,const std::map<OperatorInfoPtr,StrategyPtr,OpsPtrCompare> & configured_ops,const OperatorInfoPtr & curr_op,std::queue<std::pair<std::pair<OperatorInfoPtr,std::pair<StrategyPtr,int64_t>>,int64_t>> * next_level)248 void BFSNextNode(
249 const std::shared_ptr<Edge> &edge, std::map<OperatorInfoPtr, bool> *visited, int64_t curr_depth,
250 const std::map<OperatorInfoPtr, StrategyPtr, OpsPtrCompare> &configured_ops, const OperatorInfoPtr &curr_op,
251 std::queue<std::pair<std::pair<OperatorInfoPtr, std::pair<StrategyPtr, int64_t>>, int64_t>> *next_level) {
252 MS_EXCEPTION_IF_NULL(edge);
253 MS_EXCEPTION_IF_NULL(visited);
254 MS_EXCEPTION_IF_NULL(curr_op);
255 MS_EXCEPTION_IF_NULL(next_level);
256 const auto &next_op = edge->next_operator();
257 bool is_has_stra_op = visited->at(next_op) || configured_ops.find(next_op) != configured_ops.end();
258 if (edge->InitEdgeCost() != SUCCESS && !is_has_stra_op) {
259 MS_LOG(EXCEPTION) << "Edge cost initialization failed.";
260 }
261 if (visited->at(next_op)) {
262 (void)CheckVisitedEdgeConsistency(edge);
263 return;
264 }
265 if ((curr_depth > 0) && (configured_ops.find(next_op) != configured_ops.end())) {
266 (void)CheckConfiguredSuccEdgeConsistency(edge, configured_ops);
267 }
268 if (configured_ops.find(next_op) != configured_ops.end()) {
269 return;
270 }
271 visited->at(next_op) = true;
272 if (curr_op->IsReshape()) {
273 auto stra = edge->GetNextOpStrategyByReshapeSWCIndex(curr_op->swc_index());
274 (void)next_level->emplace(std::make_pair(next_op, std::make_pair(stra, -1)), curr_depth + 1);
275 next_op->SetSelectedStrategy(stra, LongToSize(curr_depth + 1));
276 next_op->ClearStrategyCost();
277 if (next_op->SetCostUnderStrategy(stra) != SUCCESS) {
278 MS_LOG(EXCEPTION) << "Failure: operator " << next_op->name() << " SetCostUnderStrategy failed";
279 }
280 } else if (next_op->IsReshape()) {
281 auto swc_index = edge->GetReshapeSWCIndexByPrevOpStrategy(curr_op->selected_strategy());
282 (void)next_level->emplace(std::make_pair(next_op, std::make_pair(nullptr, swc_index)), curr_depth + 1);
283 next_op->set_swc_index(swc_index, curr_depth + 1);
284 } else {
285 const auto &next_op_stra = edge->GetNextOpStrategyByPrevOpStrategyWithMiniComm(curr_op->selected_strategy());
286 if (next_op_stra == nullptr) {
287 MS_LOG(INFO) << "The strategy is: " << curr_op->selected_strategy()->ToString();
288 MS_LOG(EXCEPTION) << next_op->name() << "'s strategy is null in the edge: " << edge->edge_name();
289 }
290 (void)next_level->emplace(std::make_pair(next_op, std::make_pair(next_op_stra, -1)), curr_depth + 1);
291 next_op->SetSelectedStrategy(next_op_stra, LongToSize(curr_depth + 1));
292 next_op->ClearStrategyCost();
293 if (next_op->SetCostUnderStrategy(next_op_stra) != SUCCESS) {
294 MS_LOG(EXCEPTION) << "Failure: operator " << next_op->name() << " SetCostUnderStrategy failed";
295 }
296 }
297 }
298
BFS(const OperatorInfoPtr & op,const StrategyPtr & op_stra,const std::map<OperatorInfoPtr,StrategyPtr,OpsPtrCompare> & configured_ops,std::map<OperatorInfoPtr,bool> * visited) const299 void CostGraph::BFS(const OperatorInfoPtr &op, const StrategyPtr &op_stra,
300 const std::map<OperatorInfoPtr, StrategyPtr, OpsPtrCompare> &configured_ops,
301 std::map<OperatorInfoPtr, bool> *visited) const {
302 std::queue<std::pair<std::pair<OperatorInfoPtr, std::pair<StrategyPtr, int64_t>>, int64_t>> next_level;
303 (void)next_level.emplace(std::make_pair(op, std::make_pair(op_stra, -1)), 0);
304 op->SetSelectedStrategy(op_stra, 0);
305 while (!next_level.empty()) {
306 auto curr_op = next_level.front().first.first;
307 auto configured_stra = next_level.front().first.second.first;
308 auto curr_depth = next_level.front().second;
309 visited->at(curr_op) = true;
310 MS_LOG(INFO) << "curr_depth: " << curr_depth << " curr_op: " << curr_op->name();
311 for (auto &edge : curr_op->succ_edges()) {
312 BFSNextNode(edge, visited, curr_depth, configured_ops, curr_op, &next_level);
313 }
314 for (auto &edge : curr_op->prev_edges()) {
315 BFSPreNode(edge, visited, curr_depth, configured_ops, curr_op, &next_level);
316 }
317 next_level.pop();
318 }
319 }
320
321 // An additional propagation to deal with incontinuous strategy between the ops that share params.
ProcessDiffStraParams(const std::map<OperatorInfoPtr,StrategyPtr,OpsPtrCompare> & configured_ops)322 void CostGraph::ProcessDiffStraParams(const std::map<OperatorInfoPtr, StrategyPtr, OpsPtrCompare> &configured_ops) {
323 for (auto &curr_identity_op : _diff_stra_params) {
324 auto succ_edges = curr_identity_op->succ_edges();
325 auto is_target = [&](const std::shared_ptr<Edge> &edge) {
326 return configured_ops.find(edge->next_operator()) != configured_ops.end();
327 };
328 auto it = std::find_if(succ_edges.begin(), succ_edges.end(), is_target);
329 // if there are configured ops in the users, update the strategy of identity_op
330 if (it != succ_edges.end()) {
331 auto consistency = CheckVisitedEdgeConsistency(*it);
332 if (!consistency) {
333 curr_identity_op->ClearStrategyCost();
334 (void)curr_identity_op->GenerateStrategies(0);
335 if ((*it)->InitEdgeCost() != SUCCESS) {
336 MS_LOG(EXCEPTION) << "Edge cost initialization failed.";
337 }
338 const auto curr_op = (*it)->next_operator();
339 const auto &prev_op_stra = (*it)->GetPrevOpStrategyByNextOpStrategyWithMiniComm(curr_op->selected_strategy());
340 if (prev_op_stra == nullptr) {
341 MS_LOG(EXCEPTION) << curr_identity_op->name() << "'s strategy is null in the edge: " << (*it)->edge_name();
342 }
343 curr_identity_op->SetSelectedStrategy(prev_op_stra, 1);
344 curr_identity_op->ClearStrategyCost();
345 if (curr_identity_op->SetCostUnderStrategy(prev_op_stra) != SUCCESS) {
346 MS_LOG(EXCEPTION) << "Failure: operator " << curr_identity_op->name() << " SetCostUnderStrategy failed";
347 }
348 }
349 }
350 for (auto &edge : succ_edges) {
351 const auto &next_op = edge->next_operator();
352 ParamPropagation(curr_identity_op, edge, configured_ops);
353 if (next_op->name().find(CAST) != std::string::npos) {
354 for (auto &cast_edge : next_op->succ_edges()) {
355 ParamPropagation(next_op, cast_edge, configured_ops);
356 }
357 }
358 }
359 }
360 }
361
ParamPropagation(const OperatorInfoPtr & curr_op,const std::shared_ptr<Edge> edge,const std::map<OperatorInfoPtr,StrategyPtr,OpsPtrCompare> & configured_ops) const362 void CostGraph::ParamPropagation(const OperatorInfoPtr &curr_op, const std::shared_ptr<Edge> edge,
363 const std::map<OperatorInfoPtr, StrategyPtr, OpsPtrCompare> &configured_ops) const {
364 const auto &next_op = edge->next_operator();
365 MS_LOG(INFO) << "params propagation at " << curr_op->name() << "->" << next_op->name();
366 if ((configured_ops.find(next_op) != configured_ops.end())) {
367 auto consistency = CheckConfiguredSuccEdgeConsistency(edge, configured_ops);
368 if (!consistency) {
369 MS_LOG(EXCEPTION) << "Incorrect strategy configuration.";
370 }
371 return;
372 }
373 auto consistency = CheckVisitedEdgeConsistency(edge);
374 if (consistency) {
375 MS_LOG(INFO) << "Jump the consistency edge.";
376 return;
377 }
378 next_op->ClearStrategyCost();
379 (void)next_op->GenerateStrategies(0);
380 if (edge->InitEdgeCost() != SUCCESS) {
381 MS_LOG(EXCEPTION) << "Edge cost initialization failed.";
382 }
383 const auto &next_op_stra = edge->GetNextOpStrategyByPrevOpStrategyWithMiniComm(curr_op->selected_strategy());
384 if (next_op_stra == nullptr) {
385 MS_LOG(EXCEPTION) << next_op->name() << "'s strategy is null in the edge: " << edge->edge_name();
386 }
387 next_op->SetSelectedStrategy(next_op_stra, 1);
388 next_op->ClearStrategyCost();
389 if (next_op->SetCostUnderStrategy(next_op_stra) != SUCCESS) {
390 MS_LOG(EXCEPTION) << "Failure: operator " << next_op->name() << " SetCostUnderStrategy failed";
391 }
392 }
393
ConstructConnectedComponents(std::vector<OperatorInfoPtr> alive_ops)394 std::vector<std::shared_ptr<CostGraph>> CostGraph::ConstructConnectedComponents(
395 std::vector<OperatorInfoPtr> alive_ops) {
396 std::map<OperatorInfoPtr, bool> visited;
397
398 for (auto &op : alive_ops) {
399 visited[op] = false;
400 }
401
402 MS_LOG(INFO) << "visited: " << visited.size() << ".";
403 for (auto &op : alive_ops) {
404 if ((!visited[op]) && op->is_alive()) {
405 std::shared_ptr<CostGraph> new_component = std::make_shared<CostGraph>();
406 MS_EXCEPTION_IF_NULL(new_component);
407 DFS(op, &visited, new_component);
408 connected_compoents_.push_back(new_component);
409 }
410 }
411 return connected_compoents_;
412 }
413
DFS(const OperatorInfoPtr & current_op,std::map<OperatorInfoPtr,bool> * visited,const std::shared_ptr<CostGraph> & component)414 void CostGraph::DFS(const OperatorInfoPtr ¤t_op, std::map<OperatorInfoPtr, bool> *visited,
415 const std::shared_ptr<CostGraph> &component) {
416 MS_EXCEPTION_IF_NULL(visited);
417 MS_EXCEPTION_IF_NULL(component);
418 visited->at(current_op) = true;
419 component->AddOperator(current_op);
420
421 for (auto &edge : current_op->succ_edges()) {
422 bool bool_test = (visited->find(edge->next_operator()) != visited->end()) &&
423 (!visited->at(edge->next_operator())) && edge->next_operator()->is_alive();
424 if (bool_test) {
425 component->AddEdge(current_op, edge->next_operator(), edge);
426 DFS(edge->next_operator(), visited, component);
427 }
428 }
429
430 for (auto &edge : current_op->prev_edges()) {
431 bool bool_test = (visited->find(edge->prev_operator()) != visited->end()) &&
432 (!visited->at(edge->prev_operator())) && edge->prev_operator()->is_alive();
433 if (bool_test) {
434 component->AddEdge(edge->prev_operator(), current_op, edge);
435 DFS(edge->prev_operator(), visited, component);
436 }
437 }
438 }
439
440 // Create final cost list for the graph: u --> v
CreateFinalCostList(const OperatorInfoPtr & u,const std::shared_ptr<Edge> & e,const OperatorInfoPtr & v) const441 CostPtrList CostGraph::CreateFinalCostList(const OperatorInfoPtr &u, const std::shared_ptr<Edge> &e,
442 const OperatorInfoPtr &v) const {
443 MS_EXCEPTION_IF_NULL(u);
444 MS_EXCEPTION_IF_NULL(v);
445 MS_EXCEPTION_IF_NULL(e);
446 CostPtrList ret;
447 for (const auto &u_strategy : u->GetStrategyCost()) {
448 for (const auto &v_strategy : v->GetStrategyCost()) {
449 MS_EXCEPTION_IF_NULL(u_strategy);
450 MS_EXCEPTION_IF_NULL(v_strategy);
451 auto u_strategy_ptr = u_strategy->strategy_ptr;
452 auto v_strategy_ptr = v_strategy->strategy_ptr;
453 CostPtrList clist1 = u_strategy->cost_list;
454 CostPtrList clist2 = e->GetCostList(u_strategy_ptr, v_strategy_ptr);
455 CostPtrList clist3 = v_strategy->cost_list;
456 for (const auto &cost1 : clist1) {
457 for (const auto &cost2 : clist2) {
458 for (const auto &cost3 : clist3) {
459 MS_EXCEPTION_IF_NULL(cost1);
460 MS_EXCEPTION_IF_NULL(cost2);
461 MS_EXCEPTION_IF_NULL(cost3);
462 double computation = cost1->computation_cost_ + cost2->computation_cost_ + cost3->computation_cost_;
463 double memory = cost1->memory_with_reuse_ + cost2->memory_with_reuse_ + cost3->memory_with_reuse_;
464 double communication = cost1->communication_cost_ + cost2->communication_cost_ + cost3->communication_cost_;
465 double communication_forward =
466 cost1->communication_forward_ + cost2->communication_forward_ + cost3->communication_forward_;
467 double communication_without_para = cost1->communication_without_parameter_ +
468 cost2->communication_without_parameter_ +
469 cost3->communication_without_parameter_;
470 auto decision =
471 std::make_shared<FinalDecision>(u_strategy->strategy_ptr, v_strategy->strategy_ptr, cost1, cost2, cost3);
472 auto cost = std::make_shared<Cost>(computation, communication, decision);
473 const auto gamma = CostModelContext::GetInstance()->costmodel_gamma();
474 MS_EXCEPTION_IF_NULL(cost);
475 cost->communication_without_parameter_ = communication_without_para;
476 cost->communication_with_partial_para_ =
477 communication_without_para + gamma * (communication - communication_without_para);
478 cost->memory_with_reuse_ = memory;
479 cost->communication_forward_ = communication_forward;
480 ret.push_back(cost);
481 }
482 }
483 }
484 }
485 }
486
487 Simplify(&ret);
488 return ret;
489 }
490
491 // Create final cost list for the graph containing a single node: u
CreateFinalSingleCostList(const OperatorInfoPtr & u) const492 CostPtrList CostGraph::CreateFinalSingleCostList(const OperatorInfoPtr &u) const {
493 MS_EXCEPTION_IF_NULL(u);
494 CostPtrList ret;
495 for (const auto &u_strategy : u->GetStrategyCost()) {
496 MS_EXCEPTION_IF_NULL(u_strategy);
497 auto u_strategy_ptr = u_strategy->strategy_ptr;
498 CostPtrList clist1 = u_strategy->cost_list;
499 for (const auto &cost1 : clist1) {
500 MS_EXCEPTION_IF_NULL(cost1);
501 auto decision = std::make_shared<FinalSingleDecision>(u_strategy_ptr, cost1);
502 auto new_cost = std::make_shared<Cost>(cost1->computation_cost_, cost1->communication_cost_, decision);
503 const auto gamma = CostModelContext::GetInstance()->costmodel_gamma();
504 MS_EXCEPTION_IF_NULL(new_cost);
505 new_cost->communication_without_parameter_ = cost1->communication_without_parameter_;
506 new_cost->communication_with_partial_para_ =
507 cost1->communication_without_parameter_ +
508 gamma * (cost1->communication_cost_ - cost1->communication_without_parameter_);
509 new_cost->memory_with_reuse_ = cost1->memory_with_reuse_;
510 new_cost->communication_forward_ = cost1->communication_forward_;
511 ret.push_back(new_cost);
512 }
513 }
514
515 Simplify(&ret);
516 return ret;
517 }
518
SelectCostWithMinInferenceTime(const CostPtrList & cost_list,double memory) const519 CostPtr CostGraph::SelectCostWithMinInferenceTime(const CostPtrList &cost_list, double memory) const {
520 // Select the cost with minimum inference time. Currently, the inference time is modeled as =
521 // costmodel_alpha_ * computation_cost + costmodel_beta_ * communication_forward_
522 if (cost_list.empty()) {
523 MS_LOG(ERROR) << "Final cost list is null.";
524 return nullptr;
525 }
526 CostPtrList after_mem_filter;
527 double minimum_memory = DBL_MAX;
528 // Filter out the valid costs.
529 for (auto &a_cost : cost_list) {
530 if (a_cost->memory_with_reuse_ <= memory) {
531 after_mem_filter.push_back(a_cost);
532 } else if (a_cost->memory_with_reuse_ < minimum_memory) {
533 minimum_memory = a_cost->memory_with_reuse_;
534 }
535 }
536 if (after_mem_filter.empty()) {
537 MS_LOG(ERROR) << "No available cost. The minimum memory cost is: " << minimum_memory
538 << ", the memory capacity is: " << memory << ".";
539 return nullptr;
540 }
541 // Init the returned value with first cost.
542 CostPtr ret = after_mem_filter[0];
543 const auto alpha = CostModelContext::GetInstance()->costmodel_alpha();
544 const auto beta = CostModelContext::GetInstance()->costmodel_beta();
545
546 double minimum = alpha * ret->computation_cost_ + beta * ret->communication_forward_;
547 MS_LOG(INFO) << "Cost 0: "
548 << "memory_cost: " << ret->memory_with_reuse_ << ", computation_cost_: " << ret->computation_cost_
549 << ", communication_forward_: " << ret->communication_forward_
550 << ", communication_with_partial_para_: " << ret->communication_with_partial_para_
551 << ", communication_cost_: " << ret->communication_cost_
552 << ", communication_without_parameter_: " << ret->communication_without_parameter_ << ".";
553 MS_LOG(INFO) << "Cost 0: total_cost: " << minimum;
554 for (size_t i = 1; i < after_mem_filter.size(); ++i) {
555 MS_EXCEPTION_IF_NULL(after_mem_filter[i]);
556 MS_LOG(INFO) << "Cost " << i << ": memory_cost: " << after_mem_filter[i]->memory_with_reuse_
557 << ", computation_cost_: " << after_mem_filter[i]->computation_cost_
558 << ", communication_forward_: " << after_mem_filter[i]->communication_forward_
559 << ", communication_with_partial_para_: " << after_mem_filter[i]->communication_with_partial_para_
560 << ", communication_cost_: " << after_mem_filter[i]->communication_cost_
561 << ", communication_without_parameter_: " << after_mem_filter[i]->communication_without_parameter_
562 << ".";
563 auto tmp = alpha * after_mem_filter[i]->computation_cost_ + beta * after_mem_filter[i]->communication_forward_;
564 MS_LOG(INFO) << "Cost " << i << ": total_cost: " << tmp;
565 if (minimum > tmp) {
566 minimum = tmp;
567 ret = after_mem_filter[i];
568 MS_LOG(INFO) << "Selected: " << i;
569 }
570 }
571 return ret;
572 }
573
SelectCostWithMinTrainingTime(const CostPtrList & cost_list,double memory) const574 CostPtr CostGraph::SelectCostWithMinTrainingTime(const CostPtrList &cost_list, double memory) const {
575 // Select the cost with minimum training time. Currently, the training time is modeled as =
576 // costmodel_alpha_ * computation_cost + costmodel_beta_ * communication_with_partial_para_
577 if (cost_list.empty()) {
578 MS_LOG(ERROR) << "Final cost list is null.";
579 return nullptr;
580 }
581 CostPtrList after_mem_filter;
582 double minimum_memory = DBL_MAX;
583 // Filter out the valid costs.
584 for (auto &a_cost : cost_list) {
585 if (a_cost->memory_with_reuse_ <= memory) {
586 after_mem_filter.push_back(a_cost);
587 } else if (a_cost->memory_with_reuse_ < minimum_memory) {
588 minimum_memory = a_cost->memory_with_reuse_;
589 }
590 }
591 if (after_mem_filter.empty()) {
592 MS_LOG(ERROR) << "No available cost. The minimum memory cost is: " << minimum_memory
593 << ", the memory capacity is: " << memory << ".";
594 return nullptr;
595 }
596 // Init the returned value with first cost.
597 CostPtr ret = after_mem_filter[0];
598 const auto alpha = CostModelContext::GetInstance()->costmodel_alpha();
599 const auto beta = CostModelContext::GetInstance()->costmodel_beta();
600
601 double minimum = alpha * ret->computation_cost_ + beta * ret->communication_with_partial_para_;
602 MS_LOG(INFO) << "Cost 0: "
603 << "memory_cost: " << ret->memory_with_reuse_ << ", computation_cost_: " << ret->computation_cost_
604 << ", communication_with_partial_para_: " << ret->communication_with_partial_para_
605 << ", communication_cost_: " << ret->communication_cost_
606 << ", communication_without_parameter_: " << ret->communication_without_parameter_ << ".";
607 MS_LOG(INFO) << "Cost 0: total_cost: " << minimum;
608 for (size_t i = 1; i < after_mem_filter.size(); ++i) {
609 MS_EXCEPTION_IF_NULL(after_mem_filter[i]);
610 MS_LOG(INFO) << "Cost " << i << ": memory_cost: " << after_mem_filter[i]->memory_with_reuse_
611 << ", computation_cost_: " << after_mem_filter[i]->computation_cost_
612 << ", communication_with_partial_para_: " << after_mem_filter[i]->communication_with_partial_para_
613 << ", communication_cost_: " << after_mem_filter[i]->communication_cost_
614 << ", communication_without_parameter_: " << after_mem_filter[i]->communication_without_parameter_
615 << ".";
616 auto tmp =
617 alpha * after_mem_filter[i]->computation_cost_ + beta * after_mem_filter[i]->communication_with_partial_para_;
618 MS_LOG(INFO) << "Cost " << i << ": total_cost: " << tmp;
619 if (minimum > tmp) {
620 minimum = tmp;
621 ret = after_mem_filter[i];
622 MS_LOG(INFO) << "Selected: " << i;
623 }
624 }
625 return ret;
626 }
627
SelectCostListWithMinTrainingTimeMultiple(const std::vector<CostPtrList> & all_cost_list,double available_memory) const628 CostPtrList CostGraph::SelectCostListWithMinTrainingTimeMultiple(const std::vector<CostPtrList> &all_cost_list,
629 double available_memory) const {
630 CostPtrList selected_cost_list(all_cost_list.size(), nullptr);
631 double minimum = DBL_MAX, total_memory = 0.0;
632 CostPtrList ret(all_cost_list.size(), nullptr);
633 // Check whether valid costs exist.
634 for (size_t i = 0; i < all_cost_list.size(); ++i) {
635 if (all_cost_list[i][0] == nullptr) {
636 MS_LOG(ERROR) << "The cost list " << i << " is empty.";
637 return ret;
638 } else {
639 double memory_i_cost = DBL_MAX;
640 for (size_t j = 0; j < all_cost_list[i].size(); ++j) {
641 if (all_cost_list[i][j]->memory_with_reuse_ < memory_i_cost) {
642 memory_i_cost = all_cost_list[i][j]->memory_with_reuse_;
643 }
644 }
645 total_memory += memory_i_cost;
646 }
647 }
648 if (total_memory >= available_memory) {
649 MS_LOG(ERROR) << "No strategy can be found under current memory: " << available_memory
650 << ", minimum strategy cost: " << total_memory << ".";
651 return selected_cost_list;
652 }
653
654 std::function<void(size_t)> recursive = [&all_cost_list, &selected_cost_list, &minimum, &ret, &recursive,
655 &available_memory](size_t k) {
656 const auto alpha = CostModelContext::GetInstance()->costmodel_alpha();
657 const auto beta = CostModelContext::GetInstance()->costmodel_beta();
658 if (k == all_cost_list.size()) {
659 double tmp_memory = 0.0, tmp_minimum = 0.0;
660 for (size_t i = 0; i < selected_cost_list.size(); ++i) {
661 MS_EXCEPTION_IF_NULL(selected_cost_list[i]);
662 tmp_memory += selected_cost_list[i]->memory_with_reuse_;
663 tmp_minimum += alpha * selected_cost_list[i]->computation_cost_ +
664 beta * selected_cost_list[i]->communication_with_partial_para_;
665 }
666 MS_LOG(INFO) << "tmp_memory: " << tmp_memory << ", tmp_minimum: " << tmp_minimum << ", minimum: " << minimum
667 << ".";
668 if (tmp_memory < available_memory && tmp_minimum < minimum) {
669 ret = selected_cost_list;
670 minimum = tmp_minimum;
671 MS_LOG(INFO) << "selected tmp_memory: " << tmp_memory << ", tmp_minimum: " << tmp_minimum << ".";
672 }
673 return;
674 }
675
676 MS_LOG(DEBUG) << "The value minimum: " << minimum << ", available_memory: " << available_memory << ".";
677 for (auto &c : all_cost_list[k]) {
678 selected_cost_list[k] = c;
679 recursive(k + 1);
680 }
681 };
682 recursive(0);
683 return ret;
684 }
685
SearchStrategyForMultiNodeFinalGraph(const std::vector<OperatorInfoPtr> & alive_ops)686 Status CostGraph::SearchStrategyForMultiNodeFinalGraph(const std::vector<OperatorInfoPtr> &alive_ops) {
687 MS_LOG(INFO) << "There are " << alive_ops.size() << " nodes in the final graph.";
688 auto connected_components = ConstructConnectedComponents(alive_ops);
689 MS_LOG(INFO) << "There are " << connected_components.size() << " components in the final graph.";
690 std::vector<CostPtrList> all_list;
691 for (size_t j = 0; j < connected_components.size(); ++j) {
692 auto one_component = connected_components[j];
693 MS_EXCEPTION_IF_NULL(one_component);
694 if (one_component->GetOperators().size() == 1) {
695 MS_LOG(INFO) << "There are 1 operator in a component in the final graph.";
696 auto cost_list_1 = one_component->CreateFinalSingleCostList(one_component->GetOperators()[0]);
697 all_list.push_back(cost_list_1);
698 } else if (one_component->GetOperators().size() == 2) {
699 MS_LOG(INFO) << "There are 2 operators in a component in the final graph.";
700 OperatorInfoPtr u, v;
701 auto first_op = one_component->GetOperators()[0];
702 auto second_op = one_component->GetOperators()[1];
703 MS_EXCEPTION_IF_NULL(first_op);
704 MS_EXCEPTION_IF_NULL(second_op);
705 if (!first_op->GetAliveSuccEdges().empty() &&
706 first_op->GetAliveSuccEdges()[0]->next_operator().get() == second_op.get()) {
707 u = first_op;
708 v = second_op;
709 } else if (!second_op->GetAliveSuccEdges().empty() &&
710 second_op->GetAliveSuccEdges()[0]->next_operator().get() == first_op.get()) {
711 u = second_op;
712 v = first_op;
713 } else {
714 MS_LOG(EXCEPTION) << "The final graph is not the case of u --> v, " << first_op->GetAliveSuccEdges().size()
715 << ", " << second_op->GetAliveSuccEdges().size() << ".";
716 }
717 MS_EXCEPTION_IF_NULL(u);
718 auto e = u->GetAliveSuccEdges()[0];
719 auto cost_list = one_component->CreateFinalCostList(u, e, v);
720 all_list.push_back(cost_list);
721 } else {
722 MS_LOG(EXCEPTION) << "There are " << one_component->GetOperators().size()
723 << " operators in a component in the final graph.";
724 }
725 }
726 const auto device_mem_capacity = CostModelContext::GetInstance()->device_memory_capacity();
727 auto selected_cost_list = SelectCostListWithMinTrainingTimeMultiple(all_list, device_mem_capacity);
728 for (size_t k = 0; k < selected_cost_list.size(); ++k) {
729 auto selected_cost = selected_cost_list[k];
730 if (selected_cost == nullptr) {
731 MS_LOG(ERROR) << "No valid strategy can be found under the current device memory: " << device_mem_capacity << ".";
732 return FAILED;
733 }
734 MS_EXCEPTION_IF_NULL(connected_components[k]);
735 if (connected_components[k]->GetOperators().size() == 1) {
736 auto u = connected_components[k]->GetOperators()[0];
737 auto decision_f = selected_cost->decision_ptr_->cast<FinalSingleDecisionPtr>();
738 u->SetSelectedStrategyAndCost(decision_f->u_strategy_, decision_f->u_cost_);
739 MS_LOG(INFO) << "Searching the strategy for the component " << k << " final graph ended.";
740 } else if (connected_components[k]->GetOperators().size() == 2) {
741 OperatorInfoPtr u = nullptr;
742 OperatorInfoPtr v = nullptr;
743 auto first_op = connected_components[k]->GetOperators()[0];
744 auto second_op = connected_components[k]->GetOperators()[1];
745 MS_EXCEPTION_IF_NULL(first_op);
746 MS_EXCEPTION_IF_NULL(second_op);
747 if (!first_op->GetAliveSuccEdges().empty() &&
748 first_op->GetAliveSuccEdges()[0]->next_operator().get() == second_op.get()) {
749 u = first_op;
750 v = second_op;
751 } else if (!second_op->GetAliveSuccEdges().empty() &&
752 second_op->GetAliveSuccEdges()[0]->next_operator().get() == first_op.get()) {
753 u = second_op;
754 v = first_op;
755 }
756 MS_EXCEPTION_IF_NULL(u);
757 auto e = u->GetAliveSuccEdges()[0];
758 MS_EXCEPTION_IF_NULL(v);
759 MS_EXCEPTION_IF_NULL(e);
760 MS_EXCEPTION_IF_NULL(selected_cost->decision_ptr_);
761 auto decision = selected_cost->decision_ptr_->cast<FinalDecisionPtr>();
762 MS_EXCEPTION_IF_NULL(decision);
763 u->SetSelectedStrategyAndCost(decision->u_strategy_, decision->left_cost_);
764 v->SetSelectedStrategyAndCost(decision->v_strategy_, decision->right_cost_);
765 e->set_selected_cost(decision->middle_cost_);
766 MS_LOG(INFO) << "Searching the strategy for the component " << k << " final graph ended.";
767 }
768 }
769 return SUCCESS;
770 }
771
SearchStrategyForTwoNodeFinalGraph(const std::vector<OperatorInfoPtr> & alive_ops)772 Status CostGraph::SearchStrategyForTwoNodeFinalGraph(const std::vector<OperatorInfoPtr> &alive_ops) {
773 // In this case, the final graph should contains exactly 2 nodes.
774 if (alive_ops.empty()) {
775 MS_LOG(INFO) << "0 Operator in the final graph.";
776 return SUCCESS;
777 }
778 OperatorInfoPtr u, v;
779 MS_EXCEPTION_IF_NULL(alive_ops[0]);
780 MS_EXCEPTION_IF_NULL(alive_ops[1]);
781 const auto phase = CostModelContext::GetInstance()->run_phase();
782 const auto device_mem_capacity = CostModelContext::GetInstance()->device_memory_capacity();
783 if (!alive_ops[0]->GetAliveSuccEdges().empty() &&
784 alive_ops[0]->GetAliveSuccEdges()[0]->next_operator().get() == alive_ops[1].get()) {
785 u = alive_ops[0];
786 v = alive_ops[1];
787 } else if (!alive_ops[1]->GetAliveSuccEdges().empty() &&
788 alive_ops[1]->GetAliveSuccEdges()[0]->next_operator().get() == alive_ops[0].get()) {
789 u = alive_ops[1];
790 v = alive_ops[0];
791 } else {
792 if (!alive_ops[0]->GetAliveSuccEdges().empty() || !alive_ops[1]->GetAliveSuccEdges().empty()) {
793 MS_LOG(EXCEPTION) << "The final graph is not the case of u --> v, " << alive_ops[0]->GetAliveSuccEdges().size()
794 << ", " << alive_ops[1]->GetAliveSuccEdges().size() << ".";
795 } else {
796 // In this case, the final graph consists of two single nodes
797 MS_LOG(INFO) << "There are 2 single nodes in the final graph.";
798 std::vector<CostPtrList> all_list;
799 auto connected_components = ConstructConnectedComponents(alive_ops);
800 MS_LOG(INFO) << "There are " << connected_components.size() << " components in the final graph.";
801 for (size_t i = 0; i < connected_components.size(); ++i) {
802 MS_LOG(INFO) << "There are 1 operator in a component in the final graph.";
803 auto one_component = connected_components[i];
804 MS_EXCEPTION_IF_NULL(one_component);
805 auto cost_list = one_component->CreateFinalSingleCostList(one_component->GetOperators()[0]);
806 all_list.push_back(cost_list);
807 }
808 CostPtrList selected_cost_list;
809 if (phase == TRAINING_PHASE) {
810 // training phase
811 selected_cost_list = SelectCostListWithMinTrainingTimeMultiple(all_list, device_mem_capacity);
812 } else {
813 // inference phase
814 MS_LOG(EXCEPTION) << "Currently, searching strategy for the two-separated-node final graph in the inference "
815 "phase is not supported.";
816 }
817 for (size_t k = 0; k < selected_cost_list.size(); ++k) {
818 auto selected_cost = selected_cost_list[k];
819 if (selected_cost == nullptr) {
820 MS_LOG(ERROR) << "No valid strategy can be found under the current device memory: " << device_mem_capacity
821 << ".";
822 return FAILED;
823 }
824 MS_EXCEPTION_IF_NULL(connected_components[k]);
825 auto one_operator = connected_components[k]->GetOperators()[0];
826 MS_EXCEPTION_IF_NULL(selected_cost->decision_ptr_);
827 auto decision = selected_cost->decision_ptr_->cast<FinalSingleDecisionPtr>();
828 MS_EXCEPTION_IF_NULL(decision);
829 one_operator->SetSelectedStrategyAndCost(decision->u_strategy_, decision->u_cost_);
830 MS_LOG(INFO) << "Searching the strategy for the component " << k << " final graph ended.";
831 }
832
833 return SUCCESS;
834 }
835 }
836 MS_LOG(INFO) << "There are 2 nodes in the final graph.";
837 // In this case, the finale graph is exactly of the form: u --> v
838 MS_EXCEPTION_IF_NULL(u);
839 MS_EXCEPTION_IF_NULL(v);
840 auto e = u->GetAliveSuccEdges()[0];
841 MS_EXCEPTION_IF_NULL(e);
842 auto f_cost_list = CreateFinalCostList(u, e, v);
843 CostPtr cost = nullptr;
844 if (phase == TRAINING_PHASE) {
845 // training phase
846 cost = SelectCostWithMinTrainingTime(f_cost_list, device_mem_capacity);
847 } else {
848 MS_LOG(EXCEPTION) << "Currently, searching strategy for the two-connected-node final graph in the inference "
849 "phase is not supported.";
850 }
851 if (cost == nullptr) {
852 MS_LOG(ERROR) << "No valid strategy can be found under the current device memory: " << device_mem_capacity << ".";
853 return FAILED;
854 }
855 MS_EXCEPTION_IF_NULL(cost->decision_ptr_);
856 auto f_decision = cost->decision_ptr_->cast<FinalDecisionPtr>();
857 MS_EXCEPTION_IF_NULL(f_decision);
858 u->SetSelectedStrategyAndCost(f_decision->u_strategy_, f_decision->left_cost_);
859 v->SetSelectedStrategyAndCost(f_decision->v_strategy_, f_decision->right_cost_);
860 e->set_selected_cost(f_decision->middle_cost_);
861 MS_LOG(INFO) << "Searching the strategy for the eliminated final graph ended.";
862 return SUCCESS;
863 }
864
865 // searching the strategy for the final eliminated graph
SearchStrategy()866 Status CostGraph::SearchStrategy() {
867 MS_LOG(INFO) << "Searching the strategy for the eliminated final graph began.";
868 std::vector<OperatorInfoPtr> alive_ops;
869 (void)std::for_each(ops_.begin(), ops_.end(), [&alive_ops](const OperatorInfoPtr &op) {
870 MS_EXCEPTION_IF_NULL(op);
871 if (op->is_alive()) {
872 alive_ops.push_back(op);
873 }
874 });
875 const auto phase = CostModelContext::GetInstance()->run_phase();
876 const auto device_mem_capacity = CostModelContext::GetInstance()->device_memory_capacity();
877
878 if (alive_ops.size() > 2) {
879 if (phase == TRAINING_PHASE) {
880 // training phase
881 return SearchStrategyForMultiNodeFinalGraph(alive_ops);
882 } else {
883 // inference phase
884 MS_LOG(EXCEPTION)
885 << "Currently, searching strategy for the multi-node final graph in inference phase is not supported.";
886 }
887 } else if (alive_ops.size() == 1) {
888 MS_LOG(INFO) << "There are 1 single node in the final graph.";
889 OperatorInfoPtr u = alive_ops[0];
890 auto cost_list = CreateFinalSingleCostList(u);
891 CostPtr cost = nullptr;
892 if (phase == TRAINING_PHASE) {
893 // training phase
894 cost = SelectCostWithMinTrainingTime(cost_list, device_mem_capacity);
895 } else {
896 // inference phase
897 cost = SelectCostWithMinInferenceTime(cost_list, device_mem_capacity);
898 }
899 if (cost == nullptr) {
900 MS_LOG(ERROR) << "No valid strategy can be found under the current device memory: " << device_mem_capacity << ".";
901 return FAILED;
902 }
903 MS_EXCEPTION_IF_NULL(u);
904 MS_EXCEPTION_IF_NULL(cost->decision_ptr_);
905 auto decision = cost->decision_ptr_->cast<FinalSingleDecisionPtr>();
906 MS_EXCEPTION_IF_NULL(decision);
907 u->SetSelectedStrategyAndCost(decision->u_strategy_, decision->u_cost_);
908 MS_LOG(INFO) << "Searching the strategy for the eliminated final graph ended.";
909 return SUCCESS;
910 } else {
911 return SearchStrategyForTwoNodeFinalGraph(alive_ops);
912 }
913 }
914
915 // Given a graph which contains the following subgraph: u --> v --> w, the node v can be eliminated
916 // return the v and the edge u --> v
CheckOpElimination() const917 OperatorInfoPtr CostGraph::CheckOpElimination() const {
918 for (auto &op : ops_) {
919 bool bool_test = op->is_alive() && op->GetAliveSuccEdges().size() == 1 && op->GetAlivePrevEdges().size() == 1;
920 if (bool_test) {
921 if ((op->GetAliveSuccEdges()[0]->next_operator() != op) && (op->GetAlivePrevEdges()[0]->prev_operator() != op)) {
922 return op;
923 }
924 }
925 }
926 return nullptr;
927 }
928
929 // Check the graph whether an EdgeElimination can be performed
CheckEdgeElimination() const930 std::vector<std::shared_ptr<Edge>> CostGraph::CheckEdgeElimination() const {
931 for (auto &op : ops_) {
932 MS_EXCEPTION_IF_NULL(op);
933 if (!op->is_alive()) {
934 continue;
935 }
936 std::map<void *, int64_t> count;
937 for (auto &edge_su : op->GetAliveSuccEdges()) {
938 MS_EXCEPTION_IF_NULL(edge_su);
939 auto v = edge_su->next_operator();
940 count[v.get()]++;
941 }
942 for (auto &pair : count) {
943 auto *op_ptr = pair.first;
944 int64_t op_count = pair.second;
945 if (op_count > 1) {
946 std::vector<std::shared_ptr<Edge>> ret;
947 for (auto &edge : op->GetAliveSuccEdges()) {
948 MS_EXCEPTION_IF_NULL(edge);
949 if (edge->next_operator().get() == op_ptr) {
950 ret.push_back(edge);
951 }
952 }
953 return ret;
954 }
955 }
956 }
957 return {};
958 }
959
960 // Check the graph whether a MergeElimination can be performed
CheckMergeElimination() const961 OperatorInfoPtr CostGraph::CheckMergeElimination() const {
962 for (auto &op : ops_) {
963 MS_EXCEPTION_IF_NULL(op);
964 bool bool_test = op->is_alive() && op->GetAlivePrevEdges().empty() && op->GetAliveSuccEdges().size() == 1;
965 if (bool_test) {
966 auto next_op = op->GetAliveSuccEdges()[0]->next_operator();
967 MS_EXCEPTION_IF_NULL(next_op);
968 if (!next_op->GetAlivePrevEdges().empty()) {
969 return op;
970 }
971 }
972 }
973 return nullptr;
974 }
975
976 // Check the graph whether a ContractElimination can be performed
CheckContractElimination() const977 OperatorInfoPtr CostGraph::CheckContractElimination() const {
978 for (auto &op : ops_) {
979 MS_EXCEPTION_IF_NULL(op);
980 bool bool_test = op->is_alive() && op->GetAlivePrevEdges().size() == 1 && op->GetAliveSuccEdges().empty();
981 if (bool_test) {
982 auto edge = op->GetAlivePrevEdges()[0];
983 MS_EXCEPTION_IF_NULL(edge);
984 auto prev_op = edge->prev_operator();
985 MS_EXCEPTION_IF_NULL(prev_op);
986 if (!prev_op->GetAliveSuccEdges().empty()) {
987 return op;
988 }
989 }
990 }
991 return nullptr;
992 }
993
CheckSourceElimination() const994 std::pair<OperatorInfoPtr, OperatorInfoPtr> CostGraph::CheckSourceElimination() const {
995 size_t source_count = 0;
996 std::vector<OperatorInfoPtr> op_vector(2, nullptr);
997 for (auto &op : ops_) {
998 MS_EXCEPTION_IF_NULL(op);
999 bool bool_test = op->is_alive() && op->GetAlivePrevEdges().empty() && op->GetAliveSuccEdges().size() > 0;
1000 if (bool_test) {
1001 op_vector[source_count++] = op;
1002 if (source_count == 2) {
1003 return std::make_pair(op_vector[0], op_vector[1]);
1004 }
1005 }
1006 }
1007 return std::make_pair(nullptr, nullptr);
1008 }
1009
CreateSourceEliminationSubCostList(StrategyPtr op1_old_stra,const CostPtrList & op1_old_clist,StrategyPtr op2_old_stra,const CostPtrList & op2_old_clist,CostPtrList * op1_new_clist) const1010 void CostGraph::CreateSourceEliminationSubCostList(StrategyPtr op1_old_stra, const CostPtrList &op1_old_clist,
1011 StrategyPtr op2_old_stra, const CostPtrList &op2_old_clist,
1012 CostPtrList *op1_new_clist) const {
1013 for (auto &op1_cost : op1_old_clist) {
1014 for (auto &op2_cost : op2_old_clist) {
1015 double computation = op1_cost->computation_cost_ + op2_cost->computation_cost_;
1016 double memory = op1_cost->memory_with_reuse_ + op2_cost->memory_with_reuse_;
1017 double communication = op1_cost->communication_cost_ + op2_cost->communication_cost_;
1018 double communication_forward = op1_cost->communication_forward_ + op2_cost->communication_forward_;
1019 double communication_without_para =
1020 op1_cost->communication_without_parameter_ + op2_cost->communication_without_parameter_;
1021 auto decision = std::make_shared<SourceEliminationDecision>(op1_old_stra, op1_cost, op2_old_stra, op2_cost);
1022 auto new_cost = std::make_shared<Cost>(computation, communication, decision);
1023 const auto gamma = CostModelContext::GetInstance()->costmodel_gamma();
1024 MS_EXCEPTION_IF_NULL(new_cost);
1025 new_cost->communication_without_parameter_ = communication_without_para;
1026 new_cost->communication_with_partial_para_ =
1027 communication_without_para + gamma * (communication - communication_without_para);
1028 new_cost->memory_with_reuse_ = memory;
1029 new_cost->communication_forward_ = communication_forward;
1030 MS_EXCEPTION_IF_NULL(op1_new_clist);
1031 op1_new_clist->push_back(std::move(new_cost));
1032 }
1033 }
1034 }
1035
UpdateEdgesIncidentToNodes(OperatorInfoPtr op1,std::vector<EdgePtr> * op1_old_succ_edges,std::vector<std::map<CostPtrKey,CostPtrList>> * op1_new_edges_cost,std::vector<EdgePtr> * op1_new_succ_edges,const OperatorInfoPtr op2,std::vector<EdgePtr> * op2_old_succ_edges,std::vector<std::map<CostPtrKey,CostPtrList>> * op2_new_edges_cost,std::vector<EdgePtr> * op2_new_succ_edges)1036 std::pair<std::vector<EdgePtr>, std::vector<EdgePtr>> UpdateEdgesIncidentToNodes(
1037 OperatorInfoPtr op1, std::vector<EdgePtr> *op1_old_succ_edges,
1038 std::vector<std::map<CostPtrKey, CostPtrList>> *op1_new_edges_cost, std::vector<EdgePtr> *op1_new_succ_edges,
1039 const OperatorInfoPtr op2, std::vector<EdgePtr> *op2_old_succ_edges,
1040 std::vector<std::map<CostPtrKey, CostPtrList>> *op2_new_edges_cost, std::vector<EdgePtr> *op2_new_succ_edges) {
1041 for (size_t i = 0; i < op1_old_succ_edges->size(); ++i) {
1042 auto &new_cost_map = op1_new_edges_cost->at(i);
1043 auto ith_edge = op1_old_succ_edges->at(i);
1044
1045 std::string new_edge_name = op1->name() + OPERATOR_TO_OPERATOR_CONNECTOR + ith_edge->next_operator()->name();
1046 std::shared_ptr<Edge> new_edge;
1047 if (ith_edge->is_combined()) {
1048 std::vector<size_t> output_indexs, input_indexs;
1049 output_indexs = ith_edge->prev_op_output_indexs();
1050 input_indexs = ith_edge->next_op_input_indexs();
1051 new_edge =
1052 std::make_shared<Edge>(new_edge_name, op1, ith_edge->next_operator(), output_indexs, input_indexs, true);
1053 } else {
1054 size_t output_index;
1055 size_t input_index;
1056 output_index = ith_edge->prev_op_output_index();
1057 input_index = ith_edge->next_op_input_index();
1058 new_edge =
1059 std::make_shared<Edge>(new_edge_name, op1, ith_edge->next_operator(), output_index, input_index, false);
1060 }
1061 new_edge->SetCostMapAndInputOutput(new_cost_map);
1062 // replace the old successive edges with the new ones.
1063 op1->ReplaceSuccEdge(ith_edge->next_operator(), new_edge);
1064 ith_edge->next_operator()->ReplacePreEdge(op1, new_edge);
1065 (void)op1_new_succ_edges->erase(op1_new_succ_edges->cbegin() + SizeToLong(i));
1066 (void)op1_new_succ_edges->emplace(op1_new_succ_edges->begin() + SizeToLong(i), new_edge);
1067 }
1068 for (size_t i = 0; i < op2_old_succ_edges->size(); ++i) {
1069 auto &new_cost_map = op2_new_edges_cost->at(i);
1070 auto ith_edge = op2_old_succ_edges->at(i);
1071 const auto &destination = ith_edge->next_operator();
1072
1073 std::string new_edge_name = op1->name() + OPERATOR_TO_OPERATOR_CONNECTOR + destination->name();
1074 std::shared_ptr<Edge> new_edge;
1075 if (ith_edge->is_combined()) {
1076 std::vector<size_t> output_indexs, input_indexs;
1077 output_indexs = ith_edge->prev_op_output_indexs();
1078 input_indexs = ith_edge->next_op_input_indexs();
1079 new_edge = std::make_shared<Edge>(new_edge_name, op1, destination, output_indexs, input_indexs, true);
1080 } else {
1081 size_t output_index;
1082 size_t input_index;
1083 output_index = ith_edge->prev_op_output_index();
1084 input_index = ith_edge->next_op_input_index();
1085 new_edge = std::make_shared<Edge>(new_edge_name, op1, destination, output_index, input_index, false);
1086 }
1087 new_edge->SetCostMapAndInputOutput(new_cost_map);
1088 // replace the old successive edges with the new ones.
1089 destination->ReplacePreEdge(op2, new_edge);
1090 op1->AddSuccEdge(new_edge);
1091 (void)op2_new_succ_edges->erase(op2_new_succ_edges->cbegin() + SizeToLong(i));
1092 (void)op2_new_succ_edges->emplace(op2_new_succ_edges->begin() + SizeToLong(i), new_edge);
1093 }
1094 return std::make_pair(*op1_new_succ_edges, *op2_new_succ_edges);
1095 }
1096
EliminationSources(const OperatorInfoPtr op1,const OperatorInfoPtr op2) const1097 std::pair<std::vector<std::shared_ptr<Edge>>, std::vector<std::shared_ptr<Edge>>> CostGraph::EliminationSources(
1098 const OperatorInfoPtr op1, const OperatorInfoPtr op2) const {
1099 MS_EXCEPTION_IF_NULL(op1);
1100 MS_EXCEPTION_IF_NULL(op2);
1101 MS_LOG(INFO) << "Now source eliminating node: " << op2->name() << " to node: " << op1->name();
1102
1103 auto op1_old_succ_edges = op1->GetAliveSuccEdges();
1104 std::vector<std::map<StrategyPtr, std::vector<std::pair<StrategyPtr, CostPtrList>>>> op1_edges_reorganised_cost(
1105 op1_old_succ_edges.size());
1106 std::vector<std::map<CostPtrKey, CostPtrList>> op1_new_edges_cost(op1_old_succ_edges.size());
1107 std::vector<std::shared_ptr<Edge>> op1_new_succ_edges(op1_old_succ_edges.size());
1108
1109 auto op2_old_succ_edges = op2->GetAliveSuccEdges();
1110 std::vector<std::map<StrategyPtr, std::vector<std::pair<StrategyPtr, CostPtrList>>>> op2_edges_reorganised_cost(
1111 op2_old_succ_edges.size());
1112 std::vector<std::map<CostPtrKey, CostPtrList>> op2_new_edges_cost(op2_old_succ_edges.size());
1113 std::vector<std::shared_ptr<Edge>> op2_new_succ_edges(op2_old_succ_edges.size());
1114
1115 // Construct cost_map for the data_structure of 'op1_edges_reorganised_cost' and 'op2_edges_reorganised_cost'
1116 for (size_t i = 0; i < op1_old_succ_edges.size(); ++i) {
1117 const auto &op1_cost_map = op1_old_succ_edges[i]->GetCostMap();
1118 std::map<StrategyPtr, std::vector<std::pair<StrategyPtr, CostPtrList>>> from_tocost;
1119 for (const auto &key_value : op1_cost_map) {
1120 const auto &from_to_strategies = key_value.first;
1121 const auto &costlist = key_value.second;
1122 from_tocost[from_to_strategies.first].push_back(std::make_pair(from_to_strategies.second, costlist));
1123 }
1124 op1_edges_reorganised_cost[i] = from_tocost;
1125 }
1126
1127 for (size_t i = 0; i < op2_old_succ_edges.size(); ++i) {
1128 const auto &op2_cost_map = op2_old_succ_edges[i]->GetCostMap();
1129 std::map<StrategyPtr, std::vector<std::pair<StrategyPtr, CostPtrList>>> from_tocost;
1130 for (const auto &key_value : op2_cost_map) {
1131 const auto &from_to_strategies = key_value.first;
1132 const auto &costlist = key_value.second;
1133 from_tocost[from_to_strategies.first].push_back(std::make_pair(from_to_strategies.second, costlist));
1134 }
1135 op2_edges_reorganised_cost[i] = from_tocost;
1136 }
1137
1138 // Merge op2 into op1
1139 const auto &op1_old_stra_cost = op1->GetStrategyCost();
1140 const auto &op2_old_stra_cost = op2->GetStrategyCost();
1141 std::vector<std::shared_ptr<StrategyWithCost>> op1_new_stra_cost;
1142
1143 for (auto &op1_stra_cost : op1_old_stra_cost) {
1144 auto op1_old_stra = op1_stra_cost->strategy_ptr;
1145 auto op1_old_costlist = op1_stra_cost->cost_list;
1146
1147 for (auto &op2_stra_cost : op2_old_stra_cost) {
1148 auto op2_stra = op2_stra_cost->strategy_ptr;
1149 auto op2_costlist = op2_stra_cost->cost_list;
1150
1151 StrategyPtr op1_new_stra = std::make_shared<Strategy>(*op1_old_stra);
1152 op1_new_stra->CoverStrategy(op2_stra);
1153 CostPtrList op1_new_costlist;
1154 // Calculate new cost for 'op1_new_costlist'
1155 CreateSourceEliminationSubCostList(op1_old_stra, op1_old_costlist, op2_stra, op2_costlist, &op1_new_costlist);
1156 std::shared_ptr<StrategyWithCost> swc = std::make_shared<StrategyWithCost>(op1_new_stra, op1_new_costlist);
1157 op1_new_stra_cost.push_back(swc);
1158
1159 // Set cost for new successive edges of op1 and op2
1160 for (size_t i = 0; i < op1_old_succ_edges.size(); ++i) {
1161 auto &from_tocost = op1_edges_reorganised_cost[i];
1162 auto &to_cost = from_tocost[op1_old_stra];
1163 auto &new_cost_map = op1_new_edges_cost[i];
1164 for (auto &stra_costlit : to_cost) {
1165 auto &to_strategy = stra_costlit.first;
1166 auto &edge_costlist = stra_costlit.second;
1167 CostPtrKey new_key = {op1_new_stra, to_strategy};
1168 new_cost_map[new_key] = edge_costlist;
1169 }
1170 }
1171 for (size_t i = 0; i < op2_old_succ_edges.size(); ++i) {
1172 auto &from_tocost = op2_edges_reorganised_cost[i];
1173 auto &to_cost = from_tocost[op2_stra];
1174 auto &new_cost_map = op2_new_edges_cost[i];
1175 for (auto &stra_costlist : to_cost) {
1176 auto &to_strategy = stra_costlist.first;
1177 auto &edge_costlist = stra_costlist.second;
1178 CostPtrKey new_key = {op1_new_stra, to_strategy};
1179 new_cost_map[new_key] = edge_costlist;
1180 }
1181 }
1182 }
1183 }
1184 op1->SetStrategyCost(op1_new_stra_cost);
1185 op2->SetNotAlive();
1186
1187 // Update the edges incident to op1, and edges incident to op2
1188 MS_LOG(INFO) << "Source eliminating node: " << op2->name() << " to node: " << op1->name() + " succeeded.";
1189 return UpdateEdgesIncidentToNodes(op1, &op1_old_succ_edges, &op1_new_edges_cost, &op1_new_succ_edges, op2,
1190 &op2_old_succ_edges, &op2_new_edges_cost, &op2_new_succ_edges);
1191 }
1192
1193 // Check the graph whether a TriangleElimination can be performed
CheckTriangleElimination() const1194 std::pair<OperatorInfoPtr, std::shared_ptr<Edge>> CostGraph::CheckTriangleElimination() const {
1195 for (auto &op : ops_) {
1196 MS_EXCEPTION_IF_NULL(op);
1197 bool bool_test = (op->is_alive()) && (op->GetAlivePrevEdges().empty()) && (op->GetAliveSuccEdges().size() == 2);
1198 if (bool_test) {
1199 auto edge1 = op->GetAliveSuccEdges()[0];
1200 auto edge2 = op->GetAliveSuccEdges()[1];
1201 MS_EXCEPTION_IF_NULL(edge1);
1202 MS_EXCEPTION_IF_NULL(edge2);
1203 auto first_op = edge1->next_operator();
1204 auto second_op = edge2->next_operator();
1205 MS_EXCEPTION_IF_NULL(first_op);
1206 for (auto &first_op_succ_edge : first_op->GetAliveSuccEdges()) {
1207 if (first_op_succ_edge->next_operator() == second_op) {
1208 return {op, first_op_succ_edge};
1209 }
1210 }
1211 MS_EXCEPTION_IF_NULL(second_op);
1212 for (auto &second_op_succ_edge : second_op->GetAliveSuccEdges()) {
1213 if (second_op_succ_edge->next_operator() == first_op) {
1214 return {op, second_op_succ_edge};
1215 }
1216 }
1217 }
1218 }
1219 return {nullptr, nullptr};
1220 }
1221
1222 // Check the graph whether a StarElimination can be performed.
1223 // NOTE: this elimination MUST be performed only when the above 5 operation cannot be applied.
CheckStarElimination() const1224 OperatorInfoPtr CostGraph::CheckStarElimination() const {
1225 for (auto &op : ops_) {
1226 MS_EXCEPTION_IF_NULL(op);
1227 bool bool_test = (op->is_alive()) && (op->GetAlivePrevEdges().empty()) && (op->GetAliveSuccEdges().size() > 1);
1228 if (bool_test) {
1229 return op;
1230 }
1231 }
1232 return nullptr;
1233 }
1234
1235 // This method is for 'eliminating operator' operation in the DP algorithm. It creates a new edge to replace
1236 // 'lefe_edge', 'op' and 'right_edge'. As a consequence, it creates new costlist for the new edge.
EliminationOp(const OperatorInfoPtr & op) const1237 std::shared_ptr<Edge> CostGraph::EliminationOp(const OperatorInfoPtr &op) const {
1238 // in this case, the operators are organised in the form of u-->op-->v, and the goal
1239 // is to eliminate 'op'.
1240 MS_EXCEPTION_IF_NULL(op);
1241 MS_LOG(INFO) << "Now eliminating node: " << op->name() << ".";
1242 auto edge_u_op = op->GetAlivePrevEdges()[0];
1243 auto edge_op_v = op->GetAliveSuccEdges()[0];
1244 MS_EXCEPTION_IF_NULL(edge_u_op);
1245 MS_EXCEPTION_IF_NULL(edge_op_v);
1246 auto u = edge_u_op->prev_operator();
1247 auto v = edge_op_v->next_operator();
1248 std::vector<size_t> output_indexs;
1249 std::vector<size_t> input_indexs;
1250 size_t output_index, input_index;
1251 MS_EXCEPTION_IF_NULL(u);
1252 MS_EXCEPTION_IF_NULL(v);
1253 std::string new_edge_name = u->name() + OPERATOR_TO_OPERATOR_CONNECTOR + v->name();
1254 std::shared_ptr<Edge> new_edge;
1255 if (edge_u_op->is_combined()) {
1256 output_indexs = edge_u_op->prev_op_output_indexs();
1257 } else {
1258 output_index = edge_u_op->prev_op_output_index();
1259 output_indexs.push_back(output_index);
1260 }
1261 if (edge_op_v->is_combined()) {
1262 input_indexs = edge_op_v->next_op_input_indexs();
1263 } else {
1264 input_index = edge_op_v->next_op_input_index();
1265 input_indexs.push_back(input_index);
1266 }
1267
1268 if (!edge_u_op->is_combined() && !edge_op_v->is_combined()) {
1269 new_edge = std::make_shared<Edge>(new_edge_name, u, v, output_index, input_index, false);
1270 } else {
1271 new_edge = std::make_shared<Edge>(new_edge_name, u, v, output_indexs, input_indexs, true);
1272 }
1273 MS_EXCEPTION_IF_NULL(new_edge);
1274 new_edge->set_pre_op_output(edge_u_op->prev_op_output());
1275 new_edge->set_next_op_input(edge_op_v->next_op_input());
1276 new_edge->OpEliminationSetNewCost(edge_u_op, op, edge_op_v);
1277 u->ReplaceSuccEdge(op, new_edge);
1278 v->ReplacePreEdge(op, new_edge);
1279 op->SetNotAlive();
1280 MS_LOG(INFO) << "Eliminating node: " << op->name() << " succeeded.";
1281 return new_edge;
1282 }
1283
1284 // This method is for 'eliminating edges' operation in the DP algorithm. It creates a new edge to replace the 'edges',
1285 // and sets new costlist for the new edge.
EliminationEdges(const std::vector<std::shared_ptr<Edge>> & edges) const1286 std::shared_ptr<Edge> CostGraph::EliminationEdges(const std::vector<std::shared_ptr<Edge>> &edges) const {
1287 MS_LOG(INFO) << "Now eliminating " << edges.size() << " edges.";
1288 MS_EXCEPTION_IF_NULL(edges[0]);
1289 auto u = edges[0]->prev_operator();
1290 auto v = edges[0]->next_operator();
1291 MS_EXCEPTION_IF_NULL(u);
1292 MS_EXCEPTION_IF_NULL(v);
1293 std::string new_edge_name = u->name() + OPERATOR_TO_OPERATOR_CONNECTOR + v->name();
1294 std::vector<size_t> output_indexs, input_indexs;
1295
1296 for (auto &edge : edges) {
1297 MS_EXCEPTION_IF_NULL(edge);
1298 if (edge->is_combined()) {
1299 auto from_output_indexs = edge->prev_op_output_indexs();
1300 auto from_input_indexs = edge->next_op_input_indexs();
1301 (void)std::copy(from_output_indexs.begin(), from_output_indexs.end(), std::back_inserter(output_indexs));
1302 (void)std::copy(from_input_indexs.begin(), from_input_indexs.end(), std::back_inserter(input_indexs));
1303 } else {
1304 output_indexs.push_back(edge->prev_op_output_index());
1305 input_indexs.push_back(edge->next_op_input_index());
1306 }
1307 }
1308
1309 std::shared_ptr<Edge> new_edge = std::make_shared<Edge>(new_edge_name, u, v, output_indexs, input_indexs, true);
1310 MS_EXCEPTION_IF_NULL(new_edge);
1311 new_edge->set_pre_op_output(edges[0]->prev_op_output());
1312 new_edge->set_next_op_input(edges[0]->next_op_input());
1313
1314 new_edge->EdgeEliminationSetNewCost(u, edges, v);
1315
1316 u->ReplaceSuccEdges(v, new_edge);
1317 v->ReplacePreEdges(u, new_edge);
1318 MS_LOG(INFO) << "Eliminating " << edges.size() << " edges succeeded.";
1319 return new_edge;
1320 }
1321
1322 // Given 'op_cost_list', 'edge_cost_list', and 'tar_cost_list', this method is to create 'tar_cost_list_new'
1323 // for this contract under the strategy 'op_strategy'
CreateMergeEliminationSubCostList(StrategyPtr op_strategy,const CostPtrList & op_cost_list,const CostPtrList & edge_cost_list,StrategyPtr tar_op_strategy,const CostPtrList & tar_cost_list,CostPtrList * const tar_cost_list_new) const1324 void CostGraph::CreateMergeEliminationSubCostList(StrategyPtr op_strategy, const CostPtrList &op_cost_list,
1325 const CostPtrList &edge_cost_list, StrategyPtr tar_op_strategy,
1326 const CostPtrList &tar_cost_list,
1327 CostPtrList *const tar_cost_list_new) const {
1328 for (size_t i = 0; i < op_cost_list.size(); ++i) {
1329 auto &op_cost = op_cost_list[i];
1330 MS_EXCEPTION_IF_NULL(op_cost);
1331 for (size_t j = 0; j < edge_cost_list.size(); ++j) {
1332 auto &edge_cost = edge_cost_list[j];
1333 MS_EXCEPTION_IF_NULL(edge_cost);
1334 for (size_t k = 0; k < tar_cost_list.size(); ++k) {
1335 auto &tar_cost = tar_cost_list[k];
1336 MS_EXCEPTION_IF_NULL(tar_cost);
1337 double computation = op_cost->computation_cost_ + edge_cost->computation_cost_ + tar_cost->computation_cost_;
1338 double memory = op_cost->memory_with_reuse_ + edge_cost->memory_with_reuse_ + tar_cost->memory_with_reuse_;
1339 double communication =
1340 op_cost->communication_cost_ + edge_cost->communication_cost_ + tar_cost->communication_cost_;
1341 double communication_forward =
1342 op_cost->communication_forward_ + edge_cost->communication_forward_ + tar_cost->communication_forward_;
1343 double communication_without_para = op_cost->communication_without_parameter_ +
1344 edge_cost->communication_without_parameter_ +
1345 tar_cost->communication_without_parameter_;
1346
1347 auto decision =
1348 std::make_shared<MergeEliminationDecision>(op_strategy, op_cost, edge_cost, tar_op_strategy, tar_cost);
1349 auto new_cost = std::make_shared<Cost>(computation, communication, decision);
1350 const auto gamma = CostModelContext::GetInstance()->costmodel_gamma();
1351 MS_EXCEPTION_IF_NULL(new_cost);
1352 new_cost->communication_without_parameter_ = communication_without_para;
1353 new_cost->communication_with_partial_para_ =
1354 communication_without_para + gamma * (communication - communication_without_para);
1355 new_cost->memory_with_reuse_ = memory;
1356 new_cost->communication_forward_ = communication_forward;
1357 MS_EXCEPTION_IF_NULL(tar_cost_list_new);
1358 tar_cost_list_new->push_back(std::move(new_cost));
1359 }
1360 }
1361 }
1362 }
1363
1364 // This method is for the 'Merge' operation in DP algorithm. It creates new costlist for each strategy in the
1365 // target_op
EliminationMerge(const OperatorInfoPtr & op) const1366 OperatorInfoPtr CostGraph::EliminationMerge(const OperatorInfoPtr &op) const {
1367 MS_EXCEPTION_IF_NULL(op);
1368 auto target_op = op->GetAliveSuccEdges()[0]->next_operator();
1369 auto edge_ptr = op->GetAliveSuccEdges()[0];
1370 MS_EXCEPTION_IF_NULL(target_op);
1371 MS_EXCEPTION_IF_NULL(edge_ptr);
1372 MS_LOG(INFO) << "Now merging " << op->name() << " into " << target_op->name() << ".";
1373 bool valid = false;
1374
1375 for (auto &tar_stra_cost : target_op->GetStrategyCost()) {
1376 MS_EXCEPTION_IF_NULL(tar_stra_cost);
1377 auto tar_stra = tar_stra_cost->strategy_ptr;
1378 auto tar_clist_origin = tar_stra_cost->cost_list;
1379 CostPtrList tar_clist_new;
1380
1381 for (auto &op_stra_cost : op->GetStrategyCost()) {
1382 MS_EXCEPTION_IF_NULL(op_stra_cost);
1383 auto op_stra = op_stra_cost->strategy_ptr;
1384 auto op_clist = op_stra_cost->cost_list;
1385 auto edge_clist = edge_ptr->GetCostList(op_stra, tar_stra);
1386
1387 CreateMergeEliminationSubCostList(op_stra, op_clist, edge_clist, tar_stra, tar_clist_origin, &tar_clist_new);
1388 }
1389 Simplify(&tar_clist_new);
1390 // Set the new costlist w.r.t the strategy
1391 tar_stra_cost->cost_list = tar_clist_new;
1392 if ((!valid) && (!tar_clist_new.empty())) {
1393 valid = true;
1394 }
1395 }
1396
1397 if (!valid) {
1398 MS_LOG(EXCEPTION) << "Merging " << op->name() << " into " << target_op->name() << " failed.";
1399 }
1400 op->SetNotAlive();
1401 MS_LOG(INFO) << "Merging " << op->name() << " into " << target_op->name() << " succeeded.";
1402 return target_op;
1403 }
1404
1405 // Given 'contract_op_cost_list', 'edge_cost_list', and 'tar_cost_list', this method is to create 'tar_cost_list_new'
1406 // for this contract under the strategy 'contract_op_stra'
CreateContractEliminationSubCostList(StrategyPtr contract_op_stra,const CostPtrList & contract_op_cost_list,const CostPtrList & edge_cost_list,StrategyPtr target_op_stra,const CostPtrList & tar_cost_list,CostPtrList * tar_cost_list_new) const1407 void CostGraph::CreateContractEliminationSubCostList(StrategyPtr contract_op_stra,
1408 const CostPtrList &contract_op_cost_list,
1409 const CostPtrList &edge_cost_list, StrategyPtr target_op_stra,
1410 const CostPtrList &tar_cost_list,
1411 CostPtrList *tar_cost_list_new) const {
1412 for (size_t i = 0; i < contract_op_cost_list.size(); ++i) {
1413 auto &contract_op_cost = contract_op_cost_list[i];
1414 MS_EXCEPTION_IF_NULL(contract_op_cost);
1415 for (size_t j = 0; j < edge_cost_list.size(); ++j) {
1416 auto &edge_cost = edge_cost_list[j];
1417 MS_EXCEPTION_IF_NULL(edge_cost);
1418 for (size_t k = 0; k < tar_cost_list.size(); ++k) {
1419 auto &tar_cost = tar_cost_list[k];
1420 MS_EXCEPTION_IF_NULL(tar_cost);
1421 double computation =
1422 contract_op_cost->computation_cost_ + edge_cost->computation_cost_ + tar_cost->computation_cost_;
1423 double memory =
1424 contract_op_cost->memory_with_reuse_ + edge_cost->memory_with_reuse_ + tar_cost->memory_with_reuse_;
1425 double communication =
1426 contract_op_cost->communication_cost_ + edge_cost->communication_cost_ + tar_cost->communication_cost_;
1427 double communication_forward = contract_op_cost->communication_forward_ + edge_cost->communication_forward_ +
1428 tar_cost->communication_forward_;
1429 double communication_without_para = contract_op_cost->communication_without_parameter_ +
1430 edge_cost->communication_without_parameter_ +
1431 tar_cost->communication_without_parameter_;
1432
1433 auto decision = std::make_shared<ContractEliminationDecision>(contract_op_stra, contract_op_cost, edge_cost,
1434 target_op_stra, tar_cost);
1435 auto new_cost = std::make_shared<Cost>(computation, communication, decision);
1436 auto gamma = CostModelContext::GetInstance()->costmodel_gamma();
1437 new_cost->communication_without_parameter_ = communication_without_para;
1438 new_cost->communication_with_partial_para_ =
1439 communication_without_para + gamma * (communication - communication_without_para);
1440 new_cost->memory_with_reuse_ = memory;
1441 new_cost->communication_forward_ = communication_forward;
1442 tar_cost_list_new->push_back(std::move(new_cost));
1443 }
1444 }
1445 }
1446 }
1447
1448 // This method is for the 'Contract' operation in DP algorithm. It creates new costlist for each strategy in the
1449 // target_op
EliminationContract(const OperatorInfoPtr & op) const1450 OperatorInfoPtr CostGraph::EliminationContract(const OperatorInfoPtr &op) const {
1451 MS_EXCEPTION_IF_NULL(op);
1452 auto target_op = op->GetAlivePrevEdges()[0]->prev_operator();
1453 auto edge_ptr = op->GetAlivePrevEdges()[0];
1454 MS_LOG(INFO) << "Now contracting " << op->name() << " into " << target_op->name() << ".";
1455 bool valid = false;
1456
1457 for (auto &tar_stra_cost : target_op->GetStrategyCost()) {
1458 MS_EXCEPTION_IF_NULL(tar_stra_cost);
1459 auto tar_stra = tar_stra_cost->strategy_ptr;
1460 auto tar_clist_origin = tar_stra_cost->cost_list;
1461 CostPtrList tar_clist_new;
1462
1463 for (auto &op_stra_cost : op->GetStrategyCost()) {
1464 MS_EXCEPTION_IF_NULL(op_stra_cost);
1465 auto op_stra = op_stra_cost->strategy_ptr;
1466 auto op_clist = op_stra_cost->cost_list;
1467 auto edge_clist = edge_ptr->GetCostList(tar_stra, op_stra);
1468
1469 CreateContractEliminationSubCostList(op_stra, op_clist, edge_clist, tar_stra, tar_clist_origin, &tar_clist_new);
1470 }
1471 Simplify(&tar_clist_new);
1472 // Set the new costlist w.r.t the strategy
1473 tar_stra_cost->cost_list = tar_clist_new;
1474 if ((!valid) && (!tar_clist_new.empty())) {
1475 valid = true;
1476 }
1477 }
1478 if (!valid) {
1479 MS_LOG(EXCEPTION) << "Contracting " << op->name() << " into " << target_op->name() << " failed.";
1480 }
1481 op->SetNotAlive();
1482 MS_LOG(INFO) << "Contracting " << op->name() << " into " << target_op->name() << " succeeded.";
1483 return target_op;
1484 }
1485
CreateTriangleEliminationSubCostList(StrategyPtr elimi_op_stra,StrategyPtr left_op_stra,StrategyPtr right_op_stra,const CostPtr & right_op_cost,const CostPtrList & elimi_op_clist,const CostPtrList & left_edge_clist,const CostPtr & right_edge_cost,const CostPtrList & left_node_clist_origin,CostPtrList * left_node_clist_new) const1486 void CostGraph::CreateTriangleEliminationSubCostList(StrategyPtr elimi_op_stra, StrategyPtr left_op_stra,
1487 StrategyPtr right_op_stra, const CostPtr &right_op_cost,
1488 const CostPtrList &elimi_op_clist,
1489 const CostPtrList &left_edge_clist, const CostPtr &right_edge_cost,
1490 const CostPtrList &left_node_clist_origin,
1491 CostPtrList *left_node_clist_new) const {
1492 MS_EXCEPTION_IF_NULL(right_edge_cost);
1493 MS_EXCEPTION_IF_NULL(right_op_cost);
1494 MS_EXCEPTION_IF_NULL(left_node_clist_new);
1495 for (auto &elimi_op_cost : elimi_op_clist) {
1496 MS_EXCEPTION_IF_NULL(elimi_op_cost);
1497 for (auto &left_edge_cost : left_edge_clist) {
1498 MS_EXCEPTION_IF_NULL(left_edge_cost);
1499 for (auto &left_node_cost : left_node_clist_origin) {
1500 MS_EXCEPTION_IF_NULL(left_node_cost);
1501 double new_computation = elimi_op_cost->computation_cost_ + left_edge_cost->computation_cost_ +
1502 left_node_cost->computation_cost_ + right_edge_cost->computation_cost_;
1503 double new_memory = elimi_op_cost->memory_with_reuse_ + left_edge_cost->memory_with_reuse_ +
1504 left_node_cost->memory_with_reuse_ + right_edge_cost->memory_with_reuse_;
1505 double new_commu_cost = elimi_op_cost->communication_cost_ + left_edge_cost->communication_cost_ +
1506 left_node_cost->communication_cost_ + right_edge_cost->communication_cost_;
1507 double new_commu_forward = elimi_op_cost->communication_forward_ + left_edge_cost->communication_forward_ +
1508 left_node_cost->communication_forward_ + right_edge_cost->communication_forward_;
1509 double new_commu_without =
1510 elimi_op_cost->communication_without_parameter_ + left_edge_cost->communication_without_parameter_ +
1511 left_node_cost->communication_without_parameter_ + right_edge_cost->communication_without_parameter_;
1512 const auto triangle_star_stra_overwrite = CostModelContext::GetInstance()->triangle_star_strategy_overwrite();
1513 const auto gamma = CostModelContext::GetInstance()->costmodel_gamma();
1514
1515 if (triangle_star_stra_overwrite) {
1516 new_computation += right_op_cost->computation_cost_;
1517 new_memory += right_op_cost->memory_with_reuse_;
1518 new_commu_cost += right_op_cost->communication_cost_;
1519 new_commu_forward += right_op_cost->communication_forward_;
1520 new_commu_without += right_op_cost->communication_without_parameter_;
1521 }
1522
1523 auto decision =
1524 std::make_shared<TriangleEliminationDecision>(elimi_op_stra, elimi_op_cost, left_edge_cost, right_edge_cost,
1525 left_op_stra, left_node_cost, right_op_stra, right_op_cost);
1526 auto new_cost = std::make_shared<Cost>(new_computation, new_commu_cost, decision);
1527 new_cost->communication_without_parameter_ = new_commu_without;
1528 new_cost->communication_with_partial_para_ = new_commu_without + gamma * (new_commu_cost - new_commu_without);
1529 new_cost->memory_with_reuse_ = new_memory;
1530 new_cost->communication_forward_ = new_commu_forward;
1531 left_node_clist_new->push_back(std::move(new_cost));
1532 }
1533 }
1534 }
1535 }
1536
CreateTriangleEliminationCostList(const OperatorInfoPtr & elimi_op,const CostPtrList & right_node_clist,const CostPtrList & right_edge_clist,const StrategyPtr & elimi_op_stra,const StrategyPtr & left_node_stra,const StrategyPtr & right_node_stra,const CostPtrList & elimi_op_clist,const CostPtrList & left_edge_clist,const CostPtrList & left_node_clist_origin,CostPtrList * left_node_clist_new) const1537 void CostGraph::CreateTriangleEliminationCostList(const OperatorInfoPtr &elimi_op, const CostPtrList &right_node_clist,
1538 const CostPtrList &right_edge_clist, const StrategyPtr &elimi_op_stra,
1539 const StrategyPtr &left_node_stra, const StrategyPtr &right_node_stra,
1540 const CostPtrList &elimi_op_clist, const CostPtrList &left_edge_clist,
1541 const CostPtrList &left_node_clist_origin,
1542 CostPtrList *left_node_clist_new) const {
1543 MS_EXCEPTION_IF_NULL(elimi_op);
1544 for (auto &right_node_cost : right_node_clist) {
1545 MS_EXCEPTION_IF_NULL(right_node_cost);
1546 for (auto &right_edge_cost : right_edge_clist) {
1547 MS_EXCEPTION_IF_NULL(right_edge_cost);
1548 CreateTriangleEliminationSubCostList(elimi_op_stra, left_node_stra, right_node_stra, right_node_cost,
1549 elimi_op_clist, left_edge_clist, right_edge_cost, left_node_clist_origin,
1550 left_node_clist_new);
1551 }
1552 }
1553 }
1554
EliminationTriangle(const OperatorInfoPtr & elimi_op,const std::shared_ptr<Edge> & edge_left_right) const1555 OperatorInfoPtr CostGraph::EliminationTriangle(const OperatorInfoPtr &elimi_op,
1556 const std::shared_ptr<Edge> &edge_left_right) const {
1557 MS_EXCEPTION_IF_NULL(edge_left_right);
1558 MS_EXCEPTION_IF_NULL(elimi_op);
1559 MS_LOG(INFO) << "Now eliminating triangle: " << elimi_op->name() << ".";
1560 auto left_node = edge_left_right->prev_operator();
1561 auto right_node = edge_left_right->next_operator();
1562 auto left_edge = elimi_op->GetAliveSuccEdges()[0];
1563 auto right_edge = elimi_op->GetAliveSuccEdges()[1];
1564 MS_EXCEPTION_IF_NULL(left_node);
1565 MS_EXCEPTION_IF_NULL(right_node);
1566 MS_EXCEPTION_IF_NULL(left_edge);
1567 MS_EXCEPTION_IF_NULL(right_edge);
1568 MS_LOG(INFO) << "The left operator is: " << left_node->name() << ".";
1569 MS_LOG(INFO) << "The right operator is: " << right_node->name() << ".";
1570
1571 if (left_edge->next_operator() != left_node) {
1572 auto tmp = left_edge;
1573 left_edge = right_edge;
1574 right_edge = tmp;
1575 }
1576 bool valid = false;
1577
1578 for (auto &left_node_stra_cost : left_node->GetStrategyCost()) {
1579 MS_EXCEPTION_IF_NULL(left_node_stra_cost);
1580 auto left_node_stra = left_node_stra_cost->strategy_ptr;
1581 auto left_node_clist_origin = left_node_stra_cost->cost_list;
1582 CostPtrList left_node_clist_new;
1583
1584 for (auto &elimi_op_stra_cost : elimi_op->GetStrategyCost()) {
1585 MS_EXCEPTION_IF_NULL(elimi_op_stra_cost);
1586 auto elimi_op_stra = elimi_op_stra_cost->strategy_ptr;
1587 auto elimi_op_clist = elimi_op_stra_cost->cost_list;
1588 auto left_edge_clist = left_edge->GetCostList(elimi_op_stra, left_node_stra);
1589
1590 for (auto &right_node_stra_cost : right_node->GetStrategyCost()) {
1591 MS_EXCEPTION_IF_NULL(right_node_stra_cost);
1592 auto right_node_stra = right_node_stra_cost->strategy_ptr;
1593 auto right_node_clist = right_node_stra_cost->cost_list;
1594 auto right_edge_clist = right_edge->GetCostList(elimi_op_stra, right_node_stra);
1595
1596 CreateTriangleEliminationCostList(elimi_op, right_node_clist, right_edge_clist, elimi_op_stra, left_node_stra,
1597 right_node_stra, elimi_op_clist, left_edge_clist, left_node_clist_origin,
1598 &left_node_clist_new);
1599 }
1600 }
1601 Simplify(&left_node_clist_new);
1602 // Set the new costlist w.r.t the strategy
1603 left_node_stra_cost->cost_list = left_node_clist_new;
1604 if ((!valid) && (!left_node_clist_new.empty())) {
1605 valid = true;
1606 }
1607 }
1608
1609 if (!valid) {
1610 MS_LOG(EXCEPTION) << "Eliminating triangle: " << elimi_op->name()
1611 << " failed. It may be caused by "
1612 "configuring inconsistent strategies for operators.";
1613 }
1614 elimi_op->SetNotAlive();
1615 MS_LOG(INFO) << "Eliminating triangle: " << elimi_op->name() << " succeeded.";
1616 return left_node;
1617 }
1618
CreateStarEliminationSubCostList(const StrategyPtr & first_succ_node_stra,const CostPtrList & first_succ_node_clist,const CostPtrList & first_succ_edge_clist,const StrategyPtr & merged_op_stra,const CostPtrList & merged_op_clist,std::vector<StrategyPtr> succ_nodes_stras,CostPtrList & succ_edges_costs,CostPtrList & succ_nodes_costs,CostPtrList * first_succ_node_clist_new) const1619 void CostGraph::CreateStarEliminationSubCostList(const StrategyPtr &first_succ_node_stra,
1620 const CostPtrList &first_succ_node_clist,
1621 const CostPtrList &first_succ_edge_clist,
1622 const StrategyPtr &merged_op_stra, const CostPtrList &merged_op_clist,
1623 std::vector<StrategyPtr> succ_nodes_stras,
1624 CostPtrList &succ_edges_costs, CostPtrList &succ_nodes_costs,
1625 CostPtrList *first_succ_node_clist_new) const {
1626 for (auto &first_succ_node_cost : first_succ_node_clist) {
1627 for (auto &first_succ_edge_cost : first_succ_edge_clist) {
1628 for (auto &merged_node_cost : merged_op_clist) {
1629 MS_EXCEPTION_IF_NULL(merged_node_cost);
1630 succ_nodes_stras[0] = first_succ_node_stra;
1631 succ_edges_costs[0] = first_succ_edge_cost;
1632 succ_nodes_costs[0] = first_succ_node_cost;
1633
1634 double computation_cost = merged_node_cost->computation_cost_,
1635 memory_cost = merged_node_cost->memory_with_reuse_, commu_cost = merged_node_cost->communication_cost_,
1636 commu_without = merged_node_cost->communication_without_parameter_,
1637 commu_forward = merged_node_cost->communication_forward_;
1638 const auto triangle_star_stra_overwrite = CostModelContext::GetInstance()->triangle_star_strategy_overwrite();
1639 const auto gamma = CostModelContext::GetInstance()->costmodel_gamma();
1640 for (size_t i = 0; i < succ_nodes_stras.size(); ++i) {
1641 MS_EXCEPTION_IF_NULL(succ_edges_costs[i]);
1642 if (i == 0) {
1643 computation_cost += succ_edges_costs[i]->computation_cost_ + succ_nodes_costs[i]->computation_cost_;
1644 memory_cost += succ_edges_costs[i]->memory_with_reuse_ + succ_nodes_costs[i]->memory_with_reuse_;
1645 commu_cost += succ_edges_costs[i]->communication_cost_ + succ_nodes_costs[i]->communication_cost_;
1646 commu_forward += succ_edges_costs[i]->communication_forward_ + succ_nodes_costs[i]->communication_forward_;
1647 commu_without += succ_edges_costs[i]->communication_without_parameter_ +
1648 succ_nodes_costs[i]->communication_without_parameter_;
1649 } else {
1650 computation_cost += succ_edges_costs[i]->computation_cost_;
1651 memory_cost += succ_edges_costs[i]->memory_with_reuse_;
1652 commu_cost += succ_edges_costs[i]->communication_cost_;
1653 commu_forward += succ_edges_costs[i]->communication_forward_;
1654 commu_without += succ_edges_costs[i]->communication_without_parameter_;
1655 if (triangle_star_stra_overwrite) {
1656 computation_cost += succ_nodes_costs[i]->computation_cost_;
1657 memory_cost += succ_nodes_costs[i]->memory_with_reuse_;
1658 commu_cost += succ_nodes_costs[i]->communication_cost_;
1659 commu_forward += succ_nodes_costs[i]->communication_forward_;
1660 commu_without += succ_nodes_costs[i]->communication_without_parameter_;
1661 }
1662 }
1663 }
1664
1665 auto decision = std::make_shared<StarEliminationDecision>(merged_op_stra, merged_node_cost, succ_edges_costs,
1666 succ_nodes_stras, succ_nodes_costs);
1667 auto new_cost = std::make_shared<Cost>(computation_cost, commu_cost, decision);
1668 new_cost->communication_without_parameter_ = commu_without;
1669 new_cost->communication_with_partial_para_ = commu_without + gamma * (commu_cost - commu_without);
1670 new_cost->memory_with_reuse_ = memory_cost;
1671 new_cost->communication_forward_ = commu_forward;
1672 first_succ_node_clist_new->push_back(std::move(new_cost));
1673 }
1674 }
1675 }
1676 }
1677
CreateStarEliminationCostList(std::vector<std::shared_ptr<Edge>> & succ_edges,const StrategyPtr & first_succ_node_stra,const CostPtrList & first_succ_node_clist,const CostPtrList & first_succ_edge_clist,const StrategyPtr & merged_op_stra,const CostPtrList & merged_op_clist,CostPtrList * first_succ_node_clist_new) const1678 void CostGraph::CreateStarEliminationCostList(std::vector<std::shared_ptr<Edge>> &succ_edges,
1679 const StrategyPtr &first_succ_node_stra,
1680 const CostPtrList &first_succ_node_clist,
1681 const CostPtrList &first_succ_edge_clist,
1682 const StrategyPtr &merged_op_stra, const CostPtrList &merged_op_clist,
1683 CostPtrList *first_succ_node_clist_new) const {
1684 std::vector<StrategyPtr> succ_nodes_stras(succ_edges.size(), nullptr);
1685 CostPtrList succ_edges_costs(succ_edges.size(), nullptr), succ_nodes_costs(succ_edges.size(), nullptr);
1686 std::function<void(size_t)> recursive = [&first_succ_node_stra, &first_succ_node_clist, &first_succ_edge_clist,
1687 &merged_op_stra, &merged_op_clist, &succ_nodes_stras, &succ_edges_costs,
1688 &succ_nodes_costs, &first_succ_node_clist_new, &succ_edges, &recursive,
1689 this](size_t k) {
1690 if (k == succ_edges.size()) {
1691 CreateStarEliminationSubCostList(first_succ_node_stra, first_succ_node_clist, first_succ_edge_clist,
1692 merged_op_stra, merged_op_clist, succ_nodes_stras, succ_edges_costs,
1693 succ_nodes_costs, first_succ_node_clist_new);
1694 return;
1695 }
1696 MS_LOG(DEBUG) << "The size of first_succ_node_clist: " << first_succ_node_clist.size()
1697 << ", first_succ_edge_clist: " << first_succ_edge_clist.size()
1698 << ", merged_op_clist: " << merged_op_clist.size()
1699 << ", first_succ_node_clist_new: " << first_succ_node_clist_new->size() << ".";
1700 auto succ_edge = succ_edges[k];
1701 MS_EXCEPTION_IF_NULL(succ_edge);
1702 auto succ_node = succ_edge->next_operator();
1703 MS_EXCEPTION_IF_NULL(succ_node);
1704 for (auto &succ_node_stra_cost : succ_node->GetStrategyCost()) {
1705 MS_EXCEPTION_IF_NULL(succ_node_stra_cost);
1706 auto succ_node_stra = succ_node_stra_cost->strategy_ptr;
1707 auto succ_node_clist = succ_node_stra_cost->cost_list;
1708 auto succ_edge_clist = succ_edge->GetCostList(merged_op_stra, succ_node_stra);
1709
1710 for (auto &succ_node_cost : succ_node_clist) {
1711 MS_EXCEPTION_IF_NULL(succ_node_cost);
1712 for (auto &succ_edge_cost : succ_edge_clist) {
1713 MS_EXCEPTION_IF_NULL(succ_edge_cost);
1714 succ_nodes_stras[k] = succ_node_stra;
1715 succ_edges_costs[k] = succ_edge_cost;
1716 succ_nodes_costs[k] = succ_node_cost;
1717 recursive(k + 1);
1718 }
1719 }
1720 }
1721 };
1722
1723 recursive(1);
1724 }
1725
EliminationStar(const OperatorInfoPtr & merged_op) const1726 std::vector<std::shared_ptr<Edge>> CostGraph::EliminationStar(const OperatorInfoPtr &merged_op) const {
1727 MS_EXCEPTION_IF_NULL(merged_op);
1728 auto succ_edges = merged_op->GetAliveSuccEdges();
1729 MS_LOG(INFO) << "Now eliminating star centered at: " << merged_op->name() << ".";
1730 for (auto &succ_edge : succ_edges) {
1731 MS_EXCEPTION_IF_NULL(succ_edge->next_operator());
1732 MS_LOG(INFO) << "The successive operator is: " << succ_edge->next_operator()->name() << ".";
1733 }
1734
1735 MS_EXCEPTION_IF_NULL(succ_edges[0]);
1736 auto first_succ_node = succ_edges[0]->next_operator();
1737 auto first_succ_edge = succ_edges[0];
1738 bool valid = false;
1739
1740 // 'merged_op' is merged into first_node
1741 MS_EXCEPTION_IF_NULL(first_succ_node);
1742 for (auto &first_succ_node_stra_cost : first_succ_node->GetStrategyCost()) {
1743 MS_EXCEPTION_IF_NULL(first_succ_node_stra_cost);
1744 auto first_succ_node_stra = first_succ_node_stra_cost->strategy_ptr;
1745 auto first_succ_node_clist = first_succ_node_stra_cost->cost_list;
1746 CostPtrList first_succ_node_clist_new;
1747
1748 for (auto &merged_op_stra_cost : merged_op->GetStrategyCost()) {
1749 MS_EXCEPTION_IF_NULL(merged_op_stra_cost);
1750 auto merged_op_stra = merged_op_stra_cost->strategy_ptr;
1751 auto merged_op_clist = merged_op_stra_cost->cost_list;
1752 auto first_succ_edge_clist = first_succ_edge->GetCostList(merged_op_stra, first_succ_node_stra);
1753
1754 CreateStarEliminationCostList(succ_edges, first_succ_node_stra, first_succ_node_clist, first_succ_edge_clist,
1755 merged_op_stra, merged_op_clist, &first_succ_node_clist_new);
1756 }
1757 Simplify(&first_succ_node_clist_new);
1758 // Set the new costlist w.r.t the strategy
1759 first_succ_node_stra_cost->cost_list = first_succ_node_clist_new;
1760 if ((!valid) && (!first_succ_node_clist_new.empty())) {
1761 valid = true;
1762 }
1763 }
1764
1765 if (!valid) {
1766 MS_LOG(EXCEPTION) << "Eliminating star centered at: " << merged_op->name()
1767 << " failed. It may be caused by "
1768 "configuring inconsistent strategies for operators.";
1769 }
1770
1771 merged_op->SetNotAlive();
1772 MS_LOG(INFO) << "Eliminating star centered at: " << merged_op->name() << " succeeded.";
1773 return succ_edges;
1774 }
1775
GetNumEdges() const1776 size_t CostGraph::GetNumEdges() const {
1777 size_t sum = 0;
1778 for (const auto &kv : edges_) {
1779 auto &edges = kv.second;
1780 sum += edges.size();
1781 }
1782 return sum;
1783 }
1784
InitReshapeStrategy()1785 Status CostGraph::InitReshapeStrategy() {
1786 // reshape init should be apply after the init of it's previous node and next node.
1787 for (size_t i = 0; i < ops_.size(); ++i) {
1788 if (ops_[i]->IsReshape()) {
1789 auto reshape_info = std::dynamic_pointer_cast<ReshapeInfo>(ops_[i]);
1790 auto in_edges = GetOriginalPrevEdges(ops_[i]);
1791 auto out_edges = GetOriginalNextEdges(ops_[i]);
1792 // deal with consecutive reshape op
1793 auto pre_reshape_iter = std::find_if(in_edges.begin(), in_edges.end(), [&](const std::shared_ptr<Edge> &edge) {
1794 return edge->prev_operator()->IsReshape();
1795 });
1796 auto next_reshape_iter = std::find_if(out_edges.begin(), out_edges.end(), [&](const std::shared_ptr<Edge> &edge) {
1797 return edge->next_operator()->IsReshape();
1798 });
1799 if (pre_reshape_iter != in_edges.end() || next_reshape_iter != out_edges.end()) {
1800 reshape_info->set_strategy(nullptr);
1801 continue;
1802 }
1803 auto pre_iter = std::find_if(in_edges.begin(), in_edges.end(), [&](const std::shared_ptr<Edge> &edge) {
1804 return edge->prev_operator()->name() == reshape_info->pre_operator_name();
1805 });
1806 auto next_iter = std::find_if(out_edges.begin(), out_edges.end(), [&](const std::shared_ptr<Edge> &edge) {
1807 return edge->next_operator()->name() == reshape_info->next_operator_name();
1808 });
1809 bool reshape_is_first_op = reshape_info->pre_operator_name() == reshape_info->name();
1810 if (reshape_is_first_op) {
1811 (void)reshape_info->InitSelectedStrategy(reshape_info->selected_strategy(), nullptr);
1812 }
1813 if (pre_iter != in_edges.end() || reshape_is_first_op) {
1814 MS_LOG(DEBUG) << "Set reshape input layout by " << reshape_info->pre_operator_name();
1815 int64_t pre_index = reshape_info->pre_operator_index();
1816 TensorInfo pre_info;
1817 std::shared_ptr<OperatorInfo> pre_op_info;
1818 if (reshape_is_first_op) {
1819 pre_op_info = reshape_info;
1820 pre_info = pre_op_info->inputs_tensor_info()[LongToSize(pre_index)];
1821 } else {
1822 pre_op_info = (*pre_iter)->prev_operator();
1823 pre_info = pre_op_info->outputs_tensor_info()[LongToSize(pre_index)];
1824 }
1825 reshape_info->SetInputLayout(pre_info.tensor_layout());
1826 if (pre_iter != in_edges.end()) {
1827 Dimensions stra = pre_info.InferStrategy();
1828 if (stra.empty()) {
1829 MS_LOG(EXCEPTION) << "Infer strategy by tensor_info failed";
1830 }
1831 Strategies stra_inputs = {stra};
1832 StrategyPtr reshape_stra =
1833 std::make_shared<Strategy>((*pre_iter)->prev_operator()->strategy()->GetInputStage(), stra_inputs);
1834 reshape_info->set_strategy(reshape_stra);
1835 }
1836 }
1837 if (next_iter != out_edges.end()) {
1838 MS_LOG(DEBUG) << "Set reshape output layout by " << reshape_info->next_operator_name();
1839 int64_t next_index = reshape_info->next_operator_index();
1840 reshape_info->SetOutputLayout(
1841 (*next_iter)->next_operator()->inputs_tensor_info()[LongToSize(next_index)].tensor_layout());
1842 }
1843 if (reshape_info->Init(nullptr, nullptr) != SUCCESS) {
1844 return FAILED;
1845 }
1846 }
1847 }
1848 return SUCCESS;
1849 }
1850
InitSelectedStrategy()1851 Status CostGraph::InitSelectedStrategy() {
1852 for (auto &op : ops_) {
1853 MS_EXCEPTION_IF_NULL(op);
1854 if (op->IsReshape()) {
1855 continue;
1856 }
1857 StrategyPtr out_strategy = nullptr;
1858 if (op->type() == MATMUL && op->out_strategy() != nullptr) {
1859 out_strategy = op->out_strategy();
1860 }
1861 auto result_op = op->InitSelectedStrategy(op->selected_strategy(), out_strategy);
1862 if (result_op != SUCCESS) {
1863 return result_op;
1864 }
1865 }
1866 auto result = InitReshapeStrategy();
1867 return result;
1868 }
1869
ComputeOpsAndEdgesParameterInvolved()1870 Status CostGraph::ComputeOpsAndEdgesParameterInvolved() {
1871 for (auto &op : ops_) {
1872 MS_EXCEPTION_IF_NULL(op);
1873 const auto &output_parameter = op->ComputeOpAndPrevEdgeParameterInvolved();
1874 if ((output_parameter != 0) && (output_parameter != 1)) {
1875 MS_LOG(ERROR) << "Computing parameter_involved for " << op->name() << " failed.";
1876 return FAILED;
1877 }
1878 }
1879 return SUCCESS;
1880 }
1881
DFSForTopoOrder(const OperatorInfoPtr & current_op,std::map<OperatorInfoPtr,bool> * visited,std::vector<OperatorInfoPtr> * topo_order)1882 void CostGraph::DFSForTopoOrder(const OperatorInfoPtr ¤t_op, std::map<OperatorInfoPtr, bool> *visited,
1883 std::vector<OperatorInfoPtr> *topo_order) {
1884 MS_EXCEPTION_IF_NULL(current_op);
1885 MS_EXCEPTION_IF_NULL(visited);
1886 MS_EXCEPTION_IF_NULL(topo_order);
1887
1888 visited->at(current_op) = true;
1889 for (const auto &s_edge : current_op->succ_edges()) {
1890 if (!visited->at(s_edge->next_operator())) {
1891 DFSForTopoOrder(s_edge->next_operator(), visited, topo_order);
1892 }
1893 }
1894 topo_order->push_back(current_op);
1895 }
1896
1897 // Compute a topological order of the costgraph
TopologyOrder(std::vector<OperatorInfoPtr> * topo_order)1898 void CostGraph::TopologyOrder(std::vector<OperatorInfoPtr> *topo_order) {
1899 std::map<OperatorInfoPtr, bool> visited;
1900 for (auto &op : ops_) {
1901 visited[op] = false;
1902 }
1903
1904 for (auto &op : ops_) {
1905 if (!visited[op]) {
1906 DFSForTopoOrder(op, &visited, topo_order);
1907 }
1908 }
1909 }
MarkCriticalOpsAndEdges(const std::map<OperatorInfoPtr,int64_t> & candidate_ops)1910 void CostGraph::MarkCriticalOpsAndEdges(const std::map<OperatorInfoPtr, int64_t> &candidate_ops) {
1911 for (auto &op : ops_) {
1912 auto search = candidate_ops.find(op);
1913 if (search != candidate_ops.end()) {
1914 // Mark the critical operators
1915 op->mark_output_critical();
1916 // Mark the successive edges
1917 for (auto &s_edge : op->succ_edges()) {
1918 s_edge->mark_output_critical();
1919 }
1920 } else {
1921 op->mark_output_not_critical();
1922 }
1923 }
1924 }
1925
DetermineCriticalOps(const std::vector<OperatorInfoPtr> & topo_order)1926 Status CostGraph::DetermineCriticalOps(const std::vector<OperatorInfoPtr> &topo_order) {
1927 if (topo_order.size() == 0) {
1928 MS_LOG(ERROR) << "0 operator in costgraph.";
1929 return FAILED;
1930 }
1931 auto &first_op = topo_order[0];
1932 if (first_op->prev_edges().size() > 0) {
1933 MS_LOG(ERROR) << "The first operator in the first of topological order of "
1934 "costgraph should have 0 incoming edge, but has "
1935 << first_op->prev_edges() << "edges.";
1936 return FAILED;
1937 }
1938 // The 'curr_memory_state' records <OperatorInfo, remaining_output_cnt>, where remaining_output_cnt is the number
1939 // of the output of OperatorInfo that currently has not been used
1940 std::map<OperatorInfoPtr, int64_t> curr_memory_state;
1941 (void)curr_memory_state.emplace(std::make_pair(first_op, SizeToLong(first_op->succ_edges().size())));
1942 std::map<OperatorInfoPtr, int64_t> max_memory_state = curr_memory_state;
1943 // The 'curr_memory_size' records the current total memory size, which is the sum of outputs of operators that has
1944 // not been used
1945 double curr_memory_size = first_op->GetOutputsTotalSize();
1946 double max_memory_size = curr_memory_size;
1947
1948 for (size_t finished = 1; finished < topo_order.size(); ++finished) {
1949 // Produce
1950 (void)curr_memory_state.emplace(
1951 std::make_pair(topo_order[finished], SizeToLong(topo_order[finished]->succ_edges().size())));
1952 curr_memory_size += topo_order[finished]->GetOutputsTotalSize();
1953 // Consume
1954 for (const auto &prev_edge : topo_order[finished]->prev_edges()) {
1955 const auto &prev_op = prev_edge->prev_operator();
1956 curr_memory_state[prev_op]--;
1957 }
1958 for (const auto &prev_edge : topo_order[finished]->prev_edges()) {
1959 const auto &prev_op = prev_edge->prev_operator();
1960 if (curr_memory_state[prev_op] < 0) {
1961 MS_LOG(ERROR) << "Failure: " << prev_op->name() << "'s current output count: " << curr_memory_state[prev_op];
1962 return FAILED;
1963 } else if (curr_memory_state[prev_op] == 0) {
1964 (void)curr_memory_state.erase(prev_op);
1965 curr_memory_size -= prev_op->GetOutputsTotalSize();
1966 }
1967 }
1968
1969 if (curr_memory_size < 0) {
1970 MS_LOG(ERROR) << "Memory size calculation failed: " << curr_memory_size;
1971 }
1972 // Modify the max
1973 if (curr_memory_size > max_memory_size) {
1974 max_memory_size = curr_memory_size;
1975 max_memory_state = curr_memory_state;
1976 }
1977 }
1978 // Mark those critical operators
1979 MarkCriticalOpsAndEdges(max_memory_state);
1980 return SUCCESS;
1981 }
1982
ComputeOpsAndEdgesOutputCritical()1983 Status CostGraph::ComputeOpsAndEdgesOutputCritical() {
1984 // Two steps to do:
1985 // 1. Compute a topological order of the costgraph
1986 // 2. Determine and mark the operators (and necessary edges) that are critical
1987 std::vector<OperatorInfoPtr> topo_order;
1988 TopologyOrder(&topo_order);
1989 std::reverse(std::begin(topo_order), std::end(topo_order));
1990
1991 if (DetermineCriticalOps(topo_order) != SUCCESS) {
1992 MS_LOG(ERROR) << "Determining critical operators failed.";
1993 return FAILED;
1994 }
1995 return SUCCESS;
1996 }
1997
CalculateOpsMemoryCost()1998 Status CostGraph::CalculateOpsMemoryCost() {
1999 for (auto &op : ops_) {
2000 MS_EXCEPTION_IF_NULL(op);
2001 if (op->CalculateMemoryCost() != SUCCESS) {
2002 MS_LOG(ERROR) << "Calculate Operator: " << op->name() << " cost for memory usage failed.";
2003 return FAILED;
2004 }
2005 }
2006 return SUCCESS;
2007 }
2008
CalculateOpsMemoryCostForInference()2009 Status CostGraph::CalculateOpsMemoryCostForInference() {
2010 for (auto &op : ops_) {
2011 MS_EXCEPTION_IF_NULL(op);
2012 if (op->CalculateMemoryCostForInference() != SUCCESS) {
2013 MS_LOG(ERROR) << "Calculate Operator: " << op->name() << " cost for memory usage failed.";
2014 return FAILED;
2015 }
2016 }
2017 return SUCCESS;
2018 }
2019
CalculateEdgesMemoryCost()2020 Status CostGraph::CalculateEdgesMemoryCost() {
2021 for (const auto &edge_pair : edges_) {
2022 const auto &edges = edge_pair.second;
2023 for (auto &one_edge : edges) {
2024 if (one_edge->CalculateMemoryCost() != SUCCESS) {
2025 MS_LOG(ERROR) << "Calculate Edge: " << one_edge->edge_name() << " cost for memory usage failed.";
2026 return FAILED;
2027 }
2028 }
2029 }
2030 return SUCCESS;
2031 }
2032
CalculateEdgesMemoryCostForInference()2033 Status CostGraph::CalculateEdgesMemoryCostForInference() {
2034 for (const auto &edge_pair : edges_) {
2035 const auto &edges = edge_pair.second;
2036 for (auto &one_edge : edges) {
2037 if (one_edge->CalculateMemoryCostForInference() != SUCCESS) {
2038 MS_LOG(ERROR) << "Calculate Edge: " << one_edge->edge_name() << " cost for memory usage failed.";
2039 return FAILED;
2040 }
2041 }
2042 }
2043 return SUCCESS;
2044 }
2045
FindTmpIdentityByParameterName(const std::string & p_name) const2046 OperatorInfoPtr CostGraph::FindTmpIdentityByParameterName(const std::string &p_name) const {
2047 for (auto one_op : ops_) {
2048 if (one_op->name().find(IDENTITY_INFO) != std::string::npos) {
2049 if (one_op->refkey_parameter_name() == p_name) {
2050 return one_op;
2051 }
2052 }
2053 }
2054 return nullptr;
2055 }
CorrectOpsMemoryCost()2056 Status CostGraph::CorrectOpsMemoryCost() {
2057 for (auto &one_op : ops_) {
2058 if ((one_op->name().find(IDENTITY_INFO) != std::string::npos) && (one_op->is_output_parameter_involve() == 1)) {
2059 if (one_op->GetAliveSuccEdges().size() > 1) {
2060 // Filter out the case when the TmpIdentity being used by multiple operators
2061 std::map<size_t, int64_t> output_count;
2062 for (size_t i = 0; i < one_op->GetAliveSuccEdges().size(); ++i) {
2063 auto output_index = one_op->GetAliveSuccEdges()[i]->prev_op_output_index();
2064 output_count[output_index]++;
2065 }
2066 for (size_t i = 0; i < one_op->GetAliveSuccEdges().size(); ++i) {
2067 auto output_index = one_op->GetAliveSuccEdges()[i]->prev_op_output_index();
2068 if (output_count[output_index] <= 1) {
2069 continue;
2070 }
2071 auto next_op = one_op->GetAliveSuccEdges()[i]->next_operator();
2072 MS_EXCEPTION_IF_NULL(next_op);
2073 auto input_index = one_op->GetAliveSuccEdges()[i]->next_op_input_index();
2074 if (next_op->CorrectMemoryCost(input_index) != SUCCESS) {
2075 MS_LOG(ERROR) << "The operator name: " << one_op->name() << ", the next operator name: " << next_op->name()
2076 << ", the output_index: " << output_index << ", the input_index: " << input_index << ".";
2077 return FAILED;
2078 }
2079 output_count[output_index]--;
2080 }
2081 }
2082 }
2083 }
2084 return SUCCESS;
2085 }
2086
CalculateMemoryCost()2087 Status CostGraph::CalculateMemoryCost() {
2088 const auto phase = CostModelContext::GetInstance()->run_phase();
2089 if (phase == TRAINING_PHASE) {
2090 // training phase
2091 if (ComputeOpsAndEdgesParameterInvolved() == SUCCESS) {
2092 // Calculate operators' memory usage
2093 if (CalculateOpsMemoryCost() != SUCCESS) {
2094 MS_LOG(ERROR) << "Calculating operators' cost for memory cost failed.";
2095 return FAILED;
2096 }
2097 // Calculate edges' memory usage
2098 if (CalculateEdgesMemoryCost() != SUCCESS) {
2099 MS_LOG(ERROR) << "Calculating edges' cost for memory cost failed.";
2100 return FAILED;
2101 }
2102 // Correct memory usage caused by TmpIdentity
2103 if (CorrectOpsMemoryCost() != SUCCESS) {
2104 MS_LOG(ERROR) << "Correcting operators' cost for memory cost failed.";
2105 return FAILED;
2106 }
2107 } else {
2108 MS_LOG(ERROR) << "Computing operators' parameter_involved failed.";
2109 return FAILED;
2110 }
2111 } else {
2112 // inference phase
2113 if (ComputeOpsAndEdgesOutputCritical() == SUCCESS) {
2114 // Calculate operators' memory usage
2115 if (CalculateOpsMemoryCostForInference() != SUCCESS) {
2116 MS_LOG(ERROR) << "Calculating operators' memory cost for inference failed.";
2117 return FAILED;
2118 }
2119 // Calculate edges's memory usage
2120 if (CalculateEdgesMemoryCostForInference() != SUCCESS) {
2121 MS_LOG(ERROR) << "Calculating operators' memory cost for inference failed.";
2122 return FAILED;
2123 }
2124 } else {
2125 MS_LOG(ERROR) << "Computing operators' critical flag failed.";
2126 return FAILED;
2127 }
2128 }
2129 return SUCCESS;
2130 }
2131
CheckApproximateCostGraphEdges()2132 void CostGraph::CheckApproximateCostGraphEdges() {
2133 auto approximation = CostModelContext::GetInstance()->dp_algo_enable_approxi();
2134 if (!approximation) {
2135 return;
2136 }
2137 for (const auto &s_edge : edges_) {
2138 auto &edges_vector = s_edge.second;
2139 for (auto &edge_ptr : edges_vector) {
2140 MS_EXCEPTION_IF_NULL(edge_ptr);
2141 if (edge_ptr->CheckStrategyCostPossibility()) {
2142 continue;
2143 }
2144 MS_LOG(INFO) << "Checking StrategyCost for edge: " << edge_ptr->edge_name()
2145 << " impossible, re-initing the operators and edges";
2146 auto prev_op = edge_ptr->prev_operator();
2147 MS_EXCEPTION_IF_NULL(prev_op);
2148 auto next_op = edge_ptr->next_operator();
2149 MS_EXCEPTION_IF_NULL(next_op);
2150 // Check the 'prev_op'
2151 prev_op->ExactStrategiesAndRelatedEdges();
2152 // Check the 'next_op'
2153 next_op->ExactStrategiesAndRelatedEdges();
2154 }
2155 }
2156 }
2157 } // namespace parallel
2158 } // namespace mindspore
2159