• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "runtime/graph_scheduler/scheduler_helper.h"
18 #include "mindspore/core/ops/framework_ops.h"
19 #include "mindspore/core/ops/array_ops.h"
20 #include "runtime/graph_scheduler/actor/actor_dump.h"
21 #include "include/backend/anf_runtime_algorithm.h"
22 #include "include/common/utils/anfalgo.h"
23 #include "utils/anf_utils.h"
24 #include "utils/log_adapter.h"
25 #include "include/common/utils/convert_utils.h"
26 
27 namespace mindspore {
28 namespace runtime {
29 size_t SchedulerHelper::fusion_actor_index_ = 0;
30 
31 namespace {
CollectControlActors(const ActorSet * actor_set,std::vector<AbstractActorPtr> * actors)32 void CollectControlActors(const ActorSet *actor_set, std::vector<AbstractActorPtr> *actors) {
33   MS_EXCEPTION_IF_NULL(actor_set);
34   MS_EXCEPTION_IF_NULL(actors);
35   if (actor_set->control_actors_ != nullptr) {
36     const auto &control_actor_set = actor_set->control_actors_;
37     for (auto &switch_actor : control_actor_set->switch_actors_) {
38       MS_EXCEPTION_IF_NULL(switch_actor);
39       (void)actors->emplace_back(static_cast<AbstractActorPtr>(switch_actor));
40     }
41     for (auto &gather_actor : control_actor_set->gather_actors_) {
42       MS_EXCEPTION_IF_NULL(gather_actor);
43       (void)actors->emplace_back(static_cast<AbstractActorPtr>(gather_actor));
44     }
45     for (auto &entrance_actor : control_actor_set->entrance_actors_) {
46       MS_EXCEPTION_IF_NULL(entrance_actor);
47       (void)actors->emplace_back(static_cast<AbstractActorPtr>(entrance_actor));
48     }
49     for (auto &exit_actor : control_actor_set->exit_actors_) {
50       MS_EXCEPTION_IF_NULL(exit_actor);
51       (void)actors->emplace_back(static_cast<AbstractActorPtr>(exit_actor));
52     }
53     for (auto &stack_actor : control_actor_set->stack_actors_) {
54       MS_EXCEPTION_IF_NULL(stack_actor);
55       (void)actors->emplace_back(static_cast<AbstractActorPtr>(stack_actor));
56     }
57   }
58 }
59 
IsSkipLaunchShapeRelatedOp(KernelActor * kernel_actor)60 bool IsSkipLaunchShapeRelatedOp(KernelActor *kernel_actor) {
61   MS_EXCEPTION_IF_NULL(kernel_actor);
62   if (kernel_actor->skip_launch_shape_related_op()) {
63     return true;
64   }
65 
66   auto &kernel = kernel_actor->kernel();
67   MS_EXCEPTION_IF_NULL(kernel);
68 
69   // RealMakeTuple --> ShapeCalc pattern:
70   // If ShapeCalc is not value depend for one input RealMakeTuple op, we can skip launch this RealMakeTuple.
71   if (IsPrimitiveCNode(kernel, prim::kPrimRealMakeTuple)) {
72     auto func_graph = kernel->func_graph();
73     MS_EXCEPTION_IF_NULL(func_graph);
74     auto manager = func_graph->manager();
75     if (manager == nullptr) {
76       manager = Manage(func_graph, true);
77       func_graph->set_manager(manager);
78     }
79 
80     const auto &users_set = manager->node_users()[kernel];
81     bool can_skip_launch_real_make_tuple = true;
82     for (const auto &item : users_set) {
83       const auto &user_node = item.first;
84       if (!user_node->isa<CNode>()) {
85         can_skip_launch_real_make_tuple = false;
86         break;
87       }
88       auto user_cnode = user_node->cast<CNodePtr>();
89       MS_EXCEPTION_IF_NULL(user_cnode);
90       if (!IsPrimitiveCNode(user_cnode, prim::kPrimShapeCalc)) {
91         can_skip_launch_real_make_tuple = false;
92         break;
93       }
94 
95       if (!common::AnfAlgo::HasNodeAttr(kAttrOnlyDependShape, user_cnode)) {
96         can_skip_launch_real_make_tuple = false;
97         break;
98       }
99       const auto &only_depend_shape = common::AnfAlgo::GetNodeAttr<std::vector<bool>>(user_cnode, kAttrOnlyDependShape);
100       auto user_input_index = item.second;
101       if (user_input_index < 1) {
102         MS_LOG(EXCEPTION) << "The input index should start from 1, but got: " << user_input_index;
103       }
104       if (IntToSize(user_input_index) > only_depend_shape.size()) {
105         MS_LOG(EXCEPTION) << "The input index[" << user_input_index
106                           << "] is out of range, input size: " << only_depend_shape.size();
107       }
108       if (!only_depend_shape[user_input_index - 1]) {
109         can_skip_launch_real_make_tuple = false;
110         break;
111       }
112     }
113 
114     if (can_skip_launch_real_make_tuple) {
115       return true;
116     }
117   }
118 
119   return false;
120 }
121 
UpdateDataArrowRefCount(AbstractActor * const to_actor,size_t to_input_index,const DeviceTensorPtr & device_tensor)122 void UpdateDataArrowRefCount(AbstractActor *const to_actor, size_t to_input_index,
123                              const DeviceTensorPtr &device_tensor) {
124   MS_LOG(DEBUG) << "Process shape depend attribute for actor : " << to_actor->GetAID().Name();
125   bool need_increase_ref_count = true;
126   auto to_kernel_actor = dynamic_cast<KernelActor *>(to_actor);
127   auto ms_context = MsContext::GetInstance();
128   MS_EXCEPTION_IF_NULL(ms_context);
129   static const bool enable_infer_boost = ms_context->IsEnableInferBoost();
130   if (to_kernel_actor != nullptr && !enable_infer_boost) {
131     auto to_kernel = to_kernel_actor->kernel();
132     auto cnode = to_kernel->cast<CNodePtr>();
133     if (cnode != nullptr) {
134       MS_LOG(DEBUG) << "Process shape depend attribute for cnode : " << cnode->fullname_with_scope();
135       const auto &only_depend_shape_attr = common::AnfAlgo::GetCNodePrimitiveAttr(cnode, kAttrOnlyDependShape);
136       if (only_depend_shape_attr != nullptr) {
137         auto only_depend_shape = GetValue<std::vector<bool>>(only_depend_shape_attr);
138         if (only_depend_shape.size() <= to_input_index) {
139           MS_LOG(DEBUG) << "to_input_index : " << to_input_index
140                         << " is out of range, only_depend_shape size : " << only_depend_shape.size();
141         } else {
142           auto is_shape_depend = only_depend_shape[to_input_index];
143           MS_LOG(DEBUG) << "only_depend_shape[" << to_input_index << "] : " << is_shape_depend;
144           if (is_shape_depend) {
145             need_increase_ref_count = false;
146           }
147         }
148       }
149     }
150   }
151   if (need_increase_ref_count) {
152     UpdateRefCount(device_tensor.get(), false);
153   } else {
154     device_tensor->UpdateFlag(device::kDeviceAddressFlagNullptr);
155   }
156 }
157 }  // namespace
158 
CollectActors(const ActorSet * actor_set)159 std::vector<AbstractActorPtr> SchedulerHelper::CollectActors(const ActorSet *actor_set) {
160   MS_EXCEPTION_IF_NULL(actor_set);
161   std::vector<AbstractActorPtr> actors;
162 
163   if (actor_set->data_prepare_actor_ != nullptr) {
164     (void)actors.emplace_back(static_cast<AbstractActorPtr>(actor_set->data_prepare_actor_));
165   }
166   for (auto &data_source_actor : actor_set->data_source_actors_) {
167     MS_EXCEPTION_IF_NULL(data_source_actor);
168     (void)actors.emplace_back(static_cast<AbstractActorPtr>(data_source_actor));
169   }
170   for (auto &custom_actor : actor_set->custom_actors_) {
171     MS_EXCEPTION_IF_NULL(custom_actor);
172     (void)actors.emplace_back(static_cast<AbstractActorPtr>(custom_actor));
173   }
174   for (auto &kernel_actor : actor_set->kernel_actors_) {
175     MS_EXCEPTION_IF_NULL(kernel_actor);
176     (void)actors.emplace_back(static_cast<AbstractActorPtr>(kernel_actor));
177   }
178   for (auto &kernel_infer_actor : actor_set->kernel_infer_actors_) {
179     MS_EXCEPTION_IF_NULL(kernel_infer_actor);
180     (void)actors.emplace_back(static_cast<AbstractActorPtr>(kernel_infer_actor));
181   }
182   for (auto &kernel_resize_actor : actor_set->kernel_resize_actors_) {
183     MS_EXCEPTION_IF_NULL(kernel_resize_actor);
184     (void)actors.emplace_back(static_cast<AbstractActorPtr>(kernel_resize_actor));
185   }
186   for (auto &super_kernel_actor : actor_set->super_kernel_actors_) {
187     MS_EXCEPTION_IF_NULL(super_kernel_actor);
188     (void)actors.emplace_back(static_cast<AbstractActorPtr>(super_kernel_actor));
189   }
190   for (auto &any_type_kernel_actor : actor_set->any_type_kernel_actors_) {
191     MS_EXCEPTION_IF_NULL(any_type_kernel_actor);
192     (void)actors.emplace_back(static_cast<AbstractActorPtr>(any_type_kernel_actor));
193   }
194   for (auto &memory_actor : actor_set->memory_actors_) {
195     MS_EXCEPTION_IF_NULL(memory_actor);
196     (void)actors.emplace_back(static_cast<AbstractActorPtr>(memory_actor));
197   }
198   for (auto &copy_actor : actor_set->copy_actors_) {
199     MS_EXCEPTION_IF_NULL(copy_actor);
200     (void)actors.emplace_back(static_cast<AbstractActorPtr>(copy_actor));
201   }
202   for (auto &fusion_actor : actor_set->fusion_actors_) {
203     MS_EXCEPTION_IF_NULL(fusion_actor);
204     (void)actors.emplace_back(static_cast<AbstractActorPtr>(fusion_actor));
205   }
206   for (auto &swap_actors : actor_set->swap_actors_) {
207     (void)std::for_each(swap_actors.cbegin(), swap_actors.cend(), [&](const MemSwapActorPtr &swap_actor) {
208       if (swap_actor != nullptr) {
209         (void)actors.emplace_back(static_cast<AbstractActorPtr>(swap_actor));
210       }
211     });
212   }
213   if (actor_set->loop_count_actor_ != nullptr) {
214     (void)actors.emplace_back(static_cast<AbstractActorPtr>(actor_set->loop_count_actor_));
215   }
216   if (actor_set->output_actor_ != nullptr) {
217     (void)actors.emplace_back(static_cast<AbstractActorPtr>(actor_set->output_actor_));
218   }
219   CollectControlActors(actor_set, &actors);
220   return actors;
221 }
222 
HasMonadControl(const AnfNodePtr & input_node,const KernelGraphPtr & graph)223 bool SchedulerHelper::HasMonadControl(const AnfNodePtr &input_node, const KernelGraphPtr &graph) {
224   MS_EXCEPTION_IF_NULL(input_node);
225   MS_EXCEPTION_IF_NULL(graph);
226   const mindspore::HashSet<PrimitivePtr, PrimitiveHasher, PrimitiveEqual> auto_monad_prims = {
227     prim::kPrimDepend, prim::kPrimUpdateState, prim::kPrimLoad};
228   if (IsOneOfPrimitiveCNode(input_node, auto_monad_prims) || HasAbstractMonad(input_node)) {
229     return true;
230   }
231 
232   // The subgraph input.
233   if (IsInternalParameter(input_node, graph)) {
234     auto front_output_with_index = graph->GetOriginFrontNodeByInternalParameter(input_node);
235     auto front_output_node = front_output_with_index.first;
236     MS_EXCEPTION_IF_NULL(front_output_node);
237     if (IsOneOfPrimitiveCNode(front_output_node, auto_monad_prims) || HasAbstractMonad(front_output_node)) {
238       MS_LOG(INFO) << "The graph " << graph->graph_id()
239                    << " has monad control from internal parameter: " << input_node->DebugString()
240                    << ", front output node: " << front_output_node->fullname_with_scope();
241       return true;
242     }
243   }
244 
245   return false;
246 }
247 
AddDeviceTensorStore(const AnfNode * anf_node,const DeviceTensorPtr & device_tensor)248 void SchedulerHelper::AddDeviceTensorStore(const AnfNode *anf_node, const DeviceTensorPtr &device_tensor) {
249   MS_EXCEPTION_IF_NULL(anf_node);
250   MS_EXCEPTION_IF_NULL(device_tensor);
251   MS_LOG(DEBUG) << "Add device tensor store:" << device_tensor << " for node:" << anf_node->DebugString()
252                 << " node addr:" << anf_node << " device type:" << device_tensor->GetDeviceType();
253   DeviceTensorStore::GetInstance().Insert(const_cast<AnfNode *>(anf_node), device_tensor);
254   device_tensor->ClearFlag(device::kDeviceAddressFlagNotUsed);
255   UpdateRefCount(device_tensor.get(), true);
256 }
257 
AddMonadDeviceTensorStore(AbstractActor * const to_actor,const CNodePtr & kernel,const KernelGraphPtr & graph)258 void SchedulerHelper::AddMonadDeviceTensorStore(AbstractActor *const to_actor, const CNodePtr &kernel,
259                                                 const KernelGraphPtr &graph) {
260   MS_EXCEPTION_IF_NULL(to_actor);
261   MS_EXCEPTION_IF_NULL(kernel);
262   MS_EXCEPTION_IF_NULL(graph);
263   // Ref node monad device tensor store.
264   if (common::AnfAlgo::HasNodeAttr(kAttrRefNodeMonadOutputIdx, kernel)) {
265     auto output_idx = common::AnfAlgo::GetNodeAttr<size_t>(kernel, kAttrRefNodeMonadOutputIdx);
266     const auto &origin_pair = graph->GetRefNodeRecursive({kernel, output_idx});
267     auto front_node = AnfAlgo::FetchFrontNodeByBackendNode(origin_pair.first, *graph);
268     MS_EXCEPTION_IF_NULL(front_node);
269     if (IsPersistentDeviceTensor(front_node)) {
270       MS_LOG(INFO) << to_actor->GetAID().Name() << ", kernel:" << kernel->fullname_with_scope()
271                    << " add ref node monad device tensor store:" << front_node->fullname_with_scope();
272       (void)to_actor->auto_monad_device_tensor_stores_.insert(front_node);
273     }
274   }
275 
276   // Input node monad device tensor store.
277   if (!common::AnfAlgo::HasMonadInput(kernel)) {
278     return;
279   }
280 
281   // Super kernel actor need fetch by the input device tensor store.
282   if (to_actor->type_ == KernelTransformType::kSuperKernelActor) {
283     for (size_t i = 0; i < common::AnfAlgo::GetInputTensorNum(kernel); ++i) {
284       KernelWithIndex from_kernel_with_output_idx = common::AnfAlgo::GetPrevNodeOutput(kernel, i, false);
285       auto front_node = AnfAlgo::FetchFrontNodeByBackendNode(from_kernel_with_output_idx.first, *graph);
286       MS_EXCEPTION_IF_NULL(front_node);
287       if (IsPersistentDeviceTensor(front_node)) {
288         MS_LOG(INFO) << to_actor->GetAID().Name() << ", kernel:" << kernel->fullname_with_scope()
289                      << " add input node monad device tensor store:" << front_node->fullname_with_scope();
290         (void)to_actor->auto_monad_device_tensor_stores_.insert(front_node);
291       }
292     }
293   } else {
294     // Kernel actor can fetch by the device tensor store key directly.
295     const auto &device_tensor_store_keys = to_actor->device_tensor_store_keys_;
296     (void)std::for_each(device_tensor_store_keys.begin(), device_tensor_store_keys.end(), [&](const auto &store_key) {
297       MS_EXCEPTION_IF_NULL(store_key.second);
298       MS_LOG(INFO) << to_actor->GetAID().Name() << ", kernel:" << kernel->fullname_with_scope()
299                    << " add input node monad device tensor store:" << store_key.second->fullname_with_scope();
300       (void)to_actor->auto_monad_device_tensor_stores_.insert(store_key.second);
301     });
302   }
303 }
304 
IsIgnoredInputAddress(AbstractActor * const to_actor,size_t to_input_index)305 bool SchedulerHelper::IsIgnoredInputAddress(AbstractActor *const to_actor, size_t to_input_index) {
306   MS_EXCEPTION_IF_NULL(to_actor);
307   if (to_actor->type() != KernelTransformType::kKernelActor) {
308     return false;
309   }
310 
311   auto kernel_actor = dynamic_cast<KernelActor *>(to_actor);
312   auto &to_kernel = kernel_actor->kernel();
313   MS_EXCEPTION_IF_NULL(to_kernel);
314 
315   if (IsSkipLaunchShapeRelatedOp(kernel_actor)) {
316     kernel_actor->set_skip_launch_shape_related_op(true);
317     return true;
318   }
319 
320   auto kernel_mod = AnfAlgo::GetKernelMod(to_kernel);
321   MS_EXCEPTION_IF_NULL(kernel_mod);
322   const auto &ignored_address = kernel_mod->GetLaunchIgnoredInputAddressIdx();
323   if (ignored_address.empty()) {
324     return false;
325   }
326 
327   if (std::find(ignored_address.begin(), ignored_address.end(), to_input_index) != ignored_address.end()) {
328     MS_LOG(INFO) << "Ignore the input address for kernel: " << to_kernel->fullname_with_scope()
329                  << " with input index: " << to_input_index;
330     return true;
331   }
332 
333   return false;
334 }
335 
GetIgnoredInputAddressCount(const AnfNodePtr & node)336 size_t SchedulerHelper::GetIgnoredInputAddressCount(const AnfNodePtr &node) {
337   MS_EXCEPTION_IF_NULL(node);
338   size_t input_num = common::AnfAlgo::GetInputTensorNum(node);
339   auto kernel_mod = AnfAlgo::GetKernelMod(node);
340   MS_EXCEPTION_IF_NULL(kernel_mod);
341   const auto &ignored_input_addresses = kernel_mod->GetLaunchIgnoredInputAddressIdx();
342   if (ignored_input_addresses.empty()) {
343     return 0;
344   }
345 
346   auto count = std::count_if(ignored_input_addresses.begin(), ignored_input_addresses.end(),
347                              [input_num](size_t index) { return index < input_num; });
348   return static_cast<size_t>(count);
349 }
350 
AddDataArrow(AbstractActor * const from_actor,AbstractActor * const to_actor,size_t from_output_index,size_t to_input_index,const AnfNodePtr & from_kernel)351 void SchedulerHelper::AddDataArrow(AbstractActor *const from_actor, AbstractActor *const to_actor,
352                                    size_t from_output_index, size_t to_input_index, const AnfNodePtr &from_kernel) {
353   MS_EXCEPTION_IF_NULL(from_actor);
354   MS_EXCEPTION_IF_NULL(to_actor);
355   MS_LOG(DEBUG) << "Add data arrow from actor:" << from_actor->GetAID() << " index:" << from_output_index
356                 << " to actor:" << to_actor->GetAID() << " to index:" << to_input_index
357                 << " from kernel:" << (from_kernel == nullptr ? "null" : from_kernel->fullname_with_scope());
358   // Check the data arrow legitimacy.
359   if (IsControlFlowActor(to_actor->type()) && (from_actor->type() == KernelTransformType::kKernelActor) &&
360       (to_actor->type() != KernelTransformType::kExitActor)) {
361     MS_LOG(WARNING) << "Kernel actor:" << from_actor->GetAID().Name()
362                     << " link data arrow to actor:" << to_actor->GetAID().Name() << " is not an exit actor.";
363   }
364 
365   if (from_actor->type() == KernelTransformType::kKernelActor &&
366       to_actor->type() == KernelTransformType::kKernelActor) {
367     auto from_kernel_actor = dynamic_cast<KernelActor *>(from_actor);
368     MS_EXCEPTION_IF_NULL(from_kernel_actor);
369     if (IsSkipLaunchShapeRelatedOp(from_kernel_actor)) {
370       from_kernel_actor->set_skip_launch_shape_related_op(true);
371     }
372   }
373 
374   // The continuous memory inpus need allocate memory in advance, so must be from the inside subgraph.
375   if (to_actor->type() == KernelTransformType::kKernelActor) {
376     auto to_kernel_actor = dynamic_cast<KernelActor *>(to_actor);
377     MS_EXCEPTION_IF_NULL(to_kernel_actor);
378     if (to_kernel_actor->inputs_continuous_memory() && (from_actor->type() != KernelTransformType::kKernelActor)) {
379       MS_LOG(INTERNAL_EXCEPTION)
380         << "#dmsg#Runtime error info:#dmsg#The continuous memory input is not from the inside subgraph, to actor: "
381         << to_actor->GetAID().Name() << ", to input index: " << to_input_index
382         << ", from actor: " << from_actor->GetAID().Name() << ", from output index: " << from_output_index;
383     }
384   }
385 
386   AddMemorySign(from_actor, to_actor);
387 
388   auto data_arrow = std::make_shared<DataArrow>(from_output_index, to_actor->GetAID(), to_input_index);
389   (void)from_actor->output_data_arrows_.emplace_back(data_arrow);
390   (void)from_actor->output_data_nodes_.emplace_back(from_kernel);
391   to_actor->input_datas_num_++;
392   (void)to_actor->input_data_arrow_aids_.emplace_back(std::make_pair(from_actor->GetAID(), data_arrow.get()));
393 
394   if (from_kernel == nullptr) {
395     return;
396   }
397   // Update the reference count of from_kernel.
398   auto device_tensor = AnfAlgo::GetMutableOutputAddr(from_kernel, from_output_index, false);
399   MS_EXCEPTION_IF_NULL(device_tensor);
400   // The superkernel actor is linked by input parameter, maybe the not used parameter.
401   if (to_actor->type() != KernelTransformType::kSuperKernelActor) {
402     device_tensor->ClearFlag(device::kDeviceAddressFlagNotUsed);
403   }
404   // The device address of super kernel actor can't be changed, so set the max reference count.
405   if (IsControlFlowActor(to_actor->type()) || (from_actor->type_ == KernelTransformType::kSuperKernelActor) ||
406       (to_actor->type_ == KernelTransformType::kSuperKernelActor)) {
407     UpdateRefCount(device_tensor.get(), true);
408   } else {
409     UpdateDataArrowRefCount(to_actor, to_input_index, device_tensor);
410   }
411 
412   if (IsControlFlowActor(to_actor->type())) {
413     device_tensor->SetNodeIndex(from_kernel, from_output_index);
414   }
415 }
416 
AddResultArrow(AbstractActor * const from_actor,OutputActor * const to_actor,const AnfNodePtr & from_kernel,size_t from_output_index,size_t output_position)417 void SchedulerHelper::AddResultArrow(AbstractActor *const from_actor, OutputActor *const to_actor,
418                                      const AnfNodePtr &from_kernel, size_t from_output_index, size_t output_position) {
419   MS_EXCEPTION_IF_NULL(to_actor);
420   MS_EXCEPTION_IF_NULL(from_kernel);
421 
422   if (from_actor == nullptr) {
423     (void)to_actor->device_tensor_store_keys_.emplace_back(output_position, from_kernel);
424   } else {
425     auto result_arrow = std::make_shared<DataArrow>(from_output_index, to_actor->GetAID(), output_position);
426     (void)from_actor->output_data_arrows_.insert(from_actor->output_data_arrows_.begin(), result_arrow);
427     (void)from_actor->output_data_nodes_.insert(from_actor->output_data_nodes_.begin(), from_kernel);
428     to_actor->input_datas_num_++;
429     (void)to_actor->input_data_arrow_aids_.emplace_back(std::make_pair(from_actor->GetAID(), result_arrow.get()));
430   }
431 
432   if (!AnfAlgo::OutputAddrExist(from_kernel, from_output_index, false)) {
433     MS_LOG_WITH_NODE(INTERNAL_EXCEPTION, from_kernel)
434       << "#dmsg#Runtime error info:#dmsg#" << from_kernel->DebugString() << " index:" << from_output_index
435       << " device address does not exist";
436   }
437   auto device_tensor = AnfAlgo::GetMutableOutputAddr(from_kernel, from_output_index, false);
438   MS_EXCEPTION_IF_NULL(device_tensor);
439   device_tensor->ClearFlag(device::kDeviceAddressFlagNotUsed);
440   // The output actor need use the relevant information of node to create output tensor.
441   device_tensor->SetNodeIndex(from_kernel, from_output_index);
442   // The device tensor of graph out need be taken over by host tensor, so set the max reference count.
443   UpdateRefCount(device_tensor.get(), true);
444 
445   MS_LOG(DEBUG) << "Add result arrow from actor:" << (from_actor != nullptr ? from_actor->GetAID().Name() : "null")
446                 << " to actor:" << to_actor->GetAID() << " from kernel"
447                 << (from_kernel == nullptr ? "null" : from_kernel->DebugString()) << " device address:" << device_tensor
448                 << " original ref count:" << device_tensor->original_ref_count()
449                 << " ref count:" << device_tensor->ref_count()
450                 << " dynamic ref count:" << device_tensor->dynamic_ref_count();
451 
452   // Set the device contexts of to_actor.
453   if (output_position >= to_actor->device_contexts_.size()) {
454     MS_LOG(INTERNAL_EXCEPTION) << "#dmsg#Runtime error info:#dmsg#The output position is out of range.";
455   }
456   auto device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
457     {device_tensor->device_name(), device_tensor->device_id()});
458   to_actor->device_contexts_[output_position] = device_context;
459 }
460 
AddControlArrow(AbstractActor * const from_actor,AbstractActor * const to_actor)461 void SchedulerHelper::AddControlArrow(AbstractActor *const from_actor, AbstractActor *const to_actor) {
462   MS_EXCEPTION_IF_NULL(from_actor);
463   MS_EXCEPTION_IF_NULL(to_actor);
464 
465   // Check the control arrow whether exists.
466   auto iter = std::find_if(from_actor->output_control_arrows_.begin(), from_actor->output_control_arrows_.end(),
467                            [&to_actor](const auto &output_control_arrow) {
468                              return output_control_arrow->to_op_id_.Name() == to_actor->GetAID().Name();
469                            });
470   if (iter != from_actor->output_control_arrows_.end()) {
471     // The stack actor can only link the single control arrow.
472     if (to_actor->type_ == KernelTransformType::kStackActor) {
473       MS_LOG(INTERNAL_EXCEPTION) << "#dmsg#Runtime error info:#dmsg#The control arrow between "
474                                  << from_actor->GetAID().Name() << " and " << to_actor->GetAID().Name()
475                                  << " is repeated.";
476     }
477     return;
478   }
479 
480   // No need add control arrow if already exists data arrow in from and to actor.
481   if (from_actor->type() == KernelTransformType::kKernelActor &&
482       to_actor->type() == KernelTransformType::kKernelActor) {
483     const auto &input_data_arrows = to_actor->input_data_arrow_aids();
484     if (std::any_of(input_data_arrows.begin(), input_data_arrows.end(),
485                     [&from_actor](const std::pair<AID, DataArrow *> &input_data_arrow_pair) {
486                       return input_data_arrow_pair.first.Name() == from_actor->GetAID().Name();
487                     })) {
488       MS_LOG(INFO) << "No need add control arrow, because already exists data arrow in from actor: "
489                    << from_actor->GetAID().Name() << " and to actor: " << to_actor->GetAID().Name();
490       return;
491     }
492   }
493 
494   auto control_arrow = std::make_shared<ControlArrow>(to_actor->GetAID());
495   (void)from_actor->output_control_arrows_.emplace_back(control_arrow);
496   to_actor->input_controls_num_++;
497   (void)to_actor->input_control_arrow_aids_.emplace_back(std::make_pair(from_actor->GetAID(), control_arrow.get()));
498   MS_LOG(DEBUG) << "Add control arrow from actor:" << from_actor->GetAID() << " to actor:" << to_actor->GetAID();
499   AddMemorySign(from_actor, to_actor);
500 }
501 
AddPartialArrow(ControlActor * const from_actor,ControlActor * const to_actor,size_t from_index,size_t to_index)502 void SchedulerHelper::AddPartialArrow(ControlActor *const from_actor, ControlActor *const to_actor, size_t from_index,
503                                       size_t to_index) {
504   MS_EXCEPTION_IF_NULL(from_actor);
505   MS_EXCEPTION_IF_NULL(to_actor);
506   auto op_arrow = std::make_shared<DataArrow>(from_index, to_actor->GetAID(), to_index);
507   (void)from_actor->output_partial_arrows_.emplace_back(op_arrow);
508   to_actor->input_partials_num_++;
509   (void)to_actor->input_partial_arrow_aids_.emplace_back(from_actor->GetAID());
510 }
511 
AddBranchIDArrow(ControlActor * const from_actor,ControlActor * const to_actor)512 void SchedulerHelper::AddBranchIDArrow(ControlActor *const from_actor, ControlActor *const to_actor) {
513   MS_EXCEPTION_IF_NULL(from_actor);
514   MS_EXCEPTION_IF_NULL(to_actor);
515   (void)from_actor->output_branch_id_arrows_.emplace_back(to_actor->GetAID());
516   (void)to_actor->input_branch_id_arrow_aids_.emplace_back(from_actor->GetAID());
517   to_actor->input_branch_ids_num_++;
518 }
519 
AddLoopBodyControlArrow(AbstractActor * from_actor,EntranceActor * to_actor)520 void SchedulerHelper::AddLoopBodyControlArrow(AbstractActor *from_actor, EntranceActor *to_actor) {
521   MS_EXCEPTION_IF_NULL(from_actor);
522   MS_EXCEPTION_IF_NULL(to_actor);
523   MS_LOG(DEBUG) << "Link loop body control arrow from:" << from_actor->GetAID() << " to actor:" << to_actor->GetAID();
524   auto control_arrow = std::make_shared<ControlArrow>(to_actor->GetAID());
525   (void)from_actor->output_control_arrows_.emplace_back(control_arrow);
526   to_actor->loop_body_input_controls_nums_++;
527   (void)to_actor->loop_body_input_control_arrow_aids_.emplace_back(from_actor->GetAID());
528 }
529 
AddDataWithBranchIDArrow(GatherActor * const gather_actor,const EntranceActor * entrance_actor,const FuncGraphPtr & func_graph)530 void SchedulerHelper::AddDataWithBranchIDArrow(GatherActor *const gather_actor, const EntranceActor *entrance_actor,
531                                                const FuncGraphPtr &func_graph) {
532   MS_EXCEPTION_IF_NULL(gather_actor);
533   MS_EXCEPTION_IF_NULL(entrance_actor);
534   (void)gather_actor->output_data_with_branch_id_arrows_[func_graph.get()].emplace_back(entrance_actor->GetAID());
535 }
536 
AddDataArrowForExitActor(ExitActor * const exit_actor,AbstractActor * const to_actor,size_t from_index,size_t to_index,int branch_id)537 void SchedulerHelper::AddDataArrowForExitActor(ExitActor *const exit_actor, AbstractActor *const to_actor,
538                                                size_t from_index, size_t to_index, int branch_id) {
539   MS_EXCEPTION_IF_NULL(exit_actor);
540   MS_EXCEPTION_IF_NULL(to_actor);
541 
542   MS_LOG(DEBUG) << "Link data arrow from actor:" << exit_actor->GetAID() << " from index:" << from_index
543                 << " to actor:" << to_actor->GetAID() << " to index:" << to_index;
544   auto data_arrow = std::make_shared<DataArrow>(from_index, to_actor->GetAID(), to_index);
545   (void)exit_actor->output_branch_data_arrows_[branch_id].emplace_back(data_arrow);
546   (void)to_actor->input_data_arrow_aids_.emplace_back(std::make_pair(exit_actor->GetAID(), data_arrow.get()));
547 }
548 
AddPartialArrowForExitActor(ExitActor * const exit_actor,ControlActor * const to_actor,size_t from_index,size_t to_index,int branch_id)549 void SchedulerHelper::AddPartialArrowForExitActor(ExitActor *const exit_actor, ControlActor *const to_actor,
550                                                   size_t from_index, size_t to_index, int branch_id) {
551   MS_EXCEPTION_IF_NULL(exit_actor);
552   MS_EXCEPTION_IF_NULL(to_actor);
553   MS_LOG(DEBUG) << "Link partial arrow from actor:" << exit_actor->GetAID() << " from index:" << from_index
554                 << " to actor:" << to_actor->GetAID() << " to index:" << to_index;
555   auto partial_arrow = std::make_shared<DataArrow>(from_index, to_actor->GetAID(), to_index);
556   (void)exit_actor->output_branch_partial_arrows_[branch_id].emplace_back(partial_arrow);
557   (void)to_actor->input_partial_arrow_aids_.emplace_back(exit_actor->GetAID());
558 }
559 
AddControlArrowForExitActor(ExitActor * from_actor,AbstractActor * to_actor,int branch_id)560 void SchedulerHelper::AddControlArrowForExitActor(ExitActor *from_actor, AbstractActor *to_actor, int branch_id) {
561   MS_EXCEPTION_IF_NULL(from_actor);
562   MS_EXCEPTION_IF_NULL(to_actor);
563 
564   MS_LOG(DEBUG) << "Link control arrow from:" << from_actor->GetAID() << " to:" << to_actor->GetAID();
565   (void)from_actor->output_branch_control_arrows_[branch_id].emplace_back(to_actor->GetAID());
566   to_actor->input_controls_num_++;
567   (void)to_actor->input_control_arrow_aids_.emplace_back(std::make_pair(from_actor->GetAID(), nullptr));
568 }
569 
AddFormalParameterDeviceTensor(ControlActor * const from_actor,size_t from_index,const AnfNodePtr & input_node,const KernelGraphPtr & graph)570 void SchedulerHelper::AddFormalParameterDeviceTensor(ControlActor *const from_actor, size_t from_index,
571                                                      const AnfNodePtr &input_node, const KernelGraphPtr &graph) {
572   MS_EXCEPTION_IF_NULL(from_actor);
573   MS_EXCEPTION_IF_NULL(input_node);
574   MS_EXCEPTION_IF_NULL(graph);
575   // Graph mode does not support dynamic shape and ref node.
576   if (graph->is_graph_run_mode() || graph->is_any_type_input()) {
577     return;
578   }
579 
580   // Collect backend parameters with dynamic shapes.
581   auto base_shape = input_node->Shape();
582   if (input_node->isa<Parameter>() && base_shape != nullptr &&
583       ((base_shape->isa<abstract::Shape>() && base_shape->IsDynamic()) ||
584        base_shape->isa<abstract::DynamicSequenceShape>())) {
585     if (from_index >= from_actor->backend_parameters_.size()) {
586       MS_LOG(INTERNAL_EXCEPTION) << "#dmsg#Runtime error info:#dmsg#Invalid from index:" << from_index
587                                  << " for actor:" << from_actor->GetAID()
588                                  << " vector size:" << from_actor->backend_parameters_.size();
589     }
590     MS_LOG(INFO) << "Add dynamic shape backend parameter:" << input_node->DebugString() << " index:" << from_index
591                  << " for actor:" << from_actor->GetAID();
592     (void)from_actor->backend_parameters_[from_index].emplace_back(input_node);
593   }
594 
595   if (!common::AnfAlgo::HasAbstractRef(input_node)) {
596     return;
597   }
598 
599   auto device_tensor = AnfAlgo::GetMutableOutputAddr(input_node, 0, false);
600   MS_EXCEPTION_IF_NULL(device_tensor);
601   (void)from_actor->ref_formal_parameter_device_tensors_[from_index].insert(device_tensor);
602   if (graph->IsRefOutputMapValue({input_node, 0})) {
603     (void)from_actor->ref_node_formal_parameter_device_tensors_[from_index].insert(device_tensor);
604   }
605 
606   device_tensor->ClearFlag(device::kDeviceAddressFlagNotUsed);
607   UpdateRefCount(device_tensor.get(), true);
608   device_tensor->SetNodeIndex(input_node, 0);
609 }
610 
ConvertDataArrowToControlArrow(AbstractActor * const from_actor,AbstractActor * const to_actor,const DataArrowPtr & data_arrow,size_t data_arrow_index)611 void SchedulerHelper::ConvertDataArrowToControlArrow(AbstractActor *const from_actor, AbstractActor *const to_actor,
612                                                      const DataArrowPtr &data_arrow, size_t data_arrow_index) {
613   MS_EXCEPTION_IF_NULL(from_actor);
614   MS_EXCEPTION_IF_NULL(to_actor);
615   MS_EXCEPTION_IF_NULL(data_arrow);
616   MS_EXCEPTION_IF_CHECK_FAIL((data_arrow_index < from_actor->output_data_nodes_.size()), "Index out of range.");
617   auto &need_converted_node = from_actor->output_data_nodes_[data_arrow_index];
618   MS_EXCEPTION_IF_NULL(need_converted_node);
619 
620   // Skip the ref node because its reference count cann‘t be recalculated correctly.
621   auto device_tensor =
622     AnfAlgo::GetMutableOutputAddr(need_converted_node, IntToSize(data_arrow->from_output_index_), false);
623   MS_EXCEPTION_IF_NULL(device_tensor);
624   if (TEST_FLAG(device_tensor->flag(), device::kDeviceAddressFlagRefNode)) {
625     MS_LOG(INFO) << "Skip the invalid data arrow of ref node, from actor:" << from_actor->GetAID().Name()
626                  << ", from index:" << data_arrow->from_output_index_ << ", to actor:" << to_actor->GetAID().Name()
627                  << ", to index:" << data_arrow->to_input_index_;
628     return;
629   }
630 
631   auto kernel_info = dynamic_cast<KernelInfo *>(need_converted_node->kernel_info());
632   MS_EXCEPTION_IF_NULL(kernel_info);
633   const auto &somas_outputs = kernel_info->somas_output_result();
634   if (kernel_info->IsTensorEnableSomas(somas_outputs, data_arrow->from_output_index_)) {
635     MS_LOG(INFO) << "Skip the invalid data arrow of somas inner address, from actor:" << from_actor->GetAID().Name()
636                  << ", from index:" << data_arrow->from_output_index_ << ", to actor:" << to_actor->GetAID().Name()
637                  << ", to index:" << data_arrow->to_input_index_;
638     return;
639   }
640 
641   // Erase the output data arrow in from actor.
642   (void)from_actor->output_data_arrows_.erase(from_actor->output_data_arrows_.begin() + SizeToLong(data_arrow_index));
643   (void)from_actor->output_data_nodes_.erase(from_actor->output_data_nodes_.begin() + SizeToLong(data_arrow_index));
644 
645   // Erase the input data arrow aid in to actor.
646   bool to_actor_erase = false;
647   for (auto iter = to_actor->input_data_arrow_aids_.begin(); iter != to_actor->input_data_arrow_aids_.end(); ++iter) {
648     if ((*iter).first == from_actor->GetAID()) {
649       (void)to_actor->input_data_arrow_aids_.erase(iter);
650       to_actor_erase = true;
651       to_actor->input_datas_num_--;
652       break;
653     }
654   }
655   if (to_actor_erase == false) {
656     MS_LOG(INTERNAL_EXCEPTION) << "#dmsg#Runtime error info:#dmsg#Erase no input data arrow, from actor:"
657                                << from_actor->GetAID().Name() << ", to actor:" << to_actor->GetAID().Name()
658                                << ", data arrow index:" << data_arrow_index;
659   }
660 
661   // Recalculate the ref count of converted node.
662   size_t old_ref_count = device_tensor->ref_count();
663   // Ref count Initial value is 1.
664   size_t new_ref_count = 1;
665   for (auto &output_data_arrow : from_actor->output_data_arrows_) {
666     MS_EXCEPTION_IF_NULL(output_data_arrow);
667     if (output_data_arrow->from_output_index_ != data_arrow->from_output_index_) {
668       continue;
669     }
670     if ((output_data_arrow->to_op_id_.Name().find(kExitActorNameSuffix) != std::string::npos) ||
671         (output_data_arrow->to_op_id_.Name().find(kOutputActorNameSuffix) != std::string::npos)) {
672       new_ref_count = SIZE_MAX;
673       break;
674     }
675     ++new_ref_count;
676   }
677   device_tensor->set_original_ref_count(new_ref_count);
678   device_tensor->ResetRefCount();
679   MS_LOG(INFO) << "Erase the invalid data arrow, from actor:" << from_actor->GetAID().Name()
680                << ", from index:" << data_arrow->from_output_index_ << ", to actor:" << to_actor->GetAID().Name()
681                << ", to index:" << data_arrow->to_input_index_ << ", old ref count:" << old_ref_count
682                << ", new ref count:" << new_ref_count;
683 
684   // Add the control arrow.
685   SchedulerHelper::AddControlArrow(from_actor, to_actor);
686 }
687 
FuseDataArrowsToBatchDataArrow(AbstractActor * const actor)688 void SchedulerHelper::FuseDataArrowsToBatchDataArrow(AbstractActor *const actor) {
689   MS_EXCEPTION_IF_NULL(actor);
690   // Count the number of the same destination actor.
691   mindspore::HashMap<std::string, size_t> to_actor_count;
692   for (const auto &data_arrow : actor->output_data_arrows()) {
693     MS_EXCEPTION_IF_NULL(data_arrow);
694     ++(to_actor_count[data_arrow->to_op_id_.Name()]);
695   }
696 
697   // Sign and add the batch data arrow.
698   for (auto &data_arrow : actor->output_data_arrows()) {
699     MS_EXCEPTION_IF_NULL(data_arrow);
700     auto &to_op_name = data_arrow->to_op_id_.Name();
701     // The output data cannot be reused whose destination is stack actor, and cannot to be fused.
702     if ((to_actor_count[to_op_name] > 1) && (to_op_name.find(kStackActorNameSuffix) == std::string::npos)) {
703       SET_FLAG(data_arrow->flag_, kOutputDataFlagBatch);
704       (void)actor->batch_output_data_arrows_[to_op_name].emplace_back(data_arrow);
705     }
706   }
707 }
708 
AddDependency(AbstractActor * const actor,const AbstractActor * dependent_actor)709 void SchedulerHelper::AddDependency(AbstractActor *const actor, const AbstractActor *dependent_actor) {
710   MS_EXCEPTION_IF_NULL(actor);
711   MS_EXCEPTION_IF_NULL(dependent_actor);
712   // For example, ActorA->dependent_actor->actor, the expanded dependent actors of actor are dependent_actor and ActorA.
713   (void)actor->dependent_actors_.insert(dependent_actor->GetAID().Name());
714   actor->dependent_actors_.insert(dependent_actor->dependent_actors_.begin(), dependent_actor->dependent_actors_.end());
715 }
716 
CheckDependency(const std::vector<AbstractActorPtr> & output_actors)717 bool SchedulerHelper::CheckDependency(const std::vector<AbstractActorPtr> &output_actors) {
718   if (output_actors.size() <= 1) {
719     return true;
720   }
721 
722   for (size_t i = 1; i < output_actors.size(); ++i) {
723     auto &pre_actor = output_actors[i - 1];
724     auto &actor = output_actors[i];
725     MS_EXCEPTION_IF_NULL(pre_actor);
726     MS_EXCEPTION_IF_NULL(actor);
727     // The outputs have no dependencies.
728     if ((actor->dependent_actors_.count(pre_actor->GetAID().Name()) == 0) &&
729         (pre_actor->dependent_actors_.count(actor->GetAID().Name()) == 0)) {
730       return false;
731     }
732   }
733 
734   return true;
735 }
736 
BuildFusionActor(const std::vector<AbstractActorPtr> & actors)737 FusionActorPtr SchedulerHelper::BuildFusionActor(const std::vector<AbstractActorPtr> &actors) {
738   if (actors.size() <= 1) {
739     MS_LOG(INTERNAL_EXCEPTION) << "#dmsg#Runtime error info:#dmsg#The fusion actor size must be greater than 1.";
740   }
741 
742   std::string fusion_actor_name = std::to_string(++fusion_actor_index_) + kFusionActorNameSuffix;
743   auto fusion_actor = std::make_shared<FusionActor>(fusion_actor_name);
744   InsertActor(fusion_actor.get());
745   for (auto &actor : actors) {
746     MS_EXCEPTION_IF_NULL(actor);
747     actor->parent_fusion_actor_ = fusion_actor.get();
748     MS_LOG(DEBUG) << "Set fusion actor:" << fusion_actor->GetAID() << " to actor:" << actor->GetAID();
749     fusion_actor->sub_actors_[actor->GetAID().Name()] = actor;
750   }
751   return fusion_actor;
752 }
753 
AddArrowForFusionActor(FusionActor * fusion_actor)754 void SchedulerHelper::AddArrowForFusionActor(FusionActor *fusion_actor) {
755   MS_EXCEPTION_IF_NULL(fusion_actor);
756   for (auto &actor_iter : fusion_actor->sub_actors_) {
757     auto &actor = actor_iter.second;
758     MS_EXCEPTION_IF_NULL(actor);
759 
760     // Link data arrow of fusion actor by the input data arrow of real actor.
761     for (auto &input_data_arrow_aid : actor->input_data_arrow_aids_) {
762       auto input_data_arrow = input_data_arrow_aid.second;
763       MS_EXCEPTION_IF_NULL(input_data_arrow);
764       // Mark the kOutputDataFlagBetweenFusion flag when the input data arrow is the Internal actor in fusion actor.
765       if (fusion_actor->sub_actors_.count(input_data_arrow_aid.first.Name()) > 0) {
766         SET_FLAG(input_data_arrow->flag_, kOutputDataFlagBetweenFusion);
767         continue;
768       }
769 
770       SET_FLAG(input_data_arrow->flag_, kOutputDataFlagToFusion);
771       // The ActorB is in fusion actor and the input ActorA is on the outside of fusion actor, then change
772       // 'ActorA->ActorB' to 'ActorA->FusionActor'.
773       auto from_actor = FetchActor(input_data_arrow_aid.first.Name());
774       MS_EXCEPTION_IF_NULL(from_actor);
775       // Record the input index of real actor and fusion actor.
776       (void)fusion_actor->real_input_data_.emplace_back(std::make_pair(actor.get(), input_data_arrow->to_input_index_));
777       from_actor->data_arrow_to_fusion_actor_indexs_[input_data_arrow] = fusion_actor->input_data_arrow_aids_.size();
778       input_data_arrow->to_input_index_ = SizeToInt(fusion_actor->input_data_arrow_aids_.size());
779 
780       input_data_arrow->to_op_id_ = fusion_actor->GetAID();
781       ++fusion_actor->input_datas_num_;
782       (void)fusion_actor->input_data_arrow_aids_.emplace_back(
783         std::make_pair(input_data_arrow_aid.first, input_data_arrow));
784     }
785 
786     // Link control arrow of fusion actor by the input control arrow of real actor.
787     for (auto &input_control_arrow_aid : actor->input_control_arrow_aids_) {
788       auto input_control_arrow = input_control_arrow_aid.second;
789       MS_EXCEPTION_IF_NULL(input_control_arrow);
790       // Mark the kOutputDataFlagBetweenFusion flag when the input control arrow is the Internal actor in fusion
791       // actor.
792       if (fusion_actor->sub_actors_.count(input_control_arrow_aid.first.Name()) > 0) {
793         SET_FLAG(input_control_arrow->flag_, kOutputDataFlagBetweenFusion);
794         continue;
795       }
796 
797       SET_FLAG(input_control_arrow->flag_, kOutputDataFlagToFusion);
798       // The ActorB is in fusion actor and the input ActorA is on the outside of fusion actor, then change
799       // 'ActorA->ActorB' to 'ActorA->FusionActor'.
800       (void)fusion_actor->real_input_controls_[input_control_arrow_aid.first.Name()].emplace_back(actor.get());
801       input_control_arrow->to_op_id_ = fusion_actor->GetAID();
802       ++fusion_actor->input_controls_num_;
803       (void)fusion_actor->input_control_arrow_aids_.emplace_back(
804         std::make_pair(input_control_arrow_aid.first, input_control_arrow));
805     }
806   }
807 }
808 
AddMemorySign(AbstractActor * const from_actor,AbstractActor * const to_actor)809 void SchedulerHelper::AddMemorySign(AbstractActor *const from_actor, AbstractActor *const to_actor) {
810   MS_EXCEPTION_IF_NULL(from_actor);
811   MS_EXCEPTION_IF_NULL(to_actor);
812   auto ms_context = MsContext::GetInstance();
813   MS_EXCEPTION_IF_NULL(ms_context);
814   if (ms_context->get_param<int>(MS_CTX_MEMORY_OPTIMIZE_LEVEL) == kOptimizeO0) {
815     return;
816   }
817   // The link of memory actor no need add the memory sign.
818   if (IsMemoryActor(from_actor->type()) || IsMemoryActor(to_actor->type())) {
819     return;
820   }
821 
822   // Add the somas info.
823   AddSomasInfo(from_actor);
824   AddSomasInfo(to_actor);
825 
826   auto from_graph = FetchKernelGraphByActor(from_actor);
827   auto to_graph = FetchKernelGraphByActor(to_actor);
828   // Add the memory alloc and free sign at the boundary of the graph.
829   if ((from_graph != nullptr) && (to_graph != nullptr)) {
830     // The same graph no need insert the memory actor.
831     if (from_graph->graph_id() == to_graph->graph_id()) {
832       return;
833     }
834     AddMemoryFreeSign(from_actor, to_actor, from_graph);
835     AddMemoryAllocSign(from_actor, to_actor, to_graph);
836   } else if (from_graph != nullptr) {
837     AddMemoryFreeSign(from_actor, to_actor, from_graph);
838   } else if (to_graph != nullptr) {
839     AddMemoryAllocSign(from_actor, to_actor, to_graph);
840   }
841 }
842 
FetchKernelGraphByActor(AbstractActor * const actor)843 KernelGraphPtr SchedulerHelper::FetchKernelGraphByActor(AbstractActor *const actor) {
844   MS_EXCEPTION_IF_NULL(actor);
845   AnfNode *from_kernel = nullptr;
846   if (actor->type() == KernelTransformType::kKernelActor ||
847       actor->type() == KernelTransformType::kConditionGatherActor ||
848       actor->type() == KernelTransformType::kConditionSwitchActor) {
849     auto kernel_actor = dynamic_cast<KernelActor *>(actor);
850     MS_EXCEPTION_IF_NULL(kernel_actor);
851     from_kernel = kernel_actor->kernel().get();
852     MS_EXCEPTION_IF_NULL(from_kernel);
853   }
854 
855   // The device data source actor is from the GetNext cnode that is not a boundary of the graph and is equivalent to the
856   // kernel actor when inserted the memory actor.
857   if (actor->type() == KernelTransformType::kDeviceDataSourceActor) {
858     auto device_ds_actor = dynamic_cast<DeviceQueueDataSourceActor *>(actor);
859     MS_EXCEPTION_IF_NULL(device_ds_actor);
860     from_kernel = device_ds_actor->data_kernel().get();
861     MS_EXCEPTION_IF_NULL(from_kernel);
862   }
863 
864   // Only the copy actor from device tensor store need to fetch the kernel graph, because the copy actor is not a
865   // boundary of the graph and is equivalent to the kernel actor when inserted the memory actor.
866   if ((actor->type() == KernelTransformType::kCopyActor) &&
867       (actor->GetAID().Name().find(kCopyActorNameSignFromStore) != std::string::npos)) {
868     auto copy_actor = dynamic_cast<CopyActor *>(actor);
869     MS_EXCEPTION_IF_NULL(copy_actor);
870     from_kernel = copy_actor->from_kernel_;
871   }
872 
873   if (from_kernel == nullptr) {
874     return nullptr;
875   }
876   auto graph = AnfAlgo::FetchKernelGraph(from_kernel);
877   if (graph == nullptr) {
878     MS_LOG(INTERNAL_EXCEPTION) << "#dmsg#Runtime error info:#dmsg#No associated graph for node: "
879                                << from_kernel->fullname_with_scope();
880   }
881 
882   return graph;
883 }
884 
AddMemoryAllocSign(AbstractActor * const from_actor,AbstractActor * const to_actor,const KernelGraphPtr & to_graph)885 void SchedulerHelper::AddMemoryAllocSign(AbstractActor *const from_actor, AbstractActor *const to_actor,
886                                          const KernelGraphPtr &to_graph) {
887   MS_EXCEPTION_IF_NULL(from_actor);
888   MS_EXCEPTION_IF_NULL(to_actor);
889   MS_EXCEPTION_IF_NULL(to_graph);
890   // Somas is not work for this graph.
891   if (to_graph->somas_whole_block_size() == 0) {
892     return;
893   }
894 
895   // Set the memory alloc info.
896   to_actor->memory_alloc_insert_position_ = from_actor;
897 }
898 
AddMemoryFreeSign(AbstractActor * const from_actor,AbstractActor * const to_actor,const KernelGraphPtr & from_graph)899 void SchedulerHelper::AddMemoryFreeSign(AbstractActor *const from_actor, AbstractActor *const to_actor,
900                                         const KernelGraphPtr &from_graph) {
901   MS_EXCEPTION_IF_NULL(from_actor);
902   MS_EXCEPTION_IF_NULL(to_actor);
903   MS_EXCEPTION_IF_NULL(from_graph);
904   // Somas is not work for this graph.
905   if (from_graph->somas_whole_block_size() == 0) {
906     return;
907   }
908 
909   // Set the memory free info.
910   from_actor->memory_free_insert_position_ = to_actor;
911 }
912 
AddSomasInfo(AbstractActor * const actor)913 void SchedulerHelper::AddSomasInfo(AbstractActor *const actor) {
914   MS_EXCEPTION_IF_NULL(actor);
915   // Only the kernel actor supports somas.
916   if (actor->type() != KernelTransformType::kKernelActor &&
917       actor->type() != KernelTransformType::kConditionGatherActor &&
918       actor->type() != KernelTransformType::kConditionSwitchActor) {
919     return;
920   }
921   auto kernel_actor = dynamic_cast<KernelActor *>(actor);
922   MS_EXCEPTION_IF_NULL(kernel_actor);
923   if (kernel_actor->somas_info_ != nullptr) {
924     return;
925   }
926 
927   MS_EXCEPTION_IF_NULL(kernel_actor->kernel());
928   auto graph = AnfAlgo::FetchKernelGraph(kernel_actor->kernel().get());
929   if (graph == nullptr) {
930     MS_LOG(INTERNAL_EXCEPTION) << "#dmsg#Runtime error info:#dmsg#No associated graph for node: "
931                                << kernel_actor->kernel()->fullname_with_scope();
932   }
933   // Somas is not work for this graph.
934   if (graph->somas_whole_block_size() == 0) {
935     return;
936   }
937 
938   // Set the somas info.
939   auto somas_info = graph->MutableSomasInfo();
940   MS_EXCEPTION_IF_NULL(somas_info);
941   somas_info->graph_id_ = graph->graph_id();
942   kernel_actor->somas_info_ = somas_info;
943 }
944 
AddSomasInfoForGraphOutput(AbstractActor * const output_actor,const AnfNodePtr & output_kernel,size_t output_index,size_t graph_id)945 void SchedulerHelper::AddSomasInfoForGraphOutput(AbstractActor *const output_actor, const AnfNodePtr &output_kernel,
946                                                  size_t output_index, size_t graph_id) {
947   auto ms_context = MsContext::GetInstance();
948   MS_EXCEPTION_IF_NULL(ms_context);
949   if (ms_context->get_param<int>(MS_CTX_MEMORY_OPTIMIZE_LEVEL) == kOptimizeO0) {
950     return;
951   }
952   if ((output_actor == nullptr) || (output_actor->type() != KernelTransformType::kKernelActor &&
953                                     output_actor->type() != KernelTransformType::kConditionSwitchActor &&
954                                     output_actor->type() != KernelTransformType::kConditionGatherActor)) {
955     return;
956   }
957 
958   MS_EXCEPTION_IF_NULL(output_kernel);
959   auto kernel_info = dynamic_cast<KernelInfo *>(output_kernel->kernel_info());
960   MS_EXCEPTION_IF_NULL(kernel_info);
961   const auto &somas_outputs = kernel_info->somas_output_result();
962   auto is_somas = kernel_info->IsTensorEnableSomas(somas_outputs, output_index);
963   MS_LOG(INFO) << "The graph " << graph_id << " output node:" << output_kernel->fullname_with_scope()
964                << " with index: " << output_index << " somas enable or not: " << is_somas
965                << ", somas offset: " << kernel_info->GetTensorSomasOffset(somas_outputs, output_index)
966                << ", aligned size: " << kernel_info->GetTensorSomasAlignedSize(somas_outputs, output_index);
967   if (is_somas) {
968     auto kernel_actor = dynamic_cast<KernelActor *>(output_actor);
969     kernel_actor->somas_graph_output_indexes_.insert(output_index);
970   }
971 }
972 
973 namespace {
CheckKernelActorValid(const std::vector<KernelActorPtr> & kernel_actors)974 void CheckKernelActorValid(const std::vector<KernelActorPtr> &kernel_actors) {
975   for (const auto &kernel_actor : kernel_actors) {
976     MS_EXCEPTION_IF_NULL(kernel_actor);
977     std::string exit_actor_name = "";
978 
979     for (const auto &arrow : kernel_actor->output_data_arrows()) {
980       MS_EXCEPTION_IF_NULL(arrow);
981       if (arrow->to_op_id_.Name().find(kExitActorNameSuffix) == std::string::npos) {
982         continue;
983       }
984       if (exit_actor_name == "") {
985         exit_actor_name = arrow->to_op_id_.Name();
986         continue;
987       }
988       if (exit_actor_name != arrow->to_op_id_.Name()) {
989         MS_LOG(INTERNAL_EXCEPTION) << "#dmsg#Runtime error info:#dmsg#Kernel actor:" << kernel_actor->GetAID()
990                                    << " link to two exit actor:" << exit_actor_name
991                                    << " and:" << arrow->to_op_id_.Name();
992       }
993     }
994   }
995 }
996 
CheckExitActorInvalid(const ExitActorPtr & exit_actor)997 bool CheckExitActorInvalid(const ExitActorPtr &exit_actor) {
998   MS_EXCEPTION_IF_NULL(exit_actor);
999 
1000   return exit_actor->output_data_arrows().empty() && exit_actor->output_partial_arrows().empty() &&
1001          exit_actor->output_control_arrows().empty() && exit_actor->output_branch_control_arrows().empty() &&
1002          exit_actor->output_branch_data_arrows().empty() && exit_actor->output_branch_partial_arrows().empty() &&
1003          !exit_actor->input_data_arrow_aids().empty();
1004 }
1005 
1006 // Convert the control actors vector by the control actor set.
CollectControlActors(const ControlActorSetPtr & control_actor_set)1007 std::vector<ControlActorPtr> CollectControlActors(const ControlActorSetPtr &control_actor_set) {
1008   MS_EXCEPTION_IF_NULL(control_actor_set);
1009   std::vector<ControlActorPtr> actors;
1010 
1011   for (auto &switch_actor : control_actor_set->switch_actors_) {
1012     MS_EXCEPTION_IF_NULL(switch_actor);
1013     (void)actors.emplace_back(static_cast<ControlActorPtr>(switch_actor));
1014   }
1015   for (auto &gather_actor : control_actor_set->gather_actors_) {
1016     MS_EXCEPTION_IF_NULL(gather_actor);
1017     (void)actors.emplace_back(static_cast<ControlActorPtr>(gather_actor));
1018   }
1019   for (auto &entrance_actor : control_actor_set->entrance_actors_) {
1020     MS_EXCEPTION_IF_NULL(entrance_actor);
1021     (void)actors.emplace_back(static_cast<ControlActorPtr>(entrance_actor));
1022   }
1023   for (auto &exit_actor : control_actor_set->exit_actors_) {
1024     MS_EXCEPTION_IF_NULL(exit_actor);
1025     (void)actors.emplace_back(static_cast<ControlActorPtr>(exit_actor));
1026   }
1027   for (auto &stack_actor : control_actor_set->stack_actors_) {
1028     MS_EXCEPTION_IF_NULL(stack_actor);
1029     (void)actors.emplace_back(static_cast<ControlActorPtr>(stack_actor));
1030   }
1031 
1032   return actors;
1033 }
1034 
CheckControlActorValid(const ActorSet * actor_set)1035 void CheckControlActorValid(const ActorSet *actor_set) {
1036   MS_EXCEPTION_IF_NULL(actor_set);
1037   if (actor_set->control_actors_ == nullptr) {
1038     return;
1039   }
1040 
1041   CheckKernelActorValid(actor_set->kernel_actors_);
1042 
1043   auto control_actors = CollectControlActors(actor_set->control_actors_);
1044   for (const auto &control_actor : control_actors) {
1045     MS_EXCEPTION_IF_NULL(control_actor);
1046     for (auto &ref_node_formal_parameter_device_tensor : control_actor->ref_node_formal_parameter_device_tensors()) {
1047       auto &device_tensors = ref_node_formal_parameter_device_tensor.second;
1048       for (auto iter = device_tensors.begin(); iter != device_tensors.end(); ++iter) {
1049         if (((*device_tensors.begin())->format() != (*iter)->format()) ||
1050             ((*device_tensors.begin())->GetDeviceType() != (*iter)->GetDeviceType()) ||
1051             ((*device_tensors.begin())->type_id() != (*iter)->type_id())) {
1052           MS_LOG(INTERNAL_EXCEPTION) << "#dmsg#Runtime error info:#dmsg#" << control_actor->GetAID().Name()
1053                                      << " does not support the ref node formal parameters with different format.";
1054         }
1055       }
1056     }
1057 
1058     for (auto &ref_formal_parameter_device_tensor : control_actor->ref_formal_parameter_device_tensors()) {
1059       auto &device_tensors = ref_formal_parameter_device_tensor.second;
1060       for (auto iter = device_tensors.begin(); iter != device_tensors.end(); ++iter) {
1061         if ((*device_tensors.begin())->type_id() != (*iter)->type_id()) {
1062           MS_LOG(INTERNAL_EXCEPTION) << "#dmsg#Runtime error info:#dmsg#" << control_actor->GetAID().Name()
1063                                      << " does not support the ref formal parameters with different type.";
1064         }
1065       }
1066     }
1067   }
1068 
1069   for (const auto &exit_actor : actor_set->control_actors_->exit_actors_) {
1070     MS_EXCEPTION_IF_NULL(exit_actor);
1071     if (CheckExitActorInvalid(exit_actor)) {
1072       MS_LOG(INTERNAL_EXCEPTION) << "#dmsg#Runtime error info:#dmsg#Invalid exit actor:" << exit_actor->GetAID();
1073     }
1074   }
1075 
1076   // Since some control arrows of stack actors need to be counted according to aid, the input control arrow cannot
1077   // be repeated, otherwise the count will be inaccurate. But there are exceptions, if the control arrow does not
1078   // need to be counted according to aid, it can be repeated.
1079   for (const auto &stack_actor : actor_set->control_actors_->stack_actors_) {
1080     MS_EXCEPTION_IF_NULL(stack_actor);
1081     const auto &input_control_aids = stack_actor->input_control_arrow_aids();
1082     std::set<AID> aid_set;
1083     (void)std::for_each(input_control_aids.begin(), input_control_aids.end(),
1084                         [&aid_set](const auto &input_control_aid) { (void)aid_set.emplace(input_control_aid.first); });
1085     if (aid_set.size() != input_control_aids.size()) {
1086       MS_LOG(WARNING) << "Stack actor:" << stack_actor->GetAID() << " has duplicate control arrows.";
1087     }
1088   }
1089 }
1090 }  // namespace
1091 
CheckActorValid(const ActorSet * actor_set)1092 void SchedulerHelper::CheckActorValid(const ActorSet *actor_set) {
1093   MS_EXCEPTION_IF_NULL(actor_set);
1094   auto actors = SchedulerHelper::CollectActors(actor_set);
1095   for (auto &actor : actors) {
1096     MS_EXCEPTION_IF_NULL(actor);
1097     if (actor->type_ >= KernelTransformType::kSwitchActor) {
1098       continue;
1099     }
1100 
1101     if ((actor->input_datas_num_ != actor->input_data_arrow_aids_.size()) ||
1102         (actor->input_controls_num_ != actor->input_control_arrow_aids_.size())) {
1103       MS_LOG(INTERNAL_EXCEPTION) << "#dmsg#Runtime error info:#dmsg#The input num of " << actor->GetAID().Name()
1104                                  << " is wrong, expect data num: " << actor->input_datas_num_
1105                                  << ", actual data num: " << actor->input_data_arrow_aids_.size()
1106                                  << ", expect control num: " << actor->input_controls_num_
1107                                  << ", actual control num: " << actor->input_control_arrow_aids_.size();
1108     }
1109 
1110     if ((actor->type_ != KernelTransformType::kOutputActor) && (actor->type_ != KernelTransformType::kCustomActor) &&
1111         (actor->output_data_arrows_.size() == 0) && (actor->output_control_arrows_.size() == 0)) {
1112       MS_LOG(INTERNAL_EXCEPTION) << "#dmsg#Runtime error info:#dmsg#" << actor->GetAID().Name() << " has no user.";
1113     }
1114     if ((actor->type_ != KernelTransformType::kDataPrepareActor) &&
1115         (actor->type_ != KernelTransformType::kCustomActor) && (actor->input_datas_num_ == 0) &&
1116         (actor->input_controls_num_ == 0)) {
1117       MS_LOG(INTERNAL_EXCEPTION) << "#dmsg#Runtime error info:#dmsg#" << actor->GetAID().Name() << " has no source.";
1118     }
1119 
1120     // Check the input of kernel actors and copy actors.
1121     if ((actor->type_ == KernelTransformType::kKernelActor) || (actor->type_ == KernelTransformType::kCopyActor)) {
1122       size_t expect_input_num = 1;
1123       if (actor->type_ == KernelTransformType::kKernelActor) {
1124         auto kernel_actor = dynamic_cast<KernelActor *>(actor.get());
1125         MS_EXCEPTION_IF_NULL(kernel_actor);
1126         auto &kernel = kernel_actor->kernel();
1127         MS_EXCEPTION_IF_NULL(kernel);
1128         expect_input_num = common::AnfAlgo::GetInputTensorNum(kernel);
1129       }
1130       auto input_data_num = actor->input_datas_num_;
1131       auto device_tensor_store_num = actor->device_tensor_store_keys_.size();
1132       if (input_data_num + device_tensor_store_num != expect_input_num) {
1133         MS_LOG(INTERNAL_EXCEPTION) << "#dmsg#Runtime error info:#dmsg#The input building of " << actor->GetAID().Name()
1134                                    << " is wrong, input data num: " << input_data_num
1135                                    << ", device tensor store num: " << device_tensor_store_num
1136                                    << ", total input num: " << expect_input_num;
1137       }
1138     }
1139   }
1140 
1141   // Check the output actor.
1142   auto output_actor = actor_set->output_actor_;
1143   MS_EXCEPTION_IF_NULL(output_actor);
1144   if (output_actor->input_datas_num_ + output_actor->device_tensor_store_keys_.size() != output_actor->outputs_num()) {
1145     MS_LOG(INTERNAL_EXCEPTION)
1146       << "#dmsg#Runtime error info:#dmsg#The outputs num of output actor is wrong, the total outputs num: "
1147       << output_actor->outputs_num() << ", the input data arrows num: " << output_actor->input_datas_num_
1148       << ", the device tensor store num: " << output_actor->device_tensor_store_keys_.size();
1149   }
1150 
1151   CheckControlActorValid(actor_set);
1152 }
1153 
DumpActorSet(const ActorSet * actor_set,std::ofstream & ofs)1154 void SchedulerHelper::DumpActorSet(const ActorSet *actor_set, std::ofstream &ofs) {
1155   MS_EXCEPTION_IF_NULL(actor_set);
1156   DumpDataPrepareActor(actor_set->data_prepare_actor_, ofs);
1157   DumpDSActors(actor_set->data_source_actors_, ofs);
1158   DumpKernelActors(actor_set->kernel_actors_, ofs);
1159   DumpKernelInferActors(actor_set->kernel_infer_actors_, ofs);
1160   DumpKernelResizeActors(actor_set->kernel_resize_actors_, ofs);
1161   DumpSuperKernelActors(actor_set->super_kernel_actors_, ofs);
1162   DumpAnyTypeKernelActors(actor_set->any_type_kernel_actors_, ofs);
1163   // The on input kernel actors are taken over by control actor in the control flow scene.
1164   if (actor_set->control_actors_ == nullptr) {
1165     DumpNoInputKernelActors(actor_set->no_input_kernel_actors_, ofs);
1166   }
1167   DumpMemoryActors(actor_set->memory_actors_, ofs);
1168   DumpCopyActors(actor_set->copy_actors_, ofs);
1169   DumpLoopCountActor(actor_set->loop_count_actor_, ofs);
1170   DumpOutputActor(actor_set->output_actor_, ofs);
1171   DumpFusionActors(actor_set->fusion_actors_, ofs);
1172   DumpControlActors(actor_set->control_actors_, ofs);
1173   DumpCustomActors(actor_set->custom_actors_, ofs);
1174   DumpSwapActors(actor_set->swap_actors_, ofs);
1175 }
1176 
DumpFormatActorSet(const ActorSet * actor_set,std::ofstream & ofs)1177 void SchedulerHelper::DumpFormatActorSet(const ActorSet *actor_set, std::ofstream &ofs) {
1178   MS_EXCEPTION_IF_NULL(actor_set);
1179   try {
1180     MS_LOG(DEBUG) << "Start dump format actor set:" << actor_set->name_;
1181     if (actor_set->control_actors_ != nullptr) {
1182       for (const auto &exit_actor : actor_set->control_actors_->exit_actors_) {
1183         if (exit_actor->node() != nullptr) {
1184           continue;
1185         }
1186         auto actors = TopoSortForActor(exit_actor.get());
1187         ActorInfoMap actor_info;
1188         ofs << "\n\nBase Block : "
1189             << exit_actor->GetAID().Name().substr(0, exit_actor->GetAID().Name().find(kExitActorNameSuffix)) << "\n\n";
1190         for (size_t i = 0; i < actors.size(); ++i) {
1191           DumpActorInfo(actors[i], i, &actor_info, ofs);
1192         }
1193       }
1194       return;
1195     }
1196 
1197     auto actors = TopoSortForActor(actor_set->output_actor_.get());
1198     ActorInfoMap actor_info;
1199     for (size_t i = 0; i < actors.size(); ++i) {
1200       DumpActorInfo(actors[i], i, &actor_info, ofs);
1201     }
1202     MS_LOG(DEBUG) << "End dump format actor set:" << actor_set->name_;
1203   } catch (const std::exception &e) {
1204     MS_LOG(INFO) << "Failed to dump actor set:" << actor_set->name_ << ", msg: " << e.what();
1205   }
1206 }
1207 }  // namespace runtime
1208 }  // namespace mindspore
1209