• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests mixed precision works correctly with Keras layers and models."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import os
21
22from absl import flags
23from absl.testing import parameterized
24import numpy as np
25
26from tensorflow.python.data.ops import dataset_ops
27from tensorflow.python.distribute import central_storage_strategy
28from tensorflow.python.distribute import distribution_strategy_context
29from tensorflow.python.distribute import mirrored_strategy
30from tensorflow.python.eager import backprop
31from tensorflow.python.eager import context
32from tensorflow.python.eager import def_function
33from tensorflow.python.framework import config as tf_config
34from tensorflow.python.framework import constant_op
35from tensorflow.python.framework import dtypes
36from tensorflow.python.keras import backend
37from tensorflow.python.keras import combinations
38from tensorflow.python.keras import keras_parameterized
39from tensorflow.python.keras import layers
40from tensorflow.python.keras import models
41from tensorflow.python.keras import optimizer_v1
42from tensorflow.python.keras import testing_utils
43from tensorflow.python.keras.engine import base_layer
44from tensorflow.python.keras.engine import base_layer_utils
45from tensorflow.python.keras.engine import input_spec
46from tensorflow.python.keras.engine import sequential
47from tensorflow.python.keras.layers import core
48from tensorflow.python.keras.mixed_precision import get_layer_policy
49from tensorflow.python.keras.mixed_precision import loss_scale_optimizer
50from tensorflow.python.keras.mixed_precision import policy
51from tensorflow.python.keras.mixed_precision import test_util as mp_test_util
52from tensorflow.python.keras.optimizer_v2 import gradient_descent
53from tensorflow.python.keras.saving import save
54from tensorflow.python.keras.utils import generic_utils
55from tensorflow.python.ops import math_ops
56from tensorflow.python.ops import variables
57# from tensorflow.python.platform import flags
58from tensorflow.python.platform import test
59from tensorflow.python.training.experimental import loss_scale as loss_scale_module
60from tensorflow.python.training.tracking import util as trackable_utils
61
62
63# Pylint's static analysis incorrectly believes many layers are non-callable, so
64# we disable the lint error.
65# pylint: disable=not-callable
66
67
68class MultiplyLayerWithoutAutoCast(mp_test_util.MultiplyLayer):
69  """Same as MultiplyLayer, but does not use AutoCastVariables."""
70
71  def build(self, _):
72    dtype = self.dtype
73    if dtype in ('float16', 'bfloat16'):
74      dtype = 'float32'
75    self.v = self.add_weight(
76        'v', (),
77        initializer='ones',
78        dtype=dtype,
79        experimental_autocast=False,
80        regularizer=self._regularizer)
81    self.built = True
82
83  def call(self, inputs):
84    self.assert_input_types(inputs)
85    assert self.v.dtype in (dtypes.float32, dtypes.float64)
86    return self._multiply(inputs, math_ops.cast(self.v, inputs.dtype))
87
88
89class MultiplyLayerWithFunction(mp_test_util.MultiplyLayer):
90  """Same as MultiplyLayer, but _multiply is decorated with a tf.function."""
91
92  @def_function.function
93  def _multiply(self, x, y):
94    return super(MultiplyLayerWithFunction, self)._multiply(x, y)
95
96
97# If called outside any strategy.scope() calls, this will return the default
98# strategy.
99default_strategy_fn = distribution_strategy_context.get_strategy
100
101
102def create_mirrored_strategy():
103  """Create a MirroredStrategy, using a GPU if it is available."""
104  if tf_config.list_logical_devices('GPU'):
105    return mirrored_strategy.MirroredStrategy(['cpu:0', 'gpu:0'])
106  else:
107    return mirrored_strategy.MirroredStrategy(['cpu:0'])
108
109
110def create_central_storage_strategy():
111  """Create a CentralStorageStrategy, using a GPU if it is available."""
112  compute_devices = ['cpu:0', 'gpu:0'] if (
113      tf_config.list_logical_devices('GPU')) else ['cpu:0']
114  return central_storage_strategy.CentralStorageStrategy(
115      compute_devices, parameter_device='cpu:0')
116
117
118TESTCASES = ({
119    'testcase_name': 'base',
120    'strategy_fn': default_strategy_fn
121}, {
122    'testcase_name': 'distribute',
123    'strategy_fn': create_mirrored_strategy
124})
125
126
127@combinations.generate(combinations.combine(mode=['graph', 'eager']))
128class KerasLayerTest(keras_parameterized.TestCase):
129  """Test mixed precision with Keras layers."""
130
131  @parameterized.named_parameters(*TESTCASES)
132  def test_mixed_policies_(self, strategy_fn):
133    strategy = strategy_fn()
134    for dtype in 'float16', 'bfloat16':
135      x = constant_op.constant([1.])
136      policy_name = 'mixed_' + dtype
137      with strategy.scope(), policy.policy_scope(policy_name):
138        layer = mp_test_util.MultiplyLayer(assert_type=dtype)
139        self.assertEqual(layer.dtype, dtypes.float32)
140        self.assertEqual(get_layer_policy.get_layer_policy(layer).name,
141                         policy_name)
142        y = layer(x)
143        self.assertEqual(layer.v.dtype, dtypes.float32)
144        self.assertEqual(y.dtype, dtype)
145        self.assertEqual(layer.dtype_policy.name, policy_name)
146        self.assertIsInstance(layer.dtype_policy, policy.Policy)
147        self.assertEqual(layer.compute_dtype, dtype)
148        self.assertEqual(layer.dtype, dtypes.float32)
149        self.assertEqual(layer.variable_dtype, dtypes.float32)
150        self.assertEqual(get_layer_policy.get_layer_policy(layer).name,
151                         policy_name)
152        self.evaluate(variables.global_variables_initializer())
153        self.assertEqual(self.evaluate(y), 1.)
154
155  def test_layer_with_int_variable(self):
156    class LayerWithIntVar(base_layer.Layer):
157
158      def build(self, _):
159        self.v = self.add_weight('v', dtype='int32', trainable=False)
160
161      def call(self, inputs):
162        # Only float variables should be autocasted. This will fail if self.v is
163        # autocasted to float32
164        return math_ops.cast(inputs, 'int32') + self.v
165
166    x = constant_op.constant([1.])
167    layer = LayerWithIntVar(dtype=policy.Policy('mixed_float16'))
168    self.assertEqual(layer(x).dtype, 'int32')
169
170  @parameterized.named_parameters(*TESTCASES)
171  def test_layer_with_non_autocast_variable(self, strategy_fn):
172    x = constant_op.constant([1.])
173    with strategy_fn().scope():
174      with policy.policy_scope('mixed_float16'):
175        layer = MultiplyLayerWithoutAutoCast(assert_type=dtypes.float16)
176        y = layer(x)
177        self.assertEqual(layer.v.dtype, dtypes.float32)
178        self.assertEqual(y.dtype, dtypes.float16)
179        self.evaluate(variables.global_variables_initializer())
180        self.assertEqual(self.evaluate(y), 1.)
181
182  @parameterized.named_parameters(*TESTCASES)
183  def test_layer_calling_tf_function(self, strategy_fn):
184    x = constant_op.constant([1.])
185    with strategy_fn().scope():
186      with policy.policy_scope('mixed_float16'):
187        layer = MultiplyLayerWithFunction(assert_type=dtypes.float16)
188        y = layer(x)
189        self.assertEqual(layer.v.dtype, dtypes.float32)
190        self.assertEqual(y.dtype, dtypes.float16)
191        self.evaluate(variables.global_variables_initializer())
192        self.assertEqual(self.evaluate(y), 1.)
193
194  @parameterized.named_parameters(*TESTCASES)
195  def test_layer_regularizer_runs_in_var_dtype(self, strategy_fn):
196    x = constant_op.constant([1.])
197    with strategy_fn().scope():
198      with policy.policy_scope('mixed_float16'):
199        # Test on MultiplyLayer
200        layer = mp_test_util.MultiplyLayer(
201            assert_type=dtypes.float16,
202            regularizer=mp_test_util.IdentityRegularizer())
203        layer(x)
204        (regularizer_loss,) = layer.losses
205        self.assertEqual(regularizer_loss.dtype, dtypes.float32)
206        self.evaluate(variables.global_variables_initializer())
207        self.assertEqual(self.evaluate(regularizer_loss), 1.)
208
209        # Test on MultiplyLayerWithoutAutoCast
210        layer = MultiplyLayerWithoutAutoCast(
211            assert_type=dtypes.float16,
212            regularizer=mp_test_util.IdentityRegularizer())
213        layer(x)
214        (regularizer_loss,) = layer.losses
215        self.assertEqual(regularizer_loss.dtype, dtypes.float32)
216        self.evaluate(variables.global_variables_initializer())
217        self.assertEqual(self.evaluate(regularizer_loss), 1.)
218
219  @parameterized.named_parameters(*TESTCASES)
220  def test_passing_policy_to_layer(self, strategy_fn):
221    x = constant_op.constant([1.], dtype=dtypes.float16)
222    with strategy_fn().scope():
223      # Passing a Policy to 'dtype' sets the policy for that layer.
224      layer = mp_test_util.MultiplyLayer(
225          assert_type=dtypes.float16, dtype=policy.Policy('mixed_float16'))
226      # layer.dtype refers to the variable dtype
227      self.assertEqual(layer.dtype, dtypes.float32)
228      layer(x)
229      self.assertEqual(layer.v.dtype, dtypes.float32)
230      with policy.policy_scope('mixed_float16'):
231        # Passing a Policy to dtype overrides the global Policy
232        layer = mp_test_util.MultiplyLayer(
233            assert_type=dtypes.float64, dtype=policy.Policy('float64'))
234        self.assertEqual(layer.dtype_policy.name, 'float64')
235        self.assertIsInstance(layer.dtype_policy, policy.Policy)
236        self.assertEqual(layer.compute_dtype, dtypes.float64)
237        self.assertEqual(layer.dtype, dtypes.float64)
238        self.assertEqual(layer.variable_dtype, dtypes.float64)
239        self.assertEqual(layer(x).dtype, dtypes.float64)
240        self.assertEqual(layer.v.dtype, dtypes.float64)
241
242  def test_error_passing_policy_string_to_layer(self):
243    with self.assertRaisesRegex(
244        TypeError, "Cannot convert value 'mixed_float16' to a "
245        'TensorFlow DType'):
246      # This is not allowed, as otherwise a "mixed_float16" policy could be
247      # created without an API call that has the name "experimental" in it.
248      mp_test_util.MultiplyLayer(dtype='mixed_float16')
249
250  @parameterized.named_parameters(*TESTCASES)
251  def test_gradient(self, strategy_fn):
252    x = constant_op.constant([1.])
253    with strategy_fn().scope() as strategy:
254      with policy.policy_scope('mixed_float16'):
255        layer = mp_test_util.MultiplyLayer(assert_type=dtypes.float16)
256        # Learning rate is small enough that if applied to a float16 variable,
257        # the variable will not change. So this tests the learning rate is not
258        # applied to a float16 value, but instead the float32 variable.
259        opt = gradient_descent.SGD(2**-14)
260
261        def run_fn():
262          with backprop.GradientTape() as tape:
263            y = layer(x)
264            # Divide by num_replicas_in_sync, as the effective total loss is the
265            # sum of each of the replica's losses.
266            y /= strategy.num_replicas_in_sync
267
268          grad = tape.gradient(y, layer.v)
269          return opt.apply_gradients([(grad, layer.v)])
270
271        op = strategy.experimental_run(run_fn)
272        if not context.executing_eagerly():
273          self.evaluate(variables.global_variables_initializer())
274          self.evaluate(op)
275        # The gradient with respective to the variable is 1. Since the
276        # variable is initialized with 1 and the learning rate is 2**-14, the
277        # new variable value should be: init_val - gradient * learning_rate,
278        # which is  1 - 1 * 2**-14
279        self.assertEqual(self.evaluate(layer.v), 1 - 2**-14)
280
281  def _test_checkpointing_layer_weights(self, strategy_fn,
282                                        mixed_prec_when_saving,
283                                        mixed_prec_when_loading):
284    # In this test, we potentially save with mixed precision enabled and load
285    # with mixed precision disabled, or vice versa. This is possible because
286    # variables are float32 regardless of whether mixed precision is enabled.
287    save_policy = 'mixed_float16' if mixed_prec_when_saving else 'float32'
288    load_policy = 'mixed_float16' if mixed_prec_when_loading else 'float32'
289    save_input_dtype = 'float16' if mixed_prec_when_saving else 'float32'
290    load_input_dtype = 'float16' if mixed_prec_when_loading else 'float32'
291
292    # Create a layer and save a checkpoint.
293    x = constant_op.constant([1.])
294    with strategy_fn().scope():
295      with policy.policy_scope(save_policy):
296        layer = mp_test_util.MultiplyLayer(assert_type=save_input_dtype)
297        layer(x)  # Build layer
298    layer.set_weights([np.array(100.)])
299    self.assertEqual(self.evaluate(layer(x)), 100.)
300    checkpoint = trackable_utils.Checkpoint(layer=layer)
301    prefix = os.path.join(self.get_temp_dir(), 'ckpt')
302    save_path = checkpoint.save(prefix)
303
304    # Create a new layer and restore the checkpoint.
305    x = constant_op.constant([1.])
306    with strategy_fn().scope():
307      with policy.policy_scope(load_policy):
308        layer = mp_test_util.MultiplyLayer(assert_type=load_input_dtype)
309        layer(x)  # Build layer
310    layer.set_weights([np.array(200.)])
311    self.assertEqual(self.evaluate(layer(x)), 200.)
312    checkpoint = trackable_utils.Checkpoint(layer=layer)
313    checkpoint.restore(save_path).assert_consumed().run_restore_ops()
314    self.assertEqual(layer.get_weights(), [100.])
315    self.assertEqual(self.evaluate(layer(x)), 100.)
316
317  @parameterized.named_parameters(*TESTCASES)
318  def test_checkpointing_layer_weights(self, strategy_fn):
319    with self.test_session():
320      self._test_checkpointing_layer_weights(
321          strategy_fn, mixed_prec_when_saving=True,
322          mixed_prec_when_loading=True)
323      self._test_checkpointing_layer_weights(
324          strategy_fn, mixed_prec_when_saving=True,
325          mixed_prec_when_loading=False)
326      self._test_checkpointing_layer_weights(
327          strategy_fn, mixed_prec_when_saving=False,
328          mixed_prec_when_loading=True)
329
330  @parameterized.named_parameters(*TESTCASES)
331  def test_config(self, strategy_fn):
332    x = constant_op.constant([1.], dtype=dtypes.float16)
333    with strategy_fn().scope():
334      for layer, dtype in (
335          (mp_test_util.MultiplyLayer(), 'float32'),
336          (mp_test_util.MultiplyLayer(dtype='float64'), 'float64'),
337          (mp_test_util.MultiplyLayer(dtype=policy.Policy('float64')),
338           'float64')):
339        config = layer.get_config()
340        self.assertEqual(config['dtype'], dtype)
341        self.assertIsInstance(config['dtype'], str)
342        layer = mp_test_util.MultiplyLayer.from_config(config)
343        self.assertEqual(layer.dtype, dtype)
344        self.assertEqual(layer(x).dtype, dtype)
345        self.assertEqual(layer.v.dtype, dtype)
346
347      layer = mp_test_util.MultiplyLayer(dtype=policy.Policy('mixed_float16'))
348      config = layer.get_config()
349      self.assertEqual(config['dtype'],
350                       {'class_name': 'Policy',
351                        'config': {'name': 'mixed_float16'}})
352      layer = mp_test_util.MultiplyLayer.from_config(config)
353      self.assertEqual(layer.dtype, 'float32')
354      self.assertEqual(layer(x).dtype, 'float16')
355      self.assertEqual(layer.v.dtype, 'float32')
356      config = layer.get_config()
357      self.assertEqual(config['dtype'],
358                       {'class_name': 'Policy',
359                        'config': {'name': 'mixed_float16'}})
360
361      layer = mp_test_util.MultiplyLayer(dtype=policy.Policy('_infer'))
362      config = layer.get_config()
363      self.assertIsNone(config['dtype'])
364      layer = mp_test_util.MultiplyLayer.from_config(config)
365      # If a layer is serialized with the "_infer" policy, when deserialized
366      # into TF 2 it will have the global policy instead of "_infer". This is
367      # because "_infer" is serialized into None, and passing dtype=None in
368      # TensorFlow 2 indicates to use the global policy.
369      self.assertEqual(layer.dtype, 'float32')
370      self.assertEqual(layer(x).dtype, 'float32')
371      self.assertEqual(layer.v.dtype, 'float32')
372
373  @parameterized.named_parameters(*TESTCASES)
374  def test_config_policy_v1(self, strategy_fn):
375    x = constant_op.constant([1.], dtype=dtypes.float16)
376    with strategy_fn().scope():
377
378      layer = mp_test_util.MultiplyLayer(dtype=policy.PolicyV1('mixed_float16',
379                                                               loss_scale=None))
380      config = layer.get_config()
381      self.assertEqual(config['dtype'],
382                       {'class_name': 'PolicyV1',
383                        'config': {'name': 'mixed_float16',
384                                   'loss_scale': None}})
385      layer = mp_test_util.MultiplyLayer.from_config(config)
386      self.assertEqual(layer.dtype, 'float32')
387      self.assertEqual(layer(x).dtype, 'float16')
388      self.assertEqual(layer.v.dtype, 'float32')
389      # Restoring a PolicyV1 silently converts it to a Policy and drops the loss
390      # scale.
391      self.assertEqual(type(layer.dtype_policy), policy.Policy)
392      config = layer.get_config()
393      # The loss_scale is silently dropped
394      self.assertEqual(config['dtype'],
395                       {'class_name': 'Policy',
396                        'config': {'name': 'mixed_float16'}})
397
398      layer = mp_test_util.MultiplyLayer(dtype=policy.PolicyV1('float64',
399                                                               loss_scale=2.))
400      config = layer.get_config()
401      self.assertEqual(config['dtype'],
402                       {'class_name': 'PolicyV1',
403                        'config': {'name': 'float64',
404                                   'loss_scale': {
405                                       'class_name': 'FixedLossScale',
406                                       'config': {'loss_scale_value': 2.0}}}})
407      layer = mp_test_util.MultiplyLayer.from_config(config)
408      self.assertEqual(layer.dtype, 'float64')
409      self.assertEqual(layer(x).dtype, 'float64')
410      self.assertEqual(layer.v.dtype, 'float64')
411      self.assertEqual(type(layer.dtype_policy), policy.Policy)
412      config = layer.get_config()
413      self.assertEqual(config['dtype'], 'float64')
414
415      layer = mp_test_util.MultiplyLayer(dtype=policy.PolicyV1('_infer',
416                                                               loss_scale=2.))
417      config = layer.get_config()
418      self.assertEqual(config['dtype'],
419                       {'class_name': 'PolicyV1',
420                        'config': {'name': '_infer',
421                                   'loss_scale': {
422                                       'class_name': 'FixedLossScale',
423                                       'config': {'loss_scale_value': 2.0}}}})
424      layer = mp_test_util.MultiplyLayer.from_config(config)
425      self.assertEqual(layer.dtype, None)
426      self.assertEqual(layer(x).dtype, 'float16')
427      self.assertEqual(layer.v.dtype, 'float16')
428      self.assertEqual(type(layer.dtype_policy), policy.Policy)
429      config = layer.get_config()
430      self.assertEqual(config['dtype'], 'float16')
431
432  def test_delete_variable(self):
433    layer = base_layer.Layer(dtype=policy.Policy('mixed_float16'))
434    layer.x = layer.add_weight('x')
435    self.assertEqual(layer.trainable_weights, [layer.x])
436    del layer.x
437    self.assertEqual(layer.trainable_weights, [])
438
439  def test_build_and_call_layer_in_function(self):
440    layer = mp_test_util.MultiplyLayer(dtype=policy.Policy('mixed_float16'))
441    @def_function.function
442    def f():
443      return layer(1.)
444    y = f()
445    self.evaluate(variables.global_variables_initializer())
446    self.assertEqual(y.dtype, 'float16')
447    self.assertEqual(layer.v.dtype, 'float32')
448    self.assertEqual(self.evaluate(y), 1.)
449
450  def test_unsupported_strategy(self):
451    strategy = create_central_storage_strategy()
452    with strategy.scope(), self.assertRaisesRegex(
453        ValueError, 'Mixed precision is not supported with the '
454        'tf.distribute.Strategy: CentralStorageStrategy. Either '
455        'stop using mixed precision by removing the use of the '
456        '"mixed_float16" policy or use a different Strategy, e.g. '
457        'a MirroredStrategy.'):
458      mp_test_util.MultiplyLayer(dtype=policy.Policy('mixed_float16'))
459    # Non-mixed policies are fine
460    mp_test_util.MultiplyLayer(dtype=policy.Policy('float64'))
461
462
463class KerasModelTest(keras_parameterized.TestCase):
464  """Test mixed precision with Keras models."""
465
466  def _skip_if_strategy_unsupported(self, strategy_fn):
467    if (strategy_fn != default_strategy_fn and
468        testing_utils.get_model_type() == 'subclass'):
469      self.skipTest('Non-default strategies are unsupported with subclassed '
470                    'models')
471
472  def _skip_if_save_format_unsupported(self, save_format):
473    model_type = testing_utils.get_model_type()
474    if save_format == 'h5' and model_type == 'subclass':
475      self.skipTest('Saving subclassed models with the HDF5 format is '
476                    'unsupported')
477    if (save_format == 'tf' and model_type == 'subclass' and
478        not context.executing_eagerly()):
479      self.skipTest('b/148820505: This combination of features is currently '
480                    'broken.')
481
482  @keras_parameterized.run_with_all_model_types
483  @keras_parameterized.run_all_keras_modes
484  @parameterized.named_parameters(
485      {
486          'testcase_name': 'base',
487          'strategy_fn': default_strategy_fn
488      }, {
489          'testcase_name': 'distribute',
490          'strategy_fn': create_mirrored_strategy,
491      }, {
492          'testcase_name': 'operator',
493          'strategy_fn': create_mirrored_strategy,
494          'use_operator': True
495      }, {
496          'testcase_name': 'regularizer',
497          'strategy_fn': create_mirrored_strategy,
498          'use_regularizer': True
499      }, {
500          'testcase_name': 'get_config',
501          'strategy_fn': create_mirrored_strategy,
502          'get_config': True,
503          'use_regularizer': True,
504      }, {
505          'testcase_name': 'saved_model',
506          'strategy_fn': default_strategy_fn,
507          'save_format': 'tf',
508          'use_regularizer': True,
509      }, {
510          'testcase_name': 'saved_model_input_spec',
511          'strategy_fn': default_strategy_fn,
512          'save_format': 'tf',
513          'use_regularizer': True,
514          'use_input_spec': True,
515      }, {
516          'testcase_name': 'h5',
517          'strategy_fn': default_strategy_fn,
518          'save_format': 'h5',
519          'use_regularizer': True,
520      }, {
521          'testcase_name': 'saved_model_distribute',
522          'strategy_fn': create_mirrored_strategy,
523          'save_format': 'tf',
524          'use_regularizer': True,
525      }, {
526          'testcase_name': 'saved_model_input_spec_distribute',
527          'strategy_fn': create_mirrored_strategy,
528          'save_format': 'tf',
529          'use_regularizer': True,
530          'use_input_spec': True,
531      }, {
532          'testcase_name': 'h5_distribute',
533          'strategy_fn': create_mirrored_strategy,
534          'save_format': 'h5',
535          'use_regularizer': True,
536      }, {
537          'testcase_name': 'saved_model_v1_policy',
538          'strategy_fn': create_mirrored_strategy,
539          'use_v1_policy': True,
540          'save_format': 'tf',
541      })
542  def test_model(self,
543                 strategy_fn,
544                 use_operator=False,
545                 use_regularizer=False,
546                 policy_name='mixed_float16',
547                 get_config=False,
548                 save_format=None,
549                 use_input_spec=False,
550                 use_v1_policy=False):
551    self._skip_if_strategy_unsupported(strategy_fn)
552    self._skip_if_save_format_unsupported(save_format)
553    if use_regularizer:
554      weight_regularizer = mp_test_util.IdentityRegularizer()
555      activity_regularizer = mp_test_util.ReduceSumRegularizer()
556    else:
557      weight_regularizer = activity_regularizer = None
558    with strategy_fn().scope():
559      cls = policy.PolicyV1 if use_v1_policy else policy.Policy
560      with policy.policy_scope(cls(policy_name)):
561        layer = mp_test_util.MultiplyLayer(
562            assert_type=dtypes.float16,
563            use_operator=use_operator,
564            regularizer=weight_regularizer,
565            activity_regularizer=activity_regularizer,
566            input_shape=(1,))
567        if use_input_spec:
568          layer.input_spec = input_spec.InputSpec(shape=(None, 1))
569        model = testing_utils.get_model_from_layers([layer], input_shape=(1,),
570                                                    input_dtype=dtypes.float16)
571        if get_config:
572          config = model.get_config()
573          model = model.__class__.from_config(
574              config,
575              custom_objects={'MultiplyLayer': mp_test_util.MultiplyLayer})
576          (layer,) = (layer for layer in model.layers
577                      if isinstance(layer, mp_test_util.MultiplyLayer))
578
579        def loss_fn(y_true, y_pred):
580          del y_true
581          return math_ops.reduce_mean(y_pred)
582
583        # Learning rate is small enough that if applied to a float16 variable,
584        # the variable will not change. So this tests the learning rate not
585        # applied to a float16 value, but instead the float32 variable.
586        opt = gradient_descent.SGD(2**-14)
587        # Use a fixed loss scale, as this test will fail if gradients are
588        # skipped for a step due to dynamic loss scaling.
589        opt = loss_scale_optimizer.LossScaleOptimizer(opt, dynamic=False,
590                                                      initial_scale=8)
591        model.compile(
592            opt,
593            loss=loss_fn,
594            run_eagerly=testing_utils.should_run_eagerly())
595
596    x = np.ones((2, 1))
597    y = np.ones((2, 1))
598    dataset = dataset_ops.Dataset.from_tensor_slices((x, y)).batch(2)
599    model.fit(dataset)
600    # Variable starts at 1, and should have gradient of 2 ** -14 subtracted
601    # from it.
602    expected = 1 - 2**-14
603    if use_regularizer:
604      # Weight and activity regularizer each add another 2 ** -14 to the
605      # gradient.
606      expected -= 2 * 2**-14
607    self.assertEqual(backend.eval(layer.v), expected)
608
609    if save_format:
610      with generic_utils.CustomObjectScope(
611          {'MultiplyLayer': mp_test_util.MultiplyLayer, 'loss_fn': loss_fn}):
612        self._test_saving(model, dataset, save_format, use_regularizer)
613
614  def _test_saving(self, model, dataset, save_format, use_regularizer):
615    # Save and load model, asserting variable does not change
616    save_path = os.path.join(self.get_temp_dir(), 'model')
617    model.save(save_path, save_format=save_format)
618    model = save.load_model(save_path)
619    (layer,) = (layer for layer in model.layers
620                if 'MultiplyLayer' in layer.__class__.__name__)
621    expected = 1 - 2**-14
622    if use_regularizer:
623      expected -= 2 * 2**-14
624    self.assertEqual(backend.eval(layer.v), expected)
625
626    # Continue training, and assert variable is correct value
627    model.fit(dataset)
628    new_expected = expected - 2 ** -14
629    if use_regularizer:
630      new_expected -= 2 * 2 ** -14
631    self.assertEqual(backend.eval(layer.v), new_expected)
632
633    # Load saved model again, and assert variable is previous value
634    model = save.load_model(save_path)
635    (layer,) = (layer for layer in model.layers
636                if 'MultiplyLayer' in layer.__class__.__name__)
637    self.assertEqual(backend.eval(layer.v), expected)
638
639    # Ensure various dtype-related aspects of the layer are correct
640    self.assertEqual(layer.dtype, 'float32')
641    self.assertEqual(get_layer_policy.get_layer_policy(layer).name,
642                     'mixed_float16')
643    self.assertEqual(layer.v.dtype, 'float32')
644    self.assertEqual(layer(np.ones((2, 1))).dtype, 'float16')
645
646    # Loading a model always loads with a v2 Policy, even if saved with a
647    # PolicyV1.
648    self.assertEqual(type(model.dtype_policy), policy.Policy)
649    self.assertEqual(layer.get_config()['dtype'],
650                     {'class_name': 'Policy', 'config': {
651                         'name': 'mixed_float16'}})
652
653  @keras_parameterized.run_all_keras_modes
654  @parameterized.named_parameters(
655      {
656          'testcase_name': 'base',
657          'strategy_fn': default_strategy_fn
658      }, {
659          'testcase_name': 'distribute',
660          'strategy_fn': create_mirrored_strategy,
661      })
662  def test_fixed_loss_scaling(self,
663                              strategy_fn):
664    # Note: We do not test mixed precision in this method, only loss scaling.
665    loss_scale = 8.
666    batch_size = 4
667    with strategy_fn().scope():
668      x = layers.Input(shape=(1,), batch_size=batch_size)
669      layer = mp_test_util.MultiplyLayer()
670      y = layer(x)
671
672      # The gradient of 'y' at this point is 1. With loss scaling, the gradient
673      # is 'loss_scale'. We divide by the batch size since the loss is averaged
674      # across batch elements.
675      expected_gradient = loss_scale / batch_size
676      identity_with_grad_check_fn = (
677          mp_test_util.create_identity_with_grad_check_fn([expected_gradient]))
678      y = core.Lambda(identity_with_grad_check_fn)(y)
679      model = models.Model(inputs=x, outputs=y)
680
681      def loss_fn(y_true, y_pred):
682        del y_true
683        return math_ops.reduce_mean(y_pred)
684
685      opt = gradient_descent.SGD(1.)
686      opt = loss_scale_optimizer.LossScaleOptimizer(opt, dynamic=False,
687                                                    initial_scale=loss_scale)
688      model.compile(
689          opt,
690          loss=loss_fn,
691          run_eagerly=testing_utils.should_run_eagerly())
692
693    self.assertEqual(backend.eval(layer.v), 1)
694    x = np.ones((batch_size, 1))
695    y = np.ones((batch_size, 1))
696    dataset = dataset_ops.Dataset.from_tensor_slices((x, y)).batch(batch_size)
697    model.fit(dataset)
698    # Variable starts at 1, and should have gradient of 1 subtracted from it.
699    expected = 0
700    self.assertEqual(backend.eval(layer.v), expected)
701
702  @keras_parameterized.run_all_keras_modes
703  @parameterized.named_parameters(
704      {
705          'testcase_name': 'base',
706          'strategy_fn': default_strategy_fn
707      }, {
708          'testcase_name': 'distribute',
709          'strategy_fn': create_mirrored_strategy,
710      }, {
711          'testcase_name': 'loss_scaling',
712          'strategy_fn': create_mirrored_strategy,
713          'use_loss_scaling': True
714      })
715  def test_advanced_model(self, strategy_fn, use_loss_scaling=False):
716    # The advanced model tests mixed-precision-related features that would occur
717    # in a resnet50 model. It tests a model that has:
718    #  * Multiple layers, some which use auto-cast variables and some which do
719    #    not
720    #  * Regularization on some variables and not others.
721    #  * A fixed loss scale (if use_loss_scaling is True)
722
723    strategy = strategy_fn()
724    if use_loss_scaling:
725      loss_scale = 8.
726    learning_rate = 2**-14
727
728    with strategy.scope():
729      with policy.policy_scope(policy.Policy('mixed_float16')):
730        x = layers.Input(shape=(1,), batch_size=2)
731        layer1 = mp_test_util.MultiplyLayer(
732            assert_type=dtypes.float16,
733            regularizer=mp_test_util.IdentityRegularizer(),
734            use_operator=True)
735        layer2 = MultiplyLayerWithoutAutoCast(
736            assert_type=dtypes.float16, use_operator=True)
737        layer3 = mp_test_util.MultiplyLayer(assert_type=dtypes.float16,
738                                            use_operator=False)
739        layer4 = MultiplyLayerWithoutAutoCast(
740            assert_type=dtypes.float16,
741            regularizer=mp_test_util.IdentityRegularizer(),
742            use_operator=False)
743        y = layer1(x)
744        y = layer2(y)
745        y = layer3(y)
746        y = layer4(y)
747        if use_loss_scaling:
748          # The gradient of 'y' at this point is 1. With loss scaling, the
749          # gradient is 'loss_scale'. We divide by the batch size of 2 since the
750          # loss is averaged across batch elements.
751          expected_gradient = loss_scale / 2
752          identity_with_grad_check_fn = (
753              mp_test_util.create_identity_with_grad_check_fn(
754                  expected_dtype=dtypes.float16,
755                  expected_gradient=[expected_gradient]))
756          y = core.Lambda(identity_with_grad_check_fn)(y)
757        model = models.Model(inputs=x, outputs=y)
758
759        def loss_fn(y_true, y_pred):
760          del y_true
761          return math_ops.reduce_mean(y_pred)
762
763        opt = gradient_descent.SGD(learning_rate)
764        if use_loss_scaling:
765          opt = loss_scale_optimizer.LossScaleOptimizer(
766              opt, dynamic=False, initial_scale=loss_scale)
767        model.compile(
768            opt,
769            loss=loss_fn,
770            run_eagerly=testing_utils.should_run_eagerly())
771
772    x = np.ones((2, 1))
773    y = np.ones((2, 1))
774    dataset = dataset_ops.Dataset.from_tensor_slices((x, y)).batch(2)
775    model.fit(dataset)
776    for layer in (layer1, layer2, layer3, layer4):
777      if layer.losses:
778        # Layer has weight regularizer
779        self.assertEqual(backend.eval(layer.v), 1 - 2 * learning_rate)
780      else:
781        # Layer does not have weight regularizer
782        self.assertEqual(backend.eval(layer.v), 1 - learning_rate)
783
784  @keras_parameterized.run_all_keras_modes(always_skip_v1=True)
785  @parameterized.named_parameters(
786      {
787          'testcase_name': 'base',
788          'strategy_fn': default_strategy_fn
789      }, {
790          'testcase_name': 'distribute',
791          'strategy_fn': create_mirrored_strategy,
792      }, {
793          'testcase_name': 'pass_loss_scale_to_policy',
794          'strategy_fn': create_mirrored_strategy,
795          'pass_loss_scale_to_policy': True,
796      }, {
797          'testcase_name': 'get_config',
798          'strategy_fn': create_mirrored_strategy,
799          'get_config': True,
800      }, {
801          'testcase_name': 'get_config_v1_lso',
802          'strategy_fn': create_mirrored_strategy,
803          'get_config': True,
804          'use_v1_loss_scale_optimizer': True,
805      }, {
806          'testcase_name': 'get_config_and_pass_loss_scale_to_policy',
807          'strategy_fn': create_mirrored_strategy,
808          'get_config': True,
809          'pass_loss_scale_to_policy': True,
810      })
811  def test_dynamic_loss_scaling(self,
812                                strategy_fn,
813                                pass_loss_scale_to_policy=False,
814                                get_config=False,
815                                use_v1_loss_scale_optimizer=False):
816    strategy = strategy_fn()
817    initial_loss_scale = 2.
818    batch_size = 4
819    expected_gradient = backend.variable([initial_loss_scale / batch_size],
820                                         dtype=dtypes.float16)
821    # If this variable is set to True, the model below will have NaN gradients
822    have_nan_gradients = backend.variable(False, dtype=dtypes.bool)
823    with strategy.scope():
824      opt = gradient_descent.SGD(1.)
825      if pass_loss_scale_to_policy:
826        loss_scale = loss_scale_module.DynamicLossScale(
827            initial_loss_scale=initial_loss_scale, increment_period=2)
828        p = policy.PolicyV1('mixed_float16', loss_scale=loss_scale)
829      elif use_v1_loss_scale_optimizer:
830        loss_scale = loss_scale_module.DynamicLossScale(
831            initial_loss_scale=initial_loss_scale, increment_period=2)
832        p = policy.Policy('mixed_float16')
833        opt = loss_scale_optimizer.LossScaleOptimizerV1(
834            opt, loss_scale)
835      else:
836        p = policy.Policy('mixed_float16')
837        opt = loss_scale_optimizer.LossScaleOptimizer(
838            opt, initial_scale=initial_loss_scale, dynamic_growth_steps=2)
839      with policy.policy_scope(p):
840        x = layers.Input(
841            shape=(1,), batch_size=batch_size, dtype=dtypes.float16)
842        layer = mp_test_util.MultiplyLayer(assert_type=dtypes.float16)
843        y = layer(x)
844        identity_with_nan_grads = (
845            mp_test_util.create_identity_with_nan_gradients_fn(
846                have_nan_gradients))
847        y = core.Lambda(identity_with_nan_grads)(y)
848        identity_with_grad_check_fn = (
849            mp_test_util.create_identity_with_grad_check_fn(
850                expected_dtype=dtypes.float16,
851                expected_gradient=expected_gradient))
852        y = core.Lambda(identity_with_grad_check_fn)(y)
853        model = models.Model(inputs=x, outputs=y)
854        if get_config:
855          config = model.get_config()
856          model = model.__class__.from_config(
857              config,
858              custom_objects={'MultiplyLayer': mp_test_util.MultiplyLayer})
859          (layer,) = (layer for layer in model.layers
860                      if isinstance(layer, mp_test_util.MultiplyLayer))
861
862        def loss_fn(y_true, y_pred):
863          del y_true
864          return math_ops.reduce_mean(y_pred)
865
866        model.compile(
867            opt,
868            loss=loss_fn,
869            run_eagerly=testing_utils.should_run_eagerly())
870
871    self.assertEqual(backend.eval(layer.v), 1)
872    x = np.ones((batch_size, 1))
873    y = np.ones((batch_size, 1))
874    dataset = dataset_ops.Dataset.from_tensor_slices((x, y)).batch(batch_size)
875    model.fit(dataset)
876    # The variables starts with 1 and has a gradient of 1, so will go down by 1
877    # each step.
878    self.assertEqual(backend.eval(layer.v), 0)
879
880    model.fit(dataset)
881    self.assertEqual(backend.eval(layer.v), -1)
882
883    # There have been two steps without NaNs, so the loss scale will double
884    backend.set_value(expected_gradient,
885                      backend.get_value(expected_gradient * 2))
886    model.fit(dataset)
887    self.assertEqual(backend.eval(layer.v), -2)
888
889    # Next test with NaN gradients.
890    backend.set_value(have_nan_gradients, True)
891    model.fit(dataset)
892    # Variable should not be updated
893    self.assertEqual(backend.eval(layer.v), -2)
894
895    # Test with finite gradients again
896    backend.set_value(have_nan_gradients, False)
897    # The loss scale will be halved due to the NaNs, so the gradient will also
898    # be halved
899    backend.set_value(expected_gradient,
900                      backend.get_value(expected_gradient / 2))
901    model.fit(dataset)
902    self.assertEqual(backend.eval(layer.v), -3)
903
904  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
905  def test_loss_scale_optimizer_overrides_policy_v1_loss_scale(self):
906    with policy.policy_scope(policy.PolicyV1('float32', loss_scale=10.)):
907      opt = gradient_descent.SGD(1.)
908      opt = loss_scale_optimizer.LossScaleOptimizer(opt, dynamic=False,
909                                                    initial_scale=5.)
910      x = layers.Input(shape=(1,))
911      y = mp_test_util.MultiplyLayer()(x)
912      model = models.Model(x, y)
913      model.compile(opt, loss='mse')
914      self.assertEqual(self.evaluate(model.optimizer.loss_scale), 5.)
915
916  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
917  def test_policy_v1_without_loss_scale(self):
918    with policy.policy_scope(policy.PolicyV1('mixed_float16',
919                                             loss_scale=None)):
920      opt = gradient_descent.SGD(1.)
921      x = layers.Input(shape=(1,))
922      y = mp_test_util.MultiplyLayer()(x)
923      model = models.Model(x, y)
924      model.compile(opt, loss='mse')
925      self.assertNotIsInstance(model.optimizer,
926                               loss_scale_optimizer.LossScaleOptimizer)
927
928  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
929  def test_pass_invalid_optimizer_with_loss_scaling(self):
930    with policy.policy_scope(policy.PolicyV1('float32', loss_scale=10.)):
931      x = layers.Input(shape=(1,))
932      y = mp_test_util.MultiplyLayer()(x)
933      model = models.Model(x, y)
934      if context.executing_eagerly():
935        error_msg = 'Use a `tf.keras` Optimizer instead'
936      else:
937        error_msg = 'optimizer" must be an instance of '
938      with self.assertRaisesRegex(ValueError, error_msg):
939        model.compile(optimizer_v1.SGD(1.), 'mse')
940
941  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
942  def test_functional_model_loss_dtype(self):
943    with policy.policy_scope('float16'):
944      x = layers.Input(shape=(1,))
945      y = mp_test_util.MultiplyLayer()(x)
946      model = models.Model(x, y)
947      model.add_loss(math_ops.cast(y, 'float32'))
948      # The loss should not be casted to the policy's dtype.
949      self.assertEqual(model.losses[0].dtype, 'float32')
950
951  @keras_parameterized.run_all_keras_modes
952  @parameterized.named_parameters(
953      {
954          'testcase_name': 'base',
955          'strategy_fn': default_strategy_fn,
956      }, {
957          'testcase_name': 'distribute',
958          'strategy_fn': create_mirrored_strategy,
959      }, {
960          'testcase_name': 'base_h5',
961          'strategy_fn': default_strategy_fn,
962          'h5': True,
963      }, {
964          'testcase_name': 'distribute_h5',
965          'strategy_fn': create_mirrored_strategy,
966          'h5': True,
967      })
968  def test_save_weights_with_autocast_vars(self, strategy_fn, h5=False):
969    with strategy_fn().scope():
970      with policy.policy_scope('mixed_float16'):
971        x = layers.Input(shape=(1,), batch_size=2)
972        layer = mp_test_util.MultiplyLayer(assert_type=dtypes.float16)
973        y = layer(x)
974        model = models.Model(inputs=x, outputs=y)
975
976    model.set_weights([np.array(100.)])
977    x = np.ones((2, 1))
978    self.assertAllClose(backend.get_value(model(x)), x * 100.)
979    suffix = '.h5' if h5 else ''
980    weights_file = os.path.join(self.get_temp_dir(), 'weights' + suffix)
981    model.save_weights(weights_file)
982
983    model.set_weights([np.array(200.)])
984    self.assertAllClose(backend.get_value(model(x)), x * 200.)
985    model.load_weights(weights_file)
986    self.assertAllClose(backend.get_value(model(x)), x * 100.)
987    self.assertEqual(model.get_weights(), [np.array(100.)])
988
989  @keras_parameterized.run_all_keras_modes
990  @parameterized.named_parameters(
991      {
992          'testcase_name': 'base',
993          'strategy_fn': default_strategy_fn,
994      }, {
995          'testcase_name': 'distribute',
996          'strategy_fn': create_mirrored_strategy,
997      }, {
998          'testcase_name': 'different_var_name',
999          'strategy_fn': default_strategy_fn,
1000          'var_name': 'w'
1001      }, {
1002          'testcase_name': 'different_var_name_distribute',
1003          'strategy_fn': create_mirrored_strategy,
1004          'var_name': 'w'
1005      })
1006  def test_save_slot_variables_with_autocast_vars(self,
1007                                                  strategy_fn,
1008                                                  var_name='v'):
1009    p = policy.Policy('mixed_float16')
1010    with strategy_fn().scope(), policy.policy_scope(p):
1011      x = layers.Input(shape=(2,), batch_size=2)
1012      # Having a var_name other than 'v' tests that a fixed bug (b/134713714)
1013      # does not reoccur. The bug was that a crash would occur when saving a
1014      # checkpoint where an AutoCastVariable with a slot variable would have a
1015      # different name than the layer attribute's name (layer.v in this case).
1016      layer = mp_test_util.MultiplyLayer(assert_type=dtypes.float16,
1017                                         var_name=var_name)
1018      y = layer(x)
1019      model = models.Model(inputs=x, outputs=y)
1020      opt = gradient_descent.SGD(1., 1.)
1021      opt = loss_scale_optimizer.LossScaleOptimizer(opt, dynamic=False,
1022                                                    initial_scale=1)
1023      model.compile(
1024          optimizer=opt,
1025          loss='mse',
1026          run_eagerly=testing_utils.should_run_eagerly())
1027
1028    model.fit(np.ones((2, 2)), np.zeros((2, 2)), batch_size=2)
1029    weights_file = os.path.join(self.get_temp_dir(), 'weights')
1030    model.save_weights(weights_file)
1031    saved_slot = backend.get_value(opt.get_slot(layer.v, 'momentum'))
1032
1033    model.fit(np.ones((2, 2)), np.zeros((2, 2)), batch_size=2)
1034    new_slot = backend.get_value(opt.get_slot(layer.v, 'momentum'))
1035    self.assertNotEqual(new_slot, saved_slot)
1036
1037    model.load_weights(weights_file)
1038    restored_slot = backend.get_value(opt.get_slot(layer.v, 'momentum'))
1039    self.assertEqual(restored_slot, saved_slot)
1040
1041  @keras_parameterized.run_all_keras_modes
1042  @parameterized.named_parameters(*TESTCASES)
1043  def test_save_weights_with_dynamic_loss_scaling(self, strategy_fn):
1044    strategy = strategy_fn()
1045    if (isinstance(strategy, mirrored_strategy.MirroredStrategy) and
1046        not context.executing_eagerly()):
1047      # TODO(b/121381184): Enable running the test in this case.
1048      return
1049
1050    # Create and run model.
1051    with strategy.scope():
1052      x = layers.Input(shape=(2,), batch_size=2, dtype=dtypes.float32)
1053      y = mp_test_util.MultiplyLayer(assert_type=dtypes.float32)(x)
1054      model = models.Model(inputs=x, outputs=y)
1055
1056      opt = gradient_descent.SGD(1.)
1057      opt = loss_scale_optimizer.LossScaleOptimizer(
1058          opt, initial_scale=1., dynamic_growth_steps=2.)
1059      model.compile(
1060          optimizer=opt,
1061          loss='mse',
1062          run_eagerly=testing_utils.should_run_eagerly())
1063    # Run for 3 steps (6 examples with a batch size of 2)
1064    model.fit(np.zeros((6, 2)), np.zeros((6, 2)), batch_size=2)
1065    self.assertEqual(backend.get_value(opt.loss_scale), 2)
1066    self.assertEqual(backend.get_value(opt.dynamic_counter), 1)
1067
1068    # Save model weights.
1069    save_prefix = os.path.join(self.get_temp_dir(), 'ckpt')
1070    model.save_weights(save_prefix)
1071
1072    # Run model again for 1 step (2 examples with a batch size of 2)
1073    model.fit(np.zeros((2, 2)), np.zeros((2, 2)), batch_size=2)
1074    self.assertEqual(backend.get_value(opt.loss_scale), 4)
1075    self.assertEqual(backend.get_value(opt.dynamic_counter), 0)
1076
1077    # Load model weights and ensure loss scale weights are restored.
1078    model.load_weights(save_prefix)
1079    self.assertEqual(backend.get_value(opt.loss_scale), 2)
1080    self.assertEqual(backend.get_value(opt.dynamic_counter), 1)
1081
1082  @keras_parameterized.run_all_keras_modes
1083  def test_restore_old_loss_scale_checkpoint(self):
1084    # Ensure a checkpoint from TF 2.2 can be loaded. The checkpoint format
1085    # of LossScaleOptimizer changed, but old checkpoints can still be loaded
1086    opt = gradient_descent.SGD(0.1, momentum=0.1)
1087    opt = loss_scale_optimizer.LossScaleOptimizer(opt)
1088    model = sequential.Sequential([core.Dense(2,)])
1089
1090    # The checkpoint and expected values were obtained from the program in
1091    # testdata/BUILD.
1092    ckpt_dir = os.path.join(
1093        flags.FLAGS['test_srcdir'].value,
1094        'org_tensorflow/tensorflow/python/keras',
1095        'mixed_precision/testdata/lso_ckpt_tf2.2')
1096    # ckpt_dir = test.test_src_dir_path(
1097    #     'python/keras/mixed_precision/testdata/lso_ckpt_tf2.2')
1098    model.load_weights(os.path.join(ckpt_dir, 'ckpt'))
1099    model.compile(opt, 'mse', run_eagerly=testing_utils.should_run_eagerly())
1100    model(np.zeros((2, 2)))  # Create model weights
1101    opt._create_all_weights(model.weights)
1102    expected_kernel = np.array([[9.229685, 10.901115], [10.370763, 9.757362]])
1103    expected_slot = np.array([[10.049943, 9.917691], [10.049943, 9.917691]])
1104    self.assertAllClose(self.evaluate(model.weights[0]), expected_kernel)
1105    self.assertAllClose(
1106        self.evaluate(opt.get_slot(model.weights[0], 'momentum')),
1107        expected_slot)
1108    self.assertEqual(self.evaluate(opt.loss_scale), 32768)
1109    self.assertEqual(self.evaluate(opt.dynamic_counter), 1)
1110
1111    # Check restoring works even after the model is compiled and the weights
1112    # have been created.
1113    model.fit(np.random.normal(size=(2, 2)), np.random.normal(size=(2, 2)))
1114    self.assertNotAllClose(self.evaluate(model.weights[0]), expected_kernel)
1115    self.assertNotAllClose(
1116        self.evaluate(opt.get_slot(model.weights[0], 'momentum')),
1117        expected_slot)
1118    model.load_weights(os.path.join(ckpt_dir, 'ckpt'))
1119    self.assertAllClose(self.evaluate(model.weights[0]), expected_kernel)
1120    self.assertAllClose(
1121        self.evaluate(opt.get_slot(model.weights[0], 'momentum')),
1122        expected_slot)
1123    self.assertEqual(self.evaluate(opt.loss_scale), 32768)
1124    self.assertEqual(self.evaluate(opt.dynamic_counter), 1)
1125
1126  def test_restore_old_saved_model(self):
1127    saved_model_dir = os.path.join(
1128        flags.FLAGS['test_srcdir'].value,
1129        'org_tensorflow/tensorflow/python/keras',
1130        'mixed_precision/testdata/lso_savedmodel_tf2.2')
1131    # saved_model_dir = test.test_src_dir_path(
1132    #     'python/keras/mixed_precision/testdata/'
1133    #     'lso_savedmodel_tf2.2')
1134    model = save.load_model(saved_model_dir)
1135    expected_kernel = np.array([[9.229685, 10.901115], [10.370763, 9.757362]])
1136    self.assertAllClose(backend.eval(model.weights[0]), expected_kernel)
1137    self.assertEqual(type(model.optimizer),
1138                     loss_scale_optimizer.LossScaleOptimizer)
1139
1140  @keras_parameterized.run_all_keras_modes
1141  @parameterized.named_parameters(
1142      {
1143          'testcase_name': 'base',
1144          'strategy_fn': default_strategy_fn,
1145      }, {
1146          'testcase_name': 'distribute',
1147          'strategy_fn': create_mirrored_strategy,
1148      }, {
1149          'testcase_name': 'use_v1_lso',
1150          'strategy_fn': create_mirrored_strategy,
1151          'use_v1_loss_scale_optimizer': True
1152      }, {
1153          'testcase_name': 'base_h5',
1154          'strategy_fn': default_strategy_fn,
1155          'h5': True,
1156      }, {
1157          'testcase_name': 'distribute_h5',
1158          'strategy_fn': create_mirrored_strategy,
1159          'h5': True,
1160      })
1161  def test_save_model_with_dynamic_loss_scaling(
1162      self, strategy_fn, h5=False, use_v1_loss_scale_optimizer=False):
1163    # TODO(reedwm): Support and test saving model with a mixed_[b]float16 policy
1164    # as well.
1165    strategy = strategy_fn()
1166    if (isinstance(strategy, mirrored_strategy.MirroredStrategy) and
1167        not context.executing_eagerly()):
1168      # TODO(b/121381184): Enable running the test in this case.
1169      return
1170
1171    # Create and run model.
1172    with strategy.scope():
1173      x = layers.Input(shape=(2,), batch_size=2, dtype=dtypes.float32)
1174      y = mp_test_util.MultiplyLayer()(x)
1175      model = models.Model(inputs=x, outputs=y)
1176
1177      opt = gradient_descent.SGD(1.)
1178      if use_v1_loss_scale_optimizer:
1179        loss_scale = loss_scale_module.DynamicLossScale(
1180            initial_loss_scale=1., increment_period=2.)
1181        opt = loss_scale_optimizer.LossScaleOptimizerV1(opt, loss_scale)
1182      else:
1183        opt = loss_scale_optimizer.LossScaleOptimizer(opt, initial_scale=1.,
1184                                                      dynamic_growth_steps=2.)
1185      model.compile(
1186          optimizer=opt,
1187          loss='mse',
1188          run_eagerly=testing_utils.should_run_eagerly())
1189    # Run for 3 steps (6 examples with a batch size of 2)
1190    model.fit(np.ones((6, 2)), np.zeros((6, 2)), batch_size=2)
1191    self.assertEqual(backend.get_value(opt.loss_scale), 2)
1192    self.assertEqual(backend.get_value(opt.dynamic_counter), 1)
1193    (weight,) = model.trainable_weights
1194    orig_weight = backend.get_value(weight)
1195
1196    # Save model weights.
1197    save_path = os.path.join(self.get_temp_dir(), 'model')
1198    model.save(save_path, save_format='h5' if h5 else 'tf')
1199
1200    # Run model again for 1 step (2 examples with a batch size of 2)
1201    model.fit(np.ones((2, 2)), np.zeros((2, 2)), batch_size=2)
1202    new_weight = backend.get_value(weight)
1203    self.assertNotEqual(new_weight, orig_weight)
1204    self.assertEqual(backend.get_value(opt.loss_scale), 4)
1205    self.assertEqual(backend.get_value(opt.dynamic_counter), 0)
1206
1207    # Load model weights and ensure loss scale weights are restored.
1208    model = save.load_model(
1209        save_path, custom_objects={'MultiplyLayer': mp_test_util.MultiplyLayer})
1210    (weight,) = model.trainable_weights
1211    loaded_weight = backend.get_value(weight)
1212    self.assertEqual(loaded_weight, orig_weight)
1213    # Currently the loss scale isn't always saved when the model is saved with
1214    # Model.save(). So we assert the loss scale either has the value when it was
1215    # saved, or the value it was initialized with.
1216    # TODO(reedwm): Always save/restore the loss scale with Model.save().
1217    self.assertIn(backend.get_value(model.optimizer.loss_scale), (1, 2))
1218    self.assertIn(backend.get_value(model.optimizer.dynamic_counter), (0, 1))
1219
1220    # Test optimizer attributes and type
1221    self.assertEqual(model.optimizer.initial_scale, 1.)
1222    self.assertEqual(model.optimizer.dynamic_growth_steps, 2.)
1223    self.assertEqual(type(model.optimizer),
1224                     loss_scale_optimizer.LossScaleOptimizer)
1225
1226
1227if __name__ == '__main__':
1228  base_layer_utils.enable_v2_dtype_behavior()
1229  test.main()
1230