• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Built-in regularizers.
16"""
17# pylint: disable=invalid-name
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import math
23
24import six
25
26from tensorflow.python.keras import backend
27from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object
28from tensorflow.python.keras.utils.generic_utils import serialize_keras_object
29from tensorflow.python.ops import math_ops
30from tensorflow.python.util.tf_export import keras_export
31
32
33def _check_penalty_number(x):
34  """check penalty number availability, raise ValueError if failed."""
35  if not isinstance(x, (float, int)):
36    raise ValueError(('Value: {} is not a valid regularization penalty number, '
37                      'expected an int or float value').format(x))
38
39  if math.isinf(x) or math.isnan(x):
40    raise ValueError(
41        ('Value: {} is not a valid regularization penalty number, '
42         'a positive/negative infinity or NaN is not a property value'
43        ).format(x))
44
45
46def _none_to_default(inputs, default):
47  return default if inputs is None else default
48
49
50@keras_export('keras.regularizers.Regularizer')
51class Regularizer(object):
52  """Regularizer base class.
53
54  Regularizers allow you to apply penalties on layer parameters or layer
55  activity during optimization. These penalties are summed into the loss
56  function that the network optimizes.
57
58  Regularization penalties are applied on a per-layer basis. The exact API will
59  depend on the layer, but many layers (e.g. `Dense`, `Conv1D`, `Conv2D` and
60  `Conv3D`) have a unified API.
61
62  These layers expose 3 keyword arguments:
63
64  - `kernel_regularizer`: Regularizer to apply a penalty on the layer's kernel
65  - `bias_regularizer`: Regularizer to apply a penalty on the layer's bias
66  - `activity_regularizer`: Regularizer to apply a penalty on the layer's output
67
68  All layers (including custom layers) expose `activity_regularizer` as a
69  settable property, whether or not it is in the constructor arguments.
70
71  The value returned by the `activity_regularizer` is divided by the input
72  batch size so that the relative weighting between the weight regularizers and
73  the activity regularizers does not change with the batch size.
74
75  You can access a layer's regularization penalties by calling `layer.losses`
76  after calling the layer on inputs.
77
78  ## Example
79
80  >>> layer = tf.keras.layers.Dense(
81  ...     5, input_dim=5,
82  ...     kernel_initializer='ones',
83  ...     kernel_regularizer=tf.keras.regularizers.L1(0.01),
84  ...     activity_regularizer=tf.keras.regularizers.L2(0.01))
85  >>> tensor = tf.ones(shape=(5, 5)) * 2.0
86  >>> out = layer(tensor)
87
88  >>> # The kernel regularization term is 0.25
89  >>> # The activity regularization term (after dividing by the batch size) is 5
90  >>> tf.math.reduce_sum(layer.losses)
91  <tf.Tensor: shape=(), dtype=float32, numpy=5.25>
92
93  ## Available penalties
94
95  ```python
96  tf.keras.regularizers.L1(0.3)  # L1 Regularization Penalty
97  tf.keras.regularizers.L2(0.1)  # L2 Regularization Penalty
98  tf.keras.regularizers.L1L2(l1=0.01, l2=0.01)  # L1 + L2 penalties
99  ```
100
101  ## Directly calling a regularizer
102
103  Compute a regularization loss on a tensor by directly calling a regularizer
104  as if it is a one-argument function.
105
106  E.g.
107  >>> regularizer = tf.keras.regularizers.L2(2.)
108  >>> tensor = tf.ones(shape=(5, 5))
109  >>> regularizer(tensor)
110  <tf.Tensor: shape=(), dtype=float32, numpy=50.0>
111
112
113  ## Developing new regularizers
114
115  Any function that takes in a weight matrix and returns a scalar
116  tensor can be used as a regularizer, e.g.:
117
118  >>> @tf.keras.utils.register_keras_serializable(package='Custom', name='l1')
119  ... def l1_reg(weight_matrix):
120  ...    return 0.01 * tf.math.reduce_sum(tf.math.abs(weight_matrix))
121  ...
122  >>> layer = tf.keras.layers.Dense(5, input_dim=5,
123  ...     kernel_initializer='ones', kernel_regularizer=l1_reg)
124  >>> tensor = tf.ones(shape=(5, 5))
125  >>> out = layer(tensor)
126  >>> layer.losses
127  [<tf.Tensor: shape=(), dtype=float32, numpy=0.25>]
128
129  Alternatively, you can write your custom regularizers in an
130  object-oriented way by extending this regularizer base class, e.g.:
131
132  >>> @tf.keras.utils.register_keras_serializable(package='Custom', name='l2')
133  ... class L2Regularizer(tf.keras.regularizers.Regularizer):
134  ...   def __init__(self, l2=0.):  # pylint: disable=redefined-outer-name
135  ...     self.l2 = l2
136  ...
137  ...   def __call__(self, x):
138  ...     return self.l2 * tf.math.reduce_sum(tf.math.square(x))
139  ...
140  ...   def get_config(self):
141  ...     return {'l2': float(self.l2)}
142  ...
143  >>> layer = tf.keras.layers.Dense(
144  ...   5, input_dim=5, kernel_initializer='ones',
145  ...   kernel_regularizer=L2Regularizer(l2=0.5))
146
147  >>> tensor = tf.ones(shape=(5, 5))
148  >>> out = layer(tensor)
149  >>> layer.losses
150  [<tf.Tensor: shape=(), dtype=float32, numpy=12.5>]
151
152  ### A note on serialization and deserialization:
153
154  Registering the regularizers as serializable is optional if you are just
155  training and executing models, exporting to and from SavedModels, or saving
156  and loading weight checkpoints.
157
158  Registration is required for Keras `model_to_estimator`, saving and
159  loading models to HDF5 formats, Keras model cloning, some visualization
160  utilities, and exporting models to and from JSON. If using this functionality,
161  you must make sure any python process running your model has also defined
162  and registered your custom regularizer.
163
164  `tf.keras.utils.register_keras_serializable` is only available in TF 2.1 and
165  beyond. In earlier versions of TensorFlow you must pass your custom
166  regularizer to the `custom_objects` argument of methods that expect custom
167  regularizers to be registered as serializable.
168  """
169
170  def __call__(self, x):
171    """Compute a regularization penalty from an input tensor."""
172    return 0.
173
174  @classmethod
175  def from_config(cls, config):
176    """Creates a regularizer from its config.
177
178    This method is the reverse of `get_config`,
179    capable of instantiating the same regularizer from the config
180    dictionary.
181
182    This method is used by Keras `model_to_estimator`, saving and
183    loading models to HDF5 formats, Keras model cloning, some visualization
184    utilities, and exporting models to and from JSON.
185
186    Args:
187        config: A Python dictionary, typically the output of get_config.
188
189    Returns:
190        A regularizer instance.
191    """
192    return cls(**config)
193
194  def get_config(self):
195    """Returns the config of the regularizer.
196
197    An regularizer config is a Python dictionary (serializable)
198    containing all configuration parameters of the regularizer.
199    The same regularizer can be reinstantiated later
200    (without any saved state) from this configuration.
201
202    This method is optional if you are just training and executing models,
203    exporting to and from SavedModels, or using weight checkpoints.
204
205    This method is required for Keras `model_to_estimator`, saving and
206    loading models to HDF5 formats, Keras model cloning, some visualization
207    utilities, and exporting models to and from JSON.
208
209    Returns:
210        Python dictionary.
211    """
212    raise NotImplementedError(str(self) + ' does not implement get_config()')
213
214
215@keras_export('keras.regularizers.L1L2')
216class L1L2(Regularizer):
217  """A regularizer that applies both L1 and L2 regularization penalties.
218
219  The L1 regularization penalty is computed as:
220  `loss = l1 * reduce_sum(abs(x))`
221
222  The L2 regularization penalty is computed as
223  `loss = l2 * reduce_sum(square(x))`
224
225  L1L2 may be passed to a layer as a string identifier:
226
227  >>> dense = tf.keras.layers.Dense(3, kernel_regularizer='l1_l2')
228
229  In this case, the default values used are `l1=0.01` and `l2=0.01`.
230
231  Attributes:
232      l1: Float; L1 regularization factor.
233      l2: Float; L2 regularization factor.
234  """
235
236  def __init__(self, l1=0., l2=0.):  # pylint: disable=redefined-outer-name
237    # The default value for l1 and l2 are different from the value in l1_l2
238    # for backward compatibility reason. Eg, L1L2(l2=0.1) will only have l2
239    # and no l1 penalty.
240    l1 = 0. if l1 is None else l1
241    l2 = 0. if l2 is None else l2
242    _check_penalty_number(l1)
243    _check_penalty_number(l2)
244
245    self.l1 = backend.cast_to_floatx(l1)
246    self.l2 = backend.cast_to_floatx(l2)
247
248  def __call__(self, x):
249    regularization = backend.constant(0., dtype=x.dtype)
250    if self.l1:
251      regularization += self.l1 * math_ops.reduce_sum(math_ops.abs(x))
252    if self.l2:
253      regularization += self.l2 * math_ops.reduce_sum(math_ops.square(x))
254    return regularization
255
256  def get_config(self):
257    return {'l1': float(self.l1), 'l2': float(self.l2)}
258
259
260@keras_export('keras.regularizers.L1', 'keras.regularizers.l1')
261class L1(Regularizer):
262  """A regularizer that applies a L1 regularization penalty.
263
264  The L1 regularization penalty is computed as:
265  `loss = l1 * reduce_sum(abs(x))`
266
267  L1 may be passed to a layer as a string identifier:
268
269  >>> dense = tf.keras.layers.Dense(3, kernel_regularizer='l1')
270
271  In this case, the default value used is `l1=0.01`.
272
273  Attributes:
274      l1: Float; L1 regularization factor.
275  """
276
277  def __init__(self, l1=0.01, **kwargs):  # pylint: disable=redefined-outer-name
278    l1 = kwargs.pop('l', l1)  # Backwards compatibility
279    if kwargs:
280      raise TypeError('Argument(s) not recognized: %s' % (kwargs,))
281
282    l1 = 0.01 if l1 is None else l1
283    _check_penalty_number(l1)
284
285    self.l1 = backend.cast_to_floatx(l1)
286
287  def __call__(self, x):
288    return self.l1 * math_ops.reduce_sum(math_ops.abs(x))
289
290  def get_config(self):
291    return {'l1': float(self.l1)}
292
293
294@keras_export('keras.regularizers.L2', 'keras.regularizers.l2')
295class L2(Regularizer):
296  """A regularizer that applies a L2 regularization penalty.
297
298  The L2 regularization penalty is computed as:
299  `loss = l2 * reduce_sum(square(x))`
300
301  L2 may be passed to a layer as a string identifier:
302
303  >>> dense = tf.keras.layers.Dense(3, kernel_regularizer='l2')
304
305  In this case, the default value used is `l2=0.01`.
306
307  Attributes:
308      l2: Float; L2 regularization factor.
309  """
310
311  def __init__(self, l2=0.01, **kwargs):  # pylint: disable=redefined-outer-name
312    l2 = kwargs.pop('l', l2)  # Backwards compatibility
313    if kwargs:
314      raise TypeError('Argument(s) not recognized: %s' % (kwargs,))
315
316    l2 = 0.01 if l2 is None else l2
317    _check_penalty_number(l2)
318
319    self.l2 = backend.cast_to_floatx(l2)
320
321  def __call__(self, x):
322    return self.l2 * math_ops.reduce_sum(math_ops.square(x))
323
324  def get_config(self):
325    return {'l2': float(self.l2)}
326
327
328@keras_export('keras.regularizers.l1_l2')
329def l1_l2(l1=0.01, l2=0.01):  # pylint: disable=redefined-outer-name
330  r"""Create a regularizer that applies both L1 and L2 penalties.
331
332  The L1 regularization penalty is computed as:
333  `loss = l1 * reduce_sum(abs(x))`
334
335  The L2 regularization penalty is computed as:
336  `loss = l2 * reduce_sum(square(x))`
337
338  Args:
339      l1: Float; L1 regularization factor.
340      l2: Float; L2 regularization factor.
341
342  Returns:
343    An L1L2 Regularizer with the given regularization factors.
344  """
345  return L1L2(l1=l1, l2=l2)
346
347
348# Deserialization aliases.
349l1 = L1
350l2 = L2
351
352
353@keras_export('keras.regularizers.serialize')
354def serialize(regularizer):
355  return serialize_keras_object(regularizer)
356
357
358@keras_export('keras.regularizers.deserialize')
359def deserialize(config, custom_objects=None):
360  if config == 'l1_l2':
361    # Special case necessary since the defaults used for "l1_l2" (string)
362    # differ from those of the L1L2 class.
363    return L1L2(l1=0.01, l2=0.01)
364  return deserialize_keras_object(
365      config,
366      module_objects=globals(),
367      custom_objects=custom_objects,
368      printable_module_name='regularizer')
369
370
371@keras_export('keras.regularizers.get')
372def get(identifier):
373  """Retrieve a regularizer instance from a config or identifier."""
374  if identifier is None:
375    return None
376  if isinstance(identifier, dict):
377    return deserialize(identifier)
378  elif isinstance(identifier, six.string_types):
379    return deserialize(str(identifier))
380  elif callable(identifier):
381    return identifier
382  else:
383    raise ValueError(
384        'Could not interpret regularizer identifier: {}'.format(identifier))
385