• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "backend/common/session/exec_order_builder.h"
17 #include <algorithm>
18 #include <string>
19 #include "ops/ascend_op_name.h"
20 #include "include/common/utils/anfalgo.h"
21 #include "utils/ms_context.h"
22 
23 namespace mindspore::session {
24 const size_t kDefaultContainerSize = 5000;
25 
26 namespace {
GetNodeGroup(const AnfNodePtr & node)27 std::string GetNodeGroup(const AnfNodePtr &node) {
28   MS_EXCEPTION_IF_NULL(node);
29   auto cnode = node->cast<CNodePtr>();
30   if (common::AnfAlgo::HasNodeAttr(kAttrGroup, cnode)) {
31     return common::AnfAlgo::GetNodeAttr<std::string>(cnode, kAttrGroup);
32   }
33   return "";
34 }
35 
NeedOptimize(const AnfNodePtr & node,const std::string & optimized_comm_group)36 bool NeedOptimize(const AnfNodePtr &node, const std::string &optimized_comm_group) {
37   bool is_fused_comm = common::AnfAlgo::IsFusedCommunicationOp(node);
38   if (!is_fused_comm) {
39     return false;
40   }
41   auto node_group = GetNodeGroup(node);
42   if (node_group.find(kSyncBnGroup) == string::npos) {
43     if (optimized_comm_group.empty() || node_group == optimized_comm_group) {
44       return true;
45     }
46   }
47   return false;
48 }
49 }  // namespace
50 
~ExecOrderBuilder()51 ExecOrderBuilder::~ExecOrderBuilder() {}
52 
Build(FuncGraph * graph,std::vector<CNodePtr> * execution_order,NodeUser * node_user)53 void ExecOrderBuilder::Build(FuncGraph *graph, std::vector<CNodePtr> *execution_order, NodeUser *node_user) {
54   MS_EXCEPTION_IF_NULL(graph);
55   MS_EXCEPTION_IF_NULL(execution_order);
56   MS_EXCEPTION_IF_NULL(node_user);
57   graph_ = graph;
58   is_pynative_kernel_graph_ = graph_->has_flag(kFlagIsPyNativeBpropKernelGraph);
59   execution_order_ = execution_order;
60   node_output_edges_ = node_user;
61   node_output_edges_->clear();
62   ClearLinkInfo();
63   BuildLinkInfo();
64   FindIndependentNodes();
65   Build();
66 }
67 
ClearLinkInfo()68 void ExecOrderBuilder::ClearLinkInfo() {
69   if (node_input_num_.empty()) {
70     node_input_num_.reserve(kDefaultContainerSize);
71     node_output_num_.reserve(kDefaultContainerSize);
72     node_input_edges_.reserve(kDefaultContainerSize);
73     trivial_nodes_.reserve(kDefaultContainerSize);
74   } else {
75     node_input_num_.clear();
76     node_output_num_.clear();
77     node_input_edges_.clear();
78     trivial_nodes_.clear();
79     node_output_edges_->clear();
80   }
81 }
82 
IsTrivialNode(const AnfNodePtr & node)83 bool ExecOrderBuilder::IsTrivialNode(const AnfNodePtr &node) {
84   MS_EXCEPTION_IF_NULL(node);
85   if (!node->isa<CNode>()) {
86     return true;
87   }
88 
89   const auto iter = trivial_nodes_.find(node);
90   if (iter != trivial_nodes_.end()) {
91     return iter->second;
92   }
93 
94   if (AnfUtils::IsRealKernel(node)) {
95     (void)trivial_nodes_.emplace(node, false);
96     return false;
97   }
98 
99   auto cnode = node->cast<CNodePtr>();
100   MS_EXCEPTION_IF_NULL(cnode);
101   if (std::all_of(cnode->inputs().begin(), cnode->inputs().end(),
102                   [this](const auto &input) { return IsTrivialNode(input); })) {
103     (void)trivial_nodes_.emplace(node, true);
104     return true;
105   } else {
106     (void)trivial_nodes_.emplace(node, false);
107     return false;
108   }
109 }
110 
BuildLinkInfo()111 void ExecOrderBuilder::BuildLinkInfo() {
112   std::queue<AnfNodePtr> to_visit;
113   auto output = graph_->get_return();
114   if (!output->isa<CNode>()) {
115     return;
116   }
117   to_visit.emplace(output);
118   auto seen = NewSeenGeneration();
119   while (!to_visit.empty()) {
120     auto node = to_visit.front();
121     to_visit.pop();
122     MS_EXCEPTION_IF_NULL(node);
123     auto cnode = node->cast<CNodePtr>();
124     MS_EXCEPTION_IF_NULL(cnode);
125     for (auto &input : cnode->inputs()) {
126       MS_EXCEPTION_IF_NULL(input);
127       (void)(*node_output_edges_)[input].emplace_back(node);
128       if (IsTrivialNode(input)) {
129         GetTrivialInputNode(input, seen);
130         continue;
131       }
132       if (!is_pynative_kernel_graph_) {
133         (void)node_input_edges_[node].emplace_back(input);
134       }
135       node_input_num_[node] += 1;
136       node_output_num_[input] += 1;
137       if (input->seen_ == seen || !input->isa<CNode>() || AnfUtils::IsCustomActorNode(input)) {
138         continue;
139       }
140       to_visit.emplace(input);
141       input->seen_ = seen;
142     }
143   }
144 }
145 
GetTrivialInputNode(const AnfNodePtr & node,SeenNum seen)146 void ExecOrderBuilder::GetTrivialInputNode(const AnfNodePtr &node, SeenNum seen) {
147   MS_EXCEPTION_IF_NULL(node);
148   if (!node->isa<CNode>()) {
149     return;
150   }
151   auto cnode = node->cast<CNodePtr>();
152   for (auto &in : cnode->inputs()) {
153     (void)(*node_output_edges_)[in].emplace_back(node);
154     if (in->seen_ != seen && IsTrivialNode(in)) {
155       GetTrivialInputNode(in, seen);
156       in->seen_ = seen;
157     }
158   }
159 }
160 
CanVisitInput(bool visit_with_refcount,const AnfNodePtr & input,SeenNum seen)161 bool ExecOrderBuilder::CanVisitInput(bool visit_with_refcount, const AnfNodePtr &input, SeenNum seen) {
162   MS_EXCEPTION_IF_NULL(input);
163   if (visit_with_refcount) {
164     auto output_iter = node_output_num_.find(input);
165     if (output_iter != node_output_num_.end()) {
166       output_iter->second--;
167       if (output_iter->second != 0) {
168         return false;
169       }
170     }
171   } else {
172     if (input->seen_ == seen) {
173       return false;
174     }
175     input->seen_ = seen;
176   }
177   return true;
178 }
179 
FindIndependentNodes()180 void ExecOrderBuilder::FindIndependentNodes() {
181   std::queue<AnfNodePtr> to_visit;
182   std::queue<AnfNodePtr> vnode_to_visit;
183   vnode_to_visit.emplace(graph_->get_return());
184   bool visit_with_refcount = true;
185   auto ms_context = MsContext::GetInstance();
186   auto target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
187   if (target == kGPUDevice) {
188     visit_with_refcount = false;
189   }
190   auto seen = NewSeenGeneration();
191   while (!to_visit.empty() || !vnode_to_visit.empty()) {
192     AnfNodePtr node;
193     if (vnode_to_visit.empty()) {
194       node = to_visit.front();
195       to_visit.pop();
196     } else {
197       node = vnode_to_visit.front();
198       vnode_to_visit.pop();
199     }
200 
201     MS_EXCEPTION_IF_NULL(node);
202     if (!node->isa<CNode>()) {
203       continue;
204     }
205 
206     if (AnfUtils::IsCustomActorNode(node)) {
207       independent_nodes_.push(node);
208       continue;
209     }
210     auto cnode = node->cast<CNodePtr>();
211     MS_EXCEPTION_IF_NULL(cnode);
212     bool independent = true;
213     auto &inputs = cnode->inputs();
214     for (auto iter = inputs.rbegin(); iter != inputs.rend(); ++iter) {
215       auto &input = *iter;
216       MS_EXCEPTION_IF_NULL(input);
217       if (IsTrivialNode(input)) {
218         continue;
219       }
220       independent = false;
221 
222       if (!CanVisitInput(visit_with_refcount, input, seen)) {
223         continue;
224       }
225 
226       if (AnfUtils::IsRealKernel(input)) {
227         to_visit.emplace(input);
228         if (!independent_nodes_.empty() && visit_with_refcount) {
229           auto inode = independent_nodes_.top();
230           (void)(*node_output_edges_)[input].emplace_back(inode);
231           if (!is_pynative_kernel_graph_) {
232             (void)node_input_edges_[inode].emplace_back(input);
233           }
234           node_input_num_[inode] += 1;
235           independent_nodes_.pop();
236         }
237       } else {
238         vnode_to_visit.emplace(input);
239       }
240     }
241 
242     if (independent) {
243       independent_nodes_.push(node);
244     }
245   }
246 }
247 
EnqueueReadyNodes(const AnfNodePtr & node,std::deque<AnfNodePtr> * visit_queue,bool comm_first)248 void ExecOrderBuilder::EnqueueReadyNodes(const AnfNodePtr &node, std::deque<AnfNodePtr> *visit_queue, bool comm_first) {
249   MS_EXCEPTION_IF_NULL(visit_queue);
250   MS_EXCEPTION_IF_NULL(visit_queue);
251   MS_EXCEPTION_IF_NULL(node_output_edges_);
252   auto it = node_output_edges_->find(node);
253   if (it == node_output_edges_->end()) {
254     return;
255   }
256 
257   std::vector<AnfNodePtr> active_nodes;
258   for (const auto &output_node : it->second) {
259     MS_EXCEPTION_IF_NULL(output_node);
260     auto input_num_iter = node_input_num_.find(output_node);
261     if (input_num_iter == node_input_num_.end() || input_num_iter->second == 0) {
262       continue;
263     }
264     input_num_iter->second--;
265     if (input_num_iter->second > 0) {
266       continue;
267     }
268 
269     bool is_comm_node = common::AnfAlgo::IsCommunicationOp(output_node);
270     if (!AnfUtils::IsRealKernel(output_node) || it->second.size() == 1) {
271       visit_queue->push_front(output_node);
272     } else if ((is_comm_node && comm_first) || (!is_comm_node && !comm_first)) {
273       visit_queue->push_back(output_node);
274     } else {
275       (void)active_nodes.emplace_back(output_node);
276     }
277   }
278 
279   (void)std::copy(active_nodes.begin(), active_nodes.end(), std::back_inserter(*visit_queue));
280 }
281 
Build()282 void ExecOrderBuilder::Build() {
283   MS_EXCEPTION_IF_NULL(execution_order_);
284   execution_order_->clear();
285   execution_order_->reserve(kDefaultContainerSize);
286   std::deque<AnfNodePtr> to_visit;
287   std::deque<AnfNodePtr> delay_visit;
288   std::deque<AnfNodePtr> high_priority_to_visit;
289   std::deque<AnfNodePtr> *handle_queue_ptr;
290   std::string optimized_comm_group;
291   AnfNodePtr pending_node = nullptr;
292   while (!independent_nodes_.empty() || pending_node != nullptr || !delay_visit.empty()) {
293     if (!delay_visit.empty()) {
294       EnqueueReadyNodes(delay_visit.front(), &high_priority_to_visit, false);
295       delay_visit.pop_front();
296     } else if (pending_node != nullptr) {
297       EnqueueReadyNodes(pending_node, &high_priority_to_visit, false);
298       pending_node = nullptr;
299     } else {
300       to_visit.push_back(independent_nodes_.top());
301       independent_nodes_.pop();
302     }
303     // comm descendant first, then common queue
304     while (!to_visit.empty() || !high_priority_to_visit.empty()) {
305       AnfNodePtr node;
306       if (!high_priority_to_visit.empty()) {
307         handle_queue_ptr = &high_priority_to_visit;
308         node = high_priority_to_visit.front();
309         high_priority_to_visit.pop_front();
310       } else {
311         handle_queue_ptr = &to_visit;
312         node = to_visit.front();
313         to_visit.pop_front();
314       }
315       // add execute node
316       MS_EXCEPTION_IF_NULL(node);
317       if (node->isa<CNode>() && AnfUtils::IsRealKernel(node)) {
318         (void)execution_order_->emplace_back(node->cast<CNodePtr>());
319       }
320       // delay execute comm ops that need optimize
321       bool is_comm = common::AnfAlgo::IsCommunicationOp(node);
322       bool optimize_comm = NeedOptimize(node, optimized_comm_group);
323       if (optimize_comm) {
324         optimized_comm_group = GetNodeGroup(node);
325         if (pending_node != nullptr) {
326           EnqueueReadyNodes(pending_node, &high_priority_to_visit, false);
327         }
328         pending_node = node;
329       } else if (is_comm) {
330         delay_visit.push_back(node);
331       } else {
332         EnqueueReadyNodes(node, handle_queue_ptr);
333       }
334     }
335   }
336   if (!is_pynative_kernel_graph_) {
337     CheckLoop();
338   }
339 }
340 
PrintLoopNodesIfExist(const AnfNodePtr & node,std::set<AnfNodePtr> * visited_nodes,mindspore::HashMap<AnfNodePtr,AnfNodePtr> * next_nodes)341 bool ExecOrderBuilder::PrintLoopNodesIfExist(const AnfNodePtr &node, std::set<AnfNodePtr> *visited_nodes,
342                                              mindspore::HashMap<AnfNodePtr, AnfNodePtr> *next_nodes) {
343   MS_EXCEPTION_IF_NULL(node);
344   MS_EXCEPTION_IF_NULL(visited_nodes);
345   MS_EXCEPTION_IF_NULL(next_nodes);
346 
347   (void)visited_nodes->insert(node);
348   for (auto &input_node : node_input_edges_[node]) {
349     size_t input_num = node_input_num_[input_node];
350     if (input_num == 0) {
351       continue;
352     }
353     if (visited_nodes->find(input_node) == visited_nodes->end()) {
354       MS_EXCEPTION_IF_NULL(input_node);
355       (*next_nodes)[input_node] = node;
356       if (PrintLoopNodesIfExist(input_node, visited_nodes, next_nodes)) {
357         return true;
358       }
359     } else {
360       auto cur_node = node;
361       std::queue<AnfNodePtr> loop_nodes;
362       while (cur_node != input_node && cur_node != nullptr) {
363         loop_nodes.push(cur_node);
364         cur_node = (*next_nodes)[cur_node];
365       }
366 
367       if (cur_node == input_node) {
368         loop_nodes.push(cur_node);
369         MS_LOG(INFO) << "Print loop nodes start:";
370         while (!loop_nodes.empty()) {
371           cur_node = loop_nodes.front();
372           node_input_num_[cur_node]--;
373           MS_LOG(INFO) << "Get loop node:" << cur_node->DebugString();
374           loop_nodes.pop();
375         }
376         MS_LOG(INFO) << "Print loop nodes end.";
377         return true;
378       }
379     }
380   }
381   return false;
382 }
383 
CheckLoop()384 void ExecOrderBuilder::CheckLoop() {
385   std::vector<AnfNodePtr> unvisited_nodes;
386   for (auto &node_ref : node_input_num_) {
387     MS_EXCEPTION_IF_NULL(node_ref.first);
388     if (node_ref.second == 0) {
389       continue;
390     }
391     std::string info;
392     for (const auto &input_node : node_input_edges_[node_ref.first]) {
393       MS_EXCEPTION_IF_NULL(input_node);
394       info = info.append(input_node->DebugString()).append("|");
395     }
396     MS_LOG(WARNING) << "Node:" << node_ref.first->DebugString() << ",inputs:" << info
397                     << ",input num:" << node_ref.second;
398     (void)unvisited_nodes.emplace_back(node_ref.first);
399   }
400 
401   if (unvisited_nodes.empty()) {
402     return;
403   }
404 
405   for (auto &node : unvisited_nodes) {
406     MS_EXCEPTION_IF_NULL(node);
407     std::set<AnfNodePtr> visited_nodes;
408     mindspore::HashMap<AnfNodePtr, AnfNodePtr> next_nodes;
409     if (PrintLoopNodesIfExist(node, &visited_nodes, &next_nodes)) {
410       break;
411     }
412   }
413   MS_LOG(EXCEPTION) << "Graph has unvisited nodes and the number is :" << unvisited_nodes.size();
414 }
415 }  // namespace mindspore::session
416