1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Registers the XLA_GPU device, which is an XlaDevice instantiation that runs
17 // operators using XLA via the XLA "CUDA" or "ROCM" (GPU) backend.
18
19 #include <set>
20
21 #include "absl/memory/memory.h"
22 #include "absl/strings/numbers.h"
23 #include "absl/strings/str_split.h"
24 #include "tensorflow/compiler/jit/defs.h"
25 #include "tensorflow/compiler/jit/flags.h"
26 #include "tensorflow/compiler/jit/kernels/xla_ops.h"
27 #include "tensorflow/compiler/jit/xla_device.h"
28 #include "tensorflow/compiler/jit/xla_device_ops.h"
29 #include "tensorflow/compiler/jit/xla_platform_info.h"
30 #include "tensorflow/compiler/tf2xla/layout_util.h"
31 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
32 #include "tensorflow/core/common_runtime/device_factory.h"
33 #include "tensorflow/core/common_runtime/gpu/gpu_init.h"
34 #include "tensorflow/core/lib/core/status.h"
35
36 namespace tensorflow {
37
38 class XlaGpuDeviceFactory : public DeviceFactory {
39 public:
40 Status ListPhysicalDevices(std::vector<string>* devices) override;
41 Status CreateDevices(const SessionOptions& options, const string& name_prefix,
42 std::vector<std::unique_ptr<Device>>* devices) override;
43 };
44
ListPhysicalDevices(std::vector<string> * devices)45 Status XlaGpuDeviceFactory::ListPhysicalDevices(std::vector<string>* devices) {
46 XlaDeviceFlags* flags = GetXlaDeviceFlags();
47 if (!flags->tf_xla_enable_xla_devices && !XlaDevicesCreationRequired()) {
48 VLOG(1) << "Not creating XLA devices, tf_xla_enable_xla_devices not set "
49 "and XLA devices creation not required";
50 return OkStatus();
51 }
52
53 auto platform =
54 se::MultiPlatformManager::PlatformWithName(tensorflow::GpuPlatformName());
55 if (!platform.ok()) {
56 // Treat failures as non-fatal; there might not be a GPU in the machine.
57 VLOG(1) << "Failed to create XLA_GPU device: " << platform.status();
58 return OkStatus();
59 }
60
61 int device_count = platform.ValueOrDie()->VisibleDeviceCount();
62 if (device_count <= 0) {
63 return OkStatus();
64 }
65
66 for (int i = 0; i < device_count; ++i) {
67 devices->push_back(
68 absl::StrCat("/physical_device:", DEVICE_XLA_GPU, ":", i));
69 }
70
71 return OkStatus();
72 }
73
CreateDevices(const SessionOptions & session_options,const string & name_prefix,std::vector<std::unique_ptr<Device>> * devices)74 Status XlaGpuDeviceFactory::CreateDevices(
75 const SessionOptions& session_options, const string& name_prefix,
76 std::vector<std::unique_ptr<Device>>* devices) {
77 XlaDeviceFlags* flags = GetXlaDeviceFlags();
78 if (!flags->tf_xla_enable_xla_devices && !XlaDevicesCreationRequired()) {
79 VLOG(1) << "Not creating XLA devices, tf_xla_enable_xla_devices not set";
80 return OkStatus();
81 }
82
83 XlaOpRegistry::DeviceRegistration registration;
84 registration.compilation_device_name = DEVICE_GPU_XLA_JIT;
85 registration.autoclustering_policy =
86 XlaOpRegistry::AutoclusteringPolicy::kAlways;
87 registration.cluster_resource_variable_ops_unsafely = true;
88 registration.cluster_stack_ops = false;
89 registration.cluster_tensor_array_ops = true;
90 registration.cluster_stateful_rng_ops = true;
91 registration.cluster_control_trigger = true;
92 registration.elide_assert_and_checknumerics = true;
93 registration.cluster_variant_ops = true;
94 registration.cluster_slow_ops = true;
95 registration.cluster_inaccurate_ops = true;
96 XlaOpRegistry::RegisterCompilationDevice(DEVICE_XLA_GPU, registration);
97
98 static XlaDeviceOpRegistrations* registrations =
99 RegisterXlaDeviceKernels(DEVICE_XLA_GPU, DEVICE_GPU_XLA_JIT);
100 (void)registrations;
101
102 auto platform =
103 se::MultiPlatformManager::PlatformWithName(tensorflow::GpuPlatformName());
104 if (!platform.ok()) {
105 // Treat failures as non-fatal; there might not be a GPU in the machine.
106 VLOG(1) << "Failed to create XLA_GPU device: " << platform.status();
107 return OkStatus();
108 }
109
110 auto iter = session_options.config.device_count().find("GPU");
111 if (iter != session_options.config.device_count().end() &&
112 iter->second == 0) {
113 // Device count for GPU is 0.
114 return OkStatus();
115 }
116
117 string allowed_gpus =
118 session_options.config.gpu_options().visible_device_list();
119 std::optional<std::set<int>> gpu_ids =
120 ParseVisibleDeviceList(allowed_gpus).ValueOrDie();
121 if (!gpu_ids) {
122 gpu_ids.emplace();
123 // Fill the gpu_ids set with all devices if config string is empty.
124 for (int i = 0; i < platform.ValueOrDie()->VisibleDeviceCount(); ++i) {
125 gpu_ids->insert(i);
126 }
127 }
128 for (int i : *gpu_ids) {
129 XlaDevice::Options options;
130 options.platform = platform.ValueOrDie();
131 options.device_name_prefix = name_prefix;
132 options.device_name = DEVICE_XLA_GPU;
133 options.device_ordinal = i;
134 options.compilation_device_name = DEVICE_GPU_XLA_JIT;
135 options.use_multiple_streams = true;
136 options.allowed_devices = gpu_ids;
137 XlaShapeLayoutHelpers::ShapeDeterminationFns shape_representation_fns{
138 UseNoPreferenceLayoutFn(), IdentityShapeRepresentationFn()};
139 options.shape_determination_fns = {shape_representation_fns};
140 auto device = std::make_unique<XlaDevice>(session_options, options);
141
142 Status status = device->UseAcceleratorDeviceInfo();
143 if (!status.ok()) {
144 LOG(INFO) << "Ignoring visible " << DEVICE_GPU_XLA_JIT
145 << " device. Device number is " << i << ", reason: " << status;
146 continue;
147 }
148
149 devices->push_back(std::move(device));
150 }
151 return OkStatus();
152 }
153
154 REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_GPU, XlaGpuDeviceFactory);
155
156 // Kernel registrations
157
158 constexpr std::array<DataType, 16> kAllXlaGpuTypes = {
159 {DT_UINT8, DT_QUINT8, DT_UINT16, DT_INT8, DT_QINT8, DT_INT16, DT_INT32,
160 DT_QINT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64,
161 DT_COMPLEX128, DT_BOOL, DT_BFLOAT16}};
162
163 REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_GPU, XlaLocalLaunchOp, kAllXlaGpuTypes);
164 REGISTER_XLA_COMPILE_KERNEL(DEVICE_XLA_GPU, XlaCompileOp, kAllXlaGpuTypes);
165 REGISTER_XLA_RUN_KERNEL(DEVICE_XLA_GPU, XlaRunOp, kAllXlaGpuTypes);
166
167 REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_GPU, kAllXlaGpuTypes);
168
169 } // namespace tensorflow
170