• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/jit/kernels/xla_ops.h"
17 
18 #include "absl/container/flat_hash_map.h"
19 #include "absl/memory/memory.h"
20 #include "tensorflow/compiler/jit/defs.h"
21 #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
22 #include "tensorflow/compiler/jit/flags.h"
23 #include "tensorflow/compiler/jit/xla_activity_listener.h"
24 #include "tensorflow/compiler/jit/xla_cluster_util.h"
25 #include "tensorflow/compiler/jit/xla_platform_info.h"
26 #include "tensorflow/compiler/tf2xla/shape_util.h"
27 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
28 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
29 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
30 #include "tensorflow/compiler/xla/client/client_library.h"
31 #include "tensorflow/compiler/xla/client/local_client.h"
32 #include "tensorflow/compiler/xla/executable_run_options.h"
33 #include "tensorflow/compiler/xla/service/compiler.h"
34 #include "tensorflow/compiler/xla/status_macros.h"
35 #include "tensorflow/compiler/xla/statusor.h"
36 #include "tensorflow/core/common_runtime/dma_helper.h"
37 #include "tensorflow/core/common_runtime/function.h"
38 #include "tensorflow/core/framework/allocator.h"
39 #include "tensorflow/core/framework/node_def_util.h"
40 #include "tensorflow/core/framework/op.h"
41 #include "tensorflow/core/framework/op_kernel.h"
42 #include "tensorflow/core/framework/tensor.h"
43 #include "tensorflow/core/framework/types.h"
44 #include "tensorflow/core/lib/core/errors.h"
45 #include "tensorflow/core/lib/core/status.h"
46 #include "tensorflow/core/platform/casts.h"
47 #include "tensorflow/core/platform/env.h"
48 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
49 #include "tensorflow/core/profiler/lib/traceme.h"
50 #include "tensorflow/core/util/stream_executor_util.h"
51 
52 // OP_REQUIRES_OK_RETURN is the same as OP_REQUIRES_OK except that
53 // in error case, it returns RET instead of void.
54 #define OP_REQUIRES_OK_RETURN(CTX, RET, ...)                \
55   do {                                                      \
56     ::tensorflow::Status _s(__VA_ARGS__);                   \
57     if (!TF_PREDICT_TRUE(_s.ok())) {                        \
58       (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \
59       return RET;                                           \
60     }                                                       \
61   } while (0)
62 
63 namespace tensorflow {
64 
65 namespace {
66 
67 
68 // A closure describing how to run a compiled version of a TensorFlow function.
69 //
70 // It may seem unusual to stick the resource variable snapshots in this class.
71 // This is necessary: we need to use the snapshots observed by the compiler as
72 // the initial values for the resource variables (and cannot snapshot them again
73 // during execution) because otherwise we risk observing a different snapshot
74 // with shapes different from what we compiled for.
75 class XlaExecutableClosure {
76  public:
XlaExecutableClosure(xla::LocalClient * client,xla::LocalExecutable * executable,const XlaCompiler::CompilationResult * compilation_result,ResourceVarsSnapshot resource_var_snapshots,int num_constant_args)77   explicit XlaExecutableClosure(
78       xla::LocalClient* client, xla::LocalExecutable* executable,
79       const XlaCompiler::CompilationResult* compilation_result,
80       ResourceVarsSnapshot resource_var_snapshots, int num_constant_args)
81       : client_(client),
82         executable_(executable),
83         compilation_result_(compilation_result),
84         resource_var_snapshots_(std::move(resource_var_snapshots)),
85         num_constant_args_(num_constant_args) {}
86 
87   XlaExecutableClosure(XlaExecutableClosure&&) = default;
88   XlaExecutableClosure& operator=(XlaExecutableClosure&&) = default;
89 
client() const90   xla::LocalClient* client() const { return client_; }
executable() const91   xla::LocalExecutable* executable() const { return executable_; }
compilation_result() const92   const XlaCompiler::CompilationResult* compilation_result() const {
93     return compilation_result_;
94   }
resource_var_snapshots() const95   const ResourceVarsSnapshot& resource_var_snapshots() const {
96     return resource_var_snapshots_;
97   }
num_constant_args() const98   int num_constant_args() const { return num_constant_args_; }
99 
100  private:
101   xla::LocalClient* client_;
102   xla::LocalExecutable* executable_;
103   const XlaCompiler::CompilationResult* compilation_result_;
104   ResourceVarsSnapshot resource_var_snapshots_;
105   int num_constant_args_;
106 
107   TF_DISALLOW_COPY_AND_ASSIGN(XlaExecutableClosure);
108 };
109 
110 // This maintains a mapping from a globally unique ID to XlaExecutableClosure
111 // instances.
112 class XlaExecutableClosureStore {
113  public:
XlaExecutableClosureStore()114   XlaExecutableClosureStore() : key_counter_(0) {}
115 
116   using KeyT = string;
117 
Produce(XlaExecutableClosure result)118   KeyT Produce(XlaExecutableClosure result) {
119     mutex_lock l(mutex_);
120     KeyT key = absl::StrCat(key_counter_++);
121     bool insert_successful = closures_.emplace(key, std::move(result)).second;
122     DCHECK(insert_successful);
123     (void)insert_successful;
124     return key;
125   }
126 
Consume(const KeyT & key)127   XlaExecutableClosure Consume(const KeyT& key) {
128     mutex_lock l(mutex_);
129     auto it = closures_.find(key);
130     DCHECK(it != closures_.end());
131     XlaExecutableClosure value = std::move(it->second);
132     closures_.erase(it);
133     return value;
134   }
135 
Global()136   static XlaExecutableClosureStore* Global() {
137     static XlaExecutableClosureStore* instance = new XlaExecutableClosureStore;
138     return instance;
139   }
140 
141  private:
142   mutex mutex_;
143   int64 key_counter_ TF_GUARDED_BY(mutex_);
144   absl::flat_hash_map<KeyT, XlaExecutableClosure> closures_
145       TF_GUARDED_BY(mutex_);
146 
147   TF_DISALLOW_COPY_AND_ASSIGN(XlaExecutableClosureStore);
148 };
149 
150 }  // namespace
151 
XlaLocalLaunchBase(OpKernelConstruction * ctx,const std::vector<int> & constants,const std::vector<int> & resources,const NameAttrList & function,bool has_ref_vars)152 XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx,
153                                        const std::vector<int>& constants,
154                                        const std::vector<int>& resources,
155                                        const NameAttrList& function,
156                                        bool has_ref_vars)
157     : OpKernel(ctx),
158       constants_(constants),
159       resources_(resources),
160       function_(function),
161       platform_info_(XlaPlatformInfoFromDevice(ctx->device())),
162       has_ref_vars_(has_ref_vars) {}
163 
CompileToLocalExecutable(OpKernelContext * ctx,const NameAttrList & function,bool has_ref_vars,const XlaPlatformInfo & platform_info,absl::Span<const Tensor * const> inputs,absl::Span<VariableInfo const> variable_infos,absl::Span<const int> constants,bool lazy,bool may_alias_resource_update,xla::LocalClient ** client,const XlaCompiler::CompilationResult ** compilation_result,xla::LocalExecutable ** executable)164 static Status CompileToLocalExecutable(
165     OpKernelContext* ctx, const NameAttrList& function, bool has_ref_vars,
166     const XlaPlatformInfo& platform_info,
167     absl::Span<const Tensor* const> inputs,
168     absl::Span<VariableInfo const> variable_infos,
169     absl::Span<const int> constants, bool lazy, bool may_alias_resource_update,
170     xla::LocalClient** client,
171     const XlaCompiler::CompilationResult** compilation_result,
172     xla::LocalExecutable** executable) {
173   // We store information about the JIT-compiled XLA computation
174   // in the ResourceMgr.
175   ResourceMgr* rm = ctx->resource_manager();
176   if (!rm) {
177     return errors::Internal("No resource manager.");
178   }
179 
180   XlaCompilationCache* cache;
181   TF_RETURN_IF_ERROR(rm->LookupOrCreate<XlaCompilationCache>(
182       rm->default_container(), "xla_cache", &cache,
183       [&](XlaCompilationCache** cache) {
184         return BuildXlaCompilationCache(ctx->device(), platform_info, cache);
185       }));
186   // Hold the reference to the JIT during evaluation. (We could probably
187   // free it sooner because the ResourceMgr will retain a reference, but
188   // this is more obviously correct.)
189   core::ScopedUnref cache_ref(cache);
190 
191   *client = static_cast<xla::LocalClient*>(cache->client());
192 
193   absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
194   XlaCompiler::Options options = GenerateCompilerOptions(
195       *cache, *ctx->function_library(), ctx->device(),
196       ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr,
197       platform_info, has_ref_vars, &tf_allocator_adapter);
198 
199   XlaCompiler::CompileOptions compile_options;
200   compile_options.is_entry_computation = true;
201   // Optimization: where possible, have the computation return a naked array
202   // rather than a one-element tuple.
203   compile_options.always_return_tuple = false;
204   compile_options.alias_resource_update = !has_ref_vars &&
205                                           !platform_info.is_on_xla_device() &&
206                                           may_alias_resource_update;
207 
208   xla::StatusOr<std::vector<XlaCompiler::Argument>> args =
209       XlaComputationLaunchContext::BuildXlaCompilerArguments(
210           constants, inputs, variable_infos,
211           static_cast<Device*>(ctx->device()));
212   TF_RETURN_IF_ERROR(args.status());
213   return cache->Compile(options, function, *args, compile_options,
214                         lazy ? XlaCompilationCache::CompileMode::kLazy
215                              : XlaCompilationCache::CompileMode::kStrict,
216                         compilation_result, executable);
217 }
218 
Compute(OpKernelContext * ctx)219 void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
220   VLOG(1) << "XlaLocalLaunchOpBase::Compute "
221           << Canonicalize(function_.name(), AttrSlice(&function_.attr()));
222 
223   std::vector<const Tensor*> inputs = InputsFromContext(ctx);
224   xla::LocalClient* client;
225   const XlaCompiler::CompilationResult* compilation_result;
226   xla::LocalExecutable* executable;
227 
228   std::vector<VariableInfo> variable_infos;
229   {
230     OP_REQUIRES_OK(
231         ctx, GetVariableInfosFromInputs(ctx->resource_manager(), ctx->device(),
232                                         inputs, resources_, &variable_infos));
233     OP_REQUIRES_OK(ctx, LockVariables(absl::MakeSpan(variable_infos)));
234     Status s = CompileToLocalExecutable(
235         ctx, function_, /*has_ref_vars=*/has_ref_vars_, platform_info_, inputs,
236         variable_infos, constants_, /*lazy=*/false,
237         /*may_alias_resource_update=*/true, &client, &compilation_result,
238         &executable);
239     OP_REQUIRES_OK(ctx, s);
240   }
241 
242   std::map<int, const Tensor*> resource_var_ptrs;
243   for (int i = 0; i < resources_.size(); i++) {
244     resource_var_ptrs[resources_[i]] = variable_infos[i].var()->tensor();
245   }
246 
247   se::Stream* stream =
248       ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
249 
250   absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
251   se::DeviceMemoryAllocator* allocator = GetAllocator(
252       &tf_allocator_adapter, ctx->device(),
253       ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr,
254       platform_info_);
255   int device_ordinal = stream ? stream->parent()->device_ordinal()
256                               : client->default_device_ordinal();
257   XlaComputationLaunchContext launch_context(
258       client, allocator, device_ordinal,
259       /*allocate_xla_tensors=*/platform_info_.is_on_xla_device(),
260       platform_info_.UseMultipleStreams());
261   const xla::HloInputOutputAliasConfig& input_output_alias =
262       executable->executable()->module().input_output_alias_config();
263   xla::StatusOr<std::vector<xla::ExecutionInput>> execution_inputs =
264       launch_context.PopulateInputs(ctx, compilation_result, resource_var_ptrs,
265                                     /*missing_ctx_input_prefix=*/0,
266                                     input_output_alias);
267   OP_REQUIRES_OK(ctx, execution_inputs.status());
268 
269   // Execute the computation.
270   VLOG(2) << "Executing computation.";
271   xla::ExecutableRunOptions run_options;
272   run_options.set_stream(stream);
273   run_options.set_allocator(allocator);
274   run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
275   run_options.set_rng_seed(GetXLARandomSeed());
276   Env* env = Env::Default();
277   auto start_time = env->NowMicros();
278 
279   xla::StatusOr<xla::ExecutionOutput> execution_output;
280   if (!stream || platform_info_.platform_id() == se::host::kHostPlatformId) {
281     execution_output =
282         executable->Run(std::move(*execution_inputs), run_options);
283   } else {
284     execution_output =
285         executable->RunAsync(std::move(*execution_inputs), run_options);
286   }
287   OP_REQUIRES(ctx, execution_output.ok(), execution_output.status());
288 
289   auto elapsed = env->NowMicros() - start_time;
290   VLOG(2) << "Elapsed time: " << elapsed << "us";
291   OP_REQUIRES_OK(
292       ctx, launch_context.PopulateOutputs(
293                ctx, compilation_result, execution_output->ConsumeResult(),
294                /*missing_ctx_input_prefix=*/0, absl::MakeSpan(variable_infos),
295                input_output_alias, resource_var_ptrs));
296 
297   VLOG(1) << "Done";
298 }
299 
300 namespace {
301 // Helper static functions to construct parameters for
302 // XlaLocalLaunchBase constructor from OpKernelConstruction.
ConstantsVector(OpKernelConstruction * ctx)303 std::vector<int> ConstantsVector(OpKernelConstruction* ctx) {
304   DataTypeVector constant_types;
305   OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
306                         ctx->GetAttr("Tconstants", &constant_types));
307   std::vector<int> constants(constant_types.size());
308   std::iota(constants.begin(), constants.end(), 0);
309   return constants;
310 }
311 
ResourcesVector(OpKernelConstruction * ctx)312 std::vector<int> ResourcesVector(OpKernelConstruction* ctx) {
313   DataTypeVector constant_types;
314   OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
315                         ctx->GetAttr("Tconstants", &constant_types));
316 
317   DataTypeVector arg_types;
318   OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
319                         ctx->GetAttr("Targs", &arg_types));
320 
321   int num_resources;
322   OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
323                         ctx->GetAttr("Nresources", &num_resources));
324 
325   std::vector<int> resources(num_resources);
326   std::iota(resources.begin(), resources.end(),
327             constant_types.size() + arg_types.size());
328   return resources;
329 }
330 
FunctionAttr(OpKernelConstruction * ctx)331 NameAttrList FunctionAttr(OpKernelConstruction* ctx) {
332   const NameAttrList* func;
333   OP_REQUIRES_OK_RETURN(ctx, NameAttrList(), ctx->GetAttr("function", &func));
334   return *func;
335 }
336 
MustCompileAttr(OpKernelConstruction * ctx)337 bool MustCompileAttr(OpKernelConstruction* ctx) {
338   bool must_compile;
339   OP_REQUIRES_OK_RETURN(ctx, false,
340                         ctx->GetAttr("must_compile", &must_compile));
341   return must_compile;
342 }
343 
HasRefVars(OpKernelConstruction * ctx)344 bool HasRefVars(OpKernelConstruction* ctx) {
345   bool has_ref_vars;
346   OP_REQUIRES_OK_RETURN(ctx, false,
347                         ctx->GetAttr(kXlaHasReferenceVarsAttr, &has_ref_vars));
348   return has_ref_vars;
349 }
350 
351 }  // namespace
352 
XlaLocalLaunchOp(OpKernelConstruction * ctx)353 XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx)
354     : XlaLocalLaunchBase(ctx, ConstantsVector(ctx), ResourcesVector(ctx),
355                          FunctionAttr(ctx), /*has_ref_vars=*/true) {}
356 
~XlaLocalLaunchOp()357 XlaLocalLaunchOp::~XlaLocalLaunchOp() {
358   VLOG(1) << "XlaLocalLaunchOp destroyed";
359 }
360 
XlaCompileOp(OpKernelConstruction * ctx)361 XlaCompileOp::XlaCompileOp(OpKernelConstruction* ctx)
362     : OpKernel(ctx),
363       constants_(ConstantsVector(ctx)),
364       resources_(ResourcesVector(ctx)),
365       function_(FunctionAttr(ctx)),
366       platform_info_(XlaPlatformInfoFromDevice(ctx->device())),
367       must_compile_(MustCompileAttr(ctx)),
368       has_ref_vars_(HasRefVars(ctx)) {}
369 
Compute(OpKernelContext * ctx)370 void XlaCompileOp::Compute(OpKernelContext* ctx) {
371   VLOG(3) << "XlaCompileOp " << def().name()
372           << (must_compile_ ? "(must-compile)" : "");
373   xla::LocalClient* client;
374   const XlaCompiler::CompilationResult* kernel;
375   xla::LocalExecutable* executable;
376   ResourceVarsSnapshot variables;
377 
378   std::vector<const Tensor*> inputs = InputsFromContext(ctx);
379   bool cannot_compile_cluster;
380   {
381     mutex_lock guard(cannot_compile_cluster_mu_);
382     cannot_compile_cluster = cannot_compile_cluster_;
383   }
384 
385   if (GetXlaOpsCommonFlags().tf_xla_always_defer_compilation ||
386       cannot_compile_cluster) {
387     executable = nullptr;
388   } else {
389     std::vector<VariableInfo> variable_infos;
390     OP_REQUIRES_OK(
391         ctx, GetVariableInfosFromInputs(ctx->resource_manager(), ctx->device(),
392                                         inputs, resources_, &variable_infos));
393     OP_REQUIRES_OK(ctx, LockVariables(absl::MakeSpan(variable_infos)));
394 
395     // Do not alias resource updates as locking variables in XlaCompile and
396     // unlocking them in XlaRun may lead to deadlocks.
397     Status status = CompileToLocalExecutable(
398         ctx, function_, has_ref_vars_, platform_info_, inputs, variable_infos,
399         constants_,
400         /*lazy=*/!must_compile_,
401         /*may_alias_resource_update=*/false, &client, &kernel, &executable);
402     OP_REQUIRES_OK(ctx, SnapshotResourceVariables(ctx, resources_,
403                                                   variable_infos, &variables));
404     if (must_compile_ || status.code() != error::UNIMPLEMENTED) {
405       OP_REQUIRES_OK(ctx, status);
406     }
407 
408     if (status.code() == error::UNIMPLEMENTED) {
409       LOG(WARNING) << "Compilation failed:" << status.ToString()
410                    << ".  Falling back to TF function call.";
411 
412       BroadcastOptimizationRemark(
413           XlaOptimizationRemark::UNIMPLEMENTED_OPERATION, status.ToString())
414           .IgnoreError();
415       executable = nullptr;
416       mutex_lock guard(cannot_compile_cluster_mu_);
417       cannot_compile_cluster_ = true;
418     }
419   }
420 
421   AllocatorAttributes host_alloc_attrs;
422   host_alloc_attrs.set_gpu_compatible(true);
423   host_alloc_attrs.set_on_host(true);
424   Allocator* cpu_allocator = ctx->device()->GetAllocator(host_alloc_attrs);
425 
426   if (!executable) {
427     DCHECK(!must_compile_);
428     Tensor compilation_key(cpu_allocator, DT_STRING, TensorShape({}));
429 
430     Tensor compilation_successful(cpu_allocator, DT_BOOL, TensorShape({}));
431     compilation_successful.scalar<bool>()() = false;
432     ctx->set_output(0, Tensor(cpu_allocator, DT_STRING, TensorShape({})));
433     ctx->set_output(1, compilation_successful);
434     return;
435   }
436 
437   // Each execution of an XlaCompile op creates a new XlaExecutableClosure, even
438   // if it didn't have to compile the cluster because of a compilation-cache
439   // hit.  This is because we at least need new snapshots of the resource
440   // variables.
441   XlaExecutableClosureStore::KeyT key =
442       XlaExecutableClosureStore::Global()->Produce(XlaExecutableClosure(
443           client, executable, kernel, std::move(variables), constants_.size()));
444 
445   Tensor compilation_key(cpu_allocator, DT_STRING, TensorShape({}));
446   compilation_key.flat<tstring>()(0) = key;
447 
448   Tensor compilation_successful(cpu_allocator, DT_BOOL, TensorShape({}));
449   compilation_successful.flat<bool>()(0) = true;
450 
451   ctx->set_output(0, compilation_key);
452   ctx->set_output(1, compilation_successful);
453 }
454 
XlaRunOp(OpKernelConstruction * ctx)455 XlaRunOp::XlaRunOp(OpKernelConstruction* ctx)
456     : OpKernel(ctx), platform_info_(XlaPlatformInfoFromDevice(ctx->device())) {}
457 
Compute(OpKernelContext * ctx)458 void XlaRunOp::Compute(OpKernelContext* ctx) {
459   VLOG(3) << "XlaRunOp " << def().name();
460   Tensor key_tensor = ctx->input(ctx->num_inputs() - 1);
461   const XlaExecutableClosureStore::KeyT& key = key_tensor.flat<tstring>()(0);
462 
463   XlaExecutableClosure closure =
464       XlaExecutableClosureStore::Global()->Consume(key);
465 
466   absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
467   se::DeviceMemoryAllocator* allocator = GetAllocator(
468       &tf_allocator_adapter, ctx->device(),
469       ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr,
470       platform_info_);
471   se::Stream* stream =
472       ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
473   int device_ordinal = stream ? stream->parent()->device_ordinal()
474                               : closure.client()->default_device_ordinal();
475   XlaComputationLaunchContext launch_context(
476       closure.client(), allocator, device_ordinal,
477       /*allocate_xla_tensors=*/platform_info_.is_on_xla_device(),
478       /*use_multiple_streams=*/platform_info_.UseMultipleStreams());
479 
480   // We're missing the must-be-constant inputs, tell `PopulateInputs`
481   // about this.  We don't actually need these inputs because they've
482   // already been baked into the compiled kernel.
483   const xla::HloInputOutputAliasConfig& input_output_alias =
484       closure.executable()->executable()->module().input_output_alias_config();
485   xla::StatusOr<std::vector<xla::ExecutionInput>> execution_inputs;
486   std::map<int, const Tensor*> snapshot_ptrs;
487   {
488     tensorflow::profiler::TraceMe hlo_module_activity(
489         [&] {
490           return absl::StrCat(
491               "Populate Inputs (",
492               closure.compilation_result()->xla_input_shapes.size(), ")");
493         },
494         tensorflow::profiler::TraceMeLevel::kInfo);
495 
496     for (auto& p : closure.resource_var_snapshots()) {
497       snapshot_ptrs.emplace(p.first,
498                             p.second.has_value() ? &p.second.value() : nullptr);
499     }
500     execution_inputs = launch_context.PopulateInputs(
501         ctx, closure.compilation_result(), snapshot_ptrs,
502         /*missing_ctx_input_prefix=*/closure.num_constant_args(),
503         input_output_alias);
504     OP_REQUIRES_OK(ctx, execution_inputs.status());
505   }
506 
507   xla::ExecutableRunOptions run_options;
508   run_options.set_stream(stream);
509   run_options.set_allocator(allocator);
510   run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
511   run_options.set_rng_seed(GetXLARandomSeed());
512   Env* env = Env::Default();
513   auto start_time = env->NowMicros();
514 
515   xla::StatusOr<xla::ExecutionOutput> execution_output;
516   if (!stream || platform_info_.platform_id() == se::host::kHostPlatformId) {
517     execution_output =
518         closure.executable()->Run(std::move(*execution_inputs), run_options);
519   } else {
520     execution_output = closure.executable()->RunAsync(
521         std::move(*execution_inputs), run_options);
522   }
523   OP_REQUIRES(ctx, execution_output.ok(), execution_output.status());
524 
525   auto elapsed = env->NowMicros() - start_time;
526   VLOG(2) << "Elapsed time in computation: " << elapsed << "us";
527 
528 
529   tensorflow::profiler::TraceMe hlo_module_activity(
530       [&] {
531         return absl::StrCat("Populate Outputs (", ctx->num_outputs(), ")");
532       },
533       tensorflow::profiler::TraceMeLevel::kInfo);
534 
535   xla::StatusOr<std::vector<VariableInfo>> variable_infos = GatherVariableInfo(
536       ctx, *closure.compilation_result(), closure.num_constant_args());
537   OP_REQUIRES_OK(ctx, variable_infos.status());
538   OP_REQUIRES_OK(ctx, LockVariables(absl::MakeSpan(*variable_infos)));
539   OP_REQUIRES_OK(
540       ctx,
541       launch_context.PopulateOutputs(
542           ctx, closure.compilation_result(), execution_output->ConsumeResult(),
543           /*missing_ctx_input_prefix=*/closure.num_constant_args(),
544           absl::MakeSpan(*variable_infos), input_output_alias, snapshot_ptrs));
545 }
546 
XlaMergeOp(OpKernelConstruction * ctx)547 XlaMergeOp::XlaMergeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
548 
Compute(OpKernelContext * ctx)549 void XlaMergeOp::Compute(OpKernelContext* ctx) {
550   VLOG(3) << "XlaMergeOp " << def().name();
551   int i = 0;
552   if (ctx->has_input(i) || ctx->has_input(++i)) {
553     ctx->set_output(0, ctx->input(i));
554   }
555 }
556 
557 REGISTER_KERNEL_BUILDER(Name("XlaLaunch").Device(DEVICE_CPU), XlaLocalLaunchOp);
558 
559 REGISTER_KERNEL_BUILDER(Name("XlaLaunch")
560                             .Device(DEVICE_GPU)
561                             .HostMemory("constants")
562                             .HostMemory("resources"),
563                         XlaLocalLaunchOp);
564 
565 REGISTER_KERNEL_BUILDER(Name("_XlaCompile").Device(DEVICE_CPU), XlaCompileOp);
566 REGISTER_KERNEL_BUILDER(Name("_XlaCompile")
567                             .Device(DEVICE_GPU)
568                             .HostMemory("constants")
569                             .HostMemory("key")
570                             .HostMemory("compilation_successful")
571                             .HostMemory("resources"),
572                         XlaCompileOp);
573 
574 REGISTER_KERNEL_BUILDER(Name("_XlaRun").Device(DEVICE_CPU), XlaRunOp);
575 REGISTER_KERNEL_BUILDER(Name("_XlaRun").Device(DEVICE_GPU).HostMemory("key"),
576                         XlaRunOp);
577 
578 REGISTER_KERNEL_BUILDER(Name("_XlaMerge").Device(DEVICE_CPU), XlaMergeOp);
579 REGISTER_KERNEL_BUILDER(Name("_XlaMerge").Device(DEVICE_GPU), XlaMergeOp);
580 
581 }  // namespace tensorflow
582