• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/jit/xla_device.h"
17 
18 #include <stdlib.h>
19 
20 #include <unordered_set>
21 #include <utility>
22 
23 #include "absl/base/call_once.h"
24 #include "absl/memory/memory.h"
25 #include "absl/strings/match.h"
26 #include "tensorflow/compiler/jit/defs.h"
27 #include "tensorflow/compiler/jit/xla_compile_on_demand_op.h"
28 #include "tensorflow/compiler/jit/xla_device_context.h"
29 #include "tensorflow/compiler/jit/xla_device_ops.h"
30 #include "tensorflow/compiler/tf2xla/shape_util.h"
31 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
32 #include "tensorflow/compiler/xla/client/client_library.h"
33 #include "tensorflow/compiler/xla/service/stream_pool.h"
34 #include "tensorflow/core/common_runtime/device.h"
35 #include "tensorflow/core/common_runtime/device_factory.h"
36 #include "tensorflow/core/common_runtime/dma_helper.h"
37 #include "tensorflow/core/common_runtime/function.h"
38 #include "tensorflow/core/common_runtime/graph_constructor.h"
39 #include "tensorflow/core/common_runtime/renamed_device.h"
40 #include "tensorflow/core/framework/allocator.h"
41 #include "tensorflow/core/framework/device_base.h"
42 #include "tensorflow/core/framework/function.h"
43 #include "tensorflow/core/framework/kernel_def.pb.h"
44 #include "tensorflow/core/framework/node_def_builder.h"
45 #include "tensorflow/core/framework/op_kernel.h"
46 #include "tensorflow/core/framework/tensor.h"
47 #include "tensorflow/core/framework/tensor.pb.h"
48 #include "tensorflow/core/framework/types.h"
49 #include "tensorflow/core/lib/core/notification.h"
50 #include "tensorflow/core/lib/core/status.h"
51 #include "tensorflow/core/platform/logging.h"
52 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
53 #include "tensorflow/core/platform/tracing.h"
54 #include "tensorflow/core/profiler/lib/traceme.h"
55 #include "tensorflow/core/public/session_options.h"
56 #include "tensorflow/core/public/version.h"
57 #include "tensorflow/core/util/device_name_utils.h"
58 #include "tensorflow/core/util/dump_graph.h"
59 #include "tensorflow/core/util/ptr_util.h"
60 #include "tensorflow/core/util/stream_executor_util.h"
61 
62 namespace tensorflow {
63 
64 // Default PaddedShapeFn implementation that simply returns the unpadded
65 // on-device shape. This is accurate for CPU and GPU devices that neither
66 // transpose nor pad tensors.
DefaultPaddedShapeFn(const Tensor & tensor,xla::Shape * shape)67 Status DefaultPaddedShapeFn(const Tensor& tensor, xla::Shape* shape) {
68   const tensorflow::XlaTensor* xla_tensor =
69       tensorflow::XlaTensor::FromTensor(&tensor);
70   if (xla_tensor == nullptr) {
71     return TensorShapeToXLAShape(tensor.dtype(), tensor.shape(), shape);
72   }
73 
74   const xla::ShapedBuffer& shaped_buffer = xla_tensor->shaped_buffer();
75   *shape = shaped_buffer.on_device_shape();
76   return Status::OK();
77 }
78 
79 // Caches a XlaDeviceAllocator per <backend, device ordinal> pair. A
80 // XlaDeviceAllocator is created on demand and is associated with a
81 // XlaDevice. It outlives the device itself (for instance, the buffer
82 // backing a tensor holds a pointer to the allocator for book-keeping,
83 // and this buffer can outlast the device).
84 class XlaDeviceAllocatorState {
85  public:
86   // Creates or returns a cached XlaDeviceAllocator for a given
87   // backend and device_ordinal.
88   static XlaDeviceAllocator* GetOrCreateXlaDeviceAllocator(
89       const xla::Backend* backend, int device_ordinal);
90 
91  private:
92   // Returns the singleton instance of XlaDeviceAllocatorState.
93   static XlaDeviceAllocatorState& Singleton();
94   XlaDeviceAllocatorState();
95   ~XlaDeviceAllocatorState();
96 
97   mutex allocator_mutex_;  // Guards the singleton allocator state.
98   std::unordered_map<std::pair<const xla::Backend*, int>,
99                      std::unique_ptr<XlaDeviceAllocator>,
100                      hash<std::pair<const xla::Backend*, int>>>
101       allocators_ TF_GUARDED_BY(allocator_mutex_);
102 
103   TF_DISALLOW_COPY_AND_ASSIGN(XlaDeviceAllocatorState);
104 };
105 
Singleton()106 /* static */ XlaDeviceAllocatorState& XlaDeviceAllocatorState::Singleton() {
107   static auto a = new XlaDeviceAllocatorState;
108   return *a;
109 }
110 
111 XlaDeviceAllocatorState::XlaDeviceAllocatorState() = default;
112 XlaDeviceAllocatorState::~XlaDeviceAllocatorState() = default;
113 
GetOrCreateXlaDeviceAllocator(const xla::Backend * backend,int device_ordinal)114 XlaDeviceAllocator* XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator(
115     const xla::Backend* backend, int device_ordinal) {
116   XlaDeviceAllocatorState& state = Singleton();
117   mutex_lock lock(state.allocator_mutex_);
118 
119   auto it = state.allocators_.find({backend, device_ordinal});
120   if (it != state.allocators_.end()) {
121     return it->second.get();
122   }
123 
124   std::unique_ptr<XlaDeviceAllocator> alloc =
125       absl::make_unique<XlaDeviceAllocator>(
126           backend->stream_executors()[device_ordinal]);
127   XlaDeviceAllocator* alloc_ptr = alloc.get();
128   state.allocators_[{backend, device_ordinal}] = std::move(alloc);
129   return alloc_ptr;
130 }
131 
132 namespace {
133 
134 
BuildXlaDeviceAttributes(const string & name_prefix,const string & device_name,int device_ordinal)135 static DeviceAttributes BuildXlaDeviceAttributes(const string& name_prefix,
136                                                  const string& device_name,
137                                                  int device_ordinal) {
138   return Device::BuildDeviceAttributes(
139       absl::StrCat(name_prefix, "/device:", device_name, ":", device_ordinal),
140       DeviceType(device_name), Bytes(16ULL << 30), DeviceLocality(),
141       absl::StrCat("device: ", device_name, " device"));
142 }
143 
144 }  // namespace
145 
Metadata(int device_ordinal,se::Platform * platform,const DeviceType & device_type,XlaCompiler::ShapeRepresentationFn shape_representation_fn,PaddedShapeFn padded_shape_fn,bool use_multiple_streams)146 XlaDevice::Metadata::Metadata(
147     int device_ordinal, se::Platform* platform, const DeviceType& device_type,
148     XlaCompiler::ShapeRepresentationFn shape_representation_fn,
149     PaddedShapeFn padded_shape_fn, bool use_multiple_streams)
150     : device_ordinal_(device_ordinal),
151       device_type_(device_type),
152       platform_(platform),
153       shape_representation_fn_(std::move(shape_representation_fn)),
154       padded_shape_fn_(std::move(padded_shape_fn)),
155       use_multiple_streams_(use_multiple_streams) {}
156 
device_ordinal() const157 int XlaDevice::Metadata::device_ordinal() const { return device_ordinal_; }
158 
platform() const159 se::Platform* XlaDevice::Metadata::platform() const { return platform_; }
160 
client() const161 xla::LocalClient* XlaDevice::Metadata::client() const {
162   auto client = xla::ClientLibrary::GetOrCreateLocalClient(platform_);
163   return client.ValueOrDie();
164 }
165 
jit_device_type() const166 const DeviceType& XlaDevice::Metadata::jit_device_type() const {
167   return device_type_;
168 }
169 
GetMetadataFromDevice(DeviceBase * device,const XlaDevice::Metadata ** metadata)170 /*static*/ Status XlaDevice::GetMetadataFromDevice(
171     DeviceBase* device, const XlaDevice::Metadata** metadata) {
172   *metadata = nullptr;
173   XlaDevice* xla_device = dynamic_cast<XlaDevice*>(device->UnderlyingDevice());
174   if (xla_device == nullptr) {
175     return errors::Internal(
176         "Cannot get XLA metadata from non-XLA device \"", device->name(),
177         "\". GetMetadata must only be called on an XLA device. Either an "
178         "internal bug has been triggered, or an XLA-specific op has been "
179         "placed on the wrong device.");
180   }
181   *metadata = &(xla_device->xla_metadata_);
182   return Status::OK();
183 }
184 
GetMetadata(OpKernelContext * ctx,const Metadata ** metadata)185 /* static */ Status XlaDevice::GetMetadata(OpKernelContext* ctx,
186                                            const Metadata** metadata) {
187   return GetMetadataFromDevice(ctx->device(), metadata);
188 }
189 
GetMetadata(OpKernelConstruction * ctx,const Metadata ** metadata)190 /* static */ Status XlaDevice::GetMetadata(OpKernelConstruction* ctx,
191                                            const Metadata** metadata) {
192   return GetMetadataFromDevice(ctx->device(), metadata);
193 }
194 
195 /* static */ mutex XlaDevice::global_mu_(LINKER_INITIALIZED);
196 /* static */ std::vector<std::shared_ptr<se::Stream>>*
197     XlaDevice::global_compute_streams_ =
198         new std::vector<std::shared_ptr<se::Stream>>;
199 
XlaDevice(const SessionOptions & session_options,const Options & options)200 XlaDevice::XlaDevice(const SessionOptions& session_options,
201                      const Options& options)
202     : LocalDevice(session_options,
203                   BuildXlaDeviceAttributes(options.device_name_prefix,
204                                            options.device_name,
205                                            options.device_ordinal)),
206       xla_metadata_(options.device_ordinal, options.platform,
207                     DeviceType(options.compilation_device_name),
208                     options.shape_representation_fn,
209                     options.padded_shape_fn ? options.padded_shape_fn
210                                             : DefaultPaddedShapeFn,
211                     options.use_multiple_streams),
212       device_ordinal_(options.device_ordinal),
213       jit_device_name_(options.compilation_device_name),
214       platform_(options.platform),
215       intra_op_parallelism_threads_(
216           session_options.config.intra_op_parallelism_threads()),
217       use_multiple_streams_(options.use_multiple_streams),
218       shape_representation_fn_(options.shape_representation_fn),
219       allowed_devices_(options.allowed_devices),
220       use_global_compute_stream_(options.use_global_compute_stream) {
221   VLOG(1) << "Created XLA device " << options.compilation_device_name << " "
222           << options.device_ordinal << " " << this;
223   VLOG(1) << "XlaDevice options: use_multiple_streams: "
224           << options.use_multiple_streams << " use_global_compute_stream: "
225           << options.use_global_compute_stream;
226   thread_pool_.reset(new thread::ThreadPool(session_options.env, "xla_device",
227                                             /*num_threads=*/1));
228 
229   // We have multiple device to device streams to allow for some concurrency
230   // between transfers. The particular value of '4' is chosen fairly
231   // arbitrarily. It may be necessary to make this tunable via
232   // XlaDevice::Options.
233   static constexpr int kNumDeviceToDeviceStreams = 4;
234   device_to_device_streams_.resize(kNumDeviceToDeviceStreams);
235 }
236 
~XlaDevice()237 XlaDevice::~XlaDevice() {
238   VLOG(1) << "Destroying XLA device " << jit_device_name_ << " " << this;
239   mutex_lock lock(mu_);
240   if (device_context_) {
241     device_context_->Unref();
242   }
243   if (fast_mem_device_context_) {
244     fast_mem_device_context_->Unref();
245   }
246 }
247 
GetOrCreateClient() const248 StatusOr<xla::LocalClient*> XlaDevice::GetOrCreateClient() const {
249   // We lazily create the client because the platform commits to the
250   // details of the host hardware when the client is created, so we
251   // don't want to do it until we get a chance to hook the platform up
252   // to a simulator.
253 
254   xla::LocalClientOptions options;
255   options.set_platform(platform_)
256       .set_allowed_devices(allowed_devices_)
257       .set_intra_op_parallelism_threads(intra_op_parallelism_threads_);
258   return xla::ClientLibrary::GetOrCreateLocalClient(options);
259 }
260 
GetAllocator(AllocatorAttributes attr)261 Allocator* XlaDevice::GetAllocator(AllocatorAttributes attr) {
262   mutex_lock lock(mu_);
263   return GetAllocatorLocked(attr);
264 }
265 
GetAllocatorLocked(AllocatorAttributes attr)266 Allocator* XlaDevice::GetAllocatorLocked(AllocatorAttributes attr) {
267   if (attr.on_host()) {
268     return cpu_allocator();
269   }
270 
271   if (xla_allocator_ == nullptr) {
272     // TODO(b/78468222): This can fail, at least when the backend is GPU and
273     // there is no GPU on the host.
274     xla::Backend* backend = GetOrCreateClient().ValueOrDie()->mutable_backend();
275     xla_allocator_ = XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator(
276         backend, device_ordinal_);
277   }
278   return xla_allocator_;
279 }
280 
EnsureDeviceContextOk()281 Status XlaDevice::EnsureDeviceContextOk() {
282   mutex_lock lock(mu_);
283   return GetDeviceContextLocked().status();
284 }
285 
EnsureStreamOkLocked(xla::Backend * backend,const string & name,std::shared_ptr<se::Stream> * stream,bool * stream_was_changed)286 Status XlaDevice::EnsureStreamOkLocked(xla::Backend* backend,
287                                        const string& name,
288                                        std::shared_ptr<se::Stream>* stream,
289                                        bool* stream_was_changed) {
290   if (!(*stream) || !(*stream)->ok()) {
291     xla::StreamPool::Ptr ptr;
292     TF_ASSIGN_OR_RETURN(ptr, backend->BorrowStream(device_ordinal_));
293     *stream = std::shared_ptr<se::Stream>(std::move(ptr));
294     VLOG(1) << "XlaDevice " << this << " new " << name << " "
295             << (*stream)->DebugStreamPointers();
296     *stream_was_changed = true;
297   }
298   return Status::OK();
299 }
300 
301 StatusOr<std::pair<XlaDeviceContext*, XlaDeviceContext*>>
GetDeviceContextLocked()302 XlaDevice::GetDeviceContextLocked() {
303   TF_ASSIGN_OR_RETURN(xla::LocalClient * client, GetOrCreateClient());
304   xla::Backend* backend = client->mutable_backend();
305 
306   // Ensure all our streams are valid, borrowing new streams if necessary.
307   bool need_new_device_context = !device_context_;
308   if (use_global_compute_stream_) {
309     mutex_lock lock(global_mu_);
310     if (global_compute_streams_->size() <= device_ordinal_) {
311       global_compute_streams_->resize(device_ordinal_ + 1, nullptr);
312     }
313 
314     auto& global_stream = global_compute_streams_->at(device_ordinal_);
315     if (global_stream != nullptr && global_stream->ok()) {
316       stream_ = global_stream;
317     } else {
318       // Directly create the stream here instead of borrowing from the stream
319       // pool to avoid potential lifetime issues.
320       stream_ = absl::make_unique<se::Stream>(
321           backend->stream_executors()[device_ordinal_]);
322       stream_->Init();
323       TF_RETURN_IF_ERROR(EnsureStreamOkLocked(backend, "stream", &stream_,
324                                               &need_new_device_context));
325       (*global_compute_streams_)[device_ordinal_] = stream_;
326     }
327   } else {
328     TF_RETURN_IF_ERROR(EnsureStreamOkLocked(backend, "stream", &stream_,
329                                             &need_new_device_context));
330   }
331 
332   std::shared_ptr<se::Stream> host_to_device_stream;
333   std::shared_ptr<se::Stream> device_to_host_stream;
334   std::vector<std::shared_ptr<se::Stream>> device_to_device_streams;
335   if (use_multiple_streams_) {
336     TF_RETURN_IF_ERROR(EnsureStreamOkLocked(backend, "host_to_device_stream",
337                                             &host_to_device_stream_,
338                                             &need_new_device_context));
339     for (std::shared_ptr<se::Stream>& stream : device_to_device_streams_) {
340       TF_RETURN_IF_ERROR(
341           EnsureStreamOkLocked(backend, "device_to_device_stream", &stream,
342                                &need_new_device_context));
343     }
344     host_to_device_stream = host_to_device_stream_;
345     device_to_device_streams = device_to_device_streams_;
346     // The data transfer requests from device to host could arrive out of order,
347     // so a single stream would cause deadlock. For this case,
348     // xla_device_context would borrow a stream for each transfer request.
349     device_to_host_stream = nullptr;
350   } else {
351     host_to_device_stream = stream_;
352     device_to_host_stream = stream_;
353     device_to_device_streams = {stream_};
354   }
355 
356   if (!need_new_device_context) {
357     return std::make_pair(device_context_, fast_mem_device_context_);
358   }
359 
360   // At this point we know we need a new device context.
361   // Call GetAllocator for the side-effect of ensuring the allocator is created.
362   GetAllocatorLocked({});
363   if (device_context_) {
364     device_context_->Unref();
365   }
366   if (fast_mem_device_context_) {
367     fast_mem_device_context_->Unref();
368   }
369   // The XlaDeviceContext keeps a reference count to the streams, and the
370   // XlaDeviceContext remains live for the duration of a Executor run. This
371   // ensures that the streams remain live for the duration of a run, even if
372   // an error is encountered and the streams are replaced with new ones.
373   device_context_ = new XlaDeviceContext(
374       stream_, host_to_device_stream, device_to_host_stream,
375       device_to_device_streams, client, shape_representation_fn_,
376       thread_pool_.get(), false);
377   VLOG(1) << "XlaDevice " << this << " new XlaDeviceContext(fast_mem=false) "
378           << device_context_;
379 
380   fast_mem_device_context_ = new XlaDeviceContext(
381       stream_, std::move(host_to_device_stream),
382       std::move(device_to_host_stream), std::move(device_to_device_streams),
383       client, shape_representation_fn_, thread_pool_.get(), true);
384   VLOG(1) << "XlaDevice " << this << " new XlaDeviceContext(fast_mem=true) "
385           << fast_mem_device_context_;
386 
387   // Create and set a new GpuDeviceInfo, if necessary.
388   //
389   // TODO(b/78232898): This isn't thread-safe; there is a race between the call
390   // to set_tensorflow_gpu_device_info() with ops that call the getter
391   // tensorflow_gpu_device_info(). This isn't trivially fixed by adding locking
392   // to those methods; see the bug for details. Our only saving grace at the
393   // moment is that this race doesn't seem to occur in practice.
394   if (use_gpu_device_info_) {
395     auto gpu_device_info = absl::make_unique<GpuDeviceInfo>();
396     gpu_device_info->stream = stream_.get();
397     gpu_device_info->default_context = device_context_;
398     set_tensorflow_gpu_device_info(gpu_device_info.get());
399     gpu_device_info_ = std::move(gpu_device_info);
400     VLOG(1) << "XlaDevice " << this << " new GpuDeviceInfo "
401             << gpu_device_info_.get();
402   }
403 
404   return std::make_pair(device_context_, fast_mem_device_context_);
405 }
406 
UseGpuDeviceInfo()407 Status XlaDevice::UseGpuDeviceInfo() {
408   mutex_lock lock(mu_);
409   use_gpu_device_info_ = true;
410   return GetDeviceContextLocked().status();
411 }
412 
TryGetDeviceContext(DeviceContext ** out_context)413 Status XlaDevice::TryGetDeviceContext(DeviceContext** out_context) {
414   mutex_lock lock(mu_);
415 
416   TF_ASSIGN_OR_RETURN(auto device_contexts, GetDeviceContextLocked());
417   device_contexts.first->Ref();
418   *out_context = device_contexts.first;
419   return Status::OK();
420 }
421 
422 // Warn about XLA_CPU/XLA_GPU exactly once.
ShowXlaDeviceDeprecationWarning(absl::string_view compilation_device_name)423 static void ShowXlaDeviceDeprecationWarning(
424     absl::string_view compilation_device_name) {
425   static absl::once_flag once;
426   if (absl::StrContains(compilation_device_name, "CPU") ||
427       absl::StrContains(compilation_device_name, "GPU")) {
428     absl::call_once(once, [] {
429       LOG(INFO) << "XLA_GPU and XLA_CPU devices are deprecated and will be "
430                    "removed in subsequent releases. Instead, use either "
431                    "@tf.function(jit_compile=True) for must-compile "
432                    "semantics, or run with TF_XLA_FLAGS=--tf_xla_auto_jit=2 "
433                    "for auto-clustering best-effort compilation.";
434     });
435   }
436 }
437 
Compute(OpKernel * op_kernel,OpKernelContext * context)438 void XlaDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) {
439   VLOG(2) << "XlaDevice::Compute " << op_kernel->name() << ":"
440           << op_kernel->type_string();
441   ShowXlaDeviceDeprecationWarning(jit_device_name_.type_string());
442   op_kernel->Compute(context);
443 }
444 
ComputeAsync(AsyncOpKernel * op_kernel,OpKernelContext * context,AsyncOpKernel::DoneCallback done)445 void XlaDevice::ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context,
446                              AsyncOpKernel::DoneCallback done) {
447   ShowXlaDeviceDeprecationWarning(jit_device_name_.type_string());
448   VLOG(2) << "XlaDevice::ComputeAsync " << op_kernel->name() << ":"
449           << op_kernel->type_string();
450   op_kernel->ComputeAsync(context, done);
451 }
452 
Sync()453 Status XlaDevice::Sync() {
454   VLOG(1) << "XlaDevice::Sync";
455   profiler::TraceMe activity("XlaDevice::Sync", profiler::TraceMeLevel::kInfo);
456   std::shared_ptr<se::Stream> stream;
457   {
458     mutex_lock lock(mu_);
459     stream = stream_;
460   }
461   if (!stream) return Status::OK();
462 
463   Status status = stream->BlockHostUntilDone();
464   TF_RETURN_IF_ERROR(status);
465   if (!stream->ok()) {
466     return errors::Internal("XlaDevice::Sync() failed.");
467   }
468   VLOG(1) << "XlaDevice::Sync completed";
469   return Status::OK();
470 }
471 
472 // TODO(b/112409994): This is no longer necessary. Consolidate it with the
473 // synchronous version.
Sync(const DoneCallback & done)474 void XlaDevice::Sync(const DoneCallback& done) {
475   VLOG(1) << "XlaDevice::Sync (asynchronous)";
476   std::shared_ptr<se::Stream> stream;
477   {
478     mutex_lock lock(mu_);
479     stream = stream_;
480   }
481   if (!stream) {
482     done(Status::OK());
483     return;
484   }
485 
486   // The call to ThenEnqueueOnBackgroundThread below enqueues a host callback at
487   // the end of the stream, after everything that has already been enqueued
488   // there at this moment. When the host callback is called, everything before
489   // it must have already finished, and the host callback will then place the
490   // task below onto a background thread. (See the implementation of
491   // ThenEnqueueOnBackgroundThread for details.) Therefore, when the done
492   // callback is finally called from that background thread, we know for sure
493   // that everything enqueued onto the stream (i.e., the device) at this very
494   // moment--when ThenEnqueueOnBackgroundThread is called--will have finished.
495   // This achieves a device-wide sync.
496   stream->ThenEnqueueOnBackgroundThread([stream, done](se::StreamExecutor*) {
497     profiler::TraceMe activity("XlaDevice::Sync::Callback",
498                                profiler::TraceMeLevel::kInfo);
499     done(stream->ok() ? Status::OK()
500                       : errors::Internal("XlaDevice::Sync() failed."));
501   });
502 }
503 
MakeTensorFromProto(XlaDeviceContext * device_context,const TensorProto & tensor_proto,const AllocatorAttributes alloc_attrs,Tensor * tensor)504 Status XlaDevice::MakeTensorFromProto(XlaDeviceContext* device_context,
505                                       const TensorProto& tensor_proto,
506                                       const AllocatorAttributes alloc_attrs,
507                                       Tensor* tensor) {
508   Tensor parsed(tensor_proto.dtype());
509   if (!parsed.FromProto(cpu_allocator(), tensor_proto)) {
510     return errors::InvalidArgument("Cannot parse tensor from proto: ",
511                                    tensor_proto.DebugString());
512   }
513 
514   Status status;
515   if (alloc_attrs.on_host()) {
516     *tensor = parsed;
517   } else {
518     mutex_lock lock(mu_);
519     Allocator* allocator = GetAllocatorLocked(alloc_attrs);
520     Tensor copy(allocator, parsed.dtype(), parsed.shape());
521     TF_RETURN_IF_ERROR(
522         device_context->CopyCPUTensorToDeviceSync(&parsed, this, &copy));
523     *tensor = copy;
524   }
525   VLOG(2) << "Allocated tensor at " << DMAHelper::base(tensor);
526   return status;
527 }
528 
MakeTensorFromProto(const TensorProto & tensor_proto,const AllocatorAttributes alloc_attrs,Tensor * tensor)529 Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto,
530                                       const AllocatorAttributes alloc_attrs,
531                                       Tensor* tensor) {
532   VLOG(1) << "XlaDevice::MakeTensorFromProto";
533   std::pair<XlaDeviceContext*, XlaDeviceContext*> device_contexts;
534   {
535     mutex_lock lock(mu_);
536     TF_ASSIGN_OR_RETURN(device_contexts, GetDeviceContextLocked());
537   }
538   return MakeTensorFromProto(device_contexts.first, tensor_proto, alloc_attrs,
539                              tensor);
540 }
541 
MakeFastMemTensorFromProto(const TensorProto & tensor_proto,const AllocatorAttributes alloc_attrs,Tensor * tensor)542 Status XlaDevice::MakeFastMemTensorFromProto(
543     const TensorProto& tensor_proto, const AllocatorAttributes alloc_attrs,
544     Tensor* tensor) {
545   VLOG(1) << "XlaDevice::MakeFastMemTensorFromProto";
546   std::pair<XlaDeviceContext*, XlaDeviceContext*> device_contexts;
547   {
548     mutex_lock lock(mu_);
549     TF_ASSIGN_OR_RETURN(device_contexts, GetDeviceContextLocked());
550   }
551   return MakeTensorFromProto(device_contexts.second, tensor_proto, alloc_attrs,
552                              tensor);
553 }
554 
SetAllowsSyncOnCompletion(bool sync_on_completion)555 void XlaDevice::SetAllowsSyncOnCompletion(bool sync_on_completion) {
556   mutex_lock lock(mu_);
557   sync_on_completion_ = sync_on_completion;
558 }
559 
AllowsSyncOnCompletion() const560 bool XlaDevice::AllowsSyncOnCompletion() const {
561   mutex_lock lock(mu_);
562   return sync_on_completion_;
563 }
564 
SetHandleDeviceErrorCallback(std::function<Status ()> callback)565 void XlaDevice::SetHandleDeviceErrorCallback(std::function<Status()> callback) {
566   mutex_lock lock(mu_);
567   device_error_callback_ = callback;
568 }
569 
HandleDeviceError()570 Status XlaDevice::HandleDeviceError() {
571   std::function<Status()> local_device_error_callback;
572   {
573     mutex_lock lock(mu_);
574     local_device_error_callback = device_error_callback_;
575   }
576   if (local_device_error_callback != nullptr) {
577     return local_device_error_callback();
578   }
579   return Status::OK();
580 }
581 
RefreshStatus()582 Status XlaDevice::RefreshStatus() {
583   std::shared_ptr<se::Stream> stream;
584   {
585     mutex_lock lock(mu_);
586     stream = stream_;
587   }
588   if (!stream) {
589     return Status::OK();
590   }
591   Status status = stream->RefreshStatus();
592   if (!status.ok()) {
593     // Ignore errors from HandleDeviceError, since by definition the status is
594     // already non-ok, so there's nothing extra to report if HandleDeviceError
595     // itself returns an error.
596     HandleDeviceError().IgnoreError();
597   }
598   return status;
599 }
600 
RegisterXlaDeviceKernels(const char * device,const char * jit_device)601 XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device,
602                                                    const char* jit_device) {
603   // Any op assigned to the device that isn't rewritten by the graph rewriter
604   // gets executed by an XlaCompileOnDemandOp, which compiles it and executes
605   // it just-in-time.
606   auto factory = [](OpKernelConstruction* context) -> OpKernel* {
607     return new XlaCompileOnDemandOp(context);
608   };
609   XlaOpRegistry::RegisterCompilationKernels();
610   XlaDeviceOpRegistrations* registrations = new XlaDeviceOpRegistrations;
611   for (const KernelDef* jit_def : XlaOpRegistry::DeviceKernels(
612            jit_device,
613            /*include_compilation_only_kernels=*/false)) {
614     KernelDef* def = new KernelDef(*jit_def);
615     const std::unordered_set<std::string>* constant_inputs =
616         XlaOpRegistry::CompileTimeConstantInputArgNames(def->op());
617 
618     for (const std::string& arg_name : *constant_inputs) {
619       def->add_host_memory_arg(arg_name);
620     }
621 
622     def->set_device_type(device);
623     registrations->op_kernel_registrars.emplace_back(
624         new kernel_factory::OpKernelRegistrar(def, "XlaCompileOnDemandOp",
625                                               factory));
626   }
627   return registrations;
628 }
629 
630 }  // namespace tensorflow
631