1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/generic_transfer_manager.h"
17
18 #include <string>
19 #include <utility>
20 #include <vector>
21
22 #include "tensorflow/compiler/xla/layout_util.h"
23 #include "tensorflow/compiler/xla/literal.h"
24 #include "tensorflow/compiler/xla/shape_util.h"
25 #include "tensorflow/compiler/xla/status_macros.h"
26 #include "tensorflow/compiler/xla/types.h"
27 #include "tensorflow/compiler/xla/util.h"
28 #include "tensorflow/core/lib/core/errors.h"
29 #include "tensorflow/core/platform/logging.h"
30 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
31
32 namespace xla {
33
GenericTransferManager(se::Platform::Id platform_id,size_t pointer_size)34 GenericTransferManager::GenericTransferManager(se::Platform::Id platform_id,
35 size_t pointer_size)
36 : platform_id_(platform_id), pointer_size_(pointer_size) {}
37
PlatformId() const38 se::Platform::Id GenericTransferManager::PlatformId() const {
39 return platform_id_;
40 }
41
WriteSingleTupleIndexTable(se::Stream * stream,absl::Span<const se::DeviceMemoryBase> elements,const Shape & shape,se::DeviceMemoryBase * region)42 Status GenericTransferManager::WriteSingleTupleIndexTable(
43 se::Stream* stream, absl::Span<const se::DeviceMemoryBase> elements,
44 const Shape& shape, se::DeviceMemoryBase* region) {
45 TF_RET_CHECK(elements.size() == ShapeUtil::TupleElementCount(shape));
46
47 auto element_pointers = std::make_shared<std::vector<const void*>>();
48 element_pointers->reserve(elements.size());
49 for (const se::DeviceMemoryBase& element : elements) {
50 element_pointers->push_back(element.opaque());
51 }
52 TF_RETURN_IF_ERROR(TransferBufferToDevice(
53 stream, GetByteSizeRequirement(shape), element_pointers->data(), region));
54 // Ensure the buffer is transferred before we destroy element_pointers.
55 stream->ThenDoHostCallback([element_pointers{std::move(element_pointers)}]() {
56 /* holds reference to element_pointers in closure */
57 });
58 return Status::OK();
59 }
60
TransferLiteralFromDevice(se::Stream * stream,const ShapedBuffer & device_buffer,MutableBorrowingLiteral literal,std::function<void (Status)> done,const TransferMetadata *)61 void GenericTransferManager::TransferLiteralFromDevice(
62 se::Stream* stream, const ShapedBuffer& device_buffer,
63 MutableBorrowingLiteral literal, std::function<void(Status)> done,
64 const TransferMetadata* /*transfer_metadata*/) {
65 VLOG(2) << "transferring literal from device ordinal "
66 << stream->parent()->device_ordinal()
67 << "; device buffer: " << device_buffer;
68 Status status = [&]() -> Status {
69 TF_RET_CHECK(stream->parent()->device_ordinal() ==
70 device_buffer.device_ordinal());
71
72 TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
73 device_buffer.on_device_shape(),
74 [&](const Shape& subshape, const ShapeIndex& index) -> Status {
75 if (subshape.IsArray()) {
76 stream->ThenMemcpy(
77 /*host_dst=*/literal.untyped_data(index),
78 /*gpu_src=*/device_buffer.buffer(index),
79 // With bounded dynamic shapes, the shape of the device buffer
80 // (bounded allocation) can be bigger than the literal.
81 /*size=*/
82 GetByteSizeRequirement(
83 ShapeUtil::GetSubshape(literal.shape(), index)));
84 }
85 return Status::OK();
86 }));
87 return Status::OK();
88 }();
89 if (!status.ok()) {
90 done(status);
91 return;
92 }
93 done(stream->BlockHostUntilDone());
94 }
95
TransferLiteralToDeviceAsync(se::Stream * stream,const LiteralSlice & literal,const ShapedBuffer & device_buffer,const TransferMetadata *)96 Status GenericTransferManager::TransferLiteralToDeviceAsync(
97 se::Stream* stream, const LiteralSlice& literal,
98 const ShapedBuffer& device_buffer,
99 const TransferMetadata* /*transfer_metadata*/) {
100 const Shape& shape = literal.shape();
101 VLOG(2) << "transferring literal shape to device: "
102 << ShapeUtil::HumanString(shape)
103 << "; device buffer: " << device_buffer;
104
105 TF_RET_CHECK(
106 ShapeUtil::Compatible(literal.shape(), device_buffer.on_device_shape()));
107 TF_RET_CHECK(stream->parent()->device_ordinal() ==
108 device_buffer.device_ordinal());
109
110 TF_RETURN_IF_ERROR(WriteTupleIndexTablesAsync(stream, device_buffer));
111
112 return ShapeUtil::ForEachSubshapeWithStatus(
113 device_buffer.on_device_shape(),
114 [&](const Shape& device_subshape, const ShapeIndex& index) -> Status {
115 se::DeviceMemoryBase device_memory = device_buffer.buffer(index);
116 if (device_subshape.IsArray()) {
117 TF_RET_CHECK(GetByteSizeRequirement(device_subshape) ==
118 device_memory.size());
119 // Element is array-shaped: transfer array data to device buffer.
120 const auto subliteral = LiteralSlice(literal, index);
121 Literal relayed_out_literal;
122 const void* source;
123 if (LayoutUtil::Equal(device_subshape.layout(),
124 subliteral.shape().layout())) {
125 source = subliteral.untyped_data();
126 return TransferBufferToDevice(
127 stream,
128 /*size=*/GetByteSizeRequirement(device_subshape), source,
129 &device_memory);
130 } else {
131 // Relayout data before transferring.
132 relayed_out_literal = subliteral.Relayout(device_subshape.layout(),
133 /*shape_index=*/{});
134 source = relayed_out_literal.untyped_data();
135 TF_RETURN_IF_ERROR(TransferBufferToDevice(
136 stream,
137 /*size=*/GetByteSizeRequirement(device_subshape), source,
138 &device_memory));
139 return stream->BlockHostUntilDone();
140 }
141 }
142 return Status::OK();
143 });
144 }
145
TransferLiteralToInfeed(se::StreamExecutor * executor,const LiteralSlice & literal)146 Status GenericTransferManager::TransferLiteralToInfeed(
147 se::StreamExecutor* executor, const LiteralSlice& literal) {
148 return Unimplemented("Generic transfer to Infeed");
149 }
150
TransferLiteralFromOutfeed(se::StreamExecutor * executor,MutableBorrowingLiteral literal)151 Status GenericTransferManager::TransferLiteralFromOutfeed(
152 se::StreamExecutor* executor, MutableBorrowingLiteral literal) {
153 return Unimplemented("Generic transfer from Outfeed");
154 }
155
ResetDevices(absl::Span<se::StreamExecutor * const>)156 Status GenericTransferManager::ResetDevices(
157 absl::Span<se::StreamExecutor* const>
158 /*executors*/) {
159 return Unimplemented(
160 "Device reset is not yet supported on this platform (b/30481585)");
161 }
162
GetByteSizeRequirement(const Shape & shape) const163 int64 GenericTransferManager::GetByteSizeRequirement(const Shape& shape) const {
164 if (shape.is_static() || shape.IsTuple()) {
165 return ShapeUtil::ByteSizeOf(shape, pointer_size_);
166 }
167 int64_t metadata_size = sizeof(int32) * shape.dimensions_size();
168 return ShapeUtil::ByteSizeOf(shape, pointer_size_) + metadata_size;
169 }
170
171 } // namespace xla
172