• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/jit/xla_launch_util.h"
17 
18 #include <memory>
19 
20 #include "absl/algorithm/container.h"
21 #include "absl/memory/memory.h"
22 #include "tensorflow/compiler/jit/defs.h"
23 #include "tensorflow/compiler/tf2xla/shape_util.h"
24 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
25 #include "tensorflow/compiler/xla/client/client_library.h"
26 #include "tensorflow/compiler/xla/client/local_client.h"
27 #include "tensorflow/compiler/xla/statusor.h"
28 #include "tensorflow/core/common_runtime/dma_helper.h"
29 #include "tensorflow/core/common_runtime/function.h"
30 #include "tensorflow/core/common_runtime/gpu_device_context.h"
31 #include "tensorflow/core/framework/allocator.h"
32 #include "tensorflow/core/framework/node_def_util.h"
33 #include "tensorflow/core/framework/op.h"
34 #include "tensorflow/core/framework/op_kernel.h"
35 #include "tensorflow/core/framework/tensor.h"
36 #include "tensorflow/core/framework/types.h"
37 #include "tensorflow/core/util/stream_executor_util.h"
38 
39 namespace tensorflow {
40 namespace {
41 using xla::ScopedShapedBuffer;
42 using xla::ShapedBuffer;
43 }  // anonymous namespace
44 
VariableInfo(int index,Var * var)45 VariableInfo::VariableInfo(int index, Var* var) : index_(index), var_(var) {}
VariableInfo(VariableInfo && other)46 VariableInfo::VariableInfo(VariableInfo&& other)
47     : index_(other.index_), var_(other.var_), lock_held_(other.lock_held_) {
48   other.index_ = -1;
49   other.var_ = nullptr;
50 }
51 
operator =(VariableInfo && other)52 VariableInfo& VariableInfo::operator=(VariableInfo&& other) {
53   index_ = other.index_;
54   var_ = other.var_;
55   lock_held_ = other.lock_held_;
56 
57   other.index_ = -1;
58   other.var_ = nullptr;
59 
60   return *this;
61 }
62 
~VariableInfo()63 VariableInfo::~VariableInfo() {
64   // Release the variable's lock if we hold it. Ensures that the lock is
65   // released even on error.  It does not matter in what order we release the
66   // locks.
67   if (var()) {
68     if (lock_held()) {
69       var()->mu()->unlock();
70     }
71 
72     // Unref the variable so it can be released by ResourceManager.
73     var()->Unref();
74   }
75 }
76 
77 // Returns a vector of VaribleInfo instances for the resource variable inputs to
78 // the kernel with context `ctx`.  The input indices for the resource variable
79 // inputs are in `variable_indices`.
GetVariableInfosFromCtxInputs(OpKernelContext * ctx,absl::Span<const int> variable_indices,std::vector<VariableInfo> * result)80 static Status GetVariableInfosFromCtxInputs(
81     OpKernelContext* ctx, absl::Span<const int> variable_indices,
82     std::vector<VariableInfo>* result) {
83   std::vector<const ResourceHandle*> resource_handles;
84   absl::c_transform(
85       variable_indices, std::back_inserter(resource_handles),
86       [&](int variable_idx) { return &HandleFromInput(ctx, variable_idx); });
87 
88   std::vector<std::unique_ptr<Var, core::RefCountDeleter>> variables;
89   TF_RETURN_IF_ERROR(LookupResources(ctx, resource_handles, &variables));
90 
91   result->clear();
92   result->reserve(variable_indices.size());
93   for (int i = 0; i < variable_indices.size(); i++) {
94     // *Release* the variable because we're going to unref it later in
95     // ~VariableInfo.
96     Var* variable = variables[i].release();
97     result->emplace_back(variable_indices[i], variable);
98   }
99 
100   return Status::OK();
101 }
102 
LockVariables(absl::Span<VariableInfo> variables)103 Status LockVariables(absl::Span<VariableInfo> variables) {
104   std::vector<int> lock_order(variables.size());
105   std::iota(lock_order.begin(), lock_order.end(), 0);
106 
107   // VariableInfoComparator orders all empty VariableInfo instances as
108   // equivalent so it looks like we may want to stable sort these to maintain a
109   // deterministic order between the empty VariableInfo instances.  However
110   // since we're sorting by pointer value the sort is pretty non-deterministic
111   // anyway so we don't bother using std::stable_sort for now.
112   absl::c_sort(lock_order, [&](int a, int b) {
113     if (variables[a].var() && variables[b].var()) {
114       return variables[a].var()->mu() < variables[b].var()->mu();
115     }
116 
117     // Move all the empty VariableInfo instances to the end.
118     return variables[a].var() != nullptr;
119   });
120 
121   mutex* prev = nullptr;
122   for (int i : lock_order) {
123     Var* variable = variables[i].var();
124     if (variable == nullptr) {
125       // All empty VariableInfo instances are at the end of the order
126       // so we're done.
127       break;
128     }
129     mutex* mu = variable->mu();
130     if (prev == mu) {
131       // It is an error to pass the same variable handle twice to the same XLA
132       // cluster because we would not handle variable updates correctly.  Any
133       // locks we have already acquired will be released when the VariableInfo
134       // objects are destroyed.
135       return errors::Internal("Duplicate variable passed to XLA cluster");
136     }
137     VLOG(4) << "Acquiring lock for variable "
138             << reinterpret_cast<void*>(variable);
139     mu->lock();
140     variables[i].set_lock_held();
141     prev = mu;
142   }
143   VLOG(4) << "Finished acquiring variable locks.";
144   return Status::OK();
145 }
146 
SnapshotResourceVariables(OpKernelContext * ctx,absl::Span<const int> variable_indices,std::map<int,OptionalTensor> * result)147 Status SnapshotResourceVariables(OpKernelContext* ctx,
148                                  absl::Span<const int> variable_indices,
149                                  std::map<int, OptionalTensor>* result) {
150   std::vector<VariableInfo> variable_infos;
151   TF_RETURN_IF_ERROR(
152       GetVariableInfosFromCtxInputs(ctx, variable_indices, &variable_infos));
153   TF_RETURN_IF_ERROR(LockVariables(absl::MakeSpan(variable_infos)));
154 
155   for (int i = 0; i < variable_indices.size(); i++) {
156     if (variable_infos[i].var()) {
157       OptionalTensor& tensor = (*result)[variable_indices[i]];
158       tensor.name = HandleFromInput(ctx, variable_indices[i]).name();
159       tensor.present = true;
160       tensor.value = *variable_infos[i].var()->tensor();
161     } else {
162       (*result)[variable_indices[i]] = OptionalTensor();
163     }
164   }
165   return Status::OK();
166 }
167 
XlaAllocator(const se::Platform * platform,Allocator * wrapped)168 XlaAllocator::XlaAllocator(const se::Platform* platform, Allocator* wrapped)
169     : xla::DeviceMemoryAllocator(platform), wrapped_(wrapped) {}
170 
~XlaAllocator()171 XlaAllocator::~XlaAllocator() {}
172 
Allocate(int device_ordinal,uint64 size,bool retry_on_failure)173 xla::StatusOr<xla::OwningDeviceMemory> XlaAllocator::Allocate(
174     int device_ordinal, uint64 size, bool retry_on_failure) {
175   AllocationAttributes attrs;
176   attrs.no_retry_on_failure = !retry_on_failure;
177   void* data = nullptr;
178   if (size != 0) {
179     data = wrapped_->AllocateRaw(Allocator::kAllocatorAlignment, size, attrs);
180     if (data == nullptr) {
181       return errors::ResourceExhausted(
182           "Out of memory while trying to allocate ", size, " bytes.");
183     }
184   }
185   return xla::OwningDeviceMemory(se::DeviceMemoryBase(data, size),
186                                  device_ordinal, this);
187 }
188 
Deallocate(int device_ordinal,se::DeviceMemoryBase mem)189 Status XlaAllocator::Deallocate(int device_ordinal, se::DeviceMemoryBase mem) {
190   wrapped_->DeallocateRaw(mem.opaque());
191   return Status::OK();
192 }
193 
XlaComputationLaunchContext(xla::LocalClient * client,xla::DeviceMemoryAllocator * xla_allocator,bool allocate_xla_tensors,bool use_multiple_streams)194 XlaComputationLaunchContext::XlaComputationLaunchContext(
195     xla::LocalClient* client, xla::DeviceMemoryAllocator* xla_allocator,
196     bool allocate_xla_tensors, bool use_multiple_streams)
197     : client_(client),
198       xla_allocator_(xla_allocator),
199       allocate_xla_tensors_(allocate_xla_tensors),
200       use_multiple_streams_(use_multiple_streams) {
201   if (use_multiple_streams_) {
202     CHECK(allocate_xla_tensors_) << "To use multiple streams correctly we must "
203                                     "be allocating XLA tensors!";
204   }
205 }
206 
PopulateInputs(OpKernelContext * ctx,const XlaCompiler::CompilationResult * kernel,const std::map<int,OptionalTensor> & variables,int missing_ctx_input_prefix)207 void XlaComputationLaunchContext::PopulateInputs(
208     OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel,
209     const std::map<int, OptionalTensor>& variables,
210     int missing_ctx_input_prefix) {
211   se::Stream* stream =
212       ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
213   // Build ShapedBuffers that point directly to the Tensor buffers.
214   arg_buffers_.reserve(kernel->xla_input_shapes.size() + 1);
215   arg_buffers_.resize(kernel->xla_input_shapes.size());
216   arg_ptrs_ = std::vector<ShapedBuffer*>(arg_buffers_.size());
217 
218   // Pass remaining parameters.
219   const Tensor* t;
220   for (int i = 0; i < kernel->xla_input_shapes.size(); ++i) {
221     int arg_num = kernel->input_mapping[i];
222     DCHECK_GE(arg_num, missing_ctx_input_prefix);
223     const xla::Shape& shape = kernel->xla_input_shapes[i];
224     if (variables.count(arg_num)) {
225       t = &(variables.at(arg_num).value);
226       CHECK(t);
227     } else {
228       t = &(ctx->input(arg_num - missing_ctx_input_prefix));
229     }
230 
231     if (use_multiple_streams_) {
232       CHECK(stream) << "Must have a stream available when using XLA tensors!";
233       XlaTensor* xla_tensor = XlaTensor::FromTensor(t);
234       CHECK(xla_tensor);
235       xla_tensor->WaitForDefinitionEventOnStream(stream);
236     }
237 
238     const xla::Shape on_device_shape =
239         client_->backend().transfer_manager()->HostShapeToDeviceShape(shape);
240     if (on_device_shape.IsTuple()) {
241       const XlaTensor* xla_tensor = XlaTensor::FromTensor(t);
242       CHECK(xla_tensor && xla_tensor->has_shaped_buffer());
243       arg_ptrs_[i] = const_cast<ShapedBuffer*>(&xla_tensor->shaped_buffer());
244     } else {
245       CHECK(xla::ShapeUtil::Equal(shape, on_device_shape))
246           << "On-device shape "
247           << xla::ShapeUtil::HumanStringWithLayout(on_device_shape)
248           << " not the same as on-host shape "
249           << xla::ShapeUtil::HumanStringWithLayout(shape);
250       se::DeviceMemoryBase dmem = XlaTensor::DeviceMemoryFromTensor(*t);
251       arg_buffers_[i] = absl::make_unique<ShapedBuffer>(
252           /*on_host_shape=*/shape, /*on_device_shape=*/shape,
253           client_->platform(), client_->default_device_ordinal());
254       arg_buffers_[i]->set_buffer(dmem, /*index=*/{});
255       arg_ptrs_[i] = arg_buffers_[i].get();
256     }
257   }
258 }
259 
PopulateOutputs(OpKernelContext * ctx,const XlaCompiler::CompilationResult * kernel,ScopedShapedBuffer output,int missing_ctx_input_prefix)260 Status XlaComputationLaunchContext::PopulateOutputs(
261     OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel,
262     ScopedShapedBuffer output, int missing_ctx_input_prefix) {
263   se::Stream* stream =
264       ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
265 
266   // Computation output should always be a tuple.
267   if (VLOG_IS_ON(2)) {
268     VLOG(2) << "Result tuple shape: " << output.on_host_shape().DebugString();
269     VLOG(2) << "Result tuple shape (on device): "
270             << output.on_device_shape().DebugString();
271   }
272   CHECK_EQ(ctx->num_outputs(), kernel->outputs.size());
273 
274   // If the on-host-shape isn't a tuple, create a new single-element tuple
275   // buffer with a nullptr root index table. This allows the code below to treat
276   // output as a tuple unconditionally.
277   if (!output.on_host_shape().IsTuple()) {
278     ShapedBuffer nontuple_buffer = output.release();
279     ShapedBuffer buffer(
280         xla::ShapeUtil::MakeTupleShape({nontuple_buffer.on_host_shape()}),
281         xla::ShapeUtil::MakeTupleShape({nontuple_buffer.on_device_shape()}),
282         output.platform(), output.device_ordinal());
283     buffer.buffers().CopySubtreeFrom(nontuple_buffer.buffers(),
284                                      /*source_base_index=*/{},
285                                      /*target_base_index=*/{0});
286     output = ScopedShapedBuffer(std::move(buffer), output.memory_allocator());
287   }
288 
289   std::shared_ptr<se::Event> definition_event;
290   if (use_multiple_streams_) {
291     definition_event = std::make_shared<se::Event>(stream->parent());
292     if (!definition_event->Init()) {
293       return errors::Internal("Failed to initialize tensor definition event.");
294     }
295     stream->ThenRecordEvent(definition_event.get());
296   }
297 
298   // Copy XLA results to the OpOutputList.
299   int output_num = 0;
300   for (int i = 0; i < ctx->num_outputs(); ++i) {
301     Allocator* allocator = ctx->device()->GetAllocator({});
302     if (kernel->outputs[i].is_constant) {
303       // Output is a constant.
304       const Tensor& const_tensor = kernel->outputs[i].constant_value;
305       Tensor* output_tensor;
306       const size_t total_bytes = const_tensor.TotalBytes();
307       if (stream && total_bytes > 0) {
308         // Copy host -> device. (Empty tensors don't have backing buffers.)
309         // Manually allocate memory using an XlaTensorBuffer so we can allocate
310         // as much memory as the device requires (as given by
311         // GetByteSizeRequirement). This avoids XlaTransferManager having to
312         // reallocate the device buffer later.
313         VLOG(1) << "Constant output tensor on device";
314 
315         TF_RETURN_IF_ERROR(
316             ctx->allocate_output(i, const_tensor.shape(), &output_tensor));
317 
318         Device* device = dynamic_cast<Device*>(ctx->device());
319         if (device == nullptr) {
320           return errors::Internal("DeviceBase was not a Device.");
321         }
322         ctx->op_device_context()->CopyCPUTensorToDevice(
323             &const_tensor, device, output_tensor,
324             [&](Status status) { TF_CHECK_OK(status); });
325 
326         if (device->device_type() == DEVICE_GPU) {
327           // The GPUDeviceContext enqueues the host->device transfer in a
328           // separate stream from the main compute stream. We must ensure the
329           // compute stream is synchronized with the host->device transfer
330           // stream now otherwise we will create a race condition.
331           auto* gpu_device_context =
332               static_cast<GPUDeviceContext*>(ctx->op_device_context());
333           gpu_device_context->stream()->ThenWaitFor(
334               gpu_device_context->host_to_device_stream());
335         }
336       } else {
337         // No copy required.
338         ctx->set_output(i, const_tensor);
339         output_tensor = ctx->mutable_output(i);
340       }
341       if (XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor)) {
342         xla_tensor->set_host_tensor(const_tensor);
343       }
344     } else {
345       const TensorShape& shape = kernel->outputs[i].shape;
346       const DataType& type = kernel->outputs[i].type;
347       VLOG(2) << "Retval " << i << " shape " << shape.DebugString() << " type "
348               << DataTypeString(type);
349       if (type == DT_RESOURCE) {
350         TF_RET_CHECK(kernel->outputs[i].input_index >= 0)
351             << "Invalid input for outputs " << i;
352         ctx->set_output(i, ctx->input(kernel->outputs[i].input_index));
353       } else {
354         se::DeviceMemoryBase buffer = output.buffer({output_num});
355         if (allocate_xla_tensors_) {
356           Tensor* output_tensor;
357           TF_RETURN_IF_ERROR(ctx->allocate_output(i, shape, &output_tensor));
358           XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor);
359           if (xla_tensor) {
360             xla_tensor->set_shaped_buffer(output.TakeSubTree({output_num}));
361             if (use_multiple_streams_) {
362               xla_tensor->ResetDefinitionEvent(definition_event, stream);
363             }
364           } else {
365             // xla_tensor wasn't valid, which must mean this is a zero-element
366             // tensor.
367             CHECK_EQ(output_tensor->TotalBytes(), 0);
368           }
369         } else {
370           Tensor output_tensor = XlaTensorBuffer::MakeTensor(
371               ctx->expected_output_dtype(i), shape, buffer, allocator);
372           output.set_buffer(xla::OwningDeviceMemory(), {output_num});
373           ctx->set_output(i, output_tensor);
374         }
375         ++output_num;
376       }
377     }
378 
379     if (VLOG_IS_ON(3)) {
380       VLOG(3) << ctx->mutable_output(i)->DeviceSafeDebugString();
381     }
382   }
383 
384   // Apply variable updates, if any.
385   VLOG(2) << "Applying variable updates";
386   std::vector<VariableInfo> variable_infos;
387   variable_infos.reserve(kernel->resource_updates.size());
388 
389   for (int i = 0; i < kernel->resource_updates.size(); ++i) {
390     const XlaCompiler::ResourceUpdate& write = kernel->resource_updates[i];
391     int actual_input_index = write.input_index - missing_ctx_input_prefix;
392     if (actual_input_index < 0 || actual_input_index >= ctx->num_inputs()) {
393       return errors::Internal("Invalid input index for variable write.");
394     }
395 
396     // TODO(b/35625933): tensorflow::Var should contain a PersistentTensor,
397     // not a Tensor.
398     Var* variable = nullptr;
399     TF_RETURN_IF_ERROR(LookupOrCreateResource<Var>(
400         ctx, HandleFromInput(ctx, actual_input_index), &variable,
401         [&write](Var** ptr) {
402           *ptr = new Var(write.type);
403           return Status::OK();
404         }));
405     variable_infos.emplace_back(actual_input_index, variable);
406   }
407 
408   TF_RETURN_IF_ERROR(LockVariables(absl::MakeSpan(variable_infos)));
409 
410   for (int i = 0; i < kernel->resource_updates.size(); ++i) {
411     Allocator* allocator = ctx->device()->GetAllocator({});
412     const XlaCompiler::ResourceUpdate& write = kernel->resource_updates[i];
413 
414     if (variable_infos[i].var()->tensor()->dtype() != write.type) {
415       return errors::Internal("Mismatched type in variable write");
416     }
417 
418     if (allocate_xla_tensors_) {
419       Tensor output_tensor;
420       TF_RETURN_IF_ERROR(
421           ctx->allocate_temp(write.type, write.shape, &output_tensor));
422       if (write.shape.num_elements() > 0) {
423         XlaTensor* xla_tensor = XlaTensor::FromTensor(&output_tensor);
424         CHECK(xla_tensor);
425         xla_tensor->set_shaped_buffer(output.TakeSubTree({output_num}));
426         if (use_multiple_streams_) {
427           xla_tensor->ResetDefinitionEvent(definition_event, stream);
428         }
429       }
430       *variable_infos[i].var()->tensor() = output_tensor;
431     } else {
432       se::DeviceMemoryBase buffer = output.buffer({output_num});
433       output.set_buffer(xla::OwningDeviceMemory(), {output_num});
434       Tensor output_tensor = XlaTensorBuffer::MakeTensor(
435           write.type, write.shape, buffer, allocator);
436       *variable_infos[i].var()->tensor() = output_tensor;
437     }
438     ++output_num;
439   }
440   return Status::OK();
441 }
442 
BuildXlaCompilerArguments(const std::map<int,Tensor> & constant_args,const std::map<int,OptionalTensor> & variable_args,OpKernelContext * ctx,std::vector<XlaCompiler::Argument> * args)443 Status XlaComputationLaunchContext::BuildXlaCompilerArguments(
444     const std::map<int, Tensor>& constant_args,
445     const std::map<int, OptionalTensor>& variable_args, OpKernelContext* ctx,
446     std::vector<XlaCompiler::Argument>* args) {
447   args->resize(ctx->num_inputs());
448 
449   for (int64 input_num = 0; input_num < ctx->num_inputs(); ++input_num) {
450     XlaCompiler::Argument& arg = (*args)[input_num];
451     if (constant_args.count(input_num) > 0) {
452       // Handles compile-time constants.
453       const Tensor& input = constant_args.at(input_num);
454       TF_RET_CHECK(input.dtype() != DT_RESOURCE);
455       arg.kind = XlaCompiler::Argument::kConstant;
456       arg.type = input.dtype();
457       arg.shape = input.shape();
458       arg.constant_value = input;
459     } else if (variable_args.count(input_num) == 0) {
460       // Handles the non-constant arguments.
461       const Tensor& input = ctx->input(input_num);
462       TF_RET_CHECK(input.dtype() != DT_RESOURCE);
463       if (input.NumElements() > 0) {
464         arg.kind = XlaCompiler::Argument::kParameter;
465       } else {
466         arg.kind = XlaCompiler::Argument::kConstant;
467         arg.constant_value = input;
468       }
469       arg.type = input.dtype();
470       arg.shape = input.shape();
471     } else {
472       // Handles resource variables.
473       const Tensor& input = ctx->input(input_num);
474       TF_RET_CHECK(input.dtype() == DT_RESOURCE);
475       const OptionalTensor& variable = variable_args.at(input_num);
476       arg.name = variable.name;
477       arg.kind = XlaCompiler::Argument::kResource;
478       arg.resource_kind = XlaResource::kVariable;
479       if (variable.present) {
480         const Tensor& value = variable.value;
481         arg.type = value.dtype();
482         arg.shape = value.shape();
483         arg.initialized = true;
484       } else {
485         // The values of uninitialized variables are not passed as inputs, since
486         // they are meaningless. However, it is legal to assign to a resource
487         // variable for the first time inside the XLA computation, so we do
488         // permit uninitialized variables.
489         arg.initialized = false;
490         arg.type = DT_INVALID;
491         arg.shape = TensorShape();
492       }
493     }
494   }
495 
496   return Status::OK();
497 }
498 
499 }  // namespace tensorflow
500