• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_PIPELINE_TRANSFORMER_PIPELINE_INTERLEAVE_H_
18 #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_PIPELINE_TRANSFORMER_PIPELINE_INTERLEAVE_H_
19 
20 #include <set>
21 #include <utility>
22 #include <string>
23 #include <memory>
24 #include <vector>
25 #include "ir/value.h"
26 #include "ir/graph_utils.h"
27 #include "base/base.h"
28 #include "utils/hash_map.h"
29 #include "frontend/parallel/step_parallel.h"
30 #include "frontend/parallel/pipeline_transformer/pipeline_transformer.h"
31 #include "frontend/parallel/graph_util/generate_graph.h"
32 #include "ops/array_ops.h"
33 
34 namespace mindspore {
35 namespace parallel {
36 class PipelineInterleave {
37  public:
PipelineInterleave(const FuncGraphManagerPtr & manager,int stage,const FuncGraphPtr & root)38   PipelineInterleave(const FuncGraphManagerPtr &manager, int stage, const FuncGraphPtr &root)
39       : manager_(manager), stage_(stage), root_(root), main_graph_(nullptr), virtual_dataset_(nullptr) {}
40   virtual ~PipelineInterleave() = default;
41   void Init();
42   void Coloring();
43   void BroadCastColoring();
44   void CutBorder();
45   void LabelGenMaskFusion();
46   bool MainGraph();
47   void LabelMicroBatch();
48   void ParameterColoring();
49   void ElimParameter();
50   bool HasNoUpdateParameter();
51 
52  private:
53   void CreateSendReceiveGroup();
54   void RedundancyNode(const AnfNodePtr &node, mindspore::HashMap<CNodePtr, std::vector<AnfNodePtr>> *make_tuple_map);
55   bool IsRedundancyParameter(const AnfNodePtr &parameter, const std::vector<AnfNodePtr> &non_cloned_parameters);
56   void InsertSendReceive(const AnfNodePtr &node, const AnfNodePtr &user_node, int64_t order);
57   void RemoveMonadNode();
58   std::vector<AnfNodePtr> GetLoadNodeByParam(const AnfNodePtr &param) const;
59   ValuePtr SetMicroBatch(const AnfNodePtr &node, int64_t micro_size, size_t batch_axis) const;
60   void FreezeGradient();
61   void CutBorderForNode(const FuncGraphPtr &graph, const AnfNodePtr &node, int64_t *order);
62   bool GetStageByArgument(const CNodePtr &node, size_t index, const std::vector<AnfNodePtr> &parameters,
63                           const NodeUsersMap &node_users_map, std::set<int64_t> *const parameter_stage);
64   size_t GetBatchAxisForInput(const AnfNodeIndexSet &input_node_users) const;
65   FuncGraphManagerPtr manager_;
66   int64_t stage_;
67   FuncGraphPtr root_;
68   FuncGraphPtr main_graph_;
69   FuncGraphPtr shared_cell_;
70   AnfNodePtr virtual_dataset_;
71   int64_t micro_size_ = 0;
72   mindspore::HashMap<AnfNodePtr, std::set<int64_t>> parameter_color_map_ = {};
73   std::string world_group_;
74   std::vector<std::string> group_;
75   bool is_train_{true};
76   int64_t global_rank_ = 0;
77   int64_t per_stage_rank_num_ = 0;
78 };
79 
80 class PipelinePostProcess {
81  public:
PipelinePostProcess(const FuncGraphManagerPtr & manager,int64_t stage,int64_t stage_num,FuncGraphPtr root)82   explicit PipelinePostProcess(const FuncGraphManagerPtr &manager, int64_t stage, int64_t stage_num, FuncGraphPtr root)
83       : manager_(manager), stage_(stage), stage_num_(stage_num), root_(root) {}
84   virtual ~PipelinePostProcess() = default;
85 
86   void Init(const std::vector<AnfNodePtr> &nodes);
87   void ModifySendRecvAttr(const std::vector<AnfNodePtr> &all_nodes);
88   void GraphPartition(const std::vector<AnfNodePtr> &all_nodes);
89   void ElimGraphStage();
90   void ModifyParameterList();
91   void HandleSendParam();
92 
93  private:
94   void LabelInterleaveIndex();
95   std::vector<AnfNodePtr> PartitionChunkGraph(const FuncGraphPtr &fg, int64_t chunk);
96   void GetSendsRecvs(const FuncGraphPtr &fg, int64_t chunk, std::vector<AnfNodePtr> *recvs,
97                      std::vector<AnfNodePtr> *sends, std::vector<AnfNodePtr> *temp);
98   void SetNodeAbstract(const std::vector<AnfNodePtr> &nodes);
99   AnfNodePtr GetZeroOutputs(const FuncGraphPtr &graph);
100   AnfNodePtr GenNewNodeFromOld(const AnfNodePtr &node, const AnfNodePtr &input, int64_t micro, int64_t index);
101   std::vector<AnfNodePtr> GenerateMainGraphSend(const std::vector<AnfNodePtr> &nodes, const AnfNodePtr &node,
102                                                 const ValuePtr &micro, const ValuePtr &index);
103   AnfNodePtr GenerateMainGraphRecv(const AnfNodePtr &fg_node, const AnfNodePtr &recv);
104   FuncGraphManagerPtr manager_;
105   int64_t stage_;
106   int64_t stage_num_;
107   FuncGraphPtr root_;
108   int64_t chunk_num_ = 1;
109   FuncGraphPtr main_graph_;
110   FuncGraphPtr shared_cell_;
111   std::vector<AnfNodePtr> shared_cell_users_;
112 };
113 
114 bool IsolatedNodeAttach(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer);
115 }  // namespace parallel
116 }  // namespace mindspore
117 
118 #endif  // MINDSPORE_CCSRC_FRONTEND_PARALLEL_PIPELINE_TRANSFORMER_PIPELINE_INTERLEAVE_H_
119