1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""API to simulate quantization on a python graph.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.contrib.quantize.python import fold_batch_norms 22from tensorflow.contrib.quantize.python import quantize 23from tensorflow.python.framework import ops 24 25 26def _create_graph(input_graph=None, 27 is_training=True, 28 weight_bits=8, 29 activation_bits=8, 30 symmetric=False, 31 quant_delay=None, 32 freeze_bn_delay=None, 33 scope=None): 34 """Rewrites an input_graph in place for simulated quantization. 35 36 The graph has fake quantization ops inserted to simulate the error 37 introduced by quantization. Since the graph is transformed in place, 38 the expected behavior of previously held references to nodes and tensors may 39 change. 40 41 Args: 42 input_graph: The tf.Graph to be transformed, if None then defaults to the 43 default graph. 44 is_training: Whether quantizing training or eval graph. 45 weight_bits: Number of bits to use for quantizing weights. 46 activation_bits: Number of bits to use for quantizing activations. 47 symmetric: If true, use symmetric quantization limits instead of training 48 the minimum and maximum of each quantization range separately. 49 quant_delay: Number of steps after which weights and activations are 50 quantized during training. 51 freeze_bn_delay: Number of steps after which moving mean and variance are 52 frozen and used instead of batch statistics during training. 53 freeze_bn_delay should be greater than quant_delay and should correspond 54 to the number of steps when training has almost converged 55 scope: The scope to be transformed. If it's not None, only the ops which 56 are in this scope will be transformed. 57 58 Raises: 59 ValueError: If elements contains an element that isn't a tf.Tensor or 60 tf.Operation. 61 """ 62 63 if input_graph is None: 64 input_graph = ops.get_default_graph() 65 66 # Add check to see if graph has training ops, if so provide error message and 67 # exit 68 _check_for_training_ops(input_graph) 69 with input_graph.as_default(): 70 fold_batch_norms.FoldBatchNorms( 71 input_graph, 72 freeze_batch_norm_delay=freeze_bn_delay, 73 is_training=is_training) 74 quantize.Quantize( 75 input_graph, 76 is_training, 77 quant_delay=quant_delay, 78 weight_bits=weight_bits, 79 activation_bits=activation_bits, 80 symmetric=symmetric, 81 scope=scope) 82 83 84def create_training_graph(input_graph=None, quant_delay=0): 85 """Rewrites a training input_graph in place for simulated quantization. 86 87 Variables added by the rewrite get added to the global variables collection. 88 89 This function must be invoked prior to insertion of gradient ops in a graph 90 as quantization should be modeled in both forward and backward passes. 91 92 The graph has fake quantization ops inserted to simulate the error 93 introduced by quantization. Since the graph is transformed in place, 94 the expected behavior of previously held references to nodes and tensors may 95 change. 96 97 The default value of quant_delay is suitable for finetuning an already trained 98 floating point model (recommended). 99 If one wants to train a quantized model from scratch, quant_delay should be 100 set to the number of steps it take the floating point model to converge. 101 Quantization will be activated at this point and effectively finetune the 102 model. If quant_delay is not provided when training from scratch, training can 103 often fail. 104 105 Args: 106 input_graph: The tf.Graph to be transformed. 107 quant_delay: Number of steps after which weights and activations are 108 quantized during training. 109 110 Raises: 111 ValueError: If elements contains an element that isn't a tf.Tensor or 112 tf.Operation. 113 """ 114 # TODO(raghuramank) Need to have freeze_bn_delay be a function of batch size 115 # Currently the values below are hardcoded for mobilenetV1 on imagenet 116 # Please use the experimental API if you need to tune these values. 117 freeze_bn_delay = None 118 _create_graph( 119 input_graph=input_graph, 120 is_training=True, 121 quant_delay=quant_delay, 122 freeze_bn_delay=freeze_bn_delay) 123 124 125def create_eval_graph(input_graph=None): 126 """Rewrites an eval input_graph in place for simulated quantization. 127 128 Variables added by the rewrite get added to the global variables collection. 129 130 The graph has fake quantization ops inserted to simulate the error 131 introduced by quantization. Since the graph is transformed in place, 132 the expected behavior of previously held references to nodes and tensors may 133 change. 134 135 Args: 136 input_graph: The tf.Graph to be transformed, if None then defaults to the 137 default graph. 138 139 Raises: 140 ValueError: If elements contains an element that isn't a tf.Tensor or 141 tf.Operation. 142 """ 143 _create_graph(input_graph=input_graph, is_training=False) 144 145 146def experimental_create_training_graph(input_graph=None, 147 weight_bits=8, 148 activation_bits=8, 149 symmetric=False, 150 quant_delay=0, 151 freeze_bn_delay=None, 152 scope=None): 153 """Rewrites a training input_graph in place for simulated quantization. 154 155 This function must be invoked prior to insertion of gradient ops in a graph 156 as quantization should be modeled in both forward and backward passes. 157 158 Variables added by the rewrite get added to the global variables collection. 159 160 This function has additional experimental options not (yet) available to 161 create_training_graph. The resulting behavior may be undefined. 162 163 The graph has fake quantization ops inserted to simulate the error 164 introduced by quantization. Since the graph is transformed in place, 165 the expected behavior of previously held references to nodes and tensors may 166 change. 167 168 The default value of quant_delay is suitable for finetuning an already trained 169 floating point model (recommended). 170 If one wants to train a quantized model from scratch, quant_delay should be 171 set to the number of steps it take the floating point model to converge. 172 Quantization will be activated at this point and effectively finetune the 173 model. If quant_delay is not provided when training from scratch, training can 174 often fail. 175 176 Args: 177 input_graph: The tf.Graph to be transformed, if None then defaults to the 178 default graph. 179 weight_bits: Number of bits to use for quantizing weights. 180 activation_bits: Number of bits to use for quantizing activations. 181 symmetric: If true, use symmetric quantization limits instead of training 182 the minimum and maximum of each quantization range separately. 183 quant_delay: Number of steps after which weights and activations are 184 quantized during training. 185 freeze_bn_delay: Number of steps after which moving mean and variance are 186 frozen and used instead of batch statistics during training. 187 freeze_bn_delay should be greater than quant_delay and should correspond 188 to when training has almost converged 189 scope: The scope to be transformed. If it's not None, only the ops which 190 are in this scope will be transformed. 191 192 Raises: 193 ValueError: If elements contains an element that isn't a tf.Tensor or 194 tf.Operation. 195 """ 196 197 _create_graph( 198 input_graph=input_graph, 199 is_training=True, 200 weight_bits=weight_bits, 201 activation_bits=activation_bits, 202 symmetric=symmetric, 203 quant_delay=quant_delay, 204 freeze_bn_delay=freeze_bn_delay, 205 scope=scope) 206 207 208def experimental_create_eval_graph(input_graph=None, 209 weight_bits=8, 210 activation_bits=8, 211 symmetric=False, 212 quant_delay=None, 213 scope=None): 214 """Rewrites an eval input_graph in place for simulated quantization. 215 216 Variables added by the rewrite get added to the global variables collection. 217 218 This function has additional experimental options not (yet) available to 219 create_eval_graph. The resulting behavior may be undefined. 220 221 The graph has fake quantization ops inserted to simulate the error 222 introduced by quantization. Since the graph is transformed in place, 223 the expected behavior of previously held references to nodes and tensors may 224 change. 225 226 Args: 227 input_graph: The tf.Graph to be transformed, if None then defaults to the 228 default graph. 229 weight_bits: Number of bits to use for quantizing weights. 230 activation_bits: Number of bits to use for quantizing activations. 231 symmetric: If true, use symmetric quantization limits instead of training 232 the minimum and maximum of each quantization range separately. 233 quant_delay: Number of steps after which weights and activations are 234 quantized during eval. 235 scope: The scope to be transformed. If it's not None, only the ops which 236 are in this scope will be transformed. 237 238 Raises: 239 ValueError: If elements contains an element that isn't a tf.Tensor or 240 tf.Operation. 241 """ 242 _create_graph( 243 input_graph=input_graph, 244 is_training=False, 245 weight_bits=weight_bits, 246 activation_bits=activation_bits, 247 symmetric=symmetric, 248 quant_delay=quant_delay, 249 scope=scope) 250 251 252def _check_for_training_ops(g): 253 """Check if training ops are present in the graph. 254 255 Args: 256 g: The tf.Graph on which the check for training ops needs to be 257 performed. 258 259 Raises: 260 ValueError: If a training op is seen in the graph; 261 """ 262 263 # The list here is obtained 264 # from https://www.tensorflow.org/api_docs/cc/group/training-ops 265 training_ops = frozenset([ 266 'ApplyAdagrad', 'ApplyAdagradDA', 'ApplyAdam', 'ApplyAddSign', 267 'ApplyCenteredRMSProp', 'ApplyFtrl', 'ApplyFtrlV2', 268 'ApplyGradientDescent', 'ApplyMomentum', 'ApplyPowerSign', 269 'ApplyProximalAdagrad', 'ApplyProximalGradientDescent', 'ApplyRMSProp', 270 'ResourceApplyAdadelta', 'ResourceApplyAdagrad', 'ResourceApplyAdagradDA', 271 'ResourceApplyAdam', 'ResourceApplyAddSign', 272 'ResourceApplyCenteredRMSProp', 'ResourceApplyFtrl', 273 'ResourceApplyFtrlV2', 'ResourceApplyGradientDescent', 274 'ResourceApplyMomentum', 'ResourceApplyPowerSign', 275 'ResourceApplyProximalAdagrad', 'ResourceApplyProximalGradientDescent', 276 'ResourceApplyRMSProp', 'ResourceSparseApplyAdadelta', 277 'ResourceSparseApplyAdagrad', 'ResourceSparseApplyAdagradDA', 278 'ResourceSparseApplyCenteredRMSProp', 'ResourceSparseApplyFtrl', 279 'ResourceSparseApplyFtrlV2', 'ResourceSparseApplyMomentum', 280 'ResourceSparseApplyProximalAdagrad', 281 'ResourceSparseApplyProximalGradientDescent', 282 'ResourceSparseApplyRMSProp', 'SparseApplyAdadelta', 'SparseApplyAdagrad', 283 'SparseApplyAdagradDA', 'SparseApplyCenteredRMSProp', 'SparseApplyFtrl', 284 'SparseApplyFtrlV2', 'SparseApplyMomentum', 'SparseApplyProximalAdagrad', 285 'SparseApplyProximalGradientDescent', 'SparseApplyRMSProp' 286 ]) 287 288 op_types = set([op.type for op in g.get_operations()]) 289 train_op_list = op_types.intersection(training_ops) 290 if train_op_list: 291 raise ValueError('Training op found in graph, exiting %s' % train_op_list) 292