1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/tpu/kernels/tpu_execute_op.h"
16
17 #include <utility>
18
19 #include "absl/container/flat_hash_map.h"
20 #include "absl/memory/memory.h"
21 #include "absl/types/span.h"
22 #include "tensorflow/compiler/jit/xla_device.h"
23 #include "tensorflow/compiler/jit/xla_launch_util.h"
24 #include "tensorflow/compiler/jit/xla_tensor.h"
25 #include "tensorflow/compiler/tf2xla/shape_util.h"
26 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
27 #include "tensorflow/compiler/xla/debug_options_flags.h"
28 #include "tensorflow/compiler/xla/service/dump.h"
29 #include "tensorflow/compiler/xla/service/executable.h"
30 #include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h"
31 #include "tensorflow/compiler/xla/shape_util.h"
32 #include "tensorflow/compiler/xla/statusor.h"
33 #include "tensorflow/compiler/xla/xla_data.pb.h"
34 #include "tensorflow/core/framework/allocator.h"
35 #include "tensorflow/core/framework/node_def_util.h"
36 #include "tensorflow/core/framework/op.h"
37 #include "tensorflow/core/framework/op_kernel.h"
38 #include "tensorflow/core/framework/resource_mgr.h"
39 #include "tensorflow/core/framework/resource_var.h"
40 #include "tensorflow/core/framework/tensor.h"
41 #include "tensorflow/core/framework/types.h"
42 #include "tensorflow/core/lib/core/errors.h"
43 #include "tensorflow/core/platform/casts.h"
44 #include "tensorflow/core/platform/tracing.h"
45 #include "tensorflow/core/profiler/lib/traceme.h"
46 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h"
47 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_external.h"
48 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h"
49 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_local_lookup.h"
50 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h"
51 #include "tensorflow/core/tpu/kernels/tpu_executable_info.pb.h"
52 #include "tensorflow/core/tpu/kernels/tpu_op_consts.h"
53 #include "tensorflow/core/tpu/tpu_configuration.h"
54 #include "tensorflow/core/tpu/tpu_defs.h"
55 #include "tensorflow/core/tpu/tpu_execute.h"
56 #include "tensorflow/core/util/stream_executor_util.h"
57 #include "tensorflow/stream_executor/device_memory_allocator.h"
58 #include "tensorflow/stream_executor/tpu/tpu_node_context.h"
59
60 namespace tensorflow {
61 namespace {
62 using ::tensorflow::tpu::CompilationCacheEntryRef;
63 using ::tensorflow::tpu::TpuCompilationCacheLookup;
64 using ::tensorflow::tpu::TpuNodeContext;
65
66 // Looks up the input `key` in the compilation cache, populating
67 // `*rendezvous_key_base` and `*entry`.
GetComputationCacheEntry(OpKernelContext * context,string * rendezvous_key_base,std::unique_ptr<CompilationCacheEntryRef> * entry)68 Status GetComputationCacheEntry(
69 OpKernelContext* context, string* rendezvous_key_base,
70 std::unique_ptr<CompilationCacheEntryRef>* entry) {
71 const Tensor* key;
72 TF_RETURN_IF_ERROR(context->input("key", &key));
73 profiler::TraceMe trace_me("TpuExecuteOp::LookupProto", /*level=*/2);
74 if (!TensorShapeUtils::IsVector(key->shape()) ||
75 key->shape().dim_size(0) != 3) {
76 return errors::InvalidArgument(
77 "Key argument to TPUExecute must be a 3-element vector");
78 }
79
80 ResourceMgr* rmgr = GetTPUConfigResourceMgr();
81 TpuCompilationCacheLookup* proto_lookup;
82 TF_RETURN_IF_ERROR(rmgr->Lookup(rmgr->default_container(),
83 tpu::kCompiledProtoCacheResourceName,
84 &proto_lookup));
85 core::ScopedUnref lookup_unref(proto_lookup);
86 TF_RETURN_IF_ERROR(proto_lookup->Lookup(key->vec<tstring>()(0), entry));
87 *rendezvous_key_base = key->vec<tstring>()(1);
88 return Status::OK();
89 }
90
91 struct VariableUpdateMap {
92 // Maps input index to the updated output index. If the variable doesn't have
93 // an updated output, the corresponding output is set to -1.
94 absl::flat_hash_map<int, int> input_to_output;
95 // Maps output index to (the input index, whether the update is generated from
96 // compilation).
97 absl::flat_hash_map<int, std::pair<int, bool>> output_to_input;
98 // Part of the input indices that are from the compilation, in the compiled
99 // order.
100 std::vector<int> input_in_compiled_update_order;
101 };
102
103 // Creates a VariableUpdateMap from both the compilation and the fused variable
104 // reads/updates.
BuildVariableUpdateMap(absl::Span<const TPUExecutableInfoProto::UpdateIndexPair * const> compiled_variable_updates,absl::Span<int const> fused_device_var_reads_in_computation_inputs,const std::vector<int> & fused_device_var_updates_in_computation_outputs,int64 computation_output_count)105 xla::StatusOr<VariableUpdateMap> BuildVariableUpdateMap(
106 absl::Span<const TPUExecutableInfoProto::UpdateIndexPair* const>
107 compiled_variable_updates,
108 absl::Span<int const> fused_device_var_reads_in_computation_inputs,
109 const std::vector<int>& fused_device_var_updates_in_computation_outputs,
110 int64 computation_output_count) {
111 VariableUpdateMap map;
112 auto add_pair = [&](int input, int output, bool from_compilation) -> Status {
113 TF_RET_CHECK(map.input_to_output.emplace(input, output).second)
114 << "Duplicate variable input index: " << input;
115 if (output >= 0) {
116 TF_RET_CHECK(map.output_to_input
117 .emplace(output, std::make_pair(input, from_compilation))
118 .second)
119 << "Duplicate variable output index: " << output;
120 }
121 return Status::OK();
122 };
123
124 // First add the updates produced by the compilation. Not all variables are
125 // updated, and if not, they do not have an output in the XLA computation. The
126 // update output indices in the XLA computation start after the non-variable
127 // outputs.
128 int num_updated_variables = 0;
129 for (int i = 0; i < compiled_variable_updates.size(); ++i) {
130 const bool updated = compiled_variable_updates[i]->updated();
131 if (updated) ++num_updated_variables;
132 }
133 TF_RET_CHECK(num_updated_variables <= computation_output_count)
134 << num_updated_variables << " <= " << computation_output_count;
135 int64 compiled_variable_output_index =
136 computation_output_count - num_updated_variables;
137 for (auto update : compiled_variable_updates) {
138 map.input_in_compiled_update_order.push_back(update->index());
139 if (!update->updated()) {
140 TF_RETURN_IF_ERROR(add_pair(update->index(), -1, true));
141 continue;
142 }
143 TF_RETURN_IF_ERROR(
144 add_pair(update->index(), compiled_variable_output_index, true));
145 ++compiled_variable_output_index;
146 }
147
148 // Now add the updates from the attributes.
149 TF_RET_CHECK(fused_device_var_reads_in_computation_inputs.size() ==
150 fused_device_var_updates_in_computation_outputs.size());
151 for (int64 i = 0; i < fused_device_var_reads_in_computation_inputs.size();
152 ++i) {
153 TF_RETURN_IF_ERROR(
154 add_pair(fused_device_var_reads_in_computation_inputs[i],
155 fused_device_var_updates_in_computation_outputs[i], false));
156 }
157 return map;
158 }
159
160 // Buffers representing the inputs to a computation.
161 struct InputBuffers {
InputBufferstensorflow::__anonfe26a15c0111::InputBuffers162 explicit InputBuffers(xla::Shape device_shape)
163 : buffers(std::move(device_shape)) {}
164
165 InputBuffers(const InputBuffers&) = delete;
166 InputBuffers& operator=(const InputBuffers&) = delete;
167
168 ~InputBuffers() = default;
169
ToShapedBuffertensorflow::__anonfe26a15c0111::InputBuffers170 xla::ShapedBuffer ToShapedBuffer(xla::Shape host_shape,
171 se::DeviceMemoryAllocator* allocator,
172 int device_ordinal) {
173 CHECK_NE(allocator, nullptr);
174 xla::ShapedBuffer shaped_buffer(std::move(host_shape), buffers.shape(),
175 device_ordinal);
176 shaped_buffer.set_buffers(buffers.Map<se::DeviceMemoryBase>(
177 [](xla::MaybeOwningDeviceMemory* buffer) {
178 CHECK(buffer);
179 return buffer->AsDeviceMemoryBase();
180 }));
181 return shaped_buffer;
182 }
183
184 // Describes the buffer tree.
185 xla::ShapeTree<xla::MaybeOwningDeviceMemory> buffers;
186
187 // Information about resource variables passed directly to TPUExecute.
188 std::vector<VariableInfo> variables;
189
190 // Mapping from input index to offsets in 'variables'. < 0 if the input does
191 // not correspond to a variable in 'variables'.
192 std::vector<int> variable_index;
193 };
194
195 // Builds an InputBuffers object that describes the inputs to the computation.
BuildComputationInputs(OpKernelContext * context,const xla::Shape & input_host_shape,const VariableUpdateMap & variable_updates,xla::Backend * backend,int device_ordinal,se::Stream * stream)196 xla::StatusOr<std::unique_ptr<InputBuffers>> BuildComputationInputs(
197 OpKernelContext* context, const xla::Shape& input_host_shape,
198 const VariableUpdateMap& variable_updates, xla::Backend* backend,
199 int device_ordinal, se::Stream* stream) {
200 profiler::TraceMe trace_me("BuildComputationInputs", /*level=*/2);
201 OpInputList arg_list;
202 TF_RETURN_IF_ERROR(context->input_list("args", &arg_list));
203
204 if (arg_list.size() != xla::ShapeUtil::TupleElementCount(input_host_shape)) {
205 return errors::InvalidArgument(
206 "Number of parameters (", arg_list.size(),
207 ") does not match input shape: ",
208 xla::ShapeUtil::TupleElementCount(input_host_shape));
209 }
210
211 auto validate_shape = [&](int i, const Tensor& tensor) {
212 const xla::Shape& expected =
213 xla::ShapeUtil::GetTupleElementShape(input_host_shape, i);
214 VLOG(4) << "Input " << i << " TF shape " << tensor.shape().DebugString();
215 XlaTensor* xla_tensor = XlaTensor::FromTensor(&tensor);
216
217 if (xla_tensor == nullptr) {
218 // FromTensor failed; tensor must be empty.
219 if (!xla::ShapeUtil::IsZeroElementArray(expected)) {
220 return errors::InvalidArgument(
221 "Run-time shape mismatch for TPUExecute argument[", i, "] (",
222 context->op_kernel().requested_input(i), "). Expected ",
223 expected.DebugString(),
224 "; got empty tensor. If you are running "
225 "with TF2 TPU, make sure you set `drop_remainder=False` when "
226 "calling `dataset.batch` on the `tf.data.Dataset` so dynamic batch "
227 "size can be handled");
228 }
229 } else {
230 // Compare host shapes, easier than getting the expected device shape.
231 const xla::Shape& xla_shape = xla_tensor->shaped_buffer().on_host_shape();
232 if (!xla::ShapeUtil::Compatible(expected, xla_shape)) {
233 return errors::InvalidArgument(
234 "Run-time shape mismatch for TPUExecute argument[", i, "] (",
235 context->op_kernel().requested_input(i), "). Expected ",
236 expected.DebugString(), "; got ", xla_shape.DebugString());
237 }
238 }
239
240 return Status::OK();
241 };
242
243 // Iterate over the inputs, validating the shapes of non-variable inputs,
244 // and creating a VariableInfo object for each variable. We consider variable
245 // inputs in a separate phase because we must acquire variable locks in order.
246 std::vector<VariableInfo> variables;
247 std::vector<int> variable_index(arg_list.size(), -1);
248 variables.reserve(arg_list.size());
249 for (int i = 0; i < arg_list.size(); ++i) {
250 // Arguments are assumed to be variables if they have a resource type.
251 // (Non-variable resources are not supported.)
252 if (context->input_dtype(i) == DT_RESOURCE) {
253 variable_index[i] = variables.size();
254 // TODO(phawkins): we may be looking up many variables here; it would be
255 // better if we did not repeatedly acquire the resource manager's lock.
256 const ResourceHandle& handle = HandleFromInput(context, i);
257 Var* variable;
258 TF_RETURN_IF_ERROR(LookupResource(context, handle, &variable));
259 variables.push_back(VariableInfo(i, handle.name(), variable));
260 } else {
261 TF_RETURN_IF_ERROR(validate_shape(i, arg_list[i]));
262 }
263 }
264
265 // Lock the variables, and validate their shapes. We hold the variable locks
266 // for the duration of the TPU execution so we can donate the variable buffers
267 // to the computation. If we copied the variable's Tensor instead, its
268 // reference count would be greater than one due to the reference the Var
269 // object holds, and we would never be able to reuse variable buffers.
270 // TODO(phawkins): add a 'reuse_buffers' attribute to TPUExecute that allows
271 // the user to elect to copy the buffers and permit concurrent access instead.
272 TF_RETURN_IF_ERROR(LockVariables(absl::MakeSpan(variables)));
273 for (int i = 0; i < variables.size(); ++i) {
274 TF_RETURN_IF_ERROR(
275 validate_shape(variables[i].index(), *variables[i].var()->tensor()));
276 }
277
278 se::DeviceMemoryAllocator* const allocator = backend->memory_allocator();
279 xla::TransferManager* const transfer_manager = backend->transfer_manager();
280
281 auto input_buffers = absl::make_unique<InputBuffers>(
282 transfer_manager->HostShapeToDeviceShape(input_host_shape));
283
284 // Allocates a buffer for the root tuple.
285 const int64 root_size =
286 transfer_manager->GetByteSizeRequirement(input_buffers->buffers.shape());
287 TF_ASSIGN_OR_RETURN(*input_buffers->buffers.mutable_element({}),
288 allocator->Allocate(device_ordinal, root_size));
289
290 // Helper function that sets the input buffers for 'arg_index' to 'buffers'.
291 // If 'donate_buffers' is true, donates ownership of the buffers in 'buffers'
292 // to the computation and overwrites the entries in 'buffers' with nulls.
293 auto set_input_buffers_helper = [&](int arg_index, bool donate_buffers,
294 xla::ShapedBuffer* buffers) {
295 buffers->buffers().ForEachMutableElement([&](const xla::ShapeIndex& index,
296 se::DeviceMemoryBase* buffer) {
297 xla::ShapeIndex in_index = {arg_index};
298 for (int64 j : index) {
299 in_index.push_back(j);
300 }
301 auto* in_buffer = input_buffers->buffers.mutable_element(in_index);
302 if (donate_buffers) {
303 *in_buffer = se::OwningDeviceMemory(*buffer, device_ordinal, allocator);
304 *buffer = se::DeviceMemoryBase();
305 } else {
306 *in_buffer = *buffer;
307 }
308 });
309 };
310
311 // Assigns the buffers of 'tensor' as computation input 'i'. Allocates fresh
312 // buffers for zero-element tensors where required.
313 auto assign_input = [&](int i, const Tensor& tensor,
314 bool may_reuse) -> xla::Status {
315 XlaTensor* xla_tensor = XlaTensor::FromTensor(&tensor);
316
317 // Size 0 tensors have no backing XlaTensor, but may still need to have
318 // tuple buffers allocated.
319 if (xla_tensor == nullptr) {
320 CHECK_EQ(tensor.NumElements(), 0);
321 const xla::Shape& host_shape =
322 xla::ShapeUtil::GetSubshape(input_host_shape, {i});
323 TF_ASSIGN_OR_RETURN(xla::ScopedShapedBuffer buffers,
324 transfer_manager->AllocateScopedShapedBuffer(
325 host_shape, allocator, device_ordinal));
326 set_input_buffers_helper(/*arg_index=*/i, /*donate_buffers=*/true,
327 &buffers);
328 } else {
329 bool can_reuse_buffers = tensor.RefCountIsOne() && may_reuse;
330 set_input_buffers_helper(/*arg_index=*/i,
331 /*donate_buffers=*/can_reuse_buffers,
332 &xla_tensor->shaped_buffer());
333 xla_tensor->WaitForDefinitionEventOnStream(stream);
334 }
335 return Status::OK();
336 };
337
338 for (int i = 0; i < arg_list.size(); ++i) {
339 auto it = variable_updates.input_to_output.find(i);
340 if (it == variable_updates.input_to_output.end()) {
341 TF_RETURN_IF_ERROR(assign_input(i, arg_list[i], /*may_reuse=*/true));
342 continue;
343 }
344 // input i is a variable
345 bool updated = it->second >= 0;
346 if (arg_list[i].dtype() != DT_RESOURCE) {
347 TF_RETURN_IF_ERROR(assign_input(i, arg_list[i], updated));
348 } else {
349 int vi = variable_index[i];
350 TF_RETURN_IF_ERROR(
351 assign_input(i, *variables[vi].var()->tensor(), updated));
352 }
353 }
354
355 input_buffers->variables = std::move(variables);
356 input_buffers->variable_index = std::move(variable_index);
357
358 return std::move(input_buffers);
359 }
360
361 struct OutputBuffers {
OutputBufferstensorflow::__anonfe26a15c0111::OutputBuffers362 OutputBuffers(xla::ScopedShapedBuffer b, se::DeviceMemoryAllocator* allocator)
363 : owned_buffers(b.on_device_shape(), true),
364 buffers(b.release()),
365 memory_allocator(allocator) {}
366
~OutputBufferstensorflow::__anonfe26a15c0111::OutputBuffers367 ~OutputBuffers() {
368 buffers.buffers().ForEachElement(
369 [&](const xla::ShapeIndex& index, const se::DeviceMemoryBase& buffer) {
370 if (owned_buffers.element(index) && !buffer.is_null()) {
371 Status status =
372 memory_allocator->Deallocate(buffers.device_ordinal(), buffer);
373 if (!status.ok()) {
374 LOG(ERROR) << "Error deallocating buffer " << status;
375 }
376 }
377 });
378 }
379
380 // Which of the buffers do we own?
381 xla::ShapeTree<bool> owned_buffers;
382
383 xla::ShapedBuffer buffers;
384
385 se::DeviceMemoryAllocator* const memory_allocator;
386 };
387
388 // Allocates Tensors for the outputs of the computation. Ownership of most
389 // output buffers is passed to the output Tensors. Returns an OutputBuffer that
390 // owns the root buffer that should be passed to the XLA computation, as well as
391 // any output buffers that do not have corresponding output tensors. The latter
392 // may happen for zero-element tensors of type int64 or complex64 which still
393 // require a tuple buffer but do not have a corresponding XlaTensor.
AllocateOutputTensors(OpKernelContext * context,xla::ScopedShapedBuffer scoped_buffers,absl::Span<const TensorShapeProto * const> output_tensor_shape_protos,const VariableUpdateMap & variable_updates,TpuNodeContext * node_context,se::Stream * stream,int device_ordinal,InputBuffers * input_buffers,const std::shared_ptr<se::Event> & definition_event)394 xla::StatusOr<std::unique_ptr<OutputBuffers>> AllocateOutputTensors(
395 OpKernelContext* context, xla::ScopedShapedBuffer scoped_buffers,
396 absl::Span<const TensorShapeProto* const> output_tensor_shape_protos,
397 const VariableUpdateMap& variable_updates, TpuNodeContext* node_context,
398 se::Stream* stream, int device_ordinal, InputBuffers* input_buffers,
399 const std::shared_ptr<se::Event>& definition_event) {
400 VLOG(4) << "Output buffers: " << scoped_buffers.ToString();
401
402 profiler::TraceMe trace_me("AllocateOutputTensors", /*level=*/2);
403 // Shapes of the outputs, in TensorShape form.
404 const int64 sub_elements =
405 xla::ShapeUtil::TupleElementCount(scoped_buffers.on_host_shape());
406 if (sub_elements != output_tensor_shape_protos.size()) {
407 return errors::InvalidArgument(
408 "Mismatched numbers of output shapes: ", sub_elements, " vs. ",
409 output_tensor_shape_protos.size());
410 }
411
412 xla::TransferManager* const transfer_manager =
413 node_context->backend()->transfer_manager();
414
415 std::vector<TensorShape> output_tensor_shapes;
416 output_tensor_shapes.reserve(sub_elements);
417 for (int64 i = 0; i < sub_elements; ++i) {
418 TF_RETURN_IF_ERROR(
419 TensorShape::IsValidShape(*output_tensor_shape_protos[i]));
420 TensorShape shape(*output_tensor_shape_protos[i]);
421 const xla::Shape& xla_shape =
422 xla::ShapeUtil::GetSubshape(scoped_buffers.on_host_shape(), {i});
423 if (!xla_shape.IsArray() ||
424 xla::ShapeUtil::ElementsIn(xla_shape) != shape.num_elements()) {
425 return errors::InvalidArgument(
426 "Mismatched number of elements in output shape: ",
427 xla::ShapeUtil::HumanString(xla_shape), " vs ", shape.DebugString());
428 }
429 output_tensor_shapes.push_back(shape);
430 }
431
432 // Builds a shaped buffer for the outputs.
433 TF_RET_CHECK(scoped_buffers.on_host_shape().IsTuple());
434 TF_RET_CHECK(!xla::ShapeUtil::IsNestedTuple(scoped_buffers.on_host_shape()));
435
436 se::DeviceMemoryAllocator* const allocator =
437 node_context->backend()->memory_allocator();
438
439 auto output_buffers =
440 absl::make_unique<OutputBuffers>(std::move(scoped_buffers), allocator);
441
442 xla::Shape output_device_shape = output_buffers->buffers.on_device_shape();
443
444 if (!output_device_shape.is_static()) {
445 TF_RETURN_IF_ERROR(transfer_manager->ReadDynamicShapes(
446 stream, &output_buffers->buffers, &output_device_shape));
447 for (int64 i = 0; i < sub_elements; ++i) {
448 const xla::Shape& subshape =
449 xla::ShapeUtil::GetSubshape(output_device_shape, {i});
450 TensorShape shape;
451 TF_RETURN_IF_ERROR(XLAShapeToTensorShape(subshape, &shape));
452 output_tensor_shapes[i] = shape;
453 }
454 }
455
456 // Transfers ownership of the buffers that back XLA computation output 'i'
457 // to 'output_tensor'.
458 auto transfer_buffers = [&](int i, Tensor* output_tensor) {
459 const xla::Shape& device_shape =
460 xla::ShapeUtil::GetTupleElementShape(output_device_shape, i);
461
462 // Transfers ownership of the output buffers to the output Tensor, if
463 // there the tensor is backed by an XlaTensor. Tensors of size 0 have no
464 // backing XlaTensor, so we let retain 'output_buffers' ownership of any
465 // buffers in that case.
466 if (output_tensor->NumElements() > 0) {
467 xla::ScopedShapedBuffer shaped_buffer(device_shape, allocator,
468 device_ordinal);
469 shaped_buffer.buffers().ForEachMutableElement(
470 [&](const xla::ShapeIndex& index, se::DeviceMemoryBase* buffer) {
471 xla::ShapeIndex out_index = {i};
472 for (int64 j : index) {
473 out_index.push_back(j);
474 }
475 *buffer = output_buffers->buffers.buffers().element(out_index);
476 *output_buffers->owned_buffers.mutable_element(out_index) = false;
477 });
478
479 XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor);
480 xla_tensor->set_shaped_buffer(std::move(shaped_buffer));
481 xla_tensor->ResetDefinitionEvent(definition_event, stream);
482 }
483 };
484
485 const int num_updated_variables = variable_updates.output_to_input.size();
486 TF_RET_CHECK(num_updated_variables <= output_tensor_shapes.size())
487 << num_updated_variables << " <= " << output_tensor_shapes.size();
488
489 OpInputList arg_list;
490 TF_RETURN_IF_ERROR(context->input_list("args", &arg_list));
491
492 // The TPU program outputs the updated variables including DT_RESOURCE and
493 // non-DT_RESOURCE. The TPUExecuteOp needs to output all non-DT_RESOURCE
494 // variables (updated or not).
495 //
496 // updated not_updated
497 // |------------------|------------------|
498 // DT_RESOURCE | allocate persist | do nothing |
499 // |------------------|------------------|
500 // | allocate | forward Op input |
501 // not DT_RESOURCE | output | to Op output | Op output
502 // |------------------|------------------|
503 // program output
504
505 // Allocates a fresh tensor for each updated variable. While the variable
506 // inputs need come in no particular order, the variable values are
507 // always added last by XlaCompiler class, in the same order as the
508 // corresponding input variables.
509 int op_output_index = 0;
510 int compiled_update_index = 0;
511 auto process_non_updated_variable = [&](int input_index) {
512 const int variable_index = input_buffers->variable_index.at(input_index);
513 // If a DT_RESOURCE input is not updated, nothing needs to be done
514 // because there is no corresponding output. If a non-resource input
515 // is not updated, forward the input to the output.
516 if (variable_index < 0) {
517 context->set_output(op_output_index, arg_list[input_index]);
518 ++op_output_index;
519 }
520 };
521 for (int i = 0; i < output_tensor_shapes.size(); ++i) {
522 auto it = variable_updates.output_to_input.find(i);
523 if (it == variable_updates.output_to_input.end()) {
524 // Not a variable update.
525 // Allocates a fresh tensor for each output of the operator. We always
526 // allocate a new host-side tensor, but the on-device buffers that back
527 // that tensor may be aliases of input buffers.
528 Tensor* output_tensor;
529 TF_RETURN_IF_ERROR(context->allocate_output(
530 op_output_index, output_tensor_shapes[i], &output_tensor));
531 transfer_buffers(i, output_tensor);
532 ++op_output_index;
533 continue;
534 }
535 const int input_index = it->second.first;
536 // We must process the compiled updates in order, which includes the
537 // non-updated variables, i.e., those without an XLA output.
538 const bool from_compilation = it->second.second;
539 while (from_compilation &&
540 variable_updates
541 .input_in_compiled_update_order[compiled_update_index] !=
542 input_index) {
543 process_non_updated_variable(
544 variable_updates
545 .input_in_compiled_update_order[compiled_update_index]);
546 ++compiled_update_index;
547 }
548 ++compiled_update_index;
549 const int variable_index = input_buffers->variable_index.at(input_index);
550 PersistentTensor unused;
551 Tensor* output_tensor;
552 if (variable_index >= 0) {
553 // This output corresponds to a DT_RESOURCE input to the TPUExecute
554 // operator. Update the corresponding variable.
555 VariableInfo& var = input_buffers->variables[variable_index];
556 // TODO(b/35625933): the correct thing to do would be to transfer
557 // ownership of the PersistentTensor into the Var object. However, Var
558 // contains a Tensor so we can't.
559 TF_RETURN_IF_ERROR(context->allocate_persistent(
560 var.var()->tensor()->dtype(), output_tensor_shapes[i], &unused,
561 &output_tensor));
562 *var.var()->tensor() = *output_tensor;
563 } else {
564 // This output corresponds to a non-resource input to the TPUExecute
565 // operator. This case occurs for the distributed TPU rewrite which
566 // adds variable values as inputs and outputs rather than passing the
567 // variables themselves; reading and writing the variable is handled
568 // outside the op.
569 // TODO(phawkins): remove this case when placement of variables on TPU
570 // devices is well supported and we no longer need to place "remote"
571 // variables on CPU devices.
572 TF_RETURN_IF_ERROR(context->allocate_output(
573 op_output_index, output_tensor_shapes[i], &output_tensor));
574 ++op_output_index;
575 }
576 transfer_buffers(i, output_tensor);
577 }
578
579 // Process any remaining non-updated variables.
580 for (; compiled_update_index <
581 variable_updates.input_in_compiled_update_order.size();
582 ++compiled_update_index) {
583 process_non_updated_variable(
584 variable_updates.input_in_compiled_update_order[compiled_update_index]);
585 }
586 return std::move(output_buffers);
587 }
588
589 } // namespace
590
591 // TPUExecuteOp
592
TPUExecuteOp(OpKernelConstruction * context)593 TPUExecuteOp::TPUExecuteOp(OpKernelConstruction* context)
594 : AsyncOpKernel(context, /* is_deferred = */ true) {}
595
AsAsync()596 AsyncOpKernel* TPUExecuteOp::AsAsync() {
597 // If TPU launches are asynchronous, we can perform the launch without
598 // blocking the calling thread, and so the executor may treat this kernel as
599 // a regular (synchronous) OpKernel.
600 return nullptr;
601 }
602
Compute(OpKernelContext * context)603 void TPUExecuteOp::Compute(OpKernelContext* context) {
604 Status s = DoWork(context);
605 // NOTE: We can't use `OP_REQUIRES_OK()` here because that macro includes
606 // a dynamic check that we are not in an AsyncOpKernel.
607 if (TF_PREDICT_FALSE(!s.ok())) {
608 context->SetStatus(s);
609 }
610 }
611
ComputeAsync(OpKernelContext * context,DoneCallback done)612 void TPUExecuteOp::ComputeAsync(OpKernelContext* context, DoneCallback done) {
613 // If TPU launches are asynchronous, then perform the launch on this
614 // thread to avoid a thread hop, which has an observable latency cost.
615 OP_REQUIRES_OK_ASYNC(context, DoWork(context), done);
616 done();
617 }
618
DoWork(OpKernelContext * context)619 Status TPUExecuteOp::DoWork(OpKernelContext* context) {
620 VLOG(1) << "Cloud TPU: TPUExecuteOp::Compute";
621
622 const XlaDevice::Metadata* metadata;
623 TF_RETURN_IF_ERROR(XlaDevice::GetMetadata(context, &metadata));
624 const int device_ordinal = metadata->device_ordinal();
625
626 // We are guaranteed that the object underlying TpuNodeContext won't be
627 // deleted out from under us, while node_context is alive.
628 TF_ASSIGN_OR_RETURN(std::unique_ptr<TpuNodeContext> node_context,
629 TpuNodeContext::Create(device_ordinal));
630
631 profiler::TraceMe trace_me(
632 [device_ordinal, context] {
633 return profiler::TraceMeEncode(
634 "TpuExecuteOp", {{"device_ordinal", device_ordinal},
635 {"id", context->step_id()},
636 {"iter_num", context->frame_iter().iter_id}});
637 },
638 /*level=*/2);
639 profiler::TraceMe trace_me_init("TPUExecuteOp::Init", /*level=*/2);
640
641 string rendezvous_key_base;
642 std::unique_ptr<CompilationCacheEntryRef> entry_ref;
643 TF_RETURN_IF_ERROR(
644 GetComputationCacheEntry(context, &rendezvous_key_base, &entry_ref));
645
646 // Shapes of the inputs and outputs, in xla::Shape form.
647 tpu::TpuCompilationCacheEntry entry = entry_ref->get();
648 const tpu::TpuProgramGroup* tpu_program_group =
649 tensorflow::down_cast<const tpu::TpuProgramGroup*>(
650 entry.tpu_program_group());
651 CHECK_NE(tpu_program_group, nullptr);
652 const int core_index = entry.core_index();
653 const TPUExecutableInfoProto& executable =
654 tpu_program_group->executable_info(core_index);
655
656 xla::Backend* const backend = node_context->backend();
657 xla::TransferManager* const transfer_manager = backend->transfer_manager();
658 TF_RET_CHECK(context->op_device_context());
659 se::Stream* stream = context->op_device_context()->stream();
660
661 TF_RET_CHECK(executable.input_shapes_size() == 1);
662
663 xla::Shape host_shape(executable.input_shapes(0));
664
665 TF_ASSIGN_OR_RETURN(
666 auto variable_update_map,
667 BuildVariableUpdateMap(executable.variable_indices(),
668 fused_device_var_reads_in_computation_inputs_,
669 fused_device_var_updates_in_computation_outputs_,
670 executable.output_tensor_shapes().size()));
671 TF_ASSIGN_OR_RETURN(
672 std::unique_ptr<InputBuffers> input_buffers,
673 BuildComputationInputs(context, host_shape, variable_update_map, backend,
674 device_ordinal, stream));
675
676 // Ideally this should be the host-to-device stream from XlaDeviceContext.
677 // The particular anti-dependency this is avoiding (why we need a separate
678 // transfer stream) is between the executable writing tuple tables and
679 // TPUExecute()'s deregister_stream; if they come from the same stream pool
680 // antidependencies will occur. XlaBackend has a different pool of streams
681 // to the stream->GetOrCreateSubStream() that TPUExecute() uses, so these
682 // will never refer to the same stream.
683 //
684 // TODO(jmolloy): Add the necessary plumbing to obtain the proper
685 // host-to-device stream here.
686 TF_ASSIGN_OR_RETURN(auto transfer_stream_ptr,
687 backend->BorrowStream(device_ordinal));
688
689 se::DeviceMemoryAllocator* const allocator = backend->memory_allocator();
690 auto shaped_buffer = input_buffers->ToShapedBuffer(std::move(host_shape),
691 allocator, device_ordinal);
692 if (transfer_manager->CanShapedBufferBeAccessedNow(stream->parent(),
693 shaped_buffer)) {
694 TF_RETURN_IF_ERROR(transfer_manager->WriteRootTupleIndexTable(
695 transfer_stream_ptr.get(), shaped_buffer));
696 stream->ThenWaitFor(transfer_stream_ptr.get());
697 } else {
698 TF_RETURN_IF_ERROR(
699 transfer_manager->WriteRootTupleIndexTable(stream, shaped_buffer));
700 }
701 VLOG(4) << "Input buffers: " << shaped_buffer.ToString();
702
703 // Snapshot the inputs, if a snapshot was requested.
704 std::shared_ptr<xla::HloSnapshot> hlo_snapshot;
705 if (executable.has_session_module()) {
706 hlo_snapshot =
707 std::make_shared<xla::HloSnapshot>(executable.session_module());
708 auto literal =
709 std::make_shared<xla::Literal>(shaped_buffer.on_host_shape());
710 transfer_manager->TransferLiteralFromDevice(
711 stream, shaped_buffer, literal.get(),
712 [hlo_snapshot, literal](Status status) {
713 if (!status.ok()) {
714 LOG(ERROR) << "TransferLiteralFromDevice for HLO snapshot inputs "
715 "failed: "
716 << status;
717 return;
718 }
719 *hlo_snapshot->add_arguments() = literal->ToProto();
720 });
721 }
722
723 auto definition_event = std::make_shared<se::Event>(stream->parent());
724 TF_RET_CHECK(definition_event->Init())
725 << "TPU definition event initialization failed";
726
727 trace_me_init.Stop();
728
729 const uint32 rng_seed = GetXLARandomSeed();
730
731 std::unique_ptr<xla::DeviceAssignment> device_assignment;
732 if (executable.has_device_assignment()) {
733 TF_ASSIGN_OR_RETURN(device_assignment, xla::DeviceAssignment::Deserialize(
734 executable.device_assignment()));
735 }
736
737 VLOG(4) << "Input buffers after alias resolution: "
738 << shaped_buffer.ToString();
739
740 std::vector<xla::ExecutionInput> input;
741 input.emplace_back(xla::ExecutionInput(std::move(input_buffers->buffers),
742 shaped_buffer.on_host_shape()));
743
744 // The buffers to be freed are in the `output` and will be automatically
745 // freed when it goes out of the scope. In async mode, this means the buffers
746 // will be freed before anyone calls "BlockHostUntilDone", which indicates
747 // that some of the (input) buffers will be freed while the program is running
748 // and looks scary. However, this turns out to be not a problem since although
749 // we free a memory and reassign it to other users while a program is running,
750 // all subsequent writes to the program that could possibly clobber the memory
751 // will depend on the program to finish.
752 const TPUHostTransferInfoProto& host_transfer_info =
753 tpu_program_group->host_transfer_info(core_index);
754 TF_ASSIGN_OR_RETURN(
755 xla::ExecutionOutput output,
756 TPUExecute(executable, host_transfer_info,
757 *tpu_program_group->hlo_metadata(core_index), std::move(input),
758 rendezvous_key_base, rng_seed, node_context.get(),
759 device_assignment.get(), context->cancellation_manager(),
760 context, stream, transfer_stream_ptr.get(),
761 tpu_program_group->tpu_program(core_index)));
762 stream->ThenRecordEvent(definition_event.get());
763
764 TF_ASSIGN_OR_RETURN(
765 std::unique_ptr<OutputBuffers> output_buffers,
766 AllocateOutputTensors(
767 context, output.ConsumeResult(), executable.output_tensor_shapes(),
768 variable_update_map, node_context.get(), stream, device_ordinal,
769 input_buffers.get(), definition_event));
770
771 // Transfer the outputs and save the snapshot to disk.
772 if (hlo_snapshot) {
773 auto literal =
774 std::make_shared<xla::Literal>(output_buffers->buffers.on_host_shape());
775 transfer_manager->TransferLiteralFromDevice(
776 stream, output_buffers->buffers, literal.get(),
777 [hlo_snapshot, literal](Status status) {
778 if (status.ok()) {
779 *hlo_snapshot->mutable_result() = literal->ToProto();
780 } else {
781 LOG(ERROR) << "TransferLiteralFromDevice for HLO snapshot "
782 "outputs failed: "
783 << status;
784 }
785 DumpHloSnapshotIfEnabled(*hlo_snapshot,
786 xla::GetDebugOptionsFromFlags());
787 });
788 }
789 return Status::OK();
790 }
791
792 TPUExecuteOp::~TPUExecuteOp() = default;
793
TPUExecuteAndUpdateVariablesOp(OpKernelConstruction * context)794 TPUExecuteAndUpdateVariablesOp::TPUExecuteAndUpdateVariablesOp(
795 OpKernelConstruction* context)
796 : TPUExecuteOp(context) {
797 OP_REQUIRES_OK(context, context->GetAttr(
798 "device_var_reads_indices",
799 &fused_device_var_reads_in_computation_inputs_));
800 OP_REQUIRES_OK(
801 context,
802 context->GetAttr("device_var_updates_indices",
803 &fused_device_var_updates_in_computation_outputs_));
804 }
805
806 REGISTER_KERNEL_BUILDER(
807 Name("TPUExecute").Device(DEVICE_TPU_NODE).HostMemory("key"), TPUExecuteOp);
808
809 REGISTER_KERNEL_BUILDER(Name("TPUExecuteAndUpdateVariables")
810 .Device(DEVICE_TPU_NODE)
811 .HostMemory("key"),
812 TPUExecuteAndUpdateVariablesOp);
813
814 } // namespace tensorflow
815