1# Lint as: python3 2# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16"""Fault tolerance test for parameter server training in TF2.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import gc 23import sys 24import threading 25import time 26 27from tensorflow.python.compat import v2_compat 28from tensorflow.python.data.ops import dataset_ops 29from tensorflow.python.distribute import multi_process_runner 30from tensorflow.python.distribute import multi_worker_test_base 31from tensorflow.python.distribute import parameter_server_strategy_v2 32from tensorflow.python.distribute import test_util 33from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver 34from tensorflow.python.distribute.coordinator import cluster_coordinator 35from tensorflow.python.eager import context 36from tensorflow.python.eager import def_function 37from tensorflow.python.eager import test 38from tensorflow.python.framework import dtypes 39from tensorflow.python.framework import errors 40from tensorflow.python.framework import ops 41from tensorflow.python.ops import array_ops 42from tensorflow.python.ops import check_ops 43from tensorflow.python.ops import math_ops 44from tensorflow.python.ops import random_ops 45from tensorflow.python.ops import variables 46from tensorflow.python.platform import tf_logging as logging 47from tensorflow.python.training import coordinator as thread_coordinator 48from tensorflow.python.training import server_lib 49 50_RPC_ERROR_FROM_WORKER = "GRPC error information from remote target /job:worker" 51_RPC_ERROR_FROM_PS = "GRPC error information from remote target /job:ps" 52_WORKER_PREEMPTION_THREAD_NAME = "WorkerPreemptionHandler" 53_WORKER_THREAD_PREFIX = "WorkerClosureProcessingLoop" 54 55 56class Model(object): 57 58 def __init__(self, coordinator): 59 self.cluster_coord = coordinator 60 self.strategy = self.cluster_coord.strategy 61 with self.cluster_coord.strategy.scope(): 62 self.build() 63 64 def build(self): 65 self.w = variables.Variable( 66 initial_value=random_ops.random_uniform((10, 10)), dtype=dtypes.float32) 67 self.iterations = variables.Variable(initial_value=0, dtype=dtypes.int32) 68 # Allow external control to make the model run its train_fn in an infinite 69 # loop. This allows us to reliably test worker preemption in the middle of 70 # function execution. 71 self.do_infinite_step = variables.Variable(False) 72 73 self.rebuild_iterators() 74 75 def rebuild_iterators(self, use_dataset_fn=True): 76 77 if use_dataset_fn: 78 79 def dataset_fn(): 80 data = random_ops.random_uniform((10, 10)) 81 dataset = dataset_ops.DatasetV2.from_tensors([data]).repeat() 82 return dataset 83 84 self.iterator = iter( 85 self.cluster_coord.create_per_worker_dataset(dataset_fn)) 86 self.iterator2 = iter( 87 self.cluster_coord.create_per_worker_dataset(dataset_fn)) 88 else: 89 data = random_ops.random_uniform((10, 10)) 90 dataset = dataset_ops.DatasetV2.from_tensors([data]).repeat() 91 92 self.iterator = iter( 93 self.cluster_coord.create_per_worker_dataset(dataset)) 94 self.iterator2 = iter( 95 self.cluster_coord.create_per_worker_dataset(dataset)) 96 97 def _train_fn_internal(self, iterator, iterator2): 98 x = math_ops.matmul(array_ops.squeeze(next(iterator)), self.w) 99 x = math_ops.matmul(array_ops.squeeze(next(iterator2)), x) 100 x = math_ops.matmul(random_ops.random_uniform((10, 10)), x) 101 self.w.assign_add(x) 102 103 @def_function.function 104 def train_fn(self, iterator, iterator2): 105 self._train_fn_internal(iterator, iterator2) 106 while self.do_infinite_step: 107 self._train_fn_internal(iterator, iterator2) 108 self.iterations.assign_add(1) 109 110 def schedule_training_functions(self, num_steps): 111 with self.strategy.scope(): 112 for _ in range(num_steps): 113 self.cluster_coord.schedule( 114 self.train_fn, args=(self.iterator, self.iterator2)) 115 116 def join_training_functions(self): 117 self.do_infinite_step.assign(False) 118 self.cluster_coord.join() 119 120 121class BaseFaultToleranceTest(object): # pylint: disable=missing-docstring 122 123 def setUp(self, num_workers, num_ps): 124 super(BaseFaultToleranceTest, self).setUp() 125 126 self._cluster = multi_worker_test_base.create_multi_process_cluster( 127 num_workers=num_workers, num_ps=num_ps, rpc_layer="grpc") 128 self._cluster_def = self._cluster.cluster_resolver.cluster_spec().as_dict() 129 self._cluster_def["chief"] = [ 130 "localhost:%d" % multi_worker_test_base.pick_unused_port() 131 ] 132 cluster_resolver = SimpleClusterResolver( 133 server_lib.ClusterSpec(self._cluster_def), rpc_layer="grpc") 134 135 # The strategy's constructor would connect to the cluster. 136 self.strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( 137 cluster_resolver) 138 self.cluster_coord = cluster_coordinator.ClusterCoordinator(self.strategy) 139 140 self.thread_coord = thread_coordinator.Coordinator( 141 clean_stop_exception_types=[]) 142 self.num_workers = num_workers 143 self.num_ps = num_ps 144 145 def tearDown(self): 146 super(BaseFaultToleranceTest, self).tearDown() 147 self._cluster.stop() 148 self._cluster = None 149 150 def _restart(self, downtime_secs, job): 151 """Kills `job` (index: 0) and restarts it after `downtime_secs`. 152 153 Args: 154 downtime_secs: secs before restarting the job. 155 job: a string specifying the job to restart. 156 """ 157 self._cluster.kill_task(job, 0) 158 time.sleep(downtime_secs) 159 self.assertFalse(context.check_alive("/job:%s/replica:0/task:0" % job)) 160 self._cluster.start_task(job, 0) 161 while not context.check_alive("/job:%s/replica:0/task:0" % job): 162 time.sleep(1) 163 164 def _restart_in_thread(self, downtime_secs, restart_job): 165 166 def _restart_fn(): 167 with self.thread_coord.stop_on_exception(): 168 self._restart(downtime_secs, restart_job) 169 170 restart_thread = threading.Thread(target=_restart_fn) 171 restart_thread.start() 172 return restart_thread 173 174 def _ensure_threads_closed(self): 175 """Ensures worker and preemption threads are closed.""" 176 # Worker and preemption threads should exist before releasing 177 # ClusterCoordinator. 178 running_threads = test_util.get_running_threads() 179 self.assertTrue( 180 test_util.has_thread(_WORKER_THREAD_PREFIX, running_threads)) 181 self.assertIn(_WORKER_PREEMPTION_THREAD_NAME, running_threads) 182 183 # Print object graph if ClusterCoordinator may leak. 184 if sys.getrefcount(self.cluster_coord) > 2: 185 try: 186 test_util.show_backref(self.cluster_coord) 187 except: # pylint: disable=bare-except 188 pass 189 190 # Wait for threads to close. 191 self.cluster_coord = None 192 self.strategy = None 193 gc.collect() 194 time.sleep(1) 195 196 # Verify thread names. 197 running_threads = test_util.get_running_threads() 198 self.assertNotIn(_WORKER_PREEMPTION_THREAD_NAME, running_threads) 199 self.assertFalse( 200 test_util.has_thread(_WORKER_THREAD_PREFIX, running_threads), 201 "Worker thread is not stopped properly.") 202 203 def _create_model_and_run_indefinitely(self): 204 model = Model(self.cluster_coord) 205 model.do_infinite_step.assign(True) 206 model.schedule_training_functions(10) 207 # Model does infinite training step, so at this moment, we expect to have 208 # `self.num_workers` infinite closures inflight, and `10-self.num_workers` 209 # closures in the queue. 210 while (self.cluster_coord._cluster.closure_queue._inflight_closure_count < 211 self.num_workers): 212 time.sleep(0.1) 213 return model 214 215 def testClusterCoordinatorDestroyed(self): 216 self._ensure_threads_closed() 217 218 def testWorkerPreemptionBetweenFunctions(self): 219 model = Model(self.cluster_coord) 220 model.schedule_training_functions(2) 221 model.join_training_functions() 222 self.assertEqual(model.iterations.numpy(), 2) 223 224 self._restart(downtime_secs=2, job="worker") 225 226 model.schedule_training_functions(2) 227 model.join_training_functions() 228 self.assertEqual(model.iterations.numpy(), 4) 229 230 def testWorkerPreemptionMidstFunction(self): 231 model = Model(self.cluster_coord) 232 model.do_infinite_step.assign(True) 233 234 model.schedule_training_functions(4) 235 # Model does infinite training step, so at this moment, we expect to have 236 # `self.num_workers` infinite closures inflight, and `4-self.num_workers` 237 # closures in the queue. 238 while (self.cluster_coord._cluster.closure_queue._inflight_closure_count < 239 self.num_workers): 240 time.sleep(0.1) 241 self.assertFalse(self.cluster_coord.done()) 242 self._restart(downtime_secs=2, job="worker") 243 model.join_training_functions() 244 self.assertGreaterEqual(model.iterations.numpy(), 4) 245 246 def testOneWorkerPreemptionWithCancellation(self): 247 248 @def_function.function 249 def normal_function(): 250 x = random_ops.random_uniform((2, 10)) 251 y = random_ops.random_uniform((10, 2)) 252 return math_ops.reduce_mean(math_ops.matmul(x, y)) 253 254 @def_function.function 255 def error_function(): 256 x = random_ops.random_uniform((2, 10)) 257 y = random_ops.random_uniform((10, 2)) 258 check_ops.assert_non_positive_v2( 259 math_ops.reduce_sum(math_ops.matmul(x, y))) 260 return x 261 262 @def_function.function 263 def long_function(): 264 x = random_ops.random_uniform((1000, 1000)) 265 for _ in math_ops.range(10000): 266 a = random_ops.random_uniform((1000, 1000)) 267 b = random_ops.random_uniform((1000, 1000)) 268 x += math_ops.matmul(a, b) 269 return x 270 271 for _ in range(3): 272 self.cluster_coord.schedule(normal_function) 273 long_function_result = self.cluster_coord.schedule(long_function) 274 self.cluster_coord.schedule(error_function) 275 276 time.sleep(1) # Let it run a couple steps. 277 self._restart(1, "worker") 278 279 with self.assertRaises(errors.InvalidArgumentError): 280 self.cluster_coord.join() 281 282 with self.assertRaises(errors.CancelledError): 283 long_function_result.fetch() 284 285 for _ in range(3): 286 self.cluster_coord.schedule(normal_function) 287 self.cluster_coord.join() 288 289 # The cluster is likely still being recovered since `join` returned early 290 # due to the error_function. 291 failure_handler = self.cluster_coord._cluster.failure_handler 292 failure_handler.stop() 293 failure_handler._preemption_handler_thread.join() 294 295 def testHandleDatasetCreationFailureWithDatasetFn(self): 296 model = Model(self.cluster_coord) 297 298 restart_thread = self._restart_in_thread(5, "worker") 299 300 model.schedule_training_functions(3) 301 model.rebuild_iterators() 302 model.schedule_training_functions(3) 303 model.rebuild_iterators() 304 model.schedule_training_functions(3) 305 306 model.join_training_functions() 307 308 self.thread_coord.join([restart_thread]) 309 self.assertGreaterEqual(model.iterations.numpy(), 3) 310 311 # TODO(yuefengz): consider using combinations when there is more code 312 # duplication. 313 def testHandleDatasetCreationFailureWithDataset(self): 314 model = Model(self.cluster_coord) 315 316 restart_thread = self._restart_in_thread(5, "worker") 317 318 model.schedule_training_functions(3) 319 model.rebuild_iterators(use_dataset_fn=False) 320 model.schedule_training_functions(3) 321 model.rebuild_iterators(use_dataset_fn=False) 322 model.schedule_training_functions(3) 323 324 model.join_training_functions() 325 326 self.thread_coord.join([restart_thread]) 327 self.assertGreaterEqual(model.iterations.numpy(), 3) 328 329 def testWorkerPreemptionErrorType(self): 330 331 @def_function.function 332 def worker_train_fn(): 333 x = random_ops.random_uniform((2, 10)) 334 y = random_ops.random_uniform((10, 2)) 335 return math_ops.reduce_mean(math_ops.matmul(x, y)) 336 337 def run_fn(): 338 with self.thread_coord.stop_on_exception(): 339 with ops.device("/job:worker/replica:0/task:0"): 340 for _ in range(3): 341 for _ in range(3): 342 worker_train_fn() 343 time.sleep(5) 344 345 run_thread = threading.Thread(target=run_fn) 346 run_thread.start() 347 time.sleep(1) # Let it run a couple steps. 348 self._restart(2, "worker") 349 350 try: 351 self.thread_coord.join([run_thread]) 352 except errors.UnavailableError as e: 353 logging.info("Got exception %r, error message is %s", e, e) 354 355 self.assertIn(_RPC_ERROR_FROM_WORKER, str(e)) # pylint: disable=g-assert-in-except 356 self.assertNotIn(_RPC_ERROR_FROM_PS, str(e)) 357 358 self.assertTrue("failed to connect to all addresses" in str(e) or 359 "Unable to find a context_id" in str(e) or 360 "Socket closed" in str(e) or 361 "Connection reset by peer" in str(e) or 362 "Transport closed" in str(e)) 363 364 def testWorkerPreemptionErrorTypeWithPythonFunction(self): 365 366 def worker_train_fn(): 367 x = random_ops.random_uniform((2, 10)) 368 y = random_ops.random_uniform((10, 2)) 369 return math_ops.reduce_mean(math_ops.matmul(x, y)) 370 371 def run_fn(): 372 with self.thread_coord.stop_on_exception(): 373 with ops.device("/job:worker/replica:0/task:0"): 374 for _ in range(3): 375 for _ in range(3): 376 worker_train_fn() 377 time.sleep(5) 378 379 run_thread = threading.Thread(target=run_fn) 380 run_thread.start() 381 time.sleep(1) # Let it run a couple steps. 382 self._restart(2, "worker") 383 384 try: 385 self.thread_coord.join([run_thread]) 386 except errors.UnavailableError as e: 387 logging.info("Got exception %r, error message is %s", e, e) 388 389 self.assertIn(_RPC_ERROR_FROM_WORKER, str(e)) # pylint: disable=g-assert-in-except 390 self.assertNotIn(_RPC_ERROR_FROM_PS, str(e)) 391 392 self.assertTrue("failed to connect to all addresses" in str(e) or 393 "Unable to find a context_id" in str(e) or 394 "Socket closed" in str(e) or 395 "Connection reset by peer" in str(e) or 396 "Transport closed" in str(e)) 397 398 def testPSPreemptionErrorType(self): 399 400 with ops.device("/job:ps/replica:0/task:0"): 401 v = variables.Variable( 402 initial_value=random_ops.random_uniform((2, 10)), 403 dtype=dtypes.float32) 404 405 @def_function.function 406 def worker_train_fn(): 407 y = random_ops.random_uniform((10, 2)) 408 return math_ops.reduce_mean(math_ops.matmul(v, y)) 409 410 def run_fn(): 411 with self.thread_coord.stop_on_exception(): 412 with ops.device("/job:worker/replica:0/task:0"): 413 for _ in range(3): 414 for _ in range(3): 415 worker_train_fn() 416 time.sleep(5) 417 418 run_thread = threading.Thread(target=run_fn) 419 run_thread.start() 420 time.sleep(1) # Let it run a couple steps. 421 422 # Use a short restart delay to cover the case that RPC channel is reused 423 self._restart(1, "ps") 424 425 try: 426 self.thread_coord.join([run_thread]) 427 except (errors.UnavailableError, errors.AbortedError) as e: 428 logging.info("Got exception %r, error message is %s", e, e) 429 self.assertIn(_RPC_ERROR_FROM_PS, str(e)) # pylint: disable=g-assert-in-except 430 431 if isinstance(e, errors.UnavailableError): 432 self.assertTrue("failed to connect to all addresses" in str(e) or 433 "Unable to find a context_id" in str(e) or 434 "Socket closed" in str(e) or 435 "Connection reset by peer" in str(e) or 436 "Transport closed" in str(e)) 437 438 if isinstance(e, errors.AbortedError): 439 self.assertIn("RecvTensor expects a different device incarnation", 440 str(e)) 441 self._ensure_threads_closed() 442 443 def testTwoWorkersPreempted(self): 444 if self.num_workers < 2: 445 self.skipTest("Worker number is less than 2.") 446 model = self._create_model_and_run_indefinitely() 447 448 self.assertFalse(self.cluster_coord.done()) 449 self._cluster.kill_task("worker", 0) 450 self._cluster.kill_task("worker", 1) 451 time.sleep(2) 452 self.assertFalse(context.check_alive("/job:worker/replica:0/task:0")) 453 self.assertFalse(context.check_alive("/job:worker/replica:0/task:1")) 454 self._cluster.start_task("worker", 0) 455 self._cluster.start_task("worker", 1) 456 time.sleep(2) 457 self.assertTrue(context.check_alive("/job:worker/replica:0/task:0")) 458 self.assertTrue(context.check_alive("/job:worker/replica:0/task:1")) 459 460 model.join_training_functions() 461 self.assertGreaterEqual(model.iterations.numpy(), 10) 462 463 def testWorkerContinuousFailure(self): 464 model = self._create_model_and_run_indefinitely() 465 466 self.assertFalse(self.cluster_coord.done()) 467 self._cluster.kill_task("worker", 0) 468 time.sleep(2) 469 self.assertFalse(context.check_alive("/job:worker/replica:0/task:0")) 470 self._cluster.start_task("worker", 0) 471 time.sleep(2) 472 self.assertTrue(context.check_alive("/job:worker/replica:0/task:0")) 473 self._cluster.kill_task("worker", 0) 474 time.sleep(2) 475 self.assertFalse(context.check_alive("/job:worker/replica:0/task:0")) 476 self._cluster.start_task("worker", 0) 477 time.sleep(2) 478 self.assertTrue(context.check_alive("/job:worker/replica:0/task:0")) 479 480 model.join_training_functions() 481 self.assertGreaterEqual(model.iterations.numpy(), 10) 482 483 def testPSFailureWhileRecoveryFromWokerFailure(self): 484 model = self._create_model_and_run_indefinitely() 485 486 time.sleep(1) 487 self.assertFalse(self.cluster_coord.done()) 488 489 def kill(task): 490 self._cluster.kill_task(task, 0) 491 self.sleep(1) 492 self._cluster.start_task(task, 0) 493 494 kill_thread_1 = threading.Thread(target=kill, args=("worker",)) 495 kill_thread_2 = threading.Thread(target=kill, args=("ps",)) 496 kill_thread_1.start() 497 kill_thread_2.start() 498 kill_thread_1.join() 499 kill_thread_2.join() 500 501 with self.assertRaises( 502 (errors.UnavailableError, errors.InvalidArgumentError)): 503 model.join_training_functions() 504 505 def testNumpyFetchedAfterWorkerFailure(self): 506 507 with self.strategy.scope(): 508 v = variables.Variable(initial_value=0, dtype=dtypes.int32) 509 510 @def_function.function 511 def worker_fn(): 512 return v + 1, v - 1 513 514 remote_value = self.cluster_coord.schedule(worker_fn) 515 # Attempt to fetch before killing worker task should succeed. 516 self.assertEqual((1, -1), remote_value.fetch()) 517 self._cluster.kill_task("worker", 0) 518 # So should attempt to fetch after killing worker task. 519 self.assertEqual((1, -1), remote_value.fetch()) 520 521 def testTensorGotAfterWorkerFailure(self): 522 523 with self.strategy.scope(): 524 v = variables.Variable(initial_value=0, dtype=dtypes.int32) 525 526 @def_function.function 527 def worker_fn(): 528 return v + 1, v - 1 529 530 remote_value = self.cluster_coord.schedule(worker_fn) 531 532 # Attempt to fetch before killing worker task should succeed. 533 fetched = remote_value.get()[0] 534 self.assertIsInstance(fetched, ops.Tensor) 535 self.assertEqual(fetched.device, "/job:chief/replica:0/task:0/device:CPU:0") 536 self.assertEqual((1, -1), remote_value.get()) 537 remote_value.get()[0].numpy() 538 539 # As well as the remote tensors that point to worker0 or worker1. 540 values = remote_value._values[0] 541 self.assertIsInstance(values, ops.Tensor) 542 self.assertRegex(values.device, 543 "/job:worker/replica:0/task:[0-1]/device:CPU:0") 544 self.assertEqual((1, -1), remote_value._values) 545 remote_value._values[0].numpy() 546 547 # Terminate the workers and wait a little so that they are indeed killed. 548 for i in range(self.num_workers): 549 self._cluster.kill_task("worker", i) 550 time.sleep(5) 551 552 # Attempt to fetch after killing worker tasks should succeed as well. 553 remote_value.get()[0].numpy() 554 self.assertEqual((1, -1), remote_value.get()) 555 556 # Attempting to copy the tensor from worker now should fail. 557 with self.assertRaises(errors.UnavailableError) as cm: 558 remote_value._values[0].numpy() 559 self.assertIn("failed to connect to all addresses", cm.exception.message) 560 self.assertIn("/job:worker/replica:0/task:", cm.exception.message) 561 562 def testClusterStateNotDisrupted(self): 563 # This test has side effects and can disrupt other tests, even if the 564 # resource created by it will not be used in following tests. 565 # TODO(b/155209534): enable this test. 566 # self.testPSPreemptionErrorType() 567 568 self.thread_coord = thread_coordinator.Coordinator( 569 clean_stop_exception_types=[]) 570 self.testWorkerPreemptionMidstFunction() 571 572 self.thread_coord = thread_coordinator.Coordinator( 573 clean_stop_exception_types=[]) 574 self.testWorkerPreemptionErrorType() 575 576 # In previous tests, workers may fail after training is done. But the 577 # following tests start with creating resources where failure is not 578 # handled. 579 # TODO(b/153888707): enable the following two tests. 580 # self.testTwoWorkersPreempted() 581 # self.testWorkerContinuousFailure() 582 583 def testJoinRaisesUnavailableErrorAtPsFailure(self): 584 self._create_model_and_run_indefinitely() 585 self._cluster.kill_task("ps", 0) 586 while self.cluster_coord._cluster.closure_queue._error is None: 587 time.sleep(1) 588 with self.assertRaises((errors.UnavailableError, errors.NotFoundError, 589 errors.FailedPreconditionError)): 590 self.cluster_coord.join() 591 592 def testScheduleRaisesUnavailableErrorAtPsFailure(self): 593 self._create_model_and_run_indefinitely() 594 self._cluster.kill_task("ps", 0) 595 while self.cluster_coord._cluster.closure_queue._error is None: 596 time.sleep(1) 597 with self.assertRaises((errors.UnavailableError, errors.NotFoundError, 598 errors.FailedPreconditionError)): 599 self.cluster_coord.schedule(def_function.function(lambda: None)) 600 601 def testWorkerExecutionAfterPsFailureRaisesExpectedError(self): 602 model = self._create_model_and_run_indefinitely() 603 for i in range(self.num_ps): 604 self._cluster.kill_task("ps", i) 605 while self.cluster_coord._cluster.closure_queue._error is None: 606 time.sleep(1) 607 608 @def_function.function 609 def trivial_function(): 610 return model.iterations + 1 611 612 for i in range(self.num_workers): 613 try: 614 with ops.device("/job:worker/replica:0/task:{}".format(i)): 615 trivial_function() 616 except Exception as e: # pylint: disable=broad-except 617 if cluster_coordinator._is_ps_failure(e): 618 if i < self.num_workers - 1: 619 continue 620 return 621 raise AssertionError("Executing a function after PS fails, should " 622 "result in a PS failure.") 623 624 625class MultiWorkerFaultToleranceTest(BaseFaultToleranceTest, test.TestCase): 626 """Multi worker fault tolerance tests. 627 628 This covers the ordinary cases where multiple workers and PS are used. 629 """ 630 631 def setUp(self): 632 super(MultiWorkerFaultToleranceTest, self).setUp(2, 2) 633 634 635class SingleWorkerFaultToleranceTest(BaseFaultToleranceTest, test.TestCase): 636 """Single worker fault tolerance tests. 637 638 This covers the cases that ensure training can continue in a single-worker 639 cluster, even if the only worker can become unavailable at some point and 640 recovered (if there are multiple workers, it is possible that the training 641 succeeds with the workers that did not fail). Realistically single worker 642 is very rarely used, but the tests are important to ensure the correct 643 behaviors. 644 """ 645 646 def setUp(self): 647 super(SingleWorkerFaultToleranceTest, self).setUp(1, 1) 648 649 650if __name__ == "__main__": 651 v2_compat.enable_v2_behavior() 652 multi_process_runner.test_main() 653