• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for Dequantize Operations."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22
23from tensorflow.python.framework import constant_op
24from tensorflow.python.framework import dtypes
25from tensorflow.python.ops import array_ops
26from tensorflow.python.platform import test
27
28
29class DequantizeOpTest(test.TestCase):
30
31  def __init__(self, method_name="runTest"):
32    super(DequantizeOpTest, self).__init__(method_name)
33
34  def _testDequantizeOp(self, inputs, min_range, max_range, dtype,
35                        mode="MIN_COMBINED", narrow_range=False):
36    with self.cached_session():
37      input_op = constant_op.constant(inputs, shape=[len(inputs)], dtype=dtype)
38      dequantized = array_ops.dequantize(input_op, min_range, max_range,
39                                         mode=mode, narrow_range=narrow_range)
40      tf_ans = self.evaluate(dequantized)
41
42    # TODO(vrv): Add support for DT_QINT32 quantization if needed.
43    type_dict = {
44        dtypes.quint8: np.uint8,
45        dtypes.qint8: np.int8,
46        dtypes.quint16: np.uint16,
47        dtypes.qint16: np.int16
48    }
49    self.assertIn(dtype, type_dict.keys())
50    v_max = np.iinfo(type_dict[dtype]).max
51    v_min = np.iinfo(type_dict[dtype]).min
52    self.assertGreaterEqual(min_range, v_min)
53    self.assertLessEqual(max_range, v_max)
54    type_range = v_max - v_min
55
56    if mode == "MIN_COMBINED":
57      if v_min < 0:
58        half_range = (type_range + 1) / 2
59      else:
60        half_range = 0.0
61      np_ans = ((inputs.astype(np.float32) + half_range) *
62                (max_range - min_range) / type_range) + min_range
63    elif mode == "SCALED":
64      if narrow_range:
65        v_min += 1
66      scale_factor = max(min_range / v_min, max_range / v_max)
67      np_ans = inputs.astype(np.float32) * scale_factor
68
69    self.assertAllClose(tf_ans, np_ans, rtol=1e-5, atol=1e-5)
70
71  def testBasicQuint8(self):
72    self._testDequantizeOp(np.array([0, 128, 255]), 0.0, 6.0, dtypes.quint8)
73    self._testDequantizeOp(np.array([0, 128, 255]), 0.0, 123.456, dtypes.quint8)
74    self._testDequantizeOp(
75        np.array([0, 4, 42, 108, 243]), 5.0, 200.2, dtypes.quint8)
76
77  def testBasicQint8(self):
78    self._testDequantizeOp(np.array([-128, 0, 127]), -1.0, 2.0, dtypes.qint8)
79    self._testDequantizeOp(np.array([-2, 4, -17]), -5.0, -3.0, dtypes.qint8)
80    self._testDequantizeOp(np.array([0, -4, 42, -108]), 5.0, 40.0, dtypes.qint8)
81
82  def testScaledMode(self):
83    self._testDequantizeOp(np.array([-128, 0, 127]), -1.0, 2.0, dtypes.qint8,
84                           mode="SCALED")
85    self._testDequantizeOp(np.array([-2, 4, -17]), -5.0, -3.0, dtypes.qint8,
86                           mode="SCALED")
87    self._testDequantizeOp(np.array([0, -4, 42, -108]), 5.0, 40.0, dtypes.qint8,
88                           mode="SCALED")
89
90  def testNarrowRange(self):
91    self._testDequantizeOp(np.array([-128, 0, 127]), -1.0, 2.0, dtypes.qint8,
92                           mode="SCALED", narrow_range=True)
93    self._testDequantizeOp(np.array([-2, 4, -17]), -5.0, -3.0, dtypes.qint8,
94                           mode="SCALED", narrow_range=True)
95    self._testDequantizeOp(np.array([0, -4, 42, -108]), 5.0, 40.0, dtypes.qint8,
96                           mode="SCALED", narrow_range=True)
97
98  def testAxis(self):
99    # Generates a tensor of the specified `shape` using values from `values`
100    # scaled by (slice_idx + 1) along `axis` dimension.
101    def scale_per_slice(shape, axis, values):
102      # Note: repeats the values if the shape is larger than values.
103      out = np.take(values, np.remainder(np.arange(np.prod(shape)),
104                                         len(values))).reshape(shape)
105      if axis is not None:
106        scale_shape = [1] * len(shape)
107        scale_shape[axis] = shape[axis]
108        out *= np.arange(1, shape[axis] + 1).reshape(scale_shape)
109      return out
110
111    shape = np.array([2, 3, 4, 5])
112    values = np.array([-128, -64, 0, 38, 102, 71, 64], dtype=np.int32)
113    dequant_values = np.array([-2, -1.0, 0, 0.59375, 1.59375, 1.109375, 1.0],
114                              dtype=np.float32)
115    for axis in [None, 0, 1, 2, 3]:
116      inputs = constant_op.constant(
117          scale_per_slice(shape, None, values), dtype=dtypes.qint8)
118      expected_dequantized = scale_per_slice(shape, axis, dequant_values)
119      if axis is None:
120        min_range, max_range = -2.0, 1.6
121      else:
122        num_slices = shape[axis]
123        min_range, max_range = [], []
124        for slice_idx in range(num_slices):
125          min_range.append(-2.0 * (slice_idx + 1))
126          max_range.append(1.6 * (slice_idx + 1))
127      dequantized = self.evaluate(
128          array_ops.dequantize(
129              inputs, min_range, max_range, mode="SCALED", axis=axis))
130      self.assertAllEqual(dequantized, expected_dequantized)
131      if axis is not None:
132        dequantized = self.evaluate(
133            array_ops.dequantize(
134                inputs, min_range, max_range, mode="SCALED", axis=(axis - 4)))
135        self.assertAllClose(dequantized, expected_dequantized)
136
137if __name__ == "__main__":
138  test.main()
139