• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/jit/kernels/xla_ops.h"
17 
18 #include "absl/container/flat_hash_map.h"
19 #include "absl/memory/memory.h"
20 #include "absl/synchronization/notification.h"
21 #include "absl/types/optional.h"
22 #include "tensorflow/compiler/jit/defs.h"
23 #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
24 #include "tensorflow/compiler/jit/flags.h"
25 #include "tensorflow/compiler/jit/xla_activity_listener.h"
26 #include "tensorflow/compiler/jit/xla_cluster_util.h"
27 #include "tensorflow/compiler/jit/xla_platform_info.h"
28 #include "tensorflow/compiler/tf2xla/shape_util.h"
29 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
30 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
31 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
32 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
33 #include "tensorflow/compiler/xla/client/client_library.h"
34 #include "tensorflow/compiler/xla/client/local_client.h"
35 #include "tensorflow/compiler/xla/executable_run_options.h"
36 #include "tensorflow/compiler/xla/service/compiler.h"
37 #include "tensorflow/compiler/xla/status_macros.h"
38 #include "tensorflow/compiler/xla/statusor.h"
39 #include "tensorflow/core/common_runtime/dma_helper.h"
40 #include "tensorflow/core/common_runtime/function.h"
41 #include "tensorflow/core/framework/allocator.h"
42 #include "tensorflow/core/framework/node_def_util.h"
43 #include "tensorflow/core/framework/op.h"
44 #include "tensorflow/core/framework/op_kernel.h"
45 #include "tensorflow/core/framework/tensor.h"
46 #include "tensorflow/core/framework/types.h"
47 #include "tensorflow/core/lib/core/errors.h"
48 #include "tensorflow/core/lib/core/status.h"
49 #include "tensorflow/core/lib/monitoring/counter.h"
50 #include "tensorflow/core/platform/casts.h"
51 #include "tensorflow/core/platform/env.h"
52 #include "tensorflow/core/platform/refcount.h"
53 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
54 #include "tensorflow/core/profiler/lib/traceme.h"
55 #include "tensorflow/core/util/stream_executor_util.h"
56 
57 // OP_REQUIRES_OK_RETURN is the same as OP_REQUIRES_OK except that
58 // in error case, it returns RET instead of void.
59 #define OP_REQUIRES_OK_RETURN(CTX, RET, ...)                \
60   do {                                                      \
61     ::tensorflow::Status _s(__VA_ARGS__);                   \
62     if (!TF_PREDICT_TRUE(_s.ok())) {                        \
63       (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \
64       return RET;                                           \
65     }                                                       \
66   } while (0)
67 
68 namespace tensorflow {
69 
70 namespace {
71 
72 auto* xla_launch_counter = monitoring::Counter<1>::New(
73     "/tensorflow/core/xla_launch_counter",
74     "The number of times a XlaLaunch is called.", "device");
75 
76 // A closure describing how to run a compiled version of a TensorFlow function.
77 //
78 // It may seem unusual to stick the resource variable snapshots in this class.
79 // This is necessary: we need to use the snapshots observed by the compiler as
80 // the initial values for the resource variables (and cannot snapshot them again
81 // during execution) because otherwise we risk observing a different snapshot
82 // with shapes different from what we compiled for.
83 class XlaExecutableClosure {
84  public:
XlaExecutableClosure(xla::LocalClient * client,xla::LocalExecutable * executable,const XlaCompiler::CompilationResult * compilation_result,ResourceVarsSnapshot resource_var_snapshots,int num_constant_args)85   explicit XlaExecutableClosure(
86       xla::LocalClient* client, xla::LocalExecutable* executable,
87       const XlaCompiler::CompilationResult* compilation_result,
88       ResourceVarsSnapshot resource_var_snapshots, int num_constant_args)
89       : client_(client),
90         executable_(executable),
91         compilation_result_(compilation_result),
92         resource_var_snapshots_(std::move(resource_var_snapshots)),
93         num_constant_args_(num_constant_args) {}
94 
95   XlaExecutableClosure(XlaExecutableClosure&&) = default;
96   XlaExecutableClosure& operator=(XlaExecutableClosure&&) = default;
97 
client() const98   xla::LocalClient* client() const { return client_; }
executable() const99   xla::LocalExecutable* executable() const { return executable_; }
compilation_result() const100   const XlaCompiler::CompilationResult* compilation_result() const {
101     return compilation_result_;
102   }
resource_var_snapshots() const103   const ResourceVarsSnapshot& resource_var_snapshots() const {
104     return resource_var_snapshots_;
105   }
num_constant_args() const106   int num_constant_args() const { return num_constant_args_; }
107 
108  private:
109   xla::LocalClient* client_;
110   xla::LocalExecutable* executable_;
111   const XlaCompiler::CompilationResult* compilation_result_;
112   ResourceVarsSnapshot resource_var_snapshots_;
113   int num_constant_args_;
114 
115   TF_DISALLOW_COPY_AND_ASSIGN(XlaExecutableClosure);
116 };
117 
118 // This maintains a mapping from a globally unique ID to XlaExecutableClosure
119 // instances.
120 class XlaExecutableClosureStore {
121  public:
XlaExecutableClosureStore()122   XlaExecutableClosureStore() : key_counter_(0) {}
123 
124   using KeyT = string;
125 
Produce(XlaExecutableClosure result)126   KeyT Produce(XlaExecutableClosure result) {
127     mutex_lock l(mutex_);
128     KeyT key = absl::StrCat(key_counter_++);
129     bool insert_successful = closures_.emplace(key, std::move(result)).second;
130     DCHECK(insert_successful);
131     (void)insert_successful;
132     return key;
133   }
134 
Consume(const KeyT & key)135   XlaExecutableClosure Consume(const KeyT& key) {
136     mutex_lock l(mutex_);
137     auto it = closures_.find(key);
138     DCHECK(it != closures_.end());
139     XlaExecutableClosure value = std::move(it->second);
140     closures_.erase(it);
141     return value;
142   }
143 
Global()144   static XlaExecutableClosureStore* Global() {
145     static XlaExecutableClosureStore* instance = new XlaExecutableClosureStore;
146     return instance;
147   }
148 
149  private:
150   mutex mutex_;
151   int64 key_counter_ TF_GUARDED_BY(mutex_);
152   absl::flat_hash_map<KeyT, XlaExecutableClosure> closures_
153       TF_GUARDED_BY(mutex_);
154 
155   TF_DISALLOW_COPY_AND_ASSIGN(XlaExecutableClosureStore);
156 };
157 
158 }  // namespace
159 
XlaLocalLaunchBase(OpKernelConstruction * ctx,const std::vector<int> & constants,const std::vector<int> & resources,const NameAttrList & function,bool has_ref_vars)160 XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx,
161                                        const std::vector<int>& constants,
162                                        const std::vector<int>& resources,
163                                        const NameAttrList& function,
164                                        bool has_ref_vars)
165     : OpKernel(ctx),
166       constants_(constants),
167       resources_(resources),
168       function_(function),
169       platform_info_(XlaPlatformInfoFromDevice(ctx->device())),
170       has_ref_vars_(has_ref_vars) {}
171 
CompileToLocalExecutable(OpKernelContext * ctx,const NameAttrList & function,bool has_ref_vars,const XlaPlatformInfo & platform_info,absl::Span<const Tensor * const> inputs,absl::Span<VariableInfo const> variable_infos,absl::Span<const int> constants,XlaCompilationCache::CompileMode compile_mode,bool may_alias_resource_update,xla::LocalClient ** client,const XlaCompiler::CompilationResult ** compilation_result,xla::LocalExecutable ** executable)172 static Status CompileToLocalExecutable(
173     OpKernelContext* ctx, const NameAttrList& function, bool has_ref_vars,
174     const XlaPlatformInfo& platform_info,
175     absl::Span<const Tensor* const> inputs,
176     absl::Span<VariableInfo const> variable_infos,
177     absl::Span<const int> constants,
178     XlaCompilationCache::CompileMode compile_mode,
179     bool may_alias_resource_update, xla::LocalClient** client,
180     const XlaCompiler::CompilationResult** compilation_result,
181     xla::LocalExecutable** executable) {
182   // We store information about the JIT-compiled XLA computation
183   // in the ResourceMgr.
184   ResourceMgr* rm = ctx->resource_manager();
185   if (!rm) {
186     return errors::Internal("No resource manager.");
187   }
188 
189   XlaCompilationCache* cache;
190   TF_RETURN_IF_ERROR(rm->LookupOrCreate<XlaCompilationCache>(
191       rm->default_container(), "xla_cache", &cache,
192       [&](XlaCompilationCache** cache) {
193         return BuildXlaCompilationCache(ctx->device(), ctx->function_library(),
194                                         platform_info, cache);
195       }));
196   // Hold the reference to the JIT during evaluation. (We could probably
197   // free it sooner because the ResourceMgr will retain a reference, but
198   // this is more obviously correct.)
199   core::ScopedUnref cache_ref(cache);
200 
201   *client = static_cast<xla::LocalClient*>(cache->client());
202 
203   XlaCompiler::Options options = GenerateCompilerOptions(
204       *cache, *ctx->function_library(), ctx->device(),
205       ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr,
206       platform_info, has_ref_vars);
207 
208   XlaCompiler::CompileOptions compile_options;
209   compile_options.is_entry_computation = true;
210   // Optimization: where possible, have the computation return a naked array
211   // rather than a one-element tuple.
212   compile_options.always_return_tuple = false;
213   compile_options.alias_resource_update = !has_ref_vars &&
214                                           may_alias_resource_update;
215 
216   StatusOr<std::vector<XlaCompiler::Argument>> args =
217       XlaComputationLaunchContext::BuildXlaCompilerArguments(
218           constants, inputs, variable_infos,
219           static_cast<Device*>(ctx->device()));
220   TF_RETURN_IF_ERROR(args.status());
221   return cache->Compile(options, function, *args, compile_options, compile_mode,
222                         compilation_result, executable);
223 }
224 
Compute(OpKernelContext * ctx)225 void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
226   VLOG(1) << "XlaLocalLaunchOpBase::Compute "
227           << Canonicalize(function_.name(), AttrSlice(&function_.attr()));
228   xla_launch_counter->GetCell(platform_info_.device_type().type_string())
229       ->IncrementBy(1);
230 
231   std::vector<const Tensor*> inputs = InputsFromContext(ctx);
232   xla::LocalClient* client;
233   const XlaCompiler::CompilationResult* compilation_result;
234   xla::LocalExecutable* executable;
235 
236   std::vector<VariableInfo> variable_infos;
237   {
238     OP_REQUIRES_OK(
239         ctx, GetVariableInfosFromInputs(ctx->resource_manager(), ctx->device(),
240                                         inputs, resources_, &variable_infos));
241     OP_REQUIRES_OK(ctx, LockVariables(absl::MakeSpan(variable_infos)));
242     Status s = CompileToLocalExecutable(
243         ctx, function_, /*has_ref_vars=*/has_ref_vars_, platform_info_, inputs,
244         variable_infos, constants_, XlaCompilationCache::CompileMode::kStrict,
245         /*may_alias_resource_update=*/true, &client, &compilation_result,
246         &executable);
247     OP_REQUIRES_OK(ctx, s);
248   }
249 
250   std::map<int, const Tensor*> resource_var_ptrs;
251   for (int i = 0; i < resources_.size(); i++) {
252     resource_var_ptrs[resources_[i]] = variable_infos[i].var()->tensor();
253   }
254 
255   se::Stream* stream =
256       ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
257   std::shared_ptr<se::DeviceMemoryAllocator> allocator_ptr =
258       GetAllocator(ctx->device(), stream, platform_info_);
259   se::DeviceMemoryAllocator* allocator = allocator_ptr.get();
260   int device_ordinal = stream ? stream->parent()->device_ordinal()
261                               : client->default_device_ordinal();
262   XlaComputationLaunchContext launch_context(
263       client, allocator, device_ordinal,
264       /*allocate_xla_tensors=*/platform_info_.is_on_xla_device(),
265       platform_info_.UseMultipleStreams());
266   const xla::HloInputOutputAliasConfig& input_output_alias =
267       executable->executable()->module().input_output_alias_config();
268   StatusOr<std::vector<xla::ExecutionInput>> execution_inputs =
269       launch_context.PopulateInputs(ctx, compilation_result, resource_var_ptrs,
270                                     /*missing_ctx_input_prefix=*/0,
271                                     input_output_alias);
272   OP_REQUIRES_OK(ctx, execution_inputs.status());
273 
274   // Execute the computation.
275   VLOG(2) << "Executing computation.";
276   StatusOr<absl::optional<xla::DeviceAssignment>> device_assignment =
277       ResolveDeviceAssignment(ctx, compilation_result->collective_reduce_info);
278   OP_REQUIRES_OK(ctx, device_assignment.status());
279 
280   xla::ExecutableRunOptions run_options;
281   if (*device_assignment) {
282     run_options.set_device_assignment(&**device_assignment);
283   }
284   run_options.set_stream(stream);
285   run_options.set_allocator(allocator);
286   run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
287   run_options.set_rng_seed(GetXLARandomSeed());
288 
289   // Hardcode run id to always be zero: TF distributed strategy differentiates
290   // between subsequent runs using dependency edges.
291   // This is safe, as only TF dist-strat can produce distributed ops, and we can
292   // rely on TF dist-strat invariants.
293   xla::RunId run_id(0);
294   run_options.set_run_id(run_id);
295   Env* env = Env::Default();
296   auto start_time = env->NowMicros();
297 
298   StatusOr<xla::ExecutionOutput> execution_output;
299   if (!stream || platform_info_.platform_id() == se::host::kHostPlatformId) {
300     execution_output =
301         executable->Run(std::move(*execution_inputs), run_options);
302   } else {
303     execution_output =
304         executable->RunAsync(std::move(*execution_inputs), run_options);
305   }
306   OP_REQUIRES(ctx, execution_output.ok(), execution_output.status());
307 
308   auto elapsed = env->NowMicros() - start_time;
309   VLOG(2) << "Elapsed time: " << elapsed << "us";
310   OP_REQUIRES_OK(
311       ctx, launch_context.PopulateOutputs(
312                ctx, compilation_result, execution_output->ConsumeResult(),
313                /*missing_ctx_input_prefix=*/0, absl::MakeSpan(variable_infos),
314                input_output_alias, resource_var_ptrs));
315 
316   VLOG(1) << "Done";
317 }
318 
319 namespace {
320 // Helper static functions to construct parameters for
321 // XlaLocalLaunchBase constructor from OpKernelConstruction.
ConstantsVector(OpKernelConstruction * ctx)322 std::vector<int> ConstantsVector(OpKernelConstruction* ctx) {
323   DataTypeVector constant_types;
324   OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
325                         ctx->GetAttr("Tconstants", &constant_types));
326   std::vector<int> constants(constant_types.size());
327   std::iota(constants.begin(), constants.end(), 0);
328   return constants;
329 }
330 
ResourcesVector(OpKernelConstruction * ctx)331 std::vector<int> ResourcesVector(OpKernelConstruction* ctx) {
332   DataTypeVector constant_types;
333   OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
334                         ctx->GetAttr("Tconstants", &constant_types));
335 
336   DataTypeVector arg_types;
337   OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
338                         ctx->GetAttr("Targs", &arg_types));
339 
340   int num_resources;
341   OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
342                         ctx->GetAttr("Nresources", &num_resources));
343 
344   std::vector<int> resources(num_resources);
345   std::iota(resources.begin(), resources.end(),
346             constant_types.size() + arg_types.size());
347   return resources;
348 }
349 
FunctionAttr(OpKernelConstruction * ctx)350 NameAttrList FunctionAttr(OpKernelConstruction* ctx) {
351   const NameAttrList* func;
352   OP_REQUIRES_OK_RETURN(ctx, NameAttrList(), ctx->GetAttr("function", &func));
353   return *func;
354 }
355 
MustCompileAttr(OpKernelConstruction * ctx)356 bool MustCompileAttr(OpKernelConstruction* ctx) {
357   bool must_compile;
358   OP_REQUIRES_OK_RETURN(ctx, false,
359                         ctx->GetAttr("must_compile", &must_compile));
360   return must_compile;
361 }
362 
HasRefVars(OpKernelConstruction * ctx)363 bool HasRefVars(OpKernelConstruction* ctx) {
364   bool has_ref_vars;
365   OP_REQUIRES_OK_RETURN(ctx, false,
366                         ctx->GetAttr(kXlaHasReferenceVarsAttr, &has_ref_vars));
367   return has_ref_vars;
368 }
369 
370 }  // namespace
371 
XlaLocalLaunchOp(OpKernelConstruction * ctx)372 XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx)
373     : XlaLocalLaunchBase(ctx, ConstantsVector(ctx), ResourcesVector(ctx),
374                          FunctionAttr(ctx), /*has_ref_vars=*/true) {}
375 
~XlaLocalLaunchOp()376 XlaLocalLaunchOp::~XlaLocalLaunchOp() {
377   VLOG(1) << "XlaLocalLaunchOp destroyed";
378 }
379 
XlaCompileOp(OpKernelConstruction * ctx)380 XlaCompileOp::XlaCompileOp(OpKernelConstruction* ctx)
381     : OpKernel(ctx),
382       constants_(ConstantsVector(ctx)),
383       resources_(ResourcesVector(ctx)),
384       function_(FunctionAttr(ctx)),
385       platform_info_(XlaPlatformInfoFromDevice(ctx->device())),
386       must_compile_(MustCompileAttr(ctx)),
387       has_ref_vars_(HasRefVars(ctx)) {}
388 
Compute(OpKernelContext * ctx)389 void XlaCompileOp::Compute(OpKernelContext* ctx) {
390   VLOG(3) << "XlaCompileOp " << def().name()
391           << (must_compile_ ? "(must-compile)" : "");
392   xla::LocalClient* client;
393   const XlaCompiler::CompilationResult* kernel;
394   xla::LocalExecutable* executable;
395   ResourceVarsSnapshot variables;
396 
397   std::vector<const Tensor*> inputs = InputsFromContext(ctx);
398   bool cannot_compile_cluster;
399   {
400     mutex_lock guard(cannot_compile_cluster_mu_);
401     cannot_compile_cluster = cannot_compile_cluster_;
402   }
403   XlaCompilationCache::CompileMode compile_mode = [&] {
404     if (must_compile_) {
405       return XlaCompilationCache::CompileMode::kStrict;
406     }
407     return GetXlaOpsCommonFlags().tf_xla_async_compilation
408                ? XlaCompilationCache::CompileMode::kAsync
409                : XlaCompilationCache::CompileMode::kLazy;
410   }();
411 
412   if (GetXlaOpsCommonFlags().tf_xla_always_defer_compilation ||
413       cannot_compile_cluster) {
414     executable = nullptr;
415   } else {
416     std::vector<VariableInfo> variable_infos;
417     OP_REQUIRES_OK(
418         ctx, GetVariableInfosFromInputs(ctx->resource_manager(), ctx->device(),
419                                         inputs, resources_, &variable_infos));
420     OP_REQUIRES_OK(ctx, LockVariables(absl::MakeSpan(variable_infos)));
421 
422     // Do not alias resource updates as locking variables in XlaCompile and
423     // unlocking them in XlaRun may lead to deadlocks.
424     Status status = CompileToLocalExecutable(
425         ctx, function_, has_ref_vars_, platform_info_, inputs, variable_infos,
426         constants_, compile_mode, /*may_alias_resource_update=*/false, &client,
427         &kernel, &executable);
428     OP_REQUIRES_OK(ctx, SnapshotResourceVariables(ctx, resources_,
429                                                   variable_infos, &variables));
430     if (compile_mode != XlaCompilationCache::CompileMode::kLazy ||
431         status.code() != error::UNIMPLEMENTED) {
432       OP_REQUIRES_OK(ctx, status);
433     }
434 
435     if (status.code() == error::UNIMPLEMENTED) {
436       LOG(WARNING) << "Compilation failed:" << status.ToString()
437                    << ".  Falling back to TF function call.";
438 
439       BroadcastOptimizationRemark(
440           XlaOptimizationRemark::UNIMPLEMENTED_OPERATION, status.ToString())
441           .IgnoreError();
442       executable = nullptr;
443       mutex_lock guard(cannot_compile_cluster_mu_);
444       cannot_compile_cluster_ = true;
445     }
446   }
447 
448   AllocatorAttributes host_alloc_attrs;
449   host_alloc_attrs.set_gpu_compatible(true);
450   host_alloc_attrs.set_on_host(true);
451   Allocator* cpu_allocator = ctx->device()->GetAllocator(host_alloc_attrs);
452 
453   // Async compilation returns nullptr executable without an error.
454   if (!executable) {
455     DCHECK(!must_compile_);
456     Tensor compilation_key(cpu_allocator, DT_STRING, TensorShape({}));
457 
458     Tensor compilation_successful(cpu_allocator, DT_BOOL, TensorShape({}));
459     compilation_successful.scalar<bool>()() = false;
460     ctx->set_output(0, Tensor(cpu_allocator, DT_STRING, TensorShape({})));
461     ctx->set_output(1, compilation_successful);
462     return;
463   }
464 
465   // Each execution of an XlaCompile op creates a new XlaExecutableClosure, even
466   // if it didn't have to compile the cluster because of a compilation-cache
467   // hit.  This is because we at least need new snapshots of the resource
468   // variables.
469   XlaExecutableClosureStore::KeyT key =
470       XlaExecutableClosureStore::Global()->Produce(XlaExecutableClosure(
471           client, executable, kernel, std::move(variables), constants_.size()));
472 
473   Tensor compilation_key(cpu_allocator, DT_STRING, TensorShape({}));
474   compilation_key.flat<tstring>()(0) = key;
475 
476   Tensor compilation_successful(cpu_allocator, DT_BOOL, TensorShape({}));
477   compilation_successful.flat<bool>()(0) = true;
478 
479   ctx->set_output(0, compilation_key);
480   ctx->set_output(1, compilation_successful);
481 }
482 
XlaRunOp(OpKernelConstruction * ctx)483 XlaRunOp::XlaRunOp(OpKernelConstruction* ctx)
484     : OpKernel(ctx), platform_info_(XlaPlatformInfoFromDevice(ctx->device())) {}
485 
Compute(OpKernelContext * ctx)486 void XlaRunOp::Compute(OpKernelContext* ctx) {
487   VLOG(3) << "XlaRunOp " << def().name();
488   Tensor key_tensor = ctx->input(ctx->num_inputs() - 1);
489   const XlaExecutableClosureStore::KeyT& key = key_tensor.flat<tstring>()(0);
490 
491   XlaExecutableClosure closure =
492       XlaExecutableClosureStore::Global()->Consume(key);
493 
494   se::Stream* stream =
495       ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
496   std::shared_ptr<se::DeviceMemoryAllocator> allocator_ptr =
497       GetAllocator(ctx->device(), stream, platform_info_);
498   se::DeviceMemoryAllocator* allocator = allocator_ptr.get();
499   int device_ordinal = stream ? stream->parent()->device_ordinal()
500                               : closure.client()->default_device_ordinal();
501   XlaComputationLaunchContext launch_context(
502       closure.client(), allocator, device_ordinal,
503       /*allocate_xla_tensors=*/platform_info_.is_on_xla_device(),
504       /*use_multiple_streams=*/platform_info_.UseMultipleStreams());
505 
506   // We're missing the must-be-constant inputs, tell `PopulateInputs`
507   // about this.  We don't actually need these inputs because they've
508   // already been baked into the compiled kernel.
509   const xla::HloInputOutputAliasConfig& input_output_alias =
510       closure.executable()->executable()->module().input_output_alias_config();
511   StatusOr<std::vector<xla::ExecutionInput>> execution_inputs;
512   std::map<int, const Tensor*> snapshot_ptrs;
513   {
514     tensorflow::profiler::TraceMe hlo_module_activity(
515         [&] {
516           return absl::StrCat(
517               "Populate Inputs (",
518               closure.compilation_result()->xla_input_shapes.size(), ")");
519         },
520         tensorflow::profiler::TraceMeLevel::kInfo);
521 
522     for (auto& p : closure.resource_var_snapshots()) {
523       snapshot_ptrs.emplace(p.first,
524                             p.second.has_value() ? &p.second.value() : nullptr);
525     }
526     execution_inputs = launch_context.PopulateInputs(
527         ctx, closure.compilation_result(), snapshot_ptrs,
528         /*missing_ctx_input_prefix=*/closure.num_constant_args(),
529         input_output_alias);
530     OP_REQUIRES_OK(ctx, execution_inputs.status());
531   }
532 
533   xla::ExecutableRunOptions run_options;
534   run_options.set_stream(stream);
535   run_options.set_allocator(allocator);
536   run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
537   run_options.set_rng_seed(GetXLARandomSeed());
538   Env* env = Env::Default();
539   auto start_time = env->NowMicros();
540 
541   StatusOr<xla::ExecutionOutput> execution_output;
542   if (!stream || platform_info_.platform_id() == se::host::kHostPlatformId) {
543     execution_output =
544         closure.executable()->Run(std::move(*execution_inputs), run_options);
545   } else {
546     execution_output = closure.executable()->RunAsync(
547         std::move(*execution_inputs), run_options);
548   }
549   OP_REQUIRES(ctx, execution_output.ok(), execution_output.status());
550 
551   auto elapsed = env->NowMicros() - start_time;
552   VLOG(2) << "Elapsed time in computation: " << elapsed << "us";
553 
554 
555   tensorflow::profiler::TraceMe hlo_module_activity(
556       [&] {
557         return absl::StrCat("Populate Outputs (", ctx->num_outputs(), ")");
558       },
559       tensorflow::profiler::TraceMeLevel::kInfo);
560 
561   StatusOr<std::vector<VariableInfo>> variable_infos = GatherVariableInfo(
562       ctx, *closure.compilation_result(), closure.num_constant_args());
563   OP_REQUIRES_OK(ctx, variable_infos.status());
564   OP_REQUIRES_OK(ctx, LockVariables(absl::MakeSpan(*variable_infos)));
565   OP_REQUIRES_OK(
566       ctx,
567       launch_context.PopulateOutputs(
568           ctx, closure.compilation_result(), execution_output->ConsumeResult(),
569           /*missing_ctx_input_prefix=*/closure.num_constant_args(),
570           absl::MakeSpan(*variable_infos), input_output_alias, snapshot_ptrs));
571 }
572 
XlaMergeOp(OpKernelConstruction * ctx)573 XlaMergeOp::XlaMergeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
574 
Compute(OpKernelContext * ctx)575 void XlaMergeOp::Compute(OpKernelContext* ctx) {
576   VLOG(3) << "XlaMergeOp " << def().name();
577   int i = 0;
578   if (ctx->has_input(i) || ctx->has_input(++i)) {
579     ctx->set_output(0, ctx->input(i));
580   }
581 }
582 
583 REGISTER_KERNEL_BUILDER(Name("XlaLaunch").Device(DEVICE_CPU), XlaLocalLaunchOp);
584 
585 REGISTER_KERNEL_BUILDER(Name("XlaLaunch")
586                             .Device(DEVICE_GPU)
587                             .HostMemory("constants")
588                             .HostMemory("resources"),
589                         XlaLocalLaunchOp);
590 
591 REGISTER_KERNEL_BUILDER(Name("_XlaCompile").Device(DEVICE_CPU), XlaCompileOp);
592 REGISTER_KERNEL_BUILDER(Name("_XlaCompile")
593                             .Device(DEVICE_GPU)
594                             .HostMemory("constants")
595                             .HostMemory("key")
596                             .HostMemory("compilation_successful")
597                             .HostMemory("resources"),
598                         XlaCompileOp);
599 
600 REGISTER_KERNEL_BUILDER(Name("_XlaRun").Device(DEVICE_CPU), XlaRunOp);
601 REGISTER_KERNEL_BUILDER(Name("_XlaRun").Device(DEVICE_GPU).HostMemory("key"),
602                         XlaRunOp);
603 
604 REGISTER_KERNEL_BUILDER(Name("_XlaMerge").Device(DEVICE_CPU), XlaMergeOp);
605 REGISTER_KERNEL_BUILDER(Name("_XlaMerge").Device(DEVICE_GPU), XlaMergeOp);
606 
607 }  // namespace tensorflow
608