• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Lint as: python3
2# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16"""Module for `ClusterCoordinator` and relevant cluster-worker related library.
17
18This is currently under development and the API is subject to change.
19"""
20
21from __future__ import absolute_import
22from __future__ import division
23from __future__ import print_function
24
25import contextlib
26import enum
27import functools
28import os
29import re
30import threading
31import time
32import weakref
33from six.moves import queue
34
35from tensorflow.python.data.ops import iterator_ops
36from tensorflow.python.distribute import input_lib
37from tensorflow.python.distribute import parameter_server_strategy_v2
38from tensorflow.python.distribute.coordinator import metric_utils
39from tensorflow.python.eager import cancellation
40from tensorflow.python.eager import context
41from tensorflow.python.eager import def_function
42from tensorflow.python.eager import executor
43from tensorflow.python.eager import function as tf_function
44from tensorflow.python.framework import errors
45from tensorflow.python.framework import func_graph
46from tensorflow.python.framework import ops
47from tensorflow.python.ops import variable_scope
48from tensorflow.python.platform import tf_logging as logging
49from tensorflow.python.util import nest
50from tensorflow.python.util.tf_export import tf_export
51
52# Maximum time for failed worker to come back is 1 hour
53_WORKER_MAXIMUM_RECOVERY_SEC = 3600
54
55# Maximum size for queued closures, "infinite" if set to 0.
56# When the maximum queue size is reached, further schedule calls will become
57# blocking until some previously queued closures are executed on workers.
58# Note that using an "infinite" queue size can take a non-trivial portion of
59# memory, and even lead to coordinator OOM. Modify the size to a smaller value
60# for coordinator with constrained memory resource (only recommended for
61# advanced users). Also used in unit tests to ensure the correctness when the
62# queue is full.
63_CLOSURE_QUEUE_MAX_SIZE = 256 * 1024
64
65# RPC error message from PS
66_RPC_ERROR_FROM_PS = "GRPC error information from remote target /job:ps"
67
68# InvalidArgumentError (unknown device) will not have "GRPC error..." string.
69_JOB_WORKER_STRING_IDENTIFIER = "/job:worker"
70
71
72class _RemoteValueStatus(enum.Enum):
73  """The status of a `RemoteValue` object.
74
75  A `RemoteValue` object can have three states:
76    1) not ready: no value, no non-retryable error and not aborted;
77    2) aborted: i.e. the execution of function was aborted because of task
78       failure, but can be retried;
79    3) ready: i.e. has value or has non-tryable error;
80
81  The initial state of a `RemoteValue` is "not ready". When its corresponding
82  closure has
83  been executed at least once, it will become aborted or ready. The state
84  transitions are:
85    1) not ready -> 2) aborted:
86      when the corresponding closure is aborted due to worker failure, and the
87      worker failure is not immediately handled.
88    1) not ready -> 3) ready:
89      when the corresponding closure has been executed successfully.
90    2) aborted -> 3) ready:
91      when the `RemoteValue` is rebuilt by rerunning the corresponding closure
92      and the closure has been executed successfully.
93    3) ready -> 2) aborted:
94      when the corresponding closure had been executed successfully but later
95      the corresponding remote worker failed. This is currently only implemented
96      for resource `RemoteValue` like iterators.
97  """
98  NOT_READY = "NOT_READY"
99  ABORTED = "ABORTED"
100  READY = "READY"
101
102
103@tf_export("distribute.experimental.coordinator.RemoteValue", v1=[])
104class RemoteValue(object):
105  """An asynchronously available value of a scheduled function.
106
107  This class is used as the return value of
108  `tf.distribute.experimental.coordinator.ClusterCoordinator.schedule` where
109  the underlying value becomes available at a later time once the function has
110  been executed.
111
112  Using `tf.distribute.experimental.coordinator.RemoteValue` as an input to
113  a subsequent function scheduled with
114  `tf.distribute.experimental.coordinator.ClusterCoordinator.schedule` is
115  currently not supported.
116
117  Example:
118
119  ```python
120  strategy = tf.distribute.experimental.ParameterServerStrategy(
121      cluster_resolver=...)
122  coordinator = (
123      tf.distribute.experimental.coordinator.ClusterCoordinator(strategy))
124
125  with strategy.scope():
126    v1 = tf.Variable(initial_value=0.0)
127    v2 = tf.Variable(initial_value=1.0)
128
129  @tf.function
130  def worker_fn():
131    v1.assign_add(0.1)
132    v2.assign_sub(0.2)
133    return v1.read_value() / v2.read_value()
134
135  result = coordinator.schedule(worker_fn)
136  # Note that `fetch()` gives the actual result instead of a `tf.Tensor`.
137  assert result.fetch() == 0.125
138
139  for _ in range(10):
140    # `worker_fn` will be run on arbitrary workers that are available. The
141    # `result` value will be available later.
142    result = coordinator.schedule(worker_fn)
143  ```
144  """
145
146  def fetch(self):
147    """Wait for the result of `RemoteValue` to be ready and return the result.
148
149    This makes the value concrete by copying the remote value to local.
150
151    Returns:
152      The actual output of the `tf.function` associated with this `RemoteValue`,
153      previously by a
154      `tf.distribute.experimental.coordinator.ClusterCoordinator.schedule` call.
155      This can be a single value, or a structure of values, depending on the
156      output of the `tf.function`.
157
158    Raises:
159      tf.errors.CancelledError: If the function that produces this `RemoteValue`
160        is aborted or cancelled due to failure.
161    """
162    raise NotImplementedError("Must be implemented in subclasses.")
163
164
165class RemoteValueImpl(RemoteValue):
166  """Implementation of `RemoteValue`."""
167
168  def __init__(self, closure, type_spec):  # pylint: disable=super-init-not-called
169    """Initializes a `RemoteValueImpl`.
170
171    Args:
172      closure: The closure from which the `RemoteValue` is created.
173      type_spec: The type spec for this `RemoteValue` which is used to trace
174        functions that take this `RemoteValue` as input.
175    """
176    self._closure = closure
177    self._type_spec = type_spec
178    self._values = None
179    self._fetched_numpys = None
180    self._error = None
181    self._status_available_event = threading.Event()
182    self._status = _RemoteValueStatus.NOT_READY
183
184  def _set_aborted(self):
185    self._status = _RemoteValueStatus.ABORTED
186    self._values = None
187    self._error = None
188
189    # Wake up any waiting thread and clear the event.
190    self._status_available_event.set()
191
192  def _rebuild_on(self, worker):
193    self._status_available_event.clear()
194    # TODO(yuefengz): we may need to rebuild its inputs as well.
195    self._closure.execute_on(worker)
196
197  def _set_values(self, tensors):
198    self._status = _RemoteValueStatus.READY
199    self._values = tensors
200    self._error = None
201    self._status_available_event.set()
202
203  def _set_error(self, exception):
204    self._status = _RemoteValueStatus.READY
205    self._values = None
206    self._error = exception
207    self._status_available_event.set()
208
209  def _get_values(self):
210    self._status_available_event.wait()
211    return self._values
212
213  def _get_error(self):
214    self._status_available_event.wait()
215    return self._error
216
217  def fetch(self):
218    self._status_available_event.wait()
219    if self._status is _RemoteValueStatus.ABORTED:
220      raise errors.CancelledError(
221          None, None,
222          "The corresponding function is aborted. Please reschedule the "
223          "function.")
224    if self._error is not None:
225      raise self._error
226    if self._fetched_numpys is None:
227      self._fetched_numpys = nest.map_structure(
228          lambda x: x.numpy() if hasattr(x, "numpy") else x, self._values)
229    return self._fetched_numpys
230
231
232class InputError(Exception):
233
234  def __init__(self, original_exception):
235    message = ("Input has an error, the original exception is %r, "
236               "error message is %s." %
237               (original_exception, str(original_exception)))
238    super().__init__(message)
239
240
241def _maybe_rebuild_remote_values(worker, structure):
242  """Attempts to return errors from `RemoteValue`s. Rebuilds them if needed."""
243  errors_in_structure = []
244
245  def _get_error(val):
246    if isinstance(val, RemoteValue):
247      if val._status is _RemoteValueStatus.ABORTED:  # pylint: disable=protected-access
248        try:
249          with worker.failure_handler.wait_on_failure(
250              on_recovery_fn=functools.partial(val._rebuild_on, worker),  # pylint: disable=protected-access
251              worker_device_name=worker.device_name):
252            val._rebuild_on(worker)  # pylint: disable=protected-access
253        except Exception as e:  # pylint: disable=broad-except
254          val._set_error(e)  # pylint: disable=protected-access
255
256      error = val._get_error()  # pylint: disable=protected-access
257      if error:
258        errors_in_structure.append(error)
259
260  nest.map_structure(_get_error, structure)
261  if errors_in_structure:
262    return errors_in_structure[0]
263  else:
264    return None
265
266
267def _maybe_get_remote_value(val):
268  """Gets the value of `val` if it is a `RemoteValue`."""
269  if isinstance(val, RemoteValue):
270    error = val._get_error()  # pylint: disable=protected-access
271    if error:
272      raise AssertionError(
273          "RemoteValue doesn't have a value because it has errors.")
274    else:
275      return val._get_values()  # pylint: disable=protected-access
276  else:
277    return val
278
279
280def _maybe_as_type_spec(val):
281  if isinstance(val, RemoteValue):
282    if val._type_spec is None:  # pylint: disable=protected-access
283      raise ValueError("Output of a scheduled function that is not "
284                       "tf.function cannot be the input of another function.")
285    return val._type_spec  # pylint: disable=protected-access
286  else:
287    return val
288
289
290@tf_export("distribute.experimental.coordinator.PerWorkerValues", v1=[])
291class PerWorkerValues(object):
292  """A container that holds a list of values, one value per worker.
293
294  `tf.distribute.experimental.coordinator.PerWorkerValues` contains a collection
295  of values, where each of the values is located on its corresponding worker,
296  and upon being used as one of the `args` or `kwargs` of
297  `tf.distribute.experimental.coordinator.ClusterCoordinator.schedule()`, the
298  value specific to a worker will be passed into the function being executed at
299  that corresponding worker.
300
301  Currently, the only supported path to create an object of
302  `tf.distribute.experimental.coordinator.PerWorkerValues` is through calling
303  `iter` on a `ClusterCoordinator.create_per_worker_dataset`-returned
304  distributed dataset instance. The mechanism to create a custom
305  `tf.distribute.experimental.coordinator.PerWorkerValues` is not yet supported.
306  """
307
308  def __init__(self, values):
309    self._values = tuple(values)
310
311
312def _select_worker_slice(worker_id, structured):
313  """Selects the worker slice of each of the items in `structured`."""
314
315  def _get(x):
316    return x._values[worker_id] if isinstance(x, PerWorkerValues) else x  # pylint: disable=protected-access
317
318  return nest.map_structure(_get, structured)
319
320
321def _disallow_remote_value_as_input(structured):
322  """Raises if any element of `structured` is a RemoteValue."""
323
324  def _raise_if_remote_value(x):
325    if isinstance(x, RemoteValue):
326      raise ValueError(
327          "`tf.distribute.experimental.coordinator.RemoteValue` used "
328          "as an input to scheduled function is not yet "
329          "supported.")
330
331  nest.map_structure(_raise_if_remote_value, structured)
332
333
334class Closure(object):
335  """Hold a function to be scheduled and its arguments."""
336
337  def __init__(self, function, cancellation_mgr, args=None, kwargs=None):
338    if not callable(function):
339      raise ValueError("Function passed to `ClusterCoordinator.schedule` must "
340                       "be a callable object.")
341    self._args = args or ()
342    self._kwargs = kwargs or {}
343
344    _disallow_remote_value_as_input(self._args)
345    _disallow_remote_value_as_input(self._kwargs)
346
347    if isinstance(function, def_function.Function):
348      replica_args = _select_worker_slice(0, self._args)
349      replica_kwargs = _select_worker_slice(0, self._kwargs)
350
351      # Note: no need to handle function registration failure since this kind of
352      # failure will not raise exceptions as designed in the runtime. The
353      # coordinator has to rely on subsequent operations that raise to catch
354      # function registration failure.
355
356      # Record the function tracing overhead. Note that we pass in the tracing
357      # count of the def_function.Function as a state tracker, so that metrics
358      # will only record the time for actual function tracing (i.e., excluding
359      # function cache lookups).
360      with metric_utils.monitored_timer(
361          "function_tracing", state_tracker=function._get_tracing_count):  # pylint: disable=protected-access
362        self._concrete_function = function.get_concrete_function(
363            *nest.map_structure(_maybe_as_type_spec, replica_args),
364            **nest.map_structure(_maybe_as_type_spec, replica_kwargs))
365    elif isinstance(function, tf_function.ConcreteFunction):
366      self._concrete_function = function
367
368    if hasattr(self, "_concrete_function"):
369      # If we have a concrete function, we get to retrieve the output type spec
370      # via the structured_output.
371      output_type_spec = func_graph.convert_structure_to_signature(
372          self._concrete_function.structured_outputs)
373      self._function = cancellation_mgr.get_cancelable_function(
374          self._concrete_function)
375    else:
376      # Otherwise (i.e. what is passed in is a regular python function), we have
377      # no such information.
378      output_type_spec = None
379      self._function = function
380
381    self.output_remote_value = RemoteValueImpl(self, output_type_spec)
382
383  def mark_cancelled(self):
384    self.output_remote_value._set_error(  # pylint: disable=protected-access
385        errors.CancelledError(
386            None, None, "The corresponding function is "
387            "cancelled. Please reschedule the function."))
388
389  def execute_on(self, worker):
390    """Executes the closure on the given worker.
391
392    Args:
393      worker: a `Worker` object.
394    """
395    replica_args = _select_worker_slice(worker.worker_index, self._args)
396    replica_kwargs = _select_worker_slice(worker.worker_index, self._kwargs)
397
398    e = (
399        _maybe_rebuild_remote_values(worker, replica_args) or
400        _maybe_rebuild_remote_values(worker, replica_kwargs))
401    if e:
402      if not isinstance(e, InputError):
403        e = InputError(e)
404      self.output_remote_value._set_error(e)  # pylint: disable=protected-access
405      return
406
407    with ops.device(worker.device_name):
408      with context.executor_scope(worker.executor):
409        with metric_utils.monitored_timer("closure_execution"):
410          output_values = self._function(
411              *nest.map_structure(_maybe_get_remote_value, replica_args),
412              **nest.map_structure(_maybe_get_remote_value, replica_kwargs))
413    self.output_remote_value._set_values(output_values)  # pylint: disable=protected-access
414
415
416class _CoordinatedClosureQueue(object):
417  """Manage a queue of closures, inflight count and errors from execution.
418
419  This class is thread-safe.
420  """
421
422  def __init__(self):
423    # `self._inflight_closure_count` only tracks the number of inflight closures
424    # that are "in generation". Once an error occurs, error generation is
425    # incremented and all subsequent arriving closures (from inflight) are
426    # considered "out of generation".
427    self._inflight_closure_count = 0
428
429    self._queue_lock = threading.Lock()
430
431    # Condition indicating that all pending closures (either queued or inflight)
432    # have been processed, failed, or cancelled.
433    self._stop_waiting_condition = threading.Condition(self._queue_lock)
434
435    # Condition indicating that an item becomes available in queue (not empty).
436    self._closures_queued_condition = threading.Condition(self._queue_lock)
437    self._should_process_closures = True
438
439    # Condition indicating that a queue slot becomes available (not full).
440    # Note that even with "infinite" queue size, there is still a "practical"
441    # size limit for the queue depending on host memory capacity, and thus the
442    # queue will eventually become full with a lot of enqueued closures.
443    self._queue_free_slot_condition = threading.Condition(self._queue_lock)
444
445    # Condition indicating there is no inflight closures.
446    self._no_inflight_closure_condition = threading.Condition(self._queue_lock)
447
448    # Use to cancel in-flight closures.
449    self._cancellation_mgr = cancellation.CancellationManager()
450
451    if _CLOSURE_QUEUE_MAX_SIZE <= 0:
452      logging.warning(
453          "In a `ClusterCoordinator`, creating an infinite closure queue can "
454          "consume a significant amount of memory and even lead to OOM.")
455    self._queue = queue.Queue(maxsize=_CLOSURE_QUEUE_MAX_SIZE)
456    self._error = None
457
458    # The following is a lock to make sure when `wait` is called and before it
459    # returns no `put` can be executed during this period. It is because `wait`
460    # won't know what to do with newly put closures. This lock adds an cutoff
461    # for `wait` so that closures put into the queue while waiting would not be
462    # taken responsible by this `wait`.
463    #
464    # We cannot reuse the `self._queue_lock` since when `wait` waits for a
465    # condition, the `self._queue_lock` will be released.
466    #
467    # We don't use a reader/writer's lock on purpose to reduce the complexity
468    # of the code.
469    self._put_wait_lock = threading.Lock()
470
471  def stop(self):
472    with self._queue_lock:
473      self._should_process_closures = False
474      self._closures_queued_condition.notifyAll()
475
476  def _cancel_all_closures(self):
477    """Clears the queue and sets remaining closures cancelled error.
478
479    This method expects self._queue_lock to be held prior to entry.
480    """
481    self._cancellation_mgr.start_cancel()
482    while self._inflight_closure_count > 0:
483      self._no_inflight_closure_condition.wait()
484    while True:
485      try:
486        closure = self._queue.get(block=False)
487        self._queue_free_slot_condition.notify()
488        closure.mark_cancelled()
489      except queue.Empty:
490        break
491    # The cancellation manager cannot be reused once cancelled. After all
492    # closures (queued or inflight) are cleaned up, recreate the cancellation
493    # manager with clean state.
494    # Note on thread-safety: this is triggered when one of theses
495    # ClusterCoordinator APIs are called: `schedule`, `wait`, and `done`. At the
496    # same time, no new closures can be constructed (which reads the
497    # _cancellation_mgr to get cancellable functions).
498    self._cancellation_mgr = cancellation.CancellationManager()
499
500  def _raise_if_error(self):
501    """Raises the error if one exists.
502
503    If an error exists, cancel the closures in queue, raises it, and clear
504    the error.
505
506    This method expects self._queue_lock to be held prior to entry.
507    """
508    if self._error:
509      logging.error("Start cancelling closures due to error %r: %s",
510                    self._error, self._error)
511      self._cancel_all_closures()
512      try:
513        raise self._error  # pylint: disable=raising-bad-type
514      finally:
515        self._error = None
516
517  def put(self, closure):
518    """Put a closure into the queue for later execution.
519
520    If `mark_failed` was called before `put`, the error from the first
521    invocation of `mark_failed` will be raised.
522
523    Args:
524      closure: The `Closure` to put into the queue.
525    """
526    with self._put_wait_lock, self._queue_lock:
527      self._queue_free_slot_condition.wait_for(lambda: not self._queue.full())
528      self._queue.put(closure, block=False)
529      self._raise_if_error()
530      self._closures_queued_condition.notify()
531
532  def get(self, timeout=None):
533    """Return a closure from the queue to be executed."""
534    with self._queue_lock:
535      while self._queue.empty() and self._should_process_closures:
536        if not self._closures_queued_condition.wait(timeout=timeout):
537          return None
538      if not self._should_process_closures:
539        return None
540      closure = self._queue.get(block=False)
541      self._queue_free_slot_condition.notify()
542      self._inflight_closure_count += 1
543      return closure
544
545  def mark_finished(self):
546    """Let the queue know that a closure has been successfully executed."""
547    with self._queue_lock:
548      if self._inflight_closure_count < 1:
549        raise AssertionError("There is no inflight closures to mark_finished.")
550      self._inflight_closure_count -= 1
551      if self._inflight_closure_count == 0:
552        self._no_inflight_closure_condition.notifyAll()
553      if self._queue.empty() and self._inflight_closure_count == 0:
554        self._stop_waiting_condition.notifyAll()
555
556  def put_back(self, closure):
557    """Put the closure back into the queue as it was not properly executed."""
558    with self._queue_lock:
559      if self._inflight_closure_count < 1:
560        raise AssertionError("There is no inflight closures to put_back.")
561      if self._error:
562        closure.mark_cancelled()
563      else:
564        self._queue_free_slot_condition.wait_for(lambda: not self._queue.full())
565        self._queue.put(closure, block=False)
566        self._closures_queued_condition.notify()
567      self._inflight_closure_count -= 1
568      if self._inflight_closure_count == 0:
569        self._no_inflight_closure_condition.notifyAll()
570
571  def wait(self, timeout=None):
572    """Wait for all closures to be finished before returning.
573
574    If `mark_failed` was called before or during `wait`, the error from the
575    first invocation of `mark_failed` will be raised.
576
577    Args:
578      timeout: A float specifying a timeout for the wait in seconds.
579
580    Returns:
581      True unless the given timeout expired, in which case it returns False.
582    """
583    with self._put_wait_lock, self._queue_lock:
584      while (not self._error and
585             (not self._queue.empty() or self._inflight_closure_count > 0)):
586        if not self._stop_waiting_condition.wait(timeout=timeout):
587          return False
588      self._raise_if_error()
589      return True
590
591  def mark_failed(self, e):
592    """Sets error and unblocks any wait() call."""
593    with self._queue_lock:
594      # TODO(yuefengz): maybe record all failure and give users more
595      # information?
596      if self._inflight_closure_count < 1:
597        raise AssertionError("There is no inflight closures to mark_failed.")
598      if self._error is None:
599        self._error = e
600      self._inflight_closure_count -= 1
601      if self._inflight_closure_count == 0:
602        self._no_inflight_closure_condition.notifyAll()
603      self._stop_waiting_condition.notifyAll()
604
605  def done(self):
606    """Returns true if the queue is empty and there is no inflight closure.
607
608    If `mark_failed` was called before `done`, the error from the first
609    invocation of `mark_failed` will be raised.
610    """
611    with self._queue_lock:
612      self._raise_if_error()
613      return self._queue.empty() and self._inflight_closure_count == 0
614
615
616class WorkerPreemptionHandler(object):
617  """Handles worker preemptions."""
618
619  def __init__(self, server_def, cluster):
620    self._server_def = server_def
621    self._cluster = cluster
622    self._cluster_update_lock = threading.Lock()
623    self._cluster_due_for_update_or_finish = threading.Event()
624    self._worker_up_cond = threading.Condition(self._cluster_update_lock)
625    self._should_preemption_thread_run = True
626    threading.Thread(target=self._preemption_handler,
627                     name="WorkerPreemptionHandler",
628                     daemon=True).start()
629
630  def stop(self):
631    """Ensure the worker preemption thread is closed."""
632    self._should_preemption_thread_run = False
633    with self._cluster_update_lock:
634      self._cluster_due_for_update_or_finish.set()
635
636  def _validate_preemption_failure(self, e):
637    """Validates that the given exception represents worker preemption."""
638    if _is_worker_failure(e):
639      return
640    raise e
641
642  @contextlib.contextmanager
643  def wait_on_failure(self,
644                      on_failure_fn=None,
645                      on_recovery_fn=None,
646                      worker_device_name="(unknown)"):
647    """Catches worker preemption error and wait until failed workers are back.
648
649    Args:
650      on_failure_fn: an optional function to run if preemption happens.
651      on_recovery_fn: an optional function to run when a worker is recovered
652        from preemption.
653      worker_device_name: the device name of the worker instance that is passing
654        through the failure.
655
656    Yields:
657      None.
658    """
659    try:
660      yield
661    except errors.OpError as e:
662      # If the error is due to temporary connectivity issues between worker and
663      # ps, put back closure, ignore error and do not mark worker as failure.
664      if self._cluster._record_and_ignore_transient_ps_failure(e):  # pylint: disable=protected-access
665        if on_failure_fn:
666          on_failure_fn()
667        return
668
669      self._validate_preemption_failure(e)
670      logging.error("Worker %s failed with error: %s", worker_device_name, e)
671      if on_failure_fn:
672        on_failure_fn()
673
674      with self._cluster_update_lock:
675        self._cluster_due_for_update_or_finish.set()
676        self._worker_up_cond.wait(_WORKER_MAXIMUM_RECOVERY_SEC)
677        logging.info("Worker %s has been recovered.", worker_device_name)
678
679      if on_recovery_fn:
680        with self.wait_on_failure(
681            on_recovery_fn=on_recovery_fn,
682            worker_device_name=worker_device_name):
683          on_recovery_fn()
684
685  def _preemption_handler(self):
686    """A loop that handles preemption.
687
688    This loop waits for signal of worker preemption and upon worker preemption,
689    it waits until all workers are back and updates the cluster about the
690    restarted workers.
691    """
692    while True:
693      self._cluster_due_for_update_or_finish.wait()
694      if not self._should_preemption_thread_run:
695        break
696
697      with self._cluster_update_lock:
698        try:
699          # TODO(haoyuzhang): support partial cluster recovery
700          logging.info("Cluster now being recovered.")
701          context.context().update_server_def(self._server_def)
702
703          # Cluster updated successfully, clear the update signal, and notify
704          # all workers that they are recovered from failure.
705          logging.info("Cluster successfully recovered.")
706          self._worker_up_cond.notify_all()
707          self._cluster_due_for_update_or_finish.clear()
708        except Exception as e:  # pylint: disable=broad-except
709          self._validate_preemption_failure(e)
710          # NOTE: Since the first RPC (GetStatus) of update_server_def is
711          # currently blocking by default, error should only happen if:
712          # (1) More workers failed while waiting for the previous workers to
713          #     come back;
714          # (2) Worker failed when exchanging subsequent RPCs after the first
715          #     RPC returns.
716          # Consider adding backoff retry logic if we see the error logged
717          # too frequently.
718          logging.error("Cluster update failed with error: %s. Retrying...", e)
719
720
721class Worker(object):
722  """A worker in a cluster.
723
724  Attributes:
725    worker_index: The index of the worker in the cluster.
726    device_name: The device string of the worker, e.g. "/job:worker/task:1".
727    executor: The worker's executor for remote function execution.
728    failure_handler: The failure handler used to handler worker preemption
729      failure.
730  """
731
732  def __init__(self, worker_index, device_name, cluster):
733    self.worker_index = worker_index
734    self.device_name = device_name
735    self.executor = executor.new_executor(enable_async=False)
736    self.failure_handler = cluster.failure_handler
737    self._cluster = cluster
738    self._resource_remote_value_refs = []
739    self._should_worker_thread_run = True
740
741    # Worker threads need to start after `Worker`'s initialization.
742    threading.Thread(target=self._process_queue,
743                     name="WorkerClosureProcessingLoop-%d" % self.worker_index,
744                     daemon=True).start()
745
746  def stop(self):
747    """Ensure the worker thread is closed."""
748    self._should_worker_thread_run = False
749
750  def _set_resources_aborted(self):
751    # TODO(yuefengz): maybe we can query whether a tensor is valid or not
752    # instead of marking a tensor aborted?
753    for weakref_resource in self._resource_remote_value_refs:
754      resource = weakref_resource()
755      if resource:
756        resource._set_aborted()  # pylint: disable=protected-access
757
758  def _set_dead(self):
759    raise NotImplementedError("_set_dead is not implemented.")
760
761  def _process_closure(self, closure):
762    """Runs a closure with preemption handling."""
763    assert closure is not None
764    try:
765      with self._cluster.failure_handler.wait_on_failure(
766          on_failure_fn=lambda: self._cluster._closure_queue.put_back(closure),  # pylint: disable=protected-access
767          on_recovery_fn=self._set_resources_aborted,
768          worker_device_name=self.device_name):
769        closure.execute_on(self)
770        # TODO(yuefengz): we don't have to materialize results every step.
771        with metric_utils.monitored_timer("remote_value_fetch"):
772          closure.output_remote_value.fetch()
773        self._cluster._closure_queue.mark_finished()  # pylint: disable=protected-access
774    except Exception as e:  # pylint: disable=broad-except
775      # Avoid logging the derived cancellation error
776      if not isinstance(e, errors.CancelledError):
777        logging.error(
778            "/job:worker/task:%d encountered the following error when "
779            "processing closure: %r:%s", self.worker_index, e, e)
780      closure.output_remote_value._set_error(e)  # pylint: disable=protected-access
781      self._cluster._closure_queue.mark_failed(e)  # pylint: disable=protected-access
782
783  def _maybe_delay(self):
784    """Delay if corresponding env vars are set."""
785    # If the following two env vars variables are set. Scheduling for workers
786    # will start in a staggered manner. Worker i will wait for
787    # `TF_COORDINATOR_SCHEDULE_START_DELAY` * i seconds, not exceeding
788    # `TF_COORDINATOR_SCHEDULE_START_DELAY_MAX`.
789    delay_secs = int(os.environ.get("TF_COORDINATOR_SCHEDULE_START_DELAY", "0"))
790    delay_cap = int(
791        os.environ.get("TF_COORDINATOR_SCHEDULE_START_DELAY_MAX", "0"))
792    if delay_cap:
793      delay_secs = min(delay_secs * self.worker_index, delay_cap)
794    if delay_secs > 0:
795      logging.info("Worker %d sleeping for %d seconds before running function",
796                   self.worker_index, delay_secs)
797    time.sleep(delay_secs)
798
799  def _process_queue(self):
800    """Function running in a thread to process closure queues."""
801    self._maybe_delay()
802    while self._should_worker_thread_run:
803      closure = self._cluster._closure_queue.get()  # pylint: disable=protected-access
804      if not self._should_worker_thread_run or closure is None:
805        return
806      self._process_closure(closure)
807
808  def _create_resource(self, function, args=None, kwargs=None):
809    """Synchronously creates a per-worker resource represented by a `RemoteValue`.
810
811    Args:
812      function: the resource function to be run remotely. It should be a
813        `tf.function`, a concrete function or a Python function.
814      args: positional arguments to be passed to the function.
815      kwargs: keyword arguments to be passed to the function.
816
817    Returns:
818      one or several RemoteValue objects depending on the function return
819      values.
820    """
821    # Some notes about the concurrency: currently all the activities related to
822    # the same worker such as creating resources, setting resources' aborted
823    # status, and executing closures happen on the same thread. This allows us
824    # to have simpler logic of concurrency.
825    closure = Closure(
826        function,
827        self._cluster._closure_queue._cancellation_mgr,  # pylint: disable=protected-access
828        args=args,
829        kwargs=kwargs)
830    resource_remote_value = closure.output_remote_value
831    self._register_resource(resource_remote_value)
832
833    # The following is a short-term solution to lazily create resources in
834    # parallel.
835    # TODO(b/160343165): we should create resources eagerly, i.e. schedule the
836    # resource creation function as soon as users call this method.
837    resource_remote_value._set_aborted()  # pylint: disable=protected-access
838    return resource_remote_value
839
840  def _register_resource(self, resource_remote_value):
841    if not isinstance(resource_remote_value, RemoteValue):
842      raise ValueError("Resource being registered is not of type "
843                       "`tf.distribute.experimental.coordinator.RemoteValue`.")
844    self._resource_remote_value_refs.append(weakref.ref(resource_remote_value))
845
846
847class Cluster(object):
848  """A cluster with workers.
849
850  We assume all function errors are fatal and based on this assumption our
851  error reporting logic is:
852  1) Both `schedule` and `join` can raise a non-retryable error which is the
853  first error seen by the coordinator from any previously scheduled functions.
854  2) When an error is raised, there is no guarantee on how many previously
855  scheduled functions have been executed; functions that have not been executed
856  will be thrown away and marked as cancelled.
857  3) After an error is raised, the internal state of error will be cleared.
858  I.e. functions can continue to be scheduled and subsequent calls of `schedule`
859  or `join` will not raise the same error again.
860
861  Attributes:
862    failure_handler: The failure handler used to handler worker preemption
863      failure.
864    workers: a list of `Worker` objects in the cluster.
865  """
866
867  def __init__(self, strategy):
868    """Initializes the cluster instance."""
869
870    self._num_workers = strategy._num_workers
871    self._num_ps = strategy._num_ps
872
873    # Ignore PS failures reported by workers due to transient connection errors.
874    # Transient connectivity issues between workers and PS are relayed by the
875    # workers to the coordinator, leading the coordinator to believe that there
876    # are PS failures. The difference between transient vs. permanent PS failure
877    # is the number of reports from the workers. When this env var is set to a
878    # positive integer K, the coordinator ignores up to K reports of a failed PS
879    # task, i.e., only when there are more than K trials of executing closures
880    # fail due to errors from the same PS instance do we consider the PS
881    # instance encounters a failure.
882    # TODO(b/164279603): Remove this workaround when the underlying connectivity
883    # issue in gRPC server is resolved.
884    self._transient_ps_failures_threshold = int(
885        os.environ.get("TF_COORDINATOR_IGNORE_TRANSIENT_PS_FAILURES", 3))
886    self._potential_ps_failures_lock = threading.Lock()
887    self._potential_ps_failures_count = [0] * self._num_ps
888
889    self._closure_queue = _CoordinatedClosureQueue()
890    self.failure_handler = WorkerPreemptionHandler(context.get_server_def(),
891                                                   self)
892    worker_device_strings = [
893        "/job:worker/replica:0/task:%d" % i for i in range(self._num_workers)
894    ]
895    self.workers = [
896        Worker(i, w, self) for i, w in enumerate(worker_device_strings)
897    ]
898
899  def stop(self):
900    """Stop worker, worker preemption threads, and the closure queue."""
901    self.failure_handler.stop()
902
903    for worker in self.workers:
904      worker.stop()
905    self._closure_queue.stop()
906
907  def _record_and_ignore_transient_ps_failure(self, e):
908    """Records potential PS failures and return if failure should be ignored."""
909    if self._transient_ps_failures_threshold <= 0 or not _is_ps_failure(e):
910      return False
911
912    ps_tasks = _extract_failed_ps_instances(str(e))
913    with self._potential_ps_failures_lock:
914      for t in ps_tasks:
915        self._potential_ps_failures_count[t] += 1
916        # The number of UnavailableError encountered on this PS task exceeds the
917        # maximum number of ignored error
918        if (self._potential_ps_failures_count[t] >=
919            self._transient_ps_failures_threshold):
920          return False
921    return True
922
923  def schedule(self, function, args, kwargs):
924    """Schedules `function` to be dispatched to a worker for execution.
925
926    Args:
927      function: The function to be dispatched to a worker for execution
928        asynchronously.
929      args: Positional arguments for `fn`.
930      kwargs: Keyword arguments for `fn`.
931
932    Returns:
933      A `RemoteValue` object.
934    """
935    closure = Closure(
936        function,
937        self._closure_queue._cancellation_mgr,  # pylint: disable=protected-access
938        args=args,
939        kwargs=kwargs)
940    self._closure_queue.put(closure)
941    return closure.output_remote_value
942
943  def join(self):
944    """Blocks until all scheduled functions are executed."""
945    self._closure_queue.wait()
946
947  def done(self):
948    """Returns true if all scheduled functions are executed."""
949    return self._closure_queue.done()
950
951
952@tf_export("distribute.experimental.coordinator.ClusterCoordinator", v1=[])
953class ClusterCoordinator(object):
954  """An object to schedule and coordinate remote function execution.
955
956  This class is used to create fault-tolerant resources and dispatch functions
957  to remote TensorFlow servers.
958
959  Currently, this class is not supported to be used in a standalone manner. It
960  should be used in conjunction with a `tf.distribute` strategy that is designed
961  to work with it. The `ClusterCoordinator` class currently only works
962  `tf.distribute.experimental.ParameterServerStrategy`.
963
964  __The `schedule`/`join` APIs__
965
966  The most important APIs provided by this class is the `schedule`/`join` pair.
967  The `schedule` API is non-blocking in that it queues a `tf.function` and
968  returns a `RemoteValue` immediately. The queued functions will be dispatched
969  to remote workers in background threads and their `RemoteValue`s will be
970  filled asynchronously. Since `schedule` doesn’t require worker assignment, the
971  `tf.function` passed in can be executed on any available worker. If the worker
972  it is executed on becomes unavailable before its completion, it will be
973  migrated to another worker. Because of this fact and function execution is not
974  atomic, a function may be executed more than once.
975
976  __Handling Task Failure__
977
978  This class when used with
979  `tf.distribute.experimental.ParameterServerStrategy`, comes with built-in
980  fault tolerance for worker failures. That is, when some workers are not
981  available for any reason to be reached from the coordinator, the training
982  progress continues to be made with the remaining workers. Upon recovery of a
983  failed worker, it will be added for function execution after datasets created
984  by `create_per_worker_dataset` are re-built on it.
985
986  When a parameter server fails, a `tf.errors.UnavailableError` is raised by
987  `schedule`, `join` or `done`. In this case, in addition to bringing back the
988  failed parameter server, users should restart the coordinator so that it
989  reconnects to workers and parameter servers, re-creates the variables, and
990  loads checkpoints. If the coordinator fails, after the user brings it back,
991  the program will automatically connect to workers and parameter servers, and
992  continue the progress from a checkpoint.
993
994  It is thus essential that in user's program, a checkpoint file is periodically
995  saved, and restored at the start of the program. If an
996  `tf.keras.optimizers.Optimizer` is checkpointed, after restoring from a
997  checkpoiont, its `iterations` property roughly indicates the number of steps
998  that have been made. This can be used to decide how many epochs and steps are
999  needed before the training completion.
1000
1001  See `tf.distribute.experimental.ParameterServerStrategy` docstring for an
1002  example usage of this API.
1003
1004  This is currently under development, and the API as well as implementation
1005  are subject to changes.
1006  """
1007
1008  def __new__(cls, strategy):
1009    # `ClusterCoordinator` is kept as a single instance to a given `Strategy`.
1010    # TODO(rchao): Needs a lock for thread-safety
1011    if strategy._cluster_coordinator is None:
1012      strategy._cluster_coordinator = super(
1013          ClusterCoordinator, cls).__new__(cls)
1014    return strategy._cluster_coordinator
1015
1016  def __init__(self, strategy):
1017    """Initialization of a `ClusterCoordinator` instance.
1018
1019    Args:
1020      strategy: a supported `tf.distribute.Strategy` object. Currently, only
1021        `tf.distribute.experimental.ParameterServerStrategy` is supported.
1022
1023    Raises:
1024      ValueError: if the strategy being used is not supported.
1025    """
1026    if not isinstance(strategy,
1027                      parameter_server_strategy_v2.ParameterServerStrategyV2):
1028      raise ValueError(
1029          "Only `tf.distribute.experimental.ParameterServerStrategy` "
1030          "is supported to work with "
1031          "`tf.distribute.experimental.coordinator.ClusterCoordinator` "
1032          "currently.")
1033    self._strategy = strategy
1034    self.strategy.extended._used_with_coordinator = True
1035    self._cluster = Cluster(strategy)
1036
1037  def __del__(self):
1038    self._cluster.stop()
1039
1040  @property
1041  def strategy(self):
1042    """Returns the `Strategy` associated with the `ClusterCoordinator`."""
1043    return self._strategy
1044
1045  def schedule(self, fn, args=None, kwargs=None):
1046    """Schedules `fn` to be dispatched to a worker for asynchronous execution.
1047
1048    This method is non-blocking in that it queues the `fn` which will be
1049    executed later and returns a
1050    `tf.distribute.experimental.coordinator.RemoteValue` object immediately.
1051    `fetch` can be called on it to wait for the function execution to finish
1052    and retrieve its output from a remote worker. On the other hand, call
1053    `tf.distribute.experimental.coordinator.ClusterCoordinator.join` to wait for
1054    all scheduled functions to finish.
1055
1056    `schedule` guarantees that `fn` will be executed on a worker at least once;
1057    it could be more than once if its corresponding worker fails in the middle
1058    of its execution. Note that since worker can fail at any point when
1059    executing the function, it is possible that the function is partially
1060    executed, but `tf.distribute.experimental.coordinator.ClusterCoordinator`
1061    guarantees that in those events, the function will eventually be executed on
1062    any worker that is available.
1063
1064    If any previously scheduled function raises an error, `schedule` will raise
1065    any one of those errors, and clear the errors collected so far. What happens
1066    here, some of the previously scheduled functions may have not been executed.
1067    User can call `fetch` on the returned
1068    `tf.distribute.experimental.coordinator.RemoteValue` to inspect if they have
1069    executed, failed, or cancelled, and reschedule the corresponding function if
1070    needed.
1071
1072    When `schedule` raises, it guarantees that there is no function that is
1073    still being executed.
1074
1075    At this time, there is no support of worker assignment for function
1076    execution, or priority of the workers.
1077
1078    `args` and `kwargs` are the arguments passed into `fn`, when `fn` is
1079    executed on a worker. They can be
1080    `tf.distribute.experimental.coordinator.PerWorkerValues` and in this case,
1081    the argument will be substituted with the corresponding component on the
1082    target worker. Arguments that are not
1083    `tf.distribute.experimental.coordinator.PerWorkerValues` will be passed into
1084    `fn` as-is. Currently, `tf.distribute.experimental.coordinator.RemoteValue`
1085    is not supported to be input `args` or `kwargs`.
1086
1087    Args:
1088      fn: A `tf.function`; the function to be dispatched to a worker for
1089        execution asynchronously. Regular python funtion is not supported to be
1090        scheduled.
1091      args: Positional arguments for `fn`.
1092      kwargs: Keyword arguments for `fn`.
1093
1094    Returns:
1095      A `tf.distribute.experimental.coordinator.RemoteValue` object that
1096      represents the output of the function scheduled.
1097
1098    Raises:
1099      Exception: one of the exceptions caught by the coordinator from any
1100        previously scheduled function, since the last time an error was thrown
1101        or since the beginning of the program.
1102    """
1103    if not isinstance(fn,
1104                      (def_function.Function, tf_function.ConcreteFunction)):
1105      raise TypeError(
1106          "`tf.distribute.experimental.coordinator.ClusterCoordinator.schedule`"
1107          " only accepts a `tf.function` or a concrete function.")
1108    # Slot variables are usually created during function tracing time; thus
1109    # `schedule` needs to be called within the `strategy.scope()`.
1110    with self.strategy.scope():
1111      self.strategy.extended._being_scheduled = True  # pylint: disable=protected-access
1112      remote_value = self._cluster.schedule(fn, args=args, kwargs=kwargs)
1113      self.strategy.extended._being_scheduled = False  # pylint: disable=protected-access
1114      return remote_value
1115
1116  def join(self):
1117    """Blocks until all the scheduled functions have finished execution.
1118
1119    If any previously scheduled function raises an error, `join` will fail by
1120    raising any one of those errors, and clear the errors collected so far. If
1121    this happens, some of the previously scheduled functions may have not been
1122    executed. Users can call `fetch` on the returned
1123    `tf.distribute.experimental.coordinator.RemoteValue` to inspect if they have
1124    executed, failed, or cancelled. If some that have been cancelled need to be
1125    rescheduled, users should call `schedule` with the function again.
1126
1127    When `join` returns or raises, it guarantees that there is no function that
1128    is still being executed.
1129
1130    Raises:
1131      Exception: one of the exceptions caught by the coordinator by any
1132        previously scheduled function since the last time an error was thrown or
1133        since the beginning of the program.
1134    """
1135    self._cluster.join()
1136
1137  def done(self):
1138    """Returns whether all the scheduled functions have finished execution.
1139
1140    If any previously scheduled function raises an error, `done` will fail by
1141    raising any one of those errors.
1142
1143    When `done` returns True or raises, it guarantees that there is no function
1144    that is still being executed.
1145
1146    Returns:
1147      Whether all the scheduled functions have finished execution.
1148    Raises:
1149      Exception: one of the exceptions caught by the coordinator by any
1150        previously scheduled function since the last time an error was thrown or
1151        since the beginning of the program.
1152    """
1153    return self._cluster.done()
1154
1155  def create_per_worker_dataset(self, dataset_fn):
1156    """Create dataset on workers by calling `dataset_fn` on worker devices.
1157
1158    This creates the given dataset generated by dataset_fn on workers
1159    and returns an object that represents the collection of those individual
1160    datasets. Calling `iter` on such collection of datasets returns a
1161    `tf.distribute.experimental.coordinator.PerWorkerValues`, which is a
1162    collection of iterators, where the iterators have been placed on respective
1163    workers.
1164
1165    Calling `next` on a `PerWorkerValues` of iterator is unsupported. The
1166    iterator is meant to be passed as an argument into
1167    `tf.distribute.experimental.coordinator.ClusterCoordinator.schedule`. When
1168    the scheduled function is about to be executed by a worker, the
1169    function will receive the individual iterator that corresponds to the
1170    worker. The `next` method can be called on an iterator inside a
1171    scheduled function when the iterator is an input of the function.
1172
1173    Currently the `schedule` method assumes workers are all the same and thus
1174    assumes the datasets on different workers are the same, except they may be
1175    shuffled differently if they contain a `dataset.shuffle` operation and a
1176    random seed is not set. Because of this, we also recommend the datasets to
1177    be repeated indefinitely and schedule a finite number of steps instead of
1178    relying on the `OutOfRangeError` from a dataset.
1179
1180
1181    Example:
1182
1183    ```python
1184    strategy = tf.distribute.experimental.ParameterServerStrategy(
1185        cluster_resolver=...)
1186    coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator(
1187        strategy=strategy)
1188
1189    @tf.function
1190    def worker_fn(iterator):
1191      return next(iterator)
1192
1193    def per_worker_dataset_fn():
1194      return strategy.distribute_datasets_from_function(
1195          lambda x: tf.data.Dataset.from_tensor_slices([3] * 3))
1196
1197    per_worker_dataset = coordinator.create_per_worker_dataset(
1198        per_worker_dataset_fn)
1199    per_worker_iter = iter(per_worker_dataset)
1200    remote_value = coordinator.schedule(worker_fn, args=(per_worker_iter,))
1201    assert remote_value.fetch() == 3
1202    ```
1203
1204    NOTE: A known limitation is `tf.data.Options` is ignored in dataset created
1205    by `create_per_worker_dataset`.
1206
1207    Args:
1208      dataset_fn: The dataset function that returns a dataset. This is to be
1209        executed on the workers.
1210
1211    Returns:
1212      An object that represents the collection of those individual
1213      datasets. `iter` is expected to be called on this object that returns
1214      a `tf.distribute.experimental.coordinator.PerWorkerValues` of the
1215      iterators (that are on the workers).
1216    """
1217    input_workers = input_lib.InputWorkers([
1218        (w.device_name, [w.device_name]) for w in self._cluster.workers
1219    ])
1220
1221    return _PerWorkerDistributedDataset(dataset_fn, input_workers, self)
1222
1223  def _create_per_worker_resources(self, fn, args=None, kwargs=None):
1224    """Synchronously create resources on the workers.
1225
1226    The resources are represented by
1227    `tf.distribute.experimental.coordinator.RemoteValue`s.
1228
1229    Args:
1230      fn: The function to be dispatched to all workers for execution
1231        asynchronously.
1232      args: Positional arguments for `fn`.
1233      kwargs: Keyword arguments for `fn`.
1234
1235    Returns:
1236      A `tf.distribute.experimental.coordinator.PerWorkerValues` object, which
1237      wraps a tuple of `tf.distribute.experimental.coordinator.RemoteValue`
1238      objects.
1239    """
1240    results = []
1241    for w in self._cluster.workers:
1242      results.append(w._create_resource(fn, args=args, kwargs=kwargs))  # pylint: disable=protected-access
1243    return PerWorkerValues(tuple(results))
1244
1245  def fetch(self, val):
1246    """Blocking call to fetch results from the remote values.
1247
1248    This is a wrapper around
1249    `tf.distribute.experimental.coordinator.RemoteValue.fetch` for a
1250    `RemoteValue` structure; it returns the execution results of
1251    `RemoteValue`s. If not ready, wait for them while blocking the caller.
1252
1253    Example:
1254    ```python
1255    strategy = ...
1256    coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator(
1257        strategy)
1258
1259    def dataset_fn():
1260      return tf.data.Dataset.from_tensor_slices([1, 1, 1])
1261
1262    with strategy.scope():
1263      v = tf.Variable(initial_value=0)
1264
1265    @tf.function
1266    def worker_fn(iterator):
1267      def replica_fn(x):
1268        v.assign_add(x)
1269        return v.read_value()
1270      return strategy.run(replica_fn, args=(next(iterator),))
1271
1272    distributed_dataset = coordinator.create_per_worker_dataset(dataset_fn)
1273    distributed_iterator = iter(distributed_dataset)
1274    result = coordinator.schedule(worker_fn, args=(distributed_iterator,))
1275    assert coordinator.fetch(result) == 1
1276    ```
1277
1278    Args:
1279      val: The value to fetch the results from. If this is structure of
1280        `tf.distribute.experimental.coordinator.RemoteValue`, `fetch()` will be
1281        called on the individual
1282        `tf.distribute.experimental.coordinator.RemoteValue` to get the result.
1283
1284    Returns:
1285      If `val` is a `tf.distribute.experimental.coordinator.RemoteValue` or a
1286      structure of `tf.distribute.experimental.coordinator.RemoteValue`s,
1287      return the fetched `tf.distribute.experimental.coordinator.RemoteValue`
1288      values immediately if they are available, or block the call until they are
1289      available, and return the fetched
1290      `tf.distribute.experimental.coordinator.RemoteValue` values with the same
1291      structure. If `val` is other types, return it as-is.
1292    """
1293
1294    def _maybe_fetch(val):
1295      if isinstance(val, RemoteValue):
1296        return val.fetch()
1297      else:
1298        return val
1299
1300    # TODO(yuefengz): we should fetch values in a batch.
1301    return nest.map_structure(_maybe_fetch, val)
1302
1303
1304class _PerWorkerDistributedDataset(object):
1305  """Represents worker-distributed datasets created from dataset function."""
1306
1307  def __init__(self, dataset_fn, input_workers, coordinator):
1308    """Makes an iterable from datasets created by the given function.
1309
1310    Args:
1311      dataset_fn: A function that returns a `Dataset`.
1312      input_workers: an `InputWorkers` object.
1313      coordinator: a `ClusterCoordinator` object, used to create dataset
1314        resources.
1315    """
1316    def disallow_variable_creation(next_creator, **kwargs):
1317      raise ValueError("Creating variables in `dataset_fn` is not allowed.")
1318
1319    if isinstance(dataset_fn, def_function.Function):
1320      with variable_scope.variable_creator_scope(disallow_variable_creation):
1321        dataset_fn = dataset_fn.get_concrete_function()
1322    elif not isinstance(dataset_fn, tf_function.ConcreteFunction):
1323      with variable_scope.variable_creator_scope(disallow_variable_creation):
1324        dataset_fn = def_function.function(dataset_fn).get_concrete_function()
1325    self._dataset_fn = dataset_fn
1326    self._input_workers = input_workers
1327    self._coordinator = coordinator
1328    self._element_spec = None
1329
1330  def __iter__(self):
1331    # We would like users to create iterators outside `tf.function`s so that we
1332    # can track them.
1333    if (not context.executing_eagerly() or
1334        ops.get_default_graph().building_function):
1335      raise RuntimeError(
1336          "__iter__() is not supported inside of tf.function or in graph mode.")
1337
1338    def _create_per_worker_iterator():
1339      dataset = self._dataset_fn()
1340      return iter(dataset)
1341
1342    # If _PerWorkerDistributedDataset.__iter__ is called multiple
1343    # times, for the same object it should only create and register resource
1344    # once. Using object id to distinguish different iterator resources.
1345    per_worker_iterator = self._coordinator._create_per_worker_resources(
1346        _create_per_worker_iterator)
1347
1348    # Setting type_spec of each RemoteValue so that functions taking these
1349    # RemoteValues as inputs can be traced.
1350    for iterator_remote_value in per_worker_iterator._values:
1351      iterator_remote_value._type_spec = (  # pylint: disable=protected-access
1352          iterator_ops.IteratorSpec(
1353              self._dataset_fn.structured_outputs.element_spec))
1354    return _PerWorkerDistributedIterator(per_worker_iterator._values)
1355
1356  @property
1357  def element_spec(self):
1358    """The type specification of an element of this dataset."""
1359    raise NotImplementedError("Passing `AsyncDistributedDataset` to a "
1360                              "tf.function is not supported.")
1361
1362
1363class _PerWorkerDistributedIterator(PerWorkerValues):
1364  """Distributed iterator for `ClusterCoordinator`."""
1365
1366  def __next__(self):
1367    return self.get_next()
1368
1369  def get_next(self, name=None):
1370    """Returns the next input from the iterator for all replicas."""
1371    raise NotImplementedError("Iterating over an `AsyncDistributedIterator` "
1372                              "is not supported right now.")
1373
1374
1375def _extract_failed_ps_instances(err_msg):
1376  """Return a set of potentially failing ps instances from error message."""
1377  tasks = re.findall("/job:ps/replica:0/task:[0-9]+", err_msg)
1378  return set(int(t.split(":")[-1]) for t in tasks)
1379
1380
1381def _is_ps_failure(error):
1382  """Whether the error is considered a parameter server failure."""
1383  return (isinstance(error, errors.UnavailableError) and
1384          _RPC_ERROR_FROM_PS in str(error))
1385
1386
1387def _is_worker_failure(error):
1388  """Whether the error is considered a worker failure."""
1389  if _JOB_WORKER_STRING_IDENTIFIER not in str(error):
1390    return False
1391  if _RPC_ERROR_FROM_PS in str(error):
1392    return False
1393
1394  # TODO(haoyuzhang): Consider using special status code if error from a
1395  # remote is derived from RPC errors originated from other hosts.
1396  if isinstance(error, (errors.UnavailableError, errors.AbortedError)):
1397    return True
1398
1399  # The following error could happen when the remote task fails and restarts
1400  # in a very short interval during which no RPCs were exchanged to detect the
1401  # failure. In that case, gRPC allows channel (which is different from a
1402  # connection) to be reused for a replaced server listening to same address.
1403  if isinstance(error, errors.InvalidArgumentError):
1404    if ("unknown device" in str(error) or
1405        "Unable to find the relevant tensor remote_handle" in str(error)):
1406      # TODO(b/159961667): Fix "Unable to find the relevant tensor
1407      # remote_handle" part.
1408      return True
1409
1410  # TODO(b/162541228): The following 2 types of errors are very rare and only
1411  # observed in large-scale testing. The types of errors should be reduced.
1412  # This could happen when the function registration fails. In the observed
1413  # cases this only happens to the dataset related functions.
1414  if isinstance(error, errors.NotFoundError):
1415    if ("is neither a type of a primitive operation nor a name of a function "
1416        "registered" in str(error)):
1417      return True
1418
1419  return False
1420