• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/jit/xla_launch_util.h"
17 
18 #include <memory>
19 
20 #include "absl/algorithm/container.h"
21 #include "absl/memory/memory.h"
22 #include "tensorflow/compiler/jit/defs.h"
23 #include "tensorflow/compiler/tf2xla/shape_util.h"
24 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
25 #include "tensorflow/compiler/xla/client/client_library.h"
26 #include "tensorflow/compiler/xla/client/local_client.h"
27 #include "tensorflow/compiler/xla/statusor.h"
28 #include "tensorflow/core/common_runtime/dma_helper.h"
29 #include "tensorflow/core/common_runtime/function.h"
30 #include "tensorflow/core/common_runtime/gpu_device_context.h"
31 #include "tensorflow/core/framework/allocator.h"
32 #include "tensorflow/core/framework/node_def_util.h"
33 #include "tensorflow/core/framework/op.h"
34 #include "tensorflow/core/framework/op_kernel.h"
35 #include "tensorflow/core/framework/resource_mgr.h"
36 #include "tensorflow/core/framework/tensor.h"
37 #include "tensorflow/core/framework/types.h"
38 #include "tensorflow/core/lib/core/errors.h"
39 #include "tensorflow/core/lib/core/refcount.h"
40 #include "tensorflow/core/util/stream_executor_util.h"
41 
42 namespace tensorflow {
43 namespace {
44 using xla::ScopedShapedBuffer;
45 using xla::ShapedBuffer;
46 
47 }  // anonymous namespace
48 
VariableInfo(int index,absl::string_view name,Var * var)49 VariableInfo::VariableInfo(int index, absl::string_view name, Var* var)
50     : index_(index), name_(name), var_(var) {}
VariableInfo(VariableInfo && other)51 VariableInfo::VariableInfo(VariableInfo&& other)
52     : index_(other.index_), var_(other.var_), lock_held_(other.lock_held_) {
53   other.index_ = -1;
54   other.var_ = nullptr;
55 }
56 
operator =(VariableInfo && other)57 VariableInfo& VariableInfo::operator=(VariableInfo&& other) {
58   index_ = other.index_;
59   var_ = other.var_;
60   lock_held_ = other.lock_held_;
61 
62   other.index_ = -1;
63   other.var_ = nullptr;
64 
65   return *this;
66 }
67 
~VariableInfo()68 VariableInfo::~VariableInfo() {
69   // Release the variable's lock if we hold it. Ensures that the lock is
70   // released even on error.  It does not matter in what order we release the
71   // locks.
72   if (var()) {
73     if (lock_held()) {
74       var()->mu()->unlock();
75     }
76 
77     // Unref the variable so it can be released by ResourceManager.
78     var()->Unref();
79   }
80 }
81 
GetVariableInfosFromInputs(ResourceMgr * rm,DeviceBase * dev,absl::Span<const Tensor * const> inputs,absl::Span<const int> variable_indices,std::vector<VariableInfo> * result)82 Status GetVariableInfosFromInputs(ResourceMgr* rm, DeviceBase* dev,
83                                   absl::Span<const Tensor* const> inputs,
84                                   absl::Span<const int> variable_indices,
85                                   std::vector<VariableInfo>* result) {
86   result->clear();
87   result->reserve(variable_indices.size());
88   for (int var_idx : variable_indices) {
89     Var* variable = nullptr;
90     ResourceHandle handle = inputs[var_idx]->flat<ResourceHandle>()(0);
91     if (handle.device() != dev->attributes().name()) {
92       std::string definition_location = [&]() -> std::string {
93         if (handle.definition_stack_trace()) {
94           std::vector<StackFrame> stack_frames =
95               handle.definition_stack_trace()->ToStackFrames(
96                   {}, IsInternalFrameForFilename,
97                   /*reverse_traversal=*/true,
98                   /*limit=*/1);
99           if (!stack_frames.empty()) {
100             const StackFrame& last_frame = stack_frames[0];
101             return absl::StrCat(" (defined @ ", last_frame.file_name, ":",
102                                 last_frame.line_number, ")");
103           }
104         }
105         return "";
106       }();
107       return errors::InvalidArgument("Trying to access resource ",
108                                      handle.name(), definition_location,
109                                      " located in device ", handle.device(),
110                                      " from device ", dev->attributes().name());
111     }
112     TF_RETURN_IF_ERROR(rm->LookupOrCreate<Var>(
113         handle.container(), handle.name(), &variable, [](Var** ptr) {
114           // This var is uninitialized for now.
115           *ptr = new Var(DT_INVALID);
116           return Status::OK();
117         }));
118     result->emplace_back(var_idx, handle.name(), variable);
119   }
120   return Status::OK();
121 }
122 
InputsFromContext(OpKernelContext * ctx)123 std::vector<const Tensor*> InputsFromContext(OpKernelContext* ctx) {
124   std::vector<const Tensor*> inputs;
125   inputs.reserve(ctx->num_inputs());
126   for (int input_idx = 0; input_idx < ctx->num_inputs(); input_idx++) {
127     inputs.push_back(&ctx->input(input_idx));
128   }
129   return inputs;
130 }
131 
LockVariables(absl::Span<VariableInfo> variables)132 Status LockVariables(absl::Span<VariableInfo> variables) {
133   std::vector<int> lock_order(variables.size());
134   std::iota(lock_order.begin(), lock_order.end(), 0);
135 
136   // VariableInfoComparator orders all empty VariableInfo instances as
137   // equivalent so it looks like we may want to stable sort these to maintain a
138   // deterministic order between the empty VariableInfo instances.  However
139   // since we're sorting by pointer value the sort is pretty non-deterministic
140   // anyway so we don't bother using std::stable_sort for now.
141   absl::c_sort(lock_order, [&](int a, int b) {
142     if (variables[a].var() && variables[b].var()) {
143       return variables[a].var()->mu() < variables[b].var()->mu();
144     }
145 
146     // Move all the empty VariableInfo instances to the end.
147     return variables[a].var() != nullptr;
148   });
149 
150   mutex* prev = nullptr;
151   for (int i : lock_order) {
152     Var* variable = variables[i].var();
153     if (variable == nullptr) {
154       // All empty VariableInfo instances are at the end of the order
155       // so we're done.
156       break;
157     }
158     mutex* mu = variable->mu();
159     if (prev == mu) {
160       // It is an error to pass the same variable handle twice to the same XLA
161       // cluster because we would not handle variable updates correctly.  Any
162       // locks we have already acquired will be released when the VariableInfo
163       // objects are destroyed.
164       // TODO(b/128495870) Add support for passing aliased resource variables.
165       return errors::Unimplemented("Duplicate variable passed to XLA cluster");
166     }
167     VLOG(4) << "Acquiring lock for variable "
168             << reinterpret_cast<void*>(variable);
169     mu->lock();
170     variables[i].set_lock_held();
171     prev = mu;
172   }
173   VLOG(4) << "Finished acquiring variable locks.";
174   return Status::OK();
175 }
176 
SnapshotResourceVariables(OpKernelContext * ctx,absl::Span<const int> variable_indices,absl::Span<VariableInfo const> variable_infos,ResourceVarsSnapshot * result)177 Status SnapshotResourceVariables(OpKernelContext* ctx,
178                                  absl::Span<const int> variable_indices,
179                                  absl::Span<VariableInfo const> variable_infos,
180                                  ResourceVarsSnapshot* result) {
181   for (int i = 0, end = variable_indices.size(); i < end; i++) {
182     Var* var = variable_infos[i].var();
183     (*result)[variable_indices[i]] =
184         var ? absl::make_optional(*var->tensor()) : absl::nullopt;
185   }
186   return Status::OK();
187 }
188 
XlaComputationLaunchContext(xla::LocalClient * client,se::DeviceMemoryAllocator * xla_allocator,int device_ordinal,bool allocate_xla_tensors,bool use_multiple_streams)189 XlaComputationLaunchContext::XlaComputationLaunchContext(
190     xla::LocalClient* client, se::DeviceMemoryAllocator* xla_allocator,
191     int device_ordinal, bool allocate_xla_tensors, bool use_multiple_streams)
192     : client_(client),
193       xla_allocator_(xla_allocator),
194       allocate_xla_tensors_(allocate_xla_tensors),
195       use_multiple_streams_(use_multiple_streams),
196       device_ordinal_(device_ordinal) {
197   if (use_multiple_streams_) {
198     CHECK(allocate_xla_tensors_) << "To use multiple streams correctly we must "
199                                     "be allocating XLA tensors!";
200   }
201 }
202 
203 // Fills in `execution_input` with `buffer` for `index`.
PopulateExecutionInputBuffer(xla::ExecutionInput & execution_input,xla::ShapeIndex index,se::DeviceMemoryBase & buffer,bool donate_buffer,int device_ordinal,se::DeviceMemoryAllocator * allocator)204 static void PopulateExecutionInputBuffer(xla::ExecutionInput& execution_input,
205                                          xla::ShapeIndex index,
206                                          se::DeviceMemoryBase& buffer,
207                                          bool donate_buffer, int device_ordinal,
208                                          se::DeviceMemoryAllocator* allocator) {
209   xla::MaybeOwningDeviceMemory* in_buffer =
210       execution_input.MutableBuffer(index);
211   if (donate_buffer) {
212     *in_buffer = se::OwningDeviceMemory(buffer, device_ordinal, allocator);
213     buffer = se::DeviceMemoryBase();
214   } else {
215     *in_buffer = buffer;
216   }
217 }
218 
219 xla::StatusOr<std::vector<xla::ExecutionInput>>
PopulateInputs(OpKernelContext * ctx,const XlaCompiler::CompilationResult * compilation_result,const std::map<int,const Tensor * > & resource_vars,int missing_ctx_input_prefix,const xla::HloInputOutputAliasConfig & input_output_alias)220 XlaComputationLaunchContext::PopulateInputs(
221     OpKernelContext* ctx,
222     const XlaCompiler::CompilationResult* compilation_result,
223     const std::map<int, const Tensor*>& resource_vars,
224     int missing_ctx_input_prefix,
225     const xla::HloInputOutputAliasConfig& input_output_alias) {
226   std::vector<xla::ExecutionInput> arguments;
227   arguments.reserve(compilation_result->xla_input_shapes.size());
228 
229   xla::TransferManager* transfer_manager =
230       client_->backend().transfer_manager();
231   for (int i = 0, end = compilation_result->xla_input_shapes.size(); i < end;
232        ++i) {
233     int arg_num = compilation_result->input_mapping[i];
234     CHECK_GE(arg_num, missing_ctx_input_prefix);
235     const xla::Shape& shape = compilation_result->xla_input_shapes[i];
236     const xla::Shape& device_shape =
237         transfer_manager->HostShapeToDeviceShape(shape);
238 
239     bool is_resource_variable = resource_vars.count(arg_num);
240     bool is_updated_resource_variable =
241         is_resource_variable &&
242         absl::c_any_of(compilation_result->resource_updates,
243                        [&](const XlaCompiler::ResourceUpdate& update) {
244                          return update.input_index == i && update.modified;
245                        });
246 
247     const Tensor* t = is_resource_variable
248                           ? resource_vars.at(arg_num)
249                           : &(ctx->input(arg_num - missing_ctx_input_prefix));
250     CHECK(t);
251     bool donate_buffer =
252         t->RefCountIsOne() && is_updated_resource_variable &&
253         input_output_alias.ParameterHasAlias(i, xla::ShapeIndex{});
254     VLOG(3) << "Processing input: " << i
255             << "; is_resource_variable=" << is_resource_variable
256             << "; is_updated_resource_variable=" << is_updated_resource_variable
257             << "; donate_buffer=" << donate_buffer;
258 
259     if (use_multiple_streams_) {
260       CHECK(ctx->op_device_context() && ctx->op_device_context()->stream())
261           << "Must have a stream available when using XLA tensors!";
262       XlaTensor* xla_tensor = XlaTensor::FromTensor(t);
263       CHECK(xla_tensor);
264       xla_tensor->WaitForDefinitionEventOnStream(
265           ctx->op_device_context()->stream());
266     }
267 
268     arguments.emplace_back(device_shape, shape);
269     xla::ExecutionInput& execution_input = arguments.back();
270     if (xla::Shape::Equal().MinorToMajorOnlyInLayout()(shape, device_shape)) {
271       se::DeviceMemoryBase dmem = XlaTensor::DeviceMemoryFromTensor(*t);
272       PopulateExecutionInputBuffer(execution_input, xla::ShapeIndex{}, dmem,
273                                    donate_buffer, device_ordinal_,
274                                    xla_allocator_);
275     } else {
276       XlaTensor* xla_tensor = XlaTensor::FromTensor(t);
277       CHECK(xla_tensor && xla_tensor->has_shaped_buffer());
278       xla_tensor->shaped_buffer().buffers().ForEachMutableElement(
279           [&](const xla::ShapeIndex& index, se::DeviceMemoryBase* buffer) {
280             PopulateExecutionInputBuffer(execution_input, index, *buffer,
281                                          donate_buffer, device_ordinal_,
282                                          xla_allocator_);
283           });
284     }
285   }
286   return std::move(arguments);
287 }
288 
289 // Construct the tensor for the given type and buffer.
MakeTensor(DataType dtype,const TensorShape & shape,se::DeviceMemoryBase buffer,Allocator * allocator)290 static Tensor MakeTensor(DataType dtype, const TensorShape& shape,
291                          se::DeviceMemoryBase buffer, Allocator* allocator) {
292   size_t expected_size = shape.num_elements() * DataTypeSize(dtype);
293   auto* tensor_buffer = new XlaTensorBuffer(buffer.opaque(), expected_size,
294                                             buffer.size(), allocator);
295   Tensor t(dtype, shape, tensor_buffer);
296   tensor_buffer->Unref();
297   return t;
298 }
299 
300 // Get aliased tensor, or make a new one for the corresponding output operation.
GetOrCreateTensorForOutput(int output_num,OpKernelContext * ctx,int missing_ctx_input_prefix,const xla::HloInputOutputAliasConfig & input_output_alias,absl::Span<const int> input_mapping,const std::map<int,const Tensor * > & resource_vars_snapshots,DataType output_dtype,const TensorShape & output_shape,se::DeviceMemoryBase output_buffer,Allocator * output_allocator)301 static Tensor GetOrCreateTensorForOutput(
302     int output_num, OpKernelContext* ctx, int missing_ctx_input_prefix,
303     const xla::HloInputOutputAliasConfig& input_output_alias,
304     absl::Span<const int> input_mapping,
305     const std::map<int, const Tensor*>& resource_vars_snapshots,
306     DataType output_dtype, const TensorShape& output_shape,
307     se::DeviceMemoryBase output_buffer, Allocator* output_allocator) {
308   xla::ShapeIndex output_index = input_output_alias.shape().IsTuple()
309                                      ? xla::ShapeIndex({output_num})
310                                      : xla::ShapeIndex({});
311 
312   CHECK(input_output_alias.shape().IsTuple() || output_num == 0);
313   if (absl::optional<xla::HloInputOutputAliasConfig::Alias> alias =
314           input_output_alias.GetAliasedParameter(output_index)) {
315     VLOG(3) << "Found alias: " << alias->ToString();
316     int tf_param =
317         input_mapping[alias->parameter_number] - missing_ctx_input_prefix;
318     const Tensor input_tensor =
319         ctx->input(tf_param).dtype() != DT_RESOURCE
320             ? ctx->input(tf_param)
321             : *resource_vars_snapshots.at(missing_ctx_input_prefix + tf_param);
322     if (output_buffer.opaque() == input_tensor.data()) {
323       return input_tensor;
324     }
325   }
326   return MakeTensor(output_dtype, output_shape, output_buffer,
327                     output_allocator);
328 }
329 
PopulateXlaTensor(Tensor * output_tensor,xla::ScopedShapedBuffer * output,int output_num,se::Stream * stream,bool use_multiple_streams,std::shared_ptr<se::Event> definition_event)330 static void PopulateXlaTensor(Tensor* output_tensor,
331                               xla::ScopedShapedBuffer* output, int output_num,
332                               se::Stream* stream, bool use_multiple_streams,
333                               std::shared_ptr<se::Event> definition_event) {
334   XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor);
335   CHECK(xla_tensor);
336   xla_tensor->set_shaped_buffer(output->TakeSubTree({output_num}));
337   if (use_multiple_streams) {
338     xla_tensor->ResetDefinitionEvent(definition_event, stream);
339   }
340 }
341 
342 // Sets output `output_num` for `ctx` provided it is known at a compile time.
SetOutputForConstant(OpKernelContext * ctx,se::Stream * stream,const XlaCompiler::CompilationResult * compilation_result,int output_num)343 static Status SetOutputForConstant(
344     OpKernelContext* ctx, se::Stream* stream,
345     const XlaCompiler::CompilationResult* compilation_result, int output_num) {
346   CHECK(compilation_result->outputs[output_num].is_constant);
347   const Tensor& const_tensor =
348       compilation_result->outputs[output_num].constant_value;
349   Tensor* output_tensor;
350   if (stream && const_tensor.TotalBytes() > 0) {
351     // Copy host -> device. (Empty tensors don't have backing buffers.)
352     // Manually allocate memory using an XlaTensorBuffer so we can allocate
353     // as much memory as the device requires (as given by
354     // GetByteSizeRequirement). This avoids XlaTransferManager having to
355     // reallocate the device buffer later.
356     VLOG(1) << "Constant output tensor on device";
357 
358     TF_RETURN_IF_ERROR(
359         ctx->allocate_output(output_num, const_tensor.shape(), &output_tensor));
360     Device* device = dynamic_cast<Device*>(ctx->device());
361     if (device == nullptr) {
362       return errors::Internal("DeviceBase was not a Device.");
363     }
364     ctx->op_device_context()->CopyCPUTensorToDevice(
365         &const_tensor, device, output_tensor,
366         [&](Status status) { TF_CHECK_OK(status); });
367 
368     if (device->device_type() == DEVICE_GPU) {
369       // The GPUDeviceContext enqueues the host->device transfer in a
370       // separate stream from the main compute stream. We must ensure the
371       // compute stream is synchronized with the host->device transfer
372       // stream now otherwise we will create a race condition.
373       auto* gpu_device_context =
374           static_cast<GPUDeviceContext*>(ctx->op_device_context());
375       gpu_device_context->stream()->ThenWaitFor(
376           gpu_device_context->host_to_device_stream());
377     }
378   } else {
379     // No copy required.
380     ctx->set_output(output_num, const_tensor);
381     output_tensor = ctx->mutable_output(output_num);
382   }
383   return Status::OK();
384 }
385 
GetOrCreateResourceVar(OpKernelContext * ctx,const ResourceHandle & handle,const XlaCompiler::ResourceUpdate & write)386 static xla::StatusOr<Var*> GetOrCreateResourceVar(
387     OpKernelContext* ctx, const ResourceHandle& handle,
388     const XlaCompiler::ResourceUpdate& write) {
389   Var* variable = nullptr;
390   TF_RETURN_IF_ERROR(
391       LookupOrCreateResource<Var>(ctx, handle, &variable, [&write](Var** ptr) {
392         *ptr = new Var(write.type);
393         return Status::OK();
394       }));
395   return variable;
396 }
397 
GatherVariableInfo(OpKernelContext * ctx,const XlaCompiler::CompilationResult & compilation_result,int missing_ctx_input_prefix)398 xla::StatusOr<std::vector<VariableInfo>> GatherVariableInfo(
399     OpKernelContext* ctx,
400     const XlaCompiler::CompilationResult& compilation_result,
401     int missing_ctx_input_prefix) {
402   std::vector<VariableInfo> out;
403   out.reserve(compilation_result.resource_updates.size());
404   for (int i = 0; i < compilation_result.resource_updates.size(); ++i) {
405     const XlaCompiler::ResourceUpdate& write =
406         compilation_result.resource_updates[i];
407     int actual_input_index = write.input_index - missing_ctx_input_prefix;
408     if (actual_input_index < 0 || actual_input_index >= ctx->num_inputs()) {
409       return errors::Internal("Invalid input index for variable write.");
410     }
411 
412     const ResourceHandle handle = HandleFromInput(ctx, actual_input_index);
413     TF_ASSIGN_OR_RETURN(Var * variable,
414                         GetOrCreateResourceVar(ctx, handle, write));
415     out.emplace_back(actual_input_index, handle.name(), variable);
416   }
417   return std::move(out);
418 }
419 
PopulateOutputs(OpKernelContext * ctx,const XlaCompiler::CompilationResult * compilation_result,ScopedShapedBuffer output,int missing_ctx_input_prefix,absl::Span<VariableInfo> variable_infos,const xla::HloInputOutputAliasConfig & input_output_alias,const std::map<int,const Tensor * > & resource_vars)420 Status XlaComputationLaunchContext::PopulateOutputs(
421     OpKernelContext* ctx,
422     const XlaCompiler::CompilationResult* compilation_result,
423     ScopedShapedBuffer output, int missing_ctx_input_prefix,
424     absl::Span<VariableInfo> variable_infos,
425     const xla::HloInputOutputAliasConfig& input_output_alias,
426     const std::map<int, const Tensor*>& resource_vars) {
427   se::Stream* stream =
428       ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
429   Allocator* allocator = ctx->device()->GetAllocator({});
430 
431   // Computation output should always be a tuple.
432   VLOG(2) << "Result tuple shape: " << output.on_host_shape().DebugString();
433   VLOG(2) << "Result tuple shape (on device): "
434           << output.on_device_shape().DebugString();
435   CHECK_EQ(ctx->num_outputs(), compilation_result->outputs.size());
436 
437   // If the on-host-shape isn't a tuple, create a new single-element tuple
438   // buffer with a nullptr root index table. This allows the code below to treat
439   // output as a tuple unconditionally.
440   if (!output.on_host_shape().IsTuple()) {
441     ShapedBuffer nontuple_buffer = output.release();
442     ShapedBuffer buffer(
443         xla::ShapeUtil::MakeTupleShape({nontuple_buffer.on_host_shape()}),
444         xla::ShapeUtil::MakeTupleShape({nontuple_buffer.on_device_shape()}),
445         output.device_ordinal());
446     buffer.buffers().CopySubtreeFrom(nontuple_buffer.buffers(),
447                                      /*source_base_index=*/{},
448                                      /*target_base_index=*/{0});
449     output = ScopedShapedBuffer(std::move(buffer), output.memory_allocator());
450   }
451 
452   std::shared_ptr<se::Event> definition_event;
453   if (use_multiple_streams_) {
454     definition_event = std::make_shared<se::Event>(stream->parent());
455     if (!definition_event->Init()) {
456       return errors::Internal("Failed to initialize tensor definition event.");
457     }
458     stream->ThenRecordEvent(definition_event.get());
459   }
460 
461   std::vector<TensorShape> output_tensor_shapes;
462   output_tensor_shapes.reserve(ctx->num_outputs());
463   if (output.on_host_shape().is_dynamic()) {
464     TF_ASSIGN_OR_RETURN(
465         auto transfer_manager,
466         xla::TransferManager::GetForPlatform(stream->parent()->platform()));
467 
468     xla::Shape output_device_shape = output.on_device_shape();
469     TF_RETURN_IF_ERROR(transfer_manager->ReadDynamicShapes(
470         stream, &output, &output_device_shape));
471 
472     output.set_shapes(output_device_shape, output_device_shape);
473     for (int i = 0; i < ctx->num_outputs(); ++i) {
474       const xla::Shape& subshape =
475           xla::ShapeUtil::GetSubshape(output_device_shape, {i});
476       TensorShape shape;
477       TF_RETURN_IF_ERROR(XLAShapeToTensorShape(subshape, &shape));
478       output_tensor_shapes.push_back(shape);
479     }
480   } else {
481     for (int i = 0; i < ctx->num_outputs(); ++i) {
482       output_tensor_shapes.push_back(compilation_result->outputs[i].shape);
483     }
484   }
485 
486   // Copy XLA results to the OpOutputList.
487   int output_num = 0;
488   for (int i = 0, end = ctx->num_outputs(); i < end; ++i) {
489     const TensorShape& shape = output_tensor_shapes[i];
490     const DataType& type = compilation_result->outputs[i].type;
491     VLOG(2) << "Populating output for retval " << i << " shape "
492             << shape.DebugString() << " type " << DataTypeString(type);
493     if (type == DT_VARIANT) {
494       return errors::Unimplemented(
495           "Support for TensorList crossing the XLA/TF boundary "
496           "is not implemented");
497     }
498 
499     if (compilation_result->outputs[i].is_constant) {
500       TF_RETURN_IF_ERROR(
501           SetOutputForConstant(ctx, stream, compilation_result, i));
502     } else if (type == DT_RESOURCE) {
503       int input_index =
504           compilation_result->outputs[i].input_index - missing_ctx_input_prefix;
505       TF_RET_CHECK(input_index >= 0 && input_index < ctx->num_inputs())
506           << "Invalid input for outputs " << i << ": " << input_index;
507       ctx->set_output(i, ctx->input(input_index));
508     } else {
509       if (allocate_xla_tensors_) {
510         Tensor* output_tensor;
511         TF_RETURN_IF_ERROR(ctx->allocate_output(i, shape, &output_tensor));
512         if (output_tensor->TotalBytes() > 0) {
513           PopulateXlaTensor(output_tensor, &output, output_num, stream,
514                             use_multiple_streams_, definition_event);
515         }
516       } else {
517         se::DeviceMemoryBase buffer = output.buffer({output_num});
518         Tensor output_tensor = GetOrCreateTensorForOutput(
519             output_num, ctx, missing_ctx_input_prefix, input_output_alias,
520             compilation_result->input_mapping, resource_vars,
521             ctx->expected_output_dtype(i), shape, buffer, allocator);
522         ctx->set_output(i, output_tensor);
523       }
524       output.set_buffer(se::OwningDeviceMemory(), {output_num});
525       ++output_num;
526     }
527   }
528 
529   // input_index -> index into variable_infos.
530   absl::flat_hash_map<int, int> variable_info_lookup;
531   for (int i = 0; i < variable_infos.size(); i++) {
532     variable_info_lookup.emplace(variable_infos[i].index(), i);
533   }
534 
535   // Apply variable updates, if any.
536   for (int i = 0, end = compilation_result->resource_updates.size(); i < end;
537        ++i) {
538     const XlaCompiler::ResourceUpdate& write =
539         compilation_result->resource_updates[i];
540     int actual_input_index = write.input_index - missing_ctx_input_prefix;
541     CHECK_GE(actual_input_index, 0);
542     CHECK_LT(actual_input_index, ctx->num_inputs());
543     Var* var = variable_infos[variable_info_lookup[actual_input_index]].var();
544     CHECK(var);
545 
546     VLOG(2) << "Updating variable #" << i
547             << " at input index: " << actual_input_index << " with shape "
548             << write.shape.DebugString() << "; variable tensor has shape: "
549             << var->tensor()->shape().DebugString();
550 
551     if (var->is_initialized && var->tensor()->dtype() != write.type) {
552       return errors::Internal("Mismatched type in variable write");
553     }
554 
555     Tensor output_tensor;
556     if (allocate_xla_tensors_) {
557       TF_RETURN_IF_ERROR(
558           ctx->allocate_temp(write.type, write.shape, &output_tensor));
559       if (write.shape.num_elements() > 0) {
560         PopulateXlaTensor(&output_tensor, &output, output_num, stream,
561                           use_multiple_streams_, definition_event);
562       }
563     } else {
564       se::DeviceMemoryBase buffer = output.buffer({output_num});
565       output_tensor = GetOrCreateTensorForOutput(
566           output_num, ctx, missing_ctx_input_prefix, input_output_alias,
567           compilation_result->input_mapping, resource_vars, write.type,
568           write.shape, buffer, allocator);
569     }
570     output.set_buffer(se::OwningDeviceMemory(), {output_num});
571     var->is_initialized |= write.modified;
572     *var->tensor() = output_tensor;
573     ++output_num;
574   }
575   return Status::OK();
576 }
577 
578 xla::StatusOr<std::vector<XlaCompiler::Argument>>
BuildXlaCompilerArguments(absl::Span<int const> must_be_constant_idxs,absl::Span<const Tensor * const> inputs,absl::Span<VariableInfo const> variable_args,Device * device)579 XlaComputationLaunchContext::BuildXlaCompilerArguments(
580     absl::Span<int const> must_be_constant_idxs,
581     absl::Span<const Tensor* const> inputs,
582     absl::Span<VariableInfo const> variable_args, Device* device) {
583   CHECK(absl::c_is_sorted(must_be_constant_idxs));
584   std::vector<XlaCompiler::Argument> out;
585   out.resize(inputs.size());
586 
587   // TODO(cheshire): Avoid duplication with framework/op_kernel.h
588   DeviceContext* device_context = nullptr;
589   TF_RETURN_IF_ERROR(device->TryGetDeviceContext(&device_context));
590   bool using_default_context = false;
591   auto cleanup = xla::MakeCleanup([&] {
592     if (device_context != nullptr && !using_default_context) {
593       device_context->Unref();
594     }
595   });
596   if (device_context == nullptr) {
597     using_default_context = true;
598     auto* dev_info = device->tensorflow_gpu_device_info();
599     if (dev_info) device_context = dev_info->default_context;
600   }
601 
602   absl::flat_hash_map<int, const VariableInfo*> variable_info_lookup;
603   for (const VariableInfo& info : variable_args) {
604     CHECK(!info.var() || info.lock_held())
605         << "Need to hold the lock on resource variables "
606            "before calling BuildXlaCompilerArguments";
607     variable_info_lookup.emplace(info.index(), &info);
608   }
609 
610   for (int64 input_num = 0; input_num < inputs.size(); ++input_num) {
611     const Tensor* input = inputs[input_num];
612 
613     XlaCompiler::Argument& arg = out[input_num];
614     if (variable_info_lookup.count(input_num)) {
615       // Handles resource variables.
616       TF_RET_CHECK(input->dtype() == DT_RESOURCE);
617       const VariableInfo& variable = *variable_info_lookup[input_num];
618       arg.name = std::string(variable.name());
619       arg.kind = XlaCompiler::Argument::kResource;
620       arg.resource_kind = XlaResource::kVariable;
621       if (variable.var() && variable.var()->is_initialized) {
622         const Tensor* value = variable.var()->tensor();
623         arg.type = value->dtype();
624         arg.shape = value->shape();
625         arg.initialized = true;
626       } else {
627         // The values of uninitialized variables are not passed as inputs, since
628         // they are meaningless. However, it is legal to assign to a resource
629         // variable for the first time inside the XLA computation, so we do
630         // permit uninitialized variables.
631         arg.initialized = false;
632         arg.type = DT_INVALID;
633         arg.shape = TensorShape();
634       }
635 
636       if (absl::c_binary_search(must_be_constant_idxs, input_num)) {
637         TF_RET_CHECK(variable.var() && variable.var()->is_initialized);
638         const Tensor* value = variable.var()->tensor();
639         Tensor value_on_host(value->dtype(), value->shape());
640         if (!device_context) {
641           value_on_host = *value;
642         } else {
643           TF_RETURN_IF_ERROR(device_context->CopyDeviceTensorToCPUSync(
644               value, "", device, &value_on_host));
645         }
646         arg.kind = XlaCompiler::Argument::kConstantResource;
647         arg.constant_value = value_on_host;
648       }
649     } else if (absl::c_binary_search(must_be_constant_idxs, input_num)) {
650       arg.kind = XlaCompiler::Argument::kConstant;
651       arg.type = input->dtype();
652       arg.shape = input->shape();
653       arg.constant_value = *input;
654     } else {
655       // Normal inputs.
656       TF_RET_CHECK(input->dtype() != DT_RESOURCE);
657       if (input->NumElements() > 0) {
658         arg.kind = XlaCompiler::Argument::kParameter;
659       } else {
660         arg.kind = XlaCompiler::Argument::kConstant;
661         arg.constant_value = *input;
662       }
663       arg.type = input->dtype();
664       arg.shape = input->shape();
665     }
666   }
667 
668   return out;
669 }
670 
671 }  // namespace tensorflow
672