• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "runtime/framework/graph_scheduler.h"
18 #include "runtime/framework/actor/memory_manager_actor.h"
19 #include "runtime/framework/actor/debug_actor.h"
20 #include "runtime/framework/actor/recorder_actor.h"
21 #include "runtime/hardware/device_context_manager.h"
22 #include "mindrt/src/actor/actormgr.h"
23 #include "mindrt/include/async/async.h"
24 #include "backend/session/anf_runtime_algorithm.h"
25 #include "backend/optimizer/common/helper.h"
26 #include "utils/config_manager.h"
27 #include "utils/log_adapter.h"
28 #include "utils/convert_utils.h"
29 #include "utils/ms_context.h"
30 #if !defined(_WIN32) && !defined(_WIN64)
31 #include "utils/signal_util.h"
32 #endif
33 #ifndef ENABLE_SECURITY
34 #include "debug/data_dump/dump_json_parser.h"
35 #endif
36 #ifdef ENABLE_DUMP_IR
37 #include "debug/rdr/recorder_manager.h"
38 #endif
39 #ifdef ENABLE_DEBUGGER
40 #include "debug/debugger/debugger.h"
41 #endif
42 #include "profiler/device/profiling.h"
43 #include "debug/common.h"
44 
45 namespace mindspore {
46 namespace runtime {
47 namespace {
IsNeedInsertCopyActor(const DeviceContext * from_device_context,const DeviceContext * to_device_context)48 bool IsNeedInsertCopyActor(const DeviceContext *from_device_context, const DeviceContext *to_device_context) {
49   MS_EXCEPTION_IF_NULL(from_device_context);
50   MS_EXCEPTION_IF_NULL(to_device_context);
51 
52   if (from_device_context->GetDeviceAddressType() == to_device_context->GetDeviceAddressType()) {
53     return false;
54   } else {
55     return true;
56   }
57 }
58 
IsSingleOpActorSet(const ActorSet * actor_set)59 inline bool IsSingleOpActorSet(const ActorSet *actor_set) {
60   MS_EXCEPTION_IF_NULL(actor_set);
61   return actor_set->kernel_actors_.size() == 1;
62 }
63 
64 // Convert the actors vector by the actor set.
CollectActors(const ActorSet * actor_set)65 std::vector<ActorReference> CollectActors(const ActorSet *actor_set) {
66   MS_EXCEPTION_IF_NULL(actor_set);
67   std::vector<ActorReference> actors;
68 
69   if (actor_set->data_prepare_actor_ != nullptr) {
70     (void)actors.emplace_back(static_cast<ActorReference>(actor_set->data_prepare_actor_));
71   }
72   for (auto &data_source_actor : actor_set->data_source_actors_) {
73     MS_EXCEPTION_IF_NULL(data_source_actor);
74     (void)actors.emplace_back(static_cast<ActorReference>(data_source_actor));
75   }
76   for (auto &kernel_actor : actor_set->kernel_actors_) {
77     MS_EXCEPTION_IF_NULL(kernel_actor);
78     (void)actors.emplace_back(static_cast<ActorReference>(kernel_actor));
79   }
80   for (auto &switch_actor : actor_set->switch_actors_) {
81     MS_EXCEPTION_IF_NULL(switch_actor);
82     (void)actors.emplace_back(static_cast<ActorReference>(switch_actor));
83   }
84   for (auto &gather_actor : actor_set->gather_actors_) {
85     MS_EXCEPTION_IF_NULL(gather_actor);
86     (void)actors.emplace_back(static_cast<ActorReference>(gather_actor));
87   }
88   for (auto &copy_actor : actor_set->copy_actors_) {
89     MS_EXCEPTION_IF_NULL(copy_actor);
90     (void)actors.emplace_back(static_cast<ActorReference>(copy_actor));
91   }
92   if (actor_set->loop_count_actor_ != nullptr) {
93     (void)actors.emplace_back(static_cast<ActorReference>(actor_set->loop_count_actor_));
94   }
95   if (actor_set->output_actor_ != nullptr) {
96     (void)actors.emplace_back(static_cast<ActorReference>(actor_set->output_actor_));
97   }
98 
99   return actors;
100 }
101 
ClearNodeInfo(const KernelGraphPtr & graph)102 void ClearNodeInfo(const KernelGraphPtr &graph) {
103   MS_EXCEPTION_IF_NULL(graph);
104 
105   // Clear input parameter device tensor and device tensor store.
106   for (const auto &input_node : graph->input_nodes()) {
107     MS_EXCEPTION_IF_NULL(input_node);
108     if (!input_node->isa<Parameter>()) {
109       continue;
110     }
111     auto parameter = input_node->cast<ParameterPtr>();
112     MS_EXCEPTION_IF_NULL(parameter);
113     parameter->DecreaseUsedGraphCount();
114     // Only the parameter has no graph used, then clear the device tensor.
115     if (parameter->used_graph_count() != 0) {
116       continue;
117     }
118     auto front_input_node = FetchFrontNodeByBackendNode(input_node, graph);
119     DeviceTensorStore::GetInstance().Remove(front_input_node.get());
120     size_t output_num = AnfAlgo::GetOutputTensorNum(input_node);
121     for (size_t index = 0; index < output_num; ++index) {
122       if (AnfAlgo::OutputAddrExist(input_node, index)) {
123         AnfAlgo::SetOutputAddr(nullptr, index, input_node.get());
124       }
125     }
126   }
127 
128   // Clear input value node device tensor and device tensor store.
129   for (const auto &value_node : graph->graph_value_nodes()) {
130     auto front_value_node = FetchFrontNodeByBackendNode(value_node, graph);
131     DeviceTensorStore::GetInstance().Remove(front_value_node.get());
132     if (AnfAlgo::OutputAddrExist(value_node, 0)) {
133       AnfAlgo::SetOutputAddr(nullptr, 0, value_node.get());
134     }
135   }
136 
137   // Clear cnode device tensor.
138   for (const auto &cnode : graph->execution_order()) {
139     size_t output_num = AnfAlgo::GetOutputTensorNum(cnode);
140     for (size_t index = 0; index < output_num; ++index) {
141       if (AnfAlgo::OutputAddrExist(cnode, index)) {
142         AnfAlgo::SetOutputAddr(nullptr, index, cnode.get());
143       }
144     }
145   }
146 }
147 
148 #if !defined(_WIN32) && !defined(_WIN64)
IntHandler(int,siginfo_t *,void *)149 void IntHandler(int, siginfo_t *, void *) {
150   int this_pid = getpid();
151   MS_LOG(WARNING) << "Process " << this_pid << " receive KeyboardInterrupt signal.";
152   (void)kill(this_pid, SIGTERM);
153 }
154 #endif
155 }  // namespace
156 
Clear(const ActorInfo & actor_info,const std::vector<KernelGraphPtr> & graphs)157 void GraphScheduler::Clear(const ActorInfo &actor_info, const std::vector<KernelGraphPtr> &graphs) noexcept {
158   // Terminate the actors of actor info.
159   if (actors_.count(actor_info) > 0) {
160     auto actor_manager = ActorMgr::GetActorMgrRef();
161     if (actor_manager == nullptr) {
162       MS_LOG(ERROR) << "Actor manager is not exist.";
163       return;
164     }
165     auto actor_set = actors_[actor_info];
166     auto base_actors = CollectActors(actor_set.get());
167     for (auto &base_actor : base_actors) {
168       MS_EXCEPTION_IF_NULL(base_actor);
169       (void)actor_name_to_actor_.erase(base_actor->GetAID().Name());
170       actor_manager->Terminate(base_actor->GetAID());
171     }
172   }
173 
174   // Clear device tensor and device tensor store.
175   for (auto &graph : graphs) {
176     ClearNodeInfo(graph);
177   }
178 
179   // Clear global maps of actor info.
180   (void)actors_.erase(actor_info);
181 }
182 
Clear()183 void GraphScheduler::Clear() {
184   // Terminate all actors.
185   auto actor_manager = ActorMgr::GetActorMgrRef();
186   MS_EXCEPTION_IF_NULL(actor_manager);
187   actor_manager->Finalize();
188 
189   // Clear the member of DeviceTensorStore.
190   DeviceTensorStore::GetInstance().Clear();
191 
192   // Clear global maps.
193   actors_.clear();
194   actor_name_to_actor_.clear();
195 }
196 
197 using DataArrowLinkFunc = void (GraphScheduler::*)(AbstractActor *const, KernelActor *const, const KernelWithIndex &,
198                                                    const KernelWithIndex &, const KernelGraphPtr &);
199 static std::map<KernelTransformType, DataArrowLinkFunc> kKernelTypeToLinkFunc = {};
200 
Initialize()201 void GraphScheduler::Initialize() {
202   if (init_) {
203     return;
204   }
205   init_ = true;
206 
207   (void)kKernelTypeToLinkFunc.emplace(KernelTransformType::kDeviceDataSourceActor,
208                                       &GraphScheduler::LinkDataArrowForDeviceDSActor);
209   (void)kKernelTypeToLinkFunc.emplace(KernelTransformType::kHostDataSourceActor,
210                                       &GraphScheduler::LinkDataArrowForHostDSActor);
211   (void)kKernelTypeToLinkFunc.emplace(KernelTransformType::kKernelActor, &GraphScheduler::LinkDataArrowForKernelActor);
212   (void)kKernelTypeToLinkFunc.emplace(KernelTransformType::kDeviceTensorStore,
213                                       &GraphScheduler::LinkDataArrowForDeviceTensorStore);
214   (void)kKernelTypeToLinkFunc.emplace(KernelTransformType::kInternalParameter,
215                                       &GraphScheduler::LinkDataArrowForInternalParameter);
216 
217   // Create the thread pool of actor runtime and Set the OMP_NUM_THREADS env.
218   size_t actor_thread_num = 0;
219   size_t OMP_thread_num = 0;
220   size_t max_thread_num = 0;
221   ComputeThreadNums(&actor_thread_num, &OMP_thread_num, &max_thread_num);
222   auto actor_manager = ActorMgr::GetActorMgrRef();
223   MS_EXCEPTION_IF_NULL(actor_manager);
224   auto ret = actor_manager->Initialize(true, actor_thread_num, max_thread_num);
225   if (ret != MINDRT_OK) {
226     MS_LOG(EXCEPTION) << "Actor manager init failed.";
227   }
228   std::string OMP_env = std::to_string(OMP_thread_num);
229   (void)common::SetEnv("OMP_NUM_THREADS", OMP_env.c_str(), 0);
230   auto OMP_thread_num_used = common::GetEnv("OMP_NUM_THREADS");
231   MS_LOG(INFO) << "The actor thread number: " << actor_thread_num
232                << ", the computed OMP thread number : " << OMP_thread_num
233                << ", the used OMP thread number : " << OMP_thread_num_used;
234 
235   BuildAndScheduleGlobalActor();
236 }
237 
BuildAndScheduleGlobalActor()238 void GraphScheduler::BuildAndScheduleGlobalActor() {
239   auto actor_manager = ActorMgr::GetActorMgrRef();
240   MS_EXCEPTION_IF_NULL(actor_manager);
241 
242   // Create and schedule memory manager actor.
243   auto memory_manager_actor = std::make_shared<MemoryManagerActor>();
244   MS_EXCEPTION_IF_NULL(memory_manager_actor);
245   memory_manager_aid_ = memory_manager_actor->GetAID();
246   auto base_actor = static_cast<ActorReference>(memory_manager_actor);
247   // Bind single thread to response to memory alloc and free quickly.
248   (void)actor_manager->Spawn(base_actor, false);
249 
250   // Create and schedule recorder actor.
251   auto recorder_actor = std::make_shared<RecorderActor>();
252   MS_EXCEPTION_IF_NULL(recorder_actor);
253   recorder_aid_ = &(recorder_actor->GetAID());
254   auto base_recorder_actor = static_cast<ActorReference>(recorder_actor);
255   (void)actor_manager->Spawn(base_recorder_actor, true);
256 
257   // Create and schedule debug actor.
258 #ifndef ENABLE_SECURITY
259   bool debugger_actor_need = DumpJsonParser::GetInstance().e2e_dump_enabled();
260 #endif
261 #ifdef ENABLE_DEBUGGER
262   if (Debugger::GetInstance()->DebuggerBackendEnabled()) {
263     debugger_actor_need = true;
264   }
265 #endif
266 #ifndef ENABLE_SECURITY
267   if (debugger_actor_need) {
268     auto debug_actor = std::make_shared<DebugActor>();
269     MS_EXCEPTION_IF_NULL(debug_actor);
270     debug_aid_ = &(debug_actor->GetAID());
271     auto base_debug_actor = static_cast<ActorReference>(debug_actor);
272     (void)actor_manager->Spawn(base_debug_actor, true);
273   }
274 #endif
275 }
276 
Transform(const GraphCompilerInfo & graph_compiler_info)277 ActorSet *GraphScheduler::Transform(const GraphCompilerInfo &graph_compiler_info) {
278   MS_LOG(INFO) << "Graph(" << graph_compiler_info.name_ << ") transforms actor begin.";
279   if (graph_compiler_info.graphs_.size() == 0) {
280     MS_LOG(EXCEPTION) << "The number of graphs is zero.";
281   }
282   if (graph_compiler_info.graphs_.size() != graph_compiler_info.device_contexts_.size()) {
283     MS_LOG(EXCEPTION) << "The number of graphs is not equal to the number of device contexts.";
284   }
285 
286   PersistDeviceTensor(graph_compiler_info);
287   const auto &actor_set = Build(graph_compiler_info);
288   MS_EXCEPTION_IF_NULL(actor_set);
289   CacheGraphOutputToActor(graph_compiler_info);
290   Link(actor_set.get(), graph_compiler_info);
291   // The copy actors are built in the link, so need push into the actor set after link.
292   actor_set->copy_actors_ = copy_actors_;
293 
294   (void)actors_.emplace(actor_set->name_, actor_set);
295 
296   DumpActor(actor_set.get(), graph_compiler_info);
297   if (!CheckActorValid(actor_set.get(), graph_compiler_info.strategy_)) {
298     MS_LOG(EXCEPTION) << "The actor set of " << graph_compiler_info.name_ << " is invalid.";
299   }
300   MS_LOG(INFO) << "Graph(" << graph_compiler_info.name_ << ") transforms actor end.";
301 
302   // Local maps and vectors clear.
303   graph_output_to_actor_.clear();
304   front_node_to_actor_.clear();
305   copy_actors_.clear();
306 
307   return actor_set.get();
308 }
309 
Schedule(const ActorSet * actor_set)310 void GraphScheduler::Schedule(const ActorSet *actor_set) {
311   MS_EXCEPTION_IF_NULL(actor_set);
312   auto actors = CollectActors(actor_set);
313   // Schedule actors.
314   auto actor_manager = ActorMgr::GetActorMgrRef();
315   MS_EXCEPTION_IF_NULL(actor_manager);
316   for (auto actor : actors) {
317     (void)actor_manager->Spawn(actor);
318   }
319 }
320 
Run(const ActorSet * actor_set,const std::vector<std::vector<TensorPtr>> & input_tensors,const std::vector<TensorPtr> & input_tensors_with_value_node,GraphExecutionStrategy strategy)321 bool GraphScheduler::Run(const ActorSet *actor_set, const std::vector<std::vector<TensorPtr>> &input_tensors,
322                          const std::vector<TensorPtr> &input_tensors_with_value_node, GraphExecutionStrategy strategy) {
323   MS_EXCEPTION_IF_NULL(actor_set);
324   MS_EXCEPTION_IF_NULL(actor_set->data_prepare_actor_);
325 #if !defined(_WIN32) && !defined(_WIN64)
326   SignalGuard sg(IntHandler);
327 #endif
328 
329   // Construct OpContext.
330   OpContext<DeviceTensor> op_context;
331   std::vector<Promise<int>> result(1);
332   op_context.sequential_num_ = RandInt::Instance().Get();
333   op_context.results_ = &result;
334 
335   if ((strategy == GraphExecutionStrategy::kStep) && IsSingleOpActorSet(actor_set)) {
336     actor_set->data_prepare_actor_->PrepareData(input_tensors, &op_context);
337     MS_EXCEPTION_IF_NULL(actor_set->kernel_actors_[0]);
338     actor_set->kernel_actors_[0]->RunOpControlWithInputTensor(nullptr, &op_context, &input_tensors_with_value_node);
339     return true;
340   }
341 
342   // Trigger data prepare actor running.
343   Async(actor_set->data_prepare_actor_->GetAID(), &DataPrepareActor::PrepareData, input_tensors, &op_context);
344 
345   // Get the run result.
346   auto result_future = result[0].GetFuture();
347   result_future.Wait();
348   MsException::Instance().CheckException();
349   return result_future.IsOK();
350 }
351 
Fetch(const ActorInfo & actor_info) const352 ActorSet *GraphScheduler::Fetch(const ActorInfo &actor_info) const {
353   auto iter = actors_.find(actor_info);
354   if (iter != actors_.end()) {
355     return iter->second.get();
356   } else {
357     MS_LOG(ERROR) << "Can't find the actors map of " << actor_info;
358     return nullptr;
359   }
360 }
361 
Build(const GraphCompilerInfo & graph_compiler_info)362 ActorSetPtr GraphScheduler::Build(const GraphCompilerInfo &graph_compiler_info) {
363   auto actor_set = std::make_shared<ActorSet>(graph_compiler_info.name_);
364   MS_EXCEPTION_IF_NULL(actor_set);
365 
366   auto host_queue = std::make_shared<HostTensorQueue>();
367   actor_set->data_source_actors_ = BuildDataSourceActor(graph_compiler_info, host_queue);
368   actor_set->kernel_actors_ = BuildKernelActor(graph_compiler_info);
369   actor_set->loop_count_actor_ = BuildLoopCountActor(graph_compiler_info);
370   actor_set->output_actor_ = BuildOutputActor(graph_compiler_info);
371   actor_set->data_prepare_actor_ =
372     BuildDataPrepareActor(graph_compiler_info, actor_set->data_source_actors_, host_queue);
373   actor_set->switch_actors_ = BuildSwitchActor(graph_compiler_info);
374   actor_set->gather_actors_ = BuildGatherActor(graph_compiler_info);
375 
376   return actor_set;
377 }
378 
CacheGraphOutputToActor(const GraphCompilerInfo & graph_compiler_info)379 void GraphScheduler::CacheGraphOutputToActor(const GraphCompilerInfo &graph_compiler_info) {
380   if (graph_compiler_info.strategy_ == GraphExecutionStrategy::kStep) {
381     return;
382   }
383 
384   for (const auto &graph : graph_compiler_info.graphs_) {
385     MS_EXCEPTION_IF_NULL(graph);
386     auto outputs = AnfAlgo::GetAllOutputWithIndex(graph->output());
387     for (const auto &output_with_index : outputs) {
388       auto output_kernel = output_with_index.first;
389       MS_EXCEPTION_IF_NULL(output_kernel);
390       auto origin_output_with_index = graph->GetFrontNodeWithIndexByGraphOutput(output_with_index);
391       if (origin_output_with_index.first == nullptr) {
392         MS_LOG(WARNING) << "The graph " << graph->graph_id() << " output node:" << output_kernel->fullname_with_scope()
393                         << " with index: " << output_with_index.second << " has no actor.";
394         continue;
395       }
396 
397       auto actor_output_index = output_with_index.second;
398       OpActor<DeviceTensor> *actor = nullptr;
399       if (IsKernelActor(output_kernel, graph_compiler_info.strategy_)) {
400         actor = FetchActor(output_kernel->fullname_with_scope());
401       } else if (IsDeviceQueueDSActor(output_kernel, graph_compiler_info.strategy_)) {
402         std::string actor_name = graph_compiler_info.name_ + "_DeviceDSActor" + "_" + std::to_string(graph->graph_id());
403         actor = FetchActor(actor_name);
404       } else if (IsHostQueueDSActor(output_kernel, graph, graph_compiler_info.origin_parameters_order_,
405                                     graph_compiler_info.strategy_)) {
406         actor = FetchActor(graph_compiler_info.name_ + "_HostDSActor");
407         const auto &host_ds_actor = dynamic_cast<HostQueueDataSourceActor *>(actor);
408         MS_EXCEPTION_IF_NULL(host_ds_actor);
409         // Get the position of output kernel in the data source actor.
410         actor_output_index = host_ds_actor->FetchNodePosition(output_kernel);
411       } else if (IsPersistentDeviceTensor(output_kernel)) {
412         MS_LOG(INFO) << "The graph " << graph->graph_id() << " output node:" << output_kernel->fullname_with_scope()
413                      << " is device tensor store.";
414         continue;
415       } else {
416         MS_LOG(INFO) << "Ignore the internal parameter node:" << output_kernel->DebugString();
417         continue;
418       }
419 
420       MS_EXCEPTION_IF_NULL(actor);
421       MS_LOG(INFO) << "Cache the graph " << graph->graph_id() << " output node:" << output_kernel->fullname_with_scope()
422                    << " with index: " << output_with_index.second << " to actor:" << actor->GetAID().Name()
423                    << " with index:" << actor_output_index
424                    << ", from front node:" << origin_output_with_index.first->fullname_with_scope()
425                    << " with index: " << origin_output_with_index.second;
426       (void)graph_output_to_actor_.emplace(origin_output_with_index,
427                                            GraphOutputPair(dynamic_cast<AbstractActor *>(actor), actor_output_index));
428     }
429   }
430 }
431 
Link(ActorSet * actor_set,const GraphCompilerInfo & graph_compiler_info)432 void GraphScheduler::Link(ActorSet *actor_set, const GraphCompilerInfo &graph_compiler_info) {
433   MS_EXCEPTION_IF_NULL(actor_set);
434   std::vector<KernelActor *> auto_monad_actors;
435   std::vector<CNodePtr> communication_nodes;
436   const std::unordered_set<PrimitivePtr, PrimitiveHasher, PrimitiveEqual> auto_monad_prims = {
437     prim::kPrimDepend, prim::kPrimUpdateState, prim::kPrimLoad};
438 
439   // Foreach the execution order to link the actors.
440   for (size_t index = 0; index < graph_compiler_info.graphs_.size(); ++index) {
441     const auto &graph = graph_compiler_info.graphs_[index];
442     MS_EXCEPTION_IF_NULL(graph);
443     auto execution_order = graph->execution_order();
444     for (auto &kernel : execution_order) {
445       MS_EXCEPTION_IF_NULL(kernel);
446       if (AnfAlgo::IsCommunicationOp(kernel)) {
447         (void)communication_nodes.emplace_back(kernel);
448       }
449       if (IsSkippedKernelActor(kernel) || (!IsKernelActor(kernel, graph_compiler_info.strategy_))) {
450         continue;
451       }
452       const auto &kernel_actor = dynamic_cast<KernelActor *>(FetchActor(kernel->fullname_with_scope()));
453       MS_EXCEPTION_IF_NULL(kernel_actor);
454 
455       for (size_t i = 0; i < AnfAlgo::GetInputNum(kernel); ++i) {
456         auto input_node = AnfAlgo::GetInputNode(kernel, i);
457         // Link the control arrows of kernel actor by the auto monad, the inputs include monad node.
458         if (AnfAlgo::IsOneOfPrimitiveCNode(input_node, auto_monad_prims)) {
459           LinkControlArrowByAutoMonad(kernel_actor, input_node, graph);
460         }
461         if (HasAbstractMonad(input_node)) {
462           (void)auto_monad_actors.emplace_back(kernel_actor);
463           continue;  // No data arrow for monad input.
464         }
465 
466         KernelWithIndex from_kernel_with_output_idx = AnfAlgo::VisitKernelWithReturnType(input_node, 0, false);
467         KernelWithIndex to_kernel_with_input_idx = std::make_pair(kernel, i);
468         // The gather of linking data arrows of kernel by the different from kernel type.
469         LinkDataArrow(kernel_actor, graph_compiler_info, graph, from_kernel_with_output_idx, to_kernel_with_input_idx);
470       }
471     }
472     // Link the control arrows for allreduce kernel by the send/recv nodes in the kernel graph.
473     LinkControlArrowBySendRecvNodes(graph);
474   }
475 
476   // Link the arrow in the control flow scene.
477   if (graph_compiler_info.strategy_ == GraphExecutionStrategy::kPipeline) {
478     LinkArrowByControlNode(graph_compiler_info, actor_set);
479   }
480 
481   LinkGlobalControlArrow(actor_set, communication_nodes, auto_monad_actors, graph_compiler_info);
482   LinkOutputResultArrowForOutputActor(actor_set->output_actor_.get(), graph_compiler_info);
483 }
484 
BuildDataSourceActor(const GraphCompilerInfo & graph_compiler_info,const HostTensorQueuePtr & host_queue)485 std::vector<DataSourceActorPtr> GraphScheduler::BuildDataSourceActor(const GraphCompilerInfo &graph_compiler_info,
486                                                                      const HostTensorQueuePtr &host_queue) {
487   std::vector<DataSourceActorPtr> data_source_actors;
488   HostQueueDSActorPtr host_queue_ds_actor = nullptr;
489   size_t data_node_position = 0;
490   std::unordered_map<AnfNodePtr, size_t> front_node_position_temp_map;
491 
492   for (size_t i = 0; i < graph_compiler_info.graphs_.size(); ++i) {
493     const auto &graph = graph_compiler_info.graphs_[i];
494     const auto &device_context = graph_compiler_info.device_contexts_[i];
495     MS_EXCEPTION_IF_NULL(graph);
496     // Build host queue data source actor.
497     const std::vector<AnfNodePtr> &input_nodes = graph->input_nodes();
498 
499     for (size_t j = 0; j < input_nodes.size(); j++) {
500       const auto &input_node = input_nodes[j];
501       MS_EXCEPTION_IF_NULL(input_node);
502 
503       if (IsHostQueueDSActor(input_node, graph, graph_compiler_info.origin_parameters_order_,
504                              graph_compiler_info.strategy_)) {
505         if (host_queue_ds_actor == nullptr) {
506           auto actor_name = graph_compiler_info.name_ + "_HostDSActor";
507           MS_LOG(INFO) << "Create host queue data source actor: " << actor_name;
508           host_queue_ds_actor = std::make_shared<HostQueueDataSourceActor>(actor_name, 1, memory_manager_aid_, nullptr,
509                                                                            nullptr, host_queue);
510           InsertActor(host_queue_ds_actor.get());
511           (void)data_source_actors.emplace_back(host_queue_ds_actor);
512         }
513 
514         const auto &front_node = FetchFrontNodeByBackendNode(input_node, graph);
515         // In the scenario where multiple backend nodes correspond to the same front node, only the first backend node
516         // is saved in the host queue data source actor.
517         if (front_node_position_temp_map.count(front_node) > 0) {
518           (void)host_queue_ds_actor->data_node_position_map_.emplace(input_node,
519                                                                      front_node_position_temp_map[front_node]);
520           continue;
521         }
522         (void)host_queue_ds_actor->data_nodes_.emplace_back(input_node);
523         (void)host_queue_ds_actor->device_contexts_.emplace_back(device_context);
524         (void)host_queue_ds_actor->data_node_position_map_.emplace(input_node, data_node_position);
525         (void)front_node_position_temp_map.emplace(front_node, data_node_position);
526         data_node_position++;
527       }
528     }
529 
530     // Build device queue data source actor.
531     const auto &execution_order = graph->execution_order();
532     const auto &iter =
533       std::find_if(execution_order.begin(), execution_order.end(), [&graph_compiler_info](const CNodePtr &node) {
534         return IsDeviceQueueDSActor(node, graph_compiler_info.strategy_);
535       });
536     if (iter != execution_order.end()) {
537       auto actor_name = graph_compiler_info.name_ + "_DeviceDSActor" + "_" + std::to_string(graph->graph_id());
538       MS_LOG(INFO) << "Create queue data source actor: " << actor_name;
539       auto device_queue_ds_actor = std::make_shared<DeviceQueueDataSourceActor>(
540         actor_name, 1, device_context, memory_manager_aid_, debug_aid_, recorder_aid_);
541       MS_EXCEPTION_IF_NULL(device_queue_ds_actor);
542       InsertActor(device_queue_ds_actor.get());
543       (void)data_source_actors.emplace_back(device_queue_ds_actor);
544       device_queue_ds_actor->data_kernel_ = *iter;
545       device_queue_ds_actor->kernel_info_ = dynamic_cast<device::KernelInfo *>((*iter)->kernel_info());
546     }
547   }
548 
549   MS_EXCEPTION_IF_NULL(graph_compiler_info.control_node_parser_);
550   const auto &front_to_backend_parameter = graph_compiler_info.control_node_parser_->front_to_backend_parameters_;
551 
552   // Initialize the parameter in the control node, first get all the front parameters in the control node, then find
553   // the corresponding backend parameter from the map, and insert it into the host data source actor
554   const auto &control_node_parameters = graph_compiler_info.control_node_parser_->control_node_parameters();
555   for (const auto &parameter : control_node_parameters) {
556     if (IsPersistentDeviceTensor(parameter)) {
557       continue;
558     }
559     auto backend_iter = front_to_backend_parameter.find(parameter);
560     if (backend_iter == front_to_backend_parameter.end()) {
561       MS_LOG(EXCEPTION) << "Cannot find backend node for front node:" << AnfAlgo::GetNodeDebugString(parameter);
562     }
563 
564     if (host_queue_ds_actor == nullptr) {
565       auto actor_name = graph_compiler_info.name_ + "_HostDSActor";
566       MS_LOG(INFO) << "Create host queue data source actor: " << actor_name;
567       host_queue_ds_actor =
568         std::make_shared<HostQueueDataSourceActor>(actor_name, 1, memory_manager_aid_, nullptr, nullptr, host_queue);
569       InsertActor(host_queue_ds_actor.get());
570       (void)data_source_actors.emplace_back(host_queue_ds_actor);
571     }
572 
573     const auto &backend_node = backend_iter->second.first;
574     auto iter = find(host_queue_ds_actor->data_nodes_.begin(), host_queue_ds_actor->data_nodes_.end(), backend_node);
575     if (iter != host_queue_ds_actor->data_nodes_.end()) {
576       (void)host_queue_ds_actor->data_node_position_map_.emplace(parameter,
577                                                                  iter - host_queue_ds_actor->data_nodes_.begin());
578     } else {
579       (void)host_queue_ds_actor->data_node_position_map_.emplace(parameter, host_queue_ds_actor->data_nodes_.size());
580       (void)host_queue_ds_actor->data_nodes_.emplace_back(backend_iter->second.first);
581       (void)host_queue_ds_actor->device_contexts_.emplace_back(backend_iter->second.second);
582     }
583   }
584 
585   return data_source_actors;
586 }
587 
BuildKernelActor(const GraphCompilerInfo & graph_compiler_info)588 std::vector<KernelActorPtr> GraphScheduler::BuildKernelActor(const GraphCompilerInfo &graph_compiler_info) {
589   std::vector<KernelActorPtr> kernel_actors;
590 
591   for (size_t i = 0; i < graph_compiler_info.graphs_.size(); ++i) {
592     const auto &graph = graph_compiler_info.graphs_[i];
593     const auto &device_context = graph_compiler_info.device_contexts_[i];
594     MS_EXCEPTION_IF_NULL(graph);
595     auto execution_order = graph->execution_order();
596 
597     // Single op graph in step mode, kernel actor executes synchronously.
598     bool is_single_op_graph = execution_order.size() == 1;
599     GraphExecutionStrategy strategy = graph_compiler_info.strategy_;
600     if (strategy == GraphExecutionStrategy::kStep) {
601       strategy = (is_single_op_graph ? strategy : GraphExecutionStrategy::kPipeline);
602     }
603 
604     for (auto &kernel : execution_order) {
605       MS_EXCEPTION_IF_NULL(kernel);
606       if (IsKernelActor(kernel, graph_compiler_info.strategy_) && (!IsSkippedKernelActor(kernel))) {
607         auto kernel_actor = std::make_shared<KernelActor>(kernel->fullname_with_scope(), kernel, device_context,
608                                                           memory_manager_aid_, debug_aid_, recorder_aid_, strategy);
609         MS_EXCEPTION_IF_NULL(kernel_actor);
610         InsertActor(kernel_actor.get());
611         (void)kernel_actors.emplace_back(kernel_actor);
612         auto front_node = graph->GetFrontAnfByBackendAnf(kernel);
613         if (front_node != nullptr) {
614           front_node_to_actor_[front_node] = kernel_actor;
615         }
616       }
617     }
618   }
619   return kernel_actors;
620 }
621 
BuildLoopCountActor(const GraphCompilerInfo & graph_compiler_info)622 LoopCountActorPtr GraphScheduler::BuildLoopCountActor(const GraphCompilerInfo &graph_compiler_info) {
623   if (graph_compiler_info.strategy_ == GraphExecutionStrategy::kStep) {
624     return nullptr;
625   }
626 
627   auto loop_count = ConfigManager::GetInstance().iter_num();
628   auto actor_name = graph_compiler_info.name_ + "_LoopCountActor";
629   auto loop_count_actor =
630     std::make_shared<LoopCountActor>(actor_name, loop_count, memory_manager_aid_, debug_aid_, recorder_aid_);
631   MS_LOG(INFO) << "Create loop count actor: " << actor_name;
632   MS_EXCEPTION_IF_NULL(loop_count_actor);
633 
634   InsertActor(loop_count_actor.get());
635   return loop_count_actor;
636 }
637 
BuildOutputActor(const GraphCompilerInfo & graph_compiler_info)638 OutputActorPtr GraphScheduler::BuildOutputActor(const GraphCompilerInfo &graph_compiler_info) {
639   if (graph_compiler_info.strategy_ == GraphExecutionStrategy::kStep) {
640     return nullptr;
641   }
642 
643   auto loop_count = ConfigManager::GetInstance().iter_num();
644   auto actor_name = graph_compiler_info.name_ + "_" + "OutputActor";
645   bool need_loop_count = (graph_compiler_info.strategy_ == GraphExecutionStrategy::kPipeline) ? true : false;
646 
647   auto output_actor =
648     std::make_shared<OutputActor>(actor_name, loop_count, graph_compiler_info.outputs_num_, need_loop_count);
649   MS_LOG(INFO) << "Create output actor: " << actor_name;
650   MS_EXCEPTION_IF_NULL(output_actor);
651   InsertActor(output_actor.get());
652   return output_actor;
653 }
654 
BuildDataPrepareActor(const GraphCompilerInfo & graph_compiler_info,const std::vector<DataSourceActorPtr> & data_source_actors,const HostTensorQueuePtr & host_queue)655 DataPrepareActorPtr GraphScheduler::BuildDataPrepareActor(const GraphCompilerInfo &graph_compiler_info,
656                                                           const std::vector<DataSourceActorPtr> &data_source_actors,
657                                                           const HostTensorQueuePtr &host_queue) {
658   HostQueueDSActorPtr host_queue_ds_actor = nullptr;
659   auto iter = std::find_if(data_source_actors.begin(), data_source_actors.end(), [&](const auto &data_source_actor) {
660     return data_source_actor->type_ == KernelTransformType::kHostDataSourceActor;
661   });
662   if (iter != data_source_actors.end()) {
663     host_queue_ds_actor = std::dynamic_pointer_cast<HostQueueDataSourceActor>(*iter);
664   }
665 
666   auto actor_name = graph_compiler_info.name_ + "_DataPrepareActor";
667   auto data_prepare_actor = std::make_shared<DataPrepareActor>(actor_name, memory_manager_aid_, debug_aid_,
668                                                                &graph_compiler_info, host_queue_ds_actor, host_queue);
669   MS_LOG(INFO) << "Create data prepare actor: " << actor_name;
670   MS_EXCEPTION_IF_NULL(data_prepare_actor);
671 
672   // Cache the nodes which need continuous memory.
673   if (graph_compiler_info.strategy_ == GraphExecutionStrategy::kPipeline) {
674     for (size_t index = 0; index < graph_compiler_info.graphs_.size(); ++index) {
675       const auto &graph = graph_compiler_info.graphs_[index];
676       MS_EXCEPTION_IF_NULL(graph);
677       auto &execution_order = graph->execution_order();
678       for (auto &kernel : execution_order) {
679         if (!AnfAlgo::IsCommunicationOp(kernel)) {
680           continue;
681         }
682 
683         auto key = std::make_pair(kernel, graph_compiler_info.device_contexts_[index]);
684         auto value = std::make_pair(false, false);
685         if (AnfAlgo::GetInputTensorNum(kernel) > 1) {
686           value.first = true;
687         }
688         if (AnfAlgo::GetOutputTensorNum(kernel) > 1) {
689           value.second = true;
690         }
691         if ((value.first == true) || (value.second == true)) {
692           data_prepare_actor->continuous_memory_nodes_[key] = value;
693         }
694       }
695     }
696   }
697 
698   InsertActor(data_prepare_actor.get());
699   return data_prepare_actor;
700 }
701 
BuildNoInputKernelActor(const ActorSet * actor_set,GraphExecutionStrategy strategy)702 std::vector<KernelActorPtr> GraphScheduler::BuildNoInputKernelActor(const ActorSet *actor_set,
703                                                                     GraphExecutionStrategy strategy) {
704   MS_EXCEPTION_IF_NULL(actor_set);
705   std::vector<KernelActorPtr> no_input_kernel_actors;
706 
707   for (auto &kernel_actor : actor_set->kernel_actors_) {
708     MS_EXCEPTION_IF_NULL(kernel_actor);
709     // Framework will trigger kernel actor running in the step execution strategy.
710     if (strategy == GraphExecutionStrategy::kStep && IsSingleOpActorSet(actor_set)) {
711       kernel_actor->input_controls_num_++;
712       continue;
713     }
714 
715     if ((kernel_actor->input_datas_num_ == 0) && (kernel_actor->input_controls_num_ == 0)) {
716       // Check whether the kernel actor belongs to the root graph.
717       // In general, all no input nodes belong to the root funcgraph, and the corresponding gather actor should be
718       // empty. In control flow, the control arrow of the no input node in the sub funcgraph should be sent by the
719       // gather actor and should not be placed in the no input list.
720       MS_EXCEPTION_IF_NULL(kernel_actor->kernel_);
721       const auto &graph = kernel_actor->kernel_->func_graph();
722       if (graph != nullptr) {
723         const auto &kernel_graph = dynamic_cast<KernelGraph *>(graph.get());
724         MS_EXCEPTION_IF_NULL(kernel_graph);
725         const auto func_graph = kernel_graph->GetFuncGraph();
726         if (func_graph != nullptr && FetchActor(func_graph->ToString()) != nullptr) {
727           continue;
728         }
729       }
730 
731       (void)no_input_kernel_actors.emplace_back(kernel_actor);
732     }
733   }
734   return no_input_kernel_actors;
735 }
736 
BuildSwitchActor(const GraphCompilerInfo & graph_compiler_info)737 std::vector<SwitchActorPtr> GraphScheduler::BuildSwitchActor(const GraphCompilerInfo &graph_compiler_info) {
738   std::vector<SwitchActorPtr> switch_actors;
739   std::unordered_map<AnfNodePtr, AnfNodePtr> front_to_backend_kernel;
740   for (const auto &pair : front_node_to_actor_) {
741     front_to_backend_kernel[pair.first] = pair.second->kernel_;
742   }
743 
744   // Build switch actor by switch node and switchlayer node.
745   for (const auto &control_node : graph_compiler_info.control_nodes_) {
746     if (AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimSwitch) ||
747         AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimSwitchLayer)) {
748       const auto func_graph = control_node->func_graph();
749       const auto branch_id = graph_compiler_info.control_node_parser_->GetBranchIDByFuncGraph(func_graph);
750       const auto &actor_name = control_node->DebugString();
751       auto switch_actor = std::make_shared<SwitchActor>(actor_name, graph_compiler_info.device_contexts_[0],
752                                                         control_node->cast<CNodePtr>(), branch_id, false);
753       switch_actor->ParseInput(graph_compiler_info.control_node_parser_);
754 
755       // Fetch all the input nodes of switch actor.
756       switch_actor->FetchInputNode(graph_compiler_info.control_node_parser_);
757       InsertActor(switch_actor.get());
758       (void)switch_actors.emplace_back(switch_actor);
759     }
760   }
761 
762   // Build switch actor by return node.
763   const auto func_graphs_to_call_num = graph_compiler_info.control_node_parser_->func_graph_to_call_num_;
764   for (const auto &func_graph_to_call_num : func_graphs_to_call_num) {
765     const auto &return_node = func_graph_to_call_num.first->get_return();
766     MS_EXCEPTION_IF_NULL(return_node);
767     const auto &actor_name = return_node->DebugString();
768     auto switch_actor = std::make_shared<SwitchActor>(actor_name, graph_compiler_info.device_contexts_[0],
769                                                       return_node->cast<CNodePtr>(), kInvalidBranchID, true);
770     switch_actor->ParseInput(graph_compiler_info.control_node_parser_);
771 
772     // Fetch all the input nodes of switch actor.
773     switch_actor->FetchInputNode(graph_compiler_info.control_node_parser_);
774     InsertActor(switch_actor.get());
775     (void)switch_actors.emplace_back(switch_actor);
776   }
777 
778   return switch_actors;
779 }
780 
BuildGatherActor(const GraphCompilerInfo & graph_compiler_info)781 std::vector<GatherActorPtr> GraphScheduler::BuildGatherActor(const GraphCompilerInfo &graph_compiler_info) {
782   std::vector<GatherActorPtr> gather_actors;
783 
784   const auto &loop_count_actor_name = graph_compiler_info.name_ + "_LoopCountActor";
785   const auto &loop_count_actor = FetchActor(loop_count_actor_name);
786   if (loop_count_actor == nullptr) {
787     return gather_actors;
788   }
789 
790   const auto &output_actor_name = graph_compiler_info.name_ + "_" + "OutputActor";
791   const auto &output_actor = FetchActor(output_actor_name);
792   MS_EXCEPTION_IF_NULL(output_actor);
793 
794   const auto parser = graph_compiler_info.control_node_parser_;
795 
796   bool is_main_return = true;
797   // Each funcgraph has a return node, get the funcgraph from the return node, and create a gather actor.
798   std::unordered_map<AnfNodePtr, AnfNodePtr> front_to_backend_kernel;
799   for (const auto &pair : front_node_to_actor_) {
800     front_to_backend_kernel[pair.first] = pair.second->kernel_;
801   }
802 
803   for (const auto &control_node : graph_compiler_info.control_nodes_) {
804     const auto &func_graph = control_node->func_graph();
805     const auto &cnode = control_node->cast<CNodePtr>();
806     const auto &inputs = cnode->inputs();
807     const auto &return_node = func_graph->get_return();
808 
809     if (AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimReturn)) {
810       // Root funcgraph does not need to create a gather actor.
811       if (is_main_return) {
812         is_main_return = false;
813         continue;
814       }
815 
816       if (AnfAlgo::CheckPrimitiveType(inputs[kReturnInputPos], prim::kPrimPartial)) {
817         continue;
818       }
819       auto actor_name = func_graph->ToString();
820       std::vector<KernelWithIndex> parameters;
821       for (const auto &parameter : func_graph->get_inputs()) {
822         if (HasAbstractMonad(parameter) || HasAbstractRef(parameter)) {
823           continue;
824         }
825         (void)parameters.emplace_back(parameter, 0);
826       }
827 
828       const auto branch_id = parser->GetBranchIDByFuncGraph(func_graph);
829 
830       const auto &output_switch_actor = FetchActor(return_node->DebugString());
831       MS_EXCEPTION_IF_NULL(output_switch_actor);
832       const auto &output_switch_aid = output_switch_actor->GetAID();
833 
834       auto gather_actor =
835         std::make_shared<GatherActor>(actor_name, parameters, true, output_switch_aid, AID(), branch_id);
836       gather_actor->FetchBackendInputNode(func_graph, graph_compiler_info.control_node_parser_);
837       InsertActor(gather_actor.get());
838       (void)gather_actors.emplace_back(gather_actor);
839     }
840   }
841 
842   // Create gather actor for call node which input0 of call node is a funcgraph.
843   for (const auto &control_node : graph_compiler_info.control_nodes_) {
844     const auto &cnode = control_node->cast<CNodePtr>();
845     const auto &inputs = cnode->inputs();
846 
847     if (inputs[0]->isa<ValueNode>() && IsValueNode<FuncGraph>(inputs[0])) {
848       // Collect the parameters.
849       std::vector<KernelWithIndex> parameters;
850       for (size_t i = kCallInputStartPos; i < inputs.size(); ++i) {
851         if (HasAbstractMonad(inputs[i]) || (inputs[i]->isa<Parameter>() && HasAbstractRef(inputs[i]))) {
852           continue;
853         }
854         (void)parameters.emplace_back(inputs[i], 0);
855       }
856 
857       auto func_graph = control_node->func_graph();
858       auto actor_name = control_node->DebugString();
859       const auto branch_id = parser->GetBranchIDByFuncGraph(func_graph);
860       const auto &to_func_graph = GetValueNode<FuncGraphPtr>(inputs[0]);
861       const auto &to_actor = FetchActor(to_func_graph->ToString());
862       auto gather_actor =
863         std::make_shared<GatherActor>(actor_name, parameters, false, AID(), to_actor->GetAID(), branch_id);
864       gather_actor->FetchBackendInputNode(func_graph, graph_compiler_info.control_node_parser_);
865 
866       InsertActor(gather_actor.get());
867       (void)gather_actors.emplace_back(gather_actor);
868     }
869   }
870 
871   // Create gather actor for kernel graph which has a call input.
872   const auto &graph_with_device_contexts = graph_compiler_info.control_node_parser_->call_input_kernel_graphs_;
873   for (const auto &graph_with_device_context : graph_with_device_contexts) {
874     const auto &graph = graph_with_device_context.first;
875     const auto &parameters = FetchParameterbyKernelGraph(graph);
876 
877     auto actor_name = graph->ToString();
878     auto gather_actor = std::make_shared<GatherActor>(actor_name, parameters, false, AID(), AID(), kInvalidBranchID);
879     InsertActor(gather_actor.get());
880     (void)gather_actors.emplace_back(gather_actor);
881   }
882 
883   return gather_actors;
884 }
885 
LinkDataArrow(KernelActor * const to_actor,const GraphCompilerInfo & graph_compiler_info,const KernelGraphPtr & graph,const KernelWithIndex & from_kernel_with_output_idx,const KernelWithIndex & to_kernel_with_input_idx)886 void GraphScheduler::LinkDataArrow(KernelActor *const to_actor, const GraphCompilerInfo &graph_compiler_info,
887                                    const KernelGraphPtr &graph, const KernelWithIndex &from_kernel_with_output_idx,
888                                    const KernelWithIndex &to_kernel_with_input_idx) {
889   MS_EXCEPTION_IF_NULL(to_actor);
890   MS_EXCEPTION_IF_NULL(graph);
891 
892   auto from_kernel = from_kernel_with_output_idx.first;
893   MS_EXCEPTION_IF_NULL(from_kernel);
894   MS_EXCEPTION_IF_NULL(graph_compiler_info.control_node_parser_);
895   if (from_kernel->isa<Parameter>() && graph_compiler_info.control_node_parser_->IsCallInputKernelGraph(graph)) {
896     const auto &kernel_with_index = GetFrontNodeByKernelGraph(from_kernel, graph);
897     const auto &real_front_node_with_index =
898       AnfAlgo::VisitKernelWithReturnType(kernel_with_index.first, kernel_with_index.second);
899     if (HasAbstractRef(real_front_node_with_index.first)) {
900       (void)to_actor->device_tensor_store_keys_.emplace_back(to_kernel_with_input_idx.second,
901                                                              real_front_node_with_index.first);
902       return;
903     }
904     // When there is a call input in the kernel graph, all the inputs of the kernel graph needs to be sent by gather.
905     const auto actor_name = graph->ToString();
906     auto actor = FetchActor(actor_name);
907     MS_EXCEPTION_IF_NULL(actor);
908     LinkDataArrowForGatherActor(dynamic_cast<GatherActor *>(actor), to_actor, real_front_node_with_index,
909                                 to_kernel_with_input_idx);
910     return;
911   }
912 
913   auto front_node = GetFrontNodeByBackendNode(from_kernel);
914   if (front_node != nullptr && IsGatherActor(front_node, actor_name_to_actor_)) {
915     // Link the data arrows of gather actor.
916     auto func_graph = GetFuncgraphByBackendNode(from_kernel);
917     if (func_graph == nullptr) {
918       MS_LOG(EXCEPTION) << "Cannot find funcgraph of node:" << AnfAlgo::GetNodeDebugString(from_kernel);
919     }
920     auto actor_name = func_graph->ToString();
921     const auto &from_actor = dynamic_cast<GatherActor *>(FetchActor(actor_name));
922     if (HasAbstractRef(from_kernel)) {
923       (void)to_actor->device_tensor_store_keys_.emplace_back(to_kernel_with_input_idx.second, front_node);
924       return;
925     }
926     LinkDataArrowForGatherActor(from_actor, to_actor, {front_node, 0}, to_kernel_with_input_idx);
927     return;
928   }
929 
930   auto kernel_type = KernelTransformType::kUnknown;
931   std::string kernel_name = "";
932   FetchKernelTransformTypeAndName(from_kernel, graph, graph_compiler_info, &kernel_type, &kernel_name);
933   auto from_actor = dynamic_cast<AbstractActor *>(FetchActor(kernel_name));
934   if (kKernelTypeToLinkFunc.count(kernel_type) > 0) {
935     (this->*kKernelTypeToLinkFunc[kernel_type])(from_actor, to_actor, from_kernel_with_output_idx,
936                                                 to_kernel_with_input_idx, graph);
937   }
938 }
939 
LinkDataArrowForDeviceTensorStore(AbstractActor * const,KernelActor * const to_actor,const KernelWithIndex & from_kernel_with_output_idx,const KernelWithIndex & to_kernel_with_input_idx,const KernelGraphPtr & graph)940 void GraphScheduler::LinkDataArrowForDeviceTensorStore(AbstractActor *const, KernelActor *const to_actor,
941                                                        const KernelWithIndex &from_kernel_with_output_idx,
942                                                        const KernelWithIndex &to_kernel_with_input_idx,
943                                                        const KernelGraphPtr &graph) {
944   MS_EXCEPTION_IF_NULL(to_actor);
945   MS_EXCEPTION_IF_NULL(graph);
946   auto from_kernel = from_kernel_with_output_idx.first;
947   MS_EXCEPTION_IF_NULL(from_kernel);
948 
949   auto device_tensor_store_key = FetchFrontNodeByBackendNode(from_kernel, graph);
950   (void)to_actor->device_tensor_store_keys_.emplace_back(to_kernel_with_input_idx.second, device_tensor_store_key);
951 }
952 
LinkDataArrowForInternalParameter(AbstractActor * const,KernelActor * to_actor,const KernelWithIndex & from_kernel_with_output_idx,const KernelWithIndex & to_kernel_with_input_idx,const KernelGraphPtr & graph)953 void GraphScheduler::LinkDataArrowForInternalParameter(AbstractActor *const, KernelActor *to_actor,
954                                                        const KernelWithIndex &from_kernel_with_output_idx,
955                                                        const KernelWithIndex &to_kernel_with_input_idx,
956                                                        const KernelGraphPtr &graph) {
957   MS_EXCEPTION_IF_NULL(to_actor);
958   MS_EXCEPTION_IF_NULL(graph);
959   auto internal_parameter = from_kernel_with_output_idx.first;
960   MS_EXCEPTION_IF_NULL(internal_parameter);
961 
962   // Parameter ---> front node.
963   auto front_output_with_index = graph->GetFrontNodeByInternalParameter(internal_parameter);
964   auto front_output_node = front_output_with_index.first;
965   MS_EXCEPTION_IF_NULL(front_output_node);
966   if (IsSwitchActor(front_output_node)) {
967     auto switch_actor = dynamic_cast<SwitchActor *>(FetchActor(front_output_node->DebugString()));
968     MS_EXCEPTION_IF_NULL(switch_actor);
969     LinkDataArrowForSwitchActor(switch_actor, 0, to_actor, to_kernel_with_input_idx.second);
970     to_actor->input_datas_num_++;
971     return;
972   }
973 
974   auto real_from_kernel_with_output_idx = from_kernel_with_output_idx;
975   AbstractActor *real_from_actor = nullptr;
976   KernelTransformType kernel_type;
977   if (IsPersistentDeviceTensor(front_output_node)) {
978     kernel_type = KernelTransformType::kDeviceTensorStore;
979   } else {
980     // front node ---> actor.
981     if (graph_output_to_actor_.count(front_output_with_index) == 0) {
982       MS_LOG(EXCEPTION) << "Can't find actor by front node:" << AnfAlgo::GetNodeDebugString(front_output_node)
983                         << ", internal parameter:" << AnfAlgo::GetNodeDebugString(internal_parameter);
984     }
985     auto actor_pair = graph_output_to_actor_[front_output_with_index];
986     MS_EXCEPTION_IF_NULL(actor_pair.first);
987     MS_LOG(INFO) << "Graph " << graph->graph_id() << " internal parameter:" << internal_parameter->DebugString()
988                  << ", corresponding front node:" << front_output_node->fullname_with_scope()
989                  << " with index:" << front_output_with_index.second
990                  << ", from actor:" << actor_pair.first->GetAID().Name() << " with index:" << actor_pair.second
991                  << ", to actor:" << to_actor->GetAID().Name() << " with index:" << to_kernel_with_input_idx.second;
992     real_from_actor = actor_pair.first;
993     real_from_kernel_with_output_idx = KernelWithIndex(nullptr, actor_pair.second);
994     kernel_type = actor_pair.first->type_;
995   }
996 
997   if (kKernelTypeToLinkFunc.count(kernel_type) == 0) {
998     MS_LOG(EXCEPTION) << "Invalid internal parameter:" << internal_parameter->DebugString() << ", type:" << kernel_type;
999   }
1000   (this->*kKernelTypeToLinkFunc[kernel_type])(real_from_actor, to_actor, real_from_kernel_with_output_idx,
1001                                               to_kernel_with_input_idx, graph);
1002 }
1003 
LinkDataArrowForBaseActor(AbstractActor * const from_actor,KernelActor * const to_actor,const KernelWithIndex & from_kernel_with_output_idx,const KernelWithIndex & to_kernel_with_input_idx)1004 void GraphScheduler::LinkDataArrowForBaseActor(AbstractActor *const from_actor, KernelActor *const to_actor,
1005                                                const KernelWithIndex &from_kernel_with_output_idx,
1006                                                const KernelWithIndex &to_kernel_with_input_idx) {
1007   MS_EXCEPTION_IF_NULL(from_actor);
1008   MS_EXCEPTION_IF_NULL(to_actor);
1009 
1010   auto from_kernel = from_kernel_with_output_idx.first;
1011   MS_EXCEPTION_IF_NULL(from_kernel);
1012   auto from_output_index = from_kernel_with_output_idx.second;
1013   auto to_input_index = to_kernel_with_input_idx.second;
1014 
1015   // Get the position of from kernel in the data source actor.
1016   auto position = from_actor->FetchNodePosition(from_kernel);
1017   if ((from_actor->device_contexts_.size() <= position) || (to_actor->device_contexts_.empty())) {
1018     MS_LOG(EXCEPTION) << "The device contexts size is wrong.";
1019   }
1020 
1021   if (IsNeedInsertCopyActor(from_actor->device_contexts_[position], to_actor->device_contexts_[0])) {
1022     LinkDataArrowForCopyActor(from_actor, to_actor, from_kernel_with_output_idx, to_kernel_with_input_idx);
1023   } else {
1024     auto to_aid = to_actor->GetAID();
1025     auto op_arrow = std::make_shared<DataArrow>(from_output_index, to_aid, to_input_index);
1026     // If the from actor has the multi nodes, then use the real output position.
1027     if (position != 0) {
1028       op_arrow->from_output_index_ = SizeToInt(position);
1029     }
1030 
1031     (void)from_actor->output_data_arrows_.emplace_back(op_arrow);
1032     to_actor->input_datas_num_++;
1033     (void)to_actor->input_data_arrow_aids_.emplace_back(from_actor->GetAID());
1034 
1035     // Update the reference count of device tensor.
1036     UpdateRefCount(from_kernel, from_output_index);
1037   }
1038 }
1039 
LinkDataArrowForDeviceDSActor(AbstractActor * const from_actor,KernelActor * const to_actor,const KernelWithIndex & from_kernel_with_output_idx,const KernelWithIndex & to_kernel_with_input_idx,const KernelGraphPtr &)1040 void GraphScheduler::LinkDataArrowForDeviceDSActor(AbstractActor *const from_actor, KernelActor *const to_actor,
1041                                                    const KernelWithIndex &from_kernel_with_output_idx,
1042                                                    const KernelWithIndex &to_kernel_with_input_idx,
1043                                                    const KernelGraphPtr &) {
1044   auto real_from_kernel_with_output_idx = from_kernel_with_output_idx;
1045   if (real_from_kernel_with_output_idx.first == nullptr) {
1046     auto device_ds_actor = dynamic_cast<DeviceQueueDataSourceActor *>(from_actor);
1047     MS_EXCEPTION_IF_NULL(device_ds_actor);
1048     real_from_kernel_with_output_idx.first = device_ds_actor->data_kernel_;
1049   }
1050 
1051   LinkDataArrowForBaseActor(from_actor, to_actor, real_from_kernel_with_output_idx, to_kernel_with_input_idx);
1052 }
1053 
LinkDataArrowForHostDSActor(AbstractActor * const from_actor,KernelActor * const to_actor,const KernelWithIndex & from_kernel_with_output_idx,const KernelWithIndex & to_kernel_with_input_idx,const KernelGraphPtr &)1054 void GraphScheduler::LinkDataArrowForHostDSActor(AbstractActor *const from_actor, KernelActor *const to_actor,
1055                                                  const KernelWithIndex &from_kernel_with_output_idx,
1056                                                  const KernelWithIndex &to_kernel_with_input_idx,
1057                                                  const KernelGraphPtr &) {
1058   auto host_ds_actor = dynamic_cast<HostQueueDataSourceActor *>(from_actor);
1059   MS_EXCEPTION_IF_NULL(host_ds_actor);
1060 
1061   KernelWithIndex real_from_kernel_with_output_idx;
1062   if (from_kernel_with_output_idx.first != nullptr) {
1063     // Get the position of from kernel in the data source actor.
1064     auto position = host_ds_actor->FetchNodePosition(from_kernel_with_output_idx.first);
1065     real_from_kernel_with_output_idx.first = host_ds_actor->FetchNode(position);
1066     real_from_kernel_with_output_idx.second = from_kernel_with_output_idx.second;
1067   } else {
1068     real_from_kernel_with_output_idx.first = host_ds_actor->FetchNode(from_kernel_with_output_idx.second);
1069     real_from_kernel_with_output_idx.second = 0;
1070   }
1071 
1072   LinkDataArrowForBaseActor(from_actor, to_actor, real_from_kernel_with_output_idx, to_kernel_with_input_idx);
1073 }
1074 
LinkDataArrowForKernelActor(AbstractActor * const from_actor,KernelActor * const to_actor,const KernelWithIndex & from_kernel_with_output_idx,const KernelWithIndex & to_kernel_with_input_idx,const KernelGraphPtr &)1075 void GraphScheduler::LinkDataArrowForKernelActor(AbstractActor *const from_actor, KernelActor *const to_actor,
1076                                                  const KernelWithIndex &from_kernel_with_output_idx,
1077                                                  const KernelWithIndex &to_kernel_with_input_idx,
1078                                                  const KernelGraphPtr &) {
1079   auto real_from_actor = from_actor;
1080   auto real_from_kernel_with_output_idx = from_kernel_with_output_idx;
1081   auto from_kernel = from_kernel_with_output_idx.first;
1082   if (from_kernel == nullptr) {
1083     auto kernel_actor = dynamic_cast<KernelActor *>(from_actor);
1084     MS_EXCEPTION_IF_NULL(kernel_actor);
1085     from_kernel = kernel_actor->kernel_;
1086     real_from_kernel_with_output_idx.first = kernel_actor->kernel_;
1087   }
1088 
1089   // Update the from kernel info by the real node info.
1090   MS_EXCEPTION_IF_NULL(from_kernel);
1091   if (IsSkippedKernelActor(from_kernel)) {
1092     real_from_kernel_with_output_idx = AnfAlgo::GetPrevNodeOutput(from_kernel, 0);
1093     MS_EXCEPTION_IF_NULL(real_from_kernel_with_output_idx.first);
1094     LinkControlArrowBySkippedNode(to_actor, from_kernel);
1095 
1096     MS_EXCEPTION_IF_NULL(to_kernel_with_input_idx.first);
1097     MS_LOG(INFO) << "Link data arrow for inplace node, aggregate node: "
1098                  << to_kernel_with_input_idx.first->fullname_with_scope()
1099                  << ", aggregate input index: " << to_kernel_with_input_idx.second
1100                  << ", skip node: " << from_kernel->fullname_with_scope()
1101                  << ", real node: " << real_from_kernel_with_output_idx.first->fullname_with_scope();
1102     real_from_actor =
1103       dynamic_cast<AbstractActor *>(FetchActor(real_from_kernel_with_output_idx.first->fullname_with_scope()));
1104     MS_EXCEPTION_IF_NULL(real_from_actor);
1105   }
1106 
1107   LinkDataArrowForBaseActor(real_from_actor, to_actor, real_from_kernel_with_output_idx, to_kernel_with_input_idx);
1108 }
1109 
LinkDataArrowForCopyActor(AbstractActor * const from_actor,KernelActor * const to_actor,const KernelWithIndex & from_kernel_with_output_idx,const KernelWithIndex & to_kernel_with_input_idx)1110 void GraphScheduler::LinkDataArrowForCopyActor(AbstractActor *const from_actor, KernelActor *const to_actor,
1111                                                const KernelWithIndex &from_kernel_with_output_idx,
1112                                                const KernelWithIndex &to_kernel_with_input_idx) {
1113   MS_EXCEPTION_IF_NULL(from_actor);
1114   MS_EXCEPTION_IF_NULL(to_actor);
1115   auto from_kernel = from_kernel_with_output_idx.first;
1116   MS_EXCEPTION_IF_NULL(from_kernel);
1117   auto from_output_index = from_kernel_with_output_idx.second;
1118   auto to_input_index = to_kernel_with_input_idx.second;
1119 
1120   std::string name = "copy_from:" + from_actor->GetAID().Name() + "_node:" + from_kernel->fullname_with_scope() +
1121                      "_output_index:" + std::to_string(from_output_index);
1122   CopyActor *copy_actor = dynamic_cast<CopyActor *>(FetchActor(name));
1123   // Link between from actor and copy actor.
1124   if (copy_actor == nullptr) {
1125     // Create the copy actor.
1126     auto copy_actor_shared_ptr = std::make_shared<CopyActor>(name, memory_manager_aid_);
1127     (void)copy_actors_.emplace_back(copy_actor_shared_ptr);
1128     copy_actor = copy_actor_shared_ptr.get();
1129     MS_EXCEPTION_IF_NULL(copy_actor);
1130     InsertActor(copy_actor);
1131 
1132     // Get the position of from kernel in the data source actor.
1133     auto position = from_actor->FetchNodePosition(from_kernel);
1134     if ((from_actor->device_contexts_.size() <= position) || (to_actor->device_contexts_.empty())) {
1135       MS_LOG(EXCEPTION) << "The device contexts size is wrong.";
1136     }
1137     auto from_device_context = from_actor->device_contexts_[position];
1138     auto to_device_context = to_actor->device_contexts_[0];
1139     auto from_device_tensor = AnfAlgo::GetMutableOutputAddr(from_kernel, from_output_index, false);
1140     MS_EXCEPTION_IF_NULL(from_device_context);
1141     MS_EXCEPTION_IF_NULL(to_device_context);
1142     MS_EXCEPTION_IF_NULL(from_device_tensor);
1143     auto op_arrow_to_copy = std::make_shared<DataArrow>(from_output_index, copy_actor->GetAID(), 0);
1144     // If the from actor has the multi nodes, then use the real output position.
1145     if (position != 0) {
1146       op_arrow_to_copy->from_output_index_ = SizeToInt(position);
1147     }
1148 
1149     // Link.
1150     (void)from_actor->output_data_arrows_.emplace_back(op_arrow_to_copy);
1151     copy_actor->input_datas_num_++;
1152 
1153     // Set the member of the copy actor.
1154     auto to_kernel_mod = AnfAlgo::GetKernelMod(to_kernel_with_input_idx.first);
1155     MS_EXCEPTION_IF_NULL(to_kernel_mod);
1156     auto input_sizes = to_kernel_mod->GetInputSizeList();
1157     if (to_input_index >= input_sizes.size()) {
1158       MS_LOG(EXCEPTION) << "To input index(" << to_input_index << ") is out of size: " << input_sizes.size();
1159     }
1160     copy_actor->output_ = to_device_context->CreateDeviceAddress(
1161       nullptr, input_sizes[to_input_index], from_device_tensor->format(), from_device_tensor->type_id());
1162     (void)copy_actor->device_contexts_.emplace_back(from_device_context);
1163     (void)copy_actor->device_contexts_.emplace_back(to_device_context);
1164 
1165     // Update the reference count of device tensor.
1166     UpdateRefCount(from_device_tensor.get());
1167   }
1168 
1169   // If the copy actor already exists, only need link between copy actor and to actor.
1170   auto op_arrow_from_copy = std::make_shared<DataArrow>(0, to_actor->GetAID(), to_input_index);
1171   (void)copy_actor->output_data_arrows_.emplace_back(op_arrow_from_copy);
1172   to_actor->input_datas_num_++;
1173   UpdateRefCount(copy_actor->output_.get());
1174 }
1175 
LinkControlArrowByAutoMonad(KernelActor * to_actor,const AnfNodePtr & from_node,const KernelGraphPtr & graph)1176 void GraphScheduler::LinkControlArrowByAutoMonad(KernelActor *to_actor, const AnfNodePtr &from_node,
1177                                                  const KernelGraphPtr &graph) {
1178   MS_EXCEPTION_IF_NULL(to_actor);
1179   MS_EXCEPTION_IF_NULL(from_node);
1180   MS_EXCEPTION_IF_NULL(graph);
1181   // Find the real input node, include the monad node and make tuple node.
1182   const std::vector<PrimitivePtr> return_types = {prim::kPrimDepend, prim::kPrimUpdateState, prim::kPrimLoad,
1183                                                   prim::kPrimMakeTuple};
1184   const auto &input_kernel_with_output_idx = AnfAlgo::VisitKernelWithReturnType(from_node, 0, false, return_types);
1185   MS_EXCEPTION_IF_NULL(input_kernel_with_output_idx.first);
1186   auto input_anfnode = input_kernel_with_output_idx.first;
1187   CNodePtr input_cnode = nullptr;
1188   if (input_anfnode->isa<CNode>()) {
1189     input_cnode = input_anfnode->cast<CNodePtr>();
1190   }
1191   // Make tuple node needs to be expanded.
1192   if (AnfAlgo::CheckPrimitiveType(input_anfnode, prim::kPrimMakeTuple)) {
1193     MS_EXCEPTION_IF_NULL(input_cnode);
1194     for (size_t i = 1; i < input_cnode->inputs().size(); ++i) {
1195       LinkControlArrowByAutoMonad(to_actor, input_cnode->input(i), graph);
1196     }
1197     return;
1198   }
1199 
1200   const std::unordered_set<PrimitivePtr, PrimitiveHasher, PrimitiveEqual> recursion_prims = {
1201     prim::kPrimDepend, prim::kPrimUpdateState, prim::kPrimLoad, prim::kPrimMakeTuple};
1202   // Get the real depend input by monad node which needs to link the control arrow.
1203   std::vector<AnfNodePtr> real_depend_inputs;
1204   if (AnfAlgo::CheckPrimitiveType(input_anfnode, prim::kPrimDepend) ||
1205       AnfAlgo::CheckPrimitiveType(input_anfnode, prim::kPrimLoad)) {
1206     MS_EXCEPTION_IF_NULL(input_cnode);
1207     real_depend_inputs.push_back(input_cnode->input(kDependAttachNodeIndex));
1208     // The real input may be this scene:  depend/load --> load/depend, so need add the control arrow for real input
1209     // node in this scene.
1210     if (AnfAlgo::IsOneOfPrimitiveCNode(input_cnode->input(kRealInputIndexInDepend), recursion_prims)) {
1211       real_depend_inputs.push_back(input_cnode->input(kRealInputIndexInDepend));
1212     }
1213   } else if (AnfAlgo::CheckPrimitiveType(input_anfnode, prim::kPrimUpdateState)) {
1214     MS_EXCEPTION_IF_NULL(input_cnode);
1215     for (size_t i = kUpdateStateRealInput; i < input_cnode->inputs().size(); ++i) {
1216       real_depend_inputs.push_back(input_cnode->input(i));
1217     }
1218   } else {
1219     real_depend_inputs.push_back(input_anfnode);
1220   }
1221 
1222   for (const auto &real_depend_input : real_depend_inputs) {
1223     auto real_depend_input_with_idx = AnfAlgo::VisitKernelWithReturnType(real_depend_input, 0, false, return_types);
1224     auto real_depend_kernel = real_depend_input_with_idx.first;
1225     // The monad node and make tuple node need recursion.
1226     if (AnfAlgo::IsOneOfPrimitiveCNode(real_depend_kernel, recursion_prims)) {
1227       LinkControlArrowByAutoMonad(to_actor, real_depend_kernel, graph);
1228       continue;
1229     }
1230 
1231     KernelActor *from_actor = nullptr;
1232     if (IsKernelActor(real_depend_kernel)) {
1233       from_actor = dynamic_cast<KernelActor *>(FetchActor(real_depend_kernel->fullname_with_scope()));
1234     } else if (IsInternalParameter(real_depend_kernel, graph)) {
1235       auto front_output_with_index = graph->GetFrontNodeByInternalParameter(real_depend_kernel);
1236       MS_EXCEPTION_IF_NULL(front_output_with_index.first);
1237       if (IsKernelActor(front_output_with_index.first)) {
1238         if (graph_output_to_actor_.count(front_output_with_index) == 0) {
1239           MS_LOG(EXCEPTION) << "Can't find actor by front node:" << front_output_with_index.first->DebugString();
1240         }
1241         from_actor = dynamic_cast<KernelActor *>(graph_output_to_actor_[front_output_with_index].first);
1242       }
1243     }
1244     if (from_actor == nullptr) {
1245       continue;
1246     }
1247     MS_LOG(INFO) << "Link control arrow by auto monad, from actor:  " << from_actor->GetAID().Name()
1248                  << ", to actor: " << to_actor->GetAID().Name();
1249     (void)from_actor->output_control_arrows_.emplace_back(to_actor->GetAID());
1250     to_actor->input_controls_num_++;
1251   }
1252 }
1253 
LinkControlArrowBySkippedNode(KernelActor * to_actor,const AnfNodePtr & skipped_node)1254 void GraphScheduler::LinkControlArrowBySkippedNode(KernelActor *to_actor, const AnfNodePtr &skipped_node) {
1255   MS_EXCEPTION_IF_NULL(to_actor);
1256   MS_EXCEPTION_IF_NULL(skipped_node);
1257   auto to_aid = to_actor->GetAID();
1258 
1259   // Link the control arrow from all the inputs of skipped node to the user of skipped node.
1260   auto input_num = AnfAlgo::GetInputTensorNum(skipped_node);
1261   for (size_t i = 0; i < input_num; ++i) {
1262     auto kernel_with_index = AnfAlgo::GetPrevNodeOutput(skipped_node, i, false);
1263     MS_EXCEPTION_IF_NULL(kernel_with_index.first);
1264     auto from_actor = dynamic_cast<KernelActor *>(FetchActor(kernel_with_index.first->fullname_with_scope()));
1265     MS_EXCEPTION_IF_NULL(from_actor);
1266     MS_LOG(INFO) << "Link control arrow by skipped node: " << skipped_node->fullname_with_scope()
1267                  << ", from actor: " << from_actor->GetAID().Name() << ", to actor: " << to_actor->GetAID().Name();
1268     (void)from_actor->output_control_arrows_.emplace_back(to_aid);
1269     to_actor->input_controls_num_++;
1270   }
1271 }
1272 
LinkControlArrowBySendRecvNodes(const KernelGraphPtr & graph)1273 void GraphScheduler::LinkControlArrowBySendRecvNodes(const KernelGraphPtr &graph) {
1274   MS_EXCEPTION_IF_NULL(graph);
1275   for (auto &from_iter : graph->allreduce_from_send_recv_pairs()) {
1276     auto to_allreduce_node = from_iter.first;
1277     auto from_send_node = from_iter.second.first;
1278     auto from_recv_node = from_iter.second.second;
1279     MS_EXCEPTION_IF_NULL(to_allreduce_node);
1280     MS_EXCEPTION_IF_NULL(from_send_node);
1281     MS_EXCEPTION_IF_NULL(from_recv_node);
1282     MS_LOG(INFO) << "Link control arrow for to_allreduce_node: " << to_allreduce_node->fullname_with_scope();
1283     auto to_allreduce_actor = dynamic_cast<KernelActor *>(FetchActor(to_allreduce_node->fullname_with_scope()));
1284     auto from_send_actor = dynamic_cast<KernelActor *>(FetchActor(from_send_node->fullname_with_scope()));
1285     auto from_recv_actor = dynamic_cast<KernelActor *>(FetchActor(from_recv_node->fullname_with_scope()));
1286     MS_EXCEPTION_IF_NULL(to_allreduce_actor);
1287     MS_EXCEPTION_IF_NULL(from_send_actor);
1288     MS_EXCEPTION_IF_NULL(from_recv_actor);
1289 
1290     // inputs of to_allreduce_actor  --> from_send_actor
1291     for (auto &input_aid : to_allreduce_actor->input_data_arrow_aids_) {
1292       auto input_actor = dynamic_cast<KernelActor *>(FetchActor(input_aid.Name()));
1293       if (input_actor != nullptr) {
1294         (void)input_actor->output_control_arrows_.emplace_back(from_send_actor->GetAID());
1295         from_send_actor->input_controls_num_++;
1296       }
1297     }
1298 
1299     // from_send_actor --> from_recv_actor
1300     (void)from_send_actor->output_control_arrows_.emplace_back(from_recv_actor->GetAID());
1301     from_recv_actor->input_controls_num_++;
1302 
1303     // from_recv_actor --> to_allreduce_actor
1304     (void)from_recv_actor->output_control_arrows_.emplace_back(to_allreduce_actor->GetAID());
1305     to_allreduce_actor->input_controls_num_++;
1306   }
1307 
1308   for (auto &to_iter : graph->allreduce_to_send_recv_pairs()) {
1309     auto from_allreduce_node = to_iter.first;
1310     auto to_send_node = to_iter.second.first;
1311     auto to_recv_node = to_iter.second.second;
1312     MS_EXCEPTION_IF_NULL(from_allreduce_node);
1313     MS_EXCEPTION_IF_NULL(to_send_node);
1314     MS_EXCEPTION_IF_NULL(to_recv_node);
1315     MS_LOG(INFO) << "Link control arrow for from_allreduce_node: " << from_allreduce_node->fullname_with_scope();
1316     auto from_allreduce_actor = dynamic_cast<KernelActor *>(FetchActor(from_allreduce_node->fullname_with_scope()));
1317     auto to_send_actor = dynamic_cast<KernelActor *>(FetchActor(to_send_node->fullname_with_scope()));
1318     auto to_recv_actor = dynamic_cast<KernelActor *>(FetchActor(to_recv_node->fullname_with_scope()));
1319     MS_EXCEPTION_IF_NULL(from_allreduce_actor);
1320     MS_EXCEPTION_IF_NULL(to_send_actor);
1321     MS_EXCEPTION_IF_NULL(to_recv_actor);
1322 
1323     // from_allreduce_actor  --> to_send_actor
1324     (void)from_allreduce_actor->output_control_arrows_.emplace_back(to_send_actor->GetAID());
1325     to_send_actor->input_controls_num_++;
1326 
1327     // to_send_actor --> to_recv_actor
1328     (void)to_send_actor->output_control_arrows_.emplace_back(to_recv_actor->GetAID());
1329     to_recv_actor->input_controls_num_++;
1330 
1331     // to_recv_actor --> outputs of from_allreduce_actor
1332     for (auto &output_data_arrow : from_allreduce_actor->output_data_arrows_) {
1333       auto output_actor = dynamic_cast<KernelActor *>(FetchActor(output_data_arrow->to_op_id_.Name()));
1334       if (output_actor != nullptr) {
1335         (void)to_recv_actor->output_control_arrows_.emplace_back(output_actor->GetAID());
1336         output_actor->input_controls_num_++;
1337       }
1338     }
1339 
1340     // In the scene of allreduce op and computing op parallel multi stream, the input memory of allreduce can be
1341     // reused only when the recv node runs finished, which is expressed by the reference count increased.
1342     for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(from_allreduce_node); ++i) {
1343       auto device_tensor = AnfAlgo::GetPrevNodeMutableOutputAddr(from_allreduce_node, i, false);
1344       MS_EXCEPTION_IF_NULL(device_tensor);
1345       UpdateRefCount(device_tensor.get());
1346       (void)to_recv_actor->external_reference_tensors_.emplace_back(device_tensor.get());
1347     }
1348   }
1349 }
1350 
LinkGlobalControlArrow(ActorSet * const actor_set,const std::vector<CNodePtr> & communication_nodes,const std::vector<KernelActor * > & auto_monad_actors,const GraphCompilerInfo & graph_compiler_info)1351 void GraphScheduler::LinkGlobalControlArrow(ActorSet *const actor_set, const std::vector<CNodePtr> &communication_nodes,
1352                                             const std::vector<KernelActor *> &auto_monad_actors,
1353                                             const GraphCompilerInfo &graph_compiler_info) {
1354   MS_EXCEPTION_IF_NULL(actor_set);
1355 
1356   // Link the control arrows by the communication nodes to ensure communication nodes running order.
1357   LinkControlArrowByCommunicationNode(communication_nodes, graph_compiler_info);
1358 
1359   // Auto monad actor may modify the device tensor store.
1360   LinkDeviceTensorStoreForAutoMonadActor(auto_monad_actors);
1361 
1362   // BuildNoInputKernelActor depends on whether kernel actors have input, so must be behind the link of kernel actors.
1363   actor_set->no_input_kernel_actors_ = BuildNoInputKernelActor(actor_set, graph_compiler_info.strategy_);
1364 
1365   // Link the control arrows of data prepare actor, which depends on the no input kernel actors.
1366   if ((graph_compiler_info.strategy_ == GraphExecutionStrategy::kPipeline) || (!IsSingleOpActorSet(actor_set))) {
1367     LinkControlArrowForDataPrepareActor(actor_set->data_prepare_actor_.get(), actor_set);
1368   }
1369 
1370   LinkControlArrowForLoopCountActor(actor_set->loop_count_actor_.get(), actor_set,
1371                                     graph_compiler_info.control_node_parser_);
1372 }
1373 
LinkControlArrowByCommunicationNode(const std::vector<CNodePtr> & communication_nodes,const GraphCompilerInfo & graph_compiler_info)1374 void GraphScheduler::LinkControlArrowByCommunicationNode(const std::vector<CNodePtr> &communication_nodes,
1375                                                          const GraphCompilerInfo &graph_compiler_info) {
1376   const size_t kCommunicationNodesMinNum = 2;
1377   if (communication_nodes.size() < kCommunicationNodesMinNum) {
1378     return;
1379   }
1380 
1381   // Ensure communication node to execute orderly.
1382   for (size_t i = 1; i < communication_nodes.size(); ++i) {
1383     auto from_actor = dynamic_cast<KernelActor *>(FetchActor(communication_nodes[i - 1]->fullname_with_scope()));
1384     auto to_actor = dynamic_cast<KernelActor *>(FetchActor(communication_nodes[i]->fullname_with_scope()));
1385     MS_EXCEPTION_IF_NULL(from_actor);
1386     MS_EXCEPTION_IF_NULL(to_actor);
1387     (void)from_actor->output_control_arrows_.emplace_back(to_actor->GetAID());
1388     to_actor->input_controls_num_++;
1389   }
1390 
1391   // Ensure all actors execute orderly to optimize the execution performance in the multi device scenario currently.
1392   // Using the multi stream to optimize the performance in the future.
1393   for (auto &graph : graph_compiler_info.graphs_) {
1394     MS_EXCEPTION_IF_NULL(graph);
1395     auto &execution_order = graph->execution_order();
1396     for (size_t i = 1; i < execution_order.size(); ++i) {
1397       auto from_actor = dynamic_cast<KernelActor *>(FetchActor(execution_order[i - 1]->fullname_with_scope()));
1398       auto to_actor = dynamic_cast<KernelActor *>(FetchActor(execution_order[i]->fullname_with_scope()));
1399       if ((from_actor != nullptr) && (to_actor != nullptr)) {
1400         (void)from_actor->output_control_arrows_.emplace_back(to_actor->GetAID());
1401         to_actor->input_controls_num_++;
1402       }
1403     }
1404   }
1405 }
1406 
LinkControlArrowForDataPrepareActor(DataPrepareActor * data_prepare_actor,const ActorSet * actor_set)1407 void GraphScheduler::LinkControlArrowForDataPrepareActor(DataPrepareActor *data_prepare_actor,
1408                                                          const ActorSet *actor_set) {
1409   MS_EXCEPTION_IF_NULL(data_prepare_actor);
1410   MS_EXCEPTION_IF_NULL(actor_set);
1411 
1412   // Data prepare actor --> data source actor.
1413   for (auto &data_source_actor : actor_set->data_source_actors_) {
1414     MS_EXCEPTION_IF_NULL(data_source_actor);
1415     (void)data_prepare_actor->data_source_aids_.emplace_back(data_source_actor->GetAID());
1416   }
1417 
1418   // Data prepare actor --> no input kernel actor.
1419   for (auto &no_input_kernel_actor : actor_set->no_input_kernel_actors_) {
1420     MS_EXCEPTION_IF_NULL(no_input_kernel_actor);
1421     (void)data_prepare_actor->no_input_kernel_aids_.emplace_back(no_input_kernel_actor->GetAID());
1422     no_input_kernel_actor->input_controls_num_++;
1423   }
1424 
1425   // Data prepare actor --> loop count actor.
1426   if ((actor_set->data_source_actors_.size() + actor_set->no_input_kernel_actors_.size() == 0) &&
1427       (actor_set->loop_count_actor_ != nullptr)) {
1428     data_prepare_actor->loop_count_aid_ = &(actor_set->loop_count_actor_->GetAID());
1429   }
1430 }
1431 
LinkControlArrowForLoopCountActor(LoopCountActor * loop_count_actor,const ActorSet * actor_set,const ControlNodeParserPtr & parser)1432 void GraphScheduler::LinkControlArrowForLoopCountActor(LoopCountActor *loop_count_actor, const ActorSet *actor_set,
1433                                                        const ControlNodeParserPtr &parser) {
1434   MS_EXCEPTION_IF_NULL(actor_set);
1435   MS_EXCEPTION_IF_NULL(parser);
1436   // There is no loop count actor in step mode.
1437   if (loop_count_actor == nullptr) {
1438     return;
1439   }
1440 
1441   // Collect the actors which have no output.
1442   std::vector<MemoryAwareActor *> no_output_actors;
1443   for (auto &kernel_actor : actor_set->kernel_actors_) {
1444     // The no output kernel control side in subgraph needs to be connected to the corresponding output switch actor.
1445     if ((kernel_actor->output_data_arrows_.size() == 0) && (kernel_actor->output_control_arrows_.size() == 0) &&
1446         parser->IsKernelInRootFuncGraph(kernel_actor->kernel_)) {
1447       MS_EXCEPTION_IF_NULL(kernel_actor->kernel_);
1448       MS_LOG(INFO) << kernel_actor->kernel_->fullname_with_scope() << " is not real used by other nodes.";
1449       (void)no_output_actors.emplace_back(kernel_actor.get());
1450     }
1451   }
1452   for (auto &data_actor : actor_set->data_source_actors_) {
1453     if ((data_actor->output_data_arrows_.size() == 0) && (data_actor->output_control_arrows_.size() == 0)) {
1454       (void)no_output_actors.emplace_back(data_actor.get());
1455     }
1456   }
1457   for (auto &copy_actor : copy_actors_) {
1458     if ((copy_actor->output_data_arrows_.size() == 0) && (copy_actor->output_control_arrows_.size() == 0)) {
1459       (void)no_output_actors.emplace_back(copy_actor.get());
1460     }
1461   }
1462 
1463   // No output actor --> loop count actor.
1464   for (auto &no_output_actor : no_output_actors) {
1465     (void)no_output_actor->output_control_arrows_.emplace_back(loop_count_actor->GetAID());
1466     loop_count_actor->input_controls_num_++;
1467   }
1468 
1469   // Loop count actor --> data prepare actor.
1470   MS_EXCEPTION_IF_NULL(actor_set->data_prepare_actor_);
1471   loop_count_actor->data_prepare_aid_ = actor_set->data_prepare_actor_->GetAID();
1472 
1473   // Loop count actor --> output actor.
1474   MS_EXCEPTION_IF_NULL(actor_set->output_actor_);
1475   loop_count_actor->output_aid_ = actor_set->output_actor_->GetAID();
1476 }
1477 
LinkOutputResultArrowForOutputActor(OutputActor * to_actor,const GraphCompilerInfo & graph_compiler_info)1478 void GraphScheduler::LinkOutputResultArrowForOutputActor(OutputActor *to_actor,
1479                                                          const GraphCompilerInfo &graph_compiler_info) {
1480   if (graph_compiler_info.strategy_ == GraphExecutionStrategy::kStep) {
1481     return;
1482   }
1483 
1484   MS_EXCEPTION_IF_NULL(to_actor);
1485 
1486   for (size_t i = 0; i < graph_compiler_info.graphs_.size(); ++i) {
1487     const auto &graph = graph_compiler_info.graphs_[i];
1488     MS_EXCEPTION_IF_NULL(graph);
1489     auto outputs = AnfAlgo::GetAllOutputWithIndex(graph->output());
1490     std::set<std::vector<size_t>> unique_output_positions;
1491     std::set<KernelWithIndex> unique_outputs;
1492     for (const auto &output : outputs) {
1493       if (IsInternalParameter(output.first, graph)) {
1494         MS_LOG(INFO) << "Ignore the internal parameter node:" << output.first->DebugString();
1495         continue;
1496       }
1497       (void)unique_outputs.insert(output);
1498     }
1499     for (const auto &output_with_index : unique_outputs) {
1500       MS_EXCEPTION_IF_NULL(output_with_index.first);
1501       auto origin_output_with_index = FetchFrontNodeWithIndexByGraphOutput(output_with_index, graph);
1502       const auto &iter = graph_compiler_info.origin_outputs_order_.find(origin_output_with_index);
1503       if (iter == graph_compiler_info.origin_outputs_order_.end()) {
1504         continue;
1505       }
1506 
1507       // Skip duplicate position.
1508       if (unique_output_positions.count(iter->second) > 0) {
1509         continue;
1510       }
1511       (void)unique_output_positions.insert(iter->second);
1512       for (auto &output_position : iter->second) {
1513         if (output_position >= to_actor->device_contexts_.size()) {
1514           MS_LOG(EXCEPTION) << "The output position is out of range.";
1515         }
1516         to_actor->device_contexts_[output_position] = graph_compiler_info.device_contexts_[i];
1517         // The device tensor of graph out need be taken over by host tensor, so set the max reference count.
1518         UpdateRefCount(output_with_index.first, output_with_index.second, true);
1519 
1520         // The graph output is from device tensor store.
1521         if (IsPersistentDeviceTensor(output_with_index.first)) {
1522           (void)to_actor->device_tensor_store_keys_.emplace_back(output_position, output_with_index.first);
1523           continue;
1524         }
1525 
1526         // The graph output is from kernel actor or data source actor.
1527         auto kernel_type = KernelTransformType::kUnknown;
1528         std::string kernel_name = "";
1529         FetchKernelTransformTypeAndName(output_with_index.first, graph, graph_compiler_info, &kernel_type,
1530                                         &kernel_name);
1531         auto from_actor = dynamic_cast<AbstractActor *>(FetchActor(kernel_name));
1532         if (from_actor == nullptr) {
1533           continue;
1534         }
1535         auto op_arrow = std::make_shared<DataArrow>(output_with_index.second, to_actor->GetAID(), output_position);
1536         auto position = from_actor->FetchNodePosition(output_with_index.first);
1537         // If the from actor has the multi nodes, then use the real output position.
1538         if (position != 0) {
1539           op_arrow->from_output_index_ = SizeToInt(position);
1540         }
1541         (void)from_actor->output_result_arrows_.emplace_back(op_arrow);
1542         if (kernel_type == KernelTransformType::kHostDataSourceActor) {
1543           auto host_queue_ds_actor = dynamic_cast<HostQueueDataSourceActor *>(from_actor);
1544           MS_EXCEPTION_IF_NULL(host_queue_ds_actor);
1545           UpdateRefCount(host_queue_ds_actor->data_nodes_[position], output_with_index.second, true);
1546         }
1547       }
1548     }
1549   }
1550 }
1551 
LinkOutputResultArrowForSwitchActor(const GraphCompilerInfo & graph_compiler_info,const ActorSet * actor_set)1552 void GraphScheduler::LinkOutputResultArrowForSwitchActor(const GraphCompilerInfo &graph_compiler_info,
1553                                                          const ActorSet *actor_set) {
1554   const auto &to_actor = actor_set->output_actor_;
1555   const auto &loop_count_actor = actor_set->loop_count_actor_;
1556   if (to_actor == nullptr || loop_count_actor == nullptr) {
1557     return;
1558   }
1559 
1560   const auto &switch_actors = actor_set->switch_actors_;
1561   for (const auto &from_actor : switch_actors) {
1562     MS_EXCEPTION_IF_NULL(from_actor);
1563     auto origin_output_with_index = KernelWithIndex(from_actor->node_, 0);
1564     const auto &iter = graph_compiler_info.origin_outputs_order_.find(origin_output_with_index);
1565     if (iter == graph_compiler_info.origin_outputs_order_.end()) {
1566       continue;
1567     }
1568 
1569     // If the switch actor is in the output list, the output of switch actor should be sent to the output actor.
1570     // And need to link a control arrow to the loop count actor.
1571     for (const auto pos : iter->second) {
1572       to_actor->device_contexts_[pos] = from_actor->device_context_;
1573     }
1574 
1575     for (size_t i = 0; i < from_actor->branch_inputs_pos_.size(); ++i) {
1576       const auto &input_pos = from_actor->branch_inputs_pos_[i];
1577       if (input_pos.empty()) {
1578         MS_LOG(EXCEPTION) << "Invalid input num in switch actor:" << from_actor->GetAID();
1579       }
1580 
1581       for (const auto pos : iter->second) {
1582         auto op_arrow = std::make_shared<DataArrow>(0, to_actor->GetAID(), pos);
1583         (void)from_actor->output_branch_result_arrows_[i].emplace_back(op_arrow);
1584       }
1585 
1586       (void)from_actor->output_branch_control_arrows_[i].emplace_back(loop_count_actor->GetAID());
1587     }
1588     loop_count_actor->input_controls_num_++;
1589   }
1590 }
1591 
LinkDeviceTensorStoreForAutoMonadActor(const std::vector<KernelActor * > & auto_monad_actors)1592 void GraphScheduler::LinkDeviceTensorStoreForAutoMonadActor(const std::vector<KernelActor *> &auto_monad_actors) {
1593   const size_t kNeedUpdateDeviceTensorStoreNum = 2;
1594   for (auto &kernel_actor : auto_monad_actors) {
1595     MS_EXCEPTION_IF_NULL(kernel_actor);
1596     for (auto &device_tensor_store_key : kernel_actor->device_tensor_store_keys_) {
1597       auto device_tensors = DeviceTensorStore::GetInstance().Fetch(device_tensor_store_key.second.get());
1598       if (device_tensors.size() < kNeedUpdateDeviceTensorStoreNum) {
1599         continue;
1600       }
1601 
1602       // Create the copy actor.
1603       std::string name = "copy_from:" + kernel_actor->GetAID().Name() +
1604                          "_device_tensor_store:" + device_tensor_store_key.second->fullname_with_scope();
1605       if (FetchActor(name) != nullptr) {
1606         continue;
1607       }
1608       auto copy_actor = std::make_shared<CopyActor>(name, memory_manager_aid_);
1609       MS_EXCEPTION_IF_NULL(copy_actor);
1610       (void)copy_actors_.emplace_back(copy_actor);
1611       InsertActor(copy_actor.get());
1612 
1613       // Set the member of the copy actor.
1614       (void)copy_actor->device_tensor_store_keys_.emplace_back(0, device_tensor_store_key.second);
1615       auto input_device_context = kernel_actor->device_contexts_[0];
1616       (void)copy_actor->device_contexts_.emplace_back(input_device_context);
1617       auto another_device_tensor = (device_tensors[0]->DeviceType() == input_device_context->GetDeviceAddressType())
1618                                      ? device_tensors[1]
1619                                      : device_tensors[0];
1620       MS_EXCEPTION_IF_NULL(another_device_tensor);
1621       auto another_device_type = another_device_tensor->DeviceType();
1622       const auto &another_device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
1623         {device::kDeviceTypeToName.at(another_device_type), input_device_context->device_context_key().device_id_});
1624       MS_EXCEPTION_IF_NULL(another_device_context);
1625       (void)copy_actor->device_contexts_.emplace_back(another_device_context);
1626 
1627       MS_LOG(INFO) << "The kernel actor: " << kernel_actor->GetAID().Name()
1628                    << "has control arrows number:" << kernel_actor->output_control_arrows_.size();
1629       // Link from copy actor to kernel actor users.
1630       for (auto &output_contorl : kernel_actor->output_control_arrows_) {
1631         (void)copy_actor->output_control_arrows_.emplace_back(output_contorl);
1632       }
1633       // Move the control arrows from kernel actor to kernel actor users.
1634       kernel_actor->output_control_arrows_.clear();
1635 
1636       // Link from kernel actor to copy actor.
1637       (void)kernel_actor->output_control_arrows_.emplace_back(copy_actor->GetAID());
1638       copy_actor->input_controls_num_++;
1639     }
1640   }
1641 }
1642 
PrepareInputNodeForSwitchActor(const std::vector<AnfNodePtr> & control_nodes)1643 void GraphScheduler::PrepareInputNodeForSwitchActor(const std::vector<AnfNodePtr> &control_nodes) {
1644   for (const auto &node : control_nodes) {
1645     CNodePtr cnode = node->cast<CNodePtr>();
1646     auto inputs = cnode->inputs();
1647     // Before link data arrow, parameters of the call node in switch-call need to be add to the switch actor.
1648     if (inputs[0]->isa<CNode>()) {
1649       auto actor = FetchActor(inputs[0]->DebugString());
1650       MS_EXCEPTION_IF_NULL(actor);
1651       auto switch_actor = dynamic_cast<SwitchActor *>(actor);
1652       MS_EXCEPTION_IF_NULL(switch_actor);
1653 
1654       for (size_t i = kCallInputStartPos; i < inputs.size(); ++i) {
1655         if (HasAbstractMonad(inputs[i])) {
1656           continue;
1657         }
1658         switch_actor->AddCommonInput(inputs[i]);
1659       }
1660     }
1661   }
1662 }
1663 
LinkArrowByControlNode(const GraphCompilerInfo & graph_compiler_info,ActorSet * const actor_set)1664 void GraphScheduler::LinkArrowByControlNode(const GraphCompilerInfo &graph_compiler_info, ActorSet *const actor_set) {
1665   PrepareInputNodeForSwitchActor(graph_compiler_info.control_nodes_);
1666 
1667   for (const auto &node : graph_compiler_info.control_nodes_) {
1668     CNodePtr cnode = node->cast<CNodePtr>();
1669     const auto &from_func_graph = node->func_graph();
1670     auto inputs = cnode->inputs();
1671     // Link data arrow for switch node.
1672     if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimSwitch) ||
1673         AnfAlgo::CheckPrimitiveType(node, prim::kPrimSwitchLayer)) {
1674       auto actor = actor_name_to_actor_[node->DebugString()];
1675       MS_EXCEPTION_IF_NULL(actor);
1676       auto switch_actor = dynamic_cast<SwitchActor *>(actor);
1677       MS_EXCEPTION_IF_NULL(switch_actor);
1678       LinkDataArrowForSwitchActor(graph_compiler_info, switch_actor);
1679     } else if (inputs[0]->isa<ValueNode>() && IsValueNode<FuncGraph>(inputs[0])) {
1680       // Link the data arrow for the input of the call node.
1681       const auto &actor_name = node->DebugString();
1682       auto actor = FetchActor(actor_name);
1683       MS_EXCEPTION_IF_NULL(actor);
1684       auto gather_actor = dynamic_cast<GatherActor *>(actor);
1685       MS_EXCEPTION_IF_NULL(gather_actor);
1686 
1687       const auto &func_graph = GetValueNode<FuncGraphPtr>(inputs[0]);
1688       MS_EXCEPTION_IF_NULL(func_graph);
1689       const auto &to_actor = FetchActor(func_graph->ToString());
1690       MS_EXCEPTION_IF_NULL(to_actor);
1691 
1692       size_t persist_input_num = 0;
1693       for (size_t i = kCallInputStartPos; i < inputs.size(); ++i) {
1694         MS_EXCEPTION_IF_NULL(actor);
1695         if (inputs[i]->isa<ValueNode>()) {
1696           const auto &node_value = inputs[i]->cast<ValueNodePtr>()->value();
1697           if (!node_value->isa<tensor::Tensor>()) {
1698             persist_input_num++;
1699             continue;
1700           }
1701 
1702           (void)gather_actor->device_tensor_store_keys_.emplace_back(i - kCallInputStartPos - persist_input_num,
1703                                                                      inputs[i].get());
1704           gather_actor->device_contexts_[i - kCallInputStartPos - persist_input_num] =
1705             graph_compiler_info.control_node_parser_->GetFrontValueNodeDeviceContext(inputs[i]);
1706         } else if ((inputs[i]->isa<Parameter>() && HasAbstractRef(inputs[i]->cast<ParameterPtr>())) ||
1707                    AnfAlgo::CheckPrimitiveType(inputs[i], prim::kPrimUpdateState) || HasAbstractMonad(inputs[i])) {
1708           persist_input_num++;
1709           continue;
1710         } else {
1711           const auto &input_with_index = AnfAlgo::VisitKernelWithReturnType(inputs[i], 0);
1712           LinkDataArrowByControlNode(graph_compiler_info, input_with_index, from_func_graph, actor,
1713                                      i - kCallInputStartPos - persist_input_num);
1714         }
1715 
1716         auto op_arrow = std::make_shared<DataArrow>(i - kCallInputStartPos - persist_input_num, to_actor->GetAID(),
1717                                                     i - kCallInputStartPos - persist_input_num);
1718         (void)gather_actor->output_data_arrows_.emplace_back(op_arrow);
1719       }
1720     }
1721   }
1722 
1723   // Link arrow for switch actor of subgraph output.
1724   for (const auto &func_graph_with_call_num : graph_compiler_info.control_node_parser_->func_graph_to_call_num_) {
1725     const auto &func_graph = func_graph_with_call_num.first;
1726     MS_EXCEPTION_IF_NULL(func_graph);
1727     auto actor = FetchActor(func_graph->get_return()->DebugString());
1728     MS_EXCEPTION_IF_NULL(actor);
1729     auto switch_actor = dynamic_cast<SwitchActor *>(actor);
1730     MS_EXCEPTION_IF_NULL(switch_actor);
1731     LinkDataArrowForSwitchActor(graph_compiler_info, switch_actor);
1732   }
1733 
1734   // Link arrow for gather actor for call input kernel graph.
1735   for (const auto &call_input_kernel_graph : graph_compiler_info.control_node_parser_->call_input_kernel_graphs_) {
1736     const auto &kernel_graph = call_input_kernel_graph.first;
1737     MS_EXCEPTION_IF_NULL(kernel_graph);
1738     auto actor = FetchActor(kernel_graph->ToString());
1739     MS_EXCEPTION_IF_NULL(actor);
1740     auto gather_actor = dynamic_cast<GatherActor *>(actor);
1741     MS_EXCEPTION_IF_NULL(gather_actor);
1742 
1743     for (size_t i = 0; i < gather_actor->data_nodes_.size(); ++i) {
1744       const auto &input_with_index = gather_actor->data_nodes_[i];
1745       const auto &from_func_graph = kernel_graph->GetFuncGraph();
1746       LinkDataArrowByControlNode(graph_compiler_info, input_with_index, from_func_graph, gather_actor, i);
1747     }
1748   }
1749   LinkBranchArrowForSwitchActor(graph_compiler_info);
1750 
1751   LinkBranchArrowForGatherActor(graph_compiler_info);
1752 
1753   LinkControlArrowForGatherActor(&(actor_set->kernel_actors_), graph_compiler_info.graphs_,
1754                                  graph_compiler_info.control_node_parser_);
1755 
1756   LinkControlArrowForSwitchActor(&(actor_set->switch_actors_), actor_set->loop_count_actor_.get(),
1757                                  graph_compiler_info.origin_outputs_order_);
1758 
1759   LinkOutputResultArrowForSwitchActor(graph_compiler_info, actor_set);
1760 }
1761 
LinkDataArrowForGatherActor(GatherActor * const from_actor,KernelActor * const to_actor,const KernelWithIndex & front_node_with_index,const KernelWithIndex & to_node_with_index)1762 void GraphScheduler::LinkDataArrowForGatherActor(GatherActor *const from_actor, KernelActor *const to_actor,
1763                                                  const KernelWithIndex &front_node_with_index,
1764                                                  const KernelWithIndex &to_node_with_index) {
1765   MS_EXCEPTION_IF_NULL(from_actor);
1766   MS_EXCEPTION_IF_NULL(to_actor);
1767   MS_EXCEPTION_IF_NULL(front_node_with_index.first);
1768 
1769   auto position = from_actor->FetchDataNodePosition(front_node_with_index);
1770 
1771   auto op_arrow = std::make_shared<DataArrow>(position, to_actor->GetAID(), to_node_with_index.second);
1772   (void)from_actor->output_data_arrows_.emplace_back(op_arrow);
1773   to_actor->input_datas_num_++;
1774 }
1775 
LinkDataArrowByCallInput(const KernelWithIndex & call_node_with_index,const ControlNodeParserPtr & parser,const FuncGraphPtr & from_func_graph,OpActor<DeviceTensor> * const to_actor,const size_t to_index)1776 void GraphScheduler::LinkDataArrowByCallInput(const KernelWithIndex &call_node_with_index,
1777                                               const ControlNodeParserPtr &parser, const FuncGraphPtr &from_func_graph,
1778                                               OpActor<DeviceTensor> *const to_actor, const size_t to_index) {
1779   // Fetch all the funcgraph that call node would call.
1780   const auto cnode = call_node_with_index.first->cast<CNodePtr>();
1781   std::vector<FuncGraphPtr> func_graphs = FetchFuncGraphbyCallNode(cnode);
1782 
1783   // Collect the output of each funcgraph.
1784   for (const auto &func_graph : func_graphs) {
1785     const auto actor_name = func_graph->get_return()->DebugString();
1786     auto actor = FetchActor(actor_name);
1787     MS_EXCEPTION_IF_NULL(actor);
1788     auto switch_actor = dynamic_cast<SwitchActor *>(actor);
1789     MS_EXCEPTION_IF_NULL(switch_actor);
1790     const size_t branch_index = switch_actor->branch_id_to_index_.size();
1791 
1792     const auto &func_graph_to_branch_id = parser->func_graph_to_branch_id_;
1793     const auto &iter = func_graph_to_branch_id.find(from_func_graph);
1794 
1795     int branch_id = kMainBranchID;
1796     if (iter != func_graph_to_branch_id.end()) {
1797       branch_id = iter->second;
1798     }
1799     if (switch_actor->branch_id_to_index_.find(branch_id) != switch_actor->branch_id_to_index_.end()) {
1800       LinkDataArrowForSwitchActor(switch_actor, call_node_with_index.second, to_actor, to_index,
1801                                   switch_actor->branch_id_to_index_[branch_id]);
1802       continue;
1803     }
1804     LinkDataArrowForSwitchActor(switch_actor, call_node_with_index.second, to_actor, to_index, branch_index);
1805     switch_actor->branch_id_to_index_[branch_id] = branch_index;
1806   }
1807 }
1808 
LinkDataArrowForSwitchActor(SwitchActor * from_actor,const size_t from_index,OpActor<DeviceTensor> * to_actor,const size_t to_index,const size_t branch_index)1809 void GraphScheduler::LinkDataArrowForSwitchActor(SwitchActor *from_actor, const size_t from_index,
1810                                                  OpActor<DeviceTensor> *to_actor, const size_t to_index,
1811                                                  const size_t branch_index) {
1812   MS_EXCEPTION_IF_NULL(from_actor);
1813   MS_EXCEPTION_IF_NULL(to_actor);
1814   size_t start_branch = 0;
1815   size_t max_branch = from_actor->output_branch_arrows_.size();
1816   if (branch_index != SIZE_MAX) {
1817     start_branch = branch_index;
1818     max_branch = branch_index + 1;
1819   }
1820   for (size_t i = start_branch; i < max_branch; ++i) {
1821     if (from_actor->branch_inputs_pos_[i].size() <= from_index) {
1822       MS_LOG(EXCEPTION) << "No input for switch actor:" << from_actor->GetAID() << " branch:" << i
1823                         << " from index:" << from_index << " output size:" << from_actor->branch_inputs_pos_[i].size()
1824                         << " to actor:" << to_actor->GetAID() << " to index:" << to_index;
1825     }
1826     auto op_arrow =
1827       std::make_shared<DataArrow>(from_actor->branch_inputs_pos_[i][from_index], to_actor->GetAID(), to_index);
1828     (void)from_actor->output_branch_arrows_[i].emplace_back(op_arrow);
1829   }
1830 }
1831 
LinkDataArrowByControlNode(const GraphCompilerInfo & graph_compiler_info,const KernelWithIndex & input_with_index,const FuncGraphPtr & from_func_graph,OpActor<DeviceTensor> * const to_actor,const size_t to_index)1832 void GraphScheduler::LinkDataArrowByControlNode(const GraphCompilerInfo &graph_compiler_info,
1833                                                 const KernelWithIndex &input_with_index,
1834                                                 const FuncGraphPtr &from_func_graph,
1835                                                 OpActor<DeviceTensor> *const to_actor, const size_t to_index) {
1836   const auto &parameters = graph_compiler_info.origin_parameters_order_;
1837   const auto &front_to_backend_parameter = graph_compiler_info.control_node_parser_->front_to_backend_parameters_;
1838   const auto &input_node = input_with_index.first;
1839 
1840   if (IsCallNode(input_node)) {
1841     // The actor input is a call node.
1842     LinkDataArrowByCallInput(input_with_index, graph_compiler_info.control_node_parser_, from_func_graph, to_actor,
1843                              to_index);
1844   } else if (IsGatherActor(input_node, actor_name_to_actor_)) {
1845     // The actor input is a parameter in gather actor.
1846     auto from_actor = dynamic_cast<GatherActor *>(actor_name_to_actor_[input_node->func_graph()->ToString()]);
1847     auto position = from_actor->FetchDataNodePosition({input_node, 0});
1848     auto op_arrow = std::make_shared<DataArrow>(position, to_actor->GetAID(), to_index);
1849     (void)from_actor->output_data_arrows_.emplace_back(op_arrow);
1850   } else if (IsSwitchActor(input_node)) {
1851     const auto &actor_name = input_node->DebugString();
1852     auto actor = FetchActor(actor_name);
1853     MS_EXCEPTION_IF_NULL(actor);
1854     LinkDataArrowForSwitchActor(dynamic_cast<SwitchActor *>(actor), 0, to_actor, to_index);
1855   } else if (IsKernelActor(input_node, graph_compiler_info.strategy_)) {
1856     // The actor input is a cnode.
1857     if (front_node_to_actor_.find(input_node) == front_node_to_actor_.end()) {
1858       const auto &kernel_with_index = AnfAlgo::VisitKernelWithReturnType(input_node, 0);
1859       const auto &backend_node =
1860         graph_compiler_info.control_node_parser_->GetBackendKernelByFrontKernel(kernel_with_index);
1861       if (backend_node.first == nullptr) {
1862         MS_LOG(EXCEPTION) << "Cannot find actor:" << to_actor->GetAID()
1863                           << " input_node:" << AnfAlgo::GetNodeDebugString(input_node) << " addr:" << input_node;
1864       }
1865       const auto &actor_name = backend_node.first->fullname_with_scope();
1866       const auto &actor = FetchActor(actor_name);
1867       MS_EXCEPTION_IF_NULL(actor);
1868       auto from_actor = dynamic_cast<KernelActor *>(actor);
1869       MS_EXCEPTION_IF_NULL(from_actor);
1870 
1871       auto op_arrow = std::make_shared<DataArrow>(backend_node.second, to_actor->GetAID(), to_index);
1872       (void)from_actor->output_data_arrows_.emplace_back(op_arrow);
1873       auto device_tensor = AnfAlgo::GetMutableOutputAddr(from_actor->kernel_, backend_node.second, false);
1874       UpdateRefCount(device_tensor.get(), true);
1875       return;
1876     }
1877 
1878     auto op_arrow = std::make_shared<DataArrow>(input_with_index.second, to_actor->GetAID(), to_index);
1879     auto from_actor = front_node_to_actor_[input_node];
1880     (void)from_actor->output_data_arrows_.emplace_back(op_arrow);
1881     auto device_tensor = AnfAlgo::GetMutableOutputAddr(from_actor->kernel_, input_with_index.second, false);
1882     UpdateRefCount(device_tensor.get(), true);
1883   } else if (find(parameters.begin(), parameters.end(), input_node) != parameters.end()) {
1884     // The actor input is a parameter in host data source actor.
1885     std::string actor_name = graph_compiler_info.name_ + "_HostDSActor";
1886 
1887     auto actor = FetchActor(actor_name);
1888     MS_EXCEPTION_IF_NULL(actor);
1889     auto from_actor = dynamic_cast<HostQueueDataSourceActor *>(actor);
1890     MS_EXCEPTION_IF_NULL(from_actor);
1891 
1892     auto backend_iter = front_to_backend_parameter.find(input_node);
1893     if (backend_iter == front_to_backend_parameter.end()) {
1894       MS_LOG(EXCEPTION) << "Cannot find backend node for front node:" << AnfAlgo::GetNodeDebugString(input_node);
1895     }
1896 
1897     const auto &backend_node = backend_iter->second.first;
1898     auto iter = from_actor->data_node_position_map_.find(input_node);
1899     if (iter == from_actor->data_node_position_map_.end()) {
1900       MS_LOG(EXCEPTION) << "Cannot find data node in data source actor, backend node:"
1901                         << AnfAlgo::GetNodeDebugString(backend_node)
1902                         << " front node:" << AnfAlgo::GetNodeDebugString(input_node);
1903     }
1904 
1905     auto op_arrow = std::make_shared<DataArrow>(iter->second, to_actor->GetAID(), to_index);
1906     (void)from_actor->output_data_arrows_.emplace_back(op_arrow);
1907     auto device_tensor = AnfAlgo::GetMutableOutputAddr(from_actor->data_nodes_[iter->second], 0, false);
1908     UpdateRefCount(device_tensor.get(), true);
1909   } else {
1910     MS_LOG(EXCEPTION) << "Cannot find actor of switch input_node:" << AnfAlgo::GetNodeDebugString(input_node)
1911                       << " to actor:" << to_actor->GetAID();
1912   }
1913 }
1914 
LinkDataArrowForSwitchActor(const GraphCompilerInfo & graph_compiler_info,SwitchActor * const actor)1915 void GraphScheduler::LinkDataArrowForSwitchActor(const GraphCompilerInfo &graph_compiler_info,
1916                                                  SwitchActor *const actor) {
1917   // Link switch input.
1918   const auto &inputs = actor->input_nodes_;
1919   for (size_t i = 0; i < inputs.size(); ++i) {
1920     auto input = inputs[i];
1921     if (input.first->isa<ValueNode>() || (input.first->isa<Parameter>() && HasAbstractRef(input.first))) {
1922       continue;
1923     }
1924 
1925     const FuncGraphPtr from_func_graph = actor->node_->func_graph();
1926     LinkDataArrowByControlNode(graph_compiler_info, input, from_func_graph, actor, i);
1927   }
1928 
1929   // Link switch output.
1930   for (size_t i = 0; i < actor->branch_func_graph_.size(); ++i) {
1931     auto func_graph = actor->branch_func_graph_[i];
1932     if (func_graph == nullptr) {
1933       continue;
1934     }
1935 
1936     auto gather_name = func_graph->ToString();
1937     if (actor_name_to_actor_.find(gather_name) == actor_name_to_actor_.end()) {
1938       MS_LOG(EXCEPTION) << "Cannot find gather actor for funcgraph:" << gather_name
1939                         << ",switch input size:" << actor->input_nodes_.size();
1940     }
1941     auto to_actor = dynamic_cast<GatherActor *>(actor_name_to_actor_[gather_name]);
1942     for (size_t j = 0; j < actor->branch_inputs_pos_[i].size(); ++j) {
1943       auto pos = actor->branch_inputs_pos_[i][j];
1944       auto to_actor_index = j;
1945       auto op_arrow = std::make_shared<DataArrow>(pos, to_actor->GetAID(), to_actor_index);
1946       (void)actor->output_branch_arrows_[i].emplace_back(op_arrow);
1947     }
1948   }
1949 }
1950 
LinkControlArrowForGatherActor(std::vector<KernelActorPtr> * const kernel_actors,const std::vector<KernelGraphPtr> & graphs,const ControlNodeParserPtr & parser)1951 void GraphScheduler::LinkControlArrowForGatherActor(std::vector<KernelActorPtr> *const kernel_actors,
1952                                                     const std::vector<KernelGraphPtr> &graphs,
1953                                                     const ControlNodeParserPtr &parser) {
1954   // Link control arrow to kernel actor.
1955   for (size_t i = 0; i < graphs.size(); ++i) {
1956     const auto &kernel_graph = graphs[i];
1957     MS_EXCEPTION_IF_NULL(kernel_graph);
1958     const auto &func_graph = kernel_graph->GetFuncGraph();
1959     if (func_graph == nullptr) {
1960       continue;
1961     }
1962     const auto &actor = FetchActor(func_graph->ToString());
1963     if (actor == nullptr) {
1964       continue;
1965     }
1966     const auto &gather_actor = dynamic_cast<GatherActor *>(actor);
1967     MS_EXCEPTION_IF_NULL(gather_actor);
1968 
1969     // When gather actor is not empty, it means the control arrow of no input kernel actor needs to be sent by gather.
1970     for (const auto &kernel : kernel_graph->execution_order()) {
1971       if (IsKernelActor(kernel) && (!IsSkippedKernelActor(kernel))) {
1972         const auto &kernel_actor = dynamic_cast<KernelActor *>(FetchActor(kernel->fullname_with_scope()));
1973         MS_EXCEPTION_IF_NULL(kernel_actor);
1974 
1975         if ((kernel_actor->input_datas_num_ == 0) && (kernel_actor->input_controls_num_ == 0)) {
1976           (void)gather_actor->output_control_arrows_.emplace_back(kernel_actor->GetAID());
1977           kernel_actor->input_controls_num_ = 1;
1978         }
1979       }
1980     }
1981   }
1982 
1983   for (auto &kernel_actor : *kernel_actors) {
1984     MS_EXCEPTION_IF_NULL(kernel_actor);
1985 
1986     if ((kernel_actor->output_data_arrows_.size() == 0) && (kernel_actor->output_control_arrows_.size() == 0) &&
1987         !parser->IsKernelInRootFuncGraph(kernel_actor->kernel_)) {
1988       // Check whether the kernel actor belongs to the root graph.
1989       // In general, all no output nodes belong to the root funcgraph, and the corresponding switch actor for output
1990       // should be empty. In control flow, the control arrow of the no output node in the sub funcgraph should be
1991       // sent to the output switch actor.
1992       const auto &graph = kernel_actor->kernel_->func_graph();
1993       OpActor<DeviceTensor> *actor = nullptr;
1994 
1995       if (graph != nullptr) {
1996         const auto &kernel_graph = dynamic_cast<KernelGraph *>(graph.get());
1997         const auto func_graph = kernel_graph->GetFuncGraph();
1998         if (func_graph != nullptr) {
1999           actor = FetchActor(func_graph->get_return()->DebugString());
2000           if (actor != nullptr) {
2001             auto switch_actor = dynamic_cast<SwitchActor *>(actor);
2002             MS_EXCEPTION_IF_NULL(switch_actor);
2003 
2004             (void)kernel_actor->output_control_arrows_.emplace_back(switch_actor->GetAID());
2005             switch_actor->input_controls_num_++;
2006           }
2007         }
2008       }
2009     }
2010   }
2011 
2012   // Link input auto monad control arrow from kernel actor to gather actor.
2013   const auto &monad_nodes = parser->kernel_to_call_nodes_;
2014   for (const auto node_pair : monad_nodes) {
2015     const auto &kernel_actor_name = node_pair.first->fullname_with_scope();
2016     const auto &gather_actor_name = node_pair.second->DebugString();
2017     auto kernel_op_actor = FetchActor(kernel_actor_name);
2018     auto gather_op_actor = FetchActor(gather_actor_name);
2019     if (kernel_op_actor == nullptr || gather_op_actor == nullptr) {
2020       continue;
2021     }
2022     auto kernel_actor = dynamic_cast<KernelActor *>(kernel_op_actor);
2023     auto gather_actor = dynamic_cast<GatherActor *>(gather_op_actor);
2024     (void)kernel_actor->output_control_arrows_.emplace_back(gather_actor->GetAID());
2025     gather_actor->input_controls_num_++;
2026   }
2027 }
2028 
LinkControlArrowForSwitchActor(std::vector<SwitchActorPtr> * const switch_actors,LoopCountActor * const to_actor,const KernelMapPosition & origin_outputs_order)2029 void GraphScheduler::LinkControlArrowForSwitchActor(std::vector<SwitchActorPtr> *const switch_actors,
2030                                                     LoopCountActor *const to_actor,
2031                                                     const KernelMapPosition &origin_outputs_order) {
2032   if (to_actor == nullptr || (*switch_actors).empty()) {
2033     return;
2034   }
2035 
2036   // If there is no output from the switch actor branch, it means that the subgraph has no input,
2037   // and need to connect a control arrow to the corresponding gather actor.
2038   for (auto &switch_actor : (*switch_actors)) {
2039     if (AnfAlgo::CheckPrimitiveType(switch_actor->node_, prim::kPrimReturn)) {
2040       const auto &func_graph = switch_actor->node_->func_graph();
2041       if (func_graph->output()->isa<ValueNode>()) {
2042         const auto &actor_name = func_graph->ToString();
2043         auto actor = FetchActor(actor_name);
2044         MS_EXCEPTION_IF_NULL(actor);
2045         auto gather_actor = dynamic_cast<GatherActor *>(actor);
2046         MS_EXCEPTION_IF_NULL(gather_actor);
2047         (void)gather_actor->output_control_arrows_.emplace_back(switch_actor->GetAID());
2048         switch_actor->input_controls_num_++;
2049       }
2050     }
2051 
2052     for (size_t i = 0; i < switch_actor->output_branch_arrows_.size(); ++i) {
2053       const auto &arrows = switch_actor->output_branch_arrows_[i];
2054       if (arrows.empty() && switch_actor->branch_func_graph_[i] != nullptr) {
2055         const auto &actor_name = switch_actor->branch_func_graph_[i]->ToString();
2056         const auto &actor = FetchActor(actor_name);
2057         if (actor != nullptr) {
2058           const auto &gather_actor = dynamic_cast<GatherActor *>(actor);
2059           MS_EXCEPTION_IF_NULL(gather_actor);
2060           (void)switch_actor->output_branch_control_arrows_[i].emplace_back(gather_actor->GetAID());
2061           gather_actor->input_controls_num_++;
2062         }
2063       }
2064     }
2065   }
2066 
2067   // Collect all the call node in outputs.
2068   std::set<AnfNodePtr> call_nodes;
2069   for (const auto &output : origin_outputs_order) {
2070     if (IsCallNode(output.first.first)) {
2071       (void)call_nodes.insert(output.first.first);
2072     }
2073   }
2074   to_actor->input_controls_num_ += call_nodes.size();
2075 
2076   // Link the output switch actor of the subgraph to the output actor.
2077   for (const auto &call_node : call_nodes) {
2078     const auto &func_graphs = FetchFuncGraphbyCallNode(call_node);
2079     for (const auto func_graph : func_graphs) {
2080       MS_EXCEPTION_IF_NULL(func_graph);
2081       const auto &actor_name = func_graph->get_return()->DebugString();
2082       auto actor = FetchActor(actor_name);
2083       MS_EXCEPTION_IF_NULL(actor);
2084       auto switch_actor = dynamic_cast<SwitchActor *>(actor);
2085       MS_EXCEPTION_IF_NULL(switch_actor);
2086 
2087       size_t branch_index = switch_actor->branch_id_to_index_.size();
2088       if (switch_actor->branch_id_to_index_.find(kMainBranchID) != switch_actor->branch_id_to_index_.end()) {
2089         branch_index = switch_actor->branch_id_to_index_[kMainBranchID];
2090       } else {
2091         switch_actor->branch_id_to_index_[kMainBranchID] = branch_index;
2092       }
2093 
2094       (void)switch_actor->output_branch_control_arrows_[branch_index].emplace_back(to_actor->GetAID());
2095     }
2096   }
2097 }
2098 
LinkBranchArrowForSwitchActor(const GraphCompilerInfo & graph_compiler_info)2099 void GraphScheduler::LinkBranchArrowForSwitchActor(const GraphCompilerInfo &graph_compiler_info) {
2100   for (const auto &control_node : graph_compiler_info.control_nodes_) {
2101     if (AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimSwitch) ||
2102         AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimSwitchLayer)) {
2103       const auto &actor_name = control_node->DebugString();
2104       auto actor = FetchActor(actor_name);
2105       MS_EXCEPTION_IF_NULL(actor);
2106       auto switch_actor = dynamic_cast<SwitchActor *>(actor);
2107       MS_EXCEPTION_IF_NULL(switch_actor);
2108 
2109       for (size_t i = 0; i < switch_actor->branch_func_graph_.size(); ++i) {
2110         const auto &func_graph = switch_actor->branch_func_graph_[i];
2111         if (func_graph == nullptr) {
2112           continue;
2113         }
2114 
2115         const auto &gather_actor = FetchActor(func_graph->ToString());
2116         MS_EXCEPTION_IF_NULL(gather_actor);
2117         (void)switch_actor->output_branch_branch_arrows_[i].emplace_back(gather_actor->GetAID());
2118       }
2119     }
2120   }
2121 }
2122 
LinkBranchArrowForGatherActor(const GraphCompilerInfo & graph_compiler_info)2123 void GraphScheduler::LinkBranchArrowForGatherActor(const GraphCompilerInfo &graph_compiler_info) {
2124   if (graph_compiler_info.control_nodes_.empty()) {
2125     return;
2126   }
2127 
2128   // Link branch arrow from gather actor to gather actor.
2129   for (const auto &control_node : graph_compiler_info.control_nodes_) {
2130     const auto &cnode = control_node->cast<CNodePtr>();
2131     const auto &inputs = cnode->inputs();
2132     if (inputs[0]->isa<ValueNode>() && IsValueNode<FuncGraph>(inputs[0])) {
2133       const auto &actor_name = control_node->DebugString();
2134       auto actor = FetchActor(actor_name);
2135       MS_EXCEPTION_IF_NULL(actor);
2136       auto gather_actor = dynamic_cast<GatherActor *>(actor);
2137       MS_EXCEPTION_IF_NULL(gather_actor);
2138       (void)gather_actor->output_branch_arrows_.emplace_back(gather_actor->gather_aid_);
2139     }
2140   }
2141 
2142   // Link branch arrow from gather actor to switch actor.
2143   for (const auto &func_graph_with_call_num : graph_compiler_info.control_node_parser_->func_graph_to_call_num_) {
2144     const auto &actor_name = func_graph_with_call_num.first->ToString();
2145     auto actor = FetchActor(actor_name);
2146     MS_EXCEPTION_IF_NULL(actor);
2147     auto gather_actor = dynamic_cast<GatherActor *>(actor);
2148     MS_EXCEPTION_IF_NULL(gather_actor);
2149     (void)gather_actor->output_branch_arrows_.emplace_back(gather_actor->switch_aid_);
2150   }
2151 }
2152 
CheckActorValid(const ActorSet * actor_set,GraphExecutionStrategy strategy) const2153 bool GraphScheduler::CheckActorValid(const ActorSet *actor_set, GraphExecutionStrategy strategy) const {
2154   MS_EXCEPTION_IF_NULL(actor_set);
2155   // Check the data source actors.
2156   for (const auto &data_source_actor : actor_set->data_source_actors_) {
2157     MS_EXCEPTION_IF_NULL(data_source_actor);
2158     if (data_source_actor->output_data_arrows_.size() + data_source_actor->output_result_arrows_.size() +
2159           data_source_actor->output_control_arrows_.size() ==
2160         0) {
2161       MS_LOG(ERROR) << data_source_actor->GetAID().Name() << " has no user.";
2162       return false;
2163     }
2164   }
2165 
2166   if (strategy == GraphExecutionStrategy::kStep) {
2167     return true;
2168   }
2169 
2170   // Check the kernel actors.
2171   for (const auto &kernel_actor : actor_set->kernel_actors_) {
2172     MS_EXCEPTION_IF_NULL(kernel_actor);
2173     if (kernel_actor->output_data_arrows_.size() + kernel_actor->output_control_arrows_.size() == 0) {
2174       MS_LOG(ERROR) << kernel_actor->GetAID().Name() << " has no user.";
2175       return false;
2176     }
2177 
2178     auto input_num = AnfAlgo::GetInputTensorNum(kernel_actor->kernel_);
2179     auto input_data_num = kernel_actor->input_datas_num_;
2180     auto device_tensor_store_num = kernel_actor->device_tensor_store_keys_.size();
2181     if (input_data_num + device_tensor_store_num != input_num) {
2182       MS_LOG(ERROR) << "The input building of " << AnfAlgo::GetNodeDebugString(kernel_actor->kernel_)
2183                     << " is wrong, input data num: " << input_data_num
2184                     << ", device tensor store num: " << device_tensor_store_num << ", total input num: " << input_num;
2185       return false;
2186     }
2187   }
2188 
2189   // Check the copy actors.
2190   for (const auto &copy_actor : actor_set->copy_actors_) {
2191     MS_EXCEPTION_IF_NULL(copy_actor);
2192     if (copy_actor->output_data_arrows_.size() + copy_actor->output_control_arrows_.size() == 0) {
2193       MS_LOG(ERROR) << copy_actor->GetAID().Name() << " has no user.";
2194       return false;
2195     }
2196 
2197     const size_t kCopyActorInputDataNum = 1;
2198     auto input_data_num = copy_actor->input_datas_num_;
2199     size_t device_tensor_store_num = copy_actor->device_tensor_store_keys_.size();
2200     if (input_data_num + device_tensor_store_num != kCopyActorInputDataNum) {
2201       MS_LOG(ERROR) << "The input building of " << copy_actor->GetAID().Name()
2202                     << " is wrong, input data num: " << input_data_num
2203                     << ", device tensor store num: " << device_tensor_store_num
2204                     << ", total input num: " << kCopyActorInputDataNum;
2205       return false;
2206     }
2207   }
2208 
2209   // Check the loop count actor.
2210   const auto &loop_count_actor = actor_set->loop_count_actor_;
2211   if ((loop_count_actor != nullptr) &&
2212       (actor_set->data_source_actors_.size() + actor_set->kernel_actors_.size() + actor_set->copy_actors_.size() > 0)) {
2213     if (loop_count_actor->input_controls_num_ == 0) {
2214       MS_LOG(ERROR) << loop_count_actor->GetAID().Name() << " has no source.";
2215       return false;
2216     }
2217   }
2218 
2219   return true;
2220 }
2221 
PersistDeviceTensor(const GraphCompilerInfo & graph_compiler_info)2222 void GraphScheduler::PersistDeviceTensor(const GraphCompilerInfo &graph_compiler_info) {
2223   for (size_t i = 0; i < graph_compiler_info.graphs_.size(); ++i) {
2224     const auto &graph = graph_compiler_info.graphs_[i];
2225     const auto &device_context = graph_compiler_info.device_contexts_[i];
2226     MS_EXCEPTION_IF_NULL(graph);
2227     MS_EXCEPTION_IF_NULL(device_context);
2228 
2229     for (auto &value_node : graph->graph_value_nodes()) {
2230       MS_EXCEPTION_IF_NULL(value_node);
2231       if (!AnfAlgo::OutputAddrExist(value_node, 0)) {
2232         MS_LOG(INFO) << "The device address is not exist: " << value_node->ToString();
2233         continue;
2234       }
2235       auto device_tensor = AnfAlgo::GetMutableOutputAddr(value_node, 0, false);
2236       const auto &front_node = FetchFrontNodeByBackendNode(value_node, graph);
2237       DeviceTensorStore::GetInstance().Insert(front_node.get(), device_tensor);
2238       UpdateRefCount(device_tensor.get(), true);
2239     }
2240 
2241     for (auto &input_node : graph->input_nodes()) {
2242       MS_EXCEPTION_IF_NULL(input_node);
2243       AnfNodePtr sub_front_node = nullptr;
2244       if (IsInternalParameter(input_node, graph)) {
2245         auto front_output_with_index = graph->GetFrontNodeByInternalParameter(input_node);
2246         sub_front_node = front_output_with_index.first;
2247       } else if (IsPersistentDeviceTensor(input_node) || HasAbstractRef(input_node)) {
2248         sub_front_node = FetchFrontNodeByBackendNode(input_node, graph);
2249       }
2250       if (sub_front_node == nullptr) {
2251         continue;
2252       }
2253 
2254       // The sub front nodes share the device tensor store with the root front node.
2255       MS_EXCEPTION_IF_NULL(graph_compiler_info.control_node_parser_);
2256       auto front_node = graph_compiler_info.control_node_parser_->FetchRootGraphFrontNodeBySubFrontNode(sub_front_node);
2257       MS_EXCEPTION_IF_NULL(front_node);
2258       MS_LOG(DEBUG) << "Graph id:" << graph->graph_id() << ", sub front node:" << sub_front_node->DebugString()
2259                     << ", root front node:" << front_node->DebugString();
2260 
2261       auto device_tensor = AnfAlgo::GetMutableOutputAddr(input_node, 0, false);
2262       MS_EXCEPTION_IF_NULL(device_tensor);
2263       if (IsPersistentDeviceTensor(input_node)) {
2264         DeviceTensorStore::GetInstance().Insert(front_node.get(), device_tensor);
2265         UpdateRefCount(device_tensor.get(), true);
2266       }
2267 
2268       // Share the weight in the host and device, then input_node is internal parameter and front_node is weight.
2269       if (!IsPersistentDeviceTensor(front_node)) {
2270         continue;
2271       }
2272       // If the device tensor store of this device type is not exist, then create the new device tensor of this type.
2273       if (DeviceTensorStore::GetInstance().Fetch(front_node.get(), device_context->GetDeviceAddressType()) == nullptr) {
2274         MS_LOG(INFO) << "Fetch no device tensor store by:" << front_node->fullname_with_scope()
2275                      << ", type:" << device_context->GetDeviceAddressType();
2276         auto other_type_device_tensor = device_context->CreateDeviceAddress(
2277           nullptr, device_tensor->GetSize(), device_tensor->format(), device_tensor->type_id());
2278         DeviceTensorStore::GetInstance().Insert(front_node.get(), other_type_device_tensor);
2279         UpdateRefCount(other_type_device_tensor.get(), true);
2280       }
2281     }
2282   }
2283 
2284   // In control flow, there may be some value nodes that is not in the kernel graph and needs to be placed
2285   // in the tensor store separately.
2286   for (const auto &value_node : graph_compiler_info.control_node_parser_->front_value_nodes_) {
2287     MS_EXCEPTION_IF_NULL(value_node.first);
2288     auto device_tensor = AnfAlgo::GetMutableOutputAddr(value_node.first, 0, false);
2289     DeviceTensorStore::GetInstance().Insert(value_node.first.get(), device_tensor);
2290     UpdateRefCount(device_tensor.get(), true);
2291   }
2292 }
2293 
FetchKernelTransformTypeAndName(const AnfNodePtr & node,const KernelGraphPtr & graph,const GraphCompilerInfo & graph_compiler_info,KernelTransformType * const kernel_type,std::string * const kernel_name)2294 void GraphScheduler::FetchKernelTransformTypeAndName(const AnfNodePtr &node, const KernelGraphPtr &graph,
2295                                                      const GraphCompilerInfo &graph_compiler_info,
2296                                                      KernelTransformType *const kernel_type,
2297                                                      std::string *const kernel_name) {
2298   MS_EXCEPTION_IF_NULL(node);
2299   MS_EXCEPTION_IF_NULL(graph);
2300   MS_EXCEPTION_IF_NULL(kernel_type);
2301   MS_EXCEPTION_IF_NULL(kernel_name);
2302 
2303   if (IsDeviceQueueDSActor(node, graph_compiler_info.strategy_)) {
2304     *kernel_type = KernelTransformType::kDeviceDataSourceActor;
2305     *kernel_name = graph_compiler_info.name_ + "_DeviceDSActor" + "_" + std::to_string(graph->graph_id());
2306   } else if (IsHostQueueDSActor(node, graph, graph_compiler_info.origin_parameters_order_,
2307                                 graph_compiler_info.strategy_)) {
2308     *kernel_type = KernelTransformType::kHostDataSourceActor;
2309     *kernel_name = graph_compiler_info.name_ + "_HostDSActor";
2310   } else if (IsKernelActor(node, graph_compiler_info.strategy_)) {
2311     *kernel_type = KernelTransformType::kKernelActor;
2312     *kernel_name = node->fullname_with_scope();
2313   } else if (IsInternalParameter(node, graph)) {
2314     *kernel_type = KernelTransformType::kInternalParameter;
2315     *kernel_name = "";
2316   } else if (IsPersistentDeviceTensor(node)) {
2317     *kernel_type = KernelTransformType::kDeviceTensorStore;
2318     *kernel_name = "";
2319   } else {
2320     // May exist the from kernel that no need link in the pynative mode.
2321     MS_LOG(DEBUG) << "Invalid from kernel: " << node->fullname_with_scope();
2322     *kernel_type = KernelTransformType::kUnknown;
2323     *kernel_name = "";
2324   }
2325 }
2326 
InsertActor(OpActor<DeviceTensor> * actor)2327 void GraphScheduler::InsertActor(OpActor<DeviceTensor> *actor) {
2328   MS_EXCEPTION_IF_NULL(actor);
2329   if (actor_name_to_actor_.count(actor->GetAID().Name()) > 0) {
2330     MS_LOG(EXCEPTION) << "The actor already exists: " << actor->GetAID().Name();
2331   }
2332   actor_name_to_actor_[actor->GetAID().Name()] = actor;
2333 }
2334 
FetchActor(const std::string & actor_name) const2335 OpActor<DeviceTensor> *GraphScheduler::FetchActor(const std::string &actor_name) const {
2336   const auto &iter = actor_name_to_actor_.find(actor_name);
2337   if (iter == actor_name_to_actor_.end()) {
2338     return nullptr;
2339   }
2340   return iter->second;
2341 }
2342 
DumpActor(const ActorSet * actor_set,const GraphCompilerInfo & graph_compiler_info) const2343 void GraphScheduler::DumpActor(const ActorSet *actor_set, const GraphCompilerInfo &graph_compiler_info) const {
2344   MS_EXCEPTION_IF_NULL(actor_set);
2345   const auto &context_ptr = MsContext::GetInstance();
2346   MS_EXCEPTION_IF_NULL(context_ptr);
2347   auto save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
2348   if (!save_graphs) {
2349     return;
2350   }
2351 
2352   std::string filename = GetSaveGraphsPathName("actor_set_" + actor_set->name_ + ".ir");
2353   std::ofstream ofs(filename);
2354   if (!ofs.is_open()) {
2355     MS_LOG(ERROR) << "Open file [" << filename << "] failed!";
2356     return;
2357   }
2358 
2359   ofs << "[Device tensor stores]\n";
2360   DumpDeviceTensorStore(graph_compiler_info, ofs);
2361 
2362   const auto &data_prepare_actor = actor_set->data_prepare_actor_;
2363   ofs << "\n\n[Data prepare actor:" << (data_prepare_actor != nullptr ? 1 : 0) << "]\n";
2364   if (data_prepare_actor != nullptr) {
2365     DumpDataPrepareActor(data_prepare_actor.get(), ofs);
2366   }
2367 
2368   ofs << "\n\n[Data source actors:" << actor_set->data_source_actors_.size() << "]\n";
2369   for (const auto &data_source_actor : actor_set->data_source_actors_) {
2370     DumpDSActor(data_source_actor.get(), ofs);
2371   }
2372 
2373   ofs << "\n\n[Kernel actors:" << actor_set->kernel_actors_.size() << "]\n";
2374   for (const auto &kernel_actor : actor_set->kernel_actors_) {
2375     DumpKernelActor(kernel_actor.get(), ofs);
2376   }
2377 
2378   ofs << "\n\n[No input kernel actors:" << actor_set->no_input_kernel_actors_.size() << "]\n";
2379   for (const auto &no_input_kernel_actor : actor_set->no_input_kernel_actors_) {
2380     DumpKernelActor(no_input_kernel_actor.get(), ofs);
2381   }
2382 
2383   ofs << "\n\n[Copy actors:" << actor_set->copy_actors_.size() << "]\n";
2384   for (const auto &copy_actor : actor_set->copy_actors_) {
2385     DumpCopyActor(copy_actor.get(), ofs);
2386   }
2387 
2388   ofs << "\n\n[Gather actors:" << actor_set->gather_actors_.size() << "]\n";
2389   for (const auto &gather_actor : actor_set->gather_actors_) {
2390     DumpGatherActor(gather_actor.get(), ofs);
2391   }
2392 
2393   ofs << "\n\n[Switch actors:" << actor_set->switch_actors_.size() << "]\n";
2394   for (const auto &switch_actor : actor_set->switch_actors_) {
2395     DumpSwitchActor(switch_actor.get(), ofs);
2396   }
2397 
2398   const auto &loop_count_actor = actor_set->loop_count_actor_;
2399   ofs << "\n\n[Loop count actor:" << (loop_count_actor != nullptr ? 1 : 0) << "]\n";
2400   if (loop_count_actor != nullptr) {
2401     DumpLoopCountActor(loop_count_actor.get(), ofs);
2402   }
2403 
2404   const auto &output_actor = actor_set->output_actor_;
2405   ofs << "\n\n[Output actor:" << (output_actor != nullptr ? 1 : 0) << "]\n";
2406   if (output_actor != nullptr) {
2407     DumpOutputActor(output_actor.get(), ofs);
2408   }
2409 }
2410 
DumpAbstractActor(const AbstractActor * actor,std::ofstream & ofs) const2411 void GraphScheduler::DumpAbstractActor(const AbstractActor *actor, std::ofstream &ofs) const {
2412   MS_EXCEPTION_IF_NULL(actor);
2413   ofs << "\t\tdevice_contexts_num:" << actor->device_contexts_.size()
2414       << "\tdevice_tensor_store_keys_num:" << actor->device_tensor_store_keys_.size()
2415       << "\tinput_data_arrow_actors_num:" << actor->input_datas_num_
2416       << "\tinput_control_arrow_actors_num:" << actor->input_controls_num_ << "\n";
2417   ofs << "\t\toutput_data_arrows_num:" << actor->output_data_arrows_.size()
2418       << "\toutput_control_arrows_num:" << actor->output_control_arrows_.size()
2419       << "\toutput_result_arrows_num:" << actor->output_result_arrows_.size() << "\n";
2420 
2421   if (actor->device_contexts_.size() > 0) {
2422     ofs << "\t\tdevice_contexts:" << actor->device_contexts_.size() << "\n ";
2423     for (const auto &device_context : actor->device_contexts_) {
2424       if (device_context == nullptr) {
2425         ofs << "\t\t\tdevice_context:" << device_context << "\n";
2426         continue;
2427       }
2428       ofs << "\t\t\tdevice_context:" << device_context->device_context_key().ToString() << "\n";
2429     }
2430   }
2431 
2432   if (actor->device_tensor_store_keys_.size() > 0) {
2433     ofs << "\t\tdevice_tensor_store_keys:" << actor->device_tensor_store_keys_.size() << "\n ";
2434     for (const auto &device_tensor_store_key : actor->device_tensor_store_keys_) {
2435       MS_EXCEPTION_IF_NULL(device_tensor_store_key.second);
2436       ofs << "\t\t\tto_input_index:" << device_tensor_store_key.first
2437           << "\tfrom_node_name:" << device_tensor_store_key.second->fullname_with_scope() << "\n";
2438     }
2439   }
2440 
2441   if (actor->input_data_arrow_aids_.size() > 0) {
2442     ofs << "\t\tinput_data_arrow_actors:" << actor->input_data_arrow_aids_.size() << "\n ";
2443     for (const auto &input_data_arrow_aid : actor->input_data_arrow_aids_) {
2444       ofs << "\t\t\tfrom_actor_name:" << input_data_arrow_aid.Name() << "\n";
2445     }
2446   }
2447 
2448   if (actor->input_control_arrow_aids_.size() > 0) {
2449     ofs << "\t\tinput_control_arrow_actors:" << actor->input_control_arrow_aids_.size() << "\n ";
2450     for (const auto &input_control_arrow_aid : actor->input_control_arrow_aids_) {
2451       ofs << "\t\t\tfrom_actor_name:" << input_control_arrow_aid.Name() << "\n";
2452     }
2453   }
2454 
2455   const auto &output_data_arrows = actor->output_data_arrows();
2456   if (output_data_arrows.size() > 0) {
2457     ofs << "\t\toutput_data_arrows:" << output_data_arrows.size() << "\n ";
2458     for (const auto &data_arrow : output_data_arrows) {
2459       MS_EXCEPTION_IF_NULL(data_arrow);
2460       ofs << "\t\t\tfrom_output_index:" << data_arrow->from_output_index_
2461           << "\tto_actor_name:" << data_arrow->to_op_id_.Name() << "\tto_input_index:" << data_arrow->to_input_index_
2462           << "\n";
2463     }
2464   }
2465 
2466   const auto &output_control_arrows = actor->output_control_arrows();
2467   if (output_control_arrows.size() > 0) {
2468     ofs << "\t\toutput_control_arrows:" << output_control_arrows.size() << "\n ";
2469     for (const auto &aid : output_control_arrows) {
2470       ofs << "\t\t\tto_actor_name:" << aid.Name() << "\n";
2471     }
2472   }
2473 
2474   if (actor->output_result_arrows_.size() > 0) {
2475     ofs << "\t\toutput_result_arrows:" << actor->output_result_arrows_.size() << "\n ";
2476     for (const auto &result_arrow : actor->output_result_arrows_) {
2477       MS_EXCEPTION_IF_NULL(result_arrow);
2478       ofs << "\t\t\tfrom_output_index:" << result_arrow->from_output_index_
2479           << "\tto_actor_name:" << result_arrow->to_op_id_.Name()
2480           << "\toutput_node_position:" << result_arrow->to_input_index_ << "\n";
2481     }
2482   }
2483 }
2484 
DumpDataPrepareActor(const DataPrepareActor * actor,std::ofstream & ofs) const2485 void GraphScheduler::DumpDataPrepareActor(const DataPrepareActor *actor, std::ofstream &ofs) const {
2486   MS_EXCEPTION_IF_NULL(actor);
2487   ofs << "\tactor_name:" << actor->GetAID().Name() << "\n";
2488   DumpAbstractActor(actor, ofs);
2489 
2490   ofs << "\t\toutput_control_arrows:" << actor->data_source_aids_.size() + actor->no_input_kernel_aids_.size() << "\n ";
2491   for (const auto &aid : actor->data_source_aids_) {
2492     ofs << "\t\t\tto_actor_name:" << aid.Name() << "\n";
2493   }
2494   for (const auto &aid : actor->no_input_kernel_aids_) {
2495     ofs << "\t\t\tto_actor_name:" << aid.Name() << "\n";
2496   }
2497 
2498   ofs << "\t\tcontinuous_memory_nodes:" << actor->continuous_memory_nodes_.size() << "\n ";
2499   for (const auto &iter : actor->continuous_memory_nodes_) {
2500     MS_EXCEPTION_IF_NULL(iter.first.first);
2501     MS_EXCEPTION_IF_NULL(iter.first.second);
2502     ofs << "\t\t\tnode_name:" << iter.first.first->fullname_with_scope()
2503         << "\tdevice_context:" << iter.first.second->device_context_key().ToString()
2504         << "\tis_input_need:" << iter.second.first << "\tis_output_need:" << iter.second.second << "\n";
2505   }
2506 }
2507 
DumpDSActor(const DataSourceActor * actor,std::ofstream & ofs) const2508 void GraphScheduler::DumpDSActor(const DataSourceActor *actor, std::ofstream &ofs) const {
2509   MS_EXCEPTION_IF_NULL(actor);
2510   const auto &actor_name = actor->GetAID().Name();
2511   ofs << "\tactor_name:" << actor_name << "\n";
2512 
2513   if (actor->type_ == KernelTransformType::kDeviceDataSourceActor) {
2514     // Dump the member info of device queue data source actor.
2515     const auto &device_queue_ds_actor = dynamic_cast<const DeviceQueueDataSourceActor *>(actor);
2516     MS_EXCEPTION_IF_NULL(device_queue_ds_actor);
2517     const auto &data_kernel = device_queue_ds_actor->data_kernel_;
2518     MS_EXCEPTION_IF_NULL(data_kernel);
2519     ofs << "\t\tdata_kernel_name:" << data_kernel->fullname_with_scope()
2520         << "\tinput_number:" << AnfAlgo::GetInputTensorNum(data_kernel)
2521         << "\toutput_number:" << AnfAlgo::GetOutputTensorNum(data_kernel) << "\n";
2522     for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(data_kernel); ++i) {
2523       const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(data_kernel, i, false);
2524       MS_EXCEPTION_IF_NULL(device_tensor);
2525       ofs << "\t\t\toutput_index:" << i << "\tptr:" << device_tensor->GetPtr() << "\tsize:" << device_tensor->GetSize()
2526           << "\toriginal_ref_count:" << device_tensor->original_ref_count() << "\n ";
2527     }
2528   } else if (actor->type_ == KernelTransformType::kHostDataSourceActor) {
2529     // Dump the member info of host queue data source actor.
2530     const auto &host_queue_ds_actor = dynamic_cast<const HostQueueDataSourceActor *>(actor);
2531     MS_EXCEPTION_IF_NULL(host_queue_ds_actor);
2532     ofs << "\t\tdata_nodes:" << host_queue_ds_actor->data_nodes_.size() << "\n";
2533     for (size_t i = 0; i < host_queue_ds_actor->data_nodes_.size(); ++i) {
2534       const auto &data_node = host_queue_ds_actor->data_nodes_[i];
2535       MS_EXCEPTION_IF_NULL(data_node);
2536       const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(data_node, 0, false);
2537       MS_EXCEPTION_IF_NULL(device_tensor);
2538       ofs << "\t\t\tnode_order_number:" << i << "\tnode_name:" << data_node->fullname_with_scope()
2539           << "\tptr:" << device_tensor->GetPtr() << "\tsize:" << device_tensor->GetSize()
2540           << "\toriginal_ref_count:" << device_tensor->original_ref_count() << "\n";
2541     }
2542   }
2543 
2544   DumpAbstractActor(actor, ofs);
2545   ofs << "\n";
2546 }
2547 
DumpLoopCountActor(const LoopCountActor * actor,std::ofstream & ofs) const2548 void GraphScheduler::DumpLoopCountActor(const LoopCountActor *actor, std::ofstream &ofs) const {
2549   MS_EXCEPTION_IF_NULL(actor);
2550   ofs << "\tactor_name:" << actor->GetAID().Name() << "\tloop_count:" << actor->loop_count_ << "\n";
2551   DumpAbstractActor(actor, ofs);
2552 
2553   const size_t kOutputControlArrowsNum = 2;
2554   ofs << "\t\toutput_control_arrows:" << kOutputControlArrowsNum << "\n ";
2555   ofs << "\t\t\tto_actor_name:" << actor->output_aid_.Name() << "\n";
2556   ofs << "\t\t\tto_actor_name:" << actor->data_prepare_aid_.Name() << "\n";
2557 }
2558 
DumpKernelActor(const KernelActor * actor,std::ofstream & ofs) const2559 void GraphScheduler::DumpKernelActor(const KernelActor *actor, std::ofstream &ofs) const {
2560   MS_EXCEPTION_IF_NULL(actor);
2561   ofs << "\tactor_name:" << actor->GetAID().Name() << "\n";
2562 
2563   const auto &kernel = actor->kernel_;
2564   MS_EXCEPTION_IF_NULL(kernel);
2565   ofs << "\t\tkernel_name:" << kernel->fullname_with_scope() << "\tinputs_num:" << AnfAlgo::GetInputTensorNum(kernel)
2566       << "\toutputs_num:" << AnfAlgo::GetOutputTensorNum(kernel) << "\n";
2567   for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(kernel); ++i) {
2568     const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(kernel, i, false);
2569     MS_EXCEPTION_IF_NULL(device_tensor);
2570     ofs << "\t\t\toutput_index:" << i << "\tptr:" << device_tensor->GetPtr() << "\tsize:" << device_tensor->GetSize()
2571         << "\toriginal_ref_count:" << device_tensor->original_ref_count() << "\n ";
2572   }
2573 
2574   DumpAbstractActor(actor, ofs);
2575   ofs << "\n";
2576 }
2577 
DumpOutputActor(const OutputActor * actor,std::ofstream & ofs) const2578 void GraphScheduler::DumpOutputActor(const OutputActor *actor, std::ofstream &ofs) const {
2579   MS_EXCEPTION_IF_NULL(actor);
2580   ofs << "\tactor_name:" << actor->GetAID().Name() << "\tloop_count:" << actor->loop_count_
2581       << "\toutputs_num:" << actor->outputs_num_ << "\n";
2582   DumpAbstractActor(actor, ofs);
2583 }
2584 
DumpCopyActor(const CopyActor * actor,std::ofstream & ofs) const2585 void GraphScheduler::DumpCopyActor(const CopyActor *actor, std::ofstream &ofs) const {
2586   MS_EXCEPTION_IF_NULL(actor);
2587   ofs << "\tactor_name:" << actor->GetAID().Name() << "\n";
2588 
2589   auto device_tensor = actor->output_;
2590   if (device_tensor != nullptr) {
2591     ofs << "\t\toutput_index:" << 0 << "\tptr:" << device_tensor->GetPtr() << "\tsize:" << device_tensor->GetSize()
2592         << "\toriginal_ref_count:" << device_tensor->original_ref_count() << "\n ";
2593   }
2594 
2595   DumpAbstractActor(actor, ofs);
2596   ofs << "\n";
2597 }
2598 
DumpDeviceTensorStore(const GraphCompilerInfo & graph_compiler_info,std::ofstream & ofs) const2599 void GraphScheduler::DumpDeviceTensorStore(const GraphCompilerInfo &graph_compiler_info, std::ofstream &ofs) const {
2600   for (const auto &graph : graph_compiler_info.graphs_) {
2601     MS_EXCEPTION_IF_NULL(graph);
2602     ofs << "\tgraph id:" << graph->graph_id() << "\n";
2603 
2604     for (auto &value_node : graph->graph_value_nodes()) {
2605       MS_EXCEPTION_IF_NULL(value_node);
2606       if (!AnfAlgo::OutputAddrExist(value_node, 0)) {
2607         continue;
2608       }
2609       const auto &front_node = FetchFrontNodeByBackendNode(value_node, graph);
2610       MS_EXCEPTION_IF_NULL(front_node);
2611       const auto device_tensors = DeviceTensorStore::GetInstance().Fetch(front_node.get());
2612       ofs << "\t\tdevice tensor key:" << front_node->fullname_with_scope() << "\tvalue size:" << device_tensors.size()
2613           << "\n";
2614       for (const auto &device_tensor : device_tensors) {
2615         MS_EXCEPTION_IF_NULL(device_tensor);
2616         ofs << "\t\t\tdevice tensor value:" << device_tensor << "\tptr:" << device_tensor->GetPtr()
2617             << "\tsize:" << device_tensor->GetSize() << "\toriginal_ref_count:" << device_tensor->original_ref_count()
2618             << "\tdevice_type:" << device_tensor->DeviceType() << "\n ";
2619       }
2620     }
2621 
2622     for (auto &input_node : graph->input_nodes()) {
2623       MS_EXCEPTION_IF_NULL(input_node);
2624       if (!IsPersistentDeviceTensor(input_node)) {
2625         continue;
2626       }
2627       const auto &sub_front_node = FetchFrontNodeByBackendNode(input_node, graph);
2628       // The sub front nodes share the device tensor store with the root front node.
2629       auto front_node = sub_front_node;
2630       if (graph_compiler_info.control_node_parser_ != nullptr) {
2631         front_node = graph_compiler_info.control_node_parser_->FetchRootGraphFrontNodeBySubFrontNode(sub_front_node);
2632       }
2633       const auto device_tensors = DeviceTensorStore::GetInstance().Fetch(front_node.get());
2634       MS_EXCEPTION_IF_NULL(front_node);
2635       ofs << "\t\tdevice tensor key:" << front_node->fullname_with_scope() << "\tvalue size:" << device_tensors.size()
2636           << "\n";
2637       for (const auto &device_tensor : device_tensors) {
2638         MS_EXCEPTION_IF_NULL(device_tensor);
2639         ofs << "\t\t\tdevice tensor value:" << device_tensor << "\tptr:" << device_tensor->GetPtr()
2640             << "\tsize:" << device_tensor->GetSize() << "\toriginal_ref_count:" << device_tensor->original_ref_count()
2641             << "\tdevice_type:" << device_tensor->DeviceType() << "\n ";
2642       }
2643     }
2644     ofs << "\n";
2645   }
2646 }
2647 
DumpGatherActor(const GatherActor * actor,std::ofstream & ofs) const2648 void GraphScheduler::DumpGatherActor(const GatherActor *actor, std::ofstream &ofs) const {
2649   MS_EXCEPTION_IF_NULL(actor);
2650   ofs << "\tactor_name:" << actor->GetAID().Name() << '\n';
2651 
2652   ofs << "\t\tactor input num:" << actor->data_nodes_.size() << "\n";
2653   for (const auto &node : actor->data_nodes_) {
2654     ofs << "\t\t\t" << AnfAlgo::GetNodeDebugString(node.first) << "\tindex:" << node.second << '\n';
2655   }
2656 
2657   ofs << "\t\tactor front to backend node:\n";
2658   for (const auto &front_to_backend_parameter : actor->front_to_backend_parameter_) {
2659     ofs << "\t\t\tfront node:" << AnfAlgo::GetNodeDebugString(front_to_backend_parameter.first) << '\n';
2660     for (const auto node_with_index : front_to_backend_parameter.second) {
2661       ofs << "\t\t\t\tbackend node:" << AnfAlgo::GetNodeDebugString(node_with_index.first)
2662           << "\tindex:" << node_with_index.second << '\n';
2663     }
2664   }
2665 
2666   ofs << "\t\tactor output data arrow:\n";
2667   for (const auto &data_arrow : actor->output_data_arrows_) {
2668     MS_EXCEPTION_IF_NULL(data_arrow);
2669     ofs << "\t\t\tfrom_output_index:" << data_arrow->from_output_index_
2670         << "\tto_actor_name:" << data_arrow->to_op_id_.Name() << "\tto_input_index:" << data_arrow->to_input_index_
2671         << "\n";
2672   }
2673 
2674   ofs << "\t\tactor output result arrow:\n";
2675   for (const auto &result_arrow : actor->output_result_arrows_) {
2676     MS_EXCEPTION_IF_NULL(result_arrow);
2677     ofs << "\t\t\tfrom_output_index:" << result_arrow->from_output_index_
2678         << "\tto_actor_name:" << result_arrow->to_op_id_.Name() << "\tto_input_index:" << result_arrow->to_input_index_
2679         << "\n";
2680   }
2681 
2682   ofs << "\t\tactor output control arrow:\n";
2683   for (const auto &control_arrow : actor->output_control_arrows_) {
2684     ofs << "\t\t\tto_actor_name:" << control_arrow;
2685   }
2686   ofs << "\n";
2687 }
2688 
DumpSwitchActor(const SwitchActor * actor,std::ofstream & ofs) const2689 void GraphScheduler::DumpSwitchActor(const SwitchActor *actor, std::ofstream &ofs) const {
2690   MS_EXCEPTION_IF_NULL(actor);
2691   ofs << "\tactor_name:" << actor->GetAID().Name() << '\n';
2692 
2693   ofs << "\t\tactor input num:" << actor->input_nodes_.size() << "\n";
2694   for (const auto &node : actor->input_nodes_) {
2695     ofs << "\t\t\t" << AnfAlgo::GetNodeDebugString(node.first) << '\t' << node.second << '\n';
2696   }
2697 
2698   ofs << "\t\tactor input pos:\n";
2699   for (size_t i = 0; i < actor->branch_inputs_pos_.size(); ++i) {
2700     ofs << "\t\t\tbranch " << i << " input pos:";
2701     for (const auto pos : actor->branch_inputs_pos_[i]) {
2702       ofs << pos << '\t';
2703     }
2704     ofs << '\n';
2705   }
2706 
2707   ofs << "\t\tactor output data arrow:\n";
2708   for (size_t i = 0; i < actor->output_branch_arrows_.size(); ++i) {
2709     ofs << "\t\t\tbranch " << i << " output data:\n";
2710     for (const auto arrow : actor->output_branch_arrows_[i]) {
2711       MS_EXCEPTION_IF_NULL(arrow);
2712       ofs << "\t\t\t\t from index:" << arrow->from_output_index_ << "\tto_actor_name:" << arrow->to_op_id_
2713           << "\tto_input_index:" << arrow->to_input_index_ << '\n';
2714     }
2715   }
2716 
2717   ofs << "\t\tactor output result arrow:\n";
2718   for (size_t i = 0; i < actor->output_branch_result_arrows_.size(); ++i) {
2719     ofs << "\t\t\tbranch " << i << " output result:\n";
2720     for (const auto arrow : actor->output_branch_result_arrows_[i]) {
2721       MS_EXCEPTION_IF_NULL(arrow);
2722       ofs << "\t\t\t\t from index:" << arrow->from_output_index_ << "\tto_actor_name:" << arrow->to_op_id_
2723           << "\tto_input_index:" << arrow->to_input_index_ << '\n';
2724     }
2725   }
2726 
2727   ofs << "\t\tactor output control arrow:\n";
2728   for (size_t i = 0; i < actor->output_branch_control_arrows_.size(); ++i) {
2729     ofs << "\t\t\tbranch " << i << " output control:\n";
2730     for (const auto arrow : actor->output_branch_control_arrows_[i]) {
2731       ofs << "\t\t\t\t from index:" << arrow << '\n';
2732     }
2733   }
2734   ofs << "\n";
2735 }
2736 }  // namespace runtime
2737 }  // namespace mindspore
2738