1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // Classes for keeping track of on-device state. 17 18 #ifndef TENSORFLOW_COMPILER_XRT_XRT_STATE_H_ 19 #define TENSORFLOW_COMPILER_XRT_XRT_STATE_H_ 20 21 #include <functional> 22 #include <memory> 23 #include <string> 24 #include <vector> 25 26 #include "tensorflow/compiler/xla/literal.h" 27 #include "tensorflow/compiler/xla/service/backend.h" 28 #include "tensorflow/compiler/xla/service/device_memory_allocator.h" 29 #include "tensorflow/compiler/xla/service/shaped_buffer.h" 30 #include "tensorflow/compiler/xla/shape_util.h" 31 #include "tensorflow/compiler/xla/xla_data.pb.h" 32 #include "tensorflow/core/framework/resource_mgr.h" 33 #include "tensorflow/core/lib/core/refcount.h" 34 #include "tensorflow/core/lib/core/status.h" 35 #include "tensorflow/core/lib/gtl/array_slice.h" 36 #include "tensorflow/core/platform/types.h" 37 #include "tensorflow/stream_executor/stream_executor.h" 38 39 namespace tensorflow { 40 41 // TODO(misard) make this a Tensor if and when that makes sense. 42 // A reference-counted wrapper around a buffer allocation. This maps an XLA 43 // tuple index or a non-tuple XLA shape to a region of device memory. The device 44 // memory buffer is freed when the reference count drops to zero. 45 class XRTBufferAllocation : public core::RefCounted { 46 public: 47 XRTBufferAllocation(const se::DeviceMemoryBase& allocation, 48 int device_ordinal, 49 xla::DeviceMemoryAllocator* allocator); 50 ~XRTBufferAllocation() override; 51 52 // The region of device memory being wrapped. 53 const se::DeviceMemoryBase& allocation(); 54 55 // Sets the DeviceMemoryBase to be null. DiscardAllocation should be called 56 // when ownership of the underlying buffer has been transferred, e.g., to an 57 // output buffer when input and output buffers are aliased during 58 // execution. The call to DiscardAllocation prevents any device buffer being 59 // freed when the reference count drops to zero. 60 void DiscardAllocation(); 61 62 // Returns the expected size of the allocation. Since DiscardAllocation() will 63 // set allocation_ to {null,0}, and since later we might want to replace the 64 // discarded buffer with a new one, we need to be able to verify the size 65 // compatibility. size()66 uint64 size() const { return size_; } 67 68 private: 69 uint64 size_ = 0; 70 se::DeviceMemoryBase allocation_; 71 int device_ordinal_; 72 xla::DeviceMemoryAllocator* allocator_; 73 }; 74 75 // Entry in the resource manager corresponding to an allocation handle returned 76 // to a client. The handle identifies an immutable tuple of data in device 77 // memory. New handles can be created in three ways: by passing a literal in 78 // which case device memory is allocated and the literal is transferred to that 79 // memory; by aliasing a sub-shape of an existing tuple-shaped handle; or by 80 // aliasing a vector of existing handles to create a new tuple. The underlying 81 // storage is reference-counted. When a handle is released, the reference count 82 // of each storage buffer is decremented, and buffers with no outstanding 83 // references are freed. 84 class XRTTupleAllocation : public ResourceBase { 85 public: 86 ~XRTTupleAllocation() override; 87 88 // Allocates new device memory buffers sufficient to store literal, transfers 89 // literal to that memory, and returns a XRTTupleAllocation handle to the 90 // allocated buffers. 91 static Status CreateAndTransfer(const xla::LiteralBase& literal, 92 xla::Backend* backend, int device_ordinal, 93 XRTTupleAllocation** allocation); 94 95 // Wraps an existing ShapeBuffer in a new XRTTupleAllocation handle. 96 static Status CreateFromBuffer(const xla::ShapedBuffer& shaped_buffer, 97 xla::Backend* backend, int device_ordinal, 98 XRTTupleAllocation** allocation); 99 100 // Aliases a sub-shape of parent and returns a XRTTupleAllocation handle 101 // to the sub-shape. If alias_base_allocation is true, the buffers in the 102 // sub-shape will be shared between parent and the returned allocation, 103 // otherwise the overlapping buffers in parent will be replaced by 104 // nullptr. 105 static Status MakeSubBuffer(XRTTupleAllocation* parent, 106 const xla::ShapeIndex& subshape, 107 XRTTupleAllocation** allocation, 108 bool alias_parent_allocation); 109 110 // A structure describing a leaf of a tree of tuples to expand. Each leaf 111 // contains an allocation and indicates whether or not the allocation's handle 112 // should be freed after incorporating its buffers into the expanded tree. 113 struct ExpandedTupleInput { 114 XRTTupleAllocation* allocation; 115 bool release_allocation_after_use; 116 }; 117 118 // Returns a handle to a new tuple where the subtree of the new tuple at an 119 // index corresponding to a leaf of 'elements' is constructed from the 120 // allocation (i.e., a tuple or array) pointed to by that leaf. If 121 // release_allocation_after_use is false at a leaf, the new tuple will alias 122 // the input allocation at that leaf, otherwise the input allocation will be 123 // released. Input allocations may be repeated (appear in more than one leaf) 124 // in which case the corresponding buffers in the output tuple will alias. If 125 // an input is repeated, release_input_handle must be false for every leaf 126 // where that input appears. The latter property is not validated by MakeTuple 127 // and must be enforced by the caller. 128 static Status MakeTuple(xla::Backend* backend, int device_ordinal, 129 const xla::ShapeTree<ExpandedTupleInput>& elements, 130 XRTTupleAllocation** allocation); 131 132 // Retrieves the allocation interned under key from rm. The caller owns a 133 // reference to allocation after looking it up. 134 static Status Lookup(ResourceMgr* rm, int64 key, 135 XRTTupleAllocation** allocation); 136 137 // Deletes the reference in the rm to an allocation interned under key. 138 static Status DeleteFromResourceManager(ResourceMgr* rm, int64 key); 139 140 // Releases all the device memory allocated by XRT within the resource 141 // manager. 142 static Status ReleaseAllAllocations(ResourceMgr* rm); 143 144 // Adds the allocation to a ResourceMgr and returns the key that will be used 145 // to retrieve it. Transfers a reference on *this to rm. 146 Status Intern(ResourceMgr* rm, int64* key); 147 148 // Copies the allocation from device to host and returns it in literal. 149 Status ToLiteral(xla::Backend* backend, int device_ordinal, 150 xla::MutableLiteralBase* literal); 151 152 // Write a new literal value to the allocation. 153 Status WriteLiteral(xla::Backend* backend, const xla::Literal& literal); 154 155 // True if none of the buffers in the allocation are aliased by any other live 156 // handle. 157 bool IsExclusiveOwner(); 158 159 // The ordinal of the device holding this tuple. 160 int device_ordinal(); 161 162 // Returns the shape of the tuple as seen by the host. 163 const xla::Shape& on_host_shape(); 164 165 // Returns the shape of the tuple as stored on the device. 166 const xla::Shape& on_device_shape(); 167 168 // Returns the buffer pointed to by the root of the tuple. 169 const se::DeviceMemoryBase& root_allocation(); 170 171 // Stops managing the storage for the allocation at buffer_index, e.g., 172 // because it has been aliased to the output buffer of a computation. 173 void DiscardAllocation(const xla::ShapeIndex& buffer_index); 174 175 // Returns the tree of allocations as a ShapedBuffer. This tree may not have 176 // the same shape as on_host_shape. 177 xla::ShapedBuffer ToShapedBuffer(); 178 179 // Aliases the source buffer at source_index into the current tuple allocation 180 // dest_index. 181 Status AliasBufferFrom(const XRTTupleAllocation& source, 182 const xla::ShapeIndex& source_index, 183 const xla::ShapeIndex& dest_index); 184 185 // Returns the device memory tree of this allocation. If the release_checker 186 // function returns true for a given index, the ownership of the device memory 187 // at that index is transferred to the result. Every attempt to read the value 188 // at that index will fail. 189 xla::ShapeTree<xla::MaybeOwningDeviceMemory> ToDeviceMemoryTree( 190 const std::function<bool(const xla::ShapeIndex&)>& release_checker); 191 DebugString()192 string DebugString() const override { return "XLA allocation handle"; } 193 194 private: 195 // Creates a new handle with (tuple) shape. 196 XRTTupleAllocation(int device_ordinal, xla::DeviceMemoryAllocator* allocator, 197 const xla::Shape& on_host_shape, 198 const xla::Shape& on_device_shape); 199 200 // Inherits the allocations represented in buffer, which must have the same 201 // shape as buffers_. 202 void InitializeFromShapedBuffer(const xla::ShapedBuffer& shaped_buffer, 203 xla::DeviceMemoryAllocator* allocator, 204 int device_ordinal); 205 206 // Takes a tree 'elements' where each leaf is an allocation, validates that 207 // they are all on device_ordinal managed by allocator, and returns in 208 // host_shape and device_shape the host/device shapes of the expanded tree, 209 // where at each leaf of elements the shape of the allocation at elements is 210 // grafted on. 211 static Status ExpandTreeOfTuples( 212 const xla::ShapeTree<ExpandedTupleInput>& elements, int device_ordinal, 213 xla::DeviceMemoryAllocator* allocator, xla::Shape* host_shape, 214 xla::Shape* device_shape); 215 216 // Location of the memory that is being managed. 217 int device_ordinal_; 218 xla::DeviceMemoryAllocator* allocator_; 219 220 // The shape that the caller thinks the tuple has. 221 const xla::Shape on_host_shape_; 222 // The shape that the tuple has on device. Store this explicitly instead of 223 // using a shape stored in ShapeTree because ShapeTree discards the layout. 224 const xla::Shape on_device_shape_; 225 // The tree of reference-counted buffers, which uses on_device_shape_ as its 226 // shape. 227 xla::ShapeTree<XRTBufferAllocation*> buffers_; 228 }; 229 230 } // namespace tensorflow 231 232 #endif // TENSORFLOW_COMPILER_XRT_XRT_STATE_H_ 233