1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Common utilities used across this package.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import re 23 24from tensorflow.python.framework import dtypes 25from tensorflow.python.framework import ops 26from tensorflow.python.ops import array_ops 27from tensorflow.python.ops import init_ops 28from tensorflow.python.ops import state_ops 29from tensorflow.python.ops import variable_scope 30 31# Skip all operations that are backprop related or export summaries. 32SKIPPED_PREFIXES = ( 33 'gradients/', 'RMSProp/', 'Adagrad/', 'Const_', 'HistogramSummary', 34 'ScalarSummary') 35 36# Valid activation ops for quantization end points. 37_ACTIVATION_OP_SUFFIXES = ['Relu6', 'Relu', 'Identity'] 38 39# Regular expression for recognizing nodes that are part of batch norm group. 40_BATCHNORM_RE = re.compile(r'^(.*)BatchNorm/batchnorm') 41 42 43def BatchNormGroups(graph): 44 """Finds batch norm layers, returns their prefixes as a list of strings. 45 46 Args: 47 graph: Graph to inspect. 48 49 Returns: 50 List of strings, prefixes of batch norm group names found. 51 """ 52 bns = [] 53 for op in graph.get_operations(): 54 match = _BATCHNORM_RE.search(op.name) 55 if match: 56 bn = match.group(1) 57 if not bn.startswith(SKIPPED_PREFIXES): 58 bns.append(bn) 59 # Filter out duplicates. 60 return list(collections.OrderedDict.fromkeys(bns)) 61 62 63def GetEndpointActivationOp(graph, prefix): 64 """Returns an Operation with the given prefix and a valid end point suffix. 65 66 Args: 67 graph: Graph where to look for the operation. 68 prefix: String, prefix of Operation to return. 69 70 Returns: 71 The Operation with the given prefix and a valid end point suffix or None if 72 there are no matching operations in the graph for any valid suffix 73 """ 74 for suffix in _ACTIVATION_OP_SUFFIXES: 75 activation = _GetOperationByNameDontThrow(graph, prefix + suffix) 76 if activation: 77 return activation 78 return None 79 80 81def _GetOperationByNameDontThrow(graph, name): 82 """Returns an Operation with the given name. 83 84 Args: 85 graph: Graph where to look for the operation. 86 name: String, name of Operation to return. 87 88 Returns: 89 The Operation with the given name. None if the name does not correspond to 90 any operation in the graph 91 """ 92 try: 93 return graph.get_operation_by_name(name) 94 except KeyError: 95 return None 96 97 98def CreateOrGetQuantizationStep(): 99 """Returns a Tensor of the number of steps the quantized graph has run. 100 101 Returns: 102 Quantization step Tensor. 103 """ 104 quantization_step_name = 'fake_quantization_step' 105 quantization_step_tensor_name = quantization_step_name + '/Identity:0' 106 g = ops.get_default_graph() 107 try: 108 return g.get_tensor_by_name(quantization_step_tensor_name) 109 except KeyError: 110 # Create in proper graph and base name_scope. 111 with g.name_scope(None): 112 quantization_step_tensor = variable_scope.get_variable( 113 quantization_step_name, 114 shape=[], 115 dtype=dtypes.int64, 116 initializer=init_ops.zeros_initializer(), 117 trainable=False, 118 collections=[ops.GraphKeys.GLOBAL_VARIABLES]) 119 with g.name_scope(quantization_step_tensor.op.name + '/'): 120 # We return the incremented variable tensor. Since this is used in conds 121 # for quant_delay and freeze_bn_delay, it will run once per graph 122 # execution. We return an identity to force resource variables and 123 # normal variables to return a tensor of the same name. 124 return array_ops.identity( 125 state_ops.assign_add(quantization_step_tensor, 1)) 126 127 128def DropStringPrefix(s, prefix): 129 """If the string starts with this prefix, drops it.""" 130 if s.startswith(prefix): 131 return s[len(prefix):] 132 else: 133 return s 134 135 136def RerouteTensor(t0, t1, can_modify=None): 137 """Reroute the end of the tensor t0 to the ends of the tensor t1. 138 139 Args: 140 t0: a tf.Tensor. 141 t1: a tf.Tensor. 142 can_modify: iterable of operations which can be modified. Any operation 143 outside within_ops will be left untouched by this function. 144 145 Returns: 146 The number of individual modifications made by the function. 147 """ 148 nb_update_inputs = 0 149 consumers = t1.consumers() 150 if can_modify is not None: 151 consumers = [c for c in consumers if c in can_modify] 152 consumers_indices = {} 153 for c in consumers: 154 consumers_indices[c] = [i for i, t in enumerate(c.inputs) if t is t1] 155 for c in consumers: 156 for i in consumers_indices[c]: 157 c._update_input(i, t0) # pylint: disable=protected-access 158 nb_update_inputs += 1 159 return nb_update_inputs 160