• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/jit/xla_device_context.h"
17 
18 #include <memory>
19 
20 #include "tensorflow/compiler/jit/xla_device.h"
21 #include "tensorflow/compiler/jit/xla_launch_util.h"
22 #include "tensorflow/compiler/tf2xla/literal_util.h"
23 #include "tensorflow/compiler/tf2xla/shape_util.h"
24 #include "tensorflow/compiler/xla/util.h"
25 #include "tensorflow/core/common_runtime/device.h"
26 #include "tensorflow/core/common_runtime/dma_helper.h"
27 #include "tensorflow/core/framework/tensor_reference.h"
28 #include "tensorflow/core/platform/mem.h"
29 #include "tensorflow/stream_executor/platform/port.h"
30 
31 namespace tensorflow {
32 
33 // The allocator used for Tensors assigned to the XLA device.
XlaDeviceAllocator(stream_executor::StreamExecutor * stream_executor)34 XlaDeviceAllocator::XlaDeviceAllocator(
35     stream_executor::StreamExecutor* stream_executor)
36     : stream_executor_(stream_executor) {}
37 
38 XlaDeviceAllocator::~XlaDeviceAllocator() = default;
39 
Name()40 string XlaDeviceAllocator::Name() { return "xla"; }
41 
AllocateRaw(size_t alignment,size_t num_bytes)42 void* XlaDeviceAllocator::AllocateRaw(size_t alignment, size_t num_bytes) {
43   // We always return an empty XlaTensor object, encoded as an opaque tagged
44   // pointer. We can return an empty object and ignore num_bytes here because we
45   // have control over all of the uses of this device tensor, and can lazily
46   // allocate memory when used. This allows us to also know the shape of the
47   // allocated Tensor, which is useful if the device's tensor representation
48   // differs from the host.
49   return XlaTensor::ToOpaquePointer(new XlaTensor());
50 }
51 
DeallocateRaw(void * ptr)52 void XlaDeviceAllocator::DeallocateRaw(void* ptr) {
53   delete XlaTensor::FromOpaquePointer(ptr);
54 }
55 
GetStats()56 absl::optional<AllocatorStats> XlaDeviceAllocator::GetStats() {
57   absl::optional<stream_executor::AllocatorStats> se_stats =
58       stream_executor_->GetAllocatorStats();
59   if (!se_stats) {
60     return absl::nullopt;
61   }
62 
63   tensorflow::AllocatorStats tf_stats;
64   tf_stats.num_allocs = se_stats->num_allocs;
65   tf_stats.bytes_in_use = se_stats->bytes_in_use;
66   tf_stats.peak_bytes_in_use = se_stats->peak_bytes_in_use;
67   tf_stats.largest_alloc_size = se_stats->largest_alloc_size;
68   tf_stats.bytes_limit = se_stats->bytes_limit;
69   tf_stats.bytes_reserved = se_stats->bytes_reserved;
70   tf_stats.peak_bytes_reserved = se_stats->peak_bytes_reserved;
71   tf_stats.bytes_reservable_limit = se_stats->bytes_reservable_limit;
72   tf_stats.largest_free_block_bytes = se_stats->largest_free_block_bytes;
73   return tf_stats;
74 }
75 
XlaDeviceContext(std::shared_ptr<se::Stream> compute_stream,std::shared_ptr<se::Stream> host_to_device_stream,std::shared_ptr<se::Stream> device_to_host_stream,std::vector<std::shared_ptr<se::Stream>> device_to_device_streams,xla::LocalClient * client,XlaCompiler::ShapeRepresentationFn shape_representation_fn,thread::ThreadPool * thread_pool,bool use_fast_mem)76 XlaDeviceContext::XlaDeviceContext(
77     std::shared_ptr<se::Stream> compute_stream,
78     std::shared_ptr<se::Stream> host_to_device_stream,
79     std::shared_ptr<se::Stream> device_to_host_stream,
80     std::vector<std::shared_ptr<se::Stream>> device_to_device_streams,
81     xla::LocalClient* client,
82     XlaCompiler::ShapeRepresentationFn shape_representation_fn,
83     thread::ThreadPool* thread_pool, bool use_fast_mem)
84     : stream_(std::move(compute_stream)),
85       host_to_device_stream_(std::move(host_to_device_stream)),
86       device_to_host_stream_(std::move(device_to_host_stream)),
87       device_to_device_streams_(std::move(device_to_device_streams)),
88       client_(client),
89       transfer_manager_(client->backend().transfer_manager()),
90       shape_representation_fn_(std::move(shape_representation_fn)),
91       thread_pool_(thread_pool),
92       use_fast_mem_(use_fast_mem) {
93   CHECK(host_to_device_stream_ != nullptr);
94   CHECK(stream_ != nullptr);
95   if (!shape_representation_fn_) {
96     shape_representation_fn_ =
97         [](const TensorShape& shape, DataType dtype,
98            bool use_fast_memory) -> xla::StatusOr<xla::Shape> {
99       xla::Shape xla_shape;
100       TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype, shape, &xla_shape));
101       return xla_shape;
102     };
103   }
104 }
105 
CopyTensorInSameDevice(const Tensor * input_tensor,Device * device,Tensor * output_tensor,StatusCallback done) const106 void XlaDeviceContext::CopyTensorInSameDevice(const Tensor* input_tensor,
107                                               Device* device,
108                                               Tensor* output_tensor,
109                                               StatusCallback done) const {
110   done(errors::Unimplemented("XLA->XLA same-device copies not implemented."));
111 }
112 
CopyCPUTensorToDevice(const Tensor * cpu_tensor,Device * device,Tensor * device_tensor,StatusCallback done,bool sync_dst_compute) const113 void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
114                                              Device* device,
115                                              Tensor* device_tensor,
116                                              StatusCallback done,
117                                              bool sync_dst_compute) const {
118   if (cpu_tensor->NumElements() == 0) {
119     VLOG(2) << "CopyCPUTensorToDevice empty tensor";
120     done(Status::OK());
121     return;
122   }
123 
124   VLOG(2) << "CopyCPUTensorToDevice use_fast_mem " << use_fast_mem_ << " "
125           << this << " "
126           << reinterpret_cast<const void*>(cpu_tensor->tensor_data().data())
127           << " "
128           << reinterpret_cast<const void*>(device_tensor->tensor_data().data())
129           << " " << cpu_tensor->NumElements() << " "
130           << cpu_tensor->shape().DebugString() << " "
131           << device_tensor->shape().DebugString();
132 
133   XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor);
134   CHECK(xla_tensor);
135 
136   Status status = [&]() -> Status {
137     TF_ASSIGN_OR_RETURN(
138         xla::Shape shape,
139         shape_representation_fn_(device_tensor->shape(), device_tensor->dtype(),
140                                  use_fast_mem_));
141 
142     // The device tensor should always be fresh.
143     TF_RET_CHECK(!xla_tensor->has_shaped_buffer());
144 
145     TF_RETURN_IF_ERROR(
146         xla_tensor->AllocateShapedBuffer(device_tensor->dtype(), shape, client_,
147                                          stream_->parent()->device_ordinal()));
148 
149     // The cpu_tensor and literal that we created here hold the data of host
150     // tensor in descending layout. The layout could be different from layout in
151     // device_tensor (but the logical shape has to be the same). The
152     // transfer_manager is responsible to do corresponding transposing when
153     // transferring the data to device.
154     xla::BorrowingLiteral literal(
155         static_cast<const char*>(DMAHelper::base(cpu_tensor)),
156         xla::ShapeUtil::MakeShape(shape.element_type(),
157                                   xla::AsInt64Slice(shape.dimensions())));
158 
159     VLOG(2) << "Transfer to device as literal: " << literal.ToString() << " "
160             << xla_tensor->shaped_buffer().ToString();
161     if (UseMultipleStreams() &&
162         !transfer_manager_->CanShapedBufferBeAccessedNow(
163             stream_->parent(), xla_tensor->shaped_buffer())) {
164       // Initially wait for the compute stream so that memory allocations are
165       // synchronized.
166       host_to_device_stream_->ThenWaitFor(stream_.get());
167     }
168 
169     TF_RETURN_IF_ERROR(transfer_manager_->TransferLiteralToDeviceAsync(
170         host_to_device_stream_.get(), literal, xla_tensor->shaped_buffer()));
171 
172     if (UseMultipleStreams()) {
173       auto event = std::make_shared<se::Event>(stream_->parent());
174       TF_RET_CHECK(event->Init()) << "Event failed to initialize!";
175       host_to_device_stream_->ThenRecordEvent(event.get());
176       xla_tensor->ResetDefinitionEvent(std::move(event),
177                                        host_to_device_stream_.get());
178     }
179 
180     return Status::OK();
181   }();
182   if (!status.ok()) {
183     done(status);
184     return;
185   }
186 
187   // Create a reference to hold onto cpu_tensor until after the literal has
188   // been transferred
189   TensorReference ref(*cpu_tensor);
190   if (UseMultipleStreams()) {
191     // Unref the host tensor when the transfer completes.
192     // We don't defer the call to done() onto the stream here, and the reasons
193     // why this is correct are subtle. We assume that:
194     // a) all consumers of the device tensor will wait for its definition event.
195     // b) if the tensor is destroyed, then the memory allocator will not hand
196     //    out the same buffers until the transfer has completed.
197     host_to_device_stream_->ThenDoHostCallback([ref]() { ref.Unref(); });
198     done(status);
199   } else {
200     host_to_device_stream_->ThenDoHostCallback([ref, done]() {
201       ref.Unref();
202       done(Status::OK());
203     });
204   }
205 }
206 
CopyDeviceTensorToCPU(const Tensor * device_tensor,absl::string_view tensor_name,Device * device,Tensor * cpu_tensor,StatusCallback done)207 void XlaDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor,
208                                              absl::string_view tensor_name,
209                                              Device* device, Tensor* cpu_tensor,
210                                              StatusCallback done) {
211   if (device_tensor->NumElements() == 0) {
212     VLOG(2) << "CopyDeviceTensorToCPU empty tensor";
213     done(Status::OK());
214     return;
215   }
216   VLOG(2) << "CopyDeviceTensorToCPU "
217           << reinterpret_cast<const void*>(device_tensor->tensor_data().data())
218           << " "
219           << reinterpret_cast<const void*>(cpu_tensor->tensor_data().data())
220           << " " << device_tensor->NumElements() << " "
221           << cpu_tensor->shape().DebugString() << " "
222           << device_tensor->shape().DebugString();
223 
224   std::shared_ptr<se::Stream> device_to_host_stream;
225   if (device_to_host_stream_) {
226     device_to_host_stream = device_to_host_stream_;
227   } else {
228     stream_executor::port::StatusOr<xla::StreamPool::Ptr> ptr_or_status =
229         client_->mutable_backend()->BorrowStream(
230             stream_->parent()->device_ordinal());
231     if (!ptr_or_status.status().ok()) {
232       done(ptr_or_status.status());
233       return;
234     }
235     device_to_host_stream =
236         std::shared_ptr<se::Stream>(std::move(ptr_or_status.ValueOrDie()));
237   }
238 
239   XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor);
240   xla_tensor->WaitForDefinitionEventOnStream(device_to_host_stream.get());
241 
242   // Transfer manager requires the shape of the shaped buffer to be the same as
243   // literal shape except for the layout.  Set the literal to use xla_tensor's
244   // shape as it is derived from the cpu_tensor's shape using
245   // shape_representation_fn_.
246   xla::MutableBorrowingLiteral literal;
247   TF_CHECK_OK(HostTensorToMutableBorrowingLiteral(
248       xla::LayoutUtil::GetWithDefaultLayout(
249           xla_tensor->shaped_buffer().on_host_shape()),
250       cpu_tensor, &literal));
251 
252   TensorReference ref(*device_tensor);
253   const bool device_allows_sync_on_completion =
254       device->AllowsSyncOnCompletion();
255   // Explicitly capture device_to_host_stream to make sure the stream is alive
256   // before the transfer finishes.
257   transfer_manager_->TransferLiteralFromDevice(
258       device_to_host_stream.get(), xla_tensor->shaped_buffer(), literal,
259       [this, ref, xla_tensor, done, device_to_host_stream,
260        device_allows_sync_on_completion](xla::Status status) {
261         Status done_status = status;
262         VLOG(2) << "Transfer from device as literal: "
263                 << xla_tensor->shaped_buffer().ToString();
264         // For devices don't allow sync on completion, the device execution is
265         // deferred. We check the execution stream status here to avoid wrong
266         // results from a failed stream being propagated to following
267         // host-side ops.
268         if (!device_allows_sync_on_completion) {
269           done_status.Update(xla_tensor->RefreshStatusOfStreams());
270         }
271         done(done_status);
272         ref.Unref();
273         // If a stream is in a bad state, it gets deleted when it's returned to
274         // the stream pool, i.e. when it leaves this scope. However, a stream
275         // deleting itself in a host callback on itself can cause bad behaviors
276         // on some platforms. Releasing it in another stream to avoid that.
277         if (!device_allows_sync_on_completion &&
278             !device_to_host_stream->RefreshStatus().ok()) {
279           auto status_or_new_stream = client_->mutable_backend()->BorrowStream(
280               stream_->parent()->device_ordinal());
281           if (status_or_new_stream.ok()) {
282             status_or_new_stream.ValueOrDie()->ThenDoHostCallback(
283                 [device_to_host_stream] {});
284           }
285         }
286       });
287 }
288 
GetDeviceToDeviceStream()289 se::Stream* XlaDeviceContext::GetDeviceToDeviceStream() {
290   DCHECK_GT(device_to_device_streams_.size(), 0);
291   absl::MutexLock lock(&mu_);
292   int stream = next_stream_;
293   next_stream_ = (next_stream_ + 1) % device_to_device_streams_.size();
294   return device_to_device_stream(stream);
295 }
296 
ThenExecute(Device * device,stream_executor::Stream * stream,std::function<void ()> func)297 Status XlaDeviceContext::ThenExecute(Device* device,
298                                      stream_executor::Stream* stream,
299                                      std::function<void()> func) {
300   VLOG(2) << "XlaDeviceContext::ThenExecute";
301   stream->ThenDoHostCallback(std::move(func));
302   return Status::OK();
303 }
304 
305 }  // namespace tensorflow
306