1 /** 2 * Copyright 2019 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_AUTO_PARALLEL_GRAPH_COSTMODEL_H_ 18 #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_AUTO_PARALLEL_GRAPH_COSTMODEL_H_ 19 20 #include <map> 21 #include <memory> 22 #include <string> 23 #include <utility> 24 #include <vector> 25 #include "frontend/parallel/auto_parallel/edge_costmodel.h" 26 #include "frontend/parallel/costmodel_context.h" 27 #include "frontend/parallel/ops_info/operator_info.h" 28 #include "frontend/parallel/ops_info/tmp_identity_info.h" 29 #include "utils/ms_utils.h" 30 31 namespace mindspore { 32 namespace parallel { 33 class CostGraph; 34 using CostGraphPtr = std::shared_ptr<CostGraph>; 35 extern CostGraphPtr entire_costgraph; 36 extern size_t TOTAL_OPS; 37 38 class CostGraph { 39 // 'CostGraph' consists of Operators and edges between them. An edge is created between two Operators if they have 40 // output-input dependency relationship. 41 public: CostGraph()42 CostGraph() {} 43 ~CostGraph() = default; 44 void Init(); AddOperator(const OperatorInfoPtr & op)45 void AddOperator(const OperatorInfoPtr &op) { ops_.push_back(op); } FindOperatorByIndex(size_t index)46 OperatorInfoPtr FindOperatorByIndex(size_t index) { 47 if (index >= ops_.size()) { 48 MS_LOG(ERROR) << "The index: " << index << " is out of the range of ops_: " << ops_.size() << "."; 49 return nullptr; 50 } 51 return ops_[index]; 52 } 53 void RemoveOperator(const OperatorInfoPtr &op); 54 bool IsOperatorInCostGraph(const OperatorInfoPtr &op); 55 void StrategyPropagate(const std::map<OperatorInfoPtr, StrategyPtr> &); 56 void BFS(const OperatorInfoPtr &, const StrategyPtr &, std::map<OperatorInfoPtr, StrategyPtr>, 57 std::map<OperatorInfoPtr, bool> *); 58 // the edge is in the form: u --> v 59 void AddEdge(OperatorInfoPtr u_node, OperatorInfoPtr v_node, const EdgePtr &edge); GetOriginalPrevEdges(OperatorInfoPtr v_node)60 std::vector<std::shared_ptr<Edge>> GetOriginalPrevEdges(OperatorInfoPtr v_node) { return in_edges_[v_node]; } GetOriginalNextEdges(OperatorInfoPtr u_node)61 std::vector<std::shared_ptr<Edge>> GetOriginalNextEdges(OperatorInfoPtr u_node) { return out_edges_[u_node]; } 62 // An edge is uniquely identified by its name, and its output index and input index. 63 bool IsEdgeInCostGraph(const std::string &, size_t, size_t); 64 65 std::vector<std::shared_ptr<CostGraph>> ConstructConnectedComponents(std::vector<OperatorInfoPtr>); 66 void DFS(const OperatorInfoPtr ¤t_op, std::map<OperatorInfoPtr, bool> *visited, 67 const std::shared_ptr<CostGraph> &component); 68 69 CostPtrList CreateFinalCostList(const OperatorInfoPtr &u, const EdgePtr &e, const OperatorInfoPtr &v); 70 CostPtrList CreateFinalSingleCostList(const OperatorInfoPtr &u); 71 CostPtr SelectCostWithMinInferenceTime(const CostPtrList &cost_list, double memory); 72 CostPtr SelectCostWithMinTrainingTime(const CostPtrList &cost_list, double memory); 73 CostPtrList SelectCostListWithMinTrainingTimeMultiple(const std::vector<CostPtrList> &all_costlist, 74 double memory) const; 75 Status SearchStrategyForMultiNodeFinalGraph(const std::vector<OperatorInfoPtr> &); 76 Status SearchStrategyForTwoNodeFinalGraph(const std::vector<OperatorInfoPtr> &); GetOriginalEdgeBetweenOperators(OperatorInfoPtr u_node,OperatorInfoPtr v_node)77 std::vector<std::shared_ptr<Edge>> GetOriginalEdgeBetweenOperators(OperatorInfoPtr u_node, OperatorInfoPtr v_node) { 78 return edges_[{u_node, v_node}]; 79 } 80 81 // Search the cost_list in the final graph, and determine the optimal one 82 Status SearchStrategy(); 83 84 // Given a graph which contains the following subgraph: u --> v --> w, the node v can be eliminated 85 OperatorInfoPtr CheckOpElimination() const; 86 // Given a graph which contains the following subgraph where there are multiple edges between u and v, these edges 87 // can be eliminated into one 88 std::vector<EdgePtr> CheckEdgeElimination() const; 89 // Given a graph which contains the following subgraph: 90 // u 91 // | 92 // w --- v --- x 93 // where u has 0 incoming edge, u has 1 outgoing edge, and v has > 1 incoming edges, u can be merged into v. 94 // u is returned. 95 OperatorInfoPtr CheckMergeElimination() const; 96 // Given a graph which contains the following subgraph: 97 // u 98 // | 99 // v --- x 100 // where v has 2 outgoing edges, and u has 1 incoming edges and no outgoing edges. In this case, u can be contracted 101 // into v. u is returned. 102 OperatorInfoPtr CheckContractElimination() const; 103 /* Given a graph which contains the following subgraph: 104 * u 105 * / \ 106 * / \ 107 * v --- w 108 * where u has 2 outgoing edges, v has 1 outgoing edge, and w has 2 incoming edges, u can be eliminated into v. 109 * The returned value includes u and the edge <u, <v, w>>. 110 */ 111 std::pair<OperatorInfoPtr, EdgePtr> CheckTriangleElimination() const; 112 /* Given a graph which contains the following subgraph: 113 * v <--- u ---> w 114 * where u has 0 incoming edges, and multiple outgoing edges. In addition, v and w have other complicated connections, 115 * resulting in v and w can not be performed ContractElimination. u is returned. 116 * NOTE: this elimination MUST be performed only when the above 5 operation cannot be applied. 117 */ 118 OperatorInfoPtr CheckStarElimination() const; 119 // Applying Operator Elimination in DP algorithm 120 EdgePtr EliminationOp(const OperatorInfoPtr &op); 121 // Applying Edge Elimination in DP algorithm 122 EdgePtr EliminationEdges(const std::vector<EdgePtr> &edges); 123 // Applying Merge Elimination in DP algorithm 124 OperatorInfoPtr EliminationMerge(const OperatorInfoPtr &op); 125 void CreateMergeEliminationSubCostList(StrategyPtr op_strategy, const CostPtrList &op_cost_list, 126 const CostPtrList &edge_cost_list, StrategyPtr tar_op_strategy, 127 const CostPtrList &tar_cost_list, CostPtrList *tar_cost_list_new); 128 // Applying Contract Elimination in DP algorithm 129 OperatorInfoPtr EliminationContract(const OperatorInfoPtr &op); 130 void CreateContractEliminationSubCostList(StrategyPtr, const CostPtrList &, const CostPtrList &, StrategyPtr, 131 const CostPtrList &, CostPtrList *); 132 133 // Applying Triangle Elimination in DP algorithm. return the left_node 134 OperatorInfoPtr EliminationTriangle(const OperatorInfoPtr &elimi_op, const EdgePtr &edge_left_right); 135 void CreateTriangleEliminationCostList(const OperatorInfoPtr &, const CostPtrList &, const CostPtrList &, 136 const StrategyPtr &, const StrategyPtr &, const StrategyPtr &, 137 const CostPtrList &, const CostPtrList &, const CostPtrList &, CostPtrList *); 138 // Given the relevant costlist, create the TriangleElimination cost 139 void CreateTriangleEliminationSubCostList(StrategyPtr, StrategyPtr, StrategyPtr, const CostPtr &, const CostPtrList &, 140 const CostPtrList &, const CostPtr &, const CostPtrList &, CostPtrList *); 141 142 // Applying the Star Elimination in DP algorithm. Return the successive edges of this merged_op 143 // NOTE: this elimination MUST be performed only when the above 5 operation cannot be applied. 144 std::vector<EdgePtr> EliminationStar(const OperatorInfoPtr &op); 145 void CreateStarEliminationCostList(std::vector<EdgePtr> &, const StrategyPtr &, const CostPtrList &, 146 const CostPtrList &, const StrategyPtr &, const CostPtrList &, CostPtrList *); 147 void CreateStarEliminationSubCostList(const StrategyPtr &, const CostPtrList &, const CostPtrList &, 148 const StrategyPtr &, const CostPtrList &, std::vector<StrategyPtr>, 149 CostPtrList &, CostPtrList &, CostPtrList *); 150 // Return <op1, op2>. we merge 'op2' into 'op1' 151 std::pair<OperatorInfoPtr, OperatorInfoPtr> CheckSourceElimination() const; 152 void CreateSourceEliminationSubCostList(StrategyPtr, const CostPtrList &, StrategyPtr, const CostPtrList &, 153 CostPtrList *); 154 // We merge 'op2' into op1. The returned value are '<Edges1, Edges2>'. 'Edges1' are newly updated edges for 'op1', 155 // 'Edges2' are newly updated edges for 'op2'. 156 std::pair<std::vector<std::shared_ptr<Edge>>, std::vector<std::shared_ptr<Edge>>> EliminationSources( 157 const OperatorInfoPtr op1, const OperatorInfoPtr op2); 158 // Calculate memory cost for training phase or inference phase. 159 Status CalculateMemoryCost(); 160 // When the input of a operator is neither a WEIGHT, nor a output of a subsequent operator involving WEIGHT, then 161 // the memory cost can be resused. This is used to calculate memory in the training phase. 162 Status CalculateOpsMemoryCost(); 163 // When the input of the edge is neither a WEIGHT, nor a output of a subsequent operator involving WEIGHT, then 164 // the memory cost can be reused. This is used to calculate memory in the training phase. 165 Status CalculateEdgesMemoryCost(); 166 // Calculate memory cost of operators in the inference phase. 167 Status CalculateOpsMemoryCostForInference(); 168 // Calculate memory cost of edges in the inference phase. 169 Status CalculateEdgesMemoryCostForInference(); 170 Status ComputeOpsAndEdgesParameterInvolved(); 171 // Compute for each operator whether the output is critical. 172 Status ComputeOpsAndEdgesOutputCritical(); 173 GetOperators()174 std::vector<OperatorInfoPtr> GetOperators() const { return ops_; } 175 size_t GetNumEdges() const; 176 Status InitReshapeStrategy(); 177 Status InitSelectedStrategy(); 178 OperatorInfoPtr FindTmpIdentityByParameterName(std::string &) const; 179 // When TmpIdentity is used by multiple operators, the corresponding parameter's memory cost should be calculated only 180 // once (instead of multiple times), this method is used to correct this. 181 Status CorrectOpsMemoryCost(); 182 // When APPROXIMATION is enabled in the DP algorithm, some edges may have no valid strategies. 183 // This method is to re-init those edge involved operators. 184 void CheckApproximateCostGraphEdges(); 185 // Needed by rec_parser add_inputs_tensor_name(const std::vector<std::string> & inputs_tensor_name)186 void add_inputs_tensor_name(const std::vector<std::string> &inputs_tensor_name) { 187 inputs_tensor_name_list_.push_back(inputs_tensor_name); 188 } get_inputs_tensor_name_list()189 const std::vector<std::vector<std::string>> get_inputs_tensor_name_list() const { return inputs_tensor_name_list_; } set_inputs_tensor_name_list(const std::vector<std::vector<std::string>> & inputs_tensor_name_list)190 void set_inputs_tensor_name_list(const std::vector<std::vector<std::string>> &inputs_tensor_name_list) { 191 inputs_tensor_name_list_ = inputs_tensor_name_list; 192 } add_tuple_getitem(const std::pair<std::string,std::string> & tuple_getitem)193 void add_tuple_getitem(const std::pair<std::string, std::string> &tuple_getitem) { 194 auto ret = tuple_getitem_list_.insert(tuple_getitem); 195 if (ret.second == false) { 196 MS_LOG(EXCEPTION) << "The insert item is already exist."; 197 } 198 } get_tuple_getitem_list()199 const std::map<std::string, std::string> get_tuple_getitem_list() const { return tuple_getitem_list_; } 200 201 private: 202 void TopologyOrder(std::vector<OperatorInfoPtr> *); 203 void DFSForTopoOrder(const OperatorInfoPtr &, std::map<OperatorInfoPtr, bool> *, std::vector<OperatorInfoPtr> *); 204 Status DetermineCriticalOps(const std::vector<OperatorInfoPtr> &); 205 void MarkCriticalOpsAndEdges(const std::map<OperatorInfoPtr, int64_t> &); 206 // Needed by rec_parser 207 std::vector<std::vector<std::string>> inputs_tensor_name_list_; 208 std::map<std::string, std::string> tuple_getitem_list_; 209 std::vector<OperatorInfoPtr> ops_; 210 std::map<std::pair<OperatorInfoPtr, OperatorInfoPtr>, std::vector<EdgePtr>> edges_; 211 std::vector<std::shared_ptr<CostGraph>> connected_compoents_; 212 std::map<OperatorInfoPtr, std::vector<EdgePtr>> out_edges_; 213 std::map<OperatorInfoPtr, std::vector<EdgePtr>> in_edges_; 214 }; 215 } // namespace parallel 216 } // namespace mindspore 217 218 #endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_AUTO_PARALLEL_GRAPH_COSTMODEL_H_ 219