• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/jit/kernels/xla_ops.h"
17 
18 #include "absl/container/flat_hash_map.h"
19 #include "absl/memory/memory.h"
20 #include "tensorflow/compiler/jit/defs.h"
21 #include "tensorflow/compiler/jit/flags.h"
22 #include "tensorflow/compiler/tf2xla/shape_util.h"
23 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
24 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
25 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
26 #include "tensorflow/compiler/xla/client/client_library.h"
27 #include "tensorflow/compiler/xla/client/local_client.h"
28 #include "tensorflow/compiler/xla/service/compiler.h"
29 #include "tensorflow/compiler/xla/status_macros.h"
30 #include "tensorflow/compiler/xla/statusor.h"
31 #include "tensorflow/core/common_runtime/dma_helper.h"
32 #include "tensorflow/core/common_runtime/function.h"
33 #include "tensorflow/core/framework/allocator.h"
34 #include "tensorflow/core/framework/node_def_util.h"
35 #include "tensorflow/core/framework/op.h"
36 #include "tensorflow/core/framework/op_kernel.h"
37 #include "tensorflow/core/framework/tensor.h"
38 #include "tensorflow/core/framework/types.h"
39 #include "tensorflow/core/lib/core/errors.h"
40 #include "tensorflow/core/lib/core/status.h"
41 #include "tensorflow/core/platform/env.h"
42 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
43 #include "tensorflow/core/util/stream_executor_util.h"
44 
45 // OP_REQUIRES_OK_RETURN is the same as OP_REQUIRES_OK except that
46 // in error case, it returns RET instead of void.
47 #define OP_REQUIRES_OK_RETURN(CTX, RET, ...)                \
48   do {                                                      \
49     ::tensorflow::Status _s(__VA_ARGS__);                   \
50     if (!TF_PREDICT_TRUE(_s.ok())) {                        \
51       (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \
52       return RET;                                           \
53     }                                                       \
54   } while (0)
55 
56 namespace tensorflow {
57 
58 namespace {
59 
PlatformInfoFromContext(OpKernelConstruction * ctx)60 XlaPlatformInfo PlatformInfoFromContext(OpKernelConstruction* ctx) {
61   DeviceType device_type = ctx->device_type();
62   se::Platform::Id platform_id = nullptr;
63   const XlaDevice::Metadata* xla_device_metadata = nullptr;
64   std::unique_ptr<XlaAllocator> xla_allocator;
65   xla::DeviceMemoryAllocator* device_allocator = nullptr;
66 
67   if (ctx->device_type() == DeviceType(DEVICE_CPU)) {
68     platform_id = se::host::kHostPlatformId;
69   } else if (ctx->device_type() == DeviceType(DEVICE_GPU)) {
70     platform_id = ctx->device()
71                       ->tensorflow_gpu_device_info()
72                       ->stream->parent()
73                       ->platform()
74                       ->id();
75   } else if (XlaDevice::GetMetadata(ctx, &xla_device_metadata).ok()) {
76     // If we are on an XlaDevice, use the underlying XLA platform's allocator
77     // directly. We could use the StreamExecutor's allocator which may
78     // theoretically be more correct, but XLA returns a nice OOM message in a
79     // Status and StreamExecutor does not.
80     //
81     // Importantly we can't use ctx->device()->GetAllocator() as the allocator
82     // (which xla_allocator above uses) as on an XlaDevice, this is a dummy
83     // allocator that returns XlaTensor objects. The XlaCompiler needs a real
84     // allocator to allocate real buffers.
85 
86     platform_id = xla_device_metadata->platform()->id();
87     device_allocator =
88         xla_device_metadata->client()->backend().memory_allocator();
89   }
90 
91   if (!device_allocator) {
92     xla::StatusOr<se::Platform*> maybe_platform =
93         se::MultiPlatformManager::PlatformWithId(platform_id);
94     OP_REQUIRES_OK_RETURN(ctx, XlaPlatformInfo(), maybe_platform.status());
95 
96     xla_allocator = absl::make_unique<XlaAllocator>(
97         maybe_platform.ValueOrDie(), ctx->device()->GetAllocator({}));
98   }
99 
100   return XlaPlatformInfo(device_type, platform_id, xla_device_metadata,
101                          std::move(xla_allocator), device_allocator);
102 }
103 
104 // A closure describing how to run a compiled version of a TensorFlow function.
105 //
106 // It may seem unusual to stick the resource variable snapshots in this class.
107 // This is necessary: we need to use the snapshots observed by the compiler as
108 // the initial values for the resource variables (and cannot snapshot them again
109 // during execution) because otherwise we risk observing a different snapshot
110 // with shapes different from what we compiled for.
111 class XlaExecutableClosure {
112  public:
XlaExecutableClosure(xla::LocalClient * client,xla::LocalExecutable * executable,const XlaCompiler::CompilationResult * compilation_result,std::map<int,OptionalTensor> resource_var_snapshots,int num_constant_args)113   explicit XlaExecutableClosure(
114       xla::LocalClient* client, xla::LocalExecutable* executable,
115       const XlaCompiler::CompilationResult* compilation_result,
116       std::map<int, OptionalTensor> resource_var_snapshots,
117       int num_constant_args)
118       : client_(client),
119         executable_(executable),
120         compilation_result_(compilation_result),
121         resource_var_snapshots_(std::move(resource_var_snapshots)),
122         num_constant_args_(num_constant_args) {}
123 
124   XlaExecutableClosure(XlaExecutableClosure&&) = default;
125   XlaExecutableClosure& operator=(XlaExecutableClosure&&) = default;
126 
client() const127   xla::LocalClient* client() const { return client_; }
executable() const128   xla::LocalExecutable* executable() const { return executable_; }
compilation_result() const129   const XlaCompiler::CompilationResult* compilation_result() const {
130     return compilation_result_;
131   }
resource_var_snapshots() const132   const std::map<int, OptionalTensor>& resource_var_snapshots() const {
133     return resource_var_snapshots_;
134   }
num_constant_args() const135   int num_constant_args() const { return num_constant_args_; }
136 
137  private:
138   xla::LocalClient* client_;
139   xla::LocalExecutable* executable_;
140   const XlaCompiler::CompilationResult* compilation_result_;
141   std::map<int, OptionalTensor> resource_var_snapshots_;
142   int num_constant_args_;
143 
144   TF_DISALLOW_COPY_AND_ASSIGN(XlaExecutableClosure);
145 };
146 
147 // This maintains a mapping from a globally unique ID to XlaExecutableClosure
148 // instances.
149 class XlaExecutableClosureStore {
150  public:
XlaExecutableClosureStore()151   XlaExecutableClosureStore() : key_counter_(0) {}
152 
153   using KeyT = string;
154 
Produce(XlaExecutableClosure result)155   KeyT Produce(XlaExecutableClosure result) {
156     mutex_lock l(mutex_);
157     KeyT key = absl::StrCat(key_counter_++);
158     bool insert_successful = closures_.emplace(key, std::move(result)).second;
159     DCHECK(insert_successful);
160     (void)insert_successful;
161     return key;
162   }
163 
Consume(const KeyT & key)164   XlaExecutableClosure Consume(const KeyT& key) {
165     mutex_lock l(mutex_);
166     auto it = closures_.find(key);
167     DCHECK(it != closures_.end());
168     XlaExecutableClosure value = std::move(it->second);
169     closures_.erase(it);
170     return value;
171   }
172 
Global()173   static XlaExecutableClosureStore* Global() {
174     static XlaExecutableClosureStore* instance = new XlaExecutableClosureStore;
175     return instance;
176   }
177 
178  private:
179   mutex mutex_;
180   int64 key_counter_ GUARDED_BY(mutex_);
181   absl::flat_hash_map<KeyT, XlaExecutableClosure> closures_ GUARDED_BY(mutex_);
182 
183   TF_DISALLOW_COPY_AND_ASSIGN(XlaExecutableClosureStore);
184 };
185 
186 }  // namespace
187 
XlaLocalLaunchBase(OpKernelConstruction * ctx,const std::vector<int> & constants,const std::vector<int> & resources,const NameAttrList & function)188 XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx,
189                                        const std::vector<int>& constants,
190                                        const std::vector<int>& resources,
191                                        const NameAttrList& function)
192     : OpKernel(ctx),
193       constants_(constants),
194       resources_(resources),
195       function_(function),
196       platform_info_(PlatformInfoFromContext(ctx)) {}
197 
BuildCompilationCache(OpKernelContext * ctx,const XlaPlatformInfo & platform_info,XlaCompilationCache ** cache)198 static Status BuildCompilationCache(OpKernelContext* ctx,
199                                     const XlaPlatformInfo& platform_info,
200                                     XlaCompilationCache** cache) {
201   if (platform_info.xla_device_metadata()) {
202     *cache = new XlaCompilationCache(
203         platform_info.xla_device_metadata()->client(),
204         platform_info.xla_device_metadata()->jit_device_type());
205     return Status::OK();
206   }
207 
208   auto platform =
209       se::MultiPlatformManager::PlatformWithId(platform_info.platform_id());
210   if (!platform.ok()) {
211     return platform.status();
212   }
213 
214   xla::StatusOr<xla::Compiler*> compiler_for_platform =
215       xla::Compiler::GetForPlatform(platform.ValueOrDie());
216   if (!compiler_for_platform.ok()) {
217     // In some rare cases (usually in unit tests with very small clusters) we
218     // may end up transforming an XLA cluster with at least one GPU operation
219     // (which would normally force the cluster to be compiled using XLA:GPU)
220     // into an XLA cluster with no GPU operations (i.e. containing only CPU
221     // operations).  Such a cluster can fail compilation (in way that
222     // MarkForCompilation could not have detected) if the CPU JIT is not linked
223     // in.
224     //
225     // So bail out of _XlaCompile in this case, and let the executor handle the
226     // situation for us.
227     const Status& status = compiler_for_platform.status();
228     if (status.code() == error::NOT_FOUND) {
229       return errors::Unimplemented("Could not find compiler for platform ",
230                                    platform.ValueOrDie()->Name(), ": ",
231                                    status.ToString());
232     }
233   }
234 
235   xla::LocalClientOptions client_options;
236   client_options.set_platform(platform.ValueOrDie());
237   client_options.set_intra_op_parallelism_threads(
238       ctx->device()->tensorflow_cpu_worker_threads()->num_threads);
239   auto client = xla::ClientLibrary::GetOrCreateLocalClient(client_options);
240   if (!client.ok()) {
241     return client.status();
242   }
243   const XlaOpRegistry::DeviceRegistration* registration;
244   if (!XlaOpRegistry::GetCompilationDevice(platform_info.device_type().type(),
245                                            &registration)) {
246     return errors::InvalidArgument("No JIT device registered for ",
247                                    platform_info.device_type().type());
248   }
249   *cache = new XlaCompilationCache(
250       client.ValueOrDie(), DeviceType(registration->compilation_device_name));
251   return Status::OK();
252 }
253 
CompileToLocalExecutable(OpKernelContext * ctx,const NameAttrList & function,const XlaPlatformInfo & platform_info,absl::Span<const int> resources,absl::Span<const int> constants,bool lazy,xla::LocalClient ** client,std::map<int,OptionalTensor> * variables,const XlaCompiler::CompilationResult ** kernel,xla::LocalExecutable ** executable)254 static Status CompileToLocalExecutable(
255     OpKernelContext* ctx, const NameAttrList& function,
256     const XlaPlatformInfo& platform_info, absl::Span<const int> resources,
257     absl::Span<const int> constants, bool lazy, xla::LocalClient** client,
258     std::map<int, OptionalTensor>* variables,
259     const XlaCompiler::CompilationResult** kernel,
260     xla::LocalExecutable** executable) {
261   // We store information about the JIT-compiled XLA computation
262   // in the ResourceMgr.
263   ResourceMgr* rm = ctx->resource_manager();
264   if (!rm) {
265     return errors::Internal("No resource manager.");
266   }
267 
268   XlaCompilationCache* cache;
269   TF_RETURN_IF_ERROR(rm->LookupOrCreate<XlaCompilationCache>(
270       rm->default_container(), "xla_cache", &cache,
271       [&](XlaCompilationCache** cache) {
272         return BuildCompilationCache(ctx, platform_info, cache);
273       }));
274   // Hold the reference to the JIT during evaluation. (We could probably
275   // free it sooner because the ResourceMgr will retain a reference, but
276   // this is more obviously correct.)
277   core::ScopedUnref cache_ref(cache);
278 
279   TF_RETURN_IF_ERROR(SnapshotResourceVariables(ctx, resources, variables));
280   *client = static_cast<xla::LocalClient*>(cache->client());
281 
282   XlaCompiler::Options options;
283   options.client = *client;
284   if (ctx->op_device_context() != nullptr) {
285     options.device_ordinal =
286         ctx->op_device_context()->stream()->parent()->device_ordinal();
287   }
288   options.device_type = cache->device_type();
289   options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition();
290   options.graph_def_version = ctx->function_library()->graph_def_version();
291   options.allow_cpu_custom_calls =
292       (platform_info.platform_id() == se::host::kHostPlatformId);
293   options.device_allocator = platform_info.allocator();
294   if (platform_info.xla_device_metadata()) {
295     options.shape_representation_fn =
296         platform_info.xla_device_metadata()->shape_representation_fn();
297   }
298 
299   std::map<int, Tensor> constant_args;
300   for (int i : constants) {
301     constant_args.insert({i, ctx->input(i)});
302   }
303   XlaCompiler::CompileOptions compile_options;
304   compile_options.is_entry_computation = true;
305   // If we resolve constants we never emit them on the device, meaning that if
306   // they are needed by a following computation the host has to transfer
307   // them. Not resolving constants is expected to be faster than resolving
308   // constants.
309   compile_options.resolve_compile_time_constants = true;
310   // Optimization: where possible, have the computation return a naked array
311   // rather than a one-element tuple.
312   compile_options.always_return_tuple = false;
313 
314   std::vector<XlaCompiler::Argument> args;
315   TF_RETURN_IF_ERROR(XlaComputationLaunchContext::BuildXlaCompilerArguments(
316       constant_args, *variables, ctx, &args));
317   return cache->Compile(options, function, args, compile_options,
318                         lazy ? XlaCompilationCache::CompileMode::kLazy
319                              : XlaCompilationCache::CompileMode::kStrict,
320                         kernel, executable);
321 }
322 
Compute(OpKernelContext * ctx)323 void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
324   VLOG(1) << "XlaLocalLaunchOpBase::Compute "
325           << Canonicalize(function_.name(), AttrSlice(&function_.attr()));
326 
327   xla::LocalClient* client;
328   const XlaCompiler::CompilationResult* kernel;
329   xla::LocalExecutable* executable;
330   std::map<int, OptionalTensor> variables;
331 
332   {
333     Status s = CompileToLocalExecutable(
334         ctx, function_, platform_info_, resources_, constants_, /*lazy=*/false,
335         &client, &variables, &kernel, &executable);
336     if (!s.ok() && (platform_info_.device_type().type_string() == DEVICE_CPU ||
337                     platform_info_.device_type().type_string() == DEVICE_GPU)) {
338       // Suggest auto jit if the failure was with GPU or CPU.
339       errors::AppendToMessage(&s,
340                               xla::status_macros::kPossibleAutoJitAlternative);
341     }
342 
343     OP_REQUIRES_OK(ctx, s);
344   }
345 
346   se::Stream* stream =
347       ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
348 
349   VLOG(1) << "Executing XLA Computation...";
350 
351   XlaComputationLaunchContext launch_context(
352       client, platform_info_.allocator(),
353       /*allocate_xla_tensors=*/platform_info_.is_on_xla_device(),
354       platform_info_.UseMultipleStreams());
355   launch_context.PopulateInputs(ctx, kernel, variables,
356                                 /*missing_ctx_input_prefix=*/0);
357 
358   // Execute the computation.
359   VLOG(2) << "Executing computation.";
360   xla::ExecutableRunOptions run_options;
361   run_options.set_stream(stream);
362   run_options.set_allocator(platform_info_.allocator());
363   run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
364   run_options.set_rng_seed(GetXLARandomSeed());
365   Env* env = Env::Default();
366   auto start_time = env->NowMicros();
367 
368   auto run_result = executable->Run(launch_context.arguments(), run_options);
369   OP_REQUIRES(ctx, run_result.ok(), run_result.status());
370 
371   auto elapsed = env->NowMicros() - start_time;
372   VLOG(2) << "Elapsed time: " << elapsed << "us";
373 
374   OP_REQUIRES_OK(ctx, launch_context.PopulateOutputs(
375                           ctx, kernel, run_result.ConsumeValueOrDie(),
376                           /*missing_ctx_input_prefix=*/0));
377   VLOG(1) << "Done";
378 }
379 
380 namespace {
381 // Helper static functions to construct parameters for
382 // XlaLocalLaunchBase constructor from OpKernelConstruction.
ConstantsVector(OpKernelConstruction * ctx)383 std::vector<int> ConstantsVector(OpKernelConstruction* ctx) {
384   DataTypeVector constant_types;
385   OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
386                         ctx->GetAttr("Tconstants", &constant_types));
387   std::vector<int> constants(constant_types.size());
388   std::iota(constants.begin(), constants.end(), 0);
389   return constants;
390 }
391 
ResourcesVector(OpKernelConstruction * ctx)392 std::vector<int> ResourcesVector(OpKernelConstruction* ctx) {
393   DataTypeVector constant_types;
394   OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
395                         ctx->GetAttr("Tconstants", &constant_types));
396 
397   DataTypeVector arg_types;
398   OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
399                         ctx->GetAttr("Targs", &arg_types));
400 
401   int num_resources;
402   OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
403                         ctx->GetAttr("Nresources", &num_resources));
404 
405   std::vector<int> resources(num_resources);
406   std::iota(resources.begin(), resources.end(),
407             constant_types.size() + arg_types.size());
408   return resources;
409 }
410 
FunctionAttr(OpKernelConstruction * ctx)411 NameAttrList FunctionAttr(OpKernelConstruction* ctx) {
412   const NameAttrList* func;
413   OP_REQUIRES_OK_RETURN(ctx, NameAttrList(), ctx->GetAttr("function", &func));
414   return *func;
415 }
416 
MustCompileAttr(OpKernelConstruction * ctx)417 bool MustCompileAttr(OpKernelConstruction* ctx) {
418   bool must_compile;
419   OP_REQUIRES_OK_RETURN(ctx, false,
420                         ctx->GetAttr("must_compile", &must_compile));
421   return must_compile;
422 }
423 }  // namespace
424 
XlaLocalLaunchOp(OpKernelConstruction * ctx)425 XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx)
426     : XlaLocalLaunchBase(ctx, ConstantsVector(ctx), ResourcesVector(ctx),
427                          FunctionAttr(ctx)) {}
428 
~XlaLocalLaunchOp()429 XlaLocalLaunchOp::~XlaLocalLaunchOp() {
430   VLOG(1) << "XlaLocalLaunchOp destroyed";
431 }
432 
XlaCompileOp(OpKernelConstruction * ctx)433 XlaCompileOp::XlaCompileOp(OpKernelConstruction* ctx)
434     : OpKernel(ctx),
435       constants_(ConstantsVector(ctx)),
436       resources_(ResourcesVector(ctx)),
437       function_(FunctionAttr(ctx)),
438       platform_info_(PlatformInfoFromContext(ctx)),
439       must_compile_(MustCompileAttr(ctx)) {}
440 
Compute(OpKernelContext * ctx)441 void XlaCompileOp::Compute(OpKernelContext* ctx) {
442   VLOG(3) << "XlaCompileOp " << def().name()
443           << (must_compile_ ? "(must-compile)" : "");
444   xla::LocalClient* client;
445   const XlaCompiler::CompilationResult* kernel;
446   xla::LocalExecutable* executable;
447   std::map<int, OptionalTensor> variables;
448 
449   bool cannot_compile_cluster;
450   {
451     mutex_lock guard(cannot_compile_cluster_mu_);
452     cannot_compile_cluster = cannot_compile_cluster_;
453   }
454 
455   if (GetXlaOpsCommonFlags().tf_xla_always_defer_compilation ||
456       cannot_compile_cluster) {
457     executable = nullptr;
458   } else {
459     Status status = CompileToLocalExecutable(
460         ctx, function_, platform_info_, resources_, constants_,
461         /*lazy=*/!must_compile_, &client, &variables, &kernel, &executable);
462     if (must_compile_ || status.code() != error::UNIMPLEMENTED) {
463       OP_REQUIRES_OK(ctx, status);
464     }
465 
466     if (status.code() == error::UNIMPLEMENTED) {
467       LOG(WARNING) << "Compilation failed:" << status.ToString()
468                    << ".  Falling back to TF function call.";
469       executable = nullptr;
470       mutex_lock guard(cannot_compile_cluster_mu_);
471       cannot_compile_cluster_ = true;
472     }
473   }
474 
475   AllocatorAttributes host_alloc_attrs;
476   host_alloc_attrs.set_gpu_compatible(true);
477   host_alloc_attrs.set_on_host(true);
478   Allocator* cpu_allocator = ctx->device()->GetAllocator(host_alloc_attrs);
479 
480   if (!executable) {
481     DCHECK(!must_compile_);
482     Tensor compilation_key(cpu_allocator, DT_STRING, TensorShape({}));
483 
484     Tensor compilation_successful(cpu_allocator, DT_BOOL, TensorShape({}));
485     compilation_successful.scalar<bool>()() = false;
486     ctx->set_output(0, Tensor(cpu_allocator, DT_STRING, TensorShape({})));
487     ctx->set_output(1, compilation_successful);
488     return;
489   }
490 
491   // Each execution of an XlaCompile op creates a new XlaExecutableClosure, even
492   // if it didn't have to compile the cluster because of a compilation-cache
493   // hit.  This is because we at least need new snapshots of the resource
494   // variables.
495   XlaExecutableClosureStore::KeyT key =
496       XlaExecutableClosureStore::Global()->Produce(XlaExecutableClosure(
497           client, executable, kernel, std::move(variables), constants_.size()));
498 
499   Tensor compilation_key(cpu_allocator, DT_STRING, TensorShape({}));
500   compilation_key.flat<string>()(0) = key;
501 
502   Tensor compilation_successful(cpu_allocator, DT_BOOL, TensorShape({}));
503   compilation_successful.flat<bool>()(0) = true;
504 
505   ctx->set_output(0, compilation_key);
506   ctx->set_output(1, compilation_successful);
507 }
508 
XlaRunOp(OpKernelConstruction * ctx)509 XlaRunOp::XlaRunOp(OpKernelConstruction* ctx)
510     : OpKernel(ctx), platform_info_(PlatformInfoFromContext(ctx)) {}
511 
Compute(OpKernelContext * ctx)512 void XlaRunOp::Compute(OpKernelContext* ctx) {
513   VLOG(3) << "XlaRunOp " << def().name();
514   Tensor key_tensor = ctx->input(ctx->num_inputs() - 1);
515   const XlaExecutableClosureStore::KeyT& key = key_tensor.flat<string>()(0);
516 
517   XlaExecutableClosure closure =
518       XlaExecutableClosureStore::Global()->Consume(key);
519 
520   XlaComputationLaunchContext launch_context(
521       closure.client(), platform_info_.allocator(),
522       /*allocate_xla_tensors=*/platform_info_.is_on_xla_device(),
523       /*use_multiple_streams=*/platform_info_.UseMultipleStreams());
524 
525   // We're missing the must-be-constant inputs, tell `PopulateInputs`
526   // about this.  We don't actually need these inputs because they've
527   // already been baked into the compiled kernel.
528   launch_context.PopulateInputs(
529       ctx, closure.compilation_result(), closure.resource_var_snapshots(),
530       /*missing_ctx_input_prefix=*/closure.num_constant_args());
531 
532   se::Stream* stream =
533       ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
534   xla::ExecutableRunOptions run_options;
535   run_options.set_stream(stream);
536   run_options.set_allocator(platform_info_.allocator());
537   run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
538   run_options.set_rng_seed(GetXLARandomSeed());
539   Env* env = Env::Default();
540   auto start_time = env->NowMicros();
541 
542   auto run_result =
543       closure.executable()->Run(launch_context.arguments(), run_options);
544   OP_REQUIRES(ctx, run_result.ok(), run_result.status());
545 
546   auto elapsed = env->NowMicros() - start_time;
547   VLOG(2) << "Elapsed time in computation: " << elapsed << "us";
548 
549   OP_REQUIRES_OK(
550       ctx,
551       launch_context.PopulateOutputs(
552           ctx, closure.compilation_result(), run_result.ConsumeValueOrDie(),
553           /*missing_ctx_input_prefix=*/closure.num_constant_args()));
554 }
555 
556 REGISTER_KERNEL_BUILDER(Name("XlaLaunch").Device(DEVICE_CPU), XlaLocalLaunchOp);
557 
558 REGISTER_KERNEL_BUILDER(Name("XlaLaunch")
559                             .Device(DEVICE_GPU)
560                             .HostMemory("constants")
561                             .HostMemory("resources"),
562                         XlaLocalLaunchOp);
563 
564 REGISTER_KERNEL_BUILDER(Name("_XlaCompile").Device(DEVICE_CPU), XlaCompileOp);
565 REGISTER_KERNEL_BUILDER(Name("_XlaCompile")
566                             .Device(DEVICE_GPU)
567                             .HostMemory("constants")
568                             .HostMemory("key")
569                             .HostMemory("compilation_successful")
570                             .HostMemory("resources"),
571                         XlaCompileOp);
572 
573 REGISTER_KERNEL_BUILDER(Name("_XlaRun").Device(DEVICE_CPU), XlaRunOp);
574 REGISTER_KERNEL_BUILDER(Name("_XlaRun").Device(DEVICE_GPU), XlaRunOp);
575 
576 }  // namespace tensorflow
577