1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/jit/xla_launch_util.h"
17
18 #include <memory>
19
20 #include "absl/algorithm/container.h"
21 #include "absl/memory/memory.h"
22 #include "tensorflow/compiler/jit/defs.h"
23 #include "tensorflow/compiler/tf2xla/shape_util.h"
24 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
25 #include "tensorflow/compiler/xla/client/client_library.h"
26 #include "tensorflow/compiler/xla/client/local_client.h"
27 #include "tensorflow/compiler/xla/statusor.h"
28 #include "tensorflow/core/common_runtime/dma_helper.h"
29 #include "tensorflow/core/common_runtime/function.h"
30 #include "tensorflow/core/common_runtime/gpu_device_context.h"
31 #include "tensorflow/core/framework/allocator.h"
32 #include "tensorflow/core/framework/node_def_util.h"
33 #include "tensorflow/core/framework/op.h"
34 #include "tensorflow/core/framework/op_kernel.h"
35 #include "tensorflow/core/framework/resource_mgr.h"
36 #include "tensorflow/core/framework/tensor.h"
37 #include "tensorflow/core/framework/types.h"
38 #include "tensorflow/core/lib/core/errors.h"
39 #include "tensorflow/core/lib/core/refcount.h"
40 #include "tensorflow/core/util/stream_executor_util.h"
41
42 namespace tensorflow {
43 namespace {
44 using xla::ScopedShapedBuffer;
45 using xla::ShapedBuffer;
46
47 } // anonymous namespace
48
VariableInfo(int index,absl::string_view name,Var * var)49 VariableInfo::VariableInfo(int index, absl::string_view name, Var* var)
50 : index_(index), name_(name), var_(var) {}
VariableInfo(VariableInfo && other)51 VariableInfo::VariableInfo(VariableInfo&& other)
52 : index_(other.index_), var_(other.var_), lock_held_(other.lock_held_) {
53 other.index_ = -1;
54 other.var_ = nullptr;
55 }
56
operator =(VariableInfo && other)57 VariableInfo& VariableInfo::operator=(VariableInfo&& other) {
58 index_ = other.index_;
59 var_ = other.var_;
60 lock_held_ = other.lock_held_;
61
62 other.index_ = -1;
63 other.var_ = nullptr;
64
65 return *this;
66 }
67
~VariableInfo()68 VariableInfo::~VariableInfo() {
69 // Release the variable's lock if we hold it. Ensures that the lock is
70 // released even on error. It does not matter in what order we release the
71 // locks.
72 if (var()) {
73 if (lock_held()) {
74 var()->mu()->unlock();
75 }
76
77 // Unref the variable so it can be released by ResourceManager.
78 var()->Unref();
79 }
80 }
81
GetVariableInfosFromInputs(ResourceMgr * rm,DeviceBase * dev,absl::Span<const Tensor * const> inputs,absl::Span<const int> variable_indices,std::vector<VariableInfo> * result)82 Status GetVariableInfosFromInputs(ResourceMgr* rm, DeviceBase* dev,
83 absl::Span<const Tensor* const> inputs,
84 absl::Span<const int> variable_indices,
85 std::vector<VariableInfo>* result) {
86 result->clear();
87 result->reserve(variable_indices.size());
88 for (int var_idx : variable_indices) {
89 Var* variable = nullptr;
90 ResourceHandle handle = inputs[var_idx]->flat<ResourceHandle>()(0);
91 if (handle.device() != dev->attributes().name()) {
92 std::string definition_location = [&]() -> std::string {
93 if (handle.definition_stack_trace()) {
94 std::vector<StackFrame> stack_frames =
95 handle.definition_stack_trace()->ToStackFrames(
96 {}, IsInternalFrameForFilename,
97 /*reverse_traversal=*/true,
98 /*limit=*/1);
99 if (!stack_frames.empty()) {
100 const StackFrame& last_frame = stack_frames[0];
101 return absl::StrCat(" (defined @ ", last_frame.file_name, ":",
102 last_frame.line_number, ")");
103 }
104 }
105 return "";
106 }();
107 return errors::InvalidArgument("Trying to access resource ",
108 handle.name(), definition_location,
109 " located in device ", handle.device(),
110 " from device ", dev->attributes().name());
111 }
112 TF_RETURN_IF_ERROR(rm->LookupOrCreate<Var>(
113 handle.container(), handle.name(), &variable, [](Var** ptr) {
114 // This var is uninitialized for now.
115 *ptr = new Var(DT_INVALID);
116 return Status::OK();
117 }));
118 result->emplace_back(var_idx, handle.name(), variable);
119 }
120 return Status::OK();
121 }
122
InputsFromContext(OpKernelContext * ctx)123 std::vector<const Tensor*> InputsFromContext(OpKernelContext* ctx) {
124 std::vector<const Tensor*> inputs;
125 inputs.reserve(ctx->num_inputs());
126 for (int input_idx = 0; input_idx < ctx->num_inputs(); input_idx++) {
127 inputs.push_back(&ctx->input(input_idx));
128 }
129 return inputs;
130 }
131
LockVariables(absl::Span<VariableInfo> variables)132 Status LockVariables(absl::Span<VariableInfo> variables) {
133 std::vector<int> lock_order(variables.size());
134 std::iota(lock_order.begin(), lock_order.end(), 0);
135
136 // VariableInfoComparator orders all empty VariableInfo instances as
137 // equivalent so it looks like we may want to stable sort these to maintain a
138 // deterministic order between the empty VariableInfo instances. However
139 // since we're sorting by pointer value the sort is pretty non-deterministic
140 // anyway so we don't bother using std::stable_sort for now.
141 absl::c_sort(lock_order, [&](int a, int b) {
142 if (variables[a].var() && variables[b].var()) {
143 return variables[a].var()->mu() < variables[b].var()->mu();
144 }
145
146 // Move all the empty VariableInfo instances to the end.
147 return variables[a].var() != nullptr;
148 });
149
150 mutex* prev = nullptr;
151 for (int i : lock_order) {
152 Var* variable = variables[i].var();
153 if (variable == nullptr) {
154 // All empty VariableInfo instances are at the end of the order
155 // so we're done.
156 break;
157 }
158 mutex* mu = variable->mu();
159 if (prev == mu) {
160 // It is an error to pass the same variable handle twice to the same XLA
161 // cluster because we would not handle variable updates correctly. Any
162 // locks we have already acquired will be released when the VariableInfo
163 // objects are destroyed.
164 // TODO(b/128495870) Add support for passing aliased resource variables.
165 return errors::Unimplemented("Duplicate variable passed to XLA cluster");
166 }
167 VLOG(4) << "Acquiring lock for variable "
168 << reinterpret_cast<void*>(variable);
169 mu->lock();
170 variables[i].set_lock_held();
171 prev = mu;
172 }
173 VLOG(4) << "Finished acquiring variable locks.";
174 return Status::OK();
175 }
176
SnapshotResourceVariables(OpKernelContext * ctx,absl::Span<const int> variable_indices,absl::Span<VariableInfo const> variable_infos,ResourceVarsSnapshot * result)177 Status SnapshotResourceVariables(OpKernelContext* ctx,
178 absl::Span<const int> variable_indices,
179 absl::Span<VariableInfo const> variable_infos,
180 ResourceVarsSnapshot* result) {
181 for (int i = 0, end = variable_indices.size(); i < end; i++) {
182 Var* var = variable_infos[i].var();
183 (*result)[variable_indices[i]] =
184 var ? absl::make_optional(*var->tensor()) : absl::nullopt;
185 }
186 return Status::OK();
187 }
188
XlaComputationLaunchContext(xla::LocalClient * client,se::DeviceMemoryAllocator * xla_allocator,int device_ordinal,bool allocate_xla_tensors,bool use_multiple_streams)189 XlaComputationLaunchContext::XlaComputationLaunchContext(
190 xla::LocalClient* client, se::DeviceMemoryAllocator* xla_allocator,
191 int device_ordinal, bool allocate_xla_tensors, bool use_multiple_streams)
192 : client_(client),
193 xla_allocator_(xla_allocator),
194 allocate_xla_tensors_(allocate_xla_tensors),
195 use_multiple_streams_(use_multiple_streams),
196 device_ordinal_(device_ordinal) {
197 if (use_multiple_streams_) {
198 CHECK(allocate_xla_tensors_) << "To use multiple streams correctly we must "
199 "be allocating XLA tensors!";
200 }
201 }
202
203 // Fills in `execution_input` with `buffer` for `index`.
PopulateExecutionInputBuffer(xla::ExecutionInput & execution_input,xla::ShapeIndex index,se::DeviceMemoryBase & buffer,bool donate_buffer,int device_ordinal,se::DeviceMemoryAllocator * allocator)204 static void PopulateExecutionInputBuffer(xla::ExecutionInput& execution_input,
205 xla::ShapeIndex index,
206 se::DeviceMemoryBase& buffer,
207 bool donate_buffer, int device_ordinal,
208 se::DeviceMemoryAllocator* allocator) {
209 xla::MaybeOwningDeviceMemory* in_buffer =
210 execution_input.MutableBuffer(index);
211 if (donate_buffer) {
212 *in_buffer = se::OwningDeviceMemory(buffer, device_ordinal, allocator);
213 buffer = se::DeviceMemoryBase();
214 } else {
215 *in_buffer = buffer;
216 }
217 }
218
219 xla::StatusOr<std::vector<xla::ExecutionInput>>
PopulateInputs(OpKernelContext * ctx,const XlaCompiler::CompilationResult * compilation_result,const std::map<int,const Tensor * > & resource_vars,int missing_ctx_input_prefix,const xla::HloInputOutputAliasConfig & input_output_alias)220 XlaComputationLaunchContext::PopulateInputs(
221 OpKernelContext* ctx,
222 const XlaCompiler::CompilationResult* compilation_result,
223 const std::map<int, const Tensor*>& resource_vars,
224 int missing_ctx_input_prefix,
225 const xla::HloInputOutputAliasConfig& input_output_alias) {
226 std::vector<xla::ExecutionInput> arguments;
227 arguments.reserve(compilation_result->xla_input_shapes.size());
228
229 xla::TransferManager* transfer_manager =
230 client_->backend().transfer_manager();
231 for (int i = 0, end = compilation_result->xla_input_shapes.size(); i < end;
232 ++i) {
233 int arg_num = compilation_result->input_mapping[i];
234 CHECK_GE(arg_num, missing_ctx_input_prefix);
235 const xla::Shape& shape = compilation_result->xla_input_shapes[i];
236 const xla::Shape& device_shape =
237 transfer_manager->HostShapeToDeviceShape(shape);
238
239 bool is_resource_variable = resource_vars.count(arg_num);
240 bool is_updated_resource_variable =
241 is_resource_variable &&
242 absl::c_any_of(compilation_result->resource_updates,
243 [&](const XlaCompiler::ResourceUpdate& update) {
244 return update.input_index == i && update.modified;
245 });
246
247 const Tensor* t = is_resource_variable
248 ? resource_vars.at(arg_num)
249 : &(ctx->input(arg_num - missing_ctx_input_prefix));
250 CHECK(t);
251 bool donate_buffer =
252 t->RefCountIsOne() && is_updated_resource_variable &&
253 input_output_alias.ParameterHasAlias(i, xla::ShapeIndex{});
254 VLOG(3) << "Processing input: " << i
255 << "; is_resource_variable=" << is_resource_variable
256 << "; is_updated_resource_variable=" << is_updated_resource_variable
257 << "; donate_buffer=" << donate_buffer;
258
259 if (use_multiple_streams_) {
260 CHECK(ctx->op_device_context() && ctx->op_device_context()->stream())
261 << "Must have a stream available when using XLA tensors!";
262 XlaTensor* xla_tensor = XlaTensor::FromTensor(t);
263 CHECK(xla_tensor);
264 xla_tensor->WaitForDefinitionEventOnStream(
265 ctx->op_device_context()->stream());
266 }
267
268 arguments.emplace_back(device_shape, shape);
269 xla::ExecutionInput& execution_input = arguments.back();
270 if (xla::Shape::Equal().MinorToMajorOnlyInLayout()(shape, device_shape)) {
271 se::DeviceMemoryBase dmem = XlaTensor::DeviceMemoryFromTensor(*t);
272 PopulateExecutionInputBuffer(execution_input, xla::ShapeIndex{}, dmem,
273 donate_buffer, device_ordinal_,
274 xla_allocator_);
275 } else {
276 XlaTensor* xla_tensor = XlaTensor::FromTensor(t);
277 CHECK(xla_tensor && xla_tensor->has_shaped_buffer());
278 xla_tensor->shaped_buffer().buffers().ForEachMutableElement(
279 [&](const xla::ShapeIndex& index, se::DeviceMemoryBase* buffer) {
280 PopulateExecutionInputBuffer(execution_input, index, *buffer,
281 donate_buffer, device_ordinal_,
282 xla_allocator_);
283 });
284 }
285 }
286 return std::move(arguments);
287 }
288
289 // Construct the tensor for the given type and buffer.
MakeTensor(DataType dtype,const TensorShape & shape,se::DeviceMemoryBase buffer,Allocator * allocator)290 static Tensor MakeTensor(DataType dtype, const TensorShape& shape,
291 se::DeviceMemoryBase buffer, Allocator* allocator) {
292 size_t expected_size = shape.num_elements() * DataTypeSize(dtype);
293 auto* tensor_buffer = new XlaTensorBuffer(buffer.opaque(), expected_size,
294 buffer.size(), allocator);
295 Tensor t(dtype, shape, tensor_buffer);
296 tensor_buffer->Unref();
297 return t;
298 }
299
300 // Get aliased tensor, or make a new one for the corresponding output operation.
GetOrCreateTensorForOutput(int output_num,OpKernelContext * ctx,int missing_ctx_input_prefix,const xla::HloInputOutputAliasConfig & input_output_alias,absl::Span<const int> input_mapping,const std::map<int,const Tensor * > & resource_vars_snapshots,DataType output_dtype,const TensorShape & output_shape,se::DeviceMemoryBase output_buffer,Allocator * output_allocator)301 static Tensor GetOrCreateTensorForOutput(
302 int output_num, OpKernelContext* ctx, int missing_ctx_input_prefix,
303 const xla::HloInputOutputAliasConfig& input_output_alias,
304 absl::Span<const int> input_mapping,
305 const std::map<int, const Tensor*>& resource_vars_snapshots,
306 DataType output_dtype, const TensorShape& output_shape,
307 se::DeviceMemoryBase output_buffer, Allocator* output_allocator) {
308 xla::ShapeIndex output_index = input_output_alias.shape().IsTuple()
309 ? xla::ShapeIndex({output_num})
310 : xla::ShapeIndex({});
311
312 CHECK(input_output_alias.shape().IsTuple() || output_num == 0);
313 if (absl::optional<xla::HloInputOutputAliasConfig::Alias> alias =
314 input_output_alias.GetAliasedParameter(output_index)) {
315 VLOG(3) << "Found alias: " << alias->ToString();
316 int tf_param =
317 input_mapping[alias->parameter_number] - missing_ctx_input_prefix;
318 const Tensor input_tensor =
319 ctx->input(tf_param).dtype() != DT_RESOURCE
320 ? ctx->input(tf_param)
321 : *resource_vars_snapshots.at(missing_ctx_input_prefix + tf_param);
322 if (output_buffer.opaque() == input_tensor.data()) {
323 return input_tensor;
324 }
325 }
326 return MakeTensor(output_dtype, output_shape, output_buffer,
327 output_allocator);
328 }
329
PopulateXlaTensor(Tensor * output_tensor,xla::ScopedShapedBuffer * output,int output_num,se::Stream * stream,bool use_multiple_streams,std::shared_ptr<se::Event> definition_event)330 static void PopulateXlaTensor(Tensor* output_tensor,
331 xla::ScopedShapedBuffer* output, int output_num,
332 se::Stream* stream, bool use_multiple_streams,
333 std::shared_ptr<se::Event> definition_event) {
334 XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor);
335 CHECK(xla_tensor);
336 xla_tensor->set_shaped_buffer(output->TakeSubTree({output_num}));
337 if (use_multiple_streams) {
338 xla_tensor->ResetDefinitionEvent(definition_event, stream);
339 }
340 }
341
342 // Sets output `output_num` for `ctx` provided it is known at a compile time.
SetOutputForConstant(OpKernelContext * ctx,se::Stream * stream,const XlaCompiler::CompilationResult * compilation_result,int output_num)343 static Status SetOutputForConstant(
344 OpKernelContext* ctx, se::Stream* stream,
345 const XlaCompiler::CompilationResult* compilation_result, int output_num) {
346 CHECK(compilation_result->outputs[output_num].is_constant);
347 const Tensor& const_tensor =
348 compilation_result->outputs[output_num].constant_value;
349 Tensor* output_tensor;
350 if (stream && const_tensor.TotalBytes() > 0) {
351 // Copy host -> device. (Empty tensors don't have backing buffers.)
352 // Manually allocate memory using an XlaTensorBuffer so we can allocate
353 // as much memory as the device requires (as given by
354 // GetByteSizeRequirement). This avoids XlaTransferManager having to
355 // reallocate the device buffer later.
356 VLOG(1) << "Constant output tensor on device";
357
358 TF_RETURN_IF_ERROR(
359 ctx->allocate_output(output_num, const_tensor.shape(), &output_tensor));
360 Device* device = dynamic_cast<Device*>(ctx->device());
361 if (device == nullptr) {
362 return errors::Internal("DeviceBase was not a Device.");
363 }
364 ctx->op_device_context()->CopyCPUTensorToDevice(
365 &const_tensor, device, output_tensor,
366 [&](Status status) { TF_CHECK_OK(status); });
367
368 if (device->device_type() == DEVICE_GPU) {
369 // The GPUDeviceContext enqueues the host->device transfer in a
370 // separate stream from the main compute stream. We must ensure the
371 // compute stream is synchronized with the host->device transfer
372 // stream now otherwise we will create a race condition.
373 auto* gpu_device_context =
374 static_cast<GPUDeviceContext*>(ctx->op_device_context());
375 gpu_device_context->stream()->ThenWaitFor(
376 gpu_device_context->host_to_device_stream());
377 }
378 } else {
379 // No copy required.
380 ctx->set_output(output_num, const_tensor);
381 output_tensor = ctx->mutable_output(output_num);
382 }
383 return Status::OK();
384 }
385
GetOrCreateResourceVar(OpKernelContext * ctx,const ResourceHandle & handle,const XlaCompiler::ResourceUpdate & write)386 static xla::StatusOr<Var*> GetOrCreateResourceVar(
387 OpKernelContext* ctx, const ResourceHandle& handle,
388 const XlaCompiler::ResourceUpdate& write) {
389 Var* variable = nullptr;
390 TF_RETURN_IF_ERROR(
391 LookupOrCreateResource<Var>(ctx, handle, &variable, [&write](Var** ptr) {
392 *ptr = new Var(write.type);
393 return Status::OK();
394 }));
395 return variable;
396 }
397
GatherVariableInfo(OpKernelContext * ctx,const XlaCompiler::CompilationResult & compilation_result,int missing_ctx_input_prefix)398 xla::StatusOr<std::vector<VariableInfo>> GatherVariableInfo(
399 OpKernelContext* ctx,
400 const XlaCompiler::CompilationResult& compilation_result,
401 int missing_ctx_input_prefix) {
402 std::vector<VariableInfo> out;
403 out.reserve(compilation_result.resource_updates.size());
404 for (int i = 0; i < compilation_result.resource_updates.size(); ++i) {
405 const XlaCompiler::ResourceUpdate& write =
406 compilation_result.resource_updates[i];
407 int actual_input_index = write.input_index - missing_ctx_input_prefix;
408 if (actual_input_index < 0 || actual_input_index >= ctx->num_inputs()) {
409 return errors::Internal("Invalid input index for variable write.");
410 }
411
412 const ResourceHandle handle = HandleFromInput(ctx, actual_input_index);
413 TF_ASSIGN_OR_RETURN(Var * variable,
414 GetOrCreateResourceVar(ctx, handle, write));
415 out.emplace_back(actual_input_index, handle.name(), variable);
416 }
417 return std::move(out);
418 }
419
PopulateOutputs(OpKernelContext * ctx,const XlaCompiler::CompilationResult * compilation_result,ScopedShapedBuffer output,int missing_ctx_input_prefix,absl::Span<VariableInfo> variable_infos,const xla::HloInputOutputAliasConfig & input_output_alias,const std::map<int,const Tensor * > & resource_vars)420 Status XlaComputationLaunchContext::PopulateOutputs(
421 OpKernelContext* ctx,
422 const XlaCompiler::CompilationResult* compilation_result,
423 ScopedShapedBuffer output, int missing_ctx_input_prefix,
424 absl::Span<VariableInfo> variable_infos,
425 const xla::HloInputOutputAliasConfig& input_output_alias,
426 const std::map<int, const Tensor*>& resource_vars) {
427 se::Stream* stream =
428 ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
429 Allocator* allocator = ctx->device()->GetAllocator({});
430
431 // Computation output should always be a tuple.
432 VLOG(2) << "Result tuple shape: " << output.on_host_shape().DebugString();
433 VLOG(2) << "Result tuple shape (on device): "
434 << output.on_device_shape().DebugString();
435 CHECK_EQ(ctx->num_outputs(), compilation_result->outputs.size());
436
437 // If the on-host-shape isn't a tuple, create a new single-element tuple
438 // buffer with a nullptr root index table. This allows the code below to treat
439 // output as a tuple unconditionally.
440 if (!output.on_host_shape().IsTuple()) {
441 ShapedBuffer nontuple_buffer = output.release();
442 ShapedBuffer buffer(
443 xla::ShapeUtil::MakeTupleShape({nontuple_buffer.on_host_shape()}),
444 xla::ShapeUtil::MakeTupleShape({nontuple_buffer.on_device_shape()}),
445 output.device_ordinal());
446 buffer.buffers().CopySubtreeFrom(nontuple_buffer.buffers(),
447 /*source_base_index=*/{},
448 /*target_base_index=*/{0});
449 output = ScopedShapedBuffer(std::move(buffer), output.memory_allocator());
450 }
451
452 std::shared_ptr<se::Event> definition_event;
453 if (use_multiple_streams_) {
454 definition_event = std::make_shared<se::Event>(stream->parent());
455 if (!definition_event->Init()) {
456 return errors::Internal("Failed to initialize tensor definition event.");
457 }
458 stream->ThenRecordEvent(definition_event.get());
459 }
460
461 std::vector<TensorShape> output_tensor_shapes;
462 output_tensor_shapes.reserve(ctx->num_outputs());
463 if (output.on_host_shape().is_dynamic()) {
464 TF_ASSIGN_OR_RETURN(
465 auto transfer_manager,
466 xla::TransferManager::GetForPlatform(stream->parent()->platform()));
467
468 xla::Shape output_device_shape = output.on_device_shape();
469 TF_RETURN_IF_ERROR(transfer_manager->ReadDynamicShapes(
470 stream, &output, &output_device_shape));
471
472 output.set_shapes(output_device_shape, output_device_shape);
473 for (int i = 0; i < ctx->num_outputs(); ++i) {
474 const xla::Shape& subshape =
475 xla::ShapeUtil::GetSubshape(output_device_shape, {i});
476 TensorShape shape;
477 TF_RETURN_IF_ERROR(XLAShapeToTensorShape(subshape, &shape));
478 output_tensor_shapes.push_back(shape);
479 }
480 } else {
481 for (int i = 0; i < ctx->num_outputs(); ++i) {
482 output_tensor_shapes.push_back(compilation_result->outputs[i].shape);
483 }
484 }
485
486 // Copy XLA results to the OpOutputList.
487 int output_num = 0;
488 for (int i = 0, end = ctx->num_outputs(); i < end; ++i) {
489 const TensorShape& shape = output_tensor_shapes[i];
490 const DataType& type = compilation_result->outputs[i].type;
491 VLOG(2) << "Populating output for retval " << i << " shape "
492 << shape.DebugString() << " type " << DataTypeString(type);
493 if (type == DT_VARIANT) {
494 return errors::Unimplemented(
495 "Support for TensorList crossing the XLA/TF boundary "
496 "is not implemented");
497 }
498
499 if (compilation_result->outputs[i].is_constant) {
500 TF_RETURN_IF_ERROR(
501 SetOutputForConstant(ctx, stream, compilation_result, i));
502 } else if (type == DT_RESOURCE) {
503 int input_index =
504 compilation_result->outputs[i].input_index - missing_ctx_input_prefix;
505 TF_RET_CHECK(input_index >= 0 && input_index < ctx->num_inputs())
506 << "Invalid input for outputs " << i << ": " << input_index;
507 ctx->set_output(i, ctx->input(input_index));
508 } else {
509 if (allocate_xla_tensors_) {
510 Tensor* output_tensor;
511 TF_RETURN_IF_ERROR(ctx->allocate_output(i, shape, &output_tensor));
512 if (output_tensor->TotalBytes() > 0) {
513 PopulateXlaTensor(output_tensor, &output, output_num, stream,
514 use_multiple_streams_, definition_event);
515 }
516 } else {
517 se::DeviceMemoryBase buffer = output.buffer({output_num});
518 Tensor output_tensor = GetOrCreateTensorForOutput(
519 output_num, ctx, missing_ctx_input_prefix, input_output_alias,
520 compilation_result->input_mapping, resource_vars,
521 ctx->expected_output_dtype(i), shape, buffer, allocator);
522 ctx->set_output(i, output_tensor);
523 }
524 output.set_buffer(se::OwningDeviceMemory(), {output_num});
525 ++output_num;
526 }
527 }
528
529 // input_index -> index into variable_infos.
530 absl::flat_hash_map<int, int> variable_info_lookup;
531 for (int i = 0; i < variable_infos.size(); i++) {
532 variable_info_lookup.emplace(variable_infos[i].index(), i);
533 }
534
535 // Apply variable updates, if any.
536 for (int i = 0, end = compilation_result->resource_updates.size(); i < end;
537 ++i) {
538 const XlaCompiler::ResourceUpdate& write =
539 compilation_result->resource_updates[i];
540 int actual_input_index = write.input_index - missing_ctx_input_prefix;
541 CHECK_GE(actual_input_index, 0);
542 CHECK_LT(actual_input_index, ctx->num_inputs());
543 Var* var = variable_infos[variable_info_lookup[actual_input_index]].var();
544 CHECK(var);
545
546 VLOG(2) << "Updating variable #" << i
547 << " at input index: " << actual_input_index << " with shape "
548 << write.shape.DebugString() << "; variable tensor has shape: "
549 << var->tensor()->shape().DebugString();
550
551 if (var->is_initialized && var->tensor()->dtype() != write.type) {
552 return errors::Internal("Mismatched type in variable write");
553 }
554
555 Tensor output_tensor;
556 if (allocate_xla_tensors_) {
557 TF_RETURN_IF_ERROR(
558 ctx->allocate_temp(write.type, write.shape, &output_tensor));
559 if (write.shape.num_elements() > 0) {
560 PopulateXlaTensor(&output_tensor, &output, output_num, stream,
561 use_multiple_streams_, definition_event);
562 }
563 } else {
564 se::DeviceMemoryBase buffer = output.buffer({output_num});
565 output_tensor = GetOrCreateTensorForOutput(
566 output_num, ctx, missing_ctx_input_prefix, input_output_alias,
567 compilation_result->input_mapping, resource_vars, write.type,
568 write.shape, buffer, allocator);
569 }
570 output.set_buffer(se::OwningDeviceMemory(), {output_num});
571 var->is_initialized |= write.modified;
572 *var->tensor() = output_tensor;
573 ++output_num;
574 }
575 return Status::OK();
576 }
577
578 xla::StatusOr<std::vector<XlaCompiler::Argument>>
BuildXlaCompilerArguments(absl::Span<int const> must_be_constant_idxs,absl::Span<const Tensor * const> inputs,absl::Span<VariableInfo const> variable_args,Device * device)579 XlaComputationLaunchContext::BuildXlaCompilerArguments(
580 absl::Span<int const> must_be_constant_idxs,
581 absl::Span<const Tensor* const> inputs,
582 absl::Span<VariableInfo const> variable_args, Device* device) {
583 CHECK(absl::c_is_sorted(must_be_constant_idxs));
584 std::vector<XlaCompiler::Argument> out;
585 out.resize(inputs.size());
586
587 // TODO(cheshire): Avoid duplication with framework/op_kernel.h
588 DeviceContext* device_context = nullptr;
589 TF_RETURN_IF_ERROR(device->TryGetDeviceContext(&device_context));
590 bool using_default_context = false;
591 auto cleanup = xla::MakeCleanup([&] {
592 if (device_context != nullptr && !using_default_context) {
593 device_context->Unref();
594 }
595 });
596 if (device_context == nullptr) {
597 using_default_context = true;
598 auto* dev_info = device->tensorflow_gpu_device_info();
599 if (dev_info) device_context = dev_info->default_context;
600 }
601
602 absl::flat_hash_map<int, const VariableInfo*> variable_info_lookup;
603 for (const VariableInfo& info : variable_args) {
604 CHECK(!info.var() || info.lock_held())
605 << "Need to hold the lock on resource variables "
606 "before calling BuildXlaCompilerArguments";
607 variable_info_lookup.emplace(info.index(), &info);
608 }
609
610 for (int64 input_num = 0; input_num < inputs.size(); ++input_num) {
611 const Tensor* input = inputs[input_num];
612
613 XlaCompiler::Argument& arg = out[input_num];
614 if (variable_info_lookup.count(input_num)) {
615 // Handles resource variables.
616 TF_RET_CHECK(input->dtype() == DT_RESOURCE);
617 const VariableInfo& variable = *variable_info_lookup[input_num];
618 arg.name = std::string(variable.name());
619 arg.kind = XlaCompiler::Argument::kResource;
620 arg.resource_kind = XlaResource::kVariable;
621 if (variable.var() && variable.var()->is_initialized) {
622 const Tensor* value = variable.var()->tensor();
623 arg.type = value->dtype();
624 arg.shape = value->shape();
625 arg.initialized = true;
626 } else {
627 // The values of uninitialized variables are not passed as inputs, since
628 // they are meaningless. However, it is legal to assign to a resource
629 // variable for the first time inside the XLA computation, so we do
630 // permit uninitialized variables.
631 arg.initialized = false;
632 arg.type = DT_INVALID;
633 arg.shape = TensorShape();
634 }
635
636 if (absl::c_binary_search(must_be_constant_idxs, input_num)) {
637 TF_RET_CHECK(variable.var() && variable.var()->is_initialized);
638 const Tensor* value = variable.var()->tensor();
639 Tensor value_on_host(value->dtype(), value->shape());
640 if (!device_context) {
641 value_on_host = *value;
642 } else {
643 TF_RETURN_IF_ERROR(device_context->CopyDeviceTensorToCPUSync(
644 value, "", device, &value_on_host));
645 }
646 arg.kind = XlaCompiler::Argument::kConstantResource;
647 arg.constant_value = value_on_host;
648 }
649 } else if (absl::c_binary_search(must_be_constant_idxs, input_num)) {
650 arg.kind = XlaCompiler::Argument::kConstant;
651 arg.type = input->dtype();
652 arg.shape = input->shape();
653 arg.constant_value = *input;
654 } else {
655 // Normal inputs.
656 TF_RET_CHECK(input->dtype() != DT_RESOURCE);
657 if (input->NumElements() > 0) {
658 arg.kind = XlaCompiler::Argument::kParameter;
659 } else {
660 arg.kind = XlaCompiler::Argument::kConstant;
661 arg.constant_value = *input;
662 }
663 arg.type = input->dtype();
664 arg.shape = input->shape();
665 }
666 }
667
668 return out;
669 }
670
671 } // namespace tensorflow
672