1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_TRANSFER_MANAGER_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_TRANSFER_MANAGER_H_ 18 19 #include <map> 20 #include <set> 21 #include <vector> 22 23 #include "absl/types/span.h" 24 #include "tensorflow/compiler/xla/literal.h" 25 #include "tensorflow/compiler/xla/service/shaped_buffer.h" 26 #include "tensorflow/compiler/xla/statusor.h" 27 #include "tensorflow/compiler/xla/types.h" 28 #include "tensorflow/compiler/xla/xla_data.pb.h" 29 #include "tensorflow/core/platform/mutex.h" 30 #include "tensorflow/core/platform/stream_executor_no_cuda.h" 31 #include "tensorflow/core/platform/thread_annotations.h" 32 #include "tensorflow/core/platform/types.h" 33 34 namespace xla { 35 36 // The TransferManager interface lets backends provide platform-specific 37 // mechanisms for constructing literals from given device memory handles. 38 // This lets each platform customize how literals are transferred to/from the 39 // device in terms of padding, leading dimension, etc. 40 class TransferManager { 41 public: ~TransferManager()42 virtual ~TransferManager() {} 43 44 // Returns the ID of the platform that this transfer manager acts on. 45 virtual se::Platform::Id PlatformId() const = 0; 46 47 // Returns the shape of the on-device representation for the given shape on 48 // the host. This is intended for use with ShapedBuffer where buffers are 49 // pre-allocated by the host, e.g. TransferLiteralToDevice, without the user 50 // needing to consider device-specific behaviors. HostShapeToDeviceShape(const Shape & host_shape)51 virtual Shape HostShapeToDeviceShape(const Shape& host_shape) const { 52 return host_shape; 53 } 54 55 // Base class for specifying platform specific transfer metadata that can be 56 // used to tell the underlying implementation to perform specific optimization 57 // to a transfer. Actual metadata passed to supported transfer methods should 58 // subclass this class. 59 class TransferMetadata { 60 public: 61 virtual ~TransferMetadata() = 0; 62 }; 63 // Returns a literal containing the data held in the given ShapedBuffer 64 // using the provided executor. This operation is performed synchronously 65 // without waiting for any other operation on a stream to complete. 66 // 67 // This function should be avoided in favor of the asynchronous version below. 68 // 69 // Optionally caller can specify platform-specific transfer metadata that 70 // tells the actual implementation to do something special. 71 virtual StatusOr<Literal> TransferLiteralFromDevice( 72 se::Stream* stream, const ShapedBuffer& device_buffer, 73 const TransferMetadata* transfer_metadata); TransferLiteralFromDevice(se::Stream * stream,const ShapedBuffer & device_buffer)74 StatusOr<Literal> TransferLiteralFromDevice( 75 se::Stream* stream, const ShapedBuffer& device_buffer) { 76 return TransferLiteralFromDevice(stream, device_buffer, nullptr); 77 } 78 virtual Status TransferLiteralFromDevice( 79 se::Stream* stream, const ShapedBuffer& device_buffer, 80 const MutableBorrowingLiteral& literal, 81 const TransferMetadata* transfer_metadata); TransferLiteralFromDevice(se::Stream * stream,const ShapedBuffer & device_buffer,const MutableBorrowingLiteral & literal)82 Status TransferLiteralFromDevice(se::Stream* stream, 83 const ShapedBuffer& device_buffer, 84 const MutableBorrowingLiteral& literal) { 85 return TransferLiteralFromDevice(stream, device_buffer, literal, nullptr); 86 } 87 88 // Begins transferring a literal containing the data held in the given 89 // ShapedBuffer using the provided executor. 90 // 91 // This operation is performed asynchronously on the given stream. It returns 92 // once the transfer is enqueued. 'done' is invoked with the result when 93 // complete. 94 // 95 // device_buffer is copied by reference and must live at least until done() is 96 // invoked. 97 // 98 // Optionally caller can specify platform-specific transfer metadata that 99 // tells the actual implementation to do something special. 100 virtual void TransferLiteralFromDevice( 101 se::Stream* stream, const ShapedBuffer& device_buffer, 102 MutableBorrowingLiteral literal, std::function<void(Status)> done, 103 const TransferMetadata* transfer_metadata) = 0; TransferLiteralFromDevice(se::Stream * stream,const ShapedBuffer & device_buffer,MutableBorrowingLiteral literal,std::function<void (Status)> done)104 void TransferLiteralFromDevice(se::Stream* stream, 105 const ShapedBuffer& device_buffer, 106 MutableBorrowingLiteral literal, 107 std::function<void(Status)> done) { 108 return TransferLiteralFromDevice(stream, device_buffer, literal, done, 109 nullptr); 110 } 111 112 // Transfers the given literal into the previously allocated device memory 113 // represented by the given ShapedBuffer using the given executor. The shape 114 // of the ShapedBuffer and DeviceShape(literal.shape()) must be compatible, 115 // but need not have the same layout. 116 // 117 // This operation is performed synchronously without waiting for any other 118 // operation on a stream to complete. This function should be avoided in favor 119 // of the asynchronous version below. 120 // 121 // Optionally caller can specify platform-specific transfer metadata that 122 // tells the actual implementation to do something special. 123 virtual Status TransferLiteralToDevice( 124 se::Stream* stream, const LiteralSlice& literal, 125 const ShapedBuffer& device_buffer, 126 const TransferMetadata* transfer_metadata); TransferLiteralToDevice(se::Stream * stream,const LiteralSlice & literal,const ShapedBuffer & device_buffer)127 Status TransferLiteralToDevice(se::Stream* stream, 128 const LiteralSlice& literal, 129 const ShapedBuffer& device_buffer) { 130 return TransferLiteralToDevice(stream, literal, device_buffer, nullptr); 131 } 132 133 // Transfers the given literal into the previously allocated device memory 134 // represented by the given ShapedBuffer using the given executor. The shape 135 // of the ShapedBuffer and DeviceShape(literal.shape()) must be compatible, 136 // but need not have the same layout. 137 // 138 // This operation is performed asynchronously on the given stream. It returns 139 // once the transfer is enqueued, and may return before the transfer has 140 // completed. 141 // 142 // The caller may free the data structures 'literal' and 'device_buffer' 143 // immediately after this function returns, however their constituent buffers 144 // on both host and device must remain valid until the enqueued transfer has 145 // completed on 'stream'. 146 // 147 // Optionally caller can specify platform-specific transfer metadata that 148 // tells the actual implementation to do something special. 149 virtual Status TransferLiteralToDeviceAsync( 150 se::Stream* stream, const LiteralSlice& literal, 151 const ShapedBuffer& device_buffer, 152 const TransferMetadata* transfer_metadata) = 0; TransferLiteralToDeviceAsync(se::Stream * stream,const LiteralSlice & literal,const ShapedBuffer & device_buffer)153 Status TransferLiteralToDeviceAsync(se::Stream* stream, 154 const LiteralSlice& literal, 155 const ShapedBuffer& device_buffer) { 156 return TransferLiteralToDeviceAsync(stream, literal, device_buffer, 157 nullptr); 158 } 159 160 // Convenience methods for transferring an array to or from the device at a 161 // known address. This avoids having to construct a ShapedBuffer just to 162 // transfer an array at a known address. 163 // 164 // Optionally caller can specify platform-specific transfer metadata that 165 // tells the actual implementation to do something special. 166 Status TransferArrayToDevice( 167 se::Stream* stream, const LiteralSlice& literal, 168 const se::DeviceMemoryBase& dest, 169 const TransferMetadata* transfer_metadata = nullptr); 170 void TransferArrayFromDevice( 171 se::Stream* stream, const Shape& shape, 172 const se::DeviceMemoryBase& source, 173 const MutableBorrowingLiteral& literal, std::function<void(Status)> done, 174 const TransferMetadata* transfer_metadata = nullptr); 175 176 Status TransferArrayToDeviceAsync( 177 se::Stream* stream, const LiteralSlice& literal, 178 const se::DeviceMemoryBase& dest, 179 const TransferMetadata* transfer_metadata = nullptr); 180 StatusOr<Literal> TransferArrayFromDevice( 181 se::Stream* stream, const Shape& shape, 182 const se::DeviceMemoryBase& source, 183 const TransferMetadata* transfer_metadata = nullptr); 184 185 // Transfers the given literal into the Infeed interface of the device, 186 // using the given executor. 187 virtual Status TransferLiteralToInfeed(se::StreamExecutor* executor, 188 const LiteralSlice& literal) = 0; 189 190 // Transfers the given literal from the Outfeed interface of the device, 191 // using the given executor. 192 virtual Status TransferLiteralFromOutfeed( 193 se::StreamExecutor* executor, const Shape& literal_shape, 194 MutableBorrowingLiteral literal) = 0; 195 196 // Resets the devices associated with this transfer manager. 197 virtual Status ResetDevices( 198 absl::Span<se::StreamExecutor* const> executor) = 0; 199 200 // Given an allocated ShapedBuffer, constructs the tuple index table(s) in 201 // each buffer of the given ShapedBuffer corresponding to tuple shapes. If the 202 // ShapedBuffer is array-shaped this method does nothing. 203 Status WriteTupleIndexTables(se::Stream* stream, 204 const ShapedBuffer& device_buffer); 205 Status WriteTupleIndexTablesAsync(se::Stream* stream, 206 const ShapedBuffer& device_buffer); 207 208 // Writes a tuple index buffer for the root of 'device_buffer', which must 209 // be a tuple. Unlike WriteTupleIndexTables, only writes the root buffer, 210 // rather than writing all subbuffers. This method is always asynchronous. 211 Status WriteRootTupleIndexTable(se::Stream* stream, 212 const ShapedBuffer& device_buffer); 213 214 // Determines the byte size requirement for the given shape on the underlying 215 // architecture. This will be used to allocate an appropriately sized memory 216 // region for a host-to-device transfer. 217 virtual int64 GetByteSizeRequirement(const Shape& shape) const = 0; 218 219 // Allocates a ScopedShapedBuffer which can hold data with the given on-host 220 // shape. The on-device shape may be different as indicated by 221 // HostShapeToDeviceShape. 222 StatusOr<ScopedShapedBuffer> AllocateScopedShapedBuffer( 223 const Shape& on_host_shape, DeviceMemoryAllocator* allocator, 224 int device_ordinal); 225 226 // The given ShapedBuffer holds a handle to allocated memory, but it is not 227 // in the general case legal to immediately copy or access that allocated 228 // memory because queued operations on the device may alias that memory. 229 // Memory ordering is enforced by the Stream's happens-before relationship 230 // which allows eager deallocation and reallocation of buffers host-side even 231 // if the device hasn't finished with them. 232 // 233 // In certain cases, it can be known that a ShapedBuffer does not have any 234 // conflicting accesses on the device and thus is eligible to be accessed at 235 // any time from the host. 236 // 237 // This function returns true if device_buffer can be accessed immediately 238 // without waiting for the Stream's previously enqueued items. This only 239 // returns true if all subbuffers in device_buffer can be accessed 240 // immediately. CanShapedBufferBeAccessedNow(se::StreamExecutor * executor,const ShapedBuffer & device_buffer)241 virtual bool CanShapedBufferBeAccessedNow( 242 se::StreamExecutor* executor, const ShapedBuffer& device_buffer) const { 243 return false; 244 } 245 246 ///// 247 // The TransferManager class also serves as a point to register objects for 248 // the various platforms. 249 250 // Registers the TransferManager singleton for the platform kind. This is 251 // assumed to be a singleton, so no ownership is transferred. 252 // 253 // Precondition: a platform kind must not be registered more than once. 254 typedef std::unique_ptr<TransferManager> (*TransferManagerCreationFunction)(); 255 static void RegisterTransferManager( 256 se::Platform::Id platform_id, 257 TransferManagerCreationFunction transfer_manager); 258 259 // Returns the transfer manager singleton pointer if it is available for the 260 // given platform, or an error status if it is not. 261 static StatusOr<TransferManager*> GetForPlatform( 262 const se::Platform* platform); 263 264 protected: 265 // Transfer a memory block of the given size from the device source into the 266 // 'destination' buffer. 267 // 268 // size is the size to transfer to destination in bytes. 269 virtual Status TransferBufferFromDevice(se::Stream* stream, 270 const se::DeviceMemoryBase& source, 271 int64 size, void* destination); 272 273 // Transfer a memory block of the given size from 'source' buffer to the given 274 // destination of the device. 275 // 276 // size is the size to transfer from source in bytes. 277 virtual Status TransferBufferToDevice(se::Stream* stream, int64 size, 278 const void* source, 279 se::DeviceMemoryBase* destination); 280 281 // Writes the given device-memory pointers in 'elements' to the given region 282 // to construct a tuple index table in the platform-specific tuple 283 // representation. 284 virtual Status WriteSingleTupleIndexTable( 285 se::Stream* stream, absl::Span<const se::DeviceMemoryBase> elements, 286 const Shape& shape, se::DeviceMemoryBase* region) = 0; 287 288 private: 289 // The mutex that guards the platform-to-transfer manager map. 290 static tensorflow::mutex platform_transfer_manager_mutex_; 291 292 // State kept for each kind of TransferManager. Registration functions 293 // set up creation_function, and then we use that to lazily create 294 // "manager" the first time GetForPlatform is invoked for a particular id. 295 struct State { 296 std::unique_ptr<TransferManager> manager; 297 TransferManagerCreationFunction creation_function = nullptr; 298 }; 299 300 // Map from platform kind to transfer manager singleton. 301 static std::map<se::Platform::Id, State>* GetPlatformTransferManagers(); 302 }; 303 304 } // namespace xla 305 306 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_TRANSFER_MANAGER_H_ 307