• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for third_party.tensorflow.contrib.quantize.python.quant_ops."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.contrib.quantize.python import quant_ops
22from tensorflow.python.client import session
23from tensorflow.python.framework import dtypes
24from tensorflow.python.framework import ops
25from tensorflow.python.ops import array_ops
26from tensorflow.python.ops import partitioned_variables
27from tensorflow.python.ops import variable_scope
28from tensorflow.python.ops import variables
29from tensorflow.python.platform import googletest
30
31_MIN_MAX_VARS = 'min_max_vars'
32_SYMMETRIC_RANGE_RATIO = 0.9921875  # 127 / 128
33
34
35class QuantOpsTest(googletest.TestCase):
36
37  def testLastValueQuantizeTrainingAssign(self):
38    min_value, max_value = self._GetMinMaxValues(quant_ops.LastValueQuantize,
39                                                 [[-1, 1]])
40    self.assertEqual(min_value, -1.0)
41    self.assertEqual(max_value, 1.0)
42
43  def testLastValueSymmetricQuantizeTrainingAssign(self):
44    min_value, max_value = self._GetMinMaxValues(
45        quant_ops.LastValueQuantize,
46        [[-_SYMMETRIC_RANGE_RATIO, _SYMMETRIC_RANGE_RATIO]],
47        symmetric=True,
48        narrow_range=False)
49    self.assertEqual(min_value, -1.0)
50    self.assertEqual(max_value, _SYMMETRIC_RANGE_RATIO)
51
52  def testLastValueSymmetricQuantizeNarrowRangeTrainingAssign(self):
53    min_value, max_value = self._GetMinMaxValues(
54        quant_ops.LastValueQuantize, [[-1, 0.5]],
55        symmetric=True,
56        narrow_range=True)
57    self.assertEqual(min_value, -1.0)
58    self.assertEqual(max_value, 1)
59
60  def testMovingAvgQuantizeTrainingAssign(self):
61    min_value, max_value = self._GetMinMaxValues(quant_ops.MovingAvgQuantize,
62                                                 [[-1, 1], [0, 0]])
63    self.assertAlmostEqual(min_value, -0.5, delta=1e-3)
64    self.assertAlmostEqual(max_value, 0.5, delta=1e-3)
65
66  def testMovingAvgQuantizeTrainingAssignNoShape(self):
67    min_value, max_value = self._GetMinMaxValues(
68        quant_ops.MovingAvgQuantize, [[-1, 1], [0, 0]], shape=None)
69    self.assertAlmostEqual(min_value, -0.5, delta=1e-3)
70    self.assertAlmostEqual(max_value, 0.5, delta=1e-3)
71
72  def testMovingAvgSymmetricQuantizeTrainingAssign(self):
73    min_value, max_value = self._GetMinMaxValues(
74        quant_ops.MovingAvgQuantize, [[-1, 0.5], [0, 0]], symmetric=True)
75    self.assertAlmostEqual(min_value, -0.5, delta=1e-3)
76    self.assertAlmostEqual(max_value, 0.5 * _SYMMETRIC_RANGE_RATIO, delta=1e-3)
77    self.assertAlmostEqual(max_value, min_value * -_SYMMETRIC_RANGE_RATIO)
78
79  def testMovingAvgSymmetricQuantizeNarrowRangeTrainingAssign(self):
80    min_value, max_value = self._GetMinMaxValues(
81        quant_ops.MovingAvgQuantize, [[-1, 0.5], [0, 0]],
82        symmetric=True,
83        narrow_range=True)
84    self.assertAlmostEqual(min_value, -0.5, delta=1e-3)
85    self.assertAlmostEqual(max_value, 0.5, delta=1e-3)
86    self.assertAlmostEqual(max_value, -min_value)
87
88  def testVariablesNotPartitioned_LastValue(self):
89    # Variables added should not use a default partiioner since they are
90    # scalar. There would be a tensorflow error thrown if the partitioner was
91    # respected by the rewrite.
92    with ops.Graph().as_default():
93      with variable_scope.variable_scope(
94          'part', partitioner=partitioned_variables.fixed_size_partitioner(2)):
95        x = array_ops.placeholder(dtypes.float32, shape=[2])
96        _ = quant_ops.LastValueQuantize(
97            x,
98            init_min=0.0,
99            init_max=0.0,
100            is_training=True,
101            vars_collection=_MIN_MAX_VARS)
102
103  def testVariablesNotPartitioned_MovingAvg(self):
104    # Variables added should not use a default partiioner since they are
105    # scalar. There would be a tensorflow error thrown if the partitioner was
106    # respected by the rewrite.
107    with ops.Graph().as_default():
108      with variable_scope.variable_scope(
109          'part', partitioner=partitioned_variables.fixed_size_partitioner(2)):
110        x = array_ops.placeholder(dtypes.float32, shape=[2])
111        _ = quant_ops.MovingAvgQuantize(
112            x,
113            init_min=0.0,
114            init_max=0.0,
115            is_training=True,
116            vars_collection=_MIN_MAX_VARS)
117
118  def _GetMinMaxValues(self, quantize_fn, input_values, shape=(2), **kwds):
119    g = ops.Graph()
120    with session.Session(graph=g) as sess:
121      x = array_ops.placeholder(dtypes.float32, shape=shape)
122      y = quantize_fn(
123          x,
124          init_min=0.0,
125          init_max=0.0,
126          is_training=True,
127          vars_collection=_MIN_MAX_VARS,
128          **kwds)
129
130      # Run the step.
131      sess.run(variables.global_variables_initializer())
132      for input_elem in input_values:
133        sess.run(y, feed_dict={x: input_elem})
134
135      # Now check that the min_max_vars were, in fact, updated.
136      min_max_vars = ops.get_collection(_MIN_MAX_VARS)
137      self.assertEqual(len(min_max_vars), 2)
138      min_idx = 0 if 'min' in min_max_vars[0].name else 1
139      max_idx = (min_idx + 1) % 2
140      min_var, max_var = min_max_vars[min_idx], min_max_vars[max_idx]
141      min_max_values = sess.run([min_var, max_var])
142      return min_max_values[0], min_max_values[1]
143
144
145if __name__ == '__main__':
146  googletest.main()
147