• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""API to simulate quantization on a python graph."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.contrib.quantize.python import fold_batch_norms
22from tensorflow.contrib.quantize.python import quantize
23from tensorflow.python.framework import ops
24
25
26def _create_graph(input_graph=None,
27                  is_training=True,
28                  weight_bits=8,
29                  activation_bits=8,
30                  symmetric=False,
31                  quant_delay=None,
32                  freeze_bn_delay=None,
33                  scope=None):
34  """Rewrites an input_graph in place for simulated quantization.
35
36  The graph has fake quantization ops inserted to simulate the error
37  introduced by quantization. Since the graph is transformed in place,
38  the expected behavior of previously held references to nodes and tensors may
39  change.
40
41  Args:
42    input_graph: The tf.Graph to be transformed, if None then defaults to the
43      default graph.
44    is_training: Whether quantizing training or eval graph.
45    weight_bits: Number of bits to use for quantizing weights.
46    activation_bits: Number of bits to use for quantizing activations.
47    symmetric: If true, use symmetric quantization limits instead of training
48      the minimum and maximum of each quantization range separately.
49    quant_delay: Number of steps after which weights and activations are
50      quantized during training.
51    freeze_bn_delay: Number of steps after which moving mean and variance are
52      frozen and used instead of batch statistics during training.
53      freeze_bn_delay should be greater than quant_delay and should correspond
54      to the number of steps when training has almost converged
55    scope: The scope to be transformed. If it's not None, only the ops which
56      are in this scope will be transformed.
57
58  Raises:
59    ValueError: If elements contains an element that isn't a tf.Tensor or
60      tf.Operation.
61  """
62
63  if input_graph is None:
64    input_graph = ops.get_default_graph()
65
66  # Add check to see if graph has training ops, if so provide error message and
67  # exit
68  _check_for_training_ops(input_graph)
69  with input_graph.as_default():
70    fold_batch_norms.FoldBatchNorms(
71        input_graph,
72        freeze_batch_norm_delay=freeze_bn_delay,
73        is_training=is_training)
74    quantize.Quantize(
75        input_graph,
76        is_training,
77        quant_delay=quant_delay,
78        weight_bits=weight_bits,
79        activation_bits=activation_bits,
80        symmetric=symmetric,
81        scope=scope)
82
83
84def create_training_graph(input_graph=None, quant_delay=0):
85  """Rewrites a training input_graph in place for simulated quantization.
86
87  Variables added by the rewrite get added to the global variables collection.
88
89  This function must be invoked prior to insertion of gradient ops in a graph
90  as quantization should be modeled in both forward and backward passes.
91
92  The graph has fake quantization ops inserted to simulate the error
93  introduced by quantization. Since the graph is transformed in place,
94  the expected behavior of previously held references to nodes and tensors may
95  change.
96
97  The default value of quant_delay is suitable for finetuning an already trained
98  floating point model (recommended).
99  If one wants to train a quantized model from scratch, quant_delay should be
100  set to the number of steps it take the floating point model to converge.
101  Quantization will be activated at this point and effectively finetune the
102  model. If quant_delay is not provided when training from scratch, training can
103  often fail.
104
105  Args:
106    input_graph: The tf.Graph to be transformed.
107    quant_delay: Number of steps after which weights and activations are
108      quantized during training.
109
110  Raises:
111    ValueError: If elements contains an element that isn't a tf.Tensor or
112      tf.Operation.
113  """
114  # TODO(raghuramank) Need to have freeze_bn_delay be a function of batch size
115  # Currently the values below are hardcoded for mobilenetV1 on imagenet
116  # Please use the experimental API if you need to tune these values.
117  freeze_bn_delay = None
118  _create_graph(
119      input_graph=input_graph,
120      is_training=True,
121      quant_delay=quant_delay,
122      freeze_bn_delay=freeze_bn_delay)
123
124
125def create_eval_graph(input_graph=None):
126  """Rewrites an eval input_graph in place for simulated quantization.
127
128  Variables added by the rewrite get added to the global variables collection.
129
130  The graph has fake quantization ops inserted to simulate the error
131  introduced by quantization. Since the graph is transformed in place,
132  the expected behavior of previously held references to nodes and tensors may
133  change.
134
135  Args:
136    input_graph: The tf.Graph to be transformed, if None then defaults to the
137      default graph.
138
139  Raises:
140    ValueError: If elements contains an element that isn't a tf.Tensor or
141      tf.Operation.
142  """
143  _create_graph(input_graph=input_graph, is_training=False)
144
145
146def experimental_create_training_graph(input_graph=None,
147                                       weight_bits=8,
148                                       activation_bits=8,
149                                       symmetric=False,
150                                       quant_delay=0,
151                                       freeze_bn_delay=None,
152                                       scope=None):
153  """Rewrites a training input_graph in place for simulated quantization.
154
155  This function must be invoked prior to insertion of gradient ops in a graph
156  as quantization should be modeled in both forward and backward passes.
157
158  Variables added by the rewrite get added to the global variables collection.
159
160  This function has additional experimental options not (yet) available to
161  create_training_graph. The resulting behavior may be undefined.
162
163  The graph has fake quantization ops inserted to simulate the error
164  introduced by quantization. Since the graph is transformed in place,
165  the expected behavior of previously held references to nodes and tensors may
166  change.
167
168  The default value of quant_delay is suitable for finetuning an already trained
169  floating point model (recommended).
170  If one wants to train a quantized model from scratch, quant_delay should be
171  set to the number of steps it take the floating point model to converge.
172  Quantization will be activated at this point and effectively finetune the
173  model. If quant_delay is not provided when training from scratch, training can
174  often fail.
175
176  Args:
177    input_graph: The tf.Graph to be transformed, if None then defaults to the
178      default graph.
179    weight_bits: Number of bits to use for quantizing weights.
180    activation_bits: Number of bits to use for quantizing activations.
181    symmetric: If true, use symmetric quantization limits instead of training
182      the minimum and maximum of each quantization range separately.
183    quant_delay: Number of steps after which weights and activations are
184      quantized during training.
185    freeze_bn_delay: Number of steps after which moving mean and variance are
186      frozen and used instead of batch statistics during training.
187      freeze_bn_delay should be greater than quant_delay and should correspond
188      to when training has almost converged
189    scope: The scope to be transformed. If it's not None, only the ops which
190      are in this scope will be transformed.
191
192  Raises:
193    ValueError: If elements contains an element that isn't a tf.Tensor or
194        tf.Operation.
195  """
196
197  _create_graph(
198      input_graph=input_graph,
199      is_training=True,
200      weight_bits=weight_bits,
201      activation_bits=activation_bits,
202      symmetric=symmetric,
203      quant_delay=quant_delay,
204      freeze_bn_delay=freeze_bn_delay,
205      scope=scope)
206
207
208def experimental_create_eval_graph(input_graph=None,
209                                   weight_bits=8,
210                                   activation_bits=8,
211                                   symmetric=False,
212                                   quant_delay=None,
213                                   scope=None):
214  """Rewrites an eval input_graph in place for simulated quantization.
215
216  Variables added by the rewrite get added to the global variables collection.
217
218  This function has additional experimental options not (yet) available to
219  create_eval_graph. The resulting behavior may be undefined.
220
221  The graph has fake quantization ops inserted to simulate the error
222  introduced by quantization. Since the graph is transformed in place,
223  the expected behavior of previously held references to nodes and tensors may
224  change.
225
226  Args:
227    input_graph: The tf.Graph to be transformed, if None then defaults to the
228      default graph.
229    weight_bits: Number of bits to use for quantizing weights.
230    activation_bits: Number of bits to use for quantizing activations.
231    symmetric: If true, use symmetric quantization limits instead of training
232      the minimum and maximum of each quantization range separately.
233    quant_delay: Number of steps after which weights and activations are
234      quantized during eval.
235    scope: The scope to be transformed. If it's not None, only the ops which
236      are in this scope will be transformed.
237
238  Raises:
239    ValueError: If elements contains an element that isn't a tf.Tensor or
240      tf.Operation.
241  """
242  _create_graph(
243      input_graph=input_graph,
244      is_training=False,
245      weight_bits=weight_bits,
246      activation_bits=activation_bits,
247      symmetric=symmetric,
248      quant_delay=quant_delay,
249      scope=scope)
250
251
252def _check_for_training_ops(g):
253  """Check if training ops are present in the graph.
254
255  Args:
256   g: The tf.Graph on which the check for training ops needs to be
257   performed.
258
259  Raises:
260    ValueError: If a training op is seen in the graph;
261  """
262
263  # The list here is obtained
264  # from https://www.tensorflow.org/api_docs/cc/group/training-ops
265  training_ops = frozenset([
266      'ApplyAdagrad', 'ApplyAdagradDA', 'ApplyAdam', 'ApplyAddSign',
267      'ApplyCenteredRMSProp', 'ApplyFtrl', 'ApplyFtrlV2',
268      'ApplyGradientDescent', 'ApplyMomentum', 'ApplyPowerSign',
269      'ApplyProximalAdagrad', 'ApplyProximalGradientDescent', 'ApplyRMSProp',
270      'ResourceApplyAdadelta', 'ResourceApplyAdagrad', 'ResourceApplyAdagradDA',
271      'ResourceApplyAdam', 'ResourceApplyAddSign',
272      'ResourceApplyCenteredRMSProp', 'ResourceApplyFtrl',
273      'ResourceApplyFtrlV2', 'ResourceApplyGradientDescent',
274      'ResourceApplyMomentum', 'ResourceApplyPowerSign',
275      'ResourceApplyProximalAdagrad', 'ResourceApplyProximalGradientDescent',
276      'ResourceApplyRMSProp', 'ResourceSparseApplyAdadelta',
277      'ResourceSparseApplyAdagrad', 'ResourceSparseApplyAdagradDA',
278      'ResourceSparseApplyCenteredRMSProp', 'ResourceSparseApplyFtrl',
279      'ResourceSparseApplyFtrlV2', 'ResourceSparseApplyMomentum',
280      'ResourceSparseApplyProximalAdagrad',
281      'ResourceSparseApplyProximalGradientDescent',
282      'ResourceSparseApplyRMSProp', 'SparseApplyAdadelta', 'SparseApplyAdagrad',
283      'SparseApplyAdagradDA', 'SparseApplyCenteredRMSProp', 'SparseApplyFtrl',
284      'SparseApplyFtrlV2', 'SparseApplyMomentum', 'SparseApplyProximalAdagrad',
285      'SparseApplyProximalGradientDescent', 'SparseApplyRMSProp'
286  ])
287
288  op_types = set([op.type for op in g.get_operations()])
289  train_op_list = op_types.intersection(training_ops)
290  if train_op_list:
291    raise ValueError('Training op found in graph, exiting %s' % train_op_list)
292