• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for DCT operations."""
16
17import importlib
18import itertools
19
20from absl.testing import parameterized
21import numpy as np
22
23from tensorflow.python.framework import test_util
24from tensorflow.python.ops.signal import dct_ops
25from tensorflow.python.platform import test
26from tensorflow.python.platform import tf_logging
27
28
29def try_import(name):  # pylint: disable=invalid-name
30  module = None
31  try:
32    module = importlib.import_module(name)
33  except ImportError as e:
34    tf_logging.warning("Could not import %s: %s" % (name, str(e)))
35  return module
36
37
38fftpack = try_import("scipy.fftpack")
39
40
41def _modify_input_for_dct(signals, n=None):
42  """Pad or trim the provided NumPy array's innermost axis to length n."""
43  signal = np.array(signals)
44  if n is None or n == signal.shape[-1]:
45    signal_mod = signal
46  elif n >= 1:
47    signal_len = signal.shape[-1]
48    if n <= signal_len:
49      signal_mod = signal[..., 0:n]
50    else:
51      output_shape = list(signal.shape)
52      output_shape[-1] = n
53      signal_mod = np.zeros(output_shape)
54      signal_mod[..., 0:signal.shape[-1]] = signal
55  if n:
56    assert signal_mod.shape[-1] == n
57  return signal_mod
58
59
60def _np_dct1(signals, n=None, norm=None):
61  """Computes the DCT-I manually with NumPy."""
62  # X_k = (x_0 + (-1)**k * x_{N-1} +
63  #       2 * sum_{n=0}^{N-2} x_n * cos(\frac{pi}{N-1} * n * k)  k=0,...,N-1
64  del norm
65  signals_mod = _modify_input_for_dct(signals, n=n)
66  dct_size = signals_mod.shape[-1]
67  dct = np.zeros_like(signals_mod)
68  for k in range(dct_size):
69    phi = np.cos(np.pi * np.arange(1, dct_size - 1) * k / (dct_size - 1))
70    dct[..., k] = 2 * np.sum(
71        signals_mod[..., 1:-1] * phi, axis=-1) + (
72            signals_mod[..., 0] + (-1)**k * signals_mod[..., -1])
73  return dct
74
75
76def _np_dct2(signals, n=None, norm=None):
77  """Computes the DCT-II manually with NumPy."""
78  # X_k = sum_{n=0}^{N-1} x_n * cos(\frac{pi}{N} * (n + 0.5) * k)  k=0,...,N-1
79  signals_mod = _modify_input_for_dct(signals, n=n)
80  dct_size = signals_mod.shape[-1]
81  dct = np.zeros_like(signals_mod)
82  for k in range(dct_size):
83    phi = np.cos(np.pi * (np.arange(dct_size) + 0.5) * k / dct_size)
84    dct[..., k] = np.sum(signals_mod * phi, axis=-1)
85  # SciPy's `dct` has a scaling factor of 2.0 which we follow.
86  # https://github.com/scipy/scipy/blob/v1.2.1/scipy/fftpack/src/dct.c.src
87  if norm == "ortho":
88    # The orthonormal scaling includes a factor of 0.5 which we combine with
89    # the overall scaling of 2.0 to cancel.
90    dct[..., 0] *= np.sqrt(1.0 / dct_size)
91    dct[..., 1:] *= np.sqrt(2.0 / dct_size)
92  else:
93    dct *= 2.0
94  return dct
95
96
97def _np_dct3(signals, n=None, norm=None):
98  """Computes the DCT-III manually with NumPy."""
99  # SciPy's `dct` has a scaling factor of 2.0 which we follow.
100  # https://github.com/scipy/scipy/blob/v1.2.1/scipy/fftpack/src/dct.c.src
101  signals_mod = _modify_input_for_dct(signals, n=n)
102  dct_size = signals_mod.shape[-1]
103  signals_mod = np.array(signals_mod)  # make a copy so we can modify
104  if norm == "ortho":
105    signals_mod[..., 0] *= np.sqrt(4.0 / dct_size)
106    signals_mod[..., 1:] *= np.sqrt(2.0 / dct_size)
107  else:
108    signals_mod *= 2.0
109  dct = np.zeros_like(signals_mod)
110  # X_k = 0.5 * x_0 +
111  #       sum_{n=1}^{N-1} x_n * cos(\frac{pi}{N} * n * (k + 0.5))  k=0,...,N-1
112  half_x0 = 0.5 * signals_mod[..., 0]
113  for k in range(dct_size):
114    phi = np.cos(np.pi * np.arange(1, dct_size) * (k + 0.5) / dct_size)
115    dct[..., k] = half_x0 + np.sum(signals_mod[..., 1:] * phi, axis=-1)
116  return dct
117
118
119def _np_dct4(signals, n=None, norm=None):
120  """Computes the DCT-IV manually with NumPy."""
121  # SciPy's `dct` has a scaling factor of 2.0 which we follow.
122  # https://github.com/scipy/scipy/blob/v1.2.1/scipy/fftpack/src/dct.c.src
123  signals_mod = _modify_input_for_dct(signals, n=n)
124  dct_size = signals_mod.shape[-1]
125  signals_mod = np.array(signals_mod)  # make a copy so we can modify
126  if norm == "ortho":
127    signals_mod *= np.sqrt(2.0 / dct_size)
128  else:
129    signals_mod *= 2.0
130  dct = np.zeros_like(signals_mod)
131  # X_k = sum_{n=0}^{N-1}
132  #            x_n * cos(\frac{pi}{4N} * (2n + 1) * (2k + 1))  k=0,...,N-1
133  for k in range(dct_size):
134    phi = np.cos(np.pi *
135                 (2 * np.arange(0, dct_size) + 1) * (2 * k + 1) /
136                 (4.0 * dct_size))
137    dct[..., k] = np.sum(signals_mod * phi, axis=-1)
138  return dct
139
140
141NP_DCT = {1: _np_dct1, 2: _np_dct2, 3: _np_dct3, 4: _np_dct4}
142NP_IDCT = {1: _np_dct1, 2: _np_dct3, 3: _np_dct2, 4: _np_dct4}
143
144
145@test_util.run_all_in_graph_and_eager_modes
146class DCTOpsTest(parameterized.TestCase, test.TestCase):
147
148  def _compare(self, signals, n, norm, dct_type, atol, rtol):
149    """Compares (I)DCT to SciPy (if available) and a NumPy implementation."""
150    np_dct = NP_DCT[dct_type](signals, n=n, norm=norm)
151    tf_dct = dct_ops.dct(signals, n=n, type=dct_type, norm=norm)
152    self.assertEqual(tf_dct.dtype.as_numpy_dtype, signals.dtype)
153    self.assertAllClose(np_dct, tf_dct, atol=atol, rtol=rtol)
154    np_idct = NP_IDCT[dct_type](signals, n=None, norm=norm)
155    tf_idct = dct_ops.idct(signals, type=dct_type, norm=norm)
156    self.assertEqual(tf_idct.dtype.as_numpy_dtype, signals.dtype)
157    self.assertAllClose(np_idct, tf_idct, atol=atol, rtol=rtol)
158    if fftpack and dct_type != 4:
159      scipy_dct = fftpack.dct(signals, n=n, type=dct_type, norm=norm)
160      self.assertAllClose(scipy_dct, tf_dct, atol=atol, rtol=rtol)
161      scipy_idct = fftpack.idct(signals, type=dct_type, norm=norm)
162      self.assertAllClose(scipy_idct, tf_idct, atol=atol, rtol=rtol)
163    # Verify inverse(forward(s)) == s, up to a normalization factor.
164    # Since `n` is not implemented for IDCT operation, re-calculating tf_dct
165    # without n.
166    tf_dct = dct_ops.dct(signals, type=dct_type, norm=norm)
167    tf_idct_dct = dct_ops.idct(tf_dct, type=dct_type, norm=norm)
168    tf_dct_idct = dct_ops.dct(tf_idct, type=dct_type, norm=norm)
169    if norm is None:
170      if dct_type == 1:
171        tf_idct_dct *= 0.5 / (signals.shape[-1] - 1)
172        tf_dct_idct *= 0.5 / (signals.shape[-1] - 1)
173      else:
174        tf_idct_dct *= 0.5 / signals.shape[-1]
175        tf_dct_idct *= 0.5 / signals.shape[-1]
176    self.assertAllClose(signals, tf_idct_dct, atol=atol, rtol=rtol)
177    self.assertAllClose(signals, tf_dct_idct, atol=atol, rtol=rtol)
178
179  @parameterized.parameters(itertools.product(
180      [1, 2, 3, 4],
181      [None, "ortho"],
182      [[2], [3], [10], [2, 20], [2, 3, 25]],
183      [np.float32, np.float64]))
184  def test_random(self, dct_type, norm, shape, dtype):
185    """Test randomly generated batches of data."""
186    # "ortho" normalization is not implemented for type I.
187    if dct_type == 1 and norm == "ortho":
188      return
189    with self.session():
190      tol = 5e-4 if dtype == np.float32 else 1e-7
191      signals = np.random.rand(*shape).astype(dtype)
192      n = np.random.randint(1, 2 * signals.shape[-1])
193      n = np.random.choice([None, n])
194      self._compare(signals, n, norm=norm, dct_type=dct_type,
195                    rtol=tol, atol=tol)
196
197  def test_error(self):
198    signals = np.random.rand(10)
199    # Unsupported type.
200    with self.assertRaises(ValueError):
201      dct_ops.dct(signals, type=5)
202    # Invalid n.
203    with self.assertRaises(ValueError):
204      dct_ops.dct(signals, n=-2)
205    # DCT-I normalization not implemented.
206    with self.assertRaises(ValueError):
207      dct_ops.dct(signals, type=1, norm="ortho")
208    # DCT-I requires at least two inputs.
209    with self.assertRaises(ValueError):
210      dct_ops.dct(np.random.rand(1), type=1)
211    # Unknown normalization.
212    with self.assertRaises(ValueError):
213      dct_ops.dct(signals, norm="bad")
214    with self.assertRaises(NotImplementedError):
215      dct_ops.dct(signals, axis=0)
216
217
218if __name__ == "__main__":
219  test.main()
220