• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "absl/types/optional.h"
17 #ifdef GOOGLE_CUDA
18 #include "third_party/gpus/cuda/include/cuda.h"
19 #include "tensorflow/stream_executor/cuda/cuda_activation.h"
20 #endif  // GOOGLE_CUDA
21 
22 #include "absl/strings/str_cat.h"
23 #include "tensorflow/core/common_runtime/device/device_id_utils.h"
24 #include "tensorflow/core/common_runtime/gpu/gpu_cudamallocasync_allocator.h"
25 #include "tensorflow/core/common_runtime/gpu/gpu_id.h"
26 #include "tensorflow/core/common_runtime/gpu/gpu_init.h"
27 #include "tensorflow/core/framework/allocator.h"
28 #include "tensorflow/core/platform/stream_executor.h"
29 #include "tensorflow/core/util/determinism.h"
30 #include "tensorflow/core/util/env_var.h"
31 
32 namespace tensorflow {
33 
34 #if GOOGLE_CUDA
GetCudaErrorMessage(CUresult result)35 static std::string GetCudaErrorMessage(CUresult result) {
36   const char* error;
37   cuGetErrorString(result, &error);
38   const char* name;
39   cuGetErrorName(result, &name);
40   return absl::StrCat("CUDA error: ", error ? error : "<unknown>", " (",
41                       name ? name : "Unknown", ")");
42 }
43 #endif  // GOOGLE_CUDA
44 
PrintAllocatorStatistics()45 void GpuCudaMallocAsyncAllocator::PrintAllocatorStatistics() {
46   mutex_lock lock(lock_);
47 
48   std::map<size_t, int> size_map_historgram;
49   std::vector<string> ptr_size_string;
50   for (auto p : size_map_) {
51     if (VLOG_IS_ON(8)) {
52       ptr_size_string.push_back(
53           absl::StrCat("(", absl::Hex(p.first), ",", p.second) + ")");
54     }
55     size_map_historgram[p.second]++;
56   }
57   LOG(ERROR) << "Histogram of current allocation: (allocation_size_in_bytes, "
58              << "nb_allocation_of_that_sizes), ...;";
59   for (auto p : size_map_historgram) {
60     LOG(ERROR) << p.first << ", " << p.second;
61   }
62 
63   VLOG(8) << "\nThe sorted list of (ptr,size):";
64   VLOG(8) << absl::StrJoin(ptr_size_string, ",");
65 
66 #if CUDA_VERSION >= 11030
67   cuuint64_t mem_reserved_current;
68   if (auto result = cuMemPoolGetAttribute(
69           pool_, CU_MEMPOOL_ATTR_RESERVED_MEM_CURRENT, &mem_reserved_current)) {
70     LOG(ERROR) << "Error while fetching extra cudaMallocAsync pool attribute: "
71                << GetCudaErrorMessage(result);
72   }
73   cuuint64_t mem_used_current;
74   if (auto result = cuMemPoolGetAttribute(
75           pool_, CU_MEMPOOL_ATTR_USED_MEM_CURRENT, &mem_used_current)) {
76     LOG(ERROR) << "Error while fetching extra cudaMallocAsync pool attribute: "
77                << GetCudaErrorMessage(result);
78   }
79   cuuint64_t mem_reserved_high;
80   if (auto result = cuMemPoolGetAttribute(
81           pool_, CU_MEMPOOL_ATTR_RESERVED_MEM_HIGH, &mem_reserved_high)) {
82     LOG(ERROR) << "Error while fetching extra cudaMallocAsync pool attribute: "
83                << GetCudaErrorMessage(result);
84   }
85   cuuint64_t mem_used_high;
86   if (auto result = cuMemPoolGetAttribute(pool_, CU_MEMPOOL_ATTR_USED_MEM_HIGH,
87                                           &mem_used_high)) {
88     LOG(ERROR) << "Error while fetching extra cudaMallocAsync pool attribute: "
89                << GetCudaErrorMessage(result);
90   }
91   LOG(ERROR) << "CU_MEMPOOL_ATTR_RESERVED_MEM_CURRENT: "
92              << mem_reserved_current;
93   LOG(ERROR) << "CU_MEMPOOL_ATTR_USED_MEM_CURRENT: " << mem_used_current;
94   LOG(ERROR) << "CU_MEMPOOL_ATTR_RESERVED_MEM_HIGH: " << mem_reserved_high;
95   LOG(ERROR) << "CU_MEMPOOL_ATTR_USED_MEM_HIGH: " << mem_used_high;
96 #endif
97 }
98 
GpuCudaMallocAsyncAllocator(PlatformDeviceId platform_device_id,size_t pool_size,bool reserve_memory,bool compute_stats)99 GpuCudaMallocAsyncAllocator::GpuCudaMallocAsyncAllocator(
100     PlatformDeviceId platform_device_id, size_t pool_size, bool reserve_memory,
101     bool compute_stats)
102     : name_(absl::StrCat("gpu_async_", platform_device_id.value())) {
103 #if TF_CUDA_MALLOC_ASYNC_SUPPORTED
104   stream_exec_ = DeviceIdUtil::ExecutorForPlatformDeviceId(GPUMachineManager(),
105                                                            platform_device_id)
106                      .ValueOrDie();
107   // Initialized here as it only exist if compiled with a recent
108   // enough CUDA.
109   pool_ = nullptr;
110   cuda_stream_ = nullptr;
111   // WAR an CUDA 11.2 driver bug for multiple-GPU. It currently
112   // request that the context on GPU 0 is initialized. Which isn't the
113   // case for TF+horovod.
114   int driverVersion;
115   cuDriverGetVersion(&driverVersion);
116   VLOG(2) << "DRIVER VERSION: " << driverVersion;
117   if (platform_device_id.value() > 0 && driverVersion < 11030) {
118     CUcontext pctx;  // We loose track of it. But this is fine.
119     if (auto result = cuDevicePrimaryCtxRetain(&pctx, 0))
120       LOG(FATAL)  // Crash OK.
121           << "Failed to retain context: " << GetCudaErrorMessage(result);
122   }
123 
124   se::cuda::ScopedActivateExecutorContext scoped_activation{stream_exec_};
125   int cuda_malloc_async_supported;
126   if (auto status =
127           cuDeviceGetAttribute(&cuda_malloc_async_supported,
128                                CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED,
129                                platform_device_id.value()))
130     LOG(FATAL) <<  // Crash OK.
131         "Failed to get device attribute: " << GetCudaErrorMessage(status);
132   if (!cuda_malloc_async_supported)
133     LOG(FATAL)  // Crash OK.
134         << "TF_GPU_ALLOCATOR=cuda_malloc_async isn't currently supported on "
135         << "GPU id " << platform_device_id.value() << ":"
136         << " Possible causes: device not supported (request SM60+), driver too "
137            "old, "
138         << " OS not supported, CUDA version too old(request CUDA11.2+).";
139 
140   if (auto status =
141           cuDeviceGetDefaultMemPool(&pool_, platform_device_id.value()))
142     LOG(FATAL) <<  // Crash OK.
143         "Failed to get default CUDA pool: " << GetCudaErrorMessage(status);
144 
145   VLOG(1) << Name() << " CudaMallocAsync initialized on platform: "
146           << platform_device_id.value() << " with pool size of: " << pool_size
147           << " this ptr: " << this;
148   uint64_t pool_size_64 = pool_size;
149   if (auto status = cuMemPoolSetAttribute(
150           pool_, CU_MEMPOOL_ATTR_RELEASE_THRESHOLD, &pool_size_64))
151     LOG(FATAL) <<  // Crash OK.
152         "Failed to set CUDA pool attribute: " << GetCudaErrorMessage(status);
153 
154   if (compute_stats) {
155     stats_ = std::make_unique<AllocatorStats>();
156     stats_->bytes_limit = static_cast<int64>(pool_size);
157   }  // If not set, it means we do not compute stats.
158 
159   // If op determinism is enabled, then make the allocator behave
160   // determistically.
161   // TODO(reedwm): OpDeterminismRequired() should not be used here since op
162   // determinism only is supposed to affect the determinism of op outputs and
163   // side effects.
164   if (OpDeterminismRequired()) {
165     int disable = 0;
166     if (auto status = cuMemPoolSetAttribute(
167             pool_, CU_MEMPOOL_ATTR_REUSE_ALLOW_OPPORTUNISTIC, &disable)) {
168       LOG(FATAL) <<  // Crash OK.
169           "Failed to set CUDA pool attribute: " << GetCudaErrorMessage(status);
170     }
171     if (auto status = cuMemPoolSetAttribute(
172             pool_, CU_MEMPOOL_ATTR_REUSE_ALLOW_INTERNAL_DEPENDENCIES,
173             &disable)) {
174       LOG(FATAL) <<  // Crash OK.
175           "Failed to set CUDA pool attribute: " << GetCudaErrorMessage(status);
176     }
177   }
178 
179   // Set read/write access to all GPUs.
180   static auto* all_pools_ = new std::vector<CUmemoryPool*>();
181   static auto* all_ids_ = new std::vector<PlatformDeviceId>();
182   DCHECK(all_pools_->size() == all_ids_->size());
183   for (int i = 0; i < all_pools_->size(); ++i) {
184     // Set the current pool access to the previous GPUs.
185     CUmemAccessDesc map;
186     map.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
187     map.location.id = (*all_ids_)[i].value();
188 
189     map.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
190     VLOG(2) << "Setting access of the current pool to "
191             << " location id: " << map.location.id;
192     int canAccessPeer;
193     if (auto status = cuDeviceCanAccessPeer(
194             &canAccessPeer, platform_device_id.value(), map.location.id)) {
195       pool_ = nullptr;
196       LOG(FATAL)  // Crash OK.
197           << "cuDeviceCanAccessPeer failed to know if GPU id "
198           << map.location.id << " can access GPU id "
199           << platform_device_id.value() << ": " << GetCudaErrorMessage(status);
200     }
201     if (canAccessPeer == 1) {
202       if (auto status = cuMemPoolSetAccess(pool_, &map, 1)) {
203         pool_ = nullptr;
204         LOG(FATAL)  // Crash OK.
205             << "Error when setting access to the pool id: " << i
206             << " location id: " << map.location.id
207             << " error: " << GetCudaErrorMessage(status);
208       }
209     }
210 
211     // Set the previous pools access to the current GPU.
212     map.location.id = platform_device_id.value();
213 
214     VLOG(2) << "Set access to the pool id: " << i
215             << " location id: " << map.location.id;
216     if (auto status = cuDeviceCanAccessPeer(&canAccessPeer, i,
217                                             platform_device_id.value())) {
218       pool_ = nullptr;
219       LOG(FATAL)  // Crash OK.
220           << "cuDeviceCanAccessPeer failed: " << GetCudaErrorMessage(status);
221     }
222     if (canAccessPeer == 1) {
223       if (auto status = cuMemPoolSetAccess(*(*all_pools_)[i], &map, 1)) {
224         pool_ = nullptr;
225         LOG(FATAL)  // Crash OK.
226             << "Error when setting access to the pool id: " << i
227             << " location id: " << map.location.id
228             << " error: " << GetCudaErrorMessage(status);
229       }
230     }
231   }
232   all_pools_->push_back(&pool_);
233   all_ids_->push_back(platform_device_id);
234 
235   VLOG(2) << Name() << " GpuCudaMallocAsyncAllocator PoolSize " << pool_size;
236   int64 prealloc_size = 0;
237   // TF_CUDA_MALLOC_ASYNC_SUPPORTED_PREALLOC=-1 is a special value that
238   // preallocates the total pool size.
239   TF_CHECK_OK(ReadInt64FromEnvVar("TF_CUDA_MALLOC_ASYNC_SUPPORTED_PREALLOC", 0,
240                                   &prealloc_size));
241   if (prealloc_size == -1) {
242     prealloc_size = pool_size;
243   } else if (reserve_memory) {
244     prealloc_size = pool_size;
245   }
246 
247   if (prealloc_size != 0) {
248     void* ptr = AllocateRaw(0, prealloc_size);
249     DeallocateRaw(ptr);
250     VLOG(2) << Name() << " GpuCudaMallocAsyncAllocator reserved the pool for "
251             << prealloc_size << " bytes"
252             << ". First ptr: " << ptr;
253     ClearStats();
254   }
255 #else   // TF_CUDA_MALLOC_ASYNC_SUPPORTED
256   LOG(FATAL) << "GpuCudaMallocAsyncAllocator requires CUDA 11.2+";  // Crash OK.
257 #endif  // TF_CUDA_MALLOC_ASYNC_SUPPORTED
258 }
259 
~GpuCudaMallocAsyncAllocator()260 GpuCudaMallocAsyncAllocator::~GpuCudaMallocAsyncAllocator() {}
261 
AllocateRaw(size_t alignment,size_t num_bytes)262 void* GpuCudaMallocAsyncAllocator::AllocateRaw(size_t alignment,
263                                                size_t num_bytes) {
264 #if TF_CUDA_MALLOC_ASYNC_SUPPORTED
265   CHECK(cuda_stream_ != nullptr)
266       << "A stream must be added to the GpuCudaMallocAsync allocator";
267   if (pool_ == nullptr) {
268     LOG(FATAL)  // Crash OK.
269         << "The instantiation of GpuCudaMallocAsyncAllocator failed."
270         << " See previous errors.";
271   }
272   se::cuda::ScopedActivateExecutorContext scoped_activation{stream_exec_};
273   void* ptr = nullptr;
274   if (auto result =
275           cuMemAllocFromPoolAsync(reinterpret_cast<CUdeviceptr*>(&ptr),
276                                   num_bytes, pool_, cuda_stream_)) {
277     size_t free, total;
278     cuMemGetInfo(&free, &total);
279     LOG(ERROR) << Name() << " cuMemAllocAsync failed to allocate " << num_bytes
280                << " bytes: " << GetCudaErrorMessage(result)
281                << "\n Reported by CUDA: Free memory/Total memory: " << free
282                << "/" << total;
283     if (auto stats = GetStats())
284       LOG(ERROR) << "Stats: " << stats->DebugString();
285 
286     PrintAllocatorStatistics();
287 
288     return nullptr;
289   }
290 
291   // Update stats.
292   if (stats_) {
293     mutex_lock lock(lock_);
294     ++(stats_->num_allocs);
295     stats_->bytes_in_use += num_bytes;
296     if (stats_->bytes_in_use > stats_->peak_bytes_in_use) {
297       VLOG(9) << "New Peak memory usage of " << stats_->bytes_in_use
298               << " bytes.";
299     }
300     stats_->peak_bytes_in_use =
301         std::max(stats_->peak_bytes_in_use, stats_->bytes_in_use);
302     stats_->largest_alloc_size =
303         std::max<std::size_t>(stats_->largest_alloc_size, num_bytes);
304     size_map_[ptr] = num_bytes;
305   }
306   VLOG(10) << Name() << " Allocated " << num_bytes << " at " << ptr;
307   return ptr;
308 #else   // TF_CUDA_MALLOC_ASYNC_SUPPORTED
309   return nullptr;
310 #endif  // TF_CUDA_MALLOC_ASYNC_SUPPORTED
311 }
DeallocateRaw(void * ptr)312 void GpuCudaMallocAsyncAllocator::DeallocateRaw(void* ptr) {
313 #if TF_CUDA_MALLOC_ASYNC_SUPPORTED
314   if (auto result = cuMemFreeAsync(reinterpret_cast<const CUdeviceptr&>(ptr),
315                                    cuda_stream_)) {
316     if (result == CUDA_ERROR_DEINITIALIZED) {
317       // It happens with multi-GPU that TF free the GPU allocation after
318       // the driver is unloaded. It is safe to ignore this error here.
319       // TODO: Find how to fix the shutdown steps in TF.
320       VLOG(1) << "Ignoring CUDA error: " << GetCudaErrorMessage(result);
321     } else {
322       size_t free, total;
323       se::cuda::ScopedActivateExecutorContext scoped_activation{stream_exec_};
324       cuMemGetInfo(&free, &total);
325       LOG(ERROR) << "cudaFreeAsync failed to free " << ptr << ": "
326                  << GetCudaErrorMessage(result)
327                  << "\n Free memory/Total memory: " << free << "/" << total;
328       if (auto stats = GetStats())
329         LOG(ERROR) << "Stats: " << stats->DebugString();
330     }
331   }
332 
333   // Updates the stats.
334   if (stats_) {
335     mutex_lock lock(lock_);
336     DCHECK(size_map_.contains(ptr));
337     size_t size = size_map_[ptr];
338     stats_->bytes_in_use -= size;
339     size_map_.erase(ptr);
340   }
341 
342   VLOG(10) << Name() << " Freed ptr: " << ptr;
343 #endif  // TF_CUDA_MALLOC_ASYNC_SUPPORTED
344 }
345 
TracksAllocationSizes() const346 bool GpuCudaMallocAsyncAllocator::TracksAllocationSizes() const {
347   return static_cast<bool>(stats_);
348 }
349 
RequestedSize(const void * ptr) const350 size_t GpuCudaMallocAsyncAllocator::RequestedSize(const void* ptr) const {
351   if (!stats_ || !ptr) return 0;
352   mutex_lock l(lock_);
353   return size_map_.at(ptr);
354 }
355 
AllocatedSize(const void * ptr) const356 size_t GpuCudaMallocAsyncAllocator::AllocatedSize(const void* ptr) const {
357   if (!stats_ || !ptr) return 0;
358   mutex_lock l(lock_);
359   return size_map_.at(ptr);
360 }
361 
GetStats()362 absl::optional<AllocatorStats> GpuCudaMallocAsyncAllocator::GetStats() {
363   if (!stats_) return absl::nullopt;
364   mutex_lock l(lock_);
365   return *stats_;
366 }
367 
ClearStats()368 bool GpuCudaMallocAsyncAllocator::ClearStats() {
369   if (!stats_) return false;
370   mutex_lock l(lock_);
371   stats_->num_allocs = 0;
372   stats_->peak_bytes_in_use = stats_->bytes_in_use;
373   stats_->largest_alloc_size = 0;
374   return true;
375 }
376 
377 }  // namespace tensorflow
378