1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/tpu/tpu_execute.h"
17
18 #include <cstdlib>
19 #include <memory>
20 #include <string>
21 #include <tuple>
22 #include <utility>
23
24 #include "absl/base/casts.h"
25 #include "absl/memory/memory.h"
26 #include "tensorflow/compiler/xla/executable_run_options.h"
27 #include "tensorflow/compiler/xla/service/computation_layout.h"
28 #include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
29 #include "tensorflow/compiler/xla/service/hlo_module.h"
30 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
31 #include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h"
32 #include "tensorflow/compiler/xla/service/service_executable_run_options.h"
33 #include "tensorflow/compiler/xla/service/transfer_manager.h"
34 #include "tensorflow/compiler/xla/shape.h"
35 #include "tensorflow/compiler/xla/shape_layout.h"
36 #include "tensorflow/compiler/xla/shape_util.h"
37 #include "tensorflow/compiler/xla/status.h"
38 #include "tensorflow/compiler/xla/status_macros.h"
39 #include "tensorflow/compiler/xla/util.h"
40 #include "tensorflow/compiler/xla/xla_data.pb.h"
41 #include "tensorflow/core/platform/casts.h"
42 #include "tensorflow/core/platform/env.h"
43 #include "tensorflow/core/platform/errors.h"
44 #include "tensorflow/core/profiler/lib/traceme.h"
45 #include "tensorflow/core/tpu/tpu_api.h"
46 #include "tensorflow/stream_executor/device_memory.h"
47 #include "tensorflow/stream_executor/lib/statusor.h"
48 #include "tensorflow/stream_executor/tpu/c_api_conversions.h"
49 #include "tensorflow/stream_executor/tpu/status_helper.h"
50 #include "tensorflow/stream_executor/tpu/tpu_executable.h"
51 #include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
52 #include "tensorflow/stream_executor/tpu/tpu_platform_interface.h"
53
54 namespace tensorflow {
55
56 namespace {
57
58 using ::tensorflow::tpu::TpuNodeContext;
59
60 static bool tpu_cancellation_terminates_process = false;
61 static bool tpu_cancellation_closes_chips = true;
62
63 // Host-side runtime for transfers between TPU and host.
64 // TODO(b/161940519): Implement this class.
65 class HostTransferManager {
66 public:
HostTransferManager(TpuNodeContext *,xla::Backend *)67 explicit HostTransferManager(TpuNodeContext*, xla::Backend*) {}
68
69 using HostCommmandHandler = xla::TpuExecutable::HostCommandHandler;
70
71 // Returns a function to be called when the TPU triggers a host command
72 // interrupt while executing the current program.
73 xla::StatusOr<HostCommmandHandler> Initialize(
74 const TPUHostTransferInfoProto& program,
75 const std::string& rendezvous_key_base, OpKernelContext* ctx);
76
77 private:
78 TF_DISALLOW_COPY_AND_ASSIGN(HostTransferManager);
79 };
80
81 xla::StatusOr<HostTransferManager::HostCommmandHandler>
Initialize(const TPUHostTransferInfoProto & program,const string & rendezvous_key_base,OpKernelContext * ctx)82 HostTransferManager::Initialize(const TPUHostTransferInfoProto& program,
83 const string& rendezvous_key_base,
84 OpKernelContext* ctx) {
85 return HostCommmandHandler([](uint32, int64) {
86 LOG(WARNING) << "HostTransferManager is unimplemented.";
87 });
88 }
89
90 // Sleep for 5 seconds, then call std::quick_exit(42) to quickly restart.
ExitCountdown(Env * env)91 void ExitCountdown(Env* env) {
92 const int kSleepSeconds = 5;
93 LOG(INFO) << "TpuExecute was cancelled. Sleeping for " << kSleepSeconds
94 << " seconds before terminating the process to give time "
95 "for other errors to propagate";
96 env->SleepForMicroseconds(kSleepSeconds * 1000000);
97 LOG(ERROR) << "Aborting process due to cancelled TPUExecute. Consult "
98 "the anomalies reported above (if any), run state of job "
99 "(including failed RPCs) and worker logs. This "
100 "termination is to ensure a consistent state, if your job "
101 "does not restart, modify the retries allowed. See "
102 "b/62262381 and b/65223927.";
103 std::quick_exit(42);
104 }
105
HostShapeToDeviceShape(const xla::Shape & host_shape)106 xla::Shape HostShapeToDeviceShape(const xla::Shape& host_shape) {
107 XLA_Shape c_host_shape;
108 XLA_Shape c_device_shape;
109 ApiConverter::ToC(host_shape, &c_host_shape);
110 tensorflow::tpu::OpsApiFn()->HardwareLayout_HostShapeToDeviceShapeFn(
111 &c_host_shape, &c_device_shape);
112 xla::Shape device_shape = ApiConverter::FromC(&c_device_shape);
113 ApiConverter::Free(&c_host_shape);
114 ApiConverter::Free(&c_device_shape);
115 return device_shape;
116 }
117
ShapeSizeCompact(const xla::Shape & shape)118 int64 ShapeSizeCompact(const xla::Shape& shape) {
119 XLA_Shape c_shape;
120 ApiConverter::ToC(shape, &c_shape);
121 int64 size =
122 tensorflow::tpu::OpsApiFn()->HardwareLayout_ShapeSizeCompactFn(&c_shape);
123 ApiConverter::Free(&c_shape);
124 return size;
125 }
126
ShapeSizeCompactRaw(const xla::Shape & shape)127 int64 ShapeSizeCompactRaw(const xla::Shape& shape) {
128 XLA_Shape c_shape;
129 ApiConverter::ToC(shape, &c_shape);
130 int64 size =
131 tensorflow::tpu::OpsApiFn()->HardwareLayout_ShapeSizeCompactRawFn(
132 &c_shape);
133 ApiConverter::Free(&c_shape);
134 return size;
135 }
136
137 // Given a tuple, fix all non-leaf nodes (tuples) such that the tuple tables
138 // point to the correct leaf nodes.
FixTupleTableAsync(se::Stream * stream,const xla::Shape & tuple_shape,xla::ExecutionInput * mem,xla::TransferManager * transfer_manager)139 xla::Status FixTupleTableAsync(se::Stream* stream,
140 const xla::Shape& tuple_shape,
141 xla::ExecutionInput* mem,
142 xla::TransferManager* transfer_manager) {
143 return xla::ShapeUtil::ForEachSubshapeWithStatus(
144 tuple_shape,
145 [&](const xla::Shape& element_shape,
146 const xla::ShapeIndex& index) -> Status {
147 if (!element_shape.IsTuple()) {
148 return Status::OK();
149 }
150 std::vector<se::DeviceMemoryBase> elements;
151 xla::ShapeIndex element_index = index;
152 element_index.push_back(0);
153 for (int64 i = 0; i < element_shape.tuple_shapes_size(); ++i) {
154 // Gather all children of the tuple element.
155 element_index.back() = i;
156 elements.push_back(mem->Buffer(element_index).AsDeviceMemoryBase());
157 }
158 se::DeviceMemoryBase tuple_table_addr =
159 mem->Buffer(index).AsDeviceMemoryBase();
160 return transfer_manager->WriteSingleTupleIndexTable(
161 stream, elements, element_shape, &tuple_table_addr);
162 });
163 }
164
165 // Returns true if `dynamic_shape` has dimensions that are less-equal to the
166 // "bounded_shape".
DynamicShapeIsCompatible(const xla::Shape & dynamic_shape,const xla::Shape & bounded_shape)167 bool DynamicShapeIsCompatible(const xla::Shape& dynamic_shape,
168 const xla::Shape& bounded_shape) {
169 if (dynamic_shape.rank() != bounded_shape.rank()) {
170 return false;
171 }
172 for (int64 i = 0; i < dynamic_shape.rank(); ++i) {
173 if (dynamic_shape.dimensions(i) > bounded_shape.dimensions(i)) {
174 return false;
175 }
176 }
177 return true;
178 }
179
180 // For dynamic inputs, copy them and attach metadata of shape sizes to the
181 // end of the tensor.
182 //
183 // The buffer for dynamic shapes contains three parts:
184 // +--------+
185 // |Payload |
186 // +--------+
187 // | Padding|
188 // +--------+
189 // |Metadata|
190 // +--------+
191 //
192 // Metadata contains the sizes of shape without padding, eventually
193 // representing the size of valid data.
UpdateDynamicInputs(se::Stream * stream,se::DeviceMemoryAllocator * allocator,std::vector<xla::ExecutionInput> * runtime_inputs,const std::vector<xla::Shape> & compile_time_shapes)194 xla::Status UpdateDynamicInputs(
195 se::Stream* stream, se::DeviceMemoryAllocator* allocator,
196 std::vector<xla::ExecutionInput>* runtime_inputs,
197 const std::vector<xla::Shape>& compile_time_shapes) {
198 TF_RET_CHECK(runtime_inputs->size() == compile_time_shapes.size());
199 for (int64 i = 0; i < compile_time_shapes.size(); i++) {
200 // TODO(yunxing): Iterating over thousands of elements can be slow. One way
201 // to optimize for fast path without dynamic shapes is add a field in
202 // compilation result indicating if dynamic input is presented.
203 if (compile_time_shapes[i].is_static()) {
204 continue;
205 }
206 auto& runtime_input = (*runtime_inputs)[i];
207 xla::Shape compile_time_shapes_on_device =
208 HostShapeToDeviceShape(compile_time_shapes[i]);
209 bool element_modified = false;
210 TF_RETURN_IF_ERROR(xla::ShapeUtil::ForEachSubshapeWithStatus(
211 compile_time_shapes_on_device,
212 [&](const xla::Shape& compile_time_shape,
213 const xla::ShapeIndex& index) -> Status {
214 if (compile_time_shape.IsTuple() || compile_time_shape.is_static()) {
215 return Status::OK();
216 }
217
218 const xla::Shape& runtime_shape =
219 xla::ShapeUtil::GetSubshape(runtime_input.shape(), index);
220
221 TF_RET_CHECK(!runtime_shape.IsTuple());
222 TF_RET_CHECK(
223 DynamicShapeIsCompatible(runtime_shape, compile_time_shape));
224
225 xla::MaybeOwningDeviceMemory* mutable_input_mem =
226 runtime_input.MutableBuffer(index);
227 auto padded_data = std::make_shared<std::vector<int8>>(
228 ShapeSizeCompact(compile_time_shape), -1);
229 auto raw_input_runtime = std::make_shared<std::vector<uint32>>(
230 ShapeSizeCompact(runtime_shape) / sizeof(uint32));
231 stream->ThenMemcpyD2H(
232 se::DeviceMemory<int8>(mutable_input_mem->AsDeviceMemoryBase()),
233 absl::MakeSpan(absl::bit_cast<int8*>(raw_input_runtime->data()),
234 ShapeSizeCompactRaw(runtime_shape)));
235 stream->ThenDoHostCallback([raw_input_runtime, padded_data,
236 runtime_shape, compile_time_shape]() {
237 // After getting the data onto the host, transpose the data to
238 // the correct layout by delinearizing it and linearizing it again.
239 XLA_Shape c_runtime_shape, c_compile_time_shape;
240 ApiConverter::ToC(runtime_shape, &c_runtime_shape);
241 ApiConverter::ToC(compile_time_shape, &c_compile_time_shape);
242 StatusHelper status;
243
244 TpuExecute_RuntimeInputToPaddedData_Params params;
245 params.struct_size =
246 TpuExecute_RuntimeInputToPaddedData_Params_SIZE;
247 params.priv = nullptr;
248 params.runtime_input_ptr = raw_input_runtime->data();
249 params.runtime_input_size = raw_input_runtime->size();
250 params.padded_data_ptr = padded_data->data();
251 params.padded_data_size = padded_data->size();
252 params.runtime_shape = &c_runtime_shape;
253 params.compile_time_shape = &c_compile_time_shape;
254 params.status = status.c_status;
255
256 tensorflow::tpu::OpsApiFn()->TpuExecute_RuntimeInputToPaddedDataFn(
257 ¶ms);
258 ApiConverter::Free(&c_runtime_shape);
259 ApiConverter::Free(&c_compile_time_shape);
260 return status.status();
261 });
262 // Allocate new input and transfer the padded and transposed data to
263 // the new input location.
264 TF_ASSIGN_OR_RETURN(
265 auto new_input,
266 allocator->Allocate(stream->parent()->device_ordinal(),
267 ShapeSizeCompact(compile_time_shape)));
268 auto typed_new_input_memory =
269 se::DeviceMemory<int8>(new_input.cref());
270 stream->ThenMemcpyH2D<int8>(*padded_data, &typed_new_input_memory);
271
272 // Retain the memory until the end of the transfer.
273 stream->ThenDoHostCallback([padded_data]() { return Status::OK(); });
274
275 // Modify the memory location in the input shape tree to point to the
276 // new input.
277 *mutable_input_mem =
278 xla::MaybeOwningDeviceMemory(std::move(new_input));
279 element_modified = true;
280 return Status::OK();
281 }));
282 if (element_modified) {
283 // The input location has been modified, need to fix tuple table to
284 // point to the correct address.
285 TF_ASSIGN_OR_RETURN(
286 auto transfer_manager,
287 xla::TransferManager::GetForPlatform(stream->parent()->platform()));
288 TF_RETURN_IF_ERROR(FixTupleTableAsync(stream,
289 compile_time_shapes_on_device,
290 &runtime_input, transfer_manager));
291 }
292 }
293 return Status::OK();
294 }
295
TPUCancelExecution(Env * env,int device_ordinal)296 void TPUCancelExecution(Env* env, int device_ordinal) {
297 if (tpu_cancellation_terminates_process) {
298 LOG(INFO) << "TPUCancelExecution StopChipHeartbeats on device "
299 << device_ordinal;
300 Status status = TpuNodeContext::StopChipHeartbeats();
301 LOG(INFO) << "TPUCancelExecution StopChipHeartbeats done: " << status
302 << " on device " << device_ordinal;
303 // Sleep and exit in another thread so the cancellation manager can
304 // continue running callbacks. The new thread will call quick_exit,
305 // so we discard the returned Thread pointer because we won't have
306 // an opportunity to delete it.
307 (void)env->StartThread(ThreadOptions(), "tpu_execute_exit_countdown",
308 [env]() { ExitCountdown(env); });
309 } else if (tpu_cancellation_closes_chips) {
310 LOG(INFO) << "TPUCancelExecution CloseTPUHost on device " << device_ordinal;
311 Status status = TpuNodeContext::CloseTpuHost();
312 LOG(INFO) << "TPUCancelExecution CloseTPUHost done: " << status
313 << " on device " << device_ordinal;
314 } else {
315 LOG(INFO) << "TPUCancelExecution CloseTPUHost on device " << device_ordinal
316 << " is suppressed";
317 }
318 }
319
RegisterCancellation(OpKernelContext * ctx,CancellationManager * cancellation_manager,int device_ordinal)320 std::pair<CancellationToken, bool> RegisterCancellation(
321 OpKernelContext* ctx, CancellationManager* cancellation_manager,
322 int device_ordinal) {
323 // Set up a cancellation callback, to ensure the TPU program we run will
324 // halt if the RPC is cancelled. Without this the TPU program might block
325 // forever. The mechanism itself is a big hammer; we close all devices
326 // attached to this host on each cancellation callback. This is necessary to
327 // ensure the system will eventually halt, since the TensorNodes on each
328 // chip may be stuck waiting for mutual communication.
329 //
330 // By closing all devices, we ensure all subsequent attempts to use the
331 // device will fail, until the devices are re-initialized via a new call to
332 // tpu.initialize_system.
333 //
334 // In a multi-TensorNode setup, CloseTPUHost may be called once for each
335 // TensorNode, and each call will close all TensorNodes. This quadratic
336 // behavior ensures the mechanism is robust to various orderings
337 // (i.e. races) between the TPU programs, which are run on separate threads.
338 // In practice the quadratic behavior isn't that bad; the first call will
339 // actually halt any running TPU programs (which may be expensive), while
340 // subsequent calls will attempt to close an already-closed device (which is
341 // cheap).
342 //
343 // TODO(b/62262381): The cancellation manager is shared between multiple TPU
344 // execute ops and the cancellation will not be invoked only when RPC is
345 // cancelled (it may also be induced by OOM errors from a different TPU
346 // execute), this results in a pretty coarse cancellation domain. This
347 // cancellation callback should only execute in a narrower scope to not be
348 // triggered in such cases.
349 CancellationToken token = cancellation_manager->get_cancellation_token();
350 // Don't rely on OpKernelContext being available when the callback runs.
351 Env* env = ctx->env();
352 bool already_cancelled = !cancellation_manager->RegisterCallback(
353 token,
354 [device_ordinal, env]() { TPUCancelExecution(env, device_ordinal); });
355 return std::pair<CancellationToken, bool>(token, already_cancelled);
356 }
357
UnregisterCancellation(OpKernelContext * ctx,CancellationManager * cancellation_manager,se::Stream * stream,int device_ordinal,CancellationToken token,std::shared_ptr<HostTransferManager> host_transfer_manager)358 void UnregisterCancellation(
359 OpKernelContext* ctx, CancellationManager* cancellation_manager,
360 se::Stream* stream, int device_ordinal, CancellationToken token,
361 std::shared_ptr<HostTransferManager> host_transfer_manager) {
362 // If execution reaches this point, the host callback enqueued below will get
363 // called regardless of stream status. Call inc_num_deferred_ops_function here
364 // and dec_num_deferred_ops_function in the host callback.
365 ctx->inc_num_deferred_ops_function()();
366 auto dec_num_deferred_ops_function = ctx->dec_num_deferred_ops_function();
367
368 // Try to avoid running callbacks on the compute stream, because this reduces
369 // the frequency of back-to-back programs (which are most efficient because
370 // they don't require host synchronization). Instead, borrow a substream and
371 // have the substream wait on the compute stream.
372 se::Stream* deregister_stream = stream->GetOrCreateSubStream();
373 deregister_stream->ThenWaitFor(stream);
374 deregister_stream->ThenDoHostCallback([=]() {
375 // Ensure the host_transfer_manager is copied into the callback scope.
376 (void)host_transfer_manager;
377
378 // We must deregister the callback in the success case, to avoid closing all
379 // devices. In the failure case we must NOT call DeregisterCallback as that
380 // waits for all previous cancellation callbacks to complete and any call
381 // to XlaDevice::Sync() will cause deadlock. Consider:
382 // 1) CancellationManager::StartCancel() is in progress (state is
383 // cancelling_).
384 // 2) The call below to DeregisterCallback will block until state is
385 // cancelled_ (all callbacks are completed).
386 // 3) A different cancellation callback has called XlaDevice::Sync(),
387 // which will block until (2) is done.
388 // 4) StartCancel() in (1) cannot complete until (3) is done.
389 //
390 // Instead, call TryDeregisterCallback. The functional difference is
391 // TryDeregisterCallback will not block if cancellation is in proress
392 // so makes no guarantees as to the state of any callbacks.
393 // This is not a problem, as our cancellation handler does not rely on
394 // any external state.
395 VLOG(1) << "cancellation_manager->TryDeregisterCallback on device "
396 << device_ordinal;
397 cancellation_manager->TryDeregisterCallback(token);
398 VLOG(1) << "cancellation_manager->TryDeregisterCallback done on device "
399 << device_ordinal;
400
401 // ExecutorState is held alive until at least this point to ensure
402 // cancellation_manager is valid. After all outstanding
403 // dec_num_deferred_ops_function are called, ExecutorState::Finish will be
404 // allowed to proceed.
405 dec_num_deferred_ops_function();
406 });
407 stream->ReturnSubStream(deregister_stream);
408 }
409
410 } // namespace
411
TPUExecute(const TPUExecutableInfoProto & executable,const TPUHostTransferInfoProto & host_transfers,const xla::HloProto & hlo_metadata,std::vector<xla::ExecutionInput> arguments,const string & rendezvous_key_base,uint32 rng_seed,TpuNodeContext * node_context,xla::DeviceAssignment * device_assignment,CancellationManager * cancellation_manager,OpKernelContext * ctx,stream_executor::Stream * stream,stream_executor::Stream * host_to_device_stream,const XLA_TpuProgram * tpu_program)412 xla::StatusOr<xla::ExecutionOutput> TPUExecute(
413 const TPUExecutableInfoProto& executable,
414 const TPUHostTransferInfoProto& host_transfers,
415 const xla::HloProto& hlo_metadata,
416 std::vector<xla::ExecutionInput> arguments,
417 const string& rendezvous_key_base, uint32 rng_seed,
418 TpuNodeContext* node_context, xla::DeviceAssignment* device_assignment,
419 CancellationManager* cancellation_manager, OpKernelContext* ctx,
420 stream_executor::Stream* stream,
421 stream_executor::Stream* host_to_device_stream,
422 const XLA_TpuProgram* tpu_program) {
423 profiler::TraceMe traceme("TPUExecute", 2);
424 TF_RET_CHECK(tpu::TpuPlatformInterface::GetRegisteredPlatform() != nullptr);
425 TF_RET_CHECK(tpu_program != nullptr);
426 VLOG(1) << "TPUExecute on device " << node_context->device_ordinal();
427
428 xla::Backend* backend = node_context->backend();
429
430 // Create a HostTransferManager to handle Send/Recv operations from the TPU.
431 std::shared_ptr<HostTransferManager> host_transfer_manager =
432 std::make_shared<HostTransferManager>(node_context, backend);
433 TF_ASSIGN_OR_RETURN(HostTransferManager::HostCommmandHandler handler,
434 host_transfer_manager->Initialize(
435 host_transfers, rendezvous_key_base, ctx));
436
437 VLOG(2) << "Cloud TPU: Executing computation on device "
438 << node_context->device_ordinal();
439
440 xla::ExecutableRunOptions run_options;
441 run_options.set_stream(stream);
442 run_options.set_device_assignment(device_assignment);
443 run_options.set_rng_seed(rng_seed);
444 run_options.set_allocator(backend->memory_allocator());
445 run_options.set_host_to_device_stream(host_to_device_stream);
446
447 const xla::ServiceExecutableRunOptions service_run_options(run_options);
448
449 std::unique_ptr<xla::HloModule> module;
450 std::vector<xla::Shape> input_shapes;
451 {
452 xla::ComputationLayout computation_layout(
453 xla::ShapeLayout(xla::Shape(executable.output_shape())));
454 for (const xla::ShapeProto& shape_proto : executable.input_shapes()) {
455 xla::Shape shape(shape_proto);
456 computation_layout.add_parameter_layout(xla::ShapeLayout(shape));
457 input_shapes.push_back(std::move(shape));
458 }
459 module = absl::make_unique<xla::HloModule>(
460 "TpuExecutableModule",
461 xla::HloModuleConfig(std::move(computation_layout)));
462 }
463
464 TF_ASSIGN_OR_RETURN(
465 module->input_output_alias_config(),
466 xla::HloInputOutputAliasConfig::CreateFromProto(
467 backend->transfer_manager()->HostShapeToDeviceShape(
468 module->config().entry_computation_layout().result_shape()),
469 hlo_metadata.hlo_module().input_output_alias()));
470 TF_RET_CHECK(executable.input_shapes().size() == arguments.size());
471
472 for (auto& prefetch : hlo_metadata.hlo_module().cross_program_prefetches()) {
473 module->AddCrossProgramPrefetch(
474 prefetch.parameter(),
475 xla::ShapeIndex(prefetch.index().begin(), prefetch.index().end()));
476 }
477
478 TF_RETURN_IF_ERROR(UpdateDynamicInputs(stream, backend->memory_allocator(),
479 &arguments, input_shapes));
480
481 auto tpu_executable = absl::make_unique<xla::TpuExecutable>(
482 tpu_program, std::move(module), /*host_command_handler=*/handler);
483
484 const int32 device_ordinal = node_context->device_ordinal();
485 CancellationToken token;
486 bool already_cancelled;
487 std::tie(token, already_cancelled) =
488 RegisterCancellation(ctx, cancellation_manager, device_ordinal);
489
490 // If the RPC was already cancelled before we managed to register the
491 // cancellation callback, we shouldn't attempt to run the TPU program, since
492 // it might block forever.
493 if (already_cancelled) {
494 return errors::Cancelled(
495 "RPC cancelled, not running TPU program on device ", device_ordinal);
496 }
497
498 xla::StatusOr<xla::ExecutionOutput> output =
499 tpu_executable->ExecuteAsyncOnStream(&service_run_options,
500 std::move(arguments),
501 /*hlo_execution_profile=*/nullptr);
502
503 // If !output.ok(), it means we failed to enqueue the program the TPU. This is
504 // possibly caused by a failed cancellation callback closing the chips.
505 if (!output.ok()) {
506 // If cancellation manager is already cancelled or cancelling, it means
507 // another failure has occurred earlier and this TpuExecuteOp is cancelled
508 // regardless of whether itself is an error.
509 already_cancelled = cancellation_manager->IsCancelling() ||
510 cancellation_manager->IsCancelled();
511 if (already_cancelled) {
512 return errors::Cancelled(
513 "RPC cancelled, not running TPU program on device ", device_ordinal);
514 }
515 }
516 UnregisterCancellation(ctx, cancellation_manager, stream, device_ordinal,
517 token, host_transfer_manager);
518 VLOG(1) << "Cloud TPU: TPUExecute done";
519 return output;
520 }
521
522 } // namespace tensorflow
523