• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // The XlaDevice executes a TensorFlow graph using the XLA linear algebra
17 // runtime.
18 //
19 // Operators assigned to an XlaDevice are compiled into XLA computations.
20 // Tensors on an XlaDevice are thin wrappers around XLA ScopedShapedBuffers.
21 //
22 // XlaDevice is instantiated separately for each XLA backend (e.g., CPU or GPU),
23 // under different names (e.g., XLA_CPU or XLA_GPU).
24 
25 #ifndef TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_
26 #define TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_
27 #include <set>
28 
29 #include "absl/types/optional.h"
30 #include "tensorflow/compiler/jit/xla_device_context.h"
31 #include "tensorflow/compiler/jit/xla_tensor.h"
32 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
33 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
34 #include "tensorflow/compiler/xla/client/local_client.h"
35 #include "tensorflow/core/common_runtime/device_factory.h"
36 #include "tensorflow/core/common_runtime/local_device.h"
37 #include "tensorflow/core/framework/allocator.h"
38 #include "tensorflow/core/framework/device_base.h"
39 #include "tensorflow/core/framework/node_def_builder.h"
40 #include "tensorflow/core/framework/op_kernel.h"
41 #include "tensorflow/core/framework/resource_mgr.h"
42 #include "tensorflow/core/framework/tensor.h"
43 #include "tensorflow/core/framework/types.h"
44 #include "tensorflow/core/lib/core/status.h"
45 #include "tensorflow/core/platform/mutex.h"
46 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
47 
48 namespace tensorflow {
49 
50 class XlaDevice : public LocalDevice {
51  public:
52   // Given a tensor, sets `xla::Shape*` the shape of tensor's representation
53   // on device, fully padded. On error, the contents of `xla::Shape*`
54   // are undefined.
55   typedef std::function<Status(const Tensor&, xla::Shape*)> PaddedShapeFn;
56 
57   // Wrapper class to store metadata about the XlaDevice, where it can be
58   // retrieved e.g., when lazily creating the XlaCompilationCache device.
59   class Metadata {
60    public:
61     Metadata(int device_ordinal, se::Platform* platform,
62              const DeviceType& device_type,
63              XlaCompiler::ShapeRepresentationFn shape_representation_fn,
64              PaddedShapeFn padded_shape_fn, bool use_multiple_streams);
65 
66     // The index of the device on this host.
67     int device_ordinal() const;
68 
69     se::Platform* platform() const;
70     xla::LocalClient* client() const;
71     const DeviceType& jit_device_type() const;
shape_representation_fn()72     const XlaCompiler::ShapeRepresentationFn& shape_representation_fn() const {
73       return shape_representation_fn_;
74     }
padded_shape_fn()75     const PaddedShapeFn& padded_shape_fn() const { return padded_shape_fn_; }
76 
UseMultipleStreams()77     bool UseMultipleStreams() const { return use_multiple_streams_; }
78 
79    private:
80     const int device_ordinal_;
81     const DeviceType device_type_;
82     se::Platform* platform_;  // Not owned.
83     XlaCompiler::ShapeRepresentationFn shape_representation_fn_;
84     PaddedShapeFn padded_shape_fn_;
85     const bool use_multiple_streams_;
86 
87     TF_DISALLOW_COPY_AND_ASSIGN(Metadata);
88   };
89 
90   // Sets `*metadata` to the XlaDevice Metadata in the XLA device used by `ctx`.
91   static Status GetMetadata(OpKernelContext* ctx, const Metadata** metadata);
92 
93   // Sets `*metadata` to the XlaDevice Metadata in the XLA device used by `ctx`.
94   static Status GetMetadata(OpKernelConstruction* ctx,
95                             const Metadata** metadata);
96 
97   struct Options {
98     // The StreamExecutor platform. Not owned. Must be non-null.
99     se::Platform* platform = nullptr;
100 
101     // The device name's prefix (e.g., "/task:7")
102     string device_name_prefix;
103 
104     // The name of the XLA device (e.g., "XLA_CPU")
105     string device_name;
106 
107     // The number of the device.
108     int device_ordinal = -1;
109 
110     // The name of the compilation device (e.g., "XLA_CPU_JIT");
111     string compilation_device_name;
112 
113     // If 'use_multiple_streams' is true, we create separate streams for
114     // compute, host-to-device, and device-to-host communication.
115     bool use_multiple_streams = false;
116 
117     // A function that describes how the on-host shapes of
118     // a) argument and return value, for entry computations
119     // b) variables, for all computations,
120     // should be represented in XLA. Parameters/return values will be shaped
121     // according to this function, and reshaped back to/from their declared
122     // shapes for computations. Must be non-null.
123     XlaCompiler::ShapeRepresentationFn shape_representation_fn;
124 
125     // If padded_shape_fn is empty, a default implementation that returns
126     // the logical on-device shape without padding is used.
127     PaddedShapeFn padded_shape_fn;
128 
129     // Set of devices to use. This controls which of the devices on the given
130     // platform will have resources allocated. For GPUs this will be
131     // filled from visible_gpu_devices list from session configuration.
132     absl::optional<std::set<int>> allowed_devices;
133   };
134 
135   // Creates a new XLA Device.
136   XlaDevice(const SessionOptions& session_options, const Options& options);
137   ~XlaDevice() override;
138 
139   Allocator* GetAllocator(AllocatorAttributes attr) override
140       LOCKS_EXCLUDED(mu_);
141   void Compute(OpKernel* op_kernel, OpKernelContext* context) override;
142   void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context,
143                     AsyncOpKernel::DoneCallback done) override;
144   Status Sync() override;
145   void Sync(const DoneCallback& done) override;
146 
147   Status FillContextMap(const Graph* graph,
148                         DeviceContextMap* device_context_map) override
149       LOCKS_EXCLUDED(mu_);
150 
151   Status MakeTensorFromProto(const TensorProto& tensor_proto,
152                              const AllocatorAttributes alloc_attrs,
153                              Tensor* tensor) override LOCKS_EXCLUDED(mu_);
154 
metadata()155   const Metadata& metadata() { return xla_metadata_; }
156 
157   // Ensures the DeviceContext associated with this XlaDevice is created and
158   // valid (i.e. all streams are ok). If any state is not valid, a new
159   // DeviceContext will be created.
160   //
161   // TODO(b/111859745): The Eager context needs to call this method to recover
162   // from failures.
163   Status EnsureDeviceContextOk() LOCKS_EXCLUDED(mu_);
164 
165   // Instructs this XlaDevice to set a GpuDeviceInfo, which holds extra
166   // information for GPU and TPU devices.
167   Status UseGpuDeviceInfo() LOCKS_EXCLUDED(mu_);
168 
169   // Instructs this XlaDevice to return 'sync_on_completion' for
170   // AllowsSyncOnCompletion().
171   void SetAllowsSyncOnCompletion(bool sync_on_completion) LOCKS_EXCLUDED(mu_);
172   bool AllowsSyncOnCompletion() const override LOCKS_EXCLUDED(mu_);
173 
174   // Installs an error handling callback when RefreshStatus sees !status.ok().
175   void SetHandleDeviceErrorCallback(std::function<Status()> callback);
176 
177   Status RefreshStatus() override LOCKS_EXCLUDED(mu_);
178 
179  private:
180   xla::LocalClient* client() const;
181   Allocator* GetAllocatorLocked(AllocatorAttributes attr)
182       EXCLUSIVE_LOCKS_REQUIRED(mu_);
183   Status EnsureStreamOkLocked(xla::Backend* backend, const string& name,
184                               std::shared_ptr<se::Stream>* stream,
185                               bool* stream_was_changed)
186       EXCLUSIVE_LOCKS_REQUIRED(mu_);
187   xla::StatusOr<XlaDeviceContext*> GetDeviceContextLocked()
188       EXCLUSIVE_LOCKS_REQUIRED(mu_);
189 
190   static Status GetMetadataFromDevice(DeviceBase* device,
191                                       const XlaDevice::Metadata** metadata);
192 
193   // Handles error when RefreshStatus sees !status.ok().
194   Status HandleDeviceError();
195 
196   mutable mutex mu_;
197   // The metadata of this XlaDevice.
198   const Metadata xla_metadata_;
199   // Which hardware device in the client's platform this XlaDevice controls.
200   const int device_ordinal_;
201   // The name of the device that is used to compile Ops for this XlaDevice.
202   const DeviceType jit_device_name_;
203   // The platform for this device.
204   se::Platform* const platform_;  // Not owned.
205   // Memory allocator associated with this device.
206   Allocator* xla_allocator_ GUARDED_BY(mu_) = nullptr;  // Not owned.
207 
208   // Stream associated with this device. Operations enqueued on this
209   // stream are executed on the device. Operations include data
210   // copying back and forth between CPU and the device, and
211   // computations enqueued by XLA.
212   std::shared_ptr<se::Stream> stream_ GUARDED_BY(mu_);
213   // If false, only stream_ is valid and all computation and transfers use
214   // stream_. If true, computation is performed by stream_ and transfers are
215   // performed by host_to_device/device_to_device stream or borrowing a stream
216   // for each device to host transfer.
217   const bool use_multiple_streams_;
218   // If use_multiple_streams_, host to device transfers are performed using this
219   // stream.
220   std::shared_ptr<se::Stream> host_to_device_stream_ GUARDED_BY(mu_);
221   // If use_multiple_streams_, transfers between different devices are performed
222   // using these streams.
223   std::vector<std::shared_ptr<se::Stream>> device_to_device_streams_
224       GUARDED_BY(mu_);
225 
226   const XlaCompiler::ShapeRepresentationFn shape_representation_fn_;
227 
228   // The device context accessed by all users of the XlaDevice, set by calls to
229   // EnsureDeviceContextOk. If gpu_device_info_ is non-null, this pointer is
230   // also filled in to that struct. XlaDeviceContext is a ref-counted object.
231   XlaDeviceContext* device_context_ GUARDED_BY(mu_) = nullptr;
232 
233   // Holds extra information for GPU and TPU devices, e.g. the device context.
234   bool use_gpu_device_info_ GUARDED_BY(mu_) = false;
235   std::unique_ptr<GpuDeviceInfo> gpu_device_info_ GUARDED_BY(mu_);
236 
237   // Thread pool used for running closures
238   std::unique_ptr<thread::ThreadPool> thread_pool_;
239 
240   // True if the device allows XlaDevice::Sync to be called on completion
241   // regardless of status.
242   bool sync_on_completion_ GUARDED_BY(mu_) = true;
243 
244   // A callback that will be invoked when RefreshStatus sees a status error.
245   std::function<Status()> device_error_callback_ GUARDED_BY(mu_);
246 
247   // Set of devices to use. This controls which of the devices on the given
248   // platform will have resources allocated. For GPUs this will be
249   // filled from visible_gpu_devices list from session configuration.
250   absl::optional<std::set<int>> allowed_devices_;
251 };
252 
253 // Builds OpKernel registrations on 'device' for the JIT operators
254 // registered on 'jit_device'. Returns ownership of a XlaDeviceOpRegistrations
255 // object that encapsulates the kernel registrations.
256 struct XlaDeviceOpRegistrations {
257   std::vector<std::unique_ptr<kernel_factory::OpKernelRegistrar>>
258       op_kernel_registrars;
259 };
260 XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device,
261                                                    const char* jit_device);
262 
263 }  // namespace tensorflow
264 
265 #endif  // TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_
266