• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/generic_transfer_manager.h"
17 
18 #include <string>
19 #include <utility>
20 #include <vector>
21 
22 #include "tensorflow/compiler/xla/layout_util.h"
23 #include "tensorflow/compiler/xla/literal.h"
24 #include "tensorflow/compiler/xla/shape_util.h"
25 #include "tensorflow/compiler/xla/status_macros.h"
26 #include "tensorflow/compiler/xla/types.h"
27 #include "tensorflow/compiler/xla/util.h"
28 #include "tensorflow/core/lib/core/errors.h"
29 #include "tensorflow/core/platform/logging.h"
30 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
31 
32 namespace xla {
33 
GenericTransferManager(se::Platform::Id platform_id,size_t pointer_size)34 GenericTransferManager::GenericTransferManager(se::Platform::Id platform_id,
35                                                size_t pointer_size)
36     : platform_id_(platform_id), pointer_size_(pointer_size) {}
37 
PlatformId() const38 se::Platform::Id GenericTransferManager::PlatformId() const {
39   return platform_id_;
40 }
41 
WriteSingleTupleIndexTable(se::Stream * stream,absl::Span<const se::DeviceMemoryBase> elements,const Shape & shape,se::DeviceMemoryBase * region)42 Status GenericTransferManager::WriteSingleTupleIndexTable(
43     se::Stream* stream, absl::Span<const se::DeviceMemoryBase> elements,
44     const Shape& shape, se::DeviceMemoryBase* region) {
45   TF_RET_CHECK(elements.size() == ShapeUtil::TupleElementCount(shape));
46 
47   auto element_pointers = std::make_shared<std::vector<const void*>>();
48   element_pointers->reserve(elements.size());
49   for (const se::DeviceMemoryBase& element : elements) {
50     element_pointers->push_back(element.opaque());
51   }
52   TF_RETURN_IF_ERROR(TransferBufferToDevice(
53       stream, GetByteSizeRequirement(shape), element_pointers->data(), region));
54   // Ensure the buffer is transferred before we destroy element_pointers.
55   stream->ThenDoHostCallback([element_pointers{std::move(element_pointers)}]() {
56     /* holds reference to element_pointers in closure */
57   });
58   return Status::OK();
59 }
60 
TransferLiteralFromDevice(se::Stream * stream,const ShapedBuffer & device_buffer,MutableBorrowingLiteral literal,std::function<void (Status)> done,const TransferMetadata *)61 void GenericTransferManager::TransferLiteralFromDevice(
62     se::Stream* stream, const ShapedBuffer& device_buffer,
63     MutableBorrowingLiteral literal, std::function<void(Status)> done,
64     const TransferMetadata* /*transfer_metadata*/) {
65   VLOG(2) << "transferring literal from device ordinal "
66           << stream->parent()->device_ordinal()
67           << "; device buffer: " << device_buffer;
68   Status status = [&]() -> Status {
69     TF_RET_CHECK(stream->parent()->device_ordinal() ==
70                  device_buffer.device_ordinal());
71 
72     TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
73         device_buffer.on_device_shape(),
74         [&](const Shape& subshape, const ShapeIndex& index) -> Status {
75           if (subshape.IsArray()) {
76             stream->ThenMemcpy(
77                 /*host_dst=*/literal.untyped_data(index),
78                 /*gpu_src=*/device_buffer.buffer(index),
79                 // With bounded dynamic shapes, the shape of the device buffer
80                 // (bounded allocation) can be bigger than the literal.
81                 /*size=*/
82                 GetByteSizeRequirement(
83                     ShapeUtil::GetSubshape(literal.shape(), index)));
84           }
85           return Status::OK();
86         }));
87     return Status::OK();
88   }();
89   if (!status.ok()) {
90     done(status);
91     return;
92   }
93   done(stream->BlockHostUntilDone());
94 }
95 
TransferLiteralToDeviceAsync(se::Stream * stream,const LiteralSlice & literal,const ShapedBuffer & device_buffer,const TransferMetadata *)96 Status GenericTransferManager::TransferLiteralToDeviceAsync(
97     se::Stream* stream, const LiteralSlice& literal,
98     const ShapedBuffer& device_buffer,
99     const TransferMetadata* /*transfer_metadata*/) {
100   const Shape& shape = literal.shape();
101   VLOG(2) << "transferring literal shape to device: "
102           << ShapeUtil::HumanString(shape)
103           << "; device buffer: " << device_buffer;
104 
105   TF_RET_CHECK(
106       ShapeUtil::Compatible(literal.shape(), device_buffer.on_device_shape()));
107   TF_RET_CHECK(stream->parent()->device_ordinal() ==
108                device_buffer.device_ordinal());
109 
110   TF_RETURN_IF_ERROR(WriteTupleIndexTablesAsync(stream, device_buffer));
111 
112   return ShapeUtil::ForEachSubshapeWithStatus(
113       device_buffer.on_device_shape(),
114       [&](const Shape& device_subshape, const ShapeIndex& index) -> Status {
115         se::DeviceMemoryBase device_memory = device_buffer.buffer(index);
116         if (device_subshape.IsArray()) {
117           TF_RET_CHECK(GetByteSizeRequirement(device_subshape) ==
118                        device_memory.size());
119           // Element is array-shaped: transfer array data to device buffer.
120           const auto subliteral = LiteralSlice(literal, index);
121           Literal relayed_out_literal;
122           const void* source;
123           if (LayoutUtil::Equal(device_subshape.layout(),
124                                 subliteral.shape().layout())) {
125             source = subliteral.untyped_data();
126             return TransferBufferToDevice(
127                 stream,
128                 /*size=*/GetByteSizeRequirement(device_subshape), source,
129                 &device_memory);
130           } else {
131             // Relayout data before transferring.
132             relayed_out_literal = subliteral.Relayout(device_subshape.layout(),
133                                                       /*shape_index=*/{});
134             source = relayed_out_literal.untyped_data();
135             TF_RETURN_IF_ERROR(TransferBufferToDevice(
136                 stream,
137                 /*size=*/GetByteSizeRequirement(device_subshape), source,
138                 &device_memory));
139             return stream->BlockHostUntilDone();
140           }
141         }
142         return Status::OK();
143       });
144 }
145 
TransferLiteralToInfeed(se::StreamExecutor * executor,const LiteralSlice & literal)146 Status GenericTransferManager::TransferLiteralToInfeed(
147     se::StreamExecutor* executor, const LiteralSlice& literal) {
148   return Unimplemented("Generic transfer to Infeed");
149 }
150 
TransferLiteralFromOutfeed(se::StreamExecutor * executor,MutableBorrowingLiteral literal)151 Status GenericTransferManager::TransferLiteralFromOutfeed(
152     se::StreamExecutor* executor, MutableBorrowingLiteral literal) {
153   return Unimplemented("Generic transfer from Outfeed");
154 }
155 
ResetDevices(absl::Span<se::StreamExecutor * const>)156 Status GenericTransferManager::ResetDevices(
157     absl::Span<se::StreamExecutor* const>
158     /*executors*/) {
159   return Unimplemented(
160       "Device reset is not yet supported on this platform (b/30481585)");
161 }
162 
GetByteSizeRequirement(const Shape & shape) const163 int64 GenericTransferManager::GetByteSizeRequirement(const Shape& shape) const {
164   if (shape.is_static() || shape.IsTuple()) {
165     return ShapeUtil::ByteSizeOf(shape, pointer_size_);
166   }
167   int64_t metadata_size = sizeof(int32) * shape.dimensions_size();
168   return ShapeUtil::ByteSizeOf(shape, pointer_size_) + metadata_size;
169 }
170 
171 }  // namespace xla
172