• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/tpu/tpu_execute.h"
17 
18 #include <cstdlib>
19 #include <memory>
20 #include <string>
21 #include <tuple>
22 #include <utility>
23 
24 #include "absl/base/casts.h"
25 #include "absl/memory/memory.h"
26 #include "tensorflow/compiler/xla/executable_run_options.h"
27 #include "tensorflow/compiler/xla/service/computation_layout.h"
28 #include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
29 #include "tensorflow/compiler/xla/service/hlo_module.h"
30 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
31 #include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h"
32 #include "tensorflow/compiler/xla/service/service_executable_run_options.h"
33 #include "tensorflow/compiler/xla/service/transfer_manager.h"
34 #include "tensorflow/compiler/xla/shape.h"
35 #include "tensorflow/compiler/xla/shape_layout.h"
36 #include "tensorflow/compiler/xla/shape_util.h"
37 #include "tensorflow/compiler/xla/status.h"
38 #include "tensorflow/compiler/xla/status_macros.h"
39 #include "tensorflow/compiler/xla/util.h"
40 #include "tensorflow/compiler/xla/xla_data.pb.h"
41 #include "tensorflow/core/platform/casts.h"
42 #include "tensorflow/core/platform/env.h"
43 #include "tensorflow/core/platform/errors.h"
44 #include "tensorflow/core/profiler/lib/traceme.h"
45 #include "tensorflow/core/tpu/kernels/tpu_execute_op_options.h"
46 #include "tensorflow/core/tpu/tpu_api.h"
47 #include "tensorflow/stream_executor/device_memory.h"
48 #include "tensorflow/stream_executor/lib/statusor.h"
49 #include "tensorflow/stream_executor/tpu/c_api_conversions.h"
50 #include "tensorflow/stream_executor/tpu/status_helper.h"
51 #include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
52 #include "tensorflow/stream_executor/tpu/tpu_op_executable.h"
53 #include "tensorflow/stream_executor/tpu/tpu_platform_interface.h"
54 
55 namespace tensorflow {
56 
57 namespace {
58 
59 using ::tensorflow::tpu::TpuNodeContext;
60 
61 // These are placeholders for absl flags.
62 static bool tpu_cancellation_terminates_process = false;
63 static bool tpu_cancellation_closes_chips = true;
64 
65 // Host-side runtime for transfers between TPU and host.
66 // TODO(b/161940519): Implement this class.
67 class HostTransferManager {
68  public:
HostTransferManager(TpuNodeContext *,xla::Backend *)69   explicit HostTransferManager(TpuNodeContext*, xla::Backend*) {}
70 
71   using HostCommandHandler = TpuOpExecutable::HostCommandHandler;
72 
73   // Returns a function to be called when the TPU triggers a host command
74   // interrupt while executing the current program.
75   xla::StatusOr<HostCommandHandler> Initialize(
76       const TPUHostTransferInfoProto& program,
77       const std::string& rendezvous_key_base, OpKernelContext* ctx);
78 
79  private:
80   TF_DISALLOW_COPY_AND_ASSIGN(HostTransferManager);
81 };
82 
83 xla::StatusOr<HostTransferManager::HostCommandHandler>
Initialize(const TPUHostTransferInfoProto & program,const string & rendezvous_key_base,OpKernelContext * ctx)84 HostTransferManager::Initialize(const TPUHostTransferInfoProto& program,
85                                 const string& rendezvous_key_base,
86                                 OpKernelContext* ctx) {
87   return HostCommandHandler([](uint32_t, int64_t) {
88     LOG(WARNING) << "HostTransferManager is unimplemented.";
89   });
90 }
91 
92 // Sleep for 5 seconds, then call std::quick_exit(42) to quickly restart.
ExitCountdown(Env * env)93 void ExitCountdown(Env* env) {
94   const int kSleepSeconds = 5;
95   LOG(INFO) << "TpuExecute was cancelled. Sleeping for " << kSleepSeconds
96             << " seconds before terminating the process to give time "
97                "for other errors to propagate";
98   env->SleepForMicroseconds(kSleepSeconds * 1000000);
99   LOG(ERROR) << "Aborting process due to cancelled TPUExecute. Consult "
100                 "the anomalies reported above (if any), run state of job "
101                 "(including failed RPCs) and worker logs. This "
102                 "termination is to ensure a consistent state, if your job "
103                 "does not restart, modify the retries allowed. See "
104                 "b/62262381 and b/65223927.";
105   std::quick_exit(42);
106 }
107 
HostShapeToDeviceShape(const xla::Shape & host_shape)108 xla::Shape HostShapeToDeviceShape(const xla::Shape& host_shape) {
109   XLA_Shape c_host_shape;
110   XLA_Shape c_device_shape;
111   ApiConverter::ToC(host_shape, &c_host_shape);
112   tensorflow::tpu::OpsApiFn()->HardwareLayout_HostShapeToDeviceShapeFn(
113       &c_host_shape, &c_device_shape);
114   xla::Shape device_shape = ApiConverter::FromC(&c_device_shape);
115   ApiConverter::Destroy(&c_host_shape);
116   ApiConverter::Destroy(&c_device_shape);
117   return device_shape;
118 }
119 
ShapeSizeCompact(const xla::Shape & shape)120 int64_t ShapeSizeCompact(const xla::Shape& shape) {
121   XLA_Shape c_shape;
122   ApiConverter::ToC(shape, &c_shape);
123   int64_t size =
124       tensorflow::tpu::OpsApiFn()->HardwareLayout_ShapeSizeCompactFn(&c_shape);
125   ApiConverter::Destroy(&c_shape);
126   return size;
127 }
128 
ShapeSizeCompactRaw(const xla::Shape & shape)129 int64_t ShapeSizeCompactRaw(const xla::Shape& shape) {
130   XLA_Shape c_shape;
131   ApiConverter::ToC(shape, &c_shape);
132   int64_t size =
133       tensorflow::tpu::OpsApiFn()->HardwareLayout_ShapeSizeCompactRawFn(
134           &c_shape);
135   ApiConverter::Destroy(&c_shape);
136   return size;
137 }
138 
139 // Given a tuple, fix all non-leaf nodes (tuples) such that the tuple tables
140 // point to the correct leaf nodes.
FixTupleTableAsync(se::Stream * stream,const xla::Shape & tuple_shape,xla::ExecutionInput * mem,xla::TransferManager * transfer_manager)141 xla::Status FixTupleTableAsync(se::Stream* stream,
142                                const xla::Shape& tuple_shape,
143                                xla::ExecutionInput* mem,
144                                xla::TransferManager* transfer_manager) {
145   return xla::ShapeUtil::ForEachSubshapeWithStatus(
146       tuple_shape,
147       [&](const xla::Shape& element_shape,
148           const xla::ShapeIndex& index) -> Status {
149         if (!element_shape.IsTuple()) {
150           return OkStatus();
151         }
152         std::vector<se::DeviceMemoryBase> elements;
153         xla::ShapeIndex element_index = index;
154         element_index.push_back(0);
155         for (int i = 0; i < element_shape.tuple_shapes_size(); ++i) {
156           // Gather all children of the tuple element.
157           element_index.back() = i;
158           elements.push_back(mem->Buffer(element_index).AsDeviceMemoryBase());
159         }
160         se::DeviceMemoryBase tuple_table_addr =
161             mem->Buffer(index).AsDeviceMemoryBase();
162         return transfer_manager->WriteSingleTupleIndexTable(
163             stream, elements, element_shape, &tuple_table_addr);
164       });
165 }
166 
167 // Returns true if `dynamic_shape` has dimensions that are less-equal to the
168 // "bounded_shape".
DynamicShapeIsCompatible(const xla::Shape & dynamic_shape,const xla::Shape & bounded_shape)169 bool DynamicShapeIsCompatible(const xla::Shape& dynamic_shape,
170                               const xla::Shape& bounded_shape) {
171   if (dynamic_shape.rank() != bounded_shape.rank()) {
172     return false;
173   }
174   for (int64_t i = 0; i < dynamic_shape.rank(); ++i) {
175     if (dynamic_shape.dimensions(i) > bounded_shape.dimensions(i)) {
176       return false;
177     }
178   }
179   return true;
180 }
181 
182 // For dynamic inputs, copy them and attach metadata of shape sizes to the
183 // end of the tensor.
184 //
185 // The buffer for dynamic shapes contains three parts:
186 // +--------+
187 // |Payload |
188 // +--------+
189 // | Padding|
190 // +--------+
191 // |Metadata|
192 // +--------+
193 //
194 // Metadata contains the sizes of shape without padding, eventually
195 // representing the size of valid data.
UpdateDynamicInputs(se::Stream * stream,se::DeviceMemoryAllocator * allocator,std::vector<xla::ExecutionInput> * runtime_inputs,const std::vector<xla::Shape> & compile_time_shapes)196 xla::Status UpdateDynamicInputs(
197     se::Stream* stream, se::DeviceMemoryAllocator* allocator,
198     std::vector<xla::ExecutionInput>* runtime_inputs,
199     const std::vector<xla::Shape>& compile_time_shapes) {
200   TF_RET_CHECK(runtime_inputs->size() == compile_time_shapes.size());
201   for (int64_t i = 0; i < compile_time_shapes.size(); i++) {
202     // TODO(yunxing): Iterating over thousands of elements can be slow. One way
203     // to optimize for fast path without dynamic shapes is add a field in
204     // compilation result indicating if dynamic input is presented.
205     if (compile_time_shapes[i].is_static()) {
206       continue;
207     }
208     auto& runtime_input = (*runtime_inputs)[i];
209     xla::Shape compile_time_shapes_on_device =
210         HostShapeToDeviceShape(compile_time_shapes[i]);
211     bool element_modified = false;
212     TF_RETURN_IF_ERROR(xla::ShapeUtil::ForEachSubshapeWithStatus(
213         compile_time_shapes_on_device,
214         [&](const xla::Shape& compile_time_shape,
215             const xla::ShapeIndex& index) -> Status {
216           if (compile_time_shape.IsTuple() || compile_time_shape.is_static()) {
217             return OkStatus();
218           }
219 
220           const xla::Shape& runtime_shape =
221               xla::ShapeUtil::GetSubshape(runtime_input.shape(), index);
222 
223           TF_RET_CHECK(!runtime_shape.IsTuple());
224           TF_RET_CHECK(
225               DynamicShapeIsCompatible(runtime_shape, compile_time_shape));
226 
227           xla::MaybeOwningDeviceMemory* mutable_input_mem =
228               runtime_input.MutableBuffer(index);
229           auto padded_data = std::make_shared<std::vector<int8>>(
230               ShapeSizeCompact(compile_time_shape), -1);
231           auto raw_input_runtime = std::make_shared<std::vector<uint32>>(
232               ShapeSizeCompact(runtime_shape) / sizeof(uint32));
233           stream->ThenMemcpyD2H(
234               se::DeviceMemory<int8>(mutable_input_mem->AsDeviceMemoryBase()),
235               absl::MakeSpan(absl::bit_cast<int8*>(raw_input_runtime->data()),
236                              ShapeSizeCompactRaw(runtime_shape)));
237           stream->ThenDoHostCallback([raw_input_runtime, padded_data,
238                                       runtime_shape, compile_time_shape]() {
239             // After getting the data onto the host, transpose the data to
240             // the correct layout by delinearizing it and linearizing it again.
241             XLA_Shape c_runtime_shape, c_compile_time_shape;
242             ApiConverter::ToC(runtime_shape, &c_runtime_shape);
243             ApiConverter::ToC(compile_time_shape, &c_compile_time_shape);
244             StatusHelper status;
245 
246             TpuExecute_RuntimeInputToPaddedData_Params params;
247             params.struct_size =
248                 TpuExecute_RuntimeInputToPaddedData_Params_SIZE;
249             params.priv = nullptr;
250             params.runtime_input_ptr = raw_input_runtime->data();
251             params.runtime_input_size = raw_input_runtime->size();
252             params.padded_data_ptr = padded_data->data();
253             params.padded_data_size = padded_data->size();
254             params.runtime_shape = &c_runtime_shape;
255             params.compile_time_shape = &c_compile_time_shape;
256             params.status = status.c_status;
257 
258             tensorflow::tpu::OpsApiFn()->TpuExecute_RuntimeInputToPaddedDataFn(
259                 &params);
260             ApiConverter::Destroy(&c_runtime_shape);
261             ApiConverter::Destroy(&c_compile_time_shape);
262             return status.status();
263           });
264           // Allocate new input and transfer the padded and transposed data to
265           // the new input location.
266           TF_ASSIGN_OR_RETURN(
267               auto new_input,
268               allocator->Allocate(stream->parent()->device_ordinal(),
269                                   ShapeSizeCompact(compile_time_shape)));
270           auto typed_new_input_memory =
271               se::DeviceMemory<int8>(new_input.cref());
272           stream->ThenMemcpyH2D<int8>(*padded_data, &typed_new_input_memory);
273 
274           // Retain the memory until the end of the transfer.
275           stream->ThenDoHostCallback([padded_data]() { return OkStatus(); });
276 
277           // Modify the memory location in the input shape tree to point to the
278           // new input.
279           *mutable_input_mem =
280               xla::MaybeOwningDeviceMemory(std::move(new_input));
281           element_modified = true;
282           return OkStatus();
283         }));
284     if (element_modified) {
285       // The input location has been modified, need to fix tuple table to
286       // point to the correct address.
287       TF_ASSIGN_OR_RETURN(
288           auto transfer_manager,
289           xla::TransferManager::GetForPlatform(stream->parent()->platform()));
290       TF_RETURN_IF_ERROR(FixTupleTableAsync(stream,
291                                             compile_time_shapes_on_device,
292                                             &runtime_input, transfer_manager));
293     }
294   }
295   return OkStatus();
296 }
297 
TPUCancelExecution(Env * env,int device_ordinal)298 void TPUCancelExecution(Env* env, int device_ordinal) {
299   if (tpu_cancellation_terminates_process) {
300     LOG(INFO) << "TPUCancelExecution StopChipHeartbeats on device "
301               << device_ordinal;
302     Status status = TpuNodeContext::StopChipHeartbeats();
303     LOG(INFO) << "TPUCancelExecution StopChipHeartbeats done: " << status
304               << " on device " << device_ordinal;
305     // Sleep and exit in another thread so the cancellation manager can
306     // continue running callbacks. The new thread will call quick_exit,
307     // so we discard the returned Thread pointer because we won't have
308     // an opportunity to delete it.
309     (void)env->StartThread(ThreadOptions(), "tpu_execute_exit_countdown",
310                            [env]() { ExitCountdown(env); });
311   } else if (tpu_cancellation_closes_chips) {
312     LOG(INFO) << "TPUCancelExecution CloseTPUHost on device " << device_ordinal;
313     Status status = TpuNodeContext::CloseTpuHost();
314     LOG(INFO) << "TPUCancelExecution CloseTPUHost done: " << status
315               << " on device " << device_ordinal;
316   } else {
317     LOG(INFO) << "TPUCancelExecution CloseTPUHost on device " << device_ordinal
318               << " is suppressed";
319   }
320 }
321 
RegisterCancellation(OpKernelContext * ctx,CancellationManager * cancellation_manager,int device_ordinal)322 std::pair<CancellationToken, bool> RegisterCancellation(
323     OpKernelContext* ctx, CancellationManager* cancellation_manager,
324     int device_ordinal) {
325   // Set up a cancellation callback, to ensure the TPU program we run will
326   // halt if the RPC is cancelled. Without this the TPU program might block
327   // forever. The mechanism itself is a big hammer; we close all devices
328   // attached to this host on each cancellation callback. This is necessary to
329   // ensure the system will eventually halt, since the TensorNodes on each
330   // chip may be stuck waiting for mutual communication.
331   //
332   // By closing all devices, we ensure all subsequent attempts to use the
333   // device will fail, until the devices are re-initialized via a new call to
334   // tpu.initialize_system.
335   //
336   // In a multi-TensorNode setup, CloseTPUHost may be called once for each
337   // TensorNode, and each call will close all TensorNodes. This quadratic
338   // behavior ensures the mechanism is robust to various orderings
339   // (i.e. races) between the TPU programs, which are run on separate threads.
340   // In practice the quadratic behavior isn't that bad; the first call will
341   // actually halt any running TPU programs (which may be expensive), while
342   // subsequent calls will attempt to close an already-closed device (which is
343   // cheap).
344   //
345   // TODO(b/62262381): The cancellation manager is shared between multiple TPU
346   // execute ops and the cancellation will not be invoked only when RPC is
347   // cancelled (it may also be induced by OOM errors from a different TPU
348   // execute), this results in a pretty coarse cancellation domain. This
349   // cancellation callback should only execute in a narrower scope to not be
350   // triggered in such cases.
351   CancellationToken token = cancellation_manager->get_cancellation_token();
352   // Don't rely on OpKernelContext being available when the callback runs.
353   Env* env = ctx->env();
354   bool already_cancelled =
355       !cancellation_manager->RegisterCallbackWithErrorLogging(
356           token,
357           [device_ordinal, env]() { TPUCancelExecution(env, device_ordinal); },
358           absl::StrCat("TPUCancellation on device ", device_ordinal));
359   return std::pair<CancellationToken, bool>(token, already_cancelled);
360 }
361 
UnregisterCancellation(OpKernelContext * ctx,CancellationManager * cancellation_manager,se::Stream * stream,int device_ordinal,CancellationToken token,std::shared_ptr<HostTransferManager> host_transfer_manager)362 void UnregisterCancellation(
363     OpKernelContext* ctx, CancellationManager* cancellation_manager,
364     se::Stream* stream, int device_ordinal, CancellationToken token,
365     std::shared_ptr<HostTransferManager> host_transfer_manager) {
366   // If execution reaches this point, the host callback enqueued below will get
367   // called regardless of stream status. Call inc_num_deferred_ops_function here
368   // and dec_num_deferred_ops_function in the host callback.
369   ctx->inc_num_deferred_ops_function()();
370   auto dec_num_deferred_ops_function = ctx->dec_num_deferred_ops_function();
371 
372   // Try to avoid running callbacks on the compute stream, because this reduces
373   // the frequency of back-to-back programs (which are most efficient because
374   // they don't require host synchronization). Instead, borrow a substream and
375   // have the substream wait on the compute stream.
376   se::Stream* deregister_stream = stream->GetOrCreateSubStream();
377   deregister_stream->ThenWaitFor(stream);
378   deregister_stream->ThenDoHostCallback([=]() {
379     // Ensure the host_transfer_manager is copied into the callback scope.
380     (void)host_transfer_manager;
381 
382     // We must deregister the callback in the success case, to avoid closing all
383     // devices. In the failure case we must NOT call DeregisterCallback as that
384     // waits for all previous cancellation callbacks to complete and any call
385     // to XlaDevice::Sync() will cause deadlock. Consider:
386     //   1) CancellationManager::StartCancel() is in progress (state is
387     //      cancelling_).
388     //   2) The call below to DeregisterCallback will block until state is
389     //   cancelled_ (all callbacks are completed).
390     //   3) A different cancellation callback has called XlaDevice::Sync(),
391     //   which will block until (2) is done.
392     //   4) StartCancel() in (1) cannot complete until (3) is done.
393     //
394     // Instead, call TryDeregisterCallback. The functional difference is
395     // TryDeregisterCallback will not block if cancellation is in proress
396     // so makes no guarantees as to the state of any callbacks.
397     // This is not a problem, as our cancellation handler does not rely on
398     // any external state.
399     VLOG(1) << "cancellation_manager->TryDeregisterCallback on device "
400             << device_ordinal;
401     cancellation_manager->TryDeregisterCallback(token);
402     VLOG(1) << "cancellation_manager->TryDeregisterCallback done on device "
403             << device_ordinal;
404 
405     // ExecutorState is held alive until at least this point to ensure
406     // cancellation_manager is valid. After all outstanding
407     // dec_num_deferred_ops_function are called, ExecutorState::Finish will be
408     // allowed to proceed.
409     dec_num_deferred_ops_function();
410   });
411   stream->ReturnSubStream(deregister_stream);
412 }
413 
414 }  // namespace
415 
TPUExecute(const TPUExecutableInfoProto & executable,const TPUHostTransferInfoProto & host_transfers,const xla::HloProto & hlo_metadata,std::vector<xla::ExecutionInput> arguments,const string & rendezvous_key_base,uint32 rng_seed,TpuNodeContext * node_context,xla::DeviceAssignment * device_assignment,CancellationManager * cancellation_manager,OpKernelContext * ctx,stream_executor::Stream * stream,stream_executor::Stream * host_to_device_stream,const XLA_TpuProgram * tpu_program)416 xla::StatusOr<xla::ExecutionOutput> TPUExecute(
417     const TPUExecutableInfoProto& executable,
418     const TPUHostTransferInfoProto& host_transfers,
419     const xla::HloProto& hlo_metadata,
420     std::vector<xla::ExecutionInput> arguments,
421     const string& rendezvous_key_base, uint32 rng_seed,
422     TpuNodeContext* node_context, xla::DeviceAssignment* device_assignment,
423     CancellationManager* cancellation_manager, OpKernelContext* ctx,
424     stream_executor::Stream* stream,
425     stream_executor::Stream* host_to_device_stream,
426     const XLA_TpuProgram* tpu_program) {
427   profiler::TraceMe traceme("TPUExecute", 2);
428   TF_RET_CHECK(tpu::TpuPlatformInterface::GetRegisteredPlatform() != nullptr);
429   TF_RET_CHECK(tpu_program != nullptr);
430   VLOG(1) << "TPUExecute on device " << node_context->device_ordinal();
431 
432   xla::Backend* backend = node_context->backend();
433 
434   // Create a HostTransferManager to handle Send/Recv operations from the TPU.
435   std::shared_ptr<HostTransferManager> host_transfer_manager =
436       std::make_shared<HostTransferManager>(node_context, backend);
437   TF_ASSIGN_OR_RETURN(HostTransferManager::HostCommandHandler handler,
438                       host_transfer_manager->Initialize(
439                           host_transfers, rendezvous_key_base, ctx));
440 
441   VLOG(2) << "Cloud TPU: Executing computation on device "
442           << node_context->device_ordinal();
443 
444   xla::ExecutableRunOptions run_options;
445   run_options.set_stream(stream);
446   run_options.set_device_assignment(device_assignment);
447   run_options.set_rng_seed(rng_seed);
448   run_options.set_allocator(backend->memory_allocator());
449   run_options.set_host_to_device_stream(host_to_device_stream);
450 
451   const xla::ServiceExecutableRunOptions service_run_options(run_options);
452 
453   std::unique_ptr<xla::HloModule> module;
454   std::vector<xla::Shape> input_shapes;
455   {
456     xla::ComputationLayout computation_layout(
457         xla::ShapeLayout(xla::Shape(executable.output_shape())));
458     for (const xla::ShapeProto& shape_proto : executable.input_shapes()) {
459       xla::Shape shape(shape_proto);
460       computation_layout.add_parameter_layout(xla::ShapeLayout(shape));
461       input_shapes.push_back(std::move(shape));
462     }
463     module = std::make_unique<xla::HloModule>(
464         "TpuExecutableModule",
465         xla::HloModuleConfig(std::move(computation_layout)));
466   }
467 
468   TF_ASSIGN_OR_RETURN(
469       module->input_output_alias_config(),
470       xla::HloInputOutputAliasConfig::CreateFromProto(
471           backend->transfer_manager()->HostShapeToDeviceShape(
472               module->config().entry_computation_layout().result_shape()),
473           hlo_metadata.hlo_module().input_output_alias()));
474   TF_RET_CHECK(executable.input_shapes().size() == arguments.size());
475 
476   for (auto& prefetch : hlo_metadata.hlo_module().cross_program_prefetches()) {
477     module->AddCrossProgramPrefetch(
478         prefetch.parameter(),
479         xla::ShapeIndex(prefetch.index().begin(), prefetch.index().end()));
480   }
481 
482   TF_RETURN_IF_ERROR(UpdateDynamicInputs(stream, backend->memory_allocator(),
483                                          &arguments, input_shapes));
484 
485   auto tpu_executable = std::make_unique<TpuOpExecutable>(
486       tpu_program, std::move(module), /*host_command_handler=*/handler);
487 
488   const int32_t device_ordinal = node_context->device_ordinal();
489   CancellationToken token;
490   bool already_cancelled;
491   std::tie(token, already_cancelled) =
492       RegisterCancellation(ctx, cancellation_manager, device_ordinal);
493 
494   // If the RPC was already cancelled before we managed to register the
495   // cancellation callback, we shouldn't attempt to run the TPU program, since
496   // it might block forever.
497   if (already_cancelled) {
498     return errors::Cancelled(
499         "RPC cancelled, not running TPU program on device ", device_ordinal);
500   }
501 
502   xla::StatusOr<xla::ExecutionOutput> output =
503       tpu_executable->ExecuteAsyncOnStream(&service_run_options,
504                                            std::move(arguments),
505                                            /*hlo_execution_profile=*/nullptr);
506 
507   // If !output.ok(), it means we failed to enqueue the program the TPU. This is
508   // possibly caused by a failed cancellation callback closing the chips.
509   if (!output.ok()) {
510     // If cancellation manager is already cancelled or cancelling, it means
511     // another failure has occurred earlier and this TpuExecuteOp is cancelled
512     // regardless of whether itself is an error.
513     already_cancelled = cancellation_manager->IsCancelling() ||
514                         cancellation_manager->IsCancelled();
515     if (already_cancelled) {
516       return errors::Cancelled(
517           "RPC cancelled, not running TPU program on device ", device_ordinal);
518     }
519   }
520   UnregisterCancellation(ctx, cancellation_manager, stream, device_ordinal,
521                          token, host_transfer_manager);
522   VLOG(1) << "Cloud TPU: TPUExecute done";
523   return output;
524 }
525 
526 }  // namespace tensorflow
527