• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_AUTO_PARALLEL_DP_ALGO_COSTMODEL_H_
18 #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_AUTO_PARALLEL_DP_ALGO_COSTMODEL_H_
19 
20 #include <memory>
21 #include <utility>
22 #include <vector>
23 #include "frontend/parallel/auto_parallel/edge_costmodel.h"
24 #include "frontend/parallel/auto_parallel/graph_costmodel.h"
25 #include "ir/value.h"
26 
27 namespace mindspore {
28 namespace parallel {
29 // There are 3 meta phases of the Dynamic Programming (DP) algorithm. The input is a CostGraph, and the goal
30 // is to compute the strategy for each operator in the CostGraph.
31 //
32 // Phase 1: Shrink the CostGraph using 6 operations, and record them in the order
33 //       Using for operations: Operator Elimination, Edge Elimination, Merge Elimination, and Contract Elimination,
34 //       each connected component in the CostGraph can be shrunk in to the final graph: u --> v. See the
35 //       interpretation of 6 operations in costmodel.h.
36 // Phase 2: Search the cost_list in the final graph, and determine the optimal one
37 //       Create the cost_list for the final graph, and choose the optimal one: one the minimum quantity
38 //       COST_MODEL_ALPHA * computation_cost + COST_MODEL_BETA * communication_cost
39 // Phase 3: Recover the original CostGraph, the determine strategy for each operator
40 //       After determining the optimal cost for the final graph, the algorithm recovers the original graph by applying
41 //       the 4 operations in the reverse order in the Phase 1. Because each operation decision contains the strategy,
42 //       the operators' strategies can be all determined.
43 
44 struct Elimination : public Base {
45   enum EliminationType { OPERA, EDGE, MERGE, CONTRACT, SOURCE, TRIANGLE, STAR };
EliminationElimination46   Elimination(EdgePtr n_edge, EliminationType ty) : new_edge_(std::move(n_edge)), type_(ty) {}
47 
48   EdgePtr new_edge_;
49   EliminationType type_;
50 };
51 
52 // Operator Elimination
53 struct OpElimination : public Elimination {
OpEliminationOpElimination54   OpElimination(EdgePtr n_edge, EdgePtr l_edge, OperatorInfoPtr op_info, EdgePtr r_edge)
55       : Elimination(std::move(n_edge), Elimination::EliminationType::OPERA),
56         left_edge_(std::move(l_edge)),
57         op_(std::move(op_info)),
58         right_edge_(std::move(r_edge)) {}
59 
60   EdgePtr left_edge_;
61   OperatorInfoPtr op_;
62   EdgePtr right_edge_;
63   MS_DECLARE_PARENT(OpElimination, Elimination);
64 };
65 
66 // Edge Elimination
67 struct EdgeElimination : public Elimination {
EdgeEliminationEdgeElimination68   EdgeElimination(const EdgePtr &n_edge, std::vector<EdgePtr> eds)
69       : Elimination(n_edge, Elimination::EliminationType::EDGE), edges_(std::move(eds)) {}
70 
71   std::vector<EdgePtr> edges_;
72   MS_DECLARE_PARENT(EdgeElimination, Elimination);
73 };
74 
75 // Merge Elimination
76 struct MergeElimination : public Elimination {
MergeEliminationMergeElimination77   MergeElimination(OperatorInfoPtr u_info, EdgePtr merged_target_edge, OperatorInfoPtr v_info)
78       : Elimination(nullptr, Elimination::EliminationType::MERGE),
79         merged_node_(std::move(u_info)),
80         dir_edge_(std::move(merged_target_edge)),
81         target_node_(std::move(v_info)) {}
82 
83   OperatorInfoPtr merged_node_;
84   EdgePtr dir_edge_;
85   OperatorInfoPtr target_node_;
86   MS_DECLARE_PARENT(MergeElimination, Elimination);
87 };
88 
89 // Contract Elimination
90 struct ContractElimination : public Elimination {
ContractEliminationContractElimination91   ContractElimination(OperatorInfoPtr tar_info, EdgePtr tar_con_edge, OperatorInfoPtr con_info)
92       : Elimination(nullptr, Elimination::EliminationType::CONTRACT),
93         contracted_node_(std::move(con_info)),
94         dir_edge_(std::move(tar_con_edge)),
95         target_node_(std::move(tar_info)) {}
96 
97   OperatorInfoPtr contracted_node_;
98   EdgePtr dir_edge_;
99   OperatorInfoPtr target_node_;
100   MS_DECLARE_PARENT(ContractElimination, Elimination);
101 };
102 
103 // Source Elimination
104 struct SourceElimination : public Elimination {
SourceEliminationSourceElimination105   SourceElimination(OperatorInfoPtr p_source, std::vector<EdgePtr> p_succ_edges, std::vector<EdgePtr> p_new_succ_edges,
106                     OperatorInfoPtr s_source, std::vector<EdgePtr> s_succ_edges, std::vector<EdgePtr> s_new_succ_edges)
107       : Elimination(nullptr, Elimination::EliminationType::SOURCE),
108         primary_source_(std::move(p_source)),
109         primary_succ_edges_(std::move(p_succ_edges)),
110         primary_new_succ_edges_(std::move(p_new_succ_edges)),
111         secondary_source_(std::move(s_source)),
112         secondary_succ_edges_(std::move(s_succ_edges)),
113         secondary_new_succ_edges_(std::move(s_new_succ_edges)) {}
114   OperatorInfoPtr primary_source_;
115   std::vector<EdgePtr> primary_succ_edges_;
116   std::vector<EdgePtr> primary_new_succ_edges_;
117   OperatorInfoPtr secondary_source_;
118   std::vector<EdgePtr> secondary_succ_edges_;
119   std::vector<EdgePtr> secondary_new_succ_edges_;
120   MS_DECLARE_PARENT(SourceElimination, Elimination);
121 };
122 
123 // Triangle Elimination
124 struct TriangleElimination : public Elimination {
TriangleEliminationTriangleElimination125   TriangleElimination(OperatorInfoPtr elim_node, EdgePtr l_edge, OperatorInfoPtr l_node, EdgePtr r_edge,
126                       OperatorInfoPtr r_node)
127       : Elimination(nullptr, Elimination::EliminationType::TRIANGLE),
128         eliminated_node_(std::move(elim_node)),
129         left_edge_(std::move(l_edge)),
130         left_node_(std::move(l_node)),
131         right_edge_(std::move(r_edge)),
132         right_node_(std::move(r_node)) {}
133 
134   OperatorInfoPtr eliminated_node_;
135   EdgePtr left_edge_;
136   OperatorInfoPtr left_node_;
137   EdgePtr right_edge_;
138   OperatorInfoPtr right_node_;
139   MS_DECLARE_PARENT(TriangleElimination, Elimination);
140 };
141 
142 // Star Elimination
143 struct StarElimination : public Elimination {
StarEliminationStarElimination144   StarElimination(OperatorInfoPtr elimi_node, std::vector<EdgePtr> s_edges, std::vector<OperatorInfoPtr> s_ops)
145       : Elimination(nullptr, Elimination::EliminationType::STAR),
146         eliminated_node_(std::move(elimi_node)),
147         succ_edges_(std::move(s_edges)),
148         succ_ops_(std::move(s_ops)) {}
149 
150   OperatorInfoPtr eliminated_node_;
151   std::vector<EdgePtr> succ_edges_;
152   std::vector<OperatorInfoPtr> succ_ops_;
153   MS_DECLARE_PARENT(StarElimination, Elimination);
154 };
155 
156 using EliminationPtr = std::shared_ptr<Elimination>;
157 using OpEliminationPtr = std::shared_ptr<OpElimination>;
158 using EdgeEliminationPtr = std::shared_ptr<EdgeElimination>;
159 using MergeEliminationPtr = std::shared_ptr<MergeElimination>;
160 using ContractEliminationPtr = std::shared_ptr<ContractElimination>;
161 using SourceEliminationPtr = std::shared_ptr<SourceElimination>;
162 using TriangleEliminationPtr = std::shared_ptr<TriangleElimination>;
163 using StarEliminationPtr = std::shared_ptr<StarElimination>;
164 
165 // Phase 1 and Phase 2
166 Status GetStrategy(const CostGraphPtr &graph);
167 
168 // Phase 3
169 Status RecoverStrategy(std::vector<EliminationPtr> eliminations);
170 }  // namespace parallel
171 }  // namespace mindspore
172 
173 #endif  // MINDSPORE_CCSRC_FRONTEND_PARALLEL_AUTO_PARALLEL_DP_ALGO_COSTMODEL_H_
174