• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/stream_executor/tpu/tpu_transfer_manager.h"
17 
18 #include <utility>
19 
20 #include "tensorflow/compiler/xla/shape_util.h"
21 #include "tensorflow/compiler/xla/xla_data.pb.h"
22 #include "tensorflow/core/tpu/tpu_api.h"
23 #include "tensorflow/stream_executor/device_memory.h"
24 #include "tensorflow/stream_executor/tpu/c_api_conversions.h"
25 #include "tensorflow/stream_executor/tpu/noncopyable_buffer.h"
26 #include "tensorflow/stream_executor/tpu/proto_helper.h"
27 #include "tensorflow/stream_executor/tpu/status_helper.h"
28 #include "tensorflow/stream_executor/tpu/tpu_executor.h"
29 #include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
30 #include "tensorflow/stream_executor/tpu/tpu_platform.h"
31 #include "tensorflow/stream_executor/tpu/tpu_platform_id.h"
32 
33 namespace tensorflow {
34 namespace tpu {
35 
36 using Status = stream_executor::port::Status;
37 template <typename T>
38 using StatusOr = stream_executor::port::StatusOr<T>;
39 
TpuTransferManager()40 TpuTransferManager::TpuTransferManager() {
41   manager_ = tpu::ExecutorApiFn()->TpuTransferManager_NewFn();
42 }
43 
~TpuTransferManager()44 TpuTransferManager::~TpuTransferManager() {
45   tpu::ExecutorApiFn()->TpuTransferManager_FreeFn(manager_);
46 }
47 
PlatformId() const48 stream_executor::Platform::Id TpuTransferManager::PlatformId() const {
49   return GetTpuPlatformId();
50 }
51 
HostShapeToDeviceShape(const xla::Shape & host_shape) const52 xla::Shape TpuTransferManager::HostShapeToDeviceShape(
53     const xla::Shape& host_shape) const {
54   XLA_Shape c_host_shape;
55   XLA_Shape c_device_shape;
56 
57   ApiConverter::ToC(host_shape, &c_host_shape);
58 
59   tpu::ExecutorApiFn()->TpuTransferManager_HostShapeToDeviceShapeFn(
60       manager_, &c_host_shape, &c_device_shape);
61   xla::Shape device_shape = ApiConverter::FromC(&c_device_shape);
62   ApiConverter::Free(&c_host_shape);
63   ApiConverter::Free(&c_device_shape);
64   return device_shape;
65 }
66 
TransferLiteralToDeviceAsync(stream_executor::Stream * stream,const xla::LiteralSlice & literal,const xla::ShapedBuffer & device_buffer,const TransferMetadata * transfer_metadata)67 Status TpuTransferManager::TransferLiteralToDeviceAsync(
68     stream_executor::Stream* stream, const xla::LiteralSlice& literal,
69     const xla::ShapedBuffer& device_buffer,
70     const TransferMetadata* transfer_metadata) {
71   StatusHelper status;
72 
73   XLA_Literal c_literal;
74   ApiConverter::ToC(literal, &c_literal);
75 
76   XLA_ShapedBuffer c_device_buffer;
77   ApiConverter::ToC(device_buffer, &c_device_buffer);
78 
79   tpu::ExecutorApiFn()->TpuTransferManager_TransferLiteralToDeviceAsyncFn(
80       manager_,
81       TpuPlatform::GetRegisteredPlatform()->stream_map()->at(
82           stream->implementation()),
83       &c_literal, &c_device_buffer, status.c_status);
84   ApiConverter::Free(&c_device_buffer);
85   ApiConverter::Free(&c_literal);
86   return status.status();
87 }
88 
TransferLiteralToInfeed(stream_executor::StreamExecutor * executor,const xla::LiteralSlice & literal)89 Status TpuTransferManager::TransferLiteralToInfeed(
90     stream_executor::StreamExecutor* executor,
91     const xla::LiteralSlice& literal) {
92   StatusHelper status;
93   XLA_Literal c_literal;
94   ApiConverter::ToC(literal, &c_literal);
95   auto* tpu_executor = static_cast<TpuExecutor*>(executor->implementation());
96 
97   tpu::ExecutorApiFn()->TpuTransferManager_TransferLiteralToInfeedFn(
98       manager_, tpu_executor->se_executor(), &c_literal, status.c_status);
99 
100   ApiConverter::Free(&c_literal);
101 
102   return status.status();
103 }
104 
TransferBuffersToInfeed(se::StreamExecutor * executor,const std::deque<tensorflow::tpu::NoncopyableBuffer> & buffers)105 Status TpuTransferManager::TransferBuffersToInfeed(
106     se::StreamExecutor* executor,
107     const std::deque<tensorflow::tpu::NoncopyableBuffer>& buffers) {
108   StatusHelper status;
109   auto* tpu_executor = static_cast<TpuExecutor*>(executor->implementation());
110 
111   std::vector<int64_t> buffers_size;
112   std::vector<uint32_t*> buffers_array;
113 
114   buffers_size.reserve(buffers.size());
115   buffers_array.reserve(buffers.size());
116 
117   for (int64_t i = 0; i < buffers.size(); ++i) {
118     absl::Span<const uint32_t> span = buffers[i].const_data<uint32_t>();
119     buffers_array.push_back(const_cast<uint32_t*>(span.data()));
120     buffers_size.push_back(span.size());
121   }
122 
123   tpu::ExecutorApiFn()->TpuTransferManager_TransferBuffersToInfeedFn(
124       manager_, tpu_executor->se_executor(), buffers_array.data(),
125       buffers_size.data(), buffers_size.size(), status.c_status);
126   return status.status();
127 }
128 
TransferLiteralFromOutfeed(stream_executor::StreamExecutor * executor,xla::MutableBorrowingLiteral literal)129 Status TpuTransferManager::TransferLiteralFromOutfeed(
130     stream_executor::StreamExecutor* executor,
131     xla::MutableBorrowingLiteral literal) {
132   StatusHelper status;
133   XLA_Shape c_shape;
134   XLA_Literal c_literal;
135   auto* tpu_executor = static_cast<TpuExecutor*>(executor->implementation());
136 
137   ApiConverter::ToC(literal.shape(), &c_shape);
138   ApiConverter::ToC(literal, &c_literal);
139 
140   tpu::ExecutorApiFn()->TpuTransferManager_TransferLiteralFromOutfeedFn(
141       manager_, tpu_executor->se_executor(), &c_shape, &c_literal,
142       status.c_status);
143 
144   ApiConverter::Free(&c_shape);
145   ApiConverter::Free(&c_literal);
146 
147   return status.status();
148 }
149 
ResetDevices(absl::Span<stream_executor::StreamExecutor * const> executor)150 Status TpuTransferManager::ResetDevices(
151     absl::Span<stream_executor::StreamExecutor* const> executor) {
152   StatusHelper status;
153   std::vector<SE_StreamExecutor*> se;
154   se.reserve(executor.size());
155   for (int64_t i = 0; i < executor.size(); ++i) {
156     se.push_back(static_cast<TpuExecutor*>(executor[i]->implementation())
157                      ->se_executor());
158   }
159 
160   tpu::ExecutorApiFn()->TpuTransferManager_ResetDevicesFn(
161       manager_, se.data(), executor.size(), status.c_status);
162   return status.status();
163 }
164 
165 struct TransferFromDeviceState {
166   std::atomic<int64_t> remaining_transfers;
167   TF_Status* overall_status =
168       tpu::ExecutorApiFn()->TpuStatus_NewFn();  // OK or the first error
169   std::function<void(Status)> done;
170 
TransferFinishedtensorflow::tpu::TransferFromDeviceState171   void TransferFinished(TF_Status* status) {
172     if (!tpu::ExecutorApiFn()->TpuStatus_OkFn(status) &&
173         tpu::ExecutorApiFn()->TpuStatus_OkFn(overall_status)) {
174       std::swap(overall_status, status);
175     }
176     tpu::ExecutorApiFn()->TpuStatus_FreeFn(status);
177 
178     if (--remaining_transfers == 0) {
179       done(StatusHelper::FromC(overall_status));
180       tpu::ExecutorApiFn()->TpuStatus_FreeFn(overall_status);
181       delete this;
182     }
183   }
184 };
185 
TransferLiteralFromDeviceTrampoline(void * ctx,TF_Status * status)186 void TransferLiteralFromDeviceTrampoline(void* ctx, TF_Status* status) {
187   reinterpret_cast<TransferFromDeviceState*>(ctx)->TransferFinished(status);
188 }
189 
TransferLiteralFromDevice(stream_executor::Stream * stream,const xla::ShapedBuffer & device_buffer,xla::MutableBorrowingLiteral literal,std::function<void (Status)> done,const TransferMetadata * transfer_metadata)190 void TpuTransferManager::TransferLiteralFromDevice(
191     stream_executor::Stream* stream, const xla::ShapedBuffer& device_buffer,
192     xla::MutableBorrowingLiteral literal, std::function<void(Status)> done,
193     const TransferMetadata* transfer_metadata) {
194   TransferFromDeviceState* state = new TransferFromDeviceState;
195   state->remaining_transfers = 1;
196   state->done = done;
197   XLA_ShapedBuffer c_device_buffer;
198   ApiConverter::ToC(device_buffer, &c_device_buffer);
199   XLA_Literal c_literal;
200   ApiConverter::ToC(literal, &c_literal);
201 
202   tpu::ExecutorApiFn()->TpuTransferManager_TransferLiteralFromDeviceFn(
203       manager_,
204       TpuPlatform::GetRegisteredPlatform()->LookupStream(
205           stream->implementation()),
206       &c_device_buffer, &c_literal, TransferLiteralFromDeviceTrampoline, state);
207   ApiConverter::Free(&c_device_buffer);
208   ApiConverter::Free(&c_literal);
209 }
210 
GetByteSizeRequirement(const xla::Shape & shape) const211 int64 TpuTransferManager::GetByteSizeRequirement(
212     const xla::Shape& shape) const {
213   XLA_Shape c_shape;
214   ApiConverter::ToC(shape, &c_shape);
215 
216   int64 size_in_bytes =
217       tpu::ExecutorApiFn()->TpuTransferManager_GetByteSizeRequirementFn(
218           manager_, &c_shape);
219 
220   ApiConverter::Free(&c_shape);
221   return size_in_bytes;
222 }
223 
ChooseCompactLayoutForShape(const xla::Shape & host_shape) const224 StatusOr<xla::Shape> TpuTransferManager::ChooseCompactLayoutForShape(
225     const xla::Shape& host_shape) const {
226   XLA_Shape c_host_shape;
227   ApiConverter::ToC(host_shape, &c_host_shape);
228   XLA_Shape c_output;
229   StatusHelper status;
230   tpu::ExecutorApiFn()->TpuTransferManager_ChooseCompactLayoutForShapeFn(
231       manager_, &c_host_shape, &c_output, status.c_status);
232   // TODO(skyewm): use a scoped version of XLA_Shape
233   ApiConverter::Free(&c_host_shape);
234   if (!status.status().ok()) {
235     ApiConverter::Free(&c_output);
236     return status.status();
237   }
238   xla::Shape output = ApiConverter::FromC(&c_output);
239   ApiConverter::Free(&c_output);
240   return output;
241 }
242 
CanShapedBufferBeAccessedNow(stream_executor::StreamExecutor * executor,const xla::ShapedBuffer & device_buffer) const243 bool TpuTransferManager::CanShapedBufferBeAccessedNow(
244     stream_executor::StreamExecutor* executor,
245     const xla::ShapedBuffer& device_buffer) const {
246   auto* tpu_executor = down_cast<TpuExecutor*>(executor->implementation());
247   XLA_ShapedBuffer c_device_buffer;
248   ApiConverter::ToC(device_buffer, &c_device_buffer);
249   auto cleanup = xla::MakeCleanup(
250       [&c_device_buffer]() { ApiConverter::Free(&c_device_buffer); });
251   return tpu::ExecutorApiFn()
252       ->TpuTransferManager_CanShapedBufferBeAccessedNowFn(
253           manager_, tpu_executor->se_executor(), &c_device_buffer);
254 }
255 
CanBufferBeAccessedNow(se::StreamExecutor * executor,const se::DeviceMemoryBase & device_buffer) const256 bool TpuTransferManager::CanBufferBeAccessedNow(
257     se::StreamExecutor* executor,
258     const se::DeviceMemoryBase& device_buffer) const {
259   auto* tpu_executor = down_cast<TpuExecutor*>(executor->implementation());
260   SE_DeviceMemoryBase c_device_buffer{const_cast<void*>(device_buffer.opaque()),
261                                       device_buffer.size(),
262                                       device_buffer.payload()};
263   return tpu::ExecutorApiFn()->TpuTransferManager_CanBufferBeAccessedNowFn(
264       manager_, tpu_executor->se_executor(), &c_device_buffer);
265 }
266 
WriteSingleTupleIndexTable(stream_executor::Stream * stream,absl::Span<const stream_executor::DeviceMemoryBase> elements,const xla::Shape & shape,stream_executor::DeviceMemoryBase * region)267 Status TpuTransferManager::WriteSingleTupleIndexTable(
268     stream_executor::Stream* stream,
269     absl::Span<const stream_executor::DeviceMemoryBase> elements,
270     const xla::Shape& shape, stream_executor::DeviceMemoryBase* region) {
271   CHECK_GT(elements.size(), 0);
272   SE_DeviceMemoryBase* elements_bases =
273       new SE_DeviceMemoryBase[elements.size()];
274   for (int i = 0; i < elements.size(); i++) {
275     elements_bases[i] =
276         SE_DeviceMemoryBase{const_cast<void*>(elements[i].opaque()),
277                             elements[i].size(), elements[i].payload()};
278   }
279   XLA_Shape c_shape;
280   ApiConverter::ToC(shape, &c_shape);
281   SE_DeviceMemoryBase region_base{region->opaque(), region->size(),
282                                   region->payload()};
283   StatusHelper status;
284 
285   tpu::ExecutorApiFn()->TpuTransferManager_WriteSingleTupleIndexTableFn(
286       manager_,
287       TpuPlatform::GetRegisteredPlatform()->stream_map()->at(
288           stream->implementation()),
289       elements_bases, elements.size(), &c_shape, &region_base, status.c_status);
290 
291   delete[] elements_bases;
292   ApiConverter::Free(&c_shape);
293   return status.status();
294 }
295 
LinearizeToBuffers(const xla::LiteralSlice & literal,std::deque<tensorflow::tpu::NoncopyableBuffer> * buffers)296 Status TpuTransferManager::LinearizeToBuffers(
297     const xla::LiteralSlice& literal,
298     std::deque<tensorflow::tpu::NoncopyableBuffer>* buffers) {
299   XLA_Literal c_literal;
300   ApiConverter::ToC(literal, &c_literal);
301 
302   char** buffers_array;
303   int64_t* buffers_size;
304   int64_t buffers_array_size;
305   StatusHelper status;
306 
307   tpu::ExecutorApiFn()->TpuTransferManager_LinearizeToBuffersFn(
308       manager_, &c_literal, &buffers_array, &buffers_size, &buffers_array_size,
309       status.c_status);
310 
311   for (int64_t i = 0; i < buffers_array_size; ++i) {
312     tpu::NoncopyableBuffer buf(buffers_size[i]);
313     memcpy(buf.mutable_data<uint8_t>().data(), buffers_array[i],
314            buffers_size[i]);
315     buffers->push_back(std::move(buf));
316   }
317 
318   tpu::ExecutorApiFn()->TpuTransferManager_FreeBuffersFn(
319       buffers_array, buffers_size, buffers_array_size);
320 
321   ApiConverter::Free(&c_literal);
322   return status.status();
323 }
324 
325 }  // namespace tpu
326 }  // namespace tensorflow
327