1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // The XlaDevice executes a TensorFlow graph using the XLA linear algebra 17 // runtime. 18 // 19 // Operators assigned to an XlaDevice are compiled into XLA computations. 20 // Tensors on an XlaDevice are thin wrappers around XLA ScopedShapedBuffers. 21 // 22 // XlaDevice is instantiated separately for each XLA backend (e.g., CPU or GPU), 23 // under different names (e.g., XLA_CPU or XLA_GPU). 24 25 #ifndef TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_ 26 #define TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_ 27 #include <set> 28 29 #include "absl/types/optional.h" 30 #include "tensorflow/compiler/jit/xla_device_context.h" 31 #include "tensorflow/compiler/jit/xla_tensor.h" 32 #include "tensorflow/compiler/tf2xla/xla_compiler.h" 33 #include "tensorflow/compiler/tf2xla/xla_op_registry.h" 34 #include "tensorflow/compiler/xla/client/local_client.h" 35 #include "tensorflow/core/common_runtime/device_factory.h" 36 #include "tensorflow/core/common_runtime/local_device.h" 37 #include "tensorflow/core/framework/allocator.h" 38 #include "tensorflow/core/framework/device_base.h" 39 #include "tensorflow/core/framework/node_def_builder.h" 40 #include "tensorflow/core/framework/op_kernel.h" 41 #include "tensorflow/core/framework/resource_mgr.h" 42 #include "tensorflow/core/framework/tensor.h" 43 #include "tensorflow/core/framework/types.h" 44 #include "tensorflow/core/lib/core/status.h" 45 #include "tensorflow/core/platform/mutex.h" 46 #include "tensorflow/core/platform/stream_executor_no_cuda.h" 47 48 namespace tensorflow { 49 50 class XlaDevice : public LocalDevice { 51 public: 52 // Given a tensor, sets `xla::Shape*` the shape of tensor's representation 53 // on device, fully padded. On error, the contents of `xla::Shape*` 54 // are undefined. 55 typedef std::function<Status(const Tensor&, xla::Shape*)> PaddedShapeFn; 56 57 // Wrapper class to store metadata about the XlaDevice, where it can be 58 // retrieved e.g., when lazily creating the XlaCompilationCache device. 59 class Metadata { 60 public: 61 Metadata(int device_ordinal, se::Platform* platform, 62 const DeviceType& device_type, 63 XlaCompiler::ShapeRepresentationFn shape_representation_fn, 64 PaddedShapeFn padded_shape_fn, bool use_multiple_streams); 65 66 // The index of the device on this host. 67 int device_ordinal() const; 68 69 se::Platform* platform() const; 70 xla::LocalClient* client() const; 71 const DeviceType& jit_device_type() const; shape_representation_fn()72 const XlaCompiler::ShapeRepresentationFn& shape_representation_fn() const { 73 return shape_representation_fn_; 74 } padded_shape_fn()75 const PaddedShapeFn& padded_shape_fn() const { return padded_shape_fn_; } 76 UseMultipleStreams()77 bool UseMultipleStreams() const { return use_multiple_streams_; } 78 79 private: 80 const int device_ordinal_; 81 const DeviceType device_type_; 82 se::Platform* platform_; // Not owned. 83 XlaCompiler::ShapeRepresentationFn shape_representation_fn_; 84 PaddedShapeFn padded_shape_fn_; 85 const bool use_multiple_streams_; 86 87 TF_DISALLOW_COPY_AND_ASSIGN(Metadata); 88 }; 89 90 // Sets `*metadata` to the XlaDevice Metadata in the XLA device used by `ctx`. 91 static Status GetMetadata(OpKernelContext* ctx, const Metadata** metadata); 92 93 // Sets `*metadata` to the XlaDevice Metadata in the XLA device used by `ctx`. 94 static Status GetMetadata(OpKernelConstruction* ctx, 95 const Metadata** metadata); 96 97 struct Options { 98 // The StreamExecutor platform. Not owned. Must be non-null. 99 se::Platform* platform = nullptr; 100 101 // The device name's prefix (e.g., "/task:7") 102 string device_name_prefix; 103 104 // The name of the XLA device (e.g., "XLA_CPU") 105 string device_name; 106 107 // The number of the device. 108 int device_ordinal = -1; 109 110 // The name of the compilation device (e.g., "XLA_CPU_JIT"); 111 string compilation_device_name; 112 113 // If 'use_multiple_streams' is true, we create separate streams for 114 // compute, host-to-device, and device-to-host communication. 115 bool use_multiple_streams = false; 116 117 // A function that describes how the on-host shapes of 118 // a) argument and return value, for entry computations 119 // b) variables, for all computations, 120 // should be represented in XLA. Parameters/return values will be shaped 121 // according to this function, and reshaped back to/from their declared 122 // shapes for computations. Must be non-null. 123 XlaCompiler::ShapeRepresentationFn shape_representation_fn; 124 125 // If padded_shape_fn is empty, a default implementation that returns 126 // the logical on-device shape without padding is used. 127 PaddedShapeFn padded_shape_fn; 128 129 // Set of devices to use. This controls which of the devices on the given 130 // platform will have resources allocated. For GPUs this will be 131 // filled from visible_gpu_devices list from session configuration. 132 absl::optional<std::set<int>> allowed_devices; 133 }; 134 135 // Creates a new XLA Device. 136 XlaDevice(const SessionOptions& session_options, const Options& options); 137 ~XlaDevice() override; 138 139 Allocator* GetAllocator(AllocatorAttributes attr) override 140 LOCKS_EXCLUDED(mu_); 141 void Compute(OpKernel* op_kernel, OpKernelContext* context) override; 142 void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context, 143 AsyncOpKernel::DoneCallback done) override; 144 Status Sync() override; 145 void Sync(const DoneCallback& done) override; 146 147 Status FillContextMap(const Graph* graph, 148 DeviceContextMap* device_context_map) override 149 LOCKS_EXCLUDED(mu_); 150 151 Status MakeTensorFromProto(const TensorProto& tensor_proto, 152 const AllocatorAttributes alloc_attrs, 153 Tensor* tensor) override LOCKS_EXCLUDED(mu_); 154 metadata()155 const Metadata& metadata() { return xla_metadata_; } 156 157 // Ensures the DeviceContext associated with this XlaDevice is created and 158 // valid (i.e. all streams are ok). If any state is not valid, a new 159 // DeviceContext will be created. 160 // 161 // TODO(b/111859745): The Eager context needs to call this method to recover 162 // from failures. 163 Status EnsureDeviceContextOk() LOCKS_EXCLUDED(mu_); 164 165 // Instructs this XlaDevice to set a GpuDeviceInfo, which holds extra 166 // information for GPU and TPU devices. 167 Status UseGpuDeviceInfo() LOCKS_EXCLUDED(mu_); 168 169 // Instructs this XlaDevice to return 'sync_on_completion' for 170 // AllowsSyncOnCompletion(). 171 void SetAllowsSyncOnCompletion(bool sync_on_completion) LOCKS_EXCLUDED(mu_); 172 bool AllowsSyncOnCompletion() const override LOCKS_EXCLUDED(mu_); 173 174 // Installs an error handling callback when RefreshStatus sees !status.ok(). 175 void SetHandleDeviceErrorCallback(std::function<Status()> callback); 176 177 Status RefreshStatus() override LOCKS_EXCLUDED(mu_); 178 179 private: 180 xla::LocalClient* client() const; 181 Allocator* GetAllocatorLocked(AllocatorAttributes attr) 182 EXCLUSIVE_LOCKS_REQUIRED(mu_); 183 Status EnsureStreamOkLocked(xla::Backend* backend, const string& name, 184 std::shared_ptr<se::Stream>* stream, 185 bool* stream_was_changed) 186 EXCLUSIVE_LOCKS_REQUIRED(mu_); 187 xla::StatusOr<XlaDeviceContext*> GetDeviceContextLocked() 188 EXCLUSIVE_LOCKS_REQUIRED(mu_); 189 190 static Status GetMetadataFromDevice(DeviceBase* device, 191 const XlaDevice::Metadata** metadata); 192 193 // Handles error when RefreshStatus sees !status.ok(). 194 Status HandleDeviceError(); 195 196 mutable mutex mu_; 197 // The metadata of this XlaDevice. 198 const Metadata xla_metadata_; 199 // Which hardware device in the client's platform this XlaDevice controls. 200 const int device_ordinal_; 201 // The name of the device that is used to compile Ops for this XlaDevice. 202 const DeviceType jit_device_name_; 203 // The platform for this device. 204 se::Platform* const platform_; // Not owned. 205 // Memory allocator associated with this device. 206 Allocator* xla_allocator_ GUARDED_BY(mu_) = nullptr; // Not owned. 207 208 // Stream associated with this device. Operations enqueued on this 209 // stream are executed on the device. Operations include data 210 // copying back and forth between CPU and the device, and 211 // computations enqueued by XLA. 212 std::shared_ptr<se::Stream> stream_ GUARDED_BY(mu_); 213 // If false, only stream_ is valid and all computation and transfers use 214 // stream_. If true, computation is performed by stream_ and transfers are 215 // performed by host_to_device/device_to_device stream or borrowing a stream 216 // for each device to host transfer. 217 const bool use_multiple_streams_; 218 // If use_multiple_streams_, host to device transfers are performed using this 219 // stream. 220 std::shared_ptr<se::Stream> host_to_device_stream_ GUARDED_BY(mu_); 221 // If use_multiple_streams_, transfers between different devices are performed 222 // using these streams. 223 std::vector<std::shared_ptr<se::Stream>> device_to_device_streams_ 224 GUARDED_BY(mu_); 225 226 const XlaCompiler::ShapeRepresentationFn shape_representation_fn_; 227 228 // The device context accessed by all users of the XlaDevice, set by calls to 229 // EnsureDeviceContextOk. If gpu_device_info_ is non-null, this pointer is 230 // also filled in to that struct. XlaDeviceContext is a ref-counted object. 231 XlaDeviceContext* device_context_ GUARDED_BY(mu_) = nullptr; 232 233 // Holds extra information for GPU and TPU devices, e.g. the device context. 234 bool use_gpu_device_info_ GUARDED_BY(mu_) = false; 235 std::unique_ptr<GpuDeviceInfo> gpu_device_info_ GUARDED_BY(mu_); 236 237 // Thread pool used for running closures 238 std::unique_ptr<thread::ThreadPool> thread_pool_; 239 240 // True if the device allows XlaDevice::Sync to be called on completion 241 // regardless of status. 242 bool sync_on_completion_ GUARDED_BY(mu_) = true; 243 244 // A callback that will be invoked when RefreshStatus sees a status error. 245 std::function<Status()> device_error_callback_ GUARDED_BY(mu_); 246 247 // Set of devices to use. This controls which of the devices on the given 248 // platform will have resources allocated. For GPUs this will be 249 // filled from visible_gpu_devices list from session configuration. 250 absl::optional<std::set<int>> allowed_devices_; 251 }; 252 253 // Builds OpKernel registrations on 'device' for the JIT operators 254 // registered on 'jit_device'. Returns ownership of a XlaDeviceOpRegistrations 255 // object that encapsulates the kernel registrations. 256 struct XlaDeviceOpRegistrations { 257 std::vector<std::unique_ptr<kernel_factory::OpKernelRegistrar>> 258 op_kernel_registrars; 259 }; 260 XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device, 261 const char* jit_device); 262 263 } // namespace tensorflow 264 265 #endif // TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_ 266