• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=invalid-name
16"""DenseNet models for Keras.
17
18Reference paper:
19  - [Densely Connected Convolutional Networks]
20    (https://arxiv.org/abs/1608.06993) (CVPR 2017 Best Paper Award)
21"""
22from __future__ import absolute_import
23from __future__ import division
24from __future__ import print_function
25
26import os
27
28from tensorflow.python.keras import backend
29from tensorflow.python.keras import layers
30from tensorflow.python.keras.applications import imagenet_utils
31from tensorflow.python.keras.engine import training
32from tensorflow.python.keras.utils import data_utils
33from tensorflow.python.keras.utils import layer_utils
34from tensorflow.python.util.tf_export import keras_export
35
36
37BASE_WEIGTHS_PATH = ('https://storage.googleapis.com/tensorflow/'
38                     'keras-applications/densenet/')
39DENSENET121_WEIGHT_PATH = (
40    BASE_WEIGTHS_PATH + 'densenet121_weights_tf_dim_ordering_tf_kernels.h5')
41DENSENET121_WEIGHT_PATH_NO_TOP = (
42    BASE_WEIGTHS_PATH +
43    'densenet121_weights_tf_dim_ordering_tf_kernels_notop.h5')
44DENSENET169_WEIGHT_PATH = (
45    BASE_WEIGTHS_PATH + 'densenet169_weights_tf_dim_ordering_tf_kernels.h5')
46DENSENET169_WEIGHT_PATH_NO_TOP = (
47    BASE_WEIGTHS_PATH +
48    'densenet169_weights_tf_dim_ordering_tf_kernels_notop.h5')
49DENSENET201_WEIGHT_PATH = (
50    BASE_WEIGTHS_PATH + 'densenet201_weights_tf_dim_ordering_tf_kernels.h5')
51DENSENET201_WEIGHT_PATH_NO_TOP = (
52    BASE_WEIGTHS_PATH +
53    'densenet201_weights_tf_dim_ordering_tf_kernels_notop.h5')
54
55
56def dense_block(x, blocks, name):
57  """A dense block.
58
59  Arguments:
60    x: input tensor.
61    blocks: integer, the number of building blocks.
62    name: string, block label.
63
64  Returns:
65    Output tensor for the block.
66  """
67  for i in range(blocks):
68    x = conv_block(x, 32, name=name + '_block' + str(i + 1))
69  return x
70
71
72def transition_block(x, reduction, name):
73  """A transition block.
74
75  Arguments:
76    x: input tensor.
77    reduction: float, compression rate at transition layers.
78    name: string, block label.
79
80  Returns:
81    output tensor for the block.
82  """
83  bn_axis = 3 if backend.image_data_format() == 'channels_last' else 1
84  x = layers.BatchNormalization(
85      axis=bn_axis, epsilon=1.001e-5, name=name + '_bn')(
86          x)
87  x = layers.Activation('relu', name=name + '_relu')(x)
88  x = layers.Conv2D(
89      int(backend.int_shape(x)[bn_axis] * reduction),
90      1,
91      use_bias=False,
92      name=name + '_conv')(
93          x)
94  x = layers.AveragePooling2D(2, strides=2, name=name + '_pool')(x)
95  return x
96
97
98def conv_block(x, growth_rate, name):
99  """A building block for a dense block.
100
101  Arguments:
102    x: input tensor.
103    growth_rate: float, growth rate at dense layers.
104    name: string, block label.
105
106  Returns:
107    Output tensor for the block.
108  """
109  bn_axis = 3 if backend.image_data_format() == 'channels_last' else 1
110  x1 = layers.BatchNormalization(
111      axis=bn_axis, epsilon=1.001e-5, name=name + '_0_bn')(
112          x)
113  x1 = layers.Activation('relu', name=name + '_0_relu')(x1)
114  x1 = layers.Conv2D(
115      4 * growth_rate, 1, use_bias=False, name=name + '_1_conv')(
116          x1)
117  x1 = layers.BatchNormalization(
118      axis=bn_axis, epsilon=1.001e-5, name=name + '_1_bn')(
119          x1)
120  x1 = layers.Activation('relu', name=name + '_1_relu')(x1)
121  x1 = layers.Conv2D(
122      growth_rate, 3, padding='same', use_bias=False, name=name + '_2_conv')(
123          x1)
124  x = layers.Concatenate(axis=bn_axis, name=name + '_concat')([x, x1])
125  return x
126
127
128def DenseNet(blocks,
129             include_top=True,
130             weights='imagenet',
131             input_tensor=None,
132             input_shape=None,
133             pooling=None,
134             classes=1000):
135  """Instantiates the DenseNet architecture.
136
137  Optionally loads weights pre-trained on ImageNet.
138  Note that the data format convention used by the model is
139  the one specified in your Keras config at `~/.keras/keras.json`.
140
141  Arguments:
142    blocks: numbers of building blocks for the four dense layers.
143    include_top: whether to include the fully-connected
144      layer at the top of the network.
145    weights: one of `None` (random initialization),
146      'imagenet' (pre-training on ImageNet),
147      or the path to the weights file to be loaded.
148    input_tensor: optional Keras tensor
149      (i.e. output of `layers.Input()`)
150      to use as image input for the model.
151    input_shape: optional shape tuple, only to be specified
152      if `include_top` is False (otherwise the input shape
153      has to be `(224, 224, 3)` (with `'channels_last'` data format)
154      or `(3, 224, 224)` (with `'channels_first'` data format).
155      It should have exactly 3 inputs channels,
156      and width and height should be no smaller than 32.
157      E.g. `(200, 200, 3)` would be one valid value.
158    pooling: optional pooling mode for feature extraction
159      when `include_top` is `False`.
160      - `None` means that the output of the model will be
161          the 4D tensor output of the
162          last convolutional block.
163      - `avg` means that global average pooling
164          will be applied to the output of the
165          last convolutional block, and thus
166          the output of the model will be a 2D tensor.
167      - `max` means that global max pooling will
168          be applied.
169    classes: optional number of classes to classify images
170      into, only to be specified if `include_top` is True, and
171      if no `weights` argument is specified.
172
173  Returns:
174    A Keras model instance.
175
176  Raises:
177    ValueError: in case of invalid argument for `weights`,
178      or invalid input shape.
179  """
180  if not (weights in {'imagenet', None} or os.path.exists(weights)):
181    raise ValueError('The `weights` argument should be either '
182                     '`None` (random initialization), `imagenet` '
183                     '(pre-training on ImageNet), '
184                     'or the path to the weights file to be loaded.')
185
186  if weights == 'imagenet' and include_top and classes != 1000:
187    raise ValueError('If using `weights` as `"imagenet"` with `include_top`'
188                     ' as true, `classes` should be 1000')
189
190  # Determine proper input shape
191  input_shape = imagenet_utils.obtain_input_shape(
192      input_shape,
193      default_size=224,
194      min_size=32,
195      data_format=backend.image_data_format(),
196      require_flatten=include_top,
197      weights=weights)
198
199  if input_tensor is None:
200    img_input = layers.Input(shape=input_shape)
201  else:
202    if not backend.is_keras_tensor(input_tensor):
203      img_input = layers.Input(tensor=input_tensor, shape=input_shape)
204    else:
205      img_input = input_tensor
206
207  bn_axis = 3 if backend.image_data_format() == 'channels_last' else 1
208
209  x = layers.ZeroPadding2D(padding=((3, 3), (3, 3)))(img_input)
210  x = layers.Conv2D(64, 7, strides=2, use_bias=False, name='conv1/conv')(x)
211  x = layers.BatchNormalization(
212      axis=bn_axis, epsilon=1.001e-5, name='conv1/bn')(
213          x)
214  x = layers.Activation('relu', name='conv1/relu')(x)
215  x = layers.ZeroPadding2D(padding=((1, 1), (1, 1)))(x)
216  x = layers.MaxPooling2D(3, strides=2, name='pool1')(x)
217
218  x = dense_block(x, blocks[0], name='conv2')
219  x = transition_block(x, 0.5, name='pool2')
220  x = dense_block(x, blocks[1], name='conv3')
221  x = transition_block(x, 0.5, name='pool3')
222  x = dense_block(x, blocks[2], name='conv4')
223  x = transition_block(x, 0.5, name='pool4')
224  x = dense_block(x, blocks[3], name='conv5')
225
226  x = layers.BatchNormalization(axis=bn_axis, epsilon=1.001e-5, name='bn')(x)
227  x = layers.Activation('relu', name='relu')(x)
228
229  if include_top:
230    x = layers.GlobalAveragePooling2D(name='avg_pool')(x)
231    x = layers.Dense(classes, activation='softmax', name='fc1000')(x)
232  else:
233    if pooling == 'avg':
234      x = layers.GlobalAveragePooling2D(name='avg_pool')(x)
235    elif pooling == 'max':
236      x = layers.GlobalMaxPooling2D(name='max_pool')(x)
237
238  # Ensure that the model takes into account
239  # any potential predecessors of `input_tensor`.
240  if input_tensor is not None:
241    inputs = layer_utils.get_source_inputs(input_tensor)
242  else:
243    inputs = img_input
244
245  # Create model.
246  if blocks == [6, 12, 24, 16]:
247    model = training.Model(inputs, x, name='densenet121')
248  elif blocks == [6, 12, 32, 32]:
249    model = training.Model(inputs, x, name='densenet169')
250  elif blocks == [6, 12, 48, 32]:
251    model = training.Model(inputs, x, name='densenet201')
252  else:
253    model = training.Model(inputs, x, name='densenet')
254
255  # Load weights.
256  if weights == 'imagenet':
257    if include_top:
258      if blocks == [6, 12, 24, 16]:
259        weights_path = data_utils.get_file(
260            'densenet121_weights_tf_dim_ordering_tf_kernels.h5',
261            DENSENET121_WEIGHT_PATH,
262            cache_subdir='models',
263            file_hash='9d60b8095a5708f2dcce2bca79d332c7')
264      elif blocks == [6, 12, 32, 32]:
265        weights_path = data_utils.get_file(
266            'densenet169_weights_tf_dim_ordering_tf_kernels.h5',
267            DENSENET169_WEIGHT_PATH,
268            cache_subdir='models',
269            file_hash='d699b8f76981ab1b30698df4c175e90b')
270      elif blocks == [6, 12, 48, 32]:
271        weights_path = data_utils.get_file(
272            'densenet201_weights_tf_dim_ordering_tf_kernels.h5',
273            DENSENET201_WEIGHT_PATH,
274            cache_subdir='models',
275            file_hash='1ceb130c1ea1b78c3bf6114dbdfd8807')
276    else:
277      if blocks == [6, 12, 24, 16]:
278        weights_path = data_utils.get_file(
279            'densenet121_weights_tf_dim_ordering_tf_kernels_notop.h5',
280            DENSENET121_WEIGHT_PATH_NO_TOP,
281            cache_subdir='models',
282            file_hash='30ee3e1110167f948a6b9946edeeb738')
283      elif blocks == [6, 12, 32, 32]:
284        weights_path = data_utils.get_file(
285            'densenet169_weights_tf_dim_ordering_tf_kernels_notop.h5',
286            DENSENET169_WEIGHT_PATH_NO_TOP,
287            cache_subdir='models',
288            file_hash='b8c4d4c20dd625c148057b9ff1c1176b')
289      elif blocks == [6, 12, 48, 32]:
290        weights_path = data_utils.get_file(
291            'densenet201_weights_tf_dim_ordering_tf_kernels_notop.h5',
292            DENSENET201_WEIGHT_PATH_NO_TOP,
293            cache_subdir='models',
294            file_hash='c13680b51ded0fb44dff2d8f86ac8bb1')
295    model.load_weights(weights_path)
296  elif weights is not None:
297    model.load_weights(weights)
298
299  return model
300
301
302@keras_export('keras.applications.densenet.DenseNet121',
303              'keras.applications.DenseNet121')
304def DenseNet121(include_top=True,
305                weights='imagenet',
306                input_tensor=None,
307                input_shape=None,
308                pooling=None,
309                classes=1000):
310  """Instantiates the Densenet121 architecture."""
311  return DenseNet([6, 12, 24, 16], include_top, weights, input_tensor,
312                  input_shape, pooling, classes)
313
314
315@keras_export('keras.applications.densenet.DenseNet169',
316              'keras.applications.DenseNet169')
317def DenseNet169(include_top=True,
318                weights='imagenet',
319                input_tensor=None,
320                input_shape=None,
321                pooling=None,
322                classes=1000):
323  """Instantiates the Densenet169 architecture."""
324  return DenseNet([6, 12, 32, 32], include_top, weights, input_tensor,
325                  input_shape, pooling, classes)
326
327
328@keras_export('keras.applications.densenet.DenseNet201',
329              'keras.applications.DenseNet201')
330def DenseNet201(include_top=True,
331                weights='imagenet',
332                input_tensor=None,
333                input_shape=None,
334                pooling=None,
335                classes=1000):
336  """Instantiates the Densenet201 architecture."""
337  return DenseNet([6, 12, 48, 32], include_top, weights, input_tensor,
338                  input_shape, pooling, classes)
339
340
341@keras_export('keras.applications.densenet.preprocess_input')
342def preprocess_input(x, data_format=None):
343  return imagenet_utils.preprocess_input(
344      x, data_format=data_format, mode='torch')
345
346
347@keras_export('keras.applications.densenet.decode_predictions')
348def decode_predictions(preds, top=5):
349  return imagenet_utils.decode_predictions(preds, top=top)
350
351
352DOC = """
353
354  Optionally loads weights pre-trained on ImageNet.
355  Note that the data format convention used by the model is
356  the one specified in your Keras config at `~/.keras/keras.json`.
357
358  Arguments:
359    include_top: whether to include the fully-connected
360      layer at the top of the network.
361    weights: one of `None` (random initialization),
362      'imagenet' (pre-training on ImageNet),
363      or the path to the weights file to be loaded.
364    input_tensor: optional Keras tensor (i.e. output of `layers.Input()`)
365      to use as image input for the model.
366    input_shape: optional shape tuple, only to be specified
367      if `include_top` is False (otherwise the input shape
368      has to be `(224, 224, 3)` (with `'channels_last'` data format)
369      or `(3, 224, 224)` (with `'channels_first'` data format).
370      It should have exactly 3 inputs channels,
371      and width and height should be no smaller than 32.
372      E.g. `(200, 200, 3)` would be one valid value.
373    pooling: Optional pooling mode for feature extraction
374      when `include_top` is `False`.
375      - `None` means that the output of the model will be
376          the 4D tensor output of the
377          last convolutional block.
378      - `avg` means that global average pooling
379          will be applied to the output of the
380          last convolutional block, and thus
381          the output of the model will be a 2D tensor.
382      - `max` means that global max pooling will
383          be applied.
384    classes: optional number of classes to classify images
385      into, only to be specified if `include_top` is True, and
386      if no `weights` argument is specified.
387
388  Returns:
389    A Keras model instance.
390"""
391
392setattr(DenseNet121, '__doc__', DenseNet121.__doc__ + DOC)
393setattr(DenseNet169, '__doc__', DenseNet169.__doc__ + DOC)
394setattr(DenseNet201, '__doc__', DenseNet201.__doc__ + DOC)
395