• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_PIPELINE_TRANSFORMER_PIPELINE_TRANSFORMER_H_
18 #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_PIPELINE_TRANSFORMER_PIPELINE_TRANSFORMER_H_
19 
20 #include <utility>
21 #include <string>
22 #include <memory>
23 #include <vector>
24 #include "ir/value.h"
25 #include "ir/graph_utils.h"
26 #include "base/base.h"
27 #include "frontend/parallel/step_parallel.h"
28 #include "frontend/parallel/graph_util/generate_graph.h"
29 
30 namespace mindspore {
31 namespace parallel {
32 using TensorLayoutPtr = std::shared_ptr<TensorLayout>;
33 using TensorInfoPtr = std::shared_ptr<TensorInfo>;
34 
35 typedef struct {
36   ValueListPtr shape;
37   TypePtr type;
38   AnfNodePtr depend;
39 } SendAttr;
40 
41 class PipelineTransformer {
42  public:
PipelineTransformer(const FuncGraphManagerPtr & manager,int stage,const FuncGraphPtr & root,int64_t global_rank,int64_t per_stage_rank_num)43   PipelineTransformer(const FuncGraphManagerPtr &manager, int stage, const FuncGraphPtr &root, int64_t global_rank,
44                       int64_t per_stage_rank_num)
45       : manager_(manager),
46         stage_(stage),
47         root_(root),
48         main_graph_(nullptr),
49         virtual_dataset_(nullptr),
50         global_rank_(global_rank),
51         per_stage_rank_num_(per_stage_rank_num) {}
52   virtual ~PipelineTransformer() = default;
53   void Coloring();
54   void MainGraph();
55   void LabelMicroBatch();
56   void BroadCastColoring();
57   void CutGraph();
58   void ParameterColoring();
59   void CoverSensShape();
60   void ElimGraphStage();
61   void ElimParameter();
62 
63  private:
64   void CreateForwardGroup();
65   AnfNodePtr ActualOp(const AnfNodePtr &node);
66   bool IsParameterGraph(const AnfNodePtr &node);
67   AnfNodeIndexSet GetActualOpUsers(const std::pair<AnfNodePtr, int> &node_pair, NodeUsersMap *node_users_map);
68   AnfNodePtr HandleParameterGraph(const AnfNodePtr &node, const AnfNodePtr &use_node, int64_t stage, int64_t user_stage,
69                                   const ValuePtr &micro, size_t pos, const std::vector<AnfNodePtr> ops);
70   ValuePtr SetMicroBatch(const AnfNodePtr &node, int64_t micro_size);
71   std::vector<AnfNodePtr> HandleSharedParameter();
72   SendAttr InsertSend(const AnfNodePtr &parameter, int64_t user_node_stage, int64_t node_stage, const ValuePtr &value);
73   AnfNodePtr InsertReceive(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr &use_node, int index,
74                            int64_t user_node_stage, int64_t node_stage, const ValuePtr &value,
75                            const AnfNodePtr &graph_param);
76   std::pair<std::vector<AnfNodePtr>, std::vector<AnfNodePtr>> CutBorder(const FuncGraphPtr &graph);
77   AnfNodePtr Reuse(const AnfNodePtr &node, int64_t stage, const std::vector<AnfNodePtr> &out_input,
78                    const std::string &tag);
79   AnfNodePtr FindPipelineCareNode(const AnfNodePtr &node);
80   std::pair<OperatorInfoPtr, int> GetOpInfo(const AnfNodePtr &node);
81   std::pair<OperatorInfoPtr, int> GetParameterPair(const AnfNodePtr &node);
82   OperatorInfoPtr CreateOpInfo(const CNodePtr &cnode, int tuple_index);
83   bool LabelParameterStart(const FuncGraphPtr &graph, const CNodePtr &graph_cnode);
84   bool NeedGrad(const CNodePtr &cnode, const CNodePtr &graph_cnode);
85   CNodePtr GraphOutNode(const AnfNodePtr &node, int tuple_index);
86   bool IsPipelineCareNode(const CNodePtr &cnode);
87   std::pair<CNodePtr, FuncGraphPtr> FindSensNode();
88   FuncGraphManagerPtr manager_;
89   int64_t stage_;
90   FuncGraphPtr root_;
91   FuncGraphPtr main_graph_;
92   AnfNodePtr virtual_dataset_;
93   int64_t global_rank_;
94   int64_t per_stage_rank_num_;
95   TypePtr type_ptr_;
96   ValueListPtr shape_;
97   AnfNodePtr virtual_param_;
98   int64_t micro_size_ = 0;
99   std::vector<std::string> group_ = {};
100 };
101 }  // namespace parallel
102 }  // namespace mindspore
103 
104 #endif  // MINDSPORE_CCSRC_FRONTEND_PARALLEL_PIPELINE_TRANSFORMER_PIPELINE_TRANSFORMER_H_
105