1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/jit/xla_launch_util.h"
17
18 #include <memory>
19
20 #include "absl/algorithm/container.h"
21 #include "absl/memory/memory.h"
22 #include "tensorflow/compiler/jit/defs.h"
23 #include "tensorflow/compiler/tf2xla/shape_util.h"
24 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
25 #include "tensorflow/compiler/xla/client/client_library.h"
26 #include "tensorflow/compiler/xla/client/local_client.h"
27 #include "tensorflow/compiler/xla/statusor.h"
28 #include "tensorflow/core/common_runtime/dma_helper.h"
29 #include "tensorflow/core/common_runtime/function.h"
30 #include "tensorflow/core/common_runtime/gpu_device_context.h"
31 #include "tensorflow/core/framework/allocator.h"
32 #include "tensorflow/core/framework/node_def_util.h"
33 #include "tensorflow/core/framework/op.h"
34 #include "tensorflow/core/framework/op_kernel.h"
35 #include "tensorflow/core/framework/resource_mgr.h"
36 #include "tensorflow/core/framework/tensor.h"
37 #include "tensorflow/core/framework/types.h"
38 #include "tensorflow/core/lib/core/errors.h"
39 #include "tensorflow/core/lib/core/refcount.h"
40 #include "tensorflow/core/util/stream_executor_util.h"
41
42 namespace tensorflow {
43 namespace {
44 using xla::ScopedShapedBuffer;
45 using xla::ShapedBuffer;
46
47 // Fetch the platform Id from device.
XlaPlatformInfoFromDevice(DeviceBase * device_base)48 se::Platform::Id XlaPlatformInfoFromDevice(DeviceBase* device_base) {
49 auto device = static_cast<Device*>(device_base);
50 se::Platform::Id platform_id = nullptr;
51 if (device->device_type() == DEVICE_CPU) {
52 platform_id = se::host::kHostPlatformId;
53 }
54
55 return platform_id;
56 }
57
58 } // anonymous namespace
59
VariableInfo(int index,absl::string_view name,Var * var,const absl::optional<ManagedStackTrace> & definition_stack_trace)60 VariableInfo::VariableInfo(
61 int index, absl::string_view name, Var* var,
62 const absl::optional<ManagedStackTrace>& definition_stack_trace)
63 : index_(index),
64 name_(name),
65 var_(var),
66 definition_stack_trace_(definition_stack_trace) {}
67
VariableInfo(VariableInfo && other)68 VariableInfo::VariableInfo(VariableInfo&& other)
69 : index_(other.index_),
70 var_(other.var_),
71 definition_stack_trace_(other.definition_stack_trace_),
72 lock_held_(other.lock_held_) {
73 other.index_ = -1;
74 other.var_ = nullptr;
75 }
76
operator =(VariableInfo && other)77 VariableInfo& VariableInfo::operator=(VariableInfo&& other) {
78 index_ = other.index_;
79 var_ = other.var_;
80 lock_held_ = other.lock_held_;
81 definition_stack_trace_ = other.definition_stack_trace_;
82
83 other.index_ = -1;
84 other.var_ = nullptr;
85
86 return *this;
87 }
88
~VariableInfo()89 VariableInfo::~VariableInfo() {
90 // Release the variable's lock if we hold it. Ensures that the lock is
91 // released even on error. It does not matter in what order we release the
92 // locks.
93 if (var()) {
94 if (lock_held()) {
95 var()->mu()->unlock();
96 }
97
98 // Unref the variable so it can be released by ResourceManager.
99 var()->Unref();
100 }
101 }
102
GetVariableInfosFromInputs(ResourceMgr * rm,DeviceBase * dev,absl::Span<const Tensor * const> inputs,absl::Span<const int> variable_indices,std::vector<VariableInfo> * result)103 Status GetVariableInfosFromInputs(ResourceMgr* rm, DeviceBase* dev,
104 absl::Span<const Tensor* const> inputs,
105 absl::Span<const int> variable_indices,
106 std::vector<VariableInfo>* result) {
107 result->clear();
108 result->reserve(variable_indices.size());
109 for (int var_idx : variable_indices) {
110 Var* variable = nullptr;
111 ResourceHandle handle = inputs[var_idx]->flat<ResourceHandle>()(0);
112 if (handle.device() != dev->attributes().name()) {
113 std::string definition_location =
114 DefinitionLocationMsg(handle.definition_stack_trace());
115 return errors::InvalidArgument("Trying to access resource ",
116 handle.name(), definition_location,
117 " located in device ", handle.device(),
118 " from device ", dev->attributes().name());
119 }
120 TF_RETURN_IF_ERROR(rm->LookupOrCreate<Var>(
121 handle.container(), handle.name(), &variable, [](Var** ptr) {
122 // This var is uninitialized for now.
123 *ptr = new Var(DT_INVALID);
124 return Status::OK();
125 }));
126 result->emplace_back(var_idx, handle.name(), variable,
127 handle.definition_stack_trace());
128 }
129 return Status::OK();
130 }
131
InputsFromContext(OpKernelContext * ctx)132 std::vector<const Tensor*> InputsFromContext(OpKernelContext* ctx) {
133 std::vector<const Tensor*> inputs;
134 inputs.reserve(ctx->num_inputs());
135 for (int input_idx = 0; input_idx < ctx->num_inputs(); input_idx++) {
136 inputs.push_back(&ctx->input(input_idx));
137 }
138 return inputs;
139 }
140
LockVariables(absl::Span<VariableInfo> variables)141 Status LockVariables(absl::Span<VariableInfo> variables) {
142 std::vector<int> lock_order(variables.size());
143 std::iota(lock_order.begin(), lock_order.end(), 0);
144
145 // VariableInfoComparator orders all empty VariableInfo instances as
146 // equivalent so it looks like we may want to stable sort these to maintain a
147 // deterministic order between the empty VariableInfo instances. However
148 // since we're sorting by pointer value the sort is pretty non-deterministic
149 // anyway so we don't bother using std::stable_sort for now.
150 absl::c_sort(lock_order, [&](int a, int b) {
151 if (variables[a].var() && variables[b].var()) {
152 return variables[a].var()->mu() < variables[b].var()->mu();
153 }
154
155 // Move all the empty VariableInfo instances to the end.
156 return variables[a].var() != nullptr;
157 });
158
159 mutex* prev = nullptr;
160 for (int i : lock_order) {
161 Var* variable = variables[i].var();
162 if (variable == nullptr) {
163 // All empty VariableInfo instances are at the end of the order
164 // so we're done.
165 break;
166 }
167 mutex* mu = variable->mu();
168 if (prev == mu) {
169 // It is an error to pass the same variable handle twice to the same XLA
170 // cluster because we would not handle variable updates correctly. Any
171 // locks we have already acquired will be released when the VariableInfo
172 // objects are destroyed.
173 // TODO(b/128495870) Add support for passing aliased resource variables.
174 return errors::Unimplemented("Duplicate variable passed to XLA cluster");
175 }
176 VLOG(4) << "Acquiring lock for variable "
177 << reinterpret_cast<void*>(variable);
178 mu->lock();
179 variables[i].set_lock_held();
180 prev = mu;
181 }
182 VLOG(4) << "Finished acquiring variable locks.";
183 return Status::OK();
184 }
185
SnapshotResourceVariables(OpKernelContext * ctx,absl::Span<const int> variable_indices,absl::Span<VariableInfo const> variable_infos,ResourceVarsSnapshot * result)186 Status SnapshotResourceVariables(OpKernelContext* ctx,
187 absl::Span<const int> variable_indices,
188 absl::Span<VariableInfo const> variable_infos,
189 ResourceVarsSnapshot* result) {
190 for (int i = 0, end = variable_indices.size(); i < end; i++) {
191 Var* var = variable_infos[i].var();
192 (*result)[variable_indices[i]] =
193 var ? absl::make_optional(*var->tensor()) : absl::nullopt;
194 }
195 return Status::OK();
196 }
197
XlaComputationLaunchContext(xla::LocalClient * client,se::DeviceMemoryAllocator * xla_allocator,int device_ordinal,bool allocate_xla_tensors,bool use_multiple_streams)198 XlaComputationLaunchContext::XlaComputationLaunchContext(
199 xla::LocalClient* client, se::DeviceMemoryAllocator* xla_allocator,
200 int device_ordinal, bool allocate_xla_tensors, bool use_multiple_streams)
201 : client_(client),
202 xla_allocator_(xla_allocator),
203 allocate_xla_tensors_(allocate_xla_tensors),
204 use_multiple_streams_(use_multiple_streams),
205 device_ordinal_(device_ordinal) {
206 if (use_multiple_streams_) {
207 CHECK(allocate_xla_tensors_) << "To use multiple streams correctly we must "
208 "be allocating XLA tensors!";
209 }
210 }
211
212 // Fills in `execution_input` with `buffer` for `index`.
PopulateExecutionInputBuffer(xla::ExecutionInput & execution_input,xla::ShapeIndex index,se::DeviceMemoryBase buffer,bool donate_buffer,int device_ordinal,se::DeviceMemoryAllocator * allocator)213 static void PopulateExecutionInputBuffer(xla::ExecutionInput& execution_input,
214 xla::ShapeIndex index,
215 se::DeviceMemoryBase buffer,
216 bool donate_buffer, int device_ordinal,
217 se::DeviceMemoryAllocator* allocator) {
218 xla::MaybeOwningDeviceMemory* in_buffer =
219 execution_input.MutableBuffer(index);
220 if (donate_buffer) {
221 // Here we pass ownership of the buffer to execution_input without releasing
222 // ownership from the caller of PopulateExecutionInputBuffer. If execution
223 // succeeds, we'll take back that duplicate ownership in
224 // GetOrCreateTensorForOutput. If execution fails, the ExecutionInput will
225 // release that duplicate ownership automatically.
226 *in_buffer = se::OwningDeviceMemory(buffer, device_ordinal, allocator);
227 } else {
228 *in_buffer = buffer;
229 }
230 }
231
232 StatusOr<std::vector<xla::ExecutionInput>>
PopulateInputs(OpKernelContext * ctx,const XlaCompiler::CompilationResult * compilation_result,const std::map<int,const Tensor * > & resource_vars,int missing_ctx_input_prefix,const xla::HloInputOutputAliasConfig & input_output_alias)233 XlaComputationLaunchContext::PopulateInputs(
234 OpKernelContext* ctx,
235 const XlaCompiler::CompilationResult* compilation_result,
236 const std::map<int, const Tensor*>& resource_vars,
237 int missing_ctx_input_prefix,
238 const xla::HloInputOutputAliasConfig& input_output_alias) {
239 std::vector<xla::ExecutionInput> arguments;
240 arguments.reserve(compilation_result->xla_input_shapes.size());
241
242 xla::TransferManager* transfer_manager =
243 client_->backend().transfer_manager();
244 for (int i = 0, end = compilation_result->xla_input_shapes.size(); i < end;
245 ++i) {
246 int arg_num = compilation_result->input_mapping[i];
247 CHECK_GE(arg_num, missing_ctx_input_prefix);
248 const xla::Shape& shape = compilation_result->xla_input_shapes[i];
249 const xla::Shape& device_shape =
250 transfer_manager->HostShapeToDeviceShape(shape);
251
252 bool is_resource_variable = resource_vars.count(arg_num);
253 bool is_updated_resource_variable =
254 is_resource_variable &&
255 absl::c_any_of(compilation_result->resource_updates,
256 [&](const XlaCompiler::ResourceUpdate& update) {
257 return update.input_index == i && update.modified;
258 });
259
260 const Tensor* t = is_resource_variable
261 ? resource_vars.at(arg_num)
262 : &(ctx->input(arg_num - missing_ctx_input_prefix));
263 CHECK(t);
264 bool donate_buffer =
265 t->RefCountIsOne() && is_updated_resource_variable &&
266 input_output_alias.ParameterHasAlias(i, xla::ShapeIndex{});
267 VLOG(3) << "Processing input: " << i
268 << "; is_resource_variable=" << is_resource_variable
269 << "; is_updated_resource_variable=" << is_updated_resource_variable
270 << "; donate_buffer=" << donate_buffer;
271
272 if (use_multiple_streams_) {
273 CHECK(ctx->op_device_context() && ctx->op_device_context()->stream())
274 << "Must have a stream available when using XLA tensors!";
275 XlaTensor* xla_tensor = XlaTensor::FromTensor(t);
276 CHECK(xla_tensor);
277 xla_tensor->WaitForDefinitionEventOnStream(
278 ctx->op_device_context()->stream());
279 }
280
281 arguments.emplace_back(device_shape, shape);
282 xla::ExecutionInput& execution_input = arguments.back();
283 if (xla::Shape::Equal().MinorToMajorOnlyInLayout()(shape, device_shape)) {
284 se::DeviceMemoryBase dmem = XlaTensor::DeviceMemoryFromTensor(*t);
285 PopulateExecutionInputBuffer(execution_input, xla::ShapeIndex{}, dmem,
286 donate_buffer, device_ordinal_,
287 xla_allocator_);
288 } else {
289 XlaTensor* xla_tensor = XlaTensor::FromTensor(t);
290 CHECK(xla_tensor && xla_tensor->has_shaped_buffer());
291 xla_tensor->shaped_buffer().buffers().ForEachMutableElement(
292 [&](const xla::ShapeIndex& index, se::DeviceMemoryBase* buffer) {
293 PopulateExecutionInputBuffer(execution_input, index, *buffer,
294 donate_buffer, device_ordinal_,
295 xla_allocator_);
296 });
297 }
298 }
299 return std::move(arguments);
300 }
301
302 // Construct the tensor for the given type and buffer.
MakeTensor(DataType dtype,const TensorShape & shape,se::DeviceMemoryBase buffer,Allocator * allocator)303 static Tensor MakeTensor(DataType dtype, const TensorShape& shape,
304 se::DeviceMemoryBase buffer, Allocator* allocator) {
305 size_t expected_size = shape.num_elements() * DataTypeSize(dtype);
306 auto* tensor_buffer = new XlaTensorBuffer(buffer.opaque(), expected_size,
307 buffer.size(), allocator);
308 Tensor t(dtype, shape, tensor_buffer);
309 tensor_buffer->Unref();
310 return t;
311 }
312
313 // Get aliased tensor from output, or make a new one for the corresponding
314 // output operation. Transfers ownership of the buffer from output to the
315 // returned tensor.
GetOrCreateTensorForOutput(xla::ScopedShapedBuffer & output,int output_num,OpKernelContext * ctx,int missing_ctx_input_prefix,const xla::HloInputOutputAliasConfig & input_output_alias,absl::Span<const int> input_mapping,const std::map<int,const Tensor * > & resource_vars_snapshots,DataType output_dtype,const TensorShape & output_shape,Allocator * output_allocator,bool allocate_xla_tensors,se::Stream * stream,bool use_multiple_streams,std::shared_ptr<se::Event> definition_event)316 static StatusOr<Tensor> GetOrCreateTensorForOutput(
317 xla::ScopedShapedBuffer& output, int output_num, OpKernelContext* ctx,
318 int missing_ctx_input_prefix,
319 const xla::HloInputOutputAliasConfig& input_output_alias,
320 absl::Span<const int> input_mapping,
321 const std::map<int, const Tensor*>& resource_vars_snapshots,
322 DataType output_dtype, const TensorShape& output_shape,
323 Allocator* output_allocator, bool allocate_xla_tensors, se::Stream* stream,
324 bool use_multiple_streams, std::shared_ptr<se::Event> definition_event) {
325 xla::ShapeIndex output_index = input_output_alias.shape().IsTuple()
326 ? xla::ShapeIndex({output_num})
327 : xla::ShapeIndex({});
328 CHECK(input_output_alias.shape().IsTuple() || output_num == 0);
329 if (absl::optional<xla::HloInputOutputAliasConfig::Alias> alias =
330 input_output_alias.GetAliasedParameter(output_index)) {
331 VLOG(3) << "Found alias: " << alias->ToString();
332 int tf_param =
333 input_mapping[alias->parameter_number] - missing_ctx_input_prefix;
334 const Tensor input_tensor =
335 ctx->input(tf_param).dtype() != DT_RESOURCE
336 ? ctx->input(tf_param)
337 : *resource_vars_snapshots.at(missing_ctx_input_prefix + tf_param);
338 se::DeviceMemoryBase input_buffer =
339 XlaTensor::DeviceMemoryFromTensor(input_tensor);
340 se::DeviceMemoryBase output_buffer = output.buffer({output_num});
341 if (input_buffer.opaque() == output_buffer.opaque()) {
342 // In the case of a donated buffer, both input_tensor and output think
343 // they have ownership of the buffer (see comment in
344 // PopulateExecutionInputBuffer). Release ownership from output to avoid
345 // double free.
346 output.set_buffer(se::OwningDeviceMemory(), {output_num});
347 return input_tensor;
348 }
349 }
350
351 if (allocate_xla_tensors) {
352 Tensor output_tensor;
353 TF_RETURN_IF_ERROR(
354 ctx->allocate_temp(output_dtype, output_shape, &output_tensor));
355 if (output_tensor.TotalBytes() > 0) {
356 XlaTensor* xla_tensor = XlaTensor::FromTensor(&output_tensor);
357 TF_RET_CHECK(xla_tensor);
358 xla_tensor->set_shaped_buffer(output.TakeSubTree({output_num}));
359 if (use_multiple_streams) {
360 xla_tensor->ResetDefinitionEvent(definition_event, stream);
361 }
362 }
363 return output_tensor;
364 }
365
366 se::DeviceMemoryBase output_buffer = output.buffer({output_num});
367 Tensor output_tensor =
368 MakeTensor(output_dtype, output_shape, output_buffer, output_allocator);
369 output.set_buffer(se::OwningDeviceMemory(), {output_num});
370 return output_tensor;
371 }
372
373 // Sets output `output_num` for `ctx` provided it is known at a compile time.
SetOutputForConstant(OpKernelContext * ctx,se::Stream * stream,const XlaCompiler::CompilationResult * compilation_result,int output_num)374 static Status SetOutputForConstant(
375 OpKernelContext* ctx, se::Stream* stream,
376 const XlaCompiler::CompilationResult* compilation_result, int output_num) {
377 CHECK(compilation_result->outputs[output_num].is_constant);
378 const Tensor& const_tensor =
379 compilation_result->outputs[output_num].constant_value;
380 Tensor* output_tensor;
381 if (stream && const_tensor.TotalBytes() > 0) {
382 // Copy host -> device. (Empty tensors don't have backing buffers.)
383 // Manually allocate memory using an XlaTensorBuffer so we can allocate
384 // as much memory as the device requires (as given by
385 // GetByteSizeRequirement). This avoids XlaTransferManager having to
386 // reallocate the device buffer later.
387 VLOG(1) << "Constant output tensor on device";
388
389 TF_RETURN_IF_ERROR(
390 ctx->allocate_output(output_num, const_tensor.shape(), &output_tensor));
391 Device* device = dynamic_cast<Device*>(ctx->device());
392 if (device == nullptr) {
393 return errors::Internal("DeviceBase was not a Device.");
394 }
395 ctx->op_device_context()->CopyCPUTensorToDevice(
396 &const_tensor, device, output_tensor,
397 [&](Status status) { TF_CHECK_OK(status); });
398
399 if (device->device_type() == DEVICE_GPU) {
400 // The GPUDeviceContext enqueues the host->device transfer in a
401 // separate stream from the main compute stream. We must ensure the
402 // compute stream is synchronized with the host->device transfer
403 // stream now otherwise we will create a race condition.
404 auto* gpu_device_context =
405 static_cast<GPUDeviceContext*>(ctx->op_device_context());
406 gpu_device_context->stream()->ThenWaitFor(
407 gpu_device_context->host_to_device_stream());
408 }
409 } else {
410 // No copy required.
411 ctx->set_output(output_num, const_tensor);
412 output_tensor = ctx->mutable_output(output_num);
413 }
414 return Status::OK();
415 }
416
GetOrCreateResourceVar(OpKernelContext * ctx,const ResourceHandle & handle,const XlaCompiler::ResourceUpdate & write)417 static StatusOr<Var*> GetOrCreateResourceVar(
418 OpKernelContext* ctx, const ResourceHandle& handle,
419 const XlaCompiler::ResourceUpdate& write) {
420 Var* variable = nullptr;
421 TF_RETURN_IF_ERROR(
422 LookupOrCreateResource<Var>(ctx, handle, &variable, [&write](Var** ptr) {
423 *ptr = new Var(write.type);
424 return Status::OK();
425 }));
426 return variable;
427 }
428
GatherVariableInfo(OpKernelContext * ctx,const XlaCompiler::CompilationResult & compilation_result,int missing_ctx_input_prefix)429 StatusOr<std::vector<VariableInfo>> GatherVariableInfo(
430 OpKernelContext* ctx,
431 const XlaCompiler::CompilationResult& compilation_result,
432 int missing_ctx_input_prefix) {
433 std::vector<VariableInfo> out;
434 out.reserve(compilation_result.resource_updates.size());
435 for (int i = 0; i < compilation_result.resource_updates.size(); ++i) {
436 const XlaCompiler::ResourceUpdate& write =
437 compilation_result.resource_updates[i];
438 int actual_input_index = write.input_index - missing_ctx_input_prefix;
439 if (actual_input_index < 0 || actual_input_index >= ctx->num_inputs()) {
440 return errors::Internal("Invalid input index for variable write.");
441 }
442
443 const ResourceHandle handle = HandleFromInput(ctx, actual_input_index);
444 TF_ASSIGN_OR_RETURN(Var * variable,
445 GetOrCreateResourceVar(ctx, handle, write));
446 out.emplace_back(actual_input_index, handle.name(), variable,
447 handle.definition_stack_trace());
448 }
449 return std::move(out);
450 }
451
PopulateOutputs(OpKernelContext * ctx,const XlaCompiler::CompilationResult * compilation_result,ScopedShapedBuffer output,int missing_ctx_input_prefix,absl::Span<VariableInfo> variable_infos,const xla::HloInputOutputAliasConfig & input_output_alias,const std::map<int,const Tensor * > & resource_vars)452 Status XlaComputationLaunchContext::PopulateOutputs(
453 OpKernelContext* ctx,
454 const XlaCompiler::CompilationResult* compilation_result,
455 ScopedShapedBuffer output, int missing_ctx_input_prefix,
456 absl::Span<VariableInfo> variable_infos,
457 const xla::HloInputOutputAliasConfig& input_output_alias,
458 const std::map<int, const Tensor*>& resource_vars) {
459 se::Stream* stream =
460 ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
461 Allocator* allocator = ctx->device()->GetAllocator({});
462
463 // Computation output should always be a tuple.
464 VLOG(2) << "Result tuple shape: " << output.on_host_shape().DebugString();
465 VLOG(2) << "Result tuple shape (on device): "
466 << output.on_device_shape().DebugString();
467 CHECK_EQ(ctx->num_outputs(), compilation_result->outputs.size());
468
469 // If the on-host-shape isn't a tuple, create a new single-element tuple
470 // buffer with a nullptr root index table. This allows the code below to treat
471 // output as a tuple unconditionally.
472 if (!output.on_host_shape().IsTuple()) {
473 ShapedBuffer nontuple_buffer = output.release();
474 ShapedBuffer buffer(
475 xla::ShapeUtil::MakeTupleShape({nontuple_buffer.on_host_shape()}),
476 xla::ShapeUtil::MakeTupleShape({nontuple_buffer.on_device_shape()}),
477 output.device_ordinal());
478 buffer.buffers().CopySubtreeFrom(nontuple_buffer.buffers(),
479 /*source_base_index=*/{},
480 /*target_base_index=*/{0});
481 output = ScopedShapedBuffer(std::move(buffer), output.memory_allocator());
482 }
483
484 std::shared_ptr<se::Event> definition_event;
485 if (use_multiple_streams_) {
486 definition_event = std::make_shared<se::Event>(stream->parent());
487 if (!definition_event->Init()) {
488 return errors::Internal("Failed to initialize tensor definition event.");
489 }
490 stream->ThenRecordEvent(definition_event.get());
491 }
492
493 for (const XlaOutputDescription& descr : compilation_result->outputs) {
494 if (descr.type == DT_VARIANT) {
495 return errors::Unimplemented(
496 "Support for TensorList crossing the XLA/TF boundary "
497 "is not implemented");
498 }
499 }
500
501 std::vector<TensorShape> output_tensor_shapes;
502 output_tensor_shapes.reserve(ctx->num_outputs());
503 if (output.on_host_shape().is_dynamic()) {
504 const se::Platform* platform = nullptr;
505 if (stream != nullptr) {
506 platform = stream->parent()->platform();
507 } else {
508 // Stream is not set for the host platform.
509 TF_ASSIGN_OR_RETURN(platform,
510 se::MultiPlatformManager::PlatformWithId(
511 XlaPlatformInfoFromDevice(ctx->device())));
512 }
513 TF_ASSIGN_OR_RETURN(auto transfer_manager,
514 xla::TransferManager::GetForPlatform(platform));
515
516 xla::Shape output_device_shape = output.on_device_shape();
517 TF_RETURN_IF_ERROR(transfer_manager->ReadDynamicShapes(
518 stream, &output, &output_device_shape));
519
520 output.set_shapes(output_device_shape, output_device_shape);
521 for (int i = 0; i < ctx->num_outputs(); ++i) {
522 const xla::Shape& subshape =
523 xla::ShapeUtil::GetSubshape(output_device_shape, {i});
524 TensorShape shape;
525 TF_RETURN_IF_ERROR(XLAShapeToTensorShape(subshape, &shape));
526 output_tensor_shapes.push_back(shape);
527 }
528 } else {
529 for (int i = 0; i < ctx->num_outputs(); ++i) {
530 output_tensor_shapes.push_back(compilation_result->outputs[i].shape);
531 }
532 }
533
534 // Copy XLA results to the OpOutputList.
535 int output_num = 0;
536 for (int i = 0, end = ctx->num_outputs(); i < end; ++i) {
537 const TensorShape& shape = output_tensor_shapes[i];
538 const DataType& type = compilation_result->outputs[i].type;
539 VLOG(2) << "Populating output for retval " << i << " shape "
540 << shape.DebugString() << " type " << DataTypeString(type);
541
542 if (compilation_result->outputs[i].is_constant) {
543 TF_RETURN_IF_ERROR(
544 SetOutputForConstant(ctx, stream, compilation_result, i));
545 } else if (type == DT_RESOURCE) {
546 int input_index =
547 compilation_result->outputs[i].input_index - missing_ctx_input_prefix;
548 TF_RET_CHECK(input_index >= 0 && input_index < ctx->num_inputs())
549 << "Invalid input for outputs " << i << ": " << input_index;
550 ctx->set_output(i, ctx->input(input_index));
551 } else {
552 TF_ASSIGN_OR_RETURN(
553 Tensor output_tensor,
554 GetOrCreateTensorForOutput(
555 output, output_num, ctx, missing_ctx_input_prefix,
556 input_output_alias, compilation_result->input_mapping,
557 resource_vars, ctx->expected_output_dtype(i), shape, allocator,
558 allocate_xla_tensors_, stream, use_multiple_streams_,
559 definition_event));
560 ctx->set_output(i, output_tensor);
561 ++output_num;
562 }
563 }
564
565 // input_index -> index into variable_infos.
566 absl::flat_hash_map<int, int> variable_info_lookup;
567 for (int i = 0; i < variable_infos.size(); i++) {
568 variable_info_lookup.emplace(variable_infos[i].index(), i);
569 }
570
571 // Apply variable updates, if any.
572 for (int i = 0, end = compilation_result->resource_updates.size(); i < end;
573 ++i) {
574 const XlaCompiler::ResourceUpdate& write =
575 compilation_result->resource_updates[i];
576 int actual_input_index = write.input_index - missing_ctx_input_prefix;
577 CHECK_GE(actual_input_index, 0);
578 CHECK_LT(actual_input_index, ctx->num_inputs());
579 Var* var = variable_infos[variable_info_lookup[actual_input_index]].var();
580 CHECK(var);
581
582 VLOG(2) << "Updating variable #" << i
583 << " at input index: " << actual_input_index << " with shape "
584 << write.shape.DebugString() << "; variable tensor has shape: "
585 << var->tensor()->shape().DebugString();
586
587 if (var->is_initialized && var->tensor()->dtype() != write.type) {
588 return errors::Internal("Mismatched type in variable write");
589 }
590
591 TF_ASSIGN_OR_RETURN(
592 Tensor output_tensor,
593 GetOrCreateTensorForOutput(output, output_num, ctx,
594 missing_ctx_input_prefix, input_output_alias,
595 compilation_result->input_mapping,
596 resource_vars, write.type, write.shape,
597 allocator, allocate_xla_tensors_, stream,
598 use_multiple_streams_, definition_event));
599 var->is_initialized |= write.modified;
600 *var->tensor() = output_tensor;
601 ++output_num;
602 }
603 return Status::OK();
604 }
605
606 StatusOr<std::vector<XlaCompiler::Argument>>
BuildXlaCompilerArguments(absl::Span<int const> must_be_constant_idxs,absl::Span<const Tensor * const> inputs,absl::Span<VariableInfo const> variable_args,Device * device)607 XlaComputationLaunchContext::BuildXlaCompilerArguments(
608 absl::Span<int const> must_be_constant_idxs,
609 absl::Span<const Tensor* const> inputs,
610 absl::Span<VariableInfo const> variable_args, Device* device) {
611 CHECK(absl::c_is_sorted(must_be_constant_idxs));
612 VLOG(2) << "Must be const args: {"
613 << absl::StrJoin(must_be_constant_idxs, ",") << "} out of "
614 << inputs.size() << " args";
615 std::vector<XlaCompiler::Argument> out;
616 out.resize(inputs.size());
617
618 // TODO(cheshire): Avoid duplication with framework/op_kernel.h
619 DeviceContext* device_context = nullptr;
620 TF_RETURN_IF_ERROR(device->TryGetDeviceContext(&device_context));
621 bool using_default_context = false;
622 auto cleanup = xla::MakeCleanup([&] {
623 if (device_context != nullptr && !using_default_context) {
624 device_context->Unref();
625 }
626 });
627 if (device_context == nullptr) {
628 using_default_context = true;
629 auto* dev_info = device->tensorflow_gpu_device_info();
630 if (dev_info) device_context = dev_info->default_context;
631 }
632
633 absl::flat_hash_map<int, const VariableInfo*> variable_info_lookup;
634 for (const VariableInfo& info : variable_args) {
635 CHECK(!info.var() || info.lock_held())
636 << "Need to hold the lock on resource variables "
637 "before calling BuildXlaCompilerArguments";
638 variable_info_lookup.emplace(info.index(), &info);
639 }
640
641 for (int64_t input_num = 0; input_num < inputs.size(); ++input_num) {
642 const Tensor* input = inputs[input_num];
643
644 XlaCompiler::Argument& arg = out[input_num];
645 if (variable_info_lookup.count(input_num)) {
646 // Handles resource variables.
647 TF_RET_CHECK(input->dtype() == DT_RESOURCE);
648 const VariableInfo& variable = *variable_info_lookup[input_num];
649 arg.name = std::string(variable.name());
650 arg.kind = XlaCompiler::Argument::kResource;
651 arg.resource_kind = XlaResource::kVariable;
652 arg.definition_stack_trace = variable.definition_stack_trace();
653 if (variable.var() && variable.var()->is_initialized) {
654 const Tensor* value = variable.var()->tensor();
655 arg.type = value->dtype();
656 arg.shape = value->shape();
657 arg.initialized = true;
658 } else {
659 // The values of uninitialized variables are not passed as inputs, since
660 // they are meaningless. However, it is legal to assign to a resource
661 // variable for the first time inside the XLA computation, so we do
662 // permit uninitialized variables.
663 arg.initialized = false;
664 arg.type = DT_INVALID;
665 arg.shape = TensorShape();
666 }
667
668 if (absl::c_binary_search(must_be_constant_idxs, input_num)) {
669 TF_RET_CHECK(variable.var() && variable.var()->is_initialized);
670 const Tensor* value = variable.var()->tensor();
671 Tensor value_on_host(value->dtype(), value->shape());
672 if (!device_context) {
673 value_on_host = *value;
674 } else {
675 TF_RETURN_IF_ERROR(device_context->CopyDeviceTensorToCPUSync(
676 value, "", device, &value_on_host));
677 }
678 arg.kind = XlaCompiler::Argument::kConstantResource;
679 arg.constant_value = value_on_host;
680 }
681 } else if (absl::c_binary_search(must_be_constant_idxs, input_num)) {
682 arg.kind = XlaCompiler::Argument::kConstant;
683 arg.type = input->dtype();
684 arg.shape = input->shape();
685 arg.constant_value = *input;
686 } else {
687 // Normal inputs.
688 TF_RET_CHECK(input->dtype() != DT_RESOURCE);
689 if (input->NumElements() > 0) {
690 arg.kind = XlaCompiler::Argument::kParameter;
691 } else {
692 arg.kind = XlaCompiler::Argument::kConstant;
693 arg.constant_value = *input;
694 }
695 arg.type = input->dtype();
696 arg.shape = input->shape();
697 }
698 }
699
700 return out;
701 }
702
703 } // namespace tensorflow
704