1 /**
2 * Copyright 2021-2024 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <set>
18 #include <algorithm>
19 #include "include/backend/mem_reuse/mem_tracker.h"
20 #include "runtime/graph_scheduler/actor/super_kernel_actor.h"
21 #include "runtime/graph_scheduler/scheduler_helper.h"
22 #include "runtime/graph_scheduler/actor/output_actor.h"
23 #include "runtime/graph_scheduler/actor/memory_manager_actor.h"
24 #include "runtime/graph_scheduler/actor/debug_actor.h"
25 #include "mindrt/include/async/async.h"
26 #include "utils/phase.h"
27 #include "utils/log_adapter.h"
28
29 namespace mindspore {
30 namespace runtime {
31 namespace {
UpdateShape(const AnfNodePtr & input_node,const DeviceTensorPtr & node_device_tensor,DeviceTensor * input_device_tensor,const KernelTransformType & type)32 inline void UpdateShape(const AnfNodePtr &input_node, const DeviceTensorPtr &node_device_tensor,
33 DeviceTensor *input_device_tensor, const KernelTransformType &type) {
34 MS_EXCEPTION_IF_NULL(input_node);
35 const auto &node_device_kernel_tensor = node_device_tensor->kernel_tensor();
36 MS_EXCEPTION_IF_NULL(input_device_tensor);
37 const auto &input_kernel_tensor = input_device_tensor->kernel_tensor();
38 MS_EXCEPTION_IF_NULL(node_device_kernel_tensor);
39 MS_EXCEPTION_IF_NULL(input_kernel_tensor);
40 if (type != KernelTransformType::kSuperKernelActor || input_node->cast<ParameterPtr>()->has_dynamic_shape()) {
41 // For dynamic shape in sub graph sink and any type parameter, the input size should be updated.
42 node_device_tensor->SetSize(input_device_tensor->GetSize());
43 // Update Shape.
44 node_device_kernel_tensor->SetShape(input_kernel_tensor->GetShape()->Clone());
45 }
46 }
47
InputDataNoNeedCopy(const AnfNodePtr & input_node,DeviceTensor * input_device_tensor,const DeviceTensorPtr & node_device_tensor,const KernelTransformType & type)48 inline bool InputDataNoNeedCopy(const AnfNodePtr &input_node, DeviceTensor *input_device_tensor,
49 const DeviceTensorPtr &node_device_tensor, const KernelTransformType &type) {
50 if (input_device_tensor == node_device_tensor.get()) {
51 (void)input_device_tensor->TouchSyncHandler();
52 return true;
53 }
54
55 if (input_device_tensor == nullptr) {
56 return true;
57 }
58
59 UpdateShape(input_node, node_device_tensor, input_device_tensor, type);
60
61 if (TEST_FLAG(node_device_tensor->flag(), device::kDeviceAddressFlagNotUsed) ||
62 input_device_tensor->GetPtr() == node_device_tensor->GetPtr()) {
63 return true;
64 }
65
66 return false;
67 }
68
UpdateRefCountWithOnlyDependShape(const CNodePtr & kernel,size_t input_index,const AnfNodePtr & node,size_t output_index)69 void UpdateRefCountWithOnlyDependShape(const CNodePtr &kernel, size_t input_index, const AnfNodePtr &node,
70 size_t output_index) {
71 auto ms_context = MsContext::GetInstance();
72 MS_EXCEPTION_IF_NULL(ms_context);
73 static const bool enable_infer_boost = ms_context->IsEnableInferBoost();
74 if (enable_infer_boost) {
75 UpdateRefCount(node, output_index, false);
76 return;
77 }
78
79 // Shape depend kernel should not increase ref count.
80 const auto &only_depend_shape_attr = common::AnfAlgo::GetCNodePrimitiveAttr(kernel, kAttrOnlyDependShape);
81 if (only_depend_shape_attr != nullptr) {
82 const auto &only_depend_shape = GetValue<std::vector<bool>>(only_depend_shape_attr);
83 if (input_index < only_depend_shape.size() && only_depend_shape[input_index]) {
84 // Only depend shape no need to increase ref count, and update flag.
85 auto device_tensor = AnfAlgo::GetMutableOutputAddr(node, output_index, false);
86 MS_EXCEPTION_IF_NULL(device_tensor);
87 device_tensor->UpdateFlag(device::kDeviceAddressFlagNullptr);
88 return;
89 }
90 }
91 UpdateRefCount(node, output_index, false);
92 }
93 } // namespace
Init()94 void SuperKernelActor::Init() {
95 MS_EXCEPTION_IF_NULL(graph_);
96 // Check device contexts number.
97 if (device_contexts_.size() != device::kDeviceContextsNumOne) {
98 MS_LOG(EXCEPTION) << "The device contexts number is wrong.";
99 }
100
101 // Set the number of actor running dependent messages.
102 running_dependent_msg_num_ = SizeToInt(input_datas_num_ + input_controls_num_);
103
104 // Init the output data.
105 InitOutputData();
106 if (output_data_arrows_.size() != output_data_nodes_.size()) {
107 MS_LOG(EXCEPTION) << "The size of output data arrows is not equal to the output data nodes.";
108 }
109 if (output_data_arrows_.size() != output_data_.size()) {
110 MS_LOG(EXCEPTION) << "The size of output data arrows is not equal to the output data.";
111 }
112 for (size_t i = 0; i < output_data_arrows_.size(); ++i) {
113 auto &data_arrow = output_data_arrows_[i];
114 auto &output_node = output_data_nodes_[i];
115 auto data = output_data_[i].first.get();
116 MS_EXCEPTION_IF_NULL(data_arrow);
117 MS_EXCEPTION_IF_NULL(output_node);
118 MS_EXCEPTION_IF_NULL(data);
119 auto device_address = AnfAlgo::GetMutableOutputAddr(output_node, IntToSize(data_arrow->from_output_index_), false);
120 data->data_ = device_address.get();
121 }
122
123 const auto &output_with_indexs = common::AnfAlgo::GetAllOutputWithIndex(graph_->output());
124 for (const auto &origin_output_with_index : output_with_indexs) {
125 const auto &output_with_index = common::AnfAlgo::FetchRealNodeSkipMonadControl(origin_output_with_index);
126 const auto &output_node = output_with_index.first;
127 MS_EXCEPTION_IF_NULL(output_node);
128 if (output_node->isa<CNode>() && (!HasAbstractMonad(output_node))) {
129 auto device_address = AnfAlgo::GetMutableOutputAddr(output_node, output_with_index.second, false);
130 MS_EXCEPTION_IF_NULL(device_address);
131 if (device_address->is_ptr_persisted() || graph_->is_dynamic_shape()) {
132 MS_LOG(DEBUG) << "Actor:" << GetAID() << " skip alloc memory for device address:" << device_address
133 << " is persist:" << device_address->is_ptr_persisted()
134 << " is dynamic shape:" << graph_->is_dynamic_shape()
135 << " output node:" << output_node->DebugString();
136 continue;
137 }
138 // Free the ptr in device address of output node.
139 if (device_address->GetPtr() != nullptr) {
140 MS_LOG(INFO) << "Output node:" << output_node->DebugString() << " has a default ptr, maybe a mem leak.";
141 device_address->set_ptr(nullptr);
142 }
143 if (common::IsNeedProfileMemory()) {
144 device_address_to_node_[device_address.get()] = {device_address->GetSize(), output_node->fullname_with_scope()};
145 }
146 memory_alloc_list_.emplace_back(device_address.get());
147 }
148 }
149
150 // Check whether the parameter needs to be copied out.
151 node_device_tensors_.resize(graph_->input_nodes().size());
152 is_parameters_need_copy_.resize(graph_->input_nodes().size());
153 copy_input_device_tensors_.resize(graph_->input_nodes().size());
154 for (size_t i = 0; i < graph_->input_nodes().size(); ++i) {
155 const auto &input_node = graph_->input_nodes()[i];
156 MS_EXCEPTION_IF_NULL(input_node);
157 node_device_tensors_[i] = AnfAlgo::GetMutableOutputAddr(input_node, 0, false);
158 if (!common::AnfAlgo::HasAbstractRef(input_node)) {
159 is_parameters_need_copy_[i] = false;
160 continue;
161 }
162 // If the parameter has ref attribute and is directly used by the kernel in the graph, it needs to be copied.
163 is_parameters_need_copy_[i] = true;
164 }
165
166 if (enable_kbk_sub_graph_execute_) {
167 BuildKernelActors();
168 ParseInputIndex();
169 CalcRefCount();
170 }
171
172 if (type_ == KernelTransformType::kSuperKernelActor && !enable_kbk_sub_graph_execute_) {
173 MS_EXCEPTION_IF_NULL(device_contexts_[0]);
174 MS_EXCEPTION_IF_NULL(device_contexts_[0]->graph_executor_);
175 device_contexts_[0]->graph_executor_->InitGraphInfo(graph_);
176 }
177 }
178
FetchInputNodePosition(const AnfNodePtr & intput_node)179 size_t SuperKernelActor::FetchInputNodePosition(const AnfNodePtr &intput_node) {
180 MS_EXCEPTION_IF_NULL(intput_node);
181 MS_EXCEPTION_IF_NULL(graph_);
182
183 auto &input_nodes = graph_->input_nodes();
184 const auto &iter = find(input_nodes.begin(), input_nodes.end(), intput_node);
185 if (iter == input_nodes.end()) {
186 MS_LOG_WITH_NODE(EXCEPTION, intput_node) << "Invalid input node:" << intput_node->fullname_with_scope();
187 }
188 return iter - input_nodes.begin();
189 }
190
FetchInputDeviceTensor(OpContext<DeviceTensor> * const context)191 void SuperKernelActor::FetchInputDeviceTensor(OpContext<DeviceTensor> *const context) {
192 ProfilerRecorder profiler(ProfilerModule::kRuntime, ProfilerEvent::kPreLaunch, GetAID().Name());
193 MS_EXCEPTION_IF_NULL(context);
194 if (device_contexts_.empty() || device_contexts_[0] == nullptr) {
195 SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(GraphExecutionStrategy::kPipeline, (*context),
196 "Invalid device context for super kernel actor:" + GetAID().Name());
197 }
198 std::vector<DeviceTensor *> memory_free_list;
199 const auto &data_iter = input_op_datas_.find(context->sequential_num_);
200 if (data_iter != input_op_datas_.end()) {
201 for (auto &input_data : data_iter->second) {
202 MS_EXCEPTION_IF_NULL(input_data);
203 MS_EXCEPTION_IF_NULL(input_data->data_);
204 size_t index = IntToSize(input_data->index_);
205 if (index >= input_device_tensors_.size()) {
206 std::string error_info = "Invalid input index:" + std::to_string(index) +
207 " total:" + std::to_string(input_device_tensors_.size()) +
208 " for actor:" + GetAID().Name();
209 SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
210 }
211 input_device_tensors_[index] = input_data->data_;
212
213 if (common::IsNeedProfileMemory()) {
214 auto output_address = reinterpret_cast<std::uintptr_t>(input_device_tensors_[index]);
215 MS_LOG(WARNING) << "Need Profile Memory, Memory use, actor name: " << GetAID().Name()
216 << ", kernel graph: " << graph_->ToString() << ", device address class ptr: " << output_address
217 << ", device address size: " << input_device_tensors_[index]->GetSize()
218 << ", device address addr: " << input_device_tensors_[index]->GetPtr() << ", index: " << index;
219 }
220
221 if (input_data->data_->dynamic_ref_count() != INT32_MAX) {
222 (void)memory_free_list.emplace_back(input_data->data_);
223 }
224 }
225 memory_free_lists_.push(memory_free_list);
226 }
227 }
228
Run(OpContext<DeviceTensor> * const context)229 void SuperKernelActor::Run(OpContext<DeviceTensor> *const context) {
230 MS_EXCEPTION_IF_NULL(context);
231 MS_EXCEPTION_IF_NULL(graph_);
232 if (device::tracker::MemTrackerManager::GetInstance().IsEnabled()) {
233 device::tracker::CALL_MEMORY_TRACKER_WITH_FILE(AddTask, GetAID().Name(), "SuperKernelActor", graph_->ToString());
234 }
235
236 if (enable_kbk_sub_graph_execute_) {
237 return RunGraphKernelByKernel(context);
238 }
239 if (device_contexts_.empty() || device_contexts_[0] == nullptr) {
240 SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "Invalid device context for super kernel actor:" + GetAID().Name());
241 }
242 MS_LOG(INFO) << "Super kernel actor(" << GetAID().Name()
243 << ") launches graph: " << std::to_string(graph_->graph_id());
244 if (common::IsNeedProfileMemory()) {
245 MS_LOG(WARNING) << "Need Profile Memory, launch actor name: " << GetAID().Name()
246 << ", kernel graph: " << graph_->ToString();
247 }
248 if (!WaitRuntimePipelineFinish(context)) {
249 MS_LOG(INFO) << "Run failed and early stop.";
250 return;
251 }
252 FetchInputDeviceTensor(context);
253 if (!already_fetch_persistent_device_tensor_) {
254 FetchPersistentDeviceTensor();
255 already_fetch_persistent_device_tensor_ = IsTwoPhaseInfer();
256 }
257
258 if (device::tracker::MemTrackerManager::GetInstance().IsEnabled()) {
259 for (auto &device_addr : input_device_tensors_) {
260 if (device_addr == nullptr || !device_addr->IsPtrValid()) {
261 continue;
262 }
263 device::tracker::CALL_MEMORY_TRACKER_WITH_FILE(UseMemBlock, GetAID().Name(), device_addr->GetPtr());
264 }
265 }
266 if (memory_alloc_list_.size() > 0) {
267 for (auto &device_tensor : memory_alloc_list_) {
268 if (device_tensor->IsNotNeedAlloc()) {
269 continue;
270 }
271 if (common::IsNeedProfileMemory()) {
272 MS_EXCEPTION_IF_NULL(device_tensor);
273 auto &info = device_address_to_node_[device_tensor];
274 auto output_address = reinterpret_cast<std::uintptr_t>(device_tensor);
275 MS_LOG(WARNING) << "Need Profile Memory, Memory need allocated, actor name: " << GetAID().Name()
276 << ", kernel graph: " << graph_->ToString() << ", node: " << info.node_full_name
277 << ", device address class ptr: " << output_address << ", device address size: " << info.size;
278 }
279 device::tracker::CALL_MEMORY_TRACKER_WITH_FILE(
280 AddMemInfo, GetAID().Name(), device::tracker::MemType::kGraphOutput, device_tensor->GetSize(), device_tensor);
281 }
282 SendMemoryAllocReq(context);
283 } else {
284 OnMemoryAllocFinish(context);
285 }
286 if (common::IsNeedProfileMemory()) {
287 MS_LOG(WARNING) << "Need Profile Memory, end launch, actor name: " << GetAID().Name()
288 << ", kernel graph: " << graph_->ToString();
289 }
290 }
291
FetchPersistentDeviceTensor()292 void SuperKernelActor::FetchPersistentDeviceTensor() {
293 for (auto &device_tensor_store_key : device_tensor_store_keys_) {
294 auto input_device_tensor = DeviceTensorStore::GetInstance()
295 .Fetch(device_tensor_store_key.second.get(), device_contexts_[0]->GetDeviceType())
296 .get();
297 // Ge backend maybe nullptr.
298 if (input_device_tensor == nullptr) {
299 MS_LOG(DEBUG) << "Failed get device tensor for node:" << device_tensor_store_key.second->DebugString()
300 << " index:" << device_tensor_store_key.first;
301 continue;
302 }
303
304 size_t index = device_tensor_store_key.first;
305 input_device_tensors_[index] = input_device_tensor;
306 }
307 }
308
UpdateMemoryTraceMangerStatus(OpContext<DeviceTensor> * const context)309 void SuperKernelActor::UpdateMemoryTraceMangerStatus(OpContext<DeviceTensor> *const context) {
310 MemoryTraceManager::GetInstance().PickMemoryTrackInfoForGraph(graph_->graph_id());
311 if (!ActorDispatcher::enable_static_shape()) {
312 ProfilerRecorder profiler(ProfilerModule::kRuntime, ProfilerEvent::kMemoryAlloc, GetAID().Name());
313
314 const std::shared_ptr<mindspore::HashMap<CNodePtr, std::vector<KernelMemoryTraceBlockPtr>>> &all_kernel_block_info =
315 MemoryTraceManager::GetInstance().GetAllKernelBlocksnfo();
316 MS_EXCEPTION_IF_NULL(all_kernel_block_info);
317
318 if (!all_kernel_block_info->empty()) {
319 size_t kernel_num = kernel_actors_.size();
320 for (size_t i = 0; i < kernel_num; i++) {
321 const auto &kernel_actor = kernel_actors_[i];
322 if (kernel_actor == nullptr) {
323 continue;
324 }
325
326 const auto &kernel = kernel_actor->kernel_;
327 MS_EXCEPTION_IF_NULL(kernel);
328
329 const auto &iter = all_kernel_block_info->find(kernel);
330 if (iter == all_kernel_block_info->end()) {
331 MS_LOG(DEBUG) << "Not found kernel block info for kernel: " << kernel->fullname_with_scope()
332 << ", is output kernel: " << kernel_actor->is_output_kernel_;
333 } else {
334 const auto &kernel_mem_block = iter->second;
335 for (auto &block : kernel_mem_block) {
336 MS_EXCEPTION_IF_NULL(block);
337 if (block->mem_type_ == kOutputMem) {
338 kernel_actor->output_kernel_tensors_.at(block->index_)->set_device_ptr(nullptr);
339 } else {
340 kernel_actor->workspace_kernel_tensors_.at(block->index_)->set_device_ptr(nullptr);
341 }
342 }
343 }
344 }
345 }
346
347 // First step for dynamic shape, need to record memory trace.
348 MemoryTraceManager::GetInstance().Clear();
349 static const size_t memory_block_size = 3000;
350 MemoryTraceManager::GetInstance().ReserveKernelMemoryBlocks(memory_block_size, device_contexts_[0]);
351 } else {
352 // Not first step for dynamic shape, use record trace memory.
353 // Allocate block memory for static memory step.
354 ProfilerRecorder profiler(ProfilerModule::kRuntime, ProfilerEvent::kMemoryAlloc, GetAID().Name());
355 const auto &merge_blocks_with_device_context = MemoryTraceManager::GetInstance().GetMergeBlocks();
356 MS_EXCEPTION_IF_NULL(merge_blocks_with_device_context);
357 for (auto &item : *merge_blocks_with_device_context) {
358 const auto &device_context = item.first;
359 MS_EXCEPTION_IF_NULL(device_context);
360 const auto &merge_blocks = item.second;
361 for (auto &block : merge_blocks) {
362 MS_EXCEPTION_IF_NULL(block);
363 static const size_t kMemoryAlignSize = 1024;
364 void *block_addr = device_context->device_res_manager_->AllocateMemory(block->size_ + kMemoryAlignSize);
365 if (block_addr == nullptr) {
366 SET_OPCONTEXT_MEMORY_ALLOC_FAIL_BY_STRATEGY(GraphExecutionStrategy::kPipeline, *context,
367 *(device_contexts_[0]), GetAID().Name(), block->size_);
368 }
369 block->start_ = reinterpret_cast<uint8_t *>(block_addr);
370 }
371 }
372 }
373 }
374
SetTraceMemoryForKernel(const KernelActorPtr & kernel_actor)375 void SuperKernelActor::SetTraceMemoryForKernel(const KernelActorPtr &kernel_actor) {
376 const auto &kernel = kernel_actor->kernel();
377 MS_EXCEPTION_IF_NULL(kernel);
378
379 // Allocate trace memory for static memory step.
380 const std::shared_ptr<mindspore::HashMap<CNodePtr, std::vector<KernelMemoryTraceBlockPtr>>> &all_kernel_block_info =
381 MemoryTraceManager::GetInstance().GetAllKernelBlocksnfo();
382 MS_EXCEPTION_IF_NULL(all_kernel_block_info);
383 const auto &iter = all_kernel_block_info->find(kernel);
384 if (iter == all_kernel_block_info->end()) {
385 MS_LOG(DEBUG) << "Not found kernel block info for kernel: " << kernel->fullname_with_scope()
386 << ", is output kernel: " << kernel_actor->is_output_kernel_;
387 } else {
388 const auto &kernel_mem_block = iter->second;
389 const auto &merge_blocks_with_device_context = MemoryTraceManager::GetInstance().GetMergeBlocks();
390 MS_EXCEPTION_IF_NULL(merge_blocks_with_device_context);
391 const auto &merge_blocks = merge_blocks_with_device_context->at(kernel_actor->device_contexts_[0]);
392 for (auto &block : kernel_mem_block) {
393 MS_EXCEPTION_IF_NULL(block);
394 void *ptr = merge_blocks.at(block->in_memory_trace_block_index_)->start_ + block->offset_in_memory_trace_block_;
395 MS_EXCEPTION_IF_NULL(ptr);
396 if (block->mem_type_ == kOutputMem) {
397 kernel_actor->output_kernel_tensors_.at(block->index_)->set_device_ptr(ptr);
398 } else {
399 kernel_actor->workspace_kernel_tensors_.at(block->index_)->set_device_ptr(ptr);
400 }
401 }
402 }
403 }
404
RunGraphKernelByKernel(OpContext<DeviceTensor> * const context)405 void SuperKernelActor::RunGraphKernelByKernel(OpContext<DeviceTensor> *const context) {
406 if (!ActorDispatcher::enable_async_launch_kernel()) {
407 std::string error_info =
408 "Runtime pipeline optimization is disabled, failed to execute graph kernel by kernel mode.";
409 MS_LOG(ERROR) << "Run graph failed, graph id: " << std::to_string(graph_->graph_id()) << ". " << error_info;
410 SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
411 }
412 if (!graph_->is_dynamic_shape()) {
413 ActorDispatcher::set_enable_static_shape(false);
414 }
415
416 // 1. Fetch input data
417 FetchInputDeviceTensor(context);
418 if (!already_fetch_persistent_device_tensor_) {
419 FetchPersistentDeviceTensor();
420 already_fetch_persistent_device_tensor_ = true;
421 }
422
423 // 2. Allocate somas memory for graph
424 if ((somas_info_ != nullptr) && (somas_info_->whole_block_size_ != 0)) {
425 MemoryManagerActor::GetInstance()->AllocateSomasMemory(somas_info_, device_contexts_[0], context, GetAID());
426 }
427 const auto &phase = PhaseManager::GetInstance().phase();
428 bool is_increment_graph = (phase.find("increment") != std::string::npos);
429 if (enable_trace_memory_ && graph_->is_dynamic_shape() && is_increment_graph) {
430 MS_LOG(DEBUG) << "Enable trace memory for increment inference graph: " << graph_->graph_id()
431 << ", phase: " << phase;
432 UpdateMemoryTraceMangerStatus(context);
433
434 if (IsRunningFailed(context)) {
435 // Maybe allocate memory failed, early stop to run graph.
436 MS_LOG(INFO) << "Run failed and early stop to run graph: " << graph_->graph_id();
437 return;
438 }
439 }
440
441 // 3. Launch all kernels
442 size_t kernel_num = kernel_actors_.size();
443 const auto &execution_order = graph_->execution_order();
444 for (size_t i = 0; i < kernel_num; i++) {
445 const auto &kernel_actor = kernel_actors_[i];
446 if (kernel_actor == nullptr) {
447 continue;
448 }
449 const auto &kernel = execution_order[i];
450 // 3.1 Prepare input data for kernel
451 const auto &iter = kernel_input_to_graph_input_indices_.find(kernel.get());
452 if (iter != kernel_input_to_graph_input_indices_.end()) {
453 std::vector<std::pair<size_t, size_t>> &input_to_graph_input_indices = iter->second;
454 for (const auto &item : input_to_graph_input_indices) {
455 kernel_actor->SetInputDeviceTensor(input_device_tensors_[item.second], item.first);
456 }
457 }
458
459 // 3.2 Allocate somas memory for this kernel
460 kernel_actor->SetSomasMemory(context);
461
462 if (ActorDispatcher::enable_use_trace_memory()) {
463 SetTraceMemoryForKernel(kernel_actor);
464 }
465
466 // Async Run Infer or Launch
467 if (ActorDispatcher::enable_runtime_multi_pipeline() && !ActorDispatcher::enable_static_shape()) {
468 // If the kernel need user data and is dynamic, maybe need input kernel's output user data to infer shape, this
469 // value depend case can not handle in KernelTensor auto sync phase currently.
470 if (kernel_actor->kernel_mod_->need_user_data() && kernel_actor->has_dynamic_) {
471 MS_LOG(DEBUG) << "Begin wait runtime pipeline for kernel: " << kernel_actor->kernel_->fullname_with_scope();
472 if (!WaitRuntimePipelineFinish(context)) {
473 MS_LOG(INFO) << "Run failed and early stop for kernel: " << kernel_actor->kernel_->fullname_with_scope();
474 return;
475 }
476 MS_LOG(DEBUG) << "End wait runtime pipeline for kernel: " << kernel_actor->kernel_->fullname_with_scope();
477 }
478
479 // Push run task to pipeline.
480 // Note: dynamic value or static shape also need push task into infer actor to make sure correct kernel
481 // execution order.
482 Async(kernel_async_infer_aid_, &KernelAsyncInferActor::InferShape, context, kernel_actor.get());
483
484 // The computed depend kernel should wait output shape update after kernel launch.
485 if (kernel_actor->kernel_mod_->IsNeedUpdateOutputShapeAndSize()) {
486 MS_LOG(DEBUG) << "Begin wait runtime pipeline for kernel: " << kernel_actor->kernel_->fullname_with_scope();
487 if (!WaitRuntimePipelineFinish(context)) {
488 MS_LOG(INFO) << "Run failed and early stop for kernel: " << kernel_actor->kernel_->fullname_with_scope();
489 return;
490 }
491 MS_LOG(DEBUG) << "End wait runtime pipeline for kernel: " << kernel_actor->kernel_->fullname_with_scope();
492 }
493 } else {
494 Async(kernel_async_launch_aid_, &KernelAsyncLaunchActor::LaunchKernel, context, kernel_actor.get());
495 }
496 }
497
498 WaitRuntimePipelineFinish(context);
499
500 // 4. Free somas memory for graph
501 if ((somas_info_ != nullptr) && (somas_info_->whole_block_size_ != 0)) {
502 MemoryManagerActor::GetInstance()->FreeSomasMemory(somas_info_, device_contexts_[0], context, GetAID());
503 }
504
505 if (ActorDispatcher::enable_trace_dynamic_memory()) {
506 // Record and analyse the memory trace of this step, use to optimize the memory manage performance.
507 ProfilerRecorder profiler(ProfilerModule::kRuntime, ProfilerEvent::kMemoryFree, GetAID().Name());
508 MemoryTraceManager::GetInstance().MergeBlocks();
509 }
510 if (ActorDispatcher::enable_use_trace_memory()) {
511 // Free block memory for static memory step.
512 ProfilerRecorder profiler(ProfilerModule::kRuntime, ProfilerEvent::kMemoryFree, GetAID().Name());
513 const auto &merge_blocks_with_device_context = MemoryTraceManager::GetInstance().GetMergeBlocks();
514 MS_EXCEPTION_IF_NULL(merge_blocks_with_device_context);
515 for (auto &item : *merge_blocks_with_device_context) {
516 const auto &device_context = item.first;
517 MS_EXCEPTION_IF_NULL(device_context);
518 const auto &merge_blocks = item.second;
519 for (auto &block : merge_blocks) {
520 MS_EXCEPTION_IF_NULL(block);
521 device_context->device_res_manager_->FreeMemory(block->start_);
522 }
523 }
524 }
525
526 // Free input data.
527 PostRun(context);
528 }
529
SendMemoryAllocReq(OpContext<DeviceTensor> * const context)530 void SuperKernelActor::SendMemoryAllocReq(OpContext<DeviceTensor> *const context) {
531 MS_EXCEPTION_IF_NULL(context);
532 if (device_contexts_.empty() || device_contexts_[0] == nullptr) {
533 SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(GraphExecutionStrategy::kPipeline, (*context),
534 "Invalid device context for super kernel actor:" + GetAID().Name());
535 }
536 sort(memory_alloc_list_.begin(), memory_alloc_list_.end(), [](const DeviceTensor *a, const DeviceTensor *b) {
537 MS_EXCEPTION_IF_NULL(a);
538 MS_EXCEPTION_IF_NULL(b);
539 return a->GetSize() > b->GetSize();
540 });
541 if (ActorDispatcher::is_memory_allocation_sync()) {
542 ActorDispatcher::SendSync(memory_manager_aid_, &MemoryManagerActor::AllocateMemory, &memory_alloc_list_,
543 device_contexts_[0], context, GetAID());
544 OnMemoryAllocFinish(context);
545 } else {
546 ActorDispatcher::Send(memory_manager_aid_, &MemoryManagerActor::AllocateMemory, &memory_alloc_list_,
547 device_contexts_[0], context, GetAID());
548 }
549 }
550
OnMemoryAllocFinish(OpContext<DeviceTensor> * const context)551 void SuperKernelActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *const context) {
552 MS_EXCEPTION_IF_NULL(context);
553 MS_EXCEPTION_IF_NULL(graph_);
554 if (IsRunningFailed(context)) {
555 MS_LOG(INFO) << "Running failed in actor:" << GetAID().Name();
556 return;
557 }
558 {
559 ProfilerRecorder profiler(ProfilerModule::kRuntime, ProfilerEvent::kPreLaunch, GetAID().Name());
560 if (!CopyInputData(context, graph_)) {
561 std::string error_info = "Copy the input data failed, graph id: " + std::to_string(graph_->graph_id());
562 SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
563 }
564 }
565
566 try {
567 const std::vector<tensor::Tensor> inputs;
568 std::vector<tensor::Tensor> outputs;
569 const std::map<string, string> compile_options;
570 if (device_contexts_.empty() || device_contexts_[0] == nullptr) {
571 SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(GraphExecutionStrategy::kPipeline, (*context),
572 "Invalid device context for super kernel actor:" + GetAID().Name());
573 }
574 MS_EXCEPTION_IF_NULL(device_contexts_[0]->graph_executor_);
575 if (!IsSkippedLaunch(nullptr, graph_)) {
576 ProfilerRecorder profiler(ProfilerModule::kKernel, ProfilerEvent::kGraphLaunch, GetAID().Name());
577 auto ret = device_contexts_[0]->graph_executor_->RunGraph(graph_, inputs, &outputs, compile_options);
578 if (!ret) {
579 std::string error_info = "Launch graph failed, graph id: " + std::to_string(graph_->graph_id());
580 SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
581 }
582 } else if (common::IsNeedProfileMemory()) {
583 auto memory_size = device_contexts_[0]->graph_executor_->GetGraphFeatureMemory(graph_);
584 MS_LOG(WARNING) << "Need Profile Memory, graph: " << graph_->ToString() << ", feature memory: " << memory_size;
585 MS_LOG(WARNING) << "Need Profile Memory, max used static memory: "
586 << device_contexts_[0]->device_res_manager_->GetMaxUsedMemorySize();
587 }
588 } catch (const std::exception &e) {
589 MsException::Instance().SetException();
590 std::string error_info = "Launch graph exception, graph id: " + std::to_string(graph_->graph_id());
591 SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
592 }
593
594 {
595 ProfilerRecorder profiler(ProfilerModule::kRuntime, ProfilerEvent::kPostLaunch, GetAID().Name());
596 for (auto item : ref_node_addr_map_) {
597 MS_EXCEPTION_IF_NULL(item.first);
598 MS_EXCEPTION_IF_NULL(item.second);
599 MS_LOG(INFO) << "The input ref node copy back from address: " << item.first->GetPtr()
600 << " to address: " << item.second->GetPtr() << ".";
601 if (!Copy(item.second, item.first)) {
602 SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "Copy data failed.");
603 }
604 }
605 ref_node_addr_map_.clear();
606 }
607
608 // Debug actor is blocked, must wait debug actor callback message to process continue.
609 if (debug_aid_ != nullptr) {
610 SendDebugReq(context);
611 return;
612 }
613 PostRun(context);
614 }
615
SendDebugReq(OpContext<DeviceTensor> * const context)616 void SuperKernelActor::SendDebugReq(OpContext<DeviceTensor> *const context) {
617 running_dependent_msg_num_ = 1;
618 if (device_contexts_.empty() || device_contexts_[0] == nullptr) {
619 SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(GraphExecutionStrategy::kPipeline, (*context),
620 "Invalid device context for super kernel actor:" + GetAID().Name());
621 }
622 OnDebugFinish(context);
623 }
624
CopyInputDataPersistedHandle(const DeviceContext * device_context,DeviceTensor * input_device_tensor,const DeviceTensorPtr & node_device_tensor,size_t i)625 bool SuperKernelActor::CopyInputDataPersistedHandle(const DeviceContext *device_context,
626 DeviceTensor *input_device_tensor,
627 const DeviceTensorPtr &node_device_tensor, size_t i) {
628 if ((input_device_tensor->GetDeviceType() == node_device_tensor->GetDeviceType()) &&
629 AnfAlgo::IsEquivalentFormat(input_device_tensor->format(), node_device_tensor->format())) {
630 MS_LOG(DEBUG) << "Not need copy for device tensor:" << node_device_tensor << " ptr:" << node_device_tensor->GetPtr()
631 << " index:" << i << " for actor:" << GetAID();
632 // Set the ptr from input_device_tensor and set mem pool false to avoid memory double management for
633 // supporting zero copy.
634 if (type_ != KernelTransformType::kSuperKernelActor) {
635 node_device_tensor->set_ptr(input_device_tensor->GetMutablePtr());
636 } else {
637 node_device_tensor->set_ptr(input_device_tensor->GetValidPtr(input_device_tensor->stream_id()));
638 }
639 MS_LOG(DEBUG) << "Actor:" << GetAID() << "set need sync flag from:" << input_device_tensor
640 << " to:" << node_device_tensor
641 << " sync user data handler:" << node_device_tensor->need_sync_user_data();
642 node_device_tensor->set_from_mem_pool(false);
643 // continue
644 return true;
645 }
646 if (device_context->GetDeviceType() != node_device_tensor->GetDeviceType()) {
647 device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
648 {node_device_tensor->device_name(), node_device_tensor->device_id()});
649 MS_EXCEPTION_IF_NULL(device_context);
650 MS_EXCEPTION_IF_NULL(device_context->device_res_manager_);
651 }
652
653 if (copy_input_device_tensors_[i] == nullptr) {
654 MS_EXCEPTION_IF_NULL(node_device_tensor->kernel_tensor());
655 const auto new_kernel_tensor = node_device_tensor->kernel_tensor()->CloneKernelTensor();
656 MS_EXCEPTION_IF_NULL(new_kernel_tensor);
657 new_kernel_tensor->set_device_name(device_context->device_context_key().device_name_);
658 new_kernel_tensor->set_device_id(device_context->device_context_key().device_id_);
659 new_kernel_tensor->set_device_ptr(nullptr);
660
661 copy_input_device_tensors_[i] = device_context->device_res_manager_->CreateDeviceAddress(new_kernel_tensor);
662 MS_LOG(DEBUG) << "Create new device tensor:" << copy_input_device_tensors_[i] << " index:" << i
663 << " for actor:" << GetAID();
664 }
665 auto copy_device_tensor = copy_input_device_tensors_[i];
666 MS_EXCEPTION_IF_NULL(copy_device_tensor);
667 copy_device_tensor->set_user_data(node_device_tensor->user_data());
668 copy_device_tensor->set_need_sync_user_data(node_device_tensor->need_sync_user_data());
669 if ((copy_device_tensor->GetPtr() == nullptr) &&
670 (!device_context->device_res_manager_->AllocateMemory(copy_device_tensor.get()))) {
671 MS_LOG(ERROR) << "Device(id:" << std::to_string(device_context->device_context_key().device_id_)
672 << ") memory isn't enough and alloc failed, kernel name: " << GetAID()
673 << ", alloc size: " + std::to_string(copy_device_tensor->GetSize()) << "B.";
674 return true;
675 }
676 MS_LOG(DEBUG) << "Alloc memory for device tensor:" << copy_device_tensor << " ptr:" << copy_device_tensor->GetPtr()
677 << " size:" << copy_device_tensor->GetSize() << " index:" << i << " for actor:" << GetAID();
678 if (type_ != KernelTransformType::kSuperKernelActor) {
679 node_device_tensor->set_ptr(copy_device_tensor->GetMutablePtr());
680 } else {
681 node_device_tensor->set_ptr(copy_device_tensor->GetValidPtr(copy_device_tensor->stream_id()));
682 }
683 node_device_tensor->set_from_mem_pool(false);
684 return false;
685 }
686
CopyInputData(const OpContext<DeviceTensor> * context,const KernelGraphPtr & graph)687 bool SuperKernelActor::CopyInputData(const OpContext<DeviceTensor> *context, const KernelGraphPtr &graph) {
688 MS_EXCEPTION_IF_NULL(context);
689 MS_EXCEPTION_IF_NULL(graph);
690 if (device_contexts_.empty() || device_contexts_[0] == nullptr ||
691 device_contexts_[0]->device_res_manager_ == nullptr) {
692 MS_LOG(ERROR) << "Invalid device context for actor:" << GetAID();
693 return false;
694 }
695 auto device_context = device_contexts_[0];
696 auto &input_nodes = graph->input_nodes();
697 if (input_device_tensors_.size() != node_device_tensors_.size()) {
698 MS_LOG(ERROR) << "The size of input_device_tensors_[" << input_device_tensors_.size()
699 << "] is not equal to the size of node_device_tensors_[" << node_device_tensors_.size() << "].";
700 return false;
701 }
702
703 for (size_t i = 0; i < input_device_tensors_.size(); ++i) {
704 auto &node_device_tensor = node_device_tensors_[i];
705 MS_EXCEPTION_IF_NULL(node_device_tensor);
706 auto &input_device_tensor = input_device_tensors_[i];
707 if (InputDataNoNeedCopy(input_nodes[i], input_device_tensor, node_device_tensor, type_)) {
708 MS_LOG(DEBUG) << "Actor:" << GetAID() << " input device tensor " << i << ":" << input_device_tensor
709 << " no need copy.";
710 continue;
711 }
712 MS_EXCEPTION_IF_NULL(input_nodes[i]);
713 const auto &node_device_kernel_tensor = node_device_tensor->kernel_tensor();
714 MS_EXCEPTION_IF_NULL(input_device_tensor);
715 const auto &input_kernel_tensor = input_device_tensor->kernel_tensor();
716 MS_EXCEPTION_IF_NULL(node_device_kernel_tensor);
717 MS_EXCEPTION_IF_NULL(input_kernel_tensor);
718 UpdateShape(input_nodes[i], node_device_tensor, input_device_tensor, type_);
719 node_device_tensor->set_user_data(input_device_tensor->user_data());
720 node_device_tensor->set_need_sync_user_data(input_device_tensor->need_sync_user_data());
721 if (type_ != KernelTransformType::kSuperKernelActor) {
722 node_device_kernel_tensor->SetValue(input_kernel_tensor->GetValueTrack());
723 }
724
725 // Copy.
726 DeviceTensorPtr copy_device_tensor = nullptr;
727 // If the input is not a persist device address, in a heterogeneous scenario, a new device address needs to
728 // be created. And set ptr to node device address to support the zero copy of graph input nodes.
729 if (!node_device_tensor->is_ptr_persisted()) {
730 if (CopyInputDataPersistedHandle(device_context, input_device_tensor, node_device_tensor, i)) {
731 continue;
732 }
733 copy_device_tensor = copy_input_device_tensors_[i];
734 } else {
735 if (node_device_tensor->GetPtr() == nullptr) {
736 MS_LOG(INFO) << "The node device tensor, which shared with another graph, has no device memory and will skip "
737 "copy for actor:"
738 << GetAID();
739 continue;
740 }
741 copy_device_tensor = node_device_tensor;
742 }
743 MS_EXCEPTION_IF_NULL(copy_device_tensor);
744 MS_LOG(INFO) << "The input data of node:" << input_nodes[i]->DebugString()
745 << " need copy from device address:" << input_device_tensor << " ptr:" << input_device_tensor->GetPtr()
746 << " size:" << input_device_tensor->GetSize() << ", type:" << input_device_tensor->GetDeviceType()
747 << " to device address:" << copy_device_tensor << " ptr:" << copy_device_tensor->GetPtr()
748 << " size:" << copy_device_tensor->GetSize() << ", type:" << copy_device_tensor->GetDeviceType()
749 << ", is ref node need copy back:" << is_parameters_need_copy_[i] << " for actor:" << GetAID();
750 if (!Copy(copy_device_tensor.get(), input_device_tensor)) {
751 MS_LOG(ERROR) << "Copy data failed for actor:" << GetAID() << " input index:" << i;
752 continue;
753 }
754
755 if (is_parameters_need_copy_[i]) {
756 ref_node_addr_map_[copy_device_tensor.get()] = input_device_tensor;
757 }
758 }
759 return true;
760 }
761
SendMemoryFreeReq(OpContext<DeviceTensor> * const context)762 void SuperKernelActor::SendMemoryFreeReq(OpContext<DeviceTensor> *const context) {
763 MS_EXCEPTION_IF_NULL(context);
764 MS_EXCEPTION_IF_NULL(graph_);
765
766 if (device_contexts_.empty() || device_contexts_[0] == nullptr ||
767 device_contexts_[0]->device_res_manager_ == nullptr) {
768 SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(GraphExecutionStrategy::kPipeline, (*context),
769 "Invalid device context for super kernel actor:" + GetAID().Name());
770 }
771 if (memory_free_lists_.size() > 0 && memory_free_lists_.back().size() > 0) {
772 if (common::IsNeedProfileMemory()) {
773 for (auto data : memory_free_lists_.back()) {
774 auto output_address = reinterpret_cast<std::uintptr_t>(data);
775 MS_LOG(WARNING) << "Need Profile Memory, Memory need Decrease DynamicRefCount, actor name: " << GetAID().Name()
776 << ", kernel graph: " << graph_->ToString() << ", device address class ptr: " << output_address
777 << ", device address size: " << data->GetSize() << ", device address addr: " << data->GetPtr();
778 }
779 }
780
781 if (ActorDispatcher::is_memory_free_sync()) {
782 ActorDispatcher::SendSync(memory_manager_aid_, &MemoryManagerActor::FreeMemory, &(memory_free_lists_.back()),
783 device_contexts_[0], context, GetAID());
784 } else {
785 ActorDispatcher::Send(memory_manager_aid_, &MemoryManagerActor::FreeMemory, &(memory_free_lists_.back()),
786 device_contexts_[0], context, GetAID());
787 }
788 }
789
790 // Free the address that is the temp store for kernel input copy.
791 for (auto ©_input_device_tensor : copy_input_device_tensors_) {
792 if ((copy_input_device_tensor != nullptr) && (copy_input_device_tensor->GetPtr() != nullptr)) {
793 device_contexts_[0]->device_res_manager_->FreeMemory(copy_input_device_tensor.get());
794 }
795 }
796 }
797
BuildKernelActors()798 void SuperKernelActor::BuildKernelActors() {
799 MS_EXCEPTION_IF_NULL(graph_);
800 const auto &execution_order = graph_->execution_order();
801 size_t kernel_num = execution_order.size();
802 kernel_actors_.resize(kernel_num);
803
804 mindspore::HashMap<AnfNodePtr, KernelActor *> node_to_kernel_actor_;
805
806 // 1. Create kernel actor if need.
807 for (size_t i = 0; i < kernel_num; i++) {
808 const auto &kernel = execution_order[i];
809 MS_EXCEPTION_IF_NULL(kernel);
810 if (IsSkippedKernelActor(kernel)) {
811 kernel_actors_[i] = nullptr;
812 continue;
813 }
814
815 if (!IsKernelActor(kernel, GraphExecutionStrategy::kPipeline)) {
816 MS_LOG(WARNING) << "Find not real cnode in execution order for graph: " << graph_->graph_id();
817 kernel_actors_[i] = nullptr;
818 continue;
819 }
820
821 auto ref_input_indexes = FetchModifiableRefInputIndex(kernel);
822 auto ref_output_indexes = FetchModifiableRefOutputIndex(kernel, graph_);
823 const auto &real_device_context = device::FetchRealDeviceContext(kernel, device_contexts_[0]);
824 MS_EXCEPTION_IF_NULL(real_device_context);
825 if (IsRpcActor(kernel)) {
826 MS_LOG(EXCEPTION) << "Can not launch a sub graph which contains rpc kernel by kbk.";
827 } else if (IsInnerControlFlowActor(kernel)) {
828 MS_LOG(EXCEPTION) << "Can not launch a sub graph which contains ConditionSwitch or ConditionSwitch by kbk.";
829 }
830
831 KernelActorPtr kernel_actor = std::make_shared<KernelActor>(
832 kernel->fullname_with_scope(), kernel, real_device_context, memory_manager_aid_, debug_aid_, recorder_aid_,
833 GraphExecutionStrategy::kPipeline, ref_input_indexes, ref_output_indexes);
834 MS_EXCEPTION_IF_NULL(kernel_actor);
835 kernel_actors_[i] = kernel_actor;
836
837 // Set the member of kernel actor.
838 kernel_actor->is_launch_skipped_ =
839 common::AnfAlgo::IsNopNode(kernel) && graph_->IsInRefOutputMap(std::make_pair(kernel, 0));
840 kernel_actor->inputs_continuous_memory_ =
841 (common::AnfAlgo::IsCommunicationOp(kernel) && common::AnfAlgo::GetCNodeName(kernel) != kMatMulAllReduceOpName) &&
842 (common::AnfAlgo::GetInputTensorNum(kernel) > 1);
843
844 SchedulerHelper::AddSomasInfo(kernel_actor.get());
845
846 node_to_kernel_actor_[kernel] = kernel_actor.get();
847 }
848
849 // 2. Add somas info.
850 // AddSomasOutput
851 for (const auto &front_backend_pair : graph_->front_node_to_graph_output_map()) {
852 const auto &output_with_index = front_backend_pair.second;
853 auto output_kernel = output_with_index.first;
854 auto output_index = output_with_index.second;
855 MS_EXCEPTION_IF_NULL(output_kernel);
856 auto origin_output_with_index = front_backend_pair.first;
857 if (origin_output_with_index.first == nullptr) {
858 MS_LOG(WARNING) << "The graph " << graph_->graph_id() << " output node:" << output_kernel->fullname_with_scope()
859 << " with index: " << output_index << " has no front node.";
860 continue;
861 }
862 if (!AnfUtils::IsRealCNodeKernel(output_kernel)) {
863 continue;
864 }
865 auto iter = node_to_kernel_actor_.find(output_kernel);
866 if (iter == node_to_kernel_actor_.end()) {
867 MS_LOG_WITH_NODE(EXCEPTION, output_kernel)
868 << "Can not find kernel actor for node: " << output_kernel->fullname_with_scope();
869 }
870 const auto &output_actor = iter->second;
871 MS_EXCEPTION_IF_NULL(output_actor);
872 output_actor->is_output_kernel_ = true;
873 SchedulerHelper::AddSomasInfoForGraphOutput(output_actor, output_kernel, output_index, graph_->graph_id());
874 }
875
876 // 3. Initialize all kernel actor.
877 for (size_t i = 0; i < kernel_num; i++) {
878 const auto &kernel_actor = kernel_actors_[i];
879 if (kernel_actor) {
880 kernel_actor->Init();
881 }
882 }
883 }
884
ParseInputIndex()885 void SuperKernelActor::ParseInputIndex() {
886 const auto &input_nodes = graph_->input_nodes();
887 size_t input_num = input_nodes.size();
888 mindspore::HashMap<AnfNode *, size_t> node_to_input_idx;
889 node_to_input_idx.reserve(input_num);
890
891 for (size_t i = 0; i < input_num; i++) {
892 node_to_input_idx[input_nodes[i].get()] = i;
893 }
894
895 const auto &execution_order = graph_->execution_order();
896 size_t kernel_num = execution_order.size();
897 for (size_t i = 0; i < kernel_num; i++) {
898 const auto &kernel = execution_order[i];
899 MS_EXCEPTION_IF_NULL(kernel);
900
901 if (!IsKernelActor(kernel, GraphExecutionStrategy::kPipeline) || IsSkippedKernelActor(kernel)) {
902 continue;
903 }
904
905 auto real_input_num = common::AnfAlgo::GetInputTensorNum(kernel);
906 for (size_t j = 0; j < real_input_num; j++) {
907 auto real_input_node = common::AnfAlgo::GetPrevNodeOutput(kernel, j, false);
908 MS_EXCEPTION_IF_NULL(real_input_node.first);
909 // Note: only record input data, persist weight in compile phase.
910 if (real_input_node.first->isa<Parameter>()) {
911 auto iter = node_to_input_idx.find(real_input_node.first.get());
912 if (iter == node_to_input_idx.end()) {
913 MS_LOG_WITH_NODE(EXCEPTION, real_input_node.first)
914 << "Can not find index for input node: " << real_input_node.first->fullname_with_scope();
915 }
916 kernel_input_to_graph_input_indices_[kernel.get()].emplace_back(j, iter->second);
917 } else if (real_input_node.first->isa<ValueNode>()) {
918 const auto &kernel_actor = kernel_actors_[i];
919 MS_EXCEPTION_IF_NULL(kernel_actor);
920
921 const auto &real_device_context = device::FetchRealDeviceContext(kernel, device_contexts_[0]);
922 MS_EXCEPTION_IF_NULL(real_device_context);
923 const auto &front_node = AnfAlgo::FetchFrontNodeByBackendNode(real_input_node.first, *graph_);
924 MS_EXCEPTION_IF_NULL(front_node);
925 auto device_address =
926 DeviceTensorStore::GetInstance().Fetch(front_node.get(), real_device_context->GetDeviceType());
927 MS_EXCEPTION_IF_NULL(device_address);
928 kernel_actor->SetInputDeviceTensor(device_address.get(), j);
929 }
930 }
931 }
932 }
933
CalcRefCount()934 void SuperKernelActor::CalcRefCount() {
935 const auto &execution_order = graph_->execution_order();
936 size_t kernel_num = execution_order.size();
937 for (size_t i = 0; i < kernel_num; i++) {
938 const auto &kernel = execution_order[i];
939 MS_EXCEPTION_IF_NULL(kernel);
940 if (!IsKernelActor(kernel, GraphExecutionStrategy::kPipeline) || IsSkippedKernelActor(kernel)) {
941 continue;
942 }
943
944 auto input_num = common::AnfAlgo::GetInputTensorNum(kernel);
945 for (size_t j = 0; j < input_num; j++) {
946 auto input_node_with_idx = common::AnfAlgo::GetPrevNodeOutput(kernel, j, false);
947 MS_EXCEPTION_IF_NULL(input_node_with_idx.first);
948
949 if (input_node_with_idx.first->isa<CNode>()) {
950 if (IsSkippedKernelActor(input_node_with_idx.first)) {
951 const auto &real_input_node_with_idx =
952 common::AnfAlgo::GetPrevNodeOutput(input_node_with_idx.first, 0, false);
953 UpdateRefCountWithOnlyDependShape(kernel, j, real_input_node_with_idx.first, real_input_node_with_idx.second);
954 } else {
955 UpdateRefCountWithOnlyDependShape(kernel, j, input_node_with_idx.first, input_node_with_idx.second);
956 }
957 } else if (IsPersistentDeviceTensor(input_node_with_idx.first)) {
958 UpdateRefCount(input_node_with_idx.first, input_node_with_idx.second, true);
959 }
960 }
961 }
962 }
963 } // namespace runtime
964 } // namespace mindspore
965