1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/jit/kernels/xla_ops.h"
17
18 #include "absl/container/flat_hash_map.h"
19 #include "absl/memory/memory.h"
20 #include "tensorflow/compiler/jit/defs.h"
21 #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
22 #include "tensorflow/compiler/jit/flags.h"
23 #include "tensorflow/compiler/jit/xla_activity_listener.h"
24 #include "tensorflow/compiler/jit/xla_cluster_util.h"
25 #include "tensorflow/compiler/jit/xla_platform_info.h"
26 #include "tensorflow/compiler/tf2xla/shape_util.h"
27 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
28 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
29 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
30 #include "tensorflow/compiler/xla/client/client_library.h"
31 #include "tensorflow/compiler/xla/client/local_client.h"
32 #include "tensorflow/compiler/xla/executable_run_options.h"
33 #include "tensorflow/compiler/xla/service/compiler.h"
34 #include "tensorflow/compiler/xla/status_macros.h"
35 #include "tensorflow/compiler/xla/statusor.h"
36 #include "tensorflow/core/common_runtime/dma_helper.h"
37 #include "tensorflow/core/common_runtime/function.h"
38 #include "tensorflow/core/framework/allocator.h"
39 #include "tensorflow/core/framework/node_def_util.h"
40 #include "tensorflow/core/framework/op.h"
41 #include "tensorflow/core/framework/op_kernel.h"
42 #include "tensorflow/core/framework/tensor.h"
43 #include "tensorflow/core/framework/types.h"
44 #include "tensorflow/core/lib/core/errors.h"
45 #include "tensorflow/core/lib/core/status.h"
46 #include "tensorflow/core/platform/casts.h"
47 #include "tensorflow/core/platform/env.h"
48 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
49 #include "tensorflow/core/profiler/lib/traceme.h"
50 #include "tensorflow/core/util/stream_executor_util.h"
51
52 // OP_REQUIRES_OK_RETURN is the same as OP_REQUIRES_OK except that
53 // in error case, it returns RET instead of void.
54 #define OP_REQUIRES_OK_RETURN(CTX, RET, ...) \
55 do { \
56 ::tensorflow::Status _s(__VA_ARGS__); \
57 if (!TF_PREDICT_TRUE(_s.ok())) { \
58 (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \
59 return RET; \
60 } \
61 } while (0)
62
63 namespace tensorflow {
64
65 namespace {
66
67
68 // A closure describing how to run a compiled version of a TensorFlow function.
69 //
70 // It may seem unusual to stick the resource variable snapshots in this class.
71 // This is necessary: we need to use the snapshots observed by the compiler as
72 // the initial values for the resource variables (and cannot snapshot them again
73 // during execution) because otherwise we risk observing a different snapshot
74 // with shapes different from what we compiled for.
75 class XlaExecutableClosure {
76 public:
XlaExecutableClosure(xla::LocalClient * client,xla::LocalExecutable * executable,const XlaCompiler::CompilationResult * compilation_result,ResourceVarsSnapshot resource_var_snapshots,int num_constant_args)77 explicit XlaExecutableClosure(
78 xla::LocalClient* client, xla::LocalExecutable* executable,
79 const XlaCompiler::CompilationResult* compilation_result,
80 ResourceVarsSnapshot resource_var_snapshots, int num_constant_args)
81 : client_(client),
82 executable_(executable),
83 compilation_result_(compilation_result),
84 resource_var_snapshots_(std::move(resource_var_snapshots)),
85 num_constant_args_(num_constant_args) {}
86
87 XlaExecutableClosure(XlaExecutableClosure&&) = default;
88 XlaExecutableClosure& operator=(XlaExecutableClosure&&) = default;
89
client() const90 xla::LocalClient* client() const { return client_; }
executable() const91 xla::LocalExecutable* executable() const { return executable_; }
compilation_result() const92 const XlaCompiler::CompilationResult* compilation_result() const {
93 return compilation_result_;
94 }
resource_var_snapshots() const95 const ResourceVarsSnapshot& resource_var_snapshots() const {
96 return resource_var_snapshots_;
97 }
num_constant_args() const98 int num_constant_args() const { return num_constant_args_; }
99
100 private:
101 xla::LocalClient* client_;
102 xla::LocalExecutable* executable_;
103 const XlaCompiler::CompilationResult* compilation_result_;
104 ResourceVarsSnapshot resource_var_snapshots_;
105 int num_constant_args_;
106
107 TF_DISALLOW_COPY_AND_ASSIGN(XlaExecutableClosure);
108 };
109
110 // This maintains a mapping from a globally unique ID to XlaExecutableClosure
111 // instances.
112 class XlaExecutableClosureStore {
113 public:
XlaExecutableClosureStore()114 XlaExecutableClosureStore() : key_counter_(0) {}
115
116 using KeyT = string;
117
Produce(XlaExecutableClosure result)118 KeyT Produce(XlaExecutableClosure result) {
119 mutex_lock l(mutex_);
120 KeyT key = absl::StrCat(key_counter_++);
121 bool insert_successful = closures_.emplace(key, std::move(result)).second;
122 DCHECK(insert_successful);
123 (void)insert_successful;
124 return key;
125 }
126
Consume(const KeyT & key)127 XlaExecutableClosure Consume(const KeyT& key) {
128 mutex_lock l(mutex_);
129 auto it = closures_.find(key);
130 DCHECK(it != closures_.end());
131 XlaExecutableClosure value = std::move(it->second);
132 closures_.erase(it);
133 return value;
134 }
135
Global()136 static XlaExecutableClosureStore* Global() {
137 static XlaExecutableClosureStore* instance = new XlaExecutableClosureStore;
138 return instance;
139 }
140
141 private:
142 mutex mutex_;
143 int64 key_counter_ TF_GUARDED_BY(mutex_);
144 absl::flat_hash_map<KeyT, XlaExecutableClosure> closures_
145 TF_GUARDED_BY(mutex_);
146
147 TF_DISALLOW_COPY_AND_ASSIGN(XlaExecutableClosureStore);
148 };
149
150 } // namespace
151
XlaLocalLaunchBase(OpKernelConstruction * ctx,const std::vector<int> & constants,const std::vector<int> & resources,const NameAttrList & function,bool has_ref_vars)152 XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx,
153 const std::vector<int>& constants,
154 const std::vector<int>& resources,
155 const NameAttrList& function,
156 bool has_ref_vars)
157 : OpKernel(ctx),
158 constants_(constants),
159 resources_(resources),
160 function_(function),
161 platform_info_(XlaPlatformInfoFromDevice(ctx->device())),
162 has_ref_vars_(has_ref_vars) {}
163
CompileToLocalExecutable(OpKernelContext * ctx,const NameAttrList & function,bool has_ref_vars,const XlaPlatformInfo & platform_info,absl::Span<const Tensor * const> inputs,absl::Span<VariableInfo const> variable_infos,absl::Span<const int> constants,bool lazy,bool may_alias_resource_update,xla::LocalClient ** client,const XlaCompiler::CompilationResult ** compilation_result,xla::LocalExecutable ** executable)164 static Status CompileToLocalExecutable(
165 OpKernelContext* ctx, const NameAttrList& function, bool has_ref_vars,
166 const XlaPlatformInfo& platform_info,
167 absl::Span<const Tensor* const> inputs,
168 absl::Span<VariableInfo const> variable_infos,
169 absl::Span<const int> constants, bool lazy, bool may_alias_resource_update,
170 xla::LocalClient** client,
171 const XlaCompiler::CompilationResult** compilation_result,
172 xla::LocalExecutable** executable) {
173 // We store information about the JIT-compiled XLA computation
174 // in the ResourceMgr.
175 ResourceMgr* rm = ctx->resource_manager();
176 if (!rm) {
177 return errors::Internal("No resource manager.");
178 }
179
180 XlaCompilationCache* cache;
181 TF_RETURN_IF_ERROR(rm->LookupOrCreate<XlaCompilationCache>(
182 rm->default_container(), "xla_cache", &cache,
183 [&](XlaCompilationCache** cache) {
184 return BuildXlaCompilationCache(ctx->device(), platform_info, cache);
185 }));
186 // Hold the reference to the JIT during evaluation. (We could probably
187 // free it sooner because the ResourceMgr will retain a reference, but
188 // this is more obviously correct.)
189 core::ScopedUnref cache_ref(cache);
190
191 *client = static_cast<xla::LocalClient*>(cache->client());
192
193 absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
194 XlaCompiler::Options options = GenerateCompilerOptions(
195 *cache, *ctx->function_library(), ctx->device(),
196 ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr,
197 platform_info, has_ref_vars, &tf_allocator_adapter);
198
199 XlaCompiler::CompileOptions compile_options;
200 compile_options.is_entry_computation = true;
201 // Optimization: where possible, have the computation return a naked array
202 // rather than a one-element tuple.
203 compile_options.always_return_tuple = false;
204 compile_options.alias_resource_update = !has_ref_vars &&
205 !platform_info.is_on_xla_device() &&
206 may_alias_resource_update;
207
208 xla::StatusOr<std::vector<XlaCompiler::Argument>> args =
209 XlaComputationLaunchContext::BuildXlaCompilerArguments(
210 constants, inputs, variable_infos,
211 static_cast<Device*>(ctx->device()));
212 TF_RETURN_IF_ERROR(args.status());
213 return cache->Compile(options, function, *args, compile_options,
214 lazy ? XlaCompilationCache::CompileMode::kLazy
215 : XlaCompilationCache::CompileMode::kStrict,
216 compilation_result, executable);
217 }
218
Compute(OpKernelContext * ctx)219 void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
220 VLOG(1) << "XlaLocalLaunchOpBase::Compute "
221 << Canonicalize(function_.name(), AttrSlice(&function_.attr()));
222
223 std::vector<const Tensor*> inputs = InputsFromContext(ctx);
224 xla::LocalClient* client;
225 const XlaCompiler::CompilationResult* compilation_result;
226 xla::LocalExecutable* executable;
227
228 std::vector<VariableInfo> variable_infos;
229 {
230 OP_REQUIRES_OK(
231 ctx, GetVariableInfosFromInputs(ctx->resource_manager(), ctx->device(),
232 inputs, resources_, &variable_infos));
233 OP_REQUIRES_OK(ctx, LockVariables(absl::MakeSpan(variable_infos)));
234 Status s = CompileToLocalExecutable(
235 ctx, function_, /*has_ref_vars=*/has_ref_vars_, platform_info_, inputs,
236 variable_infos, constants_, /*lazy=*/false,
237 /*may_alias_resource_update=*/true, &client, &compilation_result,
238 &executable);
239 OP_REQUIRES_OK(ctx, s);
240 }
241
242 std::map<int, const Tensor*> resource_var_ptrs;
243 for (int i = 0; i < resources_.size(); i++) {
244 resource_var_ptrs[resources_[i]] = variable_infos[i].var()->tensor();
245 }
246
247 se::Stream* stream =
248 ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
249
250 absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
251 se::DeviceMemoryAllocator* allocator = GetAllocator(
252 &tf_allocator_adapter, ctx->device(),
253 ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr,
254 platform_info_);
255 int device_ordinal = stream ? stream->parent()->device_ordinal()
256 : client->default_device_ordinal();
257 XlaComputationLaunchContext launch_context(
258 client, allocator, device_ordinal,
259 /*allocate_xla_tensors=*/platform_info_.is_on_xla_device(),
260 platform_info_.UseMultipleStreams());
261 const xla::HloInputOutputAliasConfig& input_output_alias =
262 executable->executable()->module().input_output_alias_config();
263 xla::StatusOr<std::vector<xla::ExecutionInput>> execution_inputs =
264 launch_context.PopulateInputs(ctx, compilation_result, resource_var_ptrs,
265 /*missing_ctx_input_prefix=*/0,
266 input_output_alias);
267 OP_REQUIRES_OK(ctx, execution_inputs.status());
268
269 // Execute the computation.
270 VLOG(2) << "Executing computation.";
271 xla::ExecutableRunOptions run_options;
272 run_options.set_stream(stream);
273 run_options.set_allocator(allocator);
274 run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
275 run_options.set_rng_seed(GetXLARandomSeed());
276 Env* env = Env::Default();
277 auto start_time = env->NowMicros();
278
279 xla::StatusOr<xla::ExecutionOutput> execution_output;
280 if (!stream || platform_info_.platform_id() == se::host::kHostPlatformId) {
281 execution_output =
282 executable->Run(std::move(*execution_inputs), run_options);
283 } else {
284 execution_output =
285 executable->RunAsync(std::move(*execution_inputs), run_options);
286 }
287 OP_REQUIRES(ctx, execution_output.ok(), execution_output.status());
288
289 auto elapsed = env->NowMicros() - start_time;
290 VLOG(2) << "Elapsed time: " << elapsed << "us";
291 OP_REQUIRES_OK(
292 ctx, launch_context.PopulateOutputs(
293 ctx, compilation_result, execution_output->ConsumeResult(),
294 /*missing_ctx_input_prefix=*/0, absl::MakeSpan(variable_infos),
295 input_output_alias, resource_var_ptrs));
296
297 VLOG(1) << "Done";
298 }
299
300 namespace {
301 // Helper static functions to construct parameters for
302 // XlaLocalLaunchBase constructor from OpKernelConstruction.
ConstantsVector(OpKernelConstruction * ctx)303 std::vector<int> ConstantsVector(OpKernelConstruction* ctx) {
304 DataTypeVector constant_types;
305 OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
306 ctx->GetAttr("Tconstants", &constant_types));
307 std::vector<int> constants(constant_types.size());
308 std::iota(constants.begin(), constants.end(), 0);
309 return constants;
310 }
311
ResourcesVector(OpKernelConstruction * ctx)312 std::vector<int> ResourcesVector(OpKernelConstruction* ctx) {
313 DataTypeVector constant_types;
314 OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
315 ctx->GetAttr("Tconstants", &constant_types));
316
317 DataTypeVector arg_types;
318 OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
319 ctx->GetAttr("Targs", &arg_types));
320
321 int num_resources;
322 OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
323 ctx->GetAttr("Nresources", &num_resources));
324
325 std::vector<int> resources(num_resources);
326 std::iota(resources.begin(), resources.end(),
327 constant_types.size() + arg_types.size());
328 return resources;
329 }
330
FunctionAttr(OpKernelConstruction * ctx)331 NameAttrList FunctionAttr(OpKernelConstruction* ctx) {
332 const NameAttrList* func;
333 OP_REQUIRES_OK_RETURN(ctx, NameAttrList(), ctx->GetAttr("function", &func));
334 return *func;
335 }
336
MustCompileAttr(OpKernelConstruction * ctx)337 bool MustCompileAttr(OpKernelConstruction* ctx) {
338 bool must_compile;
339 OP_REQUIRES_OK_RETURN(ctx, false,
340 ctx->GetAttr("must_compile", &must_compile));
341 return must_compile;
342 }
343
HasRefVars(OpKernelConstruction * ctx)344 bool HasRefVars(OpKernelConstruction* ctx) {
345 bool has_ref_vars;
346 OP_REQUIRES_OK_RETURN(ctx, false,
347 ctx->GetAttr(kXlaHasReferenceVarsAttr, &has_ref_vars));
348 return has_ref_vars;
349 }
350
351 } // namespace
352
XlaLocalLaunchOp(OpKernelConstruction * ctx)353 XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx)
354 : XlaLocalLaunchBase(ctx, ConstantsVector(ctx), ResourcesVector(ctx),
355 FunctionAttr(ctx), /*has_ref_vars=*/true) {}
356
~XlaLocalLaunchOp()357 XlaLocalLaunchOp::~XlaLocalLaunchOp() {
358 VLOG(1) << "XlaLocalLaunchOp destroyed";
359 }
360
XlaCompileOp(OpKernelConstruction * ctx)361 XlaCompileOp::XlaCompileOp(OpKernelConstruction* ctx)
362 : OpKernel(ctx),
363 constants_(ConstantsVector(ctx)),
364 resources_(ResourcesVector(ctx)),
365 function_(FunctionAttr(ctx)),
366 platform_info_(XlaPlatformInfoFromDevice(ctx->device())),
367 must_compile_(MustCompileAttr(ctx)),
368 has_ref_vars_(HasRefVars(ctx)) {}
369
Compute(OpKernelContext * ctx)370 void XlaCompileOp::Compute(OpKernelContext* ctx) {
371 VLOG(3) << "XlaCompileOp " << def().name()
372 << (must_compile_ ? "(must-compile)" : "");
373 xla::LocalClient* client;
374 const XlaCompiler::CompilationResult* kernel;
375 xla::LocalExecutable* executable;
376 ResourceVarsSnapshot variables;
377
378 std::vector<const Tensor*> inputs = InputsFromContext(ctx);
379 bool cannot_compile_cluster;
380 {
381 mutex_lock guard(cannot_compile_cluster_mu_);
382 cannot_compile_cluster = cannot_compile_cluster_;
383 }
384
385 if (GetXlaOpsCommonFlags().tf_xla_always_defer_compilation ||
386 cannot_compile_cluster) {
387 executable = nullptr;
388 } else {
389 std::vector<VariableInfo> variable_infos;
390 OP_REQUIRES_OK(
391 ctx, GetVariableInfosFromInputs(ctx->resource_manager(), ctx->device(),
392 inputs, resources_, &variable_infos));
393 OP_REQUIRES_OK(ctx, LockVariables(absl::MakeSpan(variable_infos)));
394
395 // Do not alias resource updates as locking variables in XlaCompile and
396 // unlocking them in XlaRun may lead to deadlocks.
397 Status status = CompileToLocalExecutable(
398 ctx, function_, has_ref_vars_, platform_info_, inputs, variable_infos,
399 constants_,
400 /*lazy=*/!must_compile_,
401 /*may_alias_resource_update=*/false, &client, &kernel, &executable);
402 OP_REQUIRES_OK(ctx, SnapshotResourceVariables(ctx, resources_,
403 variable_infos, &variables));
404 if (must_compile_ || status.code() != error::UNIMPLEMENTED) {
405 OP_REQUIRES_OK(ctx, status);
406 }
407
408 if (status.code() == error::UNIMPLEMENTED) {
409 LOG(WARNING) << "Compilation failed:" << status.ToString()
410 << ". Falling back to TF function call.";
411
412 BroadcastOptimizationRemark(
413 XlaOptimizationRemark::UNIMPLEMENTED_OPERATION, status.ToString())
414 .IgnoreError();
415 executable = nullptr;
416 mutex_lock guard(cannot_compile_cluster_mu_);
417 cannot_compile_cluster_ = true;
418 }
419 }
420
421 AllocatorAttributes host_alloc_attrs;
422 host_alloc_attrs.set_gpu_compatible(true);
423 host_alloc_attrs.set_on_host(true);
424 Allocator* cpu_allocator = ctx->device()->GetAllocator(host_alloc_attrs);
425
426 if (!executable) {
427 DCHECK(!must_compile_);
428 Tensor compilation_key(cpu_allocator, DT_STRING, TensorShape({}));
429
430 Tensor compilation_successful(cpu_allocator, DT_BOOL, TensorShape({}));
431 compilation_successful.scalar<bool>()() = false;
432 ctx->set_output(0, Tensor(cpu_allocator, DT_STRING, TensorShape({})));
433 ctx->set_output(1, compilation_successful);
434 return;
435 }
436
437 // Each execution of an XlaCompile op creates a new XlaExecutableClosure, even
438 // if it didn't have to compile the cluster because of a compilation-cache
439 // hit. This is because we at least need new snapshots of the resource
440 // variables.
441 XlaExecutableClosureStore::KeyT key =
442 XlaExecutableClosureStore::Global()->Produce(XlaExecutableClosure(
443 client, executable, kernel, std::move(variables), constants_.size()));
444
445 Tensor compilation_key(cpu_allocator, DT_STRING, TensorShape({}));
446 compilation_key.flat<tstring>()(0) = key;
447
448 Tensor compilation_successful(cpu_allocator, DT_BOOL, TensorShape({}));
449 compilation_successful.flat<bool>()(0) = true;
450
451 ctx->set_output(0, compilation_key);
452 ctx->set_output(1, compilation_successful);
453 }
454
XlaRunOp(OpKernelConstruction * ctx)455 XlaRunOp::XlaRunOp(OpKernelConstruction* ctx)
456 : OpKernel(ctx), platform_info_(XlaPlatformInfoFromDevice(ctx->device())) {}
457
Compute(OpKernelContext * ctx)458 void XlaRunOp::Compute(OpKernelContext* ctx) {
459 VLOG(3) << "XlaRunOp " << def().name();
460 Tensor key_tensor = ctx->input(ctx->num_inputs() - 1);
461 const XlaExecutableClosureStore::KeyT& key = key_tensor.flat<tstring>()(0);
462
463 XlaExecutableClosure closure =
464 XlaExecutableClosureStore::Global()->Consume(key);
465
466 absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
467 se::DeviceMemoryAllocator* allocator = GetAllocator(
468 &tf_allocator_adapter, ctx->device(),
469 ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr,
470 platform_info_);
471 se::Stream* stream =
472 ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
473 int device_ordinal = stream ? stream->parent()->device_ordinal()
474 : closure.client()->default_device_ordinal();
475 XlaComputationLaunchContext launch_context(
476 closure.client(), allocator, device_ordinal,
477 /*allocate_xla_tensors=*/platform_info_.is_on_xla_device(),
478 /*use_multiple_streams=*/platform_info_.UseMultipleStreams());
479
480 // We're missing the must-be-constant inputs, tell `PopulateInputs`
481 // about this. We don't actually need these inputs because they've
482 // already been baked into the compiled kernel.
483 const xla::HloInputOutputAliasConfig& input_output_alias =
484 closure.executable()->executable()->module().input_output_alias_config();
485 xla::StatusOr<std::vector<xla::ExecutionInput>> execution_inputs;
486 std::map<int, const Tensor*> snapshot_ptrs;
487 {
488 tensorflow::profiler::TraceMe hlo_module_activity(
489 [&] {
490 return absl::StrCat(
491 "Populate Inputs (",
492 closure.compilation_result()->xla_input_shapes.size(), ")");
493 },
494 tensorflow::profiler::TraceMeLevel::kInfo);
495
496 for (auto& p : closure.resource_var_snapshots()) {
497 snapshot_ptrs.emplace(p.first,
498 p.second.has_value() ? &p.second.value() : nullptr);
499 }
500 execution_inputs = launch_context.PopulateInputs(
501 ctx, closure.compilation_result(), snapshot_ptrs,
502 /*missing_ctx_input_prefix=*/closure.num_constant_args(),
503 input_output_alias);
504 OP_REQUIRES_OK(ctx, execution_inputs.status());
505 }
506
507 xla::ExecutableRunOptions run_options;
508 run_options.set_stream(stream);
509 run_options.set_allocator(allocator);
510 run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
511 run_options.set_rng_seed(GetXLARandomSeed());
512 Env* env = Env::Default();
513 auto start_time = env->NowMicros();
514
515 xla::StatusOr<xla::ExecutionOutput> execution_output;
516 if (!stream || platform_info_.platform_id() == se::host::kHostPlatformId) {
517 execution_output =
518 closure.executable()->Run(std::move(*execution_inputs), run_options);
519 } else {
520 execution_output = closure.executable()->RunAsync(
521 std::move(*execution_inputs), run_options);
522 }
523 OP_REQUIRES(ctx, execution_output.ok(), execution_output.status());
524
525 auto elapsed = env->NowMicros() - start_time;
526 VLOG(2) << "Elapsed time in computation: " << elapsed << "us";
527
528
529 tensorflow::profiler::TraceMe hlo_module_activity(
530 [&] {
531 return absl::StrCat("Populate Outputs (", ctx->num_outputs(), ")");
532 },
533 tensorflow::profiler::TraceMeLevel::kInfo);
534
535 xla::StatusOr<std::vector<VariableInfo>> variable_infos = GatherVariableInfo(
536 ctx, *closure.compilation_result(), closure.num_constant_args());
537 OP_REQUIRES_OK(ctx, variable_infos.status());
538 OP_REQUIRES_OK(ctx, LockVariables(absl::MakeSpan(*variable_infos)));
539 OP_REQUIRES_OK(
540 ctx,
541 launch_context.PopulateOutputs(
542 ctx, closure.compilation_result(), execution_output->ConsumeResult(),
543 /*missing_ctx_input_prefix=*/closure.num_constant_args(),
544 absl::MakeSpan(*variable_infos), input_output_alias, snapshot_ptrs));
545 }
546
XlaMergeOp(OpKernelConstruction * ctx)547 XlaMergeOp::XlaMergeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
548
Compute(OpKernelContext * ctx)549 void XlaMergeOp::Compute(OpKernelContext* ctx) {
550 VLOG(3) << "XlaMergeOp " << def().name();
551 int i = 0;
552 if (ctx->has_input(i) || ctx->has_input(++i)) {
553 ctx->set_output(0, ctx->input(i));
554 }
555 }
556
557 REGISTER_KERNEL_BUILDER(Name("XlaLaunch").Device(DEVICE_CPU), XlaLocalLaunchOp);
558
559 REGISTER_KERNEL_BUILDER(Name("XlaLaunch")
560 .Device(DEVICE_GPU)
561 .HostMemory("constants")
562 .HostMemory("resources"),
563 XlaLocalLaunchOp);
564
565 REGISTER_KERNEL_BUILDER(Name("_XlaCompile").Device(DEVICE_CPU), XlaCompileOp);
566 REGISTER_KERNEL_BUILDER(Name("_XlaCompile")
567 .Device(DEVICE_GPU)
568 .HostMemory("constants")
569 .HostMemory("key")
570 .HostMemory("compilation_successful")
571 .HostMemory("resources"),
572 XlaCompileOp);
573
574 REGISTER_KERNEL_BUILDER(Name("_XlaRun").Device(DEVICE_CPU), XlaRunOp);
575 REGISTER_KERNEL_BUILDER(Name("_XlaRun").Device(DEVICE_GPU).HostMemory("key"),
576 XlaRunOp);
577
578 REGISTER_KERNEL_BUILDER(Name("_XlaMerge").Device(DEVICE_CPU), XlaMergeOp);
579 REGISTER_KERNEL_BUILDER(Name("_XlaMerge").Device(DEVICE_GPU), XlaMergeOp);
580
581 } // namespace tensorflow
582