• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=invalid-name
16"""Inception-ResNet V2 model for Keras.
17
18
19Reference paper:
20  - [Inception-v4, Inception-ResNet and the Impact of
21     Residual Connections on Learning](https://arxiv.org/abs/1602.07261)
22    (AAAI 2017)
23"""
24from __future__ import absolute_import
25from __future__ import division
26from __future__ import print_function
27
28import os
29
30from tensorflow.python.keras import backend
31from tensorflow.python.keras import layers
32from tensorflow.python.keras.applications import imagenet_utils
33from tensorflow.python.keras.engine import training
34from tensorflow.python.keras.utils import data_utils
35from tensorflow.python.keras.utils import layer_utils
36from tensorflow.python.util.tf_export import keras_export
37
38
39BASE_WEIGHT_URL = ('https://storage.googleapis.com/tensorflow/'
40                   'keras-applications/inception_resnet_v2/')
41
42
43@keras_export('keras.applications.inception_resnet_v2.InceptionResNetV2',
44              'keras.applications.InceptionResNetV2')
45def InceptionResNetV2(include_top=True,
46                      weights='imagenet',
47                      input_tensor=None,
48                      input_shape=None,
49                      pooling=None,
50                      classes=1000,
51                      **kwargs):
52  """Instantiates the Inception-ResNet v2 architecture.
53
54  Optionally loads weights pre-trained on ImageNet.
55  Note that the data format convention used by the model is
56  the one specified in your Keras config at `~/.keras/keras.json`.
57
58  Arguments:
59    include_top: whether to include the fully-connected
60      layer at the top of the network.
61    weights: one of `None` (random initialization),
62      'imagenet' (pre-training on ImageNet),
63      or the path to the weights file to be loaded.
64    input_tensor: optional Keras tensor (i.e. output of `layers.Input()`)
65      to use as image input for the model.
66    input_shape: optional shape tuple, only to be specified
67      if `include_top` is `False` (otherwise the input shape
68      has to be `(299, 299, 3)` (with `'channels_last'` data format)
69      or `(3, 299, 299)` (with `'channels_first'` data format).
70      It should have exactly 3 inputs channels,
71      and width and height should be no smaller than 75.
72      E.g. `(150, 150, 3)` would be one valid value.
73    pooling: Optional pooling mode for feature extraction
74      when `include_top` is `False`.
75      - `None` means that the output of the model will be
76          the 4D tensor output of the last convolutional block.
77      - `'avg'` means that global average pooling
78          will be applied to the output of the
79          last convolutional block, and thus
80          the output of the model will be a 2D tensor.
81      - `'max'` means that global max pooling will be applied.
82    classes: optional number of classes to classify images
83      into, only to be specified if `include_top` is `True`, and
84      if no `weights` argument is specified.
85    **kwargs: For backwards compatibility only.
86
87  Returns:
88    A Keras `Model` instance.
89
90  Raises:
91    ValueError: in case of invalid argument for `weights`,
92      or invalid input shape.
93  """
94  if 'layers' in kwargs:
95    global layers
96    layers = kwargs.pop('layers')
97  if kwargs:
98    raise ValueError('Unknown argument(s): %s' % (kwargs,))
99  if not (weights in {'imagenet', None} or os.path.exists(weights)):
100    raise ValueError('The `weights` argument should be either '
101                     '`None` (random initialization), `imagenet` '
102                     '(pre-training on ImageNet), '
103                     'or the path to the weights file to be loaded.')
104
105  if weights == 'imagenet' and include_top and classes != 1000:
106    raise ValueError('If using `weights` as `"imagenet"` with `include_top`'
107                     ' as true, `classes` should be 1000')
108
109  # Determine proper input shape
110  input_shape = imagenet_utils.obtain_input_shape(
111      input_shape,
112      default_size=299,
113      min_size=75,
114      data_format=backend.image_data_format(),
115      require_flatten=include_top,
116      weights=weights)
117
118  if input_tensor is None:
119    img_input = layers.Input(shape=input_shape)
120  else:
121    if not backend.is_keras_tensor(input_tensor):
122      img_input = layers.Input(tensor=input_tensor, shape=input_shape)
123    else:
124      img_input = input_tensor
125
126  # Stem block: 35 x 35 x 192
127  x = conv2d_bn(img_input, 32, 3, strides=2, padding='valid')
128  x = conv2d_bn(x, 32, 3, padding='valid')
129  x = conv2d_bn(x, 64, 3)
130  x = layers.MaxPooling2D(3, strides=2)(x)
131  x = conv2d_bn(x, 80, 1, padding='valid')
132  x = conv2d_bn(x, 192, 3, padding='valid')
133  x = layers.MaxPooling2D(3, strides=2)(x)
134
135  # Mixed 5b (Inception-A block): 35 x 35 x 320
136  branch_0 = conv2d_bn(x, 96, 1)
137  branch_1 = conv2d_bn(x, 48, 1)
138  branch_1 = conv2d_bn(branch_1, 64, 5)
139  branch_2 = conv2d_bn(x, 64, 1)
140  branch_2 = conv2d_bn(branch_2, 96, 3)
141  branch_2 = conv2d_bn(branch_2, 96, 3)
142  branch_pool = layers.AveragePooling2D(3, strides=1, padding='same')(x)
143  branch_pool = conv2d_bn(branch_pool, 64, 1)
144  branches = [branch_0, branch_1, branch_2, branch_pool]
145  channel_axis = 1 if backend.image_data_format() == 'channels_first' else 3
146  x = layers.Concatenate(axis=channel_axis, name='mixed_5b')(branches)
147
148  # 10x block35 (Inception-ResNet-A block): 35 x 35 x 320
149  for block_idx in range(1, 11):
150    x = inception_resnet_block(
151        x, scale=0.17, block_type='block35', block_idx=block_idx)
152
153  # Mixed 6a (Reduction-A block): 17 x 17 x 1088
154  branch_0 = conv2d_bn(x, 384, 3, strides=2, padding='valid')
155  branch_1 = conv2d_bn(x, 256, 1)
156  branch_1 = conv2d_bn(branch_1, 256, 3)
157  branch_1 = conv2d_bn(branch_1, 384, 3, strides=2, padding='valid')
158  branch_pool = layers.MaxPooling2D(3, strides=2, padding='valid')(x)
159  branches = [branch_0, branch_1, branch_pool]
160  x = layers.Concatenate(axis=channel_axis, name='mixed_6a')(branches)
161
162  # 20x block17 (Inception-ResNet-B block): 17 x 17 x 1088
163  for block_idx in range(1, 21):
164    x = inception_resnet_block(
165        x, scale=0.1, block_type='block17', block_idx=block_idx)
166
167  # Mixed 7a (Reduction-B block): 8 x 8 x 2080
168  branch_0 = conv2d_bn(x, 256, 1)
169  branch_0 = conv2d_bn(branch_0, 384, 3, strides=2, padding='valid')
170  branch_1 = conv2d_bn(x, 256, 1)
171  branch_1 = conv2d_bn(branch_1, 288, 3, strides=2, padding='valid')
172  branch_2 = conv2d_bn(x, 256, 1)
173  branch_2 = conv2d_bn(branch_2, 288, 3)
174  branch_2 = conv2d_bn(branch_2, 320, 3, strides=2, padding='valid')
175  branch_pool = layers.MaxPooling2D(3, strides=2, padding='valid')(x)
176  branches = [branch_0, branch_1, branch_2, branch_pool]
177  x = layers.Concatenate(axis=channel_axis, name='mixed_7a')(branches)
178
179  # 10x block8 (Inception-ResNet-C block): 8 x 8 x 2080
180  for block_idx in range(1, 10):
181    x = inception_resnet_block(
182        x, scale=0.2, block_type='block8', block_idx=block_idx)
183  x = inception_resnet_block(
184      x, scale=1., activation=None, block_type='block8', block_idx=10)
185
186  # Final convolution block: 8 x 8 x 1536
187  x = conv2d_bn(x, 1536, 1, name='conv_7b')
188
189  if include_top:
190    # Classification block
191    x = layers.GlobalAveragePooling2D(name='avg_pool')(x)
192    x = layers.Dense(classes, activation='softmax', name='predictions')(x)
193  else:
194    if pooling == 'avg':
195      x = layers.GlobalAveragePooling2D()(x)
196    elif pooling == 'max':
197      x = layers.GlobalMaxPooling2D()(x)
198
199  # Ensure that the model takes into account
200  # any potential predecessors of `input_tensor`.
201  if input_tensor is not None:
202    inputs = layer_utils.get_source_inputs(input_tensor)
203  else:
204    inputs = img_input
205
206  # Create model.
207  model = training.Model(inputs, x, name='inception_resnet_v2')
208
209  # Load weights.
210  if weights == 'imagenet':
211    if include_top:
212      fname = 'inception_resnet_v2_weights_tf_dim_ordering_tf_kernels.h5'
213      weights_path = data_utils.get_file(
214          fname,
215          BASE_WEIGHT_URL + fname,
216          cache_subdir='models',
217          file_hash='e693bd0210a403b3192acc6073ad2e96')
218    else:
219      fname = ('inception_resnet_v2_weights_'
220               'tf_dim_ordering_tf_kernels_notop.h5')
221      weights_path = data_utils.get_file(
222          fname,
223          BASE_WEIGHT_URL + fname,
224          cache_subdir='models',
225          file_hash='d19885ff4a710c122648d3b5c3b684e4')
226    model.load_weights(weights_path)
227  elif weights is not None:
228    model.load_weights(weights)
229
230  return model
231
232
233def conv2d_bn(x,
234              filters,
235              kernel_size,
236              strides=1,
237              padding='same',
238              activation='relu',
239              use_bias=False,
240              name=None):
241  """Utility function to apply conv + BN.
242
243  Arguments:
244    x: input tensor.
245    filters: filters in `Conv2D`.
246    kernel_size: kernel size as in `Conv2D`.
247    strides: strides in `Conv2D`.
248    padding: padding mode in `Conv2D`.
249    activation: activation in `Conv2D`.
250    use_bias: whether to use a bias in `Conv2D`.
251    name: name of the ops; will become `name + '_ac'` for the activation
252        and `name + '_bn'` for the batch norm layer.
253
254  Returns:
255    Output tensor after applying `Conv2D` and `BatchNormalization`.
256  """
257  x = layers.Conv2D(
258      filters,
259      kernel_size,
260      strides=strides,
261      padding=padding,
262      use_bias=use_bias,
263      name=name)(
264          x)
265  if not use_bias:
266    bn_axis = 1 if backend.image_data_format() == 'channels_first' else 3
267    bn_name = None if name is None else name + '_bn'
268    x = layers.BatchNormalization(axis=bn_axis, scale=False, name=bn_name)(x)
269  if activation is not None:
270    ac_name = None if name is None else name + '_ac'
271    x = layers.Activation(activation, name=ac_name)(x)
272  return x
273
274
275def inception_resnet_block(x, scale, block_type, block_idx, activation='relu'):
276  """Adds a Inception-ResNet block.
277
278  This function builds 3 types of Inception-ResNet blocks mentioned
279  in the paper, controlled by the `block_type` argument (which is the
280  block name used in the official TF-slim implementation):
281  - Inception-ResNet-A: `block_type='block35'`
282  - Inception-ResNet-B: `block_type='block17'`
283  - Inception-ResNet-C: `block_type='block8'`
284
285  Arguments:
286    x: input tensor.
287    scale: scaling factor to scale the residuals (i.e., the output of
288      passing `x` through an inception module) before adding them
289      to the shortcut branch.
290      Let `r` be the output from the residual branch,
291      the output of this block will be `x + scale * r`.
292    block_type: `'block35'`, `'block17'` or `'block8'`, determines
293      the network structure in the residual branch.
294    block_idx: an `int` used for generating layer names.
295      The Inception-ResNet blocks
296      are repeated many times in this network.
297      We use `block_idx` to identify
298      each of the repetitions. For example,
299      the first Inception-ResNet-A block
300      will have `block_type='block35', block_idx=0`,
301      and the layer names will have
302      a common prefix `'block35_0'`.
303    activation: activation function to use at the end of the block
304      (see [activations](../activations.md)).
305      When `activation=None`, no activation is applied
306      (i.e., "linear" activation: `a(x) = x`).
307
308  Returns:
309      Output tensor for the block.
310
311  Raises:
312    ValueError: if `block_type` is not one of `'block35'`,
313      `'block17'` or `'block8'`.
314  """
315  if block_type == 'block35':
316    branch_0 = conv2d_bn(x, 32, 1)
317    branch_1 = conv2d_bn(x, 32, 1)
318    branch_1 = conv2d_bn(branch_1, 32, 3)
319    branch_2 = conv2d_bn(x, 32, 1)
320    branch_2 = conv2d_bn(branch_2, 48, 3)
321    branch_2 = conv2d_bn(branch_2, 64, 3)
322    branches = [branch_0, branch_1, branch_2]
323  elif block_type == 'block17':
324    branch_0 = conv2d_bn(x, 192, 1)
325    branch_1 = conv2d_bn(x, 128, 1)
326    branch_1 = conv2d_bn(branch_1, 160, [1, 7])
327    branch_1 = conv2d_bn(branch_1, 192, [7, 1])
328    branches = [branch_0, branch_1]
329  elif block_type == 'block8':
330    branch_0 = conv2d_bn(x, 192, 1)
331    branch_1 = conv2d_bn(x, 192, 1)
332    branch_1 = conv2d_bn(branch_1, 224, [1, 3])
333    branch_1 = conv2d_bn(branch_1, 256, [3, 1])
334    branches = [branch_0, branch_1]
335  else:
336    raise ValueError('Unknown Inception-ResNet block type. '
337                     'Expects "block35", "block17" or "block8", '
338                     'but got: ' + str(block_type))
339
340  block_name = block_type + '_' + str(block_idx)
341  channel_axis = 1 if backend.image_data_format() == 'channels_first' else 3
342  mixed = layers.Concatenate(
343      axis=channel_axis, name=block_name + '_mixed')(
344          branches)
345  up = conv2d_bn(
346      mixed,
347      backend.int_shape(x)[channel_axis],
348      1,
349      activation=None,
350      use_bias=True,
351      name=block_name + '_conv')
352
353  x = layers.Lambda(
354      lambda inputs, scale: inputs[0] + inputs[1] * scale,
355      output_shape=backend.int_shape(x)[1:],
356      arguments={'scale': scale},
357      name=block_name)([x, up])
358  if activation is not None:
359    x = layers.Activation(activation, name=block_name + '_ac')(x)
360  return x
361
362
363@keras_export('keras.applications.inception_resnet_v2.preprocess_input')
364def preprocess_input(x, data_format=None):
365  return imagenet_utils.preprocess_input(x, data_format=data_format, mode='tf')
366
367
368@keras_export('keras.applications.inception_resnet_v2.decode_predictions')
369def decode_predictions(preds, top=5):
370  return imagenet_utils.decode_predictions(preds, top=top)
371