• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for custom training loops."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import os
22
23from absl.testing import parameterized
24import numpy as np
25
26from tensorflow.python import keras
27from tensorflow.python import tf2
28from tensorflow.python.data.ops import dataset_ops
29from tensorflow.python.distribute import combinations
30from tensorflow.python.distribute import reduce_util
31from tensorflow.python.distribute import strategy_combinations
32from tensorflow.python.eager import backprop
33from tensorflow.python.eager import def_function
34from tensorflow.python.eager import test
35from tensorflow.python.framework import constant_op
36from tensorflow.python.framework import ops
37from tensorflow.python.ops import array_ops
38from tensorflow.python.ops import control_flow_ops
39from tensorflow.python.ops import math_ops
40from tensorflow.python.ops import variables
41from tensorflow.python.util import nest
42
43
44def get_dataset_from_tensor_slices(inp_array):
45  dataset = dataset_ops.DatasetV2.from_tensor_slices(inp_array)
46  # TODO(b/138326910): Remove Dataset V1 version once bug resolved.
47  if not tf2.enabled():
48    dataset = dataset_ops.Dataset.from_tensor_slices(inp_array)
49  return dataset
50
51
52class AssertFlattenedMixin(object):
53  """Mixin for specialized asserts."""
54
55  def assert_equal_flattened(self, expected_results, actual_results):
56    """Asserts that flattened results are equal.
57
58    Due to the number of replicas in the strategy, the output may have a
59    different structure and needs to be flattened for comparison.
60
61    Args:
62      expected_results: The results expected as a result of a computation.
63      actual_results: The actual results of a computation.
64    """
65    self.assertEqual(len(expected_results), len(actual_results))
66
67    for i, expected_result in enumerate(expected_results):
68      final_result = []
69      actual_result = actual_results[i]
70      for val in actual_result:
71        final_result.extend(val.numpy())
72      self.assertAllEqual(expected_result, final_result)
73
74
75class InputIterationTest(test.TestCase, parameterized.TestCase,
76                         AssertFlattenedMixin):
77
78  @combinations.generate(
79      combinations.combine(
80          distribution=strategy_combinations.all_strategies,
81          mode=["eager"]
82      ))
83  def testConstantNumpyInput(self, distribution):
84
85    @def_function.function
86    def run(x):
87
88      def computation(x):
89        return math_ops.square(x)
90
91      outputs = distribution.experimental_local_results(
92          distribution.experimental_run_v2(computation, args=(x,)))
93      return outputs
94
95    self.assertAllEqual(
96        constant_op.constant(4., shape=(distribution.num_replicas_in_sync)),
97        run(2.))
98
99  @combinations.generate(
100      combinations.combine(
101          distribution=strategy_combinations.all_strategies,
102          mode=["eager"]
103      ))
104  def testStatefulExperimentalRunAlwaysExecute(self, distribution):
105    with distribution.scope():
106      v = variables.Variable(
107          0.0, aggregation=variables.VariableAggregation.MEAN)
108
109    @def_function.function
110    def train_step():
111
112      def assign_add():
113        v.assign_add(1.0)
114
115      distribution.experimental_run_v2(assign_add)
116      return array_ops.zeros([])
117
118    train_step()
119    self.assertAllEqual(1.0, v.numpy())
120
121  @combinations.generate(
122      combinations.combine(
123          distribution=strategy_combinations.strategies_minus_tpu,
124          mode=["eager"]))
125  def testFullEager(self, distribution):
126    dataset = get_dataset_from_tensor_slices([5., 6., 7., 8.]).batch(2)
127
128    def train_step(data):
129      return math_ops.square(data)
130
131    dist_dataset = distribution.experimental_distribute_dataset(dataset)
132    results = []
133    for x in dist_dataset:
134      output = distribution.experimental_local_results(
135          distribution.experimental_run_v2(train_step, args=(x,)))
136      results.append(output)
137    self.assert_equal_flattened([[25., 36.], [49., 64.]], results)
138
139  @combinations.generate(
140      combinations.combine(
141          distribution=strategy_combinations.all_strategies,
142          mode=["eager"]
143      ))
144  def testStepInFunction(self, distribution):
145    dataset = get_dataset_from_tensor_slices([5., 6., 7., 8.]).batch(2)
146
147    @def_function.function
148    def train_step(data):
149      return math_ops.square(data)
150
151    dist_dataset = distribution.experimental_distribute_dataset(dataset)
152    results = []
153    for x in dist_dataset:
154      output = distribution.experimental_local_results(
155          distribution.experimental_run_v2(train_step, args=(x,)))
156      results.append(output)
157    self.assert_equal_flattened([[25., 36.], [49., 64.]], results)
158
159  @combinations.generate(
160      combinations.combine(
161          distribution=strategy_combinations.all_strategies,
162          mode=["eager"]
163      ))
164  def testRunInFunction(self, distribution):
165    dataset = get_dataset_from_tensor_slices([5., 6., 7., 8.]).batch(2)
166
167    def train_step(data):
168      return math_ops.square(data)
169
170    @def_function.function
171    def f_train_step(input_data):
172      return distribution.experimental_local_results(
173          distribution.experimental_run_v2(train_step, args=(input_data,)))
174
175    dist_dataset = distribution.experimental_distribute_dataset(dataset)
176    results = []
177    for x in dist_dataset:
178      output = f_train_step(x)
179      results.append(output)
180    self.assert_equal_flattened([[25., 36.], [49., 64.]], results)
181
182  @combinations.generate(
183      combinations.combine(
184          distribution=[
185              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
186              strategy_combinations.tpu_strategy
187          ],
188          mode=["eager"]))
189  def testNestedOutput(self, distribution):
190    dataset = get_dataset_from_tensor_slices([0, 1, 2, 3]).batch(2)
191    input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
192
193    @def_function.function
194    def run(iterator):
195
196      def computation(x):
197        return [{
198            "a": x - 1,
199            "b": x + 1
200        }]
201
202      inputs = next(iterator)
203      outputs = distribution.experimental_run_v2(computation, args=(inputs,))
204      return nest.map_structure(distribution.experimental_local_results,
205                                outputs)
206
207    results = run(input_iterator)
208    for replica in range(distribution.num_replicas_in_sync):
209      # The input dataset is range(4), so the replica id is same as input.
210      self.assertAllEqual(results[0]["a"][replica], [replica - 1])
211      self.assertAllEqual(results[0]["b"][replica], [replica + 1])
212
213  @combinations.generate(
214      combinations.combine(
215          distribution=strategy_combinations.all_strategies,
216          mode=["eager"]
217      ))
218  def testRunInFunctionAutoGraphApplication(self, distribution):
219    dataset = get_dataset_from_tensor_slices([5., 6., 7., 8.]).batch(2)
220
221    def train_step(data):
222      return math_ops.square(data)
223
224    @def_function.function
225    def f_train_step(input_data):
226      return distribution.experimental_local_results(
227          distribution.experimental_run_v2(train_step, args=(input_data,)))
228
229    dist_dataset = distribution.experimental_distribute_dataset(dataset)
230    results = []
231    for x in dist_dataset:
232      output = f_train_step(x)
233      results.append(output)
234    self.assert_equal_flattened([[25., 36.], [49., 64.]], results)
235
236  @combinations.generate(
237      combinations.combine(
238          distribution=strategy_combinations.all_strategies,
239          mode=["eager"]
240      ))
241  def testDatasetIterationInFunction(self, distribution):
242    with distribution.scope():
243      a = variables.Variable(
244          1.0, aggregation=variables.VariableAggregation.ONLY_FIRST_REPLICA)
245
246    def train_step(_):
247      a.assign_add(1.0)
248
249    @def_function.function
250    def f_train_step(dist_dataset):
251      number_of_steps = constant_op.constant(0.0)
252      product_of_means = constant_op.constant(2.0)
253      for x in dist_dataset:  # loop with values modified each iteration
254        number_of_steps += 1
255        product_of_means *= math_ops.cast(
256            distribution.reduce("MEAN", x, axis=0), product_of_means.dtype)
257
258      for y in dist_dataset:  # loop with no intermediate state
259        distribution.experimental_run_v2(train_step, args=(y,))
260
261      return number_of_steps, product_of_means
262
263    dataset = get_dataset_from_tensor_slices([5., 6., 7., 8.]).batch(2)
264    dist_dataset = distribution.experimental_distribute_dataset(dataset)
265
266    number_of_steps, product_of_means = f_train_step(dist_dataset)
267    self.assertEqual(2, number_of_steps.numpy())
268    self.assertNear((2 * (5+6)/2 * (7+8)/2), product_of_means.numpy(), 1e-3)
269
270    # We set the initial value of `a` to 1 and iterate through the dataset 2
271    # times(4/2 where 4 is the number of dataset elements and 2 is the batch
272    # size). Hence the final result is 3.
273    self.assertEqual(3.0, (a.numpy()))
274
275  @combinations.generate(
276      combinations.combine(
277          distribution=strategy_combinations.all_strategies,
278          mode=["eager"]
279      ))
280  def testDatasetAssertWithDynamicBatch(self, distribution):
281    # Regression test for github issue 33517.
282    def step_fn(data):
283      assert_op = control_flow_ops.Assert(math_ops.less_equal(
284          math_ops.reduce_max(data), 100.), [data])
285      with ops.control_dependencies([assert_op]):
286        return math_ops.square(data)
287
288    @def_function.function
289    def train(dataset):
290      results = []
291      iterator = iter(dataset)
292      # we iterate through the loop 5 times since we have 3 elements and a
293      # global batch of 2.
294      for _ in range(2):
295        elem = next(iterator)
296        output = distribution.experimental_local_results(
297            distribution.experimental_run_v2(step_fn, args=(elem,)))
298        results.append(output)
299      return results
300
301    dataset = dataset_ops.DatasetV2.from_tensor_slices([5., 6., 7.,]).batch(2)
302    # TODO(b/138326910): Remove Dataset V1 version once bug resolved.
303    if not tf2.enabled():
304      dataset = dataset_ops.Dataset.from_tensor_slices([5., 6., 7.,]).batch(2)
305    dist_dataset = distribution.experimental_distribute_dataset(dataset)
306    results = train(dist_dataset)
307
308    expected_results = [[25., 36.], [49.]]
309    self.assertEqual(len(expected_results), len(results))
310
311    # Need to expand results since output will be grouped differently depending
312    # on the number of replicas.
313    for i, expected_result in enumerate(expected_results):
314      final_result = []
315      actual_result = results[i]
316      for val in actual_result:
317        final_result.extend(val.numpy())
318      self.assertAllEqual(expected_result, final_result)
319
320  @combinations.generate(
321      combinations.combine(
322          distribution=strategy_combinations.multidevice_strategies,
323          mode=["eager"]
324      ))
325  def testDynamicShapes(self, distribution):
326    dataset = get_dataset_from_tensor_slices([5., 6., 7.]).batch(4)
327    input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
328
329    @def_function.function
330    def run(iterator):
331      def computation(x):
332        return math_ops.reduce_mean(x)
333      inputs = next(iterator)
334      outputs = distribution.experimental_local_results(
335          distribution.experimental_run_v2(computation, args=(inputs,)))
336      return outputs
337
338    # This assumes that there are exactly 2 replicas
339    self.assertAllEqual([5.5, 7.], run(input_iterator))
340
341  @combinations.generate(
342      combinations.combine(
343          distribution=strategy_combinations.multidevice_strategies,
344          mode=["eager"]
345      ))
346  def testDynamicShapesWithGetNextOutsideFunction(self, distribution):
347    dataset = get_dataset_from_tensor_slices([5., 6., 7.]).batch(4)
348    input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
349
350    @def_function.function
351    def run(inputs):
352      def computation(x):
353        return math_ops.reduce_mean(x)
354      outputs = distribution.experimental_local_results(
355          distribution.experimental_run_v2(computation, args=(inputs,)))
356      return outputs
357
358    # This assumes that there are exactly 2 replicas
359    self.assertAllEqual([5.5, 7.], run(next(input_iterator)))
360
361  @combinations.generate(
362      combinations.combine(
363          distribution=strategy_combinations.multidevice_strategies,
364          mode=["eager"]
365      ))
366  def testStrategyReduceWithDynamicShapes(self, distribution):
367    dataset = get_dataset_from_tensor_slices([5., 6., 7.]).batch(4)
368    input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
369
370    @def_function.function
371    def run(iterator):
372      inputs = next(iterator)
373      return distribution.reduce(reduce_util.ReduceOp.MEAN, inputs, axis=0)
374
375    self.assertAllEqual(6., run(input_iterator))
376
377  @combinations.generate(
378      combinations.combine(
379          distribution=strategy_combinations.multidevice_strategies,
380          mode=["eager"]
381      ))
382  def testStrategyReduceWithDynamicShapesRank2(self, distribution):
383    dataset = get_dataset_from_tensor_slices(
384        [[1., 1.], [1., 1.], [1., 1.]]).batch(4)
385    input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
386
387    @def_function.function
388    def run(iterator):
389      inputs = next(iterator)
390      return distribution.reduce(reduce_util.ReduceOp.MEAN, inputs, axis=0)
391
392    self.assertAllEqual([1., 1.], run(input_iterator))
393
394  @combinations.generate(
395      combinations.combine(
396          distribution=strategy_combinations.multidevice_strategies,
397          mode=["eager"]
398      ))
399  def testDynamicShapesWithSizeOp(self, distribution):
400    dataset = get_dataset_from_tensor_slices([5., 6., 7.]).batch(4)
401    input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
402
403    @def_function.function
404    def run(inputs):
405      def computation(x):
406        return array_ops.size_v2(x)
407      outputs = distribution.experimental_local_results(
408          distribution.experimental_run_v2(computation, args=(inputs,)))
409      return outputs
410
411    # This assumes that there are exactly 2 replicas
412    self.assertAllEqual([2, 1], run(next(input_iterator)))
413
414  @combinations.generate(
415      combinations.combine(
416          distribution=strategy_combinations.multidevice_strategies,
417          mode=["eager"]
418      ))
419  def testDynamicShapesWithFirstReplicaNotMaximumShape(self, distribution):
420    def dataset_fn(_):
421      dataset1 = get_dataset_from_tensor_slices([[1., 2.], [1., 2.]])
422      dataset2 = get_dataset_from_tensor_slices([[1., 2., 3.],
423                                                 [1., 2., 3.]])
424      dataset = dataset1.concatenate(dataset2)
425      dataset = dataset.batch(2, drop_remainder=True)
426      return dataset
427
428    input_iterator = iter(
429        distribution.experimental_distribute_datasets_from_function(dataset_fn))
430
431    @def_function.function
432    def run(inputs):
433      def computation(x):
434        return math_ops.reduce_mean(x)
435      outputs = distribution.experimental_local_results(
436          distribution.experimental_run_v2(computation, args=(inputs,)))
437      return outputs
438
439    # This assumes that there are exactly 2 replicas
440    self.assertAllEqual([1.5, 2.], run(next(input_iterator)))
441
442  @combinations.generate(
443      combinations.combine(
444          distribution=strategy_combinations.all_strategies,
445          mode=["eager"]
446      ))
447  def testDatasetDistributeEvenlyDivisibleDrop(self, distribution):
448    # If the batch size is evenly divisible by the number of workers and we set
449    # drop_remainder=True on the dataset, then DistributedIterator will use a
450    # different (and more efficient) code path which avoids some control flow
451    # ops.
452    dataset = get_dataset_from_tensor_slices([5., 6.]).batch(
453        2, drop_remainder=True)
454    input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
455
456    data = next(input_iterator)
457
458    expected_result = [5., 6.]
459    final_result = []
460    actual_result = distribution.experimental_local_results(data)
461    for val in actual_result:
462      final_result.extend(val)
463    self.assertAllEqual(expected_result, final_result)
464
465  @combinations.generate(
466      combinations.combine(
467          distribution=strategy_combinations.all_strategies,
468          mode=["eager"]
469      ))
470  def testDatasetDistributeNotDivisibleDrop(self, distribution):
471    # If each batch is not evenly divisible by the number of workers,
472    # the remainder will be dropped.
473    dataset = get_dataset_from_tensor_slices([5., 6.]).batch(
474        1, drop_remainder=True)
475    input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
476
477    data = next(input_iterator)
478
479    expected_result = [5.]
480    final_result = []
481    actual_result = distribution.experimental_local_results(data)
482    for val in actual_result:
483      final_result.extend(val)
484    self.assertAllEqual(expected_result, final_result)
485
486  @combinations.generate(
487      combinations.combine(
488          distribution=strategy_combinations.all_strategies,
489          mode=["eager"]
490      ))
491  def testDatasetDistributeEvenlyDivisibleNoDrop(self, distribution):
492    # Setting drop_remainder=False on the dataset causes DistributedIterator
493    # to use get_next_as_optional(), even if the batched dataset is evenly
494    # divisible by the number of workers.
495    dataset = get_dataset_from_tensor_slices([5., 6.]).batch(
496        2, drop_remainder=False)
497    input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
498
499    data = next(input_iterator)
500
501    expected_result = [5., 6.]
502    final_result = []
503    actual_result = distribution.experimental_local_results(data)
504    for val in actual_result:
505      final_result.extend(val)
506    self.assertAllEqual(expected_result, final_result)
507
508  @combinations.generate(
509      combinations.combine(
510          distribution=strategy_combinations.all_strategies,
511          mode=["eager"]
512      ))
513  def testDatasetPartialBatchWithMixedOutputs(self, distribution):
514    # Dynamic output size with a mix of static and dynamic outputs
515    dataset = get_dataset_from_tensor_slices([5.]).batch(2)
516    input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
517
518    @def_function.function
519    def run(iterator):
520
521      def computation(x):
522        # Fixed size output with a dynamic sized output.
523        return array_ops.zeros([3]), math_ops.square(x)
524
525      return distribution.experimental_run_v2(
526          computation, args=(next(iterator),))
527
528    results = run(input_iterator)
529
530    # First result is fixed for all replicas.
531    for replica_id in range(distribution.num_replicas_in_sync):
532      self.assertAllEqual([0., 0., 0.],
533                          distribution.experimental_local_results(
534                              results[0])[replica_id])
535    # Only first replica has distributed dataset computation.
536    self.assertAllEqual([25.],
537                        distribution.experimental_local_results(results[1])[0])
538    # Other replicas have no distributed dataset computation.
539    for replica_id in range(1, distribution.num_replicas_in_sync):
540      self.assertAllEqual([],
541                          distribution.experimental_local_results(
542                              results[1])[replica_id])
543
544  @combinations.generate(
545      combinations.combine(
546          distribution=strategy_combinations.all_strategies,
547          mode=["eager"]
548      ))
549  def testIterationInsideFunction(self, distribution):
550
551    def step_fn(data):
552      return math_ops.square(data)
553
554    @def_function.function
555    def train(dataset):
556      results = []
557      iterator = iter(dataset)
558      # we iterate through the loop 2 times since we have 4 elements and a
559      # global batch of 2.
560      for _ in range(2):
561        elem = next(iterator)
562        output = distribution.experimental_local_results(
563            distribution.experimental_run_v2(step_fn, args=(elem,)))
564        results.append(output)
565      return results
566
567    dataset = get_dataset_from_tensor_slices([5., 6., 7., 8.]).batch(2)
568    dist_dataset = distribution.experimental_distribute_dataset(dataset)
569    results = train(dist_dataset)
570    self.assert_equal_flattened([[25., 36.], [49., 64.]], results)
571
572  @combinations.generate(
573      combinations.combine(
574          distribution=strategy_combinations.all_strategies,
575          mode=["eager"]
576      ))
577  def testIterationOutsideFunction(self, distribution):
578
579    def train_step(data):
580      return math_ops.square(data)
581
582    @def_function.function
583    def f_train_step(input_data):
584      return distribution.experimental_local_results(
585          distribution.experimental_run_v2(train_step, args=(input_data,)))
586
587    dataset = get_dataset_from_tensor_slices([5., 6., 7., 8.]).batch(2)
588    dist_dataset = distribution.experimental_distribute_dataset(dataset)
589    iterator = iter(dist_dataset)
590    results = []
591    # we iterate through the loop 2 times since we have 4 elements and a
592    # global batch of 2.
593    for _ in range(2):
594      output = f_train_step(next(iterator))
595      results.append(output)
596    self.assert_equal_flattened([[25., 36.], [49., 64.]], results)
597
598
599class GradientTapeTest(test.TestCase, parameterized.TestCase,
600                       AssertFlattenedMixin):
601
602  @combinations.generate(
603      combinations.combine(
604          distribution=strategy_combinations.all_strategies,
605          mode=["eager"]
606      ))
607  def testStepInFunctionGradient(self, distribution):
608    dataset = get_dataset_from_tensor_slices([5., 6., 7., 8.]).batch(2)
609
610    @def_function.function
611    def train_step(x):
612      def computation(x):
613        return math_ops.square(x)
614      with backprop.GradientTape() as tape:
615        tape.watch(x)  # Manually watch non-variable tensors.
616        y = computation(x)
617      grads = tape.gradient(y, x)
618      return grads
619
620    dist_dataset = distribution.experimental_distribute_dataset(dataset)
621    results = []
622    for x in dist_dataset:
623      output = distribution.experimental_local_results(
624          distribution.experimental_run_v2(train_step, args=(x,)))
625      results.append(output)
626    self.assert_equal_flattened([[10., 12.], [14., 16.]], results)
627
628  @combinations.generate(
629      combinations.combine(
630          distribution=strategy_combinations.all_strategies,
631          mode=["eager"]
632      ))
633  def testRunInFunctionGradient(self, distribution):
634    dataset = get_dataset_from_tensor_slices([5., 6., 7., 8.]).batch(2)
635
636    @def_function.function
637    def run(x):
638      def train_step(x):
639        def computation(x):
640          return math_ops.square(x)
641        with backprop.GradientTape() as tape:
642          tape.watch(x)  # Manually watch non-variable tensors.
643          y = computation(x)
644        grads = tape.gradient(y, x)
645        return grads
646      return distribution.experimental_local_results(
647          distribution.experimental_run_v2(train_step, args=(x,)))
648
649    dist_dataset = distribution.experimental_distribute_dataset(dataset)
650    results = []
651    for x in dist_dataset:
652      output = run(x)
653      results.append(output)
654    self.assert_equal_flattened([[10., 12.], [14., 16.]], results)
655
656  @combinations.generate(
657      combinations.combine(
658          distribution=strategy_combinations.all_strategies,
659          mode=["eager"],
660          model_in_tf_function=[True, False]
661      ))
662  def testNestedFunction(self, distribution, model_in_tf_function):
663    def model(x):
664      return x * x
665
666    if model_in_tf_function:
667      model = def_function.function(model)
668
669    with distribution.scope():
670      x = variables.Variable(1.0)
671
672      @def_function.function
673      def train_step():
674        def replica_step():
675          with backprop.GradientTape() as tape:
676            y = model(x)
677          return tape.gradient(y, x)
678        return distribution.experimental_run_v2(replica_step)
679
680      grads = distribution.experimental_local_results(train_step())
681      self.assertLen(grads, distribution.num_replicas_in_sync)
682      self.assertTrue(all(g is not None for g in grads))
683
684
685class KerasModelsTest(test.TestCase, parameterized.TestCase):
686
687  @combinations.generate(
688      combinations.combine(
689          distribution=strategy_combinations.all_strategies,
690          mode=["eager"]
691      ))
692  def test_single_keras_layer_experimental_run(self, distribution):
693    dataset = self._get_dataset()
694    input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
695
696    with distribution.scope():
697      model = keras.layers.Dense(4, name="dense")
698
699    @def_function.function
700    def train_step(iterator):
701      def step_fn(inputs):
702        images, targets = inputs
703        with backprop.GradientTape() as tape:
704          outputs = model(images)
705          loss = math_ops.reduce_sum(outputs - targets)
706        grads = tape.gradient(loss, model.variables)
707        return grads
708
709      outputs = distribution.experimental_run_v2(
710          step_fn, args=(next(iterator),))
711      return nest.map_structure(distribution.experimental_local_results,
712                                outputs)
713
714    train_step(input_iterator)
715
716  @combinations.generate(
717      combinations.combine(
718          distribution=strategy_combinations.all_strategies,
719          mode=["eager"]
720      ))
721  def test_keras_model_creation_experimental_run(self, distribution):
722    dataset = self._get_dataset()
723    input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
724
725    with distribution.scope():
726      model = self._get_model()
727
728    @def_function.function
729    def train_step(iterator):
730      def step_fn(inputs):
731        images, targets = inputs
732        with backprop.GradientTape() as tape:
733          outputs = model(images)
734          loss = math_ops.reduce_sum(outputs - targets)
735        grads = tape.gradient(loss, model.variables)
736        return grads
737
738      outputs = distribution.experimental_run_v2(
739          step_fn, args=(next(iterator),))
740      return nest.map_structure(distribution.experimental_local_results,
741                                outputs)
742
743    train_step(input_iterator)
744
745  @combinations.generate(
746      combinations.combine(
747          distribution=strategy_combinations.all_strategies,
748          mode=["eager"]
749      ))
750  def test_keras_model_optimizer_experimental_run(self, distribution):
751    dataset = self._get_dataset()
752    input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
753
754    with distribution.scope():
755      model = self._get_model()
756      optimizer = keras.optimizer_v2.rmsprop.RMSprop()
757
758    @def_function.function
759    def train_step(iterator):
760      def step_fn(inputs):
761        images, targets = inputs
762        with backprop.GradientTape() as tape:
763          outputs = model(images)
764          loss = math_ops.reduce_sum(outputs - targets)
765        grads = tape.gradient(loss, model.variables)
766        optimizer.apply_gradients(zip(grads, model.variables))
767        return loss
768
769      outputs = distribution.experimental_run_v2(
770          step_fn, args=(next(iterator),))
771      return nest.map_structure(distribution.experimental_local_results,
772                                outputs)
773
774    train_step(input_iterator)
775
776  @combinations.generate(
777      combinations.combine(
778          distribution=strategy_combinations.all_strategies,
779          mode=["eager"]
780      ))
781  def test_keras_subclass_model_optimizer_experimental_run(self, distribution):
782    def get_subclass_model():
783
784      class KerasSubclassModel(keras.Model):
785
786        def __init__(self):
787          super(KerasSubclassModel, self).__init__()
788          self.l = keras.layers.Dense(4, name="dense")
789
790        def call(self, x):
791          return self.l(x)
792
793      return KerasSubclassModel()
794    dataset = self._get_dataset()
795    input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
796
797    with distribution.scope():
798      model = get_subclass_model()
799      optimizer = keras.optimizer_v2.rmsprop.RMSprop()
800
801    @def_function.function
802    def train_step(iterator):
803      def step_fn(inputs):
804        images, targets = inputs
805        with backprop.GradientTape() as tape:
806          outputs = model(images)
807          loss = math_ops.reduce_sum(outputs - targets)
808        grads = tape.gradient(loss, model.variables)
809        optimizer.apply_gradients(zip(grads, model.variables))
810        return loss
811
812      outputs = distribution.experimental_run_v2(
813          step_fn, args=(next(iterator),))
814      return nest.map_structure(distribution.experimental_local_results,
815                                outputs)
816
817    train_step(input_iterator)
818
819  @combinations.generate(
820      combinations.combine(
821          distribution=strategy_combinations.all_strategies,
822          mode=["eager"]
823      ))
824  def test_keras_model_optimizer_experimental_run_loop(self, distribution):
825    dataset = self._get_dataset()
826    input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
827
828    with distribution.scope():
829      model = self._get_model()
830      optimizer = keras.optimizer_v2.rmsprop.RMSprop()
831
832    @def_function.function
833    def train_step(iterator):
834      def step_fn(inputs):
835        images, targets = inputs
836        with backprop.GradientTape() as tape:
837          outputs = model(images)
838          loss = math_ops.reduce_sum(outputs - targets)
839        grads = tape.gradient(loss, model.variables)
840        optimizer.apply_gradients(zip(grads, model.variables))
841        return loss
842
843      for _ in range(5):
844        distribution.experimental_run_v2(step_fn, args=(next(iterator),))
845
846    train_step(input_iterator)
847
848  @combinations.generate(
849      combinations.combine(
850          distribution=strategy_combinations.all_strategies,
851          mode=["eager"]
852      ))
853  def test_lstm(self, distribution):
854
855    batch_size = 32
856
857    def create_lstm_model():
858      model = keras.models.Sequential()
859      # We only have LSTM variables so we can detect no gradient issues more
860      # easily.
861      model.add(
862          keras.layers.LSTM(1, return_sequences=False, input_shape=(10, 1)))
863      return model
864
865    def create_lstm_data():
866      seq_length = 10
867
868      x_train = np.random.rand(batch_size, seq_length, 1).astype("float32")
869      y_train = np.random.rand(batch_size, 1).astype("float32")
870      return x_train, y_train
871
872    x, y = create_lstm_data()
873    dataset = dataset_ops.Dataset.from_tensor_slices((x, y))
874    dataset = dataset.batch(batch_size, drop_remainder=True)
875    input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
876
877    with distribution.scope():
878      model = create_lstm_model()
879      optimizer = keras.optimizer_v2.gradient_descent.SGD()
880
881    @def_function.function
882    def train_step(input_iterator):
883
884      def step_fn(inputs):
885        inps, targ = inputs
886        with backprop.GradientTape() as tape:
887          output = model(inps)
888          loss = math_ops.reduce_mean(
889              keras.losses.binary_crossentropy(
890                  y_true=targ, y_pred=output, from_logits=False))
891        grads = tape.gradient(loss, model.variables)
892        optimizer.apply_gradients(zip(grads, model.variables))
893        return loss
894
895      outputs = distribution.experimental_run_v2(
896          step_fn, args=(next(input_iterator),))
897      return distribution.experimental_local_results(outputs)
898
899    train_step(input_iterator)
900
901  @combinations.generate(
902      combinations.combine(
903          distribution=strategy_combinations.all_strategies, mode=["eager"]))
904  def test_nested_tf_functions(self, distribution):
905    # The test builds two computations with keras layers, one with nested
906    # tf.function, and the other without nested tf.function. We run these
907    # computations independently on the model with same weights, and make sure
908    # the variables are still the same after one training step.
909
910    inputs = np.random.random((10, 3)).astype(np.float32)
911    targets = np.ones((10, 4), dtype=np.float32)
912    dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)).repeat()
913    dataset = dataset.batch(10, drop_remainder=True)
914    input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
915
916    def get_model():
917      x = keras.layers.Input(shape=(3,), name="input")
918      y = keras.layers.Dense(4, name="dense")(x)
919      model = keras.Model(x, y)
920      return model
921
922    with distribution.scope():
923      model = get_model()
924      optimizer = keras.optimizer_v2.gradient_descent.SGD(0.1, momentum=0.01)
925      weights_file = os.path.join(self.get_temp_dir(), ".h5")
926      model.save_weights(weights_file)
927      model2 = get_model()
928      model2.load_weights(weights_file)
929
930    # Make sure model and model2 variables are in sync when initialized.
931    for model_v, model2_v in zip(model.variables, model2.variables):
932      self.assertAllClose(model_v.numpy(), model2_v.numpy())
933
934    def compute_loss(images, targets):
935      outputs = model(images)
936      return math_ops.reduce_sum(outputs - targets)
937
938    @def_function.function
939    def train_step_without_nested_tf_function(inputs):
940
941      def step_fn(inputs):
942        images, targets = inputs
943        with backprop.GradientTape() as tape:
944          loss = compute_loss(images, targets)
945        grads = tape.gradient(loss, model.variables)
946        optimizer.apply_gradients(zip(grads, model.variables))
947
948      distribution.experimental_run_v2(step_fn, args=(inputs,))
949
950    @def_function.function
951    def compute_loss2(images, targets):
952      outputs = model2(images)
953      return math_ops.reduce_sum(outputs - targets)
954
955    @def_function.function
956    def train_step_with_nested_tf_function(inputs):
957
958      def step_fn(inputs):
959        images, targets = inputs
960        with backprop.GradientTape() as tape:
961          loss = compute_loss2(images, targets)
962        grads = tape.gradient(loss, model2.variables)
963        optimizer.apply_gradients(zip(grads, model2.variables))
964
965      distribution.experimental_run_v2(step_fn, args=(inputs,))
966
967    inputs = next(input_iterator)
968
969    train_step_without_nested_tf_function(inputs)
970    train_step_with_nested_tf_function(inputs)
971
972    # Make sure model and model2 variables are still in sync.
973    for model_v, model2_v in zip(model.variables, model2.variables):
974      self.assertAllClose(model_v.numpy(), model2_v.numpy())
975
976  def _get_dataset(self):
977    inputs = np.zeros((10, 3), dtype=np.float32)
978    targets = np.zeros((10, 4), dtype=np.float32)
979    dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
980    dataset = dataset.repeat(100)
981    dataset = dataset.batch(10, drop_remainder=True)
982    return dataset
983
984  def _get_model(self):
985    x = keras.layers.Input(shape=(3,), name="input")
986    y = keras.layers.Dense(4, name="dense")(x)
987    model = keras.Model(x, y)
988    return model
989
990
991if __name__ == "__main__":
992  test.main()
993