1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/transfer_manager.h"
17
18 #include <string>
19 #include <utility>
20
21 #include "absl/memory/memory.h"
22 #include "absl/strings/str_cat.h"
23 #include "tensorflow/compiler/xla/service/compiler.h"
24 #include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h"
25 #include "tensorflow/compiler/xla/shape_util.h"
26 #include "tensorflow/compiler/xla/status_macros.h"
27 #include "tensorflow/compiler/xla/types.h"
28 #include "tensorflow/compiler/xla/util.h"
29 #include "tensorflow/core/lib/gtl/cleanup.h"
30 #include "tensorflow/core/platform/logging.h"
31 #include "tensorflow/core/platform/macros.h"
32 #include "tensorflow/core/platform/notification.h"
33
34 using absl::StrCat;
35
36 namespace xla {
37
38 /* static */ tensorflow::mutex
39 TransferManager::platform_transfer_manager_mutex_(
40 tensorflow::LINKER_INITIALIZED);
41
42 /* static */ absl::flat_hash_map<se::Platform::Id, TransferManager::State>*
GetPlatformTransferManagers()43 TransferManager::GetPlatformTransferManagers() {
44 static auto* r =
45 new absl::flat_hash_map<se::Platform::Id, TransferManager::State>;
46 return r;
47 }
48
~TransferMetadata()49 TransferManager::TransferMetadata::~TransferMetadata() {}
50
TransferLiteralFromDevice(se::Stream * stream,const ShapedBuffer & device_buffer,const TransferMetadata * transfer_metadata)51 StatusOr<Literal> TransferManager::TransferLiteralFromDevice(
52 se::Stream* stream, const ShapedBuffer& device_buffer,
53 const TransferMetadata* transfer_metadata) {
54 StatusOr<Literal> ret;
55
56 se::Stream* substream = stream->GetOrCreateSubStream();
57 substream->ThenWaitFor(stream);
58 auto cleanup = tensorflow::gtl::MakeCleanup(
59 [&]() { stream->ReturnSubStream(substream); });
60
61 tensorflow::Notification n;
62 Status s;
63 Literal literal(device_buffer.on_host_shape());
64 TransferLiteralFromDevice(
65 substream, device_buffer, &literal,
66 [&](Status status) {
67 s = status;
68 n.Notify();
69 },
70 transfer_metadata);
71 n.WaitForNotification();
72 if (!s.ok()) {
73 return s;
74 }
75 return std::move(literal);
76 }
77
TransferLiteralFromDevice(se::Stream * stream,const ShapedBuffer & device_buffer,const MutableBorrowingLiteral & literal,const TransferMetadata * transfer_metadata)78 Status TransferManager::TransferLiteralFromDevice(
79 se::Stream* stream, const ShapedBuffer& device_buffer,
80 const MutableBorrowingLiteral& literal,
81 const TransferMetadata* transfer_metadata) {
82 se::Stream* substream = stream->GetOrCreateSubStream();
83 auto cleanup = tensorflow::gtl::MakeCleanup(
84 [&]() { stream->ReturnSubStream(substream); });
85
86 Status ret;
87 tensorflow::Notification n;
88 TransferLiteralFromDevice(
89 substream, device_buffer, literal,
90 [&](Status status) {
91 ret = status;
92 n.Notify();
93 },
94 transfer_metadata);
95 n.WaitForNotification();
96 return ret;
97 }
98
TransferLiteralToDevice(se::Stream * stream,const LiteralSlice & literal,const ShapedBuffer & device_buffer,const TransferMetadata * transfer_metadata)99 Status TransferManager::TransferLiteralToDevice(
100 se::Stream* stream, const LiteralSlice& literal,
101 const ShapedBuffer& device_buffer,
102 const TransferMetadata* transfer_metadata) {
103 // Implement the synchronous version by waiting on the asynchronous version.
104 // Use a substream so that if we are called from a HostCallback we don't
105 // deadlock.
106 se::Stream* substream = stream->GetOrCreateSubStream();
107 substream->ThenWaitFor(stream);
108 auto cleanup = tensorflow::gtl::MakeCleanup(
109 [&]() { stream->ReturnSubStream(substream); });
110 TF_RETURN_IF_ERROR(TransferLiteralToDeviceAsync(
111 substream, literal, device_buffer, transfer_metadata));
112 return substream->BlockHostUntilDone();
113 }
114
TransferArrayFromDevice(se::Stream * stream,const Shape & shape,const se::DeviceMemoryBase & source,const TransferMetadata * transfer_metadata)115 StatusOr<Literal> TransferManager::TransferArrayFromDevice(
116 se::Stream* stream, const Shape& shape, const se::DeviceMemoryBase& source,
117 const TransferMetadata* transfer_metadata) {
118 StatusOr<Literal> ret;
119 // Implement the synchronous version by waiting on the asynchronous version.
120 // Use a substream so that if we are called from a HostCallback we don't
121 // deadlock.
122 se::Stream* substream = stream->GetOrCreateSubStream();
123 auto cleanup = tensorflow::gtl::MakeCleanup(
124 [&]() { stream->ReturnSubStream(substream); });
125
126 tensorflow::Notification n;
127 Literal literal(shape);
128 Status s;
129 TransferArrayFromDevice(
130 substream, shape, source, &literal,
131 [&](Status status) {
132 s = status;
133 n.Notify();
134 },
135 transfer_metadata);
136 n.WaitForNotification();
137 if (!s.ok()) {
138 return s;
139 }
140 return std::move(literal);
141 }
142
TransferArrayToDevice(se::Stream * stream,const LiteralSlice & literal,const se::DeviceMemoryBase & dest,const TransferMetadata * transfer_metadata)143 Status TransferManager::TransferArrayToDevice(
144 se::Stream* stream, const LiteralSlice& literal,
145 const se::DeviceMemoryBase& dest,
146 const TransferMetadata* transfer_metadata) {
147 // Implement the synchronous version by waiting on the asynchronous version.
148 // Use a substream so that if we are called from a HostCallback we don't
149 // deadlock.
150 se::Stream* substream = stream->GetOrCreateSubStream();
151 auto cleanup = tensorflow::gtl::MakeCleanup(
152 [&]() { stream->ReturnSubStream(substream); });
153 TF_RETURN_IF_ERROR(
154 TransferArrayToDeviceAsync(substream, literal, dest, transfer_metadata));
155 return substream->BlockHostUntilDone();
156 }
157
TransferArrayToDeviceAsync(se::Stream * stream,const LiteralSlice & literal,const se::DeviceMemoryBase & dest,const TransferMetadata * transfer_metadata)158 Status TransferManager::TransferArrayToDeviceAsync(
159 se::Stream* stream, const LiteralSlice& literal,
160 const se::DeviceMemoryBase& dest,
161 const TransferMetadata* transfer_metadata) {
162 const Shape on_device_shape = HostShapeToDeviceShape(literal.shape());
163 TF_RET_CHECK(on_device_shape.IsArray())
164 << "On-device representation of "
165 << ShapeUtil::HumanString(literal.shape())
166 << " is not an array: " << ShapeUtil::HumanString(on_device_shape);
167 if (dest.size() < GetByteSizeRequirement(on_device_shape)) {
168 return FailedPrecondition(
169 "Allocation on device not large enough for array: "
170 "%d < %d",
171 dest.size(), GetByteSizeRequirement(on_device_shape));
172 }
173 ShapedBuffer shaped_buffer(on_device_shape,
174 stream->parent()->device_ordinal());
175 shaped_buffer.set_buffer(dest, /*index=*/{});
176 return TransferLiteralToDevice(stream, literal, shaped_buffer,
177 transfer_metadata);
178 }
179
TransferArrayFromDevice(se::Stream * stream,const Shape & shape,const se::DeviceMemoryBase & source,const MutableBorrowingLiteral & literal,std::function<void (Status)> done,const TransferMetadata * transfer_metadata)180 void TransferManager::TransferArrayFromDevice(
181 se::Stream* stream, const Shape& shape, const se::DeviceMemoryBase& source,
182 const MutableBorrowingLiteral& literal, std::function<void(Status)> done,
183 const TransferMetadata* transfer_metadata) {
184 if (!Shape::Equal().MinorToMajorOnlyInLayout()(HostShapeToDeviceShape(shape),
185 shape)) {
186 auto error = StrCat("Shape ", ShapeUtil::HumanString(shape),
187 " has a differently shaped representation on-device: ",
188 ShapeUtil::HumanString(HostShapeToDeviceShape(shape)));
189 return done(FailedPrecondition("%s", error));
190 }
191 if (source.size() < GetByteSizeRequirement(shape)) {
192 return done(
193 FailedPrecondition("Allocation on device not large enough for array: "
194 "%d < %d",
195 source.size(), GetByteSizeRequirement(shape)));
196 }
197 ShapedBuffer shaped_buffer(shape, stream->parent()->device_ordinal());
198 shaped_buffer.set_buffer(source, /*index=*/{});
199 return TransferLiteralFromDevice(stream, shaped_buffer, literal,
200 std::move(done), transfer_metadata);
201 }
202
ReadDynamicShapes(se::Stream * stream,ShapedBuffer * device_buffer,Shape * device_shape)203 Status TransferManager::ReadDynamicShapes(se::Stream* stream,
204 ShapedBuffer* device_buffer,
205 Shape* device_shape) {
206 DCHECK(device_shape->is_dynamic());
207 Shape original_device_shape = *device_shape;
208 TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
209
210 TF_ASSIGN_OR_RETURN(auto compiler,
211 Compiler::GetForPlatform(stream->parent()->platform()));
212 TF_RETURN_IF_ERROR(device_buffer->buffers().ForEachMutableElementWithStatus(
213 [&](const ShapeIndex& index, se::DeviceMemoryBase* buffer) {
214 const Shape& buffer_shape =
215 ShapeUtil::GetSubshape(*device_shape, index);
216 if (buffer_shape.IsTuple()) {
217 return Status::OK();
218 }
219 Shape& device_sub_shape =
220 *ShapeUtil::GetMutableSubshape(device_shape, index);
221 if (device_sub_shape.is_static()) {
222 return Status::OK();
223 }
224
225 // Read the dynamic shape metadata from the device stream.
226 auto shape_size_fn = compiler->ShapeSizeBytesFunction();
227 Shape buffer_shape_static = ShapeUtil::MakeStaticShape(buffer_shape);
228 const int64_t offset = shape_size_fn(buffer_shape_static);
229 int64_t metadata_size = shape_size_fn(buffer_shape) - offset;
230 if (metadata_size == 0) {
231 return InvalidArgument("Dynamic shape metadata size should not be 0");
232 }
233 auto buffer_8 = se::DeviceMemory<uint8>(*buffer);
234 auto metadata_buffer =
235 stream->parent()->GetSubBuffer(&buffer_8, offset, metadata_size);
236 TF_ASSIGN_OR_RETURN(
237 auto metadata,
238 TransferArrayFromDevice(
239 stream,
240 ShapeUtil::MakeShape(S32, {buffer_shape.dimensions_size()}),
241 metadata_buffer));
242
243 // Update shape size from metadata.
244 for (int64_t i = 0; i < metadata.element_count(); ++i) {
245 device_sub_shape.mutable_dimensions()[i] = metadata.Get<int32>({i});
246 }
247 return Status::OK();
248 }));
249 device_shape->clear_dynamic_dimensions();
250
251 TF_RET_CHECK(ShapeUtil::DynamicShapeIsCompatible(*device_shape,
252 original_device_shape));
253 return Status::OK();
254 }
255
RegisterTransferManager(se::Platform::Id platform_id,TransferManagerCreationFunction creation_function)256 /* static */ void TransferManager::RegisterTransferManager(
257 se::Platform::Id platform_id,
258 TransferManagerCreationFunction creation_function) {
259 tensorflow::mutex_lock lock(
260 TransferManager::platform_transfer_manager_mutex_);
261 auto* managers = GetPlatformTransferManagers();
262 CHECK(managers->find(platform_id) == managers->end());
263 (*managers)[platform_id].creation_function = creation_function;
264 }
265
GetForPlatform(const se::Platform * platform)266 /* static */ StatusOr<TransferManager*> TransferManager::GetForPlatform(
267 const se::Platform* platform) {
268 tensorflow::mutex_lock lock(
269 TransferManager::platform_transfer_manager_mutex_);
270 auto* managers = GetPlatformTransferManagers();
271
272 auto it = managers->find(platform->id());
273 if (it == managers->end()) {
274 return NotFound(
275 "could not find registered transfer manager for platform %s -- check "
276 "target linkage",
277 platform->Name());
278 }
279
280 if (it->second.manager == nullptr) {
281 // Lazily create the transfer manager the first time it is needed
282 it->second.manager = (*it->second.creation_function)();
283 }
284
285 return it->second.manager.get();
286 }
287
WriteTupleIndexTables(se::Stream * stream,const ShapedBuffer & device_buffer)288 Status TransferManager::WriteTupleIndexTables(
289 se::Stream* stream, const ShapedBuffer& device_buffer) {
290 TF_RETURN_IF_ERROR(WriteTupleIndexTablesAsync(stream, device_buffer));
291 return stream->BlockHostUntilDone();
292 }
293
WriteTupleIndexTablesAsync(se::Stream * stream,const ShapedBuffer & device_buffer)294 Status TransferManager::WriteTupleIndexTablesAsync(
295 se::Stream* stream, const ShapedBuffer& device_buffer) {
296 VLOG(2) << "Writing tuple index tables for " << device_buffer;
297
298 return ShapeUtil::ForEachSubshapeWithStatus(
299 device_buffer.on_device_shape(),
300 [&](const Shape& device_subshape, const ShapeIndex& index) -> Status {
301 if (device_subshape.IsTuple() &&
302 ShapeUtil::TupleElementCount(device_subshape) > 0) {
303 se::DeviceMemoryBase device_memory = device_buffer.buffer(index);
304 TF_RET_CHECK(GetByteSizeRequirement(device_subshape) ==
305 device_memory.size());
306
307 std::vector<se::DeviceMemoryBase> elements;
308 ShapeIndex element_index = index;
309 for (int64_t i = 0; i < ShapeUtil::TupleElementCount(device_subshape);
310 ++i) {
311 element_index.push_back(i);
312 elements.push_back(device_buffer.buffer(element_index));
313 element_index.pop_back();
314 }
315 return WriteSingleTupleIndexTable(stream, elements, device_subshape,
316 &device_memory);
317 }
318
319 return Status::OK();
320 });
321 }
322
WriteRootTupleIndexTable(se::Stream * stream,const ShapedBuffer & device_buffer)323 Status TransferManager::WriteRootTupleIndexTable(
324 se::Stream* stream, const ShapedBuffer& device_buffer) {
325 TF_RET_CHECK(device_buffer.on_device_shape().IsTuple());
326 if (ShapeUtil::TupleElementCount(device_buffer.on_device_shape()) == 0) {
327 return Status::OK();
328 }
329 se::DeviceMemoryBase device_memory = device_buffer.buffer({});
330 TF_RET_CHECK(GetByteSizeRequirement(device_buffer.on_device_shape()) ==
331 device_memory.size());
332
333 std::vector<se::DeviceMemoryBase> elements;
334 for (int64_t i = 0;
335 i < ShapeUtil::TupleElementCount(device_buffer.on_device_shape()); ++i) {
336 elements.push_back(device_buffer.buffer({i}));
337 }
338 return WriteSingleTupleIndexTable(
339 stream, elements, device_buffer.on_device_shape(), &device_memory);
340 }
341
WriteRootTupleIndexTable(se::Stream * stream,const ShapeTree<MaybeOwningDeviceMemory> & buffer_tree)342 Status TransferManager::WriteRootTupleIndexTable(
343 se::Stream* stream, const ShapeTree<MaybeOwningDeviceMemory>& buffer_tree) {
344 TF_RET_CHECK(buffer_tree.shape().IsTuple());
345 if (ShapeUtil::TupleElementCount(buffer_tree.shape()) == 0) {
346 return Status::OK();
347 }
348 se::DeviceMemoryBase device_memory =
349 buffer_tree.element({}).AsDeviceMemoryBase();
350 TF_RET_CHECK(GetByteSizeRequirement(buffer_tree.shape()) ==
351 device_memory.size());
352
353 std::vector<se::DeviceMemoryBase> elements;
354 for (int64_t i = 0; i < ShapeUtil::TupleElementCount(buffer_tree.shape());
355 ++i) {
356 elements.push_back(buffer_tree.element({i}).AsDeviceMemoryBase());
357 }
358 return WriteSingleTupleIndexTable(stream, elements, buffer_tree.shape(),
359 &device_memory);
360 }
361
TransferBufferFromDevice(se::Stream * stream,const se::DeviceMemoryBase & source,int64_t size,void * destination)362 Status TransferManager::TransferBufferFromDevice(
363 se::Stream* stream, const se::DeviceMemoryBase& source, int64_t size,
364 void* destination) {
365 if (source.size() < size) {
366 return FailedPrecondition(
367 "Source allocation on device not large enough for data transfer: "
368 "%d < %d",
369 source.size(), size);
370 }
371 stream->ThenMemcpy(destination, source, size);
372 return Status::OK();
373 }
374
TransferBufferToDevice(se::Stream * stream,int64_t size,const void * source,se::DeviceMemoryBase * destination)375 Status TransferManager::TransferBufferToDevice(
376 se::Stream* stream, int64_t size, const void* source,
377 se::DeviceMemoryBase* destination) {
378 if (destination->size() < size) {
379 return FailedPrecondition(
380 "Destination allocation on device not large enough for data transfer: "
381 "%d < %d",
382 destination->size(), size);
383 }
384 stream->ThenMemcpy(destination, source, size);
385 return Status::OK();
386 }
387
AllocateScopedShapedBuffer(const Shape & on_host_shape,se::DeviceMemoryAllocator * allocator,int device_ordinal,DeviceShapeRepresentationFn shape_representation_fn)388 StatusOr<ScopedShapedBuffer> TransferManager::AllocateScopedShapedBuffer(
389 const Shape& on_host_shape, se::DeviceMemoryAllocator* allocator,
390 int device_ordinal, DeviceShapeRepresentationFn shape_representation_fn) {
391 if (!LayoutUtil::HasLayout(on_host_shape)) {
392 return InvalidArgument("Shape must have a layout: %s",
393 ShapeUtil::HumanStringWithLayout(on_host_shape));
394 }
395 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(on_host_shape));
396 Shape on_device_shape = (shape_representation_fn == nullptr)
397 ? HostShapeToDeviceShape(on_host_shape)
398 : shape_representation_fn(on_host_shape);
399 TF_RET_CHECK(LayoutUtil::HasLayout(on_device_shape));
400
401 ScopedShapedBuffer shaped_buffer(std::move(on_device_shape), allocator,
402 device_ordinal);
403
404 // Allocate an appropriate sized buffer for each element in the shape
405 // including the tuple pointer arrays.
406 for (auto& pair : shaped_buffer.buffers()) {
407 const ShapeIndex& index = pair.first;
408 se::DeviceMemoryBase& memory_base = pair.second;
409 const Shape& subshape =
410 ShapeUtil::GetSubshape(shaped_buffer.on_device_shape(), index);
411 TF_ASSIGN_OR_RETURN(auto memory,
412 allocator->Allocate(shaped_buffer.device_ordinal(),
413 GetByteSizeRequirement(subshape),
414 /*retry_on_failure=*/true,
415 subshape.layout().memory_space()));
416 // Move the allocated buffer into the ScopedShapedBuffer, which owns it.
417 memory_base = memory.Release();
418 }
419
420 return std::move(shaped_buffer);
421 }
422
ChooseCompactLayoutForShape(const Shape & host_shape) const423 StatusOr<Shape> TransferManager::ChooseCompactLayoutForShape(
424 const Shape& host_shape) const {
425 return LayoutUtil::GetWithDefaultLayout(host_shape);
426 }
427
428 } // namespace xla
429