• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Computes Receptive Field (RF) information for different models.
16
17The receptive field (and related parameters) for the different models are
18printed to stdout, and may also optionally be written to a CSV file.
19"""
20
21from __future__ import absolute_import
22from __future__ import division
23from __future__ import print_function
24
25import argparse
26import csv
27import sys
28
29from tensorflow.contrib import framework
30from tensorflow.contrib import slim
31from tensorflow.contrib.receptive_field import receptive_field_api as receptive_field
32from tensorflow.python.framework import dtypes
33from tensorflow.python.framework import ops
34from tensorflow.python.ops import array_ops
35from tensorflow.python.platform import app
36from nets import alexnet
37from nets import inception
38from nets import mobilenet_v1
39from nets import resnet_v1
40from nets import resnet_v2
41from nets import vgg
42
43cmd_args = None
44
45# Input node name for all architectures.
46_INPUT_NODE = 'input_image'
47
48# Variants of different network architectures.
49
50# - resnet: different versions and sizes.
51_SUPPORTED_RESNET_VARIANTS = [
52    'resnet_v1_50', 'resnet_v1_101', 'resnet_v1_152', 'resnet_v1_200',
53    'resnet_v2_50', 'resnet_v2_101', 'resnet_v2_152', 'resnet_v2_200'
54]
55
56# - inception_resnet_v2: default, and version with SAME padding.
57_SUPPORTED_INCEPTIONRESNETV2_VARIANTS = [
58    'inception_resnet_v2', 'inception_resnet_v2-same'
59]
60
61# - inception_v2: default, and version with no separable conv.
62_SUPPORTED_INCEPTIONV2_VARIANTS = [
63    'inception_v2', 'inception_v2-no-separable-conv'
64]
65
66# - inception_v3: default version.
67_SUPPORTED_INCEPTIONV3_VARIANTS = ['inception_v3']
68
69# - inception_v4: default version.
70_SUPPORTED_INCEPTIONV4_VARIANTS = ['inception_v4']
71
72# - alexnet_v2: default version.
73_SUPPORTED_ALEXNETV2_VARIANTS = ['alexnet_v2']
74
75# - vgg: vgg_a (with 11 layers) and vgg_16 (version D).
76_SUPPORTED_VGG_VARIANTS = ['vgg_a', 'vgg_16']
77
78# - mobilenet_v1: 100% and 75%.
79_SUPPORTED_MOBILENETV1_VARIANTS = ['mobilenet_v1', 'mobilenet_v1_075']
80
81
82def _construct_model(model_type='resnet_v1_50'):
83  """Constructs model for the desired type of CNN.
84
85  Args:
86    model_type: Type of model to be used.
87
88  Returns:
89    end_points: A dictionary from components of the network to the corresponding
90      activations.
91
92  Raises:
93    ValueError: If the model_type is not supported.
94  """
95  # Placeholder input.
96  images = array_ops.placeholder(
97      dtypes.float32, shape=(1, None, None, 3), name=_INPUT_NODE)
98
99  # Construct model.
100  if model_type == 'inception_resnet_v2':
101    _, end_points = inception.inception_resnet_v2_base(images)
102  elif model_type == 'inception_resnet_v2-same':
103    _, end_points = inception.inception_resnet_v2_base(
104        images, align_feature_maps=True)
105  elif model_type == 'inception_v2':
106    _, end_points = inception.inception_v2_base(images)
107  elif model_type == 'inception_v2-no-separable-conv':
108    _, end_points = inception.inception_v2_base(
109        images, use_separable_conv=False)
110  elif model_type == 'inception_v3':
111    _, end_points = inception.inception_v3_base(images)
112  elif model_type == 'inception_v4':
113    _, end_points = inception.inception_v4_base(images)
114  elif model_type == 'alexnet_v2':
115    _, end_points = alexnet.alexnet_v2(images)
116  elif model_type == 'vgg_a':
117    _, end_points = vgg.vgg_a(images)
118  elif model_type == 'vgg_16':
119    _, end_points = vgg.vgg_16(images)
120  elif model_type == 'mobilenet_v1':
121    _, end_points = mobilenet_v1.mobilenet_v1_base(images)
122  elif model_type == 'mobilenet_v1_075':
123    _, end_points = mobilenet_v1.mobilenet_v1_base(
124        images, depth_multiplier=0.75)
125  elif model_type == 'resnet_v1_50':
126    _, end_points = resnet_v1.resnet_v1_50(
127        images, num_classes=None, is_training=False, global_pool=False)
128  elif model_type == 'resnet_v1_101':
129    _, end_points = resnet_v1.resnet_v1_101(
130        images, num_classes=None, is_training=False, global_pool=False)
131  elif model_type == 'resnet_v1_152':
132    _, end_points = resnet_v1.resnet_v1_152(
133        images, num_classes=None, is_training=False, global_pool=False)
134  elif model_type == 'resnet_v1_200':
135    _, end_points = resnet_v1.resnet_v1_200(
136        images, num_classes=None, is_training=False, global_pool=False)
137  elif model_type == 'resnet_v2_50':
138    _, end_points = resnet_v2.resnet_v2_50(
139        images, num_classes=None, is_training=False, global_pool=False)
140  elif model_type == 'resnet_v2_101':
141    _, end_points = resnet_v2.resnet_v2_101(
142        images, num_classes=None, is_training=False, global_pool=False)
143  elif model_type == 'resnet_v2_152':
144    _, end_points = resnet_v2.resnet_v2_152(
145        images, num_classes=None, is_training=False, global_pool=False)
146  elif model_type == 'resnet_v2_200':
147    _, end_points = resnet_v2.resnet_v2_200(
148        images, num_classes=None, is_training=False, global_pool=False)
149  else:
150    raise ValueError('Unsupported model_type %s.' % model_type)
151
152  return end_points
153
154
155def _get_desired_end_point_keys(model_type='resnet_v1_50'):
156  """Gets list of desired end point keys for a type of CNN.
157
158  Args:
159    model_type: Type of model to be used.
160
161  Returns:
162    desired_end_point_types: A list containing the desired end-points.
163
164  Raises:
165    ValueError: If the model_type is not supported.
166  """
167  if model_type in _SUPPORTED_RESNET_VARIANTS:
168    blocks = ['block1', 'block2', 'block3', 'block4']
169    desired_end_point_keys = ['%s/%s' % (model_type, i) for i in blocks]
170  elif model_type in _SUPPORTED_INCEPTIONRESNETV2_VARIANTS:
171    desired_end_point_keys = [
172        'Conv2d_1a_3x3', 'Conv2d_2a_3x3', 'Conv2d_2b_3x3', 'MaxPool_3a_3x3',
173        'Conv2d_3b_1x1', 'Conv2d_4a_3x3', 'MaxPool_5a_3x3', 'Mixed_5b',
174        'Mixed_6a', 'PreAuxLogits', 'Mixed_7a', 'Conv2d_7b_1x1'
175    ]
176  elif model_type in _SUPPORTED_INCEPTIONV2_VARIANTS:
177    desired_end_point_keys = [
178        'Conv2d_1a_7x7', 'MaxPool_2a_3x3', 'Conv2d_2b_1x1', 'Conv2d_2c_3x3',
179        'MaxPool_3a_3x3', 'Mixed_3b', 'Mixed_3c', 'Mixed_4a', 'Mixed_4b',
180        'Mixed_4c', 'Mixed_4d', 'Mixed_4e', 'Mixed_5a', 'Mixed_5b', 'Mixed_5c'
181    ]
182  elif model_type in _SUPPORTED_INCEPTIONV3_VARIANTS:
183    desired_end_point_keys = [
184        'Conv2d_1a_3x3', 'Conv2d_2a_3x3', 'Conv2d_2b_3x3', 'MaxPool_3a_3x3',
185        'Conv2d_3b_1x1', 'Conv2d_4a_3x3', 'MaxPool_5a_3x3', 'Mixed_5b',
186        'Mixed_5c', 'Mixed_5d', 'Mixed_6a', 'Mixed_6b', 'Mixed_6c', 'Mixed_6d',
187        'Mixed_6e', 'Mixed_7a', 'Mixed_7b', 'Mixed_7c'
188    ]
189  elif model_type in _SUPPORTED_INCEPTIONV4_VARIANTS:
190    desired_end_point_keys = [
191        'Conv2d_1a_3x3', 'Conv2d_2a_3x3', 'Conv2d_2b_3x3', 'Mixed_3a',
192        'Mixed_4a', 'Mixed_5a', 'Mixed_5b', 'Mixed_5c', 'Mixed_5d', 'Mixed_5e',
193        'Mixed_6a', 'Mixed_6b', 'Mixed_6c', 'Mixed_6d', 'Mixed_6e', 'Mixed_6f',
194        'Mixed_6g', 'Mixed_6h', 'Mixed_7a', 'Mixed_7b', 'Mixed_7c', 'Mixed_7d'
195    ]
196  elif model_type in _SUPPORTED_ALEXNETV2_VARIANTS:
197    ep = ['conv1', 'pool1', 'conv2', 'conv3', 'conv4', 'conv5', 'pool5']
198    desired_end_point_keys = ['%s/%s' % (model_type, i) for i in ep]
199  elif model_type in _SUPPORTED_VGG_VARIANTS:
200    ep = [
201        'conv1/conv1_1', 'pool1', 'conv2/conv2_1', 'pool2', 'conv3/conv3_1',
202        'conv3/conv3_2', 'pool3', 'conv4/conv4_1', 'conv4/conv4_2', 'pool4',
203        'conv5/conv5_1', 'conv5/conv5_2', 'pool5'
204    ]
205    desired_end_point_keys = ['%s/%s' % (model_type, i) for i in ep]
206  elif model_type in _SUPPORTED_MOBILENETV1_VARIANTS:
207    desired_end_point_keys = [
208        'Conv2d_0', 'Conv2d_1_pointwise', 'Conv2d_2_pointwise',
209        'Conv2d_3_pointwise', 'Conv2d_4_pointwise', 'Conv2d_5_pointwise',
210        'Conv2d_6_pointwise', 'Conv2d_7_pointwise', 'Conv2d_8_pointwise',
211        'Conv2d_9_pointwise', 'Conv2d_10_pointwise', 'Conv2d_11_pointwise',
212        'Conv2d_12_pointwise', 'Conv2d_13_pointwise'
213    ]
214  else:
215    raise ValueError('Unsupported model_type %s.' % model_type)
216
217  return desired_end_point_keys
218
219
220def _model_graph_def(model_type='resnet_v1_50', arg_sc=None):
221  """Constructs a model graph, returning GraphDef and end-points.
222
223  Args:
224    model_type: Type of model to be used.
225    arg_sc: Optional arg scope to use in constructing the graph.
226
227  Returns:
228    graph_def: GraphDef of constructed graph.
229    end_points: A dictionary from components of the network to the corresponding
230      activations.
231  """
232  if arg_sc is None:
233    arg_sc = {}
234  g = ops.Graph()
235  with g.as_default():
236    with framework.arg_scope(arg_sc):
237      end_points = _construct_model(model_type)
238
239  return g.as_graph_def(), end_points
240
241
242def _model_rf(graphdef,
243              end_points,
244              desired_end_point_keys,
245              model_type='resnet_v1_50',
246              csv_writer=None,
247              input_resolution=None):
248  """Computes receptive field information for a given CNN model.
249
250  The information will be printed to stdout. If the RF parameters are the same
251  for the horizontal and vertical directions, it will be printed only once.
252  Otherwise, they are printed once for the horizontal and once for the vertical
253  directions.
254
255  Args:
256    graphdef: GraphDef of given model.
257    end_points: A dictionary from components of the model to the corresponding
258      activations.
259    desired_end_point_keys: List of desired end points for which receptive field
260      information will be computed.
261    model_type: Type of model to be used, used only for printing purposes.
262    csv_writer: A CSV writer for RF parameters, which is used if it is not None.
263    input_resolution: Input resolution to use when computing RF parameters. This
264      is important for the case where padding can only be defined if the input
265      resolution is known, which may happen if using SAME padding. This is
266      assumed the resolution for both height and width. If None, we consider the
267      resolution is unknown.
268  """
269  for desired_end_point_key in desired_end_point_keys:
270    print('- %s:' % desired_end_point_key)
271    output_node_with_colon = end_points[desired_end_point_key].name
272    pos = output_node_with_colon.rfind(':')
273    output_node = output_node_with_colon[:pos]
274    try:
275      (receptive_field_x, receptive_field_y, effective_stride_x,
276       effective_stride_y, effective_padding_x, effective_padding_y
277      ) = receptive_field.compute_receptive_field_from_graph_def(
278          graphdef, _INPUT_NODE, output_node, input_resolution=input_resolution)
279      # If values are the same in horizontal/vertical directions, just report
280      # one of them. Otherwise, report both.
281      if (receptive_field_x == receptive_field_y) and (
282          effective_stride_x == effective_stride_y) and (
283              effective_padding_x == effective_padding_y):
284        print(
285            'Receptive field size = %5s, effective stride = %5s, effective '
286            'padding = %5s' % (str(receptive_field_x), str(effective_stride_x),
287                               str(effective_padding_x)))
288      else:
289        print('Receptive field size: horizontal = %5s, vertical = %5s. '
290              'Effective stride: horizontal = %5s, vertical = %5s. Effective '
291              'padding: horizontal = %5s, vertical = %5s' %
292              (str(receptive_field_x), str(receptive_field_y),
293               str(effective_stride_x), str(effective_stride_y),
294               str(effective_padding_x), str(effective_padding_y)))
295      if csv_writer is not None:
296        csv_writer.writerow({
297            'CNN':
298                model_type,
299            'input resolution':
300                str(input_resolution[0])
301                if input_resolution is not None else 'None',
302            'end_point':
303                desired_end_point_key,
304            'RF size hor':
305                str(receptive_field_x),
306            'RF size ver':
307                str(receptive_field_y),
308            'effective stride hor':
309                str(effective_stride_x),
310            'effective stride ver':
311                str(effective_stride_y),
312            'effective padding hor':
313                str(effective_padding_x),
314            'effective padding ver':
315                str(effective_padding_y)
316        })
317    except ValueError as e:
318      print('---->ERROR: Computing RF parameters for model %s with final end '
319            'point %s and input resolution %s did not work' %
320            (model_type, desired_end_point_key, input_resolution))
321      print('---->The returned error is: %s' % e)
322      if csv_writer is not None:
323        csv_writer.writerow({
324            'CNN':
325                model_type,
326            'input resolution':
327                str(input_resolution[0])
328                if input_resolution is not None else 'None',
329            'end_point':
330                desired_end_point_key,
331            'RF size hor':
332                'None',
333            'RF size ver':
334                'None',
335            'effective stride hor':
336                'None',
337            'effective stride ver':
338                'None',
339            'effective padding hor':
340                'None',
341            'effective padding ver':
342                'None'
343        })
344
345
346def _process_model_rf(model_type='resnet_v1_50',
347                      csv_writer=None,
348                      arg_sc=None,
349                      input_resolutions=None):
350  """Contructs model graph and desired end-points, and compute RF.
351
352  The computed RF parameters are printed to stdout by the _model_rf function.
353
354  Args:
355    model_type: Type of model to be used.
356    csv_writer: A CSV writer for RF parameters, which is used if it is not None.
357    arg_sc: Optional arg scope to use in constructing the graph.
358    input_resolutions: List of 1D input resolutions to use when computing RF
359      parameters. This is important for the case where padding can only be
360      defined if the input resolution is known, which may happen if using SAME
361      padding. The entries in the list are assumed the resolution for both
362      height and width. If one of the elements in the list is None, we consider
363      it to mean that the resolution is unknown. If the list itself is None, we
364      use the default list [None, 224, 321].
365  """
366  # Process default value for this list.
367  if input_resolutions is None:
368    input_resolutions = [None, 224, 321]
369
370  for n in input_resolutions:
371    print('********************%s, input resolution = %s' % (model_type, n))
372    graphdef, end_points = _model_graph_def(model_type, arg_sc)
373    desired_end_point_keys = _get_desired_end_point_keys(model_type)
374    _model_rf(
375        graphdef,
376        end_points,
377        desired_end_point_keys,
378        model_type,
379        csv_writer,
380        input_resolution=[n, n] if n is not None else None)
381
382
383def _resnet_rf(csv_writer=None):
384  """Computes RF and associated parameters for resnet models.
385
386  The computed values are written to stdout.
387
388  Args:
389    csv_writer: A CSV writer for RF parameters, which is used if it is not None.
390  """
391  for model_type in _SUPPORTED_RESNET_VARIANTS:
392    arg_sc = resnet_v1.resnet_arg_scope()
393    _process_model_rf(model_type, csv_writer, arg_sc)
394
395
396def _inception_resnet_v2_rf(csv_writer=None):
397  """Computes RF and associated parameters for the inception_resnet_v2 model.
398
399  The computed values are written to stdout.
400
401  Args:
402    csv_writer: A CSV writer for RF parameters, which is used if it is not None.
403  """
404  for model_type in _SUPPORTED_INCEPTIONRESNETV2_VARIANTS:
405    _process_model_rf(model_type, csv_writer)
406
407
408def _inception_v2_rf(csv_writer=None):
409  """Computes RF and associated parameters for the inception_v2 model.
410
411  The computed values are written to stdout.
412
413  Args:
414    csv_writer: A CSV writer for RF parameters, which is used if it is not None.
415  """
416  for model_type in _SUPPORTED_INCEPTIONV2_VARIANTS:
417    _process_model_rf(model_type, csv_writer)
418
419
420def _inception_v3_rf(csv_writer=None):
421  """Computes RF and associated parameters for the inception_v3 model.
422
423  The computed values are written to stdout.
424
425  Args:
426    csv_writer: A CSV writer for RF parameters, which is used if it is not None.
427  """
428  for model_type in _SUPPORTED_INCEPTIONV3_VARIANTS:
429    _process_model_rf(model_type, csv_writer)
430
431
432def _inception_v4_rf(csv_writer=None):
433  """Computes RF and associated parameters for the inception_v4 model.
434
435  The computed values are written to stdout.
436
437  Args:
438    csv_writer: A CSV writer for RF parameters, which is used if it is not None.
439  """
440  for model_type in _SUPPORTED_INCEPTIONV4_VARIANTS:
441    _process_model_rf(model_type, csv_writer)
442
443
444def _alexnet_v2_rf(csv_writer=None):
445  """Computes RF and associated parameters for the alexnet_v2 model.
446
447  The computed values are written to stdout.
448
449  Args:
450    csv_writer: A CSV writer for RF parameters, which is used if it is not None.
451  """
452  for model_type in _SUPPORTED_ALEXNETV2_VARIANTS:
453    _process_model_rf(model_type, csv_writer)
454
455
456def _vgg_rf(csv_writer=None):
457  """Computes RF and associated parameters for the vgg model.
458
459  The computed values are written to stdout.
460
461  Args:
462    csv_writer: A CSV writer for RF parameters, which is used if it is not None.
463  """
464  for model_type in _SUPPORTED_VGG_VARIANTS:
465    _process_model_rf(model_type, csv_writer)
466
467
468def _mobilenet_v1_rf(csv_writer=None):
469  """Computes RF and associated parameters for the mobilenet_v1 model.
470
471  The computed values are written to stdout.
472
473  Args:
474    csv_writer: A CSV writer for RF parameters, which is used if it is not None.
475  """
476  for model_type in _SUPPORTED_MOBILENETV1_VARIANTS:
477    with slim.arg_scope([slim.batch_norm, slim.dropout],
478                        is_training=False) as arg_sc:
479      _process_model_rf(model_type, csv_writer, arg_sc)
480
481
482def main(unused_argv):
483  # Configure CSV file which will be written, if desired.
484  if cmd_args.csv_path:
485    csv_file = open(cmd_args.csv_path, 'w')
486    field_names = [
487        'CNN', 'input resolution', 'end_point', 'RF size hor', 'RF size ver',
488        'effective stride hor', 'effective stride ver', 'effective padding hor',
489        'effective padding ver'
490    ]
491    rf_writer = csv.DictWriter(csv_file, fieldnames=field_names)
492    rf_writer.writeheader()
493  else:
494    rf_writer = None
495
496  # Compute RF parameters for each network architecture.
497  _alexnet_v2_rf(rf_writer)
498  _vgg_rf(rf_writer)
499  _inception_v2_rf(rf_writer)
500  _inception_v3_rf(rf_writer)
501  _inception_v4_rf(rf_writer)
502  _inception_resnet_v2_rf(rf_writer)
503  _mobilenet_v1_rf(rf_writer)
504  _resnet_rf(rf_writer)
505
506  # Close CSV file, if it was opened.
507  if cmd_args.csv_path:
508    csv_file.close()
509
510
511if __name__ == '__main__':
512  parser = argparse.ArgumentParser()
513  parser.register('type', 'bool', lambda v: v.lower() == 'true')
514  parser.add_argument(
515      '--csv_path',
516      type=str,
517      default='',
518      help="""\
519      Path to CSV file that will be written with RF parameters.If empty, no
520      file will be written.\
521      """)
522  cmd_args, unparsed = parser.parse_known_args()
523  app.run(main=main, argv=[sys.argv[0]] + unparsed)
524