1# Lint as: python3 2# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16"""Module for `ClusterCoordinator` and relevant cluster-worker related library. 17 18This is currently under development and the API is subject to change. 19""" 20 21from __future__ import absolute_import 22from __future__ import division 23from __future__ import print_function 24 25import contextlib 26import enum 27import functools 28import os 29import re 30import threading 31import time 32import weakref 33from six.moves import queue 34 35from tensorflow.python.data.ops import iterator_ops 36from tensorflow.python.distribute import input_lib 37from tensorflow.python.distribute import parameter_server_strategy_v2 38from tensorflow.python.distribute.coordinator import metric_utils 39from tensorflow.python.eager import cancellation 40from tensorflow.python.eager import context 41from tensorflow.python.eager import def_function 42from tensorflow.python.eager import executor 43from tensorflow.python.eager import function as tf_function 44from tensorflow.python.framework import errors 45from tensorflow.python.framework import func_graph 46from tensorflow.python.framework import ops 47from tensorflow.python.ops import variable_scope 48from tensorflow.python.platform import tf_logging as logging 49from tensorflow.python.util import nest 50from tensorflow.python.util.tf_export import tf_export 51 52# Maximum time for failed worker to come back is 1 hour 53_WORKER_MAXIMUM_RECOVERY_SEC = 3600 54 55# Maximum size for queued closures, "infinite" if set to 0. 56# When the maximum queue size is reached, further schedule calls will become 57# blocking until some previously queued closures are executed on workers. 58# Note that using an "infinite" queue size can take a non-trivial portion of 59# memory, and even lead to coordinator OOM. Modify the size to a smaller value 60# for coordinator with constrained memory resource (only recommended for 61# advanced users). Also used in unit tests to ensure the correctness when the 62# queue is full. 63_CLOSURE_QUEUE_MAX_SIZE = 256 * 1024 64 65# RPC error message from PS 66_RPC_ERROR_FROM_PS = "GRPC error information from remote target /job:ps" 67 68# InvalidArgumentError (unknown device) will not have "GRPC error..." string. 69_JOB_WORKER_STRING_IDENTIFIER = "/job:worker" 70 71 72class _RemoteValueStatus(enum.Enum): 73 """The status of a `RemoteValue` object. 74 75 A `RemoteValue` object can have three states: 76 1) not ready: no value, no non-retryable error and not aborted; 77 2) aborted: i.e. the execution of function was aborted because of task 78 failure, but can be retried; 79 3) ready: i.e. has value or has non-tryable error; 80 81 The initial state of a `RemoteValue` is "not ready". When its corresponding 82 closure has 83 been executed at least once, it will become aborted or ready. The state 84 transitions are: 85 1) not ready -> 2) aborted: 86 when the corresponding closure is aborted due to worker failure, and the 87 worker failure is not immediately handled. 88 1) not ready -> 3) ready: 89 when the corresponding closure has been executed successfully. 90 2) aborted -> 3) ready: 91 when the `RemoteValue` is rebuilt by rerunning the corresponding closure 92 and the closure has been executed successfully. 93 3) ready -> 2) aborted: 94 when the corresponding closure had been executed successfully but later 95 the corresponding remote worker failed. This is currently only implemented 96 for resource `RemoteValue` like iterators. 97 """ 98 NOT_READY = "NOT_READY" 99 ABORTED = "ABORTED" 100 READY = "READY" 101 102 103@tf_export("distribute.experimental.coordinator.RemoteValue", v1=[]) 104class RemoteValue(object): 105 """An asynchronously available value of a scheduled function. 106 107 This class is used as the return value of 108 `tf.distribute.experimental.coordinator.ClusterCoordinator.schedule` where 109 the underlying value becomes available at a later time once the function has 110 been executed. 111 112 Using `tf.distribute.experimental.coordinator.RemoteValue` as an input to 113 a subsequent function scheduled with 114 `tf.distribute.experimental.coordinator.ClusterCoordinator.schedule` is 115 currently not supported. 116 117 Example: 118 119 ```python 120 strategy = tf.distribute.experimental.ParameterServerStrategy( 121 cluster_resolver=...) 122 coordinator = ( 123 tf.distribute.experimental.coordinator.ClusterCoordinator(strategy)) 124 125 with strategy.scope(): 126 v1 = tf.Variable(initial_value=0.0) 127 v2 = tf.Variable(initial_value=1.0) 128 129 @tf.function 130 def worker_fn(): 131 v1.assign_add(0.1) 132 v2.assign_sub(0.2) 133 return v1.read_value() / v2.read_value() 134 135 result = coordinator.schedule(worker_fn) 136 # Note that `fetch()` gives the actual result instead of a `tf.Tensor`. 137 assert result.fetch() == 0.125 138 139 for _ in range(10): 140 # `worker_fn` will be run on arbitrary workers that are available. The 141 # `result` value will be available later. 142 result = coordinator.schedule(worker_fn) 143 ``` 144 """ 145 146 def fetch(self): 147 """Wait for the result of `RemoteValue` to be ready and return the result. 148 149 This makes the value concrete by copying the remote value to local. 150 151 Returns: 152 The actual output of the `tf.function` associated with this `RemoteValue`, 153 previously by a 154 `tf.distribute.experimental.coordinator.ClusterCoordinator.schedule` call. 155 This can be a single value, or a structure of values, depending on the 156 output of the `tf.function`. 157 158 Raises: 159 tf.errors.CancelledError: If the function that produces this `RemoteValue` 160 is aborted or cancelled due to failure. 161 """ 162 raise NotImplementedError("Must be implemented in subclasses.") 163 164 165class RemoteValueImpl(RemoteValue): 166 """Implementation of `RemoteValue`.""" 167 168 def __init__(self, closure, type_spec): # pylint: disable=super-init-not-called 169 """Initializes a `RemoteValueImpl`. 170 171 Args: 172 closure: The closure from which the `RemoteValue` is created. 173 type_spec: The type spec for this `RemoteValue` which is used to trace 174 functions that take this `RemoteValue` as input. 175 """ 176 self._closure = closure 177 self._type_spec = type_spec 178 self._values = None 179 self._fetched_numpys = None 180 self._error = None 181 self._status_available_event = threading.Event() 182 self._status = _RemoteValueStatus.NOT_READY 183 184 def _set_aborted(self): 185 self._status = _RemoteValueStatus.ABORTED 186 self._values = None 187 self._error = None 188 189 # Wake up any waiting thread and clear the event. 190 self._status_available_event.set() 191 192 def _rebuild_on(self, worker): 193 self._status_available_event.clear() 194 # TODO(yuefengz): we may need to rebuild its inputs as well. 195 self._closure.execute_on(worker) 196 197 def _set_values(self, tensors): 198 self._status = _RemoteValueStatus.READY 199 self._values = tensors 200 self._error = None 201 self._status_available_event.set() 202 203 def _set_error(self, exception): 204 self._status = _RemoteValueStatus.READY 205 self._values = None 206 self._error = exception 207 self._status_available_event.set() 208 209 def _get_values(self): 210 self._status_available_event.wait() 211 return self._values 212 213 def _get_error(self): 214 self._status_available_event.wait() 215 return self._error 216 217 def fetch(self): 218 self._status_available_event.wait() 219 if self._status is _RemoteValueStatus.ABORTED: 220 raise errors.CancelledError( 221 None, None, 222 "The corresponding function is aborted. Please reschedule the " 223 "function.") 224 if self._error is not None: 225 raise self._error 226 if self._fetched_numpys is None: 227 self._fetched_numpys = nest.map_structure( 228 lambda x: x.numpy() if hasattr(x, "numpy") else x, self._values) 229 return self._fetched_numpys 230 231 232class InputError(Exception): 233 234 def __init__(self, original_exception): 235 message = ("Input has an error, the original exception is %r, " 236 "error message is %s." % 237 (original_exception, str(original_exception))) 238 super().__init__(message) 239 240 241def _maybe_rebuild_remote_values(worker, structure): 242 """Attempts to return errors from `RemoteValue`s. Rebuilds them if needed.""" 243 errors_in_structure = [] 244 245 def _get_error(val): 246 if isinstance(val, RemoteValue): 247 if val._status is _RemoteValueStatus.ABORTED: # pylint: disable=protected-access 248 try: 249 with worker.failure_handler.wait_on_failure( 250 on_recovery_fn=functools.partial(val._rebuild_on, worker), # pylint: disable=protected-access 251 worker_device_name=worker.device_name): 252 val._rebuild_on(worker) # pylint: disable=protected-access 253 except Exception as e: # pylint: disable=broad-except 254 val._set_error(e) # pylint: disable=protected-access 255 256 error = val._get_error() # pylint: disable=protected-access 257 if error: 258 errors_in_structure.append(error) 259 260 nest.map_structure(_get_error, structure) 261 if errors_in_structure: 262 return errors_in_structure[0] 263 else: 264 return None 265 266 267def _maybe_get_remote_value(val): 268 """Gets the value of `val` if it is a `RemoteValue`.""" 269 if isinstance(val, RemoteValue): 270 error = val._get_error() # pylint: disable=protected-access 271 if error: 272 raise AssertionError( 273 "RemoteValue doesn't have a value because it has errors.") 274 else: 275 return val._get_values() # pylint: disable=protected-access 276 else: 277 return val 278 279 280def _maybe_as_type_spec(val): 281 if isinstance(val, RemoteValue): 282 if val._type_spec is None: # pylint: disable=protected-access 283 raise ValueError("Output of a scheduled function that is not " 284 "tf.function cannot be the input of another function.") 285 return val._type_spec # pylint: disable=protected-access 286 else: 287 return val 288 289 290@tf_export("distribute.experimental.coordinator.PerWorkerValues", v1=[]) 291class PerWorkerValues(object): 292 """A container that holds a list of values, one value per worker. 293 294 `tf.distribute.experimental.coordinator.PerWorkerValues` contains a collection 295 of values, where each of the values is located on its corresponding worker, 296 and upon being used as one of the `args` or `kwargs` of 297 `tf.distribute.experimental.coordinator.ClusterCoordinator.schedule()`, the 298 value specific to a worker will be passed into the function being executed at 299 that corresponding worker. 300 301 Currently, the only supported path to create an object of 302 `tf.distribute.experimental.coordinator.PerWorkerValues` is through calling 303 `iter` on a `ClusterCoordinator.create_per_worker_dataset`-returned 304 distributed dataset instance. The mechanism to create a custom 305 `tf.distribute.experimental.coordinator.PerWorkerValues` is not yet supported. 306 """ 307 308 def __init__(self, values): 309 self._values = tuple(values) 310 311 312def _select_worker_slice(worker_id, structured): 313 """Selects the worker slice of each of the items in `structured`.""" 314 315 def _get(x): 316 return x._values[worker_id] if isinstance(x, PerWorkerValues) else x # pylint: disable=protected-access 317 318 return nest.map_structure(_get, structured) 319 320 321def _disallow_remote_value_as_input(structured): 322 """Raises if any element of `structured` is a RemoteValue.""" 323 324 def _raise_if_remote_value(x): 325 if isinstance(x, RemoteValue): 326 raise ValueError( 327 "`tf.distribute.experimental.coordinator.RemoteValue` used " 328 "as an input to scheduled function is not yet " 329 "supported.") 330 331 nest.map_structure(_raise_if_remote_value, structured) 332 333 334class Closure(object): 335 """Hold a function to be scheduled and its arguments.""" 336 337 def __init__(self, function, cancellation_mgr, args=None, kwargs=None): 338 if not callable(function): 339 raise ValueError("Function passed to `ClusterCoordinator.schedule` must " 340 "be a callable object.") 341 self._args = args or () 342 self._kwargs = kwargs or {} 343 344 _disallow_remote_value_as_input(self._args) 345 _disallow_remote_value_as_input(self._kwargs) 346 347 if isinstance(function, def_function.Function): 348 replica_args = _select_worker_slice(0, self._args) 349 replica_kwargs = _select_worker_slice(0, self._kwargs) 350 351 # Note: no need to handle function registration failure since this kind of 352 # failure will not raise exceptions as designed in the runtime. The 353 # coordinator has to rely on subsequent operations that raise to catch 354 # function registration failure. 355 356 # Record the function tracing overhead. Note that we pass in the tracing 357 # count of the def_function.Function as a state tracker, so that metrics 358 # will only record the time for actual function tracing (i.e., excluding 359 # function cache lookups). 360 with metric_utils.monitored_timer( 361 "function_tracing", state_tracker=function._get_tracing_count): # pylint: disable=protected-access 362 self._concrete_function = function.get_concrete_function( 363 *nest.map_structure(_maybe_as_type_spec, replica_args), 364 **nest.map_structure(_maybe_as_type_spec, replica_kwargs)) 365 elif isinstance(function, tf_function.ConcreteFunction): 366 self._concrete_function = function 367 368 if hasattr(self, "_concrete_function"): 369 # If we have a concrete function, we get to retrieve the output type spec 370 # via the structured_output. 371 output_type_spec = func_graph.convert_structure_to_signature( 372 self._concrete_function.structured_outputs) 373 self._function = cancellation_mgr.get_cancelable_function( 374 self._concrete_function) 375 else: 376 # Otherwise (i.e. what is passed in is a regular python function), we have 377 # no such information. 378 output_type_spec = None 379 self._function = function 380 381 self.output_remote_value = RemoteValueImpl(self, output_type_spec) 382 383 def mark_cancelled(self): 384 self.output_remote_value._set_error( # pylint: disable=protected-access 385 errors.CancelledError( 386 None, None, "The corresponding function is " 387 "cancelled. Please reschedule the function.")) 388 389 def execute_on(self, worker): 390 """Executes the closure on the given worker. 391 392 Args: 393 worker: a `Worker` object. 394 """ 395 replica_args = _select_worker_slice(worker.worker_index, self._args) 396 replica_kwargs = _select_worker_slice(worker.worker_index, self._kwargs) 397 398 e = ( 399 _maybe_rebuild_remote_values(worker, replica_args) or 400 _maybe_rebuild_remote_values(worker, replica_kwargs)) 401 if e: 402 if not isinstance(e, InputError): 403 e = InputError(e) 404 self.output_remote_value._set_error(e) # pylint: disable=protected-access 405 return 406 407 with ops.device(worker.device_name): 408 with context.executor_scope(worker.executor): 409 with metric_utils.monitored_timer("closure_execution"): 410 output_values = self._function( 411 *nest.map_structure(_maybe_get_remote_value, replica_args), 412 **nest.map_structure(_maybe_get_remote_value, replica_kwargs)) 413 self.output_remote_value._set_values(output_values) # pylint: disable=protected-access 414 415 416class _CoordinatedClosureQueue(object): 417 """Manage a queue of closures, inflight count and errors from execution. 418 419 This class is thread-safe. 420 """ 421 422 def __init__(self): 423 # `self._inflight_closure_count` only tracks the number of inflight closures 424 # that are "in generation". Once an error occurs, error generation is 425 # incremented and all subsequent arriving closures (from inflight) are 426 # considered "out of generation". 427 self._inflight_closure_count = 0 428 429 self._queue_lock = threading.Lock() 430 431 # Condition indicating that all pending closures (either queued or inflight) 432 # have been processed, failed, or cancelled. 433 self._stop_waiting_condition = threading.Condition(self._queue_lock) 434 435 # Condition indicating that an item becomes available in queue (not empty). 436 self._closures_queued_condition = threading.Condition(self._queue_lock) 437 self._should_process_closures = True 438 439 # Condition indicating that a queue slot becomes available (not full). 440 # Note that even with "infinite" queue size, there is still a "practical" 441 # size limit for the queue depending on host memory capacity, and thus the 442 # queue will eventually become full with a lot of enqueued closures. 443 self._queue_free_slot_condition = threading.Condition(self._queue_lock) 444 445 # Condition indicating there is no inflight closures. 446 self._no_inflight_closure_condition = threading.Condition(self._queue_lock) 447 448 # Use to cancel in-flight closures. 449 self._cancellation_mgr = cancellation.CancellationManager() 450 451 if _CLOSURE_QUEUE_MAX_SIZE <= 0: 452 logging.warning( 453 "In a `ClusterCoordinator`, creating an infinite closure queue can " 454 "consume a significant amount of memory and even lead to OOM.") 455 self._queue = queue.Queue(maxsize=_CLOSURE_QUEUE_MAX_SIZE) 456 self._error = None 457 458 # The following is a lock to make sure when `wait` is called and before it 459 # returns no `put` can be executed during this period. It is because `wait` 460 # won't know what to do with newly put closures. This lock adds an cutoff 461 # for `wait` so that closures put into the queue while waiting would not be 462 # taken responsible by this `wait`. 463 # 464 # We cannot reuse the `self._queue_lock` since when `wait` waits for a 465 # condition, the `self._queue_lock` will be released. 466 # 467 # We don't use a reader/writer's lock on purpose to reduce the complexity 468 # of the code. 469 self._put_wait_lock = threading.Lock() 470 471 def stop(self): 472 with self._queue_lock: 473 self._should_process_closures = False 474 self._closures_queued_condition.notifyAll() 475 476 def _cancel_all_closures(self): 477 """Clears the queue and sets remaining closures cancelled error. 478 479 This method expects self._queue_lock to be held prior to entry. 480 """ 481 self._cancellation_mgr.start_cancel() 482 while self._inflight_closure_count > 0: 483 self._no_inflight_closure_condition.wait() 484 while True: 485 try: 486 closure = self._queue.get(block=False) 487 self._queue_free_slot_condition.notify() 488 closure.mark_cancelled() 489 except queue.Empty: 490 break 491 # The cancellation manager cannot be reused once cancelled. After all 492 # closures (queued or inflight) are cleaned up, recreate the cancellation 493 # manager with clean state. 494 # Note on thread-safety: this is triggered when one of theses 495 # ClusterCoordinator APIs are called: `schedule`, `wait`, and `done`. At the 496 # same time, no new closures can be constructed (which reads the 497 # _cancellation_mgr to get cancellable functions). 498 self._cancellation_mgr = cancellation.CancellationManager() 499 500 def _raise_if_error(self): 501 """Raises the error if one exists. 502 503 If an error exists, cancel the closures in queue, raises it, and clear 504 the error. 505 506 This method expects self._queue_lock to be held prior to entry. 507 """ 508 if self._error: 509 logging.error("Start cancelling closures due to error %r: %s", 510 self._error, self._error) 511 self._cancel_all_closures() 512 try: 513 raise self._error # pylint: disable=raising-bad-type 514 finally: 515 self._error = None 516 517 def put(self, closure): 518 """Put a closure into the queue for later execution. 519 520 If `mark_failed` was called before `put`, the error from the first 521 invocation of `mark_failed` will be raised. 522 523 Args: 524 closure: The `Closure` to put into the queue. 525 """ 526 with self._put_wait_lock, self._queue_lock: 527 self._queue_free_slot_condition.wait_for(lambda: not self._queue.full()) 528 self._queue.put(closure, block=False) 529 self._raise_if_error() 530 self._closures_queued_condition.notify() 531 532 def get(self, timeout=None): 533 """Return a closure from the queue to be executed.""" 534 with self._queue_lock: 535 while self._queue.empty() and self._should_process_closures: 536 if not self._closures_queued_condition.wait(timeout=timeout): 537 return None 538 if not self._should_process_closures: 539 return None 540 closure = self._queue.get(block=False) 541 self._queue_free_slot_condition.notify() 542 self._inflight_closure_count += 1 543 return closure 544 545 def mark_finished(self): 546 """Let the queue know that a closure has been successfully executed.""" 547 with self._queue_lock: 548 if self._inflight_closure_count < 1: 549 raise AssertionError("There is no inflight closures to mark_finished.") 550 self._inflight_closure_count -= 1 551 if self._inflight_closure_count == 0: 552 self._no_inflight_closure_condition.notifyAll() 553 if self._queue.empty() and self._inflight_closure_count == 0: 554 self._stop_waiting_condition.notifyAll() 555 556 def put_back(self, closure): 557 """Put the closure back into the queue as it was not properly executed.""" 558 with self._queue_lock: 559 if self._inflight_closure_count < 1: 560 raise AssertionError("There is no inflight closures to put_back.") 561 if self._error: 562 closure.mark_cancelled() 563 else: 564 self._queue_free_slot_condition.wait_for(lambda: not self._queue.full()) 565 self._queue.put(closure, block=False) 566 self._closures_queued_condition.notify() 567 self._inflight_closure_count -= 1 568 if self._inflight_closure_count == 0: 569 self._no_inflight_closure_condition.notifyAll() 570 571 def wait(self, timeout=None): 572 """Wait for all closures to be finished before returning. 573 574 If `mark_failed` was called before or during `wait`, the error from the 575 first invocation of `mark_failed` will be raised. 576 577 Args: 578 timeout: A float specifying a timeout for the wait in seconds. 579 580 Returns: 581 True unless the given timeout expired, in which case it returns False. 582 """ 583 with self._put_wait_lock, self._queue_lock: 584 while (not self._error and 585 (not self._queue.empty() or self._inflight_closure_count > 0)): 586 if not self._stop_waiting_condition.wait(timeout=timeout): 587 return False 588 self._raise_if_error() 589 return True 590 591 def mark_failed(self, e): 592 """Sets error and unblocks any wait() call.""" 593 with self._queue_lock: 594 # TODO(yuefengz): maybe record all failure and give users more 595 # information? 596 if self._inflight_closure_count < 1: 597 raise AssertionError("There is no inflight closures to mark_failed.") 598 if self._error is None: 599 self._error = e 600 self._inflight_closure_count -= 1 601 if self._inflight_closure_count == 0: 602 self._no_inflight_closure_condition.notifyAll() 603 self._stop_waiting_condition.notifyAll() 604 605 def done(self): 606 """Returns true if the queue is empty and there is no inflight closure. 607 608 If `mark_failed` was called before `done`, the error from the first 609 invocation of `mark_failed` will be raised. 610 """ 611 with self._queue_lock: 612 self._raise_if_error() 613 return self._queue.empty() and self._inflight_closure_count == 0 614 615 616class WorkerPreemptionHandler(object): 617 """Handles worker preemptions.""" 618 619 def __init__(self, server_def, cluster): 620 self._server_def = server_def 621 self._cluster = cluster 622 self._cluster_update_lock = threading.Lock() 623 self._cluster_due_for_update_or_finish = threading.Event() 624 self._worker_up_cond = threading.Condition(self._cluster_update_lock) 625 self._should_preemption_thread_run = True 626 threading.Thread(target=self._preemption_handler, 627 name="WorkerPreemptionHandler", 628 daemon=True).start() 629 630 def stop(self): 631 """Ensure the worker preemption thread is closed.""" 632 self._should_preemption_thread_run = False 633 with self._cluster_update_lock: 634 self._cluster_due_for_update_or_finish.set() 635 636 def _validate_preemption_failure(self, e): 637 """Validates that the given exception represents worker preemption.""" 638 if _is_worker_failure(e): 639 return 640 raise e 641 642 @contextlib.contextmanager 643 def wait_on_failure(self, 644 on_failure_fn=None, 645 on_recovery_fn=None, 646 worker_device_name="(unknown)"): 647 """Catches worker preemption error and wait until failed workers are back. 648 649 Args: 650 on_failure_fn: an optional function to run if preemption happens. 651 on_recovery_fn: an optional function to run when a worker is recovered 652 from preemption. 653 worker_device_name: the device name of the worker instance that is passing 654 through the failure. 655 656 Yields: 657 None. 658 """ 659 try: 660 yield 661 except errors.OpError as e: 662 # If the error is due to temporary connectivity issues between worker and 663 # ps, put back closure, ignore error and do not mark worker as failure. 664 if self._cluster._record_and_ignore_transient_ps_failure(e): # pylint: disable=protected-access 665 if on_failure_fn: 666 on_failure_fn() 667 return 668 669 self._validate_preemption_failure(e) 670 logging.error("Worker %s failed with error: %s", worker_device_name, e) 671 if on_failure_fn: 672 on_failure_fn() 673 674 with self._cluster_update_lock: 675 self._cluster_due_for_update_or_finish.set() 676 self._worker_up_cond.wait(_WORKER_MAXIMUM_RECOVERY_SEC) 677 logging.info("Worker %s has been recovered.", worker_device_name) 678 679 if on_recovery_fn: 680 with self.wait_on_failure( 681 on_recovery_fn=on_recovery_fn, 682 worker_device_name=worker_device_name): 683 on_recovery_fn() 684 685 def _preemption_handler(self): 686 """A loop that handles preemption. 687 688 This loop waits for signal of worker preemption and upon worker preemption, 689 it waits until all workers are back and updates the cluster about the 690 restarted workers. 691 """ 692 while True: 693 self._cluster_due_for_update_or_finish.wait() 694 if not self._should_preemption_thread_run: 695 break 696 697 with self._cluster_update_lock: 698 try: 699 # TODO(haoyuzhang): support partial cluster recovery 700 logging.info("Cluster now being recovered.") 701 context.context().update_server_def(self._server_def) 702 703 # Cluster updated successfully, clear the update signal, and notify 704 # all workers that they are recovered from failure. 705 logging.info("Cluster successfully recovered.") 706 self._worker_up_cond.notify_all() 707 self._cluster_due_for_update_or_finish.clear() 708 except Exception as e: # pylint: disable=broad-except 709 self._validate_preemption_failure(e) 710 # NOTE: Since the first RPC (GetStatus) of update_server_def is 711 # currently blocking by default, error should only happen if: 712 # (1) More workers failed while waiting for the previous workers to 713 # come back; 714 # (2) Worker failed when exchanging subsequent RPCs after the first 715 # RPC returns. 716 # Consider adding backoff retry logic if we see the error logged 717 # too frequently. 718 logging.error("Cluster update failed with error: %s. Retrying...", e) 719 720 721class Worker(object): 722 """A worker in a cluster. 723 724 Attributes: 725 worker_index: The index of the worker in the cluster. 726 device_name: The device string of the worker, e.g. "/job:worker/task:1". 727 executor: The worker's executor for remote function execution. 728 failure_handler: The failure handler used to handler worker preemption 729 failure. 730 """ 731 732 def __init__(self, worker_index, device_name, cluster): 733 self.worker_index = worker_index 734 self.device_name = device_name 735 self.executor = executor.new_executor(enable_async=False) 736 self.failure_handler = cluster.failure_handler 737 self._cluster = cluster 738 self._resource_remote_value_refs = [] 739 self._should_worker_thread_run = True 740 741 # Worker threads need to start after `Worker`'s initialization. 742 threading.Thread(target=self._process_queue, 743 name="WorkerClosureProcessingLoop-%d" % self.worker_index, 744 daemon=True).start() 745 746 def stop(self): 747 """Ensure the worker thread is closed.""" 748 self._should_worker_thread_run = False 749 750 def _set_resources_aborted(self): 751 # TODO(yuefengz): maybe we can query whether a tensor is valid or not 752 # instead of marking a tensor aborted? 753 for weakref_resource in self._resource_remote_value_refs: 754 resource = weakref_resource() 755 if resource: 756 resource._set_aborted() # pylint: disable=protected-access 757 758 def _set_dead(self): 759 raise NotImplementedError("_set_dead is not implemented.") 760 761 def _process_closure(self, closure): 762 """Runs a closure with preemption handling.""" 763 assert closure is not None 764 try: 765 with self._cluster.failure_handler.wait_on_failure( 766 on_failure_fn=lambda: self._cluster._closure_queue.put_back(closure), # pylint: disable=protected-access 767 on_recovery_fn=self._set_resources_aborted, 768 worker_device_name=self.device_name): 769 closure.execute_on(self) 770 # TODO(yuefengz): we don't have to materialize results every step. 771 with metric_utils.monitored_timer("remote_value_fetch"): 772 closure.output_remote_value.fetch() 773 self._cluster._closure_queue.mark_finished() # pylint: disable=protected-access 774 except Exception as e: # pylint: disable=broad-except 775 # Avoid logging the derived cancellation error 776 if not isinstance(e, errors.CancelledError): 777 logging.error( 778 "/job:worker/task:%d encountered the following error when " 779 "processing closure: %r:%s", self.worker_index, e, e) 780 closure.output_remote_value._set_error(e) # pylint: disable=protected-access 781 self._cluster._closure_queue.mark_failed(e) # pylint: disable=protected-access 782 783 def _maybe_delay(self): 784 """Delay if corresponding env vars are set.""" 785 # If the following two env vars variables are set. Scheduling for workers 786 # will start in a staggered manner. Worker i will wait for 787 # `TF_COORDINATOR_SCHEDULE_START_DELAY` * i seconds, not exceeding 788 # `TF_COORDINATOR_SCHEDULE_START_DELAY_MAX`. 789 delay_secs = int(os.environ.get("TF_COORDINATOR_SCHEDULE_START_DELAY", "0")) 790 delay_cap = int( 791 os.environ.get("TF_COORDINATOR_SCHEDULE_START_DELAY_MAX", "0")) 792 if delay_cap: 793 delay_secs = min(delay_secs * self.worker_index, delay_cap) 794 if delay_secs > 0: 795 logging.info("Worker %d sleeping for %d seconds before running function", 796 self.worker_index, delay_secs) 797 time.sleep(delay_secs) 798 799 def _process_queue(self): 800 """Function running in a thread to process closure queues.""" 801 self._maybe_delay() 802 while self._should_worker_thread_run: 803 closure = self._cluster._closure_queue.get() # pylint: disable=protected-access 804 if not self._should_worker_thread_run or closure is None: 805 return 806 self._process_closure(closure) 807 808 def _create_resource(self, function, args=None, kwargs=None): 809 """Synchronously creates a per-worker resource represented by a `RemoteValue`. 810 811 Args: 812 function: the resource function to be run remotely. It should be a 813 `tf.function`, a concrete function or a Python function. 814 args: positional arguments to be passed to the function. 815 kwargs: keyword arguments to be passed to the function. 816 817 Returns: 818 one or several RemoteValue objects depending on the function return 819 values. 820 """ 821 # Some notes about the concurrency: currently all the activities related to 822 # the same worker such as creating resources, setting resources' aborted 823 # status, and executing closures happen on the same thread. This allows us 824 # to have simpler logic of concurrency. 825 closure = Closure( 826 function, 827 self._cluster._closure_queue._cancellation_mgr, # pylint: disable=protected-access 828 args=args, 829 kwargs=kwargs) 830 resource_remote_value = closure.output_remote_value 831 self._register_resource(resource_remote_value) 832 833 # The following is a short-term solution to lazily create resources in 834 # parallel. 835 # TODO(b/160343165): we should create resources eagerly, i.e. schedule the 836 # resource creation function as soon as users call this method. 837 resource_remote_value._set_aborted() # pylint: disable=protected-access 838 return resource_remote_value 839 840 def _register_resource(self, resource_remote_value): 841 if not isinstance(resource_remote_value, RemoteValue): 842 raise ValueError("Resource being registered is not of type " 843 "`tf.distribute.experimental.coordinator.RemoteValue`.") 844 self._resource_remote_value_refs.append(weakref.ref(resource_remote_value)) 845 846 847class Cluster(object): 848 """A cluster with workers. 849 850 We assume all function errors are fatal and based on this assumption our 851 error reporting logic is: 852 1) Both `schedule` and `join` can raise a non-retryable error which is the 853 first error seen by the coordinator from any previously scheduled functions. 854 2) When an error is raised, there is no guarantee on how many previously 855 scheduled functions have been executed; functions that have not been executed 856 will be thrown away and marked as cancelled. 857 3) After an error is raised, the internal state of error will be cleared. 858 I.e. functions can continue to be scheduled and subsequent calls of `schedule` 859 or `join` will not raise the same error again. 860 861 Attributes: 862 failure_handler: The failure handler used to handler worker preemption 863 failure. 864 workers: a list of `Worker` objects in the cluster. 865 """ 866 867 def __init__(self, strategy): 868 """Initializes the cluster instance.""" 869 870 self._num_workers = strategy._num_workers 871 self._num_ps = strategy._num_ps 872 873 # Ignore PS failures reported by workers due to transient connection errors. 874 # Transient connectivity issues between workers and PS are relayed by the 875 # workers to the coordinator, leading the coordinator to believe that there 876 # are PS failures. The difference between transient vs. permanent PS failure 877 # is the number of reports from the workers. When this env var is set to a 878 # positive integer K, the coordinator ignores up to K reports of a failed PS 879 # task, i.e., only when there are more than K trials of executing closures 880 # fail due to errors from the same PS instance do we consider the PS 881 # instance encounters a failure. 882 # TODO(b/164279603): Remove this workaround when the underlying connectivity 883 # issue in gRPC server is resolved. 884 self._transient_ps_failures_threshold = int( 885 os.environ.get("TF_COORDINATOR_IGNORE_TRANSIENT_PS_FAILURES", 3)) 886 self._potential_ps_failures_lock = threading.Lock() 887 self._potential_ps_failures_count = [0] * self._num_ps 888 889 self._closure_queue = _CoordinatedClosureQueue() 890 self.failure_handler = WorkerPreemptionHandler(context.get_server_def(), 891 self) 892 worker_device_strings = [ 893 "/job:worker/replica:0/task:%d" % i for i in range(self._num_workers) 894 ] 895 self.workers = [ 896 Worker(i, w, self) for i, w in enumerate(worker_device_strings) 897 ] 898 899 def stop(self): 900 """Stop worker, worker preemption threads, and the closure queue.""" 901 self.failure_handler.stop() 902 903 for worker in self.workers: 904 worker.stop() 905 self._closure_queue.stop() 906 907 def _record_and_ignore_transient_ps_failure(self, e): 908 """Records potential PS failures and return if failure should be ignored.""" 909 if self._transient_ps_failures_threshold <= 0 or not _is_ps_failure(e): 910 return False 911 912 ps_tasks = _extract_failed_ps_instances(str(e)) 913 with self._potential_ps_failures_lock: 914 for t in ps_tasks: 915 self._potential_ps_failures_count[t] += 1 916 # The number of UnavailableError encountered on this PS task exceeds the 917 # maximum number of ignored error 918 if (self._potential_ps_failures_count[t] >= 919 self._transient_ps_failures_threshold): 920 return False 921 return True 922 923 def schedule(self, function, args, kwargs): 924 """Schedules `function` to be dispatched to a worker for execution. 925 926 Args: 927 function: The function to be dispatched to a worker for execution 928 asynchronously. 929 args: Positional arguments for `fn`. 930 kwargs: Keyword arguments for `fn`. 931 932 Returns: 933 A `RemoteValue` object. 934 """ 935 closure = Closure( 936 function, 937 self._closure_queue._cancellation_mgr, # pylint: disable=protected-access 938 args=args, 939 kwargs=kwargs) 940 self._closure_queue.put(closure) 941 return closure.output_remote_value 942 943 def join(self): 944 """Blocks until all scheduled functions are executed.""" 945 self._closure_queue.wait() 946 947 def done(self): 948 """Returns true if all scheduled functions are executed.""" 949 return self._closure_queue.done() 950 951 952@tf_export("distribute.experimental.coordinator.ClusterCoordinator", v1=[]) 953class ClusterCoordinator(object): 954 """An object to schedule and coordinate remote function execution. 955 956 This class is used to create fault-tolerant resources and dispatch functions 957 to remote TensorFlow servers. 958 959 Currently, this class is not supported to be used in a standalone manner. It 960 should be used in conjunction with a `tf.distribute` strategy that is designed 961 to work with it. The `ClusterCoordinator` class currently only works 962 `tf.distribute.experimental.ParameterServerStrategy`. 963 964 __The `schedule`/`join` APIs__ 965 966 The most important APIs provided by this class is the `schedule`/`join` pair. 967 The `schedule` API is non-blocking in that it queues a `tf.function` and 968 returns a `RemoteValue` immediately. The queued functions will be dispatched 969 to remote workers in background threads and their `RemoteValue`s will be 970 filled asynchronously. Since `schedule` doesn’t require worker assignment, the 971 `tf.function` passed in can be executed on any available worker. If the worker 972 it is executed on becomes unavailable before its completion, it will be 973 migrated to another worker. Because of this fact and function execution is not 974 atomic, a function may be executed more than once. 975 976 __Handling Task Failure__ 977 978 This class when used with 979 `tf.distribute.experimental.ParameterServerStrategy`, comes with built-in 980 fault tolerance for worker failures. That is, when some workers are not 981 available for any reason to be reached from the coordinator, the training 982 progress continues to be made with the remaining workers. Upon recovery of a 983 failed worker, it will be added for function execution after datasets created 984 by `create_per_worker_dataset` are re-built on it. 985 986 When a parameter server fails, a `tf.errors.UnavailableError` is raised by 987 `schedule`, `join` or `done`. In this case, in addition to bringing back the 988 failed parameter server, users should restart the coordinator so that it 989 reconnects to workers and parameter servers, re-creates the variables, and 990 loads checkpoints. If the coordinator fails, after the user brings it back, 991 the program will automatically connect to workers and parameter servers, and 992 continue the progress from a checkpoint. 993 994 It is thus essential that in user's program, a checkpoint file is periodically 995 saved, and restored at the start of the program. If an 996 `tf.keras.optimizers.Optimizer` is checkpointed, after restoring from a 997 checkpoiont, its `iterations` property roughly indicates the number of steps 998 that have been made. This can be used to decide how many epochs and steps are 999 needed before the training completion. 1000 1001 See `tf.distribute.experimental.ParameterServerStrategy` docstring for an 1002 example usage of this API. 1003 1004 This is currently under development, and the API as well as implementation 1005 are subject to changes. 1006 """ 1007 1008 def __new__(cls, strategy): 1009 # `ClusterCoordinator` is kept as a single instance to a given `Strategy`. 1010 # TODO(rchao): Needs a lock for thread-safety 1011 if strategy._cluster_coordinator is None: 1012 strategy._cluster_coordinator = super( 1013 ClusterCoordinator, cls).__new__(cls) 1014 return strategy._cluster_coordinator 1015 1016 def __init__(self, strategy): 1017 """Initialization of a `ClusterCoordinator` instance. 1018 1019 Args: 1020 strategy: a supported `tf.distribute.Strategy` object. Currently, only 1021 `tf.distribute.experimental.ParameterServerStrategy` is supported. 1022 1023 Raises: 1024 ValueError: if the strategy being used is not supported. 1025 """ 1026 if not isinstance(strategy, 1027 parameter_server_strategy_v2.ParameterServerStrategyV2): 1028 raise ValueError( 1029 "Only `tf.distribute.experimental.ParameterServerStrategy` " 1030 "is supported to work with " 1031 "`tf.distribute.experimental.coordinator.ClusterCoordinator` " 1032 "currently.") 1033 self._strategy = strategy 1034 self.strategy.extended._used_with_coordinator = True 1035 self._cluster = Cluster(strategy) 1036 1037 def __del__(self): 1038 self._cluster.stop() 1039 1040 @property 1041 def strategy(self): 1042 """Returns the `Strategy` associated with the `ClusterCoordinator`.""" 1043 return self._strategy 1044 1045 def schedule(self, fn, args=None, kwargs=None): 1046 """Schedules `fn` to be dispatched to a worker for asynchronous execution. 1047 1048 This method is non-blocking in that it queues the `fn` which will be 1049 executed later and returns a 1050 `tf.distribute.experimental.coordinator.RemoteValue` object immediately. 1051 `fetch` can be called on it to wait for the function execution to finish 1052 and retrieve its output from a remote worker. On the other hand, call 1053 `tf.distribute.experimental.coordinator.ClusterCoordinator.join` to wait for 1054 all scheduled functions to finish. 1055 1056 `schedule` guarantees that `fn` will be executed on a worker at least once; 1057 it could be more than once if its corresponding worker fails in the middle 1058 of its execution. Note that since worker can fail at any point when 1059 executing the function, it is possible that the function is partially 1060 executed, but `tf.distribute.experimental.coordinator.ClusterCoordinator` 1061 guarantees that in those events, the function will eventually be executed on 1062 any worker that is available. 1063 1064 If any previously scheduled function raises an error, `schedule` will raise 1065 any one of those errors, and clear the errors collected so far. What happens 1066 here, some of the previously scheduled functions may have not been executed. 1067 User can call `fetch` on the returned 1068 `tf.distribute.experimental.coordinator.RemoteValue` to inspect if they have 1069 executed, failed, or cancelled, and reschedule the corresponding function if 1070 needed. 1071 1072 When `schedule` raises, it guarantees that there is no function that is 1073 still being executed. 1074 1075 At this time, there is no support of worker assignment for function 1076 execution, or priority of the workers. 1077 1078 `args` and `kwargs` are the arguments passed into `fn`, when `fn` is 1079 executed on a worker. They can be 1080 `tf.distribute.experimental.coordinator.PerWorkerValues` and in this case, 1081 the argument will be substituted with the corresponding component on the 1082 target worker. Arguments that are not 1083 `tf.distribute.experimental.coordinator.PerWorkerValues` will be passed into 1084 `fn` as-is. Currently, `tf.distribute.experimental.coordinator.RemoteValue` 1085 is not supported to be input `args` or `kwargs`. 1086 1087 Args: 1088 fn: A `tf.function`; the function to be dispatched to a worker for 1089 execution asynchronously. Regular python funtion is not supported to be 1090 scheduled. 1091 args: Positional arguments for `fn`. 1092 kwargs: Keyword arguments for `fn`. 1093 1094 Returns: 1095 A `tf.distribute.experimental.coordinator.RemoteValue` object that 1096 represents the output of the function scheduled. 1097 1098 Raises: 1099 Exception: one of the exceptions caught by the coordinator from any 1100 previously scheduled function, since the last time an error was thrown 1101 or since the beginning of the program. 1102 """ 1103 if not isinstance(fn, 1104 (def_function.Function, tf_function.ConcreteFunction)): 1105 raise TypeError( 1106 "`tf.distribute.experimental.coordinator.ClusterCoordinator.schedule`" 1107 " only accepts a `tf.function` or a concrete function.") 1108 # Slot variables are usually created during function tracing time; thus 1109 # `schedule` needs to be called within the `strategy.scope()`. 1110 with self.strategy.scope(): 1111 self.strategy.extended._being_scheduled = True # pylint: disable=protected-access 1112 remote_value = self._cluster.schedule(fn, args=args, kwargs=kwargs) 1113 self.strategy.extended._being_scheduled = False # pylint: disable=protected-access 1114 return remote_value 1115 1116 def join(self): 1117 """Blocks until all the scheduled functions have finished execution. 1118 1119 If any previously scheduled function raises an error, `join` will fail by 1120 raising any one of those errors, and clear the errors collected so far. If 1121 this happens, some of the previously scheduled functions may have not been 1122 executed. Users can call `fetch` on the returned 1123 `tf.distribute.experimental.coordinator.RemoteValue` to inspect if they have 1124 executed, failed, or cancelled. If some that have been cancelled need to be 1125 rescheduled, users should call `schedule` with the function again. 1126 1127 When `join` returns or raises, it guarantees that there is no function that 1128 is still being executed. 1129 1130 Raises: 1131 Exception: one of the exceptions caught by the coordinator by any 1132 previously scheduled function since the last time an error was thrown or 1133 since the beginning of the program. 1134 """ 1135 self._cluster.join() 1136 1137 def done(self): 1138 """Returns whether all the scheduled functions have finished execution. 1139 1140 If any previously scheduled function raises an error, `done` will fail by 1141 raising any one of those errors. 1142 1143 When `done` returns True or raises, it guarantees that there is no function 1144 that is still being executed. 1145 1146 Returns: 1147 Whether all the scheduled functions have finished execution. 1148 Raises: 1149 Exception: one of the exceptions caught by the coordinator by any 1150 previously scheduled function since the last time an error was thrown or 1151 since the beginning of the program. 1152 """ 1153 return self._cluster.done() 1154 1155 def create_per_worker_dataset(self, dataset_fn): 1156 """Create dataset on workers by calling `dataset_fn` on worker devices. 1157 1158 This creates the given dataset generated by dataset_fn on workers 1159 and returns an object that represents the collection of those individual 1160 datasets. Calling `iter` on such collection of datasets returns a 1161 `tf.distribute.experimental.coordinator.PerWorkerValues`, which is a 1162 collection of iterators, where the iterators have been placed on respective 1163 workers. 1164 1165 Calling `next` on a `PerWorkerValues` of iterator is unsupported. The 1166 iterator is meant to be passed as an argument into 1167 `tf.distribute.experimental.coordinator.ClusterCoordinator.schedule`. When 1168 the scheduled function is about to be executed by a worker, the 1169 function will receive the individual iterator that corresponds to the 1170 worker. The `next` method can be called on an iterator inside a 1171 scheduled function when the iterator is an input of the function. 1172 1173 Currently the `schedule` method assumes workers are all the same and thus 1174 assumes the datasets on different workers are the same, except they may be 1175 shuffled differently if they contain a `dataset.shuffle` operation and a 1176 random seed is not set. Because of this, we also recommend the datasets to 1177 be repeated indefinitely and schedule a finite number of steps instead of 1178 relying on the `OutOfRangeError` from a dataset. 1179 1180 1181 Example: 1182 1183 ```python 1184 strategy = tf.distribute.experimental.ParameterServerStrategy( 1185 cluster_resolver=...) 1186 coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator( 1187 strategy=strategy) 1188 1189 @tf.function 1190 def worker_fn(iterator): 1191 return next(iterator) 1192 1193 def per_worker_dataset_fn(): 1194 return strategy.distribute_datasets_from_function( 1195 lambda x: tf.data.Dataset.from_tensor_slices([3] * 3)) 1196 1197 per_worker_dataset = coordinator.create_per_worker_dataset( 1198 per_worker_dataset_fn) 1199 per_worker_iter = iter(per_worker_dataset) 1200 remote_value = coordinator.schedule(worker_fn, args=(per_worker_iter,)) 1201 assert remote_value.fetch() == 3 1202 ``` 1203 1204 NOTE: A known limitation is `tf.data.Options` is ignored in dataset created 1205 by `create_per_worker_dataset`. 1206 1207 Args: 1208 dataset_fn: The dataset function that returns a dataset. This is to be 1209 executed on the workers. 1210 1211 Returns: 1212 An object that represents the collection of those individual 1213 datasets. `iter` is expected to be called on this object that returns 1214 a `tf.distribute.experimental.coordinator.PerWorkerValues` of the 1215 iterators (that are on the workers). 1216 """ 1217 input_workers = input_lib.InputWorkers([ 1218 (w.device_name, [w.device_name]) for w in self._cluster.workers 1219 ]) 1220 1221 return _PerWorkerDistributedDataset(dataset_fn, input_workers, self) 1222 1223 def _create_per_worker_resources(self, fn, args=None, kwargs=None): 1224 """Synchronously create resources on the workers. 1225 1226 The resources are represented by 1227 `tf.distribute.experimental.coordinator.RemoteValue`s. 1228 1229 Args: 1230 fn: The function to be dispatched to all workers for execution 1231 asynchronously. 1232 args: Positional arguments for `fn`. 1233 kwargs: Keyword arguments for `fn`. 1234 1235 Returns: 1236 A `tf.distribute.experimental.coordinator.PerWorkerValues` object, which 1237 wraps a tuple of `tf.distribute.experimental.coordinator.RemoteValue` 1238 objects. 1239 """ 1240 results = [] 1241 for w in self._cluster.workers: 1242 results.append(w._create_resource(fn, args=args, kwargs=kwargs)) # pylint: disable=protected-access 1243 return PerWorkerValues(tuple(results)) 1244 1245 def fetch(self, val): 1246 """Blocking call to fetch results from the remote values. 1247 1248 This is a wrapper around 1249 `tf.distribute.experimental.coordinator.RemoteValue.fetch` for a 1250 `RemoteValue` structure; it returns the execution results of 1251 `RemoteValue`s. If not ready, wait for them while blocking the caller. 1252 1253 Example: 1254 ```python 1255 strategy = ... 1256 coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator( 1257 strategy) 1258 1259 def dataset_fn(): 1260 return tf.data.Dataset.from_tensor_slices([1, 1, 1]) 1261 1262 with strategy.scope(): 1263 v = tf.Variable(initial_value=0) 1264 1265 @tf.function 1266 def worker_fn(iterator): 1267 def replica_fn(x): 1268 v.assign_add(x) 1269 return v.read_value() 1270 return strategy.run(replica_fn, args=(next(iterator),)) 1271 1272 distributed_dataset = coordinator.create_per_worker_dataset(dataset_fn) 1273 distributed_iterator = iter(distributed_dataset) 1274 result = coordinator.schedule(worker_fn, args=(distributed_iterator,)) 1275 assert coordinator.fetch(result) == 1 1276 ``` 1277 1278 Args: 1279 val: The value to fetch the results from. If this is structure of 1280 `tf.distribute.experimental.coordinator.RemoteValue`, `fetch()` will be 1281 called on the individual 1282 `tf.distribute.experimental.coordinator.RemoteValue` to get the result. 1283 1284 Returns: 1285 If `val` is a `tf.distribute.experimental.coordinator.RemoteValue` or a 1286 structure of `tf.distribute.experimental.coordinator.RemoteValue`s, 1287 return the fetched `tf.distribute.experimental.coordinator.RemoteValue` 1288 values immediately if they are available, or block the call until they are 1289 available, and return the fetched 1290 `tf.distribute.experimental.coordinator.RemoteValue` values with the same 1291 structure. If `val` is other types, return it as-is. 1292 """ 1293 1294 def _maybe_fetch(val): 1295 if isinstance(val, RemoteValue): 1296 return val.fetch() 1297 else: 1298 return val 1299 1300 # TODO(yuefengz): we should fetch values in a batch. 1301 return nest.map_structure(_maybe_fetch, val) 1302 1303 1304class _PerWorkerDistributedDataset(object): 1305 """Represents worker-distributed datasets created from dataset function.""" 1306 1307 def __init__(self, dataset_fn, input_workers, coordinator): 1308 """Makes an iterable from datasets created by the given function. 1309 1310 Args: 1311 dataset_fn: A function that returns a `Dataset`. 1312 input_workers: an `InputWorkers` object. 1313 coordinator: a `ClusterCoordinator` object, used to create dataset 1314 resources. 1315 """ 1316 def disallow_variable_creation(next_creator, **kwargs): 1317 raise ValueError("Creating variables in `dataset_fn` is not allowed.") 1318 1319 if isinstance(dataset_fn, def_function.Function): 1320 with variable_scope.variable_creator_scope(disallow_variable_creation): 1321 dataset_fn = dataset_fn.get_concrete_function() 1322 elif not isinstance(dataset_fn, tf_function.ConcreteFunction): 1323 with variable_scope.variable_creator_scope(disallow_variable_creation): 1324 dataset_fn = def_function.function(dataset_fn).get_concrete_function() 1325 self._dataset_fn = dataset_fn 1326 self._input_workers = input_workers 1327 self._coordinator = coordinator 1328 self._element_spec = None 1329 1330 def __iter__(self): 1331 # We would like users to create iterators outside `tf.function`s so that we 1332 # can track them. 1333 if (not context.executing_eagerly() or 1334 ops.get_default_graph().building_function): 1335 raise RuntimeError( 1336 "__iter__() is not supported inside of tf.function or in graph mode.") 1337 1338 def _create_per_worker_iterator(): 1339 dataset = self._dataset_fn() 1340 return iter(dataset) 1341 1342 # If _PerWorkerDistributedDataset.__iter__ is called multiple 1343 # times, for the same object it should only create and register resource 1344 # once. Using object id to distinguish different iterator resources. 1345 per_worker_iterator = self._coordinator._create_per_worker_resources( 1346 _create_per_worker_iterator) 1347 1348 # Setting type_spec of each RemoteValue so that functions taking these 1349 # RemoteValues as inputs can be traced. 1350 for iterator_remote_value in per_worker_iterator._values: 1351 iterator_remote_value._type_spec = ( # pylint: disable=protected-access 1352 iterator_ops.IteratorSpec( 1353 self._dataset_fn.structured_outputs.element_spec)) 1354 return _PerWorkerDistributedIterator(per_worker_iterator._values) 1355 1356 @property 1357 def element_spec(self): 1358 """The type specification of an element of this dataset.""" 1359 raise NotImplementedError("Passing `AsyncDistributedDataset` to a " 1360 "tf.function is not supported.") 1361 1362 1363class _PerWorkerDistributedIterator(PerWorkerValues): 1364 """Distributed iterator for `ClusterCoordinator`.""" 1365 1366 def __next__(self): 1367 return self.get_next() 1368 1369 def get_next(self, name=None): 1370 """Returns the next input from the iterator for all replicas.""" 1371 raise NotImplementedError("Iterating over an `AsyncDistributedIterator` " 1372 "is not supported right now.") 1373 1374 1375def _extract_failed_ps_instances(err_msg): 1376 """Return a set of potentially failing ps instances from error message.""" 1377 tasks = re.findall("/job:ps/replica:0/task:[0-9]+", err_msg) 1378 return set(int(t.split(":")[-1]) for t in tasks) 1379 1380 1381def _is_ps_failure(error): 1382 """Whether the error is considered a parameter server failure.""" 1383 return (isinstance(error, errors.UnavailableError) and 1384 _RPC_ERROR_FROM_PS in str(error)) 1385 1386 1387def _is_worker_failure(error): 1388 """Whether the error is considered a worker failure.""" 1389 if _JOB_WORKER_STRING_IDENTIFIER not in str(error): 1390 return False 1391 if _RPC_ERROR_FROM_PS in str(error): 1392 return False 1393 1394 # TODO(haoyuzhang): Consider using special status code if error from a 1395 # remote is derived from RPC errors originated from other hosts. 1396 if isinstance(error, (errors.UnavailableError, errors.AbortedError)): 1397 return True 1398 1399 # The following error could happen when the remote task fails and restarts 1400 # in a very short interval during which no RPCs were exchanged to detect the 1401 # failure. In that case, gRPC allows channel (which is different from a 1402 # connection) to be reused for a replaced server listening to same address. 1403 if isinstance(error, errors.InvalidArgumentError): 1404 if ("unknown device" in str(error) or 1405 "Unable to find the relevant tensor remote_handle" in str(error)): 1406 # TODO(b/159961667): Fix "Unable to find the relevant tensor 1407 # remote_handle" part. 1408 return True 1409 1410 # TODO(b/162541228): The following 2 types of errors are very rare and only 1411 # observed in large-scale testing. The types of errors should be reduced. 1412 # This could happen when the function registration fails. In the observed 1413 # cases this only happens to the dataset related functions. 1414 if isinstance(error, errors.NotFoundError): 1415 if ("is neither a type of a primitive operation nor a name of a function " 1416 "registered" in str(error)): 1417 return True 1418 1419 return False 1420