• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""ResNet50 model definition compatible with TensorFlow's eager execution.
16
17Reference [Deep Residual Learning for Image
18Recognition](https://arxiv.org/abs/1512.03385)
19
20Adapted from tf.keras.applications.ResNet50. A notable difference is that the
21model here outputs logits while the Keras model outputs probability.
22"""
23import functools
24
25import tensorflow as tf
26
27layers = tf.keras.layers
28
29
30class _IdentityBlock(tf.keras.Model):
31  """_IdentityBlock is the block that has no conv layer at shortcut.
32
33  Args:
34    kernel_size: the kernel size of middle conv layer at main path
35    filters: list of integers, the filters of 3 conv layer at main path
36    stage: integer, current stage label, used for generating layer names
37    block: 'a','b'..., current block label, used for generating layer names
38    data_format: data_format for the input ('channels_first' or
39      'channels_last').
40  """
41
42  def __init__(self, kernel_size, filters, stage, block, data_format):
43    super(_IdentityBlock, self).__init__(name='')
44    filters1, filters2, filters3 = filters
45
46    conv_name_base = 'res' + str(stage) + block + '_branch'
47    bn_name_base = 'bn' + str(stage) + block + '_branch'
48    bn_axis = 1 if data_format == 'channels_first' else 3
49
50    self.conv2a = layers.Conv2D(
51        filters1, (1, 1), name=conv_name_base + '2a', data_format=data_format)
52    self.bn2a = layers.BatchNormalization(
53        axis=bn_axis, name=bn_name_base + '2a')
54
55    self.conv2b = layers.Conv2D(
56        filters2,
57        kernel_size,
58        padding='same',
59        data_format=data_format,
60        name=conv_name_base + '2b')
61    self.bn2b = layers.BatchNormalization(
62        axis=bn_axis, name=bn_name_base + '2b')
63
64    self.conv2c = layers.Conv2D(
65        filters3, (1, 1), name=conv_name_base + '2c', data_format=data_format)
66    self.bn2c = layers.BatchNormalization(
67        axis=bn_axis, name=bn_name_base + '2c')
68
69  def call(self, input_tensor, training=False):
70    x = self.conv2a(input_tensor)
71    x = self.bn2a(x, training=training)
72    x = tf.nn.relu(x)
73
74    x = self.conv2b(x)
75    x = self.bn2b(x, training=training)
76    x = tf.nn.relu(x)
77
78    x = self.conv2c(x)
79    x = self.bn2c(x, training=training)
80
81    x += input_tensor
82    return tf.nn.relu(x)
83
84
85class _ConvBlock(tf.keras.Model):
86  """_ConvBlock is the block that has a conv layer at shortcut.
87
88  Args:
89      kernel_size: the kernel size of middle conv layer at main path
90      filters: list of integers, the filters of 3 conv layer at main path
91      stage: integer, current stage label, used for generating layer names
92      block: 'a','b'..., current block label, used for generating layer names
93      data_format: data_format for the input ('channels_first' or
94        'channels_last').
95      strides: strides for the convolution. Note that from stage 3, the first
96       conv layer at main path is with strides=(2,2), and the shortcut should
97       have strides=(2,2) as well.
98  """
99
100  def __init__(self,
101               kernel_size,
102               filters,
103               stage,
104               block,
105               data_format,
106               strides=(2, 2)):
107    super(_ConvBlock, self).__init__(name='')
108    filters1, filters2, filters3 = filters
109
110    conv_name_base = 'res' + str(stage) + block + '_branch'
111    bn_name_base = 'bn' + str(stage) + block + '_branch'
112    bn_axis = 1 if data_format == 'channels_first' else 3
113
114    self.conv2a = layers.Conv2D(
115        filters1, (1, 1),
116        strides=strides,
117        name=conv_name_base + '2a',
118        data_format=data_format)
119    self.bn2a = layers.BatchNormalization(
120        axis=bn_axis, name=bn_name_base + '2a')
121
122    self.conv2b = layers.Conv2D(
123        filters2,
124        kernel_size,
125        padding='same',
126        name=conv_name_base + '2b',
127        data_format=data_format)
128    self.bn2b = layers.BatchNormalization(
129        axis=bn_axis, name=bn_name_base + '2b')
130
131    self.conv2c = layers.Conv2D(
132        filters3, (1, 1), name=conv_name_base + '2c', data_format=data_format)
133    self.bn2c = layers.BatchNormalization(
134        axis=bn_axis, name=bn_name_base + '2c')
135
136    self.conv_shortcut = layers.Conv2D(
137        filters3, (1, 1),
138        strides=strides,
139        name=conv_name_base + '1',
140        data_format=data_format)
141    self.bn_shortcut = layers.BatchNormalization(
142        axis=bn_axis, name=bn_name_base + '1')
143
144  def call(self, input_tensor, training=False):
145    x = self.conv2a(input_tensor)
146    x = self.bn2a(x, training=training)
147    x = tf.nn.relu(x)
148
149    x = self.conv2b(x)
150    x = self.bn2b(x, training=training)
151    x = tf.nn.relu(x)
152
153    x = self.conv2c(x)
154    x = self.bn2c(x, training=training)
155
156    shortcut = self.conv_shortcut(input_tensor)
157    shortcut = self.bn_shortcut(shortcut, training=training)
158
159    x += shortcut
160    return tf.nn.relu(x)
161
162
163# pylint: disable=not-callable
164class ResNet50(tf.keras.Model):
165  """Instantiates the ResNet50 architecture.
166
167  Args:
168    data_format: format for the image. Either 'channels_first' or
169      'channels_last'.  'channels_first' is typically faster on GPUs while
170      'channels_last' is typically faster on CPUs. See
171      https://www.tensorflow.org/performance/performance_guide#data_formats
172    name: Prefix applied to names of variables created in the model.
173    trainable: Is the model trainable? If true, performs backward
174        and optimization after call() method.
175    include_top: whether to include the fully-connected layer at the top of the
176      network.
177    pooling: Optional pooling mode for feature extraction when `include_top`
178      is `False`.
179      - `None` means that the output of the model will be the 4D tensor
180          output of the last convolutional layer.
181      - `avg` means that global average pooling will be applied to the output of
182          the last convolutional layer, and thus the output of the model will be
183          a 2D tensor.
184      - `max` means that global max pooling will be applied.
185    block3_strides: whether to add a stride of 2 to block3 to make it compatible
186      with tf.slim ResNet implementation.
187    average_pooling: whether to do average pooling of block4 features before
188      global pooling.
189    classes: optional number of classes to classify images into, only to be
190      specified if `include_top` is True.
191
192  Raises:
193      ValueError: in case of invalid argument for data_format.
194  """
195
196  def __init__(self,
197               data_format,
198               name='',
199               trainable=True,
200               include_top=True,
201               pooling=None,
202               block3_strides=False,
203               average_pooling=True,
204               classes=1000):
205    super(ResNet50, self).__init__(name=name)
206
207    valid_channel_values = ('channels_first', 'channels_last')
208    if data_format not in valid_channel_values:
209      raise ValueError('Unknown data_format: %s. Valid values: %s' %
210                       (data_format, valid_channel_values))
211    self.include_top = include_top
212    self.block3_strides = block3_strides
213    self.average_pooling = average_pooling
214    self.pooling = pooling
215
216    def conv_block(filters, stage, block, strides=(2, 2)):
217      return _ConvBlock(
218          3,
219          filters,
220          stage=stage,
221          block=block,
222          data_format=data_format,
223          strides=strides)
224
225    def id_block(filters, stage, block):
226      return _IdentityBlock(
227          3, filters, stage=stage, block=block, data_format=data_format)
228
229    self.conv1 = layers.Conv2D(
230        64, (7, 7),
231        strides=(2, 2),
232        data_format=data_format,
233        padding='same',
234        name='conv1')
235    bn_axis = 1 if data_format == 'channels_first' else 3
236    self.bn_conv1 = layers.BatchNormalization(axis=bn_axis, name='bn_conv1')
237    self.max_pool = layers.MaxPooling2D((3, 3),
238                                        strides=(2, 2),
239                                        data_format=data_format)
240
241    self.l2a = conv_block([64, 64, 256], stage=2, block='a', strides=(1, 1))
242    self.l2b = id_block([64, 64, 256], stage=2, block='b')
243    self.l2c = id_block([64, 64, 256], stage=2, block='c')
244
245    self.l3a = conv_block([128, 128, 512], stage=3, block='a')
246    self.l3b = id_block([128, 128, 512], stage=3, block='b')
247    self.l3c = id_block([128, 128, 512], stage=3, block='c')
248    self.l3d = id_block([128, 128, 512], stage=3, block='d')
249
250    self.l4a = conv_block([256, 256, 1024], stage=4, block='a')
251    self.l4b = id_block([256, 256, 1024], stage=4, block='b')
252    self.l4c = id_block([256, 256, 1024], stage=4, block='c')
253    self.l4d = id_block([256, 256, 1024], stage=4, block='d')
254    self.l4e = id_block([256, 256, 1024], stage=4, block='e')
255    self.l4f = id_block([256, 256, 1024], stage=4, block='f')
256
257    # Striding layer that can be used on top of block3 to produce feature maps
258    # with the same resolution as the TF-Slim implementation.
259    if self.block3_strides:
260      self.subsampling_layer = layers.MaxPooling2D((1, 1),
261                                                   strides=(2, 2),
262                                                   data_format=data_format)
263      self.l5a = conv_block([512, 512, 2048],
264                            stage=5,
265                            block='a',
266                            strides=(1, 1))
267    else:
268      self.l5a = conv_block([512, 512, 2048], stage=5, block='a')
269    self.l5b = id_block([512, 512, 2048], stage=5, block='b')
270    self.l5c = id_block([512, 512, 2048], stage=5, block='c')
271
272    self.avg_pool = layers.AveragePooling2D((7, 7),
273                                            strides=(7, 7),
274                                            data_format=data_format)
275
276    if self.include_top:
277      self.flatten = layers.Flatten()
278      self.fc1000 = layers.Dense(classes, name='fc1000')
279    else:
280      reduction_indices = [1, 2] if data_format == 'channels_last' else [2, 3]
281      reduction_indices = tf.constant(reduction_indices)
282      if pooling == 'avg':
283        self.global_pooling = functools.partial(
284            tf.reduce_mean,
285            axis=reduction_indices,
286            keepdims=False)
287      elif pooling == 'max':
288        self.global_pooling = functools.partial(
289            tf.reduce_max, reduction_indices=reduction_indices, keep_dims=False)
290      else:
291        self.global_pooling = None
292
293  def call(self, inputs, training=True, intermediates_dict=None):
294    """Call the ResNet50 model.
295
296    Args:
297      inputs: Images to compute features for.
298      training: Whether model is in training phase.
299      intermediates_dict: `None` or dictionary. If not None, accumulate feature
300        maps from intermediate blocks into the dictionary.
301        ""
302
303    Returns:
304      Tensor with featuremap.
305    """
306
307    x = self.conv1(inputs)
308    x = self.bn_conv1(x, training=training)
309    x = tf.nn.relu(x)
310    if intermediates_dict is not None:
311      intermediates_dict['block0'] = x
312
313    x = self.max_pool(x)
314    if intermediates_dict is not None:
315      intermediates_dict['block0mp'] = x
316
317    # Block 1 (equivalent to "conv2" in Resnet paper).
318    x = self.l2a(x, training=training)
319    x = self.l2b(x, training=training)
320    x = self.l2c(x, training=training)
321    if intermediates_dict is not None:
322      intermediates_dict['block1'] = x
323
324    # Block 2 (equivalent to "conv3" in Resnet paper).
325    x = self.l3a(x, training=training)
326    x = self.l3b(x, training=training)
327    x = self.l3c(x, training=training)
328    x = self.l3d(x, training=training)
329    if intermediates_dict is not None:
330      intermediates_dict['block2'] = x
331
332    # Block 3 (equivalent to "conv4" in Resnet paper).
333    x = self.l4a(x, training=training)
334    x = self.l4b(x, training=training)
335    x = self.l4c(x, training=training)
336    x = self.l4d(x, training=training)
337    x = self.l4e(x, training=training)
338    x = self.l4f(x, training=training)
339
340    if self.block3_strides:
341      x = self.subsampling_layer(x)
342      if intermediates_dict is not None:
343        intermediates_dict['block3'] = x
344    else:
345      if intermediates_dict is not None:
346        intermediates_dict['block3'] = x
347
348    x = self.l5a(x, training=training)
349    x = self.l5b(x, training=training)
350    x = self.l5c(x, training=training)
351
352    if self.average_pooling:
353      x = self.avg_pool(x)
354      if intermediates_dict is not None:
355        intermediates_dict['block4'] = x
356    else:
357      if intermediates_dict is not None:
358        intermediates_dict['block4'] = x
359
360    if self.include_top:
361      return self.fc1000(self.flatten(x))
362    elif self.global_pooling:
363      return self.global_pooling(x)
364    else:
365      return x
366