1 /** 2 * Copyright 2019 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_AUTO_PARALLEL_DP_ALGO_COSTMODEL_H_ 18 #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_AUTO_PARALLEL_DP_ALGO_COSTMODEL_H_ 19 20 #include <memory> 21 #include <utility> 22 #include <vector> 23 #include "frontend/parallel/auto_parallel/edge_costmodel.h" 24 #include "frontend/parallel/auto_parallel/graph_costmodel.h" 25 #include "ir/value.h" 26 27 namespace mindspore { 28 namespace parallel { 29 // There are 3 meta phases of the Dynamic Programming (DP) algorithm. The input is a CostGraph, and the goal 30 // is to compute the strategy for each operator in the CostGraph. 31 // 32 // Phase 1: Shrink the CostGraph using 6 operations, and record them in the order 33 // Using for operations: Operator Elimination, Edge Elimination, Merge Elimination, and Contract Elimination, 34 // each connected component in the CostGraph can be shrunk in to the final graph: u --> v. See the 35 // interpretation of 6 operations in costmodel.h. 36 // Phase 2: Search the cost_list in the final graph, and determine the optimal one 37 // Create the cost_list for the final graph, and choose the optimal one: one the minimum quantity 38 // COST_MODEL_ALPHA * computation_cost + COST_MODEL_BETA * communication_cost 39 // Phase 3: Recover the original CostGraph, the determine strategy for each operator 40 // After determining the optimal cost for the final graph, the algorithm recovers the original graph by applying 41 // the 4 operations in the reverse order in the Phase 1. Because each operation decision contains the strategy, 42 // the operators' strategies can be all determined. 43 44 struct Elimination : public Base { 45 enum EliminationType { OPERA, EDGE, MERGE, CONTRACT, SOURCE, TRIANGLE, STAR }; EliminationElimination46 Elimination(EdgePtr n_edge, EliminationType ty) : new_edge_(std::move(n_edge)), type_(ty) {} 47 48 EdgePtr new_edge_; 49 EliminationType type_; 50 }; 51 52 // Operator Elimination 53 struct OpElimination : public Elimination { OpEliminationOpElimination54 OpElimination(EdgePtr n_edge, EdgePtr l_edge, OperatorInfoPtr op_info, EdgePtr r_edge) 55 : Elimination(std::move(n_edge), Elimination::EliminationType::OPERA), 56 left_edge_(std::move(l_edge)), 57 op_(std::move(op_info)), 58 right_edge_(std::move(r_edge)) {} 59 60 EdgePtr left_edge_; 61 OperatorInfoPtr op_; 62 EdgePtr right_edge_; 63 MS_DECLARE_PARENT(OpElimination, Elimination); 64 }; 65 66 // Edge Elimination 67 struct EdgeElimination : public Elimination { EdgeEliminationEdgeElimination68 EdgeElimination(const EdgePtr &n_edge, std::vector<EdgePtr> eds) 69 : Elimination(n_edge, Elimination::EliminationType::EDGE), edges_(std::move(eds)) {} 70 71 std::vector<EdgePtr> edges_; 72 MS_DECLARE_PARENT(EdgeElimination, Elimination); 73 }; 74 75 // Merge Elimination 76 struct MergeElimination : public Elimination { MergeEliminationMergeElimination77 MergeElimination(OperatorInfoPtr u_info, EdgePtr merged_target_edge, OperatorInfoPtr v_info) 78 : Elimination(nullptr, Elimination::EliminationType::MERGE), 79 merged_node_(std::move(u_info)), 80 dir_edge_(std::move(merged_target_edge)), 81 target_node_(std::move(v_info)) {} 82 83 OperatorInfoPtr merged_node_; 84 EdgePtr dir_edge_; 85 OperatorInfoPtr target_node_; 86 MS_DECLARE_PARENT(MergeElimination, Elimination); 87 }; 88 89 // Contract Elimination 90 struct ContractElimination : public Elimination { ContractEliminationContractElimination91 ContractElimination(OperatorInfoPtr tar_info, EdgePtr tar_con_edge, OperatorInfoPtr con_info) 92 : Elimination(nullptr, Elimination::EliminationType::CONTRACT), 93 contracted_node_(std::move(con_info)), 94 dir_edge_(std::move(tar_con_edge)), 95 target_node_(std::move(tar_info)) {} 96 97 OperatorInfoPtr contracted_node_; 98 EdgePtr dir_edge_; 99 OperatorInfoPtr target_node_; 100 MS_DECLARE_PARENT(ContractElimination, Elimination); 101 }; 102 103 // Source Elimination 104 struct SourceElimination : public Elimination { SourceEliminationSourceElimination105 SourceElimination(OperatorInfoPtr p_source, std::vector<EdgePtr> p_succ_edges, std::vector<EdgePtr> p_new_succ_edges, 106 OperatorInfoPtr s_source, std::vector<EdgePtr> s_succ_edges, std::vector<EdgePtr> s_new_succ_edges) 107 : Elimination(nullptr, Elimination::EliminationType::SOURCE), 108 primary_source_(std::move(p_source)), 109 primary_succ_edges_(std::move(p_succ_edges)), 110 primary_new_succ_edges_(std::move(p_new_succ_edges)), 111 secondary_source_(std::move(s_source)), 112 secondary_succ_edges_(std::move(s_succ_edges)), 113 secondary_new_succ_edges_(std::move(s_new_succ_edges)) {} 114 OperatorInfoPtr primary_source_; 115 std::vector<EdgePtr> primary_succ_edges_; 116 std::vector<EdgePtr> primary_new_succ_edges_; 117 OperatorInfoPtr secondary_source_; 118 std::vector<EdgePtr> secondary_succ_edges_; 119 std::vector<EdgePtr> secondary_new_succ_edges_; 120 MS_DECLARE_PARENT(SourceElimination, Elimination); 121 }; 122 123 // Triangle Elimination 124 struct TriangleElimination : public Elimination { TriangleEliminationTriangleElimination125 TriangleElimination(OperatorInfoPtr elim_node, EdgePtr l_edge, OperatorInfoPtr l_node, EdgePtr r_edge, 126 OperatorInfoPtr r_node) 127 : Elimination(nullptr, Elimination::EliminationType::TRIANGLE), 128 eliminated_node_(std::move(elim_node)), 129 left_edge_(std::move(l_edge)), 130 left_node_(std::move(l_node)), 131 right_edge_(std::move(r_edge)), 132 right_node_(std::move(r_node)) {} 133 134 OperatorInfoPtr eliminated_node_; 135 EdgePtr left_edge_; 136 OperatorInfoPtr left_node_; 137 EdgePtr right_edge_; 138 OperatorInfoPtr right_node_; 139 MS_DECLARE_PARENT(TriangleElimination, Elimination); 140 }; 141 142 // Star Elimination 143 struct StarElimination : public Elimination { StarEliminationStarElimination144 StarElimination(OperatorInfoPtr elimi_node, std::vector<EdgePtr> s_edges, std::vector<OperatorInfoPtr> s_ops) 145 : Elimination(nullptr, Elimination::EliminationType::STAR), 146 eliminated_node_(std::move(elimi_node)), 147 succ_edges_(std::move(s_edges)), 148 succ_ops_(std::move(s_ops)) {} 149 150 OperatorInfoPtr eliminated_node_; 151 std::vector<EdgePtr> succ_edges_; 152 std::vector<OperatorInfoPtr> succ_ops_; 153 MS_DECLARE_PARENT(StarElimination, Elimination); 154 }; 155 156 using EliminationPtr = std::shared_ptr<Elimination>; 157 using OpEliminationPtr = std::shared_ptr<OpElimination>; 158 using EdgeEliminationPtr = std::shared_ptr<EdgeElimination>; 159 using MergeEliminationPtr = std::shared_ptr<MergeElimination>; 160 using ContractEliminationPtr = std::shared_ptr<ContractElimination>; 161 using SourceEliminationPtr = std::shared_ptr<SourceElimination>; 162 using TriangleEliminationPtr = std::shared_ptr<TriangleElimination>; 163 using StarEliminationPtr = std::shared_ptr<StarElimination>; 164 165 // Phase 1 and Phase 2 166 Status GetStrategy(const CostGraphPtr &graph); 167 168 // Phase 3 169 Status RecoverStrategy(std::vector<EliminationPtr> eliminations); 170 } // namespace parallel 171 } // namespace mindspore 172 173 #endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_AUTO_PARALLEL_DP_ALGO_COSTMODEL_H_ 174