• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/pluggable_device/pluggable_device_process_state.h"
17 
18 #include <cstring>
19 #include <unordered_map>
20 #include <vector>
21 
22 #include "tensorflow/c/experimental/stream_executor/stream_executor_internal.h"
23 #include "tensorflow/core/common_runtime/device/device_host_allocator.h"
24 #include "tensorflow/core/common_runtime/device/device_id.h"
25 #include "tensorflow/core/common_runtime/device/device_id_manager.h"
26 #include "tensorflow/core/common_runtime/device/device_id_utils.h"
27 #include "tensorflow/core/common_runtime/device_factory.h"
28 #include "tensorflow/core/common_runtime/pluggable_device/pluggable_device_bfc_allocator.h"
29 #include "tensorflow/core/common_runtime/pluggable_device/pluggable_device_init.h"
30 #include "tensorflow/core/common_runtime/pluggable_device/pluggable_device_simple_allocator.h"
31 #include "tensorflow/core/common_runtime/pool_allocator.h"
32 #include "tensorflow/core/common_runtime/shared_counter.h"
33 #include "tensorflow/core/framework/allocator.h"
34 #include "tensorflow/core/framework/log_memory.h"
35 #include "tensorflow/core/framework/tracking_allocator.h"
36 #include "tensorflow/core/lib/strings/strcat.h"
37 #include "tensorflow/core/platform/logging.h"
38 #include "tensorflow/core/platform/mutex.h"
39 #include "tensorflow/core/platform/stream_executor.h"
40 #include "tensorflow/core/platform/types.h"
41 #include "tensorflow/core/util/env_var.h"
42 
43 namespace tensorflow {
44 
singleton(const string & device_type,const string & platform_name)45 /*static*/ PluggableDeviceProcessState* PluggableDeviceProcessState::singleton(
46     const string& device_type, const string& platform_name) {
47   using ProcessStateMap =
48       std::unordered_map<string, PluggableDeviceProcessState*>;
49   static ProcessStateMap* process_state_map = new ProcessStateMap;
50   auto iter = process_state_map->find(platform_name);
51   if (iter != process_state_map->end()) {
52     return iter->second;
53   }
54   (*process_state_map)[platform_name] =
55       new PluggableDeviceProcessState(device_type, platform_name);
56   return (*process_state_map)[platform_name];
57 }
58 
PluggableDeviceProcessState(const string & device_type,const string & platform_name)59 PluggableDeviceProcessState::PluggableDeviceProcessState(
60     const string& device_type, const string& platform_name)
61     : pluggable_device_enabled_(false),
62       device_type_(device_type),
63       platform_name_(platform_name) {
64   process_state_ = ProcessState::singleton();
65 }
66 
BusIdForPluggableDevice(TfDeviceId tf_device_id)67 int PluggableDeviceProcessState::BusIdForPluggableDevice(
68     TfDeviceId tf_device_id) {
69   // Return the NUMA node associated with the PluggableDevice's StreamExecutor.
70   se::Platform* platform = PluggableDeviceMachineManager(platform_name_);
71   se::StreamExecutor* se = DeviceIdUtil::ExecutorForTfDeviceId(
72                                DeviceType(device_type_), platform, tf_device_id)
73                                .ValueOrDie();
74   int numa_node = se->GetDeviceDescription().numa_node();
75   // `bus_id` must be non-negative. If the `numa_node` is unknown, use 0.
76   return numa_node >= 0 ? numa_node : 0;
77 }
78 
GetPluggableDeviceAllocator(const GPUOptions & options,TfDeviceId tf_device_id,size_t total_bytes)79 Allocator* PluggableDeviceProcessState::GetPluggableDeviceAllocator(
80     const GPUOptions& options, TfDeviceId tf_device_id, size_t total_bytes) {
81   DCHECK(process_state_);
82   const string& allocator_type = options.allocator_type();
83   se::Platform* platform = PluggableDeviceMachineManager(platform_name_);
84   mutex_lock lock(mu_);
85   DeviceIdUtil::CheckValidTfDeviceId(DeviceType(device_type_), platform,
86                                      tf_device_id);
87 
88   if (tf_device_id.value() >=
89       static_cast<int64_t>(pluggable_device_allocators_.size())) {
90     pluggable_device_allocators_.resize(tf_device_id.value() + 1);
91   }
92 
93   AllocatorParts& allocator_parts =
94       pluggable_device_allocators_[tf_device_id.value()];
95   if (allocator_parts.allocator == nullptr) {
96     if (!allocator_type.empty()) {
97       LOG(ERROR) << "Invalid allocator type: " << allocator_type;
98       return nullptr;
99     }
100 
101     PlatformDeviceId platform_device_id;
102     TF_CHECK_OK(DeviceIdManager::TfToPlatformDeviceId(
103         DeviceType(device_type_), tf_device_id, &platform_device_id));
104 
105     int bus_id = BusIdForPluggableDevice(tf_device_id);
106     DCHECK_GE(bus_id, 0);
107     while (bus_id >= pluggable_device_visitors_.size()) {
108       pluggable_device_visitors_.push_back({});
109     }
110 
111     bool use_unified_memory = options.per_process_gpu_memory_fraction() > 1.0 ||
112                               options.experimental().use_unified_memory();
113     DeviceMemAllocator* sub_allocator = new DeviceMemAllocator(
114         DeviceIdUtil::ExecutorForPlatformDeviceId(platform, platform_device_id)
115             .ValueOrDie(),
116         platform_device_id, use_unified_memory,
117         pluggable_device_visitors_[bus_id], {});
118 
119     Allocator* device_allocator = nullptr;
120     auto cplatform = dynamic_cast<se::CPlatform*>(platform);
121     if (cplatform == nullptr) {
122       LOG(FATAL) << "PluggableDevice's platform must be of type "  // Crash OK
123                  << "stream_executor::CPlatform";
124     }
125     if (cplatform->UseBfcAllocator()) {
126       device_allocator = new PluggableDeviceBFCAllocator(
127           sub_allocator, total_bytes, options,
128           strings::StrCat("PluggableDevice_", tf_device_id.value(), "_bfc"),
129           cplatform->ForceMemoryGrowth());
130     } else {
131       device_allocator = new PluggableDeviceSimpleAllocator(sub_allocator);
132     }
133 
134     allocator_parts = {std::unique_ptr<Allocator>(device_allocator),
135                        device_allocator, sub_allocator};
136   }
137   return allocator_parts.allocator.get();
138 }
139 
GetPluggableDeviceHostAllocator(int numa_node)140 Allocator* PluggableDeviceProcessState::GetPluggableDeviceHostAllocator(
141     int numa_node) {
142   DCHECK(process_state_);
143   if (!HasPluggableDevice()) {
144     return process_state_->GetCPUAllocator(numa_node);
145   }
146   if (numa_node == port::kNUMANoAffinity) {
147     numa_node = 0;
148   }
149   {
150     // Here we optimize the most common use case where
151     // pluggable_device_host_allocators_ have already been populated and since
152     // we're only reading these vectors, we can get by with a shared lock. In
153     // the slower case, we take a unique lock and populate these vectors.
154     tf_shared_lock lock(mu_);
155     if (static_cast<int>(pluggable_device_host_allocators_.size()) >
156         numa_node) {
157       return pluggable_device_host_allocators_[0].allocator.get();
158     }
159   }
160 
161   mutex_lock lock(mu_);
162   // Find the first valid StreamExecutor to request PluggableDevice host memory
163   // through, since any will work.
164   se::Platform* platform = PluggableDeviceMachineManager(platform_name_);
165   se::StreamExecutor* se = nullptr;
166   for (int i = 0; i < static_cast<int>(pluggable_device_allocators_.size());
167        ++i) {
168     if (pluggable_device_allocators_[i].allocator != nullptr) {
169       se = DeviceIdUtil::ExecutorForTfDeviceId(DeviceType(device_type_),
170                                                platform, TfDeviceId(i))
171                .ValueOrDie();
172       break;
173     }
174   }
175 
176   DCHECK_NE(nullptr, se);
177 
178   while (static_cast<int>(pluggable_device_host_allocators_.size()) <=
179          numa_node) {
180     while (pluggable_device_host_alloc_visitors_.size() <= numa_node) {
181       pluggable_device_host_alloc_visitors_.push_back({});
182     }
183     while (pluggable_device_host_free_visitors_.size() <= numa_node) {
184       pluggable_device_host_free_visitors_.push_back({});
185     }
186     SubAllocator* sub_allocator = new DeviceHostAllocator(
187         se, numa_node, pluggable_device_host_alloc_visitors_[numa_node],
188         pluggable_device_host_free_visitors_[numa_node]);
189     int64_t pluggable_device_host_mem_limit_in_mb = -1;
190     Status status = ReadInt64FromEnvVar("TF_GPU_HOST_MEM_LIMIT_IN_MB",
191                                         1LL << 16 /*64GB max by default*/,
192                                         &pluggable_device_host_mem_limit_in_mb);
193     if (!status.ok()) {
194       LOG(ERROR) << "GetPluggableDeviceHostAllocator: "
195                  << status.error_message();
196     }
197     int64_t pluggable_device_host_mem_limit =
198         pluggable_device_host_mem_limit_in_mb << 20;
199 
200     BFCAllocator::Options allocator_opts;
201     allocator_opts.allow_growth = true;
202     Allocator* allocator = new BFCAllocator(
203         absl::WrapUnique(sub_allocator), pluggable_device_host_mem_limit,
204         /*name=*/"pluggable_device_host_bfc", allocator_opts);
205 
206     if (LogMemory::IsEnabled() && !allocator->TracksAllocationSizes()) {
207       // Wrap the allocator to track allocation ids for better logging
208       // at the cost of performance.
209       allocator = new TrackingAllocator(allocator, true);
210     }
211     pluggable_device_host_allocators_.push_back(
212         {std::unique_ptr<Allocator>(allocator), nullptr /*bfc_allocator*/,
213          sub_allocator});
214   }
215   return pluggable_device_host_allocators_[0].allocator.get();
216 }
217 
218 }  // namespace tensorflow
219