• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Lint as: python3
2# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16"""Module for `ClusterCoordinator` and relevant cluster-worker related library.
17
18This is currently under development and the API is subject to change.
19"""
20
21from __future__ import absolute_import
22from __future__ import division
23from __future__ import print_function
24
25import contextlib
26import os
27import re
28import threading
29import time
30import weakref
31from six.moves import queue
32
33from tensorflow.python.distribute import parameter_server_strategy_v2
34from tensorflow.python.distribute.coordinator import metric_utils
35from tensorflow.python.distribute.coordinator import values as values_lib
36from tensorflow.python.eager import cancellation
37from tensorflow.python.eager import context
38from tensorflow.python.eager import def_function
39from tensorflow.python.eager import executor
40from tensorflow.python.eager import function as tf_function
41from tensorflow.python.framework import errors
42from tensorflow.python.framework import func_graph
43from tensorflow.python.framework import ops
44from tensorflow.python.platform import tf_logging as logging
45from tensorflow.python.util import nest
46from tensorflow.python.util.tf_export import tf_export
47
48# Maximum time for failed worker to come back is 1 hour
49_WORKER_MAXIMUM_RECOVERY_SEC = 3600
50
51# Maximum size for queued closures, "infinite" if set to 0.
52# When the maximum queue size is reached, further schedule calls will become
53# blocking until some previously queued closures are executed on workers.
54# Note that using an "infinite" queue size can take a non-trivial portion of
55# memory, and even lead to coordinator OOM. Modify the size to a smaller value
56# for coordinator with constrained memory resource (only recommended for
57# advanced users). Also used in unit tests to ensure the correctness when the
58# queue is full.
59_CLOSURE_QUEUE_MAX_SIZE = 256 * 1024
60
61# RPC error message from PS
62_RPC_ERROR_FROM_PS = "GRPC error information from remote target /job:ps"
63
64# InvalidArgumentError (unknown device) will not have "GRPC error..." string.
65_JOB_WORKER_STRING_IDENTIFIER = "/job:worker"
66
67
68RemoteValueStatus = values_lib.RemoteValueStatus
69RemoteValue = values_lib.RemoteValue
70RemoteValueImpl = values_lib.RemoteValueImpl
71PerWorkerValues = values_lib.PerWorkerValues
72
73
74class InputError(Exception):
75
76  def __init__(self, original_exception):
77    self.original_exception = original_exception
78    message = ("Input has an error, the original exception is %r, "
79               "error message is %s." %
80               (original_exception, str(original_exception)))
81    super().__init__(message)
82    self.with_traceback(original_exception.__traceback__)
83
84
85def _maybe_rebuild_remote_values(worker, structure):
86  """Attempts to return errors from `RemoteValue`s. Rebuilds them if needed."""
87  errors_in_structure = []
88
89  def _get_error(val):
90    if isinstance(val, RemoteValue):
91      if val._status is RemoteValueStatus.ABORTED:  # pylint: disable=protected-access
92        try:
93          # This attempts to rebuild the resource on the worker, which may fail
94          # if the worker or PS is unavailable. If it fails, the original
95          # RemoteValue that requests a result will receive an `InputError`,
96          # which gets handled by `wait_on_failure` in `process_closure`.
97          val._rebuild_on(worker)  # pylint: disable=protected-access
98        except Exception as e:  # pylint: disable=broad-except
99          val._set_error(e)  # pylint: disable=protected-access
100
101      error = val._get_error()  # pylint: disable=protected-access
102      if error:
103        errors_in_structure.append(error)
104
105  nest.map_structure(_get_error, structure)
106  if errors_in_structure:
107    return errors_in_structure[0]
108  else:
109    return None
110
111
112def _maybe_get_remote_value(val):
113  """Gets the value of `val` if it is a `RemoteValue`."""
114  if isinstance(val, RemoteValue):
115    error = val._get_error()  # pylint: disable=protected-access
116    if error:
117      raise AssertionError(
118          "RemoteValue doesn't have a value because it has errors.")
119    else:
120      return val._get_values()  # pylint: disable=protected-access
121  else:
122    return val
123
124
125def _maybe_as_type_spec(val):
126  if isinstance(val, (RemoteValue, PerWorkerValues)):
127    if val._type_spec is None:  # pylint: disable=protected-access
128      raise ValueError("Output of a scheduled function that is not "
129                       "tf.function cannot be the input of another function.")
130    return val._type_spec  # pylint: disable=protected-access
131  else:
132    return val
133
134
135def _select_worker_slice(worker_id, structured):
136  """Selects the worker slice of each of the items in `structured`."""
137
138  def _get(x):
139    return x._values[worker_id] if isinstance(x, PerWorkerValues) else x  # pylint: disable=protected-access
140
141  return nest.map_structure(_get, structured)
142
143
144def _disallow_remote_value_as_input(structured):
145  """Raises if any element of `structured` is a RemoteValue."""
146
147  def _raise_if_remote_value(x):
148    if isinstance(x, RemoteValue):
149      raise ValueError(
150          "`tf.distribute.experimental.coordinator.RemoteValue` used "
151          "as an input to scheduled function is not yet "
152          "supported.")
153
154  nest.map_structure(_raise_if_remote_value, structured)
155
156
157class Closure(object):
158  """Hold a function to be scheduled and its arguments."""
159
160  def __init__(self, function, cancellation_mgr, args=None, kwargs=None):
161    if not callable(function):
162      raise ValueError("Function passed to `ClusterCoordinator.schedule` must "
163                       "be a callable object.")
164    self._args = args or ()
165    self._kwargs = kwargs or {}
166
167    _disallow_remote_value_as_input(self._args)
168    _disallow_remote_value_as_input(self._kwargs)
169
170    if isinstance(function, def_function.Function):
171      replica_args = _select_worker_slice(0, self._args)
172      replica_kwargs = _select_worker_slice(0, self._kwargs)
173
174      # Note: no need to handle function registration failure since this kind of
175      # failure will not raise exceptions as designed in the runtime. The
176      # coordinator has to rely on subsequent operations that raise to catch
177      # function registration failure.
178
179      # Record the function tracing overhead. Note that we pass in the tracing
180      # count of the def_function.Function as a state tracker, so that metrics
181      # will only record the time for actual function tracing (i.e., excluding
182      # function cache lookups).
183      with metric_utils.monitored_timer(
184          "function_tracing", state_tracker=function._get_tracing_count):  # pylint: disable=protected-access
185        self._concrete_function = function.get_concrete_function(
186            *nest.map_structure(_maybe_as_type_spec, replica_args),
187            **nest.map_structure(_maybe_as_type_spec, replica_kwargs))
188    elif isinstance(function, tf_function.ConcreteFunction):
189      self._concrete_function = function
190
191    if hasattr(self, "_concrete_function"):
192      # If we have a concrete function, we get to retrieve the output type spec
193      # via the structured_output.
194      self._output_type_spec = func_graph.convert_structure_to_signature(
195          self._concrete_function.structured_outputs)
196      self._function = cancellation_mgr.get_cancelable_function(
197          self._concrete_function)
198    else:
199      # Otherwise (i.e. what is passed in is a regular python function), we have
200      # no such information.
201      self._output_type_spec = None
202      self._function = function
203
204    self._output_remote_value_ref = None
205
206  def build_output_remote_value(self):
207    if self._output_remote_value_ref is None:
208      ret = RemoteValueImpl(None, self._output_type_spec)
209      self._output_remote_value_ref = weakref.ref(ret)
210      return ret
211    else:
212      raise ValueError(
213          "The output of the Closure cannot be built more than once.")
214
215  def maybe_call_with_output_remote_value(self, method):
216    if self._output_remote_value_ref is None:
217      return None
218    output_remote_value = self._output_remote_value_ref()
219    if output_remote_value is not None:
220      return method(output_remote_value)
221    return None
222
223  def mark_cancelled(self):
224    e = errors.CancelledError(
225        None, None, "The corresponding function is "
226        "cancelled. Please reschedule the function.")
227    self.maybe_call_with_output_remote_value(lambda r: r._set_error(e))  # pylint: disable=protected-access
228
229  def execute_on(self, worker):
230    """Executes the closure on the given worker.
231
232    Args:
233      worker: a `Worker` object.
234    """
235    replica_args = _select_worker_slice(worker.worker_index, self._args)
236    replica_kwargs = _select_worker_slice(worker.worker_index, self._kwargs)
237
238    e = (
239        _maybe_rebuild_remote_values(worker, replica_args) or
240        _maybe_rebuild_remote_values(worker, replica_kwargs))
241    if e:
242      if not isinstance(e, InputError):
243        e = InputError(e)
244      raise e
245
246    with ops.device(worker.device_name):
247      with context.executor_scope(worker.executor):
248        with metric_utils.monitored_timer("closure_execution"):
249          output_values = self._function(
250              *nest.map_structure(_maybe_get_remote_value, replica_args),
251              **nest.map_structure(_maybe_get_remote_value, replica_kwargs))
252    self.maybe_call_with_output_remote_value(
253        lambda r: r._set_values(output_values))  # pylint: disable=protected-access
254
255
256class ResourceClosure(Closure):
257
258  def build_output_remote_value(self):
259    if self._output_remote_value_ref is None:
260      # We need to remember the Closure object in the `RemoteValue` here.
261      ret = RemoteValueImpl(self, self._output_type_spec)
262      self._output_remote_value_ref = weakref.ref(ret)
263      return ret
264    else:
265      return self._output_remote_value_ref()
266
267
268class _CoordinatedClosureQueue(object):
269  """Manage a queue of closures, inflight count and errors from execution.
270
271  This class is thread-safe.
272  """
273
274  def __init__(self):
275    # `self._inflight_closure_count` only tracks the number of inflight closures
276    # that are "in generation". Once an error occurs, error generation is
277    # incremented and all subsequent arriving closures (from inflight) are
278    # considered "out of generation".
279    self._inflight_closure_count = 0
280
281    self._queue_lock = threading.Lock()
282
283    # Condition indicating that all pending closures (either queued or inflight)
284    # have been processed, failed, or cancelled.
285    self._stop_waiting_condition = threading.Condition(self._queue_lock)
286
287    # Condition indicating that an item becomes available in queue (not empty).
288    self._closures_queued_condition = threading.Condition(self._queue_lock)
289    self._should_process_closures = True
290
291    # Condition indicating that a queue slot becomes available (not full).
292    # Note that even with "infinite" queue size, there is still a "practical"
293    # size limit for the queue depending on host memory capacity, and thus the
294    # queue will eventually become full with a lot of enqueued closures.
295    self._queue_free_slot_condition = threading.Condition(self._queue_lock)
296
297    # Condition indicating there is no inflight closures.
298    self._no_inflight_closure_condition = threading.Condition(self._queue_lock)
299
300    # Use to cancel in-flight closures.
301    self._cancellation_mgr = cancellation.CancellationManager()
302
303    if _CLOSURE_QUEUE_MAX_SIZE <= 0:
304      logging.warning(
305          "In a `ClusterCoordinator`, creating an infinite closure queue can "
306          "consume a significant amount of memory and even lead to OOM.")
307    self._queue = queue.Queue(maxsize=_CLOSURE_QUEUE_MAX_SIZE)
308    self._error = None
309
310    # The following is a lock to make sure when `wait` is called and before it
311    # returns no `put` can be executed during this period. It is because `wait`
312    # won't know what to do with newly put closures. This lock adds an cutoff
313    # for `wait` so that closures put into the queue while waiting would not be
314    # taken responsible by this `wait`.
315    #
316    # We cannot reuse the `self._queue_lock` since when `wait` waits for a
317    # condition, the `self._queue_lock` will be released.
318    #
319    # We don't use a reader/writer's lock on purpose to reduce the complexity
320    # of the code.
321    self._put_wait_lock = threading.Lock()
322
323  def stop(self):
324    with self._queue_lock:
325      self._should_process_closures = False
326      self._closures_queued_condition.notifyAll()
327
328  def _cancel_all_closures(self):
329    """Clears the queue and sets remaining closures cancelled error.
330
331    This method expects self._queue_lock to be held prior to entry.
332    """
333    self._cancellation_mgr.start_cancel()
334    while self._inflight_closure_count > 0:
335      self._no_inflight_closure_condition.wait()
336    while True:
337      try:
338        closure = self._queue.get(block=False)
339        self._queue_free_slot_condition.notify()
340        closure.mark_cancelled()
341      except queue.Empty:
342        break
343    # The cancellation manager cannot be reused once cancelled. After all
344    # closures (queued or inflight) are cleaned up, recreate the cancellation
345    # manager with clean state.
346    # Note on thread-safety: this is triggered when one of theses
347    # ClusterCoordinator APIs are called: `schedule`, `wait`, and `done`. At the
348    # same time, no new closures can be constructed (which reads the
349    # _cancellation_mgr to get cancellable functions).
350    self._cancellation_mgr = cancellation.CancellationManager()
351
352  def _raise_if_error(self):
353    """Raises the error if one exists.
354
355    If an error exists, cancel the closures in queue, raises it, and clear
356    the error.
357
358    This method expects self._queue_lock to be held prior to entry.
359    """
360    if self._error:
361      logging.error("Start cancelling closures due to error %r: %s",
362                    self._error, self._error)
363      self._cancel_all_closures()
364      try:
365        raise self._error  # pylint: disable=raising-bad-type
366      finally:
367        self._error = None
368
369  def put(self, closure):
370    """Put a closure into the queue for later execution.
371
372    If `mark_failed` was called before `put`, the error from the first
373    invocation of `mark_failed` will be raised.
374
375    Args:
376      closure: The `Closure` to put into the queue.
377    """
378    with self._put_wait_lock, self._queue_lock:
379      self._queue_free_slot_condition.wait_for(lambda: not self._queue.full())
380      self._queue.put(closure, block=False)
381      self._raise_if_error()
382      self._closures_queued_condition.notify()
383
384  def get(self, timeout=None):
385    """Return a closure from the queue to be executed."""
386    with self._queue_lock:
387      while self._queue.empty() and self._should_process_closures:
388        if not self._closures_queued_condition.wait(timeout=timeout):
389          return None
390      if not self._should_process_closures:
391        return None
392      closure = self._queue.get(block=False)
393      self._queue_free_slot_condition.notify()
394      self._inflight_closure_count += 1
395      return closure
396
397  def mark_finished(self):
398    """Let the queue know that a closure has been successfully executed."""
399    with self._queue_lock:
400      if self._inflight_closure_count < 1:
401        raise AssertionError("There is no inflight closures to mark_finished.")
402      self._inflight_closure_count -= 1
403      if self._inflight_closure_count == 0:
404        self._no_inflight_closure_condition.notifyAll()
405      if self._queue.empty() and self._inflight_closure_count == 0:
406        self._stop_waiting_condition.notifyAll()
407
408  def put_back(self, closure):
409    """Put the closure back into the queue as it was not properly executed."""
410    with self._queue_lock:
411      if self._inflight_closure_count < 1:
412        raise AssertionError("There is no inflight closures to put_back.")
413      if self._error:
414        closure.mark_cancelled()
415      else:
416        self._queue_free_slot_condition.wait_for(lambda: not self._queue.full())
417        self._queue.put(closure, block=False)
418        self._closures_queued_condition.notify()
419      self._inflight_closure_count -= 1
420      if self._inflight_closure_count == 0:
421        self._no_inflight_closure_condition.notifyAll()
422
423  def wait(self, timeout=None):
424    """Wait for all closures to be finished before returning.
425
426    If `mark_failed` was called before or during `wait`, the error from the
427    first invocation of `mark_failed` will be raised.
428
429    Args:
430      timeout: A float specifying a timeout for the wait in seconds.
431
432    Returns:
433      True unless the given timeout expired, in which case it returns False.
434    """
435    with self._put_wait_lock, self._queue_lock:
436      while (not self._error and
437             (not self._queue.empty() or self._inflight_closure_count > 0)):
438        if not self._stop_waiting_condition.wait(timeout=timeout):
439          return False
440      self._raise_if_error()
441      return True
442
443  def mark_failed(self, e):
444    """Sets error and unblocks any wait() call."""
445    with self._queue_lock:
446      # TODO(yuefengz): maybe record all failure and give users more
447      # information?
448      if self._inflight_closure_count < 1:
449        raise AssertionError("There is no inflight closures to mark_failed.")
450      if self._error is None:
451        self._error = e
452      self._inflight_closure_count -= 1
453      if self._inflight_closure_count == 0:
454        self._no_inflight_closure_condition.notifyAll()
455      self._stop_waiting_condition.notifyAll()
456
457  def done(self):
458    """Returns true if the queue is empty and there is no inflight closure.
459
460    If `mark_failed` was called before `done`, the error from the first
461    invocation of `mark_failed` will be raised.
462    """
463    with self._queue_lock:
464      self._raise_if_error()
465      return self._queue.empty() and self._inflight_closure_count == 0
466
467
468class WorkerPreemptionHandler(object):
469  """Handles worker preemptions."""
470
471  def __init__(self, server_def, cluster):
472    self._server_def = server_def
473    self._cluster = cluster
474    self._cluster_update_lock = threading.Lock()
475    self._cluster_due_for_update_or_finish = threading.Event()
476    self._worker_up_cond = threading.Condition(self._cluster_update_lock)
477    self._error_from_recovery = None
478    self._should_preemption_thread_run = True
479    self._preemption_handler_thread = threading.Thread(
480        target=self._preemption_handler,
481        name="WorkerPreemptionHandler",
482        daemon=True)
483    self._preemption_handler_thread.start()
484
485  def stop(self):
486    """Ensure the worker preemption thread is closed."""
487    self._should_preemption_thread_run = False
488    with self._cluster_update_lock:
489      self._cluster_due_for_update_or_finish.set()
490    # TODO(yuefengz): The preemption handler thread shouldn't be terminated
491    # asynchronously since it touches eager context which is a process-wide
492    # singleton. The problem is in OSS unit tests will time out.
493
494  def _validate_preemption_failure(self, e):
495    """Validates that the given exception represents worker preemption."""
496
497    # Only categorize the failure as a worker preemption if the cancellation
498    # manager did not attempt to cancel the blocking operations.
499    if _is_worker_failure(e) and (
500        not self._cluster.closure_queue._cancellation_mgr.is_cancelled):  # pylint: disable=protected-access
501      return
502    raise e
503
504  @contextlib.contextmanager
505  def wait_on_failure(self,
506                      on_failure_fn=None,
507                      on_transient_failure_fn=None,
508                      on_recovery_fn=None,
509                      worker_device_name="(unknown)"):
510    """Catches worker preemption error and wait until failed workers are back.
511
512    Args:
513      on_failure_fn: an optional function to run if preemption happens.
514      on_transient_failure_fn: an optional function to run if transient failure
515        happens.
516      on_recovery_fn: an optional function to run when a worker is recovered
517        from preemption.
518      worker_device_name: the device name of the worker instance that is passing
519        through the failure.
520
521    Yields:
522      None.
523    """
524    assert self._should_preemption_thread_run
525    try:
526      yield
527    except (errors.OpError, InputError) as e:
528      # If the error is due to temporary connectivity issues between worker and
529      # ps, put back closure, ignore error and do not mark worker as failure.
530      if self._cluster._record_and_ignore_transient_ps_failure(e):  # pylint: disable=protected-access
531        logging.error(
532            "Remote function on worker %s failed with %r:%s\n"
533            "It is treated as a transient connectivity failure for now.",
534            worker_device_name, e, e)
535        if on_transient_failure_fn:
536          on_transient_failure_fn()
537        return
538
539      # Ignoring derived CancelledErrors to tolerate transient failures in
540      # PS-worker communication, which initially exposed as an UnavailableError
541      # and then lead to sub-function cancellation, subsequently getting
542      # reported from worker to chief as CancelledError.
543      # We do not mark either worker or PS as failed due to only CancelledError.
544      # If there are real (non-transient) failures, they must also be reported
545      # as other errors (UnavailableError most likely) in closure executions.
546      if isinstance(e, errors.CancelledError) and "/job:" in str(e):
547        logging.error(
548            "Remote function on worker %s failed with %r:%s\n"
549            "This derived error is ignored and not reported to users.",
550            worker_device_name, e, e)
551        if on_transient_failure_fn:
552          on_transient_failure_fn()
553        return
554
555      # This reraises the error, if it's not considered recoverable; otherwise,
556      # the following failure recovery logic run. At this time, only worker
557      # unavailability is recoverable. PS unavailability as well as other
558      # errors in the user function is not recoverable.
559      self._validate_preemption_failure(e)
560
561      logging.error("Worker %s failed with %r:%s", worker_device_name, e, e)
562      if on_failure_fn:
563        on_failure_fn()
564
565      with self._cluster_update_lock:
566        self._cluster_due_for_update_or_finish.set()
567        self._worker_up_cond.wait(_WORKER_MAXIMUM_RECOVERY_SEC)
568        if self._error_from_recovery:
569          # TODO(yuefengz): there is only one worker that will get this error.
570          # Ideally we shuold let all workers notified by `_worker_up_cond` get
571          # this error.
572          try:
573            raise self._error_from_recovery
574          finally:
575            self._error_from_recovery = None
576        logging.info("Worker %s has been recovered.", worker_device_name)
577
578      if on_recovery_fn:
579        with self.wait_on_failure(
580            on_recovery_fn=on_recovery_fn,
581            on_transient_failure_fn=on_transient_failure_fn,
582            worker_device_name=worker_device_name):
583          on_recovery_fn()
584
585  def _preemption_handler(self):
586    """A loop that handles preemption.
587
588    This loop waits for signal of worker preemption and upon worker preemption,
589    it waits until all workers are back and updates the cluster about the
590    restarted workers.
591    """
592    assert self._should_preemption_thread_run
593    while True:
594      self._cluster_due_for_update_or_finish.wait()
595      if not self._should_preemption_thread_run:
596        logging.info("Stopping the failure handing thread.")
597        break
598
599      with self._cluster_update_lock:
600        try:
601          # TODO(haoyuzhang): support partial cluster recovery
602          logging.info("Cluster now being recovered.")
603          context.context().update_server_def(self._server_def)
604
605          # Cluster updated successfully, clear the update signal, and notify
606          # all workers that they are recovered from failure.
607          logging.info("Cluster successfully recovered.")
608          self._worker_up_cond.notify_all()
609          # The check for _should_preemption_thread_run is necessary since the
610          # `stop` may have already set _cluster_due_for_update_or_finish.
611          if self._should_preemption_thread_run:
612            self._cluster_due_for_update_or_finish.clear()
613        except Exception as e:  # pylint: disable=broad-except
614          try:
615            self._validate_preemption_failure(e)
616          except Exception as ps_e:  # pylint: disable=broad-except
617            # In this case, a parameter server fails. So we raise this error to
618            # the caller of `wait_on_failure`.
619            self._error_from_recovery = ps_e
620            self._worker_up_cond.notify_all()
621            if self._should_preemption_thread_run:
622              self._cluster_due_for_update_or_finish.clear()
623          # NOTE: Since the first RPC (GetStatus) of update_server_def is
624          # currently blocking by default, error should only happen if:
625          # (1) More workers failed while waiting for the previous workers to
626          #     come back;
627          # (2) Worker failed when exchanging subsequent RPCs after the first
628          #     RPC returns.
629          # Consider adding backoff retry logic if we see the error logged
630          # too frequently.
631          logging.error("Cluster update failed with error: %s. Retrying...", e)
632
633
634class Worker(object):
635  """A worker in a cluster.
636
637  Attributes:
638    worker_index: The index of the worker in the cluster.
639    device_name: The device string of the worker, e.g. "/job:worker/task:1".
640    executor: The worker's executor for remote function execution.
641    failure_handler: The failure handler used to handler worker preemption
642      failure.
643  """
644
645  def __init__(self, worker_index, device_name, cluster):
646    self.worker_index = worker_index
647    self.device_name = device_name
648    self.executor = executor.new_executor(enable_async=False)
649    self.failure_handler = cluster.failure_handler
650    self._cluster = cluster
651    self._resource_remote_value_refs = []
652    self._should_worker_thread_run = True
653
654    # Worker threads need to start after `Worker`'s initialization.
655    threading.Thread(target=self._process_queue,
656                     name="WorkerClosureProcessingLoop-%d" % self.worker_index,
657                     daemon=True).start()
658
659  def stop(self):
660    """Ensure the worker thread is closed."""
661    self._should_worker_thread_run = False
662
663  def _set_resources_aborted(self):
664    # TODO(yuefengz): maybe we can query whether a tensor is valid or not
665    # instead of marking a tensor aborted?
666    for weakref_resource in self._resource_remote_value_refs:
667      resource = weakref_resource()
668      if resource:
669        resource._set_aborted()  # pylint: disable=protected-access
670
671  def _set_dead(self):
672    raise NotImplementedError("_set_dead is not implemented.")
673
674  def _process_closure(self, closure):
675    """Runs a closure with preemption handling."""
676    assert closure is not None
677    try:
678      with self._cluster.failure_handler.wait_on_failure(
679          on_failure_fn=lambda: self._cluster.closure_queue.put_back(closure),
680          on_transient_failure_fn=lambda: self._cluster.closure_queue.put_back(
681              closure),
682          on_recovery_fn=self._set_resources_aborted,
683          worker_device_name=self.device_name):
684        closure.execute_on(self)
685        with metric_utils.monitored_timer("remote_value_fetch"):
686          # Copy the remote tensor to local (the coordinator) in case worker
687          # becomes unavailable at a later time.
688          closure.maybe_call_with_output_remote_value(lambda r: r.get())
689        self._cluster.closure_queue.mark_finished()
690    except Exception as e:  # pylint: disable=broad-except
691      # Avoid logging the derived cancellation error
692      if not isinstance(e, errors.CancelledError):
693        logging.error(
694            "/job:worker/task:%d encountered the following error when "
695            "processing closure: %r:%s", self.worker_index, e, e)
696      closure.maybe_call_with_output_remote_value(lambda r: r._set_error(e))  # pylint: disable=protected-access
697      self._cluster.closure_queue.mark_failed(e)
698
699  def _maybe_delay(self):
700    """Delay if corresponding env vars are set."""
701    # If the following two env vars variables are set. Scheduling for workers
702    # will start in a staggered manner. Worker i will wait for
703    # `TF_COORDINATOR_SCHEDULE_START_DELAY` * i seconds, not exceeding
704    # `TF_COORDINATOR_SCHEDULE_START_DELAY_MAX`.
705    delay_secs = int(os.environ.get("TF_COORDINATOR_SCHEDULE_START_DELAY", "0"))
706    delay_cap = int(
707        os.environ.get("TF_COORDINATOR_SCHEDULE_START_DELAY_MAX", "0"))
708    if delay_cap:
709      delay_secs = min(delay_secs * self.worker_index, delay_cap)
710    if delay_secs > 0:
711      logging.info("Worker %d sleeping for %d seconds before running function",
712                   self.worker_index, delay_secs)
713    time.sleep(delay_secs)
714
715  def _process_queue(self):
716    """Function running in a worker thread to process closure queues."""
717    self._maybe_delay()
718    while self._should_worker_thread_run:
719      closure = self._cluster.closure_queue.get()
720      if not self._should_worker_thread_run or closure is None:
721        return
722      self._process_closure(closure)
723      # To properly stop the worker and preemption threads, it is important that
724      # `ClusterCoordinator` object is not held onto so its `__del__` can be
725      # called. By removing the reference to the `closure` that has already been
726      # processed, we ensure that the `closure` object is released, while
727      # getting the next `closure` at above `self._cluster.closure_queue.get()`
728      # call.
729      del closure
730
731  def create_resource(self, function, args=None, kwargs=None):
732    """Synchronously creates a per-worker resource represented by a `RemoteValue`.
733
734    Args:
735      function: the resource function to be run remotely. It should be a
736        `tf.function`, a concrete function or a Python function.
737      args: positional arguments to be passed to the function.
738      kwargs: keyword arguments to be passed to the function.
739
740    Returns:
741      one or several RemoteValue objects depending on the function return
742      values.
743    """
744    # Some notes about the concurrency: currently all the activities related to
745    # the same worker such as creating resources, setting resources' aborted
746    # status, and executing closures happen on the same thread. This allows us
747    # to have simpler logic of concurrency.
748    closure = ResourceClosure(
749        function,
750        self._cluster.closure_queue._cancellation_mgr,  # pylint: disable=protected-access
751        args=args,
752        kwargs=kwargs)
753    resource_remote_value = closure.build_output_remote_value()
754    self._register_resource(resource_remote_value)
755
756    # The following is a short-term solution to lazily create resources in
757    # parallel.
758    # TODO(b/160343165): we should create resources eagerly, i.e. schedule the
759    # resource creation function as soon as users call this method.
760    resource_remote_value._set_aborted()  # pylint: disable=protected-access
761    return resource_remote_value
762
763  def _register_resource(self, resource_remote_value):
764    if not isinstance(resource_remote_value, RemoteValue):
765      raise ValueError("Resource being registered is not of type "
766                       "`tf.distribute.experimental.coordinator.RemoteValue`.")
767    self._resource_remote_value_refs.append(weakref.ref(resource_remote_value))
768
769
770class Cluster(object):
771  """A cluster with workers.
772
773  We assume all function errors are fatal and based on this assumption our
774  error reporting logic is:
775  1) Both `schedule` and `join` can raise a non-retryable error which is the
776  first error seen by the coordinator from any previously scheduled functions.
777  2) When an error is raised, there is no guarantee on how many previously
778  scheduled functions have been executed; functions that have not been executed
779  will be thrown away and marked as cancelled.
780  3) After an error is raised, the internal state of error will be cleared.
781  I.e. functions can continue to be scheduled and subsequent calls of `schedule`
782  or `join` will not raise the same error again.
783
784  Attributes:
785    failure_handler: The failure handler used to handler worker preemption
786      failure.
787    workers: a list of `Worker` objects in the cluster.
788    closure_queue: the global Closure queue.
789  """
790
791  def __init__(self, strategy):
792    """Initializes the cluster instance."""
793
794    self._num_workers = strategy._num_workers
795    self._num_ps = strategy._num_ps
796
797    # Ignore PS failures reported by workers due to transient connection errors.
798    # Transient connectivity issues between workers and PS are relayed by the
799    # workers to the coordinator, leading the coordinator to believe that there
800    # are PS failures. The difference between transient vs. permanent PS failure
801    # is the number of reports from the workers. When this env var is set to a
802    # positive integer K, the coordinator ignores up to K reports of a failed PS
803    # task, i.e., only when there are more than K trials of executing closures
804    # fail due to errors from the same PS instance do we consider the PS
805    # instance encounters a failure.
806    # TODO(b/164279603): Remove this workaround when the underlying connectivity
807    # issue in gRPC server is resolved.
808    self._transient_ps_failures_threshold = int(
809        os.environ.get("TF_COORDINATOR_IGNORE_TRANSIENT_PS_FAILURES", 3))
810    self._potential_ps_failures_lock = threading.Lock()
811    self._potential_ps_failures_count = [0] * self._num_ps
812
813    self.closure_queue = _CoordinatedClosureQueue()
814    self.failure_handler = WorkerPreemptionHandler(context.get_server_def(),
815                                                   self)
816    worker_device_strings = [
817        "/job:worker/replica:0/task:%d" % i for i in range(self._num_workers)
818    ]
819    self.workers = [
820        Worker(i, w, self) for i, w in enumerate(worker_device_strings)
821    ]
822
823  def stop(self):
824    """Stop worker, worker preemption threads, and the closure queue."""
825    self.failure_handler.stop()
826
827    for worker in self.workers:
828      worker.stop()
829    self.closure_queue.stop()
830
831  def _record_and_ignore_transient_ps_failure(self, e):
832    """Records potential PS failures and return if failure should be ignored."""
833    if self._transient_ps_failures_threshold <= 0 or not _is_ps_failure(e):
834      return False
835
836    ps_tasks = _extract_failed_ps_instances(str(e))
837    with self._potential_ps_failures_lock:
838      for t in ps_tasks:
839        self._potential_ps_failures_count[t] += 1
840        # The number of UnavailableError encountered on this PS task exceeds the
841        # maximum number of ignored error
842        if (self._potential_ps_failures_count[t] >=
843            self._transient_ps_failures_threshold):
844          return False
845    return True
846
847  def schedule(self, function, args, kwargs):
848    """Schedules `function` to be dispatched to a worker for execution.
849
850    Args:
851      function: The function to be dispatched to a worker for execution
852        asynchronously.
853      args: Positional arguments for `fn`.
854      kwargs: Keyword arguments for `fn`.
855
856    Returns:
857      A `RemoteValue` object.
858    """
859    closure = Closure(
860        function,
861        self.closure_queue._cancellation_mgr,  # pylint: disable=protected-access
862        args=args,
863        kwargs=kwargs)
864    ret = closure.build_output_remote_value()
865    self.closure_queue.put(closure)
866    return ret
867
868  def join(self):
869    """Blocks until all scheduled functions are executed."""
870    self.closure_queue.wait()
871
872  def done(self):
873    """Returns true if all scheduled functions are executed."""
874    return self.closure_queue.done()
875
876
877@tf_export("distribute.experimental.coordinator.ClusterCoordinator", v1=[])
878class ClusterCoordinator(object):
879  """An object to schedule and coordinate remote function execution.
880
881  This class is used to create fault-tolerant resources and dispatch functions
882  to remote TensorFlow servers.
883
884  Currently, this class is not supported to be used in a standalone manner. It
885  should be used in conjunction with a `tf.distribute` strategy that is designed
886  to work with it. The `ClusterCoordinator` class currently only works
887  `tf.distribute.experimental.ParameterServerStrategy`.
888
889  __The `schedule`/`join` APIs__
890
891  The most important APIs provided by this class is the `schedule`/`join` pair.
892  The `schedule` API is non-blocking in that it queues a `tf.function` and
893  returns a `RemoteValue` immediately. The queued functions will be dispatched
894  to remote workers in background threads and their `RemoteValue`s will be
895  filled asynchronously. Since `schedule` doesn’t require worker assignment, the
896  `tf.function` passed in can be executed on any available worker. If the worker
897  it is executed on becomes unavailable before its completion, it will be
898  migrated to another worker. Because of this fact and function execution is not
899  atomic, a function may be executed more than once.
900
901  __Handling Task Failure__
902
903  This class when used with
904  `tf.distribute.experimental.ParameterServerStrategy`, comes with built-in
905  fault tolerance for worker failures. That is, when some workers are not
906  available for any reason to be reached from the coordinator, the training
907  progress continues to be made with the remaining workers. Upon recovery of a
908  failed worker, it will be added for function execution after datasets created
909  by `create_per_worker_dataset` are re-built on it.
910
911  When a parameter server fails, a `tf.errors.UnavailableError` is raised by
912  `schedule`, `join` or `done`. In this case, in addition to bringing back the
913  failed parameter server, users should restart the coordinator so that it
914  reconnects to workers and parameter servers, re-creates the variables, and
915  loads checkpoints. If the coordinator fails, after the user brings it back,
916  the program will automatically connect to workers and parameter servers, and
917  continue the progress from a checkpoint.
918
919  It is thus essential that in user's program, a checkpoint file is periodically
920  saved, and restored at the start of the program. If an
921  `tf.keras.optimizers.Optimizer` is checkpointed, after restoring from a
922  checkpoiont, its `iterations` property roughly indicates the number of steps
923  that have been made. This can be used to decide how many epochs and steps are
924  needed before the training completion.
925
926  See `tf.distribute.experimental.ParameterServerStrategy` docstring for an
927  example usage of this API.
928
929  This is currently under development, and the API as well as implementation
930  are subject to changes.
931  """
932
933  def __new__(cls, strategy):
934    # `ClusterCoordinator` is kept as a single instance to a given `Strategy`.
935    # TODO(rchao): Needs a lock for thread-safety
936    if strategy._cluster_coordinator is None:
937      strategy._cluster_coordinator = super(
938          ClusterCoordinator, cls).__new__(cls)
939    return strategy._cluster_coordinator
940
941  def __init__(self, strategy):
942    """Initialization of a `ClusterCoordinator` instance.
943
944    Args:
945      strategy: a supported `tf.distribute.Strategy` object. Currently, only
946        `tf.distribute.experimental.ParameterServerStrategy` is supported.
947
948    Raises:
949      ValueError: if the strategy being used is not supported.
950    """
951    if not getattr(self, "_has_initialized", False):
952      if not isinstance(strategy,
953                        parameter_server_strategy_v2.ParameterServerStrategyV2):
954        raise ValueError(
955            "Only `tf.distribute.experimental.ParameterServerStrategy` "
956            "is supported to work with "
957            "`tf.distribute.experimental.coordinator.ClusterCoordinator` "
958            "currently.")
959      self._strategy = strategy
960      self.strategy.extended._used_with_coordinator = True
961      self._cluster = Cluster(strategy)
962      self._has_initialized = True
963
964  def __del__(self):
965    self._cluster.stop()
966
967  @property
968  def strategy(self):
969    """Returns the `Strategy` associated with the `ClusterCoordinator`."""
970    return self._strategy
971
972  def schedule(self, fn, args=None, kwargs=None):
973    """Schedules `fn` to be dispatched to a worker for asynchronous execution.
974
975    This method is non-blocking in that it queues the `fn` which will be
976    executed later and returns a
977    `tf.distribute.experimental.coordinator.RemoteValue` object immediately.
978    `fetch` can be called on it to wait for the function execution to finish
979    and retrieve its output from a remote worker. On the other hand, call
980    `tf.distribute.experimental.coordinator.ClusterCoordinator.join` to wait for
981    all scheduled functions to finish.
982
983    `schedule` guarantees that `fn` will be executed on a worker at least once;
984    it could be more than once if its corresponding worker fails in the middle
985    of its execution. Note that since worker can fail at any point when
986    executing the function, it is possible that the function is partially
987    executed, but `tf.distribute.experimental.coordinator.ClusterCoordinator`
988    guarantees that in those events, the function will eventually be executed on
989    any worker that is available.
990
991    If any previously scheduled function raises an error, `schedule` will raise
992    any one of those errors, and clear the errors collected so far. What happens
993    here, some of the previously scheduled functions may have not been executed.
994    User can call `fetch` on the returned
995    `tf.distribute.experimental.coordinator.RemoteValue` to inspect if they have
996    executed, failed, or cancelled, and reschedule the corresponding function if
997    needed.
998
999    When `schedule` raises, it guarantees that there is no function that is
1000    still being executed.
1001
1002    At this time, there is no support of worker assignment for function
1003    execution, or priority of the workers.
1004
1005    `args` and `kwargs` are the arguments passed into `fn`, when `fn` is
1006    executed on a worker. They can be
1007    `tf.distribute.experimental.coordinator.PerWorkerValues` and in this case,
1008    the argument will be substituted with the corresponding component on the
1009    target worker. Arguments that are not
1010    `tf.distribute.experimental.coordinator.PerWorkerValues` will be passed into
1011    `fn` as-is. Currently, `tf.distribute.experimental.coordinator.RemoteValue`
1012    is not supported to be input `args` or `kwargs`.
1013
1014    Args:
1015      fn: A `tf.function`; the function to be dispatched to a worker for
1016        execution asynchronously. Regular python funtion is not supported to be
1017        scheduled.
1018      args: Positional arguments for `fn`.
1019      kwargs: Keyword arguments for `fn`.
1020
1021    Returns:
1022      A `tf.distribute.experimental.coordinator.RemoteValue` object that
1023      represents the output of the function scheduled.
1024
1025    Raises:
1026      Exception: one of the exceptions caught by the coordinator from any
1027        previously scheduled function, since the last time an error was thrown
1028        or since the beginning of the program.
1029    """
1030    if not isinstance(fn,
1031                      (def_function.Function, tf_function.ConcreteFunction)):
1032      raise TypeError(
1033          "`tf.distribute.experimental.coordinator.ClusterCoordinator.schedule`"
1034          " only accepts a `tf.function` or a concrete function.")
1035    # Slot variables are usually created during function tracing time; thus
1036    # `schedule` needs to be called within the `strategy.scope()`.
1037    with self.strategy.scope():
1038      self.strategy.extended._being_scheduled = True  # pylint: disable=protected-access
1039      remote_value = self._cluster.schedule(fn, args=args, kwargs=kwargs)
1040      self.strategy.extended._being_scheduled = False  # pylint: disable=protected-access
1041      return remote_value
1042
1043  def join(self):
1044    """Blocks until all the scheduled functions have finished execution.
1045
1046    If any previously scheduled function raises an error, `join` will fail by
1047    raising any one of those errors, and clear the errors collected so far. If
1048    this happens, some of the previously scheduled functions may have not been
1049    executed. Users can call `fetch` on the returned
1050    `tf.distribute.experimental.coordinator.RemoteValue` to inspect if they have
1051    executed, failed, or cancelled. If some that have been cancelled need to be
1052    rescheduled, users should call `schedule` with the function again.
1053
1054    When `join` returns or raises, it guarantees that there is no function that
1055    is still being executed.
1056
1057    Raises:
1058      Exception: one of the exceptions caught by the coordinator by any
1059        previously scheduled function since the last time an error was thrown or
1060        since the beginning of the program.
1061    """
1062    self._cluster.join()
1063
1064  def done(self):
1065    """Returns whether all the scheduled functions have finished execution.
1066
1067    If any previously scheduled function raises an error, `done` will fail by
1068    raising any one of those errors.
1069
1070    When `done` returns True or raises, it guarantees that there is no function
1071    that is still being executed.
1072
1073    Returns:
1074      Whether all the scheduled functions have finished execution.
1075    Raises:
1076      Exception: one of the exceptions caught by the coordinator by any
1077        previously scheduled function since the last time an error was thrown or
1078        since the beginning of the program.
1079    """
1080    return self._cluster.done()
1081
1082  def create_per_worker_dataset(self, dataset_fn):
1083    """Create dataset on workers by calling `dataset_fn` on worker devices.
1084
1085    This creates the given dataset generated by dataset_fn on workers
1086    and returns an object that represents the collection of those individual
1087    datasets. Calling `iter` on such collection of datasets returns a
1088    `tf.distribute.experimental.coordinator.PerWorkerValues`, which is a
1089    collection of iterators, where the iterators have been placed on respective
1090    workers.
1091
1092    Calling `next` on a `PerWorkerValues` of iterator is unsupported. The
1093    iterator is meant to be passed as an argument into
1094    `tf.distribute.experimental.coordinator.ClusterCoordinator.schedule`. When
1095    the scheduled function is about to be executed by a worker, the
1096    function will receive the individual iterator that corresponds to the
1097    worker. The `next` method can be called on an iterator inside a
1098    scheduled function when the iterator is an input of the function.
1099
1100    Currently the `schedule` method assumes workers are all the same and thus
1101    assumes the datasets on different workers are the same, except they may be
1102    shuffled differently if they contain a `dataset.shuffle` operation and a
1103    random seed is not set. Because of this, we also recommend the datasets to
1104    be repeated indefinitely and schedule a finite number of steps instead of
1105    relying on the `OutOfRangeError` from a dataset.
1106
1107
1108    Example:
1109
1110    ```python
1111    strategy = tf.distribute.experimental.ParameterServerStrategy(
1112        cluster_resolver=...)
1113    coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator(
1114        strategy=strategy)
1115
1116    @tf.function
1117    def worker_fn(iterator):
1118      return next(iterator)
1119
1120    def per_worker_dataset_fn():
1121      return strategy.distribute_datasets_from_function(
1122          lambda x: tf.data.Dataset.from_tensor_slices([3] * 3))
1123
1124    per_worker_dataset = coordinator.create_per_worker_dataset(
1125        per_worker_dataset_fn)
1126    per_worker_iter = iter(per_worker_dataset)
1127    remote_value = coordinator.schedule(worker_fn, args=(per_worker_iter,))
1128    assert remote_value.fetch() == 3
1129    ```
1130
1131    Args:
1132      dataset_fn: The dataset function that returns a dataset. This is to be
1133        executed on the workers.
1134
1135    Returns:
1136      An object that represents the collection of those individual
1137      datasets. `iter` is expected to be called on this object that returns
1138      a `tf.distribute.experimental.coordinator.PerWorkerValues` of the
1139      iterators (that are on the workers).
1140    """
1141    return values_lib.get_per_worker_dataset(dataset_fn, self)
1142
1143  def _create_per_worker_resources(self, fn, args=None, kwargs=None):
1144    """Synchronously create resources on the workers.
1145
1146    The resources are represented by
1147    `tf.distribute.experimental.coordinator.RemoteValue`s.
1148
1149    Args:
1150      fn: The function to be dispatched to all workers for execution
1151        asynchronously.
1152      args: Positional arguments for `fn`.
1153      kwargs: Keyword arguments for `fn`.
1154
1155    Returns:
1156      A `tf.distribute.experimental.coordinator.PerWorkerValues` object, which
1157      wraps a tuple of `tf.distribute.experimental.coordinator.RemoteValue`
1158      objects.
1159    """
1160    results = []
1161    for w in self._cluster.workers:
1162      results.append(w.create_resource(fn, args=args, kwargs=kwargs))  # pylint: disable=protected-access
1163    return PerWorkerValues(tuple(results))
1164
1165  def fetch(self, val):
1166    """Blocking call to fetch results from the remote values.
1167
1168    This is a wrapper around
1169    `tf.distribute.experimental.coordinator.RemoteValue.fetch` for a
1170    `RemoteValue` structure; it returns the execution results of
1171    `RemoteValue`s. If not ready, wait for them while blocking the caller.
1172
1173    Example:
1174    ```python
1175    strategy = ...
1176    coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator(
1177        strategy)
1178
1179    def dataset_fn():
1180      return tf.data.Dataset.from_tensor_slices([1, 1, 1])
1181
1182    with strategy.scope():
1183      v = tf.Variable(initial_value=0)
1184
1185    @tf.function
1186    def worker_fn(iterator):
1187      def replica_fn(x):
1188        v.assign_add(x)
1189        return v.read_value()
1190      return strategy.run(replica_fn, args=(next(iterator),))
1191
1192    distributed_dataset = coordinator.create_per_worker_dataset(dataset_fn)
1193    distributed_iterator = iter(distributed_dataset)
1194    result = coordinator.schedule(worker_fn, args=(distributed_iterator,))
1195    assert coordinator.fetch(result) == 1
1196    ```
1197
1198    Args:
1199      val: The value to fetch the results from. If this is structure of
1200        `tf.distribute.experimental.coordinator.RemoteValue`, `fetch()` will be
1201        called on the individual
1202        `tf.distribute.experimental.coordinator.RemoteValue` to get the result.
1203
1204    Returns:
1205      If `val` is a `tf.distribute.experimental.coordinator.RemoteValue` or a
1206      structure of `tf.distribute.experimental.coordinator.RemoteValue`s,
1207      return the fetched `tf.distribute.experimental.coordinator.RemoteValue`
1208      values immediately if they are available, or block the call until they are
1209      available, and return the fetched
1210      `tf.distribute.experimental.coordinator.RemoteValue` values with the same
1211      structure. If `val` is other types, return it as-is.
1212    """
1213
1214    def _maybe_fetch(val):
1215      if isinstance(val, RemoteValue):
1216        return val.fetch()
1217      else:
1218        return val
1219
1220    # TODO(yuefengz): we should fetch values in a batch.
1221    return nest.map_structure(_maybe_fetch, val)
1222
1223
1224def _extract_failed_ps_instances(err_msg):
1225  """Return a set of potentially failing ps instances from error message."""
1226  tasks = re.findall("/job:ps/replica:0/task:[0-9]+", err_msg)
1227  return set(int(t.split(":")[-1]) for t in tasks)
1228
1229
1230def _is_ps_failure(error):
1231  """Whether the error is considered a parameter server failure."""
1232
1233  # For an `InputError`, extract the original error and assess it accordingly.
1234  if isinstance(error, InputError):
1235    error = error.original_exception
1236
1237  return (isinstance(error, errors.UnavailableError) and
1238          _RPC_ERROR_FROM_PS in str(error))
1239
1240
1241def _is_worker_failure(error):
1242  """Whether the error is considered a worker failure."""
1243
1244  # For an `InputError`, extract the original error and assess it accordingly.
1245  if isinstance(error, InputError):
1246    error = error.original_exception
1247
1248  if _JOB_WORKER_STRING_IDENTIFIER not in str(error):
1249    return False
1250  if _RPC_ERROR_FROM_PS in str(error):
1251    return False
1252
1253  # TODO(haoyuzhang): Consider using special status code if error from a
1254  # remote is derived from RPC errors originated from other hosts.
1255  if isinstance(error, (errors.UnavailableError, errors.AbortedError)):
1256    return True
1257
1258  # The following error could happen when the remote task fails and restarts
1259  # in a very short interval during which no RPCs were exchanged to detect the
1260  # failure. In that case, gRPC allows channel (which is different from a
1261  # connection) to be reused for a replaced server listening to same address.
1262  if isinstance(error, errors.InvalidArgumentError):
1263    if ("unknown device" in str(error) or
1264        "Unable to find the relevant tensor remote_handle" in str(error)):
1265      # TODO(b/159961667): Fix "Unable to find the relevant tensor
1266      # remote_handle" part.
1267      return True
1268
1269  # TODO(b/162541228): The following 2 types of errors are very rare and only
1270  # observed in large-scale testing. The types of errors should be reduced.
1271  # This could happen when the function registration fails. In the observed
1272  # cases this only happens to the dataset related functions.
1273  if isinstance(error, errors.NotFoundError):
1274    if ("is neither a type of a primitive operation nor a name of a function "
1275        "registered" in str(error)):
1276      return True
1277
1278  # NOTE(b/179061495): During worker preemptions, if multiple functions are
1279  # running concurrently (especially with subfunctions spanning chief/PS),
1280  # CancelledError can be returned due to chief/PS cancelling outstanding RPCs
1281  # to the failing workers.
1282  if isinstance(error, errors.CancelledError):
1283    return True
1284
1285  return False
1286