• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Keras layers API."""
16
17from tensorflow.python import tf2
18
19# Generic layers.
20# pylint: disable=g-bad-import-order
21# pylint: disable=g-import-not-at-top
22from tensorflow.python.keras.engine.input_layer import Input
23from tensorflow.python.keras.engine.input_layer import InputLayer
24from tensorflow.python.keras.engine.input_spec import InputSpec
25from tensorflow.python.keras.engine.base_layer import Layer
26from tensorflow.python.keras.engine.base_preprocessing_layer import PreprocessingLayer
27
28# Image preprocessing layers.
29from tensorflow.python.keras.layers.preprocessing.image_preprocessing import CenterCrop
30from tensorflow.python.keras.layers.preprocessing.image_preprocessing import RandomCrop
31from tensorflow.python.keras.layers.preprocessing.image_preprocessing import RandomFlip
32from tensorflow.python.keras.layers.preprocessing.image_preprocessing import RandomContrast
33from tensorflow.python.keras.layers.preprocessing.image_preprocessing import RandomHeight
34from tensorflow.python.keras.layers.preprocessing.image_preprocessing import RandomRotation
35from tensorflow.python.keras.layers.preprocessing.image_preprocessing import RandomTranslation
36from tensorflow.python.keras.layers.preprocessing.image_preprocessing import RandomWidth
37from tensorflow.python.keras.layers.preprocessing.image_preprocessing import RandomZoom
38from tensorflow.python.keras.layers.preprocessing.image_preprocessing import Resizing
39from tensorflow.python.keras.layers.preprocessing.image_preprocessing import Rescaling
40
41# Preprocessing layers.
42from tensorflow.python.keras.layers.preprocessing.category_crossing import CategoryCrossing
43from tensorflow.python.keras.layers.preprocessing.category_encoding import CategoryEncoding
44from tensorflow.python.keras.layers.preprocessing.discretization import Discretization
45from tensorflow.python.keras.layers.preprocessing.hashing import Hashing
46from tensorflow.python.keras.layers.preprocessing.integer_lookup import IntegerLookup
47from tensorflow.python.keras.layers.preprocessing.normalization import Normalization
48from tensorflow.python.keras.layers.preprocessing.string_lookup import StringLookup
49from tensorflow.python.keras.layers.preprocessing.text_vectorization import TextVectorization
50
51# Advanced activations.
52from tensorflow.python.keras.layers.advanced_activations import LeakyReLU
53from tensorflow.python.keras.layers.advanced_activations import PReLU
54from tensorflow.python.keras.layers.advanced_activations import ELU
55from tensorflow.python.keras.layers.advanced_activations import ReLU
56from tensorflow.python.keras.layers.advanced_activations import ThresholdedReLU
57from tensorflow.python.keras.layers.advanced_activations import Softmax
58
59# Convolution layers.
60from tensorflow.python.keras.layers.convolutional import Conv1D
61from tensorflow.python.keras.layers.convolutional import Conv2D
62from tensorflow.python.keras.layers.convolutional import Conv3D
63from tensorflow.python.keras.layers.convolutional import Conv1DTranspose
64from tensorflow.python.keras.layers.convolutional import Conv2DTranspose
65from tensorflow.python.keras.layers.convolutional import Conv3DTranspose
66from tensorflow.python.keras.layers.convolutional import SeparableConv1D
67from tensorflow.python.keras.layers.convolutional import SeparableConv2D
68
69# Convolution layer aliases.
70from tensorflow.python.keras.layers.convolutional import Convolution1D
71from tensorflow.python.keras.layers.convolutional import Convolution2D
72from tensorflow.python.keras.layers.convolutional import Convolution3D
73from tensorflow.python.keras.layers.convolutional import Convolution2DTranspose
74from tensorflow.python.keras.layers.convolutional import Convolution3DTranspose
75from tensorflow.python.keras.layers.convolutional import SeparableConvolution1D
76from tensorflow.python.keras.layers.convolutional import SeparableConvolution2D
77from tensorflow.python.keras.layers.convolutional import DepthwiseConv2D
78
79# Image processing layers.
80from tensorflow.python.keras.layers.convolutional import UpSampling1D
81from tensorflow.python.keras.layers.convolutional import UpSampling2D
82from tensorflow.python.keras.layers.convolutional import UpSampling3D
83from tensorflow.python.keras.layers.convolutional import ZeroPadding1D
84from tensorflow.python.keras.layers.convolutional import ZeroPadding2D
85from tensorflow.python.keras.layers.convolutional import ZeroPadding3D
86from tensorflow.python.keras.layers.convolutional import Cropping1D
87from tensorflow.python.keras.layers.convolutional import Cropping2D
88from tensorflow.python.keras.layers.convolutional import Cropping3D
89
90# Core layers.
91from tensorflow.python.keras.layers.core import Masking
92from tensorflow.python.keras.layers.core import Dropout
93from tensorflow.python.keras.layers.core import SpatialDropout1D
94from tensorflow.python.keras.layers.core import SpatialDropout2D
95from tensorflow.python.keras.layers.core import SpatialDropout3D
96from tensorflow.python.keras.layers.core import Activation
97from tensorflow.python.keras.layers.core import Reshape
98from tensorflow.python.keras.layers.core import Permute
99from tensorflow.python.keras.layers.core import Flatten
100from tensorflow.python.keras.layers.core import RepeatVector
101from tensorflow.python.keras.layers.core import Lambda
102from tensorflow.python.keras.layers.core import Dense
103from tensorflow.python.keras.layers.core import ActivityRegularization
104
105# Dense Attention layers.
106from tensorflow.python.keras.layers.dense_attention import AdditiveAttention
107from tensorflow.python.keras.layers.dense_attention import Attention
108
109# Embedding layers.
110from tensorflow.python.keras.layers.embeddings import Embedding
111
112# Einsum-based dense layer/
113from tensorflow.python.keras.layers.einsum_dense import EinsumDense
114
115# Multi-head Attention layer.
116from tensorflow.python.keras.layers.multi_head_attention import MultiHeadAttention
117
118# Locally-connected layers.
119from tensorflow.python.keras.layers.local import LocallyConnected1D
120from tensorflow.python.keras.layers.local import LocallyConnected2D
121
122# Merge layers.
123from tensorflow.python.keras.layers.merge import Add
124from tensorflow.python.keras.layers.merge import Subtract
125from tensorflow.python.keras.layers.merge import Multiply
126from tensorflow.python.keras.layers.merge import Average
127from tensorflow.python.keras.layers.merge import Maximum
128from tensorflow.python.keras.layers.merge import Minimum
129from tensorflow.python.keras.layers.merge import Concatenate
130from tensorflow.python.keras.layers.merge import Dot
131from tensorflow.python.keras.layers.merge import add
132from tensorflow.python.keras.layers.merge import subtract
133from tensorflow.python.keras.layers.merge import multiply
134from tensorflow.python.keras.layers.merge import average
135from tensorflow.python.keras.layers.merge import maximum
136from tensorflow.python.keras.layers.merge import minimum
137from tensorflow.python.keras.layers.merge import concatenate
138from tensorflow.python.keras.layers.merge import dot
139
140# Noise layers.
141from tensorflow.python.keras.layers.noise import AlphaDropout
142from tensorflow.python.keras.layers.noise import GaussianNoise
143from tensorflow.python.keras.layers.noise import GaussianDropout
144
145# Normalization layers.
146from tensorflow.python.keras.layers.normalization.layer_normalization import LayerNormalization
147from tensorflow.python.keras.layers.normalization.batch_normalization import SyncBatchNormalization
148
149if tf2.enabled():
150  from tensorflow.python.keras.layers.normalization.batch_normalization import BatchNormalization
151  from tensorflow.python.keras.layers.normalization.batch_normalization_v1 import BatchNormalization as BatchNormalizationV1
152  BatchNormalizationV2 = BatchNormalization
153else:
154  from tensorflow.python.keras.layers.normalization.batch_normalization_v1 import BatchNormalization
155  from tensorflow.python.keras.layers.normalization.batch_normalization import BatchNormalization as BatchNormalizationV2
156  BatchNormalizationV1 = BatchNormalization
157
158# Kernelized layers.
159from tensorflow.python.keras.layers.kernelized import RandomFourierFeatures
160
161# Pooling layers.
162from tensorflow.python.keras.layers.pooling import MaxPooling1D
163from tensorflow.python.keras.layers.pooling import MaxPooling2D
164from tensorflow.python.keras.layers.pooling import MaxPooling3D
165from tensorflow.python.keras.layers.pooling import AveragePooling1D
166from tensorflow.python.keras.layers.pooling import AveragePooling2D
167from tensorflow.python.keras.layers.pooling import AveragePooling3D
168from tensorflow.python.keras.layers.pooling import GlobalAveragePooling1D
169from tensorflow.python.keras.layers.pooling import GlobalAveragePooling2D
170from tensorflow.python.keras.layers.pooling import GlobalAveragePooling3D
171from tensorflow.python.keras.layers.pooling import GlobalMaxPooling1D
172from tensorflow.python.keras.layers.pooling import GlobalMaxPooling2D
173from tensorflow.python.keras.layers.pooling import GlobalMaxPooling3D
174
175# Pooling layer aliases.
176from tensorflow.python.keras.layers.pooling import MaxPool1D
177from tensorflow.python.keras.layers.pooling import MaxPool2D
178from tensorflow.python.keras.layers.pooling import MaxPool3D
179from tensorflow.python.keras.layers.pooling import AvgPool1D
180from tensorflow.python.keras.layers.pooling import AvgPool2D
181from tensorflow.python.keras.layers.pooling import AvgPool3D
182from tensorflow.python.keras.layers.pooling import GlobalAvgPool1D
183from tensorflow.python.keras.layers.pooling import GlobalAvgPool2D
184from tensorflow.python.keras.layers.pooling import GlobalAvgPool3D
185from tensorflow.python.keras.layers.pooling import GlobalMaxPool1D
186from tensorflow.python.keras.layers.pooling import GlobalMaxPool2D
187from tensorflow.python.keras.layers.pooling import GlobalMaxPool3D
188
189# Recurrent layers.
190from tensorflow.python.keras.layers.recurrent import RNN
191from tensorflow.python.keras.layers.recurrent import AbstractRNNCell
192from tensorflow.python.keras.layers.recurrent import StackedRNNCells
193from tensorflow.python.keras.layers.recurrent import SimpleRNNCell
194from tensorflow.python.keras.layers.recurrent import PeepholeLSTMCell
195from tensorflow.python.keras.layers.recurrent import SimpleRNN
196
197if tf2.enabled():
198  from tensorflow.python.keras.layers.recurrent_v2 import GRU
199  from tensorflow.python.keras.layers.recurrent_v2 import GRUCell
200  from tensorflow.python.keras.layers.recurrent_v2 import LSTM
201  from tensorflow.python.keras.layers.recurrent_v2 import LSTMCell
202  from tensorflow.python.keras.layers.recurrent import GRU as GRUV1
203  from tensorflow.python.keras.layers.recurrent import GRUCell as GRUCellV1
204  from tensorflow.python.keras.layers.recurrent import LSTM as LSTMV1
205  from tensorflow.python.keras.layers.recurrent import LSTMCell as LSTMCellV1
206  GRUV2 = GRU
207  GRUCellV2 = GRUCell
208  LSTMV2 = LSTM
209  LSTMCellV2 = LSTMCell
210else:
211  from tensorflow.python.keras.layers.recurrent import GRU
212  from tensorflow.python.keras.layers.recurrent import GRUCell
213  from tensorflow.python.keras.layers.recurrent import LSTM
214  from tensorflow.python.keras.layers.recurrent import LSTMCell
215  from tensorflow.python.keras.layers.recurrent_v2 import GRU as GRUV2
216  from tensorflow.python.keras.layers.recurrent_v2 import GRUCell as GRUCellV2
217  from tensorflow.python.keras.layers.recurrent_v2 import LSTM as LSTMV2
218  from tensorflow.python.keras.layers.recurrent_v2 import LSTMCell as LSTMCellV2
219  GRUV1 = GRU
220  GRUCellV1 = GRUCell
221  LSTMV1 = LSTM
222  LSTMCellV1 = LSTMCell
223
224# Convolutional-recurrent layers.
225from tensorflow.python.keras.layers.convolutional_recurrent import ConvLSTM2D
226
227# CuDNN recurrent layers.
228from tensorflow.python.keras.layers.cudnn_recurrent import CuDNNLSTM
229from tensorflow.python.keras.layers.cudnn_recurrent import CuDNNGRU
230
231# Wrapper functions
232from tensorflow.python.keras.layers.wrappers import Wrapper
233from tensorflow.python.keras.layers.wrappers import Bidirectional
234from tensorflow.python.keras.layers.wrappers import TimeDistributed
235
236# # RNN Cell wrappers.
237from tensorflow.python.keras.layers.rnn_cell_wrapper_v2 import DeviceWrapper
238from tensorflow.python.keras.layers.rnn_cell_wrapper_v2 import DropoutWrapper
239from tensorflow.python.keras.layers.rnn_cell_wrapper_v2 import ResidualWrapper
240
241# Serialization functions
242from tensorflow.python.keras.layers import serialization
243from tensorflow.python.keras.layers.serialization import deserialize
244from tensorflow.python.keras.layers.serialization import serialize
245
246
247class VersionAwareLayers(object):
248  """Utility to be used internally to access layers in a V1/V2-aware fashion.
249
250  When using layers within the Keras codebase, under the constraint that
251  e.g. `layers.BatchNormalization` should be the `BatchNormalization` version
252  corresponding to the current runtime (TF1 or TF2), do not simply access
253  `layers.BatchNormalization` since it would ignore e.g. an early
254  `compat.v2.disable_v2_behavior()` call. Instead, use an instance
255  of `VersionAwareLayers` (which you can use just like the `layers` module).
256  """
257
258  def __getattr__(self, name):
259    serialization.populate_deserializable_objects()
260    if name in serialization.LOCAL.ALL_OBJECTS:
261      return serialization.LOCAL.ALL_OBJECTS[name]
262    return super(VersionAwareLayers, self).__getattr__(name)
263