• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=invalid-name
16"""Xception V1 model for Keras.
17
18On ImageNet, this model gets to a top-1 validation accuracy of 0.790
19and a top-5 validation accuracy of 0.945.
20
21Reference paper:
22  - [Xception: Deep Learning with Depthwise Separable Convolutions](
23      https://arxiv.org/abs/1610.02357) (CVPR 2017)
24
25"""
26from __future__ import absolute_import
27from __future__ import division
28from __future__ import print_function
29
30import os
31
32from tensorflow.python.keras import backend
33from tensorflow.python.keras import layers
34from tensorflow.python.keras.applications import imagenet_utils
35from tensorflow.python.keras.engine import training
36from tensorflow.python.keras.utils import data_utils
37from tensorflow.python.keras.utils import layer_utils
38from tensorflow.python.util.tf_export import keras_export
39
40
41TF_WEIGHTS_PATH = (
42    'https://storage.googleapis.com/tensorflow/keras-applications/'
43    'xception/xception_weights_tf_dim_ordering_tf_kernels.h5')
44TF_WEIGHTS_PATH_NO_TOP = (
45    'https://storage.googleapis.com/tensorflow/keras-applications/'
46    'xception/xception_weights_tf_dim_ordering_tf_kernels_notop.h5')
47
48
49@keras_export('keras.applications.xception.Xception',
50              'keras.applications.Xception')
51def Xception(include_top=True,
52             weights='imagenet',
53             input_tensor=None,
54             input_shape=None,
55             pooling=None,
56             classes=1000):
57  """Instantiates the Xception architecture.
58
59  Optionally loads weights pre-trained on ImageNet.
60  Note that the data format convention used by the model is
61  the one specified in your Keras config at `~/.keras/keras.json`.
62  Note that the default input image size for this model is 299x299.
63
64  Arguments:
65    include_top: whether to include the fully-connected
66      layer at the top of the network.
67    weights: one of `None` (random initialization),
68      'imagenet' (pre-training on ImageNet),
69      or the path to the weights file to be loaded.
70    input_tensor: optional Keras tensor
71      (i.e. output of `layers.Input()`)
72      to use as image input for the model.
73    input_shape: optional shape tuple, only to be specified
74      if `include_top` is False (otherwise the input shape
75      has to be `(299, 299, 3)`.
76      It should have exactly 3 inputs channels,
77      and width and height should be no smaller than 71.
78      E.g. `(150, 150, 3)` would be one valid value.
79    pooling: Optional pooling mode for feature extraction
80      when `include_top` is `False`.
81      - `None` means that the output of the model will be
82          the 4D tensor output of the
83          last convolutional block.
84      - `avg` means that global average pooling
85          will be applied to the output of the
86          last convolutional block, and thus
87          the output of the model will be a 2D tensor.
88      - `max` means that global max pooling will
89          be applied.
90    classes: optional number of classes to classify images
91      into, only to be specified if `include_top` is True,
92      and if no `weights` argument is specified.
93
94  Returns:
95    A Keras model instance.
96
97  Raises:
98    ValueError: in case of invalid argument for `weights`,
99      or invalid input shape.
100  """
101  if not (weights in {'imagenet', None} or os.path.exists(weights)):
102    raise ValueError('The `weights` argument should be either '
103                     '`None` (random initialization), `imagenet` '
104                     '(pre-training on ImageNet), '
105                     'or the path to the weights file to be loaded.')
106
107  if weights == 'imagenet' and include_top and classes != 1000:
108    raise ValueError('If using `weights` as `"imagenet"` with `include_top`'
109                     ' as true, `classes` should be 1000')
110
111  # Determine proper input shape
112  input_shape = imagenet_utils.obtain_input_shape(
113      input_shape,
114      default_size=299,
115      min_size=71,
116      data_format=backend.image_data_format(),
117      require_flatten=include_top,
118      weights=weights)
119
120  if input_tensor is None:
121    img_input = layers.Input(shape=input_shape)
122  else:
123    if not backend.is_keras_tensor(input_tensor):
124      img_input = layers.Input(tensor=input_tensor, shape=input_shape)
125    else:
126      img_input = input_tensor
127
128  channel_axis = 1 if backend.image_data_format() == 'channels_first' else -1
129
130  x = layers.Conv2D(
131      32, (3, 3),
132      strides=(2, 2),
133      use_bias=False,
134      name='block1_conv1')(img_input)
135  x = layers.BatchNormalization(axis=channel_axis, name='block1_conv1_bn')(x)
136  x = layers.Activation('relu', name='block1_conv1_act')(x)
137  x = layers.Conv2D(64, (3, 3), use_bias=False, name='block1_conv2')(x)
138  x = layers.BatchNormalization(axis=channel_axis, name='block1_conv2_bn')(x)
139  x = layers.Activation('relu', name='block1_conv2_act')(x)
140
141  residual = layers.Conv2D(
142      128, (1, 1), strides=(2, 2), padding='same', use_bias=False)(x)
143  residual = layers.BatchNormalization(axis=channel_axis)(residual)
144
145  x = layers.SeparableConv2D(
146      128, (3, 3), padding='same', use_bias=False, name='block2_sepconv1')(x)
147  x = layers.BatchNormalization(axis=channel_axis, name='block2_sepconv1_bn')(x)
148  x = layers.Activation('relu', name='block2_sepconv2_act')(x)
149  x = layers.SeparableConv2D(
150      128, (3, 3), padding='same', use_bias=False, name='block2_sepconv2')(x)
151  x = layers.BatchNormalization(axis=channel_axis, name='block2_sepconv2_bn')(x)
152
153  x = layers.MaxPooling2D((3, 3),
154                          strides=(2, 2),
155                          padding='same',
156                          name='block2_pool')(x)
157  x = layers.add([x, residual])
158
159  residual = layers.Conv2D(
160      256, (1, 1), strides=(2, 2), padding='same', use_bias=False)(x)
161  residual = layers.BatchNormalization(axis=channel_axis)(residual)
162
163  x = layers.Activation('relu', name='block3_sepconv1_act')(x)
164  x = layers.SeparableConv2D(
165      256, (3, 3), padding='same', use_bias=False, name='block3_sepconv1')(x)
166  x = layers.BatchNormalization(axis=channel_axis, name='block3_sepconv1_bn')(x)
167  x = layers.Activation('relu', name='block3_sepconv2_act')(x)
168  x = layers.SeparableConv2D(
169      256, (3, 3), padding='same', use_bias=False, name='block3_sepconv2')(x)
170  x = layers.BatchNormalization(axis=channel_axis, name='block3_sepconv2_bn')(x)
171
172  x = layers.MaxPooling2D((3, 3),
173                          strides=(2, 2),
174                          padding='same',
175                          name='block3_pool')(x)
176  x = layers.add([x, residual])
177
178  residual = layers.Conv2D(
179      728, (1, 1), strides=(2, 2), padding='same', use_bias=False)(x)
180  residual = layers.BatchNormalization(axis=channel_axis)(residual)
181
182  x = layers.Activation('relu', name='block4_sepconv1_act')(x)
183  x = layers.SeparableConv2D(
184      728, (3, 3), padding='same', use_bias=False, name='block4_sepconv1')(x)
185  x = layers.BatchNormalization(axis=channel_axis, name='block4_sepconv1_bn')(x)
186  x = layers.Activation('relu', name='block4_sepconv2_act')(x)
187  x = layers.SeparableConv2D(
188      728, (3, 3), padding='same', use_bias=False, name='block4_sepconv2')(x)
189  x = layers.BatchNormalization(axis=channel_axis, name='block4_sepconv2_bn')(x)
190
191  x = layers.MaxPooling2D((3, 3),
192                          strides=(2, 2),
193                          padding='same',
194                          name='block4_pool')(x)
195  x = layers.add([x, residual])
196
197  for i in range(8):
198    residual = x
199    prefix = 'block' + str(i + 5)
200
201    x = layers.Activation('relu', name=prefix + '_sepconv1_act')(x)
202    x = layers.SeparableConv2D(
203        728, (3, 3),
204        padding='same',
205        use_bias=False,
206        name=prefix + '_sepconv1')(x)
207    x = layers.BatchNormalization(
208        axis=channel_axis, name=prefix + '_sepconv1_bn')(x)
209    x = layers.Activation('relu', name=prefix + '_sepconv2_act')(x)
210    x = layers.SeparableConv2D(
211        728, (3, 3),
212        padding='same',
213        use_bias=False,
214        name=prefix + '_sepconv2')(x)
215    x = layers.BatchNormalization(
216        axis=channel_axis, name=prefix + '_sepconv2_bn')(x)
217    x = layers.Activation('relu', name=prefix + '_sepconv3_act')(x)
218    x = layers.SeparableConv2D(
219        728, (3, 3),
220        padding='same',
221        use_bias=False,
222        name=prefix + '_sepconv3')(x)
223    x = layers.BatchNormalization(
224        axis=channel_axis, name=prefix + '_sepconv3_bn')(x)
225
226    x = layers.add([x, residual])
227
228  residual = layers.Conv2D(
229      1024, (1, 1), strides=(2, 2), padding='same', use_bias=False)(x)
230  residual = layers.BatchNormalization(axis=channel_axis)(residual)
231
232  x = layers.Activation('relu', name='block13_sepconv1_act')(x)
233  x = layers.SeparableConv2D(
234      728, (3, 3), padding='same', use_bias=False, name='block13_sepconv1')(x)
235  x = layers.BatchNormalization(
236      axis=channel_axis, name='block13_sepconv1_bn')(x)
237  x = layers.Activation('relu', name='block13_sepconv2_act')(x)
238  x = layers.SeparableConv2D(
239      1024, (3, 3), padding='same', use_bias=False, name='block13_sepconv2')(x)
240  x = layers.BatchNormalization(
241      axis=channel_axis, name='block13_sepconv2_bn')(x)
242
243  x = layers.MaxPooling2D((3, 3),
244                          strides=(2, 2),
245                          padding='same',
246                          name='block13_pool')(x)
247  x = layers.add([x, residual])
248
249  x = layers.SeparableConv2D(
250      1536, (3, 3), padding='same', use_bias=False, name='block14_sepconv1')(x)
251  x = layers.BatchNormalization(
252      axis=channel_axis, name='block14_sepconv1_bn')(x)
253  x = layers.Activation('relu', name='block14_sepconv1_act')(x)
254
255  x = layers.SeparableConv2D(
256      2048, (3, 3), padding='same', use_bias=False, name='block14_sepconv2')(x)
257  x = layers.BatchNormalization(
258      axis=channel_axis, name='block14_sepconv2_bn')(x)
259  x = layers.Activation('relu', name='block14_sepconv2_act')(x)
260
261  if include_top:
262    x = layers.GlobalAveragePooling2D(name='avg_pool')(x)
263    x = layers.Dense(classes, activation='softmax', name='predictions')(x)
264  else:
265    if pooling == 'avg':
266      x = layers.GlobalAveragePooling2D()(x)
267    elif pooling == 'max':
268      x = layers.GlobalMaxPooling2D()(x)
269
270  # Ensure that the model takes into account
271  # any potential predecessors of `input_tensor`.
272  if input_tensor is not None:
273    inputs = layer_utils.get_source_inputs(input_tensor)
274  else:
275    inputs = img_input
276  # Create model.
277  model = training.Model(inputs, x, name='xception')
278
279  # Load weights.
280  if weights == 'imagenet':
281    if include_top:
282      weights_path = data_utils.get_file(
283          'xception_weights_tf_dim_ordering_tf_kernels.h5',
284          TF_WEIGHTS_PATH,
285          cache_subdir='models',
286          file_hash='0a58e3b7378bc2990ea3b43d5981f1f6')
287    else:
288      weights_path = data_utils.get_file(
289          'xception_weights_tf_dim_ordering_tf_kernels_notop.h5',
290          TF_WEIGHTS_PATH_NO_TOP,
291          cache_subdir='models',
292          file_hash='b0042744bf5b25fce3cb969f33bebb97')
293    model.load_weights(weights_path)
294  elif weights is not None:
295    model.load_weights(weights)
296
297  return model
298
299
300@keras_export('keras.applications.xception.preprocess_input')
301def preprocess_input(x, data_format=None):
302  return imagenet_utils.preprocess_input(x, data_format=data_format, mode='tf')
303
304
305@keras_export('keras.applications.xception.decode_predictions')
306def decode_predictions(preds, top=5):
307  return imagenet_utils.decode_predictions(preds, top=top)
308