• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=g-classes-have-attributes
16"""Python module for Keras base types.
17
18All the classes in this module is abstract classes that contains none or minimal
19implementations. It is designed be used as base class for other concrete
20classes, type checks, and python3 type hints.
21"""
22
23from __future__ import absolute_import
24from __future__ import division
25from __future__ import print_function
26
27import abc
28import six
29
30# TODO(scottzhu): Export all the types under this module with API symbol.
31
32
33@six.add_metaclass(abc.ABCMeta)
34class Layer(object):
35  """This is the class from which all layers inherit.
36
37  A layer is a callable object that takes as input one or more tensors and
38  that outputs one or more tensors. It involves *computation*, defined
39  in the `call()` method, and a *state* (weight variables), defined
40  either in the constructor `__init__()` or in the `build()` method.
41
42  Users will just instantiate a layer and then treat it as a callable.
43
44  We recommend that descendants of `Layer` implement the following methods:
45
46  * `__init__()`: Defines custom layer attributes, and creates layer state
47    variables that do not depend on input shapes, using `add_weight()`.
48  * `build(self, input_shape)`: This method can be used to create weights that
49    depend on the shape(s) of the input(s), using `add_weight()`. `__call__()`
50    will automatically build the layer (if it has not been built yet) by
51    calling `build()`.
52  * `call(self, *args, **kwargs)`: Called in `__call__` after making sure
53    `build()` has been called. `call()` performs the logic of applying the
54    layer to the input tensors (which should be passed in as argument).
55    Two reserved keyword arguments you can optionally use in `call()` are:
56      - `training` (boolean, whether the call is in
57        inference mode or training mode)
58      - `mask` (boolean tensor encoding masked timesteps in the input, used
59        in RNN layers)
60  * `get_config(self)`: Returns a dictionary containing the configuration used
61    to initialize this layer. If the keys differ from the arguments
62    in `__init__`, then override `from_config(self)` as well.
63    This method is used when saving
64    the layer or a model that contains this layer.
65
66  Examples:
67
68  Here's a basic example: a layer with two variables, `w` and `b`,
69  that returns `y = w . x + b`.
70  It shows how to implement `build()` and `call()`.
71  Variables set as attributes of a layer are tracked as weights
72  of the layers (in `layer.weights`).
73
74  ```python
75  class SimpleDense(Layer):
76
77    def __init__(self, units=32):
78        super(SimpleDense, self).__init__()
79        self.units = units
80
81    def build(self, input_shape):  # Create the state of the layer (weights)
82      w_init = tf.random_normal_initializer()
83      self.w = tf.Variable(
84          initial_value=w_init(shape=(input_shape[-1], self.units),
85                               dtype='float32'),
86          trainable=True)
87      b_init = tf.zeros_initializer()
88      self.b = tf.Variable(
89          initial_value=b_init(shape=(self.units,), dtype='float32'),
90          trainable=True)
91
92    def call(self, inputs):  # Defines the computation from inputs to outputs
93        return tf.matmul(inputs, self.w) + self.b
94
95  # Instantiates the layer.
96  linear_layer = SimpleDense(4)
97
98  # This will also call `build(input_shape)` and create the weights.
99  y = linear_layer(tf.ones((2, 2)))
100  assert len(linear_layer.weights) == 2
101
102  # These weights are trainable, so they're listed in `trainable_weights`:
103  assert len(linear_layer.trainable_weights) == 2
104  ```
105
106  Note that the method `add_weight()` offers a shortcut to create weights:
107
108  ```python
109  class SimpleDense(Layer):
110
111    def __init__(self, units=32):
112        super(SimpleDense, self).__init__()
113        self.units = units
114
115    def build(self, input_shape):
116        self.w = self.add_weight(shape=(input_shape[-1], self.units),
117                                 initializer='random_normal',
118                                 trainable=True)
119        self.b = self.add_weight(shape=(self.units,),
120                                 initializer='random_normal',
121                                 trainable=True)
122
123    def call(self, inputs):
124        return tf.matmul(inputs, self.w) + self.b
125  ```
126
127  Besides trainable weights, updated via backpropagation during training,
128  layers can also have non-trainable weights. These weights are meant to
129  be updated manually during `call()`. Here's a example layer that computes
130  the running sum of its inputs:
131
132  ```python
133  class ComputeSum(Layer):
134
135    def __init__(self, input_dim):
136        super(ComputeSum, self).__init__()
137        # Create a non-trainable weight.
138        self.total = tf.Variable(initial_value=tf.zeros((input_dim,)),
139                                 trainable=False)
140
141    def call(self, inputs):
142        self.total.assign_add(tf.reduce_sum(inputs, axis=0))
143        return self.total
144
145  my_sum = ComputeSum(2)
146  x = tf.ones((2, 2))
147
148  y = my_sum(x)
149  print(y.numpy())  # [2. 2.]
150
151  y = my_sum(x)
152  print(y.numpy())  # [4. 4.]
153
154  assert my_sum.weights == [my_sum.total]
155  assert my_sum.non_trainable_weights == [my_sum.total]
156  assert my_sum.trainable_weights == []
157  ```
158
159  For more information about creating layers, see the guide
160  [Making new Layers and Models via subclassing](
161    https://www.tensorflow.org/guide/keras/custom_layers_and_models)
162
163  Args:
164    trainable: Boolean, whether the layer's variables should be trainable.
165    name: String name of the layer.
166    dtype: The dtype of the layer's computations and weights (default of
167      `None` means use `tf.keras.backend.floatx` in TensorFlow 2, or the type
168      of the first input in TensorFlow 1).
169    dynamic: Set this to `True` if your layer should only be run eagerly, and
170      should not be used to generate a static computation graph.
171      This would be the case for a Tree-RNN or a recursive network,
172      for example, or generally for any layer that manipulates tensors
173      using Python control flow. If `False`, we assume that the layer can
174      safely be used to generate a static computation graph.
175
176  Attributes:
177    name: The name of the layer (string).
178    dtype: The dtype of the layer's computations and weights. If mixed
179      precision is used with a `tf.keras.mixed_precision.Policy`, this is
180      instead just the dtype of the layer's weights, as the computations are
181      done in a different dtype.
182    updates: List of update ops of this layer.
183    losses: List of losses added by this layer.
184    trainable_weights: List of variables to be included in backprop.
185    non_trainable_weights: List of variables that should not be
186      included in backprop.
187    weights: The concatenation of the lists trainable_weights and
188      non_trainable_weights (in this order).
189    trainable: Whether the layer should be trained (boolean).
190    input_spec: Optional (list of) `InputSpec` object(s) specifying the
191      constraints on inputs that can be accepted by the layer.
192
193  Each layer has a dtype, which is typically the dtype of the layer's
194  computations and variables. A layer's dtype can be queried via the
195  `Layer.dtype` property. The dtype is specified with the `dtype` constructor
196  argument. In TensorFlow 2, the dtype defaults to `tf.keras.backend.floatx()`
197  if no dtype is passed. `floatx()` itself defaults to "float32". Additionally,
198  layers will cast their inputs to the layer's dtype in TensorFlow 2. When mixed
199  precision is used, layers may have different computation and variable dtypes.
200  See `tf.keras.mixed_precision.Policy` for details on layer dtypes.
201  """
202  pass
203