• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/pjrt/gpu_device.h"
17 
18 #include "absl/container/flat_hash_map.h"
19 #include "tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h"
20 
21 #ifdef NCCL_ENABLED
22 #include "third_party/nccl/nccl.h"
23 #endif  // NCCL_ENABLED
24 #include "tensorflow/compiler/xla/client/client_library.h"
25 #include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h"
26 #include "tensorflow/compiler/xla/service/platform_util.h"
27 #include "tensorflow/compiler/xla/statusor.h"
28 #include "tensorflow/compiler/xla/util.h"
29 #include "tensorflow/core/common_runtime/device/device_host_allocator.h"
30 #include "tensorflow/core/common_runtime/device/device_id.h"
31 #include "tensorflow/core/common_runtime/device/device_mem_allocator.h"
32 #include "tensorflow/core/util/env_var.h"
33 #include "tensorflow/stream_executor/tf_allocator_adapter.h"
34 
35 namespace xla {
36 namespace {
37 
38 // A custom PjRtClient that overrides the device assignment method.
39 class GpuClient : public xla::PjRtStreamExecutorClient {
40  public:
41   using xla::PjRtStreamExecutorClient::PjRtStreamExecutorClient;
42 
43   xla::StatusOr<xla::DeviceAssignment> GetDefaultDeviceAssignment(
44       int num_replicas, int num_partitions) const override;
45 };
46 
GetDefaultDeviceAssignment(int num_replicas,int num_partitions) const47 xla::StatusOr<xla::DeviceAssignment> GpuClient::GetDefaultDeviceAssignment(
48     int num_replicas, int num_partitions) const {
49   if (num_partitions == 1 && num_replicas <= addressable_devices().size()) {
50     xla::DeviceAssignment assignment(num_replicas, 1);
51     for (int i = 0; i < num_replicas; ++i) {
52       assignment(i, 0) = addressable_devices().at(i)->id();
53     }
54     return assignment;
55   }
56   // Fallback to default global device assignment if we can't run locally.
57   return PjRtStreamExecutorClient::GetDefaultDeviceAssignment(num_replicas,
58                                                               num_partitions);
59 }
60 
61 // Builds an xla::LocalClient for the GPU platform.
GetGpuXlaClient()62 StatusOr<LocalClient*> GetGpuXlaClient() {
63   // "gpu" will be substitued by the default defined in platform_util.cc
64   TF_ASSIGN_OR_RETURN(se::Platform * platform,
65                       PlatformUtil::GetPlatform("gpu"));
66   if (platform->VisibleDeviceCount() <= 0) {
67     return FailedPrecondition("No visible GPU devices.");
68   }
69   LocalClientOptions options;
70   options.set_platform(platform);
71   return ClientLibrary::GetOrCreateLocalClient(options);
72 }
73 
74 // Builds a LocalDeviceState for each GPU present.
BuildLocalDeviceStates(LocalClient * xla_client,bool asynchronous)75 StatusOr<std::vector<std::unique_ptr<LocalDeviceState>>> BuildLocalDeviceStates(
76     LocalClient* xla_client, bool asynchronous) {
77   std::vector<std::unique_ptr<LocalDeviceState>> addressable_devices;
78   for (int i = 0; i < xla_client->device_count(); ++i) {
79     se::StreamExecutor* executor =
80         xla_client->backend().stream_executor(i).ValueOrDie();
81     addressable_devices.push_back(absl::make_unique<LocalDeviceState>(
82         executor, xla_client, LocalDeviceState::kComputeSynchronized,
83         asynchronous,
84         /*allow_event_reuse=*/true));
85   }
86   return std::move(addressable_devices);
87 }
88 
89 // Builds a BFCAllocator for all local GPUs.
CreateBFCAllocator(absl::Span<std::unique_ptr<LocalDeviceState> const> addressable_devices,double memory_fraction,bool preallocate)90 StatusOr<std::unique_ptr<se::MultiDeviceAdapter>> CreateBFCAllocator(
91     absl::Span<std::unique_ptr<LocalDeviceState> const> addressable_devices,
92     double memory_fraction, bool preallocate) {
93   CHECK_GT(addressable_devices.size(), 0);
94   const se::Platform* platform =
95       addressable_devices.front()->executor()->platform();
96   std::vector<se::MultiDeviceAdapter::AllocatorWithStream> allocators;
97   bool enable_unified_memory;
98   Status status = tensorflow::ReadBoolFromEnvVar("TF_FORCE_UNIFIED_MEMORY",
99                                                  false, &enable_unified_memory);
100   if (!status.ok()) {
101     LOG(ERROR) << "Unable to read TF_FORCE_UNIFIED_MEMORY: "
102                << status.error_message();
103   }
104 
105   for (auto& local_device : addressable_devices) {
106     se::StreamExecutor* executor = local_device->executor();
107     int device_ordinal = executor->device_ordinal();
108     auto sub_allocator = absl::make_unique<tensorflow::DeviceMemAllocator>(
109         executor, tensorflow::PlatformDeviceId(device_ordinal),
110         /*use_unified_memory=*/enable_unified_memory,
111         /*alloc_visitors=*/std::vector<tensorflow::SubAllocator::Visitor>(),
112         /*free_visitors=*/std::vector<tensorflow::SubAllocator::Visitor>());
113 
114     int64 free_memory;
115     int64 total_memory;
116     if (!executor->DeviceMemoryUsage(&free_memory, &total_memory)) {
117       return Unavailable("Failed to query available memory from device %i",
118                          device_ordinal);
119     }
120     // To allow full GPU memory to be visible to the BFC allocator if using
121     // unified memory.
122     size_t allocator_memory =
123         enable_unified_memory ? total_memory : free_memory * memory_fraction;
124     if (preallocate) {
125       LOG(INFO) << "XLA backend allocating " << allocator_memory
126                 << " bytes on device " << device_ordinal
127                 << " for BFCAllocator.";
128     } else {
129       LOG(INFO) << "XLA backend will use up to " << allocator_memory
130                 << " bytes on device " << device_ordinal
131                 << " for BFCAllocator.";
132     }
133     auto gpu_bfc_allocator = absl::make_unique<tensorflow::BFCAllocator>(
134         sub_allocator.release(), allocator_memory,
135         /*allow_growth=*/!preallocate,
136         absl::StrCat("GPU_", device_ordinal, "_bfc"));
137     allocators.emplace_back(std::move(gpu_bfc_allocator),
138                             local_device->compute_stream());
139   }
140   return absl::make_unique<se::MultiDeviceAdapter>(platform,
141                                                    std::move(allocators));
142 }
143 
144 // Constructs a GPU device memory allocator to use, according to the allocator
145 // configuration the client requested.
GetGpuDeviceAllocator(const GpuAllocatorConfig & allocator_config,absl::Span<std::unique_ptr<LocalDeviceState> const> addressable_devices)146 StatusOr<std::unique_ptr<se::DeviceMemoryAllocator>> GetGpuDeviceAllocator(
147     const GpuAllocatorConfig& allocator_config,
148     absl::Span<std::unique_ptr<LocalDeviceState> const> addressable_devices) {
149   std::unique_ptr<se::DeviceMemoryAllocator> allocator;
150   if (allocator_config.kind != GpuAllocatorConfig::Kind::kPlatform) {
151     TF_ASSIGN_OR_RETURN(allocator,
152                         CreateBFCAllocator(addressable_devices,
153                                            allocator_config.memory_fraction,
154                                            allocator_config.preallocate));
155   }
156   return std::move(allocator);
157 }
158 
159 // Returns a GPU pinned host memory allocator to use when staging host->GPU
160 // transfers. We use a fixed 64MB pool of pinned memory.
GetGpuHostAllocator(se::StreamExecutor * executor)161 std::unique_ptr<tensorflow::BFCAllocator> GetGpuHostAllocator(
162     se::StreamExecutor* executor) {
163   tensorflow::SubAllocator* sub_allocator = new tensorflow::DeviceHostAllocator(
164       executor, /*numa_node=*/0, /*alloc_visitors=*/{}, /*free_visitors=*/{});
165   // TODO(phawkins): allow the user to tune this.
166   const int64 kGpuHostMemoryLimitBytes = 64 * (1LL << 30);
167   return absl::make_unique<tensorflow::BFCAllocator>(
168       sub_allocator, kGpuHostMemoryLimitBytes, /*allow_growth=*/true,
169       /*name=*/"xla_gpu_host_bfc");
170 }
171 
172 // A table mapping NcclCliqueKeys to ncclUniqueId values encoded as strings.
173 // In a distributed setup the table of NCCL IDs is kept on the master node
174 // (node 0). The node of the first participating device will create the unique
175 // id.
176 class NcclIdStore {
177  public:
NcclIdStore(int node_id,std::shared_ptr<DistributedRuntimeClient> client,absl::flat_hash_map<GlobalDeviceId,int> device_to_node)178   NcclIdStore(int node_id, std::shared_ptr<DistributedRuntimeClient> client,
179               absl::flat_hash_map<GlobalDeviceId, int> device_to_node)
180       : node_id_(node_id),
181         client_(std::move(client)),
182         device_to_node_(std::move(device_to_node)) {}
183 
184   StatusOr<std::string> GetNcclUniqueId(const gpu::NcclCliqueKey& key);
185 
186  private:
187   const int node_id_;
188   const std::shared_ptr<DistributedRuntimeClient> client_;
189   const absl::flat_hash_map<GlobalDeviceId, int> device_to_node_;
190 
191   absl::Mutex mu_;
192   absl::flat_hash_map<gpu::NcclCliqueKey, std::string> cache_
193       ABSL_GUARDED_BY(mu_);
194 };
195 
GetNcclUniqueId(const gpu::NcclCliqueKey & key)196 StatusOr<std::string> NcclIdStore::GetNcclUniqueId(
197     const gpu::NcclCliqueKey& key) {
198   // The caller must ensure that threads calling this method concurrently have
199   // unique keys, otherwise the global key-value store may hold the wrong value.
200   {
201     absl::MutexLock lock(&mu_);
202     auto it = cache_.find(key);
203     if (it != cache_.end()) {
204       return it->second;
205     }
206   }
207   std::string id_string;
208   int primary_node_id = device_to_node_.at(key.devices()[0]);
209   if (node_id_ == primary_node_id) {
210 #ifdef NCCL_ENABLED
211     ncclUniqueId id;
212     ncclResult_t r = ncclGetUniqueId(&id);
213     TF_RET_CHECK(r == ncclSuccess);
214     id_string = std::string(id.internal, NCCL_UNIQUE_ID_BYTES);
215     TF_RETURN_IF_ERROR(client_->KeyValueSet(key.ToString(), id_string));
216 #else
217     return FailedPrecondition("NCCL support was not built into XLA binary.");
218 #endif
219   } else {
220     TF_ASSIGN_OR_RETURN(id_string, client_->BlockingKeyValueGet(
221                                        key.ToString(), absl::Minutes(5)));
222   }
223   absl::MutexLock lock(&mu_);
224   auto result = cache_.emplace(key, std::move(id_string));
225   TF_RET_CHECK(result.second) << "Unique ID already in cache.";
226   return result.first->second;
227 }
228 
BuildLocalDevices(std::vector<std::unique_ptr<LocalDeviceState>> local_device_states)229 std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> BuildLocalDevices(
230     std::vector<std::unique_ptr<LocalDeviceState>> local_device_states) {
231   std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices;
232   for (auto& local_device : local_device_states) {
233     int device_ordinal = local_device->device_ordinal();
234     const se::DeviceDescription& description =
235         local_device->executor()->GetDeviceDescription();
236     auto device = absl::make_unique<GpuDevice>(
237         device_ordinal, std::move(local_device), description.name(),
238         /*node_id=*/0);
239     devices.push_back(std::move(device));
240   }
241   return devices;
242 }
243 
BuildDistributedDevices(std::vector<std::unique_ptr<LocalDeviceState>> local_device_states,std::shared_ptr<DistributedRuntimeClient> distributed_client,int node_id,std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> * devices,gpu::GpuExecutableRunOptions * gpu_executable_run_options)244 Status BuildDistributedDevices(
245     std::vector<std::unique_ptr<LocalDeviceState>> local_device_states,
246     std::shared_ptr<DistributedRuntimeClient> distributed_client, int node_id,
247     std::vector<std::unique_ptr<PjRtStreamExecutorDevice>>* devices,
248     gpu::GpuExecutableRunOptions* gpu_executable_run_options) {
249   LocalTopologyProto local_topology;
250   local_topology.set_node_id(node_id);
251   for (const auto& local_device : local_device_states) {
252     const se::Platform* platform = local_device->executor()->platform();
253     TF_ASSIGN_OR_RETURN(
254         std::unique_ptr<xla::se::DeviceDescription> desc,
255         platform->DescriptionForDevice(local_device->device_ordinal()));
256     TF_RET_CHECK(local_device->device_ordinal() ==
257                  local_topology.devices_size());
258     DeviceProto* device_proto = local_topology.add_devices();
259     device_proto->set_local_device_ordinal(local_device->device_ordinal());
260     device_proto->set_name(desc->name());
261     device_proto->set_vendor(desc->device_vendor());
262   }
263 
264   GlobalTopologyProto global_topology;
265   TF_RETURN_IF_ERROR(
266       distributed_client->EnumerateDevices(local_topology, &global_topology));
267 
268   std::vector<GlobalDeviceId> gpu_device_ids(local_device_states.size());
269   absl::flat_hash_map<GlobalDeviceId, int> device_to_node;
270   for (const LocalTopologyProto& node : global_topology.nodes()) {
271     for (const DeviceProto& device_proto : node.devices()) {
272       GlobalDeviceId global_device_id(device_proto.global_device_id());
273       device_to_node[global_device_id] = node.node_id();
274       std::unique_ptr<LocalDeviceState> local_device;
275       if (node.node_id() == node_id) {
276         TF_RET_CHECK(device_proto.local_device_ordinal() >= 0 &&
277                      device_proto.local_device_ordinal() <
278                          local_device_states.size());
279         TF_RET_CHECK(local_device_states[device_proto.local_device_ordinal()] !=
280                      nullptr);
281         local_device =
282             std::move(local_device_states[device_proto.local_device_ordinal()]);
283         gpu_device_ids[device_proto.local_device_ordinal()] = global_device_id;
284       }
285       auto device = absl::make_unique<GpuDevice>(
286           device_proto.global_device_id(), std::move(local_device),
287           device_proto.name(), node.node_id());
288       devices->push_back(std::move(device));
289     }
290   }
291   for (const auto& device : local_device_states) {
292     TF_RET_CHECK(device == nullptr);
293   }
294   gpu_executable_run_options->set_gpu_global_device_ids(
295       std::move(gpu_device_ids));
296   auto nccl_id_store = std::make_shared<NcclIdStore>(
297       node_id, distributed_client, device_to_node);
298   gpu_executable_run_options->set_nccl_unique_id_callback(
299       [nccl_id_store](const gpu::NcclCliqueKey& key) {
300         return nccl_id_store->GetNcclUniqueId(key);
301       });
302   return Status::OK();
303 }
304 
305 }  // namespace
306 
GpuDevice(int id,std::unique_ptr<LocalDeviceState> local_device_state,std::string device_kind,int node_id)307 GpuDevice::GpuDevice(int id,
308                      std::unique_ptr<LocalDeviceState> local_device_state,
309                      std::string device_kind, int node_id)
310     : PjRtStreamExecutorDevice(id, std::move(local_device_state),
311                                std::move(device_kind), node_id) {}
312 
GetGpuClient(bool asynchronous,const GpuAllocatorConfig & allocator_config,std::shared_ptr<DistributedRuntimeClient> distributed_client,int node_id)313 StatusOr<std::unique_ptr<PjRtClient>> GetGpuClient(
314     bool asynchronous, const GpuAllocatorConfig& allocator_config,
315     std::shared_ptr<DistributedRuntimeClient> distributed_client, int node_id) {
316   TF_ASSIGN_OR_RETURN(LocalClient * xla_client, GetGpuXlaClient());
317   TF_ASSIGN_OR_RETURN(
318       std::vector<std::unique_ptr<LocalDeviceState>> local_device_states,
319       BuildLocalDeviceStates(xla_client, asynchronous));
320   TF_ASSIGN_OR_RETURN(
321       auto allocator,
322       GetGpuDeviceAllocator(allocator_config, local_device_states));
323   auto host_memory_allocator =
324       GetGpuHostAllocator(local_device_states.front()->executor());
325 
326   std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices;
327   auto gpu_run_options = absl::make_unique<gpu::GpuExecutableRunOptions>();
328   if (distributed_client) {
329     TF_RETURN_IF_ERROR(BuildDistributedDevices(
330         std::move(local_device_states), std::move(distributed_client), node_id,
331         &devices, gpu_run_options.get()));
332   } else {
333     devices = BuildLocalDevices(std::move(local_device_states));
334   }
335 
336   return std::unique_ptr<PjRtClient>(std::make_unique<GpuClient>(
337       kGpuName, xla_client, std::move(devices),
338       /*node_id=*/node_id, std::move(allocator),
339       std::move(host_memory_allocator),
340       /*should_stage_host_to_device_transfers=*/true,
341       /*gpu_run_options=*/std::move(gpu_run_options)));
342 }
343 
344 }  // namespace xla
345