1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_TRANSFER_MANAGER_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_TRANSFER_MANAGER_H_ 18 19 #include <map> 20 #include <set> 21 #include <vector> 22 23 #include "absl/container/flat_hash_map.h" 24 #include "absl/types/span.h" 25 #include "tensorflow/compiler/xla/literal.h" 26 #include "tensorflow/compiler/xla/service/executable.h" 27 #include "tensorflow/compiler/xla/service/shaped_buffer.h" 28 #include "tensorflow/compiler/xla/statusor.h" 29 #include "tensorflow/compiler/xla/types.h" 30 #include "tensorflow/compiler/xla/xla_data.pb.h" 31 #include "tensorflow/core/platform/mutex.h" 32 #include "tensorflow/core/platform/stream_executor_no_cuda.h" 33 #include "tensorflow/core/platform/thread_annotations.h" 34 #include "tensorflow/core/platform/types.h" 35 #include "tensorflow/stream_executor/device_memory.h" 36 37 namespace xla { 38 39 // The TransferManager interface lets backends provide platform-specific 40 // mechanisms for constructing literals from given device memory handles. 41 // This lets each platform customize how literals are transferred to/from the 42 // device in terms of padding, leading dimension, etc. 43 class TransferManager { 44 public: ~TransferManager()45 virtual ~TransferManager() {} 46 47 // Returns the ID of the platform that this transfer manager acts on. 48 virtual se::Platform::Id PlatformId() const = 0; 49 50 // Returns the shape of the on-device representation for the given shape on 51 // the host. This is intended for use with ShapedBuffer where buffers are 52 // pre-allocated by the host, e.g. TransferLiteralToDevice, without the user 53 // needing to consider device-specific behaviors. HostShapeToDeviceShape(const Shape & host_shape)54 virtual Shape HostShapeToDeviceShape(const Shape& host_shape) const { 55 // Strips off any preexisting tiling or memory space information. 56 // TODO(phawkins): fix clients not to including tiling or memory space 57 // information in shapes passed to this function and turn this into an 58 // assertion. 59 return ShapeUtil::DeviceShapeToHostShape(host_shape); 60 } 61 62 // Base class for specifying platform specific transfer metadata that can be 63 // used to tell the underlying implementation to perform specific optimization 64 // to a transfer. Actual metadata passed to supported transfer methods should 65 // subclass this class. 66 class TransferMetadata { 67 public: 68 virtual ~TransferMetadata() = 0; 69 }; 70 // Returns a literal containing the data held in the given ShapedBuffer 71 // using the provided executor. This operation is performed synchronously 72 // without waiting for any other operation on a stream to complete. 73 // 74 // This function should be avoided in favor of the asynchronous version below. 75 // 76 // Optionally caller can specify platform-specific transfer metadata that 77 // tells the actual implementation to do something special. 78 virtual StatusOr<Literal> TransferLiteralFromDevice( 79 se::Stream* stream, const ShapedBuffer& device_buffer, 80 const TransferMetadata* transfer_metadata); TransferLiteralFromDevice(se::Stream * stream,const ShapedBuffer & device_buffer)81 StatusOr<Literal> TransferLiteralFromDevice( 82 se::Stream* stream, const ShapedBuffer& device_buffer) { 83 return TransferLiteralFromDevice(stream, device_buffer, nullptr); 84 } 85 virtual Status TransferLiteralFromDevice( 86 se::Stream* stream, const ShapedBuffer& device_buffer, 87 const MutableBorrowingLiteral& literal, 88 const TransferMetadata* transfer_metadata); TransferLiteralFromDevice(se::Stream * stream,const ShapedBuffer & device_buffer,const MutableBorrowingLiteral & literal)89 Status TransferLiteralFromDevice(se::Stream* stream, 90 const ShapedBuffer& device_buffer, 91 const MutableBorrowingLiteral& literal) { 92 return TransferLiteralFromDevice(stream, device_buffer, literal, nullptr); 93 } 94 95 // Begins transferring a literal containing the data held in the given 96 // ShapedBuffer using the provided executor. 97 // 98 // This operation is performed asynchronously on the given stream. It returns 99 // once the transfer is enqueued. 'done' is invoked with the result when 100 // complete. 101 // 102 // device_buffer is copied by reference and must live at least until done() is 103 // invoked. 104 // 105 // Optionally caller can specify platform-specific transfer metadata that 106 // tells the actual implementation to do something special. 107 virtual void TransferLiteralFromDevice( 108 se::Stream* stream, const ShapedBuffer& device_buffer, 109 MutableBorrowingLiteral literal, std::function<void(Status)> done, 110 const TransferMetadata* transfer_metadata) = 0; TransferLiteralFromDevice(se::Stream * stream,const ShapedBuffer & device_buffer,MutableBorrowingLiteral literal,std::function<void (Status)> done)111 void TransferLiteralFromDevice(se::Stream* stream, 112 const ShapedBuffer& device_buffer, 113 MutableBorrowingLiteral literal, 114 std::function<void(Status)> done) { 115 return TransferLiteralFromDevice(stream, device_buffer, literal, done, 116 nullptr); 117 } 118 119 // Transfers the given literal into the previously allocated device memory 120 // represented by the given ShapedBuffer using the given executor. The shape 121 // of the ShapedBuffer and DeviceShape(literal.shape()) must be compatible, 122 // but need not have the same layout. 123 // 124 // This operation is performed synchronously without waiting for any other 125 // operation on a stream to complete. This function should be avoided in favor 126 // of the asynchronous version below. 127 // 128 // Optionally caller can specify platform-specific transfer metadata that 129 // tells the actual implementation to do something special. 130 virtual Status TransferLiteralToDevice( 131 se::Stream* stream, const LiteralSlice& literal, 132 const ShapedBuffer& device_buffer, 133 const TransferMetadata* transfer_metadata); TransferLiteralToDevice(se::Stream * stream,const LiteralSlice & literal,const ShapedBuffer & device_buffer)134 Status TransferLiteralToDevice(se::Stream* stream, 135 const LiteralSlice& literal, 136 const ShapedBuffer& device_buffer) { 137 return TransferLiteralToDevice(stream, literal, device_buffer, nullptr); 138 } 139 140 // Transfers the given literal into the previously allocated device memory 141 // represented by the given ShapedBuffer using the given executor. The shape 142 // of the ShapedBuffer and DeviceShape(literal.shape()) must be compatible, 143 // but need not have the same layout. 144 // 145 // This operation is performed asynchronously on the given stream. It returns 146 // once the transfer is enqueued, and may return before the transfer has 147 // completed. 148 // 149 // The caller may free the data structures 'literal' and 'device_buffer' 150 // immediately after this function returns, however their constituent buffers 151 // on both host and device must remain valid until the enqueued transfer has 152 // completed on 'stream'. 153 // 154 // Optionally caller can specify platform-specific transfer metadata that 155 // tells the actual implementation to do something special. 156 virtual Status TransferLiteralToDeviceAsync( 157 se::Stream* stream, const LiteralSlice& literal, 158 const ShapedBuffer& device_buffer, 159 const TransferMetadata* transfer_metadata) = 0; TransferLiteralToDeviceAsync(se::Stream * stream,const LiteralSlice & literal,const ShapedBuffer & device_buffer)160 Status TransferLiteralToDeviceAsync(se::Stream* stream, 161 const LiteralSlice& literal, 162 const ShapedBuffer& device_buffer) { 163 return TransferLiteralToDeviceAsync(stream, literal, device_buffer, 164 nullptr); 165 } 166 167 // Convenience methods for transferring an array to or from the device at a 168 // known address. This avoids having to construct a ShapedBuffer just to 169 // transfer an array at a known address. 170 // 171 // Optionally caller can specify platform-specific transfer metadata that 172 // tells the actual implementation to do something special. 173 Status TransferArrayToDevice( 174 se::Stream* stream, const LiteralSlice& literal, 175 const se::DeviceMemoryBase& dest, 176 const TransferMetadata* transfer_metadata = nullptr); 177 void TransferArrayFromDevice( 178 se::Stream* stream, const Shape& shape, 179 const se::DeviceMemoryBase& source, 180 const MutableBorrowingLiteral& literal, std::function<void(Status)> done, 181 const TransferMetadata* transfer_metadata = nullptr); 182 183 Status TransferArrayToDeviceAsync( 184 se::Stream* stream, const LiteralSlice& literal, 185 const se::DeviceMemoryBase& dest, 186 const TransferMetadata* transfer_metadata = nullptr); 187 StatusOr<Literal> TransferArrayFromDevice( 188 se::Stream* stream, const Shape& shape, 189 const se::DeviceMemoryBase& source, 190 const TransferMetadata* transfer_metadata = nullptr); 191 192 // Read from a device buffer and update the dynamic dimension sizes of 193 // `host_shape` and `device_shape`. The function takes in bounded dynamic 194 // shapes, and returns static shapes with dynamic shapes updated. 195 // The shape of the buffer also have to be compatible with the host shape and 196 // device shape. 197 virtual Status ReadDynamicShapes(se::Stream* stream, 198 ShapedBuffer* device_buffer, 199 Shape* device_shape); 200 201 // Transfers the given literal into the Infeed interface of the device, 202 // using the given executor. 203 virtual Status TransferLiteralToInfeed(se::StreamExecutor* executor, 204 const LiteralSlice& literal) = 0; 205 206 // Transfers the given literal from the Outfeed interface of the device, 207 // using the given executor. The shape and layout are determined by the 208 // shape and layout of `literal`. 209 virtual Status TransferLiteralFromOutfeed( 210 se::StreamExecutor* executor, MutableBorrowingLiteral literal) = 0; 211 212 // Resets the devices associated with this transfer manager. 213 virtual Status ResetDevices( 214 absl::Span<se::StreamExecutor* const> executor) = 0; 215 216 // Given an allocated ShapedBuffer, constructs the tuple index table(s) in 217 // each buffer of the given ShapedBuffer corresponding to tuple shapes. If the 218 // ShapedBuffer is array-shaped this method does nothing. 219 Status WriteTupleIndexTables(se::Stream* stream, 220 const ShapedBuffer& device_buffer); 221 Status WriteTupleIndexTablesAsync(se::Stream* stream, 222 const ShapedBuffer& device_buffer); 223 224 // Writes a tuple index buffer for the root of 'device_buffer', which must 225 // be a tuple. Unlike WriteTupleIndexTables, only writes the root buffer, 226 // rather than writing all subbuffers. This method is always asynchronous. 227 Status WriteRootTupleIndexTable(se::Stream* stream, 228 const ShapedBuffer& device_buffer); 229 Status WriteRootTupleIndexTable( 230 se::Stream* stream, 231 const ShapeTree<MaybeOwningDeviceMemory>& buffer_tree); 232 233 // Determines the byte size requirement for the given shape on the underlying 234 // architecture. This will be used to allocate an appropriately sized memory 235 // region for a host-to-device transfer. 236 virtual int64 GetByteSizeRequirement(const Shape& shape) const = 0; 237 238 // Chooses a compact layout for 'shape', ignoring any existing layout on 239 // 'shape'. What "reasonable" means is left up to the backend. The 240 // intended use case is to choose a layout that avoids excessive padding on 241 // devices that have tiled memory architectures. 242 // The default implementation always picks a default (major-to-minor) layout. 243 // Fails if 'shape' cannot be represented by the device. 244 virtual StatusOr<Shape> ChooseCompactLayoutForShape( 245 const Shape& host_shape) const; 246 247 typedef std::function<Shape(const Shape&)> DeviceShapeRepresentationFn; 248 249 // Allocates a ScopedShapedBuffer which can hold data with the given on-host 250 // shape. The on-device shape may be different as indicated by 251 // HostShapeToDeviceShape. 252 StatusOr<ScopedShapedBuffer> AllocateScopedShapedBuffer( 253 const Shape& on_host_shape, se::DeviceMemoryAllocator* allocator, 254 int device_ordinal, 255 DeviceShapeRepresentationFn shape_representation_fn = nullptr); 256 257 // The given ShapedBuffer holds a handle to allocated memory, but it is not 258 // in the general case legal to immediately copy or access that allocated 259 // memory because queued operations on the device may alias that memory. 260 // Memory ordering is enforced by the Stream's happens-before relationship 261 // which allows eager deallocation and reallocation of buffers host-side even 262 // if the device hasn't finished with them. 263 // 264 // In certain cases, it can be known that a ShapedBuffer does not have any 265 // conflicting accesses on the device and thus is eligible to be accessed at 266 // any time from the host. 267 // 268 // This function returns true if device_buffer can be accessed immediately 269 // without waiting for the Stream's previously enqueued items. This only 270 // returns true if all subbuffers in device_buffer can be accessed 271 // immediately. CanShapedBufferBeAccessedNow(se::StreamExecutor * executor,const ShapedBuffer & device_buffer)272 virtual bool CanShapedBufferBeAccessedNow( 273 se::StreamExecutor* executor, const ShapedBuffer& device_buffer) const { 274 return false; 275 } 276 277 // Equivalent to CanShapedBufferBeAccessedNow but for a single device buffer. CanBufferBeAccessedNow(se::StreamExecutor * executor,const se::DeviceMemoryBase & device_buffer)278 virtual bool CanBufferBeAccessedNow( 279 se::StreamExecutor* executor, 280 const se::DeviceMemoryBase& device_buffer) const { 281 return false; 282 } 283 284 ///// 285 // The TransferManager class also serves as a point to register objects for 286 // the various platforms. 287 288 // Registers the TransferManager singleton for the platform kind. This is 289 // assumed to be a singleton, so no ownership is transferred. 290 // 291 // Precondition: a platform kind must not be registered more than once. 292 typedef std::unique_ptr<TransferManager> (*TransferManagerCreationFunction)(); 293 static void RegisterTransferManager( 294 se::Platform::Id platform_id, 295 TransferManagerCreationFunction transfer_manager); 296 297 // Returns the transfer manager singleton pointer if it is available for the 298 // given platform, or an error status if it is not. 299 static StatusOr<TransferManager*> GetForPlatform( 300 const se::Platform* platform); 301 302 // Writes the given device-memory pointers in 'elements' to the given region 303 // to construct a tuple index table in the platform-specific tuple 304 // representation. 305 virtual Status WriteSingleTupleIndexTable( 306 se::Stream* stream, absl::Span<const se::DeviceMemoryBase> elements, 307 const Shape& shape, se::DeviceMemoryBase* region) = 0; 308 309 protected: 310 // Transfer a memory block of the given size from the device source into the 311 // 'destination' buffer. 312 // 313 // size is the size to transfer to destination in bytes. 314 virtual Status TransferBufferFromDevice(se::Stream* stream, 315 const se::DeviceMemoryBase& source, 316 int64_t size, void* destination); 317 318 // Transfer a memory block of the given size from 'source' buffer to the given 319 // destination of the device. 320 // 321 // size is the size to transfer from source in bytes. 322 virtual Status TransferBufferToDevice(se::Stream* stream, int64_t size, 323 const void* source, 324 se::DeviceMemoryBase* destination); 325 326 private: 327 // The mutex that guards the platform-to-transfer manager map. 328 static tensorflow::mutex platform_transfer_manager_mutex_; 329 330 // State kept for each kind of TransferManager. Registration functions 331 // set up creation_function, and then we use that to lazily create 332 // "manager" the first time GetForPlatform is invoked for a particular id. 333 struct State { 334 std::unique_ptr<TransferManager> manager; 335 TransferManagerCreationFunction creation_function = nullptr; 336 }; 337 338 // Map from platform kind to transfer manager singleton. 339 static absl::flat_hash_map<se::Platform::Id, State>* 340 GetPlatformTransferManagers(); 341 }; 342 343 } // namespace xla 344 345 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_TRANSFER_MANAGER_H_ 346