1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests keras.layers.Layer works properly with mixed precision.""" 16 17import os 18 19from absl.testing import parameterized 20import numpy as np 21 22from tensorflow.python.distribute import central_storage_strategy 23from tensorflow.python.distribute import distribution_strategy_context 24from tensorflow.python.distribute import mirrored_strategy 25from tensorflow.python.eager import backprop 26from tensorflow.python.eager import context 27from tensorflow.python.eager import def_function 28from tensorflow.python.framework import config as tf_config 29from tensorflow.python.framework import constant_op 30from tensorflow.python.framework import dtypes 31from tensorflow.python.keras import combinations 32from tensorflow.python.keras import keras_parameterized 33from tensorflow.python.keras import layers 34from tensorflow.python.keras import models 35from tensorflow.python.keras.engine import base_layer 36from tensorflow.python.keras.engine import base_layer_utils 37from tensorflow.python.keras.engine import input_spec 38from tensorflow.python.keras.mixed_precision import get_layer_policy 39from tensorflow.python.keras.mixed_precision import policy 40from tensorflow.python.keras.mixed_precision import test_util as mp_test_util 41from tensorflow.python.keras.optimizer_v2 import gradient_descent 42from tensorflow.python.ops import array_ops 43from tensorflow.python.ops import math_ops 44from tensorflow.python.ops import variables 45from tensorflow.python.platform import test 46from tensorflow.python.training.tracking import util as trackable_utils 47 48 49class MultiplyLayerWithFunction(mp_test_util.MultiplyLayer): 50 """Same as MultiplyLayer, but _multiply is decorated with a tf.function.""" 51 52 @def_function.function 53 def _multiply(self, x, y): 54 return super(MultiplyLayerWithFunction, self)._multiply(x, y) 55 56 57# If called outside any strategy.scope() calls, this will return the default 58# strategy. 59default_strategy_fn = distribution_strategy_context.get_strategy 60 61 62def create_mirrored_strategy(): 63 """Create a MirroredStrategy, using a GPU if it is available.""" 64 if tf_config.list_logical_devices('GPU'): 65 return mirrored_strategy.MirroredStrategy(['cpu:0', 'gpu:0']) 66 else: 67 return mirrored_strategy.MirroredStrategy(['cpu:0']) 68 69 70def create_central_storage_strategy(): 71 """Create a CentralStorageStrategy, using a GPU if it is available.""" 72 compute_devices = ['cpu:0', 'gpu:0'] if ( 73 tf_config.list_logical_devices('GPU')) else ['cpu:0'] 74 return central_storage_strategy.CentralStorageStrategy( 75 compute_devices, parameter_device='cpu:0') 76 77 78TESTCASES = ({ 79 'testcase_name': 'base', 80 'strategy_fn': default_strategy_fn 81}, { 82 'testcase_name': 'distribute', 83 'strategy_fn': create_mirrored_strategy 84}) 85 86 87@combinations.generate(combinations.combine(mode=['graph', 'eager'])) 88class LayerTest(keras_parameterized.TestCase): 89 """Test mixed precision with Keras layers.""" 90 91 @parameterized.named_parameters(*TESTCASES) 92 def test_mixed_policies_(self, strategy_fn): 93 strategy = strategy_fn() 94 for dtype in 'float16', 'bfloat16': 95 x = constant_op.constant([1.]) 96 policy_name = 'mixed_' + dtype 97 with strategy.scope(), policy.policy_scope(policy_name): 98 layer = mp_test_util.MultiplyLayer(assert_type=dtype) 99 self.assertEqual(layer.dtype, dtypes.float32) 100 self.assertEqual(get_layer_policy.get_layer_policy(layer).name, 101 policy_name) 102 y = layer(x) 103 self.assertEqual(layer.v.dtype, dtypes.float32) 104 self.assertEqual(y.dtype, dtype) 105 self.assertEqual(layer.dtype_policy.name, policy_name) 106 self.assertIsInstance(layer.dtype_policy, policy.Policy) 107 self.assertEqual(layer.compute_dtype, dtype) 108 self.assertEqual(layer.dtype, dtypes.float32) 109 self.assertEqual(layer.variable_dtype, dtypes.float32) 110 self.assertEqual(get_layer_policy.get_layer_policy(layer).name, 111 policy_name) 112 self.evaluate(variables.global_variables_initializer()) 113 self.assertEqual(self.evaluate(y), 1.) 114 115 def test_layer_with_int_variable(self): 116 class LayerWithIntVar(base_layer.Layer): 117 118 def build(self, _): 119 self.v = self.add_weight('v', dtype='int32', trainable=False) 120 121 def call(self, inputs): 122 # Only float variables should be autocasted. This will fail if self.v is 123 # autocasted to float32 124 return math_ops.cast(inputs, 'int32') + self.v 125 126 x = constant_op.constant([1.]) 127 layer = LayerWithIntVar(dtype='mixed_float16') 128 self.assertEqual(layer(x).dtype, 'int32') 129 130 @parameterized.named_parameters(*TESTCASES) 131 def test_layer_with_non_autocast_variable(self, strategy_fn): 132 x = constant_op.constant([1.]) 133 with strategy_fn().scope(): 134 with policy.policy_scope('mixed_float16'): 135 layer = mp_test_util.MultiplyLayerWithoutAutoCast( 136 assert_type=dtypes.float16) 137 y = layer(x) 138 self.assertEqual(layer.v.dtype, dtypes.float32) 139 self.assertEqual(y.dtype, dtypes.float16) 140 self.evaluate(variables.global_variables_initializer()) 141 self.assertEqual(self.evaluate(y), 1.) 142 143 @parameterized.named_parameters(*TESTCASES) 144 def test_layer_calling_tf_function(self, strategy_fn): 145 x = constant_op.constant([1.]) 146 with strategy_fn().scope(): 147 with policy.policy_scope('mixed_float16'): 148 layer = MultiplyLayerWithFunction(assert_type=dtypes.float16) 149 y = layer(x) 150 self.assertEqual(layer.v.dtype, dtypes.float32) 151 self.assertEqual(y.dtype, dtypes.float16) 152 self.evaluate(variables.global_variables_initializer()) 153 self.assertEqual(self.evaluate(y), 1.) 154 155 @parameterized.named_parameters(*TESTCASES) 156 def test_layer_regularizer_runs_in_var_dtype(self, strategy_fn): 157 x = constant_op.constant([1.]) 158 with strategy_fn().scope(): 159 with policy.policy_scope('mixed_float16'): 160 # Test on MultiplyLayer 161 layer = mp_test_util.MultiplyLayer( 162 assert_type=dtypes.float16, 163 regularizer=mp_test_util.IdentityRegularizer()) 164 layer(x) 165 (regularizer_loss,) = layer.losses 166 self.assertEqual(regularizer_loss.dtype, dtypes.float32) 167 self.evaluate(variables.global_variables_initializer()) 168 self.assertEqual(self.evaluate(regularizer_loss), 1.) 169 170 # Test on MultiplyLayerWithoutAutoCast 171 layer = mp_test_util.MultiplyLayerWithoutAutoCast( 172 assert_type=dtypes.float16, 173 regularizer=mp_test_util.IdentityRegularizer()) 174 layer(x) 175 (regularizer_loss,) = layer.losses 176 self.assertEqual(regularizer_loss.dtype, dtypes.float32) 177 self.evaluate(variables.global_variables_initializer()) 178 self.assertEqual(self.evaluate(regularizer_loss), 1.) 179 180 @parameterized.named_parameters(*TESTCASES) 181 def test_passing_policy_to_layer(self, strategy_fn): 182 x = constant_op.constant([1.], dtype=dtypes.float16) 183 with strategy_fn().scope(): 184 # Passing a Policy to 'dtype' sets the policy for that layer. 185 layer = mp_test_util.MultiplyLayer( 186 assert_type=dtypes.float16, dtype=policy.Policy('mixed_float16')) 187 # layer.dtype refers to the variable dtype 188 self.assertEqual(layer.dtype, dtypes.float32) 189 layer(x) 190 self.assertEqual(layer.v.dtype, dtypes.float32) 191 with policy.policy_scope('mixed_float16'): 192 # Passing a Policy to dtype overrides the global Policy 193 layer = mp_test_util.MultiplyLayer( 194 assert_type=dtypes.float64, dtype=policy.Policy('float64')) 195 self.assertEqual(layer.dtype_policy.name, 'float64') 196 self.assertIsInstance(layer.dtype_policy, policy.Policy) 197 self.assertEqual(layer.compute_dtype, dtypes.float64) 198 self.assertEqual(layer.dtype, dtypes.float64) 199 self.assertEqual(layer.variable_dtype, dtypes.float64) 200 self.assertEqual(layer(x).dtype, dtypes.float64) 201 self.assertEqual(layer.v.dtype, dtypes.float64) 202 203 @parameterized.named_parameters(*TESTCASES) 204 def test_gradient(self, strategy_fn): 205 x = constant_op.constant([1.]) 206 with strategy_fn().scope() as strategy: 207 with policy.policy_scope('mixed_float16'): 208 layer = mp_test_util.MultiplyLayer(assert_type=dtypes.float16) 209 # Learning rate is small enough that if applied to a float16 variable, 210 # the variable will not change. So this tests the learning rate is not 211 # applied to a float16 value, but instead the float32 variable. 212 opt = gradient_descent.SGD(2**-14) 213 214 def run_fn(): 215 with backprop.GradientTape() as tape: 216 y = layer(x) 217 # Divide by num_replicas_in_sync, as the effective total loss is the 218 # sum of each of the replica's losses. 219 y /= strategy.num_replicas_in_sync 220 221 grad = tape.gradient(y, layer.v) 222 return opt.apply_gradients([(grad, layer.v)]) 223 224 op = strategy.experimental_run(run_fn) 225 if not context.executing_eagerly(): 226 self.evaluate(variables.global_variables_initializer()) 227 self.evaluate(op) 228 # The gradient with respective to the variable is 1. Since the 229 # variable is initialized with 1 and the learning rate is 2**-14, the 230 # new variable value should be: init_val - gradient * learning_rate, 231 # which is 1 - 1 * 2**-14 232 self.assertEqual(self.evaluate(layer.v), 1 - 2**-14) 233 234 def _test_checkpointing_layer_weights(self, strategy_fn, 235 mixed_prec_when_saving, 236 mixed_prec_when_loading): 237 # In this test, we potentially save with mixed precision enabled and load 238 # with mixed precision disabled, or vice versa. This is possible because 239 # variables are float32 regardless of whether mixed precision is enabled. 240 save_policy = 'mixed_float16' if mixed_prec_when_saving else 'float32' 241 load_policy = 'mixed_float16' if mixed_prec_when_loading else 'float32' 242 save_input_dtype = 'float16' if mixed_prec_when_saving else 'float32' 243 load_input_dtype = 'float16' if mixed_prec_when_loading else 'float32' 244 245 # Create a layer and save a checkpoint. 246 x = constant_op.constant([1.]) 247 with strategy_fn().scope(): 248 with policy.policy_scope(save_policy): 249 layer = mp_test_util.MultiplyLayer(assert_type=save_input_dtype) 250 layer(x) # Build layer 251 layer.set_weights([np.array(100.)]) 252 self.assertEqual(self.evaluate(layer(x)), 100.) 253 checkpoint = trackable_utils.Checkpoint(layer=layer) 254 prefix = os.path.join(self.get_temp_dir(), 'ckpt') 255 save_path = checkpoint.save(prefix) 256 257 # Create a new layer and restore the checkpoint. 258 x = constant_op.constant([1.]) 259 with strategy_fn().scope(): 260 with policy.policy_scope(load_policy): 261 layer = mp_test_util.MultiplyLayer(assert_type=load_input_dtype) 262 layer(x) # Build layer 263 layer.set_weights([np.array(200.)]) 264 self.assertEqual(self.evaluate(layer(x)), 200.) 265 checkpoint = trackable_utils.Checkpoint(layer=layer) 266 checkpoint.restore(save_path).assert_consumed().run_restore_ops() 267 self.assertEqual(layer.get_weights(), [100.]) 268 self.assertEqual(self.evaluate(layer(x)), 100.) 269 270 @parameterized.named_parameters(*TESTCASES) 271 def test_checkpointing_layer_weights(self, strategy_fn): 272 with self.test_session(): 273 self._test_checkpointing_layer_weights( 274 strategy_fn, mixed_prec_when_saving=True, 275 mixed_prec_when_loading=True) 276 self._test_checkpointing_layer_weights( 277 strategy_fn, mixed_prec_when_saving=True, 278 mixed_prec_when_loading=False) 279 self._test_checkpointing_layer_weights( 280 strategy_fn, mixed_prec_when_saving=False, 281 mixed_prec_when_loading=True) 282 283 @parameterized.named_parameters(*TESTCASES) 284 def test_config(self, strategy_fn): 285 x = constant_op.constant([1.], dtype=dtypes.float16) 286 with strategy_fn().scope(): 287 for layer, dtype in ( 288 (mp_test_util.MultiplyLayer(), 'float32'), 289 (mp_test_util.MultiplyLayer(dtype='float64'), 'float64'), 290 (mp_test_util.MultiplyLayer(dtype=policy.Policy('float64')), 291 'float64')): 292 config = layer.get_config() 293 self.assertEqual(config['dtype'], dtype) 294 self.assertIsInstance(config['dtype'], str) 295 layer = mp_test_util.MultiplyLayer.from_config(config) 296 self.assertEqual(layer.dtype, dtype) 297 self.assertEqual(layer(x).dtype, dtype) 298 self.assertEqual(layer.v.dtype, dtype) 299 300 layer = mp_test_util.MultiplyLayer(dtype='mixed_float16') 301 config = layer.get_config() 302 self.assertEqual(config['dtype'], 303 {'class_name': 'Policy', 304 'config': {'name': 'mixed_float16'}}) 305 layer = mp_test_util.MultiplyLayer.from_config(config) 306 self.assertEqual(layer.dtype, 'float32') 307 self.assertEqual(layer(x).dtype, 'float16') 308 self.assertEqual(layer.v.dtype, 'float32') 309 config = layer.get_config() 310 self.assertEqual(config['dtype'], 311 {'class_name': 'Policy', 312 'config': {'name': 'mixed_float16'}}) 313 314 layer = mp_test_util.MultiplyLayer(dtype=policy.Policy('_infer')) 315 config = layer.get_config() 316 self.assertIsNone(config['dtype']) 317 layer = mp_test_util.MultiplyLayer.from_config(config) 318 # If a layer is serialized with the "_infer" policy, when deserialized 319 # into TF 2 it will have the global policy instead of "_infer". This is 320 # because "_infer" is serialized into None, and passing dtype=None in 321 # TensorFlow 2 indicates to use the global policy. 322 self.assertEqual(layer.dtype, 'float32') 323 self.assertEqual(layer(x).dtype, 'float32') 324 self.assertEqual(layer.v.dtype, 'float32') 325 326 @parameterized.named_parameters(*TESTCASES) 327 def test_config_policy_v1(self, strategy_fn): 328 x = constant_op.constant([1.], dtype=dtypes.float16) 329 with strategy_fn().scope(): 330 331 layer = mp_test_util.MultiplyLayer(dtype=policy.PolicyV1('mixed_float16', 332 loss_scale=None)) 333 config = layer.get_config() 334 self.assertEqual(config['dtype'], 335 {'class_name': 'PolicyV1', 336 'config': {'name': 'mixed_float16', 337 'loss_scale': None}}) 338 layer = mp_test_util.MultiplyLayer.from_config(config) 339 self.assertEqual(layer.dtype, 'float32') 340 self.assertEqual(layer(x).dtype, 'float16') 341 self.assertEqual(layer.v.dtype, 'float32') 342 # Restoring a PolicyV1 silently converts it to a Policy and drops the loss 343 # scale. 344 self.assertEqual(type(layer.dtype_policy), policy.Policy) 345 config = layer.get_config() 346 # The loss_scale is silently dropped 347 self.assertEqual(config['dtype'], 348 {'class_name': 'Policy', 349 'config': {'name': 'mixed_float16'}}) 350 351 layer = mp_test_util.MultiplyLayer(dtype=policy.PolicyV1('float64', 352 loss_scale=2.)) 353 config = layer.get_config() 354 self.assertEqual(config['dtype'], 355 {'class_name': 'PolicyV1', 356 'config': {'name': 'float64', 357 'loss_scale': { 358 'class_name': 'FixedLossScale', 359 'config': {'loss_scale_value': 2.0}}}}) 360 layer = mp_test_util.MultiplyLayer.from_config(config) 361 self.assertEqual(layer.dtype, 'float64') 362 self.assertEqual(layer(x).dtype, 'float64') 363 self.assertEqual(layer.v.dtype, 'float64') 364 self.assertEqual(type(layer.dtype_policy), policy.Policy) 365 config = layer.get_config() 366 self.assertEqual(config['dtype'], 'float64') 367 368 layer = mp_test_util.MultiplyLayer(dtype=policy.PolicyV1('_infer', 369 loss_scale=2.)) 370 config = layer.get_config() 371 self.assertEqual(config['dtype'], 372 {'class_name': 'PolicyV1', 373 'config': {'name': '_infer', 374 'loss_scale': { 375 'class_name': 'FixedLossScale', 376 'config': {'loss_scale_value': 2.0}}}}) 377 layer = mp_test_util.MultiplyLayer.from_config(config) 378 self.assertEqual(layer.dtype, None) 379 self.assertEqual(layer(x).dtype, 'float16') 380 self.assertEqual(layer.v.dtype, 'float16') 381 self.assertEqual(type(layer.dtype_policy), policy.Policy) 382 config = layer.get_config() 383 self.assertEqual(config['dtype'], 'float16') 384 385 def test_delete_variable(self): 386 layer = base_layer.Layer(dtype='mixed_float16') 387 layer.x = layer.add_weight('x') 388 self.assertEqual(layer.trainable_weights, [layer.x]) 389 del layer.x 390 self.assertEqual(layer.trainable_weights, []) 391 392 def test_build_and_call_layer_in_function(self): 393 layer = mp_test_util.MultiplyLayer(dtype=policy.Policy('mixed_float16')) 394 @def_function.function 395 def f(): 396 return layer(1.) 397 y = f() 398 self.evaluate(variables.global_variables_initializer()) 399 self.assertEqual(y.dtype, 'float16') 400 self.assertEqual(layer.v.dtype, 'float32') 401 self.assertEqual(self.evaluate(y), 1.) 402 403 def test_unsupported_strategy(self): 404 strategy = create_central_storage_strategy() 405 with strategy.scope(), self.assertRaisesRegex( 406 ValueError, 'Mixed precision is not supported with the ' 407 'tf.distribute.Strategy: CentralStorageStrategy. Either ' 408 'stop using mixed precision by removing the use of the ' 409 '"mixed_float16" policy or use a different Strategy, e.g. ' 410 'a MirroredStrategy.'): 411 mp_test_util.MultiplyLayer(dtype='mixed_float16') 412 # Non-mixed policies are fine 413 mp_test_util.MultiplyLayer(dtype=policy.Policy('float64')) 414 415 def test_input_spec_dtype(self): 416 # Test the InputSpec's dtype is compared against the inputs before the layer 417 # casts them, not after. 418 layer = mp_test_util.MultiplyLayer(dtype='float64') 419 layer.input_spec = input_spec.InputSpec(dtype='float16') 420 421 # Test passing Eager tensors 422 x = array_ops.ones((2, 2), dtype='float16') 423 layer(x) 424 x = array_ops.ones((2, 2), dtype='float64') 425 with self.assertRaisesRegex( 426 ValueError, 'expected dtype=float16, found dtype=.*float64'): 427 layer(x) 428 429 # Test passing symbolic tensors 430 x = layers.Input((2,), dtype='float16') 431 y = layer(x) 432 model = models.Model(x, y) 433 model(array_ops.ones((2, 2))) 434 435 x = layers.Input((2,), dtype='float64') 436 with self.assertRaisesRegex( 437 ValueError, 'expected dtype=float16, found dtype=.*float64'): 438 # In TF2, the error is only raised when the model is run 439 y = layer(x) 440 model = models.Model(x, y) 441 model(array_ops.ones((2, 2))) 442 443 444if __name__ == '__main__': 445 base_layer_utils.enable_v2_dtype_behavior() 446 test.main() 447