1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/jit/kernels/xla_ops.h"
17
18 #include "absl/container/flat_hash_map.h"
19 #include "absl/memory/memory.h"
20 #include "tensorflow/compiler/jit/defs.h"
21 #include "tensorflow/compiler/jit/flags.h"
22 #include "tensorflow/compiler/tf2xla/shape_util.h"
23 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
24 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
25 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
26 #include "tensorflow/compiler/xla/client/client_library.h"
27 #include "tensorflow/compiler/xla/client/local_client.h"
28 #include "tensorflow/compiler/xla/service/compiler.h"
29 #include "tensorflow/compiler/xla/status_macros.h"
30 #include "tensorflow/compiler/xla/statusor.h"
31 #include "tensorflow/core/common_runtime/dma_helper.h"
32 #include "tensorflow/core/common_runtime/function.h"
33 #include "tensorflow/core/framework/allocator.h"
34 #include "tensorflow/core/framework/node_def_util.h"
35 #include "tensorflow/core/framework/op.h"
36 #include "tensorflow/core/framework/op_kernel.h"
37 #include "tensorflow/core/framework/tensor.h"
38 #include "tensorflow/core/framework/types.h"
39 #include "tensorflow/core/lib/core/errors.h"
40 #include "tensorflow/core/lib/core/status.h"
41 #include "tensorflow/core/platform/env.h"
42 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
43 #include "tensorflow/core/util/stream_executor_util.h"
44
45 // OP_REQUIRES_OK_RETURN is the same as OP_REQUIRES_OK except that
46 // in error case, it returns RET instead of void.
47 #define OP_REQUIRES_OK_RETURN(CTX, RET, ...) \
48 do { \
49 ::tensorflow::Status _s(__VA_ARGS__); \
50 if (!TF_PREDICT_TRUE(_s.ok())) { \
51 (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \
52 return RET; \
53 } \
54 } while (0)
55
56 namespace tensorflow {
57
58 namespace {
59
PlatformInfoFromContext(OpKernelConstruction * ctx)60 XlaPlatformInfo PlatformInfoFromContext(OpKernelConstruction* ctx) {
61 DeviceType device_type = ctx->device_type();
62 se::Platform::Id platform_id = nullptr;
63 const XlaDevice::Metadata* xla_device_metadata = nullptr;
64 std::unique_ptr<XlaAllocator> xla_allocator;
65 xla::DeviceMemoryAllocator* device_allocator = nullptr;
66
67 if (ctx->device_type() == DeviceType(DEVICE_CPU)) {
68 platform_id = se::host::kHostPlatformId;
69 } else if (ctx->device_type() == DeviceType(DEVICE_GPU)) {
70 platform_id = ctx->device()
71 ->tensorflow_gpu_device_info()
72 ->stream->parent()
73 ->platform()
74 ->id();
75 } else if (XlaDevice::GetMetadata(ctx, &xla_device_metadata).ok()) {
76 // If we are on an XlaDevice, use the underlying XLA platform's allocator
77 // directly. We could use the StreamExecutor's allocator which may
78 // theoretically be more correct, but XLA returns a nice OOM message in a
79 // Status and StreamExecutor does not.
80 //
81 // Importantly we can't use ctx->device()->GetAllocator() as the allocator
82 // (which xla_allocator above uses) as on an XlaDevice, this is a dummy
83 // allocator that returns XlaTensor objects. The XlaCompiler needs a real
84 // allocator to allocate real buffers.
85
86 platform_id = xla_device_metadata->platform()->id();
87 device_allocator =
88 xla_device_metadata->client()->backend().memory_allocator();
89 }
90
91 if (!device_allocator) {
92 xla::StatusOr<se::Platform*> maybe_platform =
93 se::MultiPlatformManager::PlatformWithId(platform_id);
94 OP_REQUIRES_OK_RETURN(ctx, XlaPlatformInfo(), maybe_platform.status());
95
96 xla_allocator = absl::make_unique<XlaAllocator>(
97 maybe_platform.ValueOrDie(), ctx->device()->GetAllocator({}));
98 }
99
100 return XlaPlatformInfo(device_type, platform_id, xla_device_metadata,
101 std::move(xla_allocator), device_allocator);
102 }
103
104 // A closure describing how to run a compiled version of a TensorFlow function.
105 //
106 // It may seem unusual to stick the resource variable snapshots in this class.
107 // This is necessary: we need to use the snapshots observed by the compiler as
108 // the initial values for the resource variables (and cannot snapshot them again
109 // during execution) because otherwise we risk observing a different snapshot
110 // with shapes different from what we compiled for.
111 class XlaExecutableClosure {
112 public:
XlaExecutableClosure(xla::LocalClient * client,xla::LocalExecutable * executable,const XlaCompiler::CompilationResult * compilation_result,std::map<int,OptionalTensor> resource_var_snapshots,int num_constant_args)113 explicit XlaExecutableClosure(
114 xla::LocalClient* client, xla::LocalExecutable* executable,
115 const XlaCompiler::CompilationResult* compilation_result,
116 std::map<int, OptionalTensor> resource_var_snapshots,
117 int num_constant_args)
118 : client_(client),
119 executable_(executable),
120 compilation_result_(compilation_result),
121 resource_var_snapshots_(std::move(resource_var_snapshots)),
122 num_constant_args_(num_constant_args) {}
123
124 XlaExecutableClosure(XlaExecutableClosure&&) = default;
125 XlaExecutableClosure& operator=(XlaExecutableClosure&&) = default;
126
client() const127 xla::LocalClient* client() const { return client_; }
executable() const128 xla::LocalExecutable* executable() const { return executable_; }
compilation_result() const129 const XlaCompiler::CompilationResult* compilation_result() const {
130 return compilation_result_;
131 }
resource_var_snapshots() const132 const std::map<int, OptionalTensor>& resource_var_snapshots() const {
133 return resource_var_snapshots_;
134 }
num_constant_args() const135 int num_constant_args() const { return num_constant_args_; }
136
137 private:
138 xla::LocalClient* client_;
139 xla::LocalExecutable* executable_;
140 const XlaCompiler::CompilationResult* compilation_result_;
141 std::map<int, OptionalTensor> resource_var_snapshots_;
142 int num_constant_args_;
143
144 TF_DISALLOW_COPY_AND_ASSIGN(XlaExecutableClosure);
145 };
146
147 // This maintains a mapping from a globally unique ID to XlaExecutableClosure
148 // instances.
149 class XlaExecutableClosureStore {
150 public:
XlaExecutableClosureStore()151 XlaExecutableClosureStore() : key_counter_(0) {}
152
153 using KeyT = string;
154
Produce(XlaExecutableClosure result)155 KeyT Produce(XlaExecutableClosure result) {
156 mutex_lock l(mutex_);
157 KeyT key = absl::StrCat(key_counter_++);
158 bool insert_successful = closures_.emplace(key, std::move(result)).second;
159 DCHECK(insert_successful);
160 (void)insert_successful;
161 return key;
162 }
163
Consume(const KeyT & key)164 XlaExecutableClosure Consume(const KeyT& key) {
165 mutex_lock l(mutex_);
166 auto it = closures_.find(key);
167 DCHECK(it != closures_.end());
168 XlaExecutableClosure value = std::move(it->second);
169 closures_.erase(it);
170 return value;
171 }
172
Global()173 static XlaExecutableClosureStore* Global() {
174 static XlaExecutableClosureStore* instance = new XlaExecutableClosureStore;
175 return instance;
176 }
177
178 private:
179 mutex mutex_;
180 int64 key_counter_ GUARDED_BY(mutex_);
181 absl::flat_hash_map<KeyT, XlaExecutableClosure> closures_ GUARDED_BY(mutex_);
182
183 TF_DISALLOW_COPY_AND_ASSIGN(XlaExecutableClosureStore);
184 };
185
186 } // namespace
187
XlaLocalLaunchBase(OpKernelConstruction * ctx,const std::vector<int> & constants,const std::vector<int> & resources,const NameAttrList & function)188 XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx,
189 const std::vector<int>& constants,
190 const std::vector<int>& resources,
191 const NameAttrList& function)
192 : OpKernel(ctx),
193 constants_(constants),
194 resources_(resources),
195 function_(function),
196 platform_info_(PlatformInfoFromContext(ctx)) {}
197
BuildCompilationCache(OpKernelContext * ctx,const XlaPlatformInfo & platform_info,XlaCompilationCache ** cache)198 static Status BuildCompilationCache(OpKernelContext* ctx,
199 const XlaPlatformInfo& platform_info,
200 XlaCompilationCache** cache) {
201 if (platform_info.xla_device_metadata()) {
202 *cache = new XlaCompilationCache(
203 platform_info.xla_device_metadata()->client(),
204 platform_info.xla_device_metadata()->jit_device_type());
205 return Status::OK();
206 }
207
208 auto platform =
209 se::MultiPlatformManager::PlatformWithId(platform_info.platform_id());
210 if (!platform.ok()) {
211 return platform.status();
212 }
213
214 xla::StatusOr<xla::Compiler*> compiler_for_platform =
215 xla::Compiler::GetForPlatform(platform.ValueOrDie());
216 if (!compiler_for_platform.ok()) {
217 // In some rare cases (usually in unit tests with very small clusters) we
218 // may end up transforming an XLA cluster with at least one GPU operation
219 // (which would normally force the cluster to be compiled using XLA:GPU)
220 // into an XLA cluster with no GPU operations (i.e. containing only CPU
221 // operations). Such a cluster can fail compilation (in way that
222 // MarkForCompilation could not have detected) if the CPU JIT is not linked
223 // in.
224 //
225 // So bail out of _XlaCompile in this case, and let the executor handle the
226 // situation for us.
227 const Status& status = compiler_for_platform.status();
228 if (status.code() == error::NOT_FOUND) {
229 return errors::Unimplemented("Could not find compiler for platform ",
230 platform.ValueOrDie()->Name(), ": ",
231 status.ToString());
232 }
233 }
234
235 xla::LocalClientOptions client_options;
236 client_options.set_platform(platform.ValueOrDie());
237 client_options.set_intra_op_parallelism_threads(
238 ctx->device()->tensorflow_cpu_worker_threads()->num_threads);
239 auto client = xla::ClientLibrary::GetOrCreateLocalClient(client_options);
240 if (!client.ok()) {
241 return client.status();
242 }
243 const XlaOpRegistry::DeviceRegistration* registration;
244 if (!XlaOpRegistry::GetCompilationDevice(platform_info.device_type().type(),
245 ®istration)) {
246 return errors::InvalidArgument("No JIT device registered for ",
247 platform_info.device_type().type());
248 }
249 *cache = new XlaCompilationCache(
250 client.ValueOrDie(), DeviceType(registration->compilation_device_name));
251 return Status::OK();
252 }
253
CompileToLocalExecutable(OpKernelContext * ctx,const NameAttrList & function,const XlaPlatformInfo & platform_info,absl::Span<const int> resources,absl::Span<const int> constants,bool lazy,xla::LocalClient ** client,std::map<int,OptionalTensor> * variables,const XlaCompiler::CompilationResult ** kernel,xla::LocalExecutable ** executable)254 static Status CompileToLocalExecutable(
255 OpKernelContext* ctx, const NameAttrList& function,
256 const XlaPlatformInfo& platform_info, absl::Span<const int> resources,
257 absl::Span<const int> constants, bool lazy, xla::LocalClient** client,
258 std::map<int, OptionalTensor>* variables,
259 const XlaCompiler::CompilationResult** kernel,
260 xla::LocalExecutable** executable) {
261 // We store information about the JIT-compiled XLA computation
262 // in the ResourceMgr.
263 ResourceMgr* rm = ctx->resource_manager();
264 if (!rm) {
265 return errors::Internal("No resource manager.");
266 }
267
268 XlaCompilationCache* cache;
269 TF_RETURN_IF_ERROR(rm->LookupOrCreate<XlaCompilationCache>(
270 rm->default_container(), "xla_cache", &cache,
271 [&](XlaCompilationCache** cache) {
272 return BuildCompilationCache(ctx, platform_info, cache);
273 }));
274 // Hold the reference to the JIT during evaluation. (We could probably
275 // free it sooner because the ResourceMgr will retain a reference, but
276 // this is more obviously correct.)
277 core::ScopedUnref cache_ref(cache);
278
279 TF_RETURN_IF_ERROR(SnapshotResourceVariables(ctx, resources, variables));
280 *client = static_cast<xla::LocalClient*>(cache->client());
281
282 XlaCompiler::Options options;
283 options.client = *client;
284 if (ctx->op_device_context() != nullptr) {
285 options.device_ordinal =
286 ctx->op_device_context()->stream()->parent()->device_ordinal();
287 }
288 options.device_type = cache->device_type();
289 options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition();
290 options.graph_def_version = ctx->function_library()->graph_def_version();
291 options.allow_cpu_custom_calls =
292 (platform_info.platform_id() == se::host::kHostPlatformId);
293 options.device_allocator = platform_info.allocator();
294 if (platform_info.xla_device_metadata()) {
295 options.shape_representation_fn =
296 platform_info.xla_device_metadata()->shape_representation_fn();
297 }
298
299 std::map<int, Tensor> constant_args;
300 for (int i : constants) {
301 constant_args.insert({i, ctx->input(i)});
302 }
303 XlaCompiler::CompileOptions compile_options;
304 compile_options.is_entry_computation = true;
305 // If we resolve constants we never emit them on the device, meaning that if
306 // they are needed by a following computation the host has to transfer
307 // them. Not resolving constants is expected to be faster than resolving
308 // constants.
309 compile_options.resolve_compile_time_constants = true;
310 // Optimization: where possible, have the computation return a naked array
311 // rather than a one-element tuple.
312 compile_options.always_return_tuple = false;
313
314 std::vector<XlaCompiler::Argument> args;
315 TF_RETURN_IF_ERROR(XlaComputationLaunchContext::BuildXlaCompilerArguments(
316 constant_args, *variables, ctx, &args));
317 return cache->Compile(options, function, args, compile_options,
318 lazy ? XlaCompilationCache::CompileMode::kLazy
319 : XlaCompilationCache::CompileMode::kStrict,
320 kernel, executable);
321 }
322
Compute(OpKernelContext * ctx)323 void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
324 VLOG(1) << "XlaLocalLaunchOpBase::Compute "
325 << Canonicalize(function_.name(), AttrSlice(&function_.attr()));
326
327 xla::LocalClient* client;
328 const XlaCompiler::CompilationResult* kernel;
329 xla::LocalExecutable* executable;
330 std::map<int, OptionalTensor> variables;
331
332 {
333 Status s = CompileToLocalExecutable(
334 ctx, function_, platform_info_, resources_, constants_, /*lazy=*/false,
335 &client, &variables, &kernel, &executable);
336 if (!s.ok() && (platform_info_.device_type().type_string() == DEVICE_CPU ||
337 platform_info_.device_type().type_string() == DEVICE_GPU)) {
338 // Suggest auto jit if the failure was with GPU or CPU.
339 errors::AppendToMessage(&s,
340 xla::status_macros::kPossibleAutoJitAlternative);
341 }
342
343 OP_REQUIRES_OK(ctx, s);
344 }
345
346 se::Stream* stream =
347 ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
348
349 VLOG(1) << "Executing XLA Computation...";
350
351 XlaComputationLaunchContext launch_context(
352 client, platform_info_.allocator(),
353 /*allocate_xla_tensors=*/platform_info_.is_on_xla_device(),
354 platform_info_.UseMultipleStreams());
355 launch_context.PopulateInputs(ctx, kernel, variables,
356 /*missing_ctx_input_prefix=*/0);
357
358 // Execute the computation.
359 VLOG(2) << "Executing computation.";
360 xla::ExecutableRunOptions run_options;
361 run_options.set_stream(stream);
362 run_options.set_allocator(platform_info_.allocator());
363 run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
364 run_options.set_rng_seed(GetXLARandomSeed());
365 Env* env = Env::Default();
366 auto start_time = env->NowMicros();
367
368 auto run_result = executable->Run(launch_context.arguments(), run_options);
369 OP_REQUIRES(ctx, run_result.ok(), run_result.status());
370
371 auto elapsed = env->NowMicros() - start_time;
372 VLOG(2) << "Elapsed time: " << elapsed << "us";
373
374 OP_REQUIRES_OK(ctx, launch_context.PopulateOutputs(
375 ctx, kernel, run_result.ConsumeValueOrDie(),
376 /*missing_ctx_input_prefix=*/0));
377 VLOG(1) << "Done";
378 }
379
380 namespace {
381 // Helper static functions to construct parameters for
382 // XlaLocalLaunchBase constructor from OpKernelConstruction.
ConstantsVector(OpKernelConstruction * ctx)383 std::vector<int> ConstantsVector(OpKernelConstruction* ctx) {
384 DataTypeVector constant_types;
385 OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
386 ctx->GetAttr("Tconstants", &constant_types));
387 std::vector<int> constants(constant_types.size());
388 std::iota(constants.begin(), constants.end(), 0);
389 return constants;
390 }
391
ResourcesVector(OpKernelConstruction * ctx)392 std::vector<int> ResourcesVector(OpKernelConstruction* ctx) {
393 DataTypeVector constant_types;
394 OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
395 ctx->GetAttr("Tconstants", &constant_types));
396
397 DataTypeVector arg_types;
398 OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
399 ctx->GetAttr("Targs", &arg_types));
400
401 int num_resources;
402 OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
403 ctx->GetAttr("Nresources", &num_resources));
404
405 std::vector<int> resources(num_resources);
406 std::iota(resources.begin(), resources.end(),
407 constant_types.size() + arg_types.size());
408 return resources;
409 }
410
FunctionAttr(OpKernelConstruction * ctx)411 NameAttrList FunctionAttr(OpKernelConstruction* ctx) {
412 const NameAttrList* func;
413 OP_REQUIRES_OK_RETURN(ctx, NameAttrList(), ctx->GetAttr("function", &func));
414 return *func;
415 }
416
MustCompileAttr(OpKernelConstruction * ctx)417 bool MustCompileAttr(OpKernelConstruction* ctx) {
418 bool must_compile;
419 OP_REQUIRES_OK_RETURN(ctx, false,
420 ctx->GetAttr("must_compile", &must_compile));
421 return must_compile;
422 }
423 } // namespace
424
XlaLocalLaunchOp(OpKernelConstruction * ctx)425 XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx)
426 : XlaLocalLaunchBase(ctx, ConstantsVector(ctx), ResourcesVector(ctx),
427 FunctionAttr(ctx)) {}
428
~XlaLocalLaunchOp()429 XlaLocalLaunchOp::~XlaLocalLaunchOp() {
430 VLOG(1) << "XlaLocalLaunchOp destroyed";
431 }
432
XlaCompileOp(OpKernelConstruction * ctx)433 XlaCompileOp::XlaCompileOp(OpKernelConstruction* ctx)
434 : OpKernel(ctx),
435 constants_(ConstantsVector(ctx)),
436 resources_(ResourcesVector(ctx)),
437 function_(FunctionAttr(ctx)),
438 platform_info_(PlatformInfoFromContext(ctx)),
439 must_compile_(MustCompileAttr(ctx)) {}
440
Compute(OpKernelContext * ctx)441 void XlaCompileOp::Compute(OpKernelContext* ctx) {
442 VLOG(3) << "XlaCompileOp " << def().name()
443 << (must_compile_ ? "(must-compile)" : "");
444 xla::LocalClient* client;
445 const XlaCompiler::CompilationResult* kernel;
446 xla::LocalExecutable* executable;
447 std::map<int, OptionalTensor> variables;
448
449 bool cannot_compile_cluster;
450 {
451 mutex_lock guard(cannot_compile_cluster_mu_);
452 cannot_compile_cluster = cannot_compile_cluster_;
453 }
454
455 if (GetXlaOpsCommonFlags().tf_xla_always_defer_compilation ||
456 cannot_compile_cluster) {
457 executable = nullptr;
458 } else {
459 Status status = CompileToLocalExecutable(
460 ctx, function_, platform_info_, resources_, constants_,
461 /*lazy=*/!must_compile_, &client, &variables, &kernel, &executable);
462 if (must_compile_ || status.code() != error::UNIMPLEMENTED) {
463 OP_REQUIRES_OK(ctx, status);
464 }
465
466 if (status.code() == error::UNIMPLEMENTED) {
467 LOG(WARNING) << "Compilation failed:" << status.ToString()
468 << ". Falling back to TF function call.";
469 executable = nullptr;
470 mutex_lock guard(cannot_compile_cluster_mu_);
471 cannot_compile_cluster_ = true;
472 }
473 }
474
475 AllocatorAttributes host_alloc_attrs;
476 host_alloc_attrs.set_gpu_compatible(true);
477 host_alloc_attrs.set_on_host(true);
478 Allocator* cpu_allocator = ctx->device()->GetAllocator(host_alloc_attrs);
479
480 if (!executable) {
481 DCHECK(!must_compile_);
482 Tensor compilation_key(cpu_allocator, DT_STRING, TensorShape({}));
483
484 Tensor compilation_successful(cpu_allocator, DT_BOOL, TensorShape({}));
485 compilation_successful.scalar<bool>()() = false;
486 ctx->set_output(0, Tensor(cpu_allocator, DT_STRING, TensorShape({})));
487 ctx->set_output(1, compilation_successful);
488 return;
489 }
490
491 // Each execution of an XlaCompile op creates a new XlaExecutableClosure, even
492 // if it didn't have to compile the cluster because of a compilation-cache
493 // hit. This is because we at least need new snapshots of the resource
494 // variables.
495 XlaExecutableClosureStore::KeyT key =
496 XlaExecutableClosureStore::Global()->Produce(XlaExecutableClosure(
497 client, executable, kernel, std::move(variables), constants_.size()));
498
499 Tensor compilation_key(cpu_allocator, DT_STRING, TensorShape({}));
500 compilation_key.flat<string>()(0) = key;
501
502 Tensor compilation_successful(cpu_allocator, DT_BOOL, TensorShape({}));
503 compilation_successful.flat<bool>()(0) = true;
504
505 ctx->set_output(0, compilation_key);
506 ctx->set_output(1, compilation_successful);
507 }
508
XlaRunOp(OpKernelConstruction * ctx)509 XlaRunOp::XlaRunOp(OpKernelConstruction* ctx)
510 : OpKernel(ctx), platform_info_(PlatformInfoFromContext(ctx)) {}
511
Compute(OpKernelContext * ctx)512 void XlaRunOp::Compute(OpKernelContext* ctx) {
513 VLOG(3) << "XlaRunOp " << def().name();
514 Tensor key_tensor = ctx->input(ctx->num_inputs() - 1);
515 const XlaExecutableClosureStore::KeyT& key = key_tensor.flat<string>()(0);
516
517 XlaExecutableClosure closure =
518 XlaExecutableClosureStore::Global()->Consume(key);
519
520 XlaComputationLaunchContext launch_context(
521 closure.client(), platform_info_.allocator(),
522 /*allocate_xla_tensors=*/platform_info_.is_on_xla_device(),
523 /*use_multiple_streams=*/platform_info_.UseMultipleStreams());
524
525 // We're missing the must-be-constant inputs, tell `PopulateInputs`
526 // about this. We don't actually need these inputs because they've
527 // already been baked into the compiled kernel.
528 launch_context.PopulateInputs(
529 ctx, closure.compilation_result(), closure.resource_var_snapshots(),
530 /*missing_ctx_input_prefix=*/closure.num_constant_args());
531
532 se::Stream* stream =
533 ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
534 xla::ExecutableRunOptions run_options;
535 run_options.set_stream(stream);
536 run_options.set_allocator(platform_info_.allocator());
537 run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
538 run_options.set_rng_seed(GetXLARandomSeed());
539 Env* env = Env::Default();
540 auto start_time = env->NowMicros();
541
542 auto run_result =
543 closure.executable()->Run(launch_context.arguments(), run_options);
544 OP_REQUIRES(ctx, run_result.ok(), run_result.status());
545
546 auto elapsed = env->NowMicros() - start_time;
547 VLOG(2) << "Elapsed time in computation: " << elapsed << "us";
548
549 OP_REQUIRES_OK(
550 ctx,
551 launch_context.PopulateOutputs(
552 ctx, closure.compilation_result(), run_result.ConsumeValueOrDie(),
553 /*missing_ctx_input_prefix=*/closure.num_constant_args()));
554 }
555
556 REGISTER_KERNEL_BUILDER(Name("XlaLaunch").Device(DEVICE_CPU), XlaLocalLaunchOp);
557
558 REGISTER_KERNEL_BUILDER(Name("XlaLaunch")
559 .Device(DEVICE_GPU)
560 .HostMemory("constants")
561 .HostMemory("resources"),
562 XlaLocalLaunchOp);
563
564 REGISTER_KERNEL_BUILDER(Name("_XlaCompile").Device(DEVICE_CPU), XlaCompileOp);
565 REGISTER_KERNEL_BUILDER(Name("_XlaCompile")
566 .Device(DEVICE_GPU)
567 .HostMemory("constants")
568 .HostMemory("key")
569 .HostMemory("compilation_successful")
570 .HostMemory("resources"),
571 XlaCompileOp);
572
573 REGISTER_KERNEL_BUILDER(Name("_XlaRun").Device(DEVICE_CPU), XlaRunOp);
574 REGISTER_KERNEL_BUILDER(Name("_XlaRun").Device(DEVICE_GPU), XlaRunOp);
575
576 } // namespace tensorflow
577