• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // The XlaDevice executes a TensorFlow graph using the XLA linear algebra
17 // runtime.
18 //
19 // Operators assigned to an XlaDevice are compiled into XLA computations.
20 // Tensors on an XlaDevice are thin wrappers around XLA ScopedShapedBuffers.
21 //
22 // XlaDevice is instantiated separately for each XLA backend (e.g., CPU or GPU),
23 // under different names (e.g., XLA_CPU or XLA_GPU).
24 
25 #ifndef TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_
26 #define TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_
27 #include <set>
28 
29 #include "absl/types/optional.h"
30 #include "tensorflow/compiler/jit/xla_device_context.h"
31 #include "tensorflow/compiler/jit/xla_tensor.h"
32 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
33 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
34 #include "tensorflow/compiler/xla/client/local_client.h"
35 #include "tensorflow/core/common_runtime/device_factory.h"
36 #include "tensorflow/core/common_runtime/local_device.h"
37 #include "tensorflow/core/framework/allocator.h"
38 #include "tensorflow/core/framework/device_base.h"
39 #include "tensorflow/core/framework/node_def_builder.h"
40 #include "tensorflow/core/framework/op_kernel.h"
41 #include "tensorflow/core/framework/resource_mgr.h"
42 #include "tensorflow/core/framework/tensor.h"
43 #include "tensorflow/core/framework/types.h"
44 #include "tensorflow/core/lib/core/status.h"
45 #include "tensorflow/core/platform/mutex.h"
46 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
47 
48 namespace tensorflow {
49 
50 class XlaDevice : public LocalDevice {
51  public:
52   // Given a tensor, sets `xla::Shape*` the shape of tensor's representation
53   // on device, fully padded. On error, the contents of `xla::Shape*`
54   // are undefined.
55   typedef std::function<Status(const Tensor&, xla::Shape*)> PaddedShapeFn;
56 
57   // Wrapper class to store metadata about the XlaDevice, where it can be
58   // retrieved e.g., when lazily creating the XlaCompilationCache device.
59   class Metadata {
60    public:
61     Metadata(int device_ordinal, se::Platform* platform,
62              const DeviceType& device_type,
63              XlaCompiler::ShapeRepresentationFn shape_representation_fn,
64              PaddedShapeFn padded_shape_fn, bool use_multiple_streams);
65 
66     // The index of the device on this host.
67     int device_ordinal() const;
68 
69     se::Platform* platform() const;
70     xla::LocalClient* client() const;
71     const DeviceType& jit_device_type() const;
shape_representation_fn()72     const XlaCompiler::ShapeRepresentationFn& shape_representation_fn() const {
73       return shape_representation_fn_;
74     }
padded_shape_fn()75     const PaddedShapeFn& padded_shape_fn() const { return padded_shape_fn_; }
76 
UseMultipleStreams()77     bool UseMultipleStreams() const { return use_multiple_streams_; }
78 
79    private:
80     const int device_ordinal_;
81     const DeviceType device_type_;
82     se::Platform* platform_;  // Not owned.
83     XlaCompiler::ShapeRepresentationFn shape_representation_fn_;
84     PaddedShapeFn padded_shape_fn_;
85     const bool use_multiple_streams_;
86 
87     TF_DISALLOW_COPY_AND_ASSIGN(Metadata);
88   };
89 
90   // Sets `*metadata` to the XlaDevice Metadata in the XLA device used by `ctx`.
91   static Status GetMetadata(OpKernelContext* ctx, const Metadata** metadata);
92 
93   // Sets `*metadata` to the XlaDevice Metadata in the XLA device used by `ctx`.
94   static Status GetMetadata(OpKernelConstruction* ctx,
95                             const Metadata** metadata);
96 
97   // Sets `*metadata` to the XlaDevice Metadata in the XLA device used by
98   // `device`.
99   static Status GetMetadataFromDevice(DeviceBase* device,
100                                       const XlaDevice::Metadata** metadata);
101 
102   struct Options {
103     // The StreamExecutor platform. Not owned. Must be non-null.
104     se::Platform* platform = nullptr;
105 
106     // The device name's prefix (e.g., "/task:7")
107     string device_name_prefix;
108 
109     // The name of the XLA device (e.g., "XLA_CPU")
110     string device_name;
111 
112     // The number of the device.
113     int device_ordinal = -1;
114 
115     // The name of the compilation device (e.g., "XLA_CPU_JIT");
116     string compilation_device_name;
117 
118     // If 'use_multiple_streams' is true, we create separate streams for
119     // compute, host-to-device, and device-to-host communication.
120     bool use_multiple_streams = false;
121 
122     // A function that describes how the on-host shapes of
123     // a) argument and return value, for entry computations
124     // b) variables, for all computations,
125     // should be represented in XLA. Parameters/return values will be shaped
126     // according to this function, and reshaped back to/from their declared
127     // shapes for computations. Must be non-null.
128     XlaCompiler::ShapeRepresentationFn shape_representation_fn;
129 
130     // If padded_shape_fn is empty, a default implementation that returns
131     // the logical on-device shape without padding is used.
132     PaddedShapeFn padded_shape_fn;
133 
134     // Set of devices to use. This controls which of the devices on the given
135     // platform will have resources allocated. For GPUs this will be
136     // filled from visible_gpu_devices list from session configuration.
137     absl::optional<std::set<int>> allowed_devices;
138   };
139 
140   // Creates a new XLA Device.
141   XlaDevice(const SessionOptions& session_options, const Options& options);
142   ~XlaDevice() override;
143 
144   Allocator* GetAllocator(AllocatorAttributes attr) override
145       TF_LOCKS_EXCLUDED(mu_);
146   void Compute(OpKernel* op_kernel, OpKernelContext* context) override;
147   void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context,
148                     AsyncOpKernel::DoneCallback done) override;
149   Status Sync() override;
150   void Sync(const DoneCallback& done) override;
151 
152   Status TryGetDeviceContext(DeviceContext** out_context) override
153       TF_LOCKS_EXCLUDED(mu_);
154 
155   Status MakeTensorFromProto(const TensorProto& tensor_proto,
156                              const AllocatorAttributes alloc_attrs,
157                              Tensor* tensor) override TF_LOCKS_EXCLUDED(mu_);
158 
159   // Allocate tensor on fast memory space. This is only applied to the new TPU
160   // hardware which has faster read/write memory. If the hardware doesn't
161   // have such memory space, we fallback to the ordinary memory space.
162   Status MakeFastMemTensorFromProto(const TensorProto& tensor_proto,
163                                     const AllocatorAttributes alloc_attrs,
164                                     Tensor* tensor) TF_LOCKS_EXCLUDED(mu_);
165 
metadata()166   const Metadata& metadata() { return xla_metadata_; }
167 
168   // Ensures the DeviceContext associated with this XlaDevice is created and
169   // valid (i.e. all streams are ok). If any state is not valid, a new
170   // DeviceContext will be created.
171   //
172   // TODO(b/111859745): The Eager context needs to call this method to recover
173   // from failures.
174   Status EnsureDeviceContextOk() TF_LOCKS_EXCLUDED(mu_);
175 
176   // Instructs this XlaDevice to set a GpuDeviceInfo, which holds extra
177   // information for GPU and TPU devices.
178   Status UseGpuDeviceInfo() TF_LOCKS_EXCLUDED(mu_);
179 
180   // Instructs this XlaDevice to return 'sync_on_completion' for
181   // AllowsSyncOnCompletion().
182   void SetAllowsSyncOnCompletion(bool sync_on_completion)
183       TF_LOCKS_EXCLUDED(mu_);
184   bool AllowsSyncOnCompletion() const override TF_LOCKS_EXCLUDED(mu_);
185 
186   // Installs an error handling callback when RefreshStatus sees !status.ok().
187   void SetHandleDeviceErrorCallback(std::function<Status()> callback);
188 
189   Status RefreshStatus() override TF_LOCKS_EXCLUDED(mu_);
190 
191  private:
192   xla::StatusOr<xla::LocalClient*> GetOrCreateClient() const;
193   Allocator* GetAllocatorLocked(AllocatorAttributes attr)
194       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
195   Status EnsureStreamOkLocked(xla::Backend* backend, const string& name,
196                               std::shared_ptr<se::Stream>* stream,
197                               bool* stream_was_changed)
198       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
199 
200   // Return a pair of device context, the second one is fast_mem device context.
201   xla::StatusOr<std::pair<XlaDeviceContext*, XlaDeviceContext*>>
202   GetDeviceContextLocked() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
203 
204 
205   Status MakeTensorFromProto(XlaDeviceContext* device_context,
206                              const TensorProto& tensor_proto,
207                              const AllocatorAttributes alloc_attrs,
208                              Tensor* tensor);
209 
210   // Handles error when RefreshStatus sees !status.ok().
211   Status HandleDeviceError();
212 
213   mutable mutex mu_;
214   // The metadata of this XlaDevice.
215   const Metadata xla_metadata_;
216   // Which hardware device in the client's platform this XlaDevice controls.
217   const int device_ordinal_;
218   // The name of the device that is used to compile Ops for this XlaDevice.
219   const DeviceType jit_device_name_;
220   // The platform for this device.
221   se::Platform* const platform_;  // Not owned.
222   // Intra-op threads to spawn (from SessionOptions).
223   const int intra_op_parallelism_threads_;
224   // Memory allocator associated with this device.
225   Allocator* xla_allocator_ TF_GUARDED_BY(mu_) = nullptr;  // Not owned.
226 
227   // Stream associated with this device. Operations enqueued on this
228   // stream are executed on the device. Operations include data
229   // copying back and forth between CPU and the device, and
230   // computations enqueued by XLA.
231   std::shared_ptr<se::Stream> stream_ TF_GUARDED_BY(mu_);
232   // If false, only stream_ is valid and all computation and transfers use
233   // stream_. If true, computation is performed by stream_ and transfers are
234   // performed by host_to_device/device_to_device stream or borrowing a stream
235   // for each device to host transfer.
236   const bool use_multiple_streams_;
237   // If use_multiple_streams_, host to device transfers are performed using this
238   // stream.
239   std::shared_ptr<se::Stream> host_to_device_stream_ TF_GUARDED_BY(mu_);
240   // If use_multiple_streams_, transfers between different devices are performed
241   // using these streams.
242   std::vector<std::shared_ptr<se::Stream>> device_to_device_streams_
243       TF_GUARDED_BY(mu_);
244 
245   const XlaCompiler::ShapeRepresentationFn shape_representation_fn_;
246 
247   // The device context accessed by all users of the XlaDevice, set by calls to
248   // EnsureDeviceContextOk. If gpu_device_info_ is non-null, this pointer is
249   // also filled in to that struct. XlaDeviceContext is a ref-counted object.
250   XlaDeviceContext* device_context_ TF_GUARDED_BY(mu_) = nullptr;
251 
252   // The device context will allocate memory on fast memory space on TPU.
253   // XlaDeviceContext is a ref-counted object.
254   XlaDeviceContext* fast_mem_device_context_ TF_GUARDED_BY(mu_) = nullptr;
255 
256   // Holds extra information for GPU and TPU devices, e.g. the device context.
257   bool use_gpu_device_info_ TF_GUARDED_BY(mu_) = false;
258   std::unique_ptr<GpuDeviceInfo> gpu_device_info_ TF_GUARDED_BY(mu_);
259 
260   // Thread pool used for running closures
261   std::unique_ptr<thread::ThreadPool> thread_pool_;
262 
263   // True if the device allows XlaDevice::Sync to be called on completion
264   // regardless of status.
265   bool sync_on_completion_ TF_GUARDED_BY(mu_) = true;
266 
267   // A callback that will be invoked when RefreshStatus sees a status error.
268   std::function<Status()> device_error_callback_ TF_GUARDED_BY(mu_);
269 
270   // Set of devices to use. This controls which of the devices on the given
271   // platform will have resources allocated. For GPUs this will be
272   // filled from visible_gpu_devices list from session configuration.
273   absl::optional<std::set<int>> allowed_devices_;
274 };
275 
276 // Builds OpKernel registrations on 'device' for the JIT operators
277 // registered on 'jit_device'. Returns ownership of a XlaDeviceOpRegistrations
278 // object that encapsulates the kernel registrations.
279 struct XlaDeviceOpRegistrations {
280   std::vector<std::unique_ptr<kernel_factory::OpKernelRegistrar>>
281       op_kernel_registrars;
282 };
283 XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device,
284                                                    const char* jit_device);
285 
286 Status DefaultPaddedShapeFn(const Tensor& tensor, xla::Shape* shape);
287 
288 }  // namespace tensorflow
289 
290 #endif  // TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_
291