1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=invalid-name 16"""Constraints: functions that impose constraints on weight values. 17""" 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import six 23 24from tensorflow.python.framework import tensor_shape 25from tensorflow.python.keras import backend as K 26from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object 27from tensorflow.python.keras.utils.generic_utils import serialize_keras_object 28from tensorflow.python.ops import array_ops 29from tensorflow.python.ops import control_flow_ops 30from tensorflow.python.ops import math_ops 31from tensorflow.python.util.tf_export import keras_export 32 33 34@keras_export('keras.constraints.Constraint') 35class Constraint(object): 36 37 def __call__(self, w): 38 return w 39 40 def get_config(self): 41 return {} 42 43 44@keras_export('keras.constraints.MaxNorm', 'keras.constraints.max_norm') 45class MaxNorm(Constraint): 46 """MaxNorm weight constraint. 47 48 Constrains the weights incident to each hidden unit 49 to have a norm less than or equal to a desired value. 50 51 Arguments: 52 m: the maximum norm for the incoming weights. 53 axis: integer, axis along which to calculate weight norms. 54 For instance, in a `Dense` layer the weight matrix 55 has shape `(input_dim, output_dim)`, 56 set `axis` to `0` to constrain each weight vector 57 of length `(input_dim,)`. 58 In a `Conv2D` layer with `data_format="channels_last"`, 59 the weight tensor has shape 60 `(rows, cols, input_depth, output_depth)`, 61 set `axis` to `[0, 1, 2]` 62 to constrain the weights of each filter tensor of size 63 `(rows, cols, input_depth)`. 64 65 """ 66 67 def __init__(self, max_value=2, axis=0): 68 self.max_value = max_value 69 self.axis = axis 70 71 def __call__(self, w): 72 norms = K.sqrt( 73 math_ops.reduce_sum(math_ops.square(w), axis=self.axis, keepdims=True)) 74 desired = K.clip(norms, 0, self.max_value) 75 return w * (desired / (K.epsilon() + norms)) 76 77 def get_config(self): 78 return {'max_value': self.max_value, 'axis': self.axis} 79 80 81@keras_export('keras.constraints.NonNeg', 'keras.constraints.non_neg') 82class NonNeg(Constraint): 83 """Constrains the weights to be non-negative. 84 """ 85 86 def __call__(self, w): 87 return w * math_ops.cast(math_ops.greater_equal(w, 0.), K.floatx()) 88 89 90@keras_export('keras.constraints.UnitNorm', 'keras.constraints.unit_norm') 91class UnitNorm(Constraint): 92 """Constrains the weights incident to each hidden unit to have unit norm. 93 94 Arguments: 95 axis: integer, axis along which to calculate weight norms. 96 For instance, in a `Dense` layer the weight matrix 97 has shape `(input_dim, output_dim)`, 98 set `axis` to `0` to constrain each weight vector 99 of length `(input_dim,)`. 100 In a `Conv2D` layer with `data_format="channels_last"`, 101 the weight tensor has shape 102 `(rows, cols, input_depth, output_depth)`, 103 set `axis` to `[0, 1, 2]` 104 to constrain the weights of each filter tensor of size 105 `(rows, cols, input_depth)`. 106 """ 107 108 def __init__(self, axis=0): 109 self.axis = axis 110 111 def __call__(self, w): 112 return w / ( 113 K.epsilon() + K.sqrt( 114 math_ops.reduce_sum( 115 math_ops.square(w), axis=self.axis, keepdims=True))) 116 117 def get_config(self): 118 return {'axis': self.axis} 119 120 121@keras_export('keras.constraints.MinMaxNorm', 'keras.constraints.min_max_norm') 122class MinMaxNorm(Constraint): 123 """MinMaxNorm weight constraint. 124 125 Constrains the weights incident to each hidden unit 126 to have the norm between a lower bound and an upper bound. 127 128 Arguments: 129 min_value: the minimum norm for the incoming weights. 130 max_value: the maximum norm for the incoming weights. 131 rate: rate for enforcing the constraint: weights will be 132 rescaled to yield 133 `(1 - rate) * norm + rate * norm.clip(min_value, max_value)`. 134 Effectively, this means that rate=1.0 stands for strict 135 enforcement of the constraint, while rate<1.0 means that 136 weights will be rescaled at each step to slowly move 137 towards a value inside the desired interval. 138 axis: integer, axis along which to calculate weight norms. 139 For instance, in a `Dense` layer the weight matrix 140 has shape `(input_dim, output_dim)`, 141 set `axis` to `0` to constrain each weight vector 142 of length `(input_dim,)`. 143 In a `Conv2D` layer with `data_format="channels_last"`, 144 the weight tensor has shape 145 `(rows, cols, input_depth, output_depth)`, 146 set `axis` to `[0, 1, 2]` 147 to constrain the weights of each filter tensor of size 148 `(rows, cols, input_depth)`. 149 """ 150 151 def __init__(self, min_value=0.0, max_value=1.0, rate=1.0, axis=0): 152 self.min_value = min_value 153 self.max_value = max_value 154 self.rate = rate 155 self.axis = axis 156 157 def __call__(self, w): 158 norms = K.sqrt( 159 math_ops.reduce_sum(math_ops.square(w), axis=self.axis, keepdims=True)) 160 desired = ( 161 self.rate * K.clip(norms, self.min_value, self.max_value) + 162 (1 - self.rate) * norms) 163 return w * (desired / (K.epsilon() + norms)) 164 165 def get_config(self): 166 return { 167 'min_value': self.min_value, 168 'max_value': self.max_value, 169 'rate': self.rate, 170 'axis': self.axis 171 } 172 173 174@keras_export('keras.constraints.RadialConstraint', 175 'keras.constraints.radial_constraint') 176class RadialConstraint(Constraint): 177 """Constrains `Conv2D` kernel weights to be the same for each radius. 178 179 For example, the desired output for the following 4-by-4 kernel:: 180 181 ``` 182 kernel = [[v_00, v_01, v_02, v_03], 183 [v_10, v_11, v_12, v_13], 184 [v_20, v_21, v_22, v_23], 185 [v_30, v_31, v_32, v_33]] 186 ``` 187 188 is this:: 189 190 ``` 191 kernel = [[v_11, v_11, v_11, v_11], 192 [v_11, v_33, v_33, v_11], 193 [v_11, v_33, v_33, v_11], 194 [v_11, v_11, v_11, v_11]] 195 ``` 196 197 This constraint can be applied to any `Conv2D` layer version, including 198 `Conv2DTranspose` and `SeparableConv2D`, and with either `"channels_last"` or 199 `"channels_first"` data format. The method assumes the weight tensor is of 200 shape `(rows, cols, input_depth, output_depth)`. 201 """ 202 203 def __call__(self, w): 204 w_shape = w.shape 205 if w_shape.rank is None or w_shape.rank != 4: 206 raise ValueError( 207 'The weight tensor must be of rank 4, but is of shape: %s' % w_shape) 208 209 height, width, channels, kernels = w_shape 210 w = K.reshape(w, (height, width, channels * kernels)) 211 # TODO(cpeter): Switch map_fn for a faster tf.vectorized_map once K.switch 212 # is supported. 213 w = K.map_fn( 214 self._kernel_constraint, 215 K.stack(array_ops.unstack(w, axis=-1), axis=0)) 216 return K.reshape(K.stack(array_ops.unstack(w, axis=0), axis=-1), 217 (height, width, channels, kernels)) 218 219 def _kernel_constraint(self, kernel): 220 """Radially constraints a kernel with shape (height, width, channels).""" 221 padding = K.constant([[1, 1], [1, 1]], dtype='int32') 222 223 kernel_shape = K.shape(kernel)[0] 224 start = K.cast(kernel_shape / 2, 'int32') 225 226 kernel_new = K.switch( 227 K.cast(math_ops.floormod(kernel_shape, 2), 'bool'), 228 lambda: kernel[start - 1:start, start - 1:start], 229 lambda: kernel[start - 1:start, start - 1:start] + K.zeros( # pylint: disable=g-long-lambda 230 (2, 2), dtype=kernel.dtype)) 231 index = K.switch( 232 K.cast(math_ops.floormod(kernel_shape, 2), 'bool'), 233 lambda: K.constant(0, dtype='int32'), 234 lambda: K.constant(1, dtype='int32')) 235 while_condition = lambda index, *args: K.less(index, start) 236 237 def body_fn(i, array): 238 return i + 1, array_ops.pad( 239 array, 240 padding, 241 constant_values=kernel[start + i, start + i]) 242 243 _, kernel_new = control_flow_ops.while_loop( 244 while_condition, 245 body_fn, 246 [index, kernel_new], 247 shape_invariants=[index.get_shape(), 248 tensor_shape.TensorShape([None, None])]) 249 return kernel_new 250 251 252# Aliases. 253 254max_norm = MaxNorm 255non_neg = NonNeg 256unit_norm = UnitNorm 257min_max_norm = MinMaxNorm 258radial_constraint = RadialConstraint 259 260# Legacy aliases. 261maxnorm = max_norm 262nonneg = non_neg 263unitnorm = unit_norm 264 265 266@keras_export('keras.constraints.serialize') 267def serialize(constraint): 268 return serialize_keras_object(constraint) 269 270 271@keras_export('keras.constraints.deserialize') 272def deserialize(config, custom_objects=None): 273 return deserialize_keras_object( 274 config, 275 module_objects=globals(), 276 custom_objects=custom_objects, 277 printable_module_name='constraint') 278 279 280@keras_export('keras.constraints.get') 281def get(identifier): 282 if identifier is None: 283 return None 284 if isinstance(identifier, dict): 285 return deserialize(identifier) 286 elif isinstance(identifier, six.string_types): 287 config = {'class_name': str(identifier), 'config': {}} 288 return deserialize(config) 289 elif callable(identifier): 290 return identifier 291 else: 292 raise ValueError('Could not interpret constraint identifier: ' + 293 str(identifier)) 294