• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for custom training loops."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from absl.testing import parameterized
22
23from tensorflow.python import tf2
24from tensorflow.python.data.ops import dataset_ops
25from tensorflow.python.distribute import combinations
26from tensorflow.python.distribute import device_util
27from tensorflow.python.distribute import distribute_lib
28from tensorflow.python.distribute import reduce_util
29from tensorflow.python.distribute import strategy_combinations
30from tensorflow.python.distribute import test_util
31from tensorflow.python.eager import def_function
32from tensorflow.python.eager import test
33from tensorflow.python.framework import constant_op
34from tensorflow.python.framework import dtypes
35from tensorflow.python.framework import errors
36from tensorflow.python.framework import ops
37from tensorflow.python.ops import array_ops
38from tensorflow.python.ops import control_flow_ops
39from tensorflow.python.ops import map_fn
40from tensorflow.python.ops import math_ops
41from tensorflow.python.ops import variables
42from tensorflow.python.util import nest
43
44
45def get_dataset_from_tensor_slices(inp_array):
46  dataset = dataset_ops.DatasetV2.from_tensor_slices(inp_array)
47  # TODO(b/138326910): Remove Dataset V1 version once bug resolved.
48  if not tf2.enabled():
49    dataset = dataset_ops.Dataset.from_tensor_slices(inp_array)
50  return dataset
51
52
53class AssertFlattenedMixin(object):
54  """Mixin for specialized asserts."""
55
56  def assert_equal_flattened(self, expected_results, actual_results):
57    """Asserts that flattened results are equal.
58
59    Due to the number of replicas in the strategy, the output may have a
60    different structure and needs to be flattened for comparison.
61
62    Args:
63      expected_results: The results expected as a result of a computation.
64      actual_results: The actual results of a computation.
65    """
66    self.assertEqual(len(expected_results), len(actual_results))
67
68    for i, expected_result in enumerate(expected_results):
69      final_result = []
70      actual_result = actual_results[i]
71      for val in actual_result:
72        final_result.extend(val.numpy())
73      self.assertAllEqual(expected_result, final_result)
74
75
76class InputIterationTest(test.TestCase, parameterized.TestCase,
77                         AssertFlattenedMixin):
78
79  @combinations.generate(
80      combinations.combine(
81          distribution=strategy_combinations.all_strategies,
82          mode=["eager"]
83      ))
84  def testConstantNumpyInput(self, distribution):
85
86    @def_function.function
87    def run(x):
88
89      def computation(x):
90        return math_ops.square(x)
91
92      outputs = distribution.experimental_local_results(
93          distribution.run(computation, args=(x,)))
94      return outputs
95
96    self.assertAllEqual(
97        constant_op.constant(4., shape=(distribution.num_replicas_in_sync)),
98        run(2.))
99
100  @combinations.generate(
101      combinations.combine(
102          distribution=strategy_combinations.all_strategies,
103          mode=["eager"]
104      ))
105  def testStatefulExperimentalRunAlwaysExecute(self, distribution):
106    with distribution.scope():
107      v = variables.Variable(
108          0.0, aggregation=variables.VariableAggregation.MEAN)
109
110    @def_function.function
111    def train_step():
112
113      def assign_add():
114        v.assign_add(1.0)
115
116      distribution.run(assign_add)
117      return array_ops.zeros([])
118
119    train_step()
120    self.assertAllEqual(1.0, v.numpy())
121
122  @combinations.generate(
123      combinations.combine(
124          distribution=strategy_combinations.strategies_minus_tpu,
125          mode=["eager"]))
126  def testFullEager(self, distribution):
127    dataset = get_dataset_from_tensor_slices([5., 6., 7., 8.]).batch(2)
128
129    def train_step(data):
130      return math_ops.square(data)
131
132    dist_dataset = distribution.experimental_distribute_dataset(dataset)
133    results = []
134    for x in dist_dataset:
135      output = distribution.experimental_local_results(
136          distribution.run(train_step, args=(x,)))
137      results.append(output)
138    self.assert_equal_flattened([[25., 36.], [49., 64.]], results)
139
140  @combinations.generate(
141      combinations.combine(
142          distribution=strategy_combinations.all_strategies, mode=["eager"]))
143  def testGetNextAsOptional(self, distribution):
144    data = [5., 6., 7., 8.]
145    dataset = get_dataset_from_tensor_slices(data).batch(2)
146    dist_dataset = distribution.experimental_distribute_dataset(dataset)
147    iterator = iter(dist_dataset)
148
149    def train_step(data):
150      return math_ops.square(data)
151
152    @def_function.function
153    def run(iterator):
154      return distribution.experimental_local_results(
155          distribution.run(
156              train_step, args=(iterator.get_next_as_optional().get_value(),)))
157
158    self.assert_equal_flattened([[25., 36.]], [run(iterator)])
159
160  @combinations.generate(
161      combinations.combine(
162          distribution=strategy_combinations.all_strategies, mode=["eager"]))
163  def testGetNextAsOptionalExampleUsage(self, distribution):
164    global_batch_size = 2
165    steps_per_loop = 6
166    dataset = dataset_ops.Dataset.range(
167        8, output_type=dtypes.int32).batch(global_batch_size)
168    distributed_iterator = iter(
169        distribution.experimental_distribute_dataset(dataset))
170
171    @def_function.function
172    def train_fn(distributed_iterator):
173
174      def step_fn(x):
175        return x
176
177      for _ in math_ops.range(steps_per_loop):
178        optional_data = distributed_iterator.get_next_as_optional()
179        if not optional_data.has_value():
180          break
181        distribution.run(step_fn, args=(optional_data.get_value(),))
182
183    train_fn(distributed_iterator)
184
185  @combinations.generate(
186      combinations.combine(
187          distribution=strategy_combinations.tpu_strategies, mode=["eager"]))
188  def testFullEagerTPU(self, distribution):
189    dataset = get_dataset_from_tensor_slices([5., 6., 7., 8.]).batch(2)
190
191    def train_step(data):
192      return math_ops.square(data)
193
194    input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
195
196    with self.assertRaisesRegex(NotImplementedError,
197                                "does not support pure eager execution"):
198      distribution.run(train_step, args=(next(input_iterator),))
199
200  @combinations.generate(
201      combinations.combine(
202          distribution=strategy_combinations.all_strategies,
203          mode=["eager"]
204      ))
205  def testStepInFunction(self, distribution):
206    dataset = get_dataset_from_tensor_slices([5., 6., 7., 8.]).batch(2)
207
208    @def_function.function
209    def train_step(data):
210      return math_ops.square(data)
211
212    dist_dataset = distribution.experimental_distribute_dataset(dataset)
213    results = []
214    for x in dist_dataset:
215      output = distribution.experimental_local_results(
216          distribution.run(train_step, args=(x,)))
217      results.append(output)
218    self.assert_equal_flattened([[25., 36.], [49., 64.]], results)
219
220  @combinations.generate(
221      combinations.combine(
222          distribution=strategy_combinations.all_strategies,
223          mode=["eager"]
224      ))
225  def testRunInFunction(self, distribution):
226    dataset = get_dataset_from_tensor_slices([5., 6., 7., 8.]).batch(2)
227
228    def train_step(data):
229      return math_ops.square(data)
230
231    @def_function.function
232    def f_train_step(input_data):
233      return distribution.experimental_local_results(
234          distribution.run(train_step, args=(input_data,)))
235
236    dist_dataset = distribution.experimental_distribute_dataset(dataset)
237    results = []
238    for x in dist_dataset:
239      output = f_train_step(x)
240      results.append(output)
241    self.assert_equal_flattened([[25., 36.], [49., 64.]], results)
242
243  @combinations.generate(
244      combinations.combine(
245          distribution=[
246              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
247              strategy_combinations.tpu_strategy,
248              strategy_combinations.tpu_strategy_packed_var,
249          ],
250          mode=["eager"]))
251  def testNestedOutput(self, distribution):
252    dataset = get_dataset_from_tensor_slices([0, 1, 2, 3]).batch(2)
253    input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
254
255    @def_function.function
256    def run(iterator):
257
258      def computation(x):
259        return [{
260            "a": x - 1,
261            "b": x + 1
262        }]
263
264      inputs = next(iterator)
265      outputs = distribution.run(computation, args=(inputs,))
266      return nest.map_structure(distribution.experimental_local_results,
267                                outputs)
268
269    results = run(input_iterator)
270    for replica in range(distribution.num_replicas_in_sync):
271      # The input dataset is range(4), so the replica id is same as input.
272      self.assertAllEqual(results[0]["a"][replica], [replica - 1])
273      self.assertAllEqual(results[0]["b"][replica], [replica + 1])
274
275  @combinations.generate(
276      combinations.combine(
277          distribution=strategy_combinations.all_strategies,
278          mode=["eager"]
279      ))
280  def testRunInFunctionAutoGraphApplication(self, distribution):
281    dataset = get_dataset_from_tensor_slices([5., 6., 7., 8.]).batch(2)
282
283    def train_step(data):
284      return math_ops.square(data)
285
286    @def_function.function
287    def f_train_step(input_data):
288      return distribution.experimental_local_results(
289          distribution.run(train_step, args=(input_data,)))
290
291    dist_dataset = distribution.experimental_distribute_dataset(dataset)
292    results = []
293    for x in dist_dataset:
294      output = f_train_step(x)
295      results.append(output)
296    self.assert_equal_flattened([[25., 36.], [49., 64.]], results)
297
298  @combinations.generate(
299      combinations.combine(
300          distribution=strategy_combinations.all_strategies,
301          mode=["eager"]
302      ))
303  def testDatasetIterationInFunction(self, distribution):
304    with distribution.scope():
305      a = variables.Variable(
306          1.0, aggregation=variables.VariableAggregation.ONLY_FIRST_REPLICA)
307
308    def train_step(_):
309      a.assign_add(1.0)
310
311    @def_function.function
312    def f_train_step(dist_dataset):
313      number_of_steps = constant_op.constant(0.0)
314      product_of_means = constant_op.constant(2.0)
315      for x in dist_dataset:  # loop with values modified each iteration
316        number_of_steps += 1
317        product_of_means *= math_ops.cast(
318            distribution.reduce("MEAN", x, axis=0), product_of_means.dtype)
319
320      for y in dist_dataset:  # loop with no intermediate state
321        distribution.run(train_step, args=(y,))
322
323      return number_of_steps, product_of_means
324
325    dataset = get_dataset_from_tensor_slices([5., 6., 7., 8.]).batch(2)
326    dist_dataset = distribution.experimental_distribute_dataset(dataset)
327
328    number_of_steps, product_of_means = f_train_step(dist_dataset)
329    self.assertEqual(2, number_of_steps.numpy())
330    self.assertNear((2 * (5+6)/2 * (7+8)/2), product_of_means.numpy(), 1e-3)
331
332    # We set the initial value of `a` to 1 and iterate through the dataset 2
333    # times(4/2 where 4 is the number of dataset elements and 2 is the batch
334    # size). Hence the final result is 3.
335    self.assertEqual(3.0, (a.numpy()))
336
337  @combinations.generate(
338      combinations.combine(
339          distribution=strategy_combinations.all_strategies,
340          mode=["eager"]
341      ))
342  def testDatasetAssertWithDynamicBatch(self, distribution):
343    # Regression test for github issue 33517.
344    def step_fn(data):
345      assert_op = control_flow_ops.Assert(math_ops.less_equal(
346          math_ops.reduce_max(data), 100.), [data])
347      with ops.control_dependencies([assert_op]):
348        return math_ops.square(data)
349
350    @def_function.function
351    def train(dataset):
352      results = []
353      iterator = iter(dataset)
354      # we iterate through the loop 5 times since we have 3 elements and a
355      # global batch of 2.
356      for _ in range(2):
357        elem = next(iterator)
358        output = distribution.experimental_local_results(
359            distribution.run(step_fn, args=(elem,)))
360        results.append(output)
361      return results
362
363    dataset = dataset_ops.DatasetV2.from_tensor_slices([5., 6., 7.,]).batch(2)
364    # TODO(b/138326910): Remove Dataset V1 version once bug resolved.
365    if not tf2.enabled():
366      dataset = dataset_ops.Dataset.from_tensor_slices([5., 6., 7.,]).batch(2)
367    dist_dataset = distribution.experimental_distribute_dataset(dataset)
368    results = train(dist_dataset)
369
370    expected_results = [[25., 36.], [49.]]
371    self.assertEqual(len(expected_results), len(results))
372
373    # Need to expand results since output will be grouped differently depending
374    # on the number of replicas.
375    for i, expected_result in enumerate(expected_results):
376      final_result = []
377      actual_result = results[i]
378      for val in actual_result:
379        final_result.extend(val.numpy())
380      self.assertAllEqual(expected_result, final_result)
381
382  @combinations.generate(
383      combinations.combine(
384          distribution=strategy_combinations.all_strategies,
385          mode=["eager"]
386      ))
387  def testDistributeDatasetIteratorWithoutFunction(self, distribution):
388    data = [5., 6., 7., 8.]
389    input_iterator = iter(
390        distribution.distribute_datasets_from_function(
391            lambda _: get_dataset_from_tensor_slices(data)))
392
393    self.assertAllEqual(
394        distribution.experimental_local_results(input_iterator.get_next()),
395        data[0:distribution.num_replicas_in_sync])
396
397  @combinations.generate(
398      combinations.combine(
399          distribution=strategy_combinations.multidevice_strategies,
400          mode=["eager"]
401      ))
402  def testDistributeDatasetIteratorWithFunction(self, distribution):
403    data = [5., 6., 7., 8.]
404    input_iterator = iter(
405        distribution.distribute_datasets_from_function(
406            lambda _: get_dataset_from_tensor_slices(data)))
407
408    @def_function.function
409    def run(iterator):
410      return distribution.experimental_local_results(iterator.get_next())
411
412    local_results = run(input_iterator)
413    self.assertAllEqual(local_results,
414                        data[0:distribution.num_replicas_in_sync])
415    backing_devices = [result.backing_device for result in local_results]
416    self.assertAllEqual(backing_devices, distribution.extended.worker_devices)
417
418  @combinations.generate(
419      combinations.combine(
420          distribution=strategy_combinations.multidevice_strategies,
421          mode=["eager"]
422      ))
423  def testDistributeDatasetPrefetch(self, distribution):
424    data = [5., 6., 7., 8.]
425    input_iterator = iter(
426        distribution.experimental_distribute_dataset(
427            get_dataset_from_tensor_slices(data).batch(2)))
428
429    local_results = distribution.experimental_local_results(
430        input_iterator.get_next())
431
432    backing_devices = [result.backing_device for result in local_results]
433    self.assertAllEqual(backing_devices, distribution.extended.worker_devices)
434
435  @combinations.generate(
436      combinations.combine(
437          distribution=strategy_combinations.multidevice_strategies,
438          mode=["eager"]
439      ))
440  def testDistributeDatasetFunctionPrefetch(self, distribution):
441    data = [5., 6., 7., 8.]
442    input_iterator = iter(
443        distribution.distribute_datasets_from_function(
444            lambda _: get_dataset_from_tensor_slices(data)))
445
446    local_results = distribution.experimental_local_results(
447        input_iterator.get_next())
448
449    backing_devices = [result.backing_device for result in local_results]
450    self.assertAllEqual(backing_devices, distribution.extended.worker_devices)
451
452  @combinations.generate(
453      combinations.combine(
454          distribution=strategy_combinations.tpu_strategies,
455          mode=["eager"]
456      ))
457  def testDistributeDatasetHostPrefetch(self, distribution):
458    data = [5., 6., 7., 8.]
459    input_iterator = iter(
460        distribution.experimental_distribute_dataset(
461            get_dataset_from_tensor_slices(data).batch(2),
462            distribute_lib.InputOptions(
463                experimental_prefetch_to_device=False)))
464
465    local_results = distribution.experimental_local_results(
466        input_iterator.get_next())
467
468    for result in local_results:
469      self.assertEqual(result.backing_device,
470                       device_util.resolve("/device:CPU:0"))
471
472  @combinations.generate(
473      combinations.combine(
474          distribution=strategy_combinations.tpu_strategies,
475          mode=["eager"]
476      ))
477  def testDistributeDatasetFunctionHostPrefetch(self, distribution):
478    data = [5., 6., 7., 8.]
479    input_iterator = iter(
480        distribution.distribute_datasets_from_function(
481            lambda _: get_dataset_from_tensor_slices(data),
482            distribute_lib.InputOptions(experimental_prefetch_to_device=False)))
483
484    local_results = distribution.experimental_local_results(
485        input_iterator.get_next())
486
487    for result in local_results:
488      self.assertEqual(result.backing_device,
489                       device_util.resolve("/device:CPU:0"))
490
491  @combinations.generate(
492      combinations.combine(
493          distribution=strategy_combinations.multidevice_strategies,
494          mode=["eager"]
495      ))
496  def testDynamicShapes(self, distribution):
497    dataset = get_dataset_from_tensor_slices([5., 6., 7.]).batch(4)
498    input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
499
500    @def_function.function
501    def run(iterator):
502      def computation(x):
503        return math_ops.reduce_mean(x)
504      inputs = next(iterator)
505      outputs = distribution.experimental_local_results(
506          distribution.run(computation, args=(inputs,)))
507      return outputs
508
509    # This assumes that there are exactly 2 replicas
510    self.assertAllEqual([5.5, 7.], run(input_iterator))
511
512  @combinations.generate(
513      combinations.combine(
514          distribution=strategy_combinations.multidevice_strategies,
515          mode=["eager"]))
516  def testDynamicShapesWithRunOptions(self, distribution):
517    dataset = get_dataset_from_tensor_slices([5., 6., 7.]).batch(4)
518    input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
519    options = distribute_lib.RunOptions
520    options.experimental_bucketizing_dynamic_shape = True
521
522    @def_function.function
523    def run(iterator):
524
525      def computation(x):
526        return math_ops.reduce_mean(x)
527
528      inputs = next(iterator)
529      outputs = distribution.experimental_local_results(
530          distribution.run(
531              computation, args=(inputs,), options=options))
532      return outputs
533
534    # This assumes that there are exactly 2 replicas
535    self.assertAllEqual([5.5, 7.], run(input_iterator))
536
537  @combinations.generate(
538      combinations.combine(
539          distribution=strategy_combinations.multidevice_strategies,
540          mode=["eager"]))
541  def testDynamicOutputsWithX64(self, distribution):
542    dataset = get_dataset_from_tensor_slices(
543        [5]).map(lambda x: math_ops.cast(x, dtypes.int64)).batch(2)
544    input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
545
546    @def_function.function
547    def run(iterator):
548
549      def computation(x):
550        return math_ops.add(x, x)
551
552      inputs = next(iterator)
553      outputs = distribution.experimental_local_results(
554          distribution.run(computation, args=(inputs,)))
555      return outputs
556
557    # This assumes that there are exactly 2 replicas
558    result = run(input_iterator)
559    self.assertAllEqual([10], result[0])
560    self.assertAllEqual([], result[1])
561
562  @combinations.generate(
563      combinations.combine(
564          distribution=strategy_combinations.multidevice_strategies,
565          mode=["eager"]
566      ))
567  def testDynamicShapesWithGetNextOutsideFunction(self, distribution):
568    dataset = get_dataset_from_tensor_slices([5., 6., 7.]).batch(4)
569    input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
570
571    @def_function.function
572    def run(inputs):
573      def computation(x):
574        return math_ops.reduce_mean(x)
575      outputs = distribution.experimental_local_results(
576          distribution.run(computation, args=(inputs,)))
577      return outputs
578
579    # This assumes that there are exactly 2 replicas
580    self.assertAllEqual([5.5, 7.], run(next(input_iterator)))
581
582  @combinations.generate(
583      combinations.combine(
584          distribution=strategy_combinations.multidevice_strategies,
585          mode=["eager"]
586      ))
587  def testStrategyReduceWithDynamicShapes(self, distribution):
588    dataset = get_dataset_from_tensor_slices([5., 6., 7.]).batch(4)
589    input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
590
591    @def_function.function
592    def run(iterator):
593      inputs = next(iterator)
594      return distribution.reduce(reduce_util.ReduceOp.MEAN, inputs, axis=0)
595
596    self.assertAllEqual(6., run(input_iterator))
597
598  @combinations.generate(
599      combinations.combine(
600          distribution=strategy_combinations.multidevice_strategies,
601          mode=["eager"]
602      ))
603  def testStrategyReduceWithDynamicShapesRank2(self, distribution):
604    dataset = get_dataset_from_tensor_slices(
605        [[1., 1.], [1., 1.], [1., 1.]]).batch(4)
606    input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
607
608    @def_function.function
609    def run(iterator):
610      inputs = next(iterator)
611      return distribution.reduce(reduce_util.ReduceOp.MEAN, inputs, axis=0)
612
613    self.assertAllEqual([1., 1.], run(input_iterator))
614
615  @combinations.generate(
616      combinations.combine(
617          distribution=strategy_combinations.multidevice_strategies,
618          mode=["eager"]
619      ))
620  def testDynamicShapesWithSizeOp(self, distribution):
621    dataset = get_dataset_from_tensor_slices([5., 6., 7.]).batch(4)
622    input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
623
624    @def_function.function
625    def run(inputs):
626      def computation(x):
627        return array_ops.size_v2(x)
628      outputs = distribution.experimental_local_results(
629          distribution.run(computation, args=(inputs,)))
630      return outputs
631
632    # This assumes that there are exactly 2 replicas
633    self.assertAllEqual([2, 1], run(next(input_iterator)))
634
635  @combinations.generate(
636      combinations.combine(
637          distribution=strategy_combinations.multidevice_strategies,
638          mode=["eager"]))
639  def testSegmentSumWithDynamicNumberOfSegments(self, distribution):
640
641    def dataset_fn(_):
642      data = array_ops.zeros(5, dtype=dtypes.int32)
643      dataset = get_dataset_from_tensor_slices(data)
644      dataset = dataset.batch(3)
645      return dataset
646
647    input_iterator = iter(
648        distribution.distribute_datasets_from_function(dataset_fn))
649
650    @def_function.function
651    def step_fn(example):
652      segment_ids = array_ops.zeros_like_v2(example)
653      num_segment = array_ops.shape(example)[0]
654      # If number of segments is dynamic, output should be a dynamic shape.
655      return math_ops.unsorted_segment_sum(example, segment_ids, num_segment)
656
657    # This assumes that there are exactly 2 replicas
658    outputs = distribution.experimental_local_results(
659        distribution.run(step_fn, args=(next(input_iterator),)))
660    self.assertAllEqual((3,), outputs[0].shape)
661    self.assertAllEqual((2,), outputs[1].shape)
662
663  @combinations.generate(
664      combinations.combine(
665          distribution=strategy_combinations.multidevice_strategies,
666          mode=["eager"]))
667  def testReshapeWithDynamicInputs(self, distribution):
668
669    def dataset_fn(_):
670      data = array_ops.zeros((5, 1, 2), dtype=dtypes.int32)
671      dataset = get_dataset_from_tensor_slices(data)
672      dataset = dataset.batch(3)
673      return dataset
674
675    input_iterator = iter(
676        distribution.distribute_datasets_from_function(dataset_fn))
677
678    @def_function.function
679    def step_fn(example):
680      # example: [<=3, 1, 2]
681      # tile: [<=3, <=3, 2]
682      tile = array_ops.tile(example, [1, array_ops.shape(example)[0], 1])
683      # reshape1: [<=(3*3 = 9), 2]
684      reshape1 = array_ops.reshape(tile, [-1, 2])
685
686      # reshape2: [<=3, <=3, 2]
687      reshape2 = array_ops.reshape(
688          reshape1,
689          [array_ops.shape(example)[0],
690           array_ops.shape(example)[0], 2])
691
692      # reshape3: [<=3, -1, 2]
693      reshape3 = array_ops.reshape(reshape1,
694                                   [array_ops.shape(example)[0], -1, 2])
695      # reshape4: [-1, <=3, 2]
696      reshape4 = array_ops.reshape(reshape1,
697                                   [-1, array_ops.shape(example)[0], 2])
698      return [reshape1, reshape2, reshape3, reshape4]
699
700    # This assumes that there are exactly 2 replicas
701    outputs = distribution.experimental_local_results(
702        distribution.run(step_fn, args=(next(input_iterator),)))
703    self.assertAllEqual((9, 2), outputs[0][0].values[0].shape)
704    self.assertAllEqual((3, 3, 2), outputs[0][1].values[0].shape)
705    self.assertAllEqual((3, 3, 2), outputs[0][2].values[0].shape)
706    self.assertAllEqual((3, 3, 2), outputs[0][3].values[0].shape)
707
708    self.assertAllEqual((4, 2), outputs[0][0].values[1].shape)
709    self.assertAllEqual((2, 2, 2), outputs[0][1].values[1].shape)
710    self.assertAllEqual((2, 2, 2), outputs[0][2].values[1].shape)
711    self.assertAllEqual((2, 2, 2), outputs[0][3].values[1].shape)
712
713  @combinations.generate(
714      combinations.combine(
715          distribution=strategy_combinations.multidevice_strategies,
716          mode=["eager"]))
717  def testDynamicShapesWithFirstReplicaNotMaximumShape(self, distribution):
718    def dataset_fn(_):
719      dataset1 = get_dataset_from_tensor_slices([[1., 2.], [1., 2.]])
720      dataset2 = get_dataset_from_tensor_slices([[1., 2., 3.],
721                                                 [1., 2., 3.]])
722      dataset = dataset1.concatenate(dataset2)
723      dataset = dataset.batch(2, drop_remainder=True)
724      return dataset
725
726    input_iterator = iter(
727        distribution.distribute_datasets_from_function(dataset_fn))
728
729    @def_function.function
730    def run(inputs):
731      def computation(x):
732        return math_ops.reduce_mean(x)
733      outputs = distribution.experimental_local_results(
734          distribution.run(computation, args=(inputs,)))
735      return outputs
736
737    # This assumes that there are exactly 2 replicas
738    self.assertAllEqual([1.5, 2.], run(next(input_iterator)))
739
740  @combinations.generate(
741      combinations.combine(
742          distribution=strategy_combinations.multidevice_strategies,
743          mode=["eager"]))
744  def testMapFnWithDynamicInputs(self, distribution):
745
746    def dataset_fn(_):
747      data = array_ops.zeros((20, 300, 32), dtype=dtypes.int32)
748      dataset = get_dataset_from_tensor_slices(data)
749      dataset = dataset.batch(16)
750      return dataset
751
752    input_iterator = iter(
753        distribution.distribute_datasets_from_function(dataset_fn))
754
755    def embedding_lookup(inputs):
756      embedding_weights = array_ops.zeros((1, 128))
757      flat_inputs = array_ops.reshape(inputs, [-1])
758      embeddings = array_ops.gather(embedding_weights, flat_inputs)
759      embeddings = array_ops.reshape(embeddings, inputs.shape.as_list() + [128])
760      return embeddings
761
762    @def_function.function
763    def step_fn(example):
764      return map_fn.map_fn(
765          embedding_lookup, example, fn_output_signature=dtypes.float32)
766
767    # This assumes that there are exactly 2 replicas
768    outputs = distribution.experimental_local_results(
769        distribution.run(step_fn, args=(next(input_iterator),)))
770    self.assertAllEqual((16, 300, 32, 128), outputs[0].shape)
771    self.assertAllEqual((4, 300, 32, 128), outputs[1].shape)
772
773  @combinations.generate(
774      combinations.combine(
775          distribution=strategy_combinations.all_strategies,
776          mode=["eager"]
777      ))
778  def testDatasetDistributeEvenlyDivisibleDrop(self, distribution):
779    # If the batch size is evenly divisible by the number of workers and we set
780    # drop_remainder=True on the dataset, then DistributedIterator will use a
781    # different (and more efficient) code path which avoids some control flow
782    # ops.
783    dataset = get_dataset_from_tensor_slices([5., 6.]).batch(
784        2, drop_remainder=True)
785    input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
786
787    data = next(input_iterator)
788
789    expected_result = [5., 6.]
790    final_result = []
791    actual_result = distribution.experimental_local_results(data)
792    for val in actual_result:
793      final_result.extend(val)
794    self.assertAllEqual(expected_result, final_result)
795
796  @combinations.generate(
797      combinations.combine(
798          distribution=strategy_combinations.all_strategies,
799          mode=["eager"]
800      ))
801  def testDatasetDistributeNotDivisibleDrop(self, distribution):
802    # If each batch is not evenly divisible by the number of workers,
803    # the remainder will be dropped.
804    dataset = get_dataset_from_tensor_slices([5., 6.]).batch(
805        1, drop_remainder=True)
806    input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
807
808    data = next(input_iterator)
809
810    expected_result = [5.]
811    final_result = []
812    actual_result = distribution.experimental_local_results(data)
813    for val in actual_result:
814      final_result.extend(val)
815    self.assertAllEqual(expected_result, final_result)
816
817  @combinations.generate(
818      combinations.combine(
819          distribution=strategy_combinations.all_strategies,
820          mode=["eager"]
821      ))
822  def testDatasetDistributeEvenlyDivisibleNoDrop(self, distribution):
823    # Setting drop_remainder=False on the dataset causes DistributedIterator
824    # to use get_next_as_optional(), even if the batched dataset is evenly
825    # divisible by the number of workers.
826    dataset = get_dataset_from_tensor_slices([5., 6.]).batch(
827        2, drop_remainder=False)
828    input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
829
830    data = next(input_iterator)
831
832    expected_result = [5., 6.]
833    final_result = []
834    actual_result = distribution.experimental_local_results(data)
835    for val in actual_result:
836      final_result.extend(val)
837    self.assertAllEqual(expected_result, final_result)
838
839  @combinations.generate(
840      combinations.combine(
841          distribution=strategy_combinations.all_strategies,
842          mode=["eager"]
843      ))
844  def testDatasetPartialBatchWithMixedOutputs(self, distribution):
845    # Dynamic output size with a mix of static and dynamic outputs
846    dataset = get_dataset_from_tensor_slices([5.]).batch(2)
847    input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
848
849    @def_function.function
850    def run(iterator):
851
852      def computation(x):
853        # Fixed size output with a dynamic sized output.
854        return array_ops.zeros([3]), math_ops.square(x)
855
856      return distribution.run(
857          computation, args=(next(iterator),))
858
859    results = run(input_iterator)
860
861    # First result is fixed for all replicas.
862    for replica_id in range(distribution.num_replicas_in_sync):
863      self.assertAllEqual([0., 0., 0.],
864                          distribution.experimental_local_results(
865                              results[0])[replica_id])
866    # Only first replica has distributed dataset computation.
867    self.assertAllEqual([25.],
868                        distribution.experimental_local_results(results[1])[0])
869    # Other replicas have no distributed dataset computation.
870    for replica_id in range(1, distribution.num_replicas_in_sync):
871      self.assertAllEqual([],
872                          distribution.experimental_local_results(
873                              results[1])[replica_id])
874
875  @combinations.generate(
876      combinations.combine(
877          distribution=strategy_combinations.all_strategies,
878          mode=["eager"]
879      ))
880  def testIterationInsideFunction(self, distribution):
881
882    def step_fn(data):
883      return math_ops.square(data)
884
885    @def_function.function
886    def train(dataset):
887      results = []
888      iterator = iter(dataset)
889      # we iterate through the loop 2 times since we have 4 elements and a
890      # global batch of 2.
891      for _ in range(2):
892        elem = next(iterator)
893        output = distribution.experimental_local_results(
894            distribution.run(step_fn, args=(elem,)))
895        results.append(output)
896      return results
897
898    dataset = get_dataset_from_tensor_slices([5., 6., 7., 8.]).batch(2)
899    dist_dataset = distribution.experimental_distribute_dataset(dataset)
900    results = train(dist_dataset)
901    self.assert_equal_flattened([[25., 36.], [49., 64.]], results)
902
903  @combinations.generate(
904      combinations.combine(
905          distribution=strategy_combinations.all_strategies,
906          mode=["eager"]
907      ))
908  def testIterationOutsideFunction(self, distribution):
909
910    def train_step(data):
911      return math_ops.square(data)
912
913    @def_function.function
914    def f_train_step(input_data):
915      return distribution.experimental_local_results(
916          distribution.run(train_step, args=(input_data,)))
917
918    dataset = get_dataset_from_tensor_slices([5., 6., 7., 8.]).batch(2)
919    dist_dataset = distribution.experimental_distribute_dataset(dataset)
920    iterator = iter(dist_dataset)
921    results = []
922    # we iterate through the loop 2 times since we have 4 elements and a
923    # global batch of 2.
924    for _ in range(2):
925      output = f_train_step(next(iterator))
926      results.append(output)
927    self.assert_equal_flattened([[25., 36.], [49., 64.]], results)
928
929  @combinations.generate(
930      combinations.combine(
931          distribution=strategy_combinations.all_strategies,
932          mode=["eager"]
933      ))
934  def testMultiDeviceDataCapturedFunction(self, distribution):
935    inputs = constant_op.constant([2., 3.])
936    dataset = lambda _: dataset_ops.Dataset.from_tensor_slices(inputs).repeat(5)
937    input_iterator = iter(
938        distribution.distribute_datasets_from_function(dataset))
939    with distribution.scope():
940      var = variables.Variable(1.0)
941
942    @def_function.function
943    def train_step(input_iterator):
944
945      def func(inputs):
946        return math_ops.square(inputs) + var
947
948      per_replica_outputs = distribution.run(
949          func, (next(input_iterator),))
950      mean = distribution.reduce(
951          reduce_util.ReduceOp.MEAN, per_replica_outputs, axis=None)
952      for _ in dataset_ops.Dataset.range(1):
953        per_replica_outputs = distribution.run(
954            func, (next(input_iterator),))
955        mean = distribution.reduce(
956            reduce_util.ReduceOp.MEAN, per_replica_outputs, axis=None)
957      return mean
958
959    with distribution.scope():
960      if distribution.num_replicas_in_sync == 1:
961        self.assertAlmostEqual(10.0, self.evaluate(train_step(input_iterator)))
962      else:
963        self.assertAlmostEqual(7.5, self.evaluate(train_step(input_iterator)))
964
965  @combinations.generate(
966      combinations.combine(
967          distribution=strategy_combinations.all_strategies,
968          mode=["eager"]
969      ))
970  def testDatasetOutOfRange(self, distribution):
971    with distribution.scope():
972      a = variables.Variable(
973          0.0, aggregation=variables.VariableAggregation.SUM)
974
975    def train_step(val):
976      a.assign_add(math_ops.reduce_sum(val))
977
978    @def_function.function
979    def f_train_step(iterator):
980      distribution.run(train_step, args=(next(iterator),))
981      return a
982
983    dataset = get_dataset_from_tensor_slices([5., 6., 7., 8.]).batch(2)
984    dist_dataset = distribution.experimental_distribute_dataset(dataset)
985
986    iterator = iter(dist_dataset)
987    with self.assertRaises(errors.OutOfRangeError):
988      for _ in range(100):
989        f_train_step(iterator)
990
991    self.assertAlmostEqual(26.0, a.numpy())
992
993
994if __name__ == "__main__":
995  test_util.main()
996