• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for custom training loops."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from absl.testing import parameterized
22
23from tensorflow.python import tf2
24from tensorflow.python.data.ops import dataset_ops
25from tensorflow.python.distribute import combinations
26from tensorflow.python.distribute import device_util
27from tensorflow.python.distribute import distribute_lib
28from tensorflow.python.distribute import reduce_util
29from tensorflow.python.distribute import strategy_combinations
30from tensorflow.python.distribute import test_util
31from tensorflow.python.eager import def_function
32from tensorflow.python.eager import test
33from tensorflow.python.framework import constant_op
34from tensorflow.python.framework import dtypes
35from tensorflow.python.framework import errors
36from tensorflow.python.framework import ops
37from tensorflow.python.ops import array_ops
38from tensorflow.python.ops import control_flow_ops
39from tensorflow.python.ops import map_fn
40from tensorflow.python.ops import math_ops
41from tensorflow.python.ops import variables
42from tensorflow.python.ops.losses import losses
43from tensorflow.python.tpu import tpu
44from tensorflow.python.util import nest
45
46
47def get_dataset_from_tensor_slices(inp_array):
48  dataset = dataset_ops.DatasetV2.from_tensor_slices(inp_array)
49  # TODO(b/138326910): Remove Dataset V1 version once bug resolved.
50  if not tf2.enabled():
51    dataset = dataset_ops.Dataset.from_tensor_slices(inp_array)
52  return dataset
53
54
55class AssertFlattenedMixin(object):
56  """Mixin for specialized asserts."""
57
58  def assert_equal_flattened(self, expected_results, actual_results):
59    """Asserts that flattened results are equal.
60
61    Due to the number of replicas in the strategy, the output may have a
62    different structure and needs to be flattened for comparison.
63
64    Args:
65      expected_results: The results expected as a result of a computation.
66      actual_results: The actual results of a computation.
67    """
68    self.assertEqual(len(expected_results), len(actual_results))
69
70    for i, expected_result in enumerate(expected_results):
71      final_result = []
72      actual_result = actual_results[i]
73      for val in actual_result:
74        final_result.extend(val.numpy())
75      self.assertAllEqual(expected_result, final_result)
76
77
78class InputIterationTest(test.TestCase, parameterized.TestCase,
79                         AssertFlattenedMixin):
80
81  @combinations.generate(
82      combinations.combine(
83          distribution=strategy_combinations.all_strategies,
84          mode=["eager"]
85      ))
86  def testConstantNumpyInput(self, distribution):
87
88    @def_function.function
89    def run(x):
90
91      def computation(x):
92        return math_ops.square(x)
93
94      outputs = distribution.experimental_local_results(
95          distribution.run(computation, args=(x,)))
96      return outputs
97
98    self.assertAllEqual(
99        constant_op.constant(4., shape=(distribution.num_replicas_in_sync)),
100        run(2.))
101
102  @combinations.generate(
103      combinations.combine(
104          distribution=strategy_combinations.all_strategies,
105          mode=["eager"]
106      ))
107  def testStatefulExperimentalRunAlwaysExecute(self, distribution):
108    with distribution.scope():
109      v = variables.Variable(
110          0.0, aggregation=variables.VariableAggregation.MEAN)
111
112    @def_function.function
113    def train_step():
114
115      def assign_add():
116        v.assign_add(1.0)
117
118      distribution.run(assign_add)
119      return array_ops.zeros([])
120
121    train_step()
122    self.assertAllEqual(1.0, v.numpy())
123
124  @combinations.generate(
125      combinations.combine(
126          distribution=strategy_combinations.strategies_minus_tpu,
127          mode=["eager"]))
128  def testFullEager(self, distribution):
129    dataset = get_dataset_from_tensor_slices([5., 6., 7., 8.]).batch(2)
130
131    def train_step(data):
132      return math_ops.square(data)
133
134    dist_dataset = distribution.experimental_distribute_dataset(dataset)
135    results = []
136    for x in dist_dataset:
137      output = distribution.experimental_local_results(
138          distribution.run(train_step, args=(x,)))
139      results.append(output)
140    self.assert_equal_flattened([[25., 36.], [49., 64.]], results)
141
142  @combinations.generate(
143      combinations.combine(
144          distribution=strategy_combinations.all_strategies, mode=["eager"]))
145  def testGetNextAsOptional(self, distribution):
146    data = [5., 6., 7., 8.]
147    dataset = get_dataset_from_tensor_slices(data).batch(2)
148    dist_dataset = distribution.experimental_distribute_dataset(dataset)
149    iterator = iter(dist_dataset)
150
151    def train_step(data):
152      return math_ops.square(data)
153
154    @def_function.function
155    def run(iterator):
156      return distribution.experimental_local_results(
157          distribution.run(
158              train_step, args=(iterator.get_next_as_optional().get_value(),)))
159
160    self.assert_equal_flattened([[25., 36.]], [run(iterator)])
161
162  @combinations.generate(
163      combinations.combine(
164          distribution=strategy_combinations.all_strategies, mode=["eager"]))
165  def testGetNextAsOptionalExampleUsage(self, distribution):
166    global_batch_size = 2
167    steps_per_loop = 6
168    dataset = dataset_ops.Dataset.range(
169        8, output_type=dtypes.int32).batch(global_batch_size)
170    distributed_iterator = iter(
171        distribution.experimental_distribute_dataset(dataset))
172
173    @def_function.function
174    def train_fn(distributed_iterator):
175
176      def step_fn(x):
177        return x
178
179      for _ in math_ops.range(steps_per_loop):
180        optional_data = distributed_iterator.get_next_as_optional()
181        if not optional_data.has_value():
182          break
183        distribution.run(step_fn, args=(optional_data.get_value(),))
184
185    train_fn(distributed_iterator)
186
187  @combinations.generate(
188      combinations.combine(
189          distribution=strategy_combinations.tpu_strategies, mode=["eager"]))
190  def testFullEagerTPU(self, distribution):
191    dataset = get_dataset_from_tensor_slices([5., 6., 7., 8.]).batch(2)
192
193    def train_step(data):
194      return math_ops.square(data)
195
196    input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
197
198    with self.assertRaisesRegex(NotImplementedError,
199                                "does not support pure eager execution"):
200      distribution.run(train_step, args=(next(input_iterator),))
201
202  @combinations.generate(
203      combinations.combine(
204          distribution=strategy_combinations.all_strategies,
205          mode=["eager"]
206      ))
207  def testStepInFunction(self, distribution):
208    dataset = get_dataset_from_tensor_slices([5., 6., 7., 8.]).batch(2)
209
210    @def_function.function
211    def train_step(data):
212      return math_ops.square(data)
213
214    dist_dataset = distribution.experimental_distribute_dataset(dataset)
215    results = []
216    for x in dist_dataset:
217      output = distribution.experimental_local_results(
218          distribution.run(train_step, args=(x,)))
219      results.append(output)
220    self.assert_equal_flattened([[25., 36.], [49., 64.]], results)
221
222  @combinations.generate(
223      combinations.combine(
224          distribution=strategy_combinations.all_strategies,
225          mode=["eager"]
226      ))
227  def testRunInFunction(self, distribution):
228    dataset = get_dataset_from_tensor_slices([5., 6., 7., 8.]).batch(2)
229
230    def train_step(data):
231      return math_ops.square(data)
232
233    @def_function.function
234    def f_train_step(input_data):
235      return distribution.experimental_local_results(
236          distribution.run(train_step, args=(input_data,)))
237
238    dist_dataset = distribution.experimental_distribute_dataset(dataset)
239    results = []
240    for x in dist_dataset:
241      output = f_train_step(x)
242      results.append(output)
243    self.assert_equal_flattened([[25., 36.], [49., 64.]], results)
244
245  @combinations.generate(
246      combinations.combine(
247          distribution=[
248              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
249              strategy_combinations.tpu_strategy,
250              strategy_combinations.tpu_strategy_packed_var,
251          ],
252          mode=["eager"]))
253  def testNestedOutput(self, distribution):
254    dataset = get_dataset_from_tensor_slices([0, 1, 2, 3]).batch(2)
255    input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
256
257    @def_function.function
258    def run(iterator):
259
260      def computation(x):
261        return [{
262            "a": x - 1,
263            "b": x + 1
264        }]
265
266      inputs = next(iterator)
267      outputs = distribution.run(computation, args=(inputs,))
268      return nest.map_structure(distribution.experimental_local_results,
269                                outputs)
270
271    results = run(input_iterator)
272    for replica in range(distribution.num_replicas_in_sync):
273      # The input dataset is range(4), so the replica id is same as input.
274      self.assertAllEqual(results[0]["a"][replica], [replica - 1])
275      self.assertAllEqual(results[0]["b"][replica], [replica + 1])
276
277  @combinations.generate(
278      combinations.combine(
279          distribution=strategy_combinations.all_strategies,
280          mode=["eager"]
281      ))
282  def testRunInFunctionAutoGraphApplication(self, distribution):
283    dataset = get_dataset_from_tensor_slices([5., 6., 7., 8.]).batch(2)
284
285    def train_step(data):
286      return math_ops.square(data)
287
288    @def_function.function
289    def f_train_step(input_data):
290      return distribution.experimental_local_results(
291          distribution.run(train_step, args=(input_data,)))
292
293    dist_dataset = distribution.experimental_distribute_dataset(dataset)
294    results = []
295    for x in dist_dataset:
296      output = f_train_step(x)
297      results.append(output)
298    self.assert_equal_flattened([[25., 36.], [49., 64.]], results)
299
300  @combinations.generate(
301      combinations.combine(
302          distribution=strategy_combinations.all_strategies,
303          mode=["eager"]
304      ))
305  def testDatasetIterationInFunction(self, distribution):
306    with distribution.scope():
307      a = variables.Variable(
308          1.0, aggregation=variables.VariableAggregation.ONLY_FIRST_REPLICA)
309
310    def train_step(_):
311      a.assign_add(1.0)
312
313    @def_function.function
314    def f_train_step(dist_dataset):
315      number_of_steps = constant_op.constant(0.0)
316      product_of_means = constant_op.constant(2.0)
317      for x in dist_dataset:  # loop with values modified each iteration
318        number_of_steps += 1
319        product_of_means *= math_ops.cast(
320            distribution.reduce("MEAN", x, axis=0), product_of_means.dtype)
321
322      for y in dist_dataset:  # loop with no intermediate state
323        distribution.run(train_step, args=(y,))
324
325      return number_of_steps, product_of_means
326
327    dataset = get_dataset_from_tensor_slices([5., 6., 7., 8.]).batch(2)
328    dist_dataset = distribution.experimental_distribute_dataset(dataset)
329
330    number_of_steps, product_of_means = f_train_step(dist_dataset)
331    self.assertEqual(2, number_of_steps.numpy())
332    self.assertNear((2 * (5+6)/2 * (7+8)/2), product_of_means.numpy(), 1e-3)
333
334    # We set the initial value of `a` to 1 and iterate through the dataset 2
335    # times(4/2 where 4 is the number of dataset elements and 2 is the batch
336    # size). Hence the final result is 3.
337    self.assertEqual(3.0, (a.numpy()))
338
339  @combinations.generate(
340      combinations.combine(
341          distribution=strategy_combinations.all_strategies,
342          mode=["eager"]
343      ))
344  def testDatasetAssertWithDynamicBatch(self, distribution):
345    # Regression test for github issue 33517.
346    def step_fn(data):
347      assert_op = control_flow_ops.Assert(math_ops.less_equal(
348          math_ops.reduce_max(data), 100.), [data])
349      with ops.control_dependencies([assert_op]):
350        return math_ops.square(data)
351
352    @def_function.function
353    def train(dataset):
354      results = []
355      iterator = iter(dataset)
356      # we iterate through the loop 5 times since we have 3 elements and a
357      # global batch of 2.
358      for _ in range(2):
359        elem = next(iterator)
360        output = distribution.experimental_local_results(
361            distribution.run(step_fn, args=(elem,)))
362        results.append(output)
363      return results
364
365    dataset = dataset_ops.DatasetV2.from_tensor_slices([5., 6., 7.,]).batch(2)
366    # TODO(b/138326910): Remove Dataset V1 version once bug resolved.
367    if not tf2.enabled():
368      dataset = dataset_ops.Dataset.from_tensor_slices([5., 6., 7.,]).batch(2)
369    dist_dataset = distribution.experimental_distribute_dataset(dataset)
370    results = train(dist_dataset)
371
372    expected_results = [[25., 36.], [49.]]
373    self.assertEqual(len(expected_results), len(results))
374
375    # Need to expand results since output will be grouped differently depending
376    # on the number of replicas.
377    for i, expected_result in enumerate(expected_results):
378      final_result = []
379      actual_result = results[i]
380      for val in actual_result:
381        final_result.extend(val.numpy())
382      self.assertAllEqual(expected_result, final_result)
383
384  @combinations.generate(
385      combinations.combine(
386          distribution=strategy_combinations.all_strategies,
387          mode=["eager"]
388      ))
389  def testDistributeDatasetIteratorWithoutFunction(self, distribution):
390    data = [5., 6., 7., 8.]
391    input_iterator = iter(
392        distribution.distribute_datasets_from_function(
393            lambda _: get_dataset_from_tensor_slices(data)))
394
395    self.assertAllEqual(
396        distribution.experimental_local_results(input_iterator.get_next()),
397        data[0:distribution.num_replicas_in_sync])
398
399  @combinations.generate(
400      combinations.combine(
401          distribution=strategy_combinations.multidevice_strategies,
402          mode=["eager"]
403      ))
404  def testDistributeDatasetIteratorWithFunction(self, distribution):
405    data = [5., 6., 7., 8.]
406    input_iterator = iter(
407        distribution.distribute_datasets_from_function(
408            lambda _: get_dataset_from_tensor_slices(data)))
409
410    @def_function.function
411    def run(iterator):
412      return distribution.experimental_local_results(iterator.get_next())
413
414    local_results = run(input_iterator)
415    self.assertAllEqual(local_results,
416                        data[0:distribution.num_replicas_in_sync])
417    backing_devices = [result.backing_device for result in local_results]
418    self.assertAllEqual(backing_devices, distribution.extended.worker_devices)
419
420  @combinations.generate(
421      combinations.combine(
422          distribution=strategy_combinations.multidevice_strategies,
423          mode=["eager"]
424      ))
425  def testDistributeDatasetPrefetch(self, distribution):
426    data = [5., 6., 7., 8.]
427    input_iterator = iter(
428        distribution.experimental_distribute_dataset(
429            get_dataset_from_tensor_slices(data).batch(2)))
430
431    local_results = distribution.experimental_local_results(
432        input_iterator.get_next())
433
434    backing_devices = [result.backing_device for result in local_results]
435    self.assertAllEqual(backing_devices, distribution.extended.worker_devices)
436
437  @combinations.generate(
438      combinations.combine(
439          distribution=strategy_combinations.multidevice_strategies,
440          mode=["eager"]
441      ))
442  def testDistributeDatasetFunctionPrefetch(self, distribution):
443    data = [5., 6., 7., 8.]
444    input_iterator = iter(
445        distribution.distribute_datasets_from_function(
446            lambda _: get_dataset_from_tensor_slices(data)))
447
448    local_results = distribution.experimental_local_results(
449        input_iterator.get_next())
450
451    backing_devices = [result.backing_device for result in local_results]
452    self.assertAllEqual(backing_devices, distribution.extended.worker_devices)
453
454  @combinations.generate(
455      combinations.combine(
456          distribution=strategy_combinations.tpu_strategies,
457          mode=["eager"]
458      ))
459  def testDistributeDatasetHostPrefetch(self, distribution):
460    data = [5., 6., 7., 8.]
461    input_iterator = iter(
462        distribution.experimental_distribute_dataset(
463            get_dataset_from_tensor_slices(data).batch(2),
464            distribute_lib.InputOptions(experimental_fetch_to_device=False)))
465
466    local_results = distribution.experimental_local_results(
467        input_iterator.get_next())
468
469    for result in local_results:
470      self.assertEqual(result.backing_device,
471                       device_util.resolve("/device:CPU:0"))
472
473  @combinations.generate(
474      combinations.combine(
475          distribution=strategy_combinations.tpu_strategies,
476          mode=["eager"]
477      ))
478  def testDistributeDatasetFunctionHostPrefetch(self, distribution):
479    data = [5., 6., 7., 8.]
480    input_iterator = iter(
481        distribution.distribute_datasets_from_function(
482            lambda _: get_dataset_from_tensor_slices(data),
483            distribute_lib.InputOptions(experimental_fetch_to_device=False)))
484
485    local_results = distribution.experimental_local_results(
486        input_iterator.get_next())
487
488    for result in local_results:
489      self.assertEqual(result.backing_device,
490                       device_util.resolve("/device:CPU:0"))
491
492  @combinations.generate(
493      combinations.combine(
494          distribution=strategy_combinations.multidevice_strategies,
495          mode=["eager"]
496      ))
497  def testDynamicShapes(self, distribution):
498    dataset = get_dataset_from_tensor_slices([5., 6., 7.]).batch(4)
499    input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
500
501    @def_function.function
502    def run(iterator):
503      def computation(x):
504        return math_ops.reduce_mean(x)
505      inputs = next(iterator)
506      outputs = distribution.experimental_local_results(
507          distribution.run(computation, args=(inputs,)))
508      return outputs
509
510    # This assumes that there are exactly 2 replicas
511    self.assertAllEqual([5.5, 7.], run(input_iterator))
512
513  @combinations.generate(
514      combinations.combine(
515          distribution=strategy_combinations.tpu_strategy, mode=["eager"]))
516  def testDynamicShapesWithRunOptionsBucketizing(self, distribution):
517    dataset = get_dataset_from_tensor_slices([5., 6., 7.]).batch(4)
518    input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
519    options = distribute_lib.RunOptions(
520        experimental_bucketizing_dynamic_shape=True)
521
522    @def_function.function
523    def run(iterator):
524
525      def computation(x):
526        return math_ops.reduce_mean(x)
527
528      inputs = next(iterator)
529      outputs = distribution.experimental_local_results(
530          distribution.run(
531              computation, args=(inputs,), options=options))
532      return outputs
533
534    # This assumes that there are exactly 2 replicas
535    self.assertAllEqual([5.5, 7.], run(input_iterator))
536
537  @combinations.generate(
538      combinations.combine(
539          distribution=strategy_combinations.tpu_strategy, mode=["eager"]))
540  def testDynamicShapesWithRunOptionsDisableDynamicPadder(self, distribution):
541    dataset = get_dataset_from_tensor_slices([5, 6, 7]).batch(4)
542    mask_dataset = get_dataset_from_tensor_slices([1, 0, 1]).batch(4)
543    dataset = dataset_ops.DatasetV2.zip((dataset, mask_dataset))
544
545    input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
546    options = distribute_lib.RunOptions(
547        experimental_xla_options=tpu.XLAOptions(
548            enable_xla_dynamic_padder=False))
549
550    @def_function.function
551    def run(iterator):
552
553      def computation(inputs):
554        x, mask = inputs
555        y = x * mask
556        return math_ops.reduce_sum(y)
557
558      inputs = next(iterator)
559      outputs = distribution.experimental_local_results(
560          distribution.run(computation, args=(inputs,), options=options))
561      return outputs
562
563    # This assumes that there are exactly 2 replicas
564    self.assertAllEqual([5, 7], run(input_iterator))
565
566  @combinations.generate(
567      combinations.combine(
568          distribution=strategy_combinations.multidevice_strategies,
569          mode=["eager"]))
570  def testDynamicOutputsWithX64(self, distribution):
571    dataset = get_dataset_from_tensor_slices(
572        [5]).map(lambda x: math_ops.cast(x, dtypes.int64)).batch(2)
573    input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
574
575    @def_function.function
576    def run(iterator):
577
578      def computation(x):
579        return math_ops.add(x, x)
580
581      inputs = next(iterator)
582      outputs = distribution.experimental_local_results(
583          distribution.run(computation, args=(inputs,)))
584      return outputs
585
586    # This assumes that there are exactly 2 replicas
587    result = run(input_iterator)
588    self.assertAllEqual([10], result[0])
589    self.assertAllEqual([], result[1])
590
591  @combinations.generate(
592      combinations.combine(
593          distribution=strategy_combinations.multidevice_strategies,
594          mode=["eager"]
595      ))
596  def testDynamicShapesWithGetNextOutsideFunction(self, distribution):
597    dataset = get_dataset_from_tensor_slices([5., 6., 7.]).batch(4)
598    input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
599
600    @def_function.function
601    def run(inputs):
602      def computation(x):
603        return math_ops.reduce_mean(x)
604      outputs = distribution.experimental_local_results(
605          distribution.run(computation, args=(inputs,)))
606      return outputs
607
608    # This assumes that there are exactly 2 replicas
609    self.assertAllEqual([5.5, 7.], run(next(input_iterator)))
610
611  @combinations.generate(
612      combinations.combine(
613          distribution=strategy_combinations.multidevice_strategies,
614          mode=["eager"]
615      ))
616  def testStrategyReduceWithDynamicShapes(self, distribution):
617    dataset = get_dataset_from_tensor_slices([5., 6., 7.]).batch(4)
618    input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
619
620    @def_function.function
621    def run(iterator):
622      inputs = next(iterator)
623      return distribution.reduce(reduce_util.ReduceOp.MEAN, inputs, axis=0)
624
625    self.assertAllEqual(6., run(input_iterator))
626
627  @combinations.generate(
628      combinations.combine(
629          distribution=strategy_combinations.multidevice_strategies,
630          mode=["eager"]
631      ))
632  def testStrategyReduceWithDynamicShapesRank2(self, distribution):
633    dataset = get_dataset_from_tensor_slices(
634        [[1., 1.], [1., 1.], [1., 1.]]).batch(4)
635    input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
636
637    @def_function.function
638    def run(iterator):
639      inputs = next(iterator)
640      return distribution.reduce(reduce_util.ReduceOp.MEAN, inputs, axis=0)
641
642    self.assertAllEqual([1., 1.], run(input_iterator))
643
644  @combinations.generate(
645      combinations.combine(
646          distribution=strategy_combinations.multidevice_strategies,
647          mode=["eager"]
648      ))
649  def testDynamicShapesWithSizeOp(self, distribution):
650    dataset = get_dataset_from_tensor_slices([5., 6., 7.]).batch(4)
651    input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
652
653    @def_function.function
654    def run(inputs):
655      def computation(x):
656        return array_ops.size_v2(x)
657      outputs = distribution.experimental_local_results(
658          distribution.run(computation, args=(inputs,)))
659      return outputs
660
661    # This assumes that there are exactly 2 replicas
662    self.assertAllEqual([2, 1], run(next(input_iterator)))
663
664  @combinations.generate(
665      combinations.combine(
666          distribution=strategy_combinations.multidevice_strategies,
667          mode=["eager"]))
668  def testSegmentSumWithDynamicNumberOfSegments(self, distribution):
669
670    def dataset_fn(_):
671      data = array_ops.zeros(5, dtype=dtypes.int32)
672      dataset = get_dataset_from_tensor_slices(data)
673      dataset = dataset.batch(3)
674      return dataset
675
676    input_iterator = iter(
677        distribution.distribute_datasets_from_function(dataset_fn))
678
679    @def_function.function
680    def step_fn(example):
681      segment_ids = array_ops.zeros_like_v2(example)
682      num_segment = array_ops.shape(example)[0]
683      # If number of segments is dynamic, output should be a dynamic shape.
684      return math_ops.unsorted_segment_sum(example, segment_ids, num_segment)
685
686    # This assumes that there are exactly 2 replicas
687    outputs = distribution.experimental_local_results(
688        distribution.run(step_fn, args=(next(input_iterator),)))
689    self.assertAllEqual((3,), outputs[0].shape)
690    self.assertAllEqual((2,), outputs[1].shape)
691
692  @combinations.generate(
693      combinations.combine(
694          distribution=strategy_combinations.multidevice_strategies,
695          mode=["eager"]))
696  def testReshapeWithDynamicInputs(self, distribution):
697
698    def dataset_fn(_):
699      data = array_ops.zeros((5, 1, 2), dtype=dtypes.int32)
700      dataset = get_dataset_from_tensor_slices(data)
701      dataset = dataset.batch(3)
702      return dataset
703
704    input_iterator = iter(
705        distribution.distribute_datasets_from_function(dataset_fn))
706
707    @def_function.function
708    def step_fn(example):
709      # example: [<=3, 1, 2]
710      # tile: [<=3, <=3, 2]
711      tile = array_ops.tile(example, [1, array_ops.shape(example)[0], 1])
712      # reshape1: [<=(3*3 = 9), 2]
713      reshape1 = array_ops.reshape(tile, [-1, 2])
714
715      # reshape2: [<=3, <=3, 2]
716      reshape2 = array_ops.reshape(
717          reshape1,
718          [array_ops.shape(example)[0],
719           array_ops.shape(example)[0], 2])
720
721      # reshape3: [<=3, -1, 2]
722      reshape3 = array_ops.reshape(reshape1,
723                                   [array_ops.shape(example)[0], -1, 2])
724      # reshape4: [-1, <=3, 2]
725      reshape4 = array_ops.reshape(reshape1,
726                                   [-1, array_ops.shape(example)[0], 2])
727      # Reshape1 is duplicated in order to test dynamic dimension on copies.
728      return [reshape1, reshape2, reshape3, reshape4, reshape1]
729
730    # This assumes that there are exactly 2 replicas
731    outputs = distribution.experimental_local_results(
732        distribution.run(step_fn, args=(next(input_iterator),)))
733    self.assertAllEqual((9, 2), outputs[0][0].shape)
734    self.assertAllEqual((3, 3, 2), outputs[0][1].shape)
735    self.assertAllEqual((3, 3, 2), outputs[0][2].shape)
736    self.assertAllEqual((3, 3, 2), outputs[0][3].shape)
737    self.assertAllEqual((9, 2), outputs[0][4].shape)
738
739    self.assertAllEqual((4, 2), outputs[1][0].shape)
740    self.assertAllEqual((2, 2, 2), outputs[1][1].shape)
741    self.assertAllEqual((2, 2, 2), outputs[1][2].shape)
742    self.assertAllEqual((2, 2, 2), outputs[1][3].shape)
743    self.assertAllEqual((4, 2), outputs[1][4].shape)
744
745  @combinations.generate(
746      combinations.combine(
747          distribution=strategy_combinations.multidevice_strategies,
748          mode=["eager"]))
749  def testDynamicShapesWithFirstReplicaNotMaximumShape(self, distribution):
750    def dataset_fn(_):
751      dataset1 = get_dataset_from_tensor_slices([[1., 2.], [1., 2.]])
752      dataset2 = get_dataset_from_tensor_slices([[1., 2., 3.],
753                                                 [1., 2., 3.]])
754      dataset = dataset1.concatenate(dataset2)
755      dataset = dataset.batch(2, drop_remainder=True)
756      return dataset
757
758    input_iterator = iter(
759        distribution.distribute_datasets_from_function(dataset_fn))
760
761    @def_function.function
762    def run(inputs):
763      def computation(x):
764        return math_ops.reduce_mean(x)
765      outputs = distribution.experimental_local_results(
766          distribution.run(computation, args=(inputs,)))
767      return outputs
768
769    # This assumes that there are exactly 2 replicas
770    self.assertAllEqual([1.5, 2.], run(next(input_iterator)))
771
772  @combinations.generate(
773      combinations.combine(
774          distribution=strategy_combinations.multidevice_strategies,
775          mode=["eager"]))
776  def testMapFnWithDynamicInputs(self, distribution):
777
778    def dataset_fn(_):
779      data = array_ops.zeros((20, 300, 32), dtype=dtypes.int32)
780      dataset = get_dataset_from_tensor_slices(data)
781      dataset = dataset.batch(16)
782      return dataset
783
784    input_iterator = iter(
785        distribution.distribute_datasets_from_function(dataset_fn))
786
787    def embedding_lookup(inputs):
788      embedding_weights = array_ops.zeros((1, 128))
789      flat_inputs = array_ops.reshape(inputs, [-1])
790      embeddings = array_ops.gather(embedding_weights, flat_inputs)
791      embeddings = array_ops.reshape(embeddings, inputs.shape.as_list() + [128])
792      return embeddings
793
794    @def_function.function
795    def step_fn(example):
796      return map_fn.map_fn(
797          embedding_lookup, example, fn_output_signature=dtypes.float32)
798
799    # This assumes that there are exactly 2 replicas
800    outputs = distribution.experimental_local_results(
801        distribution.run(step_fn, args=(next(input_iterator),)))
802    self.assertAllEqual((16, 300, 32, 128), outputs[0].shape)
803    self.assertAllEqual((4, 300, 32, 128), outputs[1].shape)
804
805  @combinations.generate(
806      combinations.combine(
807          distribution=strategy_combinations.all_strategies,
808          mode=["eager"]
809      ))
810  def testDatasetDistributeEvenlyDivisibleDrop(self, distribution):
811    # If the batch size is evenly divisible by the number of workers and we set
812    # drop_remainder=True on the dataset, then DistributedIterator will use a
813    # different (and more efficient) code path which avoids some control flow
814    # ops.
815    dataset = get_dataset_from_tensor_slices([5., 6.]).batch(
816        2, drop_remainder=True)
817    input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
818
819    data = next(input_iterator)
820
821    expected_result = [5., 6.]
822    final_result = []
823    actual_result = distribution.experimental_local_results(data)
824    for val in actual_result:
825      final_result.extend(val)
826    self.assertAllEqual(expected_result, final_result)
827
828  @combinations.generate(
829      combinations.combine(
830          distribution=strategy_combinations.all_strategies,
831          mode=["eager"]
832      ))
833  def testDatasetDistributeNotDivisibleDrop(self, distribution):
834    # If each batch is not evenly divisible by the number of workers,
835    # the remainder will be dropped.
836    dataset = get_dataset_from_tensor_slices([5., 6.]).batch(
837        1, drop_remainder=True)
838    input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
839
840    data = next(input_iterator)
841
842    expected_result = [5.]
843    final_result = []
844    actual_result = distribution.experimental_local_results(data)
845    for val in actual_result:
846      final_result.extend(val)
847    self.assertAllEqual(expected_result, final_result)
848
849  @combinations.generate(
850      combinations.combine(
851          distribution=strategy_combinations.all_strategies,
852          mode=["eager"]
853      ))
854  def testDatasetDistributeEvenlyDivisibleNoDrop(self, distribution):
855    # Setting drop_remainder=False on the dataset causes DistributedIterator
856    # to use get_next_as_optional(), even if the batched dataset is evenly
857    # divisible by the number of workers.
858    dataset = get_dataset_from_tensor_slices([5., 6.]).batch(
859        2, drop_remainder=False)
860    input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
861
862    data = next(input_iterator)
863
864    expected_result = [5., 6.]
865    final_result = []
866    actual_result = distribution.experimental_local_results(data)
867    for val in actual_result:
868      final_result.extend(val)
869    self.assertAllEqual(expected_result, final_result)
870
871  @combinations.generate(
872      combinations.combine(
873          distribution=strategy_combinations.all_strategies,
874          mode=["eager"]
875      ))
876  def testDatasetPartialBatchWithMixedOutputs(self, distribution):
877    # Dynamic output size with a mix of static and dynamic outputs
878    dataset = get_dataset_from_tensor_slices([5.]).batch(2)
879    input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
880
881    @def_function.function
882    def run(iterator):
883
884      def computation(x):
885        # Fixed size output with a dynamic sized output.
886        return array_ops.zeros([3]), math_ops.square(x)
887
888      return distribution.run(
889          computation, args=(next(iterator),))
890
891    results = run(input_iterator)
892
893    # First result is fixed for all replicas.
894    for replica_id in range(distribution.num_replicas_in_sync):
895      self.assertAllEqual([0., 0., 0.],
896                          distribution.experimental_local_results(
897                              results[0])[replica_id])
898    # Only first replica has distributed dataset computation.
899    self.assertAllEqual([25.],
900                        distribution.experimental_local_results(results[1])[0])
901    # Other replicas have no distributed dataset computation.
902    for replica_id in range(1, distribution.num_replicas_in_sync):
903      self.assertAllEqual([],
904                          distribution.experimental_local_results(
905                              results[1])[replica_id])
906
907  @combinations.generate(
908      combinations.combine(
909          distribution=strategy_combinations.all_strategies,
910          mode=["eager"]
911      ))
912  def testIterationInsideFunction(self, distribution):
913
914    def step_fn(data):
915      return math_ops.square(data)
916
917    @def_function.function
918    def train(dataset):
919      results = []
920      iterator = iter(dataset)
921      # we iterate through the loop 2 times since we have 4 elements and a
922      # global batch of 2.
923      for _ in range(2):
924        elem = next(iterator)
925        output = distribution.experimental_local_results(
926            distribution.run(step_fn, args=(elem,)))
927        results.append(output)
928      return results
929
930    dataset = get_dataset_from_tensor_slices([5., 6., 7., 8.]).batch(2)
931    dist_dataset = distribution.experimental_distribute_dataset(dataset)
932    results = train(dist_dataset)
933    self.assert_equal_flattened([[25., 36.], [49., 64.]], results)
934
935  @combinations.generate(
936      combinations.combine(
937          distribution=strategy_combinations.all_strategies,
938          mode=["eager"]
939      ))
940  def testIterationOutsideFunction(self, distribution):
941
942    def train_step(data):
943      return math_ops.square(data)
944
945    @def_function.function
946    def f_train_step(input_data):
947      return distribution.experimental_local_results(
948          distribution.run(train_step, args=(input_data,)))
949
950    dataset = get_dataset_from_tensor_slices([5., 6., 7., 8.]).batch(2)
951    dist_dataset = distribution.experimental_distribute_dataset(dataset)
952    iterator = iter(dist_dataset)
953    results = []
954    # we iterate through the loop 2 times since we have 4 elements and a
955    # global batch of 2.
956    for _ in range(2):
957      output = f_train_step(next(iterator))
958      results.append(output)
959    self.assert_equal_flattened([[25., 36.], [49., 64.]], results)
960
961  @combinations.generate(
962      combinations.combine(
963          distribution=strategy_combinations.all_strategies,
964          mode=["eager"]
965      ))
966  def testMultiDeviceDataCapturedFunction(self, distribution):
967    inputs = constant_op.constant([2., 3.])
968    dataset = lambda _: dataset_ops.Dataset.from_tensor_slices(inputs).repeat(5)
969    input_iterator = iter(
970        distribution.distribute_datasets_from_function(dataset))
971    with distribution.scope():
972      var = variables.Variable(1.0)
973
974    @def_function.function
975    def train_step(input_iterator):
976
977      def func(inputs):
978        return math_ops.square(inputs) + var
979
980      per_replica_outputs = distribution.run(
981          func, (next(input_iterator),))
982      mean = distribution.reduce(
983          reduce_util.ReduceOp.MEAN, per_replica_outputs, axis=None)
984      for _ in dataset_ops.Dataset.range(1):
985        per_replica_outputs = distribution.run(
986            func, (next(input_iterator),))
987        mean = distribution.reduce(
988            reduce_util.ReduceOp.MEAN, per_replica_outputs, axis=None)
989      return mean
990
991    with distribution.scope():
992      if distribution.num_replicas_in_sync == 1:
993        self.assertAlmostEqual(10.0, self.evaluate(train_step(input_iterator)))
994      else:
995        self.assertAlmostEqual(7.5, self.evaluate(train_step(input_iterator)))
996
997  @combinations.generate(
998      combinations.combine(
999          distribution=strategy_combinations.all_strategies,
1000          mode=["eager"]
1001      ))
1002  def testDatasetOutOfRange(self, distribution):
1003    with distribution.scope():
1004      a = variables.Variable(
1005          0.0, aggregation=variables.VariableAggregation.SUM)
1006
1007    def train_step(val):
1008      a.assign_add(math_ops.reduce_sum(val))
1009
1010    @def_function.function
1011    def f_train_step(iterator):
1012      distribution.run(train_step, args=(next(iterator),))
1013      return a
1014
1015    dataset = get_dataset_from_tensor_slices([5., 6., 7., 8.]).batch(2)
1016    dist_dataset = distribution.experimental_distribute_dataset(dataset)
1017
1018    iterator = iter(dist_dataset)
1019    with self.assertRaises(errors.OutOfRangeError):
1020      for _ in range(100):
1021        f_train_step(iterator)
1022
1023    self.assertAlmostEqual(26.0, a.numpy())
1024
1025  @combinations.generate(
1026      combinations.combine(
1027          distribution=strategy_combinations.multidevice_strategies,
1028          mode=["eager"]))
1029  def testComputeLossWithDynamicShapes(self, distribution):
1030    dataset = get_dataset_from_tensor_slices([5., 6., 7.]).batch(4)
1031    input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
1032
1033    @def_function.function
1034    def run(iterator):
1035
1036      def computation(x):
1037        return losses.compute_weighted_loss(x, weights=array_ops.ones_like(x))
1038
1039      inputs = next(iterator)
1040      outputs = distribution.experimental_local_results(
1041          distribution.run(computation, args=(inputs,)))
1042      return outputs
1043
1044    # This assumes that there are exactly 2 replicas
1045    self.assertAllEqual([5.5, 7.], run(input_iterator))
1046
1047
1048if __name__ == "__main__":
1049  test_util.main()
1050