1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Classes for allocating XLA literals in device memory and managing handles
17 // that refer to them.
18
19 #include "tensorflow/compiler/xrt/xrt_state.h"
20
21 #include <stdint.h>
22 #include <map>
23 #include <memory>
24 #include <string>
25 #include <utility>
26
27 #include "absl/memory/memory.h"
28 #include "absl/strings/str_cat.h"
29 #include "tensorflow/compiler/xla/literal.h"
30 #include "tensorflow/compiler/xla/service/backend.h"
31 #include "tensorflow/compiler/xla/shape_util.h"
32 #include "tensorflow/compiler/xla/status_macros.h"
33 #include "tensorflow/compiler/xla/statusor.h"
34 #include "tensorflow/core/framework/resource_mgr.h"
35 #include "tensorflow/core/lib/core/status.h"
36 #include "tensorflow/core/lib/random/random.h"
37 #include "tensorflow/core/platform/mutex.h"
38 #include "tensorflow/core/platform/types.h"
39 #include "tensorflow/stream_executor/stream_executor.h"
40
41 namespace tensorflow {
42
43 namespace {
44
45 class BufferAllocStats {
46 public:
47 struct Stats {
48 int64 count = 0;
49 int64 size = 0;
50 };
51
ReportAlloc(int64 device,int64 msize)52 Stats ReportAlloc(int64 device, int64 msize) {
53 mutex_lock lock(lock_);
54 Stats* device_stats = &stats_[device];
55 device_stats->count += 1;
56 device_stats->size += msize;
57 return *device_stats;
58 }
59
ReportFree(int64 device,int64 msize)60 Stats ReportFree(int64 device, int64 msize) {
61 mutex_lock lock(lock_);
62 Stats* device_stats = &stats_[device];
63 device_stats->count -= 1;
64 device_stats->size -= msize;
65 return *device_stats;
66 }
67
68 private:
69 mutable mutex lock_;
70 std::map<int64, Stats> stats_;
71 };
72
73 const char* kTupleContainer = "tuples";
74
get_uid()75 int64 get_uid() {
76 uint64 unsigned_rand = random::New64() & INT64_MAX;
77 return static_cast<int64>(unsigned_rand);
78 }
79
GetAllocStats()80 BufferAllocStats* GetAllocStats() {
81 static BufferAllocStats* stats = new BufferAllocStats();
82 return stats;
83 }
84
AllocateScopedShapedBuffer(xla::Backend * backend,int device_ordinal,const xla::Shape & shape,std::unique_ptr<xla::ScopedShapedBuffer> * buffer)85 Status AllocateScopedShapedBuffer(
86 xla::Backend* backend, int device_ordinal, const xla::Shape& shape,
87 std::unique_ptr<xla::ScopedShapedBuffer>* buffer) {
88 auto transfer_manager = backend->transfer_manager();
89 auto allocator = backend->memory_allocator();
90 TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal));
91
92 // XLA may use a different representation on device than the representation on
93 // the host. XLA does not document any contract for the relationship between
94 // these representations :/ Right now, the device shape is always a superset
95 // of the host shape, meaning that for any valid ShapeIndex in the host shape
96 // that ShapeIndex is also valid in the device shape, but not vice versa. In
97 // particular, some host-side types are rewritten to be tuples. We rely on
98 // this property when making sub-buffers, because we assume that if the client
99 // requests the host-shape sub-buffer at index i, that will correspond to the
100 // right device-shape sub-buffer at the same index.
101 xla::Shape on_device_shape = transfer_manager->HostShapeToDeviceShape(shape);
102 VLOG(3) << "Allocating literal buffer: host_shape="
103 << xla::ShapeUtil::HumanStringWithLayout(shape) << " device_shape="
104 << xla::ShapeUtil::HumanStringWithLayout(on_device_shape);
105
106 // The ScopedShapedBuffer frees the buffers that have so far been allocated if
107 // it goes out of scope. That's useful if we return early as the result of an
108 // error allocating one of the later buffers.
109 *buffer = absl::make_unique<xla::ScopedShapedBuffer>(
110 shape, on_device_shape, allocator, device_ordinal);
111 for (auto& index_to_buffer : (*buffer)->buffers()) {
112 xla::Shape subshape =
113 xla::ShapeUtil::GetSubshape(on_device_shape, index_to_buffer.first);
114 uint64 size = transfer_manager->GetByteSizeRequirement(subshape);
115 TF_ASSIGN_OR_RETURN(
116 xla::OwningDeviceMemory buffer,
117 allocator->Allocate(device_ordinal, size, /*retry_on_failure=*/false));
118 // Move our buffer into shaped_buffer, which takes ownership of it.
119 index_to_buffer.second = buffer.Forget();
120 VLOG(2) << "Allocated buffer at " << index_to_buffer.second.opaque()
121 << " index " << index_to_buffer.first.ToString();
122 }
123
124 TF_RETURN_IF_ERROR(
125 transfer_manager->WriteTupleIndexTables(stream.get(), *(buffer->get())));
126
127 return Status::OK();
128 }
129
130 } // namespace
131
XRTBufferAllocation(const se::DeviceMemoryBase & allocation,int device_ordinal,xla::DeviceMemoryAllocator * allocator)132 XRTBufferAllocation::XRTBufferAllocation(const se::DeviceMemoryBase& allocation,
133 int device_ordinal,
134 xla::DeviceMemoryAllocator* allocator)
135 : size_(allocation.size()),
136 allocation_(allocation),
137 device_ordinal_(device_ordinal),
138 allocator_(allocator) {
139 if (VLOG_IS_ON(2)) {
140 auto stats =
141 GetAllocStats()->ReportAlloc(device_ordinal_, allocation_.size());
142 LOG(INFO) << "XRT Allocation Stats: device=" << device_ordinal_
143 << " count=" << stats.count << " size=" << stats.size;
144 }
145 }
146
~XRTBufferAllocation()147 XRTBufferAllocation::~XRTBufferAllocation() {
148 if (VLOG_IS_ON(2)) {
149 GetAllocStats()->ReportFree(device_ordinal_, allocation_.size());
150 }
151 // Deallocate explicitly allows allocation_ to be null.
152 Status s = allocator_->Deallocate(device_ordinal_, allocation_);
153 // Nothing to do but check fail here if memory datastructures are corrupted.
154 CHECK(s.ok());
155 VLOG(2) << "Freed buffer at " << allocation_.opaque();
156 }
157
allocation()158 const se::DeviceMemoryBase& XRTBufferAllocation::allocation() {
159 return allocation_;
160 }
161
DiscardAllocation()162 void XRTBufferAllocation::DiscardAllocation() {
163 // Replace the allocation with a null.
164 allocation_ = se::DeviceMemoryBase();
165 }
166
XRTTupleAllocation(int device_ordinal,xla::DeviceMemoryAllocator * allocator,const xla::Shape & on_host_shape,const xla::Shape & on_device_shape)167 XRTTupleAllocation::XRTTupleAllocation(int device_ordinal,
168 xla::DeviceMemoryAllocator* allocator,
169 const xla::Shape& on_host_shape,
170 const xla::Shape& on_device_shape)
171 : device_ordinal_(device_ordinal),
172 allocator_(allocator),
173 on_host_shape_(on_host_shape),
174 on_device_shape_(on_device_shape),
175 buffers_(&on_device_shape_) {}
176
~XRTTupleAllocation()177 XRTTupleAllocation::~XRTTupleAllocation() {
178 for (auto& buffer : buffers_) {
179 buffer.second->Unref();
180 }
181 }
182
CreateAndTransfer(const xla::LiteralBase & literal,xla::Backend * backend,int device_ordinal,XRTTupleAllocation ** allocation)183 /*static*/ Status XRTTupleAllocation::CreateAndTransfer(
184 const xla::LiteralBase& literal, xla::Backend* backend, int device_ordinal,
185 XRTTupleAllocation** allocation) {
186 auto transfer_manager = backend->transfer_manager();
187 auto allocator = backend->memory_allocator();
188
189 std::unique_ptr<xla::ScopedShapedBuffer> scoped_buffer;
190 TF_RETURN_IF_ERROR(AllocateScopedShapedBuffer(
191 backend, device_ordinal, literal.shape(), &scoped_buffer));
192 TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal));
193 TF_RETURN_IF_ERROR(transfer_manager->TransferLiteralToDevice(
194 stream.get(), literal, *scoped_buffer));
195
196 // By releasing the ScopedShapedBuffer we ensure that the underlying storage
197 // won't be freed when the buffer goes out of scope at the end of this
198 // call. To avoid a leak, there must be no error-case returns from here until
199 // the end of the method.
200 auto shaped_buffer = scoped_buffer->release();
201 *allocation = new XRTTupleAllocation(device_ordinal, allocator,
202 shaped_buffer.on_host_shape(),
203 shaped_buffer.on_device_shape());
204 (*allocation)
205 ->InitializeFromShapedBuffer(shaped_buffer, allocator, device_ordinal);
206 return Status::OK();
207 }
208
CreateFromBuffer(const xla::ShapedBuffer & shaped_buffer,xla::Backend * backend,int device_ordinal,XRTTupleAllocation ** allocation)209 /*static*/ Status XRTTupleAllocation::CreateFromBuffer(
210 const xla::ShapedBuffer& shaped_buffer, xla::Backend* backend,
211 int device_ordinal, XRTTupleAllocation** allocation) {
212 auto allocator = backend->memory_allocator();
213
214 *allocation = new XRTTupleAllocation(device_ordinal, allocator,
215 shaped_buffer.on_host_shape(),
216 shaped_buffer.on_device_shape());
217 (*allocation)
218 ->InitializeFromShapedBuffer(shaped_buffer, allocator, device_ordinal);
219 return Status::OK();
220 }
221
ToLiteral(xla::Backend * backend,int device_ordinal,xla::MutableLiteralBase * literal)222 Status XRTTupleAllocation::ToLiteral(xla::Backend* backend, int device_ordinal,
223 xla::MutableLiteralBase* literal) {
224 auto transfer_manager = backend->transfer_manager();
225 TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal));
226
227 // Validate the allocation buffers as if nulls gets to
228 // TransferLiteralFromDevice() a CHECK is issued.
229 xla::ShapedBuffer shaped_buffer = ToShapedBuffer();
230 for (auto& index_buffer : shaped_buffer.buffers()) {
231 if (index_buffer.second.is_null()) {
232 return errors::InvalidArgument("Literal buffer at index ",
233 index_buffer.first.ToString(),
234 " has been released");
235 }
236 }
237 return transfer_manager->TransferLiteralFromDevice(stream.get(),
238 shaped_buffer, *literal);
239 }
240
WriteLiteral(xla::Backend * backend,const xla::Literal & literal)241 Status XRTTupleAllocation::WriteLiteral(xla::Backend* backend,
242 const xla::Literal& literal) {
243 if (!xla::ShapeUtil::Equal(literal.shape(), on_host_shape())) {
244 return errors::InvalidArgument(
245 "New literal shape not matching the existing one: literal=",
246 xla::ShapeUtil::HumanStringWithLayout(literal.shape()),
247 " device=", xla::ShapeUtil::HumanStringWithLayout(on_host_shape()));
248 }
249 auto transfer_manager = backend->transfer_manager();
250 TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal()));
251 return transfer_manager->TransferLiteralToDevice(stream.get(), literal,
252 ToShapedBuffer());
253 }
254
DiscardAllocation(const xla::ShapeIndex & buffer_index)255 void XRTTupleAllocation::DiscardAllocation(
256 const xla::ShapeIndex& buffer_index) {
257 buffers_.element(buffer_index)->DiscardAllocation();
258 }
259
on_host_shape()260 const xla::Shape& XRTTupleAllocation::on_host_shape() { return on_host_shape_; }
261
on_device_shape()262 const xla::Shape& XRTTupleAllocation::on_device_shape() {
263 return on_device_shape_;
264 }
265
device_ordinal()266 int XRTTupleAllocation::device_ordinal() { return device_ordinal_; }
267
root_allocation()268 const se::DeviceMemoryBase& XRTTupleAllocation::root_allocation() {
269 return buffers_.element({})->allocation();
270 }
271
Lookup(ResourceMgr * rm,int64 key,XRTTupleAllocation ** allocation)272 /*static*/ Status XRTTupleAllocation::Lookup(ResourceMgr* rm, int64 key,
273 XRTTupleAllocation** allocation) {
274 string key_string = absl::StrCat(key);
275 TF_RETURN_IF_ERROR(rm->Lookup(kTupleContainer, key_string, allocation));
276 return Status::OK();
277 }
278
DeleteFromResourceManager(ResourceMgr * rm,int64 key)279 /*static*/ Status XRTTupleAllocation::DeleteFromResourceManager(ResourceMgr* rm,
280 int64 key) {
281 string key_string = absl::StrCat(key);
282 return rm->Delete<XRTTupleAllocation>(kTupleContainer, key_string);
283 }
284
ReleaseAllAllocations(ResourceMgr * rm)285 /* static */ Status XRTTupleAllocation::ReleaseAllAllocations(ResourceMgr* rm) {
286 VLOG(1) << "Releasing all XRT held device memory";
287 return rm->Cleanup(kTupleContainer);
288 }
289
290 // Helper typedef to make ShapeTree ForEach helper lambda signatures more
291 // readable. They need a type of const T& where in this case T is the
292 // following pointer.
293 typedef XRTBufferAllocation* XRTBufferAllocationPtr;
294
MakeSubBuffer(XRTTupleAllocation * parent,const xla::ShapeIndex & subshape,XRTTupleAllocation ** allocation,bool alias_parent_allocation)295 /*static*/ Status XRTTupleAllocation::MakeSubBuffer(
296 XRTTupleAllocation* parent, const xla::ShapeIndex& subshape,
297 XRTTupleAllocation** allocation, bool alias_parent_allocation) {
298 TF_ASSIGN_OR_RETURN(
299 const xla::Shape* host_sub_shape,
300 xla::ShapeUtil::TryGetSubshape(parent->on_host_shape(), subshape));
301 TF_ASSIGN_OR_RETURN(
302 const xla::Shape* device_sub_shape,
303 xla::ShapeUtil::TryGetSubshape(parent->on_device_shape(), subshape));
304
305 *allocation =
306 new XRTTupleAllocation(parent->device_ordinal(), parent->allocator_,
307 *host_sub_shape, *device_sub_shape);
308 if (alias_parent_allocation) {
309 // Copy the subtree of allocations from the parent allocation.
310 (*allocation)->buffers_.CopySubtreeFrom(parent->buffers_, subshape, {});
311 // Increment the refcount on each aliased buffer.
312 (*allocation)
313 ->buffers_.ForEachElement(
314 [](const xla::ShapeIndex& index,
315 const XRTBufferAllocationPtr& buffer) { buffer->Ref(); });
316 } else {
317 // Find the buffers in the parent allocation that match the subtree, and
318 // move the parent allocation's buffer over to the new allocation.
319 (*allocation)
320 ->buffers_.ForEachMutableElement(
321 [&](const xla::ShapeIndex& index, XRTBufferAllocationPtr* buffer) {
322 // Extend the allocation's index to the parent's frame by adding
323 // subshape as a prefix.
324 xla::ShapeIndex parent_index = subshape;
325 for (int i = 0; i < index.size(); ++i) {
326 parent_index.push_back(index[i]);
327 }
328 *buffer = parent->buffers_.element(parent_index);
329 *parent->buffers_.mutable_element(parent_index) =
330 new XRTBufferAllocation(se::DeviceMemoryBase(),
331 parent->device_ordinal(),
332 parent->allocator_);
333 });
334 }
335
336 return Status::OK();
337 }
338
ExpandTreeOfTuples(const xla::ShapeTree<ExpandedTupleInput> & elements,int device_ordinal,xla::DeviceMemoryAllocator * allocator,xla::Shape * host_shape,xla::Shape * device_shape)339 /* static */ Status XRTTupleAllocation::ExpandTreeOfTuples(
340 const xla::ShapeTree<ExpandedTupleInput>& elements, int device_ordinal,
341 xla::DeviceMemoryAllocator* allocator, xla::Shape* host_shape,
342 xla::Shape* device_shape) {
343 // Initialize both host and device shape to be the 'spine' of the new tuple
344 // shape, given by the shape of the tree of tuples.
345 *host_shape = elements.shape();
346 *device_shape = elements.shape();
347 // Now go over the leaves of the tree of tuples, and 'graft' the host/device
348 // shapes of the allocation at that leaf onto the expanded host/device shapes
349 // at the leaf position.
350 TF_RETURN_IF_ERROR(elements.ForEachElementWithStatus(
351 [&](const xla::ShapeIndex& index, const ExpandedTupleInput& element) {
352 if (elements.IsLeaf(index)) {
353 if (element.allocation == nullptr) {
354 return errors::InvalidArgument(
355 "MakeTuple elements has a null internal node at index ",
356 index.ToString());
357 }
358 if (device_ordinal != element.allocation->device_ordinal() ||
359 allocator != element.allocation->allocator_) {
360 return errors::InvalidArgument(
361 "MakeTuple elements must all be allocated on the same device "
362 "as the destination.");
363 }
364 *xla::ShapeUtil::GetMutableSubshape(host_shape, index) =
365 element.allocation->on_host_shape();
366 *xla::ShapeUtil::GetMutableSubshape(device_shape, index) =
367 element.allocation->on_device_shape();
368 } else {
369 if (element.allocation != nullptr) {
370 return errors::InvalidArgument(
371 "MakeTuple elements has a non-null internal node at index ",
372 index.ToString());
373 }
374 }
375 return Status::OK();
376 }));
377 return Status::OK();
378 }
379
MakeTuple(xla::Backend * backend,int device_ordinal,const xla::ShapeTree<ExpandedTupleInput> & elements,XRTTupleAllocation ** allocation)380 /*static*/ Status XRTTupleAllocation::MakeTuple(
381 xla::Backend* backend, int device_ordinal,
382 const xla::ShapeTree<ExpandedTupleInput>& elements,
383 XRTTupleAllocation** allocation) {
384 auto transfer_manager = backend->transfer_manager();
385 auto allocator = backend->memory_allocator();
386 TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal));
387
388 xla::Shape host_shape;
389 xla::Shape device_shape;
390 TF_RETURN_IF_ERROR(ExpandTreeOfTuples(elements, device_ordinal, allocator,
391 &host_shape, &device_shape));
392
393 // The aliasing is determined below based on whether or not all the inputs are
394 // released while being transferred. allocation_tmp is a local pointer that is
395 // copied to *allocation at the end only if the method succeeds.
396 auto allocation_tmp = new XRTTupleAllocation(device_ordinal, allocator,
397 host_shape, device_shape);
398 core::ScopedUnref allocation_unref(allocation_tmp);
399 // First allocate device memory for the new tuple index tables, one at each
400 // internal node of the elements tree. Do this in a separate pass into a
401 // ScopedShapedBuffer so that it's easy to free the newly-allocated memory if
402 // an allocation fails. Make sure the shape has layout so that the code that
403 // writes index tables will be happy lower down.
404 xla::Shape spine_shape = elements.shape();
405 xla::LayoutUtil::SetToDefaultLayout(&spine_shape);
406 auto new_tuple_buffers = absl::make_unique<xla::ScopedShapedBuffer>(
407 spine_shape, spine_shape, allocator, device_ordinal);
408 TF_RETURN_IF_ERROR(elements.ForEachElementWithStatus(
409 [&](const xla::ShapeIndex& index, const ExpandedTupleInput& element) {
410 if (!elements.IsLeaf(index)) {
411 xla::Shape subshape =
412 xla::ShapeUtil::GetSubshape(device_shape, index);
413 uint64 size = transfer_manager->GetByteSizeRequirement(subshape);
414 TF_ASSIGN_OR_RETURN(xla::OwningDeviceMemory buffer,
415 allocator->Allocate(device_ordinal, size,
416 /*retry_on_failure=*/false));
417 VLOG(2) << "Allocated buffer at " << buffer.opaque() << " index "
418 << index.ToString();
419 // Move the new buffer into new_tuple_buffers, which takes ownership
420 // of it.
421 new_tuple_buffers->set_buffer(std::move(buffer), index);
422 }
423 return Status::OK();
424 }));
425 // Transfer from the ScopedShapedBuffer to a ShapedBuffer, which does not own
426 // the newly-allocated index tables. Right now there's no owner for the new
427 // index tables, so next we will transfer ownership to the new allocation,
428 // taking care not to return early on any errors in the meantime.
429 xla::ShapedBuffer tuple_buffers = new_tuple_buffers->release();
430 // Now fill in the remaining datastructures. After this ForEachElement
431 // completes:
432 // 1) Every leaf element of tuple_buffers will be the root buffer of
433 // an existing allocation, and every internal element of tuple_buffers
434 // will be a newly-allocated index table. tuple_buffers does not own any
435 // of these.
436 // 2) Every element of allocation_tmp->buffers_ will be a correctly
437 // constructed
438 // XRTBufferAllocation wrapping the necessary allocations. For buffers in
439 // existing allocations there will be a new reference owned by the new
440 // allocation, and for newly-allocated index tables there will be a
441 // single reference owned by the new allocation.
442 elements.ForEachElement([&](const xla::ShapeIndex& index,
443 const ExpandedTupleInput& element) {
444 if (elements.IsLeaf(index)) {
445 allocation_tmp->buffers_.CopySubtreeFrom(element.allocation->buffers_, {},
446 index);
447 tuple_buffers.set_buffer(element.allocation->root_allocation(), index);
448 if (element.release_allocation_after_use) {
449 // Transfer the references from element's buffers to the new allocation
450 // rather than incrementing the refcount. The caller should have
451 // validated that release_allocation_after_use is false if
452 // element.allocation appears in more than one leaf.
453 element.allocation->buffers_.ForEachMutableElement(
454 [&](const xla::ShapeIndex& index, XRTBufferAllocationPtr* buffer) {
455 *buffer = new XRTBufferAllocation(
456 se::DeviceMemoryBase(), element.allocation->device_ordinal(),
457 element.allocation->allocator_);
458 });
459 } else {
460 // Increment the refcount on each newly-aliased buffer.
461 element.allocation->buffers_.ForEachElement(
462 [](const xla::ShapeIndex& index,
463 const XRTBufferAllocationPtr& buffer) { buffer->Ref(); });
464 }
465 } else {
466 // This is an internal node of the tuple tree so take ownership of the
467 // newly-created index table.
468 *allocation_tmp->buffers_.mutable_element(index) =
469 new XRTBufferAllocation(tuple_buffers.buffer(index), device_ordinal,
470 allocator);
471 }
472 });
473 // Because the internal nodes of tuple_buffers are exactly the new index
474 // tables, WriteTupleIndexTables will write only the new index tables and not
475 // rewrite the index tables for the existing allocations.
476 TF_RETURN_IF_ERROR(
477 transfer_manager->WriteTupleIndexTables(stream.get(), tuple_buffers));
478
479 *allocation = allocation_tmp;
480 // Get another reference since allocation_tmp will be Unrefed automatically on
481 // exit.
482 (*allocation)->Ref();
483 return Status::OK();
484 }
485
Intern(ResourceMgr * rm,int64 * key)486 Status XRTTupleAllocation::Intern(ResourceMgr* rm, int64* key) {
487 *key = get_uid();
488 string key_string = absl::StrCat(*key);
489 return rm->Create(kTupleContainer, key_string, this);
490 }
491
IsExclusiveOwner()492 bool XRTTupleAllocation::IsExclusiveOwner() {
493 for (const auto& buffer : buffers_) {
494 if (!buffer.second->RefCountIsOne()) return false;
495 }
496 return true;
497 }
498
InitializeFromShapedBuffer(const xla::ShapedBuffer & shaped_buffer,xla::DeviceMemoryAllocator * allocator,int device_ordinal)499 void XRTTupleAllocation::InitializeFromShapedBuffer(
500 const xla::ShapedBuffer& shaped_buffer,
501 xla::DeviceMemoryAllocator* allocator, int device_ordinal) {
502 for (auto& buffer : buffers_) {
503 // Make a reference-counted version of the allocated buffer.
504 buffer.second = new XRTBufferAllocation(shaped_buffer.buffer(buffer.first),
505 device_ordinal, allocator);
506 }
507 }
508
ToShapedBuffer()509 xla::ShapedBuffer XRTTupleAllocation::ToShapedBuffer() {
510 xla::ShapedBuffer shaped_buffer(on_host_shape(), on_device_shape(),
511 allocator_->platform(), device_ordinal_);
512 for (const auto& buffer : buffers_) {
513 shaped_buffer.set_buffer(buffer.second->allocation(), buffer.first);
514 }
515 return shaped_buffer;
516 }
517
AliasBufferFrom(const XRTTupleAllocation & source,const xla::ShapeIndex & source_index,const xla::ShapeIndex & dest_index)518 Status XRTTupleAllocation::AliasBufferFrom(const XRTTupleAllocation& source,
519 const xla::ShapeIndex& source_index,
520 const xla::ShapeIndex& dest_index) {
521 XRTBufferAllocation* source_buffer = source.buffers_.element(source_index);
522 XRTBufferAllocation* dest_buffer = buffers_.element(dest_index);
523 // We allow the destination size being zero, because there are cases where we
524 // are coming in later filling in null/uninitialized device buffers.
525 // In all other cases, the size of the new buffer must match.
526 if (source_buffer->size() != dest_buffer->size() &&
527 dest_buffer->size() != 0) {
528 return errors::InvalidArgument(
529 "Source buffer at index ", source_index.ToString(),
530 " does not match the size of destination buffer at index ",
531 dest_index.ToString(), ": ", source_buffer->size(), " vs ",
532 dest_buffer->size());
533 }
534 *buffers_.mutable_element(dest_index) = source_buffer;
535 source_buffer->Ref();
536 dest_buffer->Unref();
537 return Status::OK();
538 }
539
540 xla::ShapeTree<xla::MaybeOwningDeviceMemory>
ToDeviceMemoryTree(const std::function<bool (const xla::ShapeIndex &)> & release_checker)541 XRTTupleAllocation::ToDeviceMemoryTree(
542 const std::function<bool(const xla::ShapeIndex&)>& release_checker) {
543 xla::ShapeTree<xla::MaybeOwningDeviceMemory> shaped_tree(on_device_shape());
544 for (const auto& buffer : buffers_) {
545 if (!release_checker(buffer.first)) {
546 *shaped_tree.mutable_element(buffer.first) = buffer.second->allocation();
547 } else {
548 *shaped_tree.mutable_element(buffer.first) = xla::OwningDeviceMemory(
549 buffer.second->allocation(), device_ordinal_, allocator_);
550 DiscardAllocation(buffer.first);
551 }
552 }
553 return shaped_tree;
554 }
555
556 } // namespace tensorflow
557