• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/transfer_manager.h"
17 
18 #include <string>
19 #include <utility>
20 
21 #include "absl/memory/memory.h"
22 #include "absl/strings/str_cat.h"
23 #include "tensorflow/compiler/xla/shape_util.h"
24 #include "tensorflow/compiler/xla/status_macros.h"
25 #include "tensorflow/compiler/xla/types.h"
26 #include "tensorflow/compiler/xla/util.h"
27 #include "tensorflow/core/lib/gtl/cleanup.h"
28 #include "tensorflow/core/platform/logging.h"
29 #include "tensorflow/core/platform/macros.h"
30 #include "tensorflow/core/platform/notification.h"
31 
32 using absl::StrCat;
33 
34 namespace xla {
35 /* static */ tensorflow::mutex
36     TransferManager::platform_transfer_manager_mutex_(
37         tensorflow::LINKER_INITIALIZED);
38 
39 /* static */ std::map<se::Platform::Id, TransferManager::State>*
GetPlatformTransferManagers()40 TransferManager::GetPlatformTransferManagers() {
41   static auto* r = new std::map<se::Platform::Id, TransferManager::State>;
42   return r;
43 }
44 
~TransferMetadata()45 TransferManager::TransferMetadata::~TransferMetadata() {}
46 
TransferLiteralFromDevice(se::Stream * stream,const ShapedBuffer & device_buffer,const TransferMetadata * transfer_metadata)47 StatusOr<Literal> TransferManager::TransferLiteralFromDevice(
48     se::Stream* stream, const ShapedBuffer& device_buffer,
49     const TransferMetadata* transfer_metadata) {
50   StatusOr<Literal> ret;
51 
52   se::Stream* substream = stream->GetOrCreateSubStream();
53   substream->ThenWaitFor(stream);
54   auto cleanup = tensorflow::gtl::MakeCleanup(
55       [&]() { stream->ReturnSubStream(substream); });
56 
57   tensorflow::Notification n;
58   Status s;
59   Literal literal(device_buffer.on_host_shape());
60   TransferLiteralFromDevice(
61       substream, device_buffer, literal,
62       [&](Status status) {
63         s = status;
64         n.Notify();
65       },
66       transfer_metadata);
67   n.WaitForNotification();
68   if (!s.ok()) {
69     return s;
70   }
71   return std::move(literal);
72 }
73 
TransferLiteralFromDevice(se::Stream * stream,const ShapedBuffer & device_buffer,const MutableBorrowingLiteral & literal,const TransferMetadata * transfer_metadata)74 Status TransferManager::TransferLiteralFromDevice(
75     se::Stream* stream, const ShapedBuffer& device_buffer,
76     const MutableBorrowingLiteral& literal,
77     const TransferMetadata* transfer_metadata) {
78   se::Stream* substream = stream->GetOrCreateSubStream();
79   auto cleanup = tensorflow::gtl::MakeCleanup(
80       [&]() { stream->ReturnSubStream(substream); });
81 
82   Status ret;
83   tensorflow::Notification n;
84   TransferLiteralFromDevice(
85       substream, device_buffer, literal,
86       [&](Status status) {
87         ret = status;
88         n.Notify();
89       },
90       transfer_metadata);
91   n.WaitForNotification();
92   return ret;
93 }
94 
TransferLiteralToDevice(se::Stream * stream,const LiteralSlice & literal,const ShapedBuffer & device_buffer,const TransferMetadata * transfer_metadata)95 Status TransferManager::TransferLiteralToDevice(
96     se::Stream* stream, const LiteralSlice& literal,
97     const ShapedBuffer& device_buffer,
98     const TransferMetadata* transfer_metadata) {
99   // Implement the synchronous version by waiting on the asynchronous version.
100   // Use a substream so that if we are called from a HostCallback we don't
101   // deadlock.
102   se::Stream* substream = stream->GetOrCreateSubStream();
103   substream->ThenWaitFor(stream);
104   auto cleanup = tensorflow::gtl::MakeCleanup(
105       [&]() { stream->ReturnSubStream(substream); });
106   TF_RETURN_IF_ERROR(TransferLiteralToDeviceAsync(
107       substream, literal, device_buffer, transfer_metadata));
108   return substream->BlockHostUntilDone();
109 }
110 
TransferArrayFromDevice(se::Stream * stream,const Shape & shape,const se::DeviceMemoryBase & source,const TransferMetadata * transfer_metadata)111 StatusOr<Literal> TransferManager::TransferArrayFromDevice(
112     se::Stream* stream, const Shape& shape, const se::DeviceMemoryBase& source,
113     const TransferMetadata* transfer_metadata) {
114   StatusOr<Literal> ret;
115   // Implement the synchronous version by waiting on the asynchronous version.
116   // Use a substream so that if we are called from a HostCallback we don't
117   // deadlock.
118   se::Stream* substream = stream->GetOrCreateSubStream();
119   auto cleanup = tensorflow::gtl::MakeCleanup(
120       [&]() { stream->ReturnSubStream(substream); });
121 
122   tensorflow::Notification n;
123   Literal literal(shape);
124   Status s;
125   TransferArrayFromDevice(
126       substream, shape, source, literal,
127       [&](Status status) {
128         s = status;
129         n.Notify();
130       },
131       transfer_metadata);
132   n.WaitForNotification();
133   if (!s.ok()) {
134     return s;
135   }
136   return std::move(literal);
137 }
138 
TransferArrayToDevice(se::Stream * stream,const LiteralSlice & literal,const se::DeviceMemoryBase & dest,const TransferMetadata * transfer_metadata)139 Status TransferManager::TransferArrayToDevice(
140     se::Stream* stream, const LiteralSlice& literal,
141     const se::DeviceMemoryBase& dest,
142     const TransferMetadata* transfer_metadata) {
143   // Implement the synchronous version by waiting on the asynchronous version.
144   // Use a substream so that if we are called from a HostCallback we don't
145   // deadlock.
146   se::Stream* substream = stream->GetOrCreateSubStream();
147   auto cleanup = tensorflow::gtl::MakeCleanup(
148       [&]() { stream->ReturnSubStream(substream); });
149   TF_RETURN_IF_ERROR(
150       TransferArrayToDeviceAsync(substream, literal, dest, transfer_metadata));
151   return substream->BlockHostUntilDone();
152 }
153 
TransferArrayToDeviceAsync(se::Stream * stream,const LiteralSlice & literal,const se::DeviceMemoryBase & dest,const TransferMetadata * transfer_metadata)154 Status TransferManager::TransferArrayToDeviceAsync(
155     se::Stream* stream, const LiteralSlice& literal,
156     const se::DeviceMemoryBase& dest,
157     const TransferMetadata* transfer_metadata) {
158   const Shape on_device_shape = HostShapeToDeviceShape(literal.shape());
159   TF_RET_CHECK(on_device_shape.IsArray())
160       << "On-device representation of "
161       << ShapeUtil::HumanString(literal.shape())
162       << " is not an array: " << ShapeUtil::HumanString(on_device_shape);
163   if (dest.size() < GetByteSizeRequirement(on_device_shape)) {
164     return FailedPrecondition(
165         "Allocation on device not large enough for array: "
166         "%d < %d",
167         dest.size(), GetByteSizeRequirement(on_device_shape));
168   }
169   ShapedBuffer shaped_buffer(/*on_host_shape=*/literal.shape(), on_device_shape,
170                              stream->parent()->platform(),
171                              stream->parent()->device_ordinal());
172   shaped_buffer.set_buffer(dest, /*index=*/{});
173   return TransferLiteralToDevice(stream, literal, shaped_buffer,
174                                  transfer_metadata);
175 }
176 
TransferArrayFromDevice(se::Stream * stream,const Shape & shape,const se::DeviceMemoryBase & source,const MutableBorrowingLiteral & literal,std::function<void (Status)> done,const TransferMetadata * transfer_metadata)177 void TransferManager::TransferArrayFromDevice(
178     se::Stream* stream, const Shape& shape, const se::DeviceMemoryBase& source,
179     const MutableBorrowingLiteral& literal, std::function<void(Status)> done,
180     const TransferMetadata* transfer_metadata) {
181   if (!ShapeUtil::Equal(HostShapeToDeviceShape(shape), shape)) {
182     auto error = StrCat("Shape ", ShapeUtil::HumanString(shape),
183                         " has a differently shaped representation on-device: ",
184                         ShapeUtil::HumanString(HostShapeToDeviceShape(shape)));
185     return done(FailedPrecondition("%s", error));
186   }
187   if (source.size() < GetByteSizeRequirement(shape)) {
188     return done(
189         FailedPrecondition("Allocation on device not large enough for array: "
190                            "%d < %d",
191                            source.size(), GetByteSizeRequirement(shape)));
192   }
193   ShapedBuffer shaped_buffer(/*on_host_shape=*/shape, shape,
194                              stream->parent()->platform(),
195                              stream->parent()->device_ordinal());
196   shaped_buffer.set_buffer(source, /*index=*/{});
197   return TransferLiteralFromDevice(stream, shaped_buffer, literal,
198                                    std::move(done), transfer_metadata);
199 }
200 
RegisterTransferManager(se::Platform::Id platform_id,TransferManagerCreationFunction creation_function)201 /* static */ void TransferManager::RegisterTransferManager(
202     se::Platform::Id platform_id,
203     TransferManagerCreationFunction creation_function) {
204   tensorflow::mutex_lock lock(
205       TransferManager::platform_transfer_manager_mutex_);
206   auto* managers = GetPlatformTransferManagers();
207   CHECK(managers->find(platform_id) == managers->end());
208   (*managers)[platform_id].creation_function = creation_function;
209 }
210 
GetForPlatform(const se::Platform * platform)211 /* static */ StatusOr<TransferManager*> TransferManager::GetForPlatform(
212     const se::Platform* platform) {
213   tensorflow::mutex_lock lock(
214       TransferManager::platform_transfer_manager_mutex_);
215   auto* managers = GetPlatformTransferManagers();
216 
217   auto it = managers->find(platform->id());
218   if (it == managers->end()) {
219     return NotFound(
220         "could not find registered transfer manager for platform %s -- check "
221         "target linkage",
222         platform->Name());
223   }
224 
225   if (it->second.manager == nullptr) {
226     // Lazily create the transfer manager the first time it is needed
227     it->second.manager = (*it->second.creation_function)();
228   }
229 
230   return it->second.manager.get();
231 }
232 
WriteTupleIndexTables(se::Stream * stream,const ShapedBuffer & device_buffer)233 Status TransferManager::WriteTupleIndexTables(
234     se::Stream* stream, const ShapedBuffer& device_buffer) {
235   TF_RETURN_IF_ERROR(WriteTupleIndexTablesAsync(stream, device_buffer));
236   return stream->BlockHostUntilDone();
237 }
238 
WriteTupleIndexTablesAsync(se::Stream * stream,const ShapedBuffer & device_buffer)239 Status TransferManager::WriteTupleIndexTablesAsync(
240     se::Stream* stream, const ShapedBuffer& device_buffer) {
241   VLOG(2) << "Writing tuple index tables for " << device_buffer;
242 
243   return ShapeUtil::ForEachSubshapeWithStatus(
244       device_buffer.on_device_shape(),
245       [&](const Shape& device_subshape, const ShapeIndex& index) -> Status {
246         if (device_subshape.IsTuple()) {
247           se::DeviceMemoryBase device_memory = device_buffer.buffer(index);
248           TF_RET_CHECK(GetByteSizeRequirement(device_subshape) ==
249                        device_memory.size());
250 
251           std::vector<se::DeviceMemoryBase> elements;
252           ShapeIndex element_index = index;
253           for (int64 i = 0; i < ShapeUtil::TupleElementCount(device_subshape);
254                ++i) {
255             element_index.push_back(i);
256             elements.push_back(device_buffer.buffer(element_index));
257             element_index.pop_back();
258           }
259           return WriteSingleTupleIndexTable(stream, elements, device_subshape,
260                                             &device_memory);
261         }
262 
263         return Status::OK();
264       });
265 }
266 
WriteRootTupleIndexTable(se::Stream * stream,const ShapedBuffer & device_buffer)267 Status TransferManager::WriteRootTupleIndexTable(
268     se::Stream* stream, const ShapedBuffer& device_buffer) {
269   TF_RET_CHECK(device_buffer.on_device_shape().IsTuple());
270   se::DeviceMemoryBase device_memory = device_buffer.buffer({});
271   TF_RET_CHECK(GetByteSizeRequirement(device_buffer.on_device_shape()) ==
272                device_memory.size());
273 
274   std::vector<se::DeviceMemoryBase> elements;
275   for (int64 i = 0;
276        i < ShapeUtil::TupleElementCount(device_buffer.on_device_shape()); ++i) {
277     elements.push_back(device_buffer.buffer({i}));
278   }
279   return WriteSingleTupleIndexTable(
280       stream, elements, device_buffer.on_device_shape(), &device_memory);
281 }
282 
TransferBufferFromDevice(se::Stream * stream,const se::DeviceMemoryBase & source,int64 size,void * destination)283 Status TransferManager::TransferBufferFromDevice(
284     se::Stream* stream, const se::DeviceMemoryBase& source, int64 size,
285     void* destination) {
286   if (source.size() < size) {
287     return FailedPrecondition(
288         "Source allocation on device not large enough for data tranfer: "
289         "%d < %d",
290         source.size(), size);
291   }
292   stream->ThenMemcpy(destination, source, size);
293   return Status::OK();
294 }
295 
TransferBufferToDevice(se::Stream * stream,int64 size,const void * source,se::DeviceMemoryBase * destination)296 Status TransferManager::TransferBufferToDevice(
297     se::Stream* stream, int64 size, const void* source,
298     se::DeviceMemoryBase* destination) {
299   if (destination->size() < size) {
300     return FailedPrecondition(
301         "Destination allocation on device not large enough for data tranfer: "
302         "%d < %d",
303         destination->size(), size);
304   }
305   stream->ThenMemcpy(destination, source, size);
306   return Status::OK();
307 }
308 
AllocateScopedShapedBuffer(const Shape & on_host_shape,DeviceMemoryAllocator * allocator,int device_ordinal)309 StatusOr<ScopedShapedBuffer> TransferManager::AllocateScopedShapedBuffer(
310     const Shape& on_host_shape, DeviceMemoryAllocator* allocator,
311     int device_ordinal) {
312   if (!LayoutUtil::HasLayout(on_host_shape)) {
313     return InvalidArgument("Shape must have a layout: %s",
314                            ShapeUtil::HumanStringWithLayout(on_host_shape));
315   }
316   TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(on_host_shape));
317   const Shape on_device_shape = HostShapeToDeviceShape(on_host_shape);
318   TF_RET_CHECK(LayoutUtil::HasLayout(on_device_shape));
319 
320   ScopedShapedBuffer shaped_buffer(on_host_shape, on_device_shape, allocator,
321                                    device_ordinal);
322 
323   // Allocate an appropriate sized buffer for each element in the shape
324   // including the tuple pointer arrays.
325   for (auto& pair : shaped_buffer.buffers()) {
326     const ShapeIndex& index = pair.first;
327     se::DeviceMemoryBase& memory_base = pair.second;
328     const Shape& subshape = ShapeUtil::GetSubshape(on_device_shape, index);
329     TF_ASSIGN_OR_RETURN(auto memory,
330                         allocator->Allocate(shaped_buffer.device_ordinal(),
331                                             GetByteSizeRequirement(subshape)));
332     // Move the allocated buffer into the ScopedShapedBuffer, which owns it.
333     memory_base = memory.Forget();
334   }
335 
336   return std::move(shaped_buffer);
337 }
338 
339 }  // namespace xla
340