• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <set>
18 #include <algorithm>
19 #include "include/backend/mem_reuse/mem_tracker.h"
20 #include "runtime/graph_scheduler/actor/super_kernel_actor.h"
21 #include "runtime/graph_scheduler/scheduler_helper.h"
22 #include "runtime/graph_scheduler/actor/output_actor.h"
23 #include "runtime/graph_scheduler/actor/memory_manager_actor.h"
24 #include "runtime/graph_scheduler/actor/debug_actor.h"
25 #include "mindrt/include/async/async.h"
26 #include "utils/phase.h"
27 #include "utils/log_adapter.h"
28 
29 namespace mindspore {
30 namespace runtime {
31 namespace {
UpdateShape(const AnfNodePtr & input_node,const DeviceTensorPtr & node_device_tensor,DeviceTensor * input_device_tensor,const KernelTransformType & type)32 inline void UpdateShape(const AnfNodePtr &input_node, const DeviceTensorPtr &node_device_tensor,
33                         DeviceTensor *input_device_tensor, const KernelTransformType &type) {
34   MS_EXCEPTION_IF_NULL(input_node);
35   const auto &node_device_kernel_tensor = node_device_tensor->kernel_tensor();
36   MS_EXCEPTION_IF_NULL(input_device_tensor);
37   const auto &input_kernel_tensor = input_device_tensor->kernel_tensor();
38   MS_EXCEPTION_IF_NULL(node_device_kernel_tensor);
39   MS_EXCEPTION_IF_NULL(input_kernel_tensor);
40   if (type != KernelTransformType::kSuperKernelActor || input_node->cast<ParameterPtr>()->has_dynamic_shape()) {
41     // For dynamic shape in sub graph sink and any type parameter, the input size should be updated.
42     node_device_tensor->SetSize(input_device_tensor->GetSize());
43     // Update Shape.
44     node_device_kernel_tensor->SetShape(input_kernel_tensor->GetShape()->Clone());
45   }
46 }
47 
InputDataNoNeedCopy(const AnfNodePtr & input_node,DeviceTensor * input_device_tensor,const DeviceTensorPtr & node_device_tensor,const KernelTransformType & type)48 inline bool InputDataNoNeedCopy(const AnfNodePtr &input_node, DeviceTensor *input_device_tensor,
49                                 const DeviceTensorPtr &node_device_tensor, const KernelTransformType &type) {
50   if (input_device_tensor == node_device_tensor.get()) {
51     (void)input_device_tensor->TouchSyncHandler();
52     return true;
53   }
54 
55   if (input_device_tensor == nullptr) {
56     return true;
57   }
58 
59   UpdateShape(input_node, node_device_tensor, input_device_tensor, type);
60 
61   if (TEST_FLAG(node_device_tensor->flag(), device::kDeviceAddressFlagNotUsed) ||
62       input_device_tensor->GetPtr() == node_device_tensor->GetPtr()) {
63     return true;
64   }
65 
66   return false;
67 }
68 
UpdateRefCountWithOnlyDependShape(const CNodePtr & kernel,size_t input_index,const AnfNodePtr & node,size_t output_index)69 void UpdateRefCountWithOnlyDependShape(const CNodePtr &kernel, size_t input_index, const AnfNodePtr &node,
70                                        size_t output_index) {
71   auto ms_context = MsContext::GetInstance();
72   MS_EXCEPTION_IF_NULL(ms_context);
73   static const bool enable_infer_boost = ms_context->IsEnableInferBoost();
74   if (enable_infer_boost) {
75     UpdateRefCount(node, output_index, false);
76     return;
77   }
78 
79   // Shape depend kernel should not increase ref count.
80   const auto &only_depend_shape_attr = common::AnfAlgo::GetCNodePrimitiveAttr(kernel, kAttrOnlyDependShape);
81   if (only_depend_shape_attr != nullptr) {
82     const auto &only_depend_shape = GetValue<std::vector<bool>>(only_depend_shape_attr);
83     if (input_index < only_depend_shape.size() && only_depend_shape[input_index]) {
84       // Only depend shape no need to increase ref count, and update flag.
85       auto device_tensor = AnfAlgo::GetMutableOutputAddr(node, output_index, false);
86       MS_EXCEPTION_IF_NULL(device_tensor);
87       device_tensor->UpdateFlag(device::kDeviceAddressFlagNullptr);
88       return;
89     }
90   }
91   UpdateRefCount(node, output_index, false);
92 }
93 }  // namespace
Init()94 void SuperKernelActor::Init() {
95   MS_EXCEPTION_IF_NULL(graph_);
96   // Check device contexts number.
97   if (device_contexts_.size() != device::kDeviceContextsNumOne) {
98     MS_LOG(EXCEPTION) << "The device contexts number is wrong.";
99   }
100 
101   // Set the number of actor running dependent messages.
102   running_dependent_msg_num_ = SizeToInt(input_datas_num_ + input_controls_num_);
103 
104   // Init the output data.
105   InitOutputData();
106   if (output_data_arrows_.size() != output_data_nodes_.size()) {
107     MS_LOG(EXCEPTION) << "The size of output data arrows is not equal to the output data nodes.";
108   }
109   if (output_data_arrows_.size() != output_data_.size()) {
110     MS_LOG(EXCEPTION) << "The size of output data arrows is not equal to the output data.";
111   }
112   for (size_t i = 0; i < output_data_arrows_.size(); ++i) {
113     auto &data_arrow = output_data_arrows_[i];
114     auto &output_node = output_data_nodes_[i];
115     auto data = output_data_[i].first.get();
116     MS_EXCEPTION_IF_NULL(data_arrow);
117     MS_EXCEPTION_IF_NULL(output_node);
118     MS_EXCEPTION_IF_NULL(data);
119     auto device_address = AnfAlgo::GetMutableOutputAddr(output_node, IntToSize(data_arrow->from_output_index_), false);
120     data->data_ = device_address.get();
121   }
122 
123   const auto &output_with_indexs = common::AnfAlgo::GetAllOutputWithIndex(graph_->output());
124   for (const auto &origin_output_with_index : output_with_indexs) {
125     const auto &output_with_index = common::AnfAlgo::FetchRealNodeSkipMonadControl(origin_output_with_index);
126     const auto &output_node = output_with_index.first;
127     MS_EXCEPTION_IF_NULL(output_node);
128     if (output_node->isa<CNode>() && (!HasAbstractMonad(output_node))) {
129       auto device_address = AnfAlgo::GetMutableOutputAddr(output_node, output_with_index.second, false);
130       MS_EXCEPTION_IF_NULL(device_address);
131       if (device_address->is_ptr_persisted() || graph_->is_dynamic_shape()) {
132         MS_LOG(DEBUG) << "Actor:" << GetAID() << " skip alloc memory for device address:" << device_address
133                       << " is persist:" << device_address->is_ptr_persisted()
134                       << " is dynamic shape:" << graph_->is_dynamic_shape()
135                       << " output node:" << output_node->DebugString();
136         continue;
137       }
138       // Free the ptr in device address of output node.
139       if (device_address->GetPtr() != nullptr) {
140         MS_LOG(INFO) << "Output node:" << output_node->DebugString() << " has a default ptr, maybe a mem leak.";
141         device_address->set_ptr(nullptr);
142       }
143       if (common::IsNeedProfileMemory()) {
144         device_address_to_node_[device_address.get()] = {device_address->GetSize(), output_node->fullname_with_scope()};
145       }
146       memory_alloc_list_.emplace_back(device_address.get());
147     }
148   }
149 
150   // Check whether the parameter needs to be copied out.
151   node_device_tensors_.resize(graph_->input_nodes().size());
152   is_parameters_need_copy_.resize(graph_->input_nodes().size());
153   copy_input_device_tensors_.resize(graph_->input_nodes().size());
154   for (size_t i = 0; i < graph_->input_nodes().size(); ++i) {
155     const auto &input_node = graph_->input_nodes()[i];
156     MS_EXCEPTION_IF_NULL(input_node);
157     node_device_tensors_[i] = AnfAlgo::GetMutableOutputAddr(input_node, 0, false);
158     if (!common::AnfAlgo::HasAbstractRef(input_node)) {
159       is_parameters_need_copy_[i] = false;
160       continue;
161     }
162     // If the parameter has ref attribute and is directly used by the kernel in the graph, it needs to be copied.
163     is_parameters_need_copy_[i] = true;
164   }
165 
166   if (enable_kbk_sub_graph_execute_) {
167     BuildKernelActors();
168     ParseInputIndex();
169     CalcRefCount();
170   }
171 
172   if (type_ == KernelTransformType::kSuperKernelActor && !enable_kbk_sub_graph_execute_) {
173     MS_EXCEPTION_IF_NULL(device_contexts_[0]);
174     MS_EXCEPTION_IF_NULL(device_contexts_[0]->graph_executor_);
175     device_contexts_[0]->graph_executor_->InitGraphInfo(graph_);
176   }
177 }
178 
FetchInputNodePosition(const AnfNodePtr & intput_node)179 size_t SuperKernelActor::FetchInputNodePosition(const AnfNodePtr &intput_node) {
180   MS_EXCEPTION_IF_NULL(intput_node);
181   MS_EXCEPTION_IF_NULL(graph_);
182 
183   auto &input_nodes = graph_->input_nodes();
184   const auto &iter = find(input_nodes.begin(), input_nodes.end(), intput_node);
185   if (iter == input_nodes.end()) {
186     MS_LOG_WITH_NODE(EXCEPTION, intput_node) << "Invalid input node:" << intput_node->fullname_with_scope();
187   }
188   return iter - input_nodes.begin();
189 }
190 
FetchInputDeviceTensor(OpContext<DeviceTensor> * const context)191 void SuperKernelActor::FetchInputDeviceTensor(OpContext<DeviceTensor> *const context) {
192   ProfilerRecorder profiler(ProfilerModule::kRuntime, ProfilerEvent::kPreLaunch, GetAID().Name());
193   MS_EXCEPTION_IF_NULL(context);
194   if (device_contexts_.empty() || device_contexts_[0] == nullptr) {
195     SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(GraphExecutionStrategy::kPipeline, (*context),
196                                                   "Invalid device context for super kernel actor:" + GetAID().Name());
197   }
198   std::vector<DeviceTensor *> memory_free_list;
199   const auto &data_iter = input_op_datas_.find(context->sequential_num_);
200   if (data_iter != input_op_datas_.end()) {
201     for (auto &input_data : data_iter->second) {
202       MS_EXCEPTION_IF_NULL(input_data);
203       MS_EXCEPTION_IF_NULL(input_data->data_);
204       size_t index = IntToSize(input_data->index_);
205       if (index >= input_device_tensors_.size()) {
206         std::string error_info = "Invalid input index:" + std::to_string(index) +
207                                  " total:" + std::to_string(input_device_tensors_.size()) +
208                                  " for actor:" + GetAID().Name();
209         SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
210       }
211       input_device_tensors_[index] = input_data->data_;
212 
213       if (common::IsNeedProfileMemory()) {
214         auto output_address = reinterpret_cast<std::uintptr_t>(input_device_tensors_[index]);
215         MS_LOG(WARNING) << "Need Profile Memory, Memory use, actor name: " << GetAID().Name()
216                         << ", kernel graph: " << graph_->ToString() << ", device address class ptr: " << output_address
217                         << ", device address size: " << input_device_tensors_[index]->GetSize()
218                         << ", device address addr: " << input_device_tensors_[index]->GetPtr() << ", index: " << index;
219       }
220 
221       if (input_data->data_->dynamic_ref_count() != INT32_MAX) {
222         (void)memory_free_list.emplace_back(input_data->data_);
223       }
224     }
225     memory_free_lists_.push(memory_free_list);
226   }
227 }
228 
Run(OpContext<DeviceTensor> * const context)229 void SuperKernelActor::Run(OpContext<DeviceTensor> *const context) {
230   MS_EXCEPTION_IF_NULL(context);
231   MS_EXCEPTION_IF_NULL(graph_);
232   if (device::tracker::MemTrackerManager::GetInstance().IsEnabled()) {
233     device::tracker::CALL_MEMORY_TRACKER_WITH_FILE(AddTask, GetAID().Name(), "SuperKernelActor", graph_->ToString());
234   }
235 
236   if (enable_kbk_sub_graph_execute_) {
237     return RunGraphKernelByKernel(context);
238   }
239   if (device_contexts_.empty() || device_contexts_[0] == nullptr) {
240     SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "Invalid device context for super kernel actor:" + GetAID().Name());
241   }
242   MS_LOG(INFO) << "Super kernel actor(" << GetAID().Name()
243                << ") launches graph: " << std::to_string(graph_->graph_id());
244   if (common::IsNeedProfileMemory()) {
245     MS_LOG(WARNING) << "Need Profile Memory, launch actor name: " << GetAID().Name()
246                     << ", kernel graph: " << graph_->ToString();
247   }
248   if (!WaitRuntimePipelineFinish(context)) {
249     MS_LOG(INFO) << "Run failed and early stop.";
250     return;
251   }
252   FetchInputDeviceTensor(context);
253   if (!already_fetch_persistent_device_tensor_) {
254     FetchPersistentDeviceTensor();
255     already_fetch_persistent_device_tensor_ = IsTwoPhaseInfer();
256   }
257 
258   if (device::tracker::MemTrackerManager::GetInstance().IsEnabled()) {
259     for (auto &device_addr : input_device_tensors_) {
260       if (device_addr == nullptr || !device_addr->IsPtrValid()) {
261         continue;
262       }
263       device::tracker::CALL_MEMORY_TRACKER_WITH_FILE(UseMemBlock, GetAID().Name(), device_addr->GetPtr());
264     }
265   }
266   if (memory_alloc_list_.size() > 0) {
267     for (auto &device_tensor : memory_alloc_list_) {
268       if (device_tensor->IsNotNeedAlloc()) {
269         continue;
270       }
271       if (common::IsNeedProfileMemory()) {
272         MS_EXCEPTION_IF_NULL(device_tensor);
273         auto &info = device_address_to_node_[device_tensor];
274         auto output_address = reinterpret_cast<std::uintptr_t>(device_tensor);
275         MS_LOG(WARNING) << "Need Profile Memory, Memory need allocated, actor name: " << GetAID().Name()
276                         << ", kernel graph: " << graph_->ToString() << ", node: " << info.node_full_name
277                         << ", device address class ptr: " << output_address << ", device address size: " << info.size;
278       }
279       device::tracker::CALL_MEMORY_TRACKER_WITH_FILE(
280         AddMemInfo, GetAID().Name(), device::tracker::MemType::kGraphOutput, device_tensor->GetSize(), device_tensor);
281     }
282     SendMemoryAllocReq(context);
283   } else {
284     OnMemoryAllocFinish(context);
285   }
286   if (common::IsNeedProfileMemory()) {
287     MS_LOG(WARNING) << "Need Profile Memory, end launch, actor name: " << GetAID().Name()
288                     << ", kernel graph: " << graph_->ToString();
289   }
290 }
291 
FetchPersistentDeviceTensor()292 void SuperKernelActor::FetchPersistentDeviceTensor() {
293   for (auto &device_tensor_store_key : device_tensor_store_keys_) {
294     auto input_device_tensor = DeviceTensorStore::GetInstance()
295                                  .Fetch(device_tensor_store_key.second.get(), device_contexts_[0]->GetDeviceType())
296                                  .get();
297     // Ge backend maybe nullptr.
298     if (input_device_tensor == nullptr) {
299       MS_LOG(DEBUG) << "Failed get device tensor for node:" << device_tensor_store_key.second->DebugString()
300                     << " index:" << device_tensor_store_key.first;
301       continue;
302     }
303 
304     size_t index = device_tensor_store_key.first;
305     input_device_tensors_[index] = input_device_tensor;
306   }
307 }
308 
UpdateMemoryTraceMangerStatus(OpContext<DeviceTensor> * const context)309 void SuperKernelActor::UpdateMemoryTraceMangerStatus(OpContext<DeviceTensor> *const context) {
310   MemoryTraceManager::GetInstance().PickMemoryTrackInfoForGraph(graph_->graph_id());
311   if (!ActorDispatcher::enable_static_shape()) {
312     ProfilerRecorder profiler(ProfilerModule::kRuntime, ProfilerEvent::kMemoryAlloc, GetAID().Name());
313 
314     const std::shared_ptr<mindspore::HashMap<CNodePtr, std::vector<KernelMemoryTraceBlockPtr>>> &all_kernel_block_info =
315       MemoryTraceManager::GetInstance().GetAllKernelBlocksnfo();
316     MS_EXCEPTION_IF_NULL(all_kernel_block_info);
317 
318     if (!all_kernel_block_info->empty()) {
319       size_t kernel_num = kernel_actors_.size();
320       for (size_t i = 0; i < kernel_num; i++) {
321         const auto &kernel_actor = kernel_actors_[i];
322         if (kernel_actor == nullptr) {
323           continue;
324         }
325 
326         const auto &kernel = kernel_actor->kernel_;
327         MS_EXCEPTION_IF_NULL(kernel);
328 
329         const auto &iter = all_kernel_block_info->find(kernel);
330         if (iter == all_kernel_block_info->end()) {
331           MS_LOG(DEBUG) << "Not found kernel block info for kernel: " << kernel->fullname_with_scope()
332                         << ", is output kernel: " << kernel_actor->is_output_kernel_;
333         } else {
334           const auto &kernel_mem_block = iter->second;
335           for (auto &block : kernel_mem_block) {
336             MS_EXCEPTION_IF_NULL(block);
337             if (block->mem_type_ == kOutputMem) {
338               kernel_actor->output_kernel_tensors_.at(block->index_)->set_device_ptr(nullptr);
339             } else {
340               kernel_actor->workspace_kernel_tensors_.at(block->index_)->set_device_ptr(nullptr);
341             }
342           }
343         }
344       }
345     }
346 
347     // First step for dynamic shape, need to record memory trace.
348     MemoryTraceManager::GetInstance().Clear();
349     static const size_t memory_block_size = 3000;
350     MemoryTraceManager::GetInstance().ReserveKernelMemoryBlocks(memory_block_size, device_contexts_[0]);
351   } else {
352     // Not first step for dynamic shape, use record trace memory.
353     // Allocate block memory for static memory step.
354     ProfilerRecorder profiler(ProfilerModule::kRuntime, ProfilerEvent::kMemoryAlloc, GetAID().Name());
355     const auto &merge_blocks_with_device_context = MemoryTraceManager::GetInstance().GetMergeBlocks();
356     MS_EXCEPTION_IF_NULL(merge_blocks_with_device_context);
357     for (auto &item : *merge_blocks_with_device_context) {
358       const auto &device_context = item.first;
359       MS_EXCEPTION_IF_NULL(device_context);
360       const auto &merge_blocks = item.second;
361       for (auto &block : merge_blocks) {
362         MS_EXCEPTION_IF_NULL(block);
363         static const size_t kMemoryAlignSize = 1024;
364         void *block_addr = device_context->device_res_manager_->AllocateMemory(block->size_ + kMemoryAlignSize);
365         if (block_addr == nullptr) {
366           SET_OPCONTEXT_MEMORY_ALLOC_FAIL_BY_STRATEGY(GraphExecutionStrategy::kPipeline, *context,
367                                                       *(device_contexts_[0]), GetAID().Name(), block->size_);
368         }
369         block->start_ = reinterpret_cast<uint8_t *>(block_addr);
370       }
371     }
372   }
373 }
374 
SetTraceMemoryForKernel(const KernelActorPtr & kernel_actor)375 void SuperKernelActor::SetTraceMemoryForKernel(const KernelActorPtr &kernel_actor) {
376   const auto &kernel = kernel_actor->kernel();
377   MS_EXCEPTION_IF_NULL(kernel);
378 
379   // Allocate trace memory for static memory step.
380   const std::shared_ptr<mindspore::HashMap<CNodePtr, std::vector<KernelMemoryTraceBlockPtr>>> &all_kernel_block_info =
381     MemoryTraceManager::GetInstance().GetAllKernelBlocksnfo();
382   MS_EXCEPTION_IF_NULL(all_kernel_block_info);
383   const auto &iter = all_kernel_block_info->find(kernel);
384   if (iter == all_kernel_block_info->end()) {
385     MS_LOG(DEBUG) << "Not found kernel block info for kernel: " << kernel->fullname_with_scope()
386                   << ", is output kernel: " << kernel_actor->is_output_kernel_;
387   } else {
388     const auto &kernel_mem_block = iter->second;
389     const auto &merge_blocks_with_device_context = MemoryTraceManager::GetInstance().GetMergeBlocks();
390     MS_EXCEPTION_IF_NULL(merge_blocks_with_device_context);
391     const auto &merge_blocks = merge_blocks_with_device_context->at(kernel_actor->device_contexts_[0]);
392     for (auto &block : kernel_mem_block) {
393       MS_EXCEPTION_IF_NULL(block);
394       void *ptr = merge_blocks.at(block->in_memory_trace_block_index_)->start_ + block->offset_in_memory_trace_block_;
395       MS_EXCEPTION_IF_NULL(ptr);
396       if (block->mem_type_ == kOutputMem) {
397         kernel_actor->output_kernel_tensors_.at(block->index_)->set_device_ptr(ptr);
398       } else {
399         kernel_actor->workspace_kernel_tensors_.at(block->index_)->set_device_ptr(ptr);
400       }
401     }
402   }
403 }
404 
RunGraphKernelByKernel(OpContext<DeviceTensor> * const context)405 void SuperKernelActor::RunGraphKernelByKernel(OpContext<DeviceTensor> *const context) {
406   if (!ActorDispatcher::enable_async_launch_kernel()) {
407     std::string error_info =
408       "Runtime pipeline optimization is disabled, failed to execute graph kernel by kernel mode.";
409     MS_LOG(ERROR) << "Run graph failed, graph id: " << std::to_string(graph_->graph_id()) << ". " << error_info;
410     SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
411   }
412   if (!graph_->is_dynamic_shape()) {
413     ActorDispatcher::set_enable_static_shape(false);
414   }
415 
416   // 1. Fetch input data
417   FetchInputDeviceTensor(context);
418   if (!already_fetch_persistent_device_tensor_) {
419     FetchPersistentDeviceTensor();
420     already_fetch_persistent_device_tensor_ = true;
421   }
422 
423   // 2. Allocate somas memory for graph
424   if ((somas_info_ != nullptr) && (somas_info_->whole_block_size_ != 0)) {
425     MemoryManagerActor::GetInstance()->AllocateSomasMemory(somas_info_, device_contexts_[0], context, GetAID());
426   }
427   const auto &phase = PhaseManager::GetInstance().phase();
428   bool is_increment_graph = (phase.find("increment") != std::string::npos);
429   if (enable_trace_memory_ && graph_->is_dynamic_shape() && is_increment_graph) {
430     MS_LOG(DEBUG) << "Enable trace memory for increment inference graph: " << graph_->graph_id()
431                   << ", phase: " << phase;
432     UpdateMemoryTraceMangerStatus(context);
433 
434     if (IsRunningFailed(context)) {
435       // Maybe allocate memory failed, early stop to run graph.
436       MS_LOG(INFO) << "Run failed and early stop to run graph: " << graph_->graph_id();
437       return;
438     }
439   }
440 
441   // 3. Launch all kernels
442   size_t kernel_num = kernel_actors_.size();
443   const auto &execution_order = graph_->execution_order();
444   for (size_t i = 0; i < kernel_num; i++) {
445     const auto &kernel_actor = kernel_actors_[i];
446     if (kernel_actor == nullptr) {
447       continue;
448     }
449     const auto &kernel = execution_order[i];
450     // 3.1 Prepare input data for kernel
451     const auto &iter = kernel_input_to_graph_input_indices_.find(kernel.get());
452     if (iter != kernel_input_to_graph_input_indices_.end()) {
453       std::vector<std::pair<size_t, size_t>> &input_to_graph_input_indices = iter->second;
454       for (const auto &item : input_to_graph_input_indices) {
455         kernel_actor->SetInputDeviceTensor(input_device_tensors_[item.second], item.first);
456       }
457     }
458 
459     // 3.2 Allocate somas memory for this kernel
460     kernel_actor->SetSomasMemory(context);
461 
462     if (ActorDispatcher::enable_use_trace_memory()) {
463       SetTraceMemoryForKernel(kernel_actor);
464     }
465 
466     // Async Run Infer or Launch
467     if (ActorDispatcher::enable_runtime_multi_pipeline() && !ActorDispatcher::enable_static_shape()) {
468       // If the kernel need user data and is dynamic, maybe need input kernel's output user data to infer shape, this
469       // value depend case can not handle in KernelTensor auto sync phase currently.
470       if (kernel_actor->kernel_mod_->need_user_data() && kernel_actor->has_dynamic_) {
471         MS_LOG(DEBUG) << "Begin wait runtime pipeline for kernel: " << kernel_actor->kernel_->fullname_with_scope();
472         if (!WaitRuntimePipelineFinish(context)) {
473           MS_LOG(INFO) << "Run failed and early stop for kernel: " << kernel_actor->kernel_->fullname_with_scope();
474           return;
475         }
476         MS_LOG(DEBUG) << "End wait runtime pipeline for kernel: " << kernel_actor->kernel_->fullname_with_scope();
477       }
478 
479       // Push run task to pipeline.
480       // Note: dynamic value or static shape also need push task into infer actor to make sure correct kernel
481       // execution order.
482       Async(kernel_async_infer_aid_, &KernelAsyncInferActor::InferShape, context, kernel_actor.get());
483 
484       // The computed depend kernel should wait output shape update after kernel launch.
485       if (kernel_actor->kernel_mod_->IsNeedUpdateOutputShapeAndSize()) {
486         MS_LOG(DEBUG) << "Begin wait runtime pipeline for kernel: " << kernel_actor->kernel_->fullname_with_scope();
487         if (!WaitRuntimePipelineFinish(context)) {
488           MS_LOG(INFO) << "Run failed and early stop for kernel: " << kernel_actor->kernel_->fullname_with_scope();
489           return;
490         }
491         MS_LOG(DEBUG) << "End wait runtime pipeline for kernel: " << kernel_actor->kernel_->fullname_with_scope();
492       }
493     } else {
494       Async(kernel_async_launch_aid_, &KernelAsyncLaunchActor::LaunchKernel, context, kernel_actor.get());
495     }
496   }
497 
498   WaitRuntimePipelineFinish(context);
499 
500   // 4. Free somas memory for graph
501   if ((somas_info_ != nullptr) && (somas_info_->whole_block_size_ != 0)) {
502     MemoryManagerActor::GetInstance()->FreeSomasMemory(somas_info_, device_contexts_[0], context, GetAID());
503   }
504 
505   if (ActorDispatcher::enable_trace_dynamic_memory()) {
506     // Record and analyse the memory trace of this step, use to optimize the memory manage performance.
507     ProfilerRecorder profiler(ProfilerModule::kRuntime, ProfilerEvent::kMemoryFree, GetAID().Name());
508     MemoryTraceManager::GetInstance().MergeBlocks();
509   }
510   if (ActorDispatcher::enable_use_trace_memory()) {
511     // Free block memory for static memory step.
512     ProfilerRecorder profiler(ProfilerModule::kRuntime, ProfilerEvent::kMemoryFree, GetAID().Name());
513     const auto &merge_blocks_with_device_context = MemoryTraceManager::GetInstance().GetMergeBlocks();
514     MS_EXCEPTION_IF_NULL(merge_blocks_with_device_context);
515     for (auto &item : *merge_blocks_with_device_context) {
516       const auto &device_context = item.first;
517       MS_EXCEPTION_IF_NULL(device_context);
518       const auto &merge_blocks = item.second;
519       for (auto &block : merge_blocks) {
520         MS_EXCEPTION_IF_NULL(block);
521         device_context->device_res_manager_->FreeMemory(block->start_);
522       }
523     }
524   }
525 
526   // Free input data.
527   PostRun(context);
528 }
529 
SendMemoryAllocReq(OpContext<DeviceTensor> * const context)530 void SuperKernelActor::SendMemoryAllocReq(OpContext<DeviceTensor> *const context) {
531   MS_EXCEPTION_IF_NULL(context);
532   if (device_contexts_.empty() || device_contexts_[0] == nullptr) {
533     SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(GraphExecutionStrategy::kPipeline, (*context),
534                                                   "Invalid device context for super kernel actor:" + GetAID().Name());
535   }
536   sort(memory_alloc_list_.begin(), memory_alloc_list_.end(), [](const DeviceTensor *a, const DeviceTensor *b) {
537     MS_EXCEPTION_IF_NULL(a);
538     MS_EXCEPTION_IF_NULL(b);
539     return a->GetSize() > b->GetSize();
540   });
541   if (ActorDispatcher::is_memory_allocation_sync()) {
542     ActorDispatcher::SendSync(memory_manager_aid_, &MemoryManagerActor::AllocateMemory, &memory_alloc_list_,
543                               device_contexts_[0], context, GetAID());
544     OnMemoryAllocFinish(context);
545   } else {
546     ActorDispatcher::Send(memory_manager_aid_, &MemoryManagerActor::AllocateMemory, &memory_alloc_list_,
547                           device_contexts_[0], context, GetAID());
548   }
549 }
550 
OnMemoryAllocFinish(OpContext<DeviceTensor> * const context)551 void SuperKernelActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *const context) {
552   MS_EXCEPTION_IF_NULL(context);
553   MS_EXCEPTION_IF_NULL(graph_);
554   if (IsRunningFailed(context)) {
555     MS_LOG(INFO) << "Running failed in actor:" << GetAID().Name();
556     return;
557   }
558   {
559     ProfilerRecorder profiler(ProfilerModule::kRuntime, ProfilerEvent::kPreLaunch, GetAID().Name());
560     if (!CopyInputData(context, graph_)) {
561       std::string error_info = "Copy the input data failed, graph id: " + std::to_string(graph_->graph_id());
562       SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
563     }
564   }
565 
566   try {
567     const std::vector<tensor::Tensor> inputs;
568     std::vector<tensor::Tensor> outputs;
569     const std::map<string, string> compile_options;
570     if (device_contexts_.empty() || device_contexts_[0] == nullptr) {
571       SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(GraphExecutionStrategy::kPipeline, (*context),
572                                                     "Invalid device context for super kernel actor:" + GetAID().Name());
573     }
574     MS_EXCEPTION_IF_NULL(device_contexts_[0]->graph_executor_);
575     if (!IsSkippedLaunch(nullptr, graph_)) {
576       ProfilerRecorder profiler(ProfilerModule::kKernel, ProfilerEvent::kGraphLaunch, GetAID().Name());
577       auto ret = device_contexts_[0]->graph_executor_->RunGraph(graph_, inputs, &outputs, compile_options);
578       if (!ret) {
579         std::string error_info = "Launch graph failed, graph id: " + std::to_string(graph_->graph_id());
580         SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
581       }
582     } else if (common::IsNeedProfileMemory()) {
583       auto memory_size = device_contexts_[0]->graph_executor_->GetGraphFeatureMemory(graph_);
584       MS_LOG(WARNING) << "Need Profile Memory, graph: " << graph_->ToString() << ", feature memory: " << memory_size;
585       MS_LOG(WARNING) << "Need Profile Memory, max used static memory: "
586                       << device_contexts_[0]->device_res_manager_->GetMaxUsedMemorySize();
587     }
588   } catch (const std::exception &e) {
589     MsException::Instance().SetException();
590     std::string error_info = "Launch graph exception, graph id: " + std::to_string(graph_->graph_id());
591     SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
592   }
593 
594   {
595     ProfilerRecorder profiler(ProfilerModule::kRuntime, ProfilerEvent::kPostLaunch, GetAID().Name());
596     for (auto item : ref_node_addr_map_) {
597       MS_EXCEPTION_IF_NULL(item.first);
598       MS_EXCEPTION_IF_NULL(item.second);
599       MS_LOG(INFO) << "The input ref node copy back from address: " << item.first->GetPtr()
600                    << " to address: " << item.second->GetPtr() << ".";
601       if (!Copy(item.second, item.first)) {
602         SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "Copy data failed.");
603       }
604     }
605     ref_node_addr_map_.clear();
606   }
607 
608   // Debug actor is blocked, must wait debug actor callback message to process continue.
609   if (debug_aid_ != nullptr) {
610     SendDebugReq(context);
611     return;
612   }
613   PostRun(context);
614 }
615 
SendDebugReq(OpContext<DeviceTensor> * const context)616 void SuperKernelActor::SendDebugReq(OpContext<DeviceTensor> *const context) {
617   running_dependent_msg_num_ = 1;
618   if (device_contexts_.empty() || device_contexts_[0] == nullptr) {
619     SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(GraphExecutionStrategy::kPipeline, (*context),
620                                                   "Invalid device context for super kernel actor:" + GetAID().Name());
621   }
622   OnDebugFinish(context);
623 }
624 
CopyInputDataPersistedHandle(const DeviceContext * device_context,DeviceTensor * input_device_tensor,const DeviceTensorPtr & node_device_tensor,size_t i)625 bool SuperKernelActor::CopyInputDataPersistedHandle(const DeviceContext *device_context,
626                                                     DeviceTensor *input_device_tensor,
627                                                     const DeviceTensorPtr &node_device_tensor, size_t i) {
628   if ((input_device_tensor->GetDeviceType() == node_device_tensor->GetDeviceType()) &&
629       AnfAlgo::IsEquivalentFormat(input_device_tensor->format(), node_device_tensor->format())) {
630     MS_LOG(DEBUG) << "Not need copy for device tensor:" << node_device_tensor << " ptr:" << node_device_tensor->GetPtr()
631                   << " index:" << i << " for actor:" << GetAID();
632     // Set the ptr from input_device_tensor and set mem pool false to avoid memory double management for
633     // supporting zero copy.
634     if (type_ != KernelTransformType::kSuperKernelActor) {
635       node_device_tensor->set_ptr(input_device_tensor->GetMutablePtr());
636     } else {
637       node_device_tensor->set_ptr(input_device_tensor->GetValidPtr(input_device_tensor->stream_id()));
638     }
639     MS_LOG(DEBUG) << "Actor:" << GetAID() << "set need sync flag from:" << input_device_tensor
640                   << " to:" << node_device_tensor
641                   << " sync user data handler:" << node_device_tensor->need_sync_user_data();
642     node_device_tensor->set_from_mem_pool(false);
643     // continue
644     return true;
645   }
646   if (device_context->GetDeviceType() != node_device_tensor->GetDeviceType()) {
647     device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
648       {node_device_tensor->device_name(), node_device_tensor->device_id()});
649     MS_EXCEPTION_IF_NULL(device_context);
650     MS_EXCEPTION_IF_NULL(device_context->device_res_manager_);
651   }
652 
653   if (copy_input_device_tensors_[i] == nullptr) {
654     MS_EXCEPTION_IF_NULL(node_device_tensor->kernel_tensor());
655     const auto new_kernel_tensor = node_device_tensor->kernel_tensor()->CloneKernelTensor();
656     MS_EXCEPTION_IF_NULL(new_kernel_tensor);
657     new_kernel_tensor->set_device_name(device_context->device_context_key().device_name_);
658     new_kernel_tensor->set_device_id(device_context->device_context_key().device_id_);
659     new_kernel_tensor->set_device_ptr(nullptr);
660 
661     copy_input_device_tensors_[i] = device_context->device_res_manager_->CreateDeviceAddress(new_kernel_tensor);
662     MS_LOG(DEBUG) << "Create new device tensor:" << copy_input_device_tensors_[i] << " index:" << i
663                   << " for actor:" << GetAID();
664   }
665   auto copy_device_tensor = copy_input_device_tensors_[i];
666   MS_EXCEPTION_IF_NULL(copy_device_tensor);
667   copy_device_tensor->set_user_data(node_device_tensor->user_data());
668   copy_device_tensor->set_need_sync_user_data(node_device_tensor->need_sync_user_data());
669   if ((copy_device_tensor->GetPtr() == nullptr) &&
670       (!device_context->device_res_manager_->AllocateMemory(copy_device_tensor.get()))) {
671     MS_LOG(ERROR) << "Device(id:" << std::to_string(device_context->device_context_key().device_id_)
672                   << ") memory isn't enough and alloc failed, kernel name: " << GetAID()
673                   << ", alloc size: " + std::to_string(copy_device_tensor->GetSize()) << "B.";
674     return true;
675   }
676   MS_LOG(DEBUG) << "Alloc memory for device tensor:" << copy_device_tensor << " ptr:" << copy_device_tensor->GetPtr()
677                 << " size:" << copy_device_tensor->GetSize() << " index:" << i << " for actor:" << GetAID();
678   if (type_ != KernelTransformType::kSuperKernelActor) {
679     node_device_tensor->set_ptr(copy_device_tensor->GetMutablePtr());
680   } else {
681     node_device_tensor->set_ptr(copy_device_tensor->GetValidPtr(copy_device_tensor->stream_id()));
682   }
683   node_device_tensor->set_from_mem_pool(false);
684   return false;
685 }
686 
CopyInputData(const OpContext<DeviceTensor> * context,const KernelGraphPtr & graph)687 bool SuperKernelActor::CopyInputData(const OpContext<DeviceTensor> *context, const KernelGraphPtr &graph) {
688   MS_EXCEPTION_IF_NULL(context);
689   MS_EXCEPTION_IF_NULL(graph);
690   if (device_contexts_.empty() || device_contexts_[0] == nullptr ||
691       device_contexts_[0]->device_res_manager_ == nullptr) {
692     MS_LOG(ERROR) << "Invalid device context for actor:" << GetAID();
693     return false;
694   }
695   auto device_context = device_contexts_[0];
696   auto &input_nodes = graph->input_nodes();
697   if (input_device_tensors_.size() != node_device_tensors_.size()) {
698     MS_LOG(ERROR) << "The size of input_device_tensors_[" << input_device_tensors_.size()
699                   << "] is not equal to the size of node_device_tensors_[" << node_device_tensors_.size() << "].";
700     return false;
701   }
702 
703   for (size_t i = 0; i < input_device_tensors_.size(); ++i) {
704     auto &node_device_tensor = node_device_tensors_[i];
705     MS_EXCEPTION_IF_NULL(node_device_tensor);
706     auto &input_device_tensor = input_device_tensors_[i];
707     if (InputDataNoNeedCopy(input_nodes[i], input_device_tensor, node_device_tensor, type_)) {
708       MS_LOG(DEBUG) << "Actor:" << GetAID() << " input device tensor " << i << ":" << input_device_tensor
709                     << " no need copy.";
710       continue;
711     }
712     MS_EXCEPTION_IF_NULL(input_nodes[i]);
713     const auto &node_device_kernel_tensor = node_device_tensor->kernel_tensor();
714     MS_EXCEPTION_IF_NULL(input_device_tensor);
715     const auto &input_kernel_tensor = input_device_tensor->kernel_tensor();
716     MS_EXCEPTION_IF_NULL(node_device_kernel_tensor);
717     MS_EXCEPTION_IF_NULL(input_kernel_tensor);
718     UpdateShape(input_nodes[i], node_device_tensor, input_device_tensor, type_);
719     node_device_tensor->set_user_data(input_device_tensor->user_data());
720     node_device_tensor->set_need_sync_user_data(input_device_tensor->need_sync_user_data());
721     if (type_ != KernelTransformType::kSuperKernelActor) {
722       node_device_kernel_tensor->SetValue(input_kernel_tensor->GetValueTrack());
723     }
724 
725     // Copy.
726     DeviceTensorPtr copy_device_tensor = nullptr;
727     // If the input is not a persist device address, in a heterogeneous scenario, a new device address needs to
728     // be created. And set ptr to node device address to support the zero copy of graph input nodes.
729     if (!node_device_tensor->is_ptr_persisted()) {
730       if (CopyInputDataPersistedHandle(device_context, input_device_tensor, node_device_tensor, i)) {
731         continue;
732       }
733       copy_device_tensor = copy_input_device_tensors_[i];
734     } else {
735       if (node_device_tensor->GetPtr() == nullptr) {
736         MS_LOG(INFO) << "The node device tensor, which shared with another graph, has no device memory and will skip "
737                         "copy for actor:"
738                      << GetAID();
739         continue;
740       }
741       copy_device_tensor = node_device_tensor;
742     }
743     MS_EXCEPTION_IF_NULL(copy_device_tensor);
744     MS_LOG(INFO) << "The input data of node:" << input_nodes[i]->DebugString()
745                  << " need copy from device address:" << input_device_tensor << " ptr:" << input_device_tensor->GetPtr()
746                  << " size:" << input_device_tensor->GetSize() << ", type:" << input_device_tensor->GetDeviceType()
747                  << " to device address:" << copy_device_tensor << " ptr:" << copy_device_tensor->GetPtr()
748                  << " size:" << copy_device_tensor->GetSize() << ", type:" << copy_device_tensor->GetDeviceType()
749                  << ", is ref node need copy back:" << is_parameters_need_copy_[i] << " for actor:" << GetAID();
750     if (!Copy(copy_device_tensor.get(), input_device_tensor)) {
751       MS_LOG(ERROR) << "Copy data failed for actor:" << GetAID() << " input index:" << i;
752       continue;
753     }
754 
755     if (is_parameters_need_copy_[i]) {
756       ref_node_addr_map_[copy_device_tensor.get()] = input_device_tensor;
757     }
758   }
759   return true;
760 }
761 
SendMemoryFreeReq(OpContext<DeviceTensor> * const context)762 void SuperKernelActor::SendMemoryFreeReq(OpContext<DeviceTensor> *const context) {
763   MS_EXCEPTION_IF_NULL(context);
764   MS_EXCEPTION_IF_NULL(graph_);
765 
766   if (device_contexts_.empty() || device_contexts_[0] == nullptr ||
767       device_contexts_[0]->device_res_manager_ == nullptr) {
768     SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(GraphExecutionStrategy::kPipeline, (*context),
769                                                   "Invalid device context for super kernel actor:" + GetAID().Name());
770   }
771   if (memory_free_lists_.size() > 0 && memory_free_lists_.back().size() > 0) {
772     if (common::IsNeedProfileMemory()) {
773       for (auto data : memory_free_lists_.back()) {
774         auto output_address = reinterpret_cast<std::uintptr_t>(data);
775         MS_LOG(WARNING) << "Need Profile Memory, Memory need Decrease DynamicRefCount, actor name: " << GetAID().Name()
776                         << ", kernel graph: " << graph_->ToString() << ", device address class ptr: " << output_address
777                         << ", device address size: " << data->GetSize() << ", device address addr: " << data->GetPtr();
778       }
779     }
780 
781     if (ActorDispatcher::is_memory_free_sync()) {
782       ActorDispatcher::SendSync(memory_manager_aid_, &MemoryManagerActor::FreeMemory, &(memory_free_lists_.back()),
783                                 device_contexts_[0], context, GetAID());
784     } else {
785       ActorDispatcher::Send(memory_manager_aid_, &MemoryManagerActor::FreeMemory, &(memory_free_lists_.back()),
786                             device_contexts_[0], context, GetAID());
787     }
788   }
789 
790   // Free the address that is the temp store for kernel input copy.
791   for (auto &copy_input_device_tensor : copy_input_device_tensors_) {
792     if ((copy_input_device_tensor != nullptr) && (copy_input_device_tensor->GetPtr() != nullptr)) {
793       device_contexts_[0]->device_res_manager_->FreeMemory(copy_input_device_tensor.get());
794     }
795   }
796 }
797 
BuildKernelActors()798 void SuperKernelActor::BuildKernelActors() {
799   MS_EXCEPTION_IF_NULL(graph_);
800   const auto &execution_order = graph_->execution_order();
801   size_t kernel_num = execution_order.size();
802   kernel_actors_.resize(kernel_num);
803 
804   mindspore::HashMap<AnfNodePtr, KernelActor *> node_to_kernel_actor_;
805 
806   // 1. Create kernel actor if need.
807   for (size_t i = 0; i < kernel_num; i++) {
808     const auto &kernel = execution_order[i];
809     MS_EXCEPTION_IF_NULL(kernel);
810     if (IsSkippedKernelActor(kernel)) {
811       kernel_actors_[i] = nullptr;
812       continue;
813     }
814 
815     if (!IsKernelActor(kernel, GraphExecutionStrategy::kPipeline)) {
816       MS_LOG(WARNING) << "Find not real cnode in execution order for graph: " << graph_->graph_id();
817       kernel_actors_[i] = nullptr;
818       continue;
819     }
820 
821     auto ref_input_indexes = FetchModifiableRefInputIndex(kernel);
822     auto ref_output_indexes = FetchModifiableRefOutputIndex(kernel, graph_);
823     const auto &real_device_context = device::FetchRealDeviceContext(kernel, device_contexts_[0]);
824     MS_EXCEPTION_IF_NULL(real_device_context);
825     if (IsRpcActor(kernel)) {
826       MS_LOG(EXCEPTION) << "Can not launch a sub graph which contains rpc kernel by kbk.";
827     } else if (IsInnerControlFlowActor(kernel)) {
828       MS_LOG(EXCEPTION) << "Can not launch a sub graph which contains ConditionSwitch or ConditionSwitch by kbk.";
829     }
830 
831     KernelActorPtr kernel_actor = std::make_shared<KernelActor>(
832       kernel->fullname_with_scope(), kernel, real_device_context, memory_manager_aid_, debug_aid_, recorder_aid_,
833       GraphExecutionStrategy::kPipeline, ref_input_indexes, ref_output_indexes);
834     MS_EXCEPTION_IF_NULL(kernel_actor);
835     kernel_actors_[i] = kernel_actor;
836 
837     // Set the member of kernel actor.
838     kernel_actor->is_launch_skipped_ =
839       common::AnfAlgo::IsNopNode(kernel) && graph_->IsInRefOutputMap(std::make_pair(kernel, 0));
840     kernel_actor->inputs_continuous_memory_ =
841       (common::AnfAlgo::IsCommunicationOp(kernel) && common::AnfAlgo::GetCNodeName(kernel) != kMatMulAllReduceOpName) &&
842       (common::AnfAlgo::GetInputTensorNum(kernel) > 1);
843 
844     SchedulerHelper::AddSomasInfo(kernel_actor.get());
845 
846     node_to_kernel_actor_[kernel] = kernel_actor.get();
847   }
848 
849   // 2. Add somas info.
850   // AddSomasOutput
851   for (const auto &front_backend_pair : graph_->front_node_to_graph_output_map()) {
852     const auto &output_with_index = front_backend_pair.second;
853     auto output_kernel = output_with_index.first;
854     auto output_index = output_with_index.second;
855     MS_EXCEPTION_IF_NULL(output_kernel);
856     auto origin_output_with_index = front_backend_pair.first;
857     if (origin_output_with_index.first == nullptr) {
858       MS_LOG(WARNING) << "The graph " << graph_->graph_id() << " output node:" << output_kernel->fullname_with_scope()
859                       << " with index: " << output_index << " has no front node.";
860       continue;
861     }
862     if (!AnfUtils::IsRealCNodeKernel(output_kernel)) {
863       continue;
864     }
865     auto iter = node_to_kernel_actor_.find(output_kernel);
866     if (iter == node_to_kernel_actor_.end()) {
867       MS_LOG_WITH_NODE(EXCEPTION, output_kernel)
868         << "Can not find kernel actor for node: " << output_kernel->fullname_with_scope();
869     }
870     const auto &output_actor = iter->second;
871     MS_EXCEPTION_IF_NULL(output_actor);
872     output_actor->is_output_kernel_ = true;
873     SchedulerHelper::AddSomasInfoForGraphOutput(output_actor, output_kernel, output_index, graph_->graph_id());
874   }
875 
876   // 3. Initialize all kernel actor.
877   for (size_t i = 0; i < kernel_num; i++) {
878     const auto &kernel_actor = kernel_actors_[i];
879     if (kernel_actor) {
880       kernel_actor->Init();
881     }
882   }
883 }
884 
ParseInputIndex()885 void SuperKernelActor::ParseInputIndex() {
886   const auto &input_nodes = graph_->input_nodes();
887   size_t input_num = input_nodes.size();
888   mindspore::HashMap<AnfNode *, size_t> node_to_input_idx;
889   node_to_input_idx.reserve(input_num);
890 
891   for (size_t i = 0; i < input_num; i++) {
892     node_to_input_idx[input_nodes[i].get()] = i;
893   }
894 
895   const auto &execution_order = graph_->execution_order();
896   size_t kernel_num = execution_order.size();
897   for (size_t i = 0; i < kernel_num; i++) {
898     const auto &kernel = execution_order[i];
899     MS_EXCEPTION_IF_NULL(kernel);
900 
901     if (!IsKernelActor(kernel, GraphExecutionStrategy::kPipeline) || IsSkippedKernelActor(kernel)) {
902       continue;
903     }
904 
905     auto real_input_num = common::AnfAlgo::GetInputTensorNum(kernel);
906     for (size_t j = 0; j < real_input_num; j++) {
907       auto real_input_node = common::AnfAlgo::GetPrevNodeOutput(kernel, j, false);
908       MS_EXCEPTION_IF_NULL(real_input_node.first);
909       // Note: only record input data, persist weight in compile phase.
910       if (real_input_node.first->isa<Parameter>()) {
911         auto iter = node_to_input_idx.find(real_input_node.first.get());
912         if (iter == node_to_input_idx.end()) {
913           MS_LOG_WITH_NODE(EXCEPTION, real_input_node.first)
914             << "Can not find index for input node: " << real_input_node.first->fullname_with_scope();
915         }
916         kernel_input_to_graph_input_indices_[kernel.get()].emplace_back(j, iter->second);
917       } else if (real_input_node.first->isa<ValueNode>()) {
918         const auto &kernel_actor = kernel_actors_[i];
919         MS_EXCEPTION_IF_NULL(kernel_actor);
920 
921         const auto &real_device_context = device::FetchRealDeviceContext(kernel, device_contexts_[0]);
922         MS_EXCEPTION_IF_NULL(real_device_context);
923         const auto &front_node = AnfAlgo::FetchFrontNodeByBackendNode(real_input_node.first, *graph_);
924         MS_EXCEPTION_IF_NULL(front_node);
925         auto device_address =
926           DeviceTensorStore::GetInstance().Fetch(front_node.get(), real_device_context->GetDeviceType());
927         MS_EXCEPTION_IF_NULL(device_address);
928         kernel_actor->SetInputDeviceTensor(device_address.get(), j);
929       }
930     }
931   }
932 }
933 
CalcRefCount()934 void SuperKernelActor::CalcRefCount() {
935   const auto &execution_order = graph_->execution_order();
936   size_t kernel_num = execution_order.size();
937   for (size_t i = 0; i < kernel_num; i++) {
938     const auto &kernel = execution_order[i];
939     MS_EXCEPTION_IF_NULL(kernel);
940     if (!IsKernelActor(kernel, GraphExecutionStrategy::kPipeline) || IsSkippedKernelActor(kernel)) {
941       continue;
942     }
943 
944     auto input_num = common::AnfAlgo::GetInputTensorNum(kernel);
945     for (size_t j = 0; j < input_num; j++) {
946       auto input_node_with_idx = common::AnfAlgo::GetPrevNodeOutput(kernel, j, false);
947       MS_EXCEPTION_IF_NULL(input_node_with_idx.first);
948 
949       if (input_node_with_idx.first->isa<CNode>()) {
950         if (IsSkippedKernelActor(input_node_with_idx.first)) {
951           const auto &real_input_node_with_idx =
952             common::AnfAlgo::GetPrevNodeOutput(input_node_with_idx.first, 0, false);
953           UpdateRefCountWithOnlyDependShape(kernel, j, real_input_node_with_idx.first, real_input_node_with_idx.second);
954         } else {
955           UpdateRefCountWithOnlyDependShape(kernel, j, input_node_with_idx.first, input_node_with_idx.second);
956         }
957       } else if (IsPersistentDeviceTensor(input_node_with_idx.first)) {
958         UpdateRefCount(input_node_with_idx.first, input_node_with_idx.second, true);
959       }
960     }
961   }
962 }
963 }  // namespace runtime
964 }  // namespace mindspore
965