1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/common_runtime/gpu/gpu_process_state.h"
17
18 #include <cstring>
19 #include <vector>
20
21 #include "tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h"
22 #include "tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.h"
23 #include "tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h"
24 #include "tensorflow/core/common_runtime/gpu/gpu_host_allocator.h"
25 #include "tensorflow/core/common_runtime/gpu/gpu_id.h"
26 #include "tensorflow/core/common_runtime/gpu/gpu_id_manager.h"
27 #include "tensorflow/core/common_runtime/gpu/gpu_id_utils.h"
28 #include "tensorflow/core/common_runtime/gpu/gpu_init.h"
29 #include "tensorflow/core/common_runtime/pool_allocator.h"
30 #include "tensorflow/core/common_runtime/shared_counter.h"
31 #include "tensorflow/core/framework/allocator.h"
32 #include "tensorflow/core/framework/log_memory.h"
33 #include "tensorflow/core/framework/tracking_allocator.h"
34 #include "tensorflow/core/lib/gtl/stl_util.h"
35 #include "tensorflow/core/lib/strings/strcat.h"
36 #include "tensorflow/core/platform/logging.h"
37 #include "tensorflow/core/platform/mutex.h"
38 #include "tensorflow/core/platform/stream_executor.h"
39 #include "tensorflow/core/platform/types.h"
40 #include "tensorflow/core/util/env_var.h"
41
42 namespace tensorflow {
43 namespace {
44
useCudaMallocAllocator()45 bool useCudaMallocAllocator() {
46 const char* debug_allocator_str = std::getenv("TF_GPU_ALLOCATOR");
47 return debug_allocator_str != nullptr &&
48 std::strcmp(debug_allocator_str, "cuda_malloc") == 0;
49 }
50
useCudaMemoryGuardAllocator()51 bool useCudaMemoryGuardAllocator() {
52 const char* debug_allocator_str = std::getenv("TF_GPU_ALLOCATOR");
53 return debug_allocator_str != nullptr &&
54 std::strcmp(debug_allocator_str, "memory_guard") == 0;
55 }
56
57 } // namespace
58
singleton(GPUProcessState * ps)59 /*static*/ GPUProcessState* GPUProcessState::singleton(GPUProcessState* ps) {
60 static GPUProcessState* instance = ps ? ps : new GPUProcessState;
61 DCHECK((!ps) || (ps == instance))
62 << "Multiple calls to GPUProcessState with non-null ps";
63 return instance;
64 }
65
GPUProcessState()66 GPUProcessState::GPUProcessState() : gpu_device_enabled_(false) {
67 process_state_ = ProcessState::singleton();
68 }
69
BusIdForGPU(TfGpuId tf_gpu_id)70 int GPUProcessState::BusIdForGPU(TfGpuId tf_gpu_id) {
71 // Return the NUMA node associated with the GPU's StreamExecutor.
72 se::StreamExecutor* se =
73 GpuIdUtil::ExecutorForTfGpuId(tf_gpu_id).ValueOrDie();
74 int numa_node = se->GetDeviceDescription().numa_node();
75 // bus_id must be non-negative. If the numa_node is not known,
76 // use 0.
77 return numa_node >= 0 ? numa_node : 0;
78 }
79
GetGPUAllocator(const GPUOptions & options,TfGpuId tf_gpu_id,size_t total_bytes)80 Allocator* GPUProcessState::GetGPUAllocator(const GPUOptions& options,
81 TfGpuId tf_gpu_id,
82 size_t total_bytes) {
83 CHECK(process_state_);
84 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
85 const string& allocator_type = options.allocator_type();
86 mutex_lock lock(mu_);
87 GpuIdUtil::CheckValidTfGpuId(tf_gpu_id);
88
89 if (tf_gpu_id.value() >= static_cast<int64>(gpu_allocators_.size())) {
90 gpu_allocators_.resize(tf_gpu_id.value() + 1);
91 }
92
93 AllocatorParts& allocator_parts = gpu_allocators_[tf_gpu_id.value()];
94 if (allocator_parts.allocator == nullptr) {
95 // Validate allocator types.
96 if (!allocator_type.empty() && allocator_type != "BFC") {
97 LOG(ERROR) << "Invalid allocator type: " << allocator_type;
98 return nullptr;
99 }
100
101 PlatformGpuId platform_gpu_id;
102 TF_CHECK_OK(GpuIdManager::TfToPlatformGpuId(tf_gpu_id, &platform_gpu_id));
103 int bus_id = BusIdForGPU(tf_gpu_id);
104 DCHECK_GE(bus_id, 0);
105 while (bus_id >= gpu_visitors_.size()) {
106 gpu_visitors_.push_back({});
107 }
108 GPUMemAllocator* sub_allocator = new GPUMemAllocator(
109 GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
110 platform_gpu_id,
111 (options.per_process_gpu_memory_fraction() > 1.0 ||
112 options.experimental().use_unified_memory()),
113 gpu_visitors_[bus_id], {});
114 GPUBFCAllocator* gpu_bfc_allocator =
115 new GPUBFCAllocator(sub_allocator, total_bytes, options,
116 strings::StrCat("GPU_", tf_gpu_id.value(), "_bfc"));
117 Allocator* gpu_allocator = gpu_bfc_allocator;
118 SharedCounter* timing_counter = nullptr;
119 if (options.experimental().timestamped_allocator()) {
120 timing_counter = new SharedCounter;
121 gpu_bfc_allocator->SetTimingCounter(timing_counter);
122 }
123
124 // If true, checks for memory overwrites by writing
125 // distinctive patterns on both ends of allocated memory.
126 if (useCudaMemoryGuardAllocator()) {
127 gpu_allocator = new GPUDebugAllocator(gpu_allocator, platform_gpu_id);
128 gpu_allocator = new GPUNanResetAllocator(gpu_allocator, platform_gpu_id);
129 } else if (useCudaMallocAllocator()) {
130 // If true, passes all allocation requests through to cudaMalloc
131 // useful for doing memory debugging with tools like cuda-memcheck
132 // **WARNING** probably will not work in a multi-gpu scenario
133 gpu_allocator =
134 new GPUcudaMallocAllocator(gpu_allocator, platform_gpu_id);
135 }
136
137 Allocator* recording_allocator = nullptr;
138 if (process_state_->ProcessState::FLAGS_brain_gpu_record_mem_types) {
139 ProcessState::MemDesc md;
140 md.loc = ProcessState::MemDesc::GPU;
141 md.dev_index = platform_gpu_id.value();
142 md.gpu_registered = false;
143 md.nic_registered = true;
144 recording_allocator = new internal::RecordingAllocator(
145 &process_state_->mem_desc_map_, gpu_allocator, md, &mu_);
146 }
147 allocator_parts = {std::unique_ptr<Allocator>(gpu_allocator),
148 std::unique_ptr<SharedCounter>(timing_counter),
149 sub_allocator,
150 std::unique_ptr<Allocator>(recording_allocator)};
151 }
152 if (process_state_->ProcessState::FLAGS_brain_gpu_record_mem_types) {
153 return allocator_parts.recording_allocator.get();
154 } else {
155 return allocator_parts.allocator.get();
156 }
157 #else
158 LOG(FATAL) << "GPUAllocator unavailable. Not compiled with --config=cuda or "
159 "--config=rocm.";
160 return nullptr;
161 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
162 }
163
GPUAllocatorCounter(TfGpuId tf_gpu_id)164 SharedCounter* GPUProcessState::GPUAllocatorCounter(TfGpuId tf_gpu_id) {
165 DCHECK(process_state_);
166 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
167 GpuIdUtil::CheckValidTfGpuId(tf_gpu_id);
168 mutex_lock l(mu_);
169 if (tf_gpu_id.value() >= static_cast<int64>(gpu_allocators_.size())) {
170 return nullptr;
171 }
172
173 AllocatorParts& allocator_parts = gpu_allocators_[tf_gpu_id.value()];
174 return allocator_parts.counter.get();
175 #else
176 return nullptr;
177 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
178 }
179
GetGpuHostAllocator(int numa_node)180 Allocator* GPUProcessState::GetGpuHostAllocator(int numa_node) {
181 CHECK(process_state_);
182 if (!HasGPUDevice() ||
183 !process_state_->ProcessState::FLAGS_brain_mem_reg_gpu_dma) {
184 return process_state_->GetCPUAllocator(numa_node);
185 }
186 if (numa_node == port::kNUMANoAffinity) {
187 numa_node = 0;
188 }
189 {
190 // Here we optimize the most common use case where gpu_host_allocators_
191 // have already been populated and since we're only reading
192 // these vectors, we can get by with a shared lock. In the slower case,
193 // we take a unique lock and populate these vectors.
194 tf_shared_lock lock(mu_);
195
196 if (process_state_->ProcessState::FLAGS_brain_gpu_record_mem_types &&
197 !gpu_host_allocators_.empty() &&
198 gpu_host_allocators_[0].recording_allocator != nullptr) {
199 return gpu_host_allocators_[0].recording_allocator.get();
200 }
201 if (static_cast<int>(gpu_host_allocators_.size()) > numa_node) {
202 return gpu_host_allocators_[0].allocator.get();
203 }
204 }
205
206 mutex_lock lock(mu_);
207 // Find the first valid StreamExecutor to request CUDA or ROCm host memory
208 // through, since any will work.
209 //
210 // This search isn't super clean, and it would be nice to use a
211 // better source of information about which executor to use. For
212 // example, process_state could maybe save the first stream executor
213 // it knows is valid.
214 se::StreamExecutor* se = nullptr;
215 for (int i = 0; i < static_cast<int>(gpu_allocators_.size()); ++i) {
216 if (gpu_allocators_[i].allocator != nullptr) {
217 se = GpuIdUtil::ExecutorForTfGpuId(TfGpuId(i)).ValueOrDie();
218 break;
219 }
220 }
221
222 CHECK_NE(nullptr, se);
223
224 while (static_cast<int>(gpu_host_allocators_.size()) <= numa_node) {
225 while (gpu_host_alloc_visitors_.size() <= numa_node) {
226 gpu_host_alloc_visitors_.push_back({});
227 }
228 while (gpu_host_free_visitors_.size() <= numa_node) {
229 gpu_host_free_visitors_.push_back({});
230 }
231 SubAllocator* sub_allocator =
232 new GpuHostAllocator(se, numa_node, gpu_host_alloc_visitors_[numa_node],
233 gpu_host_free_visitors_[numa_node]);
234 // TODO(zheng-xq): evaluate whether 64GB by default is the best choice.
235 int64 gpu_host_mem_limit_in_mb = -1;
236 Status status = ReadInt64FromEnvVar("TF_GPU_HOST_MEM_LIMIT_IN_MB",
237 1LL << 16 /*64GB max by default*/,
238 &gpu_host_mem_limit_in_mb);
239 if (!status.ok()) {
240 LOG(ERROR) << "GetGpuHostAllocator: " << status.error_message();
241 }
242 int64 gpu_host_mem_limit = gpu_host_mem_limit_in_mb * (1LL << 20);
243 Allocator* allocator =
244 new BFCAllocator(sub_allocator, gpu_host_mem_limit,
245 true /*allow_growth*/, "gpu_host_bfc" /*name*/);
246
247 if (LogMemory::IsEnabled() && !allocator->TracksAllocationSizes()) {
248 // Wrap the allocator to track allocation ids for better logging
249 // at the cost of performance.
250 allocator = new TrackingAllocator(allocator, true);
251 }
252 gpu_host_allocators_.push_back({std::unique_ptr<Allocator>(allocator),
253 std::unique_ptr<SharedCounter>(nullptr),
254 sub_allocator,
255 std::unique_ptr<Allocator>(nullptr)});
256 AllocatorParts& allocator_parts = gpu_host_allocators_.back();
257 if (process_state_->ProcessState::FLAGS_brain_gpu_record_mem_types) {
258 ProcessState::MemDesc md;
259 md.loc = ProcessState::MemDesc::CPU;
260 md.dev_index = 0;
261 md.gpu_registered = true;
262 md.nic_registered = false;
263 allocator_parts.recording_allocator.reset(
264 new internal::RecordingAllocator(&process_state_->mem_desc_map_,
265 allocator_parts.allocator.get(), md,
266 &mu_));
267 }
268 }
269 if (process_state_->ProcessState::FLAGS_brain_gpu_record_mem_types) {
270 return gpu_host_allocators_[0].recording_allocator.get();
271 } else {
272 return gpu_host_allocators_[0].allocator.get();
273 }
274 }
275
AddGPUAllocVisitor(int bus_id,const SubAllocator::Visitor & visitor)276 void GPUProcessState::AddGPUAllocVisitor(int bus_id,
277 const SubAllocator::Visitor& visitor) {
278 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
279 mutex_lock lock(mu_);
280 CHECK(gpu_allocators_.empty()) // Crash OK
281 << "AddGPUAllocVisitor must be called before "
282 "first call to GetGPUAllocator.";
283 DCHECK_GE(bus_id, 0);
284 while (bus_id >= static_cast<int64>(gpu_visitors_.size())) {
285 gpu_visitors_.push_back(std::vector<SubAllocator::Visitor>());
286 }
287 gpu_visitors_[bus_id].push_back(visitor);
288 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
289 }
290
AddGpuHostAllocVisitor(int numa_node,const SubAllocator::Visitor & visitor)291 void GPUProcessState::AddGpuHostAllocVisitor(
292 int numa_node, const SubAllocator::Visitor& visitor) {
293 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
294 mutex_lock lock(mu_);
295 CHECK(gpu_host_allocators_.empty()) // Crash OK
296 << "AddGpuHostAllocVisitor must be called before "
297 "first call to GetGpuHostAllocator.";
298 while (numa_node >= static_cast<int64>(gpu_host_alloc_visitors_.size())) {
299 gpu_host_alloc_visitors_.push_back(std::vector<SubAllocator::Visitor>());
300 }
301 gpu_host_alloc_visitors_[numa_node].push_back(visitor);
302 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
303 }
304
AddGpuHostFreeVisitor(int numa_node,const SubAllocator::Visitor & visitor)305 void GPUProcessState::AddGpuHostFreeVisitor(
306 int numa_node, const SubAllocator::Visitor& visitor) {
307 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
308 mutex_lock lock(mu_);
309 CHECK(gpu_host_allocators_.empty()) // Crash OK
310 << "AddGpuHostFreeVisitor must be called before "
311 "first call to GetGpuHostAllocator.";
312 while (numa_node >= static_cast<int64>(gpu_host_free_visitors_.size())) {
313 gpu_host_free_visitors_.push_back(std::vector<SubAllocator::Visitor>());
314 }
315 gpu_host_free_visitors_[numa_node].push_back(visitor);
316 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
317 }
318
TestOnlyReset()319 void GPUProcessState::TestOnlyReset() {
320 if (process_state_) {
321 process_state_->ProcessState::TestOnlyReset();
322 }
323 {
324 mutex_lock lock(mu_);
325 gpu_device_enabled_ = false;
326 gpu_allocators_.clear();
327 gpu_visitors_.clear();
328 gpu_host_allocators_.clear();
329 gpu_host_alloc_visitors_.clear();
330 gpu_host_free_visitors_.clear();
331 }
332 }
333
334 } // namespace tensorflow
335