1 /** 2 * Copyright 2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_PIPELINE_TRANSFORMER_FOLD_PIPELINE_TRANSFORMER_H_ 18 #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_PIPELINE_TRANSFORMER_FOLD_PIPELINE_TRANSFORMER_H_ 19 20 #include <set> 21 #include <utility> 22 #include <string> 23 #include <memory> 24 #include <vector> 25 #include "ir/value.h" 26 #include "ir/graph_utils.h" 27 #include "base/base.h" 28 #include "utils/hash_map.h" 29 #include "frontend/parallel/step_parallel.h" 30 #include "frontend/parallel/graph_util/generate_graph.h" 31 #include "frontend/parallel/pipeline_transformer/pipeline_transformer.h" 32 33 namespace mindspore { 34 namespace parallel { 35 const int32_t DEPEND_NODE_SOURCE_INDEX = 2; 36 37 class FoldPipelineTransformer : public PipelineTransformer { 38 public: FoldPipelineTransformer(const FuncGraphManagerPtr & manager,int stage,const FuncGraphPtr & root,int64_t global_rank,int64_t per_stage_rank_num)39 FoldPipelineTransformer(const FuncGraphManagerPtr &manager, int stage, const FuncGraphPtr &root, int64_t global_rank, 40 int64_t per_stage_rank_num) 41 : PipelineTransformer(manager, stage, root, global_rank, per_stage_rank_num) {} 42 ~FoldPipelineTransformer() = default; 43 void Coloring() override; 44 void BroadCastColoring() override; 45 void CutGraph() override; 46 47 SendAttr InsertSend(const AnfNodePtr ¶meter, int64_t user_node_stage, int64_t node_stage, const ValuePtr &value, 48 int64_t segment); 49 AnfNodePtr InsertReceive(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr &use_node, int index, 50 int64_t user_node_stage, int64_t node_stage, const ValuePtr &value, 51 const AnfNodePtr &graph_param, int64_t segment); 52 53 void CutBorderForNode(const FuncGraphPtr &graph, const AnfNodePtr &node, std::vector<AnfNodePtr> *send_ops, 54 std::vector<int64_t> *send_ops_segment, std::vector<AnfNodePtr> *receive_ops); 55 AnfNodePtr Reuse(const AnfNodePtr &node, int64_t stage, int64_t node_segment, 56 const std::vector<AnfNodePtr> &out_input, const std::vector<int64_t> &out_input_segment, 57 const std::string &tag); 58 std::pair<std::vector<AnfNodePtr>, std::vector<AnfNodePtr>> CutBorder(const FuncGraphPtr &graph) override; 59 AnfNodePtr HandleParameterGraph(const AnfNodePtr &node, const AnfNodePtr &use_node, int64_t stage, int64_t user_stage, 60 const ValuePtr µ, size_t pos, const std::vector<AnfNodePtr> &ops) override; 61 std::pair<std::vector<AnfNodePtr>, std::vector<AnfNodePtr>> HandleSharedParameter() override; 62 63 private: 64 void CreateForwardGroup2(); 65 int64_t ComputeRecvTag(int64_t node_stage, int64_t user_node_stage, int64_t stage_num, int64_t src_rank); 66 void ColorForNodes(); 67 std::vector<std::string> group_ = {}; 68 }; 69 70 class NodeSegmentInfo { 71 public: NodeSegmentInfo(int64_t segment)72 explicit NodeSegmentInfo(int64_t segment) : segment_(segment) {} 73 ~NodeSegmentInfo() = default; 74 segment()75 int64_t segment() const { return segment_; } 76 77 // Key for user data. 78 constexpr static char key[] = "NodeSegmentInfo"; 79 80 private: 81 int64_t segment_; 82 }; 83 84 } // namespace parallel 85 } // namespace mindspore 86 87 #endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_PIPELINE_TRANSFORMER_FOLD_PIPELINE_TRANSFORMER_H_ 88