1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=invalid-name 16"""Inception V3 model for Keras. 17 18Reference paper: 19 - [Rethinking the Inception Architecture for Computer Vision]( 20 http://arxiv.org/abs/1512.00567) (CVPR 2016) 21""" 22from __future__ import absolute_import 23from __future__ import division 24from __future__ import print_function 25 26import os 27 28from tensorflow.python.keras import backend 29from tensorflow.python.keras import layers 30from tensorflow.python.keras.applications import imagenet_utils 31from tensorflow.python.keras.engine import training 32from tensorflow.python.keras.utils import data_utils 33from tensorflow.python.keras.utils import layer_utils 34from tensorflow.python.util.tf_export import keras_export 35 36 37WEIGHTS_PATH = ( 38 'https://storage.googleapis.com/tensorflow/keras-applications/' 39 'inception_v3/inception_v3_weights_tf_dim_ordering_tf_kernels.h5') 40WEIGHTS_PATH_NO_TOP = ( 41 'https://storage.googleapis.com/tensorflow/keras-applications/' 42 'inception_v3/inception_v3_weights_tf_dim_ordering_tf_kernels_notop.h5') 43 44 45@keras_export('keras.applications.inception_v3.InceptionV3', 46 'keras.applications.InceptionV3') 47def InceptionV3(include_top=True, 48 weights='imagenet', 49 input_tensor=None, 50 input_shape=None, 51 pooling=None, 52 classes=1000): 53 """Instantiates the Inception v3 architecture. 54 55 Reference paper: 56 - [Rethinking the Inception Architecture for Computer Vision]( 57 http://arxiv.org/abs/1512.00567) (CVPR 2016) 58 59 Optionally loads weights pre-trained on ImageNet. 60 Note that the data format convention used by the model is 61 the one specified in the `tf.keras.backend.image_data_format()`. 62 63 Arguments: 64 include_top: Boolean, whether to include the fully-connected 65 layer at the top, as the last layer of the network. Default to `True`. 66 weights: One of `None` (random initialization), 67 `imagenet` (pre-training on ImageNet), 68 or the path to the weights file to be loaded. Default to `imagenet`. 69 input_tensor: Optional Keras tensor (i.e. output of `layers.Input()`) 70 to use as image input for the model. `input_tensor` is useful for sharing 71 inputs between multiple different networks. Default to None. 72 input_shape: Optional shape tuple, only to be specified 73 if `include_top` is False (otherwise the input shape 74 has to be `(299, 299, 3)` (with `channels_last` data format) 75 or `(3, 299, 299)` (with `channels_first` data format). 76 It should have exactly 3 inputs channels, 77 and width and height should be no smaller than 75. 78 E.g. `(150, 150, 3)` would be one valid value. 79 `input_shape` will be ignored if the `input_tensor` is provided. 80 pooling: Optional pooling mode for feature extraction 81 when `include_top` is `False`. 82 - `None` (default) means that the output of the model will be 83 the 4D tensor output of the last convolutional block. 84 - `avg` means that global average pooling 85 will be applied to the output of the 86 last convolutional block, and thus 87 the output of the model will be a 2D tensor. 88 - `max` means that global max pooling will be applied. 89 classes: optional number of classes to classify images 90 into, only to be specified if `include_top` is True, and 91 if no `weights` argument is specified. Default to 1000. 92 93 Returns: 94 A Keras `tf.keras.Model` instance. 95 96 Raises: 97 ValueError: in case of invalid argument for `weights`, 98 or invalid input shape. 99 """ 100 if not (weights in {'imagenet', None} or os.path.exists(weights)): 101 raise ValueError('The `weights` argument should be either ' 102 '`None` (random initialization), `imagenet` ' 103 '(pre-training on ImageNet), ' 104 'or the path to the weights file to be loaded.') 105 106 if weights == 'imagenet' and include_top and classes != 1000: 107 raise ValueError('If using `weights` as `"imagenet"` with `include_top`' 108 ' as true, `classes` should be 1000') 109 110 # Determine proper input shape 111 input_shape = imagenet_utils.obtain_input_shape( 112 input_shape, 113 default_size=299, 114 min_size=75, 115 data_format=backend.image_data_format(), 116 require_flatten=include_top, 117 weights=weights) 118 119 if input_tensor is None: 120 img_input = layers.Input(shape=input_shape) 121 else: 122 if not backend.is_keras_tensor(input_tensor): 123 img_input = layers.Input(tensor=input_tensor, shape=input_shape) 124 else: 125 img_input = input_tensor 126 127 if backend.image_data_format() == 'channels_first': 128 channel_axis = 1 129 else: 130 channel_axis = 3 131 132 x = conv2d_bn(img_input, 32, 3, 3, strides=(2, 2), padding='valid') 133 x = conv2d_bn(x, 32, 3, 3, padding='valid') 134 x = conv2d_bn(x, 64, 3, 3) 135 x = layers.MaxPooling2D((3, 3), strides=(2, 2))(x) 136 137 x = conv2d_bn(x, 80, 1, 1, padding='valid') 138 x = conv2d_bn(x, 192, 3, 3, padding='valid') 139 x = layers.MaxPooling2D((3, 3), strides=(2, 2))(x) 140 141 # mixed 0: 35 x 35 x 256 142 branch1x1 = conv2d_bn(x, 64, 1, 1) 143 144 branch5x5 = conv2d_bn(x, 48, 1, 1) 145 branch5x5 = conv2d_bn(branch5x5, 64, 5, 5) 146 147 branch3x3dbl = conv2d_bn(x, 64, 1, 1) 148 branch3x3dbl = conv2d_bn(branch3x3dbl, 96, 3, 3) 149 branch3x3dbl = conv2d_bn(branch3x3dbl, 96, 3, 3) 150 151 branch_pool = layers.AveragePooling2D( 152 (3, 3), strides=(1, 1), padding='same')(x) 153 branch_pool = conv2d_bn(branch_pool, 32, 1, 1) 154 x = layers.concatenate([branch1x1, branch5x5, branch3x3dbl, branch_pool], 155 axis=channel_axis, 156 name='mixed0') 157 158 # mixed 1: 35 x 35 x 288 159 branch1x1 = conv2d_bn(x, 64, 1, 1) 160 161 branch5x5 = conv2d_bn(x, 48, 1, 1) 162 branch5x5 = conv2d_bn(branch5x5, 64, 5, 5) 163 164 branch3x3dbl = conv2d_bn(x, 64, 1, 1) 165 branch3x3dbl = conv2d_bn(branch3x3dbl, 96, 3, 3) 166 branch3x3dbl = conv2d_bn(branch3x3dbl, 96, 3, 3) 167 168 branch_pool = layers.AveragePooling2D( 169 (3, 3), strides=(1, 1), padding='same')(x) 170 branch_pool = conv2d_bn(branch_pool, 64, 1, 1) 171 x = layers.concatenate([branch1x1, branch5x5, branch3x3dbl, branch_pool], 172 axis=channel_axis, 173 name='mixed1') 174 175 # mixed 2: 35 x 35 x 288 176 branch1x1 = conv2d_bn(x, 64, 1, 1) 177 178 branch5x5 = conv2d_bn(x, 48, 1, 1) 179 branch5x5 = conv2d_bn(branch5x5, 64, 5, 5) 180 181 branch3x3dbl = conv2d_bn(x, 64, 1, 1) 182 branch3x3dbl = conv2d_bn(branch3x3dbl, 96, 3, 3) 183 branch3x3dbl = conv2d_bn(branch3x3dbl, 96, 3, 3) 184 185 branch_pool = layers.AveragePooling2D( 186 (3, 3), strides=(1, 1), padding='same')(x) 187 branch_pool = conv2d_bn(branch_pool, 64, 1, 1) 188 x = layers.concatenate([branch1x1, branch5x5, branch3x3dbl, branch_pool], 189 axis=channel_axis, 190 name='mixed2') 191 192 # mixed 3: 17 x 17 x 768 193 branch3x3 = conv2d_bn(x, 384, 3, 3, strides=(2, 2), padding='valid') 194 195 branch3x3dbl = conv2d_bn(x, 64, 1, 1) 196 branch3x3dbl = conv2d_bn(branch3x3dbl, 96, 3, 3) 197 branch3x3dbl = conv2d_bn( 198 branch3x3dbl, 96, 3, 3, strides=(2, 2), padding='valid') 199 200 branch_pool = layers.MaxPooling2D((3, 3), strides=(2, 2))(x) 201 x = layers.concatenate([branch3x3, branch3x3dbl, branch_pool], 202 axis=channel_axis, 203 name='mixed3') 204 205 # mixed 4: 17 x 17 x 768 206 branch1x1 = conv2d_bn(x, 192, 1, 1) 207 208 branch7x7 = conv2d_bn(x, 128, 1, 1) 209 branch7x7 = conv2d_bn(branch7x7, 128, 1, 7) 210 branch7x7 = conv2d_bn(branch7x7, 192, 7, 1) 211 212 branch7x7dbl = conv2d_bn(x, 128, 1, 1) 213 branch7x7dbl = conv2d_bn(branch7x7dbl, 128, 7, 1) 214 branch7x7dbl = conv2d_bn(branch7x7dbl, 128, 1, 7) 215 branch7x7dbl = conv2d_bn(branch7x7dbl, 128, 7, 1) 216 branch7x7dbl = conv2d_bn(branch7x7dbl, 192, 1, 7) 217 218 branch_pool = layers.AveragePooling2D( 219 (3, 3), strides=(1, 1), padding='same')(x) 220 branch_pool = conv2d_bn(branch_pool, 192, 1, 1) 221 x = layers.concatenate([branch1x1, branch7x7, branch7x7dbl, branch_pool], 222 axis=channel_axis, 223 name='mixed4') 224 225 # mixed 5, 6: 17 x 17 x 768 226 for i in range(2): 227 branch1x1 = conv2d_bn(x, 192, 1, 1) 228 229 branch7x7 = conv2d_bn(x, 160, 1, 1) 230 branch7x7 = conv2d_bn(branch7x7, 160, 1, 7) 231 branch7x7 = conv2d_bn(branch7x7, 192, 7, 1) 232 233 branch7x7dbl = conv2d_bn(x, 160, 1, 1) 234 branch7x7dbl = conv2d_bn(branch7x7dbl, 160, 7, 1) 235 branch7x7dbl = conv2d_bn(branch7x7dbl, 160, 1, 7) 236 branch7x7dbl = conv2d_bn(branch7x7dbl, 160, 7, 1) 237 branch7x7dbl = conv2d_bn(branch7x7dbl, 192, 1, 7) 238 239 branch_pool = layers.AveragePooling2D((3, 3), 240 strides=(1, 1), 241 padding='same')( 242 x) 243 branch_pool = conv2d_bn(branch_pool, 192, 1, 1) 244 x = layers.concatenate([branch1x1, branch7x7, branch7x7dbl, branch_pool], 245 axis=channel_axis, 246 name='mixed' + str(5 + i)) 247 248 # mixed 7: 17 x 17 x 768 249 branch1x1 = conv2d_bn(x, 192, 1, 1) 250 251 branch7x7 = conv2d_bn(x, 192, 1, 1) 252 branch7x7 = conv2d_bn(branch7x7, 192, 1, 7) 253 branch7x7 = conv2d_bn(branch7x7, 192, 7, 1) 254 255 branch7x7dbl = conv2d_bn(x, 192, 1, 1) 256 branch7x7dbl = conv2d_bn(branch7x7dbl, 192, 7, 1) 257 branch7x7dbl = conv2d_bn(branch7x7dbl, 192, 1, 7) 258 branch7x7dbl = conv2d_bn(branch7x7dbl, 192, 7, 1) 259 branch7x7dbl = conv2d_bn(branch7x7dbl, 192, 1, 7) 260 261 branch_pool = layers.AveragePooling2D( 262 (3, 3), strides=(1, 1), padding='same')(x) 263 branch_pool = conv2d_bn(branch_pool, 192, 1, 1) 264 x = layers.concatenate([branch1x1, branch7x7, branch7x7dbl, branch_pool], 265 axis=channel_axis, 266 name='mixed7') 267 268 # mixed 8: 8 x 8 x 1280 269 branch3x3 = conv2d_bn(x, 192, 1, 1) 270 branch3x3 = conv2d_bn(branch3x3, 320, 3, 3, strides=(2, 2), padding='valid') 271 272 branch7x7x3 = conv2d_bn(x, 192, 1, 1) 273 branch7x7x3 = conv2d_bn(branch7x7x3, 192, 1, 7) 274 branch7x7x3 = conv2d_bn(branch7x7x3, 192, 7, 1) 275 branch7x7x3 = conv2d_bn( 276 branch7x7x3, 192, 3, 3, strides=(2, 2), padding='valid') 277 278 branch_pool = layers.MaxPooling2D((3, 3), strides=(2, 2))(x) 279 x = layers.concatenate([branch3x3, branch7x7x3, branch_pool], 280 axis=channel_axis, 281 name='mixed8') 282 283 # mixed 9: 8 x 8 x 2048 284 for i in range(2): 285 branch1x1 = conv2d_bn(x, 320, 1, 1) 286 287 branch3x3 = conv2d_bn(x, 384, 1, 1) 288 branch3x3_1 = conv2d_bn(branch3x3, 384, 1, 3) 289 branch3x3_2 = conv2d_bn(branch3x3, 384, 3, 1) 290 branch3x3 = layers.concatenate([branch3x3_1, branch3x3_2], 291 axis=channel_axis, 292 name='mixed9_' + str(i)) 293 294 branch3x3dbl = conv2d_bn(x, 448, 1, 1) 295 branch3x3dbl = conv2d_bn(branch3x3dbl, 384, 3, 3) 296 branch3x3dbl_1 = conv2d_bn(branch3x3dbl, 384, 1, 3) 297 branch3x3dbl_2 = conv2d_bn(branch3x3dbl, 384, 3, 1) 298 branch3x3dbl = layers.concatenate([branch3x3dbl_1, branch3x3dbl_2], 299 axis=channel_axis) 300 301 branch_pool = layers.AveragePooling2D((3, 3), 302 strides=(1, 1), 303 padding='same')( 304 x) 305 branch_pool = conv2d_bn(branch_pool, 192, 1, 1) 306 x = layers.concatenate([branch1x1, branch3x3, branch3x3dbl, branch_pool], 307 axis=channel_axis, 308 name='mixed' + str(9 + i)) 309 if include_top: 310 # Classification block 311 x = layers.GlobalAveragePooling2D(name='avg_pool')(x) 312 x = layers.Dense(classes, activation='softmax', name='predictions')(x) 313 else: 314 if pooling == 'avg': 315 x = layers.GlobalAveragePooling2D()(x) 316 elif pooling == 'max': 317 x = layers.GlobalMaxPooling2D()(x) 318 319 # Ensure that the model takes into account 320 # any potential predecessors of `input_tensor`. 321 if input_tensor is not None: 322 inputs = layer_utils.get_source_inputs(input_tensor) 323 else: 324 inputs = img_input 325 # Create model. 326 model = training.Model(inputs, x, name='inception_v3') 327 328 # Load weights. 329 if weights == 'imagenet': 330 if include_top: 331 weights_path = data_utils.get_file( 332 'inception_v3_weights_tf_dim_ordering_tf_kernels.h5', 333 WEIGHTS_PATH, 334 cache_subdir='models', 335 file_hash='9a0d58056eeedaa3f26cb7ebd46da564') 336 else: 337 weights_path = data_utils.get_file( 338 'inception_v3_weights_tf_dim_ordering_tf_kernels_notop.h5', 339 WEIGHTS_PATH_NO_TOP, 340 cache_subdir='models', 341 file_hash='bcbd6486424b2319ff4ef7d526e38f63') 342 model.load_weights(weights_path) 343 elif weights is not None: 344 model.load_weights(weights) 345 346 return model 347 348 349def conv2d_bn(x, 350 filters, 351 num_row, 352 num_col, 353 padding='same', 354 strides=(1, 1), 355 name=None): 356 """Utility function to apply conv + BN. 357 358 Arguments: 359 x: input tensor. 360 filters: filters in `Conv2D`. 361 num_row: height of the convolution kernel. 362 num_col: width of the convolution kernel. 363 padding: padding mode in `Conv2D`. 364 strides: strides in `Conv2D`. 365 name: name of the ops; will become `name + '_conv'` 366 for the convolution and `name + '_bn'` for the 367 batch norm layer. 368 369 Returns: 370 Output tensor after applying `Conv2D` and `BatchNormalization`. 371 """ 372 if name is not None: 373 bn_name = name + '_bn' 374 conv_name = name + '_conv' 375 else: 376 bn_name = None 377 conv_name = None 378 if backend.image_data_format() == 'channels_first': 379 bn_axis = 1 380 else: 381 bn_axis = 3 382 x = layers.Conv2D( 383 filters, (num_row, num_col), 384 strides=strides, 385 padding=padding, 386 use_bias=False, 387 name=conv_name)( 388 x) 389 x = layers.BatchNormalization(axis=bn_axis, scale=False, name=bn_name)(x) 390 x = layers.Activation('relu', name=name)(x) 391 return x 392 393 394@keras_export('keras.applications.inception_v3.preprocess_input') 395def preprocess_input(x, data_format=None): 396 return imagenet_utils.preprocess_input(x, data_format=data_format, mode='tf') 397 398 399@keras_export('keras.applications.inception_v3.decode_predictions') 400def decode_predictions(preds, top=5): 401 return imagenet_utils.decode_predictions(preds, top=top) 402