1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/jit/xla_device.h"
17
18 #include <stdlib.h>
19
20 #include <unordered_set>
21 #include <utility>
22
23 #include "absl/base/call_once.h"
24 #include "absl/memory/memory.h"
25 #include "absl/strings/match.h"
26 #include "tensorflow/compiler/jit/defs.h"
27 #include "tensorflow/compiler/jit/xla_compile_on_demand_op.h"
28 #include "tensorflow/compiler/jit/xla_device_context.h"
29 #include "tensorflow/compiler/jit/xla_device_ops.h"
30 #include "tensorflow/compiler/tf2xla/shape_util.h"
31 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
32 #include "tensorflow/compiler/xla/client/client_library.h"
33 #include "tensorflow/compiler/xla/service/stream_pool.h"
34 #include "tensorflow/core/common_runtime/device.h"
35 #include "tensorflow/core/common_runtime/device_factory.h"
36 #include "tensorflow/core/common_runtime/dma_helper.h"
37 #include "tensorflow/core/common_runtime/function.h"
38 #include "tensorflow/core/common_runtime/graph_constructor.h"
39 #include "tensorflow/core/common_runtime/renamed_device.h"
40 #include "tensorflow/core/framework/allocator.h"
41 #include "tensorflow/core/framework/device_base.h"
42 #include "tensorflow/core/framework/function.h"
43 #include "tensorflow/core/framework/kernel_def.pb.h"
44 #include "tensorflow/core/framework/node_def_builder.h"
45 #include "tensorflow/core/framework/op_kernel.h"
46 #include "tensorflow/core/framework/tensor.h"
47 #include "tensorflow/core/framework/tensor.pb.h"
48 #include "tensorflow/core/framework/types.h"
49 #include "tensorflow/core/lib/core/notification.h"
50 #include "tensorflow/core/lib/core/status.h"
51 #include "tensorflow/core/platform/logging.h"
52 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
53 #include "tensorflow/core/platform/tracing.h"
54 #include "tensorflow/core/profiler/lib/traceme.h"
55 #include "tensorflow/core/public/session_options.h"
56 #include "tensorflow/core/public/version.h"
57 #include "tensorflow/core/util/device_name_utils.h"
58 #include "tensorflow/core/util/dump_graph.h"
59 #include "tensorflow/core/util/ptr_util.h"
60 #include "tensorflow/core/util/stream_executor_util.h"
61
62 namespace tensorflow {
63
64 // Default PaddedShapeFn implementation that simply returns the unpadded
65 // on-device shape. This is accurate for CPU and GPU devices that neither
66 // transpose nor pad tensors.
DefaultPaddedShapeFn(const Tensor & tensor,xla::Shape * shape)67 Status DefaultPaddedShapeFn(const Tensor& tensor, xla::Shape* shape) {
68 const tensorflow::XlaTensor* xla_tensor =
69 tensorflow::XlaTensor::FromTensor(&tensor);
70 if (xla_tensor == nullptr) {
71 return TensorShapeToXLAShape(tensor.dtype(), tensor.shape(), shape);
72 }
73
74 const xla::ShapedBuffer& shaped_buffer = xla_tensor->shaped_buffer();
75 *shape = shaped_buffer.on_device_shape();
76 return Status::OK();
77 }
78
79 // Caches a XlaDeviceAllocator per <backend, device ordinal> pair. A
80 // XlaDeviceAllocator is created on demand and is associated with a
81 // XlaDevice. It outlives the device itself (for instance, the buffer
82 // backing a tensor holds a pointer to the allocator for book-keeping,
83 // and this buffer can outlast the device).
84 class XlaDeviceAllocatorState {
85 public:
86 // Creates or returns a cached XlaDeviceAllocator for a given
87 // backend and device_ordinal.
88 static XlaDeviceAllocator* GetOrCreateXlaDeviceAllocator(
89 const xla::Backend* backend, int device_ordinal);
90
91 private:
92 // Returns the singleton instance of XlaDeviceAllocatorState.
93 static XlaDeviceAllocatorState& Singleton();
94 XlaDeviceAllocatorState();
95 ~XlaDeviceAllocatorState();
96
97 mutex allocator_mutex_; // Guards the singleton allocator state.
98 std::unordered_map<std::pair<const xla::Backend*, int>,
99 std::unique_ptr<XlaDeviceAllocator>,
100 hash<std::pair<const xla::Backend*, int>>>
101 allocators_ TF_GUARDED_BY(allocator_mutex_);
102
103 TF_DISALLOW_COPY_AND_ASSIGN(XlaDeviceAllocatorState);
104 };
105
Singleton()106 /* static */ XlaDeviceAllocatorState& XlaDeviceAllocatorState::Singleton() {
107 static auto a = new XlaDeviceAllocatorState;
108 return *a;
109 }
110
111 XlaDeviceAllocatorState::XlaDeviceAllocatorState() = default;
112 XlaDeviceAllocatorState::~XlaDeviceAllocatorState() = default;
113
GetOrCreateXlaDeviceAllocator(const xla::Backend * backend,int device_ordinal)114 XlaDeviceAllocator* XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator(
115 const xla::Backend* backend, int device_ordinal) {
116 XlaDeviceAllocatorState& state = Singleton();
117 mutex_lock lock(state.allocator_mutex_);
118
119 auto it = state.allocators_.find({backend, device_ordinal});
120 if (it != state.allocators_.end()) {
121 return it->second.get();
122 }
123
124 std::unique_ptr<XlaDeviceAllocator> alloc =
125 absl::make_unique<XlaDeviceAllocator>(
126 backend->stream_executors()[device_ordinal]);
127 XlaDeviceAllocator* alloc_ptr = alloc.get();
128 state.allocators_[{backend, device_ordinal}] = std::move(alloc);
129 return alloc_ptr;
130 }
131
132 namespace {
133
134
BuildXlaDeviceAttributes(const string & name_prefix,const string & device_name,int device_ordinal)135 static DeviceAttributes BuildXlaDeviceAttributes(const string& name_prefix,
136 const string& device_name,
137 int device_ordinal) {
138 return Device::BuildDeviceAttributes(
139 absl::StrCat(name_prefix, "/device:", device_name, ":", device_ordinal),
140 DeviceType(device_name), Bytes(16ULL << 30), DeviceLocality(),
141 absl::StrCat("device: ", device_name, " device"));
142 }
143
144 } // namespace
145
Metadata(int device_ordinal,se::Platform * platform,const DeviceType & device_type,XlaCompiler::ShapeRepresentationFn shape_representation_fn,PaddedShapeFn padded_shape_fn,bool use_multiple_streams)146 XlaDevice::Metadata::Metadata(
147 int device_ordinal, se::Platform* platform, const DeviceType& device_type,
148 XlaCompiler::ShapeRepresentationFn shape_representation_fn,
149 PaddedShapeFn padded_shape_fn, bool use_multiple_streams)
150 : device_ordinal_(device_ordinal),
151 device_type_(device_type),
152 platform_(platform),
153 shape_representation_fn_(std::move(shape_representation_fn)),
154 padded_shape_fn_(std::move(padded_shape_fn)),
155 use_multiple_streams_(use_multiple_streams) {}
156
device_ordinal() const157 int XlaDevice::Metadata::device_ordinal() const { return device_ordinal_; }
158
platform() const159 se::Platform* XlaDevice::Metadata::platform() const { return platform_; }
160
client() const161 xla::LocalClient* XlaDevice::Metadata::client() const {
162 auto client = xla::ClientLibrary::GetOrCreateLocalClient(platform_);
163 return client.ValueOrDie();
164 }
165
jit_device_type() const166 const DeviceType& XlaDevice::Metadata::jit_device_type() const {
167 return device_type_;
168 }
169
GetMetadataFromDevice(DeviceBase * device,const XlaDevice::Metadata ** metadata)170 /*static*/ Status XlaDevice::GetMetadataFromDevice(
171 DeviceBase* device, const XlaDevice::Metadata** metadata) {
172 *metadata = nullptr;
173 XlaDevice* xla_device = dynamic_cast<XlaDevice*>(device->UnderlyingDevice());
174 if (xla_device == nullptr) {
175 return errors::Internal(
176 "Cannot get XLA metadata from non-XLA device \"", device->name(),
177 "\". GetMetadata must only be called on an XLA device. Either an "
178 "internal bug has been triggered, or an XLA-specific op has been "
179 "placed on the wrong device.");
180 }
181 *metadata = &(xla_device->xla_metadata_);
182 return Status::OK();
183 }
184
GetMetadata(OpKernelContext * ctx,const Metadata ** metadata)185 /* static */ Status XlaDevice::GetMetadata(OpKernelContext* ctx,
186 const Metadata** metadata) {
187 return GetMetadataFromDevice(ctx->device(), metadata);
188 }
189
GetMetadata(OpKernelConstruction * ctx,const Metadata ** metadata)190 /* static */ Status XlaDevice::GetMetadata(OpKernelConstruction* ctx,
191 const Metadata** metadata) {
192 return GetMetadataFromDevice(ctx->device(), metadata);
193 }
194
XlaDevice(const SessionOptions & session_options,const Options & options)195 XlaDevice::XlaDevice(const SessionOptions& session_options,
196 const Options& options)
197 : LocalDevice(session_options,
198 BuildXlaDeviceAttributes(options.device_name_prefix,
199 options.device_name,
200 options.device_ordinal)),
201 xla_metadata_(options.device_ordinal, options.platform,
202 DeviceType(options.compilation_device_name),
203 options.shape_representation_fn,
204 options.padded_shape_fn ? options.padded_shape_fn
205 : DefaultPaddedShapeFn,
206 options.use_multiple_streams),
207 device_ordinal_(options.device_ordinal),
208 jit_device_name_(options.compilation_device_name),
209 platform_(options.platform),
210 intra_op_parallelism_threads_(
211 session_options.config.intra_op_parallelism_threads()),
212 use_multiple_streams_(options.use_multiple_streams),
213 shape_representation_fn_(options.shape_representation_fn),
214 allowed_devices_(options.allowed_devices) {
215 VLOG(1) << "Created XLA device " << options.compilation_device_name << " "
216 << this;
217 thread_pool_.reset(new thread::ThreadPool(session_options.env, "xla_device",
218 /*num_threads=*/1));
219
220 // We have multiple device to device streams to allow for some concurrency
221 // between transfers. The particular value of '4' is chosen fairly
222 // arbitrarily. It may be necessary to make this tunable via
223 // XlaDevice::Options.
224 static constexpr int kNumDeviceToDeviceStreams = 4;
225 device_to_device_streams_.resize(kNumDeviceToDeviceStreams);
226 }
227
~XlaDevice()228 XlaDevice::~XlaDevice() {
229 VLOG(1) << "Destroying XLA device " << jit_device_name_ << " " << this;
230 mutex_lock lock(mu_);
231 if (device_context_) {
232 device_context_->Unref();
233 }
234 if (fast_mem_device_context_) {
235 fast_mem_device_context_->Unref();
236 }
237 }
238
GetOrCreateClient() const239 xla::StatusOr<xla::LocalClient*> XlaDevice::GetOrCreateClient() const {
240 // We lazily create the client because the platform commits to the
241 // details of the host hardware when the client is created, so we
242 // don't want to do it until we get a chance to hook the platform up
243 // to a simulator.
244
245 xla::LocalClientOptions options;
246 options.set_platform(platform_)
247 .set_allowed_devices(allowed_devices_)
248 .set_intra_op_parallelism_threads(intra_op_parallelism_threads_);
249 return xla::ClientLibrary::GetOrCreateLocalClient(options);
250 }
251
GetAllocator(AllocatorAttributes attr)252 Allocator* XlaDevice::GetAllocator(AllocatorAttributes attr) {
253 mutex_lock lock(mu_);
254 return GetAllocatorLocked(attr);
255 }
256
GetAllocatorLocked(AllocatorAttributes attr)257 Allocator* XlaDevice::GetAllocatorLocked(AllocatorAttributes attr) {
258 if (attr.on_host()) {
259 return cpu_allocator();
260 }
261
262 if (xla_allocator_ == nullptr) {
263 // TODO(b/78468222): This can fail, at least when the backend is GPU and
264 // there is no GPU on the host.
265 xla::Backend* backend = GetOrCreateClient().ValueOrDie()->mutable_backend();
266 xla_allocator_ = XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator(
267 backend, device_ordinal_);
268 }
269 return xla_allocator_;
270 }
271
EnsureDeviceContextOk()272 Status XlaDevice::EnsureDeviceContextOk() {
273 mutex_lock lock(mu_);
274 return GetDeviceContextLocked().status();
275 }
276
EnsureStreamOkLocked(xla::Backend * backend,const string & name,std::shared_ptr<se::Stream> * stream,bool * stream_was_changed)277 Status XlaDevice::EnsureStreamOkLocked(xla::Backend* backend,
278 const string& name,
279 std::shared_ptr<se::Stream>* stream,
280 bool* stream_was_changed) {
281 if (!(*stream) || !(*stream)->ok()) {
282 xla::StreamPool::Ptr ptr;
283 TF_ASSIGN_OR_RETURN(ptr, backend->BorrowStream(device_ordinal_));
284 *stream = std::shared_ptr<se::Stream>(std::move(ptr));
285 VLOG(1) << "XlaDevice " << this << " new " << name << " "
286 << (*stream)->DebugStreamPointers();
287 *stream_was_changed = true;
288 }
289 return Status::OK();
290 }
291
292 xla::StatusOr<std::pair<XlaDeviceContext*, XlaDeviceContext*>>
GetDeviceContextLocked()293 XlaDevice::GetDeviceContextLocked() {
294 TF_ASSIGN_OR_RETURN(xla::LocalClient * client, GetOrCreateClient());
295 xla::Backend* backend = client->mutable_backend();
296
297 // Ensure all our streams are valid, borrowing new streams if necessary.
298 bool need_new_device_context = !device_context_;
299 TF_RETURN_IF_ERROR(EnsureStreamOkLocked(backend, "stream", &stream_,
300 &need_new_device_context));
301
302 std::shared_ptr<se::Stream> host_to_device_stream;
303 std::shared_ptr<se::Stream> device_to_host_stream;
304 std::vector<std::shared_ptr<se::Stream>> device_to_device_streams;
305 if (use_multiple_streams_) {
306 TF_RETURN_IF_ERROR(EnsureStreamOkLocked(backend, "host_to_device_stream",
307 &host_to_device_stream_,
308 &need_new_device_context));
309 for (std::shared_ptr<se::Stream>& stream : device_to_device_streams_) {
310 TF_RETURN_IF_ERROR(
311 EnsureStreamOkLocked(backend, "device_to_device_stream", &stream,
312 &need_new_device_context));
313 }
314 host_to_device_stream = host_to_device_stream_;
315 device_to_device_streams = device_to_device_streams_;
316 // The data transfer requests from device to host could arrive out of order,
317 // so a single stream would cause deadlock. For this case,
318 // xla_device_context would borrow a stream for each transfer request.
319 device_to_host_stream = nullptr;
320 } else {
321 host_to_device_stream = stream_;
322 device_to_host_stream = stream_;
323 device_to_device_streams = {stream_};
324 }
325
326 if (!need_new_device_context) {
327 return std::make_pair(device_context_, fast_mem_device_context_);
328 }
329
330 // At this point we know we need a new device context.
331 // Call GetAllocator for the side-effect of ensuring the allocator is created.
332 GetAllocatorLocked({});
333 if (device_context_) {
334 device_context_->Unref();
335 }
336 if (fast_mem_device_context_) {
337 fast_mem_device_context_->Unref();
338 }
339 // The XlaDeviceContext keeps a reference count to the streams, and the
340 // XlaDeviceContext remains live for the duration of a Executor run. This
341 // ensures that the streams remain live for the duration of a run, even if
342 // an error is encountered and the streams are replaced with new ones.
343 device_context_ = new XlaDeviceContext(
344 stream_, host_to_device_stream, device_to_host_stream,
345 device_to_device_streams, client, shape_representation_fn_,
346 thread_pool_.get(), false);
347 VLOG(1) << "XlaDevice " << this << " new XlaDeviceContext(fast_mem=false) "
348 << device_context_;
349
350 fast_mem_device_context_ = new XlaDeviceContext(
351 stream_, std::move(host_to_device_stream),
352 std::move(device_to_host_stream), std::move(device_to_device_streams),
353 client, shape_representation_fn_, thread_pool_.get(), true);
354 VLOG(1) << "XlaDevice " << this << " new XlaDeviceContext(fast_mem=true) "
355 << fast_mem_device_context_;
356
357 // Create and set a new GpuDeviceInfo, if necessary.
358 //
359 // TODO(b/78232898): This isn't thread-safe; there is a race between the call
360 // to set_tensorflow_gpu_device_info() with ops that call the getter
361 // tensorflow_gpu_device_info(). This isn't trivially fixed by adding locking
362 // to those methods; see the bug for details. Our only saving grace at the
363 // moment is that this race doesn't seem to occur in practice.
364 if (use_gpu_device_info_) {
365 auto gpu_device_info = absl::make_unique<GpuDeviceInfo>();
366 gpu_device_info->stream = stream_.get();
367 gpu_device_info->default_context = device_context_;
368 set_tensorflow_gpu_device_info(gpu_device_info.get());
369 gpu_device_info_ = std::move(gpu_device_info);
370 VLOG(1) << "XlaDevice " << this << " new GpuDeviceInfo "
371 << gpu_device_info_.get();
372 }
373
374 return std::make_pair(device_context_, fast_mem_device_context_);
375 }
376
UseGpuDeviceInfo()377 Status XlaDevice::UseGpuDeviceInfo() {
378 mutex_lock lock(mu_);
379 use_gpu_device_info_ = true;
380 return GetDeviceContextLocked().status();
381 }
382
TryGetDeviceContext(DeviceContext ** out_context)383 Status XlaDevice::TryGetDeviceContext(DeviceContext** out_context) {
384 mutex_lock lock(mu_);
385
386 TF_ASSIGN_OR_RETURN(auto device_contexts, GetDeviceContextLocked());
387 device_contexts.first->Ref();
388 *out_context = device_contexts.first;
389 return Status::OK();
390 }
391
392 // Warn about XLA_CPU/XLA_GPU exactly once.
ShowXlaDeviceDeprecationWarning(absl::string_view compilation_device_name)393 static void ShowXlaDeviceDeprecationWarning(
394 absl::string_view compilation_device_name) {
395 static absl::once_flag once;
396 if (absl::StrContains(compilation_device_name, "CPU") ||
397 absl::StrContains(compilation_device_name, "GPU")) {
398 absl::call_once(once, [] {
399 LOG(INFO) << "XLA_GPU and XLA_CPU devices are deprecated and will be "
400 "removed in subsequent releases. Instead, use either "
401 "@tf.function(jit_compile=True) for must-compile "
402 "semantics, or run with TF_XLA_FLAGS=--tf_xla_auto_jit=2 "
403 "for auto-clustering best-effort compilation.";
404 });
405 }
406 }
407
Compute(OpKernel * op_kernel,OpKernelContext * context)408 void XlaDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) {
409 VLOG(2) << "XlaDevice::Compute " << op_kernel->name() << ":"
410 << op_kernel->type_string();
411 ShowXlaDeviceDeprecationWarning(jit_device_name_.type_string());
412 op_kernel->Compute(context);
413 }
414
ComputeAsync(AsyncOpKernel * op_kernel,OpKernelContext * context,AsyncOpKernel::DoneCallback done)415 void XlaDevice::ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context,
416 AsyncOpKernel::DoneCallback done) {
417 ShowXlaDeviceDeprecationWarning(jit_device_name_.type_string());
418 VLOG(2) << "XlaDevice::ComputeAsync " << op_kernel->name() << ":"
419 << op_kernel->type_string();
420 op_kernel->ComputeAsync(context, done);
421 }
422
Sync()423 Status XlaDevice::Sync() {
424 VLOG(1) << "XlaDevice::Sync";
425 profiler::TraceMe activity("XlaDevice::Sync", profiler::TraceMeLevel::kInfo);
426 std::shared_ptr<se::Stream> stream;
427 {
428 mutex_lock lock(mu_);
429 stream = stream_;
430 }
431 if (!stream) return Status::OK();
432
433 Status status = stream->BlockHostUntilDone();
434 TF_RETURN_IF_ERROR(status);
435 if (!stream->ok()) {
436 return errors::Internal("XlaDevice::Sync() failed.");
437 }
438 VLOG(1) << "XlaDevice::Sync completed";
439 return Status::OK();
440 }
441
442 // TODO(b/112409994): This is no longer necessary. Consolidate it with the
443 // synchronous version.
Sync(const DoneCallback & done)444 void XlaDevice::Sync(const DoneCallback& done) {
445 VLOG(1) << "XlaDevice::Sync (asynchronous)";
446 std::shared_ptr<se::Stream> stream;
447 {
448 mutex_lock lock(mu_);
449 stream = stream_;
450 }
451 if (!stream) {
452 done(Status::OK());
453 return;
454 }
455
456 // The call to ThenEnqueueOnBackgroundThread below enqueues a host callback at
457 // the end of the stream, after everything that has already been enqueued
458 // there at this moment. When the host callback is called, everything before
459 // it must have already finished, and the host callback will then place the
460 // task below onto a background thread. (See the implementation of
461 // ThenEnqueueOnBackgroundThread for details.) Therefore, when the done
462 // callback is finally called from that background thread, we know for sure
463 // that everything enqueued onto the stream (i.e., the device) at this very
464 // moment--when ThenEnqueueOnBackgroundThread is called--will have finished.
465 // This achieves a device-wide sync.
466 stream->ThenEnqueueOnBackgroundThread([stream, done](se::StreamExecutor*) {
467 profiler::TraceMe activity("XlaDevice::Sync::Callback",
468 profiler::TraceMeLevel::kInfo);
469 done(stream->ok() ? Status::OK()
470 : errors::Internal("XlaDevice::Sync() failed."));
471 });
472 }
473
MakeTensorFromProto(XlaDeviceContext * device_context,const TensorProto & tensor_proto,const AllocatorAttributes alloc_attrs,Tensor * tensor)474 Status XlaDevice::MakeTensorFromProto(XlaDeviceContext* device_context,
475 const TensorProto& tensor_proto,
476 const AllocatorAttributes alloc_attrs,
477 Tensor* tensor) {
478 Tensor parsed(tensor_proto.dtype());
479 if (!parsed.FromProto(cpu_allocator(), tensor_proto)) {
480 return errors::InvalidArgument("Cannot parse tensor from proto: ",
481 tensor_proto.DebugString());
482 }
483
484 Status status;
485 if (alloc_attrs.on_host()) {
486 *tensor = parsed;
487 } else {
488 mutex_lock lock(mu_);
489 Allocator* allocator = GetAllocatorLocked(alloc_attrs);
490 Tensor copy(allocator, parsed.dtype(), parsed.shape());
491 TF_RETURN_IF_ERROR(
492 device_context->CopyCPUTensorToDeviceSync(&parsed, this, ©));
493 *tensor = copy;
494 }
495 VLOG(2) << "Allocated tensor at " << DMAHelper::base(tensor);
496 return status;
497 }
498
MakeTensorFromProto(const TensorProto & tensor_proto,const AllocatorAttributes alloc_attrs,Tensor * tensor)499 Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto,
500 const AllocatorAttributes alloc_attrs,
501 Tensor* tensor) {
502 VLOG(1) << "XlaDevice::MakeTensorFromProto";
503 std::pair<XlaDeviceContext*, XlaDeviceContext*> device_contexts;
504 {
505 mutex_lock lock(mu_);
506 TF_ASSIGN_OR_RETURN(device_contexts, GetDeviceContextLocked());
507 }
508 return MakeTensorFromProto(device_contexts.first, tensor_proto, alloc_attrs,
509 tensor);
510 }
511
MakeFastMemTensorFromProto(const TensorProto & tensor_proto,const AllocatorAttributes alloc_attrs,Tensor * tensor)512 Status XlaDevice::MakeFastMemTensorFromProto(
513 const TensorProto& tensor_proto, const AllocatorAttributes alloc_attrs,
514 Tensor* tensor) {
515 VLOG(1) << "XlaDevice::MakeFastMemTensorFromProto";
516 std::pair<XlaDeviceContext*, XlaDeviceContext*> device_contexts;
517 {
518 mutex_lock lock(mu_);
519 TF_ASSIGN_OR_RETURN(device_contexts, GetDeviceContextLocked());
520 }
521 return MakeTensorFromProto(device_contexts.second, tensor_proto, alloc_attrs,
522 tensor);
523 }
524
SetAllowsSyncOnCompletion(bool sync_on_completion)525 void XlaDevice::SetAllowsSyncOnCompletion(bool sync_on_completion) {
526 mutex_lock lock(mu_);
527 sync_on_completion_ = sync_on_completion;
528 }
529
AllowsSyncOnCompletion() const530 bool XlaDevice::AllowsSyncOnCompletion() const {
531 mutex_lock lock(mu_);
532 return sync_on_completion_;
533 }
534
SetHandleDeviceErrorCallback(std::function<Status ()> callback)535 void XlaDevice::SetHandleDeviceErrorCallback(std::function<Status()> callback) {
536 mutex_lock lock(mu_);
537 device_error_callback_ = callback;
538 }
539
HandleDeviceError()540 Status XlaDevice::HandleDeviceError() {
541 std::function<Status()> local_device_error_callback;
542 {
543 mutex_lock lock(mu_);
544 local_device_error_callback = device_error_callback_;
545 }
546 if (local_device_error_callback != nullptr) {
547 return local_device_error_callback();
548 }
549 return Status::OK();
550 }
551
RefreshStatus()552 Status XlaDevice::RefreshStatus() {
553 std::shared_ptr<se::Stream> stream;
554 {
555 mutex_lock lock(mu_);
556 stream = stream_;
557 }
558 if (!stream) {
559 return Status::OK();
560 }
561 Status status = stream->RefreshStatus();
562 if (!status.ok()) {
563 // Ignore errors from HandleDeviceError, since by definition the status is
564 // already non-ok, so there's nothing extra to report if HandleDeviceError
565 // itself returns an error.
566 HandleDeviceError().IgnoreError();
567 }
568 return status;
569 }
570
RegisterXlaDeviceKernels(const char * device,const char * jit_device)571 XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device,
572 const char* jit_device) {
573 // Any op assigned to the device that isn't rewritten by the graph rewriter
574 // gets executed by an XlaCompileOnDemandOp, which compiles it and executes
575 // it just-in-time.
576 auto factory = [](OpKernelConstruction* context) -> OpKernel* {
577 return new XlaCompileOnDemandOp(context);
578 };
579 XlaOpRegistry::RegisterCompilationKernels();
580 XlaDeviceOpRegistrations* registrations = new XlaDeviceOpRegistrations;
581 for (const KernelDef* jit_def : XlaOpRegistry::DeviceKernels(
582 jit_device,
583 /*include_compilation_only_kernels=*/false)) {
584 KernelDef* def = new KernelDef(*jit_def);
585 const std::unordered_set<std::string>* constant_inputs =
586 XlaOpRegistry::CompileTimeConstantInputArgNames(def->op());
587
588 for (const std::string& arg_name : *constant_inputs) {
589 def->add_host_memory_arg(arg_name);
590 }
591
592 def->set_device_type(device);
593 registrations->op_kernel_registrars.emplace_back(
594 new kernel_factory::OpKernelRegistrar(def, "XlaCompileOnDemandOp",
595 factory));
596 }
597 return registrations;
598 }
599
600 } // namespace tensorflow
601