1 /** 2 * Copyright 2019-2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_AUTO_PARALLEL_COSTMODEL_H_ 18 #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_AUTO_PARALLEL_COSTMODEL_H_ 19 20 #include <algorithm> 21 #include <memory> 22 #include <string> 23 #include <utility> 24 #include <vector> 25 #include "frontend/parallel/strategy.h" 26 #include "frontend/parallel/tensor_layout/tensor_info.h" 27 #include "frontend/parallel/costmodel_context.h" 28 29 namespace mindspore { 30 namespace parallel { 31 struct Decision; 32 using OperatorName = std::string; 33 using Attr = std::pair<std::string, ValuePtr>; 34 using Param = std::pair<std::pair<std::string, ValuePtr>, int64_t>; 35 using OperatorParams = std::vector<Param>; 36 using OperatorAttrs = std::vector<Attr>; 37 // OutPutInfo.fist: true if the operator's output is a tuple 38 // OutPutInfo.second: elements number of the tuple output. Only meaningful if OutPutInfo.fist is true. 39 using OutPutInfo = std::pair<bool, uint64_t>; 40 using OutPutInfoVector = std::vector<OutPutInfo>; 41 using OperatorArgs = std::pair<OperatorAttrs, OperatorParams>; 42 using Operator = std::pair<OperatorName, OperatorArgs>; 43 using OperatorVector = std::vector<Operator>; 44 using RedistributionOpListPtr = std::shared_ptr<std::pair<OperatorVector, OutPutInfoVector>>; 45 46 struct Cost { 47 Cost(); 48 Cost(double computation, double communication, const std::shared_ptr<Decision> &decision_ = nullptr) computation_cost_Cost49 : computation_cost_(computation), 50 communication_cost_(communication), 51 communication_without_parameter_(0.0), 52 communication_with_partial_para_(0.0), 53 communication_forward_(0.0), 54 communication_redis_forward_(0.0), 55 communication_redis_backward_(0.0), 56 memory_with_reuse_(0.0), 57 decision_ptr_(decision_) {} 58 // 'computation_cost_' models the training time of an iteration in a training phase. Currently, this is calculated 59 // by ONLY forward phase 60 double computation_cost_; 61 // 'communication_cost_' includes communications from operators (forward and backward) and edges (redistribution) 62 double communication_cost_; 63 // communication_without_parameter_ = communication_cost_ - (backward communication from operators) 64 double communication_without_parameter_; 65 // communication_with_partial_para_ = 66 // communication_without_parameter_ + COST_MODEL_GAMMA * (communication_cost_ - communication_without_parameter_ ) 67 double communication_with_partial_para_; 68 // communication_forward_ = communication cost from operators (only forward phase) and forward redistribution. 69 double communication_forward_; 70 double communication_redis_forward_; 71 double communication_redis_backward_; 72 // 'memory_with_reuse_' calculates the peak memory usage in a training (or inference) phase 73 double memory_with_reuse_; 74 std::shared_ptr<Decision> decision_ptr_; 75 }; 76 77 using CostPtr = std::shared_ptr<Cost>; 78 using CostPtrList = std::vector<std::shared_ptr<Cost>>; 79 80 class StrategyWithCost { 81 public: StrategyWithCost(StrategyPtr strategy,std::vector<TensorInfo> inputs_,std::vector<TensorInfo> outputs_)82 StrategyWithCost(StrategyPtr strategy, std::vector<TensorInfo> inputs_, std::vector<TensorInfo> outputs_) 83 : strategy_ptr(std::move(strategy)), inputs_ptr(std::move(inputs_)), outputs_ptr(std::move(outputs_)) {} StrategyWithCost(StrategyPtr strategy,CostPtrList c_list)84 StrategyWithCost(StrategyPtr strategy, CostPtrList c_list) 85 : strategy_ptr(std::move(strategy)), cost_list(std::move(c_list)) {} 86 87 StrategyWithCost(const StrategyWithCost &swc) = delete; 88 StrategyWithCost &operator=(const StrategyWithCost &swc) = delete; StrategyWithCost(StrategyWithCost && swc)89 StrategyWithCost(StrategyWithCost &&swc) 90 : strategy_ptr(swc.strategy_ptr), 91 inputs_ptr(swc.inputs_ptr), 92 outputs_ptr(swc.outputs_ptr), 93 cost_list(swc.cost_list) {} 94 StrategyWithCost &operator=(StrategyWithCost &&swc) { 95 if (&swc != this) { 96 strategy_ptr = swc.strategy_ptr; 97 inputs_ptr = swc.inputs_ptr; 98 outputs_ptr = swc.outputs_ptr; 99 cost_list = swc.cost_list; 100 } 101 return *this; 102 } 103 ~StrategyWithCost() = default; 104 105 StrategyPtr strategy_ptr; 106 std::vector<TensorInfo> inputs_ptr; 107 std::vector<TensorInfo> outputs_ptr; 108 CostPtrList cost_list; 109 }; 110 111 enum DecisionType { 112 OP_ELIMINATION, 113 EDGE_ELIMINATION, 114 MERGE_ELIMINATION, 115 CONTRACT_ELIMINATION, 116 SOURCE_ELIMINATION, 117 TRIANGLE_ELIMINATION, 118 STAR_ELIMINATION, 119 FINAL_TYPE, 120 FINAL_SINGLE 121 }; 122 123 struct Decision : public Base { 124 ~Decision() override = default; 125 DecisionType type_; 126 }; 127 128 // 'OpEliminationDecision' is for the Operator Elimination in DP algorithm: u --> v --> w ==> u --> w. 129 // This data structure records the strategy 'op_strategy_' for v, the edge cost 'left_cost_' for 'u --> v', the 130 // operator cost 'middle_cost_' for v, and the edge cost 'right_cost_' for 'v --> w' 131 struct OpEliminationDecision : public Decision { OpEliminationDecisionOpEliminationDecision132 OpEliminationDecision(StrategyPtr op_stra, CostPtr l_cost, CostPtr m_cost, CostPtr r_cost) 133 : op_strategy_(std::move(op_stra)), 134 left_cost_(std::move(l_cost)), 135 middle_cost_(std::move(m_cost)), 136 right_cost_(std::move(r_cost)) { 137 type_ = DecisionType::OP_ELIMINATION; 138 } 139 ~OpEliminationDecision() override = default; 140 141 StrategyPtr op_strategy_; 142 CostPtr left_cost_; 143 CostPtr middle_cost_; 144 CostPtr right_cost_; 145 MS_DECLARE_PARENT(OpEliminationDecision, Decision); 146 }; 147 148 /* 'EdgeEliminationDecision' is for the Edge Elimination in DP algorithm: 149 ____ 150 / \ 151 u v ==> u --> v, which replace the multi-edges by a single edge. 152 \____/ 153 This data structure records the cost list for all edges 'edges_cost_list_' 154 */ 155 struct EdgeEliminationDecision : public Decision { EdgeEliminationDecisionEdgeEliminationDecision156 explicit EdgeEliminationDecision(CostPtrList cost_list) : edges_cost_list_(std::move(cost_list)) { 157 type_ = DecisionType::EDGE_ELIMINATION; 158 } 159 ~EdgeEliminationDecision() override = default; 160 161 CostPtrList edges_cost_list_; 162 MS_DECLARE_PARENT(EdgeEliminationDecision, Decision); 163 }; 164 165 // 'MergeEliminationDecision' is for the Merge Elimination in DP algorithm: 166 // w 167 // | 168 // | ==> u --> v 169 // u --> v In the original graph, v has two alive incoming edges, w has one alive outgoing edge, 170 // and w has zero alive incoming edges. After the Merge Elimination, the result graph contains only 'u -- >v'. 171 // This data structure records the strategy 'merged_op_strategy_' for operator 'w', 172 // the cost 'merged_op_cost_' for operator 'w', and the edge cost 'edge_cost_' for 'w --> v'. 173 struct MergeEliminationDecision : public Decision { MergeEliminationDecisionMergeEliminationDecision174 MergeEliminationDecision(StrategyPtr op_stra, CostPtr op_cost, CostPtr edge_c, StrategyPtr tar_op_stra, 175 CostPtr target_op_c) 176 : merged_op_strategy_(std::move(op_stra)), 177 merged_op_cost_(std::move(op_cost)), 178 edge_cost_(std::move(edge_c)), 179 target_op_strategy_(std::move(tar_op_stra)), 180 target_op_cost_(std::move(target_op_c)) { 181 type_ = DecisionType::MERGE_ELIMINATION; 182 } 183 ~MergeEliminationDecision() override = default; 184 185 StrategyPtr merged_op_strategy_; 186 CostPtr merged_op_cost_; 187 CostPtr edge_cost_; 188 StrategyPtr target_op_strategy_; 189 CostPtr target_op_cost_; 190 MS_DECLARE_PARENT(MergeEliminationDecision, Decision); 191 }; 192 193 // 'ContractEliminationDecision' is for the Contract Elimination in DP algorithm: 194 // u --> v 195 // | 196 // | ==> u --> w 197 // w In the original graph, u has two alive outgoing edges, v has one alive incoming edge, 198 // and v has zero outgoing edge. After the Contract Elimination, the result graph contains only 'u --> w'. 199 // This data structure records the strategy 'contracted_op_strategy_' for operator 'v', the cost for 200 // operator 'contracted_op_cost_', and the edge cost for 'edge_cost_'. 201 struct ContractEliminationDecision : public Decision { ContractEliminationDecisionContractEliminationDecision202 ContractEliminationDecision(StrategyPtr contra_stra, CostPtr contra_op_cost, CostPtr edge_cost, 203 StrategyPtr target_stra, CostPtr tar_cost) 204 : contracted_op_strategy_(std::move(contra_stra)), 205 contracted_op_cost_(std::move(contra_op_cost)), 206 edge_cost_(std::move(edge_cost)), 207 target_op_strategy_(std::move(target_stra)), 208 target_cost_(std::move(tar_cost)) { 209 type_ = DecisionType::CONTRACT_ELIMINATION; 210 } 211 ~ContractEliminationDecision() override = default; 212 213 StrategyPtr contracted_op_strategy_; 214 CostPtr contracted_op_cost_; 215 CostPtr edge_cost_; 216 StrategyPtr target_op_strategy_; 217 CostPtr target_cost_; 218 MS_DECLARE_PARENT(ContractEliminationDecision, Decision); 219 }; 220 221 /* 'SourceEliminationDecision' is for the source Elimination in DP algorithm: 222 * 1 1,5 223 * / \ // \\ 224 * / \ // \\ 225 * / \ // \\ 226 * / \ // \\ 227 * 2 <- 5 -> 3 ==> 2 3 228 * \ / \ / 229 * \ / \ / 230 * \ / \ / 231 * 4 4 232 * 233 * In the original graph, '1' has two alive outgoing edges and no incoming edges. '5' has two alive outgoing edges and 234 * no incoming edges. '4' has two alive incoming edges and no outgoing edges. Source Elimination will merge '5' into 235 * '1' new edges are generated to replace the old ones incident to '1' and '5'. 236 * 237 */ 238 struct SourceEliminationDecision : public Decision { SourceEliminationDecisionSourceEliminationDecision239 SourceEliminationDecision(StrategyPtr op1_stra, CostPtr op1_c, StrategyPtr op2_stra, CostPtr op2_c) 240 : op1_strategy_(std::move(op1_stra)), 241 op1_cost_(std::move(op1_c)), 242 op2_strategy_(std::move(op2_stra)), 243 op2_cost_(std::move(op2_c)) { 244 type_ = DecisionType::SOURCE_ELIMINATION; 245 } 246 ~SourceEliminationDecision() override = default; 247 248 StrategyPtr op1_strategy_; 249 CostPtr op1_cost_; 250 StrategyPtr op2_strategy_; 251 CostPtr op2_cost_; 252 MS_DECLARE_PARENT(SourceEliminationDecision, Decision); 253 }; 254 255 /* 'TriangleEliminationDecision' is for the Triangle Elimination in DP algorithm: 256 * 257 * u 258 * / \ 259 * / \ 260 * v --- w ==> v --- w In the original graph, u has 2 outgoing edges, v has 1 outgoing edge, 261 * and w has 2 incoming edges, u can be eliminated into v. 262 * 'eliminated_op_strategy_' is for u, 'eliminated_op_cost_' is for u, 'eliminated_left_edge_' is for edge u --> v, 263 * 'eliminated_right_edge_' is for edge u --> w. 264 */ 265 struct TriangleEliminationDecision : public Decision { TriangleEliminationDecisionTriangleEliminationDecision266 TriangleEliminationDecision(StrategyPtr elimi_stra, CostPtr elimi_op_cost, CostPtr l_edge_cost, CostPtr r_edge_cost, 267 StrategyPtr left_stra, CostPtr l_node_cost, StrategyPtr right_stra, CostPtr r_node_cost) 268 : eliminated_op_strategy_(std::move(elimi_stra)), 269 eliminated_op_cost_(std::move(elimi_op_cost)), 270 left_edge_cost_(std::move(l_edge_cost)), 271 right_edge_cost_(std::move(r_edge_cost)), 272 left_node_strategy_(std::move(left_stra)), 273 left_node_cost_(std::move(l_node_cost)), 274 right_node_strategy_(std::move(right_stra)), 275 right_node_cost_(std::move(r_node_cost)) { 276 type_ = DecisionType::TRIANGLE_ELIMINATION; 277 } 278 ~TriangleEliminationDecision() override = default; 279 280 StrategyPtr eliminated_op_strategy_; 281 CostPtr eliminated_op_cost_; 282 CostPtr left_edge_cost_; 283 CostPtr right_edge_cost_; 284 StrategyPtr left_node_strategy_; 285 CostPtr left_node_cost_; 286 StrategyPtr right_node_strategy_; 287 CostPtr right_node_cost_; 288 MS_DECLARE_PARENT(TriangleEliminationDecision, Decision); 289 }; 290 291 /* 'StarEliminationDecision' is for the Star Elimination in DP algorithm: 292 * 293 * v <--- u ---> w ==> v w In the original graph, u has 0 incoming edges, and multiple outgoing edges. 294 * In addition, v and w have other complicated connections, resulting in v and w can not be performed other 295 * eliminations. After the StarElimination, u is merged into v, and the resulting graph is split into multiple 296 * connected components. 297 * NOTE: this elimination MUST be performed only when the above 5 operation cannot be applied. 298 */ 299 struct StarEliminationDecision : public Decision { StarEliminationDecisionStarEliminationDecision300 StarEliminationDecision(StrategyPtr elimi_op_stra, CostPtr elimi_op_cost, CostPtrList succ_edges_clist, 301 std::vector<StrategyPtr> succ_ops_stra_list, CostPtrList succ_ops_clist) 302 : eliminated_op_strategy_(std::move(elimi_op_stra)), 303 eliminated_op_cost_(std::move(elimi_op_cost)), 304 succ_edges_cost_list_(std::move(succ_edges_clist)), 305 succ_ops_stra_list_(std::move(succ_ops_stra_list)), 306 succ_ops_cost_list_(std::move(succ_ops_clist)) { 307 type_ = DecisionType::STAR_ELIMINATION; 308 } 309 ~StarEliminationDecision() override = default; 310 311 StrategyPtr eliminated_op_strategy_; 312 CostPtr eliminated_op_cost_; 313 CostPtrList succ_edges_cost_list_; 314 std::vector<StrategyPtr> succ_ops_stra_list_; 315 CostPtrList succ_ops_cost_list_; 316 MS_DECLARE_PARENT(StarEliminationDecision, Decision); 317 }; 318 319 // This data structure records the decision for the graph which contains two nodes: u --> v. This includes 320 // the strategy 'u_strategy_' for 'u', the strategy 'v_strategy_' for 'v', the cost 'left_cost_' for 'u'. 321 struct FinalDecision : public Decision { FinalDecisionFinalDecision322 FinalDecision(StrategyPtr u_stra, StrategyPtr v_stra, CostPtr l_cost, CostPtr m_cost, CostPtr r_cost) 323 : u_strategy_(std::move(u_stra)), 324 v_strategy_(std::move(v_stra)), 325 left_cost_(std::move(l_cost)), 326 middle_cost_(std::move(m_cost)), 327 right_cost_(std::move(r_cost)) { 328 type_ = DecisionType::FINAL_TYPE; 329 } 330 ~FinalDecision() override = default; 331 332 StrategyPtr u_strategy_; 333 StrategyPtr v_strategy_; 334 CostPtr left_cost_; 335 CostPtr middle_cost_; 336 CostPtr right_cost_; 337 MS_DECLARE_PARENT(FinalDecision, Decision); 338 }; 339 340 // This data structure records the final decision for the graph containing a single node: u. This includes 341 // the strategy 'u_strategy_' for 'u', the cost 'u_cost_' for 'u'. 342 struct FinalSingleDecision : public Decision { FinalSingleDecisionFinalSingleDecision343 FinalSingleDecision(StrategyPtr u_stra, CostPtr u_cost) : u_strategy_(std::move(u_stra)), u_cost_(std::move(u_cost)) { 344 type_ = DecisionType::FINAL_SINGLE; 345 } 346 ~FinalSingleDecision() override = default; 347 348 StrategyPtr u_strategy_; 349 CostPtr u_cost_; 350 MS_DECLARE_PARENT(FinalSingleDecision, Decision); 351 }; 352 353 using DecisionPtr = std::shared_ptr<Decision>; 354 using OpEliminationDecisionPtr = std::shared_ptr<OpEliminationDecision>; 355 using EdgeEliminationDecisionPtr = std::shared_ptr<EdgeEliminationDecision>; 356 using MergeEliminationDecisionPtr = std::shared_ptr<MergeEliminationDecision>; 357 using ContractEliminationDecisionPtr = std::shared_ptr<ContractEliminationDecision>; 358 using SourceEliminationDecisionPtr = std::shared_ptr<SourceEliminationDecision>; 359 using TriangleEliminationDecisionPtr = std::shared_ptr<TriangleEliminationDecision>; 360 using StarEliminationDecisionPtr = std::shared_ptr<StarEliminationDecision>; 361 using FinalDecisionPtr = std::shared_ptr<FinalDecision>; 362 using FinalSingleDecisionPtr = std::shared_ptr<FinalSingleDecision>; 363 364 void Simplify(CostPtrList *clist); 365 void SimplifyForDecreasingCommunicationForward(CostPtrList *clist_ptrs); 366 void SimplifyForDecreasingCommunicationWithPartialPara(CostPtrList *clist_ptrs); 367 void RefineForPracticalCost(const CostPtr &, bool is_redistribution); 368 } // namespace parallel 369 } // namespace mindspore 370 371 #endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_AUTO_PARALLEL_COSTMODEL_H_ 372