• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Classes for allocating XLA literals in device memory and managing handles
17 // that refer to them.
18 
19 #include "tensorflow/compiler/xrt/xrt_state.h"
20 
21 #include <map>
22 #include <memory>
23 #include <string>
24 #include <utility>
25 
26 #include "absl/memory/memory.h"
27 #include "tensorflow/compiler/xla/service/backend.h"
28 #include "tensorflow/compiler/xla/status_macros.h"
29 #include "tensorflow/compiler/xrt/xrt_memory_manager.h"
30 
31 namespace tensorflow {
32 namespace {
33 
34 // Helper typedef to make ShapeTree ForEach helper lambda signatures more
35 // readable. They need a type of const T& where in this case T is the
36 // following pointer.
37 typedef XRTBufferAllocation* XRTBufferAllocationPtr;
38 
39 class BufferAllocStats {
40  public:
41   struct Stats {
42     int64 count = 0;
43     int64 size = 0;
44   };
45 
ReportAlloc(int64 device,int64 msize)46   Stats ReportAlloc(int64 device, int64 msize) {
47     mutex_lock lock(lock_);
48     Stats* device_stats = &stats_[device];
49     device_stats->count += 1;
50     device_stats->size += msize;
51     return *device_stats;
52   }
53 
ReportFree(int64 device,int64 msize)54   Stats ReportFree(int64 device, int64 msize) {
55     mutex_lock lock(lock_);
56     Stats* device_stats = &stats_[device];
57     device_stats->count -= 1;
58     device_stats->size -= msize;
59     return *device_stats;
60   }
61 
62  private:
63   mutable mutex lock_;
64   std::map<int64, Stats> stats_;
65 };
66 
GetAllocStats()67 BufferAllocStats* GetAllocStats() {
68   static BufferAllocStats* stats = new BufferAllocStats();
69   return stats;
70 }
71 
AllocateScopedShapedBuffer(XRTMemoryManager * memory_manager,xla::Backend * backend,int device_ordinal,const xla::Shape & shape,std::unique_ptr<xla::ScopedShapedBuffer> * buffer)72 Status AllocateScopedShapedBuffer(
73     XRTMemoryManager* memory_manager, xla::Backend* backend, int device_ordinal,
74     const xla::Shape& shape, std::unique_ptr<xla::ScopedShapedBuffer>* buffer) {
75   auto transfer_manager = backend->transfer_manager();
76   TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal));
77 
78   // XLA may use a different representation on device than the representation on
79   // the host. XLA does not document any contract for the relationship between
80   // these representations :/ Right now, the device shape is always a superset
81   // of the host shape, meaning that for any valid ShapeIndex in the host shape
82   // that ShapeIndex is also valid in the device shape, but not vice versa. In
83   // particular, some host-side types are rewritten to be tuples. We rely on
84   // this property when making sub-buffers, because we assume that if the client
85   // requests the host-shape sub-buffer at index i, that will correspond to the
86   // right device-shape sub-buffer at the same index.
87   xla::Shape on_device_shape = transfer_manager->HostShapeToDeviceShape(shape);
88   VLOG(3) << "Allocating literal buffer: host_shape="
89           << xla::ShapeUtil::HumanStringWithLayout(shape) << " device_shape="
90           << xla::ShapeUtil::HumanStringWithLayout(on_device_shape);
91 
92   // The ScopedShapedBuffer frees the buffers that have so far been allocated if
93   // it goes out of scope. That's useful if we return early as the result of an
94   // error allocating one of the later buffers.
95   *buffer = absl::make_unique<xla::ScopedShapedBuffer>(
96       shape, on_device_shape, backend->memory_allocator(), device_ordinal);
97   for (auto& index_to_buffer : (*buffer)->buffers()) {
98     const xla::Shape& subshape =
99         xla::ShapeUtil::GetSubshape(on_device_shape, index_to_buffer.first);
100     uint64 size = transfer_manager->GetByteSizeRequirement(subshape);
101     TF_ASSIGN_OR_RETURN(
102         se::OwningDeviceMemory buffer,
103         memory_manager->Allocate(backend, device_ordinal, size));
104     // Move our buffer into shaped_buffer, which takes ownership of it.
105     index_to_buffer.second = buffer.Release();
106     VLOG(2) << "Allocated buffer at " << index_to_buffer.second.opaque()
107             << " index " << index_to_buffer.first.ToString() << " (" << size
108             << " bytes)";
109   }
110 
111   TF_RETURN_IF_ERROR(
112       transfer_manager->WriteTupleIndexTables(stream.get(), *(buffer->get())));
113 
114   return Status::OK();
115 }
116 
117 }  // namespace
118 
XRTBufferAllocation(const se::DeviceMemoryBase & allocation,int device_ordinal,se::DeviceMemoryAllocator * allocator)119 XRTBufferAllocation::XRTBufferAllocation(const se::DeviceMemoryBase& allocation,
120                                          int device_ordinal,
121                                          se::DeviceMemoryAllocator* allocator)
122     : allocation_(allocation),
123       device_ordinal_(device_ordinal),
124       allocator_(allocator) {
125   if (VLOG_IS_ON(2)) {
126     auto stats =
127         GetAllocStats()->ReportAlloc(device_ordinal_, allocation_.size());
128     LOG(INFO) << "XRT Allocation Stats: device=" << device_ordinal_
129               << " count=" << stats.count << " size=" << stats.size;
130   }
131 }
132 
~XRTBufferAllocation()133 XRTBufferAllocation::~XRTBufferAllocation() {
134   if (VLOG_IS_ON(2)) {
135     GetAllocStats()->ReportFree(device_ordinal_, allocation_.size());
136   }
137   // Deallocate explicitly allows allocation_ to be null.
138   TF_CHECK_OK(allocator_->Deallocate(device_ordinal_, allocation_));
139   VLOG(2) << "Freed buffer at " << allocation_.opaque() << " ("
140           << allocation_.size() << " bytes)";
141 }
142 
allocation()143 const se::DeviceMemoryBase& XRTBufferAllocation::allocation() {
144   return allocation_;
145 }
146 
XRTTupleAllocation(int device_ordinal,se::DeviceMemoryAllocator * allocator,const xla::Shape & on_host_shape,const xla::Shape & on_device_shape)147 XRTTupleAllocation::XRTTupleAllocation(int device_ordinal,
148                                        se::DeviceMemoryAllocator* allocator,
149                                        const xla::Shape& on_host_shape,
150                                        const xla::Shape& on_device_shape)
151     : device_ordinal_(device_ordinal),
152       allocator_(allocator),
153       on_host_shape_(on_host_shape),
154       on_device_shape_(on_device_shape),
155       buffers_(&on_device_shape_),
156       pin_count_(0) {}
157 
~XRTTupleAllocation()158 XRTTupleAllocation::~XRTTupleAllocation() { ReleaseBuffers(); }
159 
ReleaseBuffers()160 void XRTTupleAllocation::ReleaseBuffers() {
161   for (auto& index_buffer : buffers_) {
162     if (index_buffer.second != nullptr) {
163       index_buffer.second->Unref();
164       index_buffer.second = nullptr;
165     }
166   }
167 }
168 
CreateAndTransfer(const xla::LiteralBase & literal,XRTMemoryManager * memory_manager,xla::Backend * backend,int device_ordinal,XRTTupleAllocation ** allocation)169 /*static*/ Status XRTTupleAllocation::CreateAndTransfer(
170     const xla::LiteralBase& literal, XRTMemoryManager* memory_manager,
171     xla::Backend* backend, int device_ordinal,
172     XRTTupleAllocation** allocation) {
173   auto transfer_manager = backend->transfer_manager();
174   std::unique_ptr<xla::ScopedShapedBuffer> scoped_buffer;
175   TF_RETURN_IF_ERROR(AllocateScopedShapedBuffer(memory_manager, backend,
176                                                 device_ordinal, literal.shape(),
177                                                 &scoped_buffer));
178   TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal));
179   TF_RETURN_IF_ERROR(transfer_manager->TransferLiteralToDevice(
180       stream.get(), literal, *scoped_buffer));
181 
182   // By releasing the ScopedShapedBuffer we ensure that the underlying storage
183   // won't be freed when the buffer goes out of scope at the end of this
184   // call. To avoid a leak, there must be no error-case returns from here until
185   // the end of the method.
186   auto shaped_buffer = scoped_buffer->release();
187   *allocation = new XRTTupleAllocation(
188       device_ordinal, backend->memory_allocator(),
189       shaped_buffer.on_host_shape(), shaped_buffer.on_device_shape());
190   (*allocation)
191       ->InitializeFromShapedBuffer(shaped_buffer, backend->memory_allocator(),
192                                    device_ordinal);
193   (*allocation)->SetDeviceMemorySize();
194   return Status::OK();
195 }
196 
CreateUninitialized(const xla::Shape & shape,XRTMemoryManager * memory_manager,xla::Backend * backend,int device_ordinal,XRTTupleAllocation ** allocation)197 /*static*/ Status XRTTupleAllocation::CreateUninitialized(
198     const xla::Shape& shape, XRTMemoryManager* memory_manager,
199     xla::Backend* backend, int device_ordinal,
200     XRTTupleAllocation** allocation) {
201   std::unique_ptr<xla::ScopedShapedBuffer> scoped_buffer;
202   TF_RETURN_IF_ERROR(AllocateScopedShapedBuffer(
203       memory_manager, backend, device_ordinal, shape, &scoped_buffer));
204 
205   // By releasing the ScopedShapedBuffer we ensure that the underlying storage
206   // won't be freed when the buffer goes out of scope at the end of this
207   // call. To avoid a leak, there must be no error-case returns from here until
208   // the end of the method.
209   auto shaped_buffer = scoped_buffer->release();
210   *allocation = new XRTTupleAllocation(
211       device_ordinal, backend->memory_allocator(),
212       shaped_buffer.on_host_shape(), shaped_buffer.on_device_shape());
213   (*allocation)
214       ->InitializeFromShapedBuffer(shaped_buffer, backend->memory_allocator(),
215                                    device_ordinal);
216   (*allocation)->SetDeviceMemorySize();
217   return Status::OK();
218 }
219 
CreateFromBuffer(const xla::ShapedBuffer & shaped_buffer,const xla::Shape & on_host_shape,const xla::Shape & on_device_shape,xla::Backend * backend,int device_ordinal,XRTTupleAllocation ** allocation)220 /*static*/ Status XRTTupleAllocation::CreateFromBuffer(
221     const xla::ShapedBuffer& shaped_buffer, const xla::Shape& on_host_shape,
222     const xla::Shape& on_device_shape, xla::Backend* backend,
223     int device_ordinal, XRTTupleAllocation** allocation) {
224   auto allocator = backend->memory_allocator();
225 
226   *allocation = new XRTTupleAllocation(device_ordinal, allocator, on_host_shape,
227                                        on_device_shape);
228   (*allocation)
229       ->InitializeFromShapedBuffer(shaped_buffer, allocator, device_ordinal);
230   (*allocation)->SetDeviceMemorySize();
231   return Status::OK();
232 }
233 
CreateFromBuffer(const xla::ShapedBuffer & shaped_buffer,xla::Backend * backend,int device_ordinal,XRTTupleAllocation ** allocation)234 /*static*/ Status XRTTupleAllocation::CreateFromBuffer(
235     const xla::ShapedBuffer& shaped_buffer, xla::Backend* backend,
236     int device_ordinal, XRTTupleAllocation** allocation) {
237   return CreateFromBuffer(shaped_buffer, shaped_buffer.on_host_shape(),
238                           shaped_buffer.on_device_shape(), backend,
239                           device_ordinal, allocation);
240 }
241 
ToLiteral(xla::Backend * backend,xla::MutableLiteralBase * literal)242 Status XRTTupleAllocation::ToLiteral(xla::Backend* backend,
243                                      xla::MutableLiteralBase* literal) {
244   mutex_lock lock(lock_);
245   return literal_ == nullptr ? StoreToLiteral(backend, literal)
246                              : literal->CopyFrom(*literal_);
247 }
248 
StoreToLiteral(xla::Backend * backend,xla::MutableLiteralBase * literal)249 Status XRTTupleAllocation::StoreToLiteral(xla::Backend* backend,
250                                           xla::MutableLiteralBase* literal) {
251   auto transfer_manager = backend->transfer_manager();
252   TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal()));
253   TF_ASSIGN_OR_RETURN(xla::ShapedBuffer shaped_buffer, ToShapedBuffer());
254   return transfer_manager->TransferLiteralFromDevice(stream.get(),
255                                                      shaped_buffer, literal);
256 }
257 
WriteLiteral(xla::Backend * backend,const xla::Literal & literal)258 Status XRTTupleAllocation::WriteLiteral(xla::Backend* backend,
259                                         const xla::Literal& literal) {
260   if (!xla::ShapeUtil::Equal(literal.shape(), on_host_shape())) {
261     return errors::InvalidArgument(
262         "New literal shape not matching the existing one: literal=",
263         xla::ShapeUtil::HumanStringWithLayout(literal.shape()),
264         " device=", xla::ShapeUtil::HumanStringWithLayout(on_host_shape()));
265   }
266   mutex_lock lock(lock_);
267   if (literal_ != nullptr) {
268     // The allocation is currently swapped out, and we have a host literal for
269     // its content. Just update the host literal with the new value.
270     return literal_->CopyFrom(literal);
271   }
272   TF_ASSIGN_OR_RETURN(xla::ShapedBuffer shaped_buffer, ToShapedBuffer());
273   auto transfer_manager = backend->transfer_manager();
274   TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal()));
275   return transfer_manager->TransferLiteralToDevice(stream.get(), literal,
276                                                    shaped_buffer);
277 }
278 
SwapOut(xla::Backend * backend,bool swap_pinned)279 xla::StatusOr<bool> XRTTupleAllocation::SwapOut(xla::Backend* backend,
280                                                 bool swap_pinned) {
281   mutex_lock lock(lock_);
282   if (literal_ == nullptr && (!IsPinned() || swap_pinned)) {
283     xla::Literal literal(on_host_shape());
284     TF_RETURN_IF_ERROR(StoreToLiteral(backend, &literal));
285     ReleaseBuffers();
286     literal_ = absl::make_unique<xla::Literal>(std::move(literal));
287     return true;
288   }
289   return false;
290 }
291 
SwapIn(XRTMemoryManager * memory_manager,xla::Backend * backend)292 xla::StatusOr<bool> XRTTupleAllocation::SwapIn(XRTMemoryManager* memory_manager,
293                                                xla::Backend* backend) {
294   // We need to call AllocateScopedShapedBuffer() outside the locks, since the
295   // XRTMemoryManager might end up calling back into the SwapOut() API.
296   // So we do a quick check before using the IsSwapped() API, and it can happen
297   // that the allocation becomes swapped in after the check. This means which we
298   // will end up doing an allocation, and then releasing it soon after (via its
299   // scoped variables). This is an unlikely scenario (two threads calling
300   // SwapIn() on the same allocation) though.
301   if (!IsSwapped()) {
302     return false;
303   }
304 
305   auto transfer_manager = backend->transfer_manager();
306   std::unique_ptr<xla::ScopedShapedBuffer> scoped_buffer;
307   TF_RETURN_IF_ERROR(
308       AllocateScopedShapedBuffer(memory_manager, backend, device_ordinal(),
309                                  on_host_shape(), &scoped_buffer));
310   TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal()));
311 
312   mutex_lock lock(lock_);
313   if (literal_ != nullptr) {
314     TF_RETURN_IF_ERROR(transfer_manager->TransferLiteralToDevice(
315         stream.get(), *literal_, *scoped_buffer));
316 
317     auto shaped_buffer = scoped_buffer->release();
318     InitializeFromShapedBuffer(shaped_buffer, backend->memory_allocator(),
319                                device_ordinal());
320     literal_ = nullptr;
321     return true;
322   }
323   return false;
324 }
325 
PinAndSwapIn(XRTMemoryManager * memory_manager,xla::Backend * backend)326 xla::StatusOr<bool> XRTTupleAllocation::PinAndSwapIn(
327     XRTMemoryManager* memory_manager, xla::Backend* backend) {
328   Pin();
329   return SwapIn(memory_manager, backend);
330 }
331 
IsSwapped() const332 bool XRTTupleAllocation::IsSwapped() const {
333   mutex_lock lock(lock_);
334   return literal_ != nullptr;
335 }
336 
Pin()337 int64 XRTTupleAllocation::Pin() { return pin_count_.fetch_add(1); }
338 
Unpin()339 int64 XRTTupleAllocation::Unpin() { return pin_count_.fetch_sub(1); }
340 
IsPinned() const341 bool XRTTupleAllocation::IsPinned() const { return pin_count_ != 0; }
342 
DiscardAllocation(const xla::ShapeIndex & buffer_index)343 void XRTTupleAllocation::DiscardAllocation(
344     const xla::ShapeIndex& buffer_index) {
345   buffers_.element(buffer_index)->DiscardAllocation();
346 }
347 
on_host_shape() const348 const xla::Shape& XRTTupleAllocation::on_host_shape() const {
349   return on_host_shape_;
350 }
351 
on_device_shape() const352 const xla::Shape& XRTTupleAllocation::on_device_shape() const {
353   return on_device_shape_;
354 }
355 
device_ordinal() const356 int XRTTupleAllocation::device_ordinal() const { return device_ordinal_; }
357 
root_allocation() const358 const se::DeviceMemoryBase& XRTTupleAllocation::root_allocation() const {
359   return buffers_.element({})->allocation();
360 }
361 
MakeSubBuffer(XRTTupleAllocation * parent,const xla::ShapeIndex & subshape,XRTTupleAllocation ** allocation,bool alias_parent_allocation)362 /*static*/ Status XRTTupleAllocation::MakeSubBuffer(
363     XRTTupleAllocation* parent, const xla::ShapeIndex& subshape,
364     XRTTupleAllocation** allocation, bool alias_parent_allocation) {
365   TF_ASSIGN_OR_RETURN(
366       const xla::Shape* host_sub_shape,
367       xla::ShapeUtil::TryGetSubshape(parent->on_host_shape(), subshape));
368   TF_ASSIGN_OR_RETURN(
369       const xla::Shape* device_sub_shape,
370       xla::ShapeUtil::TryGetSubshape(parent->on_device_shape(), subshape));
371 
372   *allocation =
373       new XRTTupleAllocation(parent->device_ordinal(), parent->allocator_,
374                              *host_sub_shape, *device_sub_shape);
375   if (alias_parent_allocation) {
376     // Copy the subtree of allocations from the parent allocation.
377     (*allocation)->buffers_.CopySubtreeFrom(parent->buffers_, subshape, {});
378     // Increment the refcount on each aliased buffer.
379     (*allocation)
380         ->buffers_.ForEachElement(
381             [](const xla::ShapeIndex& index,
382                const XRTBufferAllocationPtr& buffer) { buffer->Ref(); });
383   } else {
384     // Find the buffers in the parent allocation that match the subtree, and
385     // move the parent allocation's buffer over to the new allocation.
386     (*allocation)
387         ->buffers_.ForEachMutableElement(
388             [&](const xla::ShapeIndex& index, XRTBufferAllocationPtr* buffer) {
389               // Extend the allocation's index to the parent's frame by adding
390               // subshape as a prefix.
391               xla::ShapeIndex parent_index = subshape;
392               for (int i = 0; i < index.size(); ++i) {
393                 parent_index.push_back(index[i]);
394               }
395               *buffer = parent->buffers_.element(parent_index);
396               *parent->buffers_.mutable_element(parent_index) = nullptr;
397             });
398   }
399   (*allocation)->SetDeviceMemorySize();
400   return Status::OK();
401 }
402 
SetDeviceMemorySize()403 void XRTTupleAllocation::SetDeviceMemorySize() {
404   size_t size = 0;
405   for (auto& index_buffer : buffers_) {
406     if (index_buffer.second != nullptr) {
407       size += index_buffer.second->allocation().size();
408     }
409   }
410   device_memory_size_ = size;
411 }
412 
ExpandTreeOfTuples(const xla::ShapeTree<ExpandedTupleInput> & elements,int device_ordinal,se::DeviceMemoryAllocator * allocator,xla::Shape * host_shape,xla::Shape * device_shape)413 /* static */ Status XRTTupleAllocation::ExpandTreeOfTuples(
414     const xla::ShapeTree<ExpandedTupleInput>& elements, int device_ordinal,
415     se::DeviceMemoryAllocator* allocator, xla::Shape* host_shape,
416     xla::Shape* device_shape) {
417   // Initialize both host and device shape to be the 'spine' of the new tuple
418   // shape, given by the shape of the tree of tuples.
419   *host_shape = elements.shape();
420   *device_shape = elements.shape();
421   // Now go over the leaves of the tree of tuples, and 'graft' the host/device
422   // shapes of the allocation at that leaf onto the expanded host/device shapes
423   // at the leaf position.
424   TF_RETURN_IF_ERROR(elements.ForEachElementWithStatus(
425       [&](const xla::ShapeIndex& index, const ExpandedTupleInput& element) {
426         if (elements.IsLeaf(index)) {
427           if (element.allocation == nullptr) {
428             return errors::InvalidArgument(
429                 "MakeTuple elements has a null internal node at index ",
430                 index.ToString());
431           }
432           if (device_ordinal != element.allocation->device_ordinal() ||
433               allocator != element.allocation->allocator_) {
434             return errors::InvalidArgument(
435                 "MakeTuple elements must all be allocated on the same device "
436                 "as the destination.");
437           }
438           *xla::ShapeUtil::GetMutableSubshape(host_shape, index) =
439               element.allocation->on_host_shape();
440           *xla::ShapeUtil::GetMutableSubshape(device_shape, index) =
441               element.allocation->on_device_shape();
442         } else {
443           if (element.allocation != nullptr) {
444             return errors::InvalidArgument(
445                 "MakeTuple elements has a non-null internal node at index ",
446                 index.ToString());
447           }
448         }
449         return Status::OK();
450       }));
451   return Status::OK();
452 }
453 
MakeTuple(XRTMemoryManager * memory_manager,xla::Backend * backend,int device_ordinal,const xla::ShapeTree<ExpandedTupleInput> & elements,XRTTupleAllocation ** allocation)454 /*static*/ Status XRTTupleAllocation::MakeTuple(
455     XRTMemoryManager* memory_manager, xla::Backend* backend, int device_ordinal,
456     const xla::ShapeTree<ExpandedTupleInput>& elements,
457     XRTTupleAllocation** allocation) {
458   auto transfer_manager = backend->transfer_manager();
459   auto allocator = backend->memory_allocator();
460   TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal));
461 
462   xla::Shape host_shape;
463   xla::Shape device_shape;
464   TF_RETURN_IF_ERROR(ExpandTreeOfTuples(elements, device_ordinal, allocator,
465                                         &host_shape, &device_shape));
466 
467   // The aliasing is determined below based on whether or not all the inputs are
468   // released while being transferred. allocation_tmp is a local pointer that is
469   // copied to *allocation at the end only if the method succeeds.
470   XRTTupleAllocation* allocation_tmp = new XRTTupleAllocation(
471       device_ordinal, allocator, host_shape, device_shape);
472   core::ScopedUnref allocation_unref(allocation_tmp);
473   // First allocate device memory for the new tuple index tables, one at each
474   // internal node of the elements tree. Do this in a separate pass into a
475   // ScopedShapedBuffer so that it's easy to free the newly-allocated memory if
476   // an allocation fails. Make sure the shape has layout so that the code that
477   // writes index tables will be happy lower down.
478   xla::Shape spine_shape = elements.shape();
479   xla::LayoutUtil::SetToDefaultLayout(&spine_shape);
480   auto new_tuple_buffers = absl::make_unique<xla::ScopedShapedBuffer>(
481       spine_shape, spine_shape, allocator, device_ordinal);
482   TF_RETURN_IF_ERROR(elements.ForEachElementWithStatus(
483       [&](const xla::ShapeIndex& index, const ExpandedTupleInput& element) {
484         if (!elements.IsLeaf(index)) {
485           const xla::Shape& subshape =
486               xla::ShapeUtil::GetSubshape(device_shape, index);
487           uint64 size = transfer_manager->GetByteSizeRequirement(subshape);
488           TF_ASSIGN_OR_RETURN(
489               se::OwningDeviceMemory buffer,
490               memory_manager->Allocate(backend, device_ordinal, size));
491           VLOG(2) << "Allocated buffer at " << buffer->opaque() << " index "
492                   << index.ToString();
493           // Move the new buffer into new_tuple_buffers, which takes ownership
494           // of it.
495           new_tuple_buffers->set_buffer(std::move(buffer), index);
496         }
497         return Status::OK();
498       }));
499   // Transfer from the ScopedShapedBuffer to a ShapedBuffer, which does not own
500   // the newly-allocated index tables. Right now there's no owner for the new
501   // index tables, so next we will transfer ownership to the new allocation,
502   // taking care not to return early on any errors in the meantime.
503   xla::ShapedBuffer tuple_buffers = new_tuple_buffers->release();
504   // Now fill in the remaining datastructures. After this ForEachElement
505   // completes:
506   //   1) Every leaf element of tuple_buffers will be the root buffer of
507   //      an existing allocation, and every internal element of tuple_buffers
508   //      will be a newly-allocated index table. tuple_buffers does not own any
509   //      of these.
510   //   2) Every element of allocation_tmp->buffers_ will be a correctly
511   //   constructed
512   //      XRTBufferAllocation wrapping the necessary allocations. For buffers in
513   //      existing allocations there will be a new reference owned by the new
514   //      allocation, and for newly-allocated index tables there will be a
515   //      single reference owned by the new allocation.
516   elements.ForEachElement([&](const xla::ShapeIndex& index,
517                               const ExpandedTupleInput& element) {
518     if (elements.IsLeaf(index)) {
519       allocation_tmp->buffers_.CopySubtreeFrom(element.allocation->buffers_, {},
520                                                index);
521       tuple_buffers.set_buffer(element.allocation->root_allocation(), index);
522       if (element.release_allocation_after_use) {
523         // Transfer the references from element's buffers to the new allocation
524         // rather than incrementing the refcount. The caller should have
525         // validated that release_allocation_after_use is false if
526         // element.allocation appears in more than one leaf.
527         element.allocation->buffers_.ForEachMutableElement(
528             [&](const xla::ShapeIndex&, XRTBufferAllocationPtr* buffer) {
529               *buffer = nullptr;
530             });
531       } else {
532         // Increment the refcount on each newly-aliased buffer.
533         element.allocation->buffers_.ForEachElement(
534             [](const xla::ShapeIndex& index,
535                const XRTBufferAllocationPtr& buffer) { buffer->Ref(); });
536       }
537     } else {
538       // This is an internal node of the tuple tree so take ownership of the
539       // newly-created index table.
540       *allocation_tmp->buffers_.mutable_element(index) =
541           new XRTBufferAllocation(tuple_buffers.buffer(index), device_ordinal,
542                                   allocator);
543     }
544   });
545   allocation_tmp->SetDeviceMemorySize();
546   // Because the internal nodes of tuple_buffers are exactly the new index
547   // tables, WriteTupleIndexTables will write only the new index tables and not
548   // rewrite the index tables for the existing allocations.
549   TF_RETURN_IF_ERROR(
550       transfer_manager->WriteTupleIndexTables(stream.get(), tuple_buffers));
551 
552   *allocation = allocation_tmp;
553   // Get another reference since allocation_tmp will be Unrefed automatically on
554   // exit.
555   (*allocation)->Ref();
556   return Status::OK();
557 }
558 
IsExclusiveOwner() const559 bool XRTTupleAllocation::IsExclusiveOwner() const {
560   for (const auto& index_buffer : buffers_) {
561     if (index_buffer.second != nullptr &&
562         !index_buffer.second->RefCountIsOne()) {
563       return false;
564     }
565   }
566   return true;
567 }
568 
GetDeviceMemorySize() const569 size_t XRTTupleAllocation::GetDeviceMemorySize() const {
570   return device_memory_size_;
571 }
572 
InitializeFromShapedBuffer(const xla::ShapedBuffer & shaped_buffer,se::DeviceMemoryAllocator * allocator,int device_ordinal)573 void XRTTupleAllocation::InitializeFromShapedBuffer(
574     const xla::ShapedBuffer& shaped_buffer,
575     se::DeviceMemoryAllocator* allocator, int device_ordinal) {
576   for (auto& index_buffer : buffers_) {
577     if (index_buffer.second != nullptr) {
578       index_buffer.second->Unref();
579     }
580     // Make a reference-counted version of the allocated buffer.
581     index_buffer.second = new XRTBufferAllocation(
582         shaped_buffer.buffer(index_buffer.first), device_ordinal, allocator);
583   }
584 }
585 
ToShapedBuffer()586 xla::StatusOr<xla::ShapedBuffer> XRTTupleAllocation::ToShapedBuffer() {
587   xla::ShapedBuffer shaped_buffer(on_host_shape(), on_device_shape(),
588                                   device_ordinal_);
589   for (const auto& index_buffer : buffers_) {
590     if (index_buffer.second == nullptr ||
591         (index_buffer.second->allocation().is_null() &&
592          index_buffer.second->allocation().size() > 0)) {
593       return errors::InvalidArgument("Literal buffer at index ",
594                                      index_buffer.first.ToString(),
595                                      " has been released");
596     }
597     shaped_buffer.set_buffer(index_buffer.second->allocation(),
598                              index_buffer.first);
599   }
600   return std::move(shaped_buffer);
601 }
602 
AliasBufferFrom(const XRTTupleAllocation & source,const xla::ShapeIndex & source_index,const xla::ShapeIndex & dest_index)603 Status XRTTupleAllocation::AliasBufferFrom(const XRTTupleAllocation& source,
604                                            const xla::ShapeIndex& source_index,
605                                            const xla::ShapeIndex& dest_index) {
606   XRTBufferAllocation* source_buffer = source.buffers_.element(source_index);
607   XRTBufferAllocation* dest_buffer = buffers_.element(dest_index);
608   if (dest_buffer != nullptr) {
609     // We allow the destination size being zero, because there are cases where
610     // we are coming in later filling in null/uninitialized device buffers. In
611     // all other cases, the size of the new buffer must match.
612     if (source_buffer->allocation().size() !=
613             dest_buffer->allocation().size() &&
614         dest_buffer->allocation().size() != 0) {
615       return errors::InvalidArgument(
616           "Source buffer at index ", source_index.ToString(),
617           " does not match the size of destination buffer at index ",
618           dest_index.ToString(), ": ", source_buffer->allocation().size(),
619           " vs ", dest_buffer->allocation().size());
620     }
621   } else {
622     const xla::Shape& source_subshape =
623         xla::ShapeUtil::GetSubshape(source.on_device_shape(), source_index);
624     const xla::Shape& dest_subshape =
625         xla::ShapeUtil::GetSubshape(on_device_shape(), dest_index);
626     if (!xla::ShapeUtil::Equal(source_subshape, dest_subshape)) {
627       return errors::InvalidArgument(
628           "Source and destination subshapes do not match: source=",
629           xla::ShapeUtil::HumanStringWithLayout(source_subshape),
630           " dest=", xla::ShapeUtil::HumanStringWithLayout(dest_subshape));
631     }
632   }
633   *buffers_.mutable_element(dest_index) = source_buffer;
634   source_buffer->Ref();
635   if (dest_buffer != nullptr) {
636     // If we handed over the ownership of a buffer in ToExecutionInput(), we
637     // will be called here on the way back from execution, to alias back the
638     // buffer at that index. In that case the buffers will be the same. So we
639     // need to discard the memory at the destination buffer, before releasing
640     // the reference.
641     if (dest_buffer->allocation().IsSameAs(source_buffer->allocation()) &&
642         dest_buffer != source_buffer) {
643       dest_buffer->DiscardAllocation();
644     }
645     dest_buffer->Unref();
646   }
647   return Status::OK();
648 }
649 
ToExecutionInput(const std::function<xla::StatusOr<bool> (const xla::ShapeIndex &)> & alias_checker)650 xla::StatusOr<xla::ExecutionInput> XRTTupleAllocation::ToExecutionInput(
651     const std::function<xla::StatusOr<bool>(const xla::ShapeIndex&)>&
652         alias_checker) {
653   xla::ExecutionInput result(on_device_shape(), on_host_shape());
654   for (const auto& index_buffer : buffers_) {
655     if (index_buffer.second == nullptr ||
656         (index_buffer.second->allocation().is_null() &&
657          index_buffer.second->allocation().size() > 0)) {
658       return errors::InvalidArgument("Literal buffer at index ",
659                                      index_buffer.first.ToString(),
660                                      " has been released");
661     }
662     TF_ASSIGN_OR_RETURN(bool should_alias, alias_checker(index_buffer.first));
663     if (!should_alias) {
664       result.SetBuffer(
665           index_buffer.first,
666           xla::MaybeOwningDeviceMemory(index_buffer.second->allocation()));
667     } else {
668       // We keep the ownership of the device memory here.
669       result.SetUnownedBuffer(
670           index_buffer.first,
671           xla::MaybeOwningDeviceMemory(se::OwningDeviceMemory(
672               index_buffer.second->allocation(), device_ordinal_, allocator_)));
673     }
674   }
675   return std::move(result);
676 }
677 
678 }  // namespace tensorflow
679