1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/tpu/tpu_execute.h"
17
18 #include <cstdlib>
19 #include <memory>
20 #include <string>
21 #include <tuple>
22 #include <utility>
23
24 #include "absl/base/casts.h"
25 #include "absl/memory/memory.h"
26 #include "tensorflow/compiler/xla/executable_run_options.h"
27 #include "tensorflow/compiler/xla/service/computation_layout.h"
28 #include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
29 #include "tensorflow/compiler/xla/service/hlo_module.h"
30 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
31 #include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h"
32 #include "tensorflow/compiler/xla/service/service_executable_run_options.h"
33 #include "tensorflow/compiler/xla/service/transfer_manager.h"
34 #include "tensorflow/compiler/xla/shape.h"
35 #include "tensorflow/compiler/xla/shape_layout.h"
36 #include "tensorflow/compiler/xla/shape_util.h"
37 #include "tensorflow/compiler/xla/status.h"
38 #include "tensorflow/compiler/xla/status_macros.h"
39 #include "tensorflow/compiler/xla/util.h"
40 #include "tensorflow/compiler/xla/xla_data.pb.h"
41 #include "tensorflow/core/platform/casts.h"
42 #include "tensorflow/core/platform/env.h"
43 #include "tensorflow/core/platform/errors.h"
44 #include "tensorflow/core/profiler/lib/traceme.h"
45 #include "tensorflow/core/tpu/kernels/tpu_execute_op_options.h"
46 #include "tensorflow/core/tpu/tpu_api.h"
47 #include "tensorflow/stream_executor/device_memory.h"
48 #include "tensorflow/stream_executor/lib/statusor.h"
49 #include "tensorflow/stream_executor/tpu/c_api_conversions.h"
50 #include "tensorflow/stream_executor/tpu/status_helper.h"
51 #include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
52 #include "tensorflow/stream_executor/tpu/tpu_op_executable.h"
53 #include "tensorflow/stream_executor/tpu/tpu_platform_interface.h"
54
55 namespace tensorflow {
56
57 namespace {
58
59 using ::tensorflow::tpu::TpuNodeContext;
60
61 // These are placeholders for absl flags.
62 static bool tpu_cancellation_terminates_process = false;
63 static bool tpu_cancellation_closes_chips = true;
64
65 // Host-side runtime for transfers between TPU and host.
66 // TODO(b/161940519): Implement this class.
67 class HostTransferManager {
68 public:
HostTransferManager(TpuNodeContext *,xla::Backend *)69 explicit HostTransferManager(TpuNodeContext*, xla::Backend*) {}
70
71 using HostCommandHandler = TpuOpExecutable::HostCommandHandler;
72
73 // Returns a function to be called when the TPU triggers a host command
74 // interrupt while executing the current program.
75 xla::StatusOr<HostCommandHandler> Initialize(
76 const TPUHostTransferInfoProto& program,
77 const std::string& rendezvous_key_base, OpKernelContext* ctx);
78
79 private:
80 TF_DISALLOW_COPY_AND_ASSIGN(HostTransferManager);
81 };
82
83 xla::StatusOr<HostTransferManager::HostCommandHandler>
Initialize(const TPUHostTransferInfoProto & program,const string & rendezvous_key_base,OpKernelContext * ctx)84 HostTransferManager::Initialize(const TPUHostTransferInfoProto& program,
85 const string& rendezvous_key_base,
86 OpKernelContext* ctx) {
87 return HostCommandHandler([](uint32_t, int64_t) {
88 LOG(WARNING) << "HostTransferManager is unimplemented.";
89 });
90 }
91
92 // Sleep for 5 seconds, then call std::quick_exit(42) to quickly restart.
ExitCountdown(Env * env)93 void ExitCountdown(Env* env) {
94 const int kSleepSeconds = 5;
95 LOG(INFO) << "TpuExecute was cancelled. Sleeping for " << kSleepSeconds
96 << " seconds before terminating the process to give time "
97 "for other errors to propagate";
98 env->SleepForMicroseconds(kSleepSeconds * 1000000);
99 LOG(ERROR) << "Aborting process due to cancelled TPUExecute. Consult "
100 "the anomalies reported above (if any), run state of job "
101 "(including failed RPCs) and worker logs. This "
102 "termination is to ensure a consistent state, if your job "
103 "does not restart, modify the retries allowed. See "
104 "b/62262381 and b/65223927.";
105 std::quick_exit(42);
106 }
107
HostShapeToDeviceShape(const xla::Shape & host_shape)108 xla::Shape HostShapeToDeviceShape(const xla::Shape& host_shape) {
109 XLA_Shape c_host_shape;
110 XLA_Shape c_device_shape;
111 ApiConverter::ToC(host_shape, &c_host_shape);
112 tensorflow::tpu::OpsApiFn()->HardwareLayout_HostShapeToDeviceShapeFn(
113 &c_host_shape, &c_device_shape);
114 xla::Shape device_shape = ApiConverter::FromC(&c_device_shape);
115 ApiConverter::Destroy(&c_host_shape);
116 ApiConverter::Destroy(&c_device_shape);
117 return device_shape;
118 }
119
ShapeSizeCompact(const xla::Shape & shape)120 int64_t ShapeSizeCompact(const xla::Shape& shape) {
121 XLA_Shape c_shape;
122 ApiConverter::ToC(shape, &c_shape);
123 int64_t size =
124 tensorflow::tpu::OpsApiFn()->HardwareLayout_ShapeSizeCompactFn(&c_shape);
125 ApiConverter::Destroy(&c_shape);
126 return size;
127 }
128
ShapeSizeCompactRaw(const xla::Shape & shape)129 int64_t ShapeSizeCompactRaw(const xla::Shape& shape) {
130 XLA_Shape c_shape;
131 ApiConverter::ToC(shape, &c_shape);
132 int64_t size =
133 tensorflow::tpu::OpsApiFn()->HardwareLayout_ShapeSizeCompactRawFn(
134 &c_shape);
135 ApiConverter::Destroy(&c_shape);
136 return size;
137 }
138
139 // Given a tuple, fix all non-leaf nodes (tuples) such that the tuple tables
140 // point to the correct leaf nodes.
FixTupleTableAsync(se::Stream * stream,const xla::Shape & tuple_shape,xla::ExecutionInput * mem,xla::TransferManager * transfer_manager)141 xla::Status FixTupleTableAsync(se::Stream* stream,
142 const xla::Shape& tuple_shape,
143 xla::ExecutionInput* mem,
144 xla::TransferManager* transfer_manager) {
145 return xla::ShapeUtil::ForEachSubshapeWithStatus(
146 tuple_shape,
147 [&](const xla::Shape& element_shape,
148 const xla::ShapeIndex& index) -> Status {
149 if (!element_shape.IsTuple()) {
150 return OkStatus();
151 }
152 std::vector<se::DeviceMemoryBase> elements;
153 xla::ShapeIndex element_index = index;
154 element_index.push_back(0);
155 for (int i = 0; i < element_shape.tuple_shapes_size(); ++i) {
156 // Gather all children of the tuple element.
157 element_index.back() = i;
158 elements.push_back(mem->Buffer(element_index).AsDeviceMemoryBase());
159 }
160 se::DeviceMemoryBase tuple_table_addr =
161 mem->Buffer(index).AsDeviceMemoryBase();
162 return transfer_manager->WriteSingleTupleIndexTable(
163 stream, elements, element_shape, &tuple_table_addr);
164 });
165 }
166
167 // Returns true if `dynamic_shape` has dimensions that are less-equal to the
168 // "bounded_shape".
DynamicShapeIsCompatible(const xla::Shape & dynamic_shape,const xla::Shape & bounded_shape)169 bool DynamicShapeIsCompatible(const xla::Shape& dynamic_shape,
170 const xla::Shape& bounded_shape) {
171 if (dynamic_shape.rank() != bounded_shape.rank()) {
172 return false;
173 }
174 for (int64_t i = 0; i < dynamic_shape.rank(); ++i) {
175 if (dynamic_shape.dimensions(i) > bounded_shape.dimensions(i)) {
176 return false;
177 }
178 }
179 return true;
180 }
181
182 // For dynamic inputs, copy them and attach metadata of shape sizes to the
183 // end of the tensor.
184 //
185 // The buffer for dynamic shapes contains three parts:
186 // +--------+
187 // |Payload |
188 // +--------+
189 // | Padding|
190 // +--------+
191 // |Metadata|
192 // +--------+
193 //
194 // Metadata contains the sizes of shape without padding, eventually
195 // representing the size of valid data.
UpdateDynamicInputs(se::Stream * stream,se::DeviceMemoryAllocator * allocator,std::vector<xla::ExecutionInput> * runtime_inputs,const std::vector<xla::Shape> & compile_time_shapes)196 xla::Status UpdateDynamicInputs(
197 se::Stream* stream, se::DeviceMemoryAllocator* allocator,
198 std::vector<xla::ExecutionInput>* runtime_inputs,
199 const std::vector<xla::Shape>& compile_time_shapes) {
200 TF_RET_CHECK(runtime_inputs->size() == compile_time_shapes.size());
201 for (int64_t i = 0; i < compile_time_shapes.size(); i++) {
202 // TODO(yunxing): Iterating over thousands of elements can be slow. One way
203 // to optimize for fast path without dynamic shapes is add a field in
204 // compilation result indicating if dynamic input is presented.
205 if (compile_time_shapes[i].is_static()) {
206 continue;
207 }
208 auto& runtime_input = (*runtime_inputs)[i];
209 xla::Shape compile_time_shapes_on_device =
210 HostShapeToDeviceShape(compile_time_shapes[i]);
211 bool element_modified = false;
212 TF_RETURN_IF_ERROR(xla::ShapeUtil::ForEachSubshapeWithStatus(
213 compile_time_shapes_on_device,
214 [&](const xla::Shape& compile_time_shape,
215 const xla::ShapeIndex& index) -> Status {
216 if (compile_time_shape.IsTuple() || compile_time_shape.is_static()) {
217 return OkStatus();
218 }
219
220 const xla::Shape& runtime_shape =
221 xla::ShapeUtil::GetSubshape(runtime_input.shape(), index);
222
223 TF_RET_CHECK(!runtime_shape.IsTuple());
224 TF_RET_CHECK(
225 DynamicShapeIsCompatible(runtime_shape, compile_time_shape));
226
227 xla::MaybeOwningDeviceMemory* mutable_input_mem =
228 runtime_input.MutableBuffer(index);
229 auto padded_data = std::make_shared<std::vector<int8>>(
230 ShapeSizeCompact(compile_time_shape), -1);
231 auto raw_input_runtime = std::make_shared<std::vector<uint32>>(
232 ShapeSizeCompact(runtime_shape) / sizeof(uint32));
233 stream->ThenMemcpyD2H(
234 se::DeviceMemory<int8>(mutable_input_mem->AsDeviceMemoryBase()),
235 absl::MakeSpan(absl::bit_cast<int8*>(raw_input_runtime->data()),
236 ShapeSizeCompactRaw(runtime_shape)));
237 stream->ThenDoHostCallback([raw_input_runtime, padded_data,
238 runtime_shape, compile_time_shape]() {
239 // After getting the data onto the host, transpose the data to
240 // the correct layout by delinearizing it and linearizing it again.
241 XLA_Shape c_runtime_shape, c_compile_time_shape;
242 ApiConverter::ToC(runtime_shape, &c_runtime_shape);
243 ApiConverter::ToC(compile_time_shape, &c_compile_time_shape);
244 StatusHelper status;
245
246 TpuExecute_RuntimeInputToPaddedData_Params params;
247 params.struct_size =
248 TpuExecute_RuntimeInputToPaddedData_Params_SIZE;
249 params.priv = nullptr;
250 params.runtime_input_ptr = raw_input_runtime->data();
251 params.runtime_input_size = raw_input_runtime->size();
252 params.padded_data_ptr = padded_data->data();
253 params.padded_data_size = padded_data->size();
254 params.runtime_shape = &c_runtime_shape;
255 params.compile_time_shape = &c_compile_time_shape;
256 params.status = status.c_status;
257
258 tensorflow::tpu::OpsApiFn()->TpuExecute_RuntimeInputToPaddedDataFn(
259 ¶ms);
260 ApiConverter::Destroy(&c_runtime_shape);
261 ApiConverter::Destroy(&c_compile_time_shape);
262 return status.status();
263 });
264 // Allocate new input and transfer the padded and transposed data to
265 // the new input location.
266 TF_ASSIGN_OR_RETURN(
267 auto new_input,
268 allocator->Allocate(stream->parent()->device_ordinal(),
269 ShapeSizeCompact(compile_time_shape)));
270 auto typed_new_input_memory =
271 se::DeviceMemory<int8>(new_input.cref());
272 stream->ThenMemcpyH2D<int8>(*padded_data, &typed_new_input_memory);
273
274 // Retain the memory until the end of the transfer.
275 stream->ThenDoHostCallback([padded_data]() { return OkStatus(); });
276
277 // Modify the memory location in the input shape tree to point to the
278 // new input.
279 *mutable_input_mem =
280 xla::MaybeOwningDeviceMemory(std::move(new_input));
281 element_modified = true;
282 return OkStatus();
283 }));
284 if (element_modified) {
285 // The input location has been modified, need to fix tuple table to
286 // point to the correct address.
287 TF_ASSIGN_OR_RETURN(
288 auto transfer_manager,
289 xla::TransferManager::GetForPlatform(stream->parent()->platform()));
290 TF_RETURN_IF_ERROR(FixTupleTableAsync(stream,
291 compile_time_shapes_on_device,
292 &runtime_input, transfer_manager));
293 }
294 }
295 return OkStatus();
296 }
297
TPUCancelExecution(Env * env,int device_ordinal)298 void TPUCancelExecution(Env* env, int device_ordinal) {
299 if (tpu_cancellation_terminates_process) {
300 LOG(INFO) << "TPUCancelExecution StopChipHeartbeats on device "
301 << device_ordinal;
302 Status status = TpuNodeContext::StopChipHeartbeats();
303 LOG(INFO) << "TPUCancelExecution StopChipHeartbeats done: " << status
304 << " on device " << device_ordinal;
305 // Sleep and exit in another thread so the cancellation manager can
306 // continue running callbacks. The new thread will call quick_exit,
307 // so we discard the returned Thread pointer because we won't have
308 // an opportunity to delete it.
309 (void)env->StartThread(ThreadOptions(), "tpu_execute_exit_countdown",
310 [env]() { ExitCountdown(env); });
311 } else if (tpu_cancellation_closes_chips) {
312 LOG(INFO) << "TPUCancelExecution CloseTPUHost on device " << device_ordinal;
313 Status status = TpuNodeContext::CloseTpuHost();
314 LOG(INFO) << "TPUCancelExecution CloseTPUHost done: " << status
315 << " on device " << device_ordinal;
316 } else {
317 LOG(INFO) << "TPUCancelExecution CloseTPUHost on device " << device_ordinal
318 << " is suppressed";
319 }
320 }
321
RegisterCancellation(OpKernelContext * ctx,CancellationManager * cancellation_manager,int device_ordinal)322 std::pair<CancellationToken, bool> RegisterCancellation(
323 OpKernelContext* ctx, CancellationManager* cancellation_manager,
324 int device_ordinal) {
325 // Set up a cancellation callback, to ensure the TPU program we run will
326 // halt if the RPC is cancelled. Without this the TPU program might block
327 // forever. The mechanism itself is a big hammer; we close all devices
328 // attached to this host on each cancellation callback. This is necessary to
329 // ensure the system will eventually halt, since the TensorNodes on each
330 // chip may be stuck waiting for mutual communication.
331 //
332 // By closing all devices, we ensure all subsequent attempts to use the
333 // device will fail, until the devices are re-initialized via a new call to
334 // tpu.initialize_system.
335 //
336 // In a multi-TensorNode setup, CloseTPUHost may be called once for each
337 // TensorNode, and each call will close all TensorNodes. This quadratic
338 // behavior ensures the mechanism is robust to various orderings
339 // (i.e. races) between the TPU programs, which are run on separate threads.
340 // In practice the quadratic behavior isn't that bad; the first call will
341 // actually halt any running TPU programs (which may be expensive), while
342 // subsequent calls will attempt to close an already-closed device (which is
343 // cheap).
344 //
345 // TODO(b/62262381): The cancellation manager is shared between multiple TPU
346 // execute ops and the cancellation will not be invoked only when RPC is
347 // cancelled (it may also be induced by OOM errors from a different TPU
348 // execute), this results in a pretty coarse cancellation domain. This
349 // cancellation callback should only execute in a narrower scope to not be
350 // triggered in such cases.
351 CancellationToken token = cancellation_manager->get_cancellation_token();
352 // Don't rely on OpKernelContext being available when the callback runs.
353 Env* env = ctx->env();
354 bool already_cancelled =
355 !cancellation_manager->RegisterCallbackWithErrorLogging(
356 token,
357 [device_ordinal, env]() { TPUCancelExecution(env, device_ordinal); },
358 absl::StrCat("TPUCancellation on device ", device_ordinal));
359 return std::pair<CancellationToken, bool>(token, already_cancelled);
360 }
361
UnregisterCancellation(OpKernelContext * ctx,CancellationManager * cancellation_manager,se::Stream * stream,int device_ordinal,CancellationToken token,std::shared_ptr<HostTransferManager> host_transfer_manager)362 void UnregisterCancellation(
363 OpKernelContext* ctx, CancellationManager* cancellation_manager,
364 se::Stream* stream, int device_ordinal, CancellationToken token,
365 std::shared_ptr<HostTransferManager> host_transfer_manager) {
366 // If execution reaches this point, the host callback enqueued below will get
367 // called regardless of stream status. Call inc_num_deferred_ops_function here
368 // and dec_num_deferred_ops_function in the host callback.
369 ctx->inc_num_deferred_ops_function()();
370 auto dec_num_deferred_ops_function = ctx->dec_num_deferred_ops_function();
371
372 // Try to avoid running callbacks on the compute stream, because this reduces
373 // the frequency of back-to-back programs (which are most efficient because
374 // they don't require host synchronization). Instead, borrow a substream and
375 // have the substream wait on the compute stream.
376 se::Stream* deregister_stream = stream->GetOrCreateSubStream();
377 deregister_stream->ThenWaitFor(stream);
378 deregister_stream->ThenDoHostCallback([=]() {
379 // Ensure the host_transfer_manager is copied into the callback scope.
380 (void)host_transfer_manager;
381
382 // We must deregister the callback in the success case, to avoid closing all
383 // devices. In the failure case we must NOT call DeregisterCallback as that
384 // waits for all previous cancellation callbacks to complete and any call
385 // to XlaDevice::Sync() will cause deadlock. Consider:
386 // 1) CancellationManager::StartCancel() is in progress (state is
387 // cancelling_).
388 // 2) The call below to DeregisterCallback will block until state is
389 // cancelled_ (all callbacks are completed).
390 // 3) A different cancellation callback has called XlaDevice::Sync(),
391 // which will block until (2) is done.
392 // 4) StartCancel() in (1) cannot complete until (3) is done.
393 //
394 // Instead, call TryDeregisterCallback. The functional difference is
395 // TryDeregisterCallback will not block if cancellation is in proress
396 // so makes no guarantees as to the state of any callbacks.
397 // This is not a problem, as our cancellation handler does not rely on
398 // any external state.
399 VLOG(1) << "cancellation_manager->TryDeregisterCallback on device "
400 << device_ordinal;
401 cancellation_manager->TryDeregisterCallback(token);
402 VLOG(1) << "cancellation_manager->TryDeregisterCallback done on device "
403 << device_ordinal;
404
405 // ExecutorState is held alive until at least this point to ensure
406 // cancellation_manager is valid. After all outstanding
407 // dec_num_deferred_ops_function are called, ExecutorState::Finish will be
408 // allowed to proceed.
409 dec_num_deferred_ops_function();
410 });
411 stream->ReturnSubStream(deregister_stream);
412 }
413
414 } // namespace
415
TPUExecute(const TPUExecutableInfoProto & executable,const TPUHostTransferInfoProto & host_transfers,const xla::HloProto & hlo_metadata,std::vector<xla::ExecutionInput> arguments,const string & rendezvous_key_base,uint32 rng_seed,TpuNodeContext * node_context,xla::DeviceAssignment * device_assignment,CancellationManager * cancellation_manager,OpKernelContext * ctx,stream_executor::Stream * stream,stream_executor::Stream * host_to_device_stream,const XLA_TpuProgram * tpu_program)416 xla::StatusOr<xla::ExecutionOutput> TPUExecute(
417 const TPUExecutableInfoProto& executable,
418 const TPUHostTransferInfoProto& host_transfers,
419 const xla::HloProto& hlo_metadata,
420 std::vector<xla::ExecutionInput> arguments,
421 const string& rendezvous_key_base, uint32 rng_seed,
422 TpuNodeContext* node_context, xla::DeviceAssignment* device_assignment,
423 CancellationManager* cancellation_manager, OpKernelContext* ctx,
424 stream_executor::Stream* stream,
425 stream_executor::Stream* host_to_device_stream,
426 const XLA_TpuProgram* tpu_program) {
427 profiler::TraceMe traceme("TPUExecute", 2);
428 TF_RET_CHECK(tpu::TpuPlatformInterface::GetRegisteredPlatform() != nullptr);
429 TF_RET_CHECK(tpu_program != nullptr);
430 VLOG(1) << "TPUExecute on device " << node_context->device_ordinal();
431
432 xla::Backend* backend = node_context->backend();
433
434 // Create a HostTransferManager to handle Send/Recv operations from the TPU.
435 std::shared_ptr<HostTransferManager> host_transfer_manager =
436 std::make_shared<HostTransferManager>(node_context, backend);
437 TF_ASSIGN_OR_RETURN(HostTransferManager::HostCommandHandler handler,
438 host_transfer_manager->Initialize(
439 host_transfers, rendezvous_key_base, ctx));
440
441 VLOG(2) << "Cloud TPU: Executing computation on device "
442 << node_context->device_ordinal();
443
444 xla::ExecutableRunOptions run_options;
445 run_options.set_stream(stream);
446 run_options.set_device_assignment(device_assignment);
447 run_options.set_rng_seed(rng_seed);
448 run_options.set_allocator(backend->memory_allocator());
449 run_options.set_host_to_device_stream(host_to_device_stream);
450
451 const xla::ServiceExecutableRunOptions service_run_options(run_options);
452
453 std::unique_ptr<xla::HloModule> module;
454 std::vector<xla::Shape> input_shapes;
455 {
456 xla::ComputationLayout computation_layout(
457 xla::ShapeLayout(xla::Shape(executable.output_shape())));
458 for (const xla::ShapeProto& shape_proto : executable.input_shapes()) {
459 xla::Shape shape(shape_proto);
460 computation_layout.add_parameter_layout(xla::ShapeLayout(shape));
461 input_shapes.push_back(std::move(shape));
462 }
463 module = std::make_unique<xla::HloModule>(
464 "TpuExecutableModule",
465 xla::HloModuleConfig(std::move(computation_layout)));
466 }
467
468 TF_ASSIGN_OR_RETURN(
469 module->input_output_alias_config(),
470 xla::HloInputOutputAliasConfig::CreateFromProto(
471 backend->transfer_manager()->HostShapeToDeviceShape(
472 module->config().entry_computation_layout().result_shape()),
473 hlo_metadata.hlo_module().input_output_alias()));
474 TF_RET_CHECK(executable.input_shapes().size() == arguments.size());
475
476 for (auto& prefetch : hlo_metadata.hlo_module().cross_program_prefetches()) {
477 module->AddCrossProgramPrefetch(
478 prefetch.parameter(),
479 xla::ShapeIndex(prefetch.index().begin(), prefetch.index().end()));
480 }
481
482 TF_RETURN_IF_ERROR(UpdateDynamicInputs(stream, backend->memory_allocator(),
483 &arguments, input_shapes));
484
485 auto tpu_executable = std::make_unique<TpuOpExecutable>(
486 tpu_program, std::move(module), /*host_command_handler=*/handler);
487
488 const int32_t device_ordinal = node_context->device_ordinal();
489 CancellationToken token;
490 bool already_cancelled;
491 std::tie(token, already_cancelled) =
492 RegisterCancellation(ctx, cancellation_manager, device_ordinal);
493
494 // If the RPC was already cancelled before we managed to register the
495 // cancellation callback, we shouldn't attempt to run the TPU program, since
496 // it might block forever.
497 if (already_cancelled) {
498 return errors::Cancelled(
499 "RPC cancelled, not running TPU program on device ", device_ordinal);
500 }
501
502 xla::StatusOr<xla::ExecutionOutput> output =
503 tpu_executable->ExecuteAsyncOnStream(&service_run_options,
504 std::move(arguments),
505 /*hlo_execution_profile=*/nullptr);
506
507 // If !output.ok(), it means we failed to enqueue the program the TPU. This is
508 // possibly caused by a failed cancellation callback closing the chips.
509 if (!output.ok()) {
510 // If cancellation manager is already cancelled or cancelling, it means
511 // another failure has occurred earlier and this TpuExecuteOp is cancelled
512 // regardless of whether itself is an error.
513 already_cancelled = cancellation_manager->IsCancelling() ||
514 cancellation_manager->IsCancelled();
515 if (already_cancelled) {
516 return errors::Cancelled(
517 "RPC cancelled, not running TPU program on device ", device_ordinal);
518 }
519 }
520 UnregisterCancellation(ctx, cancellation_manager, stream, device_ordinal,
521 token, host_transfer_manager);
522 VLOG(1) << "Cloud TPU: TPUExecute done";
523 return output;
524 }
525
526 } // namespace tensorflow
527