1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/jit/xla_device.h"
17
18 #include <stdlib.h>
19
20 #include <unordered_set>
21 #include <utility>
22
23 #include "absl/base/call_once.h"
24 #include "absl/memory/memory.h"
25 #include "absl/strings/match.h"
26 #include "tensorflow/compiler/jit/defs.h"
27 #include "tensorflow/compiler/jit/xla_compile_on_demand_op.h"
28 #include "tensorflow/compiler/jit/xla_device_context.h"
29 #include "tensorflow/compiler/jit/xla_device_ops.h"
30 #include "tensorflow/compiler/tf2xla/shape_util.h"
31 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
32 #include "tensorflow/compiler/xla/client/client_library.h"
33 #include "tensorflow/compiler/xla/service/stream_pool.h"
34 #include "tensorflow/core/common_runtime/device.h"
35 #include "tensorflow/core/common_runtime/device_factory.h"
36 #include "tensorflow/core/common_runtime/dma_helper.h"
37 #include "tensorflow/core/common_runtime/function.h"
38 #include "tensorflow/core/common_runtime/graph_constructor.h"
39 #include "tensorflow/core/common_runtime/renamed_device.h"
40 #include "tensorflow/core/framework/allocator.h"
41 #include "tensorflow/core/framework/device_base.h"
42 #include "tensorflow/core/framework/function.h"
43 #include "tensorflow/core/framework/kernel_def.pb.h"
44 #include "tensorflow/core/framework/node_def_builder.h"
45 #include "tensorflow/core/framework/op_kernel.h"
46 #include "tensorflow/core/framework/tensor.h"
47 #include "tensorflow/core/framework/tensor.pb.h"
48 #include "tensorflow/core/framework/types.h"
49 #include "tensorflow/core/lib/core/notification.h"
50 #include "tensorflow/core/lib/core/status.h"
51 #include "tensorflow/core/platform/logging.h"
52 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
53 #include "tensorflow/core/platform/tracing.h"
54 #include "tensorflow/core/profiler/lib/traceme.h"
55 #include "tensorflow/core/public/session_options.h"
56 #include "tensorflow/core/public/version.h"
57 #include "tensorflow/core/util/device_name_utils.h"
58 #include "tensorflow/core/util/dump_graph.h"
59 #include "tensorflow/core/util/ptr_util.h"
60 #include "tensorflow/core/util/stream_executor_util.h"
61
62 namespace tensorflow {
63
64 // Default PaddedShapeFn implementation that simply returns the unpadded
65 // on-device shape. This is accurate for CPU and GPU devices that neither
66 // transpose nor pad tensors.
DefaultPaddedShapeFn(const Tensor & tensor,xla::Shape * shape)67 Status DefaultPaddedShapeFn(const Tensor& tensor, xla::Shape* shape) {
68 const tensorflow::XlaTensor* xla_tensor =
69 tensorflow::XlaTensor::FromTensor(&tensor);
70 if (xla_tensor == nullptr) {
71 return TensorShapeToXLAShape(tensor.dtype(), tensor.shape(), shape);
72 }
73
74 const xla::ShapedBuffer& shaped_buffer = xla_tensor->shaped_buffer();
75 *shape = shaped_buffer.on_device_shape();
76 return Status::OK();
77 }
78
79 // Caches a XlaDeviceAllocator per <backend, device ordinal> pair. A
80 // XlaDeviceAllocator is created on demand and is associated with a
81 // XlaDevice. It outlives the device itself (for instance, the buffer
82 // backing a tensor holds a pointer to the allocator for book-keeping,
83 // and this buffer can outlast the device).
84 class XlaDeviceAllocatorState {
85 public:
86 // Creates or returns a cached XlaDeviceAllocator for a given
87 // backend and device_ordinal.
88 static XlaDeviceAllocator* GetOrCreateXlaDeviceAllocator(
89 const xla::Backend* backend, int device_ordinal);
90
91 private:
92 // Returns the singleton instance of XlaDeviceAllocatorState.
93 static XlaDeviceAllocatorState& Singleton();
94 XlaDeviceAllocatorState();
95 ~XlaDeviceAllocatorState();
96
97 mutex allocator_mutex_; // Guards the singleton allocator state.
98 std::unordered_map<std::pair<const xla::Backend*, int>,
99 std::unique_ptr<XlaDeviceAllocator>,
100 hash<std::pair<const xla::Backend*, int>>>
101 allocators_ TF_GUARDED_BY(allocator_mutex_);
102
103 TF_DISALLOW_COPY_AND_ASSIGN(XlaDeviceAllocatorState);
104 };
105
Singleton()106 /* static */ XlaDeviceAllocatorState& XlaDeviceAllocatorState::Singleton() {
107 static auto a = new XlaDeviceAllocatorState;
108 return *a;
109 }
110
111 XlaDeviceAllocatorState::XlaDeviceAllocatorState() = default;
112 XlaDeviceAllocatorState::~XlaDeviceAllocatorState() = default;
113
GetOrCreateXlaDeviceAllocator(const xla::Backend * backend,int device_ordinal)114 XlaDeviceAllocator* XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator(
115 const xla::Backend* backend, int device_ordinal) {
116 XlaDeviceAllocatorState& state = Singleton();
117 mutex_lock lock(state.allocator_mutex_);
118
119 auto it = state.allocators_.find({backend, device_ordinal});
120 if (it != state.allocators_.end()) {
121 return it->second.get();
122 }
123
124 std::unique_ptr<XlaDeviceAllocator> alloc =
125 absl::make_unique<XlaDeviceAllocator>(
126 backend->stream_executors()[device_ordinal]);
127 XlaDeviceAllocator* alloc_ptr = alloc.get();
128 state.allocators_[{backend, device_ordinal}] = std::move(alloc);
129 return alloc_ptr;
130 }
131
132 namespace {
133
134
BuildXlaDeviceAttributes(const string & name_prefix,const string & device_name,int device_ordinal)135 static DeviceAttributes BuildXlaDeviceAttributes(const string& name_prefix,
136 const string& device_name,
137 int device_ordinal) {
138 return Device::BuildDeviceAttributes(
139 absl::StrCat(name_prefix, "/device:", device_name, ":", device_ordinal),
140 DeviceType(device_name), Bytes(16ULL << 30), DeviceLocality(),
141 absl::StrCat("device: ", device_name, " device"));
142 }
143
144 } // namespace
145
Metadata(int device_ordinal,se::Platform * platform,const DeviceType & device_type,XlaCompiler::ShapeRepresentationFn shape_representation_fn,PaddedShapeFn padded_shape_fn,bool use_multiple_streams)146 XlaDevice::Metadata::Metadata(
147 int device_ordinal, se::Platform* platform, const DeviceType& device_type,
148 XlaCompiler::ShapeRepresentationFn shape_representation_fn,
149 PaddedShapeFn padded_shape_fn, bool use_multiple_streams)
150 : device_ordinal_(device_ordinal),
151 device_type_(device_type),
152 platform_(platform),
153 shape_representation_fn_(std::move(shape_representation_fn)),
154 padded_shape_fn_(std::move(padded_shape_fn)),
155 use_multiple_streams_(use_multiple_streams) {}
156
device_ordinal() const157 int XlaDevice::Metadata::device_ordinal() const { return device_ordinal_; }
158
platform() const159 se::Platform* XlaDevice::Metadata::platform() const { return platform_; }
160
client() const161 xla::LocalClient* XlaDevice::Metadata::client() const {
162 auto client = xla::ClientLibrary::GetOrCreateLocalClient(platform_);
163 return client.ValueOrDie();
164 }
165
jit_device_type() const166 const DeviceType& XlaDevice::Metadata::jit_device_type() const {
167 return device_type_;
168 }
169
GetMetadataFromDevice(DeviceBase * device,const XlaDevice::Metadata ** metadata)170 /*static*/ Status XlaDevice::GetMetadataFromDevice(
171 DeviceBase* device, const XlaDevice::Metadata** metadata) {
172 *metadata = nullptr;
173 XlaDevice* xla_device = dynamic_cast<XlaDevice*>(device->UnderlyingDevice());
174 if (xla_device == nullptr) {
175 return errors::Internal(
176 "Cannot get XLA metadata from non-XLA device \"", device->name(),
177 "\". GetMetadata must only be called on an XLA device. Either an "
178 "internal bug has been triggered, or an XLA-specific op has been "
179 "placed on the wrong device.");
180 }
181 *metadata = &(xla_device->xla_metadata_);
182 return Status::OK();
183 }
184
GetMetadata(OpKernelContext * ctx,const Metadata ** metadata)185 /* static */ Status XlaDevice::GetMetadata(OpKernelContext* ctx,
186 const Metadata** metadata) {
187 return GetMetadataFromDevice(ctx->device(), metadata);
188 }
189
GetMetadata(OpKernelConstruction * ctx,const Metadata ** metadata)190 /* static */ Status XlaDevice::GetMetadata(OpKernelConstruction* ctx,
191 const Metadata** metadata) {
192 return GetMetadataFromDevice(ctx->device(), metadata);
193 }
194
195 /* static */ mutex XlaDevice::global_mu_(LINKER_INITIALIZED);
196 /* static */ std::vector<std::shared_ptr<se::Stream>>*
197 XlaDevice::global_compute_streams_ =
198 new std::vector<std::shared_ptr<se::Stream>>;
199
XlaDevice(const SessionOptions & session_options,const Options & options)200 XlaDevice::XlaDevice(const SessionOptions& session_options,
201 const Options& options)
202 : LocalDevice(session_options,
203 BuildXlaDeviceAttributes(options.device_name_prefix,
204 options.device_name,
205 options.device_ordinal)),
206 xla_metadata_(options.device_ordinal, options.platform,
207 DeviceType(options.compilation_device_name),
208 options.shape_representation_fn,
209 options.padded_shape_fn ? options.padded_shape_fn
210 : DefaultPaddedShapeFn,
211 options.use_multiple_streams),
212 device_ordinal_(options.device_ordinal),
213 jit_device_name_(options.compilation_device_name),
214 platform_(options.platform),
215 intra_op_parallelism_threads_(
216 session_options.config.intra_op_parallelism_threads()),
217 use_multiple_streams_(options.use_multiple_streams),
218 shape_representation_fn_(options.shape_representation_fn),
219 allowed_devices_(options.allowed_devices),
220 use_global_compute_stream_(options.use_global_compute_stream) {
221 VLOG(1) << "Created XLA device " << options.compilation_device_name << " "
222 << options.device_ordinal << " " << this;
223 VLOG(1) << "XlaDevice options: use_multiple_streams: "
224 << options.use_multiple_streams << " use_global_compute_stream: "
225 << options.use_global_compute_stream;
226 thread_pool_.reset(new thread::ThreadPool(session_options.env, "xla_device",
227 /*num_threads=*/1));
228
229 // We have multiple device to device streams to allow for some concurrency
230 // between transfers. The particular value of '4' is chosen fairly
231 // arbitrarily. It may be necessary to make this tunable via
232 // XlaDevice::Options.
233 static constexpr int kNumDeviceToDeviceStreams = 4;
234 device_to_device_streams_.resize(kNumDeviceToDeviceStreams);
235 }
236
~XlaDevice()237 XlaDevice::~XlaDevice() {
238 VLOG(1) << "Destroying XLA device " << jit_device_name_ << " " << this;
239 mutex_lock lock(mu_);
240 if (device_context_) {
241 device_context_->Unref();
242 }
243 if (fast_mem_device_context_) {
244 fast_mem_device_context_->Unref();
245 }
246 }
247
GetOrCreateClient() const248 StatusOr<xla::LocalClient*> XlaDevice::GetOrCreateClient() const {
249 // We lazily create the client because the platform commits to the
250 // details of the host hardware when the client is created, so we
251 // don't want to do it until we get a chance to hook the platform up
252 // to a simulator.
253
254 xla::LocalClientOptions options;
255 options.set_platform(platform_)
256 .set_allowed_devices(allowed_devices_)
257 .set_intra_op_parallelism_threads(intra_op_parallelism_threads_);
258 return xla::ClientLibrary::GetOrCreateLocalClient(options);
259 }
260
GetAllocator(AllocatorAttributes attr)261 Allocator* XlaDevice::GetAllocator(AllocatorAttributes attr) {
262 mutex_lock lock(mu_);
263 return GetAllocatorLocked(attr);
264 }
265
GetAllocatorLocked(AllocatorAttributes attr)266 Allocator* XlaDevice::GetAllocatorLocked(AllocatorAttributes attr) {
267 if (attr.on_host()) {
268 return cpu_allocator();
269 }
270
271 if (xla_allocator_ == nullptr) {
272 // TODO(b/78468222): This can fail, at least when the backend is GPU and
273 // there is no GPU on the host.
274 xla::Backend* backend = GetOrCreateClient().ValueOrDie()->mutable_backend();
275 xla_allocator_ = XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator(
276 backend, device_ordinal_);
277 }
278 return xla_allocator_;
279 }
280
EnsureDeviceContextOk()281 Status XlaDevice::EnsureDeviceContextOk() {
282 mutex_lock lock(mu_);
283 return GetDeviceContextLocked().status();
284 }
285
EnsureStreamOkLocked(xla::Backend * backend,const string & name,std::shared_ptr<se::Stream> * stream,bool * stream_was_changed)286 Status XlaDevice::EnsureStreamOkLocked(xla::Backend* backend,
287 const string& name,
288 std::shared_ptr<se::Stream>* stream,
289 bool* stream_was_changed) {
290 if (!(*stream) || !(*stream)->ok()) {
291 xla::StreamPool::Ptr ptr;
292 TF_ASSIGN_OR_RETURN(ptr, backend->BorrowStream(device_ordinal_));
293 *stream = std::shared_ptr<se::Stream>(std::move(ptr));
294 VLOG(1) << "XlaDevice " << this << " new " << name << " "
295 << (*stream)->DebugStreamPointers();
296 *stream_was_changed = true;
297 }
298 return Status::OK();
299 }
300
301 StatusOr<std::pair<XlaDeviceContext*, XlaDeviceContext*>>
GetDeviceContextLocked()302 XlaDevice::GetDeviceContextLocked() {
303 TF_ASSIGN_OR_RETURN(xla::LocalClient * client, GetOrCreateClient());
304 xla::Backend* backend = client->mutable_backend();
305
306 // Ensure all our streams are valid, borrowing new streams if necessary.
307 bool need_new_device_context = !device_context_;
308 if (use_global_compute_stream_) {
309 mutex_lock lock(global_mu_);
310 if (global_compute_streams_->size() <= device_ordinal_) {
311 global_compute_streams_->resize(device_ordinal_ + 1, nullptr);
312 }
313
314 auto& global_stream = global_compute_streams_->at(device_ordinal_);
315 if (global_stream != nullptr && global_stream->ok()) {
316 stream_ = global_stream;
317 } else {
318 // Directly create the stream here instead of borrowing from the stream
319 // pool to avoid potential lifetime issues.
320 stream_ = absl::make_unique<se::Stream>(
321 backend->stream_executors()[device_ordinal_]);
322 stream_->Init();
323 TF_RETURN_IF_ERROR(EnsureStreamOkLocked(backend, "stream", &stream_,
324 &need_new_device_context));
325 (*global_compute_streams_)[device_ordinal_] = stream_;
326 }
327 } else {
328 TF_RETURN_IF_ERROR(EnsureStreamOkLocked(backend, "stream", &stream_,
329 &need_new_device_context));
330 }
331
332 std::shared_ptr<se::Stream> host_to_device_stream;
333 std::shared_ptr<se::Stream> device_to_host_stream;
334 std::vector<std::shared_ptr<se::Stream>> device_to_device_streams;
335 if (use_multiple_streams_) {
336 TF_RETURN_IF_ERROR(EnsureStreamOkLocked(backend, "host_to_device_stream",
337 &host_to_device_stream_,
338 &need_new_device_context));
339 for (std::shared_ptr<se::Stream>& stream : device_to_device_streams_) {
340 TF_RETURN_IF_ERROR(
341 EnsureStreamOkLocked(backend, "device_to_device_stream", &stream,
342 &need_new_device_context));
343 }
344 host_to_device_stream = host_to_device_stream_;
345 device_to_device_streams = device_to_device_streams_;
346 // The data transfer requests from device to host could arrive out of order,
347 // so a single stream would cause deadlock. For this case,
348 // xla_device_context would borrow a stream for each transfer request.
349 device_to_host_stream = nullptr;
350 } else {
351 host_to_device_stream = stream_;
352 device_to_host_stream = stream_;
353 device_to_device_streams = {stream_};
354 }
355
356 if (!need_new_device_context) {
357 return std::make_pair(device_context_, fast_mem_device_context_);
358 }
359
360 // At this point we know we need a new device context.
361 // Call GetAllocator for the side-effect of ensuring the allocator is created.
362 GetAllocatorLocked({});
363 if (device_context_) {
364 device_context_->Unref();
365 }
366 if (fast_mem_device_context_) {
367 fast_mem_device_context_->Unref();
368 }
369 // The XlaDeviceContext keeps a reference count to the streams, and the
370 // XlaDeviceContext remains live for the duration of a Executor run. This
371 // ensures that the streams remain live for the duration of a run, even if
372 // an error is encountered and the streams are replaced with new ones.
373 device_context_ = new XlaDeviceContext(
374 stream_, host_to_device_stream, device_to_host_stream,
375 device_to_device_streams, client, shape_representation_fn_,
376 thread_pool_.get(), false);
377 VLOG(1) << "XlaDevice " << this << " new XlaDeviceContext(fast_mem=false) "
378 << device_context_;
379
380 fast_mem_device_context_ = new XlaDeviceContext(
381 stream_, std::move(host_to_device_stream),
382 std::move(device_to_host_stream), std::move(device_to_device_streams),
383 client, shape_representation_fn_, thread_pool_.get(), true);
384 VLOG(1) << "XlaDevice " << this << " new XlaDeviceContext(fast_mem=true) "
385 << fast_mem_device_context_;
386
387 // Create and set a new GpuDeviceInfo, if necessary.
388 //
389 // TODO(b/78232898): This isn't thread-safe; there is a race between the call
390 // to set_tensorflow_gpu_device_info() with ops that call the getter
391 // tensorflow_gpu_device_info(). This isn't trivially fixed by adding locking
392 // to those methods; see the bug for details. Our only saving grace at the
393 // moment is that this race doesn't seem to occur in practice.
394 if (use_gpu_device_info_) {
395 auto gpu_device_info = absl::make_unique<GpuDeviceInfo>();
396 gpu_device_info->stream = stream_.get();
397 gpu_device_info->default_context = device_context_;
398 set_tensorflow_gpu_device_info(gpu_device_info.get());
399 gpu_device_info_ = std::move(gpu_device_info);
400 VLOG(1) << "XlaDevice " << this << " new GpuDeviceInfo "
401 << gpu_device_info_.get();
402 }
403
404 return std::make_pair(device_context_, fast_mem_device_context_);
405 }
406
UseGpuDeviceInfo()407 Status XlaDevice::UseGpuDeviceInfo() {
408 mutex_lock lock(mu_);
409 use_gpu_device_info_ = true;
410 return GetDeviceContextLocked().status();
411 }
412
TryGetDeviceContext(DeviceContext ** out_context)413 Status XlaDevice::TryGetDeviceContext(DeviceContext** out_context) {
414 mutex_lock lock(mu_);
415
416 TF_ASSIGN_OR_RETURN(auto device_contexts, GetDeviceContextLocked());
417 device_contexts.first->Ref();
418 *out_context = device_contexts.first;
419 return Status::OK();
420 }
421
422 // Warn about XLA_CPU/XLA_GPU exactly once.
ShowXlaDeviceDeprecationWarning(absl::string_view compilation_device_name)423 static void ShowXlaDeviceDeprecationWarning(
424 absl::string_view compilation_device_name) {
425 static absl::once_flag once;
426 if (absl::StrContains(compilation_device_name, "CPU") ||
427 absl::StrContains(compilation_device_name, "GPU")) {
428 absl::call_once(once, [] {
429 LOG(INFO) << "XLA_GPU and XLA_CPU devices are deprecated and will be "
430 "removed in subsequent releases. Instead, use either "
431 "@tf.function(jit_compile=True) for must-compile "
432 "semantics, or run with TF_XLA_FLAGS=--tf_xla_auto_jit=2 "
433 "for auto-clustering best-effort compilation.";
434 });
435 }
436 }
437
Compute(OpKernel * op_kernel,OpKernelContext * context)438 void XlaDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) {
439 VLOG(2) << "XlaDevice::Compute " << op_kernel->name() << ":"
440 << op_kernel->type_string();
441 ShowXlaDeviceDeprecationWarning(jit_device_name_.type_string());
442 op_kernel->Compute(context);
443 }
444
ComputeAsync(AsyncOpKernel * op_kernel,OpKernelContext * context,AsyncOpKernel::DoneCallback done)445 void XlaDevice::ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context,
446 AsyncOpKernel::DoneCallback done) {
447 ShowXlaDeviceDeprecationWarning(jit_device_name_.type_string());
448 VLOG(2) << "XlaDevice::ComputeAsync " << op_kernel->name() << ":"
449 << op_kernel->type_string();
450 op_kernel->ComputeAsync(context, done);
451 }
452
Sync()453 Status XlaDevice::Sync() {
454 VLOG(1) << "XlaDevice::Sync";
455 profiler::TraceMe activity("XlaDevice::Sync", profiler::TraceMeLevel::kInfo);
456 std::shared_ptr<se::Stream> stream;
457 {
458 mutex_lock lock(mu_);
459 stream = stream_;
460 }
461 if (!stream) return Status::OK();
462
463 Status status = stream->BlockHostUntilDone();
464 TF_RETURN_IF_ERROR(status);
465 if (!stream->ok()) {
466 return errors::Internal("XlaDevice::Sync() failed.");
467 }
468 VLOG(1) << "XlaDevice::Sync completed";
469 return Status::OK();
470 }
471
472 // TODO(b/112409994): This is no longer necessary. Consolidate it with the
473 // synchronous version.
Sync(const DoneCallback & done)474 void XlaDevice::Sync(const DoneCallback& done) {
475 VLOG(1) << "XlaDevice::Sync (asynchronous)";
476 std::shared_ptr<se::Stream> stream;
477 {
478 mutex_lock lock(mu_);
479 stream = stream_;
480 }
481 if (!stream) {
482 done(Status::OK());
483 return;
484 }
485
486 // The call to ThenEnqueueOnBackgroundThread below enqueues a host callback at
487 // the end of the stream, after everything that has already been enqueued
488 // there at this moment. When the host callback is called, everything before
489 // it must have already finished, and the host callback will then place the
490 // task below onto a background thread. (See the implementation of
491 // ThenEnqueueOnBackgroundThread for details.) Therefore, when the done
492 // callback is finally called from that background thread, we know for sure
493 // that everything enqueued onto the stream (i.e., the device) at this very
494 // moment--when ThenEnqueueOnBackgroundThread is called--will have finished.
495 // This achieves a device-wide sync.
496 stream->ThenEnqueueOnBackgroundThread([stream, done](se::StreamExecutor*) {
497 profiler::TraceMe activity("XlaDevice::Sync::Callback",
498 profiler::TraceMeLevel::kInfo);
499 done(stream->ok() ? Status::OK()
500 : errors::Internal("XlaDevice::Sync() failed."));
501 });
502 }
503
MakeTensorFromProto(XlaDeviceContext * device_context,const TensorProto & tensor_proto,const AllocatorAttributes alloc_attrs,Tensor * tensor)504 Status XlaDevice::MakeTensorFromProto(XlaDeviceContext* device_context,
505 const TensorProto& tensor_proto,
506 const AllocatorAttributes alloc_attrs,
507 Tensor* tensor) {
508 Tensor parsed(tensor_proto.dtype());
509 if (!parsed.FromProto(cpu_allocator(), tensor_proto)) {
510 return errors::InvalidArgument("Cannot parse tensor from proto: ",
511 tensor_proto.DebugString());
512 }
513
514 Status status;
515 if (alloc_attrs.on_host()) {
516 *tensor = parsed;
517 } else {
518 mutex_lock lock(mu_);
519 Allocator* allocator = GetAllocatorLocked(alloc_attrs);
520 Tensor copy(allocator, parsed.dtype(), parsed.shape());
521 TF_RETURN_IF_ERROR(
522 device_context->CopyCPUTensorToDeviceSync(&parsed, this, ©));
523 *tensor = copy;
524 }
525 VLOG(2) << "Allocated tensor at " << DMAHelper::base(tensor);
526 return status;
527 }
528
MakeTensorFromProto(const TensorProto & tensor_proto,const AllocatorAttributes alloc_attrs,Tensor * tensor)529 Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto,
530 const AllocatorAttributes alloc_attrs,
531 Tensor* tensor) {
532 VLOG(1) << "XlaDevice::MakeTensorFromProto";
533 std::pair<XlaDeviceContext*, XlaDeviceContext*> device_contexts;
534 {
535 mutex_lock lock(mu_);
536 TF_ASSIGN_OR_RETURN(device_contexts, GetDeviceContextLocked());
537 }
538 return MakeTensorFromProto(device_contexts.first, tensor_proto, alloc_attrs,
539 tensor);
540 }
541
MakeFastMemTensorFromProto(const TensorProto & tensor_proto,const AllocatorAttributes alloc_attrs,Tensor * tensor)542 Status XlaDevice::MakeFastMemTensorFromProto(
543 const TensorProto& tensor_proto, const AllocatorAttributes alloc_attrs,
544 Tensor* tensor) {
545 VLOG(1) << "XlaDevice::MakeFastMemTensorFromProto";
546 std::pair<XlaDeviceContext*, XlaDeviceContext*> device_contexts;
547 {
548 mutex_lock lock(mu_);
549 TF_ASSIGN_OR_RETURN(device_contexts, GetDeviceContextLocked());
550 }
551 return MakeTensorFromProto(device_contexts.second, tensor_proto, alloc_attrs,
552 tensor);
553 }
554
SetAllowsSyncOnCompletion(bool sync_on_completion)555 void XlaDevice::SetAllowsSyncOnCompletion(bool sync_on_completion) {
556 mutex_lock lock(mu_);
557 sync_on_completion_ = sync_on_completion;
558 }
559
AllowsSyncOnCompletion() const560 bool XlaDevice::AllowsSyncOnCompletion() const {
561 mutex_lock lock(mu_);
562 return sync_on_completion_;
563 }
564
SetHandleDeviceErrorCallback(std::function<Status ()> callback)565 void XlaDevice::SetHandleDeviceErrorCallback(std::function<Status()> callback) {
566 mutex_lock lock(mu_);
567 device_error_callback_ = callback;
568 }
569
HandleDeviceError()570 Status XlaDevice::HandleDeviceError() {
571 std::function<Status()> local_device_error_callback;
572 {
573 mutex_lock lock(mu_);
574 local_device_error_callback = device_error_callback_;
575 }
576 if (local_device_error_callback != nullptr) {
577 return local_device_error_callback();
578 }
579 return Status::OK();
580 }
581
RefreshStatus()582 Status XlaDevice::RefreshStatus() {
583 std::shared_ptr<se::Stream> stream;
584 {
585 mutex_lock lock(mu_);
586 stream = stream_;
587 }
588 if (!stream) {
589 return Status::OK();
590 }
591 Status status = stream->RefreshStatus();
592 if (!status.ok()) {
593 // Ignore errors from HandleDeviceError, since by definition the status is
594 // already non-ok, so there's nothing extra to report if HandleDeviceError
595 // itself returns an error.
596 HandleDeviceError().IgnoreError();
597 }
598 return status;
599 }
600
RegisterXlaDeviceKernels(const char * device,const char * jit_device)601 XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device,
602 const char* jit_device) {
603 // Any op assigned to the device that isn't rewritten by the graph rewriter
604 // gets executed by an XlaCompileOnDemandOp, which compiles it and executes
605 // it just-in-time.
606 auto factory = [](OpKernelConstruction* context) -> OpKernel* {
607 return new XlaCompileOnDemandOp(context);
608 };
609 XlaOpRegistry::RegisterCompilationKernels();
610 XlaDeviceOpRegistrations* registrations = new XlaDeviceOpRegistrations;
611 for (const KernelDef* jit_def : XlaOpRegistry::DeviceKernels(
612 jit_device,
613 /*include_compilation_only_kernels=*/false)) {
614 KernelDef* def = new KernelDef(*jit_def);
615 const std::unordered_set<std::string>* constant_inputs =
616 XlaOpRegistry::CompileTimeConstantInputArgNames(def->op());
617
618 for (const std::string& arg_name : *constant_inputs) {
619 def->add_host_memory_arg(arg_name);
620 }
621
622 def->set_device_type(device);
623 registrations->op_kernel_registrars.emplace_back(
624 new kernel_factory::OpKernelRegistrar(def, "XlaCompileOnDemandOp",
625 factory));
626 }
627 return registrations;
628 }
629
630 } // namespace tensorflow
631