1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/transfer_manager.h"
17
18 #include <string>
19 #include <utility>
20
21 #include "absl/memory/memory.h"
22 #include "absl/strings/str_cat.h"
23 #include "tensorflow/compiler/xla/shape_util.h"
24 #include "tensorflow/compiler/xla/status_macros.h"
25 #include "tensorflow/compiler/xla/types.h"
26 #include "tensorflow/compiler/xla/util.h"
27 #include "tensorflow/core/lib/gtl/cleanup.h"
28 #include "tensorflow/core/platform/logging.h"
29 #include "tensorflow/core/platform/macros.h"
30 #include "tensorflow/core/platform/notification.h"
31
32 using absl::StrCat;
33
34 namespace xla {
35 /* static */ tensorflow::mutex
36 TransferManager::platform_transfer_manager_mutex_(
37 tensorflow::LINKER_INITIALIZED);
38
39 /* static */ std::map<se::Platform::Id, TransferManager::State>*
GetPlatformTransferManagers()40 TransferManager::GetPlatformTransferManagers() {
41 static auto* r = new std::map<se::Platform::Id, TransferManager::State>;
42 return r;
43 }
44
~TransferMetadata()45 TransferManager::TransferMetadata::~TransferMetadata() {}
46
TransferLiteralFromDevice(se::Stream * stream,const ShapedBuffer & device_buffer,const TransferMetadata * transfer_metadata)47 StatusOr<Literal> TransferManager::TransferLiteralFromDevice(
48 se::Stream* stream, const ShapedBuffer& device_buffer,
49 const TransferMetadata* transfer_metadata) {
50 StatusOr<Literal> ret;
51
52 se::Stream* substream = stream->GetOrCreateSubStream();
53 substream->ThenWaitFor(stream);
54 auto cleanup = tensorflow::gtl::MakeCleanup(
55 [&]() { stream->ReturnSubStream(substream); });
56
57 tensorflow::Notification n;
58 Status s;
59 Literal literal(device_buffer.on_host_shape());
60 TransferLiteralFromDevice(
61 substream, device_buffer, literal,
62 [&](Status status) {
63 s = status;
64 n.Notify();
65 },
66 transfer_metadata);
67 n.WaitForNotification();
68 if (!s.ok()) {
69 return s;
70 }
71 return std::move(literal);
72 }
73
TransferLiteralFromDevice(se::Stream * stream,const ShapedBuffer & device_buffer,const MutableBorrowingLiteral & literal,const TransferMetadata * transfer_metadata)74 Status TransferManager::TransferLiteralFromDevice(
75 se::Stream* stream, const ShapedBuffer& device_buffer,
76 const MutableBorrowingLiteral& literal,
77 const TransferMetadata* transfer_metadata) {
78 se::Stream* substream = stream->GetOrCreateSubStream();
79 auto cleanup = tensorflow::gtl::MakeCleanup(
80 [&]() { stream->ReturnSubStream(substream); });
81
82 Status ret;
83 tensorflow::Notification n;
84 TransferLiteralFromDevice(
85 substream, device_buffer, literal,
86 [&](Status status) {
87 ret = status;
88 n.Notify();
89 },
90 transfer_metadata);
91 n.WaitForNotification();
92 return ret;
93 }
94
TransferLiteralToDevice(se::Stream * stream,const LiteralSlice & literal,const ShapedBuffer & device_buffer,const TransferMetadata * transfer_metadata)95 Status TransferManager::TransferLiteralToDevice(
96 se::Stream* stream, const LiteralSlice& literal,
97 const ShapedBuffer& device_buffer,
98 const TransferMetadata* transfer_metadata) {
99 // Implement the synchronous version by waiting on the asynchronous version.
100 // Use a substream so that if we are called from a HostCallback we don't
101 // deadlock.
102 se::Stream* substream = stream->GetOrCreateSubStream();
103 substream->ThenWaitFor(stream);
104 auto cleanup = tensorflow::gtl::MakeCleanup(
105 [&]() { stream->ReturnSubStream(substream); });
106 TF_RETURN_IF_ERROR(TransferLiteralToDeviceAsync(
107 substream, literal, device_buffer, transfer_metadata));
108 return substream->BlockHostUntilDone();
109 }
110
TransferArrayFromDevice(se::Stream * stream,const Shape & shape,const se::DeviceMemoryBase & source,const TransferMetadata * transfer_metadata)111 StatusOr<Literal> TransferManager::TransferArrayFromDevice(
112 se::Stream* stream, const Shape& shape, const se::DeviceMemoryBase& source,
113 const TransferMetadata* transfer_metadata) {
114 StatusOr<Literal> ret;
115 // Implement the synchronous version by waiting on the asynchronous version.
116 // Use a substream so that if we are called from a HostCallback we don't
117 // deadlock.
118 se::Stream* substream = stream->GetOrCreateSubStream();
119 auto cleanup = tensorflow::gtl::MakeCleanup(
120 [&]() { stream->ReturnSubStream(substream); });
121
122 tensorflow::Notification n;
123 Literal literal(shape);
124 Status s;
125 TransferArrayFromDevice(
126 substream, shape, source, literal,
127 [&](Status status) {
128 s = status;
129 n.Notify();
130 },
131 transfer_metadata);
132 n.WaitForNotification();
133 if (!s.ok()) {
134 return s;
135 }
136 return std::move(literal);
137 }
138
TransferArrayToDevice(se::Stream * stream,const LiteralSlice & literal,const se::DeviceMemoryBase & dest,const TransferMetadata * transfer_metadata)139 Status TransferManager::TransferArrayToDevice(
140 se::Stream* stream, const LiteralSlice& literal,
141 const se::DeviceMemoryBase& dest,
142 const TransferMetadata* transfer_metadata) {
143 // Implement the synchronous version by waiting on the asynchronous version.
144 // Use a substream so that if we are called from a HostCallback we don't
145 // deadlock.
146 se::Stream* substream = stream->GetOrCreateSubStream();
147 auto cleanup = tensorflow::gtl::MakeCleanup(
148 [&]() { stream->ReturnSubStream(substream); });
149 TF_RETURN_IF_ERROR(
150 TransferArrayToDeviceAsync(substream, literal, dest, transfer_metadata));
151 return substream->BlockHostUntilDone();
152 }
153
TransferArrayToDeviceAsync(se::Stream * stream,const LiteralSlice & literal,const se::DeviceMemoryBase & dest,const TransferMetadata * transfer_metadata)154 Status TransferManager::TransferArrayToDeviceAsync(
155 se::Stream* stream, const LiteralSlice& literal,
156 const se::DeviceMemoryBase& dest,
157 const TransferMetadata* transfer_metadata) {
158 const Shape on_device_shape = HostShapeToDeviceShape(literal.shape());
159 TF_RET_CHECK(on_device_shape.IsArray())
160 << "On-device representation of "
161 << ShapeUtil::HumanString(literal.shape())
162 << " is not an array: " << ShapeUtil::HumanString(on_device_shape);
163 if (dest.size() < GetByteSizeRequirement(on_device_shape)) {
164 return FailedPrecondition(
165 "Allocation on device not large enough for array: "
166 "%d < %d",
167 dest.size(), GetByteSizeRequirement(on_device_shape));
168 }
169 ShapedBuffer shaped_buffer(/*on_host_shape=*/literal.shape(), on_device_shape,
170 stream->parent()->platform(),
171 stream->parent()->device_ordinal());
172 shaped_buffer.set_buffer(dest, /*index=*/{});
173 return TransferLiteralToDevice(stream, literal, shaped_buffer,
174 transfer_metadata);
175 }
176
TransferArrayFromDevice(se::Stream * stream,const Shape & shape,const se::DeviceMemoryBase & source,const MutableBorrowingLiteral & literal,std::function<void (Status)> done,const TransferMetadata * transfer_metadata)177 void TransferManager::TransferArrayFromDevice(
178 se::Stream* stream, const Shape& shape, const se::DeviceMemoryBase& source,
179 const MutableBorrowingLiteral& literal, std::function<void(Status)> done,
180 const TransferMetadata* transfer_metadata) {
181 if (!ShapeUtil::Equal(HostShapeToDeviceShape(shape), shape)) {
182 auto error = StrCat("Shape ", ShapeUtil::HumanString(shape),
183 " has a differently shaped representation on-device: ",
184 ShapeUtil::HumanString(HostShapeToDeviceShape(shape)));
185 return done(FailedPrecondition("%s", error));
186 }
187 if (source.size() < GetByteSizeRequirement(shape)) {
188 return done(
189 FailedPrecondition("Allocation on device not large enough for array: "
190 "%d < %d",
191 source.size(), GetByteSizeRequirement(shape)));
192 }
193 ShapedBuffer shaped_buffer(/*on_host_shape=*/shape, shape,
194 stream->parent()->platform(),
195 stream->parent()->device_ordinal());
196 shaped_buffer.set_buffer(source, /*index=*/{});
197 return TransferLiteralFromDevice(stream, shaped_buffer, literal,
198 std::move(done), transfer_metadata);
199 }
200
RegisterTransferManager(se::Platform::Id platform_id,TransferManagerCreationFunction creation_function)201 /* static */ void TransferManager::RegisterTransferManager(
202 se::Platform::Id platform_id,
203 TransferManagerCreationFunction creation_function) {
204 tensorflow::mutex_lock lock(
205 TransferManager::platform_transfer_manager_mutex_);
206 auto* managers = GetPlatformTransferManagers();
207 CHECK(managers->find(platform_id) == managers->end());
208 (*managers)[platform_id].creation_function = creation_function;
209 }
210
GetForPlatform(const se::Platform * platform)211 /* static */ StatusOr<TransferManager*> TransferManager::GetForPlatform(
212 const se::Platform* platform) {
213 tensorflow::mutex_lock lock(
214 TransferManager::platform_transfer_manager_mutex_);
215 auto* managers = GetPlatformTransferManagers();
216
217 auto it = managers->find(platform->id());
218 if (it == managers->end()) {
219 return NotFound(
220 "could not find registered transfer manager for platform %s -- check "
221 "target linkage",
222 platform->Name());
223 }
224
225 if (it->second.manager == nullptr) {
226 // Lazily create the transfer manager the first time it is needed
227 it->second.manager = (*it->second.creation_function)();
228 }
229
230 return it->second.manager.get();
231 }
232
WriteTupleIndexTables(se::Stream * stream,const ShapedBuffer & device_buffer)233 Status TransferManager::WriteTupleIndexTables(
234 se::Stream* stream, const ShapedBuffer& device_buffer) {
235 TF_RETURN_IF_ERROR(WriteTupleIndexTablesAsync(stream, device_buffer));
236 return stream->BlockHostUntilDone();
237 }
238
WriteTupleIndexTablesAsync(se::Stream * stream,const ShapedBuffer & device_buffer)239 Status TransferManager::WriteTupleIndexTablesAsync(
240 se::Stream* stream, const ShapedBuffer& device_buffer) {
241 VLOG(2) << "Writing tuple index tables for " << device_buffer;
242
243 return ShapeUtil::ForEachSubshapeWithStatus(
244 device_buffer.on_device_shape(),
245 [&](const Shape& device_subshape, const ShapeIndex& index) -> Status {
246 if (device_subshape.IsTuple()) {
247 se::DeviceMemoryBase device_memory = device_buffer.buffer(index);
248 TF_RET_CHECK(GetByteSizeRequirement(device_subshape) ==
249 device_memory.size());
250
251 std::vector<se::DeviceMemoryBase> elements;
252 ShapeIndex element_index = index;
253 for (int64 i = 0; i < ShapeUtil::TupleElementCount(device_subshape);
254 ++i) {
255 element_index.push_back(i);
256 elements.push_back(device_buffer.buffer(element_index));
257 element_index.pop_back();
258 }
259 return WriteSingleTupleIndexTable(stream, elements, device_subshape,
260 &device_memory);
261 }
262
263 return Status::OK();
264 });
265 }
266
WriteRootTupleIndexTable(se::Stream * stream,const ShapedBuffer & device_buffer)267 Status TransferManager::WriteRootTupleIndexTable(
268 se::Stream* stream, const ShapedBuffer& device_buffer) {
269 TF_RET_CHECK(device_buffer.on_device_shape().IsTuple());
270 se::DeviceMemoryBase device_memory = device_buffer.buffer({});
271 TF_RET_CHECK(GetByteSizeRequirement(device_buffer.on_device_shape()) ==
272 device_memory.size());
273
274 std::vector<se::DeviceMemoryBase> elements;
275 for (int64 i = 0;
276 i < ShapeUtil::TupleElementCount(device_buffer.on_device_shape()); ++i) {
277 elements.push_back(device_buffer.buffer({i}));
278 }
279 return WriteSingleTupleIndexTable(
280 stream, elements, device_buffer.on_device_shape(), &device_memory);
281 }
282
TransferBufferFromDevice(se::Stream * stream,const se::DeviceMemoryBase & source,int64 size,void * destination)283 Status TransferManager::TransferBufferFromDevice(
284 se::Stream* stream, const se::DeviceMemoryBase& source, int64 size,
285 void* destination) {
286 if (source.size() < size) {
287 return FailedPrecondition(
288 "Source allocation on device not large enough for data tranfer: "
289 "%d < %d",
290 source.size(), size);
291 }
292 stream->ThenMemcpy(destination, source, size);
293 return Status::OK();
294 }
295
TransferBufferToDevice(se::Stream * stream,int64 size,const void * source,se::DeviceMemoryBase * destination)296 Status TransferManager::TransferBufferToDevice(
297 se::Stream* stream, int64 size, const void* source,
298 se::DeviceMemoryBase* destination) {
299 if (destination->size() < size) {
300 return FailedPrecondition(
301 "Destination allocation on device not large enough for data tranfer: "
302 "%d < %d",
303 destination->size(), size);
304 }
305 stream->ThenMemcpy(destination, source, size);
306 return Status::OK();
307 }
308
AllocateScopedShapedBuffer(const Shape & on_host_shape,DeviceMemoryAllocator * allocator,int device_ordinal)309 StatusOr<ScopedShapedBuffer> TransferManager::AllocateScopedShapedBuffer(
310 const Shape& on_host_shape, DeviceMemoryAllocator* allocator,
311 int device_ordinal) {
312 if (!LayoutUtil::HasLayout(on_host_shape)) {
313 return InvalidArgument("Shape must have a layout: %s",
314 ShapeUtil::HumanStringWithLayout(on_host_shape));
315 }
316 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(on_host_shape));
317 const Shape on_device_shape = HostShapeToDeviceShape(on_host_shape);
318 TF_RET_CHECK(LayoutUtil::HasLayout(on_device_shape));
319
320 ScopedShapedBuffer shaped_buffer(on_host_shape, on_device_shape, allocator,
321 device_ordinal);
322
323 // Allocate an appropriate sized buffer for each element in the shape
324 // including the tuple pointer arrays.
325 for (auto& pair : shaped_buffer.buffers()) {
326 const ShapeIndex& index = pair.first;
327 se::DeviceMemoryBase& memory_base = pair.second;
328 const Shape& subshape = ShapeUtil::GetSubshape(on_device_shape, index);
329 TF_ASSIGN_OR_RETURN(auto memory,
330 allocator->Allocate(shaped_buffer.device_ordinal(),
331 GetByteSizeRequirement(subshape)));
332 // Move the allocated buffer into the ScopedShapedBuffer, which owns it.
333 memory_base = memory.Forget();
334 }
335
336 return std::move(shaped_buffer);
337 }
338
339 } // namespace xla
340