1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/jit/xla_launch_util.h"
17
18 #include <memory>
19
20 #include "absl/algorithm/container.h"
21 #include "absl/memory/memory.h"
22 #include "tensorflow/compiler/jit/defs.h"
23 #include "tensorflow/compiler/tf2xla/shape_util.h"
24 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
25 #include "tensorflow/compiler/xla/client/client_library.h"
26 #include "tensorflow/compiler/xla/client/local_client.h"
27 #include "tensorflow/compiler/xla/statusor.h"
28 #include "tensorflow/core/common_runtime/dma_helper.h"
29 #include "tensorflow/core/common_runtime/function.h"
30 #include "tensorflow/core/common_runtime/gpu_device_context.h"
31 #include "tensorflow/core/framework/allocator.h"
32 #include "tensorflow/core/framework/node_def_util.h"
33 #include "tensorflow/core/framework/op.h"
34 #include "tensorflow/core/framework/op_kernel.h"
35 #include "tensorflow/core/framework/tensor.h"
36 #include "tensorflow/core/framework/types.h"
37 #include "tensorflow/core/util/stream_executor_util.h"
38
39 namespace tensorflow {
40 namespace {
41 using xla::ScopedShapedBuffer;
42 using xla::ShapedBuffer;
43 } // anonymous namespace
44
VariableInfo(int index,Var * var)45 VariableInfo::VariableInfo(int index, Var* var) : index_(index), var_(var) {}
VariableInfo(VariableInfo && other)46 VariableInfo::VariableInfo(VariableInfo&& other)
47 : index_(other.index_), var_(other.var_), lock_held_(other.lock_held_) {
48 other.index_ = -1;
49 other.var_ = nullptr;
50 }
51
operator =(VariableInfo && other)52 VariableInfo& VariableInfo::operator=(VariableInfo&& other) {
53 index_ = other.index_;
54 var_ = other.var_;
55 lock_held_ = other.lock_held_;
56
57 other.index_ = -1;
58 other.var_ = nullptr;
59
60 return *this;
61 }
62
~VariableInfo()63 VariableInfo::~VariableInfo() {
64 // Release the variable's lock if we hold it. Ensures that the lock is
65 // released even on error. It does not matter in what order we release the
66 // locks.
67 if (var()) {
68 if (lock_held()) {
69 var()->mu()->unlock();
70 }
71
72 // Unref the variable so it can be released by ResourceManager.
73 var()->Unref();
74 }
75 }
76
77 // Returns a vector of VaribleInfo instances for the resource variable inputs to
78 // the kernel with context `ctx`. The input indices for the resource variable
79 // inputs are in `variable_indices`.
GetVariableInfosFromCtxInputs(OpKernelContext * ctx,absl::Span<const int> variable_indices,std::vector<VariableInfo> * result)80 static Status GetVariableInfosFromCtxInputs(
81 OpKernelContext* ctx, absl::Span<const int> variable_indices,
82 std::vector<VariableInfo>* result) {
83 std::vector<const ResourceHandle*> resource_handles;
84 absl::c_transform(
85 variable_indices, std::back_inserter(resource_handles),
86 [&](int variable_idx) { return &HandleFromInput(ctx, variable_idx); });
87
88 std::vector<std::unique_ptr<Var, core::RefCountDeleter>> variables;
89 TF_RETURN_IF_ERROR(LookupResources(ctx, resource_handles, &variables));
90
91 result->clear();
92 result->reserve(variable_indices.size());
93 for (int i = 0; i < variable_indices.size(); i++) {
94 // *Release* the variable because we're going to unref it later in
95 // ~VariableInfo.
96 Var* variable = variables[i].release();
97 result->emplace_back(variable_indices[i], variable);
98 }
99
100 return Status::OK();
101 }
102
LockVariables(absl::Span<VariableInfo> variables)103 Status LockVariables(absl::Span<VariableInfo> variables) {
104 std::vector<int> lock_order(variables.size());
105 std::iota(lock_order.begin(), lock_order.end(), 0);
106
107 // VariableInfoComparator orders all empty VariableInfo instances as
108 // equivalent so it looks like we may want to stable sort these to maintain a
109 // deterministic order between the empty VariableInfo instances. However
110 // since we're sorting by pointer value the sort is pretty non-deterministic
111 // anyway so we don't bother using std::stable_sort for now.
112 absl::c_sort(lock_order, [&](int a, int b) {
113 if (variables[a].var() && variables[b].var()) {
114 return variables[a].var()->mu() < variables[b].var()->mu();
115 }
116
117 // Move all the empty VariableInfo instances to the end.
118 return variables[a].var() != nullptr;
119 });
120
121 mutex* prev = nullptr;
122 for (int i : lock_order) {
123 Var* variable = variables[i].var();
124 if (variable == nullptr) {
125 // All empty VariableInfo instances are at the end of the order
126 // so we're done.
127 break;
128 }
129 mutex* mu = variable->mu();
130 if (prev == mu) {
131 // It is an error to pass the same variable handle twice to the same XLA
132 // cluster because we would not handle variable updates correctly. Any
133 // locks we have already acquired will be released when the VariableInfo
134 // objects are destroyed.
135 return errors::Internal("Duplicate variable passed to XLA cluster");
136 }
137 VLOG(4) << "Acquiring lock for variable "
138 << reinterpret_cast<void*>(variable);
139 mu->lock();
140 variables[i].set_lock_held();
141 prev = mu;
142 }
143 VLOG(4) << "Finished acquiring variable locks.";
144 return Status::OK();
145 }
146
SnapshotResourceVariables(OpKernelContext * ctx,absl::Span<const int> variable_indices,std::map<int,OptionalTensor> * result)147 Status SnapshotResourceVariables(OpKernelContext* ctx,
148 absl::Span<const int> variable_indices,
149 std::map<int, OptionalTensor>* result) {
150 std::vector<VariableInfo> variable_infos;
151 TF_RETURN_IF_ERROR(
152 GetVariableInfosFromCtxInputs(ctx, variable_indices, &variable_infos));
153 TF_RETURN_IF_ERROR(LockVariables(absl::MakeSpan(variable_infos)));
154
155 for (int i = 0; i < variable_indices.size(); i++) {
156 if (variable_infos[i].var()) {
157 OptionalTensor& tensor = (*result)[variable_indices[i]];
158 tensor.name = HandleFromInput(ctx, variable_indices[i]).name();
159 tensor.present = true;
160 tensor.value = *variable_infos[i].var()->tensor();
161 } else {
162 (*result)[variable_indices[i]] = OptionalTensor();
163 }
164 }
165 return Status::OK();
166 }
167
XlaAllocator(const se::Platform * platform,Allocator * wrapped)168 XlaAllocator::XlaAllocator(const se::Platform* platform, Allocator* wrapped)
169 : xla::DeviceMemoryAllocator(platform), wrapped_(wrapped) {}
170
~XlaAllocator()171 XlaAllocator::~XlaAllocator() {}
172
Allocate(int device_ordinal,uint64 size,bool retry_on_failure)173 xla::StatusOr<xla::OwningDeviceMemory> XlaAllocator::Allocate(
174 int device_ordinal, uint64 size, bool retry_on_failure) {
175 AllocationAttributes attrs;
176 attrs.no_retry_on_failure = !retry_on_failure;
177 void* data = nullptr;
178 if (size != 0) {
179 data = wrapped_->AllocateRaw(Allocator::kAllocatorAlignment, size, attrs);
180 if (data == nullptr) {
181 return errors::ResourceExhausted(
182 "Out of memory while trying to allocate ", size, " bytes.");
183 }
184 }
185 return xla::OwningDeviceMemory(se::DeviceMemoryBase(data, size),
186 device_ordinal, this);
187 }
188
Deallocate(int device_ordinal,se::DeviceMemoryBase mem)189 Status XlaAllocator::Deallocate(int device_ordinal, se::DeviceMemoryBase mem) {
190 wrapped_->DeallocateRaw(mem.opaque());
191 return Status::OK();
192 }
193
XlaComputationLaunchContext(xla::LocalClient * client,xla::DeviceMemoryAllocator * xla_allocator,bool allocate_xla_tensors,bool use_multiple_streams)194 XlaComputationLaunchContext::XlaComputationLaunchContext(
195 xla::LocalClient* client, xla::DeviceMemoryAllocator* xla_allocator,
196 bool allocate_xla_tensors, bool use_multiple_streams)
197 : client_(client),
198 xla_allocator_(xla_allocator),
199 allocate_xla_tensors_(allocate_xla_tensors),
200 use_multiple_streams_(use_multiple_streams) {
201 if (use_multiple_streams_) {
202 CHECK(allocate_xla_tensors_) << "To use multiple streams correctly we must "
203 "be allocating XLA tensors!";
204 }
205 }
206
PopulateInputs(OpKernelContext * ctx,const XlaCompiler::CompilationResult * kernel,const std::map<int,OptionalTensor> & variables,int missing_ctx_input_prefix)207 void XlaComputationLaunchContext::PopulateInputs(
208 OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel,
209 const std::map<int, OptionalTensor>& variables,
210 int missing_ctx_input_prefix) {
211 se::Stream* stream =
212 ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
213 // Build ShapedBuffers that point directly to the Tensor buffers.
214 arg_buffers_.reserve(kernel->xla_input_shapes.size() + 1);
215 arg_buffers_.resize(kernel->xla_input_shapes.size());
216 arg_ptrs_ = std::vector<ShapedBuffer*>(arg_buffers_.size());
217
218 // Pass remaining parameters.
219 const Tensor* t;
220 for (int i = 0; i < kernel->xla_input_shapes.size(); ++i) {
221 int arg_num = kernel->input_mapping[i];
222 DCHECK_GE(arg_num, missing_ctx_input_prefix);
223 const xla::Shape& shape = kernel->xla_input_shapes[i];
224 if (variables.count(arg_num)) {
225 t = &(variables.at(arg_num).value);
226 CHECK(t);
227 } else {
228 t = &(ctx->input(arg_num - missing_ctx_input_prefix));
229 }
230
231 if (use_multiple_streams_) {
232 CHECK(stream) << "Must have a stream available when using XLA tensors!";
233 XlaTensor* xla_tensor = XlaTensor::FromTensor(t);
234 CHECK(xla_tensor);
235 xla_tensor->WaitForDefinitionEventOnStream(stream);
236 }
237
238 const xla::Shape on_device_shape =
239 client_->backend().transfer_manager()->HostShapeToDeviceShape(shape);
240 if (on_device_shape.IsTuple()) {
241 const XlaTensor* xla_tensor = XlaTensor::FromTensor(t);
242 CHECK(xla_tensor && xla_tensor->has_shaped_buffer());
243 arg_ptrs_[i] = const_cast<ShapedBuffer*>(&xla_tensor->shaped_buffer());
244 } else {
245 CHECK(xla::ShapeUtil::Equal(shape, on_device_shape))
246 << "On-device shape "
247 << xla::ShapeUtil::HumanStringWithLayout(on_device_shape)
248 << " not the same as on-host shape "
249 << xla::ShapeUtil::HumanStringWithLayout(shape);
250 se::DeviceMemoryBase dmem = XlaTensor::DeviceMemoryFromTensor(*t);
251 arg_buffers_[i] = absl::make_unique<ShapedBuffer>(
252 /*on_host_shape=*/shape, /*on_device_shape=*/shape,
253 client_->platform(), client_->default_device_ordinal());
254 arg_buffers_[i]->set_buffer(dmem, /*index=*/{});
255 arg_ptrs_[i] = arg_buffers_[i].get();
256 }
257 }
258 }
259
PopulateOutputs(OpKernelContext * ctx,const XlaCompiler::CompilationResult * kernel,ScopedShapedBuffer output,int missing_ctx_input_prefix)260 Status XlaComputationLaunchContext::PopulateOutputs(
261 OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel,
262 ScopedShapedBuffer output, int missing_ctx_input_prefix) {
263 se::Stream* stream =
264 ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
265
266 // Computation output should always be a tuple.
267 if (VLOG_IS_ON(2)) {
268 VLOG(2) << "Result tuple shape: " << output.on_host_shape().DebugString();
269 VLOG(2) << "Result tuple shape (on device): "
270 << output.on_device_shape().DebugString();
271 }
272 CHECK_EQ(ctx->num_outputs(), kernel->outputs.size());
273
274 // If the on-host-shape isn't a tuple, create a new single-element tuple
275 // buffer with a nullptr root index table. This allows the code below to treat
276 // output as a tuple unconditionally.
277 if (!output.on_host_shape().IsTuple()) {
278 ShapedBuffer nontuple_buffer = output.release();
279 ShapedBuffer buffer(
280 xla::ShapeUtil::MakeTupleShape({nontuple_buffer.on_host_shape()}),
281 xla::ShapeUtil::MakeTupleShape({nontuple_buffer.on_device_shape()}),
282 output.platform(), output.device_ordinal());
283 buffer.buffers().CopySubtreeFrom(nontuple_buffer.buffers(),
284 /*source_base_index=*/{},
285 /*target_base_index=*/{0});
286 output = ScopedShapedBuffer(std::move(buffer), output.memory_allocator());
287 }
288
289 std::shared_ptr<se::Event> definition_event;
290 if (use_multiple_streams_) {
291 definition_event = std::make_shared<se::Event>(stream->parent());
292 if (!definition_event->Init()) {
293 return errors::Internal("Failed to initialize tensor definition event.");
294 }
295 stream->ThenRecordEvent(definition_event.get());
296 }
297
298 // Copy XLA results to the OpOutputList.
299 int output_num = 0;
300 for (int i = 0; i < ctx->num_outputs(); ++i) {
301 Allocator* allocator = ctx->device()->GetAllocator({});
302 if (kernel->outputs[i].is_constant) {
303 // Output is a constant.
304 const Tensor& const_tensor = kernel->outputs[i].constant_value;
305 Tensor* output_tensor;
306 const size_t total_bytes = const_tensor.TotalBytes();
307 if (stream && total_bytes > 0) {
308 // Copy host -> device. (Empty tensors don't have backing buffers.)
309 // Manually allocate memory using an XlaTensorBuffer so we can allocate
310 // as much memory as the device requires (as given by
311 // GetByteSizeRequirement). This avoids XlaTransferManager having to
312 // reallocate the device buffer later.
313 VLOG(1) << "Constant output tensor on device";
314
315 TF_RETURN_IF_ERROR(
316 ctx->allocate_output(i, const_tensor.shape(), &output_tensor));
317
318 Device* device = dynamic_cast<Device*>(ctx->device());
319 if (device == nullptr) {
320 return errors::Internal("DeviceBase was not a Device.");
321 }
322 ctx->op_device_context()->CopyCPUTensorToDevice(
323 &const_tensor, device, output_tensor,
324 [&](Status status) { TF_CHECK_OK(status); });
325
326 if (device->device_type() == DEVICE_GPU) {
327 // The GPUDeviceContext enqueues the host->device transfer in a
328 // separate stream from the main compute stream. We must ensure the
329 // compute stream is synchronized with the host->device transfer
330 // stream now otherwise we will create a race condition.
331 auto* gpu_device_context =
332 static_cast<GPUDeviceContext*>(ctx->op_device_context());
333 gpu_device_context->stream()->ThenWaitFor(
334 gpu_device_context->host_to_device_stream());
335 }
336 } else {
337 // No copy required.
338 ctx->set_output(i, const_tensor);
339 output_tensor = ctx->mutable_output(i);
340 }
341 if (XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor)) {
342 xla_tensor->set_host_tensor(const_tensor);
343 }
344 } else {
345 const TensorShape& shape = kernel->outputs[i].shape;
346 const DataType& type = kernel->outputs[i].type;
347 VLOG(2) << "Retval " << i << " shape " << shape.DebugString() << " type "
348 << DataTypeString(type);
349 if (type == DT_RESOURCE) {
350 TF_RET_CHECK(kernel->outputs[i].input_index >= 0)
351 << "Invalid input for outputs " << i;
352 ctx->set_output(i, ctx->input(kernel->outputs[i].input_index));
353 } else {
354 se::DeviceMemoryBase buffer = output.buffer({output_num});
355 if (allocate_xla_tensors_) {
356 Tensor* output_tensor;
357 TF_RETURN_IF_ERROR(ctx->allocate_output(i, shape, &output_tensor));
358 XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor);
359 if (xla_tensor) {
360 xla_tensor->set_shaped_buffer(output.TakeSubTree({output_num}));
361 if (use_multiple_streams_) {
362 xla_tensor->ResetDefinitionEvent(definition_event, stream);
363 }
364 } else {
365 // xla_tensor wasn't valid, which must mean this is a zero-element
366 // tensor.
367 CHECK_EQ(output_tensor->TotalBytes(), 0);
368 }
369 } else {
370 Tensor output_tensor = XlaTensorBuffer::MakeTensor(
371 ctx->expected_output_dtype(i), shape, buffer, allocator);
372 output.set_buffer(xla::OwningDeviceMemory(), {output_num});
373 ctx->set_output(i, output_tensor);
374 }
375 ++output_num;
376 }
377 }
378
379 if (VLOG_IS_ON(3)) {
380 VLOG(3) << ctx->mutable_output(i)->DeviceSafeDebugString();
381 }
382 }
383
384 // Apply variable updates, if any.
385 VLOG(2) << "Applying variable updates";
386 std::vector<VariableInfo> variable_infos;
387 variable_infos.reserve(kernel->resource_updates.size());
388
389 for (int i = 0; i < kernel->resource_updates.size(); ++i) {
390 const XlaCompiler::ResourceUpdate& write = kernel->resource_updates[i];
391 int actual_input_index = write.input_index - missing_ctx_input_prefix;
392 if (actual_input_index < 0 || actual_input_index >= ctx->num_inputs()) {
393 return errors::Internal("Invalid input index for variable write.");
394 }
395
396 // TODO(b/35625933): tensorflow::Var should contain a PersistentTensor,
397 // not a Tensor.
398 Var* variable = nullptr;
399 TF_RETURN_IF_ERROR(LookupOrCreateResource<Var>(
400 ctx, HandleFromInput(ctx, actual_input_index), &variable,
401 [&write](Var** ptr) {
402 *ptr = new Var(write.type);
403 return Status::OK();
404 }));
405 variable_infos.emplace_back(actual_input_index, variable);
406 }
407
408 TF_RETURN_IF_ERROR(LockVariables(absl::MakeSpan(variable_infos)));
409
410 for (int i = 0; i < kernel->resource_updates.size(); ++i) {
411 Allocator* allocator = ctx->device()->GetAllocator({});
412 const XlaCompiler::ResourceUpdate& write = kernel->resource_updates[i];
413
414 if (variable_infos[i].var()->tensor()->dtype() != write.type) {
415 return errors::Internal("Mismatched type in variable write");
416 }
417
418 if (allocate_xla_tensors_) {
419 Tensor output_tensor;
420 TF_RETURN_IF_ERROR(
421 ctx->allocate_temp(write.type, write.shape, &output_tensor));
422 if (write.shape.num_elements() > 0) {
423 XlaTensor* xla_tensor = XlaTensor::FromTensor(&output_tensor);
424 CHECK(xla_tensor);
425 xla_tensor->set_shaped_buffer(output.TakeSubTree({output_num}));
426 if (use_multiple_streams_) {
427 xla_tensor->ResetDefinitionEvent(definition_event, stream);
428 }
429 }
430 *variable_infos[i].var()->tensor() = output_tensor;
431 } else {
432 se::DeviceMemoryBase buffer = output.buffer({output_num});
433 output.set_buffer(xla::OwningDeviceMemory(), {output_num});
434 Tensor output_tensor = XlaTensorBuffer::MakeTensor(
435 write.type, write.shape, buffer, allocator);
436 *variable_infos[i].var()->tensor() = output_tensor;
437 }
438 ++output_num;
439 }
440 return Status::OK();
441 }
442
BuildXlaCompilerArguments(const std::map<int,Tensor> & constant_args,const std::map<int,OptionalTensor> & variable_args,OpKernelContext * ctx,std::vector<XlaCompiler::Argument> * args)443 Status XlaComputationLaunchContext::BuildXlaCompilerArguments(
444 const std::map<int, Tensor>& constant_args,
445 const std::map<int, OptionalTensor>& variable_args, OpKernelContext* ctx,
446 std::vector<XlaCompiler::Argument>* args) {
447 args->resize(ctx->num_inputs());
448
449 for (int64 input_num = 0; input_num < ctx->num_inputs(); ++input_num) {
450 XlaCompiler::Argument& arg = (*args)[input_num];
451 if (constant_args.count(input_num) > 0) {
452 // Handles compile-time constants.
453 const Tensor& input = constant_args.at(input_num);
454 TF_RET_CHECK(input.dtype() != DT_RESOURCE);
455 arg.kind = XlaCompiler::Argument::kConstant;
456 arg.type = input.dtype();
457 arg.shape = input.shape();
458 arg.constant_value = input;
459 } else if (variable_args.count(input_num) == 0) {
460 // Handles the non-constant arguments.
461 const Tensor& input = ctx->input(input_num);
462 TF_RET_CHECK(input.dtype() != DT_RESOURCE);
463 if (input.NumElements() > 0) {
464 arg.kind = XlaCompiler::Argument::kParameter;
465 } else {
466 arg.kind = XlaCompiler::Argument::kConstant;
467 arg.constant_value = input;
468 }
469 arg.type = input.dtype();
470 arg.shape = input.shape();
471 } else {
472 // Handles resource variables.
473 const Tensor& input = ctx->input(input_num);
474 TF_RET_CHECK(input.dtype() == DT_RESOURCE);
475 const OptionalTensor& variable = variable_args.at(input_num);
476 arg.name = variable.name;
477 arg.kind = XlaCompiler::Argument::kResource;
478 arg.resource_kind = XlaResource::kVariable;
479 if (variable.present) {
480 const Tensor& value = variable.value;
481 arg.type = value.dtype();
482 arg.shape = value.shape();
483 arg.initialized = true;
484 } else {
485 // The values of uninitialized variables are not passed as inputs, since
486 // they are meaningless. However, it is legal to assign to a resource
487 // variable for the first time inside the XLA computation, so we do
488 // permit uninitialized variables.
489 arg.initialized = false;
490 arg.type = DT_INVALID;
491 arg.shape = TensorShape();
492 }
493 }
494 }
495
496 return Status::OK();
497 }
498
499 } // namespace tensorflow
500