• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15import contextlib
16
17from absl.testing import parameterized
18import numpy as np
19
20from tensorflow.python.framework import config
21from tensorflow.python.framework import dtypes
22from tensorflow.python.framework import test_util
23from tensorflow.python.ops import array_ops
24from tensorflow.python.ops import math_ops
25from tensorflow.python.ops import variables as variables_module
26from tensorflow.python.ops.linalg import linalg
27from tensorflow.python.ops.linalg import linear_operator_circulant
28from tensorflow.python.ops.linalg import linear_operator_test_util
29from tensorflow.python.ops.signal import fft_ops
30from tensorflow.python.platform import test
31
32rng = np.random.RandomState(0)
33_to_complex = linear_operator_circulant._to_complex
34
35exponential_power_convolution_kernel = (
36    linear_operator_circulant.exponential_power_convolution_kernel)
37
38
39def _operator_from_kernel(kernel, d, **kwargs):
40  spectrum = linear_operator_circulant._FFT_OP[d](
41      math_ops.cast(kernel, dtypes.complex64))
42  if d == 1:
43    return linear_operator_circulant.LinearOperatorCirculant(spectrum, **kwargs)
44  elif d == 2:
45    return linear_operator_circulant.LinearOperatorCirculant2D(
46        spectrum, **kwargs)
47  elif d == 3:
48    return linear_operator_circulant.LinearOperatorCirculant3D(
49        spectrum, **kwargs)
50
51
52def _spectrum_for_symmetric_circulant(
53    spectrum_shape,
54    d,
55    ensure_self_adjoint_and_pd,
56    dtype,
57):
58  """Spectrum for d-dimensional real/symmetric circulant."""
59  grid_shape = spectrum_shape[-d:]
60
61  if grid_shape == (0,) * d:
62    kernel = array_ops.reshape(math_ops.cast([], dtype), grid_shape)
63  else:
64    kernel = exponential_power_convolution_kernel(
65        grid_shape=grid_shape,
66        # power=2 with this scale and no inflation will have some negative
67        # spectra. It will still be real/symmetric.
68        length_scale=math_ops.cast([0.2] * d, dtype.real_dtype),
69        power=1 if ensure_self_adjoint_and_pd else 2,
70        zero_inflation=0.2 if ensure_self_adjoint_and_pd else None,
71    )
72  spectrum = linear_operator_circulant._FFT_OP[d](_to_complex(kernel))
73  spectrum = math_ops.cast(spectrum, dtype)
74  return array_ops.broadcast_to(spectrum, spectrum_shape)
75
76
77@test_util.run_all_in_graph_and_eager_modes
78class ExponentialPowerConvolutionKernelTest(parameterized.TestCase,
79                                            test.TestCase):
80
81  def assert_diag_is_ones(self, matrix, rtol):
82    self.assertAllClose(
83        np.ones_like(np.diag(matrix)), np.diag(matrix), rtol=rtol)
84
85  def assert_real_symmetric(self, matrix, tol):
86    self.assertAllClose(np.zeros_like(matrix.imag), matrix.imag, atol=tol)
87    self.assertAllClose(matrix.real, matrix.real.T, rtol=tol)
88
89  @parameterized.named_parameters(
90      dict(testcase_name="1Deven_power1", grid_shape=[10], power=1.),
91      dict(testcase_name="2Deven_power1", grid_shape=[4, 6], power=1.),
92      dict(testcase_name="3Deven_power1", grid_shape=[4, 6, 8], power=1.),
93      dict(testcase_name="3Devenodd_power1", grid_shape=[4, 5, 7], power=1.),
94      dict(testcase_name="1Dodd_power2", grid_shape=[9], power=2.),
95      dict(testcase_name="2Deven_power2", grid_shape=[8, 4], power=2.),
96      dict(testcase_name="3Devenodd_power2", grid_shape=[4, 5, 3], power=2.),
97  )
98  def test_makes_symmetric_and_real_circulant_with_ones_diag(
99      self, grid_shape, power):
100    d = len(grid_shape)
101    length_scale = [0.2] * d
102    kernel = exponential_power_convolution_kernel(
103        grid_shape=grid_shape,
104        length_scale=length_scale,
105        power=power)
106    operator = _operator_from_kernel(kernel, d)
107
108    matrix = self.evaluate(operator.to_dense())
109
110    tol = np.finfo(matrix.dtype).eps * np.prod(grid_shape)
111    self.assert_real_symmetric(matrix, tol)
112    self.assert_diag_is_ones(matrix, rtol=tol)
113
114  @parameterized.named_parameters(
115      dict(testcase_name="1D", grid_shape=[10]),
116      dict(testcase_name="2D", grid_shape=[5, 5]),
117      dict(testcase_name="3D", grid_shape=[5, 4, 3]),
118  )
119  def test_zero_inflation(self, grid_shape):
120    d = len(grid_shape)
121    length_scale = [0.2] * d
122
123    kernel_no_inflation = exponential_power_convolution_kernel(
124        grid_shape=grid_shape,
125        length_scale=length_scale,
126        zero_inflation=None,
127    )
128    matrix_no_inflation = self.evaluate(
129        _operator_from_kernel(kernel_no_inflation, d).to_dense())
130
131    kernel_inflation_one_half = exponential_power_convolution_kernel(
132        grid_shape=grid_shape,
133        length_scale=length_scale,
134        zero_inflation=0.5,
135    )
136    matrix_inflation_one_half = self.evaluate(
137        _operator_from_kernel(kernel_inflation_one_half, d).to_dense())
138
139    kernel_inflation_one = exponential_power_convolution_kernel(
140        grid_shape=grid_shape,
141        length_scale=length_scale,
142        zero_inflation=1.0,
143    )
144    matrix_inflation_one = self.evaluate(
145        _operator_from_kernel(kernel_inflation_one, d).to_dense())
146
147    tol = np.finfo(matrix_no_inflation.dtype).eps * np.prod(grid_shape)
148
149    # In all cases, matrix should be real and symmetric.
150    self.assert_real_symmetric(matrix_no_inflation, tol)
151    self.assert_real_symmetric(matrix_inflation_one, tol)
152    self.assert_real_symmetric(matrix_inflation_one_half, tol)
153
154    # In all cases, the diagonal should be all ones.
155    self.assert_diag_is_ones(matrix_no_inflation, rtol=tol)
156    self.assert_diag_is_ones(matrix_inflation_one_half, rtol=tol)
157    self.assert_diag_is_ones(matrix_inflation_one, rtol=tol)
158
159    def _matrix_with_zerod_diag(matrix):
160      return matrix - np.diag(np.diag(matrix))
161
162    # Inflation = 0.5 means the off-diagonal is deflated by factor (1 - .5) = .5
163    self.assertAllClose(
164        _matrix_with_zerod_diag(matrix_no_inflation) * 0.5,
165        _matrix_with_zerod_diag(matrix_inflation_one_half), rtol=tol)
166
167    # Inflation = 1.0 means the off-diagonal is deflated by factor (1 - 1) = 0
168    self.assertAllClose(
169        np.zeros_like(matrix_inflation_one),
170        _matrix_with_zerod_diag(matrix_inflation_one), rtol=tol)
171
172  @parameterized.named_parameters(
173      dict(testcase_name="1D", grid_shape=[10]),
174      dict(testcase_name="2D", grid_shape=[5, 5]),
175      dict(testcase_name="3D", grid_shape=[5, 4, 3]),
176  )
177  def test_tiny_scale_corresponds_to_identity_matrix(self, grid_shape):
178    d = len(grid_shape)
179
180    kernel = exponential_power_convolution_kernel(
181        grid_shape=grid_shape, length_scale=[0.001] * d, power=2)
182    matrix = self.evaluate(_operator_from_kernel(kernel, d).to_dense())
183
184    tol = np.finfo(matrix.dtype).eps * np.prod(grid_shape)
185    self.assertAllClose(matrix, np.eye(np.prod(grid_shape)), atol=tol)
186    self.assert_real_symmetric(matrix, tol)
187
188  @parameterized.named_parameters(
189      dict(testcase_name="1D", grid_shape=[10]),
190      dict(testcase_name="2D", grid_shape=[5, 5]),
191      dict(testcase_name="3D", grid_shape=[5, 4, 3]),
192  )
193  def test_huge_scale_corresponds_to_ones_matrix(self, grid_shape):
194    d = len(grid_shape)
195
196    kernel = exponential_power_convolution_kernel(
197        grid_shape=grid_shape, length_scale=[100.] * d, power=2)
198    matrix = self.evaluate(_operator_from_kernel(kernel, d).to_dense())
199
200    tol = np.finfo(matrix.dtype).eps * np.prod(grid_shape) * 50
201    self.assert_real_symmetric(matrix, tol)
202    self.assertAllClose(np.ones_like(matrix), matrix, rtol=tol)
203
204
205@test_util.run_all_in_graph_and_eager_modes
206class LinearOperatorCirculantBaseTest(object):
207  """Common class for circulant tests."""
208
209  _atol = {
210      dtypes.float16: 1e-3,
211      dtypes.float32: 1e-6,
212      dtypes.float64: 1e-7,
213      dtypes.complex64: 1e-6,
214      dtypes.complex128: 1e-7
215  }
216  _rtol = {
217      dtypes.float16: 1e-3,
218      dtypes.float32: 1e-6,
219      dtypes.float64: 1e-7,
220      dtypes.complex64: 1e-6,
221      dtypes.complex128: 1e-7
222  }
223
224  @contextlib.contextmanager
225  def _constrain_devices_and_set_default(self, sess, use_gpu, force_gpu):
226    """We overwrite the FFT operation mapping for testing."""
227    with test.TestCase._constrain_devices_and_set_default(
228        self, sess, use_gpu, force_gpu) as sess:
229      yield sess
230
231  def _shape_to_spectrum_shape(self, shape):
232    # If spectrum.shape = batch_shape + [N],
233    # this creates an operator of shape batch_shape + [N, N]
234    return shape[:-1]
235
236  def _spectrum_to_circulant_1d(self, spectrum, shape, dtype):
237    """Creates a circulant matrix from a spectrum.
238
239    Intentionally done in an explicit yet inefficient way.  This provides a
240    cross check to the main code that uses fancy reshapes.
241
242    Args:
243      spectrum: Float or complex `Tensor`.
244      shape:  Python list.  Desired shape of returned matrix.
245      dtype:  Type to cast the returned matrix to.
246
247    Returns:
248      Circulant (batch) matrix of desired `dtype`.
249    """
250    spectrum = _to_complex(spectrum)
251    spectrum_shape = self._shape_to_spectrum_shape(shape)
252    domain_dimension = spectrum_shape[-1]
253    if not domain_dimension:
254      return array_ops.zeros(shape, dtype)
255
256    # Explicitly compute the action of spectrum on basis vectors.
257    matrix_rows = []
258    for m in range(domain_dimension):
259      x = np.zeros([domain_dimension])
260      # x is a basis vector.
261      x[m] = 1.0
262      fft_x = fft_ops.fft(math_ops.cast(x, spectrum.dtype))
263      h_convolve_x = fft_ops.ifft(spectrum * fft_x)
264      matrix_rows.append(h_convolve_x)
265    matrix = array_ops.stack(matrix_rows, axis=-1)
266    return math_ops.cast(matrix, dtype)
267
268
269class LinearOperatorCirculantTestSelfAdjointOperator(
270    LinearOperatorCirculantBaseTest,
271    linear_operator_test_util.SquareLinearOperatorDerivedClassTest):
272  """Test of LinearOperatorCirculant when operator is self-adjoint.
273
274  Real spectrum <==> Self adjoint operator.
275  Note that when the spectrum is real, the operator may still be complex.
276  """
277
278  @staticmethod
279  def dtypes_to_test():
280    # This operator will always be complex because, although the spectrum is
281    # real, the matrix will not be real.
282    return [dtypes.complex64, dtypes.complex128]
283
284  @staticmethod
285  def optional_tests():
286    """List of optional test names to run."""
287    return [
288        "operator_matmul_with_same_type",
289        "operator_solve_with_same_type",
290    ]
291
292  def operator_and_matrix(self,
293                          shape_info,
294                          dtype,
295                          use_placeholder,
296                          ensure_self_adjoint_and_pd=False):
297    shape = shape_info.shape
298    # For this test class, we are creating real spectrums.
299    # We also want the spectrum to have eigenvalues bounded away from zero.
300    #
301    # spectrum is bounded away from zero.
302    spectrum = linear_operator_test_util.random_sign_uniform(
303        shape=self._shape_to_spectrum_shape(shape), minval=1., maxval=2.)
304    if ensure_self_adjoint_and_pd:
305      spectrum = math_ops.abs(spectrum)
306    # If dtype is complex, cast spectrum to complex.  The imaginary part will be
307    # zero, so the operator will still be self-adjoint.
308    spectrum = math_ops.cast(spectrum, dtype)
309
310    lin_op_spectrum = spectrum
311
312    if use_placeholder:
313      lin_op_spectrum = array_ops.placeholder_with_default(spectrum, shape=None)
314
315    operator = linalg.LinearOperatorCirculant(
316        lin_op_spectrum,
317        is_self_adjoint=True,
318        is_positive_definite=True if ensure_self_adjoint_and_pd else None,
319        input_output_dtype=dtype)
320
321    mat = self._spectrum_to_circulant_1d(spectrum, shape, dtype=dtype)
322
323    return operator, mat
324
325  @test_util.disable_xla("No registered Const")
326  def test_simple_hermitian_spectrum_gives_operator_with_zero_imag_part(self):
327    with self.cached_session():
328      spectrum = math_ops.cast([1. + 0j, 1j, -1j], dtypes.complex64)
329      operator = linalg.LinearOperatorCirculant(
330          spectrum, input_output_dtype=dtypes.complex64)
331      matrix = operator.to_dense()
332      imag_matrix = math_ops.imag(matrix)
333      eps = np.finfo(np.float32).eps
334      np.testing.assert_allclose(
335          0, self.evaluate(imag_matrix), rtol=0, atol=eps * 3)
336
337  def test_tape_safe(self):
338    spectrum = variables_module.Variable(
339        math_ops.cast([1. + 0j, 1. + 0j], dtypes.complex64))
340    operator = linalg.LinearOperatorCirculant(spectrum, is_self_adjoint=True)
341    self.check_tape_safe(operator)
342
343  def test_convert_variables_to_tensors(self):
344    spectrum = variables_module.Variable(
345        math_ops.cast([1. + 0j, 1. + 0j], dtypes.complex64))
346    operator = linalg.LinearOperatorCirculant(spectrum, is_self_adjoint=True)
347    with self.cached_session() as sess:
348      sess.run([spectrum.initializer])
349      self.check_convert_variables_to_tensors(operator)
350
351
352class LinearOperatorCirculantTestHermitianSpectrum(
353    LinearOperatorCirculantBaseTest,
354    linear_operator_test_util.SquareLinearOperatorDerivedClassTest):
355  """Test of LinearOperatorCirculant when the spectrum is Hermitian.
356
357  Hermitian spectrum <==> Real valued operator.  We test both real and complex
358  dtypes here though.  So in some cases the matrix will be complex but with
359  zero imaginary part.
360  """
361
362  def tearDown(self):
363    config.enable_tensor_float_32_execution(self.tf32_keep_)
364
365  def setUp(self):
366    self.tf32_keep_ = config.tensor_float_32_execution_enabled()
367    config.enable_tensor_float_32_execution(False)
368
369  @staticmethod
370  def optional_tests():
371    """List of optional test names to run."""
372    return [
373        "operator_matmul_with_same_type",
374        "operator_solve_with_same_type",
375    ]
376
377  def operator_and_matrix(self,
378                          shape_info,
379                          dtype,
380                          use_placeholder,
381                          ensure_self_adjoint_and_pd=False):
382    shape = shape_info.shape
383    spectrum = _spectrum_for_symmetric_circulant(
384        spectrum_shape=self._shape_to_spectrum_shape(shape),
385        d=1,
386        ensure_self_adjoint_and_pd=ensure_self_adjoint_and_pd,
387        dtype=dtype)
388
389    lin_op_spectrum = spectrum
390
391    if use_placeholder:
392      lin_op_spectrum = array_ops.placeholder_with_default(spectrum, shape=None)
393
394    operator = linalg.LinearOperatorCirculant(
395        lin_op_spectrum,
396        input_output_dtype=dtype,
397        is_positive_definite=True if ensure_self_adjoint_and_pd else None,
398        is_self_adjoint=True if ensure_self_adjoint_and_pd else None,
399    )
400
401    mat = self._spectrum_to_circulant_1d(spectrum, shape, dtype=dtype)
402
403    return operator, mat
404
405  @test_util.disable_xla("No registered Const")
406  def test_simple_hermitian_spectrum_gives_operator_with_zero_imag_part(self):
407    with self.cached_session():
408      spectrum = math_ops.cast([1. + 0j, 1j, -1j], dtypes.complex64)
409      operator = linalg.LinearOperatorCirculant(
410          spectrum, input_output_dtype=dtypes.complex64)
411      matrix = operator.to_dense()
412      imag_matrix = math_ops.imag(matrix)
413      eps = np.finfo(np.float32).eps
414      np.testing.assert_allclose(
415          0, self.evaluate(imag_matrix), rtol=0, atol=eps * 3)
416
417  def test_tape_safe(self):
418    spectrum = variables_module.Variable(
419        math_ops.cast([1. + 0j, 1. + 1j], dtypes.complex64))
420    operator = linalg.LinearOperatorCirculant(spectrum, is_self_adjoint=False)
421    self.check_tape_safe(operator)
422
423
424class LinearOperatorCirculantTestNonHermitianSpectrum(
425    LinearOperatorCirculantBaseTest,
426    linear_operator_test_util.SquareLinearOperatorDerivedClassTest):
427  """Test of LinearOperatorCirculant when the spectrum is not Hermitian.
428
429  Non-Hermitian spectrum <==> Complex valued operator.
430  We test only complex dtypes here.
431  """
432
433  @staticmethod
434  def dtypes_to_test():
435    return [dtypes.complex64, dtypes.complex128]
436
437  # Skip Cholesky since we are explicitly testing non-hermitian
438  # spectra.
439  @staticmethod
440  def skip_these_tests():
441    return ["cholesky", "eigvalsh"]
442
443  @staticmethod
444  def optional_tests():
445    """List of optional test names to run."""
446    return [
447        "operator_matmul_with_same_type",
448        "operator_solve_with_same_type",
449    ]
450
451  def operator_and_matrix(self,
452                          shape_info,
453                          dtype,
454                          use_placeholder,
455                          ensure_self_adjoint_and_pd=False):
456    del ensure_self_adjoint_and_pd
457    shape = shape_info.shape
458    # Will be well conditioned enough to get accurate solves.
459    spectrum = linear_operator_test_util.random_sign_uniform(
460        shape=self._shape_to_spectrum_shape(shape),
461        dtype=dtype,
462        minval=1.,
463        maxval=2.)
464
465    lin_op_spectrum = spectrum
466
467    if use_placeholder:
468      lin_op_spectrum = array_ops.placeholder_with_default(spectrum, shape=None)
469
470    operator = linalg.LinearOperatorCirculant(
471        lin_op_spectrum, input_output_dtype=dtype)
472
473    self.assertEqual(
474        operator.parameters,
475        {
476            "input_output_dtype": dtype,
477            "is_non_singular": None,
478            "is_positive_definite": None,
479            "is_self_adjoint": None,
480            "is_square": True,
481            "name": "LinearOperatorCirculant",
482            "spectrum": lin_op_spectrum,
483        })
484
485    mat = self._spectrum_to_circulant_1d(spectrum, shape, dtype=dtype)
486
487    return operator, mat
488
489  @test_util.disable_xla("No registered Const")
490  def test_simple_hermitian_spectrum_gives_operator_with_zero_imag_part(self):
491    with self.cached_session():
492      spectrum = math_ops.cast([1. + 0j, 1j, -1j], dtypes.complex64)
493      operator = linalg.LinearOperatorCirculant(
494          spectrum, input_output_dtype=dtypes.complex64)
495      matrix = operator.to_dense()
496      imag_matrix = math_ops.imag(matrix)
497      eps = np.finfo(np.float32).eps
498      np.testing.assert_allclose(
499          0, self.evaluate(imag_matrix), rtol=0, atol=eps * 3)
500
501  def test_simple_positive_real_spectrum_gives_self_adjoint_pos_def_oper(self):
502    with self.cached_session() as sess:
503      spectrum = math_ops.cast([6., 4, 2], dtypes.complex64)
504      operator = linalg.LinearOperatorCirculant(
505          spectrum, input_output_dtype=dtypes.complex64)
506      matrix, matrix_h = sess.run(
507          [operator.to_dense(),
508           linalg.adjoint(operator.to_dense())])
509      self.assertAllClose(matrix, matrix_h)
510      self.evaluate(operator.assert_positive_definite())  # Should not fail
511      self.evaluate(operator.assert_self_adjoint())  # Should not fail
512
513  def test_defining_operator_using_real_convolution_kernel(self):
514    with self.cached_session():
515      convolution_kernel = [1., 2., 1.]
516      spectrum = fft_ops.fft(
517          math_ops.cast(convolution_kernel, dtypes.complex64))
518
519      # spectrum is shape [3] ==> operator is shape [3, 3]
520      # spectrum is Hermitian ==> operator is real.
521      operator = linalg.LinearOperatorCirculant(spectrum)
522
523      # Allow for complex output so we can make sure it has zero imag part.
524      self.assertEqual(operator.dtype, dtypes.complex64)
525
526      matrix = self.evaluate(operator.to_dense())
527      np.testing.assert_allclose(0, np.imag(matrix), atol=1e-6)
528
529  @test_util.run_v1_only("currently failing on v2")
530  def test_hermitian_spectrum_gives_operator_with_zero_imag_part(self):
531    with self.cached_session():
532      # Make spectrum the FFT of a real convolution kernel h.  This ensures that
533      # spectrum is Hermitian.
534      h = linear_operator_test_util.random_normal(shape=(3, 4))
535      spectrum = fft_ops.fft(math_ops.cast(h, dtypes.complex64))
536      operator = linalg.LinearOperatorCirculant(
537          spectrum, input_output_dtype=dtypes.complex64)
538      matrix = operator.to_dense()
539      imag_matrix = math_ops.imag(matrix)
540      eps = np.finfo(np.float32).eps
541      np.testing.assert_allclose(
542          0, self.evaluate(imag_matrix), rtol=0, atol=eps * 3 * 4)
543
544  def test_convolution_kernel_same_as_first_row_of_to_dense(self):
545    spectrum = [[3., 2., 1.], [2., 1.5, 1.]]
546    with self.cached_session():
547      operator = linalg.LinearOperatorCirculant(spectrum)
548      h = operator.convolution_kernel()
549      c = operator.to_dense()
550
551      self.assertAllEqual((2, 3), h.shape)
552      self.assertAllEqual((2, 3, 3), c.shape)
553      self.assertAllClose(self.evaluate(h), self.evaluate(c)[:, :, 0])
554
555  def test_assert_non_singular_fails_for_singular_operator(self):
556    spectrum = math_ops.cast([0 + 0j, 4 + 0j, 2j + 2], dtypes.complex64)
557    operator = linalg.LinearOperatorCirculant(spectrum)
558    with self.cached_session():
559      with self.assertRaisesOpError("Singular operator"):
560        self.evaluate(operator.assert_non_singular())
561
562  def test_assert_non_singular_does_not_fail_for_non_singular_operator(self):
563    spectrum = math_ops.cast([-3j, 4 + 0j, 2j + 2], dtypes.complex64)
564    operator = linalg.LinearOperatorCirculant(spectrum)
565    with self.cached_session():
566      self.evaluate(operator.assert_non_singular())  # Should not fail
567
568  def test_assert_positive_definite_fails_for_non_positive_definite(self):
569    spectrum = math_ops.cast([6. + 0j, 4 + 0j, 2j], dtypes.complex64)
570    operator = linalg.LinearOperatorCirculant(spectrum)
571    with self.cached_session():
572      with self.assertRaisesOpError("Not positive definite"):
573        self.evaluate(operator.assert_positive_definite())
574
575  def test_assert_positive_definite_does_not_fail_when_pos_def(self):
576    spectrum = math_ops.cast([6. + 0j, 4 + 0j, 2j + 2], dtypes.complex64)
577    operator = linalg.LinearOperatorCirculant(spectrum)
578    with self.cached_session():
579      self.evaluate(operator.assert_positive_definite())  # Should not fail
580
581  def test_real_spectrum_and_not_self_adjoint_hint_raises(self):
582    spectrum = [1., 2.]
583    with self.assertRaisesRegex(ValueError, "real.*always.*self-adjoint"):
584      linalg.LinearOperatorCirculant(spectrum, is_self_adjoint=False)
585
586  def test_real_spectrum_auto_sets_is_self_adjoint_to_true(self):
587    spectrum = [1., 2.]
588    operator = linalg.LinearOperatorCirculant(spectrum)
589    self.assertTrue(operator.is_self_adjoint)
590
591
592@test_util.run_all_in_graph_and_eager_modes
593class LinearOperatorCirculant2DBaseTest(object):
594  """Common class for 2D circulant tests."""
595
596  _atol = {
597      dtypes.float16: 1e-3,
598      dtypes.float32: 1e-6,
599      dtypes.float64: 1e-7,
600      dtypes.complex64: 1e-6,
601      dtypes.complex128: 1e-7
602  }
603  _rtol = {
604      dtypes.float16: 1e-3,
605      dtypes.float32: 1e-6,
606      dtypes.float64: 1e-7,
607      dtypes.complex64: 1e-6,
608      dtypes.complex128: 1e-7
609  }
610
611  @contextlib.contextmanager
612  def _constrain_devices_and_set_default(self, sess, use_gpu, force_gpu):
613    """We overwrite the FFT operation mapping for testing."""
614    with test.TestCase._constrain_devices_and_set_default(
615        self, sess, use_gpu, force_gpu) as sess:
616      yield sess
617
618  @staticmethod
619  def operator_shapes_infos():
620    shape_info = linear_operator_test_util.OperatorShapesInfo
621    # non-batch operators (n, n) and batch operators.
622    return [
623        shape_info((0, 0)),
624        shape_info((1, 1)),
625        shape_info((1, 6, 6)),
626        shape_info((3, 4, 4)),
627        shape_info((2, 1, 3, 3))
628    ]
629
630  @staticmethod
631  def optional_tests():
632    """List of optional test names to run."""
633    return [
634        "operator_matmul_with_same_type",
635        "operator_solve_with_same_type",
636    ]
637
638  def _shape_to_spectrum_shape(self, shape):
639    """Get a spectrum shape that will make an operator of desired shape."""
640    # This 2D block circulant operator takes a spectrum of shape
641    # batch_shape + [N0, N1],
642    # and creates and operator of shape
643    # batch_shape + [N0*N1, N0*N1]
644    if shape == (0, 0):
645      return (0, 0)
646    elif shape == (1, 1):
647      return (1, 1)
648    elif shape == (1, 6, 6):
649      return (1, 2, 3)
650    elif shape == (3, 4, 4):
651      return (3, 2, 2)
652    elif shape == (2, 1, 3, 3):
653      return (2, 1, 3, 1)
654    else:
655      raise ValueError("Unhandled shape: %s" % shape)
656
657  def _spectrum_to_circulant_2d(self, spectrum, shape, dtype):
658    """Creates a block circulant matrix from a spectrum.
659
660    Intentionally done in an explicit yet inefficient way.  This provides a
661    cross check to the main code that uses fancy reshapes.
662
663    Args:
664      spectrum: Float or complex `Tensor`.
665      shape:  Python list.  Desired shape of returned matrix.
666      dtype:  Type to cast the returned matrix to.
667
668    Returns:
669      Block circulant (batch) matrix of desired `dtype`.
670    """
671    spectrum = _to_complex(spectrum)
672    spectrum_shape = self._shape_to_spectrum_shape(shape)
673    domain_dimension = spectrum_shape[-1]
674    if not domain_dimension:
675      return array_ops.zeros(shape, dtype)
676
677    block_shape = spectrum_shape[-2:]
678
679    # Explicitly compute the action of spectrum on basis vectors.
680    matrix_rows = []
681    for n0 in range(block_shape[0]):
682      for n1 in range(block_shape[1]):
683        x = np.zeros(block_shape)
684        # x is a basis vector.
685        x[n0, n1] = 1.0
686        fft_x = fft_ops.fft2d(math_ops.cast(x, spectrum.dtype))
687        h_convolve_x = fft_ops.ifft2d(spectrum * fft_x)
688        # We want the flat version of the action of the operator on a basis
689        # vector, not the block version.
690        h_convolve_x = array_ops.reshape(h_convolve_x, shape[:-1])
691        matrix_rows.append(h_convolve_x)
692    matrix = array_ops.stack(matrix_rows, axis=-1)
693    return math_ops.cast(matrix, dtype)
694
695
696class LinearOperatorCirculant2DTestHermitianSpectrum(
697    LinearOperatorCirculant2DBaseTest,
698    linear_operator_test_util.SquareLinearOperatorDerivedClassTest):
699  """Test of LinearOperatorCirculant2D when the spectrum is Hermitian.
700
701  Hermitian spectrum <==> Real valued operator.  We test both real and complex
702  dtypes here though.  So in some cases the matrix will be complex but with
703  zero imaginary part.
704  """
705
706  def tearDown(self):
707    config.enable_tensor_float_32_execution(self.tf32_keep_)
708
709  def setUp(self):
710    self.tf32_keep_ = config.tensor_float_32_execution_enabled()
711    config.enable_tensor_float_32_execution(False)
712
713  def operator_and_matrix(self,
714                          shape_info,
715                          dtype,
716                          use_placeholder,
717                          ensure_self_adjoint_and_pd=False):
718    shape = shape_info.shape
719    spectrum = _spectrum_for_symmetric_circulant(
720        spectrum_shape=self._shape_to_spectrum_shape(shape),
721        d=2,
722        ensure_self_adjoint_and_pd=ensure_self_adjoint_and_pd,
723        dtype=dtype)
724
725    lin_op_spectrum = spectrum
726
727    if use_placeholder:
728      lin_op_spectrum = array_ops.placeholder_with_default(spectrum, shape=None)
729
730    operator = linalg.LinearOperatorCirculant2D(
731        lin_op_spectrum,
732        is_positive_definite=True if ensure_self_adjoint_and_pd else None,
733        is_self_adjoint=True if ensure_self_adjoint_and_pd else None,
734        input_output_dtype=dtype)
735
736    self.assertEqual(
737        operator.parameters,
738        {
739            "input_output_dtype": dtype,
740            "is_non_singular": None,
741            "is_positive_definite": (
742                True if ensure_self_adjoint_and_pd else None),
743            "is_self_adjoint": (
744                True if ensure_self_adjoint_and_pd else None),
745            "is_square": True,
746            "name": "LinearOperatorCirculant2D",
747            "spectrum": lin_op_spectrum,
748        })
749
750    mat = self._spectrum_to_circulant_2d(spectrum, shape, dtype=dtype)
751
752    return operator, mat
753
754
755class LinearOperatorCirculant2DTestNonHermitianSpectrum(
756    LinearOperatorCirculant2DBaseTest,
757    linear_operator_test_util.SquareLinearOperatorDerivedClassTest):
758  """Test of LinearOperatorCirculant when the spectrum is not Hermitian.
759
760  Non-Hermitian spectrum <==> Complex valued operator.
761  We test only complex dtypes here.
762  """
763
764  @staticmethod
765  def dtypes_to_test():
766    return [dtypes.complex64, dtypes.complex128]
767
768  @staticmethod
769  def skip_these_tests():
770    return ["cholesky", "eigvalsh"]
771
772  def operator_and_matrix(self,
773                          shape_info,
774                          dtype,
775                          use_placeholder,
776                          ensure_self_adjoint_and_pd=False):
777    del ensure_self_adjoint_and_pd
778    shape = shape_info.shape
779    # Will be well conditioned enough to get accurate solves.
780    spectrum = linear_operator_test_util.random_sign_uniform(
781        shape=self._shape_to_spectrum_shape(shape),
782        dtype=dtype,
783        minval=1.,
784        maxval=2.)
785
786    lin_op_spectrum = spectrum
787
788    if use_placeholder:
789      lin_op_spectrum = array_ops.placeholder_with_default(spectrum, shape=None)
790
791    operator = linalg.LinearOperatorCirculant2D(
792        lin_op_spectrum, input_output_dtype=dtype)
793
794    self.assertEqual(
795        operator.parameters,
796        {
797            "input_output_dtype": dtype,
798            "is_non_singular": None,
799            "is_positive_definite": None,
800            "is_self_adjoint": None,
801            "is_square": True,
802            "name": "LinearOperatorCirculant2D",
803            "spectrum": lin_op_spectrum,
804        }
805    )
806
807    mat = self._spectrum_to_circulant_2d(spectrum, shape, dtype=dtype)
808
809    return operator, mat
810
811  def test_real_hermitian_spectrum_gives_real_symmetric_operator(self):
812    with self.cached_session():  # Necessary for fft_kernel_label_map
813      # This is a real and hermitian spectrum.
814      spectrum = [[1., 2., 2.], [3., 4., 4.], [3., 4., 4.]]
815      operator = linalg.LinearOperatorCirculant(spectrum)
816
817      matrix_tensor = operator.to_dense()
818      self.assertEqual(matrix_tensor.dtype, dtypes.complex64)
819      matrix_t = array_ops.matrix_transpose(matrix_tensor)
820      imag_matrix = math_ops.imag(matrix_tensor)
821      matrix, matrix_transpose, imag_matrix = self.evaluate(
822          [matrix_tensor, matrix_t, imag_matrix])
823
824      np.testing.assert_allclose(0, imag_matrix, atol=1e-6)
825      self.assertAllClose(matrix, matrix_transpose, atol=1e-6)
826
827  def test_real_spectrum_gives_self_adjoint_operator(self):
828    with self.cached_session():
829      # This is a real and hermitian spectrum.
830      spectrum = linear_operator_test_util.random_normal(
831          shape=(3, 3), dtype=dtypes.float32)
832      operator = linalg.LinearOperatorCirculant2D(spectrum)
833
834      matrix_tensor = operator.to_dense()
835      self.assertEqual(matrix_tensor.dtype, dtypes.complex64)
836      matrix_h = linalg.adjoint(matrix_tensor)
837      matrix, matrix_h = self.evaluate([matrix_tensor, matrix_h])
838      self.assertAllClose(matrix, matrix_h, atol=1e-5)
839
840  def test_assert_non_singular_fails_for_singular_operator(self):
841    spectrum = math_ops.cast([[0 + 0j, 4 + 0j], [2j + 2, 3. + 0j]],
842                             dtypes.complex64)
843    operator = linalg.LinearOperatorCirculant2D(spectrum)
844    with self.cached_session():
845      with self.assertRaisesOpError("Singular operator"):
846        self.evaluate(operator.assert_non_singular())
847
848  def test_assert_non_singular_does_not_fail_for_non_singular_operator(self):
849    spectrum = math_ops.cast([[-3j, 4 + 0j], [2j + 2, 3. + 0j]],
850                             dtypes.complex64)
851    operator = linalg.LinearOperatorCirculant2D(spectrum)
852    with self.cached_session():
853      self.evaluate(operator.assert_non_singular())  # Should not fail
854
855  def test_assert_positive_definite_fails_for_non_positive_definite(self):
856    spectrum = math_ops.cast([[6. + 0j, 4 + 0j], [2j, 3. + 0j]],
857                             dtypes.complex64)
858    operator = linalg.LinearOperatorCirculant2D(spectrum)
859    with self.cached_session():
860      with self.assertRaisesOpError("Not positive definite"):
861        self.evaluate(operator.assert_positive_definite())
862
863  def test_assert_positive_definite_does_not_fail_when_pos_def(self):
864    spectrum = math_ops.cast([[6. + 0j, 4 + 0j], [2j + 2, 3. + 0j]],
865                             dtypes.complex64)
866    operator = linalg.LinearOperatorCirculant2D(spectrum)
867    with self.cached_session():
868      self.evaluate(operator.assert_positive_definite())  # Should not fail
869
870  def test_real_spectrum_and_not_self_adjoint_hint_raises(self):
871    spectrum = [[1., 2.], [3., 4]]
872    with self.assertRaisesRegex(ValueError, "real.*always.*self-adjoint"):
873      linalg.LinearOperatorCirculant2D(spectrum, is_self_adjoint=False)
874
875  def test_real_spectrum_auto_sets_is_self_adjoint_to_true(self):
876    spectrum = [[1., 2.], [3., 4]]
877    operator = linalg.LinearOperatorCirculant2D(spectrum)
878    self.assertTrue(operator.is_self_adjoint)
879
880  def test_invalid_rank_raises(self):
881    spectrum = array_ops.constant(np.float32(rng.rand(2)))
882    with self.assertRaisesRegex(ValueError, "must have at least 2 dimensions"):
883      linalg.LinearOperatorCirculant2D(spectrum)
884
885  def test_tape_safe(self):
886    spectrum = variables_module.Variable(
887        math_ops.cast([[1. + 0j, 1. + 0j], [1. + 1j, 2. + 2j]],
888                      dtypes.complex64))
889    operator = linalg.LinearOperatorCirculant2D(spectrum)
890    self.check_tape_safe(operator)
891
892
893@test_util.run_all_in_graph_and_eager_modes
894class LinearOperatorCirculant3DTest(test.TestCase):
895  """Simple test of the 3D case.  See also the 1D and 2D tests."""
896
897  _atol = {
898      dtypes.float16: 1e-3,
899      dtypes.float32: 1e-6,
900      dtypes.float64: 1e-7,
901      dtypes.complex64: 1e-6,
902      dtypes.complex128: 1e-7
903  }
904  _rtol = {
905      dtypes.float16: 1e-3,
906      dtypes.float32: 1e-6,
907      dtypes.float64: 1e-7,
908      dtypes.complex64: 1e-6,
909      dtypes.complex128: 1e-7
910  }
911
912  @contextlib.contextmanager
913  def _constrain_devices_and_set_default(self, sess, use_gpu, force_gpu):
914    """We overwrite the FFT operation mapping for testing."""
915    with test.TestCase._constrain_devices_and_set_default(
916        self, sess, use_gpu, force_gpu) as sess:
917      yield sess
918
919  def test_real_spectrum_gives_self_adjoint_operator(self):
920    with self.cached_session():
921      # This is a real and hermitian spectrum.
922      spectrum = linear_operator_test_util.random_normal(
923          shape=(2, 2, 3, 5), dtype=dtypes.float32)
924      operator = linalg.LinearOperatorCirculant3D(spectrum)
925      self.assertAllEqual((2, 2 * 3 * 5, 2 * 3 * 5), operator.shape)
926
927      self.assertEqual(
928          operator.parameters,
929          {
930              "input_output_dtype": dtypes.complex64,
931              "is_non_singular": None,
932              "is_positive_definite": None,
933              "is_self_adjoint": None,
934              "is_square": True,
935              "name": "LinearOperatorCirculant3D",
936              "spectrum": spectrum,
937          })
938
939      matrix_tensor = operator.to_dense()
940      self.assertEqual(matrix_tensor.dtype, dtypes.complex64)
941      matrix_h = linalg.adjoint(matrix_tensor)
942
943      matrix, matrix_h = self.evaluate([matrix_tensor, matrix_h])
944      self.assertAllEqual((2, 2 * 3 * 5, 2 * 3 * 5), matrix.shape)
945      self.assertAllClose(matrix, matrix_h)
946
947  def test_defining_operator_using_real_convolution_kernel(self):
948    with self.cached_session():
949      convolution_kernel = linear_operator_test_util.random_normal(
950          shape=(2, 2, 3, 5), dtype=dtypes.float32)
951      # Convolution kernel is real ==> spectrum is Hermitian.
952      spectrum = fft_ops.fft3d(
953          math_ops.cast(convolution_kernel, dtypes.complex64))
954
955      # spectrum is Hermitian ==> operator is real.
956      operator = linalg.LinearOperatorCirculant3D(spectrum)
957      self.assertAllEqual((2, 2 * 3 * 5, 2 * 3 * 5), operator.shape)
958
959      # Allow for complex output so we can make sure it has zero imag part.
960      self.assertEqual(operator.dtype, dtypes.complex64)
961      matrix = self.evaluate(operator.to_dense())
962      self.assertAllEqual((2, 2 * 3 * 5, 2 * 3 * 5), matrix.shape)
963      np.testing.assert_allclose(0, np.imag(matrix), atol=1e-5)
964
965  def test_defining_spd_operator_by_taking_real_part(self):
966    with self.cached_session():  # Necessary for fft_kernel_label_map
967      # S is real and positive.
968      s = linear_operator_test_util.random_uniform(
969          shape=(10, 2, 3, 4), dtype=dtypes.float32, minval=1., maxval=2.)
970
971      # Let S = S1 + S2, the Hermitian and anti-hermitian parts.
972      # S1 = 0.5 * (S + S^H), S2 = 0.5 * (S - S^H),
973      # where ^H is the Hermitian transpose of the function:
974      #    f(n0, n1, n2)^H := ComplexConjugate[f(N0-n0, N1-n1, N2-n2)].
975      # We want to isolate S1, since
976      #   S1 is Hermitian by construction
977      #   S1 is real since S is
978      #   S1 is positive since it is the sum of two positive kernels
979
980      # IDFT[S] = IDFT[S1] + IDFT[S2]
981      #         =      H1  +      H2
982      # where H1 is real since it is Hermitian,
983      # and H2 is imaginary since it is anti-Hermitian.
984      ifft_s = fft_ops.ifft3d(math_ops.cast(s, dtypes.complex64))
985
986      # Throw away H2, keep H1.
987      real_ifft_s = math_ops.real(ifft_s)
988
989      # This is the perfect spectrum!
990      # spectrum = DFT[H1]
991      #          = S1,
992      fft_real_ifft_s = fft_ops.fft3d(
993          math_ops.cast(real_ifft_s, dtypes.complex64))
994
995      # S1 is Hermitian ==> operator is real.
996      # S1 is real ==> operator is self-adjoint.
997      # S1 is positive ==> operator is positive-definite.
998      operator = linalg.LinearOperatorCirculant3D(fft_real_ifft_s)
999
1000      # Allow for complex output so we can check operator has zero imag part.
1001      self.assertEqual(operator.dtype, dtypes.complex64)
1002      matrix, matrix_t = self.evaluate([
1003          operator.to_dense(),
1004          array_ops.matrix_transpose(operator.to_dense())
1005      ])
1006      self.evaluate(operator.assert_positive_definite())  # Should not fail.
1007      np.testing.assert_allclose(0, np.imag(matrix), atol=1e-6)
1008      self.assertAllClose(matrix, matrix_t)
1009
1010      # Just to test the theory, get S2 as well.
1011      # This should create an imaginary operator.
1012      # S2 is anti-Hermitian ==> operator is imaginary.
1013      # S2 is real ==> operator is self-adjoint.
1014      imag_ifft_s = math_ops.imag(ifft_s)
1015      fft_imag_ifft_s = fft_ops.fft3d(
1016          1j * math_ops.cast(imag_ifft_s, dtypes.complex64))
1017      operator_imag = linalg.LinearOperatorCirculant3D(fft_imag_ifft_s)
1018
1019      matrix, matrix_h = self.evaluate([
1020          operator_imag.to_dense(),
1021          array_ops.matrix_transpose(math_ops.conj(operator_imag.to_dense()))
1022      ])
1023      self.assertAllClose(matrix, matrix_h)
1024      np.testing.assert_allclose(0, np.real(matrix), atol=1e-7)
1025
1026
1027if __name__ == "__main__":
1028  linear_operator_test_util.add_tests(
1029      LinearOperatorCirculantTestSelfAdjointOperator)
1030  linear_operator_test_util.add_tests(
1031      LinearOperatorCirculantTestHermitianSpectrum)
1032  linear_operator_test_util.add_tests(
1033      LinearOperatorCirculantTestNonHermitianSpectrum)
1034  linear_operator_test_util.add_tests(
1035      LinearOperatorCirculant2DTestHermitianSpectrum)
1036  linear_operator_test_util.add_tests(
1037      LinearOperatorCirculant2DTestNonHermitianSpectrum)
1038  test.main()
1039