• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Registers the XLA_GPU device, which is an XlaDevice instantiation that runs
17 // operators using XLA via the XLA "CUDA" or "ROCM" (GPU) backend.
18 
19 #include <set>
20 
21 #include "absl/memory/memory.h"
22 #include "absl/strings/numbers.h"
23 #include "absl/strings/str_split.h"
24 #include "tensorflow/compiler/jit/flags.h"
25 #include "tensorflow/compiler/jit/kernels/xla_ops.h"
26 #include "tensorflow/compiler/jit/xla_device.h"
27 #include "tensorflow/compiler/jit/xla_device_ops.h"
28 #include "tensorflow/compiler/jit/xla_platform_info.h"
29 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
30 #include "tensorflow/core/common_runtime/device_factory.h"
31 #include "tensorflow/core/common_runtime/gpu/gpu_init.h"
32 #include "tensorflow/core/lib/core/status.h"
33 
34 namespace tensorflow {
35 
36 class XlaGpuDeviceFactory : public DeviceFactory {
37  public:
38   Status ListPhysicalDevices(std::vector<string>* devices) override;
39   Status CreateDevices(const SessionOptions& options, const string& name_prefix,
40                        std::vector<std::unique_ptr<Device>>* devices) override;
41 };
42 
ListPhysicalDevices(std::vector<string> * devices)43 Status XlaGpuDeviceFactory::ListPhysicalDevices(std::vector<string>* devices) {
44   XlaDeviceFlags* flags = GetXlaDeviceFlags();
45   if (!flags->tf_xla_enable_xla_devices) {
46     VLOG(1) << "Not creating XLA devices, tf_xla_enable_xla_devices not set";
47     return Status::OK();
48   }
49 
50   auto platform =
51       se::MultiPlatformManager::PlatformWithName(tensorflow::GpuPlatformName());
52   if (!platform.ok()) {
53     // Treat failures as non-fatal; there might not be a GPU in the machine.
54     VLOG(1) << "Failed to create XLA_GPU device: " << platform.status();
55     return Status::OK();
56   }
57 
58   int device_count = platform.ValueOrDie()->VisibleDeviceCount();
59   if (device_count <= 0) {
60     return Status::OK();
61   }
62 
63   for (int i = 0; i < device_count; ++i) {
64     devices->push_back(
65         absl::StrCat("/physical_device:", DEVICE_XLA_GPU, ":", i));
66   }
67 
68   return Status::OK();
69 }
70 
CreateDevices(const SessionOptions & session_options,const string & name_prefix,std::vector<std::unique_ptr<Device>> * devices)71 Status XlaGpuDeviceFactory::CreateDevices(
72     const SessionOptions& session_options, const string& name_prefix,
73     std::vector<std::unique_ptr<Device>>* devices) {
74   XlaDeviceFlags* flags = GetXlaDeviceFlags();
75   if (!flags->tf_xla_enable_xla_devices) {
76     VLOG(1) << "Not creating XLA devices, tf_xla_enable_xla_devices not set";
77     return Status::OK();
78   }
79 
80   XlaOpRegistry::DeviceRegistration registration;
81   registration.compilation_device_name = DEVICE_GPU_XLA_JIT;
82   registration.autoclustering_policy =
83       XlaOpRegistry::AutoclusteringPolicy::kAlways;
84   registration.cluster_resource_variable_ops_unsafely = true;
85   registration.cluster_stack_ops = false;
86   registration.cluster_tensor_array_ops = true;
87   registration.cluster_stateful_rng_ops = true;
88   registration.cluster_control_trigger = true;
89   registration.elide_assert_and_checknumerics = true;
90   registration.cluster_variant_ops = true;
91   registration.cluster_slow_ops = true;
92   registration.cluster_inaccurate_ops = true;
93   XlaOpRegistry::RegisterCompilationDevice(DEVICE_XLA_GPU, registration);
94 
95   static XlaDeviceOpRegistrations* registrations =
96       RegisterXlaDeviceKernels(DEVICE_XLA_GPU, DEVICE_GPU_XLA_JIT);
97   (void)registrations;
98 
99   auto platform =
100       se::MultiPlatformManager::PlatformWithName(tensorflow::GpuPlatformName());
101   if (!platform.ok()) {
102     // Treat failures as non-fatal; there might not be a GPU in the machine.
103     VLOG(1) << "Failed to create XLA_GPU device: " << platform.status();
104     return Status::OK();
105   }
106 
107   auto iter = session_options.config.device_count().find("GPU");
108   if (iter != session_options.config.device_count().end() &&
109       iter->second == 0) {
110     // Device count for GPU is 0.
111     return Status::OK();
112   }
113 
114   string allowed_gpus =
115       session_options.config.gpu_options().visible_device_list();
116   absl::optional<std::set<int>> gpu_ids =
117       ParseVisibleDeviceList(allowed_gpus).ValueOrDie();
118   if (!gpu_ids) {
119     gpu_ids.emplace();
120     // Fill the gpu_ids set with all devices if config string is empty.
121     for (int i = 0; i < platform.ValueOrDie()->VisibleDeviceCount(); ++i) {
122       gpu_ids->insert(i);
123     }
124   }
125   for (int i : *gpu_ids) {
126     XlaDevice::Options options;
127     options.platform = platform.ValueOrDie();
128     options.device_name_prefix = name_prefix;
129     options.device_name = DEVICE_XLA_GPU;
130     options.device_ordinal = i;
131     options.compilation_device_name = DEVICE_GPU_XLA_JIT;
132     options.use_multiple_streams = true;
133     options.allowed_devices = gpu_ids;
134     auto device = absl::make_unique<XlaDevice>(session_options, options);
135 
136     Status status = device->UseGpuDeviceInfo();
137     if (!status.ok()) {
138       LOG(INFO) << "Ignoring visible " << DEVICE_GPU_XLA_JIT
139                 << " device. Device number is " << i << ", reason: " << status;
140       continue;
141     }
142 
143     devices->push_back(std::move(device));
144   }
145   return Status::OK();
146 }
147 
148 REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_GPU, XlaGpuDeviceFactory);
149 
150 // Kernel registrations
151 
152 constexpr std::array<DataType, 16> kAllXlaGpuTypes = {
153     {DT_UINT8, DT_QUINT8, DT_UINT16, DT_INT8, DT_QINT8, DT_INT16, DT_INT32,
154      DT_QINT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64,
155      DT_COMPLEX128, DT_BOOL, DT_BFLOAT16}};
156 
157 REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_GPU, XlaLocalLaunchOp, kAllXlaGpuTypes);
158 REGISTER_XLA_COMPILE_KERNEL(DEVICE_XLA_GPU, XlaCompileOp, kAllXlaGpuTypes);
159 REGISTER_XLA_RUN_KERNEL(DEVICE_XLA_GPU, XlaRunOp, kAllXlaGpuTypes);
160 
161 REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_GPU, kAllXlaGpuTypes);
162 
163 }  // namespace tensorflow
164