• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_TRANSFER_MANAGER_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_TRANSFER_MANAGER_H_
18 
19 #include <map>
20 #include <set>
21 #include <vector>
22 
23 #include "absl/container/flat_hash_map.h"
24 #include "absl/types/span.h"
25 #include "tensorflow/compiler/xla/literal.h"
26 #include "tensorflow/compiler/xla/service/executable.h"
27 #include "tensorflow/compiler/xla/service/shaped_buffer.h"
28 #include "tensorflow/compiler/xla/statusor.h"
29 #include "tensorflow/compiler/xla/types.h"
30 #include "tensorflow/compiler/xla/xla_data.pb.h"
31 #include "tensorflow/core/platform/mutex.h"
32 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
33 #include "tensorflow/core/platform/thread_annotations.h"
34 #include "tensorflow/core/platform/types.h"
35 #include "tensorflow/stream_executor/device_memory.h"
36 
37 namespace xla {
38 
39 // The TransferManager interface lets backends provide platform-specific
40 // mechanisms for constructing literals from given device memory handles.
41 // This lets each platform customize how literals are transferred to/from the
42 // device in terms of padding, leading dimension, etc.
43 class TransferManager {
44  public:
~TransferManager()45   virtual ~TransferManager() {}
46 
47   // Returns the ID of the platform that this transfer manager acts on.
48   virtual se::Platform::Id PlatformId() const = 0;
49 
50   // Returns the shape of the on-device representation for the given shape on
51   // the host. This is intended for use with ShapedBuffer where buffers are
52   // pre-allocated by the host, e.g. TransferLiteralToDevice, without the user
53   // needing to consider device-specific behaviors.
HostShapeToDeviceShape(const Shape & host_shape)54   virtual Shape HostShapeToDeviceShape(const Shape& host_shape) const {
55     // Strips off any preexisting tiling or memory space information.
56     // TODO(phawkins): fix clients not to including tiling or memory space
57     // information in shapes passed to this function and turn this into an
58     // assertion.
59     return ShapeUtil::DeviceShapeToHostShape(host_shape);
60   }
61 
62   // Base class for specifying platform specific transfer metadata that can be
63   // used to tell the underlying implementation to perform specific optimization
64   // to a transfer. Actual metadata passed to supported transfer methods should
65   // subclass this class.
66   class TransferMetadata {
67    public:
68     virtual ~TransferMetadata() = 0;
69   };
70   // Returns a literal containing the data held in the given ShapedBuffer
71   // using the provided executor. This operation is performed synchronously
72   // without waiting for any other operation on a stream to complete.
73   //
74   // This function should be avoided in favor of the asynchronous version below.
75   //
76   // Optionally caller can specify platform-specific transfer metadata that
77   // tells the actual implementation to do something special.
78   virtual StatusOr<Literal> TransferLiteralFromDevice(
79       se::Stream* stream, const ShapedBuffer& device_buffer,
80       const TransferMetadata* transfer_metadata);
TransferLiteralFromDevice(se::Stream * stream,const ShapedBuffer & device_buffer)81   StatusOr<Literal> TransferLiteralFromDevice(
82       se::Stream* stream, const ShapedBuffer& device_buffer) {
83     return TransferLiteralFromDevice(stream, device_buffer, nullptr);
84   }
85   virtual Status TransferLiteralFromDevice(
86       se::Stream* stream, const ShapedBuffer& device_buffer,
87       const MutableBorrowingLiteral& literal,
88       const TransferMetadata* transfer_metadata);
TransferLiteralFromDevice(se::Stream * stream,const ShapedBuffer & device_buffer,const MutableBorrowingLiteral & literal)89   Status TransferLiteralFromDevice(se::Stream* stream,
90                                    const ShapedBuffer& device_buffer,
91                                    const MutableBorrowingLiteral& literal) {
92     return TransferLiteralFromDevice(stream, device_buffer, literal, nullptr);
93   }
94 
95   // Begins transferring a literal containing the data held in the given
96   // ShapedBuffer using the provided executor.
97   //
98   // This operation is performed asynchronously on the given stream. It returns
99   // once the transfer is enqueued. 'done' is invoked with the result when
100   // complete.
101   //
102   // device_buffer is copied by reference and must live at least until done() is
103   // invoked.
104   //
105   // Optionally caller can specify platform-specific transfer metadata that
106   // tells the actual implementation to do something special.
107   virtual void TransferLiteralFromDevice(
108       se::Stream* stream, const ShapedBuffer& device_buffer,
109       MutableBorrowingLiteral literal, std::function<void(Status)> done,
110       const TransferMetadata* transfer_metadata) = 0;
TransferLiteralFromDevice(se::Stream * stream,const ShapedBuffer & device_buffer,MutableBorrowingLiteral literal,std::function<void (Status)> done)111   void TransferLiteralFromDevice(se::Stream* stream,
112                                  const ShapedBuffer& device_buffer,
113                                  MutableBorrowingLiteral literal,
114                                  std::function<void(Status)> done) {
115     return TransferLiteralFromDevice(stream, device_buffer, literal, done,
116                                      nullptr);
117   }
118 
119   // Transfers the given literal into the previously allocated device memory
120   // represented by the given ShapedBuffer using the given executor. The shape
121   // of the ShapedBuffer and DeviceShape(literal.shape()) must be compatible,
122   // but need not have the same layout.
123   //
124   // This operation is performed synchronously without waiting for any other
125   // operation on a stream to complete. This function should be avoided in favor
126   // of the asynchronous version below.
127   //
128   // Optionally caller can specify platform-specific transfer metadata that
129   // tells the actual implementation to do something special.
130   virtual Status TransferLiteralToDevice(
131       se::Stream* stream, const LiteralSlice& literal,
132       const ShapedBuffer& device_buffer,
133       const TransferMetadata* transfer_metadata);
TransferLiteralToDevice(se::Stream * stream,const LiteralSlice & literal,const ShapedBuffer & device_buffer)134   Status TransferLiteralToDevice(se::Stream* stream,
135                                  const LiteralSlice& literal,
136                                  const ShapedBuffer& device_buffer) {
137     return TransferLiteralToDevice(stream, literal, device_buffer, nullptr);
138   }
139 
140   // Transfers the given literal into the previously allocated device memory
141   // represented by the given ShapedBuffer using the given executor. The shape
142   // of the ShapedBuffer and DeviceShape(literal.shape()) must be compatible,
143   // but need not have the same layout.
144   //
145   // This operation is performed asynchronously on the given stream. It returns
146   // once the transfer is enqueued, and may return before the transfer has
147   // completed.
148   //
149   // The caller may free the data structures 'literal' and 'device_buffer'
150   // immediately after this function returns, however their constituent buffers
151   // on both host and device must remain valid until the enqueued transfer has
152   // completed on 'stream'.
153   //
154   // Optionally caller can specify platform-specific transfer metadata that
155   // tells the actual implementation to do something special.
156   virtual Status TransferLiteralToDeviceAsync(
157       se::Stream* stream, const LiteralSlice& literal,
158       const ShapedBuffer& device_buffer,
159       const TransferMetadata* transfer_metadata) = 0;
TransferLiteralToDeviceAsync(se::Stream * stream,const LiteralSlice & literal,const ShapedBuffer & device_buffer)160   Status TransferLiteralToDeviceAsync(se::Stream* stream,
161                                       const LiteralSlice& literal,
162                                       const ShapedBuffer& device_buffer) {
163     return TransferLiteralToDeviceAsync(stream, literal, device_buffer,
164                                         nullptr);
165   }
166 
167   // Convenience methods for transferring an array to or from the device at a
168   // known address. This avoids having to construct a ShapedBuffer just to
169   // transfer an array at a known address.
170   //
171   // Optionally caller can specify platform-specific transfer metadata that
172   // tells the actual implementation to do something special.
173   Status TransferArrayToDevice(
174       se::Stream* stream, const LiteralSlice& literal,
175       const se::DeviceMemoryBase& dest,
176       const TransferMetadata* transfer_metadata = nullptr);
177   void TransferArrayFromDevice(
178       se::Stream* stream, const Shape& shape,
179       const se::DeviceMemoryBase& source,
180       const MutableBorrowingLiteral& literal, std::function<void(Status)> done,
181       const TransferMetadata* transfer_metadata = nullptr);
182 
183   Status TransferArrayToDeviceAsync(
184       se::Stream* stream, const LiteralSlice& literal,
185       const se::DeviceMemoryBase& dest,
186       const TransferMetadata* transfer_metadata = nullptr);
187   StatusOr<Literal> TransferArrayFromDevice(
188       se::Stream* stream, const Shape& shape,
189       const se::DeviceMemoryBase& source,
190       const TransferMetadata* transfer_metadata = nullptr);
191 
192   // Read from a device buffer and update the dynamic dimension sizes of
193   // `host_shape` and `device_shape`. The function takes in bounded dynamic
194   // shapes, and returns static shapes with dynamic shapes updated.
195   // The shape of the buffer also have to be compatible with the host shape and
196   // device shape.
197   virtual Status ReadDynamicShapes(se::Stream* stream,
198                                    ShapedBuffer* device_buffer,
199                                    Shape* device_shape);
200 
201   // Transfers the given literal into the Infeed interface of the device,
202   // using the given executor.
203   virtual Status TransferLiteralToInfeed(se::StreamExecutor* executor,
204                                          const LiteralSlice& literal) = 0;
205 
206   // Transfers the given literal from the Outfeed interface of the device,
207   // using the given executor. The shape and layout are determined by the
208   // shape and layout of `literal`.
209   virtual Status TransferLiteralFromOutfeed(
210       se::StreamExecutor* executor, MutableBorrowingLiteral literal) = 0;
211 
212   // Resets the devices associated with this transfer manager.
213   virtual Status ResetDevices(
214       absl::Span<se::StreamExecutor* const> executor) = 0;
215 
216   // Given an allocated ShapedBuffer, constructs the tuple index table(s) in
217   // each buffer of the given ShapedBuffer corresponding to tuple shapes. If the
218   // ShapedBuffer is array-shaped this method does nothing.
219   Status WriteTupleIndexTables(se::Stream* stream,
220                                const ShapedBuffer& device_buffer);
221   Status WriteTupleIndexTablesAsync(se::Stream* stream,
222                                     const ShapedBuffer& device_buffer);
223 
224   // Writes a tuple index buffer for the root of 'device_buffer', which must
225   // be a tuple. Unlike WriteTupleIndexTables, only writes the root buffer,
226   // rather than writing all subbuffers. This method is always asynchronous.
227   Status WriteRootTupleIndexTable(se::Stream* stream,
228                                   const ShapedBuffer& device_buffer);
229   Status WriteRootTupleIndexTable(
230       se::Stream* stream,
231       const ShapeTree<MaybeOwningDeviceMemory>& buffer_tree);
232 
233   // Determines the byte size requirement for the given shape on the underlying
234   // architecture. This will be used to allocate an appropriately sized memory
235   // region for a host-to-device transfer.
236   virtual int64 GetByteSizeRequirement(const Shape& shape) const = 0;
237 
238   // Chooses a compact layout for 'shape', ignoring any existing layout on
239   // 'shape'. What "reasonable" means is left up to the backend. The
240   // intended use case is to choose a layout that avoids excessive padding on
241   // devices that have tiled memory architectures.
242   // The default implementation always picks a default (major-to-minor) layout.
243   // Fails if 'shape' cannot be represented by the device.
244   virtual StatusOr<Shape> ChooseCompactLayoutForShape(
245       const Shape& host_shape) const;
246 
247   typedef std::function<Shape(const Shape&)> DeviceShapeRepresentationFn;
248 
249   // Allocates a ScopedShapedBuffer which can hold data with the given on-host
250   // shape. The on-device shape may be different as indicated by
251   // HostShapeToDeviceShape.
252   StatusOr<ScopedShapedBuffer> AllocateScopedShapedBuffer(
253       const Shape& on_host_shape, se::DeviceMemoryAllocator* allocator,
254       int device_ordinal,
255       DeviceShapeRepresentationFn shape_representation_fn = nullptr);
256 
257   // The given ShapedBuffer holds a handle to allocated memory, but it is not
258   // in the general case legal to immediately copy or access that allocated
259   // memory because queued operations on the device may alias that memory.
260   // Memory ordering is enforced by the Stream's happens-before relationship
261   // which allows eager deallocation and reallocation of buffers host-side even
262   // if the device hasn't finished with them.
263   //
264   // In certain cases, it can be known that a ShapedBuffer does not have any
265   // conflicting accesses on the device and thus is eligible to be accessed at
266   // any time from the host.
267   //
268   // This function returns true if device_buffer can be accessed immediately
269   // without waiting for the Stream's previously enqueued items. This only
270   // returns true if all subbuffers in device_buffer can be accessed
271   // immediately.
CanShapedBufferBeAccessedNow(se::StreamExecutor * executor,const ShapedBuffer & device_buffer)272   virtual bool CanShapedBufferBeAccessedNow(
273       se::StreamExecutor* executor, const ShapedBuffer& device_buffer) const {
274     return false;
275   }
276 
277   // Equivalent to CanShapedBufferBeAccessedNow but for a single device buffer.
CanBufferBeAccessedNow(se::StreamExecutor * executor,const se::DeviceMemoryBase & device_buffer)278   virtual bool CanBufferBeAccessedNow(
279       se::StreamExecutor* executor,
280       const se::DeviceMemoryBase& device_buffer) const {
281     return false;
282   }
283 
284   /////
285   // The TransferManager class also serves as a point to register objects for
286   // the various platforms.
287 
288   // Registers the TransferManager singleton for the platform kind. This is
289   // assumed to be a singleton, so no ownership is transferred.
290   //
291   // Precondition: a platform kind must not be registered more than once.
292   typedef std::unique_ptr<TransferManager> (*TransferManagerCreationFunction)();
293   static void RegisterTransferManager(
294       se::Platform::Id platform_id,
295       TransferManagerCreationFunction transfer_manager);
296 
297   // Returns the transfer manager singleton pointer if it is available for the
298   // given platform, or an error status if it is not.
299   static StatusOr<TransferManager*> GetForPlatform(
300       const se::Platform* platform);
301 
302   // Writes the given device-memory pointers in 'elements' to the given region
303   // to construct a tuple index table in the platform-specific tuple
304   // representation.
305   virtual Status WriteSingleTupleIndexTable(
306       se::Stream* stream, absl::Span<const se::DeviceMemoryBase> elements,
307       const Shape& shape, se::DeviceMemoryBase* region) = 0;
308 
309  protected:
310   // Transfer a memory block of the given size from the device source into the
311   // 'destination' buffer.
312   //
313   // size is the size to transfer to destination in bytes.
314   virtual Status TransferBufferFromDevice(se::Stream* stream,
315                                           const se::DeviceMemoryBase& source,
316                                           int64_t size, void* destination);
317 
318   // Transfer a memory block of the given size from 'source' buffer to the given
319   // destination of the device.
320   //
321   // size is the size to transfer from source in bytes.
322   virtual Status TransferBufferToDevice(se::Stream* stream, int64_t size,
323                                         const void* source,
324                                         se::DeviceMemoryBase* destination);
325 
326  private:
327   // The mutex that guards the platform-to-transfer manager map.
328   static tensorflow::mutex platform_transfer_manager_mutex_;
329 
330   // State kept for each kind of TransferManager.  Registration functions
331   // set up creation_function, and then we use that to lazily create
332   // "manager" the first time GetForPlatform is invoked for a particular id.
333   struct State {
334     std::unique_ptr<TransferManager> manager;
335     TransferManagerCreationFunction creation_function = nullptr;
336   };
337 
338   // Map from platform kind to transfer manager singleton.
339   static absl::flat_hash_map<se::Platform::Id, State>*
340   GetPlatformTransferManagers();
341 };
342 
343 }  // namespace xla
344 
345 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_TRANSFER_MANAGER_H_
346