• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // The XlaDevice executes a TensorFlow graph using the XLA linear algebra
17 // runtime.
18 //
19 // Operators assigned to an XlaDevice are compiled into XLA computations.
20 // Tensors on an XlaDevice are thin wrappers around XLA ScopedShapedBuffers.
21 //
22 // XlaDevice is instantiated separately for each XLA backend (e.g., CPU or GPU),
23 // under different names (e.g., XLA_CPU or XLA_GPU).
24 
25 #ifndef TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_
26 #define TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_
27 #include <set>
28 
29 #include "absl/types/optional.h"
30 #include "tensorflow/compiler/jit/xla_device_context.h"
31 #include "tensorflow/compiler/jit/xla_tensor.h"
32 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
33 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
34 #include "tensorflow/compiler/xla/client/local_client.h"
35 #include "tensorflow/core/common_runtime/device_factory.h"
36 #include "tensorflow/core/common_runtime/local_device.h"
37 #include "tensorflow/core/framework/allocator.h"
38 #include "tensorflow/core/framework/device_base.h"
39 #include "tensorflow/core/framework/node_def_builder.h"
40 #include "tensorflow/core/framework/op_kernel.h"
41 #include "tensorflow/core/framework/resource_mgr.h"
42 #include "tensorflow/core/framework/tensor.h"
43 #include "tensorflow/core/framework/types.h"
44 #include "tensorflow/core/lib/core/status.h"
45 #include "tensorflow/core/platform/mutex.h"
46 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
47 
48 namespace tensorflow {
49 
50 class XlaDevice : public LocalDevice {
51  public:
52   // Given a tensor, sets `xla::Shape*` the shape of tensor's representation
53   // on device, fully padded. On error, the contents of `xla::Shape*`
54   // are undefined.
55   typedef std::function<Status(const Tensor&, xla::Shape*)> PaddedShapeFn;
56 
57   // Wrapper class to store metadata about the XlaDevice, where it can be
58   // retrieved e.g., when lazily creating the XlaCompilationCache device.
59   class Metadata {
60    public:
61     Metadata(int device_ordinal, se::Platform* platform,
62              const DeviceType& device_type,
63              XlaCompiler::ShapeRepresentationFn shape_representation_fn,
64              PaddedShapeFn padded_shape_fn, bool use_multiple_streams);
65 
66     // The index of the device on this host.
67     int device_ordinal() const;
68 
69     se::Platform* platform() const;
70     xla::LocalClient* client() const;
71     const DeviceType& jit_device_type() const;
shape_representation_fn()72     const XlaCompiler::ShapeRepresentationFn& shape_representation_fn() const {
73       return shape_representation_fn_;
74     }
padded_shape_fn()75     const PaddedShapeFn& padded_shape_fn() const { return padded_shape_fn_; }
76 
UseMultipleStreams()77     bool UseMultipleStreams() const { return use_multiple_streams_; }
78 
79    private:
80     const int device_ordinal_;
81     const DeviceType device_type_;
82     se::Platform* platform_;  // Not owned.
83     XlaCompiler::ShapeRepresentationFn shape_representation_fn_;
84     PaddedShapeFn padded_shape_fn_;
85     const bool use_multiple_streams_;
86 
87     TF_DISALLOW_COPY_AND_ASSIGN(Metadata);
88   };
89 
90   // Sets `*metadata` to the XlaDevice Metadata in the XLA device used by `ctx`.
91   static Status GetMetadata(OpKernelContext* ctx, const Metadata** metadata);
92 
93   // Sets `*metadata` to the XlaDevice Metadata in the XLA device used by `ctx`.
94   static Status GetMetadata(OpKernelConstruction* ctx,
95                             const Metadata** metadata);
96 
97   // Sets `*metadata` to the XlaDevice Metadata in the XLA device used by
98   // `device`.
99   static Status GetMetadataFromDevice(DeviceBase* device,
100                                       const XlaDevice::Metadata** metadata);
101 
102   struct Options {
103     // The StreamExecutor platform. Not owned. Must be non-null.
104     se::Platform* platform = nullptr;
105 
106     // The device name's prefix (e.g., "/task:7")
107     string device_name_prefix;
108 
109     // The name of the XLA device (e.g., "XLA_CPU")
110     string device_name;
111 
112     // The number of the device.
113     int device_ordinal = -1;
114 
115     // The name of the compilation device (e.g., "XLA_CPU_JIT");
116     string compilation_device_name;
117 
118     // If 'use_multiple_streams' is true, we create separate streams for
119     // compute, host-to-device, and device-to-host communication.
120     bool use_multiple_streams = false;
121 
122     // If true, the XLA devices with the same device ordinal will share the same
123     // compute stream. Otherwise each XLA device will having their own compute
124     // streams.
125     bool use_global_compute_stream = false;
126 
127     // A function that describes how the on-host shapes of
128     // a) argument and return value, for entry computations
129     // b) variables, for all computations,
130     // should be represented in XLA. Parameters/return values will be shaped
131     // according to this function, and reshaped back to/from their declared
132     // shapes for computations. Must be non-null.
133     XlaCompiler::ShapeRepresentationFn shape_representation_fn;
134 
135     // If padded_shape_fn is empty, a default implementation that returns
136     // the logical on-device shape without padding is used.
137     PaddedShapeFn padded_shape_fn;
138 
139     // Set of devices to use. This controls which of the devices on the given
140     // platform will have resources allocated. For GPUs this will be
141     // filled from visible_gpu_devices list from session configuration.
142     absl::optional<std::set<int>> allowed_devices;
143   };
144 
145   // Creates a new XLA Device.
146   XlaDevice(const SessionOptions& session_options, const Options& options);
147   ~XlaDevice() override;
148 
149   Allocator* GetAllocator(AllocatorAttributes attr) override
150       TF_LOCKS_EXCLUDED(mu_);
151   void Compute(OpKernel* op_kernel, OpKernelContext* context) override;
152   void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context,
153                     AsyncOpKernel::DoneCallback done) override;
154   Status Sync() override;
155   void Sync(const DoneCallback& done) override;
156 
157   Status TryGetDeviceContext(DeviceContext** out_context) override
158       TF_LOCKS_EXCLUDED(mu_);
159 
160   Status MakeTensorFromProto(const TensorProto& tensor_proto,
161                              const AllocatorAttributes alloc_attrs,
162                              Tensor* tensor) override TF_LOCKS_EXCLUDED(mu_);
163 
164   // Allocate tensor on fast memory space. This is only applied to the new TPU
165   // hardware which has faster read/write memory. If the hardware doesn't
166   // have such memory space, we fallback to the ordinary memory space.
167   Status MakeFastMemTensorFromProto(const TensorProto& tensor_proto,
168                                     const AllocatorAttributes alloc_attrs,
169                                     Tensor* tensor) TF_LOCKS_EXCLUDED(mu_);
170 
metadata()171   const Metadata& metadata() { return xla_metadata_; }
172 
173   // Ensures the DeviceContext associated with this XlaDevice is created and
174   // valid (i.e. all streams are ok). If any state is not valid, a new
175   // DeviceContext will be created.
176   //
177   // TODO(b/111859745): The Eager context needs to call this method to recover
178   // from failures.
179   Status EnsureDeviceContextOk() TF_LOCKS_EXCLUDED(mu_);
180 
181   // Instructs this XlaDevice to set a GpuDeviceInfo, which holds extra
182   // information for GPU and TPU devices.
183   Status UseGpuDeviceInfo() TF_LOCKS_EXCLUDED(mu_);
184 
185   // Instructs this XlaDevice to return 'sync_on_completion' for
186   // AllowsSyncOnCompletion().
187   void SetAllowsSyncOnCompletion(bool sync_on_completion)
188       TF_LOCKS_EXCLUDED(mu_);
189   bool AllowsSyncOnCompletion() const override TF_LOCKS_EXCLUDED(mu_);
190 
191   // Installs an error handling callback when RefreshStatus sees !status.ok().
192   void SetHandleDeviceErrorCallback(std::function<Status()> callback);
193 
194   Status RefreshStatus() override TF_LOCKS_EXCLUDED(mu_);
195 
196  private:
197   StatusOr<xla::LocalClient*> GetOrCreateClient() const;
198   Allocator* GetAllocatorLocked(AllocatorAttributes attr)
199       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
200   Status EnsureStreamOkLocked(xla::Backend* backend, const string& name,
201                               std::shared_ptr<se::Stream>* stream,
202                               bool* stream_was_changed)
203       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
204 
205   // Return a pair of device context, the second one is fast_mem device context.
206   StatusOr<std::pair<XlaDeviceContext*, XlaDeviceContext*>>
207   GetDeviceContextLocked() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
208 
209   Status MakeTensorFromProto(XlaDeviceContext* device_context,
210                              const TensorProto& tensor_proto,
211                              const AllocatorAttributes alloc_attrs,
212                              Tensor* tensor);
213 
214   // Handles error when RefreshStatus sees !status.ok().
215   Status HandleDeviceError();
216 
217   mutable mutex mu_;
218   // The metadata of this XlaDevice.
219   const Metadata xla_metadata_;
220   // Which hardware device in the client's platform this XlaDevice controls.
221   const int device_ordinal_;
222   // The name of the device that is used to compile Ops for this XlaDevice.
223   const DeviceType jit_device_name_;
224   // The platform for this device.
225   se::Platform* const platform_;  // Not owned.
226   // Intra-op threads to spawn (from SessionOptions).
227   const int intra_op_parallelism_threads_;
228   // Memory allocator associated with this device.
229   Allocator* xla_allocator_ TF_GUARDED_BY(mu_) = nullptr;  // Not owned.
230 
231   // Stream associated with this device. Operations enqueued on this
232   // stream are executed on the device. Operations include data
233   // copying back and forth between CPU and the device, and
234   // computations enqueued by XLA.
235   std::shared_ptr<se::Stream> stream_ TF_GUARDED_BY(mu_);
236   // If false, only stream_ is valid and all computation and transfers use
237   // stream_. If true, computation is performed by stream_ and transfers are
238   // performed by host_to_device/device_to_device stream or borrowing a stream
239   // for each device to host transfer.
240   const bool use_multiple_streams_;
241   // If use_multiple_streams_, host to device transfers are performed using this
242   // stream.
243   std::shared_ptr<se::Stream> host_to_device_stream_ TF_GUARDED_BY(mu_);
244   // If use_multiple_streams_, transfers between different devices are performed
245   // using these streams.
246   std::vector<std::shared_ptr<se::Stream>> device_to_device_streams_
247       TF_GUARDED_BY(mu_);
248 
249   const XlaCompiler::ShapeRepresentationFn shape_representation_fn_;
250 
251   // The device context accessed by all users of the XlaDevice, set by calls to
252   // EnsureDeviceContextOk. If gpu_device_info_ is non-null, this pointer is
253   // also filled in to that struct. XlaDeviceContext is a ref-counted object.
254   XlaDeviceContext* device_context_ TF_GUARDED_BY(mu_) = nullptr;
255 
256   // The device context will allocate memory on fast memory space on TPU.
257   // XlaDeviceContext is a ref-counted object.
258   XlaDeviceContext* fast_mem_device_context_ TF_GUARDED_BY(mu_) = nullptr;
259 
260   // Holds extra information for GPU and TPU devices, e.g. the device context.
261   bool use_gpu_device_info_ TF_GUARDED_BY(mu_) = false;
262   std::unique_ptr<GpuDeviceInfo> gpu_device_info_ TF_GUARDED_BY(mu_);
263 
264   // Thread pool used for running closures
265   std::unique_ptr<thread::ThreadPool> thread_pool_;
266 
267   // True if the device allows XlaDevice::Sync to be called on completion
268   // regardless of status.
269   bool sync_on_completion_ TF_GUARDED_BY(mu_) = true;
270 
271   // A callback that will be invoked when RefreshStatus sees a status error.
272   std::function<Status()> device_error_callback_ TF_GUARDED_BY(mu_);
273 
274   // Set of devices to use. This controls which of the devices on the given
275   // platform will have resources allocated. For GPUs this will be
276   // filled from visible_gpu_devices list from session configuration.
277   absl::optional<std::set<int>> allowed_devices_;
278 
279   const bool use_global_compute_stream_;
280 
281   // A static vector with device_ordinal as its index, describing the global
282   // compute streams used in each XLA device. It is only used if
283   // `use_global_compute_stream` in `XlaDevice::Options` is set to true.
284   static mutex global_mu_;
285   static std::vector<std::shared_ptr<se::Stream>>* global_compute_streams_
286       TF_GUARDED_BY(global_mu_);
287 };
288 
289 // Builds OpKernel registrations on 'device' for the JIT operators
290 // registered on 'jit_device'. Returns ownership of a XlaDeviceOpRegistrations
291 // object that encapsulates the kernel registrations.
292 struct XlaDeviceOpRegistrations {
293   std::vector<std::unique_ptr<kernel_factory::OpKernelRegistrar>>
294       op_kernel_registrars;
295 };
296 XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device,
297                                                    const char* jit_device);
298 
299 Status DefaultPaddedShapeFn(const Tensor& tensor, xla::Shape* shape);
300 
301 }  // namespace tensorflow
302 
303 #endif  // TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_
304