• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=invalid-name
16"""Constraints: functions that impose constraints on weight values.
17"""
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import six
23
24from tensorflow.python.framework import tensor_shape
25from tensorflow.python.keras import backend as K
26from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object
27from tensorflow.python.keras.utils.generic_utils import serialize_keras_object
28from tensorflow.python.ops import array_ops
29from tensorflow.python.ops import control_flow_ops
30from tensorflow.python.ops import math_ops
31from tensorflow.python.util.tf_export import keras_export
32
33
34@keras_export('keras.constraints.Constraint')
35class Constraint(object):
36
37  def __call__(self, w):
38    return w
39
40  def get_config(self):
41    return {}
42
43
44@keras_export('keras.constraints.MaxNorm', 'keras.constraints.max_norm')
45class MaxNorm(Constraint):
46  """MaxNorm weight constraint.
47
48  Constrains the weights incident to each hidden unit
49  to have a norm less than or equal to a desired value.
50
51  Arguments:
52      m: the maximum norm for the incoming weights.
53      axis: integer, axis along which to calculate weight norms.
54          For instance, in a `Dense` layer the weight matrix
55          has shape `(input_dim, output_dim)`,
56          set `axis` to `0` to constrain each weight vector
57          of length `(input_dim,)`.
58          In a `Conv2D` layer with `data_format="channels_last"`,
59          the weight tensor has shape
60          `(rows, cols, input_depth, output_depth)`,
61          set `axis` to `[0, 1, 2]`
62          to constrain the weights of each filter tensor of size
63          `(rows, cols, input_depth)`.
64
65  """
66
67  def __init__(self, max_value=2, axis=0):
68    self.max_value = max_value
69    self.axis = axis
70
71  def __call__(self, w):
72    norms = K.sqrt(
73        math_ops.reduce_sum(math_ops.square(w), axis=self.axis, keepdims=True))
74    desired = K.clip(norms, 0, self.max_value)
75    return w * (desired / (K.epsilon() + norms))
76
77  def get_config(self):
78    return {'max_value': self.max_value, 'axis': self.axis}
79
80
81@keras_export('keras.constraints.NonNeg', 'keras.constraints.non_neg')
82class NonNeg(Constraint):
83  """Constrains the weights to be non-negative.
84  """
85
86  def __call__(self, w):
87    return w * math_ops.cast(math_ops.greater_equal(w, 0.), K.floatx())
88
89
90@keras_export('keras.constraints.UnitNorm', 'keras.constraints.unit_norm')
91class UnitNorm(Constraint):
92  """Constrains the weights incident to each hidden unit to have unit norm.
93
94  Arguments:
95      axis: integer, axis along which to calculate weight norms.
96          For instance, in a `Dense` layer the weight matrix
97          has shape `(input_dim, output_dim)`,
98          set `axis` to `0` to constrain each weight vector
99          of length `(input_dim,)`.
100          In a `Conv2D` layer with `data_format="channels_last"`,
101          the weight tensor has shape
102          `(rows, cols, input_depth, output_depth)`,
103          set `axis` to `[0, 1, 2]`
104          to constrain the weights of each filter tensor of size
105          `(rows, cols, input_depth)`.
106  """
107
108  def __init__(self, axis=0):
109    self.axis = axis
110
111  def __call__(self, w):
112    return w / (
113        K.epsilon() + K.sqrt(
114            math_ops.reduce_sum(
115                math_ops.square(w), axis=self.axis, keepdims=True)))
116
117  def get_config(self):
118    return {'axis': self.axis}
119
120
121@keras_export('keras.constraints.MinMaxNorm', 'keras.constraints.min_max_norm')
122class MinMaxNorm(Constraint):
123  """MinMaxNorm weight constraint.
124
125  Constrains the weights incident to each hidden unit
126  to have the norm between a lower bound and an upper bound.
127
128  Arguments:
129      min_value: the minimum norm for the incoming weights.
130      max_value: the maximum norm for the incoming weights.
131      rate: rate for enforcing the constraint: weights will be
132          rescaled to yield
133          `(1 - rate) * norm + rate * norm.clip(min_value, max_value)`.
134          Effectively, this means that rate=1.0 stands for strict
135          enforcement of the constraint, while rate<1.0 means that
136          weights will be rescaled at each step to slowly move
137          towards a value inside the desired interval.
138      axis: integer, axis along which to calculate weight norms.
139          For instance, in a `Dense` layer the weight matrix
140          has shape `(input_dim, output_dim)`,
141          set `axis` to `0` to constrain each weight vector
142          of length `(input_dim,)`.
143          In a `Conv2D` layer with `data_format="channels_last"`,
144          the weight tensor has shape
145          `(rows, cols, input_depth, output_depth)`,
146          set `axis` to `[0, 1, 2]`
147          to constrain the weights of each filter tensor of size
148          `(rows, cols, input_depth)`.
149  """
150
151  def __init__(self, min_value=0.0, max_value=1.0, rate=1.0, axis=0):
152    self.min_value = min_value
153    self.max_value = max_value
154    self.rate = rate
155    self.axis = axis
156
157  def __call__(self, w):
158    norms = K.sqrt(
159        math_ops.reduce_sum(math_ops.square(w), axis=self.axis, keepdims=True))
160    desired = (
161        self.rate * K.clip(norms, self.min_value, self.max_value) +
162        (1 - self.rate) * norms)
163    return w * (desired / (K.epsilon() + norms))
164
165  def get_config(self):
166    return {
167        'min_value': self.min_value,
168        'max_value': self.max_value,
169        'rate': self.rate,
170        'axis': self.axis
171    }
172
173
174@keras_export('keras.constraints.RadialConstraint',
175              'keras.constraints.radial_constraint')
176class RadialConstraint(Constraint):
177  """Constrains `Conv2D` kernel weights to be the same for each radius.
178
179  For example, the desired output for the following 4-by-4 kernel::
180
181  ```
182      kernel = [[v_00, v_01, v_02, v_03],
183                [v_10, v_11, v_12, v_13],
184                [v_20, v_21, v_22, v_23],
185                [v_30, v_31, v_32, v_33]]
186  ```
187
188  is this::
189
190  ```
191      kernel = [[v_11, v_11, v_11, v_11],
192                [v_11, v_33, v_33, v_11],
193                [v_11, v_33, v_33, v_11],
194                [v_11, v_11, v_11, v_11]]
195  ```
196
197  This constraint can be applied to any `Conv2D` layer version, including
198  `Conv2DTranspose` and `SeparableConv2D`, and with either `"channels_last"` or
199  `"channels_first"` data format. The method assumes the weight tensor is of
200  shape `(rows, cols, input_depth, output_depth)`.
201  """
202
203  def __call__(self, w):
204    w_shape = w.shape
205    if w_shape.rank is None or w_shape.rank != 4:
206      raise ValueError(
207          'The weight tensor must be of rank 4, but is of shape: %s' % w_shape)
208
209    height, width, channels, kernels = w_shape
210    w = K.reshape(w, (height, width, channels * kernels))
211    # TODO(cpeter): Switch map_fn for a faster tf.vectorized_map once K.switch
212    # is supported.
213    w = K.map_fn(
214        self._kernel_constraint,
215        K.stack(array_ops.unstack(w, axis=-1), axis=0))
216    return K.reshape(K.stack(array_ops.unstack(w, axis=0), axis=-1),
217                     (height, width, channels, kernels))
218
219  def _kernel_constraint(self, kernel):
220    """Radially constraints a kernel with shape (height, width, channels)."""
221    padding = K.constant([[1, 1], [1, 1]], dtype='int32')
222
223    kernel_shape = K.shape(kernel)[0]
224    start = K.cast(kernel_shape / 2, 'int32')
225
226    kernel_new = K.switch(
227        K.cast(math_ops.floormod(kernel_shape, 2), 'bool'),
228        lambda: kernel[start - 1:start, start - 1:start],
229        lambda: kernel[start - 1:start, start - 1:start] + K.zeros(  # pylint: disable=g-long-lambda
230            (2, 2), dtype=kernel.dtype))
231    index = K.switch(
232        K.cast(math_ops.floormod(kernel_shape, 2), 'bool'),
233        lambda: K.constant(0, dtype='int32'),
234        lambda: K.constant(1, dtype='int32'))
235    while_condition = lambda index, *args: K.less(index, start)
236
237    def body_fn(i, array):
238      return i + 1, array_ops.pad(
239          array,
240          padding,
241          constant_values=kernel[start + i, start + i])
242
243    _, kernel_new = control_flow_ops.while_loop(
244        while_condition,
245        body_fn,
246        [index, kernel_new],
247        shape_invariants=[index.get_shape(),
248                          tensor_shape.TensorShape([None, None])])
249    return kernel_new
250
251
252# Aliases.
253
254max_norm = MaxNorm
255non_neg = NonNeg
256unit_norm = UnitNorm
257min_max_norm = MinMaxNorm
258radial_constraint = RadialConstraint
259
260# Legacy aliases.
261maxnorm = max_norm
262nonneg = non_neg
263unitnorm = unit_norm
264
265
266@keras_export('keras.constraints.serialize')
267def serialize(constraint):
268  return serialize_keras_object(constraint)
269
270
271@keras_export('keras.constraints.deserialize')
272def deserialize(config, custom_objects=None):
273  return deserialize_keras_object(
274      config,
275      module_objects=globals(),
276      custom_objects=custom_objects,
277      printable_module_name='constraint')
278
279
280@keras_export('keras.constraints.get')
281def get(identifier):
282  if identifier is None:
283    return None
284  if isinstance(identifier, dict):
285    return deserialize(identifier)
286  elif isinstance(identifier, six.string_types):
287    config = {'class_name': str(identifier), 'config': {}}
288    return deserialize(config)
289  elif callable(identifier):
290    return identifier
291  else:
292    raise ValueError('Could not interpret constraint identifier: ' +
293                     str(identifier))
294