1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_PIPELINE_TRANSFORMER_PIPELINE_TRANSFORMER_H_ 18 #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_PIPELINE_TRANSFORMER_PIPELINE_TRANSFORMER_H_ 19 20 #include <set> 21 #include <utility> 22 #include <string> 23 #include <memory> 24 #include <vector> 25 #include "ir/value.h" 26 #include "ir/graph_utils.h" 27 #include "base/base.h" 28 #include "utils/hash_map.h" 29 #include "frontend/parallel/step_parallel.h" 30 #include "frontend/parallel/graph_util/generate_graph.h" 31 #include "ops/array_ops.h" 32 33 namespace mindspore { 34 namespace parallel { 35 using TensorInfoPtr = std::shared_ptr<TensorInfo>; 36 37 typedef struct { 38 ValueListPtr shape; 39 TypePtr type; 40 AnfNodePtr depend; 41 } SendAttr; 42 43 class PipelineTransformer { 44 public: PipelineTransformer(const FuncGraphManagerPtr & manager,int stage,const FuncGraphPtr & root,int64_t global_rank,int64_t per_stage_rank_num)45 PipelineTransformer(const FuncGraphManagerPtr &manager, int stage, const FuncGraphPtr &root, int64_t global_rank, 46 int64_t per_stage_rank_num) 47 : manager_(manager), 48 stage_(stage), 49 root_(root), 50 main_graph_(nullptr), 51 virtual_dataset_(nullptr), 52 global_rank_(global_rank), 53 per_stage_rank_num_(per_stage_rank_num) {} 54 virtual ~PipelineTransformer() = default; 55 virtual void Coloring(); 56 virtual void BroadCastColoring(); 57 virtual void CutGraph(); 58 void LabelGenMaskFusion(); 59 bool MainGraph(); 60 void LabelMicroBatch(); 61 void ParameterColoring(); 62 void ElimGraphStage(); 63 void ModifyParameterList(); 64 65 AnfNodePtr GetArgumentsByParameter(const AnfNodePtr ¶meter); 66 void RemoveMonadNode(); 67 bool HasNoUpdateParameter(); 68 AnfNodePtr CreateTupleZeroTensor(const AnfNodePtr &node, size_t index); 69 std::vector<AnfNodePtr> GetLoadNodeByParam(const AnfNodePtr ¶m) const; 70 AnfNodePtr ActualOp(const AnfNodePtr &node); 71 bool IsParameterGraph(const AnfNodePtr &node) const; 72 virtual AnfNodePtr HandleParameterGraph(const AnfNodePtr &node, const AnfNodePtr &use_node, int64_t stage, 73 int64_t user_stage, const ValuePtr µ, size_t pos, 74 const std::vector<AnfNodePtr> &ops); 75 AnfNodeIndexSet GetParameterLoadUsers(const AnfNodePtr &node, const NodeUsersMap &node_users_map) const; 76 ValuePtr SetMicroBatch(const AnfNodePtr &node, int64_t micro_size, size_t batch_axis) const; 77 virtual std::pair<std::vector<AnfNodePtr>, std::vector<AnfNodePtr>> HandleSharedParameter(); 78 SendAttr InsertSend(const AnfNodePtr ¶meter, int64_t user_node_stage, int64_t node_stage, const ValuePtr &value); 79 AnfNodePtr InsertReceive(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr &use_node, int index, 80 int64_t user_node_stage, int64_t node_stage, const ValuePtr &value, 81 const AnfNodePtr &graph_param); 82 virtual std::pair<std::vector<AnfNodePtr>, std::vector<AnfNodePtr>> CutBorder(const FuncGraphPtr &graph); 83 void CutBorderForNode(const FuncGraphPtr &graph, const AnfNodePtr &node, std::vector<AnfNodePtr> *send_ops, 84 std::vector<AnfNodePtr> *receive_ops); 85 AnfNodePtr Reuse(const AnfNodePtr &node, int64_t stage, const std::vector<AnfNodePtr> &out_input, 86 const std::string &tag) const; 87 AnfNodePtr FindPipelineCareNode(const AnfNodePtr &node) const; 88 std::pair<OperatorInfoPtr, int> GetOpInfo(const AnfNodePtr &node); 89 TensorInfo GetTensorInfo(const std::pair<OperatorInfoPtr, int> &op_info_pair, bool is_param); 90 std::pair<OperatorInfoPtr, int> GetParameterPair(const AnfNodePtr &node); 91 OperatorInfoPtr CreateOpInfo(const CNodePtr &cnode, int tuple_index); 92 bool LabelParameterStart(const FuncGraphPtr &graph); 93 bool NeedGrad(const CNodePtr &cnode); 94 CNodePtr GraphOutNode(const AnfNodePtr &node, int tuple_index); 95 bool IsPipelineCareNode(const CNodePtr &cnode) const; 96 void RedundancyNode(const AnfNodePtr &node, mindspore::HashMap<CNodePtr, std::vector<AnfNodePtr>> *make_tuple_map); 97 bool IsRedundancyParameter(const AnfNodePtr ¶meter, const std::vector<AnfNodePtr> &non_cloned_parameters); 98 void ElimParameter(); 99 void FreezeGradient(); 100 AnfNodePtr CreateZeroseOutput(const AnfNodePtr &node, size_t index); 101 AnfNodePtr GetZeroOutputs(const FuncGraphPtr &graph); 102 103 std::pair<OperatorInfoPtr, int> GetOpInfoPair(const AnfNodePtr &node, const AnfNodePtr &graph_param, 104 AnfNodePtr *care_node, bool *is_param); 105 void SetNodeAbstract(const std::vector<AnfNodePtr> &nodes); 106 std::vector<AnfNodePtr> FetchSend(const AnfNodePtr &node, bool pipeline_param, bool single_pipeline_end, 107 size_t end_index); 108 AnfNodePtr GenNewSendFromOld(const AnfNodePtr &node, const AnfNodePtr &send_input, const ValuePtr &value); 109 void HandleGraphOutputs(const std::vector<AnfNodePtr> &nodes); 110 111 std::vector<AnfNodePtr> FetchRecv(const AnfNodePtr &node, bool pipeline_param); 112 AnfNodePtr GenNewRecvFromOld(const AnfNodePtr &node, const AnfNodePtr &input, const ValuePtr &value); 113 void ResetSharedCellParamAndArgu(const std::vector<std::vector<AnfNodePtr>> &pipeline_begins_fetched, 114 const std::vector<AnfNodePtr> &newly_added_params, 115 const std::vector<AnfNodePtr> &reserved_inputs); 116 // set shared_cell_ parameters, and call_input 117 void HandleGraphInputs(const std::vector<AnfNodePtr> &recv_ops); 118 bool GetStageByArgument(const CNodePtr &node, size_t index, const std::vector<AnfNodePtr> ¶meters, 119 const NodeUsersMap &node_users_map, std::set<int64_t> *const parameter_stage); 120 size_t GetBatchAxisForInput(const AnfNodeIndexSet &input_node_users) const; 121 void UpdateParameterSharedInfo(const AnfNodePtr &node, const AnfNodePtr &communcate_op, bool is_send); 122 void FillParameterStage(const CNodePtr &node, std::set<int64_t> *const parameter_stage); 123 FuncGraphManagerPtr manager_; 124 int64_t stage_ = 0; 125 FuncGraphPtr root_; 126 FuncGraphPtr main_graph_; 127 FuncGraphPtr shared_cell_; 128 AnfNodePtr virtual_dataset_; 129 int64_t global_rank_ = 0; 130 int64_t per_stage_rank_num_ = 1; 131 TypePtr type_ptr_; 132 ValueListPtr shape_; 133 AnfNodePtr virtual_param_; 134 int64_t micro_size_ = 0; 135 mindspore::HashMap<AnfNodePtr, std::set<int64_t>> parameter_color_map_ = {}; 136 bool is_train_{true}; 137 std::vector<AnfNodePtr> shared_cell_users_; 138 bool enable_share_cell_ = false; 139 std::string world_group_; 140 }; 141 142 bool IsInWhiteList(const CNodePtr &cnode); 143 std::pair<ValueListPtr, TypePtr> GetShapeType(const AnfNodePtr &node, const Shape &shape, size_t index); 144 145 class NodeStageInfo { 146 public: stage_(stage)147 explicit NodeStageInfo(int64_t stage, int64_t chunk = 0) : stage_(stage), chunk_(chunk) {} 148 ~NodeStageInfo() = default; 149 stage()150 int64_t stage() const { return stage_; } chunk()151 int64_t chunk() const { return chunk_; } set_chunk(int64_t chunk)152 void set_chunk(int64_t chunk) { chunk_ = chunk; } 153 154 // Key for user data. 155 constexpr static char key[] = "NodeStageInfo"; 156 157 private: 158 int64_t stage_; 159 int64_t chunk_; 160 }; 161 size_t MicroSize(const AnfNodeIndexSet &input_node_users); 162 } // namespace parallel 163 } // namespace mindspore 164 165 #endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_PIPELINE_TRANSFORMER_PIPELINE_TRANSFORMER_H_ 166