• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_AUTO_PARALLEL_GRAPH_COSTMODEL_H_
18 #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_AUTO_PARALLEL_GRAPH_COSTMODEL_H_
19 
20 #include <map>
21 #include <memory>
22 #include <string>
23 #include <utility>
24 #include <vector>
25 #include "frontend/parallel/auto_parallel/edge_costmodel.h"
26 #include "frontend/parallel/costmodel_context.h"
27 #include "frontend/parallel/ops_info/operator_info.h"
28 #include "frontend/parallel/ops_info/tmp_identity_info.h"
29 #include "utils/ms_utils.h"
30 
31 namespace mindspore {
32 namespace parallel {
33 class CostGraph;
34 using CostGraphPtr = std::shared_ptr<CostGraph>;
35 extern CostGraphPtr entire_costgraph;
36 
37 class CostGraph {
38   // 'CostGraph' consists of Operators and edges between them. An edge is created between two Operators if they have
39   // output-input dependency relationship.
40  public:
CostGraph()41   CostGraph() {}
42   ~CostGraph() = default;
43   void Init();
AddOperator(const OperatorInfoPtr & op)44   void AddOperator(const OperatorInfoPtr &op) { ops_.push_back(op); }
FindOperatorByIndex(size_t index)45   OperatorInfoPtr FindOperatorByIndex(size_t index) {
46     if (index >= ops_.size()) {
47       MS_LOG(ERROR) << "The index: " << index << " is out of the range of ops_: " << ops_.size() << ".";
48       return nullptr;
49     }
50     return ops_[index];
51   }
52   void RemoveOperator(const OperatorInfoPtr &op);
53   bool IsOperatorInCostGraph(const OperatorInfoPtr &op);
54   void StrategyPropagate(const std::map<OperatorInfoPtr, StrategyPtr, OpsPtrCompare> &);
55   void BFS(const OperatorInfoPtr &op, const StrategyPtr &op_stra,
56            const std::map<OperatorInfoPtr, StrategyPtr, OpsPtrCompare> &configured_ops,
57            std::map<OperatorInfoPtr, bool> *visited) const;
58   void ProcessDiffStraParams(const std::map<OperatorInfoPtr, StrategyPtr, OpsPtrCompare> &configured_ops);
59   void ParamPropagation(const OperatorInfoPtr &curr_op, const std::shared_ptr<Edge> edge,
60                         const std::map<OperatorInfoPtr, StrategyPtr, OpsPtrCompare> &configured_ops) const;
61   // the edge is in the form: u --> v
62   void AddEdge(OperatorInfoPtr u_node, OperatorInfoPtr v_node, const EdgePtr &edge);
GetOriginalPrevEdges(const OperatorInfoPtr & v_node)63   std::vector<std::shared_ptr<Edge>> GetOriginalPrevEdges(const OperatorInfoPtr &v_node) { return in_edges_[v_node]; }
GetOriginalNextEdges(const OperatorInfoPtr & u_node)64   std::vector<std::shared_ptr<Edge>> GetOriginalNextEdges(const OperatorInfoPtr &u_node) { return out_edges_[u_node]; }
65   // An edge is uniquely identified by its name, and its output index and input index.
66   bool IsEdgeInCostGraph(const std::string &, size_t, size_t);
67 
68   std::vector<std::shared_ptr<CostGraph>> ConstructConnectedComponents(std::vector<OperatorInfoPtr>);
69   void DFS(const OperatorInfoPtr &current_op, std::map<OperatorInfoPtr, bool> *visited,
70            const std::shared_ptr<CostGraph> &component);
71 
72   CostPtrList CreateFinalCostList(const OperatorInfoPtr &u, const EdgePtr &e, const OperatorInfoPtr &v) const;
73   CostPtrList CreateFinalSingleCostList(const OperatorInfoPtr &u) const;
74   CostPtr SelectCostWithMinInferenceTime(const CostPtrList &cost_list, double memory) const;
75   CostPtr SelectCostWithMinTrainingTime(const CostPtrList &cost_list, double memory) const;
76   CostPtrList SelectCostListWithMinTrainingTimeMultiple(const std::vector<CostPtrList> &all_costlist,
77                                                         double memory) const;
78   Status SearchStrategyForMultiNodeFinalGraph(const std::vector<OperatorInfoPtr> &);
79   Status SearchStrategyForTwoNodeFinalGraph(const std::vector<OperatorInfoPtr> &);
GetOriginalEdgeBetweenOperators(OperatorInfoPtr u_node,OperatorInfoPtr v_node)80   std::vector<std::shared_ptr<Edge>> GetOriginalEdgeBetweenOperators(OperatorInfoPtr u_node, OperatorInfoPtr v_node) {
81     return edges_[{u_node, v_node}];
82   }
83 
84   // Search the cost_list in the final graph, and determine the optimal one
85   Status SearchStrategy();
86 
87   // Given a graph which contains the following subgraph: u --> v --> w, the node v can be eliminated
88   OperatorInfoPtr CheckOpElimination() const;
89   // Given a graph which contains the following subgraph where there are multiple edges between u and v, these edges
90   // can be eliminated into one
91   std::vector<EdgePtr> CheckEdgeElimination() const;
92   // Given a graph which contains the following subgraph:
93   //        u
94   //        |
95   //  w --- v --- x
96   // where u has 0 incoming edge, u has 1 outgoing edge, and v has > 1 incoming edges, u can be merged into v.
97   // u is returned.
98   OperatorInfoPtr CheckMergeElimination() const;
99   // Given a graph which contains the following subgraph:
100   //        u
101   //        |
102   //        v --- x
103   // where v has 2 outgoing edges, and u has 1 incoming edges and no outgoing edges. In this case, u can be contracted
104   // into v. u is returned.
105   OperatorInfoPtr CheckContractElimination() const;
106   /* Given a graph which contains the following subgraph:
107    *       u
108    *      / \
109    *     /   \
110    *    v --- w
111    * where u has 2 outgoing edges, v has 1 outgoing edge, and w has 2 incoming edges, u can be eliminated into v.
112    * The returned value includes u and the edge <u, <v, w>>.
113    */
114   std::pair<OperatorInfoPtr, EdgePtr> CheckTriangleElimination() const;
115   /* Given a graph which contains the following subgraph:
116    *  v <--- u ---> w
117    * where u has 0 incoming edges, and multiple outgoing edges. In addition, v and w have other complicated connections,
118    * resulting in v and w can not be performed ContractElimination. u is returned.
119    * NOTE: this elimination MUST be performed only when the above 5 operation cannot be applied.
120    */
121   OperatorInfoPtr CheckStarElimination() const;
122   // Applying Operator Elimination in DP algorithm
123   EdgePtr EliminationOp(const OperatorInfoPtr &op) const;
124   // Applying Edge Elimination in DP algorithm
125   EdgePtr EliminationEdges(const std::vector<EdgePtr> &edges) const;
126   // Applying Merge Elimination in DP algorithm
127   OperatorInfoPtr EliminationMerge(const OperatorInfoPtr &op) const;
128   void CreateMergeEliminationSubCostList(StrategyPtr op_strategy, const CostPtrList &op_cost_list,
129                                          const CostPtrList &edge_cost_list, StrategyPtr tar_op_strategy,
130                                          const CostPtrList &tar_cost_list, CostPtrList *tar_cost_list_new) const;
131   // Applying Contract Elimination in DP algorithm
132   OperatorInfoPtr EliminationContract(const OperatorInfoPtr &op) const;
133   void CreateContractEliminationSubCostList(StrategyPtr, const CostPtrList &, const CostPtrList &, StrategyPtr,
134                                             const CostPtrList &, CostPtrList *) const;
135 
136   // Applying Triangle Elimination in DP algorithm. return the left_node
137   OperatorInfoPtr EliminationTriangle(const OperatorInfoPtr &elimi_op, const EdgePtr &edge_left_right) const;
138   void CreateTriangleEliminationCostList(const OperatorInfoPtr &, const CostPtrList &, const CostPtrList &,
139                                          const StrategyPtr &, const StrategyPtr &, const StrategyPtr &,
140                                          const CostPtrList &, const CostPtrList &, const CostPtrList &,
141                                          CostPtrList *) const;
142   // Given the relevant costlist, create the TriangleElimination cost
143   void CreateTriangleEliminationSubCostList(StrategyPtr elimi_op_stra, StrategyPtr left_op_stra,
144                                             StrategyPtr right_op_stra, const CostPtr &right_op_cost,
145                                             const CostPtrList &elimi_op_clist, const CostPtrList &left_edge_clist,
146                                             const CostPtr &right_edge_cost, const CostPtrList &left_node_clist_origin,
147                                             CostPtrList *left_node_clist_new) const;
148 
149   // Applying the Star Elimination in DP algorithm. Return the successive edges of this merged_op
150   // NOTE: this elimination MUST be performed only when the above 5 operation cannot be applied.
151   std::vector<EdgePtr> EliminationStar(const OperatorInfoPtr &op) const;
152   void CreateStarEliminationCostList(std::vector<EdgePtr> &, const StrategyPtr &, const CostPtrList &,
153                                      const CostPtrList &, const StrategyPtr &, const CostPtrList &,
154                                      CostPtrList *) const;
155   void CreateStarEliminationSubCostList(const StrategyPtr &, const CostPtrList &, const CostPtrList &,
156                                         const StrategyPtr &, const CostPtrList &, std::vector<StrategyPtr>,
157                                         CostPtrList &, CostPtrList &, CostPtrList *) const;
158   // Return <op1, op2>. we merge 'op2' into 'op1'
159   std::pair<OperatorInfoPtr, OperatorInfoPtr> CheckSourceElimination() const;
160   void CreateSourceEliminationSubCostList(StrategyPtr op1_old_stra, const CostPtrList &op1_old_clist,
161                                           StrategyPtr op2_old_stra, const CostPtrList &op2_old_clist,
162                                           CostPtrList *op1_new_clist) const;
163   // We merge 'op2' into op1. The returned value are '<Edges1, Edges2>'. 'Edges1' are newly updated edges for 'op1',
164   // 'Edges2' are newly updated edges for 'op2'.
165   std::pair<std::vector<std::shared_ptr<Edge>>, std::vector<std::shared_ptr<Edge>>> EliminationSources(
166     const OperatorInfoPtr op1, const OperatorInfoPtr op2) const;
167   // Calculate memory cost for training phase or inference phase.
168   Status CalculateMemoryCost();
169   // When the input of a operator is neither a WEIGHT, nor a output of a subsequent operator involving WEIGHT, then
170   // the memory cost can be resused. This is used to calculate memory in the training phase.
171   Status CalculateOpsMemoryCost();
172   // When the input of the edge is neither a WEIGHT, nor a output of a subsequent operator involving WEIGHT, then
173   // the memory cost can be reused. This is used to calculate memory in the training phase.
174   Status CalculateEdgesMemoryCost();
175   // Calculate memory cost of operators in the inference phase.
176   Status CalculateOpsMemoryCostForInference();
177   // Calculate memory cost of edges in the inference phase.
178   Status CalculateEdgesMemoryCostForInference();
179   Status ComputeOpsAndEdgesParameterInvolved();
180   // Compute for each operator whether the output is critical.
181   Status ComputeOpsAndEdgesOutputCritical();
182 
GetOperators()183   std::vector<OperatorInfoPtr> GetOperators() const { return ops_; }
184   size_t GetNumEdges() const;
185   Status InitReshapeStrategy();
186   Status InitSelectedStrategy();
187   OperatorInfoPtr FindTmpIdentityByParameterName(const std::string &p_name) const;
188   // When TmpIdentity is used by multiple operators, the corresponding parameter's memory cost should be calculated only
189   // once (instead of multiple times), this method is used to correct this.
190   Status CorrectOpsMemoryCost();
191   // When APPROXIMATION is enabled in the DP algorithm, some edges may have no valid strategies.
192   // This method is to re-init those edge involved operators.
193   void CheckApproximateCostGraphEdges();
194   // Needed by rec_parser
add_inputs_tensor_name(const std::vector<std::string> & inputs_tensor_name)195   void add_inputs_tensor_name(const std::vector<std::string> &inputs_tensor_name) {
196     inputs_tensor_name_list_.push_back(inputs_tensor_name);
197   }
get_inputs_tensor_name_list()198   const std::vector<std::vector<std::string>> get_inputs_tensor_name_list() const { return inputs_tensor_name_list_; }
set_inputs_tensor_name_list(const std::vector<std::vector<std::string>> & inputs_tensor_name_list)199   void set_inputs_tensor_name_list(const std::vector<std::vector<std::string>> &inputs_tensor_name_list) {
200     inputs_tensor_name_list_ = inputs_tensor_name_list;
201   }
202   // Needed by rec_parser 2
add_param_users_uniqueid(const std::vector<std::string> & param_users_uniqueid)203   void add_param_users_uniqueid(const std::vector<std::string> &param_users_uniqueid) {
204     param_users_uniqueid_list_.push_back(param_users_uniqueid);
205   }
get_param_users_uniqueid_list()206   const std::vector<std::vector<std::string>> get_param_users_uniqueid_list() const {
207     return param_users_uniqueid_list_;
208   }
add_tuple_getitem(const std::pair<std::string,std::string> & tuple_getitem)209   void add_tuple_getitem(const std::pair<std::string, std::string> &tuple_getitem) {
210     auto ret = tuple_getitem_list_.insert(tuple_getitem);
211     if (ret.second == false) {
212       MS_LOG(EXCEPTION) << "The insert item is already exist.";
213     }
214   }
get_tuple_getitem_list()215   const std::map<std::string, std::string> get_tuple_getitem_list() const { return tuple_getitem_list_; }
216 
217  private:
218   void TopologyOrder(std::vector<OperatorInfoPtr> *topo_order);
219   void DFSForTopoOrder(const OperatorInfoPtr &current_op, std::map<OperatorInfoPtr, bool> *visited,
220                        std::vector<OperatorInfoPtr> *topo_order);
221   Status DetermineCriticalOps(const std::vector<OperatorInfoPtr> &topo_order);
222   void MarkCriticalOpsAndEdges(const std::map<OperatorInfoPtr, int64_t> &candidate_ops);
223   // Needed by rec_parser
224   std::vector<std::vector<std::string>> inputs_tensor_name_list_;
225   // Needed by rec_parser 2
226   std::vector<std::vector<std::string>> param_users_uniqueid_list_;
227   std::map<std::string, std::string> tuple_getitem_list_;
228   std::vector<OperatorInfoPtr> ops_;
229   std::map<std::pair<OperatorInfoPtr, OperatorInfoPtr>, std::vector<EdgePtr>> edges_;
230   std::vector<std::shared_ptr<CostGraph>> connected_compoents_;
231   std::map<OperatorInfoPtr, std::vector<EdgePtr>> out_edges_;
232   std::map<OperatorInfoPtr, std::vector<EdgePtr>> in_edges_;
233 };
234 }  // namespace parallel
235 }  // namespace mindspore
236 
237 #endif  // MINDSPORE_CCSRC_FRONTEND_PARALLEL_AUTO_PARALLEL_GRAPH_COSTMODEL_H_
238