1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Registers the XLA_GPU device, which is an XlaDevice instantiation that runs
17 // operators using XLA via the XLA "CUDA" or "ROCM" (GPU) backend.
18
19 #include <set>
20
21 #include "absl/memory/memory.h"
22 #include "absl/strings/numbers.h"
23 #include "absl/strings/str_split.h"
24 #include "tensorflow/compiler/jit/flags.h"
25 #include "tensorflow/compiler/jit/kernels/xla_ops.h"
26 #include "tensorflow/compiler/jit/xla_device.h"
27 #include "tensorflow/compiler/jit/xla_device_ops.h"
28 #include "tensorflow/compiler/jit/xla_platform_info.h"
29 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
30 #include "tensorflow/core/common_runtime/device_factory.h"
31 #include "tensorflow/core/common_runtime/gpu/gpu_init.h"
32 #include "tensorflow/core/lib/core/status.h"
33
34 namespace tensorflow {
35
36 class XlaGpuDeviceFactory : public DeviceFactory {
37 public:
38 Status ListPhysicalDevices(std::vector<string>* devices) override;
39 Status CreateDevices(const SessionOptions& options, const string& name_prefix,
40 std::vector<std::unique_ptr<Device>>* devices) override;
41 };
42
ListPhysicalDevices(std::vector<string> * devices)43 Status XlaGpuDeviceFactory::ListPhysicalDevices(std::vector<string>* devices) {
44 XlaDeviceFlags* flags = GetXlaDeviceFlags();
45 if (!flags->tf_xla_enable_xla_devices) {
46 VLOG(1) << "Not creating XLA devices, tf_xla_enable_xla_devices not set";
47 return Status::OK();
48 }
49
50 auto platform =
51 se::MultiPlatformManager::PlatformWithName(tensorflow::GpuPlatformName());
52 if (!platform.ok()) {
53 // Treat failures as non-fatal; there might not be a GPU in the machine.
54 VLOG(1) << "Failed to create XLA_GPU device: " << platform.status();
55 return Status::OK();
56 }
57
58 int device_count = platform.ValueOrDie()->VisibleDeviceCount();
59 if (device_count <= 0) {
60 return Status::OK();
61 }
62
63 for (int i = 0; i < device_count; ++i) {
64 devices->push_back(
65 absl::StrCat("/physical_device:", DEVICE_XLA_GPU, ":", i));
66 }
67
68 return Status::OK();
69 }
70
CreateDevices(const SessionOptions & session_options,const string & name_prefix,std::vector<std::unique_ptr<Device>> * devices)71 Status XlaGpuDeviceFactory::CreateDevices(
72 const SessionOptions& session_options, const string& name_prefix,
73 std::vector<std::unique_ptr<Device>>* devices) {
74 XlaDeviceFlags* flags = GetXlaDeviceFlags();
75 if (!flags->tf_xla_enable_xla_devices) {
76 VLOG(1) << "Not creating XLA devices, tf_xla_enable_xla_devices not set";
77 return Status::OK();
78 }
79
80 XlaOpRegistry::DeviceRegistration registration;
81 registration.compilation_device_name = DEVICE_GPU_XLA_JIT;
82 registration.autoclustering_policy =
83 XlaOpRegistry::AutoclusteringPolicy::kAlways;
84 registration.cluster_resource_variable_ops_unsafely = true;
85 registration.cluster_stack_ops = false;
86 registration.cluster_tensor_array_ops = true;
87 registration.cluster_stateful_rng_ops = true;
88 registration.cluster_control_trigger = true;
89 registration.elide_assert_and_checknumerics = true;
90 registration.cluster_variant_ops = true;
91 registration.cluster_slow_ops = true;
92 registration.cluster_inaccurate_ops = true;
93 XlaOpRegistry::RegisterCompilationDevice(DEVICE_XLA_GPU, registration);
94
95 static XlaDeviceOpRegistrations* registrations =
96 RegisterXlaDeviceKernels(DEVICE_XLA_GPU, DEVICE_GPU_XLA_JIT);
97 (void)registrations;
98
99 auto platform =
100 se::MultiPlatformManager::PlatformWithName(tensorflow::GpuPlatformName());
101 if (!platform.ok()) {
102 // Treat failures as non-fatal; there might not be a GPU in the machine.
103 VLOG(1) << "Failed to create XLA_GPU device: " << platform.status();
104 return Status::OK();
105 }
106
107 auto iter = session_options.config.device_count().find("GPU");
108 if (iter != session_options.config.device_count().end() &&
109 iter->second == 0) {
110 // Device count for GPU is 0.
111 return Status::OK();
112 }
113
114 string allowed_gpus =
115 session_options.config.gpu_options().visible_device_list();
116 absl::optional<std::set<int>> gpu_ids =
117 ParseVisibleDeviceList(allowed_gpus).ValueOrDie();
118 if (!gpu_ids) {
119 gpu_ids.emplace();
120 // Fill the gpu_ids set with all devices if config string is empty.
121 for (int i = 0; i < platform.ValueOrDie()->VisibleDeviceCount(); ++i) {
122 gpu_ids->insert(i);
123 }
124 }
125 for (int i : *gpu_ids) {
126 XlaDevice::Options options;
127 options.platform = platform.ValueOrDie();
128 options.device_name_prefix = name_prefix;
129 options.device_name = DEVICE_XLA_GPU;
130 options.device_ordinal = i;
131 options.compilation_device_name = DEVICE_GPU_XLA_JIT;
132 options.use_multiple_streams = true;
133 options.allowed_devices = gpu_ids;
134 auto device = absl::make_unique<XlaDevice>(session_options, options);
135
136 Status status = device->UseGpuDeviceInfo();
137 if (!status.ok()) {
138 LOG(INFO) << "Ignoring visible " << DEVICE_GPU_XLA_JIT
139 << " device. Device number is " << i << ", reason: " << status;
140 continue;
141 }
142
143 devices->push_back(std::move(device));
144 }
145 return Status::OK();
146 }
147
148 REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_GPU, XlaGpuDeviceFactory);
149
150 // Kernel registrations
151
152 constexpr std::array<DataType, 16> kAllXlaGpuTypes = {
153 {DT_UINT8, DT_QUINT8, DT_UINT16, DT_INT8, DT_QINT8, DT_INT16, DT_INT32,
154 DT_QINT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64,
155 DT_COMPLEX128, DT_BOOL, DT_BFLOAT16}};
156
157 REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_GPU, XlaLocalLaunchOp, kAllXlaGpuTypes);
158 REGISTER_XLA_COMPILE_KERNEL(DEVICE_XLA_GPU, XlaCompileOp, kAllXlaGpuTypes);
159 REGISTER_XLA_RUN_KERNEL(DEVICE_XLA_GPU, XlaRunOp, kAllXlaGpuTypes);
160
161 REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_GPU, kAllXlaGpuTypes);
162
163 } // namespace tensorflow
164