• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=invalid-name
16"""DenseNet models for Keras.
17
18Reference:
19  - [Densely Connected Convolutional Networks](
20      https://arxiv.org/abs/1608.06993) (CVPR 2017)
21"""
22
23from tensorflow.python.keras import backend
24from tensorflow.python.keras.applications import imagenet_utils
25from tensorflow.python.keras.engine import training
26from tensorflow.python.keras.layers import VersionAwareLayers
27from tensorflow.python.keras.utils import data_utils
28from tensorflow.python.keras.utils import layer_utils
29from tensorflow.python.lib.io import file_io
30from tensorflow.python.util.tf_export import keras_export
31
32
33BASE_WEIGHTS_PATH = ('https://storage.googleapis.com/tensorflow/'
34                     'keras-applications/densenet/')
35DENSENET121_WEIGHT_PATH = (
36    BASE_WEIGHTS_PATH + 'densenet121_weights_tf_dim_ordering_tf_kernels.h5')
37DENSENET121_WEIGHT_PATH_NO_TOP = (
38    BASE_WEIGHTS_PATH +
39    'densenet121_weights_tf_dim_ordering_tf_kernels_notop.h5')
40DENSENET169_WEIGHT_PATH = (
41    BASE_WEIGHTS_PATH + 'densenet169_weights_tf_dim_ordering_tf_kernels.h5')
42DENSENET169_WEIGHT_PATH_NO_TOP = (
43    BASE_WEIGHTS_PATH +
44    'densenet169_weights_tf_dim_ordering_tf_kernels_notop.h5')
45DENSENET201_WEIGHT_PATH = (
46    BASE_WEIGHTS_PATH + 'densenet201_weights_tf_dim_ordering_tf_kernels.h5')
47DENSENET201_WEIGHT_PATH_NO_TOP = (
48    BASE_WEIGHTS_PATH +
49    'densenet201_weights_tf_dim_ordering_tf_kernels_notop.h5')
50
51layers = VersionAwareLayers()
52
53
54def dense_block(x, blocks, name):
55  """A dense block.
56
57  Args:
58    x: input tensor.
59    blocks: integer, the number of building blocks.
60    name: string, block label.
61
62  Returns:
63    Output tensor for the block.
64  """
65  for i in range(blocks):
66    x = conv_block(x, 32, name=name + '_block' + str(i + 1))
67  return x
68
69
70def transition_block(x, reduction, name):
71  """A transition block.
72
73  Args:
74    x: input tensor.
75    reduction: float, compression rate at transition layers.
76    name: string, block label.
77
78  Returns:
79    output tensor for the block.
80  """
81  bn_axis = 3 if backend.image_data_format() == 'channels_last' else 1
82  x = layers.BatchNormalization(
83      axis=bn_axis, epsilon=1.001e-5, name=name + '_bn')(
84          x)
85  x = layers.Activation('relu', name=name + '_relu')(x)
86  x = layers.Conv2D(
87      int(backend.int_shape(x)[bn_axis] * reduction),
88      1,
89      use_bias=False,
90      name=name + '_conv')(
91          x)
92  x = layers.AveragePooling2D(2, strides=2, name=name + '_pool')(x)
93  return x
94
95
96def conv_block(x, growth_rate, name):
97  """A building block for a dense block.
98
99  Args:
100    x: input tensor.
101    growth_rate: float, growth rate at dense layers.
102    name: string, block label.
103
104  Returns:
105    Output tensor for the block.
106  """
107  bn_axis = 3 if backend.image_data_format() == 'channels_last' else 1
108  x1 = layers.BatchNormalization(
109      axis=bn_axis, epsilon=1.001e-5, name=name + '_0_bn')(
110          x)
111  x1 = layers.Activation('relu', name=name + '_0_relu')(x1)
112  x1 = layers.Conv2D(
113      4 * growth_rate, 1, use_bias=False, name=name + '_1_conv')(
114          x1)
115  x1 = layers.BatchNormalization(
116      axis=bn_axis, epsilon=1.001e-5, name=name + '_1_bn')(
117          x1)
118  x1 = layers.Activation('relu', name=name + '_1_relu')(x1)
119  x1 = layers.Conv2D(
120      growth_rate, 3, padding='same', use_bias=False, name=name + '_2_conv')(
121          x1)
122  x = layers.Concatenate(axis=bn_axis, name=name + '_concat')([x, x1])
123  return x
124
125
126def DenseNet(
127    blocks,
128    include_top=True,
129    weights='imagenet',
130    input_tensor=None,
131    input_shape=None,
132    pooling=None,
133    classes=1000,
134    classifier_activation='softmax'):
135  """Instantiates the DenseNet architecture.
136
137  Reference:
138  - [Densely Connected Convolutional Networks](
139      https://arxiv.org/abs/1608.06993) (CVPR 2017)
140
141  This function returns a Keras image classification model,
142  optionally loaded with weights pre-trained on ImageNet.
143
144  For image classification use cases, see
145  [this page for detailed examples](
146    https://keras.io/api/applications/#usage-examples-for-image-classification-models).
147
148  For transfer learning use cases, make sure to read the
149  [guide to transfer learning & fine-tuning](
150    https://keras.io/guides/transfer_learning/).
151
152  Note: each Keras Application expects a specific kind of input preprocessing.
153  For DenseNet, call `tf.keras.applications.densenet.preprocess_input` on your
154  inputs before passing them to the model.
155  `densenet.preprocess_input` will scale pixels between 0 and 1 and then
156  will normalize each channel with respect to the ImageNet dataset statistics.
157
158  Args:
159    blocks: numbers of building blocks for the four dense layers.
160    include_top: whether to include the fully-connected
161      layer at the top of the network.
162    weights: one of `None` (random initialization),
163      'imagenet' (pre-training on ImageNet),
164      or the path to the weights file to be loaded.
165    input_tensor: optional Keras tensor
166      (i.e. output of `layers.Input()`)
167      to use as image input for the model.
168    input_shape: optional shape tuple, only to be specified
169      if `include_top` is False (otherwise the input shape
170      has to be `(224, 224, 3)` (with `'channels_last'` data format)
171      or `(3, 224, 224)` (with `'channels_first'` data format).
172      It should have exactly 3 inputs channels,
173      and width and height should be no smaller than 32.
174      E.g. `(200, 200, 3)` would be one valid value.
175    pooling: optional pooling mode for feature extraction
176      when `include_top` is `False`.
177      - `None` means that the output of the model will be
178          the 4D tensor output of the
179          last convolutional block.
180      - `avg` means that global average pooling
181          will be applied to the output of the
182          last convolutional block, and thus
183          the output of the model will be a 2D tensor.
184      - `max` means that global max pooling will
185          be applied.
186    classes: optional number of classes to classify images
187      into, only to be specified if `include_top` is True, and
188      if no `weights` argument is specified.
189    classifier_activation: A `str` or callable. The activation function to use
190      on the "top" layer. Ignored unless `include_top=True`. Set
191      `classifier_activation=None` to return the logits of the "top" layer.
192      When loading pretrained weights, `classifier_activation` can only
193      be `None` or `"softmax"`.
194
195  Returns:
196    A `keras.Model` instance.
197  """
198  if not (weights in {'imagenet', None} or file_io.file_exists_v2(weights)):
199    raise ValueError('The `weights` argument should be either '
200                     '`None` (random initialization), `imagenet` '
201                     '(pre-training on ImageNet), '
202                     'or the path to the weights file to be loaded.')
203
204  if weights == 'imagenet' and include_top and classes != 1000:
205    raise ValueError('If using `weights` as `"imagenet"` with `include_top`'
206                     ' as true, `classes` should be 1000')
207
208  # Determine proper input shape
209  input_shape = imagenet_utils.obtain_input_shape(
210      input_shape,
211      default_size=224,
212      min_size=32,
213      data_format=backend.image_data_format(),
214      require_flatten=include_top,
215      weights=weights)
216
217  if input_tensor is None:
218    img_input = layers.Input(shape=input_shape)
219  else:
220    if not backend.is_keras_tensor(input_tensor):
221      img_input = layers.Input(tensor=input_tensor, shape=input_shape)
222    else:
223      img_input = input_tensor
224
225  bn_axis = 3 if backend.image_data_format() == 'channels_last' else 1
226
227  x = layers.ZeroPadding2D(padding=((3, 3), (3, 3)))(img_input)
228  x = layers.Conv2D(64, 7, strides=2, use_bias=False, name='conv1/conv')(x)
229  x = layers.BatchNormalization(
230      axis=bn_axis, epsilon=1.001e-5, name='conv1/bn')(
231          x)
232  x = layers.Activation('relu', name='conv1/relu')(x)
233  x = layers.ZeroPadding2D(padding=((1, 1), (1, 1)))(x)
234  x = layers.MaxPooling2D(3, strides=2, name='pool1')(x)
235
236  x = dense_block(x, blocks[0], name='conv2')
237  x = transition_block(x, 0.5, name='pool2')
238  x = dense_block(x, blocks[1], name='conv3')
239  x = transition_block(x, 0.5, name='pool3')
240  x = dense_block(x, blocks[2], name='conv4')
241  x = transition_block(x, 0.5, name='pool4')
242  x = dense_block(x, blocks[3], name='conv5')
243
244  x = layers.BatchNormalization(axis=bn_axis, epsilon=1.001e-5, name='bn')(x)
245  x = layers.Activation('relu', name='relu')(x)
246
247  if include_top:
248    x = layers.GlobalAveragePooling2D(name='avg_pool')(x)
249
250    imagenet_utils.validate_activation(classifier_activation, weights)
251    x = layers.Dense(classes, activation=classifier_activation,
252                     name='predictions')(x)
253  else:
254    if pooling == 'avg':
255      x = layers.GlobalAveragePooling2D(name='avg_pool')(x)
256    elif pooling == 'max':
257      x = layers.GlobalMaxPooling2D(name='max_pool')(x)
258
259  # Ensure that the model takes into account
260  # any potential predecessors of `input_tensor`.
261  if input_tensor is not None:
262    inputs = layer_utils.get_source_inputs(input_tensor)
263  else:
264    inputs = img_input
265
266  # Create model.
267  if blocks == [6, 12, 24, 16]:
268    model = training.Model(inputs, x, name='densenet121')
269  elif blocks == [6, 12, 32, 32]:
270    model = training.Model(inputs, x, name='densenet169')
271  elif blocks == [6, 12, 48, 32]:
272    model = training.Model(inputs, x, name='densenet201')
273  else:
274    model = training.Model(inputs, x, name='densenet')
275
276  # Load weights.
277  if weights == 'imagenet':
278    if include_top:
279      if blocks == [6, 12, 24, 16]:
280        weights_path = data_utils.get_file(
281            'densenet121_weights_tf_dim_ordering_tf_kernels.h5',
282            DENSENET121_WEIGHT_PATH,
283            cache_subdir='models',
284            file_hash='9d60b8095a5708f2dcce2bca79d332c7')
285      elif blocks == [6, 12, 32, 32]:
286        weights_path = data_utils.get_file(
287            'densenet169_weights_tf_dim_ordering_tf_kernels.h5',
288            DENSENET169_WEIGHT_PATH,
289            cache_subdir='models',
290            file_hash='d699b8f76981ab1b30698df4c175e90b')
291      elif blocks == [6, 12, 48, 32]:
292        weights_path = data_utils.get_file(
293            'densenet201_weights_tf_dim_ordering_tf_kernels.h5',
294            DENSENET201_WEIGHT_PATH,
295            cache_subdir='models',
296            file_hash='1ceb130c1ea1b78c3bf6114dbdfd8807')
297    else:
298      if blocks == [6, 12, 24, 16]:
299        weights_path = data_utils.get_file(
300            'densenet121_weights_tf_dim_ordering_tf_kernels_notop.h5',
301            DENSENET121_WEIGHT_PATH_NO_TOP,
302            cache_subdir='models',
303            file_hash='30ee3e1110167f948a6b9946edeeb738')
304      elif blocks == [6, 12, 32, 32]:
305        weights_path = data_utils.get_file(
306            'densenet169_weights_tf_dim_ordering_tf_kernels_notop.h5',
307            DENSENET169_WEIGHT_PATH_NO_TOP,
308            cache_subdir='models',
309            file_hash='b8c4d4c20dd625c148057b9ff1c1176b')
310      elif blocks == [6, 12, 48, 32]:
311        weights_path = data_utils.get_file(
312            'densenet201_weights_tf_dim_ordering_tf_kernels_notop.h5',
313            DENSENET201_WEIGHT_PATH_NO_TOP,
314            cache_subdir='models',
315            file_hash='c13680b51ded0fb44dff2d8f86ac8bb1')
316    model.load_weights(weights_path)
317  elif weights is not None:
318    model.load_weights(weights)
319
320  return model
321
322
323@keras_export('keras.applications.densenet.DenseNet121',
324              'keras.applications.DenseNet121')
325def DenseNet121(include_top=True,
326                weights='imagenet',
327                input_tensor=None,
328                input_shape=None,
329                pooling=None,
330                classes=1000):
331  """Instantiates the Densenet121 architecture."""
332  return DenseNet([6, 12, 24, 16], include_top, weights, input_tensor,
333                  input_shape, pooling, classes)
334
335
336@keras_export('keras.applications.densenet.DenseNet169',
337              'keras.applications.DenseNet169')
338def DenseNet169(include_top=True,
339                weights='imagenet',
340                input_tensor=None,
341                input_shape=None,
342                pooling=None,
343                classes=1000):
344  """Instantiates the Densenet169 architecture."""
345  return DenseNet([6, 12, 32, 32], include_top, weights, input_tensor,
346                  input_shape, pooling, classes)
347
348
349@keras_export('keras.applications.densenet.DenseNet201',
350              'keras.applications.DenseNet201')
351def DenseNet201(include_top=True,
352                weights='imagenet',
353                input_tensor=None,
354                input_shape=None,
355                pooling=None,
356                classes=1000):
357  """Instantiates the Densenet201 architecture."""
358  return DenseNet([6, 12, 48, 32], include_top, weights, input_tensor,
359                  input_shape, pooling, classes)
360
361
362@keras_export('keras.applications.densenet.preprocess_input')
363def preprocess_input(x, data_format=None):
364  return imagenet_utils.preprocess_input(
365      x, data_format=data_format, mode='torch')
366
367
368@keras_export('keras.applications.densenet.decode_predictions')
369def decode_predictions(preds, top=5):
370  return imagenet_utils.decode_predictions(preds, top=top)
371
372
373preprocess_input.__doc__ = imagenet_utils.PREPROCESS_INPUT_DOC.format(
374    mode='',
375    ret=imagenet_utils.PREPROCESS_INPUT_RET_DOC_TORCH,
376    error=imagenet_utils.PREPROCESS_INPUT_ERROR_DOC)
377decode_predictions.__doc__ = imagenet_utils.decode_predictions.__doc__
378
379DOC = """
380
381  Reference:
382  - [Densely Connected Convolutional Networks](
383      https://arxiv.org/abs/1608.06993) (CVPR 2017)
384
385  Optionally loads weights pre-trained on ImageNet.
386  Note that the data format convention used by the model is
387  the one specified in your Keras config at `~/.keras/keras.json`.
388
389  Note: each Keras Application expects a specific kind of input preprocessing.
390  For DenseNet, call `tf.keras.applications.densenet.preprocess_input` on your
391  inputs before passing them to the model.
392
393  Args:
394    include_top: whether to include the fully-connected
395      layer at the top of the network.
396    weights: one of `None` (random initialization),
397      'imagenet' (pre-training on ImageNet),
398      or the path to the weights file to be loaded.
399    input_tensor: optional Keras tensor (i.e. output of `layers.Input()`)
400      to use as image input for the model.
401    input_shape: optional shape tuple, only to be specified
402      if `include_top` is False (otherwise the input shape
403      has to be `(224, 224, 3)` (with `'channels_last'` data format)
404      or `(3, 224, 224)` (with `'channels_first'` data format).
405      It should have exactly 3 inputs channels,
406      and width and height should be no smaller than 32.
407      E.g. `(200, 200, 3)` would be one valid value.
408    pooling: Optional pooling mode for feature extraction
409      when `include_top` is `False`.
410      - `None` means that the output of the model will be
411          the 4D tensor output of the
412          last convolutional block.
413      - `avg` means that global average pooling
414          will be applied to the output of the
415          last convolutional block, and thus
416          the output of the model will be a 2D tensor.
417      - `max` means that global max pooling will
418          be applied.
419    classes: optional number of classes to classify images
420      into, only to be specified if `include_top` is True, and
421      if no `weights` argument is specified.
422
423  Returns:
424    A Keras model instance.
425"""
426
427setattr(DenseNet121, '__doc__', DenseNet121.__doc__ + DOC)
428setattr(DenseNet169, '__doc__', DenseNet169.__doc__ + DOC)
429setattr(DenseNet201, '__doc__', DenseNet201.__doc__ + DOC)
430