1 /**
2 * Copyright 2019 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include <algorithm>
17 #include <cstdlib>
18 #include <numeric>
19 #include <string>
20 #include <utility>
21 #include <vector>
22 #include <queue>
23
24 #include "frontend/parallel/auto_parallel/graph_costmodel.h"
25 #include "frontend/parallel/ops_info/reshape_info.h"
26 #include "frontend/parallel/step_auto_parallel.h"
27
28 namespace mindspore {
29 namespace parallel {
30 CostGraphPtr entire_costgraph = nullptr;
31 size_t TOTAL_OPS = 0;
32
Init()33 void CostGraph::Init() {
34 inputs_tensor_name_list_.clear();
35 tuple_getitem_list_.clear();
36 ops_.clear();
37 edges_.clear();
38 connected_compoents_.clear();
39 out_edges_.clear();
40 in_edges_.clear();
41 }
42
RemoveOperator(const OperatorInfoPtr & op)43 void CostGraph::RemoveOperator(const OperatorInfoPtr &op) {
44 for (auto it = ops_.begin(); it != ops_.end();) {
45 if ((*it) == op) {
46 it = ops_.erase(it);
47 } else {
48 ++it;
49 }
50 }
51 }
52
IsOperatorInCostGraph(const OperatorInfoPtr & op_test)53 bool CostGraph::IsOperatorInCostGraph(const OperatorInfoPtr &op_test) {
54 struct IsInGraph {
55 const OperatorInfoPtr test_;
56 explicit IsInGraph(const OperatorInfoPtr &n) : test_(n) {}
57 bool operator()(const OperatorInfoPtr &in) const { return (test_ == in); }
58 };
59 return std::any_of(ops_.begin(), ops_.end(), IsInGraph(op_test));
60 }
61
AddEdge(OperatorInfoPtr u_node,OperatorInfoPtr v_node,const EdgePtr & edge)62 void CostGraph::AddEdge(OperatorInfoPtr u_node, OperatorInfoPtr v_node, const EdgePtr &edge) {
63 std::vector<EdgePtr> curr_edges(edges_[{u_node, v_node}]);
64 curr_edges.push_back(edge);
65 edges_[{u_node, v_node}] = curr_edges;
66
67 std::vector<EdgePtr> curr_out_edges(out_edges_[u_node]);
68 curr_out_edges.push_back(edge);
69 out_edges_[u_node] = curr_out_edges;
70
71 std::vector<EdgePtr> curr_in_edges(in_edges_[v_node]);
72 curr_in_edges.push_back(edge);
73 in_edges_[v_node] = curr_in_edges;
74 }
75
IsEdgeInCostGraph(const std::string & test_edge_name,size_t output_index,size_t input_index)76 bool CostGraph::IsEdgeInCostGraph(const std::string &test_edge_name, size_t output_index, size_t input_index) {
77 for (auto &edge_pair : edges_) {
78 auto edges = edge_pair.second;
79 for (auto &edge : edges) {
80 MS_EXCEPTION_IF_NULL(edge);
81 bool bool_result = (edge->edge_name() == test_edge_name) && (edge->prev_op_output_index() == output_index) &&
82 (edge->next_op_input_index() == input_index);
83 if (bool_result) {
84 return true;
85 }
86 }
87 }
88 return false;
89 }
90
StrategyPropagate(const std::map<OperatorInfoPtr,StrategyPtr> & ops_stras)91 void CostGraph::StrategyPropagate(const std::map<OperatorInfoPtr, StrategyPtr> &ops_stras) {
92 if (ops_stras.empty()) {
93 MS_LOG(EXCEPTION) << "There is no operator that is configured sharding strategy.";
94 }
95 std::map<OperatorInfoPtr, bool> visited;
96 for (auto &op : ops_) {
97 visited[op] = false;
98 }
99 for (auto &op_stra : ops_stras) {
100 BFS(op_stra.first, op_stra.second, ops_stras, &visited);
101 }
102 }
103
CheckShardingConsisitency(std::map<OperatorInfoPtr,StrategyPtr> configured_ops,const OperatorInfoPtr & curr_op,const OperatorInfoPtr & another_op,const CostPtr & cost,const EdgePtr & edge)104 void CheckShardingConsisitency(std::map<OperatorInfoPtr, StrategyPtr> configured_ops, const OperatorInfoPtr &curr_op,
105 const OperatorInfoPtr &another_op, const CostPtr &cost, const EdgePtr &edge) {
106 if ((configured_ops.find(another_op) == configured_ops.end()) &&
107 (cost == nullptr || cost->communication_cost_ != 0.0)) {
108 PrintStrategy(another_op->selected_strategy());
109 PrintStrategy(curr_op->selected_strategy());
110 MS_LOG(EXCEPTION) << "There are redistribution cost occurs at edge: " << edge->edge_name()
111 << ", consider configuring sharding strategies for two operators."
112 << " The full name of these two operators are: " << curr_op->cnode()->fullname_with_scope()
113 << " and " << another_op->cnode()->fullname_with_scope();
114 }
115 }
116
BFS(const OperatorInfoPtr & op,const StrategyPtr & op_stra,std::map<OperatorInfoPtr,StrategyPtr> configured_ops,std::map<OperatorInfoPtr,bool> * visited)117 void CostGraph::BFS(const OperatorInfoPtr &op, const StrategyPtr &op_stra,
118 std::map<OperatorInfoPtr, StrategyPtr> configured_ops, std::map<OperatorInfoPtr, bool> *visited) {
119 std::queue<std::pair<std::pair<OperatorInfoPtr, StrategyPtr>, size_t>> next_level;
120 (void)next_level.emplace(std::make_pair(op, op_stra), 0);
121 while (!next_level.empty()) {
122 auto curr_op = next_level.front().first.first;
123 auto configured_stra = next_level.front().first.second;
124 auto curr_depth = next_level.front().second;
125 visited->at(curr_op) = true;
126 MS_LOG(INFO) << "curr_depth: " << curr_depth;
127 curr_op->SetSelectedStrategy(configured_stra, curr_depth);
128 for (auto &edge : curr_op->succ_edges()) {
129 const auto &next_op = edge->next_operator();
130 if (visited->at(next_op)) {
131 const auto cost = edge->GetCostByStrategyPair({curr_op->selected_strategy(), next_op->selected_strategy()});
132 CheckShardingConsisitency(configured_ops, curr_op, next_op, cost, edge);
133 continue;
134 }
135 if ((curr_depth > 0) && (configured_ops.find(next_op) != configured_ops.end())) {
136 const auto &next_op_conf_stra = configured_ops[next_op];
137 const auto &next_op_stra = edge->GetNextOpStrategyByPrevOpStrategyWithZeroComm(curr_op->selected_strategy());
138 if ((next_op_conf_stra == nullptr) || (!next_op_conf_stra->IsEqual(next_op_stra))) {
139 MS_LOG(EXCEPTION) << "Sharding strategies should be configured on the boundary operators. "
140 << "Currently reaching " << curr_op->name() << " and " << next_op->name() << "."
141 << " The full name of these two operators are: " << curr_op->cnode()->fullname_with_scope()
142 << " and " << next_op->cnode()->fullname_with_scope();
143 }
144 }
145 if (configured_ops.find(next_op) != configured_ops.end()) {
146 continue;
147 }
148 const auto &next_op_stra = edge->GetNextOpStrategyByPrevOpStrategyWithZeroComm(curr_op->selected_strategy());
149 if (next_op_stra == nullptr) {
150 PrintStrategy(curr_op->selected_strategy());
151 MS_LOG(EXCEPTION) << next_op->name() << "'s strategy is null in the edge: " << edge->edge_name();
152 }
153 (void)next_level.emplace(std::make_pair(next_op, next_op_stra), curr_depth + 1);
154 }
155 for (auto &edge : curr_op->prev_edges()) {
156 const auto &prev_op = edge->prev_operator();
157 if (visited->at(prev_op)) {
158 const auto cost = edge->GetCostByStrategyPair({prev_op->selected_strategy(), curr_op->selected_strategy()});
159 CheckShardingConsisitency(configured_ops, curr_op, prev_op, cost, edge);
160 continue;
161 }
162 if ((curr_depth > 0) && (configured_ops.find(prev_op) != configured_ops.end())) {
163 const auto &prev_op_conf_stra = configured_ops[prev_op];
164 const auto &prev_op_stra = edge->GetPrevOpStrategyByNextOpStrategyWithZeroComm(curr_op->selected_strategy());
165 if ((prev_op_conf_stra == nullptr) || (!prev_op_conf_stra->IsEqual(prev_op_stra))) {
166 MS_LOG(ERROR) << "curr_depth: " << curr_depth;
167 MS_LOG(EXCEPTION) << "Sharding strategies should be configured on the boundary operators. "
168 << "Currently reaching " << prev_op->name() << " and " << curr_op->name() << "."
169 << " The full name of these two operators are: " << prev_op->cnode()->fullname_with_scope()
170 << " and " << curr_op->cnode()->fullname_with_scope();
171 }
172 }
173 if (configured_ops.find(prev_op) != configured_ops.end()) {
174 continue;
175 }
176 const auto &prev_op_stra = edge->GetPrevOpStrategyByNextOpStrategyWithZeroComm(curr_op->selected_strategy());
177 if (prev_op_stra == nullptr) {
178 PrintStrategy(curr_op->selected_strategy());
179 MS_LOG(EXCEPTION) << prev_op->name() << "'s strategy is null in the edge: " << edge->edge_name();
180 }
181 (void)next_level.emplace(std::make_pair(prev_op, prev_op_stra), curr_depth + 1);
182 }
183 next_level.pop();
184 }
185 }
186
ConstructConnectedComponents(std::vector<OperatorInfoPtr> alive_ops)187 std::vector<std::shared_ptr<CostGraph>> CostGraph::ConstructConnectedComponents(
188 std::vector<OperatorInfoPtr> alive_ops) {
189 std::map<OperatorInfoPtr, bool> visited;
190
191 for (auto &op : alive_ops) {
192 visited[op] = false;
193 }
194
195 MS_LOG(INFO) << "visited: " << visited.size() << ".";
196 for (auto &op : alive_ops) {
197 if ((!visited[op]) && op->is_alive()) {
198 std::shared_ptr<CostGraph> new_component = std::make_shared<CostGraph>();
199 MS_EXCEPTION_IF_NULL(new_component);
200 DFS(op, &visited, new_component);
201 connected_compoents_.push_back(new_component);
202 }
203 }
204 return connected_compoents_;
205 }
206
DFS(const OperatorInfoPtr & current_op,std::map<OperatorInfoPtr,bool> * visited,const std::shared_ptr<CostGraph> & component)207 void CostGraph::DFS(const OperatorInfoPtr ¤t_op, std::map<OperatorInfoPtr, bool> *visited,
208 const std::shared_ptr<CostGraph> &component) {
209 MS_EXCEPTION_IF_NULL(visited);
210 MS_EXCEPTION_IF_NULL(component);
211 visited->at(current_op) = true;
212 component->AddOperator(current_op);
213
214 for (auto &edge : current_op->succ_edges()) {
215 bool bool_test = (visited->find(edge->next_operator()) != visited->end()) &&
216 (!visited->at(edge->next_operator())) && edge->next_operator()->is_alive();
217 if (bool_test) {
218 component->AddEdge(current_op, edge->next_operator(), edge);
219 DFS(edge->next_operator(), visited, component);
220 }
221 }
222
223 for (auto &edge : current_op->prev_edges()) {
224 bool bool_test = (visited->find(edge->prev_operator()) != visited->end()) &&
225 (!visited->at(edge->prev_operator())) && edge->prev_operator()->is_alive();
226 if (bool_test) {
227 component->AddEdge(edge->prev_operator(), current_op, edge);
228 DFS(edge->prev_operator(), visited, component);
229 }
230 }
231 }
232
233 // Create final cost list for the graph: u --> v
CreateFinalCostList(const OperatorInfoPtr & u,const std::shared_ptr<Edge> & e,const OperatorInfoPtr & v)234 CostPtrList CostGraph::CreateFinalCostList(const OperatorInfoPtr &u, const std::shared_ptr<Edge> &e,
235 const OperatorInfoPtr &v) {
236 MS_EXCEPTION_IF_NULL(u);
237 MS_EXCEPTION_IF_NULL(v);
238 MS_EXCEPTION_IF_NULL(e);
239 CostPtrList ret;
240 for (const auto &u_strategy : u->GetStrategyCost()) {
241 for (const auto &v_strategy : v->GetStrategyCost()) {
242 MS_EXCEPTION_IF_NULL(u_strategy);
243 MS_EXCEPTION_IF_NULL(v_strategy);
244 auto u_strategy_ptr = u_strategy->strategy_ptr;
245 auto v_strategy_ptr = v_strategy->strategy_ptr;
246 CostPtrList clist1 = u_strategy->cost_list;
247 CostPtrList clist2 = e->GetCostList(u_strategy_ptr, v_strategy_ptr);
248 CostPtrList clist3 = v_strategy->cost_list;
249 for (const auto &cost1 : clist1) {
250 for (const auto &cost2 : clist2) {
251 for (const auto &cost3 : clist3) {
252 MS_EXCEPTION_IF_NULL(cost1);
253 MS_EXCEPTION_IF_NULL(cost2);
254 MS_EXCEPTION_IF_NULL(cost3);
255 double computation = cost1->computation_cost_ + cost2->computation_cost_ + cost3->computation_cost_;
256 double memory = cost1->memory_with_reuse_ + cost2->memory_with_reuse_ + cost3->memory_with_reuse_;
257 double communication = cost1->communication_cost_ + cost2->communication_cost_ + cost3->communication_cost_;
258 double communication_forward =
259 cost1->communication_forward_ + cost2->communication_forward_ + cost3->communication_forward_;
260 double communication_without_para = cost1->communication_without_parameter_ +
261 cost2->communication_without_parameter_ +
262 cost3->communication_without_parameter_;
263 auto decision =
264 std::make_shared<FinalDecision>(u_strategy->strategy_ptr, v_strategy->strategy_ptr, cost1, cost2, cost3);
265 auto cost = std::make_shared<Cost>(computation, communication, decision);
266 const auto gamma = CostModelContext::GetInstance()->costmodel_gamma();
267 MS_EXCEPTION_IF_NULL(cost);
268 cost->communication_without_parameter_ = communication_without_para;
269 cost->communication_with_partial_para_ =
270 communication_without_para + gamma * (communication - communication_without_para);
271 cost->memory_with_reuse_ = memory;
272 cost->communication_forward_ = communication_forward;
273 ret.push_back(cost);
274 }
275 }
276 }
277 }
278 }
279
280 Simplify(&ret);
281 return ret;
282 }
283
284 // Create final cost list for the graph containing a single node: u
CreateFinalSingleCostList(const OperatorInfoPtr & u)285 CostPtrList CostGraph::CreateFinalSingleCostList(const OperatorInfoPtr &u) {
286 MS_EXCEPTION_IF_NULL(u);
287 CostPtrList ret;
288 for (const auto &u_strategy : u->GetStrategyCost()) {
289 MS_EXCEPTION_IF_NULL(u_strategy);
290 auto u_strategy_ptr = u_strategy->strategy_ptr;
291 CostPtrList clist1 = u_strategy->cost_list;
292 for (const auto &cost1 : clist1) {
293 MS_EXCEPTION_IF_NULL(cost1);
294 auto decision = std::make_shared<FinalSingleDecision>(u_strategy_ptr, cost1);
295 auto new_cost = std::make_shared<Cost>(cost1->computation_cost_, cost1->communication_cost_, decision);
296 const auto gamma = CostModelContext::GetInstance()->costmodel_gamma();
297 MS_EXCEPTION_IF_NULL(new_cost);
298 new_cost->communication_without_parameter_ = cost1->communication_without_parameter_;
299 new_cost->communication_with_partial_para_ =
300 cost1->communication_without_parameter_ +
301 gamma * (cost1->communication_cost_ - cost1->communication_without_parameter_);
302 new_cost->memory_with_reuse_ = cost1->memory_with_reuse_;
303 new_cost->communication_forward_ = cost1->communication_forward_;
304 ret.push_back(new_cost);
305 }
306 }
307
308 Simplify(&ret);
309 return ret;
310 }
311
SelectCostWithMinInferenceTime(const CostPtrList & cost_list,double memory)312 CostPtr CostGraph::SelectCostWithMinInferenceTime(const CostPtrList &cost_list, double memory) {
313 // Select the cost with minimum inference time. Currently, the inference time is modeled as =
314 // costmodel_alpha_ * computation_cost + costmodel_beta_ * communication_forward_
315 if (cost_list.empty()) {
316 MS_LOG(ERROR) << "Final cost list is null.";
317 return nullptr;
318 }
319 CostPtrList after_mem_filter;
320 double minimum_memory = DBL_MAX;
321 // Filter out the valid costs.
322 for (auto &a_cost : cost_list) {
323 if (a_cost->memory_with_reuse_ <= memory) {
324 after_mem_filter.emplace_back(std::move(a_cost));
325 } else if (a_cost->memory_with_reuse_ < minimum_memory) {
326 minimum_memory = a_cost->memory_with_reuse_;
327 }
328 }
329 if (after_mem_filter.empty()) {
330 MS_LOG(ERROR) << "No available cost. The minimum memory cost is: " << minimum_memory
331 << ", the memory capacity is: " << memory << ".";
332 return nullptr;
333 }
334 // Init the returned value with first cost.
335 CostPtr ret = after_mem_filter[0];
336 const auto alpha = CostModelContext::GetInstance()->costmodel_alpha();
337 const auto beta = CostModelContext::GetInstance()->costmodel_beta();
338
339 double minimum = alpha * ret->computation_cost_ + beta * ret->communication_forward_;
340 MS_LOG(INFO) << "Cost 0: "
341 << "memory_cost: " << ret->memory_with_reuse_ << ", computation_cost_: " << ret->computation_cost_
342 << ", communication_forward_: " << ret->communication_forward_
343 << ", communication_with_partial_para_: " << ret->communication_with_partial_para_
344 << ", communication_cost_: " << ret->communication_cost_
345 << ", communication_without_parameter_: " << ret->communication_without_parameter_ << ".";
346 MS_LOG(INFO) << "Cost 0: total_cost: " << minimum;
347 for (size_t i = 1; i < after_mem_filter.size(); ++i) {
348 MS_EXCEPTION_IF_NULL(after_mem_filter[i]);
349 MS_LOG(INFO) << "Cost " << i << ": memory_cost: " << after_mem_filter[i]->memory_with_reuse_
350 << ", computation_cost_: " << after_mem_filter[i]->computation_cost_
351 << ", communication_forward_: " << after_mem_filter[i]->communication_forward_
352 << ", communication_with_partial_para_: " << after_mem_filter[i]->communication_with_partial_para_
353 << ", communication_cost_: " << after_mem_filter[i]->communication_cost_
354 << ", communication_without_parameter_: " << after_mem_filter[i]->communication_without_parameter_
355 << ".";
356 auto tmp = alpha * after_mem_filter[i]->computation_cost_ + beta * after_mem_filter[i]->communication_forward_;
357 MS_LOG(INFO) << "Cost " << i << ": total_cost: " << tmp;
358 if (minimum > tmp) {
359 minimum = tmp;
360 ret = after_mem_filter[i];
361 MS_LOG(INFO) << "Selected: " << i;
362 }
363 }
364 return ret;
365 }
366
SelectCostWithMinTrainingTime(const CostPtrList & cost_list,double memory)367 CostPtr CostGraph::SelectCostWithMinTrainingTime(const CostPtrList &cost_list, double memory) {
368 // Select the cost with minimum training time. Currently, the training time is modeled as =
369 // costmodel_alpha_ * computation_cost + costmodel_beta_ * communication_with_partial_para_
370 if (cost_list.empty()) {
371 MS_LOG(ERROR) << "Final cost list is null.";
372 return nullptr;
373 }
374 CostPtrList after_mem_filter;
375 double minimum_memory = DBL_MAX;
376 // Filter out the valid costs.
377 for (auto &a_cost : cost_list) {
378 if (a_cost->memory_with_reuse_ <= memory) {
379 after_mem_filter.emplace_back(std::move(a_cost));
380 } else if (a_cost->memory_with_reuse_ < minimum_memory) {
381 minimum_memory = a_cost->memory_with_reuse_;
382 }
383 }
384 if (after_mem_filter.empty()) {
385 MS_LOG(ERROR) << "No available cost. The minimum memory cost is: " << minimum_memory
386 << ", the memory capacity is: " << memory << ".";
387 return nullptr;
388 }
389 // Init the returned value with first cost.
390 CostPtr ret = after_mem_filter[0];
391 const auto alpha = CostModelContext::GetInstance()->costmodel_alpha();
392 const auto beta = CostModelContext::GetInstance()->costmodel_beta();
393
394 double minimum = alpha * ret->computation_cost_ + beta * ret->communication_with_partial_para_;
395 MS_LOG(INFO) << "Cost 0: "
396 << "memory_cost: " << ret->memory_with_reuse_ << ", computation_cost_: " << ret->computation_cost_
397 << ", communication_with_partial_para_: " << ret->communication_with_partial_para_
398 << ", communication_cost_: " << ret->communication_cost_
399 << ", communication_without_parameter_: " << ret->communication_without_parameter_ << ".";
400 MS_LOG(INFO) << "Cost 0: total_cost: " << minimum;
401 for (size_t i = 1; i < after_mem_filter.size(); ++i) {
402 MS_EXCEPTION_IF_NULL(after_mem_filter[i]);
403 MS_LOG(INFO) << "Cost " << i << ": memory_cost: " << after_mem_filter[i]->memory_with_reuse_
404 << ", computation_cost_: " << after_mem_filter[i]->computation_cost_
405 << ", communication_with_partial_para_: " << after_mem_filter[i]->communication_with_partial_para_
406 << ", communication_cost_: " << after_mem_filter[i]->communication_cost_
407 << ", communication_without_parameter_: " << after_mem_filter[i]->communication_without_parameter_
408 << ".";
409 auto tmp =
410 alpha * after_mem_filter[i]->computation_cost_ + beta * after_mem_filter[i]->communication_with_partial_para_;
411 MS_LOG(INFO) << "Cost " << i << ": total_cost: " << tmp;
412 if (minimum > tmp) {
413 minimum = tmp;
414 ret = after_mem_filter[i];
415 MS_LOG(INFO) << "Selected: " << i;
416 }
417 }
418 return ret;
419 }
420
SelectCostListWithMinTrainingTimeMultiple(const std::vector<CostPtrList> & all_cost_list,double available_memory) const421 CostPtrList CostGraph::SelectCostListWithMinTrainingTimeMultiple(const std::vector<CostPtrList> &all_cost_list,
422 double available_memory) const {
423 CostPtrList selected_cost_list(all_cost_list.size(), nullptr);
424 double minimum = DBL_MAX, total_memory = 0.0;
425 CostPtrList ret(all_cost_list.size(), nullptr);
426 // Check whether valid costs exist.
427 for (size_t i = 0; i < all_cost_list.size(); ++i) {
428 if (all_cost_list[i][0] == nullptr) {
429 MS_LOG(ERROR) << "The cost list " << i << " is empty.";
430 return ret;
431 } else {
432 double memory_i_cost = DBL_MAX;
433 for (size_t j = 0; j < all_cost_list[i].size(); ++j) {
434 if (all_cost_list[i][j]->memory_with_reuse_ < memory_i_cost) {
435 memory_i_cost = all_cost_list[i][j]->memory_with_reuse_;
436 }
437 }
438 total_memory += memory_i_cost;
439 }
440 }
441 if (total_memory >= available_memory) {
442 MS_LOG(ERROR) << "No strategy can be found under current memory: " << available_memory
443 << ", minimum strategy cost: " << total_memory << ".";
444 return selected_cost_list;
445 }
446
447 std::function<void(size_t)> recursive = [&all_cost_list, &selected_cost_list, &minimum, &ret, &recursive,
448 &available_memory](size_t k) {
449 const auto alpha = CostModelContext::GetInstance()->costmodel_alpha();
450 const auto beta = CostModelContext::GetInstance()->costmodel_beta();
451 if (k == all_cost_list.size()) {
452 double tmp_memory = 0.0, tmp_minimum = 0.0;
453 for (size_t i = 0; i < selected_cost_list.size(); ++i) {
454 MS_EXCEPTION_IF_NULL(selected_cost_list[i]);
455 tmp_memory += selected_cost_list[i]->memory_with_reuse_;
456 tmp_minimum += alpha * selected_cost_list[i]->computation_cost_ +
457 beta * selected_cost_list[i]->communication_with_partial_para_;
458 }
459 MS_LOG(INFO) << "tmp_memory: " << tmp_memory << ", tmp_minimum: " << tmp_minimum << ", minimum: " << minimum
460 << ".";
461 if (tmp_memory < available_memory && tmp_minimum < minimum) {
462 ret = selected_cost_list;
463 minimum = tmp_minimum;
464 MS_LOG(INFO) << "selected tmp_memory: " << tmp_memory << ", tmp_minimum: " << tmp_minimum << ".";
465 }
466 return;
467 }
468
469 MS_LOG(DEBUG) << "The value minimum: " << minimum << ", available_memory: " << available_memory << ".";
470 for (auto &c : all_cost_list[k]) {
471 selected_cost_list[k] = c;
472 recursive(k + 1);
473 }
474 };
475 recursive(0);
476 return ret;
477 }
478
SearchStrategyForMultiNodeFinalGraph(const std::vector<OperatorInfoPtr> & alive_ops)479 Status CostGraph::SearchStrategyForMultiNodeFinalGraph(const std::vector<OperatorInfoPtr> &alive_ops) {
480 MS_LOG(INFO) << "There are " << alive_ops.size() << " nodes in the final graph.";
481 auto connected_components = ConstructConnectedComponents(alive_ops);
482 MS_LOG(INFO) << "There are " << connected_components.size() << " components in the final graph.";
483 std::vector<CostPtrList> all_list;
484 for (size_t j = 0; j < connected_components.size(); ++j) {
485 auto one_component = connected_components[j];
486 MS_EXCEPTION_IF_NULL(one_component);
487 if (one_component->GetOperators().size() == 1) {
488 MS_LOG(INFO) << "There are 1 operator in a component in the final graph.";
489 auto cost_list_1 = one_component->CreateFinalSingleCostList(one_component->GetOperators()[0]);
490 all_list.push_back(cost_list_1);
491 } else if (one_component->GetOperators().size() == 2) {
492 MS_LOG(INFO) << "There are 2 operators in a component in the final graph.";
493 OperatorInfoPtr u, v;
494 auto first_op = one_component->GetOperators()[0];
495 auto second_op = one_component->GetOperators()[1];
496 MS_EXCEPTION_IF_NULL(first_op);
497 MS_EXCEPTION_IF_NULL(second_op);
498 if (!first_op->GetAliveSuccEdges().empty() &&
499 first_op->GetAliveSuccEdges()[0]->next_operator().get() == second_op.get()) {
500 u = first_op;
501 v = second_op;
502 } else if (!second_op->GetAliveSuccEdges().empty() &&
503 second_op->GetAliveSuccEdges()[0]->next_operator().get() == first_op.get()) {
504 u = second_op;
505 v = first_op;
506 } else {
507 MS_LOG(EXCEPTION) << "The final graph is not the case of u --> v, " << first_op->GetAliveSuccEdges().size()
508 << ", " << second_op->GetAliveSuccEdges().size() << ".";
509 }
510 MS_EXCEPTION_IF_NULL(u);
511 auto e = u->GetAliveSuccEdges()[0];
512 auto cost_list = one_component->CreateFinalCostList(u, e, v);
513 all_list.push_back(cost_list);
514 } else {
515 MS_LOG(EXCEPTION) << "There are " << one_component->GetOperators().size()
516 << " operators in a component in the final graph.";
517 }
518 }
519 const auto device_mem_capacity = CostModelContext::GetInstance()->device_memory_capacity();
520 auto selected_cost_list = SelectCostListWithMinTrainingTimeMultiple(all_list, device_mem_capacity);
521 for (size_t k = 0; k < selected_cost_list.size(); ++k) {
522 auto selected_cost = selected_cost_list[k];
523 if (selected_cost == nullptr) {
524 MS_LOG(ERROR) << "No valid strategy can be found under the current device memory: " << device_mem_capacity << ".";
525 return FAILED;
526 }
527 MS_EXCEPTION_IF_NULL(connected_components[k]);
528 if (connected_components[k]->GetOperators().size() == 1) {
529 auto u = connected_components[k]->GetOperators()[0];
530 auto decision_f = selected_cost->decision_ptr_->cast<FinalSingleDecisionPtr>();
531 u->SetSelectedStrategyAndCost(decision_f->u_strategy_, decision_f->u_cost_);
532 MS_LOG(INFO) << "Searching the strategy for the component " << k << " final graph ended.";
533 } else if (connected_components[k]->GetOperators().size() == 2) {
534 OperatorInfoPtr u = nullptr, v = nullptr;
535 auto first_op = connected_components[k]->GetOperators()[0];
536 auto second_op = connected_components[k]->GetOperators()[1];
537 MS_EXCEPTION_IF_NULL(first_op);
538 MS_EXCEPTION_IF_NULL(second_op);
539 if (!first_op->GetAliveSuccEdges().empty() &&
540 first_op->GetAliveSuccEdges()[0]->next_operator().get() == second_op.get()) {
541 u = first_op;
542 v = second_op;
543 } else if (!second_op->GetAliveSuccEdges().empty() &&
544 second_op->GetAliveSuccEdges()[0]->next_operator().get() == first_op.get()) {
545 u = second_op;
546 v = first_op;
547 }
548 MS_EXCEPTION_IF_NULL(u);
549 auto e = u->GetAliveSuccEdges()[0];
550 MS_EXCEPTION_IF_NULL(v);
551 MS_EXCEPTION_IF_NULL(e);
552 MS_EXCEPTION_IF_NULL(selected_cost->decision_ptr_);
553 auto decision = selected_cost->decision_ptr_->cast<FinalDecisionPtr>();
554 MS_EXCEPTION_IF_NULL(decision);
555 u->SetSelectedStrategyAndCost(decision->u_strategy_, decision->left_cost_);
556 v->SetSelectedStrategyAndCost(decision->v_strategy_, decision->right_cost_);
557 e->set_selected_cost(decision->middle_cost_);
558 MS_LOG(INFO) << "Searching the strategy for the component " << k << " final graph ended.";
559 }
560 }
561 return SUCCESS;
562 }
563
SearchStrategyForTwoNodeFinalGraph(const std::vector<OperatorInfoPtr> & alive_ops)564 Status CostGraph::SearchStrategyForTwoNodeFinalGraph(const std::vector<OperatorInfoPtr> &alive_ops) {
565 // In this case, the final graph should contains exactly 2 nodes.
566 if (alive_ops.empty()) {
567 MS_LOG(INFO) << "0 Operator in the final graph.";
568 return SUCCESS;
569 }
570 OperatorInfoPtr u, v;
571 MS_EXCEPTION_IF_NULL(alive_ops[0]);
572 MS_EXCEPTION_IF_NULL(alive_ops[1]);
573 const auto phase = CostModelContext::GetInstance()->run_phase();
574 const auto device_mem_capacity = CostModelContext::GetInstance()->device_memory_capacity();
575 if (!alive_ops[0]->GetAliveSuccEdges().empty() &&
576 alive_ops[0]->GetAliveSuccEdges()[0]->next_operator().get() == alive_ops[1].get()) {
577 u = alive_ops[0];
578 v = alive_ops[1];
579 } else if (!alive_ops[1]->GetAliveSuccEdges().empty() &&
580 alive_ops[1]->GetAliveSuccEdges()[0]->next_operator().get() == alive_ops[0].get()) {
581 u = alive_ops[1];
582 v = alive_ops[0];
583 } else {
584 if (!alive_ops[0]->GetAliveSuccEdges().empty() || !alive_ops[1]->GetAliveSuccEdges().empty()) {
585 MS_LOG(EXCEPTION) << "The final graph is not the case of u --> v, " << alive_ops[0]->GetAliveSuccEdges().size()
586 << ", " << alive_ops[1]->GetAliveSuccEdges().size() << ".";
587 } else {
588 // In this case, the final graph consists of two single nodes
589 MS_LOG(INFO) << "There are 2 single nodes in the final graph.";
590 std::vector<CostPtrList> all_list;
591 auto connected_components = ConstructConnectedComponents(alive_ops);
592 MS_LOG(INFO) << "There are " << connected_components.size() << " components in the final graph.";
593 for (size_t i = 0; i < connected_components.size(); ++i) {
594 MS_LOG(INFO) << "There are 1 operator in a component in the final graph.";
595 auto one_component = connected_components[i];
596 MS_EXCEPTION_IF_NULL(one_component);
597 auto cost_list = one_component->CreateFinalSingleCostList(one_component->GetOperators()[0]);
598 all_list.push_back(cost_list);
599 }
600 CostPtrList selected_cost_list;
601 if (phase == TRAINING_PHASE) {
602 // training phase
603 selected_cost_list = SelectCostListWithMinTrainingTimeMultiple(all_list, device_mem_capacity);
604 } else {
605 // inference phase
606 MS_LOG(EXCEPTION) << "Currently, searching strategy for the two-separated-node final graph in the inference "
607 "phase is not supported.";
608 }
609 for (size_t k = 0; k < selected_cost_list.size(); ++k) {
610 auto selected_cost = selected_cost_list[k];
611 if (selected_cost == nullptr) {
612 MS_LOG(ERROR) << "No valid strategy can be found under the current device memory: " << device_mem_capacity
613 << ".";
614 return FAILED;
615 }
616 MS_EXCEPTION_IF_NULL(connected_components[k]);
617 auto one_operator = connected_components[k]->GetOperators()[0];
618 MS_EXCEPTION_IF_NULL(selected_cost->decision_ptr_);
619 auto decision = selected_cost->decision_ptr_->cast<FinalSingleDecisionPtr>();
620 MS_EXCEPTION_IF_NULL(decision);
621 one_operator->SetSelectedStrategyAndCost(decision->u_strategy_, decision->u_cost_);
622 MS_LOG(INFO) << "Searching the strategy for the component " << k << " final graph ended.";
623 }
624
625 return SUCCESS;
626 }
627 }
628 MS_LOG(INFO) << "There are 2 nodes in the final graph.";
629 // In this case, the finale graph is exactly of the form: u --> v
630 MS_EXCEPTION_IF_NULL(u);
631 MS_EXCEPTION_IF_NULL(v);
632 auto e = u->GetAliveSuccEdges()[0];
633 MS_EXCEPTION_IF_NULL(e);
634 auto f_cost_list = CreateFinalCostList(u, e, v);
635 CostPtr cost = nullptr;
636 if (phase == TRAINING_PHASE) {
637 // training phase
638 cost = SelectCostWithMinTrainingTime(f_cost_list, device_mem_capacity);
639 } else {
640 MS_LOG(EXCEPTION) << "Currently, searching strategy for the two-connected-node final graph in the inference "
641 "phase is not supported.";
642 }
643 if (cost == nullptr) {
644 MS_LOG(ERROR) << "No valid strategy can be found under the current device memory: " << device_mem_capacity << ".";
645 return FAILED;
646 }
647 MS_EXCEPTION_IF_NULL(cost->decision_ptr_);
648 auto f_decision = cost->decision_ptr_->cast<FinalDecisionPtr>();
649 MS_EXCEPTION_IF_NULL(f_decision);
650 u->SetSelectedStrategyAndCost(f_decision->u_strategy_, f_decision->left_cost_);
651 v->SetSelectedStrategyAndCost(f_decision->v_strategy_, f_decision->right_cost_);
652 e->set_selected_cost(f_decision->middle_cost_);
653 MS_LOG(INFO) << "Searching the strategy for the eliminated final graph ended.";
654 return SUCCESS;
655 }
656
657 // searching the strategy for the final eliminated graph
SearchStrategy()658 Status CostGraph::SearchStrategy() {
659 MS_LOG(INFO) << "Searching the strategy for the eliminated final graph began.";
660 std::vector<OperatorInfoPtr> alive_ops;
661 (void)std::for_each(ops_.begin(), ops_.end(), [&alive_ops](const OperatorInfoPtr &op) {
662 MS_EXCEPTION_IF_NULL(op);
663 if (op->is_alive()) {
664 alive_ops.push_back(op);
665 }
666 });
667 const auto phase = CostModelContext::GetInstance()->run_phase();
668 const auto device_mem_capacity = CostModelContext::GetInstance()->device_memory_capacity();
669
670 if (alive_ops.size() > 2) {
671 if (phase == TRAINING_PHASE) {
672 // training phase
673 return SearchStrategyForMultiNodeFinalGraph(alive_ops);
674 } else {
675 // inference phase
676 MS_LOG(EXCEPTION)
677 << "Currently, searching strategy for the multi-node final graph in inference phase is not supported.";
678 }
679 } else if (alive_ops.size() == 1) {
680 MS_LOG(INFO) << "There are 1 single node in the final graph.";
681 OperatorInfoPtr u = alive_ops[0];
682 auto cost_list = CreateFinalSingleCostList(u);
683 CostPtr cost = nullptr;
684 if (phase == TRAINING_PHASE) {
685 // training phase
686 cost = SelectCostWithMinTrainingTime(cost_list, device_mem_capacity);
687 } else {
688 // inference phase
689 cost = SelectCostWithMinInferenceTime(cost_list, device_mem_capacity);
690 }
691 if (cost == nullptr) {
692 MS_LOG(ERROR) << "No valid strategy can be found under the current device memory: " << device_mem_capacity << ".";
693 return FAILED;
694 }
695 MS_EXCEPTION_IF_NULL(u);
696 MS_EXCEPTION_IF_NULL(cost->decision_ptr_);
697 auto decision = cost->decision_ptr_->cast<FinalSingleDecisionPtr>();
698 MS_EXCEPTION_IF_NULL(decision);
699 u->SetSelectedStrategyAndCost(decision->u_strategy_, decision->u_cost_);
700 MS_LOG(INFO) << "Searching the strategy for the eliminated final graph ended.";
701 return SUCCESS;
702 } else {
703 return SearchStrategyForTwoNodeFinalGraph(alive_ops);
704 }
705 }
706
707 // Given a graph which contains the following subgraph: u --> v --> w, the node v can be eliminated
708 // return the v and the edge u --> v
CheckOpElimination() const709 OperatorInfoPtr CostGraph::CheckOpElimination() const {
710 for (auto &op : ops_) {
711 bool bool_test = op->is_alive() && op->GetAliveSuccEdges().size() == 1 && op->GetAlivePrevEdges().size() == 1;
712 if (bool_test) {
713 if ((op->GetAliveSuccEdges()[0]->next_operator() != op) && (op->GetAlivePrevEdges()[0]->prev_operator() != op)) {
714 return op;
715 }
716 }
717 }
718 return nullptr;
719 }
720
721 // Check the graph whether an EdgeElimination can be performed
CheckEdgeElimination() const722 std::vector<std::shared_ptr<Edge>> CostGraph::CheckEdgeElimination() const {
723 for (auto &op : ops_) {
724 MS_EXCEPTION_IF_NULL(op);
725 if (!op->is_alive()) continue;
726 std::map<void *, int64_t> count;
727 for (auto &edge_su : op->GetAliveSuccEdges()) {
728 MS_EXCEPTION_IF_NULL(edge_su);
729 auto v = edge_su->next_operator();
730 count[v.get()]++;
731 }
732 for (auto &pair : count) {
733 auto *op_ptr = pair.first;
734 int64_t op_count = pair.second;
735 if (op_count > 1) {
736 std::vector<std::shared_ptr<Edge>> ret;
737 for (auto &edge : op->GetAliveSuccEdges()) {
738 MS_EXCEPTION_IF_NULL(edge);
739 if (edge->next_operator().get() == op_ptr) {
740 ret.push_back(edge);
741 }
742 }
743 return ret;
744 }
745 }
746 }
747 return {};
748 }
749
750 // Check the graph whether a MergeElimination can be performed
CheckMergeElimination() const751 OperatorInfoPtr CostGraph::CheckMergeElimination() const {
752 for (auto &op : ops_) {
753 MS_EXCEPTION_IF_NULL(op);
754 bool bool_test = op->is_alive() && op->GetAlivePrevEdges().empty() && op->GetAliveSuccEdges().size() == 1;
755 if (bool_test) {
756 auto next_op = op->GetAliveSuccEdges()[0]->next_operator();
757 MS_EXCEPTION_IF_NULL(next_op);
758 if (!next_op->GetAlivePrevEdges().empty()) {
759 return op;
760 }
761 }
762 }
763 return nullptr;
764 }
765
766 // Check the graph whether a ContractElimination can be performed
CheckContractElimination() const767 OperatorInfoPtr CostGraph::CheckContractElimination() const {
768 for (auto &op : ops_) {
769 MS_EXCEPTION_IF_NULL(op);
770 bool bool_test = op->is_alive() && op->GetAlivePrevEdges().size() == 1 && op->GetAliveSuccEdges().empty();
771 if (bool_test) {
772 auto edge = op->GetAlivePrevEdges()[0];
773 MS_EXCEPTION_IF_NULL(edge);
774 auto prev_op = edge->prev_operator();
775 MS_EXCEPTION_IF_NULL(prev_op);
776 if (!prev_op->GetAliveSuccEdges().empty()) {
777 return op;
778 }
779 }
780 }
781 return nullptr;
782 }
783
CheckSourceElimination() const784 std::pair<OperatorInfoPtr, OperatorInfoPtr> CostGraph::CheckSourceElimination() const {
785 size_t source_count = 0;
786 std::vector<OperatorInfoPtr> op_vector(2, nullptr);
787 for (auto &op : ops_) {
788 MS_EXCEPTION_IF_NULL(op);
789 bool bool_test = op->is_alive() && op->GetAlivePrevEdges().empty() && op->GetAliveSuccEdges().size() > 0;
790 if (bool_test) {
791 op_vector[source_count++] = op;
792 if (source_count == 2) {
793 return std::make_pair(op_vector[0], op_vector[1]);
794 }
795 }
796 }
797 return std::make_pair(nullptr, nullptr);
798 }
799
CreateSourceEliminationSubCostList(StrategyPtr op1_old_stra,const CostPtrList & op1_old_clist,StrategyPtr op2_old_stra,const CostPtrList & op2_old_clist,CostPtrList * op1_new_clist)800 void CostGraph::CreateSourceEliminationSubCostList(StrategyPtr op1_old_stra, const CostPtrList &op1_old_clist,
801 StrategyPtr op2_old_stra, const CostPtrList &op2_old_clist,
802 CostPtrList *op1_new_clist) {
803 for (auto &op1_cost : op1_old_clist) {
804 for (auto &op2_cost : op2_old_clist) {
805 double computation = op1_cost->computation_cost_ + op2_cost->computation_cost_;
806 double memory = op1_cost->memory_with_reuse_ + op2_cost->memory_with_reuse_;
807 double communication = op1_cost->communication_cost_ + op2_cost->communication_cost_;
808 double communication_forward = op1_cost->communication_forward_ + op2_cost->communication_forward_;
809 double communication_without_para =
810 op1_cost->communication_without_parameter_ + op2_cost->communication_without_parameter_;
811 auto decision = std::make_shared<SourceEliminationDecision>(op1_old_stra, op1_cost, op2_old_stra, op2_cost);
812 auto new_cost = std::make_shared<Cost>(computation, communication, decision);
813 const auto gamma = CostModelContext::GetInstance()->costmodel_gamma();
814 MS_EXCEPTION_IF_NULL(new_cost);
815 new_cost->communication_without_parameter_ = communication_without_para;
816 new_cost->communication_with_partial_para_ =
817 communication_without_para + gamma * (communication - communication_without_para);
818 new_cost->memory_with_reuse_ = memory;
819 new_cost->communication_forward_ = communication_forward;
820 MS_EXCEPTION_IF_NULL(op1_new_clist);
821 op1_new_clist->emplace_back(std::move(new_cost));
822 }
823 }
824 }
825
UpdateEdgesIncidentToNodes(OperatorInfoPtr op1,std::vector<EdgePtr> * op1_old_succ_edges,std::vector<std::map<CostPtrKey,CostPtrList>> * op1_new_edges_cost,std::vector<EdgePtr> * op1_new_succ_edges,const OperatorInfoPtr op2,std::vector<EdgePtr> * op2_old_succ_edges,std::vector<std::map<CostPtrKey,CostPtrList>> * op2_new_edges_cost,std::vector<EdgePtr> * op2_new_succ_edges)826 std::pair<std::vector<EdgePtr>, std::vector<EdgePtr>> UpdateEdgesIncidentToNodes(
827 OperatorInfoPtr op1, std::vector<EdgePtr> *op1_old_succ_edges,
828 std::vector<std::map<CostPtrKey, CostPtrList>> *op1_new_edges_cost, std::vector<EdgePtr> *op1_new_succ_edges,
829 const OperatorInfoPtr op2, std::vector<EdgePtr> *op2_old_succ_edges,
830 std::vector<std::map<CostPtrKey, CostPtrList>> *op2_new_edges_cost, std::vector<EdgePtr> *op2_new_succ_edges) {
831 for (size_t i = 0; i < op1_old_succ_edges->size(); ++i) {
832 auto &new_cost_map = op1_new_edges_cost->at(i);
833 auto ith_edge = op1_old_succ_edges->at(i);
834
835 std::string new_edge_name = op1->name() + OPERATOR_TO_OPERATOR_CONNECTOR + ith_edge->next_operator()->name();
836 std::shared_ptr<Edge> new_edge;
837 if (ith_edge->is_combined()) {
838 std::vector<size_t> output_indexs, input_indexs;
839 output_indexs = ith_edge->prev_op_output_indexs();
840 input_indexs = ith_edge->next_op_input_indexs();
841 new_edge =
842 std::make_shared<Edge>(new_edge_name, op1, ith_edge->next_operator(), output_indexs, input_indexs, true);
843 } else {
844 size_t output_index, input_index;
845 output_index = ith_edge->prev_op_output_index();
846 input_index = ith_edge->next_op_input_index();
847 new_edge =
848 std::make_shared<Edge>(new_edge_name, op1, ith_edge->next_operator(), output_index, input_index, false);
849 }
850 new_edge->SetCostMapAndInputOutput(new_cost_map);
851 // replace the old successive edges with the new ones.
852 op1->ReplaceSuccEdge(ith_edge->next_operator(), new_edge);
853 ith_edge->next_operator()->ReplacePreEdge(op1, new_edge);
854 (void)op1_new_succ_edges->erase(op1_new_succ_edges->begin() + SizeToLong(i));
855 (void)op1_new_succ_edges->emplace(op1_new_succ_edges->begin() + SizeToLong(i), new_edge);
856 }
857 for (size_t i = 0; i < op2_old_succ_edges->size(); ++i) {
858 auto &new_cost_map = op2_new_edges_cost->at(i);
859 auto ith_edge = op2_old_succ_edges->at(i);
860 const auto &destination = ith_edge->next_operator();
861
862 std::string new_edge_name = op1->name() + OPERATOR_TO_OPERATOR_CONNECTOR + destination->name();
863 std::shared_ptr<Edge> new_edge;
864 if (ith_edge->is_combined()) {
865 std::vector<size_t> output_indexs, input_indexs;
866 output_indexs = ith_edge->prev_op_output_indexs();
867 input_indexs = ith_edge->next_op_input_indexs();
868 new_edge = std::make_shared<Edge>(new_edge_name, op1, destination, output_indexs, input_indexs, true);
869 } else {
870 size_t output_index, input_index;
871 output_index = ith_edge->prev_op_output_index();
872 input_index = ith_edge->next_op_input_index();
873 new_edge = std::make_shared<Edge>(new_edge_name, op1, destination, output_index, input_index, false);
874 }
875 new_edge->SetCostMapAndInputOutput(new_cost_map);
876 // replace the old successive edges with the new ones.
877 destination->ReplacePreEdge(op2, new_edge);
878 op1->AddSuccEdge(new_edge);
879 (void)op2_new_succ_edges->erase(op2_new_succ_edges->begin() + SizeToLong(i));
880 (void)op2_new_succ_edges->emplace(op2_new_succ_edges->begin() + SizeToLong(i), new_edge);
881 }
882 return std::make_pair(*op1_new_succ_edges, *op2_new_succ_edges);
883 }
884
EliminationSources(const OperatorInfoPtr op1,const OperatorInfoPtr op2)885 std::pair<std::vector<std::shared_ptr<Edge>>, std::vector<std::shared_ptr<Edge>>> CostGraph::EliminationSources(
886 const OperatorInfoPtr op1, const OperatorInfoPtr op2) {
887 MS_EXCEPTION_IF_NULL(op1);
888 MS_EXCEPTION_IF_NULL(op2);
889 MS_LOG(INFO) << "Now source eliminating node: " << op2->name() << " to node: " << op1->name();
890
891 auto op1_old_succ_edges = op1->GetAliveSuccEdges();
892 std::vector<std::map<StrategyPtr, std::vector<std::pair<StrategyPtr, CostPtrList>>>> op1_edges_reorganised_cost(
893 op1_old_succ_edges.size());
894 std::vector<std::map<CostPtrKey, CostPtrList>> op1_new_edges_cost(op1_old_succ_edges.size());
895 std::vector<std::shared_ptr<Edge>> op1_new_succ_edges(op1_old_succ_edges.size());
896
897 auto op2_old_succ_edges = op2->GetAliveSuccEdges();
898 std::vector<std::map<StrategyPtr, std::vector<std::pair<StrategyPtr, CostPtrList>>>> op2_edges_reorganised_cost(
899 op2_old_succ_edges.size());
900 std::vector<std::map<CostPtrKey, CostPtrList>> op2_new_edges_cost(op2_old_succ_edges.size());
901 std::vector<std::shared_ptr<Edge>> op2_new_succ_edges(op2_old_succ_edges.size());
902
903 // Construct cost_map for the data_structure of 'op1_edges_reorganised_cost' and 'op2_edges_reorganised_cost'
904 for (size_t i = 0; i < op1_old_succ_edges.size(); ++i) {
905 const auto &op1_cost_map = op1_old_succ_edges[i]->GetCostMap();
906 std::map<StrategyPtr, std::vector<std::pair<StrategyPtr, CostPtrList>>> from_tocost;
907 for (const auto &key_value : op1_cost_map) {
908 const auto &from_to_strategies = key_value.first;
909 const auto &costlist = key_value.second;
910 from_tocost[from_to_strategies.first].emplace_back(std::make_pair(from_to_strategies.second, costlist));
911 }
912 op1_edges_reorganised_cost[i] = from_tocost;
913 }
914
915 for (size_t i = 0; i < op2_old_succ_edges.size(); ++i) {
916 const auto &op2_cost_map = op2_old_succ_edges[i]->GetCostMap();
917 std::map<StrategyPtr, std::vector<std::pair<StrategyPtr, CostPtrList>>> from_tocost;
918 for (const auto &key_value : op2_cost_map) {
919 const auto &from_to_strategies = key_value.first;
920 const auto &costlist = key_value.second;
921 from_tocost[from_to_strategies.first].emplace_back(std::make_pair(from_to_strategies.second, costlist));
922 }
923 op2_edges_reorganised_cost[i] = from_tocost;
924 }
925
926 // Merge op2 into op1
927 const auto &op1_old_stra_cost = op1->GetStrategyCost();
928 const auto &op2_old_stra_cost = op2->GetStrategyCost();
929 std::vector<std::shared_ptr<StrategyWithCost>> op1_new_stra_cost;
930
931 for (auto &op1_stra_cost : op1_old_stra_cost) {
932 auto op1_old_stra = op1_stra_cost->strategy_ptr;
933 auto op1_old_costlist = op1_stra_cost->cost_list;
934
935 for (auto &op2_stra_cost : op2_old_stra_cost) {
936 auto op2_stra = op2_stra_cost->strategy_ptr;
937 auto op2_costlist = op2_stra_cost->cost_list;
938
939 StrategyPtr op1_new_stra = std::make_shared<Strategy>(*op1_old_stra);
940 op1_new_stra->CoverStrategy(op2_stra);
941 CostPtrList op1_new_costlist;
942 // Calculate new cost for 'op1_new_costlist'
943 CreateSourceEliminationSubCostList(op1_old_stra, op1_old_costlist, op2_stra, op2_costlist, &op1_new_costlist);
944 std::shared_ptr<StrategyWithCost> swc = std::make_shared<StrategyWithCost>(op1_new_stra, op1_new_costlist);
945 op1_new_stra_cost.emplace_back(swc);
946
947 // Set cost for new successive edges of op1 and op2
948 for (size_t i = 0; i < op1_old_succ_edges.size(); ++i) {
949 auto &from_tocost = op1_edges_reorganised_cost[i];
950 auto &to_cost = from_tocost[op1_old_stra];
951 auto &new_cost_map = op1_new_edges_cost[i];
952 for (auto &stra_costlit : to_cost) {
953 auto &to_strategy = stra_costlit.first;
954 auto &edge_costlist = stra_costlit.second;
955 CostPtrKey new_key = {op1_new_stra, to_strategy};
956 new_cost_map[new_key] = edge_costlist;
957 }
958 }
959 for (size_t i = 0; i < op2_old_succ_edges.size(); ++i) {
960 auto &from_tocost = op2_edges_reorganised_cost[i];
961 auto &to_cost = from_tocost[op2_stra];
962 auto &new_cost_map = op2_new_edges_cost[i];
963 for (auto &stra_costlist : to_cost) {
964 auto &to_strategy = stra_costlist.first;
965 auto &edge_costlist = stra_costlist.second;
966 CostPtrKey new_key = {op1_new_stra, to_strategy};
967 new_cost_map[new_key] = edge_costlist;
968 }
969 }
970 }
971 }
972 op1->SetStrategyCost(op1_new_stra_cost);
973 op2->SetNotAlive();
974
975 // Update the edges incident to op1, and edges incident to op2
976 MS_LOG(INFO) << "Source eliminating node: " << op2->name() << " to node: " << op1->name() + " succeeded.";
977 return UpdateEdgesIncidentToNodes(op1, &op1_old_succ_edges, &op1_new_edges_cost, &op1_new_succ_edges, op2,
978 &op2_old_succ_edges, &op2_new_edges_cost, &op2_new_succ_edges);
979 }
980
981 // Check the graph whether a TriangleElimination can be performed
CheckTriangleElimination() const982 std::pair<OperatorInfoPtr, std::shared_ptr<Edge>> CostGraph::CheckTriangleElimination() const {
983 for (auto &op : ops_) {
984 MS_EXCEPTION_IF_NULL(op);
985 bool bool_test = (op->is_alive()) && (op->GetAlivePrevEdges().empty()) && (op->GetAliveSuccEdges().size() == 2);
986 if (bool_test) {
987 auto edge1 = op->GetAliveSuccEdges()[0];
988 auto edge2 = op->GetAliveSuccEdges()[1];
989 MS_EXCEPTION_IF_NULL(edge1);
990 MS_EXCEPTION_IF_NULL(edge2);
991 auto first_op = edge1->next_operator();
992 auto second_op = edge2->next_operator();
993 MS_EXCEPTION_IF_NULL(first_op);
994 for (auto &first_op_succ_edge : first_op->GetAliveSuccEdges()) {
995 if (first_op_succ_edge->next_operator() == second_op) {
996 return {op, first_op_succ_edge};
997 }
998 }
999 MS_EXCEPTION_IF_NULL(second_op);
1000 for (auto &second_op_succ_edge : second_op->GetAliveSuccEdges()) {
1001 if (second_op_succ_edge->next_operator() == first_op) {
1002 return {op, second_op_succ_edge};
1003 }
1004 }
1005 }
1006 }
1007 return {nullptr, nullptr};
1008 }
1009
1010 // Check the graph whether a StarElimination can be performed.
1011 // NOTE: this elimination MUST be performed only when the above 5 operation cannot be applied.
CheckStarElimination() const1012 OperatorInfoPtr CostGraph::CheckStarElimination() const {
1013 for (auto &op : ops_) {
1014 MS_EXCEPTION_IF_NULL(op);
1015 bool bool_test = (op->is_alive()) && (op->GetAlivePrevEdges().empty()) && (op->GetAliveSuccEdges().size() > 1);
1016 if (bool_test) {
1017 return op;
1018 }
1019 }
1020 return nullptr;
1021 }
1022
1023 // This method is for 'eliminating operator' operation in the DP algorithm. It creates a new edge to replace
1024 // 'lefe_edge', 'op' and 'right_edge'. As a consequence, it creates new costlist for the new edge.
EliminationOp(const OperatorInfoPtr & op)1025 std::shared_ptr<Edge> CostGraph::EliminationOp(const OperatorInfoPtr &op) {
1026 // in this case, the operators are organised in the form of u-->op-->v, and the goal
1027 // is to eliminate 'op'.
1028 MS_EXCEPTION_IF_NULL(op);
1029 MS_LOG(INFO) << "Now eliminating node: " << op->name() << ".";
1030 auto edge_u_op = op->GetAlivePrevEdges()[0];
1031 auto edge_op_v = op->GetAliveSuccEdges()[0];
1032 MS_EXCEPTION_IF_NULL(edge_u_op);
1033 MS_EXCEPTION_IF_NULL(edge_op_v);
1034 auto u = edge_u_op->prev_operator();
1035 auto v = edge_op_v->next_operator();
1036 std::vector<size_t> output_indexs, input_indexs;
1037 size_t output_index, input_index;
1038 MS_EXCEPTION_IF_NULL(u);
1039 MS_EXCEPTION_IF_NULL(v);
1040 std::string new_edge_name = u->name() + OPERATOR_TO_OPERATOR_CONNECTOR + v->name();
1041 std::shared_ptr<Edge> new_edge;
1042 if (edge_u_op->is_combined()) {
1043 output_indexs = edge_u_op->prev_op_output_indexs();
1044 } else {
1045 output_index = edge_u_op->prev_op_output_index();
1046 output_indexs.push_back(output_index);
1047 }
1048 if (edge_op_v->is_combined()) {
1049 input_indexs = edge_op_v->next_op_input_indexs();
1050 } else {
1051 input_index = edge_op_v->next_op_input_index();
1052 input_indexs.push_back(input_index);
1053 }
1054
1055 if (!edge_u_op->is_combined() && !edge_op_v->is_combined()) {
1056 new_edge = std::make_shared<Edge>(new_edge_name, u, v, output_index, input_index, false);
1057 } else {
1058 new_edge = std::make_shared<Edge>(new_edge_name, u, v, output_indexs, input_indexs, true);
1059 }
1060 MS_EXCEPTION_IF_NULL(new_edge);
1061 new_edge->set_pre_op_output(edge_u_op->prev_op_output());
1062 new_edge->set_next_op_input(edge_op_v->next_op_input());
1063 new_edge->OpEliminationSetNewCost(edge_u_op, op, edge_op_v);
1064 u->ReplaceSuccEdge(op, new_edge);
1065 v->ReplacePreEdge(op, new_edge);
1066 op->SetNotAlive();
1067 MS_LOG(INFO) << "Eliminating node: " << op->name() << " succeeded.";
1068 return new_edge;
1069 }
1070
1071 // This method is for 'eliminating edges' operation in the DP algorithm. It creates a new edge to replace the 'edges',
1072 // and sets new costlist for the new edge.
EliminationEdges(const std::vector<std::shared_ptr<Edge>> & edges)1073 std::shared_ptr<Edge> CostGraph::EliminationEdges(const std::vector<std::shared_ptr<Edge>> &edges) {
1074 MS_LOG(INFO) << "Now eliminating " << edges.size() << " edges.";
1075 MS_EXCEPTION_IF_NULL(edges[0]);
1076 auto u = edges[0]->prev_operator();
1077 auto v = edges[0]->next_operator();
1078 MS_EXCEPTION_IF_NULL(u);
1079 MS_EXCEPTION_IF_NULL(v);
1080 std::string new_edge_name = u->name() + OPERATOR_TO_OPERATOR_CONNECTOR + v->name();
1081 std::vector<size_t> output_indexs, input_indexs;
1082
1083 for (auto &edge : edges) {
1084 MS_EXCEPTION_IF_NULL(edge);
1085 if (edge->is_combined()) {
1086 auto from_output_indexs = edge->prev_op_output_indexs();
1087 auto from_input_indexs = edge->next_op_input_indexs();
1088 (void)std::copy(from_output_indexs.begin(), from_output_indexs.end(), std::back_inserter(output_indexs));
1089 (void)std::copy(from_input_indexs.begin(), from_input_indexs.end(), std::back_inserter(input_indexs));
1090 } else {
1091 output_indexs.push_back(edge->prev_op_output_index());
1092 input_indexs.push_back(edge->next_op_input_index());
1093 }
1094 }
1095
1096 std::shared_ptr<Edge> new_edge = std::make_shared<Edge>(new_edge_name, u, v, output_indexs, input_indexs, true);
1097 MS_EXCEPTION_IF_NULL(new_edge);
1098 new_edge->set_pre_op_output(edges[0]->prev_op_output());
1099 new_edge->set_next_op_input(edges[0]->next_op_input());
1100
1101 new_edge->EdgeEliminationSetNewCost(u, edges, v);
1102
1103 u->ReplaceSuccEdges(v, new_edge);
1104 v->ReplacePreEdges(u, new_edge);
1105 MS_LOG(INFO) << "Eliminating " << edges.size() << " edges succeeded.";
1106 return new_edge;
1107 }
1108
1109 // Given 'op_cost_list', 'edge_cost_list', and 'tar_cost_list', this method is to create 'tar_cost_list_new'
1110 // for this contract under the strategy 'op_strategy'
CreateMergeEliminationSubCostList(StrategyPtr op_strategy,const CostPtrList & op_cost_list,const CostPtrList & edge_cost_list,StrategyPtr tar_op_strategy,const CostPtrList & tar_cost_list,CostPtrList * const tar_cost_list_new)1111 void CostGraph::CreateMergeEliminationSubCostList(StrategyPtr op_strategy, const CostPtrList &op_cost_list,
1112 const CostPtrList &edge_cost_list, StrategyPtr tar_op_strategy,
1113 const CostPtrList &tar_cost_list,
1114 CostPtrList *const tar_cost_list_new) {
1115 for (size_t i = 0; i < op_cost_list.size(); ++i) {
1116 auto &op_cost = op_cost_list[i];
1117 MS_EXCEPTION_IF_NULL(op_cost);
1118 for (size_t j = 0; j < edge_cost_list.size(); ++j) {
1119 auto &edge_cost = edge_cost_list[j];
1120 MS_EXCEPTION_IF_NULL(edge_cost);
1121 for (size_t k = 0; k < tar_cost_list.size(); ++k) {
1122 auto &tar_cost = tar_cost_list[k];
1123 MS_EXCEPTION_IF_NULL(tar_cost);
1124 double computation = op_cost->computation_cost_ + edge_cost->computation_cost_ + tar_cost->computation_cost_;
1125 double memory = op_cost->memory_with_reuse_ + edge_cost->memory_with_reuse_ + tar_cost->memory_with_reuse_;
1126 double communication =
1127 op_cost->communication_cost_ + edge_cost->communication_cost_ + tar_cost->communication_cost_;
1128 double communication_forward =
1129 op_cost->communication_forward_ + edge_cost->communication_forward_ + tar_cost->communication_forward_;
1130 double communication_without_para = op_cost->communication_without_parameter_ +
1131 edge_cost->communication_without_parameter_ +
1132 tar_cost->communication_without_parameter_;
1133
1134 auto decision =
1135 std::make_shared<MergeEliminationDecision>(op_strategy, op_cost, edge_cost, tar_op_strategy, tar_cost);
1136 auto new_cost = std::make_shared<Cost>(computation, communication, decision);
1137 const auto gamma = CostModelContext::GetInstance()->costmodel_gamma();
1138 MS_EXCEPTION_IF_NULL(new_cost);
1139 new_cost->communication_without_parameter_ = communication_without_para;
1140 new_cost->communication_with_partial_para_ =
1141 communication_without_para + gamma * (communication - communication_without_para);
1142 new_cost->memory_with_reuse_ = memory;
1143 new_cost->communication_forward_ = communication_forward;
1144 MS_EXCEPTION_IF_NULL(tar_cost_list_new);
1145 tar_cost_list_new->emplace_back(std::move(new_cost));
1146 }
1147 }
1148 }
1149 }
1150
1151 // This method is for the 'Merge' operation in DP algorithm. It creates new costlist for each strategy in the
1152 // target_op
EliminationMerge(const OperatorInfoPtr & op)1153 OperatorInfoPtr CostGraph::EliminationMerge(const OperatorInfoPtr &op) {
1154 MS_EXCEPTION_IF_NULL(op);
1155 auto target_op = op->GetAliveSuccEdges()[0]->next_operator();
1156 auto edge_ptr = op->GetAliveSuccEdges()[0];
1157 MS_EXCEPTION_IF_NULL(target_op);
1158 MS_EXCEPTION_IF_NULL(edge_ptr);
1159 MS_LOG(INFO) << "Now merging " << op->name() << " into " << target_op->name() << ".";
1160 bool valid = false;
1161
1162 for (auto &tar_stra_cost : target_op->GetStrategyCost()) {
1163 MS_EXCEPTION_IF_NULL(tar_stra_cost);
1164 auto tar_stra = tar_stra_cost->strategy_ptr;
1165 auto tar_clist_origin = tar_stra_cost->cost_list;
1166 CostPtrList tar_clist_new;
1167
1168 for (auto &op_stra_cost : op->GetStrategyCost()) {
1169 MS_EXCEPTION_IF_NULL(op_stra_cost);
1170 auto op_stra = op_stra_cost->strategy_ptr;
1171 auto op_clist = op_stra_cost->cost_list;
1172 auto edge_clist = edge_ptr->GetCostList(op_stra, tar_stra);
1173
1174 CreateMergeEliminationSubCostList(op_stra, op_clist, edge_clist, tar_stra, tar_clist_origin, &tar_clist_new);
1175 }
1176 Simplify(&tar_clist_new);
1177 // Set the new costlist w.r.t the strategy
1178 tar_stra_cost->cost_list = tar_clist_new;
1179 if ((!valid) && (!tar_clist_new.empty())) {
1180 valid = true;
1181 }
1182 }
1183
1184 if (!valid) {
1185 MS_LOG(EXCEPTION) << "Merging " << op->name() << " into " << target_op->name() << " failed.";
1186 }
1187 op->SetNotAlive();
1188 MS_LOG(INFO) << "Merging " << op->name() << " into " << target_op->name() << " succeeded.";
1189 return target_op;
1190 }
1191
1192 // Given 'contract_op_cost_list', 'edge_cost_list', and 'tar_cost_list', this method is to create 'tar_cost_list_new'
1193 // for this contract under the strategy 'contract_op_stra'
CreateContractEliminationSubCostList(StrategyPtr contract_op_stra,const CostPtrList & contract_op_cost_list,const CostPtrList & edge_cost_list,StrategyPtr target_op_stra,const CostPtrList & tar_cost_list,CostPtrList * tar_cost_list_new)1194 void CostGraph::CreateContractEliminationSubCostList(StrategyPtr contract_op_stra,
1195 const CostPtrList &contract_op_cost_list,
1196 const CostPtrList &edge_cost_list, StrategyPtr target_op_stra,
1197 const CostPtrList &tar_cost_list, CostPtrList *tar_cost_list_new) {
1198 for (size_t i = 0; i < contract_op_cost_list.size(); ++i) {
1199 auto &contract_op_cost = contract_op_cost_list[i];
1200 MS_EXCEPTION_IF_NULL(contract_op_cost);
1201 for (size_t j = 0; j < edge_cost_list.size(); ++j) {
1202 auto &edge_cost = edge_cost_list[j];
1203 MS_EXCEPTION_IF_NULL(edge_cost);
1204 for (size_t k = 0; k < tar_cost_list.size(); ++k) {
1205 auto &tar_cost = tar_cost_list[k];
1206 MS_EXCEPTION_IF_NULL(tar_cost);
1207 double computation =
1208 contract_op_cost->computation_cost_ + edge_cost->computation_cost_ + tar_cost->computation_cost_;
1209 double memory =
1210 contract_op_cost->memory_with_reuse_ + edge_cost->memory_with_reuse_ + tar_cost->memory_with_reuse_;
1211 double communication =
1212 contract_op_cost->communication_cost_ + edge_cost->communication_cost_ + tar_cost->communication_cost_;
1213 double communication_forward = contract_op_cost->communication_forward_ + edge_cost->communication_forward_ +
1214 tar_cost->communication_forward_;
1215 double communication_without_para = contract_op_cost->communication_without_parameter_ +
1216 edge_cost->communication_without_parameter_ +
1217 tar_cost->communication_without_parameter_;
1218
1219 auto decision = std::make_shared<ContractEliminationDecision>(contract_op_stra, contract_op_cost, edge_cost,
1220 target_op_stra, tar_cost);
1221 auto new_cost = std::make_shared<Cost>(computation, communication, decision);
1222 auto gamma = CostModelContext::GetInstance()->costmodel_gamma();
1223 new_cost->communication_without_parameter_ = communication_without_para;
1224 new_cost->communication_with_partial_para_ =
1225 communication_without_para + gamma * (communication - communication_without_para);
1226 new_cost->memory_with_reuse_ = memory;
1227 new_cost->communication_forward_ = communication_forward;
1228 tar_cost_list_new->emplace_back(std::move(new_cost));
1229 }
1230 }
1231 }
1232 }
1233
1234 // This method is for the 'Contract' operation in DP algorithm. It creates new costlist for each strategy in the
1235 // target_op
EliminationContract(const OperatorInfoPtr & op)1236 OperatorInfoPtr CostGraph::EliminationContract(const OperatorInfoPtr &op) {
1237 MS_EXCEPTION_IF_NULL(op);
1238 auto target_op = op->GetAlivePrevEdges()[0]->prev_operator();
1239 auto edge_ptr = op->GetAlivePrevEdges()[0];
1240 MS_LOG(INFO) << "Now contracting " << op->name() << " into " << target_op->name() << ".";
1241 bool valid = false;
1242
1243 for (auto &tar_stra_cost : target_op->GetStrategyCost()) {
1244 MS_EXCEPTION_IF_NULL(tar_stra_cost);
1245 auto tar_stra = tar_stra_cost->strategy_ptr;
1246 auto tar_clist_origin = tar_stra_cost->cost_list;
1247 CostPtrList tar_clist_new;
1248
1249 for (auto &op_stra_cost : op->GetStrategyCost()) {
1250 MS_EXCEPTION_IF_NULL(op_stra_cost);
1251 auto op_stra = op_stra_cost->strategy_ptr;
1252 auto op_clist = op_stra_cost->cost_list;
1253 auto edge_clist = edge_ptr->GetCostList(tar_stra, op_stra);
1254
1255 CreateContractEliminationSubCostList(op_stra, op_clist, edge_clist, tar_stra, tar_clist_origin, &tar_clist_new);
1256 }
1257 Simplify(&tar_clist_new);
1258 // Set the new costlist w.r.t the strategy
1259 tar_stra_cost->cost_list = tar_clist_new;
1260 if ((!valid) && (!tar_clist_new.empty())) {
1261 valid = true;
1262 }
1263 }
1264 if (!valid) {
1265 MS_LOG(EXCEPTION) << "Contracting " << op->name() << " into " << target_op->name() << " failed.";
1266 }
1267 op->SetNotAlive();
1268 MS_LOG(INFO) << "Contracting " << op->name() << " into " << target_op->name() << " succeeded.";
1269 return target_op;
1270 }
1271
CreateTriangleEliminationSubCostList(StrategyPtr elimi_op_stra,StrategyPtr left_op_stra,StrategyPtr right_op_stra,const CostPtr & right_op_cost,const CostPtrList & elimi_op_clist,const CostPtrList & left_edge_clist,const CostPtr & right_edge_cost,const CostPtrList & left_node_clist_origin,CostPtrList * left_node_clist_new)1272 void CostGraph::CreateTriangleEliminationSubCostList(StrategyPtr elimi_op_stra, StrategyPtr left_op_stra,
1273 StrategyPtr right_op_stra, const CostPtr &right_op_cost,
1274 const CostPtrList &elimi_op_clist,
1275 const CostPtrList &left_edge_clist, const CostPtr &right_edge_cost,
1276 const CostPtrList &left_node_clist_origin,
1277 CostPtrList *left_node_clist_new) {
1278 MS_EXCEPTION_IF_NULL(right_edge_cost);
1279 MS_EXCEPTION_IF_NULL(right_op_cost);
1280 MS_EXCEPTION_IF_NULL(left_node_clist_new);
1281 for (auto &elimi_op_cost : elimi_op_clist) {
1282 MS_EXCEPTION_IF_NULL(elimi_op_cost);
1283 for (auto &left_edge_cost : left_edge_clist) {
1284 MS_EXCEPTION_IF_NULL(left_edge_cost);
1285 for (auto &left_node_cost : left_node_clist_origin) {
1286 MS_EXCEPTION_IF_NULL(left_node_cost);
1287 double new_computation = elimi_op_cost->computation_cost_ + left_edge_cost->computation_cost_ +
1288 left_node_cost->computation_cost_ + right_edge_cost->computation_cost_;
1289 double new_memory = elimi_op_cost->memory_with_reuse_ + left_edge_cost->memory_with_reuse_ +
1290 left_node_cost->memory_with_reuse_ + right_edge_cost->memory_with_reuse_;
1291 double new_commu_cost = elimi_op_cost->communication_cost_ + left_edge_cost->communication_cost_ +
1292 left_node_cost->communication_cost_ + right_edge_cost->communication_cost_;
1293 double new_commu_forward = elimi_op_cost->communication_forward_ + left_edge_cost->communication_forward_ +
1294 left_node_cost->communication_forward_ + right_edge_cost->communication_forward_;
1295 double new_commu_without =
1296 elimi_op_cost->communication_without_parameter_ + left_edge_cost->communication_without_parameter_ +
1297 left_node_cost->communication_without_parameter_ + right_edge_cost->communication_without_parameter_;
1298 const auto triangle_star_stra_overwrite = CostModelContext::GetInstance()->triangle_star_strategy_overwrite();
1299 const auto gamma = CostModelContext::GetInstance()->costmodel_gamma();
1300
1301 if (triangle_star_stra_overwrite) {
1302 new_computation += right_op_cost->computation_cost_;
1303 new_memory += right_op_cost->memory_with_reuse_;
1304 new_commu_cost += right_op_cost->communication_cost_;
1305 new_commu_forward += right_op_cost->communication_forward_;
1306 new_commu_without += right_op_cost->communication_without_parameter_;
1307 }
1308
1309 auto decision =
1310 std::make_shared<TriangleEliminationDecision>(elimi_op_stra, elimi_op_cost, left_edge_cost, right_edge_cost,
1311 left_op_stra, left_node_cost, right_op_stra, right_op_cost);
1312 auto new_cost = std::make_shared<Cost>(new_computation, new_commu_cost, decision);
1313 new_cost->communication_without_parameter_ = new_commu_without;
1314 new_cost->communication_with_partial_para_ = new_commu_without + gamma * (new_commu_cost - new_commu_without);
1315 new_cost->memory_with_reuse_ = new_memory;
1316 new_cost->communication_forward_ = new_commu_forward;
1317 left_node_clist_new->emplace_back(std::move(new_cost));
1318 }
1319 }
1320 }
1321 }
1322
CreateTriangleEliminationCostList(const OperatorInfoPtr & elimi_op,const CostPtrList & right_node_clist,const CostPtrList & right_edge_clist,const StrategyPtr & elimi_op_stra,const StrategyPtr & left_node_stra,const StrategyPtr & right_node_stra,const CostPtrList & elimi_op_clist,const CostPtrList & left_edge_clist,const CostPtrList & left_node_clist_origin,CostPtrList * left_node_clist_new)1323 void CostGraph::CreateTriangleEliminationCostList(const OperatorInfoPtr &elimi_op, const CostPtrList &right_node_clist,
1324 const CostPtrList &right_edge_clist, const StrategyPtr &elimi_op_stra,
1325 const StrategyPtr &left_node_stra, const StrategyPtr &right_node_stra,
1326 const CostPtrList &elimi_op_clist, const CostPtrList &left_edge_clist,
1327 const CostPtrList &left_node_clist_origin,
1328 CostPtrList *left_node_clist_new) {
1329 MS_EXCEPTION_IF_NULL(elimi_op);
1330 for (auto &right_node_cost : right_node_clist) {
1331 MS_EXCEPTION_IF_NULL(right_node_cost);
1332 for (auto &right_edge_cost : right_edge_clist) {
1333 MS_EXCEPTION_IF_NULL(right_edge_cost);
1334 CreateTriangleEliminationSubCostList(elimi_op_stra, left_node_stra, right_node_stra, right_node_cost,
1335 elimi_op_clist, left_edge_clist, right_edge_cost, left_node_clist_origin,
1336 left_node_clist_new);
1337 }
1338 }
1339 }
1340
EliminationTriangle(const OperatorInfoPtr & elimi_op,const std::shared_ptr<Edge> & edge_left_right)1341 OperatorInfoPtr CostGraph::EliminationTriangle(const OperatorInfoPtr &elimi_op,
1342 const std::shared_ptr<Edge> &edge_left_right) {
1343 MS_EXCEPTION_IF_NULL(edge_left_right);
1344 MS_EXCEPTION_IF_NULL(elimi_op);
1345 MS_LOG(INFO) << "Now eliminating triangle: " << elimi_op->name() << ".";
1346 auto left_node = edge_left_right->prev_operator();
1347 auto right_node = edge_left_right->next_operator();
1348 auto left_edge = elimi_op->GetAliveSuccEdges()[0];
1349 auto right_edge = elimi_op->GetAliveSuccEdges()[1];
1350 MS_EXCEPTION_IF_NULL(left_node);
1351 MS_EXCEPTION_IF_NULL(right_node);
1352 MS_EXCEPTION_IF_NULL(left_edge);
1353 MS_EXCEPTION_IF_NULL(right_edge);
1354 MS_LOG(INFO) << "The left operator is: " << left_node->name() << ".";
1355 MS_LOG(INFO) << "The right operator is: " << right_node->name() << ".";
1356
1357 if (left_edge->next_operator() != left_node) {
1358 auto tmp = left_edge;
1359 left_edge = right_edge;
1360 right_edge = tmp;
1361 }
1362 bool valid = false;
1363
1364 for (auto &left_node_stra_cost : left_node->GetStrategyCost()) {
1365 MS_EXCEPTION_IF_NULL(left_node_stra_cost);
1366 auto left_node_stra = left_node_stra_cost->strategy_ptr;
1367 auto left_node_clist_origin = left_node_stra_cost->cost_list;
1368 CostPtrList left_node_clist_new;
1369
1370 for (auto &elimi_op_stra_cost : elimi_op->GetStrategyCost()) {
1371 MS_EXCEPTION_IF_NULL(elimi_op_stra_cost);
1372 auto elimi_op_stra = elimi_op_stra_cost->strategy_ptr;
1373 auto elimi_op_clist = elimi_op_stra_cost->cost_list;
1374 auto left_edge_clist = left_edge->GetCostList(elimi_op_stra, left_node_stra);
1375
1376 for (auto &right_node_stra_cost : right_node->GetStrategyCost()) {
1377 MS_EXCEPTION_IF_NULL(right_node_stra_cost);
1378 auto right_node_stra = right_node_stra_cost->strategy_ptr;
1379 auto right_node_clist = right_node_stra_cost->cost_list;
1380 auto right_edge_clist = right_edge->GetCostList(elimi_op_stra, right_node_stra);
1381
1382 CreateTriangleEliminationCostList(elimi_op, right_node_clist, right_edge_clist, elimi_op_stra, left_node_stra,
1383 right_node_stra, elimi_op_clist, left_edge_clist, left_node_clist_origin,
1384 &left_node_clist_new);
1385 }
1386 }
1387 Simplify(&left_node_clist_new);
1388 // Set the new costlist w.r.t the strategy
1389 left_node_stra_cost->cost_list = left_node_clist_new;
1390 if ((!valid) && (!left_node_clist_new.empty())) {
1391 valid = true;
1392 }
1393 }
1394
1395 if (!valid) {
1396 MS_LOG(EXCEPTION) << "Eliminating triangle: " << elimi_op->name()
1397 << " failed. It may be caused by "
1398 "configuring inconsistent strategies for operators.";
1399 }
1400 elimi_op->SetNotAlive();
1401 MS_LOG(INFO) << "Eliminating triangle: " << elimi_op->name() << " succeeded.";
1402 return left_node;
1403 }
1404
CreateStarEliminationSubCostList(const StrategyPtr & first_succ_node_stra,const CostPtrList & first_succ_node_clist,const CostPtrList & first_succ_edge_clist,const StrategyPtr & merged_op_stra,const CostPtrList & merged_op_clist,std::vector<StrategyPtr> succ_nodes_stras,CostPtrList & succ_edges_costs,CostPtrList & succ_nodes_costs,CostPtrList * first_succ_node_clist_new)1405 void CostGraph::CreateStarEliminationSubCostList(const StrategyPtr &first_succ_node_stra,
1406 const CostPtrList &first_succ_node_clist,
1407 const CostPtrList &first_succ_edge_clist,
1408 const StrategyPtr &merged_op_stra, const CostPtrList &merged_op_clist,
1409 std::vector<StrategyPtr> succ_nodes_stras,
1410 CostPtrList &succ_edges_costs, CostPtrList &succ_nodes_costs,
1411 CostPtrList *first_succ_node_clist_new) {
1412 for (auto &first_succ_node_cost : first_succ_node_clist) {
1413 for (auto &first_succ_edge_cost : first_succ_edge_clist) {
1414 for (auto &merged_node_cost : merged_op_clist) {
1415 MS_EXCEPTION_IF_NULL(merged_node_cost);
1416 succ_nodes_stras[0] = first_succ_node_stra;
1417 succ_edges_costs[0] = first_succ_edge_cost;
1418 succ_nodes_costs[0] = first_succ_node_cost;
1419
1420 double computation_cost = merged_node_cost->computation_cost_,
1421 memory_cost = merged_node_cost->memory_with_reuse_, commu_cost = merged_node_cost->communication_cost_,
1422 commu_without = merged_node_cost->communication_without_parameter_,
1423 commu_forward = merged_node_cost->communication_forward_;
1424 const auto triangle_star_stra_overwrite = CostModelContext::GetInstance()->triangle_star_strategy_overwrite();
1425 const auto gamma = CostModelContext::GetInstance()->costmodel_gamma();
1426 for (size_t i = 0; i < succ_nodes_stras.size(); ++i) {
1427 MS_EXCEPTION_IF_NULL(succ_edges_costs[i]);
1428 if (i == 0) {
1429 computation_cost += succ_edges_costs[i]->computation_cost_ + succ_nodes_costs[i]->computation_cost_;
1430 memory_cost += succ_edges_costs[i]->memory_with_reuse_ + succ_nodes_costs[i]->memory_with_reuse_;
1431 commu_cost += succ_edges_costs[i]->communication_cost_ + succ_nodes_costs[i]->communication_cost_;
1432 commu_forward += succ_edges_costs[i]->communication_forward_ + succ_nodes_costs[i]->communication_forward_;
1433 commu_without += succ_edges_costs[i]->communication_without_parameter_ +
1434 succ_nodes_costs[i]->communication_without_parameter_;
1435 } else {
1436 computation_cost += succ_edges_costs[i]->computation_cost_;
1437 memory_cost += succ_edges_costs[i]->memory_with_reuse_;
1438 commu_cost += succ_edges_costs[i]->communication_cost_;
1439 commu_forward += succ_edges_costs[i]->communication_forward_;
1440 commu_without += succ_edges_costs[i]->communication_without_parameter_;
1441 if (triangle_star_stra_overwrite) {
1442 computation_cost += succ_nodes_costs[i]->computation_cost_;
1443 memory_cost += succ_nodes_costs[i]->memory_with_reuse_;
1444 commu_cost += succ_nodes_costs[i]->communication_cost_;
1445 commu_forward += succ_nodes_costs[i]->communication_forward_;
1446 commu_without += succ_nodes_costs[i]->communication_without_parameter_;
1447 }
1448 }
1449 }
1450
1451 auto decision = std::make_shared<StarEliminationDecision>(merged_op_stra, merged_node_cost, succ_edges_costs,
1452 succ_nodes_stras, succ_nodes_costs);
1453 auto new_cost = std::make_shared<Cost>(computation_cost, commu_cost, decision);
1454 new_cost->communication_without_parameter_ = commu_without;
1455 new_cost->communication_with_partial_para_ = commu_without + gamma * (commu_cost - commu_without);
1456 new_cost->memory_with_reuse_ = memory_cost;
1457 new_cost->communication_forward_ = commu_forward;
1458 first_succ_node_clist_new->emplace_back(std::move(new_cost));
1459 }
1460 }
1461 }
1462 }
1463
CreateStarEliminationCostList(std::vector<std::shared_ptr<Edge>> & succ_edges,const StrategyPtr & first_succ_node_stra,const CostPtrList & first_succ_node_clist,const CostPtrList & first_succ_edge_clist,const StrategyPtr & merged_op_stra,const CostPtrList & merged_op_clist,CostPtrList * first_succ_node_clist_new)1464 void CostGraph::CreateStarEliminationCostList(std::vector<std::shared_ptr<Edge>> &succ_edges,
1465 const StrategyPtr &first_succ_node_stra,
1466 const CostPtrList &first_succ_node_clist,
1467 const CostPtrList &first_succ_edge_clist,
1468 const StrategyPtr &merged_op_stra, const CostPtrList &merged_op_clist,
1469 CostPtrList *first_succ_node_clist_new) {
1470 std::vector<StrategyPtr> succ_nodes_stras(succ_edges.size(), nullptr);
1471 CostPtrList succ_edges_costs(succ_edges.size(), nullptr), succ_nodes_costs(succ_edges.size(), nullptr);
1472 std::function<void(size_t)> recursive = [&first_succ_node_stra, &first_succ_node_clist, &first_succ_edge_clist,
1473 &merged_op_stra, &merged_op_clist, &succ_nodes_stras, &succ_edges_costs,
1474 &succ_nodes_costs, &first_succ_node_clist_new, &succ_edges, &recursive,
1475 this](size_t k) {
1476 if (k == succ_edges.size()) {
1477 CreateStarEliminationSubCostList(first_succ_node_stra, first_succ_node_clist, first_succ_edge_clist,
1478 merged_op_stra, merged_op_clist, succ_nodes_stras, succ_edges_costs,
1479 succ_nodes_costs, first_succ_node_clist_new);
1480 return;
1481 }
1482 MS_LOG(DEBUG) << "The size of first_succ_node_clist: " << first_succ_node_clist.size()
1483 << ", first_succ_edge_clist: " << first_succ_edge_clist.size()
1484 << ", merged_op_clist: " << merged_op_clist.size()
1485 << ", first_succ_node_clist_new: " << first_succ_node_clist_new->size() << ".";
1486 auto succ_edge = succ_edges[k];
1487 MS_EXCEPTION_IF_NULL(succ_edge);
1488 auto succ_node = succ_edge->next_operator();
1489 MS_EXCEPTION_IF_NULL(succ_node);
1490 for (auto &succ_node_stra_cost : succ_node->GetStrategyCost()) {
1491 MS_EXCEPTION_IF_NULL(succ_node_stra_cost);
1492 auto succ_node_stra = succ_node_stra_cost->strategy_ptr;
1493 auto succ_node_clist = succ_node_stra_cost->cost_list;
1494 auto succ_edge_clist = succ_edge->GetCostList(merged_op_stra, succ_node_stra);
1495
1496 for (auto &succ_node_cost : succ_node_clist) {
1497 MS_EXCEPTION_IF_NULL(succ_node_cost);
1498 for (auto &succ_edge_cost : succ_edge_clist) {
1499 MS_EXCEPTION_IF_NULL(succ_edge_cost);
1500 succ_nodes_stras[k] = succ_node_stra;
1501 succ_edges_costs[k] = succ_edge_cost;
1502 succ_nodes_costs[k] = succ_node_cost;
1503 recursive(k + 1);
1504 }
1505 }
1506 }
1507 };
1508
1509 recursive(1);
1510 }
1511
EliminationStar(const OperatorInfoPtr & merged_op)1512 std::vector<std::shared_ptr<Edge>> CostGraph::EliminationStar(const OperatorInfoPtr &merged_op) {
1513 MS_EXCEPTION_IF_NULL(merged_op);
1514 auto succ_edges = merged_op->GetAliveSuccEdges();
1515 MS_LOG(INFO) << "Now eliminating star centered at: " << merged_op->name() << ".";
1516 for (auto &succ_edge : succ_edges) {
1517 MS_EXCEPTION_IF_NULL(succ_edge->next_operator());
1518 MS_LOG(INFO) << "The successive operator is: " << succ_edge->next_operator()->name() << ".";
1519 }
1520
1521 MS_EXCEPTION_IF_NULL(succ_edges[0]);
1522 auto first_succ_node = succ_edges[0]->next_operator();
1523 auto first_succ_edge = succ_edges[0];
1524 bool valid = false;
1525
1526 // 'merged_op' is merged into first_node
1527 MS_EXCEPTION_IF_NULL(first_succ_node);
1528 for (auto &first_succ_node_stra_cost : first_succ_node->GetStrategyCost()) {
1529 MS_EXCEPTION_IF_NULL(first_succ_node_stra_cost);
1530 auto first_succ_node_stra = first_succ_node_stra_cost->strategy_ptr;
1531 auto first_succ_node_clist = first_succ_node_stra_cost->cost_list;
1532 CostPtrList first_succ_node_clist_new;
1533
1534 for (auto &merged_op_stra_cost : merged_op->GetStrategyCost()) {
1535 MS_EXCEPTION_IF_NULL(merged_op_stra_cost);
1536 auto merged_op_stra = merged_op_stra_cost->strategy_ptr;
1537 auto merged_op_clist = merged_op_stra_cost->cost_list;
1538 auto first_succ_edge_clist = first_succ_edge->GetCostList(merged_op_stra, first_succ_node_stra);
1539
1540 CreateStarEliminationCostList(succ_edges, first_succ_node_stra, first_succ_node_clist, first_succ_edge_clist,
1541 merged_op_stra, merged_op_clist, &first_succ_node_clist_new);
1542 }
1543 Simplify(&first_succ_node_clist_new);
1544 // Set the new costlist w.r.t the strategy
1545 first_succ_node_stra_cost->cost_list = first_succ_node_clist_new;
1546 if ((!valid) && (!first_succ_node_clist_new.empty())) {
1547 valid = true;
1548 }
1549 }
1550
1551 if (!valid) {
1552 MS_LOG(EXCEPTION) << "Eliminating star centered at: " << merged_op->name()
1553 << " failed. It may be caused by "
1554 "configuring inconsistent strategies for operators.";
1555 }
1556
1557 merged_op->SetNotAlive();
1558 MS_LOG(INFO) << "Eliminating star centered at: " << merged_op->name() << " succeeded.";
1559 return succ_edges;
1560 }
1561
GetNumEdges() const1562 size_t CostGraph::GetNumEdges() const {
1563 size_t sum = 0;
1564 for (const auto &kv : edges_) {
1565 auto &edges = kv.second;
1566 sum += edges.size();
1567 }
1568 return sum;
1569 }
1570
InitReshapeStrategy()1571 Status CostGraph::InitReshapeStrategy() {
1572 // reshape init should be apply after the init of it's previous node and next node.
1573 for (size_t i = 0; i < ops_.size(); ++i) {
1574 if (ops_[i]->name().find(RESHAPEINFO) != std::string::npos) {
1575 auto reshape_info = std::dynamic_pointer_cast<ReshapeInfo>(ops_[i]);
1576 auto in_edges = GetOriginalPrevEdges(ops_[i]);
1577 auto pre_iter = std::find_if(in_edges.begin(), in_edges.end(), [&](const std::shared_ptr<Edge> &edge) {
1578 return edge->prev_operator()->name() == reshape_info->pre_operator_name();
1579 });
1580 auto out_edges = GetOriginalNextEdges(ops_[i]);
1581 auto next_iter = std::find_if(out_edges.begin(), out_edges.end(), [&](const std::shared_ptr<Edge> &edge) {
1582 return edge->next_operator()->name() == reshape_info->next_operator_name();
1583 });
1584 bool reshape_is_first_op = reshape_info->pre_operator_name() == reshape_info->name();
1585 if (reshape_is_first_op) {
1586 reshape_info->InitSelectedStrategy(reshape_info->selected_strategy());
1587 }
1588 if (pre_iter != in_edges.end() || reshape_is_first_op) {
1589 MS_LOG(DEBUG) << "Set reshape input layout by " << reshape_info->pre_operator_name();
1590 int64_t pre_index = reshape_info->pre_operator_index();
1591 TensorInfo pre_info;
1592 std::shared_ptr<OperatorInfo> pre_op_info;
1593 if (reshape_is_first_op) {
1594 pre_op_info = reshape_info;
1595 pre_info = pre_op_info->inputs_tensor_info()[LongToSize(pre_index)];
1596 } else {
1597 pre_op_info = (*pre_iter)->prev_operator();
1598 pre_info = pre_op_info->outputs_tensor_info()[LongToSize(pre_index)];
1599 }
1600 reshape_info->SetInputLayout(pre_info.tensor_layout());
1601 if (pre_iter != in_edges.end()) {
1602 Dimensions stra = pre_info.InferStrategy();
1603 if (stra.empty()) {
1604 MS_LOG(EXCEPTION) << "Infer strategy by tensor_info failed";
1605 }
1606 Strategys stra_inputs = {stra};
1607 StrategyPtr reshape_stra =
1608 std::make_shared<Strategy>((*pre_iter)->prev_operator()->strategy()->GetInputStage(), stra_inputs);
1609 reshape_info->set_strategy(reshape_stra);
1610 }
1611 }
1612 if (next_iter != out_edges.end()) {
1613 MS_LOG(DEBUG) << "Set reshape output layout by " << reshape_info->next_operator_name();
1614 int64_t next_index = reshape_info->next_operator_index();
1615 reshape_info->SetOutputLayout(
1616 (*next_iter)->next_operator()->inputs_tensor_info()[LongToSize(next_index)].tensor_layout());
1617 }
1618 if (reshape_info->Init(nullptr) != SUCCESS) {
1619 return FAILED;
1620 }
1621 }
1622 }
1623 return SUCCESS;
1624 }
1625
InitSelectedStrategy()1626 Status CostGraph::InitSelectedStrategy() {
1627 for (auto &op : ops_) {
1628 MS_EXCEPTION_IF_NULL(op);
1629 if (op->name().find(RESHAPEINFO) != std::string::npos) {
1630 continue;
1631 }
1632 auto result_op = op->InitSelectedStrategy(op->selected_strategy());
1633 if (result_op != SUCCESS) {
1634 return result_op;
1635 }
1636 }
1637 auto result = InitReshapeStrategy();
1638 return result;
1639 }
1640
ComputeOpsAndEdgesParameterInvolved()1641 Status CostGraph::ComputeOpsAndEdgesParameterInvolved() {
1642 for (auto &op : ops_) {
1643 MS_EXCEPTION_IF_NULL(op);
1644 const auto &output_parameter = op->ComputeOpAndPrevEdgeParameterInvolved();
1645 if ((output_parameter != 0) && (output_parameter != 1)) {
1646 MS_LOG(ERROR) << "Computing parameter_involved for " << op->name() << " failed.";
1647 return FAILED;
1648 }
1649 }
1650 return SUCCESS;
1651 }
1652
DFSForTopoOrder(const OperatorInfoPtr & current_op,std::map<OperatorInfoPtr,bool> * visited,std::vector<OperatorInfoPtr> * topo_order)1653 void CostGraph::DFSForTopoOrder(const OperatorInfoPtr ¤t_op, std::map<OperatorInfoPtr, bool> *visited,
1654 std::vector<OperatorInfoPtr> *topo_order) {
1655 MS_EXCEPTION_IF_NULL(current_op);
1656 MS_EXCEPTION_IF_NULL(visited);
1657 MS_EXCEPTION_IF_NULL(topo_order);
1658
1659 visited->at(current_op) = true;
1660 for (const auto &s_edge : current_op->succ_edges()) {
1661 if (!visited->at(s_edge->next_operator())) {
1662 DFSForTopoOrder(s_edge->next_operator(), visited, topo_order);
1663 }
1664 }
1665 topo_order->push_back(current_op);
1666 }
1667
1668 // Compute a topological order of the costgraph
TopologyOrder(std::vector<OperatorInfoPtr> * topo_order)1669 void CostGraph::TopologyOrder(std::vector<OperatorInfoPtr> *topo_order) {
1670 std::map<OperatorInfoPtr, bool> visited;
1671 for (auto &op : ops_) {
1672 visited[op] = false;
1673 }
1674
1675 for (auto &op : ops_) {
1676 if (!visited[op]) {
1677 DFSForTopoOrder(op, &visited, topo_order);
1678 }
1679 }
1680 }
MarkCriticalOpsAndEdges(const std::map<OperatorInfoPtr,int64_t> & candidate_ops)1681 void CostGraph::MarkCriticalOpsAndEdges(const std::map<OperatorInfoPtr, int64_t> &candidate_ops) {
1682 for (auto &op : ops_) {
1683 auto search = candidate_ops.find(op);
1684 if (search != candidate_ops.end()) {
1685 // Mark the critical operators
1686 op->mark_output_critical();
1687 // Mark the successive edges
1688 for (auto &s_edge : op->succ_edges()) {
1689 s_edge->mark_output_critical();
1690 }
1691 } else {
1692 op->mark_output_not_critical();
1693 }
1694 }
1695 }
1696
DetermineCriticalOps(const std::vector<OperatorInfoPtr> & topo_order)1697 Status CostGraph::DetermineCriticalOps(const std::vector<OperatorInfoPtr> &topo_order) {
1698 if (topo_order.size() == 0) {
1699 MS_LOG(ERROR) << "0 operator in costgraph.";
1700 return FAILED;
1701 }
1702 auto &first_op = topo_order[0];
1703 if (first_op->prev_edges().size() > 0) {
1704 MS_LOG(ERROR) << "The first operator in the first of topological order of "
1705 "costgraph should have 0 incoming edge, but has "
1706 << first_op->prev_edges() << "edges.";
1707 return FAILED;
1708 }
1709 // The 'curr_memory_state' records <OperatorInfo, remaining_output_cnt>, where remaining_output_cnt is the number
1710 // of the output of OperatorInfo that currently has not been used
1711 std::map<OperatorInfoPtr, int64_t> curr_memory_state;
1712 (void)curr_memory_state.emplace(std::make_pair(first_op, SizeToLong(first_op->succ_edges().size())));
1713 std::map<OperatorInfoPtr, int64_t> max_memory_state = curr_memory_state;
1714 // The 'curr_memory_size' records the current total memory size, which is the sum of outputs of operators that has
1715 // not been used
1716 double curr_memory_size = first_op->GetOutputsTotalSize();
1717 double max_memory_size = curr_memory_size;
1718
1719 for (size_t finished = 1; finished < topo_order.size(); ++finished) {
1720 // Produce
1721 (void)curr_memory_state.emplace(
1722 std::make_pair(topo_order[finished], SizeToLong(topo_order[finished]->succ_edges().size())));
1723 curr_memory_size += topo_order[finished]->GetOutputsTotalSize();
1724 // Consume
1725 for (const auto &prev_edge : topo_order[finished]->prev_edges()) {
1726 const auto &prev_op = prev_edge->prev_operator();
1727 curr_memory_state[prev_op]--;
1728 }
1729 for (const auto &prev_edge : topo_order[finished]->prev_edges()) {
1730 const auto &prev_op = prev_edge->prev_operator();
1731 if (curr_memory_state[prev_op] < 0) {
1732 MS_LOG(ERROR) << "Failure: " << prev_op->name() << "'s current output count: " << curr_memory_state[prev_op];
1733 return FAILED;
1734 } else if (curr_memory_state[prev_op] == 0) {
1735 curr_memory_state.erase(prev_op);
1736 curr_memory_size -= prev_op->GetOutputsTotalSize();
1737 }
1738 }
1739
1740 if (curr_memory_size < 0) {
1741 MS_LOG(ERROR) << "Memory size calculation failed: " << curr_memory_size;
1742 }
1743 // Modify the max
1744 if (curr_memory_size > max_memory_size) {
1745 max_memory_size = curr_memory_size;
1746 max_memory_state = curr_memory_state;
1747 }
1748 }
1749 // Mark those critical operators
1750 MarkCriticalOpsAndEdges(max_memory_state);
1751 return SUCCESS;
1752 }
1753
ComputeOpsAndEdgesOutputCritical()1754 Status CostGraph::ComputeOpsAndEdgesOutputCritical() {
1755 // Two steps to do:
1756 // 1. Compute a topological order of the costgraph
1757 // 2. Determine and mark the operators (and necessary edges) that are critical
1758 std::vector<OperatorInfoPtr> topo_order;
1759 TopologyOrder(&topo_order);
1760 std::reverse(std::begin(topo_order), std::end(topo_order));
1761
1762 if (DetermineCriticalOps(topo_order) != SUCCESS) {
1763 MS_LOG(ERROR) << "Determining critical operators failed.";
1764 return FAILED;
1765 }
1766 return SUCCESS;
1767 }
1768
CalculateOpsMemoryCost()1769 Status CostGraph::CalculateOpsMemoryCost() {
1770 for (auto &op : ops_) {
1771 MS_EXCEPTION_IF_NULL(op);
1772 if (op->CalculateMemoryCost() != SUCCESS) {
1773 MS_LOG(ERROR) << "Calculate Operator: " << op->name() << " cost for memory usage failed.";
1774 return FAILED;
1775 }
1776 }
1777 return SUCCESS;
1778 }
1779
CalculateOpsMemoryCostForInference()1780 Status CostGraph::CalculateOpsMemoryCostForInference() {
1781 for (auto &op : ops_) {
1782 MS_EXCEPTION_IF_NULL(op);
1783 if (op->CalculateMemoryCostForInference() != SUCCESS) {
1784 MS_LOG(ERROR) << "Calculate Operator: " << op->name() << " cost for memory usage failed.";
1785 return FAILED;
1786 }
1787 }
1788 return SUCCESS;
1789 }
1790
CalculateEdgesMemoryCost()1791 Status CostGraph::CalculateEdgesMemoryCost() {
1792 for (auto &edge_pair : edges_) {
1793 const auto &edges = edge_pair.second;
1794 for (auto &one_edge : edges) {
1795 if (one_edge->CalculateMemoryCost() != SUCCESS) {
1796 MS_LOG(ERROR) << "Calculate Edge: " << one_edge->edge_name() << " cost for memory usage failed.";
1797 return FAILED;
1798 }
1799 }
1800 }
1801 return SUCCESS;
1802 }
1803
CalculateEdgesMemoryCostForInference()1804 Status CostGraph::CalculateEdgesMemoryCostForInference() {
1805 for (auto &edge_pair : edges_) {
1806 const auto &edges = edge_pair.second;
1807 for (auto &one_edge : edges) {
1808 if (one_edge->CalculateMemoryCostForInference() != SUCCESS) {
1809 MS_LOG(ERROR) << "Calculate Edge: " << one_edge->edge_name() << " cost for memory usage failed.";
1810 return FAILED;
1811 }
1812 }
1813 }
1814 return SUCCESS;
1815 }
1816
FindTmpIdentityByParameterName(std::string & p_name) const1817 OperatorInfoPtr CostGraph::FindTmpIdentityByParameterName(std::string &p_name) const {
1818 for (auto one_op : ops_) {
1819 if (one_op->name().find(IDENTITY_INFO) != std::string::npos) {
1820 if (one_op->refkey_parameter_name() == p_name) {
1821 return one_op;
1822 }
1823 }
1824 }
1825 return nullptr;
1826 }
CorrectOpsMemoryCost()1827 Status CostGraph::CorrectOpsMemoryCost() {
1828 for (auto &one_op : ops_) {
1829 if ((one_op->name().find(IDENTITY_INFO) != std::string::npos) && (one_op->is_output_parameter_involve() == 1)) {
1830 if (one_op->GetAliveSuccEdges().size() > 1) {
1831 // Filter out the case when the TmpIdentity being used by multiple operators
1832 std::map<size_t, int64_t> output_count;
1833 for (size_t i = 0; i < one_op->GetAliveSuccEdges().size(); ++i) {
1834 auto output_index = one_op->GetAliveSuccEdges()[i]->prev_op_output_index();
1835 output_count[output_index]++;
1836 }
1837 for (size_t i = 0; i < one_op->GetAliveSuccEdges().size(); ++i) {
1838 auto output_index = one_op->GetAliveSuccEdges()[i]->prev_op_output_index();
1839 if (output_count[output_index] <= 1) {
1840 continue;
1841 }
1842 auto next_op = one_op->GetAliveSuccEdges()[i]->next_operator();
1843 MS_EXCEPTION_IF_NULL(next_op);
1844 auto input_index = one_op->GetAliveSuccEdges()[i]->next_op_input_index();
1845 if (next_op->CorrectMemoryCost(input_index) != SUCCESS) {
1846 MS_LOG(ERROR) << "The operator name: " << one_op->name() << ", the next operator name: " << next_op->name()
1847 << ", the output_index: " << output_index << ", the input_index: " << input_index << ".";
1848 return FAILED;
1849 }
1850 output_count[output_index]--;
1851 }
1852 }
1853 }
1854 }
1855 return SUCCESS;
1856 }
1857
CalculateMemoryCost()1858 Status CostGraph::CalculateMemoryCost() {
1859 const auto phase = CostModelContext::GetInstance()->run_phase();
1860 if (phase == TRAINING_PHASE) {
1861 // training phase
1862 if (ComputeOpsAndEdgesParameterInvolved() == SUCCESS) {
1863 // Calculate operators' memory usage
1864 if (CalculateOpsMemoryCost() != SUCCESS) {
1865 MS_LOG(ERROR) << "Calculating operators' cost for memory cost failed.";
1866 return FAILED;
1867 }
1868 // Calculate edges' memory usage
1869 if (CalculateEdgesMemoryCost() != SUCCESS) {
1870 MS_LOG(ERROR) << "Calculating edges' cost for memory cost failed.";
1871 return FAILED;
1872 }
1873 // Correct memory usage caused by TmpIdentity
1874 if (CorrectOpsMemoryCost() != SUCCESS) {
1875 MS_LOG(ERROR) << "Correcting operators' cost for memory cost failed.";
1876 return FAILED;
1877 }
1878 } else {
1879 MS_LOG(ERROR) << "Computing operators' parameter_involved failed.";
1880 return FAILED;
1881 }
1882 } else {
1883 // inference phase
1884 if (ComputeOpsAndEdgesOutputCritical() == SUCCESS) {
1885 // Calculate operators' memory usage
1886 if (CalculateOpsMemoryCostForInference() != SUCCESS) {
1887 MS_LOG(ERROR) << "Calculating operators' memory cost for inference failed.";
1888 return FAILED;
1889 }
1890 // Calculate edges's memory usage
1891 if (CalculateEdgesMemoryCostForInference() != SUCCESS) {
1892 MS_LOG(ERROR) << "Calculating operators' memory cost for inference failed.";
1893 return FAILED;
1894 }
1895 } else {
1896 MS_LOG(ERROR) << "Computing operators' critical flag failed.";
1897 return FAILED;
1898 }
1899 }
1900 return SUCCESS;
1901 }
1902
CheckApproximateCostGraphEdges()1903 void CostGraph::CheckApproximateCostGraphEdges() {
1904 auto approximation = CostModelContext::GetInstance()->dp_algo_enable_approxi();
1905 if (!approximation) {
1906 return;
1907 }
1908 for (auto &s_edge : edges_) {
1909 auto &edges_vector = s_edge.second;
1910 for (auto &edge_ptr : edges_vector) {
1911 MS_EXCEPTION_IF_NULL(edge_ptr);
1912 if (edge_ptr->CheckStrategyCostPossibility()) {
1913 continue;
1914 }
1915 MS_LOG(INFO) << "Checking StrategyCost for edge: " << edge_ptr->edge_name()
1916 << " impossible, re-initing the operators and edges";
1917 auto prev_op = edge_ptr->prev_operator();
1918 MS_EXCEPTION_IF_NULL(prev_op);
1919 auto next_op = edge_ptr->next_operator();
1920 MS_EXCEPTION_IF_NULL(next_op);
1921 // Check the 'prev_op'
1922 prev_op->ExactStrategiesAndRelatedEdges();
1923 // Check the 'next_op'
1924 next_op->ExactStrategiesAndRelatedEdges();
1925 }
1926 }
1927 }
1928 } // namespace parallel
1929 } // namespace mindspore
1930