• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/parallel/auto_parallel/dp_algo_costmodel.h"
18 
19 #include <memory>
20 #include <utility>
21 #include <vector>
22 
23 namespace mindspore {
24 namespace parallel {
GetStrategy(const CostGraphPtr & graph)25 Status GetStrategy(const CostGraphPtr &graph) {
26   MS_LOG(INFO) << "Searching strategies begins.";
27   MS_EXCEPTION_IF_NULL(graph);
28   std::vector<EliminationPtr> eliminations;
29   bool flag = true;
30 
31   // Phase 1: Shrink the CostGraph using 6 operations, and record them in the order.
32   // Note: the checking and applying of the 6 operations MUST in current order.
33   while (flag) {
34     flag = false;
35     auto node = graph->CheckOpElimination();
36     if (node != nullptr) {
37       // Applying the Operator Elimination
38       flag = true;
39       auto l_edge = node->GetAlivePrevEdges()[0];
40       auto r_edge = node->GetAliveSuccEdges()[0];
41       auto n_edge = graph->EliminationOp(node);
42       auto elimi_op = std::make_shared<OpElimination>(n_edge, l_edge, node, r_edge);
43       (void)eliminations.emplace_back(std::move(elimi_op));
44     }
45     if (!flag) {
46       auto edges = graph->CheckEdgeElimination();
47       if (!edges.empty()) {
48         // Applying the Edge Elimination
49         flag = true;
50         auto new_edge = graph->EliminationEdges(edges);
51         auto elimi_edge = std::make_shared<EdgeElimination>(new_edge, edges);
52         (void)eliminations.emplace_back(std::move(elimi_edge));
53       }
54     }
55     if (!flag) {
56       auto merge_node = graph->CheckMergeElimination();
57       if (merge_node != nullptr) {
58         // Applying the Merge Elimination
59         flag = true;
60         auto succ_edge = merge_node->GetAliveSuccEdges()[0];
61         auto target_node = graph->EliminationMerge(merge_node);
62         auto elimi_merge = std::make_shared<MergeElimination>(merge_node, succ_edge, target_node);
63         (void)eliminations.emplace_back(std::move(elimi_merge));
64       }
65     }
66     if (!flag) {
67       auto contracted_node = graph->CheckContractElimination();
68       if ((contracted_node != nullptr)) {
69         // Applying the Contract Elimination
70         flag = true;
71         auto prev_edge = contracted_node->GetAlivePrevEdges()[0];
72         auto target_node = graph->EliminationContract(contracted_node);
73         auto elimi_contract = std::make_shared<ContractElimination>(target_node, prev_edge, contracted_node);
74         (void)eliminations.emplace_back(std::move(elimi_contract));
75       }
76     }
77     if (!flag) {
78       auto triangle_pair = graph->CheckTriangleElimination();
79       if (triangle_pair.first != nullptr) {
80         // Applying the Triangle Elimination
81         flag = true;
82         auto eliminated_node = triangle_pair.first;
83         auto l_r_edge = triangle_pair.second;
84 
85         auto left_node = l_r_edge->prev_operator();
86         auto left_edge = eliminated_node->GetAliveSuccEdges()[0];
87         auto right_edge = eliminated_node->GetAliveSuccEdges()[1];
88         MS_EXCEPTION_IF_NULL(left_edge);
89         if (left_edge->next_operator() != left_node) {
90           auto tmp = left_edge;
91           left_edge = right_edge;
92           right_edge = tmp;
93         }
94         auto left_node_cpy = graph->EliminationTriangle(eliminated_node, l_r_edge);
95         auto right_node = l_r_edge->next_operator();
96         auto elimi_tri =
97           std::make_shared<TriangleElimination>(eliminated_node, left_edge, left_node_cpy, right_edge, right_node);
98         (void)eliminations.emplace_back(std::move(elimi_tri));
99       }
100     }
101     if (!flag) {
102       auto star_center = graph->CheckStarElimination();
103       if (star_center != nullptr) {
104         // Applying the Star Elimination
105         flag = true;
106         auto succ_edges = graph->EliminationStar(star_center);
107         std::vector<OperatorInfoPtr> succ_nodes;
108         for (size_t i = 0; i < succ_edges.size(); ++i) {
109           MS_EXCEPTION_IF_NULL(succ_edges[i]);
110           succ_nodes.push_back(succ_edges[i]->next_operator());
111         }
112         auto elimi_star = std::make_shared<StarElimination>(star_center, succ_edges, succ_nodes);
113         (void)eliminations.emplace_back(std::move(elimi_star));
114       }
115     }
116   }
117 
118   // Phase 2: Search the cost_list in the final graph, and determine the optimal one
119   if (graph->SearchStrategy() != SUCCESS) {
120     MS_LOG(ERROR) << "Searching strategy for the final failed.";
121     return FAILED;
122   }
123 
124   // Phase 3: Recover the original CostGraph, the determine strategy for each operator
125   if (RecoverStrategy(eliminations) == SUCCESS) {
126     MS_LOG(INFO) << "Searching strategies ends.";
127     return SUCCESS;
128   } else {
129     MS_LOG(EXCEPTION) << "Searching strategies failed.";
130   }
131 }
132 
RecoverStrategy(std::vector<EliminationPtr> eliminations)133 Status RecoverStrategy(std::vector<EliminationPtr> eliminations) {
134   std::vector<EliminationPtr>::reverse_iterator rit;
135   const auto triangle_star_overwrite = CostModelContext::GetInstance()->triangle_star_strategy_overwrite();
136   for (rit = eliminations.rbegin(); rit != eliminations.rend(); ++rit) {
137     if ((*rit)->isa<OpElimination>()) {
138       auto elimination_op = (*rit)->cast<OpEliminationPtr>();
139       auto e = elimination_op->new_edge_;
140       auto w = elimination_op->op_;
141       auto left_edge_op = elimination_op->left_edge_;
142       auto right_edge_op = elimination_op->right_edge_;
143       auto decision_op = e->selected_cost()->decision_ptr_->cast<OpEliminationDecisionPtr>();
144       w->SetSelectedStrategyAndCost(decision_op->op_strategy_, decision_op->middle_cost_);
145       left_edge_op->set_selected_cost(decision_op->left_cost_);
146       right_edge_op->set_selected_cost(decision_op->right_cost_);
147       MS_LOG(INFO) << "Recover opElimination succeeded.";
148     } else if ((*rit)->isa<EdgeElimination>()) {
149       auto elimination_edge = (*rit)->cast<EdgeEliminationPtr>();
150       auto new_edge = elimination_edge->new_edge_;
151       auto &edges = elimination_edge->edges_;
152       auto decision_edge = new_edge->selected_cost()->decision_ptr_->cast<EdgeEliminationDecisionPtr>();
153       for (size_t j = 0; j < edges.size(); ++j) {
154         MS_EXCEPTION_IF_NULL(edges[j]);
155         edges[j]->set_selected_cost(decision_edge->edges_cost_list_[j]);
156       }
157       MS_LOG(INFO) << "Recover edgeElimination succeeded.";
158     } else if ((*rit)->isa<MergeElimination>()) {
159       auto elimination_merge = (*rit)->cast<MergeEliminationPtr>();
160       auto target_node_m = elimination_merge->target_node_;
161       auto merged_node = elimination_merge->merged_node_;
162       auto merged_edge = elimination_merge->dir_edge_;
163       MS_EXCEPTION_IF_NULL(target_node_m->selected_cost());
164       MS_EXCEPTION_IF_NULL(target_node_m->selected_cost()->decision_ptr_);
165       auto decision = target_node_m->selected_cost()->decision_ptr_->cast<MergeEliminationDecisionPtr>();
166       merged_node->SetSelectedStrategyAndCost(decision->merged_op_strategy_, decision->merged_op_cost_);
167       merged_edge->set_selected_cost(decision->edge_cost_);
168       target_node_m->SetSelectedStrategyAndCost(decision->target_op_strategy_, decision->target_op_cost_);
169       MS_LOG(INFO) << "Recover mergeElimination succeeded.";
170     } else if ((*rit)->isa<ContractElimination>()) {
171       auto elimination_cont = (*rit)->cast<ContractEliminationPtr>();
172       auto target_node = elimination_cont->target_node_;
173       auto contracted_node = elimination_cont->contracted_node_;
174       auto contracted_edge = elimination_cont->dir_edge_;
175       auto decision_cont = target_node->selected_cost()->decision_ptr_->cast<ContractEliminationDecisionPtr>();
176       contracted_node->SetSelectedStrategyAndCost(decision_cont->contracted_op_strategy_,
177                                                   decision_cont->contracted_op_cost_);
178       contracted_edge->set_selected_cost(decision_cont->edge_cost_);
179       target_node->SetSelectedStrategyAndCost(decision_cont->target_op_strategy_, decision_cont->target_cost_);
180       MS_LOG(INFO) << "Recover contractElimination succeeded.";
181     } else if ((*rit)->isa<TriangleElimination>()) {
182       auto elimination_tri = (*rit)->cast<TriangleEliminationPtr>();
183       auto left_node = elimination_tri->left_node_;
184       auto left_edge = elimination_tri->left_edge_;
185       auto eliminated_node = elimination_tri->eliminated_node_;
186       auto right_edge_tri = elimination_tri->right_edge_;
187       auto right_node = elimination_tri->right_node_;
188       auto decision_tri = left_node->selected_cost()->decision_ptr_->cast<TriangleEliminationDecisionPtr>();
189 
190       eliminated_node->SetSelectedStrategyAndCost(decision_tri->eliminated_op_strategy_,
191                                                   decision_tri->eliminated_op_cost_);
192       left_edge->set_selected_cost(decision_tri->left_edge_cost_);
193       right_edge_tri->set_selected_cost(decision_tri->right_edge_cost_);
194       // 'left_node' recovers the strategy.
195       left_node->SetSelectedStrategyAndCost(decision_tri->left_node_strategy_, decision_tri->left_node_cost_);
196       if (triangle_star_overwrite) {
197         // 'right_node' recovers the strategy.
198         MS_LOG(INFO) << "Overwrite the right-node: " << right_node->name() << " in recovering triangle elimination.";
199         right_node->SetSelectedStrategyAndCost(decision_tri->right_node_strategy_, decision_tri->right_node_cost_);
200       } else {
201         // In this case, 'right_node' is not overwritten strategy, and it checks strategy consistency.
202         right_node->CheckSelectedStrategy(decision_tri->right_node_strategy_);
203       }
204       MS_LOG(INFO) << "Recover triangleElimination succeeded.";
205     } else if ((*rit)->isa<StarElimination>()) {
206       auto elimination_star = (*rit)->cast<StarEliminationPtr>();
207       auto merged_node_star = elimination_star->eliminated_node_;
208       auto succ_edges = elimination_star->succ_edges_;
209       auto succ_nodes = elimination_star->succ_ops_;
210       // decision is hidden in succ_nodes[0]
211       auto decision_star = succ_nodes[0]->selected_cost()->decision_ptr_->cast<StarEliminationDecisionPtr>();
212       merged_node_star->SetSelectedStrategyAndCost(decision_star->eliminated_op_strategy_,
213                                                    decision_star->eliminated_op_cost_);
214       for (size_t i = 0; i < succ_edges.size(); ++i) {
215         succ_edges[i]->set_selected_cost(decision_star->succ_edges_cost_list_[i]);
216       }
217       MS_EXCEPTION_IF_NULL(succ_nodes[0]);
218       MS_EXCEPTION_IF_NULL(decision_star->succ_ops_stra_list_[0]);
219       MS_EXCEPTION_IF_NULL(decision_star->succ_ops_cost_list_[0]);
220       // Star is eliminated into 'succ_nodes[0]'
221       succ_nodes[0]->SetSelectedStrategyAndCost(decision_star->succ_ops_stra_list_[0],
222                                                 decision_star->succ_ops_cost_list_[0]);
223       for (size_t k = 1; k < succ_nodes.size(); ++k) {
224         if (triangle_star_overwrite) {
225           // 'succ_nodes[k]' is overwritten strategy and cost.
226           succ_nodes[k]->SetSelectedStrategyAndCost(decision_star->succ_ops_stra_list_[k],
227                                                     decision_star->succ_ops_cost_list_[k]);
228         } else {
229           // In this case, 'succ_nodes[k]' is NOT overwritten strategy and cost, however, it checks the strategy.
230           succ_nodes[k]->CheckSelectedStrategy(decision_star->succ_ops_stra_list_[k]);
231         }
232       }
233       MS_LOG(INFO) << "Recover starElimination succeeded.";
234     } else {
235       MS_LOG(ERROR) << "Unknown Elimination type.";
236       return FAILED;
237     }
238   }
239 
240   return SUCCESS;
241 }
242 }  // namespace parallel
243 }  // namespace mindspore
244