• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Contains testing utilities related to mixed precision."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20from tensorflow.python.framework import dtypes
21from tensorflow.python.framework import ops
22from tensorflow.python.keras import regularizers
23from tensorflow.python.keras.engine import base_layer
24from tensorflow.python.ops import array_ops
25from tensorflow.python.ops import check_ops
26from tensorflow.python.ops import control_flow_ops
27from tensorflow.python.ops import custom_gradient
28from tensorflow.python.ops import math_ops
29from tensorflow.python.util import nest
30
31
32def create_identity_with_grad_check_fn(expected_gradient, expected_dtype=None):
33  """Returns a function that asserts it's gradient has a certain value.
34
35  This serves as a hook to assert intermediate gradients have a certain value.
36  This returns an identity function. The identity's gradient function is also
37  the identity function, except it asserts that the gradient equals
38  `expected_gradient` and has dtype `expected_dtype`.
39
40  Args:
41    expected_gradient: The gradient function asserts that the gradient is this
42      value.
43    expected_dtype: The gradient function asserts the gradient has this dtype.
44
45  Returns:
46    An identity function whose gradient function asserts the gradient has a
47    certain value.
48  """
49  @custom_gradient.custom_gradient
50  def _identity_with_grad_check(x):
51    """Function that asserts it's gradient has a certain value."""
52    x = array_ops.identity(x)
53    def grad(dx):
54      """Gradient function that asserts the gradient has a certain value."""
55      if expected_dtype:
56        assert dx.dtype == expected_dtype, (
57            'dx.dtype should be %s but is: %s' % (expected_dtype, dx.dtype))
58      expected_tensor = ops.convert_to_tensor_v2_with_dispatch(
59          expected_gradient, dtype=dx.dtype, name='expected_gradient')
60      # Control dependency is to ensure input is available. It's possible the
61      # dataset will throw a StopIteration to indicate there is no more data, in
62      # which case we don't want to run the assertion.
63      with ops.control_dependencies([x]):
64        assert_op = check_ops.assert_equal(dx, expected_tensor)
65      with ops.control_dependencies([assert_op]):
66        dx = array_ops.identity(dx)
67      return dx
68    return x, grad
69  # Keras sometimes has trouble serializing Lambda layers with a decorated
70  # function. So we define and return a non-decorated function.
71  def identity_with_grad_check(x):
72    return _identity_with_grad_check(x)
73  return identity_with_grad_check
74
75
76def create_identity_with_nan_gradients_fn(have_nan_gradients):
77  """Returns a function that optionally has NaN gradients.
78
79  This serves as a hook to introduce NaN gradients to a model. This returns an
80  identity function. The identity's gradient function will check if the boolean
81  tensor `have_nan_gradients` is True. If so, the gradient will be NaN.
82  Otherwise, the gradient will also be the identity.
83
84  Args:
85    have_nan_gradients: A scalar boolean tensor. If True, gradients will be NaN.
86      Otherwise, the gradient function is the identity function.
87
88  Returns:
89    An identity function whose gradient function will return NaNs, if
90    `have_nan_gradients` is True.
91  """
92  @custom_gradient.custom_gradient
93  def _identity_with_nan_gradients(x):
94    """Function whose gradient is NaN iff `have_nan_gradients` is True."""
95    x = array_ops.identity(x)
96    def grad(dx):
97      return control_flow_ops.cond(
98          have_nan_gradients,
99          lambda: dx * float('NaN'),
100          lambda: dx
101      )
102    return x, grad
103  # Keras sometimes has trouble serializing Lambda layers with a decorated
104  # function. So we define and return a non-decorated function.
105  def identity_with_nan_gradients(x):
106    return _identity_with_nan_gradients(x)
107  return identity_with_nan_gradients
108
109
110class AssertTypeLayer(base_layer.Layer):
111  """A layer which asserts it's inputs are a certain type."""
112
113  def __init__(self, assert_type=None, **kwargs):
114    self._assert_type = (dtypes.as_dtype(assert_type).name if assert_type
115                         else None)
116    super(AssertTypeLayer, self).__init__(**kwargs)
117
118  def assert_input_types(self, inputs):
119    """Asserts `inputs` are of the correct type. Should be called in call()."""
120    if self._assert_type:
121      inputs_flattened = nest.flatten(inputs)
122      for inp in inputs_flattened:
123        assert inp.dtype.base_dtype == self._assert_type, (
124            'Input tensor has type %s which does not match assert type %s' %
125            (inp.dtype.name, self._assert_type))
126
127
128class MultiplyLayer(AssertTypeLayer):
129  """A layer which multiplies its input by a scalar variable."""
130
131  def __init__(self,
132               regularizer=None,
133               activity_regularizer=None,
134               use_operator=False,
135               var_name='v',
136               **kwargs):
137    """Initializes the MultiplyLayer.
138
139    Args:
140      regularizer: The weight regularizer on the scalar variable.
141      activity_regularizer: The activity regularizer.
142      use_operator: If True, add using the * operator. If False, add using
143        tf.multiply.
144      var_name: The name of the variable. It can be useful to pass a name other
145        than 'v', to test having the attribute name (self.v) being different
146        from the variable name.
147      **kwargs: Passed to AssertTypeLayer constructor.
148    """
149    self._regularizer = regularizer
150    if isinstance(regularizer, dict):
151      self._regularizer = regularizers.deserialize(regularizer,
152                                                   custom_objects=globals())
153    self._activity_regularizer = activity_regularizer
154    if isinstance(activity_regularizer, dict):
155      self._activity_regularizer = regularizers.deserialize(
156          activity_regularizer, custom_objects=globals())
157
158    self._use_operator = use_operator
159    self._var_name = var_name
160    super(MultiplyLayer, self).__init__(
161        activity_regularizer=self._activity_regularizer, **kwargs)
162
163  def build(self, _):
164    self.v = self.add_weight(
165        self._var_name, (), initializer='ones', regularizer=self._regularizer)
166    self.built = True
167
168  def call(self, inputs):
169    self.assert_input_types(inputs)
170    return self._multiply(inputs, self.v)
171
172  def _multiply(self, x, y):
173    if self._use_operator:
174      return x * y
175    else:
176      return math_ops.multiply(x, y)
177
178  def get_config(self):
179    config = super(MultiplyLayer, self).get_config()
180    config['regularizer'] = regularizers.serialize(self._regularizer)
181    config['activity_regularizer'] = regularizers.serialize(
182        self._activity_regularizer)
183    config['use_operator'] = self._use_operator
184    config['var_name'] = self._var_name
185    config['assert_type'] = self._assert_type
186    return config
187
188
189class IdentityRegularizer(regularizers.Regularizer):
190
191  def __call__(self, x):
192    assert x.dtype == dtypes.float32
193    return array_ops.identity(x)
194
195  def get_config(self):
196    return {}
197
198
199class ReduceSumRegularizer(regularizers.Regularizer):
200
201  def __call__(self, x):
202    return math_ops.reduce_sum(x)
203
204  def get_config(self):
205    return {}
206