1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // The XlaDevice executes a TensorFlow graph using the XLA linear algebra 17 // runtime. 18 // 19 // Operators assigned to an XlaDevice are compiled into XLA computations. 20 // Tensors on an XlaDevice are thin wrappers around XLA ScopedShapedBuffers. 21 // 22 // XlaDevice is instantiated separately for each XLA backend (e.g., CPU or GPU), 23 // under different names (e.g., XLA_CPU or XLA_GPU). 24 25 #ifndef TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_ 26 #define TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_ 27 #include <set> 28 29 #include "absl/types/optional.h" 30 #include "tensorflow/compiler/jit/xla_device_context.h" 31 #include "tensorflow/compiler/jit/xla_tensor.h" 32 #include "tensorflow/compiler/tf2xla/xla_compiler.h" 33 #include "tensorflow/compiler/tf2xla/xla_op_registry.h" 34 #include "tensorflow/compiler/xla/client/local_client.h" 35 #include "tensorflow/core/common_runtime/device_factory.h" 36 #include "tensorflow/core/common_runtime/local_device.h" 37 #include "tensorflow/core/framework/allocator.h" 38 #include "tensorflow/core/framework/device_base.h" 39 #include "tensorflow/core/framework/node_def_builder.h" 40 #include "tensorflow/core/framework/op_kernel.h" 41 #include "tensorflow/core/framework/resource_mgr.h" 42 #include "tensorflow/core/framework/tensor.h" 43 #include "tensorflow/core/framework/types.h" 44 #include "tensorflow/core/lib/core/status.h" 45 #include "tensorflow/core/platform/mutex.h" 46 #include "tensorflow/core/platform/stream_executor_no_cuda.h" 47 48 namespace tensorflow { 49 50 class XlaDevice : public LocalDevice { 51 public: 52 // Given a tensor, sets `xla::Shape*` the shape of tensor's representation 53 // on device, fully padded. On error, the contents of `xla::Shape*` 54 // are undefined. 55 typedef std::function<Status(const Tensor&, xla::Shape*)> PaddedShapeFn; 56 57 // Wrapper class to store metadata about the XlaDevice, where it can be 58 // retrieved e.g., when lazily creating the XlaCompilationCache device. 59 class Metadata { 60 public: 61 Metadata(int device_ordinal, se::Platform* platform, 62 const DeviceType& device_type, 63 XlaCompiler::ShapeRepresentationFn shape_representation_fn, 64 PaddedShapeFn padded_shape_fn, bool use_multiple_streams); 65 66 // The index of the device on this host. 67 int device_ordinal() const; 68 69 se::Platform* platform() const; 70 xla::LocalClient* client() const; 71 const DeviceType& jit_device_type() const; shape_representation_fn()72 const XlaCompiler::ShapeRepresentationFn& shape_representation_fn() const { 73 return shape_representation_fn_; 74 } padded_shape_fn()75 const PaddedShapeFn& padded_shape_fn() const { return padded_shape_fn_; } 76 UseMultipleStreams()77 bool UseMultipleStreams() const { return use_multiple_streams_; } 78 79 private: 80 const int device_ordinal_; 81 const DeviceType device_type_; 82 se::Platform* platform_; // Not owned. 83 XlaCompiler::ShapeRepresentationFn shape_representation_fn_; 84 PaddedShapeFn padded_shape_fn_; 85 const bool use_multiple_streams_; 86 87 TF_DISALLOW_COPY_AND_ASSIGN(Metadata); 88 }; 89 90 // Sets `*metadata` to the XlaDevice Metadata in the XLA device used by `ctx`. 91 static Status GetMetadata(OpKernelContext* ctx, const Metadata** metadata); 92 93 // Sets `*metadata` to the XlaDevice Metadata in the XLA device used by `ctx`. 94 static Status GetMetadata(OpKernelConstruction* ctx, 95 const Metadata** metadata); 96 97 // Sets `*metadata` to the XlaDevice Metadata in the XLA device used by 98 // `device`. 99 static Status GetMetadataFromDevice(DeviceBase* device, 100 const XlaDevice::Metadata** metadata); 101 102 struct Options { 103 // The StreamExecutor platform. Not owned. Must be non-null. 104 se::Platform* platform = nullptr; 105 106 // The device name's prefix (e.g., "/task:7") 107 string device_name_prefix; 108 109 // The name of the XLA device (e.g., "XLA_CPU") 110 string device_name; 111 112 // The number of the device. 113 int device_ordinal = -1; 114 115 // The name of the compilation device (e.g., "XLA_CPU_JIT"); 116 string compilation_device_name; 117 118 // If 'use_multiple_streams' is true, we create separate streams for 119 // compute, host-to-device, and device-to-host communication. 120 bool use_multiple_streams = false; 121 122 // A function that describes how the on-host shapes of 123 // a) argument and return value, for entry computations 124 // b) variables, for all computations, 125 // should be represented in XLA. Parameters/return values will be shaped 126 // according to this function, and reshaped back to/from their declared 127 // shapes for computations. Must be non-null. 128 XlaCompiler::ShapeRepresentationFn shape_representation_fn; 129 130 // If padded_shape_fn is empty, a default implementation that returns 131 // the logical on-device shape without padding is used. 132 PaddedShapeFn padded_shape_fn; 133 134 // Set of devices to use. This controls which of the devices on the given 135 // platform will have resources allocated. For GPUs this will be 136 // filled from visible_gpu_devices list from session configuration. 137 absl::optional<std::set<int>> allowed_devices; 138 }; 139 140 // Creates a new XLA Device. 141 XlaDevice(const SessionOptions& session_options, const Options& options); 142 ~XlaDevice() override; 143 144 Allocator* GetAllocator(AllocatorAttributes attr) override 145 TF_LOCKS_EXCLUDED(mu_); 146 void Compute(OpKernel* op_kernel, OpKernelContext* context) override; 147 void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context, 148 AsyncOpKernel::DoneCallback done) override; 149 Status Sync() override; 150 void Sync(const DoneCallback& done) override; 151 152 Status TryGetDeviceContext(DeviceContext** out_context) override 153 TF_LOCKS_EXCLUDED(mu_); 154 155 Status MakeTensorFromProto(const TensorProto& tensor_proto, 156 const AllocatorAttributes alloc_attrs, 157 Tensor* tensor) override TF_LOCKS_EXCLUDED(mu_); 158 159 // Allocate tensor on fast memory space. This is only applied to the new TPU 160 // hardware which has faster read/write memory. If the hardware doesn't 161 // have such memory space, we fallback to the ordinary memory space. 162 Status MakeFastMemTensorFromProto(const TensorProto& tensor_proto, 163 const AllocatorAttributes alloc_attrs, 164 Tensor* tensor) TF_LOCKS_EXCLUDED(mu_); 165 metadata()166 const Metadata& metadata() { return xla_metadata_; } 167 168 // Ensures the DeviceContext associated with this XlaDevice is created and 169 // valid (i.e. all streams are ok). If any state is not valid, a new 170 // DeviceContext will be created. 171 // 172 // TODO(b/111859745): The Eager context needs to call this method to recover 173 // from failures. 174 Status EnsureDeviceContextOk() TF_LOCKS_EXCLUDED(mu_); 175 176 // Instructs this XlaDevice to set a GpuDeviceInfo, which holds extra 177 // information for GPU and TPU devices. 178 Status UseGpuDeviceInfo() TF_LOCKS_EXCLUDED(mu_); 179 180 // Instructs this XlaDevice to return 'sync_on_completion' for 181 // AllowsSyncOnCompletion(). 182 void SetAllowsSyncOnCompletion(bool sync_on_completion) 183 TF_LOCKS_EXCLUDED(mu_); 184 bool AllowsSyncOnCompletion() const override TF_LOCKS_EXCLUDED(mu_); 185 186 // Installs an error handling callback when RefreshStatus sees !status.ok(). 187 void SetHandleDeviceErrorCallback(std::function<Status()> callback); 188 189 Status RefreshStatus() override TF_LOCKS_EXCLUDED(mu_); 190 191 private: 192 xla::StatusOr<xla::LocalClient*> GetOrCreateClient() const; 193 Allocator* GetAllocatorLocked(AllocatorAttributes attr) 194 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 195 Status EnsureStreamOkLocked(xla::Backend* backend, const string& name, 196 std::shared_ptr<se::Stream>* stream, 197 bool* stream_was_changed) 198 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 199 200 // Return a pair of device context, the second one is fast_mem device context. 201 xla::StatusOr<std::pair<XlaDeviceContext*, XlaDeviceContext*>> 202 GetDeviceContextLocked() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 203 204 205 Status MakeTensorFromProto(XlaDeviceContext* device_context, 206 const TensorProto& tensor_proto, 207 const AllocatorAttributes alloc_attrs, 208 Tensor* tensor); 209 210 // Handles error when RefreshStatus sees !status.ok(). 211 Status HandleDeviceError(); 212 213 mutable mutex mu_; 214 // The metadata of this XlaDevice. 215 const Metadata xla_metadata_; 216 // Which hardware device in the client's platform this XlaDevice controls. 217 const int device_ordinal_; 218 // The name of the device that is used to compile Ops for this XlaDevice. 219 const DeviceType jit_device_name_; 220 // The platform for this device. 221 se::Platform* const platform_; // Not owned. 222 // Intra-op threads to spawn (from SessionOptions). 223 const int intra_op_parallelism_threads_; 224 // Memory allocator associated with this device. 225 Allocator* xla_allocator_ TF_GUARDED_BY(mu_) = nullptr; // Not owned. 226 227 // Stream associated with this device. Operations enqueued on this 228 // stream are executed on the device. Operations include data 229 // copying back and forth between CPU and the device, and 230 // computations enqueued by XLA. 231 std::shared_ptr<se::Stream> stream_ TF_GUARDED_BY(mu_); 232 // If false, only stream_ is valid and all computation and transfers use 233 // stream_. If true, computation is performed by stream_ and transfers are 234 // performed by host_to_device/device_to_device stream or borrowing a stream 235 // for each device to host transfer. 236 const bool use_multiple_streams_; 237 // If use_multiple_streams_, host to device transfers are performed using this 238 // stream. 239 std::shared_ptr<se::Stream> host_to_device_stream_ TF_GUARDED_BY(mu_); 240 // If use_multiple_streams_, transfers between different devices are performed 241 // using these streams. 242 std::vector<std::shared_ptr<se::Stream>> device_to_device_streams_ 243 TF_GUARDED_BY(mu_); 244 245 const XlaCompiler::ShapeRepresentationFn shape_representation_fn_; 246 247 // The device context accessed by all users of the XlaDevice, set by calls to 248 // EnsureDeviceContextOk. If gpu_device_info_ is non-null, this pointer is 249 // also filled in to that struct. XlaDeviceContext is a ref-counted object. 250 XlaDeviceContext* device_context_ TF_GUARDED_BY(mu_) = nullptr; 251 252 // The device context will allocate memory on fast memory space on TPU. 253 // XlaDeviceContext is a ref-counted object. 254 XlaDeviceContext* fast_mem_device_context_ TF_GUARDED_BY(mu_) = nullptr; 255 256 // Holds extra information for GPU and TPU devices, e.g. the device context. 257 bool use_gpu_device_info_ TF_GUARDED_BY(mu_) = false; 258 std::unique_ptr<GpuDeviceInfo> gpu_device_info_ TF_GUARDED_BY(mu_); 259 260 // Thread pool used for running closures 261 std::unique_ptr<thread::ThreadPool> thread_pool_; 262 263 // True if the device allows XlaDevice::Sync to be called on completion 264 // regardless of status. 265 bool sync_on_completion_ TF_GUARDED_BY(mu_) = true; 266 267 // A callback that will be invoked when RefreshStatus sees a status error. 268 std::function<Status()> device_error_callback_ TF_GUARDED_BY(mu_); 269 270 // Set of devices to use. This controls which of the devices on the given 271 // platform will have resources allocated. For GPUs this will be 272 // filled from visible_gpu_devices list from session configuration. 273 absl::optional<std::set<int>> allowed_devices_; 274 }; 275 276 // Builds OpKernel registrations on 'device' for the JIT operators 277 // registered on 'jit_device'. Returns ownership of a XlaDeviceOpRegistrations 278 // object that encapsulates the kernel registrations. 279 struct XlaDeviceOpRegistrations { 280 std::vector<std::unique_ptr<kernel_factory::OpKernelRegistrar>> 281 op_kernel_registrars; 282 }; 283 XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device, 284 const char* jit_device); 285 286 Status DefaultPaddedShapeFn(const Tensor& tensor, xla::Shape* shape); 287 288 } // namespace tensorflow 289 290 #endif // TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_ 291