1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/generic_transfer_manager.h"
17
18 #include <string>
19 #include <utility>
20 #include <vector>
21
22 #include "tensorflow/compiler/xla/layout_util.h"
23 #include "tensorflow/compiler/xla/literal.h"
24 #include "tensorflow/compiler/xla/shape_util.h"
25 #include "tensorflow/compiler/xla/status_macros.h"
26 #include "tensorflow/compiler/xla/types.h"
27 #include "tensorflow/compiler/xla/util.h"
28 #include "tensorflow/core/lib/core/errors.h"
29 #include "tensorflow/core/platform/logging.h"
30 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
31
32 namespace xla {
33
GenericTransferManager(se::Platform::Id platform_id,size_t pointer_size)34 GenericTransferManager::GenericTransferManager(se::Platform::Id platform_id,
35 size_t pointer_size)
36 : platform_id_(platform_id), pointer_size_(pointer_size) {}
37
PlatformId() const38 se::Platform::Id GenericTransferManager::PlatformId() const {
39 return platform_id_;
40 }
41
WriteSingleTupleIndexTable(se::Stream * stream,absl::Span<const se::DeviceMemoryBase> elements,const Shape & shape,se::DeviceMemoryBase * region)42 Status GenericTransferManager::WriteSingleTupleIndexTable(
43 se::Stream* stream, absl::Span<const se::DeviceMemoryBase> elements,
44 const Shape& shape, se::DeviceMemoryBase* region) {
45 TF_RET_CHECK(elements.size() == ShapeUtil::TupleElementCount(shape));
46
47 auto element_pointers = std::make_shared<std::vector<const void*>>();
48 element_pointers->reserve(elements.size());
49 for (const se::DeviceMemoryBase& element : elements) {
50 element_pointers->push_back(element.opaque());
51 }
52 TF_RETURN_IF_ERROR(TransferBufferToDevice(
53 stream, GetByteSizeRequirement(shape), element_pointers->data(), region));
54 // Ensure the buffer is transferred before we destroy element_pointers.
55 stream->ThenDoHostCallback([element_pointers{std::move(element_pointers)}]() {
56 /* holds reference to element_pointers in closure */
57 });
58 return Status::OK();
59 }
60
TransferLiteralFromDevice(se::Stream * stream,const ShapedBuffer & device_buffer,MutableBorrowingLiteral literal,std::function<void (Status)> done,const TransferMetadata *)61 void GenericTransferManager::TransferLiteralFromDevice(
62 se::Stream* stream, const ShapedBuffer& device_buffer,
63 MutableBorrowingLiteral literal, std::function<void(Status)> done,
64 const TransferMetadata* /*transfer_metadata*/) {
65 VLOG(2) << "transferring literal from device ordinal "
66 << stream->parent()->device_ordinal()
67 << "; device buffer: " << device_buffer;
68 Status status = [&]() -> Status {
69 TF_RET_CHECK(stream->parent()->device_ordinal() ==
70 device_buffer.device_ordinal());
71
72 TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
73 device_buffer.on_device_shape(),
74 [&](const Shape& subshape, const ShapeIndex& index) -> Status {
75 if (subshape.IsArray()) {
76 stream->ThenMemcpy(
77 /*host_dst=*/literal.untyped_data(index),
78 /*gpu_src=*/device_buffer.buffer(index),
79 /*size=*/GetByteSizeRequirement(subshape));
80 }
81 return Status::OK();
82 }));
83 return Status::OK();
84 }();
85 if (!status.ok()) {
86 done(status);
87 return;
88 }
89 done(stream->BlockHostUntilDone());
90 }
91
TransferLiteralToDeviceAsync(se::Stream * stream,const LiteralSlice & literal,const ShapedBuffer & device_buffer,const TransferMetadata *)92 Status GenericTransferManager::TransferLiteralToDeviceAsync(
93 se::Stream* stream, const LiteralSlice& literal,
94 const ShapedBuffer& device_buffer,
95 const TransferMetadata* /*transfer_metadata*/) {
96 const Shape& shape = literal.shape();
97 VLOG(2) << "transferring literal shape to device: "
98 << ShapeUtil::HumanString(shape)
99 << "; device buffer: " << device_buffer;
100
101 TF_RET_CHECK(
102 ShapeUtil::Compatible(literal.shape(), device_buffer.on_device_shape()));
103 TF_RET_CHECK(stream->parent()->device_ordinal() ==
104 device_buffer.device_ordinal());
105
106 TF_RETURN_IF_ERROR(WriteTupleIndexTablesAsync(stream, device_buffer));
107
108 return ShapeUtil::ForEachSubshapeWithStatus(
109 device_buffer.on_device_shape(),
110 [&](const Shape& device_subshape, const ShapeIndex& index) -> Status {
111 se::DeviceMemoryBase device_memory = device_buffer.buffer(index);
112 if (device_subshape.IsArray()) {
113 TF_RET_CHECK(GetByteSizeRequirement(device_subshape) ==
114 device_memory.size());
115 // Element is array-shaped: transfer array data to device buffer.
116 const auto subliteral = LiteralSlice(literal, index);
117 Literal relayed_out_literal;
118 const void* source;
119 if (LayoutUtil::Equal(device_subshape.layout(),
120 subliteral.shape().layout())) {
121 source = subliteral.untyped_data();
122 return TransferBufferToDevice(
123 stream,
124 /*size=*/GetByteSizeRequirement(device_subshape), source,
125 &device_memory);
126 } else {
127 // Relayout data before transferring.
128 relayed_out_literal = subliteral.Relayout(device_subshape.layout(),
129 /*shape_index=*/{});
130 source = relayed_out_literal.untyped_data();
131 TF_RETURN_IF_ERROR(TransferBufferToDevice(
132 stream,
133 /*size=*/GetByteSizeRequirement(device_subshape), source,
134 &device_memory));
135 return stream->BlockHostUntilDone();
136 }
137 }
138 return Status::OK();
139 });
140 }
141
TransferLiteralToInfeed(se::StreamExecutor * executor,const LiteralSlice & literal)142 Status GenericTransferManager::TransferLiteralToInfeed(
143 se::StreamExecutor* executor, const LiteralSlice& literal) {
144 return Unimplemented("Generic transfer to Infeed");
145 }
146
TransferLiteralFromOutfeed(se::StreamExecutor * executor,MutableBorrowingLiteral literal)147 Status GenericTransferManager::TransferLiteralFromOutfeed(
148 se::StreamExecutor* executor, MutableBorrowingLiteral literal) {
149 return Unimplemented("Generic transfer from Outfeed");
150 }
151
ResetDevices(absl::Span<se::StreamExecutor * const>)152 Status GenericTransferManager::ResetDevices(
153 absl::Span<se::StreamExecutor* const>
154 /*executors*/) {
155 return Unimplemented(
156 "Device reset is not yet supported on this platform (b/30481585)");
157 }
158
GetByteSizeRequirement(const Shape & shape) const159 int64 GenericTransferManager::GetByteSizeRequirement(const Shape& shape) const {
160 return ShapeUtil::ByteSizeOf(shape, pointer_size_);
161 }
162
163 } // namespace xla
164