1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "runtime/graph_scheduler/scheduler_helper.h"
18 #include "mindspore/core/ops/framework_ops.h"
19 #include "mindspore/core/ops/array_ops.h"
20 #include "runtime/graph_scheduler/actor/actor_dump.h"
21 #include "include/backend/anf_runtime_algorithm.h"
22 #include "include/common/utils/anfalgo.h"
23 #include "utils/anf_utils.h"
24 #include "utils/log_adapter.h"
25 #include "include/common/utils/convert_utils.h"
26
27 namespace mindspore {
28 namespace runtime {
29 size_t SchedulerHelper::fusion_actor_index_ = 0;
30
31 namespace {
CollectControlActors(const ActorSet * actor_set,std::vector<AbstractActorPtr> * actors)32 void CollectControlActors(const ActorSet *actor_set, std::vector<AbstractActorPtr> *actors) {
33 MS_EXCEPTION_IF_NULL(actor_set);
34 MS_EXCEPTION_IF_NULL(actors);
35 if (actor_set->control_actors_ != nullptr) {
36 const auto &control_actor_set = actor_set->control_actors_;
37 for (auto &switch_actor : control_actor_set->switch_actors_) {
38 MS_EXCEPTION_IF_NULL(switch_actor);
39 (void)actors->emplace_back(static_cast<AbstractActorPtr>(switch_actor));
40 }
41 for (auto &gather_actor : control_actor_set->gather_actors_) {
42 MS_EXCEPTION_IF_NULL(gather_actor);
43 (void)actors->emplace_back(static_cast<AbstractActorPtr>(gather_actor));
44 }
45 for (auto &entrance_actor : control_actor_set->entrance_actors_) {
46 MS_EXCEPTION_IF_NULL(entrance_actor);
47 (void)actors->emplace_back(static_cast<AbstractActorPtr>(entrance_actor));
48 }
49 for (auto &exit_actor : control_actor_set->exit_actors_) {
50 MS_EXCEPTION_IF_NULL(exit_actor);
51 (void)actors->emplace_back(static_cast<AbstractActorPtr>(exit_actor));
52 }
53 for (auto &stack_actor : control_actor_set->stack_actors_) {
54 MS_EXCEPTION_IF_NULL(stack_actor);
55 (void)actors->emplace_back(static_cast<AbstractActorPtr>(stack_actor));
56 }
57 }
58 }
59
IsSkipLaunchShapeRelatedOp(KernelActor * kernel_actor)60 bool IsSkipLaunchShapeRelatedOp(KernelActor *kernel_actor) {
61 MS_EXCEPTION_IF_NULL(kernel_actor);
62 if (kernel_actor->skip_launch_shape_related_op()) {
63 return true;
64 }
65
66 auto &kernel = kernel_actor->kernel();
67 MS_EXCEPTION_IF_NULL(kernel);
68
69 // RealMakeTuple --> ShapeCalc pattern:
70 // If ShapeCalc is not value depend for one input RealMakeTuple op, we can skip launch this RealMakeTuple.
71 if (IsPrimitiveCNode(kernel, prim::kPrimRealMakeTuple)) {
72 auto func_graph = kernel->func_graph();
73 MS_EXCEPTION_IF_NULL(func_graph);
74 auto manager = func_graph->manager();
75 if (manager == nullptr) {
76 manager = Manage(func_graph, true);
77 func_graph->set_manager(manager);
78 }
79
80 const auto &users_set = manager->node_users()[kernel];
81 bool can_skip_launch_real_make_tuple = true;
82 for (const auto &item : users_set) {
83 const auto &user_node = item.first;
84 if (!user_node->isa<CNode>()) {
85 can_skip_launch_real_make_tuple = false;
86 break;
87 }
88 auto user_cnode = user_node->cast<CNodePtr>();
89 MS_EXCEPTION_IF_NULL(user_cnode);
90 if (!IsPrimitiveCNode(user_cnode, prim::kPrimShapeCalc)) {
91 can_skip_launch_real_make_tuple = false;
92 break;
93 }
94
95 if (!common::AnfAlgo::HasNodeAttr(kAttrOnlyDependShape, user_cnode)) {
96 can_skip_launch_real_make_tuple = false;
97 break;
98 }
99 const auto &only_depend_shape = common::AnfAlgo::GetNodeAttr<std::vector<bool>>(user_cnode, kAttrOnlyDependShape);
100 auto user_input_index = item.second;
101 if (user_input_index < 1) {
102 MS_LOG(EXCEPTION) << "The input index should start from 1, but got: " << user_input_index;
103 }
104 if (IntToSize(user_input_index) > only_depend_shape.size()) {
105 MS_LOG(EXCEPTION) << "The input index[" << user_input_index
106 << "] is out of range, input size: " << only_depend_shape.size();
107 }
108 if (!only_depend_shape[user_input_index - 1]) {
109 can_skip_launch_real_make_tuple = false;
110 break;
111 }
112 }
113
114 if (can_skip_launch_real_make_tuple) {
115 return true;
116 }
117 }
118
119 return false;
120 }
121
UpdateDataArrowRefCount(AbstractActor * const to_actor,size_t to_input_index,const DeviceTensorPtr & device_tensor)122 void UpdateDataArrowRefCount(AbstractActor *const to_actor, size_t to_input_index,
123 const DeviceTensorPtr &device_tensor) {
124 MS_LOG(DEBUG) << "Process shape depend attribute for actor : " << to_actor->GetAID().Name();
125 bool need_increase_ref_count = true;
126 auto to_kernel_actor = dynamic_cast<KernelActor *>(to_actor);
127 auto ms_context = MsContext::GetInstance();
128 MS_EXCEPTION_IF_NULL(ms_context);
129 static const bool enable_infer_boost = ms_context->IsEnableInferBoost();
130 if (to_kernel_actor != nullptr && !enable_infer_boost) {
131 auto to_kernel = to_kernel_actor->kernel();
132 auto cnode = to_kernel->cast<CNodePtr>();
133 if (cnode != nullptr) {
134 MS_LOG(DEBUG) << "Process shape depend attribute for cnode : " << cnode->fullname_with_scope();
135 const auto &only_depend_shape_attr = common::AnfAlgo::GetCNodePrimitiveAttr(cnode, kAttrOnlyDependShape);
136 if (only_depend_shape_attr != nullptr) {
137 auto only_depend_shape = GetValue<std::vector<bool>>(only_depend_shape_attr);
138 if (only_depend_shape.size() <= to_input_index) {
139 MS_LOG(DEBUG) << "to_input_index : " << to_input_index
140 << " is out of range, only_depend_shape size : " << only_depend_shape.size();
141 } else {
142 auto is_shape_depend = only_depend_shape[to_input_index];
143 MS_LOG(DEBUG) << "only_depend_shape[" << to_input_index << "] : " << is_shape_depend;
144 if (is_shape_depend) {
145 need_increase_ref_count = false;
146 }
147 }
148 }
149 }
150 }
151 if (need_increase_ref_count) {
152 UpdateRefCount(device_tensor.get(), false);
153 } else {
154 device_tensor->UpdateFlag(device::kDeviceAddressFlagNullptr);
155 }
156 }
157 } // namespace
158
CollectActors(const ActorSet * actor_set)159 std::vector<AbstractActorPtr> SchedulerHelper::CollectActors(const ActorSet *actor_set) {
160 MS_EXCEPTION_IF_NULL(actor_set);
161 std::vector<AbstractActorPtr> actors;
162
163 if (actor_set->data_prepare_actor_ != nullptr) {
164 (void)actors.emplace_back(static_cast<AbstractActorPtr>(actor_set->data_prepare_actor_));
165 }
166 for (auto &data_source_actor : actor_set->data_source_actors_) {
167 MS_EXCEPTION_IF_NULL(data_source_actor);
168 (void)actors.emplace_back(static_cast<AbstractActorPtr>(data_source_actor));
169 }
170 for (auto &custom_actor : actor_set->custom_actors_) {
171 MS_EXCEPTION_IF_NULL(custom_actor);
172 (void)actors.emplace_back(static_cast<AbstractActorPtr>(custom_actor));
173 }
174 for (auto &kernel_actor : actor_set->kernel_actors_) {
175 MS_EXCEPTION_IF_NULL(kernel_actor);
176 (void)actors.emplace_back(static_cast<AbstractActorPtr>(kernel_actor));
177 }
178 for (auto &kernel_infer_actor : actor_set->kernel_infer_actors_) {
179 MS_EXCEPTION_IF_NULL(kernel_infer_actor);
180 (void)actors.emplace_back(static_cast<AbstractActorPtr>(kernel_infer_actor));
181 }
182 for (auto &kernel_resize_actor : actor_set->kernel_resize_actors_) {
183 MS_EXCEPTION_IF_NULL(kernel_resize_actor);
184 (void)actors.emplace_back(static_cast<AbstractActorPtr>(kernel_resize_actor));
185 }
186 for (auto &super_kernel_actor : actor_set->super_kernel_actors_) {
187 MS_EXCEPTION_IF_NULL(super_kernel_actor);
188 (void)actors.emplace_back(static_cast<AbstractActorPtr>(super_kernel_actor));
189 }
190 for (auto &any_type_kernel_actor : actor_set->any_type_kernel_actors_) {
191 MS_EXCEPTION_IF_NULL(any_type_kernel_actor);
192 (void)actors.emplace_back(static_cast<AbstractActorPtr>(any_type_kernel_actor));
193 }
194 for (auto &memory_actor : actor_set->memory_actors_) {
195 MS_EXCEPTION_IF_NULL(memory_actor);
196 (void)actors.emplace_back(static_cast<AbstractActorPtr>(memory_actor));
197 }
198 for (auto ©_actor : actor_set->copy_actors_) {
199 MS_EXCEPTION_IF_NULL(copy_actor);
200 (void)actors.emplace_back(static_cast<AbstractActorPtr>(copy_actor));
201 }
202 for (auto &fusion_actor : actor_set->fusion_actors_) {
203 MS_EXCEPTION_IF_NULL(fusion_actor);
204 (void)actors.emplace_back(static_cast<AbstractActorPtr>(fusion_actor));
205 }
206 for (auto &swap_actors : actor_set->swap_actors_) {
207 (void)std::for_each(swap_actors.cbegin(), swap_actors.cend(), [&](const MemSwapActorPtr &swap_actor) {
208 if (swap_actor != nullptr) {
209 (void)actors.emplace_back(static_cast<AbstractActorPtr>(swap_actor));
210 }
211 });
212 }
213 if (actor_set->loop_count_actor_ != nullptr) {
214 (void)actors.emplace_back(static_cast<AbstractActorPtr>(actor_set->loop_count_actor_));
215 }
216 if (actor_set->output_actor_ != nullptr) {
217 (void)actors.emplace_back(static_cast<AbstractActorPtr>(actor_set->output_actor_));
218 }
219 CollectControlActors(actor_set, &actors);
220 return actors;
221 }
222
HasMonadControl(const AnfNodePtr & input_node,const KernelGraphPtr & graph)223 bool SchedulerHelper::HasMonadControl(const AnfNodePtr &input_node, const KernelGraphPtr &graph) {
224 MS_EXCEPTION_IF_NULL(input_node);
225 MS_EXCEPTION_IF_NULL(graph);
226 const mindspore::HashSet<PrimitivePtr, PrimitiveHasher, PrimitiveEqual> auto_monad_prims = {
227 prim::kPrimDepend, prim::kPrimUpdateState, prim::kPrimLoad};
228 if (IsOneOfPrimitiveCNode(input_node, auto_monad_prims) || HasAbstractMonad(input_node)) {
229 return true;
230 }
231
232 // The subgraph input.
233 if (IsInternalParameter(input_node, graph)) {
234 auto front_output_with_index = graph->GetOriginFrontNodeByInternalParameter(input_node);
235 auto front_output_node = front_output_with_index.first;
236 MS_EXCEPTION_IF_NULL(front_output_node);
237 if (IsOneOfPrimitiveCNode(front_output_node, auto_monad_prims) || HasAbstractMonad(front_output_node)) {
238 MS_LOG(INFO) << "The graph " << graph->graph_id()
239 << " has monad control from internal parameter: " << input_node->DebugString()
240 << ", front output node: " << front_output_node->fullname_with_scope();
241 return true;
242 }
243 }
244
245 return false;
246 }
247
AddDeviceTensorStore(const AnfNode * anf_node,const DeviceTensorPtr & device_tensor)248 void SchedulerHelper::AddDeviceTensorStore(const AnfNode *anf_node, const DeviceTensorPtr &device_tensor) {
249 MS_EXCEPTION_IF_NULL(anf_node);
250 MS_EXCEPTION_IF_NULL(device_tensor);
251 MS_LOG(DEBUG) << "Add device tensor store:" << device_tensor << " for node:" << anf_node->DebugString()
252 << " node addr:" << anf_node << " device type:" << device_tensor->GetDeviceType();
253 DeviceTensorStore::GetInstance().Insert(const_cast<AnfNode *>(anf_node), device_tensor);
254 device_tensor->ClearFlag(device::kDeviceAddressFlagNotUsed);
255 UpdateRefCount(device_tensor.get(), true);
256 }
257
AddMonadDeviceTensorStore(AbstractActor * const to_actor,const CNodePtr & kernel,const KernelGraphPtr & graph)258 void SchedulerHelper::AddMonadDeviceTensorStore(AbstractActor *const to_actor, const CNodePtr &kernel,
259 const KernelGraphPtr &graph) {
260 MS_EXCEPTION_IF_NULL(to_actor);
261 MS_EXCEPTION_IF_NULL(kernel);
262 MS_EXCEPTION_IF_NULL(graph);
263 // Ref node monad device tensor store.
264 if (common::AnfAlgo::HasNodeAttr(kAttrRefNodeMonadOutputIdx, kernel)) {
265 auto output_idx = common::AnfAlgo::GetNodeAttr<size_t>(kernel, kAttrRefNodeMonadOutputIdx);
266 const auto &origin_pair = graph->GetRefNodeRecursive({kernel, output_idx});
267 auto front_node = AnfAlgo::FetchFrontNodeByBackendNode(origin_pair.first, *graph);
268 MS_EXCEPTION_IF_NULL(front_node);
269 if (IsPersistentDeviceTensor(front_node)) {
270 MS_LOG(INFO) << to_actor->GetAID().Name() << ", kernel:" << kernel->fullname_with_scope()
271 << " add ref node monad device tensor store:" << front_node->fullname_with_scope();
272 (void)to_actor->auto_monad_device_tensor_stores_.insert(front_node);
273 }
274 }
275
276 // Input node monad device tensor store.
277 if (!common::AnfAlgo::HasMonadInput(kernel)) {
278 return;
279 }
280
281 // Super kernel actor need fetch by the input device tensor store.
282 if (to_actor->type_ == KernelTransformType::kSuperKernelActor) {
283 for (size_t i = 0; i < common::AnfAlgo::GetInputTensorNum(kernel); ++i) {
284 KernelWithIndex from_kernel_with_output_idx = common::AnfAlgo::GetPrevNodeOutput(kernel, i, false);
285 auto front_node = AnfAlgo::FetchFrontNodeByBackendNode(from_kernel_with_output_idx.first, *graph);
286 MS_EXCEPTION_IF_NULL(front_node);
287 if (IsPersistentDeviceTensor(front_node)) {
288 MS_LOG(INFO) << to_actor->GetAID().Name() << ", kernel:" << kernel->fullname_with_scope()
289 << " add input node monad device tensor store:" << front_node->fullname_with_scope();
290 (void)to_actor->auto_monad_device_tensor_stores_.insert(front_node);
291 }
292 }
293 } else {
294 // Kernel actor can fetch by the device tensor store key directly.
295 const auto &device_tensor_store_keys = to_actor->device_tensor_store_keys_;
296 (void)std::for_each(device_tensor_store_keys.begin(), device_tensor_store_keys.end(), [&](const auto &store_key) {
297 MS_EXCEPTION_IF_NULL(store_key.second);
298 MS_LOG(INFO) << to_actor->GetAID().Name() << ", kernel:" << kernel->fullname_with_scope()
299 << " add input node monad device tensor store:" << store_key.second->fullname_with_scope();
300 (void)to_actor->auto_monad_device_tensor_stores_.insert(store_key.second);
301 });
302 }
303 }
304
IsIgnoredInputAddress(AbstractActor * const to_actor,size_t to_input_index)305 bool SchedulerHelper::IsIgnoredInputAddress(AbstractActor *const to_actor, size_t to_input_index) {
306 MS_EXCEPTION_IF_NULL(to_actor);
307 if (to_actor->type() != KernelTransformType::kKernelActor) {
308 return false;
309 }
310
311 auto kernel_actor = dynamic_cast<KernelActor *>(to_actor);
312 auto &to_kernel = kernel_actor->kernel();
313 MS_EXCEPTION_IF_NULL(to_kernel);
314
315 if (IsSkipLaunchShapeRelatedOp(kernel_actor)) {
316 kernel_actor->set_skip_launch_shape_related_op(true);
317 return true;
318 }
319
320 auto kernel_mod = AnfAlgo::GetKernelMod(to_kernel);
321 MS_EXCEPTION_IF_NULL(kernel_mod);
322 const auto &ignored_address = kernel_mod->GetLaunchIgnoredInputAddressIdx();
323 if (ignored_address.empty()) {
324 return false;
325 }
326
327 if (std::find(ignored_address.begin(), ignored_address.end(), to_input_index) != ignored_address.end()) {
328 MS_LOG(INFO) << "Ignore the input address for kernel: " << to_kernel->fullname_with_scope()
329 << " with input index: " << to_input_index;
330 return true;
331 }
332
333 return false;
334 }
335
GetIgnoredInputAddressCount(const AnfNodePtr & node)336 size_t SchedulerHelper::GetIgnoredInputAddressCount(const AnfNodePtr &node) {
337 MS_EXCEPTION_IF_NULL(node);
338 size_t input_num = common::AnfAlgo::GetInputTensorNum(node);
339 auto kernel_mod = AnfAlgo::GetKernelMod(node);
340 MS_EXCEPTION_IF_NULL(kernel_mod);
341 const auto &ignored_input_addresses = kernel_mod->GetLaunchIgnoredInputAddressIdx();
342 if (ignored_input_addresses.empty()) {
343 return 0;
344 }
345
346 auto count = std::count_if(ignored_input_addresses.begin(), ignored_input_addresses.end(),
347 [input_num](size_t index) { return index < input_num; });
348 return static_cast<size_t>(count);
349 }
350
AddDataArrow(AbstractActor * const from_actor,AbstractActor * const to_actor,size_t from_output_index,size_t to_input_index,const AnfNodePtr & from_kernel)351 void SchedulerHelper::AddDataArrow(AbstractActor *const from_actor, AbstractActor *const to_actor,
352 size_t from_output_index, size_t to_input_index, const AnfNodePtr &from_kernel) {
353 MS_EXCEPTION_IF_NULL(from_actor);
354 MS_EXCEPTION_IF_NULL(to_actor);
355 MS_LOG(DEBUG) << "Add data arrow from actor:" << from_actor->GetAID() << " index:" << from_output_index
356 << " to actor:" << to_actor->GetAID() << " to index:" << to_input_index
357 << " from kernel:" << (from_kernel == nullptr ? "null" : from_kernel->fullname_with_scope());
358 // Check the data arrow legitimacy.
359 if (IsControlFlowActor(to_actor->type()) && (from_actor->type() == KernelTransformType::kKernelActor) &&
360 (to_actor->type() != KernelTransformType::kExitActor)) {
361 MS_LOG(WARNING) << "Kernel actor:" << from_actor->GetAID().Name()
362 << " link data arrow to actor:" << to_actor->GetAID().Name() << " is not an exit actor.";
363 }
364
365 if (from_actor->type() == KernelTransformType::kKernelActor &&
366 to_actor->type() == KernelTransformType::kKernelActor) {
367 auto from_kernel_actor = dynamic_cast<KernelActor *>(from_actor);
368 MS_EXCEPTION_IF_NULL(from_kernel_actor);
369 if (IsSkipLaunchShapeRelatedOp(from_kernel_actor)) {
370 from_kernel_actor->set_skip_launch_shape_related_op(true);
371 }
372 }
373
374 // The continuous memory inpus need allocate memory in advance, so must be from the inside subgraph.
375 if (to_actor->type() == KernelTransformType::kKernelActor) {
376 auto to_kernel_actor = dynamic_cast<KernelActor *>(to_actor);
377 MS_EXCEPTION_IF_NULL(to_kernel_actor);
378 if (to_kernel_actor->inputs_continuous_memory() && (from_actor->type() != KernelTransformType::kKernelActor)) {
379 MS_LOG(INTERNAL_EXCEPTION)
380 << "#dmsg#Runtime error info:#dmsg#The continuous memory input is not from the inside subgraph, to actor: "
381 << to_actor->GetAID().Name() << ", to input index: " << to_input_index
382 << ", from actor: " << from_actor->GetAID().Name() << ", from output index: " << from_output_index;
383 }
384 }
385
386 AddMemorySign(from_actor, to_actor);
387
388 auto data_arrow = std::make_shared<DataArrow>(from_output_index, to_actor->GetAID(), to_input_index);
389 (void)from_actor->output_data_arrows_.emplace_back(data_arrow);
390 (void)from_actor->output_data_nodes_.emplace_back(from_kernel);
391 to_actor->input_datas_num_++;
392 (void)to_actor->input_data_arrow_aids_.emplace_back(std::make_pair(from_actor->GetAID(), data_arrow.get()));
393
394 if (from_kernel == nullptr) {
395 return;
396 }
397 // Update the reference count of from_kernel.
398 auto device_tensor = AnfAlgo::GetMutableOutputAddr(from_kernel, from_output_index, false);
399 MS_EXCEPTION_IF_NULL(device_tensor);
400 // The superkernel actor is linked by input parameter, maybe the not used parameter.
401 if (to_actor->type() != KernelTransformType::kSuperKernelActor) {
402 device_tensor->ClearFlag(device::kDeviceAddressFlagNotUsed);
403 }
404 // The device address of super kernel actor can't be changed, so set the max reference count.
405 if (IsControlFlowActor(to_actor->type()) || (from_actor->type_ == KernelTransformType::kSuperKernelActor) ||
406 (to_actor->type_ == KernelTransformType::kSuperKernelActor)) {
407 UpdateRefCount(device_tensor.get(), true);
408 } else {
409 UpdateDataArrowRefCount(to_actor, to_input_index, device_tensor);
410 }
411
412 if (IsControlFlowActor(to_actor->type())) {
413 device_tensor->SetNodeIndex(from_kernel, from_output_index);
414 }
415 }
416
AddResultArrow(AbstractActor * const from_actor,OutputActor * const to_actor,const AnfNodePtr & from_kernel,size_t from_output_index,size_t output_position)417 void SchedulerHelper::AddResultArrow(AbstractActor *const from_actor, OutputActor *const to_actor,
418 const AnfNodePtr &from_kernel, size_t from_output_index, size_t output_position) {
419 MS_EXCEPTION_IF_NULL(to_actor);
420 MS_EXCEPTION_IF_NULL(from_kernel);
421
422 if (from_actor == nullptr) {
423 (void)to_actor->device_tensor_store_keys_.emplace_back(output_position, from_kernel);
424 } else {
425 auto result_arrow = std::make_shared<DataArrow>(from_output_index, to_actor->GetAID(), output_position);
426 (void)from_actor->output_data_arrows_.insert(from_actor->output_data_arrows_.begin(), result_arrow);
427 (void)from_actor->output_data_nodes_.insert(from_actor->output_data_nodes_.begin(), from_kernel);
428 to_actor->input_datas_num_++;
429 (void)to_actor->input_data_arrow_aids_.emplace_back(std::make_pair(from_actor->GetAID(), result_arrow.get()));
430 }
431
432 if (!AnfAlgo::OutputAddrExist(from_kernel, from_output_index, false)) {
433 MS_LOG_WITH_NODE(INTERNAL_EXCEPTION, from_kernel)
434 << "#dmsg#Runtime error info:#dmsg#" << from_kernel->DebugString() << " index:" << from_output_index
435 << " device address does not exist";
436 }
437 auto device_tensor = AnfAlgo::GetMutableOutputAddr(from_kernel, from_output_index, false);
438 MS_EXCEPTION_IF_NULL(device_tensor);
439 device_tensor->ClearFlag(device::kDeviceAddressFlagNotUsed);
440 // The output actor need use the relevant information of node to create output tensor.
441 device_tensor->SetNodeIndex(from_kernel, from_output_index);
442 // The device tensor of graph out need be taken over by host tensor, so set the max reference count.
443 UpdateRefCount(device_tensor.get(), true);
444
445 MS_LOG(DEBUG) << "Add result arrow from actor:" << (from_actor != nullptr ? from_actor->GetAID().Name() : "null")
446 << " to actor:" << to_actor->GetAID() << " from kernel"
447 << (from_kernel == nullptr ? "null" : from_kernel->DebugString()) << " device address:" << device_tensor
448 << " original ref count:" << device_tensor->original_ref_count()
449 << " ref count:" << device_tensor->ref_count()
450 << " dynamic ref count:" << device_tensor->dynamic_ref_count();
451
452 // Set the device contexts of to_actor.
453 if (output_position >= to_actor->device_contexts_.size()) {
454 MS_LOG(INTERNAL_EXCEPTION) << "#dmsg#Runtime error info:#dmsg#The output position is out of range.";
455 }
456 auto device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
457 {device_tensor->device_name(), device_tensor->device_id()});
458 to_actor->device_contexts_[output_position] = device_context;
459 }
460
AddControlArrow(AbstractActor * const from_actor,AbstractActor * const to_actor)461 void SchedulerHelper::AddControlArrow(AbstractActor *const from_actor, AbstractActor *const to_actor) {
462 MS_EXCEPTION_IF_NULL(from_actor);
463 MS_EXCEPTION_IF_NULL(to_actor);
464
465 // Check the control arrow whether exists.
466 auto iter = std::find_if(from_actor->output_control_arrows_.begin(), from_actor->output_control_arrows_.end(),
467 [&to_actor](const auto &output_control_arrow) {
468 return output_control_arrow->to_op_id_.Name() == to_actor->GetAID().Name();
469 });
470 if (iter != from_actor->output_control_arrows_.end()) {
471 // The stack actor can only link the single control arrow.
472 if (to_actor->type_ == KernelTransformType::kStackActor) {
473 MS_LOG(INTERNAL_EXCEPTION) << "#dmsg#Runtime error info:#dmsg#The control arrow between "
474 << from_actor->GetAID().Name() << " and " << to_actor->GetAID().Name()
475 << " is repeated.";
476 }
477 return;
478 }
479
480 // No need add control arrow if already exists data arrow in from and to actor.
481 if (from_actor->type() == KernelTransformType::kKernelActor &&
482 to_actor->type() == KernelTransformType::kKernelActor) {
483 const auto &input_data_arrows = to_actor->input_data_arrow_aids();
484 if (std::any_of(input_data_arrows.begin(), input_data_arrows.end(),
485 [&from_actor](const std::pair<AID, DataArrow *> &input_data_arrow_pair) {
486 return input_data_arrow_pair.first.Name() == from_actor->GetAID().Name();
487 })) {
488 MS_LOG(INFO) << "No need add control arrow, because already exists data arrow in from actor: "
489 << from_actor->GetAID().Name() << " and to actor: " << to_actor->GetAID().Name();
490 return;
491 }
492 }
493
494 auto control_arrow = std::make_shared<ControlArrow>(to_actor->GetAID());
495 (void)from_actor->output_control_arrows_.emplace_back(control_arrow);
496 to_actor->input_controls_num_++;
497 (void)to_actor->input_control_arrow_aids_.emplace_back(std::make_pair(from_actor->GetAID(), control_arrow.get()));
498 MS_LOG(DEBUG) << "Add control arrow from actor:" << from_actor->GetAID() << " to actor:" << to_actor->GetAID();
499 AddMemorySign(from_actor, to_actor);
500 }
501
AddPartialArrow(ControlActor * const from_actor,ControlActor * const to_actor,size_t from_index,size_t to_index)502 void SchedulerHelper::AddPartialArrow(ControlActor *const from_actor, ControlActor *const to_actor, size_t from_index,
503 size_t to_index) {
504 MS_EXCEPTION_IF_NULL(from_actor);
505 MS_EXCEPTION_IF_NULL(to_actor);
506 auto op_arrow = std::make_shared<DataArrow>(from_index, to_actor->GetAID(), to_index);
507 (void)from_actor->output_partial_arrows_.emplace_back(op_arrow);
508 to_actor->input_partials_num_++;
509 (void)to_actor->input_partial_arrow_aids_.emplace_back(from_actor->GetAID());
510 }
511
AddBranchIDArrow(ControlActor * const from_actor,ControlActor * const to_actor)512 void SchedulerHelper::AddBranchIDArrow(ControlActor *const from_actor, ControlActor *const to_actor) {
513 MS_EXCEPTION_IF_NULL(from_actor);
514 MS_EXCEPTION_IF_NULL(to_actor);
515 (void)from_actor->output_branch_id_arrows_.emplace_back(to_actor->GetAID());
516 (void)to_actor->input_branch_id_arrow_aids_.emplace_back(from_actor->GetAID());
517 to_actor->input_branch_ids_num_++;
518 }
519
AddLoopBodyControlArrow(AbstractActor * from_actor,EntranceActor * to_actor)520 void SchedulerHelper::AddLoopBodyControlArrow(AbstractActor *from_actor, EntranceActor *to_actor) {
521 MS_EXCEPTION_IF_NULL(from_actor);
522 MS_EXCEPTION_IF_NULL(to_actor);
523 MS_LOG(DEBUG) << "Link loop body control arrow from:" << from_actor->GetAID() << " to actor:" << to_actor->GetAID();
524 auto control_arrow = std::make_shared<ControlArrow>(to_actor->GetAID());
525 (void)from_actor->output_control_arrows_.emplace_back(control_arrow);
526 to_actor->loop_body_input_controls_nums_++;
527 (void)to_actor->loop_body_input_control_arrow_aids_.emplace_back(from_actor->GetAID());
528 }
529
AddDataWithBranchIDArrow(GatherActor * const gather_actor,const EntranceActor * entrance_actor,const FuncGraphPtr & func_graph)530 void SchedulerHelper::AddDataWithBranchIDArrow(GatherActor *const gather_actor, const EntranceActor *entrance_actor,
531 const FuncGraphPtr &func_graph) {
532 MS_EXCEPTION_IF_NULL(gather_actor);
533 MS_EXCEPTION_IF_NULL(entrance_actor);
534 (void)gather_actor->output_data_with_branch_id_arrows_[func_graph.get()].emplace_back(entrance_actor->GetAID());
535 }
536
AddDataArrowForExitActor(ExitActor * const exit_actor,AbstractActor * const to_actor,size_t from_index,size_t to_index,int branch_id)537 void SchedulerHelper::AddDataArrowForExitActor(ExitActor *const exit_actor, AbstractActor *const to_actor,
538 size_t from_index, size_t to_index, int branch_id) {
539 MS_EXCEPTION_IF_NULL(exit_actor);
540 MS_EXCEPTION_IF_NULL(to_actor);
541
542 MS_LOG(DEBUG) << "Link data arrow from actor:" << exit_actor->GetAID() << " from index:" << from_index
543 << " to actor:" << to_actor->GetAID() << " to index:" << to_index;
544 auto data_arrow = std::make_shared<DataArrow>(from_index, to_actor->GetAID(), to_index);
545 (void)exit_actor->output_branch_data_arrows_[branch_id].emplace_back(data_arrow);
546 (void)to_actor->input_data_arrow_aids_.emplace_back(std::make_pair(exit_actor->GetAID(), data_arrow.get()));
547 }
548
AddPartialArrowForExitActor(ExitActor * const exit_actor,ControlActor * const to_actor,size_t from_index,size_t to_index,int branch_id)549 void SchedulerHelper::AddPartialArrowForExitActor(ExitActor *const exit_actor, ControlActor *const to_actor,
550 size_t from_index, size_t to_index, int branch_id) {
551 MS_EXCEPTION_IF_NULL(exit_actor);
552 MS_EXCEPTION_IF_NULL(to_actor);
553 MS_LOG(DEBUG) << "Link partial arrow from actor:" << exit_actor->GetAID() << " from index:" << from_index
554 << " to actor:" << to_actor->GetAID() << " to index:" << to_index;
555 auto partial_arrow = std::make_shared<DataArrow>(from_index, to_actor->GetAID(), to_index);
556 (void)exit_actor->output_branch_partial_arrows_[branch_id].emplace_back(partial_arrow);
557 (void)to_actor->input_partial_arrow_aids_.emplace_back(exit_actor->GetAID());
558 }
559
AddControlArrowForExitActor(ExitActor * from_actor,AbstractActor * to_actor,int branch_id)560 void SchedulerHelper::AddControlArrowForExitActor(ExitActor *from_actor, AbstractActor *to_actor, int branch_id) {
561 MS_EXCEPTION_IF_NULL(from_actor);
562 MS_EXCEPTION_IF_NULL(to_actor);
563
564 MS_LOG(DEBUG) << "Link control arrow from:" << from_actor->GetAID() << " to:" << to_actor->GetAID();
565 (void)from_actor->output_branch_control_arrows_[branch_id].emplace_back(to_actor->GetAID());
566 to_actor->input_controls_num_++;
567 (void)to_actor->input_control_arrow_aids_.emplace_back(std::make_pair(from_actor->GetAID(), nullptr));
568 }
569
AddFormalParameterDeviceTensor(ControlActor * const from_actor,size_t from_index,const AnfNodePtr & input_node,const KernelGraphPtr & graph)570 void SchedulerHelper::AddFormalParameterDeviceTensor(ControlActor *const from_actor, size_t from_index,
571 const AnfNodePtr &input_node, const KernelGraphPtr &graph) {
572 MS_EXCEPTION_IF_NULL(from_actor);
573 MS_EXCEPTION_IF_NULL(input_node);
574 MS_EXCEPTION_IF_NULL(graph);
575 // Graph mode does not support dynamic shape and ref node.
576 if (graph->is_graph_run_mode() || graph->is_any_type_input()) {
577 return;
578 }
579
580 // Collect backend parameters with dynamic shapes.
581 auto base_shape = input_node->Shape();
582 if (input_node->isa<Parameter>() && base_shape != nullptr &&
583 ((base_shape->isa<abstract::Shape>() && base_shape->IsDynamic()) ||
584 base_shape->isa<abstract::DynamicSequenceShape>())) {
585 if (from_index >= from_actor->backend_parameters_.size()) {
586 MS_LOG(INTERNAL_EXCEPTION) << "#dmsg#Runtime error info:#dmsg#Invalid from index:" << from_index
587 << " for actor:" << from_actor->GetAID()
588 << " vector size:" << from_actor->backend_parameters_.size();
589 }
590 MS_LOG(INFO) << "Add dynamic shape backend parameter:" << input_node->DebugString() << " index:" << from_index
591 << " for actor:" << from_actor->GetAID();
592 (void)from_actor->backend_parameters_[from_index].emplace_back(input_node);
593 }
594
595 if (!common::AnfAlgo::HasAbstractRef(input_node)) {
596 return;
597 }
598
599 auto device_tensor = AnfAlgo::GetMutableOutputAddr(input_node, 0, false);
600 MS_EXCEPTION_IF_NULL(device_tensor);
601 (void)from_actor->ref_formal_parameter_device_tensors_[from_index].insert(device_tensor);
602 if (graph->IsRefOutputMapValue({input_node, 0})) {
603 (void)from_actor->ref_node_formal_parameter_device_tensors_[from_index].insert(device_tensor);
604 }
605
606 device_tensor->ClearFlag(device::kDeviceAddressFlagNotUsed);
607 UpdateRefCount(device_tensor.get(), true);
608 device_tensor->SetNodeIndex(input_node, 0);
609 }
610
ConvertDataArrowToControlArrow(AbstractActor * const from_actor,AbstractActor * const to_actor,const DataArrowPtr & data_arrow,size_t data_arrow_index)611 void SchedulerHelper::ConvertDataArrowToControlArrow(AbstractActor *const from_actor, AbstractActor *const to_actor,
612 const DataArrowPtr &data_arrow, size_t data_arrow_index) {
613 MS_EXCEPTION_IF_NULL(from_actor);
614 MS_EXCEPTION_IF_NULL(to_actor);
615 MS_EXCEPTION_IF_NULL(data_arrow);
616 MS_EXCEPTION_IF_CHECK_FAIL((data_arrow_index < from_actor->output_data_nodes_.size()), "Index out of range.");
617 auto &need_converted_node = from_actor->output_data_nodes_[data_arrow_index];
618 MS_EXCEPTION_IF_NULL(need_converted_node);
619
620 // Skip the ref node because its reference count cann‘t be recalculated correctly.
621 auto device_tensor =
622 AnfAlgo::GetMutableOutputAddr(need_converted_node, IntToSize(data_arrow->from_output_index_), false);
623 MS_EXCEPTION_IF_NULL(device_tensor);
624 if (TEST_FLAG(device_tensor->flag(), device::kDeviceAddressFlagRefNode)) {
625 MS_LOG(INFO) << "Skip the invalid data arrow of ref node, from actor:" << from_actor->GetAID().Name()
626 << ", from index:" << data_arrow->from_output_index_ << ", to actor:" << to_actor->GetAID().Name()
627 << ", to index:" << data_arrow->to_input_index_;
628 return;
629 }
630
631 auto kernel_info = dynamic_cast<KernelInfo *>(need_converted_node->kernel_info());
632 MS_EXCEPTION_IF_NULL(kernel_info);
633 const auto &somas_outputs = kernel_info->somas_output_result();
634 if (kernel_info->IsTensorEnableSomas(somas_outputs, data_arrow->from_output_index_)) {
635 MS_LOG(INFO) << "Skip the invalid data arrow of somas inner address, from actor:" << from_actor->GetAID().Name()
636 << ", from index:" << data_arrow->from_output_index_ << ", to actor:" << to_actor->GetAID().Name()
637 << ", to index:" << data_arrow->to_input_index_;
638 return;
639 }
640
641 // Erase the output data arrow in from actor.
642 (void)from_actor->output_data_arrows_.erase(from_actor->output_data_arrows_.begin() + SizeToLong(data_arrow_index));
643 (void)from_actor->output_data_nodes_.erase(from_actor->output_data_nodes_.begin() + SizeToLong(data_arrow_index));
644
645 // Erase the input data arrow aid in to actor.
646 bool to_actor_erase = false;
647 for (auto iter = to_actor->input_data_arrow_aids_.begin(); iter != to_actor->input_data_arrow_aids_.end(); ++iter) {
648 if ((*iter).first == from_actor->GetAID()) {
649 (void)to_actor->input_data_arrow_aids_.erase(iter);
650 to_actor_erase = true;
651 to_actor->input_datas_num_--;
652 break;
653 }
654 }
655 if (to_actor_erase == false) {
656 MS_LOG(INTERNAL_EXCEPTION) << "#dmsg#Runtime error info:#dmsg#Erase no input data arrow, from actor:"
657 << from_actor->GetAID().Name() << ", to actor:" << to_actor->GetAID().Name()
658 << ", data arrow index:" << data_arrow_index;
659 }
660
661 // Recalculate the ref count of converted node.
662 size_t old_ref_count = device_tensor->ref_count();
663 // Ref count Initial value is 1.
664 size_t new_ref_count = 1;
665 for (auto &output_data_arrow : from_actor->output_data_arrows_) {
666 MS_EXCEPTION_IF_NULL(output_data_arrow);
667 if (output_data_arrow->from_output_index_ != data_arrow->from_output_index_) {
668 continue;
669 }
670 if ((output_data_arrow->to_op_id_.Name().find(kExitActorNameSuffix) != std::string::npos) ||
671 (output_data_arrow->to_op_id_.Name().find(kOutputActorNameSuffix) != std::string::npos)) {
672 new_ref_count = SIZE_MAX;
673 break;
674 }
675 ++new_ref_count;
676 }
677 device_tensor->set_original_ref_count(new_ref_count);
678 device_tensor->ResetRefCount();
679 MS_LOG(INFO) << "Erase the invalid data arrow, from actor:" << from_actor->GetAID().Name()
680 << ", from index:" << data_arrow->from_output_index_ << ", to actor:" << to_actor->GetAID().Name()
681 << ", to index:" << data_arrow->to_input_index_ << ", old ref count:" << old_ref_count
682 << ", new ref count:" << new_ref_count;
683
684 // Add the control arrow.
685 SchedulerHelper::AddControlArrow(from_actor, to_actor);
686 }
687
FuseDataArrowsToBatchDataArrow(AbstractActor * const actor)688 void SchedulerHelper::FuseDataArrowsToBatchDataArrow(AbstractActor *const actor) {
689 MS_EXCEPTION_IF_NULL(actor);
690 // Count the number of the same destination actor.
691 mindspore::HashMap<std::string, size_t> to_actor_count;
692 for (const auto &data_arrow : actor->output_data_arrows()) {
693 MS_EXCEPTION_IF_NULL(data_arrow);
694 ++(to_actor_count[data_arrow->to_op_id_.Name()]);
695 }
696
697 // Sign and add the batch data arrow.
698 for (auto &data_arrow : actor->output_data_arrows()) {
699 MS_EXCEPTION_IF_NULL(data_arrow);
700 auto &to_op_name = data_arrow->to_op_id_.Name();
701 // The output data cannot be reused whose destination is stack actor, and cannot to be fused.
702 if ((to_actor_count[to_op_name] > 1) && (to_op_name.find(kStackActorNameSuffix) == std::string::npos)) {
703 SET_FLAG(data_arrow->flag_, kOutputDataFlagBatch);
704 (void)actor->batch_output_data_arrows_[to_op_name].emplace_back(data_arrow);
705 }
706 }
707 }
708
AddDependency(AbstractActor * const actor,const AbstractActor * dependent_actor)709 void SchedulerHelper::AddDependency(AbstractActor *const actor, const AbstractActor *dependent_actor) {
710 MS_EXCEPTION_IF_NULL(actor);
711 MS_EXCEPTION_IF_NULL(dependent_actor);
712 // For example, ActorA->dependent_actor->actor, the expanded dependent actors of actor are dependent_actor and ActorA.
713 (void)actor->dependent_actors_.insert(dependent_actor->GetAID().Name());
714 actor->dependent_actors_.insert(dependent_actor->dependent_actors_.begin(), dependent_actor->dependent_actors_.end());
715 }
716
CheckDependency(const std::vector<AbstractActorPtr> & output_actors)717 bool SchedulerHelper::CheckDependency(const std::vector<AbstractActorPtr> &output_actors) {
718 if (output_actors.size() <= 1) {
719 return true;
720 }
721
722 for (size_t i = 1; i < output_actors.size(); ++i) {
723 auto &pre_actor = output_actors[i - 1];
724 auto &actor = output_actors[i];
725 MS_EXCEPTION_IF_NULL(pre_actor);
726 MS_EXCEPTION_IF_NULL(actor);
727 // The outputs have no dependencies.
728 if ((actor->dependent_actors_.count(pre_actor->GetAID().Name()) == 0) &&
729 (pre_actor->dependent_actors_.count(actor->GetAID().Name()) == 0)) {
730 return false;
731 }
732 }
733
734 return true;
735 }
736
BuildFusionActor(const std::vector<AbstractActorPtr> & actors)737 FusionActorPtr SchedulerHelper::BuildFusionActor(const std::vector<AbstractActorPtr> &actors) {
738 if (actors.size() <= 1) {
739 MS_LOG(INTERNAL_EXCEPTION) << "#dmsg#Runtime error info:#dmsg#The fusion actor size must be greater than 1.";
740 }
741
742 std::string fusion_actor_name = std::to_string(++fusion_actor_index_) + kFusionActorNameSuffix;
743 auto fusion_actor = std::make_shared<FusionActor>(fusion_actor_name);
744 InsertActor(fusion_actor.get());
745 for (auto &actor : actors) {
746 MS_EXCEPTION_IF_NULL(actor);
747 actor->parent_fusion_actor_ = fusion_actor.get();
748 MS_LOG(DEBUG) << "Set fusion actor:" << fusion_actor->GetAID() << " to actor:" << actor->GetAID();
749 fusion_actor->sub_actors_[actor->GetAID().Name()] = actor;
750 }
751 return fusion_actor;
752 }
753
AddArrowForFusionActor(FusionActor * fusion_actor)754 void SchedulerHelper::AddArrowForFusionActor(FusionActor *fusion_actor) {
755 MS_EXCEPTION_IF_NULL(fusion_actor);
756 for (auto &actor_iter : fusion_actor->sub_actors_) {
757 auto &actor = actor_iter.second;
758 MS_EXCEPTION_IF_NULL(actor);
759
760 // Link data arrow of fusion actor by the input data arrow of real actor.
761 for (auto &input_data_arrow_aid : actor->input_data_arrow_aids_) {
762 auto input_data_arrow = input_data_arrow_aid.second;
763 MS_EXCEPTION_IF_NULL(input_data_arrow);
764 // Mark the kOutputDataFlagBetweenFusion flag when the input data arrow is the Internal actor in fusion actor.
765 if (fusion_actor->sub_actors_.count(input_data_arrow_aid.first.Name()) > 0) {
766 SET_FLAG(input_data_arrow->flag_, kOutputDataFlagBetweenFusion);
767 continue;
768 }
769
770 SET_FLAG(input_data_arrow->flag_, kOutputDataFlagToFusion);
771 // The ActorB is in fusion actor and the input ActorA is on the outside of fusion actor, then change
772 // 'ActorA->ActorB' to 'ActorA->FusionActor'.
773 auto from_actor = FetchActor(input_data_arrow_aid.first.Name());
774 MS_EXCEPTION_IF_NULL(from_actor);
775 // Record the input index of real actor and fusion actor.
776 (void)fusion_actor->real_input_data_.emplace_back(std::make_pair(actor.get(), input_data_arrow->to_input_index_));
777 from_actor->data_arrow_to_fusion_actor_indexs_[input_data_arrow] = fusion_actor->input_data_arrow_aids_.size();
778 input_data_arrow->to_input_index_ = SizeToInt(fusion_actor->input_data_arrow_aids_.size());
779
780 input_data_arrow->to_op_id_ = fusion_actor->GetAID();
781 ++fusion_actor->input_datas_num_;
782 (void)fusion_actor->input_data_arrow_aids_.emplace_back(
783 std::make_pair(input_data_arrow_aid.first, input_data_arrow));
784 }
785
786 // Link control arrow of fusion actor by the input control arrow of real actor.
787 for (auto &input_control_arrow_aid : actor->input_control_arrow_aids_) {
788 auto input_control_arrow = input_control_arrow_aid.second;
789 MS_EXCEPTION_IF_NULL(input_control_arrow);
790 // Mark the kOutputDataFlagBetweenFusion flag when the input control arrow is the Internal actor in fusion
791 // actor.
792 if (fusion_actor->sub_actors_.count(input_control_arrow_aid.first.Name()) > 0) {
793 SET_FLAG(input_control_arrow->flag_, kOutputDataFlagBetweenFusion);
794 continue;
795 }
796
797 SET_FLAG(input_control_arrow->flag_, kOutputDataFlagToFusion);
798 // The ActorB is in fusion actor and the input ActorA is on the outside of fusion actor, then change
799 // 'ActorA->ActorB' to 'ActorA->FusionActor'.
800 (void)fusion_actor->real_input_controls_[input_control_arrow_aid.first.Name()].emplace_back(actor.get());
801 input_control_arrow->to_op_id_ = fusion_actor->GetAID();
802 ++fusion_actor->input_controls_num_;
803 (void)fusion_actor->input_control_arrow_aids_.emplace_back(
804 std::make_pair(input_control_arrow_aid.first, input_control_arrow));
805 }
806 }
807 }
808
AddMemorySign(AbstractActor * const from_actor,AbstractActor * const to_actor)809 void SchedulerHelper::AddMemorySign(AbstractActor *const from_actor, AbstractActor *const to_actor) {
810 MS_EXCEPTION_IF_NULL(from_actor);
811 MS_EXCEPTION_IF_NULL(to_actor);
812 auto ms_context = MsContext::GetInstance();
813 MS_EXCEPTION_IF_NULL(ms_context);
814 if (ms_context->get_param<int>(MS_CTX_MEMORY_OPTIMIZE_LEVEL) == kOptimizeO0) {
815 return;
816 }
817 // The link of memory actor no need add the memory sign.
818 if (IsMemoryActor(from_actor->type()) || IsMemoryActor(to_actor->type())) {
819 return;
820 }
821
822 // Add the somas info.
823 AddSomasInfo(from_actor);
824 AddSomasInfo(to_actor);
825
826 auto from_graph = FetchKernelGraphByActor(from_actor);
827 auto to_graph = FetchKernelGraphByActor(to_actor);
828 // Add the memory alloc and free sign at the boundary of the graph.
829 if ((from_graph != nullptr) && (to_graph != nullptr)) {
830 // The same graph no need insert the memory actor.
831 if (from_graph->graph_id() == to_graph->graph_id()) {
832 return;
833 }
834 AddMemoryFreeSign(from_actor, to_actor, from_graph);
835 AddMemoryAllocSign(from_actor, to_actor, to_graph);
836 } else if (from_graph != nullptr) {
837 AddMemoryFreeSign(from_actor, to_actor, from_graph);
838 } else if (to_graph != nullptr) {
839 AddMemoryAllocSign(from_actor, to_actor, to_graph);
840 }
841 }
842
FetchKernelGraphByActor(AbstractActor * const actor)843 KernelGraphPtr SchedulerHelper::FetchKernelGraphByActor(AbstractActor *const actor) {
844 MS_EXCEPTION_IF_NULL(actor);
845 AnfNode *from_kernel = nullptr;
846 if (actor->type() == KernelTransformType::kKernelActor ||
847 actor->type() == KernelTransformType::kConditionGatherActor ||
848 actor->type() == KernelTransformType::kConditionSwitchActor) {
849 auto kernel_actor = dynamic_cast<KernelActor *>(actor);
850 MS_EXCEPTION_IF_NULL(kernel_actor);
851 from_kernel = kernel_actor->kernel().get();
852 MS_EXCEPTION_IF_NULL(from_kernel);
853 }
854
855 // The device data source actor is from the GetNext cnode that is not a boundary of the graph and is equivalent to the
856 // kernel actor when inserted the memory actor.
857 if (actor->type() == KernelTransformType::kDeviceDataSourceActor) {
858 auto device_ds_actor = dynamic_cast<DeviceQueueDataSourceActor *>(actor);
859 MS_EXCEPTION_IF_NULL(device_ds_actor);
860 from_kernel = device_ds_actor->data_kernel().get();
861 MS_EXCEPTION_IF_NULL(from_kernel);
862 }
863
864 // Only the copy actor from device tensor store need to fetch the kernel graph, because the copy actor is not a
865 // boundary of the graph and is equivalent to the kernel actor when inserted the memory actor.
866 if ((actor->type() == KernelTransformType::kCopyActor) &&
867 (actor->GetAID().Name().find(kCopyActorNameSignFromStore) != std::string::npos)) {
868 auto copy_actor = dynamic_cast<CopyActor *>(actor);
869 MS_EXCEPTION_IF_NULL(copy_actor);
870 from_kernel = copy_actor->from_kernel_;
871 }
872
873 if (from_kernel == nullptr) {
874 return nullptr;
875 }
876 auto graph = AnfAlgo::FetchKernelGraph(from_kernel);
877 if (graph == nullptr) {
878 MS_LOG(INTERNAL_EXCEPTION) << "#dmsg#Runtime error info:#dmsg#No associated graph for node: "
879 << from_kernel->fullname_with_scope();
880 }
881
882 return graph;
883 }
884
AddMemoryAllocSign(AbstractActor * const from_actor,AbstractActor * const to_actor,const KernelGraphPtr & to_graph)885 void SchedulerHelper::AddMemoryAllocSign(AbstractActor *const from_actor, AbstractActor *const to_actor,
886 const KernelGraphPtr &to_graph) {
887 MS_EXCEPTION_IF_NULL(from_actor);
888 MS_EXCEPTION_IF_NULL(to_actor);
889 MS_EXCEPTION_IF_NULL(to_graph);
890 // Somas is not work for this graph.
891 if (to_graph->somas_whole_block_size() == 0) {
892 return;
893 }
894
895 // Set the memory alloc info.
896 to_actor->memory_alloc_insert_position_ = from_actor;
897 }
898
AddMemoryFreeSign(AbstractActor * const from_actor,AbstractActor * const to_actor,const KernelGraphPtr & from_graph)899 void SchedulerHelper::AddMemoryFreeSign(AbstractActor *const from_actor, AbstractActor *const to_actor,
900 const KernelGraphPtr &from_graph) {
901 MS_EXCEPTION_IF_NULL(from_actor);
902 MS_EXCEPTION_IF_NULL(to_actor);
903 MS_EXCEPTION_IF_NULL(from_graph);
904 // Somas is not work for this graph.
905 if (from_graph->somas_whole_block_size() == 0) {
906 return;
907 }
908
909 // Set the memory free info.
910 from_actor->memory_free_insert_position_ = to_actor;
911 }
912
AddSomasInfo(AbstractActor * const actor)913 void SchedulerHelper::AddSomasInfo(AbstractActor *const actor) {
914 MS_EXCEPTION_IF_NULL(actor);
915 // Only the kernel actor supports somas.
916 if (actor->type() != KernelTransformType::kKernelActor &&
917 actor->type() != KernelTransformType::kConditionGatherActor &&
918 actor->type() != KernelTransformType::kConditionSwitchActor) {
919 return;
920 }
921 auto kernel_actor = dynamic_cast<KernelActor *>(actor);
922 MS_EXCEPTION_IF_NULL(kernel_actor);
923 if (kernel_actor->somas_info_ != nullptr) {
924 return;
925 }
926
927 MS_EXCEPTION_IF_NULL(kernel_actor->kernel());
928 auto graph = AnfAlgo::FetchKernelGraph(kernel_actor->kernel().get());
929 if (graph == nullptr) {
930 MS_LOG(INTERNAL_EXCEPTION) << "#dmsg#Runtime error info:#dmsg#No associated graph for node: "
931 << kernel_actor->kernel()->fullname_with_scope();
932 }
933 // Somas is not work for this graph.
934 if (graph->somas_whole_block_size() == 0) {
935 return;
936 }
937
938 // Set the somas info.
939 auto somas_info = graph->MutableSomasInfo();
940 MS_EXCEPTION_IF_NULL(somas_info);
941 somas_info->graph_id_ = graph->graph_id();
942 kernel_actor->somas_info_ = somas_info;
943 }
944
AddSomasInfoForGraphOutput(AbstractActor * const output_actor,const AnfNodePtr & output_kernel,size_t output_index,size_t graph_id)945 void SchedulerHelper::AddSomasInfoForGraphOutput(AbstractActor *const output_actor, const AnfNodePtr &output_kernel,
946 size_t output_index, size_t graph_id) {
947 auto ms_context = MsContext::GetInstance();
948 MS_EXCEPTION_IF_NULL(ms_context);
949 if (ms_context->get_param<int>(MS_CTX_MEMORY_OPTIMIZE_LEVEL) == kOptimizeO0) {
950 return;
951 }
952 if ((output_actor == nullptr) || (output_actor->type() != KernelTransformType::kKernelActor &&
953 output_actor->type() != KernelTransformType::kConditionSwitchActor &&
954 output_actor->type() != KernelTransformType::kConditionGatherActor)) {
955 return;
956 }
957
958 MS_EXCEPTION_IF_NULL(output_kernel);
959 auto kernel_info = dynamic_cast<KernelInfo *>(output_kernel->kernel_info());
960 MS_EXCEPTION_IF_NULL(kernel_info);
961 const auto &somas_outputs = kernel_info->somas_output_result();
962 auto is_somas = kernel_info->IsTensorEnableSomas(somas_outputs, output_index);
963 MS_LOG(INFO) << "The graph " << graph_id << " output node:" << output_kernel->fullname_with_scope()
964 << " with index: " << output_index << " somas enable or not: " << is_somas
965 << ", somas offset: " << kernel_info->GetTensorSomasOffset(somas_outputs, output_index)
966 << ", aligned size: " << kernel_info->GetTensorSomasAlignedSize(somas_outputs, output_index);
967 if (is_somas) {
968 auto kernel_actor = dynamic_cast<KernelActor *>(output_actor);
969 kernel_actor->somas_graph_output_indexes_.insert(output_index);
970 }
971 }
972
973 namespace {
CheckKernelActorValid(const std::vector<KernelActorPtr> & kernel_actors)974 void CheckKernelActorValid(const std::vector<KernelActorPtr> &kernel_actors) {
975 for (const auto &kernel_actor : kernel_actors) {
976 MS_EXCEPTION_IF_NULL(kernel_actor);
977 std::string exit_actor_name = "";
978
979 for (const auto &arrow : kernel_actor->output_data_arrows()) {
980 MS_EXCEPTION_IF_NULL(arrow);
981 if (arrow->to_op_id_.Name().find(kExitActorNameSuffix) == std::string::npos) {
982 continue;
983 }
984 if (exit_actor_name == "") {
985 exit_actor_name = arrow->to_op_id_.Name();
986 continue;
987 }
988 if (exit_actor_name != arrow->to_op_id_.Name()) {
989 MS_LOG(INTERNAL_EXCEPTION) << "#dmsg#Runtime error info:#dmsg#Kernel actor:" << kernel_actor->GetAID()
990 << " link to two exit actor:" << exit_actor_name
991 << " and:" << arrow->to_op_id_.Name();
992 }
993 }
994 }
995 }
996
CheckExitActorInvalid(const ExitActorPtr & exit_actor)997 bool CheckExitActorInvalid(const ExitActorPtr &exit_actor) {
998 MS_EXCEPTION_IF_NULL(exit_actor);
999
1000 return exit_actor->output_data_arrows().empty() && exit_actor->output_partial_arrows().empty() &&
1001 exit_actor->output_control_arrows().empty() && exit_actor->output_branch_control_arrows().empty() &&
1002 exit_actor->output_branch_data_arrows().empty() && exit_actor->output_branch_partial_arrows().empty() &&
1003 !exit_actor->input_data_arrow_aids().empty();
1004 }
1005
1006 // Convert the control actors vector by the control actor set.
CollectControlActors(const ControlActorSetPtr & control_actor_set)1007 std::vector<ControlActorPtr> CollectControlActors(const ControlActorSetPtr &control_actor_set) {
1008 MS_EXCEPTION_IF_NULL(control_actor_set);
1009 std::vector<ControlActorPtr> actors;
1010
1011 for (auto &switch_actor : control_actor_set->switch_actors_) {
1012 MS_EXCEPTION_IF_NULL(switch_actor);
1013 (void)actors.emplace_back(static_cast<ControlActorPtr>(switch_actor));
1014 }
1015 for (auto &gather_actor : control_actor_set->gather_actors_) {
1016 MS_EXCEPTION_IF_NULL(gather_actor);
1017 (void)actors.emplace_back(static_cast<ControlActorPtr>(gather_actor));
1018 }
1019 for (auto &entrance_actor : control_actor_set->entrance_actors_) {
1020 MS_EXCEPTION_IF_NULL(entrance_actor);
1021 (void)actors.emplace_back(static_cast<ControlActorPtr>(entrance_actor));
1022 }
1023 for (auto &exit_actor : control_actor_set->exit_actors_) {
1024 MS_EXCEPTION_IF_NULL(exit_actor);
1025 (void)actors.emplace_back(static_cast<ControlActorPtr>(exit_actor));
1026 }
1027 for (auto &stack_actor : control_actor_set->stack_actors_) {
1028 MS_EXCEPTION_IF_NULL(stack_actor);
1029 (void)actors.emplace_back(static_cast<ControlActorPtr>(stack_actor));
1030 }
1031
1032 return actors;
1033 }
1034
CheckControlActorValid(const ActorSet * actor_set)1035 void CheckControlActorValid(const ActorSet *actor_set) {
1036 MS_EXCEPTION_IF_NULL(actor_set);
1037 if (actor_set->control_actors_ == nullptr) {
1038 return;
1039 }
1040
1041 CheckKernelActorValid(actor_set->kernel_actors_);
1042
1043 auto control_actors = CollectControlActors(actor_set->control_actors_);
1044 for (const auto &control_actor : control_actors) {
1045 MS_EXCEPTION_IF_NULL(control_actor);
1046 for (auto &ref_node_formal_parameter_device_tensor : control_actor->ref_node_formal_parameter_device_tensors()) {
1047 auto &device_tensors = ref_node_formal_parameter_device_tensor.second;
1048 for (auto iter = device_tensors.begin(); iter != device_tensors.end(); ++iter) {
1049 if (((*device_tensors.begin())->format() != (*iter)->format()) ||
1050 ((*device_tensors.begin())->GetDeviceType() != (*iter)->GetDeviceType()) ||
1051 ((*device_tensors.begin())->type_id() != (*iter)->type_id())) {
1052 MS_LOG(INTERNAL_EXCEPTION) << "#dmsg#Runtime error info:#dmsg#" << control_actor->GetAID().Name()
1053 << " does not support the ref node formal parameters with different format.";
1054 }
1055 }
1056 }
1057
1058 for (auto &ref_formal_parameter_device_tensor : control_actor->ref_formal_parameter_device_tensors()) {
1059 auto &device_tensors = ref_formal_parameter_device_tensor.second;
1060 for (auto iter = device_tensors.begin(); iter != device_tensors.end(); ++iter) {
1061 if ((*device_tensors.begin())->type_id() != (*iter)->type_id()) {
1062 MS_LOG(INTERNAL_EXCEPTION) << "#dmsg#Runtime error info:#dmsg#" << control_actor->GetAID().Name()
1063 << " does not support the ref formal parameters with different type.";
1064 }
1065 }
1066 }
1067 }
1068
1069 for (const auto &exit_actor : actor_set->control_actors_->exit_actors_) {
1070 MS_EXCEPTION_IF_NULL(exit_actor);
1071 if (CheckExitActorInvalid(exit_actor)) {
1072 MS_LOG(INTERNAL_EXCEPTION) << "#dmsg#Runtime error info:#dmsg#Invalid exit actor:" << exit_actor->GetAID();
1073 }
1074 }
1075
1076 // Since some control arrows of stack actors need to be counted according to aid, the input control arrow cannot
1077 // be repeated, otherwise the count will be inaccurate. But there are exceptions, if the control arrow does not
1078 // need to be counted according to aid, it can be repeated.
1079 for (const auto &stack_actor : actor_set->control_actors_->stack_actors_) {
1080 MS_EXCEPTION_IF_NULL(stack_actor);
1081 const auto &input_control_aids = stack_actor->input_control_arrow_aids();
1082 std::set<AID> aid_set;
1083 (void)std::for_each(input_control_aids.begin(), input_control_aids.end(),
1084 [&aid_set](const auto &input_control_aid) { (void)aid_set.emplace(input_control_aid.first); });
1085 if (aid_set.size() != input_control_aids.size()) {
1086 MS_LOG(WARNING) << "Stack actor:" << stack_actor->GetAID() << " has duplicate control arrows.";
1087 }
1088 }
1089 }
1090 } // namespace
1091
CheckActorValid(const ActorSet * actor_set)1092 void SchedulerHelper::CheckActorValid(const ActorSet *actor_set) {
1093 MS_EXCEPTION_IF_NULL(actor_set);
1094 auto actors = SchedulerHelper::CollectActors(actor_set);
1095 for (auto &actor : actors) {
1096 MS_EXCEPTION_IF_NULL(actor);
1097 if (actor->type_ >= KernelTransformType::kSwitchActor) {
1098 continue;
1099 }
1100
1101 if ((actor->input_datas_num_ != actor->input_data_arrow_aids_.size()) ||
1102 (actor->input_controls_num_ != actor->input_control_arrow_aids_.size())) {
1103 MS_LOG(INTERNAL_EXCEPTION) << "#dmsg#Runtime error info:#dmsg#The input num of " << actor->GetAID().Name()
1104 << " is wrong, expect data num: " << actor->input_datas_num_
1105 << ", actual data num: " << actor->input_data_arrow_aids_.size()
1106 << ", expect control num: " << actor->input_controls_num_
1107 << ", actual control num: " << actor->input_control_arrow_aids_.size();
1108 }
1109
1110 if ((actor->type_ != KernelTransformType::kOutputActor) && (actor->type_ != KernelTransformType::kCustomActor) &&
1111 (actor->output_data_arrows_.size() == 0) && (actor->output_control_arrows_.size() == 0)) {
1112 MS_LOG(INTERNAL_EXCEPTION) << "#dmsg#Runtime error info:#dmsg#" << actor->GetAID().Name() << " has no user.";
1113 }
1114 if ((actor->type_ != KernelTransformType::kDataPrepareActor) &&
1115 (actor->type_ != KernelTransformType::kCustomActor) && (actor->input_datas_num_ == 0) &&
1116 (actor->input_controls_num_ == 0)) {
1117 MS_LOG(INTERNAL_EXCEPTION) << "#dmsg#Runtime error info:#dmsg#" << actor->GetAID().Name() << " has no source.";
1118 }
1119
1120 // Check the input of kernel actors and copy actors.
1121 if ((actor->type_ == KernelTransformType::kKernelActor) || (actor->type_ == KernelTransformType::kCopyActor)) {
1122 size_t expect_input_num = 1;
1123 if (actor->type_ == KernelTransformType::kKernelActor) {
1124 auto kernel_actor = dynamic_cast<KernelActor *>(actor.get());
1125 MS_EXCEPTION_IF_NULL(kernel_actor);
1126 auto &kernel = kernel_actor->kernel();
1127 MS_EXCEPTION_IF_NULL(kernel);
1128 expect_input_num = common::AnfAlgo::GetInputTensorNum(kernel);
1129 }
1130 auto input_data_num = actor->input_datas_num_;
1131 auto device_tensor_store_num = actor->device_tensor_store_keys_.size();
1132 if (input_data_num + device_tensor_store_num != expect_input_num) {
1133 MS_LOG(INTERNAL_EXCEPTION) << "#dmsg#Runtime error info:#dmsg#The input building of " << actor->GetAID().Name()
1134 << " is wrong, input data num: " << input_data_num
1135 << ", device tensor store num: " << device_tensor_store_num
1136 << ", total input num: " << expect_input_num;
1137 }
1138 }
1139 }
1140
1141 // Check the output actor.
1142 auto output_actor = actor_set->output_actor_;
1143 MS_EXCEPTION_IF_NULL(output_actor);
1144 if (output_actor->input_datas_num_ + output_actor->device_tensor_store_keys_.size() != output_actor->outputs_num()) {
1145 MS_LOG(INTERNAL_EXCEPTION)
1146 << "#dmsg#Runtime error info:#dmsg#The outputs num of output actor is wrong, the total outputs num: "
1147 << output_actor->outputs_num() << ", the input data arrows num: " << output_actor->input_datas_num_
1148 << ", the device tensor store num: " << output_actor->device_tensor_store_keys_.size();
1149 }
1150
1151 CheckControlActorValid(actor_set);
1152 }
1153
DumpActorSet(const ActorSet * actor_set,std::ofstream & ofs)1154 void SchedulerHelper::DumpActorSet(const ActorSet *actor_set, std::ofstream &ofs) {
1155 MS_EXCEPTION_IF_NULL(actor_set);
1156 DumpDataPrepareActor(actor_set->data_prepare_actor_, ofs);
1157 DumpDSActors(actor_set->data_source_actors_, ofs);
1158 DumpKernelActors(actor_set->kernel_actors_, ofs);
1159 DumpKernelInferActors(actor_set->kernel_infer_actors_, ofs);
1160 DumpKernelResizeActors(actor_set->kernel_resize_actors_, ofs);
1161 DumpSuperKernelActors(actor_set->super_kernel_actors_, ofs);
1162 DumpAnyTypeKernelActors(actor_set->any_type_kernel_actors_, ofs);
1163 // The on input kernel actors are taken over by control actor in the control flow scene.
1164 if (actor_set->control_actors_ == nullptr) {
1165 DumpNoInputKernelActors(actor_set->no_input_kernel_actors_, ofs);
1166 }
1167 DumpMemoryActors(actor_set->memory_actors_, ofs);
1168 DumpCopyActors(actor_set->copy_actors_, ofs);
1169 DumpLoopCountActor(actor_set->loop_count_actor_, ofs);
1170 DumpOutputActor(actor_set->output_actor_, ofs);
1171 DumpFusionActors(actor_set->fusion_actors_, ofs);
1172 DumpControlActors(actor_set->control_actors_, ofs);
1173 DumpCustomActors(actor_set->custom_actors_, ofs);
1174 DumpSwapActors(actor_set->swap_actors_, ofs);
1175 }
1176
DumpFormatActorSet(const ActorSet * actor_set,std::ofstream & ofs)1177 void SchedulerHelper::DumpFormatActorSet(const ActorSet *actor_set, std::ofstream &ofs) {
1178 MS_EXCEPTION_IF_NULL(actor_set);
1179 try {
1180 MS_LOG(DEBUG) << "Start dump format actor set:" << actor_set->name_;
1181 if (actor_set->control_actors_ != nullptr) {
1182 for (const auto &exit_actor : actor_set->control_actors_->exit_actors_) {
1183 if (exit_actor->node() != nullptr) {
1184 continue;
1185 }
1186 auto actors = TopoSortForActor(exit_actor.get());
1187 ActorInfoMap actor_info;
1188 ofs << "\n\nBase Block : "
1189 << exit_actor->GetAID().Name().substr(0, exit_actor->GetAID().Name().find(kExitActorNameSuffix)) << "\n\n";
1190 for (size_t i = 0; i < actors.size(); ++i) {
1191 DumpActorInfo(actors[i], i, &actor_info, ofs);
1192 }
1193 }
1194 return;
1195 }
1196
1197 auto actors = TopoSortForActor(actor_set->output_actor_.get());
1198 ActorInfoMap actor_info;
1199 for (size_t i = 0; i < actors.size(); ++i) {
1200 DumpActorInfo(actors[i], i, &actor_info, ofs);
1201 }
1202 MS_LOG(DEBUG) << "End dump format actor set:" << actor_set->name_;
1203 } catch (const std::exception &e) {
1204 MS_LOG(INFO) << "Failed to dump actor set:" << actor_set->name_ << ", msg: " << e.what();
1205 }
1206 }
1207 } // namespace runtime
1208 } // namespace mindspore
1209