• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Defines the XlaCompileOnDemandOp.
17 
18 #include "tensorflow/compiler/jit/xla_compile_on_demand_op.h"
19 
20 #include "absl/memory/memory.h"
21 #include "tensorflow/compiler/jit/xla_device.h"
22 #include "tensorflow/compiler/jit/xla_launch_util.h"
23 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
24 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
25 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
26 
27 namespace tensorflow {
28 
29 namespace {
GetVariables(OpKernelContext * ctx)30 std::map<int, OptionalTensor> GetVariables(OpKernelContext* ctx) {
31   std::map<int, OptionalTensor> variables;
32   for (int64 i = 0; i < ctx->num_inputs(); ++i) {
33     if (ctx->input(i).dtype() == DT_RESOURCE) {
34       Var* variable = nullptr;
35       ResourceHandle handle = HandleFromInput(ctx, i);
36       OptionalTensor& optional = variables[i];
37       optional.name = handle.name();
38       if (LookupResource(ctx, handle, &variable).ok()) {
39         core::ScopedUnref scoped_unref(variable);
40         tf_shared_lock lock(*variable->mu());
41         optional.present = true;
42         optional.value = *variable->tensor();
43       }
44     }
45   }
46   return variables;
47 }
48 }  // namespace
49 
Run(OpKernelContext * ctx,const XlaDevice::Metadata & metadata,const XlaCompiler::CompilationResult * result,xla::LocalExecutable * executable)50 Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx,
51                                  const XlaDevice::Metadata& metadata,
52                                  const XlaCompiler::CompilationResult* result,
53                                  xla::LocalExecutable* executable) {
54   std::map<int, OptionalTensor> variables = GetVariables(ctx);
55 
56   xla::LocalClient* client = metadata.client();
57 
58   // Builds an XLA allocator for the device.
59   XlaComputationLaunchContext launch_context(
60       client, client->backend().memory_allocator(),
61       /*allocate_xla_tensors=*/true,
62       /*use_multiple_streams=*/metadata.UseMultipleStreams());
63 
64   launch_context.PopulateInputs(ctx, result, variables,
65                                 /*missing_ctx_input_prefix=*/0);
66 
67   se::Stream* stream =
68       ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
69   TF_RET_CHECK(stream);
70 
71   VLOG(2) << "Executing computation: " << name();
72   for (const xla::ShapedBuffer* arg : launch_context.arguments()) {
73     VLOG(2) << name() << ": " << *arg;
74   }
75   xla::ExecutableRunOptions run_options;
76   run_options.set_stream(stream);
77   run_options.set_allocator(client->backend().memory_allocator());
78   run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
79   run_options.set_rng_seed(GetXLARandomSeed());
80 
81   xla::StatusOr<xla::ScopedShapedBuffer> run_result =
82       executable->Run(launch_context.arguments(), run_options);
83   TF_RETURN_IF_ERROR(run_result.status());
84 
85   TF_RETURN_IF_ERROR(launch_context.PopulateOutputs(
86       ctx, result, run_result.ConsumeValueOrDie(),
87       /*missing_ctx_input_prefix=*/0));
88   return Status::OK();
89 }
90 
MustArgumentBeConstant(const OpKernel * op_kernel,int64 argument_idx,bool * result)91 Status XlaCompileOnDemandOp::MustArgumentBeConstant(const OpKernel* op_kernel,
92                                                     int64 argument_idx,
93                                                     bool* result) {
94   *result = false;
95 
96   // TODO(jmolloy): This could be expensive, so memoize.
97   std::vector<int> constant_input_indices;
98   TF_RETURN_IF_ERROR(XlaOpRegistry::CompileTimeConstantInputs(
99       *op_kernel, &constant_input_indices));
100   *result = absl::c_binary_search(constant_input_indices, argument_idx);
101   return Status::OK();
102 }
103 
ShouldArgumentBeConstant(const OpKernel * op_kernel,int64 argument_idx,bool * result)104 Status XlaCompileOnDemandOp::ShouldArgumentBeConstant(const OpKernel* op_kernel,
105                                                       int64 argument_idx,
106                                                       bool* result) {
107   // Right now we only create kConstant arguments when absolutely required, but
108   // there may be benefit in eagerly constant-folding a larger subset of
109   // arguments in the future.
110   return MustArgumentBeConstant(op_kernel, argument_idx, result);
111 }
112 
Compile(OpKernelContext * ctx,const XlaDevice::Metadata & metadata,const XlaCompiler::CompilationResult ** result,xla::LocalExecutable ** executable)113 Status XlaCompileOnDemandOp::Compile(
114     OpKernelContext* ctx, const XlaDevice::Metadata& metadata,
115     const XlaCompiler::CompilationResult** result,
116     xla::LocalExecutable** executable) {
117   std::map<int, Tensor> constant_arguments;
118   for (int64 i = 0; i < ctx->num_inputs(); ++i) {
119     const Tensor& device_tensor = ctx->input(i);
120     if (const XlaTensor* xla_tensor = XlaTensor::FromTensor(&device_tensor)) {
121       if (xla_tensor->has_host_tensor()) {
122         bool should_arg_be_const;
123         TF_RETURN_IF_ERROR(ShouldArgumentBeConstant(&ctx->op_kernel(), i,
124                                                     &should_arg_be_const));
125         if (should_arg_be_const) {
126           constant_arguments[i] = xla_tensor->host_tensor();
127         }
128       }
129     }
130 
131     if (constant_arguments.count(i) == 0) {
132       bool must_argument_be_const;
133       TF_RETURN_IF_ERROR(MustArgumentBeConstant(&ctx->op_kernel(), i,
134                                                 &must_argument_be_const));
135 
136       if (must_argument_be_const) {
137         // Slow path; the argument is not available as a host constant so we
138         // must fetch it synchronously.
139         Tensor host_tensor;
140         AllocatorAttributes attrs;
141         attrs.set_on_host(true);
142         TF_RETURN_IF_ERROR(ctx->allocate_temp(
143             device_tensor.dtype(), device_tensor.shape(), &host_tensor, attrs));
144         Notification n;
145         Status status;
146         ctx->op_device_context()->CopyDeviceTensorToCPU(
147             &device_tensor, "ConstantArgument",
148             reinterpret_cast<Device*>(ctx->device()), &host_tensor,
149             [&](Status s) {
150               status = s;
151               n.Notify();
152             });
153         n.WaitForNotification();
154         if (!status.ok()) {
155           LOG(ERROR) << "Copying tensor of shape "
156                      << device_tensor.shape().DebugString() << " from "
157                      << ctx->device()->name() << "to CPU failed with "
158                      << status.ToString();
159           return status;
160         }
161         constant_arguments[i] = host_tensor;
162       }
163     }
164   }
165 
166   // We store information about the JIT-compiled XLA computation
167   // in the ResourceMgr.
168   ResourceMgr* rm = ctx->resource_manager();
169   CHECK(rm);
170 
171   XlaCompilationCache* cache;
172   TF_RETURN_IF_ERROR(rm->LookupOrCreate<XlaCompilationCache>(
173       rm->default_container(), "xla_cache", &cache,
174       [&](XlaCompilationCache** cache) {
175         *cache = new XlaCompilationCache(metadata.client(),
176                                          metadata.jit_device_type());
177         return Status::OK();
178       }));
179   // Hold the reference to the JIT during evaluation. (We could probably
180   // free it sooner because the ResourceMgr will retain a reference, but
181   // this is more obviously correct.)
182   core::ScopedUnref cache_ref(cache);
183 
184   XlaCompiler::Options options;
185   options.device_type = metadata.jit_device_type();
186   options.client = metadata.client();
187   options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition();
188   options.shape_representation_fn = metadata.shape_representation_fn();
189 
190   XlaCompiler::CompileOptions compile_options;
191   compile_options.is_entry_computation = true;
192   // Optimization: don't resolve constants. If we resolve constants we never
193   // emit them on the device, meaning that if they are needed by a following
194   // computation the host has to transfer them.
195   compile_options.resolve_compile_time_constants = false;
196   // Optimization: where possible, have the computation return a naked array
197   // rather than a one-element tuple.
198   compile_options.always_return_tuple = false;
199 
200   std::map<int, OptionalTensor> variable_args = GetVariables(ctx);
201 
202   std::vector<XlaCompiler::Argument> args;
203 
204   TF_RETURN_IF_ERROR(XlaComputationLaunchContext::BuildXlaCompilerArguments(
205       constant_arguments, variable_args, ctx, &args));
206 
207   return cache->CompileSingleOp(options, args, ctx, compile_options, result,
208                                 executable);
209 }
210 
Compute(OpKernelContext * ctx)211 void XlaCompileOnDemandOp::Compute(OpKernelContext* ctx) {
212   const XlaCompiler::CompilationResult* result;
213   xla::LocalExecutable* executable;
214   const XlaDevice::Metadata* metadata;
215   OP_REQUIRES_OK(ctx, XlaDevice::GetMetadata(ctx, &metadata));
216   OP_REQUIRES_OK(ctx, Compile(ctx, *metadata, &result, &executable));
217   OP_REQUIRES_OK(ctx, Run(ctx, *metadata, result, executable));
218 }
219 
220 }  // namespace tensorflow
221