• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests keras.layers.Layer works properly with mixed precision."""
16
17import os
18
19from absl.testing import parameterized
20import numpy as np
21
22from tensorflow.python.distribute import central_storage_strategy
23from tensorflow.python.distribute import distribution_strategy_context
24from tensorflow.python.distribute import mirrored_strategy
25from tensorflow.python.eager import backprop
26from tensorflow.python.eager import context
27from tensorflow.python.eager import def_function
28from tensorflow.python.framework import config as tf_config
29from tensorflow.python.framework import constant_op
30from tensorflow.python.framework import dtypes
31from tensorflow.python.keras import combinations
32from tensorflow.python.keras import keras_parameterized
33from tensorflow.python.keras import layers
34from tensorflow.python.keras import models
35from tensorflow.python.keras.engine import base_layer
36from tensorflow.python.keras.engine import base_layer_utils
37from tensorflow.python.keras.engine import input_spec
38from tensorflow.python.keras.mixed_precision import get_layer_policy
39from tensorflow.python.keras.mixed_precision import policy
40from tensorflow.python.keras.mixed_precision import test_util as mp_test_util
41from tensorflow.python.keras.optimizer_v2 import gradient_descent
42from tensorflow.python.ops import array_ops
43from tensorflow.python.ops import math_ops
44from tensorflow.python.ops import variables
45from tensorflow.python.platform import test
46from tensorflow.python.training.tracking import util as trackable_utils
47
48
49class MultiplyLayerWithFunction(mp_test_util.MultiplyLayer):
50  """Same as MultiplyLayer, but _multiply is decorated with a tf.function."""
51
52  @def_function.function
53  def _multiply(self, x, y):
54    return super(MultiplyLayerWithFunction, self)._multiply(x, y)
55
56
57# If called outside any strategy.scope() calls, this will return the default
58# strategy.
59default_strategy_fn = distribution_strategy_context.get_strategy
60
61
62def create_mirrored_strategy():
63  """Create a MirroredStrategy, using a GPU if it is available."""
64  if tf_config.list_logical_devices('GPU'):
65    return mirrored_strategy.MirroredStrategy(['cpu:0', 'gpu:0'])
66  else:
67    return mirrored_strategy.MirroredStrategy(['cpu:0'])
68
69
70def create_central_storage_strategy():
71  """Create a CentralStorageStrategy, using a GPU if it is available."""
72  compute_devices = ['cpu:0', 'gpu:0'] if (
73      tf_config.list_logical_devices('GPU')) else ['cpu:0']
74  return central_storage_strategy.CentralStorageStrategy(
75      compute_devices, parameter_device='cpu:0')
76
77
78TESTCASES = ({
79    'testcase_name': 'base',
80    'strategy_fn': default_strategy_fn
81}, {
82    'testcase_name': 'distribute',
83    'strategy_fn': create_mirrored_strategy
84})
85
86
87@combinations.generate(combinations.combine(mode=['graph', 'eager']))
88class LayerTest(keras_parameterized.TestCase):
89  """Test mixed precision with Keras layers."""
90
91  @parameterized.named_parameters(*TESTCASES)
92  def test_mixed_policies_(self, strategy_fn):
93    strategy = strategy_fn()
94    for dtype in 'float16', 'bfloat16':
95      x = constant_op.constant([1.])
96      policy_name = 'mixed_' + dtype
97      with strategy.scope(), policy.policy_scope(policy_name):
98        layer = mp_test_util.MultiplyLayer(assert_type=dtype)
99        self.assertEqual(layer.dtype, dtypes.float32)
100        self.assertEqual(get_layer_policy.get_layer_policy(layer).name,
101                         policy_name)
102        y = layer(x)
103        self.assertEqual(layer.v.dtype, dtypes.float32)
104        self.assertEqual(y.dtype, dtype)
105        self.assertEqual(layer.dtype_policy.name, policy_name)
106        self.assertIsInstance(layer.dtype_policy, policy.Policy)
107        self.assertEqual(layer.compute_dtype, dtype)
108        self.assertEqual(layer.dtype, dtypes.float32)
109        self.assertEqual(layer.variable_dtype, dtypes.float32)
110        self.assertEqual(get_layer_policy.get_layer_policy(layer).name,
111                         policy_name)
112        self.evaluate(variables.global_variables_initializer())
113        self.assertEqual(self.evaluate(y), 1.)
114
115  def test_layer_with_int_variable(self):
116    class LayerWithIntVar(base_layer.Layer):
117
118      def build(self, _):
119        self.v = self.add_weight('v', dtype='int32', trainable=False)
120
121      def call(self, inputs):
122        # Only float variables should be autocasted. This will fail if self.v is
123        # autocasted to float32
124        return math_ops.cast(inputs, 'int32') + self.v
125
126    x = constant_op.constant([1.])
127    layer = LayerWithIntVar(dtype='mixed_float16')
128    self.assertEqual(layer(x).dtype, 'int32')
129
130  @parameterized.named_parameters(*TESTCASES)
131  def test_layer_with_non_autocast_variable(self, strategy_fn):
132    x = constant_op.constant([1.])
133    with strategy_fn().scope():
134      with policy.policy_scope('mixed_float16'):
135        layer = mp_test_util.MultiplyLayerWithoutAutoCast(
136            assert_type=dtypes.float16)
137        y = layer(x)
138        self.assertEqual(layer.v.dtype, dtypes.float32)
139        self.assertEqual(y.dtype, dtypes.float16)
140        self.evaluate(variables.global_variables_initializer())
141        self.assertEqual(self.evaluate(y), 1.)
142
143  @parameterized.named_parameters(*TESTCASES)
144  def test_layer_calling_tf_function(self, strategy_fn):
145    x = constant_op.constant([1.])
146    with strategy_fn().scope():
147      with policy.policy_scope('mixed_float16'):
148        layer = MultiplyLayerWithFunction(assert_type=dtypes.float16)
149        y = layer(x)
150        self.assertEqual(layer.v.dtype, dtypes.float32)
151        self.assertEqual(y.dtype, dtypes.float16)
152        self.evaluate(variables.global_variables_initializer())
153        self.assertEqual(self.evaluate(y), 1.)
154
155  @parameterized.named_parameters(*TESTCASES)
156  def test_layer_regularizer_runs_in_var_dtype(self, strategy_fn):
157    x = constant_op.constant([1.])
158    with strategy_fn().scope():
159      with policy.policy_scope('mixed_float16'):
160        # Test on MultiplyLayer
161        layer = mp_test_util.MultiplyLayer(
162            assert_type=dtypes.float16,
163            regularizer=mp_test_util.IdentityRegularizer())
164        layer(x)
165        (regularizer_loss,) = layer.losses
166        self.assertEqual(regularizer_loss.dtype, dtypes.float32)
167        self.evaluate(variables.global_variables_initializer())
168        self.assertEqual(self.evaluate(regularizer_loss), 1.)
169
170        # Test on MultiplyLayerWithoutAutoCast
171        layer = mp_test_util.MultiplyLayerWithoutAutoCast(
172            assert_type=dtypes.float16,
173            regularizer=mp_test_util.IdentityRegularizer())
174        layer(x)
175        (regularizer_loss,) = layer.losses
176        self.assertEqual(regularizer_loss.dtype, dtypes.float32)
177        self.evaluate(variables.global_variables_initializer())
178        self.assertEqual(self.evaluate(regularizer_loss), 1.)
179
180  @parameterized.named_parameters(*TESTCASES)
181  def test_passing_policy_to_layer(self, strategy_fn):
182    x = constant_op.constant([1.], dtype=dtypes.float16)
183    with strategy_fn().scope():
184      # Passing a Policy to 'dtype' sets the policy for that layer.
185      layer = mp_test_util.MultiplyLayer(
186          assert_type=dtypes.float16, dtype=policy.Policy('mixed_float16'))
187      # layer.dtype refers to the variable dtype
188      self.assertEqual(layer.dtype, dtypes.float32)
189      layer(x)
190      self.assertEqual(layer.v.dtype, dtypes.float32)
191      with policy.policy_scope('mixed_float16'):
192        # Passing a Policy to dtype overrides the global Policy
193        layer = mp_test_util.MultiplyLayer(
194            assert_type=dtypes.float64, dtype=policy.Policy('float64'))
195        self.assertEqual(layer.dtype_policy.name, 'float64')
196        self.assertIsInstance(layer.dtype_policy, policy.Policy)
197        self.assertEqual(layer.compute_dtype, dtypes.float64)
198        self.assertEqual(layer.dtype, dtypes.float64)
199        self.assertEqual(layer.variable_dtype, dtypes.float64)
200        self.assertEqual(layer(x).dtype, dtypes.float64)
201        self.assertEqual(layer.v.dtype, dtypes.float64)
202
203  @parameterized.named_parameters(*TESTCASES)
204  def test_gradient(self, strategy_fn):
205    x = constant_op.constant([1.])
206    with strategy_fn().scope() as strategy:
207      with policy.policy_scope('mixed_float16'):
208        layer = mp_test_util.MultiplyLayer(assert_type=dtypes.float16)
209        # Learning rate is small enough that if applied to a float16 variable,
210        # the variable will not change. So this tests the learning rate is not
211        # applied to a float16 value, but instead the float32 variable.
212        opt = gradient_descent.SGD(2**-14)
213
214        def run_fn():
215          with backprop.GradientTape() as tape:
216            y = layer(x)
217            # Divide by num_replicas_in_sync, as the effective total loss is the
218            # sum of each of the replica's losses.
219            y /= strategy.num_replicas_in_sync
220
221          grad = tape.gradient(y, layer.v)
222          return opt.apply_gradients([(grad, layer.v)])
223
224        op = strategy.experimental_run(run_fn)
225        if not context.executing_eagerly():
226          self.evaluate(variables.global_variables_initializer())
227          self.evaluate(op)
228        # The gradient with respective to the variable is 1. Since the
229        # variable is initialized with 1 and the learning rate is 2**-14, the
230        # new variable value should be: init_val - gradient * learning_rate,
231        # which is  1 - 1 * 2**-14
232        self.assertEqual(self.evaluate(layer.v), 1 - 2**-14)
233
234  def _test_checkpointing_layer_weights(self, strategy_fn,
235                                        mixed_prec_when_saving,
236                                        mixed_prec_when_loading):
237    # In this test, we potentially save with mixed precision enabled and load
238    # with mixed precision disabled, or vice versa. This is possible because
239    # variables are float32 regardless of whether mixed precision is enabled.
240    save_policy = 'mixed_float16' if mixed_prec_when_saving else 'float32'
241    load_policy = 'mixed_float16' if mixed_prec_when_loading else 'float32'
242    save_input_dtype = 'float16' if mixed_prec_when_saving else 'float32'
243    load_input_dtype = 'float16' if mixed_prec_when_loading else 'float32'
244
245    # Create a layer and save a checkpoint.
246    x = constant_op.constant([1.])
247    with strategy_fn().scope():
248      with policy.policy_scope(save_policy):
249        layer = mp_test_util.MultiplyLayer(assert_type=save_input_dtype)
250        layer(x)  # Build layer
251    layer.set_weights([np.array(100.)])
252    self.assertEqual(self.evaluate(layer(x)), 100.)
253    checkpoint = trackable_utils.Checkpoint(layer=layer)
254    prefix = os.path.join(self.get_temp_dir(), 'ckpt')
255    save_path = checkpoint.save(prefix)
256
257    # Create a new layer and restore the checkpoint.
258    x = constant_op.constant([1.])
259    with strategy_fn().scope():
260      with policy.policy_scope(load_policy):
261        layer = mp_test_util.MultiplyLayer(assert_type=load_input_dtype)
262        layer(x)  # Build layer
263    layer.set_weights([np.array(200.)])
264    self.assertEqual(self.evaluate(layer(x)), 200.)
265    checkpoint = trackable_utils.Checkpoint(layer=layer)
266    checkpoint.restore(save_path).assert_consumed().run_restore_ops()
267    self.assertEqual(layer.get_weights(), [100.])
268    self.assertEqual(self.evaluate(layer(x)), 100.)
269
270  @parameterized.named_parameters(*TESTCASES)
271  def test_checkpointing_layer_weights(self, strategy_fn):
272    with self.test_session():
273      self._test_checkpointing_layer_weights(
274          strategy_fn, mixed_prec_when_saving=True,
275          mixed_prec_when_loading=True)
276      self._test_checkpointing_layer_weights(
277          strategy_fn, mixed_prec_when_saving=True,
278          mixed_prec_when_loading=False)
279      self._test_checkpointing_layer_weights(
280          strategy_fn, mixed_prec_when_saving=False,
281          mixed_prec_when_loading=True)
282
283  @parameterized.named_parameters(*TESTCASES)
284  def test_config(self, strategy_fn):
285    x = constant_op.constant([1.], dtype=dtypes.float16)
286    with strategy_fn().scope():
287      for layer, dtype in (
288          (mp_test_util.MultiplyLayer(), 'float32'),
289          (mp_test_util.MultiplyLayer(dtype='float64'), 'float64'),
290          (mp_test_util.MultiplyLayer(dtype=policy.Policy('float64')),
291           'float64')):
292        config = layer.get_config()
293        self.assertEqual(config['dtype'], dtype)
294        self.assertIsInstance(config['dtype'], str)
295        layer = mp_test_util.MultiplyLayer.from_config(config)
296        self.assertEqual(layer.dtype, dtype)
297        self.assertEqual(layer(x).dtype, dtype)
298        self.assertEqual(layer.v.dtype, dtype)
299
300      layer = mp_test_util.MultiplyLayer(dtype='mixed_float16')
301      config = layer.get_config()
302      self.assertEqual(config['dtype'],
303                       {'class_name': 'Policy',
304                        'config': {'name': 'mixed_float16'}})
305      layer = mp_test_util.MultiplyLayer.from_config(config)
306      self.assertEqual(layer.dtype, 'float32')
307      self.assertEqual(layer(x).dtype, 'float16')
308      self.assertEqual(layer.v.dtype, 'float32')
309      config = layer.get_config()
310      self.assertEqual(config['dtype'],
311                       {'class_name': 'Policy',
312                        'config': {'name': 'mixed_float16'}})
313
314      layer = mp_test_util.MultiplyLayer(dtype=policy.Policy('_infer'))
315      config = layer.get_config()
316      self.assertIsNone(config['dtype'])
317      layer = mp_test_util.MultiplyLayer.from_config(config)
318      # If a layer is serialized with the "_infer" policy, when deserialized
319      # into TF 2 it will have the global policy instead of "_infer". This is
320      # because "_infer" is serialized into None, and passing dtype=None in
321      # TensorFlow 2 indicates to use the global policy.
322      self.assertEqual(layer.dtype, 'float32')
323      self.assertEqual(layer(x).dtype, 'float32')
324      self.assertEqual(layer.v.dtype, 'float32')
325
326  @parameterized.named_parameters(*TESTCASES)
327  def test_config_policy_v1(self, strategy_fn):
328    x = constant_op.constant([1.], dtype=dtypes.float16)
329    with strategy_fn().scope():
330
331      layer = mp_test_util.MultiplyLayer(dtype=policy.PolicyV1('mixed_float16',
332                                                               loss_scale=None))
333      config = layer.get_config()
334      self.assertEqual(config['dtype'],
335                       {'class_name': 'PolicyV1',
336                        'config': {'name': 'mixed_float16',
337                                   'loss_scale': None}})
338      layer = mp_test_util.MultiplyLayer.from_config(config)
339      self.assertEqual(layer.dtype, 'float32')
340      self.assertEqual(layer(x).dtype, 'float16')
341      self.assertEqual(layer.v.dtype, 'float32')
342      # Restoring a PolicyV1 silently converts it to a Policy and drops the loss
343      # scale.
344      self.assertEqual(type(layer.dtype_policy), policy.Policy)
345      config = layer.get_config()
346      # The loss_scale is silently dropped
347      self.assertEqual(config['dtype'],
348                       {'class_name': 'Policy',
349                        'config': {'name': 'mixed_float16'}})
350
351      layer = mp_test_util.MultiplyLayer(dtype=policy.PolicyV1('float64',
352                                                               loss_scale=2.))
353      config = layer.get_config()
354      self.assertEqual(config['dtype'],
355                       {'class_name': 'PolicyV1',
356                        'config': {'name': 'float64',
357                                   'loss_scale': {
358                                       'class_name': 'FixedLossScale',
359                                       'config': {'loss_scale_value': 2.0}}}})
360      layer = mp_test_util.MultiplyLayer.from_config(config)
361      self.assertEqual(layer.dtype, 'float64')
362      self.assertEqual(layer(x).dtype, 'float64')
363      self.assertEqual(layer.v.dtype, 'float64')
364      self.assertEqual(type(layer.dtype_policy), policy.Policy)
365      config = layer.get_config()
366      self.assertEqual(config['dtype'], 'float64')
367
368      layer = mp_test_util.MultiplyLayer(dtype=policy.PolicyV1('_infer',
369                                                               loss_scale=2.))
370      config = layer.get_config()
371      self.assertEqual(config['dtype'],
372                       {'class_name': 'PolicyV1',
373                        'config': {'name': '_infer',
374                                   'loss_scale': {
375                                       'class_name': 'FixedLossScale',
376                                       'config': {'loss_scale_value': 2.0}}}})
377      layer = mp_test_util.MultiplyLayer.from_config(config)
378      self.assertEqual(layer.dtype, None)
379      self.assertEqual(layer(x).dtype, 'float16')
380      self.assertEqual(layer.v.dtype, 'float16')
381      self.assertEqual(type(layer.dtype_policy), policy.Policy)
382      config = layer.get_config()
383      self.assertEqual(config['dtype'], 'float16')
384
385  def test_delete_variable(self):
386    layer = base_layer.Layer(dtype='mixed_float16')
387    layer.x = layer.add_weight('x')
388    self.assertEqual(layer.trainable_weights, [layer.x])
389    del layer.x
390    self.assertEqual(layer.trainable_weights, [])
391
392  def test_build_and_call_layer_in_function(self):
393    layer = mp_test_util.MultiplyLayer(dtype=policy.Policy('mixed_float16'))
394    @def_function.function
395    def f():
396      return layer(1.)
397    y = f()
398    self.evaluate(variables.global_variables_initializer())
399    self.assertEqual(y.dtype, 'float16')
400    self.assertEqual(layer.v.dtype, 'float32')
401    self.assertEqual(self.evaluate(y), 1.)
402
403  def test_unsupported_strategy(self):
404    strategy = create_central_storage_strategy()
405    with strategy.scope(), self.assertRaisesRegex(
406        ValueError, 'Mixed precision is not supported with the '
407        'tf.distribute.Strategy: CentralStorageStrategy. Either '
408        'stop using mixed precision by removing the use of the '
409        '"mixed_float16" policy or use a different Strategy, e.g. '
410        'a MirroredStrategy.'):
411      mp_test_util.MultiplyLayer(dtype='mixed_float16')
412    # Non-mixed policies are fine
413    mp_test_util.MultiplyLayer(dtype=policy.Policy('float64'))
414
415  def test_input_spec_dtype(self):
416    # Test the InputSpec's dtype is compared against the inputs before the layer
417    # casts them, not after.
418    layer = mp_test_util.MultiplyLayer(dtype='float64')
419    layer.input_spec = input_spec.InputSpec(dtype='float16')
420
421    # Test passing Eager tensors
422    x = array_ops.ones((2, 2), dtype='float16')
423    layer(x)
424    x = array_ops.ones((2, 2), dtype='float64')
425    with self.assertRaisesRegex(
426        ValueError, 'expected dtype=float16, found dtype=.*float64'):
427      layer(x)
428
429    # Test passing symbolic tensors
430    x = layers.Input((2,), dtype='float16')
431    y = layer(x)
432    model = models.Model(x, y)
433    model(array_ops.ones((2, 2)))
434
435    x = layers.Input((2,), dtype='float64')
436    with self.assertRaisesRegex(
437        ValueError, 'expected dtype=float16, found dtype=.*float64'):
438      # In TF2, the error is only raised when the model is run
439      y = layer(x)
440      model = models.Model(x, y)
441      model(array_ops.ones((2, 2)))
442
443
444if __name__ == '__main__':
445  base_layer_utils.enable_v2_dtype_behavior()
446  test.main()
447