• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_AUTO_PARALLEL_GRAPH_COSTMODEL_H_
18 #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_AUTO_PARALLEL_GRAPH_COSTMODEL_H_
19 
20 #include <map>
21 #include <memory>
22 #include <string>
23 #include <utility>
24 #include <vector>
25 #include "frontend/parallel/auto_parallel/edge_costmodel.h"
26 #include "frontend/parallel/costmodel_context.h"
27 #include "frontend/parallel/ops_info/operator_info.h"
28 #include "frontend/parallel/ops_info/tmp_identity_info.h"
29 #include "utils/ms_utils.h"
30 
31 namespace mindspore {
32 namespace parallel {
33 class CostGraph;
34 using CostGraphPtr = std::shared_ptr<CostGraph>;
35 extern CostGraphPtr entire_costgraph;
36 extern size_t TOTAL_OPS;
37 
38 class CostGraph {
39   // 'CostGraph' consists of Operators and edges between them. An edge is created between two Operators if they have
40   // output-input dependency relationship.
41  public:
CostGraph()42   CostGraph() {}
43   ~CostGraph() = default;
44   void Init();
AddOperator(const OperatorInfoPtr & op)45   void AddOperator(const OperatorInfoPtr &op) { ops_.push_back(op); }
FindOperatorByIndex(size_t index)46   OperatorInfoPtr FindOperatorByIndex(size_t index) {
47     if (index >= ops_.size()) {
48       MS_LOG(ERROR) << "The index: " << index << " is out of the range of ops_: " << ops_.size() << ".";
49       return nullptr;
50     }
51     return ops_[index];
52   }
53   void RemoveOperator(const OperatorInfoPtr &op);
54   bool IsOperatorInCostGraph(const OperatorInfoPtr &op);
55   void StrategyPropagate(const std::map<OperatorInfoPtr, StrategyPtr> &);
56   void BFS(const OperatorInfoPtr &, const StrategyPtr &, std::map<OperatorInfoPtr, StrategyPtr>,
57            std::map<OperatorInfoPtr, bool> *);
58   // the edge is in the form: u --> v
59   void AddEdge(OperatorInfoPtr u_node, OperatorInfoPtr v_node, const EdgePtr &edge);
GetOriginalPrevEdges(OperatorInfoPtr v_node)60   std::vector<std::shared_ptr<Edge>> GetOriginalPrevEdges(OperatorInfoPtr v_node) { return in_edges_[v_node]; }
GetOriginalNextEdges(OperatorInfoPtr u_node)61   std::vector<std::shared_ptr<Edge>> GetOriginalNextEdges(OperatorInfoPtr u_node) { return out_edges_[u_node]; }
62   // An edge is uniquely identified by its name, and its output index and input index.
63   bool IsEdgeInCostGraph(const std::string &, size_t, size_t);
64 
65   std::vector<std::shared_ptr<CostGraph>> ConstructConnectedComponents(std::vector<OperatorInfoPtr>);
66   void DFS(const OperatorInfoPtr &current_op, std::map<OperatorInfoPtr, bool> *visited,
67            const std::shared_ptr<CostGraph> &component);
68 
69   CostPtrList CreateFinalCostList(const OperatorInfoPtr &u, const EdgePtr &e, const OperatorInfoPtr &v);
70   CostPtrList CreateFinalSingleCostList(const OperatorInfoPtr &u);
71   CostPtr SelectCostWithMinInferenceTime(const CostPtrList &cost_list, double memory);
72   CostPtr SelectCostWithMinTrainingTime(const CostPtrList &cost_list, double memory);
73   CostPtrList SelectCostListWithMinTrainingTimeMultiple(const std::vector<CostPtrList> &all_costlist,
74                                                         double memory) const;
75   Status SearchStrategyForMultiNodeFinalGraph(const std::vector<OperatorInfoPtr> &);
76   Status SearchStrategyForTwoNodeFinalGraph(const std::vector<OperatorInfoPtr> &);
GetOriginalEdgeBetweenOperators(OperatorInfoPtr u_node,OperatorInfoPtr v_node)77   std::vector<std::shared_ptr<Edge>> GetOriginalEdgeBetweenOperators(OperatorInfoPtr u_node, OperatorInfoPtr v_node) {
78     return edges_[{u_node, v_node}];
79   }
80 
81   // Search the cost_list in the final graph, and determine the optimal one
82   Status SearchStrategy();
83 
84   // Given a graph which contains the following subgraph: u --> v --> w, the node v can be eliminated
85   OperatorInfoPtr CheckOpElimination() const;
86   // Given a graph which contains the following subgraph where there are multiple edges between u and v, these edges
87   // can be eliminated into one
88   std::vector<EdgePtr> CheckEdgeElimination() const;
89   // Given a graph which contains the following subgraph:
90   //        u
91   //        |
92   //  w --- v --- x
93   // where u has 0 incoming edge, u has 1 outgoing edge, and v has > 1 incoming edges, u can be merged into v.
94   // u is returned.
95   OperatorInfoPtr CheckMergeElimination() const;
96   // Given a graph which contains the following subgraph:
97   //        u
98   //        |
99   //        v --- x
100   // where v has 2 outgoing edges, and u has 1 incoming edges and no outgoing edges. In this case, u can be contracted
101   // into v. u is returned.
102   OperatorInfoPtr CheckContractElimination() const;
103   /* Given a graph which contains the following subgraph:
104    *       u
105    *      / \
106    *     /   \
107    *    v --- w
108    * where u has 2 outgoing edges, v has 1 outgoing edge, and w has 2 incoming edges, u can be eliminated into v.
109    * The returned value includes u and the edge <u, <v, w>>.
110    */
111   std::pair<OperatorInfoPtr, EdgePtr> CheckTriangleElimination() const;
112   /* Given a graph which contains the following subgraph:
113    *  v <--- u ---> w
114    * where u has 0 incoming edges, and multiple outgoing edges. In addition, v and w have other complicated connections,
115    * resulting in v and w can not be performed ContractElimination. u is returned.
116    * NOTE: this elimination MUST be performed only when the above 5 operation cannot be applied.
117    */
118   OperatorInfoPtr CheckStarElimination() const;
119   // Applying Operator Elimination in DP algorithm
120   EdgePtr EliminationOp(const OperatorInfoPtr &op);
121   // Applying Edge Elimination in DP algorithm
122   EdgePtr EliminationEdges(const std::vector<EdgePtr> &edges);
123   // Applying Merge Elimination in DP algorithm
124   OperatorInfoPtr EliminationMerge(const OperatorInfoPtr &op);
125   void CreateMergeEliminationSubCostList(StrategyPtr op_strategy, const CostPtrList &op_cost_list,
126                                          const CostPtrList &edge_cost_list, StrategyPtr tar_op_strategy,
127                                          const CostPtrList &tar_cost_list, CostPtrList *tar_cost_list_new);
128   // Applying Contract Elimination in DP algorithm
129   OperatorInfoPtr EliminationContract(const OperatorInfoPtr &op);
130   void CreateContractEliminationSubCostList(StrategyPtr, const CostPtrList &, const CostPtrList &, StrategyPtr,
131                                             const CostPtrList &, CostPtrList *);
132 
133   // Applying Triangle Elimination in DP algorithm. return the left_node
134   OperatorInfoPtr EliminationTriangle(const OperatorInfoPtr &elimi_op, const EdgePtr &edge_left_right);
135   void CreateTriangleEliminationCostList(const OperatorInfoPtr &, const CostPtrList &, const CostPtrList &,
136                                          const StrategyPtr &, const StrategyPtr &, const StrategyPtr &,
137                                          const CostPtrList &, const CostPtrList &, const CostPtrList &, CostPtrList *);
138   // Given the relevant costlist, create the TriangleElimination cost
139   void CreateTriangleEliminationSubCostList(StrategyPtr, StrategyPtr, StrategyPtr, const CostPtr &, const CostPtrList &,
140                                             const CostPtrList &, const CostPtr &, const CostPtrList &, CostPtrList *);
141 
142   // Applying the Star Elimination in DP algorithm. Return the successive edges of this merged_op
143   // NOTE: this elimination MUST be performed only when the above 5 operation cannot be applied.
144   std::vector<EdgePtr> EliminationStar(const OperatorInfoPtr &op);
145   void CreateStarEliminationCostList(std::vector<EdgePtr> &, const StrategyPtr &, const CostPtrList &,
146                                      const CostPtrList &, const StrategyPtr &, const CostPtrList &, CostPtrList *);
147   void CreateStarEliminationSubCostList(const StrategyPtr &, const CostPtrList &, const CostPtrList &,
148                                         const StrategyPtr &, const CostPtrList &, std::vector<StrategyPtr>,
149                                         CostPtrList &, CostPtrList &, CostPtrList *);
150   // Return <op1, op2>. we merge 'op2' into 'op1'
151   std::pair<OperatorInfoPtr, OperatorInfoPtr> CheckSourceElimination() const;
152   void CreateSourceEliminationSubCostList(StrategyPtr, const CostPtrList &, StrategyPtr, const CostPtrList &,
153                                           CostPtrList *);
154   // We merge 'op2' into op1. The returned value are '<Edges1, Edges2>'. 'Edges1' are newly updated edges for 'op1',
155   // 'Edges2' are newly updated edges for 'op2'.
156   std::pair<std::vector<std::shared_ptr<Edge>>, std::vector<std::shared_ptr<Edge>>> EliminationSources(
157     const OperatorInfoPtr op1, const OperatorInfoPtr op2);
158   // Calculate memory cost for training phase or inference phase.
159   Status CalculateMemoryCost();
160   // When the input of a operator is neither a WEIGHT, nor a output of a subsequent operator involving WEIGHT, then
161   // the memory cost can be resused. This is used to calculate memory in the training phase.
162   Status CalculateOpsMemoryCost();
163   // When the input of the edge is neither a WEIGHT, nor a output of a subsequent operator involving WEIGHT, then
164   // the memory cost can be reused. This is used to calculate memory in the training phase.
165   Status CalculateEdgesMemoryCost();
166   // Calculate memory cost of operators in the inference phase.
167   Status CalculateOpsMemoryCostForInference();
168   // Calculate memory cost of edges in the inference phase.
169   Status CalculateEdgesMemoryCostForInference();
170   Status ComputeOpsAndEdgesParameterInvolved();
171   // Compute for each operator whether the output is critical.
172   Status ComputeOpsAndEdgesOutputCritical();
173 
GetOperators()174   std::vector<OperatorInfoPtr> GetOperators() const { return ops_; }
175   size_t GetNumEdges() const;
176   Status InitReshapeStrategy();
177   Status InitSelectedStrategy();
178   OperatorInfoPtr FindTmpIdentityByParameterName(std::string &) const;
179   // When TmpIdentity is used by multiple operators, the corresponding parameter's memory cost should be calculated only
180   // once (instead of multiple times), this method is used to correct this.
181   Status CorrectOpsMemoryCost();
182   // When APPROXIMATION is enabled in the DP algorithm, some edges may have no valid strategies.
183   // This method is to re-init those edge involved operators.
184   void CheckApproximateCostGraphEdges();
185   // Needed by rec_parser
add_inputs_tensor_name(const std::vector<std::string> & inputs_tensor_name)186   void add_inputs_tensor_name(const std::vector<std::string> &inputs_tensor_name) {
187     inputs_tensor_name_list_.push_back(inputs_tensor_name);
188   }
get_inputs_tensor_name_list()189   const std::vector<std::vector<std::string>> get_inputs_tensor_name_list() const { return inputs_tensor_name_list_; }
set_inputs_tensor_name_list(const std::vector<std::vector<std::string>> & inputs_tensor_name_list)190   void set_inputs_tensor_name_list(const std::vector<std::vector<std::string>> &inputs_tensor_name_list) {
191     inputs_tensor_name_list_ = inputs_tensor_name_list;
192   }
add_tuple_getitem(const std::pair<std::string,std::string> & tuple_getitem)193   void add_tuple_getitem(const std::pair<std::string, std::string> &tuple_getitem) {
194     auto ret = tuple_getitem_list_.insert(tuple_getitem);
195     if (ret.second == false) {
196       MS_LOG(EXCEPTION) << "The insert item is already exist.";
197     }
198   }
get_tuple_getitem_list()199   const std::map<std::string, std::string> get_tuple_getitem_list() const { return tuple_getitem_list_; }
200 
201  private:
202   void TopologyOrder(std::vector<OperatorInfoPtr> *);
203   void DFSForTopoOrder(const OperatorInfoPtr &, std::map<OperatorInfoPtr, bool> *, std::vector<OperatorInfoPtr> *);
204   Status DetermineCriticalOps(const std::vector<OperatorInfoPtr> &);
205   void MarkCriticalOpsAndEdges(const std::map<OperatorInfoPtr, int64_t> &);
206   // Needed by rec_parser
207   std::vector<std::vector<std::string>> inputs_tensor_name_list_;
208   std::map<std::string, std::string> tuple_getitem_list_;
209   std::vector<OperatorInfoPtr> ops_;
210   std::map<std::pair<OperatorInfoPtr, OperatorInfoPtr>, std::vector<EdgePtr>> edges_;
211   std::vector<std::shared_ptr<CostGraph>> connected_compoents_;
212   std::map<OperatorInfoPtr, std::vector<EdgePtr>> out_edges_;
213   std::map<OperatorInfoPtr, std::vector<EdgePtr>> in_edges_;
214 };
215 }  // namespace parallel
216 }  // namespace mindspore
217 
218 #endif  // MINDSPORE_CCSRC_FRONTEND_PARALLEL_AUTO_PARALLEL_GRAPH_COSTMODEL_H_
219