1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=g-classes-have-attributes 16"""Python module for Keras base types. 17 18All the classes in this module is abstract classes that contains none or minimal 19implementations. It is designed be used as base class for other concrete 20classes, type checks, and python3 type hints. 21""" 22 23from __future__ import absolute_import 24from __future__ import division 25from __future__ import print_function 26 27import abc 28import six 29 30# TODO(scottzhu): Export all the types under this module with API symbol. 31 32 33@six.add_metaclass(abc.ABCMeta) 34class Layer(object): 35 """This is the class from which all layers inherit. 36 37 A layer is a callable object that takes as input one or more tensors and 38 that outputs one or more tensors. It involves *computation*, defined 39 in the `call()` method, and a *state* (weight variables), defined 40 either in the constructor `__init__()` or in the `build()` method. 41 42 Users will just instantiate a layer and then treat it as a callable. 43 44 We recommend that descendants of `Layer` implement the following methods: 45 46 * `__init__()`: Defines custom layer attributes, and creates layer state 47 variables that do not depend on input shapes, using `add_weight()`. 48 * `build(self, input_shape)`: This method can be used to create weights that 49 depend on the shape(s) of the input(s), using `add_weight()`. `__call__()` 50 will automatically build the layer (if it has not been built yet) by 51 calling `build()`. 52 * `call(self, *args, **kwargs)`: Called in `__call__` after making sure 53 `build()` has been called. `call()` performs the logic of applying the 54 layer to the input tensors (which should be passed in as argument). 55 Two reserved keyword arguments you can optionally use in `call()` are: 56 - `training` (boolean, whether the call is in 57 inference mode or training mode) 58 - `mask` (boolean tensor encoding masked timesteps in the input, used 59 in RNN layers) 60 * `get_config(self)`: Returns a dictionary containing the configuration used 61 to initialize this layer. If the keys differ from the arguments 62 in `__init__`, then override `from_config(self)` as well. 63 This method is used when saving 64 the layer or a model that contains this layer. 65 66 Examples: 67 68 Here's a basic example: a layer with two variables, `w` and `b`, 69 that returns `y = w . x + b`. 70 It shows how to implement `build()` and `call()`. 71 Variables set as attributes of a layer are tracked as weights 72 of the layers (in `layer.weights`). 73 74 ```python 75 class SimpleDense(Layer): 76 77 def __init__(self, units=32): 78 super(SimpleDense, self).__init__() 79 self.units = units 80 81 def build(self, input_shape): # Create the state of the layer (weights) 82 w_init = tf.random_normal_initializer() 83 self.w = tf.Variable( 84 initial_value=w_init(shape=(input_shape[-1], self.units), 85 dtype='float32'), 86 trainable=True) 87 b_init = tf.zeros_initializer() 88 self.b = tf.Variable( 89 initial_value=b_init(shape=(self.units,), dtype='float32'), 90 trainable=True) 91 92 def call(self, inputs): # Defines the computation from inputs to outputs 93 return tf.matmul(inputs, self.w) + self.b 94 95 # Instantiates the layer. 96 linear_layer = SimpleDense(4) 97 98 # This will also call `build(input_shape)` and create the weights. 99 y = linear_layer(tf.ones((2, 2))) 100 assert len(linear_layer.weights) == 2 101 102 # These weights are trainable, so they're listed in `trainable_weights`: 103 assert len(linear_layer.trainable_weights) == 2 104 ``` 105 106 Note that the method `add_weight()` offers a shortcut to create weights: 107 108 ```python 109 class SimpleDense(Layer): 110 111 def __init__(self, units=32): 112 super(SimpleDense, self).__init__() 113 self.units = units 114 115 def build(self, input_shape): 116 self.w = self.add_weight(shape=(input_shape[-1], self.units), 117 initializer='random_normal', 118 trainable=True) 119 self.b = self.add_weight(shape=(self.units,), 120 initializer='random_normal', 121 trainable=True) 122 123 def call(self, inputs): 124 return tf.matmul(inputs, self.w) + self.b 125 ``` 126 127 Besides trainable weights, updated via backpropagation during training, 128 layers can also have non-trainable weights. These weights are meant to 129 be updated manually during `call()`. Here's a example layer that computes 130 the running sum of its inputs: 131 132 ```python 133 class ComputeSum(Layer): 134 135 def __init__(self, input_dim): 136 super(ComputeSum, self).__init__() 137 # Create a non-trainable weight. 138 self.total = tf.Variable(initial_value=tf.zeros((input_dim,)), 139 trainable=False) 140 141 def call(self, inputs): 142 self.total.assign_add(tf.reduce_sum(inputs, axis=0)) 143 return self.total 144 145 my_sum = ComputeSum(2) 146 x = tf.ones((2, 2)) 147 148 y = my_sum(x) 149 print(y.numpy()) # [2. 2.] 150 151 y = my_sum(x) 152 print(y.numpy()) # [4. 4.] 153 154 assert my_sum.weights == [my_sum.total] 155 assert my_sum.non_trainable_weights == [my_sum.total] 156 assert my_sum.trainable_weights == [] 157 ``` 158 159 For more information about creating layers, see the guide 160 [Making new Layers and Models via subclassing]( 161 https://www.tensorflow.org/guide/keras/custom_layers_and_models) 162 163 Args: 164 trainable: Boolean, whether the layer's variables should be trainable. 165 name: String name of the layer. 166 dtype: The dtype of the layer's computations and weights (default of 167 `None` means use `tf.keras.backend.floatx` in TensorFlow 2, or the type 168 of the first input in TensorFlow 1). 169 dynamic: Set this to `True` if your layer should only be run eagerly, and 170 should not be used to generate a static computation graph. 171 This would be the case for a Tree-RNN or a recursive network, 172 for example, or generally for any layer that manipulates tensors 173 using Python control flow. If `False`, we assume that the layer can 174 safely be used to generate a static computation graph. 175 176 Attributes: 177 name: The name of the layer (string). 178 dtype: The dtype of the layer's computations and weights. If mixed 179 precision is used with a `tf.keras.mixed_precision.Policy`, this is 180 instead just the dtype of the layer's weights, as the computations are 181 done in a different dtype. 182 updates: List of update ops of this layer. 183 losses: List of losses added by this layer. 184 trainable_weights: List of variables to be included in backprop. 185 non_trainable_weights: List of variables that should not be 186 included in backprop. 187 weights: The concatenation of the lists trainable_weights and 188 non_trainable_weights (in this order). 189 trainable: Whether the layer should be trained (boolean). 190 input_spec: Optional (list of) `InputSpec` object(s) specifying the 191 constraints on inputs that can be accepted by the layer. 192 193 Each layer has a dtype, which is typically the dtype of the layer's 194 computations and variables. A layer's dtype can be queried via the 195 `Layer.dtype` property. The dtype is specified with the `dtype` constructor 196 argument. In TensorFlow 2, the dtype defaults to `tf.keras.backend.floatx()` 197 if no dtype is passed. `floatx()` itself defaults to "float32". Additionally, 198 layers will cast their inputs to the layer's dtype in TensorFlow 2. When mixed 199 precision is used, layers may have different computation and variable dtypes. 200 See `tf.keras.mixed_precision.Policy` for details on layer dtypes. 201 """ 202 pass 203