1 /** 2 * Copyright 2019-2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_AUTO_PARALLEL_DP_ALGO_COSTMODEL_H_ 18 #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_AUTO_PARALLEL_DP_ALGO_COSTMODEL_H_ 19 20 #include <memory> 21 #include <utility> 22 #include <vector> 23 #include "frontend/parallel/auto_parallel/edge_costmodel.h" 24 #include "frontend/parallel/auto_parallel/graph_costmodel.h" 25 #include "ir/value.h" 26 27 namespace mindspore { 28 namespace parallel { 29 // There are 3 meta phases of the Dynamic Programming (DP) algorithm. The input is a CostGraph, and the goal 30 // is to compute the strategy for each operator in the CostGraph. 31 // 32 // Phase 1: Shrink the CostGraph using 6 operations, and record them in the order 33 // Using for operations: Operator Elimination, Edge Elimination, Merge Elimination, and Contract Elimination, 34 // each connected component in the CostGraph can be shrunk in to the final graph: u --> v. See the 35 // interpretation of 6 operations in costmodel.h. 36 // Phase 2: Search the cost_list in the final graph, and determine the optimal one 37 // Create the cost_list for the final graph, and choose the optimal one: one the minimum quantity 38 // COST_MODEL_ALPHA * computation_cost + COST_MODEL_BETA * communication_cost 39 // Phase 3: Recover the original CostGraph, the determine strategy for each operator 40 // After determining the optimal cost for the final graph, the algorithm recovers the original graph by applying 41 // the 4 operations in the reverse order in the Phase 1. Because each operation decision contains the strategy, 42 // the operators' strategies can be all determined. 43 44 struct Elimination : public Base { 45 enum EliminationType { OPERA, EDGE, MERGE, CONTRACT, SOURCE, TRIANGLE, STAR }; EliminationElimination46 Elimination(EdgePtr n_edge, EliminationType ty) : new_edge_(std::move(n_edge)), type_(ty) {} 47 ~Elimination() override = default; 48 49 EdgePtr new_edge_; 50 EliminationType type_; 51 }; 52 53 // Operator Elimination 54 struct OpElimination : public Elimination { OpEliminationOpElimination55 OpElimination(EdgePtr n_edge, EdgePtr l_edge, OperatorInfoPtr op_info, EdgePtr r_edge) 56 : Elimination(std::move(n_edge), Elimination::EliminationType::OPERA), 57 left_edge_(std::move(l_edge)), 58 op_(std::move(op_info)), 59 right_edge_(std::move(r_edge)) {} 60 ~OpElimination() override = default; 61 62 EdgePtr left_edge_; 63 OperatorInfoPtr op_; 64 EdgePtr right_edge_; 65 MS_DECLARE_PARENT(OpElimination, Elimination); 66 }; 67 68 // Edge Elimination 69 struct EdgeElimination : public Elimination { EdgeEliminationEdgeElimination70 EdgeElimination(const EdgePtr &n_edge, std::vector<EdgePtr> eds) 71 : Elimination(n_edge, Elimination::EliminationType::EDGE), edges_(std::move(eds)) {} 72 ~EdgeElimination() override = default; 73 74 std::vector<EdgePtr> edges_; 75 MS_DECLARE_PARENT(EdgeElimination, Elimination); 76 }; 77 78 // Merge Elimination 79 struct MergeElimination : public Elimination { MergeEliminationMergeElimination80 MergeElimination(OperatorInfoPtr u_info, EdgePtr merged_target_edge, OperatorInfoPtr v_info) 81 : Elimination(nullptr, Elimination::EliminationType::MERGE), 82 merged_node_(std::move(u_info)), 83 dir_edge_(std::move(merged_target_edge)), 84 target_node_(std::move(v_info)) {} 85 ~MergeElimination() override = default; 86 87 OperatorInfoPtr merged_node_; 88 EdgePtr dir_edge_; 89 OperatorInfoPtr target_node_; 90 MS_DECLARE_PARENT(MergeElimination, Elimination); 91 }; 92 93 // Contract Elimination 94 struct ContractElimination : public Elimination { ContractEliminationContractElimination95 ContractElimination(OperatorInfoPtr tar_info, EdgePtr tar_con_edge, OperatorInfoPtr con_info) 96 : Elimination(nullptr, Elimination::EliminationType::CONTRACT), 97 contracted_node_(std::move(con_info)), 98 dir_edge_(std::move(tar_con_edge)), 99 target_node_(std::move(tar_info)) {} 100 ~ContractElimination() override = default; 101 102 OperatorInfoPtr contracted_node_; 103 EdgePtr dir_edge_; 104 OperatorInfoPtr target_node_; 105 MS_DECLARE_PARENT(ContractElimination, Elimination); 106 }; 107 108 // Source Elimination 109 struct SourceElimination : public Elimination { SourceEliminationSourceElimination110 SourceElimination(OperatorInfoPtr p_source, std::vector<EdgePtr> p_succ_edges, std::vector<EdgePtr> p_new_succ_edges, 111 OperatorInfoPtr s_source, std::vector<EdgePtr> s_succ_edges, std::vector<EdgePtr> s_new_succ_edges) 112 : Elimination(nullptr, Elimination::EliminationType::SOURCE), 113 primary_source_(std::move(p_source)), 114 primary_succ_edges_(std::move(p_succ_edges)), 115 primary_new_succ_edges_(std::move(p_new_succ_edges)), 116 secondary_source_(std::move(s_source)), 117 secondary_succ_edges_(std::move(s_succ_edges)), 118 secondary_new_succ_edges_(std::move(s_new_succ_edges)) {} 119 ~SourceElimination() override = default; 120 121 OperatorInfoPtr primary_source_; 122 std::vector<EdgePtr> primary_succ_edges_; 123 std::vector<EdgePtr> primary_new_succ_edges_; 124 OperatorInfoPtr secondary_source_; 125 std::vector<EdgePtr> secondary_succ_edges_; 126 std::vector<EdgePtr> secondary_new_succ_edges_; 127 MS_DECLARE_PARENT(SourceElimination, Elimination); 128 }; 129 130 // Triangle Elimination 131 struct TriangleElimination : public Elimination { TriangleEliminationTriangleElimination132 TriangleElimination(OperatorInfoPtr elim_node, EdgePtr l_edge, OperatorInfoPtr l_node, EdgePtr r_edge, 133 OperatorInfoPtr r_node) 134 : Elimination(nullptr, Elimination::EliminationType::TRIANGLE), 135 eliminated_node_(std::move(elim_node)), 136 left_edge_(std::move(l_edge)), 137 left_node_(std::move(l_node)), 138 right_edge_(std::move(r_edge)), 139 right_node_(std::move(r_node)) {} 140 ~TriangleElimination() override = default; 141 142 OperatorInfoPtr eliminated_node_; 143 EdgePtr left_edge_; 144 OperatorInfoPtr left_node_; 145 EdgePtr right_edge_; 146 OperatorInfoPtr right_node_; 147 MS_DECLARE_PARENT(TriangleElimination, Elimination); 148 }; 149 150 // Star Elimination 151 struct StarElimination : public Elimination { StarEliminationStarElimination152 StarElimination(OperatorInfoPtr elimi_node, std::vector<EdgePtr> s_edges, std::vector<OperatorInfoPtr> s_ops) 153 : Elimination(nullptr, Elimination::EliminationType::STAR), 154 eliminated_node_(std::move(elimi_node)), 155 succ_edges_(std::move(s_edges)), 156 succ_ops_(std::move(s_ops)) {} 157 ~StarElimination() override = default; 158 159 OperatorInfoPtr eliminated_node_; 160 std::vector<EdgePtr> succ_edges_; 161 std::vector<OperatorInfoPtr> succ_ops_; 162 MS_DECLARE_PARENT(StarElimination, Elimination); 163 }; 164 165 using EliminationPtr = std::shared_ptr<Elimination>; 166 using OpEliminationPtr = std::shared_ptr<OpElimination>; 167 using EdgeEliminationPtr = std::shared_ptr<EdgeElimination>; 168 using MergeEliminationPtr = std::shared_ptr<MergeElimination>; 169 using ContractEliminationPtr = std::shared_ptr<ContractElimination>; 170 using SourceEliminationPtr = std::shared_ptr<SourceElimination>; 171 using TriangleEliminationPtr = std::shared_ptr<TriangleElimination>; 172 using StarEliminationPtr = std::shared_ptr<StarElimination>; 173 174 // Phase 1 and Phase 2 175 Status GetStrategy(const CostGraphPtr &graph); 176 177 // Phase 3 178 Status RecoverStrategy(std::vector<EliminationPtr> eliminations); 179 } // namespace parallel 180 } // namespace mindspore 181 182 #endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_AUTO_PARALLEL_DP_ALGO_COSTMODEL_H_ 183