• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Unit tests for quantizing a Tensorflow graph."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.contrib.layers.python.layers import layers
22from tensorflow.contrib.quantize.python import quantize
23from tensorflow.python.framework import ops
24from tensorflow.python.framework import test_util
25from tensorflow.python.ops import array_ops
26from tensorflow.python.ops import control_flow_ops
27from tensorflow.python.ops import init_ops
28from tensorflow.python.ops import math_ops
29from tensorflow.python.ops import nn_ops
30from tensorflow.python.platform import googletest
31
32conv2d = layers.conv2d
33separable_conv2d = layers.separable_conv2d
34
35
36class QuantizeTest(test_util.TensorFlowTestCase):
37
38  def _RunTestOverParameters(self, test_fn):
39    params = [True, False]
40    for is_training in params:
41      test_fn(is_training)
42
43  def testInsertQuantOpFailsWhenOpsNotConnected(self):
44    pass
45
46  def _TestInsertQuantOpFailsWhenOpsNotConnected(self, is_training):
47    graph = ops.Graph()
48    with graph.as_default():
49      batch_size, height, width, depth = 5, 128, 128, 3
50      inputs = array_ops.zeros((batch_size, height, width, depth))
51      conv = conv2d(inputs, 32, [5, 5], stride=2, padding='SAME',
52                    weights_initializer=self._WeightInit(0.09),
53                    activation_fn=None, scope='test')
54      relu = nn_ops.relu6(inputs)
55
56    # Inserting a quantization op between two unconnected ops should fail with
57    # ValueError.
58    with self.assertRaises(ValueError) as err:
59      quantize._InsertQuantOp('test', is_training, conv.op, [relu.op],
60                              'FailingQuantOp')
61    self.assertEqual(
62        str(err.exception), 'Some inputs not quantized for ops: [Relu6]')
63
64  def testInsertQuantOpForAddAfterConv2d(self):
65    self._RunTestOverParameters(self._TestInsertQuantOpForAddAfterConv2d)
66
67  def _TestInsertQuantOpForAddAfterConv2d(self, is_training):
68    graph = ops.Graph()
69    with graph.as_default():
70      batch_size, height, width, depth = 5, 128, 128, 3
71      input1 = array_ops.zeros((batch_size, height, width, depth))
72      input2 = array_ops.zeros((batch_size, height / 2, width / 2, 32))
73      conv = conv2d(input1, 32, [5, 5], stride=2, padding='SAME',
74                    weights_initializer=self._WeightInit(0.09),
75                    activation_fn=None, scope='test/test')
76      node = math_ops.add(conv, input2, name='test/add')
77      node = array_ops.identity(node, name='test/identity')
78      update_barrier = control_flow_ops.no_op(name='update_barrier')
79      with ops.control_dependencies([update_barrier]):
80        array_ops.identity(node, name='control_dependency')
81
82    quantize.Quantize(graph, is_training, weight_bits=8, activation_bits=8)
83
84    quantization_node_name = 'FakeQuantWithMinMaxVars'
85    add_quant = graph.get_operation_by_name('test/add_quant/' +
86                                            quantization_node_name)
87    self.assertEqual(add_quant.type, quantization_node_name)
88
89  def testInsertQuantOpForAddAfterSeparableConv2d(self):
90    self._RunTestOverParameters(
91        self._TestInsertQuantOpForAddAfterSeparableConv2d)
92
93  def _TestInsertQuantOpForAddAfterSeparableConv2d(self, is_training):
94    graph = ops.Graph()
95    with graph.as_default():
96      batch_size, height, width, depth = 5, 128, 128, 3
97      input1 = array_ops.zeros((batch_size, height, width, depth))
98      input2 = array_ops.zeros((batch_size, height / 2, width / 2, depth))
99      conv = separable_conv2d(input1, None, [5, 5], stride=2,
100                              depth_multiplier=1.0, padding='SAME',
101                              weights_initializer=self._WeightInit(0.09),
102                              activation_fn=None, scope='test/test')
103      node = math_ops.add(conv, input2, name='test/add')
104      node = array_ops.identity(node, name='test/identity')
105      update_barrier = control_flow_ops.no_op(name='update_barrier')
106      with ops.control_dependencies([update_barrier]):
107        array_ops.identity(node, name='control_dependency')
108
109    quantize.Quantize(graph, is_training, weight_bits=8, activation_bits=8)
110
111    quantization_node_name = 'FakeQuantWithMinMaxVars'
112    add_quant = graph.get_operation_by_name('test/add_quant/' +
113                                            quantization_node_name)
114    self.assertEqual(add_quant.type, quantization_node_name)
115
116  def _WeightInit(self, stddev):
117    """Returns truncated normal variable initializer.
118
119    Function is defined purely to shorten the name so that it stops wrapping.
120
121    Args:
122      stddev: Standard deviation of normal variable.
123
124    Returns:
125      An initialized that initialzes with a truncated normal variable.
126    """
127    return init_ops.truncated_normal_initializer(stddev=stddev)
128
129if __name__ == '__main__':
130  googletest.main()
131