• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "runtime/graph_scheduler/control_node_scheduler.h"
18 #include "mindspore/core/ops/sequence_ops.h"
19 #include "mindspore/core/ops/framework_ops.h"
20 #include "runtime/graph_scheduler/control_node_parser.h"
21 #include "runtime/graph_scheduler/inline_control_flow_scheduler.h"
22 #include "runtime/graph_scheduler/scheduler_helper.h"
23 
24 namespace mindspore {
25 namespace runtime {
26 namespace {
GetActorName(const AnfNodePtr & node)27 std::string GetActorName(const AnfNodePtr &node) {
28   MS_EXCEPTION_IF_NULL(node);
29   auto debug_name = node->DebugString();
30   auto index = debug_name.find('{');
31   if ((index != std::string::npos) && (index > 0)) {
32     debug_name = debug_name.substr(0, index);
33   }
34 
35   if (common::AnfAlgo::IsCallNode(node)) {
36     return "Call_" + node->UniqueName() + "_" + debug_name;
37   } else {
38     return node->UniqueName() + "_" + debug_name;
39   }
40 }
41 
GetStackActorNameByExitName(const std::string & exit_name)42 std::string GetStackActorNameByExitName(const std::string &exit_name) {
43   size_t pos = exit_name.find(kExitActorNameSuffix);
44   if (pos == std::string::npos) {
45     MS_LOG(INTERNAL_EXCEPTION) << "#dmsg#Runtime error info:#dmsg#Invalid exit actor name:" << exit_name;
46   }
47 
48   return exit_name.substr(0, pos) + kStackActorNameSuffix;
49 }
50 
51 // Parameter and ref node can not copy the device tensor.
is_need_copy_device_tensor(const AnfNodePtr & backend_node,size_t index)52 bool is_need_copy_device_tensor(const AnfNodePtr &backend_node, size_t index) {
53   MS_EXCEPTION_IF_NULL(backend_node);
54   // Skip the parameter and Load node.
55   const auto &real_backend_node = common::AnfAlgo::VisitKernelWithReturnType(backend_node, index, false).first;
56   if (real_backend_node != nullptr && (!real_backend_node->isa<CNode>())) {
57     return false;
58   }
59 
60   if (common::AnfAlgo::HasAbstractRef(backend_node)) {
61     return false;
62   }
63 
64   auto kernel_graph = AnfAlgo::FetchKernelGraph(backend_node.get());
65   MS_EXCEPTION_IF_NULL(kernel_graph);
66   if (kernel_graph->IsInRefOutputMap({backend_node, index})) {
67     if (!kernel_graph->is_graph_run_mode()) {
68       return false;
69     }
70     const auto &origin_node = kernel_graph->GetRefCorrespondOutput({backend_node, index}).first;
71     MS_EXCEPTION_IF_NULL(origin_node);
72     if (origin_node->isa<ValueNode>() || origin_node->isa<Parameter>()) {
73       return false;
74     }
75   }
76   return true;
77 }
78 
79 // Check whether the exit actor corresponding to the call node to the to actor already exists control arrow.
IsControlArrowExistForCallNode(const AnfNodePtr & node,const AbstractActor * const to_actor,const ControlNodeParserPtr & parser)80 bool IsControlArrowExistForCallNode(const AnfNodePtr &node, const AbstractActor *const to_actor,
81                                     const ControlNodeParserPtr &parser) {
82   MS_EXCEPTION_IF_NULL(node);
83   MS_EXCEPTION_IF_NULL(to_actor);
84   MS_EXCEPTION_IF_NULL(parser);
85   if (!common::AnfAlgo::IsCallNode(node)) {
86     MS_LOG_WITH_NODE(INTERNAL_EXCEPTION, node)
87       << "#dmsg#Runtime error info:#dmsg#Invalid call node:" << node->DebugString();
88   }
89   int branch_id = parser->FetchBranchIDByCallNode(node);
90 
91   const auto &func_graphs = parser->FetchFuncGraphbyCallNode(node);
92   if (func_graphs.empty()) {
93     MS_LOG_WITH_NODE(INTERNAL_EXCEPTION, node)
94       << "#dmsg#Runtime error info:#dmsg#Failed to get funcgraph by call node:" << node->DebugString();
95   }
96   MS_EXCEPTION_IF_NULL(*(func_graphs.begin()));
97   auto actor_name = (*(func_graphs.begin()))->ToString() + kExitActorNameSuffix;
98   const auto &actor = FetchActor(actor_name);
99   MS_EXCEPTION_IF_NULL(actor);
100   const auto &exit_actor = dynamic_cast<ExitActor *>(actor);
101   MS_EXCEPTION_IF_NULL(exit_actor);
102 
103   const auto &branch_arrows = exit_actor->output_branch_control_arrows();
104   const auto &arrow_iter = branch_arrows.find(branch_id);
105   if (arrow_iter == branch_arrows.end()) {
106     return false;
107   }
108   const auto &arrows = arrow_iter->second;
109   return std::find(arrows.begin(), arrows.end(), to_actor->GetAID()) != arrows.end();
110 }
111 
IsNotCut(const AnfNodePtr & node)112 bool IsNotCut(const AnfNodePtr &node) {
113   MS_EXCEPTION_IF_NULL(node);
114   if (!node->isa<CNode>()) {
115     return false;
116   }
117   auto cnode = node->cast<CNodePtr>();
118   MS_EXCEPTION_IF_NULL(cnode);
119   return cnode->HasPrimalAttr(kAttrNotCut);
120 }
121 }  // namespace
122 
Build(const GraphCompilerInfo & graph_compiler_info,const AID & memory_manager_aid)123 ControlActorSetPtr ControlNodeScheduler::Build(const GraphCompilerInfo &graph_compiler_info,
124                                                const AID &memory_manager_aid) {
125   const auto &control_nodes = graph_compiler_info.control_nodes_;
126   if (control_nodes.size() <= kSingleControlNode) {
127     return nullptr;
128   }
129 
130   memory_manager_aid_ = memory_manager_aid;
131   ControlActorSetPtr control_actors = std::make_shared<ControlActorSet>();
132   MS_EXCEPTION_IF_NULL(control_actors);
133   control_actors->switch_actors_ = BuildSwitchActor(graph_compiler_info);
134   control_actors->gather_actors_ = BuildGatherActor(graph_compiler_info);
135   control_actors->entrance_actors_ = BuildEntranceActor(graph_compiler_info);
136   control_actors->exit_actors_ = BuildExitActor(graph_compiler_info);
137   control_actors->stack_actors_ = BuildStackActor(graph_compiler_info);
138   return control_actors;
139 }
140 
BuildSwitchActor(const GraphCompilerInfo & graph_compiler_info) const141 std::vector<SwitchActorPtr> ControlNodeScheduler::BuildSwitchActor(const GraphCompilerInfo &graph_compiler_info) const {
142   std::vector<SwitchActorPtr> switch_actors;
143   const auto &control_nodes = graph_compiler_info.control_nodes_;
144 
145   for (const auto &control_node : control_nodes) {
146     // Switch node and switch layer node will be converted to switch actor.
147     if (common::AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimSwitch) ||
148         common::AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimSwitchLayer)) {
149       const auto &actor_name = GetActorName(control_node);
150       const auto &parameters = FetchInputNodeByCNode(control_node);
151       const auto &switch_actor =
152         std::make_shared<SwitchActor>(actor_name, memory_manager_aid_, parameters, control_node);
153       (void)switch_actors.emplace_back(switch_actor);
154       for (const auto &parameter : parameters) {
155         MS_EXCEPTION_IF_NULL(parameter.first);
156         MS_LOG(DEBUG) << "Print formal parameter for actor:" << actor_name
157                       << " parameter:" << parameter.first->DebugString() << " index:" << parameter.second;
158       }
159       InsertActor(switch_actor.get());
160     }
161   }
162   return switch_actors;
163 }
164 
BuildDataSourceActorForControlNode(const GraphCompilerInfo & graph_compiler_info,const HostTensorQueuePtr & host_queue,const HostQueueDSActorPtr & host_queue_ds_actor,const AID & memory_manager_aid,std::vector<DataSourceActorPtr> * data_source_actors) const165 void ControlNodeScheduler::BuildDataSourceActorForControlNode(
166   const GraphCompilerInfo &graph_compiler_info, const HostTensorQueuePtr &host_queue,
167   const HostQueueDSActorPtr &host_queue_ds_actor, const AID &memory_manager_aid,
168   std::vector<DataSourceActorPtr> *data_source_actors) const {
169   HostQueueDSActorPtr control_node_ds_actor = host_queue_ds_actor;
170   const auto parser = graph_compiler_info.control_node_parser_;
171   MS_EXCEPTION_IF_NULL(parser);
172   MS_EXCEPTION_IF_NULL(data_source_actors);
173 
174   // Initialize the parameter in the control node, first get all the front parameters in the control node, then find
175   // the corresponding backend parameter from the map, and insert it into the host data source actor.
176   const auto &control_node_parameters = parser->control_node_parameters();
177   for (const auto &parameter_with_index : control_node_parameters) {
178     MS_EXCEPTION_IF_NULL(parameter_with_index.first);
179     if (IsPersistentDeviceTensor(parameter_with_index.first)) {
180       continue;
181     }
182     if (control_node_ds_actor == nullptr) {
183       auto actor_name = graph_compiler_info.name_ + kHostDSActorNameSuffix;
184       MS_LOG(INFO) << "Create host queue data source actor: " << actor_name;
185       control_node_ds_actor =
186         std::make_shared<HostQueueDataSourceActor>(actor_name, 1, memory_manager_aid, nullptr, nullptr, host_queue);
187       MS_EXCEPTION_IF_NULL(control_node_ds_actor);
188       InsertActor(control_node_ds_actor.get());
189       (void)data_source_actors->emplace_back(control_node_ds_actor);
190     }
191 
192     auto &node_map = control_node_ds_actor->data_node_position_map_;
193     if (node_map.find(parameter_with_index) != node_map.end()) {
194       continue;
195     }
196     graph_compiler_info.origin_parameters_to_backend_parameters_[parameter_with_index.first].emplace_back(
197       std::make_pair(parameter_with_index, parameter_with_index));
198 
199     const auto &node_with_index_with_context =
200       parser->FetchBackendParameterWithContextByFrontParameter(parameter_with_index);
201     const auto &node_with_index = node_with_index_with_context.first;
202     const auto &device_context = node_with_index_with_context.second;
203     MS_EXCEPTION_IF_NULL(node_with_index.first);
204     MS_EXCEPTION_IF_NULL(device_context);
205     MS_LOG(DEBUG) << "Control node parameter front node:" << parameter_with_index.first->DebugString()
206                   << " index:" << parameter_with_index.second
207                   << " backend node:" << node_with_index.first->DebugString() << " index:" << node_with_index.second;
208     auto iter = find(control_node_ds_actor->data_node_with_indexs_.begin(),
209                      control_node_ds_actor->data_node_with_indexs_.end(), node_with_index);
210     if (iter != control_node_ds_actor->data_node_with_indexs_.end()) {
211       (void)node_map.emplace(parameter_with_index, iter - control_node_ds_actor->data_node_with_indexs_.begin());
212       MS_LOG(DEBUG) << "Insert front node:" << parameter_with_index.first->DebugString()
213                     << " index:" << parameter_with_index.second << " to host queue data source actor.";
214     } else {
215       CreateBuildInfoForFrontNode(parameter_with_index, node_with_index.first);
216       // Create device tensor.
217       const auto &device_address = AnfAlgo::GetMutableOutputAddr(node_with_index.first, node_with_index.second, false);
218       MS_EXCEPTION_IF_NULL(device_address);
219       const auto &sub_abstract =
220         common::AnfAlgo::FetchAbstractByIndex(parameter_with_index.first->abstract(), parameter_with_index.second);
221       MS_EXCEPTION_IF_NULL(sub_abstract);
222       const auto &kernel_tensor = std::make_shared<kernel::KernelTensor>(
223         sub_abstract->BuildShape(), sub_abstract->BuildType(), nullptr, nullptr, device_address->GetSize(),
224         device_address->format(), device_address->type_id(), device_address->host_shape(),
225         device_context->device_context_key().device_name_, device_context->device_context_key().device_id_);
226       MS_EXCEPTION_IF_NULL(kernel_tensor);
227       kernel_tensor->set_stream_id(AnfAlgo::GetStreamId(parameter_with_index.first));
228       auto new_address = device_context->device_res_manager_->CreateDeviceAddress(kernel_tensor);
229       MS_EXCEPTION_IF_NULL(new_address);
230       MS_LOG(DEBUG) << "Create new address for node that has no corresponding backend node:"
231                     << parameter_with_index.first->DebugString() << " index:" << parameter_with_index.second
232                     << " addr:" << new_address << " size:" << device_address->GetSize()
233                     << ", type id:" << device_address->type_id()
234                     << " type:" << (kernel_tensor->GetType() == nullptr ? "null" : kernel_tensor->GetType()->ToString())
235                     << " shape:"
236                     << (kernel_tensor->GetShape() == nullptr ? "null" : kernel_tensor->GetShape()->ToString());
237       AnfAlgo::SetOutputAddr(new_address, parameter_with_index.second, parameter_with_index.first.get());
238 
239       (void)node_map.emplace(parameter_with_index, control_node_ds_actor->data_node_with_indexs_.size());
240       (void)control_node_ds_actor->data_node_with_indexs_.emplace_back(parameter_with_index);
241       (void)control_node_ds_actor->device_contexts_.emplace_back(device_context);
242     }
243   }
244 }
245 
BuildGatherActor(const GraphCompilerInfo & graph_compiler_info) const246 std::vector<GatherActorPtr> ControlNodeScheduler::BuildGatherActor(const GraphCompilerInfo &graph_compiler_info) const {
247   std::vector<GatherActorPtr> gather_actors;
248   const auto &control_nodes = graph_compiler_info.control_nodes_;
249   const auto &parser = graph_compiler_info.control_node_parser_;
250   MS_EXCEPTION_IF_NULL(parser);
251 
252   for (const auto &control_node : control_nodes) {
253     MS_EXCEPTION_IF_NULL(control_node);
254     // Partial node and call node will be converted to gather actor.
255     if ((common::AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimPartial) && (!IsInvalidPartial(control_node))) ||
256         common::AnfAlgo::IsCallNode(control_node)) {
257       const auto &actor_name = GetActorName(control_node);
258       const auto &parameters = FetchInputNodeByCNode(control_node);
259       const auto &gather_actor =
260         std::make_shared<GatherActor>(actor_name, memory_manager_aid_, parameters, control_node);
261       MS_EXCEPTION_IF_NULL(gather_actor);
262       (void)gather_actors.emplace_back(gather_actor);
263       for (const auto &parameter : parameters) {
264         MS_EXCEPTION_IF_NULL(parameter.first);
265         MS_LOG(DEBUG) << "Print formal parameter for actor:" << actor_name
266                       << " parameter:" << parameter.first->DebugString() << " index:" << parameter.second;
267       }
268       InsertActor(gather_actor.get());
269 
270       // The gather actor corresponding to a call node needs to set the branch id.
271       if (common::AnfAlgo::IsCallNode(control_node)) {
272         gather_actor->output_branch_id_ = parser->FetchBranchIDByCallNode(control_node);
273       }
274 
275       // Fetch device contexts for gather actor.
276       const auto &iter = parser->control_node_to_device_contexts_.find(control_node);
277       if (iter == parser->control_node_to_device_contexts_.end()) {
278         MS_LOG_WITH_NODE(INTERNAL_EXCEPTION, control_node)
279           << "#dmsg#Runtime error info:#dmsg#Failed to get device contexts for node:" << control_node->DebugString();
280       }
281       gather_actor->device_contexts_ = iter->second;
282     }
283   }
284   return gather_actors;
285 }
286 
BuildEntranceActor(const GraphCompilerInfo & graph_compiler_info) const287 std::vector<EntranceActorPtr> ControlNodeScheduler::BuildEntranceActor(
288   const GraphCompilerInfo &graph_compiler_info) const {
289   const auto &parser = graph_compiler_info.control_node_parser_;
290   MS_EXCEPTION_IF_NULL(parser);
291   const auto &call_node_to_func_graphs = parser->call_node_to_func_graphs_;
292   std::unordered_map<FuncGraphPtr, std::set<KernelWithIndex>> func_graph_to_call_nodes;
293   for (const auto &call_node_to_func_graph : call_node_to_func_graphs) {
294     const auto &node = call_node_to_func_graph.first;
295     for (const auto &func_graph : call_node_to_func_graph.second) {
296       (void)func_graph_to_call_nodes[func_graph].emplace(node, 0);
297     }
298   }
299 
300   std::vector<EntranceActorPtr> entrance_actors;
301   const auto &control_nodes = graph_compiler_info.control_nodes_;
302   for (const auto &control_node : control_nodes) {
303     MS_EXCEPTION_IF_NULL(control_node);
304     if (common::AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimReturn)) {
305       const auto &func_graph = control_node->func_graph();
306       MS_EXCEPTION_IF_NULL(func_graph);
307       const auto &actor_name = func_graph->ToString() + kEntranceActorNameSuffix;
308       std::vector<KernelWithIndex> formal_parameters;
309 
310       // The entrance actor has two parts of node members :
311       // 1. The formal parameters of the subgraph are used to connect the actor's output arrows.
312       for (const auto &parameter : func_graph->parameters()) {
313         MS_EXCEPTION_IF_NULL(parameter);
314         const auto &abstract = parameter->abstract();
315         MS_EXCEPTION_IF_NULL(abstract);
316         size_t output_num = common::AnfAlgo::GetOutputNumByAbstract(abstract);
317         for (size_t i = 0; i < output_num; ++i) {
318           (void)formal_parameters.emplace_back(parameter, i);
319         }
320       }
321 
322       // 2. The caller of the subgraph, namely call nodes, is used to connect the input arrows.
323       std::set<KernelWithIndex> call_nodes;
324       const auto &iter = func_graph_to_call_nodes.find(func_graph);
325       if (iter != func_graph_to_call_nodes.end()) {
326         call_nodes = iter->second;
327       }
328       for (const auto &formal_parameter : formal_parameters) {
329         MS_EXCEPTION_IF_NULL(formal_parameter.first);
330         MS_LOG(DEBUG) << "Print formal parameter for actor:" << actor_name
331                       << " parameter:" << formal_parameter.first->DebugString() << " index:" << formal_parameter.second;
332       }
333       const auto &entrance_actor =
334         std::make_shared<EntranceActor>(actor_name, memory_manager_aid_, formal_parameters, call_nodes, control_node);
335       MS_EXCEPTION_IF_NULL(entrance_actor);
336       auto context_iter = parser->func_graph_to_device_contexts_.find(func_graph);
337       if (context_iter == parser->func_graph_to_device_contexts_.end() ||
338           context_iter->second.size() < formal_parameters.size()) {
339         MS_LOG(INTERNAL_EXCEPTION)
340           << "#dmsg#Runtime error info:#dmsg#Invalid device contexts for funcgraph:" << func_graph->ToString()
341           << " parameter num:" << formal_parameters.size() << " device contexts num:"
342           << (context_iter == parser->func_graph_to_device_contexts_.end() ? 0 : context_iter->second.size());
343       }
344       entrance_actor->device_contexts_.clear();
345       (void)entrance_actor->device_contexts_.insert(
346         entrance_actor->device_contexts_.begin(), context_iter->second.begin(),
347         context_iter->second.begin() + SizeToLong(formal_parameters.size()));
348       (void)entrance_actors.emplace_back(entrance_actor);
349       InsertActor(entrance_actor.get());
350     }
351   }
352 
353   return entrance_actors;
354 }
355 
BuildExitActor(const GraphCompilerInfo & graph_compiler_info) const356 std::vector<ExitActorPtr> ControlNodeScheduler::BuildExitActor(const GraphCompilerInfo &graph_compiler_info) const {
357   std::vector<ExitActorPtr> exit_actors;
358   const auto &control_nodes = graph_compiler_info.control_nodes_;
359   const auto &parser = graph_compiler_info.control_node_parser_;
360   MS_EXCEPTION_IF_NULL(parser);
361 
362   // The exit actor is used in 2 places:
363   // 1.funcgraph output, that is the output of the return node.
364   for (const auto &control_node : control_nodes) {
365     MS_EXCEPTION_IF_NULL(control_node);
366     if (common::AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimReturn)) {
367       const auto &func_graph = control_node->func_graph();
368       MS_EXCEPTION_IF_NULL(func_graph);
369       const auto &actor_name = func_graph->ToString() + kExitActorNameSuffix;
370       const auto &parameters = FetchInputNodeByCNode(control_node);
371       const auto &exit_actor = std::make_shared<ExitActor>(actor_name, memory_manager_aid_, parameters, control_node);
372       MS_EXCEPTION_IF_NULL(exit_actor);
373       for (const auto &parameter : parameters) {
374         MS_EXCEPTION_IF_NULL(parameter.first);
375         MS_LOG(DEBUG) << "Print formal parameter for actor:" << actor_name
376                       << " parameter:" << parameter.first->DebugString() << " index:" << parameter.second;
377       }
378       auto context_iter = parser->control_node_to_device_contexts_.find(control_node);
379       if (context_iter == parser->control_node_to_device_contexts_.end() ||
380           context_iter->second.size() != parameters.size()) {
381         MS_LOG(INTERNAL_EXCEPTION) << "#dmsg#Runtime error info:#dmsg#Failed to get device contexts for funcgraph:"
382                                    << func_graph->ToString();
383       }
384       exit_actor->device_contexts_ = context_iter->second;
385       (void)exit_actors.emplace_back(exit_actor);
386       InsertActor(exit_actor.get());
387     }
388   }
389 
390   if (graph_compiler_info.graphs_.size() != graph_compiler_info.device_contexts_.size()) {
391     MS_LOG(INTERNAL_EXCEPTION) << "#dmsg#Runtime error info:#dmsg#Invalid graphs num:"
392                                << graph_compiler_info.graphs_.size()
393                                << " and contexts num:" << graph_compiler_info.device_contexts_.size();
394   }
395 
396   // 2. Replace the device address in the kernel actor when calling funcgraph, that is to say in the data exchange
397   // between kernel graph and the control node, in fact, it is the output of the kernel graph.
398   for (const auto &kernel_graph_group_info : parser->kernel_graph_group_infos_) {
399     MS_EXCEPTION_IF_NULL(kernel_graph_group_info);
400     if (kernel_graph_group_info->graphs_.empty()) {
401       continue;
402     }
403 
404     std::vector<bool> is_need_copy_device_tensors;
405     std::vector<bool> is_need_dynamic_checks;
406     std::vector<bool> is_dynamic_shapes;
407     std::vector<KernelWithIndex> formal_parameters;
408     std::vector<const DeviceContext *> device_contexts;
409 
410     for (const auto &node_with_context : kernel_graph_group_info->front_output_nodes_) {
411       if (HasAbstractMonad(node_with_context.first.first) || IsCsrNode(node_with_context.first.first)) {
412         continue;
413       }
414       // Collect inputs of exit actor.
415       (void)formal_parameters.emplace_back(node_with_context.first);
416       // Get the device contexts of the exit actor's cnode inputs.
417       const AnfNodePtr &backend_node = node_with_context.second.first.first;
418       MS_EXCEPTION_IF_NULL(backend_node);
419       (void)is_need_copy_device_tensors.emplace_back(
420         is_need_copy_device_tensor(backend_node, node_with_context.second.first.second));
421       (void)is_need_dynamic_checks.emplace_back(
422         common::AnfAlgo::CheckPrimitiveType(backend_node, prim::kPrimConditionGather));
423       auto is_dynamic_shape =
424         common::AnfAlgo::IsDynamicShape(backend_node) || common::AnfAlgo::IsDynamicSequence(backend_node);
425       (void)is_dynamic_shapes.emplace_back(is_dynamic_shape);
426       (void)device_contexts.emplace_back(node_with_context.second.second);
427     }
428     const auto &actor_name = kernel_graph_group_info->group_name_ + kExitActorNameSuffix;
429     const auto &exit_actor = std::make_shared<ExitActor>(actor_name, memory_manager_aid_, formal_parameters, nullptr);
430     MS_EXCEPTION_IF_NULL(exit_actor);
431     exit_actor->is_need_copy_device_tensors_.swap(is_need_copy_device_tensors);
432     exit_actor->is_need_dynamic_checks_.swap(is_need_dynamic_checks);
433     exit_actor->is_dynamic_shapes_.swap(is_dynamic_shapes);
434     exit_actor->device_contexts_.swap(device_contexts);
435     for (const auto &graph : kernel_graph_group_info->graphs_) {
436       MS_EXCEPTION_IF_NULL(graph);
437       std::for_each(graph->GetRefMap().begin(), graph->GetRefMap().end(),
438                     [&exit_actor, &graph](const std::pair<KernelWithIndex, KernelWithIndex> &pair) {
439                       exit_actor->ref_out_in_map_[pair.first] = graph->GetRefNodeRecursive(pair.first);
440                     });
441     }
442     (void)exit_actors.emplace_back(exit_actor);
443     InsertActor(exit_actor.get());
444   }
445 
446   return exit_actors;
447 }
448 
BuildStackActor(const GraphCompilerInfo & graph_compiler_info) const449 std::vector<StackActorPtr> ControlNodeScheduler::BuildStackActor(const GraphCompilerInfo &graph_compiler_info) const {
450   std::vector<StackActorPtr> stack_actors;
451   const auto &parser = graph_compiler_info.control_node_parser_;
452   MS_EXCEPTION_IF_NULL(parser);
453 
454   // Create a corresponding stack actor for each kernel graph that has a call node as input.
455   for (const auto &kernel_graph_group_info : parser->kernel_graph_group_infos_) {
456     MS_EXCEPTION_IF_NULL(kernel_graph_group_info);
457     if (!kernel_graph_group_info->need_stack_) {
458       continue;
459     }
460     const auto &actor_name = kernel_graph_group_info->group_name_ + kStackActorNameSuffix;
461     size_t input_parameter_data_num = 0;
462     std::vector<const DeviceContext *> device_contexts;
463     std::vector<KernelWithIndex> formal_parameters;
464     // Collect inputs of stack actor.
465     for (const auto &node_with_context : kernel_graph_group_info->front_input_nodes_) {
466       // If the input comes from inside funcgraph, put it at the front of the vector, otherwise put it at the end.
467       const auto &from_node = node_with_context.first.first;
468       MS_EXCEPTION_IF_NULL(from_node);
469       auto iter = parser->node_to_level_.find(from_node);
470       if (iter == parser->node_to_level_.end()) {
471         MS_LOG_WITH_NODE(INTERNAL_EXCEPTION, from_node)
472           << "#dmsg#Runtime error info:#dmsg#Failed to get level by from node:" << from_node->DebugString()
473           << " in graph:" << kernel_graph_group_info->group_name_;
474       }
475       if (iter->second == kernel_graph_group_info->level_ && (!parser->IsRootGraphPersistentDeviceTensor(from_node))) {
476         (void)formal_parameters.emplace_back(node_with_context.first);
477         (void)device_contexts.emplace_back(node_with_context.second);
478         MS_LOG(DEBUG) << "Add normal parameter for actor:" << actor_name << " node:" << from_node->DebugString()
479                       << " index:" << node_with_context.first.second;
480       } else {
481         (void)formal_parameters.insert(formal_parameters.begin(), node_with_context.first);
482         (void)device_contexts.insert(device_contexts.begin(), node_with_context.second);
483         MS_LOG(DEBUG) << "Add stack parameter for actor:" << actor_name << " node:" << from_node->DebugString()
484                       << " index:" << node_with_context.first.second;
485         input_parameter_data_num++;
486       }
487     }
488     const auto &stack_actor = std::make_shared<StackActor>(actor_name, memory_manager_aid_, formal_parameters);
489     MS_EXCEPTION_IF_NULL(stack_actor);
490     (void)stack_actors.emplace_back(stack_actor);
491     stack_actor->device_contexts_.swap(device_contexts);
492     stack_actor->input_stack_data_num_ = input_parameter_data_num;
493     InsertActor(stack_actor.get());
494   }
495   // Create stack actors for control nodes.
496   BuildStackActorForControlNode(graph_compiler_info, &stack_actors);
497 
498   return stack_actors;
499 }
500 
BuildStackActorForControlNode(const GraphCompilerInfo & graph_compiler_info,std::vector<StackActorPtr> * const stack_actors) const501 void ControlNodeScheduler::BuildStackActorForControlNode(const GraphCompilerInfo &graph_compiler_info,
502                                                          std::vector<StackActorPtr> *const stack_actors) const {
503   const auto &parser = graph_compiler_info.control_node_parser_;
504   MS_EXCEPTION_IF_NULL(parser);
505 
506   for (const auto &need_stack_control_node : parser->need_stack_control_nodes_) {
507     MS_EXCEPTION_IF_NULL(need_stack_control_node);
508     MS_LOG(DEBUG) << "Build stack actor for control node:" << need_stack_control_node->DebugString();
509     const auto &stack_actor_name = GetActorName(need_stack_control_node) + kStackActorNameSuffix;
510     std::vector<KernelWithIndex> formal_parameters;
511     std::vector<const DeviceContext *> device_contexts;
512     size_t input_parameter_data_num = 0;
513     size_t input_parameter_partials_num = 0;
514 
515     // Fetch the control actor of control node.
516     std::string control_actor_name = "";
517     if (common::AnfAlgo::CheckPrimitiveType(need_stack_control_node, prim::kPrimReturn)) {
518       const auto &func_graph = need_stack_control_node->func_graph();
519       MS_EXCEPTION_IF_NULL(func_graph);
520       control_actor_name = func_graph->ToString() + kExitActorNameSuffix;
521     } else if (common::AnfAlgo::CheckPrimitiveType(need_stack_control_node, prim::kPrimPartial) ||
522                common::AnfAlgo::CheckPrimitiveType(need_stack_control_node, prim::kPrimSwitch) ||
523                common::AnfAlgo::CheckPrimitiveType(need_stack_control_node, prim::kPrimSwitchLayer) ||
524                common::AnfAlgo::IsCallNode(need_stack_control_node)) {
525       control_actor_name = GetActorName(need_stack_control_node);
526     } else {
527       MS_LOG_WITH_NODE(INTERNAL_EXCEPTION, need_stack_control_node)
528         << "#dmsg#Runtime error info:#dmsg#Invalid control node:" << need_stack_control_node->DebugString();
529     }
530 
531     auto iter = parser->node_to_level_.find(need_stack_control_node);
532     if (iter == parser->node_to_level_.end()) {
533       MS_LOG_WITH_NODE(INTERNAL_EXCEPTION, need_stack_control_node)
534         << "#dmsg#Runtime error info:#dmsg#Failed to get level for need stack control node:"
535         << need_stack_control_node->DebugString();
536     }
537     size_t control_node_level = iter->second;
538 
539     auto actor = FetchActor(control_actor_name);
540     if (actor == nullptr) {
541       MS_LOG(INTERNAL_EXCEPTION) << "#dmsg#Runtime error info:#dmsg#Invalid actor name:" << control_actor_name;
542     }
543     auto control_actor = dynamic_cast<ControlActor *>(actor);
544     MS_EXCEPTION_IF_NULL(control_actor);
545     if (control_actor->formal_parameters_.size() > control_actor->device_contexts_.size()) {
546       MS_LOG(INTERNAL_EXCEPTION) << "#dmsg#Runtime error info:#dmsg#Invalid device context size:"
547                                  << control_actor->device_contexts_.size()
548                                  << " and formal parameter size:" << control_actor->formal_parameters_.size()
549                                  << " for actor:" << control_actor->GetAID();
550     }
551 
552     // Collect formal parameters and device contexts, skip the value nodes.
553     for (size_t i = 0; i < control_actor->formal_parameters_.size(); ++i) {
554       const auto &parameter = control_actor->formal_parameters_[i];
555       auto device_context = control_actor->device_contexts_[i];
556       MS_EXCEPTION_IF_NULL(parameter.first);
557       if (parameter.first->isa<ValueNode>()) {
558         continue;
559       }
560 
561       iter = parser->node_to_level_.find(parameter.first);
562       if (iter == parser->node_to_level_.end()) {
563         MS_LOG_WITH_NODE(INTERNAL_EXCEPTION, parameter.first)
564           << "#dmsg#Runtime error info:#dmsg#Failed to get level for formal parameter:"
565           << parameter.first->DebugString()
566           << " for need stack control node:" << need_stack_control_node->DebugString();
567       }
568 
569       if (control_node_level == iter->second && (!parser->IsRootGraphPersistentDeviceTensor(parameter.first))) {
570         MS_LOG(DEBUG) << "Add normal parameter:" << parameter.first->DebugString()
571                       << " for stack actor:" << stack_actor_name;
572         (void)formal_parameters.emplace_back(parameter);
573         (void)device_contexts.emplace_back(device_context);
574       } else {
575         MS_LOG(DEBUG) << "Add stack parameter:" << parameter.first->DebugString()
576                       << " for stack actor:" << stack_actor_name;
577         (void)formal_parameters.insert(formal_parameters.begin(), parameter);
578         (void)device_contexts.insert(device_contexts.begin(), device_context);
579 
580         const auto &abstract = parameter.first->abstract();
581         MS_EXCEPTION_IF_NULL(abstract);
582         const auto &real_abstract = common::AnfAlgo::FetchAbstractByIndex(abstract, parameter.second);
583         MS_EXCEPTION_IF_NULL(real_abstract);
584         if (real_abstract->isa<abstract::AbstractFunction>()) {
585           input_parameter_partials_num++;
586         } else {
587           input_parameter_data_num++;
588         }
589       }
590     }
591     // Create stack actor.
592     const auto &stack_actor = std::make_shared<StackActor>(stack_actor_name, memory_manager_aid_, formal_parameters);
593     MS_EXCEPTION_IF_NULL(stack_actor);
594     stack_actor->device_contexts_ = device_contexts;
595     stack_actor->input_stack_data_num_ = input_parameter_data_num;
596     stack_actor->input_stack_partials_num_ = input_parameter_partials_num;
597     stack_actor->node_ = need_stack_control_node;
598     InsertActor(stack_actor.get());
599     (void)stack_actors->emplace_back(stack_actor);
600   }
601 }
602 
603 namespace {
ParseRealIndex(const mindspore::HashMap<size_t,size_t> & dynamic_len_index,size_t formal_input_num,std::vector<std::pair<std::vector<size_t>,bool>> * real_indexes,AbstractActor * actor)604 void ParseRealIndex(const mindspore::HashMap<size_t, size_t> &dynamic_len_index, size_t formal_input_num,
605                     std::vector<std::pair<std::vector<size_t>, bool>> *real_indexes, AbstractActor *actor) {
606   MS_EXCEPTION_IF_NULL(real_indexes);
607   MS_EXCEPTION_IF_NULL(actor);
608   auto tmp_dynamic_len_index = dynamic_len_index;
609   size_t real_output_num = formal_input_num + tmp_dynamic_len_index.size();
610   for (const auto &index_pair : tmp_dynamic_len_index) {
611     if (real_output_num < index_pair.second) {
612       MS_LOG(EXCEPTION) << "Invalid dynamic len:" << std::to_string(index_pair.second)
613                         << " start index:" << std::to_string(index_pair.first)
614                         << " real input num:" << std::to_string(real_output_num)
615                         << " for actor:" << actor->GetAID().Name();
616     }
617     real_output_num -= index_pair.second;
618   }
619   MS_LOG(DEBUG) << "for actor:" << actor->GetAID() << " real output num:" << real_output_num;
620   size_t start_index = 0;
621   for (const auto &pair : tmp_dynamic_len_index) {
622     MS_LOG(DEBUG) << "start index:" << pair.first << " len:" << pair.second;
623   }
624   for (size_t i = 0; i < real_output_num; ++i) {
625     MS_LOG(DEBUG) << "for actor:" << actor->GetAID() << " real output index:" << i;
626     if (tmp_dynamic_len_index.find(start_index) != tmp_dynamic_len_index.end()) {
627       std::vector<size_t> indexes;
628       for (size_t j = 0; j < tmp_dynamic_len_index[start_index]; ++j) {
629         indexes.emplace_back(j + start_index);
630       }
631       real_indexes->emplace_back(indexes, true);
632       start_index += tmp_dynamic_len_index[start_index];
633       tmp_dynamic_len_index.erase(start_index);
634     } else {
635       std::vector<size_t> indexes{start_index};
636       real_indexes->emplace_back(indexes, false);
637       start_index++;
638     }
639   }
640   for (size_t i = 0; i < real_indexes->size(); ++i) {
641     std::string index_str = "index_";
642     for (const auto &index : (*real_indexes)[i].first) {
643       index_str = index_str + std::to_string(index) + "_";
644     }
645     MS_LOG(DEBUG) << "for actor:" << actor->GetAID() << " real input " << i << " " << index_str;
646   }
647   if (real_indexes->size() != real_output_num) {
648     MS_LOG(EXCEPTION) << "Invalid real index size:" << std::to_string(real_indexes->size())
649                       << " start need:" << std::to_string(real_output_num) << " for actor:" << actor->GetAID().Name();
650   }
651 }
652 }  // namespace
653 
CollectDynamicLenIndexForArgment(const GraphCompilerInfo & graph_compiler_info) const654 void ControlNodeScheduler::CollectDynamicLenIndexForArgment(const GraphCompilerInfo &graph_compiler_info) const {
655   const auto &parser = graph_compiler_info.control_node_parser_;
656   MS_EXCEPTION_IF_NULL(parser);
657   for (const auto &node_to_func_with_index : parser->control_node_to_funcgraph_with_dynamic_sequence_index_) {
658     const auto &node = node_to_func_with_index.first;
659     MS_EXCEPTION_IF_NULL(node);
660     const auto &actor_name = GetActorName(node);
661     const auto &actor = FetchActor(actor_name);
662     MS_EXCEPTION_IF_NULL(actor);
663     const auto &gather_actor = dynamic_cast<GatherActor *>(actor);
664     MS_EXCEPTION_IF_NULL(gather_actor);
665     for (const auto &func_with_index : node_to_func_with_index.second) {
666       const auto &func_graph = func_with_index.first;
667       MS_EXCEPTION_IF_NULL(func_graph);
668       auto dynamic_len_index = func_with_index.second;
669       std::vector<std::pair<std::vector<size_t>, bool>> real_indexes;
670       ParseRealIndex(dynamic_len_index, gather_actor->formal_parameters_.size(), &real_indexes, gather_actor);
671       MS_LOG(INFO) << "Add dynamic len index for funcgraph:" << func_graph->ToString()
672                    << " actor:" << gather_actor->GetAID()
673                    << " formal parameter num:" << gather_actor->formal_parameters_.size();
674       gather_actor->dynamic_len_index_[func_graph] = real_indexes;
675     }
676   }
677 
678   for (const auto &node_to_call_with_index : parser->return_to_call_with_dynamic_sequence_index_) {
679     const auto &node = node_to_call_with_index.first;
680     MS_EXCEPTION_IF_NULL(node);
681     MS_EXCEPTION_IF_NULL(node->func_graph());
682     MS_LOG(DEBUG) << "for node:" << node->DebugString();
683     const auto &actor_name = node->func_graph()->ToString() + kExitActorNameSuffix;
684     auto actor = FetchActor(actor_name);
685     MS_EXCEPTION_IF_NULL(actor);
686     const auto &exit_actor = dynamic_cast<ExitActor *>(actor);
687     MS_EXCEPTION_IF_NULL(exit_actor);
688     for (const auto &call_with_index : node_to_call_with_index.second) {
689       const auto &call = call_with_index.first;
690       MS_EXCEPTION_IF_NULL(call);
691       int branch_id = parser->FetchBranchIDByCallNode(call);
692       std::vector<std::pair<std::vector<size_t>, bool>> real_indexes;
693       ParseRealIndex(call_with_index.second, exit_actor->formal_parameters_.size(), &real_indexes, exit_actor);
694       exit_actor->output_branch_dynamic_len_index_[branch_id] = real_indexes;
695     }
696   }
697 }
698 
Link(ActorSet * const actor_set,const GraphCompilerInfo & graph_compiler_info) const699 void ControlNodeScheduler::Link(ActorSet *const actor_set, const GraphCompilerInfo &graph_compiler_info) const {
700   MS_EXCEPTION_IF_NULL(actor_set);
701   MS_EXCEPTION_IF_NULL(actor_set->control_actors_);
702   MS_LOG(DEBUG) << "Control node scheduler link start.";
703   // Link data arrows and partial arrows between control actors.
704   LinkArrowForControlActor(actor_set->control_actors_.get(), graph_compiler_info);
705 
706   // Link arrows from host data source actor or data prepare actor to entrance actor of root graph.
707   LinkArrowForRootGraphEntranceActor(graph_compiler_info);
708 
709   // Link output data arrows from control actors to output actor.
710   LinkDataArrowForOutputActor(actor_set, graph_compiler_info);
711 
712   // Link data arrows from entrance actors to kernel actors.
713   LinkDataArrowForKernelActor(graph_compiler_info);
714 
715   // Link branch id arrows between control actors.
716   LinkBranchIDArrowForControlActor(actor_set->control_actors_.get());
717 
718   // Link all control arrows between control actors.
719   LinkControlArrowForControlActor(actor_set, graph_compiler_info);
720 
721   // Link control arrows for no input and no output kernel actor.
722   LinkControlArrowForKernelActor(actor_set, graph_compiler_info);
723 
724   LinkControlArrowForLoopCountActor(actor_set, graph_compiler_info);
725 
726   LinkDataArrowForCustomActor(actor_set, graph_compiler_info);
727 
728   LinkControlArrowForCustomActor(actor_set, graph_compiler_info);
729 
730   SetTimeSummaryForControlActor(graph_compiler_info);
731 
732   CollectDynamicLenIndexForArgment(graph_compiler_info);
733   MS_LOG(DEBUG) << "Control node scheduler link end.";
734 }
735 
736 namespace {
FetchInternalParameterInput(const AnfNodePtr & node,const ControlNodeParserPtr & parser,const KernelGraphPtr & graph)737 AnfNodePtr FetchInternalParameterInput(const AnfNodePtr &node, const ControlNodeParserPtr &parser,
738                                        const KernelGraphPtr &graph) {
739   MS_EXCEPTION_IF_NULL(node);
740   if (!node->isa<CNode>()) {
741     MS_LOG(WARNING) << "Node:" << node->DebugString() << " is not a cnode.";
742     return nullptr;
743   }
744   const auto &kernel = node->cast<CNodePtr>();
745   MS_EXCEPTION_IF_NULL(kernel);
746   for (size_t i = 0; i < common::AnfAlgo::GetInputNum(kernel); ++i) {
747     auto input_node = common::AnfAlgo::GetInputNode(kernel, i);
748     MS_EXCEPTION_IF_NULL(input_node);
749     auto input_with_index = common::AnfAlgo::VisitKernelWithReturnType(input_node, 0, false);
750     auto input = input_with_index.first;
751     MS_EXCEPTION_IF_NULL(input);
752     if (HasAbstractMonad(input) || (!parser->IsControlFlowDataArrow(graph, input))) {
753       continue;
754     }
755 
756     auto from_node_with_index = GetFrontNodeByKernelGraph(input, graph.get());
757     MS_EXCEPTION_IF_NULL(from_node_with_index.first);
758     const auto &from_node = from_node_with_index.first;
759     if (from_node->isa<CNode>()) {
760       return from_node;
761     }
762   }
763   return nullptr;
764 }
765 }  // namespace
766 
LinkControlArrowForCustomActor(ActorSet * const actor_set,const GraphCompilerInfo & graph_compiler_info) const767 void ControlNodeScheduler::LinkControlArrowForCustomActor(ActorSet *const actor_set,
768                                                           const GraphCompilerInfo &graph_compiler_info) const {
769   MS_EXCEPTION_IF_NULL(actor_set);
770   const auto &parser = graph_compiler_info.control_node_parser_;
771   MS_EXCEPTION_IF_NULL(parser);
772 
773   for (auto &custom_actor : actor_set->custom_actors_) {
774     MS_EXCEPTION_IF_NULL(custom_actor);
775     const auto &kernel = custom_actor->kernel().lock();
776     MS_EXCEPTION_IF_NULL(kernel);
777     const auto &graph = kernel->func_graph();
778     MS_EXCEPTION_IF_NULL(graph);
779     if (custom_actor->output_data_arrows().empty() && custom_actor->output_control_arrows().empty()) {
780       const auto &actor_name = graph->ToString() + kExitActorNameSuffix;
781       auto actor = FetchActor(actor_name);
782       MS_EXCEPTION_IF_NULL(actor);
783       SchedulerHelper::AddControlArrow(custom_actor.get(), actor);
784     }
785     if (custom_actor->input_control_arrow_aids().empty() && custom_actor->input_data_arrow_aids().empty()) {
786       const auto &kernel_graph = std::dynamic_pointer_cast<KernelGraph>(graph);
787       MS_EXCEPTION_IF_NULL(kernel_graph);
788       auto base_node = AnfUtils::GetCustomActorBaseNode(kernel);
789       AnfNodePtr internal_parameter = nullptr;
790       if (base_node != nullptr) {
791         internal_parameter = FetchInternalParameterInput(base_node, parser, kernel_graph);
792       }
793       AbstractActor *from_actor = nullptr;
794       if (parser->IsCallInputKernelGraph(kernel_graph.get())) {
795         auto kernel_graph_ptr = std::dynamic_pointer_cast<KernelGraph>(kernel->func_graph());
796         const auto &actor_name = parser->FetchGroupNameByKernelGraph(kernel_graph_ptr) + kStackActorNameSuffix;
797         from_actor = FetchActor(actor_name);
798       } else if (internal_parameter != nullptr) {
799         const auto &from_graph = parser->FetchKernelGraphByFrontNode(internal_parameter);
800         MS_EXCEPTION_IF_NULL(from_graph);
801         from_actor = FetchActor(parser->FetchGroupNameByKernelGraph(from_graph) + kExitActorNameSuffix);
802       } else {
803         const auto &func_graph = parser->FetchFuncGraphByKernelGraph(kernel_graph.get());
804         MS_EXCEPTION_IF_NULL(func_graph);
805         const auto &actor_name = func_graph->ToString() + kEntranceActorNameSuffix;
806         from_actor = FetchActor(actor_name);
807       }
808       MS_EXCEPTION_IF_NULL(from_actor);
809       SchedulerHelper::AddControlArrow(from_actor, custom_actor.get());
810     }
811   }
812 }
813 
ClearActorData(const ControlActorSet * control_actor_set) const814 void ControlNodeScheduler::ClearActorData(const ControlActorSet *control_actor_set) const {
815   if (control_actor_set == nullptr) {
816     return;
817   }
818 
819   for (auto &switch_actor : control_actor_set->switch_actors_) {
820     MS_EXCEPTION_IF_NULL(switch_actor);
821     switch_actor->memory_free_lists_ = std::queue<std::vector<DeviceTensor *>>();
822   }
823 
824   for (auto &gather_actor : control_actor_set->gather_actors_) {
825     MS_EXCEPTION_IF_NULL(gather_actor);
826     gather_actor->memory_free_lists_ = std::queue<std::vector<DeviceTensor *>>();
827     gather_actor->created_device_tensors_.clear();
828     gather_actor->created_new_graphs_.clear();
829     gather_actor->created_new_nodes_.clear();
830   }
831 
832   for (auto &entrance_actor : control_actor_set->entrance_actors_) {
833     MS_EXCEPTION_IF_NULL(entrance_actor);
834     entrance_actor->memory_free_lists_ = std::queue<std::vector<DeviceTensor *>>();
835   }
836 
837   for (auto &stack_actor : control_actor_set->stack_actors_) {
838     MS_EXCEPTION_IF_NULL(stack_actor);
839     stack_actor->memory_free_lists_ = std::queue<std::vector<DeviceTensor *>>();
840   }
841 
842   for (auto &exit_actor : control_actor_set->exit_actors_) {
843     MS_EXCEPTION_IF_NULL(exit_actor);
844     exit_actor->memory_free_lists_ = std::queue<std::vector<DeviceTensor *>>();
845     exit_actor->last_step_created_device_tensors_.swap(exit_actor->created_device_tensors_);
846     exit_actor->created_new_graphs_.clear();
847     exit_actor->created_new_nodes_.clear();
848   }
849 }
850 
851 namespace {
GetLazyInlineFuncGraph(const StackActorPtr & stack_actor)852 FuncGraphPtr GetLazyInlineFuncGraph(const StackActorPtr &stack_actor) {
853   MS_EXCEPTION_IF_NULL(stack_actor);
854   for (const auto &input_pair : stack_actor->input_data_arrow_aids()) {
855     if (input_pair.second == nullptr) {
856       continue;
857     }
858     const auto &data_arrow = input_pair.second;
859     if (IntToSize(data_arrow->to_input_index_) < stack_actor->input_stack_data_num()) {
860       continue;
861     }
862     std::string actor_name = input_pair.first.Name();
863     if (actor_name.empty()) {
864       continue;
865     }
866     const auto &actor = FetchActor(actor_name);
867     if (actor == nullptr) {
868       MS_LOG(WARNING) << "Failed to get actor by aid:" << actor_name << " for stack actor:" << stack_actor->GetAID();
869       continue;
870     }
871     if (actor->type() != KernelTransformType::kExitActor) {
872       continue;
873     }
874     const auto &exit_actor = dynamic_cast<ExitActor *>(actor);
875     MS_EXCEPTION_IF_NULL(exit_actor);
876     if (exit_actor->node() == nullptr) {
877       continue;
878     }
879     return exit_actor->node()->func_graph();
880   }
881   return nullptr;
882 }
883 }  // namespace
884 
Optimize(const ActorSetPtr & actor_set,const GraphCompilerInfo & graph_compiler_info) const885 void ControlNodeScheduler::Optimize(const ActorSetPtr &actor_set, const GraphCompilerInfo &graph_compiler_info) const {
886   MS_EXCEPTION_IF_NULL(actor_set);
887   if (actor_set->control_actors_ == nullptr || graph_compiler_info.control_node_parser_ == nullptr) {
888     return;
889   }
890   const auto &parser = graph_compiler_info.control_node_parser_;
891 
892   for (const auto &entrance_actor : actor_set->control_actors_->entrance_actors_) {
893     MS_EXCEPTION_IF_NULL(entrance_actor);
894     std::stable_partition(entrance_actor->output_data_arrows_.begin(), entrance_actor->output_data_arrows_.end(),
895                           [](const DataArrowPtr &arrow) {
896                             MS_EXCEPTION_IF_NULL(arrow);
897                             const auto &actor = FetchActor(arrow->to_op_id_.Name());
898                             return actor != nullptr && (actor->type() == KernelTransformType::kKernelActor ||
899                                                         actor->type() == KernelTransformType::kSuperKernelActor ||
900                                                         actor->type() == KernelTransformType::kCopyActor);
901                           });
902   }
903 
904   for (const auto &stack_actor : actor_set->control_actors_->stack_actors_) {
905     MS_EXCEPTION_IF_NULL(stack_actor);
906     if (stack_actor->formal_parameters_.size() !=
907         stack_actor->input_data_arrow_aids_.size() + stack_actor->device_tensor_store_keys_.size()) {
908       continue;
909     }
910     if (stack_actor->input_stack_data_num_ == 0 ||
911         stack_actor->input_stack_data_num_ >= stack_actor->device_tensor_store_keys_.size() +
912                                                 stack_actor->local_device_tensors_.size() +
913                                                 stack_actor->input_data_arrow_aids_.size()) {
914       continue;
915     }
916     const auto &func_graph = GetLazyInlineFuncGraph(stack_actor);
917     if (func_graph == nullptr) {
918       continue;
919     }
920     if (!func_graph->has_flag(FUNC_GRAPH_FLAG_CELL_REUSE)) {
921       continue;
922     }
923     const auto &iter = parser->func_graph_to_call_nodes_.find(func_graph);
924     if (iter != parser->func_graph_to_call_nodes_.end() && (!iter->second.empty())) {
925       continue;
926     }
927     for (const auto &input_aid : stack_actor->input_branch_id_arrow_aids_) {
928       const auto &actor = FetchActor(input_aid.Name());
929       if (actor == nullptr || actor->type() != KernelTransformType::kEntranceActor) {
930         MS_LOG(WARNING) << "Invalid input branch id aid:" << input_aid;
931         continue;
932       }
933       const auto &entrance_actor = dynamic_cast<EntranceActor *>(actor);
934       MS_EXCEPTION_IF_NULL(entrance_actor);
935       const auto &branch_id_iter = std::find(entrance_actor->output_branch_id_arrows_.begin(),
936                                              entrance_actor->output_branch_id_arrows_.end(), stack_actor->GetAID());
937       if (branch_id_iter != entrance_actor->output_branch_id_arrows_.end()) {
938         stack_actor->is_branch_id_enable_ = false;
939         entrance_actor->output_branch_id_arrows_.erase(branch_id_iter);
940         MS_LOG(DEBUG) << "Disable branch id from entrance actor:" << entrance_actor->GetAID()
941                       << " to stack actor:" << stack_actor->GetAID()
942                       << " for cell reuse funcgraph:" << func_graph->ToString();
943         break;
944       }
945     }
946   }
947 }
948 
LinkArrowForControlActor(ControlActorSet * const control_actor_set,const GraphCompilerInfo & graph_compiler_info) const949 void ControlNodeScheduler::LinkArrowForControlActor(ControlActorSet *const control_actor_set,
950                                                     const GraphCompilerInfo &graph_compiler_info) const {
951   if (control_actor_set == nullptr) {
952     return;
953   }
954 
955   MS_EXCEPTION_IF_NULL(graph_compiler_info.control_node_parser_);
956   const auto &parser = graph_compiler_info.control_node_parser_;
957   for (auto &switch_actor : control_actor_set->switch_actors_) {
958     MS_EXCEPTION_IF_NULL(switch_actor);
959     if (!parser->IsNeedStackControlNode(switch_actor->node_)) {
960       for (size_t i = 0; i < switch_actor->formal_parameters_.size(); ++i) {
961         LinkArrowbyFormalParameter(switch_actor.get(), switch_actor->formal_parameters_[i], {switch_actor->node_, i},
962                                    graph_compiler_info);
963       }
964     } else {
965       // If the control actor has a corresponding stack actor, the input should be linked to the stack actor.
966       auto stack_actor_name = GetActorName(switch_actor->node_) + kStackActorNameSuffix;
967       auto actor = FetchActor(stack_actor_name);
968       MS_EXCEPTION_IF_NULL(actor);
969       auto stack_actor = dynamic_cast<StackActor *>(actor);
970       MS_EXCEPTION_IF_NULL(stack_actor);
971       LinkArrowFromStackActor(stack_actor, switch_actor.get(), graph_compiler_info);
972     }
973   }
974 
975   for (auto &gather_actor : control_actor_set->gather_actors_) {
976     MS_EXCEPTION_IF_NULL(gather_actor);
977     MS_EXCEPTION_IF_NULL(gather_actor->node_);
978     if (!parser->IsNeedStackControlNode(gather_actor->node_)) {
979       for (size_t i = 0; i < gather_actor->formal_parameters_.size(); ++i) {
980         LinkArrowbyFormalParameter(gather_actor.get(), gather_actor->formal_parameters_[i], {gather_actor->node_, i},
981                                    graph_compiler_info);
982       }
983     } else {
984       // If the control actor has a corresponding stack actor, the input should be linked to the stack actor.
985       auto stack_actor_name = GetActorName(gather_actor->node_) + kStackActorNameSuffix;
986       auto actor = FetchActor(stack_actor_name);
987       MS_EXCEPTION_IF_NULL(actor);
988       auto stack_actor = dynamic_cast<StackActor *>(actor);
989       MS_EXCEPTION_IF_NULL(stack_actor);
990       LinkArrowFromStackActor(stack_actor, gather_actor.get(), graph_compiler_info);
991     }
992   }
993 
994   for (auto &entrance_actor : control_actor_set->entrance_actors_) {
995     MS_EXCEPTION_IF_NULL(entrance_actor);
996     for (const auto &call_node : entrance_actor->call_nodes_) {
997       LinkArrowbyFormalParameter(entrance_actor.get(), call_node, {entrance_actor->node_, 0}, graph_compiler_info);
998     }
999   }
1000 
1001   for (auto &exit_actor : control_actor_set->exit_actors_) {
1002     MS_EXCEPTION_IF_NULL(exit_actor);
1003 
1004     auto stack_actor_name = (exit_actor->node_ == nullptr ? GetStackActorNameByExitName(exit_actor->GetAID().Name())
1005                                                           : GetActorName(exit_actor->node_) + kStackActorNameSuffix);
1006     auto actor = FetchActor(stack_actor_name);
1007     if (actor == nullptr) {
1008       for (size_t i = 0; i < exit_actor->formal_parameters_.size(); ++i) {
1009         LinkArrowbyFormalParameter(exit_actor.get(), exit_actor->formal_parameters_[i], {exit_actor->node_, i},
1010                                    graph_compiler_info);
1011       }
1012     } else {
1013       // If the control actor has a corresponding stack actor, the input should be linked to the stack actor.
1014       auto stack_actor = dynamic_cast<StackActor *>(actor);
1015       MS_EXCEPTION_IF_NULL(stack_actor);
1016       LinkArrowFromStackActor(stack_actor, exit_actor.get(), graph_compiler_info);
1017     }
1018   }
1019 
1020   for (auto &stack_actor : control_actor_set->stack_actors_) {
1021     MS_EXCEPTION_IF_NULL(stack_actor);
1022     for (size_t i = 0; i < stack_actor->formal_parameters_.size(); ++i) {
1023       LinkArrowbyFormalParameter(stack_actor.get(), stack_actor->formal_parameters_[i], {stack_actor->node_, i},
1024                                  graph_compiler_info);
1025     }
1026   }
1027 }
1028 
LinkArrowFromStackActor(StackActor * const stack_actor,ControlActor * const to_actor,const GraphCompilerInfo & graph_compiler_info) const1029 void ControlNodeScheduler::LinkArrowFromStackActor(StackActor *const stack_actor, ControlActor *const to_actor,
1030                                                    const GraphCompilerInfo &graph_compiler_info) const {
1031   MS_EXCEPTION_IF_NULL(stack_actor);
1032   MS_EXCEPTION_IF_NULL(to_actor);
1033   MS_EXCEPTION_IF_NULL(graph_compiler_info.control_node_parser_);
1034   const auto &parser = graph_compiler_info.control_node_parser_;
1035   MS_EXCEPTION_IF_NULL(parser);
1036 
1037   for (size_t to_index = 0; to_index < to_actor->formal_parameters_.size(); ++to_index) {
1038     const auto &formal_parameter =
1039       common::AnfAlgo::FetchRealNodeSkipMonadControl(to_actor->formal_parameters_[to_index]);
1040     const auto &from_node = formal_parameter.first;
1041     MS_EXCEPTION_IF_NULL(from_node);
1042     if (from_node->isa<ValueNode>()) {
1043       LinkArrowByValueNode(from_node, to_actor, formal_parameter.second, to_index);
1044       continue;
1045     }
1046 
1047     // Fetch the arrow type of input.
1048     if (to_actor->type_ == KernelTransformType::kExitActor && to_actor->node_ == nullptr && from_node->isa<CNode>() &&
1049         (!common::AnfAlgo::IsCallNode(from_node) || IsNotCut(from_node)) &&
1050         (!common::AnfAlgo::CheckPrimitiveType(from_node, prim::kPrimPartial)) &&
1051         to_actor->GetAID().Name().find(
1052           parser->FetchGroupNameByKernelGraph(parser->FetchKernelGraphByFrontNode(from_node))) != std::string::npos) {
1053       LinkArrowByKernel(from_node, to_actor, formal_parameter, {to_actor->node_, to_index}, graph_compiler_info);
1054       continue;
1055     }
1056 
1057     size_t from_index = stack_actor->FetchNodePosition(formal_parameter);
1058     const auto &abstract = from_node->abstract();
1059     MS_EXCEPTION_IF_NULL(abstract);
1060     const auto &real_abstract = common::AnfAlgo::FetchAbstractByIndex(abstract, formal_parameter.second);
1061     MS_EXCEPTION_IF_NULL(real_abstract);
1062 
1063     // Link arrow according to abstract.
1064     if (real_abstract->isa<abstract::AbstractFunction>()) {
1065       SchedulerHelper::AddPartialArrow(stack_actor, to_actor, from_index, to_index);
1066     } else {
1067       SchedulerHelper::AddDataArrow(stack_actor, to_actor, from_index, to_index);
1068     }
1069   }
1070 }
1071 
LinkArrowbyFormalParameter(ControlActor * const to_actor,const KernelWithIndex & from_node_with_index,const KernelWithIndex & to_node_with_index,const GraphCompilerInfo & graph_compiler_info) const1072 void ControlNodeScheduler::LinkArrowbyFormalParameter(ControlActor *const to_actor,
1073                                                       const KernelWithIndex &from_node_with_index,
1074                                                       const KernelWithIndex &to_node_with_index,
1075                                                       const GraphCompilerInfo &graph_compiler_info) const {
1076   const auto &real_from_node_with_index = common::AnfAlgo::FetchRealNodeSkipMonadControl(from_node_with_index);
1077   const auto &from_node = real_from_node_with_index.first;
1078   MS_EXCEPTION_IF_NULL(from_node);
1079   MS_EXCEPTION_IF_NULL(to_actor);
1080   MS_LOG(DEBUG) << "Link arrow by formal parameter, from node:" << from_node->DebugString()
1081                 << " from index:" << real_from_node_with_index.second << " to actor:" << to_actor->GetAID()
1082                 << " to index:" << to_node_with_index.second;
1083   if (from_node->isa<ValueNode>()) {
1084     LinkArrowByValueNode(from_node, to_actor, real_from_node_with_index.second, to_node_with_index.second);
1085   } else if (from_node->isa<Parameter>()) {
1086     LinkArrowByParameter(from_node, to_actor, real_from_node_with_index, to_node_with_index,
1087                          graph_compiler_info.control_node_parser_);
1088   } else if (common::AnfAlgo::IsCallNode(from_node) && !IsNotCut(from_node)) {
1089     // Link arrow by call node.
1090     LinkArrowByCallNode(from_node, to_actor, real_from_node_with_index, to_node_with_index, graph_compiler_info);
1091   } else if (common::AnfAlgo::CheckPrimitiveType(from_node, prim::kPrimSwitch) ||
1092              common::AnfAlgo::CheckPrimitiveType(from_node, prim::kPrimSwitchLayer)) {
1093     // Link arrow from switch actor.
1094     const auto &actor_name = GetActorName(from_node);
1095     const auto &actor = FetchActor(actor_name);
1096     MS_EXCEPTION_IF_NULL(actor);
1097     const auto &switch_actor = dynamic_cast<SwitchActor *>(actor);
1098     MS_EXCEPTION_IF_NULL(switch_actor);
1099     if (IsPartialInput(from_node)) {
1100       SchedulerHelper::AddPartialArrow(switch_actor, to_actor, real_from_node_with_index.second,
1101                                        to_node_with_index.second);
1102     } else {
1103       SchedulerHelper::AddDataArrow(switch_actor, to_actor, real_from_node_with_index.second,
1104                                     to_node_with_index.second);
1105     }
1106   } else if (common::AnfAlgo::CheckPrimitiveType(from_node, prim::kPrimPartial)) {
1107     // If the funcgraph of the partial node is a deadnode, in order to ensure the correspondence between formal
1108     // parameters and real parameters, we need to create an empty partial for it.
1109     if (IsInvalidPartial(from_node)) {
1110       MS_LOG(DEBUG) << "Invalid partial node:" << from_node->DebugString();
1111       to_actor->local_partials_[to_node_with_index.second] = std::make_shared<OpPartial>();
1112       return;
1113     }
1114     // Link arrow from gather actor
1115     const auto &actor_name = GetActorName(from_node);
1116     const auto &actor = FetchActor(actor_name);
1117     if (actor == nullptr) {
1118       MS_LOG(DEBUG) << "No actor of " << actor_name;
1119       return;
1120     }
1121     const auto &gather_actor = dynamic_cast<GatherActor *>(actor);
1122     MS_EXCEPTION_IF_NULL(gather_actor);
1123     SchedulerHelper::AddPartialArrow(gather_actor, to_actor, real_from_node_with_index.second,
1124                                      to_node_with_index.second);
1125   } else if (from_node->isa<CNode>()) {
1126     // Link arrow by kernel.
1127     LinkArrowByKernel(from_node, to_actor, real_from_node_with_index, to_node_with_index, graph_compiler_info);
1128   }
1129 }
1130 
LinkArrowByValueNode(const AnfNodePtr & value_node,ControlActor * const to_actor,size_t from_index,size_t to_index) const1131 void ControlNodeScheduler::LinkArrowByValueNode(const AnfNodePtr &value_node, ControlActor *const to_actor,
1132                                                 size_t from_index, size_t to_index) const {
1133   MS_EXCEPTION_IF_NULL(value_node);
1134   MS_EXCEPTION_IF_NULL(to_actor);
1135 
1136   if (IsValueNode<FuncGraph>(value_node)) {
1137     // Link local partial.
1138     const auto &func_graph = GetValueNode<FuncGraphPtr>(value_node);
1139     MS_EXCEPTION_IF_NULL(func_graph);
1140     MS_LOG(DEBUG) << "Add local partial, graph:" << func_graph->ToString() << " for actor:" << to_actor->GetAID();
1141     to_actor->local_partials_[to_index] = std::make_shared<OpPartial>();
1142     *(to_actor->local_partials_[to_index]) = {func_graph.get(), {}, {}};
1143   } else {
1144     // Link device store value node.
1145     if (!AnfAlgo::OutputAddrExist(value_node, from_index)) {
1146       auto node = value_node->cast<ValueNodePtr>();
1147       MS_EXCEPTION_IF_NULL(node);
1148       auto value = node->value();
1149       MS_EXCEPTION_IF_NULL(value);
1150       // If the from index exceeds the size of the value node, we need to change the from index to 0.
1151       if (!value->isa<ValueTuple>() && from_index > 0) {
1152         from_index = 0;
1153       } else {
1154         MS_LOG_WITH_NODE(INTERNAL_EXCEPTION, value_node)
1155           << "#dmsg#Runtime error info:#dmsg#Invalid output address index:" << from_index
1156           << " for value node:" << value_node->DebugString() << " to actor:" << to_actor->GetAID();
1157       }
1158     }
1159     to_actor->local_device_tensors_[to_index] = AnfAlgo::GetMutableOutputAddr(value_node, from_index, false).get();
1160     to_actor->local_device_tensors_[to_index]->SetNodeIndex(value_node, from_index);
1161     MS_LOG(DEBUG) << "Add local device tensor:" << to_actor->local_device_tensors_[to_index] << " index:" << to_index
1162                   << " for actor:" << to_actor->GetAID() << " from index:" << from_index;
1163   }
1164 }
1165 
LinkArrowByParameter(const AnfNodePtr & parameter,ControlActor * const to_actor,const KernelWithIndex & from_node_with_index,const KernelWithIndex & to_node_with_index,const ControlNodeParserPtr & parser) const1166 void ControlNodeScheduler::LinkArrowByParameter(const AnfNodePtr &parameter, ControlActor *const to_actor,
1167                                                 const KernelWithIndex &from_node_with_index,
1168                                                 const KernelWithIndex &to_node_with_index,
1169                                                 const ControlNodeParserPtr &parser) const {
1170   MS_EXCEPTION_IF_NULL(parameter);
1171   MS_EXCEPTION_IF_NULL(to_actor);
1172   MS_EXCEPTION_IF_NULL(parser);
1173   MS_LOG(DEBUG) << "Link arrow by parameter:" << parameter->DebugString() << " indx:" << from_node_with_index.second
1174                 << " for actor:" << to_actor->GetAID();
1175   if (parser->IsRootGraphPersistentDeviceTensor(parameter)) {
1176     (void)to_actor->device_tensor_store_keys_.emplace_back(to_node_with_index.second, parameter);
1177     return;
1178   }
1179   // Link arrow from entrance actor.
1180   const auto &func_graph = parameter->func_graph();
1181   MS_EXCEPTION_IF_NULL(func_graph);
1182   const auto &actor_name = func_graph->ToString() + kEntranceActorNameSuffix;
1183   auto actor = FetchActor(actor_name);
1184   MS_EXCEPTION_IF_NULL(actor);
1185 
1186   // If the input of the exit actor of the kernel graph is a parameter node, and there is a corresponding stack actor,
1187   // it should be linked to the stack actor.
1188   if (to_actor->type() == KernelTransformType::kExitActor) {
1189     auto stack_actor_name = (to_actor->node_ == nullptr ? GetStackActorNameByExitName(to_actor->GetAID().Name())
1190                                                         : GetActorName(to_actor->node_) + kStackActorNameSuffix);
1191     auto stack_actor = FetchActor(stack_actor_name);
1192     actor = (stack_actor == nullptr ? actor : stack_actor);
1193   }
1194 
1195   auto from_actor = dynamic_cast<ControlActor *>(actor);
1196   MS_EXCEPTION_IF_NULL(from_actor);
1197 
1198   auto abstract = parameter->abstract();
1199   MS_EXCEPTION_IF_NULL(abstract);
1200   auto dst_abstract = common::AnfAlgo::FetchAbstractByIndex(abstract, from_node_with_index.second);
1201   MS_EXCEPTION_IF_NULL(dst_abstract);
1202   if (dst_abstract->isa<abstract::AbstractFunction>()) {
1203     SchedulerHelper::AddPartialArrow(from_actor, to_actor, from_actor->FetchNodePosition(from_node_with_index),
1204                                      to_node_with_index.second);
1205   } else {
1206     SchedulerHelper::AddDataArrow(from_actor, to_actor, from_actor->FetchNodePosition(from_node_with_index),
1207                                   to_node_with_index.second);
1208   }
1209 }
1210 
LinkArrowByCallNode(const AnfNodePtr & call_node,ControlActor * const to_actor,const KernelWithIndex & from_node_with_index,const KernelWithIndex & to_node_with_index,const GraphCompilerInfo & graph_compiler_info) const1211 void ControlNodeScheduler::LinkArrowByCallNode(const AnfNodePtr &call_node, ControlActor *const to_actor,
1212                                                const KernelWithIndex &from_node_with_index,
1213                                                const KernelWithIndex &to_node_with_index,
1214                                                const GraphCompilerInfo &graph_compiler_info) const {
1215   MS_EXCEPTION_IF_NULL(call_node);
1216   MS_EXCEPTION_IF_NULL(to_actor);
1217   const auto &from_node = from_node_with_index.first;
1218   MS_EXCEPTION_IF_NULL(from_node);
1219   auto parser = graph_compiler_info.control_node_parser_;
1220   MS_EXCEPTION_IF_NULL(parser);
1221 
1222   if (to_actor->type_ != KernelTransformType::kEntranceActor) {
1223     // Link arrow from exit actor to control actor.
1224     const auto &abstract = call_node->abstract();
1225     MS_EXCEPTION_IF_NULL(abstract);
1226     const auto &real_abstract = common::AnfAlgo::FetchAbstractByIndex(abstract, from_node_with_index.second);
1227     MS_EXCEPTION_IF_NULL(real_abstract);
1228 
1229     std::set<FuncGraphPtr> func_graphs;
1230     try {
1231       func_graphs = parser->FetchFuncGraphbyCallNode(from_node);
1232     } catch (std::exception &e) {
1233       LinkArrowByKernel(call_node, to_actor, from_node_with_index, to_node_with_index, graph_compiler_info);
1234       func_graphs.clear();
1235     }
1236 
1237     for (const auto &func_graph : func_graphs) {
1238       MS_EXCEPTION_IF_NULL(func_graph);
1239       const auto &actor_name = func_graph->ToString() + kExitActorNameSuffix;
1240       auto actor = FetchActor(actor_name);
1241       MS_EXCEPTION_IF_NULL(actor);
1242       auto exit_actor = dynamic_cast<ExitActor *>(actor);
1243       MS_EXCEPTION_IF_NULL(exit_actor);
1244       auto branch_id = parser->FetchBranchIDByCallNode(from_node);
1245       if (real_abstract->isa<abstract::AbstractFunction>()) {
1246         SchedulerHelper::AddPartialArrowForExitActor(exit_actor, to_actor, from_node_with_index.second,
1247                                                      to_node_with_index.second, branch_id);
1248       } else {
1249         SchedulerHelper::AddDataArrowForExitActor(exit_actor, to_actor, from_node_with_index.second,
1250                                                   to_node_with_index.second, branch_id);
1251       }
1252       MS_LOG(DEBUG) << "Link data arrow from:" << exit_actor->GetAID() << " index:" << from_node_with_index.second
1253                     << " to:" << to_actor->GetAID() << " index" << to_node_with_index.second;
1254     }
1255     if (real_abstract->isa<abstract::AbstractFunction>()) {
1256       to_actor->input_partials_num_++;
1257     } else {
1258       MS_LOG(DEBUG) << "Actor:" << to_actor->GetAID() << " add input num:" << to_actor->input_datas_num_;
1259       to_actor->input_datas_num_++;
1260     }
1261   } else {
1262     // Link arrow from gather actor to entrance actor.
1263     const auto &actor_name = GetActorName(from_node);
1264     const auto &actor = FetchActor(actor_name);
1265     MS_EXCEPTION_IF_NULL(actor);
1266     const auto &gather_actor = dynamic_cast<GatherActor *>(actor);
1267     MS_EXCEPTION_IF_NULL(gather_actor);
1268     const auto &to_node = to_node_with_index.first;
1269     MS_EXCEPTION_IF_NULL(to_node);
1270     const auto &func_graph = to_node->func_graph();
1271     MS_EXCEPTION_IF_NULL(func_graph);
1272     SchedulerHelper::AddDataWithBranchIDArrow(gather_actor, dynamic_cast<EntranceActor *>(to_actor), func_graph);
1273   }
1274 }
1275 
LinkArrowByKernel(const AnfNodePtr & kernel,ControlActor * const to_actor,const KernelWithIndex & from_node_with_index,const KernelWithIndex & to_node_with_index,const GraphCompilerInfo & graph_compiler_info) const1276 void ControlNodeScheduler::LinkArrowByKernel(const AnfNodePtr &kernel, ControlActor *const to_actor,
1277                                              const KernelWithIndex &from_node_with_index,
1278                                              const KernelWithIndex &to_node_with_index,
1279                                              const GraphCompilerInfo &graph_compiler_info) const {
1280   MS_EXCEPTION_IF_NULL(kernel);
1281   MS_EXCEPTION_IF_NULL(to_actor);
1282   const auto &from_node = from_node_with_index.first;
1283   MS_EXCEPTION_IF_NULL(from_node);
1284   MS_EXCEPTION_IF_NULL(graph_compiler_info.control_node_parser_);
1285   const auto &parser = graph_compiler_info.control_node_parser_;
1286   MS_EXCEPTION_IF_NULL(parser);
1287   const auto &graph = parser->FetchKernelGraphByFrontNode(from_node);
1288   MS_LOG(DEBUG) << "Link arrow by kernel, from mode:" << from_node->DebugString() << " to actor:" << to_actor->GetAID();
1289   MS_EXCEPTION_IF_NULL(graph);
1290   const auto &group_name = parser->FetchGroupNameByKernelGraph(graph);
1291 
1292   if (to_actor->type_ == KernelTransformType::kExitActor && to_actor->node_ == nullptr &&
1293       to_actor->GetAID().Name().find(group_name) != std::string::npos) {
1294     // Link arrow from actor of output node to exit actor of kernel graph.
1295     auto kernel_with_index = parser->FetchBackendNodeByFrontNode(from_node_with_index);
1296     if (kernel_with_index.first == nullptr) {
1297       kernel_with_index = parser->FetchBackendOutputByKernelGraph(graph, from_node_with_index);
1298       if (kernel_with_index.first == nullptr) {
1299         parser->PrintParseInfo();
1300         MS_LOG_WITH_NODE(EXCEPTION, from_node)
1301           << "Failed to get kernel with index by front node:" << from_node->fullname_with_scope()
1302           << " debug string:" << from_node->DebugString() << " index:" << from_node_with_index.second
1303           << " by graph:" << graph->ToString() << " to actor:" << to_actor->GetAID();
1304       }
1305     }
1306     auto type = FetchKernelTransformType(kernel_with_index.first, graph, {});
1307     auto from_actor = FetchActor(type, graph_compiler_info.name_, kernel_with_index.first, graph);
1308     if (from_actor == nullptr) {
1309       parser->PrintParseInfo();
1310       MS_LOG_WITH_NODE(EXCEPTION, from_node)
1311         << "Failed to get from actor by backend node:" << kernel_with_index.first->DebugString()
1312         << " front node : " << from_node->fullname_with_scope() << " debug string:" << from_node->DebugString()
1313         << " index:" << from_node_with_index.second << " by graph:" << graph->ToString()
1314         << " to actor:" << to_actor->GetAID() << " type:" << type;
1315     }
1316     SchedulerHelper::AddDataArrow(from_actor, to_actor, kernel_with_index.second, to_node_with_index.second,
1317                                   kernel_with_index.first);
1318   } else {
1319     // Link arrow from exit actor of kernel graph to exit actor of function graph.
1320     const auto &actor_name = parser->FetchGroupNameByKernelGraph(graph) + kExitActorNameSuffix;
1321     MS_LOG(DEBUG) << "Actor name:" << actor_name << " from node:" << from_node->DebugString();
1322     auto actor = FetchActor(actor_name);
1323     MS_EXCEPTION_IF_NULL(actor);
1324     auto exit_actor = dynamic_cast<ExitActor *>(actor);
1325     MS_EXCEPTION_IF_NULL(exit_actor);
1326     size_t from_index = exit_actor->FetchNodePosition(from_node_with_index);
1327     SchedulerHelper::AddDataArrow(exit_actor, to_actor, from_index, to_node_with_index.second);
1328   }
1329 }
1330 
LinkControlArrowForControlActor(ActorSet * const actor_set,const GraphCompilerInfo & graph_compiler_info) const1331 void ControlNodeScheduler::LinkControlArrowForControlActor(ActorSet *const actor_set,
1332                                                            const GraphCompilerInfo &graph_compiler_info) const {
1333   MS_EXCEPTION_IF_NULL(actor_set);
1334   auto control_actor_set = actor_set->control_actors_.get();
1335   MS_EXCEPTION_IF_NULL(control_actor_set);
1336   const auto &parser = graph_compiler_info.control_node_parser_;
1337   MS_EXCEPTION_IF_NULL(parser);
1338   LinkControlArrowForEntranceActor(actor_set, graph_compiler_info);
1339 
1340   // When the switch actor and gather actor have no input, need to link a control arrow from entrance actor.
1341   std::vector<ControlActor *> need_check_control_actors;
1342   (void)std::transform(control_actor_set->switch_actors_.begin(), control_actor_set->switch_actors_.end(),
1343                        std::back_inserter(need_check_control_actors),
1344                        [](const auto &switch_actor) { return switch_actor.get(); });
1345   (void)std::transform(control_actor_set->gather_actors_.begin(), control_actor_set->gather_actors_.end(),
1346                        std::back_inserter(need_check_control_actors),
1347                        [](const auto &gather_actor) { return gather_actor.get(); });
1348 
1349   for (auto control_actor : need_check_control_actors) {
1350     MS_EXCEPTION_IF_NULL(control_actor);
1351     if (IsNoInputActor(control_actor)) {
1352       MS_EXCEPTION_IF_NULL(control_actor->node_);
1353       if (parser->IsNeedStackControlNode(control_actor->node_)) {
1354         const auto &stack_actor_name = GetActorName(control_actor->node_) + kStackActorNameSuffix;
1355         auto actor = FetchActor(stack_actor_name);
1356         MS_EXCEPTION_IF_NULL(actor);
1357         auto to_actor = dynamic_cast<ControlActor *>(actor);
1358         MS_EXCEPTION_IF_NULL(to_actor);
1359         SchedulerHelper::AddControlArrow(to_actor, control_actor);
1360         continue;
1361       }
1362       const FuncGraphPtr &func_graph = control_actor->node_->func_graph();
1363       MS_EXCEPTION_IF_NULL(func_graph);
1364       const auto &actor_name = func_graph->ToString() + kEntranceActorNameSuffix;
1365       const auto &entrance_actor = dynamic_cast<EntranceActor *>(FetchActor(actor_name));
1366       MS_EXCEPTION_IF_NULL(entrance_actor);
1367       SchedulerHelper::AddControlArrow(entrance_actor, control_actor);
1368     }
1369   }
1370 
1371   // Link auto monad control arrow for control actor.
1372   std::vector<ControlActor *> control_actors;
1373   (void)std::transform(control_actor_set->switch_actors_.begin(), control_actor_set->switch_actors_.end(),
1374                        std::back_inserter(control_actors), [](auto &switch_actor) { return switch_actor.get(); });
1375   (void)std::transform(control_actor_set->gather_actors_.begin(), control_actor_set->gather_actors_.end(),
1376                        std::back_inserter(control_actors), [](auto &gather_actor) { return gather_actor.get(); });
1377   (void)std::transform(control_actor_set->exit_actors_.begin(), control_actor_set->exit_actors_.end(),
1378                        std::back_inserter(control_actors), [](auto &exit_actor) { return exit_actor.get(); });
1379   for (auto control_actor : control_actors) {
1380     MS_EXCEPTION_IF_NULL(control_actor);
1381     const auto &node = control_actor->node_;
1382     if (node == nullptr) {
1383       continue;
1384     }
1385 
1386     auto to_actor = control_actor;
1387     if (parser->IsNeedStackControlNode(node)) {
1388       const auto &stack_actor_name = GetActorName(node) + kStackActorNameSuffix;
1389       auto actor = FetchActor(stack_actor_name);
1390       MS_EXCEPTION_IF_NULL(actor);
1391       to_actor = dynamic_cast<ControlActor *>(actor);
1392       MS_EXCEPTION_IF_NULL(to_actor);
1393     }
1394 
1395     const auto &cnode = node->cast<CNodePtr>();
1396     MS_EXCEPTION_IF_NULL(cnode);
1397     const auto &inputs = cnode->inputs();
1398     for (const auto &input : inputs) {
1399       MS_EXCEPTION_IF_NULL(input);
1400       std::vector<AnfNodePtr> monad_nodes = FetchAllMonadNodeByNode(input);
1401       for (const auto &monad_node : monad_nodes) {
1402         MS_EXCEPTION_IF_NULL(monad_node);
1403         LinkControlArrowByAutoMonad(to_actor, monad_node, parser);
1404       }
1405     }
1406   }
1407 
1408   // Link copy actor to exit actor.
1409   for (const auto &copy_actor : actor_set->copy_actors_) {
1410     MS_EXCEPTION_IF_NULL(copy_actor);
1411     if ((!copy_actor->output_data_arrows_.empty()) || (!copy_actor->output_control_arrows_.empty())) {
1412       continue;
1413     }
1414     KernelGraphPtr kernel_graph = nullptr;
1415     if (copy_actor->from_graph_ != nullptr) {
1416       kernel_graph = copy_actor->from_graph_;
1417     } else if (copy_actor->from_kernel_ != nullptr) {
1418       kernel_graph = std::dynamic_pointer_cast<KernelGraph>(copy_actor->from_kernel_->func_graph());
1419     }
1420     if (kernel_graph == nullptr) {
1421       MS_LOG(INTERNAL_EXCEPTION) << "#dmsg#Runtime error info:#dmsg#Invalid copy actor:" << copy_actor->GetAID().Name();
1422     }
1423     auto exit_actor_name = parser->FetchGroupNameByKernelGraph(kernel_graph) + kExitActorNameSuffix;
1424     auto exit_actor = FetchActor(exit_actor_name);
1425     MS_EXCEPTION_IF_NULL(exit_actor);
1426     SchedulerHelper::AddControlArrow(copy_actor.get(), exit_actor);
1427   }
1428 
1429   LinkControlArrowByKernelGraphGroup(graph_compiler_info);
1430 }
1431 
LinkControlArrowForEntranceActor(ActorSet * const actor_set,const GraphCompilerInfo & graph_compiler_info) const1432 void ControlNodeScheduler::LinkControlArrowForEntranceActor(ActorSet *const actor_set,
1433                                                             const GraphCompilerInfo &graph_compiler_info) const {
1434   MS_EXCEPTION_IF_NULL(actor_set);
1435   auto control_actor_set = actor_set->control_actors_.get();
1436   MS_EXCEPTION_IF_NULL(control_actor_set);
1437   const auto &parser = graph_compiler_info.control_node_parser_;
1438   MS_EXCEPTION_IF_NULL(parser);
1439 
1440   // Since only one set of real parameters are allowed to be executed in funcgraph at the same time, when the funcgraph
1441   // stops running, it is necessary to send the control arrow to the corresponding entrance actor at the exit of the
1442   // graph to run the next set of real parameters. The corresponding nodes of the actors that need to send the control
1443   // arrow have been parsed in the control node parser.
1444   for (const auto &graph_to_nodes : parser->func_graph_to_first_control_nodes_) {
1445     // Fetch the entrance actor.
1446     const auto &func_graph = graph_to_nodes.first;
1447     MS_EXCEPTION_IF_NULL(func_graph);
1448     auto actor_name = func_graph->ToString() + kEntranceActorNameSuffix;
1449     auto entrance_actor = dynamic_cast<EntranceActor *>(FetchActor(actor_name));
1450     MS_EXCEPTION_IF_NULL(entrance_actor);
1451 
1452     const auto &nodes = graph_to_nodes.second;
1453     for (const auto &node : nodes) {
1454       // Fetch the source actor of control arrow.
1455       MS_EXCEPTION_IF_NULL(node);
1456       if (common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) {
1457         actor_name = func_graph->ToString() + kExitActorNameSuffix;
1458       } else {
1459         actor_name = GetActorName(node);
1460       }
1461       auto from_actor = dynamic_cast<ControlActor *>(FetchActor(actor_name));
1462       MS_EXCEPTION_IF_NULL(from_actor);
1463       SchedulerHelper::AddLoopBodyControlArrow(from_actor, entrance_actor);
1464     }
1465   }
1466 
1467   // In the recursive scene, some kernel graph needs to be completed before the next set of data is sent by the
1468   // entrance actor. At this time, it is necessary to connect a control arrow from the exit actor of the graph
1469   // to the entrance actor.
1470   for (const auto &func_graph_to_group_info : parser->func_graph_to_first_kernel_graphs_) {
1471     const auto &func_graph = func_graph_to_group_info.first;
1472     MS_EXCEPTION_IF_NULL(func_graph);
1473     auto actor_name = func_graph->ToString() + kEntranceActorNameSuffix;
1474     auto actor = FetchActor(actor_name);
1475     MS_EXCEPTION_IF_NULL(actor);
1476     auto entrance_actor = dynamic_cast<EntranceActor *>(actor);
1477     MS_EXCEPTION_IF_NULL(entrance_actor);
1478     for (const auto &group_info : func_graph_to_group_info.second) {
1479       MS_EXCEPTION_IF_NULL(group_info);
1480       actor_name = group_info->group_name_ + kExitActorNameSuffix;
1481       auto from_actor = FetchActor(actor_name);
1482       MS_EXCEPTION_IF_NULL(from_actor);
1483       SchedulerHelper::AddLoopBodyControlArrow(from_actor, entrance_actor);
1484     }
1485   }
1486 }
1487 
LinkControlArrowForLoopCountActor(const ActorSet * actor_set,const GraphCompilerInfo & graph_compiler_info) const1488 void ControlNodeScheduler::LinkControlArrowForLoopCountActor(const ActorSet *actor_set,
1489                                                              const GraphCompilerInfo &graph_compiler_info) const {
1490   MS_EXCEPTION_IF_NULL(actor_set);
1491   auto loop_count_actor = actor_set->loop_count_actor_;
1492   MS_EXCEPTION_IF_NULL(loop_count_actor);
1493 
1494   // The final output is always sent by the exit of the root graph in control flow.
1495   const auto &parser = graph_compiler_info.control_node_parser_;
1496   MS_EXCEPTION_IF_NULL(parser);
1497   const auto &root_graph = parser->root_func_graph_;
1498   MS_EXCEPTION_IF_NULL(root_graph);
1499   auto exit_actor_name = root_graph->ToString() + kExitActorNameSuffix;
1500   auto root_exit_actor = dynamic_cast<ExitActor *>(FetchActor(exit_actor_name));
1501   MS_EXCEPTION_IF_NULL(root_exit_actor);
1502   // link control arrow from root exit actor to loop count actor.
1503   SchedulerHelper::AddControlArrowForExitActor(root_exit_actor, loop_count_actor.get(), kMainBranchID);
1504 
1505   // The entrance actor will generate some data in the loop body execution, so need clear on the end of step.
1506   MS_EXCEPTION_IF_NULL(actor_set->control_actors_);
1507   for (auto &entrance_actor : actor_set->control_actors_->entrance_actors_) {
1508     MS_EXCEPTION_IF_NULL(entrance_actor);
1509     (void)loop_count_actor->entrance_aids_.emplace_back(entrance_actor->GetAID());
1510   }
1511 }
1512 
LinkOutputControlArrowForActor(ActorSet * const actor_set,const GraphCompilerInfo & graph_compiler_info) const1513 void ControlNodeScheduler::LinkOutputControlArrowForActor(ActorSet *const actor_set,
1514                                                           const GraphCompilerInfo &graph_compiler_info) const {
1515   MS_EXCEPTION_IF_NULL(actor_set);
1516   const auto &parser = graph_compiler_info.control_node_parser_;
1517   MS_EXCEPTION_IF_NULL(parser);
1518   // Link control arrows from no output kernel actor to the corresponding exit actor.
1519   for (auto &kernel_actor : actor_set->kernel_actors_) {
1520     MS_EXCEPTION_IF_NULL(kernel_actor);
1521     if ((kernel_actor->output_data_arrows_.size() == 0) && (kernel_actor->output_control_arrows_.size() == 0) &&
1522         (!IsInlineKernelActor(kernel_actor))) {
1523       auto kernel_graph = AnfAlgo::FetchKernelGraph(kernel_actor->kernel().get());
1524       MS_EXCEPTION_IF_NULL(kernel_graph);
1525       auto to_actor_name = parser->FetchGroupNameByKernelGraph(kernel_graph) + kExitActorNameSuffix;
1526       auto to_actor = FetchActor(to_actor_name);
1527       MS_EXCEPTION_IF_NULL(to_actor);
1528       SchedulerHelper::AddControlArrow(kernel_actor.get(), to_actor);
1529     }
1530   }
1531 
1532   // Link control arrows from no super kernel actor to the corresponding exit actor.
1533   for (auto &super_actor : actor_set->super_kernel_actors_) {
1534     MS_EXCEPTION_IF_NULL(super_actor);
1535     if ((super_actor->output_data_arrows_.size() == 0) && (super_actor->output_control_arrows_.size() == 0)) {
1536       auto kernel_graph = super_actor->graph();
1537       MS_EXCEPTION_IF_NULL(kernel_graph);
1538       auto to_actor_name = parser->FetchGroupNameByKernelGraph(kernel_graph) + kExitActorNameSuffix;
1539       auto to_actor = FetchActor(to_actor_name);
1540       MS_EXCEPTION_IF_NULL(to_actor);
1541       SchedulerHelper::AddControlArrow(super_actor.get(), to_actor);
1542     }
1543   }
1544 
1545   // Link control arrows from no super kernel actor to the corresponding exit actor.
1546   for (auto &any_type_kernel_actor : actor_set->any_type_kernel_actors_) {
1547     MS_EXCEPTION_IF_NULL(any_type_kernel_actor);
1548     if ((any_type_kernel_actor->output_data_arrows_.size() == 0) &&
1549         (any_type_kernel_actor->output_control_arrows_.size() == 0)) {
1550       auto kernel_graph = any_type_kernel_actor->graph();
1551       MS_EXCEPTION_IF_NULL(kernel_graph);
1552       auto to_actor_name = parser->FetchGroupNameByKernelGraph(kernel_graph) + kExitActorNameSuffix;
1553       auto to_actor = FetchActor(to_actor_name);
1554       MS_EXCEPTION_IF_NULL(to_actor);
1555       SchedulerHelper::AddControlArrow(any_type_kernel_actor.get(), to_actor);
1556     }
1557   }
1558 }
1559 
LinkControlArrowForKernelActor(ActorSet * const actor_set,const GraphCompilerInfo & graph_compiler_info) const1560 void ControlNodeScheduler::LinkControlArrowForKernelActor(ActorSet *const actor_set,
1561                                                           const GraphCompilerInfo &graph_compiler_info) const {
1562   MS_EXCEPTION_IF_NULL(actor_set);
1563   const auto &parser = graph_compiler_info.control_node_parser_;
1564   MS_EXCEPTION_IF_NULL(parser);
1565 
1566   // Link control arrow from entrance actors or stack actors to no input kernel actors.
1567   for (const auto &no_input_kernel_actor : actor_set->no_input_kernel_actors_) {
1568     // In control flow, when the input of the kernel actor is a parameter, this input needs to be linked to the
1569     // control actor, so the no-input kernel actor collected in the graph scheduler will also collect this actor,
1570     // and it needs to be skipped here.
1571     MS_EXCEPTION_IF_NULL(no_input_kernel_actor);
1572     // Control arrow for custom actor will be linked in next step.
1573     if ((no_input_kernel_actor->input_datas_num_ != 0) || (no_input_kernel_actor->input_controls_num_ != 0) ||
1574         no_input_kernel_actor->type() == KernelTransformType::kCustomActor ||
1575         IsInlineKernelActor(no_input_kernel_actor)) {
1576       continue;
1577     }
1578 
1579     KernelGraphPtr kernel_graph = nullptr;
1580     if (no_input_kernel_actor->type_ == KernelTransformType::kSuperKernelActor) {
1581       const auto &super_kernel_actor = dynamic_cast<SuperKernelActor *>(no_input_kernel_actor.get());
1582       MS_EXCEPTION_IF_NULL(super_kernel_actor);
1583       kernel_graph = super_kernel_actor->graph();
1584     } else if (no_input_kernel_actor->type_ == KernelTransformType::kKernelActor) {
1585       const auto &kernel_actor = dynamic_cast<KernelActor *>(no_input_kernel_actor.get());
1586       MS_EXCEPTION_IF_NULL(kernel_actor);
1587       kernel_graph = AnfAlgo::FetchKernelGraph(kernel_actor->kernel().get());
1588     } else if (no_input_kernel_actor->type_ == KernelTransformType::kCustomActor) {
1589       const auto &custom_actor = dynamic_cast<CustomActor *>(no_input_kernel_actor.get());
1590       MS_EXCEPTION_IF_NULL(custom_actor);
1591       auto custom_kernel = custom_actor->kernel().lock();
1592       MS_EXCEPTION_IF_NULL(custom_kernel);
1593       const auto &base_node = AnfUtils::GetCustomActorBaseNode(custom_kernel);
1594       MS_EXCEPTION_IF_NULL(base_node);
1595       kernel_graph = AnfAlgo::FetchKernelGraph(base_node.get());
1596     } else {
1597       MS_LOG(EXCEPTION) << "Invalid no input actor: " << no_input_kernel_actor->GetAID().Name();
1598     }
1599     MS_EXCEPTION_IF_NULL(kernel_graph);
1600     auto actor_name = parser->FetchGroupNameByKernelGraph(kernel_graph) + kStackActorNameSuffix;
1601     if (!parser->IsCallInputKernelGraph(kernel_graph.get())) {
1602       const auto &func_graph = parser->FetchFuncGraphByKernelGraph(kernel_graph.get());
1603       MS_EXCEPTION_IF_NULL(func_graph);
1604       actor_name = func_graph->ToString() + kEntranceActorNameSuffix;
1605     }
1606 
1607     auto from_actor = FetchActor(actor_name);
1608     MS_EXCEPTION_IF_NULL(from_actor);
1609     SchedulerHelper::AddControlArrow(from_actor, no_input_kernel_actor.get());
1610   }
1611   LinkOutputControlArrowForActor(actor_set, graph_compiler_info);
1612 }
1613 
LinkControlArrowByAutoMonad(ControlActor * to_actor,const AnfNodePtr & from_node,const ControlNodeParserPtr & parser) const1614 void ControlNodeScheduler::LinkControlArrowByAutoMonad(ControlActor *to_actor, const AnfNodePtr &from_node,
1615                                                        const ControlNodeParserPtr &parser) const {
1616   MS_EXCEPTION_IF_NULL(to_actor);
1617   MS_EXCEPTION_IF_NULL(from_node);
1618   MS_EXCEPTION_IF_NULL(parser);
1619   MS_LOG(DEBUG) << "Link auto monad control arrow from node:" << from_node->DebugString()
1620                 << " to actor:" << to_actor->GetAID();
1621 
1622   std::set<AnfNodePtr> depend_nodes;
1623   FetchRealDependNodeByAutoMonad(from_node, &depend_nodes);
1624 
1625   for (const auto &depend_node : depend_nodes) {
1626     MS_EXCEPTION_IF_NULL(depend_node);
1627     MS_LOG(DEBUG) << "Add depend node:" << depend_node->DebugString() << " for actor:" << to_actor->GetAID();
1628     auto from_actor = FetchActor(GetActorName(depend_node));
1629     auto graph = parser->FetchKernelGraphByFrontNode(depend_node);
1630 
1631     std::vector<AbstractActor *> from_actors;
1632     if (common::AnfAlgo::IsCallNode(depend_node) && !IsNotCut(depend_node)) {
1633       // If the actor already exists with control arrow, skip it.
1634       if (IsControlArrowExistForCallNode(depend_node, to_actor, parser)) {
1635         MS_LOG(DEBUG) << "Control arrow from call node:" << depend_node << " to actor:" << to_actor->GetAID()
1636                       << "is exist, skip it";
1637         continue;
1638       }
1639       int branch_id = parser->FetchBranchIDByCallNode(depend_node);
1640       const auto &func_graphs = parser->FetchFuncGraphbyCallNode(depend_node);
1641       if (func_graphs.empty()) {
1642         MS_LOG_WITH_NODE(INTERNAL_EXCEPTION, depend_node)
1643           << "#dmsg#Runtime error info:#dmsg#Failed to get funcgraph by call node:" << depend_node->DebugString();
1644       }
1645       for (const auto &func_graph : func_graphs) {
1646         MS_EXCEPTION_IF_NULL(func_graph);
1647         auto exit_actor_name = func_graph->ToString() + kExitActorNameSuffix;
1648         from_actor = FetchActor(exit_actor_name);
1649         MS_EXCEPTION_IF_NULL(from_actor);
1650         (void)from_actors.emplace_back(from_actor);
1651         auto exit_actor = dynamic_cast<ExitActor *>(from_actor);
1652         MS_EXCEPTION_IF_NULL(exit_actor);
1653         SchedulerHelper::AddControlArrowForExitActor(exit_actor, to_actor, branch_id);
1654       }
1655       to_actor->input_controls_num_ -= (func_graphs.size() - 1);
1656     } else if (from_actor != nullptr) {
1657       (void)from_actors.emplace_back(from_actor);
1658       SchedulerHelper::AddControlArrow(from_actor, to_actor);
1659     } else {
1660       if (graph == nullptr) {
1661         MS_LOG_WITH_NODE(INTERNAL_EXCEPTION, depend_node)
1662           << "#dmsg#Runtime error info:#dmsg#Failed to find actor for node:" << depend_node->DebugString();
1663       }
1664       from_actor = FetchActor(parser->FetchGroupNameByKernelGraph(graph) + kExitActorNameSuffix);
1665       MS_EXCEPTION_IF_NULL(from_actor);
1666       if (std::find_if(from_actor->output_control_arrows_.begin(), from_actor->output_control_arrows_.end(),
1667                        [&to_actor](auto &output_control_arrow) {
1668                          MS_EXCEPTION_IF_NULL(output_control_arrow);
1669                          return output_control_arrow->to_op_id_.Name() == to_actor->GetAID().Name();
1670                        }) != from_actor->output_control_arrows_.end()) {
1671         MS_LOG(DEBUG) << "Link auto monad control from actor:" << from_actor->GetAID()
1672                       << " to actor:" << to_actor->GetAID() << " is already exist.";
1673         continue;
1674       }
1675       (void)from_actors.emplace_back(from_actor);
1676       SchedulerHelper::AddControlArrow(from_actor, to_actor);
1677     }
1678     if (to_actor->type_ != KernelTransformType::kStackActor || parser->IsNeedStackControlNode(depend_node) ||
1679         parser->IsRecursionCallNode(depend_node) || (graph != nullptr && parser->IsRecursionKernelGraph(graph))) {
1680       continue;
1681     }
1682     // If the control arrow comes from a recursive call node or a recursive kernel graph, these control edges will be
1683     // directly linked to the stack actor, otherwise, they need to be cached in the stack of the stack actor.
1684     auto stack_actor = dynamic_cast<StackActor *>(to_actor);
1685     MS_EXCEPTION_IF_NULL(stack_actor);
1686     stack_actor->input_controls_num_--;
1687     stack_actor->input_stack_controls_num_++;
1688     for (const auto &actor : from_actors) {
1689       MS_EXCEPTION_IF_NULL(actor);
1690       MS_LOG(DEBUG) << "Add stack control aid:" << actor->GetAID() << " for actor:" << stack_actor->GetAID();
1691       (void)stack_actor->stack_control_aids_.emplace(actor->GetAID());
1692       stack_actor->control_aid_to_indexs_[actor->GetAID()] = stack_actor->input_stack_controls_num_;
1693     }
1694   }
1695   MS_LOG(DEBUG) << "Link auto monad control arrow from node:" << from_node->DebugString()
1696                 << " to actor:" << to_actor->GetAID() << " end";
1697 }
1698 
LinkControlArrowByKernelGraphGroup(const GraphCompilerInfo & graph_compiler_info) const1699 void ControlNodeScheduler::LinkControlArrowByKernelGraphGroup(const GraphCompilerInfo &graph_compiler_info) const {
1700   const auto &parser = graph_compiler_info.control_node_parser_;
1701   MS_EXCEPTION_IF_NULL(parser);
1702 
1703   for (const auto &graph_group : parser->kernel_graph_group_infos_) {
1704     MS_EXCEPTION_IF_NULL(graph_group);
1705     if (!graph_group->need_stack_) {
1706       continue;
1707     }
1708     auto stack_actor = FetchActor(graph_group->group_name_ + kStackActorNameSuffix);
1709     MS_EXCEPTION_IF_NULL(stack_actor);
1710     auto to_actor = dynamic_cast<ControlActor *>(stack_actor);
1711     MS_EXCEPTION_IF_NULL(to_actor);
1712     for (const auto &monad_input : graph_group->monad_inputs_) {
1713       MS_EXCEPTION_IF_NULL(monad_input);
1714       MS_LOG(DEBUG) << "Add monad control arrow for group:" << graph_group->group_name_
1715                     << " to actor:" << to_actor->GetAID() << " by monad input:" << monad_input->DebugString();
1716       LinkControlArrowByAutoMonad(to_actor, monad_input, parser);
1717     }
1718   }
1719 }
1720 
LinkBranchIDArrowForControlActor(ControlActorSet * const control_actor_set) const1721 void ControlNodeScheduler::LinkBranchIDArrowForControlActor(ControlActorSet *const control_actor_set) const {
1722   MS_EXCEPTION_IF_NULL(control_actor_set);
1723 
1724   // Connect the branch id arrows from the entrance actor to the exit actor for each funcgraph.
1725   for (auto exit_actor : control_actor_set->exit_actors_) {
1726     MS_EXCEPTION_IF_NULL(exit_actor);
1727 
1728     // If the node in the exit actor is empty, it means that it is between the kernel actor and the control actor,
1729     // and no need to send the branch id.
1730     const auto &node = exit_actor->node_;
1731     if (node == nullptr) {
1732       continue;
1733     }
1734 
1735     const auto &func_graph = node->func_graph();
1736     MS_EXCEPTION_IF_NULL(func_graph);
1737     const auto &actor_name = func_graph->ToString() + kEntranceActorNameSuffix;
1738     auto actor = FetchActor(actor_name);
1739     MS_EXCEPTION_IF_NULL(actor);
1740     auto entrance_actor = dynamic_cast<EntranceActor *>(actor);
1741     MS_EXCEPTION_IF_NULL(entrance_actor);
1742     SchedulerHelper::AddBranchIDArrow(entrance_actor, exit_actor.get());
1743   }
1744 
1745   // Connect the branch id arrows from the entrance actor to the stack actor.
1746   for (auto stack_actor : control_actor_set->stack_actors_) {
1747     MS_EXCEPTION_IF_NULL(stack_actor);
1748     auto node = stack_actor->node_;
1749     if (!stack_actor->formal_parameters_.empty()) {
1750       node = stack_actor->formal_parameters_.back().first;
1751     } else {
1752       MS_LOG(INFO) << "No formal parameter for stack actor:" << stack_actor->GetAID();
1753     }
1754     MS_EXCEPTION_IF_NULL(node);
1755     const auto &func_graph = node->func_graph();
1756     MS_EXCEPTION_IF_NULL(func_graph);
1757     const auto &actor_name = func_graph->ToString() + kEntranceActorNameSuffix;
1758     auto actor = FetchActor(actor_name);
1759     MS_EXCEPTION_IF_NULL(actor);
1760     auto entrance_actor = dynamic_cast<EntranceActor *>(actor);
1761     MS_EXCEPTION_IF_NULL(entrance_actor);
1762     SchedulerHelper::AddBranchIDArrow(entrance_actor, stack_actor.get());
1763   }
1764 }
1765 
LinkDataArrowForKernelActor(const GraphCompilerInfo & graph_compiler_info) const1766 void ControlNodeScheduler::LinkDataArrowForKernelActor(const GraphCompilerInfo &graph_compiler_info) const {
1767   const auto &parser = graph_compiler_info.control_node_parser_;
1768   MS_EXCEPTION_IF_NULL(parser);
1769 
1770   // Link data arrows from entrance actors and stack actors to kernel actors.
1771   for (const auto &func_graph_to_kernel_graphs : parser->func_graph_to_kernel_graph_groups_) {
1772     // Fetch the source entrance actor.
1773     const auto &func_graph = func_graph_to_kernel_graphs.first;
1774     MS_EXCEPTION_IF_NULL(func_graph);
1775     auto actor_name = func_graph->ToString() + kEntranceActorNameSuffix;
1776     auto actor = FetchActor(actor_name);
1777     MS_EXCEPTION_IF_NULL(actor);
1778     auto entrance_actor = dynamic_cast<ControlActor *>(actor);
1779     MS_EXCEPTION_IF_NULL(entrance_actor);
1780 
1781     for (const auto &kernel_graph_group : func_graph_to_kernel_graphs.second) {
1782       for (const auto &kernel_graph : kernel_graph_group) {
1783         MS_EXCEPTION_IF_NULL(kernel_graph);
1784         if (kernel_graph->execution_order().empty()) {
1785           continue;
1786         }
1787         LinkDataArrowByKernelGraph(kernel_graph, entrance_actor, parser);
1788       }
1789     }
1790   }
1791 }
1792 
LinkDataArrowForCustomActor(const ActorSet * actor_set,const GraphCompilerInfo & graph_compiler_info) const1793 void ControlNodeScheduler::LinkDataArrowForCustomActor(const ActorSet *actor_set,
1794                                                        const GraphCompilerInfo &graph_compiler_info) const {
1795   MS_EXCEPTION_IF_NULL(actor_set);
1796   const auto &parser = graph_compiler_info.control_node_parser_;
1797   MS_EXCEPTION_IF_NULL(parser);
1798 
1799   for (const auto &custom_actor : actor_set->custom_actors_) {
1800     MS_EXCEPTION_IF_NULL(custom_actor);
1801     auto kernel = custom_actor->kernel().lock();
1802     MS_EXCEPTION_IF_NULL(kernel);
1803     if (AnfUtils::GetCustomActorType(kernel) != kInfer) {
1804       continue;
1805     }
1806     // Kernel in depends form map should link data arrow for infer shape.
1807     auto base_node = AnfUtils::GetCustomActorBaseNode(kernel);
1808     MS_EXCEPTION_IF_NULL(base_node);
1809     auto dynamic_shape_depends = abstract::GetValueDependArgIndices(base_node);
1810     for (auto iter = dynamic_shape_depends.begin(); iter != dynamic_shape_depends.end(); ++iter) {
1811       auto input_node = common::AnfAlgo::GetInputNode(base_node, LongToSize(*iter));
1812       MS_EXCEPTION_IF_NULL(input_node);
1813       KernelWithIndex from_kernel_with_index = common::AnfAlgo::VisitKernelWithReturnType(input_node, 0, false);
1814       const AnfNodePtr real_input_node = from_kernel_with_index.first;
1815       MS_EXCEPTION_IF_NULL(real_input_node);
1816       if (real_input_node->isa<ValueNode>()) {
1817         continue;
1818       }
1819       auto graph = AnfAlgo::FetchKernelGraph(real_input_node.get());
1820       MS_EXCEPTION_IF_NULL(graph);
1821       if (!parser->IsControlFlowDataArrow(graph, real_input_node)) {
1822         continue;
1823       }
1824 
1825       // Link data arrow from entrance actor or stack actor to infer shape custom actor.
1826       const auto &front_node_with_index = GetFrontNodeByKernelGraph(real_input_node, graph.get());
1827       MS_EXCEPTION_IF_NULL(front_node_with_index.first);
1828       AbstractActor *from_base_actor = nullptr;
1829       if (parser->IsCallInputKernelGraph(graph.get())) {
1830         from_base_actor = FetchActor(parser->FetchGroupNameByKernelGraph(graph) + kStackActorNameSuffix);
1831         MS_EXCEPTION_IF_NULL(from_base_actor);
1832       } else if (!front_node_with_index.first->isa<Parameter>()) {
1833         MS_LOG(INFO) << "Internal front node:" << front_node_with_index.first->DebugString()
1834                      << " index:" << front_node_with_index.second << " for custom actor:" << custom_actor->GetAID()
1835                      << " kernel:" << kernel->fullname_with_scope() << " input index:" << *iter;
1836         const auto &from_graph = parser->FetchKernelGraphByFrontNode(front_node_with_index.first);
1837         MS_EXCEPTION_IF_NULL(from_graph);
1838         from_base_actor = FetchActor(parser->FetchGroupNameByKernelGraph(from_graph) + kExitActorNameSuffix);
1839         MS_EXCEPTION_IF_NULL(from_base_actor);
1840       } else {
1841         const auto &func_graph = front_node_with_index.first->func_graph();
1842         MS_EXCEPTION_IF_NULL(func_graph);
1843         from_base_actor = FetchActor(func_graph->ToString() + kEntranceActorNameSuffix);
1844         MS_EXCEPTION_IF_NULL(from_base_actor);
1845       }
1846       const auto &from_actor = dynamic_cast<ControlActor *>(from_base_actor);
1847       MS_EXCEPTION_IF_NULL(from_actor);
1848       size_t from_index = from_actor->FetchNodePosition(front_node_with_index);
1849       MS_LOG(DEBUG) << "Link data arrow from actor:" << from_actor->GetAID()
1850                     << " to custom actor:" << custom_actor->GetAID();
1851       SchedulerHelper::AddDataArrow(from_actor, custom_actor.get(), from_index, LongToSize(*iter));
1852     }
1853   }
1854 }
1855 
LinkDataArrowByKernelGraphInSinkMode(const KernelGraphPtr & graph,ControlActor * const from_actor,const ControlNodeParserPtr & parser) const1856 void ControlNodeScheduler::LinkDataArrowByKernelGraphInSinkMode(const KernelGraphPtr &graph,
1857                                                                 ControlActor *const from_actor,
1858                                                                 const ControlNodeParserPtr &parser) const {
1859   MS_EXCEPTION_IF_NULL(graph);
1860   MS_EXCEPTION_IF_NULL(from_actor);
1861   MS_EXCEPTION_IF_NULL(parser);
1862   MS_LOG(DEBUG) << "Link data arrow in sink mode by kernel graph:" << graph->ToString();
1863   auto to_actor = FetchActor(
1864     graph->is_any_type_input() ? KernelTransformType::kAnyTypeKernelActor : KernelTransformType::kSuperKernelActor, "",
1865     nullptr, graph);
1866   MS_EXCEPTION_IF_NULL(to_actor);
1867   auto super_kernel_actor = dynamic_cast<SuperKernelActor *>(to_actor);
1868   MS_EXCEPTION_IF_NULL(super_kernel_actor);
1869 
1870   auto &input_nodes = graph->input_nodes();
1871   for (size_t i = 0; i < input_nodes.size(); ++i) {
1872     const auto &input_node = input_nodes[i];
1873     MS_EXCEPTION_IF_NULL(input_node);
1874     if (HasAbstractMonad(input_node) || (!parser->IsControlFlowDataArrow(graph, input_node))) {
1875       continue;
1876     }
1877     size_t to_index = super_kernel_actor->FetchInputNodePosition(input_node);
1878     const auto &front_node_with_index = GetFrontNodeByKernelGraph(input_node, graph.get());
1879     MS_EXCEPTION_IF_NULL(front_node_with_index.first);
1880     if (front_node_with_index.first->isa<ValueNode>()) {
1881       continue;
1882     }
1883     if (front_node_with_index.first->isa<CNode>() && (from_actor->type() != KernelTransformType::kStackActor)) {
1884       // If the input is an internal parameter, the input arrow should be linked to the exit actor of the kernel
1885       // graph which the internal parameter belong.
1886       MS_LOG(INFO) << "Internal parameter in control flow, backend input:" << input_node->DebugString()
1887                    << " front node:" << front_node_with_index.first->DebugString();
1888       const auto &from_graph = parser->FetchKernelGraphByFrontNode(front_node_with_index.first);
1889       MS_EXCEPTION_IF_NULL(from_graph);
1890       auto actor = FetchActor(parser->FetchGroupNameByKernelGraph(from_graph) + kExitActorNameSuffix);
1891       MS_EXCEPTION_IF_NULL(actor);
1892       auto exit_actor = dynamic_cast<ControlActor *>(actor);
1893       MS_EXCEPTION_IF_NULL(exit_actor);
1894       size_t from_index = exit_actor->FetchNodePosition(front_node_with_index);
1895       SchedulerHelper::AddFormalParameterDeviceTensor(exit_actor, from_index, input_node, graph);
1896       SchedulerHelper::AddDataArrow(exit_actor, to_actor, from_index, i);
1897       continue;
1898     }
1899     size_t from_index = from_actor->FetchNodePosition(front_node_with_index);
1900     SchedulerHelper::AddFormalParameterDeviceTensor(from_actor, from_index, input_node, graph);
1901     SchedulerHelper::AddDataArrow(from_actor, to_actor, from_index, to_index);
1902   }
1903   return;
1904 }
1905 
LinkDataArrowByKernelGraph(const KernelGraphPtr & graph,ControlActor * const entrance_actor,const ControlNodeParserPtr & parser) const1906 void ControlNodeScheduler::LinkDataArrowByKernelGraph(const KernelGraphPtr &graph, ControlActor *const entrance_actor,
1907                                                       const ControlNodeParserPtr &parser) const {
1908   MS_EXCEPTION_IF_NULL(graph);
1909   MS_EXCEPTION_IF_NULL(parser);
1910   MS_LOG(DEBUG) << "Link data arrow by kernel graph:" << graph->ToString();
1911   auto from_actor = entrance_actor;
1912   // If there is a call node in the input of the graph, the parameter of the graph needs to be sent by the
1913   // corresponding stack actor, otherwise it is sent by the entrance actor.
1914   if (parser->IsCallInputKernelGraph(graph.get())) {
1915     auto actor = FetchActor(parser->FetchGroupNameByKernelGraph(graph) + kStackActorNameSuffix);
1916     MS_EXCEPTION_IF_NULL(actor);
1917     from_actor = dynamic_cast<ControlActor *>(actor);
1918   }
1919 
1920   if (graph->is_graph_run_mode() || graph->is_any_type_input()) {
1921     // Link data arrow in graph mode.
1922     LinkDataArrowByKernelGraphInSinkMode(graph, from_actor, parser);
1923     return;
1924   }
1925 
1926   auto &execution_order = graph->execution_order();
1927   for (const auto &kernel : execution_order) {
1928     MS_EXCEPTION_IF_NULL(kernel);
1929     if ((!graph->is_graph_run_mode()) && (IsSkippedKernelActor(kernel) || !IsKernelActor(kernel))) {
1930       continue;
1931     }
1932     for (size_t i = 0; i < common::AnfAlgo::GetInputNum(kernel); ++i) {
1933       auto input_node = common::AnfAlgo::GetInputNode(kernel, i);
1934       MS_EXCEPTION_IF_NULL(input_node);
1935       auto input_with_index = common::AnfAlgo::VisitKernelWithReturnType(input_node, 0, false);
1936       auto input = input_with_index.first;
1937       MS_EXCEPTION_IF_NULL(input);
1938       if (HasAbstractMonad(input) || (!parser->IsControlFlowDataArrow(graph, input))) {
1939         continue;
1940       }
1941 
1942       auto from_node_with_index = GetFrontNodeByKernelGraph(input, graph.get());
1943       MS_EXCEPTION_IF_NULL(from_node_with_index.first);
1944       if (common::AnfAlgo::CheckPrimitiveType(from_node_with_index.first, prim::kPrimTupleGetItem) &&
1945           (!from_node_with_index.first->cast<CNodePtr>()->HasAttr(kAttrReplaceRealKernelInBackend)) &&
1946           (!common::AnfAlgo::IsTupleOutput(from_node_with_index.first))) {
1947         MS_LOG(WARNING) << "Input node:" << from_node_with_index.first->DebugString()
1948                         << " for graph:" << graph->ToString() << " is a tuple get item";
1949         from_node_with_index = FetchRealNodeByGetItem(from_node_with_index);
1950       }
1951 
1952       // If the formal parameter is a tuple type, the parameter of the kernel graph will not directly correspond
1953       // to the front parameter, but the node in the internal parameter.
1954       const auto &from_node = from_node_with_index.first;
1955       MS_EXCEPTION_IF_NULL(from_node);
1956       MS_LOG(DEBUG) << "Graph:" << graph->ToString() << " from node:" << from_node_with_index.first->DebugString()
1957                     << " index:" << from_node_with_index.second;
1958 
1959       // Fetch actor and link.
1960       auto type = FetchKernelTransformType(kernel, graph, {});
1961       auto to_actor = FetchActor(type, "", kernel, graph);
1962       MS_EXCEPTION_IF_NULL(to_actor);
1963       size_t from_index = 0;
1964       // If the input is a switch node and the graph does not need a stack, then the data arrow needs to be connected
1965       // from the switch actor.
1966       if ((common::AnfAlgo::CheckPrimitiveType(from_node, prim::kPrimSwitch) ||
1967            common::AnfAlgo::CheckPrimitiveType(from_node, prim::kPrimSwitchLayer)) &&
1968           (from_actor->type() != KernelTransformType::kStackActor)) {
1969         const auto &actor_name = GetActorName(from_node);
1970         auto actor = FetchActor(actor_name);
1971         MS_EXCEPTION_IF_NULL(actor);
1972         from_actor = dynamic_cast<ControlActor *>(actor);
1973       } else if (from_node->isa<CNode>() && (from_actor->type() != KernelTransformType::kStackActor)) {
1974         // If the input is an internal parameter, the input arrow should be linked to the exit actor of the kernel
1975         // graph which the internal parameter belong.
1976         MS_LOG(INFO) << "Internal parameter in control flow, backend input:" << input->DebugString()
1977                      << " front node:" << from_node->DebugString();
1978         const auto &from_graph = parser->FetchKernelGraphByFrontNode(from_node);
1979         MS_EXCEPTION_IF_NULL(from_graph);
1980         auto actor = FetchActor(parser->FetchGroupNameByKernelGraph(from_graph) + kExitActorNameSuffix);
1981         MS_EXCEPTION_IF_NULL(actor);
1982         auto exit_actor = dynamic_cast<ControlActor *>(actor);
1983         from_index = exit_actor->FetchNodePosition(from_node_with_index);
1984         SchedulerHelper::AddFormalParameterDeviceTensor(exit_actor, from_index, input, graph);
1985         SchedulerHelper::AddDataArrow(exit_actor, to_actor, from_index, i);
1986         continue;
1987       } else {
1988         MS_LOG(DEBUG) << "Fetch node:" << from_node_with_index.first->DebugString()
1989                       << " index:" << from_node_with_index.second << " from actor:" << from_actor->GetAID();
1990         from_index = from_actor->FetchNodePosition(from_node_with_index);
1991       }
1992 
1993       MS_EXCEPTION_IF_NULL(from_actor);
1994       SchedulerHelper::AddFormalParameterDeviceTensor(from_actor, from_index, input, graph);
1995       SchedulerHelper::AddDataArrow(from_actor, to_actor, from_index, i);
1996     }
1997   }
1998 }
1999 
LinkDataArrowForOutputActor(ActorSet * const actor_set,const GraphCompilerInfo & graph_compiler_info) const2000 void ControlNodeScheduler::LinkDataArrowForOutputActor(ActorSet *const actor_set,
2001                                                        const GraphCompilerInfo &graph_compiler_info) const {
2002   MS_EXCEPTION_IF_NULL(actor_set);
2003   auto &to_actor = actor_set->output_actor_;
2004   MS_EXCEPTION_IF_NULL(to_actor);
2005   const auto &parser = graph_compiler_info.control_node_parser_;
2006   MS_EXCEPTION_IF_NULL(parser);
2007   const auto &root_graph = parser->root_func_graph_;
2008   MS_EXCEPTION_IF_NULL(root_graph);
2009   const auto &return_node = root_graph->return_node();
2010   MS_EXCEPTION_IF_NULL(return_node);
2011 
2012   const auto &exit_actor_name = root_graph->ToString() + kExitActorNameSuffix;
2013   auto actor = FetchActor(exit_actor_name);
2014   MS_EXCEPTION_IF_NULL(actor);
2015   auto exit_actor = dynamic_cast<ExitActor *>(actor);
2016   MS_EXCEPTION_IF_NULL(exit_actor);
2017   for (size_t i = 0; i < exit_actor->formal_parameters_.size(); ++i) {
2018     SchedulerHelper::AddDataArrowForExitActor(exit_actor, to_actor.get(), i, i, 0);
2019     to_actor->input_datas_num_++;
2020   }
2021 
2022   auto control_node_to_device_contexts = parser->control_node_to_device_contexts_;
2023   auto iter = control_node_to_device_contexts.find(return_node);
2024   if (iter == control_node_to_device_contexts.end()) {
2025     MS_LOG_WITH_NODE(INTERNAL_EXCEPTION, return_node)
2026       << "#dmsg#Runtime error info:#dmsg#Failed to find device contexts for node:" << return_node->DebugString();
2027   }
2028   if (iter->second.size() != to_actor->device_contexts().size()) {
2029     MS_LOG(INTERNAL_EXCEPTION) << "#dmsg#Runtime error info:#dmsg#Invalid context size, need:"
2030                                << to_actor->device_contexts().size() << " current:" << iter->second.size();
2031   }
2032   to_actor->device_contexts_ = iter->second;
2033 }
2034 
LinkArrowForRootGraphEntranceActor(const GraphCompilerInfo & graph_compiler_info) const2035 void ControlNodeScheduler::LinkArrowForRootGraphEntranceActor(const GraphCompilerInfo &graph_compiler_info) const {
2036   MS_EXCEPTION_IF_NULL(graph_compiler_info.control_node_parser_);
2037   const auto &root_graph = graph_compiler_info.control_node_parser_->root_func_graph_;
2038   MS_EXCEPTION_IF_NULL(root_graph);
2039   const auto &entrance_actor_name = root_graph->ToString() + kEntranceActorNameSuffix;
2040   auto to_actor = dynamic_cast<EntranceActor *>(FetchActor(entrance_actor_name));
2041   MS_EXCEPTION_IF_NULL(to_actor);
2042 
2043   const auto &host_ds_actor_name = graph_compiler_info.name_ + kHostDSActorNameSuffix;
2044   auto host_ds_actor = dynamic_cast<HostQueueDataSourceActor *>(FetchActor(host_ds_actor_name));
2045   // No host data source actor scenario.
2046   if (host_ds_actor == nullptr) {
2047     const auto &data_prepare_actor_name = graph_compiler_info.name_ + kDataPrepareActorNameSuffix;
2048     auto data_prepare_actor = FetchActor(data_prepare_actor_name);
2049     MS_EXCEPTION_IF_NULL(data_prepare_actor);
2050     SchedulerHelper::AddControlArrow(data_prepare_actor, to_actor);
2051     return;
2052   }
2053 
2054   // The host data source actor sends all the input to the entrance actor of the root graph.
2055   for (size_t i = 0; i < to_actor->formal_parameters_.size(); ++i) {
2056     const auto &formal_parameter = to_actor->formal_parameters_[i];
2057     MS_EXCEPTION_IF_NULL(formal_parameter.first);
2058     MS_LOG(DEBUG) << "Formal parameter:" << formal_parameter.first->DebugString()
2059                   << " index:" << formal_parameter.second;
2060     const auto &iter = host_ds_actor->data_node_position_map_.find(formal_parameter);
2061     if (iter != host_ds_actor->data_node_position_map_.end()) {
2062       const auto &parameter_with_index = host_ds_actor->data_nodes()[iter->second];
2063       SchedulerHelper::AddDataArrow(host_ds_actor, to_actor, parameter_with_index.second, i,
2064                                     parameter_with_index.first);
2065     } else {
2066       MS_LOG(INFO) << "Invalid formal parameter:" << formal_parameter.first->DebugString()
2067                    << " index:" << formal_parameter.second << " for actor:" << to_actor->GetAID();
2068     }
2069   }
2070 }
2071 
SetTimeSummaryForControlActor(const GraphCompilerInfo & graph_compiler_info) const2072 void ControlNodeScheduler::SetTimeSummaryForControlActor(const GraphCompilerInfo &graph_compiler_info) const {
2073   const auto &parser = graph_compiler_info.control_node_parser_;
2074   MS_EXCEPTION_IF_NULL(parser);
2075 
2076   for (const auto &kernel_graph_group_info : parser->kernel_graph_group_infos_) {
2077     MS_EXCEPTION_IF_NULL(kernel_graph_group_info);
2078     const auto &exit_actor_name = kernel_graph_group_info->group_name_ + kExitActorNameSuffix;
2079     const auto &exit_base_actor = FetchActor(exit_actor_name);
2080     if (exit_base_actor == nullptr) {
2081       continue;
2082     }
2083     const auto &exit_actor = dynamic_cast<ControlActor *>(exit_base_actor);
2084     MS_EXCEPTION_IF_NULL(exit_actor);
2085 
2086     // Set the exit actor of kernel graph to its entrance actor or stack actor.
2087     if (kernel_graph_group_info->need_stack_ == false) {
2088       if (kernel_graph_group_info->graphs_.empty()) {
2089         continue;
2090       }
2091       const auto &graph = *(kernel_graph_group_info->graphs_.begin());
2092       const auto &func_graph = parser->FetchFuncGraphByKernelGraph(graph.get());
2093       MS_EXCEPTION_IF_NULL(func_graph);
2094       auto entrance_base_actor = FetchActor(func_graph->ToString() + kEntranceActorNameSuffix);
2095       if (entrance_base_actor != nullptr) {
2096         const auto &entrance_actor = dynamic_cast<ControlActor *>(entrance_base_actor);
2097         MS_EXCEPTION_IF_NULL(entrance_actor);
2098         (void)entrance_actor->end_actors_.emplace(exit_actor);
2099         MS_LOG(DEBUG) << "Add time summart for exit actor:" << exit_actor->GetAID()
2100                       << " to actor:" << entrance_actor->GetAID();
2101       }
2102       continue;
2103     }
2104 
2105     auto stack_base_actor = FetchActor(kernel_graph_group_info->group_name_ + kStackActorNameSuffix);
2106     if (stack_base_actor != nullptr) {
2107       const auto &stack_actor = dynamic_cast<ControlActor *>(stack_base_actor);
2108       MS_EXCEPTION_IF_NULL(stack_actor);
2109       (void)stack_actor->end_actors_.emplace(exit_actor);
2110       MS_LOG(DEBUG) << "Add time summart for exit actor:" << exit_actor->GetAID()
2111                     << " to actor:" << stack_actor->GetAID();
2112     }
2113   }
2114 }
2115 
IsNoInputActor(const ControlActor * control_actor) const2116 bool ControlNodeScheduler::IsNoInputActor(const ControlActor *control_actor) const {
2117   MS_EXCEPTION_IF_NULL(control_actor);
2118   return (control_actor->input_datas_num_ == 0 && control_actor->input_controls_num_ == 0 &&
2119           control_actor->input_partials_num_ == 0 && control_actor->input_branch_ids_num_ == 0);
2120 }
2121 }  // namespace runtime
2122 }  // namespace mindspore
2123