• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Library to compute order of computations in a graph.
16"""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import collections
23import math
24from tensorflow.contrib.receptive_field.python.util import parse_layer_parameters
25from tensorflow.python.platform import tf_logging as logging
26
27
28def parse_graph_nodes(graph_def):
29  """Helper function to parse GraphDef's nodes.
30
31  It returns a dict mapping from node name to NodeDef.
32
33  Args:
34    graph_def: A GraphDef object.
35
36  Returns:
37    name_to_node: Dict keyed by node name, each entry containing the node's
38      NodeDef.
39  """
40  name_to_node = {}
41  for node_def in graph_def.node:
42    name_to_node[node_def.name] = node_def
43  return name_to_node
44
45
46# Named tuple used to collect information from each node in a computation graph.
47_node_info = collections.namedtuple(
48    'NodeInfo', field_names=['order', 'node', 'input_size', 'output_size'])
49
50
51def _compute_output_resolution(input_spatial_resolution, kernel_size, stride,
52                               total_padding):
53  """Computes output resolution, given input resolution and layer parameters.
54
55  Note that this computation is done only over one dimension (eg, x or y).
56  If any of the inputs is None, returns None.
57
58  Args:
59    input_spatial_resolution: Input spatial resolution (int).
60    kernel_size: Kernel size (int).
61    stride: Stride (int).
62    total_padding: Total padding to be applied (int).
63  Returns:
64    output_resolution: Output dimension (int) or None.
65  """
66  if (input_spatial_resolution is None) or (kernel_size is None) or (
67      stride is None) or (total_padding is None):
68    return None
69  return int(
70      math.ceil((
71          input_spatial_resolution + total_padding - kernel_size + 1) / stride))
72
73
74def _get_computed_nodes(name_to_node,
75                        current,
76                        node_info,
77                        input_node_name='',
78                        input_node_size=None):
79  """Traverses the graph recursively to compute its topological order.
80
81  Optionally, the function may also compute the input and output feature map
82  resolutions at each node. In this case, input_node_name and input_node_size
83  must be set. Note that if a node's op type is unknown, the input and output
84  resolutions are ignored and set to None.
85
86  Args:
87    name_to_node: Dict keyed by node name, each entry containing the node's
88      NodeDef.
89    current: Current node name.
90    node_info: Map of nodes we've already traversed, containing their _node_info
91      information.
92    input_node_name: Name of node with fixed input resolution (optional).
93    input_node_size: Fixed input resolution to use (optional).
94  Returns:
95    order: Order in topological sort for 'current'.
96    input_size: Tensor spatial resolution at input of current node.
97    output_size: Tensor spatial resolution at output of current node.
98  """
99  if current in node_info:
100    return (node_info[current].order, node_info[current].input_size,
101            node_info[current].output_size)
102
103  node_def = name_to_node[current]
104
105  if current == input_node_name:
106    order = 0
107    input_size = None
108    output_size = input_node_size
109    node_info[current] = _node_info(order, node_def, input_size, output_size)
110    return (order, input_size, output_size)
111
112  input_size = None
113  output_size = None
114
115  order = 0
116  number_inputs = 0
117  for each in node_def.input:
118    # Parses name of input node.
119    if each.startswith('^'):
120      # The character '^' denotes a control dependency, so this input node can
121      # be safely ignored.
122      continue
123    each = each.split(':')[0]
124    # Recursively computes ordering.
125    (parent_order, _, parent_output_size) = _get_computed_nodes(
126        name_to_node, each, node_info, input_node_name, input_node_size)
127    order = max(order, parent_order + 1)
128    if number_inputs == 0:
129      # For all the types of nodes we consider, the first input corresponds to
130      # the feature map.
131      input_size = parent_output_size
132    number_inputs += 1
133
134  # Figure out output size for this layer.
135  logging.vlog(3, 'input_size = %s', input_size)
136  if input_size is None:
137    output_size = None
138  else:
139    (kernel_size_x, kernel_size_y, stride_x, stride_y, _, _, total_padding_x,
140     total_padding_y) = (
141         parse_layer_parameters.get_layer_params(
142             node_def, name_to_node, input_size, force=True))
143    logging.vlog(3, 'kernel_size_x = %s, kernel_size_y = %s, '
144                 'stride_x = %s, stride_y = %s, '
145                 'total_padding_x = %s, total_padding_y = %s' %
146                 (kernel_size_x, kernel_size_y, stride_x, stride_y,
147                  total_padding_x, total_padding_y))
148    output_size = [None] * 2
149    output_size[0] = _compute_output_resolution(input_size[0], kernel_size_x,
150                                                stride_x, total_padding_x)
151    output_size[1] = _compute_output_resolution(input_size[1], kernel_size_y,
152                                                stride_y, total_padding_y)
153
154  logging.vlog(3, 'output_size = %s', output_size)
155  node_info[current] = _node_info(order, node_def, input_size, output_size)
156
157  return order, input_size, output_size
158
159
160def get_compute_order(graph_def, input_node_name='', input_node_size=None):
161  """Computes order of computation for a given CNN graph.
162
163  Optionally, the function may also compute the input and output feature map
164  resolutions at each node. In this case, input_node_name and input_node_size
165  must be set. Note that if a node's op type is unknown, the input and output
166  resolutions are ignored and set to None.
167
168  Args:
169    graph_def: GraphDef object.
170    input_node_name: Name of node with fixed input resolution (optional). This
171      is usually the node name for the input image in a CNN.
172    input_node_size: 2D list of integers, fixed input resolution to use
173      (optional). This is usually the input resolution used for the input image
174      in a CNN (common examples are: [224, 224], [299, 299], [321, 321]).
175  Returns:
176    node_info: Default dict keyed by node name, mapping to a named tuple with
177      the following fields:
178      - order: Integer denoting topological order;
179      - node: NodeDef for the given node;
180      - input_size: 2D list of integers, denoting the input spatial resolution
181        to the node;
182      - output_size: 2D list of integers, denoting the output spatial resolution
183        of the node.
184    name_to_node: Dict keyed by node name, each entry containing the node's
185      NodeDef.
186  """
187  name_to_node = parse_graph_nodes(graph_def)
188  node_info = collections.defaultdict(_node_info)
189  for each in graph_def.node:
190    _get_computed_nodes(name_to_node, each.name, node_info, input_node_name,
191                        input_node_size)
192  return node_info, name_to_node
193