• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/tpu/kernels/tpu_execute_op.h"
16 
17 #include <utility>
18 
19 #include "absl/container/flat_hash_map.h"
20 #include "absl/memory/memory.h"
21 #include "absl/types/span.h"
22 #include "tensorflow/compiler/jit/xla_device.h"
23 #include "tensorflow/compiler/jit/xla_launch_util.h"
24 #include "tensorflow/compiler/jit/xla_tensor.h"
25 #include "tensorflow/compiler/tf2xla/shape_util.h"
26 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
27 #include "tensorflow/compiler/xla/debug_options_flags.h"
28 #include "tensorflow/compiler/xla/service/dump.h"
29 #include "tensorflow/compiler/xla/service/executable.h"
30 #include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h"
31 #include "tensorflow/compiler/xla/shape_util.h"
32 #include "tensorflow/compiler/xla/statusor.h"
33 #include "tensorflow/compiler/xla/xla_data.pb.h"
34 #include "tensorflow/core/framework/allocator.h"
35 #include "tensorflow/core/framework/node_def_util.h"
36 #include "tensorflow/core/framework/op.h"
37 #include "tensorflow/core/framework/op_kernel.h"
38 #include "tensorflow/core/framework/resource_mgr.h"
39 #include "tensorflow/core/framework/resource_var.h"
40 #include "tensorflow/core/framework/tensor.h"
41 #include "tensorflow/core/framework/types.h"
42 #include "tensorflow/core/lib/core/errors.h"
43 #include "tensorflow/core/platform/casts.h"
44 #include "tensorflow/core/platform/tracing.h"
45 #include "tensorflow/core/profiler/lib/traceme.h"
46 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h"
47 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_external.h"
48 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h"
49 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_local_lookup.h"
50 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h"
51 #include "tensorflow/core/tpu/kernels/tpu_executable_info.pb.h"
52 #include "tensorflow/core/tpu/kernels/tpu_op_consts.h"
53 #include "tensorflow/core/tpu/tpu_configuration.h"
54 #include "tensorflow/core/tpu/tpu_defs.h"
55 #include "tensorflow/core/tpu/tpu_execute.h"
56 #include "tensorflow/core/util/stream_executor_util.h"
57 #include "tensorflow/stream_executor/device_memory_allocator.h"
58 #include "tensorflow/stream_executor/tpu/tpu_node_context.h"
59 
60 namespace tensorflow {
61 namespace {
62 using ::tensorflow::tpu::CompilationCacheEntryRef;
63 using ::tensorflow::tpu::TpuCompilationCacheLookup;
64 using ::tensorflow::tpu::TpuNodeContext;
65 
66 // Looks up the input `key` in the compilation cache, populating
67 // `*rendezvous_key_base` and `*entry`.
GetComputationCacheEntry(OpKernelContext * context,string * rendezvous_key_base,std::unique_ptr<CompilationCacheEntryRef> * entry)68 Status GetComputationCacheEntry(
69     OpKernelContext* context, string* rendezvous_key_base,
70     std::unique_ptr<CompilationCacheEntryRef>* entry) {
71   const Tensor* key;
72   TF_RETURN_IF_ERROR(context->input("key", &key));
73   profiler::TraceMe trace_me("TpuExecuteOp::LookupProto", /*level=*/2);
74   if (!TensorShapeUtils::IsVector(key->shape()) ||
75       key->shape().dim_size(0) != 3) {
76     return errors::InvalidArgument(
77         "Key argument to TPUExecute must be a 3-element vector");
78   }
79 
80   ResourceMgr* rmgr = GetTPUConfigResourceMgr();
81   TpuCompilationCacheLookup* proto_lookup;
82   TF_RETURN_IF_ERROR(rmgr->Lookup(rmgr->default_container(),
83                                   tpu::kCompiledProtoCacheResourceName,
84                                   &proto_lookup));
85   core::ScopedUnref lookup_unref(proto_lookup);
86   TF_RETURN_IF_ERROR(proto_lookup->Lookup(key->vec<tstring>()(0), entry));
87   *rendezvous_key_base = key->vec<tstring>()(1);
88   return Status::OK();
89 }
90 
91 struct VariableUpdateMap {
92   // Maps input index to the updated output index. If the variable doesn't have
93   // an updated output, the corresponding output is set to -1.
94   absl::flat_hash_map<int, int> input_to_output;
95   // Maps output index to (the input index, whether the update is generated from
96   // compilation).
97   absl::flat_hash_map<int, std::pair<int, bool>> output_to_input;
98   // Part of the input indices that are from the compilation, in the compiled
99   // order.
100   std::vector<int> input_in_compiled_update_order;
101 };
102 
103 // Creates a VariableUpdateMap from both the compilation and the fused variable
104 // reads/updates.
BuildVariableUpdateMap(absl::Span<const TPUExecutableInfoProto::UpdateIndexPair * const> compiled_variable_updates,absl::Span<int const> fused_device_var_reads_in_computation_inputs,const std::vector<int> & fused_device_var_updates_in_computation_outputs,int64_t computation_output_count)105 xla::StatusOr<VariableUpdateMap> BuildVariableUpdateMap(
106     absl::Span<const TPUExecutableInfoProto::UpdateIndexPair* const>
107         compiled_variable_updates,
108     absl::Span<int const> fused_device_var_reads_in_computation_inputs,
109     const std::vector<int>& fused_device_var_updates_in_computation_outputs,
110     int64_t computation_output_count) {
111   VariableUpdateMap map;
112   auto add_pair = [&](int input, int output, bool from_compilation) -> Status {
113     TF_RET_CHECK(map.input_to_output.emplace(input, output).second)
114         << "Duplicate variable input index: " << input;
115     if (output >= 0) {
116       TF_RET_CHECK(map.output_to_input
117                        .emplace(output, std::make_pair(input, from_compilation))
118                        .second)
119           << "Duplicate variable output index: " << output;
120     }
121     return Status::OK();
122   };
123 
124   // First add the updates produced by the compilation. Not all variables are
125   // updated, and if not, they do not have an output in the XLA computation. The
126   // update output indices in the XLA computation start after the non-variable
127   // outputs.
128   int num_updated_variables = 0;
129   for (int i = 0; i < compiled_variable_updates.size(); ++i) {
130     const bool updated = compiled_variable_updates[i]->updated();
131     if (updated) ++num_updated_variables;
132   }
133   TF_RET_CHECK(num_updated_variables <= computation_output_count)
134       << num_updated_variables << " <= " << computation_output_count;
135   int64_t compiled_variable_output_index =
136       computation_output_count - num_updated_variables;
137   for (auto update : compiled_variable_updates) {
138     map.input_in_compiled_update_order.push_back(update->index());
139     if (!update->updated()) {
140       TF_RETURN_IF_ERROR(add_pair(update->index(), -1, true));
141       continue;
142     }
143     TF_RETURN_IF_ERROR(
144         add_pair(update->index(), compiled_variable_output_index, true));
145     ++compiled_variable_output_index;
146   }
147 
148   // Now add the updates from the attributes.
149   TF_RET_CHECK(fused_device_var_reads_in_computation_inputs.size() ==
150                fused_device_var_updates_in_computation_outputs.size());
151   for (int64_t i = 0; i < fused_device_var_reads_in_computation_inputs.size();
152        ++i) {
153     TF_RETURN_IF_ERROR(
154         add_pair(fused_device_var_reads_in_computation_inputs[i],
155                  fused_device_var_updates_in_computation_outputs[i], false));
156   }
157   return map;
158 }
159 
160 // Buffers representing the inputs to a computation.
161 struct InputBuffers {
InputBufferstensorflow::__anonea0ab8410111::InputBuffers162   explicit InputBuffers(xla::Shape device_shape)
163       : buffers(std::move(device_shape)) {}
164 
165   InputBuffers(const InputBuffers&) = delete;
166   InputBuffers& operator=(const InputBuffers&) = delete;
167 
168   ~InputBuffers() = default;
169 
ToShapedBuffertensorflow::__anonea0ab8410111::InputBuffers170   xla::ShapedBuffer ToShapedBuffer(xla::Shape host_shape,
171                                    se::DeviceMemoryAllocator* allocator,
172                                    int device_ordinal) {
173     CHECK_NE(allocator, nullptr);
174     xla::ShapedBuffer shaped_buffer(std::move(host_shape), buffers.shape(),
175                                     device_ordinal);
176     shaped_buffer.set_buffers(buffers.Map<se::DeviceMemoryBase>(
177         [](xla::MaybeOwningDeviceMemory* buffer) {
178           CHECK(buffer);
179           return buffer->AsDeviceMemoryBase();
180         }));
181     return shaped_buffer;
182   }
183 
184   // Describes the buffer tree.
185   xla::ShapeTree<xla::MaybeOwningDeviceMemory> buffers;
186 
187   // Information about resource variables passed directly to TPUExecute.
188   std::vector<VariableInfo> variables;
189 
190   // Mapping from input index to offsets in 'variables'. < 0 if the input does
191   // not correspond to a variable in 'variables'.
192   std::vector<int> variable_index;
193 };
194 
195 // Builds an InputBuffers object that describes the inputs to the computation.
BuildComputationInputs(OpKernelContext * context,const xla::Shape & input_host_shape,const VariableUpdateMap & variable_updates,xla::Backend * backend,int device_ordinal,se::Stream * stream)196 xla::StatusOr<std::unique_ptr<InputBuffers>> BuildComputationInputs(
197     OpKernelContext* context, const xla::Shape& input_host_shape,
198     const VariableUpdateMap& variable_updates, xla::Backend* backend,
199     int device_ordinal, se::Stream* stream) {
200   profiler::TraceMe trace_me("BuildComputationInputs", /*level=*/2);
201   OpInputList arg_list;
202   TF_RETURN_IF_ERROR(context->input_list("args", &arg_list));
203 
204   if (arg_list.size() != xla::ShapeUtil::TupleElementCount(input_host_shape)) {
205     return errors::InvalidArgument(
206         "Number of parameters (", arg_list.size(),
207         ") does not match input shape: ",
208         xla::ShapeUtil::TupleElementCount(input_host_shape));
209   }
210 
211   auto validate_shape = [&](int i, const Tensor& tensor) {
212     const xla::Shape& expected =
213         xla::ShapeUtil::GetTupleElementShape(input_host_shape, i);
214     VLOG(4) << "Input " << i << " TF shape " << tensor.shape().DebugString();
215     XlaTensor* xla_tensor = XlaTensor::FromTensor(&tensor);
216 
217     if (xla_tensor == nullptr) {
218       // FromTensor failed; tensor must be empty.
219       if (!xla::ShapeUtil::IsZeroElementArray(expected)) {
220         return errors::InvalidArgument(
221             "Run-time shape mismatch for TPUExecute argument[", i, "] (",
222             context->op_kernel().requested_input(i), "). Expected ",
223             expected.DebugString(),
224             "; got empty tensor. If you are running "
225             "with TF2 TPU, make sure you set `drop_remainder=False` when "
226             "calling `dataset.batch` on the `tf.data.Dataset` so dynamic batch "
227             "size can be handled");
228       }
229     } else {
230       // Compare host shapes, easier than getting the expected device shape.
231       const xla::Shape& xla_shape = xla_tensor->shaped_buffer().on_host_shape();
232       if (!xla::ShapeUtil::Compatible(expected, xla_shape)) {
233         return errors::InvalidArgument(
234             "Run-time shape mismatch for TPUExecute argument[", i, "] (",
235             context->op_kernel().requested_input(i), "). Expected ",
236             expected.DebugString(), "; got ", xla_shape.DebugString());
237       }
238     }
239 
240     return Status::OK();
241   };
242 
243   // Iterate over the inputs, validating the shapes of non-variable inputs,
244   // and creating a VariableInfo object for each variable. We consider variable
245   // inputs in a separate phase because we must acquire variable locks in order.
246   std::vector<VariableInfo> variables;
247   std::vector<int> variable_index(arg_list.size(), -1);
248   variables.reserve(arg_list.size());
249   for (int i = 0; i < arg_list.size(); ++i) {
250     // Arguments are assumed to be variables if they have a resource type.
251     // (Non-variable resources are not supported.)
252     if (context->input_dtype(i) == DT_RESOURCE) {
253       variable_index[i] = variables.size();
254       // TODO(phawkins): we may be looking up many variables here; it would be
255       // better if we did not repeatedly acquire the resource manager's lock.
256       const ResourceHandle& handle = HandleFromInput(context, i);
257       Var* variable;
258       TF_RETURN_IF_ERROR(LookupResource(context, handle, &variable));
259       variables.push_back(VariableInfo(i, handle.name(), variable));
260     } else {
261       TF_RETURN_IF_ERROR(validate_shape(i, arg_list[i]));
262     }
263   }
264 
265   // Lock the variables, and validate their shapes. We hold the variable locks
266   // for the duration of the TPU execution so we can donate the variable buffers
267   // to the computation. If we copied the variable's Tensor instead, its
268   // reference count would be greater than one due to the reference the Var
269   // object holds, and we would never be able to reuse variable buffers.
270   // TODO(phawkins): add a 'reuse_buffers' attribute to TPUExecute that allows
271   // the user to elect to copy the buffers and permit concurrent access instead.
272   TF_RETURN_IF_ERROR(LockVariables(absl::MakeSpan(variables)));
273   for (int i = 0; i < variables.size(); ++i) {
274     TF_RETURN_IF_ERROR(
275         validate_shape(variables[i].index(), *variables[i].var()->tensor()));
276   }
277 
278   se::DeviceMemoryAllocator* const allocator = backend->memory_allocator();
279   xla::TransferManager* const transfer_manager = backend->transfer_manager();
280 
281   auto input_buffers = absl::make_unique<InputBuffers>(
282       transfer_manager->HostShapeToDeviceShape(input_host_shape));
283 
284   // Allocates a buffer for the root tuple.
285   const int64_t root_size =
286       transfer_manager->GetByteSizeRequirement(input_buffers->buffers.shape());
287   TF_ASSIGN_OR_RETURN(*input_buffers->buffers.mutable_element({}),
288                       allocator->Allocate(device_ordinal, root_size));
289 
290   // Helper function that sets the input buffers for 'arg_index' to 'buffers'.
291   // If 'donate_buffers' is true, donates ownership of the buffers in 'buffers'
292   // to the computation and overwrites the entries in 'buffers' with nulls.
293   auto set_input_buffers_helper = [&](int arg_index, bool donate_buffers,
294                                       xla::ShapedBuffer* buffers) {
295     buffers->buffers().ForEachMutableElement([&](const xla::ShapeIndex& index,
296                                                  se::DeviceMemoryBase* buffer) {
297       xla::ShapeIndex in_index = {arg_index};
298       for (int64_t j : index) {
299         in_index.push_back(j);
300       }
301       auto* in_buffer = input_buffers->buffers.mutable_element(in_index);
302       if (donate_buffers) {
303         *in_buffer = se::OwningDeviceMemory(*buffer, device_ordinal, allocator);
304         *buffer = se::DeviceMemoryBase();
305       } else {
306         *in_buffer = *buffer;
307       }
308     });
309   };
310 
311   // Assigns the buffers of 'tensor' as computation input 'i'. Allocates fresh
312   // buffers for zero-element tensors where required.
313   auto assign_input = [&](int i, const Tensor& tensor,
314                           bool may_reuse) -> xla::Status {
315     XlaTensor* xla_tensor = XlaTensor::FromTensor(&tensor);
316 
317     // Size 0 tensors have no backing XlaTensor, but may still need to have
318     // tuple buffers allocated.
319     if (xla_tensor == nullptr) {
320       CHECK_EQ(tensor.NumElements(), 0);
321       const xla::Shape& host_shape =
322           xla::ShapeUtil::GetSubshape(input_host_shape, {i});
323       TF_ASSIGN_OR_RETURN(xla::ScopedShapedBuffer buffers,
324                           transfer_manager->AllocateScopedShapedBuffer(
325                               host_shape, allocator, device_ordinal));
326       set_input_buffers_helper(/*arg_index=*/i, /*donate_buffers=*/true,
327                                &buffers);
328     } else {
329       bool can_reuse_buffers = tensor.RefCountIsOne() && may_reuse;
330       set_input_buffers_helper(/*arg_index=*/i,
331                                /*donate_buffers=*/can_reuse_buffers,
332                                &xla_tensor->shaped_buffer());
333       xla_tensor->WaitForDefinitionEventOnStream(stream);
334     }
335     return Status::OK();
336   };
337 
338   for (int i = 0; i < arg_list.size(); ++i) {
339     auto it = variable_updates.input_to_output.find(i);
340     if (it == variable_updates.input_to_output.end()) {
341       TF_RETURN_IF_ERROR(assign_input(i, arg_list[i], /*may_reuse=*/true));
342       continue;
343     }
344     // input i is a variable
345     bool updated = it->second >= 0;
346     if (arg_list[i].dtype() != DT_RESOURCE) {
347       TF_RETURN_IF_ERROR(assign_input(i, arg_list[i], updated));
348     } else {
349       int vi = variable_index[i];
350       TF_RETURN_IF_ERROR(
351           assign_input(i, *variables[vi].var()->tensor(), updated));
352     }
353   }
354 
355   input_buffers->variables = std::move(variables);
356   input_buffers->variable_index = std::move(variable_index);
357 
358   return std::move(input_buffers);
359 }
360 
361 struct OutputBuffers {
OutputBufferstensorflow::__anonea0ab8410111::OutputBuffers362   OutputBuffers(xla::ScopedShapedBuffer b, se::DeviceMemoryAllocator* allocator)
363       : owned_buffers(b.on_device_shape(), true),
364         buffers(b.release()),
365         memory_allocator(allocator) {}
366 
~OutputBufferstensorflow::__anonea0ab8410111::OutputBuffers367   ~OutputBuffers() {
368     buffers.buffers().ForEachElement(
369         [&](const xla::ShapeIndex& index, const se::DeviceMemoryBase& buffer) {
370           if (owned_buffers.element(index) && !buffer.is_null()) {
371             Status status =
372                 memory_allocator->Deallocate(buffers.device_ordinal(), buffer);
373             if (!status.ok()) {
374               LOG(ERROR) << "Error deallocating buffer " << status;
375             }
376           }
377         });
378   }
379 
380   // Which of the buffers do we own?
381   xla::ShapeTree<bool> owned_buffers;
382 
383   xla::ShapedBuffer buffers;
384 
385   se::DeviceMemoryAllocator* const memory_allocator;
386 };
387 
388 // Allocates Tensors for the outputs of the computation. Ownership of most
389 // output buffers is passed to the output Tensors. Returns an OutputBuffer that
390 // owns the root buffer that should be passed to the XLA computation, as well as
391 // any output buffers that do not have corresponding output tensors. The latter
392 // may happen for zero-element tensors of type int64 or complex64 which still
393 // require a tuple buffer but do not have a corresponding XlaTensor.
AllocateOutputTensors(OpKernelContext * context,xla::ScopedShapedBuffer scoped_buffers,absl::Span<const TensorShapeProto * const> output_tensor_shape_protos,const VariableUpdateMap & variable_updates,TpuNodeContext * node_context,se::Stream * stream,int device_ordinal,InputBuffers * input_buffers,const std::shared_ptr<se::Event> & definition_event)394 xla::StatusOr<std::unique_ptr<OutputBuffers>> AllocateOutputTensors(
395     OpKernelContext* context, xla::ScopedShapedBuffer scoped_buffers,
396     absl::Span<const TensorShapeProto* const> output_tensor_shape_protos,
397     const VariableUpdateMap& variable_updates, TpuNodeContext* node_context,
398     se::Stream* stream, int device_ordinal, InputBuffers* input_buffers,
399     const std::shared_ptr<se::Event>& definition_event) {
400   VLOG(4) << "Output buffers: " << scoped_buffers.ToString();
401 
402   profiler::TraceMe trace_me("AllocateOutputTensors", /*level=*/2);
403   // Shapes of the outputs, in TensorShape form.
404   const int64_t sub_elements =
405       xla::ShapeUtil::TupleElementCount(scoped_buffers.on_host_shape());
406   if (sub_elements != output_tensor_shape_protos.size()) {
407     return errors::InvalidArgument(
408         "Mismatched numbers of output shapes: ", sub_elements, " vs. ",
409         output_tensor_shape_protos.size());
410   }
411 
412   xla::TransferManager* const transfer_manager =
413       node_context->backend()->transfer_manager();
414 
415   std::vector<TensorShape> output_tensor_shapes;
416   output_tensor_shapes.reserve(sub_elements);
417   for (int64_t i = 0; i < sub_elements; ++i) {
418     TF_RETURN_IF_ERROR(
419         TensorShape::IsValidShape(*output_tensor_shape_protos[i]));
420     TensorShape shape(*output_tensor_shape_protos[i]);
421     const xla::Shape& xla_shape =
422         xla::ShapeUtil::GetSubshape(scoped_buffers.on_host_shape(), {i});
423     if (!xla_shape.IsArray() ||
424         xla::ShapeUtil::ElementsIn(xla_shape) != shape.num_elements()) {
425       return errors::InvalidArgument(
426           "Mismatched number of elements in output shape: ",
427           xla::ShapeUtil::HumanString(xla_shape), " vs ", shape.DebugString());
428     }
429     output_tensor_shapes.push_back(shape);
430   }
431 
432   // Builds a shaped buffer for the outputs.
433   TF_RET_CHECK(scoped_buffers.on_host_shape().IsTuple());
434   TF_RET_CHECK(!xla::ShapeUtil::IsNestedTuple(scoped_buffers.on_host_shape()));
435 
436   se::DeviceMemoryAllocator* const allocator =
437       node_context->backend()->memory_allocator();
438 
439   auto output_buffers =
440       absl::make_unique<OutputBuffers>(std::move(scoped_buffers), allocator);
441 
442   xla::Shape output_device_shape = output_buffers->buffers.on_device_shape();
443 
444   if (!output_device_shape.is_static()) {
445     TF_RETURN_IF_ERROR(transfer_manager->ReadDynamicShapes(
446         stream, &output_buffers->buffers, &output_device_shape));
447     for (int64_t i = 0; i < sub_elements; ++i) {
448       const xla::Shape& subshape =
449           xla::ShapeUtil::GetSubshape(output_device_shape, {i});
450       TensorShape shape;
451       TF_RETURN_IF_ERROR(XLAShapeToTensorShape(subshape, &shape));
452       output_tensor_shapes[i] = shape;
453     }
454   }
455 
456   // Transfers ownership of the buffers that back XLA computation output 'i'
457   // to 'output_tensor'.
458   auto transfer_buffers = [&](int i, Tensor* output_tensor) {
459     const xla::Shape& device_shape =
460         xla::ShapeUtil::GetTupleElementShape(output_device_shape, i);
461 
462     // Transfers ownership of the output buffers to the output Tensor, if
463     // there the tensor is backed by an XlaTensor. Tensors of size 0 have no
464     // backing XlaTensor, so we let retain 'output_buffers' ownership of any
465     // buffers in that case.
466     if (output_tensor->NumElements() > 0) {
467       xla::ScopedShapedBuffer shaped_buffer(device_shape, allocator,
468                                             device_ordinal);
469       shaped_buffer.buffers().ForEachMutableElement(
470           [&](const xla::ShapeIndex& index, se::DeviceMemoryBase* buffer) {
471             xla::ShapeIndex out_index = {i};
472             for (int64_t j : index) {
473               out_index.push_back(j);
474             }
475             *buffer = output_buffers->buffers.buffers().element(out_index);
476             *output_buffers->owned_buffers.mutable_element(out_index) = false;
477           });
478 
479       XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor);
480       xla_tensor->set_shaped_buffer(std::move(shaped_buffer));
481       xla_tensor->ResetDefinitionEvent(definition_event, stream);
482     }
483   };
484 
485   const int num_updated_variables = variable_updates.output_to_input.size();
486   TF_RET_CHECK(num_updated_variables <= output_tensor_shapes.size())
487       << num_updated_variables << " <= " << output_tensor_shapes.size();
488 
489   OpInputList arg_list;
490   TF_RETURN_IF_ERROR(context->input_list("args", &arg_list));
491 
492   // The TPU program outputs the updated variables including DT_RESOURCE and
493   // non-DT_RESOURCE. The TPUExecuteOp needs to output all non-DT_RESOURCE
494   // variables (updated or not).
495   //
496   //                       updated          not_updated
497   //                 |------------------|------------------|
498   // DT_RESOURCE     | allocate persist |    do nothing    |
499   //                 |------------------|------------------|
500   //                 |     allocate     | forward Op input |
501   // not DT_RESOURCE |      output      |   to Op output   | Op output
502   //                 |------------------|------------------|
503   //                    program output
504 
505   // Allocates a fresh tensor for each updated variable. While the variable
506   // inputs need come in no particular order, the variable values are
507   // always added last by XlaCompiler class, in the same order as the
508   // corresponding input variables.
509   int op_output_index = 0;
510   int compiled_update_index = 0;
511   auto process_non_updated_variable = [&](int input_index) {
512     const int variable_index = input_buffers->variable_index.at(input_index);
513     // If a DT_RESOURCE input is not updated, nothing needs to be done
514     // because there is no corresponding output. If a non-resource input
515     // is not updated, forward the input to the output.
516     if (variable_index < 0) {
517       context->set_output(op_output_index, arg_list[input_index]);
518       ++op_output_index;
519     }
520   };
521   for (int i = 0; i < output_tensor_shapes.size(); ++i) {
522     auto it = variable_updates.output_to_input.find(i);
523     if (it == variable_updates.output_to_input.end()) {
524       // Not a variable update.
525       // Allocates a fresh tensor for each output of the operator. We always
526       // allocate a new host-side tensor, but the on-device buffers that back
527       // that tensor may be aliases of input buffers.
528       Tensor* output_tensor;
529       TF_RETURN_IF_ERROR(context->allocate_output(
530           op_output_index, output_tensor_shapes[i], &output_tensor));
531       transfer_buffers(i, output_tensor);
532       ++op_output_index;
533       continue;
534     }
535     const int input_index = it->second.first;
536     // We must process the compiled updates in order, which includes the
537     // non-updated variables, i.e., those without an XLA output.
538     const bool from_compilation = it->second.second;
539     while (from_compilation &&
540            variable_updates
541                    .input_in_compiled_update_order[compiled_update_index] !=
542                input_index) {
543       process_non_updated_variable(
544           variable_updates
545               .input_in_compiled_update_order[compiled_update_index]);
546       ++compiled_update_index;
547     }
548     ++compiled_update_index;
549     const int variable_index = input_buffers->variable_index.at(input_index);
550     if (variable_index >= 0) {
551       // This output corresponds to a DT_RESOURCE input to the TPUExecute
552       // operator. Update the corresponding variable.
553       VariableInfo& var = input_buffers->variables[variable_index];
554       TF_RETURN_IF_ERROR(context->allocate_temp(var.var()->tensor()->dtype(),
555                                                 output_tensor_shapes[i],
556                                                 var.var()->tensor()));
557       transfer_buffers(i, var.var()->tensor());
558     } else {
559       // This output corresponds to a non-resource input to the TPUExecute
560       // operator. This case occurs for the distributed TPU rewrite which
561       // adds variable values as inputs and outputs rather than passing the
562       // variables themselves; reading and writing the variable is handled
563       // outside the op.
564       // TODO(phawkins): remove this case when placement of variables on TPU
565       // devices is well supported and we no longer need to place "remote"
566       // variables on CPU devices.
567       Tensor* output_tensor;
568       TF_RETURN_IF_ERROR(context->allocate_output(
569           op_output_index, output_tensor_shapes[i], &output_tensor));
570       ++op_output_index;
571       transfer_buffers(i, output_tensor);
572     }
573   }
574 
575   // Process any remaining non-updated variables.
576   for (; compiled_update_index <
577          variable_updates.input_in_compiled_update_order.size();
578        ++compiled_update_index) {
579     process_non_updated_variable(
580         variable_updates.input_in_compiled_update_order[compiled_update_index]);
581   }
582   return std::move(output_buffers);
583 }
584 
585 }  // namespace
586 
587 // TPUExecuteOp
588 
TPUExecuteOp(OpKernelConstruction * context)589 TPUExecuteOp::TPUExecuteOp(OpKernelConstruction* context)
590     : AsyncOpKernel(context, /* is_deferred = */ true) {}
591 
AsAsync()592 AsyncOpKernel* TPUExecuteOp::AsAsync() {
593   // If TPU launches are asynchronous, we can perform the launch without
594   // blocking the calling thread, and so the executor may treat this kernel as
595   // a regular (synchronous) OpKernel.
596   return nullptr;
597 }
598 
Compute(OpKernelContext * context)599 void TPUExecuteOp::Compute(OpKernelContext* context) {
600   Status s = DoWork(context);
601   // NOTE: We can't use `OP_REQUIRES_OK()` here because that macro includes
602   // a dynamic check that we are not in an AsyncOpKernel.
603   if (TF_PREDICT_FALSE(!s.ok())) {
604     context->SetStatus(s);
605   }
606 }
607 
ComputeAsync(OpKernelContext * context,DoneCallback done)608 void TPUExecuteOp::ComputeAsync(OpKernelContext* context, DoneCallback done) {
609   // If TPU launches are asynchronous, then perform the launch on this
610   // thread to avoid a thread hop, which has an observable latency cost.
611   OP_REQUIRES_OK_ASYNC(context, DoWork(context), done);
612   done();
613 }
614 
DoWork(OpKernelContext * context)615 Status TPUExecuteOp::DoWork(OpKernelContext* context) {
616   VLOG(1) << "Cloud TPU: TPUExecuteOp::Compute";
617 
618   const XlaDevice::Metadata* metadata;
619   TF_RETURN_IF_ERROR(XlaDevice::GetMetadata(context, &metadata));
620   const int device_ordinal = metadata->device_ordinal();
621 
622   // We are guaranteed that the object underlying TpuNodeContext won't be
623   // deleted out from under us, while node_context is alive.
624   TF_ASSIGN_OR_RETURN(std::unique_ptr<TpuNodeContext> node_context,
625                       TpuNodeContext::Create(device_ordinal));
626 
627   profiler::TraceMe trace_me(
628       [device_ordinal, context] {
629         return profiler::TraceMeEncode(
630             "TpuExecuteOp", {{"device_ordinal", device_ordinal},
631                              {"id", context->step_id()},
632                              {"iter_num", context->frame_iter().iter_id}});
633       },
634       /*level=*/2);
635   profiler::TraceMe trace_me_init("TPUExecuteOp::Init", /*level=*/2);
636 
637   string rendezvous_key_base;
638   std::unique_ptr<CompilationCacheEntryRef> entry_ref;
639   TF_RETURN_IF_ERROR(
640       GetComputationCacheEntry(context, &rendezvous_key_base, &entry_ref));
641 
642   // Shapes of the inputs and outputs, in xla::Shape form.
643   tpu::TpuCompilationCacheEntry entry = entry_ref->get();
644   const tpu::TpuProgramGroup* tpu_program_group =
645       tensorflow::down_cast<const tpu::TpuProgramGroup*>(
646           entry.tpu_program_group());
647   CHECK_NE(tpu_program_group, nullptr);
648   const int core_index = entry.core_index();
649   const TPUExecutableInfoProto& executable =
650       tpu_program_group->executable_info(core_index);
651 
652   xla::Backend* const backend = node_context->backend();
653   xla::TransferManager* const transfer_manager = backend->transfer_manager();
654   TF_RET_CHECK(context->op_device_context());
655   se::Stream* stream = context->op_device_context()->stream();
656 
657   TF_RET_CHECK(executable.input_shapes_size() == 1);
658 
659   xla::Shape host_shape(executable.input_shapes(0));
660 
661   TF_ASSIGN_OR_RETURN(
662       auto variable_update_map,
663       BuildVariableUpdateMap(executable.variable_indices(),
664                              fused_device_var_reads_in_computation_inputs_,
665                              fused_device_var_updates_in_computation_outputs_,
666                              executable.output_tensor_shapes().size()));
667   TF_ASSIGN_OR_RETURN(
668       std::unique_ptr<InputBuffers> input_buffers,
669       BuildComputationInputs(context, host_shape, variable_update_map, backend,
670                              device_ordinal, stream));
671 
672   // Ideally this should be the host-to-device stream from XlaDeviceContext.
673   // The particular anti-dependency this is avoiding (why we need a separate
674   // transfer stream) is between the executable writing tuple tables and
675   // TPUExecute()'s deregister_stream; if they come from the same stream pool
676   // antidependencies will occur. XlaBackend has a different pool of streams
677   // to the stream->GetOrCreateSubStream() that TPUExecute() uses, so these
678   // will never refer to the same stream.
679   //
680   // TODO(jmolloy): Add the necessary plumbing to obtain the proper
681   // host-to-device stream here.
682   TF_ASSIGN_OR_RETURN(auto transfer_stream_ptr,
683                       backend->BorrowStream(device_ordinal));
684 
685   se::DeviceMemoryAllocator* const allocator = backend->memory_allocator();
686   auto shaped_buffer = input_buffers->ToShapedBuffer(std::move(host_shape),
687                                                      allocator, device_ordinal);
688   if (transfer_manager->CanShapedBufferBeAccessedNow(stream->parent(),
689                                                      shaped_buffer)) {
690     TF_RETURN_IF_ERROR(transfer_manager->WriteRootTupleIndexTable(
691         transfer_stream_ptr.get(), shaped_buffer));
692     stream->ThenWaitFor(transfer_stream_ptr.get());
693   } else {
694     TF_RETURN_IF_ERROR(
695         transfer_manager->WriteRootTupleIndexTable(stream, shaped_buffer));
696   }
697   VLOG(4) << "Input buffers: " << shaped_buffer.ToString();
698 
699   // Snapshot the inputs, if a snapshot was requested.
700   std::shared_ptr<xla::HloSnapshot> hlo_snapshot;
701   if (executable.has_session_module()) {
702     hlo_snapshot =
703         std::make_shared<xla::HloSnapshot>(executable.session_module());
704     auto literal =
705         std::make_shared<xla::Literal>(shaped_buffer.on_host_shape());
706     transfer_manager->TransferLiteralFromDevice(
707         stream, shaped_buffer, literal.get(),
708         [hlo_snapshot, literal](Status status) {
709           if (!status.ok()) {
710             LOG(ERROR) << "TransferLiteralFromDevice for HLO snapshot inputs "
711                           "failed: "
712                        << status;
713             return;
714           }
715           *hlo_snapshot->add_arguments() = literal->ToProto();
716         });
717   }
718 
719   auto definition_event = std::make_shared<se::Event>(stream->parent());
720   TF_RET_CHECK(definition_event->Init())
721       << "TPU definition event initialization failed";
722 
723   trace_me_init.Stop();
724 
725   const uint32 rng_seed = GetXLARandomSeed();
726 
727   std::unique_ptr<xla::DeviceAssignment> device_assignment;
728   if (executable.has_device_assignment()) {
729     TF_ASSIGN_OR_RETURN(device_assignment, xla::DeviceAssignment::Deserialize(
730                                                executable.device_assignment()));
731   }
732 
733   VLOG(4) << "Input buffers after alias resolution: "
734           << shaped_buffer.ToString();
735 
736   std::vector<xla::ExecutionInput> input;
737   input.emplace_back(xla::ExecutionInput(std::move(input_buffers->buffers),
738                                          shaped_buffer.on_host_shape()));
739 
740   // The buffers to be freed are in the `output` and will be automatically
741   // freed when it goes out of the scope. In async mode, this means the buffers
742   // will be freed before anyone calls "BlockHostUntilDone", which indicates
743   // that some of the (input) buffers will be freed while the program is running
744   // and looks scary. However, this turns out to be not a problem since although
745   // we free a memory and reassign it to other users while a program is running,
746   // all subsequent writes to the program that could possibly clobber the memory
747   // will depend on the program to finish.
748   const TPUHostTransferInfoProto& host_transfer_info =
749       tpu_program_group->host_transfer_info(core_index);
750   TF_ASSIGN_OR_RETURN(
751       xla::ExecutionOutput output,
752       TPUExecute(executable, host_transfer_info,
753                  *tpu_program_group->hlo_metadata(core_index), std::move(input),
754                  rendezvous_key_base, rng_seed, node_context.get(),
755                  device_assignment.get(), context->cancellation_manager(),
756                  context, stream, transfer_stream_ptr.get(),
757                  tpu_program_group->tpu_program(core_index)));
758   stream->ThenRecordEvent(definition_event.get());
759 
760   TF_ASSIGN_OR_RETURN(
761       std::unique_ptr<OutputBuffers> output_buffers,
762       AllocateOutputTensors(
763           context, output.ConsumeResult(), executable.output_tensor_shapes(),
764           variable_update_map, node_context.get(), stream, device_ordinal,
765           input_buffers.get(), definition_event));
766 
767   // Transfer the outputs and save the snapshot to disk.
768   if (hlo_snapshot) {
769     auto literal =
770         std::make_shared<xla::Literal>(output_buffers->buffers.on_host_shape());
771     transfer_manager->TransferLiteralFromDevice(
772         stream, output_buffers->buffers, literal.get(),
773         [hlo_snapshot, literal](Status status) {
774           if (status.ok()) {
775             *hlo_snapshot->mutable_result() = literal->ToProto();
776           } else {
777             LOG(ERROR) << "TransferLiteralFromDevice for HLO snapshot "
778                           "outputs failed: "
779                        << status;
780           }
781           DumpHloSnapshotIfEnabled(*hlo_snapshot,
782                                    xla::GetDebugOptionsFromFlags());
783         });
784   }
785   return Status::OK();
786 }
787 
788 TPUExecuteOp::~TPUExecuteOp() = default;
789 
TPUExecuteAndUpdateVariablesOp(OpKernelConstruction * context)790 TPUExecuteAndUpdateVariablesOp::TPUExecuteAndUpdateVariablesOp(
791     OpKernelConstruction* context)
792     : TPUExecuteOp(context) {
793   OP_REQUIRES_OK(context, context->GetAttr(
794                               "device_var_reads_indices",
795                               &fused_device_var_reads_in_computation_inputs_));
796   OP_REQUIRES_OK(
797       context,
798       context->GetAttr("device_var_updates_indices",
799                        &fused_device_var_updates_in_computation_outputs_));
800 }
801 
802 REGISTER_KERNEL_BUILDER(
803     Name("TPUExecute").Device(DEVICE_TPU_NODE).HostMemory("key"), TPUExecuteOp);
804 
805 REGISTER_KERNEL_BUILDER(Name("TPUExecuteAndUpdateVariables")
806                             .Device(DEVICE_TPU_NODE)
807                             .HostMemory("key"),
808                         TPUExecuteAndUpdateVariablesOp);
809 
810 }  // namespace tensorflow
811