1 /** 2 * Copyright 2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_GRAPH_UTIL_PIPELINE_SPLIT_UTILS_H_ 18 #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_GRAPH_UTIL_PIPELINE_SPLIT_UTILS_H_ 19 20 #include <utility> 21 #include <vector> 22 #include <string> 23 #include "ir/anf.h" 24 #include "ir/manager.h" 25 26 namespace mindspore { 27 namespace parallel { 28 using PipelinePair = std::pair<std::vector<AnfNodePtr>, std::vector<AnfNodePtr>>; 29 using PipelinePairVector = std::vector<std::vector<mindspore::parallel::PipelinePair>>; 30 AnfNodePtr FindAccuGrad(const CNodePtr &cnode); 31 bool IsFirstStage(); 32 bool IsLastStage(); 33 int64_t InferStage(); 34 void InsertVirtualAssignAdd(const std::pair<AnfNodePtr, int> &node_user, const FuncGraphManagerPtr &manager, 35 const AnfNodePtr &accu_parameter, const NodeUsersMap &node_user_map); 36 void InsertVirtualAccuGrad(const AnfNodePtr &recv, const FuncGraphManagerPtr &manager, const AnfNodePtr ¶m); 37 AnfNodePtr FindGradAccuParameter(const std::vector<AnfNodePtr> ¶meters, const std::string &name); 38 void HandleReceiveParam(const FuncGraphPtr &root); 39 void AddVirtualAssignAdd(const FuncGraphPtr &root); 40 void SetParameterStartForCellShare(const FuncGraphPtr &root); 41 bool CompFunc(const AnfNodePtr &node1, const AnfNodePtr &node2); 42 void ReorderForForward(const std::vector<AnfNodePtr> &forward_start, const std::vector<AnfNodePtr> &forward_end, 43 const FuncGraphPtr &root); 44 void ReorderForBackward(const PipelinePair &forward_start_pair, const PipelinePair &forward_end_pair, 45 const PipelinePair &backward_start_pair, const PipelinePair &backward_end_pair, 46 const PipelinePair &forward_end_before_pair, const FuncGraphPtr &root); 47 void ReorderForParams(const PipelinePair &backward_params_pair, const PipelinePair &forward_params_pair, 48 const PipelinePair &backward_end_pair, const PipelinePair &forward_start_pair, 49 const FuncGraphPtr &root); 50 int64_t GetMicroBatch(const AnfNodePtr &node); 51 void InsertDepend(const AnfNodePtr &prior_node, const AnfNodePtr &post_node, const FuncGraphManagerPtr &manager, 52 const FuncGraphPtr &root, const std::string &attr_tag = ""); 53 AnfNodePtr GetActualOp(const AnfNodePtr &node); 54 void GetBorderNode(std::vector<AnfNodePtr> *forward_start, std::vector<AnfNodePtr> *forward_end, 55 std::vector<AnfNodePtr> *backward_start, std::vector<AnfNodePtr> *backward_end, 56 std::vector<AnfNodePtr> *forward_params, std::vector<AnfNodePtr> *backward_params, 57 std::vector<AnfNodePtr> *allreduce_params, const FuncGraphPtr &root); 58 void Reorder(const FuncGraphPtr &root); 59 void ReorderForPredict(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager); 60 void HandleMicroBatch(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphManagerPtr &manager); 61 void BroadCastMicroBatch(const CNodePtr &node, NodeUsersMap *node_users_map, const ValuePtr &value, size_t max_depth); 62 void LabelNeedGrad(const FuncGraphManagerPtr &manager, const FuncGraphPtr &root); 63 void BroadCastNeedGrad(const AnfNodePtr &node, NodeUsersMap *node_user_map, const FuncGraphPtr &root); 64 void LastStageEndNode(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphManagerPtr &manager, 65 const FuncGraphPtr &root); 66 void SetStridedSliceStrategy(const AnfNodePtr &node); 67 void ParameterStartNode(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphManagerPtr &manager); 68 bool IsValidNode(const AnfNodePtr &node, const AnfNodePtr &return_node, const NodeUsersMap &node_user_map); 69 ValuePtr Micro(const CNodePtr &cnode, NodeUsersMap *node_users_map, size_t max_depth); 70 void CheckBorderNode(const PipelinePair &forward_start_pair, const PipelinePair &forward_end_pair, 71 const PipelinePair &backward_start_pair, const PipelinePair &backward_end_pair, 72 std::vector<int64_t> seg_micro_max); 73 void CommonDeduplicate(const std::vector<AnfNodePtr> &node_vector, std::vector<AnfNodePtr> *out_vec_begin, 74 std::vector<AnfNodePtr> *out_vec_end, const FuncGraphPtr &root, int64_t micro_max, 75 int64_t seg_max, int64_t h, bool is_train); 76 PipelinePair GetForwardEndBeforePair(const PipelinePair &forward_end_pair); 77 int64_t GetMicroMax(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &forward_end); 78 int64_t GetSegment(const AnfNodePtr &node); 79 std::string GetWorldGroup(); 80 int64_t GetRank(); 81 } // namespace parallel 82 } // namespace mindspore 83 84 #endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_GRAPH_UTIL_PIPELINE_SPLIT_UTILS_H_ 85