• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/generic_transfer_manager.h"
17 
18 #include <string>
19 #include <utility>
20 #include <vector>
21 
22 #include "tensorflow/compiler/xla/layout_util.h"
23 #include "tensorflow/compiler/xla/literal.h"
24 #include "tensorflow/compiler/xla/shape_util.h"
25 #include "tensorflow/compiler/xla/status_macros.h"
26 #include "tensorflow/compiler/xla/types.h"
27 #include "tensorflow/compiler/xla/util.h"
28 #include "tensorflow/core/lib/core/errors.h"
29 #include "tensorflow/core/platform/logging.h"
30 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
31 
32 namespace xla {
33 
GenericTransferManager(se::Platform::Id platform_id,size_t pointer_size)34 GenericTransferManager::GenericTransferManager(se::Platform::Id platform_id,
35                                                size_t pointer_size)
36     : platform_id_(platform_id), pointer_size_(pointer_size) {}
37 
PlatformId() const38 se::Platform::Id GenericTransferManager::PlatformId() const {
39   return platform_id_;
40 }
41 
WriteSingleTupleIndexTable(se::Stream * stream,absl::Span<const se::DeviceMemoryBase> elements,const Shape & shape,se::DeviceMemoryBase * region)42 Status GenericTransferManager::WriteSingleTupleIndexTable(
43     se::Stream* stream, absl::Span<const se::DeviceMemoryBase> elements,
44     const Shape& shape, se::DeviceMemoryBase* region) {
45   TF_RET_CHECK(elements.size() == ShapeUtil::TupleElementCount(shape));
46 
47   auto element_pointers = std::make_shared<std::vector<const void*>>();
48   element_pointers->reserve(elements.size());
49   for (const se::DeviceMemoryBase& element : elements) {
50     element_pointers->push_back(element.opaque());
51   }
52   TF_RETURN_IF_ERROR(TransferBufferToDevice(
53       stream, GetByteSizeRequirement(shape), element_pointers->data(), region));
54   // Ensure the buffer is transferred before we destroy element_pointers.
55   stream->ThenDoHostCallback([element_pointers{std::move(element_pointers)}]() {
56     /* holds reference to element_pointers in closure */
57   });
58   return Status::OK();
59 }
60 
TransferLiteralFromDevice(se::Stream * stream,const ShapedBuffer & device_buffer,MutableBorrowingLiteral literal,std::function<void (Status)> done,const TransferMetadata *)61 void GenericTransferManager::TransferLiteralFromDevice(
62     se::Stream* stream, const ShapedBuffer& device_buffer,
63     MutableBorrowingLiteral literal, std::function<void(Status)> done,
64     const TransferMetadata* /*transfer_metadata*/) {
65   VLOG(2) << "transferring literal from device ordinal "
66           << stream->parent()->device_ordinal()
67           << "; device buffer: " << device_buffer;
68   Status status = [&]() -> Status {
69     TF_RET_CHECK(stream->parent()->device_ordinal() ==
70                  device_buffer.device_ordinal());
71 
72     TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
73         device_buffer.on_device_shape(),
74         [&](const Shape& subshape, const ShapeIndex& index) -> Status {
75           if (subshape.IsArray()) {
76             stream->ThenMemcpy(
77                 /*host_dst=*/literal.untyped_data(index),
78                 /*gpu_src=*/device_buffer.buffer(index),
79                 /*size=*/GetByteSizeRequirement(subshape));
80           }
81           return Status::OK();
82         }));
83     return Status::OK();
84   }();
85   if (!status.ok()) {
86     done(status);
87     return;
88   }
89   done(stream->BlockHostUntilDone());
90 }
91 
TransferLiteralToDeviceAsync(se::Stream * stream,const LiteralSlice & literal,const ShapedBuffer & device_buffer,const TransferMetadata *)92 Status GenericTransferManager::TransferLiteralToDeviceAsync(
93     se::Stream* stream, const LiteralSlice& literal,
94     const ShapedBuffer& device_buffer,
95     const TransferMetadata* /*transfer_metadata*/) {
96   const Shape& shape = literal.shape();
97   VLOG(2) << "transferring literal shape to device: "
98           << ShapeUtil::HumanString(shape)
99           << "; device buffer: " << device_buffer;
100 
101   TF_RET_CHECK(
102       ShapeUtil::Compatible(literal.shape(), device_buffer.on_device_shape()));
103   TF_RET_CHECK(stream->parent()->device_ordinal() ==
104                device_buffer.device_ordinal());
105 
106   TF_RETURN_IF_ERROR(WriteTupleIndexTablesAsync(stream, device_buffer));
107 
108   return ShapeUtil::ForEachSubshapeWithStatus(
109       device_buffer.on_device_shape(),
110       [&](const Shape& device_subshape, const ShapeIndex& index) -> Status {
111         se::DeviceMemoryBase device_memory = device_buffer.buffer(index);
112         if (device_subshape.IsArray()) {
113           TF_RET_CHECK(GetByteSizeRequirement(device_subshape) ==
114                        device_memory.size());
115           // Element is array-shaped: transfer array data to device buffer.
116           const auto subliteral = LiteralSlice(literal, index);
117           Literal relayed_out_literal;
118           const void* source;
119           if (LayoutUtil::Equal(device_subshape.layout(),
120                                 subliteral.shape().layout())) {
121             source = subliteral.untyped_data();
122             return TransferBufferToDevice(
123                 stream,
124                 /*size=*/GetByteSizeRequirement(device_subshape), source,
125                 &device_memory);
126           } else {
127             // Relayout data before transferring.
128             relayed_out_literal = subliteral.Relayout(device_subshape.layout(),
129                                                       /*shape_index=*/{});
130             source = relayed_out_literal.untyped_data();
131             TF_RETURN_IF_ERROR(TransferBufferToDevice(
132                 stream,
133                 /*size=*/GetByteSizeRequirement(device_subshape), source,
134                 &device_memory));
135             return stream->BlockHostUntilDone();
136           }
137         }
138         return Status::OK();
139       });
140 }
141 
TransferLiteralToInfeed(se::StreamExecutor * executor,const LiteralSlice & literal)142 Status GenericTransferManager::TransferLiteralToInfeed(
143     se::StreamExecutor* executor, const LiteralSlice& literal) {
144   return Unimplemented("Generic transfer to Infeed");
145 }
146 
TransferLiteralFromOutfeed(se::StreamExecutor * executor,MutableBorrowingLiteral literal)147 Status GenericTransferManager::TransferLiteralFromOutfeed(
148     se::StreamExecutor* executor, MutableBorrowingLiteral literal) {
149   return Unimplemented("Generic transfer from Outfeed");
150 }
151 
ResetDevices(absl::Span<se::StreamExecutor * const>)152 Status GenericTransferManager::ResetDevices(
153     absl::Span<se::StreamExecutor* const>
154     /*executors*/) {
155   return Unimplemented(
156       "Device reset is not yet supported on this platform (b/30481585)");
157 }
158 
GetByteSizeRequirement(const Shape & shape) const159 int64 GenericTransferManager::GetByteSizeRequirement(const Shape& shape) const {
160   return ShapeUtil::ByteSizeOf(shape, pointer_size_);
161 }
162 
163 }  // namespace xla
164