• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=invalid-name
16# pylint: disable=g-classes-have-attributes
17"""Constraints: functions that impose constraints on weight values."""
18
19from tensorflow.python.framework import tensor_shape
20from tensorflow.python.keras import backend
21from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object
22from tensorflow.python.keras.utils.generic_utils import serialize_keras_object
23from tensorflow.python.ops import array_ops
24from tensorflow.python.ops import control_flow_ops
25from tensorflow.python.ops import math_ops
26from tensorflow.python.util.tf_export import keras_export
27from tensorflow.tools.docs import doc_controls
28
29
30@keras_export('keras.constraints.Constraint')
31class Constraint:
32  """Base class for weight constraints.
33
34  A `Constraint` instance works like a stateless function.
35  Users who subclass this
36  class should override the `__call__` method, which takes a single
37  weight parameter and return a projected version of that parameter
38  (e.g. normalized or clipped). Constraints can be used with various Keras
39  layers via the `kernel_constraint` or `bias_constraint` arguments.
40
41  Here's a simple example of a non-negative weight constraint:
42
43  >>> class NonNegative(tf.keras.constraints.Constraint):
44  ...
45  ...  def __call__(self, w):
46  ...    return w * tf.cast(tf.math.greater_equal(w, 0.), w.dtype)
47
48  >>> weight = tf.constant((-1.0, 1.0))
49  >>> NonNegative()(weight)
50  <tf.Tensor: shape=(2,), dtype=float32, numpy=array([0.,  1.], dtype=float32)>
51
52  >>> tf.keras.layers.Dense(4, kernel_constraint=NonNegative())
53  """
54
55  def __call__(self, w):
56    """Applies the constraint to the input weight variable.
57
58    By default, the inputs weight variable is not modified.
59    Users should override this method to implement their own projection
60    function.
61
62    Args:
63      w: Input weight variable.
64
65    Returns:
66      Projected variable (by default, returns unmodified inputs).
67    """
68    return w
69
70  def get_config(self):
71    """Returns a Python dict of the object config.
72
73    A constraint config is a Python dictionary (JSON-serializable) that can
74    be used to reinstantiate the same object.
75
76    Returns:
77      Python dict containing the configuration of the constraint object.
78    """
79    return {}
80
81
82@keras_export('keras.constraints.MaxNorm', 'keras.constraints.max_norm')
83class MaxNorm(Constraint):
84  """MaxNorm weight constraint.
85
86  Constrains the weights incident to each hidden unit
87  to have a norm less than or equal to a desired value.
88
89  Also available via the shortcut function `tf.keras.constraints.max_norm`.
90
91  Args:
92    max_value: the maximum norm value for the incoming weights.
93    axis: integer, axis along which to calculate weight norms.
94      For instance, in a `Dense` layer the weight matrix
95      has shape `(input_dim, output_dim)`,
96      set `axis` to `0` to constrain each weight vector
97      of length `(input_dim,)`.
98      In a `Conv2D` layer with `data_format="channels_last"`,
99      the weight tensor has shape
100      `(rows, cols, input_depth, output_depth)`,
101      set `axis` to `[0, 1, 2]`
102      to constrain the weights of each filter tensor of size
103      `(rows, cols, input_depth)`.
104
105  """
106
107  def __init__(self, max_value=2, axis=0):
108    self.max_value = max_value
109    self.axis = axis
110
111  @doc_controls.do_not_generate_docs
112  def __call__(self, w):
113    norms = backend.sqrt(
114        math_ops.reduce_sum(math_ops.square(w), axis=self.axis, keepdims=True))
115    desired = backend.clip(norms, 0, self.max_value)
116    return w * (desired / (backend.epsilon() + norms))
117
118  @doc_controls.do_not_generate_docs
119  def get_config(self):
120    return {'max_value': self.max_value, 'axis': self.axis}
121
122
123@keras_export('keras.constraints.NonNeg', 'keras.constraints.non_neg')
124class NonNeg(Constraint):
125  """Constrains the weights to be non-negative.
126
127  Also available via the shortcut function `tf.keras.constraints.non_neg`.
128  """
129
130  def __call__(self, w):
131    return w * math_ops.cast(math_ops.greater_equal(w, 0.), backend.floatx())
132
133
134@keras_export('keras.constraints.UnitNorm', 'keras.constraints.unit_norm')
135class UnitNorm(Constraint):
136  """Constrains the weights incident to each hidden unit to have unit norm.
137
138  Also available via the shortcut function `tf.keras.constraints.unit_norm`.
139
140  Args:
141    axis: integer, axis along which to calculate weight norms.
142      For instance, in a `Dense` layer the weight matrix
143      has shape `(input_dim, output_dim)`,
144      set `axis` to `0` to constrain each weight vector
145      of length `(input_dim,)`.
146      In a `Conv2D` layer with `data_format="channels_last"`,
147      the weight tensor has shape
148      `(rows, cols, input_depth, output_depth)`,
149      set `axis` to `[0, 1, 2]`
150      to constrain the weights of each filter tensor of size
151      `(rows, cols, input_depth)`.
152  """
153
154  def __init__(self, axis=0):
155    self.axis = axis
156
157  @doc_controls.do_not_generate_docs
158  def __call__(self, w):
159    return w / (
160        backend.epsilon() + backend.sqrt(
161            math_ops.reduce_sum(
162                math_ops.square(w), axis=self.axis, keepdims=True)))
163
164  @doc_controls.do_not_generate_docs
165  def get_config(self):
166    return {'axis': self.axis}
167
168
169@keras_export('keras.constraints.MinMaxNorm', 'keras.constraints.min_max_norm')
170class MinMaxNorm(Constraint):
171  """MinMaxNorm weight constraint.
172
173  Constrains the weights incident to each hidden unit
174  to have the norm between a lower bound and an upper bound.
175
176  Also available via the shortcut function `tf.keras.constraints.min_max_norm`.
177
178  Args:
179    min_value: the minimum norm for the incoming weights.
180    max_value: the maximum norm for the incoming weights.
181    rate: rate for enforcing the constraint: weights will be
182      rescaled to yield
183      `(1 - rate) * norm + rate * norm.clip(min_value, max_value)`.
184      Effectively, this means that rate=1.0 stands for strict
185      enforcement of the constraint, while rate<1.0 means that
186      weights will be rescaled at each step to slowly move
187      towards a value inside the desired interval.
188    axis: integer, axis along which to calculate weight norms.
189      For instance, in a `Dense` layer the weight matrix
190      has shape `(input_dim, output_dim)`,
191      set `axis` to `0` to constrain each weight vector
192      of length `(input_dim,)`.
193      In a `Conv2D` layer with `data_format="channels_last"`,
194      the weight tensor has shape
195      `(rows, cols, input_depth, output_depth)`,
196      set `axis` to `[0, 1, 2]`
197      to constrain the weights of each filter tensor of size
198      `(rows, cols, input_depth)`.
199  """
200
201  def __init__(self, min_value=0.0, max_value=1.0, rate=1.0, axis=0):
202    self.min_value = min_value
203    self.max_value = max_value
204    self.rate = rate
205    self.axis = axis
206
207  @doc_controls.do_not_generate_docs
208  def __call__(self, w):
209    norms = backend.sqrt(
210        math_ops.reduce_sum(math_ops.square(w), axis=self.axis, keepdims=True))
211    desired = (
212        self.rate * backend.clip(norms, self.min_value, self.max_value) +
213        (1 - self.rate) * norms)
214    return w * (desired / (backend.epsilon() + norms))
215
216  @doc_controls.do_not_generate_docs
217  def get_config(self):
218    return {
219        'min_value': self.min_value,
220        'max_value': self.max_value,
221        'rate': self.rate,
222        'axis': self.axis
223    }
224
225
226@keras_export('keras.constraints.RadialConstraint',
227              'keras.constraints.radial_constraint')
228class RadialConstraint(Constraint):
229  """Constrains `Conv2D` kernel weights to be the same for each radius.
230
231  Also available via the shortcut function
232  `tf.keras.constraints.radial_constraint`.
233
234  For example, the desired output for the following 4-by-4 kernel:
235
236  ```
237      kernel = [[v_00, v_01, v_02, v_03],
238                [v_10, v_11, v_12, v_13],
239                [v_20, v_21, v_22, v_23],
240                [v_30, v_31, v_32, v_33]]
241  ```
242
243  is this::
244
245  ```
246      kernel = [[v_11, v_11, v_11, v_11],
247                [v_11, v_33, v_33, v_11],
248                [v_11, v_33, v_33, v_11],
249                [v_11, v_11, v_11, v_11]]
250  ```
251
252  This constraint can be applied to any `Conv2D` layer version, including
253  `Conv2DTranspose` and `SeparableConv2D`, and with either `"channels_last"` or
254  `"channels_first"` data format. The method assumes the weight tensor is of
255  shape `(rows, cols, input_depth, output_depth)`.
256  """
257
258  @doc_controls.do_not_generate_docs
259  def __call__(self, w):
260    w_shape = w.shape
261    if w_shape.rank is None or w_shape.rank != 4:
262      raise ValueError(
263          'The weight tensor must be of rank 4, but is of shape: %s' % w_shape)
264
265    height, width, channels, kernels = w_shape
266    w = backend.reshape(w, (height, width, channels * kernels))
267    # TODO(cpeter): Switch map_fn for a faster tf.vectorized_map once
268    # backend.switch is supported.
269    w = backend.map_fn(
270        self._kernel_constraint,
271        backend.stack(array_ops.unstack(w, axis=-1), axis=0))
272    return backend.reshape(backend.stack(array_ops.unstack(w, axis=0), axis=-1),
273                           (height, width, channels, kernels))
274
275  def _kernel_constraint(self, kernel):
276    """Radially constraints a kernel with shape (height, width, channels)."""
277    padding = backend.constant([[1, 1], [1, 1]], dtype='int32')
278
279    kernel_shape = backend.shape(kernel)[0]
280    start = backend.cast(kernel_shape / 2, 'int32')
281
282    kernel_new = backend.switch(
283        backend.cast(math_ops.floormod(kernel_shape, 2), 'bool'),
284        lambda: kernel[start - 1:start, start - 1:start],
285        lambda: kernel[start - 1:start, start - 1:start] + backend.zeros(  # pylint: disable=g-long-lambda
286            (2, 2), dtype=kernel.dtype))
287    index = backend.switch(
288        backend.cast(math_ops.floormod(kernel_shape, 2), 'bool'),
289        lambda: backend.constant(0, dtype='int32'),
290        lambda: backend.constant(1, dtype='int32'))
291    while_condition = lambda index, *args: backend.less(index, start)
292
293    def body_fn(i, array):
294      return i + 1, array_ops.pad(
295          array,
296          padding,
297          constant_values=kernel[start + i, start + i])
298
299    _, kernel_new = control_flow_ops.while_loop(
300        while_condition,
301        body_fn,
302        [index, kernel_new],
303        shape_invariants=[index.get_shape(),
304                          tensor_shape.TensorShape([None, None])])
305    return kernel_new
306
307
308# Aliases.
309
310max_norm = MaxNorm
311non_neg = NonNeg
312unit_norm = UnitNorm
313min_max_norm = MinMaxNorm
314radial_constraint = RadialConstraint
315
316# Legacy aliases.
317maxnorm = max_norm
318nonneg = non_neg
319unitnorm = unit_norm
320
321
322@keras_export('keras.constraints.serialize')
323def serialize(constraint):
324  return serialize_keras_object(constraint)
325
326
327@keras_export('keras.constraints.deserialize')
328def deserialize(config, custom_objects=None):
329  return deserialize_keras_object(
330      config,
331      module_objects=globals(),
332      custom_objects=custom_objects,
333      printable_module_name='constraint')
334
335
336@keras_export('keras.constraints.get')
337def get(identifier):
338  if identifier is None:
339    return None
340  if isinstance(identifier, dict):
341    return deserialize(identifier)
342  elif isinstance(identifier, str):
343    config = {'class_name': str(identifier), 'config': {}}
344    return deserialize(config)
345  elif callable(identifier):
346    return identifier
347  else:
348    raise ValueError('Could not interpret constraint identifier: ' +
349                     str(identifier))
350