1 /**
2 * Copyright 2019 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/parallel/auto_parallel/dp_algo_costmodel.h"
18
19 #include <memory>
20 #include <utility>
21 #include <vector>
22
23 namespace mindspore {
24 namespace parallel {
GetStrategy(const CostGraphPtr & graph)25 Status GetStrategy(const CostGraphPtr &graph) {
26 MS_LOG(INFO) << "Searching strategies begins.";
27 MS_EXCEPTION_IF_NULL(graph);
28 std::vector<EliminationPtr> eliminations;
29 bool flag = true;
30
31 // Phase 1: Shrink the CostGraph using 6 operations, and record them in the order.
32 // Note: the checking and applying of the 6 operations MUST in current order.
33 while (flag) {
34 flag = false;
35 auto node = graph->CheckOpElimination();
36 if (node != nullptr) {
37 // Applying the Operator Elimination
38 flag = true;
39 auto l_edge = node->GetAlivePrevEdges()[0];
40 auto r_edge = node->GetAliveSuccEdges()[0];
41 auto n_edge = graph->EliminationOp(node);
42 auto elimi_op = std::make_shared<OpElimination>(n_edge, l_edge, node, r_edge);
43 (void)eliminations.emplace_back(std::move(elimi_op));
44 }
45 if (!flag) {
46 auto edges = graph->CheckEdgeElimination();
47 if (!edges.empty()) {
48 // Applying the Edge Elimination
49 flag = true;
50 auto new_edge = graph->EliminationEdges(edges);
51 auto elimi_edge = std::make_shared<EdgeElimination>(new_edge, edges);
52 (void)eliminations.emplace_back(std::move(elimi_edge));
53 }
54 }
55 if (!flag) {
56 auto merge_node = graph->CheckMergeElimination();
57 if (merge_node != nullptr) {
58 // Applying the Merge Elimination
59 flag = true;
60 auto succ_edge = merge_node->GetAliveSuccEdges()[0];
61 auto target_node = graph->EliminationMerge(merge_node);
62 auto elimi_merge = std::make_shared<MergeElimination>(merge_node, succ_edge, target_node);
63 (void)eliminations.emplace_back(std::move(elimi_merge));
64 }
65 }
66 if (!flag) {
67 auto contracted_node = graph->CheckContractElimination();
68 if ((contracted_node != nullptr)) {
69 // Applying the Contract Elimination
70 flag = true;
71 auto prev_edge = contracted_node->GetAlivePrevEdges()[0];
72 auto target_node = graph->EliminationContract(contracted_node);
73 auto elimi_contract = std::make_shared<ContractElimination>(target_node, prev_edge, contracted_node);
74 (void)eliminations.emplace_back(std::move(elimi_contract));
75 }
76 }
77 if (!flag) {
78 auto triangle_pair = graph->CheckTriangleElimination();
79 if (triangle_pair.first != nullptr) {
80 // Applying the Triangle Elimination
81 flag = true;
82 auto eliminated_node = triangle_pair.first;
83 auto l_r_edge = triangle_pair.second;
84
85 auto left_node = l_r_edge->prev_operator();
86 auto left_edge = eliminated_node->GetAliveSuccEdges()[0];
87 auto right_edge = eliminated_node->GetAliveSuccEdges()[1];
88 MS_EXCEPTION_IF_NULL(left_edge);
89 if (left_edge->next_operator() != left_node) {
90 auto tmp = left_edge;
91 left_edge = right_edge;
92 right_edge = tmp;
93 }
94 auto left_node_cpy = graph->EliminationTriangle(eliminated_node, l_r_edge);
95 auto right_node = l_r_edge->next_operator();
96 auto elimi_tri =
97 std::make_shared<TriangleElimination>(eliminated_node, left_edge, left_node_cpy, right_edge, right_node);
98 (void)eliminations.emplace_back(std::move(elimi_tri));
99 }
100 }
101 if (!flag) {
102 auto star_center = graph->CheckStarElimination();
103 if (star_center != nullptr) {
104 // Applying the Star Elimination
105 flag = true;
106 auto succ_edges = graph->EliminationStar(star_center);
107 std::vector<OperatorInfoPtr> succ_nodes;
108 for (size_t i = 0; i < succ_edges.size(); ++i) {
109 MS_EXCEPTION_IF_NULL(succ_edges[i]);
110 succ_nodes.push_back(succ_edges[i]->next_operator());
111 }
112 auto elimi_star = std::make_shared<StarElimination>(star_center, succ_edges, succ_nodes);
113 (void)eliminations.emplace_back(std::move(elimi_star));
114 }
115 }
116 }
117
118 // Phase 2: Search the cost_list in the final graph, and determine the optimal one
119 if (graph->SearchStrategy() != SUCCESS) {
120 MS_LOG(ERROR) << "Searching strategy for the final failed.";
121 return FAILED;
122 }
123
124 // Phase 3: Recover the original CostGraph, the determine strategy for each operator
125 if (RecoverStrategy(eliminations) == SUCCESS) {
126 MS_LOG(INFO) << "Searching strategies ends.";
127 return SUCCESS;
128 } else {
129 MS_LOG(EXCEPTION) << "Searching strategies failed.";
130 }
131 }
132
RecoverStrategy(std::vector<EliminationPtr> eliminations)133 Status RecoverStrategy(std::vector<EliminationPtr> eliminations) {
134 std::vector<EliminationPtr>::reverse_iterator rit;
135 const auto triangle_star_overwrite = CostModelContext::GetInstance()->triangle_star_strategy_overwrite();
136 for (rit = eliminations.rbegin(); rit != eliminations.rend(); ++rit) {
137 if ((*rit)->isa<OpElimination>()) {
138 auto elimination_op = (*rit)->cast<OpEliminationPtr>();
139 auto e = elimination_op->new_edge_;
140 auto w = elimination_op->op_;
141 auto left_edge_op = elimination_op->left_edge_;
142 auto right_edge_op = elimination_op->right_edge_;
143 auto decision_op = e->selected_cost()->decision_ptr_->cast<OpEliminationDecisionPtr>();
144 w->SetSelectedStrategyAndCost(decision_op->op_strategy_, decision_op->middle_cost_);
145 left_edge_op->set_selected_cost(decision_op->left_cost_);
146 right_edge_op->set_selected_cost(decision_op->right_cost_);
147 MS_LOG(INFO) << "Recover opElimination succeeded.";
148 } else if ((*rit)->isa<EdgeElimination>()) {
149 auto elimination_edge = (*rit)->cast<EdgeEliminationPtr>();
150 auto new_edge = elimination_edge->new_edge_;
151 auto &edges = elimination_edge->edges_;
152 auto decision_edge = new_edge->selected_cost()->decision_ptr_->cast<EdgeEliminationDecisionPtr>();
153 for (size_t j = 0; j < edges.size(); ++j) {
154 MS_EXCEPTION_IF_NULL(edges[j]);
155 edges[j]->set_selected_cost(decision_edge->edges_cost_list_[j]);
156 }
157 MS_LOG(INFO) << "Recover edgeElimination succeeded.";
158 } else if ((*rit)->isa<MergeElimination>()) {
159 auto elimination_merge = (*rit)->cast<MergeEliminationPtr>();
160 auto target_node_m = elimination_merge->target_node_;
161 auto merged_node = elimination_merge->merged_node_;
162 auto merged_edge = elimination_merge->dir_edge_;
163 MS_EXCEPTION_IF_NULL(target_node_m->selected_cost());
164 MS_EXCEPTION_IF_NULL(target_node_m->selected_cost()->decision_ptr_);
165 auto decision = target_node_m->selected_cost()->decision_ptr_->cast<MergeEliminationDecisionPtr>();
166 merged_node->SetSelectedStrategyAndCost(decision->merged_op_strategy_, decision->merged_op_cost_);
167 merged_edge->set_selected_cost(decision->edge_cost_);
168 target_node_m->SetSelectedStrategyAndCost(decision->target_op_strategy_, decision->target_op_cost_);
169 MS_LOG(INFO) << "Recover mergeElimination succeeded.";
170 } else if ((*rit)->isa<ContractElimination>()) {
171 auto elimination_cont = (*rit)->cast<ContractEliminationPtr>();
172 auto target_node = elimination_cont->target_node_;
173 auto contracted_node = elimination_cont->contracted_node_;
174 auto contracted_edge = elimination_cont->dir_edge_;
175 auto decision_cont = target_node->selected_cost()->decision_ptr_->cast<ContractEliminationDecisionPtr>();
176 contracted_node->SetSelectedStrategyAndCost(decision_cont->contracted_op_strategy_,
177 decision_cont->contracted_op_cost_);
178 contracted_edge->set_selected_cost(decision_cont->edge_cost_);
179 target_node->SetSelectedStrategyAndCost(decision_cont->target_op_strategy_, decision_cont->target_cost_);
180 MS_LOG(INFO) << "Recover contractElimination succeeded.";
181 } else if ((*rit)->isa<TriangleElimination>()) {
182 auto elimination_tri = (*rit)->cast<TriangleEliminationPtr>();
183 auto left_node = elimination_tri->left_node_;
184 auto left_edge = elimination_tri->left_edge_;
185 auto eliminated_node = elimination_tri->eliminated_node_;
186 auto right_edge_tri = elimination_tri->right_edge_;
187 auto right_node = elimination_tri->right_node_;
188 auto decision_tri = left_node->selected_cost()->decision_ptr_->cast<TriangleEliminationDecisionPtr>();
189
190 eliminated_node->SetSelectedStrategyAndCost(decision_tri->eliminated_op_strategy_,
191 decision_tri->eliminated_op_cost_);
192 left_edge->set_selected_cost(decision_tri->left_edge_cost_);
193 right_edge_tri->set_selected_cost(decision_tri->right_edge_cost_);
194 // 'left_node' recovers the strategy.
195 left_node->SetSelectedStrategyAndCost(decision_tri->left_node_strategy_, decision_tri->left_node_cost_);
196 if (triangle_star_overwrite) {
197 // 'right_node' recovers the strategy.
198 MS_LOG(INFO) << "Overwrite the right-node: " << right_node->name() << " in recovering triangle elimination.";
199 right_node->SetSelectedStrategyAndCost(decision_tri->right_node_strategy_, decision_tri->right_node_cost_);
200 } else {
201 // In this case, 'right_node' is not overwritten strategy, and it checks strategy consistency.
202 right_node->CheckSelectedStrategy(decision_tri->right_node_strategy_);
203 }
204 MS_LOG(INFO) << "Recover triangleElimination succeeded.";
205 } else if ((*rit)->isa<StarElimination>()) {
206 auto elimination_star = (*rit)->cast<StarEliminationPtr>();
207 auto merged_node_star = elimination_star->eliminated_node_;
208 auto succ_edges = elimination_star->succ_edges_;
209 auto succ_nodes = elimination_star->succ_ops_;
210 // decision is hidden in succ_nodes[0]
211 auto decision_star = succ_nodes[0]->selected_cost()->decision_ptr_->cast<StarEliminationDecisionPtr>();
212 merged_node_star->SetSelectedStrategyAndCost(decision_star->eliminated_op_strategy_,
213 decision_star->eliminated_op_cost_);
214 for (size_t i = 0; i < succ_edges.size(); ++i) {
215 succ_edges[i]->set_selected_cost(decision_star->succ_edges_cost_list_[i]);
216 }
217 MS_EXCEPTION_IF_NULL(succ_nodes[0]);
218 MS_EXCEPTION_IF_NULL(decision_star->succ_ops_stra_list_[0]);
219 MS_EXCEPTION_IF_NULL(decision_star->succ_ops_cost_list_[0]);
220 // Star is eliminated into 'succ_nodes[0]'
221 succ_nodes[0]->SetSelectedStrategyAndCost(decision_star->succ_ops_stra_list_[0],
222 decision_star->succ_ops_cost_list_[0]);
223 for (size_t k = 1; k < succ_nodes.size(); ++k) {
224 if (triangle_star_overwrite) {
225 // 'succ_nodes[k]' is overwritten strategy and cost.
226 succ_nodes[k]->SetSelectedStrategyAndCost(decision_star->succ_ops_stra_list_[k],
227 decision_star->succ_ops_cost_list_[k]);
228 } else {
229 // In this case, 'succ_nodes[k]' is NOT overwritten strategy and cost, however, it checks the strategy.
230 succ_nodes[k]->CheckSelectedStrategy(decision_star->succ_ops_stra_list_[k]);
231 }
232 }
233 MS_LOG(INFO) << "Recover starElimination succeeded.";
234 } else {
235 MS_LOG(ERROR) << "Unknown Elimination type.";
236 return FAILED;
237 }
238 }
239
240 return SUCCESS;
241 }
242 } // namespace parallel
243 } // namespace mindspore
244