1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_PIPELINE_TRANSFORMER_PIPELINE_INTERLEAVE_H_ 18 #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_PIPELINE_TRANSFORMER_PIPELINE_INTERLEAVE_H_ 19 20 #include <set> 21 #include <utility> 22 #include <string> 23 #include <memory> 24 #include <vector> 25 #include "ir/value.h" 26 #include "ir/graph_utils.h" 27 #include "base/base.h" 28 #include "utils/hash_map.h" 29 #include "frontend/parallel/step_parallel.h" 30 #include "frontend/parallel/pipeline_transformer/pipeline_transformer.h" 31 #include "frontend/parallel/graph_util/generate_graph.h" 32 #include "ops/array_ops.h" 33 34 namespace mindspore { 35 namespace parallel { 36 class PipelineInterleave { 37 public: PipelineInterleave(const FuncGraphManagerPtr & manager,int stage,const FuncGraphPtr & root)38 PipelineInterleave(const FuncGraphManagerPtr &manager, int stage, const FuncGraphPtr &root) 39 : manager_(manager), stage_(stage), root_(root), main_graph_(nullptr), virtual_dataset_(nullptr) {} 40 virtual ~PipelineInterleave() = default; 41 void Init(); 42 void Coloring(); 43 void BroadCastColoring(); 44 void CutBorder(); 45 void LabelGenMaskFusion(); 46 bool MainGraph(); 47 void LabelMicroBatch(); 48 void ParameterColoring(); 49 void ElimParameter(); 50 bool HasNoUpdateParameter(); 51 52 private: 53 void CreateSendReceiveGroup(); 54 void RedundancyNode(const AnfNodePtr &node, mindspore::HashMap<CNodePtr, std::vector<AnfNodePtr>> *make_tuple_map); 55 bool IsRedundancyParameter(const AnfNodePtr ¶meter, const std::vector<AnfNodePtr> &non_cloned_parameters); 56 void InsertSendReceive(const AnfNodePtr &node, const AnfNodePtr &user_node, int64_t order); 57 void RemoveMonadNode(); 58 std::vector<AnfNodePtr> GetLoadNodeByParam(const AnfNodePtr ¶m) const; 59 ValuePtr SetMicroBatch(const AnfNodePtr &node, int64_t micro_size, size_t batch_axis) const; 60 void FreezeGradient(); 61 void CutBorderForNode(const FuncGraphPtr &graph, const AnfNodePtr &node, int64_t *order); 62 bool GetStageByArgument(const CNodePtr &node, size_t index, const std::vector<AnfNodePtr> ¶meters, 63 const NodeUsersMap &node_users_map, std::set<int64_t> *const parameter_stage); 64 size_t GetBatchAxisForInput(const AnfNodeIndexSet &input_node_users) const; 65 FuncGraphManagerPtr manager_; 66 int64_t stage_; 67 FuncGraphPtr root_; 68 FuncGraphPtr main_graph_; 69 FuncGraphPtr shared_cell_; 70 AnfNodePtr virtual_dataset_; 71 int64_t micro_size_ = 0; 72 mindspore::HashMap<AnfNodePtr, std::set<int64_t>> parameter_color_map_ = {}; 73 std::string world_group_; 74 std::vector<std::string> group_; 75 bool is_train_{true}; 76 int64_t global_rank_ = 0; 77 int64_t per_stage_rank_num_ = 0; 78 }; 79 80 class PipelinePostProcess { 81 public: PipelinePostProcess(const FuncGraphManagerPtr & manager,int64_t stage,int64_t stage_num,FuncGraphPtr root)82 explicit PipelinePostProcess(const FuncGraphManagerPtr &manager, int64_t stage, int64_t stage_num, FuncGraphPtr root) 83 : manager_(manager), stage_(stage), stage_num_(stage_num), root_(root) {} 84 virtual ~PipelinePostProcess() = default; 85 86 void Init(const std::vector<AnfNodePtr> &nodes); 87 void ModifySendRecvAttr(const std::vector<AnfNodePtr> &all_nodes); 88 void GraphPartition(const std::vector<AnfNodePtr> &all_nodes); 89 void ElimGraphStage(); 90 void ModifyParameterList(); 91 void HandleSendParam(); 92 93 private: 94 void LabelInterleaveIndex(); 95 std::vector<AnfNodePtr> PartitionChunkGraph(const FuncGraphPtr &fg, int64_t chunk); 96 void GetSendsRecvs(const FuncGraphPtr &fg, int64_t chunk, std::vector<AnfNodePtr> *recvs, 97 std::vector<AnfNodePtr> *sends, std::vector<AnfNodePtr> *temp); 98 void SetNodeAbstract(const std::vector<AnfNodePtr> &nodes); 99 AnfNodePtr GetZeroOutputs(const FuncGraphPtr &graph); 100 AnfNodePtr GenNewNodeFromOld(const AnfNodePtr &node, const AnfNodePtr &input, int64_t micro, int64_t index); 101 std::vector<AnfNodePtr> GenerateMainGraphSend(const std::vector<AnfNodePtr> &nodes, const AnfNodePtr &node, 102 const ValuePtr µ, const ValuePtr &index); 103 AnfNodePtr GenerateMainGraphRecv(const AnfNodePtr &fg_node, const AnfNodePtr &recv); 104 FuncGraphManagerPtr manager_; 105 int64_t stage_; 106 int64_t stage_num_; 107 FuncGraphPtr root_; 108 int64_t chunk_num_ = 1; 109 FuncGraphPtr main_graph_; 110 FuncGraphPtr shared_cell_; 111 std::vector<AnfNodePtr> shared_cell_users_; 112 }; 113 114 bool IsolatedNodeAttach(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer); 115 } // namespace parallel 116 } // namespace mindspore 117 118 #endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_PIPELINE_TRANSFORMER_PIPELINE_INTERLEAVE_H_ 119