1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/jit/xla_device_context.h"
17
18 #include <memory>
19
20 #include "tensorflow/compiler/jit/xla_device.h"
21 #include "tensorflow/compiler/jit/xla_launch_util.h"
22 #include "tensorflow/compiler/tf2xla/literal_util.h"
23 #include "tensorflow/compiler/tf2xla/shape_util.h"
24 #include "tensorflow/compiler/xla/util.h"
25 #include "tensorflow/core/common_runtime/device.h"
26 #include "tensorflow/core/common_runtime/dma_helper.h"
27 #include "tensorflow/core/platform/mem.h"
28 #include "tensorflow/stream_executor/platform/port.h"
29
30 namespace tensorflow {
31
32 // The allocator used for Tensors assigned to the XLA device.
XlaDeviceAllocator(stream_executor::StreamExecutor * stream_executor)33 XlaDeviceAllocator::XlaDeviceAllocator(
34 stream_executor::StreamExecutor* stream_executor)
35 : stream_executor_(stream_executor) {}
36
37 XlaDeviceAllocator::~XlaDeviceAllocator() = default;
38
Name()39 string XlaDeviceAllocator::Name() { return "xla"; }
40
AllocateRaw(size_t alignment,size_t num_bytes)41 void* XlaDeviceAllocator::AllocateRaw(size_t alignment, size_t num_bytes) {
42 // We always return an empty XlaTensor object, encoded as an opaque tagged
43 // pointer. We can return an empty object and ignore num_bytes here because we
44 // have control over all of the uses of this device tensor, and can lazily
45 // allocate memory when used. This allows us to also know the shape of the
46 // allocated Tensor, which is useful if the device's tensor representation
47 // differs from the host.
48 return XlaTensor::ToOpaquePointer(new XlaTensor());
49 }
50
DeallocateRaw(void * ptr)51 void XlaDeviceAllocator::DeallocateRaw(void* ptr) {
52 delete XlaTensor::FromOpaquePointer(ptr);
53 }
54
GetStats()55 absl::optional<AllocatorStats> XlaDeviceAllocator::GetStats() {
56 absl::optional<stream_executor::AllocatorStats> se_stats =
57 stream_executor_->GetAllocatorStats();
58 if (!se_stats) {
59 return absl::nullopt;
60 }
61
62 tensorflow::AllocatorStats tf_stats;
63 tf_stats.num_allocs = se_stats->num_allocs;
64 tf_stats.bytes_in_use = se_stats->bytes_in_use;
65 tf_stats.peak_bytes_in_use = se_stats->peak_bytes_in_use;
66 tf_stats.largest_alloc_size = se_stats->largest_alloc_size;
67 tf_stats.bytes_limit = se_stats->bytes_limit;
68 return tf_stats;
69 }
70
XlaDeviceContext(std::shared_ptr<se::Stream> compute_stream,std::shared_ptr<se::Stream> host_to_device_stream,std::shared_ptr<se::Stream> device_to_host_stream,std::vector<std::shared_ptr<se::Stream>> device_to_device_streams,xla::LocalClient * client,XlaCompiler::ShapeRepresentationFn shape_representation_fn,thread::ThreadPool * thread_pool)71 XlaDeviceContext::XlaDeviceContext(
72 std::shared_ptr<se::Stream> compute_stream,
73 std::shared_ptr<se::Stream> host_to_device_stream,
74 std::shared_ptr<se::Stream> device_to_host_stream,
75 std::vector<std::shared_ptr<se::Stream>> device_to_device_streams,
76 xla::LocalClient* client,
77 XlaCompiler::ShapeRepresentationFn shape_representation_fn,
78 thread::ThreadPool* thread_pool)
79 : stream_(std::move(compute_stream)),
80 host_to_device_stream_(std::move(host_to_device_stream)),
81 device_to_host_stream_(std::move(device_to_host_stream)),
82 device_to_device_streams_(std::move(device_to_device_streams)),
83 client_(client),
84 transfer_manager_(client->backend().transfer_manager()),
85 shape_representation_fn_(std::move(shape_representation_fn)),
86 thread_pool_(thread_pool) {
87 CHECK(host_to_device_stream_ != nullptr);
88 CHECK(stream_ != nullptr);
89 if (!shape_representation_fn_) {
90 shape_representation_fn_ = [](const TensorShape& shape,
91 DataType dtype) -> xla::StatusOr<xla::Shape> {
92 xla::Shape xla_shape;
93 TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype, shape, &xla_shape));
94 return xla_shape;
95 };
96 }
97 }
98
CopyTensorInSameDevice(const Tensor * input_tensor,Device * device,Tensor * output_tensor,StatusCallback done) const99 void XlaDeviceContext::CopyTensorInSameDevice(const Tensor* input_tensor,
100 Device* device,
101 Tensor* output_tensor,
102 StatusCallback done) const {
103 done(errors::Unimplemented("XLA->XLA same-device copies not implemented."));
104 }
105
CopyCPUTensorToDevice(const Tensor * cpu_tensor,Device * device,Tensor * device_tensor,StatusCallback done) const106 void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
107 Device* device,
108 Tensor* device_tensor,
109 StatusCallback done) const {
110 if (cpu_tensor->NumElements() == 0) {
111 VLOG(2) << "CopyCPUTensorToDevice empty tensor";
112 done(Status::OK());
113 return;
114 }
115
116 VLOG(2) << "CopyCPUTensorToDevice "
117 << reinterpret_cast<const void*>(cpu_tensor->tensor_data().data())
118 << " "
119 << reinterpret_cast<const void*>(device_tensor->tensor_data().data())
120 << " " << cpu_tensor->NumElements() << " "
121 << cpu_tensor->shape().DebugString() << " "
122 << device_tensor->shape().DebugString();
123
124
125 XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor);
126 CHECK(xla_tensor);
127
128 Status status = [&]() -> Status {
129 TF_ASSIGN_OR_RETURN(xla::Shape shape,
130 shape_representation_fn_(device_tensor->shape(),
131 device_tensor->dtype()));
132
133 // The device tensor should always be fresh.
134 TF_RET_CHECK(!xla_tensor->has_shaped_buffer());
135
136 xla_tensor->set_host_tensor(*cpu_tensor);
137 TF_RETURN_IF_ERROR(
138 xla_tensor->AllocateShapedBuffer(device_tensor->dtype(), shape, client_,
139 stream_->parent()->device_ordinal()));
140
141 // The cpu_tensor and literal that we created here hold the data of host
142 // tensor in descending layout. The layout could be different from layout in
143 // device_tensor (but the logical shape has to be the same). The
144 // transfer_manager is responsible to do corresponding transposing when
145 // transferring the data to device.
146 xla::BorrowingLiteral literal(
147 static_cast<const char*>(DMAHelper::base(cpu_tensor)),
148 xla::ShapeUtil::MakeShape(shape.element_type(),
149 xla::AsInt64Slice(shape.dimensions())));
150
151 VLOG(2) << "Transfer to device as literal: " << literal.ToString() << " "
152 << xla_tensor->shaped_buffer().ToString();
153 if (UseMultipleStreams() &&
154 !transfer_manager_->CanShapedBufferBeAccessedNow(
155 stream_->parent(), xla_tensor->shaped_buffer())) {
156 // Initially wait for the compute stream so that memory allocations are
157 // synchronized.
158 host_to_device_stream_->ThenWaitFor(stream_.get());
159 }
160
161 TF_RETURN_IF_ERROR(transfer_manager_->TransferLiteralToDeviceAsync(
162 host_to_device_stream_.get(), literal, xla_tensor->shaped_buffer()));
163
164 if (UseMultipleStreams()) {
165 auto event = std::make_shared<se::Event>(stream_->parent());
166 TF_RET_CHECK(event->Init()) << "Event failed to initialize!";
167 host_to_device_stream_->ThenRecordEvent(event.get());
168 xla_tensor->ResetDefinitionEvent(std::move(event),
169 host_to_device_stream_.get());
170 }
171
172 return Status::OK();
173 }();
174 if (!status.ok()) {
175 done(status);
176 return;
177 }
178
179 // Create a reference to hold onto cpu_tensor until after the literal has
180 // been transferred
181 TensorReference ref(*cpu_tensor);
182 if (UseMultipleStreams()) {
183 // Unref the host tensor when the transfer completes.
184 // We don't defer the call to done() onto the stream here, and the reasons
185 // why this is correct are subtle. We assume that:
186 // a) all consumers of the device tensor will wait for its definition event.
187 // b) if the tensor is destroyed, then the memory allocator will not hand
188 // out the same buffers until the transfer has completed.
189 host_to_device_stream_->ThenDoHostCallback([ref]() { ref.Unref(); });
190 done(status);
191 } else {
192 host_to_device_stream_->ThenDoHostCallback([ref, done]() {
193 ref.Unref();
194 done(Status::OK());
195 });
196 }
197 }
198
CopyDeviceTensorToCPU(const Tensor * device_tensor,absl::string_view tensor_name,Device * device,Tensor * cpu_tensor,StatusCallback done)199 void XlaDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor,
200 absl::string_view tensor_name,
201 Device* device, Tensor* cpu_tensor,
202 StatusCallback done) {
203 if (device_tensor->NumElements() == 0) {
204 VLOG(2) << "CopyDeviceTensorToCPU empty tensor";
205 done(Status::OK());
206 return;
207 }
208 VLOG(2) << "CopyDeviceTensorToCPU "
209 << reinterpret_cast<const void*>(device_tensor->tensor_data().data())
210 << " "
211 << reinterpret_cast<const void*>(cpu_tensor->tensor_data().data())
212 << " " << device_tensor->NumElements() << " "
213 << cpu_tensor->shape().DebugString() << " "
214 << device_tensor->shape().DebugString();
215
216 std::shared_ptr<se::Stream> device_to_host_stream;
217 if (device_to_host_stream_) {
218 device_to_host_stream = device_to_host_stream_;
219 } else {
220 stream_executor::port::StatusOr<xla::StreamPool::Ptr> ptr_or_status =
221 client_->mutable_backend()->BorrowStream(
222 stream_->parent()->device_ordinal());
223 if (!ptr_or_status.status().ok()) {
224 done(ptr_or_status.status());
225 return;
226 }
227 device_to_host_stream =
228 std::shared_ptr<se::Stream>(std::move(ptr_or_status.ValueOrDie()));
229 }
230
231 XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor);
232 xla_tensor->WaitForDefinitionEventOnStream(device_to_host_stream.get());
233
234 // Transfer manager requires the shape of the shaped buffer to be the same as
235 // literal shape except for the layout. Set the literal to use xla_tensor's
236 // shape as it is derived from the cpu_tensor's shape using
237 // shape_representation_fn_.
238 xla::MutableBorrowingLiteral literal;
239 TF_CHECK_OK(HostTensorToMutableBorrowingLiteral(
240 xla::LayoutUtil::GetWithDefaultLayout(
241 xla_tensor->shaped_buffer().on_host_shape()),
242 cpu_tensor, &literal));
243
244 TensorReference ref(*device_tensor);
245 // Explicitly capture device_to_host_stream to make sure the stream is alive
246 // before the transfer finishes.
247 transfer_manager_->TransferLiteralFromDevice(
248 device_to_host_stream.get(), xla_tensor->shaped_buffer(), literal,
249 [ref, xla_tensor, done, device_to_host_stream](xla::Status status) {
250 done([&]() -> Status {
251 VLOG(2) << "Transfer from device as literal: "
252 << xla_tensor->shaped_buffer().ToString();
253 return status;
254 }());
255 ref.Unref();
256 });
257 }
258
GetDeviceToDeviceStream()259 se::Stream* XlaDeviceContext::GetDeviceToDeviceStream() {
260 DCHECK_GT(device_to_device_streams_.size(), 0);
261 absl::MutexLock lock(&mu_);
262 int stream = next_stream_;
263 next_stream_ = (next_stream_ + 1) % device_to_device_streams_.size();
264 return device_to_device_stream(stream);
265 }
266
267 } // namespace tensorflow
268