1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Built-in regularizers. 16""" 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import six 22 23from tensorflow.python.keras import backend as K 24from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object 25from tensorflow.python.keras.utils.generic_utils import serialize_keras_object 26from tensorflow.python.ops import math_ops 27from tensorflow.python.util.tf_export import keras_export 28 29 30@keras_export('keras.regularizers.Regularizer') 31class Regularizer(object): 32 """Regularizer base class. 33 34 Regularizers allow you to apply penalties on layer parameters or layer 35 activity during optimization. These penalties are summed into the loss 36 function that the network optimizes. 37 38 Regularization penalties are applied on a per-layer basis. The exact API will 39 depend on the layer, but many layers (e.g. `Dense`, `Conv1D`, `Conv2D` and 40 `Conv3D`) have a unified API. 41 42 These layers expose 3 keyword arguments: 43 44 - `kernel_regularizer`: Regularizer to apply a penalty on the layer's kernel 45 - `bias_regularizer`: Regularizer to apply a penalty on the layer's bias 46 - `activity_regularizer`: Regularizer to apply a penalty on the layer's output 47 48 All layers (including custom layers) expose `activity_regularizer` as a 49 settable property, whether or not it is in the constructor arguments. 50 51 The value returned by the `activity_regularizer` is divided by the input 52 batch size so that the relative weighting between the weight regularizers and 53 the activity regularizers does not change with the batch size. 54 55 You can access a layer's regularization penalties by calling `layer.losses` 56 after calling the layer on inputs. 57 58 ## Example 59 60 >>> layer = tf.keras.layers.Dense( 61 ... 5, input_dim=5, 62 ... kernel_initializer='ones', 63 ... kernel_regularizer=tf.keras.regularizers.l1(0.01), 64 ... activity_regularizer=tf.keras.regularizers.l2(0.01)) 65 >>> tensor = tf.ones(shape=(5, 5)) * 2.0 66 >>> out = layer(tensor) 67 68 >>> # The kernel regularization term is 0.25 69 >>> # The activity regularization term (after dividing by the batch size) is 5 70 >>> tf.math.reduce_sum(layer.losses) 71 <tf.Tensor: shape=(), dtype=float32, numpy=5.25> 72 73 ## Available penalties 74 75 ```python 76 tf.keras.regularizers.l1(0.3) # L1 Regularization Penalty 77 tf.keras.regularizers.l2(0.1) # L2 Regularization Penalty 78 tf.keras.regularizers.l1_l2(l1=0.01, l2=0.01) # L1 + L2 penalties 79 ``` 80 81 ## Directly calling a regularizer 82 83 Compute a regularization loss on a tensor by directly calling a regularizer 84 as if it is a one-argument function. 85 86 E.g. 87 >>> regularizer = tf.keras.regularizers.l2(2.) 88 >>> tensor = tf.ones(shape=(5, 5)) 89 >>> regularizer(tensor) 90 <tf.Tensor: shape=(), dtype=float32, numpy=50.0> 91 92 93 ## Developing new regularizers 94 95 Any function that takes in a weight matrix and returns a scalar 96 tensor can be used as a regularizer, e.g.: 97 98 >>> @tf.keras.utils.register_keras_serializable(package='Custom', name='l1') 99 ... def l1_reg(weight_matrix): 100 ... return 0.01 * tf.math.reduce_sum(tf.math.abs(weight_matrix)) 101 ... 102 >>> layer = tf.keras.layers.Dense(5, input_dim=5, 103 ... kernel_initializer='ones', kernel_regularizer=l1_reg) 104 >>> tensor = tf.ones(shape=(5, 5)) 105 >>> out = layer(tensor) 106 >>> layer.losses 107 [<tf.Tensor: shape=(), dtype=float32, numpy=0.25>] 108 109 Alternatively, you can write your custom regularizers in an 110 object-oriented way by extending this regularizer base class, e.g.: 111 112 >>> @tf.keras.utils.register_keras_serializable(package='Custom', name='l2') 113 ... class L2Regularizer(tf.keras.regularizers.Regularizer): 114 ... def __init__(self, l2=0.): # pylint: disable=redefined-outer-name 115 ... self.l2 = l2 116 ... 117 ... def __call__(self, x): 118 ... return self.l2 * tf.math.reduce_sum(tf.math.square(x)) 119 ... 120 ... def get_config(self): 121 ... return {'l2': float(self.l2)} 122 ... 123 >>> layer = tf.keras.layers.Dense( 124 ... 5, input_dim=5, kernel_initializer='ones', 125 ... kernel_regularizer=L2Regularizer(l2=0.5)) 126 127 >>> tensor = tf.ones(shape=(5, 5)) 128 >>> out = layer(tensor) 129 >>> layer.losses 130 [<tf.Tensor: shape=(), dtype=float32, numpy=12.5>] 131 132 ### A note on serialization and deserialization: 133 134 Registering the regularizers as serializable is optional if you are just 135 training and executing models, exporting to and from SavedModels, or saving 136 and loading weight checkpoints. 137 138 Registration is required for Keras `model_to_estimator`, saving and 139 loading models to HDF5 formats, Keras model cloning, some visualization 140 utilities, and exporting models to and from JSON. If using this functionality, 141 you must make sure any python process running your model has also defined 142 and registered your custom regularizer. 143 144 `tf.keras.utils.register_keras_serializable` is only available in TF 2.1 and 145 beyond. In earlier versions of TensorFlow you must pass your custom 146 regularizer to the `custom_objects` argument of methods that expect custom 147 regularizers to be registered as serializable. 148 """ 149 150 def __call__(self, x): 151 """Compute a regularization penalty from an input tensor.""" 152 return 0. 153 154 @classmethod 155 def from_config(cls, config): 156 """Creates a regularizer from its config. 157 158 This method is the reverse of `get_config`, 159 capable of instantiating the same regularizer from the config 160 dictionary. 161 162 This method is used by Keras `model_to_estimator`, saving and 163 loading models to HDF5 formats, Keras model cloning, some visualization 164 utilities, and exporting models to and from JSON. 165 166 Arguments: 167 config: A Python dictionary, typically the output of get_config. 168 169 Returns: 170 A regularizer instance. 171 """ 172 return cls(**config) 173 174 def get_config(self): 175 """Returns the config of the regularizer. 176 177 An regularizer config is a Python dictionary (serializable) 178 containing all configuration parameters of the regularizer. 179 The same regularizer can be reinstantiated later 180 (without any saved state) from this configuration. 181 182 This method is optional if you are just training and executing models, 183 exporting to and from SavedModels, or using weight checkpoints. 184 185 This method is required for Keras `model_to_estimator`, saving and 186 loading models to HDF5 formats, Keras model cloning, some visualization 187 utilities, and exporting models to and from JSON. 188 189 Returns: 190 Python dictionary. 191 """ 192 raise NotImplementedError(str(self) + ' does not implement get_config()') 193 194 195@keras_export('keras.regularizers.L1L2') 196class L1L2(Regularizer): 197 r"""A regularizer that applies both L1 and L2 regularization penalties. 198 199 The L1 regularization penalty is computed as: 200 $$\ell_1\,\,penalty =\ell_1\sum_{i=0}^n|x_i|$$ 201 202 The L2 regularization penalty is computed as 203 $$\ell_2\,\,penalty =\ell_2\sum_{i=0}^nx_i^2$$ 204 205 Attributes: 206 l1: Float; L1 regularization factor. 207 l2: Float; L2 regularization factor. 208 """ 209 210 def __init__(self, l1=0., l2=0.): # pylint: disable=redefined-outer-name 211 self.l1 = K.cast_to_floatx(l1) 212 self.l2 = K.cast_to_floatx(l2) 213 214 def __call__(self, x): 215 if not self.l1 and not self.l2: 216 return K.constant(0.) 217 regularization = 0. 218 if self.l1: 219 regularization += self.l1 * math_ops.reduce_sum(math_ops.abs(x)) 220 if self.l2: 221 regularization += self.l2 * math_ops.reduce_sum(math_ops.square(x)) 222 return regularization 223 224 def get_config(self): 225 return {'l1': float(self.l1), 'l2': float(self.l2)} 226 227 228# Aliases. 229 230 231@keras_export('keras.regularizers.l1') 232def l1(l=0.01): 233 r"""Create a regularizer that applies an L1 regularization penalty. 234 235 The L1 regularization penalty is computed as: 236 $$\ell_1\,\,penalty =\ell_1\sum_{i=0}^n|x_i|$$ 237 238 Arguments: 239 l: Float; L1 regularization factor. 240 241 Returns: 242 An L1 Regularizer with the given regularization factor. 243 """ 244 return L1L2(l1=l) 245 246 247@keras_export('keras.regularizers.l2') 248def l2(l=0.01): 249 r"""Create a regularizer that applies an L2 regularization penalty. 250 251 The L2 regularization penalty is computed as: 252 $$\ell_2\,\,penalty =\ell_2\sum_{i=0}^nx_i^2$$ 253 254 Arguments: 255 l: Float; L2 regularization factor. 256 257 Returns: 258 An L2 Regularizer with the given regularization factor. 259 """ 260 return L1L2(l2=l) 261 262 263@keras_export('keras.regularizers.l1_l2') 264def l1_l2(l1=0.01, l2=0.01): # pylint: disable=redefined-outer-name 265 r"""Create a regularizer that applies both L1 and L2 penalties. 266 267 The L1 regularization penalty is computed as: 268 $$\ell_1\,\,penalty =\ell_1\sum_{i=0}^n|x_i|$$ 269 270 The L2 regularization penalty is computed as: 271 $$\ell_2\,\,penalty =\ell_2\sum_{i=0}^nx_i^2$$ 272 273 Arguments: 274 l1: Float; L1 regularization factor. 275 l2: Float; L2 regularization factor. 276 277 Returns: 278 An L1L2 Regularizer with the given regularization factors. 279 """ 280 return L1L2(l1=l1, l2=l2) 281 282 283@keras_export('keras.regularizers.serialize') 284def serialize(regularizer): 285 return serialize_keras_object(regularizer) 286 287 288@keras_export('keras.regularizers.deserialize') 289def deserialize(config, custom_objects=None): 290 return deserialize_keras_object( 291 config, 292 module_objects=globals(), 293 custom_objects=custom_objects, 294 printable_module_name='regularizer') 295 296 297@keras_export('keras.regularizers.get') 298def get(identifier): 299 if identifier is None: 300 return None 301 if isinstance(identifier, dict): 302 return deserialize(identifier) 303 elif isinstance(identifier, six.string_types): 304 identifier = str(identifier) 305 # We have to special-case functions that return classes. 306 # TODO(omalleyt): Turn these into classes or class aliases. 307 special_cases = ['l1', 'l2', 'l1_l2'] 308 if identifier in special_cases: 309 # Treat like a class. 310 return deserialize({'class_name': identifier, 'config': {}}) 311 return deserialize(str(identifier)) 312 elif callable(identifier): 313 return identifier 314 else: 315 raise ValueError('Could not interpret regularizer identifier:', identifier) 316