1 /**
2 * Copyright 2021-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "runtime/graph_scheduler/actor/kernel_actor.h"
18
19 #include <mutex>
20 #include <algorithm>
21
22 #include "runtime/device/multi_stream_controller.h"
23 #include "runtime/graph_scheduler/actor/memory_manager_actor.h"
24 #include "runtime/graph_scheduler/actor/output_actor.h"
25 #include "runtime/graph_scheduler/actor/recorder_actor.h"
26 #include "runtime/graph_scheduler/actor/debug_actor.h"
27 #include "mindrt/include/async/async.h"
28 #include "utils/log_adapter.h"
29 #include "include/backend/mem_reuse/mem_tracker.h"
30 #include "include/backend/distributed/recovery/recovery_context.h"
31 #include "include/backend/distributed/collective/collective_manager.h"
32 #include "backend/common/optimizer/dynamic_shape/dynamic_shape_helper.h"
33 #include "kernel/framework_utils.h"
34 #include "mindspore/core/ops/framework_ops.h"
35 #include "utils/compile_config.h"
36
37 namespace mindspore {
38 namespace runtime {
39 namespace {
IsSomasEnable(const SomasInfo * somas_info)40 bool IsSomasEnable(const SomasInfo *somas_info) {
41 return ((somas_info != nullptr) && (somas_info->whole_block_size_ != 0));
42 }
43
CheckDryRun(const CNodePtr & kernel_)44 void CheckDryRun(const CNodePtr &kernel_) {
45 static const bool is_dry_run_mode = (common::GetEnv(kSimulationLevel) == kSimulationLevelCompileKernel);
46 static auto enabled_profile = common::GetCompileConfig("COMPILE_PROFILE") == "1";
47 if (is_dry_run_mode && !enabled_profile) {
48 MS_LOG_WITH_NODE(EXCEPTION, kernel_)
49 << "The dry run mode can not support dynamic shape graph which contains value depend kernel:"
50 << kernel_->fullname_with_scope()
51 << ", launch kernel is skipped for dry run mode, which leads to fail to GetValue for infer "
52 "shape of these value depend kernel. You can only simulate compile graph and not do "
53 "InferShape and Resize by `export MS_SIMULATION_LEVEL=0` instead.";
54 }
55 }
56 } // namespace
57
58 using distributed::collective::CollectiveManager;
59 using distributed::recovery::RecoveryContext;
60
Init()61 void KernelActor::Init() {
62 // Check device contexts number.
63 if (device_contexts_.size() != device::kDeviceContextsNumOne) {
64 MS_LOG(EXCEPTION) << "The device contexts number is wrong.";
65 }
66 MS_EXCEPTION_IF_NULL(device_contexts_[0]);
67
68 // Set the number of actor running dependent messages.
69 running_dependent_msg_num_ = SizeToInt(input_datas_num_ + input_controls_num_);
70
71 MS_EXCEPTION_IF_NULL(kernel_);
72 real_input_num_ = common::AnfAlgo::GetInputTensorNum(kernel_);
73 kernel_info_ = dynamic_cast<KernelInfo *>(kernel_->kernel_info());
74 MS_EXCEPTION_IF_NULL(kernel_info_);
75 kernel_mod_ = kernel_info_->MutableKernelMod();
76 MS_EXCEPTION_IF_NULL(kernel_mod_);
77 is_dynamic_value_ = common::AnfAlgo::IsDynamicValue(kernel_);
78 if (is_dynamic_shape_ && IsSomasEnable(somas_info_)) {
79 MS_LOG(EXCEPTION) << "Not support the somas for the dynamic shape: " << GetAID().Name();
80 }
81 is_dynamic_type_ = common::AnfAlgo::IsAnyTypeOutput(kernel_);
82 has_dynamic_ = is_dynamic_shape_ || is_dynamic_type_ || is_dynamic_value_;
83
84 if (is_dynamic_value_ && (is_dynamic_shape_ || is_dynamic_type_)) {
85 CheckDryRun(kernel_);
86 }
87
88 // Check whether the kernel has input node which is a computed depend kernel.
89 launch_ignored_inputs_ = kernel_mod_->GetLaunchIgnoredInputAddressIdx();
90
91 stream_ = device_contexts_[0]->device_res_manager_->GetStream(kernel_info_->stream_id());
92 // Init the device tensors and kernel launch info.
93 InitInputInfo();
94 InitOutputInfo();
95 InitWorkspaceInfo();
96
97 // Init the output data.
98 InitOutputData();
99 if (output_data_.size() != output_data_arrows_.size()) {
100 MS_LOG(EXCEPTION) << "The output data size is wrong: " << GetAID().Name();
101 }
102 size_t output_data_index = 0;
103 for (auto &data_arrow : output_data_arrows_) {
104 auto data = output_data_[output_data_index].first.get();
105 MS_EXCEPTION_IF_NULL(data);
106 MS_EXCEPTION_IF_NULL(data_arrow);
107 if (IntToSize(data_arrow->from_output_index_) >= output_device_tensors_.size()) {
108 MS_LOG(EXCEPTION) << "The output index is out of range: " << GetAID().Name();
109 }
110 data->data_ = output_device_tensors_[IntToSize(data_arrow->from_output_index_)];
111 ++output_data_index;
112 }
113
114 auto device_context = device_contexts_[0];
115 // cpu kernel does not need multi stream process, and gpu kernel has not adapt it currently.
116 if (device_context->GetDeviceType() == device::DeviceType::kCPU ||
117 device_context->GetDeviceType() == device::DeviceType::kGPU) {
118 MS_LOG(DEBUG) << "Kernel : " << kernel_->fullname_with_scope() << " device type is "
119 << device_context->GetDeviceType() << ", will skip multi stream process.";
120 is_multi_stream_process_skipped_ = true;
121 }
122
123 // Share pointer of task id on stream with output kernel tensor.
124 for (auto &output_kernel_tensor : output_kernel_tensors_) {
125 output_kernel_tensor->set_task_id_on_stream(task_id_on_stream_);
126 }
127 is_stream_recv_actor_ = IsPrimitiveCNode(kernel_, prim::kPrimStreamRecv);
128 // kernel_ may be ValueNode<FuncGraph>, skip exception situation.
129 auto cnode = kernel_->cast<CNodePtr>();
130 if (cnode == nullptr) {
131 return;
132 }
133
134 // shape depend need kernel is cnode.
135 InitShapeDependInfo();
136
137 auto input0 = cnode->input(kAnfPrimitiveIndex);
138 if (IsValueNode<FuncGraph>(input0)) {
139 MS_LOG(INFO) << "Cnode is not a func graph value node : " << kernel_->fullname_with_scope() << ".";
140 return;
141 }
142
143 auto multi_stream_safe_value = cnode->GetAttr(kAttrInputMultiStreamSafe);
144 if (multi_stream_safe_value != nullptr) {
145 is_multi_stream_safe_ = GetValue<bool>(multi_stream_safe_value);
146 MS_LOG(DEBUG) << "cnode : " << cnode->DebugString() << " is thread safe.";
147 }
148 }
149
InitInputInfo()150 void KernelActor::InitInputInfo() {
151 for (size_t i = 0; i < real_input_num_; ++i) {
152 const auto &input_device_tensor = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel_, i, false);
153 MS_EXCEPTION_IF_NULL(input_device_tensor);
154 (void)real_input_data_infos_.emplace_back(
155 std::make_shared<InputDataInfo>(input_device_tensor->format(), input_device_tensor->host_shape(),
156 input_device_tensor->GetSize(), input_device_tensor->type_id()));
157 }
158
159 copy_input_device_tensors_.resize(real_input_num_);
160 input_device_tensors_.resize(real_input_num_);
161 input_kernel_tensors_.resize(real_input_num_);
162 input_kernel_tensors_for_infer_.resize(real_input_num_);
163 for (auto &input_address : input_device_tensors_) {
164 (void)memory_free_list_.emplace_back(input_address);
165 if (recorder_aid_ != nullptr || debug_aid_ != nullptr) {
166 (void)mem_info_.inputs_.emplace_back(std::make_shared<Address>());
167 }
168 }
169
170 if (EnableKbkSubGraphExecute()) {
171 memory_free_list_.clear();
172 for (size_t i = 0; i < real_input_num_; ++i) {
173 auto input_node_with_idx = common::AnfAlgo::GetPrevNodeOutput(kernel_, i, false);
174 MS_EXCEPTION_IF_NULL(input_node_with_idx.first);
175 if (!input_node_with_idx.first->isa<CNode>()) {
176 continue;
177 }
178
179 if (IsSkippedKernelActor(input_node_with_idx.first)) {
180 input_node_with_idx = common::AnfAlgo::GetPrevNodeOutput(input_node_with_idx.first, 0, false);
181 }
182
183 const auto &input_device_address =
184 AnfAlgo::GetMutableOutputAddr(input_node_with_idx.first, input_node_with_idx.second, false);
185 MS_EXCEPTION_IF_NULL(input_device_address);
186 input_device_tensors_[i] = input_device_address.get();
187 input_kernel_tensors_[i] = input_device_tensors_[i]->kernel_tensor().get();
188 input_kernel_tensors_for_infer_[i] = input_device_tensors_[i]->kernel_tensor();
189
190 if (!IsSomasEnable(somas_info_)) {
191 memory_free_list_.emplace_back(input_device_address.get());
192 }
193 }
194 }
195 }
196
InitOutputInfo()197 void KernelActor::InitOutputInfo() {
198 MS_EXCEPTION_IF_NULL(kernel_info_);
199 const auto &output_addresses = kernel_info_->output_address_list();
200 const auto &somas_outputs = kernel_info_->somas_output_result();
201 bool output_need_somas = false;
202 for (size_t i = 0; i < output_addresses.size(); ++i) {
203 auto &output_address = output_addresses[i];
204 MS_EXCEPTION_IF_NULL(output_address);
205
206 if (output_address->stream_id() != kernel_info_->stream_id()) {
207 MS_LOG(DEBUG) << "Output address : " << output_address << " stream id :" << output_address->stream_id()
208 << " is not equal kernel info stream id : " << kernel_info_->stream_id() << ".";
209 }
210
211 (void)output_device_tensors_.emplace_back(output_address.get());
212 (void)output_kernel_tensors_.emplace_back(output_address->kernel_tensor().get());
213 MS_LOG(DEBUG) << "Init output[" << i << "] info for node:" << kernel_->fullname_with_scope()
214 << " addr:" << output_address << " type:" << output_address->type_id()
215 << ", kernel tensor addr:" << output_address->kernel_tensor().get()
216 << ", kernel tensor: " << output_address->kernel_tensor()->ToString();
217 if (recorder_aid_ != nullptr || debug_aid_ != nullptr) {
218 (void)mem_info_.outputs_.emplace_back(std::make_shared<Address>());
219 }
220 // The output taken over by soma does not need to allocate memory.
221 if (kernel_info_->IsTensorEnableSomas(somas_outputs, i)) {
222 output_address->kernel_tensor()->set_managed_by_somas(true);
223 MS_LOG(INFO) << "Device address : " << output_address << ", kernel tensor : " << output_address->kernel_tensor()
224 << " is managed by somas.";
225 // Somas outputs use the info of kernelMod, and output address use the info of device address.
226 if (somas_outputs[i].second < output_address->GetSize()) {
227 MS_LOG(INFO) << GetAID().Name() << " check somas size warning, output index:" << i
228 << " somas aligned size:" << somas_outputs[i].second
229 << " is smaller than address size:" << output_address->GetSize();
230 }
231 // Used to keep graph output address when somas block memory free, and reused by the ref conut in other graphs.
232 if (somas_graph_output_indexes_.count(i) > 0) {
233 MS_LOG(DEBUG) << "Somas keep output device address:" << output_address << " ptr:" << output_address->GetPtr();
234 (void)somas_info_->InsertGraphOutputInfo(output_address.get(), somas_outputs[i].first, somas_outputs[i].second);
235 } else {
236 UpdateRefCount(output_address.get(), true);
237 }
238 output_need_somas = true;
239 } else {
240 (void)memory_alloc_list_.emplace_back(output_address.get());
241 if (output_address->original_ref_count() == SIZE_MAX) {
242 max_ref_cnt_output_list_.emplace_back(output_address.get());
243 }
244 (void)memory_free_list_.emplace_back(output_address.get());
245 }
246 }
247
248 if (output_need_somas && (!IsSomasEnable(somas_info_))) {
249 MS_LOG(EXCEPTION) << "The somas is not enable for: " << GetAID().Name();
250 }
251
252 if (IsSomasEnable(somas_info_)) {
253 MS_EXCEPTION_IF_CHECK_FAIL((output_device_tensors_.size() >= somas_outputs.size()), "The output num is wrong.");
254 }
255
256 for (auto &external_reference_tensor : external_reference_tensors_) {
257 (void)memory_free_list_.emplace_back(external_reference_tensor);
258 }
259 }
260
InitWorkspaceInfo()261 void KernelActor::InitWorkspaceInfo() {
262 MS_EXCEPTION_IF_NULL(kernel_info_);
263 // The size of workspace maybe changed in dynamic shape, so put workspace_address in the end of memory_alloc_list_ and
264 // memory_free_list_, for the operation of dynamic_shape condition in FetchWorkspaceDeviceTensor.
265 const auto &workspace_addresses = kernel_info_->workspace_address_list();
266 const auto &somas_workspace = kernel_info_->somas_workspace_result();
267 bool workspace_need_somas = false;
268 for (size_t i = 0; i < workspace_addresses.size(); ++i) {
269 auto &workspace_address = workspace_addresses[i];
270 MS_EXCEPTION_IF_NULL(workspace_address);
271 (void)workspace_device_tensors_.emplace_back(workspace_address.get());
272 (void)workspace_kernel_tensors_.emplace_back(workspace_address->kernel_tensor().get());
273 if (recorder_aid_ != nullptr || debug_aid_ != nullptr) {
274 (void)mem_info_.workspaces_.emplace_back(std::make_shared<Address>());
275 }
276
277 // The workspace taken over by soma does not need to allocate memory.
278 if (kernel_info_->IsTensorEnableSomas(somas_workspace, i)) {
279 if (somas_workspace[i].second < workspace_address->GetSize()) {
280 MS_LOG(INFO) << GetAID().Name() << " check somas size warning, workspace index:" << i
281 << " somas aligned size:" << somas_workspace[i].second
282 << " is smaller than address size:" << workspace_address->GetSize();
283 }
284 UpdateRefCount(workspace_address.get(), true);
285 workspace_need_somas = true;
286 } else {
287 (void)memory_alloc_list_.emplace_back(workspace_address.get());
288 (void)memory_free_list_.emplace_back(workspace_address.get());
289 }
290 }
291
292 if (workspace_need_somas && (!IsSomasEnable(somas_info_))) {
293 MS_LOG(EXCEPTION) << "The somas is not enable for: " << GetAID().Name();
294 }
295
296 if (IsSomasEnable(somas_info_)) {
297 MS_EXCEPTION_IF_CHECK_FAIL((workspace_device_tensors_.size() >= somas_workspace.size()),
298 "The output num is wrong.");
299 }
300 }
301
InitShapeDependInfo()302 void KernelActor::InitShapeDependInfo() {
303 auto ms_context = MsContext::GetInstance();
304 MS_EXCEPTION_IF_NULL(ms_context);
305 static const bool enable_infer_boost = ms_context->IsEnableInferBoost();
306 if (enable_infer_boost) {
307 return;
308 }
309 // Shape kernel no need to decrease ref count.
310 const auto &only_depend_shape_attr = common::AnfAlgo::GetCNodePrimitiveAttr(kernel_, kAttrOnlyDependShape);
311 if (only_depend_shape_attr != nullptr) {
312 auto only_depend_shape = GetValue<std::vector<bool>>(only_depend_shape_attr);
313 MS_LOG(INFO) << "Init shape depend info, real_input_num_ : " << real_input_num_
314 << ", only_depend_shape size : " << only_depend_shape.size() << ".";
315 for (size_t i = 0; i < only_depend_shape.size(); i++) {
316 // shape depend, no need free this device tensor.
317 MS_LOG(INFO) << "only_shape_depend[" << i << "] : " << only_depend_shape[i] << ".";
318 depend_shape_input_list_.emplace_back(only_depend_shape[i]);
319 }
320 }
321 }
322
Run(OpContext<DeviceTensor> * const context)323 void KernelActor::Run(OpContext<DeviceTensor> *const context) {
324 try {
325 MS_EXCEPTION_IF_NULL(kernel_);
326 MS_EXCEPTION_IF_NULL(kernel_->func_graph());
327 if (device::tracker::MemTrackerManager::GetInstance().IsEnabled()) {
328 device::tracker::CALL_MEMORY_TRACKER_WITH_FILE(AddTask, GetAID().Name(), kernel_->fullname_with_scope(),
329 kernel_->func_graph()->ToString());
330 }
331 FetchInputDeviceTensor(context);
332
333 if (ActorDispatcher::enable_runtime_multi_pipeline()) {
334 RunWithMultiPipeline(context);
335 return;
336 }
337
338 device_contexts_[0]->device_res_manager_->BindDeviceToCurrentThread(false);
339 if (has_dynamic_) {
340 // Infer shape and resize for dynamic shape case.
341 InferAndResize(context);
342 FetchOutputDeviceTensor(context);
343 FetchWorkspaceDeviceTensor();
344 } else {
345 FetchOutputDeviceTensor(context);
346 }
347
348 // Set the memory address for the tensors which use the somas.
349 SetSomasMemory(context);
350
351 if (ActorDispatcher::enable_async_launch_kernel()) {
352 RunWithAsyncLaunchKernel(context);
353 return;
354 }
355
356 if (!memory_alloc_list_.empty()) {
357 // Allocate the memory address for other tensors which don't use the somas.
358 SendMemoryAllocReq(context);
359 }
360 OnMemoryAllocFinish(context);
361 } catch (const std::exception &e) {
362 MsException::Instance().SetException();
363 std::string error_info =
364 "#umsg#Kernel error:#umsg#run kernel[" + kernel_->fullname_with_scope() + "] failed, exception: " + e.what();
365 SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(strategy_, (*context), error_info);
366 }
367 }
368
RunWithMultiPipeline(OpContext<DeviceTensor> * const context)369 void KernelActor::RunWithMultiPipeline(OpContext<DeviceTensor> *const context) {
370 // 1. Set the memory address for the tensors which use the somas if need.
371 SetSomasMemory(context);
372
373 // If the kernel need user data and is dynamic, maybe need input kernel's output user data to infer shape, this value
374 // depend case can not handle in KernelTensor auto sync phase currently.
375 if (kernel_mod_->need_user_data() && has_dynamic_) {
376 MS_LOG(DEBUG) << "Begin wait runtime pipeline for kernel: " << kernel_->fullname_with_scope();
377 if (!WaitRuntimePipelineFinish(context)) {
378 MS_LOG(INFO) << "Run failed and early stop for kernel: " << kernel_->fullname_with_scope();
379 return;
380 }
381 MS_LOG(DEBUG) << "End wait runtime pipeline for kernel: " << kernel_->fullname_with_scope();
382 }
383
384 // 2. Push run task to pipeline.
385 // Note: dynamic value or static shape also need push task into infer actor to make sure correct kernel execution
386 // order.
387 Async(kernel_async_infer_aid_, &KernelAsyncInferActor::InferShape, context, this);
388
389 // The computed depend kernel should wait output shape update after kernel launch.
390 if (kernel_mod_->IsNeedUpdateOutputShapeAndSize()) {
391 MS_LOG(DEBUG) << "Begin wait runtime pipeline for kernel: " << kernel_->fullname_with_scope();
392 if (!WaitRuntimePipelineFinish(context)) {
393 MS_LOG(INFO) << "Run failed and early stop for kernel: " << kernel_->fullname_with_scope();
394 return;
395 }
396 MS_LOG(DEBUG) << "End wait runtime pipeline for kernel: " << kernel_->fullname_with_scope();
397 }
398
399 // 3. Post run.
400 EraseInput(context);
401 SendOutput(context);
402 }
403
RunWithAsyncLaunchKernel(OpContext<DeviceTensor> * const context)404 void KernelActor::RunWithAsyncLaunchKernel(OpContext<DeviceTensor> *const context) {
405 Async(kernel_async_launch_aid_, &KernelAsyncLaunchActor::LaunchKernel, context, this);
406
407 if (IsRunningFailed(context)) {
408 MS_LOG(INFO) << "Run failed and early stop for kernel: " << kernel_->fullname_with_scope();
409 return;
410 }
411
412 // PostLaunchKernel
413 EraseInput(context);
414 SendOutput(context);
415 }
416
FetchWorkspaceDeviceTensor()417 void KernelActor::FetchWorkspaceDeviceTensor() {
418 auto workspace_sizes = kernel_mod_->GetWorkspaceSizeList();
419 // Resize of workspace_device_tensors_, memory_alloc_list_ and memory_free_list_, because of
420 // the dynamic size of workspace.
421 if (workspace_device_tensors_.size() > workspace_sizes.size()) {
422 size_t size = workspace_device_tensors_.size() - workspace_sizes.size();
423 (void)workspace_device_tensors_.erase(workspace_device_tensors_.end() - size, workspace_device_tensors_.end());
424 if (recorder_aid_ != nullptr || debug_aid_ != nullptr) {
425 (void)mem_info_.workspaces_.erase(mem_info_.workspaces_.end() - size, mem_info_.workspaces_.end());
426 }
427
428 MS_EXCEPTION_IF_CHECK_FAIL((memory_alloc_list_.size() >= size), "The memory alloc list size is wrong.");
429 MS_EXCEPTION_IF_CHECK_FAIL((memory_free_list_.size() >= size), "The memory free list size is wrong.");
430 (void)memory_alloc_list_.erase(memory_alloc_list_.end() - size, memory_alloc_list_.end());
431 (void)memory_free_list_.erase(memory_free_list_.end() - size, memory_free_list_.end());
432 } else if (workspace_device_tensors_.size() < workspace_sizes.size()) {
433 if (device_contexts_.empty() || device_contexts_[0] == nullptr) {
434 MS_LOG(ERROR) << "Invalid device context for kernel actor:" + GetAID().Name();
435 return;
436 }
437 for (size_t i = workspace_device_tensors_.size(); i < workspace_sizes.size(); ++i) {
438 auto kernel_tensor = std::make_shared<kernel::KernelTensor>(
439 nullptr, workspace_sizes[i], Format::DEFAULT_FORMAT, kTypeUnknown, ShapeVector(),
440 device_contexts_[0]->device_context_key().device_name_, device_contexts_[0]->device_context_key().device_id_);
441 kernel_tensor->set_stream_id(kernel_info_->stream_id());
442 auto device_address = device_contexts_[0]->device_res_manager_->CreateDeviceAddress(kernel_tensor);
443 MS_EXCEPTION_IF_NULL(device_address);
444 MS_LOG(DEBUG) << "Create addr for node:" << common::AnfAlgo::GetNodeDebugString(kernel_)
445 << " addr:" << device_address;
446 AnfAlgo::SetWorkspaceAddr(device_address, i, kernel_.get()); // set to kernel_info
447 (void)workspace_device_tensors_.emplace_back(device_address.get());
448 if (recorder_aid_ != nullptr || debug_aid_ != nullptr) {
449 (void)mem_info_.workspaces_.emplace_back(std::make_shared<Address>());
450 }
451 (void)memory_alloc_list_.emplace_back(device_address.get());
452 (void)memory_free_list_.emplace_back(device_address.get());
453 }
454 }
455 // Set workspace address new size
456 for (size_t i = 0; i < workspace_sizes.size(); ++i) {
457 workspace_device_tensors_[i]->SetSize(workspace_sizes[i]);
458 }
459
460 // Update workspace kernel tensors.
461 workspace_kernel_tensors_.resize(workspace_device_tensors_.size());
462 for (size_t i = 0; i < workspace_sizes.size(); ++i) {
463 workspace_kernel_tensors_[i] = workspace_device_tensors_[i]->kernel_tensor().get();
464 }
465 }
466
SetSomasMemory(OpContext<DeviceTensor> * const context) const467 void KernelActor::SetSomasMemory(OpContext<DeviceTensor> *const context) const {
468 if (!IsSomasEnable(somas_info_)) {
469 return;
470 }
471
472 // Set the memory address for the output tensors which use the somas.
473 const auto &somas_outputs = kernel_info_->somas_output_result();
474 for (size_t i = 0; i < somas_outputs.size(); ++i) {
475 if (somas_outputs[i].second > 0) {
476 auto device_ptr = GetSomasDevicePtr(somas_outputs[i].first);
477 // In this scenario, the Init function can ensure that the pointer of the relevant operation is not nullptr.
478 // In order to perform performance, the pointer validity is not checked here.
479 // Check the graph output address need free.
480 if (somas_graph_output_indexes_.count(i) && (output_device_tensors_[i]->GetPtr() != nullptr)) {
481 MS_LOG(ERROR) << GetAID().Name() << " does not free address for graph output index: " << i;
482 device_contexts_[0]->device_res_manager_->FreeMemory(output_device_tensors_[i]);
483 }
484 MS_LOG(DEBUG) << "Set ptr:" << device_ptr << " to device address:" << output_device_tensors_[i]
485 << " in actor:" << GetAID();
486 device::tracker::CALL_MEMORY_TRACKER_WITH_FILE(AddMemInfo, GetAID().Name(),
487 device::tracker::MemType::kInSideSomas,
488 output_device_tensors_[i]->GetSize(), output_device_tensors_[i]);
489 output_device_tensors_[i]->set_ptr(device_ptr);
490 device::tracker::CALL_MEMORY_TRACKER_WITH_FILE(BindDevicePtr, output_device_tensors_[i], device_ptr);
491 }
492 }
493
494 // Set the memory address for the workspace tensors which use the somas.
495 const auto &somas_workspace = kernel_info_->somas_workspace_result();
496 for (size_t i = 0; i < somas_workspace.size(); ++i) {
497 if (somas_workspace[i].second > 0) {
498 auto device_ptr = GetSomasDevicePtr(somas_workspace[i].first);
499 // In this scenario, the Init function can ensure that the pointer of the relevant operation is not nullptr.
500 // In order to perform performance, the pointer validity is not checked here.
501 device::tracker::CALL_MEMORY_TRACKER_WITH_FILE(
502 AddMemInfo, GetAID().Name(), device::tracker::MemType::kInSideSomas, workspace_device_tensors_[i]->GetSize(),
503 workspace_device_tensors_[i]);
504 workspace_device_tensors_[i]->set_ptr(device_ptr);
505 device::tracker::CALL_MEMORY_TRACKER_WITH_FILE(BindDevicePtr, workspace_device_tensors_[i], device_ptr);
506 }
507 }
508 }
509
GetSomasDevicePtr(size_t offset) const510 void *KernelActor::GetSomasDevicePtr(size_t offset) const {
511 // Get the ptr from the whole block.
512 if (somas_info_->base_address_ != nullptr) {
513 return AddressOffset(somas_info_->base_address_, offset);
514 }
515
516 // Get the ptr from the merged blocks.
517 auto iter = somas_info_->merged_base_addresses_.upper_bound(offset);
518 if (iter == somas_info_->merged_base_addresses_.begin()) {
519 MS_LOG(ERROR) << GetAID().Name() << " can't find the merged block for offset: " << offset;
520 return nullptr;
521 }
522 --iter;
523 size_t real_offset = offset - iter->first;
524 void *real_base_address = iter->second;
525 if (real_base_address == nullptr) {
526 MS_LOG(ERROR) << GetAID().Name() << " doesn't allocate the merged block base address for offset: " << iter->first;
527 return nullptr;
528 }
529 return AddressOffset(real_base_address, real_offset);
530 }
531
TraceDynamicMemory()532 void KernelActor::TraceDynamicMemory() {
533 for (size_t i = 0; i < output_kernel_tensors_.size(); i++) {
534 if (output_device_tensors_[i]->original_ref_count() != SIZE_MAX) {
535 const auto &kernel_tensor = output_kernel_tensors_[i];
536 MemoryTraceManager::GetInstance().AddKernelMemoryTraceBlock(
537 std::make_shared<KernelMemoryTraceBlock>(kernel_, kernel_tensor->device_ptr(), kernel_tensor->size(),
538 kOutputMem, i),
539 device_contexts_[0]);
540 }
541 }
542
543 for (size_t i = 0; i < workspace_kernel_tensors_.size(); i++) {
544 const auto &kernel_tensor = workspace_kernel_tensors_[i];
545 MemoryTraceManager::GetInstance().AddKernelMemoryTraceBlock(
546 std::make_shared<KernelMemoryTraceBlock>(kernel_, kernel_tensor->device_ptr(), kernel_tensor->size(),
547 kWorkspaceMem, i),
548 device_contexts_[0]);
549 }
550 }
551
SendMemoryAllocReq(OpContext<DeviceTensor> * const context)552 void KernelActor::SendMemoryAllocReq(OpContext<DeviceTensor> *const context) {
553 if (device_contexts_[0]->device_res_manager_->swap_manager() != nullptr) {
554 device_contexts_[0]->device_res_manager_->swap_manager()->SetSwappableBeforeMemAllocate(input_device_tensors_,
555 output_device_tensors_);
556 MS_EXCEPTION_IF_NULL(kernel_info_);
557 for (const auto &out_in : kernel_info_->out_in_ref_map()) {
558 MS_EXCEPTION_IF_NULL(input_device_tensors_[out_in.second]);
559 const auto &ptr = input_device_tensors_[out_in.second]->GetValidPtr(kDefaultStreamIndex);
560 if (ptr == nullptr || output_device_tensors_[out_in.first] == nullptr ||
561 output_device_tensors_[out_in.first]->GetPtr() != nullptr) {
562 continue;
563 }
564 // Pointer in DeviceAddress which is reference output may not be updated to the same as the reference input
565 // which is swapped out.
566 MS_LOG(DEBUG) << "Set device ptr of " << out_in.first << "th ref output the same as input " << out_in.second
567 << ": " << ptr;
568 output_device_tensors_[out_in.first]->set_ptr(ptr);
569 }
570 }
571
572 MemoryManagerActor::GetInstance()->AllocateMemory(&memory_alloc_list_, device_contexts_[0], context, GetAID());
573
574 if (ActorDispatcher::enable_trace_dynamic_memory()) {
575 if (IsRunningFailed(context)) {
576 return;
577 }
578 TraceDynamicMemory();
579 }
580 }
581
SendMemoryFreeReq(OpContext<DeviceTensor> * const context)582 void KernelActor::SendMemoryFreeReq(OpContext<DeviceTensor> *const context) {
583 if (device_contexts_[0]->device_res_manager_->swap_manager() != nullptr) {
584 device_contexts_[0]->device_res_manager_->swap_manager()->SetSwappableBeforeMemFree(
585 input_device_tensors_, output_device_tensors_, kernel_info_);
586 }
587 if (depend_shape_input_list_.empty()) {
588 MemoryManagerActor::GetInstance()->FreeMemory(&memory_free_list_, device_contexts_[0], context, GetAID());
589 } else {
590 MS_LOG(DEBUG) << "depend_shape_input_list size : " << depend_shape_input_list_.size() << ".";
591 std::vector<DeviceTensor *> free_list;
592 for (size_t i = 0; i < memory_free_list_.size(); i++) {
593 const auto device_tensor = memory_free_list_[i];
594 if (device_tensor->dynamic_ref_count() == INT32_MAX && device_tensor->ref_count() != SIZE_MAX &&
595 i < depend_shape_input_list_.size() && depend_shape_input_list_[i]) {
596 MS_LOG(DEBUG) << "Skip memory free for kernel actor : " << kernel_->fullname_with_scope() << " index : " << i
597 << ", device address : " << memory_free_list_[i] << ".";
598 continue;
599 }
600 free_list.emplace_back(memory_free_list_[i]);
601 }
602 MemoryManagerActor::GetInstance()->FreeMemory(&free_list, device_contexts_[0], context, GetAID());
603 }
604
605 // Free the address that is the temp store for kernel input copy.
606 for (auto ©_input_device_tensor : copy_input_device_tensors_) {
607 if ((copy_input_device_tensor != nullptr) && (copy_input_device_tensor->GetPtr() != nullptr)) {
608 device_contexts_[0]->device_res_manager_->FreeMemory(copy_input_device_tensor.get());
609 }
610 }
611 }
612
OnMemoryAllocFinish(OpContext<DeviceTensor> * const context)613 void KernelActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *const context) {
614 if (IsRunningFailed(context)) {
615 MS_LOG(INFO) << "Run failed and early stop for kernel: " << kernel_->fullname_with_scope();
616 return;
617 }
618 PreLaunchKernel(context);
619
620 if (debug_aid_ != nullptr) {
621 ActorDispatcher::SendSync(*debug_aid_, &DebugActor::DebugPreLaunch, kernel_, input_device_tensors_,
622 output_device_tensors_, device_contexts_[0], context, &GetAID());
623 }
624
625 bool skip_launch = CollectiveManager::instance()->need_reinit() || IsSkippedLaunch(kernel_, nullptr);
626 if (!LaunchKernel(context, skip_launch)) {
627 MS_LOG_WITH_NODE(EXCEPTION, kernel_) << "#umsg#Kernel error:#umsg#Launch kernel failed: " +
628 kernel_->fullname_with_scope()
629 << trace::DumpSourceLines(kernel_);
630 }
631
632 // Record mem info, because async send may free device info.
633 if (recorder_aid_ != nullptr || debug_aid_ != nullptr) {
634 SetMemInfoForDebugAndRdr();
635 }
636
637 // Debug actor is blocked, must wait debug actor callback message to process continue.
638 if (debug_aid_ != nullptr) {
639 ActorDispatcher::SendSync(*debug_aid_, &DebugActor::DebugPostLaunch, kernel_, input_device_tensors_,
640 output_device_tensors_, device_contexts_[0], context, &GetAID());
641 }
642
643 PostLaunchKernel(context);
644 }
645
SetMemInfoForDebugAndRdr()646 void KernelActor::SetMemInfoForDebugAndRdr() {
647 for (size_t i = 0; i < input_device_tensors_.size(); ++i) {
648 mem_info_.inputs_[i]->addr = input_device_tensors_[i]->GetMutablePtr();
649 mem_info_.inputs_[i]->size = input_device_tensors_[i]->GetSize();
650 }
651 for (size_t i = 0; i < output_device_tensors_.size(); ++i) {
652 mem_info_.outputs_[i]->addr = output_device_tensors_[i]->GetMutablePtr();
653 mem_info_.outputs_[i]->size = output_device_tensors_[i]->GetSize();
654 }
655 for (size_t i = 0; i < workspace_device_tensors_.size(); ++i) {
656 mem_info_.workspaces_[i]->addr = workspace_device_tensors_[i]->GetMutablePtr();
657 mem_info_.workspaces_[i]->size = workspace_device_tensors_[i]->GetSize();
658 }
659 }
660
CopyInputDeviceTensor(const OpData<DeviceTensor> * input_data,OpContext<DeviceTensor> * const context)661 void KernelActor::CopyInputDeviceTensor(const OpData<DeviceTensor> *input_data,
662 OpContext<DeviceTensor> *const context) {
663 size_t input_data_index = IntToSize(input_data->index_);
664 // The ignored input address that is not used in the kernel launch and no need copy.
665 if (!launch_ignored_inputs_.empty() && (std::find(launch_ignored_inputs_.begin(), launch_ignored_inputs_.end(),
666 input_data_index) != launch_ignored_inputs_.end())) {
667 MS_LOG(DEBUG) << GetAID().Name() << " ignore the input address for input index: " << input_data_index;
668 return;
669 }
670 if (skip_launch_shape_related_op_) {
671 return;
672 }
673 if (input_data_index >= real_input_data_infos_.size()) {
674 SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(strategy_, *context, "The input index is of range.");
675 }
676 auto &real_input_info = real_input_data_infos_[input_data_index];
677 if ((input_data->data_->GetDeviceType() == device_contexts_[0]->GetDeviceType()) &&
678 AnfAlgo::IsEquivalentFormat(input_data->data_->format(), real_input_info->format_)) {
679 return;
680 }
681
682 if (!WaitRuntimePipelineFinish(context)) {
683 MS_LOG(INFO) << "Run failed and early stop for kernel: " << kernel_->fullname_with_scope();
684 return;
685 }
686 if (inputs_continuous_memory_) {
687 std::string error_info = GetAID().Name() + " inputs must be continuous memory and can't be copied for index " +
688 std::to_string(input_data_index);
689 SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(strategy_, *context, error_info);
690 }
691 if (input_data_index >= copy_input_device_tensors_.size()) {
692 SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(strategy_, *context, "The input index is of range.");
693 }
694 if (copy_input_device_tensors_[input_data_index] == nullptr) {
695 const auto &pre_kernel_tensor = AnfAlgo::GetPrevNodeOutputKernelTensor(kernel_, input_data_index);
696 MS_EXCEPTION_IF_NULL(pre_kernel_tensor);
697 auto new_kernel_tensor = std::make_shared<kernel::KernelTensor>(
698 pre_kernel_tensor->GetShape(), pre_kernel_tensor->GetType(), pre_kernel_tensor->GetValueTrack(), nullptr,
699 real_input_info->size_, real_input_info->format_, real_input_info->type_id_, real_input_info->shape_,
700 device_contexts_[0]->device_context_key().device_name_, device_contexts_[0]->device_context_key().device_id_);
701 MS_EXCEPTION_IF_NULL(new_kernel_tensor);
702 auto pre_stream_id = pre_kernel_tensor->stream_id();
703 if (pre_stream_id == UINT32_MAX) {
704 auto stream_id = kernel_info_->stream_id();
705 MS_LOG(DEBUG) << "Rewrite kernel tensor : " << new_kernel_tensor
706 << " stream id with kernel info stream id : " << stream_id << ".";
707 new_kernel_tensor->set_stream_id(stream_id);
708 } else {
709 MS_LOG(DEBUG) << "Rewrite kernel tensor : " << new_kernel_tensor
710 << " stream id with pre kernel tensor stream id : " << pre_stream_id << ".";
711 new_kernel_tensor->set_stream_id(pre_stream_id);
712 }
713
714 copy_input_device_tensors_[input_data_index] =
715 device_contexts_[0]->device_res_manager_->CreateDeviceAddress(new_kernel_tensor);
716 MS_EXCEPTION_IF_NULL(copy_input_device_tensors_[input_data_index]);
717 }
718 auto &new_device_tensor = copy_input_device_tensors_[input_data_index];
719 MS_EXCEPTION_IF_NULL(new_device_tensor);
720
721 MS_LOG(DEBUG) << "Prev stream id : " << input_device_tensors_[input_data_index]->stream_id()
722 << " new stream id : " << new_device_tensor->stream_id() << ".";
723 // Update the input device tensor.
724 input_device_tensors_[input_data_index] = new_device_tensor.get();
725 input_kernel_tensors_[input_data_index] = input_device_tensors_[input_data_index]->kernel_tensor().get();
726 if (is_dynamic_shape_) {
727 // Need update shape and size for dynamic shape case.
728 input_kernel_tensors_for_infer_[input_data_index] = input_device_tensors_[input_data_index]->kernel_tensor();
729 MS_EXCEPTION_IF_NULL(input_kernel_tensors_[input_data_index]);
730 MS_EXCEPTION_IF_NULL(input_data->data_->kernel_tensor());
731 MS_EXCEPTION_IF_NULL(input_data->data_->kernel_tensor()->GetShape());
732 input_kernel_tensors_[input_data_index]->SetShape(input_data->data_->kernel_tensor()->GetShape()->Clone());
733 input_kernel_tensors_[input_data_index]->set_size(input_data->data_->GetSize());
734 }
735
736 device::DynamicMemAllocatorDebugInfo::SetDebugInfo(GetAID().Name(), device::AllocatorType::kKernelOutput,
737 input_data_index);
738 if (new_device_tensor->GetPtr() == nullptr) {
739 device::tracker::CALL_MEMORY_TRACKER_WITH_FILE(AddMemInfo, GetAID().Name(), device::tracker::MemType::kOther,
740 new_device_tensor->GetSize(), new_device_tensor.get());
741 if (!device_contexts_[0]->device_res_manager_->AllocateMemory(new_device_tensor.get(), kDefaultStreamIndex)) {
742 SET_OPCONTEXT_MEMORY_ALLOC_FAIL_BY_STRATEGY(strategy_, *context, *(device_contexts_[0]), GetAID().Name(),
743 new_device_tensor->GetSize());
744 }
745 }
746
747 MS_LOG(INFO) << GetAID().Name() << " the input position:" << input_data_index
748 << " copy from device address:" << input_data->data_ << " ptr:" << input_data->data_->GetPtr()
749 << ", type:" << input_data->data_->GetDeviceType() << ", format:" << input_data->data_->format()
750 << " to device address:" << new_device_tensor.get() << " ptr:" << new_device_tensor->GetPtr()
751 << ", type:" << new_device_tensor->GetDeviceType() << ", format:" << new_device_tensor->format();
752 // Copy from the real parameter to formal parameter and insert the device tensor copy store.
753 if (!Copy(new_device_tensor.get(), input_data->data_)) {
754 std::string error_info = "Copy device tensor failed: " + GetAID().Name();
755 SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(strategy_, *context, error_info);
756 }
757 if (modifiable_ref_input_indexes_.count(input_data->index_) > 0) {
758 DeviceTensorCopyStore::GetInstance().Insert(new_device_tensor.get(), input_data->data_);
759 }
760 }
761
UpdateInputDeviceTensor(const OpData<DeviceTensor> * input_data,OpContext<DeviceTensor> * const context)762 void KernelActor::UpdateInputDeviceTensor(const OpData<DeviceTensor> *input_data,
763 OpContext<DeviceTensor> *const context) {
764 size_t input_index = IntToSize(input_data->index_);
765 if (input_index >= input_device_tensors_.size()) {
766 SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(
767 strategy_, (*context),
768 "The input index:" + std::to_string(input_index) + " is out of vector size:" +
769 std::to_string(input_device_tensors_.size()) + " for kernel:" + kernel_->fullname_with_scope());
770 }
771
772 // Update the input device tensor.
773 if (input_device_tensors_[input_index] != input_data->data_) {
774 input_device_tensors_[input_index] = input_data->data_;
775 memory_free_list_[input_index] = input_data->data_;
776 }
777
778 // Update the input kernel tensor.
779 const auto &kernel_tensor = input_device_tensors_[input_index]->kernel_tensor();
780 if (input_kernel_tensors_[input_index] != kernel_tensor.get()) {
781 input_kernel_tensors_[input_index] = kernel_tensor.get();
782 if (is_dynamic_shape_) {
783 input_kernel_tensors_for_infer_[input_index] = kernel_tensor;
784 }
785 }
786 }
787
FetchInputDeviceTensor(OpContext<DeviceTensor> * const context)788 void KernelActor::FetchInputDeviceTensor(OpContext<DeviceTensor> *const context) {
789 // Collect the inputs from input data.
790 const auto &data_iter = input_op_datas_.find(context->sequential_num_);
791 if (data_iter != input_op_datas_.end()) {
792 for (auto &input_data : data_iter->second) {
793 UpdateInputDeviceTensor(input_data, context);
794 CopyInputDeviceTensor(input_data, context);
795 }
796 }
797
798 // Collect the inputs from device tensor store.
799 FetchInputByTensorStore(&input_device_tensors_, &input_kernel_tensors_, &input_kernel_tensors_for_infer_,
800 &memory_free_list_, context);
801 }
802
FetchOutputDeviceTensor(OpContext<DeviceTensor> * const context)803 void KernelActor::FetchOutputDeviceTensor(OpContext<DeviceTensor> *const context) {
804 auto &output_addresses = kernel_info_->output_address_list();
805 const auto &output_size_list = kernel_mod_->GetOutputSizeList();
806
807 // May exist in the kernel which does not support the dynamic shape.
808 if (output_addresses.size() != output_size_list.size()) {
809 std::string error_info = "The outputs number(" + std::to_string(output_size_list.size()) + ") is wrong, " +
810 GetAID().Name() + " may not support the dynamic shape, please check.";
811 SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(strategy_, (*context), error_info);
812 }
813
814 // Update the size of output device tensor.
815 for (size_t i = 0; i < output_addresses.size(); ++i) {
816 if (output_size_list[i] == output_addresses[i]->GetSize()) {
817 continue;
818 }
819 output_addresses[i]->SetSize(output_size_list[i]);
820 }
821 }
822
PreLaunchKernel(OpContext<DeviceTensor> *)823 void KernelActor::PreLaunchKernel(OpContext<DeviceTensor> *) {
824 for (size_t i = 0; i < input_device_tensors_.size(); ++i) {
825 if (!input_device_tensors_[i]->GetValidPtr(kernel_info_->stream_id())) {
826 MS_LOG(DEBUG) << "For kernel: " << kernel_->fullname_with_scope() << ", input device tensor "
827 << input_device_tensors_[i] << " has no device ptr.";
828 }
829 }
830
831 for (size_t i = 0; i < output_device_tensors_.size(); ++i) {
832 if (!output_device_tensors_[i]->GetValidPtr(kernel_info_->stream_id())) {
833 MS_LOG(DEBUG) << "For kernel: " << kernel_->fullname_with_scope() << ", output device tensor "
834 << output_device_tensors_[i] << " has no device ptr.";
835 }
836 }
837
838 for (size_t i = 0; i < workspace_device_tensors_.size(); ++i) {
839 if (!workspace_device_tensors_[i]->GetValidPtr(kernel_info_->stream_id())) {
840 MS_LOG(DEBUG) << "For kernel: " << kernel_->fullname_with_scope() << ", workspace device tensor "
841 << workspace_device_tensors_[i] << " has no device ptr.";
842 }
843 }
844 }
845
ExecuteInferShapeTask(OpContext<DeviceTensor> * const context)846 void KernelActor::ExecuteInferShapeTask(OpContext<DeviceTensor> *const context) {
847 ProfilerRecorder profiler(ProfilerModule::kKernel, ProfilerEvent::kKernelInfer, GetAID().Name());
848 if (IsRunningFailed(context)) {
849 MS_LOG(DEBUG) << "Run failed and early stop infer shape for kernel: " << kernel_->fullname_with_scope();
850 return;
851 }
852
853 if (is_dynamic_type_) {
854 InferShapeAndType();
855 } else if (is_dynamic_shape_) {
856 InferShape();
857 }
858
859 Async(kernel_async_resize_aid_, &KernelAsyncResizeActor::ResizeKernelMod, context, this);
860 }
861
ExecuteResizeKernelModTask(OpContext<DeviceTensor> * const context)862 void KernelActor::ExecuteResizeKernelModTask(OpContext<DeviceTensor> *const context) {
863 ProfilerRecorder profiler(ProfilerModule::kKernel, ProfilerEvent::kKernelResize, GetAID().Name());
864 if (IsRunningFailed(context)) {
865 MS_LOG(DEBUG) << "Run failed and early stop resize for kernel: " << kernel_->fullname_with_scope();
866 return;
867 }
868
869 if (has_dynamic_) {
870 device_contexts_[0]->device_res_manager_->BindDeviceToCurrentThread(false);
871 ResizeKernelMod();
872
873 FetchOutputDeviceTensor(context);
874 FetchWorkspaceDeviceTensor();
875 } else {
876 FetchOutputDeviceTensor(context);
877 }
878
879 Async(kernel_async_launch_aid_, &KernelAsyncLaunchActor::LaunchKernel, context, this);
880 }
881
ExecuteLaunchKernelTask(OpContext<DeviceTensor> * const context)882 void KernelActor::ExecuteLaunchKernelTask(OpContext<DeviceTensor> *const context) {
883 if (IsRunningFailed(context)) {
884 MS_LOG(DEBUG) << "Run failed and early stop launch kernel: " << kernel_->fullname_with_scope();
885 return;
886 }
887 // 1. Allocate memory.
888 if (!ActorDispatcher::enable_use_trace_memory()) {
889 if (!memory_alloc_list_.empty()) {
890 SendMemoryAllocReq(context);
891 }
892 } else if (!max_ref_cnt_output_list_.empty()) {
893 // Allocate dynamic memory for graph output.
894 MemoryManagerActor::GetInstance()->AllocateMemory(&max_ref_cnt_output_list_, device_contexts_[0], context,
895 GetAID());
896 }
897
898 if (IsRunningFailed(context)) {
899 MS_LOG(DEBUG) << "Run failed and early stop launch kernel: " << kernel_->fullname_with_scope();
900 return;
901 }
902 // For performance, Only kernel need user data (such as PyExecute op) need call 'PreLaunchKernel', the
903 // 'PreLaunchKernel' will be removed in the future.
904 if (ActorDispatcher::has_kernel_need_user_data()) {
905 PreLaunchKernel(context);
906 }
907
908 // 2. Launch kernel if need.
909 device_contexts_[0]->device_res_manager_->BindDeviceToCurrentThread(false);
910
911 if (debug_aid_ != nullptr) {
912 ActorDispatcher::SendSync(*debug_aid_, &DebugActor::DebugPreLaunch, kernel_, input_device_tensors_,
913 output_device_tensors_, device_contexts_[0], context, &GetAID());
914 }
915
916 if (!LaunchKernel(context, IsSkippedLaunch(kernel_, nullptr))) {
917 MS_LOG_WITH_NODE(EXCEPTION, kernel_) << "#umsg#Kernel error:#umsg#Launch kernel failed: " +
918 kernel_->fullname_with_scope()
919 << trace::DumpSourceLines(kernel_);
920 }
921
922 if (debug_aid_ != nullptr || recorder_aid_ != nullptr) {
923 SetMemInfoForDebugAndRdr();
924
925 if (debug_aid_ != nullptr) {
926 ActorDispatcher::SendSync(*debug_aid_, &DebugActor::DebugPostLaunch, kernel_, input_device_tensors_,
927 output_device_tensors_, device_contexts_[0], context, &GetAID());
928 }
929 if (recorder_aid_ != nullptr) {
930 ActorDispatcher::Send(*recorder_aid_, &RecorderActor::RecordInfo, kernel_->fullname_with_scope(), &mem_info_,
931 device_contexts_[0], context);
932 }
933 }
934
935 if (is_dynamic_shape_ && kernel_mod_->IsNeedUpdateOutputShapeAndSize()) {
936 kernel_mod_->UpdateOutputShapeAndSize(input_kernel_tensors_, output_kernel_tensors_);
937 }
938
939 if (kernel_mod_->need_user_data()) {
940 for_each(output_device_tensors_.begin(), output_device_tensors_.end(),
941 [](auto &device_tensor) { device_tensor->set_need_sync_user_data(true); });
942 }
943
944 if ((modifiable_ref_input_indexes_.size() != 0) || (modifiable_ref_output_indexes_.size() != 0)) {
945 RefreshDeviceTensorCopyStore(context);
946 }
947
948 // 3. Free memory.
949 if (!ActorDispatcher::enable_use_trace_memory()) {
950 if (memory_free_list_.size() > 0) {
951 SendMemoryFreeReq(context);
952 }
953 }
954 }
955
InferAndResize(OpContext<DeviceTensor> * const context)956 void KernelActor::InferAndResize(OpContext<DeviceTensor> *const context) {
957 if (!enable_async_infer_) {
958 // If the kernel need user data and is dynamic, maybe need input kernel's output user data to infer shape, this
959 // value depend case can not handle in KernelTensor auto sync phase currently.
960 if (ActorDispatcher::enable_async_launch_kernel() && kernel_mod_->need_user_data() &&
961 !WaitRuntimePipelineFinish(context)) {
962 MS_LOG(INFO) << "Run failed and early stop for kernel: " << kernel_->fullname_with_scope();
963 return;
964 }
965
966 if (is_dynamic_type_) {
967 ProfilerRecorder profiler(ProfilerModule::kKernel, ProfilerEvent::kKernelInferAndResize, GetAID().Name());
968 // For dynamic shape case, need Re-InferShape and Resize kernel mod.
969 InferShapeAndType();
970 ResizeKernelMod();
971 } else if (is_dynamic_shape_) {
972 ProfilerRecorder profiler(ProfilerModule::kKernel, ProfilerEvent::kKernelInferAndResize, GetAID().Name());
973 // For dynamic shape case, need Re-InferShape and Resize kernel mod.
974 InferShape();
975 ResizeKernelMod();
976 } else if (is_dynamic_value_) {
977 ProfilerRecorder profiler(ProfilerModule::kKernel, ProfilerEvent::kKernelResize, GetAID().Name());
978 ResizeKernelMod();
979 }
980
981 return;
982 }
983
984 if (is_dynamic_value_ && !is_dynamic_shape_ && !is_dynamic_type_) {
985 ProfilerRecorder profiler(ProfilerModule::kKernel, ProfilerEvent::kKernelResize, GetAID().Name());
986 ResizeKernelMod();
987 }
988 }
989
InferShapeAndType()990 void KernelActor::InferShapeAndType() {
991 MS_LOG(DEBUG) << "Begin InferShapeAnyType for kernel: " << kernel_->fullname_with_scope()
992 << ", inputs: " << input_kernel_tensors_for_infer_;
993 // 1. Infer operator's output's Shape and Type.
994 auto abstract = opt::dynamic_shape::InferShapeAndType(kernel_mod_->primitive(), input_kernel_tensors_for_infer_);
995 MS_EXCEPTION_IF_NULL(abstract);
996 MS_LOG(DEBUG) << "End InferShapeAnyType for kernel: " << kernel_->fullname_with_scope()
997 << ", abstract: " << abstract->ToString();
998 // 2. Update shape of output kernel tensor.
999 opt::dynamic_shape::UpdateKernelTensorType(abstract->GetType(), output_kernel_tensors_);
1000 opt::dynamic_shape::UpdateKernelTensorShape(abstract->GetShape(), output_kernel_tensors_);
1001 }
1002
InferShape()1003 void KernelActor::InferShape() {
1004 MS_LOG(DEBUG) << "Begin InferShape for kernel: " << kernel_->fullname_with_scope()
1005 << ", inputs: " << input_kernel_tensors_for_infer_;
1006 // 1. Infer operator's output's Shape.
1007 auto base_shape = opt::dynamic_shape::InferShape(kernel_mod_->primitive(), input_kernel_tensors_for_infer_);
1008 MS_EXCEPTION_IF_NULL(base_shape);
1009 MS_LOG(DEBUG) << "End InferShape for kernel: " << kernel_->fullname_with_scope()
1010 << ", shape: " << base_shape->ToString();
1011
1012 // 2. Update shape of output kernel tensor.
1013 opt::dynamic_shape::UpdateKernelTensorShape(base_shape, output_kernel_tensors_);
1014 }
1015
ResizeKernelMod()1016 void KernelActor::ResizeKernelMod() {
1017 MS_LOG(DEBUG) << "Begin Resize kernel mod for kernel: " << kernel_->fullname_with_scope();
1018 int ret = kernel_mod_->Resize(input_kernel_tensors_, output_kernel_tensors_);
1019 MS_LOG(DEBUG) << "End Resize kernel mod for kernel: " << kernel_->fullname_with_scope()
1020 << ", the output size list: " << kernel_mod_->GetOutputSizeList()
1021 << ", workspace size list: " << kernel_mod_->GetWorkspaceSizeList();
1022 if (ret != kernel::KRET_OK) {
1023 MS_LOG_WITH_NODE(EXCEPTION, kernel_) << "Resize failed for kernel: " << kernel_->fullname_with_scope();
1024 }
1025 }
1026 namespace {
TrackInputMemory(const std::vector<DeviceTensor * > & input_device_tensors,const std::string & actor_name,const std::vector<bool> & depend_shape_input_list)1027 void TrackInputMemory(const std::vector<DeviceTensor *> &input_device_tensors, const std::string &actor_name,
1028 const std::vector<bool> &depend_shape_input_list) {
1029 for (size_t i = 0, end = input_device_tensors.size(); i < end; i++) {
1030 // Skip shape depend inputs.
1031 if (i < depend_shape_input_list.size() && depend_shape_input_list[i]) {
1032 continue;
1033 }
1034 auto device_addr = input_device_tensors[i];
1035 if (device_addr == nullptr || !device_addr->IsPtrValid()) {
1036 continue;
1037 }
1038 device::tracker::CALL_MEMORY_TRACKER_WITH_FILE(UseMemBlock, actor_name, device_addr->GetPtr());
1039 }
1040 }
1041 } // namespace
1042
LaunchKernel(OpContext<DeviceTensor> * const context,bool is_skip_launch)1043 bool KernelActor::LaunchKernel(OpContext<DeviceTensor> *const context, bool is_skip_launch) {
1044 if (device::tracker::MemTrackerManager::GetInstance().IsEnabled()) {
1045 TrackInputMemory(input_device_tensors_, GetAID().Name(), depend_shape_input_list_);
1046 }
1047 if (is_skip_launch) {
1048 return true;
1049 }
1050 if (skip_launch_shape_related_op_) {
1051 MS_LOG(DEBUG) << "Skip launch real make tuple kernel: " << kernel_->fullname_with_scope()
1052 << " input kernel tensor: " << input_kernel_tensors_;
1053 return true;
1054 }
1055 // Check the skipped launch condition.
1056 if (is_launch_skipped_) {
1057 MS_EXCEPTION_IF_CHECK_FAIL((input_device_tensors_.size() >= 1), "The inputs size is wrong.");
1058 MS_EXCEPTION_IF_CHECK_FAIL((output_device_tensors_.size() >= 1), "The outputs size is wrong.");
1059 MS_EXCEPTION_IF_NULL(input_device_tensors_[0]);
1060 MS_EXCEPTION_IF_NULL(output_device_tensors_[0]);
1061 if (input_device_tensors_[0]->GetPtr() == output_device_tensors_[0]->GetPtr()) {
1062 MS_LOG(DEBUG) << "Skipped launch kernel: " << kernel_->fullname_with_scope();
1063 return true;
1064 } else {
1065 MS_LOG(ERROR) << "Input address:" << input_device_tensors_[0]->GetPtr()
1066 << " and output address:" << output_device_tensors_[0]->GetPtr()
1067 << " are not equal of skipped launch actor: " << GetAID().Name();
1068 return false;
1069 }
1070 }
1071
1072 // Cpu not support stream lock with LaunchKernel.
1073 if (!ActorDispatcher::enable_multi_stream() || is_multi_stream_process_skipped_) {
1074 MS_LOG(DEBUG) << "Begin launch kernel: " << kernel_->fullname_with_scope();
1075 auto ret = device_contexts_[0]->GetKernelExecutor(false)->LaunchKernel(
1076 kernel_, input_kernel_tensors_, workspace_kernel_tensors_, output_kernel_tensors_, kernel_mod_, stream_);
1077 MS_LOG(DEBUG) << "End launch kernel: " << kernel_->fullname_with_scope();
1078 return ret;
1079 }
1080
1081 auto multi_stream_controller = device::MultiStreamController::GetInstance();
1082 bool ret = false;
1083 if (!ActorDispatcher::enable_async_launch_kernel()) {
1084 std::lock_guard<std::mutex> lock(
1085 multi_stream_controller->GetStreamMutex(device_contexts_[0], kernel_info_->stream_id()));
1086 ProcessMultiStreamBeforeKernelLaunch(context);
1087 MS_LOG(DEBUG) << "Begin launch kernel: " << kernel_->fullname_with_scope();
1088 ret = device_contexts_[0]->GetKernelExecutor(false)->LaunchKernel(
1089 kernel_, input_kernel_tensors_, workspace_kernel_tensors_, output_kernel_tensors_, kernel_mod_, stream_);
1090 MS_LOG(DEBUG) << "End launch kernel: " << kernel_->fullname_with_scope();
1091 ProcessMultiStreamAfterKernelLaunch(context);
1092 } else {
1093 ProcessMultiStreamBeforeKernelLaunch(context);
1094 MS_LOG(DEBUG) << "Begin launch kernel: " << kernel_->fullname_with_scope();
1095 ret = device_contexts_[0]->GetKernelExecutor(false)->LaunchKernel(
1096 kernel_, input_kernel_tensors_, workspace_kernel_tensors_, output_kernel_tensors_, kernel_mod_, stream_);
1097 MS_LOG(DEBUG) << "End launch kernel: " << kernel_->fullname_with_scope();
1098 ProcessMultiStreamAfterKernelLaunch(context);
1099 }
1100 return ret;
1101 }
1102
ProcessMultiStreamBeforeKernelLaunch(OpContext<DeviceTensor> * const context)1103 void KernelActor::ProcessMultiStreamBeforeKernelLaunch(OpContext<DeviceTensor> *const context) {
1104 ProfilerRecorder profiler(ProfilerModule::kKernel, ProfilerEvent::kProcessMultiStream, GetAID().Name());
1105 auto device_context = device_contexts_[0];
1106 auto stream_id = kernel_info_->stream_id();
1107 // Update output_kernel_tensors_ with task id on stream.
1108 auto multi_stream_controller = device::MultiStreamController::GetInstance();
1109 auto task_id_on_stream = multi_stream_controller->LaunchTaskIdOnStream(device_context, stream_id);
1110 MS_LOG(DEBUG) << "device context : " << device_context
1111 << ", name : " << device_context->device_context_key().device_name_ << ", stream id : " << stream_id
1112 << ", actor name : " << GetAID().Name() << ", task_id_on_stream : " << task_id_on_stream << ".";
1113 if (INT64_MAX == task_id_on_stream) {
1114 // Cpu kernel task id on stream is meanless.
1115 *task_id_on_stream_ = 0;
1116 MS_LOG(DEBUG) << "Skip ProcessMultiStreamBeforeKernelLaunch since kernel type is CPU.";
1117 return;
1118 }
1119 *task_id_on_stream_ = task_id_on_stream;
1120
1121 // Process wait stream.
1122 if (is_stream_recv_actor_) {
1123 // Note: wait node start to launch. Event was record on send node, so, we can releases events on send node stream.
1124 // Release events on send node means memory stream id is recv node stream id and user stream id is send node
1125 // stream id.
1126 auto user_stream_id = kernel_mod_->record_stream_id();
1127 auto memory_stream_id = stream_id;
1128 if (stream_send_actor_ == nullptr) {
1129 // Gpu not add stream send/recv pair, nullptr is normal case.
1130 MS_LOG(DEBUG) << "Stream_send_actor_ is nullptr.";
1131 return;
1132 }
1133 MS_LOG(DEBUG) << "Process wait stream start, memory_stream_id : " << memory_stream_id
1134 << ", send task id on stream : " << *(stream_send_actor_->task_id_on_stream_) << ".";
1135 // Here, need get task id on stream from send node.
1136 (void)multi_stream_controller->WaitEvent(device_context, *(stream_send_actor_->task_id_on_stream_), user_stream_id,
1137 memory_stream_id);
1138 return;
1139 }
1140
1141 // Reset cross stream addresses.
1142 cross_stream_addresses_.clear();
1143
1144 // Process inputs.
1145 if (input_kernel_tensors_.empty()) {
1146 return;
1147 }
1148
1149 std::vector<KernelTensor *> cross_stream_kernel_tensors;
1150 for (const auto &input_kernel_tensor : input_kernel_tensors_) {
1151 if (input_kernel_tensor->stream_id() == stream_id) {
1152 continue;
1153 }
1154 if (input_kernel_tensor->task_id_on_stream() == nullptr) {
1155 MS_LOG(DEBUG) << "Input_kernel_tensor : " << input_kernel_tensor
1156 << " task id on stream is nullptr, will skip multi stream process.";
1157 continue;
1158 }
1159 if (input_kernel_tensor->managed_by_somas()) {
1160 MS_LOG(DEBUG) << "Input_kernel_tensor : " << input_kernel_tensor << " is managed by somas.";
1161 continue;
1162 }
1163 // Nullptr device ptr is normal case, here need skip these inputs.
1164 if (input_kernel_tensor->device_ptr() == nullptr) {
1165 MS_LOG(DEBUG) << "Input kernel tensor device ptr is nullptr.";
1166 continue;
1167 }
1168 (void)cross_stream_addresses_.emplace_back(input_kernel_tensor->stream_id(), input_kernel_tensor->device_ptr());
1169 if (!is_multi_stream_safe_) {
1170 (void)cross_stream_kernel_tensors.emplace_back(input_kernel_tensor);
1171 }
1172 }
1173
1174 // Dispatch record/wait.
1175 if (!is_multi_stream_safe_) {
1176 for (const auto &cross_stream_kernel_tensor : cross_stream_kernel_tensors) {
1177 // Nullptr of task id on stream is normal case.
1178 // If cross_stream_kernel_tensor's task id on stream is nullptr, kernel tensor must be safe.
1179 // Data prepare actor, data source actor and so on has prepare device tensors without task id on stream, and
1180 // those device tensors is multi-stream safe.
1181 if (cross_stream_kernel_tensor->task_id_on_stream() == nullptr) {
1182 continue;
1183 }
1184 // Input kernel tensor is memory stream id, this is important.
1185 auto user_stream_id = stream_id;
1186 auto memory_stream_id = cross_stream_kernel_tensor->stream_id();
1187 auto memory_task_id_on_stream = *cross_stream_kernel_tensor->task_id_on_stream();
1188 auto safe_task_id_on_stream =
1189 multi_stream_controller->QueryTaskIdOnStream(device_context, user_stream_id, memory_stream_id);
1190 if (safe_task_id_on_stream >= memory_task_id_on_stream) {
1191 MS_LOG(DEBUG) << "Safe_task_id_on_stream : " << safe_task_id_on_stream
1192 << " is bigger than memory_task_id_on_stream : " << memory_task_id_on_stream << ".";
1193 continue;
1194 }
1195 multi_stream_controller->DispatchRecordWaitEvent(device_context, user_stream_id, memory_stream_id);
1196 // Add recv process.
1197 user_stream_id = memory_stream_id;
1198 memory_stream_id = stream_id;
1199 auto last_task_id_on_stream = multi_stream_controller->GetTaskIdOnStream(device_context, user_stream_id);
1200 MS_LOG(DEBUG) << "Dispatch wait stream start, usert_stream_id : " << user_stream_id
1201 << ", memory_stream_id : " << memory_stream_id
1202 << ", last_task_id_on_stream : " << last_task_id_on_stream << ".";
1203 // Here, need get task id on stream from send node.
1204 (void)multi_stream_controller->WaitEvent(device_context, last_task_id_on_stream, user_stream_id,
1205 memory_stream_id);
1206 }
1207 }
1208 }
1209
ProcessMultiStreamAfterKernelLaunch(OpContext<DeviceTensor> * const context)1210 void KernelActor::ProcessMultiStreamAfterKernelLaunch(OpContext<DeviceTensor> *const context) {
1211 auto stream_id = kernel_info_->stream_id();
1212 if (stream_id != kDefaultStreamIndex) {
1213 for (const auto &output_kernel_tensor : output_kernel_tensors_) {
1214 cross_stream_addresses_.emplace_back(kDefaultStreamIndex, output_kernel_tensor->device_ptr());
1215 }
1216 }
1217
1218 // Record event.
1219 if (!cross_stream_addresses_.empty()) {
1220 MS_LOG(DEBUG) << "Record event for kernel : " << kernel_->fullname_with_scope()
1221 << ", addresses size : " << cross_stream_addresses_.size() << ".";
1222 // Record event on stream.
1223 auto device_context = device_contexts_[0];
1224 auto multi_stream_controller = device::MultiStreamController::GetInstance();
1225 multi_stream_controller->RecordEvent(device_context, *task_id_on_stream_, stream_id, cross_stream_addresses_);
1226 }
1227 }
1228
PostLaunchKernel(OpContext<DeviceTensor> * const context)1229 void KernelActor::PostLaunchKernel(OpContext<DeviceTensor> *const context) {
1230 if (is_dynamic_shape_ && kernel_mod_->IsNeedUpdateOutputShapeAndSize()) {
1231 kernel_mod_->UpdateOutputShapeAndSize(input_kernel_tensors_, output_kernel_tensors_);
1232 }
1233
1234 if (kernel_mod_->need_user_data()) {
1235 for_each(output_device_tensors_.begin(), output_device_tensors_.end(),
1236 [](auto &device_tensor) { device_tensor->set_need_sync_user_data(true); });
1237 }
1238
1239 if ((modifiable_ref_input_indexes_.size() != 0) || (modifiable_ref_output_indexes_.size() != 0)) {
1240 RefreshDeviceTensorCopyStore(context);
1241 }
1242
1243 // The input is invalid and needs to be erased when finish kernel launch.
1244 EraseInput(context);
1245
1246 // Note that SendMemoryFreeReq must be in front of SendOutput, because SendOutput will trigger SendMemoryAllocReq
1247 // of the next actor and the actor is asynchronous execution. So it is necessary to ensure that SendMemoryFreeReq
1248 // of the current actor is in front of SendMemoryAllocReq of the next actor. One is to reuse the memory more
1249 // fully, the other is to ensure the execution order and avoid the illegal memory timing problem.
1250 if (memory_free_list_.size() > 0) {
1251 SendMemoryFreeReq(context);
1252 }
1253
1254 SendOutput(context);
1255 }
1256
RefreshDeviceTensorCopyStore(OpContext<DeviceTensor> * const context)1257 void KernelActor::RefreshDeviceTensorCopyStore(OpContext<DeviceTensor> *const context) {
1258 uint64_t start_time = 0;
1259 PROFILER_START(start_time);
1260
1261 for (auto &ref_input_index : modifiable_ref_input_indexes_) {
1262 if (ref_input_index >= input_device_tensors_.size()) {
1263 SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(strategy_, *context, "The input index is of range.");
1264 }
1265 auto &input_device_tensor = input_device_tensors_[ref_input_index];
1266 MS_EXCEPTION_IF_NULL(input_device_tensor);
1267 auto need_refreshed_device_tensors = DeviceTensorCopyStore::GetInstance().Fetch(input_device_tensor);
1268 for (auto &new_device_tensor : need_refreshed_device_tensors) {
1269 MS_EXCEPTION_IF_NULL(new_device_tensor);
1270 MS_LOG(INFO) << GetAID().Name() << " the input position:" << ref_input_index
1271 << " refresh from device address:" << input_device_tensor
1272 << ", type:" << input_device_tensor->GetDeviceType() << ", format:" << input_device_tensor->format()
1273 << " to device address:" << new_device_tensor << ", type:" << new_device_tensor->GetDeviceType()
1274 << ", format:" << new_device_tensor->format();
1275 if (!Copy(new_device_tensor, input_device_tensor)) {
1276 std::string error_info = "Copy input device tensor failed: " + GetAID().Name();
1277 SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(strategy_, *context, error_info);
1278 }
1279 }
1280 }
1281
1282 for (auto &ref_output_index : modifiable_ref_output_indexes_) {
1283 if (ref_output_index >= output_device_tensors_.size()) {
1284 SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(strategy_, *context, "The output index is of range.");
1285 }
1286 auto &output_device_tensor = output_device_tensors_[ref_output_index];
1287 MS_EXCEPTION_IF_NULL(output_device_tensor);
1288 auto need_refreshed_device_tensors = DeviceTensorCopyStore::GetInstance().Fetch(output_device_tensor);
1289 for (auto &new_device_tensor : need_refreshed_device_tensors) {
1290 MS_EXCEPTION_IF_NULL(new_device_tensor);
1291 MS_LOG(INFO) << GetAID().Name() << " the output position:" << ref_output_index
1292 << " refresh from device address:" << output_device_tensor
1293 << ", type:" << output_device_tensor->GetDeviceType()
1294 << ", format:" << output_device_tensor->format() << " to device address:" << new_device_tensor
1295 << ", type:" << new_device_tensor->GetDeviceType() << ", format:" << new_device_tensor->format();
1296 if (!Copy(new_device_tensor, output_device_tensor)) {
1297 std::string error_info = "Copy output device tensor failed: " + GetAID().Name();
1298 SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(strategy_, *context, error_info);
1299 }
1300 }
1301 }
1302
1303 PROFILER_END(start_time, ProfilerModule::kRuntime, ProfilerEvent::kPostLaunch, GetAID().Name(), false);
1304 }
1305
SendRecorderInfo(OpContext<DeviceTensor> * const context) const1306 void KernelActor::SendRecorderInfo(OpContext<DeviceTensor> *const context) const {
1307 if (recorder_aid_ != nullptr && !ActorDispatcher::enable_async_launch_kernel()) {
1308 MS_EXCEPTION_IF_NULL(kernel_);
1309 ActorDispatcher::Send(*recorder_aid_, &RecorderActor::RecordInfo, kernel_->fullname_with_scope(), &mem_info_,
1310 device_contexts_[0], context);
1311 }
1312 }
1313
SetInputDeviceTensor(DeviceTensor * input_device_tensor,size_t input_index)1314 void KernelActor::SetInputDeviceTensor(DeviceTensor *input_device_tensor, size_t input_index) {
1315 input_device_tensors_[input_index] = input_device_tensor;
1316 input_kernel_tensors_[input_index] = input_device_tensor->kernel_tensor().get();
1317 input_kernel_tensors_for_infer_[input_index] = input_device_tensor->kernel_tensor();
1318 }
1319
1320 } // namespace runtime
1321 } // namespace mindspore
1322