• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Classes for keeping track of on-device state.
17 
18 #ifndef TENSORFLOW_COMPILER_XRT_XRT_STATE_H_
19 #define TENSORFLOW_COMPILER_XRT_XRT_STATE_H_
20 
21 #include <atomic>
22 #include <functional>
23 #include <memory>
24 #include <string>
25 #include <vector>
26 
27 #include "tensorflow/compiler/xla/literal.h"
28 #include "tensorflow/compiler/xla/service/backend.h"
29 #include "tensorflow/compiler/xla/service/executable.h"
30 #include "tensorflow/compiler/xla/service/shaped_buffer.h"
31 #include "tensorflow/compiler/xla/shape_util.h"
32 #include "tensorflow/compiler/xla/statusor.h"
33 #include "tensorflow/compiler/xla/xla_data.pb.h"
34 #include "tensorflow/compiler/xrt/xrt_refptr.h"
35 #include "tensorflow/core/lib/core/refcount.h"
36 #include "tensorflow/core/lib/core/status.h"
37 #include "tensorflow/core/platform/mutex.h"
38 #include "tensorflow/core/platform/types.h"
39 #include "tensorflow/stream_executor/device_memory_allocator.h"
40 #include "tensorflow/stream_executor/stream_executor.h"
41 
42 namespace tensorflow {
43 
44 // Cannot include xrt_memory_manager.h here, as it needs to include this file.
45 class XRTMemoryManager;
46 
47 // TODO(misard) make this a Tensor if and when that makes sense.
48 // A reference-counted wrapper around a buffer allocation. This maps an XLA
49 // tuple index or a non-tuple XLA shape to a region of device memory. The device
50 // memory buffer is freed when the reference count drops to zero.
51 class XRTBufferAllocation : public core::RefCounted {
52  public:
53   XRTBufferAllocation(const se::DeviceMemoryBase& allocation,
54                       int device_ordinal, se::DeviceMemoryAllocator* allocator);
55   ~XRTBufferAllocation() override;
56 
57   // The region of device memory being wrapped.
58   const se::DeviceMemoryBase& allocation();
59 
DiscardAllocation()60   void DiscardAllocation() { allocation_ = se::DeviceMemoryBase(); }
61 
62  private:
63   se::DeviceMemoryBase allocation_;
64   int device_ordinal_;
65   se::DeviceMemoryAllocator* allocator_;
66 };
67 
68 // A XRTTupleAllocation represents an allocated memory area on the device.
69 // New tuples can be created in three ways: by passing a literal in which case
70 // device memory is allocated and the literal is transferred to that memory; by
71 // aliasing a sub-shape of an existing tuple-shaped handle; or by aliasing a
72 // vector of existing handles to create a new tuple. The underlying storage is
73 // reference-counted. When a handle is released, the reference count of each
74 // storage buffer is decremented, and buffers with no outstanding references are
75 // freed.
76 class XRTTupleAllocation : public core::RefCounted {
77  public:
78   ~XRTTupleAllocation() override;
79 
80   // Allocates new device memory buffers sufficient to store literal, transfers
81   // literal to that memory, and returns a XRTTupleAllocation handle to the
82   // allocated buffers.
83   static Status CreateAndTransfer(const xla::LiteralBase& literal,
84                                   XRTMemoryManager* memory_manager,
85                                   xla::Backend* backend, int device_ordinal,
86                                   XRTTupleAllocation** allocation);
87 
88   // Allocates new device memory buffers sufficient to store a tensor of
89   // the specified shape, and returns a XRTTupleAllocation handle to the
90   // allocated buffers.  The allocated buffers are not initialized.
91   static Status CreateUninitialized(const xla::Shape& shape,
92                                     XRTMemoryManager* memory_manager,
93                                     xla::Backend* backend, int device_ordinal,
94                                     XRTTupleAllocation** allocation);
95 
96   // Wraps an existing ShapeBuffer in a new XRTTupleAllocation handle.
97   static Status CreateFromBuffer(const xla::ShapedBuffer& shaped_buffer,
98                                  xla::Backend* backend, int device_ordinal,
99                                  XRTTupleAllocation** allocation);
100 
101   // Same as the CreateFromBuffer() API above, but with the shapes being passed
102   // as input. This API is used when creating tuple allocations with the output
103   // of XLA computations which emit dynamic shaped output via the output shape
104   // table.
105   static Status CreateFromBuffer(const xla::ShapedBuffer& shaped_buffer,
106                                  const xla::Shape& on_host_shape,
107                                  const xla::Shape& on_device_shape,
108                                  xla::Backend* backend, int device_ordinal,
109                                  XRTTupleAllocation** allocation);
110 
111   // Aliases a sub-shape of parent and returns a XRTTupleAllocation handle
112   // to the sub-shape. If alias_base_allocation is true, the buffers in the
113   // sub-shape will be shared between parent and the returned allocation,
114   // otherwise the overlapping buffers in parent will be replaced by
115   // nullptr.
116   static Status MakeSubBuffer(XRTTupleAllocation* parent,
117                               const xla::ShapeIndex& subshape,
118                               XRTTupleAllocation** allocation,
119                               bool alias_parent_allocation);
120 
121   // A structure describing a leaf of a tree of tuples to expand. Each leaf
122   // contains an allocation and indicates whether or not the allocation's handle
123   // should be freed after incorporating its buffers into the expanded tree.
124   struct ExpandedTupleInput {
125     RefPtr<XRTTupleAllocation> allocation;
126     bool release_allocation_after_use;
127   };
128 
129   // Returns a handle to a new tuple where the subtree of the new tuple at an
130   // index corresponding to a leaf of 'elements' is constructed from the
131   // allocation (i.e., a tuple or array) pointed to by that leaf. If
132   // release_allocation_after_use is false at a leaf, the new tuple will alias
133   // the input allocation at that leaf, otherwise the input allocation will be
134   // released. Input allocations may be repeated (appear in more than one leaf)
135   // in which case the corresponding buffers in the output tuple will alias. If
136   // an input is repeated, release_input_handle must be false for every leaf
137   // where that input appears. The latter property is not validated by MakeTuple
138   // and must be enforced by the caller.
139   static Status MakeTuple(XRTMemoryManager* memory_manager,
140                           xla::Backend* backend, int device_ordinal,
141                           const xla::ShapeTree<ExpandedTupleInput>& elements,
142                           XRTTupleAllocation** allocation);
143 
144   // Copies the allocation from device to host and returns it in literal.
145   Status ToLiteral(xla::Backend* backend, xla::MutableLiteralBase* literal);
146 
147   // Write a new literal value to the allocation.
148   Status WriteLiteral(xla::Backend* backend, const xla::Literal& literal);
149 
150   // Stores the content of the tuple allocation into the internal literal, and
151   // releases all the device buffers. The swap_pinned flag tells whether a
152   // pinned allocation should be swapped out. It should be false on all cases,
153   // but during the memory compaction operation from the XRTMemoryManager.
154   // Returns a boolean telling whether the allocation was swapped out.
155   xla::StatusOr<bool> SwapOut(xla::Backend* backend, bool swap_pinned);
156 
157   // Allocates the device memory required to store the tuple value held within
158   // the internal literal, and transfer the literal value into the device
159   // memory. Returns a boolean telling whether the allocation was swapped in.
160   xla::StatusOr<bool> SwapIn(XRTMemoryManager* memory_manager,
161                              xla::Backend* backend);
162 
163   // Pins the allocation first, then swap it in (if it is not already). After
164   // this API returns, the allocation is pinned and its content on device
165   // memory. The caller is responsible for releasing the pin-count using the
166   // Unpin() API.
167   xla::StatusOr<bool> PinAndSwapIn(XRTMemoryManager* memory_manager,
168                                    xla::Backend* backend);
169 
170   // Checks whether the allocation is currently swapped out.
171   bool IsSwapped() const;
172 
173   // Increases the pin-count of this allocation. If the pin-count is greater
174   // than 0, the allocation cannot be swapped. Returned the pin-count value
175   // before the increase.
176   int64 Pin();
177 
178   // Decreases the pin-count of this allocation. Returned the pin-count value
179   // before the decrease.
180   int64 Unpin();
181 
182   // Checks whether the allocation is currently pinned.
183   bool IsPinned() const;
184 
185   // True if none of the buffers in the allocation are aliased by any other live
186   // handle.
187   bool IsExclusiveOwner() const;
188 
189   // Retrieves the footprint in terms of device memory, of this allocation.
190   size_t GetDeviceMemorySize() const;
191 
192   // The ordinal of the device holding this tuple.
193   int device_ordinal() const;
194 
195   // Returns the shape of the tuple as seen by the host.
196   const xla::Shape& on_host_shape() const;
197 
198   // Returns the shape of the tuple as stored on the device.
199   const xla::Shape& on_device_shape() const;
200 
201   // Returns the buffer pointed to by the root of the tuple.
202   const se::DeviceMemoryBase& root_allocation() const;
203 
204   // Stops managing the storage for the allocation at buffer_index, e.g.,
205   // because it has been aliased to the output buffer of a computation.
206   void DiscardAllocation(const xla::ShapeIndex& buffer_index);
207 
208   // Returns the tree of allocations as a ShapedBuffer. This tree may not have
209   // the same shape as on_host_shape.
210   xla::StatusOr<xla::ShapedBuffer> ToShapedBuffer();
211 
212   // Aliases the source buffer at source_index into the current tuple allocation
213   // dest_index.
214   Status AliasBufferFrom(const XRTTupleAllocation& source,
215                          const xla::ShapeIndex& source_index,
216                          const xla::ShapeIndex& dest_index);
217 
218   // Returns the device memory tree of this allocation. If the alias_checker
219   // function returns true for a given index, an owned device memory is returned
220   // to the caller. But the tuple allocation cannot release the ownership in
221   // full, as the execute operation might fail. So we rely on a call to
222   // AliasBufferFrom() to re-alias back the buffers. This is not great (to say
223   // the least), but the current aliasing logic relies on
224   // MaybeOwningDeviceMemory being owned, to detect the fact that the user may
225   // want to alias a buffer. Unfortunately to do that, it needs to release the
226   // ownership, which is a problem if the execute will fail.
227   // This calls for a refactoring of the whole owning/maybe-owning interface to
228   // introduce a sharing concept (IOW shared_ptr model vs. unique_ptr).
229   // We'd need something similar to XRTTupleAllocation instead of
230   // ScopedShapedBuffer, which wants ownership and does not allow sharing.
231   xla::StatusOr<xla::ExecutionInput> ToExecutionInput(
232       const std::function<xla::StatusOr<bool>(const xla::ShapeIndex&)>&
233           alias_checker);
234 
235  private:
236   // Creates a new handle with (tuple) shape.
237   XRTTupleAllocation(int device_ordinal, se::DeviceMemoryAllocator* allocator,
238                      const xla::Shape& on_host_shape,
239                      const xla::Shape& on_device_shape);
240 
241   // Inherits the allocations represented in buffer, which must have the same
242   // shape as buffers_.
243   void InitializeFromShapedBuffer(const xla::ShapedBuffer& shaped_buffer,
244                                   se::DeviceMemoryAllocator* allocator,
245                                   int device_ordinal);
246 
247   // Releases all the XRTBufferAllocation buffer references and set the
248   // corresponding shape tree entry to nullptr.
249   void ReleaseBuffers();
250 
251   // Stores the content of the allocation from device memory to the target host
252   // literal.
253   Status StoreToLiteral(xla::Backend* backend,
254                         xla::MutableLiteralBase* literal);
255 
256   // Sets the total size of the buffers held within this allocation buffers.
257   // This API should be called once when an XRTTupleAllocation object is
258   // created, as the XRTTupleAllocation shapes never change, and hence the
259   // device memory size.
260   void SetDeviceMemorySize();
261 
262   // Takes a tree 'elements' where each leaf is an allocation, validates that
263   // they are all on device_ordinal managed by allocator, and returns in
264   // host_shape and device_shape the host/device shapes of the expanded tree,
265   // where at each leaf of elements the shape of the allocation at elements is
266   // grafted on.
267   static Status ExpandTreeOfTuples(
268       const xla::ShapeTree<ExpandedTupleInput>& elements, int device_ordinal,
269       se::DeviceMemoryAllocator* allocator, xla::Shape* host_shape,
270       xla::Shape* device_shape);
271 
272   // The lock which protects the internal operations of the tuple allocation. Is
273   // mutable to allow const-like operations to be declared as such.
274   mutable mutex lock_;
275 
276   // Location of the memory that is being managed.
277   const int device_ordinal_;
278   se::DeviceMemoryAllocator* const allocator_;
279 
280   // The shape that the caller thinks the tuple has.
281   const xla::Shape on_host_shape_;
282   // The shape that the tuple has on device. Store this explicitly instead of
283   // using a shape stored in ShapeTree because ShapeTree discards the layout.
284   const xla::Shape on_device_shape_;
285   // The tree of reference-counted buffers, which uses on_device_shape_ as its
286   // shape.
287   xla::ShapeTree<XRTBufferAllocation*> buffers_;
288   // The footprint of the allocation, when residing on device memory.
289   size_t device_memory_size_ = 0;
290   // If the allocation is swapped out, this is the literal storing its content.
291   std::unique_ptr<xla::Literal> literal_;
292   // A pinned allocation is one which cannot be swapped out. If pin_count_ > 0
293   // then the allocation is pinned.
294   std::atomic<int64> pin_count_;
295 };
296 
297 }  // namespace tensorflow
298 
299 #endif  // TENSORFLOW_COMPILER_XRT_XRT_STATE_H_
300