• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Registers the XLA_GPU device, which is an XlaDevice instantiation that runs
17 // operators using XLA via the XLA "CUDA" or "ROCM" (GPU) backend.
18 
19 #include <set>
20 
21 #include "absl/memory/memory.h"
22 #include "absl/strings/numbers.h"
23 #include "absl/strings/str_split.h"
24 #include "tensorflow/compiler/jit/defs.h"
25 #include "tensorflow/compiler/jit/flags.h"
26 #include "tensorflow/compiler/jit/kernels/xla_ops.h"
27 #include "tensorflow/compiler/jit/xla_device.h"
28 #include "tensorflow/compiler/jit/xla_device_ops.h"
29 #include "tensorflow/compiler/jit/xla_platform_info.h"
30 #include "tensorflow/compiler/tf2xla/layout_util.h"
31 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
32 #include "tensorflow/core/common_runtime/device_factory.h"
33 #include "tensorflow/core/common_runtime/gpu/gpu_init.h"
34 #include "tensorflow/core/lib/core/status.h"
35 
36 namespace tensorflow {
37 
38 class XlaGpuDeviceFactory : public DeviceFactory {
39  public:
40   Status ListPhysicalDevices(std::vector<string>* devices) override;
41   Status CreateDevices(const SessionOptions& options, const string& name_prefix,
42                        std::vector<std::unique_ptr<Device>>* devices) override;
43 };
44 
ListPhysicalDevices(std::vector<string> * devices)45 Status XlaGpuDeviceFactory::ListPhysicalDevices(std::vector<string>* devices) {
46   XlaDeviceFlags* flags = GetXlaDeviceFlags();
47   if (!flags->tf_xla_enable_xla_devices && !XlaDevicesCreationRequired()) {
48     VLOG(1) << "Not creating XLA devices, tf_xla_enable_xla_devices not set "
49                "and XLA devices creation not required";
50     return OkStatus();
51   }
52 
53   auto platform =
54       se::MultiPlatformManager::PlatformWithName(tensorflow::GpuPlatformName());
55   if (!platform.ok()) {
56     // Treat failures as non-fatal; there might not be a GPU in the machine.
57     VLOG(1) << "Failed to create XLA_GPU device: " << platform.status();
58     return OkStatus();
59   }
60 
61   int device_count = platform.ValueOrDie()->VisibleDeviceCount();
62   if (device_count <= 0) {
63     return OkStatus();
64   }
65 
66   for (int i = 0; i < device_count; ++i) {
67     devices->push_back(
68         absl::StrCat("/physical_device:", DEVICE_XLA_GPU, ":", i));
69   }
70 
71   return OkStatus();
72 }
73 
CreateDevices(const SessionOptions & session_options,const string & name_prefix,std::vector<std::unique_ptr<Device>> * devices)74 Status XlaGpuDeviceFactory::CreateDevices(
75     const SessionOptions& session_options, const string& name_prefix,
76     std::vector<std::unique_ptr<Device>>* devices) {
77   XlaDeviceFlags* flags = GetXlaDeviceFlags();
78   if (!flags->tf_xla_enable_xla_devices && !XlaDevicesCreationRequired()) {
79     VLOG(1) << "Not creating XLA devices, tf_xla_enable_xla_devices not set";
80     return OkStatus();
81   }
82 
83   XlaOpRegistry::DeviceRegistration registration;
84   registration.compilation_device_name = DEVICE_GPU_XLA_JIT;
85   registration.autoclustering_policy =
86       XlaOpRegistry::AutoclusteringPolicy::kAlways;
87   registration.cluster_resource_variable_ops_unsafely = true;
88   registration.cluster_stack_ops = false;
89   registration.cluster_tensor_array_ops = true;
90   registration.cluster_stateful_rng_ops = true;
91   registration.cluster_control_trigger = true;
92   registration.elide_assert_and_checknumerics = true;
93   registration.cluster_variant_ops = true;
94   registration.cluster_slow_ops = true;
95   registration.cluster_inaccurate_ops = true;
96   XlaOpRegistry::RegisterCompilationDevice(DEVICE_XLA_GPU, registration);
97 
98   static XlaDeviceOpRegistrations* registrations =
99       RegisterXlaDeviceKernels(DEVICE_XLA_GPU, DEVICE_GPU_XLA_JIT);
100   (void)registrations;
101 
102   auto platform =
103       se::MultiPlatformManager::PlatformWithName(tensorflow::GpuPlatformName());
104   if (!platform.ok()) {
105     // Treat failures as non-fatal; there might not be a GPU in the machine.
106     VLOG(1) << "Failed to create XLA_GPU device: " << platform.status();
107     return OkStatus();
108   }
109 
110   auto iter = session_options.config.device_count().find("GPU");
111   if (iter != session_options.config.device_count().end() &&
112       iter->second == 0) {
113     // Device count for GPU is 0.
114     return OkStatus();
115   }
116 
117   string allowed_gpus =
118       session_options.config.gpu_options().visible_device_list();
119   std::optional<std::set<int>> gpu_ids =
120       ParseVisibleDeviceList(allowed_gpus).ValueOrDie();
121   if (!gpu_ids) {
122     gpu_ids.emplace();
123     // Fill the gpu_ids set with all devices if config string is empty.
124     for (int i = 0; i < platform.ValueOrDie()->VisibleDeviceCount(); ++i) {
125       gpu_ids->insert(i);
126     }
127   }
128   for (int i : *gpu_ids) {
129     XlaDevice::Options options;
130     options.platform = platform.ValueOrDie();
131     options.device_name_prefix = name_prefix;
132     options.device_name = DEVICE_XLA_GPU;
133     options.device_ordinal = i;
134     options.compilation_device_name = DEVICE_GPU_XLA_JIT;
135     options.use_multiple_streams = true;
136     options.allowed_devices = gpu_ids;
137     XlaShapeLayoutHelpers::ShapeDeterminationFns shape_representation_fns{
138         UseNoPreferenceLayoutFn(), IdentityShapeRepresentationFn()};
139     options.shape_determination_fns = {shape_representation_fns};
140     auto device = std::make_unique<XlaDevice>(session_options, options);
141 
142     Status status = device->UseAcceleratorDeviceInfo();
143     if (!status.ok()) {
144       LOG(INFO) << "Ignoring visible " << DEVICE_GPU_XLA_JIT
145                 << " device. Device number is " << i << ", reason: " << status;
146       continue;
147     }
148 
149     devices->push_back(std::move(device));
150   }
151   return OkStatus();
152 }
153 
154 REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_GPU, XlaGpuDeviceFactory);
155 
156 // Kernel registrations
157 
158 constexpr std::array<DataType, 16> kAllXlaGpuTypes = {
159     {DT_UINT8, DT_QUINT8, DT_UINT16, DT_INT8, DT_QINT8, DT_INT16, DT_INT32,
160      DT_QINT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64,
161      DT_COMPLEX128, DT_BOOL, DT_BFLOAT16}};
162 
163 REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_GPU, XlaLocalLaunchOp, kAllXlaGpuTypes);
164 REGISTER_XLA_COMPILE_KERNEL(DEVICE_XLA_GPU, XlaCompileOp, kAllXlaGpuTypes);
165 REGISTER_XLA_RUN_KERNEL(DEVICE_XLA_GPU, XlaRunOp, kAllXlaGpuTypes);
166 
167 REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_GPU, kAllXlaGpuTypes);
168 
169 }  // namespace tensorflow
170