• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/parallel/allreduce_fusion/allreduce_graph.h"
18 #include <algorithm>
19 #include <functional>
20 #include "ir/anf.h"
21 #include "frontend/parallel/allreduce_fusion/allreduce_node.h"
22 #include "frontend/parallel/ops_info/ops_utils.h"
23 #include "utils/log_adapter.h"
24 
25 namespace mindspore {
26 namespace parallel {
AddNode(const CNodePtr & node,const AnfNodePtr & para)27 Status AllreduceGraph::AddNode(const CNodePtr &node, const AnfNodePtr &para) {
28   AllreduceNodePtr arnode;
29   auto cnode_emplace_return = cnode_set_.emplace(node);
30   if (!cnode_emplace_return.second) {
31     MS_LOG(INFO) << "node: " << node->DebugString() << " has already been added!";
32     auto cnode_arnode_pair = cnode_arnode_map_.find(node);
33     if (cnode_arnode_pair == cnode_arnode_map_.end()) {
34       MS_LOG(EXCEPTION) << "node is not in cnode_arnode_map_!";
35     }
36     arnode = cnode_arnode_pair->second;
37   } else {
38     arnode = std::make_shared<AllreduceNode>(AllreduceNode());
39   }
40 
41   if (arnode->Init(node) != SUCCESS) {
42     MS_LOG(ERROR) << "AllreduceNode Init failed";
43     return FAILED;
44   }
45   if (arnode->AddPara(para) != SUCCESS) {
46     MS_LOG(ERROR) << "AllreduceNode AddPara failed";
47     return FAILED;
48   }
49   cnode_arnode_map_[node] = arnode;
50 
51   auto arnode_emplace_return = arnode_set_.insert(arnode);
52   if (!arnode_emplace_return.second) {
53     MS_LOG(INFO) << "node: " << node->DebugString() << "'s arnode has already been added!";
54   }
55   cnode_emplace_return = para_cnodeset_map_[para].emplace(node);
56   if (!cnode_emplace_return.second) {
57     MS_LOG(INFO) << "node: " << node->DebugString() << " already in para: " << para->fullname_with_scope()
58                  << "'s cnodeset!";
59   }
60   auto para_emplace_return = cnode_paraset_map_[node].emplace(para);
61   if (!para_emplace_return.second) {
62     MS_LOG(INFO) << "para: " << para->fullname_with_scope() << " already in node: " << node->DebugString()
63                  << "'s paraset!";
64   }
65   return SUCCESS;
66 }
67 
AddEdge(const CNodePtr & from,const CNodePtr & to,double dist)68 Status AllreduceGraph::AddEdge(const CNodePtr &from, const CNodePtr &to, double dist) {
69   auto from_arnode_iter = cnode_arnode_map_.find(from);
70   if (from_arnode_iter == cnode_arnode_map_.end()) {
71     MS_LOG(ERROR) << "cnode from: " << from->DebugString() << "has not been added";
72     PrintCNodeSet();
73     return FAILED;
74   }
75   auto to_arnode_iter = cnode_arnode_map_.find(to);
76   if (to_arnode_iter == cnode_arnode_map_.end()) {
77     MS_LOG(ERROR) << "cnode to: " << to->DebugString() << "has not been added";
78     PrintCNodeSet();
79     return FAILED;
80   }
81   auto from_arnode = from_arnode_iter->second;
82   auto to_arnode = to_arnode_iter->second;
83   if (from_arnode->AddNext(to_arnode) != SUCCESS) {
84     MS_LOG(ERROR) << "from_arnode AddNext failed";
85     return FAILED;
86   }
87   if (to_arnode->AddPrev(from_arnode, dist, &max_) != SUCCESS) {
88     MS_LOG(ERROR) << "to_arnode AddPrev failed";
89     return FAILED;
90   }
91   max_ = std::max(max_, to_arnode->depend_feat_size());
92   MS_LOG(DEBUG) << "from " << from->DebugString() << ", to " << to->DebugString();
93   MS_LOG(DEBUG) << "from depend_feat_size: " << from_arnode->depend_feat_size()
94                 << ", to depend_feat_size: " << to_arnode->depend_feat_size();
95   return SUCCESS;
96 }
97 
NodeInGraph(const CNodePtr & node) const98 bool AllreduceGraph::NodeInGraph(const CNodePtr &node) const {
99   auto cnode_iter = cnode_set_.find(node);
100   return !(cnode_iter == cnode_set_.end());
101 }
102 
GetParaByCost(double from,double to)103 std::vector<AnfNodePtr> AllreduceGraph::GetParaByCost(double from, double to) {
104   std::vector<AnfNodePtr> nodes;
105   for (auto &cnode_arnode : cnode_arnode_map_) {
106     MS_LOG(DEBUG) << "cnode: " << cnode_arnode.first->DebugString()
107                   << ", depend_feat_size: " << cnode_arnode.second->depend_feat_size()
108                   << " curr_para_size: " << cnode_arnode.second->curr_para_size();
109     if ((cnode_arnode.second->depend_feat_size() <= to) && (cnode_arnode.second->depend_feat_size() > from)) {
110       (void)nodes.insert(nodes.end(), cnode_paraset_map_[cnode_arnode.first].begin(),
111                          cnode_paraset_map_[cnode_arnode.first].end());
112     }
113   }
114   return nodes;
115 }
116 
GetParaByParaSize(double to,double para_size)117 std::pair<std::vector<AnfNodePtr>, double> AllreduceGraph::GetParaByParaSize(double to, double para_size) {
118   std::vector<AnfNodePtr> nodes;
119   double cur_para_size = 0;
120   double from = to;
121   for (auto &arnode : arnode_vec_) {
122     if ((arnode.depend_feat_size() - max_ <= EPS) && arnode.depend_feat_size() >= to) {
123       continue;
124     }
125     if (para_size > 0 && cur_para_size >= para_size && arnode.depend_feat_size() < from) {
126       return std::make_pair(nodes, from);
127     }
128     (void)nodes.insert(nodes.end(), arnode.paras().begin(), arnode.paras().end());
129     cur_para_size += arnode.curr_para_size();
130     from = arnode.depend_feat_size();
131   }
132   MS_LOG(INFO) << "GetParaByParaSize has reached head node! para_size: " << para_size
133                << " cur_para_size: " << cur_para_size << " from: " << from;
134   return std::make_pair(nodes, from);
135 }
136 
PrintCNodeSet() const137 void AllreduceGraph::PrintCNodeSet() const {
138   MS_LOG(INFO) << "CNodeSet:";
139   for (auto &cnode : cnode_set_) {
140     MS_LOG(INFO) << cnode->DebugString();
141   }
142 }
143 
PrintAllredueGraphInfo() const144 void AllreduceGraph::PrintAllredueGraphInfo() const {
145   MS_LOG(INFO) << "max: " << max_;
146   for (auto &cnode_arnode : cnode_arnode_map_) {
147     MS_LOG(INFO) << "cnode: " << cnode_arnode.first->DebugString();
148     MS_LOG(INFO) << "arnode info: ";
149     cnode_arnode.second->ToString();
150   }
151 }
152 
PrintArnodeVec() const153 void AllreduceGraph::PrintArnodeVec() const {
154   MS_LOG(INFO) << "ArnodeVec:";
155   for (auto &arnode : arnode_vec_) {
156     arnode.ToString();
157   }
158 }
159 
PrintArnodeSet() const160 void AllreduceGraph::PrintArnodeSet() const {
161   MS_LOG(INFO) << "ArnodeSet:";
162   for (auto &arnode : arnode_set_) {
163     arnode->ToString();
164   }
165 }
166 
SortArnode()167 void AllreduceGraph::SortArnode() {
168   arnode_vec_.clear();
169   for (auto &node : arnode_set_) {
170     arnode_vec_.emplace_back(*node);
171   }
172   std::sort(arnode_vec_.begin(), arnode_vec_.end(), std::greater<>());
173 }
174 
RemoveExtraParas()175 Status AllreduceGraph::RemoveExtraParas() {
176   std::unordered_set<AnfNodePtr> para_map;
177   for (auto &node : arnode_vec_) {
178     for (auto &para : node.paras()) {
179       auto emplac_result = para_map.emplace(para);
180       if (!emplac_result.second) {
181         MS_LOG(DEBUG) << "parameter: " << para->fullname_with_scope() << "in arnode";
182         if (node.RemovePara(para) != SUCCESS) {
183           MS_LOG(ERROR) << "remove para failed";
184           return FAILED;
185         }
186       }
187     }
188   }
189   return SUCCESS;
190 }
191 
set_head_cnode(const CNodePtr & node)192 Status AllreduceGraph::set_head_cnode(const CNodePtr &node) {
193   auto arnode = std::make_shared<AllreduceNode>(AllreduceNode());
194   if (arnode->Init(node) != SUCCESS) {
195     MS_LOG(ERROR) << "AllreduceNode Init failed";
196   }
197   head_cnode_ = node;
198   cnode_arnode_map_[node] = arnode;
199   auto arnode_emplace_return = arnode_set_.insert(arnode);
200   if (!arnode_emplace_return.second) {
201     MS_LOG(WARNING) << "node: " << node->DebugString() << "'s arnode has already been added!";
202   }
203   auto cnode_emplace_return = cnode_set_.emplace(node);
204   if (!cnode_emplace_return.second) {
205     MS_LOG(WARNING) << "node: " << node->DebugString() << " has already been added!";
206   }
207   return SUCCESS;
208 }
209 }  // namespace parallel
210 }  // namespace mindspore
211