1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/jit/kernels/xla_ops.h"
17
18 #include "absl/container/flat_hash_map.h"
19 #include "absl/memory/memory.h"
20 #include "absl/synchronization/notification.h"
21 #include "absl/types/optional.h"
22 #include "tensorflow/compiler/jit/defs.h"
23 #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
24 #include "tensorflow/compiler/jit/flags.h"
25 #include "tensorflow/compiler/jit/xla_activity_listener.h"
26 #include "tensorflow/compiler/jit/xla_cluster_util.h"
27 #include "tensorflow/compiler/jit/xla_platform_info.h"
28 #include "tensorflow/compiler/tf2xla/shape_util.h"
29 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
30 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
31 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
32 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
33 #include "tensorflow/compiler/xla/client/client_library.h"
34 #include "tensorflow/compiler/xla/client/local_client.h"
35 #include "tensorflow/compiler/xla/executable_run_options.h"
36 #include "tensorflow/compiler/xla/service/compiler.h"
37 #include "tensorflow/compiler/xla/status_macros.h"
38 #include "tensorflow/compiler/xla/statusor.h"
39 #include "tensorflow/core/common_runtime/dma_helper.h"
40 #include "tensorflow/core/common_runtime/function.h"
41 #include "tensorflow/core/framework/allocator.h"
42 #include "tensorflow/core/framework/node_def_util.h"
43 #include "tensorflow/core/framework/op.h"
44 #include "tensorflow/core/framework/op_kernel.h"
45 #include "tensorflow/core/framework/tensor.h"
46 #include "tensorflow/core/framework/types.h"
47 #include "tensorflow/core/lib/core/errors.h"
48 #include "tensorflow/core/lib/core/status.h"
49 #include "tensorflow/core/lib/monitoring/counter.h"
50 #include "tensorflow/core/platform/casts.h"
51 #include "tensorflow/core/platform/env.h"
52 #include "tensorflow/core/platform/refcount.h"
53 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
54 #include "tensorflow/core/profiler/lib/traceme.h"
55 #include "tensorflow/core/util/stream_executor_util.h"
56
57 // OP_REQUIRES_OK_RETURN is the same as OP_REQUIRES_OK except that
58 // in error case, it returns RET instead of void.
59 #define OP_REQUIRES_OK_RETURN(CTX, RET, ...) \
60 do { \
61 ::tensorflow::Status _s(__VA_ARGS__); \
62 if (!TF_PREDICT_TRUE(_s.ok())) { \
63 (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \
64 return RET; \
65 } \
66 } while (0)
67
68 namespace tensorflow {
69
70 namespace {
71
72 auto* xla_launch_counter = monitoring::Counter<1>::New(
73 "/tensorflow/core/xla_launch_counter",
74 "The number of times a XlaLaunch is called.", "device");
75
76 // A closure describing how to run a compiled version of a TensorFlow function.
77 //
78 // It may seem unusual to stick the resource variable snapshots in this class.
79 // This is necessary: we need to use the snapshots observed by the compiler as
80 // the initial values for the resource variables (and cannot snapshot them again
81 // during execution) because otherwise we risk observing a different snapshot
82 // with shapes different from what we compiled for.
83 class XlaExecutableClosure {
84 public:
XlaExecutableClosure(xla::LocalClient * client,xla::LocalExecutable * executable,const XlaCompiler::CompilationResult * compilation_result,ResourceVarsSnapshot resource_var_snapshots,int num_constant_args)85 explicit XlaExecutableClosure(
86 xla::LocalClient* client, xla::LocalExecutable* executable,
87 const XlaCompiler::CompilationResult* compilation_result,
88 ResourceVarsSnapshot resource_var_snapshots, int num_constant_args)
89 : client_(client),
90 executable_(executable),
91 compilation_result_(compilation_result),
92 resource_var_snapshots_(std::move(resource_var_snapshots)),
93 num_constant_args_(num_constant_args) {}
94
95 XlaExecutableClosure(XlaExecutableClosure&&) = default;
96 XlaExecutableClosure& operator=(XlaExecutableClosure&&) = default;
97
client() const98 xla::LocalClient* client() const { return client_; }
executable() const99 xla::LocalExecutable* executable() const { return executable_; }
compilation_result() const100 const XlaCompiler::CompilationResult* compilation_result() const {
101 return compilation_result_;
102 }
resource_var_snapshots() const103 const ResourceVarsSnapshot& resource_var_snapshots() const {
104 return resource_var_snapshots_;
105 }
num_constant_args() const106 int num_constant_args() const { return num_constant_args_; }
107
108 private:
109 xla::LocalClient* client_;
110 xla::LocalExecutable* executable_;
111 const XlaCompiler::CompilationResult* compilation_result_;
112 ResourceVarsSnapshot resource_var_snapshots_;
113 int num_constant_args_;
114
115 TF_DISALLOW_COPY_AND_ASSIGN(XlaExecutableClosure);
116 };
117
118 // This maintains a mapping from a globally unique ID to XlaExecutableClosure
119 // instances.
120 class XlaExecutableClosureStore {
121 public:
XlaExecutableClosureStore()122 XlaExecutableClosureStore() : key_counter_(0) {}
123
124 using KeyT = string;
125
Produce(XlaExecutableClosure result)126 KeyT Produce(XlaExecutableClosure result) {
127 mutex_lock l(mutex_);
128 KeyT key = absl::StrCat(key_counter_++);
129 bool insert_successful = closures_.emplace(key, std::move(result)).second;
130 DCHECK(insert_successful);
131 (void)insert_successful;
132 return key;
133 }
134
Consume(const KeyT & key)135 XlaExecutableClosure Consume(const KeyT& key) {
136 mutex_lock l(mutex_);
137 auto it = closures_.find(key);
138 DCHECK(it != closures_.end());
139 XlaExecutableClosure value = std::move(it->second);
140 closures_.erase(it);
141 return value;
142 }
143
Global()144 static XlaExecutableClosureStore* Global() {
145 static XlaExecutableClosureStore* instance = new XlaExecutableClosureStore;
146 return instance;
147 }
148
149 private:
150 mutex mutex_;
151 int64 key_counter_ TF_GUARDED_BY(mutex_);
152 absl::flat_hash_map<KeyT, XlaExecutableClosure> closures_
153 TF_GUARDED_BY(mutex_);
154
155 TF_DISALLOW_COPY_AND_ASSIGN(XlaExecutableClosureStore);
156 };
157
158 } // namespace
159
XlaLocalLaunchBase(OpKernelConstruction * ctx,const std::vector<int> & constants,const std::vector<int> & resources,const NameAttrList & function,bool has_ref_vars)160 XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx,
161 const std::vector<int>& constants,
162 const std::vector<int>& resources,
163 const NameAttrList& function,
164 bool has_ref_vars)
165 : OpKernel(ctx),
166 constants_(constants),
167 resources_(resources),
168 function_(function),
169 platform_info_(XlaPlatformInfoFromDevice(ctx->device())),
170 has_ref_vars_(has_ref_vars) {}
171
CompileToLocalExecutable(OpKernelContext * ctx,const NameAttrList & function,bool has_ref_vars,const XlaPlatformInfo & platform_info,absl::Span<const Tensor * const> inputs,absl::Span<VariableInfo const> variable_infos,absl::Span<const int> constants,XlaCompilationCache::CompileMode compile_mode,bool may_alias_resource_update,xla::LocalClient ** client,const XlaCompiler::CompilationResult ** compilation_result,xla::LocalExecutable ** executable)172 static Status CompileToLocalExecutable(
173 OpKernelContext* ctx, const NameAttrList& function, bool has_ref_vars,
174 const XlaPlatformInfo& platform_info,
175 absl::Span<const Tensor* const> inputs,
176 absl::Span<VariableInfo const> variable_infos,
177 absl::Span<const int> constants,
178 XlaCompilationCache::CompileMode compile_mode,
179 bool may_alias_resource_update, xla::LocalClient** client,
180 const XlaCompiler::CompilationResult** compilation_result,
181 xla::LocalExecutable** executable) {
182 // We store information about the JIT-compiled XLA computation
183 // in the ResourceMgr.
184 ResourceMgr* rm = ctx->resource_manager();
185 if (!rm) {
186 return errors::Internal("No resource manager.");
187 }
188
189 XlaCompilationCache* cache;
190 TF_RETURN_IF_ERROR(rm->LookupOrCreate<XlaCompilationCache>(
191 rm->default_container(), "xla_cache", &cache,
192 [&](XlaCompilationCache** cache) {
193 return BuildXlaCompilationCache(ctx->device(), ctx->function_library(),
194 platform_info, cache);
195 }));
196 // Hold the reference to the JIT during evaluation. (We could probably
197 // free it sooner because the ResourceMgr will retain a reference, but
198 // this is more obviously correct.)
199 core::ScopedUnref cache_ref(cache);
200
201 *client = static_cast<xla::LocalClient*>(cache->client());
202
203 XlaCompiler::Options options = GenerateCompilerOptions(
204 *cache, *ctx->function_library(), ctx->device(),
205 ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr,
206 platform_info, has_ref_vars);
207
208 XlaCompiler::CompileOptions compile_options;
209 compile_options.is_entry_computation = true;
210 // Optimization: where possible, have the computation return a naked array
211 // rather than a one-element tuple.
212 compile_options.always_return_tuple = false;
213 compile_options.alias_resource_update = !has_ref_vars &&
214 may_alias_resource_update;
215
216 StatusOr<std::vector<XlaCompiler::Argument>> args =
217 XlaComputationLaunchContext::BuildXlaCompilerArguments(
218 constants, inputs, variable_infos,
219 static_cast<Device*>(ctx->device()));
220 TF_RETURN_IF_ERROR(args.status());
221 return cache->Compile(options, function, *args, compile_options, compile_mode,
222 compilation_result, executable);
223 }
224
Compute(OpKernelContext * ctx)225 void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
226 VLOG(1) << "XlaLocalLaunchOpBase::Compute "
227 << Canonicalize(function_.name(), AttrSlice(&function_.attr()));
228 xla_launch_counter->GetCell(platform_info_.device_type().type_string())
229 ->IncrementBy(1);
230
231 std::vector<const Tensor*> inputs = InputsFromContext(ctx);
232 xla::LocalClient* client;
233 const XlaCompiler::CompilationResult* compilation_result;
234 xla::LocalExecutable* executable;
235
236 std::vector<VariableInfo> variable_infos;
237 {
238 OP_REQUIRES_OK(
239 ctx, GetVariableInfosFromInputs(ctx->resource_manager(), ctx->device(),
240 inputs, resources_, &variable_infos));
241 OP_REQUIRES_OK(ctx, LockVariables(absl::MakeSpan(variable_infos)));
242 Status s = CompileToLocalExecutable(
243 ctx, function_, /*has_ref_vars=*/has_ref_vars_, platform_info_, inputs,
244 variable_infos, constants_, XlaCompilationCache::CompileMode::kStrict,
245 /*may_alias_resource_update=*/true, &client, &compilation_result,
246 &executable);
247 OP_REQUIRES_OK(ctx, s);
248 }
249
250 std::map<int, const Tensor*> resource_var_ptrs;
251 for (int i = 0; i < resources_.size(); i++) {
252 resource_var_ptrs[resources_[i]] = variable_infos[i].var()->tensor();
253 }
254
255 se::Stream* stream =
256 ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
257 std::shared_ptr<se::DeviceMemoryAllocator> allocator_ptr =
258 GetAllocator(ctx->device(), stream, platform_info_);
259 se::DeviceMemoryAllocator* allocator = allocator_ptr.get();
260 int device_ordinal = stream ? stream->parent()->device_ordinal()
261 : client->default_device_ordinal();
262 XlaComputationLaunchContext launch_context(
263 client, allocator, device_ordinal,
264 /*allocate_xla_tensors=*/platform_info_.is_on_xla_device(),
265 platform_info_.UseMultipleStreams());
266 const xla::HloInputOutputAliasConfig& input_output_alias =
267 executable->executable()->module().input_output_alias_config();
268 StatusOr<std::vector<xla::ExecutionInput>> execution_inputs =
269 launch_context.PopulateInputs(ctx, compilation_result, resource_var_ptrs,
270 /*missing_ctx_input_prefix=*/0,
271 input_output_alias);
272 OP_REQUIRES_OK(ctx, execution_inputs.status());
273
274 // Execute the computation.
275 VLOG(2) << "Executing computation.";
276 StatusOr<absl::optional<xla::DeviceAssignment>> device_assignment =
277 ResolveDeviceAssignment(ctx, compilation_result->collective_reduce_info);
278 OP_REQUIRES_OK(ctx, device_assignment.status());
279
280 xla::ExecutableRunOptions run_options;
281 if (*device_assignment) {
282 run_options.set_device_assignment(&**device_assignment);
283 }
284 run_options.set_stream(stream);
285 run_options.set_allocator(allocator);
286 run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
287 run_options.set_rng_seed(GetXLARandomSeed());
288
289 // Hardcode run id to always be zero: TF distributed strategy differentiates
290 // between subsequent runs using dependency edges.
291 // This is safe, as only TF dist-strat can produce distributed ops, and we can
292 // rely on TF dist-strat invariants.
293 xla::RunId run_id(0);
294 run_options.set_run_id(run_id);
295 Env* env = Env::Default();
296 auto start_time = env->NowMicros();
297
298 StatusOr<xla::ExecutionOutput> execution_output;
299 if (!stream || platform_info_.platform_id() == se::host::kHostPlatformId) {
300 execution_output =
301 executable->Run(std::move(*execution_inputs), run_options);
302 } else {
303 execution_output =
304 executable->RunAsync(std::move(*execution_inputs), run_options);
305 }
306 OP_REQUIRES(ctx, execution_output.ok(), execution_output.status());
307
308 auto elapsed = env->NowMicros() - start_time;
309 VLOG(2) << "Elapsed time: " << elapsed << "us";
310 OP_REQUIRES_OK(
311 ctx, launch_context.PopulateOutputs(
312 ctx, compilation_result, execution_output->ConsumeResult(),
313 /*missing_ctx_input_prefix=*/0, absl::MakeSpan(variable_infos),
314 input_output_alias, resource_var_ptrs));
315
316 VLOG(1) << "Done";
317 }
318
319 namespace {
320 // Helper static functions to construct parameters for
321 // XlaLocalLaunchBase constructor from OpKernelConstruction.
ConstantsVector(OpKernelConstruction * ctx)322 std::vector<int> ConstantsVector(OpKernelConstruction* ctx) {
323 DataTypeVector constant_types;
324 OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
325 ctx->GetAttr("Tconstants", &constant_types));
326 std::vector<int> constants(constant_types.size());
327 std::iota(constants.begin(), constants.end(), 0);
328 return constants;
329 }
330
ResourcesVector(OpKernelConstruction * ctx)331 std::vector<int> ResourcesVector(OpKernelConstruction* ctx) {
332 DataTypeVector constant_types;
333 OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
334 ctx->GetAttr("Tconstants", &constant_types));
335
336 DataTypeVector arg_types;
337 OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
338 ctx->GetAttr("Targs", &arg_types));
339
340 int num_resources;
341 OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
342 ctx->GetAttr("Nresources", &num_resources));
343
344 std::vector<int> resources(num_resources);
345 std::iota(resources.begin(), resources.end(),
346 constant_types.size() + arg_types.size());
347 return resources;
348 }
349
FunctionAttr(OpKernelConstruction * ctx)350 NameAttrList FunctionAttr(OpKernelConstruction* ctx) {
351 const NameAttrList* func;
352 OP_REQUIRES_OK_RETURN(ctx, NameAttrList(), ctx->GetAttr("function", &func));
353 return *func;
354 }
355
MustCompileAttr(OpKernelConstruction * ctx)356 bool MustCompileAttr(OpKernelConstruction* ctx) {
357 bool must_compile;
358 OP_REQUIRES_OK_RETURN(ctx, false,
359 ctx->GetAttr("must_compile", &must_compile));
360 return must_compile;
361 }
362
HasRefVars(OpKernelConstruction * ctx)363 bool HasRefVars(OpKernelConstruction* ctx) {
364 bool has_ref_vars;
365 OP_REQUIRES_OK_RETURN(ctx, false,
366 ctx->GetAttr(kXlaHasReferenceVarsAttr, &has_ref_vars));
367 return has_ref_vars;
368 }
369
370 } // namespace
371
XlaLocalLaunchOp(OpKernelConstruction * ctx)372 XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx)
373 : XlaLocalLaunchBase(ctx, ConstantsVector(ctx), ResourcesVector(ctx),
374 FunctionAttr(ctx), /*has_ref_vars=*/true) {}
375
~XlaLocalLaunchOp()376 XlaLocalLaunchOp::~XlaLocalLaunchOp() {
377 VLOG(1) << "XlaLocalLaunchOp destroyed";
378 }
379
XlaCompileOp(OpKernelConstruction * ctx)380 XlaCompileOp::XlaCompileOp(OpKernelConstruction* ctx)
381 : OpKernel(ctx),
382 constants_(ConstantsVector(ctx)),
383 resources_(ResourcesVector(ctx)),
384 function_(FunctionAttr(ctx)),
385 platform_info_(XlaPlatformInfoFromDevice(ctx->device())),
386 must_compile_(MustCompileAttr(ctx)),
387 has_ref_vars_(HasRefVars(ctx)) {}
388
Compute(OpKernelContext * ctx)389 void XlaCompileOp::Compute(OpKernelContext* ctx) {
390 VLOG(3) << "XlaCompileOp " << def().name()
391 << (must_compile_ ? "(must-compile)" : "");
392 xla::LocalClient* client;
393 const XlaCompiler::CompilationResult* kernel;
394 xla::LocalExecutable* executable;
395 ResourceVarsSnapshot variables;
396
397 std::vector<const Tensor*> inputs = InputsFromContext(ctx);
398 bool cannot_compile_cluster;
399 {
400 mutex_lock guard(cannot_compile_cluster_mu_);
401 cannot_compile_cluster = cannot_compile_cluster_;
402 }
403 XlaCompilationCache::CompileMode compile_mode = [&] {
404 if (must_compile_) {
405 return XlaCompilationCache::CompileMode::kStrict;
406 }
407 return GetXlaOpsCommonFlags().tf_xla_async_compilation
408 ? XlaCompilationCache::CompileMode::kAsync
409 : XlaCompilationCache::CompileMode::kLazy;
410 }();
411
412 if (GetXlaOpsCommonFlags().tf_xla_always_defer_compilation ||
413 cannot_compile_cluster) {
414 executable = nullptr;
415 } else {
416 std::vector<VariableInfo> variable_infos;
417 OP_REQUIRES_OK(
418 ctx, GetVariableInfosFromInputs(ctx->resource_manager(), ctx->device(),
419 inputs, resources_, &variable_infos));
420 OP_REQUIRES_OK(ctx, LockVariables(absl::MakeSpan(variable_infos)));
421
422 // Do not alias resource updates as locking variables in XlaCompile and
423 // unlocking them in XlaRun may lead to deadlocks.
424 Status status = CompileToLocalExecutable(
425 ctx, function_, has_ref_vars_, platform_info_, inputs, variable_infos,
426 constants_, compile_mode, /*may_alias_resource_update=*/false, &client,
427 &kernel, &executable);
428 OP_REQUIRES_OK(ctx, SnapshotResourceVariables(ctx, resources_,
429 variable_infos, &variables));
430 if (compile_mode != XlaCompilationCache::CompileMode::kLazy ||
431 status.code() != error::UNIMPLEMENTED) {
432 OP_REQUIRES_OK(ctx, status);
433 }
434
435 if (status.code() == error::UNIMPLEMENTED) {
436 LOG(WARNING) << "Compilation failed:" << status.ToString()
437 << ". Falling back to TF function call.";
438
439 BroadcastOptimizationRemark(
440 XlaOptimizationRemark::UNIMPLEMENTED_OPERATION, status.ToString())
441 .IgnoreError();
442 executable = nullptr;
443 mutex_lock guard(cannot_compile_cluster_mu_);
444 cannot_compile_cluster_ = true;
445 }
446 }
447
448 AllocatorAttributes host_alloc_attrs;
449 host_alloc_attrs.set_gpu_compatible(true);
450 host_alloc_attrs.set_on_host(true);
451 Allocator* cpu_allocator = ctx->device()->GetAllocator(host_alloc_attrs);
452
453 // Async compilation returns nullptr executable without an error.
454 if (!executable) {
455 DCHECK(!must_compile_);
456 Tensor compilation_key(cpu_allocator, DT_STRING, TensorShape({}));
457
458 Tensor compilation_successful(cpu_allocator, DT_BOOL, TensorShape({}));
459 compilation_successful.scalar<bool>()() = false;
460 ctx->set_output(0, Tensor(cpu_allocator, DT_STRING, TensorShape({})));
461 ctx->set_output(1, compilation_successful);
462 return;
463 }
464
465 // Each execution of an XlaCompile op creates a new XlaExecutableClosure, even
466 // if it didn't have to compile the cluster because of a compilation-cache
467 // hit. This is because we at least need new snapshots of the resource
468 // variables.
469 XlaExecutableClosureStore::KeyT key =
470 XlaExecutableClosureStore::Global()->Produce(XlaExecutableClosure(
471 client, executable, kernel, std::move(variables), constants_.size()));
472
473 Tensor compilation_key(cpu_allocator, DT_STRING, TensorShape({}));
474 compilation_key.flat<tstring>()(0) = key;
475
476 Tensor compilation_successful(cpu_allocator, DT_BOOL, TensorShape({}));
477 compilation_successful.flat<bool>()(0) = true;
478
479 ctx->set_output(0, compilation_key);
480 ctx->set_output(1, compilation_successful);
481 }
482
XlaRunOp(OpKernelConstruction * ctx)483 XlaRunOp::XlaRunOp(OpKernelConstruction* ctx)
484 : OpKernel(ctx), platform_info_(XlaPlatformInfoFromDevice(ctx->device())) {}
485
Compute(OpKernelContext * ctx)486 void XlaRunOp::Compute(OpKernelContext* ctx) {
487 VLOG(3) << "XlaRunOp " << def().name();
488 Tensor key_tensor = ctx->input(ctx->num_inputs() - 1);
489 const XlaExecutableClosureStore::KeyT& key = key_tensor.flat<tstring>()(0);
490
491 XlaExecutableClosure closure =
492 XlaExecutableClosureStore::Global()->Consume(key);
493
494 se::Stream* stream =
495 ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
496 std::shared_ptr<se::DeviceMemoryAllocator> allocator_ptr =
497 GetAllocator(ctx->device(), stream, platform_info_);
498 se::DeviceMemoryAllocator* allocator = allocator_ptr.get();
499 int device_ordinal = stream ? stream->parent()->device_ordinal()
500 : closure.client()->default_device_ordinal();
501 XlaComputationLaunchContext launch_context(
502 closure.client(), allocator, device_ordinal,
503 /*allocate_xla_tensors=*/platform_info_.is_on_xla_device(),
504 /*use_multiple_streams=*/platform_info_.UseMultipleStreams());
505
506 // We're missing the must-be-constant inputs, tell `PopulateInputs`
507 // about this. We don't actually need these inputs because they've
508 // already been baked into the compiled kernel.
509 const xla::HloInputOutputAliasConfig& input_output_alias =
510 closure.executable()->executable()->module().input_output_alias_config();
511 StatusOr<std::vector<xla::ExecutionInput>> execution_inputs;
512 std::map<int, const Tensor*> snapshot_ptrs;
513 {
514 tensorflow::profiler::TraceMe hlo_module_activity(
515 [&] {
516 return absl::StrCat(
517 "Populate Inputs (",
518 closure.compilation_result()->xla_input_shapes.size(), ")");
519 },
520 tensorflow::profiler::TraceMeLevel::kInfo);
521
522 for (auto& p : closure.resource_var_snapshots()) {
523 snapshot_ptrs.emplace(p.first,
524 p.second.has_value() ? &p.second.value() : nullptr);
525 }
526 execution_inputs = launch_context.PopulateInputs(
527 ctx, closure.compilation_result(), snapshot_ptrs,
528 /*missing_ctx_input_prefix=*/closure.num_constant_args(),
529 input_output_alias);
530 OP_REQUIRES_OK(ctx, execution_inputs.status());
531 }
532
533 xla::ExecutableRunOptions run_options;
534 run_options.set_stream(stream);
535 run_options.set_allocator(allocator);
536 run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
537 run_options.set_rng_seed(GetXLARandomSeed());
538 Env* env = Env::Default();
539 auto start_time = env->NowMicros();
540
541 StatusOr<xla::ExecutionOutput> execution_output;
542 if (!stream || platform_info_.platform_id() == se::host::kHostPlatformId) {
543 execution_output =
544 closure.executable()->Run(std::move(*execution_inputs), run_options);
545 } else {
546 execution_output = closure.executable()->RunAsync(
547 std::move(*execution_inputs), run_options);
548 }
549 OP_REQUIRES(ctx, execution_output.ok(), execution_output.status());
550
551 auto elapsed = env->NowMicros() - start_time;
552 VLOG(2) << "Elapsed time in computation: " << elapsed << "us";
553
554
555 tensorflow::profiler::TraceMe hlo_module_activity(
556 [&] {
557 return absl::StrCat("Populate Outputs (", ctx->num_outputs(), ")");
558 },
559 tensorflow::profiler::TraceMeLevel::kInfo);
560
561 StatusOr<std::vector<VariableInfo>> variable_infos = GatherVariableInfo(
562 ctx, *closure.compilation_result(), closure.num_constant_args());
563 OP_REQUIRES_OK(ctx, variable_infos.status());
564 OP_REQUIRES_OK(ctx, LockVariables(absl::MakeSpan(*variable_infos)));
565 OP_REQUIRES_OK(
566 ctx,
567 launch_context.PopulateOutputs(
568 ctx, closure.compilation_result(), execution_output->ConsumeResult(),
569 /*missing_ctx_input_prefix=*/closure.num_constant_args(),
570 absl::MakeSpan(*variable_infos), input_output_alias, snapshot_ptrs));
571 }
572
XlaMergeOp(OpKernelConstruction * ctx)573 XlaMergeOp::XlaMergeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
574
Compute(OpKernelContext * ctx)575 void XlaMergeOp::Compute(OpKernelContext* ctx) {
576 VLOG(3) << "XlaMergeOp " << def().name();
577 int i = 0;
578 if (ctx->has_input(i) || ctx->has_input(++i)) {
579 ctx->set_output(0, ctx->input(i));
580 }
581 }
582
583 REGISTER_KERNEL_BUILDER(Name("XlaLaunch").Device(DEVICE_CPU), XlaLocalLaunchOp);
584
585 REGISTER_KERNEL_BUILDER(Name("XlaLaunch")
586 .Device(DEVICE_GPU)
587 .HostMemory("constants")
588 .HostMemory("resources"),
589 XlaLocalLaunchOp);
590
591 REGISTER_KERNEL_BUILDER(Name("_XlaCompile").Device(DEVICE_CPU), XlaCompileOp);
592 REGISTER_KERNEL_BUILDER(Name("_XlaCompile")
593 .Device(DEVICE_GPU)
594 .HostMemory("constants")
595 .HostMemory("key")
596 .HostMemory("compilation_successful")
597 .HostMemory("resources"),
598 XlaCompileOp);
599
600 REGISTER_KERNEL_BUILDER(Name("_XlaRun").Device(DEVICE_CPU), XlaRunOp);
601 REGISTER_KERNEL_BUILDER(Name("_XlaRun").Device(DEVICE_GPU).HostMemory("key"),
602 XlaRunOp);
603
604 REGISTER_KERNEL_BUILDER(Name("_XlaMerge").Device(DEVICE_CPU), XlaMergeOp);
605 REGISTER_KERNEL_BUILDER(Name("_XlaMerge").Device(DEVICE_GPU), XlaMergeOp);
606
607 } // namespace tensorflow
608