• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Built-in activation functions."""
16
17from tensorflow.python.keras import backend
18from tensorflow.python.keras.layers import advanced_activations
19from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object
20from tensorflow.python.keras.utils.generic_utils import serialize_keras_object
21from tensorflow.python.ops import math_ops
22from tensorflow.python.ops import nn
23from tensorflow.python.util import dispatch
24from tensorflow.python.util.tf_export import keras_export
25
26# b/123041942
27# In TF 2.x, if the `tf.nn.softmax` is used as an activation function in Keras
28# layers, it gets serialized as 'softmax_v2' instead of 'softmax' as the
29# internal method name is returned in serialization. This results in errors in
30# model exporting and loading as Keras can't find any activation function with
31# the name of `softmax_v2`.
32# This dict maps the activation function name from its v2 version to its
33# canonical name.
34_TF_ACTIVATIONS_V2 = {
35    'softmax_v2': 'softmax',
36}
37
38
39@keras_export('keras.activations.softmax')
40@dispatch.add_dispatch_support
41def softmax(x, axis=-1):
42  """Softmax converts a vector of values to a probability distribution.
43
44  The elements of the output vector are in range (0, 1) and sum to 1.
45
46  Each vector is handled independently. The `axis` argument sets which axis
47  of the input the function is applied along.
48
49  Softmax is often used as the activation for the last
50  layer of a classification network because the result could be interpreted as
51  a probability distribution.
52
53  The softmax of each vector x is computed as
54  `exp(x) / tf.reduce_sum(exp(x))`.
55
56  The input values in are the log-odds of the resulting probability.
57
58  Args:
59    x : Input tensor.
60    axis: Integer, axis along which the softmax normalization is applied.
61
62  Returns:
63    Tensor, output of softmax transformation (all values are non-negative
64      and sum to 1).
65
66  Examples:
67
68  **Example 1: standalone usage**
69
70  >>> inputs = tf.random.normal(shape=(32, 10))
71  >>> outputs = tf.keras.activations.softmax(inputs)
72  >>> tf.reduce_sum(outputs[0, :])  # Each sample in the batch now sums to 1
73  <tf.Tensor: shape=(), dtype=float32, numpy=1.0000001>
74
75  **Example 2: usage in a `Dense` layer**
76
77  >>> layer = tf.keras.layers.Dense(32, activation=tf.keras.activations.softmax)
78  """
79  if x.shape.rank > 1:
80    if isinstance(axis, int):
81      output = nn.softmax(x, axis=axis)
82    else:
83      # nn.softmax does not support tuple axis.
84      e = math_ops.exp(x - math_ops.reduce_max(x, axis=axis, keepdims=True))
85      s = math_ops.reduce_sum(e, axis=axis, keepdims=True)
86      output = e / s
87  else:
88    raise ValueError('Cannot apply softmax to a tensor that is 1D. '
89                     'Received input: %s' % (x,))
90
91  # Cache the logits to use for crossentropy loss.
92  output._keras_logits = x  # pylint: disable=protected-access
93  return output
94
95
96@keras_export('keras.activations.elu')
97@dispatch.add_dispatch_support
98def elu(x, alpha=1.0):
99  """Exponential Linear Unit.
100
101  The exponential linear unit (ELU) with `alpha > 0` is:
102  `x` if `x > 0` and
103  `alpha * (exp(x) - 1)` if `x < 0`
104  The ELU hyperparameter `alpha` controls the value to which an
105  ELU saturates for negative net inputs. ELUs diminish the
106  vanishing gradient effect.
107
108  ELUs have negative values which pushes the mean of the activations
109  closer to zero.
110  Mean activations that are closer to zero enable faster learning as they
111  bring the gradient closer to the natural gradient.
112  ELUs saturate to a negative value when the argument gets smaller.
113  Saturation means a small derivative which decreases the variation
114  and the information that is propagated to the next layer.
115
116  Example Usage:
117
118  >>> import tensorflow as tf
119  >>> model = tf.keras.Sequential()
120  >>> model.add(tf.keras.layers.Conv2D(32, (3, 3), activation='elu',
121  ...          input_shape=(28, 28, 1)))
122  >>> model.add(tf.keras.layers.MaxPooling2D((2, 2)))
123  >>> model.add(tf.keras.layers.Conv2D(64, (3, 3), activation='elu'))
124  >>> model.add(tf.keras.layers.MaxPooling2D((2, 2)))
125  >>> model.add(tf.keras.layers.Conv2D(64, (3, 3), activation='elu'))
126
127  <tensorflow.python.keras.engine.sequential.Sequential object ...>
128
129  Args:
130      x: Input tensor.
131      alpha: A scalar, slope of negative section. `alpha` controls the value to
132        which an ELU saturates for negative net inputs.
133
134  Returns:
135      The exponential linear unit (ELU) activation function: `x` if `x > 0` and
136      `alpha * (exp(x) - 1)` if `x < 0`.
137
138
139  Reference:
140      [Fast and Accurate Deep Network Learning by Exponential Linear Units
141      (ELUs) (Clevert et al, 2016)](https://arxiv.org/abs/1511.07289)
142  """
143  return backend.elu(x, alpha)
144
145
146@keras_export('keras.activations.selu')
147@dispatch.add_dispatch_support
148def selu(x):
149  """Scaled Exponential Linear Unit (SELU).
150
151  The Scaled Exponential Linear Unit (SELU) activation function is defined as:
152
153  - `if x > 0: return scale * x`
154  - `if x < 0: return scale * alpha * (exp(x) - 1)`
155
156  where `alpha` and `scale` are pre-defined constants
157  (`alpha=1.67326324` and `scale=1.05070098`).
158
159  Basically, the SELU activation function multiplies `scale` (> 1) with the
160  output of the `tf.keras.activations.elu` function to ensure a slope larger
161  than one for positive inputs.
162
163  The values of `alpha` and `scale` are
164  chosen so that the mean and variance of the inputs are preserved
165  between two consecutive layers as long as the weights are initialized
166  correctly (see `tf.keras.initializers.LecunNormal` initializer)
167  and the number of input units is "large enough"
168  (see reference paper for more information).
169
170  Example Usage:
171
172  >>> num_classes = 10  # 10-class problem
173  >>> model = tf.keras.Sequential()
174  >>> model.add(tf.keras.layers.Dense(64, kernel_initializer='lecun_normal',
175  ...                                 activation='selu'))
176  >>> model.add(tf.keras.layers.Dense(32, kernel_initializer='lecun_normal',
177  ...                                 activation='selu'))
178  >>> model.add(tf.keras.layers.Dense(16, kernel_initializer='lecun_normal',
179  ...                                 activation='selu'))
180  >>> model.add(tf.keras.layers.Dense(num_classes, activation='softmax'))
181
182  Args:
183      x: A tensor or variable to compute the activation function for.
184
185  Returns:
186      The scaled exponential unit activation: `scale * elu(x, alpha)`.
187
188  Notes:
189      - To be used together with the
190        `tf.keras.initializers.LecunNormal` initializer.
191      - To be used together with the dropout variant
192        `tf.keras.layers.AlphaDropout` (not regular dropout).
193
194  References:
195      - [Klambauer et al., 2017](https://arxiv.org/abs/1706.02515)
196  """
197  return nn.selu(x)
198
199
200@keras_export('keras.activations.softplus')
201@dispatch.add_dispatch_support
202def softplus(x):
203  """Softplus activation function, `softplus(x) = log(exp(x) + 1)`.
204
205  Example Usage:
206
207  >>> a = tf.constant([-20, -1.0, 0.0, 1.0, 20], dtype = tf.float32)
208  >>> b = tf.keras.activations.softplus(a)
209  >>> b.numpy()
210  array([2.0611537e-09, 3.1326166e-01, 6.9314718e-01, 1.3132616e+00,
211           2.0000000e+01], dtype=float32)
212
213  Args:
214      x: Input tensor.
215
216  Returns:
217      The softplus activation: `log(exp(x) + 1)`.
218  """
219  return math_ops.softplus(x)
220
221
222@keras_export('keras.activations.softsign')
223@dispatch.add_dispatch_support
224def softsign(x):
225  """Softsign activation function, `softsign(x) = x / (abs(x) + 1)`.
226
227  Example Usage:
228
229  >>> a = tf.constant([-1.0, 0.0, 1.0], dtype = tf.float32)
230  >>> b = tf.keras.activations.softsign(a)
231  >>> b.numpy()
232  array([-0.5,  0. ,  0.5], dtype=float32)
233
234  Args:
235      x: Input tensor.
236
237  Returns:
238      The softsign activation: `x / (abs(x) + 1)`.
239  """
240  return nn.softsign(x)
241
242
243@keras_export('keras.activations.swish')
244@dispatch.add_dispatch_support
245def swish(x):
246  """Swish activation function, `swish(x) = x * sigmoid(x)`.
247
248  Swish activation function which returns `x*sigmoid(x)`.
249  It is a smooth, non-monotonic function that consistently matches
250  or outperforms ReLU on deep networks, it is unbounded above and
251  bounded below.
252
253
254  Example Usage:
255
256  >>> a = tf.constant([-20, -1.0, 0.0, 1.0, 20], dtype = tf.float32)
257  >>> b = tf.keras.activations.swish(a)
258  >>> b.numpy()
259  array([-4.1223075e-08, -2.6894143e-01,  0.0000000e+00,  7.3105860e-01,
260            2.0000000e+01], dtype=float32)
261
262  Args:
263      x: Input tensor.
264
265  Returns:
266      The swish activation applied to `x` (see reference paper for details).
267
268  Reference:
269    - [Ramachandran et al., 2017](https://arxiv.org/abs/1710.05941)
270  """
271  return nn.swish(x)
272
273
274@keras_export('keras.activations.relu')
275@dispatch.add_dispatch_support
276def relu(x, alpha=0., max_value=None, threshold=0):
277  """Applies the rectified linear unit activation function.
278
279  With default values, this returns the standard ReLU activation:
280  `max(x, 0)`, the element-wise maximum of 0 and the input tensor.
281
282  Modifying default parameters allows you to use non-zero thresholds,
283  change the max value of the activation,
284  and to use a non-zero multiple of the input for values below the threshold.
285
286  For example:
287
288  >>> foo = tf.constant([-10, -5, 0.0, 5, 10], dtype = tf.float32)
289  >>> tf.keras.activations.relu(foo).numpy()
290  array([ 0.,  0.,  0.,  5., 10.], dtype=float32)
291  >>> tf.keras.activations.relu(foo, alpha=0.5).numpy()
292  array([-5. , -2.5,  0. ,  5. , 10. ], dtype=float32)
293  >>> tf.keras.activations.relu(foo, max_value=5).numpy()
294  array([0., 0., 0., 5., 5.], dtype=float32)
295  >>> tf.keras.activations.relu(foo, threshold=5).numpy()
296  array([-0., -0.,  0.,  0., 10.], dtype=float32)
297
298  Args:
299      x: Input `tensor` or `variable`.
300      alpha: A `float` that governs the slope for values lower than the
301        threshold.
302      max_value: A `float` that sets the saturation threshold (the largest value
303        the function will return).
304      threshold: A `float` giving the threshold value of the activation function
305        below which values will be damped or set to zero.
306
307  Returns:
308      A `Tensor` representing the input tensor,
309      transformed by the relu activation function.
310      Tensor will be of the same shape and dtype of input `x`.
311  """
312  return backend.relu(x, alpha=alpha, max_value=max_value, threshold=threshold)
313
314
315@keras_export('keras.activations.gelu', v1=[])
316@dispatch.add_dispatch_support
317def gelu(x, approximate=False):
318  """Applies the Gaussian error linear unit (GELU) activation function.
319
320  Gaussian error linear unit (GELU) computes
321  `x * P(X <= x)`, where `P(X) ~ N(0, 1)`.
322  The (GELU) nonlinearity weights inputs by their value, rather than gates
323  inputs by their sign as in ReLU.
324
325  For example:
326
327  >>> x = tf.constant([-3.0, -1.0, 0.0, 1.0, 3.0], dtype=tf.float32)
328  >>> y = tf.keras.activations.gelu(x)
329  >>> y.numpy()
330  array([-0.00404951, -0.15865529,  0.        ,  0.8413447 ,  2.9959507 ],
331      dtype=float32)
332  >>> y = tf.keras.activations.gelu(x, approximate=True)
333  >>> y.numpy()
334  array([-0.00363752, -0.15880796,  0.        ,  0.841192  ,  2.9963627 ],
335      dtype=float32)
336
337  Args:
338      x: Input tensor.
339      approximate: A `bool`, whether to enable approximation.
340
341  Returns:
342      The gaussian error linear activation:
343      `0.5 * x * (1 + tanh(sqrt(2 / pi) * (x + 0.044715 * x^3)))`
344      if `approximate` is `True` or
345      `x * P(X <= x) = 0.5 * x * (1 + erf(x / sqrt(2)))`,
346      where `P(X) ~ N(0, 1)`,
347      if `approximate` is `False`.
348
349  Reference:
350    - [Gaussian Error Linear Units (GELUs)](https://arxiv.org/abs/1606.08415)
351  """
352  return nn.gelu(x, approximate)
353
354
355@keras_export('keras.activations.tanh')
356@dispatch.add_dispatch_support
357def tanh(x):
358  """Hyperbolic tangent activation function.
359
360  For example:
361
362  >>> a = tf.constant([-3.0,-1.0, 0.0,1.0,3.0], dtype = tf.float32)
363  >>> b = tf.keras.activations.tanh(a)
364  >>> b.numpy()
365  array([-0.9950547, -0.7615942,  0.,  0.7615942,  0.9950547], dtype=float32)
366
367  Args:
368      x: Input tensor.
369
370  Returns:
371      Tensor of same shape and dtype of input `x`, with tanh activation:
372      `tanh(x) = sinh(x)/cosh(x) = ((exp(x) - exp(-x))/(exp(x) + exp(-x)))`.
373  """
374  return nn.tanh(x)
375
376
377@keras_export('keras.activations.sigmoid')
378@dispatch.add_dispatch_support
379def sigmoid(x):
380  """Sigmoid activation function, `sigmoid(x) = 1 / (1 + exp(-x))`.
381
382  Applies the sigmoid activation function. For small values (<-5),
383  `sigmoid` returns a value close to zero, and for large values (>5)
384  the result of the function gets close to 1.
385
386  Sigmoid is equivalent to a 2-element Softmax, where the second element is
387  assumed to be zero. The sigmoid function always returns a value between
388  0 and 1.
389
390  For example:
391
392  >>> a = tf.constant([-20, -1.0, 0.0, 1.0, 20], dtype = tf.float32)
393  >>> b = tf.keras.activations.sigmoid(a)
394  >>> b.numpy()
395  array([2.0611537e-09, 2.6894143e-01, 5.0000000e-01, 7.3105860e-01,
396           1.0000000e+00], dtype=float32)
397
398  Args:
399      x: Input tensor.
400
401  Returns:
402      Tensor with the sigmoid activation: `1 / (1 + exp(-x))`.
403  """
404  output = nn.sigmoid(x)
405  # Cache the logits to use for crossentropy loss.
406  output._keras_logits = x  # pylint: disable=protected-access
407  return output
408
409
410@keras_export('keras.activations.exponential')
411@dispatch.add_dispatch_support
412def exponential(x):
413  """Exponential activation function.
414
415  For example:
416
417  >>> a = tf.constant([-3.0,-1.0, 0.0,1.0,3.0], dtype = tf.float32)
418  >>> b = tf.keras.activations.exponential(a)
419  >>> b.numpy()
420  array([0.04978707,  0.36787945,  1.,  2.7182817 , 20.085537], dtype=float32)
421
422  Args:
423      x: Input tensor.
424
425  Returns:
426      Tensor with exponential activation: `exp(x)`.
427  """
428  return math_ops.exp(x)
429
430
431@keras_export('keras.activations.hard_sigmoid')
432@dispatch.add_dispatch_support
433def hard_sigmoid(x):
434  """Hard sigmoid activation function.
435
436  A faster approximation of the sigmoid activation.
437  Piecewise linear approximation of the sigmoid function.
438  Ref: 'https://en.wikipedia.org/wiki/Hard_sigmoid'
439
440  For example:
441
442  >>> a = tf.constant([-3.0,-1.0, 0.0,1.0,3.0], dtype = tf.float32)
443  >>> b = tf.keras.activations.hard_sigmoid(a)
444  >>> b.numpy()
445  array([0. , 0.3, 0.5, 0.7, 1. ], dtype=float32)
446
447  Args:
448      x: Input tensor.
449
450  Returns:
451    The hard sigmoid activation, defined as:
452
453      - `if x < -2.5: return 0`
454      - `if x > 2.5: return 1`
455      - `if -2.5 <= x <= 2.5: return 0.2 * x + 0.5`
456  """
457  return backend.hard_sigmoid(x)
458
459
460@keras_export('keras.activations.linear')
461@dispatch.add_dispatch_support
462def linear(x):
463  """Linear activation function (pass-through).
464
465  For example:
466
467  >>> a = tf.constant([-3.0,-1.0, 0.0,1.0,3.0], dtype = tf.float32)
468  >>> b = tf.keras.activations.linear(a)
469  >>> b.numpy()
470  array([-3., -1.,  0.,  1.,  3.], dtype=float32)
471
472  Args:
473      x: Input tensor.
474
475  Returns:
476      The input, unmodified.
477  """
478  return x
479
480
481@keras_export('keras.activations.serialize')
482@dispatch.add_dispatch_support
483def serialize(activation):
484  """Returns the string identifier of an activation function.
485
486  Args:
487      activation : Function object.
488
489  Returns:
490      String denoting the name attribute of the input function
491
492  For example:
493
494  >>> tf.keras.activations.serialize(tf.keras.activations.tanh)
495  'tanh'
496  >>> tf.keras.activations.serialize(tf.keras.activations.sigmoid)
497  'sigmoid'
498  >>> tf.keras.activations.serialize('abcd')
499  Traceback (most recent call last):
500  ...
501  ValueError: ('Cannot serialize', 'abcd')
502
503  Raises:
504      ValueError: The input function is not a valid one.
505  """
506  if (hasattr(activation, '__name__') and
507      activation.__name__ in _TF_ACTIVATIONS_V2):
508    return _TF_ACTIVATIONS_V2[activation.__name__]
509  return serialize_keras_object(activation)
510
511
512# Add additional globals so that deserialize can find these common activation
513# functions
514leaky_relu = nn.leaky_relu
515log_softmax = nn.log_softmax
516relu6 = nn.relu6
517silu = nn.swish
518
519
520@keras_export('keras.activations.deserialize')
521@dispatch.add_dispatch_support
522def deserialize(name, custom_objects=None):
523  """Returns activation function given a string identifier.
524
525  Args:
526    name: The name of the activation function.
527    custom_objects: Optional `{function_name: function_obj}`
528      dictionary listing user-provided activation functions.
529
530  Returns:
531      Corresponding activation function.
532
533  For example:
534
535  >>> tf.keras.activations.deserialize('linear')
536   <function linear at 0x1239596a8>
537  >>> tf.keras.activations.deserialize('sigmoid')
538   <function sigmoid at 0x123959510>
539  >>> tf.keras.activations.deserialize('abcd')
540  Traceback (most recent call last):
541  ...
542  ValueError: Unknown activation function:abcd
543
544  Raises:
545      ValueError: `Unknown activation function` if the input string does not
546      denote any defined Tensorflow activation function.
547  """
548  globs = globals()
549
550  # only replace missing activations
551  advanced_activations_globs = advanced_activations.get_globals()
552  for key, val in advanced_activations_globs.items():
553    if key not in globs:
554      globs[key] = val
555
556  return deserialize_keras_object(
557      name,
558      module_objects=globs,
559      custom_objects=custom_objects,
560      printable_module_name='activation function')
561
562
563@keras_export('keras.activations.get')
564@dispatch.add_dispatch_support
565def get(identifier):
566  """Returns function.
567
568  Args:
569      identifier: Function or string
570
571  Returns:
572      Function corresponding to the input string or input function.
573
574  For example:
575
576  >>> tf.keras.activations.get('softmax')
577   <function softmax at 0x1222a3d90>
578  >>> tf.keras.activations.get(tf.keras.activations.softmax)
579   <function softmax at 0x1222a3d90>
580  >>> tf.keras.activations.get(None)
581   <function linear at 0x1239596a8>
582  >>> tf.keras.activations.get(abs)
583   <built-in function abs>
584  >>> tf.keras.activations.get('abcd')
585  Traceback (most recent call last):
586  ...
587  ValueError: Unknown activation function:abcd
588
589  Raises:
590      ValueError: Input is an unknown function or string, i.e., the input does
591      not denote any defined function.
592  """
593  if identifier is None:
594    return linear
595  if isinstance(identifier, str):
596    identifier = str(identifier)
597    return deserialize(identifier)
598  elif isinstance(identifier, dict):
599    return deserialize(identifier)
600  elif callable(identifier):
601    return identifier
602  else:
603    raise TypeError(
604        'Could not interpret activation function identifier: {}'.format(
605            identifier))
606