1# Lint as: python3 2# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16"""Module for `ClusterCoordinator` and relevant cluster-worker related library. 17 18This is currently under development and the API is subject to change. 19""" 20 21from __future__ import absolute_import 22from __future__ import division 23from __future__ import print_function 24 25import contextlib 26import os 27import re 28import threading 29import time 30import weakref 31from six.moves import queue 32 33from tensorflow.python.distribute import parameter_server_strategy_v2 34from tensorflow.python.distribute.coordinator import metric_utils 35from tensorflow.python.distribute.coordinator import values as values_lib 36from tensorflow.python.eager import cancellation 37from tensorflow.python.eager import context 38from tensorflow.python.eager import def_function 39from tensorflow.python.eager import executor 40from tensorflow.python.eager import function as tf_function 41from tensorflow.python.framework import errors 42from tensorflow.python.framework import func_graph 43from tensorflow.python.framework import ops 44from tensorflow.python.platform import tf_logging as logging 45from tensorflow.python.util import nest 46from tensorflow.python.util.tf_export import tf_export 47 48# Maximum time for failed worker to come back is 1 hour 49_WORKER_MAXIMUM_RECOVERY_SEC = 3600 50 51# Maximum size for queued closures, "infinite" if set to 0. 52# When the maximum queue size is reached, further schedule calls will become 53# blocking until some previously queued closures are executed on workers. 54# Note that using an "infinite" queue size can take a non-trivial portion of 55# memory, and even lead to coordinator OOM. Modify the size to a smaller value 56# for coordinator with constrained memory resource (only recommended for 57# advanced users). Also used in unit tests to ensure the correctness when the 58# queue is full. 59_CLOSURE_QUEUE_MAX_SIZE = 256 * 1024 60 61# RPC error message from PS 62_RPC_ERROR_FROM_PS = "GRPC error information from remote target /job:ps" 63 64# InvalidArgumentError (unknown device) will not have "GRPC error..." string. 65_JOB_WORKER_STRING_IDENTIFIER = "/job:worker" 66 67 68RemoteValueStatus = values_lib.RemoteValueStatus 69RemoteValue = values_lib.RemoteValue 70RemoteValueImpl = values_lib.RemoteValueImpl 71PerWorkerValues = values_lib.PerWorkerValues 72 73 74class InputError(Exception): 75 76 def __init__(self, original_exception): 77 self.original_exception = original_exception 78 message = ("Input has an error, the original exception is %r, " 79 "error message is %s." % 80 (original_exception, str(original_exception))) 81 super().__init__(message) 82 self.with_traceback(original_exception.__traceback__) 83 84 85def _maybe_rebuild_remote_values(worker, structure): 86 """Attempts to return errors from `RemoteValue`s. Rebuilds them if needed.""" 87 errors_in_structure = [] 88 89 def _get_error(val): 90 if isinstance(val, RemoteValue): 91 if val._status is RemoteValueStatus.ABORTED: # pylint: disable=protected-access 92 try: 93 # This attempts to rebuild the resource on the worker, which may fail 94 # if the worker or PS is unavailable. If it fails, the original 95 # RemoteValue that requests a result will receive an `InputError`, 96 # which gets handled by `wait_on_failure` in `process_closure`. 97 val._rebuild_on(worker) # pylint: disable=protected-access 98 except Exception as e: # pylint: disable=broad-except 99 val._set_error(e) # pylint: disable=protected-access 100 101 error = val._get_error() # pylint: disable=protected-access 102 if error: 103 errors_in_structure.append(error) 104 105 nest.map_structure(_get_error, structure) 106 if errors_in_structure: 107 return errors_in_structure[0] 108 else: 109 return None 110 111 112def _maybe_get_remote_value(val): 113 """Gets the value of `val` if it is a `RemoteValue`.""" 114 if isinstance(val, RemoteValue): 115 error = val._get_error() # pylint: disable=protected-access 116 if error: 117 raise AssertionError( 118 "RemoteValue doesn't have a value because it has errors.") 119 else: 120 return val._get_values() # pylint: disable=protected-access 121 else: 122 return val 123 124 125def _maybe_as_type_spec(val): 126 if isinstance(val, (RemoteValue, PerWorkerValues)): 127 if val._type_spec is None: # pylint: disable=protected-access 128 raise ValueError("Output of a scheduled function that is not " 129 "tf.function cannot be the input of another function.") 130 return val._type_spec # pylint: disable=protected-access 131 else: 132 return val 133 134 135def _select_worker_slice(worker_id, structured): 136 """Selects the worker slice of each of the items in `structured`.""" 137 138 def _get(x): 139 return x._values[worker_id] if isinstance(x, PerWorkerValues) else x # pylint: disable=protected-access 140 141 return nest.map_structure(_get, structured) 142 143 144def _disallow_remote_value_as_input(structured): 145 """Raises if any element of `structured` is a RemoteValue.""" 146 147 def _raise_if_remote_value(x): 148 if isinstance(x, RemoteValue): 149 raise ValueError( 150 "`tf.distribute.experimental.coordinator.RemoteValue` used " 151 "as an input to scheduled function is not yet " 152 "supported.") 153 154 nest.map_structure(_raise_if_remote_value, structured) 155 156 157class Closure(object): 158 """Hold a function to be scheduled and its arguments.""" 159 160 def __init__(self, function, cancellation_mgr, args=None, kwargs=None): 161 if not callable(function): 162 raise ValueError("Function passed to `ClusterCoordinator.schedule` must " 163 "be a callable object.") 164 self._args = args or () 165 self._kwargs = kwargs or {} 166 167 _disallow_remote_value_as_input(self._args) 168 _disallow_remote_value_as_input(self._kwargs) 169 170 if isinstance(function, def_function.Function): 171 replica_args = _select_worker_slice(0, self._args) 172 replica_kwargs = _select_worker_slice(0, self._kwargs) 173 174 # Note: no need to handle function registration failure since this kind of 175 # failure will not raise exceptions as designed in the runtime. The 176 # coordinator has to rely on subsequent operations that raise to catch 177 # function registration failure. 178 179 # Record the function tracing overhead. Note that we pass in the tracing 180 # count of the def_function.Function as a state tracker, so that metrics 181 # will only record the time for actual function tracing (i.e., excluding 182 # function cache lookups). 183 with metric_utils.monitored_timer( 184 "function_tracing", state_tracker=function._get_tracing_count): # pylint: disable=protected-access 185 self._concrete_function = function.get_concrete_function( 186 *nest.map_structure(_maybe_as_type_spec, replica_args), 187 **nest.map_structure(_maybe_as_type_spec, replica_kwargs)) 188 elif isinstance(function, tf_function.ConcreteFunction): 189 self._concrete_function = function 190 191 if hasattr(self, "_concrete_function"): 192 # If we have a concrete function, we get to retrieve the output type spec 193 # via the structured_output. 194 self._output_type_spec = func_graph.convert_structure_to_signature( 195 self._concrete_function.structured_outputs) 196 self._function = cancellation_mgr.get_cancelable_function( 197 self._concrete_function) 198 else: 199 # Otherwise (i.e. what is passed in is a regular python function), we have 200 # no such information. 201 self._output_type_spec = None 202 self._function = function 203 204 self._output_remote_value_ref = None 205 206 def build_output_remote_value(self): 207 if self._output_remote_value_ref is None: 208 ret = RemoteValueImpl(None, self._output_type_spec) 209 self._output_remote_value_ref = weakref.ref(ret) 210 return ret 211 else: 212 raise ValueError( 213 "The output of the Closure cannot be built more than once.") 214 215 def maybe_call_with_output_remote_value(self, method): 216 if self._output_remote_value_ref is None: 217 return None 218 output_remote_value = self._output_remote_value_ref() 219 if output_remote_value is not None: 220 return method(output_remote_value) 221 return None 222 223 def mark_cancelled(self): 224 e = errors.CancelledError( 225 None, None, "The corresponding function is " 226 "cancelled. Please reschedule the function.") 227 self.maybe_call_with_output_remote_value(lambda r: r._set_error(e)) # pylint: disable=protected-access 228 229 def execute_on(self, worker): 230 """Executes the closure on the given worker. 231 232 Args: 233 worker: a `Worker` object. 234 """ 235 replica_args = _select_worker_slice(worker.worker_index, self._args) 236 replica_kwargs = _select_worker_slice(worker.worker_index, self._kwargs) 237 238 e = ( 239 _maybe_rebuild_remote_values(worker, replica_args) or 240 _maybe_rebuild_remote_values(worker, replica_kwargs)) 241 if e: 242 if not isinstance(e, InputError): 243 e = InputError(e) 244 raise e 245 246 with ops.device(worker.device_name): 247 with context.executor_scope(worker.executor): 248 with metric_utils.monitored_timer("closure_execution"): 249 output_values = self._function( 250 *nest.map_structure(_maybe_get_remote_value, replica_args), 251 **nest.map_structure(_maybe_get_remote_value, replica_kwargs)) 252 self.maybe_call_with_output_remote_value( 253 lambda r: r._set_values(output_values)) # pylint: disable=protected-access 254 255 256class ResourceClosure(Closure): 257 258 def build_output_remote_value(self): 259 if self._output_remote_value_ref is None: 260 # We need to remember the Closure object in the `RemoteValue` here. 261 ret = RemoteValueImpl(self, self._output_type_spec) 262 self._output_remote_value_ref = weakref.ref(ret) 263 return ret 264 else: 265 return self._output_remote_value_ref() 266 267 268class _CoordinatedClosureQueue(object): 269 """Manage a queue of closures, inflight count and errors from execution. 270 271 This class is thread-safe. 272 """ 273 274 def __init__(self): 275 # `self._inflight_closure_count` only tracks the number of inflight closures 276 # that are "in generation". Once an error occurs, error generation is 277 # incremented and all subsequent arriving closures (from inflight) are 278 # considered "out of generation". 279 self._inflight_closure_count = 0 280 281 self._queue_lock = threading.Lock() 282 283 # Condition indicating that all pending closures (either queued or inflight) 284 # have been processed, failed, or cancelled. 285 self._stop_waiting_condition = threading.Condition(self._queue_lock) 286 287 # Condition indicating that an item becomes available in queue (not empty). 288 self._closures_queued_condition = threading.Condition(self._queue_lock) 289 self._should_process_closures = True 290 291 # Condition indicating that a queue slot becomes available (not full). 292 # Note that even with "infinite" queue size, there is still a "practical" 293 # size limit for the queue depending on host memory capacity, and thus the 294 # queue will eventually become full with a lot of enqueued closures. 295 self._queue_free_slot_condition = threading.Condition(self._queue_lock) 296 297 # Condition indicating there is no inflight closures. 298 self._no_inflight_closure_condition = threading.Condition(self._queue_lock) 299 300 # Use to cancel in-flight closures. 301 self._cancellation_mgr = cancellation.CancellationManager() 302 303 if _CLOSURE_QUEUE_MAX_SIZE <= 0: 304 logging.warning( 305 "In a `ClusterCoordinator`, creating an infinite closure queue can " 306 "consume a significant amount of memory and even lead to OOM.") 307 self._queue = queue.Queue(maxsize=_CLOSURE_QUEUE_MAX_SIZE) 308 self._error = None 309 310 # The following is a lock to make sure when `wait` is called and before it 311 # returns no `put` can be executed during this period. It is because `wait` 312 # won't know what to do with newly put closures. This lock adds an cutoff 313 # for `wait` so that closures put into the queue while waiting would not be 314 # taken responsible by this `wait`. 315 # 316 # We cannot reuse the `self._queue_lock` since when `wait` waits for a 317 # condition, the `self._queue_lock` will be released. 318 # 319 # We don't use a reader/writer's lock on purpose to reduce the complexity 320 # of the code. 321 self._put_wait_lock = threading.Lock() 322 323 def stop(self): 324 with self._queue_lock: 325 self._should_process_closures = False 326 self._closures_queued_condition.notifyAll() 327 328 def _cancel_all_closures(self): 329 """Clears the queue and sets remaining closures cancelled error. 330 331 This method expects self._queue_lock to be held prior to entry. 332 """ 333 self._cancellation_mgr.start_cancel() 334 while self._inflight_closure_count > 0: 335 self._no_inflight_closure_condition.wait() 336 while True: 337 try: 338 closure = self._queue.get(block=False) 339 self._queue_free_slot_condition.notify() 340 closure.mark_cancelled() 341 except queue.Empty: 342 break 343 # The cancellation manager cannot be reused once cancelled. After all 344 # closures (queued or inflight) are cleaned up, recreate the cancellation 345 # manager with clean state. 346 # Note on thread-safety: this is triggered when one of theses 347 # ClusterCoordinator APIs are called: `schedule`, `wait`, and `done`. At the 348 # same time, no new closures can be constructed (which reads the 349 # _cancellation_mgr to get cancellable functions). 350 self._cancellation_mgr = cancellation.CancellationManager() 351 352 def _raise_if_error(self): 353 """Raises the error if one exists. 354 355 If an error exists, cancel the closures in queue, raises it, and clear 356 the error. 357 358 This method expects self._queue_lock to be held prior to entry. 359 """ 360 if self._error: 361 logging.error("Start cancelling closures due to error %r: %s", 362 self._error, self._error) 363 self._cancel_all_closures() 364 try: 365 raise self._error # pylint: disable=raising-bad-type 366 finally: 367 self._error = None 368 369 def put(self, closure): 370 """Put a closure into the queue for later execution. 371 372 If `mark_failed` was called before `put`, the error from the first 373 invocation of `mark_failed` will be raised. 374 375 Args: 376 closure: The `Closure` to put into the queue. 377 """ 378 with self._put_wait_lock, self._queue_lock: 379 self._queue_free_slot_condition.wait_for(lambda: not self._queue.full()) 380 self._queue.put(closure, block=False) 381 self._raise_if_error() 382 self._closures_queued_condition.notify() 383 384 def get(self, timeout=None): 385 """Return a closure from the queue to be executed.""" 386 with self._queue_lock: 387 while self._queue.empty() and self._should_process_closures: 388 if not self._closures_queued_condition.wait(timeout=timeout): 389 return None 390 if not self._should_process_closures: 391 return None 392 closure = self._queue.get(block=False) 393 self._queue_free_slot_condition.notify() 394 self._inflight_closure_count += 1 395 return closure 396 397 def mark_finished(self): 398 """Let the queue know that a closure has been successfully executed.""" 399 with self._queue_lock: 400 if self._inflight_closure_count < 1: 401 raise AssertionError("There is no inflight closures to mark_finished.") 402 self._inflight_closure_count -= 1 403 if self._inflight_closure_count == 0: 404 self._no_inflight_closure_condition.notifyAll() 405 if self._queue.empty() and self._inflight_closure_count == 0: 406 self._stop_waiting_condition.notifyAll() 407 408 def put_back(self, closure): 409 """Put the closure back into the queue as it was not properly executed.""" 410 with self._queue_lock: 411 if self._inflight_closure_count < 1: 412 raise AssertionError("There is no inflight closures to put_back.") 413 if self._error: 414 closure.mark_cancelled() 415 else: 416 self._queue_free_slot_condition.wait_for(lambda: not self._queue.full()) 417 self._queue.put(closure, block=False) 418 self._closures_queued_condition.notify() 419 self._inflight_closure_count -= 1 420 if self._inflight_closure_count == 0: 421 self._no_inflight_closure_condition.notifyAll() 422 423 def wait(self, timeout=None): 424 """Wait for all closures to be finished before returning. 425 426 If `mark_failed` was called before or during `wait`, the error from the 427 first invocation of `mark_failed` will be raised. 428 429 Args: 430 timeout: A float specifying a timeout for the wait in seconds. 431 432 Returns: 433 True unless the given timeout expired, in which case it returns False. 434 """ 435 with self._put_wait_lock, self._queue_lock: 436 while (not self._error and 437 (not self._queue.empty() or self._inflight_closure_count > 0)): 438 if not self._stop_waiting_condition.wait(timeout=timeout): 439 return False 440 self._raise_if_error() 441 return True 442 443 def mark_failed(self, e): 444 """Sets error and unblocks any wait() call.""" 445 with self._queue_lock: 446 # TODO(yuefengz): maybe record all failure and give users more 447 # information? 448 if self._inflight_closure_count < 1: 449 raise AssertionError("There is no inflight closures to mark_failed.") 450 if self._error is None: 451 self._error = e 452 self._inflight_closure_count -= 1 453 if self._inflight_closure_count == 0: 454 self._no_inflight_closure_condition.notifyAll() 455 self._stop_waiting_condition.notifyAll() 456 457 def done(self): 458 """Returns true if the queue is empty and there is no inflight closure. 459 460 If `mark_failed` was called before `done`, the error from the first 461 invocation of `mark_failed` will be raised. 462 """ 463 with self._queue_lock: 464 self._raise_if_error() 465 return self._queue.empty() and self._inflight_closure_count == 0 466 467 468class WorkerPreemptionHandler(object): 469 """Handles worker preemptions.""" 470 471 def __init__(self, server_def, cluster): 472 self._server_def = server_def 473 self._cluster = cluster 474 self._cluster_update_lock = threading.Lock() 475 self._cluster_due_for_update_or_finish = threading.Event() 476 self._worker_up_cond = threading.Condition(self._cluster_update_lock) 477 self._error_from_recovery = None 478 self._should_preemption_thread_run = True 479 self._preemption_handler_thread = threading.Thread( 480 target=self._preemption_handler, 481 name="WorkerPreemptionHandler", 482 daemon=True) 483 self._preemption_handler_thread.start() 484 485 def stop(self): 486 """Ensure the worker preemption thread is closed.""" 487 self._should_preemption_thread_run = False 488 with self._cluster_update_lock: 489 self._cluster_due_for_update_or_finish.set() 490 # TODO(yuefengz): The preemption handler thread shouldn't be terminated 491 # asynchronously since it touches eager context which is a process-wide 492 # singleton. The problem is in OSS unit tests will time out. 493 494 def _validate_preemption_failure(self, e): 495 """Validates that the given exception represents worker preemption.""" 496 497 # Only categorize the failure as a worker preemption if the cancellation 498 # manager did not attempt to cancel the blocking operations. 499 if _is_worker_failure(e) and ( 500 not self._cluster.closure_queue._cancellation_mgr.is_cancelled): # pylint: disable=protected-access 501 return 502 raise e 503 504 @contextlib.contextmanager 505 def wait_on_failure(self, 506 on_failure_fn=None, 507 on_transient_failure_fn=None, 508 on_recovery_fn=None, 509 worker_device_name="(unknown)"): 510 """Catches worker preemption error and wait until failed workers are back. 511 512 Args: 513 on_failure_fn: an optional function to run if preemption happens. 514 on_transient_failure_fn: an optional function to run if transient failure 515 happens. 516 on_recovery_fn: an optional function to run when a worker is recovered 517 from preemption. 518 worker_device_name: the device name of the worker instance that is passing 519 through the failure. 520 521 Yields: 522 None. 523 """ 524 assert self._should_preemption_thread_run 525 try: 526 yield 527 except (errors.OpError, InputError) as e: 528 # If the error is due to temporary connectivity issues between worker and 529 # ps, put back closure, ignore error and do not mark worker as failure. 530 if self._cluster._record_and_ignore_transient_ps_failure(e): # pylint: disable=protected-access 531 logging.error( 532 "Remote function on worker %s failed with %r:%s\n" 533 "It is treated as a transient connectivity failure for now.", 534 worker_device_name, e, e) 535 if on_transient_failure_fn: 536 on_transient_failure_fn() 537 return 538 539 # Ignoring derived CancelledErrors to tolerate transient failures in 540 # PS-worker communication, which initially exposed as an UnavailableError 541 # and then lead to sub-function cancellation, subsequently getting 542 # reported from worker to chief as CancelledError. 543 # We do not mark either worker or PS as failed due to only CancelledError. 544 # If there are real (non-transient) failures, they must also be reported 545 # as other errors (UnavailableError most likely) in closure executions. 546 if isinstance(e, errors.CancelledError) and "/job:" in str(e): 547 logging.error( 548 "Remote function on worker %s failed with %r:%s\n" 549 "This derived error is ignored and not reported to users.", 550 worker_device_name, e, e) 551 if on_transient_failure_fn: 552 on_transient_failure_fn() 553 return 554 555 # This reraises the error, if it's not considered recoverable; otherwise, 556 # the following failure recovery logic run. At this time, only worker 557 # unavailability is recoverable. PS unavailability as well as other 558 # errors in the user function is not recoverable. 559 self._validate_preemption_failure(e) 560 561 logging.error("Worker %s failed with %r:%s", worker_device_name, e, e) 562 if on_failure_fn: 563 on_failure_fn() 564 565 with self._cluster_update_lock: 566 self._cluster_due_for_update_or_finish.set() 567 self._worker_up_cond.wait(_WORKER_MAXIMUM_RECOVERY_SEC) 568 if self._error_from_recovery: 569 # TODO(yuefengz): there is only one worker that will get this error. 570 # Ideally we shuold let all workers notified by `_worker_up_cond` get 571 # this error. 572 try: 573 raise self._error_from_recovery 574 finally: 575 self._error_from_recovery = None 576 logging.info("Worker %s has been recovered.", worker_device_name) 577 578 if on_recovery_fn: 579 with self.wait_on_failure( 580 on_recovery_fn=on_recovery_fn, 581 on_transient_failure_fn=on_transient_failure_fn, 582 worker_device_name=worker_device_name): 583 on_recovery_fn() 584 585 def _preemption_handler(self): 586 """A loop that handles preemption. 587 588 This loop waits for signal of worker preemption and upon worker preemption, 589 it waits until all workers are back and updates the cluster about the 590 restarted workers. 591 """ 592 assert self._should_preemption_thread_run 593 while True: 594 self._cluster_due_for_update_or_finish.wait() 595 if not self._should_preemption_thread_run: 596 logging.info("Stopping the failure handing thread.") 597 break 598 599 with self._cluster_update_lock: 600 try: 601 # TODO(haoyuzhang): support partial cluster recovery 602 logging.info("Cluster now being recovered.") 603 context.context().update_server_def(self._server_def) 604 605 # Cluster updated successfully, clear the update signal, and notify 606 # all workers that they are recovered from failure. 607 logging.info("Cluster successfully recovered.") 608 self._worker_up_cond.notify_all() 609 # The check for _should_preemption_thread_run is necessary since the 610 # `stop` may have already set _cluster_due_for_update_or_finish. 611 if self._should_preemption_thread_run: 612 self._cluster_due_for_update_or_finish.clear() 613 except Exception as e: # pylint: disable=broad-except 614 try: 615 self._validate_preemption_failure(e) 616 except Exception as ps_e: # pylint: disable=broad-except 617 # In this case, a parameter server fails. So we raise this error to 618 # the caller of `wait_on_failure`. 619 self._error_from_recovery = ps_e 620 self._worker_up_cond.notify_all() 621 if self._should_preemption_thread_run: 622 self._cluster_due_for_update_or_finish.clear() 623 # NOTE: Since the first RPC (GetStatus) of update_server_def is 624 # currently blocking by default, error should only happen if: 625 # (1) More workers failed while waiting for the previous workers to 626 # come back; 627 # (2) Worker failed when exchanging subsequent RPCs after the first 628 # RPC returns. 629 # Consider adding backoff retry logic if we see the error logged 630 # too frequently. 631 logging.error("Cluster update failed with error: %s. Retrying...", e) 632 633 634class Worker(object): 635 """A worker in a cluster. 636 637 Attributes: 638 worker_index: The index of the worker in the cluster. 639 device_name: The device string of the worker, e.g. "/job:worker/task:1". 640 executor: The worker's executor for remote function execution. 641 failure_handler: The failure handler used to handler worker preemption 642 failure. 643 """ 644 645 def __init__(self, worker_index, device_name, cluster): 646 self.worker_index = worker_index 647 self.device_name = device_name 648 self.executor = executor.new_executor(enable_async=False) 649 self.failure_handler = cluster.failure_handler 650 self._cluster = cluster 651 self._resource_remote_value_refs = [] 652 self._should_worker_thread_run = True 653 654 # Worker threads need to start after `Worker`'s initialization. 655 threading.Thread(target=self._process_queue, 656 name="WorkerClosureProcessingLoop-%d" % self.worker_index, 657 daemon=True).start() 658 659 def stop(self): 660 """Ensure the worker thread is closed.""" 661 self._should_worker_thread_run = False 662 663 def _set_resources_aborted(self): 664 # TODO(yuefengz): maybe we can query whether a tensor is valid or not 665 # instead of marking a tensor aborted? 666 for weakref_resource in self._resource_remote_value_refs: 667 resource = weakref_resource() 668 if resource: 669 resource._set_aborted() # pylint: disable=protected-access 670 671 def _set_dead(self): 672 raise NotImplementedError("_set_dead is not implemented.") 673 674 def _process_closure(self, closure): 675 """Runs a closure with preemption handling.""" 676 assert closure is not None 677 try: 678 with self._cluster.failure_handler.wait_on_failure( 679 on_failure_fn=lambda: self._cluster.closure_queue.put_back(closure), 680 on_transient_failure_fn=lambda: self._cluster.closure_queue.put_back( 681 closure), 682 on_recovery_fn=self._set_resources_aborted, 683 worker_device_name=self.device_name): 684 closure.execute_on(self) 685 with metric_utils.monitored_timer("remote_value_fetch"): 686 # Copy the remote tensor to local (the coordinator) in case worker 687 # becomes unavailable at a later time. 688 closure.maybe_call_with_output_remote_value(lambda r: r.get()) 689 self._cluster.closure_queue.mark_finished() 690 except Exception as e: # pylint: disable=broad-except 691 # Avoid logging the derived cancellation error 692 if not isinstance(e, errors.CancelledError): 693 logging.error( 694 "/job:worker/task:%d encountered the following error when " 695 "processing closure: %r:%s", self.worker_index, e, e) 696 closure.maybe_call_with_output_remote_value(lambda r: r._set_error(e)) # pylint: disable=protected-access 697 self._cluster.closure_queue.mark_failed(e) 698 699 def _maybe_delay(self): 700 """Delay if corresponding env vars are set.""" 701 # If the following two env vars variables are set. Scheduling for workers 702 # will start in a staggered manner. Worker i will wait for 703 # `TF_COORDINATOR_SCHEDULE_START_DELAY` * i seconds, not exceeding 704 # `TF_COORDINATOR_SCHEDULE_START_DELAY_MAX`. 705 delay_secs = int(os.environ.get("TF_COORDINATOR_SCHEDULE_START_DELAY", "0")) 706 delay_cap = int( 707 os.environ.get("TF_COORDINATOR_SCHEDULE_START_DELAY_MAX", "0")) 708 if delay_cap: 709 delay_secs = min(delay_secs * self.worker_index, delay_cap) 710 if delay_secs > 0: 711 logging.info("Worker %d sleeping for %d seconds before running function", 712 self.worker_index, delay_secs) 713 time.sleep(delay_secs) 714 715 def _process_queue(self): 716 """Function running in a worker thread to process closure queues.""" 717 self._maybe_delay() 718 while self._should_worker_thread_run: 719 closure = self._cluster.closure_queue.get() 720 if not self._should_worker_thread_run or closure is None: 721 return 722 self._process_closure(closure) 723 # To properly stop the worker and preemption threads, it is important that 724 # `ClusterCoordinator` object is not held onto so its `__del__` can be 725 # called. By removing the reference to the `closure` that has already been 726 # processed, we ensure that the `closure` object is released, while 727 # getting the next `closure` at above `self._cluster.closure_queue.get()` 728 # call. 729 del closure 730 731 def create_resource(self, function, args=None, kwargs=None): 732 """Synchronously creates a per-worker resource represented by a `RemoteValue`. 733 734 Args: 735 function: the resource function to be run remotely. It should be a 736 `tf.function`, a concrete function or a Python function. 737 args: positional arguments to be passed to the function. 738 kwargs: keyword arguments to be passed to the function. 739 740 Returns: 741 one or several RemoteValue objects depending on the function return 742 values. 743 """ 744 # Some notes about the concurrency: currently all the activities related to 745 # the same worker such as creating resources, setting resources' aborted 746 # status, and executing closures happen on the same thread. This allows us 747 # to have simpler logic of concurrency. 748 closure = ResourceClosure( 749 function, 750 self._cluster.closure_queue._cancellation_mgr, # pylint: disable=protected-access 751 args=args, 752 kwargs=kwargs) 753 resource_remote_value = closure.build_output_remote_value() 754 self._register_resource(resource_remote_value) 755 756 # The following is a short-term solution to lazily create resources in 757 # parallel. 758 # TODO(b/160343165): we should create resources eagerly, i.e. schedule the 759 # resource creation function as soon as users call this method. 760 resource_remote_value._set_aborted() # pylint: disable=protected-access 761 return resource_remote_value 762 763 def _register_resource(self, resource_remote_value): 764 if not isinstance(resource_remote_value, RemoteValue): 765 raise ValueError("Resource being registered is not of type " 766 "`tf.distribute.experimental.coordinator.RemoteValue`.") 767 self._resource_remote_value_refs.append(weakref.ref(resource_remote_value)) 768 769 770class Cluster(object): 771 """A cluster with workers. 772 773 We assume all function errors are fatal and based on this assumption our 774 error reporting logic is: 775 1) Both `schedule` and `join` can raise a non-retryable error which is the 776 first error seen by the coordinator from any previously scheduled functions. 777 2) When an error is raised, there is no guarantee on how many previously 778 scheduled functions have been executed; functions that have not been executed 779 will be thrown away and marked as cancelled. 780 3) After an error is raised, the internal state of error will be cleared. 781 I.e. functions can continue to be scheduled and subsequent calls of `schedule` 782 or `join` will not raise the same error again. 783 784 Attributes: 785 failure_handler: The failure handler used to handler worker preemption 786 failure. 787 workers: a list of `Worker` objects in the cluster. 788 closure_queue: the global Closure queue. 789 """ 790 791 def __init__(self, strategy): 792 """Initializes the cluster instance.""" 793 794 self._num_workers = strategy._num_workers 795 self._num_ps = strategy._num_ps 796 797 # Ignore PS failures reported by workers due to transient connection errors. 798 # Transient connectivity issues between workers and PS are relayed by the 799 # workers to the coordinator, leading the coordinator to believe that there 800 # are PS failures. The difference between transient vs. permanent PS failure 801 # is the number of reports from the workers. When this env var is set to a 802 # positive integer K, the coordinator ignores up to K reports of a failed PS 803 # task, i.e., only when there are more than K trials of executing closures 804 # fail due to errors from the same PS instance do we consider the PS 805 # instance encounters a failure. 806 # TODO(b/164279603): Remove this workaround when the underlying connectivity 807 # issue in gRPC server is resolved. 808 self._transient_ps_failures_threshold = int( 809 os.environ.get("TF_COORDINATOR_IGNORE_TRANSIENT_PS_FAILURES", 3)) 810 self._potential_ps_failures_lock = threading.Lock() 811 self._potential_ps_failures_count = [0] * self._num_ps 812 813 self.closure_queue = _CoordinatedClosureQueue() 814 self.failure_handler = WorkerPreemptionHandler(context.get_server_def(), 815 self) 816 worker_device_strings = [ 817 "/job:worker/replica:0/task:%d" % i for i in range(self._num_workers) 818 ] 819 self.workers = [ 820 Worker(i, w, self) for i, w in enumerate(worker_device_strings) 821 ] 822 823 def stop(self): 824 """Stop worker, worker preemption threads, and the closure queue.""" 825 self.failure_handler.stop() 826 827 for worker in self.workers: 828 worker.stop() 829 self.closure_queue.stop() 830 831 def _record_and_ignore_transient_ps_failure(self, e): 832 """Records potential PS failures and return if failure should be ignored.""" 833 if self._transient_ps_failures_threshold <= 0 or not _is_ps_failure(e): 834 return False 835 836 ps_tasks = _extract_failed_ps_instances(str(e)) 837 with self._potential_ps_failures_lock: 838 for t in ps_tasks: 839 self._potential_ps_failures_count[t] += 1 840 # The number of UnavailableError encountered on this PS task exceeds the 841 # maximum number of ignored error 842 if (self._potential_ps_failures_count[t] >= 843 self._transient_ps_failures_threshold): 844 return False 845 return True 846 847 def schedule(self, function, args, kwargs): 848 """Schedules `function` to be dispatched to a worker for execution. 849 850 Args: 851 function: The function to be dispatched to a worker for execution 852 asynchronously. 853 args: Positional arguments for `fn`. 854 kwargs: Keyword arguments for `fn`. 855 856 Returns: 857 A `RemoteValue` object. 858 """ 859 closure = Closure( 860 function, 861 self.closure_queue._cancellation_mgr, # pylint: disable=protected-access 862 args=args, 863 kwargs=kwargs) 864 ret = closure.build_output_remote_value() 865 self.closure_queue.put(closure) 866 return ret 867 868 def join(self): 869 """Blocks until all scheduled functions are executed.""" 870 self.closure_queue.wait() 871 872 def done(self): 873 """Returns true if all scheduled functions are executed.""" 874 return self.closure_queue.done() 875 876 877@tf_export("distribute.experimental.coordinator.ClusterCoordinator", v1=[]) 878class ClusterCoordinator(object): 879 """An object to schedule and coordinate remote function execution. 880 881 This class is used to create fault-tolerant resources and dispatch functions 882 to remote TensorFlow servers. 883 884 Currently, this class is not supported to be used in a standalone manner. It 885 should be used in conjunction with a `tf.distribute` strategy that is designed 886 to work with it. The `ClusterCoordinator` class currently only works 887 `tf.distribute.experimental.ParameterServerStrategy`. 888 889 __The `schedule`/`join` APIs__ 890 891 The most important APIs provided by this class is the `schedule`/`join` pair. 892 The `schedule` API is non-blocking in that it queues a `tf.function` and 893 returns a `RemoteValue` immediately. The queued functions will be dispatched 894 to remote workers in background threads and their `RemoteValue`s will be 895 filled asynchronously. Since `schedule` doesn’t require worker assignment, the 896 `tf.function` passed in can be executed on any available worker. If the worker 897 it is executed on becomes unavailable before its completion, it will be 898 migrated to another worker. Because of this fact and function execution is not 899 atomic, a function may be executed more than once. 900 901 __Handling Task Failure__ 902 903 This class when used with 904 `tf.distribute.experimental.ParameterServerStrategy`, comes with built-in 905 fault tolerance for worker failures. That is, when some workers are not 906 available for any reason to be reached from the coordinator, the training 907 progress continues to be made with the remaining workers. Upon recovery of a 908 failed worker, it will be added for function execution after datasets created 909 by `create_per_worker_dataset` are re-built on it. 910 911 When a parameter server fails, a `tf.errors.UnavailableError` is raised by 912 `schedule`, `join` or `done`. In this case, in addition to bringing back the 913 failed parameter server, users should restart the coordinator so that it 914 reconnects to workers and parameter servers, re-creates the variables, and 915 loads checkpoints. If the coordinator fails, after the user brings it back, 916 the program will automatically connect to workers and parameter servers, and 917 continue the progress from a checkpoint. 918 919 It is thus essential that in user's program, a checkpoint file is periodically 920 saved, and restored at the start of the program. If an 921 `tf.keras.optimizers.Optimizer` is checkpointed, after restoring from a 922 checkpoiont, its `iterations` property roughly indicates the number of steps 923 that have been made. This can be used to decide how many epochs and steps are 924 needed before the training completion. 925 926 See `tf.distribute.experimental.ParameterServerStrategy` docstring for an 927 example usage of this API. 928 929 This is currently under development, and the API as well as implementation 930 are subject to changes. 931 """ 932 933 def __new__(cls, strategy): 934 # `ClusterCoordinator` is kept as a single instance to a given `Strategy`. 935 # TODO(rchao): Needs a lock for thread-safety 936 if strategy._cluster_coordinator is None: 937 strategy._cluster_coordinator = super( 938 ClusterCoordinator, cls).__new__(cls) 939 return strategy._cluster_coordinator 940 941 def __init__(self, strategy): 942 """Initialization of a `ClusterCoordinator` instance. 943 944 Args: 945 strategy: a supported `tf.distribute.Strategy` object. Currently, only 946 `tf.distribute.experimental.ParameterServerStrategy` is supported. 947 948 Raises: 949 ValueError: if the strategy being used is not supported. 950 """ 951 if not getattr(self, "_has_initialized", False): 952 if not isinstance(strategy, 953 parameter_server_strategy_v2.ParameterServerStrategyV2): 954 raise ValueError( 955 "Only `tf.distribute.experimental.ParameterServerStrategy` " 956 "is supported to work with " 957 "`tf.distribute.experimental.coordinator.ClusterCoordinator` " 958 "currently.") 959 self._strategy = strategy 960 self.strategy.extended._used_with_coordinator = True 961 self._cluster = Cluster(strategy) 962 self._has_initialized = True 963 964 def __del__(self): 965 self._cluster.stop() 966 967 @property 968 def strategy(self): 969 """Returns the `Strategy` associated with the `ClusterCoordinator`.""" 970 return self._strategy 971 972 def schedule(self, fn, args=None, kwargs=None): 973 """Schedules `fn` to be dispatched to a worker for asynchronous execution. 974 975 This method is non-blocking in that it queues the `fn` which will be 976 executed later and returns a 977 `tf.distribute.experimental.coordinator.RemoteValue` object immediately. 978 `fetch` can be called on it to wait for the function execution to finish 979 and retrieve its output from a remote worker. On the other hand, call 980 `tf.distribute.experimental.coordinator.ClusterCoordinator.join` to wait for 981 all scheduled functions to finish. 982 983 `schedule` guarantees that `fn` will be executed on a worker at least once; 984 it could be more than once if its corresponding worker fails in the middle 985 of its execution. Note that since worker can fail at any point when 986 executing the function, it is possible that the function is partially 987 executed, but `tf.distribute.experimental.coordinator.ClusterCoordinator` 988 guarantees that in those events, the function will eventually be executed on 989 any worker that is available. 990 991 If any previously scheduled function raises an error, `schedule` will raise 992 any one of those errors, and clear the errors collected so far. What happens 993 here, some of the previously scheduled functions may have not been executed. 994 User can call `fetch` on the returned 995 `tf.distribute.experimental.coordinator.RemoteValue` to inspect if they have 996 executed, failed, or cancelled, and reschedule the corresponding function if 997 needed. 998 999 When `schedule` raises, it guarantees that there is no function that is 1000 still being executed. 1001 1002 At this time, there is no support of worker assignment for function 1003 execution, or priority of the workers. 1004 1005 `args` and `kwargs` are the arguments passed into `fn`, when `fn` is 1006 executed on a worker. They can be 1007 `tf.distribute.experimental.coordinator.PerWorkerValues` and in this case, 1008 the argument will be substituted with the corresponding component on the 1009 target worker. Arguments that are not 1010 `tf.distribute.experimental.coordinator.PerWorkerValues` will be passed into 1011 `fn` as-is. Currently, `tf.distribute.experimental.coordinator.RemoteValue` 1012 is not supported to be input `args` or `kwargs`. 1013 1014 Args: 1015 fn: A `tf.function`; the function to be dispatched to a worker for 1016 execution asynchronously. Regular python funtion is not supported to be 1017 scheduled. 1018 args: Positional arguments for `fn`. 1019 kwargs: Keyword arguments for `fn`. 1020 1021 Returns: 1022 A `tf.distribute.experimental.coordinator.RemoteValue` object that 1023 represents the output of the function scheduled. 1024 1025 Raises: 1026 Exception: one of the exceptions caught by the coordinator from any 1027 previously scheduled function, since the last time an error was thrown 1028 or since the beginning of the program. 1029 """ 1030 if not isinstance(fn, 1031 (def_function.Function, tf_function.ConcreteFunction)): 1032 raise TypeError( 1033 "`tf.distribute.experimental.coordinator.ClusterCoordinator.schedule`" 1034 " only accepts a `tf.function` or a concrete function.") 1035 # Slot variables are usually created during function tracing time; thus 1036 # `schedule` needs to be called within the `strategy.scope()`. 1037 with self.strategy.scope(): 1038 self.strategy.extended._being_scheduled = True # pylint: disable=protected-access 1039 remote_value = self._cluster.schedule(fn, args=args, kwargs=kwargs) 1040 self.strategy.extended._being_scheduled = False # pylint: disable=protected-access 1041 return remote_value 1042 1043 def join(self): 1044 """Blocks until all the scheduled functions have finished execution. 1045 1046 If any previously scheduled function raises an error, `join` will fail by 1047 raising any one of those errors, and clear the errors collected so far. If 1048 this happens, some of the previously scheduled functions may have not been 1049 executed. Users can call `fetch` on the returned 1050 `tf.distribute.experimental.coordinator.RemoteValue` to inspect if they have 1051 executed, failed, or cancelled. If some that have been cancelled need to be 1052 rescheduled, users should call `schedule` with the function again. 1053 1054 When `join` returns or raises, it guarantees that there is no function that 1055 is still being executed. 1056 1057 Raises: 1058 Exception: one of the exceptions caught by the coordinator by any 1059 previously scheduled function since the last time an error was thrown or 1060 since the beginning of the program. 1061 """ 1062 self._cluster.join() 1063 1064 def done(self): 1065 """Returns whether all the scheduled functions have finished execution. 1066 1067 If any previously scheduled function raises an error, `done` will fail by 1068 raising any one of those errors. 1069 1070 When `done` returns True or raises, it guarantees that there is no function 1071 that is still being executed. 1072 1073 Returns: 1074 Whether all the scheduled functions have finished execution. 1075 Raises: 1076 Exception: one of the exceptions caught by the coordinator by any 1077 previously scheduled function since the last time an error was thrown or 1078 since the beginning of the program. 1079 """ 1080 return self._cluster.done() 1081 1082 def create_per_worker_dataset(self, dataset_fn): 1083 """Create dataset on workers by calling `dataset_fn` on worker devices. 1084 1085 This creates the given dataset generated by dataset_fn on workers 1086 and returns an object that represents the collection of those individual 1087 datasets. Calling `iter` on such collection of datasets returns a 1088 `tf.distribute.experimental.coordinator.PerWorkerValues`, which is a 1089 collection of iterators, where the iterators have been placed on respective 1090 workers. 1091 1092 Calling `next` on a `PerWorkerValues` of iterator is unsupported. The 1093 iterator is meant to be passed as an argument into 1094 `tf.distribute.experimental.coordinator.ClusterCoordinator.schedule`. When 1095 the scheduled function is about to be executed by a worker, the 1096 function will receive the individual iterator that corresponds to the 1097 worker. The `next` method can be called on an iterator inside a 1098 scheduled function when the iterator is an input of the function. 1099 1100 Currently the `schedule` method assumes workers are all the same and thus 1101 assumes the datasets on different workers are the same, except they may be 1102 shuffled differently if they contain a `dataset.shuffle` operation and a 1103 random seed is not set. Because of this, we also recommend the datasets to 1104 be repeated indefinitely and schedule a finite number of steps instead of 1105 relying on the `OutOfRangeError` from a dataset. 1106 1107 1108 Example: 1109 1110 ```python 1111 strategy = tf.distribute.experimental.ParameterServerStrategy( 1112 cluster_resolver=...) 1113 coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator( 1114 strategy=strategy) 1115 1116 @tf.function 1117 def worker_fn(iterator): 1118 return next(iterator) 1119 1120 def per_worker_dataset_fn(): 1121 return strategy.distribute_datasets_from_function( 1122 lambda x: tf.data.Dataset.from_tensor_slices([3] * 3)) 1123 1124 per_worker_dataset = coordinator.create_per_worker_dataset( 1125 per_worker_dataset_fn) 1126 per_worker_iter = iter(per_worker_dataset) 1127 remote_value = coordinator.schedule(worker_fn, args=(per_worker_iter,)) 1128 assert remote_value.fetch() == 3 1129 ``` 1130 1131 Args: 1132 dataset_fn: The dataset function that returns a dataset. This is to be 1133 executed on the workers. 1134 1135 Returns: 1136 An object that represents the collection of those individual 1137 datasets. `iter` is expected to be called on this object that returns 1138 a `tf.distribute.experimental.coordinator.PerWorkerValues` of the 1139 iterators (that are on the workers). 1140 """ 1141 return values_lib.get_per_worker_dataset(dataset_fn, self) 1142 1143 def _create_per_worker_resources(self, fn, args=None, kwargs=None): 1144 """Synchronously create resources on the workers. 1145 1146 The resources are represented by 1147 `tf.distribute.experimental.coordinator.RemoteValue`s. 1148 1149 Args: 1150 fn: The function to be dispatched to all workers for execution 1151 asynchronously. 1152 args: Positional arguments for `fn`. 1153 kwargs: Keyword arguments for `fn`. 1154 1155 Returns: 1156 A `tf.distribute.experimental.coordinator.PerWorkerValues` object, which 1157 wraps a tuple of `tf.distribute.experimental.coordinator.RemoteValue` 1158 objects. 1159 """ 1160 results = [] 1161 for w in self._cluster.workers: 1162 results.append(w.create_resource(fn, args=args, kwargs=kwargs)) # pylint: disable=protected-access 1163 return PerWorkerValues(tuple(results)) 1164 1165 def fetch(self, val): 1166 """Blocking call to fetch results from the remote values. 1167 1168 This is a wrapper around 1169 `tf.distribute.experimental.coordinator.RemoteValue.fetch` for a 1170 `RemoteValue` structure; it returns the execution results of 1171 `RemoteValue`s. If not ready, wait for them while blocking the caller. 1172 1173 Example: 1174 ```python 1175 strategy = ... 1176 coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator( 1177 strategy) 1178 1179 def dataset_fn(): 1180 return tf.data.Dataset.from_tensor_slices([1, 1, 1]) 1181 1182 with strategy.scope(): 1183 v = tf.Variable(initial_value=0) 1184 1185 @tf.function 1186 def worker_fn(iterator): 1187 def replica_fn(x): 1188 v.assign_add(x) 1189 return v.read_value() 1190 return strategy.run(replica_fn, args=(next(iterator),)) 1191 1192 distributed_dataset = coordinator.create_per_worker_dataset(dataset_fn) 1193 distributed_iterator = iter(distributed_dataset) 1194 result = coordinator.schedule(worker_fn, args=(distributed_iterator,)) 1195 assert coordinator.fetch(result) == 1 1196 ``` 1197 1198 Args: 1199 val: The value to fetch the results from. If this is structure of 1200 `tf.distribute.experimental.coordinator.RemoteValue`, `fetch()` will be 1201 called on the individual 1202 `tf.distribute.experimental.coordinator.RemoteValue` to get the result. 1203 1204 Returns: 1205 If `val` is a `tf.distribute.experimental.coordinator.RemoteValue` or a 1206 structure of `tf.distribute.experimental.coordinator.RemoteValue`s, 1207 return the fetched `tf.distribute.experimental.coordinator.RemoteValue` 1208 values immediately if they are available, or block the call until they are 1209 available, and return the fetched 1210 `tf.distribute.experimental.coordinator.RemoteValue` values with the same 1211 structure. If `val` is other types, return it as-is. 1212 """ 1213 1214 def _maybe_fetch(val): 1215 if isinstance(val, RemoteValue): 1216 return val.fetch() 1217 else: 1218 return val 1219 1220 # TODO(yuefengz): we should fetch values in a batch. 1221 return nest.map_structure(_maybe_fetch, val) 1222 1223 1224def _extract_failed_ps_instances(err_msg): 1225 """Return a set of potentially failing ps instances from error message.""" 1226 tasks = re.findall("/job:ps/replica:0/task:[0-9]+", err_msg) 1227 return set(int(t.split(":")[-1]) for t in tasks) 1228 1229 1230def _is_ps_failure(error): 1231 """Whether the error is considered a parameter server failure.""" 1232 1233 # For an `InputError`, extract the original error and assess it accordingly. 1234 if isinstance(error, InputError): 1235 error = error.original_exception 1236 1237 return (isinstance(error, errors.UnavailableError) and 1238 _RPC_ERROR_FROM_PS in str(error)) 1239 1240 1241def _is_worker_failure(error): 1242 """Whether the error is considered a worker failure.""" 1243 1244 # For an `InputError`, extract the original error and assess it accordingly. 1245 if isinstance(error, InputError): 1246 error = error.original_exception 1247 1248 if _JOB_WORKER_STRING_IDENTIFIER not in str(error): 1249 return False 1250 if _RPC_ERROR_FROM_PS in str(error): 1251 return False 1252 1253 # TODO(haoyuzhang): Consider using special status code if error from a 1254 # remote is derived from RPC errors originated from other hosts. 1255 if isinstance(error, (errors.UnavailableError, errors.AbortedError)): 1256 return True 1257 1258 # The following error could happen when the remote task fails and restarts 1259 # in a very short interval during which no RPCs were exchanged to detect the 1260 # failure. In that case, gRPC allows channel (which is different from a 1261 # connection) to be reused for a replaced server listening to same address. 1262 if isinstance(error, errors.InvalidArgumentError): 1263 if ("unknown device" in str(error) or 1264 "Unable to find the relevant tensor remote_handle" in str(error)): 1265 # TODO(b/159961667): Fix "Unable to find the relevant tensor 1266 # remote_handle" part. 1267 return True 1268 1269 # TODO(b/162541228): The following 2 types of errors are very rare and only 1270 # observed in large-scale testing. The types of errors should be reduced. 1271 # This could happen when the function registration fails. In the observed 1272 # cases this only happens to the dataset related functions. 1273 if isinstance(error, errors.NotFoundError): 1274 if ("is neither a type of a primitive operation nor a name of a function " 1275 "registered" in str(error)): 1276 return True 1277 1278 # NOTE(b/179061495): During worker preemptions, if multiple functions are 1279 # running concurrently (especially with subfunctions spanning chief/PS), 1280 # CancelledError can be returned due to chief/PS cancelling outstanding RPCs 1281 # to the failing workers. 1282 if isinstance(error, errors.CancelledError): 1283 return True 1284 1285 return False 1286