• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_TRANSFER_MANAGER_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_TRANSFER_MANAGER_H_
18 
19 #include <map>
20 #include <set>
21 #include <vector>
22 
23 #include "absl/types/span.h"
24 #include "tensorflow/compiler/xla/literal.h"
25 #include "tensorflow/compiler/xla/service/shaped_buffer.h"
26 #include "tensorflow/compiler/xla/statusor.h"
27 #include "tensorflow/compiler/xla/types.h"
28 #include "tensorflow/compiler/xla/xla_data.pb.h"
29 #include "tensorflow/core/platform/mutex.h"
30 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
31 #include "tensorflow/core/platform/thread_annotations.h"
32 #include "tensorflow/core/platform/types.h"
33 
34 namespace xla {
35 
36 // The TransferManager interface lets backends provide platform-specific
37 // mechanisms for constructing literals from given device memory handles.
38 // This lets each platform customize how literals are transferred to/from the
39 // device in terms of padding, leading dimension, etc.
40 class TransferManager {
41  public:
~TransferManager()42   virtual ~TransferManager() {}
43 
44   // Returns the ID of the platform that this transfer manager acts on.
45   virtual se::Platform::Id PlatformId() const = 0;
46 
47   // Returns the shape of the on-device representation for the given shape on
48   // the host. This is intended for use with ShapedBuffer where buffers are
49   // pre-allocated by the host, e.g. TransferLiteralToDevice, without the user
50   // needing to consider device-specific behaviors.
HostShapeToDeviceShape(const Shape & host_shape)51   virtual Shape HostShapeToDeviceShape(const Shape& host_shape) const {
52     return host_shape;
53   }
54 
55   // Base class for specifying platform specific transfer metadata that can be
56   // used to tell the underlying implementation to perform specific optimization
57   // to a transfer. Actual metadata passed to supported transfer methods should
58   // subclass this class.
59   class TransferMetadata {
60    public:
61     virtual ~TransferMetadata() = 0;
62   };
63   // Returns a literal containing the data held in the given ShapedBuffer
64   // using the provided executor. This operation is performed synchronously
65   // without waiting for any other operation on a stream to complete.
66   //
67   // This function should be avoided in favor of the asynchronous version below.
68   //
69   // Optionally caller can specify platform-specific transfer metadata that
70   // tells the actual implementation to do something special.
71   virtual StatusOr<Literal> TransferLiteralFromDevice(
72       se::Stream* stream, const ShapedBuffer& device_buffer,
73       const TransferMetadata* transfer_metadata);
TransferLiteralFromDevice(se::Stream * stream,const ShapedBuffer & device_buffer)74   StatusOr<Literal> TransferLiteralFromDevice(
75       se::Stream* stream, const ShapedBuffer& device_buffer) {
76     return TransferLiteralFromDevice(stream, device_buffer, nullptr);
77   }
78   virtual Status TransferLiteralFromDevice(
79       se::Stream* stream, const ShapedBuffer& device_buffer,
80       const MutableBorrowingLiteral& literal,
81       const TransferMetadata* transfer_metadata);
TransferLiteralFromDevice(se::Stream * stream,const ShapedBuffer & device_buffer,const MutableBorrowingLiteral & literal)82   Status TransferLiteralFromDevice(se::Stream* stream,
83                                    const ShapedBuffer& device_buffer,
84                                    const MutableBorrowingLiteral& literal) {
85     return TransferLiteralFromDevice(stream, device_buffer, literal, nullptr);
86   }
87 
88   // Begins transferring a literal containing the data held in the given
89   // ShapedBuffer using the provided executor.
90   //
91   // This operation is performed asynchronously on the given stream. It returns
92   // once the transfer is enqueued. 'done' is invoked with the result when
93   // complete.
94   //
95   // device_buffer is copied by reference and must live at least until done() is
96   // invoked.
97   //
98   // Optionally caller can specify platform-specific transfer metadata that
99   // tells the actual implementation to do something special.
100   virtual void TransferLiteralFromDevice(
101       se::Stream* stream, const ShapedBuffer& device_buffer,
102       MutableBorrowingLiteral literal, std::function<void(Status)> done,
103       const TransferMetadata* transfer_metadata) = 0;
TransferLiteralFromDevice(se::Stream * stream,const ShapedBuffer & device_buffer,MutableBorrowingLiteral literal,std::function<void (Status)> done)104   void TransferLiteralFromDevice(se::Stream* stream,
105                                  const ShapedBuffer& device_buffer,
106                                  MutableBorrowingLiteral literal,
107                                  std::function<void(Status)> done) {
108     return TransferLiteralFromDevice(stream, device_buffer, literal, done,
109                                      nullptr);
110   }
111 
112   // Transfers the given literal into the previously allocated device memory
113   // represented by the given ShapedBuffer using the given executor. The shape
114   // of the ShapedBuffer and DeviceShape(literal.shape()) must be compatible,
115   // but need not have the same layout.
116   //
117   // This operation is performed synchronously without waiting for any other
118   // operation on a stream to complete. This function should be avoided in favor
119   // of the asynchronous version below.
120   //
121   // Optionally caller can specify platform-specific transfer metadata that
122   // tells the actual implementation to do something special.
123   virtual Status TransferLiteralToDevice(
124       se::Stream* stream, const LiteralSlice& literal,
125       const ShapedBuffer& device_buffer,
126       const TransferMetadata* transfer_metadata);
TransferLiteralToDevice(se::Stream * stream,const LiteralSlice & literal,const ShapedBuffer & device_buffer)127   Status TransferLiteralToDevice(se::Stream* stream,
128                                  const LiteralSlice& literal,
129                                  const ShapedBuffer& device_buffer) {
130     return TransferLiteralToDevice(stream, literal, device_buffer, nullptr);
131   }
132 
133   // Transfers the given literal into the previously allocated device memory
134   // represented by the given ShapedBuffer using the given executor. The shape
135   // of the ShapedBuffer and DeviceShape(literal.shape()) must be compatible,
136   // but need not have the same layout.
137   //
138   // This operation is performed asynchronously on the given stream. It returns
139   // once the transfer is enqueued, and may return before the transfer has
140   // completed.
141   //
142   // The caller may free the data structures 'literal' and 'device_buffer'
143   // immediately after this function returns, however their constituent buffers
144   // on both host and device must remain valid until the enqueued transfer has
145   // completed on 'stream'.
146   //
147   // Optionally caller can specify platform-specific transfer metadata that
148   // tells the actual implementation to do something special.
149   virtual Status TransferLiteralToDeviceAsync(
150       se::Stream* stream, const LiteralSlice& literal,
151       const ShapedBuffer& device_buffer,
152       const TransferMetadata* transfer_metadata) = 0;
TransferLiteralToDeviceAsync(se::Stream * stream,const LiteralSlice & literal,const ShapedBuffer & device_buffer)153   Status TransferLiteralToDeviceAsync(se::Stream* stream,
154                                       const LiteralSlice& literal,
155                                       const ShapedBuffer& device_buffer) {
156     return TransferLiteralToDeviceAsync(stream, literal, device_buffer,
157                                         nullptr);
158   }
159 
160   // Convenience methods for transferring an array to or from the device at a
161   // known address. This avoids having to construct a ShapedBuffer just to
162   // transfer an array at a known address.
163   //
164   // Optionally caller can specify platform-specific transfer metadata that
165   // tells the actual implementation to do something special.
166   Status TransferArrayToDevice(
167       se::Stream* stream, const LiteralSlice& literal,
168       const se::DeviceMemoryBase& dest,
169       const TransferMetadata* transfer_metadata = nullptr);
170   void TransferArrayFromDevice(
171       se::Stream* stream, const Shape& shape,
172       const se::DeviceMemoryBase& source,
173       const MutableBorrowingLiteral& literal, std::function<void(Status)> done,
174       const TransferMetadata* transfer_metadata = nullptr);
175 
176   Status TransferArrayToDeviceAsync(
177       se::Stream* stream, const LiteralSlice& literal,
178       const se::DeviceMemoryBase& dest,
179       const TransferMetadata* transfer_metadata = nullptr);
180   StatusOr<Literal> TransferArrayFromDevice(
181       se::Stream* stream, const Shape& shape,
182       const se::DeviceMemoryBase& source,
183       const TransferMetadata* transfer_metadata = nullptr);
184 
185   // Transfers the given literal into the Infeed interface of the device,
186   // using the given executor.
187   virtual Status TransferLiteralToInfeed(se::StreamExecutor* executor,
188                                          const LiteralSlice& literal) = 0;
189 
190   // Transfers the given literal from the Outfeed interface of the device,
191   // using the given executor.
192   virtual Status TransferLiteralFromOutfeed(
193       se::StreamExecutor* executor, const Shape& literal_shape,
194       MutableBorrowingLiteral literal) = 0;
195 
196   // Resets the devices associated with this transfer manager.
197   virtual Status ResetDevices(
198       absl::Span<se::StreamExecutor* const> executor) = 0;
199 
200   // Given an allocated ShapedBuffer, constructs the tuple index table(s) in
201   // each buffer of the given ShapedBuffer corresponding to tuple shapes. If the
202   // ShapedBuffer is array-shaped this method does nothing.
203   Status WriteTupleIndexTables(se::Stream* stream,
204                                const ShapedBuffer& device_buffer);
205   Status WriteTupleIndexTablesAsync(se::Stream* stream,
206                                     const ShapedBuffer& device_buffer);
207 
208   // Writes a tuple index buffer for the root of 'device_buffer', which must
209   // be a tuple. Unlike WriteTupleIndexTables, only writes the root buffer,
210   // rather than writing all subbuffers. This method is always asynchronous.
211   Status WriteRootTupleIndexTable(se::Stream* stream,
212                                   const ShapedBuffer& device_buffer);
213 
214   // Determines the byte size requirement for the given shape on the underlying
215   // architecture. This will be used to allocate an appropriately sized memory
216   // region for a host-to-device transfer.
217   virtual int64 GetByteSizeRequirement(const Shape& shape) const = 0;
218 
219   // Allocates a ScopedShapedBuffer which can hold data with the given on-host
220   // shape. The on-device shape may be different as indicated by
221   // HostShapeToDeviceShape.
222   StatusOr<ScopedShapedBuffer> AllocateScopedShapedBuffer(
223       const Shape& on_host_shape, DeviceMemoryAllocator* allocator,
224       int device_ordinal);
225 
226   // The given ShapedBuffer holds a handle to allocated memory, but it is not
227   // in the general case legal to immediately copy or access that allocated
228   // memory because queued operations on the device may alias that memory.
229   // Memory ordering is enforced by the Stream's happens-before relationship
230   // which allows eager deallocation and reallocation of buffers host-side even
231   // if the device hasn't finished with them.
232   //
233   // In certain cases, it can be known that a ShapedBuffer does not have any
234   // conflicting accesses on the device and thus is eligible to be accessed at
235   // any time from the host.
236   //
237   // This function returns true if device_buffer can be accessed immediately
238   // without waiting for the Stream's previously enqueued items. This only
239   // returns true if all subbuffers in device_buffer can be accessed
240   // immediately.
CanShapedBufferBeAccessedNow(se::StreamExecutor * executor,const ShapedBuffer & device_buffer)241   virtual bool CanShapedBufferBeAccessedNow(
242       se::StreamExecutor* executor, const ShapedBuffer& device_buffer) const {
243     return false;
244   }
245 
246   /////
247   // The TransferManager class also serves as a point to register objects for
248   // the various platforms.
249 
250   // Registers the TransferManager singleton for the platform kind. This is
251   // assumed to be a singleton, so no ownership is transferred.
252   //
253   // Precondition: a platform kind must not be registered more than once.
254   typedef std::unique_ptr<TransferManager> (*TransferManagerCreationFunction)();
255   static void RegisterTransferManager(
256       se::Platform::Id platform_id,
257       TransferManagerCreationFunction transfer_manager);
258 
259   // Returns the transfer manager singleton pointer if it is available for the
260   // given platform, or an error status if it is not.
261   static StatusOr<TransferManager*> GetForPlatform(
262       const se::Platform* platform);
263 
264  protected:
265   // Transfer a memory block of the given size from the device source into the
266   // 'destination' buffer.
267   //
268   // size is the size to transfer to destination in bytes.
269   virtual Status TransferBufferFromDevice(se::Stream* stream,
270                                           const se::DeviceMemoryBase& source,
271                                           int64 size, void* destination);
272 
273   // Transfer a memory block of the given size from 'source' buffer to the given
274   // destination of the device.
275   //
276   // size is the size to transfer from source in bytes.
277   virtual Status TransferBufferToDevice(se::Stream* stream, int64 size,
278                                         const void* source,
279                                         se::DeviceMemoryBase* destination);
280 
281   // Writes the given device-memory pointers in 'elements' to the given region
282   // to construct a tuple index table in the platform-specific tuple
283   // representation.
284   virtual Status WriteSingleTupleIndexTable(
285       se::Stream* stream, absl::Span<const se::DeviceMemoryBase> elements,
286       const Shape& shape, se::DeviceMemoryBase* region) = 0;
287 
288  private:
289   // The mutex that guards the platform-to-transfer manager map.
290   static tensorflow::mutex platform_transfer_manager_mutex_;
291 
292   // State kept for each kind of TransferManager.  Registration functions
293   // set up creation_function, and then we use that to lazily create
294   // "manager" the first time GetForPlatform is invoked for a particular id.
295   struct State {
296     std::unique_ptr<TransferManager> manager;
297     TransferManagerCreationFunction creation_function = nullptr;
298   };
299 
300   // Map from platform kind to transfer manager singleton.
301   static std::map<se::Platform::Id, State>* GetPlatformTransferManagers();
302 };
303 
304 }  // namespace xla
305 
306 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_TRANSFER_MANAGER_H_
307