• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/tpu/tpu_execute.h"
17 
18 #include <cstdlib>
19 #include <memory>
20 #include <string>
21 #include <tuple>
22 #include <utility>
23 
24 #include "absl/base/casts.h"
25 #include "absl/memory/memory.h"
26 #include "tensorflow/compiler/xla/executable_run_options.h"
27 #include "tensorflow/compiler/xla/service/computation_layout.h"
28 #include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
29 #include "tensorflow/compiler/xla/service/hlo_module.h"
30 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
31 #include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h"
32 #include "tensorflow/compiler/xla/service/service_executable_run_options.h"
33 #include "tensorflow/compiler/xla/service/transfer_manager.h"
34 #include "tensorflow/compiler/xla/shape.h"
35 #include "tensorflow/compiler/xla/shape_layout.h"
36 #include "tensorflow/compiler/xla/shape_util.h"
37 #include "tensorflow/compiler/xla/status.h"
38 #include "tensorflow/compiler/xla/status_macros.h"
39 #include "tensorflow/compiler/xla/util.h"
40 #include "tensorflow/compiler/xla/xla_data.pb.h"
41 #include "tensorflow/core/platform/casts.h"
42 #include "tensorflow/core/platform/env.h"
43 #include "tensorflow/core/platform/errors.h"
44 #include "tensorflow/core/profiler/lib/traceme.h"
45 #include "tensorflow/core/tpu/tpu_api.h"
46 #include "tensorflow/stream_executor/device_memory.h"
47 #include "tensorflow/stream_executor/lib/statusor.h"
48 #include "tensorflow/stream_executor/tpu/c_api_conversions.h"
49 #include "tensorflow/stream_executor/tpu/status_helper.h"
50 #include "tensorflow/stream_executor/tpu/tpu_executable.h"
51 #include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
52 #include "tensorflow/stream_executor/tpu/tpu_platform_interface.h"
53 
54 namespace tensorflow {
55 
56 namespace {
57 
58 using ::tensorflow::tpu::TpuNodeContext;
59 
60 static bool tpu_cancellation_terminates_process = false;
61 static bool tpu_cancellation_closes_chips = true;
62 
63 // Host-side runtime for transfers between TPU and host.
64 // TODO(b/161940519): Implement this class.
65 class HostTransferManager {
66  public:
HostTransferManager(TpuNodeContext *,xla::Backend *)67   explicit HostTransferManager(TpuNodeContext*, xla::Backend*) {}
68 
69   using HostCommmandHandler = xla::TpuExecutable::HostCommandHandler;
70 
71   // Returns a function to be called when the TPU triggers a host command
72   // interrupt while executing the current program.
73   xla::StatusOr<HostCommmandHandler> Initialize(
74       const TPUHostTransferInfoProto& program,
75       const std::string& rendezvous_key_base, OpKernelContext* ctx);
76 
77  private:
78   TF_DISALLOW_COPY_AND_ASSIGN(HostTransferManager);
79 };
80 
81 xla::StatusOr<HostTransferManager::HostCommmandHandler>
Initialize(const TPUHostTransferInfoProto & program,const string & rendezvous_key_base,OpKernelContext * ctx)82 HostTransferManager::Initialize(const TPUHostTransferInfoProto& program,
83                                 const string& rendezvous_key_base,
84                                 OpKernelContext* ctx) {
85   return HostCommmandHandler([](uint32, int64) {
86     LOG(WARNING) << "HostTransferManager is unimplemented.";
87   });
88 }
89 
90 // Sleep for 5 seconds, then call std::quick_exit(42) to quickly restart.
ExitCountdown(Env * env)91 void ExitCountdown(Env* env) {
92   const int kSleepSeconds = 5;
93   LOG(INFO) << "TpuExecute was cancelled. Sleeping for " << kSleepSeconds
94             << " seconds before terminating the process to give time "
95                "for other errors to propagate";
96   env->SleepForMicroseconds(kSleepSeconds * 1000000);
97   LOG(ERROR) << "Aborting process due to cancelled TPUExecute. Consult "
98                 "the anomalies reported above (if any), run state of job "
99                 "(including failed RPCs) and worker logs. This "
100                 "termination is to ensure a consistent state, if your job "
101                 "does not restart, modify the retries allowed. See "
102                 "b/62262381 and b/65223927.";
103   std::quick_exit(42);
104 }
105 
HostShapeToDeviceShape(const xla::Shape & host_shape)106 xla::Shape HostShapeToDeviceShape(const xla::Shape& host_shape) {
107   XLA_Shape c_host_shape;
108   XLA_Shape c_device_shape;
109   ApiConverter::ToC(host_shape, &c_host_shape);
110   tensorflow::tpu::OpsApiFn()->HardwareLayout_HostShapeToDeviceShapeFn(
111       &c_host_shape, &c_device_shape);
112   xla::Shape device_shape = ApiConverter::FromC(&c_device_shape);
113   ApiConverter::Free(&c_host_shape);
114   ApiConverter::Free(&c_device_shape);
115   return device_shape;
116 }
117 
ShapeSizeCompact(const xla::Shape & shape)118 int64 ShapeSizeCompact(const xla::Shape& shape) {
119   XLA_Shape c_shape;
120   ApiConverter::ToC(shape, &c_shape);
121   int64 size =
122       tensorflow::tpu::OpsApiFn()->HardwareLayout_ShapeSizeCompactFn(&c_shape);
123   ApiConverter::Free(&c_shape);
124   return size;
125 }
126 
ShapeSizeCompactRaw(const xla::Shape & shape)127 int64 ShapeSizeCompactRaw(const xla::Shape& shape) {
128   XLA_Shape c_shape;
129   ApiConverter::ToC(shape, &c_shape);
130   int64 size =
131       tensorflow::tpu::OpsApiFn()->HardwareLayout_ShapeSizeCompactRawFn(
132           &c_shape);
133   ApiConverter::Free(&c_shape);
134   return size;
135 }
136 
137 // Given a tuple, fix all non-leaf nodes (tuples) such that the tuple tables
138 // point to the correct leaf nodes.
FixTupleTableAsync(se::Stream * stream,const xla::Shape & tuple_shape,xla::ExecutionInput * mem,xla::TransferManager * transfer_manager)139 xla::Status FixTupleTableAsync(se::Stream* stream,
140                                const xla::Shape& tuple_shape,
141                                xla::ExecutionInput* mem,
142                                xla::TransferManager* transfer_manager) {
143   return xla::ShapeUtil::ForEachSubshapeWithStatus(
144       tuple_shape,
145       [&](const xla::Shape& element_shape,
146           const xla::ShapeIndex& index) -> Status {
147         if (!element_shape.IsTuple()) {
148           return Status::OK();
149         }
150         std::vector<se::DeviceMemoryBase> elements;
151         xla::ShapeIndex element_index = index;
152         element_index.push_back(0);
153         for (int64 i = 0; i < element_shape.tuple_shapes_size(); ++i) {
154           // Gather all children of the tuple element.
155           element_index.back() = i;
156           elements.push_back(mem->Buffer(element_index).AsDeviceMemoryBase());
157         }
158         se::DeviceMemoryBase tuple_table_addr =
159             mem->Buffer(index).AsDeviceMemoryBase();
160         return transfer_manager->WriteSingleTupleIndexTable(
161             stream, elements, element_shape, &tuple_table_addr);
162       });
163 }
164 
165 // Returns true if `dynamic_shape` has dimensions that are less-equal to the
166 // "bounded_shape".
DynamicShapeIsCompatible(const xla::Shape & dynamic_shape,const xla::Shape & bounded_shape)167 bool DynamicShapeIsCompatible(const xla::Shape& dynamic_shape,
168                               const xla::Shape& bounded_shape) {
169   if (dynamic_shape.rank() != bounded_shape.rank()) {
170     return false;
171   }
172   for (int64 i = 0; i < dynamic_shape.rank(); ++i) {
173     if (dynamic_shape.dimensions(i) > bounded_shape.dimensions(i)) {
174       return false;
175     }
176   }
177   return true;
178 }
179 
180 // For dynamic inputs, copy them and attach metadata of shape sizes to the
181 // end of the tensor.
182 //
183 // The buffer for dynamic shapes contains three parts:
184 // +--------+
185 // |Payload |
186 // +--------+
187 // | Padding|
188 // +--------+
189 // |Metadata|
190 // +--------+
191 //
192 // Metadata contains the sizes of shape without padding, eventually
193 // representing the size of valid data.
UpdateDynamicInputs(se::Stream * stream,se::DeviceMemoryAllocator * allocator,std::vector<xla::ExecutionInput> * runtime_inputs,const std::vector<xla::Shape> & compile_time_shapes)194 xla::Status UpdateDynamicInputs(
195     se::Stream* stream, se::DeviceMemoryAllocator* allocator,
196     std::vector<xla::ExecutionInput>* runtime_inputs,
197     const std::vector<xla::Shape>& compile_time_shapes) {
198   TF_RET_CHECK(runtime_inputs->size() == compile_time_shapes.size());
199   for (int64 i = 0; i < compile_time_shapes.size(); i++) {
200     // TODO(yunxing): Iterating over thousands of elements can be slow. One way
201     // to optimize for fast path without dynamic shapes is add a field in
202     // compilation result indicating if dynamic input is presented.
203     if (compile_time_shapes[i].is_static()) {
204       continue;
205     }
206     auto& runtime_input = (*runtime_inputs)[i];
207     xla::Shape compile_time_shapes_on_device =
208         HostShapeToDeviceShape(compile_time_shapes[i]);
209     bool element_modified = false;
210     TF_RETURN_IF_ERROR(xla::ShapeUtil::ForEachSubshapeWithStatus(
211         compile_time_shapes_on_device,
212         [&](const xla::Shape& compile_time_shape,
213             const xla::ShapeIndex& index) -> Status {
214           if (compile_time_shape.IsTuple() || compile_time_shape.is_static()) {
215             return Status::OK();
216           }
217 
218           const xla::Shape& runtime_shape =
219               xla::ShapeUtil::GetSubshape(runtime_input.shape(), index);
220 
221           TF_RET_CHECK(!runtime_shape.IsTuple());
222           TF_RET_CHECK(
223               DynamicShapeIsCompatible(runtime_shape, compile_time_shape));
224 
225           xla::MaybeOwningDeviceMemory* mutable_input_mem =
226               runtime_input.MutableBuffer(index);
227           auto padded_data = std::make_shared<std::vector<int8>>(
228               ShapeSizeCompact(compile_time_shape), -1);
229           auto raw_input_runtime = std::make_shared<std::vector<uint32>>(
230               ShapeSizeCompact(runtime_shape) / sizeof(uint32));
231           stream->ThenMemcpyD2H(
232               se::DeviceMemory<int8>(mutable_input_mem->AsDeviceMemoryBase()),
233               absl::MakeSpan(absl::bit_cast<int8*>(raw_input_runtime->data()),
234                              ShapeSizeCompactRaw(runtime_shape)));
235           stream->ThenDoHostCallback([raw_input_runtime, padded_data,
236                                       runtime_shape, compile_time_shape]() {
237             // After getting the data onto the host, transpose the data to
238             // the correct layout by delinearizing it and linearizing it again.
239             XLA_Shape c_runtime_shape, c_compile_time_shape;
240             ApiConverter::ToC(runtime_shape, &c_runtime_shape);
241             ApiConverter::ToC(compile_time_shape, &c_compile_time_shape);
242             StatusHelper status;
243 
244             TpuExecute_RuntimeInputToPaddedData_Params params;
245             params.struct_size =
246                 TpuExecute_RuntimeInputToPaddedData_Params_SIZE;
247             params.priv = nullptr;
248             params.runtime_input_ptr = raw_input_runtime->data();
249             params.runtime_input_size = raw_input_runtime->size();
250             params.padded_data_ptr = padded_data->data();
251             params.padded_data_size = padded_data->size();
252             params.runtime_shape = &c_runtime_shape;
253             params.compile_time_shape = &c_compile_time_shape;
254             params.status = status.c_status;
255 
256             tensorflow::tpu::OpsApiFn()->TpuExecute_RuntimeInputToPaddedDataFn(
257                 &params);
258             ApiConverter::Free(&c_runtime_shape);
259             ApiConverter::Free(&c_compile_time_shape);
260             return status.status();
261           });
262           // Allocate new input and transfer the padded and transposed data to
263           // the new input location.
264           TF_ASSIGN_OR_RETURN(
265               auto new_input,
266               allocator->Allocate(stream->parent()->device_ordinal(),
267                                   ShapeSizeCompact(compile_time_shape)));
268           auto typed_new_input_memory =
269               se::DeviceMemory<int8>(new_input.cref());
270           stream->ThenMemcpyH2D<int8>(*padded_data, &typed_new_input_memory);
271 
272           // Retain the memory until the end of the transfer.
273           stream->ThenDoHostCallback([padded_data]() { return Status::OK(); });
274 
275           // Modify the memory location in the input shape tree to point to the
276           // new input.
277           *mutable_input_mem =
278               xla::MaybeOwningDeviceMemory(std::move(new_input));
279           element_modified = true;
280           return Status::OK();
281         }));
282     if (element_modified) {
283       // The input location has been modified, need to fix tuple table to
284       // point to the correct address.
285       TF_ASSIGN_OR_RETURN(
286           auto transfer_manager,
287           xla::TransferManager::GetForPlatform(stream->parent()->platform()));
288       TF_RETURN_IF_ERROR(FixTupleTableAsync(stream,
289                                             compile_time_shapes_on_device,
290                                             &runtime_input, transfer_manager));
291     }
292   }
293   return Status::OK();
294 }
295 
TPUCancelExecution(Env * env,int device_ordinal)296 void TPUCancelExecution(Env* env, int device_ordinal) {
297   if (tpu_cancellation_terminates_process) {
298     LOG(INFO) << "TPUCancelExecution StopChipHeartbeats on device "
299               << device_ordinal;
300     Status status = TpuNodeContext::StopChipHeartbeats();
301     LOG(INFO) << "TPUCancelExecution StopChipHeartbeats done: " << status
302               << " on device " << device_ordinal;
303     // Sleep and exit in another thread so the cancellation manager can
304     // continue running callbacks. The new thread will call quick_exit,
305     // so we discard the returned Thread pointer because we won't have
306     // an opportunity to delete it.
307     (void)env->StartThread(ThreadOptions(), "tpu_execute_exit_countdown",
308                            [env]() { ExitCountdown(env); });
309   } else if (tpu_cancellation_closes_chips) {
310     LOG(INFO) << "TPUCancelExecution CloseTPUHost on device " << device_ordinal;
311     Status status = TpuNodeContext::CloseTpuHost();
312     LOG(INFO) << "TPUCancelExecution CloseTPUHost done: " << status
313               << " on device " << device_ordinal;
314   } else {
315     LOG(INFO) << "TPUCancelExecution CloseTPUHost on device " << device_ordinal
316               << " is suppressed";
317   }
318 }
319 
RegisterCancellation(OpKernelContext * ctx,CancellationManager * cancellation_manager,int device_ordinal)320 std::pair<CancellationToken, bool> RegisterCancellation(
321     OpKernelContext* ctx, CancellationManager* cancellation_manager,
322     int device_ordinal) {
323   // Set up a cancellation callback, to ensure the TPU program we run will
324   // halt if the RPC is cancelled. Without this the TPU program might block
325   // forever. The mechanism itself is a big hammer; we close all devices
326   // attached to this host on each cancellation callback. This is necessary to
327   // ensure the system will eventually halt, since the TensorNodes on each
328   // chip may be stuck waiting for mutual communication.
329   //
330   // By closing all devices, we ensure all subsequent attempts to use the
331   // device will fail, until the devices are re-initialized via a new call to
332   // tpu.initialize_system.
333   //
334   // In a multi-TensorNode setup, CloseTPUHost may be called once for each
335   // TensorNode, and each call will close all TensorNodes. This quadratic
336   // behavior ensures the mechanism is robust to various orderings
337   // (i.e. races) between the TPU programs, which are run on separate threads.
338   // In practice the quadratic behavior isn't that bad; the first call will
339   // actually halt any running TPU programs (which may be expensive), while
340   // subsequent calls will attempt to close an already-closed device (which is
341   // cheap).
342   //
343   // TODO(b/62262381): The cancellation manager is shared between multiple TPU
344   // execute ops and the cancellation will not be invoked only when RPC is
345   // cancelled (it may also be induced by OOM errors from a different TPU
346   // execute), this results in a pretty coarse cancellation domain. This
347   // cancellation callback should only execute in a narrower scope to not be
348   // triggered in such cases.
349   CancellationToken token = cancellation_manager->get_cancellation_token();
350   // Don't rely on OpKernelContext being available when the callback runs.
351   Env* env = ctx->env();
352   bool already_cancelled = !cancellation_manager->RegisterCallback(
353       token,
354       [device_ordinal, env]() { TPUCancelExecution(env, device_ordinal); });
355   return std::pair<CancellationToken, bool>(token, already_cancelled);
356 }
357 
UnregisterCancellation(OpKernelContext * ctx,CancellationManager * cancellation_manager,se::Stream * stream,int device_ordinal,CancellationToken token,std::shared_ptr<HostTransferManager> host_transfer_manager)358 void UnregisterCancellation(
359     OpKernelContext* ctx, CancellationManager* cancellation_manager,
360     se::Stream* stream, int device_ordinal, CancellationToken token,
361     std::shared_ptr<HostTransferManager> host_transfer_manager) {
362   // If execution reaches this point, the host callback enqueued below will get
363   // called regardless of stream status. Call inc_num_deferred_ops_function here
364   // and dec_num_deferred_ops_function in the host callback.
365   ctx->inc_num_deferred_ops_function()();
366   auto dec_num_deferred_ops_function = ctx->dec_num_deferred_ops_function();
367 
368   // Try to avoid running callbacks on the compute stream, because this reduces
369   // the frequency of back-to-back programs (which are most efficient because
370   // they don't require host synchronization). Instead, borrow a substream and
371   // have the substream wait on the compute stream.
372   se::Stream* deregister_stream = stream->GetOrCreateSubStream();
373   deregister_stream->ThenWaitFor(stream);
374   deregister_stream->ThenDoHostCallback([=]() {
375     // Ensure the host_transfer_manager is copied into the callback scope.
376     (void)host_transfer_manager;
377 
378     // We must deregister the callback in the success case, to avoid closing all
379     // devices. In the failure case we must NOT call DeregisterCallback as that
380     // waits for all previous cancellation callbacks to complete and any call
381     // to XlaDevice::Sync() will cause deadlock. Consider:
382     //   1) CancellationManager::StartCancel() is in progress (state is
383     //      cancelling_).
384     //   2) The call below to DeregisterCallback will block until state is
385     //   cancelled_ (all callbacks are completed).
386     //   3) A different cancellation callback has called XlaDevice::Sync(),
387     //   which will block until (2) is done.
388     //   4) StartCancel() in (1) cannot complete until (3) is done.
389     //
390     // Instead, call TryDeregisterCallback. The functional difference is
391     // TryDeregisterCallback will not block if cancellation is in proress
392     // so makes no guarantees as to the state of any callbacks.
393     // This is not a problem, as our cancellation handler does not rely on
394     // any external state.
395     VLOG(1) << "cancellation_manager->TryDeregisterCallback on device "
396             << device_ordinal;
397     cancellation_manager->TryDeregisterCallback(token);
398     VLOG(1) << "cancellation_manager->TryDeregisterCallback done on device "
399             << device_ordinal;
400 
401     // ExecutorState is held alive until at least this point to ensure
402     // cancellation_manager is valid. After all outstanding
403     // dec_num_deferred_ops_function are called, ExecutorState::Finish will be
404     // allowed to proceed.
405     dec_num_deferred_ops_function();
406   });
407   stream->ReturnSubStream(deregister_stream);
408 }
409 
410 }  // namespace
411 
TPUExecute(const TPUExecutableInfoProto & executable,const TPUHostTransferInfoProto & host_transfers,const xla::HloProto & hlo_metadata,std::vector<xla::ExecutionInput> arguments,const string & rendezvous_key_base,uint32 rng_seed,TpuNodeContext * node_context,xla::DeviceAssignment * device_assignment,CancellationManager * cancellation_manager,OpKernelContext * ctx,stream_executor::Stream * stream,stream_executor::Stream * host_to_device_stream,const XLA_TpuProgram * tpu_program)412 xla::StatusOr<xla::ExecutionOutput> TPUExecute(
413     const TPUExecutableInfoProto& executable,
414     const TPUHostTransferInfoProto& host_transfers,
415     const xla::HloProto& hlo_metadata,
416     std::vector<xla::ExecutionInput> arguments,
417     const string& rendezvous_key_base, uint32 rng_seed,
418     TpuNodeContext* node_context, xla::DeviceAssignment* device_assignment,
419     CancellationManager* cancellation_manager, OpKernelContext* ctx,
420     stream_executor::Stream* stream,
421     stream_executor::Stream* host_to_device_stream,
422     const XLA_TpuProgram* tpu_program) {
423   profiler::TraceMe traceme("TPUExecute", 2);
424   TF_RET_CHECK(tpu::TpuPlatformInterface::GetRegisteredPlatform() != nullptr);
425   TF_RET_CHECK(tpu_program != nullptr);
426   VLOG(1) << "TPUExecute on device " << node_context->device_ordinal();
427 
428   xla::Backend* backend = node_context->backend();
429 
430   // Create a HostTransferManager to handle Send/Recv operations from the TPU.
431   std::shared_ptr<HostTransferManager> host_transfer_manager =
432       std::make_shared<HostTransferManager>(node_context, backend);
433   TF_ASSIGN_OR_RETURN(HostTransferManager::HostCommmandHandler handler,
434                       host_transfer_manager->Initialize(
435                           host_transfers, rendezvous_key_base, ctx));
436 
437   VLOG(2) << "Cloud TPU: Executing computation on device "
438           << node_context->device_ordinal();
439 
440   xla::ExecutableRunOptions run_options;
441   run_options.set_stream(stream);
442   run_options.set_device_assignment(device_assignment);
443   run_options.set_rng_seed(rng_seed);
444   run_options.set_allocator(backend->memory_allocator());
445   run_options.set_host_to_device_stream(host_to_device_stream);
446 
447   const xla::ServiceExecutableRunOptions service_run_options(run_options);
448 
449   std::unique_ptr<xla::HloModule> module;
450   std::vector<xla::Shape> input_shapes;
451   {
452     xla::ComputationLayout computation_layout(
453         xla::ShapeLayout(xla::Shape(executable.output_shape())));
454     for (const xla::ShapeProto& shape_proto : executable.input_shapes()) {
455       xla::Shape shape(shape_proto);
456       computation_layout.add_parameter_layout(xla::ShapeLayout(shape));
457       input_shapes.push_back(std::move(shape));
458     }
459     module = absl::make_unique<xla::HloModule>(
460         "TpuExecutableModule",
461         xla::HloModuleConfig(std::move(computation_layout)));
462   }
463 
464   TF_ASSIGN_OR_RETURN(
465       module->input_output_alias_config(),
466       xla::HloInputOutputAliasConfig::CreateFromProto(
467           backend->transfer_manager()->HostShapeToDeviceShape(
468               module->config().entry_computation_layout().result_shape()),
469           hlo_metadata.hlo_module().input_output_alias()));
470   TF_RET_CHECK(executable.input_shapes().size() == arguments.size());
471 
472   for (auto& prefetch : hlo_metadata.hlo_module().cross_program_prefetches()) {
473     module->AddCrossProgramPrefetch(
474         prefetch.parameter(),
475         xla::ShapeIndex(prefetch.index().begin(), prefetch.index().end()));
476   }
477 
478   TF_RETURN_IF_ERROR(UpdateDynamicInputs(stream, backend->memory_allocator(),
479                                          &arguments, input_shapes));
480 
481   auto tpu_executable = absl::make_unique<xla::TpuExecutable>(
482       tpu_program, std::move(module), /*host_command_handler=*/handler);
483 
484   const int32 device_ordinal = node_context->device_ordinal();
485   CancellationToken token;
486   bool already_cancelled;
487   std::tie(token, already_cancelled) =
488       RegisterCancellation(ctx, cancellation_manager, device_ordinal);
489 
490   // If the RPC was already cancelled before we managed to register the
491   // cancellation callback, we shouldn't attempt to run the TPU program, since
492   // it might block forever.
493   if (already_cancelled) {
494     return errors::Cancelled(
495         "RPC cancelled, not running TPU program on device ", device_ordinal);
496   }
497 
498   xla::StatusOr<xla::ExecutionOutput> output =
499       tpu_executable->ExecuteAsyncOnStream(&service_run_options,
500                                            std::move(arguments),
501                                            /*hlo_execution_profile=*/nullptr);
502 
503   // If !output.ok(), it means we failed to enqueue the program the TPU. This is
504   // possibly caused by a failed cancellation callback closing the chips.
505   if (!output.ok()) {
506     // If cancellation manager is already cancelled or cancelling, it means
507     // another failure has occurred earlier and this TpuExecuteOp is cancelled
508     // regardless of whether itself is an error.
509     already_cancelled = cancellation_manager->IsCancelling() ||
510                         cancellation_manager->IsCancelled();
511     if (already_cancelled) {
512       return errors::Cancelled(
513           "RPC cancelled, not running TPU program on device ", device_ordinal);
514     }
515   }
516   UnregisterCancellation(ctx, cancellation_manager, stream, device_ordinal,
517                          token, host_transfer_manager);
518   VLOG(1) << "Cloud TPU: TPUExecute done";
519   return output;
520 }
521 
522 }  // namespace tensorflow
523