1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/common_runtime/gpu/gpu_process_state.h"
17
18 #include <cstring>
19 #include <vector>
20
21 #include "tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h"
22 #include "tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.h"
23 #include "tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h"
24 #include "tensorflow/core/common_runtime/gpu/gpu_host_allocator.h"
25 #include "tensorflow/core/common_runtime/gpu/gpu_id.h"
26 #include "tensorflow/core/common_runtime/gpu/gpu_id_manager.h"
27 #include "tensorflow/core/common_runtime/gpu/gpu_id_utils.h"
28 #include "tensorflow/core/common_runtime/gpu/gpu_init.h"
29 #include "tensorflow/core/common_runtime/pool_allocator.h"
30 #include "tensorflow/core/common_runtime/shared_counter.h"
31 #include "tensorflow/core/framework/allocator.h"
32 #include "tensorflow/core/framework/log_memory.h"
33 #include "tensorflow/core/framework/tracking_allocator.h"
34 #include "tensorflow/core/lib/strings/strcat.h"
35 #include "tensorflow/core/platform/logging.h"
36 #include "tensorflow/core/platform/mutex.h"
37 #include "tensorflow/core/platform/stream_executor.h"
38 #include "tensorflow/core/platform/types.h"
39 #include "tensorflow/core/util/env_var.h"
40
41 namespace tensorflow {
42 namespace {
43
useCudaMallocAllocator()44 bool useCudaMallocAllocator() {
45 const char* debug_allocator_str = std::getenv("TF_GPU_ALLOCATOR");
46 return debug_allocator_str != nullptr &&
47 std::strcmp(debug_allocator_str, "cuda_malloc") == 0;
48 }
49
useCudaMemoryGuardAllocator()50 bool useCudaMemoryGuardAllocator() {
51 const char* debug_allocator_str = std::getenv("TF_GPU_ALLOCATOR");
52 return debug_allocator_str != nullptr &&
53 std::strcmp(debug_allocator_str, "memory_guard") == 0;
54 }
55
56 } // namespace
57
singleton(GPUProcessState * ps)58 /*static*/ GPUProcessState* GPUProcessState::singleton(GPUProcessState* ps) {
59 static GPUProcessState* instance = ps ? ps : new GPUProcessState;
60 DCHECK((!ps) || (ps == instance))
61 << "Multiple calls to GPUProcessState with non-null ps";
62 return instance;
63 }
64
GPUProcessState()65 GPUProcessState::GPUProcessState() : gpu_device_enabled_(false) {
66 process_state_ = ProcessState::singleton();
67 }
68
BusIdForGPU(TfGpuId tf_gpu_id)69 int GPUProcessState::BusIdForGPU(TfGpuId tf_gpu_id) {
70 // Return the NUMA node associated with the GPU's StreamExecutor.
71 se::StreamExecutor* se =
72 GpuIdUtil::ExecutorForTfGpuId(tf_gpu_id).ValueOrDie();
73 int numa_node = se->GetDeviceDescription().numa_node();
74 // bus_id must be non-negative. If the numa_node is not known,
75 // use 0.
76 return numa_node >= 0 ? numa_node : 0;
77 }
78
GetGPUAllocator(const GPUOptions & options,TfGpuId tf_gpu_id,size_t total_bytes)79 Allocator* GPUProcessState::GetGPUAllocator(const GPUOptions& options,
80 TfGpuId tf_gpu_id,
81 size_t total_bytes) {
82 CHECK(process_state_);
83 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
84 (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
85 const string& allocator_type = options.allocator_type();
86 mutex_lock lock(mu_);
87 GpuIdUtil::CheckValidTfGpuId(tf_gpu_id);
88
89 if (tf_gpu_id.value() >= static_cast<int64>(gpu_allocators_.size())) {
90 gpu_allocators_.resize(tf_gpu_id.value() + 1);
91 }
92
93 AllocatorParts& allocator_parts = gpu_allocators_[tf_gpu_id.value()];
94 if (allocator_parts.allocator == nullptr) {
95 // Validate allocator types.
96 if (!allocator_type.empty() && allocator_type != "BFC") {
97 LOG(ERROR) << "Invalid allocator type: " << allocator_type;
98 return nullptr;
99 }
100
101 PlatformGpuId platform_gpu_id;
102 TF_CHECK_OK(GpuIdManager::TfToPlatformGpuId(tf_gpu_id, &platform_gpu_id));
103 int bus_id = BusIdForGPU(tf_gpu_id);
104 DCHECK_GE(bus_id, 0);
105 while (bus_id >= gpu_visitors_.size()) {
106 gpu_visitors_.push_back({});
107 }
108 GPUMemAllocator* sub_allocator = new GPUMemAllocator(
109 GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
110 platform_gpu_id,
111 (options.per_process_gpu_memory_fraction() > 1.0 ||
112 options.experimental().use_unified_memory()),
113 gpu_visitors_[bus_id], {});
114 GPUBFCAllocator* gpu_bfc_allocator =
115 new GPUBFCAllocator(sub_allocator, total_bytes, options,
116 strings::StrCat("GPU_", tf_gpu_id.value(), "_bfc"));
117 Allocator* gpu_allocator = gpu_bfc_allocator;
118 SharedCounter* timing_counter = nullptr;
119 if (options.experimental().timestamped_allocator()) {
120 timing_counter = new SharedCounter;
121 gpu_bfc_allocator->SetTimingCounter(timing_counter);
122 }
123
124 // If true, checks for memory overwrites by writing
125 // distinctive patterns on both ends of allocated memory.
126 if (useCudaMemoryGuardAllocator()) {
127 LOG(INFO) << "Using memory guard allocator for GPU.";
128 gpu_allocator = new GPUDebugAllocator(gpu_allocator, platform_gpu_id);
129 gpu_allocator = new GPUNanResetAllocator(gpu_allocator, platform_gpu_id);
130 } else if (useCudaMallocAllocator()) {
131 LOG(INFO) << "Using CUDA malloc allocator for GPU.";
132 // If true, passes all allocation requests through to cudaMalloc
133 // useful for doing memory debugging with tools like cuda-memcheck
134 // **WARNING** probably will not work in a multi-gpu scenario
135 gpu_allocator =
136 new GPUcudaMallocAllocator(gpu_allocator, platform_gpu_id);
137 }
138
139 Allocator* recording_allocator = nullptr;
140 if (process_state_->ProcessState::FLAGS_brain_gpu_record_mem_types) {
141 ProcessState::MemDesc md;
142 md.loc = ProcessState::MemDesc::GPU;
143 md.dev_index = platform_gpu_id.value();
144 md.gpu_registered = false;
145 md.nic_registered = true;
146 recording_allocator = new internal::RecordingAllocator(
147 &process_state_->mem_desc_map_, gpu_allocator, md, &mu_);
148 }
149 allocator_parts = {std::unique_ptr<Allocator>(gpu_allocator),
150 std::unique_ptr<SharedCounter>(timing_counter),
151 gpu_bfc_allocator, sub_allocator,
152 std::unique_ptr<Allocator>(recording_allocator)};
153 }
154 if (process_state_->ProcessState::FLAGS_brain_gpu_record_mem_types) {
155 return allocator_parts.recording_allocator.get();
156 } else {
157 return allocator_parts.allocator.get();
158 }
159 #else
160 LOG(FATAL) << "GPUAllocator unavailable. Not compiled with --config=cuda or "
161 "--config=rocm.";
162 return nullptr;
163 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
164 }
165
GPUAllocatorCounter(TfGpuId tf_gpu_id)166 SharedCounter* GPUProcessState::GPUAllocatorCounter(TfGpuId tf_gpu_id) {
167 DCHECK(process_state_);
168 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
169 (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
170 GpuIdUtil::CheckValidTfGpuId(tf_gpu_id);
171 mutex_lock l(mu_);
172 if (tf_gpu_id.value() >= static_cast<int64>(gpu_allocators_.size())) {
173 LOG(ERROR) << "Asked for counter for GPU allocator " << tf_gpu_id.value()
174 << " but only have " << gpu_allocators_.size();
175 return nullptr;
176 }
177
178 AllocatorParts& allocator_parts = gpu_allocators_[tf_gpu_id.value()];
179 if (allocator_parts.counter.get() == nullptr) {
180 SharedCounter* timing_counter = new SharedCounter;
181 allocator_parts.bfc_allocator->SetTimingCounter(timing_counter);
182 allocator_parts.counter.reset(timing_counter);
183 }
184 return allocator_parts.counter.get();
185 #else
186 return nullptr;
187 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
188 }
189
GetGpuHostAllocator(int numa_node)190 Allocator* GPUProcessState::GetGpuHostAllocator(int numa_node) {
191 CHECK(process_state_);
192 if (!HasGPUDevice() ||
193 !process_state_->ProcessState::FLAGS_brain_mem_reg_gpu_dma) {
194 return process_state_->GetCPUAllocator(numa_node);
195 }
196 if (numa_node == port::kNUMANoAffinity) {
197 numa_node = 0;
198 }
199 {
200 // Here we optimize the most common use case where gpu_host_allocators_
201 // have already been populated and since we're only reading
202 // these vectors, we can get by with a shared lock. In the slower case,
203 // we take a unique lock and populate these vectors.
204 tf_shared_lock lock(mu_);
205
206 if (process_state_->ProcessState::FLAGS_brain_gpu_record_mem_types &&
207 !gpu_host_allocators_.empty() &&
208 gpu_host_allocators_[0].recording_allocator != nullptr) {
209 return gpu_host_allocators_[0].recording_allocator.get();
210 }
211 if (static_cast<int>(gpu_host_allocators_.size()) > numa_node) {
212 return gpu_host_allocators_[0].allocator.get();
213 }
214 }
215
216 mutex_lock lock(mu_);
217 // Find the first valid StreamExecutor to request CUDA or ROCm host memory
218 // through, since any will work.
219 //
220 // This search isn't super clean, and it would be nice to use a
221 // better source of information about which executor to use. For
222 // example, process_state could maybe save the first stream executor
223 // it knows is valid.
224 se::StreamExecutor* se = nullptr;
225 for (int i = 0; i < static_cast<int>(gpu_allocators_.size()); ++i) {
226 if (gpu_allocators_[i].allocator != nullptr) {
227 se = GpuIdUtil::ExecutorForTfGpuId(TfGpuId(i)).ValueOrDie();
228 break;
229 }
230 }
231
232 CHECK_NE(nullptr, se);
233
234 while (static_cast<int>(gpu_host_allocators_.size()) <= numa_node) {
235 while (gpu_host_alloc_visitors_.size() <= numa_node) {
236 gpu_host_alloc_visitors_.push_back({});
237 }
238 while (gpu_host_free_visitors_.size() <= numa_node) {
239 gpu_host_free_visitors_.push_back({});
240 }
241 SubAllocator* sub_allocator =
242 new GpuHostAllocator(se, numa_node, gpu_host_alloc_visitors_[numa_node],
243 gpu_host_free_visitors_[numa_node]);
244 // TODO(zheng-xq): evaluate whether 64GB by default is the best choice.
245 int64 gpu_host_mem_limit_in_mb = -1;
246 Status status = ReadInt64FromEnvVar("TF_GPU_HOST_MEM_LIMIT_IN_MB",
247 1LL << 16 /*64GB max by default*/,
248 &gpu_host_mem_limit_in_mb);
249 if (!status.ok()) {
250 LOG(ERROR) << "GetGpuHostAllocator: " << status.error_message();
251 }
252 int64 gpu_host_mem_limit = gpu_host_mem_limit_in_mb * (1LL << 20);
253
254 Allocator* allocator =
255 new BFCAllocator(sub_allocator, gpu_host_mem_limit,
256 true /*allow_growth*/, "gpu_host_bfc" /*name*/);
257
258 if (LogMemory::IsEnabled() && !allocator->TracksAllocationSizes()) {
259 // Wrap the allocator to track allocation ids for better logging
260 // at the cost of performance.
261 allocator = new TrackingAllocator(allocator, true);
262 }
263 gpu_host_allocators_.push_back({std::unique_ptr<Allocator>(allocator),
264 std::unique_ptr<SharedCounter>(nullptr),
265 nullptr, sub_allocator,
266 std::unique_ptr<Allocator>(nullptr)});
267 AllocatorParts& allocator_parts = gpu_host_allocators_.back();
268 if (process_state_->ProcessState::FLAGS_brain_gpu_record_mem_types) {
269 ProcessState::MemDesc md;
270 md.loc = ProcessState::MemDesc::CPU;
271 md.dev_index = 0;
272 md.gpu_registered = true;
273 md.nic_registered = false;
274 allocator_parts.recording_allocator.reset(
275 new internal::RecordingAllocator(&process_state_->mem_desc_map_,
276 allocator_parts.allocator.get(), md,
277 &mu_));
278 }
279 }
280 if (process_state_->ProcessState::FLAGS_brain_gpu_record_mem_types) {
281 return gpu_host_allocators_[0].recording_allocator.get();
282 } else {
283 return gpu_host_allocators_[0].allocator.get();
284 }
285 }
286
AddGPUAllocVisitor(int bus_id,const SubAllocator::Visitor & visitor)287 void GPUProcessState::AddGPUAllocVisitor(int bus_id,
288 const SubAllocator::Visitor& visitor) {
289 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
290 (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
291 mutex_lock lock(mu_);
292 CHECK(gpu_allocators_.empty()) // Crash OK
293 << "AddGPUAllocVisitor must be called before "
294 "first call to GetGPUAllocator.";
295 DCHECK_GE(bus_id, 0);
296 while (bus_id >= static_cast<int64>(gpu_visitors_.size())) {
297 gpu_visitors_.push_back(std::vector<SubAllocator::Visitor>());
298 }
299 gpu_visitors_[bus_id].push_back(visitor);
300 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
301 }
302
AddGpuHostAllocVisitor(int numa_node,const SubAllocator::Visitor & visitor)303 void GPUProcessState::AddGpuHostAllocVisitor(
304 int numa_node, const SubAllocator::Visitor& visitor) {
305 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
306 (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
307 mutex_lock lock(mu_);
308 CHECK(gpu_host_allocators_.empty()) // Crash OK
309 << "AddGpuHostAllocVisitor must be called before "
310 "first call to GetGpuHostAllocator.";
311 while (numa_node >= static_cast<int64>(gpu_host_alloc_visitors_.size())) {
312 gpu_host_alloc_visitors_.push_back(std::vector<SubAllocator::Visitor>());
313 }
314 gpu_host_alloc_visitors_[numa_node].push_back(visitor);
315 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
316 }
317
AddGpuHostFreeVisitor(int numa_node,const SubAllocator::Visitor & visitor)318 void GPUProcessState::AddGpuHostFreeVisitor(
319 int numa_node, const SubAllocator::Visitor& visitor) {
320 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
321 (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
322 mutex_lock lock(mu_);
323 CHECK(gpu_host_allocators_.empty()) // Crash OK
324 << "AddGpuHostFreeVisitor must be called before "
325 "first call to GetGpuHostAllocator.";
326 while (numa_node >= static_cast<int64>(gpu_host_free_visitors_.size())) {
327 gpu_host_free_visitors_.push_back(std::vector<SubAllocator::Visitor>());
328 }
329 gpu_host_free_visitors_[numa_node].push_back(visitor);
330 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
331 }
332
TestOnlyReset()333 void GPUProcessState::TestOnlyReset() {
334 if (process_state_) {
335 process_state_->ProcessState::TestOnlyReset();
336 }
337 {
338 mutex_lock lock(mu_);
339 gpu_device_enabled_ = false;
340 gpu_allocators_.clear();
341 gpu_visitors_.clear();
342 gpu_host_allocators_.clear();
343 gpu_host_alloc_visitors_.clear();
344 gpu_host_free_visitors_.clear();
345 }
346 }
347
348 } // namespace tensorflow
349