• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/transfer_manager.h"
17 
18 #include <string>
19 #include <utility>
20 
21 #include "absl/memory/memory.h"
22 #include "absl/strings/str_cat.h"
23 #include "tensorflow/compiler/xla/service/compiler.h"
24 #include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h"
25 #include "tensorflow/compiler/xla/shape_util.h"
26 #include "tensorflow/compiler/xla/status_macros.h"
27 #include "tensorflow/compiler/xla/types.h"
28 #include "tensorflow/compiler/xla/util.h"
29 #include "tensorflow/core/lib/gtl/cleanup.h"
30 #include "tensorflow/core/platform/logging.h"
31 #include "tensorflow/core/platform/macros.h"
32 #include "tensorflow/core/platform/notification.h"
33 
34 using absl::StrCat;
35 
36 namespace xla {
37 
38 /* static */ tensorflow::mutex
39     TransferManager::platform_transfer_manager_mutex_(
40         tensorflow::LINKER_INITIALIZED);
41 
42 /* static */ absl::flat_hash_map<se::Platform::Id, TransferManager::State>*
GetPlatformTransferManagers()43 TransferManager::GetPlatformTransferManagers() {
44   static auto* r =
45       new absl::flat_hash_map<se::Platform::Id, TransferManager::State>;
46   return r;
47 }
48 
~TransferMetadata()49 TransferManager::TransferMetadata::~TransferMetadata() {}
50 
TransferLiteralFromDevice(se::Stream * stream,const ShapedBuffer & device_buffer,const TransferMetadata * transfer_metadata)51 StatusOr<Literal> TransferManager::TransferLiteralFromDevice(
52     se::Stream* stream, const ShapedBuffer& device_buffer,
53     const TransferMetadata* transfer_metadata) {
54   StatusOr<Literal> ret;
55 
56   se::Stream* substream = stream->GetOrCreateSubStream();
57   substream->ThenWaitFor(stream);
58   auto cleanup = tensorflow::gtl::MakeCleanup(
59       [&]() { stream->ReturnSubStream(substream); });
60 
61   tensorflow::Notification n;
62   Status s;
63   Literal literal(device_buffer.on_host_shape());
64   TransferLiteralFromDevice(
65       substream, device_buffer, &literal,
66       [&](Status status) {
67         s = status;
68         n.Notify();
69       },
70       transfer_metadata);
71   n.WaitForNotification();
72   if (!s.ok()) {
73     return s;
74   }
75   return std::move(literal);
76 }
77 
TransferLiteralFromDevice(se::Stream * stream,const ShapedBuffer & device_buffer,const MutableBorrowingLiteral & literal,const TransferMetadata * transfer_metadata)78 Status TransferManager::TransferLiteralFromDevice(
79     se::Stream* stream, const ShapedBuffer& device_buffer,
80     const MutableBorrowingLiteral& literal,
81     const TransferMetadata* transfer_metadata) {
82   se::Stream* substream = stream->GetOrCreateSubStream();
83   auto cleanup = tensorflow::gtl::MakeCleanup(
84       [&]() { stream->ReturnSubStream(substream); });
85 
86   Status ret;
87   tensorflow::Notification n;
88   TransferLiteralFromDevice(
89       substream, device_buffer, literal,
90       [&](Status status) {
91         ret = status;
92         n.Notify();
93       },
94       transfer_metadata);
95   n.WaitForNotification();
96   return ret;
97 }
98 
TransferLiteralToDevice(se::Stream * stream,const LiteralSlice & literal,const ShapedBuffer & device_buffer,const TransferMetadata * transfer_metadata)99 Status TransferManager::TransferLiteralToDevice(
100     se::Stream* stream, const LiteralSlice& literal,
101     const ShapedBuffer& device_buffer,
102     const TransferMetadata* transfer_metadata) {
103   // Implement the synchronous version by waiting on the asynchronous version.
104   // Use a substream so that if we are called from a HostCallback we don't
105   // deadlock.
106   se::Stream* substream = stream->GetOrCreateSubStream();
107   substream->ThenWaitFor(stream);
108   auto cleanup = tensorflow::gtl::MakeCleanup(
109       [&]() { stream->ReturnSubStream(substream); });
110   TF_RETURN_IF_ERROR(TransferLiteralToDeviceAsync(
111       substream, literal, device_buffer, transfer_metadata));
112   return substream->BlockHostUntilDone();
113 }
114 
TransferArrayFromDevice(se::Stream * stream,const Shape & shape,const se::DeviceMemoryBase & source,const TransferMetadata * transfer_metadata)115 StatusOr<Literal> TransferManager::TransferArrayFromDevice(
116     se::Stream* stream, const Shape& shape, const se::DeviceMemoryBase& source,
117     const TransferMetadata* transfer_metadata) {
118   StatusOr<Literal> ret;
119   // Implement the synchronous version by waiting on the asynchronous version.
120   // Use a substream so that if we are called from a HostCallback we don't
121   // deadlock.
122   se::Stream* substream = stream->GetOrCreateSubStream();
123   auto cleanup = tensorflow::gtl::MakeCleanup(
124       [&]() { stream->ReturnSubStream(substream); });
125 
126   tensorflow::Notification n;
127   Literal literal(shape);
128   Status s;
129   TransferArrayFromDevice(
130       substream, shape, source, &literal,
131       [&](Status status) {
132         s = status;
133         n.Notify();
134       },
135       transfer_metadata);
136   n.WaitForNotification();
137   if (!s.ok()) {
138     return s;
139   }
140   return std::move(literal);
141 }
142 
TransferArrayToDevice(se::Stream * stream,const LiteralSlice & literal,const se::DeviceMemoryBase & dest,const TransferMetadata * transfer_metadata)143 Status TransferManager::TransferArrayToDevice(
144     se::Stream* stream, const LiteralSlice& literal,
145     const se::DeviceMemoryBase& dest,
146     const TransferMetadata* transfer_metadata) {
147   // Implement the synchronous version by waiting on the asynchronous version.
148   // Use a substream so that if we are called from a HostCallback we don't
149   // deadlock.
150   se::Stream* substream = stream->GetOrCreateSubStream();
151   auto cleanup = tensorflow::gtl::MakeCleanup(
152       [&]() { stream->ReturnSubStream(substream); });
153   TF_RETURN_IF_ERROR(
154       TransferArrayToDeviceAsync(substream, literal, dest, transfer_metadata));
155   return substream->BlockHostUntilDone();
156 }
157 
TransferArrayToDeviceAsync(se::Stream * stream,const LiteralSlice & literal,const se::DeviceMemoryBase & dest,const TransferMetadata * transfer_metadata)158 Status TransferManager::TransferArrayToDeviceAsync(
159     se::Stream* stream, const LiteralSlice& literal,
160     const se::DeviceMemoryBase& dest,
161     const TransferMetadata* transfer_metadata) {
162   const Shape on_device_shape = HostShapeToDeviceShape(literal.shape());
163   TF_RET_CHECK(on_device_shape.IsArray())
164       << "On-device representation of "
165       << ShapeUtil::HumanString(literal.shape())
166       << " is not an array: " << ShapeUtil::HumanString(on_device_shape);
167   if (dest.size() < GetByteSizeRequirement(on_device_shape)) {
168     return FailedPrecondition(
169         "Allocation on device not large enough for array: "
170         "%d < %d",
171         dest.size(), GetByteSizeRequirement(on_device_shape));
172   }
173   ShapedBuffer shaped_buffer(on_device_shape,
174                              stream->parent()->device_ordinal());
175   shaped_buffer.set_buffer(dest, /*index=*/{});
176   return TransferLiteralToDevice(stream, literal, shaped_buffer,
177                                  transfer_metadata);
178 }
179 
TransferArrayFromDevice(se::Stream * stream,const Shape & shape,const se::DeviceMemoryBase & source,const MutableBorrowingLiteral & literal,std::function<void (Status)> done,const TransferMetadata * transfer_metadata)180 void TransferManager::TransferArrayFromDevice(
181     se::Stream* stream, const Shape& shape, const se::DeviceMemoryBase& source,
182     const MutableBorrowingLiteral& literal, std::function<void(Status)> done,
183     const TransferMetadata* transfer_metadata) {
184   if (!Shape::Equal().MinorToMajorOnlyInLayout()(HostShapeToDeviceShape(shape),
185                                                  shape)) {
186     auto error = StrCat("Shape ", ShapeUtil::HumanString(shape),
187                         " has a differently shaped representation on-device: ",
188                         ShapeUtil::HumanString(HostShapeToDeviceShape(shape)));
189     return done(FailedPrecondition("%s", error));
190   }
191   if (source.size() < GetByteSizeRequirement(shape)) {
192     return done(
193         FailedPrecondition("Allocation on device not large enough for array: "
194                            "%d < %d",
195                            source.size(), GetByteSizeRequirement(shape)));
196   }
197   ShapedBuffer shaped_buffer(shape, stream->parent()->device_ordinal());
198   shaped_buffer.set_buffer(source, /*index=*/{});
199   return TransferLiteralFromDevice(stream, shaped_buffer, literal,
200                                    std::move(done), transfer_metadata);
201 }
202 
ReadDynamicShapes(se::Stream * stream,ShapedBuffer * device_buffer,Shape * device_shape)203 Status TransferManager::ReadDynamicShapes(se::Stream* stream,
204                                           ShapedBuffer* device_buffer,
205                                           Shape* device_shape) {
206   DCHECK(device_shape->is_dynamic());
207   Shape original_device_shape = *device_shape;
208   TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
209 
210   TF_ASSIGN_OR_RETURN(auto compiler,
211                       Compiler::GetForPlatform(stream->parent()->platform()));
212   TF_RETURN_IF_ERROR(device_buffer->buffers().ForEachMutableElementWithStatus(
213       [&](const ShapeIndex& index, se::DeviceMemoryBase* buffer) {
214         const Shape& buffer_shape =
215             ShapeUtil::GetSubshape(*device_shape, index);
216         if (buffer_shape.IsTuple()) {
217           return Status::OK();
218         }
219         Shape& device_sub_shape =
220             *ShapeUtil::GetMutableSubshape(device_shape, index);
221         if (device_sub_shape.is_static()) {
222           return Status::OK();
223         }
224 
225         // Read the dynamic shape metadata from the device stream.
226         auto shape_size_fn = compiler->ShapeSizeBytesFunction();
227         Shape buffer_shape_static = ShapeUtil::MakeStaticShape(buffer_shape);
228         const int64_t offset = shape_size_fn(buffer_shape_static);
229         int64_t metadata_size = shape_size_fn(buffer_shape) - offset;
230         if (metadata_size == 0) {
231           return InvalidArgument("Dynamic shape metadata size should not be 0");
232         }
233         auto buffer_8 = se::DeviceMemory<uint8>(*buffer);
234         auto metadata_buffer =
235             stream->parent()->GetSubBuffer(&buffer_8, offset, metadata_size);
236         TF_ASSIGN_OR_RETURN(
237             auto metadata,
238             TransferArrayFromDevice(
239                 stream,
240                 ShapeUtil::MakeShape(S32, {buffer_shape.dimensions_size()}),
241                 metadata_buffer));
242 
243         // Update shape size from metadata.
244         for (int64_t i = 0; i < metadata.element_count(); ++i) {
245           device_sub_shape.mutable_dimensions()[i] = metadata.Get<int32>({i});
246         }
247         return Status::OK();
248       }));
249   device_shape->clear_dynamic_dimensions();
250 
251   TF_RET_CHECK(ShapeUtil::DynamicShapeIsCompatible(*device_shape,
252                                                    original_device_shape));
253   return Status::OK();
254 }
255 
RegisterTransferManager(se::Platform::Id platform_id,TransferManagerCreationFunction creation_function)256 /* static */ void TransferManager::RegisterTransferManager(
257     se::Platform::Id platform_id,
258     TransferManagerCreationFunction creation_function) {
259   tensorflow::mutex_lock lock(
260       TransferManager::platform_transfer_manager_mutex_);
261   auto* managers = GetPlatformTransferManagers();
262   CHECK(managers->find(platform_id) == managers->end());
263   (*managers)[platform_id].creation_function = creation_function;
264 }
265 
GetForPlatform(const se::Platform * platform)266 /* static */ StatusOr<TransferManager*> TransferManager::GetForPlatform(
267     const se::Platform* platform) {
268   tensorflow::mutex_lock lock(
269       TransferManager::platform_transfer_manager_mutex_);
270   auto* managers = GetPlatformTransferManagers();
271 
272   auto it = managers->find(platform->id());
273   if (it == managers->end()) {
274     return NotFound(
275         "could not find registered transfer manager for platform %s -- check "
276         "target linkage",
277         platform->Name());
278   }
279 
280   if (it->second.manager == nullptr) {
281     // Lazily create the transfer manager the first time it is needed
282     it->second.manager = (*it->second.creation_function)();
283   }
284 
285   return it->second.manager.get();
286 }
287 
WriteTupleIndexTables(se::Stream * stream,const ShapedBuffer & device_buffer)288 Status TransferManager::WriteTupleIndexTables(
289     se::Stream* stream, const ShapedBuffer& device_buffer) {
290   TF_RETURN_IF_ERROR(WriteTupleIndexTablesAsync(stream, device_buffer));
291   return stream->BlockHostUntilDone();
292 }
293 
WriteTupleIndexTablesAsync(se::Stream * stream,const ShapedBuffer & device_buffer)294 Status TransferManager::WriteTupleIndexTablesAsync(
295     se::Stream* stream, const ShapedBuffer& device_buffer) {
296   VLOG(2) << "Writing tuple index tables for " << device_buffer;
297 
298   return ShapeUtil::ForEachSubshapeWithStatus(
299       device_buffer.on_device_shape(),
300       [&](const Shape& device_subshape, const ShapeIndex& index) -> Status {
301         if (device_subshape.IsTuple() &&
302             ShapeUtil::TupleElementCount(device_subshape) > 0) {
303           se::DeviceMemoryBase device_memory = device_buffer.buffer(index);
304           TF_RET_CHECK(GetByteSizeRequirement(device_subshape) ==
305                        device_memory.size());
306 
307           std::vector<se::DeviceMemoryBase> elements;
308           ShapeIndex element_index = index;
309           for (int64_t i = 0; i < ShapeUtil::TupleElementCount(device_subshape);
310                ++i) {
311             element_index.push_back(i);
312             elements.push_back(device_buffer.buffer(element_index));
313             element_index.pop_back();
314           }
315           return WriteSingleTupleIndexTable(stream, elements, device_subshape,
316                                             &device_memory);
317         }
318 
319         return Status::OK();
320       });
321 }
322 
WriteRootTupleIndexTable(se::Stream * stream,const ShapedBuffer & device_buffer)323 Status TransferManager::WriteRootTupleIndexTable(
324     se::Stream* stream, const ShapedBuffer& device_buffer) {
325   TF_RET_CHECK(device_buffer.on_device_shape().IsTuple());
326   if (ShapeUtil::TupleElementCount(device_buffer.on_device_shape()) == 0) {
327     return Status::OK();
328   }
329   se::DeviceMemoryBase device_memory = device_buffer.buffer({});
330   TF_RET_CHECK(GetByteSizeRequirement(device_buffer.on_device_shape()) ==
331                device_memory.size());
332 
333   std::vector<se::DeviceMemoryBase> elements;
334   for (int64_t i = 0;
335        i < ShapeUtil::TupleElementCount(device_buffer.on_device_shape()); ++i) {
336     elements.push_back(device_buffer.buffer({i}));
337   }
338   return WriteSingleTupleIndexTable(
339       stream, elements, device_buffer.on_device_shape(), &device_memory);
340 }
341 
WriteRootTupleIndexTable(se::Stream * stream,const ShapeTree<MaybeOwningDeviceMemory> & buffer_tree)342 Status TransferManager::WriteRootTupleIndexTable(
343     se::Stream* stream, const ShapeTree<MaybeOwningDeviceMemory>& buffer_tree) {
344   TF_RET_CHECK(buffer_tree.shape().IsTuple());
345   if (ShapeUtil::TupleElementCount(buffer_tree.shape()) == 0) {
346     return Status::OK();
347   }
348   se::DeviceMemoryBase device_memory =
349       buffer_tree.element({}).AsDeviceMemoryBase();
350   TF_RET_CHECK(GetByteSizeRequirement(buffer_tree.shape()) ==
351                device_memory.size());
352 
353   std::vector<se::DeviceMemoryBase> elements;
354   for (int64_t i = 0; i < ShapeUtil::TupleElementCount(buffer_tree.shape());
355        ++i) {
356     elements.push_back(buffer_tree.element({i}).AsDeviceMemoryBase());
357   }
358   return WriteSingleTupleIndexTable(stream, elements, buffer_tree.shape(),
359                                     &device_memory);
360 }
361 
TransferBufferFromDevice(se::Stream * stream,const se::DeviceMemoryBase & source,int64_t size,void * destination)362 Status TransferManager::TransferBufferFromDevice(
363     se::Stream* stream, const se::DeviceMemoryBase& source, int64_t size,
364     void* destination) {
365   if (source.size() < size) {
366     return FailedPrecondition(
367         "Source allocation on device not large enough for data transfer: "
368         "%d < %d",
369         source.size(), size);
370   }
371   stream->ThenMemcpy(destination, source, size);
372   return Status::OK();
373 }
374 
TransferBufferToDevice(se::Stream * stream,int64_t size,const void * source,se::DeviceMemoryBase * destination)375 Status TransferManager::TransferBufferToDevice(
376     se::Stream* stream, int64_t size, const void* source,
377     se::DeviceMemoryBase* destination) {
378   if (destination->size() < size) {
379     return FailedPrecondition(
380         "Destination allocation on device not large enough for data transfer: "
381         "%d < %d",
382         destination->size(), size);
383   }
384   stream->ThenMemcpy(destination, source, size);
385   return Status::OK();
386 }
387 
AllocateScopedShapedBuffer(const Shape & on_host_shape,se::DeviceMemoryAllocator * allocator,int device_ordinal,DeviceShapeRepresentationFn shape_representation_fn)388 StatusOr<ScopedShapedBuffer> TransferManager::AllocateScopedShapedBuffer(
389     const Shape& on_host_shape, se::DeviceMemoryAllocator* allocator,
390     int device_ordinal, DeviceShapeRepresentationFn shape_representation_fn) {
391   if (!LayoutUtil::HasLayout(on_host_shape)) {
392     return InvalidArgument("Shape must have a layout: %s",
393                            ShapeUtil::HumanStringWithLayout(on_host_shape));
394   }
395   TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(on_host_shape));
396   Shape on_device_shape = (shape_representation_fn == nullptr)
397                               ? HostShapeToDeviceShape(on_host_shape)
398                               : shape_representation_fn(on_host_shape);
399   TF_RET_CHECK(LayoutUtil::HasLayout(on_device_shape));
400 
401   ScopedShapedBuffer shaped_buffer(std::move(on_device_shape), allocator,
402                                    device_ordinal);
403 
404   // Allocate an appropriate sized buffer for each element in the shape
405   // including the tuple pointer arrays.
406   for (auto& pair : shaped_buffer.buffers()) {
407     const ShapeIndex& index = pair.first;
408     se::DeviceMemoryBase& memory_base = pair.second;
409     const Shape& subshape =
410         ShapeUtil::GetSubshape(shaped_buffer.on_device_shape(), index);
411     TF_ASSIGN_OR_RETURN(auto memory,
412                         allocator->Allocate(shaped_buffer.device_ordinal(),
413                                             GetByteSizeRequirement(subshape),
414                                             /*retry_on_failure=*/true,
415                                             subshape.layout().memory_space()));
416     // Move the allocated buffer into the ScopedShapedBuffer, which owns it.
417     memory_base = memory.Release();
418   }
419 
420   return std::move(shaped_buffer);
421 }
422 
ChooseCompactLayoutForShape(const Shape & host_shape) const423 StatusOr<Shape> TransferManager::ChooseCompactLayoutForShape(
424     const Shape& host_shape) const {
425   return LayoutUtil::GetWithDefaultLayout(host_shape);
426 }
427 
428 }  // namespace xla
429