1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Defines the XlaCompileOnDemandOp.
17
18 #include "tensorflow/compiler/jit/xla_compile_on_demand_op.h"
19
20 #include "absl/memory/memory.h"
21 #include "tensorflow/compiler/jit/xla_device.h"
22 #include "tensorflow/compiler/jit/xla_launch_util.h"
23 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
24 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
25 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
26
27 namespace tensorflow {
28
29 namespace {
GetVariables(OpKernelContext * ctx)30 std::map<int, OptionalTensor> GetVariables(OpKernelContext* ctx) {
31 std::map<int, OptionalTensor> variables;
32 for (int64 i = 0; i < ctx->num_inputs(); ++i) {
33 if (ctx->input(i).dtype() == DT_RESOURCE) {
34 Var* variable = nullptr;
35 ResourceHandle handle = HandleFromInput(ctx, i);
36 OptionalTensor& optional = variables[i];
37 optional.name = handle.name();
38 if (LookupResource(ctx, handle, &variable).ok()) {
39 core::ScopedUnref scoped_unref(variable);
40 tf_shared_lock lock(*variable->mu());
41 optional.present = true;
42 optional.value = *variable->tensor();
43 }
44 }
45 }
46 return variables;
47 }
48 } // namespace
49
Run(OpKernelContext * ctx,const XlaDevice::Metadata & metadata,const XlaCompiler::CompilationResult * result,xla::LocalExecutable * executable)50 Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx,
51 const XlaDevice::Metadata& metadata,
52 const XlaCompiler::CompilationResult* result,
53 xla::LocalExecutable* executable) {
54 std::map<int, OptionalTensor> variables = GetVariables(ctx);
55
56 xla::LocalClient* client = metadata.client();
57
58 // Builds an XLA allocator for the device.
59 XlaComputationLaunchContext launch_context(
60 client, client->backend().memory_allocator(),
61 /*allocate_xla_tensors=*/true,
62 /*use_multiple_streams=*/metadata.UseMultipleStreams());
63
64 launch_context.PopulateInputs(ctx, result, variables,
65 /*missing_ctx_input_prefix=*/0);
66
67 se::Stream* stream =
68 ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
69 TF_RET_CHECK(stream);
70
71 VLOG(2) << "Executing computation: " << name();
72 for (const xla::ShapedBuffer* arg : launch_context.arguments()) {
73 VLOG(2) << name() << ": " << *arg;
74 }
75 xla::ExecutableRunOptions run_options;
76 run_options.set_stream(stream);
77 run_options.set_allocator(client->backend().memory_allocator());
78 run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
79 run_options.set_rng_seed(GetXLARandomSeed());
80
81 xla::StatusOr<xla::ScopedShapedBuffer> run_result =
82 executable->Run(launch_context.arguments(), run_options);
83 TF_RETURN_IF_ERROR(run_result.status());
84
85 TF_RETURN_IF_ERROR(launch_context.PopulateOutputs(
86 ctx, result, run_result.ConsumeValueOrDie(),
87 /*missing_ctx_input_prefix=*/0));
88 return Status::OK();
89 }
90
MustArgumentBeConstant(const OpKernel * op_kernel,int64 argument_idx,bool * result)91 Status XlaCompileOnDemandOp::MustArgumentBeConstant(const OpKernel* op_kernel,
92 int64 argument_idx,
93 bool* result) {
94 *result = false;
95
96 // TODO(jmolloy): This could be expensive, so memoize.
97 std::vector<int> constant_input_indices;
98 TF_RETURN_IF_ERROR(XlaOpRegistry::CompileTimeConstantInputs(
99 *op_kernel, &constant_input_indices));
100 *result = absl::c_binary_search(constant_input_indices, argument_idx);
101 return Status::OK();
102 }
103
ShouldArgumentBeConstant(const OpKernel * op_kernel,int64 argument_idx,bool * result)104 Status XlaCompileOnDemandOp::ShouldArgumentBeConstant(const OpKernel* op_kernel,
105 int64 argument_idx,
106 bool* result) {
107 // Right now we only create kConstant arguments when absolutely required, but
108 // there may be benefit in eagerly constant-folding a larger subset of
109 // arguments in the future.
110 return MustArgumentBeConstant(op_kernel, argument_idx, result);
111 }
112
Compile(OpKernelContext * ctx,const XlaDevice::Metadata & metadata,const XlaCompiler::CompilationResult ** result,xla::LocalExecutable ** executable)113 Status XlaCompileOnDemandOp::Compile(
114 OpKernelContext* ctx, const XlaDevice::Metadata& metadata,
115 const XlaCompiler::CompilationResult** result,
116 xla::LocalExecutable** executable) {
117 std::map<int, Tensor> constant_arguments;
118 for (int64 i = 0; i < ctx->num_inputs(); ++i) {
119 const Tensor& device_tensor = ctx->input(i);
120 if (const XlaTensor* xla_tensor = XlaTensor::FromTensor(&device_tensor)) {
121 if (xla_tensor->has_host_tensor()) {
122 bool should_arg_be_const;
123 TF_RETURN_IF_ERROR(ShouldArgumentBeConstant(&ctx->op_kernel(), i,
124 &should_arg_be_const));
125 if (should_arg_be_const) {
126 constant_arguments[i] = xla_tensor->host_tensor();
127 }
128 }
129 }
130
131 if (constant_arguments.count(i) == 0) {
132 bool must_argument_be_const;
133 TF_RETURN_IF_ERROR(MustArgumentBeConstant(&ctx->op_kernel(), i,
134 &must_argument_be_const));
135
136 if (must_argument_be_const) {
137 // Slow path; the argument is not available as a host constant so we
138 // must fetch it synchronously.
139 Tensor host_tensor;
140 AllocatorAttributes attrs;
141 attrs.set_on_host(true);
142 TF_RETURN_IF_ERROR(ctx->allocate_temp(
143 device_tensor.dtype(), device_tensor.shape(), &host_tensor, attrs));
144 Notification n;
145 Status status;
146 ctx->op_device_context()->CopyDeviceTensorToCPU(
147 &device_tensor, "ConstantArgument",
148 reinterpret_cast<Device*>(ctx->device()), &host_tensor,
149 [&](Status s) {
150 status = s;
151 n.Notify();
152 });
153 n.WaitForNotification();
154 if (!status.ok()) {
155 LOG(ERROR) << "Copying tensor of shape "
156 << device_tensor.shape().DebugString() << " from "
157 << ctx->device()->name() << "to CPU failed with "
158 << status.ToString();
159 return status;
160 }
161 constant_arguments[i] = host_tensor;
162 }
163 }
164 }
165
166 // We store information about the JIT-compiled XLA computation
167 // in the ResourceMgr.
168 ResourceMgr* rm = ctx->resource_manager();
169 CHECK(rm);
170
171 XlaCompilationCache* cache;
172 TF_RETURN_IF_ERROR(rm->LookupOrCreate<XlaCompilationCache>(
173 rm->default_container(), "xla_cache", &cache,
174 [&](XlaCompilationCache** cache) {
175 *cache = new XlaCompilationCache(metadata.client(),
176 metadata.jit_device_type());
177 return Status::OK();
178 }));
179 // Hold the reference to the JIT during evaluation. (We could probably
180 // free it sooner because the ResourceMgr will retain a reference, but
181 // this is more obviously correct.)
182 core::ScopedUnref cache_ref(cache);
183
184 XlaCompiler::Options options;
185 options.device_type = metadata.jit_device_type();
186 options.client = metadata.client();
187 options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition();
188 options.shape_representation_fn = metadata.shape_representation_fn();
189
190 XlaCompiler::CompileOptions compile_options;
191 compile_options.is_entry_computation = true;
192 // Optimization: don't resolve constants. If we resolve constants we never
193 // emit them on the device, meaning that if they are needed by a following
194 // computation the host has to transfer them.
195 compile_options.resolve_compile_time_constants = false;
196 // Optimization: where possible, have the computation return a naked array
197 // rather than a one-element tuple.
198 compile_options.always_return_tuple = false;
199
200 std::map<int, OptionalTensor> variable_args = GetVariables(ctx);
201
202 std::vector<XlaCompiler::Argument> args;
203
204 TF_RETURN_IF_ERROR(XlaComputationLaunchContext::BuildXlaCompilerArguments(
205 constant_arguments, variable_args, ctx, &args));
206
207 return cache->CompileSingleOp(options, args, ctx, compile_options, result,
208 executable);
209 }
210
Compute(OpKernelContext * ctx)211 void XlaCompileOnDemandOp::Compute(OpKernelContext* ctx) {
212 const XlaCompiler::CompilationResult* result;
213 xla::LocalExecutable* executable;
214 const XlaDevice::Metadata* metadata;
215 OP_REQUIRES_OK(ctx, XlaDevice::GetMetadata(ctx, &metadata));
216 OP_REQUIRES_OK(ctx, Compile(ctx, *metadata, &result, &executable));
217 OP_REQUIRES_OK(ctx, Run(ctx, *metadata, result, executable));
218 }
219
220 } // namespace tensorflow
221