1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "runtime/graph_scheduler/control_node_scheduler.h"
18 #include "mindspore/core/ops/sequence_ops.h"
19 #include "mindspore/core/ops/framework_ops.h"
20 #include "runtime/graph_scheduler/control_node_parser.h"
21 #include "runtime/graph_scheduler/inline_control_flow_scheduler.h"
22 #include "runtime/graph_scheduler/scheduler_helper.h"
23
24 namespace mindspore {
25 namespace runtime {
26 namespace {
GetActorName(const AnfNodePtr & node)27 std::string GetActorName(const AnfNodePtr &node) {
28 MS_EXCEPTION_IF_NULL(node);
29 auto debug_name = node->DebugString();
30 auto index = debug_name.find('{');
31 if ((index != std::string::npos) && (index > 0)) {
32 debug_name = debug_name.substr(0, index);
33 }
34
35 if (common::AnfAlgo::IsCallNode(node)) {
36 return "Call_" + node->UniqueName() + "_" + debug_name;
37 } else {
38 return node->UniqueName() + "_" + debug_name;
39 }
40 }
41
GetStackActorNameByExitName(const std::string & exit_name)42 std::string GetStackActorNameByExitName(const std::string &exit_name) {
43 size_t pos = exit_name.find(kExitActorNameSuffix);
44 if (pos == std::string::npos) {
45 MS_LOG(INTERNAL_EXCEPTION) << "#dmsg#Runtime error info:#dmsg#Invalid exit actor name:" << exit_name;
46 }
47
48 return exit_name.substr(0, pos) + kStackActorNameSuffix;
49 }
50
51 // Parameter and ref node can not copy the device tensor.
is_need_copy_device_tensor(const AnfNodePtr & backend_node,size_t index)52 bool is_need_copy_device_tensor(const AnfNodePtr &backend_node, size_t index) {
53 MS_EXCEPTION_IF_NULL(backend_node);
54 // Skip the parameter and Load node.
55 const auto &real_backend_node = common::AnfAlgo::VisitKernelWithReturnType(backend_node, index, false).first;
56 if (real_backend_node != nullptr && (!real_backend_node->isa<CNode>())) {
57 return false;
58 }
59
60 if (common::AnfAlgo::HasAbstractRef(backend_node)) {
61 return false;
62 }
63
64 auto kernel_graph = AnfAlgo::FetchKernelGraph(backend_node.get());
65 MS_EXCEPTION_IF_NULL(kernel_graph);
66 if (kernel_graph->IsInRefOutputMap({backend_node, index})) {
67 if (!kernel_graph->is_graph_run_mode()) {
68 return false;
69 }
70 const auto &origin_node = kernel_graph->GetRefCorrespondOutput({backend_node, index}).first;
71 MS_EXCEPTION_IF_NULL(origin_node);
72 if (origin_node->isa<ValueNode>() || origin_node->isa<Parameter>()) {
73 return false;
74 }
75 }
76 return true;
77 }
78
79 // Check whether the exit actor corresponding to the call node to the to actor already exists control arrow.
IsControlArrowExistForCallNode(const AnfNodePtr & node,const AbstractActor * const to_actor,const ControlNodeParserPtr & parser)80 bool IsControlArrowExistForCallNode(const AnfNodePtr &node, const AbstractActor *const to_actor,
81 const ControlNodeParserPtr &parser) {
82 MS_EXCEPTION_IF_NULL(node);
83 MS_EXCEPTION_IF_NULL(to_actor);
84 MS_EXCEPTION_IF_NULL(parser);
85 if (!common::AnfAlgo::IsCallNode(node)) {
86 MS_LOG_WITH_NODE(INTERNAL_EXCEPTION, node)
87 << "#dmsg#Runtime error info:#dmsg#Invalid call node:" << node->DebugString();
88 }
89 int branch_id = parser->FetchBranchIDByCallNode(node);
90
91 const auto &func_graphs = parser->FetchFuncGraphbyCallNode(node);
92 if (func_graphs.empty()) {
93 MS_LOG_WITH_NODE(INTERNAL_EXCEPTION, node)
94 << "#dmsg#Runtime error info:#dmsg#Failed to get funcgraph by call node:" << node->DebugString();
95 }
96 MS_EXCEPTION_IF_NULL(*(func_graphs.begin()));
97 auto actor_name = (*(func_graphs.begin()))->ToString() + kExitActorNameSuffix;
98 const auto &actor = FetchActor(actor_name);
99 MS_EXCEPTION_IF_NULL(actor);
100 const auto &exit_actor = dynamic_cast<ExitActor *>(actor);
101 MS_EXCEPTION_IF_NULL(exit_actor);
102
103 const auto &branch_arrows = exit_actor->output_branch_control_arrows();
104 const auto &arrow_iter = branch_arrows.find(branch_id);
105 if (arrow_iter == branch_arrows.end()) {
106 return false;
107 }
108 const auto &arrows = arrow_iter->second;
109 return std::find(arrows.begin(), arrows.end(), to_actor->GetAID()) != arrows.end();
110 }
111
IsNotCut(const AnfNodePtr & node)112 bool IsNotCut(const AnfNodePtr &node) {
113 MS_EXCEPTION_IF_NULL(node);
114 if (!node->isa<CNode>()) {
115 return false;
116 }
117 auto cnode = node->cast<CNodePtr>();
118 MS_EXCEPTION_IF_NULL(cnode);
119 return cnode->HasPrimalAttr(kAttrNotCut);
120 }
121 } // namespace
122
Build(const GraphCompilerInfo & graph_compiler_info,const AID & memory_manager_aid)123 ControlActorSetPtr ControlNodeScheduler::Build(const GraphCompilerInfo &graph_compiler_info,
124 const AID &memory_manager_aid) {
125 const auto &control_nodes = graph_compiler_info.control_nodes_;
126 if (control_nodes.size() <= kSingleControlNode) {
127 return nullptr;
128 }
129
130 memory_manager_aid_ = memory_manager_aid;
131 ControlActorSetPtr control_actors = std::make_shared<ControlActorSet>();
132 MS_EXCEPTION_IF_NULL(control_actors);
133 control_actors->switch_actors_ = BuildSwitchActor(graph_compiler_info);
134 control_actors->gather_actors_ = BuildGatherActor(graph_compiler_info);
135 control_actors->entrance_actors_ = BuildEntranceActor(graph_compiler_info);
136 control_actors->exit_actors_ = BuildExitActor(graph_compiler_info);
137 control_actors->stack_actors_ = BuildStackActor(graph_compiler_info);
138 return control_actors;
139 }
140
BuildSwitchActor(const GraphCompilerInfo & graph_compiler_info) const141 std::vector<SwitchActorPtr> ControlNodeScheduler::BuildSwitchActor(const GraphCompilerInfo &graph_compiler_info) const {
142 std::vector<SwitchActorPtr> switch_actors;
143 const auto &control_nodes = graph_compiler_info.control_nodes_;
144
145 for (const auto &control_node : control_nodes) {
146 // Switch node and switch layer node will be converted to switch actor.
147 if (common::AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimSwitch) ||
148 common::AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimSwitchLayer)) {
149 const auto &actor_name = GetActorName(control_node);
150 const auto ¶meters = FetchInputNodeByCNode(control_node);
151 const auto &switch_actor =
152 std::make_shared<SwitchActor>(actor_name, memory_manager_aid_, parameters, control_node);
153 (void)switch_actors.emplace_back(switch_actor);
154 for (const auto ¶meter : parameters) {
155 MS_EXCEPTION_IF_NULL(parameter.first);
156 MS_LOG(DEBUG) << "Print formal parameter for actor:" << actor_name
157 << " parameter:" << parameter.first->DebugString() << " index:" << parameter.second;
158 }
159 InsertActor(switch_actor.get());
160 }
161 }
162 return switch_actors;
163 }
164
BuildDataSourceActorForControlNode(const GraphCompilerInfo & graph_compiler_info,const HostTensorQueuePtr & host_queue,const HostQueueDSActorPtr & host_queue_ds_actor,const AID & memory_manager_aid,std::vector<DataSourceActorPtr> * data_source_actors) const165 void ControlNodeScheduler::BuildDataSourceActorForControlNode(
166 const GraphCompilerInfo &graph_compiler_info, const HostTensorQueuePtr &host_queue,
167 const HostQueueDSActorPtr &host_queue_ds_actor, const AID &memory_manager_aid,
168 std::vector<DataSourceActorPtr> *data_source_actors) const {
169 HostQueueDSActorPtr control_node_ds_actor = host_queue_ds_actor;
170 const auto parser = graph_compiler_info.control_node_parser_;
171 MS_EXCEPTION_IF_NULL(parser);
172 MS_EXCEPTION_IF_NULL(data_source_actors);
173
174 // Initialize the parameter in the control node, first get all the front parameters in the control node, then find
175 // the corresponding backend parameter from the map, and insert it into the host data source actor.
176 const auto &control_node_parameters = parser->control_node_parameters();
177 for (const auto ¶meter_with_index : control_node_parameters) {
178 MS_EXCEPTION_IF_NULL(parameter_with_index.first);
179 if (IsPersistentDeviceTensor(parameter_with_index.first)) {
180 continue;
181 }
182 if (control_node_ds_actor == nullptr) {
183 auto actor_name = graph_compiler_info.name_ + kHostDSActorNameSuffix;
184 MS_LOG(INFO) << "Create host queue data source actor: " << actor_name;
185 control_node_ds_actor =
186 std::make_shared<HostQueueDataSourceActor>(actor_name, 1, memory_manager_aid, nullptr, nullptr, host_queue);
187 MS_EXCEPTION_IF_NULL(control_node_ds_actor);
188 InsertActor(control_node_ds_actor.get());
189 (void)data_source_actors->emplace_back(control_node_ds_actor);
190 }
191
192 auto &node_map = control_node_ds_actor->data_node_position_map_;
193 if (node_map.find(parameter_with_index) != node_map.end()) {
194 continue;
195 }
196 graph_compiler_info.origin_parameters_to_backend_parameters_[parameter_with_index.first].emplace_back(
197 std::make_pair(parameter_with_index, parameter_with_index));
198
199 const auto &node_with_index_with_context =
200 parser->FetchBackendParameterWithContextByFrontParameter(parameter_with_index);
201 const auto &node_with_index = node_with_index_with_context.first;
202 const auto &device_context = node_with_index_with_context.second;
203 MS_EXCEPTION_IF_NULL(node_with_index.first);
204 MS_EXCEPTION_IF_NULL(device_context);
205 MS_LOG(DEBUG) << "Control node parameter front node:" << parameter_with_index.first->DebugString()
206 << " index:" << parameter_with_index.second
207 << " backend node:" << node_with_index.first->DebugString() << " index:" << node_with_index.second;
208 auto iter = find(control_node_ds_actor->data_node_with_indexs_.begin(),
209 control_node_ds_actor->data_node_with_indexs_.end(), node_with_index);
210 if (iter != control_node_ds_actor->data_node_with_indexs_.end()) {
211 (void)node_map.emplace(parameter_with_index, iter - control_node_ds_actor->data_node_with_indexs_.begin());
212 MS_LOG(DEBUG) << "Insert front node:" << parameter_with_index.first->DebugString()
213 << " index:" << parameter_with_index.second << " to host queue data source actor.";
214 } else {
215 CreateBuildInfoForFrontNode(parameter_with_index, node_with_index.first);
216 // Create device tensor.
217 const auto &device_address = AnfAlgo::GetMutableOutputAddr(node_with_index.first, node_with_index.second, false);
218 MS_EXCEPTION_IF_NULL(device_address);
219 const auto &sub_abstract =
220 common::AnfAlgo::FetchAbstractByIndex(parameter_with_index.first->abstract(), parameter_with_index.second);
221 MS_EXCEPTION_IF_NULL(sub_abstract);
222 const auto &kernel_tensor = std::make_shared<kernel::KernelTensor>(
223 sub_abstract->BuildShape(), sub_abstract->BuildType(), nullptr, nullptr, device_address->GetSize(),
224 device_address->format(), device_address->type_id(), device_address->host_shape(),
225 device_context->device_context_key().device_name_, device_context->device_context_key().device_id_);
226 MS_EXCEPTION_IF_NULL(kernel_tensor);
227 kernel_tensor->set_stream_id(AnfAlgo::GetStreamId(parameter_with_index.first));
228 auto new_address = device_context->device_res_manager_->CreateDeviceAddress(kernel_tensor);
229 MS_EXCEPTION_IF_NULL(new_address);
230 MS_LOG(DEBUG) << "Create new address for node that has no corresponding backend node:"
231 << parameter_with_index.first->DebugString() << " index:" << parameter_with_index.second
232 << " addr:" << new_address << " size:" << device_address->GetSize()
233 << ", type id:" << device_address->type_id()
234 << " type:" << (kernel_tensor->GetType() == nullptr ? "null" : kernel_tensor->GetType()->ToString())
235 << " shape:"
236 << (kernel_tensor->GetShape() == nullptr ? "null" : kernel_tensor->GetShape()->ToString());
237 AnfAlgo::SetOutputAddr(new_address, parameter_with_index.second, parameter_with_index.first.get());
238
239 (void)node_map.emplace(parameter_with_index, control_node_ds_actor->data_node_with_indexs_.size());
240 (void)control_node_ds_actor->data_node_with_indexs_.emplace_back(parameter_with_index);
241 (void)control_node_ds_actor->device_contexts_.emplace_back(device_context);
242 }
243 }
244 }
245
BuildGatherActor(const GraphCompilerInfo & graph_compiler_info) const246 std::vector<GatherActorPtr> ControlNodeScheduler::BuildGatherActor(const GraphCompilerInfo &graph_compiler_info) const {
247 std::vector<GatherActorPtr> gather_actors;
248 const auto &control_nodes = graph_compiler_info.control_nodes_;
249 const auto &parser = graph_compiler_info.control_node_parser_;
250 MS_EXCEPTION_IF_NULL(parser);
251
252 for (const auto &control_node : control_nodes) {
253 MS_EXCEPTION_IF_NULL(control_node);
254 // Partial node and call node will be converted to gather actor.
255 if ((common::AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimPartial) && (!IsInvalidPartial(control_node))) ||
256 common::AnfAlgo::IsCallNode(control_node)) {
257 const auto &actor_name = GetActorName(control_node);
258 const auto ¶meters = FetchInputNodeByCNode(control_node);
259 const auto &gather_actor =
260 std::make_shared<GatherActor>(actor_name, memory_manager_aid_, parameters, control_node);
261 MS_EXCEPTION_IF_NULL(gather_actor);
262 (void)gather_actors.emplace_back(gather_actor);
263 for (const auto ¶meter : parameters) {
264 MS_EXCEPTION_IF_NULL(parameter.first);
265 MS_LOG(DEBUG) << "Print formal parameter for actor:" << actor_name
266 << " parameter:" << parameter.first->DebugString() << " index:" << parameter.second;
267 }
268 InsertActor(gather_actor.get());
269
270 // The gather actor corresponding to a call node needs to set the branch id.
271 if (common::AnfAlgo::IsCallNode(control_node)) {
272 gather_actor->output_branch_id_ = parser->FetchBranchIDByCallNode(control_node);
273 }
274
275 // Fetch device contexts for gather actor.
276 const auto &iter = parser->control_node_to_device_contexts_.find(control_node);
277 if (iter == parser->control_node_to_device_contexts_.end()) {
278 MS_LOG_WITH_NODE(INTERNAL_EXCEPTION, control_node)
279 << "#dmsg#Runtime error info:#dmsg#Failed to get device contexts for node:" << control_node->DebugString();
280 }
281 gather_actor->device_contexts_ = iter->second;
282 }
283 }
284 return gather_actors;
285 }
286
BuildEntranceActor(const GraphCompilerInfo & graph_compiler_info) const287 std::vector<EntranceActorPtr> ControlNodeScheduler::BuildEntranceActor(
288 const GraphCompilerInfo &graph_compiler_info) const {
289 const auto &parser = graph_compiler_info.control_node_parser_;
290 MS_EXCEPTION_IF_NULL(parser);
291 const auto &call_node_to_func_graphs = parser->call_node_to_func_graphs_;
292 std::unordered_map<FuncGraphPtr, std::set<KernelWithIndex>> func_graph_to_call_nodes;
293 for (const auto &call_node_to_func_graph : call_node_to_func_graphs) {
294 const auto &node = call_node_to_func_graph.first;
295 for (const auto &func_graph : call_node_to_func_graph.second) {
296 (void)func_graph_to_call_nodes[func_graph].emplace(node, 0);
297 }
298 }
299
300 std::vector<EntranceActorPtr> entrance_actors;
301 const auto &control_nodes = graph_compiler_info.control_nodes_;
302 for (const auto &control_node : control_nodes) {
303 MS_EXCEPTION_IF_NULL(control_node);
304 if (common::AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimReturn)) {
305 const auto &func_graph = control_node->func_graph();
306 MS_EXCEPTION_IF_NULL(func_graph);
307 const auto &actor_name = func_graph->ToString() + kEntranceActorNameSuffix;
308 std::vector<KernelWithIndex> formal_parameters;
309
310 // The entrance actor has two parts of node members :
311 // 1. The formal parameters of the subgraph are used to connect the actor's output arrows.
312 for (const auto ¶meter : func_graph->parameters()) {
313 MS_EXCEPTION_IF_NULL(parameter);
314 const auto &abstract = parameter->abstract();
315 MS_EXCEPTION_IF_NULL(abstract);
316 size_t output_num = common::AnfAlgo::GetOutputNumByAbstract(abstract);
317 for (size_t i = 0; i < output_num; ++i) {
318 (void)formal_parameters.emplace_back(parameter, i);
319 }
320 }
321
322 // 2. The caller of the subgraph, namely call nodes, is used to connect the input arrows.
323 std::set<KernelWithIndex> call_nodes;
324 const auto &iter = func_graph_to_call_nodes.find(func_graph);
325 if (iter != func_graph_to_call_nodes.end()) {
326 call_nodes = iter->second;
327 }
328 for (const auto &formal_parameter : formal_parameters) {
329 MS_EXCEPTION_IF_NULL(formal_parameter.first);
330 MS_LOG(DEBUG) << "Print formal parameter for actor:" << actor_name
331 << " parameter:" << formal_parameter.first->DebugString() << " index:" << formal_parameter.second;
332 }
333 const auto &entrance_actor =
334 std::make_shared<EntranceActor>(actor_name, memory_manager_aid_, formal_parameters, call_nodes, control_node);
335 MS_EXCEPTION_IF_NULL(entrance_actor);
336 auto context_iter = parser->func_graph_to_device_contexts_.find(func_graph);
337 if (context_iter == parser->func_graph_to_device_contexts_.end() ||
338 context_iter->second.size() < formal_parameters.size()) {
339 MS_LOG(INTERNAL_EXCEPTION)
340 << "#dmsg#Runtime error info:#dmsg#Invalid device contexts for funcgraph:" << func_graph->ToString()
341 << " parameter num:" << formal_parameters.size() << " device contexts num:"
342 << (context_iter == parser->func_graph_to_device_contexts_.end() ? 0 : context_iter->second.size());
343 }
344 entrance_actor->device_contexts_.clear();
345 (void)entrance_actor->device_contexts_.insert(
346 entrance_actor->device_contexts_.begin(), context_iter->second.begin(),
347 context_iter->second.begin() + SizeToLong(formal_parameters.size()));
348 (void)entrance_actors.emplace_back(entrance_actor);
349 InsertActor(entrance_actor.get());
350 }
351 }
352
353 return entrance_actors;
354 }
355
BuildExitActor(const GraphCompilerInfo & graph_compiler_info) const356 std::vector<ExitActorPtr> ControlNodeScheduler::BuildExitActor(const GraphCompilerInfo &graph_compiler_info) const {
357 std::vector<ExitActorPtr> exit_actors;
358 const auto &control_nodes = graph_compiler_info.control_nodes_;
359 const auto &parser = graph_compiler_info.control_node_parser_;
360 MS_EXCEPTION_IF_NULL(parser);
361
362 // The exit actor is used in 2 places:
363 // 1.funcgraph output, that is the output of the return node.
364 for (const auto &control_node : control_nodes) {
365 MS_EXCEPTION_IF_NULL(control_node);
366 if (common::AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimReturn)) {
367 const auto &func_graph = control_node->func_graph();
368 MS_EXCEPTION_IF_NULL(func_graph);
369 const auto &actor_name = func_graph->ToString() + kExitActorNameSuffix;
370 const auto ¶meters = FetchInputNodeByCNode(control_node);
371 const auto &exit_actor = std::make_shared<ExitActor>(actor_name, memory_manager_aid_, parameters, control_node);
372 MS_EXCEPTION_IF_NULL(exit_actor);
373 for (const auto ¶meter : parameters) {
374 MS_EXCEPTION_IF_NULL(parameter.first);
375 MS_LOG(DEBUG) << "Print formal parameter for actor:" << actor_name
376 << " parameter:" << parameter.first->DebugString() << " index:" << parameter.second;
377 }
378 auto context_iter = parser->control_node_to_device_contexts_.find(control_node);
379 if (context_iter == parser->control_node_to_device_contexts_.end() ||
380 context_iter->second.size() != parameters.size()) {
381 MS_LOG(INTERNAL_EXCEPTION) << "#dmsg#Runtime error info:#dmsg#Failed to get device contexts for funcgraph:"
382 << func_graph->ToString();
383 }
384 exit_actor->device_contexts_ = context_iter->second;
385 (void)exit_actors.emplace_back(exit_actor);
386 InsertActor(exit_actor.get());
387 }
388 }
389
390 if (graph_compiler_info.graphs_.size() != graph_compiler_info.device_contexts_.size()) {
391 MS_LOG(INTERNAL_EXCEPTION) << "#dmsg#Runtime error info:#dmsg#Invalid graphs num:"
392 << graph_compiler_info.graphs_.size()
393 << " and contexts num:" << graph_compiler_info.device_contexts_.size();
394 }
395
396 // 2. Replace the device address in the kernel actor when calling funcgraph, that is to say in the data exchange
397 // between kernel graph and the control node, in fact, it is the output of the kernel graph.
398 for (const auto &kernel_graph_group_info : parser->kernel_graph_group_infos_) {
399 MS_EXCEPTION_IF_NULL(kernel_graph_group_info);
400 if (kernel_graph_group_info->graphs_.empty()) {
401 continue;
402 }
403
404 std::vector<bool> is_need_copy_device_tensors;
405 std::vector<bool> is_need_dynamic_checks;
406 std::vector<bool> is_dynamic_shapes;
407 std::vector<KernelWithIndex> formal_parameters;
408 std::vector<const DeviceContext *> device_contexts;
409
410 for (const auto &node_with_context : kernel_graph_group_info->front_output_nodes_) {
411 if (HasAbstractMonad(node_with_context.first.first) || IsCsrNode(node_with_context.first.first)) {
412 continue;
413 }
414 // Collect inputs of exit actor.
415 (void)formal_parameters.emplace_back(node_with_context.first);
416 // Get the device contexts of the exit actor's cnode inputs.
417 const AnfNodePtr &backend_node = node_with_context.second.first.first;
418 MS_EXCEPTION_IF_NULL(backend_node);
419 (void)is_need_copy_device_tensors.emplace_back(
420 is_need_copy_device_tensor(backend_node, node_with_context.second.first.second));
421 (void)is_need_dynamic_checks.emplace_back(
422 common::AnfAlgo::CheckPrimitiveType(backend_node, prim::kPrimConditionGather));
423 auto is_dynamic_shape =
424 common::AnfAlgo::IsDynamicShape(backend_node) || common::AnfAlgo::IsDynamicSequence(backend_node);
425 (void)is_dynamic_shapes.emplace_back(is_dynamic_shape);
426 (void)device_contexts.emplace_back(node_with_context.second.second);
427 }
428 const auto &actor_name = kernel_graph_group_info->group_name_ + kExitActorNameSuffix;
429 const auto &exit_actor = std::make_shared<ExitActor>(actor_name, memory_manager_aid_, formal_parameters, nullptr);
430 MS_EXCEPTION_IF_NULL(exit_actor);
431 exit_actor->is_need_copy_device_tensors_.swap(is_need_copy_device_tensors);
432 exit_actor->is_need_dynamic_checks_.swap(is_need_dynamic_checks);
433 exit_actor->is_dynamic_shapes_.swap(is_dynamic_shapes);
434 exit_actor->device_contexts_.swap(device_contexts);
435 for (const auto &graph : kernel_graph_group_info->graphs_) {
436 MS_EXCEPTION_IF_NULL(graph);
437 std::for_each(graph->GetRefMap().begin(), graph->GetRefMap().end(),
438 [&exit_actor, &graph](const std::pair<KernelWithIndex, KernelWithIndex> &pair) {
439 exit_actor->ref_out_in_map_[pair.first] = graph->GetRefNodeRecursive(pair.first);
440 });
441 }
442 (void)exit_actors.emplace_back(exit_actor);
443 InsertActor(exit_actor.get());
444 }
445
446 return exit_actors;
447 }
448
BuildStackActor(const GraphCompilerInfo & graph_compiler_info) const449 std::vector<StackActorPtr> ControlNodeScheduler::BuildStackActor(const GraphCompilerInfo &graph_compiler_info) const {
450 std::vector<StackActorPtr> stack_actors;
451 const auto &parser = graph_compiler_info.control_node_parser_;
452 MS_EXCEPTION_IF_NULL(parser);
453
454 // Create a corresponding stack actor for each kernel graph that has a call node as input.
455 for (const auto &kernel_graph_group_info : parser->kernel_graph_group_infos_) {
456 MS_EXCEPTION_IF_NULL(kernel_graph_group_info);
457 if (!kernel_graph_group_info->need_stack_) {
458 continue;
459 }
460 const auto &actor_name = kernel_graph_group_info->group_name_ + kStackActorNameSuffix;
461 size_t input_parameter_data_num = 0;
462 std::vector<const DeviceContext *> device_contexts;
463 std::vector<KernelWithIndex> formal_parameters;
464 // Collect inputs of stack actor.
465 for (const auto &node_with_context : kernel_graph_group_info->front_input_nodes_) {
466 // If the input comes from inside funcgraph, put it at the front of the vector, otherwise put it at the end.
467 const auto &from_node = node_with_context.first.first;
468 MS_EXCEPTION_IF_NULL(from_node);
469 auto iter = parser->node_to_level_.find(from_node);
470 if (iter == parser->node_to_level_.end()) {
471 MS_LOG_WITH_NODE(INTERNAL_EXCEPTION, from_node)
472 << "#dmsg#Runtime error info:#dmsg#Failed to get level by from node:" << from_node->DebugString()
473 << " in graph:" << kernel_graph_group_info->group_name_;
474 }
475 if (iter->second == kernel_graph_group_info->level_ && (!parser->IsRootGraphPersistentDeviceTensor(from_node))) {
476 (void)formal_parameters.emplace_back(node_with_context.first);
477 (void)device_contexts.emplace_back(node_with_context.second);
478 MS_LOG(DEBUG) << "Add normal parameter for actor:" << actor_name << " node:" << from_node->DebugString()
479 << " index:" << node_with_context.first.second;
480 } else {
481 (void)formal_parameters.insert(formal_parameters.begin(), node_with_context.first);
482 (void)device_contexts.insert(device_contexts.begin(), node_with_context.second);
483 MS_LOG(DEBUG) << "Add stack parameter for actor:" << actor_name << " node:" << from_node->DebugString()
484 << " index:" << node_with_context.first.second;
485 input_parameter_data_num++;
486 }
487 }
488 const auto &stack_actor = std::make_shared<StackActor>(actor_name, memory_manager_aid_, formal_parameters);
489 MS_EXCEPTION_IF_NULL(stack_actor);
490 (void)stack_actors.emplace_back(stack_actor);
491 stack_actor->device_contexts_.swap(device_contexts);
492 stack_actor->input_stack_data_num_ = input_parameter_data_num;
493 InsertActor(stack_actor.get());
494 }
495 // Create stack actors for control nodes.
496 BuildStackActorForControlNode(graph_compiler_info, &stack_actors);
497
498 return stack_actors;
499 }
500
BuildStackActorForControlNode(const GraphCompilerInfo & graph_compiler_info,std::vector<StackActorPtr> * const stack_actors) const501 void ControlNodeScheduler::BuildStackActorForControlNode(const GraphCompilerInfo &graph_compiler_info,
502 std::vector<StackActorPtr> *const stack_actors) const {
503 const auto &parser = graph_compiler_info.control_node_parser_;
504 MS_EXCEPTION_IF_NULL(parser);
505
506 for (const auto &need_stack_control_node : parser->need_stack_control_nodes_) {
507 MS_EXCEPTION_IF_NULL(need_stack_control_node);
508 MS_LOG(DEBUG) << "Build stack actor for control node:" << need_stack_control_node->DebugString();
509 const auto &stack_actor_name = GetActorName(need_stack_control_node) + kStackActorNameSuffix;
510 std::vector<KernelWithIndex> formal_parameters;
511 std::vector<const DeviceContext *> device_contexts;
512 size_t input_parameter_data_num = 0;
513 size_t input_parameter_partials_num = 0;
514
515 // Fetch the control actor of control node.
516 std::string control_actor_name = "";
517 if (common::AnfAlgo::CheckPrimitiveType(need_stack_control_node, prim::kPrimReturn)) {
518 const auto &func_graph = need_stack_control_node->func_graph();
519 MS_EXCEPTION_IF_NULL(func_graph);
520 control_actor_name = func_graph->ToString() + kExitActorNameSuffix;
521 } else if (common::AnfAlgo::CheckPrimitiveType(need_stack_control_node, prim::kPrimPartial) ||
522 common::AnfAlgo::CheckPrimitiveType(need_stack_control_node, prim::kPrimSwitch) ||
523 common::AnfAlgo::CheckPrimitiveType(need_stack_control_node, prim::kPrimSwitchLayer) ||
524 common::AnfAlgo::IsCallNode(need_stack_control_node)) {
525 control_actor_name = GetActorName(need_stack_control_node);
526 } else {
527 MS_LOG_WITH_NODE(INTERNAL_EXCEPTION, need_stack_control_node)
528 << "#dmsg#Runtime error info:#dmsg#Invalid control node:" << need_stack_control_node->DebugString();
529 }
530
531 auto iter = parser->node_to_level_.find(need_stack_control_node);
532 if (iter == parser->node_to_level_.end()) {
533 MS_LOG_WITH_NODE(INTERNAL_EXCEPTION, need_stack_control_node)
534 << "#dmsg#Runtime error info:#dmsg#Failed to get level for need stack control node:"
535 << need_stack_control_node->DebugString();
536 }
537 size_t control_node_level = iter->second;
538
539 auto actor = FetchActor(control_actor_name);
540 if (actor == nullptr) {
541 MS_LOG(INTERNAL_EXCEPTION) << "#dmsg#Runtime error info:#dmsg#Invalid actor name:" << control_actor_name;
542 }
543 auto control_actor = dynamic_cast<ControlActor *>(actor);
544 MS_EXCEPTION_IF_NULL(control_actor);
545 if (control_actor->formal_parameters_.size() > control_actor->device_contexts_.size()) {
546 MS_LOG(INTERNAL_EXCEPTION) << "#dmsg#Runtime error info:#dmsg#Invalid device context size:"
547 << control_actor->device_contexts_.size()
548 << " and formal parameter size:" << control_actor->formal_parameters_.size()
549 << " for actor:" << control_actor->GetAID();
550 }
551
552 // Collect formal parameters and device contexts, skip the value nodes.
553 for (size_t i = 0; i < control_actor->formal_parameters_.size(); ++i) {
554 const auto ¶meter = control_actor->formal_parameters_[i];
555 auto device_context = control_actor->device_contexts_[i];
556 MS_EXCEPTION_IF_NULL(parameter.first);
557 if (parameter.first->isa<ValueNode>()) {
558 continue;
559 }
560
561 iter = parser->node_to_level_.find(parameter.first);
562 if (iter == parser->node_to_level_.end()) {
563 MS_LOG_WITH_NODE(INTERNAL_EXCEPTION, parameter.first)
564 << "#dmsg#Runtime error info:#dmsg#Failed to get level for formal parameter:"
565 << parameter.first->DebugString()
566 << " for need stack control node:" << need_stack_control_node->DebugString();
567 }
568
569 if (control_node_level == iter->second && (!parser->IsRootGraphPersistentDeviceTensor(parameter.first))) {
570 MS_LOG(DEBUG) << "Add normal parameter:" << parameter.first->DebugString()
571 << " for stack actor:" << stack_actor_name;
572 (void)formal_parameters.emplace_back(parameter);
573 (void)device_contexts.emplace_back(device_context);
574 } else {
575 MS_LOG(DEBUG) << "Add stack parameter:" << parameter.first->DebugString()
576 << " for stack actor:" << stack_actor_name;
577 (void)formal_parameters.insert(formal_parameters.begin(), parameter);
578 (void)device_contexts.insert(device_contexts.begin(), device_context);
579
580 const auto &abstract = parameter.first->abstract();
581 MS_EXCEPTION_IF_NULL(abstract);
582 const auto &real_abstract = common::AnfAlgo::FetchAbstractByIndex(abstract, parameter.second);
583 MS_EXCEPTION_IF_NULL(real_abstract);
584 if (real_abstract->isa<abstract::AbstractFunction>()) {
585 input_parameter_partials_num++;
586 } else {
587 input_parameter_data_num++;
588 }
589 }
590 }
591 // Create stack actor.
592 const auto &stack_actor = std::make_shared<StackActor>(stack_actor_name, memory_manager_aid_, formal_parameters);
593 MS_EXCEPTION_IF_NULL(stack_actor);
594 stack_actor->device_contexts_ = device_contexts;
595 stack_actor->input_stack_data_num_ = input_parameter_data_num;
596 stack_actor->input_stack_partials_num_ = input_parameter_partials_num;
597 stack_actor->node_ = need_stack_control_node;
598 InsertActor(stack_actor.get());
599 (void)stack_actors->emplace_back(stack_actor);
600 }
601 }
602
603 namespace {
ParseRealIndex(const mindspore::HashMap<size_t,size_t> & dynamic_len_index,size_t formal_input_num,std::vector<std::pair<std::vector<size_t>,bool>> * real_indexes,AbstractActor * actor)604 void ParseRealIndex(const mindspore::HashMap<size_t, size_t> &dynamic_len_index, size_t formal_input_num,
605 std::vector<std::pair<std::vector<size_t>, bool>> *real_indexes, AbstractActor *actor) {
606 MS_EXCEPTION_IF_NULL(real_indexes);
607 MS_EXCEPTION_IF_NULL(actor);
608 auto tmp_dynamic_len_index = dynamic_len_index;
609 size_t real_output_num = formal_input_num + tmp_dynamic_len_index.size();
610 for (const auto &index_pair : tmp_dynamic_len_index) {
611 if (real_output_num < index_pair.second) {
612 MS_LOG(EXCEPTION) << "Invalid dynamic len:" << std::to_string(index_pair.second)
613 << " start index:" << std::to_string(index_pair.first)
614 << " real input num:" << std::to_string(real_output_num)
615 << " for actor:" << actor->GetAID().Name();
616 }
617 real_output_num -= index_pair.second;
618 }
619 MS_LOG(DEBUG) << "for actor:" << actor->GetAID() << " real output num:" << real_output_num;
620 size_t start_index = 0;
621 for (const auto &pair : tmp_dynamic_len_index) {
622 MS_LOG(DEBUG) << "start index:" << pair.first << " len:" << pair.second;
623 }
624 for (size_t i = 0; i < real_output_num; ++i) {
625 MS_LOG(DEBUG) << "for actor:" << actor->GetAID() << " real output index:" << i;
626 if (tmp_dynamic_len_index.find(start_index) != tmp_dynamic_len_index.end()) {
627 std::vector<size_t> indexes;
628 for (size_t j = 0; j < tmp_dynamic_len_index[start_index]; ++j) {
629 indexes.emplace_back(j + start_index);
630 }
631 real_indexes->emplace_back(indexes, true);
632 start_index += tmp_dynamic_len_index[start_index];
633 tmp_dynamic_len_index.erase(start_index);
634 } else {
635 std::vector<size_t> indexes{start_index};
636 real_indexes->emplace_back(indexes, false);
637 start_index++;
638 }
639 }
640 for (size_t i = 0; i < real_indexes->size(); ++i) {
641 std::string index_str = "index_";
642 for (const auto &index : (*real_indexes)[i].first) {
643 index_str = index_str + std::to_string(index) + "_";
644 }
645 MS_LOG(DEBUG) << "for actor:" << actor->GetAID() << " real input " << i << " " << index_str;
646 }
647 if (real_indexes->size() != real_output_num) {
648 MS_LOG(EXCEPTION) << "Invalid real index size:" << std::to_string(real_indexes->size())
649 << " start need:" << std::to_string(real_output_num) << " for actor:" << actor->GetAID().Name();
650 }
651 }
652 } // namespace
653
CollectDynamicLenIndexForArgment(const GraphCompilerInfo & graph_compiler_info) const654 void ControlNodeScheduler::CollectDynamicLenIndexForArgment(const GraphCompilerInfo &graph_compiler_info) const {
655 const auto &parser = graph_compiler_info.control_node_parser_;
656 MS_EXCEPTION_IF_NULL(parser);
657 for (const auto &node_to_func_with_index : parser->control_node_to_funcgraph_with_dynamic_sequence_index_) {
658 const auto &node = node_to_func_with_index.first;
659 MS_EXCEPTION_IF_NULL(node);
660 const auto &actor_name = GetActorName(node);
661 const auto &actor = FetchActor(actor_name);
662 MS_EXCEPTION_IF_NULL(actor);
663 const auto &gather_actor = dynamic_cast<GatherActor *>(actor);
664 MS_EXCEPTION_IF_NULL(gather_actor);
665 for (const auto &func_with_index : node_to_func_with_index.second) {
666 const auto &func_graph = func_with_index.first;
667 MS_EXCEPTION_IF_NULL(func_graph);
668 auto dynamic_len_index = func_with_index.second;
669 std::vector<std::pair<std::vector<size_t>, bool>> real_indexes;
670 ParseRealIndex(dynamic_len_index, gather_actor->formal_parameters_.size(), &real_indexes, gather_actor);
671 MS_LOG(INFO) << "Add dynamic len index for funcgraph:" << func_graph->ToString()
672 << " actor:" << gather_actor->GetAID()
673 << " formal parameter num:" << gather_actor->formal_parameters_.size();
674 gather_actor->dynamic_len_index_[func_graph] = real_indexes;
675 }
676 }
677
678 for (const auto &node_to_call_with_index : parser->return_to_call_with_dynamic_sequence_index_) {
679 const auto &node = node_to_call_with_index.first;
680 MS_EXCEPTION_IF_NULL(node);
681 MS_EXCEPTION_IF_NULL(node->func_graph());
682 MS_LOG(DEBUG) << "for node:" << node->DebugString();
683 const auto &actor_name = node->func_graph()->ToString() + kExitActorNameSuffix;
684 auto actor = FetchActor(actor_name);
685 MS_EXCEPTION_IF_NULL(actor);
686 const auto &exit_actor = dynamic_cast<ExitActor *>(actor);
687 MS_EXCEPTION_IF_NULL(exit_actor);
688 for (const auto &call_with_index : node_to_call_with_index.second) {
689 const auto &call = call_with_index.first;
690 MS_EXCEPTION_IF_NULL(call);
691 int branch_id = parser->FetchBranchIDByCallNode(call);
692 std::vector<std::pair<std::vector<size_t>, bool>> real_indexes;
693 ParseRealIndex(call_with_index.second, exit_actor->formal_parameters_.size(), &real_indexes, exit_actor);
694 exit_actor->output_branch_dynamic_len_index_[branch_id] = real_indexes;
695 }
696 }
697 }
698
Link(ActorSet * const actor_set,const GraphCompilerInfo & graph_compiler_info) const699 void ControlNodeScheduler::Link(ActorSet *const actor_set, const GraphCompilerInfo &graph_compiler_info) const {
700 MS_EXCEPTION_IF_NULL(actor_set);
701 MS_EXCEPTION_IF_NULL(actor_set->control_actors_);
702 MS_LOG(DEBUG) << "Control node scheduler link start.";
703 // Link data arrows and partial arrows between control actors.
704 LinkArrowForControlActor(actor_set->control_actors_.get(), graph_compiler_info);
705
706 // Link arrows from host data source actor or data prepare actor to entrance actor of root graph.
707 LinkArrowForRootGraphEntranceActor(graph_compiler_info);
708
709 // Link output data arrows from control actors to output actor.
710 LinkDataArrowForOutputActor(actor_set, graph_compiler_info);
711
712 // Link data arrows from entrance actors to kernel actors.
713 LinkDataArrowForKernelActor(graph_compiler_info);
714
715 // Link branch id arrows between control actors.
716 LinkBranchIDArrowForControlActor(actor_set->control_actors_.get());
717
718 // Link all control arrows between control actors.
719 LinkControlArrowForControlActor(actor_set, graph_compiler_info);
720
721 // Link control arrows for no input and no output kernel actor.
722 LinkControlArrowForKernelActor(actor_set, graph_compiler_info);
723
724 LinkControlArrowForLoopCountActor(actor_set, graph_compiler_info);
725
726 LinkDataArrowForCustomActor(actor_set, graph_compiler_info);
727
728 LinkControlArrowForCustomActor(actor_set, graph_compiler_info);
729
730 SetTimeSummaryForControlActor(graph_compiler_info);
731
732 CollectDynamicLenIndexForArgment(graph_compiler_info);
733 MS_LOG(DEBUG) << "Control node scheduler link end.";
734 }
735
736 namespace {
FetchInternalParameterInput(const AnfNodePtr & node,const ControlNodeParserPtr & parser,const KernelGraphPtr & graph)737 AnfNodePtr FetchInternalParameterInput(const AnfNodePtr &node, const ControlNodeParserPtr &parser,
738 const KernelGraphPtr &graph) {
739 MS_EXCEPTION_IF_NULL(node);
740 if (!node->isa<CNode>()) {
741 MS_LOG(WARNING) << "Node:" << node->DebugString() << " is not a cnode.";
742 return nullptr;
743 }
744 const auto &kernel = node->cast<CNodePtr>();
745 MS_EXCEPTION_IF_NULL(kernel);
746 for (size_t i = 0; i < common::AnfAlgo::GetInputNum(kernel); ++i) {
747 auto input_node = common::AnfAlgo::GetInputNode(kernel, i);
748 MS_EXCEPTION_IF_NULL(input_node);
749 auto input_with_index = common::AnfAlgo::VisitKernelWithReturnType(input_node, 0, false);
750 auto input = input_with_index.first;
751 MS_EXCEPTION_IF_NULL(input);
752 if (HasAbstractMonad(input) || (!parser->IsControlFlowDataArrow(graph, input))) {
753 continue;
754 }
755
756 auto from_node_with_index = GetFrontNodeByKernelGraph(input, graph.get());
757 MS_EXCEPTION_IF_NULL(from_node_with_index.first);
758 const auto &from_node = from_node_with_index.first;
759 if (from_node->isa<CNode>()) {
760 return from_node;
761 }
762 }
763 return nullptr;
764 }
765 } // namespace
766
LinkControlArrowForCustomActor(ActorSet * const actor_set,const GraphCompilerInfo & graph_compiler_info) const767 void ControlNodeScheduler::LinkControlArrowForCustomActor(ActorSet *const actor_set,
768 const GraphCompilerInfo &graph_compiler_info) const {
769 MS_EXCEPTION_IF_NULL(actor_set);
770 const auto &parser = graph_compiler_info.control_node_parser_;
771 MS_EXCEPTION_IF_NULL(parser);
772
773 for (auto &custom_actor : actor_set->custom_actors_) {
774 MS_EXCEPTION_IF_NULL(custom_actor);
775 const auto &kernel = custom_actor->kernel().lock();
776 MS_EXCEPTION_IF_NULL(kernel);
777 const auto &graph = kernel->func_graph();
778 MS_EXCEPTION_IF_NULL(graph);
779 if (custom_actor->output_data_arrows().empty() && custom_actor->output_control_arrows().empty()) {
780 const auto &actor_name = graph->ToString() + kExitActorNameSuffix;
781 auto actor = FetchActor(actor_name);
782 MS_EXCEPTION_IF_NULL(actor);
783 SchedulerHelper::AddControlArrow(custom_actor.get(), actor);
784 }
785 if (custom_actor->input_control_arrow_aids().empty() && custom_actor->input_data_arrow_aids().empty()) {
786 const auto &kernel_graph = std::dynamic_pointer_cast<KernelGraph>(graph);
787 MS_EXCEPTION_IF_NULL(kernel_graph);
788 auto base_node = AnfUtils::GetCustomActorBaseNode(kernel);
789 AnfNodePtr internal_parameter = nullptr;
790 if (base_node != nullptr) {
791 internal_parameter = FetchInternalParameterInput(base_node, parser, kernel_graph);
792 }
793 AbstractActor *from_actor = nullptr;
794 if (parser->IsCallInputKernelGraph(kernel_graph.get())) {
795 auto kernel_graph_ptr = std::dynamic_pointer_cast<KernelGraph>(kernel->func_graph());
796 const auto &actor_name = parser->FetchGroupNameByKernelGraph(kernel_graph_ptr) + kStackActorNameSuffix;
797 from_actor = FetchActor(actor_name);
798 } else if (internal_parameter != nullptr) {
799 const auto &from_graph = parser->FetchKernelGraphByFrontNode(internal_parameter);
800 MS_EXCEPTION_IF_NULL(from_graph);
801 from_actor = FetchActor(parser->FetchGroupNameByKernelGraph(from_graph) + kExitActorNameSuffix);
802 } else {
803 const auto &func_graph = parser->FetchFuncGraphByKernelGraph(kernel_graph.get());
804 MS_EXCEPTION_IF_NULL(func_graph);
805 const auto &actor_name = func_graph->ToString() + kEntranceActorNameSuffix;
806 from_actor = FetchActor(actor_name);
807 }
808 MS_EXCEPTION_IF_NULL(from_actor);
809 SchedulerHelper::AddControlArrow(from_actor, custom_actor.get());
810 }
811 }
812 }
813
ClearActorData(const ControlActorSet * control_actor_set) const814 void ControlNodeScheduler::ClearActorData(const ControlActorSet *control_actor_set) const {
815 if (control_actor_set == nullptr) {
816 return;
817 }
818
819 for (auto &switch_actor : control_actor_set->switch_actors_) {
820 MS_EXCEPTION_IF_NULL(switch_actor);
821 switch_actor->memory_free_lists_ = std::queue<std::vector<DeviceTensor *>>();
822 }
823
824 for (auto &gather_actor : control_actor_set->gather_actors_) {
825 MS_EXCEPTION_IF_NULL(gather_actor);
826 gather_actor->memory_free_lists_ = std::queue<std::vector<DeviceTensor *>>();
827 gather_actor->created_device_tensors_.clear();
828 gather_actor->created_new_graphs_.clear();
829 gather_actor->created_new_nodes_.clear();
830 }
831
832 for (auto &entrance_actor : control_actor_set->entrance_actors_) {
833 MS_EXCEPTION_IF_NULL(entrance_actor);
834 entrance_actor->memory_free_lists_ = std::queue<std::vector<DeviceTensor *>>();
835 }
836
837 for (auto &stack_actor : control_actor_set->stack_actors_) {
838 MS_EXCEPTION_IF_NULL(stack_actor);
839 stack_actor->memory_free_lists_ = std::queue<std::vector<DeviceTensor *>>();
840 }
841
842 for (auto &exit_actor : control_actor_set->exit_actors_) {
843 MS_EXCEPTION_IF_NULL(exit_actor);
844 exit_actor->memory_free_lists_ = std::queue<std::vector<DeviceTensor *>>();
845 exit_actor->last_step_created_device_tensors_.swap(exit_actor->created_device_tensors_);
846 exit_actor->created_new_graphs_.clear();
847 exit_actor->created_new_nodes_.clear();
848 }
849 }
850
851 namespace {
GetLazyInlineFuncGraph(const StackActorPtr & stack_actor)852 FuncGraphPtr GetLazyInlineFuncGraph(const StackActorPtr &stack_actor) {
853 MS_EXCEPTION_IF_NULL(stack_actor);
854 for (const auto &input_pair : stack_actor->input_data_arrow_aids()) {
855 if (input_pair.second == nullptr) {
856 continue;
857 }
858 const auto &data_arrow = input_pair.second;
859 if (IntToSize(data_arrow->to_input_index_) < stack_actor->input_stack_data_num()) {
860 continue;
861 }
862 std::string actor_name = input_pair.first.Name();
863 if (actor_name.empty()) {
864 continue;
865 }
866 const auto &actor = FetchActor(actor_name);
867 if (actor == nullptr) {
868 MS_LOG(WARNING) << "Failed to get actor by aid:" << actor_name << " for stack actor:" << stack_actor->GetAID();
869 continue;
870 }
871 if (actor->type() != KernelTransformType::kExitActor) {
872 continue;
873 }
874 const auto &exit_actor = dynamic_cast<ExitActor *>(actor);
875 MS_EXCEPTION_IF_NULL(exit_actor);
876 if (exit_actor->node() == nullptr) {
877 continue;
878 }
879 return exit_actor->node()->func_graph();
880 }
881 return nullptr;
882 }
883 } // namespace
884
Optimize(const ActorSetPtr & actor_set,const GraphCompilerInfo & graph_compiler_info) const885 void ControlNodeScheduler::Optimize(const ActorSetPtr &actor_set, const GraphCompilerInfo &graph_compiler_info) const {
886 MS_EXCEPTION_IF_NULL(actor_set);
887 if (actor_set->control_actors_ == nullptr || graph_compiler_info.control_node_parser_ == nullptr) {
888 return;
889 }
890 const auto &parser = graph_compiler_info.control_node_parser_;
891
892 for (const auto &entrance_actor : actor_set->control_actors_->entrance_actors_) {
893 MS_EXCEPTION_IF_NULL(entrance_actor);
894 std::stable_partition(entrance_actor->output_data_arrows_.begin(), entrance_actor->output_data_arrows_.end(),
895 [](const DataArrowPtr &arrow) {
896 MS_EXCEPTION_IF_NULL(arrow);
897 const auto &actor = FetchActor(arrow->to_op_id_.Name());
898 return actor != nullptr && (actor->type() == KernelTransformType::kKernelActor ||
899 actor->type() == KernelTransformType::kSuperKernelActor ||
900 actor->type() == KernelTransformType::kCopyActor);
901 });
902 }
903
904 for (const auto &stack_actor : actor_set->control_actors_->stack_actors_) {
905 MS_EXCEPTION_IF_NULL(stack_actor);
906 if (stack_actor->formal_parameters_.size() !=
907 stack_actor->input_data_arrow_aids_.size() + stack_actor->device_tensor_store_keys_.size()) {
908 continue;
909 }
910 if (stack_actor->input_stack_data_num_ == 0 ||
911 stack_actor->input_stack_data_num_ >= stack_actor->device_tensor_store_keys_.size() +
912 stack_actor->local_device_tensors_.size() +
913 stack_actor->input_data_arrow_aids_.size()) {
914 continue;
915 }
916 const auto &func_graph = GetLazyInlineFuncGraph(stack_actor);
917 if (func_graph == nullptr) {
918 continue;
919 }
920 if (!func_graph->has_flag(FUNC_GRAPH_FLAG_CELL_REUSE)) {
921 continue;
922 }
923 const auto &iter = parser->func_graph_to_call_nodes_.find(func_graph);
924 if (iter != parser->func_graph_to_call_nodes_.end() && (!iter->second.empty())) {
925 continue;
926 }
927 for (const auto &input_aid : stack_actor->input_branch_id_arrow_aids_) {
928 const auto &actor = FetchActor(input_aid.Name());
929 if (actor == nullptr || actor->type() != KernelTransformType::kEntranceActor) {
930 MS_LOG(WARNING) << "Invalid input branch id aid:" << input_aid;
931 continue;
932 }
933 const auto &entrance_actor = dynamic_cast<EntranceActor *>(actor);
934 MS_EXCEPTION_IF_NULL(entrance_actor);
935 const auto &branch_id_iter = std::find(entrance_actor->output_branch_id_arrows_.begin(),
936 entrance_actor->output_branch_id_arrows_.end(), stack_actor->GetAID());
937 if (branch_id_iter != entrance_actor->output_branch_id_arrows_.end()) {
938 stack_actor->is_branch_id_enable_ = false;
939 entrance_actor->output_branch_id_arrows_.erase(branch_id_iter);
940 MS_LOG(DEBUG) << "Disable branch id from entrance actor:" << entrance_actor->GetAID()
941 << " to stack actor:" << stack_actor->GetAID()
942 << " for cell reuse funcgraph:" << func_graph->ToString();
943 break;
944 }
945 }
946 }
947 }
948
LinkArrowForControlActor(ControlActorSet * const control_actor_set,const GraphCompilerInfo & graph_compiler_info) const949 void ControlNodeScheduler::LinkArrowForControlActor(ControlActorSet *const control_actor_set,
950 const GraphCompilerInfo &graph_compiler_info) const {
951 if (control_actor_set == nullptr) {
952 return;
953 }
954
955 MS_EXCEPTION_IF_NULL(graph_compiler_info.control_node_parser_);
956 const auto &parser = graph_compiler_info.control_node_parser_;
957 for (auto &switch_actor : control_actor_set->switch_actors_) {
958 MS_EXCEPTION_IF_NULL(switch_actor);
959 if (!parser->IsNeedStackControlNode(switch_actor->node_)) {
960 for (size_t i = 0; i < switch_actor->formal_parameters_.size(); ++i) {
961 LinkArrowbyFormalParameter(switch_actor.get(), switch_actor->formal_parameters_[i], {switch_actor->node_, i},
962 graph_compiler_info);
963 }
964 } else {
965 // If the control actor has a corresponding stack actor, the input should be linked to the stack actor.
966 auto stack_actor_name = GetActorName(switch_actor->node_) + kStackActorNameSuffix;
967 auto actor = FetchActor(stack_actor_name);
968 MS_EXCEPTION_IF_NULL(actor);
969 auto stack_actor = dynamic_cast<StackActor *>(actor);
970 MS_EXCEPTION_IF_NULL(stack_actor);
971 LinkArrowFromStackActor(stack_actor, switch_actor.get(), graph_compiler_info);
972 }
973 }
974
975 for (auto &gather_actor : control_actor_set->gather_actors_) {
976 MS_EXCEPTION_IF_NULL(gather_actor);
977 MS_EXCEPTION_IF_NULL(gather_actor->node_);
978 if (!parser->IsNeedStackControlNode(gather_actor->node_)) {
979 for (size_t i = 0; i < gather_actor->formal_parameters_.size(); ++i) {
980 LinkArrowbyFormalParameter(gather_actor.get(), gather_actor->formal_parameters_[i], {gather_actor->node_, i},
981 graph_compiler_info);
982 }
983 } else {
984 // If the control actor has a corresponding stack actor, the input should be linked to the stack actor.
985 auto stack_actor_name = GetActorName(gather_actor->node_) + kStackActorNameSuffix;
986 auto actor = FetchActor(stack_actor_name);
987 MS_EXCEPTION_IF_NULL(actor);
988 auto stack_actor = dynamic_cast<StackActor *>(actor);
989 MS_EXCEPTION_IF_NULL(stack_actor);
990 LinkArrowFromStackActor(stack_actor, gather_actor.get(), graph_compiler_info);
991 }
992 }
993
994 for (auto &entrance_actor : control_actor_set->entrance_actors_) {
995 MS_EXCEPTION_IF_NULL(entrance_actor);
996 for (const auto &call_node : entrance_actor->call_nodes_) {
997 LinkArrowbyFormalParameter(entrance_actor.get(), call_node, {entrance_actor->node_, 0}, graph_compiler_info);
998 }
999 }
1000
1001 for (auto &exit_actor : control_actor_set->exit_actors_) {
1002 MS_EXCEPTION_IF_NULL(exit_actor);
1003
1004 auto stack_actor_name = (exit_actor->node_ == nullptr ? GetStackActorNameByExitName(exit_actor->GetAID().Name())
1005 : GetActorName(exit_actor->node_) + kStackActorNameSuffix);
1006 auto actor = FetchActor(stack_actor_name);
1007 if (actor == nullptr) {
1008 for (size_t i = 0; i < exit_actor->formal_parameters_.size(); ++i) {
1009 LinkArrowbyFormalParameter(exit_actor.get(), exit_actor->formal_parameters_[i], {exit_actor->node_, i},
1010 graph_compiler_info);
1011 }
1012 } else {
1013 // If the control actor has a corresponding stack actor, the input should be linked to the stack actor.
1014 auto stack_actor = dynamic_cast<StackActor *>(actor);
1015 MS_EXCEPTION_IF_NULL(stack_actor);
1016 LinkArrowFromStackActor(stack_actor, exit_actor.get(), graph_compiler_info);
1017 }
1018 }
1019
1020 for (auto &stack_actor : control_actor_set->stack_actors_) {
1021 MS_EXCEPTION_IF_NULL(stack_actor);
1022 for (size_t i = 0; i < stack_actor->formal_parameters_.size(); ++i) {
1023 LinkArrowbyFormalParameter(stack_actor.get(), stack_actor->formal_parameters_[i], {stack_actor->node_, i},
1024 graph_compiler_info);
1025 }
1026 }
1027 }
1028
LinkArrowFromStackActor(StackActor * const stack_actor,ControlActor * const to_actor,const GraphCompilerInfo & graph_compiler_info) const1029 void ControlNodeScheduler::LinkArrowFromStackActor(StackActor *const stack_actor, ControlActor *const to_actor,
1030 const GraphCompilerInfo &graph_compiler_info) const {
1031 MS_EXCEPTION_IF_NULL(stack_actor);
1032 MS_EXCEPTION_IF_NULL(to_actor);
1033 MS_EXCEPTION_IF_NULL(graph_compiler_info.control_node_parser_);
1034 const auto &parser = graph_compiler_info.control_node_parser_;
1035 MS_EXCEPTION_IF_NULL(parser);
1036
1037 for (size_t to_index = 0; to_index < to_actor->formal_parameters_.size(); ++to_index) {
1038 const auto &formal_parameter =
1039 common::AnfAlgo::FetchRealNodeSkipMonadControl(to_actor->formal_parameters_[to_index]);
1040 const auto &from_node = formal_parameter.first;
1041 MS_EXCEPTION_IF_NULL(from_node);
1042 if (from_node->isa<ValueNode>()) {
1043 LinkArrowByValueNode(from_node, to_actor, formal_parameter.second, to_index);
1044 continue;
1045 }
1046
1047 // Fetch the arrow type of input.
1048 if (to_actor->type_ == KernelTransformType::kExitActor && to_actor->node_ == nullptr && from_node->isa<CNode>() &&
1049 (!common::AnfAlgo::IsCallNode(from_node) || IsNotCut(from_node)) &&
1050 (!common::AnfAlgo::CheckPrimitiveType(from_node, prim::kPrimPartial)) &&
1051 to_actor->GetAID().Name().find(
1052 parser->FetchGroupNameByKernelGraph(parser->FetchKernelGraphByFrontNode(from_node))) != std::string::npos) {
1053 LinkArrowByKernel(from_node, to_actor, formal_parameter, {to_actor->node_, to_index}, graph_compiler_info);
1054 continue;
1055 }
1056
1057 size_t from_index = stack_actor->FetchNodePosition(formal_parameter);
1058 const auto &abstract = from_node->abstract();
1059 MS_EXCEPTION_IF_NULL(abstract);
1060 const auto &real_abstract = common::AnfAlgo::FetchAbstractByIndex(abstract, formal_parameter.second);
1061 MS_EXCEPTION_IF_NULL(real_abstract);
1062
1063 // Link arrow according to abstract.
1064 if (real_abstract->isa<abstract::AbstractFunction>()) {
1065 SchedulerHelper::AddPartialArrow(stack_actor, to_actor, from_index, to_index);
1066 } else {
1067 SchedulerHelper::AddDataArrow(stack_actor, to_actor, from_index, to_index);
1068 }
1069 }
1070 }
1071
LinkArrowbyFormalParameter(ControlActor * const to_actor,const KernelWithIndex & from_node_with_index,const KernelWithIndex & to_node_with_index,const GraphCompilerInfo & graph_compiler_info) const1072 void ControlNodeScheduler::LinkArrowbyFormalParameter(ControlActor *const to_actor,
1073 const KernelWithIndex &from_node_with_index,
1074 const KernelWithIndex &to_node_with_index,
1075 const GraphCompilerInfo &graph_compiler_info) const {
1076 const auto &real_from_node_with_index = common::AnfAlgo::FetchRealNodeSkipMonadControl(from_node_with_index);
1077 const auto &from_node = real_from_node_with_index.first;
1078 MS_EXCEPTION_IF_NULL(from_node);
1079 MS_EXCEPTION_IF_NULL(to_actor);
1080 MS_LOG(DEBUG) << "Link arrow by formal parameter, from node:" << from_node->DebugString()
1081 << " from index:" << real_from_node_with_index.second << " to actor:" << to_actor->GetAID()
1082 << " to index:" << to_node_with_index.second;
1083 if (from_node->isa<ValueNode>()) {
1084 LinkArrowByValueNode(from_node, to_actor, real_from_node_with_index.second, to_node_with_index.second);
1085 } else if (from_node->isa<Parameter>()) {
1086 LinkArrowByParameter(from_node, to_actor, real_from_node_with_index, to_node_with_index,
1087 graph_compiler_info.control_node_parser_);
1088 } else if (common::AnfAlgo::IsCallNode(from_node) && !IsNotCut(from_node)) {
1089 // Link arrow by call node.
1090 LinkArrowByCallNode(from_node, to_actor, real_from_node_with_index, to_node_with_index, graph_compiler_info);
1091 } else if (common::AnfAlgo::CheckPrimitiveType(from_node, prim::kPrimSwitch) ||
1092 common::AnfAlgo::CheckPrimitiveType(from_node, prim::kPrimSwitchLayer)) {
1093 // Link arrow from switch actor.
1094 const auto &actor_name = GetActorName(from_node);
1095 const auto &actor = FetchActor(actor_name);
1096 MS_EXCEPTION_IF_NULL(actor);
1097 const auto &switch_actor = dynamic_cast<SwitchActor *>(actor);
1098 MS_EXCEPTION_IF_NULL(switch_actor);
1099 if (IsPartialInput(from_node)) {
1100 SchedulerHelper::AddPartialArrow(switch_actor, to_actor, real_from_node_with_index.second,
1101 to_node_with_index.second);
1102 } else {
1103 SchedulerHelper::AddDataArrow(switch_actor, to_actor, real_from_node_with_index.second,
1104 to_node_with_index.second);
1105 }
1106 } else if (common::AnfAlgo::CheckPrimitiveType(from_node, prim::kPrimPartial)) {
1107 // If the funcgraph of the partial node is a deadnode, in order to ensure the correspondence between formal
1108 // parameters and real parameters, we need to create an empty partial for it.
1109 if (IsInvalidPartial(from_node)) {
1110 MS_LOG(DEBUG) << "Invalid partial node:" << from_node->DebugString();
1111 to_actor->local_partials_[to_node_with_index.second] = std::make_shared<OpPartial>();
1112 return;
1113 }
1114 // Link arrow from gather actor
1115 const auto &actor_name = GetActorName(from_node);
1116 const auto &actor = FetchActor(actor_name);
1117 if (actor == nullptr) {
1118 MS_LOG(DEBUG) << "No actor of " << actor_name;
1119 return;
1120 }
1121 const auto &gather_actor = dynamic_cast<GatherActor *>(actor);
1122 MS_EXCEPTION_IF_NULL(gather_actor);
1123 SchedulerHelper::AddPartialArrow(gather_actor, to_actor, real_from_node_with_index.second,
1124 to_node_with_index.second);
1125 } else if (from_node->isa<CNode>()) {
1126 // Link arrow by kernel.
1127 LinkArrowByKernel(from_node, to_actor, real_from_node_with_index, to_node_with_index, graph_compiler_info);
1128 }
1129 }
1130
LinkArrowByValueNode(const AnfNodePtr & value_node,ControlActor * const to_actor,size_t from_index,size_t to_index) const1131 void ControlNodeScheduler::LinkArrowByValueNode(const AnfNodePtr &value_node, ControlActor *const to_actor,
1132 size_t from_index, size_t to_index) const {
1133 MS_EXCEPTION_IF_NULL(value_node);
1134 MS_EXCEPTION_IF_NULL(to_actor);
1135
1136 if (IsValueNode<FuncGraph>(value_node)) {
1137 // Link local partial.
1138 const auto &func_graph = GetValueNode<FuncGraphPtr>(value_node);
1139 MS_EXCEPTION_IF_NULL(func_graph);
1140 MS_LOG(DEBUG) << "Add local partial, graph:" << func_graph->ToString() << " for actor:" << to_actor->GetAID();
1141 to_actor->local_partials_[to_index] = std::make_shared<OpPartial>();
1142 *(to_actor->local_partials_[to_index]) = {func_graph.get(), {}, {}};
1143 } else {
1144 // Link device store value node.
1145 if (!AnfAlgo::OutputAddrExist(value_node, from_index)) {
1146 auto node = value_node->cast<ValueNodePtr>();
1147 MS_EXCEPTION_IF_NULL(node);
1148 auto value = node->value();
1149 MS_EXCEPTION_IF_NULL(value);
1150 // If the from index exceeds the size of the value node, we need to change the from index to 0.
1151 if (!value->isa<ValueTuple>() && from_index > 0) {
1152 from_index = 0;
1153 } else {
1154 MS_LOG_WITH_NODE(INTERNAL_EXCEPTION, value_node)
1155 << "#dmsg#Runtime error info:#dmsg#Invalid output address index:" << from_index
1156 << " for value node:" << value_node->DebugString() << " to actor:" << to_actor->GetAID();
1157 }
1158 }
1159 to_actor->local_device_tensors_[to_index] = AnfAlgo::GetMutableOutputAddr(value_node, from_index, false).get();
1160 to_actor->local_device_tensors_[to_index]->SetNodeIndex(value_node, from_index);
1161 MS_LOG(DEBUG) << "Add local device tensor:" << to_actor->local_device_tensors_[to_index] << " index:" << to_index
1162 << " for actor:" << to_actor->GetAID() << " from index:" << from_index;
1163 }
1164 }
1165
LinkArrowByParameter(const AnfNodePtr & parameter,ControlActor * const to_actor,const KernelWithIndex & from_node_with_index,const KernelWithIndex & to_node_with_index,const ControlNodeParserPtr & parser) const1166 void ControlNodeScheduler::LinkArrowByParameter(const AnfNodePtr ¶meter, ControlActor *const to_actor,
1167 const KernelWithIndex &from_node_with_index,
1168 const KernelWithIndex &to_node_with_index,
1169 const ControlNodeParserPtr &parser) const {
1170 MS_EXCEPTION_IF_NULL(parameter);
1171 MS_EXCEPTION_IF_NULL(to_actor);
1172 MS_EXCEPTION_IF_NULL(parser);
1173 MS_LOG(DEBUG) << "Link arrow by parameter:" << parameter->DebugString() << " indx:" << from_node_with_index.second
1174 << " for actor:" << to_actor->GetAID();
1175 if (parser->IsRootGraphPersistentDeviceTensor(parameter)) {
1176 (void)to_actor->device_tensor_store_keys_.emplace_back(to_node_with_index.second, parameter);
1177 return;
1178 }
1179 // Link arrow from entrance actor.
1180 const auto &func_graph = parameter->func_graph();
1181 MS_EXCEPTION_IF_NULL(func_graph);
1182 const auto &actor_name = func_graph->ToString() + kEntranceActorNameSuffix;
1183 auto actor = FetchActor(actor_name);
1184 MS_EXCEPTION_IF_NULL(actor);
1185
1186 // If the input of the exit actor of the kernel graph is a parameter node, and there is a corresponding stack actor,
1187 // it should be linked to the stack actor.
1188 if (to_actor->type() == KernelTransformType::kExitActor) {
1189 auto stack_actor_name = (to_actor->node_ == nullptr ? GetStackActorNameByExitName(to_actor->GetAID().Name())
1190 : GetActorName(to_actor->node_) + kStackActorNameSuffix);
1191 auto stack_actor = FetchActor(stack_actor_name);
1192 actor = (stack_actor == nullptr ? actor : stack_actor);
1193 }
1194
1195 auto from_actor = dynamic_cast<ControlActor *>(actor);
1196 MS_EXCEPTION_IF_NULL(from_actor);
1197
1198 auto abstract = parameter->abstract();
1199 MS_EXCEPTION_IF_NULL(abstract);
1200 auto dst_abstract = common::AnfAlgo::FetchAbstractByIndex(abstract, from_node_with_index.second);
1201 MS_EXCEPTION_IF_NULL(dst_abstract);
1202 if (dst_abstract->isa<abstract::AbstractFunction>()) {
1203 SchedulerHelper::AddPartialArrow(from_actor, to_actor, from_actor->FetchNodePosition(from_node_with_index),
1204 to_node_with_index.second);
1205 } else {
1206 SchedulerHelper::AddDataArrow(from_actor, to_actor, from_actor->FetchNodePosition(from_node_with_index),
1207 to_node_with_index.second);
1208 }
1209 }
1210
LinkArrowByCallNode(const AnfNodePtr & call_node,ControlActor * const to_actor,const KernelWithIndex & from_node_with_index,const KernelWithIndex & to_node_with_index,const GraphCompilerInfo & graph_compiler_info) const1211 void ControlNodeScheduler::LinkArrowByCallNode(const AnfNodePtr &call_node, ControlActor *const to_actor,
1212 const KernelWithIndex &from_node_with_index,
1213 const KernelWithIndex &to_node_with_index,
1214 const GraphCompilerInfo &graph_compiler_info) const {
1215 MS_EXCEPTION_IF_NULL(call_node);
1216 MS_EXCEPTION_IF_NULL(to_actor);
1217 const auto &from_node = from_node_with_index.first;
1218 MS_EXCEPTION_IF_NULL(from_node);
1219 auto parser = graph_compiler_info.control_node_parser_;
1220 MS_EXCEPTION_IF_NULL(parser);
1221
1222 if (to_actor->type_ != KernelTransformType::kEntranceActor) {
1223 // Link arrow from exit actor to control actor.
1224 const auto &abstract = call_node->abstract();
1225 MS_EXCEPTION_IF_NULL(abstract);
1226 const auto &real_abstract = common::AnfAlgo::FetchAbstractByIndex(abstract, from_node_with_index.second);
1227 MS_EXCEPTION_IF_NULL(real_abstract);
1228
1229 std::set<FuncGraphPtr> func_graphs;
1230 try {
1231 func_graphs = parser->FetchFuncGraphbyCallNode(from_node);
1232 } catch (std::exception &e) {
1233 LinkArrowByKernel(call_node, to_actor, from_node_with_index, to_node_with_index, graph_compiler_info);
1234 func_graphs.clear();
1235 }
1236
1237 for (const auto &func_graph : func_graphs) {
1238 MS_EXCEPTION_IF_NULL(func_graph);
1239 const auto &actor_name = func_graph->ToString() + kExitActorNameSuffix;
1240 auto actor = FetchActor(actor_name);
1241 MS_EXCEPTION_IF_NULL(actor);
1242 auto exit_actor = dynamic_cast<ExitActor *>(actor);
1243 MS_EXCEPTION_IF_NULL(exit_actor);
1244 auto branch_id = parser->FetchBranchIDByCallNode(from_node);
1245 if (real_abstract->isa<abstract::AbstractFunction>()) {
1246 SchedulerHelper::AddPartialArrowForExitActor(exit_actor, to_actor, from_node_with_index.second,
1247 to_node_with_index.second, branch_id);
1248 } else {
1249 SchedulerHelper::AddDataArrowForExitActor(exit_actor, to_actor, from_node_with_index.second,
1250 to_node_with_index.second, branch_id);
1251 }
1252 MS_LOG(DEBUG) << "Link data arrow from:" << exit_actor->GetAID() << " index:" << from_node_with_index.second
1253 << " to:" << to_actor->GetAID() << " index" << to_node_with_index.second;
1254 }
1255 if (real_abstract->isa<abstract::AbstractFunction>()) {
1256 to_actor->input_partials_num_++;
1257 } else {
1258 MS_LOG(DEBUG) << "Actor:" << to_actor->GetAID() << " add input num:" << to_actor->input_datas_num_;
1259 to_actor->input_datas_num_++;
1260 }
1261 } else {
1262 // Link arrow from gather actor to entrance actor.
1263 const auto &actor_name = GetActorName(from_node);
1264 const auto &actor = FetchActor(actor_name);
1265 MS_EXCEPTION_IF_NULL(actor);
1266 const auto &gather_actor = dynamic_cast<GatherActor *>(actor);
1267 MS_EXCEPTION_IF_NULL(gather_actor);
1268 const auto &to_node = to_node_with_index.first;
1269 MS_EXCEPTION_IF_NULL(to_node);
1270 const auto &func_graph = to_node->func_graph();
1271 MS_EXCEPTION_IF_NULL(func_graph);
1272 SchedulerHelper::AddDataWithBranchIDArrow(gather_actor, dynamic_cast<EntranceActor *>(to_actor), func_graph);
1273 }
1274 }
1275
LinkArrowByKernel(const AnfNodePtr & kernel,ControlActor * const to_actor,const KernelWithIndex & from_node_with_index,const KernelWithIndex & to_node_with_index,const GraphCompilerInfo & graph_compiler_info) const1276 void ControlNodeScheduler::LinkArrowByKernel(const AnfNodePtr &kernel, ControlActor *const to_actor,
1277 const KernelWithIndex &from_node_with_index,
1278 const KernelWithIndex &to_node_with_index,
1279 const GraphCompilerInfo &graph_compiler_info) const {
1280 MS_EXCEPTION_IF_NULL(kernel);
1281 MS_EXCEPTION_IF_NULL(to_actor);
1282 const auto &from_node = from_node_with_index.first;
1283 MS_EXCEPTION_IF_NULL(from_node);
1284 MS_EXCEPTION_IF_NULL(graph_compiler_info.control_node_parser_);
1285 const auto &parser = graph_compiler_info.control_node_parser_;
1286 MS_EXCEPTION_IF_NULL(parser);
1287 const auto &graph = parser->FetchKernelGraphByFrontNode(from_node);
1288 MS_LOG(DEBUG) << "Link arrow by kernel, from mode:" << from_node->DebugString() << " to actor:" << to_actor->GetAID();
1289 MS_EXCEPTION_IF_NULL(graph);
1290 const auto &group_name = parser->FetchGroupNameByKernelGraph(graph);
1291
1292 if (to_actor->type_ == KernelTransformType::kExitActor && to_actor->node_ == nullptr &&
1293 to_actor->GetAID().Name().find(group_name) != std::string::npos) {
1294 // Link arrow from actor of output node to exit actor of kernel graph.
1295 auto kernel_with_index = parser->FetchBackendNodeByFrontNode(from_node_with_index);
1296 if (kernel_with_index.first == nullptr) {
1297 kernel_with_index = parser->FetchBackendOutputByKernelGraph(graph, from_node_with_index);
1298 if (kernel_with_index.first == nullptr) {
1299 parser->PrintParseInfo();
1300 MS_LOG_WITH_NODE(EXCEPTION, from_node)
1301 << "Failed to get kernel with index by front node:" << from_node->fullname_with_scope()
1302 << " debug string:" << from_node->DebugString() << " index:" << from_node_with_index.second
1303 << " by graph:" << graph->ToString() << " to actor:" << to_actor->GetAID();
1304 }
1305 }
1306 auto type = FetchKernelTransformType(kernel_with_index.first, graph, {});
1307 auto from_actor = FetchActor(type, graph_compiler_info.name_, kernel_with_index.first, graph);
1308 if (from_actor == nullptr) {
1309 parser->PrintParseInfo();
1310 MS_LOG_WITH_NODE(EXCEPTION, from_node)
1311 << "Failed to get from actor by backend node:" << kernel_with_index.first->DebugString()
1312 << " front node : " << from_node->fullname_with_scope() << " debug string:" << from_node->DebugString()
1313 << " index:" << from_node_with_index.second << " by graph:" << graph->ToString()
1314 << " to actor:" << to_actor->GetAID() << " type:" << type;
1315 }
1316 SchedulerHelper::AddDataArrow(from_actor, to_actor, kernel_with_index.second, to_node_with_index.second,
1317 kernel_with_index.first);
1318 } else {
1319 // Link arrow from exit actor of kernel graph to exit actor of function graph.
1320 const auto &actor_name = parser->FetchGroupNameByKernelGraph(graph) + kExitActorNameSuffix;
1321 MS_LOG(DEBUG) << "Actor name:" << actor_name << " from node:" << from_node->DebugString();
1322 auto actor = FetchActor(actor_name);
1323 MS_EXCEPTION_IF_NULL(actor);
1324 auto exit_actor = dynamic_cast<ExitActor *>(actor);
1325 MS_EXCEPTION_IF_NULL(exit_actor);
1326 size_t from_index = exit_actor->FetchNodePosition(from_node_with_index);
1327 SchedulerHelper::AddDataArrow(exit_actor, to_actor, from_index, to_node_with_index.second);
1328 }
1329 }
1330
LinkControlArrowForControlActor(ActorSet * const actor_set,const GraphCompilerInfo & graph_compiler_info) const1331 void ControlNodeScheduler::LinkControlArrowForControlActor(ActorSet *const actor_set,
1332 const GraphCompilerInfo &graph_compiler_info) const {
1333 MS_EXCEPTION_IF_NULL(actor_set);
1334 auto control_actor_set = actor_set->control_actors_.get();
1335 MS_EXCEPTION_IF_NULL(control_actor_set);
1336 const auto &parser = graph_compiler_info.control_node_parser_;
1337 MS_EXCEPTION_IF_NULL(parser);
1338 LinkControlArrowForEntranceActor(actor_set, graph_compiler_info);
1339
1340 // When the switch actor and gather actor have no input, need to link a control arrow from entrance actor.
1341 std::vector<ControlActor *> need_check_control_actors;
1342 (void)std::transform(control_actor_set->switch_actors_.begin(), control_actor_set->switch_actors_.end(),
1343 std::back_inserter(need_check_control_actors),
1344 [](const auto &switch_actor) { return switch_actor.get(); });
1345 (void)std::transform(control_actor_set->gather_actors_.begin(), control_actor_set->gather_actors_.end(),
1346 std::back_inserter(need_check_control_actors),
1347 [](const auto &gather_actor) { return gather_actor.get(); });
1348
1349 for (auto control_actor : need_check_control_actors) {
1350 MS_EXCEPTION_IF_NULL(control_actor);
1351 if (IsNoInputActor(control_actor)) {
1352 MS_EXCEPTION_IF_NULL(control_actor->node_);
1353 if (parser->IsNeedStackControlNode(control_actor->node_)) {
1354 const auto &stack_actor_name = GetActorName(control_actor->node_) + kStackActorNameSuffix;
1355 auto actor = FetchActor(stack_actor_name);
1356 MS_EXCEPTION_IF_NULL(actor);
1357 auto to_actor = dynamic_cast<ControlActor *>(actor);
1358 MS_EXCEPTION_IF_NULL(to_actor);
1359 SchedulerHelper::AddControlArrow(to_actor, control_actor);
1360 continue;
1361 }
1362 const FuncGraphPtr &func_graph = control_actor->node_->func_graph();
1363 MS_EXCEPTION_IF_NULL(func_graph);
1364 const auto &actor_name = func_graph->ToString() + kEntranceActorNameSuffix;
1365 const auto &entrance_actor = dynamic_cast<EntranceActor *>(FetchActor(actor_name));
1366 MS_EXCEPTION_IF_NULL(entrance_actor);
1367 SchedulerHelper::AddControlArrow(entrance_actor, control_actor);
1368 }
1369 }
1370
1371 // Link auto monad control arrow for control actor.
1372 std::vector<ControlActor *> control_actors;
1373 (void)std::transform(control_actor_set->switch_actors_.begin(), control_actor_set->switch_actors_.end(),
1374 std::back_inserter(control_actors), [](auto &switch_actor) { return switch_actor.get(); });
1375 (void)std::transform(control_actor_set->gather_actors_.begin(), control_actor_set->gather_actors_.end(),
1376 std::back_inserter(control_actors), [](auto &gather_actor) { return gather_actor.get(); });
1377 (void)std::transform(control_actor_set->exit_actors_.begin(), control_actor_set->exit_actors_.end(),
1378 std::back_inserter(control_actors), [](auto &exit_actor) { return exit_actor.get(); });
1379 for (auto control_actor : control_actors) {
1380 MS_EXCEPTION_IF_NULL(control_actor);
1381 const auto &node = control_actor->node_;
1382 if (node == nullptr) {
1383 continue;
1384 }
1385
1386 auto to_actor = control_actor;
1387 if (parser->IsNeedStackControlNode(node)) {
1388 const auto &stack_actor_name = GetActorName(node) + kStackActorNameSuffix;
1389 auto actor = FetchActor(stack_actor_name);
1390 MS_EXCEPTION_IF_NULL(actor);
1391 to_actor = dynamic_cast<ControlActor *>(actor);
1392 MS_EXCEPTION_IF_NULL(to_actor);
1393 }
1394
1395 const auto &cnode = node->cast<CNodePtr>();
1396 MS_EXCEPTION_IF_NULL(cnode);
1397 const auto &inputs = cnode->inputs();
1398 for (const auto &input : inputs) {
1399 MS_EXCEPTION_IF_NULL(input);
1400 std::vector<AnfNodePtr> monad_nodes = FetchAllMonadNodeByNode(input);
1401 for (const auto &monad_node : monad_nodes) {
1402 MS_EXCEPTION_IF_NULL(monad_node);
1403 LinkControlArrowByAutoMonad(to_actor, monad_node, parser);
1404 }
1405 }
1406 }
1407
1408 // Link copy actor to exit actor.
1409 for (const auto ©_actor : actor_set->copy_actors_) {
1410 MS_EXCEPTION_IF_NULL(copy_actor);
1411 if ((!copy_actor->output_data_arrows_.empty()) || (!copy_actor->output_control_arrows_.empty())) {
1412 continue;
1413 }
1414 KernelGraphPtr kernel_graph = nullptr;
1415 if (copy_actor->from_graph_ != nullptr) {
1416 kernel_graph = copy_actor->from_graph_;
1417 } else if (copy_actor->from_kernel_ != nullptr) {
1418 kernel_graph = std::dynamic_pointer_cast<KernelGraph>(copy_actor->from_kernel_->func_graph());
1419 }
1420 if (kernel_graph == nullptr) {
1421 MS_LOG(INTERNAL_EXCEPTION) << "#dmsg#Runtime error info:#dmsg#Invalid copy actor:" << copy_actor->GetAID().Name();
1422 }
1423 auto exit_actor_name = parser->FetchGroupNameByKernelGraph(kernel_graph) + kExitActorNameSuffix;
1424 auto exit_actor = FetchActor(exit_actor_name);
1425 MS_EXCEPTION_IF_NULL(exit_actor);
1426 SchedulerHelper::AddControlArrow(copy_actor.get(), exit_actor);
1427 }
1428
1429 LinkControlArrowByKernelGraphGroup(graph_compiler_info);
1430 }
1431
LinkControlArrowForEntranceActor(ActorSet * const actor_set,const GraphCompilerInfo & graph_compiler_info) const1432 void ControlNodeScheduler::LinkControlArrowForEntranceActor(ActorSet *const actor_set,
1433 const GraphCompilerInfo &graph_compiler_info) const {
1434 MS_EXCEPTION_IF_NULL(actor_set);
1435 auto control_actor_set = actor_set->control_actors_.get();
1436 MS_EXCEPTION_IF_NULL(control_actor_set);
1437 const auto &parser = graph_compiler_info.control_node_parser_;
1438 MS_EXCEPTION_IF_NULL(parser);
1439
1440 // Since only one set of real parameters are allowed to be executed in funcgraph at the same time, when the funcgraph
1441 // stops running, it is necessary to send the control arrow to the corresponding entrance actor at the exit of the
1442 // graph to run the next set of real parameters. The corresponding nodes of the actors that need to send the control
1443 // arrow have been parsed in the control node parser.
1444 for (const auto &graph_to_nodes : parser->func_graph_to_first_control_nodes_) {
1445 // Fetch the entrance actor.
1446 const auto &func_graph = graph_to_nodes.first;
1447 MS_EXCEPTION_IF_NULL(func_graph);
1448 auto actor_name = func_graph->ToString() + kEntranceActorNameSuffix;
1449 auto entrance_actor = dynamic_cast<EntranceActor *>(FetchActor(actor_name));
1450 MS_EXCEPTION_IF_NULL(entrance_actor);
1451
1452 const auto &nodes = graph_to_nodes.second;
1453 for (const auto &node : nodes) {
1454 // Fetch the source actor of control arrow.
1455 MS_EXCEPTION_IF_NULL(node);
1456 if (common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) {
1457 actor_name = func_graph->ToString() + kExitActorNameSuffix;
1458 } else {
1459 actor_name = GetActorName(node);
1460 }
1461 auto from_actor = dynamic_cast<ControlActor *>(FetchActor(actor_name));
1462 MS_EXCEPTION_IF_NULL(from_actor);
1463 SchedulerHelper::AddLoopBodyControlArrow(from_actor, entrance_actor);
1464 }
1465 }
1466
1467 // In the recursive scene, some kernel graph needs to be completed before the next set of data is sent by the
1468 // entrance actor. At this time, it is necessary to connect a control arrow from the exit actor of the graph
1469 // to the entrance actor.
1470 for (const auto &func_graph_to_group_info : parser->func_graph_to_first_kernel_graphs_) {
1471 const auto &func_graph = func_graph_to_group_info.first;
1472 MS_EXCEPTION_IF_NULL(func_graph);
1473 auto actor_name = func_graph->ToString() + kEntranceActorNameSuffix;
1474 auto actor = FetchActor(actor_name);
1475 MS_EXCEPTION_IF_NULL(actor);
1476 auto entrance_actor = dynamic_cast<EntranceActor *>(actor);
1477 MS_EXCEPTION_IF_NULL(entrance_actor);
1478 for (const auto &group_info : func_graph_to_group_info.second) {
1479 MS_EXCEPTION_IF_NULL(group_info);
1480 actor_name = group_info->group_name_ + kExitActorNameSuffix;
1481 auto from_actor = FetchActor(actor_name);
1482 MS_EXCEPTION_IF_NULL(from_actor);
1483 SchedulerHelper::AddLoopBodyControlArrow(from_actor, entrance_actor);
1484 }
1485 }
1486 }
1487
LinkControlArrowForLoopCountActor(const ActorSet * actor_set,const GraphCompilerInfo & graph_compiler_info) const1488 void ControlNodeScheduler::LinkControlArrowForLoopCountActor(const ActorSet *actor_set,
1489 const GraphCompilerInfo &graph_compiler_info) const {
1490 MS_EXCEPTION_IF_NULL(actor_set);
1491 auto loop_count_actor = actor_set->loop_count_actor_;
1492 MS_EXCEPTION_IF_NULL(loop_count_actor);
1493
1494 // The final output is always sent by the exit of the root graph in control flow.
1495 const auto &parser = graph_compiler_info.control_node_parser_;
1496 MS_EXCEPTION_IF_NULL(parser);
1497 const auto &root_graph = parser->root_func_graph_;
1498 MS_EXCEPTION_IF_NULL(root_graph);
1499 auto exit_actor_name = root_graph->ToString() + kExitActorNameSuffix;
1500 auto root_exit_actor = dynamic_cast<ExitActor *>(FetchActor(exit_actor_name));
1501 MS_EXCEPTION_IF_NULL(root_exit_actor);
1502 // link control arrow from root exit actor to loop count actor.
1503 SchedulerHelper::AddControlArrowForExitActor(root_exit_actor, loop_count_actor.get(), kMainBranchID);
1504
1505 // The entrance actor will generate some data in the loop body execution, so need clear on the end of step.
1506 MS_EXCEPTION_IF_NULL(actor_set->control_actors_);
1507 for (auto &entrance_actor : actor_set->control_actors_->entrance_actors_) {
1508 MS_EXCEPTION_IF_NULL(entrance_actor);
1509 (void)loop_count_actor->entrance_aids_.emplace_back(entrance_actor->GetAID());
1510 }
1511 }
1512
LinkOutputControlArrowForActor(ActorSet * const actor_set,const GraphCompilerInfo & graph_compiler_info) const1513 void ControlNodeScheduler::LinkOutputControlArrowForActor(ActorSet *const actor_set,
1514 const GraphCompilerInfo &graph_compiler_info) const {
1515 MS_EXCEPTION_IF_NULL(actor_set);
1516 const auto &parser = graph_compiler_info.control_node_parser_;
1517 MS_EXCEPTION_IF_NULL(parser);
1518 // Link control arrows from no output kernel actor to the corresponding exit actor.
1519 for (auto &kernel_actor : actor_set->kernel_actors_) {
1520 MS_EXCEPTION_IF_NULL(kernel_actor);
1521 if ((kernel_actor->output_data_arrows_.size() == 0) && (kernel_actor->output_control_arrows_.size() == 0) &&
1522 (!IsInlineKernelActor(kernel_actor))) {
1523 auto kernel_graph = AnfAlgo::FetchKernelGraph(kernel_actor->kernel().get());
1524 MS_EXCEPTION_IF_NULL(kernel_graph);
1525 auto to_actor_name = parser->FetchGroupNameByKernelGraph(kernel_graph) + kExitActorNameSuffix;
1526 auto to_actor = FetchActor(to_actor_name);
1527 MS_EXCEPTION_IF_NULL(to_actor);
1528 SchedulerHelper::AddControlArrow(kernel_actor.get(), to_actor);
1529 }
1530 }
1531
1532 // Link control arrows from no super kernel actor to the corresponding exit actor.
1533 for (auto &super_actor : actor_set->super_kernel_actors_) {
1534 MS_EXCEPTION_IF_NULL(super_actor);
1535 if ((super_actor->output_data_arrows_.size() == 0) && (super_actor->output_control_arrows_.size() == 0)) {
1536 auto kernel_graph = super_actor->graph();
1537 MS_EXCEPTION_IF_NULL(kernel_graph);
1538 auto to_actor_name = parser->FetchGroupNameByKernelGraph(kernel_graph) + kExitActorNameSuffix;
1539 auto to_actor = FetchActor(to_actor_name);
1540 MS_EXCEPTION_IF_NULL(to_actor);
1541 SchedulerHelper::AddControlArrow(super_actor.get(), to_actor);
1542 }
1543 }
1544
1545 // Link control arrows from no super kernel actor to the corresponding exit actor.
1546 for (auto &any_type_kernel_actor : actor_set->any_type_kernel_actors_) {
1547 MS_EXCEPTION_IF_NULL(any_type_kernel_actor);
1548 if ((any_type_kernel_actor->output_data_arrows_.size() == 0) &&
1549 (any_type_kernel_actor->output_control_arrows_.size() == 0)) {
1550 auto kernel_graph = any_type_kernel_actor->graph();
1551 MS_EXCEPTION_IF_NULL(kernel_graph);
1552 auto to_actor_name = parser->FetchGroupNameByKernelGraph(kernel_graph) + kExitActorNameSuffix;
1553 auto to_actor = FetchActor(to_actor_name);
1554 MS_EXCEPTION_IF_NULL(to_actor);
1555 SchedulerHelper::AddControlArrow(any_type_kernel_actor.get(), to_actor);
1556 }
1557 }
1558 }
1559
LinkControlArrowForKernelActor(ActorSet * const actor_set,const GraphCompilerInfo & graph_compiler_info) const1560 void ControlNodeScheduler::LinkControlArrowForKernelActor(ActorSet *const actor_set,
1561 const GraphCompilerInfo &graph_compiler_info) const {
1562 MS_EXCEPTION_IF_NULL(actor_set);
1563 const auto &parser = graph_compiler_info.control_node_parser_;
1564 MS_EXCEPTION_IF_NULL(parser);
1565
1566 // Link control arrow from entrance actors or stack actors to no input kernel actors.
1567 for (const auto &no_input_kernel_actor : actor_set->no_input_kernel_actors_) {
1568 // In control flow, when the input of the kernel actor is a parameter, this input needs to be linked to the
1569 // control actor, so the no-input kernel actor collected in the graph scheduler will also collect this actor,
1570 // and it needs to be skipped here.
1571 MS_EXCEPTION_IF_NULL(no_input_kernel_actor);
1572 // Control arrow for custom actor will be linked in next step.
1573 if ((no_input_kernel_actor->input_datas_num_ != 0) || (no_input_kernel_actor->input_controls_num_ != 0) ||
1574 no_input_kernel_actor->type() == KernelTransformType::kCustomActor ||
1575 IsInlineKernelActor(no_input_kernel_actor)) {
1576 continue;
1577 }
1578
1579 KernelGraphPtr kernel_graph = nullptr;
1580 if (no_input_kernel_actor->type_ == KernelTransformType::kSuperKernelActor) {
1581 const auto &super_kernel_actor = dynamic_cast<SuperKernelActor *>(no_input_kernel_actor.get());
1582 MS_EXCEPTION_IF_NULL(super_kernel_actor);
1583 kernel_graph = super_kernel_actor->graph();
1584 } else if (no_input_kernel_actor->type_ == KernelTransformType::kKernelActor) {
1585 const auto &kernel_actor = dynamic_cast<KernelActor *>(no_input_kernel_actor.get());
1586 MS_EXCEPTION_IF_NULL(kernel_actor);
1587 kernel_graph = AnfAlgo::FetchKernelGraph(kernel_actor->kernel().get());
1588 } else if (no_input_kernel_actor->type_ == KernelTransformType::kCustomActor) {
1589 const auto &custom_actor = dynamic_cast<CustomActor *>(no_input_kernel_actor.get());
1590 MS_EXCEPTION_IF_NULL(custom_actor);
1591 auto custom_kernel = custom_actor->kernel().lock();
1592 MS_EXCEPTION_IF_NULL(custom_kernel);
1593 const auto &base_node = AnfUtils::GetCustomActorBaseNode(custom_kernel);
1594 MS_EXCEPTION_IF_NULL(base_node);
1595 kernel_graph = AnfAlgo::FetchKernelGraph(base_node.get());
1596 } else {
1597 MS_LOG(EXCEPTION) << "Invalid no input actor: " << no_input_kernel_actor->GetAID().Name();
1598 }
1599 MS_EXCEPTION_IF_NULL(kernel_graph);
1600 auto actor_name = parser->FetchGroupNameByKernelGraph(kernel_graph) + kStackActorNameSuffix;
1601 if (!parser->IsCallInputKernelGraph(kernel_graph.get())) {
1602 const auto &func_graph = parser->FetchFuncGraphByKernelGraph(kernel_graph.get());
1603 MS_EXCEPTION_IF_NULL(func_graph);
1604 actor_name = func_graph->ToString() + kEntranceActorNameSuffix;
1605 }
1606
1607 auto from_actor = FetchActor(actor_name);
1608 MS_EXCEPTION_IF_NULL(from_actor);
1609 SchedulerHelper::AddControlArrow(from_actor, no_input_kernel_actor.get());
1610 }
1611 LinkOutputControlArrowForActor(actor_set, graph_compiler_info);
1612 }
1613
LinkControlArrowByAutoMonad(ControlActor * to_actor,const AnfNodePtr & from_node,const ControlNodeParserPtr & parser) const1614 void ControlNodeScheduler::LinkControlArrowByAutoMonad(ControlActor *to_actor, const AnfNodePtr &from_node,
1615 const ControlNodeParserPtr &parser) const {
1616 MS_EXCEPTION_IF_NULL(to_actor);
1617 MS_EXCEPTION_IF_NULL(from_node);
1618 MS_EXCEPTION_IF_NULL(parser);
1619 MS_LOG(DEBUG) << "Link auto monad control arrow from node:" << from_node->DebugString()
1620 << " to actor:" << to_actor->GetAID();
1621
1622 std::set<AnfNodePtr> depend_nodes;
1623 FetchRealDependNodeByAutoMonad(from_node, &depend_nodes);
1624
1625 for (const auto &depend_node : depend_nodes) {
1626 MS_EXCEPTION_IF_NULL(depend_node);
1627 MS_LOG(DEBUG) << "Add depend node:" << depend_node->DebugString() << " for actor:" << to_actor->GetAID();
1628 auto from_actor = FetchActor(GetActorName(depend_node));
1629 auto graph = parser->FetchKernelGraphByFrontNode(depend_node);
1630
1631 std::vector<AbstractActor *> from_actors;
1632 if (common::AnfAlgo::IsCallNode(depend_node) && !IsNotCut(depend_node)) {
1633 // If the actor already exists with control arrow, skip it.
1634 if (IsControlArrowExistForCallNode(depend_node, to_actor, parser)) {
1635 MS_LOG(DEBUG) << "Control arrow from call node:" << depend_node << " to actor:" << to_actor->GetAID()
1636 << "is exist, skip it";
1637 continue;
1638 }
1639 int branch_id = parser->FetchBranchIDByCallNode(depend_node);
1640 const auto &func_graphs = parser->FetchFuncGraphbyCallNode(depend_node);
1641 if (func_graphs.empty()) {
1642 MS_LOG_WITH_NODE(INTERNAL_EXCEPTION, depend_node)
1643 << "#dmsg#Runtime error info:#dmsg#Failed to get funcgraph by call node:" << depend_node->DebugString();
1644 }
1645 for (const auto &func_graph : func_graphs) {
1646 MS_EXCEPTION_IF_NULL(func_graph);
1647 auto exit_actor_name = func_graph->ToString() + kExitActorNameSuffix;
1648 from_actor = FetchActor(exit_actor_name);
1649 MS_EXCEPTION_IF_NULL(from_actor);
1650 (void)from_actors.emplace_back(from_actor);
1651 auto exit_actor = dynamic_cast<ExitActor *>(from_actor);
1652 MS_EXCEPTION_IF_NULL(exit_actor);
1653 SchedulerHelper::AddControlArrowForExitActor(exit_actor, to_actor, branch_id);
1654 }
1655 to_actor->input_controls_num_ -= (func_graphs.size() - 1);
1656 } else if (from_actor != nullptr) {
1657 (void)from_actors.emplace_back(from_actor);
1658 SchedulerHelper::AddControlArrow(from_actor, to_actor);
1659 } else {
1660 if (graph == nullptr) {
1661 MS_LOG_WITH_NODE(INTERNAL_EXCEPTION, depend_node)
1662 << "#dmsg#Runtime error info:#dmsg#Failed to find actor for node:" << depend_node->DebugString();
1663 }
1664 from_actor = FetchActor(parser->FetchGroupNameByKernelGraph(graph) + kExitActorNameSuffix);
1665 MS_EXCEPTION_IF_NULL(from_actor);
1666 if (std::find_if(from_actor->output_control_arrows_.begin(), from_actor->output_control_arrows_.end(),
1667 [&to_actor](auto &output_control_arrow) {
1668 MS_EXCEPTION_IF_NULL(output_control_arrow);
1669 return output_control_arrow->to_op_id_.Name() == to_actor->GetAID().Name();
1670 }) != from_actor->output_control_arrows_.end()) {
1671 MS_LOG(DEBUG) << "Link auto monad control from actor:" << from_actor->GetAID()
1672 << " to actor:" << to_actor->GetAID() << " is already exist.";
1673 continue;
1674 }
1675 (void)from_actors.emplace_back(from_actor);
1676 SchedulerHelper::AddControlArrow(from_actor, to_actor);
1677 }
1678 if (to_actor->type_ != KernelTransformType::kStackActor || parser->IsNeedStackControlNode(depend_node) ||
1679 parser->IsRecursionCallNode(depend_node) || (graph != nullptr && parser->IsRecursionKernelGraph(graph))) {
1680 continue;
1681 }
1682 // If the control arrow comes from a recursive call node or a recursive kernel graph, these control edges will be
1683 // directly linked to the stack actor, otherwise, they need to be cached in the stack of the stack actor.
1684 auto stack_actor = dynamic_cast<StackActor *>(to_actor);
1685 MS_EXCEPTION_IF_NULL(stack_actor);
1686 stack_actor->input_controls_num_--;
1687 stack_actor->input_stack_controls_num_++;
1688 for (const auto &actor : from_actors) {
1689 MS_EXCEPTION_IF_NULL(actor);
1690 MS_LOG(DEBUG) << "Add stack control aid:" << actor->GetAID() << " for actor:" << stack_actor->GetAID();
1691 (void)stack_actor->stack_control_aids_.emplace(actor->GetAID());
1692 stack_actor->control_aid_to_indexs_[actor->GetAID()] = stack_actor->input_stack_controls_num_;
1693 }
1694 }
1695 MS_LOG(DEBUG) << "Link auto monad control arrow from node:" << from_node->DebugString()
1696 << " to actor:" << to_actor->GetAID() << " end";
1697 }
1698
LinkControlArrowByKernelGraphGroup(const GraphCompilerInfo & graph_compiler_info) const1699 void ControlNodeScheduler::LinkControlArrowByKernelGraphGroup(const GraphCompilerInfo &graph_compiler_info) const {
1700 const auto &parser = graph_compiler_info.control_node_parser_;
1701 MS_EXCEPTION_IF_NULL(parser);
1702
1703 for (const auto &graph_group : parser->kernel_graph_group_infos_) {
1704 MS_EXCEPTION_IF_NULL(graph_group);
1705 if (!graph_group->need_stack_) {
1706 continue;
1707 }
1708 auto stack_actor = FetchActor(graph_group->group_name_ + kStackActorNameSuffix);
1709 MS_EXCEPTION_IF_NULL(stack_actor);
1710 auto to_actor = dynamic_cast<ControlActor *>(stack_actor);
1711 MS_EXCEPTION_IF_NULL(to_actor);
1712 for (const auto &monad_input : graph_group->monad_inputs_) {
1713 MS_EXCEPTION_IF_NULL(monad_input);
1714 MS_LOG(DEBUG) << "Add monad control arrow for group:" << graph_group->group_name_
1715 << " to actor:" << to_actor->GetAID() << " by monad input:" << monad_input->DebugString();
1716 LinkControlArrowByAutoMonad(to_actor, monad_input, parser);
1717 }
1718 }
1719 }
1720
LinkBranchIDArrowForControlActor(ControlActorSet * const control_actor_set) const1721 void ControlNodeScheduler::LinkBranchIDArrowForControlActor(ControlActorSet *const control_actor_set) const {
1722 MS_EXCEPTION_IF_NULL(control_actor_set);
1723
1724 // Connect the branch id arrows from the entrance actor to the exit actor for each funcgraph.
1725 for (auto exit_actor : control_actor_set->exit_actors_) {
1726 MS_EXCEPTION_IF_NULL(exit_actor);
1727
1728 // If the node in the exit actor is empty, it means that it is between the kernel actor and the control actor,
1729 // and no need to send the branch id.
1730 const auto &node = exit_actor->node_;
1731 if (node == nullptr) {
1732 continue;
1733 }
1734
1735 const auto &func_graph = node->func_graph();
1736 MS_EXCEPTION_IF_NULL(func_graph);
1737 const auto &actor_name = func_graph->ToString() + kEntranceActorNameSuffix;
1738 auto actor = FetchActor(actor_name);
1739 MS_EXCEPTION_IF_NULL(actor);
1740 auto entrance_actor = dynamic_cast<EntranceActor *>(actor);
1741 MS_EXCEPTION_IF_NULL(entrance_actor);
1742 SchedulerHelper::AddBranchIDArrow(entrance_actor, exit_actor.get());
1743 }
1744
1745 // Connect the branch id arrows from the entrance actor to the stack actor.
1746 for (auto stack_actor : control_actor_set->stack_actors_) {
1747 MS_EXCEPTION_IF_NULL(stack_actor);
1748 auto node = stack_actor->node_;
1749 if (!stack_actor->formal_parameters_.empty()) {
1750 node = stack_actor->formal_parameters_.back().first;
1751 } else {
1752 MS_LOG(INFO) << "No formal parameter for stack actor:" << stack_actor->GetAID();
1753 }
1754 MS_EXCEPTION_IF_NULL(node);
1755 const auto &func_graph = node->func_graph();
1756 MS_EXCEPTION_IF_NULL(func_graph);
1757 const auto &actor_name = func_graph->ToString() + kEntranceActorNameSuffix;
1758 auto actor = FetchActor(actor_name);
1759 MS_EXCEPTION_IF_NULL(actor);
1760 auto entrance_actor = dynamic_cast<EntranceActor *>(actor);
1761 MS_EXCEPTION_IF_NULL(entrance_actor);
1762 SchedulerHelper::AddBranchIDArrow(entrance_actor, stack_actor.get());
1763 }
1764 }
1765
LinkDataArrowForKernelActor(const GraphCompilerInfo & graph_compiler_info) const1766 void ControlNodeScheduler::LinkDataArrowForKernelActor(const GraphCompilerInfo &graph_compiler_info) const {
1767 const auto &parser = graph_compiler_info.control_node_parser_;
1768 MS_EXCEPTION_IF_NULL(parser);
1769
1770 // Link data arrows from entrance actors and stack actors to kernel actors.
1771 for (const auto &func_graph_to_kernel_graphs : parser->func_graph_to_kernel_graph_groups_) {
1772 // Fetch the source entrance actor.
1773 const auto &func_graph = func_graph_to_kernel_graphs.first;
1774 MS_EXCEPTION_IF_NULL(func_graph);
1775 auto actor_name = func_graph->ToString() + kEntranceActorNameSuffix;
1776 auto actor = FetchActor(actor_name);
1777 MS_EXCEPTION_IF_NULL(actor);
1778 auto entrance_actor = dynamic_cast<ControlActor *>(actor);
1779 MS_EXCEPTION_IF_NULL(entrance_actor);
1780
1781 for (const auto &kernel_graph_group : func_graph_to_kernel_graphs.second) {
1782 for (const auto &kernel_graph : kernel_graph_group) {
1783 MS_EXCEPTION_IF_NULL(kernel_graph);
1784 if (kernel_graph->execution_order().empty()) {
1785 continue;
1786 }
1787 LinkDataArrowByKernelGraph(kernel_graph, entrance_actor, parser);
1788 }
1789 }
1790 }
1791 }
1792
LinkDataArrowForCustomActor(const ActorSet * actor_set,const GraphCompilerInfo & graph_compiler_info) const1793 void ControlNodeScheduler::LinkDataArrowForCustomActor(const ActorSet *actor_set,
1794 const GraphCompilerInfo &graph_compiler_info) const {
1795 MS_EXCEPTION_IF_NULL(actor_set);
1796 const auto &parser = graph_compiler_info.control_node_parser_;
1797 MS_EXCEPTION_IF_NULL(parser);
1798
1799 for (const auto &custom_actor : actor_set->custom_actors_) {
1800 MS_EXCEPTION_IF_NULL(custom_actor);
1801 auto kernel = custom_actor->kernel().lock();
1802 MS_EXCEPTION_IF_NULL(kernel);
1803 if (AnfUtils::GetCustomActorType(kernel) != kInfer) {
1804 continue;
1805 }
1806 // Kernel in depends form map should link data arrow for infer shape.
1807 auto base_node = AnfUtils::GetCustomActorBaseNode(kernel);
1808 MS_EXCEPTION_IF_NULL(base_node);
1809 auto dynamic_shape_depends = abstract::GetValueDependArgIndices(base_node);
1810 for (auto iter = dynamic_shape_depends.begin(); iter != dynamic_shape_depends.end(); ++iter) {
1811 auto input_node = common::AnfAlgo::GetInputNode(base_node, LongToSize(*iter));
1812 MS_EXCEPTION_IF_NULL(input_node);
1813 KernelWithIndex from_kernel_with_index = common::AnfAlgo::VisitKernelWithReturnType(input_node, 0, false);
1814 const AnfNodePtr real_input_node = from_kernel_with_index.first;
1815 MS_EXCEPTION_IF_NULL(real_input_node);
1816 if (real_input_node->isa<ValueNode>()) {
1817 continue;
1818 }
1819 auto graph = AnfAlgo::FetchKernelGraph(real_input_node.get());
1820 MS_EXCEPTION_IF_NULL(graph);
1821 if (!parser->IsControlFlowDataArrow(graph, real_input_node)) {
1822 continue;
1823 }
1824
1825 // Link data arrow from entrance actor or stack actor to infer shape custom actor.
1826 const auto &front_node_with_index = GetFrontNodeByKernelGraph(real_input_node, graph.get());
1827 MS_EXCEPTION_IF_NULL(front_node_with_index.first);
1828 AbstractActor *from_base_actor = nullptr;
1829 if (parser->IsCallInputKernelGraph(graph.get())) {
1830 from_base_actor = FetchActor(parser->FetchGroupNameByKernelGraph(graph) + kStackActorNameSuffix);
1831 MS_EXCEPTION_IF_NULL(from_base_actor);
1832 } else if (!front_node_with_index.first->isa<Parameter>()) {
1833 MS_LOG(INFO) << "Internal front node:" << front_node_with_index.first->DebugString()
1834 << " index:" << front_node_with_index.second << " for custom actor:" << custom_actor->GetAID()
1835 << " kernel:" << kernel->fullname_with_scope() << " input index:" << *iter;
1836 const auto &from_graph = parser->FetchKernelGraphByFrontNode(front_node_with_index.first);
1837 MS_EXCEPTION_IF_NULL(from_graph);
1838 from_base_actor = FetchActor(parser->FetchGroupNameByKernelGraph(from_graph) + kExitActorNameSuffix);
1839 MS_EXCEPTION_IF_NULL(from_base_actor);
1840 } else {
1841 const auto &func_graph = front_node_with_index.first->func_graph();
1842 MS_EXCEPTION_IF_NULL(func_graph);
1843 from_base_actor = FetchActor(func_graph->ToString() + kEntranceActorNameSuffix);
1844 MS_EXCEPTION_IF_NULL(from_base_actor);
1845 }
1846 const auto &from_actor = dynamic_cast<ControlActor *>(from_base_actor);
1847 MS_EXCEPTION_IF_NULL(from_actor);
1848 size_t from_index = from_actor->FetchNodePosition(front_node_with_index);
1849 MS_LOG(DEBUG) << "Link data arrow from actor:" << from_actor->GetAID()
1850 << " to custom actor:" << custom_actor->GetAID();
1851 SchedulerHelper::AddDataArrow(from_actor, custom_actor.get(), from_index, LongToSize(*iter));
1852 }
1853 }
1854 }
1855
LinkDataArrowByKernelGraphInSinkMode(const KernelGraphPtr & graph,ControlActor * const from_actor,const ControlNodeParserPtr & parser) const1856 void ControlNodeScheduler::LinkDataArrowByKernelGraphInSinkMode(const KernelGraphPtr &graph,
1857 ControlActor *const from_actor,
1858 const ControlNodeParserPtr &parser) const {
1859 MS_EXCEPTION_IF_NULL(graph);
1860 MS_EXCEPTION_IF_NULL(from_actor);
1861 MS_EXCEPTION_IF_NULL(parser);
1862 MS_LOG(DEBUG) << "Link data arrow in sink mode by kernel graph:" << graph->ToString();
1863 auto to_actor = FetchActor(
1864 graph->is_any_type_input() ? KernelTransformType::kAnyTypeKernelActor : KernelTransformType::kSuperKernelActor, "",
1865 nullptr, graph);
1866 MS_EXCEPTION_IF_NULL(to_actor);
1867 auto super_kernel_actor = dynamic_cast<SuperKernelActor *>(to_actor);
1868 MS_EXCEPTION_IF_NULL(super_kernel_actor);
1869
1870 auto &input_nodes = graph->input_nodes();
1871 for (size_t i = 0; i < input_nodes.size(); ++i) {
1872 const auto &input_node = input_nodes[i];
1873 MS_EXCEPTION_IF_NULL(input_node);
1874 if (HasAbstractMonad(input_node) || (!parser->IsControlFlowDataArrow(graph, input_node))) {
1875 continue;
1876 }
1877 size_t to_index = super_kernel_actor->FetchInputNodePosition(input_node);
1878 const auto &front_node_with_index = GetFrontNodeByKernelGraph(input_node, graph.get());
1879 MS_EXCEPTION_IF_NULL(front_node_with_index.first);
1880 if (front_node_with_index.first->isa<ValueNode>()) {
1881 continue;
1882 }
1883 if (front_node_with_index.first->isa<CNode>() && (from_actor->type() != KernelTransformType::kStackActor)) {
1884 // If the input is an internal parameter, the input arrow should be linked to the exit actor of the kernel
1885 // graph which the internal parameter belong.
1886 MS_LOG(INFO) << "Internal parameter in control flow, backend input:" << input_node->DebugString()
1887 << " front node:" << front_node_with_index.first->DebugString();
1888 const auto &from_graph = parser->FetchKernelGraphByFrontNode(front_node_with_index.first);
1889 MS_EXCEPTION_IF_NULL(from_graph);
1890 auto actor = FetchActor(parser->FetchGroupNameByKernelGraph(from_graph) + kExitActorNameSuffix);
1891 MS_EXCEPTION_IF_NULL(actor);
1892 auto exit_actor = dynamic_cast<ControlActor *>(actor);
1893 MS_EXCEPTION_IF_NULL(exit_actor);
1894 size_t from_index = exit_actor->FetchNodePosition(front_node_with_index);
1895 SchedulerHelper::AddFormalParameterDeviceTensor(exit_actor, from_index, input_node, graph);
1896 SchedulerHelper::AddDataArrow(exit_actor, to_actor, from_index, i);
1897 continue;
1898 }
1899 size_t from_index = from_actor->FetchNodePosition(front_node_with_index);
1900 SchedulerHelper::AddFormalParameterDeviceTensor(from_actor, from_index, input_node, graph);
1901 SchedulerHelper::AddDataArrow(from_actor, to_actor, from_index, to_index);
1902 }
1903 return;
1904 }
1905
LinkDataArrowByKernelGraph(const KernelGraphPtr & graph,ControlActor * const entrance_actor,const ControlNodeParserPtr & parser) const1906 void ControlNodeScheduler::LinkDataArrowByKernelGraph(const KernelGraphPtr &graph, ControlActor *const entrance_actor,
1907 const ControlNodeParserPtr &parser) const {
1908 MS_EXCEPTION_IF_NULL(graph);
1909 MS_EXCEPTION_IF_NULL(parser);
1910 MS_LOG(DEBUG) << "Link data arrow by kernel graph:" << graph->ToString();
1911 auto from_actor = entrance_actor;
1912 // If there is a call node in the input of the graph, the parameter of the graph needs to be sent by the
1913 // corresponding stack actor, otherwise it is sent by the entrance actor.
1914 if (parser->IsCallInputKernelGraph(graph.get())) {
1915 auto actor = FetchActor(parser->FetchGroupNameByKernelGraph(graph) + kStackActorNameSuffix);
1916 MS_EXCEPTION_IF_NULL(actor);
1917 from_actor = dynamic_cast<ControlActor *>(actor);
1918 }
1919
1920 if (graph->is_graph_run_mode() || graph->is_any_type_input()) {
1921 // Link data arrow in graph mode.
1922 LinkDataArrowByKernelGraphInSinkMode(graph, from_actor, parser);
1923 return;
1924 }
1925
1926 auto &execution_order = graph->execution_order();
1927 for (const auto &kernel : execution_order) {
1928 MS_EXCEPTION_IF_NULL(kernel);
1929 if ((!graph->is_graph_run_mode()) && (IsSkippedKernelActor(kernel) || !IsKernelActor(kernel))) {
1930 continue;
1931 }
1932 for (size_t i = 0; i < common::AnfAlgo::GetInputNum(kernel); ++i) {
1933 auto input_node = common::AnfAlgo::GetInputNode(kernel, i);
1934 MS_EXCEPTION_IF_NULL(input_node);
1935 auto input_with_index = common::AnfAlgo::VisitKernelWithReturnType(input_node, 0, false);
1936 auto input = input_with_index.first;
1937 MS_EXCEPTION_IF_NULL(input);
1938 if (HasAbstractMonad(input) || (!parser->IsControlFlowDataArrow(graph, input))) {
1939 continue;
1940 }
1941
1942 auto from_node_with_index = GetFrontNodeByKernelGraph(input, graph.get());
1943 MS_EXCEPTION_IF_NULL(from_node_with_index.first);
1944 if (common::AnfAlgo::CheckPrimitiveType(from_node_with_index.first, prim::kPrimTupleGetItem) &&
1945 (!from_node_with_index.first->cast<CNodePtr>()->HasAttr(kAttrReplaceRealKernelInBackend)) &&
1946 (!common::AnfAlgo::IsTupleOutput(from_node_with_index.first))) {
1947 MS_LOG(WARNING) << "Input node:" << from_node_with_index.first->DebugString()
1948 << " for graph:" << graph->ToString() << " is a tuple get item";
1949 from_node_with_index = FetchRealNodeByGetItem(from_node_with_index);
1950 }
1951
1952 // If the formal parameter is a tuple type, the parameter of the kernel graph will not directly correspond
1953 // to the front parameter, but the node in the internal parameter.
1954 const auto &from_node = from_node_with_index.first;
1955 MS_EXCEPTION_IF_NULL(from_node);
1956 MS_LOG(DEBUG) << "Graph:" << graph->ToString() << " from node:" << from_node_with_index.first->DebugString()
1957 << " index:" << from_node_with_index.second;
1958
1959 // Fetch actor and link.
1960 auto type = FetchKernelTransformType(kernel, graph, {});
1961 auto to_actor = FetchActor(type, "", kernel, graph);
1962 MS_EXCEPTION_IF_NULL(to_actor);
1963 size_t from_index = 0;
1964 // If the input is a switch node and the graph does not need a stack, then the data arrow needs to be connected
1965 // from the switch actor.
1966 if ((common::AnfAlgo::CheckPrimitiveType(from_node, prim::kPrimSwitch) ||
1967 common::AnfAlgo::CheckPrimitiveType(from_node, prim::kPrimSwitchLayer)) &&
1968 (from_actor->type() != KernelTransformType::kStackActor)) {
1969 const auto &actor_name = GetActorName(from_node);
1970 auto actor = FetchActor(actor_name);
1971 MS_EXCEPTION_IF_NULL(actor);
1972 from_actor = dynamic_cast<ControlActor *>(actor);
1973 } else if (from_node->isa<CNode>() && (from_actor->type() != KernelTransformType::kStackActor)) {
1974 // If the input is an internal parameter, the input arrow should be linked to the exit actor of the kernel
1975 // graph which the internal parameter belong.
1976 MS_LOG(INFO) << "Internal parameter in control flow, backend input:" << input->DebugString()
1977 << " front node:" << from_node->DebugString();
1978 const auto &from_graph = parser->FetchKernelGraphByFrontNode(from_node);
1979 MS_EXCEPTION_IF_NULL(from_graph);
1980 auto actor = FetchActor(parser->FetchGroupNameByKernelGraph(from_graph) + kExitActorNameSuffix);
1981 MS_EXCEPTION_IF_NULL(actor);
1982 auto exit_actor = dynamic_cast<ControlActor *>(actor);
1983 from_index = exit_actor->FetchNodePosition(from_node_with_index);
1984 SchedulerHelper::AddFormalParameterDeviceTensor(exit_actor, from_index, input, graph);
1985 SchedulerHelper::AddDataArrow(exit_actor, to_actor, from_index, i);
1986 continue;
1987 } else {
1988 MS_LOG(DEBUG) << "Fetch node:" << from_node_with_index.first->DebugString()
1989 << " index:" << from_node_with_index.second << " from actor:" << from_actor->GetAID();
1990 from_index = from_actor->FetchNodePosition(from_node_with_index);
1991 }
1992
1993 MS_EXCEPTION_IF_NULL(from_actor);
1994 SchedulerHelper::AddFormalParameterDeviceTensor(from_actor, from_index, input, graph);
1995 SchedulerHelper::AddDataArrow(from_actor, to_actor, from_index, i);
1996 }
1997 }
1998 }
1999
LinkDataArrowForOutputActor(ActorSet * const actor_set,const GraphCompilerInfo & graph_compiler_info) const2000 void ControlNodeScheduler::LinkDataArrowForOutputActor(ActorSet *const actor_set,
2001 const GraphCompilerInfo &graph_compiler_info) const {
2002 MS_EXCEPTION_IF_NULL(actor_set);
2003 auto &to_actor = actor_set->output_actor_;
2004 MS_EXCEPTION_IF_NULL(to_actor);
2005 const auto &parser = graph_compiler_info.control_node_parser_;
2006 MS_EXCEPTION_IF_NULL(parser);
2007 const auto &root_graph = parser->root_func_graph_;
2008 MS_EXCEPTION_IF_NULL(root_graph);
2009 const auto &return_node = root_graph->return_node();
2010 MS_EXCEPTION_IF_NULL(return_node);
2011
2012 const auto &exit_actor_name = root_graph->ToString() + kExitActorNameSuffix;
2013 auto actor = FetchActor(exit_actor_name);
2014 MS_EXCEPTION_IF_NULL(actor);
2015 auto exit_actor = dynamic_cast<ExitActor *>(actor);
2016 MS_EXCEPTION_IF_NULL(exit_actor);
2017 for (size_t i = 0; i < exit_actor->formal_parameters_.size(); ++i) {
2018 SchedulerHelper::AddDataArrowForExitActor(exit_actor, to_actor.get(), i, i, 0);
2019 to_actor->input_datas_num_++;
2020 }
2021
2022 auto control_node_to_device_contexts = parser->control_node_to_device_contexts_;
2023 auto iter = control_node_to_device_contexts.find(return_node);
2024 if (iter == control_node_to_device_contexts.end()) {
2025 MS_LOG_WITH_NODE(INTERNAL_EXCEPTION, return_node)
2026 << "#dmsg#Runtime error info:#dmsg#Failed to find device contexts for node:" << return_node->DebugString();
2027 }
2028 if (iter->second.size() != to_actor->device_contexts().size()) {
2029 MS_LOG(INTERNAL_EXCEPTION) << "#dmsg#Runtime error info:#dmsg#Invalid context size, need:"
2030 << to_actor->device_contexts().size() << " current:" << iter->second.size();
2031 }
2032 to_actor->device_contexts_ = iter->second;
2033 }
2034
LinkArrowForRootGraphEntranceActor(const GraphCompilerInfo & graph_compiler_info) const2035 void ControlNodeScheduler::LinkArrowForRootGraphEntranceActor(const GraphCompilerInfo &graph_compiler_info) const {
2036 MS_EXCEPTION_IF_NULL(graph_compiler_info.control_node_parser_);
2037 const auto &root_graph = graph_compiler_info.control_node_parser_->root_func_graph_;
2038 MS_EXCEPTION_IF_NULL(root_graph);
2039 const auto &entrance_actor_name = root_graph->ToString() + kEntranceActorNameSuffix;
2040 auto to_actor = dynamic_cast<EntranceActor *>(FetchActor(entrance_actor_name));
2041 MS_EXCEPTION_IF_NULL(to_actor);
2042
2043 const auto &host_ds_actor_name = graph_compiler_info.name_ + kHostDSActorNameSuffix;
2044 auto host_ds_actor = dynamic_cast<HostQueueDataSourceActor *>(FetchActor(host_ds_actor_name));
2045 // No host data source actor scenario.
2046 if (host_ds_actor == nullptr) {
2047 const auto &data_prepare_actor_name = graph_compiler_info.name_ + kDataPrepareActorNameSuffix;
2048 auto data_prepare_actor = FetchActor(data_prepare_actor_name);
2049 MS_EXCEPTION_IF_NULL(data_prepare_actor);
2050 SchedulerHelper::AddControlArrow(data_prepare_actor, to_actor);
2051 return;
2052 }
2053
2054 // The host data source actor sends all the input to the entrance actor of the root graph.
2055 for (size_t i = 0; i < to_actor->formal_parameters_.size(); ++i) {
2056 const auto &formal_parameter = to_actor->formal_parameters_[i];
2057 MS_EXCEPTION_IF_NULL(formal_parameter.first);
2058 MS_LOG(DEBUG) << "Formal parameter:" << formal_parameter.first->DebugString()
2059 << " index:" << formal_parameter.second;
2060 const auto &iter = host_ds_actor->data_node_position_map_.find(formal_parameter);
2061 if (iter != host_ds_actor->data_node_position_map_.end()) {
2062 const auto ¶meter_with_index = host_ds_actor->data_nodes()[iter->second];
2063 SchedulerHelper::AddDataArrow(host_ds_actor, to_actor, parameter_with_index.second, i,
2064 parameter_with_index.first);
2065 } else {
2066 MS_LOG(INFO) << "Invalid formal parameter:" << formal_parameter.first->DebugString()
2067 << " index:" << formal_parameter.second << " for actor:" << to_actor->GetAID();
2068 }
2069 }
2070 }
2071
SetTimeSummaryForControlActor(const GraphCompilerInfo & graph_compiler_info) const2072 void ControlNodeScheduler::SetTimeSummaryForControlActor(const GraphCompilerInfo &graph_compiler_info) const {
2073 const auto &parser = graph_compiler_info.control_node_parser_;
2074 MS_EXCEPTION_IF_NULL(parser);
2075
2076 for (const auto &kernel_graph_group_info : parser->kernel_graph_group_infos_) {
2077 MS_EXCEPTION_IF_NULL(kernel_graph_group_info);
2078 const auto &exit_actor_name = kernel_graph_group_info->group_name_ + kExitActorNameSuffix;
2079 const auto &exit_base_actor = FetchActor(exit_actor_name);
2080 if (exit_base_actor == nullptr) {
2081 continue;
2082 }
2083 const auto &exit_actor = dynamic_cast<ControlActor *>(exit_base_actor);
2084 MS_EXCEPTION_IF_NULL(exit_actor);
2085
2086 // Set the exit actor of kernel graph to its entrance actor or stack actor.
2087 if (kernel_graph_group_info->need_stack_ == false) {
2088 if (kernel_graph_group_info->graphs_.empty()) {
2089 continue;
2090 }
2091 const auto &graph = *(kernel_graph_group_info->graphs_.begin());
2092 const auto &func_graph = parser->FetchFuncGraphByKernelGraph(graph.get());
2093 MS_EXCEPTION_IF_NULL(func_graph);
2094 auto entrance_base_actor = FetchActor(func_graph->ToString() + kEntranceActorNameSuffix);
2095 if (entrance_base_actor != nullptr) {
2096 const auto &entrance_actor = dynamic_cast<ControlActor *>(entrance_base_actor);
2097 MS_EXCEPTION_IF_NULL(entrance_actor);
2098 (void)entrance_actor->end_actors_.emplace(exit_actor);
2099 MS_LOG(DEBUG) << "Add time summart for exit actor:" << exit_actor->GetAID()
2100 << " to actor:" << entrance_actor->GetAID();
2101 }
2102 continue;
2103 }
2104
2105 auto stack_base_actor = FetchActor(kernel_graph_group_info->group_name_ + kStackActorNameSuffix);
2106 if (stack_base_actor != nullptr) {
2107 const auto &stack_actor = dynamic_cast<ControlActor *>(stack_base_actor);
2108 MS_EXCEPTION_IF_NULL(stack_actor);
2109 (void)stack_actor->end_actors_.emplace(exit_actor);
2110 MS_LOG(DEBUG) << "Add time summart for exit actor:" << exit_actor->GetAID()
2111 << " to actor:" << stack_actor->GetAID();
2112 }
2113 }
2114 }
2115
IsNoInputActor(const ControlActor * control_actor) const2116 bool ControlNodeScheduler::IsNoInputActor(const ControlActor *control_actor) const {
2117 MS_EXCEPTION_IF_NULL(control_actor);
2118 return (control_actor->input_datas_num_ == 0 && control_actor->input_controls_num_ == 0 &&
2119 control_actor->input_partials_num_ == 0 && control_actor->input_branch_ids_num_ == 0);
2120 }
2121 } // namespace runtime
2122 } // namespace mindspore
2123