• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_PIPELINE_TRANSFORMER_FOLD_PIPELINE_TRANSFORMER_H_
18 #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_PIPELINE_TRANSFORMER_FOLD_PIPELINE_TRANSFORMER_H_
19 
20 #include <set>
21 #include <utility>
22 #include <string>
23 #include <memory>
24 #include <vector>
25 #include "ir/value.h"
26 #include "ir/graph_utils.h"
27 #include "base/base.h"
28 #include "utils/hash_map.h"
29 #include "frontend/parallel/step_parallel.h"
30 #include "frontend/parallel/graph_util/generate_graph.h"
31 #include "frontend/parallel/pipeline_transformer/pipeline_transformer.h"
32 
33 namespace mindspore {
34 namespace parallel {
35 const int32_t DEPEND_NODE_SOURCE_INDEX = 2;
36 
37 class FoldPipelineTransformer : public PipelineTransformer {
38  public:
FoldPipelineTransformer(const FuncGraphManagerPtr & manager,int stage,const FuncGraphPtr & root,int64_t global_rank,int64_t per_stage_rank_num)39   FoldPipelineTransformer(const FuncGraphManagerPtr &manager, int stage, const FuncGraphPtr &root, int64_t global_rank,
40                           int64_t per_stage_rank_num)
41       : PipelineTransformer(manager, stage, root, global_rank, per_stage_rank_num) {}
42   ~FoldPipelineTransformer() = default;
43   void Coloring() override;
44   void BroadCastColoring() override;
45   void CutGraph() override;
46 
47   SendAttr InsertSend(const AnfNodePtr &parameter, int64_t user_node_stage, int64_t node_stage, const ValuePtr &value,
48                       int64_t segment);
49   AnfNodePtr InsertReceive(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr &use_node, int index,
50                            int64_t user_node_stage, int64_t node_stage, const ValuePtr &value,
51                            const AnfNodePtr &graph_param, int64_t segment);
52 
53   void CutBorderForNode(const FuncGraphPtr &graph, const AnfNodePtr &node, std::vector<AnfNodePtr> *send_ops,
54                         std::vector<int64_t> *send_ops_segment, std::vector<AnfNodePtr> *receive_ops);
55   AnfNodePtr Reuse(const AnfNodePtr &node, int64_t stage, int64_t node_segment,
56                    const std::vector<AnfNodePtr> &out_input, const std::vector<int64_t> &out_input_segment,
57                    const std::string &tag);
58   std::pair<std::vector<AnfNodePtr>, std::vector<AnfNodePtr>> CutBorder(const FuncGraphPtr &graph) override;
59   AnfNodePtr HandleParameterGraph(const AnfNodePtr &node, const AnfNodePtr &use_node, int64_t stage, int64_t user_stage,
60                                   const ValuePtr &micro, size_t pos, const std::vector<AnfNodePtr> &ops) override;
61   std::pair<std::vector<AnfNodePtr>, std::vector<AnfNodePtr>> HandleSharedParameter() override;
62 
63  private:
64   void CreateForwardGroup2();
65   int64_t ComputeRecvTag(int64_t node_stage, int64_t user_node_stage, int64_t stage_num, int64_t src_rank);
66   void ColorForNodes();
67   std::vector<std::string> group_ = {};
68 };
69 
70 class NodeSegmentInfo {
71  public:
NodeSegmentInfo(int64_t segment)72   explicit NodeSegmentInfo(int64_t segment) : segment_(segment) {}
73   ~NodeSegmentInfo() = default;
74 
segment()75   int64_t segment() const { return segment_; }
76 
77   // Key for user data.
78   constexpr static char key[] = "NodeSegmentInfo";
79 
80  private:
81   int64_t segment_;
82 };
83 
84 }  // namespace parallel
85 }  // namespace mindspore
86 
87 #endif  // MINDSPORE_CCSRC_FRONTEND_PARALLEL_PIPELINE_TRANSFORMER_FOLD_PIPELINE_TRANSFORMER_H_
88