1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Computes Receptive Field (RF) information for different models. 16 17The receptive field (and related parameters) for the different models are 18printed to stdout, and may also optionally be written to a CSV file. 19""" 20 21from __future__ import absolute_import 22from __future__ import division 23from __future__ import print_function 24 25import argparse 26import csv 27import sys 28 29from tensorflow.contrib import framework 30from tensorflow.contrib import slim 31from tensorflow.contrib.receptive_field import receptive_field_api as receptive_field 32from tensorflow.python.framework import dtypes 33from tensorflow.python.framework import ops 34from tensorflow.python.ops import array_ops 35from tensorflow.python.platform import app 36from nets import alexnet 37from nets import inception 38from nets import mobilenet_v1 39from nets import resnet_v1 40from nets import resnet_v2 41from nets import vgg 42 43cmd_args = None 44 45# Input node name for all architectures. 46_INPUT_NODE = 'input_image' 47 48# Variants of different network architectures. 49 50# - resnet: different versions and sizes. 51_SUPPORTED_RESNET_VARIANTS = [ 52 'resnet_v1_50', 'resnet_v1_101', 'resnet_v1_152', 'resnet_v1_200', 53 'resnet_v2_50', 'resnet_v2_101', 'resnet_v2_152', 'resnet_v2_200' 54] 55 56# - inception_resnet_v2: default, and version with SAME padding. 57_SUPPORTED_INCEPTIONRESNETV2_VARIANTS = [ 58 'inception_resnet_v2', 'inception_resnet_v2-same' 59] 60 61# - inception_v2: default, and version with no separable conv. 62_SUPPORTED_INCEPTIONV2_VARIANTS = [ 63 'inception_v2', 'inception_v2-no-separable-conv' 64] 65 66# - inception_v3: default version. 67_SUPPORTED_INCEPTIONV3_VARIANTS = ['inception_v3'] 68 69# - inception_v4: default version. 70_SUPPORTED_INCEPTIONV4_VARIANTS = ['inception_v4'] 71 72# - alexnet_v2: default version. 73_SUPPORTED_ALEXNETV2_VARIANTS = ['alexnet_v2'] 74 75# - vgg: vgg_a (with 11 layers) and vgg_16 (version D). 76_SUPPORTED_VGG_VARIANTS = ['vgg_a', 'vgg_16'] 77 78# - mobilenet_v1: 100% and 75%. 79_SUPPORTED_MOBILENETV1_VARIANTS = ['mobilenet_v1', 'mobilenet_v1_075'] 80 81 82def _construct_model(model_type='resnet_v1_50'): 83 """Constructs model for the desired type of CNN. 84 85 Args: 86 model_type: Type of model to be used. 87 88 Returns: 89 end_points: A dictionary from components of the network to the corresponding 90 activations. 91 92 Raises: 93 ValueError: If the model_type is not supported. 94 """ 95 # Placeholder input. 96 images = array_ops.placeholder( 97 dtypes.float32, shape=(1, None, None, 3), name=_INPUT_NODE) 98 99 # Construct model. 100 if model_type == 'inception_resnet_v2': 101 _, end_points = inception.inception_resnet_v2_base(images) 102 elif model_type == 'inception_resnet_v2-same': 103 _, end_points = inception.inception_resnet_v2_base( 104 images, align_feature_maps=True) 105 elif model_type == 'inception_v2': 106 _, end_points = inception.inception_v2_base(images) 107 elif model_type == 'inception_v2-no-separable-conv': 108 _, end_points = inception.inception_v2_base( 109 images, use_separable_conv=False) 110 elif model_type == 'inception_v3': 111 _, end_points = inception.inception_v3_base(images) 112 elif model_type == 'inception_v4': 113 _, end_points = inception.inception_v4_base(images) 114 elif model_type == 'alexnet_v2': 115 _, end_points = alexnet.alexnet_v2(images) 116 elif model_type == 'vgg_a': 117 _, end_points = vgg.vgg_a(images) 118 elif model_type == 'vgg_16': 119 _, end_points = vgg.vgg_16(images) 120 elif model_type == 'mobilenet_v1': 121 _, end_points = mobilenet_v1.mobilenet_v1_base(images) 122 elif model_type == 'mobilenet_v1_075': 123 _, end_points = mobilenet_v1.mobilenet_v1_base( 124 images, depth_multiplier=0.75) 125 elif model_type == 'resnet_v1_50': 126 _, end_points = resnet_v1.resnet_v1_50( 127 images, num_classes=None, is_training=False, global_pool=False) 128 elif model_type == 'resnet_v1_101': 129 _, end_points = resnet_v1.resnet_v1_101( 130 images, num_classes=None, is_training=False, global_pool=False) 131 elif model_type == 'resnet_v1_152': 132 _, end_points = resnet_v1.resnet_v1_152( 133 images, num_classes=None, is_training=False, global_pool=False) 134 elif model_type == 'resnet_v1_200': 135 _, end_points = resnet_v1.resnet_v1_200( 136 images, num_classes=None, is_training=False, global_pool=False) 137 elif model_type == 'resnet_v2_50': 138 _, end_points = resnet_v2.resnet_v2_50( 139 images, num_classes=None, is_training=False, global_pool=False) 140 elif model_type == 'resnet_v2_101': 141 _, end_points = resnet_v2.resnet_v2_101( 142 images, num_classes=None, is_training=False, global_pool=False) 143 elif model_type == 'resnet_v2_152': 144 _, end_points = resnet_v2.resnet_v2_152( 145 images, num_classes=None, is_training=False, global_pool=False) 146 elif model_type == 'resnet_v2_200': 147 _, end_points = resnet_v2.resnet_v2_200( 148 images, num_classes=None, is_training=False, global_pool=False) 149 else: 150 raise ValueError('Unsupported model_type %s.' % model_type) 151 152 return end_points 153 154 155def _get_desired_end_point_keys(model_type='resnet_v1_50'): 156 """Gets list of desired end point keys for a type of CNN. 157 158 Args: 159 model_type: Type of model to be used. 160 161 Returns: 162 desired_end_point_types: A list containing the desired end-points. 163 164 Raises: 165 ValueError: If the model_type is not supported. 166 """ 167 if model_type in _SUPPORTED_RESNET_VARIANTS: 168 blocks = ['block1', 'block2', 'block3', 'block4'] 169 desired_end_point_keys = ['%s/%s' % (model_type, i) for i in blocks] 170 elif model_type in _SUPPORTED_INCEPTIONRESNETV2_VARIANTS: 171 desired_end_point_keys = [ 172 'Conv2d_1a_3x3', 'Conv2d_2a_3x3', 'Conv2d_2b_3x3', 'MaxPool_3a_3x3', 173 'Conv2d_3b_1x1', 'Conv2d_4a_3x3', 'MaxPool_5a_3x3', 'Mixed_5b', 174 'Mixed_6a', 'PreAuxLogits', 'Mixed_7a', 'Conv2d_7b_1x1' 175 ] 176 elif model_type in _SUPPORTED_INCEPTIONV2_VARIANTS: 177 desired_end_point_keys = [ 178 'Conv2d_1a_7x7', 'MaxPool_2a_3x3', 'Conv2d_2b_1x1', 'Conv2d_2c_3x3', 179 'MaxPool_3a_3x3', 'Mixed_3b', 'Mixed_3c', 'Mixed_4a', 'Mixed_4b', 180 'Mixed_4c', 'Mixed_4d', 'Mixed_4e', 'Mixed_5a', 'Mixed_5b', 'Mixed_5c' 181 ] 182 elif model_type in _SUPPORTED_INCEPTIONV3_VARIANTS: 183 desired_end_point_keys = [ 184 'Conv2d_1a_3x3', 'Conv2d_2a_3x3', 'Conv2d_2b_3x3', 'MaxPool_3a_3x3', 185 'Conv2d_3b_1x1', 'Conv2d_4a_3x3', 'MaxPool_5a_3x3', 'Mixed_5b', 186 'Mixed_5c', 'Mixed_5d', 'Mixed_6a', 'Mixed_6b', 'Mixed_6c', 'Mixed_6d', 187 'Mixed_6e', 'Mixed_7a', 'Mixed_7b', 'Mixed_7c' 188 ] 189 elif model_type in _SUPPORTED_INCEPTIONV4_VARIANTS: 190 desired_end_point_keys = [ 191 'Conv2d_1a_3x3', 'Conv2d_2a_3x3', 'Conv2d_2b_3x3', 'Mixed_3a', 192 'Mixed_4a', 'Mixed_5a', 'Mixed_5b', 'Mixed_5c', 'Mixed_5d', 'Mixed_5e', 193 'Mixed_6a', 'Mixed_6b', 'Mixed_6c', 'Mixed_6d', 'Mixed_6e', 'Mixed_6f', 194 'Mixed_6g', 'Mixed_6h', 'Mixed_7a', 'Mixed_7b', 'Mixed_7c', 'Mixed_7d' 195 ] 196 elif model_type in _SUPPORTED_ALEXNETV2_VARIANTS: 197 ep = ['conv1', 'pool1', 'conv2', 'conv3', 'conv4', 'conv5', 'pool5'] 198 desired_end_point_keys = ['%s/%s' % (model_type, i) for i in ep] 199 elif model_type in _SUPPORTED_VGG_VARIANTS: 200 ep = [ 201 'conv1/conv1_1', 'pool1', 'conv2/conv2_1', 'pool2', 'conv3/conv3_1', 202 'conv3/conv3_2', 'pool3', 'conv4/conv4_1', 'conv4/conv4_2', 'pool4', 203 'conv5/conv5_1', 'conv5/conv5_2', 'pool5' 204 ] 205 desired_end_point_keys = ['%s/%s' % (model_type, i) for i in ep] 206 elif model_type in _SUPPORTED_MOBILENETV1_VARIANTS: 207 desired_end_point_keys = [ 208 'Conv2d_0', 'Conv2d_1_pointwise', 'Conv2d_2_pointwise', 209 'Conv2d_3_pointwise', 'Conv2d_4_pointwise', 'Conv2d_5_pointwise', 210 'Conv2d_6_pointwise', 'Conv2d_7_pointwise', 'Conv2d_8_pointwise', 211 'Conv2d_9_pointwise', 'Conv2d_10_pointwise', 'Conv2d_11_pointwise', 212 'Conv2d_12_pointwise', 'Conv2d_13_pointwise' 213 ] 214 else: 215 raise ValueError('Unsupported model_type %s.' % model_type) 216 217 return desired_end_point_keys 218 219 220def _model_graph_def(model_type='resnet_v1_50', arg_sc=None): 221 """Constructs a model graph, returning GraphDef and end-points. 222 223 Args: 224 model_type: Type of model to be used. 225 arg_sc: Optional arg scope to use in constructing the graph. 226 227 Returns: 228 graph_def: GraphDef of constructed graph. 229 end_points: A dictionary from components of the network to the corresponding 230 activations. 231 """ 232 if arg_sc is None: 233 arg_sc = {} 234 g = ops.Graph() 235 with g.as_default(): 236 with framework.arg_scope(arg_sc): 237 end_points = _construct_model(model_type) 238 239 return g.as_graph_def(), end_points 240 241 242def _model_rf(graphdef, 243 end_points, 244 desired_end_point_keys, 245 model_type='resnet_v1_50', 246 csv_writer=None, 247 input_resolution=None): 248 """Computes receptive field information for a given CNN model. 249 250 The information will be printed to stdout. If the RF parameters are the same 251 for the horizontal and vertical directions, it will be printed only once. 252 Otherwise, they are printed once for the horizontal and once for the vertical 253 directions. 254 255 Args: 256 graphdef: GraphDef of given model. 257 end_points: A dictionary from components of the model to the corresponding 258 activations. 259 desired_end_point_keys: List of desired end points for which receptive field 260 information will be computed. 261 model_type: Type of model to be used, used only for printing purposes. 262 csv_writer: A CSV writer for RF parameters, which is used if it is not None. 263 input_resolution: Input resolution to use when computing RF parameters. This 264 is important for the case where padding can only be defined if the input 265 resolution is known, which may happen if using SAME padding. This is 266 assumed the resolution for both height and width. If None, we consider the 267 resolution is unknown. 268 """ 269 for desired_end_point_key in desired_end_point_keys: 270 print('- %s:' % desired_end_point_key) 271 output_node_with_colon = end_points[desired_end_point_key].name 272 pos = output_node_with_colon.rfind(':') 273 output_node = output_node_with_colon[:pos] 274 try: 275 (receptive_field_x, receptive_field_y, effective_stride_x, 276 effective_stride_y, effective_padding_x, effective_padding_y 277 ) = receptive_field.compute_receptive_field_from_graph_def( 278 graphdef, _INPUT_NODE, output_node, input_resolution=input_resolution) 279 # If values are the same in horizontal/vertical directions, just report 280 # one of them. Otherwise, report both. 281 if (receptive_field_x == receptive_field_y) and ( 282 effective_stride_x == effective_stride_y) and ( 283 effective_padding_x == effective_padding_y): 284 print( 285 'Receptive field size = %5s, effective stride = %5s, effective ' 286 'padding = %5s' % (str(receptive_field_x), str(effective_stride_x), 287 str(effective_padding_x))) 288 else: 289 print('Receptive field size: horizontal = %5s, vertical = %5s. ' 290 'Effective stride: horizontal = %5s, vertical = %5s. Effective ' 291 'padding: horizontal = %5s, vertical = %5s' % 292 (str(receptive_field_x), str(receptive_field_y), 293 str(effective_stride_x), str(effective_stride_y), 294 str(effective_padding_x), str(effective_padding_y))) 295 if csv_writer is not None: 296 csv_writer.writerow({ 297 'CNN': 298 model_type, 299 'input resolution': 300 str(input_resolution[0]) 301 if input_resolution is not None else 'None', 302 'end_point': 303 desired_end_point_key, 304 'RF size hor': 305 str(receptive_field_x), 306 'RF size ver': 307 str(receptive_field_y), 308 'effective stride hor': 309 str(effective_stride_x), 310 'effective stride ver': 311 str(effective_stride_y), 312 'effective padding hor': 313 str(effective_padding_x), 314 'effective padding ver': 315 str(effective_padding_y) 316 }) 317 except ValueError as e: 318 print('---->ERROR: Computing RF parameters for model %s with final end ' 319 'point %s and input resolution %s did not work' % 320 (model_type, desired_end_point_key, input_resolution)) 321 print('---->The returned error is: %s' % e) 322 if csv_writer is not None: 323 csv_writer.writerow({ 324 'CNN': 325 model_type, 326 'input resolution': 327 str(input_resolution[0]) 328 if input_resolution is not None else 'None', 329 'end_point': 330 desired_end_point_key, 331 'RF size hor': 332 'None', 333 'RF size ver': 334 'None', 335 'effective stride hor': 336 'None', 337 'effective stride ver': 338 'None', 339 'effective padding hor': 340 'None', 341 'effective padding ver': 342 'None' 343 }) 344 345 346def _process_model_rf(model_type='resnet_v1_50', 347 csv_writer=None, 348 arg_sc=None, 349 input_resolutions=None): 350 """Contructs model graph and desired end-points, and compute RF. 351 352 The computed RF parameters are printed to stdout by the _model_rf function. 353 354 Args: 355 model_type: Type of model to be used. 356 csv_writer: A CSV writer for RF parameters, which is used if it is not None. 357 arg_sc: Optional arg scope to use in constructing the graph. 358 input_resolutions: List of 1D input resolutions to use when computing RF 359 parameters. This is important for the case where padding can only be 360 defined if the input resolution is known, which may happen if using SAME 361 padding. The entries in the list are assumed the resolution for both 362 height and width. If one of the elements in the list is None, we consider 363 it to mean that the resolution is unknown. If the list itself is None, we 364 use the default list [None, 224, 321]. 365 """ 366 # Process default value for this list. 367 if input_resolutions is None: 368 input_resolutions = [None, 224, 321] 369 370 for n in input_resolutions: 371 print('********************%s, input resolution = %s' % (model_type, n)) 372 graphdef, end_points = _model_graph_def(model_type, arg_sc) 373 desired_end_point_keys = _get_desired_end_point_keys(model_type) 374 _model_rf( 375 graphdef, 376 end_points, 377 desired_end_point_keys, 378 model_type, 379 csv_writer, 380 input_resolution=[n, n] if n is not None else None) 381 382 383def _resnet_rf(csv_writer=None): 384 """Computes RF and associated parameters for resnet models. 385 386 The computed values are written to stdout. 387 388 Args: 389 csv_writer: A CSV writer for RF parameters, which is used if it is not None. 390 """ 391 for model_type in _SUPPORTED_RESNET_VARIANTS: 392 arg_sc = resnet_v1.resnet_arg_scope() 393 _process_model_rf(model_type, csv_writer, arg_sc) 394 395 396def _inception_resnet_v2_rf(csv_writer=None): 397 """Computes RF and associated parameters for the inception_resnet_v2 model. 398 399 The computed values are written to stdout. 400 401 Args: 402 csv_writer: A CSV writer for RF parameters, which is used if it is not None. 403 """ 404 for model_type in _SUPPORTED_INCEPTIONRESNETV2_VARIANTS: 405 _process_model_rf(model_type, csv_writer) 406 407 408def _inception_v2_rf(csv_writer=None): 409 """Computes RF and associated parameters for the inception_v2 model. 410 411 The computed values are written to stdout. 412 413 Args: 414 csv_writer: A CSV writer for RF parameters, which is used if it is not None. 415 """ 416 for model_type in _SUPPORTED_INCEPTIONV2_VARIANTS: 417 _process_model_rf(model_type, csv_writer) 418 419 420def _inception_v3_rf(csv_writer=None): 421 """Computes RF and associated parameters for the inception_v3 model. 422 423 The computed values are written to stdout. 424 425 Args: 426 csv_writer: A CSV writer for RF parameters, which is used if it is not None. 427 """ 428 for model_type in _SUPPORTED_INCEPTIONV3_VARIANTS: 429 _process_model_rf(model_type, csv_writer) 430 431 432def _inception_v4_rf(csv_writer=None): 433 """Computes RF and associated parameters for the inception_v4 model. 434 435 The computed values are written to stdout. 436 437 Args: 438 csv_writer: A CSV writer for RF parameters, which is used if it is not None. 439 """ 440 for model_type in _SUPPORTED_INCEPTIONV4_VARIANTS: 441 _process_model_rf(model_type, csv_writer) 442 443 444def _alexnet_v2_rf(csv_writer=None): 445 """Computes RF and associated parameters for the alexnet_v2 model. 446 447 The computed values are written to stdout. 448 449 Args: 450 csv_writer: A CSV writer for RF parameters, which is used if it is not None. 451 """ 452 for model_type in _SUPPORTED_ALEXNETV2_VARIANTS: 453 _process_model_rf(model_type, csv_writer) 454 455 456def _vgg_rf(csv_writer=None): 457 """Computes RF and associated parameters for the vgg model. 458 459 The computed values are written to stdout. 460 461 Args: 462 csv_writer: A CSV writer for RF parameters, which is used if it is not None. 463 """ 464 for model_type in _SUPPORTED_VGG_VARIANTS: 465 _process_model_rf(model_type, csv_writer) 466 467 468def _mobilenet_v1_rf(csv_writer=None): 469 """Computes RF and associated parameters for the mobilenet_v1 model. 470 471 The computed values are written to stdout. 472 473 Args: 474 csv_writer: A CSV writer for RF parameters, which is used if it is not None. 475 """ 476 for model_type in _SUPPORTED_MOBILENETV1_VARIANTS: 477 with slim.arg_scope([slim.batch_norm, slim.dropout], 478 is_training=False) as arg_sc: 479 _process_model_rf(model_type, csv_writer, arg_sc) 480 481 482def main(unused_argv): 483 # Configure CSV file which will be written, if desired. 484 if cmd_args.csv_path: 485 csv_file = open(cmd_args.csv_path, 'w') 486 field_names = [ 487 'CNN', 'input resolution', 'end_point', 'RF size hor', 'RF size ver', 488 'effective stride hor', 'effective stride ver', 'effective padding hor', 489 'effective padding ver' 490 ] 491 rf_writer = csv.DictWriter(csv_file, fieldnames=field_names) 492 rf_writer.writeheader() 493 else: 494 rf_writer = None 495 496 # Compute RF parameters for each network architecture. 497 _alexnet_v2_rf(rf_writer) 498 _vgg_rf(rf_writer) 499 _inception_v2_rf(rf_writer) 500 _inception_v3_rf(rf_writer) 501 _inception_v4_rf(rf_writer) 502 _inception_resnet_v2_rf(rf_writer) 503 _mobilenet_v1_rf(rf_writer) 504 _resnet_rf(rf_writer) 505 506 # Close CSV file, if it was opened. 507 if cmd_args.csv_path: 508 csv_file.close() 509 510 511if __name__ == '__main__': 512 parser = argparse.ArgumentParser() 513 parser.register('type', 'bool', lambda v: v.lower() == 'true') 514 parser.add_argument( 515 '--csv_path', 516 type=str, 517 default='', 518 help="""\ 519 Path to CSV file that will be written with RF parameters.If empty, no 520 file will be written.\ 521 """) 522 cmd_args, unparsed = parser.parse_known_args() 523 app.run(main=main, argv=[sys.argv[0]] + unparsed) 524