1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // Classes for keeping track of on-device state. 17 18 #ifndef TENSORFLOW_COMPILER_XRT_XRT_STATE_H_ 19 #define TENSORFLOW_COMPILER_XRT_XRT_STATE_H_ 20 21 #include <atomic> 22 #include <functional> 23 #include <memory> 24 #include <string> 25 #include <vector> 26 27 #include "tensorflow/compiler/xla/literal.h" 28 #include "tensorflow/compiler/xla/service/backend.h" 29 #include "tensorflow/compiler/xla/service/executable.h" 30 #include "tensorflow/compiler/xla/service/shaped_buffer.h" 31 #include "tensorflow/compiler/xla/shape_util.h" 32 #include "tensorflow/compiler/xla/statusor.h" 33 #include "tensorflow/compiler/xla/xla_data.pb.h" 34 #include "tensorflow/compiler/xrt/xrt_refptr.h" 35 #include "tensorflow/core/lib/core/refcount.h" 36 #include "tensorflow/core/lib/core/status.h" 37 #include "tensorflow/core/platform/mutex.h" 38 #include "tensorflow/core/platform/types.h" 39 #include "tensorflow/stream_executor/device_memory_allocator.h" 40 #include "tensorflow/stream_executor/stream_executor.h" 41 42 namespace tensorflow { 43 44 // Cannot include xrt_memory_manager.h here, as it needs to include this file. 45 class XRTMemoryManager; 46 47 // TODO(misard) make this a Tensor if and when that makes sense. 48 // A reference-counted wrapper around a buffer allocation. This maps an XLA 49 // tuple index or a non-tuple XLA shape to a region of device memory. The device 50 // memory buffer is freed when the reference count drops to zero. 51 class XRTBufferAllocation : public core::RefCounted { 52 public: 53 XRTBufferAllocation(const se::DeviceMemoryBase& allocation, 54 int device_ordinal, se::DeviceMemoryAllocator* allocator); 55 ~XRTBufferAllocation() override; 56 57 // The region of device memory being wrapped. 58 const se::DeviceMemoryBase& allocation(); 59 DiscardAllocation()60 void DiscardAllocation() { allocation_ = se::DeviceMemoryBase(); } 61 62 private: 63 se::DeviceMemoryBase allocation_; 64 int device_ordinal_; 65 se::DeviceMemoryAllocator* allocator_; 66 }; 67 68 // A XRTTupleAllocation represents an allocated memory area on the device. 69 // New tuples can be created in three ways: by passing a literal in which case 70 // device memory is allocated and the literal is transferred to that memory; by 71 // aliasing a sub-shape of an existing tuple-shaped handle; or by aliasing a 72 // vector of existing handles to create a new tuple. The underlying storage is 73 // reference-counted. When a handle is released, the reference count of each 74 // storage buffer is decremented, and buffers with no outstanding references are 75 // freed. 76 class XRTTupleAllocation : public core::RefCounted { 77 public: 78 ~XRTTupleAllocation() override; 79 80 // Allocates new device memory buffers sufficient to store literal, transfers 81 // literal to that memory, and returns a XRTTupleAllocation handle to the 82 // allocated buffers. 83 static Status CreateAndTransfer(const xla::LiteralBase& literal, 84 XRTMemoryManager* memory_manager, 85 xla::Backend* backend, int device_ordinal, 86 XRTTupleAllocation** allocation); 87 88 // Allocates new device memory buffers sufficient to store a tensor of 89 // the specified shape, and returns a XRTTupleAllocation handle to the 90 // allocated buffers. The allocated buffers are not initialized. 91 static Status CreateUninitialized(const xla::Shape& shape, 92 XRTMemoryManager* memory_manager, 93 xla::Backend* backend, int device_ordinal, 94 XRTTupleAllocation** allocation); 95 96 // Wraps an existing ShapeBuffer in a new XRTTupleAllocation handle. 97 static Status CreateFromBuffer(const xla::ShapedBuffer& shaped_buffer, 98 xla::Backend* backend, int device_ordinal, 99 XRTTupleAllocation** allocation); 100 101 // Same as the CreateFromBuffer() API above, but with the shapes being passed 102 // as input. This API is used when creating tuple allocations with the output 103 // of XLA computations which emit dynamic shaped output via the output shape 104 // table. 105 static Status CreateFromBuffer(const xla::ShapedBuffer& shaped_buffer, 106 const xla::Shape& on_host_shape, 107 const xla::Shape& on_device_shape, 108 xla::Backend* backend, int device_ordinal, 109 XRTTupleAllocation** allocation); 110 111 // Aliases a sub-shape of parent and returns a XRTTupleAllocation handle 112 // to the sub-shape. If alias_base_allocation is true, the buffers in the 113 // sub-shape will be shared between parent and the returned allocation, 114 // otherwise the overlapping buffers in parent will be replaced by 115 // nullptr. 116 static Status MakeSubBuffer(XRTTupleAllocation* parent, 117 const xla::ShapeIndex& subshape, 118 XRTTupleAllocation** allocation, 119 bool alias_parent_allocation); 120 121 // A structure describing a leaf of a tree of tuples to expand. Each leaf 122 // contains an allocation and indicates whether or not the allocation's handle 123 // should be freed after incorporating its buffers into the expanded tree. 124 struct ExpandedTupleInput { 125 RefPtr<XRTTupleAllocation> allocation; 126 bool release_allocation_after_use; 127 }; 128 129 // Returns a handle to a new tuple where the subtree of the new tuple at an 130 // index corresponding to a leaf of 'elements' is constructed from the 131 // allocation (i.e., a tuple or array) pointed to by that leaf. If 132 // release_allocation_after_use is false at a leaf, the new tuple will alias 133 // the input allocation at that leaf, otherwise the input allocation will be 134 // released. Input allocations may be repeated (appear in more than one leaf) 135 // in which case the corresponding buffers in the output tuple will alias. If 136 // an input is repeated, release_input_handle must be false for every leaf 137 // where that input appears. The latter property is not validated by MakeTuple 138 // and must be enforced by the caller. 139 static Status MakeTuple(XRTMemoryManager* memory_manager, 140 xla::Backend* backend, int device_ordinal, 141 const xla::ShapeTree<ExpandedTupleInput>& elements, 142 XRTTupleAllocation** allocation); 143 144 // Copies the allocation from device to host and returns it in literal. 145 Status ToLiteral(xla::Backend* backend, xla::MutableLiteralBase* literal); 146 147 // Write a new literal value to the allocation. 148 Status WriteLiteral(xla::Backend* backend, const xla::Literal& literal); 149 150 // Stores the content of the tuple allocation into the internal literal, and 151 // releases all the device buffers. The swap_pinned flag tells whether a 152 // pinned allocation should be swapped out. It should be false on all cases, 153 // but during the memory compaction operation from the XRTMemoryManager. 154 // Returns a boolean telling whether the allocation was swapped out. 155 xla::StatusOr<bool> SwapOut(xla::Backend* backend, bool swap_pinned); 156 157 // Allocates the device memory required to store the tuple value held within 158 // the internal literal, and transfer the literal value into the device 159 // memory. Returns a boolean telling whether the allocation was swapped in. 160 xla::StatusOr<bool> SwapIn(XRTMemoryManager* memory_manager, 161 xla::Backend* backend); 162 163 // Pins the allocation first, then swap it in (if it is not already). After 164 // this API returns, the allocation is pinned and its content on device 165 // memory. The caller is responsible for releasing the pin-count using the 166 // Unpin() API. 167 xla::StatusOr<bool> PinAndSwapIn(XRTMemoryManager* memory_manager, 168 xla::Backend* backend); 169 170 // Checks whether the allocation is currently swapped out. 171 bool IsSwapped() const; 172 173 // Increases the pin-count of this allocation. If the pin-count is greater 174 // than 0, the allocation cannot be swapped. Returned the pin-count value 175 // before the increase. 176 int64 Pin(); 177 178 // Decreases the pin-count of this allocation. Returned the pin-count value 179 // before the decrease. 180 int64 Unpin(); 181 182 // Checks whether the allocation is currently pinned. 183 bool IsPinned() const; 184 185 // True if none of the buffers in the allocation are aliased by any other live 186 // handle. 187 bool IsExclusiveOwner() const; 188 189 // Retrieves the footprint in terms of device memory, of this allocation. 190 size_t GetDeviceMemorySize() const; 191 192 // The ordinal of the device holding this tuple. 193 int device_ordinal() const; 194 195 // Returns the shape of the tuple as seen by the host. 196 const xla::Shape& on_host_shape() const; 197 198 // Returns the shape of the tuple as stored on the device. 199 const xla::Shape& on_device_shape() const; 200 201 // Returns the buffer pointed to by the root of the tuple. 202 const se::DeviceMemoryBase& root_allocation() const; 203 204 // Stops managing the storage for the allocation at buffer_index, e.g., 205 // because it has been aliased to the output buffer of a computation. 206 void DiscardAllocation(const xla::ShapeIndex& buffer_index); 207 208 // Returns the tree of allocations as a ShapedBuffer. This tree may not have 209 // the same shape as on_host_shape. 210 xla::StatusOr<xla::ShapedBuffer> ToShapedBuffer(); 211 212 // Aliases the source buffer at source_index into the current tuple allocation 213 // dest_index. 214 Status AliasBufferFrom(const XRTTupleAllocation& source, 215 const xla::ShapeIndex& source_index, 216 const xla::ShapeIndex& dest_index); 217 218 // Returns the device memory tree of this allocation. If the alias_checker 219 // function returns true for a given index, an owned device memory is returned 220 // to the caller. But the tuple allocation cannot release the ownership in 221 // full, as the execute operation might fail. So we rely on a call to 222 // AliasBufferFrom() to re-alias back the buffers. This is not great (to say 223 // the least), but the current aliasing logic relies on 224 // MaybeOwningDeviceMemory being owned, to detect the fact that the user may 225 // want to alias a buffer. Unfortunately to do that, it needs to release the 226 // ownership, which is a problem if the execute will fail. 227 // This calls for a refactoring of the whole owning/maybe-owning interface to 228 // introduce a sharing concept (IOW shared_ptr model vs. unique_ptr). 229 // We'd need something similar to XRTTupleAllocation instead of 230 // ScopedShapedBuffer, which wants ownership and does not allow sharing. 231 xla::StatusOr<xla::ExecutionInput> ToExecutionInput( 232 const std::function<xla::StatusOr<bool>(const xla::ShapeIndex&)>& 233 alias_checker); 234 235 private: 236 // Creates a new handle with (tuple) shape. 237 XRTTupleAllocation(int device_ordinal, se::DeviceMemoryAllocator* allocator, 238 const xla::Shape& on_host_shape, 239 const xla::Shape& on_device_shape); 240 241 // Inherits the allocations represented in buffer, which must have the same 242 // shape as buffers_. 243 void InitializeFromShapedBuffer(const xla::ShapedBuffer& shaped_buffer, 244 se::DeviceMemoryAllocator* allocator, 245 int device_ordinal); 246 247 // Releases all the XRTBufferAllocation buffer references and set the 248 // corresponding shape tree entry to nullptr. 249 void ReleaseBuffers(); 250 251 // Stores the content of the allocation from device memory to the target host 252 // literal. 253 Status StoreToLiteral(xla::Backend* backend, 254 xla::MutableLiteralBase* literal); 255 256 // Sets the total size of the buffers held within this allocation buffers. 257 // This API should be called once when an XRTTupleAllocation object is 258 // created, as the XRTTupleAllocation shapes never change, and hence the 259 // device memory size. 260 void SetDeviceMemorySize(); 261 262 // Takes a tree 'elements' where each leaf is an allocation, validates that 263 // they are all on device_ordinal managed by allocator, and returns in 264 // host_shape and device_shape the host/device shapes of the expanded tree, 265 // where at each leaf of elements the shape of the allocation at elements is 266 // grafted on. 267 static Status ExpandTreeOfTuples( 268 const xla::ShapeTree<ExpandedTupleInput>& elements, int device_ordinal, 269 se::DeviceMemoryAllocator* allocator, xla::Shape* host_shape, 270 xla::Shape* device_shape); 271 272 // The lock which protects the internal operations of the tuple allocation. Is 273 // mutable to allow const-like operations to be declared as such. 274 mutable mutex lock_; 275 276 // Location of the memory that is being managed. 277 const int device_ordinal_; 278 se::DeviceMemoryAllocator* const allocator_; 279 280 // The shape that the caller thinks the tuple has. 281 const xla::Shape on_host_shape_; 282 // The shape that the tuple has on device. Store this explicitly instead of 283 // using a shape stored in ShapeTree because ShapeTree discards the layout. 284 const xla::Shape on_device_shape_; 285 // The tree of reference-counted buffers, which uses on_device_shape_ as its 286 // shape. 287 xla::ShapeTree<XRTBufferAllocation*> buffers_; 288 // The footprint of the allocation, when residing on device memory. 289 size_t device_memory_size_ = 0; 290 // If the allocation is swapped out, this is the literal storing its content. 291 std::unique_ptr<xla::Literal> literal_; 292 // A pinned allocation is one which cannot be swapped out. If pin_count_ > 0 293 // then the allocation is pinned. 294 std::atomic<int64> pin_count_; 295 }; 296 297 } // namespace tensorflow 298 299 #endif // TENSORFLOW_COMPILER_XRT_XRT_STATE_H_ 300