1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for Dequantize Operations.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import numpy as np 22 23from tensorflow.python.framework import constant_op 24from tensorflow.python.framework import dtypes 25from tensorflow.python.ops import array_ops 26from tensorflow.python.platform import test 27 28 29class DequantizeOpTest(test.TestCase): 30 31 def __init__(self, method_name="runTest"): 32 super(DequantizeOpTest, self).__init__(method_name) 33 34 def _testDequantizeOp(self, inputs, min_range, max_range, dtype, 35 mode="MIN_COMBINED", narrow_range=False): 36 with self.cached_session(): 37 input_op = constant_op.constant(inputs, shape=[len(inputs)], dtype=dtype) 38 dequantized = array_ops.dequantize(input_op, min_range, max_range, 39 mode=mode, narrow_range=narrow_range) 40 tf_ans = self.evaluate(dequantized) 41 42 # TODO(vrv): Add support for DT_QINT32 quantization if needed. 43 type_dict = { 44 dtypes.quint8: np.uint8, 45 dtypes.qint8: np.int8, 46 dtypes.quint16: np.uint16, 47 dtypes.qint16: np.int16 48 } 49 self.assertIn(dtype, type_dict.keys()) 50 v_max = np.iinfo(type_dict[dtype]).max 51 v_min = np.iinfo(type_dict[dtype]).min 52 self.assertGreaterEqual(min_range, v_min) 53 self.assertLessEqual(max_range, v_max) 54 type_range = v_max - v_min 55 56 if mode == "MIN_COMBINED": 57 if v_min < 0: 58 half_range = (type_range + 1) / 2 59 else: 60 half_range = 0.0 61 np_ans = ((inputs.astype(np.float32) + half_range) * 62 (max_range - min_range) / type_range) + min_range 63 elif mode == "SCALED": 64 if narrow_range: 65 v_min += 1 66 scale_factor = max(min_range / v_min, max_range / v_max) 67 np_ans = inputs.astype(np.float32) * scale_factor 68 69 self.assertAllClose(tf_ans, np_ans, rtol=1e-5, atol=1e-5) 70 71 def testBasicQuint8(self): 72 self._testDequantizeOp(np.array([0, 128, 255]), 0.0, 6.0, dtypes.quint8) 73 self._testDequantizeOp(np.array([0, 128, 255]), 0.0, 123.456, dtypes.quint8) 74 self._testDequantizeOp( 75 np.array([0, 4, 42, 108, 243]), 5.0, 200.2, dtypes.quint8) 76 77 def testBasicQint8(self): 78 self._testDequantizeOp(np.array([-128, 0, 127]), -1.0, 2.0, dtypes.qint8) 79 self._testDequantizeOp(np.array([-2, 4, -17]), -5.0, -3.0, dtypes.qint8) 80 self._testDequantizeOp(np.array([0, -4, 42, -108]), 5.0, 40.0, dtypes.qint8) 81 82 def testScaledMode(self): 83 self._testDequantizeOp(np.array([-128, 0, 127]), -1.0, 2.0, dtypes.qint8, 84 mode="SCALED") 85 self._testDequantizeOp(np.array([-2, 4, -17]), -5.0, -3.0, dtypes.qint8, 86 mode="SCALED") 87 self._testDequantizeOp(np.array([0, -4, 42, -108]), 5.0, 40.0, dtypes.qint8, 88 mode="SCALED") 89 90 def testNarrowRange(self): 91 self._testDequantizeOp(np.array([-128, 0, 127]), -1.0, 2.0, dtypes.qint8, 92 mode="SCALED", narrow_range=True) 93 self._testDequantizeOp(np.array([-2, 4, -17]), -5.0, -3.0, dtypes.qint8, 94 mode="SCALED", narrow_range=True) 95 self._testDequantizeOp(np.array([0, -4, 42, -108]), 5.0, 40.0, dtypes.qint8, 96 mode="SCALED", narrow_range=True) 97 98 def testAxis(self): 99 # Generates a tensor of the specified `shape` using values from `values` 100 # scaled by (slice_idx + 1) along `axis` dimension. 101 def scale_per_slice(shape, axis, values): 102 # Note: repeats the values if the shape is larger than values. 103 out = np.take(values, np.remainder(np.arange(np.prod(shape)), 104 len(values))).reshape(shape) 105 if axis is not None: 106 scale_shape = [1] * len(shape) 107 scale_shape[axis] = shape[axis] 108 out *= np.arange(1, shape[axis] + 1).reshape(scale_shape) 109 return out 110 111 shape = np.array([2, 3, 4, 5]) 112 values = np.array([-128, -64, 0, 38, 102, 71, 64], dtype=np.int32) 113 dequant_values = np.array([-2, -1.0, 0, 0.59375, 1.59375, 1.109375, 1.0], 114 dtype=np.float32) 115 for axis in [None, 0, 1, 2, 3]: 116 inputs = constant_op.constant( 117 scale_per_slice(shape, None, values), dtype=dtypes.qint8) 118 expected_dequantized = scale_per_slice(shape, axis, dequant_values) 119 if axis is None: 120 min_range, max_range = -2.0, 1.6 121 else: 122 num_slices = shape[axis] 123 min_range, max_range = [], [] 124 for slice_idx in range(num_slices): 125 min_range.append(-2.0 * (slice_idx + 1)) 126 max_range.append(1.6 * (slice_idx + 1)) 127 dequantized = self.evaluate( 128 array_ops.dequantize( 129 inputs, min_range, max_range, mode="SCALED", axis=axis)) 130 self.assertAllEqual(dequantized, expected_dequantized) 131 if axis is not None: 132 dequantized = self.evaluate( 133 array_ops.dequantize( 134 inputs, min_range, max_range, mode="SCALED", axis=(axis - 4))) 135 self.assertAllClose(dequantized, expected_dequantized) 136 137if __name__ == "__main__": 138 test.main() 139