1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/stream_executor/tpu/tpu_transfer_manager.h"
17
18 #include <utility>
19
20 #include "tensorflow/compiler/xla/literal.h"
21 #include "tensorflow/compiler/xla/shape_util.h"
22 #include "tensorflow/compiler/xla/xla_data.pb.h"
23 #include "tensorflow/core/tpu/tpu_api.h"
24 #include "tensorflow/stream_executor/device_memory.h"
25 #include "tensorflow/stream_executor/tpu/c_api_conversions.h"
26 #include "tensorflow/stream_executor/tpu/noncopyable_buffer.h"
27 #include "tensorflow/stream_executor/tpu/proto_helper.h"
28 #include "tensorflow/stream_executor/tpu/status_helper.h"
29 #include "tensorflow/stream_executor/tpu/tpu_executor.h"
30 #include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
31 #include "tensorflow/stream_executor/tpu/tpu_platform.h"
32 #include "tensorflow/stream_executor/tpu/tpu_platform_id.h"
33
34 namespace tensorflow {
35 namespace tpu {
36
37 using Status = stream_executor::port::Status;
38 template <typename T>
39 using StatusOr = stream_executor::port::StatusOr<T>;
40
TpuTransferManager()41 TpuTransferManager::TpuTransferManager() {
42 manager_ = tpu::ExecutorApiFn()->TpuTransferManager_NewFn();
43 }
44
~TpuTransferManager()45 TpuTransferManager::~TpuTransferManager() {
46 tpu::ExecutorApiFn()->TpuTransferManager_FreeFn(manager_);
47 }
48
PlatformId() const49 stream_executor::Platform::Id TpuTransferManager::PlatformId() const {
50 return GetTpuPlatformId();
51 }
52
HostShapeToDeviceShape(const xla::Shape & host_shape) const53 xla::Shape TpuTransferManager::HostShapeToDeviceShape(
54 const xla::Shape& host_shape) const {
55 XLA_Shape c_host_shape;
56 XLA_Shape c_device_shape;
57
58 ApiConverter::ToC(host_shape, &c_host_shape);
59
60 tpu::ExecutorApiFn()->TpuTransferManager_HostShapeToDeviceShapeFn(
61 manager_, &c_host_shape, &c_device_shape);
62 xla::Shape device_shape = ApiConverter::FromC(&c_device_shape);
63 ApiConverter::Free(&c_host_shape);
64 ApiConverter::Free(&c_device_shape);
65 return device_shape;
66 }
67
TransferLiteralToDeviceAsync(stream_executor::Stream * stream,const xla::LiteralSlice & literal,const xla::ShapedBuffer & device_buffer,const TransferMetadata * transfer_metadata)68 Status TpuTransferManager::TransferLiteralToDeviceAsync(
69 stream_executor::Stream* stream, const xla::LiteralSlice& literal,
70 const xla::ShapedBuffer& device_buffer,
71 const TransferMetadata* transfer_metadata) {
72 StatusHelper status;
73
74 XLA_Literal c_literal;
75 ApiConverter::ToC(literal, &c_literal);
76
77 XLA_ShapedBuffer c_device_buffer;
78 ApiConverter::ToC(device_buffer, &c_device_buffer);
79
80 tpu::ExecutorApiFn()->TpuTransferManager_TransferLiteralToDeviceAsyncFn(
81 manager_,
82 TpuPlatform::GetRegisteredPlatform()->LookupStream(
83 stream->implementation()),
84 &c_literal, &c_device_buffer, status.c_status);
85 ApiConverter::Free(&c_device_buffer);
86 ApiConverter::Free(&c_literal);
87 return status.status();
88 }
89
TransferLiteralToInfeed(stream_executor::StreamExecutor * executor,const xla::LiteralSlice & literal)90 Status TpuTransferManager::TransferLiteralToInfeed(
91 stream_executor::StreamExecutor* executor,
92 const xla::LiteralSlice& literal) {
93 StatusHelper status;
94 XLA_Literal c_literal;
95 ApiConverter::ToC(literal, &c_literal);
96 auto* tpu_executor = static_cast<TpuExecutor*>(executor->implementation());
97
98 tpu::ExecutorApiFn()->TpuTransferManager_TransferLiteralToInfeedFn(
99 manager_, tpu_executor->se_executor(), &c_literal, status.c_status);
100
101 ApiConverter::Free(&c_literal);
102
103 return status.status();
104 }
105
TransferBuffersToInfeed(se::StreamExecutor * executor,const std::deque<tensorflow::tpu::NoncopyableBuffer> & buffers)106 Status TpuTransferManager::TransferBuffersToInfeed(
107 se::StreamExecutor* executor,
108 const std::deque<tensorflow::tpu::NoncopyableBuffer>& buffers) {
109 StatusHelper status;
110 auto* tpu_executor = static_cast<TpuExecutor*>(executor->implementation());
111
112 std::vector<int64_t> buffers_size;
113 std::vector<uint32_t*> buffers_array;
114
115 buffers_size.reserve(buffers.size());
116 buffers_array.reserve(buffers.size());
117
118 for (int64_t i = 0; i < buffers.size(); ++i) {
119 absl::Span<const uint32_t> span = buffers[i].const_data<uint32_t>();
120 buffers_array.push_back(const_cast<uint32_t*>(span.data()));
121 buffers_size.push_back(span.size());
122 }
123
124 tpu::ExecutorApiFn()->TpuTransferManager_TransferBuffersToInfeedFn(
125 manager_, tpu_executor->se_executor(), buffers_array.data(),
126 buffers_size.data(), buffers_size.size(), status.c_status);
127 return status.status();
128 }
129
TransferLiteralFromOutfeed(stream_executor::StreamExecutor * executor,xla::MutableBorrowingLiteral literal)130 Status TpuTransferManager::TransferLiteralFromOutfeed(
131 stream_executor::StreamExecutor* executor,
132 xla::MutableBorrowingLiteral literal) {
133 StatusHelper status;
134 XLA_Shape c_shape;
135 XLA_Literal c_literal;
136 auto* tpu_executor = static_cast<TpuExecutor*>(executor->implementation());
137
138 ApiConverter::ToC(literal.shape(), &c_shape);
139 ApiConverter::ToC(literal, &c_literal);
140
141 tpu::ExecutorApiFn()->TpuTransferManager_TransferLiteralFromOutfeedFn(
142 manager_, tpu_executor->se_executor(), &c_shape, &c_literal,
143 status.c_status);
144
145 ApiConverter::Free(&c_shape);
146 ApiConverter::Free(&c_literal);
147
148 return status.status();
149 }
150
ResetDevices(absl::Span<stream_executor::StreamExecutor * const> executor)151 Status TpuTransferManager::ResetDevices(
152 absl::Span<stream_executor::StreamExecutor* const> executor) {
153 StatusHelper status;
154 std::vector<SE_StreamExecutor*> se;
155 se.reserve(executor.size());
156 for (int64_t i = 0; i < executor.size(); ++i) {
157 se.push_back(static_cast<TpuExecutor*>(executor[i]->implementation())
158 ->se_executor());
159 }
160
161 tpu::ExecutorApiFn()->TpuTransferManager_ResetDevicesFn(
162 manager_, se.data(), executor.size(), status.c_status);
163 return status.status();
164 }
165
166 struct TransferFromDeviceState {
167 std::atomic<int64_t> remaining_transfers;
168 TF_Status* overall_status =
169 tpu::ExecutorApiFn()->TpuStatus_NewFn(); // OK or the first error
170 std::function<void(Status)> done;
171
TransferFinishedtensorflow::tpu::TransferFromDeviceState172 void TransferFinished(TF_Status* status) {
173 if (!tpu::ExecutorApiFn()->TpuStatus_OkFn(status) &&
174 tpu::ExecutorApiFn()->TpuStatus_OkFn(overall_status)) {
175 std::swap(overall_status, status);
176 }
177 tpu::ExecutorApiFn()->TpuStatus_FreeFn(status);
178
179 if (--remaining_transfers == 0) {
180 done(StatusHelper::FromC(overall_status));
181 tpu::ExecutorApiFn()->TpuStatus_FreeFn(overall_status);
182 delete this;
183 }
184 }
185 };
186
TransferLiteralFromDeviceTrampoline(void * ctx,TF_Status * status)187 void TransferLiteralFromDeviceTrampoline(void* ctx, TF_Status* status) {
188 reinterpret_cast<TransferFromDeviceState*>(ctx)->TransferFinished(status);
189 }
190
TransferLiteralFromDevice(stream_executor::Stream * stream,const xla::ShapedBuffer & device_buffer,xla::MutableBorrowingLiteral literal,std::function<void (Status)> done,const TransferMetadata * transfer_metadata)191 void TpuTransferManager::TransferLiteralFromDevice(
192 stream_executor::Stream* stream, const xla::ShapedBuffer& device_buffer,
193 xla::MutableBorrowingLiteral literal, std::function<void(Status)> done,
194 const TransferMetadata* transfer_metadata) {
195 TransferFromDeviceState* state = new TransferFromDeviceState;
196 state->remaining_transfers = 1;
197 state->done = done;
198 XLA_ShapedBuffer c_device_buffer;
199 ApiConverter::ToC(device_buffer, &c_device_buffer);
200 XLA_Literal c_literal;
201 ApiConverter::ToC(literal, &c_literal);
202
203 tpu::ExecutorApiFn()->TpuTransferManager_TransferLiteralFromDeviceFn(
204 manager_,
205 TpuPlatform::GetRegisteredPlatform()->LookupStream(
206 stream->implementation()),
207 &c_device_buffer, &c_literal, TransferLiteralFromDeviceTrampoline, state);
208 ApiConverter::Free(&c_device_buffer);
209 ApiConverter::Free(&c_literal);
210 }
211
GetByteSizeRequirement(const xla::Shape & shape) const212 int64 TpuTransferManager::GetByteSizeRequirement(
213 const xla::Shape& shape) const {
214 XLA_Shape c_shape;
215 ApiConverter::ToC(shape, &c_shape);
216
217 int64_t size_in_bytes =
218 tpu::ExecutorApiFn()->TpuTransferManager_GetByteSizeRequirementFn(
219 manager_, &c_shape);
220
221 ApiConverter::Free(&c_shape);
222 return size_in_bytes;
223 }
224
ChooseCompactLayoutForShape(const xla::Shape & host_shape) const225 StatusOr<xla::Shape> TpuTransferManager::ChooseCompactLayoutForShape(
226 const xla::Shape& host_shape) const {
227 XLA_Shape c_host_shape;
228 ApiConverter::ToC(host_shape, &c_host_shape);
229 XLA_Shape c_output;
230 StatusHelper status;
231 tpu::ExecutorApiFn()->TpuTransferManager_ChooseCompactLayoutForShapeFn(
232 manager_, &c_host_shape, &c_output, status.c_status);
233 // TODO(skyewm): use a scoped version of XLA_Shape
234 ApiConverter::Free(&c_host_shape);
235 if (!status.status().ok()) {
236 ApiConverter::Free(&c_output);
237 return status.status();
238 }
239 xla::Shape output = ApiConverter::FromC(&c_output);
240 ApiConverter::Free(&c_output);
241 return output;
242 }
243
CanShapedBufferBeAccessedNow(stream_executor::StreamExecutor * executor,const xla::ShapedBuffer & device_buffer) const244 bool TpuTransferManager::CanShapedBufferBeAccessedNow(
245 stream_executor::StreamExecutor* executor,
246 const xla::ShapedBuffer& device_buffer) const {
247 auto* tpu_executor = down_cast<TpuExecutor*>(executor->implementation());
248 XLA_ShapedBuffer c_device_buffer;
249 ApiConverter::ToC(device_buffer, &c_device_buffer);
250 auto cleanup = xla::MakeCleanup(
251 [&c_device_buffer]() { ApiConverter::Free(&c_device_buffer); });
252 return tpu::ExecutorApiFn()
253 ->TpuTransferManager_CanShapedBufferBeAccessedNowFn(
254 manager_, tpu_executor->se_executor(), &c_device_buffer);
255 }
256
CanBufferBeAccessedNow(se::StreamExecutor * executor,const se::DeviceMemoryBase & device_buffer) const257 bool TpuTransferManager::CanBufferBeAccessedNow(
258 se::StreamExecutor* executor,
259 const se::DeviceMemoryBase& device_buffer) const {
260 auto* tpu_executor = down_cast<TpuExecutor*>(executor->implementation());
261 SE_DeviceMemoryBase c_device_buffer{const_cast<void*>(device_buffer.opaque()),
262 device_buffer.size(),
263 device_buffer.payload()};
264 return tpu::ExecutorApiFn()->TpuTransferManager_CanBufferBeAccessedNowFn(
265 manager_, tpu_executor->se_executor(), &c_device_buffer);
266 }
267
WriteSingleTupleIndexTable(stream_executor::Stream * stream,absl::Span<const stream_executor::DeviceMemoryBase> elements,const xla::Shape & shape,stream_executor::DeviceMemoryBase * region)268 Status TpuTransferManager::WriteSingleTupleIndexTable(
269 stream_executor::Stream* stream,
270 absl::Span<const stream_executor::DeviceMemoryBase> elements,
271 const xla::Shape& shape, stream_executor::DeviceMemoryBase* region) {
272 CHECK_GT(elements.size(), 0);
273 SE_DeviceMemoryBase* elements_bases =
274 new SE_DeviceMemoryBase[elements.size()];
275 for (int i = 0; i < elements.size(); i++) {
276 elements_bases[i] =
277 SE_DeviceMemoryBase{const_cast<void*>(elements[i].opaque()),
278 elements[i].size(), elements[i].payload()};
279 }
280 XLA_Shape c_shape;
281 ApiConverter::ToC(shape, &c_shape);
282 SE_DeviceMemoryBase region_base{region->opaque(), region->size(),
283 region->payload()};
284 StatusHelper status;
285
286 tpu::ExecutorApiFn()->TpuTransferManager_WriteSingleTupleIndexTableFn(
287 manager_,
288 TpuPlatform::GetRegisteredPlatform()->LookupStream(
289 stream->implementation()),
290 elements_bases, elements.size(), &c_shape, ®ion_base, status.c_status);
291
292 delete[] elements_bases;
293 ApiConverter::Free(&c_shape);
294 return status.status();
295 }
296
LinearizeToBuffers(const xla::LiteralSlice & literal,std::deque<tensorflow::tpu::NoncopyableBuffer> * buffers)297 Status TpuTransferManager::LinearizeToBuffers(
298 const xla::LiteralSlice& literal,
299 std::deque<tensorflow::tpu::NoncopyableBuffer>* buffers) {
300 XLA_Literal c_literal;
301 ApiConverter::ToC(literal, &c_literal);
302
303 char** buffers_array;
304 int64_t* buffers_size;
305 int64_t buffers_array_size;
306 StatusHelper status;
307
308 tpu::ExecutorApiFn()->TpuTransferManager_LinearizeToBuffersFn(
309 manager_, &c_literal, &buffers_array, &buffers_size, &buffers_array_size,
310 status.c_status);
311
312 for (int64_t i = 0; i < buffers_array_size; ++i) {
313 tpu::NoncopyableBuffer buf(buffers_size[i]);
314 memcpy(buf.mutable_data<uint8_t>().data(), buffers_array[i],
315 buffers_size[i]);
316 buffers->push_back(std::move(buf));
317 }
318
319 tpu::ExecutorApiFn()->TpuTransferManager_FreeBuffersFn(
320 buffers_array, buffers_size, buffers_array_size);
321
322 ApiConverter::Free(&c_literal);
323 return status.status();
324 }
325
ReadDynamicShapes(se::Stream * stream,xla::ShapedBuffer * device_buffer,xla::Shape * device_shape)326 Status TpuTransferManager::ReadDynamicShapes(se::Stream* stream,
327 xla::ShapedBuffer* device_buffer,
328 xla::Shape* device_shape) {
329 XLA_ShapedBuffer c_device_buffer;
330 XLA_Shape c_device_shape;
331 ApiConverter::ToC(*device_buffer, &c_device_buffer);
332 ApiConverter::ToC(*device_shape, &c_device_shape);
333 XLA_Shape c_updated_shape;
334 StatusHelper status;
335 ExecutorApiFn()->TpuTransferManager_ReadDynamicShapesFn(
336 TpuPlatform::GetRegisteredPlatform()->LookupStream(
337 stream->implementation()),
338 &c_device_buffer, c_device_shape, &c_updated_shape, status.c_status);
339 ApiConverter::Free(&c_device_buffer);
340 ApiConverter::Free(&c_device_shape);
341 if (!status.ok()) {
342 return status.status();
343 }
344 *device_shape = ApiConverter::FromC(&c_updated_shape);
345 ApiConverter::Free(&c_updated_shape);
346 return Status::OK();
347 }
348
349 } // namespace tpu
350 } // namespace tensorflow
351