1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/jit/xla_device_context.h"
17
18 #include <memory>
19
20 #include "tensorflow/compiler/jit/xla_device.h"
21 #include "tensorflow/compiler/jit/xla_launch_util.h"
22 #include "tensorflow/compiler/tf2xla/literal_util.h"
23 #include "tensorflow/compiler/tf2xla/shape_util.h"
24 #include "tensorflow/compiler/xla/util.h"
25 #include "tensorflow/core/common_runtime/device.h"
26 #include "tensorflow/core/common_runtime/dma_helper.h"
27 #include "tensorflow/core/framework/tensor_reference.h"
28 #include "tensorflow/core/platform/mem.h"
29 #include "tensorflow/stream_executor/platform/port.h"
30
31 namespace tensorflow {
32
33 // The allocator used for Tensors assigned to the XLA device.
XlaDeviceAllocator(stream_executor::StreamExecutor * stream_executor)34 XlaDeviceAllocator::XlaDeviceAllocator(
35 stream_executor::StreamExecutor* stream_executor)
36 : stream_executor_(stream_executor) {}
37
38 XlaDeviceAllocator::~XlaDeviceAllocator() = default;
39
Name()40 string XlaDeviceAllocator::Name() { return "xla"; }
41
AllocateRaw(size_t alignment,size_t num_bytes)42 void* XlaDeviceAllocator::AllocateRaw(size_t alignment, size_t num_bytes) {
43 // We always return an empty XlaTensor object, encoded as an opaque tagged
44 // pointer. We can return an empty object and ignore num_bytes here because we
45 // have control over all of the uses of this device tensor, and can lazily
46 // allocate memory when used. This allows us to also know the shape of the
47 // allocated Tensor, which is useful if the device's tensor representation
48 // differs from the host.
49 return XlaTensor::ToOpaquePointer(new XlaTensor());
50 }
51
DeallocateRaw(void * ptr)52 void XlaDeviceAllocator::DeallocateRaw(void* ptr) {
53 delete XlaTensor::FromOpaquePointer(ptr);
54 }
55
GetStats()56 absl::optional<AllocatorStats> XlaDeviceAllocator::GetStats() {
57 absl::optional<stream_executor::AllocatorStats> se_stats =
58 stream_executor_->GetAllocatorStats();
59 if (!se_stats) {
60 return absl::nullopt;
61 }
62
63 tensorflow::AllocatorStats tf_stats;
64 tf_stats.num_allocs = se_stats->num_allocs;
65 tf_stats.bytes_in_use = se_stats->bytes_in_use;
66 tf_stats.peak_bytes_in_use = se_stats->peak_bytes_in_use;
67 tf_stats.largest_alloc_size = se_stats->largest_alloc_size;
68 tf_stats.bytes_limit = se_stats->bytes_limit;
69 tf_stats.bytes_reserved = se_stats->bytes_reserved;
70 tf_stats.peak_bytes_reserved = se_stats->peak_bytes_reserved;
71 tf_stats.bytes_reservable_limit = se_stats->bytes_reservable_limit;
72 tf_stats.largest_free_block_bytes = se_stats->largest_free_block_bytes;
73 return tf_stats;
74 }
75
XlaDeviceContext(std::shared_ptr<se::Stream> compute_stream,std::shared_ptr<se::Stream> host_to_device_stream,std::shared_ptr<se::Stream> device_to_host_stream,std::vector<std::shared_ptr<se::Stream>> device_to_device_streams,xla::LocalClient * client,XlaCompiler::ShapeRepresentationFn shape_representation_fn,thread::ThreadPool * thread_pool,bool use_fast_mem)76 XlaDeviceContext::XlaDeviceContext(
77 std::shared_ptr<se::Stream> compute_stream,
78 std::shared_ptr<se::Stream> host_to_device_stream,
79 std::shared_ptr<se::Stream> device_to_host_stream,
80 std::vector<std::shared_ptr<se::Stream>> device_to_device_streams,
81 xla::LocalClient* client,
82 XlaCompiler::ShapeRepresentationFn shape_representation_fn,
83 thread::ThreadPool* thread_pool, bool use_fast_mem)
84 : stream_(std::move(compute_stream)),
85 host_to_device_stream_(std::move(host_to_device_stream)),
86 device_to_host_stream_(std::move(device_to_host_stream)),
87 device_to_device_streams_(std::move(device_to_device_streams)),
88 client_(client),
89 transfer_manager_(client->backend().transfer_manager()),
90 shape_representation_fn_(std::move(shape_representation_fn)),
91 thread_pool_(thread_pool),
92 use_fast_mem_(use_fast_mem) {
93 CHECK(host_to_device_stream_ != nullptr);
94 CHECK(stream_ != nullptr);
95 if (!shape_representation_fn_) {
96 shape_representation_fn_ =
97 [](const TensorShape& shape, DataType dtype,
98 bool use_fast_memory) -> xla::StatusOr<xla::Shape> {
99 xla::Shape xla_shape;
100 TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype, shape, &xla_shape));
101 return xla_shape;
102 };
103 }
104 }
105
CopyTensorInSameDevice(const Tensor * input_tensor,Device * device,Tensor * output_tensor,StatusCallback done) const106 void XlaDeviceContext::CopyTensorInSameDevice(const Tensor* input_tensor,
107 Device* device,
108 Tensor* output_tensor,
109 StatusCallback done) const {
110 done(errors::Unimplemented("XLA->XLA same-device copies not implemented."));
111 }
112
CopyCPUTensorToDevice(const Tensor * cpu_tensor,Device * device,Tensor * device_tensor,StatusCallback done,bool sync_dst_compute) const113 void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
114 Device* device,
115 Tensor* device_tensor,
116 StatusCallback done,
117 bool sync_dst_compute) const {
118 if (cpu_tensor->NumElements() == 0) {
119 VLOG(2) << "CopyCPUTensorToDevice empty tensor";
120 done(Status::OK());
121 return;
122 }
123
124 VLOG(2) << "CopyCPUTensorToDevice use_fast_mem " << use_fast_mem_ << " "
125 << this << " "
126 << reinterpret_cast<const void*>(cpu_tensor->tensor_data().data())
127 << " "
128 << reinterpret_cast<const void*>(device_tensor->tensor_data().data())
129 << " " << cpu_tensor->NumElements() << " "
130 << cpu_tensor->shape().DebugString() << " "
131 << device_tensor->shape().DebugString();
132
133 XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor);
134 CHECK(xla_tensor);
135
136 Status status = [&]() -> Status {
137 TF_ASSIGN_OR_RETURN(
138 xla::Shape shape,
139 shape_representation_fn_(device_tensor->shape(), device_tensor->dtype(),
140 use_fast_mem_));
141
142 // The device tensor should always be fresh.
143 TF_RET_CHECK(!xla_tensor->has_shaped_buffer());
144
145 TF_RETURN_IF_ERROR(
146 xla_tensor->AllocateShapedBuffer(device_tensor->dtype(), shape, client_,
147 stream_->parent()->device_ordinal()));
148
149 // The cpu_tensor and literal that we created here hold the data of host
150 // tensor in descending layout. The layout could be different from layout in
151 // device_tensor (but the logical shape has to be the same). The
152 // transfer_manager is responsible to do corresponding transposing when
153 // transferring the data to device.
154 xla::BorrowingLiteral literal(
155 static_cast<const char*>(DMAHelper::base(cpu_tensor)),
156 xla::ShapeUtil::MakeShape(shape.element_type(),
157 xla::AsInt64Slice(shape.dimensions())));
158
159 VLOG(2) << "Transfer to device as literal: " << literal.ToString() << " "
160 << xla_tensor->shaped_buffer().ToString();
161 if (UseMultipleStreams() &&
162 !transfer_manager_->CanShapedBufferBeAccessedNow(
163 stream_->parent(), xla_tensor->shaped_buffer())) {
164 // Initially wait for the compute stream so that memory allocations are
165 // synchronized.
166 host_to_device_stream_->ThenWaitFor(stream_.get());
167 }
168
169 TF_RETURN_IF_ERROR(transfer_manager_->TransferLiteralToDeviceAsync(
170 host_to_device_stream_.get(), literal, xla_tensor->shaped_buffer()));
171
172 if (UseMultipleStreams()) {
173 auto event = std::make_shared<se::Event>(stream_->parent());
174 TF_RET_CHECK(event->Init()) << "Event failed to initialize!";
175 host_to_device_stream_->ThenRecordEvent(event.get());
176 xla_tensor->ResetDefinitionEvent(std::move(event),
177 host_to_device_stream_.get());
178 }
179
180 return Status::OK();
181 }();
182 if (!status.ok()) {
183 done(status);
184 return;
185 }
186
187 // Create a reference to hold onto cpu_tensor until after the literal has
188 // been transferred
189 TensorReference ref(*cpu_tensor);
190 if (UseMultipleStreams()) {
191 // Unref the host tensor when the transfer completes.
192 // We don't defer the call to done() onto the stream here, and the reasons
193 // why this is correct are subtle. We assume that:
194 // a) all consumers of the device tensor will wait for its definition event.
195 // b) if the tensor is destroyed, then the memory allocator will not hand
196 // out the same buffers until the transfer has completed.
197 host_to_device_stream_->ThenDoHostCallback([ref]() { ref.Unref(); });
198 done(status);
199 } else {
200 host_to_device_stream_->ThenDoHostCallback([ref, done]() {
201 ref.Unref();
202 done(Status::OK());
203 });
204 }
205 }
206
CopyDeviceTensorToCPU(const Tensor * device_tensor,absl::string_view tensor_name,Device * device,Tensor * cpu_tensor,StatusCallback done)207 void XlaDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor,
208 absl::string_view tensor_name,
209 Device* device, Tensor* cpu_tensor,
210 StatusCallback done) {
211 if (device_tensor->NumElements() == 0) {
212 VLOG(2) << "CopyDeviceTensorToCPU empty tensor";
213 done(Status::OK());
214 return;
215 }
216 VLOG(2) << "CopyDeviceTensorToCPU "
217 << reinterpret_cast<const void*>(device_tensor->tensor_data().data())
218 << " "
219 << reinterpret_cast<const void*>(cpu_tensor->tensor_data().data())
220 << " " << device_tensor->NumElements() << " "
221 << cpu_tensor->shape().DebugString() << " "
222 << device_tensor->shape().DebugString();
223
224 std::shared_ptr<se::Stream> device_to_host_stream;
225 if (device_to_host_stream_) {
226 device_to_host_stream = device_to_host_stream_;
227 } else {
228 stream_executor::port::StatusOr<xla::StreamPool::Ptr> ptr_or_status =
229 client_->mutable_backend()->BorrowStream(
230 stream_->parent()->device_ordinal());
231 if (!ptr_or_status.status().ok()) {
232 done(ptr_or_status.status());
233 return;
234 }
235 device_to_host_stream =
236 std::shared_ptr<se::Stream>(std::move(ptr_or_status.ValueOrDie()));
237 }
238
239 XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor);
240 xla_tensor->WaitForDefinitionEventOnStream(device_to_host_stream.get());
241
242 // Transfer manager requires the shape of the shaped buffer to be the same as
243 // literal shape except for the layout. Set the literal to use xla_tensor's
244 // shape as it is derived from the cpu_tensor's shape using
245 // shape_representation_fn_.
246 xla::MutableBorrowingLiteral literal;
247 TF_CHECK_OK(HostTensorToMutableBorrowingLiteral(
248 xla::LayoutUtil::GetWithDefaultLayout(
249 xla_tensor->shaped_buffer().on_host_shape()),
250 cpu_tensor, &literal));
251
252 TensorReference ref(*device_tensor);
253 const bool device_allows_sync_on_completion =
254 device->AllowsSyncOnCompletion();
255 // Explicitly capture device_to_host_stream to make sure the stream is alive
256 // before the transfer finishes.
257 transfer_manager_->TransferLiteralFromDevice(
258 device_to_host_stream.get(), xla_tensor->shaped_buffer(), literal,
259 [this, ref, xla_tensor, done, device_to_host_stream,
260 device_allows_sync_on_completion](xla::Status status) {
261 Status done_status = status;
262 VLOG(2) << "Transfer from device as literal: "
263 << xla_tensor->shaped_buffer().ToString();
264 // For devices don't allow sync on completion, the device execution is
265 // deferred. We check the execution stream status here to avoid wrong
266 // results from a failed stream being propagated to following
267 // host-side ops.
268 if (!device_allows_sync_on_completion) {
269 done_status.Update(xla_tensor->RefreshStatusOfStreams());
270 }
271 done(done_status);
272 ref.Unref();
273 // If a stream is in a bad state, it gets deleted when it's returned to
274 // the stream pool, i.e. when it leaves this scope. However, a stream
275 // deleting itself in a host callback on itself can cause bad behaviors
276 // on some platforms. Releasing it in another stream to avoid that.
277 if (!device_allows_sync_on_completion &&
278 !device_to_host_stream->RefreshStatus().ok()) {
279 auto status_or_new_stream = client_->mutable_backend()->BorrowStream(
280 stream_->parent()->device_ordinal());
281 if (status_or_new_stream.ok()) {
282 status_or_new_stream.ValueOrDie()->ThenDoHostCallback(
283 [device_to_host_stream] {});
284 }
285 }
286 });
287 }
288
GetDeviceToDeviceStream()289 se::Stream* XlaDeviceContext::GetDeviceToDeviceStream() {
290 DCHECK_GT(device_to_device_streams_.size(), 0);
291 absl::MutexLock lock(&mu_);
292 int stream = next_stream_;
293 next_stream_ = (next_stream_ + 1) % device_to_device_streams_.size();
294 return device_to_device_stream(stream);
295 }
296
ThenExecute(Device * device,stream_executor::Stream * stream,std::function<void ()> func)297 Status XlaDeviceContext::ThenExecute(Device* device,
298 stream_executor::Stream* stream,
299 std::function<void()> func) {
300 VLOG(2) << "XlaDeviceContext::ThenExecute";
301 stream->ThenDoHostCallback(std::move(func));
302 return Status::OK();
303 }
304
305 } // namespace tensorflow
306