1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""ResNet50 model definition compatible with TensorFlow's eager execution. 16 17Reference [Deep Residual Learning for Image 18Recognition](https://arxiv.org/abs/1512.03385) 19 20Adapted from tf.keras.applications.ResNet50. A notable difference is that the 21model here outputs logits while the Keras model outputs probability. 22""" 23import functools 24 25import tensorflow as tf 26 27layers = tf.keras.layers 28 29 30class _IdentityBlock(tf.keras.Model): 31 """_IdentityBlock is the block that has no conv layer at shortcut. 32 33 Args: 34 kernel_size: the kernel size of middle conv layer at main path 35 filters: list of integers, the filters of 3 conv layer at main path 36 stage: integer, current stage label, used for generating layer names 37 block: 'a','b'..., current block label, used for generating layer names 38 data_format: data_format for the input ('channels_first' or 39 'channels_last'). 40 """ 41 42 def __init__(self, kernel_size, filters, stage, block, data_format): 43 super(_IdentityBlock, self).__init__(name='') 44 filters1, filters2, filters3 = filters 45 46 conv_name_base = 'res' + str(stage) + block + '_branch' 47 bn_name_base = 'bn' + str(stage) + block + '_branch' 48 bn_axis = 1 if data_format == 'channels_first' else 3 49 50 self.conv2a = layers.Conv2D( 51 filters1, (1, 1), name=conv_name_base + '2a', data_format=data_format) 52 self.bn2a = layers.BatchNormalization( 53 axis=bn_axis, name=bn_name_base + '2a') 54 55 self.conv2b = layers.Conv2D( 56 filters2, 57 kernel_size, 58 padding='same', 59 data_format=data_format, 60 name=conv_name_base + '2b') 61 self.bn2b = layers.BatchNormalization( 62 axis=bn_axis, name=bn_name_base + '2b') 63 64 self.conv2c = layers.Conv2D( 65 filters3, (1, 1), name=conv_name_base + '2c', data_format=data_format) 66 self.bn2c = layers.BatchNormalization( 67 axis=bn_axis, name=bn_name_base + '2c') 68 69 def call(self, input_tensor, training=False): 70 x = self.conv2a(input_tensor) 71 x = self.bn2a(x, training=training) 72 x = tf.nn.relu(x) 73 74 x = self.conv2b(x) 75 x = self.bn2b(x, training=training) 76 x = tf.nn.relu(x) 77 78 x = self.conv2c(x) 79 x = self.bn2c(x, training=training) 80 81 x += input_tensor 82 return tf.nn.relu(x) 83 84 85class _ConvBlock(tf.keras.Model): 86 """_ConvBlock is the block that has a conv layer at shortcut. 87 88 Args: 89 kernel_size: the kernel size of middle conv layer at main path 90 filters: list of integers, the filters of 3 conv layer at main path 91 stage: integer, current stage label, used for generating layer names 92 block: 'a','b'..., current block label, used for generating layer names 93 data_format: data_format for the input ('channels_first' or 94 'channels_last'). 95 strides: strides for the convolution. Note that from stage 3, the first 96 conv layer at main path is with strides=(2,2), and the shortcut should 97 have strides=(2,2) as well. 98 """ 99 100 def __init__(self, 101 kernel_size, 102 filters, 103 stage, 104 block, 105 data_format, 106 strides=(2, 2)): 107 super(_ConvBlock, self).__init__(name='') 108 filters1, filters2, filters3 = filters 109 110 conv_name_base = 'res' + str(stage) + block + '_branch' 111 bn_name_base = 'bn' + str(stage) + block + '_branch' 112 bn_axis = 1 if data_format == 'channels_first' else 3 113 114 self.conv2a = layers.Conv2D( 115 filters1, (1, 1), 116 strides=strides, 117 name=conv_name_base + '2a', 118 data_format=data_format) 119 self.bn2a = layers.BatchNormalization( 120 axis=bn_axis, name=bn_name_base + '2a') 121 122 self.conv2b = layers.Conv2D( 123 filters2, 124 kernel_size, 125 padding='same', 126 name=conv_name_base + '2b', 127 data_format=data_format) 128 self.bn2b = layers.BatchNormalization( 129 axis=bn_axis, name=bn_name_base + '2b') 130 131 self.conv2c = layers.Conv2D( 132 filters3, (1, 1), name=conv_name_base + '2c', data_format=data_format) 133 self.bn2c = layers.BatchNormalization( 134 axis=bn_axis, name=bn_name_base + '2c') 135 136 self.conv_shortcut = layers.Conv2D( 137 filters3, (1, 1), 138 strides=strides, 139 name=conv_name_base + '1', 140 data_format=data_format) 141 self.bn_shortcut = layers.BatchNormalization( 142 axis=bn_axis, name=bn_name_base + '1') 143 144 def call(self, input_tensor, training=False): 145 x = self.conv2a(input_tensor) 146 x = self.bn2a(x, training=training) 147 x = tf.nn.relu(x) 148 149 x = self.conv2b(x) 150 x = self.bn2b(x, training=training) 151 x = tf.nn.relu(x) 152 153 x = self.conv2c(x) 154 x = self.bn2c(x, training=training) 155 156 shortcut = self.conv_shortcut(input_tensor) 157 shortcut = self.bn_shortcut(shortcut, training=training) 158 159 x += shortcut 160 return tf.nn.relu(x) 161 162 163# pylint: disable=not-callable 164class ResNet50(tf.keras.Model): 165 """Instantiates the ResNet50 architecture. 166 167 Args: 168 data_format: format for the image. Either 'channels_first' or 169 'channels_last'. 'channels_first' is typically faster on GPUs while 170 'channels_last' is typically faster on CPUs. See 171 https://www.tensorflow.org/performance/performance_guide#data_formats 172 name: Prefix applied to names of variables created in the model. 173 trainable: Is the model trainable? If true, performs backward 174 and optimization after call() method. 175 include_top: whether to include the fully-connected layer at the top of the 176 network. 177 pooling: Optional pooling mode for feature extraction when `include_top` 178 is `False`. 179 - `None` means that the output of the model will be the 4D tensor 180 output of the last convolutional layer. 181 - `avg` means that global average pooling will be applied to the output of 182 the last convolutional layer, and thus the output of the model will be 183 a 2D tensor. 184 - `max` means that global max pooling will be applied. 185 block3_strides: whether to add a stride of 2 to block3 to make it compatible 186 with tf.slim ResNet implementation. 187 average_pooling: whether to do average pooling of block4 features before 188 global pooling. 189 classes: optional number of classes to classify images into, only to be 190 specified if `include_top` is True. 191 192 Raises: 193 ValueError: in case of invalid argument for data_format. 194 """ 195 196 def __init__(self, 197 data_format, 198 name='', 199 trainable=True, 200 include_top=True, 201 pooling=None, 202 block3_strides=False, 203 average_pooling=True, 204 classes=1000): 205 super(ResNet50, self).__init__(name=name) 206 207 valid_channel_values = ('channels_first', 'channels_last') 208 if data_format not in valid_channel_values: 209 raise ValueError('Unknown data_format: %s. Valid values: %s' % 210 (data_format, valid_channel_values)) 211 self.include_top = include_top 212 self.block3_strides = block3_strides 213 self.average_pooling = average_pooling 214 self.pooling = pooling 215 216 def conv_block(filters, stage, block, strides=(2, 2)): 217 return _ConvBlock( 218 3, 219 filters, 220 stage=stage, 221 block=block, 222 data_format=data_format, 223 strides=strides) 224 225 def id_block(filters, stage, block): 226 return _IdentityBlock( 227 3, filters, stage=stage, block=block, data_format=data_format) 228 229 self.conv1 = layers.Conv2D( 230 64, (7, 7), 231 strides=(2, 2), 232 data_format=data_format, 233 padding='same', 234 name='conv1') 235 bn_axis = 1 if data_format == 'channels_first' else 3 236 self.bn_conv1 = layers.BatchNormalization(axis=bn_axis, name='bn_conv1') 237 self.max_pool = layers.MaxPooling2D((3, 3), 238 strides=(2, 2), 239 data_format=data_format) 240 241 self.l2a = conv_block([64, 64, 256], stage=2, block='a', strides=(1, 1)) 242 self.l2b = id_block([64, 64, 256], stage=2, block='b') 243 self.l2c = id_block([64, 64, 256], stage=2, block='c') 244 245 self.l3a = conv_block([128, 128, 512], stage=3, block='a') 246 self.l3b = id_block([128, 128, 512], stage=3, block='b') 247 self.l3c = id_block([128, 128, 512], stage=3, block='c') 248 self.l3d = id_block([128, 128, 512], stage=3, block='d') 249 250 self.l4a = conv_block([256, 256, 1024], stage=4, block='a') 251 self.l4b = id_block([256, 256, 1024], stage=4, block='b') 252 self.l4c = id_block([256, 256, 1024], stage=4, block='c') 253 self.l4d = id_block([256, 256, 1024], stage=4, block='d') 254 self.l4e = id_block([256, 256, 1024], stage=4, block='e') 255 self.l4f = id_block([256, 256, 1024], stage=4, block='f') 256 257 # Striding layer that can be used on top of block3 to produce feature maps 258 # with the same resolution as the TF-Slim implementation. 259 if self.block3_strides: 260 self.subsampling_layer = layers.MaxPooling2D((1, 1), 261 strides=(2, 2), 262 data_format=data_format) 263 self.l5a = conv_block([512, 512, 2048], 264 stage=5, 265 block='a', 266 strides=(1, 1)) 267 else: 268 self.l5a = conv_block([512, 512, 2048], stage=5, block='a') 269 self.l5b = id_block([512, 512, 2048], stage=5, block='b') 270 self.l5c = id_block([512, 512, 2048], stage=5, block='c') 271 272 self.avg_pool = layers.AveragePooling2D((7, 7), 273 strides=(7, 7), 274 data_format=data_format) 275 276 if self.include_top: 277 self.flatten = layers.Flatten() 278 self.fc1000 = layers.Dense(classes, name='fc1000') 279 else: 280 reduction_indices = [1, 2] if data_format == 'channels_last' else [2, 3] 281 reduction_indices = tf.constant(reduction_indices) 282 if pooling == 'avg': 283 self.global_pooling = functools.partial( 284 tf.reduce_mean, 285 axis=reduction_indices, 286 keepdims=False) 287 elif pooling == 'max': 288 self.global_pooling = functools.partial( 289 tf.reduce_max, reduction_indices=reduction_indices, keep_dims=False) 290 else: 291 self.global_pooling = None 292 293 def call(self, inputs, training=True, intermediates_dict=None): 294 """Call the ResNet50 model. 295 296 Args: 297 inputs: Images to compute features for. 298 training: Whether model is in training phase. 299 intermediates_dict: `None` or dictionary. If not None, accumulate feature 300 maps from intermediate blocks into the dictionary. 301 "" 302 303 Returns: 304 Tensor with featuremap. 305 """ 306 307 x = self.conv1(inputs) 308 x = self.bn_conv1(x, training=training) 309 x = tf.nn.relu(x) 310 if intermediates_dict is not None: 311 intermediates_dict['block0'] = x 312 313 x = self.max_pool(x) 314 if intermediates_dict is not None: 315 intermediates_dict['block0mp'] = x 316 317 # Block 1 (equivalent to "conv2" in Resnet paper). 318 x = self.l2a(x, training=training) 319 x = self.l2b(x, training=training) 320 x = self.l2c(x, training=training) 321 if intermediates_dict is not None: 322 intermediates_dict['block1'] = x 323 324 # Block 2 (equivalent to "conv3" in Resnet paper). 325 x = self.l3a(x, training=training) 326 x = self.l3b(x, training=training) 327 x = self.l3c(x, training=training) 328 x = self.l3d(x, training=training) 329 if intermediates_dict is not None: 330 intermediates_dict['block2'] = x 331 332 # Block 3 (equivalent to "conv4" in Resnet paper). 333 x = self.l4a(x, training=training) 334 x = self.l4b(x, training=training) 335 x = self.l4c(x, training=training) 336 x = self.l4d(x, training=training) 337 x = self.l4e(x, training=training) 338 x = self.l4f(x, training=training) 339 340 if self.block3_strides: 341 x = self.subsampling_layer(x) 342 if intermediates_dict is not None: 343 intermediates_dict['block3'] = x 344 else: 345 if intermediates_dict is not None: 346 intermediates_dict['block3'] = x 347 348 x = self.l5a(x, training=training) 349 x = self.l5b(x, training=training) 350 x = self.l5c(x, training=training) 351 352 if self.average_pooling: 353 x = self.avg_pool(x) 354 if intermediates_dict is not None: 355 intermediates_dict['block4'] = x 356 else: 357 if intermediates_dict is not None: 358 intermediates_dict['block4'] = x 359 360 if self.include_top: 361 return self.fc1000(self.flatten(x)) 362 elif self.global_pooling: 363 return self.global_pooling(x) 364 else: 365 return x 366