1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=invalid-name 16"""VGG16 model for Keras.""" 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import os 22 23from tensorflow.python.keras import backend 24from tensorflow.python.keras import layers 25from tensorflow.python.keras.applications import imagenet_utils 26from tensorflow.python.keras.engine import training 27from tensorflow.python.keras.utils import data_utils 28from tensorflow.python.keras.utils import layer_utils 29from tensorflow.python.util.tf_export import keras_export 30 31 32WEIGHTS_PATH = ('https://storage.googleapis.com/tensorflow/keras-applications/' 33 'vgg16/vgg16_weights_tf_dim_ordering_tf_kernels.h5') 34WEIGHTS_PATH_NO_TOP = ('https://storage.googleapis.com/tensorflow/' 35 'keras-applications/vgg16/' 36 'vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5') 37 38 39@keras_export('keras.applications.vgg16.VGG16', 'keras.applications.VGG16') 40def VGG16(include_top=True, 41 weights='imagenet', 42 input_tensor=None, 43 input_shape=None, 44 pooling=None, 45 classes=1000): 46 """Instantiates the VGG16 model. 47 48 By default, it loads weights pre-trained on ImageNet. Check 'weights' for 49 other options. 50 51 This model can be built both with 'channels_first' data format 52 (channels, height, width) or 'channels_last' data format 53 (height, width, channels). 54 55 The default input size for this model is 224x224. 56 57 Arguments: 58 include_top: whether to include the 3 fully-connected 59 layers at the top of the network. 60 weights: one of `None` (random initialization), 61 'imagenet' (pre-training on ImageNet), 62 or the path to the weights file to be loaded. 63 input_tensor: optional Keras tensor 64 (i.e. output of `layers.Input()`) 65 to use as image input for the model. 66 input_shape: optional shape tuple, only to be specified 67 if `include_top` is False (otherwise the input shape 68 has to be `(224, 224, 3)` 69 (with `channels_last` data format) 70 or `(3, 224, 224)` (with `channels_first` data format). 71 It should have exactly 3 input channels, 72 and width and height should be no smaller than 32. 73 E.g. `(200, 200, 3)` would be one valid value. 74 pooling: Optional pooling mode for feature extraction 75 when `include_top` is `False`. 76 - `None` means that the output of the model will be 77 the 4D tensor output of the 78 last convolutional block. 79 - `avg` means that global average pooling 80 will be applied to the output of the 81 last convolutional block, and thus 82 the output of the model will be a 2D tensor. 83 - `max` means that global max pooling will 84 be applied. 85 classes: optional number of classes to classify images 86 into, only to be specified if `include_top` is True, and 87 if no `weights` argument is specified. 88 89 Returns: 90 A Keras model instance. 91 92 Raises: 93 ValueError: in case of invalid argument for `weights`, 94 or invalid input shape. 95 """ 96 if not (weights in {'imagenet', None} or os.path.exists(weights)): 97 raise ValueError('The `weights` argument should be either ' 98 '`None` (random initialization), `imagenet` ' 99 '(pre-training on ImageNet), ' 100 'or the path to the weights file to be loaded.') 101 102 if weights == 'imagenet' and include_top and classes != 1000: 103 raise ValueError('If using `weights` as `"imagenet"` with `include_top`' 104 ' as true, `classes` should be 1000') 105 # Determine proper input shape 106 input_shape = imagenet_utils.obtain_input_shape( 107 input_shape, 108 default_size=224, 109 min_size=32, 110 data_format=backend.image_data_format(), 111 require_flatten=include_top, 112 weights=weights) 113 114 if input_tensor is None: 115 img_input = layers.Input(shape=input_shape) 116 else: 117 if not backend.is_keras_tensor(input_tensor): 118 img_input = layers.Input(tensor=input_tensor, shape=input_shape) 119 else: 120 img_input = input_tensor 121 # Block 1 122 x = layers.Conv2D( 123 64, (3, 3), activation='relu', padding='same', name='block1_conv1')( 124 img_input) 125 x = layers.Conv2D( 126 64, (3, 3), activation='relu', padding='same', name='block1_conv2')(x) 127 x = layers.MaxPooling2D((2, 2), strides=(2, 2), name='block1_pool')(x) 128 129 # Block 2 130 x = layers.Conv2D( 131 128, (3, 3), activation='relu', padding='same', name='block2_conv1')(x) 132 x = layers.Conv2D( 133 128, (3, 3), activation='relu', padding='same', name='block2_conv2')(x) 134 x = layers.MaxPooling2D((2, 2), strides=(2, 2), name='block2_pool')(x) 135 136 # Block 3 137 x = layers.Conv2D( 138 256, (3, 3), activation='relu', padding='same', name='block3_conv1')(x) 139 x = layers.Conv2D( 140 256, (3, 3), activation='relu', padding='same', name='block3_conv2')(x) 141 x = layers.Conv2D( 142 256, (3, 3), activation='relu', padding='same', name='block3_conv3')(x) 143 x = layers.MaxPooling2D((2, 2), strides=(2, 2), name='block3_pool')(x) 144 145 # Block 4 146 x = layers.Conv2D( 147 512, (3, 3), activation='relu', padding='same', name='block4_conv1')(x) 148 x = layers.Conv2D( 149 512, (3, 3), activation='relu', padding='same', name='block4_conv2')(x) 150 x = layers.Conv2D( 151 512, (3, 3), activation='relu', padding='same', name='block4_conv3')(x) 152 x = layers.MaxPooling2D((2, 2), strides=(2, 2), name='block4_pool')(x) 153 154 # Block 5 155 x = layers.Conv2D( 156 512, (3, 3), activation='relu', padding='same', name='block5_conv1')(x) 157 x = layers.Conv2D( 158 512, (3, 3), activation='relu', padding='same', name='block5_conv2')(x) 159 x = layers.Conv2D( 160 512, (3, 3), activation='relu', padding='same', name='block5_conv3')(x) 161 x = layers.MaxPooling2D((2, 2), strides=(2, 2), name='block5_pool')(x) 162 163 if include_top: 164 # Classification block 165 x = layers.Flatten(name='flatten')(x) 166 x = layers.Dense(4096, activation='relu', name='fc1')(x) 167 x = layers.Dense(4096, activation='relu', name='fc2')(x) 168 x = layers.Dense(classes, activation='softmax', name='predictions')(x) 169 else: 170 if pooling == 'avg': 171 x = layers.GlobalAveragePooling2D()(x) 172 elif pooling == 'max': 173 x = layers.GlobalMaxPooling2D()(x) 174 175 # Ensure that the model takes into account 176 # any potential predecessors of `input_tensor`. 177 if input_tensor is not None: 178 inputs = layer_utils.get_source_inputs(input_tensor) 179 else: 180 inputs = img_input 181 # Create model. 182 model = training.Model(inputs, x, name='vgg16') 183 184 # Load weights. 185 if weights == 'imagenet': 186 if include_top: 187 weights_path = data_utils.get_file( 188 'vgg16_weights_tf_dim_ordering_tf_kernels.h5', 189 WEIGHTS_PATH, 190 cache_subdir='models', 191 file_hash='64373286793e3c8b2b4e3219cbf3544b') 192 else: 193 weights_path = data_utils.get_file( 194 'vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5', 195 WEIGHTS_PATH_NO_TOP, 196 cache_subdir='models', 197 file_hash='6d6bbae143d832006294945121d1f1fc') 198 model.load_weights(weights_path) 199 elif weights is not None: 200 model.load_weights(weights) 201 202 return model 203 204 205@keras_export('keras.applications.vgg16.preprocess_input') 206def preprocess_input(x, data_format=None): 207 """Preprocesses the input (encoding a batch of images) to the VGG16 model.""" 208 return imagenet_utils.preprocess_input( 209 x, data_format=data_format, mode='caffe') 210 211 212@keras_export('keras.applications.vgg16.decode_predictions') 213def decode_predictions(preds, top=5): 214 """Decodes the prediction result from the VGG16 model.""" 215 return imagenet_utils.decode_predictions(preds, top=top) 216