• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/tpu/kernels/tpu_execute_op.h"
16 
17 #include <utility>
18 
19 #include "absl/container/flat_hash_map.h"
20 #include "absl/memory/memory.h"
21 #include "absl/types/span.h"
22 #include "tensorflow/compiler/jit/xla_device.h"
23 #include "tensorflow/compiler/jit/xla_launch_util.h"
24 #include "tensorflow/compiler/jit/xla_tensor.h"
25 #include "tensorflow/compiler/tf2xla/shape_util.h"
26 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
27 #include "tensorflow/compiler/xla/debug_options_flags.h"
28 #include "tensorflow/compiler/xla/service/dump.h"
29 #include "tensorflow/compiler/xla/service/executable.h"
30 #include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h"
31 #include "tensorflow/compiler/xla/shape_util.h"
32 #include "tensorflow/compiler/xla/statusor.h"
33 #include "tensorflow/compiler/xla/xla_data.pb.h"
34 #include "tensorflow/core/framework/allocator.h"
35 #include "tensorflow/core/framework/node_def_util.h"
36 #include "tensorflow/core/framework/op.h"
37 #include "tensorflow/core/framework/op_kernel.h"
38 #include "tensorflow/core/framework/resource_mgr.h"
39 #include "tensorflow/core/framework/resource_var.h"
40 #include "tensorflow/core/framework/tensor.h"
41 #include "tensorflow/core/framework/types.h"
42 #include "tensorflow/core/lib/core/errors.h"
43 #include "tensorflow/core/platform/casts.h"
44 #include "tensorflow/core/platform/tracing.h"
45 #include "tensorflow/core/profiler/lib/traceme.h"
46 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h"
47 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_external.h"
48 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h"
49 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_local_lookup.h"
50 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h"
51 #include "tensorflow/core/tpu/kernels/tpu_executable_info.pb.h"
52 #include "tensorflow/core/tpu/kernels/tpu_op_consts.h"
53 #include "tensorflow/core/tpu/tpu_configuration.h"
54 #include "tensorflow/core/tpu/tpu_defs.h"
55 #include "tensorflow/core/tpu/tpu_execute.h"
56 #include "tensorflow/core/util/stream_executor_util.h"
57 #include "tensorflow/stream_executor/device_memory_allocator.h"
58 #include "tensorflow/stream_executor/tpu/tpu_node_context.h"
59 
60 namespace tensorflow {
61 namespace {
62 using ::tensorflow::tpu::CompilationCacheEntryRef;
63 using ::tensorflow::tpu::TpuCompilationCacheLookup;
64 using ::tensorflow::tpu::TpuNodeContext;
65 
66 // Looks up the input `key` in the compilation cache, populating
67 // `*rendezvous_key_base` and `*entry`.
GetComputationCacheEntry(OpKernelContext * context,string * rendezvous_key_base,std::unique_ptr<CompilationCacheEntryRef> * entry)68 Status GetComputationCacheEntry(
69     OpKernelContext* context, string* rendezvous_key_base,
70     std::unique_ptr<CompilationCacheEntryRef>* entry) {
71   const Tensor* key;
72   TF_RETURN_IF_ERROR(context->input("key", &key));
73   profiler::TraceMe trace_me("TpuExecuteOp::LookupProto", /*level=*/2);
74   if (!TensorShapeUtils::IsVector(key->shape()) ||
75       key->shape().dim_size(0) != 3) {
76     return errors::InvalidArgument(
77         "Key argument to TPUExecute must be a 3-element vector");
78   }
79 
80   ResourceMgr* rmgr = GetTPUConfigResourceMgr();
81   TpuCompilationCacheLookup* proto_lookup;
82   TF_RETURN_IF_ERROR(rmgr->Lookup(rmgr->default_container(),
83                                   tpu::kCompiledProtoCacheResourceName,
84                                   &proto_lookup));
85   core::ScopedUnref lookup_unref(proto_lookup);
86   TF_RETURN_IF_ERROR(proto_lookup->Lookup(key->vec<tstring>()(0), entry));
87   *rendezvous_key_base = key->vec<tstring>()(1);
88   return Status::OK();
89 }
90 
91 struct VariableUpdateMap {
92   // Maps input index to the updated output index. If the variable doesn't have
93   // an updated output, the corresponding output is set to -1.
94   absl::flat_hash_map<int, int> input_to_output;
95   // Maps output index to (the input index, whether the update is generated from
96   // compilation).
97   absl::flat_hash_map<int, std::pair<int, bool>> output_to_input;
98   // Part of the input indices that are from the compilation, in the compiled
99   // order.
100   std::vector<int> input_in_compiled_update_order;
101 };
102 
103 // Creates a VariableUpdateMap from both the compilation and the fused variable
104 // reads/updates.
BuildVariableUpdateMap(absl::Span<const TPUExecutableInfoProto::UpdateIndexPair * const> compiled_variable_updates,absl::Span<int const> fused_device_var_reads_in_computation_inputs,const std::vector<int> & fused_device_var_updates_in_computation_outputs,int64 computation_output_count)105 xla::StatusOr<VariableUpdateMap> BuildVariableUpdateMap(
106     absl::Span<const TPUExecutableInfoProto::UpdateIndexPair* const>
107         compiled_variable_updates,
108     absl::Span<int const> fused_device_var_reads_in_computation_inputs,
109     const std::vector<int>& fused_device_var_updates_in_computation_outputs,
110     int64 computation_output_count) {
111   VariableUpdateMap map;
112   auto add_pair = [&](int input, int output, bool from_compilation) -> Status {
113     TF_RET_CHECK(map.input_to_output.emplace(input, output).second)
114         << "Duplicate variable input index: " << input;
115     if (output >= 0) {
116       TF_RET_CHECK(map.output_to_input
117                        .emplace(output, std::make_pair(input, from_compilation))
118                        .second)
119           << "Duplicate variable output index: " << output;
120     }
121     return Status::OK();
122   };
123 
124   // First add the updates produced by the compilation. Not all variables are
125   // updated, and if not, they do not have an output in the XLA computation. The
126   // update output indices in the XLA computation start after the non-variable
127   // outputs.
128   int num_updated_variables = 0;
129   for (int i = 0; i < compiled_variable_updates.size(); ++i) {
130     const bool updated = compiled_variable_updates[i]->updated();
131     if (updated) ++num_updated_variables;
132   }
133   TF_RET_CHECK(num_updated_variables <= computation_output_count)
134       << num_updated_variables << " <= " << computation_output_count;
135   int64 compiled_variable_output_index =
136       computation_output_count - num_updated_variables;
137   for (auto update : compiled_variable_updates) {
138     map.input_in_compiled_update_order.push_back(update->index());
139     if (!update->updated()) {
140       TF_RETURN_IF_ERROR(add_pair(update->index(), -1, true));
141       continue;
142     }
143     TF_RETURN_IF_ERROR(
144         add_pair(update->index(), compiled_variable_output_index, true));
145     ++compiled_variable_output_index;
146   }
147 
148   // Now add the updates from the attributes.
149   TF_RET_CHECK(fused_device_var_reads_in_computation_inputs.size() ==
150                fused_device_var_updates_in_computation_outputs.size());
151   for (int64 i = 0; i < fused_device_var_reads_in_computation_inputs.size();
152        ++i) {
153     TF_RETURN_IF_ERROR(
154         add_pair(fused_device_var_reads_in_computation_inputs[i],
155                  fused_device_var_updates_in_computation_outputs[i], false));
156   }
157   return map;
158 }
159 
160 // Buffers representing the inputs to a computation.
161 struct InputBuffers {
InputBufferstensorflow::__anonfe26a15c0111::InputBuffers162   explicit InputBuffers(xla::Shape device_shape)
163       : buffers(std::move(device_shape)) {}
164 
165   InputBuffers(const InputBuffers&) = delete;
166   InputBuffers& operator=(const InputBuffers&) = delete;
167 
168   ~InputBuffers() = default;
169 
ToShapedBuffertensorflow::__anonfe26a15c0111::InputBuffers170   xla::ShapedBuffer ToShapedBuffer(xla::Shape host_shape,
171                                    se::DeviceMemoryAllocator* allocator,
172                                    int device_ordinal) {
173     CHECK_NE(allocator, nullptr);
174     xla::ShapedBuffer shaped_buffer(std::move(host_shape), buffers.shape(),
175                                     device_ordinal);
176     shaped_buffer.set_buffers(buffers.Map<se::DeviceMemoryBase>(
177         [](xla::MaybeOwningDeviceMemory* buffer) {
178           CHECK(buffer);
179           return buffer->AsDeviceMemoryBase();
180         }));
181     return shaped_buffer;
182   }
183 
184   // Describes the buffer tree.
185   xla::ShapeTree<xla::MaybeOwningDeviceMemory> buffers;
186 
187   // Information about resource variables passed directly to TPUExecute.
188   std::vector<VariableInfo> variables;
189 
190   // Mapping from input index to offsets in 'variables'. < 0 if the input does
191   // not correspond to a variable in 'variables'.
192   std::vector<int> variable_index;
193 };
194 
195 // Builds an InputBuffers object that describes the inputs to the computation.
BuildComputationInputs(OpKernelContext * context,const xla::Shape & input_host_shape,const VariableUpdateMap & variable_updates,xla::Backend * backend,int device_ordinal,se::Stream * stream)196 xla::StatusOr<std::unique_ptr<InputBuffers>> BuildComputationInputs(
197     OpKernelContext* context, const xla::Shape& input_host_shape,
198     const VariableUpdateMap& variable_updates, xla::Backend* backend,
199     int device_ordinal, se::Stream* stream) {
200   profiler::TraceMe trace_me("BuildComputationInputs", /*level=*/2);
201   OpInputList arg_list;
202   TF_RETURN_IF_ERROR(context->input_list("args", &arg_list));
203 
204   if (arg_list.size() != xla::ShapeUtil::TupleElementCount(input_host_shape)) {
205     return errors::InvalidArgument(
206         "Number of parameters (", arg_list.size(),
207         ") does not match input shape: ",
208         xla::ShapeUtil::TupleElementCount(input_host_shape));
209   }
210 
211   auto validate_shape = [&](int i, const Tensor& tensor) {
212     const xla::Shape& expected =
213         xla::ShapeUtil::GetTupleElementShape(input_host_shape, i);
214     VLOG(4) << "Input " << i << " TF shape " << tensor.shape().DebugString();
215     XlaTensor* xla_tensor = XlaTensor::FromTensor(&tensor);
216 
217     if (xla_tensor == nullptr) {
218       // FromTensor failed; tensor must be empty.
219       if (!xla::ShapeUtil::IsZeroElementArray(expected)) {
220         return errors::InvalidArgument(
221             "Run-time shape mismatch for TPUExecute argument[", i, "] (",
222             context->op_kernel().requested_input(i), "). Expected ",
223             expected.DebugString(),
224             "; got empty tensor. If you are running "
225             "with TF2 TPU, make sure you set `drop_remainder=False` when "
226             "calling `dataset.batch` on the `tf.data.Dataset` so dynamic batch "
227             "size can be handled");
228       }
229     } else {
230       // Compare host shapes, easier than getting the expected device shape.
231       const xla::Shape& xla_shape = xla_tensor->shaped_buffer().on_host_shape();
232       if (!xla::ShapeUtil::Compatible(expected, xla_shape)) {
233         return errors::InvalidArgument(
234             "Run-time shape mismatch for TPUExecute argument[", i, "] (",
235             context->op_kernel().requested_input(i), "). Expected ",
236             expected.DebugString(), "; got ", xla_shape.DebugString());
237       }
238     }
239 
240     return Status::OK();
241   };
242 
243   // Iterate over the inputs, validating the shapes of non-variable inputs,
244   // and creating a VariableInfo object for each variable. We consider variable
245   // inputs in a separate phase because we must acquire variable locks in order.
246   std::vector<VariableInfo> variables;
247   std::vector<int> variable_index(arg_list.size(), -1);
248   variables.reserve(arg_list.size());
249   for (int i = 0; i < arg_list.size(); ++i) {
250     // Arguments are assumed to be variables if they have a resource type.
251     // (Non-variable resources are not supported.)
252     if (context->input_dtype(i) == DT_RESOURCE) {
253       variable_index[i] = variables.size();
254       // TODO(phawkins): we may be looking up many variables here; it would be
255       // better if we did not repeatedly acquire the resource manager's lock.
256       const ResourceHandle& handle = HandleFromInput(context, i);
257       Var* variable;
258       TF_RETURN_IF_ERROR(LookupResource(context, handle, &variable));
259       variables.push_back(VariableInfo(i, handle.name(), variable));
260     } else {
261       TF_RETURN_IF_ERROR(validate_shape(i, arg_list[i]));
262     }
263   }
264 
265   // Lock the variables, and validate their shapes. We hold the variable locks
266   // for the duration of the TPU execution so we can donate the variable buffers
267   // to the computation. If we copied the variable's Tensor instead, its
268   // reference count would be greater than one due to the reference the Var
269   // object holds, and we would never be able to reuse variable buffers.
270   // TODO(phawkins): add a 'reuse_buffers' attribute to TPUExecute that allows
271   // the user to elect to copy the buffers and permit concurrent access instead.
272   TF_RETURN_IF_ERROR(LockVariables(absl::MakeSpan(variables)));
273   for (int i = 0; i < variables.size(); ++i) {
274     TF_RETURN_IF_ERROR(
275         validate_shape(variables[i].index(), *variables[i].var()->tensor()));
276   }
277 
278   se::DeviceMemoryAllocator* const allocator = backend->memory_allocator();
279   xla::TransferManager* const transfer_manager = backend->transfer_manager();
280 
281   auto input_buffers = absl::make_unique<InputBuffers>(
282       transfer_manager->HostShapeToDeviceShape(input_host_shape));
283 
284   // Allocates a buffer for the root tuple.
285   const int64 root_size =
286       transfer_manager->GetByteSizeRequirement(input_buffers->buffers.shape());
287   TF_ASSIGN_OR_RETURN(*input_buffers->buffers.mutable_element({}),
288                       allocator->Allocate(device_ordinal, root_size));
289 
290   // Helper function that sets the input buffers for 'arg_index' to 'buffers'.
291   // If 'donate_buffers' is true, donates ownership of the buffers in 'buffers'
292   // to the computation and overwrites the entries in 'buffers' with nulls.
293   auto set_input_buffers_helper = [&](int arg_index, bool donate_buffers,
294                                       xla::ShapedBuffer* buffers) {
295     buffers->buffers().ForEachMutableElement([&](const xla::ShapeIndex& index,
296                                                  se::DeviceMemoryBase* buffer) {
297       xla::ShapeIndex in_index = {arg_index};
298       for (int64 j : index) {
299         in_index.push_back(j);
300       }
301       auto* in_buffer = input_buffers->buffers.mutable_element(in_index);
302       if (donate_buffers) {
303         *in_buffer = se::OwningDeviceMemory(*buffer, device_ordinal, allocator);
304         *buffer = se::DeviceMemoryBase();
305       } else {
306         *in_buffer = *buffer;
307       }
308     });
309   };
310 
311   // Assigns the buffers of 'tensor' as computation input 'i'. Allocates fresh
312   // buffers for zero-element tensors where required.
313   auto assign_input = [&](int i, const Tensor& tensor,
314                           bool may_reuse) -> xla::Status {
315     XlaTensor* xla_tensor = XlaTensor::FromTensor(&tensor);
316 
317     // Size 0 tensors have no backing XlaTensor, but may still need to have
318     // tuple buffers allocated.
319     if (xla_tensor == nullptr) {
320       CHECK_EQ(tensor.NumElements(), 0);
321       const xla::Shape& host_shape =
322           xla::ShapeUtil::GetSubshape(input_host_shape, {i});
323       TF_ASSIGN_OR_RETURN(xla::ScopedShapedBuffer buffers,
324                           transfer_manager->AllocateScopedShapedBuffer(
325                               host_shape, allocator, device_ordinal));
326       set_input_buffers_helper(/*arg_index=*/i, /*donate_buffers=*/true,
327                                &buffers);
328     } else {
329       bool can_reuse_buffers = tensor.RefCountIsOne() && may_reuse;
330       set_input_buffers_helper(/*arg_index=*/i,
331                                /*donate_buffers=*/can_reuse_buffers,
332                                &xla_tensor->shaped_buffer());
333       xla_tensor->WaitForDefinitionEventOnStream(stream);
334     }
335     return Status::OK();
336   };
337 
338   for (int i = 0; i < arg_list.size(); ++i) {
339     auto it = variable_updates.input_to_output.find(i);
340     if (it == variable_updates.input_to_output.end()) {
341       TF_RETURN_IF_ERROR(assign_input(i, arg_list[i], /*may_reuse=*/true));
342       continue;
343     }
344     // input i is a variable
345     bool updated = it->second >= 0;
346     if (arg_list[i].dtype() != DT_RESOURCE) {
347       TF_RETURN_IF_ERROR(assign_input(i, arg_list[i], updated));
348     } else {
349       int vi = variable_index[i];
350       TF_RETURN_IF_ERROR(
351           assign_input(i, *variables[vi].var()->tensor(), updated));
352     }
353   }
354 
355   input_buffers->variables = std::move(variables);
356   input_buffers->variable_index = std::move(variable_index);
357 
358   return std::move(input_buffers);
359 }
360 
361 struct OutputBuffers {
OutputBufferstensorflow::__anonfe26a15c0111::OutputBuffers362   OutputBuffers(xla::ScopedShapedBuffer b, se::DeviceMemoryAllocator* allocator)
363       : owned_buffers(b.on_device_shape(), true),
364         buffers(b.release()),
365         memory_allocator(allocator) {}
366 
~OutputBufferstensorflow::__anonfe26a15c0111::OutputBuffers367   ~OutputBuffers() {
368     buffers.buffers().ForEachElement(
369         [&](const xla::ShapeIndex& index, const se::DeviceMemoryBase& buffer) {
370           if (owned_buffers.element(index) && !buffer.is_null()) {
371             Status status =
372                 memory_allocator->Deallocate(buffers.device_ordinal(), buffer);
373             if (!status.ok()) {
374               LOG(ERROR) << "Error deallocating buffer " << status;
375             }
376           }
377         });
378   }
379 
380   // Which of the buffers do we own?
381   xla::ShapeTree<bool> owned_buffers;
382 
383   xla::ShapedBuffer buffers;
384 
385   se::DeviceMemoryAllocator* const memory_allocator;
386 };
387 
388 // Allocates Tensors for the outputs of the computation. Ownership of most
389 // output buffers is passed to the output Tensors. Returns an OutputBuffer that
390 // owns the root buffer that should be passed to the XLA computation, as well as
391 // any output buffers that do not have corresponding output tensors. The latter
392 // may happen for zero-element tensors of type int64 or complex64 which still
393 // require a tuple buffer but do not have a corresponding XlaTensor.
AllocateOutputTensors(OpKernelContext * context,xla::ScopedShapedBuffer scoped_buffers,absl::Span<const TensorShapeProto * const> output_tensor_shape_protos,const VariableUpdateMap & variable_updates,TpuNodeContext * node_context,se::Stream * stream,int device_ordinal,InputBuffers * input_buffers,const std::shared_ptr<se::Event> & definition_event)394 xla::StatusOr<std::unique_ptr<OutputBuffers>> AllocateOutputTensors(
395     OpKernelContext* context, xla::ScopedShapedBuffer scoped_buffers,
396     absl::Span<const TensorShapeProto* const> output_tensor_shape_protos,
397     const VariableUpdateMap& variable_updates, TpuNodeContext* node_context,
398     se::Stream* stream, int device_ordinal, InputBuffers* input_buffers,
399     const std::shared_ptr<se::Event>& definition_event) {
400   VLOG(4) << "Output buffers: " << scoped_buffers.ToString();
401 
402   profiler::TraceMe trace_me("AllocateOutputTensors", /*level=*/2);
403   // Shapes of the outputs, in TensorShape form.
404   const int64 sub_elements =
405       xla::ShapeUtil::TupleElementCount(scoped_buffers.on_host_shape());
406   if (sub_elements != output_tensor_shape_protos.size()) {
407     return errors::InvalidArgument(
408         "Mismatched numbers of output shapes: ", sub_elements, " vs. ",
409         output_tensor_shape_protos.size());
410   }
411 
412   xla::TransferManager* const transfer_manager =
413       node_context->backend()->transfer_manager();
414 
415   std::vector<TensorShape> output_tensor_shapes;
416   output_tensor_shapes.reserve(sub_elements);
417   for (int64 i = 0; i < sub_elements; ++i) {
418     TF_RETURN_IF_ERROR(
419         TensorShape::IsValidShape(*output_tensor_shape_protos[i]));
420     TensorShape shape(*output_tensor_shape_protos[i]);
421     const xla::Shape& xla_shape =
422         xla::ShapeUtil::GetSubshape(scoped_buffers.on_host_shape(), {i});
423     if (!xla_shape.IsArray() ||
424         xla::ShapeUtil::ElementsIn(xla_shape) != shape.num_elements()) {
425       return errors::InvalidArgument(
426           "Mismatched number of elements in output shape: ",
427           xla::ShapeUtil::HumanString(xla_shape), " vs ", shape.DebugString());
428     }
429     output_tensor_shapes.push_back(shape);
430   }
431 
432   // Builds a shaped buffer for the outputs.
433   TF_RET_CHECK(scoped_buffers.on_host_shape().IsTuple());
434   TF_RET_CHECK(!xla::ShapeUtil::IsNestedTuple(scoped_buffers.on_host_shape()));
435 
436   se::DeviceMemoryAllocator* const allocator =
437       node_context->backend()->memory_allocator();
438 
439   auto output_buffers =
440       absl::make_unique<OutputBuffers>(std::move(scoped_buffers), allocator);
441 
442   xla::Shape output_device_shape = output_buffers->buffers.on_device_shape();
443 
444   if (!output_device_shape.is_static()) {
445     TF_RETURN_IF_ERROR(transfer_manager->ReadDynamicShapes(
446         stream, &output_buffers->buffers, &output_device_shape));
447     for (int64 i = 0; i < sub_elements; ++i) {
448       const xla::Shape& subshape =
449           xla::ShapeUtil::GetSubshape(output_device_shape, {i});
450       TensorShape shape;
451       TF_RETURN_IF_ERROR(XLAShapeToTensorShape(subshape, &shape));
452       output_tensor_shapes[i] = shape;
453     }
454   }
455 
456   // Transfers ownership of the buffers that back XLA computation output 'i'
457   // to 'output_tensor'.
458   auto transfer_buffers = [&](int i, Tensor* output_tensor) {
459     const xla::Shape& device_shape =
460         xla::ShapeUtil::GetTupleElementShape(output_device_shape, i);
461 
462     // Transfers ownership of the output buffers to the output Tensor, if
463     // there the tensor is backed by an XlaTensor. Tensors of size 0 have no
464     // backing XlaTensor, so we let retain 'output_buffers' ownership of any
465     // buffers in that case.
466     if (output_tensor->NumElements() > 0) {
467       xla::ScopedShapedBuffer shaped_buffer(device_shape, allocator,
468                                             device_ordinal);
469       shaped_buffer.buffers().ForEachMutableElement(
470           [&](const xla::ShapeIndex& index, se::DeviceMemoryBase* buffer) {
471             xla::ShapeIndex out_index = {i};
472             for (int64 j : index) {
473               out_index.push_back(j);
474             }
475             *buffer = output_buffers->buffers.buffers().element(out_index);
476             *output_buffers->owned_buffers.mutable_element(out_index) = false;
477           });
478 
479       XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor);
480       xla_tensor->set_shaped_buffer(std::move(shaped_buffer));
481       xla_tensor->ResetDefinitionEvent(definition_event, stream);
482     }
483   };
484 
485   const int num_updated_variables = variable_updates.output_to_input.size();
486   TF_RET_CHECK(num_updated_variables <= output_tensor_shapes.size())
487       << num_updated_variables << " <= " << output_tensor_shapes.size();
488 
489   OpInputList arg_list;
490   TF_RETURN_IF_ERROR(context->input_list("args", &arg_list));
491 
492   // The TPU program outputs the updated variables including DT_RESOURCE and
493   // non-DT_RESOURCE. The TPUExecuteOp needs to output all non-DT_RESOURCE
494   // variables (updated or not).
495   //
496   //                       updated          not_updated
497   //                 |------------------|------------------|
498   // DT_RESOURCE     | allocate persist |    do nothing    |
499   //                 |------------------|------------------|
500   //                 |     allocate     | forward Op input |
501   // not DT_RESOURCE |      output      |   to Op output   | Op output
502   //                 |------------------|------------------|
503   //                    program output
504 
505   // Allocates a fresh tensor for each updated variable. While the variable
506   // inputs need come in no particular order, the variable values are
507   // always added last by XlaCompiler class, in the same order as the
508   // corresponding input variables.
509   int op_output_index = 0;
510   int compiled_update_index = 0;
511   auto process_non_updated_variable = [&](int input_index) {
512     const int variable_index = input_buffers->variable_index.at(input_index);
513     // If a DT_RESOURCE input is not updated, nothing needs to be done
514     // because there is no corresponding output. If a non-resource input
515     // is not updated, forward the input to the output.
516     if (variable_index < 0) {
517       context->set_output(op_output_index, arg_list[input_index]);
518       ++op_output_index;
519     }
520   };
521   for (int i = 0; i < output_tensor_shapes.size(); ++i) {
522     auto it = variable_updates.output_to_input.find(i);
523     if (it == variable_updates.output_to_input.end()) {
524       // Not a variable update.
525       // Allocates a fresh tensor for each output of the operator. We always
526       // allocate a new host-side tensor, but the on-device buffers that back
527       // that tensor may be aliases of input buffers.
528       Tensor* output_tensor;
529       TF_RETURN_IF_ERROR(context->allocate_output(
530           op_output_index, output_tensor_shapes[i], &output_tensor));
531       transfer_buffers(i, output_tensor);
532       ++op_output_index;
533       continue;
534     }
535     const int input_index = it->second.first;
536     // We must process the compiled updates in order, which includes the
537     // non-updated variables, i.e., those without an XLA output.
538     const bool from_compilation = it->second.second;
539     while (from_compilation &&
540            variable_updates
541                    .input_in_compiled_update_order[compiled_update_index] !=
542                input_index) {
543       process_non_updated_variable(
544           variable_updates
545               .input_in_compiled_update_order[compiled_update_index]);
546       ++compiled_update_index;
547     }
548     ++compiled_update_index;
549     const int variable_index = input_buffers->variable_index.at(input_index);
550     PersistentTensor unused;
551     Tensor* output_tensor;
552     if (variable_index >= 0) {
553       // This output corresponds to a DT_RESOURCE input to the TPUExecute
554       // operator. Update the corresponding variable.
555       VariableInfo& var = input_buffers->variables[variable_index];
556       // TODO(b/35625933): the correct thing to do would be to transfer
557       // ownership of the PersistentTensor into the Var object. However, Var
558       // contains a Tensor so we can't.
559       TF_RETURN_IF_ERROR(context->allocate_persistent(
560           var.var()->tensor()->dtype(), output_tensor_shapes[i], &unused,
561           &output_tensor));
562       *var.var()->tensor() = *output_tensor;
563     } else {
564       // This output corresponds to a non-resource input to the TPUExecute
565       // operator. This case occurs for the distributed TPU rewrite which
566       // adds variable values as inputs and outputs rather than passing the
567       // variables themselves; reading and writing the variable is handled
568       // outside the op.
569       // TODO(phawkins): remove this case when placement of variables on TPU
570       // devices is well supported and we no longer need to place "remote"
571       // variables on CPU devices.
572       TF_RETURN_IF_ERROR(context->allocate_output(
573           op_output_index, output_tensor_shapes[i], &output_tensor));
574       ++op_output_index;
575     }
576     transfer_buffers(i, output_tensor);
577   }
578 
579   // Process any remaining non-updated variables.
580   for (; compiled_update_index <
581          variable_updates.input_in_compiled_update_order.size();
582        ++compiled_update_index) {
583     process_non_updated_variable(
584         variable_updates.input_in_compiled_update_order[compiled_update_index]);
585   }
586   return std::move(output_buffers);
587 }
588 
589 }  // namespace
590 
591 // TPUExecuteOp
592 
TPUExecuteOp(OpKernelConstruction * context)593 TPUExecuteOp::TPUExecuteOp(OpKernelConstruction* context)
594     : AsyncOpKernel(context, /* is_deferred = */ true) {}
595 
AsAsync()596 AsyncOpKernel* TPUExecuteOp::AsAsync() {
597   // If TPU launches are asynchronous, we can perform the launch without
598   // blocking the calling thread, and so the executor may treat this kernel as
599   // a regular (synchronous) OpKernel.
600   return nullptr;
601 }
602 
Compute(OpKernelContext * context)603 void TPUExecuteOp::Compute(OpKernelContext* context) {
604   Status s = DoWork(context);
605   // NOTE: We can't use `OP_REQUIRES_OK()` here because that macro includes
606   // a dynamic check that we are not in an AsyncOpKernel.
607   if (TF_PREDICT_FALSE(!s.ok())) {
608     context->SetStatus(s);
609   }
610 }
611 
ComputeAsync(OpKernelContext * context,DoneCallback done)612 void TPUExecuteOp::ComputeAsync(OpKernelContext* context, DoneCallback done) {
613   // If TPU launches are asynchronous, then perform the launch on this
614   // thread to avoid a thread hop, which has an observable latency cost.
615   OP_REQUIRES_OK_ASYNC(context, DoWork(context), done);
616   done();
617 }
618 
DoWork(OpKernelContext * context)619 Status TPUExecuteOp::DoWork(OpKernelContext* context) {
620   VLOG(1) << "Cloud TPU: TPUExecuteOp::Compute";
621 
622   const XlaDevice::Metadata* metadata;
623   TF_RETURN_IF_ERROR(XlaDevice::GetMetadata(context, &metadata));
624   const int device_ordinal = metadata->device_ordinal();
625 
626   // We are guaranteed that the object underlying TpuNodeContext won't be
627   // deleted out from under us, while node_context is alive.
628   TF_ASSIGN_OR_RETURN(std::unique_ptr<TpuNodeContext> node_context,
629                       TpuNodeContext::Create(device_ordinal));
630 
631   profiler::TraceMe trace_me(
632       [device_ordinal, context] {
633         return profiler::TraceMeEncode(
634             "TpuExecuteOp", {{"device_ordinal", device_ordinal},
635                              {"id", context->step_id()},
636                              {"iter_num", context->frame_iter().iter_id}});
637       },
638       /*level=*/2);
639   profiler::TraceMe trace_me_init("TPUExecuteOp::Init", /*level=*/2);
640 
641   string rendezvous_key_base;
642   std::unique_ptr<CompilationCacheEntryRef> entry_ref;
643   TF_RETURN_IF_ERROR(
644       GetComputationCacheEntry(context, &rendezvous_key_base, &entry_ref));
645 
646   // Shapes of the inputs and outputs, in xla::Shape form.
647   tpu::TpuCompilationCacheEntry entry = entry_ref->get();
648   const tpu::TpuProgramGroup* tpu_program_group =
649       tensorflow::down_cast<const tpu::TpuProgramGroup*>(
650           entry.tpu_program_group());
651   CHECK_NE(tpu_program_group, nullptr);
652   const int core_index = entry.core_index();
653   const TPUExecutableInfoProto& executable =
654       tpu_program_group->executable_info(core_index);
655 
656   xla::Backend* const backend = node_context->backend();
657   xla::TransferManager* const transfer_manager = backend->transfer_manager();
658   TF_RET_CHECK(context->op_device_context());
659   se::Stream* stream = context->op_device_context()->stream();
660 
661   TF_RET_CHECK(executable.input_shapes_size() == 1);
662 
663   xla::Shape host_shape(executable.input_shapes(0));
664 
665   TF_ASSIGN_OR_RETURN(
666       auto variable_update_map,
667       BuildVariableUpdateMap(executable.variable_indices(),
668                              fused_device_var_reads_in_computation_inputs_,
669                              fused_device_var_updates_in_computation_outputs_,
670                              executable.output_tensor_shapes().size()));
671   TF_ASSIGN_OR_RETURN(
672       std::unique_ptr<InputBuffers> input_buffers,
673       BuildComputationInputs(context, host_shape, variable_update_map, backend,
674                              device_ordinal, stream));
675 
676   // Ideally this should be the host-to-device stream from XlaDeviceContext.
677   // The particular anti-dependency this is avoiding (why we need a separate
678   // transfer stream) is between the executable writing tuple tables and
679   // TPUExecute()'s deregister_stream; if they come from the same stream pool
680   // antidependencies will occur. XlaBackend has a different pool of streams
681   // to the stream->GetOrCreateSubStream() that TPUExecute() uses, so these
682   // will never refer to the same stream.
683   //
684   // TODO(jmolloy): Add the necessary plumbing to obtain the proper
685   // host-to-device stream here.
686   TF_ASSIGN_OR_RETURN(auto transfer_stream_ptr,
687                       backend->BorrowStream(device_ordinal));
688 
689   se::DeviceMemoryAllocator* const allocator = backend->memory_allocator();
690   auto shaped_buffer = input_buffers->ToShapedBuffer(std::move(host_shape),
691                                                      allocator, device_ordinal);
692   if (transfer_manager->CanShapedBufferBeAccessedNow(stream->parent(),
693                                                      shaped_buffer)) {
694     TF_RETURN_IF_ERROR(transfer_manager->WriteRootTupleIndexTable(
695         transfer_stream_ptr.get(), shaped_buffer));
696     stream->ThenWaitFor(transfer_stream_ptr.get());
697   } else {
698     TF_RETURN_IF_ERROR(
699         transfer_manager->WriteRootTupleIndexTable(stream, shaped_buffer));
700   }
701   VLOG(4) << "Input buffers: " << shaped_buffer.ToString();
702 
703   // Snapshot the inputs, if a snapshot was requested.
704   std::shared_ptr<xla::HloSnapshot> hlo_snapshot;
705   if (executable.has_session_module()) {
706     hlo_snapshot =
707         std::make_shared<xla::HloSnapshot>(executable.session_module());
708     auto literal =
709         std::make_shared<xla::Literal>(shaped_buffer.on_host_shape());
710     transfer_manager->TransferLiteralFromDevice(
711         stream, shaped_buffer, literal.get(),
712         [hlo_snapshot, literal](Status status) {
713           if (!status.ok()) {
714             LOG(ERROR) << "TransferLiteralFromDevice for HLO snapshot inputs "
715                           "failed: "
716                        << status;
717             return;
718           }
719           *hlo_snapshot->add_arguments() = literal->ToProto();
720         });
721   }
722 
723   auto definition_event = std::make_shared<se::Event>(stream->parent());
724   TF_RET_CHECK(definition_event->Init())
725       << "TPU definition event initialization failed";
726 
727   trace_me_init.Stop();
728 
729   const uint32 rng_seed = GetXLARandomSeed();
730 
731   std::unique_ptr<xla::DeviceAssignment> device_assignment;
732   if (executable.has_device_assignment()) {
733     TF_ASSIGN_OR_RETURN(device_assignment, xla::DeviceAssignment::Deserialize(
734                                                executable.device_assignment()));
735   }
736 
737   VLOG(4) << "Input buffers after alias resolution: "
738           << shaped_buffer.ToString();
739 
740   std::vector<xla::ExecutionInput> input;
741   input.emplace_back(xla::ExecutionInput(std::move(input_buffers->buffers),
742                                          shaped_buffer.on_host_shape()));
743 
744   // The buffers to be freed are in the `output` and will be automatically
745   // freed when it goes out of the scope. In async mode, this means the buffers
746   // will be freed before anyone calls "BlockHostUntilDone", which indicates
747   // that some of the (input) buffers will be freed while the program is running
748   // and looks scary. However, this turns out to be not a problem since although
749   // we free a memory and reassign it to other users while a program is running,
750   // all subsequent writes to the program that could possibly clobber the memory
751   // will depend on the program to finish.
752   const TPUHostTransferInfoProto& host_transfer_info =
753       tpu_program_group->host_transfer_info(core_index);
754   TF_ASSIGN_OR_RETURN(
755       xla::ExecutionOutput output,
756       TPUExecute(executable, host_transfer_info,
757                  *tpu_program_group->hlo_metadata(core_index), std::move(input),
758                  rendezvous_key_base, rng_seed, node_context.get(),
759                  device_assignment.get(), context->cancellation_manager(),
760                  context, stream, transfer_stream_ptr.get(),
761                  tpu_program_group->tpu_program(core_index)));
762   stream->ThenRecordEvent(definition_event.get());
763 
764   TF_ASSIGN_OR_RETURN(
765       std::unique_ptr<OutputBuffers> output_buffers,
766       AllocateOutputTensors(
767           context, output.ConsumeResult(), executable.output_tensor_shapes(),
768           variable_update_map, node_context.get(), stream, device_ordinal,
769           input_buffers.get(), definition_event));
770 
771   // Transfer the outputs and save the snapshot to disk.
772   if (hlo_snapshot) {
773     auto literal =
774         std::make_shared<xla::Literal>(output_buffers->buffers.on_host_shape());
775     transfer_manager->TransferLiteralFromDevice(
776         stream, output_buffers->buffers, literal.get(),
777         [hlo_snapshot, literal](Status status) {
778           if (status.ok()) {
779             *hlo_snapshot->mutable_result() = literal->ToProto();
780           } else {
781             LOG(ERROR) << "TransferLiteralFromDevice for HLO snapshot "
782                           "outputs failed: "
783                        << status;
784           }
785           DumpHloSnapshotIfEnabled(*hlo_snapshot,
786                                    xla::GetDebugOptionsFromFlags());
787         });
788   }
789   return Status::OK();
790 }
791 
792 TPUExecuteOp::~TPUExecuteOp() = default;
793 
TPUExecuteAndUpdateVariablesOp(OpKernelConstruction * context)794 TPUExecuteAndUpdateVariablesOp::TPUExecuteAndUpdateVariablesOp(
795     OpKernelConstruction* context)
796     : TPUExecuteOp(context) {
797   OP_REQUIRES_OK(context, context->GetAttr(
798                               "device_var_reads_indices",
799                               &fused_device_var_reads_in_computation_inputs_));
800   OP_REQUIRES_OK(
801       context,
802       context->GetAttr("device_var_updates_indices",
803                        &fused_device_var_updates_in_computation_outputs_));
804 }
805 
806 REGISTER_KERNEL_BUILDER(
807     Name("TPUExecute").Device(DEVICE_TPU_NODE).HostMemory("key"), TPUExecuteOp);
808 
809 REGISTER_KERNEL_BUILDER(Name("TPUExecuteAndUpdateVariables")
810                             .Device(DEVICE_TPU_NODE)
811                             .HostMemory("key"),
812                         TPUExecuteAndUpdateVariablesOp);
813 
814 }  // namespace tensorflow
815