1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for DCT operations.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import importlib 22import itertools 23 24from absl.testing import parameterized 25import numpy as np 26 27from tensorflow.python.framework import test_util 28from tensorflow.python.ops.signal import dct_ops 29from tensorflow.python.platform import test 30from tensorflow.python.platform import tf_logging 31 32 33def try_import(name): # pylint: disable=invalid-name 34 module = None 35 try: 36 module = importlib.import_module(name) 37 except ImportError as e: 38 tf_logging.warning("Could not import %s: %s" % (name, str(e))) 39 return module 40 41 42fftpack = try_import("scipy.fftpack") 43 44 45def _modify_input_for_dct(signals, n=None): 46 """Pad or trim the provided NumPy array's innermost axis to length n.""" 47 signal = np.array(signals) 48 if n is None or n == signal.shape[-1]: 49 signal_mod = signal 50 elif n >= 1: 51 signal_len = signal.shape[-1] 52 if n <= signal_len: 53 signal_mod = signal[..., 0:n] 54 else: 55 output_shape = list(signal.shape) 56 output_shape[-1] = n 57 signal_mod = np.zeros(output_shape) 58 signal_mod[..., 0:signal.shape[-1]] = signal 59 if n: 60 assert signal_mod.shape[-1] == n 61 return signal_mod 62 63 64def _np_dct1(signals, n=None, norm=None): 65 """Computes the DCT-I manually with NumPy.""" 66 # X_k = (x_0 + (-1)**k * x_{N-1} + 67 # 2 * sum_{n=0}^{N-2} x_n * cos(\frac{pi}{N-1} * n * k) k=0,...,N-1 68 del norm 69 signals_mod = _modify_input_for_dct(signals, n=n) 70 dct_size = signals_mod.shape[-1] 71 dct = np.zeros_like(signals_mod) 72 for k in range(dct_size): 73 phi = np.cos(np.pi * np.arange(1, dct_size - 1) * k / (dct_size - 1)) 74 dct[..., k] = 2 * np.sum( 75 signals_mod[..., 1:-1] * phi, axis=-1) + ( 76 signals_mod[..., 0] + (-1)**k * signals_mod[..., -1]) 77 return dct 78 79 80def _np_dct2(signals, n=None, norm=None): 81 """Computes the DCT-II manually with NumPy.""" 82 # X_k = sum_{n=0}^{N-1} x_n * cos(\frac{pi}{N} * (n + 0.5) * k) k=0,...,N-1 83 signals_mod = _modify_input_for_dct(signals, n=n) 84 dct_size = signals_mod.shape[-1] 85 dct = np.zeros_like(signals_mod) 86 for k in range(dct_size): 87 phi = np.cos(np.pi * (np.arange(dct_size) + 0.5) * k / dct_size) 88 dct[..., k] = np.sum(signals_mod * phi, axis=-1) 89 # SciPy's `dct` has a scaling factor of 2.0 which we follow. 90 # https://github.com/scipy/scipy/blob/v1.2.1/scipy/fftpack/src/dct.c.src 91 if norm == "ortho": 92 # The orthonormal scaling includes a factor of 0.5 which we combine with 93 # the overall scaling of 2.0 to cancel. 94 dct[..., 0] *= np.sqrt(1.0 / dct_size) 95 dct[..., 1:] *= np.sqrt(2.0 / dct_size) 96 else: 97 dct *= 2.0 98 return dct 99 100 101def _np_dct3(signals, n=None, norm=None): 102 """Computes the DCT-III manually with NumPy.""" 103 # SciPy's `dct` has a scaling factor of 2.0 which we follow. 104 # https://github.com/scipy/scipy/blob/v1.2.1/scipy/fftpack/src/dct.c.src 105 signals_mod = _modify_input_for_dct(signals, n=n) 106 dct_size = signals_mod.shape[-1] 107 signals_mod = np.array(signals_mod) # make a copy so we can modify 108 if norm == "ortho": 109 signals_mod[..., 0] *= np.sqrt(4.0 / dct_size) 110 signals_mod[..., 1:] *= np.sqrt(2.0 / dct_size) 111 else: 112 signals_mod *= 2.0 113 dct = np.zeros_like(signals_mod) 114 # X_k = 0.5 * x_0 + 115 # sum_{n=1}^{N-1} x_n * cos(\frac{pi}{N} * n * (k + 0.5)) k=0,...,N-1 116 half_x0 = 0.5 * signals_mod[..., 0] 117 for k in range(dct_size): 118 phi = np.cos(np.pi * np.arange(1, dct_size) * (k + 0.5) / dct_size) 119 dct[..., k] = half_x0 + np.sum(signals_mod[..., 1:] * phi, axis=-1) 120 return dct 121 122 123def _np_dct4(signals, n=None, norm=None): 124 """Computes the DCT-IV manually with NumPy.""" 125 # SciPy's `dct` has a scaling factor of 2.0 which we follow. 126 # https://github.com/scipy/scipy/blob/v1.2.1/scipy/fftpack/src/dct.c.src 127 signals_mod = _modify_input_for_dct(signals, n=n) 128 dct_size = signals_mod.shape[-1] 129 signals_mod = np.array(signals_mod) # make a copy so we can modify 130 if norm == "ortho": 131 signals_mod *= np.sqrt(2.0 / dct_size) 132 else: 133 signals_mod *= 2.0 134 dct = np.zeros_like(signals_mod) 135 # X_k = sum_{n=0}^{N-1} 136 # x_n * cos(\frac{pi}{4N} * (2n + 1) * (2k + 1)) k=0,...,N-1 137 for k in range(dct_size): 138 phi = np.cos(np.pi * 139 (2 * np.arange(0, dct_size) + 1) * (2 * k + 1) / 140 (4.0 * dct_size)) 141 dct[..., k] = np.sum(signals_mod * phi, axis=-1) 142 return dct 143 144 145NP_DCT = {1: _np_dct1, 2: _np_dct2, 3: _np_dct3, 4: _np_dct4} 146NP_IDCT = {1: _np_dct1, 2: _np_dct3, 3: _np_dct2, 4: _np_dct4} 147 148 149@test_util.run_all_in_graph_and_eager_modes 150class DCTOpsTest(parameterized.TestCase, test.TestCase): 151 152 def _compare(self, signals, n, norm, dct_type, atol, rtol): 153 """Compares (I)DCT to SciPy (if available) and a NumPy implementation.""" 154 np_dct = NP_DCT[dct_type](signals, n=n, norm=norm) 155 tf_dct = dct_ops.dct(signals, n=n, type=dct_type, norm=norm) 156 self.assertEqual(tf_dct.dtype.as_numpy_dtype, signals.dtype) 157 self.assertAllClose(np_dct, tf_dct, atol=atol, rtol=rtol) 158 np_idct = NP_IDCT[dct_type](signals, n=None, norm=norm) 159 tf_idct = dct_ops.idct(signals, type=dct_type, norm=norm) 160 self.assertEqual(tf_idct.dtype.as_numpy_dtype, signals.dtype) 161 self.assertAllClose(np_idct, tf_idct, atol=atol, rtol=rtol) 162 if fftpack and dct_type != 4: 163 scipy_dct = fftpack.dct(signals, n=n, type=dct_type, norm=norm) 164 self.assertAllClose(scipy_dct, tf_dct, atol=atol, rtol=rtol) 165 scipy_idct = fftpack.idct(signals, type=dct_type, norm=norm) 166 self.assertAllClose(scipy_idct, tf_idct, atol=atol, rtol=rtol) 167 # Verify inverse(forward(s)) == s, up to a normalization factor. 168 # Since `n` is not implemented for IDCT operation, re-calculating tf_dct 169 # without n. 170 tf_dct = dct_ops.dct(signals, type=dct_type, norm=norm) 171 tf_idct_dct = dct_ops.idct(tf_dct, type=dct_type, norm=norm) 172 tf_dct_idct = dct_ops.dct(tf_idct, type=dct_type, norm=norm) 173 if norm is None: 174 if dct_type == 1: 175 tf_idct_dct *= 0.5 / (signals.shape[-1] - 1) 176 tf_dct_idct *= 0.5 / (signals.shape[-1] - 1) 177 else: 178 tf_idct_dct *= 0.5 / signals.shape[-1] 179 tf_dct_idct *= 0.5 / signals.shape[-1] 180 self.assertAllClose(signals, tf_idct_dct, atol=atol, rtol=rtol) 181 self.assertAllClose(signals, tf_dct_idct, atol=atol, rtol=rtol) 182 183 @parameterized.parameters(itertools.product( 184 [1, 2, 3, 4], 185 [None, "ortho"], 186 [[2], [3], [10], [2, 20], [2, 3, 25]], 187 [np.float32, np.float64])) 188 def test_random(self, dct_type, norm, shape, dtype): 189 """Test randomly generated batches of data.""" 190 # "ortho" normalization is not implemented for type I. 191 if dct_type == 1 and norm == "ortho": 192 return 193 with self.session(): 194 tol = 5e-4 if dtype == np.float32 else 1e-7 195 signals = np.random.rand(*shape).astype(dtype) 196 n = np.random.randint(1, 2 * signals.shape[-1]) 197 n = np.random.choice([None, n]) 198 self._compare(signals, n, norm=norm, dct_type=dct_type, 199 rtol=tol, atol=tol) 200 201 def test_error(self): 202 signals = np.random.rand(10) 203 # Unsupported type. 204 with self.assertRaises(ValueError): 205 dct_ops.dct(signals, type=5) 206 # Invalid n. 207 with self.assertRaises(ValueError): 208 dct_ops.dct(signals, n=-2) 209 # DCT-I normalization not implemented. 210 with self.assertRaises(ValueError): 211 dct_ops.dct(signals, type=1, norm="ortho") 212 # DCT-I requires at least two inputs. 213 with self.assertRaises(ValueError): 214 dct_ops.dct(np.random.rand(1), type=1) 215 # Unknown normalization. 216 with self.assertRaises(ValueError): 217 dct_ops.dct(signals, norm="bad") 218 with self.assertRaises(NotImplementedError): 219 dct_ops.dct(signals, axis=0) 220 221 222if __name__ == "__main__": 223 test.main() 224