1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for DCT operations.""" 16 17import importlib 18import itertools 19 20from absl.testing import parameterized 21import numpy as np 22 23from tensorflow.python.framework import test_util 24from tensorflow.python.ops.signal import dct_ops 25from tensorflow.python.platform import test 26from tensorflow.python.platform import tf_logging 27 28 29def try_import(name): # pylint: disable=invalid-name 30 module = None 31 try: 32 module = importlib.import_module(name) 33 except ImportError as e: 34 tf_logging.warning("Could not import %s: %s" % (name, str(e))) 35 return module 36 37 38fftpack = try_import("scipy.fftpack") 39 40 41def _modify_input_for_dct(signals, n=None): 42 """Pad or trim the provided NumPy array's innermost axis to length n.""" 43 signal = np.array(signals) 44 if n is None or n == signal.shape[-1]: 45 signal_mod = signal 46 elif n >= 1: 47 signal_len = signal.shape[-1] 48 if n <= signal_len: 49 signal_mod = signal[..., 0:n] 50 else: 51 output_shape = list(signal.shape) 52 output_shape[-1] = n 53 signal_mod = np.zeros(output_shape) 54 signal_mod[..., 0:signal.shape[-1]] = signal 55 if n: 56 assert signal_mod.shape[-1] == n 57 return signal_mod 58 59 60def _np_dct1(signals, n=None, norm=None): 61 """Computes the DCT-I manually with NumPy.""" 62 # X_k = (x_0 + (-1)**k * x_{N-1} + 63 # 2 * sum_{n=0}^{N-2} x_n * cos(\frac{pi}{N-1} * n * k) k=0,...,N-1 64 del norm 65 signals_mod = _modify_input_for_dct(signals, n=n) 66 dct_size = signals_mod.shape[-1] 67 dct = np.zeros_like(signals_mod) 68 for k in range(dct_size): 69 phi = np.cos(np.pi * np.arange(1, dct_size - 1) * k / (dct_size - 1)) 70 dct[..., k] = 2 * np.sum( 71 signals_mod[..., 1:-1] * phi, axis=-1) + ( 72 signals_mod[..., 0] + (-1)**k * signals_mod[..., -1]) 73 return dct 74 75 76def _np_dct2(signals, n=None, norm=None): 77 """Computes the DCT-II manually with NumPy.""" 78 # X_k = sum_{n=0}^{N-1} x_n * cos(\frac{pi}{N} * (n + 0.5) * k) k=0,...,N-1 79 signals_mod = _modify_input_for_dct(signals, n=n) 80 dct_size = signals_mod.shape[-1] 81 dct = np.zeros_like(signals_mod) 82 for k in range(dct_size): 83 phi = np.cos(np.pi * (np.arange(dct_size) + 0.5) * k / dct_size) 84 dct[..., k] = np.sum(signals_mod * phi, axis=-1) 85 # SciPy's `dct` has a scaling factor of 2.0 which we follow. 86 # https://github.com/scipy/scipy/blob/v1.2.1/scipy/fftpack/src/dct.c.src 87 if norm == "ortho": 88 # The orthonormal scaling includes a factor of 0.5 which we combine with 89 # the overall scaling of 2.0 to cancel. 90 dct[..., 0] *= np.sqrt(1.0 / dct_size) 91 dct[..., 1:] *= np.sqrt(2.0 / dct_size) 92 else: 93 dct *= 2.0 94 return dct 95 96 97def _np_dct3(signals, n=None, norm=None): 98 """Computes the DCT-III manually with NumPy.""" 99 # SciPy's `dct` has a scaling factor of 2.0 which we follow. 100 # https://github.com/scipy/scipy/blob/v1.2.1/scipy/fftpack/src/dct.c.src 101 signals_mod = _modify_input_for_dct(signals, n=n) 102 dct_size = signals_mod.shape[-1] 103 signals_mod = np.array(signals_mod) # make a copy so we can modify 104 if norm == "ortho": 105 signals_mod[..., 0] *= np.sqrt(4.0 / dct_size) 106 signals_mod[..., 1:] *= np.sqrt(2.0 / dct_size) 107 else: 108 signals_mod *= 2.0 109 dct = np.zeros_like(signals_mod) 110 # X_k = 0.5 * x_0 + 111 # sum_{n=1}^{N-1} x_n * cos(\frac{pi}{N} * n * (k + 0.5)) k=0,...,N-1 112 half_x0 = 0.5 * signals_mod[..., 0] 113 for k in range(dct_size): 114 phi = np.cos(np.pi * np.arange(1, dct_size) * (k + 0.5) / dct_size) 115 dct[..., k] = half_x0 + np.sum(signals_mod[..., 1:] * phi, axis=-1) 116 return dct 117 118 119def _np_dct4(signals, n=None, norm=None): 120 """Computes the DCT-IV manually with NumPy.""" 121 # SciPy's `dct` has a scaling factor of 2.0 which we follow. 122 # https://github.com/scipy/scipy/blob/v1.2.1/scipy/fftpack/src/dct.c.src 123 signals_mod = _modify_input_for_dct(signals, n=n) 124 dct_size = signals_mod.shape[-1] 125 signals_mod = np.array(signals_mod) # make a copy so we can modify 126 if norm == "ortho": 127 signals_mod *= np.sqrt(2.0 / dct_size) 128 else: 129 signals_mod *= 2.0 130 dct = np.zeros_like(signals_mod) 131 # X_k = sum_{n=0}^{N-1} 132 # x_n * cos(\frac{pi}{4N} * (2n + 1) * (2k + 1)) k=0,...,N-1 133 for k in range(dct_size): 134 phi = np.cos(np.pi * 135 (2 * np.arange(0, dct_size) + 1) * (2 * k + 1) / 136 (4.0 * dct_size)) 137 dct[..., k] = np.sum(signals_mod * phi, axis=-1) 138 return dct 139 140 141NP_DCT = {1: _np_dct1, 2: _np_dct2, 3: _np_dct3, 4: _np_dct4} 142NP_IDCT = {1: _np_dct1, 2: _np_dct3, 3: _np_dct2, 4: _np_dct4} 143 144 145@test_util.run_all_in_graph_and_eager_modes 146class DCTOpsTest(parameterized.TestCase, test.TestCase): 147 148 def _compare(self, signals, n, norm, dct_type, atol, rtol): 149 """Compares (I)DCT to SciPy (if available) and a NumPy implementation.""" 150 np_dct = NP_DCT[dct_type](signals, n=n, norm=norm) 151 tf_dct = dct_ops.dct(signals, n=n, type=dct_type, norm=norm) 152 self.assertEqual(tf_dct.dtype.as_numpy_dtype, signals.dtype) 153 self.assertAllClose(np_dct, tf_dct, atol=atol, rtol=rtol) 154 np_idct = NP_IDCT[dct_type](signals, n=None, norm=norm) 155 tf_idct = dct_ops.idct(signals, type=dct_type, norm=norm) 156 self.assertEqual(tf_idct.dtype.as_numpy_dtype, signals.dtype) 157 self.assertAllClose(np_idct, tf_idct, atol=atol, rtol=rtol) 158 if fftpack and dct_type != 4: 159 scipy_dct = fftpack.dct(signals, n=n, type=dct_type, norm=norm) 160 self.assertAllClose(scipy_dct, tf_dct, atol=atol, rtol=rtol) 161 scipy_idct = fftpack.idct(signals, type=dct_type, norm=norm) 162 self.assertAllClose(scipy_idct, tf_idct, atol=atol, rtol=rtol) 163 # Verify inverse(forward(s)) == s, up to a normalization factor. 164 # Since `n` is not implemented for IDCT operation, re-calculating tf_dct 165 # without n. 166 tf_dct = dct_ops.dct(signals, type=dct_type, norm=norm) 167 tf_idct_dct = dct_ops.idct(tf_dct, type=dct_type, norm=norm) 168 tf_dct_idct = dct_ops.dct(tf_idct, type=dct_type, norm=norm) 169 if norm is None: 170 if dct_type == 1: 171 tf_idct_dct *= 0.5 / (signals.shape[-1] - 1) 172 tf_dct_idct *= 0.5 / (signals.shape[-1] - 1) 173 else: 174 tf_idct_dct *= 0.5 / signals.shape[-1] 175 tf_dct_idct *= 0.5 / signals.shape[-1] 176 self.assertAllClose(signals, tf_idct_dct, atol=atol, rtol=rtol) 177 self.assertAllClose(signals, tf_dct_idct, atol=atol, rtol=rtol) 178 179 @parameterized.parameters(itertools.product( 180 [1, 2, 3, 4], 181 [None, "ortho"], 182 [[2], [3], [10], [2, 20], [2, 3, 25]], 183 [np.float32, np.float64])) 184 def test_random(self, dct_type, norm, shape, dtype): 185 """Test randomly generated batches of data.""" 186 # "ortho" normalization is not implemented for type I. 187 if dct_type == 1 and norm == "ortho": 188 return 189 with self.session(): 190 tol = 5e-4 if dtype == np.float32 else 1e-7 191 signals = np.random.rand(*shape).astype(dtype) 192 n = np.random.randint(1, 2 * signals.shape[-1]) 193 n = np.random.choice([None, n]) 194 self._compare(signals, n, norm=norm, dct_type=dct_type, 195 rtol=tol, atol=tol) 196 197 def test_error(self): 198 signals = np.random.rand(10) 199 # Unsupported type. 200 with self.assertRaises(ValueError): 201 dct_ops.dct(signals, type=5) 202 # Invalid n. 203 with self.assertRaises(ValueError): 204 dct_ops.dct(signals, n=-2) 205 # DCT-I normalization not implemented. 206 with self.assertRaises(ValueError): 207 dct_ops.dct(signals, type=1, norm="ortho") 208 # DCT-I requires at least two inputs. 209 with self.assertRaises(ValueError): 210 dct_ops.dct(np.random.rand(1), type=1) 211 # Unknown normalization. 212 with self.assertRaises(ValueError): 213 dct_ops.dct(signals, norm="bad") 214 with self.assertRaises(NotImplementedError): 215 dct_ops.dct(signals, axis=0) 216 217 218if __name__ == "__main__": 219 test.main() 220