1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_ALLREDUCE_FUSION_ALLREDUCE_GRAPH_H_ 18 #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_ALLREDUCE_FUSION_ALLREDUCE_GRAPH_H_ 19 20 #include <memory> 21 #include <set> 22 #include <unordered_map> 23 #include <unordered_set> 24 #include <utility> 25 #include <vector> 26 #include "ir/anf.h" 27 #include "frontend/parallel/allreduce_fusion/allreduce_node.h" 28 #include "frontend/parallel/status.h" 29 30 namespace mindspore { 31 namespace parallel { 32 class AllreduceGraph { 33 public: AllreduceGraph()34 AllreduceGraph() 35 : head_cnode_(nullptr), 36 arnode_set_(), 37 arnode_vec_(), 38 cnode_set_(), 39 para_cnode_map_(), 40 para_cnodeset_map_(), 41 cnode_paraset_map_(), 42 cnode_arnode_map_(), 43 max_(0) {} 44 virtual ~AllreduceGraph() = default; 45 Status AddNode(const CNodePtr &node, const AnfNodePtr ¶); 46 Status AddEdge(const CNodePtr &from, const CNodePtr &to, double dist); 47 bool NodeInGraph(const CNodePtr &node) const; 48 std::vector<AnfNodePtr> GetParaByCost(double from, double to); 49 // Find the first several AllreduceNode whose depend_feat_size is less than to, the sum of whose parameter size is 50 // over para_size. 51 // Return the parameter AnfNodePtr vector corresponding to these AllreduceNodes and the smallest depend_feat_size. 52 // If the sum of left AllreduceNode's parameter size is less than para_size, the returned depend_feat_size must be 0. 53 std::pair<std::vector<AnfNodePtr>, double> GetParaByParaSize(double to, double para_size); 54 // If one parameter is used by multiple AllreduceNode, parameter belong to the last node for backward computation 55 // is saved by the corresponding AllreduceNode, parameters belong to other AllreduceNode are removed. 56 // Called during precise optimization, not implemented temporarily. 57 void SortArnode(); 58 Status RemoveExtraParas(); 59 void PrintCNodeSet() const; 60 void PrintAllredueGraphInfo() const; 61 void PrintArnodeVec() const; 62 void PrintArnodeSet() const; cnode_set()63 const std::unordered_set<CNodePtr> &cnode_set() const { return cnode_set_; } head_cnode()64 CNodePtr head_cnode() const { return head_cnode_; } 65 Status set_head_cnode(const CNodePtr &node); max()66 double max() const { return max_; } 67 68 private: 69 CNodePtr head_cnode_; 70 std::set<AllreduceNodePtr> arnode_set_; 71 std::vector<AllreduceNode> arnode_vec_; 72 std::unordered_set<CNodePtr> cnode_set_; 73 // If One ParameterPtr is used by multiple CNode, the last node for backward computation is saved. 74 std::unordered_map<AnfNodePtr, std::vector<CNodePtr>> para_cnode_map_; 75 // One ParameterPtr may be used by multiple CNode 76 std::unordered_map<AnfNodePtr, std::unordered_set<CNodePtr>> para_cnodeset_map_; 77 // Multiple Parameter may be inputs to the same CNode 78 std::unordered_map<CNodePtr, std::unordered_set<AnfNodePtr>> cnode_paraset_map_; 79 std::unordered_map<CNodePtr, AllreduceNodePtr> cnode_arnode_map_; 80 double max_; 81 }; 82 } // namespace parallel 83 } // namespace mindspore 84 85 #endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_ALLREDUCE_FUSION_ALLREDUCE_GRAPH_H_ 86