1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/jit/xla_device.h"
17
18 #include <stdlib.h>
19 #include <unordered_set>
20
21 #include "absl/memory/memory.h"
22 #include "tensorflow/compiler/jit/defs.h"
23 #include "tensorflow/compiler/jit/xla_compile_on_demand_op.h"
24 #include "tensorflow/compiler/jit/xla_device_context.h"
25 #include "tensorflow/compiler/jit/xla_device_ops.h"
26 #include "tensorflow/compiler/tf2xla/shape_util.h"
27 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
28 #include "tensorflow/compiler/xla/client/client_library.h"
29 #include "tensorflow/compiler/xla/service/stream_pool.h"
30 #include "tensorflow/core/common_runtime/device.h"
31 #include "tensorflow/core/common_runtime/device_factory.h"
32 #include "tensorflow/core/common_runtime/dma_helper.h"
33 #include "tensorflow/core/common_runtime/function.h"
34 #include "tensorflow/core/common_runtime/renamed_device.h"
35 #include "tensorflow/core/framework/allocator.h"
36 #include "tensorflow/core/framework/device_base.h"
37 #include "tensorflow/core/framework/function.h"
38 #include "tensorflow/core/framework/kernel_def.pb.h"
39 #include "tensorflow/core/framework/node_def_builder.h"
40 #include "tensorflow/core/framework/op_kernel.h"
41 #include "tensorflow/core/framework/tensor.h"
42 #include "tensorflow/core/framework/tensor.pb.h"
43 #include "tensorflow/core/framework/types.h"
44 #include "tensorflow/core/graph/graph_constructor.h"
45 #include "tensorflow/core/lib/core/notification.h"
46 #include "tensorflow/core/lib/core/status.h"
47 #include "tensorflow/core/platform/logging.h"
48 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
49 #include "tensorflow/core/platform/tracing.h"
50 #include "tensorflow/core/public/session_options.h"
51 #include "tensorflow/core/public/version.h"
52 #include "tensorflow/core/util/device_name_utils.h"
53 #include "tensorflow/core/util/dump_graph.h"
54 #include "tensorflow/core/util/ptr_util.h"
55 #include "tensorflow/core/util/stream_executor_util.h"
56
57 namespace tensorflow {
58
59 // Caches a XlaDeviceAllocator per <backend, device ordinal> pair. A
60 // XlaDeviceAllocator is created on demand and is associated with a
61 // XlaDevice. It outlives the device itself (for instance, the buffer
62 // backing a tensor holds a pointer to the allocator for book-keeping,
63 // and this buffer can outlast the device).
64 class XlaDeviceAllocatorState {
65 public:
66 // Creates or returns a cached XlaDeviceAllocator for a given
67 // backend and device_ordinal.
68 static XlaDeviceAllocator* GetOrCreateXlaDeviceAllocator(
69 const xla::Backend* backend, int device_ordinal);
70
71 private:
72 // Returns the singleton instance of XlaDeviceAllocatorState.
73 static XlaDeviceAllocatorState& Singleton();
74 XlaDeviceAllocatorState();
75 ~XlaDeviceAllocatorState();
76
77 mutex allocator_mutex_; // Guards the singleton allocator state.
78 std::unordered_map<std::pair<const xla::Backend*, int>,
79 std::unique_ptr<XlaDeviceAllocator>,
80 hash<std::pair<const xla::Backend*, int>>>
81 allocators_ GUARDED_BY(allocator_mutex_);
82
83 TF_DISALLOW_COPY_AND_ASSIGN(XlaDeviceAllocatorState);
84 };
85
Singleton()86 /* static */ XlaDeviceAllocatorState& XlaDeviceAllocatorState::Singleton() {
87 static auto a = new XlaDeviceAllocatorState;
88 return *a;
89 }
90
91 XlaDeviceAllocatorState::XlaDeviceAllocatorState() = default;
92 XlaDeviceAllocatorState::~XlaDeviceAllocatorState() = default;
93
GetOrCreateXlaDeviceAllocator(const xla::Backend * backend,int device_ordinal)94 XlaDeviceAllocator* XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator(
95 const xla::Backend* backend, int device_ordinal) {
96 XlaDeviceAllocatorState& state = Singleton();
97 mutex_lock lock(state.allocator_mutex_);
98
99 auto it = state.allocators_.find({backend, device_ordinal});
100 if (it != state.allocators_.end()) {
101 return it->second.get();
102 }
103
104 std::unique_ptr<XlaDeviceAllocator> alloc =
105 absl::make_unique<XlaDeviceAllocator>(
106 backend->stream_executors()[device_ordinal]);
107 XlaDeviceAllocator* alloc_ptr = alloc.get();
108 state.allocators_[{backend, device_ordinal}] = std::move(alloc);
109 return alloc_ptr;
110 }
111
112 namespace {
113
114 // Default PaddedShapeFn implementation that simply returns the unpadded
115 // on-device shape. This is accurate for CPU and GPU devices that neither
116 // transpose nor pad tensors.
DefaultPaddedShapeFn(const Tensor & tensor,xla::Shape * shape)117 Status DefaultPaddedShapeFn(const Tensor& tensor, xla::Shape* shape) {
118 const tensorflow::XlaTensor* xla_tensor =
119 tensorflow::XlaTensor::FromTensor(&tensor);
120 if (xla_tensor == nullptr) {
121 return TensorShapeToXLAShape(tensor.dtype(), tensor.shape(), shape);
122 }
123
124 const xla::ShapedBuffer& shaped_buffer = xla_tensor->shaped_buffer();
125 *shape = shaped_buffer.on_device_shape();
126 return Status::OK();
127 }
128
BuildXlaDeviceAttributes(const string & name_prefix,const string & device_name,int device_ordinal)129 static DeviceAttributes BuildXlaDeviceAttributes(const string& name_prefix,
130 const string& device_name,
131 int device_ordinal) {
132 return Device::BuildDeviceAttributes(
133 absl::StrCat(name_prefix, "/device:", device_name, ":", device_ordinal),
134 DeviceType(device_name), Bytes(16ULL << 30), DeviceLocality(),
135 absl::StrCat("device: ", device_name, " device"));
136 }
137
138 } // namespace
139
Metadata(int device_ordinal,se::Platform * platform,const DeviceType & device_type,XlaCompiler::ShapeRepresentationFn shape_representation_fn,PaddedShapeFn padded_shape_fn,bool use_multiple_streams)140 XlaDevice::Metadata::Metadata(
141 int device_ordinal, se::Platform* platform, const DeviceType& device_type,
142 XlaCompiler::ShapeRepresentationFn shape_representation_fn,
143 PaddedShapeFn padded_shape_fn, bool use_multiple_streams)
144 : device_ordinal_(device_ordinal),
145 device_type_(device_type),
146 platform_(platform),
147 shape_representation_fn_(std::move(shape_representation_fn)),
148 padded_shape_fn_(std::move(padded_shape_fn)),
149 use_multiple_streams_(use_multiple_streams) {}
150
device_ordinal() const151 int XlaDevice::Metadata::device_ordinal() const { return device_ordinal_; }
152
platform() const153 se::Platform* XlaDevice::Metadata::platform() const { return platform_; }
154
client() const155 xla::LocalClient* XlaDevice::Metadata::client() const {
156 auto client = xla::ClientLibrary::GetOrCreateLocalClient(platform_);
157 return client.ValueOrDie();
158 }
159
jit_device_type() const160 const DeviceType& XlaDevice::Metadata::jit_device_type() const {
161 return device_type_;
162 }
163
GetMetadataFromDevice(DeviceBase * device,const XlaDevice::Metadata ** metadata)164 /*static*/ Status XlaDevice::GetMetadataFromDevice(
165 DeviceBase* device, const XlaDevice::Metadata** metadata) {
166 *metadata = nullptr;
167 XlaDevice* xla_device = dynamic_cast<XlaDevice*>(device->UnderlyingDevice());
168 if (xla_device == nullptr) {
169 return errors::Internal(
170 "Cannot get XLA metadata from non-XLA device \"", device->name(),
171 "\". GetMetadata must only be called on an XLA device. Either an "
172 "internal bug has been triggered, or an XLA-specific op has been "
173 "placed on the wrong device.");
174 }
175 *metadata = &(xla_device->xla_metadata_);
176 return Status::OK();
177 }
178
GetMetadata(OpKernelContext * ctx,const Metadata ** metadata)179 /* static */ Status XlaDevice::GetMetadata(OpKernelContext* ctx,
180 const Metadata** metadata) {
181 return GetMetadataFromDevice(ctx->device(), metadata);
182 }
183
GetMetadata(OpKernelConstruction * ctx,const Metadata ** metadata)184 /* static */ Status XlaDevice::GetMetadata(OpKernelConstruction* ctx,
185 const Metadata** metadata) {
186 return GetMetadataFromDevice(ctx->device(), metadata);
187 }
188
XlaDevice(const SessionOptions & session_options,const Options & options)189 XlaDevice::XlaDevice(const SessionOptions& session_options,
190 const Options& options)
191 : LocalDevice(session_options,
192 BuildXlaDeviceAttributes(options.device_name_prefix,
193 options.device_name,
194 options.device_ordinal)),
195 xla_metadata_(options.device_ordinal, options.platform,
196 DeviceType(options.compilation_device_name),
197 options.shape_representation_fn,
198 options.padded_shape_fn ? options.padded_shape_fn
199 : DefaultPaddedShapeFn,
200 options.use_multiple_streams),
201 device_ordinal_(options.device_ordinal),
202 jit_device_name_(options.compilation_device_name),
203 platform_(options.platform),
204 use_multiple_streams_(options.use_multiple_streams),
205 shape_representation_fn_(options.shape_representation_fn),
206 allowed_devices_(options.allowed_devices) {
207 VLOG(1) << "Created XLA device " << options.compilation_device_name << " "
208 << this;
209 thread_pool_.reset(new thread::ThreadPool(session_options.env, "xla_device",
210 /*num_threads=*/1));
211
212 // We have multiple device to device streams to allow for some concurrency
213 // between transfers. The particular value of '4' is chosen fairly
214 // arbitrarily. It may be necessary to make this tunable via
215 // XlaDevice::Options.
216 static constexpr int kNumDeviceToDeviceStreams = 4;
217 device_to_device_streams_.resize(kNumDeviceToDeviceStreams);
218 }
219
~XlaDevice()220 XlaDevice::~XlaDevice() {
221 VLOG(1) << "Destroying XLA device " << jit_device_name_ << " " << this;
222 mutex_lock lock(mu_);
223 if (device_context_) {
224 device_context_->Unref();
225 }
226 }
227
client() const228 xla::LocalClient* XlaDevice::client() const {
229 // We lazily create the client because the platform commits to the
230 // details of the host hardware when the client is created, so we
231 // don't want to do it until we get a chance to hook the platform up
232 // to a simulator.
233
234 // TODO(b/78468222): This can fail, at least when the backend is GPU and
235 // there is no GPU on the host.
236 return xla::ClientLibrary::GetOrCreateLocalClient(platform_, allowed_devices_)
237 .ValueOrDie();
238 }
239
GetAllocator(AllocatorAttributes attr)240 Allocator* XlaDevice::GetAllocator(AllocatorAttributes attr) {
241 mutex_lock lock(mu_);
242 return GetAllocatorLocked(attr);
243 }
244
GetAllocatorLocked(AllocatorAttributes attr)245 Allocator* XlaDevice::GetAllocatorLocked(AllocatorAttributes attr) {
246 if (attr.on_host()) {
247 return cpu_allocator();
248 }
249
250 if (xla_allocator_ == nullptr) {
251 xla::Backend* backend = client()->mutable_backend();
252 xla_allocator_ = XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator(
253 backend, device_ordinal_);
254 }
255 return xla_allocator_;
256 }
257
EnsureDeviceContextOk()258 Status XlaDevice::EnsureDeviceContextOk() {
259 mutex_lock lock(mu_);
260 return GetDeviceContextLocked().status();
261 }
262
EnsureStreamOkLocked(xla::Backend * backend,const string & name,std::shared_ptr<se::Stream> * stream,bool * stream_was_changed)263 Status XlaDevice::EnsureStreamOkLocked(xla::Backend* backend,
264 const string& name,
265 std::shared_ptr<se::Stream>* stream,
266 bool* stream_was_changed) {
267 if (!(*stream) || !(*stream)->ok()) {
268 xla::StreamPool::Ptr ptr;
269 TF_ASSIGN_OR_RETURN(ptr, backend->BorrowStream(device_ordinal_));
270 *stream = std::shared_ptr<se::Stream>(std::move(ptr));
271 VLOG(1) << "XlaDevice " << this << " new " << name << " "
272 << (*stream)->DebugStreamPointers();
273 *stream_was_changed = true;
274 }
275 return Status::OK();
276 }
277
GetDeviceContextLocked()278 xla::StatusOr<XlaDeviceContext*> XlaDevice::GetDeviceContextLocked() {
279 xla::Backend* backend = client()->mutable_backend();
280
281 // Ensure all our streams are valid, borrowing new streams if necessary.
282 bool need_new_device_context = !device_context_;
283 TF_RETURN_IF_ERROR(EnsureStreamOkLocked(backend, "stream", &stream_,
284 &need_new_device_context));
285
286 std::shared_ptr<se::Stream> host_to_device_stream;
287 std::shared_ptr<se::Stream> device_to_host_stream;
288 std::vector<std::shared_ptr<se::Stream>> device_to_device_streams;
289 if (use_multiple_streams_) {
290 TF_RETURN_IF_ERROR(EnsureStreamOkLocked(backend, "host_to_device_stream",
291 &host_to_device_stream_,
292 &need_new_device_context));
293 for (std::shared_ptr<se::Stream>& stream : device_to_device_streams_) {
294 TF_RETURN_IF_ERROR(
295 EnsureStreamOkLocked(backend, "device_to_device_stream", &stream,
296 &need_new_device_context));
297 }
298 host_to_device_stream = host_to_device_stream_;
299 device_to_device_streams = device_to_device_streams_;
300 // The data transfer requests from device to host could arrive out of order,
301 // so a single stream would cause deadlock. For this case,
302 // xla_device_context would borrow a stream for each transfer request.
303 device_to_host_stream = nullptr;
304 } else {
305 host_to_device_stream = stream_;
306 device_to_host_stream = stream_;
307 device_to_device_streams = {stream_};
308 }
309
310 if (!need_new_device_context) {
311 return device_context_;
312 }
313
314 // At this point we know we need a new device context.
315 // Call GetAllocator for the side-effect of ensuring the allocator is created.
316 GetAllocatorLocked({});
317 if (device_context_) {
318 device_context_->Unref();
319 }
320 // The XlaDeviceContext keeps a reference count to the streams, and the
321 // XlaDeviceContext remains live for the duration of a Executor run. This
322 // ensures that the streams remain live for the duration of a run, even if
323 // an error is encountered and the streams are replaced with new ones.
324 device_context_ = new XlaDeviceContext(
325 stream_, std::move(host_to_device_stream),
326 std::move(device_to_host_stream), std::move(device_to_device_streams),
327 client(), shape_representation_fn_, thread_pool_.get());
328 VLOG(1) << "XlaDevice " << this << " new XlaDeviceContext "
329 << device_context_;
330
331 // Create and set a new GpuDeviceInfo, if necessary.
332 //
333 // TODO(b/78232898): This isn't thread-safe; there is a race between the call
334 // to set_tensorflow_gpu_device_info() with ops that call the getter
335 // tensorflow_gpu_device_info(). This isn't trivially fixed by adding locking
336 // to those methods; see the bug for details. Our only saving grace at the
337 // moment is that this race doesn't seem to occur in practice.
338 if (use_gpu_device_info_) {
339 auto gpu_device_info = absl::make_unique<GpuDeviceInfo>();
340 gpu_device_info->stream = stream_.get();
341 gpu_device_info->default_context = device_context_;
342 set_tensorflow_gpu_device_info(gpu_device_info.get());
343 gpu_device_info_ = std::move(gpu_device_info);
344 VLOG(1) << "XlaDevice " << this << " new GpuDeviceInfo "
345 << gpu_device_info_.get();
346 }
347
348 return device_context_;
349 }
350
UseGpuDeviceInfo()351 Status XlaDevice::UseGpuDeviceInfo() {
352 mutex_lock lock(mu_);
353 use_gpu_device_info_ = true;
354 return GetDeviceContextLocked().status();
355 }
356
FillContextMap(const Graph * graph,DeviceContextMap * device_context_map)357 Status XlaDevice::FillContextMap(const Graph* graph,
358 DeviceContextMap* device_context_map) {
359 VLOG(1) << "XlaDevice::FillContextMap";
360 mutex_lock lock(mu_);
361 TF_ASSIGN_OR_RETURN(XlaDeviceContext * device_context,
362 GetDeviceContextLocked());
363
364 device_context_map->resize(graph->num_node_ids());
365 for (Node* n : graph->nodes()) {
366 VLOG(2) << n->id() << " : " << n->type_string() << " : " << n->name();
367 device_context->Ref();
368 (*device_context_map)[n->id()] = device_context;
369 }
370 return Status::OK();
371 }
372
Compute(OpKernel * op_kernel,OpKernelContext * context)373 void XlaDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) {
374 VLOG(2) << "XlaDevice::Compute " << op_kernel->name() << ":"
375 << op_kernel->type_string();
376 op_kernel->Compute(context);
377 }
378
ComputeAsync(AsyncOpKernel * op_kernel,OpKernelContext * context,AsyncOpKernel::DoneCallback done)379 void XlaDevice::ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context,
380 AsyncOpKernel::DoneCallback done) {
381 VLOG(2) << "XlaDevice::ComputeAsync " << op_kernel->name() << ":"
382 << op_kernel->type_string();
383 tracing::ScopedActivity activity(op_kernel->name(), op_kernel->type_string(),
384 op_kernel->IsExpensive());
385 op_kernel->ComputeAsync(context, done);
386 }
387
Sync()388 Status XlaDevice::Sync() {
389 VLOG(1) << "XlaDevice::Sync";
390 tracing::ScopedActivity activity("XlaDevice::Sync", /*is_expensive=*/true);
391 std::shared_ptr<se::Stream> stream;
392 {
393 mutex_lock lock(mu_);
394 stream = stream_;
395 }
396 if (!stream) return Status::OK();
397
398 Status status = stream->BlockHostUntilDone();
399 TF_RETURN_IF_ERROR(status);
400 if (!stream->ok()) {
401 return errors::Internal("XlaDevice::Sync() failed.");
402 }
403 VLOG(1) << "XlaDevice::Sync completed";
404 return Status::OK();
405 }
406
407 // TODO(b/112409994): This is no longer necessary. Consolidate it with the
408 // synchronous version.
Sync(const DoneCallback & done)409 void XlaDevice::Sync(const DoneCallback& done) {
410 VLOG(1) << "XlaDevice::Sync (asynchronous)";
411 std::shared_ptr<se::Stream> stream;
412 {
413 mutex_lock lock(mu_);
414 stream = stream_;
415 }
416 if (!stream) {
417 done(Status::OK());
418 return;
419 }
420
421 // The call to ThenEnqueueOnBackgroundThread below enqueues a host callback at
422 // the end of the stream, after everything that has already been enqueued
423 // there at this moment. When the host callback is called, everything before
424 // it must have already finished, and the host callback will then place the
425 // task below onto a background thread. (See the implementation of
426 // ThenEnqueueOnBackgroundThread for details.) Therefore, when the done
427 // callback is finally called from that background thread, we know for sure
428 // that everything enqueued onto the stream (i.e., the device) at this very
429 // moment--when ThenEnqueueOnBackgroundThread is called--will have finished.
430 // This achieves a device-wide sync.
431 stream->ThenEnqueueOnBackgroundThread(
432 [stream, done](se::StreamExecutor*) {
433 tracing::ScopedActivity activity("XlaDevice::Sync::Callback",
434 /*is_expensive=*/true);
435 done(stream->ok() ? Status::OK()
436 : errors::Internal("XlaDevice::Sync() failed."));
437 });
438 }
439
MakeTensorFromProto(const TensorProto & tensor_proto,const AllocatorAttributes alloc_attrs,Tensor * tensor)440 Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto,
441 const AllocatorAttributes alloc_attrs,
442 Tensor* tensor) {
443 VLOG(1) << "XlaDevice::MakeTensorFromProto";
444
445 Tensor parsed(tensor_proto.dtype());
446 if (!parsed.FromProto(cpu_allocator(), tensor_proto)) {
447 return errors::InvalidArgument("Cannot parse tensor from proto: ",
448 tensor_proto.DebugString());
449 }
450
451 Status status;
452 if (alloc_attrs.on_host()) {
453 *tensor = parsed;
454 } else {
455 mutex_lock lock(mu_);
456 TF_ASSIGN_OR_RETURN(XlaDeviceContext * device_context,
457 GetDeviceContextLocked());
458 Allocator* allocator = GetAllocatorLocked(alloc_attrs);
459 Tensor copy(allocator, parsed.dtype(), parsed.shape());
460 Notification n;
461 device_context->CopyCPUTensorToDevice(&parsed, this, ©,
462 [&n, &status](const Status& s) {
463 status = s;
464 n.Notify();
465 });
466 n.WaitForNotification();
467 *tensor = copy;
468 }
469 VLOG(2) << "Allocated tensor at " << DMAHelper::base(tensor);
470 return status;
471 }
472
SetAllowsSyncOnCompletion(bool sync_on_completion)473 void XlaDevice::SetAllowsSyncOnCompletion(bool sync_on_completion) {
474 mutex_lock lock(mu_);
475 sync_on_completion_ = sync_on_completion;
476 }
477
AllowsSyncOnCompletion() const478 bool XlaDevice::AllowsSyncOnCompletion() const {
479 mutex_lock lock(mu_);
480 return sync_on_completion_;
481 }
482
SetHandleDeviceErrorCallback(std::function<Status ()> callback)483 void XlaDevice::SetHandleDeviceErrorCallback(std::function<Status()> callback) {
484 mutex_lock lock(mu_);
485 device_error_callback_ = callback;
486 }
487
HandleDeviceError()488 Status XlaDevice::HandleDeviceError() {
489 std::function<Status()> local_device_error_callback;
490 {
491 mutex_lock lock(mu_);
492 local_device_error_callback = device_error_callback_;
493 }
494 if (local_device_error_callback != nullptr) {
495 return local_device_error_callback();
496 }
497 return Status::OK();
498 }
499
RefreshStatus()500 Status XlaDevice::RefreshStatus() {
501 std::shared_ptr<se::Stream> stream;
502 {
503 mutex_lock lock(mu_);
504 stream = stream_;
505 }
506 if (!stream) {
507 return Status::OK();
508 }
509 Status status = stream->RefreshStatus();
510 if (!status.ok()) {
511 // Ignore errors from HandleDeviceError, since by definition the status is
512 // already non-ok, so there's nothing extra to report if HandleDeviceError
513 // itself returns an error.
514 HandleDeviceError().IgnoreError();
515 }
516 return status;
517 }
518
RegisterXlaDeviceKernels(const char * device,const char * jit_device)519 XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device,
520 const char* jit_device) {
521 // Any op assigned to the device that isn't rewritten by the graph rewriter
522 // gets executed by a n XlaCompileOnDemandOp, which compiles it and executes
523 // it just-in-time.
524 OpKernel* (*factory)(OpKernelConstruction*) =
525 [](OpKernelConstruction* context) -> OpKernel* {
526 return new XlaCompileOnDemandOp(context);
527 };
528 XlaOpRegistry::RegisterCompilationKernels();
529 XlaDeviceOpRegistrations* registrations = new XlaDeviceOpRegistrations;
530 for (const KernelDef* jit_def : XlaOpRegistry::DeviceKernels(
531 jit_device,
532 /*include_compilation_only_kernels=*/false)) {
533 KernelDef* def = new KernelDef(*jit_def);
534 def->set_device_type(device);
535 registrations->op_kernel_registrars.emplace_back(
536 new kernel_factory::OpKernelRegistrar(def, "XlaCompileOnDemandOp",
537 factory));
538 }
539 return registrations;
540 }
541
542 } // namespace tensorflow
543