• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/gpu/gpu_process_state.h"
17 
18 #include <cstring>
19 #include <vector>
20 
21 #include "absl/container/flat_hash_set.h"
22 #include "tensorflow/core/common_runtime/device/device_host_allocator.h"
23 #include "tensorflow/core/common_runtime/device/device_id_utils.h"
24 #include "tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h"
25 #include "tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.h"
26 #include "tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h"
27 #include "tensorflow/core/common_runtime/gpu/gpu_id.h"
28 #include "tensorflow/core/common_runtime/gpu/gpu_id_manager.h"
29 #include "tensorflow/core/common_runtime/gpu/gpu_init.h"
30 #include "tensorflow/core/common_runtime/gpu/gpu_virtual_mem_allocator.h"
31 #include "tensorflow/core/common_runtime/pool_allocator.h"
32 #include "tensorflow/core/common_runtime/shared_counter.h"
33 #include "tensorflow/core/framework/allocator.h"
34 #include "tensorflow/core/framework/log_memory.h"
35 #include "tensorflow/core/framework/tracking_allocator.h"
36 #include "tensorflow/core/lib/strings/strcat.h"
37 #include "tensorflow/core/platform/logging.h"
38 #include "tensorflow/core/platform/mutex.h"
39 #include "tensorflow/core/platform/stream_executor.h"
40 #include "tensorflow/core/platform/types.h"
41 #include "tensorflow/core/util/env_var.h"
42 
43 namespace tensorflow {
44 namespace {
45 
useCudaMallocAllocator()46 bool useCudaMallocAllocator() {
47   const char* debug_allocator_str = std::getenv("TF_GPU_ALLOCATOR");
48   return debug_allocator_str != nullptr &&
49          std::strcmp(debug_allocator_str, "cuda_malloc") == 0;
50 }
51 
useCudaMemoryGuardAllocator()52 bool useCudaMemoryGuardAllocator() {
53   const char* debug_allocator_str = std::getenv("TF_GPU_ALLOCATOR");
54   return debug_allocator_str != nullptr &&
55          std::strcmp(debug_allocator_str, "memory_guard") == 0;
56 }
57 
58 }  // namespace
59 
singleton(GPUProcessState * ps)60 /*static*/ GPUProcessState* GPUProcessState::singleton(GPUProcessState* ps) {
61   static GPUProcessState* instance = ps ? ps : new GPUProcessState;
62   DCHECK((!ps) || (ps == instance))
63       << "Multiple calls to GPUProcessState with non-null ps";
64   return instance;
65 }
66 
GPUProcessState()67 GPUProcessState::GPUProcessState() : gpu_device_enabled_(false) {
68   process_state_ = ProcessState::singleton();
69 }
70 
BusIdForGPU(TfGpuId tf_gpu_id)71 int GPUProcessState::BusIdForGPU(TfGpuId tf_gpu_id) {
72   // Return the NUMA node associated with the GPU's StreamExecutor.
73   se::StreamExecutor* se = DeviceIdUtil::ExecutorForTfDeviceId(
74                                DEVICE_GPU, GPUMachineManager(), tf_gpu_id)
75                                .ValueOrDie();
76   int numa_node = se->GetDeviceDescription().numa_node();
77   // bus_id must be non-negative.  If the numa_node is not known,
78   // use 0.
79   return numa_node >= 0 ? numa_node : 0;
80 }
81 
82 // NOLINTNEXTLINE: clang-tidy complains this is unused because of build flags.
CreateSubAllocator(const GPUOptions & options,PlatformGpuId platform_gpu_id,const std::vector<SubAllocator::Visitor> & alloc_visitors,size_t total_bytes,const std::vector<TfGpuId> & peer_gpu_ids)83 static SubAllocator* CreateSubAllocator(
84     const GPUOptions& options, PlatformGpuId platform_gpu_id,
85     const std::vector<SubAllocator::Visitor>& alloc_visitors,
86     size_t total_bytes, const std::vector<TfGpuId>& peer_gpu_ids) {
87   auto executor = DeviceIdUtil::ExecutorForPlatformDeviceId(GPUMachineManager(),
88                                                             platform_gpu_id)
89                       .ValueOrDie();
90 
91   // FIXME(imintz): Observed OOM issues when using the virtual memory
92   // allocators. This should be reenabled when resolved.
93 #if 0 && defined(GOOGLE_CUDA) && CUDA_VERSION >= 10020
94   // Use the old allocator when unified memory is required.
95   // TODO(imintz): Remove the cuMemAlloc capability of this allocator.
96   if (options.per_process_gpu_memory_fraction() > 1.0 ||
97       options.experimental().use_unified_memory()) {
98     return new DeviceMemAllocator(executor, platform_gpu_id,
99                                   /*use_unified_memory=*/true, alloc_visitors,
100                                   {});
101   } else {
102     auto* gpu_context = reinterpret_cast<stream_executor::gpu::GpuContext*>(
103         executor->implementation()->GpuContextHack());
104 
105     absl::flat_hash_set<PlatformGpuId> platform_peer_gpu_ids;
106     platform_peer_gpu_ids.reserve(peer_gpu_ids.size());
107     for (const TfGpuId tf_gpu_id : peer_gpu_ids) {
108       PlatformGpuId platform_gpu_id;
109       TF_CHECK_OK(GpuIdManager::TfToPlatformGpuId(tf_gpu_id, &platform_gpu_id));
110       platform_peer_gpu_ids.insert(platform_gpu_id);
111     }
112     std::vector<PlatformGpuId> platform_peer_gpu_ids_vec(
113         platform_peer_gpu_ids.begin(), platform_peer_gpu_ids.end());
114 
115     // Adjust virtual address space to be slightly larger than the physical
116     // address space in case the BFC allocator performs suboptimal garbage
117     // collection.
118     // TODO(imintz): Update BFC allocator to ensure it doesn't create holes in
119     // the va space.
120     return GpuVirtualMemAllocator::Create(
121                alloc_visitors, {}, *gpu_context, platform_gpu_id,
122                /*virtual_address_space_size=*/total_bytes * 2,
123                platform_peer_gpu_ids_vec)
124         .ValueOrDie()
125         .release();
126   }
127 #else
128   return new DeviceMemAllocator(
129       executor, platform_gpu_id,
130       (options.per_process_gpu_memory_fraction() > 1.0 ||
131        options.experimental().use_unified_memory()),
132       alloc_visitors, {});
133 #endif
134 }
135 
GetGPUAllocator(const GPUOptions & options,TfGpuId tf_gpu_id,size_t total_bytes,const std::vector<TfGpuId> & peer_gpu_ids)136 Allocator* GPUProcessState::GetGPUAllocator(
137     const GPUOptions& options, TfGpuId tf_gpu_id, size_t total_bytes,
138     const std::vector<TfGpuId>& peer_gpu_ids) {
139   CHECK(process_state_);
140 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
141     (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
142   const string& allocator_type = options.allocator_type();
143   mutex_lock lock(mu_);
144   DeviceIdUtil::CheckValidTfDeviceId(DEVICE_GPU, GPUMachineManager(),
145                                      tf_gpu_id);
146 
147   if (tf_gpu_id.value() >= static_cast<int64>(gpu_allocators_.size())) {
148     gpu_allocators_.resize(tf_gpu_id.value() + 1);
149   }
150 
151   AllocatorParts& allocator_parts = gpu_allocators_[tf_gpu_id.value()];
152   if (allocator_parts.allocator == nullptr) {
153     // Validate allocator types.
154     if (!allocator_type.empty() && allocator_type != "BFC") {
155       LOG(ERROR) << "Invalid allocator type: " << allocator_type;
156       return nullptr;
157     }
158 
159     PlatformGpuId platform_gpu_id;
160     TF_CHECK_OK(GpuIdManager::TfToPlatformGpuId(tf_gpu_id, &platform_gpu_id));
161     int bus_id = BusIdForGPU(tf_gpu_id);
162     DCHECK_GE(bus_id, 0);
163     while (bus_id >= gpu_visitors_.size()) {
164       gpu_visitors_.push_back({});
165     }
166     auto* sub_allocator =
167         CreateSubAllocator(options, platform_gpu_id, gpu_visitors_[bus_id],
168                            total_bytes, peer_gpu_ids);
169     GPUBFCAllocator* gpu_bfc_allocator =
170         new GPUBFCAllocator(sub_allocator, total_bytes, options,
171                             strings::StrCat("GPU_", tf_gpu_id.value(), "_bfc"));
172     Allocator* gpu_allocator = gpu_bfc_allocator;
173     SharedCounter* timing_counter = nullptr;
174     if (options.experimental().timestamped_allocator()) {
175       timing_counter = new SharedCounter;
176       gpu_bfc_allocator->SetTimingCounter(timing_counter);
177     }
178 
179     // If true, checks for memory overwrites by writing
180     // distinctive patterns on both ends of allocated memory.
181     if (useCudaMemoryGuardAllocator()) {
182       LOG(INFO) << "Using memory guard allocator for GPU.";
183       gpu_allocator = new GPUDebugAllocator(gpu_allocator, platform_gpu_id);
184       gpu_allocator = new GPUNanResetAllocator(gpu_allocator, platform_gpu_id);
185     } else if (useCudaMallocAllocator()) {
186       LOG(INFO) << "Using CUDA malloc allocator for GPU.";
187       // If true, passes all allocation requests through to cudaMalloc
188       // useful for doing memory debugging with tools like cuda-memcheck
189       // **WARNING** probably will not work in a multi-gpu scenario
190       gpu_allocator =
191           new GPUcudaMallocAllocator(gpu_allocator, platform_gpu_id);
192     }
193 
194     Allocator* recording_allocator = nullptr;
195     if (process_state_->ProcessState::FLAGS_brain_gpu_record_mem_types) {
196       ProcessState::MemDesc md;
197       md.loc = ProcessState::MemDesc::GPU;
198       md.dev_index = platform_gpu_id.value();
199       md.gpu_registered = false;
200       md.nic_registered = true;
201       recording_allocator = new internal::RecordingAllocator(
202           &process_state_->mem_desc_map_, gpu_allocator, md, &mu_);
203     }
204     allocator_parts = {std::unique_ptr<Allocator>(gpu_allocator),
205                        std::unique_ptr<SharedCounter>(timing_counter),
206                        gpu_bfc_allocator, sub_allocator,
207                        std::unique_ptr<Allocator>(recording_allocator)};
208   }
209   if (process_state_->ProcessState::FLAGS_brain_gpu_record_mem_types) {
210     return allocator_parts.recording_allocator.get();
211   } else {
212     return allocator_parts.allocator.get();
213   }
214 #else
215   LOG(FATAL) << "GPUAllocator unavailable. Not compiled with --config=cuda or "
216                 "--config=rocm.";
217   return nullptr;
218 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
219 }
220 
GPUAllocatorCounter(TfGpuId tf_gpu_id)221 SharedCounter* GPUProcessState::GPUAllocatorCounter(TfGpuId tf_gpu_id) {
222   DCHECK(process_state_);
223 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
224     (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
225   DeviceIdUtil::CheckValidTfDeviceId(DEVICE_GPU, GPUMachineManager(),
226                                      tf_gpu_id);
227   mutex_lock l(mu_);
228   if (tf_gpu_id.value() >= static_cast<int64>(gpu_allocators_.size())) {
229     LOG(ERROR) << "Asked for counter for GPU allocator " << tf_gpu_id.value()
230                << " but only have " << gpu_allocators_.size();
231     return nullptr;
232   }
233 
234   AllocatorParts& allocator_parts = gpu_allocators_[tf_gpu_id.value()];
235   if (allocator_parts.counter.get() == nullptr) {
236     SharedCounter* timing_counter = new SharedCounter;
237     allocator_parts.bfc_allocator->SetTimingCounter(timing_counter);
238     allocator_parts.counter.reset(timing_counter);
239   }
240   return allocator_parts.counter.get();
241 #else
242   return nullptr;
243 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
244 }
245 
GetGpuHostAllocator(int numa_node)246 Allocator* GPUProcessState::GetGpuHostAllocator(int numa_node) {
247   CHECK(process_state_);
248   if (!HasGPUDevice() ||
249       !process_state_->ProcessState::FLAGS_brain_mem_reg_gpu_dma) {
250     return process_state_->GetCPUAllocator(numa_node);
251   }
252   if (numa_node == port::kNUMANoAffinity) {
253     numa_node = 0;
254   }
255   {
256     // Here we optimize the most common use case where gpu_host_allocators_
257     // have already been populated and since we're only reading
258     // these vectors, we can get by with a shared lock. In the slower case,
259     // we take a unique lock and populate these vectors.
260     tf_shared_lock lock(mu_);
261 
262     if (process_state_->ProcessState::FLAGS_brain_gpu_record_mem_types &&
263         !gpu_host_allocators_.empty() &&
264         gpu_host_allocators_[0].recording_allocator != nullptr) {
265       return gpu_host_allocators_[0].recording_allocator.get();
266     }
267     if (static_cast<int>(gpu_host_allocators_.size()) > numa_node) {
268       return gpu_host_allocators_[0].allocator.get();
269     }
270   }
271 
272   mutex_lock lock(mu_);
273   // Find the first valid StreamExecutor to request CUDA or ROCm host memory
274   // through, since any will work.
275   //
276   // This search isn't super clean, and it would be nice to use a
277   // better source of information about which executor to use.  For
278   // example, process_state could maybe save the first stream executor
279   // it knows is valid.
280   se::StreamExecutor* se = nullptr;
281   for (int i = 0; i < static_cast<int>(gpu_allocators_.size()); ++i) {
282     if (gpu_allocators_[i].allocator != nullptr) {
283       se = DeviceIdUtil::ExecutorForTfDeviceId(DEVICE_GPU, GPUMachineManager(),
284                                                TfGpuId(i))
285                .ValueOrDie();
286       break;
287     }
288   }
289 
290   CHECK_NE(nullptr, se);
291 
292   while (static_cast<int>(gpu_host_allocators_.size()) <= numa_node) {
293     while (gpu_host_alloc_visitors_.size() <= numa_node) {
294       gpu_host_alloc_visitors_.push_back({});
295     }
296     while (gpu_host_free_visitors_.size() <= numa_node) {
297       gpu_host_free_visitors_.push_back({});
298     }
299     SubAllocator* sub_allocator = new DeviceHostAllocator(
300         se, numa_node, gpu_host_alloc_visitors_[numa_node],
301         gpu_host_free_visitors_[numa_node]);
302     // TODO(zheng-xq): evaluate whether 64GB by default is the best choice.
303     int64 gpu_host_mem_limit_in_mb = -1;
304     Status status = ReadInt64FromEnvVar("TF_GPU_HOST_MEM_LIMIT_IN_MB",
305                                         1LL << 16 /*64GB max by default*/,
306                                         &gpu_host_mem_limit_in_mb);
307     if (!status.ok()) {
308       LOG(ERROR) << "GetGpuHostAllocator: " << status.error_message();
309     }
310     int64 gpu_host_mem_limit = gpu_host_mem_limit_in_mb * (1LL << 20);
311 
312     Allocator* allocator =
313         new BFCAllocator(sub_allocator, gpu_host_mem_limit,
314                          /*allow_growth=*/true, /*name=*/"gpu_host_bfc");
315 
316     if (LogMemory::IsEnabled() && !allocator->TracksAllocationSizes()) {
317       // Wrap the allocator to track allocation ids for better logging
318       // at the cost of performance.
319       allocator = new TrackingAllocator(allocator, true);
320     }
321     gpu_host_allocators_.push_back({std::unique_ptr<Allocator>(allocator),
322                                     std::unique_ptr<SharedCounter>(nullptr),
323                                     nullptr, sub_allocator,
324                                     std::unique_ptr<Allocator>(nullptr)});
325     AllocatorParts& allocator_parts = gpu_host_allocators_.back();
326     if (process_state_->ProcessState::FLAGS_brain_gpu_record_mem_types) {
327       ProcessState::MemDesc md;
328       md.loc = ProcessState::MemDesc::CPU;
329       md.dev_index = 0;
330       md.gpu_registered = true;
331       md.nic_registered = false;
332       allocator_parts.recording_allocator.reset(
333           new internal::RecordingAllocator(&process_state_->mem_desc_map_,
334                                            allocator_parts.allocator.get(), md,
335                                            &mu_));
336     }
337   }
338   if (process_state_->ProcessState::FLAGS_brain_gpu_record_mem_types) {
339     return gpu_host_allocators_[0].recording_allocator.get();
340   } else {
341     return gpu_host_allocators_[0].allocator.get();
342   }
343 }
344 
AddGPUAllocVisitor(int bus_id,const SubAllocator::Visitor & visitor)345 void GPUProcessState::AddGPUAllocVisitor(int bus_id,
346                                          const SubAllocator::Visitor& visitor) {
347 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
348     (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
349   mutex_lock lock(mu_);
350   CHECK(gpu_allocators_.empty())  // Crash OK
351       << "AddGPUAllocVisitor must be called before "
352          "first call to GetGPUAllocator.";
353   DCHECK_GE(bus_id, 0);
354   while (bus_id >= static_cast<int64>(gpu_visitors_.size())) {
355     gpu_visitors_.push_back(std::vector<SubAllocator::Visitor>());
356   }
357   gpu_visitors_[bus_id].push_back(visitor);
358 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
359 }
360 
AddGpuHostAllocVisitor(int numa_node,const SubAllocator::Visitor & visitor)361 void GPUProcessState::AddGpuHostAllocVisitor(
362     int numa_node, const SubAllocator::Visitor& visitor) {
363 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
364     (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
365   mutex_lock lock(mu_);
366   CHECK(gpu_host_allocators_.empty())  // Crash OK
367       << "AddGpuHostAllocVisitor must be called before "
368          "first call to GetGpuHostAllocator.";
369   while (numa_node >= static_cast<int64>(gpu_host_alloc_visitors_.size())) {
370     gpu_host_alloc_visitors_.push_back(std::vector<SubAllocator::Visitor>());
371   }
372   gpu_host_alloc_visitors_[numa_node].push_back(visitor);
373 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
374 }
375 
AddGpuHostFreeVisitor(int numa_node,const SubAllocator::Visitor & visitor)376 void GPUProcessState::AddGpuHostFreeVisitor(
377     int numa_node, const SubAllocator::Visitor& visitor) {
378 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
379     (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
380   mutex_lock lock(mu_);
381   CHECK(gpu_host_allocators_.empty())  // Crash OK
382       << "AddGpuHostFreeVisitor must be called before "
383          "first call to GetGpuHostAllocator.";
384   while (numa_node >= static_cast<int64>(gpu_host_free_visitors_.size())) {
385     gpu_host_free_visitors_.push_back(std::vector<SubAllocator::Visitor>());
386   }
387   gpu_host_free_visitors_[numa_node].push_back(visitor);
388 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
389 }
390 
TestOnlyReset()391 void GPUProcessState::TestOnlyReset() {
392   if (process_state_) {
393     process_state_->ProcessState::TestOnlyReset();
394   }
395   {
396     mutex_lock lock(mu_);
397     gpu_device_enabled_ = false;
398     gpu_allocators_.clear();
399     gpu_visitors_.clear();
400     gpu_host_allocators_.clear();
401     gpu_host_alloc_visitors_.clear();
402     gpu_host_free_visitors_.clear();
403   }
404 }
405 
406 }  // namespace tensorflow
407