1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Keras layers API.""" 16 17from tensorflow.python import tf2 18 19# Generic layers. 20# pylint: disable=g-bad-import-order 21# pylint: disable=g-import-not-at-top 22from tensorflow.python.keras.engine.input_layer import Input 23from tensorflow.python.keras.engine.input_layer import InputLayer 24from tensorflow.python.keras.engine.input_spec import InputSpec 25from tensorflow.python.keras.engine.base_layer import Layer 26from tensorflow.python.keras.engine.base_preprocessing_layer import PreprocessingLayer 27 28# Image preprocessing layers. 29from tensorflow.python.keras.layers.preprocessing.image_preprocessing import CenterCrop 30from tensorflow.python.keras.layers.preprocessing.image_preprocessing import RandomCrop 31from tensorflow.python.keras.layers.preprocessing.image_preprocessing import RandomFlip 32from tensorflow.python.keras.layers.preprocessing.image_preprocessing import RandomContrast 33from tensorflow.python.keras.layers.preprocessing.image_preprocessing import RandomHeight 34from tensorflow.python.keras.layers.preprocessing.image_preprocessing import RandomRotation 35from tensorflow.python.keras.layers.preprocessing.image_preprocessing import RandomTranslation 36from tensorflow.python.keras.layers.preprocessing.image_preprocessing import RandomWidth 37from tensorflow.python.keras.layers.preprocessing.image_preprocessing import RandomZoom 38from tensorflow.python.keras.layers.preprocessing.image_preprocessing import Resizing 39from tensorflow.python.keras.layers.preprocessing.image_preprocessing import Rescaling 40 41# Preprocessing layers. 42from tensorflow.python.keras.layers.preprocessing.category_crossing import CategoryCrossing 43from tensorflow.python.keras.layers.preprocessing.category_encoding import CategoryEncoding 44from tensorflow.python.keras.layers.preprocessing.discretization import Discretization 45from tensorflow.python.keras.layers.preprocessing.hashing import Hashing 46from tensorflow.python.keras.layers.preprocessing.integer_lookup import IntegerLookup 47from tensorflow.python.keras.layers.preprocessing.normalization import Normalization 48from tensorflow.python.keras.layers.preprocessing.string_lookup import StringLookup 49from tensorflow.python.keras.layers.preprocessing.text_vectorization import TextVectorization 50 51# Advanced activations. 52from tensorflow.python.keras.layers.advanced_activations import LeakyReLU 53from tensorflow.python.keras.layers.advanced_activations import PReLU 54from tensorflow.python.keras.layers.advanced_activations import ELU 55from tensorflow.python.keras.layers.advanced_activations import ReLU 56from tensorflow.python.keras.layers.advanced_activations import ThresholdedReLU 57from tensorflow.python.keras.layers.advanced_activations import Softmax 58 59# Convolution layers. 60from tensorflow.python.keras.layers.convolutional import Conv1D 61from tensorflow.python.keras.layers.convolutional import Conv2D 62from tensorflow.python.keras.layers.convolutional import Conv3D 63from tensorflow.python.keras.layers.convolutional import Conv1DTranspose 64from tensorflow.python.keras.layers.convolutional import Conv2DTranspose 65from tensorflow.python.keras.layers.convolutional import Conv3DTranspose 66from tensorflow.python.keras.layers.convolutional import SeparableConv1D 67from tensorflow.python.keras.layers.convolutional import SeparableConv2D 68 69# Convolution layer aliases. 70from tensorflow.python.keras.layers.convolutional import Convolution1D 71from tensorflow.python.keras.layers.convolutional import Convolution2D 72from tensorflow.python.keras.layers.convolutional import Convolution3D 73from tensorflow.python.keras.layers.convolutional import Convolution2DTranspose 74from tensorflow.python.keras.layers.convolutional import Convolution3DTranspose 75from tensorflow.python.keras.layers.convolutional import SeparableConvolution1D 76from tensorflow.python.keras.layers.convolutional import SeparableConvolution2D 77from tensorflow.python.keras.layers.convolutional import DepthwiseConv2D 78 79# Image processing layers. 80from tensorflow.python.keras.layers.convolutional import UpSampling1D 81from tensorflow.python.keras.layers.convolutional import UpSampling2D 82from tensorflow.python.keras.layers.convolutional import UpSampling3D 83from tensorflow.python.keras.layers.convolutional import ZeroPadding1D 84from tensorflow.python.keras.layers.convolutional import ZeroPadding2D 85from tensorflow.python.keras.layers.convolutional import ZeroPadding3D 86from tensorflow.python.keras.layers.convolutional import Cropping1D 87from tensorflow.python.keras.layers.convolutional import Cropping2D 88from tensorflow.python.keras.layers.convolutional import Cropping3D 89 90# Core layers. 91from tensorflow.python.keras.layers.core import Masking 92from tensorflow.python.keras.layers.core import Dropout 93from tensorflow.python.keras.layers.core import SpatialDropout1D 94from tensorflow.python.keras.layers.core import SpatialDropout2D 95from tensorflow.python.keras.layers.core import SpatialDropout3D 96from tensorflow.python.keras.layers.core import Activation 97from tensorflow.python.keras.layers.core import Reshape 98from tensorflow.python.keras.layers.core import Permute 99from tensorflow.python.keras.layers.core import Flatten 100from tensorflow.python.keras.layers.core import RepeatVector 101from tensorflow.python.keras.layers.core import Lambda 102from tensorflow.python.keras.layers.core import Dense 103from tensorflow.python.keras.layers.core import ActivityRegularization 104 105# Dense Attention layers. 106from tensorflow.python.keras.layers.dense_attention import AdditiveAttention 107from tensorflow.python.keras.layers.dense_attention import Attention 108 109# Embedding layers. 110from tensorflow.python.keras.layers.embeddings import Embedding 111 112# Einsum-based dense layer/ 113from tensorflow.python.keras.layers.einsum_dense import EinsumDense 114 115# Multi-head Attention layer. 116from tensorflow.python.keras.layers.multi_head_attention import MultiHeadAttention 117 118# Locally-connected layers. 119from tensorflow.python.keras.layers.local import LocallyConnected1D 120from tensorflow.python.keras.layers.local import LocallyConnected2D 121 122# Merge layers. 123from tensorflow.python.keras.layers.merge import Add 124from tensorflow.python.keras.layers.merge import Subtract 125from tensorflow.python.keras.layers.merge import Multiply 126from tensorflow.python.keras.layers.merge import Average 127from tensorflow.python.keras.layers.merge import Maximum 128from tensorflow.python.keras.layers.merge import Minimum 129from tensorflow.python.keras.layers.merge import Concatenate 130from tensorflow.python.keras.layers.merge import Dot 131from tensorflow.python.keras.layers.merge import add 132from tensorflow.python.keras.layers.merge import subtract 133from tensorflow.python.keras.layers.merge import multiply 134from tensorflow.python.keras.layers.merge import average 135from tensorflow.python.keras.layers.merge import maximum 136from tensorflow.python.keras.layers.merge import minimum 137from tensorflow.python.keras.layers.merge import concatenate 138from tensorflow.python.keras.layers.merge import dot 139 140# Noise layers. 141from tensorflow.python.keras.layers.noise import AlphaDropout 142from tensorflow.python.keras.layers.noise import GaussianNoise 143from tensorflow.python.keras.layers.noise import GaussianDropout 144 145# Normalization layers. 146from tensorflow.python.keras.layers.normalization.layer_normalization import LayerNormalization 147from tensorflow.python.keras.layers.normalization.batch_normalization import SyncBatchNormalization 148 149if tf2.enabled(): 150 from tensorflow.python.keras.layers.normalization.batch_normalization import BatchNormalization 151 from tensorflow.python.keras.layers.normalization.batch_normalization_v1 import BatchNormalization as BatchNormalizationV1 152 BatchNormalizationV2 = BatchNormalization 153else: 154 from tensorflow.python.keras.layers.normalization.batch_normalization_v1 import BatchNormalization 155 from tensorflow.python.keras.layers.normalization.batch_normalization import BatchNormalization as BatchNormalizationV2 156 BatchNormalizationV1 = BatchNormalization 157 158# Kernelized layers. 159from tensorflow.python.keras.layers.kernelized import RandomFourierFeatures 160 161# Pooling layers. 162from tensorflow.python.keras.layers.pooling import MaxPooling1D 163from tensorflow.python.keras.layers.pooling import MaxPooling2D 164from tensorflow.python.keras.layers.pooling import MaxPooling3D 165from tensorflow.python.keras.layers.pooling import AveragePooling1D 166from tensorflow.python.keras.layers.pooling import AveragePooling2D 167from tensorflow.python.keras.layers.pooling import AveragePooling3D 168from tensorflow.python.keras.layers.pooling import GlobalAveragePooling1D 169from tensorflow.python.keras.layers.pooling import GlobalAveragePooling2D 170from tensorflow.python.keras.layers.pooling import GlobalAveragePooling3D 171from tensorflow.python.keras.layers.pooling import GlobalMaxPooling1D 172from tensorflow.python.keras.layers.pooling import GlobalMaxPooling2D 173from tensorflow.python.keras.layers.pooling import GlobalMaxPooling3D 174 175# Pooling layer aliases. 176from tensorflow.python.keras.layers.pooling import MaxPool1D 177from tensorflow.python.keras.layers.pooling import MaxPool2D 178from tensorflow.python.keras.layers.pooling import MaxPool3D 179from tensorflow.python.keras.layers.pooling import AvgPool1D 180from tensorflow.python.keras.layers.pooling import AvgPool2D 181from tensorflow.python.keras.layers.pooling import AvgPool3D 182from tensorflow.python.keras.layers.pooling import GlobalAvgPool1D 183from tensorflow.python.keras.layers.pooling import GlobalAvgPool2D 184from tensorflow.python.keras.layers.pooling import GlobalAvgPool3D 185from tensorflow.python.keras.layers.pooling import GlobalMaxPool1D 186from tensorflow.python.keras.layers.pooling import GlobalMaxPool2D 187from tensorflow.python.keras.layers.pooling import GlobalMaxPool3D 188 189# Recurrent layers. 190from tensorflow.python.keras.layers.recurrent import RNN 191from tensorflow.python.keras.layers.recurrent import AbstractRNNCell 192from tensorflow.python.keras.layers.recurrent import StackedRNNCells 193from tensorflow.python.keras.layers.recurrent import SimpleRNNCell 194from tensorflow.python.keras.layers.recurrent import PeepholeLSTMCell 195from tensorflow.python.keras.layers.recurrent import SimpleRNN 196 197if tf2.enabled(): 198 from tensorflow.python.keras.layers.recurrent_v2 import GRU 199 from tensorflow.python.keras.layers.recurrent_v2 import GRUCell 200 from tensorflow.python.keras.layers.recurrent_v2 import LSTM 201 from tensorflow.python.keras.layers.recurrent_v2 import LSTMCell 202 from tensorflow.python.keras.layers.recurrent import GRU as GRUV1 203 from tensorflow.python.keras.layers.recurrent import GRUCell as GRUCellV1 204 from tensorflow.python.keras.layers.recurrent import LSTM as LSTMV1 205 from tensorflow.python.keras.layers.recurrent import LSTMCell as LSTMCellV1 206 GRUV2 = GRU 207 GRUCellV2 = GRUCell 208 LSTMV2 = LSTM 209 LSTMCellV2 = LSTMCell 210else: 211 from tensorflow.python.keras.layers.recurrent import GRU 212 from tensorflow.python.keras.layers.recurrent import GRUCell 213 from tensorflow.python.keras.layers.recurrent import LSTM 214 from tensorflow.python.keras.layers.recurrent import LSTMCell 215 from tensorflow.python.keras.layers.recurrent_v2 import GRU as GRUV2 216 from tensorflow.python.keras.layers.recurrent_v2 import GRUCell as GRUCellV2 217 from tensorflow.python.keras.layers.recurrent_v2 import LSTM as LSTMV2 218 from tensorflow.python.keras.layers.recurrent_v2 import LSTMCell as LSTMCellV2 219 GRUV1 = GRU 220 GRUCellV1 = GRUCell 221 LSTMV1 = LSTM 222 LSTMCellV1 = LSTMCell 223 224# Convolutional-recurrent layers. 225from tensorflow.python.keras.layers.convolutional_recurrent import ConvLSTM2D 226 227# CuDNN recurrent layers. 228from tensorflow.python.keras.layers.cudnn_recurrent import CuDNNLSTM 229from tensorflow.python.keras.layers.cudnn_recurrent import CuDNNGRU 230 231# Wrapper functions 232from tensorflow.python.keras.layers.wrappers import Wrapper 233from tensorflow.python.keras.layers.wrappers import Bidirectional 234from tensorflow.python.keras.layers.wrappers import TimeDistributed 235 236# # RNN Cell wrappers. 237from tensorflow.python.keras.layers.rnn_cell_wrapper_v2 import DeviceWrapper 238from tensorflow.python.keras.layers.rnn_cell_wrapper_v2 import DropoutWrapper 239from tensorflow.python.keras.layers.rnn_cell_wrapper_v2 import ResidualWrapper 240 241# Serialization functions 242from tensorflow.python.keras.layers import serialization 243from tensorflow.python.keras.layers.serialization import deserialize 244from tensorflow.python.keras.layers.serialization import serialize 245 246 247class VersionAwareLayers(object): 248 """Utility to be used internally to access layers in a V1/V2-aware fashion. 249 250 When using layers within the Keras codebase, under the constraint that 251 e.g. `layers.BatchNormalization` should be the `BatchNormalization` version 252 corresponding to the current runtime (TF1 or TF2), do not simply access 253 `layers.BatchNormalization` since it would ignore e.g. an early 254 `compat.v2.disable_v2_behavior()` call. Instead, use an instance 255 of `VersionAwareLayers` (which you can use just like the `layers` module). 256 """ 257 258 def __getattr__(self, name): 259 serialization.populate_deserializable_objects() 260 if name in serialization.LOCAL.ALL_OBJECTS: 261 return serialization.LOCAL.ALL_OBJECTS[name] 262 return super(VersionAwareLayers, self).__getattr__(name) 263