• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "src/extendrt/delegate/tensorrt/tensorrt_allocator.h"
18 #include <cuda_runtime.h>
19 #include <mutex>
20 #include "src/common/log_adapter.h"
21 #include "src/extendrt/delegate/tensorrt/tensorrt_utils.h"
22 #include "src/extendrt/delegate/tensorrt/cuda_impl/cast.cuh"
23 
24 namespace mindspore::lite {
MallocDeviceMem(const TensorInfo & host_tensor,size_t size)25 void *TensorRTAllocator::MallocDeviceMem(const TensorInfo &host_tensor, size_t size) {
26   return MallocDeviceMem(host_tensor.Name(), size, ConvertDataType(host_tensor.DataType()));
27 }
28 
MallocDeviceMem(const std::string & name,size_t size,nvinfer1::DataType data_type)29 void *TensorRTAllocator::MallocDeviceMem(const std::string &name, size_t size, nvinfer1::DataType data_type) {
30   if (cuda_tensor_map_.find(name) != cuda_tensor_map_.end() && size <= cuda_tensor_map_[name].size) {
31     MS_LOG(DEBUG) << "tensor :" << name << " has already in cuda Allocator pool.";
32     return cuda_tensor_map_[name].data;
33   }
34   void *device_ptr = nullptr;
35   auto cuda_ret = cudaMalloc(&device_ptr, size);
36   if (cuda_ret != cudaSuccess) {
37     MS_LOG(ERROR) << "Cuda Malloc failed for size:" << size;
38     return nullptr;
39   }
40   MS_LOG(INFO) << "cudaMalloc size: " << size << " for " << name;
41   if (cuda_tensor_map_[name].data != nullptr) {
42     cuda_ret = cudaFree(cuda_tensor_map_[name].data);
43     if (cuda_ret != cudaSuccess && cuda_ret != cudaErrorCudartUnloading) {
44       MS_LOG(ERROR) << "free old cuda device_ptr failed for " << cudaGetErrorName(cuda_ret);
45       cuda_ret = cudaFree(device_ptr);
46       if (cuda_ret != cudaSuccess) {
47         MS_LOG(ERROR) << "free new cuda device_ptr failed for " << cudaGetErrorName(cuda_ret);
48         return nullptr;
49       }
50       return nullptr;
51     }
52   }
53   cuda_tensor_map_[name].data = device_ptr;
54   cuda_tensor_map_[name].is_valid_mem = false;
55   cuda_tensor_map_[name].size = size;
56   return device_ptr;
57 }
58 
MarkMemValid(const std::string & name,bool isValid)59 void TensorRTAllocator::MarkMemValid(const std::string &name, bool isValid) {
60   cuda_tensor_map_[name].is_valid_mem = isValid;
61   return;
62 }
63 
GetMemIsValid(const std::string & name)64 bool TensorRTAllocator::GetMemIsValid(const std::string &name) {
65   if (cuda_tensor_map_.find(name) == cuda_tensor_map_.end()) {
66     MS_LOG(WARNING) << "tensor :" << name << " not in cuda Allocator pool.";
67     return false;
68   }
69   return cuda_tensor_map_[name].is_valid_mem;
70 }
71 
GetDevicePtr(const std::string & tensor_name)72 void *TensorRTAllocator::GetDevicePtr(const std::string &tensor_name) {
73   if (tensor_name.empty()) {
74     return nullptr;
75   }
76   if (cuda_tensor_map_.find(tensor_name) == cuda_tensor_map_.end()) {
77     return nullptr;
78   }
79   return this->cuda_tensor_map_.find(tensor_name)->second.data;
80 }
81 
SyncMemHostToDevice(const tensor::Tensor & host_tensor,const std::string & device_tensor_name,bool sync,size_t size)82 int TensorRTAllocator::SyncMemHostToDevice(const tensor::Tensor &host_tensor, const std::string &device_tensor_name,
83                                            bool sync, size_t size) {
84   size = (size == 0) ? host_tensor.Size() : size;
85   return SyncMemInHostAndDevice(const_cast<void *>(host_tensor.data_c()), device_tensor_name, size, true, sync);
86 }
87 
SyncMemDeviceToHost(tensor::Tensor * host_tensor,const std::string & device_tensor_name,bool sync)88 int TensorRTAllocator::SyncMemDeviceToHost(tensor::Tensor *host_tensor, const std::string &device_tensor_name,
89                                            bool sync) {
90   if (host_tensor == NULL) {
91     MS_LOG(ERROR) << "host tensor is null.";
92     return RET_ERROR;
93   }
94 #if TRT_VERSION_GE(7, 2)
95   if (host_tensor->data_type() == TypeId::kNumberTypeBool) {
96     CudaTensorParam &current_cuda_tensor = cuda_tensor_map_.find(device_tensor_name)->second;
97     auto device_ptr = current_cuda_tensor.data;
98     if (device_ptr == nullptr) {
99       MS_LOG(ERROR) << "device_ptr is null for " << device_tensor_name;
100       return RET_ERROR;
101     }
102     int *host_ptr = reinterpret_cast<int *>(malloc(host_tensor->DataSize() * sizeof(int)));
103     cudaError_t cuda_ret;
104     if (sync) {
105       cuda_ret = cudaMemcpy(host_ptr, device_ptr, host_tensor->DataSize() * sizeof(int), cudaMemcpyDeviceToHost);
106     } else {
107       cuda_ret =
108         cudaMemcpyAsync(host_ptr, device_ptr, host_tensor->DataSize() * sizeof(int), cudaMemcpyDeviceToHost, stream_);
109     }
110     if (cuda_ret != cudaSuccess) {
111       MS_LOG(ERROR) << "copy mem failed,ret " << cudaGetErrorName(cuda_ret);
112       return RET_ERROR;
113     }
114     bool *host_tensor_ptr = static_cast<bool *>(host_tensor->data_c());
115     for (size_t i = 0; i != host_tensor->Size(); ++i) {
116       host_tensor_ptr[i] = (host_ptr[i] != 0);
117     }
118     free(host_ptr);
119     return RET_OK;
120   }
121 #endif
122   return SyncMemDeviceToHost(host_tensor->data_c(), host_tensor->Size(), device_tensor_name, sync);
123 }
124 
SyncMemDeviceToHost(void * dst_data,size_t data_size,const std::string & device_tensor_name,bool sync)125 int TensorRTAllocator::SyncMemDeviceToHost(void *dst_data, size_t data_size, const std::string &device_tensor_name,
126                                            bool sync) {
127   if (dst_data == nullptr) {
128     MS_LOG(ERROR) << " dst host data cannot be nullptr.";
129     return RET_ERROR;
130   }
131   auto it = cuda_tensor_map_.find(device_tensor_name);
132   if (it == cuda_tensor_map_.end()) {
133     MS_LOG(ERROR) << " cannot find device address " << device_tensor_name;
134     return RET_ERROR;
135   }
136   CudaTensorParam &current_cuda_tensor = it->second;
137   // is memcpy from device to host, the host mem is valid, change tag for mem pool.
138   current_cuda_tensor.is_valid_mem = true;
139   auto device_ptr = current_cuda_tensor.data;
140   if (device_ptr == nullptr) {
141     MS_LOG(ERROR) << "device_ptr is null for " << device_tensor_name;
142     return RET_ERROR;
143   }
144   cudaError_t cuda_ret;
145   if (sync)
146     cuda_ret = cudaMemcpy(dst_data, device_ptr, data_size, cudaMemcpyDeviceToHost);
147   else
148     cuda_ret = cudaMemcpyAsync(dst_data, device_ptr, data_size, cudaMemcpyDeviceToHost, stream_);
149   if (cuda_ret != cudaSuccess) {
150     MS_LOG(ERROR) << "copy mem failed,ret " << cudaGetErrorName(cuda_ret);
151     return RET_ERROR;
152   }
153   MS_LOG(INFO) << "cuda memcpy success for " << device_tensor_name;
154   return RET_OK;
155 }
156 
SyncMemInHostAndDevice(tensor::Tensor * host_tensor,const std::string & device_tensor_name,bool is_host2device,bool sync)157 int TensorRTAllocator::SyncMemInHostAndDevice(tensor::Tensor *host_tensor, const std::string &device_tensor_name,
158                                               bool is_host2device, bool sync) {
159   if (host_tensor == NULL) {
160     MS_LOG(ERROR) << "host tensor is null.";
161     return RET_ERROR;
162   }
163 #if TRT_VERSION_GE(7, 2)
164   if (host_tensor->data_type() == TypeId::kNumberTypeBool && !is_host2device) {
165     CudaTensorParam &current_cuda_tensor = cuda_tensor_map_.find(device_tensor_name)->second;
166     auto device_ptr = current_cuda_tensor.data;
167     if (device_ptr == nullptr) {
168       MS_LOG(ERROR) << "device_ptr is null for " << device_tensor_name;
169       return RET_ERROR;
170     }
171     int *host_ptr = reinterpret_cast<int *>(malloc(host_tensor->DataSize()));
172     auto cuda_ret = cudaMemcpy(host_ptr, device_ptr, host_tensor->DataSize(), cudaMemcpyDeviceToHost);
173     if (cuda_ret != cudaSuccess) {
174       MS_LOG(ERROR) << "copy mem failed,ret " << cudaGetErrorName(cuda_ret);
175       return RET_ERROR;
176     }
177     bool *host_tensor_ptr = static_cast<bool *>(host_tensor->data_c());
178     for (size_t i = 0; i != host_tensor->Size(); ++i) {
179       host_tensor_ptr[i] = (host_ptr[i] != 0);
180     }
181     free(host_ptr);
182     return RET_OK;
183   }
184 #endif
185   return SyncMemInHostAndDevice(host_tensor->data_c(), device_tensor_name, host_tensor->Size(), is_host2device, sync);
186 }
187 
SyncMemInHostAndDevice(void * host_data,const std::string & device_tensor_name,size_t data_size,bool is_host2device,bool sync)188 int TensorRTAllocator::SyncMemInHostAndDevice(void *host_data, const std::string &device_tensor_name, size_t data_size,
189                                               bool is_host2device, bool sync) {
190   if (host_data == nullptr || cuda_tensor_map_.find(device_tensor_name) == cuda_tensor_map_.end()) {
191     MS_LOG(ERROR) << " host or device ptr is null.";
192     return RET_ERROR;
193   }
194   CudaTensorParam &current_cuda_tensor = cuda_tensor_map_.find(device_tensor_name)->second;
195   // is memcpy from device to host, the host mem is valid, change tag for mem pool.
196   current_cuda_tensor.is_valid_mem = is_host2device ? current_cuda_tensor.is_valid_mem : true;
197   if (is_host2device && current_cuda_tensor.is_valid_mem) {
198     MS_LOG(DEBUG) << "no need memcpy for: " << device_tensor_name;
199     return RET_OK;
200   }
201   auto device_ptr = current_cuda_tensor.data;
202   if (device_ptr == nullptr) {
203     MS_LOG(ERROR) << "device_ptr is null for " << device_tensor_name;
204     return RET_ERROR;
205   }
206 
207   void *src_ptr = is_host2device ? host_data : device_ptr;
208   void *dst_ptr = is_host2device ? device_ptr : host_data;
209   cudaMemcpyKind kind = is_host2device ? cudaMemcpyHostToDevice : cudaMemcpyDeviceToHost;
210   cudaError_t cuda_ret;
211   if (sync)
212     cuda_ret = cudaMemcpy(dst_ptr, src_ptr, data_size, kind);
213   else
214     cuda_ret = cudaMemcpyAsync(dst_ptr, src_ptr, data_size, kind, stream_);
215   if (cuda_ret != cudaSuccess) {
216     MS_LOG(ERROR) << "copy mem failed,ret " << cudaGetErrorName(cuda_ret);
217     return RET_ERROR;
218   }
219   MS_LOG(INFO) << "cuda memcpy success for " << device_tensor_name;
220   return RET_OK;
221 }
222 
ClearDeviceMem()223 int TensorRTAllocator::ClearDeviceMem() {
224   for (auto &iter : cuda_tensor_map_) {
225     auto cuda_ret = cudaFree(iter.second.data);
226     if (cuda_ret != cudaSuccess && cuda_ret != cudaErrorCudartUnloading) {
227       MS_LOG(WARNING) << "free cuda failed for " << cudaGetErrorName(cuda_ret);
228     }
229     iter.second.data = nullptr;
230     iter.second.is_valid_mem = false;
231   }
232   return RET_OK;
233 }
GetAllDevicePtr()234 std::map<std::string, CudaTensorParam> TensorRTAllocator::GetAllDevicePtr() { return this->cuda_tensor_map_; }
235 }  // namespace mindspore::lite
236