1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/tpu/kernels/tpu_execute_op.h"
16
17 #include <utility>
18
19 #include "absl/container/flat_hash_map.h"
20 #include "absl/memory/memory.h"
21 #include "absl/types/span.h"
22 #include "tensorflow/compiler/jit/xla_device.h"
23 #include "tensorflow/compiler/jit/xla_launch_util.h"
24 #include "tensorflow/compiler/jit/xla_tensor.h"
25 #include "tensorflow/compiler/tf2xla/shape_util.h"
26 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
27 #include "tensorflow/compiler/xla/debug_options_flags.h"
28 #include "tensorflow/compiler/xla/service/dump.h"
29 #include "tensorflow/compiler/xla/service/executable.h"
30 #include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h"
31 #include "tensorflow/compiler/xla/shape_util.h"
32 #include "tensorflow/compiler/xla/statusor.h"
33 #include "tensorflow/compiler/xla/xla_data.pb.h"
34 #include "tensorflow/core/framework/allocator.h"
35 #include "tensorflow/core/framework/node_def_util.h"
36 #include "tensorflow/core/framework/op.h"
37 #include "tensorflow/core/framework/op_kernel.h"
38 #include "tensorflow/core/framework/resource_mgr.h"
39 #include "tensorflow/core/framework/resource_var.h"
40 #include "tensorflow/core/framework/tensor.h"
41 #include "tensorflow/core/framework/types.h"
42 #include "tensorflow/core/lib/core/errors.h"
43 #include "tensorflow/core/platform/casts.h"
44 #include "tensorflow/core/platform/tracing.h"
45 #include "tensorflow/core/profiler/lib/traceme.h"
46 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h"
47 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_external.h"
48 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h"
49 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_local_lookup.h"
50 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h"
51 #include "tensorflow/core/tpu/kernels/tpu_executable_info.pb.h"
52 #include "tensorflow/core/tpu/kernels/tpu_op_consts.h"
53 #include "tensorflow/core/tpu/tpu_configuration.h"
54 #include "tensorflow/core/tpu/tpu_defs.h"
55 #include "tensorflow/core/tpu/tpu_execute.h"
56 #include "tensorflow/core/util/stream_executor_util.h"
57 #include "tensorflow/stream_executor/device_memory_allocator.h"
58 #include "tensorflow/stream_executor/tpu/tpu_node_context.h"
59
60 namespace tensorflow {
61 namespace {
62 using ::tensorflow::tpu::CompilationCacheEntryRef;
63 using ::tensorflow::tpu::TpuCompilationCacheLookup;
64 using ::tensorflow::tpu::TpuNodeContext;
65
66 // Looks up the input `key` in the compilation cache, populating
67 // `*rendezvous_key_base` and `*entry`.
GetComputationCacheEntry(OpKernelContext * context,string * rendezvous_key_base,std::unique_ptr<CompilationCacheEntryRef> * entry)68 Status GetComputationCacheEntry(
69 OpKernelContext* context, string* rendezvous_key_base,
70 std::unique_ptr<CompilationCacheEntryRef>* entry) {
71 const Tensor* key;
72 TF_RETURN_IF_ERROR(context->input("key", &key));
73 profiler::TraceMe trace_me("TpuExecuteOp::LookupProto", /*level=*/2);
74 if (!TensorShapeUtils::IsVector(key->shape()) ||
75 key->shape().dim_size(0) != 3) {
76 return errors::InvalidArgument(
77 "Key argument to TPUExecute must be a 3-element vector");
78 }
79
80 ResourceMgr* rmgr = GetTPUConfigResourceMgr();
81 TpuCompilationCacheLookup* proto_lookup;
82 TF_RETURN_IF_ERROR(rmgr->Lookup(rmgr->default_container(),
83 tpu::kCompiledProtoCacheResourceName,
84 &proto_lookup));
85 core::ScopedUnref lookup_unref(proto_lookup);
86 TF_RETURN_IF_ERROR(proto_lookup->Lookup(key->vec<tstring>()(0), entry));
87 *rendezvous_key_base = key->vec<tstring>()(1);
88 return Status::OK();
89 }
90
91 struct VariableUpdateMap {
92 // Maps input index to the updated output index. If the variable doesn't have
93 // an updated output, the corresponding output is set to -1.
94 absl::flat_hash_map<int, int> input_to_output;
95 // Maps output index to (the input index, whether the update is generated from
96 // compilation).
97 absl::flat_hash_map<int, std::pair<int, bool>> output_to_input;
98 // Part of the input indices that are from the compilation, in the compiled
99 // order.
100 std::vector<int> input_in_compiled_update_order;
101 };
102
103 // Creates a VariableUpdateMap from both the compilation and the fused variable
104 // reads/updates.
BuildVariableUpdateMap(absl::Span<const TPUExecutableInfoProto::UpdateIndexPair * const> compiled_variable_updates,absl::Span<int const> fused_device_var_reads_in_computation_inputs,const std::vector<int> & fused_device_var_updates_in_computation_outputs,int64_t computation_output_count)105 xla::StatusOr<VariableUpdateMap> BuildVariableUpdateMap(
106 absl::Span<const TPUExecutableInfoProto::UpdateIndexPair* const>
107 compiled_variable_updates,
108 absl::Span<int const> fused_device_var_reads_in_computation_inputs,
109 const std::vector<int>& fused_device_var_updates_in_computation_outputs,
110 int64_t computation_output_count) {
111 VariableUpdateMap map;
112 auto add_pair = [&](int input, int output, bool from_compilation) -> Status {
113 TF_RET_CHECK(map.input_to_output.emplace(input, output).second)
114 << "Duplicate variable input index: " << input;
115 if (output >= 0) {
116 TF_RET_CHECK(map.output_to_input
117 .emplace(output, std::make_pair(input, from_compilation))
118 .second)
119 << "Duplicate variable output index: " << output;
120 }
121 return Status::OK();
122 };
123
124 // First add the updates produced by the compilation. Not all variables are
125 // updated, and if not, they do not have an output in the XLA computation. The
126 // update output indices in the XLA computation start after the non-variable
127 // outputs.
128 int num_updated_variables = 0;
129 for (int i = 0; i < compiled_variable_updates.size(); ++i) {
130 const bool updated = compiled_variable_updates[i]->updated();
131 if (updated) ++num_updated_variables;
132 }
133 TF_RET_CHECK(num_updated_variables <= computation_output_count)
134 << num_updated_variables << " <= " << computation_output_count;
135 int64_t compiled_variable_output_index =
136 computation_output_count - num_updated_variables;
137 for (auto update : compiled_variable_updates) {
138 map.input_in_compiled_update_order.push_back(update->index());
139 if (!update->updated()) {
140 TF_RETURN_IF_ERROR(add_pair(update->index(), -1, true));
141 continue;
142 }
143 TF_RETURN_IF_ERROR(
144 add_pair(update->index(), compiled_variable_output_index, true));
145 ++compiled_variable_output_index;
146 }
147
148 // Now add the updates from the attributes.
149 TF_RET_CHECK(fused_device_var_reads_in_computation_inputs.size() ==
150 fused_device_var_updates_in_computation_outputs.size());
151 for (int64_t i = 0; i < fused_device_var_reads_in_computation_inputs.size();
152 ++i) {
153 TF_RETURN_IF_ERROR(
154 add_pair(fused_device_var_reads_in_computation_inputs[i],
155 fused_device_var_updates_in_computation_outputs[i], false));
156 }
157 return map;
158 }
159
160 // Buffers representing the inputs to a computation.
161 struct InputBuffers {
InputBufferstensorflow::__anonea0ab8410111::InputBuffers162 explicit InputBuffers(xla::Shape device_shape)
163 : buffers(std::move(device_shape)) {}
164
165 InputBuffers(const InputBuffers&) = delete;
166 InputBuffers& operator=(const InputBuffers&) = delete;
167
168 ~InputBuffers() = default;
169
ToShapedBuffertensorflow::__anonea0ab8410111::InputBuffers170 xla::ShapedBuffer ToShapedBuffer(xla::Shape host_shape,
171 se::DeviceMemoryAllocator* allocator,
172 int device_ordinal) {
173 CHECK_NE(allocator, nullptr);
174 xla::ShapedBuffer shaped_buffer(std::move(host_shape), buffers.shape(),
175 device_ordinal);
176 shaped_buffer.set_buffers(buffers.Map<se::DeviceMemoryBase>(
177 [](xla::MaybeOwningDeviceMemory* buffer) {
178 CHECK(buffer);
179 return buffer->AsDeviceMemoryBase();
180 }));
181 return shaped_buffer;
182 }
183
184 // Describes the buffer tree.
185 xla::ShapeTree<xla::MaybeOwningDeviceMemory> buffers;
186
187 // Information about resource variables passed directly to TPUExecute.
188 std::vector<VariableInfo> variables;
189
190 // Mapping from input index to offsets in 'variables'. < 0 if the input does
191 // not correspond to a variable in 'variables'.
192 std::vector<int> variable_index;
193 };
194
195 // Builds an InputBuffers object that describes the inputs to the computation.
BuildComputationInputs(OpKernelContext * context,const xla::Shape & input_host_shape,const VariableUpdateMap & variable_updates,xla::Backend * backend,int device_ordinal,se::Stream * stream)196 xla::StatusOr<std::unique_ptr<InputBuffers>> BuildComputationInputs(
197 OpKernelContext* context, const xla::Shape& input_host_shape,
198 const VariableUpdateMap& variable_updates, xla::Backend* backend,
199 int device_ordinal, se::Stream* stream) {
200 profiler::TraceMe trace_me("BuildComputationInputs", /*level=*/2);
201 OpInputList arg_list;
202 TF_RETURN_IF_ERROR(context->input_list("args", &arg_list));
203
204 if (arg_list.size() != xla::ShapeUtil::TupleElementCount(input_host_shape)) {
205 return errors::InvalidArgument(
206 "Number of parameters (", arg_list.size(),
207 ") does not match input shape: ",
208 xla::ShapeUtil::TupleElementCount(input_host_shape));
209 }
210
211 auto validate_shape = [&](int i, const Tensor& tensor) {
212 const xla::Shape& expected =
213 xla::ShapeUtil::GetTupleElementShape(input_host_shape, i);
214 VLOG(4) << "Input " << i << " TF shape " << tensor.shape().DebugString();
215 XlaTensor* xla_tensor = XlaTensor::FromTensor(&tensor);
216
217 if (xla_tensor == nullptr) {
218 // FromTensor failed; tensor must be empty.
219 if (!xla::ShapeUtil::IsZeroElementArray(expected)) {
220 return errors::InvalidArgument(
221 "Run-time shape mismatch for TPUExecute argument[", i, "] (",
222 context->op_kernel().requested_input(i), "). Expected ",
223 expected.DebugString(),
224 "; got empty tensor. If you are running "
225 "with TF2 TPU, make sure you set `drop_remainder=False` when "
226 "calling `dataset.batch` on the `tf.data.Dataset` so dynamic batch "
227 "size can be handled");
228 }
229 } else {
230 // Compare host shapes, easier than getting the expected device shape.
231 const xla::Shape& xla_shape = xla_tensor->shaped_buffer().on_host_shape();
232 if (!xla::ShapeUtil::Compatible(expected, xla_shape)) {
233 return errors::InvalidArgument(
234 "Run-time shape mismatch for TPUExecute argument[", i, "] (",
235 context->op_kernel().requested_input(i), "). Expected ",
236 expected.DebugString(), "; got ", xla_shape.DebugString());
237 }
238 }
239
240 return Status::OK();
241 };
242
243 // Iterate over the inputs, validating the shapes of non-variable inputs,
244 // and creating a VariableInfo object for each variable. We consider variable
245 // inputs in a separate phase because we must acquire variable locks in order.
246 std::vector<VariableInfo> variables;
247 std::vector<int> variable_index(arg_list.size(), -1);
248 variables.reserve(arg_list.size());
249 for (int i = 0; i < arg_list.size(); ++i) {
250 // Arguments are assumed to be variables if they have a resource type.
251 // (Non-variable resources are not supported.)
252 if (context->input_dtype(i) == DT_RESOURCE) {
253 variable_index[i] = variables.size();
254 // TODO(phawkins): we may be looking up many variables here; it would be
255 // better if we did not repeatedly acquire the resource manager's lock.
256 const ResourceHandle& handle = HandleFromInput(context, i);
257 Var* variable;
258 TF_RETURN_IF_ERROR(LookupResource(context, handle, &variable));
259 variables.push_back(VariableInfo(i, handle.name(), variable));
260 } else {
261 TF_RETURN_IF_ERROR(validate_shape(i, arg_list[i]));
262 }
263 }
264
265 // Lock the variables, and validate their shapes. We hold the variable locks
266 // for the duration of the TPU execution so we can donate the variable buffers
267 // to the computation. If we copied the variable's Tensor instead, its
268 // reference count would be greater than one due to the reference the Var
269 // object holds, and we would never be able to reuse variable buffers.
270 // TODO(phawkins): add a 'reuse_buffers' attribute to TPUExecute that allows
271 // the user to elect to copy the buffers and permit concurrent access instead.
272 TF_RETURN_IF_ERROR(LockVariables(absl::MakeSpan(variables)));
273 for (int i = 0; i < variables.size(); ++i) {
274 TF_RETURN_IF_ERROR(
275 validate_shape(variables[i].index(), *variables[i].var()->tensor()));
276 }
277
278 se::DeviceMemoryAllocator* const allocator = backend->memory_allocator();
279 xla::TransferManager* const transfer_manager = backend->transfer_manager();
280
281 auto input_buffers = absl::make_unique<InputBuffers>(
282 transfer_manager->HostShapeToDeviceShape(input_host_shape));
283
284 // Allocates a buffer for the root tuple.
285 const int64_t root_size =
286 transfer_manager->GetByteSizeRequirement(input_buffers->buffers.shape());
287 TF_ASSIGN_OR_RETURN(*input_buffers->buffers.mutable_element({}),
288 allocator->Allocate(device_ordinal, root_size));
289
290 // Helper function that sets the input buffers for 'arg_index' to 'buffers'.
291 // If 'donate_buffers' is true, donates ownership of the buffers in 'buffers'
292 // to the computation and overwrites the entries in 'buffers' with nulls.
293 auto set_input_buffers_helper = [&](int arg_index, bool donate_buffers,
294 xla::ShapedBuffer* buffers) {
295 buffers->buffers().ForEachMutableElement([&](const xla::ShapeIndex& index,
296 se::DeviceMemoryBase* buffer) {
297 xla::ShapeIndex in_index = {arg_index};
298 for (int64_t j : index) {
299 in_index.push_back(j);
300 }
301 auto* in_buffer = input_buffers->buffers.mutable_element(in_index);
302 if (donate_buffers) {
303 *in_buffer = se::OwningDeviceMemory(*buffer, device_ordinal, allocator);
304 *buffer = se::DeviceMemoryBase();
305 } else {
306 *in_buffer = *buffer;
307 }
308 });
309 };
310
311 // Assigns the buffers of 'tensor' as computation input 'i'. Allocates fresh
312 // buffers for zero-element tensors where required.
313 auto assign_input = [&](int i, const Tensor& tensor,
314 bool may_reuse) -> xla::Status {
315 XlaTensor* xla_tensor = XlaTensor::FromTensor(&tensor);
316
317 // Size 0 tensors have no backing XlaTensor, but may still need to have
318 // tuple buffers allocated.
319 if (xla_tensor == nullptr) {
320 CHECK_EQ(tensor.NumElements(), 0);
321 const xla::Shape& host_shape =
322 xla::ShapeUtil::GetSubshape(input_host_shape, {i});
323 TF_ASSIGN_OR_RETURN(xla::ScopedShapedBuffer buffers,
324 transfer_manager->AllocateScopedShapedBuffer(
325 host_shape, allocator, device_ordinal));
326 set_input_buffers_helper(/*arg_index=*/i, /*donate_buffers=*/true,
327 &buffers);
328 } else {
329 bool can_reuse_buffers = tensor.RefCountIsOne() && may_reuse;
330 set_input_buffers_helper(/*arg_index=*/i,
331 /*donate_buffers=*/can_reuse_buffers,
332 &xla_tensor->shaped_buffer());
333 xla_tensor->WaitForDefinitionEventOnStream(stream);
334 }
335 return Status::OK();
336 };
337
338 for (int i = 0; i < arg_list.size(); ++i) {
339 auto it = variable_updates.input_to_output.find(i);
340 if (it == variable_updates.input_to_output.end()) {
341 TF_RETURN_IF_ERROR(assign_input(i, arg_list[i], /*may_reuse=*/true));
342 continue;
343 }
344 // input i is a variable
345 bool updated = it->second >= 0;
346 if (arg_list[i].dtype() != DT_RESOURCE) {
347 TF_RETURN_IF_ERROR(assign_input(i, arg_list[i], updated));
348 } else {
349 int vi = variable_index[i];
350 TF_RETURN_IF_ERROR(
351 assign_input(i, *variables[vi].var()->tensor(), updated));
352 }
353 }
354
355 input_buffers->variables = std::move(variables);
356 input_buffers->variable_index = std::move(variable_index);
357
358 return std::move(input_buffers);
359 }
360
361 struct OutputBuffers {
OutputBufferstensorflow::__anonea0ab8410111::OutputBuffers362 OutputBuffers(xla::ScopedShapedBuffer b, se::DeviceMemoryAllocator* allocator)
363 : owned_buffers(b.on_device_shape(), true),
364 buffers(b.release()),
365 memory_allocator(allocator) {}
366
~OutputBufferstensorflow::__anonea0ab8410111::OutputBuffers367 ~OutputBuffers() {
368 buffers.buffers().ForEachElement(
369 [&](const xla::ShapeIndex& index, const se::DeviceMemoryBase& buffer) {
370 if (owned_buffers.element(index) && !buffer.is_null()) {
371 Status status =
372 memory_allocator->Deallocate(buffers.device_ordinal(), buffer);
373 if (!status.ok()) {
374 LOG(ERROR) << "Error deallocating buffer " << status;
375 }
376 }
377 });
378 }
379
380 // Which of the buffers do we own?
381 xla::ShapeTree<bool> owned_buffers;
382
383 xla::ShapedBuffer buffers;
384
385 se::DeviceMemoryAllocator* const memory_allocator;
386 };
387
388 // Allocates Tensors for the outputs of the computation. Ownership of most
389 // output buffers is passed to the output Tensors. Returns an OutputBuffer that
390 // owns the root buffer that should be passed to the XLA computation, as well as
391 // any output buffers that do not have corresponding output tensors. The latter
392 // may happen for zero-element tensors of type int64 or complex64 which still
393 // require a tuple buffer but do not have a corresponding XlaTensor.
AllocateOutputTensors(OpKernelContext * context,xla::ScopedShapedBuffer scoped_buffers,absl::Span<const TensorShapeProto * const> output_tensor_shape_protos,const VariableUpdateMap & variable_updates,TpuNodeContext * node_context,se::Stream * stream,int device_ordinal,InputBuffers * input_buffers,const std::shared_ptr<se::Event> & definition_event)394 xla::StatusOr<std::unique_ptr<OutputBuffers>> AllocateOutputTensors(
395 OpKernelContext* context, xla::ScopedShapedBuffer scoped_buffers,
396 absl::Span<const TensorShapeProto* const> output_tensor_shape_protos,
397 const VariableUpdateMap& variable_updates, TpuNodeContext* node_context,
398 se::Stream* stream, int device_ordinal, InputBuffers* input_buffers,
399 const std::shared_ptr<se::Event>& definition_event) {
400 VLOG(4) << "Output buffers: " << scoped_buffers.ToString();
401
402 profiler::TraceMe trace_me("AllocateOutputTensors", /*level=*/2);
403 // Shapes of the outputs, in TensorShape form.
404 const int64_t sub_elements =
405 xla::ShapeUtil::TupleElementCount(scoped_buffers.on_host_shape());
406 if (sub_elements != output_tensor_shape_protos.size()) {
407 return errors::InvalidArgument(
408 "Mismatched numbers of output shapes: ", sub_elements, " vs. ",
409 output_tensor_shape_protos.size());
410 }
411
412 xla::TransferManager* const transfer_manager =
413 node_context->backend()->transfer_manager();
414
415 std::vector<TensorShape> output_tensor_shapes;
416 output_tensor_shapes.reserve(sub_elements);
417 for (int64_t i = 0; i < sub_elements; ++i) {
418 TF_RETURN_IF_ERROR(
419 TensorShape::IsValidShape(*output_tensor_shape_protos[i]));
420 TensorShape shape(*output_tensor_shape_protos[i]);
421 const xla::Shape& xla_shape =
422 xla::ShapeUtil::GetSubshape(scoped_buffers.on_host_shape(), {i});
423 if (!xla_shape.IsArray() ||
424 xla::ShapeUtil::ElementsIn(xla_shape) != shape.num_elements()) {
425 return errors::InvalidArgument(
426 "Mismatched number of elements in output shape: ",
427 xla::ShapeUtil::HumanString(xla_shape), " vs ", shape.DebugString());
428 }
429 output_tensor_shapes.push_back(shape);
430 }
431
432 // Builds a shaped buffer for the outputs.
433 TF_RET_CHECK(scoped_buffers.on_host_shape().IsTuple());
434 TF_RET_CHECK(!xla::ShapeUtil::IsNestedTuple(scoped_buffers.on_host_shape()));
435
436 se::DeviceMemoryAllocator* const allocator =
437 node_context->backend()->memory_allocator();
438
439 auto output_buffers =
440 absl::make_unique<OutputBuffers>(std::move(scoped_buffers), allocator);
441
442 xla::Shape output_device_shape = output_buffers->buffers.on_device_shape();
443
444 if (!output_device_shape.is_static()) {
445 TF_RETURN_IF_ERROR(transfer_manager->ReadDynamicShapes(
446 stream, &output_buffers->buffers, &output_device_shape));
447 for (int64_t i = 0; i < sub_elements; ++i) {
448 const xla::Shape& subshape =
449 xla::ShapeUtil::GetSubshape(output_device_shape, {i});
450 TensorShape shape;
451 TF_RETURN_IF_ERROR(XLAShapeToTensorShape(subshape, &shape));
452 output_tensor_shapes[i] = shape;
453 }
454 }
455
456 // Transfers ownership of the buffers that back XLA computation output 'i'
457 // to 'output_tensor'.
458 auto transfer_buffers = [&](int i, Tensor* output_tensor) {
459 const xla::Shape& device_shape =
460 xla::ShapeUtil::GetTupleElementShape(output_device_shape, i);
461
462 // Transfers ownership of the output buffers to the output Tensor, if
463 // there the tensor is backed by an XlaTensor. Tensors of size 0 have no
464 // backing XlaTensor, so we let retain 'output_buffers' ownership of any
465 // buffers in that case.
466 if (output_tensor->NumElements() > 0) {
467 xla::ScopedShapedBuffer shaped_buffer(device_shape, allocator,
468 device_ordinal);
469 shaped_buffer.buffers().ForEachMutableElement(
470 [&](const xla::ShapeIndex& index, se::DeviceMemoryBase* buffer) {
471 xla::ShapeIndex out_index = {i};
472 for (int64_t j : index) {
473 out_index.push_back(j);
474 }
475 *buffer = output_buffers->buffers.buffers().element(out_index);
476 *output_buffers->owned_buffers.mutable_element(out_index) = false;
477 });
478
479 XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor);
480 xla_tensor->set_shaped_buffer(std::move(shaped_buffer));
481 xla_tensor->ResetDefinitionEvent(definition_event, stream);
482 }
483 };
484
485 const int num_updated_variables = variable_updates.output_to_input.size();
486 TF_RET_CHECK(num_updated_variables <= output_tensor_shapes.size())
487 << num_updated_variables << " <= " << output_tensor_shapes.size();
488
489 OpInputList arg_list;
490 TF_RETURN_IF_ERROR(context->input_list("args", &arg_list));
491
492 // The TPU program outputs the updated variables including DT_RESOURCE and
493 // non-DT_RESOURCE. The TPUExecuteOp needs to output all non-DT_RESOURCE
494 // variables (updated or not).
495 //
496 // updated not_updated
497 // |------------------|------------------|
498 // DT_RESOURCE | allocate persist | do nothing |
499 // |------------------|------------------|
500 // | allocate | forward Op input |
501 // not DT_RESOURCE | output | to Op output | Op output
502 // |------------------|------------------|
503 // program output
504
505 // Allocates a fresh tensor for each updated variable. While the variable
506 // inputs need come in no particular order, the variable values are
507 // always added last by XlaCompiler class, in the same order as the
508 // corresponding input variables.
509 int op_output_index = 0;
510 int compiled_update_index = 0;
511 auto process_non_updated_variable = [&](int input_index) {
512 const int variable_index = input_buffers->variable_index.at(input_index);
513 // If a DT_RESOURCE input is not updated, nothing needs to be done
514 // because there is no corresponding output. If a non-resource input
515 // is not updated, forward the input to the output.
516 if (variable_index < 0) {
517 context->set_output(op_output_index, arg_list[input_index]);
518 ++op_output_index;
519 }
520 };
521 for (int i = 0; i < output_tensor_shapes.size(); ++i) {
522 auto it = variable_updates.output_to_input.find(i);
523 if (it == variable_updates.output_to_input.end()) {
524 // Not a variable update.
525 // Allocates a fresh tensor for each output of the operator. We always
526 // allocate a new host-side tensor, but the on-device buffers that back
527 // that tensor may be aliases of input buffers.
528 Tensor* output_tensor;
529 TF_RETURN_IF_ERROR(context->allocate_output(
530 op_output_index, output_tensor_shapes[i], &output_tensor));
531 transfer_buffers(i, output_tensor);
532 ++op_output_index;
533 continue;
534 }
535 const int input_index = it->second.first;
536 // We must process the compiled updates in order, which includes the
537 // non-updated variables, i.e., those without an XLA output.
538 const bool from_compilation = it->second.second;
539 while (from_compilation &&
540 variable_updates
541 .input_in_compiled_update_order[compiled_update_index] !=
542 input_index) {
543 process_non_updated_variable(
544 variable_updates
545 .input_in_compiled_update_order[compiled_update_index]);
546 ++compiled_update_index;
547 }
548 ++compiled_update_index;
549 const int variable_index = input_buffers->variable_index.at(input_index);
550 if (variable_index >= 0) {
551 // This output corresponds to a DT_RESOURCE input to the TPUExecute
552 // operator. Update the corresponding variable.
553 VariableInfo& var = input_buffers->variables[variable_index];
554 TF_RETURN_IF_ERROR(context->allocate_temp(var.var()->tensor()->dtype(),
555 output_tensor_shapes[i],
556 var.var()->tensor()));
557 transfer_buffers(i, var.var()->tensor());
558 } else {
559 // This output corresponds to a non-resource input to the TPUExecute
560 // operator. This case occurs for the distributed TPU rewrite which
561 // adds variable values as inputs and outputs rather than passing the
562 // variables themselves; reading and writing the variable is handled
563 // outside the op.
564 // TODO(phawkins): remove this case when placement of variables on TPU
565 // devices is well supported and we no longer need to place "remote"
566 // variables on CPU devices.
567 Tensor* output_tensor;
568 TF_RETURN_IF_ERROR(context->allocate_output(
569 op_output_index, output_tensor_shapes[i], &output_tensor));
570 ++op_output_index;
571 transfer_buffers(i, output_tensor);
572 }
573 }
574
575 // Process any remaining non-updated variables.
576 for (; compiled_update_index <
577 variable_updates.input_in_compiled_update_order.size();
578 ++compiled_update_index) {
579 process_non_updated_variable(
580 variable_updates.input_in_compiled_update_order[compiled_update_index]);
581 }
582 return std::move(output_buffers);
583 }
584
585 } // namespace
586
587 // TPUExecuteOp
588
TPUExecuteOp(OpKernelConstruction * context)589 TPUExecuteOp::TPUExecuteOp(OpKernelConstruction* context)
590 : AsyncOpKernel(context, /* is_deferred = */ true) {}
591
AsAsync()592 AsyncOpKernel* TPUExecuteOp::AsAsync() {
593 // If TPU launches are asynchronous, we can perform the launch without
594 // blocking the calling thread, and so the executor may treat this kernel as
595 // a regular (synchronous) OpKernel.
596 return nullptr;
597 }
598
Compute(OpKernelContext * context)599 void TPUExecuteOp::Compute(OpKernelContext* context) {
600 Status s = DoWork(context);
601 // NOTE: We can't use `OP_REQUIRES_OK()` here because that macro includes
602 // a dynamic check that we are not in an AsyncOpKernel.
603 if (TF_PREDICT_FALSE(!s.ok())) {
604 context->SetStatus(s);
605 }
606 }
607
ComputeAsync(OpKernelContext * context,DoneCallback done)608 void TPUExecuteOp::ComputeAsync(OpKernelContext* context, DoneCallback done) {
609 // If TPU launches are asynchronous, then perform the launch on this
610 // thread to avoid a thread hop, which has an observable latency cost.
611 OP_REQUIRES_OK_ASYNC(context, DoWork(context), done);
612 done();
613 }
614
DoWork(OpKernelContext * context)615 Status TPUExecuteOp::DoWork(OpKernelContext* context) {
616 VLOG(1) << "Cloud TPU: TPUExecuteOp::Compute";
617
618 const XlaDevice::Metadata* metadata;
619 TF_RETURN_IF_ERROR(XlaDevice::GetMetadata(context, &metadata));
620 const int device_ordinal = metadata->device_ordinal();
621
622 // We are guaranteed that the object underlying TpuNodeContext won't be
623 // deleted out from under us, while node_context is alive.
624 TF_ASSIGN_OR_RETURN(std::unique_ptr<TpuNodeContext> node_context,
625 TpuNodeContext::Create(device_ordinal));
626
627 profiler::TraceMe trace_me(
628 [device_ordinal, context] {
629 return profiler::TraceMeEncode(
630 "TpuExecuteOp", {{"device_ordinal", device_ordinal},
631 {"id", context->step_id()},
632 {"iter_num", context->frame_iter().iter_id}});
633 },
634 /*level=*/2);
635 profiler::TraceMe trace_me_init("TPUExecuteOp::Init", /*level=*/2);
636
637 string rendezvous_key_base;
638 std::unique_ptr<CompilationCacheEntryRef> entry_ref;
639 TF_RETURN_IF_ERROR(
640 GetComputationCacheEntry(context, &rendezvous_key_base, &entry_ref));
641
642 // Shapes of the inputs and outputs, in xla::Shape form.
643 tpu::TpuCompilationCacheEntry entry = entry_ref->get();
644 const tpu::TpuProgramGroup* tpu_program_group =
645 tensorflow::down_cast<const tpu::TpuProgramGroup*>(
646 entry.tpu_program_group());
647 CHECK_NE(tpu_program_group, nullptr);
648 const int core_index = entry.core_index();
649 const TPUExecutableInfoProto& executable =
650 tpu_program_group->executable_info(core_index);
651
652 xla::Backend* const backend = node_context->backend();
653 xla::TransferManager* const transfer_manager = backend->transfer_manager();
654 TF_RET_CHECK(context->op_device_context());
655 se::Stream* stream = context->op_device_context()->stream();
656
657 TF_RET_CHECK(executable.input_shapes_size() == 1);
658
659 xla::Shape host_shape(executable.input_shapes(0));
660
661 TF_ASSIGN_OR_RETURN(
662 auto variable_update_map,
663 BuildVariableUpdateMap(executable.variable_indices(),
664 fused_device_var_reads_in_computation_inputs_,
665 fused_device_var_updates_in_computation_outputs_,
666 executable.output_tensor_shapes().size()));
667 TF_ASSIGN_OR_RETURN(
668 std::unique_ptr<InputBuffers> input_buffers,
669 BuildComputationInputs(context, host_shape, variable_update_map, backend,
670 device_ordinal, stream));
671
672 // Ideally this should be the host-to-device stream from XlaDeviceContext.
673 // The particular anti-dependency this is avoiding (why we need a separate
674 // transfer stream) is between the executable writing tuple tables and
675 // TPUExecute()'s deregister_stream; if they come from the same stream pool
676 // antidependencies will occur. XlaBackend has a different pool of streams
677 // to the stream->GetOrCreateSubStream() that TPUExecute() uses, so these
678 // will never refer to the same stream.
679 //
680 // TODO(jmolloy): Add the necessary plumbing to obtain the proper
681 // host-to-device stream here.
682 TF_ASSIGN_OR_RETURN(auto transfer_stream_ptr,
683 backend->BorrowStream(device_ordinal));
684
685 se::DeviceMemoryAllocator* const allocator = backend->memory_allocator();
686 auto shaped_buffer = input_buffers->ToShapedBuffer(std::move(host_shape),
687 allocator, device_ordinal);
688 if (transfer_manager->CanShapedBufferBeAccessedNow(stream->parent(),
689 shaped_buffer)) {
690 TF_RETURN_IF_ERROR(transfer_manager->WriteRootTupleIndexTable(
691 transfer_stream_ptr.get(), shaped_buffer));
692 stream->ThenWaitFor(transfer_stream_ptr.get());
693 } else {
694 TF_RETURN_IF_ERROR(
695 transfer_manager->WriteRootTupleIndexTable(stream, shaped_buffer));
696 }
697 VLOG(4) << "Input buffers: " << shaped_buffer.ToString();
698
699 // Snapshot the inputs, if a snapshot was requested.
700 std::shared_ptr<xla::HloSnapshot> hlo_snapshot;
701 if (executable.has_session_module()) {
702 hlo_snapshot =
703 std::make_shared<xla::HloSnapshot>(executable.session_module());
704 auto literal =
705 std::make_shared<xla::Literal>(shaped_buffer.on_host_shape());
706 transfer_manager->TransferLiteralFromDevice(
707 stream, shaped_buffer, literal.get(),
708 [hlo_snapshot, literal](Status status) {
709 if (!status.ok()) {
710 LOG(ERROR) << "TransferLiteralFromDevice for HLO snapshot inputs "
711 "failed: "
712 << status;
713 return;
714 }
715 *hlo_snapshot->add_arguments() = literal->ToProto();
716 });
717 }
718
719 auto definition_event = std::make_shared<se::Event>(stream->parent());
720 TF_RET_CHECK(definition_event->Init())
721 << "TPU definition event initialization failed";
722
723 trace_me_init.Stop();
724
725 const uint32 rng_seed = GetXLARandomSeed();
726
727 std::unique_ptr<xla::DeviceAssignment> device_assignment;
728 if (executable.has_device_assignment()) {
729 TF_ASSIGN_OR_RETURN(device_assignment, xla::DeviceAssignment::Deserialize(
730 executable.device_assignment()));
731 }
732
733 VLOG(4) << "Input buffers after alias resolution: "
734 << shaped_buffer.ToString();
735
736 std::vector<xla::ExecutionInput> input;
737 input.emplace_back(xla::ExecutionInput(std::move(input_buffers->buffers),
738 shaped_buffer.on_host_shape()));
739
740 // The buffers to be freed are in the `output` and will be automatically
741 // freed when it goes out of the scope. In async mode, this means the buffers
742 // will be freed before anyone calls "BlockHostUntilDone", which indicates
743 // that some of the (input) buffers will be freed while the program is running
744 // and looks scary. However, this turns out to be not a problem since although
745 // we free a memory and reassign it to other users while a program is running,
746 // all subsequent writes to the program that could possibly clobber the memory
747 // will depend on the program to finish.
748 const TPUHostTransferInfoProto& host_transfer_info =
749 tpu_program_group->host_transfer_info(core_index);
750 TF_ASSIGN_OR_RETURN(
751 xla::ExecutionOutput output,
752 TPUExecute(executable, host_transfer_info,
753 *tpu_program_group->hlo_metadata(core_index), std::move(input),
754 rendezvous_key_base, rng_seed, node_context.get(),
755 device_assignment.get(), context->cancellation_manager(),
756 context, stream, transfer_stream_ptr.get(),
757 tpu_program_group->tpu_program(core_index)));
758 stream->ThenRecordEvent(definition_event.get());
759
760 TF_ASSIGN_OR_RETURN(
761 std::unique_ptr<OutputBuffers> output_buffers,
762 AllocateOutputTensors(
763 context, output.ConsumeResult(), executable.output_tensor_shapes(),
764 variable_update_map, node_context.get(), stream, device_ordinal,
765 input_buffers.get(), definition_event));
766
767 // Transfer the outputs and save the snapshot to disk.
768 if (hlo_snapshot) {
769 auto literal =
770 std::make_shared<xla::Literal>(output_buffers->buffers.on_host_shape());
771 transfer_manager->TransferLiteralFromDevice(
772 stream, output_buffers->buffers, literal.get(),
773 [hlo_snapshot, literal](Status status) {
774 if (status.ok()) {
775 *hlo_snapshot->mutable_result() = literal->ToProto();
776 } else {
777 LOG(ERROR) << "TransferLiteralFromDevice for HLO snapshot "
778 "outputs failed: "
779 << status;
780 }
781 DumpHloSnapshotIfEnabled(*hlo_snapshot,
782 xla::GetDebugOptionsFromFlags());
783 });
784 }
785 return Status::OK();
786 }
787
788 TPUExecuteOp::~TPUExecuteOp() = default;
789
TPUExecuteAndUpdateVariablesOp(OpKernelConstruction * context)790 TPUExecuteAndUpdateVariablesOp::TPUExecuteAndUpdateVariablesOp(
791 OpKernelConstruction* context)
792 : TPUExecuteOp(context) {
793 OP_REQUIRES_OK(context, context->GetAttr(
794 "device_var_reads_indices",
795 &fused_device_var_reads_in_computation_inputs_));
796 OP_REQUIRES_OK(
797 context,
798 context->GetAttr("device_var_updates_indices",
799 &fused_device_var_updates_in_computation_outputs_));
800 }
801
802 REGISTER_KERNEL_BUILDER(
803 Name("TPUExecute").Device(DEVICE_TPU_NODE).HostMemory("key"), TPUExecuteOp);
804
805 REGISTER_KERNEL_BUILDER(Name("TPUExecuteAndUpdateVariables")
806 .Device(DEVICE_TPU_NODE)
807 .HostMemory("key"),
808 TPUExecuteAndUpdateVariablesOp);
809
810 } // namespace tensorflow
811