1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "src/delegate/tensorrt/tensorrt_allocator.h"
18 #include <cuda_runtime.h>
19 #include <mutex>
20 #include "src/common/log_adapter.h"
21 #include "src/delegate/tensorrt/tensorrt_utils.h"
22
23 namespace mindspore::lite {
MallocDeviceMem(const mindspore::MSTensor & host_tensor,size_t size)24 void *TensorRTAllocator::MallocDeviceMem(const mindspore::MSTensor &host_tensor, size_t size) {
25 if (host_tensor == NULL) {
26 return nullptr;
27 }
28 return MallocDeviceMem(host_tensor.Name(), size, ConvertDataType(host_tensor.DataType()));
29 }
30
MallocDeviceMem(const std::string & name,size_t size,nvinfer1::DataType data_type)31 void *TensorRTAllocator::MallocDeviceMem(const std::string &name, size_t size, nvinfer1::DataType data_type) {
32 if (cuda_tensor_map_.find(name) != cuda_tensor_map_.end() && size <= cuda_tensor_map_[name].size) {
33 MS_LOG(DEBUG) << "tensor :" << name << " has already in cuda Allocator pool.";
34 return cuda_tensor_map_[name].data;
35 }
36 void *device_ptr = nullptr;
37 auto cuda_ret = cudaMalloc(&device_ptr, size);
38 if (cuda_ret != cudaSuccess) {
39 MS_LOG(ERROR) << "Cuda Malloc failed for size:" << size;
40 return nullptr;
41 }
42 MS_LOG(INFO) << "cudaMalloc size: " << size << " for " << name;
43 if (cuda_tensor_map_[name].data != nullptr) {
44 cuda_ret = cudaFree(cuda_tensor_map_[name].data);
45 if (cuda_ret != cudaSuccess && cuda_ret != cudaErrorCudartUnloading) {
46 MS_LOG(ERROR) << "free old cuda device_ptr failed for " << cudaGetErrorName(cuda_ret);
47 cuda_ret = cudaFree(device_ptr);
48 if (cuda_ret != cudaSuccess) {
49 MS_LOG(ERROR) << "free new cuda device_ptr failed for " << cudaGetErrorName(cuda_ret);
50 return nullptr;
51 }
52 return nullptr;
53 }
54 }
55 cuda_tensor_map_[name].data = device_ptr;
56 cuda_tensor_map_[name].is_valid_mem = false;
57 cuda_tensor_map_[name].size = size;
58 return device_ptr;
59 }
60
MarkMemValid(const std::string & name,bool isValid)61 void TensorRTAllocator::MarkMemValid(const std::string &name, bool isValid) {
62 cuda_tensor_map_[name].is_valid_mem = isValid;
63 return;
64 }
65
GetMemIsValid(const std::string & name)66 bool TensorRTAllocator::GetMemIsValid(const std::string &name) {
67 if (cuda_tensor_map_.find(name) == cuda_tensor_map_.end()) {
68 MS_LOG(WARNING) << "tensor :" << name << " not in cuda Allocator pool.";
69 return false;
70 }
71 return cuda_tensor_map_[name].is_valid_mem;
72 }
73
GetDevicePtr(const std::string & tensor_name)74 void *TensorRTAllocator::GetDevicePtr(const std::string &tensor_name) {
75 if (tensor_name.empty()) {
76 return nullptr;
77 }
78 if (cuda_tensor_map_.find(tensor_name) == cuda_tensor_map_.end()) {
79 return nullptr;
80 }
81 return this->cuda_tensor_map_.find(tensor_name)->second.data;
82 }
83
SyncMemInHostAndDevice(mindspore::MSTensor host_tensor,const std::string & device_tensor_name,bool is_host2device,bool sync)84 int TensorRTAllocator::SyncMemInHostAndDevice(mindspore::MSTensor host_tensor, const std::string &device_tensor_name,
85 bool is_host2device, bool sync) {
86 if (host_tensor == NULL) {
87 MS_LOG(ERROR) << "host tensor is null.";
88 return RET_ERROR;
89 }
90 return SyncMemInHostAndDevice(host_tensor.MutableData(), device_tensor_name, host_tensor.DataSize(), is_host2device,
91 sync);
92 }
93
SyncMemInHostAndDevice(void * host_data,const std::string & device_tensor_name,size_t data_size,bool is_host2device,bool sync)94 int TensorRTAllocator::SyncMemInHostAndDevice(void *host_data, const std::string &device_tensor_name, size_t data_size,
95 bool is_host2device, bool sync) {
96 if (host_data == nullptr || cuda_tensor_map_.find(device_tensor_name) == cuda_tensor_map_.end()) {
97 MS_LOG(ERROR) << " host or device ptr is null.";
98 return RET_ERROR;
99 }
100 CudaTensorParam ¤t_cuda_tensor = cuda_tensor_map_.find(device_tensor_name)->second;
101 // is memcpy from device to host, the host mem is valid, change tag for mem pool.
102 current_cuda_tensor.is_valid_mem = is_host2device ? current_cuda_tensor.is_valid_mem : true;
103 if (is_host2device && current_cuda_tensor.is_valid_mem) {
104 MS_LOG(DEBUG) << "no need memcpy for: " << device_tensor_name;
105 return RET_OK;
106 }
107 auto device_ptr = current_cuda_tensor.data;
108 if (device_ptr == nullptr) {
109 MS_LOG(ERROR) << "device_ptr is null for " << device_tensor_name;
110 return RET_ERROR;
111 }
112
113 void *src_ptr = is_host2device ? host_data : device_ptr;
114 void *dst_ptr = is_host2device ? device_ptr : host_data;
115 cudaMemcpyKind kind = is_host2device ? cudaMemcpyHostToDevice : cudaMemcpyDeviceToHost;
116 auto cuda_ret = cudaMemcpy(dst_ptr, src_ptr, data_size, kind);
117 if (cuda_ret != cudaSuccess) {
118 MS_LOG(ERROR) << "copy mem failed.";
119 return RET_ERROR;
120 }
121 MS_LOG(INFO) << "cuda memcpy success for " << device_tensor_name;
122 return RET_OK;
123 }
124
ClearDeviceMem()125 int TensorRTAllocator::ClearDeviceMem() {
126 for (auto &iter : cuda_tensor_map_) {
127 auto cuda_ret = cudaFree(iter.second.data);
128 if (cuda_ret != cudaSuccess && cuda_ret != cudaErrorCudartUnloading) {
129 MS_LOG(WARNING) << "free cuda failed for " << cudaGetErrorName(cuda_ret);
130 }
131 iter.second.data = nullptr;
132 iter.second.is_valid_mem = false;
133 }
134 return RET_OK;
135 }
136 } // namespace mindspore::lite
137