1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Classes for allocating XLA literals in device memory and managing handles
17 // that refer to them.
18
19 #include "tensorflow/compiler/xrt/xrt_state.h"
20
21 #include <map>
22 #include <memory>
23 #include <string>
24 #include <utility>
25
26 #include "absl/memory/memory.h"
27 #include "tensorflow/compiler/xla/service/backend.h"
28 #include "tensorflow/compiler/xla/status_macros.h"
29 #include "tensorflow/compiler/xrt/xrt_memory_manager.h"
30
31 namespace tensorflow {
32 namespace {
33
34 // Helper typedef to make ShapeTree ForEach helper lambda signatures more
35 // readable. They need a type of const T& where in this case T is the
36 // following pointer.
37 typedef XRTBufferAllocation* XRTBufferAllocationPtr;
38
39 class BufferAllocStats {
40 public:
41 struct Stats {
42 int64 count = 0;
43 int64 size = 0;
44 };
45
ReportAlloc(int64 device,int64 msize)46 Stats ReportAlloc(int64 device, int64 msize) {
47 mutex_lock lock(lock_);
48 Stats* device_stats = &stats_[device];
49 device_stats->count += 1;
50 device_stats->size += msize;
51 return *device_stats;
52 }
53
ReportFree(int64 device,int64 msize)54 Stats ReportFree(int64 device, int64 msize) {
55 mutex_lock lock(lock_);
56 Stats* device_stats = &stats_[device];
57 device_stats->count -= 1;
58 device_stats->size -= msize;
59 return *device_stats;
60 }
61
62 private:
63 mutable mutex lock_;
64 std::map<int64, Stats> stats_;
65 };
66
GetAllocStats()67 BufferAllocStats* GetAllocStats() {
68 static BufferAllocStats* stats = new BufferAllocStats();
69 return stats;
70 }
71
AllocateScopedShapedBuffer(XRTMemoryManager * memory_manager,xla::Backend * backend,int device_ordinal,const xla::Shape & shape,std::unique_ptr<xla::ScopedShapedBuffer> * buffer)72 Status AllocateScopedShapedBuffer(
73 XRTMemoryManager* memory_manager, xla::Backend* backend, int device_ordinal,
74 const xla::Shape& shape, std::unique_ptr<xla::ScopedShapedBuffer>* buffer) {
75 auto transfer_manager = backend->transfer_manager();
76 TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal));
77
78 // XLA may use a different representation on device than the representation on
79 // the host. XLA does not document any contract for the relationship between
80 // these representations :/ Right now, the device shape is always a superset
81 // of the host shape, meaning that for any valid ShapeIndex in the host shape
82 // that ShapeIndex is also valid in the device shape, but not vice versa. In
83 // particular, some host-side types are rewritten to be tuples. We rely on
84 // this property when making sub-buffers, because we assume that if the client
85 // requests the host-shape sub-buffer at index i, that will correspond to the
86 // right device-shape sub-buffer at the same index.
87 xla::Shape on_device_shape = transfer_manager->HostShapeToDeviceShape(shape);
88 VLOG(3) << "Allocating literal buffer: host_shape="
89 << xla::ShapeUtil::HumanStringWithLayout(shape) << " device_shape="
90 << xla::ShapeUtil::HumanStringWithLayout(on_device_shape);
91
92 // The ScopedShapedBuffer frees the buffers that have so far been allocated if
93 // it goes out of scope. That's useful if we return early as the result of an
94 // error allocating one of the later buffers.
95 *buffer = absl::make_unique<xla::ScopedShapedBuffer>(
96 shape, on_device_shape, backend->memory_allocator(), device_ordinal);
97 for (auto& index_to_buffer : (*buffer)->buffers()) {
98 const xla::Shape& subshape =
99 xla::ShapeUtil::GetSubshape(on_device_shape, index_to_buffer.first);
100 uint64 size = transfer_manager->GetByteSizeRequirement(subshape);
101 TF_ASSIGN_OR_RETURN(
102 se::OwningDeviceMemory buffer,
103 memory_manager->Allocate(backend, device_ordinal, size));
104 // Move our buffer into shaped_buffer, which takes ownership of it.
105 index_to_buffer.second = buffer.Release();
106 VLOG(2) << "Allocated buffer at " << index_to_buffer.second.opaque()
107 << " index " << index_to_buffer.first.ToString() << " (" << size
108 << " bytes)";
109 }
110
111 TF_RETURN_IF_ERROR(
112 transfer_manager->WriteTupleIndexTables(stream.get(), *(buffer->get())));
113
114 return Status::OK();
115 }
116
117 } // namespace
118
XRTBufferAllocation(const se::DeviceMemoryBase & allocation,int device_ordinal,se::DeviceMemoryAllocator * allocator)119 XRTBufferAllocation::XRTBufferAllocation(const se::DeviceMemoryBase& allocation,
120 int device_ordinal,
121 se::DeviceMemoryAllocator* allocator)
122 : allocation_(allocation),
123 device_ordinal_(device_ordinal),
124 allocator_(allocator) {
125 if (VLOG_IS_ON(2)) {
126 auto stats =
127 GetAllocStats()->ReportAlloc(device_ordinal_, allocation_.size());
128 LOG(INFO) << "XRT Allocation Stats: device=" << device_ordinal_
129 << " count=" << stats.count << " size=" << stats.size;
130 }
131 }
132
~XRTBufferAllocation()133 XRTBufferAllocation::~XRTBufferAllocation() {
134 if (VLOG_IS_ON(2)) {
135 GetAllocStats()->ReportFree(device_ordinal_, allocation_.size());
136 }
137 // Deallocate explicitly allows allocation_ to be null.
138 TF_CHECK_OK(allocator_->Deallocate(device_ordinal_, allocation_));
139 VLOG(2) << "Freed buffer at " << allocation_.opaque() << " ("
140 << allocation_.size() << " bytes)";
141 }
142
allocation()143 const se::DeviceMemoryBase& XRTBufferAllocation::allocation() {
144 return allocation_;
145 }
146
XRTTupleAllocation(int device_ordinal,se::DeviceMemoryAllocator * allocator,const xla::Shape & on_host_shape,const xla::Shape & on_device_shape)147 XRTTupleAllocation::XRTTupleAllocation(int device_ordinal,
148 se::DeviceMemoryAllocator* allocator,
149 const xla::Shape& on_host_shape,
150 const xla::Shape& on_device_shape)
151 : device_ordinal_(device_ordinal),
152 allocator_(allocator),
153 on_host_shape_(on_host_shape),
154 on_device_shape_(on_device_shape),
155 buffers_(&on_device_shape_),
156 pin_count_(0) {}
157
~XRTTupleAllocation()158 XRTTupleAllocation::~XRTTupleAllocation() { ReleaseBuffers(); }
159
ReleaseBuffers()160 void XRTTupleAllocation::ReleaseBuffers() {
161 for (auto& index_buffer : buffers_) {
162 if (index_buffer.second != nullptr) {
163 index_buffer.second->Unref();
164 index_buffer.second = nullptr;
165 }
166 }
167 }
168
CreateAndTransfer(const xla::LiteralBase & literal,XRTMemoryManager * memory_manager,xla::Backend * backend,int device_ordinal,XRTTupleAllocation ** allocation)169 /*static*/ Status XRTTupleAllocation::CreateAndTransfer(
170 const xla::LiteralBase& literal, XRTMemoryManager* memory_manager,
171 xla::Backend* backend, int device_ordinal,
172 XRTTupleAllocation** allocation) {
173 auto transfer_manager = backend->transfer_manager();
174 std::unique_ptr<xla::ScopedShapedBuffer> scoped_buffer;
175 TF_RETURN_IF_ERROR(AllocateScopedShapedBuffer(memory_manager, backend,
176 device_ordinal, literal.shape(),
177 &scoped_buffer));
178 TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal));
179 TF_RETURN_IF_ERROR(transfer_manager->TransferLiteralToDevice(
180 stream.get(), literal, *scoped_buffer));
181
182 // By releasing the ScopedShapedBuffer we ensure that the underlying storage
183 // won't be freed when the buffer goes out of scope at the end of this
184 // call. To avoid a leak, there must be no error-case returns from here until
185 // the end of the method.
186 auto shaped_buffer = scoped_buffer->release();
187 *allocation = new XRTTupleAllocation(
188 device_ordinal, backend->memory_allocator(),
189 shaped_buffer.on_host_shape(), shaped_buffer.on_device_shape());
190 (*allocation)
191 ->InitializeFromShapedBuffer(shaped_buffer, backend->memory_allocator(),
192 device_ordinal);
193 (*allocation)->SetDeviceMemorySize();
194 return Status::OK();
195 }
196
CreateUninitialized(const xla::Shape & shape,XRTMemoryManager * memory_manager,xla::Backend * backend,int device_ordinal,XRTTupleAllocation ** allocation)197 /*static*/ Status XRTTupleAllocation::CreateUninitialized(
198 const xla::Shape& shape, XRTMemoryManager* memory_manager,
199 xla::Backend* backend, int device_ordinal,
200 XRTTupleAllocation** allocation) {
201 std::unique_ptr<xla::ScopedShapedBuffer> scoped_buffer;
202 TF_RETURN_IF_ERROR(AllocateScopedShapedBuffer(
203 memory_manager, backend, device_ordinal, shape, &scoped_buffer));
204
205 // By releasing the ScopedShapedBuffer we ensure that the underlying storage
206 // won't be freed when the buffer goes out of scope at the end of this
207 // call. To avoid a leak, there must be no error-case returns from here until
208 // the end of the method.
209 auto shaped_buffer = scoped_buffer->release();
210 *allocation = new XRTTupleAllocation(
211 device_ordinal, backend->memory_allocator(),
212 shaped_buffer.on_host_shape(), shaped_buffer.on_device_shape());
213 (*allocation)
214 ->InitializeFromShapedBuffer(shaped_buffer, backend->memory_allocator(),
215 device_ordinal);
216 (*allocation)->SetDeviceMemorySize();
217 return Status::OK();
218 }
219
CreateFromBuffer(const xla::ShapedBuffer & shaped_buffer,const xla::Shape & on_host_shape,const xla::Shape & on_device_shape,xla::Backend * backend,int device_ordinal,XRTTupleAllocation ** allocation)220 /*static*/ Status XRTTupleAllocation::CreateFromBuffer(
221 const xla::ShapedBuffer& shaped_buffer, const xla::Shape& on_host_shape,
222 const xla::Shape& on_device_shape, xla::Backend* backend,
223 int device_ordinal, XRTTupleAllocation** allocation) {
224 auto allocator = backend->memory_allocator();
225
226 *allocation = new XRTTupleAllocation(device_ordinal, allocator, on_host_shape,
227 on_device_shape);
228 (*allocation)
229 ->InitializeFromShapedBuffer(shaped_buffer, allocator, device_ordinal);
230 (*allocation)->SetDeviceMemorySize();
231 return Status::OK();
232 }
233
CreateFromBuffer(const xla::ShapedBuffer & shaped_buffer,xla::Backend * backend,int device_ordinal,XRTTupleAllocation ** allocation)234 /*static*/ Status XRTTupleAllocation::CreateFromBuffer(
235 const xla::ShapedBuffer& shaped_buffer, xla::Backend* backend,
236 int device_ordinal, XRTTupleAllocation** allocation) {
237 return CreateFromBuffer(shaped_buffer, shaped_buffer.on_host_shape(),
238 shaped_buffer.on_device_shape(), backend,
239 device_ordinal, allocation);
240 }
241
ToLiteral(xla::Backend * backend,xla::MutableLiteralBase * literal)242 Status XRTTupleAllocation::ToLiteral(xla::Backend* backend,
243 xla::MutableLiteralBase* literal) {
244 mutex_lock lock(lock_);
245 return literal_ == nullptr ? StoreToLiteral(backend, literal)
246 : literal->CopyFrom(*literal_);
247 }
248
StoreToLiteral(xla::Backend * backend,xla::MutableLiteralBase * literal)249 Status XRTTupleAllocation::StoreToLiteral(xla::Backend* backend,
250 xla::MutableLiteralBase* literal) {
251 auto transfer_manager = backend->transfer_manager();
252 TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal()));
253 TF_ASSIGN_OR_RETURN(xla::ShapedBuffer shaped_buffer, ToShapedBuffer());
254 return transfer_manager->TransferLiteralFromDevice(stream.get(),
255 shaped_buffer, literal);
256 }
257
WriteLiteral(xla::Backend * backend,const xla::Literal & literal)258 Status XRTTupleAllocation::WriteLiteral(xla::Backend* backend,
259 const xla::Literal& literal) {
260 if (!xla::ShapeUtil::Equal(literal.shape(), on_host_shape())) {
261 return errors::InvalidArgument(
262 "New literal shape not matching the existing one: literal=",
263 xla::ShapeUtil::HumanStringWithLayout(literal.shape()),
264 " device=", xla::ShapeUtil::HumanStringWithLayout(on_host_shape()));
265 }
266 mutex_lock lock(lock_);
267 if (literal_ != nullptr) {
268 // The allocation is currently swapped out, and we have a host literal for
269 // its content. Just update the host literal with the new value.
270 return literal_->CopyFrom(literal);
271 }
272 TF_ASSIGN_OR_RETURN(xla::ShapedBuffer shaped_buffer, ToShapedBuffer());
273 auto transfer_manager = backend->transfer_manager();
274 TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal()));
275 return transfer_manager->TransferLiteralToDevice(stream.get(), literal,
276 shaped_buffer);
277 }
278
SwapOut(xla::Backend * backend,bool swap_pinned)279 xla::StatusOr<bool> XRTTupleAllocation::SwapOut(xla::Backend* backend,
280 bool swap_pinned) {
281 mutex_lock lock(lock_);
282 if (literal_ == nullptr && (!IsPinned() || swap_pinned)) {
283 xla::Literal literal(on_host_shape());
284 TF_RETURN_IF_ERROR(StoreToLiteral(backend, &literal));
285 ReleaseBuffers();
286 literal_ = absl::make_unique<xla::Literal>(std::move(literal));
287 return true;
288 }
289 return false;
290 }
291
SwapIn(XRTMemoryManager * memory_manager,xla::Backend * backend)292 xla::StatusOr<bool> XRTTupleAllocation::SwapIn(XRTMemoryManager* memory_manager,
293 xla::Backend* backend) {
294 // We need to call AllocateScopedShapedBuffer() outside the locks, since the
295 // XRTMemoryManager might end up calling back into the SwapOut() API.
296 // So we do a quick check before using the IsSwapped() API, and it can happen
297 // that the allocation becomes swapped in after the check. This means which we
298 // will end up doing an allocation, and then releasing it soon after (via its
299 // scoped variables). This is an unlikely scenario (two threads calling
300 // SwapIn() on the same allocation) though.
301 if (!IsSwapped()) {
302 return false;
303 }
304
305 auto transfer_manager = backend->transfer_manager();
306 std::unique_ptr<xla::ScopedShapedBuffer> scoped_buffer;
307 TF_RETURN_IF_ERROR(
308 AllocateScopedShapedBuffer(memory_manager, backend, device_ordinal(),
309 on_host_shape(), &scoped_buffer));
310 TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal()));
311
312 mutex_lock lock(lock_);
313 if (literal_ != nullptr) {
314 TF_RETURN_IF_ERROR(transfer_manager->TransferLiteralToDevice(
315 stream.get(), *literal_, *scoped_buffer));
316
317 auto shaped_buffer = scoped_buffer->release();
318 InitializeFromShapedBuffer(shaped_buffer, backend->memory_allocator(),
319 device_ordinal());
320 literal_ = nullptr;
321 return true;
322 }
323 return false;
324 }
325
PinAndSwapIn(XRTMemoryManager * memory_manager,xla::Backend * backend)326 xla::StatusOr<bool> XRTTupleAllocation::PinAndSwapIn(
327 XRTMemoryManager* memory_manager, xla::Backend* backend) {
328 Pin();
329 return SwapIn(memory_manager, backend);
330 }
331
IsSwapped() const332 bool XRTTupleAllocation::IsSwapped() const {
333 mutex_lock lock(lock_);
334 return literal_ != nullptr;
335 }
336
Pin()337 int64 XRTTupleAllocation::Pin() { return pin_count_.fetch_add(1); }
338
Unpin()339 int64 XRTTupleAllocation::Unpin() { return pin_count_.fetch_sub(1); }
340
IsPinned() const341 bool XRTTupleAllocation::IsPinned() const { return pin_count_ != 0; }
342
DiscardAllocation(const xla::ShapeIndex & buffer_index)343 void XRTTupleAllocation::DiscardAllocation(
344 const xla::ShapeIndex& buffer_index) {
345 buffers_.element(buffer_index)->DiscardAllocation();
346 }
347
on_host_shape() const348 const xla::Shape& XRTTupleAllocation::on_host_shape() const {
349 return on_host_shape_;
350 }
351
on_device_shape() const352 const xla::Shape& XRTTupleAllocation::on_device_shape() const {
353 return on_device_shape_;
354 }
355
device_ordinal() const356 int XRTTupleAllocation::device_ordinal() const { return device_ordinal_; }
357
root_allocation() const358 const se::DeviceMemoryBase& XRTTupleAllocation::root_allocation() const {
359 return buffers_.element({})->allocation();
360 }
361
MakeSubBuffer(XRTTupleAllocation * parent,const xla::ShapeIndex & subshape,XRTTupleAllocation ** allocation,bool alias_parent_allocation)362 /*static*/ Status XRTTupleAllocation::MakeSubBuffer(
363 XRTTupleAllocation* parent, const xla::ShapeIndex& subshape,
364 XRTTupleAllocation** allocation, bool alias_parent_allocation) {
365 TF_ASSIGN_OR_RETURN(
366 const xla::Shape* host_sub_shape,
367 xla::ShapeUtil::TryGetSubshape(parent->on_host_shape(), subshape));
368 TF_ASSIGN_OR_RETURN(
369 const xla::Shape* device_sub_shape,
370 xla::ShapeUtil::TryGetSubshape(parent->on_device_shape(), subshape));
371
372 *allocation =
373 new XRTTupleAllocation(parent->device_ordinal(), parent->allocator_,
374 *host_sub_shape, *device_sub_shape);
375 if (alias_parent_allocation) {
376 // Copy the subtree of allocations from the parent allocation.
377 (*allocation)->buffers_.CopySubtreeFrom(parent->buffers_, subshape, {});
378 // Increment the refcount on each aliased buffer.
379 (*allocation)
380 ->buffers_.ForEachElement(
381 [](const xla::ShapeIndex& index,
382 const XRTBufferAllocationPtr& buffer) { buffer->Ref(); });
383 } else {
384 // Find the buffers in the parent allocation that match the subtree, and
385 // move the parent allocation's buffer over to the new allocation.
386 (*allocation)
387 ->buffers_.ForEachMutableElement(
388 [&](const xla::ShapeIndex& index, XRTBufferAllocationPtr* buffer) {
389 // Extend the allocation's index to the parent's frame by adding
390 // subshape as a prefix.
391 xla::ShapeIndex parent_index = subshape;
392 for (int i = 0; i < index.size(); ++i) {
393 parent_index.push_back(index[i]);
394 }
395 *buffer = parent->buffers_.element(parent_index);
396 *parent->buffers_.mutable_element(parent_index) = nullptr;
397 });
398 }
399 (*allocation)->SetDeviceMemorySize();
400 return Status::OK();
401 }
402
SetDeviceMemorySize()403 void XRTTupleAllocation::SetDeviceMemorySize() {
404 size_t size = 0;
405 for (auto& index_buffer : buffers_) {
406 if (index_buffer.second != nullptr) {
407 size += index_buffer.second->allocation().size();
408 }
409 }
410 device_memory_size_ = size;
411 }
412
ExpandTreeOfTuples(const xla::ShapeTree<ExpandedTupleInput> & elements,int device_ordinal,se::DeviceMemoryAllocator * allocator,xla::Shape * host_shape,xla::Shape * device_shape)413 /* static */ Status XRTTupleAllocation::ExpandTreeOfTuples(
414 const xla::ShapeTree<ExpandedTupleInput>& elements, int device_ordinal,
415 se::DeviceMemoryAllocator* allocator, xla::Shape* host_shape,
416 xla::Shape* device_shape) {
417 // Initialize both host and device shape to be the 'spine' of the new tuple
418 // shape, given by the shape of the tree of tuples.
419 *host_shape = elements.shape();
420 *device_shape = elements.shape();
421 // Now go over the leaves of the tree of tuples, and 'graft' the host/device
422 // shapes of the allocation at that leaf onto the expanded host/device shapes
423 // at the leaf position.
424 TF_RETURN_IF_ERROR(elements.ForEachElementWithStatus(
425 [&](const xla::ShapeIndex& index, const ExpandedTupleInput& element) {
426 if (elements.IsLeaf(index)) {
427 if (element.allocation == nullptr) {
428 return errors::InvalidArgument(
429 "MakeTuple elements has a null internal node at index ",
430 index.ToString());
431 }
432 if (device_ordinal != element.allocation->device_ordinal() ||
433 allocator != element.allocation->allocator_) {
434 return errors::InvalidArgument(
435 "MakeTuple elements must all be allocated on the same device "
436 "as the destination.");
437 }
438 *xla::ShapeUtil::GetMutableSubshape(host_shape, index) =
439 element.allocation->on_host_shape();
440 *xla::ShapeUtil::GetMutableSubshape(device_shape, index) =
441 element.allocation->on_device_shape();
442 } else {
443 if (element.allocation != nullptr) {
444 return errors::InvalidArgument(
445 "MakeTuple elements has a non-null internal node at index ",
446 index.ToString());
447 }
448 }
449 return Status::OK();
450 }));
451 return Status::OK();
452 }
453
MakeTuple(XRTMemoryManager * memory_manager,xla::Backend * backend,int device_ordinal,const xla::ShapeTree<ExpandedTupleInput> & elements,XRTTupleAllocation ** allocation)454 /*static*/ Status XRTTupleAllocation::MakeTuple(
455 XRTMemoryManager* memory_manager, xla::Backend* backend, int device_ordinal,
456 const xla::ShapeTree<ExpandedTupleInput>& elements,
457 XRTTupleAllocation** allocation) {
458 auto transfer_manager = backend->transfer_manager();
459 auto allocator = backend->memory_allocator();
460 TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal));
461
462 xla::Shape host_shape;
463 xla::Shape device_shape;
464 TF_RETURN_IF_ERROR(ExpandTreeOfTuples(elements, device_ordinal, allocator,
465 &host_shape, &device_shape));
466
467 // The aliasing is determined below based on whether or not all the inputs are
468 // released while being transferred. allocation_tmp is a local pointer that is
469 // copied to *allocation at the end only if the method succeeds.
470 XRTTupleAllocation* allocation_tmp = new XRTTupleAllocation(
471 device_ordinal, allocator, host_shape, device_shape);
472 core::ScopedUnref allocation_unref(allocation_tmp);
473 // First allocate device memory for the new tuple index tables, one at each
474 // internal node of the elements tree. Do this in a separate pass into a
475 // ScopedShapedBuffer so that it's easy to free the newly-allocated memory if
476 // an allocation fails. Make sure the shape has layout so that the code that
477 // writes index tables will be happy lower down.
478 xla::Shape spine_shape = elements.shape();
479 xla::LayoutUtil::SetToDefaultLayout(&spine_shape);
480 auto new_tuple_buffers = absl::make_unique<xla::ScopedShapedBuffer>(
481 spine_shape, spine_shape, allocator, device_ordinal);
482 TF_RETURN_IF_ERROR(elements.ForEachElementWithStatus(
483 [&](const xla::ShapeIndex& index, const ExpandedTupleInput& element) {
484 if (!elements.IsLeaf(index)) {
485 const xla::Shape& subshape =
486 xla::ShapeUtil::GetSubshape(device_shape, index);
487 uint64 size = transfer_manager->GetByteSizeRequirement(subshape);
488 TF_ASSIGN_OR_RETURN(
489 se::OwningDeviceMemory buffer,
490 memory_manager->Allocate(backend, device_ordinal, size));
491 VLOG(2) << "Allocated buffer at " << buffer->opaque() << " index "
492 << index.ToString();
493 // Move the new buffer into new_tuple_buffers, which takes ownership
494 // of it.
495 new_tuple_buffers->set_buffer(std::move(buffer), index);
496 }
497 return Status::OK();
498 }));
499 // Transfer from the ScopedShapedBuffer to a ShapedBuffer, which does not own
500 // the newly-allocated index tables. Right now there's no owner for the new
501 // index tables, so next we will transfer ownership to the new allocation,
502 // taking care not to return early on any errors in the meantime.
503 xla::ShapedBuffer tuple_buffers = new_tuple_buffers->release();
504 // Now fill in the remaining datastructures. After this ForEachElement
505 // completes:
506 // 1) Every leaf element of tuple_buffers will be the root buffer of
507 // an existing allocation, and every internal element of tuple_buffers
508 // will be a newly-allocated index table. tuple_buffers does not own any
509 // of these.
510 // 2) Every element of allocation_tmp->buffers_ will be a correctly
511 // constructed
512 // XRTBufferAllocation wrapping the necessary allocations. For buffers in
513 // existing allocations there will be a new reference owned by the new
514 // allocation, and for newly-allocated index tables there will be a
515 // single reference owned by the new allocation.
516 elements.ForEachElement([&](const xla::ShapeIndex& index,
517 const ExpandedTupleInput& element) {
518 if (elements.IsLeaf(index)) {
519 allocation_tmp->buffers_.CopySubtreeFrom(element.allocation->buffers_, {},
520 index);
521 tuple_buffers.set_buffer(element.allocation->root_allocation(), index);
522 if (element.release_allocation_after_use) {
523 // Transfer the references from element's buffers to the new allocation
524 // rather than incrementing the refcount. The caller should have
525 // validated that release_allocation_after_use is false if
526 // element.allocation appears in more than one leaf.
527 element.allocation->buffers_.ForEachMutableElement(
528 [&](const xla::ShapeIndex&, XRTBufferAllocationPtr* buffer) {
529 *buffer = nullptr;
530 });
531 } else {
532 // Increment the refcount on each newly-aliased buffer.
533 element.allocation->buffers_.ForEachElement(
534 [](const xla::ShapeIndex& index,
535 const XRTBufferAllocationPtr& buffer) { buffer->Ref(); });
536 }
537 } else {
538 // This is an internal node of the tuple tree so take ownership of the
539 // newly-created index table.
540 *allocation_tmp->buffers_.mutable_element(index) =
541 new XRTBufferAllocation(tuple_buffers.buffer(index), device_ordinal,
542 allocator);
543 }
544 });
545 allocation_tmp->SetDeviceMemorySize();
546 // Because the internal nodes of tuple_buffers are exactly the new index
547 // tables, WriteTupleIndexTables will write only the new index tables and not
548 // rewrite the index tables for the existing allocations.
549 TF_RETURN_IF_ERROR(
550 transfer_manager->WriteTupleIndexTables(stream.get(), tuple_buffers));
551
552 *allocation = allocation_tmp;
553 // Get another reference since allocation_tmp will be Unrefed automatically on
554 // exit.
555 (*allocation)->Ref();
556 return Status::OK();
557 }
558
IsExclusiveOwner() const559 bool XRTTupleAllocation::IsExclusiveOwner() const {
560 for (const auto& index_buffer : buffers_) {
561 if (index_buffer.second != nullptr &&
562 !index_buffer.second->RefCountIsOne()) {
563 return false;
564 }
565 }
566 return true;
567 }
568
GetDeviceMemorySize() const569 size_t XRTTupleAllocation::GetDeviceMemorySize() const {
570 return device_memory_size_;
571 }
572
InitializeFromShapedBuffer(const xla::ShapedBuffer & shaped_buffer,se::DeviceMemoryAllocator * allocator,int device_ordinal)573 void XRTTupleAllocation::InitializeFromShapedBuffer(
574 const xla::ShapedBuffer& shaped_buffer,
575 se::DeviceMemoryAllocator* allocator, int device_ordinal) {
576 for (auto& index_buffer : buffers_) {
577 if (index_buffer.second != nullptr) {
578 index_buffer.second->Unref();
579 }
580 // Make a reference-counted version of the allocated buffer.
581 index_buffer.second = new XRTBufferAllocation(
582 shaped_buffer.buffer(index_buffer.first), device_ordinal, allocator);
583 }
584 }
585
ToShapedBuffer()586 xla::StatusOr<xla::ShapedBuffer> XRTTupleAllocation::ToShapedBuffer() {
587 xla::ShapedBuffer shaped_buffer(on_host_shape(), on_device_shape(),
588 device_ordinal_);
589 for (const auto& index_buffer : buffers_) {
590 if (index_buffer.second == nullptr ||
591 (index_buffer.second->allocation().is_null() &&
592 index_buffer.second->allocation().size() > 0)) {
593 return errors::InvalidArgument("Literal buffer at index ",
594 index_buffer.first.ToString(),
595 " has been released");
596 }
597 shaped_buffer.set_buffer(index_buffer.second->allocation(),
598 index_buffer.first);
599 }
600 return std::move(shaped_buffer);
601 }
602
AliasBufferFrom(const XRTTupleAllocation & source,const xla::ShapeIndex & source_index,const xla::ShapeIndex & dest_index)603 Status XRTTupleAllocation::AliasBufferFrom(const XRTTupleAllocation& source,
604 const xla::ShapeIndex& source_index,
605 const xla::ShapeIndex& dest_index) {
606 XRTBufferAllocation* source_buffer = source.buffers_.element(source_index);
607 XRTBufferAllocation* dest_buffer = buffers_.element(dest_index);
608 if (dest_buffer != nullptr) {
609 // We allow the destination size being zero, because there are cases where
610 // we are coming in later filling in null/uninitialized device buffers. In
611 // all other cases, the size of the new buffer must match.
612 if (source_buffer->allocation().size() !=
613 dest_buffer->allocation().size() &&
614 dest_buffer->allocation().size() != 0) {
615 return errors::InvalidArgument(
616 "Source buffer at index ", source_index.ToString(),
617 " does not match the size of destination buffer at index ",
618 dest_index.ToString(), ": ", source_buffer->allocation().size(),
619 " vs ", dest_buffer->allocation().size());
620 }
621 } else {
622 const xla::Shape& source_subshape =
623 xla::ShapeUtil::GetSubshape(source.on_device_shape(), source_index);
624 const xla::Shape& dest_subshape =
625 xla::ShapeUtil::GetSubshape(on_device_shape(), dest_index);
626 if (!xla::ShapeUtil::Equal(source_subshape, dest_subshape)) {
627 return errors::InvalidArgument(
628 "Source and destination subshapes do not match: source=",
629 xla::ShapeUtil::HumanStringWithLayout(source_subshape),
630 " dest=", xla::ShapeUtil::HumanStringWithLayout(dest_subshape));
631 }
632 }
633 *buffers_.mutable_element(dest_index) = source_buffer;
634 source_buffer->Ref();
635 if (dest_buffer != nullptr) {
636 // If we handed over the ownership of a buffer in ToExecutionInput(), we
637 // will be called here on the way back from execution, to alias back the
638 // buffer at that index. In that case the buffers will be the same. So we
639 // need to discard the memory at the destination buffer, before releasing
640 // the reference.
641 if (dest_buffer->allocation().IsSameAs(source_buffer->allocation()) &&
642 dest_buffer != source_buffer) {
643 dest_buffer->DiscardAllocation();
644 }
645 dest_buffer->Unref();
646 }
647 return Status::OK();
648 }
649
ToExecutionInput(const std::function<xla::StatusOr<bool> (const xla::ShapeIndex &)> & alias_checker)650 xla::StatusOr<xla::ExecutionInput> XRTTupleAllocation::ToExecutionInput(
651 const std::function<xla::StatusOr<bool>(const xla::ShapeIndex&)>&
652 alias_checker) {
653 xla::ExecutionInput result(on_device_shape(), on_host_shape());
654 for (const auto& index_buffer : buffers_) {
655 if (index_buffer.second == nullptr ||
656 (index_buffer.second->allocation().is_null() &&
657 index_buffer.second->allocation().size() > 0)) {
658 return errors::InvalidArgument("Literal buffer at index ",
659 index_buffer.first.ToString(),
660 " has been released");
661 }
662 TF_ASSIGN_OR_RETURN(bool should_alias, alias_checker(index_buffer.first));
663 if (!should_alias) {
664 result.SetBuffer(
665 index_buffer.first,
666 xla::MaybeOwningDeviceMemory(index_buffer.second->allocation()));
667 } else {
668 // We keep the ownership of the device memory here.
669 result.SetUnownedBuffer(
670 index_buffer.first,
671 xla::MaybeOwningDeviceMemory(se::OwningDeviceMemory(
672 index_buffer.second->allocation(), device_ordinal_, allocator_)));
673 }
674 }
675 return std::move(result);
676 }
677
678 } // namespace tensorflow
679