• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Unit tests for the quantize_graph graph rewriting API."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import functools
22
23from tensorflow.contrib.layers.python.layers import layers
24from tensorflow.contrib.quantize.python import quantize_graph
25from tensorflow.python import training
26from tensorflow.python.framework import ops
27from tensorflow.python.framework import test_util
28from tensorflow.python.ops import array_ops
29from tensorflow.python.ops import init_ops
30from tensorflow.python.ops import math_ops
31from tensorflow.python.ops import nn_ops
32from tensorflow.python.ops import template
33from tensorflow.python.platform import googletest
34
35
36class QuantizeGraphTest(test_util.TensorFlowTestCase):
37  # We have a lot of other tests that test the details of the rewrite, here we
38  # just the specific features of the quantize_graph API.
39
40  def _RunTestOverAllRewrites(self, test_fn):
41    rewrite_fns = [
42        quantize_graph.create_training_graph,
43        quantize_graph.create_eval_graph,
44        quantize_graph.experimental_create_training_graph,
45        quantize_graph.experimental_create_eval_graph,
46    ]
47    for fn in rewrite_fns:
48      test_fn(fn)
49
50  def _RunTestOverTrainingRewrites(self, test_fn):
51    rewrite_fns = [
52        quantize_graph.create_training_graph,
53        quantize_graph.experimental_create_training_graph,
54        functools.partial(
55            quantize_graph.experimental_create_training_graph, symmetric=True),
56    ]
57    for fn in rewrite_fns:
58      test_fn(fn)
59
60  def _RunTestOverEvalRewrites(self, test_fn):
61    rewrite_fns = [
62        quantize_graph.create_eval_graph,
63        quantize_graph.experimental_create_eval_graph,
64        functools.partial(
65            quantize_graph.experimental_create_eval_graph, symmetric=True),
66    ]
67    for fn in rewrite_fns:
68      test_fn(fn)
69
70  def _RunTestOverExperimentalRewrites(self, test_fn):
71    rewrite_fns = [
72        quantize_graph.experimental_create_training_graph,
73        quantize_graph.experimental_create_eval_graph,
74    ]
75    for fn in rewrite_fns:
76      test_fn(fn)
77
78  def _RunTestOverExperimentalRewritesWithScope(self, test_fn, scope):
79    def with_absent_scope(fn):
80      def fn_with_absent_scope(*args):
81        fn(*args, scope=scope)
82      return fn_with_absent_scope
83    rewrite_fns = [
84        with_absent_scope(
85            quantize_graph.experimental_create_training_graph),
86        with_absent_scope(
87            quantize_graph.experimental_create_eval_graph),
88    ]
89    for fn in rewrite_fns:
90      test_fn(fn)
91
92  def testRewrite(self):
93    self._RunTestOverAllRewrites(self._TestRewrite)
94
95  def _TestRewrite(self, rewrite_fn):
96    graph = ops.Graph()
97    with graph.as_default():
98      self._ConvLayer()
99
100    orig_variable_names = set(
101        [v.name for v in graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
102
103    rewrite_fn(graph)
104
105    q_variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
106    # Ensure that variables were added.
107    self.assertTrue(len(orig_variable_names) < len(q_variables))
108
109  def testDefaultGraph(self):
110    self._RunTestOverAllRewrites(self._TestRewrite)
111
112  def _TestDefaultGraph(self, rewrite_fn):
113    # Tests that the default graph is correctly used when no args are provided
114    # to rewrite_fn.
115    with ops.Graph().as_default() as g:
116      self._ConvLayer()
117      orig_variable_names = set(
118          [v.name for v in g.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
119      rewrite_fn()
120
121      q_variables = g.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
122      # Ensure that variables were added.
123      self.assertTrue(len(orig_variable_names) < len(q_variables))
124
125  def testWithPostActivationBypass(self):
126    self._RunTestOverAllRewrites(self._TestWithPostActivationBypass)
127
128  def _TestWithPostActivationBypass(self, rewrite_fn):
129    # Tests that the default graph is correctly used when no args are provided
130    # to rewrite_fn.
131    with ops.Graph().as_default() as g:
132      self._ConvLayer(post_activation_bypass=True, scope='scope1')
133      rewrite_fn()
134
135      op_names = [op.name for op in g.get_operations()]
136      self.assertTrue(any(
137          'scope1/post_activation_bypass_quant/' in name for name in op_names))
138
139  def testQuantDelay(self):
140    self._RunTestOverTrainingRewrites(self._TestQuantDelay)
141
142  def _TestQuantDelay(self, rewrite_fn):
143    with ops.Graph().as_default() as g:
144      self._ConvLayer()
145      quant_delay = 100
146      rewrite_fn(quant_delay=quant_delay)
147
148    quant_delay_found = False
149    for op in g.get_operations():
150      # Check to see if the quant_delay is correctly set.
151      if 'activate_quant' in op.name and op.type == 'Const':
152        quant_delay_found = True
153        const_value = str(op.get_attr('value'))
154        self.assertTrue(('int64_val: %i' % quant_delay) in const_value)
155    self.assertTrue(quant_delay_found)
156
157  def testTrainingOpsCheck(self):
158    self._RunTestOverTrainingRewrites(self._TestTrainingOpsCheck)
159
160  def _TestTrainingOpsCheck(self, rewrite_fn):
161    with ops.Graph().as_default():
162      output = self._ConvLayer()
163      output_scalar = math_ops.reduce_sum(output)
164      loss = math_ops.square(output_scalar - 1)
165      opt = training.gradient_descent.GradientDescentOptimizer(0.0001)
166      opt.minimize(loss)
167      with self.assertRaisesRegexp(ValueError, 'Training op found in graph'):
168        rewrite_fn()
169
170  def testWeightBits(self):
171    self._RunTestOverExperimentalRewrites(self._TestWeightBits)
172
173  def _TestWeightBits(self, rewrite_fn):
174    with ops.Graph().as_default() as g:
175      self._ConvLayer()
176      weight_bits = 4
177      rewrite_fn(weight_bits=weight_bits)
178
179    weights_quant_found = False
180    for op in g.get_operations():
181      # Check to see if FakeQuant operations for weights have the right bits
182      # set.
183      if 'weights_quant' in op.name and op.type == 'FakeQuantWithMinMaxVars':
184        weights_quant_found = True
185        self.assertEqual(op.get_attr('num_bits'), weight_bits)
186    self.assertTrue(weights_quant_found)
187
188  def testActivationBits(self):
189    self._RunTestOverExperimentalRewrites(self._TestActivationBits)
190
191  def _TestActivationBits(self, rewrite_fn):
192    with ops.Graph().as_default() as g:
193      self._ConvLayer()
194      activation_bits = 4
195      rewrite_fn(activation_bits=activation_bits)
196
197    act_quant_found = False
198    for op in g.get_operations():
199      # Check to see if FakeQuant operations for activations have the right bits
200      # set.
201      act_quant_names = ['act_quant', 'conv_quant', 'add_quant']
202      if any(s in op.name
203             for s in act_quant_names) and op.type == 'FakeQuantWithMinMaxVars':
204        act_quant_found = True
205        self.assertEqual(op.get_attr('num_bits'), activation_bits)
206    self.assertTrue(act_quant_found)
207
208  def testTrainingQuantization(self):
209    self._RunTestOverTrainingRewrites(self._TestTrainingQuantization)
210
211  def _TestTrainingQuantization(self, rewrite_fn):
212    with ops.Graph().as_default() as g:
213      self._ConvLayer()
214      rewrite_fn()
215
216    # Ensure that FakeQuant and variable update nodes were found.
217    quant_found = False
218    assign_min_last_found = False
219    assign_min_ema_found = False
220    assign_max_last_found = False
221    assign_max_ema_found = False
222    for op in g.get_operations():
223      # Check that FakeQuant operations were added.
224      if op.type == 'FakeQuantWithMinMaxVars':
225        quant_found = True
226      # Check that update operations for the added min max variables exist in
227      # the graph.
228      if 'AssignMinLast' in op.name:
229        assign_min_last_found = True
230      elif 'AssignMinEma' in op.name:
231        assign_min_ema_found = True
232      elif 'AssignMaxLast' in op.name:
233        assign_max_last_found = True
234      elif 'AssignMaxEma' in op.name:
235        assign_max_ema_found = True
236    self.assertTrue(assign_min_last_found)
237    self.assertTrue(assign_min_ema_found)
238    self.assertTrue(assign_max_last_found)
239    self.assertTrue(assign_max_ema_found)
240    self.assertTrue(quant_found)
241
242  def testEvalQuantization(self):
243    self._RunTestOverEvalRewrites(self._TestEvalQuantization)
244
245  def _TestEvalQuantization(self, rewrite_fn):
246    with ops.Graph().as_default() as g:
247      self._ConvLayer()
248      rewrite_fn()
249
250    # Ensure that FakeQuant and variable update nodes were found.
251    quant_found = False
252    for op in g.get_operations():
253      # Check that FakeQuant operations were added.
254      if op.type == 'FakeQuantWithMinMaxVars':
255        quant_found = True
256      # Check that update operations for the added min max variables don't
257      # exist in the graph.
258      update_names = [
259          'AssignMinLast', 'AssignMinEma', 'AssignMaxLast', 'AssignMaxEma'
260      ]
261      self.assertFalse(any(s in op.name for s in update_names))
262    self.assertTrue(quant_found)
263
264  def testIdempotent(self):
265    self._RunTestOverAllRewrites(self._TestIdempotent)
266
267  def _TestIdempotent(self, rewrite_fn):
268    with ops.Graph().as_default() as g:
269      self._ConvLayer()
270      rewrite_fn()
271      graph_def_before = str(g.as_graph_def())
272      # Ensuring that calling the rewrite again doesn't add more nodes.
273      rewrite_fn()
274      graph_def_after = str(g.as_graph_def())
275      self.assertEqual(graph_def_before, graph_def_after)
276
277  def testIdentityNode(self):
278    self._RunTestOverAllRewrites(self._TestIdentityNode)
279
280  def _TestIdentityNode(self, rewrite_fn):
281    graph = ops.Graph()
282    with graph.as_default():
283      self._LayerWithIdentity()
284
285    rewrite_fn(graph)
286    op_names = [op.name for op in graph.get_operations()]
287    self.assertTrue(any('test/Conv/weights_quant' in name for name in op_names))
288    self.assertTrue(any('test/Conv/act_quant' in name for name in op_names))
289    bn_out_identity = graph.get_operation_by_name('test/bn_out')
290    self._AssertInputOpsAre(bn_out_identity, [
291        'test/Conv/add_fold',
292    ])
293
294    conv_out_identity = graph.get_operation_by_name('test/conv_out')
295    self._AssertOutputGoesToOps(conv_out_identity, graph,
296                                ['test/BatchNorm/FusedBatchNorm'])
297
298  def testActivationQuantization(self):
299    self._RunTestOverAllRewrites(self._TestActivationQuantization)
300
301  def _TestActivationQuantization(self, rewrite_fn):
302    graph = ops.Graph()
303    with graph.as_default():
304      _ = self._LayerWithActivationProcessing()
305
306    rewrite_fn(graph)
307    # Check if outputs of multipliers and adds are quantized.
308
309    mul_op = graph.get_operation_by_name('test/Mul')
310    self._AssertOutputGoesToOps(
311        mul_op, graph,
312        ['test/Mul/activation_Mul_quant/FakeQuantWithMinMaxVars'])
313    mul_op = graph.get_operation_by_name('test/Mul_1')
314    self._AssertOutputGoesToOps(
315        mul_op, graph,
316        ['test/Mul_1/activation_Mul_quant/FakeQuantWithMinMaxVars'])
317    add_op = graph.get_operation_by_name('test/add')
318    self._AssertOutputGoesToOps(
319        add_op, graph,
320        ['test/add/activation_Add_quant/FakeQuantWithMinMaxVars'])
321
322  def testRewriteWithScope(self):
323    self._RunTestOverExperimentalRewritesWithScope(
324        self._TestRewriteWithScope, 'scope1')
325
326  def _TestRewriteWithScope(self, rewrite_fn):
327    graph = ops.Graph()
328    with graph.as_default():
329      scope1_output = self._ConvLayer(scope='scope1')
330      self._ConvLayer(input_tensor=scope1_output, scope='scope2')
331
332    rewrite_fn(graph)
333
334    op_names = [op.name for op in graph.get_operations()]
335    # The weights and activation of scope1 is quantized, but not scope2.
336    self.assertTrue(
337        any('scope1/Conv/act_quant' in name for name in op_names))
338    self.assertTrue(
339        any('scope1/Conv/weights_quant' in name for name in op_names))
340    self.assertFalse(
341        any('scope2/Conv/act_quant' in name for name in op_names))
342    self.assertFalse(
343        any('scope2/Conv/weights_quant' in name for name in op_names))
344
345  def testRewriteWithNonMatchingScope(self):
346    self._RunTestOverExperimentalRewritesWithScope(
347        self._TestRewriteWithNonMatchingScope, 'NonExistingScope')
348
349  def _TestRewriteWithNonMatchingScope(self, rewrite_fn):
350    graph = ops.Graph()
351    with graph.as_default():
352      self._ConvLayer()
353
354    op_names_before_rewrite = set([op.name for op in graph.get_operations()])
355    rewrite_fn(graph)
356    op_names_after_rewrite = set([op.name for op in graph.get_operations()])
357
358    # No ops should be inserted or removed.
359    self.assertEqual(op_names_before_rewrite, op_names_after_rewrite)
360
361  def testActivationRewriteWithScope(self):
362    self._RunTestOverExperimentalRewritesWithScope(
363        self._TestActivationRewriteWithScope, 'scope1')
364
365  def _TestActivationRewriteWithScope(self, rewrite_fn):
366    graph = ops.Graph()
367    with graph.as_default():
368      output = self._LayerWithIdentity(scope='scope1')
369      with ops.name_scope('scope2'):
370        output = nn_ops.relu6(output)
371        scaled_output1 = math_ops.mul(2.0, output)
372        scaled_output2 = math_ops.mul(3.0, output)
373        output = scaled_output1 + scaled_output2
374      rewrite_fn(graph)
375
376      op_names = [op.name for op in graph.get_operations()]
377      # The weights and activation of scope1 is quantized, but not scope2.
378      self.assertTrue(any('scope1/Conv/act_quant' in name for name in op_names))
379      self.assertTrue(
380          any('scope1/Conv/weights_quant' in name for name in op_names))
381
382      for op_name in op_names:
383        if op_name.startswith('scope2'):
384          self.assertTrue('FakeQuant' not in op_name)
385
386  def testActivationRewriteWithNonMatchingScope(self):
387    self._RunTestOverExperimentalRewritesWithScope(
388        self._TestActivationRewriteWithNonMatchingScope, 'NonExistingScope')
389
390  def _TestActivationRewriteWithNonMatchingScope(self, rewrite_fn):
391    graph = ops.Graph()
392    with graph.as_default():
393      self._LayerWithActivationProcessing()
394
395    rewrite_fn(graph)
396    op_types_after_rewrite = set([op.type for op in graph.get_operations()])
397    self.assertFalse(
398        op_types_after_rewrite.intersection('FakeQuantWithMinMaxVars'))
399    # No fake quant ops should be inserted.
400
401  def testWithSharedWeights(self):
402
403    self._RunTestOverAllRewrites(self._TestWithSharedWeights)
404    self._RunTestOverTrainingRewrites(self._TestRewriteWithSharedWeights)
405
406  def _TestRewriteWithSharedWeights(self, rewrite_fn, quant_delay=1):
407    self._TestWithSharedWeights(rewrite_fn, quant_delay)
408
409  def _TestWithSharedWeights(self, rewrite_fn, quant_delay=None):
410    with ops.Graph().as_default() as g:
411      conv = template.make_template('shared_weights_conv', self._ConvLayer)
412      conv()
413      conv()
414      if quant_delay is None:
415        rewrite_fn()
416      else:
417        rewrite_fn(quant_delay=quant_delay)
418
419    conv_ops = [op for op in g.get_operations() if op.type == 'Conv2D']
420    weights_quants = [
421        op for op in g.get_operations()
422        if 'weights_quant' in op.name and op.type == 'FakeQuantWithMinMaxVars'
423    ]
424    # Check that the shared weights variable is not quantized multiple times
425    self.assertTrue(len(weights_quants) == 1)
426    weights_quant_tensor = weights_quants[0].outputs[0]
427    if quant_delay:
428      delayed_weights_quants = [
429          op for op in g.get_operations()
430          if 'weights_quant' in op.name and op.type == 'Merge'
431      ]
432      self.assertTrue(len(delayed_weights_quants) == 1)
433      weights_quant_tensor = delayed_weights_quants[0].outputs[0]
434    # Check that the Conv2D operations get the quantized weights
435    self.assertTrue(all(weights_quant_tensor in op.inputs for op in conv_ops))
436
437  def _ConvLayer(
438      self, input_tensor=None, scope='test', pre_activation_bypass=False,
439      post_activation_bypass=False):
440    """Add a basic convolution layer to the default graph."""
441    batch_size, height, width, depth = 5, 128, 128, 3
442    if input_tensor is None:
443      input_tensor = array_ops.zeros((batch_size, height, width, depth))
444    weight_init = init_ops.truncated_normal_initializer
445    with ops.name_scope(scope):
446      output = layers.conv2d(
447          input_tensor,
448          depth, [5, 5],
449          padding='SAME',
450          weights_initializer=weight_init(0.09),
451          activation_fn=None)
452      if pre_activation_bypass:
453        output += input_tensor
454      output = nn_ops.relu6(output)
455      if post_activation_bypass:
456        output += input_tensor
457    return output
458
459  def _LayerWithIdentity(self,
460                         input_tensor=None,
461                         scope='test',
462                         post_activation_bypass=False):
463    """Add a basic conv, identity, batch norm with skip to the default graph."""
464    batch_size, height, width, depth = 5, 128, 128, 3
465    if input_tensor is None:
466      input_tensor = array_ops.zeros((batch_size, height, width, depth))
467    weight_init = init_ops.truncated_normal_initializer
468    with ops.name_scope(scope):
469      output = layers.conv2d(
470          input_tensor,
471          depth, [5, 5],
472          padding='SAME',
473          weights_initializer=weight_init(0.09),
474          activation_fn=None,
475          normalizer_fn=None,
476          biases_initializer=None)
477      output = array_ops.identity(output, name='conv_out')
478
479      output = layers.batch_norm(
480          output, center=True, scale=True, decay=1.0 - 0.003, fused=True)
481
482      output = array_ops.identity(output, name='bn_out')
483      if post_activation_bypass:
484        output += input_tensor
485    return output
486
487  def _LayerWithActivationProcessing(self,
488                                     input_tensor=None,
489                                     scope='test',
490                                     post_activation_bypass=False):
491
492    batch_size, height, width, depth = 5, 128, 128, 3
493    if input_tensor is None:
494      input_tensor = array_ops.zeros((batch_size, height, width, depth))
495    weight_init = init_ops.truncated_normal_initializer
496    with ops.name_scope(scope):
497      output = layers.conv2d(
498          input_tensor,
499          depth, [5, 5],
500          padding='SAME',
501          weights_initializer=weight_init(0.09),
502          activation_fn=None,
503          normalizer_fn=None,
504          biases_initializer=None)
505
506      output = layers.batch_norm(
507          output, center=True, scale=True, decay=1.0 - 0.003, fused=True)
508
509      output = nn_ops.relu6(output)
510      scaled_output1 = math_ops.mul(2.0, output)
511      scaled_output2 = math_ops.mul(3.0, output)
512      output = scaled_output1 + scaled_output2
513    return output
514
515  def _AssertInputOpsAre(self, op, in_op_names):
516    """Asserts that all inputs to op come from in_op_names (disregarding order).
517
518    Args:
519      op: Operation to check inputs for.
520      in_op_names: List of strings, operations where all op's inputs should come
521        from.
522    """
523    expected_inputs = [in_op_name + ':0' for in_op_name in in_op_names]
524    self.assertItemsEqual([t.name for t in op.inputs], expected_inputs)
525
526  def _AssertOutputGoesToOps(self, op, graph, out_op_names):
527    """Asserts that outputs from op go to out_op_names (and perhaps others).
528
529    Args:
530      op: Operation to check outputs for.
531      graph: Graph where output operations are located.
532      out_op_names: List of strings, operations where op's outputs should go.
533    """
534    for out_op_name in out_op_names:
535      out_op = graph.get_operation_by_name(out_op_name)
536      self.assertIn(op.outputs[0].name, [str(t.name) for t in out_op.inputs])
537
538
539if __name__ == '__main__':
540  googletest.main()
541