1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=invalid-name 16"""Inception-ResNet V2 model for Keras. 17 18 19Reference paper: 20 - [Inception-v4, Inception-ResNet and the Impact of 21 Residual Connections on Learning](https://arxiv.org/abs/1602.07261) 22 (AAAI 2017) 23""" 24from __future__ import absolute_import 25from __future__ import division 26from __future__ import print_function 27 28import os 29 30from tensorflow.python.keras import backend 31from tensorflow.python.keras import layers 32from tensorflow.python.keras.applications import imagenet_utils 33from tensorflow.python.keras.engine import training 34from tensorflow.python.keras.utils import data_utils 35from tensorflow.python.keras.utils import layer_utils 36from tensorflow.python.util.tf_export import keras_export 37 38 39BASE_WEIGHT_URL = ('https://storage.googleapis.com/tensorflow/' 40 'keras-applications/inception_resnet_v2/') 41 42 43@keras_export('keras.applications.inception_resnet_v2.InceptionResNetV2', 44 'keras.applications.InceptionResNetV2') 45def InceptionResNetV2(include_top=True, 46 weights='imagenet', 47 input_tensor=None, 48 input_shape=None, 49 pooling=None, 50 classes=1000, 51 **kwargs): 52 """Instantiates the Inception-ResNet v2 architecture. 53 54 Optionally loads weights pre-trained on ImageNet. 55 Note that the data format convention used by the model is 56 the one specified in your Keras config at `~/.keras/keras.json`. 57 58 Arguments: 59 include_top: whether to include the fully-connected 60 layer at the top of the network. 61 weights: one of `None` (random initialization), 62 'imagenet' (pre-training on ImageNet), 63 or the path to the weights file to be loaded. 64 input_tensor: optional Keras tensor (i.e. output of `layers.Input()`) 65 to use as image input for the model. 66 input_shape: optional shape tuple, only to be specified 67 if `include_top` is `False` (otherwise the input shape 68 has to be `(299, 299, 3)` (with `'channels_last'` data format) 69 or `(3, 299, 299)` (with `'channels_first'` data format). 70 It should have exactly 3 inputs channels, 71 and width and height should be no smaller than 75. 72 E.g. `(150, 150, 3)` would be one valid value. 73 pooling: Optional pooling mode for feature extraction 74 when `include_top` is `False`. 75 - `None` means that the output of the model will be 76 the 4D tensor output of the last convolutional block. 77 - `'avg'` means that global average pooling 78 will be applied to the output of the 79 last convolutional block, and thus 80 the output of the model will be a 2D tensor. 81 - `'max'` means that global max pooling will be applied. 82 classes: optional number of classes to classify images 83 into, only to be specified if `include_top` is `True`, and 84 if no `weights` argument is specified. 85 **kwargs: For backwards compatibility only. 86 87 Returns: 88 A Keras `Model` instance. 89 90 Raises: 91 ValueError: in case of invalid argument for `weights`, 92 or invalid input shape. 93 """ 94 if 'layers' in kwargs: 95 global layers 96 layers = kwargs.pop('layers') 97 if kwargs: 98 raise ValueError('Unknown argument(s): %s' % (kwargs,)) 99 if not (weights in {'imagenet', None} or os.path.exists(weights)): 100 raise ValueError('The `weights` argument should be either ' 101 '`None` (random initialization), `imagenet` ' 102 '(pre-training on ImageNet), ' 103 'or the path to the weights file to be loaded.') 104 105 if weights == 'imagenet' and include_top and classes != 1000: 106 raise ValueError('If using `weights` as `"imagenet"` with `include_top`' 107 ' as true, `classes` should be 1000') 108 109 # Determine proper input shape 110 input_shape = imagenet_utils.obtain_input_shape( 111 input_shape, 112 default_size=299, 113 min_size=75, 114 data_format=backend.image_data_format(), 115 require_flatten=include_top, 116 weights=weights) 117 118 if input_tensor is None: 119 img_input = layers.Input(shape=input_shape) 120 else: 121 if not backend.is_keras_tensor(input_tensor): 122 img_input = layers.Input(tensor=input_tensor, shape=input_shape) 123 else: 124 img_input = input_tensor 125 126 # Stem block: 35 x 35 x 192 127 x = conv2d_bn(img_input, 32, 3, strides=2, padding='valid') 128 x = conv2d_bn(x, 32, 3, padding='valid') 129 x = conv2d_bn(x, 64, 3) 130 x = layers.MaxPooling2D(3, strides=2)(x) 131 x = conv2d_bn(x, 80, 1, padding='valid') 132 x = conv2d_bn(x, 192, 3, padding='valid') 133 x = layers.MaxPooling2D(3, strides=2)(x) 134 135 # Mixed 5b (Inception-A block): 35 x 35 x 320 136 branch_0 = conv2d_bn(x, 96, 1) 137 branch_1 = conv2d_bn(x, 48, 1) 138 branch_1 = conv2d_bn(branch_1, 64, 5) 139 branch_2 = conv2d_bn(x, 64, 1) 140 branch_2 = conv2d_bn(branch_2, 96, 3) 141 branch_2 = conv2d_bn(branch_2, 96, 3) 142 branch_pool = layers.AveragePooling2D(3, strides=1, padding='same')(x) 143 branch_pool = conv2d_bn(branch_pool, 64, 1) 144 branches = [branch_0, branch_1, branch_2, branch_pool] 145 channel_axis = 1 if backend.image_data_format() == 'channels_first' else 3 146 x = layers.Concatenate(axis=channel_axis, name='mixed_5b')(branches) 147 148 # 10x block35 (Inception-ResNet-A block): 35 x 35 x 320 149 for block_idx in range(1, 11): 150 x = inception_resnet_block( 151 x, scale=0.17, block_type='block35', block_idx=block_idx) 152 153 # Mixed 6a (Reduction-A block): 17 x 17 x 1088 154 branch_0 = conv2d_bn(x, 384, 3, strides=2, padding='valid') 155 branch_1 = conv2d_bn(x, 256, 1) 156 branch_1 = conv2d_bn(branch_1, 256, 3) 157 branch_1 = conv2d_bn(branch_1, 384, 3, strides=2, padding='valid') 158 branch_pool = layers.MaxPooling2D(3, strides=2, padding='valid')(x) 159 branches = [branch_0, branch_1, branch_pool] 160 x = layers.Concatenate(axis=channel_axis, name='mixed_6a')(branches) 161 162 # 20x block17 (Inception-ResNet-B block): 17 x 17 x 1088 163 for block_idx in range(1, 21): 164 x = inception_resnet_block( 165 x, scale=0.1, block_type='block17', block_idx=block_idx) 166 167 # Mixed 7a (Reduction-B block): 8 x 8 x 2080 168 branch_0 = conv2d_bn(x, 256, 1) 169 branch_0 = conv2d_bn(branch_0, 384, 3, strides=2, padding='valid') 170 branch_1 = conv2d_bn(x, 256, 1) 171 branch_1 = conv2d_bn(branch_1, 288, 3, strides=2, padding='valid') 172 branch_2 = conv2d_bn(x, 256, 1) 173 branch_2 = conv2d_bn(branch_2, 288, 3) 174 branch_2 = conv2d_bn(branch_2, 320, 3, strides=2, padding='valid') 175 branch_pool = layers.MaxPooling2D(3, strides=2, padding='valid')(x) 176 branches = [branch_0, branch_1, branch_2, branch_pool] 177 x = layers.Concatenate(axis=channel_axis, name='mixed_7a')(branches) 178 179 # 10x block8 (Inception-ResNet-C block): 8 x 8 x 2080 180 for block_idx in range(1, 10): 181 x = inception_resnet_block( 182 x, scale=0.2, block_type='block8', block_idx=block_idx) 183 x = inception_resnet_block( 184 x, scale=1., activation=None, block_type='block8', block_idx=10) 185 186 # Final convolution block: 8 x 8 x 1536 187 x = conv2d_bn(x, 1536, 1, name='conv_7b') 188 189 if include_top: 190 # Classification block 191 x = layers.GlobalAveragePooling2D(name='avg_pool')(x) 192 x = layers.Dense(classes, activation='softmax', name='predictions')(x) 193 else: 194 if pooling == 'avg': 195 x = layers.GlobalAveragePooling2D()(x) 196 elif pooling == 'max': 197 x = layers.GlobalMaxPooling2D()(x) 198 199 # Ensure that the model takes into account 200 # any potential predecessors of `input_tensor`. 201 if input_tensor is not None: 202 inputs = layer_utils.get_source_inputs(input_tensor) 203 else: 204 inputs = img_input 205 206 # Create model. 207 model = training.Model(inputs, x, name='inception_resnet_v2') 208 209 # Load weights. 210 if weights == 'imagenet': 211 if include_top: 212 fname = 'inception_resnet_v2_weights_tf_dim_ordering_tf_kernels.h5' 213 weights_path = data_utils.get_file( 214 fname, 215 BASE_WEIGHT_URL + fname, 216 cache_subdir='models', 217 file_hash='e693bd0210a403b3192acc6073ad2e96') 218 else: 219 fname = ('inception_resnet_v2_weights_' 220 'tf_dim_ordering_tf_kernels_notop.h5') 221 weights_path = data_utils.get_file( 222 fname, 223 BASE_WEIGHT_URL + fname, 224 cache_subdir='models', 225 file_hash='d19885ff4a710c122648d3b5c3b684e4') 226 model.load_weights(weights_path) 227 elif weights is not None: 228 model.load_weights(weights) 229 230 return model 231 232 233def conv2d_bn(x, 234 filters, 235 kernel_size, 236 strides=1, 237 padding='same', 238 activation='relu', 239 use_bias=False, 240 name=None): 241 """Utility function to apply conv + BN. 242 243 Arguments: 244 x: input tensor. 245 filters: filters in `Conv2D`. 246 kernel_size: kernel size as in `Conv2D`. 247 strides: strides in `Conv2D`. 248 padding: padding mode in `Conv2D`. 249 activation: activation in `Conv2D`. 250 use_bias: whether to use a bias in `Conv2D`. 251 name: name of the ops; will become `name + '_ac'` for the activation 252 and `name + '_bn'` for the batch norm layer. 253 254 Returns: 255 Output tensor after applying `Conv2D` and `BatchNormalization`. 256 """ 257 x = layers.Conv2D( 258 filters, 259 kernel_size, 260 strides=strides, 261 padding=padding, 262 use_bias=use_bias, 263 name=name)( 264 x) 265 if not use_bias: 266 bn_axis = 1 if backend.image_data_format() == 'channels_first' else 3 267 bn_name = None if name is None else name + '_bn' 268 x = layers.BatchNormalization(axis=bn_axis, scale=False, name=bn_name)(x) 269 if activation is not None: 270 ac_name = None if name is None else name + '_ac' 271 x = layers.Activation(activation, name=ac_name)(x) 272 return x 273 274 275def inception_resnet_block(x, scale, block_type, block_idx, activation='relu'): 276 """Adds a Inception-ResNet block. 277 278 This function builds 3 types of Inception-ResNet blocks mentioned 279 in the paper, controlled by the `block_type` argument (which is the 280 block name used in the official TF-slim implementation): 281 - Inception-ResNet-A: `block_type='block35'` 282 - Inception-ResNet-B: `block_type='block17'` 283 - Inception-ResNet-C: `block_type='block8'` 284 285 Arguments: 286 x: input tensor. 287 scale: scaling factor to scale the residuals (i.e., the output of 288 passing `x` through an inception module) before adding them 289 to the shortcut branch. 290 Let `r` be the output from the residual branch, 291 the output of this block will be `x + scale * r`. 292 block_type: `'block35'`, `'block17'` or `'block8'`, determines 293 the network structure in the residual branch. 294 block_idx: an `int` used for generating layer names. 295 The Inception-ResNet blocks 296 are repeated many times in this network. 297 We use `block_idx` to identify 298 each of the repetitions. For example, 299 the first Inception-ResNet-A block 300 will have `block_type='block35', block_idx=0`, 301 and the layer names will have 302 a common prefix `'block35_0'`. 303 activation: activation function to use at the end of the block 304 (see [activations](../activations.md)). 305 When `activation=None`, no activation is applied 306 (i.e., "linear" activation: `a(x) = x`). 307 308 Returns: 309 Output tensor for the block. 310 311 Raises: 312 ValueError: if `block_type` is not one of `'block35'`, 313 `'block17'` or `'block8'`. 314 """ 315 if block_type == 'block35': 316 branch_0 = conv2d_bn(x, 32, 1) 317 branch_1 = conv2d_bn(x, 32, 1) 318 branch_1 = conv2d_bn(branch_1, 32, 3) 319 branch_2 = conv2d_bn(x, 32, 1) 320 branch_2 = conv2d_bn(branch_2, 48, 3) 321 branch_2 = conv2d_bn(branch_2, 64, 3) 322 branches = [branch_0, branch_1, branch_2] 323 elif block_type == 'block17': 324 branch_0 = conv2d_bn(x, 192, 1) 325 branch_1 = conv2d_bn(x, 128, 1) 326 branch_1 = conv2d_bn(branch_1, 160, [1, 7]) 327 branch_1 = conv2d_bn(branch_1, 192, [7, 1]) 328 branches = [branch_0, branch_1] 329 elif block_type == 'block8': 330 branch_0 = conv2d_bn(x, 192, 1) 331 branch_1 = conv2d_bn(x, 192, 1) 332 branch_1 = conv2d_bn(branch_1, 224, [1, 3]) 333 branch_1 = conv2d_bn(branch_1, 256, [3, 1]) 334 branches = [branch_0, branch_1] 335 else: 336 raise ValueError('Unknown Inception-ResNet block type. ' 337 'Expects "block35", "block17" or "block8", ' 338 'but got: ' + str(block_type)) 339 340 block_name = block_type + '_' + str(block_idx) 341 channel_axis = 1 if backend.image_data_format() == 'channels_first' else 3 342 mixed = layers.Concatenate( 343 axis=channel_axis, name=block_name + '_mixed')( 344 branches) 345 up = conv2d_bn( 346 mixed, 347 backend.int_shape(x)[channel_axis], 348 1, 349 activation=None, 350 use_bias=True, 351 name=block_name + '_conv') 352 353 x = layers.Lambda( 354 lambda inputs, scale: inputs[0] + inputs[1] * scale, 355 output_shape=backend.int_shape(x)[1:], 356 arguments={'scale': scale}, 357 name=block_name)([x, up]) 358 if activation is not None: 359 x = layers.Activation(activation, name=block_name + '_ac')(x) 360 return x 361 362 363@keras_export('keras.applications.inception_resnet_v2.preprocess_input') 364def preprocess_input(x, data_format=None): 365 return imagenet_utils.preprocess_input(x, data_format=data_format, mode='tf') 366 367 368@keras_export('keras.applications.inception_resnet_v2.decode_predictions') 369def decode_predictions(preds, top=5): 370 return imagenet_utils.decode_predictions(preds, top=top) 371