• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Lint as: python3
2# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16"""Fault tolerance test for parameter server training in TF2."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import gc
23import sys
24import threading
25import time
26
27from tensorflow.python.compat import v2_compat
28from tensorflow.python.data.ops import dataset_ops
29from tensorflow.python.distribute import multi_process_runner
30from tensorflow.python.distribute import multi_worker_test_base
31from tensorflow.python.distribute import parameter_server_strategy_v2
32from tensorflow.python.distribute import test_util
33from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
34from tensorflow.python.distribute.coordinator import cluster_coordinator
35from tensorflow.python.eager import context
36from tensorflow.python.eager import def_function
37from tensorflow.python.eager import test
38from tensorflow.python.framework import dtypes
39from tensorflow.python.framework import errors
40from tensorflow.python.framework import ops
41from tensorflow.python.ops import array_ops
42from tensorflow.python.ops import check_ops
43from tensorflow.python.ops import math_ops
44from tensorflow.python.ops import random_ops
45from tensorflow.python.ops import variables
46from tensorflow.python.platform import tf_logging as logging
47from tensorflow.python.training import coordinator as thread_coordinator
48from tensorflow.python.training import server_lib
49
50_RPC_ERROR_FROM_WORKER = "GRPC error information from remote target /job:worker"
51_RPC_ERROR_FROM_PS = "GRPC error information from remote target /job:ps"
52_WORKER_PREEMPTION_THREAD_NAME = "WorkerPreemptionHandler"
53_WORKER_THREAD_PREFIX = "WorkerClosureProcessingLoop"
54
55
56class Model(object):
57
58  def __init__(self, coordinator):
59    self.cluster_coord = coordinator
60    self.strategy = self.cluster_coord.strategy
61    with self.cluster_coord.strategy.scope():
62      self.build()
63
64  def build(self):
65    self.w = variables.Variable(
66        initial_value=random_ops.random_uniform((10, 10)), dtype=dtypes.float32)
67    self.iterations = variables.Variable(initial_value=0, dtype=dtypes.int32)
68    # Allow external control to make the model run its train_fn in an infinite
69    # loop. This allows us to reliably test worker preemption in the middle of
70    # function execution.
71    self.do_infinite_step = variables.Variable(False)
72
73    self.rebuild_iterators()
74
75  def rebuild_iterators(self, use_dataset_fn=True):
76
77    if use_dataset_fn:
78
79      def dataset_fn():
80        data = random_ops.random_uniform((10, 10))
81        dataset = dataset_ops.DatasetV2.from_tensors([data]).repeat()
82        return dataset
83
84      self.iterator = iter(
85          self.cluster_coord.create_per_worker_dataset(dataset_fn))
86      self.iterator2 = iter(
87          self.cluster_coord.create_per_worker_dataset(dataset_fn))
88    else:
89      data = random_ops.random_uniform((10, 10))
90      dataset = dataset_ops.DatasetV2.from_tensors([data]).repeat()
91
92      self.iterator = iter(
93          self.cluster_coord.create_per_worker_dataset(dataset))
94      self.iterator2 = iter(
95          self.cluster_coord.create_per_worker_dataset(dataset))
96
97  def _train_fn_internal(self, iterator, iterator2):
98    x = math_ops.matmul(array_ops.squeeze(next(iterator)), self.w)
99    x = math_ops.matmul(array_ops.squeeze(next(iterator2)), x)
100    x = math_ops.matmul(random_ops.random_uniform((10, 10)), x)
101    self.w.assign_add(x)
102
103  @def_function.function
104  def train_fn(self, iterator, iterator2):
105    self._train_fn_internal(iterator, iterator2)
106    while self.do_infinite_step:
107      self._train_fn_internal(iterator, iterator2)
108    self.iterations.assign_add(1)
109
110  def schedule_training_functions(self, num_steps):
111    with self.strategy.scope():
112      for _ in range(num_steps):
113        self.cluster_coord.schedule(
114            self.train_fn, args=(self.iterator, self.iterator2))
115
116  def join_training_functions(self):
117    self.do_infinite_step.assign(False)
118    self.cluster_coord.join()
119
120
121class BaseFaultToleranceTest(object):  # pylint: disable=missing-docstring
122
123  def setUp(self, num_workers, num_ps):
124    super(BaseFaultToleranceTest, self).setUp()
125
126    self._cluster = multi_worker_test_base.create_multi_process_cluster(
127        num_workers=num_workers, num_ps=num_ps, rpc_layer="grpc")
128    self._cluster_def = self._cluster.cluster_resolver.cluster_spec().as_dict()
129    self._cluster_def["chief"] = [
130        "localhost:%d" % multi_worker_test_base.pick_unused_port()
131    ]
132    cluster_resolver = SimpleClusterResolver(
133        server_lib.ClusterSpec(self._cluster_def), rpc_layer="grpc")
134
135    # The strategy's constructor would connect to the cluster.
136    self.strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
137        cluster_resolver)
138    self.cluster_coord = cluster_coordinator.ClusterCoordinator(self.strategy)
139
140    self.thread_coord = thread_coordinator.Coordinator(
141        clean_stop_exception_types=[])
142    self.num_workers = num_workers
143    self.num_ps = num_ps
144
145  def tearDown(self):
146    super(BaseFaultToleranceTest, self).tearDown()
147    self._cluster.stop()
148    self._cluster = None
149
150  def _restart(self, downtime_secs, job):
151    """Kills `job` (index: 0) and restarts it after `downtime_secs`.
152
153    Args:
154      downtime_secs: secs before restarting the job.
155      job: a string specifying the job to restart.
156    """
157    self._cluster.kill_task(job, 0)
158    time.sleep(downtime_secs)
159    self.assertFalse(context.check_alive("/job:%s/replica:0/task:0" % job))
160    self._cluster.start_task(job, 0)
161    while not context.check_alive("/job:%s/replica:0/task:0" % job):
162      time.sleep(1)
163
164  def _restart_in_thread(self, downtime_secs, restart_job):
165
166    def _restart_fn():
167      with self.thread_coord.stop_on_exception():
168        self._restart(downtime_secs, restart_job)
169
170    restart_thread = threading.Thread(target=_restart_fn)
171    restart_thread.start()
172    return restart_thread
173
174  def _ensure_threads_closed(self):
175    """Ensures worker and preemption threads are closed."""
176    # Worker and preemption threads should exist before releasing
177    # ClusterCoordinator.
178    running_threads = test_util.get_running_threads()
179    self.assertTrue(
180        test_util.has_thread(_WORKER_THREAD_PREFIX, running_threads))
181    self.assertIn(_WORKER_PREEMPTION_THREAD_NAME, running_threads)
182
183    # Print object graph if ClusterCoordinator may leak.
184    if sys.getrefcount(self.cluster_coord) > 2:
185      try:
186        test_util.show_backref(self.cluster_coord)
187      except:  # pylint: disable=bare-except
188        pass
189
190    # Wait for threads to close.
191    self.cluster_coord = None
192    self.strategy = None
193    gc.collect()
194    time.sleep(1)
195
196    # Verify thread names.
197    running_threads = test_util.get_running_threads()
198    self.assertNotIn(_WORKER_PREEMPTION_THREAD_NAME, running_threads)
199    self.assertFalse(
200        test_util.has_thread(_WORKER_THREAD_PREFIX, running_threads),
201        "Worker thread is not stopped properly.")
202
203  def _create_model_and_run_indefinitely(self):
204    model = Model(self.cluster_coord)
205    model.do_infinite_step.assign(True)
206    model.schedule_training_functions(10)
207    # Model does infinite training step, so at this moment, we expect to have
208    # `self.num_workers` infinite closures inflight, and `10-self.num_workers`
209    # closures in the queue.
210    while (self.cluster_coord._cluster.closure_queue._inflight_closure_count <
211           self.num_workers):
212      time.sleep(0.1)
213    return model
214
215  def testClusterCoordinatorDestroyed(self):
216    self._ensure_threads_closed()
217
218  def testWorkerPreemptionBetweenFunctions(self):
219    model = Model(self.cluster_coord)
220    model.schedule_training_functions(2)
221    model.join_training_functions()
222    self.assertEqual(model.iterations.numpy(), 2)
223
224    self._restart(downtime_secs=2, job="worker")
225
226    model.schedule_training_functions(2)
227    model.join_training_functions()
228    self.assertEqual(model.iterations.numpy(), 4)
229
230  def testWorkerPreemptionMidstFunction(self):
231    model = Model(self.cluster_coord)
232    model.do_infinite_step.assign(True)
233
234    model.schedule_training_functions(4)
235    # Model does infinite training step, so at this moment, we expect to have
236    # `self.num_workers` infinite closures inflight, and `4-self.num_workers`
237    # closures in the queue.
238    while (self.cluster_coord._cluster.closure_queue._inflight_closure_count <
239           self.num_workers):
240      time.sleep(0.1)
241    self.assertFalse(self.cluster_coord.done())
242    self._restart(downtime_secs=2, job="worker")
243    model.join_training_functions()
244    self.assertGreaterEqual(model.iterations.numpy(), 4)
245
246  def testOneWorkerPreemptionWithCancellation(self):
247
248    @def_function.function
249    def normal_function():
250      x = random_ops.random_uniform((2, 10))
251      y = random_ops.random_uniform((10, 2))
252      return math_ops.reduce_mean(math_ops.matmul(x, y))
253
254    @def_function.function
255    def error_function():
256      x = random_ops.random_uniform((2, 10))
257      y = random_ops.random_uniform((10, 2))
258      check_ops.assert_non_positive_v2(
259          math_ops.reduce_sum(math_ops.matmul(x, y)))
260      return x
261
262    @def_function.function
263    def long_function():
264      x = random_ops.random_uniform((1000, 1000))
265      for _ in math_ops.range(10000):
266        a = random_ops.random_uniform((1000, 1000))
267        b = random_ops.random_uniform((1000, 1000))
268        x += math_ops.matmul(a, b)
269      return x
270
271    for _ in range(3):
272      self.cluster_coord.schedule(normal_function)
273    long_function_result = self.cluster_coord.schedule(long_function)
274    self.cluster_coord.schedule(error_function)
275
276    time.sleep(1)  # Let it run a couple steps.
277    self._restart(1, "worker")
278
279    with self.assertRaises(errors.InvalidArgumentError):
280      self.cluster_coord.join()
281
282    with self.assertRaises(errors.CancelledError):
283      long_function_result.fetch()
284
285    for _ in range(3):
286      self.cluster_coord.schedule(normal_function)
287    self.cluster_coord.join()
288
289    # The cluster is likely still being recovered since `join` returned early
290    # due to the error_function.
291    failure_handler = self.cluster_coord._cluster.failure_handler
292    failure_handler.stop()
293    failure_handler._preemption_handler_thread.join()
294
295  def testHandleDatasetCreationFailureWithDatasetFn(self):
296    model = Model(self.cluster_coord)
297
298    restart_thread = self._restart_in_thread(5, "worker")
299
300    model.schedule_training_functions(3)
301    model.rebuild_iterators()
302    model.schedule_training_functions(3)
303    model.rebuild_iterators()
304    model.schedule_training_functions(3)
305
306    model.join_training_functions()
307
308    self.thread_coord.join([restart_thread])
309    self.assertGreaterEqual(model.iterations.numpy(), 3)
310
311  # TODO(yuefengz): consider using combinations when there is more code
312  # duplication.
313  def testHandleDatasetCreationFailureWithDataset(self):
314    model = Model(self.cluster_coord)
315
316    restart_thread = self._restart_in_thread(5, "worker")
317
318    model.schedule_training_functions(3)
319    model.rebuild_iterators(use_dataset_fn=False)
320    model.schedule_training_functions(3)
321    model.rebuild_iterators(use_dataset_fn=False)
322    model.schedule_training_functions(3)
323
324    model.join_training_functions()
325
326    self.thread_coord.join([restart_thread])
327    self.assertGreaterEqual(model.iterations.numpy(), 3)
328
329  def testWorkerPreemptionErrorType(self):
330
331    @def_function.function
332    def worker_train_fn():
333      x = random_ops.random_uniform((2, 10))
334      y = random_ops.random_uniform((10, 2))
335      return math_ops.reduce_mean(math_ops.matmul(x, y))
336
337    def run_fn():
338      with self.thread_coord.stop_on_exception():
339        with ops.device("/job:worker/replica:0/task:0"):
340          for _ in range(3):
341            for _ in range(3):
342              worker_train_fn()
343            time.sleep(5)
344
345    run_thread = threading.Thread(target=run_fn)
346    run_thread.start()
347    time.sleep(1)  # Let it run a couple steps.
348    self._restart(2, "worker")
349
350    try:
351      self.thread_coord.join([run_thread])
352    except errors.UnavailableError as e:
353      logging.info("Got exception %r, error message is %s", e, e)
354
355      self.assertIn(_RPC_ERROR_FROM_WORKER, str(e))  # pylint: disable=g-assert-in-except
356      self.assertNotIn(_RPC_ERROR_FROM_PS, str(e))
357
358      self.assertTrue("failed to connect to all addresses" in str(e) or
359                      "Unable to find a context_id" in str(e) or
360                      "Socket closed" in str(e) or
361                      "Connection reset by peer" in str(e) or
362                      "Transport closed" in str(e))
363
364  def testWorkerPreemptionErrorTypeWithPythonFunction(self):
365
366    def worker_train_fn():
367      x = random_ops.random_uniform((2, 10))
368      y = random_ops.random_uniform((10, 2))
369      return math_ops.reduce_mean(math_ops.matmul(x, y))
370
371    def run_fn():
372      with self.thread_coord.stop_on_exception():
373        with ops.device("/job:worker/replica:0/task:0"):
374          for _ in range(3):
375            for _ in range(3):
376              worker_train_fn()
377            time.sleep(5)
378
379    run_thread = threading.Thread(target=run_fn)
380    run_thread.start()
381    time.sleep(1)  # Let it run a couple steps.
382    self._restart(2, "worker")
383
384    try:
385      self.thread_coord.join([run_thread])
386    except errors.UnavailableError as e:
387      logging.info("Got exception %r, error message is %s", e, e)
388
389      self.assertIn(_RPC_ERROR_FROM_WORKER, str(e))  # pylint: disable=g-assert-in-except
390      self.assertNotIn(_RPC_ERROR_FROM_PS, str(e))
391
392      self.assertTrue("failed to connect to all addresses" in str(e) or
393                      "Unable to find a context_id" in str(e) or
394                      "Socket closed" in str(e) or
395                      "Connection reset by peer" in str(e) or
396                      "Transport closed" in str(e))
397
398  def testPSPreemptionErrorType(self):
399
400    with ops.device("/job:ps/replica:0/task:0"):
401      v = variables.Variable(
402          initial_value=random_ops.random_uniform((2, 10)),
403          dtype=dtypes.float32)
404
405    @def_function.function
406    def worker_train_fn():
407      y = random_ops.random_uniform((10, 2))
408      return math_ops.reduce_mean(math_ops.matmul(v, y))
409
410    def run_fn():
411      with self.thread_coord.stop_on_exception():
412        with ops.device("/job:worker/replica:0/task:0"):
413          for _ in range(3):
414            for _ in range(3):
415              worker_train_fn()
416            time.sleep(5)
417
418    run_thread = threading.Thread(target=run_fn)
419    run_thread.start()
420    time.sleep(1)  # Let it run a couple steps.
421
422    # Use a short restart delay to cover the case that RPC channel is reused
423    self._restart(1, "ps")
424
425    try:
426      self.thread_coord.join([run_thread])
427    except (errors.UnavailableError, errors.AbortedError) as e:
428      logging.info("Got exception %r, error message is %s", e, e)
429      self.assertIn(_RPC_ERROR_FROM_PS, str(e))  # pylint: disable=g-assert-in-except
430
431      if isinstance(e, errors.UnavailableError):
432        self.assertTrue("failed to connect to all addresses" in str(e) or
433                        "Unable to find a context_id" in str(e) or
434                        "Socket closed" in str(e) or
435                        "Connection reset by peer" in str(e) or
436                        "Transport closed" in str(e))
437
438      if isinstance(e, errors.AbortedError):
439        self.assertIn("RecvTensor expects a different device incarnation",
440                      str(e))
441      self._ensure_threads_closed()
442
443  def testTwoWorkersPreempted(self):
444    if self.num_workers < 2:
445      self.skipTest("Worker number is less than 2.")
446    model = self._create_model_and_run_indefinitely()
447
448    self.assertFalse(self.cluster_coord.done())
449    self._cluster.kill_task("worker", 0)
450    self._cluster.kill_task("worker", 1)
451    time.sleep(2)
452    self.assertFalse(context.check_alive("/job:worker/replica:0/task:0"))
453    self.assertFalse(context.check_alive("/job:worker/replica:0/task:1"))
454    self._cluster.start_task("worker", 0)
455    self._cluster.start_task("worker", 1)
456    time.sleep(2)
457    self.assertTrue(context.check_alive("/job:worker/replica:0/task:0"))
458    self.assertTrue(context.check_alive("/job:worker/replica:0/task:1"))
459
460    model.join_training_functions()
461    self.assertGreaterEqual(model.iterations.numpy(), 10)
462
463  def testWorkerContinuousFailure(self):
464    model = self._create_model_and_run_indefinitely()
465
466    self.assertFalse(self.cluster_coord.done())
467    self._cluster.kill_task("worker", 0)
468    time.sleep(2)
469    self.assertFalse(context.check_alive("/job:worker/replica:0/task:0"))
470    self._cluster.start_task("worker", 0)
471    time.sleep(2)
472    self.assertTrue(context.check_alive("/job:worker/replica:0/task:0"))
473    self._cluster.kill_task("worker", 0)
474    time.sleep(2)
475    self.assertFalse(context.check_alive("/job:worker/replica:0/task:0"))
476    self._cluster.start_task("worker", 0)
477    time.sleep(2)
478    self.assertTrue(context.check_alive("/job:worker/replica:0/task:0"))
479
480    model.join_training_functions()
481    self.assertGreaterEqual(model.iterations.numpy(), 10)
482
483  def testPSFailureWhileRecoveryFromWokerFailure(self):
484    model = self._create_model_and_run_indefinitely()
485
486    time.sleep(1)
487    self.assertFalse(self.cluster_coord.done())
488
489    def kill(task):
490      self._cluster.kill_task(task, 0)
491      self.sleep(1)
492      self._cluster.start_task(task, 0)
493
494    kill_thread_1 = threading.Thread(target=kill, args=("worker",))
495    kill_thread_2 = threading.Thread(target=kill, args=("ps",))
496    kill_thread_1.start()
497    kill_thread_2.start()
498    kill_thread_1.join()
499    kill_thread_2.join()
500
501    with self.assertRaises(
502        (errors.UnavailableError, errors.InvalidArgumentError)):
503      model.join_training_functions()
504
505  def testNumpyFetchedAfterWorkerFailure(self):
506
507    with self.strategy.scope():
508      v = variables.Variable(initial_value=0, dtype=dtypes.int32)
509
510    @def_function.function
511    def worker_fn():
512      return v + 1, v - 1
513
514    remote_value = self.cluster_coord.schedule(worker_fn)
515    # Attempt to fetch before killing worker task should succeed.
516    self.assertEqual((1, -1), remote_value.fetch())
517    self._cluster.kill_task("worker", 0)
518    # So should attempt to fetch after killing worker task.
519    self.assertEqual((1, -1), remote_value.fetch())
520
521  def testTensorGotAfterWorkerFailure(self):
522
523    with self.strategy.scope():
524      v = variables.Variable(initial_value=0, dtype=dtypes.int32)
525
526    @def_function.function
527    def worker_fn():
528      return v + 1, v - 1
529
530    remote_value = self.cluster_coord.schedule(worker_fn)
531
532    # Attempt to fetch before killing worker task should succeed.
533    fetched = remote_value.get()[0]
534    self.assertIsInstance(fetched, ops.Tensor)
535    self.assertEqual(fetched.device, "/job:chief/replica:0/task:0/device:CPU:0")
536    self.assertEqual((1, -1), remote_value.get())
537    remote_value.get()[0].numpy()
538
539    # As well as the remote tensors that point to worker0 or worker1.
540    values = remote_value._values[0]
541    self.assertIsInstance(values, ops.Tensor)
542    self.assertRegex(values.device,
543                     "/job:worker/replica:0/task:[0-1]/device:CPU:0")
544    self.assertEqual((1, -1), remote_value._values)
545    remote_value._values[0].numpy()
546
547    # Terminate the workers and wait a little so that they are indeed killed.
548    for i in range(self.num_workers):
549      self._cluster.kill_task("worker", i)
550    time.sleep(5)
551
552    # Attempt to fetch after killing worker tasks should succeed as well.
553    remote_value.get()[0].numpy()
554    self.assertEqual((1, -1), remote_value.get())
555
556    # Attempting to copy the tensor from worker now should fail.
557    with self.assertRaises(errors.UnavailableError) as cm:
558      remote_value._values[0].numpy()
559    self.assertIn("failed to connect to all addresses", cm.exception.message)
560    self.assertIn("/job:worker/replica:0/task:", cm.exception.message)
561
562  def testClusterStateNotDisrupted(self):
563    # This test has side effects and can disrupt other tests, even if the
564    # resource created by it will not be used in following tests.
565    # TODO(b/155209534): enable this test.
566    # self.testPSPreemptionErrorType()
567
568    self.thread_coord = thread_coordinator.Coordinator(
569        clean_stop_exception_types=[])
570    self.testWorkerPreemptionMidstFunction()
571
572    self.thread_coord = thread_coordinator.Coordinator(
573        clean_stop_exception_types=[])
574    self.testWorkerPreemptionErrorType()
575
576    # In previous tests, workers may fail after training is done. But the
577    # following tests start with creating resources where failure is not
578    # handled.
579    # TODO(b/153888707): enable the following two tests.
580    # self.testTwoWorkersPreempted()
581    # self.testWorkerContinuousFailure()
582
583  def testJoinRaisesUnavailableErrorAtPsFailure(self):
584    self._create_model_and_run_indefinitely()
585    self._cluster.kill_task("ps", 0)
586    while self.cluster_coord._cluster.closure_queue._error is None:
587      time.sleep(1)
588    with self.assertRaises((errors.UnavailableError, errors.NotFoundError,
589                            errors.FailedPreconditionError)):
590      self.cluster_coord.join()
591
592  def testScheduleRaisesUnavailableErrorAtPsFailure(self):
593    self._create_model_and_run_indefinitely()
594    self._cluster.kill_task("ps", 0)
595    while self.cluster_coord._cluster.closure_queue._error is None:
596      time.sleep(1)
597    with self.assertRaises((errors.UnavailableError, errors.NotFoundError,
598                            errors.FailedPreconditionError)):
599      self.cluster_coord.schedule(def_function.function(lambda: None))
600
601  def testWorkerExecutionAfterPsFailureRaisesExpectedError(self):
602    model = self._create_model_and_run_indefinitely()
603    for i in range(self.num_ps):
604      self._cluster.kill_task("ps", i)
605    while self.cluster_coord._cluster.closure_queue._error is None:
606      time.sleep(1)
607
608    @def_function.function
609    def trivial_function():
610      return model.iterations + 1
611
612    for i in range(self.num_workers):
613      try:
614        with ops.device("/job:worker/replica:0/task:{}".format(i)):
615          trivial_function()
616      except Exception as e:  # pylint: disable=broad-except
617        if cluster_coordinator._is_ps_failure(e):
618          if i < self.num_workers - 1:
619            continue
620          return
621      raise AssertionError("Executing a function after PS fails, should "
622                           "result in a PS failure.")
623
624
625class MultiWorkerFaultToleranceTest(BaseFaultToleranceTest, test.TestCase):
626  """Multi worker fault tolerance tests.
627
628  This covers the ordinary cases where multiple workers and PS are used.
629  """
630
631  def setUp(self):
632    super(MultiWorkerFaultToleranceTest, self).setUp(2, 2)
633
634
635class SingleWorkerFaultToleranceTest(BaseFaultToleranceTest, test.TestCase):
636  """Single worker fault tolerance tests.
637
638  This covers the cases that ensure training can continue in a single-worker
639  cluster, even if the only worker can become unavailable at some point and
640  recovered (if there are multiple workers, it is possible that the training
641  succeeds with the workers that did not fail). Realistically single worker
642  is very rarely used, but the tests are important to ensure the correct
643  behaviors.
644  """
645
646  def setUp(self):
647    super(SingleWorkerFaultToleranceTest, self).setUp(1, 1)
648
649
650if __name__ == "__main__":
651  v2_compat.enable_v2_behavior()
652  multi_process_runner.test_main()
653