1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // The XlaDevice executes a TensorFlow graph using the XLA linear algebra 17 // runtime. 18 // 19 // Operators assigned to an XlaDevice are compiled into XLA computations. 20 // Tensors on an XlaDevice are thin wrappers around XLA ScopedShapedBuffers. 21 // 22 // XlaDevice is instantiated separately for each XLA backend (e.g., CPU or GPU), 23 // under different names (e.g., XLA_CPU or XLA_GPU). 24 25 #ifndef TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_ 26 #define TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_ 27 #include <set> 28 29 #include "absl/types/optional.h" 30 #include "tensorflow/compiler/jit/xla_device_context.h" 31 #include "tensorflow/compiler/jit/xla_tensor.h" 32 #include "tensorflow/compiler/tf2xla/xla_compiler.h" 33 #include "tensorflow/compiler/tf2xla/xla_op_registry.h" 34 #include "tensorflow/compiler/xla/client/local_client.h" 35 #include "tensorflow/core/common_runtime/device_factory.h" 36 #include "tensorflow/core/common_runtime/local_device.h" 37 #include "tensorflow/core/framework/allocator.h" 38 #include "tensorflow/core/framework/device_base.h" 39 #include "tensorflow/core/framework/node_def_builder.h" 40 #include "tensorflow/core/framework/op_kernel.h" 41 #include "tensorflow/core/framework/resource_mgr.h" 42 #include "tensorflow/core/framework/tensor.h" 43 #include "tensorflow/core/framework/types.h" 44 #include "tensorflow/core/lib/core/status.h" 45 #include "tensorflow/core/platform/mutex.h" 46 #include "tensorflow/core/platform/stream_executor_no_cuda.h" 47 48 namespace tensorflow { 49 50 class XlaDevice : public LocalDevice { 51 public: 52 // Given a tensor, sets `xla::Shape*` the shape of tensor's representation 53 // on device, fully padded. On error, the contents of `xla::Shape*` 54 // are undefined. 55 typedef std::function<Status(const Tensor&, xla::Shape*)> PaddedShapeFn; 56 57 // Wrapper class to store metadata about the XlaDevice, where it can be 58 // retrieved e.g., when lazily creating the XlaCompilationCache device. 59 class Metadata { 60 public: 61 Metadata(int device_ordinal, se::Platform* platform, 62 const DeviceType& device_type, 63 XlaCompiler::ShapeRepresentationFn shape_representation_fn, 64 PaddedShapeFn padded_shape_fn, bool use_multiple_streams); 65 66 // The index of the device on this host. 67 int device_ordinal() const; 68 69 se::Platform* platform() const; 70 xla::LocalClient* client() const; 71 const DeviceType& jit_device_type() const; shape_representation_fn()72 const XlaCompiler::ShapeRepresentationFn& shape_representation_fn() const { 73 return shape_representation_fn_; 74 } padded_shape_fn()75 const PaddedShapeFn& padded_shape_fn() const { return padded_shape_fn_; } 76 UseMultipleStreams()77 bool UseMultipleStreams() const { return use_multiple_streams_; } 78 79 private: 80 const int device_ordinal_; 81 const DeviceType device_type_; 82 se::Platform* platform_; // Not owned. 83 XlaCompiler::ShapeRepresentationFn shape_representation_fn_; 84 PaddedShapeFn padded_shape_fn_; 85 const bool use_multiple_streams_; 86 87 TF_DISALLOW_COPY_AND_ASSIGN(Metadata); 88 }; 89 90 // Sets `*metadata` to the XlaDevice Metadata in the XLA device used by `ctx`. 91 static Status GetMetadata(OpKernelContext* ctx, const Metadata** metadata); 92 93 // Sets `*metadata` to the XlaDevice Metadata in the XLA device used by `ctx`. 94 static Status GetMetadata(OpKernelConstruction* ctx, 95 const Metadata** metadata); 96 97 // Sets `*metadata` to the XlaDevice Metadata in the XLA device used by 98 // `device`. 99 static Status GetMetadataFromDevice(DeviceBase* device, 100 const XlaDevice::Metadata** metadata); 101 102 struct Options { 103 // The StreamExecutor platform. Not owned. Must be non-null. 104 se::Platform* platform = nullptr; 105 106 // The device name's prefix (e.g., "/task:7") 107 string device_name_prefix; 108 109 // The name of the XLA device (e.g., "XLA_CPU") 110 string device_name; 111 112 // The number of the device. 113 int device_ordinal = -1; 114 115 // The name of the compilation device (e.g., "XLA_CPU_JIT"); 116 string compilation_device_name; 117 118 // If 'use_multiple_streams' is true, we create separate streams for 119 // compute, host-to-device, and device-to-host communication. 120 bool use_multiple_streams = false; 121 122 // If true, the XLA devices with the same device ordinal will share the same 123 // compute stream. Otherwise each XLA device will having their own compute 124 // streams. 125 bool use_global_compute_stream = false; 126 127 // A function that describes how the on-host shapes of 128 // a) argument and return value, for entry computations 129 // b) variables, for all computations, 130 // should be represented in XLA. Parameters/return values will be shaped 131 // according to this function, and reshaped back to/from their declared 132 // shapes for computations. Must be non-null. 133 XlaCompiler::ShapeRepresentationFn shape_representation_fn; 134 135 // If padded_shape_fn is empty, a default implementation that returns 136 // the logical on-device shape without padding is used. 137 PaddedShapeFn padded_shape_fn; 138 139 // Set of devices to use. This controls which of the devices on the given 140 // platform will have resources allocated. For GPUs this will be 141 // filled from visible_gpu_devices list from session configuration. 142 absl::optional<std::set<int>> allowed_devices; 143 }; 144 145 // Creates a new XLA Device. 146 XlaDevice(const SessionOptions& session_options, const Options& options); 147 ~XlaDevice() override; 148 149 Allocator* GetAllocator(AllocatorAttributes attr) override 150 TF_LOCKS_EXCLUDED(mu_); 151 void Compute(OpKernel* op_kernel, OpKernelContext* context) override; 152 void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context, 153 AsyncOpKernel::DoneCallback done) override; 154 Status Sync() override; 155 void Sync(const DoneCallback& done) override; 156 157 Status TryGetDeviceContext(DeviceContext** out_context) override 158 TF_LOCKS_EXCLUDED(mu_); 159 160 Status MakeTensorFromProto(const TensorProto& tensor_proto, 161 const AllocatorAttributes alloc_attrs, 162 Tensor* tensor) override TF_LOCKS_EXCLUDED(mu_); 163 164 // Allocate tensor on fast memory space. This is only applied to the new TPU 165 // hardware which has faster read/write memory. If the hardware doesn't 166 // have such memory space, we fallback to the ordinary memory space. 167 Status MakeFastMemTensorFromProto(const TensorProto& tensor_proto, 168 const AllocatorAttributes alloc_attrs, 169 Tensor* tensor) TF_LOCKS_EXCLUDED(mu_); 170 metadata()171 const Metadata& metadata() { return xla_metadata_; } 172 173 // Ensures the DeviceContext associated with this XlaDevice is created and 174 // valid (i.e. all streams are ok). If any state is not valid, a new 175 // DeviceContext will be created. 176 // 177 // TODO(b/111859745): The Eager context needs to call this method to recover 178 // from failures. 179 Status EnsureDeviceContextOk() TF_LOCKS_EXCLUDED(mu_); 180 181 // Instructs this XlaDevice to set a GpuDeviceInfo, which holds extra 182 // information for GPU and TPU devices. 183 Status UseGpuDeviceInfo() TF_LOCKS_EXCLUDED(mu_); 184 185 // Instructs this XlaDevice to return 'sync_on_completion' for 186 // AllowsSyncOnCompletion(). 187 void SetAllowsSyncOnCompletion(bool sync_on_completion) 188 TF_LOCKS_EXCLUDED(mu_); 189 bool AllowsSyncOnCompletion() const override TF_LOCKS_EXCLUDED(mu_); 190 191 // Installs an error handling callback when RefreshStatus sees !status.ok(). 192 void SetHandleDeviceErrorCallback(std::function<Status()> callback); 193 194 Status RefreshStatus() override TF_LOCKS_EXCLUDED(mu_); 195 196 private: 197 StatusOr<xla::LocalClient*> GetOrCreateClient() const; 198 Allocator* GetAllocatorLocked(AllocatorAttributes attr) 199 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 200 Status EnsureStreamOkLocked(xla::Backend* backend, const string& name, 201 std::shared_ptr<se::Stream>* stream, 202 bool* stream_was_changed) 203 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 204 205 // Return a pair of device context, the second one is fast_mem device context. 206 StatusOr<std::pair<XlaDeviceContext*, XlaDeviceContext*>> 207 GetDeviceContextLocked() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 208 209 Status MakeTensorFromProto(XlaDeviceContext* device_context, 210 const TensorProto& tensor_proto, 211 const AllocatorAttributes alloc_attrs, 212 Tensor* tensor); 213 214 // Handles error when RefreshStatus sees !status.ok(). 215 Status HandleDeviceError(); 216 217 mutable mutex mu_; 218 // The metadata of this XlaDevice. 219 const Metadata xla_metadata_; 220 // Which hardware device in the client's platform this XlaDevice controls. 221 const int device_ordinal_; 222 // The name of the device that is used to compile Ops for this XlaDevice. 223 const DeviceType jit_device_name_; 224 // The platform for this device. 225 se::Platform* const platform_; // Not owned. 226 // Intra-op threads to spawn (from SessionOptions). 227 const int intra_op_parallelism_threads_; 228 // Memory allocator associated with this device. 229 Allocator* xla_allocator_ TF_GUARDED_BY(mu_) = nullptr; // Not owned. 230 231 // Stream associated with this device. Operations enqueued on this 232 // stream are executed on the device. Operations include data 233 // copying back and forth between CPU and the device, and 234 // computations enqueued by XLA. 235 std::shared_ptr<se::Stream> stream_ TF_GUARDED_BY(mu_); 236 // If false, only stream_ is valid and all computation and transfers use 237 // stream_. If true, computation is performed by stream_ and transfers are 238 // performed by host_to_device/device_to_device stream or borrowing a stream 239 // for each device to host transfer. 240 const bool use_multiple_streams_; 241 // If use_multiple_streams_, host to device transfers are performed using this 242 // stream. 243 std::shared_ptr<se::Stream> host_to_device_stream_ TF_GUARDED_BY(mu_); 244 // If use_multiple_streams_, transfers between different devices are performed 245 // using these streams. 246 std::vector<std::shared_ptr<se::Stream>> device_to_device_streams_ 247 TF_GUARDED_BY(mu_); 248 249 const XlaCompiler::ShapeRepresentationFn shape_representation_fn_; 250 251 // The device context accessed by all users of the XlaDevice, set by calls to 252 // EnsureDeviceContextOk. If gpu_device_info_ is non-null, this pointer is 253 // also filled in to that struct. XlaDeviceContext is a ref-counted object. 254 XlaDeviceContext* device_context_ TF_GUARDED_BY(mu_) = nullptr; 255 256 // The device context will allocate memory on fast memory space on TPU. 257 // XlaDeviceContext is a ref-counted object. 258 XlaDeviceContext* fast_mem_device_context_ TF_GUARDED_BY(mu_) = nullptr; 259 260 // Holds extra information for GPU and TPU devices, e.g. the device context. 261 bool use_gpu_device_info_ TF_GUARDED_BY(mu_) = false; 262 std::unique_ptr<GpuDeviceInfo> gpu_device_info_ TF_GUARDED_BY(mu_); 263 264 // Thread pool used for running closures 265 std::unique_ptr<thread::ThreadPool> thread_pool_; 266 267 // True if the device allows XlaDevice::Sync to be called on completion 268 // regardless of status. 269 bool sync_on_completion_ TF_GUARDED_BY(mu_) = true; 270 271 // A callback that will be invoked when RefreshStatus sees a status error. 272 std::function<Status()> device_error_callback_ TF_GUARDED_BY(mu_); 273 274 // Set of devices to use. This controls which of the devices on the given 275 // platform will have resources allocated. For GPUs this will be 276 // filled from visible_gpu_devices list from session configuration. 277 absl::optional<std::set<int>> allowed_devices_; 278 279 const bool use_global_compute_stream_; 280 281 // A static vector with device_ordinal as its index, describing the global 282 // compute streams used in each XLA device. It is only used if 283 // `use_global_compute_stream` in `XlaDevice::Options` is set to true. 284 static mutex global_mu_; 285 static std::vector<std::shared_ptr<se::Stream>>* global_compute_streams_ 286 TF_GUARDED_BY(global_mu_); 287 }; 288 289 // Builds OpKernel registrations on 'device' for the JIT operators 290 // registered on 'jit_device'. Returns ownership of a XlaDeviceOpRegistrations 291 // object that encapsulates the kernel registrations. 292 struct XlaDeviceOpRegistrations { 293 std::vector<std::unique_ptr<kernel_factory::OpKernelRegistrar>> 294 op_kernel_registrars; 295 }; 296 XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device, 297 const char* jit_device); 298 299 Status DefaultPaddedShapeFn(const Tensor& tensor, xla::Shape* shape); 300 301 } // namespace tensorflow 302 303 #endif // TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_ 304