1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "runtime/framework/graph_scheduler.h"
18 #include "runtime/framework/actor/memory_manager_actor.h"
19 #include "runtime/framework/actor/debug_actor.h"
20 #include "runtime/framework/actor/recorder_actor.h"
21 #include "runtime/hardware/device_context_manager.h"
22 #include "mindrt/src/actor/actormgr.h"
23 #include "mindrt/include/async/async.h"
24 #include "backend/session/anf_runtime_algorithm.h"
25 #include "backend/optimizer/common/helper.h"
26 #include "utils/config_manager.h"
27 #include "utils/log_adapter.h"
28 #include "utils/convert_utils.h"
29 #include "utils/ms_context.h"
30 #if !defined(_WIN32) && !defined(_WIN64)
31 #include "utils/signal_util.h"
32 #endif
33 #ifndef ENABLE_SECURITY
34 #include "debug/data_dump/dump_json_parser.h"
35 #endif
36 #ifdef ENABLE_DUMP_IR
37 #include "debug/rdr/recorder_manager.h"
38 #endif
39 #ifdef ENABLE_DEBUGGER
40 #include "debug/debugger/debugger.h"
41 #endif
42 #include "profiler/device/profiling.h"
43 #include "debug/common.h"
44
45 namespace mindspore {
46 namespace runtime {
47 namespace {
IsNeedInsertCopyActor(const DeviceContext * from_device_context,const DeviceContext * to_device_context)48 bool IsNeedInsertCopyActor(const DeviceContext *from_device_context, const DeviceContext *to_device_context) {
49 MS_EXCEPTION_IF_NULL(from_device_context);
50 MS_EXCEPTION_IF_NULL(to_device_context);
51
52 if (from_device_context->GetDeviceAddressType() == to_device_context->GetDeviceAddressType()) {
53 return false;
54 } else {
55 return true;
56 }
57 }
58
IsSingleOpActorSet(const ActorSet * actor_set)59 inline bool IsSingleOpActorSet(const ActorSet *actor_set) {
60 MS_EXCEPTION_IF_NULL(actor_set);
61 return actor_set->kernel_actors_.size() == 1;
62 }
63
64 // Convert the actors vector by the actor set.
CollectActors(const ActorSet * actor_set)65 std::vector<ActorReference> CollectActors(const ActorSet *actor_set) {
66 MS_EXCEPTION_IF_NULL(actor_set);
67 std::vector<ActorReference> actors;
68
69 if (actor_set->data_prepare_actor_ != nullptr) {
70 (void)actors.emplace_back(static_cast<ActorReference>(actor_set->data_prepare_actor_));
71 }
72 for (auto &data_source_actor : actor_set->data_source_actors_) {
73 MS_EXCEPTION_IF_NULL(data_source_actor);
74 (void)actors.emplace_back(static_cast<ActorReference>(data_source_actor));
75 }
76 for (auto &kernel_actor : actor_set->kernel_actors_) {
77 MS_EXCEPTION_IF_NULL(kernel_actor);
78 (void)actors.emplace_back(static_cast<ActorReference>(kernel_actor));
79 }
80 for (auto &switch_actor : actor_set->switch_actors_) {
81 MS_EXCEPTION_IF_NULL(switch_actor);
82 (void)actors.emplace_back(static_cast<ActorReference>(switch_actor));
83 }
84 for (auto &gather_actor : actor_set->gather_actors_) {
85 MS_EXCEPTION_IF_NULL(gather_actor);
86 (void)actors.emplace_back(static_cast<ActorReference>(gather_actor));
87 }
88 for (auto ©_actor : actor_set->copy_actors_) {
89 MS_EXCEPTION_IF_NULL(copy_actor);
90 (void)actors.emplace_back(static_cast<ActorReference>(copy_actor));
91 }
92 if (actor_set->loop_count_actor_ != nullptr) {
93 (void)actors.emplace_back(static_cast<ActorReference>(actor_set->loop_count_actor_));
94 }
95 if (actor_set->output_actor_ != nullptr) {
96 (void)actors.emplace_back(static_cast<ActorReference>(actor_set->output_actor_));
97 }
98
99 return actors;
100 }
101
ClearNodeInfo(const KernelGraphPtr & graph)102 void ClearNodeInfo(const KernelGraphPtr &graph) {
103 MS_EXCEPTION_IF_NULL(graph);
104
105 // Clear input parameter device tensor and device tensor store.
106 for (const auto &input_node : graph->input_nodes()) {
107 MS_EXCEPTION_IF_NULL(input_node);
108 if (!input_node->isa<Parameter>()) {
109 continue;
110 }
111 auto parameter = input_node->cast<ParameterPtr>();
112 MS_EXCEPTION_IF_NULL(parameter);
113 parameter->DecreaseUsedGraphCount();
114 // Only the parameter has no graph used, then clear the device tensor.
115 if (parameter->used_graph_count() != 0) {
116 continue;
117 }
118 auto front_input_node = FetchFrontNodeByBackendNode(input_node, graph);
119 DeviceTensorStore::GetInstance().Remove(front_input_node.get());
120 size_t output_num = AnfAlgo::GetOutputTensorNum(input_node);
121 for (size_t index = 0; index < output_num; ++index) {
122 if (AnfAlgo::OutputAddrExist(input_node, index)) {
123 AnfAlgo::SetOutputAddr(nullptr, index, input_node.get());
124 }
125 }
126 }
127
128 // Clear input value node device tensor and device tensor store.
129 for (const auto &value_node : graph->graph_value_nodes()) {
130 auto front_value_node = FetchFrontNodeByBackendNode(value_node, graph);
131 DeviceTensorStore::GetInstance().Remove(front_value_node.get());
132 if (AnfAlgo::OutputAddrExist(value_node, 0)) {
133 AnfAlgo::SetOutputAddr(nullptr, 0, value_node.get());
134 }
135 }
136
137 // Clear cnode device tensor.
138 for (const auto &cnode : graph->execution_order()) {
139 size_t output_num = AnfAlgo::GetOutputTensorNum(cnode);
140 for (size_t index = 0; index < output_num; ++index) {
141 if (AnfAlgo::OutputAddrExist(cnode, index)) {
142 AnfAlgo::SetOutputAddr(nullptr, index, cnode.get());
143 }
144 }
145 }
146 }
147
148 #if !defined(_WIN32) && !defined(_WIN64)
IntHandler(int,siginfo_t *,void *)149 void IntHandler(int, siginfo_t *, void *) {
150 int this_pid = getpid();
151 MS_LOG(WARNING) << "Process " << this_pid << " receive KeyboardInterrupt signal.";
152 (void)kill(this_pid, SIGTERM);
153 }
154 #endif
155 } // namespace
156
Clear(const ActorInfo & actor_info,const std::vector<KernelGraphPtr> & graphs)157 void GraphScheduler::Clear(const ActorInfo &actor_info, const std::vector<KernelGraphPtr> &graphs) noexcept {
158 // Terminate the actors of actor info.
159 if (actors_.count(actor_info) > 0) {
160 auto actor_manager = ActorMgr::GetActorMgrRef();
161 if (actor_manager == nullptr) {
162 MS_LOG(ERROR) << "Actor manager is not exist.";
163 return;
164 }
165 auto actor_set = actors_[actor_info];
166 auto base_actors = CollectActors(actor_set.get());
167 for (auto &base_actor : base_actors) {
168 MS_EXCEPTION_IF_NULL(base_actor);
169 (void)actor_name_to_actor_.erase(base_actor->GetAID().Name());
170 actor_manager->Terminate(base_actor->GetAID());
171 }
172 }
173
174 // Clear device tensor and device tensor store.
175 for (auto &graph : graphs) {
176 ClearNodeInfo(graph);
177 }
178
179 // Clear global maps of actor info.
180 (void)actors_.erase(actor_info);
181 }
182
Clear()183 void GraphScheduler::Clear() {
184 // Terminate all actors.
185 auto actor_manager = ActorMgr::GetActorMgrRef();
186 MS_EXCEPTION_IF_NULL(actor_manager);
187 actor_manager->Finalize();
188
189 // Clear the member of DeviceTensorStore.
190 DeviceTensorStore::GetInstance().Clear();
191
192 // Clear global maps.
193 actors_.clear();
194 actor_name_to_actor_.clear();
195 }
196
197 using DataArrowLinkFunc = void (GraphScheduler::*)(AbstractActor *const, KernelActor *const, const KernelWithIndex &,
198 const KernelWithIndex &, const KernelGraphPtr &);
199 static std::map<KernelTransformType, DataArrowLinkFunc> kKernelTypeToLinkFunc = {};
200
Initialize()201 void GraphScheduler::Initialize() {
202 if (init_) {
203 return;
204 }
205 init_ = true;
206
207 (void)kKernelTypeToLinkFunc.emplace(KernelTransformType::kDeviceDataSourceActor,
208 &GraphScheduler::LinkDataArrowForDeviceDSActor);
209 (void)kKernelTypeToLinkFunc.emplace(KernelTransformType::kHostDataSourceActor,
210 &GraphScheduler::LinkDataArrowForHostDSActor);
211 (void)kKernelTypeToLinkFunc.emplace(KernelTransformType::kKernelActor, &GraphScheduler::LinkDataArrowForKernelActor);
212 (void)kKernelTypeToLinkFunc.emplace(KernelTransformType::kDeviceTensorStore,
213 &GraphScheduler::LinkDataArrowForDeviceTensorStore);
214 (void)kKernelTypeToLinkFunc.emplace(KernelTransformType::kInternalParameter,
215 &GraphScheduler::LinkDataArrowForInternalParameter);
216
217 // Create the thread pool of actor runtime and Set the OMP_NUM_THREADS env.
218 size_t actor_thread_num = 0;
219 size_t OMP_thread_num = 0;
220 size_t max_thread_num = 0;
221 ComputeThreadNums(&actor_thread_num, &OMP_thread_num, &max_thread_num);
222 auto actor_manager = ActorMgr::GetActorMgrRef();
223 MS_EXCEPTION_IF_NULL(actor_manager);
224 auto ret = actor_manager->Initialize(true, actor_thread_num, max_thread_num);
225 if (ret != MINDRT_OK) {
226 MS_LOG(EXCEPTION) << "Actor manager init failed.";
227 }
228 std::string OMP_env = std::to_string(OMP_thread_num);
229 (void)common::SetEnv("OMP_NUM_THREADS", OMP_env.c_str(), 0);
230 auto OMP_thread_num_used = common::GetEnv("OMP_NUM_THREADS");
231 MS_LOG(INFO) << "The actor thread number: " << actor_thread_num
232 << ", the computed OMP thread number : " << OMP_thread_num
233 << ", the used OMP thread number : " << OMP_thread_num_used;
234
235 BuildAndScheduleGlobalActor();
236 }
237
BuildAndScheduleGlobalActor()238 void GraphScheduler::BuildAndScheduleGlobalActor() {
239 auto actor_manager = ActorMgr::GetActorMgrRef();
240 MS_EXCEPTION_IF_NULL(actor_manager);
241
242 // Create and schedule memory manager actor.
243 auto memory_manager_actor = std::make_shared<MemoryManagerActor>();
244 MS_EXCEPTION_IF_NULL(memory_manager_actor);
245 memory_manager_aid_ = memory_manager_actor->GetAID();
246 auto base_actor = static_cast<ActorReference>(memory_manager_actor);
247 // Bind single thread to response to memory alloc and free quickly.
248 (void)actor_manager->Spawn(base_actor, false);
249
250 // Create and schedule recorder actor.
251 auto recorder_actor = std::make_shared<RecorderActor>();
252 MS_EXCEPTION_IF_NULL(recorder_actor);
253 recorder_aid_ = &(recorder_actor->GetAID());
254 auto base_recorder_actor = static_cast<ActorReference>(recorder_actor);
255 (void)actor_manager->Spawn(base_recorder_actor, true);
256
257 // Create and schedule debug actor.
258 #ifndef ENABLE_SECURITY
259 bool debugger_actor_need = DumpJsonParser::GetInstance().e2e_dump_enabled();
260 #endif
261 #ifdef ENABLE_DEBUGGER
262 if (Debugger::GetInstance()->DebuggerBackendEnabled()) {
263 debugger_actor_need = true;
264 }
265 #endif
266 #ifndef ENABLE_SECURITY
267 if (debugger_actor_need) {
268 auto debug_actor = std::make_shared<DebugActor>();
269 MS_EXCEPTION_IF_NULL(debug_actor);
270 debug_aid_ = &(debug_actor->GetAID());
271 auto base_debug_actor = static_cast<ActorReference>(debug_actor);
272 (void)actor_manager->Spawn(base_debug_actor, true);
273 }
274 #endif
275 }
276
Transform(const GraphCompilerInfo & graph_compiler_info)277 ActorSet *GraphScheduler::Transform(const GraphCompilerInfo &graph_compiler_info) {
278 MS_LOG(INFO) << "Graph(" << graph_compiler_info.name_ << ") transforms actor begin.";
279 if (graph_compiler_info.graphs_.size() == 0) {
280 MS_LOG(EXCEPTION) << "The number of graphs is zero.";
281 }
282 if (graph_compiler_info.graphs_.size() != graph_compiler_info.device_contexts_.size()) {
283 MS_LOG(EXCEPTION) << "The number of graphs is not equal to the number of device contexts.";
284 }
285
286 PersistDeviceTensor(graph_compiler_info);
287 const auto &actor_set = Build(graph_compiler_info);
288 MS_EXCEPTION_IF_NULL(actor_set);
289 CacheGraphOutputToActor(graph_compiler_info);
290 Link(actor_set.get(), graph_compiler_info);
291 // The copy actors are built in the link, so need push into the actor set after link.
292 actor_set->copy_actors_ = copy_actors_;
293
294 (void)actors_.emplace(actor_set->name_, actor_set);
295
296 DumpActor(actor_set.get(), graph_compiler_info);
297 if (!CheckActorValid(actor_set.get(), graph_compiler_info.strategy_)) {
298 MS_LOG(EXCEPTION) << "The actor set of " << graph_compiler_info.name_ << " is invalid.";
299 }
300 MS_LOG(INFO) << "Graph(" << graph_compiler_info.name_ << ") transforms actor end.";
301
302 // Local maps and vectors clear.
303 graph_output_to_actor_.clear();
304 front_node_to_actor_.clear();
305 copy_actors_.clear();
306
307 return actor_set.get();
308 }
309
Schedule(const ActorSet * actor_set)310 void GraphScheduler::Schedule(const ActorSet *actor_set) {
311 MS_EXCEPTION_IF_NULL(actor_set);
312 auto actors = CollectActors(actor_set);
313 // Schedule actors.
314 auto actor_manager = ActorMgr::GetActorMgrRef();
315 MS_EXCEPTION_IF_NULL(actor_manager);
316 for (auto actor : actors) {
317 (void)actor_manager->Spawn(actor);
318 }
319 }
320
Run(const ActorSet * actor_set,const std::vector<std::vector<TensorPtr>> & input_tensors,const std::vector<TensorPtr> & input_tensors_with_value_node,GraphExecutionStrategy strategy)321 bool GraphScheduler::Run(const ActorSet *actor_set, const std::vector<std::vector<TensorPtr>> &input_tensors,
322 const std::vector<TensorPtr> &input_tensors_with_value_node, GraphExecutionStrategy strategy) {
323 MS_EXCEPTION_IF_NULL(actor_set);
324 MS_EXCEPTION_IF_NULL(actor_set->data_prepare_actor_);
325 #if !defined(_WIN32) && !defined(_WIN64)
326 SignalGuard sg(IntHandler);
327 #endif
328
329 // Construct OpContext.
330 OpContext<DeviceTensor> op_context;
331 std::vector<Promise<int>> result(1);
332 op_context.sequential_num_ = RandInt::Instance().Get();
333 op_context.results_ = &result;
334
335 if ((strategy == GraphExecutionStrategy::kStep) && IsSingleOpActorSet(actor_set)) {
336 actor_set->data_prepare_actor_->PrepareData(input_tensors, &op_context);
337 MS_EXCEPTION_IF_NULL(actor_set->kernel_actors_[0]);
338 actor_set->kernel_actors_[0]->RunOpControlWithInputTensor(nullptr, &op_context, &input_tensors_with_value_node);
339 return true;
340 }
341
342 // Trigger data prepare actor running.
343 Async(actor_set->data_prepare_actor_->GetAID(), &DataPrepareActor::PrepareData, input_tensors, &op_context);
344
345 // Get the run result.
346 auto result_future = result[0].GetFuture();
347 result_future.Wait();
348 MsException::Instance().CheckException();
349 return result_future.IsOK();
350 }
351
Fetch(const ActorInfo & actor_info) const352 ActorSet *GraphScheduler::Fetch(const ActorInfo &actor_info) const {
353 auto iter = actors_.find(actor_info);
354 if (iter != actors_.end()) {
355 return iter->second.get();
356 } else {
357 MS_LOG(ERROR) << "Can't find the actors map of " << actor_info;
358 return nullptr;
359 }
360 }
361
Build(const GraphCompilerInfo & graph_compiler_info)362 ActorSetPtr GraphScheduler::Build(const GraphCompilerInfo &graph_compiler_info) {
363 auto actor_set = std::make_shared<ActorSet>(graph_compiler_info.name_);
364 MS_EXCEPTION_IF_NULL(actor_set);
365
366 auto host_queue = std::make_shared<HostTensorQueue>();
367 actor_set->data_source_actors_ = BuildDataSourceActor(graph_compiler_info, host_queue);
368 actor_set->kernel_actors_ = BuildKernelActor(graph_compiler_info);
369 actor_set->loop_count_actor_ = BuildLoopCountActor(graph_compiler_info);
370 actor_set->output_actor_ = BuildOutputActor(graph_compiler_info);
371 actor_set->data_prepare_actor_ =
372 BuildDataPrepareActor(graph_compiler_info, actor_set->data_source_actors_, host_queue);
373 actor_set->switch_actors_ = BuildSwitchActor(graph_compiler_info);
374 actor_set->gather_actors_ = BuildGatherActor(graph_compiler_info);
375
376 return actor_set;
377 }
378
CacheGraphOutputToActor(const GraphCompilerInfo & graph_compiler_info)379 void GraphScheduler::CacheGraphOutputToActor(const GraphCompilerInfo &graph_compiler_info) {
380 if (graph_compiler_info.strategy_ == GraphExecutionStrategy::kStep) {
381 return;
382 }
383
384 for (const auto &graph : graph_compiler_info.graphs_) {
385 MS_EXCEPTION_IF_NULL(graph);
386 auto outputs = AnfAlgo::GetAllOutputWithIndex(graph->output());
387 for (const auto &output_with_index : outputs) {
388 auto output_kernel = output_with_index.first;
389 MS_EXCEPTION_IF_NULL(output_kernel);
390 auto origin_output_with_index = graph->GetFrontNodeWithIndexByGraphOutput(output_with_index);
391 if (origin_output_with_index.first == nullptr) {
392 MS_LOG(WARNING) << "The graph " << graph->graph_id() << " output node:" << output_kernel->fullname_with_scope()
393 << " with index: " << output_with_index.second << " has no actor.";
394 continue;
395 }
396
397 auto actor_output_index = output_with_index.second;
398 OpActor<DeviceTensor> *actor = nullptr;
399 if (IsKernelActor(output_kernel, graph_compiler_info.strategy_)) {
400 actor = FetchActor(output_kernel->fullname_with_scope());
401 } else if (IsDeviceQueueDSActor(output_kernel, graph_compiler_info.strategy_)) {
402 std::string actor_name = graph_compiler_info.name_ + "_DeviceDSActor" + "_" + std::to_string(graph->graph_id());
403 actor = FetchActor(actor_name);
404 } else if (IsHostQueueDSActor(output_kernel, graph, graph_compiler_info.origin_parameters_order_,
405 graph_compiler_info.strategy_)) {
406 actor = FetchActor(graph_compiler_info.name_ + "_HostDSActor");
407 const auto &host_ds_actor = dynamic_cast<HostQueueDataSourceActor *>(actor);
408 MS_EXCEPTION_IF_NULL(host_ds_actor);
409 // Get the position of output kernel in the data source actor.
410 actor_output_index = host_ds_actor->FetchNodePosition(output_kernel);
411 } else if (IsPersistentDeviceTensor(output_kernel)) {
412 MS_LOG(INFO) << "The graph " << graph->graph_id() << " output node:" << output_kernel->fullname_with_scope()
413 << " is device tensor store.";
414 continue;
415 } else {
416 MS_LOG(INFO) << "Ignore the internal parameter node:" << output_kernel->DebugString();
417 continue;
418 }
419
420 MS_EXCEPTION_IF_NULL(actor);
421 MS_LOG(INFO) << "Cache the graph " << graph->graph_id() << " output node:" << output_kernel->fullname_with_scope()
422 << " with index: " << output_with_index.second << " to actor:" << actor->GetAID().Name()
423 << " with index:" << actor_output_index
424 << ", from front node:" << origin_output_with_index.first->fullname_with_scope()
425 << " with index: " << origin_output_with_index.second;
426 (void)graph_output_to_actor_.emplace(origin_output_with_index,
427 GraphOutputPair(dynamic_cast<AbstractActor *>(actor), actor_output_index));
428 }
429 }
430 }
431
Link(ActorSet * actor_set,const GraphCompilerInfo & graph_compiler_info)432 void GraphScheduler::Link(ActorSet *actor_set, const GraphCompilerInfo &graph_compiler_info) {
433 MS_EXCEPTION_IF_NULL(actor_set);
434 std::vector<KernelActor *> auto_monad_actors;
435 std::vector<CNodePtr> communication_nodes;
436 const std::unordered_set<PrimitivePtr, PrimitiveHasher, PrimitiveEqual> auto_monad_prims = {
437 prim::kPrimDepend, prim::kPrimUpdateState, prim::kPrimLoad};
438
439 // Foreach the execution order to link the actors.
440 for (size_t index = 0; index < graph_compiler_info.graphs_.size(); ++index) {
441 const auto &graph = graph_compiler_info.graphs_[index];
442 MS_EXCEPTION_IF_NULL(graph);
443 auto execution_order = graph->execution_order();
444 for (auto &kernel : execution_order) {
445 MS_EXCEPTION_IF_NULL(kernel);
446 if (AnfAlgo::IsCommunicationOp(kernel)) {
447 (void)communication_nodes.emplace_back(kernel);
448 }
449 if (IsSkippedKernelActor(kernel) || (!IsKernelActor(kernel, graph_compiler_info.strategy_))) {
450 continue;
451 }
452 const auto &kernel_actor = dynamic_cast<KernelActor *>(FetchActor(kernel->fullname_with_scope()));
453 MS_EXCEPTION_IF_NULL(kernel_actor);
454
455 for (size_t i = 0; i < AnfAlgo::GetInputNum(kernel); ++i) {
456 auto input_node = AnfAlgo::GetInputNode(kernel, i);
457 // Link the control arrows of kernel actor by the auto monad, the inputs include monad node.
458 if (AnfAlgo::IsOneOfPrimitiveCNode(input_node, auto_monad_prims)) {
459 LinkControlArrowByAutoMonad(kernel_actor, input_node, graph);
460 }
461 if (HasAbstractMonad(input_node)) {
462 (void)auto_monad_actors.emplace_back(kernel_actor);
463 continue; // No data arrow for monad input.
464 }
465
466 KernelWithIndex from_kernel_with_output_idx = AnfAlgo::VisitKernelWithReturnType(input_node, 0, false);
467 KernelWithIndex to_kernel_with_input_idx = std::make_pair(kernel, i);
468 // The gather of linking data arrows of kernel by the different from kernel type.
469 LinkDataArrow(kernel_actor, graph_compiler_info, graph, from_kernel_with_output_idx, to_kernel_with_input_idx);
470 }
471 }
472 // Link the control arrows for allreduce kernel by the send/recv nodes in the kernel graph.
473 LinkControlArrowBySendRecvNodes(graph);
474 }
475
476 // Link the arrow in the control flow scene.
477 if (graph_compiler_info.strategy_ == GraphExecutionStrategy::kPipeline) {
478 LinkArrowByControlNode(graph_compiler_info, actor_set);
479 }
480
481 LinkGlobalControlArrow(actor_set, communication_nodes, auto_monad_actors, graph_compiler_info);
482 LinkOutputResultArrowForOutputActor(actor_set->output_actor_.get(), graph_compiler_info);
483 }
484
BuildDataSourceActor(const GraphCompilerInfo & graph_compiler_info,const HostTensorQueuePtr & host_queue)485 std::vector<DataSourceActorPtr> GraphScheduler::BuildDataSourceActor(const GraphCompilerInfo &graph_compiler_info,
486 const HostTensorQueuePtr &host_queue) {
487 std::vector<DataSourceActorPtr> data_source_actors;
488 HostQueueDSActorPtr host_queue_ds_actor = nullptr;
489 size_t data_node_position = 0;
490 std::unordered_map<AnfNodePtr, size_t> front_node_position_temp_map;
491
492 for (size_t i = 0; i < graph_compiler_info.graphs_.size(); ++i) {
493 const auto &graph = graph_compiler_info.graphs_[i];
494 const auto &device_context = graph_compiler_info.device_contexts_[i];
495 MS_EXCEPTION_IF_NULL(graph);
496 // Build host queue data source actor.
497 const std::vector<AnfNodePtr> &input_nodes = graph->input_nodes();
498
499 for (size_t j = 0; j < input_nodes.size(); j++) {
500 const auto &input_node = input_nodes[j];
501 MS_EXCEPTION_IF_NULL(input_node);
502
503 if (IsHostQueueDSActor(input_node, graph, graph_compiler_info.origin_parameters_order_,
504 graph_compiler_info.strategy_)) {
505 if (host_queue_ds_actor == nullptr) {
506 auto actor_name = graph_compiler_info.name_ + "_HostDSActor";
507 MS_LOG(INFO) << "Create host queue data source actor: " << actor_name;
508 host_queue_ds_actor = std::make_shared<HostQueueDataSourceActor>(actor_name, 1, memory_manager_aid_, nullptr,
509 nullptr, host_queue);
510 InsertActor(host_queue_ds_actor.get());
511 (void)data_source_actors.emplace_back(host_queue_ds_actor);
512 }
513
514 const auto &front_node = FetchFrontNodeByBackendNode(input_node, graph);
515 // In the scenario where multiple backend nodes correspond to the same front node, only the first backend node
516 // is saved in the host queue data source actor.
517 if (front_node_position_temp_map.count(front_node) > 0) {
518 (void)host_queue_ds_actor->data_node_position_map_.emplace(input_node,
519 front_node_position_temp_map[front_node]);
520 continue;
521 }
522 (void)host_queue_ds_actor->data_nodes_.emplace_back(input_node);
523 (void)host_queue_ds_actor->device_contexts_.emplace_back(device_context);
524 (void)host_queue_ds_actor->data_node_position_map_.emplace(input_node, data_node_position);
525 (void)front_node_position_temp_map.emplace(front_node, data_node_position);
526 data_node_position++;
527 }
528 }
529
530 // Build device queue data source actor.
531 const auto &execution_order = graph->execution_order();
532 const auto &iter =
533 std::find_if(execution_order.begin(), execution_order.end(), [&graph_compiler_info](const CNodePtr &node) {
534 return IsDeviceQueueDSActor(node, graph_compiler_info.strategy_);
535 });
536 if (iter != execution_order.end()) {
537 auto actor_name = graph_compiler_info.name_ + "_DeviceDSActor" + "_" + std::to_string(graph->graph_id());
538 MS_LOG(INFO) << "Create queue data source actor: " << actor_name;
539 auto device_queue_ds_actor = std::make_shared<DeviceQueueDataSourceActor>(
540 actor_name, 1, device_context, memory_manager_aid_, debug_aid_, recorder_aid_);
541 MS_EXCEPTION_IF_NULL(device_queue_ds_actor);
542 InsertActor(device_queue_ds_actor.get());
543 (void)data_source_actors.emplace_back(device_queue_ds_actor);
544 device_queue_ds_actor->data_kernel_ = *iter;
545 device_queue_ds_actor->kernel_info_ = dynamic_cast<device::KernelInfo *>((*iter)->kernel_info());
546 }
547 }
548
549 MS_EXCEPTION_IF_NULL(graph_compiler_info.control_node_parser_);
550 const auto &front_to_backend_parameter = graph_compiler_info.control_node_parser_->front_to_backend_parameters_;
551
552 // Initialize the parameter in the control node, first get all the front parameters in the control node, then find
553 // the corresponding backend parameter from the map, and insert it into the host data source actor
554 const auto &control_node_parameters = graph_compiler_info.control_node_parser_->control_node_parameters();
555 for (const auto ¶meter : control_node_parameters) {
556 if (IsPersistentDeviceTensor(parameter)) {
557 continue;
558 }
559 auto backend_iter = front_to_backend_parameter.find(parameter);
560 if (backend_iter == front_to_backend_parameter.end()) {
561 MS_LOG(EXCEPTION) << "Cannot find backend node for front node:" << AnfAlgo::GetNodeDebugString(parameter);
562 }
563
564 if (host_queue_ds_actor == nullptr) {
565 auto actor_name = graph_compiler_info.name_ + "_HostDSActor";
566 MS_LOG(INFO) << "Create host queue data source actor: " << actor_name;
567 host_queue_ds_actor =
568 std::make_shared<HostQueueDataSourceActor>(actor_name, 1, memory_manager_aid_, nullptr, nullptr, host_queue);
569 InsertActor(host_queue_ds_actor.get());
570 (void)data_source_actors.emplace_back(host_queue_ds_actor);
571 }
572
573 const auto &backend_node = backend_iter->second.first;
574 auto iter = find(host_queue_ds_actor->data_nodes_.begin(), host_queue_ds_actor->data_nodes_.end(), backend_node);
575 if (iter != host_queue_ds_actor->data_nodes_.end()) {
576 (void)host_queue_ds_actor->data_node_position_map_.emplace(parameter,
577 iter - host_queue_ds_actor->data_nodes_.begin());
578 } else {
579 (void)host_queue_ds_actor->data_node_position_map_.emplace(parameter, host_queue_ds_actor->data_nodes_.size());
580 (void)host_queue_ds_actor->data_nodes_.emplace_back(backend_iter->second.first);
581 (void)host_queue_ds_actor->device_contexts_.emplace_back(backend_iter->second.second);
582 }
583 }
584
585 return data_source_actors;
586 }
587
BuildKernelActor(const GraphCompilerInfo & graph_compiler_info)588 std::vector<KernelActorPtr> GraphScheduler::BuildKernelActor(const GraphCompilerInfo &graph_compiler_info) {
589 std::vector<KernelActorPtr> kernel_actors;
590
591 for (size_t i = 0; i < graph_compiler_info.graphs_.size(); ++i) {
592 const auto &graph = graph_compiler_info.graphs_[i];
593 const auto &device_context = graph_compiler_info.device_contexts_[i];
594 MS_EXCEPTION_IF_NULL(graph);
595 auto execution_order = graph->execution_order();
596
597 // Single op graph in step mode, kernel actor executes synchronously.
598 bool is_single_op_graph = execution_order.size() == 1;
599 GraphExecutionStrategy strategy = graph_compiler_info.strategy_;
600 if (strategy == GraphExecutionStrategy::kStep) {
601 strategy = (is_single_op_graph ? strategy : GraphExecutionStrategy::kPipeline);
602 }
603
604 for (auto &kernel : execution_order) {
605 MS_EXCEPTION_IF_NULL(kernel);
606 if (IsKernelActor(kernel, graph_compiler_info.strategy_) && (!IsSkippedKernelActor(kernel))) {
607 auto kernel_actor = std::make_shared<KernelActor>(kernel->fullname_with_scope(), kernel, device_context,
608 memory_manager_aid_, debug_aid_, recorder_aid_, strategy);
609 MS_EXCEPTION_IF_NULL(kernel_actor);
610 InsertActor(kernel_actor.get());
611 (void)kernel_actors.emplace_back(kernel_actor);
612 auto front_node = graph->GetFrontAnfByBackendAnf(kernel);
613 if (front_node != nullptr) {
614 front_node_to_actor_[front_node] = kernel_actor;
615 }
616 }
617 }
618 }
619 return kernel_actors;
620 }
621
BuildLoopCountActor(const GraphCompilerInfo & graph_compiler_info)622 LoopCountActorPtr GraphScheduler::BuildLoopCountActor(const GraphCompilerInfo &graph_compiler_info) {
623 if (graph_compiler_info.strategy_ == GraphExecutionStrategy::kStep) {
624 return nullptr;
625 }
626
627 auto loop_count = ConfigManager::GetInstance().iter_num();
628 auto actor_name = graph_compiler_info.name_ + "_LoopCountActor";
629 auto loop_count_actor =
630 std::make_shared<LoopCountActor>(actor_name, loop_count, memory_manager_aid_, debug_aid_, recorder_aid_);
631 MS_LOG(INFO) << "Create loop count actor: " << actor_name;
632 MS_EXCEPTION_IF_NULL(loop_count_actor);
633
634 InsertActor(loop_count_actor.get());
635 return loop_count_actor;
636 }
637
BuildOutputActor(const GraphCompilerInfo & graph_compiler_info)638 OutputActorPtr GraphScheduler::BuildOutputActor(const GraphCompilerInfo &graph_compiler_info) {
639 if (graph_compiler_info.strategy_ == GraphExecutionStrategy::kStep) {
640 return nullptr;
641 }
642
643 auto loop_count = ConfigManager::GetInstance().iter_num();
644 auto actor_name = graph_compiler_info.name_ + "_" + "OutputActor";
645 bool need_loop_count = (graph_compiler_info.strategy_ == GraphExecutionStrategy::kPipeline) ? true : false;
646
647 auto output_actor =
648 std::make_shared<OutputActor>(actor_name, loop_count, graph_compiler_info.outputs_num_, need_loop_count);
649 MS_LOG(INFO) << "Create output actor: " << actor_name;
650 MS_EXCEPTION_IF_NULL(output_actor);
651 InsertActor(output_actor.get());
652 return output_actor;
653 }
654
BuildDataPrepareActor(const GraphCompilerInfo & graph_compiler_info,const std::vector<DataSourceActorPtr> & data_source_actors,const HostTensorQueuePtr & host_queue)655 DataPrepareActorPtr GraphScheduler::BuildDataPrepareActor(const GraphCompilerInfo &graph_compiler_info,
656 const std::vector<DataSourceActorPtr> &data_source_actors,
657 const HostTensorQueuePtr &host_queue) {
658 HostQueueDSActorPtr host_queue_ds_actor = nullptr;
659 auto iter = std::find_if(data_source_actors.begin(), data_source_actors.end(), [&](const auto &data_source_actor) {
660 return data_source_actor->type_ == KernelTransformType::kHostDataSourceActor;
661 });
662 if (iter != data_source_actors.end()) {
663 host_queue_ds_actor = std::dynamic_pointer_cast<HostQueueDataSourceActor>(*iter);
664 }
665
666 auto actor_name = graph_compiler_info.name_ + "_DataPrepareActor";
667 auto data_prepare_actor = std::make_shared<DataPrepareActor>(actor_name, memory_manager_aid_, debug_aid_,
668 &graph_compiler_info, host_queue_ds_actor, host_queue);
669 MS_LOG(INFO) << "Create data prepare actor: " << actor_name;
670 MS_EXCEPTION_IF_NULL(data_prepare_actor);
671
672 // Cache the nodes which need continuous memory.
673 if (graph_compiler_info.strategy_ == GraphExecutionStrategy::kPipeline) {
674 for (size_t index = 0; index < graph_compiler_info.graphs_.size(); ++index) {
675 const auto &graph = graph_compiler_info.graphs_[index];
676 MS_EXCEPTION_IF_NULL(graph);
677 auto &execution_order = graph->execution_order();
678 for (auto &kernel : execution_order) {
679 if (!AnfAlgo::IsCommunicationOp(kernel)) {
680 continue;
681 }
682
683 auto key = std::make_pair(kernel, graph_compiler_info.device_contexts_[index]);
684 auto value = std::make_pair(false, false);
685 if (AnfAlgo::GetInputTensorNum(kernel) > 1) {
686 value.first = true;
687 }
688 if (AnfAlgo::GetOutputTensorNum(kernel) > 1) {
689 value.second = true;
690 }
691 if ((value.first == true) || (value.second == true)) {
692 data_prepare_actor->continuous_memory_nodes_[key] = value;
693 }
694 }
695 }
696 }
697
698 InsertActor(data_prepare_actor.get());
699 return data_prepare_actor;
700 }
701
BuildNoInputKernelActor(const ActorSet * actor_set,GraphExecutionStrategy strategy)702 std::vector<KernelActorPtr> GraphScheduler::BuildNoInputKernelActor(const ActorSet *actor_set,
703 GraphExecutionStrategy strategy) {
704 MS_EXCEPTION_IF_NULL(actor_set);
705 std::vector<KernelActorPtr> no_input_kernel_actors;
706
707 for (auto &kernel_actor : actor_set->kernel_actors_) {
708 MS_EXCEPTION_IF_NULL(kernel_actor);
709 // Framework will trigger kernel actor running in the step execution strategy.
710 if (strategy == GraphExecutionStrategy::kStep && IsSingleOpActorSet(actor_set)) {
711 kernel_actor->input_controls_num_++;
712 continue;
713 }
714
715 if ((kernel_actor->input_datas_num_ == 0) && (kernel_actor->input_controls_num_ == 0)) {
716 // Check whether the kernel actor belongs to the root graph.
717 // In general, all no input nodes belong to the root funcgraph, and the corresponding gather actor should be
718 // empty. In control flow, the control arrow of the no input node in the sub funcgraph should be sent by the
719 // gather actor and should not be placed in the no input list.
720 MS_EXCEPTION_IF_NULL(kernel_actor->kernel_);
721 const auto &graph = kernel_actor->kernel_->func_graph();
722 if (graph != nullptr) {
723 const auto &kernel_graph = dynamic_cast<KernelGraph *>(graph.get());
724 MS_EXCEPTION_IF_NULL(kernel_graph);
725 const auto func_graph = kernel_graph->GetFuncGraph();
726 if (func_graph != nullptr && FetchActor(func_graph->ToString()) != nullptr) {
727 continue;
728 }
729 }
730
731 (void)no_input_kernel_actors.emplace_back(kernel_actor);
732 }
733 }
734 return no_input_kernel_actors;
735 }
736
BuildSwitchActor(const GraphCompilerInfo & graph_compiler_info)737 std::vector<SwitchActorPtr> GraphScheduler::BuildSwitchActor(const GraphCompilerInfo &graph_compiler_info) {
738 std::vector<SwitchActorPtr> switch_actors;
739 std::unordered_map<AnfNodePtr, AnfNodePtr> front_to_backend_kernel;
740 for (const auto &pair : front_node_to_actor_) {
741 front_to_backend_kernel[pair.first] = pair.second->kernel_;
742 }
743
744 // Build switch actor by switch node and switchlayer node.
745 for (const auto &control_node : graph_compiler_info.control_nodes_) {
746 if (AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimSwitch) ||
747 AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimSwitchLayer)) {
748 const auto func_graph = control_node->func_graph();
749 const auto branch_id = graph_compiler_info.control_node_parser_->GetBranchIDByFuncGraph(func_graph);
750 const auto &actor_name = control_node->DebugString();
751 auto switch_actor = std::make_shared<SwitchActor>(actor_name, graph_compiler_info.device_contexts_[0],
752 control_node->cast<CNodePtr>(), branch_id, false);
753 switch_actor->ParseInput(graph_compiler_info.control_node_parser_);
754
755 // Fetch all the input nodes of switch actor.
756 switch_actor->FetchInputNode(graph_compiler_info.control_node_parser_);
757 InsertActor(switch_actor.get());
758 (void)switch_actors.emplace_back(switch_actor);
759 }
760 }
761
762 // Build switch actor by return node.
763 const auto func_graphs_to_call_num = graph_compiler_info.control_node_parser_->func_graph_to_call_num_;
764 for (const auto &func_graph_to_call_num : func_graphs_to_call_num) {
765 const auto &return_node = func_graph_to_call_num.first->get_return();
766 MS_EXCEPTION_IF_NULL(return_node);
767 const auto &actor_name = return_node->DebugString();
768 auto switch_actor = std::make_shared<SwitchActor>(actor_name, graph_compiler_info.device_contexts_[0],
769 return_node->cast<CNodePtr>(), kInvalidBranchID, true);
770 switch_actor->ParseInput(graph_compiler_info.control_node_parser_);
771
772 // Fetch all the input nodes of switch actor.
773 switch_actor->FetchInputNode(graph_compiler_info.control_node_parser_);
774 InsertActor(switch_actor.get());
775 (void)switch_actors.emplace_back(switch_actor);
776 }
777
778 return switch_actors;
779 }
780
BuildGatherActor(const GraphCompilerInfo & graph_compiler_info)781 std::vector<GatherActorPtr> GraphScheduler::BuildGatherActor(const GraphCompilerInfo &graph_compiler_info) {
782 std::vector<GatherActorPtr> gather_actors;
783
784 const auto &loop_count_actor_name = graph_compiler_info.name_ + "_LoopCountActor";
785 const auto &loop_count_actor = FetchActor(loop_count_actor_name);
786 if (loop_count_actor == nullptr) {
787 return gather_actors;
788 }
789
790 const auto &output_actor_name = graph_compiler_info.name_ + "_" + "OutputActor";
791 const auto &output_actor = FetchActor(output_actor_name);
792 MS_EXCEPTION_IF_NULL(output_actor);
793
794 const auto parser = graph_compiler_info.control_node_parser_;
795
796 bool is_main_return = true;
797 // Each funcgraph has a return node, get the funcgraph from the return node, and create a gather actor.
798 std::unordered_map<AnfNodePtr, AnfNodePtr> front_to_backend_kernel;
799 for (const auto &pair : front_node_to_actor_) {
800 front_to_backend_kernel[pair.first] = pair.second->kernel_;
801 }
802
803 for (const auto &control_node : graph_compiler_info.control_nodes_) {
804 const auto &func_graph = control_node->func_graph();
805 const auto &cnode = control_node->cast<CNodePtr>();
806 const auto &inputs = cnode->inputs();
807 const auto &return_node = func_graph->get_return();
808
809 if (AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimReturn)) {
810 // Root funcgraph does not need to create a gather actor.
811 if (is_main_return) {
812 is_main_return = false;
813 continue;
814 }
815
816 if (AnfAlgo::CheckPrimitiveType(inputs[kReturnInputPos], prim::kPrimPartial)) {
817 continue;
818 }
819 auto actor_name = func_graph->ToString();
820 std::vector<KernelWithIndex> parameters;
821 for (const auto ¶meter : func_graph->get_inputs()) {
822 if (HasAbstractMonad(parameter) || HasAbstractRef(parameter)) {
823 continue;
824 }
825 (void)parameters.emplace_back(parameter, 0);
826 }
827
828 const auto branch_id = parser->GetBranchIDByFuncGraph(func_graph);
829
830 const auto &output_switch_actor = FetchActor(return_node->DebugString());
831 MS_EXCEPTION_IF_NULL(output_switch_actor);
832 const auto &output_switch_aid = output_switch_actor->GetAID();
833
834 auto gather_actor =
835 std::make_shared<GatherActor>(actor_name, parameters, true, output_switch_aid, AID(), branch_id);
836 gather_actor->FetchBackendInputNode(func_graph, graph_compiler_info.control_node_parser_);
837 InsertActor(gather_actor.get());
838 (void)gather_actors.emplace_back(gather_actor);
839 }
840 }
841
842 // Create gather actor for call node which input0 of call node is a funcgraph.
843 for (const auto &control_node : graph_compiler_info.control_nodes_) {
844 const auto &cnode = control_node->cast<CNodePtr>();
845 const auto &inputs = cnode->inputs();
846
847 if (inputs[0]->isa<ValueNode>() && IsValueNode<FuncGraph>(inputs[0])) {
848 // Collect the parameters.
849 std::vector<KernelWithIndex> parameters;
850 for (size_t i = kCallInputStartPos; i < inputs.size(); ++i) {
851 if (HasAbstractMonad(inputs[i]) || (inputs[i]->isa<Parameter>() && HasAbstractRef(inputs[i]))) {
852 continue;
853 }
854 (void)parameters.emplace_back(inputs[i], 0);
855 }
856
857 auto func_graph = control_node->func_graph();
858 auto actor_name = control_node->DebugString();
859 const auto branch_id = parser->GetBranchIDByFuncGraph(func_graph);
860 const auto &to_func_graph = GetValueNode<FuncGraphPtr>(inputs[0]);
861 const auto &to_actor = FetchActor(to_func_graph->ToString());
862 auto gather_actor =
863 std::make_shared<GatherActor>(actor_name, parameters, false, AID(), to_actor->GetAID(), branch_id);
864 gather_actor->FetchBackendInputNode(func_graph, graph_compiler_info.control_node_parser_);
865
866 InsertActor(gather_actor.get());
867 (void)gather_actors.emplace_back(gather_actor);
868 }
869 }
870
871 // Create gather actor for kernel graph which has a call input.
872 const auto &graph_with_device_contexts = graph_compiler_info.control_node_parser_->call_input_kernel_graphs_;
873 for (const auto &graph_with_device_context : graph_with_device_contexts) {
874 const auto &graph = graph_with_device_context.first;
875 const auto ¶meters = FetchParameterbyKernelGraph(graph);
876
877 auto actor_name = graph->ToString();
878 auto gather_actor = std::make_shared<GatherActor>(actor_name, parameters, false, AID(), AID(), kInvalidBranchID);
879 InsertActor(gather_actor.get());
880 (void)gather_actors.emplace_back(gather_actor);
881 }
882
883 return gather_actors;
884 }
885
LinkDataArrow(KernelActor * const to_actor,const GraphCompilerInfo & graph_compiler_info,const KernelGraphPtr & graph,const KernelWithIndex & from_kernel_with_output_idx,const KernelWithIndex & to_kernel_with_input_idx)886 void GraphScheduler::LinkDataArrow(KernelActor *const to_actor, const GraphCompilerInfo &graph_compiler_info,
887 const KernelGraphPtr &graph, const KernelWithIndex &from_kernel_with_output_idx,
888 const KernelWithIndex &to_kernel_with_input_idx) {
889 MS_EXCEPTION_IF_NULL(to_actor);
890 MS_EXCEPTION_IF_NULL(graph);
891
892 auto from_kernel = from_kernel_with_output_idx.first;
893 MS_EXCEPTION_IF_NULL(from_kernel);
894 MS_EXCEPTION_IF_NULL(graph_compiler_info.control_node_parser_);
895 if (from_kernel->isa<Parameter>() && graph_compiler_info.control_node_parser_->IsCallInputKernelGraph(graph)) {
896 const auto &kernel_with_index = GetFrontNodeByKernelGraph(from_kernel, graph);
897 const auto &real_front_node_with_index =
898 AnfAlgo::VisitKernelWithReturnType(kernel_with_index.first, kernel_with_index.second);
899 if (HasAbstractRef(real_front_node_with_index.first)) {
900 (void)to_actor->device_tensor_store_keys_.emplace_back(to_kernel_with_input_idx.second,
901 real_front_node_with_index.first);
902 return;
903 }
904 // When there is a call input in the kernel graph, all the inputs of the kernel graph needs to be sent by gather.
905 const auto actor_name = graph->ToString();
906 auto actor = FetchActor(actor_name);
907 MS_EXCEPTION_IF_NULL(actor);
908 LinkDataArrowForGatherActor(dynamic_cast<GatherActor *>(actor), to_actor, real_front_node_with_index,
909 to_kernel_with_input_idx);
910 return;
911 }
912
913 auto front_node = GetFrontNodeByBackendNode(from_kernel);
914 if (front_node != nullptr && IsGatherActor(front_node, actor_name_to_actor_)) {
915 // Link the data arrows of gather actor.
916 auto func_graph = GetFuncgraphByBackendNode(from_kernel);
917 if (func_graph == nullptr) {
918 MS_LOG(EXCEPTION) << "Cannot find funcgraph of node:" << AnfAlgo::GetNodeDebugString(from_kernel);
919 }
920 auto actor_name = func_graph->ToString();
921 const auto &from_actor = dynamic_cast<GatherActor *>(FetchActor(actor_name));
922 if (HasAbstractRef(from_kernel)) {
923 (void)to_actor->device_tensor_store_keys_.emplace_back(to_kernel_with_input_idx.second, front_node);
924 return;
925 }
926 LinkDataArrowForGatherActor(from_actor, to_actor, {front_node, 0}, to_kernel_with_input_idx);
927 return;
928 }
929
930 auto kernel_type = KernelTransformType::kUnknown;
931 std::string kernel_name = "";
932 FetchKernelTransformTypeAndName(from_kernel, graph, graph_compiler_info, &kernel_type, &kernel_name);
933 auto from_actor = dynamic_cast<AbstractActor *>(FetchActor(kernel_name));
934 if (kKernelTypeToLinkFunc.count(kernel_type) > 0) {
935 (this->*kKernelTypeToLinkFunc[kernel_type])(from_actor, to_actor, from_kernel_with_output_idx,
936 to_kernel_with_input_idx, graph);
937 }
938 }
939
LinkDataArrowForDeviceTensorStore(AbstractActor * const,KernelActor * const to_actor,const KernelWithIndex & from_kernel_with_output_idx,const KernelWithIndex & to_kernel_with_input_idx,const KernelGraphPtr & graph)940 void GraphScheduler::LinkDataArrowForDeviceTensorStore(AbstractActor *const, KernelActor *const to_actor,
941 const KernelWithIndex &from_kernel_with_output_idx,
942 const KernelWithIndex &to_kernel_with_input_idx,
943 const KernelGraphPtr &graph) {
944 MS_EXCEPTION_IF_NULL(to_actor);
945 MS_EXCEPTION_IF_NULL(graph);
946 auto from_kernel = from_kernel_with_output_idx.first;
947 MS_EXCEPTION_IF_NULL(from_kernel);
948
949 auto device_tensor_store_key = FetchFrontNodeByBackendNode(from_kernel, graph);
950 (void)to_actor->device_tensor_store_keys_.emplace_back(to_kernel_with_input_idx.second, device_tensor_store_key);
951 }
952
LinkDataArrowForInternalParameter(AbstractActor * const,KernelActor * to_actor,const KernelWithIndex & from_kernel_with_output_idx,const KernelWithIndex & to_kernel_with_input_idx,const KernelGraphPtr & graph)953 void GraphScheduler::LinkDataArrowForInternalParameter(AbstractActor *const, KernelActor *to_actor,
954 const KernelWithIndex &from_kernel_with_output_idx,
955 const KernelWithIndex &to_kernel_with_input_idx,
956 const KernelGraphPtr &graph) {
957 MS_EXCEPTION_IF_NULL(to_actor);
958 MS_EXCEPTION_IF_NULL(graph);
959 auto internal_parameter = from_kernel_with_output_idx.first;
960 MS_EXCEPTION_IF_NULL(internal_parameter);
961
962 // Parameter ---> front node.
963 auto front_output_with_index = graph->GetFrontNodeByInternalParameter(internal_parameter);
964 auto front_output_node = front_output_with_index.first;
965 MS_EXCEPTION_IF_NULL(front_output_node);
966 if (IsSwitchActor(front_output_node)) {
967 auto switch_actor = dynamic_cast<SwitchActor *>(FetchActor(front_output_node->DebugString()));
968 MS_EXCEPTION_IF_NULL(switch_actor);
969 LinkDataArrowForSwitchActor(switch_actor, 0, to_actor, to_kernel_with_input_idx.second);
970 to_actor->input_datas_num_++;
971 return;
972 }
973
974 auto real_from_kernel_with_output_idx = from_kernel_with_output_idx;
975 AbstractActor *real_from_actor = nullptr;
976 KernelTransformType kernel_type;
977 if (IsPersistentDeviceTensor(front_output_node)) {
978 kernel_type = KernelTransformType::kDeviceTensorStore;
979 } else {
980 // front node ---> actor.
981 if (graph_output_to_actor_.count(front_output_with_index) == 0) {
982 MS_LOG(EXCEPTION) << "Can't find actor by front node:" << AnfAlgo::GetNodeDebugString(front_output_node)
983 << ", internal parameter:" << AnfAlgo::GetNodeDebugString(internal_parameter);
984 }
985 auto actor_pair = graph_output_to_actor_[front_output_with_index];
986 MS_EXCEPTION_IF_NULL(actor_pair.first);
987 MS_LOG(INFO) << "Graph " << graph->graph_id() << " internal parameter:" << internal_parameter->DebugString()
988 << ", corresponding front node:" << front_output_node->fullname_with_scope()
989 << " with index:" << front_output_with_index.second
990 << ", from actor:" << actor_pair.first->GetAID().Name() << " with index:" << actor_pair.second
991 << ", to actor:" << to_actor->GetAID().Name() << " with index:" << to_kernel_with_input_idx.second;
992 real_from_actor = actor_pair.first;
993 real_from_kernel_with_output_idx = KernelWithIndex(nullptr, actor_pair.second);
994 kernel_type = actor_pair.first->type_;
995 }
996
997 if (kKernelTypeToLinkFunc.count(kernel_type) == 0) {
998 MS_LOG(EXCEPTION) << "Invalid internal parameter:" << internal_parameter->DebugString() << ", type:" << kernel_type;
999 }
1000 (this->*kKernelTypeToLinkFunc[kernel_type])(real_from_actor, to_actor, real_from_kernel_with_output_idx,
1001 to_kernel_with_input_idx, graph);
1002 }
1003
LinkDataArrowForBaseActor(AbstractActor * const from_actor,KernelActor * const to_actor,const KernelWithIndex & from_kernel_with_output_idx,const KernelWithIndex & to_kernel_with_input_idx)1004 void GraphScheduler::LinkDataArrowForBaseActor(AbstractActor *const from_actor, KernelActor *const to_actor,
1005 const KernelWithIndex &from_kernel_with_output_idx,
1006 const KernelWithIndex &to_kernel_with_input_idx) {
1007 MS_EXCEPTION_IF_NULL(from_actor);
1008 MS_EXCEPTION_IF_NULL(to_actor);
1009
1010 auto from_kernel = from_kernel_with_output_idx.first;
1011 MS_EXCEPTION_IF_NULL(from_kernel);
1012 auto from_output_index = from_kernel_with_output_idx.second;
1013 auto to_input_index = to_kernel_with_input_idx.second;
1014
1015 // Get the position of from kernel in the data source actor.
1016 auto position = from_actor->FetchNodePosition(from_kernel);
1017 if ((from_actor->device_contexts_.size() <= position) || (to_actor->device_contexts_.empty())) {
1018 MS_LOG(EXCEPTION) << "The device contexts size is wrong.";
1019 }
1020
1021 if (IsNeedInsertCopyActor(from_actor->device_contexts_[position], to_actor->device_contexts_[0])) {
1022 LinkDataArrowForCopyActor(from_actor, to_actor, from_kernel_with_output_idx, to_kernel_with_input_idx);
1023 } else {
1024 auto to_aid = to_actor->GetAID();
1025 auto op_arrow = std::make_shared<DataArrow>(from_output_index, to_aid, to_input_index);
1026 // If the from actor has the multi nodes, then use the real output position.
1027 if (position != 0) {
1028 op_arrow->from_output_index_ = SizeToInt(position);
1029 }
1030
1031 (void)from_actor->output_data_arrows_.emplace_back(op_arrow);
1032 to_actor->input_datas_num_++;
1033 (void)to_actor->input_data_arrow_aids_.emplace_back(from_actor->GetAID());
1034
1035 // Update the reference count of device tensor.
1036 UpdateRefCount(from_kernel, from_output_index);
1037 }
1038 }
1039
LinkDataArrowForDeviceDSActor(AbstractActor * const from_actor,KernelActor * const to_actor,const KernelWithIndex & from_kernel_with_output_idx,const KernelWithIndex & to_kernel_with_input_idx,const KernelGraphPtr &)1040 void GraphScheduler::LinkDataArrowForDeviceDSActor(AbstractActor *const from_actor, KernelActor *const to_actor,
1041 const KernelWithIndex &from_kernel_with_output_idx,
1042 const KernelWithIndex &to_kernel_with_input_idx,
1043 const KernelGraphPtr &) {
1044 auto real_from_kernel_with_output_idx = from_kernel_with_output_idx;
1045 if (real_from_kernel_with_output_idx.first == nullptr) {
1046 auto device_ds_actor = dynamic_cast<DeviceQueueDataSourceActor *>(from_actor);
1047 MS_EXCEPTION_IF_NULL(device_ds_actor);
1048 real_from_kernel_with_output_idx.first = device_ds_actor->data_kernel_;
1049 }
1050
1051 LinkDataArrowForBaseActor(from_actor, to_actor, real_from_kernel_with_output_idx, to_kernel_with_input_idx);
1052 }
1053
LinkDataArrowForHostDSActor(AbstractActor * const from_actor,KernelActor * const to_actor,const KernelWithIndex & from_kernel_with_output_idx,const KernelWithIndex & to_kernel_with_input_idx,const KernelGraphPtr &)1054 void GraphScheduler::LinkDataArrowForHostDSActor(AbstractActor *const from_actor, KernelActor *const to_actor,
1055 const KernelWithIndex &from_kernel_with_output_idx,
1056 const KernelWithIndex &to_kernel_with_input_idx,
1057 const KernelGraphPtr &) {
1058 auto host_ds_actor = dynamic_cast<HostQueueDataSourceActor *>(from_actor);
1059 MS_EXCEPTION_IF_NULL(host_ds_actor);
1060
1061 KernelWithIndex real_from_kernel_with_output_idx;
1062 if (from_kernel_with_output_idx.first != nullptr) {
1063 // Get the position of from kernel in the data source actor.
1064 auto position = host_ds_actor->FetchNodePosition(from_kernel_with_output_idx.first);
1065 real_from_kernel_with_output_idx.first = host_ds_actor->FetchNode(position);
1066 real_from_kernel_with_output_idx.second = from_kernel_with_output_idx.second;
1067 } else {
1068 real_from_kernel_with_output_idx.first = host_ds_actor->FetchNode(from_kernel_with_output_idx.second);
1069 real_from_kernel_with_output_idx.second = 0;
1070 }
1071
1072 LinkDataArrowForBaseActor(from_actor, to_actor, real_from_kernel_with_output_idx, to_kernel_with_input_idx);
1073 }
1074
LinkDataArrowForKernelActor(AbstractActor * const from_actor,KernelActor * const to_actor,const KernelWithIndex & from_kernel_with_output_idx,const KernelWithIndex & to_kernel_with_input_idx,const KernelGraphPtr &)1075 void GraphScheduler::LinkDataArrowForKernelActor(AbstractActor *const from_actor, KernelActor *const to_actor,
1076 const KernelWithIndex &from_kernel_with_output_idx,
1077 const KernelWithIndex &to_kernel_with_input_idx,
1078 const KernelGraphPtr &) {
1079 auto real_from_actor = from_actor;
1080 auto real_from_kernel_with_output_idx = from_kernel_with_output_idx;
1081 auto from_kernel = from_kernel_with_output_idx.first;
1082 if (from_kernel == nullptr) {
1083 auto kernel_actor = dynamic_cast<KernelActor *>(from_actor);
1084 MS_EXCEPTION_IF_NULL(kernel_actor);
1085 from_kernel = kernel_actor->kernel_;
1086 real_from_kernel_with_output_idx.first = kernel_actor->kernel_;
1087 }
1088
1089 // Update the from kernel info by the real node info.
1090 MS_EXCEPTION_IF_NULL(from_kernel);
1091 if (IsSkippedKernelActor(from_kernel)) {
1092 real_from_kernel_with_output_idx = AnfAlgo::GetPrevNodeOutput(from_kernel, 0);
1093 MS_EXCEPTION_IF_NULL(real_from_kernel_with_output_idx.first);
1094 LinkControlArrowBySkippedNode(to_actor, from_kernel);
1095
1096 MS_EXCEPTION_IF_NULL(to_kernel_with_input_idx.first);
1097 MS_LOG(INFO) << "Link data arrow for inplace node, aggregate node: "
1098 << to_kernel_with_input_idx.first->fullname_with_scope()
1099 << ", aggregate input index: " << to_kernel_with_input_idx.second
1100 << ", skip node: " << from_kernel->fullname_with_scope()
1101 << ", real node: " << real_from_kernel_with_output_idx.first->fullname_with_scope();
1102 real_from_actor =
1103 dynamic_cast<AbstractActor *>(FetchActor(real_from_kernel_with_output_idx.first->fullname_with_scope()));
1104 MS_EXCEPTION_IF_NULL(real_from_actor);
1105 }
1106
1107 LinkDataArrowForBaseActor(real_from_actor, to_actor, real_from_kernel_with_output_idx, to_kernel_with_input_idx);
1108 }
1109
LinkDataArrowForCopyActor(AbstractActor * const from_actor,KernelActor * const to_actor,const KernelWithIndex & from_kernel_with_output_idx,const KernelWithIndex & to_kernel_with_input_idx)1110 void GraphScheduler::LinkDataArrowForCopyActor(AbstractActor *const from_actor, KernelActor *const to_actor,
1111 const KernelWithIndex &from_kernel_with_output_idx,
1112 const KernelWithIndex &to_kernel_with_input_idx) {
1113 MS_EXCEPTION_IF_NULL(from_actor);
1114 MS_EXCEPTION_IF_NULL(to_actor);
1115 auto from_kernel = from_kernel_with_output_idx.first;
1116 MS_EXCEPTION_IF_NULL(from_kernel);
1117 auto from_output_index = from_kernel_with_output_idx.second;
1118 auto to_input_index = to_kernel_with_input_idx.second;
1119
1120 std::string name = "copy_from:" + from_actor->GetAID().Name() + "_node:" + from_kernel->fullname_with_scope() +
1121 "_output_index:" + std::to_string(from_output_index);
1122 CopyActor *copy_actor = dynamic_cast<CopyActor *>(FetchActor(name));
1123 // Link between from actor and copy actor.
1124 if (copy_actor == nullptr) {
1125 // Create the copy actor.
1126 auto copy_actor_shared_ptr = std::make_shared<CopyActor>(name, memory_manager_aid_);
1127 (void)copy_actors_.emplace_back(copy_actor_shared_ptr);
1128 copy_actor = copy_actor_shared_ptr.get();
1129 MS_EXCEPTION_IF_NULL(copy_actor);
1130 InsertActor(copy_actor);
1131
1132 // Get the position of from kernel in the data source actor.
1133 auto position = from_actor->FetchNodePosition(from_kernel);
1134 if ((from_actor->device_contexts_.size() <= position) || (to_actor->device_contexts_.empty())) {
1135 MS_LOG(EXCEPTION) << "The device contexts size is wrong.";
1136 }
1137 auto from_device_context = from_actor->device_contexts_[position];
1138 auto to_device_context = to_actor->device_contexts_[0];
1139 auto from_device_tensor = AnfAlgo::GetMutableOutputAddr(from_kernel, from_output_index, false);
1140 MS_EXCEPTION_IF_NULL(from_device_context);
1141 MS_EXCEPTION_IF_NULL(to_device_context);
1142 MS_EXCEPTION_IF_NULL(from_device_tensor);
1143 auto op_arrow_to_copy = std::make_shared<DataArrow>(from_output_index, copy_actor->GetAID(), 0);
1144 // If the from actor has the multi nodes, then use the real output position.
1145 if (position != 0) {
1146 op_arrow_to_copy->from_output_index_ = SizeToInt(position);
1147 }
1148
1149 // Link.
1150 (void)from_actor->output_data_arrows_.emplace_back(op_arrow_to_copy);
1151 copy_actor->input_datas_num_++;
1152
1153 // Set the member of the copy actor.
1154 auto to_kernel_mod = AnfAlgo::GetKernelMod(to_kernel_with_input_idx.first);
1155 MS_EXCEPTION_IF_NULL(to_kernel_mod);
1156 auto input_sizes = to_kernel_mod->GetInputSizeList();
1157 if (to_input_index >= input_sizes.size()) {
1158 MS_LOG(EXCEPTION) << "To input index(" << to_input_index << ") is out of size: " << input_sizes.size();
1159 }
1160 copy_actor->output_ = to_device_context->CreateDeviceAddress(
1161 nullptr, input_sizes[to_input_index], from_device_tensor->format(), from_device_tensor->type_id());
1162 (void)copy_actor->device_contexts_.emplace_back(from_device_context);
1163 (void)copy_actor->device_contexts_.emplace_back(to_device_context);
1164
1165 // Update the reference count of device tensor.
1166 UpdateRefCount(from_device_tensor.get());
1167 }
1168
1169 // If the copy actor already exists, only need link between copy actor and to actor.
1170 auto op_arrow_from_copy = std::make_shared<DataArrow>(0, to_actor->GetAID(), to_input_index);
1171 (void)copy_actor->output_data_arrows_.emplace_back(op_arrow_from_copy);
1172 to_actor->input_datas_num_++;
1173 UpdateRefCount(copy_actor->output_.get());
1174 }
1175
LinkControlArrowByAutoMonad(KernelActor * to_actor,const AnfNodePtr & from_node,const KernelGraphPtr & graph)1176 void GraphScheduler::LinkControlArrowByAutoMonad(KernelActor *to_actor, const AnfNodePtr &from_node,
1177 const KernelGraphPtr &graph) {
1178 MS_EXCEPTION_IF_NULL(to_actor);
1179 MS_EXCEPTION_IF_NULL(from_node);
1180 MS_EXCEPTION_IF_NULL(graph);
1181 // Find the real input node, include the monad node and make tuple node.
1182 const std::vector<PrimitivePtr> return_types = {prim::kPrimDepend, prim::kPrimUpdateState, prim::kPrimLoad,
1183 prim::kPrimMakeTuple};
1184 const auto &input_kernel_with_output_idx = AnfAlgo::VisitKernelWithReturnType(from_node, 0, false, return_types);
1185 MS_EXCEPTION_IF_NULL(input_kernel_with_output_idx.first);
1186 auto input_anfnode = input_kernel_with_output_idx.first;
1187 CNodePtr input_cnode = nullptr;
1188 if (input_anfnode->isa<CNode>()) {
1189 input_cnode = input_anfnode->cast<CNodePtr>();
1190 }
1191 // Make tuple node needs to be expanded.
1192 if (AnfAlgo::CheckPrimitiveType(input_anfnode, prim::kPrimMakeTuple)) {
1193 MS_EXCEPTION_IF_NULL(input_cnode);
1194 for (size_t i = 1; i < input_cnode->inputs().size(); ++i) {
1195 LinkControlArrowByAutoMonad(to_actor, input_cnode->input(i), graph);
1196 }
1197 return;
1198 }
1199
1200 const std::unordered_set<PrimitivePtr, PrimitiveHasher, PrimitiveEqual> recursion_prims = {
1201 prim::kPrimDepend, prim::kPrimUpdateState, prim::kPrimLoad, prim::kPrimMakeTuple};
1202 // Get the real depend input by monad node which needs to link the control arrow.
1203 std::vector<AnfNodePtr> real_depend_inputs;
1204 if (AnfAlgo::CheckPrimitiveType(input_anfnode, prim::kPrimDepend) ||
1205 AnfAlgo::CheckPrimitiveType(input_anfnode, prim::kPrimLoad)) {
1206 MS_EXCEPTION_IF_NULL(input_cnode);
1207 real_depend_inputs.push_back(input_cnode->input(kDependAttachNodeIndex));
1208 // The real input may be this scene: depend/load --> load/depend, so need add the control arrow for real input
1209 // node in this scene.
1210 if (AnfAlgo::IsOneOfPrimitiveCNode(input_cnode->input(kRealInputIndexInDepend), recursion_prims)) {
1211 real_depend_inputs.push_back(input_cnode->input(kRealInputIndexInDepend));
1212 }
1213 } else if (AnfAlgo::CheckPrimitiveType(input_anfnode, prim::kPrimUpdateState)) {
1214 MS_EXCEPTION_IF_NULL(input_cnode);
1215 for (size_t i = kUpdateStateRealInput; i < input_cnode->inputs().size(); ++i) {
1216 real_depend_inputs.push_back(input_cnode->input(i));
1217 }
1218 } else {
1219 real_depend_inputs.push_back(input_anfnode);
1220 }
1221
1222 for (const auto &real_depend_input : real_depend_inputs) {
1223 auto real_depend_input_with_idx = AnfAlgo::VisitKernelWithReturnType(real_depend_input, 0, false, return_types);
1224 auto real_depend_kernel = real_depend_input_with_idx.first;
1225 // The monad node and make tuple node need recursion.
1226 if (AnfAlgo::IsOneOfPrimitiveCNode(real_depend_kernel, recursion_prims)) {
1227 LinkControlArrowByAutoMonad(to_actor, real_depend_kernel, graph);
1228 continue;
1229 }
1230
1231 KernelActor *from_actor = nullptr;
1232 if (IsKernelActor(real_depend_kernel)) {
1233 from_actor = dynamic_cast<KernelActor *>(FetchActor(real_depend_kernel->fullname_with_scope()));
1234 } else if (IsInternalParameter(real_depend_kernel, graph)) {
1235 auto front_output_with_index = graph->GetFrontNodeByInternalParameter(real_depend_kernel);
1236 MS_EXCEPTION_IF_NULL(front_output_with_index.first);
1237 if (IsKernelActor(front_output_with_index.first)) {
1238 if (graph_output_to_actor_.count(front_output_with_index) == 0) {
1239 MS_LOG(EXCEPTION) << "Can't find actor by front node:" << front_output_with_index.first->DebugString();
1240 }
1241 from_actor = dynamic_cast<KernelActor *>(graph_output_to_actor_[front_output_with_index].first);
1242 }
1243 }
1244 if (from_actor == nullptr) {
1245 continue;
1246 }
1247 MS_LOG(INFO) << "Link control arrow by auto monad, from actor: " << from_actor->GetAID().Name()
1248 << ", to actor: " << to_actor->GetAID().Name();
1249 (void)from_actor->output_control_arrows_.emplace_back(to_actor->GetAID());
1250 to_actor->input_controls_num_++;
1251 }
1252 }
1253
LinkControlArrowBySkippedNode(KernelActor * to_actor,const AnfNodePtr & skipped_node)1254 void GraphScheduler::LinkControlArrowBySkippedNode(KernelActor *to_actor, const AnfNodePtr &skipped_node) {
1255 MS_EXCEPTION_IF_NULL(to_actor);
1256 MS_EXCEPTION_IF_NULL(skipped_node);
1257 auto to_aid = to_actor->GetAID();
1258
1259 // Link the control arrow from all the inputs of skipped node to the user of skipped node.
1260 auto input_num = AnfAlgo::GetInputTensorNum(skipped_node);
1261 for (size_t i = 0; i < input_num; ++i) {
1262 auto kernel_with_index = AnfAlgo::GetPrevNodeOutput(skipped_node, i, false);
1263 MS_EXCEPTION_IF_NULL(kernel_with_index.first);
1264 auto from_actor = dynamic_cast<KernelActor *>(FetchActor(kernel_with_index.first->fullname_with_scope()));
1265 MS_EXCEPTION_IF_NULL(from_actor);
1266 MS_LOG(INFO) << "Link control arrow by skipped node: " << skipped_node->fullname_with_scope()
1267 << ", from actor: " << from_actor->GetAID().Name() << ", to actor: " << to_actor->GetAID().Name();
1268 (void)from_actor->output_control_arrows_.emplace_back(to_aid);
1269 to_actor->input_controls_num_++;
1270 }
1271 }
1272
LinkControlArrowBySendRecvNodes(const KernelGraphPtr & graph)1273 void GraphScheduler::LinkControlArrowBySendRecvNodes(const KernelGraphPtr &graph) {
1274 MS_EXCEPTION_IF_NULL(graph);
1275 for (auto &from_iter : graph->allreduce_from_send_recv_pairs()) {
1276 auto to_allreduce_node = from_iter.first;
1277 auto from_send_node = from_iter.second.first;
1278 auto from_recv_node = from_iter.second.second;
1279 MS_EXCEPTION_IF_NULL(to_allreduce_node);
1280 MS_EXCEPTION_IF_NULL(from_send_node);
1281 MS_EXCEPTION_IF_NULL(from_recv_node);
1282 MS_LOG(INFO) << "Link control arrow for to_allreduce_node: " << to_allreduce_node->fullname_with_scope();
1283 auto to_allreduce_actor = dynamic_cast<KernelActor *>(FetchActor(to_allreduce_node->fullname_with_scope()));
1284 auto from_send_actor = dynamic_cast<KernelActor *>(FetchActor(from_send_node->fullname_with_scope()));
1285 auto from_recv_actor = dynamic_cast<KernelActor *>(FetchActor(from_recv_node->fullname_with_scope()));
1286 MS_EXCEPTION_IF_NULL(to_allreduce_actor);
1287 MS_EXCEPTION_IF_NULL(from_send_actor);
1288 MS_EXCEPTION_IF_NULL(from_recv_actor);
1289
1290 // inputs of to_allreduce_actor --> from_send_actor
1291 for (auto &input_aid : to_allreduce_actor->input_data_arrow_aids_) {
1292 auto input_actor = dynamic_cast<KernelActor *>(FetchActor(input_aid.Name()));
1293 if (input_actor != nullptr) {
1294 (void)input_actor->output_control_arrows_.emplace_back(from_send_actor->GetAID());
1295 from_send_actor->input_controls_num_++;
1296 }
1297 }
1298
1299 // from_send_actor --> from_recv_actor
1300 (void)from_send_actor->output_control_arrows_.emplace_back(from_recv_actor->GetAID());
1301 from_recv_actor->input_controls_num_++;
1302
1303 // from_recv_actor --> to_allreduce_actor
1304 (void)from_recv_actor->output_control_arrows_.emplace_back(to_allreduce_actor->GetAID());
1305 to_allreduce_actor->input_controls_num_++;
1306 }
1307
1308 for (auto &to_iter : graph->allreduce_to_send_recv_pairs()) {
1309 auto from_allreduce_node = to_iter.first;
1310 auto to_send_node = to_iter.second.first;
1311 auto to_recv_node = to_iter.second.second;
1312 MS_EXCEPTION_IF_NULL(from_allreduce_node);
1313 MS_EXCEPTION_IF_NULL(to_send_node);
1314 MS_EXCEPTION_IF_NULL(to_recv_node);
1315 MS_LOG(INFO) << "Link control arrow for from_allreduce_node: " << from_allreduce_node->fullname_with_scope();
1316 auto from_allreduce_actor = dynamic_cast<KernelActor *>(FetchActor(from_allreduce_node->fullname_with_scope()));
1317 auto to_send_actor = dynamic_cast<KernelActor *>(FetchActor(to_send_node->fullname_with_scope()));
1318 auto to_recv_actor = dynamic_cast<KernelActor *>(FetchActor(to_recv_node->fullname_with_scope()));
1319 MS_EXCEPTION_IF_NULL(from_allreduce_actor);
1320 MS_EXCEPTION_IF_NULL(to_send_actor);
1321 MS_EXCEPTION_IF_NULL(to_recv_actor);
1322
1323 // from_allreduce_actor --> to_send_actor
1324 (void)from_allreduce_actor->output_control_arrows_.emplace_back(to_send_actor->GetAID());
1325 to_send_actor->input_controls_num_++;
1326
1327 // to_send_actor --> to_recv_actor
1328 (void)to_send_actor->output_control_arrows_.emplace_back(to_recv_actor->GetAID());
1329 to_recv_actor->input_controls_num_++;
1330
1331 // to_recv_actor --> outputs of from_allreduce_actor
1332 for (auto &output_data_arrow : from_allreduce_actor->output_data_arrows_) {
1333 auto output_actor = dynamic_cast<KernelActor *>(FetchActor(output_data_arrow->to_op_id_.Name()));
1334 if (output_actor != nullptr) {
1335 (void)to_recv_actor->output_control_arrows_.emplace_back(output_actor->GetAID());
1336 output_actor->input_controls_num_++;
1337 }
1338 }
1339
1340 // In the scene of allreduce op and computing op parallel multi stream, the input memory of allreduce can be
1341 // reused only when the recv node runs finished, which is expressed by the reference count increased.
1342 for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(from_allreduce_node); ++i) {
1343 auto device_tensor = AnfAlgo::GetPrevNodeMutableOutputAddr(from_allreduce_node, i, false);
1344 MS_EXCEPTION_IF_NULL(device_tensor);
1345 UpdateRefCount(device_tensor.get());
1346 (void)to_recv_actor->external_reference_tensors_.emplace_back(device_tensor.get());
1347 }
1348 }
1349 }
1350
LinkGlobalControlArrow(ActorSet * const actor_set,const std::vector<CNodePtr> & communication_nodes,const std::vector<KernelActor * > & auto_monad_actors,const GraphCompilerInfo & graph_compiler_info)1351 void GraphScheduler::LinkGlobalControlArrow(ActorSet *const actor_set, const std::vector<CNodePtr> &communication_nodes,
1352 const std::vector<KernelActor *> &auto_monad_actors,
1353 const GraphCompilerInfo &graph_compiler_info) {
1354 MS_EXCEPTION_IF_NULL(actor_set);
1355
1356 // Link the control arrows by the communication nodes to ensure communication nodes running order.
1357 LinkControlArrowByCommunicationNode(communication_nodes, graph_compiler_info);
1358
1359 // Auto monad actor may modify the device tensor store.
1360 LinkDeviceTensorStoreForAutoMonadActor(auto_monad_actors);
1361
1362 // BuildNoInputKernelActor depends on whether kernel actors have input, so must be behind the link of kernel actors.
1363 actor_set->no_input_kernel_actors_ = BuildNoInputKernelActor(actor_set, graph_compiler_info.strategy_);
1364
1365 // Link the control arrows of data prepare actor, which depends on the no input kernel actors.
1366 if ((graph_compiler_info.strategy_ == GraphExecutionStrategy::kPipeline) || (!IsSingleOpActorSet(actor_set))) {
1367 LinkControlArrowForDataPrepareActor(actor_set->data_prepare_actor_.get(), actor_set);
1368 }
1369
1370 LinkControlArrowForLoopCountActor(actor_set->loop_count_actor_.get(), actor_set,
1371 graph_compiler_info.control_node_parser_);
1372 }
1373
LinkControlArrowByCommunicationNode(const std::vector<CNodePtr> & communication_nodes,const GraphCompilerInfo & graph_compiler_info)1374 void GraphScheduler::LinkControlArrowByCommunicationNode(const std::vector<CNodePtr> &communication_nodes,
1375 const GraphCompilerInfo &graph_compiler_info) {
1376 const size_t kCommunicationNodesMinNum = 2;
1377 if (communication_nodes.size() < kCommunicationNodesMinNum) {
1378 return;
1379 }
1380
1381 // Ensure communication node to execute orderly.
1382 for (size_t i = 1; i < communication_nodes.size(); ++i) {
1383 auto from_actor = dynamic_cast<KernelActor *>(FetchActor(communication_nodes[i - 1]->fullname_with_scope()));
1384 auto to_actor = dynamic_cast<KernelActor *>(FetchActor(communication_nodes[i]->fullname_with_scope()));
1385 MS_EXCEPTION_IF_NULL(from_actor);
1386 MS_EXCEPTION_IF_NULL(to_actor);
1387 (void)from_actor->output_control_arrows_.emplace_back(to_actor->GetAID());
1388 to_actor->input_controls_num_++;
1389 }
1390
1391 // Ensure all actors execute orderly to optimize the execution performance in the multi device scenario currently.
1392 // Using the multi stream to optimize the performance in the future.
1393 for (auto &graph : graph_compiler_info.graphs_) {
1394 MS_EXCEPTION_IF_NULL(graph);
1395 auto &execution_order = graph->execution_order();
1396 for (size_t i = 1; i < execution_order.size(); ++i) {
1397 auto from_actor = dynamic_cast<KernelActor *>(FetchActor(execution_order[i - 1]->fullname_with_scope()));
1398 auto to_actor = dynamic_cast<KernelActor *>(FetchActor(execution_order[i]->fullname_with_scope()));
1399 if ((from_actor != nullptr) && (to_actor != nullptr)) {
1400 (void)from_actor->output_control_arrows_.emplace_back(to_actor->GetAID());
1401 to_actor->input_controls_num_++;
1402 }
1403 }
1404 }
1405 }
1406
LinkControlArrowForDataPrepareActor(DataPrepareActor * data_prepare_actor,const ActorSet * actor_set)1407 void GraphScheduler::LinkControlArrowForDataPrepareActor(DataPrepareActor *data_prepare_actor,
1408 const ActorSet *actor_set) {
1409 MS_EXCEPTION_IF_NULL(data_prepare_actor);
1410 MS_EXCEPTION_IF_NULL(actor_set);
1411
1412 // Data prepare actor --> data source actor.
1413 for (auto &data_source_actor : actor_set->data_source_actors_) {
1414 MS_EXCEPTION_IF_NULL(data_source_actor);
1415 (void)data_prepare_actor->data_source_aids_.emplace_back(data_source_actor->GetAID());
1416 }
1417
1418 // Data prepare actor --> no input kernel actor.
1419 for (auto &no_input_kernel_actor : actor_set->no_input_kernel_actors_) {
1420 MS_EXCEPTION_IF_NULL(no_input_kernel_actor);
1421 (void)data_prepare_actor->no_input_kernel_aids_.emplace_back(no_input_kernel_actor->GetAID());
1422 no_input_kernel_actor->input_controls_num_++;
1423 }
1424
1425 // Data prepare actor --> loop count actor.
1426 if ((actor_set->data_source_actors_.size() + actor_set->no_input_kernel_actors_.size() == 0) &&
1427 (actor_set->loop_count_actor_ != nullptr)) {
1428 data_prepare_actor->loop_count_aid_ = &(actor_set->loop_count_actor_->GetAID());
1429 }
1430 }
1431
LinkControlArrowForLoopCountActor(LoopCountActor * loop_count_actor,const ActorSet * actor_set,const ControlNodeParserPtr & parser)1432 void GraphScheduler::LinkControlArrowForLoopCountActor(LoopCountActor *loop_count_actor, const ActorSet *actor_set,
1433 const ControlNodeParserPtr &parser) {
1434 MS_EXCEPTION_IF_NULL(actor_set);
1435 MS_EXCEPTION_IF_NULL(parser);
1436 // There is no loop count actor in step mode.
1437 if (loop_count_actor == nullptr) {
1438 return;
1439 }
1440
1441 // Collect the actors which have no output.
1442 std::vector<MemoryAwareActor *> no_output_actors;
1443 for (auto &kernel_actor : actor_set->kernel_actors_) {
1444 // The no output kernel control side in subgraph needs to be connected to the corresponding output switch actor.
1445 if ((kernel_actor->output_data_arrows_.size() == 0) && (kernel_actor->output_control_arrows_.size() == 0) &&
1446 parser->IsKernelInRootFuncGraph(kernel_actor->kernel_)) {
1447 MS_EXCEPTION_IF_NULL(kernel_actor->kernel_);
1448 MS_LOG(INFO) << kernel_actor->kernel_->fullname_with_scope() << " is not real used by other nodes.";
1449 (void)no_output_actors.emplace_back(kernel_actor.get());
1450 }
1451 }
1452 for (auto &data_actor : actor_set->data_source_actors_) {
1453 if ((data_actor->output_data_arrows_.size() == 0) && (data_actor->output_control_arrows_.size() == 0)) {
1454 (void)no_output_actors.emplace_back(data_actor.get());
1455 }
1456 }
1457 for (auto ©_actor : copy_actors_) {
1458 if ((copy_actor->output_data_arrows_.size() == 0) && (copy_actor->output_control_arrows_.size() == 0)) {
1459 (void)no_output_actors.emplace_back(copy_actor.get());
1460 }
1461 }
1462
1463 // No output actor --> loop count actor.
1464 for (auto &no_output_actor : no_output_actors) {
1465 (void)no_output_actor->output_control_arrows_.emplace_back(loop_count_actor->GetAID());
1466 loop_count_actor->input_controls_num_++;
1467 }
1468
1469 // Loop count actor --> data prepare actor.
1470 MS_EXCEPTION_IF_NULL(actor_set->data_prepare_actor_);
1471 loop_count_actor->data_prepare_aid_ = actor_set->data_prepare_actor_->GetAID();
1472
1473 // Loop count actor --> output actor.
1474 MS_EXCEPTION_IF_NULL(actor_set->output_actor_);
1475 loop_count_actor->output_aid_ = actor_set->output_actor_->GetAID();
1476 }
1477
LinkOutputResultArrowForOutputActor(OutputActor * to_actor,const GraphCompilerInfo & graph_compiler_info)1478 void GraphScheduler::LinkOutputResultArrowForOutputActor(OutputActor *to_actor,
1479 const GraphCompilerInfo &graph_compiler_info) {
1480 if (graph_compiler_info.strategy_ == GraphExecutionStrategy::kStep) {
1481 return;
1482 }
1483
1484 MS_EXCEPTION_IF_NULL(to_actor);
1485
1486 for (size_t i = 0; i < graph_compiler_info.graphs_.size(); ++i) {
1487 const auto &graph = graph_compiler_info.graphs_[i];
1488 MS_EXCEPTION_IF_NULL(graph);
1489 auto outputs = AnfAlgo::GetAllOutputWithIndex(graph->output());
1490 std::set<std::vector<size_t>> unique_output_positions;
1491 std::set<KernelWithIndex> unique_outputs;
1492 for (const auto &output : outputs) {
1493 if (IsInternalParameter(output.first, graph)) {
1494 MS_LOG(INFO) << "Ignore the internal parameter node:" << output.first->DebugString();
1495 continue;
1496 }
1497 (void)unique_outputs.insert(output);
1498 }
1499 for (const auto &output_with_index : unique_outputs) {
1500 MS_EXCEPTION_IF_NULL(output_with_index.first);
1501 auto origin_output_with_index = FetchFrontNodeWithIndexByGraphOutput(output_with_index, graph);
1502 const auto &iter = graph_compiler_info.origin_outputs_order_.find(origin_output_with_index);
1503 if (iter == graph_compiler_info.origin_outputs_order_.end()) {
1504 continue;
1505 }
1506
1507 // Skip duplicate position.
1508 if (unique_output_positions.count(iter->second) > 0) {
1509 continue;
1510 }
1511 (void)unique_output_positions.insert(iter->second);
1512 for (auto &output_position : iter->second) {
1513 if (output_position >= to_actor->device_contexts_.size()) {
1514 MS_LOG(EXCEPTION) << "The output position is out of range.";
1515 }
1516 to_actor->device_contexts_[output_position] = graph_compiler_info.device_contexts_[i];
1517 // The device tensor of graph out need be taken over by host tensor, so set the max reference count.
1518 UpdateRefCount(output_with_index.first, output_with_index.second, true);
1519
1520 // The graph output is from device tensor store.
1521 if (IsPersistentDeviceTensor(output_with_index.first)) {
1522 (void)to_actor->device_tensor_store_keys_.emplace_back(output_position, output_with_index.first);
1523 continue;
1524 }
1525
1526 // The graph output is from kernel actor or data source actor.
1527 auto kernel_type = KernelTransformType::kUnknown;
1528 std::string kernel_name = "";
1529 FetchKernelTransformTypeAndName(output_with_index.first, graph, graph_compiler_info, &kernel_type,
1530 &kernel_name);
1531 auto from_actor = dynamic_cast<AbstractActor *>(FetchActor(kernel_name));
1532 if (from_actor == nullptr) {
1533 continue;
1534 }
1535 auto op_arrow = std::make_shared<DataArrow>(output_with_index.second, to_actor->GetAID(), output_position);
1536 auto position = from_actor->FetchNodePosition(output_with_index.first);
1537 // If the from actor has the multi nodes, then use the real output position.
1538 if (position != 0) {
1539 op_arrow->from_output_index_ = SizeToInt(position);
1540 }
1541 (void)from_actor->output_result_arrows_.emplace_back(op_arrow);
1542 if (kernel_type == KernelTransformType::kHostDataSourceActor) {
1543 auto host_queue_ds_actor = dynamic_cast<HostQueueDataSourceActor *>(from_actor);
1544 MS_EXCEPTION_IF_NULL(host_queue_ds_actor);
1545 UpdateRefCount(host_queue_ds_actor->data_nodes_[position], output_with_index.second, true);
1546 }
1547 }
1548 }
1549 }
1550 }
1551
LinkOutputResultArrowForSwitchActor(const GraphCompilerInfo & graph_compiler_info,const ActorSet * actor_set)1552 void GraphScheduler::LinkOutputResultArrowForSwitchActor(const GraphCompilerInfo &graph_compiler_info,
1553 const ActorSet *actor_set) {
1554 const auto &to_actor = actor_set->output_actor_;
1555 const auto &loop_count_actor = actor_set->loop_count_actor_;
1556 if (to_actor == nullptr || loop_count_actor == nullptr) {
1557 return;
1558 }
1559
1560 const auto &switch_actors = actor_set->switch_actors_;
1561 for (const auto &from_actor : switch_actors) {
1562 MS_EXCEPTION_IF_NULL(from_actor);
1563 auto origin_output_with_index = KernelWithIndex(from_actor->node_, 0);
1564 const auto &iter = graph_compiler_info.origin_outputs_order_.find(origin_output_with_index);
1565 if (iter == graph_compiler_info.origin_outputs_order_.end()) {
1566 continue;
1567 }
1568
1569 // If the switch actor is in the output list, the output of switch actor should be sent to the output actor.
1570 // And need to link a control arrow to the loop count actor.
1571 for (const auto pos : iter->second) {
1572 to_actor->device_contexts_[pos] = from_actor->device_context_;
1573 }
1574
1575 for (size_t i = 0; i < from_actor->branch_inputs_pos_.size(); ++i) {
1576 const auto &input_pos = from_actor->branch_inputs_pos_[i];
1577 if (input_pos.empty()) {
1578 MS_LOG(EXCEPTION) << "Invalid input num in switch actor:" << from_actor->GetAID();
1579 }
1580
1581 for (const auto pos : iter->second) {
1582 auto op_arrow = std::make_shared<DataArrow>(0, to_actor->GetAID(), pos);
1583 (void)from_actor->output_branch_result_arrows_[i].emplace_back(op_arrow);
1584 }
1585
1586 (void)from_actor->output_branch_control_arrows_[i].emplace_back(loop_count_actor->GetAID());
1587 }
1588 loop_count_actor->input_controls_num_++;
1589 }
1590 }
1591
LinkDeviceTensorStoreForAutoMonadActor(const std::vector<KernelActor * > & auto_monad_actors)1592 void GraphScheduler::LinkDeviceTensorStoreForAutoMonadActor(const std::vector<KernelActor *> &auto_monad_actors) {
1593 const size_t kNeedUpdateDeviceTensorStoreNum = 2;
1594 for (auto &kernel_actor : auto_monad_actors) {
1595 MS_EXCEPTION_IF_NULL(kernel_actor);
1596 for (auto &device_tensor_store_key : kernel_actor->device_tensor_store_keys_) {
1597 auto device_tensors = DeviceTensorStore::GetInstance().Fetch(device_tensor_store_key.second.get());
1598 if (device_tensors.size() < kNeedUpdateDeviceTensorStoreNum) {
1599 continue;
1600 }
1601
1602 // Create the copy actor.
1603 std::string name = "copy_from:" + kernel_actor->GetAID().Name() +
1604 "_device_tensor_store:" + device_tensor_store_key.second->fullname_with_scope();
1605 if (FetchActor(name) != nullptr) {
1606 continue;
1607 }
1608 auto copy_actor = std::make_shared<CopyActor>(name, memory_manager_aid_);
1609 MS_EXCEPTION_IF_NULL(copy_actor);
1610 (void)copy_actors_.emplace_back(copy_actor);
1611 InsertActor(copy_actor.get());
1612
1613 // Set the member of the copy actor.
1614 (void)copy_actor->device_tensor_store_keys_.emplace_back(0, device_tensor_store_key.second);
1615 auto input_device_context = kernel_actor->device_contexts_[0];
1616 (void)copy_actor->device_contexts_.emplace_back(input_device_context);
1617 auto another_device_tensor = (device_tensors[0]->DeviceType() == input_device_context->GetDeviceAddressType())
1618 ? device_tensors[1]
1619 : device_tensors[0];
1620 MS_EXCEPTION_IF_NULL(another_device_tensor);
1621 auto another_device_type = another_device_tensor->DeviceType();
1622 const auto &another_device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
1623 {device::kDeviceTypeToName.at(another_device_type), input_device_context->device_context_key().device_id_});
1624 MS_EXCEPTION_IF_NULL(another_device_context);
1625 (void)copy_actor->device_contexts_.emplace_back(another_device_context);
1626
1627 MS_LOG(INFO) << "The kernel actor: " << kernel_actor->GetAID().Name()
1628 << "has control arrows number:" << kernel_actor->output_control_arrows_.size();
1629 // Link from copy actor to kernel actor users.
1630 for (auto &output_contorl : kernel_actor->output_control_arrows_) {
1631 (void)copy_actor->output_control_arrows_.emplace_back(output_contorl);
1632 }
1633 // Move the control arrows from kernel actor to kernel actor users.
1634 kernel_actor->output_control_arrows_.clear();
1635
1636 // Link from kernel actor to copy actor.
1637 (void)kernel_actor->output_control_arrows_.emplace_back(copy_actor->GetAID());
1638 copy_actor->input_controls_num_++;
1639 }
1640 }
1641 }
1642
PrepareInputNodeForSwitchActor(const std::vector<AnfNodePtr> & control_nodes)1643 void GraphScheduler::PrepareInputNodeForSwitchActor(const std::vector<AnfNodePtr> &control_nodes) {
1644 for (const auto &node : control_nodes) {
1645 CNodePtr cnode = node->cast<CNodePtr>();
1646 auto inputs = cnode->inputs();
1647 // Before link data arrow, parameters of the call node in switch-call need to be add to the switch actor.
1648 if (inputs[0]->isa<CNode>()) {
1649 auto actor = FetchActor(inputs[0]->DebugString());
1650 MS_EXCEPTION_IF_NULL(actor);
1651 auto switch_actor = dynamic_cast<SwitchActor *>(actor);
1652 MS_EXCEPTION_IF_NULL(switch_actor);
1653
1654 for (size_t i = kCallInputStartPos; i < inputs.size(); ++i) {
1655 if (HasAbstractMonad(inputs[i])) {
1656 continue;
1657 }
1658 switch_actor->AddCommonInput(inputs[i]);
1659 }
1660 }
1661 }
1662 }
1663
LinkArrowByControlNode(const GraphCompilerInfo & graph_compiler_info,ActorSet * const actor_set)1664 void GraphScheduler::LinkArrowByControlNode(const GraphCompilerInfo &graph_compiler_info, ActorSet *const actor_set) {
1665 PrepareInputNodeForSwitchActor(graph_compiler_info.control_nodes_);
1666
1667 for (const auto &node : graph_compiler_info.control_nodes_) {
1668 CNodePtr cnode = node->cast<CNodePtr>();
1669 const auto &from_func_graph = node->func_graph();
1670 auto inputs = cnode->inputs();
1671 // Link data arrow for switch node.
1672 if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimSwitch) ||
1673 AnfAlgo::CheckPrimitiveType(node, prim::kPrimSwitchLayer)) {
1674 auto actor = actor_name_to_actor_[node->DebugString()];
1675 MS_EXCEPTION_IF_NULL(actor);
1676 auto switch_actor = dynamic_cast<SwitchActor *>(actor);
1677 MS_EXCEPTION_IF_NULL(switch_actor);
1678 LinkDataArrowForSwitchActor(graph_compiler_info, switch_actor);
1679 } else if (inputs[0]->isa<ValueNode>() && IsValueNode<FuncGraph>(inputs[0])) {
1680 // Link the data arrow for the input of the call node.
1681 const auto &actor_name = node->DebugString();
1682 auto actor = FetchActor(actor_name);
1683 MS_EXCEPTION_IF_NULL(actor);
1684 auto gather_actor = dynamic_cast<GatherActor *>(actor);
1685 MS_EXCEPTION_IF_NULL(gather_actor);
1686
1687 const auto &func_graph = GetValueNode<FuncGraphPtr>(inputs[0]);
1688 MS_EXCEPTION_IF_NULL(func_graph);
1689 const auto &to_actor = FetchActor(func_graph->ToString());
1690 MS_EXCEPTION_IF_NULL(to_actor);
1691
1692 size_t persist_input_num = 0;
1693 for (size_t i = kCallInputStartPos; i < inputs.size(); ++i) {
1694 MS_EXCEPTION_IF_NULL(actor);
1695 if (inputs[i]->isa<ValueNode>()) {
1696 const auto &node_value = inputs[i]->cast<ValueNodePtr>()->value();
1697 if (!node_value->isa<tensor::Tensor>()) {
1698 persist_input_num++;
1699 continue;
1700 }
1701
1702 (void)gather_actor->device_tensor_store_keys_.emplace_back(i - kCallInputStartPos - persist_input_num,
1703 inputs[i].get());
1704 gather_actor->device_contexts_[i - kCallInputStartPos - persist_input_num] =
1705 graph_compiler_info.control_node_parser_->GetFrontValueNodeDeviceContext(inputs[i]);
1706 } else if ((inputs[i]->isa<Parameter>() && HasAbstractRef(inputs[i]->cast<ParameterPtr>())) ||
1707 AnfAlgo::CheckPrimitiveType(inputs[i], prim::kPrimUpdateState) || HasAbstractMonad(inputs[i])) {
1708 persist_input_num++;
1709 continue;
1710 } else {
1711 const auto &input_with_index = AnfAlgo::VisitKernelWithReturnType(inputs[i], 0);
1712 LinkDataArrowByControlNode(graph_compiler_info, input_with_index, from_func_graph, actor,
1713 i - kCallInputStartPos - persist_input_num);
1714 }
1715
1716 auto op_arrow = std::make_shared<DataArrow>(i - kCallInputStartPos - persist_input_num, to_actor->GetAID(),
1717 i - kCallInputStartPos - persist_input_num);
1718 (void)gather_actor->output_data_arrows_.emplace_back(op_arrow);
1719 }
1720 }
1721 }
1722
1723 // Link arrow for switch actor of subgraph output.
1724 for (const auto &func_graph_with_call_num : graph_compiler_info.control_node_parser_->func_graph_to_call_num_) {
1725 const auto &func_graph = func_graph_with_call_num.first;
1726 MS_EXCEPTION_IF_NULL(func_graph);
1727 auto actor = FetchActor(func_graph->get_return()->DebugString());
1728 MS_EXCEPTION_IF_NULL(actor);
1729 auto switch_actor = dynamic_cast<SwitchActor *>(actor);
1730 MS_EXCEPTION_IF_NULL(switch_actor);
1731 LinkDataArrowForSwitchActor(graph_compiler_info, switch_actor);
1732 }
1733
1734 // Link arrow for gather actor for call input kernel graph.
1735 for (const auto &call_input_kernel_graph : graph_compiler_info.control_node_parser_->call_input_kernel_graphs_) {
1736 const auto &kernel_graph = call_input_kernel_graph.first;
1737 MS_EXCEPTION_IF_NULL(kernel_graph);
1738 auto actor = FetchActor(kernel_graph->ToString());
1739 MS_EXCEPTION_IF_NULL(actor);
1740 auto gather_actor = dynamic_cast<GatherActor *>(actor);
1741 MS_EXCEPTION_IF_NULL(gather_actor);
1742
1743 for (size_t i = 0; i < gather_actor->data_nodes_.size(); ++i) {
1744 const auto &input_with_index = gather_actor->data_nodes_[i];
1745 const auto &from_func_graph = kernel_graph->GetFuncGraph();
1746 LinkDataArrowByControlNode(graph_compiler_info, input_with_index, from_func_graph, gather_actor, i);
1747 }
1748 }
1749 LinkBranchArrowForSwitchActor(graph_compiler_info);
1750
1751 LinkBranchArrowForGatherActor(graph_compiler_info);
1752
1753 LinkControlArrowForGatherActor(&(actor_set->kernel_actors_), graph_compiler_info.graphs_,
1754 graph_compiler_info.control_node_parser_);
1755
1756 LinkControlArrowForSwitchActor(&(actor_set->switch_actors_), actor_set->loop_count_actor_.get(),
1757 graph_compiler_info.origin_outputs_order_);
1758
1759 LinkOutputResultArrowForSwitchActor(graph_compiler_info, actor_set);
1760 }
1761
LinkDataArrowForGatherActor(GatherActor * const from_actor,KernelActor * const to_actor,const KernelWithIndex & front_node_with_index,const KernelWithIndex & to_node_with_index)1762 void GraphScheduler::LinkDataArrowForGatherActor(GatherActor *const from_actor, KernelActor *const to_actor,
1763 const KernelWithIndex &front_node_with_index,
1764 const KernelWithIndex &to_node_with_index) {
1765 MS_EXCEPTION_IF_NULL(from_actor);
1766 MS_EXCEPTION_IF_NULL(to_actor);
1767 MS_EXCEPTION_IF_NULL(front_node_with_index.first);
1768
1769 auto position = from_actor->FetchDataNodePosition(front_node_with_index);
1770
1771 auto op_arrow = std::make_shared<DataArrow>(position, to_actor->GetAID(), to_node_with_index.second);
1772 (void)from_actor->output_data_arrows_.emplace_back(op_arrow);
1773 to_actor->input_datas_num_++;
1774 }
1775
LinkDataArrowByCallInput(const KernelWithIndex & call_node_with_index,const ControlNodeParserPtr & parser,const FuncGraphPtr & from_func_graph,OpActor<DeviceTensor> * const to_actor,const size_t to_index)1776 void GraphScheduler::LinkDataArrowByCallInput(const KernelWithIndex &call_node_with_index,
1777 const ControlNodeParserPtr &parser, const FuncGraphPtr &from_func_graph,
1778 OpActor<DeviceTensor> *const to_actor, const size_t to_index) {
1779 // Fetch all the funcgraph that call node would call.
1780 const auto cnode = call_node_with_index.first->cast<CNodePtr>();
1781 std::vector<FuncGraphPtr> func_graphs = FetchFuncGraphbyCallNode(cnode);
1782
1783 // Collect the output of each funcgraph.
1784 for (const auto &func_graph : func_graphs) {
1785 const auto actor_name = func_graph->get_return()->DebugString();
1786 auto actor = FetchActor(actor_name);
1787 MS_EXCEPTION_IF_NULL(actor);
1788 auto switch_actor = dynamic_cast<SwitchActor *>(actor);
1789 MS_EXCEPTION_IF_NULL(switch_actor);
1790 const size_t branch_index = switch_actor->branch_id_to_index_.size();
1791
1792 const auto &func_graph_to_branch_id = parser->func_graph_to_branch_id_;
1793 const auto &iter = func_graph_to_branch_id.find(from_func_graph);
1794
1795 int branch_id = kMainBranchID;
1796 if (iter != func_graph_to_branch_id.end()) {
1797 branch_id = iter->second;
1798 }
1799 if (switch_actor->branch_id_to_index_.find(branch_id) != switch_actor->branch_id_to_index_.end()) {
1800 LinkDataArrowForSwitchActor(switch_actor, call_node_with_index.second, to_actor, to_index,
1801 switch_actor->branch_id_to_index_[branch_id]);
1802 continue;
1803 }
1804 LinkDataArrowForSwitchActor(switch_actor, call_node_with_index.second, to_actor, to_index, branch_index);
1805 switch_actor->branch_id_to_index_[branch_id] = branch_index;
1806 }
1807 }
1808
LinkDataArrowForSwitchActor(SwitchActor * from_actor,const size_t from_index,OpActor<DeviceTensor> * to_actor,const size_t to_index,const size_t branch_index)1809 void GraphScheduler::LinkDataArrowForSwitchActor(SwitchActor *from_actor, const size_t from_index,
1810 OpActor<DeviceTensor> *to_actor, const size_t to_index,
1811 const size_t branch_index) {
1812 MS_EXCEPTION_IF_NULL(from_actor);
1813 MS_EXCEPTION_IF_NULL(to_actor);
1814 size_t start_branch = 0;
1815 size_t max_branch = from_actor->output_branch_arrows_.size();
1816 if (branch_index != SIZE_MAX) {
1817 start_branch = branch_index;
1818 max_branch = branch_index + 1;
1819 }
1820 for (size_t i = start_branch; i < max_branch; ++i) {
1821 if (from_actor->branch_inputs_pos_[i].size() <= from_index) {
1822 MS_LOG(EXCEPTION) << "No input for switch actor:" << from_actor->GetAID() << " branch:" << i
1823 << " from index:" << from_index << " output size:" << from_actor->branch_inputs_pos_[i].size()
1824 << " to actor:" << to_actor->GetAID() << " to index:" << to_index;
1825 }
1826 auto op_arrow =
1827 std::make_shared<DataArrow>(from_actor->branch_inputs_pos_[i][from_index], to_actor->GetAID(), to_index);
1828 (void)from_actor->output_branch_arrows_[i].emplace_back(op_arrow);
1829 }
1830 }
1831
LinkDataArrowByControlNode(const GraphCompilerInfo & graph_compiler_info,const KernelWithIndex & input_with_index,const FuncGraphPtr & from_func_graph,OpActor<DeviceTensor> * const to_actor,const size_t to_index)1832 void GraphScheduler::LinkDataArrowByControlNode(const GraphCompilerInfo &graph_compiler_info,
1833 const KernelWithIndex &input_with_index,
1834 const FuncGraphPtr &from_func_graph,
1835 OpActor<DeviceTensor> *const to_actor, const size_t to_index) {
1836 const auto ¶meters = graph_compiler_info.origin_parameters_order_;
1837 const auto &front_to_backend_parameter = graph_compiler_info.control_node_parser_->front_to_backend_parameters_;
1838 const auto &input_node = input_with_index.first;
1839
1840 if (IsCallNode(input_node)) {
1841 // The actor input is a call node.
1842 LinkDataArrowByCallInput(input_with_index, graph_compiler_info.control_node_parser_, from_func_graph, to_actor,
1843 to_index);
1844 } else if (IsGatherActor(input_node, actor_name_to_actor_)) {
1845 // The actor input is a parameter in gather actor.
1846 auto from_actor = dynamic_cast<GatherActor *>(actor_name_to_actor_[input_node->func_graph()->ToString()]);
1847 auto position = from_actor->FetchDataNodePosition({input_node, 0});
1848 auto op_arrow = std::make_shared<DataArrow>(position, to_actor->GetAID(), to_index);
1849 (void)from_actor->output_data_arrows_.emplace_back(op_arrow);
1850 } else if (IsSwitchActor(input_node)) {
1851 const auto &actor_name = input_node->DebugString();
1852 auto actor = FetchActor(actor_name);
1853 MS_EXCEPTION_IF_NULL(actor);
1854 LinkDataArrowForSwitchActor(dynamic_cast<SwitchActor *>(actor), 0, to_actor, to_index);
1855 } else if (IsKernelActor(input_node, graph_compiler_info.strategy_)) {
1856 // The actor input is a cnode.
1857 if (front_node_to_actor_.find(input_node) == front_node_to_actor_.end()) {
1858 const auto &kernel_with_index = AnfAlgo::VisitKernelWithReturnType(input_node, 0);
1859 const auto &backend_node =
1860 graph_compiler_info.control_node_parser_->GetBackendKernelByFrontKernel(kernel_with_index);
1861 if (backend_node.first == nullptr) {
1862 MS_LOG(EXCEPTION) << "Cannot find actor:" << to_actor->GetAID()
1863 << " input_node:" << AnfAlgo::GetNodeDebugString(input_node) << " addr:" << input_node;
1864 }
1865 const auto &actor_name = backend_node.first->fullname_with_scope();
1866 const auto &actor = FetchActor(actor_name);
1867 MS_EXCEPTION_IF_NULL(actor);
1868 auto from_actor = dynamic_cast<KernelActor *>(actor);
1869 MS_EXCEPTION_IF_NULL(from_actor);
1870
1871 auto op_arrow = std::make_shared<DataArrow>(backend_node.second, to_actor->GetAID(), to_index);
1872 (void)from_actor->output_data_arrows_.emplace_back(op_arrow);
1873 auto device_tensor = AnfAlgo::GetMutableOutputAddr(from_actor->kernel_, backend_node.second, false);
1874 UpdateRefCount(device_tensor.get(), true);
1875 return;
1876 }
1877
1878 auto op_arrow = std::make_shared<DataArrow>(input_with_index.second, to_actor->GetAID(), to_index);
1879 auto from_actor = front_node_to_actor_[input_node];
1880 (void)from_actor->output_data_arrows_.emplace_back(op_arrow);
1881 auto device_tensor = AnfAlgo::GetMutableOutputAddr(from_actor->kernel_, input_with_index.second, false);
1882 UpdateRefCount(device_tensor.get(), true);
1883 } else if (find(parameters.begin(), parameters.end(), input_node) != parameters.end()) {
1884 // The actor input is a parameter in host data source actor.
1885 std::string actor_name = graph_compiler_info.name_ + "_HostDSActor";
1886
1887 auto actor = FetchActor(actor_name);
1888 MS_EXCEPTION_IF_NULL(actor);
1889 auto from_actor = dynamic_cast<HostQueueDataSourceActor *>(actor);
1890 MS_EXCEPTION_IF_NULL(from_actor);
1891
1892 auto backend_iter = front_to_backend_parameter.find(input_node);
1893 if (backend_iter == front_to_backend_parameter.end()) {
1894 MS_LOG(EXCEPTION) << "Cannot find backend node for front node:" << AnfAlgo::GetNodeDebugString(input_node);
1895 }
1896
1897 const auto &backend_node = backend_iter->second.first;
1898 auto iter = from_actor->data_node_position_map_.find(input_node);
1899 if (iter == from_actor->data_node_position_map_.end()) {
1900 MS_LOG(EXCEPTION) << "Cannot find data node in data source actor, backend node:"
1901 << AnfAlgo::GetNodeDebugString(backend_node)
1902 << " front node:" << AnfAlgo::GetNodeDebugString(input_node);
1903 }
1904
1905 auto op_arrow = std::make_shared<DataArrow>(iter->second, to_actor->GetAID(), to_index);
1906 (void)from_actor->output_data_arrows_.emplace_back(op_arrow);
1907 auto device_tensor = AnfAlgo::GetMutableOutputAddr(from_actor->data_nodes_[iter->second], 0, false);
1908 UpdateRefCount(device_tensor.get(), true);
1909 } else {
1910 MS_LOG(EXCEPTION) << "Cannot find actor of switch input_node:" << AnfAlgo::GetNodeDebugString(input_node)
1911 << " to actor:" << to_actor->GetAID();
1912 }
1913 }
1914
LinkDataArrowForSwitchActor(const GraphCompilerInfo & graph_compiler_info,SwitchActor * const actor)1915 void GraphScheduler::LinkDataArrowForSwitchActor(const GraphCompilerInfo &graph_compiler_info,
1916 SwitchActor *const actor) {
1917 // Link switch input.
1918 const auto &inputs = actor->input_nodes_;
1919 for (size_t i = 0; i < inputs.size(); ++i) {
1920 auto input = inputs[i];
1921 if (input.first->isa<ValueNode>() || (input.first->isa<Parameter>() && HasAbstractRef(input.first))) {
1922 continue;
1923 }
1924
1925 const FuncGraphPtr from_func_graph = actor->node_->func_graph();
1926 LinkDataArrowByControlNode(graph_compiler_info, input, from_func_graph, actor, i);
1927 }
1928
1929 // Link switch output.
1930 for (size_t i = 0; i < actor->branch_func_graph_.size(); ++i) {
1931 auto func_graph = actor->branch_func_graph_[i];
1932 if (func_graph == nullptr) {
1933 continue;
1934 }
1935
1936 auto gather_name = func_graph->ToString();
1937 if (actor_name_to_actor_.find(gather_name) == actor_name_to_actor_.end()) {
1938 MS_LOG(EXCEPTION) << "Cannot find gather actor for funcgraph:" << gather_name
1939 << ",switch input size:" << actor->input_nodes_.size();
1940 }
1941 auto to_actor = dynamic_cast<GatherActor *>(actor_name_to_actor_[gather_name]);
1942 for (size_t j = 0; j < actor->branch_inputs_pos_[i].size(); ++j) {
1943 auto pos = actor->branch_inputs_pos_[i][j];
1944 auto to_actor_index = j;
1945 auto op_arrow = std::make_shared<DataArrow>(pos, to_actor->GetAID(), to_actor_index);
1946 (void)actor->output_branch_arrows_[i].emplace_back(op_arrow);
1947 }
1948 }
1949 }
1950
LinkControlArrowForGatherActor(std::vector<KernelActorPtr> * const kernel_actors,const std::vector<KernelGraphPtr> & graphs,const ControlNodeParserPtr & parser)1951 void GraphScheduler::LinkControlArrowForGatherActor(std::vector<KernelActorPtr> *const kernel_actors,
1952 const std::vector<KernelGraphPtr> &graphs,
1953 const ControlNodeParserPtr &parser) {
1954 // Link control arrow to kernel actor.
1955 for (size_t i = 0; i < graphs.size(); ++i) {
1956 const auto &kernel_graph = graphs[i];
1957 MS_EXCEPTION_IF_NULL(kernel_graph);
1958 const auto &func_graph = kernel_graph->GetFuncGraph();
1959 if (func_graph == nullptr) {
1960 continue;
1961 }
1962 const auto &actor = FetchActor(func_graph->ToString());
1963 if (actor == nullptr) {
1964 continue;
1965 }
1966 const auto &gather_actor = dynamic_cast<GatherActor *>(actor);
1967 MS_EXCEPTION_IF_NULL(gather_actor);
1968
1969 // When gather actor is not empty, it means the control arrow of no input kernel actor needs to be sent by gather.
1970 for (const auto &kernel : kernel_graph->execution_order()) {
1971 if (IsKernelActor(kernel) && (!IsSkippedKernelActor(kernel))) {
1972 const auto &kernel_actor = dynamic_cast<KernelActor *>(FetchActor(kernel->fullname_with_scope()));
1973 MS_EXCEPTION_IF_NULL(kernel_actor);
1974
1975 if ((kernel_actor->input_datas_num_ == 0) && (kernel_actor->input_controls_num_ == 0)) {
1976 (void)gather_actor->output_control_arrows_.emplace_back(kernel_actor->GetAID());
1977 kernel_actor->input_controls_num_ = 1;
1978 }
1979 }
1980 }
1981 }
1982
1983 for (auto &kernel_actor : *kernel_actors) {
1984 MS_EXCEPTION_IF_NULL(kernel_actor);
1985
1986 if ((kernel_actor->output_data_arrows_.size() == 0) && (kernel_actor->output_control_arrows_.size() == 0) &&
1987 !parser->IsKernelInRootFuncGraph(kernel_actor->kernel_)) {
1988 // Check whether the kernel actor belongs to the root graph.
1989 // In general, all no output nodes belong to the root funcgraph, and the corresponding switch actor for output
1990 // should be empty. In control flow, the control arrow of the no output node in the sub funcgraph should be
1991 // sent to the output switch actor.
1992 const auto &graph = kernel_actor->kernel_->func_graph();
1993 OpActor<DeviceTensor> *actor = nullptr;
1994
1995 if (graph != nullptr) {
1996 const auto &kernel_graph = dynamic_cast<KernelGraph *>(graph.get());
1997 const auto func_graph = kernel_graph->GetFuncGraph();
1998 if (func_graph != nullptr) {
1999 actor = FetchActor(func_graph->get_return()->DebugString());
2000 if (actor != nullptr) {
2001 auto switch_actor = dynamic_cast<SwitchActor *>(actor);
2002 MS_EXCEPTION_IF_NULL(switch_actor);
2003
2004 (void)kernel_actor->output_control_arrows_.emplace_back(switch_actor->GetAID());
2005 switch_actor->input_controls_num_++;
2006 }
2007 }
2008 }
2009 }
2010 }
2011
2012 // Link input auto monad control arrow from kernel actor to gather actor.
2013 const auto &monad_nodes = parser->kernel_to_call_nodes_;
2014 for (const auto node_pair : monad_nodes) {
2015 const auto &kernel_actor_name = node_pair.first->fullname_with_scope();
2016 const auto &gather_actor_name = node_pair.second->DebugString();
2017 auto kernel_op_actor = FetchActor(kernel_actor_name);
2018 auto gather_op_actor = FetchActor(gather_actor_name);
2019 if (kernel_op_actor == nullptr || gather_op_actor == nullptr) {
2020 continue;
2021 }
2022 auto kernel_actor = dynamic_cast<KernelActor *>(kernel_op_actor);
2023 auto gather_actor = dynamic_cast<GatherActor *>(gather_op_actor);
2024 (void)kernel_actor->output_control_arrows_.emplace_back(gather_actor->GetAID());
2025 gather_actor->input_controls_num_++;
2026 }
2027 }
2028
LinkControlArrowForSwitchActor(std::vector<SwitchActorPtr> * const switch_actors,LoopCountActor * const to_actor,const KernelMapPosition & origin_outputs_order)2029 void GraphScheduler::LinkControlArrowForSwitchActor(std::vector<SwitchActorPtr> *const switch_actors,
2030 LoopCountActor *const to_actor,
2031 const KernelMapPosition &origin_outputs_order) {
2032 if (to_actor == nullptr || (*switch_actors).empty()) {
2033 return;
2034 }
2035
2036 // If there is no output from the switch actor branch, it means that the subgraph has no input,
2037 // and need to connect a control arrow to the corresponding gather actor.
2038 for (auto &switch_actor : (*switch_actors)) {
2039 if (AnfAlgo::CheckPrimitiveType(switch_actor->node_, prim::kPrimReturn)) {
2040 const auto &func_graph = switch_actor->node_->func_graph();
2041 if (func_graph->output()->isa<ValueNode>()) {
2042 const auto &actor_name = func_graph->ToString();
2043 auto actor = FetchActor(actor_name);
2044 MS_EXCEPTION_IF_NULL(actor);
2045 auto gather_actor = dynamic_cast<GatherActor *>(actor);
2046 MS_EXCEPTION_IF_NULL(gather_actor);
2047 (void)gather_actor->output_control_arrows_.emplace_back(switch_actor->GetAID());
2048 switch_actor->input_controls_num_++;
2049 }
2050 }
2051
2052 for (size_t i = 0; i < switch_actor->output_branch_arrows_.size(); ++i) {
2053 const auto &arrows = switch_actor->output_branch_arrows_[i];
2054 if (arrows.empty() && switch_actor->branch_func_graph_[i] != nullptr) {
2055 const auto &actor_name = switch_actor->branch_func_graph_[i]->ToString();
2056 const auto &actor = FetchActor(actor_name);
2057 if (actor != nullptr) {
2058 const auto &gather_actor = dynamic_cast<GatherActor *>(actor);
2059 MS_EXCEPTION_IF_NULL(gather_actor);
2060 (void)switch_actor->output_branch_control_arrows_[i].emplace_back(gather_actor->GetAID());
2061 gather_actor->input_controls_num_++;
2062 }
2063 }
2064 }
2065 }
2066
2067 // Collect all the call node in outputs.
2068 std::set<AnfNodePtr> call_nodes;
2069 for (const auto &output : origin_outputs_order) {
2070 if (IsCallNode(output.first.first)) {
2071 (void)call_nodes.insert(output.first.first);
2072 }
2073 }
2074 to_actor->input_controls_num_ += call_nodes.size();
2075
2076 // Link the output switch actor of the subgraph to the output actor.
2077 for (const auto &call_node : call_nodes) {
2078 const auto &func_graphs = FetchFuncGraphbyCallNode(call_node);
2079 for (const auto func_graph : func_graphs) {
2080 MS_EXCEPTION_IF_NULL(func_graph);
2081 const auto &actor_name = func_graph->get_return()->DebugString();
2082 auto actor = FetchActor(actor_name);
2083 MS_EXCEPTION_IF_NULL(actor);
2084 auto switch_actor = dynamic_cast<SwitchActor *>(actor);
2085 MS_EXCEPTION_IF_NULL(switch_actor);
2086
2087 size_t branch_index = switch_actor->branch_id_to_index_.size();
2088 if (switch_actor->branch_id_to_index_.find(kMainBranchID) != switch_actor->branch_id_to_index_.end()) {
2089 branch_index = switch_actor->branch_id_to_index_[kMainBranchID];
2090 } else {
2091 switch_actor->branch_id_to_index_[kMainBranchID] = branch_index;
2092 }
2093
2094 (void)switch_actor->output_branch_control_arrows_[branch_index].emplace_back(to_actor->GetAID());
2095 }
2096 }
2097 }
2098
LinkBranchArrowForSwitchActor(const GraphCompilerInfo & graph_compiler_info)2099 void GraphScheduler::LinkBranchArrowForSwitchActor(const GraphCompilerInfo &graph_compiler_info) {
2100 for (const auto &control_node : graph_compiler_info.control_nodes_) {
2101 if (AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimSwitch) ||
2102 AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimSwitchLayer)) {
2103 const auto &actor_name = control_node->DebugString();
2104 auto actor = FetchActor(actor_name);
2105 MS_EXCEPTION_IF_NULL(actor);
2106 auto switch_actor = dynamic_cast<SwitchActor *>(actor);
2107 MS_EXCEPTION_IF_NULL(switch_actor);
2108
2109 for (size_t i = 0; i < switch_actor->branch_func_graph_.size(); ++i) {
2110 const auto &func_graph = switch_actor->branch_func_graph_[i];
2111 if (func_graph == nullptr) {
2112 continue;
2113 }
2114
2115 const auto &gather_actor = FetchActor(func_graph->ToString());
2116 MS_EXCEPTION_IF_NULL(gather_actor);
2117 (void)switch_actor->output_branch_branch_arrows_[i].emplace_back(gather_actor->GetAID());
2118 }
2119 }
2120 }
2121 }
2122
LinkBranchArrowForGatherActor(const GraphCompilerInfo & graph_compiler_info)2123 void GraphScheduler::LinkBranchArrowForGatherActor(const GraphCompilerInfo &graph_compiler_info) {
2124 if (graph_compiler_info.control_nodes_.empty()) {
2125 return;
2126 }
2127
2128 // Link branch arrow from gather actor to gather actor.
2129 for (const auto &control_node : graph_compiler_info.control_nodes_) {
2130 const auto &cnode = control_node->cast<CNodePtr>();
2131 const auto &inputs = cnode->inputs();
2132 if (inputs[0]->isa<ValueNode>() && IsValueNode<FuncGraph>(inputs[0])) {
2133 const auto &actor_name = control_node->DebugString();
2134 auto actor = FetchActor(actor_name);
2135 MS_EXCEPTION_IF_NULL(actor);
2136 auto gather_actor = dynamic_cast<GatherActor *>(actor);
2137 MS_EXCEPTION_IF_NULL(gather_actor);
2138 (void)gather_actor->output_branch_arrows_.emplace_back(gather_actor->gather_aid_);
2139 }
2140 }
2141
2142 // Link branch arrow from gather actor to switch actor.
2143 for (const auto &func_graph_with_call_num : graph_compiler_info.control_node_parser_->func_graph_to_call_num_) {
2144 const auto &actor_name = func_graph_with_call_num.first->ToString();
2145 auto actor = FetchActor(actor_name);
2146 MS_EXCEPTION_IF_NULL(actor);
2147 auto gather_actor = dynamic_cast<GatherActor *>(actor);
2148 MS_EXCEPTION_IF_NULL(gather_actor);
2149 (void)gather_actor->output_branch_arrows_.emplace_back(gather_actor->switch_aid_);
2150 }
2151 }
2152
CheckActorValid(const ActorSet * actor_set,GraphExecutionStrategy strategy) const2153 bool GraphScheduler::CheckActorValid(const ActorSet *actor_set, GraphExecutionStrategy strategy) const {
2154 MS_EXCEPTION_IF_NULL(actor_set);
2155 // Check the data source actors.
2156 for (const auto &data_source_actor : actor_set->data_source_actors_) {
2157 MS_EXCEPTION_IF_NULL(data_source_actor);
2158 if (data_source_actor->output_data_arrows_.size() + data_source_actor->output_result_arrows_.size() +
2159 data_source_actor->output_control_arrows_.size() ==
2160 0) {
2161 MS_LOG(ERROR) << data_source_actor->GetAID().Name() << " has no user.";
2162 return false;
2163 }
2164 }
2165
2166 if (strategy == GraphExecutionStrategy::kStep) {
2167 return true;
2168 }
2169
2170 // Check the kernel actors.
2171 for (const auto &kernel_actor : actor_set->kernel_actors_) {
2172 MS_EXCEPTION_IF_NULL(kernel_actor);
2173 if (kernel_actor->output_data_arrows_.size() + kernel_actor->output_control_arrows_.size() == 0) {
2174 MS_LOG(ERROR) << kernel_actor->GetAID().Name() << " has no user.";
2175 return false;
2176 }
2177
2178 auto input_num = AnfAlgo::GetInputTensorNum(kernel_actor->kernel_);
2179 auto input_data_num = kernel_actor->input_datas_num_;
2180 auto device_tensor_store_num = kernel_actor->device_tensor_store_keys_.size();
2181 if (input_data_num + device_tensor_store_num != input_num) {
2182 MS_LOG(ERROR) << "The input building of " << AnfAlgo::GetNodeDebugString(kernel_actor->kernel_)
2183 << " is wrong, input data num: " << input_data_num
2184 << ", device tensor store num: " << device_tensor_store_num << ", total input num: " << input_num;
2185 return false;
2186 }
2187 }
2188
2189 // Check the copy actors.
2190 for (const auto ©_actor : actor_set->copy_actors_) {
2191 MS_EXCEPTION_IF_NULL(copy_actor);
2192 if (copy_actor->output_data_arrows_.size() + copy_actor->output_control_arrows_.size() == 0) {
2193 MS_LOG(ERROR) << copy_actor->GetAID().Name() << " has no user.";
2194 return false;
2195 }
2196
2197 const size_t kCopyActorInputDataNum = 1;
2198 auto input_data_num = copy_actor->input_datas_num_;
2199 size_t device_tensor_store_num = copy_actor->device_tensor_store_keys_.size();
2200 if (input_data_num + device_tensor_store_num != kCopyActorInputDataNum) {
2201 MS_LOG(ERROR) << "The input building of " << copy_actor->GetAID().Name()
2202 << " is wrong, input data num: " << input_data_num
2203 << ", device tensor store num: " << device_tensor_store_num
2204 << ", total input num: " << kCopyActorInputDataNum;
2205 return false;
2206 }
2207 }
2208
2209 // Check the loop count actor.
2210 const auto &loop_count_actor = actor_set->loop_count_actor_;
2211 if ((loop_count_actor != nullptr) &&
2212 (actor_set->data_source_actors_.size() + actor_set->kernel_actors_.size() + actor_set->copy_actors_.size() > 0)) {
2213 if (loop_count_actor->input_controls_num_ == 0) {
2214 MS_LOG(ERROR) << loop_count_actor->GetAID().Name() << " has no source.";
2215 return false;
2216 }
2217 }
2218
2219 return true;
2220 }
2221
PersistDeviceTensor(const GraphCompilerInfo & graph_compiler_info)2222 void GraphScheduler::PersistDeviceTensor(const GraphCompilerInfo &graph_compiler_info) {
2223 for (size_t i = 0; i < graph_compiler_info.graphs_.size(); ++i) {
2224 const auto &graph = graph_compiler_info.graphs_[i];
2225 const auto &device_context = graph_compiler_info.device_contexts_[i];
2226 MS_EXCEPTION_IF_NULL(graph);
2227 MS_EXCEPTION_IF_NULL(device_context);
2228
2229 for (auto &value_node : graph->graph_value_nodes()) {
2230 MS_EXCEPTION_IF_NULL(value_node);
2231 if (!AnfAlgo::OutputAddrExist(value_node, 0)) {
2232 MS_LOG(INFO) << "The device address is not exist: " << value_node->ToString();
2233 continue;
2234 }
2235 auto device_tensor = AnfAlgo::GetMutableOutputAddr(value_node, 0, false);
2236 const auto &front_node = FetchFrontNodeByBackendNode(value_node, graph);
2237 DeviceTensorStore::GetInstance().Insert(front_node.get(), device_tensor);
2238 UpdateRefCount(device_tensor.get(), true);
2239 }
2240
2241 for (auto &input_node : graph->input_nodes()) {
2242 MS_EXCEPTION_IF_NULL(input_node);
2243 AnfNodePtr sub_front_node = nullptr;
2244 if (IsInternalParameter(input_node, graph)) {
2245 auto front_output_with_index = graph->GetFrontNodeByInternalParameter(input_node);
2246 sub_front_node = front_output_with_index.first;
2247 } else if (IsPersistentDeviceTensor(input_node) || HasAbstractRef(input_node)) {
2248 sub_front_node = FetchFrontNodeByBackendNode(input_node, graph);
2249 }
2250 if (sub_front_node == nullptr) {
2251 continue;
2252 }
2253
2254 // The sub front nodes share the device tensor store with the root front node.
2255 MS_EXCEPTION_IF_NULL(graph_compiler_info.control_node_parser_);
2256 auto front_node = graph_compiler_info.control_node_parser_->FetchRootGraphFrontNodeBySubFrontNode(sub_front_node);
2257 MS_EXCEPTION_IF_NULL(front_node);
2258 MS_LOG(DEBUG) << "Graph id:" << graph->graph_id() << ", sub front node:" << sub_front_node->DebugString()
2259 << ", root front node:" << front_node->DebugString();
2260
2261 auto device_tensor = AnfAlgo::GetMutableOutputAddr(input_node, 0, false);
2262 MS_EXCEPTION_IF_NULL(device_tensor);
2263 if (IsPersistentDeviceTensor(input_node)) {
2264 DeviceTensorStore::GetInstance().Insert(front_node.get(), device_tensor);
2265 UpdateRefCount(device_tensor.get(), true);
2266 }
2267
2268 // Share the weight in the host and device, then input_node is internal parameter and front_node is weight.
2269 if (!IsPersistentDeviceTensor(front_node)) {
2270 continue;
2271 }
2272 // If the device tensor store of this device type is not exist, then create the new device tensor of this type.
2273 if (DeviceTensorStore::GetInstance().Fetch(front_node.get(), device_context->GetDeviceAddressType()) == nullptr) {
2274 MS_LOG(INFO) << "Fetch no device tensor store by:" << front_node->fullname_with_scope()
2275 << ", type:" << device_context->GetDeviceAddressType();
2276 auto other_type_device_tensor = device_context->CreateDeviceAddress(
2277 nullptr, device_tensor->GetSize(), device_tensor->format(), device_tensor->type_id());
2278 DeviceTensorStore::GetInstance().Insert(front_node.get(), other_type_device_tensor);
2279 UpdateRefCount(other_type_device_tensor.get(), true);
2280 }
2281 }
2282 }
2283
2284 // In control flow, there may be some value nodes that is not in the kernel graph and needs to be placed
2285 // in the tensor store separately.
2286 for (const auto &value_node : graph_compiler_info.control_node_parser_->front_value_nodes_) {
2287 MS_EXCEPTION_IF_NULL(value_node.first);
2288 auto device_tensor = AnfAlgo::GetMutableOutputAddr(value_node.first, 0, false);
2289 DeviceTensorStore::GetInstance().Insert(value_node.first.get(), device_tensor);
2290 UpdateRefCount(device_tensor.get(), true);
2291 }
2292 }
2293
FetchKernelTransformTypeAndName(const AnfNodePtr & node,const KernelGraphPtr & graph,const GraphCompilerInfo & graph_compiler_info,KernelTransformType * const kernel_type,std::string * const kernel_name)2294 void GraphScheduler::FetchKernelTransformTypeAndName(const AnfNodePtr &node, const KernelGraphPtr &graph,
2295 const GraphCompilerInfo &graph_compiler_info,
2296 KernelTransformType *const kernel_type,
2297 std::string *const kernel_name) {
2298 MS_EXCEPTION_IF_NULL(node);
2299 MS_EXCEPTION_IF_NULL(graph);
2300 MS_EXCEPTION_IF_NULL(kernel_type);
2301 MS_EXCEPTION_IF_NULL(kernel_name);
2302
2303 if (IsDeviceQueueDSActor(node, graph_compiler_info.strategy_)) {
2304 *kernel_type = KernelTransformType::kDeviceDataSourceActor;
2305 *kernel_name = graph_compiler_info.name_ + "_DeviceDSActor" + "_" + std::to_string(graph->graph_id());
2306 } else if (IsHostQueueDSActor(node, graph, graph_compiler_info.origin_parameters_order_,
2307 graph_compiler_info.strategy_)) {
2308 *kernel_type = KernelTransformType::kHostDataSourceActor;
2309 *kernel_name = graph_compiler_info.name_ + "_HostDSActor";
2310 } else if (IsKernelActor(node, graph_compiler_info.strategy_)) {
2311 *kernel_type = KernelTransformType::kKernelActor;
2312 *kernel_name = node->fullname_with_scope();
2313 } else if (IsInternalParameter(node, graph)) {
2314 *kernel_type = KernelTransformType::kInternalParameter;
2315 *kernel_name = "";
2316 } else if (IsPersistentDeviceTensor(node)) {
2317 *kernel_type = KernelTransformType::kDeviceTensorStore;
2318 *kernel_name = "";
2319 } else {
2320 // May exist the from kernel that no need link in the pynative mode.
2321 MS_LOG(DEBUG) << "Invalid from kernel: " << node->fullname_with_scope();
2322 *kernel_type = KernelTransformType::kUnknown;
2323 *kernel_name = "";
2324 }
2325 }
2326
InsertActor(OpActor<DeviceTensor> * actor)2327 void GraphScheduler::InsertActor(OpActor<DeviceTensor> *actor) {
2328 MS_EXCEPTION_IF_NULL(actor);
2329 if (actor_name_to_actor_.count(actor->GetAID().Name()) > 0) {
2330 MS_LOG(EXCEPTION) << "The actor already exists: " << actor->GetAID().Name();
2331 }
2332 actor_name_to_actor_[actor->GetAID().Name()] = actor;
2333 }
2334
FetchActor(const std::string & actor_name) const2335 OpActor<DeviceTensor> *GraphScheduler::FetchActor(const std::string &actor_name) const {
2336 const auto &iter = actor_name_to_actor_.find(actor_name);
2337 if (iter == actor_name_to_actor_.end()) {
2338 return nullptr;
2339 }
2340 return iter->second;
2341 }
2342
DumpActor(const ActorSet * actor_set,const GraphCompilerInfo & graph_compiler_info) const2343 void GraphScheduler::DumpActor(const ActorSet *actor_set, const GraphCompilerInfo &graph_compiler_info) const {
2344 MS_EXCEPTION_IF_NULL(actor_set);
2345 const auto &context_ptr = MsContext::GetInstance();
2346 MS_EXCEPTION_IF_NULL(context_ptr);
2347 auto save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
2348 if (!save_graphs) {
2349 return;
2350 }
2351
2352 std::string filename = GetSaveGraphsPathName("actor_set_" + actor_set->name_ + ".ir");
2353 std::ofstream ofs(filename);
2354 if (!ofs.is_open()) {
2355 MS_LOG(ERROR) << "Open file [" << filename << "] failed!";
2356 return;
2357 }
2358
2359 ofs << "[Device tensor stores]\n";
2360 DumpDeviceTensorStore(graph_compiler_info, ofs);
2361
2362 const auto &data_prepare_actor = actor_set->data_prepare_actor_;
2363 ofs << "\n\n[Data prepare actor:" << (data_prepare_actor != nullptr ? 1 : 0) << "]\n";
2364 if (data_prepare_actor != nullptr) {
2365 DumpDataPrepareActor(data_prepare_actor.get(), ofs);
2366 }
2367
2368 ofs << "\n\n[Data source actors:" << actor_set->data_source_actors_.size() << "]\n";
2369 for (const auto &data_source_actor : actor_set->data_source_actors_) {
2370 DumpDSActor(data_source_actor.get(), ofs);
2371 }
2372
2373 ofs << "\n\n[Kernel actors:" << actor_set->kernel_actors_.size() << "]\n";
2374 for (const auto &kernel_actor : actor_set->kernel_actors_) {
2375 DumpKernelActor(kernel_actor.get(), ofs);
2376 }
2377
2378 ofs << "\n\n[No input kernel actors:" << actor_set->no_input_kernel_actors_.size() << "]\n";
2379 for (const auto &no_input_kernel_actor : actor_set->no_input_kernel_actors_) {
2380 DumpKernelActor(no_input_kernel_actor.get(), ofs);
2381 }
2382
2383 ofs << "\n\n[Copy actors:" << actor_set->copy_actors_.size() << "]\n";
2384 for (const auto ©_actor : actor_set->copy_actors_) {
2385 DumpCopyActor(copy_actor.get(), ofs);
2386 }
2387
2388 ofs << "\n\n[Gather actors:" << actor_set->gather_actors_.size() << "]\n";
2389 for (const auto &gather_actor : actor_set->gather_actors_) {
2390 DumpGatherActor(gather_actor.get(), ofs);
2391 }
2392
2393 ofs << "\n\n[Switch actors:" << actor_set->switch_actors_.size() << "]\n";
2394 for (const auto &switch_actor : actor_set->switch_actors_) {
2395 DumpSwitchActor(switch_actor.get(), ofs);
2396 }
2397
2398 const auto &loop_count_actor = actor_set->loop_count_actor_;
2399 ofs << "\n\n[Loop count actor:" << (loop_count_actor != nullptr ? 1 : 0) << "]\n";
2400 if (loop_count_actor != nullptr) {
2401 DumpLoopCountActor(loop_count_actor.get(), ofs);
2402 }
2403
2404 const auto &output_actor = actor_set->output_actor_;
2405 ofs << "\n\n[Output actor:" << (output_actor != nullptr ? 1 : 0) << "]\n";
2406 if (output_actor != nullptr) {
2407 DumpOutputActor(output_actor.get(), ofs);
2408 }
2409 }
2410
DumpAbstractActor(const AbstractActor * actor,std::ofstream & ofs) const2411 void GraphScheduler::DumpAbstractActor(const AbstractActor *actor, std::ofstream &ofs) const {
2412 MS_EXCEPTION_IF_NULL(actor);
2413 ofs << "\t\tdevice_contexts_num:" << actor->device_contexts_.size()
2414 << "\tdevice_tensor_store_keys_num:" << actor->device_tensor_store_keys_.size()
2415 << "\tinput_data_arrow_actors_num:" << actor->input_datas_num_
2416 << "\tinput_control_arrow_actors_num:" << actor->input_controls_num_ << "\n";
2417 ofs << "\t\toutput_data_arrows_num:" << actor->output_data_arrows_.size()
2418 << "\toutput_control_arrows_num:" << actor->output_control_arrows_.size()
2419 << "\toutput_result_arrows_num:" << actor->output_result_arrows_.size() << "\n";
2420
2421 if (actor->device_contexts_.size() > 0) {
2422 ofs << "\t\tdevice_contexts:" << actor->device_contexts_.size() << "\n ";
2423 for (const auto &device_context : actor->device_contexts_) {
2424 if (device_context == nullptr) {
2425 ofs << "\t\t\tdevice_context:" << device_context << "\n";
2426 continue;
2427 }
2428 ofs << "\t\t\tdevice_context:" << device_context->device_context_key().ToString() << "\n";
2429 }
2430 }
2431
2432 if (actor->device_tensor_store_keys_.size() > 0) {
2433 ofs << "\t\tdevice_tensor_store_keys:" << actor->device_tensor_store_keys_.size() << "\n ";
2434 for (const auto &device_tensor_store_key : actor->device_tensor_store_keys_) {
2435 MS_EXCEPTION_IF_NULL(device_tensor_store_key.second);
2436 ofs << "\t\t\tto_input_index:" << device_tensor_store_key.first
2437 << "\tfrom_node_name:" << device_tensor_store_key.second->fullname_with_scope() << "\n";
2438 }
2439 }
2440
2441 if (actor->input_data_arrow_aids_.size() > 0) {
2442 ofs << "\t\tinput_data_arrow_actors:" << actor->input_data_arrow_aids_.size() << "\n ";
2443 for (const auto &input_data_arrow_aid : actor->input_data_arrow_aids_) {
2444 ofs << "\t\t\tfrom_actor_name:" << input_data_arrow_aid.Name() << "\n";
2445 }
2446 }
2447
2448 if (actor->input_control_arrow_aids_.size() > 0) {
2449 ofs << "\t\tinput_control_arrow_actors:" << actor->input_control_arrow_aids_.size() << "\n ";
2450 for (const auto &input_control_arrow_aid : actor->input_control_arrow_aids_) {
2451 ofs << "\t\t\tfrom_actor_name:" << input_control_arrow_aid.Name() << "\n";
2452 }
2453 }
2454
2455 const auto &output_data_arrows = actor->output_data_arrows();
2456 if (output_data_arrows.size() > 0) {
2457 ofs << "\t\toutput_data_arrows:" << output_data_arrows.size() << "\n ";
2458 for (const auto &data_arrow : output_data_arrows) {
2459 MS_EXCEPTION_IF_NULL(data_arrow);
2460 ofs << "\t\t\tfrom_output_index:" << data_arrow->from_output_index_
2461 << "\tto_actor_name:" << data_arrow->to_op_id_.Name() << "\tto_input_index:" << data_arrow->to_input_index_
2462 << "\n";
2463 }
2464 }
2465
2466 const auto &output_control_arrows = actor->output_control_arrows();
2467 if (output_control_arrows.size() > 0) {
2468 ofs << "\t\toutput_control_arrows:" << output_control_arrows.size() << "\n ";
2469 for (const auto &aid : output_control_arrows) {
2470 ofs << "\t\t\tto_actor_name:" << aid.Name() << "\n";
2471 }
2472 }
2473
2474 if (actor->output_result_arrows_.size() > 0) {
2475 ofs << "\t\toutput_result_arrows:" << actor->output_result_arrows_.size() << "\n ";
2476 for (const auto &result_arrow : actor->output_result_arrows_) {
2477 MS_EXCEPTION_IF_NULL(result_arrow);
2478 ofs << "\t\t\tfrom_output_index:" << result_arrow->from_output_index_
2479 << "\tto_actor_name:" << result_arrow->to_op_id_.Name()
2480 << "\toutput_node_position:" << result_arrow->to_input_index_ << "\n";
2481 }
2482 }
2483 }
2484
DumpDataPrepareActor(const DataPrepareActor * actor,std::ofstream & ofs) const2485 void GraphScheduler::DumpDataPrepareActor(const DataPrepareActor *actor, std::ofstream &ofs) const {
2486 MS_EXCEPTION_IF_NULL(actor);
2487 ofs << "\tactor_name:" << actor->GetAID().Name() << "\n";
2488 DumpAbstractActor(actor, ofs);
2489
2490 ofs << "\t\toutput_control_arrows:" << actor->data_source_aids_.size() + actor->no_input_kernel_aids_.size() << "\n ";
2491 for (const auto &aid : actor->data_source_aids_) {
2492 ofs << "\t\t\tto_actor_name:" << aid.Name() << "\n";
2493 }
2494 for (const auto &aid : actor->no_input_kernel_aids_) {
2495 ofs << "\t\t\tto_actor_name:" << aid.Name() << "\n";
2496 }
2497
2498 ofs << "\t\tcontinuous_memory_nodes:" << actor->continuous_memory_nodes_.size() << "\n ";
2499 for (const auto &iter : actor->continuous_memory_nodes_) {
2500 MS_EXCEPTION_IF_NULL(iter.first.first);
2501 MS_EXCEPTION_IF_NULL(iter.first.second);
2502 ofs << "\t\t\tnode_name:" << iter.first.first->fullname_with_scope()
2503 << "\tdevice_context:" << iter.first.second->device_context_key().ToString()
2504 << "\tis_input_need:" << iter.second.first << "\tis_output_need:" << iter.second.second << "\n";
2505 }
2506 }
2507
DumpDSActor(const DataSourceActor * actor,std::ofstream & ofs) const2508 void GraphScheduler::DumpDSActor(const DataSourceActor *actor, std::ofstream &ofs) const {
2509 MS_EXCEPTION_IF_NULL(actor);
2510 const auto &actor_name = actor->GetAID().Name();
2511 ofs << "\tactor_name:" << actor_name << "\n";
2512
2513 if (actor->type_ == KernelTransformType::kDeviceDataSourceActor) {
2514 // Dump the member info of device queue data source actor.
2515 const auto &device_queue_ds_actor = dynamic_cast<const DeviceQueueDataSourceActor *>(actor);
2516 MS_EXCEPTION_IF_NULL(device_queue_ds_actor);
2517 const auto &data_kernel = device_queue_ds_actor->data_kernel_;
2518 MS_EXCEPTION_IF_NULL(data_kernel);
2519 ofs << "\t\tdata_kernel_name:" << data_kernel->fullname_with_scope()
2520 << "\tinput_number:" << AnfAlgo::GetInputTensorNum(data_kernel)
2521 << "\toutput_number:" << AnfAlgo::GetOutputTensorNum(data_kernel) << "\n";
2522 for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(data_kernel); ++i) {
2523 const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(data_kernel, i, false);
2524 MS_EXCEPTION_IF_NULL(device_tensor);
2525 ofs << "\t\t\toutput_index:" << i << "\tptr:" << device_tensor->GetPtr() << "\tsize:" << device_tensor->GetSize()
2526 << "\toriginal_ref_count:" << device_tensor->original_ref_count() << "\n ";
2527 }
2528 } else if (actor->type_ == KernelTransformType::kHostDataSourceActor) {
2529 // Dump the member info of host queue data source actor.
2530 const auto &host_queue_ds_actor = dynamic_cast<const HostQueueDataSourceActor *>(actor);
2531 MS_EXCEPTION_IF_NULL(host_queue_ds_actor);
2532 ofs << "\t\tdata_nodes:" << host_queue_ds_actor->data_nodes_.size() << "\n";
2533 for (size_t i = 0; i < host_queue_ds_actor->data_nodes_.size(); ++i) {
2534 const auto &data_node = host_queue_ds_actor->data_nodes_[i];
2535 MS_EXCEPTION_IF_NULL(data_node);
2536 const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(data_node, 0, false);
2537 MS_EXCEPTION_IF_NULL(device_tensor);
2538 ofs << "\t\t\tnode_order_number:" << i << "\tnode_name:" << data_node->fullname_with_scope()
2539 << "\tptr:" << device_tensor->GetPtr() << "\tsize:" << device_tensor->GetSize()
2540 << "\toriginal_ref_count:" << device_tensor->original_ref_count() << "\n";
2541 }
2542 }
2543
2544 DumpAbstractActor(actor, ofs);
2545 ofs << "\n";
2546 }
2547
DumpLoopCountActor(const LoopCountActor * actor,std::ofstream & ofs) const2548 void GraphScheduler::DumpLoopCountActor(const LoopCountActor *actor, std::ofstream &ofs) const {
2549 MS_EXCEPTION_IF_NULL(actor);
2550 ofs << "\tactor_name:" << actor->GetAID().Name() << "\tloop_count:" << actor->loop_count_ << "\n";
2551 DumpAbstractActor(actor, ofs);
2552
2553 const size_t kOutputControlArrowsNum = 2;
2554 ofs << "\t\toutput_control_arrows:" << kOutputControlArrowsNum << "\n ";
2555 ofs << "\t\t\tto_actor_name:" << actor->output_aid_.Name() << "\n";
2556 ofs << "\t\t\tto_actor_name:" << actor->data_prepare_aid_.Name() << "\n";
2557 }
2558
DumpKernelActor(const KernelActor * actor,std::ofstream & ofs) const2559 void GraphScheduler::DumpKernelActor(const KernelActor *actor, std::ofstream &ofs) const {
2560 MS_EXCEPTION_IF_NULL(actor);
2561 ofs << "\tactor_name:" << actor->GetAID().Name() << "\n";
2562
2563 const auto &kernel = actor->kernel_;
2564 MS_EXCEPTION_IF_NULL(kernel);
2565 ofs << "\t\tkernel_name:" << kernel->fullname_with_scope() << "\tinputs_num:" << AnfAlgo::GetInputTensorNum(kernel)
2566 << "\toutputs_num:" << AnfAlgo::GetOutputTensorNum(kernel) << "\n";
2567 for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(kernel); ++i) {
2568 const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(kernel, i, false);
2569 MS_EXCEPTION_IF_NULL(device_tensor);
2570 ofs << "\t\t\toutput_index:" << i << "\tptr:" << device_tensor->GetPtr() << "\tsize:" << device_tensor->GetSize()
2571 << "\toriginal_ref_count:" << device_tensor->original_ref_count() << "\n ";
2572 }
2573
2574 DumpAbstractActor(actor, ofs);
2575 ofs << "\n";
2576 }
2577
DumpOutputActor(const OutputActor * actor,std::ofstream & ofs) const2578 void GraphScheduler::DumpOutputActor(const OutputActor *actor, std::ofstream &ofs) const {
2579 MS_EXCEPTION_IF_NULL(actor);
2580 ofs << "\tactor_name:" << actor->GetAID().Name() << "\tloop_count:" << actor->loop_count_
2581 << "\toutputs_num:" << actor->outputs_num_ << "\n";
2582 DumpAbstractActor(actor, ofs);
2583 }
2584
DumpCopyActor(const CopyActor * actor,std::ofstream & ofs) const2585 void GraphScheduler::DumpCopyActor(const CopyActor *actor, std::ofstream &ofs) const {
2586 MS_EXCEPTION_IF_NULL(actor);
2587 ofs << "\tactor_name:" << actor->GetAID().Name() << "\n";
2588
2589 auto device_tensor = actor->output_;
2590 if (device_tensor != nullptr) {
2591 ofs << "\t\toutput_index:" << 0 << "\tptr:" << device_tensor->GetPtr() << "\tsize:" << device_tensor->GetSize()
2592 << "\toriginal_ref_count:" << device_tensor->original_ref_count() << "\n ";
2593 }
2594
2595 DumpAbstractActor(actor, ofs);
2596 ofs << "\n";
2597 }
2598
DumpDeviceTensorStore(const GraphCompilerInfo & graph_compiler_info,std::ofstream & ofs) const2599 void GraphScheduler::DumpDeviceTensorStore(const GraphCompilerInfo &graph_compiler_info, std::ofstream &ofs) const {
2600 for (const auto &graph : graph_compiler_info.graphs_) {
2601 MS_EXCEPTION_IF_NULL(graph);
2602 ofs << "\tgraph id:" << graph->graph_id() << "\n";
2603
2604 for (auto &value_node : graph->graph_value_nodes()) {
2605 MS_EXCEPTION_IF_NULL(value_node);
2606 if (!AnfAlgo::OutputAddrExist(value_node, 0)) {
2607 continue;
2608 }
2609 const auto &front_node = FetchFrontNodeByBackendNode(value_node, graph);
2610 MS_EXCEPTION_IF_NULL(front_node);
2611 const auto device_tensors = DeviceTensorStore::GetInstance().Fetch(front_node.get());
2612 ofs << "\t\tdevice tensor key:" << front_node->fullname_with_scope() << "\tvalue size:" << device_tensors.size()
2613 << "\n";
2614 for (const auto &device_tensor : device_tensors) {
2615 MS_EXCEPTION_IF_NULL(device_tensor);
2616 ofs << "\t\t\tdevice tensor value:" << device_tensor << "\tptr:" << device_tensor->GetPtr()
2617 << "\tsize:" << device_tensor->GetSize() << "\toriginal_ref_count:" << device_tensor->original_ref_count()
2618 << "\tdevice_type:" << device_tensor->DeviceType() << "\n ";
2619 }
2620 }
2621
2622 for (auto &input_node : graph->input_nodes()) {
2623 MS_EXCEPTION_IF_NULL(input_node);
2624 if (!IsPersistentDeviceTensor(input_node)) {
2625 continue;
2626 }
2627 const auto &sub_front_node = FetchFrontNodeByBackendNode(input_node, graph);
2628 // The sub front nodes share the device tensor store with the root front node.
2629 auto front_node = sub_front_node;
2630 if (graph_compiler_info.control_node_parser_ != nullptr) {
2631 front_node = graph_compiler_info.control_node_parser_->FetchRootGraphFrontNodeBySubFrontNode(sub_front_node);
2632 }
2633 const auto device_tensors = DeviceTensorStore::GetInstance().Fetch(front_node.get());
2634 MS_EXCEPTION_IF_NULL(front_node);
2635 ofs << "\t\tdevice tensor key:" << front_node->fullname_with_scope() << "\tvalue size:" << device_tensors.size()
2636 << "\n";
2637 for (const auto &device_tensor : device_tensors) {
2638 MS_EXCEPTION_IF_NULL(device_tensor);
2639 ofs << "\t\t\tdevice tensor value:" << device_tensor << "\tptr:" << device_tensor->GetPtr()
2640 << "\tsize:" << device_tensor->GetSize() << "\toriginal_ref_count:" << device_tensor->original_ref_count()
2641 << "\tdevice_type:" << device_tensor->DeviceType() << "\n ";
2642 }
2643 }
2644 ofs << "\n";
2645 }
2646 }
2647
DumpGatherActor(const GatherActor * actor,std::ofstream & ofs) const2648 void GraphScheduler::DumpGatherActor(const GatherActor *actor, std::ofstream &ofs) const {
2649 MS_EXCEPTION_IF_NULL(actor);
2650 ofs << "\tactor_name:" << actor->GetAID().Name() << '\n';
2651
2652 ofs << "\t\tactor input num:" << actor->data_nodes_.size() << "\n";
2653 for (const auto &node : actor->data_nodes_) {
2654 ofs << "\t\t\t" << AnfAlgo::GetNodeDebugString(node.first) << "\tindex:" << node.second << '\n';
2655 }
2656
2657 ofs << "\t\tactor front to backend node:\n";
2658 for (const auto &front_to_backend_parameter : actor->front_to_backend_parameter_) {
2659 ofs << "\t\t\tfront node:" << AnfAlgo::GetNodeDebugString(front_to_backend_parameter.first) << '\n';
2660 for (const auto node_with_index : front_to_backend_parameter.second) {
2661 ofs << "\t\t\t\tbackend node:" << AnfAlgo::GetNodeDebugString(node_with_index.first)
2662 << "\tindex:" << node_with_index.second << '\n';
2663 }
2664 }
2665
2666 ofs << "\t\tactor output data arrow:\n";
2667 for (const auto &data_arrow : actor->output_data_arrows_) {
2668 MS_EXCEPTION_IF_NULL(data_arrow);
2669 ofs << "\t\t\tfrom_output_index:" << data_arrow->from_output_index_
2670 << "\tto_actor_name:" << data_arrow->to_op_id_.Name() << "\tto_input_index:" << data_arrow->to_input_index_
2671 << "\n";
2672 }
2673
2674 ofs << "\t\tactor output result arrow:\n";
2675 for (const auto &result_arrow : actor->output_result_arrows_) {
2676 MS_EXCEPTION_IF_NULL(result_arrow);
2677 ofs << "\t\t\tfrom_output_index:" << result_arrow->from_output_index_
2678 << "\tto_actor_name:" << result_arrow->to_op_id_.Name() << "\tto_input_index:" << result_arrow->to_input_index_
2679 << "\n";
2680 }
2681
2682 ofs << "\t\tactor output control arrow:\n";
2683 for (const auto &control_arrow : actor->output_control_arrows_) {
2684 ofs << "\t\t\tto_actor_name:" << control_arrow;
2685 }
2686 ofs << "\n";
2687 }
2688
DumpSwitchActor(const SwitchActor * actor,std::ofstream & ofs) const2689 void GraphScheduler::DumpSwitchActor(const SwitchActor *actor, std::ofstream &ofs) const {
2690 MS_EXCEPTION_IF_NULL(actor);
2691 ofs << "\tactor_name:" << actor->GetAID().Name() << '\n';
2692
2693 ofs << "\t\tactor input num:" << actor->input_nodes_.size() << "\n";
2694 for (const auto &node : actor->input_nodes_) {
2695 ofs << "\t\t\t" << AnfAlgo::GetNodeDebugString(node.first) << '\t' << node.second << '\n';
2696 }
2697
2698 ofs << "\t\tactor input pos:\n";
2699 for (size_t i = 0; i < actor->branch_inputs_pos_.size(); ++i) {
2700 ofs << "\t\t\tbranch " << i << " input pos:";
2701 for (const auto pos : actor->branch_inputs_pos_[i]) {
2702 ofs << pos << '\t';
2703 }
2704 ofs << '\n';
2705 }
2706
2707 ofs << "\t\tactor output data arrow:\n";
2708 for (size_t i = 0; i < actor->output_branch_arrows_.size(); ++i) {
2709 ofs << "\t\t\tbranch " << i << " output data:\n";
2710 for (const auto arrow : actor->output_branch_arrows_[i]) {
2711 MS_EXCEPTION_IF_NULL(arrow);
2712 ofs << "\t\t\t\t from index:" << arrow->from_output_index_ << "\tto_actor_name:" << arrow->to_op_id_
2713 << "\tto_input_index:" << arrow->to_input_index_ << '\n';
2714 }
2715 }
2716
2717 ofs << "\t\tactor output result arrow:\n";
2718 for (size_t i = 0; i < actor->output_branch_result_arrows_.size(); ++i) {
2719 ofs << "\t\t\tbranch " << i << " output result:\n";
2720 for (const auto arrow : actor->output_branch_result_arrows_[i]) {
2721 MS_EXCEPTION_IF_NULL(arrow);
2722 ofs << "\t\t\t\t from index:" << arrow->from_output_index_ << "\tto_actor_name:" << arrow->to_op_id_
2723 << "\tto_input_index:" << arrow->to_input_index_ << '\n';
2724 }
2725 }
2726
2727 ofs << "\t\tactor output control arrow:\n";
2728 for (size_t i = 0; i < actor->output_branch_control_arrows_.size(); ++i) {
2729 ofs << "\t\t\tbranch " << i << " output control:\n";
2730 for (const auto arrow : actor->output_branch_control_arrows_[i]) {
2731 ofs << "\t\t\t\t from index:" << arrow << '\n';
2732 }
2733 }
2734 ofs << "\n";
2735 }
2736 } // namespace runtime
2737 } // namespace mindspore
2738