• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/stream_executor/tpu/tpu_transfer_manager.h"
17 
18 #include <utility>
19 
20 #include "tensorflow/compiler/xla/literal.h"
21 #include "tensorflow/compiler/xla/shape_util.h"
22 #include "tensorflow/compiler/xla/xla_data.pb.h"
23 #include "tensorflow/core/tpu/tpu_api.h"
24 #include "tensorflow/stream_executor/device_memory.h"
25 #include "tensorflow/stream_executor/tpu/c_api_conversions.h"
26 #include "tensorflow/stream_executor/tpu/noncopyable_buffer.h"
27 #include "tensorflow/stream_executor/tpu/proto_helper.h"
28 #include "tensorflow/stream_executor/tpu/status_helper.h"
29 #include "tensorflow/stream_executor/tpu/tpu_executor.h"
30 #include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
31 #include "tensorflow/stream_executor/tpu/tpu_platform.h"
32 #include "tensorflow/stream_executor/tpu/tpu_platform_id.h"
33 
34 namespace tensorflow {
35 namespace tpu {
36 
37 using Status = stream_executor::port::Status;
38 template <typename T>
39 using StatusOr = stream_executor::port::StatusOr<T>;
40 
TpuTransferManager()41 TpuTransferManager::TpuTransferManager() {
42   manager_ = tpu::ExecutorApiFn()->TpuTransferManager_NewFn();
43 }
44 
~TpuTransferManager()45 TpuTransferManager::~TpuTransferManager() {
46   tpu::ExecutorApiFn()->TpuTransferManager_FreeFn(manager_);
47 }
48 
PlatformId() const49 stream_executor::Platform::Id TpuTransferManager::PlatformId() const {
50   return GetTpuPlatformId();
51 }
52 
HostShapeToDeviceShape(const xla::Shape & host_shape) const53 xla::Shape TpuTransferManager::HostShapeToDeviceShape(
54     const xla::Shape& host_shape) const {
55   XLA_Shape c_host_shape;
56   XLA_Shape c_device_shape;
57 
58   ApiConverter::ToC(host_shape, &c_host_shape);
59 
60   tpu::ExecutorApiFn()->TpuTransferManager_HostShapeToDeviceShapeFn(
61       manager_, &c_host_shape, &c_device_shape);
62   xla::Shape device_shape = ApiConverter::FromC(&c_device_shape);
63   ApiConverter::Free(&c_host_shape);
64   ApiConverter::Free(&c_device_shape);
65   return device_shape;
66 }
67 
TransferLiteralToDeviceAsync(stream_executor::Stream * stream,const xla::LiteralSlice & literal,const xla::ShapedBuffer & device_buffer,const TransferMetadata * transfer_metadata)68 Status TpuTransferManager::TransferLiteralToDeviceAsync(
69     stream_executor::Stream* stream, const xla::LiteralSlice& literal,
70     const xla::ShapedBuffer& device_buffer,
71     const TransferMetadata* transfer_metadata) {
72   StatusHelper status;
73 
74   XLA_Literal c_literal;
75   ApiConverter::ToC(literal, &c_literal);
76 
77   XLA_ShapedBuffer c_device_buffer;
78   ApiConverter::ToC(device_buffer, &c_device_buffer);
79 
80   tpu::ExecutorApiFn()->TpuTransferManager_TransferLiteralToDeviceAsyncFn(
81       manager_,
82       TpuPlatform::GetRegisteredPlatform()->LookupStream(
83           stream->implementation()),
84       &c_literal, &c_device_buffer, status.c_status);
85   ApiConverter::Free(&c_device_buffer);
86   ApiConverter::Free(&c_literal);
87   return status.status();
88 }
89 
TransferLiteralToInfeed(stream_executor::StreamExecutor * executor,const xla::LiteralSlice & literal)90 Status TpuTransferManager::TransferLiteralToInfeed(
91     stream_executor::StreamExecutor* executor,
92     const xla::LiteralSlice& literal) {
93   StatusHelper status;
94   XLA_Literal c_literal;
95   ApiConverter::ToC(literal, &c_literal);
96   auto* tpu_executor = static_cast<TpuExecutor*>(executor->implementation());
97 
98   tpu::ExecutorApiFn()->TpuTransferManager_TransferLiteralToInfeedFn(
99       manager_, tpu_executor->se_executor(), &c_literal, status.c_status);
100 
101   ApiConverter::Free(&c_literal);
102 
103   return status.status();
104 }
105 
TransferBuffersToInfeed(se::StreamExecutor * executor,const std::deque<tensorflow::tpu::NoncopyableBuffer> & buffers)106 Status TpuTransferManager::TransferBuffersToInfeed(
107     se::StreamExecutor* executor,
108     const std::deque<tensorflow::tpu::NoncopyableBuffer>& buffers) {
109   StatusHelper status;
110   auto* tpu_executor = static_cast<TpuExecutor*>(executor->implementation());
111 
112   std::vector<int64_t> buffers_size;
113   std::vector<uint32_t*> buffers_array;
114 
115   buffers_size.reserve(buffers.size());
116   buffers_array.reserve(buffers.size());
117 
118   for (int64_t i = 0; i < buffers.size(); ++i) {
119     absl::Span<const uint32_t> span = buffers[i].const_data<uint32_t>();
120     buffers_array.push_back(const_cast<uint32_t*>(span.data()));
121     buffers_size.push_back(span.size());
122   }
123 
124   tpu::ExecutorApiFn()->TpuTransferManager_TransferBuffersToInfeedFn(
125       manager_, tpu_executor->se_executor(), buffers_array.data(),
126       buffers_size.data(), buffers_size.size(), status.c_status);
127   return status.status();
128 }
129 
TransferLiteralFromOutfeed(stream_executor::StreamExecutor * executor,xla::MutableBorrowingLiteral literal)130 Status TpuTransferManager::TransferLiteralFromOutfeed(
131     stream_executor::StreamExecutor* executor,
132     xla::MutableBorrowingLiteral literal) {
133   StatusHelper status;
134   XLA_Shape c_shape;
135   XLA_Literal c_literal;
136   auto* tpu_executor = static_cast<TpuExecutor*>(executor->implementation());
137 
138   ApiConverter::ToC(literal.shape(), &c_shape);
139   ApiConverter::ToC(literal, &c_literal);
140 
141   tpu::ExecutorApiFn()->TpuTransferManager_TransferLiteralFromOutfeedFn(
142       manager_, tpu_executor->se_executor(), &c_shape, &c_literal,
143       status.c_status);
144 
145   ApiConverter::Free(&c_shape);
146   ApiConverter::Free(&c_literal);
147 
148   return status.status();
149 }
150 
ResetDevices(absl::Span<stream_executor::StreamExecutor * const> executor)151 Status TpuTransferManager::ResetDevices(
152     absl::Span<stream_executor::StreamExecutor* const> executor) {
153   StatusHelper status;
154   std::vector<SE_StreamExecutor*> se;
155   se.reserve(executor.size());
156   for (int64_t i = 0; i < executor.size(); ++i) {
157     se.push_back(static_cast<TpuExecutor*>(executor[i]->implementation())
158                      ->se_executor());
159   }
160 
161   tpu::ExecutorApiFn()->TpuTransferManager_ResetDevicesFn(
162       manager_, se.data(), executor.size(), status.c_status);
163   return status.status();
164 }
165 
166 struct TransferFromDeviceState {
167   std::atomic<int64_t> remaining_transfers;
168   TF_Status* overall_status =
169       tpu::ExecutorApiFn()->TpuStatus_NewFn();  // OK or the first error
170   std::function<void(Status)> done;
171 
TransferFinishedtensorflow::tpu::TransferFromDeviceState172   void TransferFinished(TF_Status* status) {
173     if (!tpu::ExecutorApiFn()->TpuStatus_OkFn(status) &&
174         tpu::ExecutorApiFn()->TpuStatus_OkFn(overall_status)) {
175       std::swap(overall_status, status);
176     }
177     tpu::ExecutorApiFn()->TpuStatus_FreeFn(status);
178 
179     if (--remaining_transfers == 0) {
180       done(StatusHelper::FromC(overall_status));
181       tpu::ExecutorApiFn()->TpuStatus_FreeFn(overall_status);
182       delete this;
183     }
184   }
185 };
186 
TransferLiteralFromDeviceTrampoline(void * ctx,TF_Status * status)187 void TransferLiteralFromDeviceTrampoline(void* ctx, TF_Status* status) {
188   reinterpret_cast<TransferFromDeviceState*>(ctx)->TransferFinished(status);
189 }
190 
TransferLiteralFromDevice(stream_executor::Stream * stream,const xla::ShapedBuffer & device_buffer,xla::MutableBorrowingLiteral literal,std::function<void (Status)> done,const TransferMetadata * transfer_metadata)191 void TpuTransferManager::TransferLiteralFromDevice(
192     stream_executor::Stream* stream, const xla::ShapedBuffer& device_buffer,
193     xla::MutableBorrowingLiteral literal, std::function<void(Status)> done,
194     const TransferMetadata* transfer_metadata) {
195   TransferFromDeviceState* state = new TransferFromDeviceState;
196   state->remaining_transfers = 1;
197   state->done = done;
198   XLA_ShapedBuffer c_device_buffer;
199   ApiConverter::ToC(device_buffer, &c_device_buffer);
200   XLA_Literal c_literal;
201   ApiConverter::ToC(literal, &c_literal);
202 
203   tpu::ExecutorApiFn()->TpuTransferManager_TransferLiteralFromDeviceFn(
204       manager_,
205       TpuPlatform::GetRegisteredPlatform()->LookupStream(
206           stream->implementation()),
207       &c_device_buffer, &c_literal, TransferLiteralFromDeviceTrampoline, state);
208   ApiConverter::Free(&c_device_buffer);
209   ApiConverter::Free(&c_literal);
210 }
211 
GetByteSizeRequirement(const xla::Shape & shape) const212 int64 TpuTransferManager::GetByteSizeRequirement(
213     const xla::Shape& shape) const {
214   XLA_Shape c_shape;
215   ApiConverter::ToC(shape, &c_shape);
216 
217   int64_t size_in_bytes =
218       tpu::ExecutorApiFn()->TpuTransferManager_GetByteSizeRequirementFn(
219           manager_, &c_shape);
220 
221   ApiConverter::Free(&c_shape);
222   return size_in_bytes;
223 }
224 
ChooseCompactLayoutForShape(const xla::Shape & host_shape) const225 StatusOr<xla::Shape> TpuTransferManager::ChooseCompactLayoutForShape(
226     const xla::Shape& host_shape) const {
227   XLA_Shape c_host_shape;
228   ApiConverter::ToC(host_shape, &c_host_shape);
229   XLA_Shape c_output;
230   StatusHelper status;
231   tpu::ExecutorApiFn()->TpuTransferManager_ChooseCompactLayoutForShapeFn(
232       manager_, &c_host_shape, &c_output, status.c_status);
233   // TODO(skyewm): use a scoped version of XLA_Shape
234   ApiConverter::Free(&c_host_shape);
235   if (!status.status().ok()) {
236     ApiConverter::Free(&c_output);
237     return status.status();
238   }
239   xla::Shape output = ApiConverter::FromC(&c_output);
240   ApiConverter::Free(&c_output);
241   return output;
242 }
243 
CanShapedBufferBeAccessedNow(stream_executor::StreamExecutor * executor,const xla::ShapedBuffer & device_buffer) const244 bool TpuTransferManager::CanShapedBufferBeAccessedNow(
245     stream_executor::StreamExecutor* executor,
246     const xla::ShapedBuffer& device_buffer) const {
247   auto* tpu_executor = down_cast<TpuExecutor*>(executor->implementation());
248   XLA_ShapedBuffer c_device_buffer;
249   ApiConverter::ToC(device_buffer, &c_device_buffer);
250   auto cleanup = xla::MakeCleanup(
251       [&c_device_buffer]() { ApiConverter::Free(&c_device_buffer); });
252   return tpu::ExecutorApiFn()
253       ->TpuTransferManager_CanShapedBufferBeAccessedNowFn(
254           manager_, tpu_executor->se_executor(), &c_device_buffer);
255 }
256 
CanBufferBeAccessedNow(se::StreamExecutor * executor,const se::DeviceMemoryBase & device_buffer) const257 bool TpuTransferManager::CanBufferBeAccessedNow(
258     se::StreamExecutor* executor,
259     const se::DeviceMemoryBase& device_buffer) const {
260   auto* tpu_executor = down_cast<TpuExecutor*>(executor->implementation());
261   SE_DeviceMemoryBase c_device_buffer{const_cast<void*>(device_buffer.opaque()),
262                                       device_buffer.size(),
263                                       device_buffer.payload()};
264   return tpu::ExecutorApiFn()->TpuTransferManager_CanBufferBeAccessedNowFn(
265       manager_, tpu_executor->se_executor(), &c_device_buffer);
266 }
267 
WriteSingleTupleIndexTable(stream_executor::Stream * stream,absl::Span<const stream_executor::DeviceMemoryBase> elements,const xla::Shape & shape,stream_executor::DeviceMemoryBase * region)268 Status TpuTransferManager::WriteSingleTupleIndexTable(
269     stream_executor::Stream* stream,
270     absl::Span<const stream_executor::DeviceMemoryBase> elements,
271     const xla::Shape& shape, stream_executor::DeviceMemoryBase* region) {
272   CHECK_GT(elements.size(), 0);
273   SE_DeviceMemoryBase* elements_bases =
274       new SE_DeviceMemoryBase[elements.size()];
275   for (int i = 0; i < elements.size(); i++) {
276     elements_bases[i] =
277         SE_DeviceMemoryBase{const_cast<void*>(elements[i].opaque()),
278                             elements[i].size(), elements[i].payload()};
279   }
280   XLA_Shape c_shape;
281   ApiConverter::ToC(shape, &c_shape);
282   SE_DeviceMemoryBase region_base{region->opaque(), region->size(),
283                                   region->payload()};
284   StatusHelper status;
285 
286   tpu::ExecutorApiFn()->TpuTransferManager_WriteSingleTupleIndexTableFn(
287       manager_,
288       TpuPlatform::GetRegisteredPlatform()->LookupStream(
289           stream->implementation()),
290       elements_bases, elements.size(), &c_shape, &region_base, status.c_status);
291 
292   delete[] elements_bases;
293   ApiConverter::Free(&c_shape);
294   return status.status();
295 }
296 
LinearizeToBuffers(const xla::LiteralSlice & literal,std::deque<tensorflow::tpu::NoncopyableBuffer> * buffers)297 Status TpuTransferManager::LinearizeToBuffers(
298     const xla::LiteralSlice& literal,
299     std::deque<tensorflow::tpu::NoncopyableBuffer>* buffers) {
300   XLA_Literal c_literal;
301   ApiConverter::ToC(literal, &c_literal);
302 
303   char** buffers_array;
304   int64_t* buffers_size;
305   int64_t buffers_array_size;
306   StatusHelper status;
307 
308   tpu::ExecutorApiFn()->TpuTransferManager_LinearizeToBuffersFn(
309       manager_, &c_literal, &buffers_array, &buffers_size, &buffers_array_size,
310       status.c_status);
311 
312   for (int64_t i = 0; i < buffers_array_size; ++i) {
313     tpu::NoncopyableBuffer buf(buffers_size[i]);
314     memcpy(buf.mutable_data<uint8_t>().data(), buffers_array[i],
315            buffers_size[i]);
316     buffers->push_back(std::move(buf));
317   }
318 
319   tpu::ExecutorApiFn()->TpuTransferManager_FreeBuffersFn(
320       buffers_array, buffers_size, buffers_array_size);
321 
322   ApiConverter::Free(&c_literal);
323   return status.status();
324 }
325 
ReadDynamicShapes(se::Stream * stream,xla::ShapedBuffer * device_buffer,xla::Shape * device_shape)326 Status TpuTransferManager::ReadDynamicShapes(se::Stream* stream,
327                                              xla::ShapedBuffer* device_buffer,
328                                              xla::Shape* device_shape) {
329   XLA_ShapedBuffer c_device_buffer;
330   XLA_Shape c_device_shape;
331   ApiConverter::ToC(*device_buffer, &c_device_buffer);
332   ApiConverter::ToC(*device_shape, &c_device_shape);
333   XLA_Shape c_updated_shape;
334   StatusHelper status;
335   ExecutorApiFn()->TpuTransferManager_ReadDynamicShapesFn(
336       TpuPlatform::GetRegisteredPlatform()->LookupStream(
337           stream->implementation()),
338       &c_device_buffer, c_device_shape, &c_updated_shape, status.c_status);
339   ApiConverter::Free(&c_device_buffer);
340   ApiConverter::Free(&c_device_shape);
341   if (!status.ok()) {
342     return status.status();
343   }
344   *device_shape = ApiConverter::FromC(&c_updated_shape);
345   ApiConverter::Free(&c_updated_shape);
346   return Status::OK();
347 }
348 
349 }  // namespace tpu
350 }  // namespace tensorflow
351