1 /**
2 * Copyright 2019-2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "runtime/device/ascend/ge_runtime/task/tbe_task.h"
18 #include <vector>
19 #include "runtime/mem.h"
20 #include "runtime/kernel.h"
21 #include "runtime/device/ascend/ge_runtime/task/task_factory.h"
22
23 namespace mindspore::ge::model_runner {
TbeTask(const ModelContext & model_context,const std::shared_ptr<TbeTaskInfo> & task_info)24 TbeTask::TbeTask(const ModelContext &model_context, const std::shared_ptr<TbeTaskInfo> &task_info)
25 : TaskRepeater<TbeTaskInfo>(model_context, task_info),
26 task_info_(task_info),
27 stream_(nullptr),
28 stub_func_(nullptr),
29 args_(nullptr) {
30 MS_EXCEPTION_IF_NULL(task_info);
31
32 auto stream_list = model_context.stream_list();
33 if (stream_list.size() == 1) {
34 stream_ = stream_list[0];
35 } else if (stream_list.size() > task_info->stream_id()) {
36 stream_ = stream_list[task_info->stream_id()];
37 } else {
38 MS_LOG(EXCEPTION) << "Index: " << task_info->stream_id() << " >= stream_list.size(): " << stream_list.size();
39 }
40 }
41
~TbeTask()42 TbeTask::~TbeTask() {
43 if (args_ != nullptr) {
44 rtError_t rt_ret = rtFree(args_);
45 if (rt_ret != RT_ERROR_NONE) {
46 MS_LOG(ERROR) << "Call rt api rtFree failed, ret: " << rt_ret;
47 }
48 args_ = nullptr;
49 }
50 }
51
Distribute()52 void TbeTask::Distribute() {
53 MS_LOG(INFO) << "InitTbeTask start.";
54 MS_EXCEPTION_IF_NULL(stream_);
55 // Get stub_func
56 if (task_info_->stub_func().empty()) {
57 MS_LOG(EXCEPTION) << "kernel_info->stub_func is empty!";
58 }
59
60 rtError_t rt_ret = rtGetFunctionByName(const_cast<char *>(task_info_->stub_func().c_str()), &stub_func_);
61 if (rt_ret != RT_ERROR_NONE) {
62 MS_LOG(EXCEPTION) << "Call rt api rtGetFunctionByName failed, ret: " << rt_ret;
63 }
64 MS_LOG(INFO) << "TbeTask: stub_func = " << task_info_->stub_func();
65
66 // Get args
67 std::vector<void *> tensor_device_addrs;
68 tensor_device_addrs.insert(tensor_device_addrs.end(), task_info_->input_data_addrs().begin(),
69 task_info_->input_data_addrs().end());
70 tensor_device_addrs.insert(tensor_device_addrs.end(), task_info_->output_data_addrs().begin(),
71 task_info_->output_data_addrs().end());
72 tensor_device_addrs.insert(tensor_device_addrs.end(), task_info_->workspace_addrs().begin(),
73 task_info_->workspace_addrs().end());
74 auto args_size = static_cast<uint32_t>(tensor_device_addrs.size() * sizeof(void *));
75
76 rt_ret = rtMalloc(&args_, args_size, RT_MEMORY_HBM);
77 if (rt_ret != RT_ERROR_NONE) {
78 MS_LOG(EXCEPTION) << "Call rt api rtMalloc failed, ret: " << rt_ret << " mem size " << args_size;
79 }
80
81 rt_ret = rtMemcpy(args_, args_size, reinterpret_cast<void *>(tensor_device_addrs.data()), args_size,
82 RT_MEMCPY_HOST_TO_DEVICE);
83 if (rt_ret != RT_ERROR_NONE) {
84 MS_LOG(EXCEPTION) << "Call rt api rtMemcpy failed, ret: " << rt_ret;
85 }
86
87 MS_LOG(INFO) << "DistributeTbeTask start.";
88 auto dump_flag = task_info_->dump_flag() ? RT_KERNEL_DUMPFLAG : RT_KERNEL_DEFAULT;
89 rt_ret = rtKernelLaunchWithFlag(stub_func_, task_info_->block_dim(), args_, args_size, nullptr, stream_, dump_flag);
90 if (rt_ret != RT_ERROR_NONE) {
91 MS_LOG(EXCEPTION) << "Call rt api rtKernelLaunch failed, ret: " << rt_ret << " mem size " << args_size;
92 }
93 MS_LOG(INFO) << "[DataDump] task name: " << task_info_->op_name() << " dump_flag: " << dump_flag;
94 }
95
96 REGISTER_TASK(TaskInfoType::TBE, TbeTask, TbeTaskInfo);
97 } // namespace mindspore::ge::model_runner
98