• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/jit/xla_launch_util.h"
17 
18 #include <memory>
19 
20 #include "absl/algorithm/container.h"
21 #include "absl/memory/memory.h"
22 #include "tensorflow/compiler/jit/defs.h"
23 #include "tensorflow/compiler/tf2xla/shape_util.h"
24 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
25 #include "tensorflow/compiler/xla/client/client_library.h"
26 #include "tensorflow/compiler/xla/client/local_client.h"
27 #include "tensorflow/compiler/xla/statusor.h"
28 #include "tensorflow/core/common_runtime/dma_helper.h"
29 #include "tensorflow/core/common_runtime/function.h"
30 #include "tensorflow/core/common_runtime/gpu_device_context.h"
31 #include "tensorflow/core/framework/allocator.h"
32 #include "tensorflow/core/framework/node_def_util.h"
33 #include "tensorflow/core/framework/op.h"
34 #include "tensorflow/core/framework/op_kernel.h"
35 #include "tensorflow/core/framework/resource_mgr.h"
36 #include "tensorflow/core/framework/tensor.h"
37 #include "tensorflow/core/framework/types.h"
38 #include "tensorflow/core/lib/core/errors.h"
39 #include "tensorflow/core/lib/core/refcount.h"
40 #include "tensorflow/core/util/stream_executor_util.h"
41 
42 namespace tensorflow {
43 namespace {
44 using xla::ScopedShapedBuffer;
45 using xla::ShapedBuffer;
46 
47 // Fetch the platform Id from device.
XlaPlatformInfoFromDevice(DeviceBase * device_base)48 se::Platform::Id XlaPlatformInfoFromDevice(DeviceBase* device_base) {
49   auto device = static_cast<Device*>(device_base);
50   se::Platform::Id platform_id = nullptr;
51   if (device->device_type() == DEVICE_CPU) {
52     platform_id = se::host::kHostPlatformId;
53   }
54 
55   return platform_id;
56 }
57 
58 }  // anonymous namespace
59 
VariableInfo(int index,absl::string_view name,Var * var,const absl::optional<ManagedStackTrace> & definition_stack_trace)60 VariableInfo::VariableInfo(
61     int index, absl::string_view name, Var* var,
62     const absl::optional<ManagedStackTrace>& definition_stack_trace)
63     : index_(index),
64       name_(name),
65       var_(var),
66       definition_stack_trace_(definition_stack_trace) {}
67 
VariableInfo(VariableInfo && other)68 VariableInfo::VariableInfo(VariableInfo&& other)
69     : index_(other.index_),
70       var_(other.var_),
71       definition_stack_trace_(other.definition_stack_trace_),
72       lock_held_(other.lock_held_) {
73   other.index_ = -1;
74   other.var_ = nullptr;
75 }
76 
operator =(VariableInfo && other)77 VariableInfo& VariableInfo::operator=(VariableInfo&& other) {
78   index_ = other.index_;
79   var_ = other.var_;
80   lock_held_ = other.lock_held_;
81   definition_stack_trace_ = other.definition_stack_trace_;
82 
83   other.index_ = -1;
84   other.var_ = nullptr;
85 
86   return *this;
87 }
88 
~VariableInfo()89 VariableInfo::~VariableInfo() {
90   // Release the variable's lock if we hold it. Ensures that the lock is
91   // released even on error.  It does not matter in what order we release the
92   // locks.
93   if (var()) {
94     if (lock_held()) {
95       var()->mu()->unlock();
96     }
97 
98     // Unref the variable so it can be released by ResourceManager.
99     var()->Unref();
100   }
101 }
102 
GetVariableInfosFromInputs(ResourceMgr * rm,DeviceBase * dev,absl::Span<const Tensor * const> inputs,absl::Span<const int> variable_indices,std::vector<VariableInfo> * result)103 Status GetVariableInfosFromInputs(ResourceMgr* rm, DeviceBase* dev,
104                                   absl::Span<const Tensor* const> inputs,
105                                   absl::Span<const int> variable_indices,
106                                   std::vector<VariableInfo>* result) {
107   result->clear();
108   result->reserve(variable_indices.size());
109   for (int var_idx : variable_indices) {
110     Var* variable = nullptr;
111     ResourceHandle handle = inputs[var_idx]->flat<ResourceHandle>()(0);
112     if (handle.device() != dev->attributes().name()) {
113       std::string definition_location =
114           DefinitionLocationMsg(handle.definition_stack_trace());
115       return errors::InvalidArgument("Trying to access resource ",
116                                      handle.name(), definition_location,
117                                      " located in device ", handle.device(),
118                                      " from device ", dev->attributes().name());
119     }
120     TF_RETURN_IF_ERROR(rm->LookupOrCreate<Var>(
121         handle.container(), handle.name(), &variable, [](Var** ptr) {
122           // This var is uninitialized for now.
123           *ptr = new Var(DT_INVALID);
124           return Status::OK();
125         }));
126     result->emplace_back(var_idx, handle.name(), variable,
127                          handle.definition_stack_trace());
128   }
129   return Status::OK();
130 }
131 
InputsFromContext(OpKernelContext * ctx)132 std::vector<const Tensor*> InputsFromContext(OpKernelContext* ctx) {
133   std::vector<const Tensor*> inputs;
134   inputs.reserve(ctx->num_inputs());
135   for (int input_idx = 0; input_idx < ctx->num_inputs(); input_idx++) {
136     inputs.push_back(&ctx->input(input_idx));
137   }
138   return inputs;
139 }
140 
LockVariables(absl::Span<VariableInfo> variables)141 Status LockVariables(absl::Span<VariableInfo> variables) {
142   std::vector<int> lock_order(variables.size());
143   std::iota(lock_order.begin(), lock_order.end(), 0);
144 
145   // VariableInfoComparator orders all empty VariableInfo instances as
146   // equivalent so it looks like we may want to stable sort these to maintain a
147   // deterministic order between the empty VariableInfo instances.  However
148   // since we're sorting by pointer value the sort is pretty non-deterministic
149   // anyway so we don't bother using std::stable_sort for now.
150   absl::c_sort(lock_order, [&](int a, int b) {
151     if (variables[a].var() && variables[b].var()) {
152       return variables[a].var()->mu() < variables[b].var()->mu();
153     }
154 
155     // Move all the empty VariableInfo instances to the end.
156     return variables[a].var() != nullptr;
157   });
158 
159   mutex* prev = nullptr;
160   for (int i : lock_order) {
161     Var* variable = variables[i].var();
162     if (variable == nullptr) {
163       // All empty VariableInfo instances are at the end of the order
164       // so we're done.
165       break;
166     }
167     mutex* mu = variable->mu();
168     if (prev == mu) {
169       // It is an error to pass the same variable handle twice to the same XLA
170       // cluster because we would not handle variable updates correctly.  Any
171       // locks we have already acquired will be released when the VariableInfo
172       // objects are destroyed.
173       // TODO(b/128495870) Add support for passing aliased resource variables.
174       return errors::Unimplemented("Duplicate variable passed to XLA cluster");
175     }
176     VLOG(4) << "Acquiring lock for variable "
177             << reinterpret_cast<void*>(variable);
178     mu->lock();
179     variables[i].set_lock_held();
180     prev = mu;
181   }
182   VLOG(4) << "Finished acquiring variable locks.";
183   return Status::OK();
184 }
185 
SnapshotResourceVariables(OpKernelContext * ctx,absl::Span<const int> variable_indices,absl::Span<VariableInfo const> variable_infos,ResourceVarsSnapshot * result)186 Status SnapshotResourceVariables(OpKernelContext* ctx,
187                                  absl::Span<const int> variable_indices,
188                                  absl::Span<VariableInfo const> variable_infos,
189                                  ResourceVarsSnapshot* result) {
190   for (int i = 0, end = variable_indices.size(); i < end; i++) {
191     Var* var = variable_infos[i].var();
192     (*result)[variable_indices[i]] =
193         var ? absl::make_optional(*var->tensor()) : absl::nullopt;
194   }
195   return Status::OK();
196 }
197 
XlaComputationLaunchContext(xla::LocalClient * client,se::DeviceMemoryAllocator * xla_allocator,int device_ordinal,bool allocate_xla_tensors,bool use_multiple_streams)198 XlaComputationLaunchContext::XlaComputationLaunchContext(
199     xla::LocalClient* client, se::DeviceMemoryAllocator* xla_allocator,
200     int device_ordinal, bool allocate_xla_tensors, bool use_multiple_streams)
201     : client_(client),
202       xla_allocator_(xla_allocator),
203       allocate_xla_tensors_(allocate_xla_tensors),
204       use_multiple_streams_(use_multiple_streams),
205       device_ordinal_(device_ordinal) {
206   if (use_multiple_streams_) {
207     CHECK(allocate_xla_tensors_) << "To use multiple streams correctly we must "
208                                     "be allocating XLA tensors!";
209   }
210 }
211 
212 // Fills in `execution_input` with `buffer` for `index`.
PopulateExecutionInputBuffer(xla::ExecutionInput & execution_input,xla::ShapeIndex index,se::DeviceMemoryBase buffer,bool donate_buffer,int device_ordinal,se::DeviceMemoryAllocator * allocator)213 static void PopulateExecutionInputBuffer(xla::ExecutionInput& execution_input,
214                                          xla::ShapeIndex index,
215                                          se::DeviceMemoryBase buffer,
216                                          bool donate_buffer, int device_ordinal,
217                                          se::DeviceMemoryAllocator* allocator) {
218   xla::MaybeOwningDeviceMemory* in_buffer =
219       execution_input.MutableBuffer(index);
220   if (donate_buffer) {
221     // Here we pass ownership of the buffer to execution_input without releasing
222     // ownership from the caller of PopulateExecutionInputBuffer. If execution
223     // succeeds, we'll take back that duplicate ownership in
224     // GetOrCreateTensorForOutput. If execution fails, the ExecutionInput will
225     // release that duplicate ownership automatically.
226     *in_buffer = se::OwningDeviceMemory(buffer, device_ordinal, allocator);
227   } else {
228     *in_buffer = buffer;
229   }
230 }
231 
232 StatusOr<std::vector<xla::ExecutionInput>>
PopulateInputs(OpKernelContext * ctx,const XlaCompiler::CompilationResult * compilation_result,const std::map<int,const Tensor * > & resource_vars,int missing_ctx_input_prefix,const xla::HloInputOutputAliasConfig & input_output_alias)233 XlaComputationLaunchContext::PopulateInputs(
234     OpKernelContext* ctx,
235     const XlaCompiler::CompilationResult* compilation_result,
236     const std::map<int, const Tensor*>& resource_vars,
237     int missing_ctx_input_prefix,
238     const xla::HloInputOutputAliasConfig& input_output_alias) {
239   std::vector<xla::ExecutionInput> arguments;
240   arguments.reserve(compilation_result->xla_input_shapes.size());
241 
242   xla::TransferManager* transfer_manager =
243       client_->backend().transfer_manager();
244   for (int i = 0, end = compilation_result->xla_input_shapes.size(); i < end;
245        ++i) {
246     int arg_num = compilation_result->input_mapping[i];
247     CHECK_GE(arg_num, missing_ctx_input_prefix);
248     const xla::Shape& shape = compilation_result->xla_input_shapes[i];
249     const xla::Shape& device_shape =
250         transfer_manager->HostShapeToDeviceShape(shape);
251 
252     bool is_resource_variable = resource_vars.count(arg_num);
253     bool is_updated_resource_variable =
254         is_resource_variable &&
255         absl::c_any_of(compilation_result->resource_updates,
256                        [&](const XlaCompiler::ResourceUpdate& update) {
257                          return update.input_index == i && update.modified;
258                        });
259 
260     const Tensor* t = is_resource_variable
261                           ? resource_vars.at(arg_num)
262                           : &(ctx->input(arg_num - missing_ctx_input_prefix));
263     CHECK(t);
264     bool donate_buffer =
265         t->RefCountIsOne() && is_updated_resource_variable &&
266         input_output_alias.ParameterHasAlias(i, xla::ShapeIndex{});
267     VLOG(3) << "Processing input: " << i
268             << "; is_resource_variable=" << is_resource_variable
269             << "; is_updated_resource_variable=" << is_updated_resource_variable
270             << "; donate_buffer=" << donate_buffer;
271 
272     if (use_multiple_streams_) {
273       CHECK(ctx->op_device_context() && ctx->op_device_context()->stream())
274           << "Must have a stream available when using XLA tensors!";
275       XlaTensor* xla_tensor = XlaTensor::FromTensor(t);
276       CHECK(xla_tensor);
277       xla_tensor->WaitForDefinitionEventOnStream(
278           ctx->op_device_context()->stream());
279     }
280 
281     arguments.emplace_back(device_shape, shape);
282     xla::ExecutionInput& execution_input = arguments.back();
283     if (xla::Shape::Equal().MinorToMajorOnlyInLayout()(shape, device_shape)) {
284       se::DeviceMemoryBase dmem = XlaTensor::DeviceMemoryFromTensor(*t);
285       PopulateExecutionInputBuffer(execution_input, xla::ShapeIndex{}, dmem,
286                                    donate_buffer, device_ordinal_,
287                                    xla_allocator_);
288     } else {
289       XlaTensor* xla_tensor = XlaTensor::FromTensor(t);
290       CHECK(xla_tensor && xla_tensor->has_shaped_buffer());
291       xla_tensor->shaped_buffer().buffers().ForEachMutableElement(
292           [&](const xla::ShapeIndex& index, se::DeviceMemoryBase* buffer) {
293             PopulateExecutionInputBuffer(execution_input, index, *buffer,
294                                          donate_buffer, device_ordinal_,
295                                          xla_allocator_);
296           });
297     }
298   }
299   return std::move(arguments);
300 }
301 
302 // Construct the tensor for the given type and buffer.
MakeTensor(DataType dtype,const TensorShape & shape,se::DeviceMemoryBase buffer,Allocator * allocator)303 static Tensor MakeTensor(DataType dtype, const TensorShape& shape,
304                          se::DeviceMemoryBase buffer, Allocator* allocator) {
305   size_t expected_size = shape.num_elements() * DataTypeSize(dtype);
306   auto* tensor_buffer = new XlaTensorBuffer(buffer.opaque(), expected_size,
307                                             buffer.size(), allocator);
308   Tensor t(dtype, shape, tensor_buffer);
309   tensor_buffer->Unref();
310   return t;
311 }
312 
313 // Get aliased tensor from output, or make a new one for the corresponding
314 // output operation. Transfers ownership of the buffer from output to the
315 // returned tensor.
GetOrCreateTensorForOutput(xla::ScopedShapedBuffer & output,int output_num,OpKernelContext * ctx,int missing_ctx_input_prefix,const xla::HloInputOutputAliasConfig & input_output_alias,absl::Span<const int> input_mapping,const std::map<int,const Tensor * > & resource_vars_snapshots,DataType output_dtype,const TensorShape & output_shape,Allocator * output_allocator,bool allocate_xla_tensors,se::Stream * stream,bool use_multiple_streams,std::shared_ptr<se::Event> definition_event)316 static StatusOr<Tensor> GetOrCreateTensorForOutput(
317     xla::ScopedShapedBuffer& output, int output_num, OpKernelContext* ctx,
318     int missing_ctx_input_prefix,
319     const xla::HloInputOutputAliasConfig& input_output_alias,
320     absl::Span<const int> input_mapping,
321     const std::map<int, const Tensor*>& resource_vars_snapshots,
322     DataType output_dtype, const TensorShape& output_shape,
323     Allocator* output_allocator, bool allocate_xla_tensors, se::Stream* stream,
324     bool use_multiple_streams, std::shared_ptr<se::Event> definition_event) {
325   xla::ShapeIndex output_index = input_output_alias.shape().IsTuple()
326                                      ? xla::ShapeIndex({output_num})
327                                      : xla::ShapeIndex({});
328   CHECK(input_output_alias.shape().IsTuple() || output_num == 0);
329   if (absl::optional<xla::HloInputOutputAliasConfig::Alias> alias =
330           input_output_alias.GetAliasedParameter(output_index)) {
331     VLOG(3) << "Found alias: " << alias->ToString();
332     int tf_param =
333         input_mapping[alias->parameter_number] - missing_ctx_input_prefix;
334     const Tensor input_tensor =
335         ctx->input(tf_param).dtype() != DT_RESOURCE
336             ? ctx->input(tf_param)
337             : *resource_vars_snapshots.at(missing_ctx_input_prefix + tf_param);
338     se::DeviceMemoryBase input_buffer =
339         XlaTensor::DeviceMemoryFromTensor(input_tensor);
340     se::DeviceMemoryBase output_buffer = output.buffer({output_num});
341     if (input_buffer.opaque() == output_buffer.opaque()) {
342       // In the case of a donated buffer, both input_tensor and output think
343       // they have ownership of the buffer (see comment in
344       // PopulateExecutionInputBuffer). Release ownership from output to avoid
345       // double free.
346       output.set_buffer(se::OwningDeviceMemory(), {output_num});
347       return input_tensor;
348     }
349   }
350 
351   if (allocate_xla_tensors) {
352     Tensor output_tensor;
353     TF_RETURN_IF_ERROR(
354         ctx->allocate_temp(output_dtype, output_shape, &output_tensor));
355     if (output_tensor.TotalBytes() > 0) {
356       XlaTensor* xla_tensor = XlaTensor::FromTensor(&output_tensor);
357       TF_RET_CHECK(xla_tensor);
358       xla_tensor->set_shaped_buffer(output.TakeSubTree({output_num}));
359       if (use_multiple_streams) {
360         xla_tensor->ResetDefinitionEvent(definition_event, stream);
361       }
362     }
363     return output_tensor;
364   }
365 
366   se::DeviceMemoryBase output_buffer = output.buffer({output_num});
367   Tensor output_tensor =
368       MakeTensor(output_dtype, output_shape, output_buffer, output_allocator);
369   output.set_buffer(se::OwningDeviceMemory(), {output_num});
370   return output_tensor;
371 }
372 
373 // Sets output `output_num` for `ctx` provided it is known at a compile time.
SetOutputForConstant(OpKernelContext * ctx,se::Stream * stream,const XlaCompiler::CompilationResult * compilation_result,int output_num)374 static Status SetOutputForConstant(
375     OpKernelContext* ctx, se::Stream* stream,
376     const XlaCompiler::CompilationResult* compilation_result, int output_num) {
377   CHECK(compilation_result->outputs[output_num].is_constant);
378   const Tensor& const_tensor =
379       compilation_result->outputs[output_num].constant_value;
380   Tensor* output_tensor;
381   if (stream && const_tensor.TotalBytes() > 0) {
382     // Copy host -> device. (Empty tensors don't have backing buffers.)
383     // Manually allocate memory using an XlaTensorBuffer so we can allocate
384     // as much memory as the device requires (as given by
385     // GetByteSizeRequirement). This avoids XlaTransferManager having to
386     // reallocate the device buffer later.
387     VLOG(1) << "Constant output tensor on device";
388 
389     TF_RETURN_IF_ERROR(
390         ctx->allocate_output(output_num, const_tensor.shape(), &output_tensor));
391     Device* device = dynamic_cast<Device*>(ctx->device());
392     if (device == nullptr) {
393       return errors::Internal("DeviceBase was not a Device.");
394     }
395     ctx->op_device_context()->CopyCPUTensorToDevice(
396         &const_tensor, device, output_tensor,
397         [&](Status status) { TF_CHECK_OK(status); });
398 
399     if (device->device_type() == DEVICE_GPU) {
400       // The GPUDeviceContext enqueues the host->device transfer in a
401       // separate stream from the main compute stream. We must ensure the
402       // compute stream is synchronized with the host->device transfer
403       // stream now otherwise we will create a race condition.
404       auto* gpu_device_context =
405           static_cast<GPUDeviceContext*>(ctx->op_device_context());
406       gpu_device_context->stream()->ThenWaitFor(
407           gpu_device_context->host_to_device_stream());
408     }
409   } else {
410     // No copy required.
411     ctx->set_output(output_num, const_tensor);
412     output_tensor = ctx->mutable_output(output_num);
413   }
414   return Status::OK();
415 }
416 
GetOrCreateResourceVar(OpKernelContext * ctx,const ResourceHandle & handle,const XlaCompiler::ResourceUpdate & write)417 static StatusOr<Var*> GetOrCreateResourceVar(
418     OpKernelContext* ctx, const ResourceHandle& handle,
419     const XlaCompiler::ResourceUpdate& write) {
420   Var* variable = nullptr;
421   TF_RETURN_IF_ERROR(
422       LookupOrCreateResource<Var>(ctx, handle, &variable, [&write](Var** ptr) {
423         *ptr = new Var(write.type);
424         return Status::OK();
425       }));
426   return variable;
427 }
428 
GatherVariableInfo(OpKernelContext * ctx,const XlaCompiler::CompilationResult & compilation_result,int missing_ctx_input_prefix)429 StatusOr<std::vector<VariableInfo>> GatherVariableInfo(
430     OpKernelContext* ctx,
431     const XlaCompiler::CompilationResult& compilation_result,
432     int missing_ctx_input_prefix) {
433   std::vector<VariableInfo> out;
434   out.reserve(compilation_result.resource_updates.size());
435   for (int i = 0; i < compilation_result.resource_updates.size(); ++i) {
436     const XlaCompiler::ResourceUpdate& write =
437         compilation_result.resource_updates[i];
438     int actual_input_index = write.input_index - missing_ctx_input_prefix;
439     if (actual_input_index < 0 || actual_input_index >= ctx->num_inputs()) {
440       return errors::Internal("Invalid input index for variable write.");
441     }
442 
443     const ResourceHandle handle = HandleFromInput(ctx, actual_input_index);
444     TF_ASSIGN_OR_RETURN(Var * variable,
445                         GetOrCreateResourceVar(ctx, handle, write));
446     out.emplace_back(actual_input_index, handle.name(), variable,
447                      handle.definition_stack_trace());
448   }
449   return std::move(out);
450 }
451 
PopulateOutputs(OpKernelContext * ctx,const XlaCompiler::CompilationResult * compilation_result,ScopedShapedBuffer output,int missing_ctx_input_prefix,absl::Span<VariableInfo> variable_infos,const xla::HloInputOutputAliasConfig & input_output_alias,const std::map<int,const Tensor * > & resource_vars)452 Status XlaComputationLaunchContext::PopulateOutputs(
453     OpKernelContext* ctx,
454     const XlaCompiler::CompilationResult* compilation_result,
455     ScopedShapedBuffer output, int missing_ctx_input_prefix,
456     absl::Span<VariableInfo> variable_infos,
457     const xla::HloInputOutputAliasConfig& input_output_alias,
458     const std::map<int, const Tensor*>& resource_vars) {
459   se::Stream* stream =
460       ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
461   Allocator* allocator = ctx->device()->GetAllocator({});
462 
463   // Computation output should always be a tuple.
464   VLOG(2) << "Result tuple shape: " << output.on_host_shape().DebugString();
465   VLOG(2) << "Result tuple shape (on device): "
466           << output.on_device_shape().DebugString();
467   CHECK_EQ(ctx->num_outputs(), compilation_result->outputs.size());
468 
469   // If the on-host-shape isn't a tuple, create a new single-element tuple
470   // buffer with a nullptr root index table. This allows the code below to treat
471   // output as a tuple unconditionally.
472   if (!output.on_host_shape().IsTuple()) {
473     ShapedBuffer nontuple_buffer = output.release();
474     ShapedBuffer buffer(
475         xla::ShapeUtil::MakeTupleShape({nontuple_buffer.on_host_shape()}),
476         xla::ShapeUtil::MakeTupleShape({nontuple_buffer.on_device_shape()}),
477         output.device_ordinal());
478     buffer.buffers().CopySubtreeFrom(nontuple_buffer.buffers(),
479                                      /*source_base_index=*/{},
480                                      /*target_base_index=*/{0});
481     output = ScopedShapedBuffer(std::move(buffer), output.memory_allocator());
482   }
483 
484   std::shared_ptr<se::Event> definition_event;
485   if (use_multiple_streams_) {
486     definition_event = std::make_shared<se::Event>(stream->parent());
487     if (!definition_event->Init()) {
488       return errors::Internal("Failed to initialize tensor definition event.");
489     }
490     stream->ThenRecordEvent(definition_event.get());
491   }
492 
493   for (const XlaOutputDescription& descr : compilation_result->outputs) {
494     if (descr.type == DT_VARIANT) {
495       return errors::Unimplemented(
496           "Support for TensorList crossing the XLA/TF boundary "
497           "is not implemented");
498     }
499   }
500 
501   std::vector<TensorShape> output_tensor_shapes;
502   output_tensor_shapes.reserve(ctx->num_outputs());
503   if (output.on_host_shape().is_dynamic()) {
504     const se::Platform* platform = nullptr;
505     if (stream != nullptr) {
506       platform = stream->parent()->platform();
507     } else {
508       // Stream is not set for the host platform.
509       TF_ASSIGN_OR_RETURN(platform,
510                           se::MultiPlatformManager::PlatformWithId(
511                               XlaPlatformInfoFromDevice(ctx->device())));
512     }
513     TF_ASSIGN_OR_RETURN(auto transfer_manager,
514                         xla::TransferManager::GetForPlatform(platform));
515 
516     xla::Shape output_device_shape = output.on_device_shape();
517     TF_RETURN_IF_ERROR(transfer_manager->ReadDynamicShapes(
518         stream, &output, &output_device_shape));
519 
520     output.set_shapes(output_device_shape, output_device_shape);
521     for (int i = 0; i < ctx->num_outputs(); ++i) {
522       const xla::Shape& subshape =
523           xla::ShapeUtil::GetSubshape(output_device_shape, {i});
524       TensorShape shape;
525       TF_RETURN_IF_ERROR(XLAShapeToTensorShape(subshape, &shape));
526       output_tensor_shapes.push_back(shape);
527     }
528   } else {
529     for (int i = 0; i < ctx->num_outputs(); ++i) {
530       output_tensor_shapes.push_back(compilation_result->outputs[i].shape);
531     }
532   }
533 
534   // Copy XLA results to the OpOutputList.
535   int output_num = 0;
536   for (int i = 0, end = ctx->num_outputs(); i < end; ++i) {
537     const TensorShape& shape = output_tensor_shapes[i];
538     const DataType& type = compilation_result->outputs[i].type;
539     VLOG(2) << "Populating output for retval " << i << " shape "
540             << shape.DebugString() << " type " << DataTypeString(type);
541 
542     if (compilation_result->outputs[i].is_constant) {
543       TF_RETURN_IF_ERROR(
544           SetOutputForConstant(ctx, stream, compilation_result, i));
545     } else if (type == DT_RESOURCE) {
546       int input_index =
547           compilation_result->outputs[i].input_index - missing_ctx_input_prefix;
548       TF_RET_CHECK(input_index >= 0 && input_index < ctx->num_inputs())
549           << "Invalid input for outputs " << i << ": " << input_index;
550       ctx->set_output(i, ctx->input(input_index));
551     } else {
552       TF_ASSIGN_OR_RETURN(
553           Tensor output_tensor,
554           GetOrCreateTensorForOutput(
555               output, output_num, ctx, missing_ctx_input_prefix,
556               input_output_alias, compilation_result->input_mapping,
557               resource_vars, ctx->expected_output_dtype(i), shape, allocator,
558               allocate_xla_tensors_, stream, use_multiple_streams_,
559               definition_event));
560       ctx->set_output(i, output_tensor);
561       ++output_num;
562     }
563   }
564 
565   // input_index -> index into variable_infos.
566   absl::flat_hash_map<int, int> variable_info_lookup;
567   for (int i = 0; i < variable_infos.size(); i++) {
568     variable_info_lookup.emplace(variable_infos[i].index(), i);
569   }
570 
571   // Apply variable updates, if any.
572   for (int i = 0, end = compilation_result->resource_updates.size(); i < end;
573        ++i) {
574     const XlaCompiler::ResourceUpdate& write =
575         compilation_result->resource_updates[i];
576     int actual_input_index = write.input_index - missing_ctx_input_prefix;
577     CHECK_GE(actual_input_index, 0);
578     CHECK_LT(actual_input_index, ctx->num_inputs());
579     Var* var = variable_infos[variable_info_lookup[actual_input_index]].var();
580     CHECK(var);
581 
582     VLOG(2) << "Updating variable #" << i
583             << " at input index: " << actual_input_index << " with shape "
584             << write.shape.DebugString() << "; variable tensor has shape: "
585             << var->tensor()->shape().DebugString();
586 
587     if (var->is_initialized && var->tensor()->dtype() != write.type) {
588       return errors::Internal("Mismatched type in variable write");
589     }
590 
591     TF_ASSIGN_OR_RETURN(
592         Tensor output_tensor,
593         GetOrCreateTensorForOutput(output, output_num, ctx,
594                                    missing_ctx_input_prefix, input_output_alias,
595                                    compilation_result->input_mapping,
596                                    resource_vars, write.type, write.shape,
597                                    allocator, allocate_xla_tensors_, stream,
598                                    use_multiple_streams_, definition_event));
599     var->is_initialized |= write.modified;
600     *var->tensor() = output_tensor;
601     ++output_num;
602   }
603   return Status::OK();
604 }
605 
606 StatusOr<std::vector<XlaCompiler::Argument>>
BuildXlaCompilerArguments(absl::Span<int const> must_be_constant_idxs,absl::Span<const Tensor * const> inputs,absl::Span<VariableInfo const> variable_args,Device * device)607 XlaComputationLaunchContext::BuildXlaCompilerArguments(
608     absl::Span<int const> must_be_constant_idxs,
609     absl::Span<const Tensor* const> inputs,
610     absl::Span<VariableInfo const> variable_args, Device* device) {
611   CHECK(absl::c_is_sorted(must_be_constant_idxs));
612   VLOG(2) << "Must be const args: {"
613           << absl::StrJoin(must_be_constant_idxs, ",") << "} out of "
614           << inputs.size() << " args";
615   std::vector<XlaCompiler::Argument> out;
616   out.resize(inputs.size());
617 
618   // TODO(cheshire): Avoid duplication with framework/op_kernel.h
619   DeviceContext* device_context = nullptr;
620   TF_RETURN_IF_ERROR(device->TryGetDeviceContext(&device_context));
621   bool using_default_context = false;
622   auto cleanup = xla::MakeCleanup([&] {
623     if (device_context != nullptr && !using_default_context) {
624       device_context->Unref();
625     }
626   });
627   if (device_context == nullptr) {
628     using_default_context = true;
629     auto* dev_info = device->tensorflow_gpu_device_info();
630     if (dev_info) device_context = dev_info->default_context;
631   }
632 
633   absl::flat_hash_map<int, const VariableInfo*> variable_info_lookup;
634   for (const VariableInfo& info : variable_args) {
635     CHECK(!info.var() || info.lock_held())
636         << "Need to hold the lock on resource variables "
637            "before calling BuildXlaCompilerArguments";
638     variable_info_lookup.emplace(info.index(), &info);
639   }
640 
641   for (int64_t input_num = 0; input_num < inputs.size(); ++input_num) {
642     const Tensor* input = inputs[input_num];
643 
644     XlaCompiler::Argument& arg = out[input_num];
645     if (variable_info_lookup.count(input_num)) {
646       // Handles resource variables.
647       TF_RET_CHECK(input->dtype() == DT_RESOURCE);
648       const VariableInfo& variable = *variable_info_lookup[input_num];
649       arg.name = std::string(variable.name());
650       arg.kind = XlaCompiler::Argument::kResource;
651       arg.resource_kind = XlaResource::kVariable;
652       arg.definition_stack_trace = variable.definition_stack_trace();
653       if (variable.var() && variable.var()->is_initialized) {
654         const Tensor* value = variable.var()->tensor();
655         arg.type = value->dtype();
656         arg.shape = value->shape();
657         arg.initialized = true;
658       } else {
659         // The values of uninitialized variables are not passed as inputs, since
660         // they are meaningless. However, it is legal to assign to a resource
661         // variable for the first time inside the XLA computation, so we do
662         // permit uninitialized variables.
663         arg.initialized = false;
664         arg.type = DT_INVALID;
665         arg.shape = TensorShape();
666       }
667 
668       if (absl::c_binary_search(must_be_constant_idxs, input_num)) {
669         TF_RET_CHECK(variable.var() && variable.var()->is_initialized);
670         const Tensor* value = variable.var()->tensor();
671         Tensor value_on_host(value->dtype(), value->shape());
672         if (!device_context) {
673           value_on_host = *value;
674         } else {
675           TF_RETURN_IF_ERROR(device_context->CopyDeviceTensorToCPUSync(
676               value, "", device, &value_on_host));
677         }
678         arg.kind = XlaCompiler::Argument::kConstantResource;
679         arg.constant_value = value_on_host;
680       }
681     } else if (absl::c_binary_search(must_be_constant_idxs, input_num)) {
682       arg.kind = XlaCompiler::Argument::kConstant;
683       arg.type = input->dtype();
684       arg.shape = input->shape();
685       arg.constant_value = *input;
686     } else {
687       // Normal inputs.
688       TF_RET_CHECK(input->dtype() != DT_RESOURCE);
689       if (input->NumElements() > 0) {
690         arg.kind = XlaCompiler::Argument::kParameter;
691       } else {
692         arg.kind = XlaCompiler::Argument::kConstant;
693         arg.constant_value = *input;
694       }
695       arg.type = input->dtype();
696       arg.shape = input->shape();
697     }
698   }
699 
700   return out;
701 }
702 
703 }  // namespace tensorflow
704