• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for running legacy optimizer code with DistributionStrategy."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from absl.testing import parameterized
22import numpy
23
24from tensorflow.python.data.ops import dataset_ops
25from tensorflow.python.distribute import combinations as ds_combinations
26from tensorflow.python.distribute import reduce_util
27from tensorflow.python.distribute import strategy_combinations
28from tensorflow.python.eager import context
29from tensorflow.python.framework import constant_op
30from tensorflow.python.framework import ops
31from tensorflow.python.framework import test_combinations as combinations
32from tensorflow.python.keras.distribute import optimizer_combinations
33from tensorflow.python.keras.distribute.test_example import batchnorm_example
34from tensorflow.python.keras.distribute.test_example import minimize_loss_example
35from tensorflow.python.keras.layers import core
36from tensorflow.python.keras.optimizer_v2 import optimizer_v2
37from tensorflow.python.ops import array_ops
38from tensorflow.python.ops import control_flow_ops
39from tensorflow.python.ops import control_flow_v2_toggles
40from tensorflow.python.ops import math_ops
41from tensorflow.python.ops import variable_scope
42from tensorflow.python.ops import variables as variables_lib
43from tensorflow.python.ops.losses import losses_impl
44from tensorflow.python.platform import test
45
46
47VAR_MAP_V1 = {
48    "GradientDescent": ("dense/kernel", "dense/bias"),
49    "Adagrad": ("dense/kernel/Adagrad", "dense/kernel", "dense/bias/Adagrad",
50                "dense/bias"),
51    "Ftrl": ("dense/kernel/Ftrl", "dense/kernel", "dense/bias/Ftrl",
52             "dense/bias", "dense/kernel/Ftrl_1", "dense/bias/Ftrl_1"),
53    "RMSProp": ("dense/kernel", "dense/bias/RMSProp", "dense/bias/RMSProp_1",
54                "dense/bias", "dense/kernel/RMSProp_1", "dense/kernel/RMSProp")
55}
56
57VAR_MAP_V2 = {
58    "SGD": ("dense/bias", "SGD/learning_rate", "SGD/decay", "SGD/iter",
59            "dense/kernel", "SGD/momentum"),
60    "Adagrad":
61        ("Adagrad/iter", "dense/bias", "dense/kernel", "Adagrad/learning_rate",
62         "Adagrad/decay", "Adagrad/dense/kernel/accumulator",
63         "Adagrad/dense/bias/accumulator")
64}
65
66
67class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
68
69  def _get_iterator(self, strategy, input_fn):
70    iterator = strategy.make_input_fn_iterator(lambda _: input_fn())
71    self.evaluate(iterator.initializer)
72    return iterator
73
74  @ds_combinations.generate(
75      combinations.times(
76          optimizer_combinations.distributions_and_v1_optimizers(),
77          combinations.combine(mode=["graph"], use_callable_loss=[True, False])
78          + combinations.combine(mode=["eager"], use_callable_loss=[True])) +
79      combinations.times(
80          optimizer_combinations.distributions_and_v2_optimizers(),
81          combinations.combine(
82              mode=["graph", "eager"], use_callable_loss=[True])) +
83      combinations.combine(
84          distribution=[strategy_combinations.tpu_strategy],
85          optimizer_fn=optimizer_combinations.optimizers_v2,
86          mode=["graph"],
87          use_callable_loss=[True]) + combinations.combine(
88              distribution=[strategy_combinations.tpu_strategy],
89              optimizer_fn=optimizer_combinations.optimizers_v1,
90              mode=["graph"],
91              use_callable_loss=[True, False]))
92  def testTrainNetwork(self, distribution, optimizer_fn, use_callable_loss):
93    with distribution.scope():
94      optimizer = optimizer_fn()
95      model_fn, dataset_fn, layer = minimize_loss_example(
96          optimizer, use_bias=True, use_callable_loss=use_callable_loss)
97
98      def step_fn(ctx, inputs):
99        del ctx  # Unused
100        return distribution.group(
101            distribution.extended.call_for_each_replica(
102                model_fn, args=(inputs,)))
103
104      iterator = self._get_iterator(distribution, dataset_fn)
105
106      def run_step():
107        return distribution.extended.experimental_run_steps_on_iterator(
108            step_fn, iterator, iterations=2).run_op
109
110      if not context.executing_eagerly():
111        with self.cached_session() as sess:
112          run_step = sess.make_callable(run_step())
113      self.evaluate(variables_lib.global_variables_initializer())
114
115      weights, biases = [], []
116      for _ in range(5):
117        run_step()
118        weights.append(self.evaluate(layer.kernel))
119        biases.append(self.evaluate(layer.bias))
120
121      error = abs(numpy.add(numpy.squeeze(weights), numpy.squeeze(biases)) - 1)
122      is_not_increasing = all(y <= x for x, y in zip(error, error[1:]))
123      self.assertTrue(is_not_increasing)
124
125  @ds_combinations.generate(
126      combinations.times(
127          optimizer_combinations.distributions_and_v1_optimizers(),
128          combinations.combine(mode=["graph"], use_callable_loss=[True, False])
129          + combinations.combine(mode=["eager"], use_callable_loss=[True])) +
130      combinations.times(
131          optimizer_combinations.distributions_and_v2_optimizers(),
132          combinations.combine(
133              mode=["graph", "eager"], use_callable_loss=[True])))
134  def testTrainNetworkByCallForEachReplica(self, distribution, optimizer_fn,
135                                           use_callable_loss):
136    with distribution.scope():
137      optimizer = optimizer_fn()
138      model_fn, dataset_fn, layer = minimize_loss_example(
139          optimizer, use_bias=True, use_callable_loss=use_callable_loss)
140
141      iterator = self._get_iterator(distribution, dataset_fn)
142
143      def run_step():
144        return distribution.group(
145            distribution.extended.call_for_each_replica(
146                model_fn, args=(iterator.get_next(),)))
147
148      if not context.executing_eagerly():
149        with self.cached_session() as sess:
150          run_step = sess.make_callable(run_step())
151        self.evaluate(variables_lib.global_variables_initializer())
152
153      weights, biases = [], []
154      for _ in range(10):
155        run_step()
156
157        weights.append(self.evaluate(layer.kernel))
158        biases.append(self.evaluate(layer.bias))
159
160      error = abs(numpy.add(numpy.squeeze(weights), numpy.squeeze(biases)) - 1)
161      is_not_increasing = all(y <= x for x, y in zip(error, error[1:]))
162      self.assertTrue(is_not_increasing)
163
164  @ds_combinations.generate(
165      combinations.times(
166          optimizer_combinations.distributions_and_v1_and_v2_optimizers(),
167          combinations.combine(mode=["graph", "eager"])) + combinations.combine(
168              distribution=[strategy_combinations.tpu_strategy],
169              optimizer_fn=optimizer_combinations.optimizers_v1_and_v2,
170              mode=["graph"]))
171  def testOptimizerInsideModelFn(self, distribution, optimizer_fn):
172    if (not context.executing_eagerly() and
173        control_flow_v2_toggles.control_flow_v2_enabled()):
174      self.skipTest("b/138751864")
175    created_variables = []
176    trainable_variables = []
177
178    def appending_creator(next_creator, **kwargs):
179      v = next_creator(**kwargs)
180      created_variables.append(v.name)
181      if "trainable" in kwargs and kwargs["trainable"]:
182        trainable_variables.append(v.name)
183      return v
184
185    # Creator scope needs to be set before it's used inside
186    # `distribution.scope`.
187    with variable_scope.variable_creator_scope(
188        appending_creator), distribution.scope():
189      optimizer = optimizer_fn()
190      model_fn, dataset_fn, _ = minimize_loss_example(
191          optimizer, use_bias=True, use_callable_loss=True)
192
193      def step_fn(ctx, inputs):
194        del ctx  # Unused
195        return distribution.group(
196            distribution.extended.call_for_each_replica(
197                model_fn, args=(inputs,)))
198
199      iterator = self._get_iterator(distribution, dataset_fn)
200
201      def run_step():
202        return distribution.extended.experimental_run_steps_on_iterator(
203            step_fn, iterator, iterations=1).run_op
204
205      if not context.executing_eagerly():
206        with self.cached_session() as sess:
207          run_step = sess.make_callable(run_step())
208      self.evaluate(variables_lib.global_variables_initializer())
209      run_step()
210
211      def get_expected_variables(num_parameter_devices):
212        name = optimizer._name
213
214        if isinstance(optimizer, optimizer_v2.OptimizerV2):
215          variables = VAR_MAP_V2[name]
216        else:
217          variables = VAR_MAP_V1[name]
218
219        extended_variables = [
220            v + "/replica_{}".format(replica)
221            for v in variables
222            for replica in range(1, num_parameter_devices)
223        ]
224        variables = list(variables) + extended_variables
225        return set(v + ":0" for v in variables)
226
227      self.assertEqual(
228          get_expected_variables(len(distribution.extended.parameter_devices)),
229          set(created_variables))
230
231  @ds_combinations.generate(
232      combinations.times(
233          combinations.combine(momentum=[0.8, 0.9, 0.99], renorm=[False, True]),
234          combinations.times(
235              optimizer_combinations.distributions_and_v1_and_v2_optimizers(),
236              combinations.combine(
237                  mode=["graph", "eager"],
238                  # TODO(isaprykin):  Allow False here.  Currently subsequent
239                  # replicas will re-execute UPDATE_OPS of previous replicas.
240                  update_ops_in_cross_replica_mode=[True])) +
241          combinations.combine(
242              distribution=[strategy_combinations.tpu_strategy],
243              optimizer_fn=optimizer_combinations.optimizers_v1_and_v2,
244              mode=["graph"],
245              update_ops_in_cross_replica_mode=[False])))
246  def testTrainNetworkWithBatchNorm(self, distribution, optimizer_fn, momentum,
247                                    renorm, update_ops_in_cross_replica_mode):
248    """Verifies that moving mean updates are reduced across replicas."""
249    with distribution.scope():
250      num_replicas = distribution.num_replicas_in_sync
251      model_fn, dataset_fn, batchnorm = batchnorm_example(
252          optimizer_fn,
253          batch_per_epoch=num_replicas,
254          momentum=momentum,
255          renorm=renorm,
256          update_ops_in_replica_mode=not update_ops_in_cross_replica_mode)
257
258      def step_fn(ctx, inputs):
259        del ctx  # Unused
260        fetches = distribution.experimental_local_results(
261            distribution.extended.call_for_each_replica(
262                model_fn, args=(inputs,)))
263        if update_ops_in_cross_replica_mode:
264          fetches += tuple(ops.get_collection(ops.GraphKeys.UPDATE_OPS))
265        return control_flow_ops.group(fetches)
266
267      iterator = self._get_iterator(distribution, dataset_fn)
268
269      def run_step():
270        return distribution.extended.experimental_run_steps_on_iterator(
271            step_fn, iterator, iterations=1).run_op
272
273      if not context.executing_eagerly():
274        with self.cached_session() as sess:
275          run_step = sess.make_callable(run_step())
276      self.evaluate(variables_lib.global_variables_initializer())
277
278      expected_moving_means = [0.] * 8
279
280      def averaged_batch_mean(i):
281        # Each batch has shape [16, 8] where the ith element in jth list is
282        # (8 * j + i + replica_id * 100). So the batch mean in each replica is
283        # (60 + i + replica_id * 100). So here comes its batch mean over all
284        # replicas:
285        return 60. + i + (num_replicas - 1.) / 2. * 100.
286
287      for _ in range(10):
288        run_step()
289        moving_means = self.evaluate(batchnorm.moving_mean)
290
291        # We make sure that the moving_mean is updated as if the sample mean is
292        # calculated over all replicas.
293        for i, expected_moving_mean in enumerate(expected_moving_means):
294          expected_moving_means[i] -= ((
295              expected_moving_mean - averaged_batch_mean(i)) * (1.0 - momentum))
296          self.assertNear(expected_moving_means[i], moving_means[i], 0.0001)
297
298  @ds_combinations.generate(
299      combinations.times(
300          combinations.combine(loss_reduction=[
301              losses_impl.Reduction.SUM, losses_impl.Reduction.MEAN,
302              losses_impl.Reduction.SUM_OVER_BATCH_SIZE,
303              losses_impl.Reduction.SUM_OVER_NONZERO_WEIGHTS
304          ]),
305          combinations.times(
306              combinations.combine(distribution=[
307                  strategy_combinations.one_device_strategy,
308                  strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
309                  strategy_combinations.mirrored_strategy_with_two_gpus
310              ]),
311              combinations.times(
312                  combinations.combine(optimizer_fn=optimizer_combinations
313                                       .gradient_descent_optimizer_v1_fn),
314                  combinations.combine(
315                      mode=["graph"], use_callable_loss=[True, False]) +
316                  combinations.combine(
317                      mode=["eager"], use_callable_loss=[True])) +
318              combinations.times(
319                  combinations.combine(optimizer_fn=optimizer_combinations
320                                       .gradient_descent_optimizer_keras_v2_fn),
321                  combinations.combine(
322                      mode=["graph", "eager"], use_callable_loss=[True]))) +
323          combinations.combine(
324              distribution=[strategy_combinations.tpu_strategy],
325              optimizer_fn=optimizer_combinations
326              .gradient_descent_optimizer_v1_fn,
327              mode=["graph"],
328              use_callable_loss=[True, False]) + combinations.combine(
329                  distribution=[strategy_combinations.tpu_strategy],
330                  optimizer_fn=optimizer_combinations
331                  .gradient_descent_optimizer_keras_v2_fn,
332                  mode=["graph"],
333                  use_callable_loss=[True])))
334  def testMeanVsSum(self, distribution, optimizer_fn, loss_reduction,
335                    use_callable_loss):
336    with distribution.scope():
337      all_vars = []
338
339      def model_fn(inputs):
340        x, y = inputs
341        w = variable_scope.get_variable("w", initializer=[[2.]])
342        all_vars.append(w)
343
344        def loss_fn():
345          # Use fixed initialization to make the steps deterministic.
346          predict = math_ops.matmul(x, w)
347          loss = losses_impl.mean_squared_error(
348              y, predict, reduction=loss_reduction)
349          if loss_reduction == losses_impl.Reduction.SUM:
350            return loss
351          return loss / distribution.num_replicas_in_sync
352
353        optimizer = optimizer_fn()  # GradientDescent with 0.2 learning rate
354
355        if isinstance(optimizer, optimizer_v2.OptimizerV2):
356          return optimizer.minimize(loss_fn, [w])
357        else:
358          if use_callable_loss:
359            return optimizer.minimize(loss_fn)
360          else:
361            return optimizer.minimize(loss_fn())
362
363      def dataset_fn():
364        features = dataset_ops.Dataset.from_tensors([[2.], [7.]])
365        labels = dataset_ops.Dataset.from_tensors([[6.], [21.]])
366        return dataset_ops.Dataset.zip((features, labels)).repeat()
367
368      def step_fn(ctx, inputs):
369        del ctx  # Unused
370        return distribution.group(
371            distribution.extended.call_for_each_replica(
372                model_fn, args=(inputs,)))
373
374      iterator = self._get_iterator(distribution, dataset_fn)
375
376      def run_step():
377        return distribution.extended.experimental_run_steps_on_iterator(
378            step_fn, iterator, iterations=1).run_op
379
380      if not context.executing_eagerly():
381        with self.cached_session() as sess:
382          run_step = sess.make_callable(run_step())
383      self.evaluate(variables_lib.global_variables_initializer())
384
385      run_step()
386
387      v = all_vars[0]
388      self.assertTrue(all(v is vi for vi in all_vars[1:]))
389      weight = numpy.squeeze(self.evaluate(v))
390      # Our model is:
391      #   predict = x * w
392      #   loss = (predict - y)^2
393      #   dloss/dpredict = 2*(predict - y)
394      #   dloss/dw = 2 * x^T @ (predict - y)
395      # For our batch size of 2, assuming sum loss reduction:
396      #   x = [2, 7]
397      #   y = [6, 21]
398      #   w_initial = 2
399      #   predict = [4, 14]
400      #   predict - y = [-2, -7]
401      #   dloss/dw = 2 <[2, 7], [-2, -7]> = - 2(4 + 49) = -106
402      # So unreplicated the update to w with lr=0.001 is -0.2 * -106 = 0.106
403      # with sum loss reduction, or 0.053 with mean.
404      if loss_reduction == losses_impl.Reduction.SUM:
405        # Note that the "distribution.num_replicas_in_sync" factor will go away
406        # once we split the input across replicas, instead of pulling a complete
407        # batch of input per replica.
408        self.assertNear(weight, 2 + 0.106 * distribution.num_replicas_in_sync,
409                        0.0001)
410      else:
411        # One of the mean loss reductions.
412        self.assertNear(weight, 2 + 0.053, 0.0001)
413
414  @ds_combinations.generate(
415      combinations.times(
416          optimizer_combinations.distributions_and_v1_and_v2_optimizers(),
417          combinations.combine(mode=["graph", "eager"]),
418          combinations.combine(is_tpu=[False])) + combinations.combine(
419              distribution=[strategy_combinations.tpu_strategy],
420              optimizer_fn=optimizer_combinations.optimizers_v1_and_v2,
421              mode=["graph"],
422              is_tpu=[True]))
423  def testRunStepsWithOutputContext(self, distribution, optimizer_fn, is_tpu):
424    with distribution.scope():
425      def dataset_fn():
426        dataset = dataset_ops.Dataset.from_tensors([[1.]]).repeat()
427        # TODO(priyag): batch with drop_remainder=True causes shapes to be
428        # fully defined for TPU. Remove this when XLA supports dynamic shapes.
429        return dataset.batch(batch_size=1, drop_remainder=True)
430
431      optimizer = optimizer_fn()
432      layer = core.Dense(1, use_bias=True)
433
434      key1 = "foo"
435      value1 = "bar"
436
437      def model_fn(output_context, x):
438        """A very simple model written by the user."""
439        def loss_fn():
440          y = array_ops.reshape(layer(x), []) - constant_op.constant(1.)
441          return y * y
442
443        if isinstance(optimizer, optimizer_v2.OptimizerV2):
444          train_op = optimizer.minimize(
445              loss_fn, lambda: layer.trainable_variables)
446        else:
447          train_op = optimizer.minimize(loss_fn)
448        loss = loss_fn()
449        output_context.set_last_step_output(
450            name="replica_loss_reduced",
451            output=loss,
452            reduce_op=reduce_util.ReduceOp.MEAN)
453        output_context.set_non_tensor_output(key1, value1)
454        return (train_op, loss)
455
456      def step_fn(output_context, inputs):
457        (train_op, loss) = distribution.extended.call_for_each_replica(
458            model_fn, args=(output_context, inputs))
459        output_context.set_last_step_output(
460            name="cross_replica_loss_reduced",
461            output=loss,
462            reduce_op=reduce_util.ReduceOp.MEAN)
463        output_context.set_last_step_output(
464            name="cross_replica_loss_not_reduced",
465            output=loss)
466        return distribution.group(train_op)
467
468      iterator = self._get_iterator(distribution, dataset_fn)
469
470      def run_step():
471        initial_loss = lambda: constant_op.constant(1e7)
472        # Initial values corresponding to reduced losses are just single
473        # tensors. But for non reduced losses, we need to have initial
474        # values that are of the same structure as non reduced losses. In
475        # MirroredStrategy, this will be a list of losses, in TPUStrategy
476        # it will be single tensor. Using `call_for_each_replica` followed
477        # by `experimental_local_results` gives us the desired initial
478        # value structure.
479        not_reduced = distribution.experimental_local_results(
480            distribution.extended.call_for_each_replica(initial_loss))
481        initial_loop_values = {
482            "replica_loss_reduced": initial_loss(),
483            "cross_replica_loss_reduced": initial_loss(),
484            "cross_replica_loss_not_reduced": not_reduced,
485        }
486        ctx = distribution.extended.experimental_run_steps_on_iterator(
487            step_fn, iterator, iterations=2,
488            initial_loop_values=initial_loop_values)
489
490        self.assertEqual({key1: (value1,)}, ctx.non_tensor_outputs)
491        self._verify_loss_output(
492            initial_loss(),
493            loss_output=ctx.last_step_outputs["replica_loss_reduced"],
494            reduced=True, distribution=distribution)
495        self._verify_loss_output(
496            initial_loss(),
497            loss_output=ctx.last_step_outputs["cross_replica_loss_reduced"],
498            reduced=True, distribution=distribution)
499        self._verify_loss_output(
500            initial_loss(),
501            loss_output=ctx.last_step_outputs["cross_replica_loss_not_reduced"],
502            reduced=False, distribution=distribution)
503        return (ctx.run_op, ctx.last_step_outputs["replica_loss_reduced"])
504
505      if not context.executing_eagerly():
506        with self.cached_session() as sess:
507          run_step = sess.make_callable(run_step())
508      self.evaluate(variables_lib.global_variables_initializer())
509
510      weights, biases, losses = [], [], []
511      for _ in range(5):
512        _, loss = run_step()
513        losses.append(loss)
514        weights.append(self.evaluate(layer.kernel))
515        biases.append(self.evaluate(layer.bias))
516
517      loss_is_not_increasing = all(y <= x for x, y in zip(losses, losses[1:]))
518      self.assertTrue(loss_is_not_increasing)
519
520      error = abs(
521          numpy.add(numpy.squeeze(weights), numpy.squeeze(biases)) - 1)
522      error_is_not_increasing = all(y <= x for x, y in zip(error, error[1:]))
523      self.assertTrue(error_is_not_increasing)
524
525  def _verify_loss_output(self, initial_loss, loss_output, reduced,
526                          distribution):
527    if not reduced:
528      self.assertLen(distribution.experimental_local_results(loss_output),
529                     distribution.num_replicas_in_sync)
530      loss_tensor = distribution.reduce(reduce_util.ReduceOp.MEAN, loss_output,
531                                        axis=None)
532    else:
533      unwrapped_output = distribution.experimental_local_results(loss_output)
534      self.assertLen(unwrapped_output, 1)
535      loss_tensor = unwrapped_output[0]
536    self.assertEqual(initial_loss.dtype, loss_tensor.dtype)
537    self.assertEqual(initial_loss.shape, loss_tensor.shape)
538
539  @ds_combinations.generate(
540      optimizer_combinations.distributions_and_v2_optimizers())
541  def test_empty_var_list(self, distribution, optimizer_fn):
542    opt = optimizer_fn()
543    with distribution.scope():
544
545      def run_fn():
546        opt.minimize(lambda: constant_op.constant(1.), [])
547        opt.apply_gradients([])
548
549      distribution.run(run_fn)
550
551
552if __name__ == "__main__":
553  test.main()
554