• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/jit/xla_device.h"
17 
18 #include <stdlib.h>
19 
20 #include <unordered_set>
21 #include <utility>
22 
23 #include "absl/base/call_once.h"
24 #include "absl/memory/memory.h"
25 #include "absl/strings/match.h"
26 #include "tensorflow/compiler/jit/defs.h"
27 #include "tensorflow/compiler/jit/xla_compile_on_demand_op.h"
28 #include "tensorflow/compiler/jit/xla_device_context.h"
29 #include "tensorflow/compiler/jit/xla_device_ops.h"
30 #include "tensorflow/compiler/tf2xla/shape_util.h"
31 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
32 #include "tensorflow/compiler/xla/client/client_library.h"
33 #include "tensorflow/compiler/xla/service/stream_pool.h"
34 #include "tensorflow/core/common_runtime/device.h"
35 #include "tensorflow/core/common_runtime/device_factory.h"
36 #include "tensorflow/core/common_runtime/dma_helper.h"
37 #include "tensorflow/core/common_runtime/function.h"
38 #include "tensorflow/core/common_runtime/graph_constructor.h"
39 #include "tensorflow/core/common_runtime/renamed_device.h"
40 #include "tensorflow/core/framework/allocator.h"
41 #include "tensorflow/core/framework/device_base.h"
42 #include "tensorflow/core/framework/function.h"
43 #include "tensorflow/core/framework/kernel_def.pb.h"
44 #include "tensorflow/core/framework/node_def_builder.h"
45 #include "tensorflow/core/framework/op_kernel.h"
46 #include "tensorflow/core/framework/tensor.h"
47 #include "tensorflow/core/framework/tensor.pb.h"
48 #include "tensorflow/core/framework/types.h"
49 #include "tensorflow/core/lib/core/notification.h"
50 #include "tensorflow/core/lib/core/status.h"
51 #include "tensorflow/core/platform/logging.h"
52 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
53 #include "tensorflow/core/platform/tracing.h"
54 #include "tensorflow/core/profiler/lib/traceme.h"
55 #include "tensorflow/core/public/session_options.h"
56 #include "tensorflow/core/public/version.h"
57 #include "tensorflow/core/util/device_name_utils.h"
58 #include "tensorflow/core/util/dump_graph.h"
59 #include "tensorflow/core/util/ptr_util.h"
60 #include "tensorflow/core/util/stream_executor_util.h"
61 
62 namespace tensorflow {
63 
64 // Default PaddedShapeFn implementation that simply returns the unpadded
65 // on-device shape. This is accurate for CPU and GPU devices that neither
66 // transpose nor pad tensors.
DefaultPaddedShapeFn(const Tensor & tensor,xla::Shape * shape)67 Status DefaultPaddedShapeFn(const Tensor& tensor, xla::Shape* shape) {
68   const tensorflow::XlaTensor* xla_tensor =
69       tensorflow::XlaTensor::FromTensor(&tensor);
70   if (xla_tensor == nullptr) {
71     return TensorShapeToXLAShape(tensor.dtype(), tensor.shape(), shape);
72   }
73 
74   const xla::ShapedBuffer& shaped_buffer = xla_tensor->shaped_buffer();
75   *shape = shaped_buffer.on_device_shape();
76   return Status::OK();
77 }
78 
79 // Caches a XlaDeviceAllocator per <backend, device ordinal> pair. A
80 // XlaDeviceAllocator is created on demand and is associated with a
81 // XlaDevice. It outlives the device itself (for instance, the buffer
82 // backing a tensor holds a pointer to the allocator for book-keeping,
83 // and this buffer can outlast the device).
84 class XlaDeviceAllocatorState {
85  public:
86   // Creates or returns a cached XlaDeviceAllocator for a given
87   // backend and device_ordinal.
88   static XlaDeviceAllocator* GetOrCreateXlaDeviceAllocator(
89       const xla::Backend* backend, int device_ordinal);
90 
91  private:
92   // Returns the singleton instance of XlaDeviceAllocatorState.
93   static XlaDeviceAllocatorState& Singleton();
94   XlaDeviceAllocatorState();
95   ~XlaDeviceAllocatorState();
96 
97   mutex allocator_mutex_;  // Guards the singleton allocator state.
98   std::unordered_map<std::pair<const xla::Backend*, int>,
99                      std::unique_ptr<XlaDeviceAllocator>,
100                      hash<std::pair<const xla::Backend*, int>>>
101       allocators_ TF_GUARDED_BY(allocator_mutex_);
102 
103   TF_DISALLOW_COPY_AND_ASSIGN(XlaDeviceAllocatorState);
104 };
105 
Singleton()106 /* static */ XlaDeviceAllocatorState& XlaDeviceAllocatorState::Singleton() {
107   static auto a = new XlaDeviceAllocatorState;
108   return *a;
109 }
110 
111 XlaDeviceAllocatorState::XlaDeviceAllocatorState() = default;
112 XlaDeviceAllocatorState::~XlaDeviceAllocatorState() = default;
113 
GetOrCreateXlaDeviceAllocator(const xla::Backend * backend,int device_ordinal)114 XlaDeviceAllocator* XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator(
115     const xla::Backend* backend, int device_ordinal) {
116   XlaDeviceAllocatorState& state = Singleton();
117   mutex_lock lock(state.allocator_mutex_);
118 
119   auto it = state.allocators_.find({backend, device_ordinal});
120   if (it != state.allocators_.end()) {
121     return it->second.get();
122   }
123 
124   std::unique_ptr<XlaDeviceAllocator> alloc =
125       absl::make_unique<XlaDeviceAllocator>(
126           backend->stream_executors()[device_ordinal]);
127   XlaDeviceAllocator* alloc_ptr = alloc.get();
128   state.allocators_[{backend, device_ordinal}] = std::move(alloc);
129   return alloc_ptr;
130 }
131 
132 namespace {
133 
134 
BuildXlaDeviceAttributes(const string & name_prefix,const string & device_name,int device_ordinal)135 static DeviceAttributes BuildXlaDeviceAttributes(const string& name_prefix,
136                                                  const string& device_name,
137                                                  int device_ordinal) {
138   return Device::BuildDeviceAttributes(
139       absl::StrCat(name_prefix, "/device:", device_name, ":", device_ordinal),
140       DeviceType(device_name), Bytes(16ULL << 30), DeviceLocality(),
141       absl::StrCat("device: ", device_name, " device"));
142 }
143 
144 }  // namespace
145 
Metadata(int device_ordinal,se::Platform * platform,const DeviceType & device_type,XlaCompiler::ShapeRepresentationFn shape_representation_fn,PaddedShapeFn padded_shape_fn,bool use_multiple_streams)146 XlaDevice::Metadata::Metadata(
147     int device_ordinal, se::Platform* platform, const DeviceType& device_type,
148     XlaCompiler::ShapeRepresentationFn shape_representation_fn,
149     PaddedShapeFn padded_shape_fn, bool use_multiple_streams)
150     : device_ordinal_(device_ordinal),
151       device_type_(device_type),
152       platform_(platform),
153       shape_representation_fn_(std::move(shape_representation_fn)),
154       padded_shape_fn_(std::move(padded_shape_fn)),
155       use_multiple_streams_(use_multiple_streams) {}
156 
device_ordinal() const157 int XlaDevice::Metadata::device_ordinal() const { return device_ordinal_; }
158 
platform() const159 se::Platform* XlaDevice::Metadata::platform() const { return platform_; }
160 
client() const161 xla::LocalClient* XlaDevice::Metadata::client() const {
162   auto client = xla::ClientLibrary::GetOrCreateLocalClient(platform_);
163   return client.ValueOrDie();
164 }
165 
jit_device_type() const166 const DeviceType& XlaDevice::Metadata::jit_device_type() const {
167   return device_type_;
168 }
169 
GetMetadataFromDevice(DeviceBase * device,const XlaDevice::Metadata ** metadata)170 /*static*/ Status XlaDevice::GetMetadataFromDevice(
171     DeviceBase* device, const XlaDevice::Metadata** metadata) {
172   *metadata = nullptr;
173   XlaDevice* xla_device = dynamic_cast<XlaDevice*>(device->UnderlyingDevice());
174   if (xla_device == nullptr) {
175     return errors::Internal(
176         "Cannot get XLA metadata from non-XLA device \"", device->name(),
177         "\". GetMetadata must only be called on an XLA device. Either an "
178         "internal bug has been triggered, or an XLA-specific op has been "
179         "placed on the wrong device.");
180   }
181   *metadata = &(xla_device->xla_metadata_);
182   return Status::OK();
183 }
184 
GetMetadata(OpKernelContext * ctx,const Metadata ** metadata)185 /* static */ Status XlaDevice::GetMetadata(OpKernelContext* ctx,
186                                            const Metadata** metadata) {
187   return GetMetadataFromDevice(ctx->device(), metadata);
188 }
189 
GetMetadata(OpKernelConstruction * ctx,const Metadata ** metadata)190 /* static */ Status XlaDevice::GetMetadata(OpKernelConstruction* ctx,
191                                            const Metadata** metadata) {
192   return GetMetadataFromDevice(ctx->device(), metadata);
193 }
194 
XlaDevice(const SessionOptions & session_options,const Options & options)195 XlaDevice::XlaDevice(const SessionOptions& session_options,
196                      const Options& options)
197     : LocalDevice(session_options,
198                   BuildXlaDeviceAttributes(options.device_name_prefix,
199                                            options.device_name,
200                                            options.device_ordinal)),
201       xla_metadata_(options.device_ordinal, options.platform,
202                     DeviceType(options.compilation_device_name),
203                     options.shape_representation_fn,
204                     options.padded_shape_fn ? options.padded_shape_fn
205                                             : DefaultPaddedShapeFn,
206                     options.use_multiple_streams),
207       device_ordinal_(options.device_ordinal),
208       jit_device_name_(options.compilation_device_name),
209       platform_(options.platform),
210       intra_op_parallelism_threads_(
211           session_options.config.intra_op_parallelism_threads()),
212       use_multiple_streams_(options.use_multiple_streams),
213       shape_representation_fn_(options.shape_representation_fn),
214       allowed_devices_(options.allowed_devices) {
215   VLOG(1) << "Created XLA device " << options.compilation_device_name << " "
216           << this;
217   thread_pool_.reset(new thread::ThreadPool(session_options.env, "xla_device",
218                                             /*num_threads=*/1));
219 
220   // We have multiple device to device streams to allow for some concurrency
221   // between transfers. The particular value of '4' is chosen fairly
222   // arbitrarily. It may be necessary to make this tunable via
223   // XlaDevice::Options.
224   static constexpr int kNumDeviceToDeviceStreams = 4;
225   device_to_device_streams_.resize(kNumDeviceToDeviceStreams);
226 }
227 
~XlaDevice()228 XlaDevice::~XlaDevice() {
229   VLOG(1) << "Destroying XLA device " << jit_device_name_ << " " << this;
230   mutex_lock lock(mu_);
231   if (device_context_) {
232     device_context_->Unref();
233   }
234   if (fast_mem_device_context_) {
235     fast_mem_device_context_->Unref();
236   }
237 }
238 
GetOrCreateClient() const239 xla::StatusOr<xla::LocalClient*> XlaDevice::GetOrCreateClient() const {
240   // We lazily create the client because the platform commits to the
241   // details of the host hardware when the client is created, so we
242   // don't want to do it until we get a chance to hook the platform up
243   // to a simulator.
244 
245   xla::LocalClientOptions options;
246   options.set_platform(platform_)
247       .set_allowed_devices(allowed_devices_)
248       .set_intra_op_parallelism_threads(intra_op_parallelism_threads_);
249   return xla::ClientLibrary::GetOrCreateLocalClient(options);
250 }
251 
GetAllocator(AllocatorAttributes attr)252 Allocator* XlaDevice::GetAllocator(AllocatorAttributes attr) {
253   mutex_lock lock(mu_);
254   return GetAllocatorLocked(attr);
255 }
256 
GetAllocatorLocked(AllocatorAttributes attr)257 Allocator* XlaDevice::GetAllocatorLocked(AllocatorAttributes attr) {
258   if (attr.on_host()) {
259     return cpu_allocator();
260   }
261 
262   if (xla_allocator_ == nullptr) {
263     // TODO(b/78468222): This can fail, at least when the backend is GPU and
264     // there is no GPU on the host.
265     xla::Backend* backend = GetOrCreateClient().ValueOrDie()->mutable_backend();
266     xla_allocator_ = XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator(
267         backend, device_ordinal_);
268   }
269   return xla_allocator_;
270 }
271 
EnsureDeviceContextOk()272 Status XlaDevice::EnsureDeviceContextOk() {
273   mutex_lock lock(mu_);
274   return GetDeviceContextLocked().status();
275 }
276 
EnsureStreamOkLocked(xla::Backend * backend,const string & name,std::shared_ptr<se::Stream> * stream,bool * stream_was_changed)277 Status XlaDevice::EnsureStreamOkLocked(xla::Backend* backend,
278                                        const string& name,
279                                        std::shared_ptr<se::Stream>* stream,
280                                        bool* stream_was_changed) {
281   if (!(*stream) || !(*stream)->ok()) {
282     xla::StreamPool::Ptr ptr;
283     TF_ASSIGN_OR_RETURN(ptr, backend->BorrowStream(device_ordinal_));
284     *stream = std::shared_ptr<se::Stream>(std::move(ptr));
285     VLOG(1) << "XlaDevice " << this << " new " << name << " "
286             << (*stream)->DebugStreamPointers();
287     *stream_was_changed = true;
288   }
289   return Status::OK();
290 }
291 
292 xla::StatusOr<std::pair<XlaDeviceContext*, XlaDeviceContext*>>
GetDeviceContextLocked()293 XlaDevice::GetDeviceContextLocked() {
294   TF_ASSIGN_OR_RETURN(xla::LocalClient * client, GetOrCreateClient());
295   xla::Backend* backend = client->mutable_backend();
296 
297   // Ensure all our streams are valid, borrowing new streams if necessary.
298   bool need_new_device_context = !device_context_;
299   TF_RETURN_IF_ERROR(EnsureStreamOkLocked(backend, "stream", &stream_,
300                                           &need_new_device_context));
301 
302   std::shared_ptr<se::Stream> host_to_device_stream;
303   std::shared_ptr<se::Stream> device_to_host_stream;
304   std::vector<std::shared_ptr<se::Stream>> device_to_device_streams;
305   if (use_multiple_streams_) {
306     TF_RETURN_IF_ERROR(EnsureStreamOkLocked(backend, "host_to_device_stream",
307                                             &host_to_device_stream_,
308                                             &need_new_device_context));
309     for (std::shared_ptr<se::Stream>& stream : device_to_device_streams_) {
310       TF_RETURN_IF_ERROR(
311           EnsureStreamOkLocked(backend, "device_to_device_stream", &stream,
312                                &need_new_device_context));
313     }
314     host_to_device_stream = host_to_device_stream_;
315     device_to_device_streams = device_to_device_streams_;
316     // The data transfer requests from device to host could arrive out of order,
317     // so a single stream would cause deadlock. For this case,
318     // xla_device_context would borrow a stream for each transfer request.
319     device_to_host_stream = nullptr;
320   } else {
321     host_to_device_stream = stream_;
322     device_to_host_stream = stream_;
323     device_to_device_streams = {stream_};
324   }
325 
326   if (!need_new_device_context) {
327     return std::make_pair(device_context_, fast_mem_device_context_);
328   }
329 
330   // At this point we know we need a new device context.
331   // Call GetAllocator for the side-effect of ensuring the allocator is created.
332   GetAllocatorLocked({});
333   if (device_context_) {
334     device_context_->Unref();
335   }
336   if (fast_mem_device_context_) {
337     fast_mem_device_context_->Unref();
338   }
339   // The XlaDeviceContext keeps a reference count to the streams, and the
340   // XlaDeviceContext remains live for the duration of a Executor run. This
341   // ensures that the streams remain live for the duration of a run, even if
342   // an error is encountered and the streams are replaced with new ones.
343   device_context_ = new XlaDeviceContext(
344       stream_, host_to_device_stream, device_to_host_stream,
345       device_to_device_streams, client, shape_representation_fn_,
346       thread_pool_.get(), false);
347   VLOG(1) << "XlaDevice " << this << " new XlaDeviceContext(fast_mem=false) "
348           << device_context_;
349 
350   fast_mem_device_context_ = new XlaDeviceContext(
351       stream_, std::move(host_to_device_stream),
352       std::move(device_to_host_stream), std::move(device_to_device_streams),
353       client, shape_representation_fn_, thread_pool_.get(), true);
354   VLOG(1) << "XlaDevice " << this << " new XlaDeviceContext(fast_mem=true) "
355           << fast_mem_device_context_;
356 
357   // Create and set a new GpuDeviceInfo, if necessary.
358   //
359   // TODO(b/78232898): This isn't thread-safe; there is a race between the call
360   // to set_tensorflow_gpu_device_info() with ops that call the getter
361   // tensorflow_gpu_device_info(). This isn't trivially fixed by adding locking
362   // to those methods; see the bug for details. Our only saving grace at the
363   // moment is that this race doesn't seem to occur in practice.
364   if (use_gpu_device_info_) {
365     auto gpu_device_info = absl::make_unique<GpuDeviceInfo>();
366     gpu_device_info->stream = stream_.get();
367     gpu_device_info->default_context = device_context_;
368     set_tensorflow_gpu_device_info(gpu_device_info.get());
369     gpu_device_info_ = std::move(gpu_device_info);
370     VLOG(1) << "XlaDevice " << this << " new GpuDeviceInfo "
371             << gpu_device_info_.get();
372   }
373 
374   return std::make_pair(device_context_, fast_mem_device_context_);
375 }
376 
UseGpuDeviceInfo()377 Status XlaDevice::UseGpuDeviceInfo() {
378   mutex_lock lock(mu_);
379   use_gpu_device_info_ = true;
380   return GetDeviceContextLocked().status();
381 }
382 
TryGetDeviceContext(DeviceContext ** out_context)383 Status XlaDevice::TryGetDeviceContext(DeviceContext** out_context) {
384   mutex_lock lock(mu_);
385 
386   TF_ASSIGN_OR_RETURN(auto device_contexts, GetDeviceContextLocked());
387   device_contexts.first->Ref();
388   *out_context = device_contexts.first;
389   return Status::OK();
390 }
391 
392 // Warn about XLA_CPU/XLA_GPU exactly once.
ShowXlaDeviceDeprecationWarning(absl::string_view compilation_device_name)393 static void ShowXlaDeviceDeprecationWarning(
394     absl::string_view compilation_device_name) {
395   static absl::once_flag once;
396   if (absl::StrContains(compilation_device_name, "CPU") ||
397       absl::StrContains(compilation_device_name, "GPU")) {
398     absl::call_once(once, [] {
399       LOG(INFO) << "XLA_GPU and XLA_CPU devices are deprecated and will be "
400                    "removed in subsequent releases. Instead, use either "
401                    "@tf.function(jit_compile=True) for must-compile "
402                    "semantics, or run with TF_XLA_FLAGS=--tf_xla_auto_jit=2 "
403                    "for auto-clustering best-effort compilation.";
404     });
405   }
406 }
407 
Compute(OpKernel * op_kernel,OpKernelContext * context)408 void XlaDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) {
409   VLOG(2) << "XlaDevice::Compute " << op_kernel->name() << ":"
410           << op_kernel->type_string();
411   ShowXlaDeviceDeprecationWarning(jit_device_name_.type_string());
412   op_kernel->Compute(context);
413 }
414 
ComputeAsync(AsyncOpKernel * op_kernel,OpKernelContext * context,AsyncOpKernel::DoneCallback done)415 void XlaDevice::ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context,
416                              AsyncOpKernel::DoneCallback done) {
417   ShowXlaDeviceDeprecationWarning(jit_device_name_.type_string());
418   VLOG(2) << "XlaDevice::ComputeAsync " << op_kernel->name() << ":"
419           << op_kernel->type_string();
420   op_kernel->ComputeAsync(context, done);
421 }
422 
Sync()423 Status XlaDevice::Sync() {
424   VLOG(1) << "XlaDevice::Sync";
425   profiler::TraceMe activity("XlaDevice::Sync", profiler::TraceMeLevel::kInfo);
426   std::shared_ptr<se::Stream> stream;
427   {
428     mutex_lock lock(mu_);
429     stream = stream_;
430   }
431   if (!stream) return Status::OK();
432 
433   Status status = stream->BlockHostUntilDone();
434   TF_RETURN_IF_ERROR(status);
435   if (!stream->ok()) {
436     return errors::Internal("XlaDevice::Sync() failed.");
437   }
438   VLOG(1) << "XlaDevice::Sync completed";
439   return Status::OK();
440 }
441 
442 // TODO(b/112409994): This is no longer necessary. Consolidate it with the
443 // synchronous version.
Sync(const DoneCallback & done)444 void XlaDevice::Sync(const DoneCallback& done) {
445   VLOG(1) << "XlaDevice::Sync (asynchronous)";
446   std::shared_ptr<se::Stream> stream;
447   {
448     mutex_lock lock(mu_);
449     stream = stream_;
450   }
451   if (!stream) {
452     done(Status::OK());
453     return;
454   }
455 
456   // The call to ThenEnqueueOnBackgroundThread below enqueues a host callback at
457   // the end of the stream, after everything that has already been enqueued
458   // there at this moment. When the host callback is called, everything before
459   // it must have already finished, and the host callback will then place the
460   // task below onto a background thread. (See the implementation of
461   // ThenEnqueueOnBackgroundThread for details.) Therefore, when the done
462   // callback is finally called from that background thread, we know for sure
463   // that everything enqueued onto the stream (i.e., the device) at this very
464   // moment--when ThenEnqueueOnBackgroundThread is called--will have finished.
465   // This achieves a device-wide sync.
466   stream->ThenEnqueueOnBackgroundThread([stream, done](se::StreamExecutor*) {
467     profiler::TraceMe activity("XlaDevice::Sync::Callback",
468                                profiler::TraceMeLevel::kInfo);
469     done(stream->ok() ? Status::OK()
470                       : errors::Internal("XlaDevice::Sync() failed."));
471   });
472 }
473 
MakeTensorFromProto(XlaDeviceContext * device_context,const TensorProto & tensor_proto,const AllocatorAttributes alloc_attrs,Tensor * tensor)474 Status XlaDevice::MakeTensorFromProto(XlaDeviceContext* device_context,
475                                       const TensorProto& tensor_proto,
476                                       const AllocatorAttributes alloc_attrs,
477                                       Tensor* tensor) {
478   Tensor parsed(tensor_proto.dtype());
479   if (!parsed.FromProto(cpu_allocator(), tensor_proto)) {
480     return errors::InvalidArgument("Cannot parse tensor from proto: ",
481                                    tensor_proto.DebugString());
482   }
483 
484   Status status;
485   if (alloc_attrs.on_host()) {
486     *tensor = parsed;
487   } else {
488     mutex_lock lock(mu_);
489     Allocator* allocator = GetAllocatorLocked(alloc_attrs);
490     Tensor copy(allocator, parsed.dtype(), parsed.shape());
491     TF_RETURN_IF_ERROR(
492         device_context->CopyCPUTensorToDeviceSync(&parsed, this, &copy));
493     *tensor = copy;
494   }
495   VLOG(2) << "Allocated tensor at " << DMAHelper::base(tensor);
496   return status;
497 }
498 
MakeTensorFromProto(const TensorProto & tensor_proto,const AllocatorAttributes alloc_attrs,Tensor * tensor)499 Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto,
500                                       const AllocatorAttributes alloc_attrs,
501                                       Tensor* tensor) {
502   VLOG(1) << "XlaDevice::MakeTensorFromProto";
503   std::pair<XlaDeviceContext*, XlaDeviceContext*> device_contexts;
504   {
505     mutex_lock lock(mu_);
506     TF_ASSIGN_OR_RETURN(device_contexts, GetDeviceContextLocked());
507   }
508   return MakeTensorFromProto(device_contexts.first, tensor_proto, alloc_attrs,
509                              tensor);
510 }
511 
MakeFastMemTensorFromProto(const TensorProto & tensor_proto,const AllocatorAttributes alloc_attrs,Tensor * tensor)512 Status XlaDevice::MakeFastMemTensorFromProto(
513     const TensorProto& tensor_proto, const AllocatorAttributes alloc_attrs,
514     Tensor* tensor) {
515   VLOG(1) << "XlaDevice::MakeFastMemTensorFromProto";
516   std::pair<XlaDeviceContext*, XlaDeviceContext*> device_contexts;
517   {
518     mutex_lock lock(mu_);
519     TF_ASSIGN_OR_RETURN(device_contexts, GetDeviceContextLocked());
520   }
521   return MakeTensorFromProto(device_contexts.second, tensor_proto, alloc_attrs,
522                              tensor);
523 }
524 
SetAllowsSyncOnCompletion(bool sync_on_completion)525 void XlaDevice::SetAllowsSyncOnCompletion(bool sync_on_completion) {
526   mutex_lock lock(mu_);
527   sync_on_completion_ = sync_on_completion;
528 }
529 
AllowsSyncOnCompletion() const530 bool XlaDevice::AllowsSyncOnCompletion() const {
531   mutex_lock lock(mu_);
532   return sync_on_completion_;
533 }
534 
SetHandleDeviceErrorCallback(std::function<Status ()> callback)535 void XlaDevice::SetHandleDeviceErrorCallback(std::function<Status()> callback) {
536   mutex_lock lock(mu_);
537   device_error_callback_ = callback;
538 }
539 
HandleDeviceError()540 Status XlaDevice::HandleDeviceError() {
541   std::function<Status()> local_device_error_callback;
542   {
543     mutex_lock lock(mu_);
544     local_device_error_callback = device_error_callback_;
545   }
546   if (local_device_error_callback != nullptr) {
547     return local_device_error_callback();
548   }
549   return Status::OK();
550 }
551 
RefreshStatus()552 Status XlaDevice::RefreshStatus() {
553   std::shared_ptr<se::Stream> stream;
554   {
555     mutex_lock lock(mu_);
556     stream = stream_;
557   }
558   if (!stream) {
559     return Status::OK();
560   }
561   Status status = stream->RefreshStatus();
562   if (!status.ok()) {
563     // Ignore errors from HandleDeviceError, since by definition the status is
564     // already non-ok, so there's nothing extra to report if HandleDeviceError
565     // itself returns an error.
566     HandleDeviceError().IgnoreError();
567   }
568   return status;
569 }
570 
RegisterXlaDeviceKernels(const char * device,const char * jit_device)571 XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device,
572                                                    const char* jit_device) {
573   // Any op assigned to the device that isn't rewritten by the graph rewriter
574   // gets executed by an XlaCompileOnDemandOp, which compiles it and executes
575   // it just-in-time.
576   auto factory = [](OpKernelConstruction* context) -> OpKernel* {
577     return new XlaCompileOnDemandOp(context);
578   };
579   XlaOpRegistry::RegisterCompilationKernels();
580   XlaDeviceOpRegistrations* registrations = new XlaDeviceOpRegistrations;
581   for (const KernelDef* jit_def : XlaOpRegistry::DeviceKernels(
582            jit_device,
583            /*include_compilation_only_kernels=*/false)) {
584     KernelDef* def = new KernelDef(*jit_def);
585     const std::unordered_set<std::string>* constant_inputs =
586         XlaOpRegistry::CompileTimeConstantInputArgNames(def->op());
587 
588     for (const std::string& arg_name : *constant_inputs) {
589       def->add_host_memory_arg(arg_name);
590     }
591 
592     def->set_device_type(device);
593     registrations->op_kernel_registrars.emplace_back(
594         new kernel_factory::OpKernelRegistrar(def, "XlaCompileOnDemandOp",
595                                               factory));
596   }
597   return registrations;
598 }
599 
600 }  // namespace tensorflow
601