• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "runtime/device/ascend/ge_runtime/task/tbe_task.h"
18 #include <vector>
19 #include "runtime/mem.h"
20 #include "runtime/kernel.h"
21 #include "runtime/device/ascend/ge_runtime/task/task_factory.h"
22 
23 namespace mindspore::ge::model_runner {
TbeTask(const ModelContext & model_context,const std::shared_ptr<TbeTaskInfo> & task_info)24 TbeTask::TbeTask(const ModelContext &model_context, const std::shared_ptr<TbeTaskInfo> &task_info)
25     : TaskRepeater<TbeTaskInfo>(model_context, task_info),
26       task_info_(task_info),
27       stream_(nullptr),
28       stub_func_(nullptr),
29       args_(nullptr) {
30   MS_EXCEPTION_IF_NULL(task_info);
31 
32   auto stream_list = model_context.stream_list();
33   if (stream_list.size() == 1) {
34     stream_ = stream_list[0];
35   } else if (stream_list.size() > task_info->stream_id()) {
36     stream_ = stream_list[task_info->stream_id()];
37   } else {
38     MS_LOG(EXCEPTION) << "Index: " << task_info->stream_id() << " >= stream_list.size(): " << stream_list.size();
39   }
40 }
41 
~TbeTask()42 TbeTask::~TbeTask() {
43   if (args_ != nullptr) {
44     rtError_t rt_ret = rtFree(args_);
45     if (rt_ret != RT_ERROR_NONE) {
46       MS_LOG(ERROR) << "Call rt api rtFree failed, ret: " << rt_ret;
47     }
48     args_ = nullptr;
49   }
50 }
51 
Distribute()52 void TbeTask::Distribute() {
53   MS_LOG(INFO) << "InitTbeTask start.";
54   MS_EXCEPTION_IF_NULL(stream_);
55   // Get stub_func
56   if (task_info_->stub_func().empty()) {
57     MS_LOG(EXCEPTION) << "kernel_info->stub_func is empty!";
58   }
59 
60   rtError_t rt_ret = rtGetFunctionByName(const_cast<char *>(task_info_->stub_func().c_str()), &stub_func_);
61   if (rt_ret != RT_ERROR_NONE) {
62     MS_LOG(EXCEPTION) << "Call rt api rtGetFunctionByName failed, ret: " << rt_ret;
63   }
64   MS_LOG(INFO) << "TbeTask: stub_func = " << task_info_->stub_func();
65 
66   // Get args
67   std::vector<void *> tensor_device_addrs;
68   tensor_device_addrs.insert(tensor_device_addrs.end(), task_info_->input_data_addrs().begin(),
69                              task_info_->input_data_addrs().end());
70   tensor_device_addrs.insert(tensor_device_addrs.end(), task_info_->output_data_addrs().begin(),
71                              task_info_->output_data_addrs().end());
72   tensor_device_addrs.insert(tensor_device_addrs.end(), task_info_->workspace_addrs().begin(),
73                              task_info_->workspace_addrs().end());
74   auto args_size = static_cast<uint32_t>(tensor_device_addrs.size() * sizeof(void *));
75 
76   rt_ret = rtMalloc(&args_, args_size, RT_MEMORY_HBM);
77   if (rt_ret != RT_ERROR_NONE) {
78     MS_LOG(EXCEPTION) << "Call rt api rtMalloc failed, ret: " << rt_ret << " mem size " << args_size;
79   }
80 
81   rt_ret = rtMemcpy(args_, args_size, reinterpret_cast<void *>(tensor_device_addrs.data()), args_size,
82                     RT_MEMCPY_HOST_TO_DEVICE);
83   if (rt_ret != RT_ERROR_NONE) {
84     MS_LOG(EXCEPTION) << "Call rt api rtMemcpy failed, ret: " << rt_ret;
85   }
86 
87   MS_LOG(INFO) << "DistributeTbeTask start.";
88   auto dump_flag = task_info_->dump_flag() ? RT_KERNEL_DUMPFLAG : RT_KERNEL_DEFAULT;
89   rt_ret = rtKernelLaunchWithFlag(stub_func_, task_info_->block_dim(), args_, args_size, nullptr, stream_, dump_flag);
90   if (rt_ret != RT_ERROR_NONE) {
91     MS_LOG(EXCEPTION) << "Call rt api rtKernelLaunch failed, ret: " << rt_ret << " mem size " << args_size;
92   }
93   MS_LOG(INFO) << "[DataDump] task name: " << task_info_->op_name() << " dump_flag: " << dump_flag;
94 }
95 
96 REGISTER_TASK(TaskInfoType::TBE, TbeTask, TbeTaskInfo);
97 }  // namespace mindspore::ge::model_runner
98