• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=invalid-name
16"""Xception V1 model for Keras.
17
18On ImageNet, this model gets to a top-1 validation accuracy of 0.790
19and a top-5 validation accuracy of 0.945.
20
21Reference:
22  - [Xception: Deep Learning with Depthwise Separable Convolutions](
23      https://arxiv.org/abs/1610.02357) (CVPR 2017)
24"""
25
26from tensorflow.python.keras import backend
27from tensorflow.python.keras.applications import imagenet_utils
28from tensorflow.python.keras.engine import training
29from tensorflow.python.keras.layers import VersionAwareLayers
30from tensorflow.python.keras.utils import data_utils
31from tensorflow.python.keras.utils import layer_utils
32from tensorflow.python.lib.io import file_io
33from tensorflow.python.util.tf_export import keras_export
34
35
36TF_WEIGHTS_PATH = (
37    'https://storage.googleapis.com/tensorflow/keras-applications/'
38    'xception/xception_weights_tf_dim_ordering_tf_kernels.h5')
39TF_WEIGHTS_PATH_NO_TOP = (
40    'https://storage.googleapis.com/tensorflow/keras-applications/'
41    'xception/xception_weights_tf_dim_ordering_tf_kernels_notop.h5')
42
43layers = VersionAwareLayers()
44
45
46@keras_export('keras.applications.xception.Xception',
47              'keras.applications.Xception')
48def Xception(
49    include_top=True,
50    weights='imagenet',
51    input_tensor=None,
52    input_shape=None,
53    pooling=None,
54    classes=1000,
55    classifier_activation='softmax'):
56  """Instantiates the Xception architecture.
57
58  Reference:
59  - [Xception: Deep Learning with Depthwise Separable Convolutions](
60      https://arxiv.org/abs/1610.02357) (CVPR 2017)
61
62  For image classification use cases, see
63  [this page for detailed examples](
64    https://keras.io/api/applications/#usage-examples-for-image-classification-models).
65
66  For transfer learning use cases, make sure to read the
67  [guide to transfer learning & fine-tuning](
68    https://keras.io/guides/transfer_learning/).
69
70  The default input image size for this model is 299x299.
71
72  Note: each Keras Application expects a specific kind of input preprocessing.
73  For Xception, call `tf.keras.applications.xception.preprocess_input` on your
74  inputs before passing them to the model.
75  `xception.preprocess_input` will scale input pixels between -1 and 1.
76
77  Args:
78    include_top: whether to include the fully-connected
79      layer at the top of the network.
80    weights: one of `None` (random initialization),
81      'imagenet' (pre-training on ImageNet),
82      or the path to the weights file to be loaded.
83    input_tensor: optional Keras tensor
84      (i.e. output of `layers.Input()`)
85      to use as image input for the model.
86    input_shape: optional shape tuple, only to be specified
87      if `include_top` is False (otherwise the input shape
88      has to be `(299, 299, 3)`.
89      It should have exactly 3 inputs channels,
90      and width and height should be no smaller than 71.
91      E.g. `(150, 150, 3)` would be one valid value.
92    pooling: Optional pooling mode for feature extraction
93      when `include_top` is `False`.
94      - `None` means that the output of the model will be
95          the 4D tensor output of the
96          last convolutional block.
97      - `avg` means that global average pooling
98          will be applied to the output of the
99          last convolutional block, and thus
100          the output of the model will be a 2D tensor.
101      - `max` means that global max pooling will
102          be applied.
103    classes: optional number of classes to classify images
104      into, only to be specified if `include_top` is True,
105      and if no `weights` argument is specified.
106    classifier_activation: A `str` or callable. The activation function to use
107      on the "top" layer. Ignored unless `include_top=True`. Set
108      `classifier_activation=None` to return the logits of the "top" layer.
109      When loading pretrained weights, `classifier_activation` can only
110      be `None` or `"softmax"`.
111
112  Returns:
113    A `keras.Model` instance.
114  """
115  if not (weights in {'imagenet', None} or file_io.file_exists_v2(weights)):
116    raise ValueError('The `weights` argument should be either '
117                     '`None` (random initialization), `imagenet` '
118                     '(pre-training on ImageNet), '
119                     'or the path to the weights file to be loaded.')
120
121  if weights == 'imagenet' and include_top and classes != 1000:
122    raise ValueError('If using `weights` as `"imagenet"` with `include_top`'
123                     ' as true, `classes` should be 1000')
124
125  # Determine proper input shape
126  input_shape = imagenet_utils.obtain_input_shape(
127      input_shape,
128      default_size=299,
129      min_size=71,
130      data_format=backend.image_data_format(),
131      require_flatten=include_top,
132      weights=weights)
133
134  if input_tensor is None:
135    img_input = layers.Input(shape=input_shape)
136  else:
137    if not backend.is_keras_tensor(input_tensor):
138      img_input = layers.Input(tensor=input_tensor, shape=input_shape)
139    else:
140      img_input = input_tensor
141
142  channel_axis = 1 if backend.image_data_format() == 'channels_first' else -1
143
144  x = layers.Conv2D(
145      32, (3, 3),
146      strides=(2, 2),
147      use_bias=False,
148      name='block1_conv1')(img_input)
149  x = layers.BatchNormalization(axis=channel_axis, name='block1_conv1_bn')(x)
150  x = layers.Activation('relu', name='block1_conv1_act')(x)
151  x = layers.Conv2D(64, (3, 3), use_bias=False, name='block1_conv2')(x)
152  x = layers.BatchNormalization(axis=channel_axis, name='block1_conv2_bn')(x)
153  x = layers.Activation('relu', name='block1_conv2_act')(x)
154
155  residual = layers.Conv2D(
156      128, (1, 1), strides=(2, 2), padding='same', use_bias=False)(x)
157  residual = layers.BatchNormalization(axis=channel_axis)(residual)
158
159  x = layers.SeparableConv2D(
160      128, (3, 3), padding='same', use_bias=False, name='block2_sepconv1')(x)
161  x = layers.BatchNormalization(axis=channel_axis, name='block2_sepconv1_bn')(x)
162  x = layers.Activation('relu', name='block2_sepconv2_act')(x)
163  x = layers.SeparableConv2D(
164      128, (3, 3), padding='same', use_bias=False, name='block2_sepconv2')(x)
165  x = layers.BatchNormalization(axis=channel_axis, name='block2_sepconv2_bn')(x)
166
167  x = layers.MaxPooling2D((3, 3),
168                          strides=(2, 2),
169                          padding='same',
170                          name='block2_pool')(x)
171  x = layers.add([x, residual])
172
173  residual = layers.Conv2D(
174      256, (1, 1), strides=(2, 2), padding='same', use_bias=False)(x)
175  residual = layers.BatchNormalization(axis=channel_axis)(residual)
176
177  x = layers.Activation('relu', name='block3_sepconv1_act')(x)
178  x = layers.SeparableConv2D(
179      256, (3, 3), padding='same', use_bias=False, name='block3_sepconv1')(x)
180  x = layers.BatchNormalization(axis=channel_axis, name='block3_sepconv1_bn')(x)
181  x = layers.Activation('relu', name='block3_sepconv2_act')(x)
182  x = layers.SeparableConv2D(
183      256, (3, 3), padding='same', use_bias=False, name='block3_sepconv2')(x)
184  x = layers.BatchNormalization(axis=channel_axis, name='block3_sepconv2_bn')(x)
185
186  x = layers.MaxPooling2D((3, 3),
187                          strides=(2, 2),
188                          padding='same',
189                          name='block3_pool')(x)
190  x = layers.add([x, residual])
191
192  residual = layers.Conv2D(
193      728, (1, 1), strides=(2, 2), padding='same', use_bias=False)(x)
194  residual = layers.BatchNormalization(axis=channel_axis)(residual)
195
196  x = layers.Activation('relu', name='block4_sepconv1_act')(x)
197  x = layers.SeparableConv2D(
198      728, (3, 3), padding='same', use_bias=False, name='block4_sepconv1')(x)
199  x = layers.BatchNormalization(axis=channel_axis, name='block4_sepconv1_bn')(x)
200  x = layers.Activation('relu', name='block4_sepconv2_act')(x)
201  x = layers.SeparableConv2D(
202      728, (3, 3), padding='same', use_bias=False, name='block4_sepconv2')(x)
203  x = layers.BatchNormalization(axis=channel_axis, name='block4_sepconv2_bn')(x)
204
205  x = layers.MaxPooling2D((3, 3),
206                          strides=(2, 2),
207                          padding='same',
208                          name='block4_pool')(x)
209  x = layers.add([x, residual])
210
211  for i in range(8):
212    residual = x
213    prefix = 'block' + str(i + 5)
214
215    x = layers.Activation('relu', name=prefix + '_sepconv1_act')(x)
216    x = layers.SeparableConv2D(
217        728, (3, 3),
218        padding='same',
219        use_bias=False,
220        name=prefix + '_sepconv1')(x)
221    x = layers.BatchNormalization(
222        axis=channel_axis, name=prefix + '_sepconv1_bn')(x)
223    x = layers.Activation('relu', name=prefix + '_sepconv2_act')(x)
224    x = layers.SeparableConv2D(
225        728, (3, 3),
226        padding='same',
227        use_bias=False,
228        name=prefix + '_sepconv2')(x)
229    x = layers.BatchNormalization(
230        axis=channel_axis, name=prefix + '_sepconv2_bn')(x)
231    x = layers.Activation('relu', name=prefix + '_sepconv3_act')(x)
232    x = layers.SeparableConv2D(
233        728, (3, 3),
234        padding='same',
235        use_bias=False,
236        name=prefix + '_sepconv3')(x)
237    x = layers.BatchNormalization(
238        axis=channel_axis, name=prefix + '_sepconv3_bn')(x)
239
240    x = layers.add([x, residual])
241
242  residual = layers.Conv2D(
243      1024, (1, 1), strides=(2, 2), padding='same', use_bias=False)(x)
244  residual = layers.BatchNormalization(axis=channel_axis)(residual)
245
246  x = layers.Activation('relu', name='block13_sepconv1_act')(x)
247  x = layers.SeparableConv2D(
248      728, (3, 3), padding='same', use_bias=False, name='block13_sepconv1')(x)
249  x = layers.BatchNormalization(
250      axis=channel_axis, name='block13_sepconv1_bn')(x)
251  x = layers.Activation('relu', name='block13_sepconv2_act')(x)
252  x = layers.SeparableConv2D(
253      1024, (3, 3), padding='same', use_bias=False, name='block13_sepconv2')(x)
254  x = layers.BatchNormalization(
255      axis=channel_axis, name='block13_sepconv2_bn')(x)
256
257  x = layers.MaxPooling2D((3, 3),
258                          strides=(2, 2),
259                          padding='same',
260                          name='block13_pool')(x)
261  x = layers.add([x, residual])
262
263  x = layers.SeparableConv2D(
264      1536, (3, 3), padding='same', use_bias=False, name='block14_sepconv1')(x)
265  x = layers.BatchNormalization(
266      axis=channel_axis, name='block14_sepconv1_bn')(x)
267  x = layers.Activation('relu', name='block14_sepconv1_act')(x)
268
269  x = layers.SeparableConv2D(
270      2048, (3, 3), padding='same', use_bias=False, name='block14_sepconv2')(x)
271  x = layers.BatchNormalization(
272      axis=channel_axis, name='block14_sepconv2_bn')(x)
273  x = layers.Activation('relu', name='block14_sepconv2_act')(x)
274
275  if include_top:
276    x = layers.GlobalAveragePooling2D(name='avg_pool')(x)
277    imagenet_utils.validate_activation(classifier_activation, weights)
278    x = layers.Dense(classes, activation=classifier_activation,
279                     name='predictions')(x)
280  else:
281    if pooling == 'avg':
282      x = layers.GlobalAveragePooling2D()(x)
283    elif pooling == 'max':
284      x = layers.GlobalMaxPooling2D()(x)
285
286  # Ensure that the model takes into account
287  # any potential predecessors of `input_tensor`.
288  if input_tensor is not None:
289    inputs = layer_utils.get_source_inputs(input_tensor)
290  else:
291    inputs = img_input
292  # Create model.
293  model = training.Model(inputs, x, name='xception')
294
295  # Load weights.
296  if weights == 'imagenet':
297    if include_top:
298      weights_path = data_utils.get_file(
299          'xception_weights_tf_dim_ordering_tf_kernels.h5',
300          TF_WEIGHTS_PATH,
301          cache_subdir='models',
302          file_hash='0a58e3b7378bc2990ea3b43d5981f1f6')
303    else:
304      weights_path = data_utils.get_file(
305          'xception_weights_tf_dim_ordering_tf_kernels_notop.h5',
306          TF_WEIGHTS_PATH_NO_TOP,
307          cache_subdir='models',
308          file_hash='b0042744bf5b25fce3cb969f33bebb97')
309    model.load_weights(weights_path)
310  elif weights is not None:
311    model.load_weights(weights)
312
313  return model
314
315
316@keras_export('keras.applications.xception.preprocess_input')
317def preprocess_input(x, data_format=None):
318  return imagenet_utils.preprocess_input(x, data_format=data_format, mode='tf')
319
320
321@keras_export('keras.applications.xception.decode_predictions')
322def decode_predictions(preds, top=5):
323  return imagenet_utils.decode_predictions(preds, top=top)
324
325
326preprocess_input.__doc__ = imagenet_utils.PREPROCESS_INPUT_DOC.format(
327    mode='',
328    ret=imagenet_utils.PREPROCESS_INPUT_RET_DOC_TF,
329    error=imagenet_utils.PREPROCESS_INPUT_ERROR_DOC)
330decode_predictions.__doc__ = imagenet_utils.decode_predictions.__doc__
331