• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Common utilities used across this package."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import re
23
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import ops
26from tensorflow.python.ops import array_ops
27from tensorflow.python.ops import init_ops
28from tensorflow.python.ops import state_ops
29from tensorflow.python.ops import variable_scope
30
31# Skip all operations that are backprop related or export summaries.
32SKIPPED_PREFIXES = (
33    'gradients/', 'RMSProp/', 'Adagrad/', 'Const_', 'HistogramSummary',
34    'ScalarSummary')
35
36# Valid activation ops for quantization end points.
37_ACTIVATION_OP_SUFFIXES = ['Relu6', 'Relu', 'Identity']
38
39# Regular expression for recognizing nodes that are part of batch norm group.
40_BATCHNORM_RE = re.compile(r'^(.*)BatchNorm/batchnorm')
41
42
43def BatchNormGroups(graph):
44  """Finds batch norm layers, returns their prefixes as a list of strings.
45
46  Args:
47    graph: Graph to inspect.
48
49  Returns:
50    List of strings, prefixes of batch norm group names found.
51  """
52  bns = []
53  for op in graph.get_operations():
54    match = _BATCHNORM_RE.search(op.name)
55    if match:
56      bn = match.group(1)
57      if not bn.startswith(SKIPPED_PREFIXES):
58        bns.append(bn)
59  # Filter out duplicates.
60  return list(collections.OrderedDict.fromkeys(bns))
61
62
63def GetEndpointActivationOp(graph, prefix):
64  """Returns an Operation with the given prefix and a valid end point suffix.
65
66  Args:
67    graph: Graph where to look for the operation.
68    prefix: String, prefix of Operation to return.
69
70  Returns:
71    The Operation with the given prefix and a valid end point suffix or None if
72    there are no matching operations in the graph for any valid suffix
73  """
74  for suffix in _ACTIVATION_OP_SUFFIXES:
75    activation = _GetOperationByNameDontThrow(graph, prefix + suffix)
76    if activation:
77      return activation
78  return None
79
80
81def _GetOperationByNameDontThrow(graph, name):
82  """Returns an Operation with the given name.
83
84  Args:
85    graph: Graph where to look for the operation.
86    name: String, name of Operation to return.
87
88  Returns:
89    The Operation with the given name. None if the name does not correspond to
90    any operation in the graph
91  """
92  try:
93    return graph.get_operation_by_name(name)
94  except KeyError:
95    return None
96
97
98def CreateOrGetQuantizationStep():
99  """Returns a Tensor of the number of steps the quantized graph has run.
100
101  Returns:
102    Quantization step Tensor.
103  """
104  quantization_step_name = 'fake_quantization_step'
105  quantization_step_tensor_name = quantization_step_name + '/Identity:0'
106  g = ops.get_default_graph()
107  try:
108    return g.get_tensor_by_name(quantization_step_tensor_name)
109  except KeyError:
110    # Create in proper graph and base name_scope.
111    with g.name_scope(None):
112      quantization_step_tensor = variable_scope.get_variable(
113          quantization_step_name,
114          shape=[],
115          dtype=dtypes.int64,
116          initializer=init_ops.zeros_initializer(),
117          trainable=False,
118          collections=[ops.GraphKeys.GLOBAL_VARIABLES])
119      with g.name_scope(quantization_step_tensor.op.name + '/'):
120        # We return the incremented variable tensor. Since this is used in conds
121        # for quant_delay and freeze_bn_delay, it will run once per graph
122        # execution. We return an identity to force resource variables and
123        # normal variables to return a tensor of the same name.
124        return array_ops.identity(
125            state_ops.assign_add(quantization_step_tensor, 1))
126
127
128def DropStringPrefix(s, prefix):
129  """If the string starts with this prefix, drops it."""
130  if s.startswith(prefix):
131    return s[len(prefix):]
132  else:
133    return s
134
135
136def RerouteTensor(t0, t1, can_modify=None):
137  """Reroute the end of the tensor t0 to the ends of the tensor t1.
138
139  Args:
140    t0: a tf.Tensor.
141    t1: a tf.Tensor.
142    can_modify: iterable of operations which can be modified. Any operation
143      outside within_ops will be left untouched by this function.
144
145  Returns:
146    The number of individual modifications made by the function.
147  """
148  nb_update_inputs = 0
149  consumers = t1.consumers()
150  if can_modify is not None:
151    consumers = [c for c in consumers if c in can_modify]
152  consumers_indices = {}
153  for c in consumers:
154    consumers_indices[c] = [i for i, t in enumerate(c.inputs) if t is t1]
155  for c in consumers:
156    for i in consumers_indices[c]:
157      c._update_input(i, t0)  # pylint: disable=protected-access
158      nb_update_inputs += 1
159  return nb_update_inputs
160