• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_PIPELINE_TRANSFORMER_PIPELINE_TRANSFORMER_H_
18 #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_PIPELINE_TRANSFORMER_PIPELINE_TRANSFORMER_H_
19 
20 #include <set>
21 #include <utility>
22 #include <string>
23 #include <memory>
24 #include <vector>
25 #include "ir/value.h"
26 #include "ir/graph_utils.h"
27 #include "base/base.h"
28 #include "utils/hash_map.h"
29 #include "frontend/parallel/step_parallel.h"
30 #include "frontend/parallel/graph_util/generate_graph.h"
31 #include "ops/array_ops.h"
32 
33 namespace mindspore {
34 namespace parallel {
35 using TensorInfoPtr = std::shared_ptr<TensorInfo>;
36 
37 typedef struct {
38   ValueListPtr shape;
39   TypePtr type;
40   AnfNodePtr depend;
41 } SendAttr;
42 
43 class PipelineTransformer {
44  public:
PipelineTransformer(const FuncGraphManagerPtr & manager,int stage,const FuncGraphPtr & root,int64_t global_rank,int64_t per_stage_rank_num)45   PipelineTransformer(const FuncGraphManagerPtr &manager, int stage, const FuncGraphPtr &root, int64_t global_rank,
46                       int64_t per_stage_rank_num)
47       : manager_(manager),
48         stage_(stage),
49         root_(root),
50         main_graph_(nullptr),
51         virtual_dataset_(nullptr),
52         global_rank_(global_rank),
53         per_stage_rank_num_(per_stage_rank_num) {}
54   virtual ~PipelineTransformer() = default;
55   virtual void Coloring();
56   virtual void BroadCastColoring();
57   virtual void CutGraph();
58   void LabelGenMaskFusion();
59   bool MainGraph();
60   void LabelMicroBatch();
61   void ParameterColoring();
62   void ElimGraphStage();
63   void ModifyParameterList();
64 
65   AnfNodePtr GetArgumentsByParameter(const AnfNodePtr &parameter);
66   void RemoveMonadNode();
67   bool HasNoUpdateParameter();
68   AnfNodePtr CreateTupleZeroTensor(const AnfNodePtr &node, size_t index);
69   std::vector<AnfNodePtr> GetLoadNodeByParam(const AnfNodePtr &param) const;
70   AnfNodePtr ActualOp(const AnfNodePtr &node);
71   bool IsParameterGraph(const AnfNodePtr &node) const;
72   virtual AnfNodePtr HandleParameterGraph(const AnfNodePtr &node, const AnfNodePtr &use_node, int64_t stage,
73                                           int64_t user_stage, const ValuePtr &micro, size_t pos,
74                                           const std::vector<AnfNodePtr> &ops);
75   AnfNodeIndexSet GetParameterLoadUsers(const AnfNodePtr &node, const NodeUsersMap &node_users_map) const;
76   ValuePtr SetMicroBatch(const AnfNodePtr &node, int64_t micro_size, size_t batch_axis) const;
77   virtual std::pair<std::vector<AnfNodePtr>, std::vector<AnfNodePtr>> HandleSharedParameter();
78   SendAttr InsertSend(const AnfNodePtr &parameter, int64_t user_node_stage, int64_t node_stage, const ValuePtr &value);
79   AnfNodePtr InsertReceive(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr &use_node, int index,
80                            int64_t user_node_stage, int64_t node_stage, const ValuePtr &value,
81                            const AnfNodePtr &graph_param);
82   virtual std::pair<std::vector<AnfNodePtr>, std::vector<AnfNodePtr>> CutBorder(const FuncGraphPtr &graph);
83   void CutBorderForNode(const FuncGraphPtr &graph, const AnfNodePtr &node, std::vector<AnfNodePtr> *send_ops,
84                         std::vector<AnfNodePtr> *receive_ops);
85   AnfNodePtr Reuse(const AnfNodePtr &node, int64_t stage, const std::vector<AnfNodePtr> &out_input,
86                    const std::string &tag) const;
87   AnfNodePtr FindPipelineCareNode(const AnfNodePtr &node) const;
88   std::pair<OperatorInfoPtr, int> GetOpInfo(const AnfNodePtr &node);
89   TensorInfo GetTensorInfo(const std::pair<OperatorInfoPtr, int> &op_info_pair, bool is_param);
90   std::pair<OperatorInfoPtr, int> GetParameterPair(const AnfNodePtr &node);
91   OperatorInfoPtr CreateOpInfo(const CNodePtr &cnode, int tuple_index);
92   bool LabelParameterStart(const FuncGraphPtr &graph);
93   bool NeedGrad(const CNodePtr &cnode);
94   CNodePtr GraphOutNode(const AnfNodePtr &node, int tuple_index);
95   bool IsPipelineCareNode(const CNodePtr &cnode) const;
96   void RedundancyNode(const AnfNodePtr &node, mindspore::HashMap<CNodePtr, std::vector<AnfNodePtr>> *make_tuple_map);
97   bool IsRedundancyParameter(const AnfNodePtr &parameter, const std::vector<AnfNodePtr> &non_cloned_parameters);
98   void ElimParameter();
99   void FreezeGradient();
100   AnfNodePtr CreateZeroseOutput(const AnfNodePtr &node, size_t index);
101   AnfNodePtr GetZeroOutputs(const FuncGraphPtr &graph);
102 
103   std::pair<OperatorInfoPtr, int> GetOpInfoPair(const AnfNodePtr &node, const AnfNodePtr &graph_param,
104                                                 AnfNodePtr *care_node, bool *is_param);
105   void SetNodeAbstract(const std::vector<AnfNodePtr> &nodes);
106   std::vector<AnfNodePtr> FetchSend(const AnfNodePtr &node, bool pipeline_param, bool single_pipeline_end,
107                                     size_t end_index);
108   AnfNodePtr GenNewSendFromOld(const AnfNodePtr &node, const AnfNodePtr &send_input, const ValuePtr &value);
109   void HandleGraphOutputs(const std::vector<AnfNodePtr> &nodes);
110 
111   std::vector<AnfNodePtr> FetchRecv(const AnfNodePtr &node, bool pipeline_param);
112   AnfNodePtr GenNewRecvFromOld(const AnfNodePtr &node, const AnfNodePtr &input, const ValuePtr &value);
113   void ResetSharedCellParamAndArgu(const std::vector<std::vector<AnfNodePtr>> &pipeline_begins_fetched,
114                                    const std::vector<AnfNodePtr> &newly_added_params,
115                                    const std::vector<AnfNodePtr> &reserved_inputs);
116   // set shared_cell_ parameters, and call_input
117   void HandleGraphInputs(const std::vector<AnfNodePtr> &recv_ops);
118   bool GetStageByArgument(const CNodePtr &node, size_t index, const std::vector<AnfNodePtr> &parameters,
119                           const NodeUsersMap &node_users_map, std::set<int64_t> *const parameter_stage);
120   size_t GetBatchAxisForInput(const AnfNodeIndexSet &input_node_users) const;
121   void UpdateParameterSharedInfo(const AnfNodePtr &node, const AnfNodePtr &communcate_op, bool is_send);
122   void FillParameterStage(const CNodePtr &node, std::set<int64_t> *const parameter_stage);
123   FuncGraphManagerPtr manager_;
124   int64_t stage_ = 0;
125   FuncGraphPtr root_;
126   FuncGraphPtr main_graph_;
127   FuncGraphPtr shared_cell_;
128   AnfNodePtr virtual_dataset_;
129   int64_t global_rank_ = 0;
130   int64_t per_stage_rank_num_ = 1;
131   TypePtr type_ptr_;
132   ValueListPtr shape_;
133   AnfNodePtr virtual_param_;
134   int64_t micro_size_ = 0;
135   mindspore::HashMap<AnfNodePtr, std::set<int64_t>> parameter_color_map_ = {};
136   bool is_train_{true};
137   std::vector<AnfNodePtr> shared_cell_users_;
138   bool enable_share_cell_ = false;
139   std::string world_group_;
140 };
141 
142 bool IsInWhiteList(const CNodePtr &cnode);
143 std::pair<ValueListPtr, TypePtr> GetShapeType(const AnfNodePtr &node, const Shape &shape, size_t index);
144 
145 class NodeStageInfo {
146  public:
stage_(stage)147   explicit NodeStageInfo(int64_t stage, int64_t chunk = 0) : stage_(stage), chunk_(chunk) {}
148   ~NodeStageInfo() = default;
149 
stage()150   int64_t stage() const { return stage_; }
chunk()151   int64_t chunk() const { return chunk_; }
set_chunk(int64_t chunk)152   void set_chunk(int64_t chunk) { chunk_ = chunk; }
153 
154   // Key for user data.
155   constexpr static char key[] = "NodeStageInfo";
156 
157  private:
158   int64_t stage_;
159   int64_t chunk_;
160 };
161 size_t MicroSize(const AnfNodeIndexSet &input_node_users);
162 }  // namespace parallel
163 }  // namespace mindspore
164 
165 #endif  // MINDSPORE_CCSRC_FRONTEND_PARALLEL_PIPELINE_TRANSFORMER_PIPELINE_TRANSFORMER_H_
166