1 /**
2 * Copyright 2022-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "backend/common/session/exec_order_builder.h"
17 #include <algorithm>
18 #include <string>
19 #include "ops/ascend_op_name.h"
20 #include "include/common/utils/anfalgo.h"
21 #include "utils/ms_context.h"
22
23 namespace mindspore::session {
24 const size_t kDefaultContainerSize = 5000;
25
26 namespace {
GetNodeGroup(const AnfNodePtr & node)27 std::string GetNodeGroup(const AnfNodePtr &node) {
28 MS_EXCEPTION_IF_NULL(node);
29 auto cnode = node->cast<CNodePtr>();
30 if (common::AnfAlgo::HasNodeAttr(kAttrGroup, cnode)) {
31 return common::AnfAlgo::GetNodeAttr<std::string>(cnode, kAttrGroup);
32 }
33 return "";
34 }
35
NeedOptimize(const AnfNodePtr & node,const std::string & optimized_comm_group)36 bool NeedOptimize(const AnfNodePtr &node, const std::string &optimized_comm_group) {
37 bool is_fused_comm = common::AnfAlgo::IsFusedCommunicationOp(node);
38 if (!is_fused_comm) {
39 return false;
40 }
41 auto node_group = GetNodeGroup(node);
42 if (node_group.find(kSyncBnGroup) == string::npos) {
43 if (optimized_comm_group.empty() || node_group == optimized_comm_group) {
44 return true;
45 }
46 }
47 return false;
48 }
49 } // namespace
50
~ExecOrderBuilder()51 ExecOrderBuilder::~ExecOrderBuilder() {}
52
Build(FuncGraph * graph,std::vector<CNodePtr> * execution_order,NodeUser * node_user)53 void ExecOrderBuilder::Build(FuncGraph *graph, std::vector<CNodePtr> *execution_order, NodeUser *node_user) {
54 MS_EXCEPTION_IF_NULL(graph);
55 MS_EXCEPTION_IF_NULL(execution_order);
56 MS_EXCEPTION_IF_NULL(node_user);
57 graph_ = graph;
58 is_pynative_kernel_graph_ = graph_->has_flag(kFlagIsPyNativeBpropKernelGraph);
59 execution_order_ = execution_order;
60 node_output_edges_ = node_user;
61 node_output_edges_->clear();
62 ClearLinkInfo();
63 BuildLinkInfo();
64 FindIndependentNodes();
65 Build();
66 }
67
ClearLinkInfo()68 void ExecOrderBuilder::ClearLinkInfo() {
69 if (node_input_num_.empty()) {
70 node_input_num_.reserve(kDefaultContainerSize);
71 node_output_num_.reserve(kDefaultContainerSize);
72 node_input_edges_.reserve(kDefaultContainerSize);
73 trivial_nodes_.reserve(kDefaultContainerSize);
74 } else {
75 node_input_num_.clear();
76 node_output_num_.clear();
77 node_input_edges_.clear();
78 trivial_nodes_.clear();
79 node_output_edges_->clear();
80 }
81 }
82
IsTrivialNode(const AnfNodePtr & node)83 bool ExecOrderBuilder::IsTrivialNode(const AnfNodePtr &node) {
84 MS_EXCEPTION_IF_NULL(node);
85 if (!node->isa<CNode>()) {
86 return true;
87 }
88
89 const auto iter = trivial_nodes_.find(node);
90 if (iter != trivial_nodes_.end()) {
91 return iter->second;
92 }
93
94 if (AnfUtils::IsRealKernel(node)) {
95 (void)trivial_nodes_.emplace(node, false);
96 return false;
97 }
98
99 auto cnode = node->cast<CNodePtr>();
100 MS_EXCEPTION_IF_NULL(cnode);
101 if (std::all_of(cnode->inputs().begin(), cnode->inputs().end(),
102 [this](const auto &input) { return IsTrivialNode(input); })) {
103 (void)trivial_nodes_.emplace(node, true);
104 return true;
105 } else {
106 (void)trivial_nodes_.emplace(node, false);
107 return false;
108 }
109 }
110
BuildLinkInfo()111 void ExecOrderBuilder::BuildLinkInfo() {
112 std::queue<AnfNodePtr> to_visit;
113 auto output = graph_->get_return();
114 if (!output->isa<CNode>()) {
115 return;
116 }
117 to_visit.emplace(output);
118 auto seen = NewSeenGeneration();
119 while (!to_visit.empty()) {
120 auto node = to_visit.front();
121 to_visit.pop();
122 MS_EXCEPTION_IF_NULL(node);
123 auto cnode = node->cast<CNodePtr>();
124 MS_EXCEPTION_IF_NULL(cnode);
125 for (auto &input : cnode->inputs()) {
126 MS_EXCEPTION_IF_NULL(input);
127 (void)(*node_output_edges_)[input].emplace_back(node);
128 if (IsTrivialNode(input)) {
129 GetTrivialInputNode(input, seen);
130 continue;
131 }
132 if (!is_pynative_kernel_graph_) {
133 (void)node_input_edges_[node].emplace_back(input);
134 }
135 node_input_num_[node] += 1;
136 node_output_num_[input] += 1;
137 if (input->seen_ == seen || !input->isa<CNode>() || AnfUtils::IsCustomActorNode(input)) {
138 continue;
139 }
140 to_visit.emplace(input);
141 input->seen_ = seen;
142 }
143 }
144 }
145
GetTrivialInputNode(const AnfNodePtr & node,SeenNum seen)146 void ExecOrderBuilder::GetTrivialInputNode(const AnfNodePtr &node, SeenNum seen) {
147 MS_EXCEPTION_IF_NULL(node);
148 if (!node->isa<CNode>()) {
149 return;
150 }
151 auto cnode = node->cast<CNodePtr>();
152 for (auto &in : cnode->inputs()) {
153 (void)(*node_output_edges_)[in].emplace_back(node);
154 if (in->seen_ != seen && IsTrivialNode(in)) {
155 GetTrivialInputNode(in, seen);
156 in->seen_ = seen;
157 }
158 }
159 }
160
CanVisitInput(bool visit_with_refcount,const AnfNodePtr & input,SeenNum seen)161 bool ExecOrderBuilder::CanVisitInput(bool visit_with_refcount, const AnfNodePtr &input, SeenNum seen) {
162 MS_EXCEPTION_IF_NULL(input);
163 if (visit_with_refcount) {
164 auto output_iter = node_output_num_.find(input);
165 if (output_iter != node_output_num_.end()) {
166 output_iter->second--;
167 if (output_iter->second != 0) {
168 return false;
169 }
170 }
171 } else {
172 if (input->seen_ == seen) {
173 return false;
174 }
175 input->seen_ = seen;
176 }
177 return true;
178 }
179
FindIndependentNodes()180 void ExecOrderBuilder::FindIndependentNodes() {
181 std::queue<AnfNodePtr> to_visit;
182 std::queue<AnfNodePtr> vnode_to_visit;
183 vnode_to_visit.emplace(graph_->get_return());
184 bool visit_with_refcount = true;
185 auto ms_context = MsContext::GetInstance();
186 auto target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
187 if (target == kGPUDevice) {
188 visit_with_refcount = false;
189 }
190 auto seen = NewSeenGeneration();
191 while (!to_visit.empty() || !vnode_to_visit.empty()) {
192 AnfNodePtr node;
193 if (vnode_to_visit.empty()) {
194 node = to_visit.front();
195 to_visit.pop();
196 } else {
197 node = vnode_to_visit.front();
198 vnode_to_visit.pop();
199 }
200
201 MS_EXCEPTION_IF_NULL(node);
202 if (!node->isa<CNode>()) {
203 continue;
204 }
205
206 if (AnfUtils::IsCustomActorNode(node)) {
207 independent_nodes_.push(node);
208 continue;
209 }
210 auto cnode = node->cast<CNodePtr>();
211 MS_EXCEPTION_IF_NULL(cnode);
212 bool independent = true;
213 auto &inputs = cnode->inputs();
214 for (auto iter = inputs.rbegin(); iter != inputs.rend(); ++iter) {
215 auto &input = *iter;
216 MS_EXCEPTION_IF_NULL(input);
217 if (IsTrivialNode(input)) {
218 continue;
219 }
220 independent = false;
221
222 if (!CanVisitInput(visit_with_refcount, input, seen)) {
223 continue;
224 }
225
226 if (AnfUtils::IsRealKernel(input)) {
227 to_visit.emplace(input);
228 if (!independent_nodes_.empty() && visit_with_refcount) {
229 auto inode = independent_nodes_.top();
230 (void)(*node_output_edges_)[input].emplace_back(inode);
231 if (!is_pynative_kernel_graph_) {
232 (void)node_input_edges_[inode].emplace_back(input);
233 }
234 node_input_num_[inode] += 1;
235 independent_nodes_.pop();
236 }
237 } else {
238 vnode_to_visit.emplace(input);
239 }
240 }
241
242 if (independent) {
243 independent_nodes_.push(node);
244 }
245 }
246 }
247
EnqueueReadyNodes(const AnfNodePtr & node,std::deque<AnfNodePtr> * visit_queue,bool comm_first)248 void ExecOrderBuilder::EnqueueReadyNodes(const AnfNodePtr &node, std::deque<AnfNodePtr> *visit_queue, bool comm_first) {
249 MS_EXCEPTION_IF_NULL(visit_queue);
250 MS_EXCEPTION_IF_NULL(visit_queue);
251 MS_EXCEPTION_IF_NULL(node_output_edges_);
252 auto it = node_output_edges_->find(node);
253 if (it == node_output_edges_->end()) {
254 return;
255 }
256
257 std::vector<AnfNodePtr> active_nodes;
258 for (const auto &output_node : it->second) {
259 MS_EXCEPTION_IF_NULL(output_node);
260 auto input_num_iter = node_input_num_.find(output_node);
261 if (input_num_iter == node_input_num_.end() || input_num_iter->second == 0) {
262 continue;
263 }
264 input_num_iter->second--;
265 if (input_num_iter->second > 0) {
266 continue;
267 }
268
269 bool is_comm_node = common::AnfAlgo::IsCommunicationOp(output_node);
270 if (!AnfUtils::IsRealKernel(output_node) || it->second.size() == 1) {
271 visit_queue->push_front(output_node);
272 } else if ((is_comm_node && comm_first) || (!is_comm_node && !comm_first)) {
273 visit_queue->push_back(output_node);
274 } else {
275 (void)active_nodes.emplace_back(output_node);
276 }
277 }
278
279 (void)std::copy(active_nodes.begin(), active_nodes.end(), std::back_inserter(*visit_queue));
280 }
281
Build()282 void ExecOrderBuilder::Build() {
283 MS_EXCEPTION_IF_NULL(execution_order_);
284 execution_order_->clear();
285 execution_order_->reserve(kDefaultContainerSize);
286 std::deque<AnfNodePtr> to_visit;
287 std::deque<AnfNodePtr> delay_visit;
288 std::deque<AnfNodePtr> high_priority_to_visit;
289 std::deque<AnfNodePtr> *handle_queue_ptr;
290 std::string optimized_comm_group;
291 AnfNodePtr pending_node = nullptr;
292 while (!independent_nodes_.empty() || pending_node != nullptr || !delay_visit.empty()) {
293 if (!delay_visit.empty()) {
294 EnqueueReadyNodes(delay_visit.front(), &high_priority_to_visit, false);
295 delay_visit.pop_front();
296 } else if (pending_node != nullptr) {
297 EnqueueReadyNodes(pending_node, &high_priority_to_visit, false);
298 pending_node = nullptr;
299 } else {
300 to_visit.push_back(independent_nodes_.top());
301 independent_nodes_.pop();
302 }
303 // comm descendant first, then common queue
304 while (!to_visit.empty() || !high_priority_to_visit.empty()) {
305 AnfNodePtr node;
306 if (!high_priority_to_visit.empty()) {
307 handle_queue_ptr = &high_priority_to_visit;
308 node = high_priority_to_visit.front();
309 high_priority_to_visit.pop_front();
310 } else {
311 handle_queue_ptr = &to_visit;
312 node = to_visit.front();
313 to_visit.pop_front();
314 }
315 // add execute node
316 MS_EXCEPTION_IF_NULL(node);
317 if (node->isa<CNode>() && AnfUtils::IsRealKernel(node)) {
318 (void)execution_order_->emplace_back(node->cast<CNodePtr>());
319 }
320 // delay execute comm ops that need optimize
321 bool is_comm = common::AnfAlgo::IsCommunicationOp(node);
322 bool optimize_comm = NeedOptimize(node, optimized_comm_group);
323 if (optimize_comm) {
324 optimized_comm_group = GetNodeGroup(node);
325 if (pending_node != nullptr) {
326 EnqueueReadyNodes(pending_node, &high_priority_to_visit, false);
327 }
328 pending_node = node;
329 } else if (is_comm) {
330 delay_visit.push_back(node);
331 } else {
332 EnqueueReadyNodes(node, handle_queue_ptr);
333 }
334 }
335 }
336 if (!is_pynative_kernel_graph_) {
337 CheckLoop();
338 }
339 }
340
PrintLoopNodesIfExist(const AnfNodePtr & node,std::set<AnfNodePtr> * visited_nodes,mindspore::HashMap<AnfNodePtr,AnfNodePtr> * next_nodes)341 bool ExecOrderBuilder::PrintLoopNodesIfExist(const AnfNodePtr &node, std::set<AnfNodePtr> *visited_nodes,
342 mindspore::HashMap<AnfNodePtr, AnfNodePtr> *next_nodes) {
343 MS_EXCEPTION_IF_NULL(node);
344 MS_EXCEPTION_IF_NULL(visited_nodes);
345 MS_EXCEPTION_IF_NULL(next_nodes);
346
347 (void)visited_nodes->insert(node);
348 for (auto &input_node : node_input_edges_[node]) {
349 size_t input_num = node_input_num_[input_node];
350 if (input_num == 0) {
351 continue;
352 }
353 if (visited_nodes->find(input_node) == visited_nodes->end()) {
354 MS_EXCEPTION_IF_NULL(input_node);
355 (*next_nodes)[input_node] = node;
356 if (PrintLoopNodesIfExist(input_node, visited_nodes, next_nodes)) {
357 return true;
358 }
359 } else {
360 auto cur_node = node;
361 std::queue<AnfNodePtr> loop_nodes;
362 while (cur_node != input_node && cur_node != nullptr) {
363 loop_nodes.push(cur_node);
364 cur_node = (*next_nodes)[cur_node];
365 }
366
367 if (cur_node == input_node) {
368 loop_nodes.push(cur_node);
369 MS_LOG(INFO) << "Print loop nodes start:";
370 while (!loop_nodes.empty()) {
371 cur_node = loop_nodes.front();
372 node_input_num_[cur_node]--;
373 MS_LOG(INFO) << "Get loop node:" << cur_node->DebugString();
374 loop_nodes.pop();
375 }
376 MS_LOG(INFO) << "Print loop nodes end.";
377 return true;
378 }
379 }
380 }
381 return false;
382 }
383
CheckLoop()384 void ExecOrderBuilder::CheckLoop() {
385 std::vector<AnfNodePtr> unvisited_nodes;
386 for (auto &node_ref : node_input_num_) {
387 MS_EXCEPTION_IF_NULL(node_ref.first);
388 if (node_ref.second == 0) {
389 continue;
390 }
391 std::string info;
392 for (const auto &input_node : node_input_edges_[node_ref.first]) {
393 MS_EXCEPTION_IF_NULL(input_node);
394 info = info.append(input_node->DebugString()).append("|");
395 }
396 MS_LOG(WARNING) << "Node:" << node_ref.first->DebugString() << ",inputs:" << info
397 << ",input num:" << node_ref.second;
398 (void)unvisited_nodes.emplace_back(node_ref.first);
399 }
400
401 if (unvisited_nodes.empty()) {
402 return;
403 }
404
405 for (auto &node : unvisited_nodes) {
406 MS_EXCEPTION_IF_NULL(node);
407 std::set<AnfNodePtr> visited_nodes;
408 mindspore::HashMap<AnfNodePtr, AnfNodePtr> next_nodes;
409 if (PrintLoopNodesIfExist(node, &visited_nodes, &next_nodes)) {
410 break;
411 }
412 }
413 MS_LOG(EXCEPTION) << "Graph has unvisited nodes and the number is :" << unvisited_nodes.size();
414 }
415 } // namespace mindspore::session
416