1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/stream_executor/tpu/tpu_transfer_manager.h"
17
18 #include <utility>
19
20 #include "tensorflow/compiler/xla/shape_util.h"
21 #include "tensorflow/compiler/xla/xla_data.pb.h"
22 #include "tensorflow/core/tpu/tpu_api.h"
23 #include "tensorflow/stream_executor/device_memory.h"
24 #include "tensorflow/stream_executor/tpu/c_api_conversions.h"
25 #include "tensorflow/stream_executor/tpu/noncopyable_buffer.h"
26 #include "tensorflow/stream_executor/tpu/proto_helper.h"
27 #include "tensorflow/stream_executor/tpu/status_helper.h"
28 #include "tensorflow/stream_executor/tpu/tpu_executor.h"
29 #include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
30 #include "tensorflow/stream_executor/tpu/tpu_platform.h"
31 #include "tensorflow/stream_executor/tpu/tpu_platform_id.h"
32
33 namespace tensorflow {
34 namespace tpu {
35
36 using Status = stream_executor::port::Status;
37 template <typename T>
38 using StatusOr = stream_executor::port::StatusOr<T>;
39
TpuTransferManager()40 TpuTransferManager::TpuTransferManager() {
41 manager_ = tpu::ExecutorApiFn()->TpuTransferManager_NewFn();
42 }
43
~TpuTransferManager()44 TpuTransferManager::~TpuTransferManager() {
45 tpu::ExecutorApiFn()->TpuTransferManager_FreeFn(manager_);
46 }
47
PlatformId() const48 stream_executor::Platform::Id TpuTransferManager::PlatformId() const {
49 return GetTpuPlatformId();
50 }
51
HostShapeToDeviceShape(const xla::Shape & host_shape) const52 xla::Shape TpuTransferManager::HostShapeToDeviceShape(
53 const xla::Shape& host_shape) const {
54 XLA_Shape c_host_shape;
55 XLA_Shape c_device_shape;
56
57 ApiConverter::ToC(host_shape, &c_host_shape);
58
59 tpu::ExecutorApiFn()->TpuTransferManager_HostShapeToDeviceShapeFn(
60 manager_, &c_host_shape, &c_device_shape);
61 xla::Shape device_shape = ApiConverter::FromC(&c_device_shape);
62 ApiConverter::Free(&c_host_shape);
63 ApiConverter::Free(&c_device_shape);
64 return device_shape;
65 }
66
TransferLiteralToDeviceAsync(stream_executor::Stream * stream,const xla::LiteralSlice & literal,const xla::ShapedBuffer & device_buffer,const TransferMetadata * transfer_metadata)67 Status TpuTransferManager::TransferLiteralToDeviceAsync(
68 stream_executor::Stream* stream, const xla::LiteralSlice& literal,
69 const xla::ShapedBuffer& device_buffer,
70 const TransferMetadata* transfer_metadata) {
71 StatusHelper status;
72
73 XLA_Literal c_literal;
74 ApiConverter::ToC(literal, &c_literal);
75
76 XLA_ShapedBuffer c_device_buffer;
77 ApiConverter::ToC(device_buffer, &c_device_buffer);
78
79 tpu::ExecutorApiFn()->TpuTransferManager_TransferLiteralToDeviceAsyncFn(
80 manager_,
81 TpuPlatform::GetRegisteredPlatform()->stream_map()->at(
82 stream->implementation()),
83 &c_literal, &c_device_buffer, status.c_status);
84 ApiConverter::Free(&c_device_buffer);
85 ApiConverter::Free(&c_literal);
86 return status.status();
87 }
88
TransferLiteralToInfeed(stream_executor::StreamExecutor * executor,const xla::LiteralSlice & literal)89 Status TpuTransferManager::TransferLiteralToInfeed(
90 stream_executor::StreamExecutor* executor,
91 const xla::LiteralSlice& literal) {
92 StatusHelper status;
93 XLA_Literal c_literal;
94 ApiConverter::ToC(literal, &c_literal);
95 auto* tpu_executor = static_cast<TpuExecutor*>(executor->implementation());
96
97 tpu::ExecutorApiFn()->TpuTransferManager_TransferLiteralToInfeedFn(
98 manager_, tpu_executor->se_executor(), &c_literal, status.c_status);
99
100 ApiConverter::Free(&c_literal);
101
102 return status.status();
103 }
104
TransferBuffersToInfeed(se::StreamExecutor * executor,const std::deque<tensorflow::tpu::NoncopyableBuffer> & buffers)105 Status TpuTransferManager::TransferBuffersToInfeed(
106 se::StreamExecutor* executor,
107 const std::deque<tensorflow::tpu::NoncopyableBuffer>& buffers) {
108 StatusHelper status;
109 auto* tpu_executor = static_cast<TpuExecutor*>(executor->implementation());
110
111 std::vector<int64_t> buffers_size;
112 std::vector<uint32_t*> buffers_array;
113
114 buffers_size.reserve(buffers.size());
115 buffers_array.reserve(buffers.size());
116
117 for (int64_t i = 0; i < buffers.size(); ++i) {
118 absl::Span<const uint32_t> span = buffers[i].const_data<uint32_t>();
119 buffers_array.push_back(const_cast<uint32_t*>(span.data()));
120 buffers_size.push_back(span.size());
121 }
122
123 tpu::ExecutorApiFn()->TpuTransferManager_TransferBuffersToInfeedFn(
124 manager_, tpu_executor->se_executor(), buffers_array.data(),
125 buffers_size.data(), buffers_size.size(), status.c_status);
126 return status.status();
127 }
128
TransferLiteralFromOutfeed(stream_executor::StreamExecutor * executor,xla::MutableBorrowingLiteral literal)129 Status TpuTransferManager::TransferLiteralFromOutfeed(
130 stream_executor::StreamExecutor* executor,
131 xla::MutableBorrowingLiteral literal) {
132 StatusHelper status;
133 XLA_Shape c_shape;
134 XLA_Literal c_literal;
135 auto* tpu_executor = static_cast<TpuExecutor*>(executor->implementation());
136
137 ApiConverter::ToC(literal.shape(), &c_shape);
138 ApiConverter::ToC(literal, &c_literal);
139
140 tpu::ExecutorApiFn()->TpuTransferManager_TransferLiteralFromOutfeedFn(
141 manager_, tpu_executor->se_executor(), &c_shape, &c_literal,
142 status.c_status);
143
144 ApiConverter::Free(&c_shape);
145 ApiConverter::Free(&c_literal);
146
147 return status.status();
148 }
149
ResetDevices(absl::Span<stream_executor::StreamExecutor * const> executor)150 Status TpuTransferManager::ResetDevices(
151 absl::Span<stream_executor::StreamExecutor* const> executor) {
152 StatusHelper status;
153 std::vector<SE_StreamExecutor*> se;
154 se.reserve(executor.size());
155 for (int64_t i = 0; i < executor.size(); ++i) {
156 se.push_back(static_cast<TpuExecutor*>(executor[i]->implementation())
157 ->se_executor());
158 }
159
160 tpu::ExecutorApiFn()->TpuTransferManager_ResetDevicesFn(
161 manager_, se.data(), executor.size(), status.c_status);
162 return status.status();
163 }
164
165 struct TransferFromDeviceState {
166 std::atomic<int64_t> remaining_transfers;
167 TF_Status* overall_status =
168 tpu::ExecutorApiFn()->TpuStatus_NewFn(); // OK or the first error
169 std::function<void(Status)> done;
170
TransferFinishedtensorflow::tpu::TransferFromDeviceState171 void TransferFinished(TF_Status* status) {
172 if (!tpu::ExecutorApiFn()->TpuStatus_OkFn(status) &&
173 tpu::ExecutorApiFn()->TpuStatus_OkFn(overall_status)) {
174 std::swap(overall_status, status);
175 }
176 tpu::ExecutorApiFn()->TpuStatus_FreeFn(status);
177
178 if (--remaining_transfers == 0) {
179 done(StatusHelper::FromC(overall_status));
180 tpu::ExecutorApiFn()->TpuStatus_FreeFn(overall_status);
181 delete this;
182 }
183 }
184 };
185
TransferLiteralFromDeviceTrampoline(void * ctx,TF_Status * status)186 void TransferLiteralFromDeviceTrampoline(void* ctx, TF_Status* status) {
187 reinterpret_cast<TransferFromDeviceState*>(ctx)->TransferFinished(status);
188 }
189
TransferLiteralFromDevice(stream_executor::Stream * stream,const xla::ShapedBuffer & device_buffer,xla::MutableBorrowingLiteral literal,std::function<void (Status)> done,const TransferMetadata * transfer_metadata)190 void TpuTransferManager::TransferLiteralFromDevice(
191 stream_executor::Stream* stream, const xla::ShapedBuffer& device_buffer,
192 xla::MutableBorrowingLiteral literal, std::function<void(Status)> done,
193 const TransferMetadata* transfer_metadata) {
194 TransferFromDeviceState* state = new TransferFromDeviceState;
195 state->remaining_transfers = 1;
196 state->done = done;
197 XLA_ShapedBuffer c_device_buffer;
198 ApiConverter::ToC(device_buffer, &c_device_buffer);
199 XLA_Literal c_literal;
200 ApiConverter::ToC(literal, &c_literal);
201
202 tpu::ExecutorApiFn()->TpuTransferManager_TransferLiteralFromDeviceFn(
203 manager_,
204 TpuPlatform::GetRegisteredPlatform()->LookupStream(
205 stream->implementation()),
206 &c_device_buffer, &c_literal, TransferLiteralFromDeviceTrampoline, state);
207 ApiConverter::Free(&c_device_buffer);
208 ApiConverter::Free(&c_literal);
209 }
210
GetByteSizeRequirement(const xla::Shape & shape) const211 int64 TpuTransferManager::GetByteSizeRequirement(
212 const xla::Shape& shape) const {
213 XLA_Shape c_shape;
214 ApiConverter::ToC(shape, &c_shape);
215
216 int64 size_in_bytes =
217 tpu::ExecutorApiFn()->TpuTransferManager_GetByteSizeRequirementFn(
218 manager_, &c_shape);
219
220 ApiConverter::Free(&c_shape);
221 return size_in_bytes;
222 }
223
ChooseCompactLayoutForShape(const xla::Shape & host_shape) const224 StatusOr<xla::Shape> TpuTransferManager::ChooseCompactLayoutForShape(
225 const xla::Shape& host_shape) const {
226 XLA_Shape c_host_shape;
227 ApiConverter::ToC(host_shape, &c_host_shape);
228 XLA_Shape c_output;
229 StatusHelper status;
230 tpu::ExecutorApiFn()->TpuTransferManager_ChooseCompactLayoutForShapeFn(
231 manager_, &c_host_shape, &c_output, status.c_status);
232 // TODO(skyewm): use a scoped version of XLA_Shape
233 ApiConverter::Free(&c_host_shape);
234 if (!status.status().ok()) {
235 ApiConverter::Free(&c_output);
236 return status.status();
237 }
238 xla::Shape output = ApiConverter::FromC(&c_output);
239 ApiConverter::Free(&c_output);
240 return output;
241 }
242
CanShapedBufferBeAccessedNow(stream_executor::StreamExecutor * executor,const xla::ShapedBuffer & device_buffer) const243 bool TpuTransferManager::CanShapedBufferBeAccessedNow(
244 stream_executor::StreamExecutor* executor,
245 const xla::ShapedBuffer& device_buffer) const {
246 auto* tpu_executor = down_cast<TpuExecutor*>(executor->implementation());
247 XLA_ShapedBuffer c_device_buffer;
248 ApiConverter::ToC(device_buffer, &c_device_buffer);
249 auto cleanup = xla::MakeCleanup(
250 [&c_device_buffer]() { ApiConverter::Free(&c_device_buffer); });
251 return tpu::ExecutorApiFn()
252 ->TpuTransferManager_CanShapedBufferBeAccessedNowFn(
253 manager_, tpu_executor->se_executor(), &c_device_buffer);
254 }
255
CanBufferBeAccessedNow(se::StreamExecutor * executor,const se::DeviceMemoryBase & device_buffer) const256 bool TpuTransferManager::CanBufferBeAccessedNow(
257 se::StreamExecutor* executor,
258 const se::DeviceMemoryBase& device_buffer) const {
259 auto* tpu_executor = down_cast<TpuExecutor*>(executor->implementation());
260 SE_DeviceMemoryBase c_device_buffer{const_cast<void*>(device_buffer.opaque()),
261 device_buffer.size(),
262 device_buffer.payload()};
263 return tpu::ExecutorApiFn()->TpuTransferManager_CanBufferBeAccessedNowFn(
264 manager_, tpu_executor->se_executor(), &c_device_buffer);
265 }
266
WriteSingleTupleIndexTable(stream_executor::Stream * stream,absl::Span<const stream_executor::DeviceMemoryBase> elements,const xla::Shape & shape,stream_executor::DeviceMemoryBase * region)267 Status TpuTransferManager::WriteSingleTupleIndexTable(
268 stream_executor::Stream* stream,
269 absl::Span<const stream_executor::DeviceMemoryBase> elements,
270 const xla::Shape& shape, stream_executor::DeviceMemoryBase* region) {
271 CHECK_GT(elements.size(), 0);
272 SE_DeviceMemoryBase* elements_bases =
273 new SE_DeviceMemoryBase[elements.size()];
274 for (int i = 0; i < elements.size(); i++) {
275 elements_bases[i] =
276 SE_DeviceMemoryBase{const_cast<void*>(elements[i].opaque()),
277 elements[i].size(), elements[i].payload()};
278 }
279 XLA_Shape c_shape;
280 ApiConverter::ToC(shape, &c_shape);
281 SE_DeviceMemoryBase region_base{region->opaque(), region->size(),
282 region->payload()};
283 StatusHelper status;
284
285 tpu::ExecutorApiFn()->TpuTransferManager_WriteSingleTupleIndexTableFn(
286 manager_,
287 TpuPlatform::GetRegisteredPlatform()->stream_map()->at(
288 stream->implementation()),
289 elements_bases, elements.size(), &c_shape, ®ion_base, status.c_status);
290
291 delete[] elements_bases;
292 ApiConverter::Free(&c_shape);
293 return status.status();
294 }
295
LinearizeToBuffers(const xla::LiteralSlice & literal,std::deque<tensorflow::tpu::NoncopyableBuffer> * buffers)296 Status TpuTransferManager::LinearizeToBuffers(
297 const xla::LiteralSlice& literal,
298 std::deque<tensorflow::tpu::NoncopyableBuffer>* buffers) {
299 XLA_Literal c_literal;
300 ApiConverter::ToC(literal, &c_literal);
301
302 char** buffers_array;
303 int64_t* buffers_size;
304 int64_t buffers_array_size;
305 StatusHelper status;
306
307 tpu::ExecutorApiFn()->TpuTransferManager_LinearizeToBuffersFn(
308 manager_, &c_literal, &buffers_array, &buffers_size, &buffers_array_size,
309 status.c_status);
310
311 for (int64_t i = 0; i < buffers_array_size; ++i) {
312 tpu::NoncopyableBuffer buf(buffers_size[i]);
313 memcpy(buf.mutable_data<uint8_t>().data(), buffers_array[i],
314 buffers_size[i]);
315 buffers->push_back(std::move(buf));
316 }
317
318 tpu::ExecutorApiFn()->TpuTransferManager_FreeBuffersFn(
319 buffers_array, buffers_size, buffers_array_size);
320
321 ApiConverter::Free(&c_literal);
322 return status.status();
323 }
324
325 } // namespace tpu
326 } // namespace tensorflow
327