1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for coordinator.py.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import contextlib 23import functools 24import gc 25import os 26import platform 27import sys 28import threading 29import time 30import traceback 31from absl.testing import parameterized 32 33from tensorflow.python.compat import v2_compat 34from tensorflow.python.data.ops import dataset_ops 35from tensorflow.python.distribute import distribute_utils 36from tensorflow.python.distribute import distribution_strategy_context 37from tensorflow.python.distribute import input_lib 38from tensorflow.python.distribute import multi_worker_test_base 39from tensorflow.python.distribute import parameter_server_strategy_v2 40from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver 41from tensorflow.python.distribute.coordinator import cluster_coordinator as coordinator_lib 42from tensorflow.python.distribute.coordinator import values as values_lib 43from tensorflow.python.eager import cancellation 44from tensorflow.python.eager import def_function 45from tensorflow.python.eager import test 46from tensorflow.python.framework import constant_op 47from tensorflow.python.framework import dtypes 48from tensorflow.python.framework import errors 49from tensorflow.python.framework import random_seed 50from tensorflow.python.framework import tensor_spec 51from tensorflow.python.framework import test_util 52from tensorflow.python.ops import array_ops 53from tensorflow.python.ops import check_ops 54from tensorflow.python.ops import math_ops 55from tensorflow.python.ops import random_ops 56from tensorflow.python.ops import variable_scope 57from tensorflow.python.ops import variables 58from tensorflow.python.platform import tf_logging as logging 59from tensorflow.python.training import coordinator 60from tensorflow.python.training.server_lib import ClusterSpec 61 62 63class ClosureWithOutput(coordinator_lib.Closure): 64 65 def __init__(self, function, cancellation_mgr=None, args=None, kwargs=None): 66 super(ClosureWithOutput, self).__init__( 67 function, cancellation_mgr=cancellation_mgr, args=args, kwargs=kwargs) 68 self.output_remote_value = self.build_output_remote_value() 69 70 71class CoordinatedClosureQueueTest(test.TestCase): 72 73 def testBasic(self): 74 queue = coordinator_lib._CoordinatedClosureQueue() 75 closure1 = self._create_closure(queue._cancellation_mgr) 76 queue.put(closure1) 77 self.assertIs(closure1, queue.get()) 78 self.assertFalse(queue.done()) 79 queue.put_back(closure1) 80 self.assertEqual(closure1, queue.get()) 81 queue.mark_finished() 82 self.assertTrue(queue.done()) 83 queue.wait() 84 85 def testProcessAtLeaseOnce(self): 86 closure_queue = coordinator_lib._CoordinatedClosureQueue() 87 labels = ['A', 'B', 'C', 'D', 'E'] 88 processed_count = collections.defaultdict(int) 89 90 coord = coordinator.Coordinator(clean_stop_exception_types=[]) 91 92 def process_queue(): 93 with coord.stop_on_exception(): 94 has_been_put_back = False 95 while True: 96 closure = closure_queue.get(timeout=30) 97 if closure is None: 98 break 99 if not has_been_put_back: 100 has_been_put_back = True 101 closure_queue.put_back(closure) 102 continue 103 closure._function() 104 closure_queue.mark_finished() 105 106 def get_func(label): 107 108 def func(): 109 time.sleep(3) 110 processed_count[label] += 1 111 112 return func 113 114 cm = cancellation.CancellationManager() 115 for label in labels: 116 closure_queue.put(ClosureWithOutput(get_func(label), cm)) 117 t1 = threading.Thread(target=process_queue, daemon=True) 118 t1.start() 119 t2 = threading.Thread(target=process_queue, daemon=True) 120 t2.start() 121 122 # Make sure multiple wait() calls are fine. 123 closure_queue.wait() 124 closure_queue.wait() 125 closure_queue.wait() 126 closure_queue.wait() 127 128 self.assertEqual(processed_count, collections.Counter(labels)) 129 130 coord.join([t1, t2]) 131 132 def testNotifyBeforeWait(self): 133 closure_queue = coordinator_lib._CoordinatedClosureQueue() 134 135 def func(): 136 logging.info('func running') 137 138 coord = coordinator.Coordinator(clean_stop_exception_types=[]) 139 140 def process_queue(): 141 with coord.stop_on_exception(): 142 closure_queue.get() 143 closure_queue.mark_finished() 144 145 closure_queue.put(ClosureWithOutput(func, closure_queue._cancellation_mgr)) 146 t = threading.Thread(target=process_queue) 147 t.start() 148 coord.join([t]) 149 150 # This test asserts that waiting at the time the function has been processed 151 # doesn't time out. 152 closure_queue.wait() 153 154 def _assert_one_unblock_the_other(self, first_fn, second_fn): 155 """Asserts `second_fn` wouldn't return before `first_fn` is finished.""" 156 first_fn_done = threading.Event() 157 second_fn_done = threading.Event() 158 coord = coordinator.Coordinator(clean_stop_exception_types=[]) 159 160 def wrapped_first_fn(): 161 with coord.stop_on_exception(): 162 self.assertFalse(second_fn_done.is_set()) 163 first_fn() 164 first_fn_done.set() 165 166 self.assertFalse(first_fn_done.is_set()) 167 t = threading.Thread(target=wrapped_first_fn) 168 t.start() 169 170 second_fn() 171 self.assertTrue(first_fn_done.is_set()) 172 second_fn_done.set() 173 174 coord.join([t]) 175 176 def testWaitRaiseErrorAfterMarkFailure(self): 177 if sys.version_info >= (3, 8) and platform.system() == 'Windows': 178 # TODO(b/165013260): Fix this 179 self.skipTest('Test is currently broken on Windows with Python 3.8') 180 181 closure_queue = coordinator_lib._CoordinatedClosureQueue() 182 closure_queue.put(self._create_closure(closure_queue._cancellation_mgr)) 183 closure = closure_queue.get() 184 185 wait_finish_event = threading.Event() 186 coord = coordinator.Coordinator(clean_stop_exception_types=[]) 187 188 # Using a thread to verify that closure_queue.wait() will not return until 189 # all inflight closures are finished. 190 191 def mark_finished_fn(): 192 try: 193 raise ValueError('Some error.') 194 except ValueError as e: 195 closure_queue.mark_failed(e) 196 197 def wait_fn(): 198 with self.assertRaises(ValueError): 199 closure_queue.wait() 200 201 self._assert_one_unblock_the_other(mark_finished_fn, wait_fn) 202 203 self.assertTrue(closure_queue.done()) 204 205 def _create_closure(self, cancellation_mgr): 206 207 @def_function.function() 208 def some_function(): 209 return 1.0 210 211 return ClosureWithOutput(some_function, cancellation_mgr) 212 213 def _put_two_closures_and_get_one(self): 214 closure_queue = coordinator_lib._CoordinatedClosureQueue() 215 closure1 = self._create_closure(closure_queue._cancellation_mgr) 216 closure_queue.put(closure1) 217 218 closure2 = self._create_closure(closure_queue._cancellation_mgr) 219 closure_queue.put(closure2) 220 221 closure_got = closure_queue.get() # returns closure1 222 self.assertIs(closure_got, closure1) 223 self.assertIsNot(closure_got, closure2) 224 return closure_queue, closure1, closure2 225 226 def testPutRaiseError(self): 227 if sys.version_info >= (3, 8) and platform.system() == 'Windows': 228 # TODO(b/165013260): Fix this 229 self.skipTest('Test is currently broken on Windows with Python 3.8') 230 231 closure_queue, _, closure2 = self._put_two_closures_and_get_one() 232 233 closure_queue.mark_failed(ValueError()) 234 235 with self.assertRaises(ValueError): 236 closure_queue.put(self._create_closure(closure_queue._cancellation_mgr)) 237 238 self.assertTrue(closure_queue.done()) 239 240 with self.assertRaisesRegex( 241 errors.CancelledError, 242 'The corresponding function is cancelled. Please reschedule the ' 243 'function.'): 244 closure2.output_remote_value.fetch() 245 246 # The error is cleared. 247 closure_queue.put(self._create_closure(closure_queue._cancellation_mgr)) 248 249 def testWaitRaiseError(self): 250 if sys.version_info >= (3, 8) and platform.system() == 'Windows': 251 # TODO(b/165013260): Fix this 252 self.skipTest('Test is currently broken on Windows with Python 3.8') 253 254 closure_queue, _, closure2 = self._put_two_closures_and_get_one() 255 256 closure_queue.mark_failed(ValueError()) 257 258 with self.assertRaises(ValueError): 259 closure_queue.wait() 260 self.assertTrue(closure_queue.done()) 261 262 with self.assertRaisesRegex( 263 errors.CancelledError, 264 'The corresponding function is cancelled. Please reschedule the ' 265 'function.'): 266 closure2.output_remote_value.fetch() 267 268 # The error is cleared. 269 closure_queue.wait() 270 271 def testDoneRaiseError(self): 272 if sys.version_info >= (3, 8) and platform.system() == 'Windows': 273 # TODO(b/165013260): Fix this 274 self.skipTest('Test is currently broken on Windows with Python 3.8') 275 276 closure_queue, _, _ = self._put_two_closures_and_get_one() 277 278 self.assertFalse(closure_queue.done()) 279 closure_queue.mark_failed(ValueError()) 280 with self.assertRaises(ValueError): 281 closure_queue.done() 282 283 def _set_error(self, closure_queue, closure, error): 284 try: 285 raise error 286 except Exception as e: # pylint: disable=broad-except 287 closure.output_remote_value._set_error(e) 288 closure_queue.mark_failed(e) 289 290 def _test_cancel_closure_when_error(self, call_wait): 291 if sys.version_info >= (3, 8) and platform.system() == 'Windows': 292 # TODO(b/165013260): Fix this 293 self.skipTest('Test is currently broken on Windows with Python 3.8') 294 295 closure_queue, closure1, closure2 = self._put_two_closures_and_get_one() 296 closure_queue.put(self._create_closure(closure_queue._cancellation_mgr)) 297 closure_queue.get() 298 # At this moment, there are two inflight, one in queue. 299 self.assertEqual(closure_queue._inflight_closure_count, 2) 300 301 # Hold a copy of the queue's cancellation manager at this point 302 initial_cm = closure_queue._cancellation_mgr 303 304 # Simulating closure1 fails. 305 self._set_error(closure_queue, closure1, ValueError('Some error.')) 306 307 # At this moment, there are one inflight, one in queue. 308 self.assertEqual(closure_queue._queue.qsize(), 1) 309 self.assertEqual(closure_queue._inflight_closure_count, 1) 310 311 closure3 = self._create_closure(closure_queue._cancellation_mgr) 312 313 def fake_cancellation(): 314 self._set_error(closure_queue, closure2, 315 ValueError('Fake cancellation error.')) 316 317 def report_error(): 318 # It should not report the fake cancellation error. 319 with self.assertRaisesRegex(ValueError, 'Some error.'): 320 # Verifying `wait()` or `put()` raises even if one closure is in 321 # flight. 322 if call_wait: 323 closure_queue.wait() 324 else: 325 closure_queue.put(closure3) 326 327 self._assert_one_unblock_the_other(fake_cancellation, report_error) 328 329 # The original cancellation manager of the queue has been cancelled. 330 self.assertTrue(initial_cm.is_cancelled) 331 332 # At this moment, there is zero inflight, nothing in queue. 333 self.assertTrue(closure_queue._queue.empty()) 334 self.assertEqual(closure_queue._inflight_closure_count, 0) 335 self.assertIsNone(closure_queue._error) 336 337 # This asserts that closure1 has errored. 338 with self.assertRaisesRegex(ValueError, 'Some error.'): 339 closure1.output_remote_value.fetch() 340 341 # The following asserts that closure3 should have been cancelled. 342 if not call_wait: 343 with self.assertRaisesRegex( 344 errors.CancelledError, 345 'The corresponding function is cancelled. Please reschedule the ' 346 'function.'): 347 closure3.output_remote_value.fetch() 348 349 # Closure2 was an inflight closure when it got cancelled. 350 self.assertEqual(closure2.output_remote_value._status, 351 values_lib.RemoteValueStatus.READY) 352 with self.assertRaisesRegex(ValueError, 'Fake cancellation error.'): 353 closure2.output_remote_value.fetch() 354 355 # This asserts that the queue has a clear state. 356 self.testBasic() 357 358 def testWaitRaiseErrorAfterCancelClosure(self): 359 self._test_cancel_closure_when_error(call_wait=True) 360 361 def testPutRaiseErrorAfterCancelClosure(self): 362 self._test_cancel_closure_when_error(call_wait=False) 363 364 def testStateIsRestoredAfterJoinIsCalled(self): 365 if sys.version_info >= (3, 8) and platform.system() == 'Windows': 366 # TODO(b/165013260): Fix this 367 self.skipTest('Test is currently broken on Windows with Python 3.8') 368 369 closure_queue, _, _ = self._put_two_closures_and_get_one() 370 self.assertEqual(closure_queue._inflight_closure_count, 1) 371 closure_queue.mark_failed(ValueError('test error')) 372 with self.assertRaises(ValueError): 373 closure_queue.put(self._create_closure(closure_queue._cancellation_mgr)) 374 375 # Its error should have been cleared. 376 self.assertIsNone(closure_queue._error) 377 closure_queue.put(self._create_closure(closure_queue._cancellation_mgr)) 378 self.assertIsNone(closure_queue._error) 379 380 def testThreadSafey(self): 381 thread_count = 10 382 queue = coordinator_lib._CoordinatedClosureQueue() 383 384 # Each thread performs 20 queue actions: 10 are `put_back` and 10 are 385 # `mark_finished`. 386 action_count = 20 387 388 def func(): 389 for i in range(action_count): 390 closure = queue.get() 391 if i % 2 == 0: 392 queue.put_back(closure) 393 else: 394 queue.mark_finished() 395 396 threads = [threading.Thread(target=func) for i in range(thread_count)] 397 for t in threads: 398 t.start() 399 400 for _ in range(thread_count * action_count // 2): 401 queue.put(self._create_closure(queue._cancellation_mgr)) 402 queue.wait() 403 self.assertTrue(queue.done()) 404 405 406class ErrorReportingThread(threading.Thread): 407 408 error = None 409 410 def __init__(self, *args, **kwargs): 411 assert 'target' in kwargs 412 target = kwargs['target'] 413 414 @functools.wraps(target) 415 def wrapped_target(*args, **kwargs): 416 try: 417 return target(*args, **kwargs) 418 except Exception as e: # pylint: disable=broad-except 419 traceback.print_exception(*sys.exc_info()) 420 ErrorReportingThread.error = e 421 422 kwargs['target'] = wrapped_target 423 super(ErrorReportingThread, self).__init__(*args, **kwargs) 424 425 426class TestCaseWithErrorReportingThread(test.TestCase): 427 428 @classmethod 429 def setUpClass(cls): 430 cls._threading_thread = threading.Thread 431 threading.Thread = ErrorReportingThread 432 super(TestCaseWithErrorReportingThread, cls).setUpClass() 433 434 @classmethod 435 def tearDownClass(cls): 436 super(TestCaseWithErrorReportingThread, cls).tearDownClass() 437 threading.Thread = cls._threading_thread 438 439 def setUp(self): 440 ErrorReportingThread.error = None 441 super(TestCaseWithErrorReportingThread, self).setUp() 442 443 def tearDown(self): 444 super(TestCaseWithErrorReportingThread, self).tearDown() 445 if ErrorReportingThread.error: 446 raise ErrorReportingThread.error # pylint: disable=raising-bad-type 447 448 449def make_coordinator(num_workers, num_ps): 450 # TODO(rchao): Test the internal rpc_layer version. 451 cluster_def = multi_worker_test_base.create_in_process_cluster( 452 num_workers=num_workers, num_ps=num_ps, rpc_layer='grpc') 453 cluster_def['chief'] = [ 454 'localhost:%d' % multi_worker_test_base.pick_unused_port() 455 ] 456 cluster_resolver = SimpleClusterResolver( 457 ClusterSpec(cluster_def), rpc_layer='grpc') 458 strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( 459 cluster_resolver) 460 return coordinator_lib.ClusterCoordinator(strategy) 461 462 463class ClusterCoordinatorTest(TestCaseWithErrorReportingThread, 464 parameterized.TestCase): 465 466 @classmethod 467 def setUpClass(cls): 468 super(ClusterCoordinatorTest, cls).setUpClass() 469 cls.coordinator = make_coordinator(num_workers=5, num_ps=2) 470 cls.strategy = cls.coordinator.strategy 471 472 def testClusterCoordinatorOnlyInitOnce(self): 473 cluster = self.coordinator._cluster 474 same_coordinator = coordinator_lib.ClusterCoordinator(self.strategy) 475 self.assertIs(self.coordinator, same_coordinator) 476 self.assertIs(cluster, same_coordinator._cluster) 477 478 def testFnReturnNestedValues(self): 479 x = constant_op.constant(1) 480 481 @def_function.function 482 def f(): 483 return x + 1, (x + 2, x + 3), [x + 4], {'v': x} 484 485 got = self.coordinator.schedule(f) 486 want = 2, (3, 4), [5], {'v': 1} 487 self.assertEqual(got.fetch(), want) 488 self.assertEqual(self.coordinator.fetch(got), want) 489 490 def testFetchingRemoteValueStructure(self): 491 x = constant_op.constant(1) 492 493 @def_function.function 494 def f(): 495 return x + 1, (x + 2, x + 3), [x + 4], {'v': x} 496 497 want = 2, (3, 4), [5], {'v': 1} 498 remote_value_list = [self.coordinator.schedule(f) for _ in range(5)] 499 self.assertAllEqual( 500 self.coordinator.fetch(remote_value_list), [want for _ in range(5)]) 501 502 def testInputFunction(self): 503 504 def input_fn(): 505 return dataset_ops.DatasetV2.range(1, 2) 506 507 with self.strategy.scope(): 508 v = variables.Variable(initial_value=0, dtype=dtypes.int64) 509 510 @def_function.function 511 def worker_fn(iterator): 512 x = next(iterator) 513 v.assign_add(x) 514 return x 515 516 distributed_dataset = self.coordinator.create_per_worker_dataset(input_fn) 517 result = self.coordinator.schedule( 518 worker_fn, args=(iter(distributed_dataset),)) 519 result = self.coordinator.fetch(result) 520 self.assertEqual(result, (1,)) 521 result = self.coordinator.schedule( 522 worker_fn, args=(iter(distributed_dataset),)) 523 result = self.coordinator.fetch(result) 524 525 self.assertEqual(result, (1,)) 526 self.assertAlmostEqual(v.read_value(), 2, delta=1e-6) 527 528 def testAsyncScheduleAndJoin(self): 529 if test_util.is_xla_enabled(): 530 self.skipTest('Assign_add is not deterministic across threads in XLA') 531 532 def input_fn(): 533 return dataset_ops.DatasetV2.from_tensor_slices([2] * 10) 534 535 with self.strategy.scope(): 536 v = variables.Variable(initial_value=0, dtype=dtypes.int32) 537 538 # TODO(yuefengz): the following tf.function has a return value which is None 539 # in its structured_outputs. 540 @def_function.function 541 def worker_fn(iterator): 542 x = next(iterator) 543 v.assign_add(x) 544 545 distributed_dataset = self.coordinator.create_per_worker_dataset(input_fn) 546 547 iterator = iter(distributed_dataset) 548 549 # Verifying joining without any scheduling doesn't hang. 550 self.coordinator.join() 551 self.assertEqual(v.read_value().numpy(), 0) 552 553 for _ in range(5): 554 self.coordinator.schedule(worker_fn, args=(iterator,)) 555 self.coordinator.join() 556 557 # With 5 addition it should be 2*5 = 10. 558 self.assertEqual(v.read_value().numpy(), 10) 559 560 for _ in range(5): 561 self.coordinator.schedule(worker_fn, args=(iterator,)) 562 563 # Verifying multiple join is fine. 564 self.coordinator.join() 565 self.coordinator.join() 566 self.coordinator.join() 567 568 self.assertTrue(self.coordinator.done()) 569 570 # Likewise, it's now 20. 571 self.assertEqual(v.read_value().numpy(), 20.) 572 573 @parameterized.parameters(True, False) 574 def testInputFunctionWithMap(self, use_input_fn): 575 self._map_fn_tracing_count = 0 576 577 def input_fn(): 578 579 def map_fn(x): 580 self._map_fn_tracing_count += 1 581 return x + 10 582 583 return dataset_ops.DatasetV2.range(0, 10).map(map_fn) 584 585 @def_function.function 586 def worker_fn(iterator): 587 return next(iterator) 588 589 if use_input_fn: 590 distributed_dataset = self.coordinator.create_per_worker_dataset(input_fn) 591 else: 592 distributed_dataset = self.coordinator.create_per_worker_dataset( 593 input_fn()) 594 595 result = self.coordinator.schedule( 596 worker_fn, args=(iter(distributed_dataset),)) 597 self.assertEqual(result.fetch(), (10,)) 598 self.assertEqual(self._map_fn_tracing_count, 1) 599 600 def testInputFunctionCreateVariables(self): 601 602 def input_fn(): 603 v = variables.Variable(initial_value=0.0) 604 return v.read_value() 605 606 with self.assertRaises(ValueError): 607 self.coordinator.create_per_worker_dataset(input_fn) 608 609 @parameterized.parameters(True, False) 610 def testDatasetsShuffledDifferently(self, use_input_fn): 611 # This test requires at least two workers in the cluster. 612 self.assertGreaterEqual(len(self.coordinator._cluster.workers), 2) 613 614 random_seed.set_random_seed(None) 615 616 def input_fn(): 617 dataset = dataset_ops.DatasetV2.range(0, 100).shuffle(100).batch(1) 618 return self.strategy.experimental_distribute_dataset(dataset) 619 620 if use_input_fn: 621 distributed_dataset = self.coordinator.create_per_worker_dataset(input_fn) 622 else: 623 distributed_dataset = self.coordinator.create_per_worker_dataset( 624 input_fn()) 625 distributed_iterator = iter(distributed_dataset) 626 # Get elements from the first two iterators. 627 iterator_1 = distributed_iterator._values[0] 628 iterator_1._rebuild_on(self.coordinator._cluster.workers[0]) 629 iterator_1 = iterator_1.fetch() 630 elements_in_iterator_1 = [ 631 self.strategy.experimental_local_results(e) 632 for e in iterator_1 633 ] 634 iterator_2 = distributed_iterator._values[1] 635 iterator_2._rebuild_on(self.coordinator._cluster.workers[1]) 636 iterator_2 = iterator_2.fetch() 637 elements_in_iterator_2 = [ 638 self.strategy.experimental_local_results(e) 639 for e in iterator_2 640 ] 641 642 self.assertNotAllEqual(elements_in_iterator_1, elements_in_iterator_2) 643 644 def testPerWorkerValue(self): 645 self.skipTest('b/168569314') 646 var_shape = tuple() 647 var_dtype = dtypes.float32 648 var_name = 'var' 649 650 def create_var(): 651 var = variables.Variable( 652 initial_value=0.0, dtype=var_dtype, name=var_name) 653 self.assertIn('worker', var.device) 654 return var 655 656 worker_local_var = self.coordinator._create_per_worker_resources(create_var) 657 658 # The following is a workaround to allow `worker_local_var` to be passed in 659 # as args to the `coordinator.schedule` method which requires tensor specs 660 # to trace tf.function but _create_worker_resources' return values don't 661 # have tensor specs. We can get rid of this workaround once 662 # _create_worker_resources is able to infer the tensor spec of the return 663 # value of the function passed in. See b/154675763. 664 for var in worker_local_var._values: 665 var._type_spec = tensor_spec.TensorSpec(var_shape, var_dtype, var_name) 666 667 def worker_fn(var): 668 var.assign_add(1.0) 669 670 for _ in range(10): 671 # Which slice of `worker_local_var` will be used will depend on which 672 # worker the `worker_fn` gets scheduled on. 673 self.coordinator.schedule(worker_fn, args=(worker_local_var,)) 674 self.coordinator.join() 675 676 var_sum = sum(self.coordinator.fetch(worker_local_var._values)) 677 self.assertEqual(var_sum, 10.0) 678 679 def testDisallowRemoteValueAsInput(self): 680 681 @def_function.function 682 def func_0(): 683 return 1.0 684 685 @def_function.function 686 def func_1(x): 687 return x + 1.0 688 689 remote_v = self.coordinator.schedule(func_0) 690 with self.assertRaises(ValueError): 691 self.coordinator.schedule(func_1, args=(remote_v,)) 692 693 def testPythonFunctionNotAllowedToSchedule(self): 694 695 def func(a): 696 return array_ops.identity(a) 697 698 with self.assertRaisesRegexp( 699 TypeError, 700 '`tf.distribute.experimental.coordinator.ClusterCoordinator.schedule` ' 701 'only accepts a `tf.function` or a concrete function.'): 702 self.coordinator.schedule(func, args=(1,)) 703 704 def testDatasetPartiallyCreatedOnCoordinator(self): 705 dataset = dataset_ops.DatasetV2.range(1, 10) 706 707 @def_function.function 708 def input_fn(): 709 return dataset.shuffle(9) 710 711 @def_function.function 712 def worker_fn(iterator): 713 x = next(iterator) 714 return x 715 716 per_worker_dataset = self.coordinator.create_per_worker_dataset(input_fn) 717 self.coordinator.schedule(worker_fn, args=(iter(per_worker_dataset),)) 718 719 with self.assertRaisesRegexp( 720 coordinator_lib.InputError, 721 'error message is Failed copying input tensor from'): 722 self.coordinator.join() 723 724 def testPassDatasetToCreatePerWorkerDataset(self): 725 dataset = dataset_ops.DatasetV2.range(1, 11).batch(4) 726 727 @def_function.function 728 def worker_fn(iterator): 729 return next(iterator) 730 731 per_worker_dataset = self.coordinator.create_per_worker_dataset(dataset) 732 result = self.coordinator.schedule( 733 worker_fn, args=(iter(per_worker_dataset),)) 734 result = result.fetch() 735 expected_result = math_ops.range(1., 5.) 736 737 self.assertAllEqual(result, (expected_result)) 738 739 def testMultipleDatasets(self): 740 741 def input_fn1(): 742 return dataset_ops.DatasetV2.range(0, 5) 743 744 def input_fn2(): 745 return dataset_ops.DatasetV2.range(5, 10) 746 747 per_worker_dataset1 = self.coordinator.create_per_worker_dataset(input_fn1) 748 per_worker_iterator1 = iter(per_worker_dataset1) 749 per_worker_dataset2 = self.coordinator.create_per_worker_dataset(input_fn2) 750 per_worker_iterator2 = iter(per_worker_dataset2) 751 752 @def_function.function 753 def worker_fn(iterator1, iterator2): 754 return next(iterator1) + next(iterator2) 755 756 result = self.coordinator.schedule( 757 worker_fn, args=(per_worker_iterator1, per_worker_iterator2)) 758 self.assertEqual(result.fetch(), 5.0) 759 760 per_worker_dataset3 = self.coordinator.create_per_worker_dataset(input_fn1) 761 per_worker_iterator3 = iter(per_worker_dataset3) 762 763 result = self.coordinator.schedule( 764 worker_fn, args=(per_worker_iterator3, per_worker_iterator2)) 765 self.assertGreaterEqual(result.fetch(), 5.0) 766 767 def testRepeatedIteratorCreation(self): 768 769 def input_fn(): 770 return dataset_ops.DatasetV2.range(1, 100) 771 772 per_worker_dataset1 = self.coordinator.create_per_worker_dataset(input_fn) 773 per_worker_dataset2 = self.coordinator.create_per_worker_dataset(input_fn) 774 775 @def_function.function 776 def worker_fn(iterator1, iterator2): 777 return next(iterator1) + next(iterator2) 778 779 for _ in range(10): 780 per_worker_iterator1 = iter(per_worker_dataset1) 781 per_worker_iterator2 = iter(per_worker_dataset2) 782 result = self.coordinator.schedule( 783 worker_fn, args=(per_worker_iterator1, per_worker_iterator2)) 784 for _ in range(10): 785 self.coordinator.schedule( 786 worker_fn, args=(per_worker_iterator1, per_worker_iterator2)) 787 self.coordinator.join() 788 self.assertGreaterEqual(result.fetch(), 2.0) 789 del per_worker_iterator1, per_worker_iterator2 790 gc.collect() 791 792 # There shouldn't be any live iterator objects. 793 for w in self.coordinator._cluster.workers: 794 for r in w._resource_remote_value_refs: 795 self.assertIsNone(r()) 796 797 798class LimitedClosureQueueSizeBasicTest(ClusterCoordinatorTest): 799 """Test basic functionality works with explicit maximum closure queue size. 800 801 Execute the same set of test cases as in `ClusterCoordinatorTest`, with an 802 explicit size limit for the closure queue. Note that even when the queue size 803 is set to infinite, there is still a maximum practical size (depends on host 804 memory limit) that might cause the queue.put operations to be blocking when 805 scheduling a large number of closures on a big cluster. These tests make sure 806 that the coordinator does not run into deadlocks in such scenario. 807 """ 808 809 @classmethod 810 def setUpClass(cls): 811 super(LimitedClosureQueueSizeBasicTest, cls).setUpClass() 812 coordinator_lib._CLOSURE_QUEUE_MAX_SIZE = 2 813 cls.coordinator = make_coordinator(num_workers=5, num_ps=2) 814 cls.strategy = cls.coordinator.strategy 815 816 817class ScheduleStartDelayTest(ClusterCoordinatorTest): 818 """Test basic functionality works with worker scheduling delay. 819 820 This is basically to make sure that setting environment variables 821 `TF_COORDINATOR_SCHEDULE_START_DELAY` and 822 `TF_COORDINATOR_SCHEDULE_START_DELAY_MAX` will cause any failure. 823 """ 824 825 @classmethod 826 def setUpClass(cls): 827 super(ScheduleStartDelayTest, cls).setUpClass() 828 os.environ['TF_COORDINATOR_SCHEDULE_START_DELAY'] = '2' 829 os.environ['TF_COORDINATOR_SCHEDULE_START_DELAY_MAX'] = '4' 830 cls.coordinator = make_coordinator(num_workers=3, num_ps=2) 831 cls.strategy = cls.coordinator.strategy 832 833 @classmethod 834 def tearDownClass(cls): 835 del os.environ['TF_COORDINATOR_SCHEDULE_START_DELAY'] 836 del os.environ['TF_COORDINATOR_SCHEDULE_START_DELAY_MAX'] 837 super(ScheduleStartDelayTest, cls).tearDownClass() 838 839 840class ErrorReportingTest(TestCaseWithErrorReportingThread): 841 842 @classmethod 843 def setUpClass(cls): 844 super(ErrorReportingTest, cls).setUpClass() 845 cls.coordinator = make_coordinator(num_workers=3, num_ps=2) 846 cls.strategy = cls.coordinator.strategy 847 848 with cls.strategy.scope(): 849 cls.iteration = variables.Variable(initial_value=0.0) 850 851 @def_function.function 852 def _normal_function(self): 853 x = random_ops.random_uniform((2, 10)) 854 y = random_ops.random_uniform((10, 2)) 855 self.iteration.assign_add(1.0) 856 return math_ops.reduce_mean(math_ops.matmul(x, y)) 857 858 @def_function.function 859 def _error_function(self): 860 x = random_ops.random_uniform((2, 10)) 861 y = random_ops.random_uniform((10, 2)) 862 check_ops.assert_non_positive_v2(math_ops.reduce_sum(math_ops.matmul(x, y))) 863 self.iteration.assign_add(1.0) 864 return self.iteration 865 866 @def_function.function 867 def _long_function(self): 868 x = random_ops.random_uniform((1000, 1000)) 869 for _ in math_ops.range(10000): 870 a = random_ops.random_uniform((1000, 1000)) 871 b = random_ops.random_uniform((1000, 1000)) 872 x += math_ops.matmul(a, b) 873 return x 874 875 def testJoinRaiseError(self): 876 for _ in range(3): 877 self.coordinator.schedule(self._normal_function) 878 self.coordinator.schedule(self._error_function) 879 with self.assertRaises(errors.InvalidArgumentError): 880 self.coordinator.join() 881 882 def testScheduleRaiseError(self): 883 for _ in range(3): 884 self.coordinator.schedule(self._normal_function) 885 self.coordinator.schedule(self._error_function) 886 with self.assertRaises(errors.InvalidArgumentError): 887 while True: 888 self.coordinator.schedule(self._normal_function) 889 890 def testScheduleRaiseErrorWithMultipleFailure(self): 891 for _ in range(3): 892 self.coordinator.schedule(self._normal_function) 893 self.coordinator.schedule(self._error_function) 894 with self.assertRaises(errors.InvalidArgumentError): 895 while True: 896 self.coordinator.schedule(self._error_function) 897 self.coordinator.join() 898 899 def testErrorWillbeCleared(self): 900 self.coordinator.schedule(self._error_function) 901 with self.assertRaises(errors.InvalidArgumentError): 902 self.coordinator.join() 903 904 for _ in range(3): 905 self.coordinator.schedule(self._normal_function) 906 self.coordinator.schedule(self._error_function) 907 with self.assertRaises(errors.InvalidArgumentError): 908 self.coordinator.join() 909 910 def testRemoteValueReturnError(self): 911 result = self.coordinator.schedule(self._error_function) 912 913 with self.assertRaises(errors.InvalidArgumentError): 914 result.fetch() 915 916 # Clear the error. 917 with self.assertRaises(errors.InvalidArgumentError): 918 self.coordinator.join() 919 920 def testInputError(self): 921 922 worker_local_val = self.coordinator._create_per_worker_resources( 923 self._error_function) 924 925 @def_function.function 926 def func(x): 927 return x + 1 928 929 result = self.coordinator.schedule(func, args=(worker_local_val,)) 930 with self.assertRaises(coordinator_lib.InputError): 931 self.coordinator.join() 932 933 with self.assertRaises(coordinator_lib.InputError): 934 result.fetch() 935 936 def testCancellation(self): 937 for _ in range(3): 938 self.coordinator.schedule(self._normal_function) 939 long_function = self.coordinator.schedule(self._long_function) 940 self.coordinator.schedule(self._error_function) 941 942 with self.assertRaises(errors.InvalidArgumentError): 943 self.coordinator.join() 944 945 with self.assertRaises(errors.CancelledError): 946 long_function.fetch() 947 948 for _ in range(3): 949 self.coordinator.schedule(self._normal_function) 950 self.coordinator.join() 951 952 953class LimitedClosureQueueErrorTest(ErrorReportingTest): 954 """Test error reporting works with explicit maximum closure queue size. 955 956 Execute the same set of test cases as in ErrorReportingTest, with an explicit 957 size limit for the closure queue. 958 """ 959 960 @classmethod 961 def setUpClass(cls): 962 super(LimitedClosureQueueErrorTest, cls).setUpClass() 963 coordinator_lib._CLOSURE_QUEUE_MAX_SIZE = 2 964 cls.coordinator = make_coordinator(num_workers=3, num_ps=2) 965 cls.strategy = cls.coordinator.strategy 966 967 with cls.coordinator.strategy.scope(): 968 cls.iteration = variables.Variable(initial_value=0.0) 969 970 971class StrategyIntegrationTest(test.TestCase, parameterized.TestCase): 972 973 @classmethod 974 def setUpClass(cls): 975 super(StrategyIntegrationTest, cls).setUpClass() 976 cls.coordinator = make_coordinator(num_workers=1, num_ps=1) 977 cls.strategy = cls.coordinator.strategy 978 979 def testRunNotUsedWithClusterCoordinatorSchedule(self): 980 981 @def_function.function 982 def input_fn(): 983 return dataset_ops.DatasetV2.range(1, 3) 984 985 with self.strategy.scope(): 986 v = variables.Variable(initial_value=1, dtype=dtypes.int64) 987 988 def replica_fn(input_tensor): 989 return input_tensor + v, input_tensor - v 990 991 @def_function.function 992 def worker_fn(iterator): 993 return self.strategy.run(replica_fn, args=(next(iterator),)) 994 995 per_worker_dataset = self.coordinator.create_per_worker_dataset(input_fn) 996 997 @contextlib.contextmanager 998 def _assert_logs_usage_warning(): 999 with self.assertLogs(level='WARNING') as logs: 1000 yield 1001 1002 self.assertIn( 1003 'It is detected that a function used with ' 1004 '`tf.distribute.experimental.ParameterServerStrategy` ' 1005 'is executed locally on the coordinator. This is inefficient but may ' 1006 'be valid for one-off tasks such as inferring output signature. ' 1007 'To properly distribute functions to run on workers, `run` or ' 1008 '`reduce` should be used within a function passed to `' 1009 'tf.distribute.experimental.coordinator.ClusterCoordinator.schedule`' 1010 '.', 1011 logs.output[0]) 1012 1013 with _assert_logs_usage_warning(): 1014 # Invoking `run` without `coordinator.schedule` should result in a 1015 # warning. 1016 self.strategy.run( 1017 replica_fn, args=(constant_op.constant(1, dtype=dtypes.int64),)) 1018 1019 # A proper `schedule` should succeed. 1020 rv = self.coordinator.schedule(worker_fn, args=(iter(per_worker_dataset),)) 1021 1022 with _assert_logs_usage_warning(): 1023 # Invoking `run` without `coordinator.schedule` again should result in a 1024 # warning. 1025 self.strategy.run( 1026 replica_fn, args=(constant_op.constant(1, dtype=dtypes.int64),)) 1027 1028 all_results = [(2, 0)] * self.strategy.num_replicas_in_sync 1029 expected_result = [] 1030 for i in range(self.strategy.num_replicas_in_sync): 1031 expected_result.append(all_results[i]) 1032 1033 self.assertAllEqual( 1034 tuple(expected_result), 1035 self.strategy.experimental_local_results(rv.fetch())) 1036 1037 def testBasicVariableAssignment(self): 1038 self.strategy.extended._variable_count = 0 1039 with self.strategy.scope(): 1040 v1 = variables.Variable(initial_value=0.0) 1041 v2 = variables.Variable(initial_value=1.0) 1042 self.assertEqual(self.strategy.extended._variable_count, 2) 1043 1044 @def_function.function 1045 def worker_fn(): 1046 v1.assign_add(0.1) 1047 v2.assign_sub(0.2) 1048 return v1.read_value() / v2.read_value() 1049 1050 results = self.coordinator.schedule(worker_fn) 1051 logging.info('Results of experimental_run_v2: %f', 1052 self.coordinator.fetch(results)) 1053 1054 self.assertAlmostEqual(v1.read_value().numpy(), 0.1, delta=1e-6) 1055 self.assertAlmostEqual(v2.read_value().numpy(), 0.8, delta=1e-6) 1056 1057 def testRunAndReduce(self): 1058 self.assertFalse(distribution_strategy_context.in_cross_replica_context()) 1059 with self.strategy.scope(): 1060 self.assertTrue(distribution_strategy_context.in_cross_replica_context()) 1061 v = variables.Variable(initial_value=1.) 1062 1063 expected_result = (4. * self.strategy.num_replicas_in_sync, 1064 2. * self.strategy.num_replicas_in_sync) 1065 1066 @def_function.function 1067 def worker_fn(input_tensor): 1068 1069 def replica_fn(input_tensor): 1070 # Within `replica_fn`, it has to be in a replica context. 1071 self.assertFalse( 1072 distribution_strategy_context.in_cross_replica_context()) 1073 return input_tensor + v, input_tensor - v 1074 1075 run_result = self.strategy.run(replica_fn, args=(input_tensor,)) 1076 reduced_result = self.strategy.reduce('SUM', run_result, axis=None) 1077 check_ops.assert_equal_v2(reduced_result, expected_result) 1078 return reduced_result 1079 1080 # Asserting scheduling in scope has the expected behavior. 1081 result = self.coordinator.schedule( 1082 worker_fn, args=(constant_op.constant(3.),)) 1083 self.assertIsInstance(result, coordinator_lib.RemoteValue) 1084 self.assertEqual(result.fetch(), expected_result) 1085 1086 # Asserting scheduling out of scope has the expected behavior. 1087 result = self.coordinator.schedule( 1088 worker_fn, args=(constant_op.constant(3.),)) 1089 self.assertEqual(result.fetch(), expected_result) 1090 1091 def testRunAndReduceWithAssignAdd(self): 1092 self.assertFalse(distribution_strategy_context.in_cross_replica_context()) 1093 with self.strategy.scope(): 1094 self.assertTrue(distribution_strategy_context.in_cross_replica_context()) 1095 v = variables.Variable(initial_value=1.) 1096 v1 = variables.Variable( 1097 initial_value=0., 1098 aggregation=variable_scope.VariableAggregation.ONLY_FIRST_REPLICA) 1099 1100 expected_result = (4. * self.strategy.num_replicas_in_sync, 1101 2. * self.strategy.num_replicas_in_sync) 1102 1103 @def_function.function 1104 def worker_fn(input_tensor): 1105 1106 def replica_fn(input_tensor): 1107 # Within `replica_fn`, it has to be in a replica context. 1108 self.assertFalse( 1109 distribution_strategy_context.in_cross_replica_context()) 1110 1111 v1.assign_add(input_tensor) 1112 return input_tensor + v, input_tensor - v 1113 1114 run_result = self.strategy.run(replica_fn, args=(input_tensor,)) 1115 reduced_result = self.strategy.reduce('SUM', run_result, axis=None) 1116 check_ops.assert_equal_v2(reduced_result, expected_result) 1117 return reduced_result 1118 1119 # Asserting scheduling in scope has the expected behavior. 1120 result = self.coordinator.schedule( 1121 worker_fn, args=(constant_op.constant(3.),)) 1122 self.assertIsInstance(result, coordinator_lib.RemoteValue) 1123 self.assertEqual(result.fetch(), expected_result) 1124 1125 # Asserting scheduling out of scope has the expected behavior. 1126 result = self.coordinator.schedule( 1127 worker_fn, args=(constant_op.constant(3.),)) 1128 self.assertEqual(result.fetch(), expected_result) 1129 self.assertEqual(v1, 6.) 1130 1131 def testVariableAggregation(self): 1132 self.assertFalse(distribution_strategy_context.in_cross_replica_context()) 1133 with self.strategy.scope(): 1134 self.assertTrue(distribution_strategy_context.in_cross_replica_context()) 1135 v = variables.Variable( 1136 initial_value=1., 1137 aggregation=variable_scope.VariableAggregation.SUM) 1138 1139 @def_function.function 1140 def worker_fn(): 1141 1142 def replica_fn(): 1143 value = math_ops.cast( 1144 distribution_strategy_context.get_replica_context() 1145 .replica_id_in_sync_group + 1, v.dtype) 1146 v.assign(value) 1147 1148 self.strategy.run(replica_fn) 1149 1150 self.coordinator.schedule(worker_fn) 1151 self.coordinator.join() 1152 expected_result = 0. 1153 for i in range(self.strategy.num_replicas_in_sync): 1154 expected_result = expected_result + i + 1 1155 self.assertEqual(v, expected_result) 1156 1157 def testVariableCaching(self): 1158 self.assertFalse(distribution_strategy_context.in_cross_replica_context()) 1159 with self.strategy.scope(): 1160 self.assertTrue(distribution_strategy_context.in_cross_replica_context()) 1161 v = variables.Variable( 1162 initial_value=1., 1163 aggregation=variable_scope.VariableAggregation.ONLY_FIRST_REPLICA) 1164 1165 # Test read value inside caching scope 1166 with distribute_utils.cache_variable_reads(): 1167 v.read_value() # Reads value 1.0 1168 v.assign(constant_op.constant(5.0)) # v changes to 5.0 1169 self.assertEqual(v.read_value(), 1.0) # should be cached 1.0 value. 1170 1171 # Reset v to 2.0 1172 v.assign(2.0) 1173 1174 # Test convert to tensor value inside caching scope 1175 with distribute_utils.cache_variable_reads(): 1176 t = v * 3.0 1177 self.assertEqual(t, 6.0) 1178 v.assign(3.0) 1179 t1 = v * 3.0 1180 self.assertEqual(t1, 6.0) # should be cached 2.0 * 3.0 value. 1181 1182 # Reset v to 1.0 1183 v.assign(1.0) 1184 1185 # Verify caching scope inside tf.function 1186 @def_function.function 1187 def worker_fn(): 1188 with distribute_utils.cache_variable_reads(): 1189 def replica_fn(): 1190 t = v.read_value() # Reads value 1.0 1191 v.assign(constant_op.constant(5.0)) # v changes to 5.0 1192 t = v.read_value() # should return 1.0 1193 return t # Should be 1.0 instead of 5.0 1194 1195 return self.strategy.run(replica_fn) 1196 1197 result = self.coordinator.schedule(worker_fn) 1198 result = result.fetch() 1199 expected_result = 1. 1200 self.assertEqual(result, expected_result) 1201 1202 # Verify that v.read_value works as expected outside of scope. 1203 v.assign(4.0) 1204 self.assertEqual(v.read_value(), 4.0) 1205 1206 v.assign(constant_op.constant(2.0)) # v changes to 2.0 1207 # Check with scope outside of tf function and check that cache is reset 1208 @def_function.function 1209 def worker_fn1(): 1210 def replica_fn(): 1211 t = v.read_value() # Reads value 2.0 ==> Should be cached 1212 v.assign(constant_op.constant(5.0)) # v changes to 5.0 1213 t = v.read_value() # should return cached value 2.0 1214 return t # Should be 2.0 instead of 5.0 1215 1216 return self.strategy.run(replica_fn) 1217 1218 with distribute_utils.cache_variable_reads(): 1219 result = self.coordinator.schedule(worker_fn1) 1220 result = result.fetch() 1221 expected_result = 2. 1222 self.assertEqual(result, expected_result) 1223 1224 # Verify scope nesting is not permitted. 1225 with self.assertRaises(ValueError): 1226 with distribute_utils.cache_variable_reads(): 1227 with distribute_utils.cache_variable_reads(): 1228 v.read_value() 1229 1230 @parameterized.parameters(True, False) 1231 def testDistributedDatasetInsidePerWorkerDatasetFn(self, from_function): 1232 if from_function: 1233 1234 def per_worker_dataset_fn(): 1235 dataset_fn = lambda _: dataset_ops.DatasetV2.range(1, 11).batch(4) 1236 return self.strategy.distribute_datasets_from_function(dataset_fn) 1237 else: 1238 1239 def per_worker_dataset_fn(): 1240 dataset = dataset_ops.DatasetV2.range(1, 11).batch(4) 1241 return self.strategy.experimental_distribute_dataset(dataset) 1242 1243 @def_function.function 1244 def worker_fn(iterator): 1245 return self.strategy.experimental_local_results(next(iterator)) 1246 1247 per_worker_dataset = self.coordinator.create_per_worker_dataset( 1248 per_worker_dataset_fn) 1249 result = self.coordinator.schedule( 1250 worker_fn, args=(iter(per_worker_dataset),)) 1251 result = result.fetch() 1252 expected_result = array_ops.split( 1253 math_ops.range(1., 5.), 1254 num_or_size_splits=self.strategy.num_replicas_in_sync, 1255 axis=0) 1256 1257 self.assertAllEqual(result, (expected_result)) 1258 1259 @parameterized.parameters(True, False) 1260 def testPassDistributedDatasetToCreatePerWorkerDataset(self, from_function): 1261 if from_function: 1262 dataset_fn = lambda _: dataset_ops.DatasetV2.range(1, 11).batch(4) 1263 distributed_dataset = self.strategy.distribute_datasets_from_function( 1264 dataset_fn) 1265 else: 1266 dataset = dataset_ops.DatasetV2.range(1, 11).batch(4) 1267 distributed_dataset = self.strategy.experimental_distribute_dataset( 1268 dataset) 1269 1270 @def_function.function 1271 def worker_fn(iterator): 1272 return self.strategy.experimental_local_results(next(iterator)) 1273 1274 per_worker_dataset = self.coordinator.create_per_worker_dataset( 1275 distributed_dataset) 1276 result = self.coordinator.schedule( 1277 worker_fn, args=(iter(per_worker_dataset),)) 1278 result = result.fetch() 1279 expected_result = array_ops.split( 1280 math_ops.range(1., 5.), 1281 num_or_size_splits=self.strategy.num_replicas_in_sync, 1282 axis=0) 1283 1284 self.assertAllEqual(result, (expected_result)) 1285 1286 def testDistributeDatasetsFromFunction(self): 1287 1288 def per_worker_dataset_fn(): 1289 1290 def input_worker_device_fn(input_context): 1291 self.assertIsNotNone(input_context) 1292 return dataset_ops.DatasetV2.range(1, 11).batch(1) 1293 1294 return self.strategy.distribute_datasets_from_function( 1295 input_worker_device_fn) 1296 1297 @def_function.function 1298 def worker_fn(iterator): 1299 result = self.strategy.experimental_local_results(next(iterator)) 1300 return result 1301 1302 distributed_dataset = self.coordinator.create_per_worker_dataset( 1303 per_worker_dataset_fn) 1304 result = self.coordinator.schedule( 1305 worker_fn, args=(iter(distributed_dataset),)) 1306 result = result.fetch() 1307 expected_result = [] 1308 for i in range(self.strategy.num_replicas_in_sync): 1309 expected_result.append([1 + i]) 1310 self.assertAllEqual(result, expected_result) 1311 1312 def testAsyncScheduleWithDistributedDataset(self): 1313 1314 def input_fn(): 1315 dataset = dataset_ops.DatasetV2.from_tensor_slices([2.]).repeat().batch( 1316 self.strategy.num_replicas_in_sync) 1317 return self.strategy.experimental_distribute_dataset(dataset) 1318 1319 with self.strategy.scope(): 1320 v = variables.Variable(initial_value=[0], dtype=dtypes.float32) 1321 1322 # TODO(yuefengz): the following tf.function has a return value which is None 1323 # in its structured_outputs. 1324 @def_function.function 1325 def worker_fn(iterator): 1326 x = next(iterator) 1327 # Reduce to convert PerReplica values to single value 1328 reduced_value = self.strategy.reduce('MEAN', x, axis=None) 1329 v.assign_add(reduced_value) 1330 1331 distributed_dataset = self.coordinator.create_per_worker_dataset(input_fn) 1332 1333 iterator = iter(distributed_dataset) 1334 1335 # Verifying joining without any scheduling doesn't hang. 1336 self.coordinator.join() 1337 self.assertAllEqual(v.read_value(), (0,)) 1338 1339 for _ in range(5): 1340 self.coordinator.schedule(worker_fn, args=(iterator,)) 1341 self.coordinator.join() 1342 1343 # With 5 addition it should be 2*5 = 10. 1344 self.assertAllEqual( 1345 self.strategy.experimental_local_results(v.read_value()), ([[10]])) 1346 1347 for _ in range(5): 1348 self.coordinator.schedule(worker_fn, args=(iterator,)) 1349 1350 # Verifying multiple join is fine. 1351 self.coordinator.join() 1352 self.coordinator.join() 1353 self.coordinator.join() 1354 1355 self.assertTrue(self.coordinator.done()) 1356 1357 # Likewise, it's now 20. 1358 self.assertAllEqual( 1359 self.strategy.experimental_local_results(v.read_value()), ([[20]])) 1360 1361 def testInputFunctionWithMapWithDistributedDataset(self): 1362 self._map_fn_tracing_count = 0 1363 1364 def input_fn(): 1365 1366 def map_fn(x): 1367 self._map_fn_tracing_count += 1 1368 return x + 10 1369 1370 dataset = dataset_ops.DatasetV2.range(0, 10).batch( 1371 self.strategy.num_replicas_in_sync).map(map_fn) 1372 return self.strategy.experimental_distribute_dataset(dataset) 1373 1374 @def_function.function 1375 def worker_fn(iterator): 1376 return next(iterator) 1377 1378 distributed_dataset = self.coordinator.create_per_worker_dataset(input_fn) 1379 result = self.coordinator.schedule( 1380 worker_fn, args=(iter(distributed_dataset),)) 1381 1382 expected_result = array_ops.split( 1383 math_ops.range(10., 10. + self.strategy.num_replicas_in_sync), 1384 num_or_size_splits=self.strategy.num_replicas_in_sync, 1385 axis=0) 1386 1387 self.assertAllEqual( 1388 self.strategy.experimental_local_results(result.fetch()), 1389 tuple(expected_result)) 1390 self.assertEqual(self._map_fn_tracing_count, 1) 1391 1392 def testPerWorkerDistributeDatasetsElementSpec(self): 1393 1394 def per_worker_dataset_fn(): 1395 return self.strategy.distribute_datasets_from_function( 1396 lambda _: dataset_ops.DatasetV2.from_tensor_slices([1, 2])) 1397 1398 dataset = dataset_ops.DatasetV2.from_tensor_slices([1, 2]) 1399 per_worker_distribute_dataset = self.coordinator.create_per_worker_dataset( 1400 per_worker_dataset_fn) 1401 1402 self.assertAllEqual( 1403 # Converts to PerReplicaSpec when num_replicas_in_sync are > 1 1404 input_lib._create_distributed_tensor_spec(self.strategy, 1405 dataset.element_spec), 1406 per_worker_distribute_dataset.element_spec) 1407 1408 def testPerWorkerDistributedIteratorTypeSpec(self): 1409 self._tracing_count = 0 1410 1411 def per_worker_dataset_fn(): 1412 self._tracing_count += 1 1413 return self.strategy.distribute_datasets_from_function( 1414 lambda _: dataset_ops.DatasetV2.range(1, 2)) 1415 1416 @def_function.function 1417 def worker_fn(iterator): 1418 return next(iterator) 1419 1420 distributed_iterator = iter( 1421 self.coordinator.create_per_worker_dataset(per_worker_dataset_fn)) 1422 worker_fn.get_concrete_function(distributed_iterator) 1423 1424 self.coordinator.schedule(worker_fn, args=(distributed_iterator,)) 1425 self.assertEqual(self._tracing_count, 1) 1426 1427 1428if __name__ == '__main__': 1429 v2_compat.enable_v2_behavior() 1430 test.main() 1431