1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Contains testing utilities related to mixed precision.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20from tensorflow.python.framework import dtypes 21from tensorflow.python.framework import ops 22from tensorflow.python.keras import regularizers 23from tensorflow.python.keras.engine import base_layer 24from tensorflow.python.ops import array_ops 25from tensorflow.python.ops import check_ops 26from tensorflow.python.ops import control_flow_ops 27from tensorflow.python.ops import custom_gradient 28from tensorflow.python.ops import math_ops 29from tensorflow.python.util import nest 30 31 32def create_identity_with_grad_check_fn(expected_gradient, expected_dtype=None): 33 """Returns a function that asserts it's gradient has a certain value. 34 35 This serves as a hook to assert intermediate gradients have a certain value. 36 This returns an identity function. The identity's gradient function is also 37 the identity function, except it asserts that the gradient equals 38 `expected_gradient` and has dtype `expected_dtype`. 39 40 Args: 41 expected_gradient: The gradient function asserts that the gradient is this 42 value. 43 expected_dtype: The gradient function asserts the gradient has this dtype. 44 45 Returns: 46 An identity function whose gradient function asserts the gradient has a 47 certain value. 48 """ 49 @custom_gradient.custom_gradient 50 def _identity_with_grad_check(x): 51 """Function that asserts it's gradient has a certain value.""" 52 x = array_ops.identity(x) 53 def grad(dx): 54 """Gradient function that asserts the gradient has a certain value.""" 55 if expected_dtype: 56 assert dx.dtype == expected_dtype, ( 57 'dx.dtype should be %s but is: %s' % (expected_dtype, dx.dtype)) 58 expected_tensor = ops.convert_to_tensor_v2_with_dispatch( 59 expected_gradient, dtype=dx.dtype, name='expected_gradient') 60 # Control dependency is to ensure input is available. It's possible the 61 # dataset will throw a StopIteration to indicate there is no more data, in 62 # which case we don't want to run the assertion. 63 with ops.control_dependencies([x]): 64 assert_op = check_ops.assert_equal(dx, expected_tensor) 65 with ops.control_dependencies([assert_op]): 66 dx = array_ops.identity(dx) 67 return dx 68 return x, grad 69 # Keras sometimes has trouble serializing Lambda layers with a decorated 70 # function. So we define and return a non-decorated function. 71 def identity_with_grad_check(x): 72 return _identity_with_grad_check(x) 73 return identity_with_grad_check 74 75 76def create_identity_with_nan_gradients_fn(have_nan_gradients): 77 """Returns a function that optionally has NaN gradients. 78 79 This serves as a hook to introduce NaN gradients to a model. This returns an 80 identity function. The identity's gradient function will check if the boolean 81 tensor `have_nan_gradients` is True. If so, the gradient will be NaN. 82 Otherwise, the gradient will also be the identity. 83 84 Args: 85 have_nan_gradients: A scalar boolean tensor. If True, gradients will be NaN. 86 Otherwise, the gradient function is the identity function. 87 88 Returns: 89 An identity function whose gradient function will return NaNs, if 90 `have_nan_gradients` is True. 91 """ 92 @custom_gradient.custom_gradient 93 def _identity_with_nan_gradients(x): 94 """Function whose gradient is NaN iff `have_nan_gradients` is True.""" 95 x = array_ops.identity(x) 96 def grad(dx): 97 return control_flow_ops.cond( 98 have_nan_gradients, 99 lambda: dx * float('NaN'), 100 lambda: dx 101 ) 102 return x, grad 103 # Keras sometimes has trouble serializing Lambda layers with a decorated 104 # function. So we define and return a non-decorated function. 105 def identity_with_nan_gradients(x): 106 return _identity_with_nan_gradients(x) 107 return identity_with_nan_gradients 108 109 110class AssertTypeLayer(base_layer.Layer): 111 """A layer which asserts it's inputs are a certain type.""" 112 113 def __init__(self, assert_type=None, **kwargs): 114 self._assert_type = (dtypes.as_dtype(assert_type).name if assert_type 115 else None) 116 super(AssertTypeLayer, self).__init__(**kwargs) 117 118 def assert_input_types(self, inputs): 119 """Asserts `inputs` are of the correct type. Should be called in call().""" 120 if self._assert_type: 121 inputs_flattened = nest.flatten(inputs) 122 for inp in inputs_flattened: 123 assert inp.dtype.base_dtype == self._assert_type, ( 124 'Input tensor has type %s which does not match assert type %s' % 125 (inp.dtype.name, self._assert_type)) 126 127 128class MultiplyLayer(AssertTypeLayer): 129 """A layer which multiplies its input by a scalar variable.""" 130 131 def __init__(self, 132 regularizer=None, 133 activity_regularizer=None, 134 use_operator=False, 135 var_name='v', 136 **kwargs): 137 """Initializes the MultiplyLayer. 138 139 Args: 140 regularizer: The weight regularizer on the scalar variable. 141 activity_regularizer: The activity regularizer. 142 use_operator: If True, add using the * operator. If False, add using 143 tf.multiply. 144 var_name: The name of the variable. It can be useful to pass a name other 145 than 'v', to test having the attribute name (self.v) being different 146 from the variable name. 147 **kwargs: Passed to AssertTypeLayer constructor. 148 """ 149 self._regularizer = regularizer 150 if isinstance(regularizer, dict): 151 self._regularizer = regularizers.deserialize(regularizer, 152 custom_objects=globals()) 153 self._activity_regularizer = activity_regularizer 154 if isinstance(activity_regularizer, dict): 155 self._activity_regularizer = regularizers.deserialize( 156 activity_regularizer, custom_objects=globals()) 157 158 self._use_operator = use_operator 159 self._var_name = var_name 160 super(MultiplyLayer, self).__init__( 161 activity_regularizer=self._activity_regularizer, **kwargs) 162 163 def build(self, _): 164 self.v = self.add_weight( 165 self._var_name, (), initializer='ones', regularizer=self._regularizer) 166 self.built = True 167 168 def call(self, inputs): 169 self.assert_input_types(inputs) 170 return self._multiply(inputs, self.v) 171 172 def _multiply(self, x, y): 173 if self._use_operator: 174 return x * y 175 else: 176 return math_ops.multiply(x, y) 177 178 def get_config(self): 179 config = super(MultiplyLayer, self).get_config() 180 config['regularizer'] = regularizers.serialize(self._regularizer) 181 config['activity_regularizer'] = regularizers.serialize( 182 self._activity_regularizer) 183 config['use_operator'] = self._use_operator 184 config['var_name'] = self._var_name 185 config['assert_type'] = self._assert_type 186 return config 187 188 189class IdentityRegularizer(regularizers.Regularizer): 190 191 def __call__(self, x): 192 assert x.dtype == dtypes.float32 193 return array_ops.identity(x) 194 195 def get_config(self): 196 return {} 197 198 199class ReduceSumRegularizer(regularizers.Regularizer): 200 201 def __call__(self, x): 202 return math_ops.reduce_sum(x) 203 204 def get_config(self): 205 return {} 206