• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=invalid-name
16"""EfficientNet models for Keras.
17
18Reference paper:
19  - [EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks]
20    (https://arxiv.org/abs/1905.11946) (ICML 2019)
21"""
22from __future__ import absolute_import
23from __future__ import division
24from __future__ import print_function
25
26import copy
27import math
28import os
29
30from tensorflow.python.keras import backend
31from tensorflow.python.keras import layers
32from tensorflow.python.keras.applications import imagenet_utils
33from tensorflow.python.keras.engine import training
34from tensorflow.python.keras.utils import data_utils
35from tensorflow.python.keras.utils import layer_utils
36from tensorflow.python.util.tf_export import keras_export
37
38
39BASE_WEIGHTS_PATH = 'https://storage.googleapis.com/keras-applications/'
40
41WEIGHTS_HASHES = {
42    'b0': ('902e53a9f72be733fc0bcb005b3ebbac',
43           '50bc09e76180e00e4465e1a485ddc09d'),
44    'b1': ('1d254153d4ab51201f1646940f018540',
45           '74c4e6b3e1f6a1eea24c589628592432'),
46    'b2': ('b15cce36ff4dcbd00b6dd88e7857a6ad',
47           '111f8e2ac8aa800a7a99e3239f7bfb39'),
48    'b3': ('ffd1fdc53d0ce67064dc6a9c7960ede0',
49           'af6d107764bb5b1abb91932881670226'),
50    'b4': ('18c95ad55216b8f92d7e70b3a046e2fc',
51           'ebc24e6d6c33eaebbd558eafbeedf1ba'),
52    'b5': ('ace28f2a6363774853a83a0b21b9421a',
53           '38879255a25d3c92d5e44e04ae6cec6f'),
54    'b6': ('165f6e37dce68623721b423839de8be5',
55           '9ecce42647a20130c1f39a5d4cb75743'),
56    'b7': ('8c03f828fec3ef71311cd463b6759d99',
57           'cbcfe4450ddf6f3ad90b1b398090fe4a'),
58}
59
60DEFAULT_BLOCKS_ARGS = [{
61    'kernel_size': 3,
62    'repeats': 1,
63    'filters_in': 32,
64    'filters_out': 16,
65    'expand_ratio': 1,
66    'id_skip': True,
67    'strides': 1,
68    'se_ratio': 0.25
69}, {
70    'kernel_size': 3,
71    'repeats': 2,
72    'filters_in': 16,
73    'filters_out': 24,
74    'expand_ratio': 6,
75    'id_skip': True,
76    'strides': 2,
77    'se_ratio': 0.25
78}, {
79    'kernel_size': 5,
80    'repeats': 2,
81    'filters_in': 24,
82    'filters_out': 40,
83    'expand_ratio': 6,
84    'id_skip': True,
85    'strides': 2,
86    'se_ratio': 0.25
87}, {
88    'kernel_size': 3,
89    'repeats': 3,
90    'filters_in': 40,
91    'filters_out': 80,
92    'expand_ratio': 6,
93    'id_skip': True,
94    'strides': 2,
95    'se_ratio': 0.25
96}, {
97    'kernel_size': 5,
98    'repeats': 3,
99    'filters_in': 80,
100    'filters_out': 112,
101    'expand_ratio': 6,
102    'id_skip': True,
103    'strides': 1,
104    'se_ratio': 0.25
105}, {
106    'kernel_size': 5,
107    'repeats': 4,
108    'filters_in': 112,
109    'filters_out': 192,
110    'expand_ratio': 6,
111    'id_skip': True,
112    'strides': 2,
113    'se_ratio': 0.25
114}, {
115    'kernel_size': 3,
116    'repeats': 1,
117    'filters_in': 192,
118    'filters_out': 320,
119    'expand_ratio': 6,
120    'id_skip': True,
121    'strides': 1,
122    'se_ratio': 0.25
123}]
124
125CONV_KERNEL_INITIALIZER = {
126    'class_name': 'VarianceScaling',
127    'config': {
128        'scale': 2.0,
129        'mode': 'fan_out',
130        'distribution': 'truncated_normal'
131    }
132}
133
134DENSE_KERNEL_INITIALIZER = {
135    'class_name': 'VarianceScaling',
136    'config': {
137        'scale': 1. / 3.,
138        'mode': 'fan_out',
139        'distribution': 'uniform'
140    }
141}
142
143
144def EfficientNet(width_coefficient,
145                 depth_coefficient,
146                 default_size,
147                 dropout_rate=0.2,
148                 drop_connect_rate=0.2,
149                 depth_divisor=8,
150                 activation='swish',
151                 blocks_args='default',
152                 model_name='efficientnet',
153                 include_top=True,
154                 weights='imagenet',
155                 input_tensor=None,
156                 input_shape=None,
157                 pooling=None,
158                 classes=1000):
159  """Instantiates the EfficientNet architecture using given scaling coefficients.
160
161  Optionally loads weights pre-trained on ImageNet.
162  Note that the data format convention used by the model is
163  the one specified in your Keras config at `~/.keras/keras.json`.
164
165  Arguments:
166    width_coefficient: float, scaling coefficient for network width.
167    depth_coefficient: float, scaling coefficient for network depth.
168    default_size: integer, default input image size.
169    dropout_rate: float, dropout rate before final classifier layer.
170    drop_connect_rate: float, dropout rate at skip connections.
171    depth_divisor: integer, a unit of network width.
172    activation: activation function.
173    blocks_args: list of dicts, parameters to construct block modules.
174    model_name: string, model name.
175    include_top: whether to include the fully-connected
176        layer at the top of the network.
177    weights: one of `None` (random initialization),
178          'imagenet' (pre-training on ImageNet),
179          or the path to the weights file to be loaded.
180    input_tensor: optional Keras tensor
181        (i.e. output of `layers.Input()`)
182        to use as image input for the model.
183    input_shape: optional shape tuple, only to be specified
184        if `include_top` is False.
185        It should have exactly 3 inputs channels.
186    pooling: optional pooling mode for feature extraction
187        when `include_top` is `False`.
188        - `None` means that the output of the model will be
189            the 4D tensor output of the
190            last convolutional layer.
191        - `avg` means that global average pooling
192            will be applied to the output of the
193            last convolutional layer, and thus
194            the output of the model will be a 2D tensor.
195        - `max` means that global max pooling will
196            be applied.
197    classes: optional number of classes to classify images
198        into, only to be specified if `include_top` is True, and
199        if no `weights` argument is specified.
200
201  Returns:
202    A Keras model instance.
203
204  Raises:
205    ValueError: in case of invalid argument for `weights`,
206      or invalid input shape.
207  """
208  if blocks_args == 'default':
209    blocks_args = DEFAULT_BLOCKS_ARGS
210
211  if not (weights in {'imagenet', None} or os.path.exists(weights)):
212    raise ValueError('The `weights` argument should be either '
213                     '`None` (random initialization), `imagenet` '
214                     '(pre-training on ImageNet), '
215                     'or the path to the weights file to be loaded.')
216
217  if weights == 'imagenet' and include_top and classes != 1000:
218    raise ValueError('If using `weights` as `"imagenet"` with `include_top`'
219                     ' as true, `classes` should be 1000')
220
221  # Determine proper input shape
222  input_shape = imagenet_utils.obtain_input_shape(
223      input_shape,
224      default_size=default_size,
225      min_size=32,
226      data_format=backend.image_data_format(),
227      require_flatten=include_top,
228      weights=weights)
229
230  if input_tensor is None:
231    img_input = layers.Input(shape=input_shape)
232  else:
233    if not backend.is_keras_tensor(input_tensor):
234      img_input = layers.Input(tensor=input_tensor, shape=input_shape)
235    else:
236      img_input = input_tensor
237
238  bn_axis = 3 if backend.image_data_format() == 'channels_last' else 1
239
240  def round_filters(filters, divisor=depth_divisor):
241    """Round number of filters based on depth multiplier."""
242    filters *= width_coefficient
243    new_filters = max(divisor, int(filters + divisor / 2) // divisor * divisor)
244    # Make sure that round down does not go down by more than 10%.
245    if new_filters < 0.9 * filters:
246      new_filters += divisor
247    return int(new_filters)
248
249  def round_repeats(repeats):
250    """Round number of repeats based on depth multiplier."""
251    return int(math.ceil(depth_coefficient * repeats))
252
253  # Build stem
254  x = img_input
255  x = layers.Rescaling(1. / 255.)(x)
256  x = layers.Normalization(axis=bn_axis)(x)
257
258  x = layers.ZeroPadding2D(
259      padding=imagenet_utils.correct_pad(x, 3),
260      name='stem_conv_pad')(x)
261  x = layers.Conv2D(
262      round_filters(32),
263      3,
264      strides=2,
265      padding='valid',
266      use_bias=False,
267      kernel_initializer=CONV_KERNEL_INITIALIZER,
268      name='stem_conv')(x)
269  x = layers.BatchNormalization(axis=bn_axis, name='stem_bn')(x)
270  x = layers.Activation(activation, name='stem_activation')(x)
271
272  # Build blocks
273  blocks_args = copy.deepcopy(blocks_args)
274
275  b = 0
276  blocks = float(sum(args['repeats'] for args in blocks_args))
277  for (i, args) in enumerate(blocks_args):
278    assert args['repeats'] > 0
279    # Update block input and output filters based on depth multiplier.
280    args['filters_in'] = round_filters(args['filters_in'])
281    args['filters_out'] = round_filters(args['filters_out'])
282
283    for j in range(round_repeats(args.pop('repeats'))):
284      # The first block needs to take care of stride and filter size increase.
285      if j > 0:
286        args['strides'] = 1
287        args['filters_in'] = args['filters_out']
288      x = block(
289          x,
290          activation,
291          drop_connect_rate * b / blocks,
292          name='block{}{}_'.format(i + 1, chr(j + 97)),
293          **args)
294      b += 1
295
296  # Build top
297  x = layers.Conv2D(
298      round_filters(1280),
299      1,
300      padding='same',
301      use_bias=False,
302      kernel_initializer=CONV_KERNEL_INITIALIZER,
303      name='top_conv')(x)
304  x = layers.BatchNormalization(axis=bn_axis, name='top_bn')(x)
305  x = layers.Activation(activation, name='top_activation')(x)
306  if include_top:
307    x = layers.GlobalAveragePooling2D(name='avg_pool')(x)
308    if dropout_rate > 0:
309      x = layers.Dropout(dropout_rate, name='top_dropout')(x)
310    x = layers.Dense(
311        classes,
312        activation='softmax',
313        kernel_initializer=DENSE_KERNEL_INITIALIZER,
314        name='probs')(x)
315  else:
316    if pooling == 'avg':
317      x = layers.GlobalAveragePooling2D(name='avg_pool')(x)
318    elif pooling == 'max':
319      x = layers.GlobalMaxPooling2D(name='max_pool')(x)
320
321  # Ensure that the model takes into account
322  # any potential predecessors of `input_tensor`.
323  if input_tensor is not None:
324    inputs = layer_utils.get_source_inputs(input_tensor)
325  else:
326    inputs = img_input
327
328  # Create model.
329  model = training.Model(inputs, x, name=model_name)
330
331  # Load weights.
332  if weights == 'imagenet':
333    if include_top:
334      file_suffix = '.h5'
335      file_hash = WEIGHTS_HASHES[model_name[-2:]][0]
336    else:
337      file_suffix = '_notop.h5'
338      file_hash = WEIGHTS_HASHES[model_name[-2:]][1]
339    file_name = model_name + file_suffix
340    weights_path = data_utils.get_file(
341        file_name,
342        BASE_WEIGHTS_PATH + file_name,
343        cache_subdir='models',
344        file_hash=file_hash)
345    model.load_weights(weights_path)
346  elif weights is not None:
347    model.load_weights(weights)
348  return model
349
350
351def block(inputs,
352          activation='swish',
353          drop_rate=0.,
354          name='',
355          filters_in=32,
356          filters_out=16,
357          kernel_size=3,
358          strides=1,
359          expand_ratio=1,
360          se_ratio=0.,
361          id_skip=True):
362  """An inverted residual block.
363
364  Arguments:
365      inputs: input tensor.
366      activation: activation function.
367      drop_rate: float between 0 and 1, fraction of the input units to drop.
368      name: string, block label.
369      filters_in: integer, the number of input filters.
370      filters_out: integer, the number of output filters.
371      kernel_size: integer, the dimension of the convolution window.
372      strides: integer, the stride of the convolution.
373      expand_ratio: integer, scaling coefficient for the input filters.
374      se_ratio: float between 0 and 1, fraction to squeeze the input filters.
375      id_skip: boolean.
376
377  Returns:
378      output tensor for the block.
379  """
380  bn_axis = 3 if backend.image_data_format() == 'channels_last' else 1
381
382  # Expansion phase
383  filters = filters_in * expand_ratio
384  if expand_ratio != 1:
385    x = layers.Conv2D(
386        filters,
387        1,
388        padding='same',
389        use_bias=False,
390        kernel_initializer=CONV_KERNEL_INITIALIZER,
391        name=name + 'expand_conv')(
392            inputs)
393    x = layers.BatchNormalization(axis=bn_axis, name=name + 'expand_bn')(x)
394    x = layers.Activation(activation, name=name + 'expand_activation')(x)
395  else:
396    x = inputs
397
398  # Depthwise Convolution
399  if strides == 2:
400    x = layers.ZeroPadding2D(
401        padding=imagenet_utils.correct_pad(x, kernel_size),
402        name=name + 'dwconv_pad')(x)
403    conv_pad = 'valid'
404  else:
405    conv_pad = 'same'
406  x = layers.DepthwiseConv2D(
407      kernel_size,
408      strides=strides,
409      padding=conv_pad,
410      use_bias=False,
411      depthwise_initializer=CONV_KERNEL_INITIALIZER,
412      name=name + 'dwconv')(x)
413  x = layers.BatchNormalization(axis=bn_axis, name=name + 'bn')(x)
414  x = layers.Activation(activation, name=name + 'activation')(x)
415
416  # Squeeze and Excitation phase
417  if 0 < se_ratio <= 1:
418    filters_se = max(1, int(filters_in * se_ratio))
419    se = layers.GlobalAveragePooling2D(name=name + 'se_squeeze')(x)
420    se = layers.Reshape((1, 1, filters), name=name + 'se_reshape')(se)
421    se = layers.Conv2D(
422        filters_se,
423        1,
424        padding='same',
425        activation=activation,
426        kernel_initializer=CONV_KERNEL_INITIALIZER,
427        name=name + 'se_reduce')(
428            se)
429    se = layers.Conv2D(
430        filters,
431        1,
432        padding='same',
433        activation='sigmoid',
434        kernel_initializer=CONV_KERNEL_INITIALIZER,
435        name=name + 'se_expand')(se)
436    x = layers.multiply([x, se], name=name + 'se_excite')
437
438  # Output phase
439  x = layers.Conv2D(
440      filters_out,
441      1,
442      padding='same',
443      use_bias=False,
444      kernel_initializer=CONV_KERNEL_INITIALIZER,
445      name=name + 'project_conv')(x)
446  x = layers.BatchNormalization(axis=bn_axis, name=name + 'project_bn')(x)
447  if id_skip and strides == 1 and filters_in == filters_out:
448    if drop_rate > 0:
449      x = layers.Dropout(
450          drop_rate, noise_shape=(None, 1, 1, 1), name=name + 'drop')(x)
451    x = layers.add([x, inputs], name=name + 'add')
452  return x
453
454
455@keras_export('keras.applications.efficientnet.EfficientNetB0',
456              'keras.applications.EfficientNetB0')
457def EfficientNetB0(include_top=True,
458                   weights='imagenet',
459                   input_tensor=None,
460                   input_shape=None,
461                   pooling=None,
462                   classes=1000,
463                   **kwargs):
464  return EfficientNet(
465      1.0,
466      1.0,
467      224,
468      0.2,
469      model_name='efficientnetb0',
470      include_top=include_top,
471      weights=weights,
472      input_tensor=input_tensor,
473      input_shape=input_shape,
474      pooling=pooling,
475      classes=classes,
476      **kwargs)
477
478
479@keras_export('keras.applications.efficientnet.EfficientNetB1',
480              'keras.applications.EfficientNetB1')
481def EfficientNetB1(include_top=True,
482                   weights='imagenet',
483                   input_tensor=None,
484                   input_shape=None,
485                   pooling=None,
486                   classes=1000,
487                   **kwargs):
488  return EfficientNet(
489      1.0,
490      1.1,
491      240,
492      0.2,
493      model_name='efficientnetb1',
494      include_top=include_top,
495      weights=weights,
496      input_tensor=input_tensor,
497      input_shape=input_shape,
498      pooling=pooling,
499      classes=classes,
500      **kwargs)
501
502
503@keras_export('keras.applications.efficientnet.EfficientNetB2',
504              'keras.applications.EfficientNetB2')
505def EfficientNetB2(include_top=True,
506                   weights='imagenet',
507                   input_tensor=None,
508                   input_shape=None,
509                   pooling=None,
510                   classes=1000,
511                   **kwargs):
512  return EfficientNet(
513      1.1,
514      1.2,
515      260,
516      0.3,
517      model_name='efficientnetb2',
518      include_top=include_top,
519      weights=weights,
520      input_tensor=input_tensor,
521      input_shape=input_shape,
522      pooling=pooling,
523      classes=classes,
524      **kwargs)
525
526
527@keras_export('keras.applications.efficientnet.EfficientNetB3',
528              'keras.applications.EfficientNetB3')
529def EfficientNetB3(include_top=True,
530                   weights='imagenet',
531                   input_tensor=None,
532                   input_shape=None,
533                   pooling=None,
534                   classes=1000,
535                   **kwargs):
536  return EfficientNet(
537      1.2,
538      1.4,
539      300,
540      0.3,
541      model_name='efficientnetb3',
542      include_top=include_top,
543      weights=weights,
544      input_tensor=input_tensor,
545      input_shape=input_shape,
546      pooling=pooling,
547      classes=classes,
548      **kwargs)
549
550
551@keras_export('keras.applications.efficientnet.EfficientNetB4',
552              'keras.applications.EfficientNetB4')
553def EfficientNetB4(include_top=True,
554                   weights='imagenet',
555                   input_tensor=None,
556                   input_shape=None,
557                   pooling=None,
558                   classes=1000,
559                   **kwargs):
560  return EfficientNet(
561      1.4,
562      1.8,
563      380,
564      0.4,
565      model_name='efficientnetb4',
566      include_top=include_top,
567      weights=weights,
568      input_tensor=input_tensor,
569      input_shape=input_shape,
570      pooling=pooling,
571      classes=classes,
572      **kwargs)
573
574
575@keras_export('keras.applications.efficientnet.EfficientNetB5',
576              'keras.applications.EfficientNetB5')
577def EfficientNetB5(include_top=True,
578                   weights='imagenet',
579                   input_tensor=None,
580                   input_shape=None,
581                   pooling=None,
582                   classes=1000,
583                   **kwargs):
584  return EfficientNet(
585      1.6,
586      2.2,
587      456,
588      0.4,
589      model_name='efficientnetb5',
590      include_top=include_top,
591      weights=weights,
592      input_tensor=input_tensor,
593      input_shape=input_shape,
594      pooling=pooling,
595      classes=classes,
596      **kwargs)
597
598
599@keras_export('keras.applications.efficientnet.EfficientNetB6',
600              'keras.applications.EfficientNetB6')
601def EfficientNetB6(include_top=True,
602                   weights='imagenet',
603                   input_tensor=None,
604                   input_shape=None,
605                   pooling=None,
606                   classes=1000,
607                   **kwargs):
608  return EfficientNet(
609      1.8,
610      2.6,
611      528,
612      0.5,
613      model_name='efficientnetb6',
614      include_top=include_top,
615      weights=weights,
616      input_tensor=input_tensor,
617      input_shape=input_shape,
618      pooling=pooling,
619      classes=classes,
620      **kwargs)
621
622
623@keras_export('keras.applications.efficientnet.EfficientNetB7',
624              'keras.applications.EfficientNetB7')
625def EfficientNetB7(include_top=True,
626                   weights='imagenet',
627                   input_tensor=None,
628                   input_shape=None,
629                   pooling=None,
630                   classes=1000,
631                   **kwargs):
632  return EfficientNet(
633      2.0,
634      3.1,
635      600,
636      0.5,
637      model_name='efficientnetb7',
638      include_top=include_top,
639      weights=weights,
640      input_tensor=input_tensor,
641      input_shape=input_shape,
642      pooling=pooling,
643      classes=classes,
644      **kwargs)
645
646
647@keras_export('keras.applications.efficientnet.preprocess_input')
648def preprocess_input(x, data_format=None):  # pylint: disable=unused-argument
649  return x
650
651
652@keras_export('keras.applications.efficientnet.decode_predictions')
653def decode_predictions(preds, top=5):
654  return imagenet_utils.decode_predictions(preds, top=top)
655