• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_AUTO_PARALLEL_DP_ALGO_COSTMODEL_H_
18 #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_AUTO_PARALLEL_DP_ALGO_COSTMODEL_H_
19 
20 #include <memory>
21 #include <utility>
22 #include <vector>
23 #include "frontend/parallel/auto_parallel/edge_costmodel.h"
24 #include "frontend/parallel/auto_parallel/graph_costmodel.h"
25 #include "ir/value.h"
26 
27 namespace mindspore {
28 namespace parallel {
29 // There are 3 meta phases of the Dynamic Programming (DP) algorithm. The input is a CostGraph, and the goal
30 // is to compute the strategy for each operator in the CostGraph.
31 //
32 // Phase 1: Shrink the CostGraph using 6 operations, and record them in the order
33 //       Using for operations: Operator Elimination, Edge Elimination, Merge Elimination, and Contract Elimination,
34 //       each connected component in the CostGraph can be shrunk in to the final graph: u --> v. See the
35 //       interpretation of 6 operations in costmodel.h.
36 // Phase 2: Search the cost_list in the final graph, and determine the optimal one
37 //       Create the cost_list for the final graph, and choose the optimal one: one the minimum quantity
38 //       COST_MODEL_ALPHA * computation_cost + COST_MODEL_BETA * communication_cost
39 // Phase 3: Recover the original CostGraph, the determine strategy for each operator
40 //       After determining the optimal cost for the final graph, the algorithm recovers the original graph by applying
41 //       the 4 operations in the reverse order in the Phase 1. Because each operation decision contains the strategy,
42 //       the operators' strategies can be all determined.
43 
44 struct Elimination : public Base {
45   enum EliminationType { OPERA, EDGE, MERGE, CONTRACT, SOURCE, TRIANGLE, STAR };
EliminationElimination46   Elimination(EdgePtr n_edge, EliminationType ty) : new_edge_(std::move(n_edge)), type_(ty) {}
47   ~Elimination() override = default;
48 
49   EdgePtr new_edge_;
50   EliminationType type_;
51 };
52 
53 // Operator Elimination
54 struct OpElimination : public Elimination {
OpEliminationOpElimination55   OpElimination(EdgePtr n_edge, EdgePtr l_edge, OperatorInfoPtr op_info, EdgePtr r_edge)
56       : Elimination(std::move(n_edge), Elimination::EliminationType::OPERA),
57         left_edge_(std::move(l_edge)),
58         op_(std::move(op_info)),
59         right_edge_(std::move(r_edge)) {}
60   ~OpElimination() override = default;
61 
62   EdgePtr left_edge_;
63   OperatorInfoPtr op_;
64   EdgePtr right_edge_;
65   MS_DECLARE_PARENT(OpElimination, Elimination);
66 };
67 
68 // Edge Elimination
69 struct EdgeElimination : public Elimination {
EdgeEliminationEdgeElimination70   EdgeElimination(const EdgePtr &n_edge, std::vector<EdgePtr> eds)
71       : Elimination(n_edge, Elimination::EliminationType::EDGE), edges_(std::move(eds)) {}
72   ~EdgeElimination() override = default;
73 
74   std::vector<EdgePtr> edges_;
75   MS_DECLARE_PARENT(EdgeElimination, Elimination);
76 };
77 
78 // Merge Elimination
79 struct MergeElimination : public Elimination {
MergeEliminationMergeElimination80   MergeElimination(OperatorInfoPtr u_info, EdgePtr merged_target_edge, OperatorInfoPtr v_info)
81       : Elimination(nullptr, Elimination::EliminationType::MERGE),
82         merged_node_(std::move(u_info)),
83         dir_edge_(std::move(merged_target_edge)),
84         target_node_(std::move(v_info)) {}
85   ~MergeElimination() override = default;
86 
87   OperatorInfoPtr merged_node_;
88   EdgePtr dir_edge_;
89   OperatorInfoPtr target_node_;
90   MS_DECLARE_PARENT(MergeElimination, Elimination);
91 };
92 
93 // Contract Elimination
94 struct ContractElimination : public Elimination {
ContractEliminationContractElimination95   ContractElimination(OperatorInfoPtr tar_info, EdgePtr tar_con_edge, OperatorInfoPtr con_info)
96       : Elimination(nullptr, Elimination::EliminationType::CONTRACT),
97         contracted_node_(std::move(con_info)),
98         dir_edge_(std::move(tar_con_edge)),
99         target_node_(std::move(tar_info)) {}
100   ~ContractElimination() override = default;
101 
102   OperatorInfoPtr contracted_node_;
103   EdgePtr dir_edge_;
104   OperatorInfoPtr target_node_;
105   MS_DECLARE_PARENT(ContractElimination, Elimination);
106 };
107 
108 // Source Elimination
109 struct SourceElimination : public Elimination {
SourceEliminationSourceElimination110   SourceElimination(OperatorInfoPtr p_source, std::vector<EdgePtr> p_succ_edges, std::vector<EdgePtr> p_new_succ_edges,
111                     OperatorInfoPtr s_source, std::vector<EdgePtr> s_succ_edges, std::vector<EdgePtr> s_new_succ_edges)
112       : Elimination(nullptr, Elimination::EliminationType::SOURCE),
113         primary_source_(std::move(p_source)),
114         primary_succ_edges_(std::move(p_succ_edges)),
115         primary_new_succ_edges_(std::move(p_new_succ_edges)),
116         secondary_source_(std::move(s_source)),
117         secondary_succ_edges_(std::move(s_succ_edges)),
118         secondary_new_succ_edges_(std::move(s_new_succ_edges)) {}
119   ~SourceElimination() override = default;
120 
121   OperatorInfoPtr primary_source_;
122   std::vector<EdgePtr> primary_succ_edges_;
123   std::vector<EdgePtr> primary_new_succ_edges_;
124   OperatorInfoPtr secondary_source_;
125   std::vector<EdgePtr> secondary_succ_edges_;
126   std::vector<EdgePtr> secondary_new_succ_edges_;
127   MS_DECLARE_PARENT(SourceElimination, Elimination);
128 };
129 
130 // Triangle Elimination
131 struct TriangleElimination : public Elimination {
TriangleEliminationTriangleElimination132   TriangleElimination(OperatorInfoPtr elim_node, EdgePtr l_edge, OperatorInfoPtr l_node, EdgePtr r_edge,
133                       OperatorInfoPtr r_node)
134       : Elimination(nullptr, Elimination::EliminationType::TRIANGLE),
135         eliminated_node_(std::move(elim_node)),
136         left_edge_(std::move(l_edge)),
137         left_node_(std::move(l_node)),
138         right_edge_(std::move(r_edge)),
139         right_node_(std::move(r_node)) {}
140   ~TriangleElimination() override = default;
141 
142   OperatorInfoPtr eliminated_node_;
143   EdgePtr left_edge_;
144   OperatorInfoPtr left_node_;
145   EdgePtr right_edge_;
146   OperatorInfoPtr right_node_;
147   MS_DECLARE_PARENT(TriangleElimination, Elimination);
148 };
149 
150 // Star Elimination
151 struct StarElimination : public Elimination {
StarEliminationStarElimination152   StarElimination(OperatorInfoPtr elimi_node, std::vector<EdgePtr> s_edges, std::vector<OperatorInfoPtr> s_ops)
153       : Elimination(nullptr, Elimination::EliminationType::STAR),
154         eliminated_node_(std::move(elimi_node)),
155         succ_edges_(std::move(s_edges)),
156         succ_ops_(std::move(s_ops)) {}
157   ~StarElimination() override = default;
158 
159   OperatorInfoPtr eliminated_node_;
160   std::vector<EdgePtr> succ_edges_;
161   std::vector<OperatorInfoPtr> succ_ops_;
162   MS_DECLARE_PARENT(StarElimination, Elimination);
163 };
164 
165 using EliminationPtr = std::shared_ptr<Elimination>;
166 using OpEliminationPtr = std::shared_ptr<OpElimination>;
167 using EdgeEliminationPtr = std::shared_ptr<EdgeElimination>;
168 using MergeEliminationPtr = std::shared_ptr<MergeElimination>;
169 using ContractEliminationPtr = std::shared_ptr<ContractElimination>;
170 using SourceEliminationPtr = std::shared_ptr<SourceElimination>;
171 using TriangleEliminationPtr = std::shared_ptr<TriangleElimination>;
172 using StarEliminationPtr = std::shared_ptr<StarElimination>;
173 
174 // Phase 1 and Phase 2
175 Status GetStrategy(const CostGraphPtr &graph);
176 
177 // Phase 3
178 Status RecoverStrategy(std::vector<EliminationPtr> eliminations);
179 }  // namespace parallel
180 }  // namespace mindspore
181 
182 #endif  // MINDSPORE_CCSRC_FRONTEND_PARALLEL_AUTO_PARALLEL_DP_ALGO_COSTMODEL_H_
183