1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=invalid-name 16# pylint: disable=g-classes-have-attributes 17"""Constraints: functions that impose constraints on weight values.""" 18 19from tensorflow.python.framework import tensor_shape 20from tensorflow.python.keras import backend 21from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object 22from tensorflow.python.keras.utils.generic_utils import serialize_keras_object 23from tensorflow.python.ops import array_ops 24from tensorflow.python.ops import control_flow_ops 25from tensorflow.python.ops import math_ops 26from tensorflow.python.util.tf_export import keras_export 27from tensorflow.tools.docs import doc_controls 28 29 30@keras_export('keras.constraints.Constraint') 31class Constraint: 32 """Base class for weight constraints. 33 34 A `Constraint` instance works like a stateless function. 35 Users who subclass this 36 class should override the `__call__` method, which takes a single 37 weight parameter and return a projected version of that parameter 38 (e.g. normalized or clipped). Constraints can be used with various Keras 39 layers via the `kernel_constraint` or `bias_constraint` arguments. 40 41 Here's a simple example of a non-negative weight constraint: 42 43 >>> class NonNegative(tf.keras.constraints.Constraint): 44 ... 45 ... def __call__(self, w): 46 ... return w * tf.cast(tf.math.greater_equal(w, 0.), w.dtype) 47 48 >>> weight = tf.constant((-1.0, 1.0)) 49 >>> NonNegative()(weight) 50 <tf.Tensor: shape=(2,), dtype=float32, numpy=array([0., 1.], dtype=float32)> 51 52 >>> tf.keras.layers.Dense(4, kernel_constraint=NonNegative()) 53 """ 54 55 def __call__(self, w): 56 """Applies the constraint to the input weight variable. 57 58 By default, the inputs weight variable is not modified. 59 Users should override this method to implement their own projection 60 function. 61 62 Args: 63 w: Input weight variable. 64 65 Returns: 66 Projected variable (by default, returns unmodified inputs). 67 """ 68 return w 69 70 def get_config(self): 71 """Returns a Python dict of the object config. 72 73 A constraint config is a Python dictionary (JSON-serializable) that can 74 be used to reinstantiate the same object. 75 76 Returns: 77 Python dict containing the configuration of the constraint object. 78 """ 79 return {} 80 81 82@keras_export('keras.constraints.MaxNorm', 'keras.constraints.max_norm') 83class MaxNorm(Constraint): 84 """MaxNorm weight constraint. 85 86 Constrains the weights incident to each hidden unit 87 to have a norm less than or equal to a desired value. 88 89 Also available via the shortcut function `tf.keras.constraints.max_norm`. 90 91 Args: 92 max_value: the maximum norm value for the incoming weights. 93 axis: integer, axis along which to calculate weight norms. 94 For instance, in a `Dense` layer the weight matrix 95 has shape `(input_dim, output_dim)`, 96 set `axis` to `0` to constrain each weight vector 97 of length `(input_dim,)`. 98 In a `Conv2D` layer with `data_format="channels_last"`, 99 the weight tensor has shape 100 `(rows, cols, input_depth, output_depth)`, 101 set `axis` to `[0, 1, 2]` 102 to constrain the weights of each filter tensor of size 103 `(rows, cols, input_depth)`. 104 105 """ 106 107 def __init__(self, max_value=2, axis=0): 108 self.max_value = max_value 109 self.axis = axis 110 111 @doc_controls.do_not_generate_docs 112 def __call__(self, w): 113 norms = backend.sqrt( 114 math_ops.reduce_sum(math_ops.square(w), axis=self.axis, keepdims=True)) 115 desired = backend.clip(norms, 0, self.max_value) 116 return w * (desired / (backend.epsilon() + norms)) 117 118 @doc_controls.do_not_generate_docs 119 def get_config(self): 120 return {'max_value': self.max_value, 'axis': self.axis} 121 122 123@keras_export('keras.constraints.NonNeg', 'keras.constraints.non_neg') 124class NonNeg(Constraint): 125 """Constrains the weights to be non-negative. 126 127 Also available via the shortcut function `tf.keras.constraints.non_neg`. 128 """ 129 130 def __call__(self, w): 131 return w * math_ops.cast(math_ops.greater_equal(w, 0.), backend.floatx()) 132 133 134@keras_export('keras.constraints.UnitNorm', 'keras.constraints.unit_norm') 135class UnitNorm(Constraint): 136 """Constrains the weights incident to each hidden unit to have unit norm. 137 138 Also available via the shortcut function `tf.keras.constraints.unit_norm`. 139 140 Args: 141 axis: integer, axis along which to calculate weight norms. 142 For instance, in a `Dense` layer the weight matrix 143 has shape `(input_dim, output_dim)`, 144 set `axis` to `0` to constrain each weight vector 145 of length `(input_dim,)`. 146 In a `Conv2D` layer with `data_format="channels_last"`, 147 the weight tensor has shape 148 `(rows, cols, input_depth, output_depth)`, 149 set `axis` to `[0, 1, 2]` 150 to constrain the weights of each filter tensor of size 151 `(rows, cols, input_depth)`. 152 """ 153 154 def __init__(self, axis=0): 155 self.axis = axis 156 157 @doc_controls.do_not_generate_docs 158 def __call__(self, w): 159 return w / ( 160 backend.epsilon() + backend.sqrt( 161 math_ops.reduce_sum( 162 math_ops.square(w), axis=self.axis, keepdims=True))) 163 164 @doc_controls.do_not_generate_docs 165 def get_config(self): 166 return {'axis': self.axis} 167 168 169@keras_export('keras.constraints.MinMaxNorm', 'keras.constraints.min_max_norm') 170class MinMaxNorm(Constraint): 171 """MinMaxNorm weight constraint. 172 173 Constrains the weights incident to each hidden unit 174 to have the norm between a lower bound and an upper bound. 175 176 Also available via the shortcut function `tf.keras.constraints.min_max_norm`. 177 178 Args: 179 min_value: the minimum norm for the incoming weights. 180 max_value: the maximum norm for the incoming weights. 181 rate: rate for enforcing the constraint: weights will be 182 rescaled to yield 183 `(1 - rate) * norm + rate * norm.clip(min_value, max_value)`. 184 Effectively, this means that rate=1.0 stands for strict 185 enforcement of the constraint, while rate<1.0 means that 186 weights will be rescaled at each step to slowly move 187 towards a value inside the desired interval. 188 axis: integer, axis along which to calculate weight norms. 189 For instance, in a `Dense` layer the weight matrix 190 has shape `(input_dim, output_dim)`, 191 set `axis` to `0` to constrain each weight vector 192 of length `(input_dim,)`. 193 In a `Conv2D` layer with `data_format="channels_last"`, 194 the weight tensor has shape 195 `(rows, cols, input_depth, output_depth)`, 196 set `axis` to `[0, 1, 2]` 197 to constrain the weights of each filter tensor of size 198 `(rows, cols, input_depth)`. 199 """ 200 201 def __init__(self, min_value=0.0, max_value=1.0, rate=1.0, axis=0): 202 self.min_value = min_value 203 self.max_value = max_value 204 self.rate = rate 205 self.axis = axis 206 207 @doc_controls.do_not_generate_docs 208 def __call__(self, w): 209 norms = backend.sqrt( 210 math_ops.reduce_sum(math_ops.square(w), axis=self.axis, keepdims=True)) 211 desired = ( 212 self.rate * backend.clip(norms, self.min_value, self.max_value) + 213 (1 - self.rate) * norms) 214 return w * (desired / (backend.epsilon() + norms)) 215 216 @doc_controls.do_not_generate_docs 217 def get_config(self): 218 return { 219 'min_value': self.min_value, 220 'max_value': self.max_value, 221 'rate': self.rate, 222 'axis': self.axis 223 } 224 225 226@keras_export('keras.constraints.RadialConstraint', 227 'keras.constraints.radial_constraint') 228class RadialConstraint(Constraint): 229 """Constrains `Conv2D` kernel weights to be the same for each radius. 230 231 Also available via the shortcut function 232 `tf.keras.constraints.radial_constraint`. 233 234 For example, the desired output for the following 4-by-4 kernel: 235 236 ``` 237 kernel = [[v_00, v_01, v_02, v_03], 238 [v_10, v_11, v_12, v_13], 239 [v_20, v_21, v_22, v_23], 240 [v_30, v_31, v_32, v_33]] 241 ``` 242 243 is this:: 244 245 ``` 246 kernel = [[v_11, v_11, v_11, v_11], 247 [v_11, v_33, v_33, v_11], 248 [v_11, v_33, v_33, v_11], 249 [v_11, v_11, v_11, v_11]] 250 ``` 251 252 This constraint can be applied to any `Conv2D` layer version, including 253 `Conv2DTranspose` and `SeparableConv2D`, and with either `"channels_last"` or 254 `"channels_first"` data format. The method assumes the weight tensor is of 255 shape `(rows, cols, input_depth, output_depth)`. 256 """ 257 258 @doc_controls.do_not_generate_docs 259 def __call__(self, w): 260 w_shape = w.shape 261 if w_shape.rank is None or w_shape.rank != 4: 262 raise ValueError( 263 'The weight tensor must be of rank 4, but is of shape: %s' % w_shape) 264 265 height, width, channels, kernels = w_shape 266 w = backend.reshape(w, (height, width, channels * kernels)) 267 # TODO(cpeter): Switch map_fn for a faster tf.vectorized_map once 268 # backend.switch is supported. 269 w = backend.map_fn( 270 self._kernel_constraint, 271 backend.stack(array_ops.unstack(w, axis=-1), axis=0)) 272 return backend.reshape(backend.stack(array_ops.unstack(w, axis=0), axis=-1), 273 (height, width, channels, kernels)) 274 275 def _kernel_constraint(self, kernel): 276 """Radially constraints a kernel with shape (height, width, channels).""" 277 padding = backend.constant([[1, 1], [1, 1]], dtype='int32') 278 279 kernel_shape = backend.shape(kernel)[0] 280 start = backend.cast(kernel_shape / 2, 'int32') 281 282 kernel_new = backend.switch( 283 backend.cast(math_ops.floormod(kernel_shape, 2), 'bool'), 284 lambda: kernel[start - 1:start, start - 1:start], 285 lambda: kernel[start - 1:start, start - 1:start] + backend.zeros( # pylint: disable=g-long-lambda 286 (2, 2), dtype=kernel.dtype)) 287 index = backend.switch( 288 backend.cast(math_ops.floormod(kernel_shape, 2), 'bool'), 289 lambda: backend.constant(0, dtype='int32'), 290 lambda: backend.constant(1, dtype='int32')) 291 while_condition = lambda index, *args: backend.less(index, start) 292 293 def body_fn(i, array): 294 return i + 1, array_ops.pad( 295 array, 296 padding, 297 constant_values=kernel[start + i, start + i]) 298 299 _, kernel_new = control_flow_ops.while_loop( 300 while_condition, 301 body_fn, 302 [index, kernel_new], 303 shape_invariants=[index.get_shape(), 304 tensor_shape.TensorShape([None, None])]) 305 return kernel_new 306 307 308# Aliases. 309 310max_norm = MaxNorm 311non_neg = NonNeg 312unit_norm = UnitNorm 313min_max_norm = MinMaxNorm 314radial_constraint = RadialConstraint 315 316# Legacy aliases. 317maxnorm = max_norm 318nonneg = non_neg 319unitnorm = unit_norm 320 321 322@keras_export('keras.constraints.serialize') 323def serialize(constraint): 324 return serialize_keras_object(constraint) 325 326 327@keras_export('keras.constraints.deserialize') 328def deserialize(config, custom_objects=None): 329 return deserialize_keras_object( 330 config, 331 module_objects=globals(), 332 custom_objects=custom_objects, 333 printable_module_name='constraint') 334 335 336@keras_export('keras.constraints.get') 337def get(identifier): 338 if identifier is None: 339 return None 340 if isinstance(identifier, dict): 341 return deserialize(identifier) 342 elif isinstance(identifier, str): 343 config = {'class_name': str(identifier), 'config': {}} 344 return deserialize(config) 345 elif callable(identifier): 346 return identifier 347 else: 348 raise ValueError('Could not interpret constraint identifier: ' + 349 str(identifier)) 350