• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for coordinator.py."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import contextlib
23import functools
24import gc
25import os
26import platform
27import sys
28import threading
29import time
30import traceback
31from absl.testing import parameterized
32
33from tensorflow.python.compat import v2_compat
34from tensorflow.python.data.ops import dataset_ops
35from tensorflow.python.distribute import distribute_utils
36from tensorflow.python.distribute import distribution_strategy_context
37from tensorflow.python.distribute import input_lib
38from tensorflow.python.distribute import multi_worker_test_base
39from tensorflow.python.distribute import parameter_server_strategy_v2
40from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
41from tensorflow.python.distribute.coordinator import cluster_coordinator as coordinator_lib
42from tensorflow.python.distribute.coordinator import values as values_lib
43from tensorflow.python.eager import cancellation
44from tensorflow.python.eager import def_function
45from tensorflow.python.eager import test
46from tensorflow.python.framework import constant_op
47from tensorflow.python.framework import dtypes
48from tensorflow.python.framework import errors
49from tensorflow.python.framework import random_seed
50from tensorflow.python.framework import tensor_spec
51from tensorflow.python.framework import test_util
52from tensorflow.python.ops import array_ops
53from tensorflow.python.ops import check_ops
54from tensorflow.python.ops import math_ops
55from tensorflow.python.ops import random_ops
56from tensorflow.python.ops import variable_scope
57from tensorflow.python.ops import variables
58from tensorflow.python.platform import tf_logging as logging
59from tensorflow.python.training import coordinator
60from tensorflow.python.training.server_lib import ClusterSpec
61
62
63class ClosureWithOutput(coordinator_lib.Closure):
64
65  def __init__(self, function, cancellation_mgr=None, args=None, kwargs=None):
66    super(ClosureWithOutput, self).__init__(
67        function, cancellation_mgr=cancellation_mgr, args=args, kwargs=kwargs)
68    self.output_remote_value = self.build_output_remote_value()
69
70
71class CoordinatedClosureQueueTest(test.TestCase):
72
73  def testBasic(self):
74    queue = coordinator_lib._CoordinatedClosureQueue()
75    closure1 = self._create_closure(queue._cancellation_mgr)
76    queue.put(closure1)
77    self.assertIs(closure1, queue.get())
78    self.assertFalse(queue.done())
79    queue.put_back(closure1)
80    self.assertEqual(closure1, queue.get())
81    queue.mark_finished()
82    self.assertTrue(queue.done())
83    queue.wait()
84
85  def testProcessAtLeaseOnce(self):
86    closure_queue = coordinator_lib._CoordinatedClosureQueue()
87    labels = ['A', 'B', 'C', 'D', 'E']
88    processed_count = collections.defaultdict(int)
89
90    coord = coordinator.Coordinator(clean_stop_exception_types=[])
91
92    def process_queue():
93      with coord.stop_on_exception():
94        has_been_put_back = False
95        while True:
96          closure = closure_queue.get(timeout=30)
97          if closure is None:
98            break
99          if not has_been_put_back:
100            has_been_put_back = True
101            closure_queue.put_back(closure)
102            continue
103          closure._function()
104          closure_queue.mark_finished()
105
106    def get_func(label):
107
108      def func():
109        time.sleep(3)
110        processed_count[label] += 1
111
112      return func
113
114    cm = cancellation.CancellationManager()
115    for label in labels:
116      closure_queue.put(ClosureWithOutput(get_func(label), cm))
117    t1 = threading.Thread(target=process_queue, daemon=True)
118    t1.start()
119    t2 = threading.Thread(target=process_queue, daemon=True)
120    t2.start()
121
122    # Make sure multiple wait() calls are fine.
123    closure_queue.wait()
124    closure_queue.wait()
125    closure_queue.wait()
126    closure_queue.wait()
127
128    self.assertEqual(processed_count, collections.Counter(labels))
129
130    coord.join([t1, t2])
131
132  def testNotifyBeforeWait(self):
133    closure_queue = coordinator_lib._CoordinatedClosureQueue()
134
135    def func():
136      logging.info('func running')
137
138    coord = coordinator.Coordinator(clean_stop_exception_types=[])
139
140    def process_queue():
141      with coord.stop_on_exception():
142        closure_queue.get()
143        closure_queue.mark_finished()
144
145    closure_queue.put(ClosureWithOutput(func, closure_queue._cancellation_mgr))
146    t = threading.Thread(target=process_queue)
147    t.start()
148    coord.join([t])
149
150    # This test asserts that waiting at the time the function has been processed
151    # doesn't time out.
152    closure_queue.wait()
153
154  def _assert_one_unblock_the_other(self, first_fn, second_fn):
155    """Asserts `second_fn` wouldn't return before `first_fn` is finished."""
156    first_fn_done = threading.Event()
157    second_fn_done = threading.Event()
158    coord = coordinator.Coordinator(clean_stop_exception_types=[])
159
160    def wrapped_first_fn():
161      with coord.stop_on_exception():
162        self.assertFalse(second_fn_done.is_set())
163        first_fn()
164        first_fn_done.set()
165
166    self.assertFalse(first_fn_done.is_set())
167    t = threading.Thread(target=wrapped_first_fn)
168    t.start()
169
170    second_fn()
171    self.assertTrue(first_fn_done.is_set())
172    second_fn_done.set()
173
174    coord.join([t])
175
176  def testWaitRaiseErrorAfterMarkFailure(self):
177    if sys.version_info >= (3, 8) and platform.system() == 'Windows':
178      # TODO(b/165013260): Fix this
179      self.skipTest('Test is currently broken on Windows with Python 3.8')
180
181    closure_queue = coordinator_lib._CoordinatedClosureQueue()
182    closure_queue.put(self._create_closure(closure_queue._cancellation_mgr))
183    closure = closure_queue.get()
184
185    wait_finish_event = threading.Event()
186    coord = coordinator.Coordinator(clean_stop_exception_types=[])
187
188    # Using a thread to verify that closure_queue.wait() will not return until
189    # all inflight closures are finished.
190
191    def mark_finished_fn():
192      try:
193        raise ValueError('Some error.')
194      except ValueError as e:
195        closure_queue.mark_failed(e)
196
197    def wait_fn():
198      with self.assertRaises(ValueError):
199        closure_queue.wait()
200
201    self._assert_one_unblock_the_other(mark_finished_fn, wait_fn)
202
203    self.assertTrue(closure_queue.done())
204
205  def _create_closure(self, cancellation_mgr):
206
207    @def_function.function()
208    def some_function():
209      return 1.0
210
211    return ClosureWithOutput(some_function, cancellation_mgr)
212
213  def _put_two_closures_and_get_one(self):
214    closure_queue = coordinator_lib._CoordinatedClosureQueue()
215    closure1 = self._create_closure(closure_queue._cancellation_mgr)
216    closure_queue.put(closure1)
217
218    closure2 = self._create_closure(closure_queue._cancellation_mgr)
219    closure_queue.put(closure2)
220
221    closure_got = closure_queue.get()  # returns closure1
222    self.assertIs(closure_got, closure1)
223    self.assertIsNot(closure_got, closure2)
224    return closure_queue, closure1, closure2
225
226  def testPutRaiseError(self):
227    if sys.version_info >= (3, 8) and platform.system() == 'Windows':
228      # TODO(b/165013260): Fix this
229      self.skipTest('Test is currently broken on Windows with Python 3.8')
230
231    closure_queue, _, closure2 = self._put_two_closures_and_get_one()
232
233    closure_queue.mark_failed(ValueError())
234
235    with self.assertRaises(ValueError):
236      closure_queue.put(self._create_closure(closure_queue._cancellation_mgr))
237
238    self.assertTrue(closure_queue.done())
239
240    with self.assertRaisesRegex(
241        errors.CancelledError,
242        'The corresponding function is cancelled. Please reschedule the '
243        'function.'):
244      closure2.output_remote_value.fetch()
245
246    # The error is cleared.
247    closure_queue.put(self._create_closure(closure_queue._cancellation_mgr))
248
249  def testWaitRaiseError(self):
250    if sys.version_info >= (3, 8) and platform.system() == 'Windows':
251      # TODO(b/165013260): Fix this
252      self.skipTest('Test is currently broken on Windows with Python 3.8')
253
254    closure_queue, _, closure2 = self._put_two_closures_and_get_one()
255
256    closure_queue.mark_failed(ValueError())
257
258    with self.assertRaises(ValueError):
259      closure_queue.wait()
260    self.assertTrue(closure_queue.done())
261
262    with self.assertRaisesRegex(
263        errors.CancelledError,
264        'The corresponding function is cancelled. Please reschedule the '
265        'function.'):
266      closure2.output_remote_value.fetch()
267
268    # The error is cleared.
269    closure_queue.wait()
270
271  def testDoneRaiseError(self):
272    if sys.version_info >= (3, 8) and platform.system() == 'Windows':
273      # TODO(b/165013260): Fix this
274      self.skipTest('Test is currently broken on Windows with Python 3.8')
275
276    closure_queue, _, _ = self._put_two_closures_and_get_one()
277
278    self.assertFalse(closure_queue.done())
279    closure_queue.mark_failed(ValueError())
280    with self.assertRaises(ValueError):
281      closure_queue.done()
282
283  def _set_error(self, closure_queue, closure, error):
284    try:
285      raise error
286    except Exception as e:  # pylint: disable=broad-except
287      closure.output_remote_value._set_error(e)
288      closure_queue.mark_failed(e)
289
290  def _test_cancel_closure_when_error(self, call_wait):
291    if sys.version_info >= (3, 8) and platform.system() == 'Windows':
292      # TODO(b/165013260): Fix this
293      self.skipTest('Test is currently broken on Windows with Python 3.8')
294
295    closure_queue, closure1, closure2 = self._put_two_closures_and_get_one()
296    closure_queue.put(self._create_closure(closure_queue._cancellation_mgr))
297    closure_queue.get()
298    # At this moment, there are two inflight, one in queue.
299    self.assertEqual(closure_queue._inflight_closure_count, 2)
300
301    # Hold a copy of the queue's cancellation manager at this point
302    initial_cm = closure_queue._cancellation_mgr
303
304    # Simulating closure1 fails.
305    self._set_error(closure_queue, closure1, ValueError('Some error.'))
306
307    # At this moment, there are one inflight, one in queue.
308    self.assertEqual(closure_queue._queue.qsize(), 1)
309    self.assertEqual(closure_queue._inflight_closure_count, 1)
310
311    closure3 = self._create_closure(closure_queue._cancellation_mgr)
312
313    def fake_cancellation():
314      self._set_error(closure_queue, closure2,
315                      ValueError('Fake cancellation error.'))
316
317    def report_error():
318      # It should not report the fake cancellation error.
319      with self.assertRaisesRegex(ValueError, 'Some error.'):
320        # Verifying `wait()` or `put()` raises even if one closure is in
321        # flight.
322        if call_wait:
323          closure_queue.wait()
324        else:
325          closure_queue.put(closure3)
326
327    self._assert_one_unblock_the_other(fake_cancellation, report_error)
328
329    # The original cancellation manager of the queue has been cancelled.
330    self.assertTrue(initial_cm.is_cancelled)
331
332    # At this moment, there is zero inflight, nothing in queue.
333    self.assertTrue(closure_queue._queue.empty())
334    self.assertEqual(closure_queue._inflight_closure_count, 0)
335    self.assertIsNone(closure_queue._error)
336
337    # This asserts that closure1 has errored.
338    with self.assertRaisesRegex(ValueError, 'Some error.'):
339      closure1.output_remote_value.fetch()
340
341    # The following asserts that closure3 should have been cancelled.
342    if not call_wait:
343      with self.assertRaisesRegex(
344          errors.CancelledError,
345          'The corresponding function is cancelled. Please reschedule the '
346          'function.'):
347        closure3.output_remote_value.fetch()
348
349    # Closure2 was an inflight closure when it got cancelled.
350    self.assertEqual(closure2.output_remote_value._status,
351                     values_lib.RemoteValueStatus.READY)
352    with self.assertRaisesRegex(ValueError, 'Fake cancellation error.'):
353      closure2.output_remote_value.fetch()
354
355    # This asserts that the queue has a clear state.
356    self.testBasic()
357
358  def testWaitRaiseErrorAfterCancelClosure(self):
359    self._test_cancel_closure_when_error(call_wait=True)
360
361  def testPutRaiseErrorAfterCancelClosure(self):
362    self._test_cancel_closure_when_error(call_wait=False)
363
364  def testStateIsRestoredAfterJoinIsCalled(self):
365    if sys.version_info >= (3, 8) and platform.system() == 'Windows':
366      # TODO(b/165013260): Fix this
367      self.skipTest('Test is currently broken on Windows with Python 3.8')
368
369    closure_queue, _, _ = self._put_two_closures_and_get_one()
370    self.assertEqual(closure_queue._inflight_closure_count, 1)
371    closure_queue.mark_failed(ValueError('test error'))
372    with self.assertRaises(ValueError):
373      closure_queue.put(self._create_closure(closure_queue._cancellation_mgr))
374
375    # Its error should have been cleared.
376    self.assertIsNone(closure_queue._error)
377    closure_queue.put(self._create_closure(closure_queue._cancellation_mgr))
378    self.assertIsNone(closure_queue._error)
379
380  def testThreadSafey(self):
381    thread_count = 10
382    queue = coordinator_lib._CoordinatedClosureQueue()
383
384    # Each thread performs 20 queue actions: 10 are `put_back` and 10 are
385    # `mark_finished`.
386    action_count = 20
387
388    def func():
389      for i in range(action_count):
390        closure = queue.get()
391        if i % 2 == 0:
392          queue.put_back(closure)
393        else:
394          queue.mark_finished()
395
396    threads = [threading.Thread(target=func) for i in range(thread_count)]
397    for t in threads:
398      t.start()
399
400    for _ in range(thread_count * action_count // 2):
401      queue.put(self._create_closure(queue._cancellation_mgr))
402    queue.wait()
403    self.assertTrue(queue.done())
404
405
406class ErrorReportingThread(threading.Thread):
407
408  error = None
409
410  def __init__(self, *args, **kwargs):
411    assert 'target' in kwargs
412    target = kwargs['target']
413
414    @functools.wraps(target)
415    def wrapped_target(*args, **kwargs):
416      try:
417        return target(*args, **kwargs)
418      except Exception as e:  # pylint: disable=broad-except
419        traceback.print_exception(*sys.exc_info())
420        ErrorReportingThread.error = e
421
422    kwargs['target'] = wrapped_target
423    super(ErrorReportingThread, self).__init__(*args, **kwargs)
424
425
426class TestCaseWithErrorReportingThread(test.TestCase):
427
428  @classmethod
429  def setUpClass(cls):
430    cls._threading_thread = threading.Thread
431    threading.Thread = ErrorReportingThread
432    super(TestCaseWithErrorReportingThread, cls).setUpClass()
433
434  @classmethod
435  def tearDownClass(cls):
436    super(TestCaseWithErrorReportingThread, cls).tearDownClass()
437    threading.Thread = cls._threading_thread
438
439  def setUp(self):
440    ErrorReportingThread.error = None
441    super(TestCaseWithErrorReportingThread, self).setUp()
442
443  def tearDown(self):
444    super(TestCaseWithErrorReportingThread, self).tearDown()
445    if ErrorReportingThread.error:
446      raise ErrorReportingThread.error  # pylint: disable=raising-bad-type
447
448
449def make_coordinator(num_workers, num_ps):
450  # TODO(rchao): Test the internal rpc_layer version.
451  cluster_def = multi_worker_test_base.create_in_process_cluster(
452      num_workers=num_workers, num_ps=num_ps, rpc_layer='grpc')
453  cluster_def['chief'] = [
454      'localhost:%d' % multi_worker_test_base.pick_unused_port()
455  ]
456  cluster_resolver = SimpleClusterResolver(
457      ClusterSpec(cluster_def), rpc_layer='grpc')
458  strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
459      cluster_resolver)
460  return coordinator_lib.ClusterCoordinator(strategy)
461
462
463class ClusterCoordinatorTest(TestCaseWithErrorReportingThread,
464                             parameterized.TestCase):
465
466  @classmethod
467  def setUpClass(cls):
468    super(ClusterCoordinatorTest, cls).setUpClass()
469    cls.coordinator = make_coordinator(num_workers=5, num_ps=2)
470    cls.strategy = cls.coordinator.strategy
471
472  def testClusterCoordinatorOnlyInitOnce(self):
473    cluster = self.coordinator._cluster
474    same_coordinator = coordinator_lib.ClusterCoordinator(self.strategy)
475    self.assertIs(self.coordinator, same_coordinator)
476    self.assertIs(cluster, same_coordinator._cluster)
477
478  def testFnReturnNestedValues(self):
479    x = constant_op.constant(1)
480
481    @def_function.function
482    def f():
483      return x + 1, (x + 2, x + 3), [x + 4], {'v': x}
484
485    got = self.coordinator.schedule(f)
486    want = 2, (3, 4), [5], {'v': 1}
487    self.assertEqual(got.fetch(), want)
488    self.assertEqual(self.coordinator.fetch(got), want)
489
490  def testFetchingRemoteValueStructure(self):
491    x = constant_op.constant(1)
492
493    @def_function.function
494    def f():
495      return x + 1, (x + 2, x + 3), [x + 4], {'v': x}
496
497    want = 2, (3, 4), [5], {'v': 1}
498    remote_value_list = [self.coordinator.schedule(f) for _ in range(5)]
499    self.assertAllEqual(
500        self.coordinator.fetch(remote_value_list), [want for _ in range(5)])
501
502  def testInputFunction(self):
503
504    def input_fn():
505      return dataset_ops.DatasetV2.range(1, 2)
506
507    with self.strategy.scope():
508      v = variables.Variable(initial_value=0, dtype=dtypes.int64)
509
510    @def_function.function
511    def worker_fn(iterator):
512      x = next(iterator)
513      v.assign_add(x)
514      return x
515
516    distributed_dataset = self.coordinator.create_per_worker_dataset(input_fn)
517    result = self.coordinator.schedule(
518        worker_fn, args=(iter(distributed_dataset),))
519    result = self.coordinator.fetch(result)
520    self.assertEqual(result, (1,))
521    result = self.coordinator.schedule(
522        worker_fn, args=(iter(distributed_dataset),))
523    result = self.coordinator.fetch(result)
524
525    self.assertEqual(result, (1,))
526    self.assertAlmostEqual(v.read_value(), 2, delta=1e-6)
527
528  def testAsyncScheduleAndJoin(self):
529    if test_util.is_xla_enabled():
530      self.skipTest('Assign_add is not deterministic across threads in XLA')
531
532    def input_fn():
533      return dataset_ops.DatasetV2.from_tensor_slices([2] * 10)
534
535    with self.strategy.scope():
536      v = variables.Variable(initial_value=0, dtype=dtypes.int32)
537
538    # TODO(yuefengz): the following tf.function has a return value which is None
539    # in its structured_outputs.
540    @def_function.function
541    def worker_fn(iterator):
542      x = next(iterator)
543      v.assign_add(x)
544
545    distributed_dataset = self.coordinator.create_per_worker_dataset(input_fn)
546
547    iterator = iter(distributed_dataset)
548
549    # Verifying joining without any scheduling doesn't hang.
550    self.coordinator.join()
551    self.assertEqual(v.read_value().numpy(), 0)
552
553    for _ in range(5):
554      self.coordinator.schedule(worker_fn, args=(iterator,))
555    self.coordinator.join()
556
557    # With 5 addition it should be 2*5 = 10.
558    self.assertEqual(v.read_value().numpy(), 10)
559
560    for _ in range(5):
561      self.coordinator.schedule(worker_fn, args=(iterator,))
562
563    # Verifying multiple join is fine.
564    self.coordinator.join()
565    self.coordinator.join()
566    self.coordinator.join()
567
568    self.assertTrue(self.coordinator.done())
569
570    # Likewise, it's now 20.
571    self.assertEqual(v.read_value().numpy(), 20.)
572
573  @parameterized.parameters(True, False)
574  def testInputFunctionWithMap(self, use_input_fn):
575    self._map_fn_tracing_count = 0
576
577    def input_fn():
578
579      def map_fn(x):
580        self._map_fn_tracing_count += 1
581        return x + 10
582
583      return dataset_ops.DatasetV2.range(0, 10).map(map_fn)
584
585    @def_function.function
586    def worker_fn(iterator):
587      return next(iterator)
588
589    if use_input_fn:
590      distributed_dataset = self.coordinator.create_per_worker_dataset(input_fn)
591    else:
592      distributed_dataset = self.coordinator.create_per_worker_dataset(
593          input_fn())
594
595    result = self.coordinator.schedule(
596        worker_fn, args=(iter(distributed_dataset),))
597    self.assertEqual(result.fetch(), (10,))
598    self.assertEqual(self._map_fn_tracing_count, 1)
599
600  def testInputFunctionCreateVariables(self):
601
602    def input_fn():
603      v = variables.Variable(initial_value=0.0)
604      return v.read_value()
605
606    with self.assertRaises(ValueError):
607      self.coordinator.create_per_worker_dataset(input_fn)
608
609  @parameterized.parameters(True, False)
610  def testDatasetsShuffledDifferently(self, use_input_fn):
611    # This test requires at least two workers in the cluster.
612    self.assertGreaterEqual(len(self.coordinator._cluster.workers), 2)
613
614    random_seed.set_random_seed(None)
615
616    def input_fn():
617      dataset = dataset_ops.DatasetV2.range(0, 100).shuffle(100).batch(1)
618      return self.strategy.experimental_distribute_dataset(dataset)
619
620    if use_input_fn:
621      distributed_dataset = self.coordinator.create_per_worker_dataset(input_fn)
622    else:
623      distributed_dataset = self.coordinator.create_per_worker_dataset(
624          input_fn())
625    distributed_iterator = iter(distributed_dataset)
626    # Get elements from the first two iterators.
627    iterator_1 = distributed_iterator._values[0]
628    iterator_1._rebuild_on(self.coordinator._cluster.workers[0])
629    iterator_1 = iterator_1.fetch()
630    elements_in_iterator_1 = [
631        self.strategy.experimental_local_results(e)
632        for e in iterator_1
633    ]
634    iterator_2 = distributed_iterator._values[1]
635    iterator_2._rebuild_on(self.coordinator._cluster.workers[1])
636    iterator_2 = iterator_2.fetch()
637    elements_in_iterator_2 = [
638        self.strategy.experimental_local_results(e)
639        for e in iterator_2
640    ]
641
642    self.assertNotAllEqual(elements_in_iterator_1, elements_in_iterator_2)
643
644  def testPerWorkerValue(self):
645    self.skipTest('b/168569314')
646    var_shape = tuple()
647    var_dtype = dtypes.float32
648    var_name = 'var'
649
650    def create_var():
651      var = variables.Variable(
652          initial_value=0.0, dtype=var_dtype, name=var_name)
653      self.assertIn('worker', var.device)
654      return var
655
656    worker_local_var = self.coordinator._create_per_worker_resources(create_var)
657
658    # The following is a workaround to allow `worker_local_var` to be passed in
659    # as args to the `coordinator.schedule` method which requires tensor specs
660    # to trace tf.function but _create_worker_resources' return values don't
661    # have tensor specs. We can get rid of this workaround once
662    # _create_worker_resources is able to infer the tensor spec of the return
663    # value of the function passed in. See b/154675763.
664    for var in worker_local_var._values:
665      var._type_spec = tensor_spec.TensorSpec(var_shape, var_dtype, var_name)
666
667    def worker_fn(var):
668      var.assign_add(1.0)
669
670    for _ in range(10):
671      # Which slice of `worker_local_var` will be used will depend on which
672      # worker the `worker_fn` gets scheduled on.
673      self.coordinator.schedule(worker_fn, args=(worker_local_var,))
674    self.coordinator.join()
675
676    var_sum = sum(self.coordinator.fetch(worker_local_var._values))
677    self.assertEqual(var_sum, 10.0)
678
679  def testDisallowRemoteValueAsInput(self):
680
681    @def_function.function
682    def func_0():
683      return 1.0
684
685    @def_function.function
686    def func_1(x):
687      return x + 1.0
688
689    remote_v = self.coordinator.schedule(func_0)
690    with self.assertRaises(ValueError):
691      self.coordinator.schedule(func_1, args=(remote_v,))
692
693  def testPythonFunctionNotAllowedToSchedule(self):
694
695    def func(a):
696      return array_ops.identity(a)
697
698    with self.assertRaisesRegexp(
699        TypeError,
700        '`tf.distribute.experimental.coordinator.ClusterCoordinator.schedule` '
701        'only accepts a `tf.function` or a concrete function.'):
702      self.coordinator.schedule(func, args=(1,))
703
704  def testDatasetPartiallyCreatedOnCoordinator(self):
705    dataset = dataset_ops.DatasetV2.range(1, 10)
706
707    @def_function.function
708    def input_fn():
709      return dataset.shuffle(9)
710
711    @def_function.function
712    def worker_fn(iterator):
713      x = next(iterator)
714      return x
715
716    per_worker_dataset = self.coordinator.create_per_worker_dataset(input_fn)
717    self.coordinator.schedule(worker_fn, args=(iter(per_worker_dataset),))
718
719    with self.assertRaisesRegexp(
720        coordinator_lib.InputError,
721        'error message is Failed copying input tensor from'):
722      self.coordinator.join()
723
724  def testPassDatasetToCreatePerWorkerDataset(self):
725    dataset = dataset_ops.DatasetV2.range(1, 11).batch(4)
726
727    @def_function.function
728    def worker_fn(iterator):
729      return next(iterator)
730
731    per_worker_dataset = self.coordinator.create_per_worker_dataset(dataset)
732    result = self.coordinator.schedule(
733        worker_fn, args=(iter(per_worker_dataset),))
734    result = result.fetch()
735    expected_result = math_ops.range(1., 5.)
736
737    self.assertAllEqual(result, (expected_result))
738
739  def testMultipleDatasets(self):
740
741    def input_fn1():
742      return dataset_ops.DatasetV2.range(0, 5)
743
744    def input_fn2():
745      return dataset_ops.DatasetV2.range(5, 10)
746
747    per_worker_dataset1 = self.coordinator.create_per_worker_dataset(input_fn1)
748    per_worker_iterator1 = iter(per_worker_dataset1)
749    per_worker_dataset2 = self.coordinator.create_per_worker_dataset(input_fn2)
750    per_worker_iterator2 = iter(per_worker_dataset2)
751
752    @def_function.function
753    def worker_fn(iterator1, iterator2):
754      return next(iterator1) + next(iterator2)
755
756    result = self.coordinator.schedule(
757        worker_fn, args=(per_worker_iterator1, per_worker_iterator2))
758    self.assertEqual(result.fetch(), 5.0)
759
760    per_worker_dataset3 = self.coordinator.create_per_worker_dataset(input_fn1)
761    per_worker_iterator3 = iter(per_worker_dataset3)
762
763    result = self.coordinator.schedule(
764        worker_fn, args=(per_worker_iterator3, per_worker_iterator2))
765    self.assertGreaterEqual(result.fetch(), 5.0)
766
767  def testRepeatedIteratorCreation(self):
768
769    def input_fn():
770      return dataset_ops.DatasetV2.range(1, 100)
771
772    per_worker_dataset1 = self.coordinator.create_per_worker_dataset(input_fn)
773    per_worker_dataset2 = self.coordinator.create_per_worker_dataset(input_fn)
774
775    @def_function.function
776    def worker_fn(iterator1, iterator2):
777      return next(iterator1) + next(iterator2)
778
779    for _ in range(10):
780      per_worker_iterator1 = iter(per_worker_dataset1)
781      per_worker_iterator2 = iter(per_worker_dataset2)
782      result = self.coordinator.schedule(
783          worker_fn, args=(per_worker_iterator1, per_worker_iterator2))
784      for _ in range(10):
785        self.coordinator.schedule(
786            worker_fn, args=(per_worker_iterator1, per_worker_iterator2))
787      self.coordinator.join()
788      self.assertGreaterEqual(result.fetch(), 2.0)
789    del per_worker_iterator1, per_worker_iterator2
790    gc.collect()
791
792    # There shouldn't be any live iterator objects.
793    for w in self.coordinator._cluster.workers:
794      for r in w._resource_remote_value_refs:
795        self.assertIsNone(r())
796
797
798class LimitedClosureQueueSizeBasicTest(ClusterCoordinatorTest):
799  """Test basic functionality works with explicit maximum closure queue size.
800
801  Execute the same set of test cases as in `ClusterCoordinatorTest`, with an
802  explicit size limit for the closure queue. Note that even when the queue size
803  is set to infinite, there is still a maximum practical size (depends on host
804  memory limit) that might cause the queue.put operations to be blocking when
805  scheduling a large number of closures on a big cluster. These tests make sure
806  that the coordinator does not run into deadlocks in such scenario.
807  """
808
809  @classmethod
810  def setUpClass(cls):
811    super(LimitedClosureQueueSizeBasicTest, cls).setUpClass()
812    coordinator_lib._CLOSURE_QUEUE_MAX_SIZE = 2
813    cls.coordinator = make_coordinator(num_workers=5, num_ps=2)
814    cls.strategy = cls.coordinator.strategy
815
816
817class ScheduleStartDelayTest(ClusterCoordinatorTest):
818  """Test basic functionality works with worker scheduling delay.
819
820  This is basically to make sure that setting environment variables
821  `TF_COORDINATOR_SCHEDULE_START_DELAY` and
822  `TF_COORDINATOR_SCHEDULE_START_DELAY_MAX` will cause any failure.
823  """
824
825  @classmethod
826  def setUpClass(cls):
827    super(ScheduleStartDelayTest, cls).setUpClass()
828    os.environ['TF_COORDINATOR_SCHEDULE_START_DELAY'] = '2'
829    os.environ['TF_COORDINATOR_SCHEDULE_START_DELAY_MAX'] = '4'
830    cls.coordinator = make_coordinator(num_workers=3, num_ps=2)
831    cls.strategy = cls.coordinator.strategy
832
833  @classmethod
834  def tearDownClass(cls):
835    del os.environ['TF_COORDINATOR_SCHEDULE_START_DELAY']
836    del os.environ['TF_COORDINATOR_SCHEDULE_START_DELAY_MAX']
837    super(ScheduleStartDelayTest, cls).tearDownClass()
838
839
840class ErrorReportingTest(TestCaseWithErrorReportingThread):
841
842  @classmethod
843  def setUpClass(cls):
844    super(ErrorReportingTest, cls).setUpClass()
845    cls.coordinator = make_coordinator(num_workers=3, num_ps=2)
846    cls.strategy = cls.coordinator.strategy
847
848    with cls.strategy.scope():
849      cls.iteration = variables.Variable(initial_value=0.0)
850
851  @def_function.function
852  def _normal_function(self):
853    x = random_ops.random_uniform((2, 10))
854    y = random_ops.random_uniform((10, 2))
855    self.iteration.assign_add(1.0)
856    return math_ops.reduce_mean(math_ops.matmul(x, y))
857
858  @def_function.function
859  def _error_function(self):
860    x = random_ops.random_uniform((2, 10))
861    y = random_ops.random_uniform((10, 2))
862    check_ops.assert_non_positive_v2(math_ops.reduce_sum(math_ops.matmul(x, y)))
863    self.iteration.assign_add(1.0)
864    return self.iteration
865
866  @def_function.function
867  def _long_function(self):
868    x = random_ops.random_uniform((1000, 1000))
869    for _ in math_ops.range(10000):
870      a = random_ops.random_uniform((1000, 1000))
871      b = random_ops.random_uniform((1000, 1000))
872      x += math_ops.matmul(a, b)
873    return x
874
875  def testJoinRaiseError(self):
876    for _ in range(3):
877      self.coordinator.schedule(self._normal_function)
878    self.coordinator.schedule(self._error_function)
879    with self.assertRaises(errors.InvalidArgumentError):
880      self.coordinator.join()
881
882  def testScheduleRaiseError(self):
883    for _ in range(3):
884      self.coordinator.schedule(self._normal_function)
885    self.coordinator.schedule(self._error_function)
886    with self.assertRaises(errors.InvalidArgumentError):
887      while True:
888        self.coordinator.schedule(self._normal_function)
889
890  def testScheduleRaiseErrorWithMultipleFailure(self):
891    for _ in range(3):
892      self.coordinator.schedule(self._normal_function)
893    self.coordinator.schedule(self._error_function)
894    with self.assertRaises(errors.InvalidArgumentError):
895      while True:
896        self.coordinator.schedule(self._error_function)
897    self.coordinator.join()
898
899  def testErrorWillbeCleared(self):
900    self.coordinator.schedule(self._error_function)
901    with self.assertRaises(errors.InvalidArgumentError):
902      self.coordinator.join()
903
904    for _ in range(3):
905      self.coordinator.schedule(self._normal_function)
906    self.coordinator.schedule(self._error_function)
907    with self.assertRaises(errors.InvalidArgumentError):
908      self.coordinator.join()
909
910  def testRemoteValueReturnError(self):
911    result = self.coordinator.schedule(self._error_function)
912
913    with self.assertRaises(errors.InvalidArgumentError):
914      result.fetch()
915
916    # Clear the error.
917    with self.assertRaises(errors.InvalidArgumentError):
918      self.coordinator.join()
919
920  def testInputError(self):
921
922    worker_local_val = self.coordinator._create_per_worker_resources(
923        self._error_function)
924
925    @def_function.function
926    def func(x):
927      return x + 1
928
929    result = self.coordinator.schedule(func, args=(worker_local_val,))
930    with self.assertRaises(coordinator_lib.InputError):
931      self.coordinator.join()
932
933    with self.assertRaises(coordinator_lib.InputError):
934      result.fetch()
935
936  def testCancellation(self):
937    for _ in range(3):
938      self.coordinator.schedule(self._normal_function)
939    long_function = self.coordinator.schedule(self._long_function)
940    self.coordinator.schedule(self._error_function)
941
942    with self.assertRaises(errors.InvalidArgumentError):
943      self.coordinator.join()
944
945    with self.assertRaises(errors.CancelledError):
946      long_function.fetch()
947
948    for _ in range(3):
949      self.coordinator.schedule(self._normal_function)
950    self.coordinator.join()
951
952
953class LimitedClosureQueueErrorTest(ErrorReportingTest):
954  """Test error reporting works with explicit maximum closure queue size.
955
956  Execute the same set of test cases as in ErrorReportingTest, with an explicit
957  size limit for the closure queue.
958  """
959
960  @classmethod
961  def setUpClass(cls):
962    super(LimitedClosureQueueErrorTest, cls).setUpClass()
963    coordinator_lib._CLOSURE_QUEUE_MAX_SIZE = 2
964    cls.coordinator = make_coordinator(num_workers=3, num_ps=2)
965    cls.strategy = cls.coordinator.strategy
966
967    with cls.coordinator.strategy.scope():
968      cls.iteration = variables.Variable(initial_value=0.0)
969
970
971class StrategyIntegrationTest(test.TestCase, parameterized.TestCase):
972
973  @classmethod
974  def setUpClass(cls):
975    super(StrategyIntegrationTest, cls).setUpClass()
976    cls.coordinator = make_coordinator(num_workers=1, num_ps=1)
977    cls.strategy = cls.coordinator.strategy
978
979  def testRunNotUsedWithClusterCoordinatorSchedule(self):
980
981    @def_function.function
982    def input_fn():
983      return dataset_ops.DatasetV2.range(1, 3)
984
985    with self.strategy.scope():
986      v = variables.Variable(initial_value=1, dtype=dtypes.int64)
987
988      def replica_fn(input_tensor):
989        return input_tensor + v, input_tensor - v
990
991      @def_function.function
992      def worker_fn(iterator):
993        return self.strategy.run(replica_fn, args=(next(iterator),))
994
995    per_worker_dataset = self.coordinator.create_per_worker_dataset(input_fn)
996
997    @contextlib.contextmanager
998    def _assert_logs_usage_warning():
999      with self.assertLogs(level='WARNING') as logs:
1000        yield
1001
1002      self.assertIn(
1003          'It is detected that a function used with '
1004          '`tf.distribute.experimental.ParameterServerStrategy` '
1005          'is executed locally on the coordinator. This is inefficient but may '
1006          'be valid for one-off tasks such as inferring output signature. '
1007          'To properly distribute functions to run on workers, `run` or '
1008          '`reduce` should be used within a function passed to `'
1009          'tf.distribute.experimental.coordinator.ClusterCoordinator.schedule`'
1010          '.',
1011          logs.output[0])
1012
1013    with _assert_logs_usage_warning():
1014      # Invoking `run` without `coordinator.schedule` should result in a
1015      # warning.
1016      self.strategy.run(
1017          replica_fn, args=(constant_op.constant(1, dtype=dtypes.int64),))
1018
1019    # A proper `schedule` should succeed.
1020    rv = self.coordinator.schedule(worker_fn, args=(iter(per_worker_dataset),))
1021
1022    with _assert_logs_usage_warning():
1023      # Invoking `run` without `coordinator.schedule` again should result in a
1024      # warning.
1025      self.strategy.run(
1026          replica_fn, args=(constant_op.constant(1, dtype=dtypes.int64),))
1027
1028    all_results = [(2, 0)] * self.strategy.num_replicas_in_sync
1029    expected_result = []
1030    for i in range(self.strategy.num_replicas_in_sync):
1031      expected_result.append(all_results[i])
1032
1033    self.assertAllEqual(
1034        tuple(expected_result),
1035        self.strategy.experimental_local_results(rv.fetch()))
1036
1037  def testBasicVariableAssignment(self):
1038    self.strategy.extended._variable_count = 0
1039    with self.strategy.scope():
1040      v1 = variables.Variable(initial_value=0.0)
1041      v2 = variables.Variable(initial_value=1.0)
1042    self.assertEqual(self.strategy.extended._variable_count, 2)
1043
1044    @def_function.function
1045    def worker_fn():
1046      v1.assign_add(0.1)
1047      v2.assign_sub(0.2)
1048      return v1.read_value() / v2.read_value()
1049
1050    results = self.coordinator.schedule(worker_fn)
1051    logging.info('Results of experimental_run_v2: %f',
1052                 self.coordinator.fetch(results))
1053
1054    self.assertAlmostEqual(v1.read_value().numpy(), 0.1, delta=1e-6)
1055    self.assertAlmostEqual(v2.read_value().numpy(), 0.8, delta=1e-6)
1056
1057  def testRunAndReduce(self):
1058    self.assertFalse(distribution_strategy_context.in_cross_replica_context())
1059    with self.strategy.scope():
1060      self.assertTrue(distribution_strategy_context.in_cross_replica_context())
1061      v = variables.Variable(initial_value=1.)
1062
1063      expected_result = (4. * self.strategy.num_replicas_in_sync,
1064                         2. * self.strategy.num_replicas_in_sync)
1065
1066      @def_function.function
1067      def worker_fn(input_tensor):
1068
1069        def replica_fn(input_tensor):
1070          # Within `replica_fn`, it has to be in a replica context.
1071          self.assertFalse(
1072              distribution_strategy_context.in_cross_replica_context())
1073          return input_tensor + v, input_tensor - v
1074
1075        run_result = self.strategy.run(replica_fn, args=(input_tensor,))
1076        reduced_result = self.strategy.reduce('SUM', run_result, axis=None)
1077        check_ops.assert_equal_v2(reduced_result, expected_result)
1078        return reduced_result
1079
1080      # Asserting scheduling in scope has the expected behavior.
1081      result = self.coordinator.schedule(
1082          worker_fn, args=(constant_op.constant(3.),))
1083      self.assertIsInstance(result, coordinator_lib.RemoteValue)
1084      self.assertEqual(result.fetch(), expected_result)
1085
1086    # Asserting scheduling out of scope has the expected behavior.
1087    result = self.coordinator.schedule(
1088        worker_fn, args=(constant_op.constant(3.),))
1089    self.assertEqual(result.fetch(), expected_result)
1090
1091  def testRunAndReduceWithAssignAdd(self):
1092    self.assertFalse(distribution_strategy_context.in_cross_replica_context())
1093    with self.strategy.scope():
1094      self.assertTrue(distribution_strategy_context.in_cross_replica_context())
1095      v = variables.Variable(initial_value=1.)
1096      v1 = variables.Variable(
1097          initial_value=0.,
1098          aggregation=variable_scope.VariableAggregation.ONLY_FIRST_REPLICA)
1099
1100      expected_result = (4. * self.strategy.num_replicas_in_sync,
1101                         2. * self.strategy.num_replicas_in_sync)
1102
1103      @def_function.function
1104      def worker_fn(input_tensor):
1105
1106        def replica_fn(input_tensor):
1107          # Within `replica_fn`, it has to be in a replica context.
1108          self.assertFalse(
1109              distribution_strategy_context.in_cross_replica_context())
1110
1111          v1.assign_add(input_tensor)
1112          return input_tensor + v, input_tensor - v
1113
1114        run_result = self.strategy.run(replica_fn, args=(input_tensor,))
1115        reduced_result = self.strategy.reduce('SUM', run_result, axis=None)
1116        check_ops.assert_equal_v2(reduced_result, expected_result)
1117        return reduced_result
1118
1119      # Asserting scheduling in scope has the expected behavior.
1120      result = self.coordinator.schedule(
1121          worker_fn, args=(constant_op.constant(3.),))
1122      self.assertIsInstance(result, coordinator_lib.RemoteValue)
1123      self.assertEqual(result.fetch(), expected_result)
1124
1125    # Asserting scheduling out of scope has the expected behavior.
1126    result = self.coordinator.schedule(
1127        worker_fn, args=(constant_op.constant(3.),))
1128    self.assertEqual(result.fetch(), expected_result)
1129    self.assertEqual(v1, 6.)
1130
1131  def testVariableAggregation(self):
1132    self.assertFalse(distribution_strategy_context.in_cross_replica_context())
1133    with self.strategy.scope():
1134      self.assertTrue(distribution_strategy_context.in_cross_replica_context())
1135      v = variables.Variable(
1136          initial_value=1.,
1137          aggregation=variable_scope.VariableAggregation.SUM)
1138
1139      @def_function.function
1140      def worker_fn():
1141
1142        def replica_fn():
1143          value = math_ops.cast(
1144              distribution_strategy_context.get_replica_context()
1145              .replica_id_in_sync_group + 1, v.dtype)
1146          v.assign(value)
1147
1148        self.strategy.run(replica_fn)
1149
1150      self.coordinator.schedule(worker_fn)
1151      self.coordinator.join()
1152      expected_result = 0.
1153      for i in range(self.strategy.num_replicas_in_sync):
1154        expected_result = expected_result + i + 1
1155      self.assertEqual(v, expected_result)
1156
1157  def testVariableCaching(self):
1158    self.assertFalse(distribution_strategy_context.in_cross_replica_context())
1159    with self.strategy.scope():
1160      self.assertTrue(distribution_strategy_context.in_cross_replica_context())
1161      v = variables.Variable(
1162          initial_value=1.,
1163          aggregation=variable_scope.VariableAggregation.ONLY_FIRST_REPLICA)
1164
1165      # Test read value inside caching scope
1166      with distribute_utils.cache_variable_reads():
1167        v.read_value()  # Reads value 1.0
1168        v.assign(constant_op.constant(5.0))  # v changes to 5.0
1169        self.assertEqual(v.read_value(), 1.0)  # should be cached 1.0 value.
1170
1171      # Reset v to 2.0
1172      v.assign(2.0)
1173
1174      # Test convert to tensor value inside caching scope
1175      with distribute_utils.cache_variable_reads():
1176        t = v * 3.0
1177        self.assertEqual(t, 6.0)
1178        v.assign(3.0)
1179        t1 = v * 3.0
1180        self.assertEqual(t1, 6.0)  # should be cached 2.0 * 3.0 value.
1181
1182      # Reset v to 1.0
1183      v.assign(1.0)
1184
1185      # Verify caching scope inside tf.function
1186      @def_function.function
1187      def worker_fn():
1188        with distribute_utils.cache_variable_reads():
1189          def replica_fn():
1190            t = v.read_value()  # Reads value 1.0
1191            v.assign(constant_op.constant(5.0))  # v changes to 5.0
1192            t = v.read_value()  # should return 1.0
1193            return t  # Should be 1.0 instead of 5.0
1194
1195          return self.strategy.run(replica_fn)
1196
1197      result = self.coordinator.schedule(worker_fn)
1198      result = result.fetch()
1199      expected_result = 1.
1200      self.assertEqual(result, expected_result)
1201
1202      # Verify that v.read_value works as expected outside of scope.
1203      v.assign(4.0)
1204      self.assertEqual(v.read_value(), 4.0)
1205
1206      v.assign(constant_op.constant(2.0))  # v changes to 2.0
1207      # Check with scope outside of tf function and check that cache is reset
1208      @def_function.function
1209      def worker_fn1():
1210        def replica_fn():
1211          t = v.read_value()  # Reads value 2.0 ==> Should be cached
1212          v.assign(constant_op.constant(5.0))  # v changes to 5.0
1213          t = v.read_value()  # should return cached value 2.0
1214          return t  # Should be 2.0 instead of 5.0
1215
1216        return self.strategy.run(replica_fn)
1217
1218      with distribute_utils.cache_variable_reads():
1219        result = self.coordinator.schedule(worker_fn1)
1220      result = result.fetch()
1221      expected_result = 2.
1222      self.assertEqual(result, expected_result)
1223
1224    # Verify scope nesting is not permitted.
1225    with self.assertRaises(ValueError):
1226      with distribute_utils.cache_variable_reads():
1227        with distribute_utils.cache_variable_reads():
1228          v.read_value()
1229
1230  @parameterized.parameters(True, False)
1231  def testDistributedDatasetInsidePerWorkerDatasetFn(self, from_function):
1232    if from_function:
1233
1234      def per_worker_dataset_fn():
1235        dataset_fn = lambda _: dataset_ops.DatasetV2.range(1, 11).batch(4)
1236        return self.strategy.distribute_datasets_from_function(dataset_fn)
1237    else:
1238
1239      def per_worker_dataset_fn():
1240        dataset = dataset_ops.DatasetV2.range(1, 11).batch(4)
1241        return self.strategy.experimental_distribute_dataset(dataset)
1242
1243    @def_function.function
1244    def worker_fn(iterator):
1245      return self.strategy.experimental_local_results(next(iterator))
1246
1247    per_worker_dataset = self.coordinator.create_per_worker_dataset(
1248        per_worker_dataset_fn)
1249    result = self.coordinator.schedule(
1250        worker_fn, args=(iter(per_worker_dataset),))
1251    result = result.fetch()
1252    expected_result = array_ops.split(
1253        math_ops.range(1., 5.),
1254        num_or_size_splits=self.strategy.num_replicas_in_sync,
1255        axis=0)
1256
1257    self.assertAllEqual(result, (expected_result))
1258
1259  @parameterized.parameters(True, False)
1260  def testPassDistributedDatasetToCreatePerWorkerDataset(self, from_function):
1261    if from_function:
1262      dataset_fn = lambda _: dataset_ops.DatasetV2.range(1, 11).batch(4)
1263      distributed_dataset = self.strategy.distribute_datasets_from_function(
1264          dataset_fn)
1265    else:
1266      dataset = dataset_ops.DatasetV2.range(1, 11).batch(4)
1267      distributed_dataset = self.strategy.experimental_distribute_dataset(
1268          dataset)
1269
1270    @def_function.function
1271    def worker_fn(iterator):
1272      return self.strategy.experimental_local_results(next(iterator))
1273
1274    per_worker_dataset = self.coordinator.create_per_worker_dataset(
1275        distributed_dataset)
1276    result = self.coordinator.schedule(
1277        worker_fn, args=(iter(per_worker_dataset),))
1278    result = result.fetch()
1279    expected_result = array_ops.split(
1280        math_ops.range(1., 5.),
1281        num_or_size_splits=self.strategy.num_replicas_in_sync,
1282        axis=0)
1283
1284    self.assertAllEqual(result, (expected_result))
1285
1286  def testDistributeDatasetsFromFunction(self):
1287
1288    def per_worker_dataset_fn():
1289
1290      def input_worker_device_fn(input_context):
1291        self.assertIsNotNone(input_context)
1292        return dataset_ops.DatasetV2.range(1, 11).batch(1)
1293
1294      return self.strategy.distribute_datasets_from_function(
1295          input_worker_device_fn)
1296
1297    @def_function.function
1298    def worker_fn(iterator):
1299      result = self.strategy.experimental_local_results(next(iterator))
1300      return result
1301
1302    distributed_dataset = self.coordinator.create_per_worker_dataset(
1303        per_worker_dataset_fn)
1304    result = self.coordinator.schedule(
1305        worker_fn, args=(iter(distributed_dataset),))
1306    result = result.fetch()
1307    expected_result = []
1308    for i in range(self.strategy.num_replicas_in_sync):
1309      expected_result.append([1 + i])
1310    self.assertAllEqual(result, expected_result)
1311
1312  def testAsyncScheduleWithDistributedDataset(self):
1313
1314    def input_fn():
1315      dataset = dataset_ops.DatasetV2.from_tensor_slices([2.]).repeat().batch(
1316          self.strategy.num_replicas_in_sync)
1317      return self.strategy.experimental_distribute_dataset(dataset)
1318
1319    with self.strategy.scope():
1320      v = variables.Variable(initial_value=[0], dtype=dtypes.float32)
1321
1322    # TODO(yuefengz): the following tf.function has a return value which is None
1323    # in its structured_outputs.
1324    @def_function.function
1325    def worker_fn(iterator):
1326      x = next(iterator)
1327      # Reduce to convert PerReplica values to single value
1328      reduced_value = self.strategy.reduce('MEAN', x, axis=None)
1329      v.assign_add(reduced_value)
1330
1331    distributed_dataset = self.coordinator.create_per_worker_dataset(input_fn)
1332
1333    iterator = iter(distributed_dataset)
1334
1335    # Verifying joining without any scheduling doesn't hang.
1336    self.coordinator.join()
1337    self.assertAllEqual(v.read_value(), (0,))
1338
1339    for _ in range(5):
1340      self.coordinator.schedule(worker_fn, args=(iterator,))
1341    self.coordinator.join()
1342
1343    # With 5 addition it should be 2*5 = 10.
1344    self.assertAllEqual(
1345        self.strategy.experimental_local_results(v.read_value()), ([[10]]))
1346
1347    for _ in range(5):
1348      self.coordinator.schedule(worker_fn, args=(iterator,))
1349
1350    # Verifying multiple join is fine.
1351    self.coordinator.join()
1352    self.coordinator.join()
1353    self.coordinator.join()
1354
1355    self.assertTrue(self.coordinator.done())
1356
1357    # Likewise, it's now 20.
1358    self.assertAllEqual(
1359        self.strategy.experimental_local_results(v.read_value()), ([[20]]))
1360
1361  def testInputFunctionWithMapWithDistributedDataset(self):
1362    self._map_fn_tracing_count = 0
1363
1364    def input_fn():
1365
1366      def map_fn(x):
1367        self._map_fn_tracing_count += 1
1368        return x + 10
1369
1370      dataset = dataset_ops.DatasetV2.range(0, 10).batch(
1371          self.strategy.num_replicas_in_sync).map(map_fn)
1372      return self.strategy.experimental_distribute_dataset(dataset)
1373
1374    @def_function.function
1375    def worker_fn(iterator):
1376      return next(iterator)
1377
1378    distributed_dataset = self.coordinator.create_per_worker_dataset(input_fn)
1379    result = self.coordinator.schedule(
1380        worker_fn, args=(iter(distributed_dataset),))
1381
1382    expected_result = array_ops.split(
1383        math_ops.range(10., 10. + self.strategy.num_replicas_in_sync),
1384        num_or_size_splits=self.strategy.num_replicas_in_sync,
1385        axis=0)
1386
1387    self.assertAllEqual(
1388        self.strategy.experimental_local_results(result.fetch()),
1389        tuple(expected_result))
1390    self.assertEqual(self._map_fn_tracing_count, 1)
1391
1392  def testPerWorkerDistributeDatasetsElementSpec(self):
1393
1394    def per_worker_dataset_fn():
1395      return self.strategy.distribute_datasets_from_function(
1396          lambda _: dataset_ops.DatasetV2.from_tensor_slices([1, 2]))
1397
1398    dataset = dataset_ops.DatasetV2.from_tensor_slices([1, 2])
1399    per_worker_distribute_dataset = self.coordinator.create_per_worker_dataset(
1400        per_worker_dataset_fn)
1401
1402    self.assertAllEqual(
1403        # Converts to PerReplicaSpec when num_replicas_in_sync are > 1
1404        input_lib._create_distributed_tensor_spec(self.strategy,
1405                                                  dataset.element_spec),
1406        per_worker_distribute_dataset.element_spec)
1407
1408  def testPerWorkerDistributedIteratorTypeSpec(self):
1409    self._tracing_count = 0
1410
1411    def per_worker_dataset_fn():
1412      self._tracing_count += 1
1413      return self.strategy.distribute_datasets_from_function(
1414          lambda _: dataset_ops.DatasetV2.range(1, 2))
1415
1416    @def_function.function
1417    def worker_fn(iterator):
1418      return next(iterator)
1419
1420    distributed_iterator = iter(
1421        self.coordinator.create_per_worker_dataset(per_worker_dataset_fn))
1422    worker_fn.get_concrete_function(distributed_iterator)
1423
1424    self.coordinator.schedule(worker_fn, args=(distributed_iterator,))
1425    self.assertEqual(self._tracing_count, 1)
1426
1427
1428if __name__ == '__main__':
1429  v2_compat.enable_v2_behavior()
1430  test.main()
1431