1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15import contextlib 16 17from absl.testing import parameterized 18import numpy as np 19 20from tensorflow.python.framework import config 21from tensorflow.python.framework import dtypes 22from tensorflow.python.framework import test_util 23from tensorflow.python.ops import array_ops 24from tensorflow.python.ops import math_ops 25from tensorflow.python.ops import variables as variables_module 26from tensorflow.python.ops.linalg import linalg 27from tensorflow.python.ops.linalg import linear_operator_circulant 28from tensorflow.python.ops.linalg import linear_operator_test_util 29from tensorflow.python.ops.signal import fft_ops 30from tensorflow.python.platform import test 31 32rng = np.random.RandomState(0) 33_to_complex = linear_operator_circulant._to_complex 34 35exponential_power_convolution_kernel = ( 36 linear_operator_circulant.exponential_power_convolution_kernel) 37 38 39def _operator_from_kernel(kernel, d, **kwargs): 40 spectrum = linear_operator_circulant._FFT_OP[d]( 41 math_ops.cast(kernel, dtypes.complex64)) 42 if d == 1: 43 return linear_operator_circulant.LinearOperatorCirculant(spectrum, **kwargs) 44 elif d == 2: 45 return linear_operator_circulant.LinearOperatorCirculant2D( 46 spectrum, **kwargs) 47 elif d == 3: 48 return linear_operator_circulant.LinearOperatorCirculant3D( 49 spectrum, **kwargs) 50 51 52def _spectrum_for_symmetric_circulant( 53 spectrum_shape, 54 d, 55 ensure_self_adjoint_and_pd, 56 dtype, 57): 58 """Spectrum for d-dimensional real/symmetric circulant.""" 59 grid_shape = spectrum_shape[-d:] 60 61 if grid_shape == (0,) * d: 62 kernel = array_ops.reshape(math_ops.cast([], dtype), grid_shape) 63 else: 64 kernel = exponential_power_convolution_kernel( 65 grid_shape=grid_shape, 66 # power=2 with this scale and no inflation will have some negative 67 # spectra. It will still be real/symmetric. 68 length_scale=math_ops.cast([0.2] * d, dtype.real_dtype), 69 power=1 if ensure_self_adjoint_and_pd else 2, 70 zero_inflation=0.2 if ensure_self_adjoint_and_pd else None, 71 ) 72 spectrum = linear_operator_circulant._FFT_OP[d](_to_complex(kernel)) 73 spectrum = math_ops.cast(spectrum, dtype) 74 return array_ops.broadcast_to(spectrum, spectrum_shape) 75 76 77@test_util.run_all_in_graph_and_eager_modes 78class ExponentialPowerConvolutionKernelTest(parameterized.TestCase, 79 test.TestCase): 80 81 def assert_diag_is_ones(self, matrix, rtol): 82 self.assertAllClose( 83 np.ones_like(np.diag(matrix)), np.diag(matrix), rtol=rtol) 84 85 def assert_real_symmetric(self, matrix, tol): 86 self.assertAllClose(np.zeros_like(matrix.imag), matrix.imag, atol=tol) 87 self.assertAllClose(matrix.real, matrix.real.T, rtol=tol) 88 89 @parameterized.named_parameters( 90 dict(testcase_name="1Deven_power1", grid_shape=[10], power=1.), 91 dict(testcase_name="2Deven_power1", grid_shape=[4, 6], power=1.), 92 dict(testcase_name="3Deven_power1", grid_shape=[4, 6, 8], power=1.), 93 dict(testcase_name="3Devenodd_power1", grid_shape=[4, 5, 7], power=1.), 94 dict(testcase_name="1Dodd_power2", grid_shape=[9], power=2.), 95 dict(testcase_name="2Deven_power2", grid_shape=[8, 4], power=2.), 96 dict(testcase_name="3Devenodd_power2", grid_shape=[4, 5, 3], power=2.), 97 ) 98 def test_makes_symmetric_and_real_circulant_with_ones_diag( 99 self, grid_shape, power): 100 d = len(grid_shape) 101 length_scale = [0.2] * d 102 kernel = exponential_power_convolution_kernel( 103 grid_shape=grid_shape, 104 length_scale=length_scale, 105 power=power) 106 operator = _operator_from_kernel(kernel, d) 107 108 matrix = self.evaluate(operator.to_dense()) 109 110 tol = np.finfo(matrix.dtype).eps * np.prod(grid_shape) 111 self.assert_real_symmetric(matrix, tol) 112 self.assert_diag_is_ones(matrix, rtol=tol) 113 114 @parameterized.named_parameters( 115 dict(testcase_name="1D", grid_shape=[10]), 116 dict(testcase_name="2D", grid_shape=[5, 5]), 117 dict(testcase_name="3D", grid_shape=[5, 4, 3]), 118 ) 119 def test_zero_inflation(self, grid_shape): 120 d = len(grid_shape) 121 length_scale = [0.2] * d 122 123 kernel_no_inflation = exponential_power_convolution_kernel( 124 grid_shape=grid_shape, 125 length_scale=length_scale, 126 zero_inflation=None, 127 ) 128 matrix_no_inflation = self.evaluate( 129 _operator_from_kernel(kernel_no_inflation, d).to_dense()) 130 131 kernel_inflation_one_half = exponential_power_convolution_kernel( 132 grid_shape=grid_shape, 133 length_scale=length_scale, 134 zero_inflation=0.5, 135 ) 136 matrix_inflation_one_half = self.evaluate( 137 _operator_from_kernel(kernel_inflation_one_half, d).to_dense()) 138 139 kernel_inflation_one = exponential_power_convolution_kernel( 140 grid_shape=grid_shape, 141 length_scale=length_scale, 142 zero_inflation=1.0, 143 ) 144 matrix_inflation_one = self.evaluate( 145 _operator_from_kernel(kernel_inflation_one, d).to_dense()) 146 147 tol = np.finfo(matrix_no_inflation.dtype).eps * np.prod(grid_shape) 148 149 # In all cases, matrix should be real and symmetric. 150 self.assert_real_symmetric(matrix_no_inflation, tol) 151 self.assert_real_symmetric(matrix_inflation_one, tol) 152 self.assert_real_symmetric(matrix_inflation_one_half, tol) 153 154 # In all cases, the diagonal should be all ones. 155 self.assert_diag_is_ones(matrix_no_inflation, rtol=tol) 156 self.assert_diag_is_ones(matrix_inflation_one_half, rtol=tol) 157 self.assert_diag_is_ones(matrix_inflation_one, rtol=tol) 158 159 def _matrix_with_zerod_diag(matrix): 160 return matrix - np.diag(np.diag(matrix)) 161 162 # Inflation = 0.5 means the off-diagonal is deflated by factor (1 - .5) = .5 163 self.assertAllClose( 164 _matrix_with_zerod_diag(matrix_no_inflation) * 0.5, 165 _matrix_with_zerod_diag(matrix_inflation_one_half), rtol=tol) 166 167 # Inflation = 1.0 means the off-diagonal is deflated by factor (1 - 1) = 0 168 self.assertAllClose( 169 np.zeros_like(matrix_inflation_one), 170 _matrix_with_zerod_diag(matrix_inflation_one), rtol=tol) 171 172 @parameterized.named_parameters( 173 dict(testcase_name="1D", grid_shape=[10]), 174 dict(testcase_name="2D", grid_shape=[5, 5]), 175 dict(testcase_name="3D", grid_shape=[5, 4, 3]), 176 ) 177 def test_tiny_scale_corresponds_to_identity_matrix(self, grid_shape): 178 d = len(grid_shape) 179 180 kernel = exponential_power_convolution_kernel( 181 grid_shape=grid_shape, length_scale=[0.001] * d, power=2) 182 matrix = self.evaluate(_operator_from_kernel(kernel, d).to_dense()) 183 184 tol = np.finfo(matrix.dtype).eps * np.prod(grid_shape) 185 self.assertAllClose(matrix, np.eye(np.prod(grid_shape)), atol=tol) 186 self.assert_real_symmetric(matrix, tol) 187 188 @parameterized.named_parameters( 189 dict(testcase_name="1D", grid_shape=[10]), 190 dict(testcase_name="2D", grid_shape=[5, 5]), 191 dict(testcase_name="3D", grid_shape=[5, 4, 3]), 192 ) 193 def test_huge_scale_corresponds_to_ones_matrix(self, grid_shape): 194 d = len(grid_shape) 195 196 kernel = exponential_power_convolution_kernel( 197 grid_shape=grid_shape, length_scale=[100.] * d, power=2) 198 matrix = self.evaluate(_operator_from_kernel(kernel, d).to_dense()) 199 200 tol = np.finfo(matrix.dtype).eps * np.prod(grid_shape) * 50 201 self.assert_real_symmetric(matrix, tol) 202 self.assertAllClose(np.ones_like(matrix), matrix, rtol=tol) 203 204 205@test_util.run_all_in_graph_and_eager_modes 206class LinearOperatorCirculantBaseTest(object): 207 """Common class for circulant tests.""" 208 209 _atol = { 210 dtypes.float16: 1e-3, 211 dtypes.float32: 1e-6, 212 dtypes.float64: 1e-7, 213 dtypes.complex64: 1e-6, 214 dtypes.complex128: 1e-7 215 } 216 _rtol = { 217 dtypes.float16: 1e-3, 218 dtypes.float32: 1e-6, 219 dtypes.float64: 1e-7, 220 dtypes.complex64: 1e-6, 221 dtypes.complex128: 1e-7 222 } 223 224 @contextlib.contextmanager 225 def _constrain_devices_and_set_default(self, sess, use_gpu, force_gpu): 226 """We overwrite the FFT operation mapping for testing.""" 227 with test.TestCase._constrain_devices_and_set_default( 228 self, sess, use_gpu, force_gpu) as sess: 229 yield sess 230 231 def _shape_to_spectrum_shape(self, shape): 232 # If spectrum.shape = batch_shape + [N], 233 # this creates an operator of shape batch_shape + [N, N] 234 return shape[:-1] 235 236 def _spectrum_to_circulant_1d(self, spectrum, shape, dtype): 237 """Creates a circulant matrix from a spectrum. 238 239 Intentionally done in an explicit yet inefficient way. This provides a 240 cross check to the main code that uses fancy reshapes. 241 242 Args: 243 spectrum: Float or complex `Tensor`. 244 shape: Python list. Desired shape of returned matrix. 245 dtype: Type to cast the returned matrix to. 246 247 Returns: 248 Circulant (batch) matrix of desired `dtype`. 249 """ 250 spectrum = _to_complex(spectrum) 251 spectrum_shape = self._shape_to_spectrum_shape(shape) 252 domain_dimension = spectrum_shape[-1] 253 if not domain_dimension: 254 return array_ops.zeros(shape, dtype) 255 256 # Explicitly compute the action of spectrum on basis vectors. 257 matrix_rows = [] 258 for m in range(domain_dimension): 259 x = np.zeros([domain_dimension]) 260 # x is a basis vector. 261 x[m] = 1.0 262 fft_x = fft_ops.fft(math_ops.cast(x, spectrum.dtype)) 263 h_convolve_x = fft_ops.ifft(spectrum * fft_x) 264 matrix_rows.append(h_convolve_x) 265 matrix = array_ops.stack(matrix_rows, axis=-1) 266 return math_ops.cast(matrix, dtype) 267 268 269class LinearOperatorCirculantTestSelfAdjointOperator( 270 LinearOperatorCirculantBaseTest, 271 linear_operator_test_util.SquareLinearOperatorDerivedClassTest): 272 """Test of LinearOperatorCirculant when operator is self-adjoint. 273 274 Real spectrum <==> Self adjoint operator. 275 Note that when the spectrum is real, the operator may still be complex. 276 """ 277 278 @staticmethod 279 def dtypes_to_test(): 280 # This operator will always be complex because, although the spectrum is 281 # real, the matrix will not be real. 282 return [dtypes.complex64, dtypes.complex128] 283 284 @staticmethod 285 def optional_tests(): 286 """List of optional test names to run.""" 287 return [ 288 "operator_matmul_with_same_type", 289 "operator_solve_with_same_type", 290 ] 291 292 def operator_and_matrix(self, 293 shape_info, 294 dtype, 295 use_placeholder, 296 ensure_self_adjoint_and_pd=False): 297 shape = shape_info.shape 298 # For this test class, we are creating real spectrums. 299 # We also want the spectrum to have eigenvalues bounded away from zero. 300 # 301 # spectrum is bounded away from zero. 302 spectrum = linear_operator_test_util.random_sign_uniform( 303 shape=self._shape_to_spectrum_shape(shape), minval=1., maxval=2.) 304 if ensure_self_adjoint_and_pd: 305 spectrum = math_ops.abs(spectrum) 306 # If dtype is complex, cast spectrum to complex. The imaginary part will be 307 # zero, so the operator will still be self-adjoint. 308 spectrum = math_ops.cast(spectrum, dtype) 309 310 lin_op_spectrum = spectrum 311 312 if use_placeholder: 313 lin_op_spectrum = array_ops.placeholder_with_default(spectrum, shape=None) 314 315 operator = linalg.LinearOperatorCirculant( 316 lin_op_spectrum, 317 is_self_adjoint=True, 318 is_positive_definite=True if ensure_self_adjoint_and_pd else None, 319 input_output_dtype=dtype) 320 321 mat = self._spectrum_to_circulant_1d(spectrum, shape, dtype=dtype) 322 323 return operator, mat 324 325 @test_util.disable_xla("No registered Const") 326 def test_simple_hermitian_spectrum_gives_operator_with_zero_imag_part(self): 327 with self.cached_session(): 328 spectrum = math_ops.cast([1. + 0j, 1j, -1j], dtypes.complex64) 329 operator = linalg.LinearOperatorCirculant( 330 spectrum, input_output_dtype=dtypes.complex64) 331 matrix = operator.to_dense() 332 imag_matrix = math_ops.imag(matrix) 333 eps = np.finfo(np.float32).eps 334 np.testing.assert_allclose( 335 0, self.evaluate(imag_matrix), rtol=0, atol=eps * 3) 336 337 def test_tape_safe(self): 338 spectrum = variables_module.Variable( 339 math_ops.cast([1. + 0j, 1. + 0j], dtypes.complex64)) 340 operator = linalg.LinearOperatorCirculant(spectrum, is_self_adjoint=True) 341 self.check_tape_safe(operator) 342 343 def test_convert_variables_to_tensors(self): 344 spectrum = variables_module.Variable( 345 math_ops.cast([1. + 0j, 1. + 0j], dtypes.complex64)) 346 operator = linalg.LinearOperatorCirculant(spectrum, is_self_adjoint=True) 347 with self.cached_session() as sess: 348 sess.run([spectrum.initializer]) 349 self.check_convert_variables_to_tensors(operator) 350 351 352class LinearOperatorCirculantTestHermitianSpectrum( 353 LinearOperatorCirculantBaseTest, 354 linear_operator_test_util.SquareLinearOperatorDerivedClassTest): 355 """Test of LinearOperatorCirculant when the spectrum is Hermitian. 356 357 Hermitian spectrum <==> Real valued operator. We test both real and complex 358 dtypes here though. So in some cases the matrix will be complex but with 359 zero imaginary part. 360 """ 361 362 def tearDown(self): 363 config.enable_tensor_float_32_execution(self.tf32_keep_) 364 365 def setUp(self): 366 self.tf32_keep_ = config.tensor_float_32_execution_enabled() 367 config.enable_tensor_float_32_execution(False) 368 369 @staticmethod 370 def optional_tests(): 371 """List of optional test names to run.""" 372 return [ 373 "operator_matmul_with_same_type", 374 "operator_solve_with_same_type", 375 ] 376 377 def operator_and_matrix(self, 378 shape_info, 379 dtype, 380 use_placeholder, 381 ensure_self_adjoint_and_pd=False): 382 shape = shape_info.shape 383 spectrum = _spectrum_for_symmetric_circulant( 384 spectrum_shape=self._shape_to_spectrum_shape(shape), 385 d=1, 386 ensure_self_adjoint_and_pd=ensure_self_adjoint_and_pd, 387 dtype=dtype) 388 389 lin_op_spectrum = spectrum 390 391 if use_placeholder: 392 lin_op_spectrum = array_ops.placeholder_with_default(spectrum, shape=None) 393 394 operator = linalg.LinearOperatorCirculant( 395 lin_op_spectrum, 396 input_output_dtype=dtype, 397 is_positive_definite=True if ensure_self_adjoint_and_pd else None, 398 is_self_adjoint=True if ensure_self_adjoint_and_pd else None, 399 ) 400 401 mat = self._spectrum_to_circulant_1d(spectrum, shape, dtype=dtype) 402 403 return operator, mat 404 405 @test_util.disable_xla("No registered Const") 406 def test_simple_hermitian_spectrum_gives_operator_with_zero_imag_part(self): 407 with self.cached_session(): 408 spectrum = math_ops.cast([1. + 0j, 1j, -1j], dtypes.complex64) 409 operator = linalg.LinearOperatorCirculant( 410 spectrum, input_output_dtype=dtypes.complex64) 411 matrix = operator.to_dense() 412 imag_matrix = math_ops.imag(matrix) 413 eps = np.finfo(np.float32).eps 414 np.testing.assert_allclose( 415 0, self.evaluate(imag_matrix), rtol=0, atol=eps * 3) 416 417 def test_tape_safe(self): 418 spectrum = variables_module.Variable( 419 math_ops.cast([1. + 0j, 1. + 1j], dtypes.complex64)) 420 operator = linalg.LinearOperatorCirculant(spectrum, is_self_adjoint=False) 421 self.check_tape_safe(operator) 422 423 424class LinearOperatorCirculantTestNonHermitianSpectrum( 425 LinearOperatorCirculantBaseTest, 426 linear_operator_test_util.SquareLinearOperatorDerivedClassTest): 427 """Test of LinearOperatorCirculant when the spectrum is not Hermitian. 428 429 Non-Hermitian spectrum <==> Complex valued operator. 430 We test only complex dtypes here. 431 """ 432 433 @staticmethod 434 def dtypes_to_test(): 435 return [dtypes.complex64, dtypes.complex128] 436 437 # Skip Cholesky since we are explicitly testing non-hermitian 438 # spectra. 439 @staticmethod 440 def skip_these_tests(): 441 return ["cholesky", "eigvalsh"] 442 443 @staticmethod 444 def optional_tests(): 445 """List of optional test names to run.""" 446 return [ 447 "operator_matmul_with_same_type", 448 "operator_solve_with_same_type", 449 ] 450 451 def operator_and_matrix(self, 452 shape_info, 453 dtype, 454 use_placeholder, 455 ensure_self_adjoint_and_pd=False): 456 del ensure_self_adjoint_and_pd 457 shape = shape_info.shape 458 # Will be well conditioned enough to get accurate solves. 459 spectrum = linear_operator_test_util.random_sign_uniform( 460 shape=self._shape_to_spectrum_shape(shape), 461 dtype=dtype, 462 minval=1., 463 maxval=2.) 464 465 lin_op_spectrum = spectrum 466 467 if use_placeholder: 468 lin_op_spectrum = array_ops.placeholder_with_default(spectrum, shape=None) 469 470 operator = linalg.LinearOperatorCirculant( 471 lin_op_spectrum, input_output_dtype=dtype) 472 473 self.assertEqual( 474 operator.parameters, 475 { 476 "input_output_dtype": dtype, 477 "is_non_singular": None, 478 "is_positive_definite": None, 479 "is_self_adjoint": None, 480 "is_square": True, 481 "name": "LinearOperatorCirculant", 482 "spectrum": lin_op_spectrum, 483 }) 484 485 mat = self._spectrum_to_circulant_1d(spectrum, shape, dtype=dtype) 486 487 return operator, mat 488 489 @test_util.disable_xla("No registered Const") 490 def test_simple_hermitian_spectrum_gives_operator_with_zero_imag_part(self): 491 with self.cached_session(): 492 spectrum = math_ops.cast([1. + 0j, 1j, -1j], dtypes.complex64) 493 operator = linalg.LinearOperatorCirculant( 494 spectrum, input_output_dtype=dtypes.complex64) 495 matrix = operator.to_dense() 496 imag_matrix = math_ops.imag(matrix) 497 eps = np.finfo(np.float32).eps 498 np.testing.assert_allclose( 499 0, self.evaluate(imag_matrix), rtol=0, atol=eps * 3) 500 501 def test_simple_positive_real_spectrum_gives_self_adjoint_pos_def_oper(self): 502 with self.cached_session() as sess: 503 spectrum = math_ops.cast([6., 4, 2], dtypes.complex64) 504 operator = linalg.LinearOperatorCirculant( 505 spectrum, input_output_dtype=dtypes.complex64) 506 matrix, matrix_h = sess.run( 507 [operator.to_dense(), 508 linalg.adjoint(operator.to_dense())]) 509 self.assertAllClose(matrix, matrix_h) 510 self.evaluate(operator.assert_positive_definite()) # Should not fail 511 self.evaluate(operator.assert_self_adjoint()) # Should not fail 512 513 def test_defining_operator_using_real_convolution_kernel(self): 514 with self.cached_session(): 515 convolution_kernel = [1., 2., 1.] 516 spectrum = fft_ops.fft( 517 math_ops.cast(convolution_kernel, dtypes.complex64)) 518 519 # spectrum is shape [3] ==> operator is shape [3, 3] 520 # spectrum is Hermitian ==> operator is real. 521 operator = linalg.LinearOperatorCirculant(spectrum) 522 523 # Allow for complex output so we can make sure it has zero imag part. 524 self.assertEqual(operator.dtype, dtypes.complex64) 525 526 matrix = self.evaluate(operator.to_dense()) 527 np.testing.assert_allclose(0, np.imag(matrix), atol=1e-6) 528 529 @test_util.run_v1_only("currently failing on v2") 530 def test_hermitian_spectrum_gives_operator_with_zero_imag_part(self): 531 with self.cached_session(): 532 # Make spectrum the FFT of a real convolution kernel h. This ensures that 533 # spectrum is Hermitian. 534 h = linear_operator_test_util.random_normal(shape=(3, 4)) 535 spectrum = fft_ops.fft(math_ops.cast(h, dtypes.complex64)) 536 operator = linalg.LinearOperatorCirculant( 537 spectrum, input_output_dtype=dtypes.complex64) 538 matrix = operator.to_dense() 539 imag_matrix = math_ops.imag(matrix) 540 eps = np.finfo(np.float32).eps 541 np.testing.assert_allclose( 542 0, self.evaluate(imag_matrix), rtol=0, atol=eps * 3 * 4) 543 544 def test_convolution_kernel_same_as_first_row_of_to_dense(self): 545 spectrum = [[3., 2., 1.], [2., 1.5, 1.]] 546 with self.cached_session(): 547 operator = linalg.LinearOperatorCirculant(spectrum) 548 h = operator.convolution_kernel() 549 c = operator.to_dense() 550 551 self.assertAllEqual((2, 3), h.shape) 552 self.assertAllEqual((2, 3, 3), c.shape) 553 self.assertAllClose(self.evaluate(h), self.evaluate(c)[:, :, 0]) 554 555 def test_assert_non_singular_fails_for_singular_operator(self): 556 spectrum = math_ops.cast([0 + 0j, 4 + 0j, 2j + 2], dtypes.complex64) 557 operator = linalg.LinearOperatorCirculant(spectrum) 558 with self.cached_session(): 559 with self.assertRaisesOpError("Singular operator"): 560 self.evaluate(operator.assert_non_singular()) 561 562 def test_assert_non_singular_does_not_fail_for_non_singular_operator(self): 563 spectrum = math_ops.cast([-3j, 4 + 0j, 2j + 2], dtypes.complex64) 564 operator = linalg.LinearOperatorCirculant(spectrum) 565 with self.cached_session(): 566 self.evaluate(operator.assert_non_singular()) # Should not fail 567 568 def test_assert_positive_definite_fails_for_non_positive_definite(self): 569 spectrum = math_ops.cast([6. + 0j, 4 + 0j, 2j], dtypes.complex64) 570 operator = linalg.LinearOperatorCirculant(spectrum) 571 with self.cached_session(): 572 with self.assertRaisesOpError("Not positive definite"): 573 self.evaluate(operator.assert_positive_definite()) 574 575 def test_assert_positive_definite_does_not_fail_when_pos_def(self): 576 spectrum = math_ops.cast([6. + 0j, 4 + 0j, 2j + 2], dtypes.complex64) 577 operator = linalg.LinearOperatorCirculant(spectrum) 578 with self.cached_session(): 579 self.evaluate(operator.assert_positive_definite()) # Should not fail 580 581 def test_real_spectrum_and_not_self_adjoint_hint_raises(self): 582 spectrum = [1., 2.] 583 with self.assertRaisesRegex(ValueError, "real.*always.*self-adjoint"): 584 linalg.LinearOperatorCirculant(spectrum, is_self_adjoint=False) 585 586 def test_real_spectrum_auto_sets_is_self_adjoint_to_true(self): 587 spectrum = [1., 2.] 588 operator = linalg.LinearOperatorCirculant(spectrum) 589 self.assertTrue(operator.is_self_adjoint) 590 591 592@test_util.run_all_in_graph_and_eager_modes 593class LinearOperatorCirculant2DBaseTest(object): 594 """Common class for 2D circulant tests.""" 595 596 _atol = { 597 dtypes.float16: 1e-3, 598 dtypes.float32: 1e-6, 599 dtypes.float64: 1e-7, 600 dtypes.complex64: 1e-6, 601 dtypes.complex128: 1e-7 602 } 603 _rtol = { 604 dtypes.float16: 1e-3, 605 dtypes.float32: 1e-6, 606 dtypes.float64: 1e-7, 607 dtypes.complex64: 1e-6, 608 dtypes.complex128: 1e-7 609 } 610 611 @contextlib.contextmanager 612 def _constrain_devices_and_set_default(self, sess, use_gpu, force_gpu): 613 """We overwrite the FFT operation mapping for testing.""" 614 with test.TestCase._constrain_devices_and_set_default( 615 self, sess, use_gpu, force_gpu) as sess: 616 yield sess 617 618 @staticmethod 619 def operator_shapes_infos(): 620 shape_info = linear_operator_test_util.OperatorShapesInfo 621 # non-batch operators (n, n) and batch operators. 622 return [ 623 shape_info((0, 0)), 624 shape_info((1, 1)), 625 shape_info((1, 6, 6)), 626 shape_info((3, 4, 4)), 627 shape_info((2, 1, 3, 3)) 628 ] 629 630 @staticmethod 631 def optional_tests(): 632 """List of optional test names to run.""" 633 return [ 634 "operator_matmul_with_same_type", 635 "operator_solve_with_same_type", 636 ] 637 638 def _shape_to_spectrum_shape(self, shape): 639 """Get a spectrum shape that will make an operator of desired shape.""" 640 # This 2D block circulant operator takes a spectrum of shape 641 # batch_shape + [N0, N1], 642 # and creates and operator of shape 643 # batch_shape + [N0*N1, N0*N1] 644 if shape == (0, 0): 645 return (0, 0) 646 elif shape == (1, 1): 647 return (1, 1) 648 elif shape == (1, 6, 6): 649 return (1, 2, 3) 650 elif shape == (3, 4, 4): 651 return (3, 2, 2) 652 elif shape == (2, 1, 3, 3): 653 return (2, 1, 3, 1) 654 else: 655 raise ValueError("Unhandled shape: %s" % shape) 656 657 def _spectrum_to_circulant_2d(self, spectrum, shape, dtype): 658 """Creates a block circulant matrix from a spectrum. 659 660 Intentionally done in an explicit yet inefficient way. This provides a 661 cross check to the main code that uses fancy reshapes. 662 663 Args: 664 spectrum: Float or complex `Tensor`. 665 shape: Python list. Desired shape of returned matrix. 666 dtype: Type to cast the returned matrix to. 667 668 Returns: 669 Block circulant (batch) matrix of desired `dtype`. 670 """ 671 spectrum = _to_complex(spectrum) 672 spectrum_shape = self._shape_to_spectrum_shape(shape) 673 domain_dimension = spectrum_shape[-1] 674 if not domain_dimension: 675 return array_ops.zeros(shape, dtype) 676 677 block_shape = spectrum_shape[-2:] 678 679 # Explicitly compute the action of spectrum on basis vectors. 680 matrix_rows = [] 681 for n0 in range(block_shape[0]): 682 for n1 in range(block_shape[1]): 683 x = np.zeros(block_shape) 684 # x is a basis vector. 685 x[n0, n1] = 1.0 686 fft_x = fft_ops.fft2d(math_ops.cast(x, spectrum.dtype)) 687 h_convolve_x = fft_ops.ifft2d(spectrum * fft_x) 688 # We want the flat version of the action of the operator on a basis 689 # vector, not the block version. 690 h_convolve_x = array_ops.reshape(h_convolve_x, shape[:-1]) 691 matrix_rows.append(h_convolve_x) 692 matrix = array_ops.stack(matrix_rows, axis=-1) 693 return math_ops.cast(matrix, dtype) 694 695 696class LinearOperatorCirculant2DTestHermitianSpectrum( 697 LinearOperatorCirculant2DBaseTest, 698 linear_operator_test_util.SquareLinearOperatorDerivedClassTest): 699 """Test of LinearOperatorCirculant2D when the spectrum is Hermitian. 700 701 Hermitian spectrum <==> Real valued operator. We test both real and complex 702 dtypes here though. So in some cases the matrix will be complex but with 703 zero imaginary part. 704 """ 705 706 def tearDown(self): 707 config.enable_tensor_float_32_execution(self.tf32_keep_) 708 709 def setUp(self): 710 self.tf32_keep_ = config.tensor_float_32_execution_enabled() 711 config.enable_tensor_float_32_execution(False) 712 713 def operator_and_matrix(self, 714 shape_info, 715 dtype, 716 use_placeholder, 717 ensure_self_adjoint_and_pd=False): 718 shape = shape_info.shape 719 spectrum = _spectrum_for_symmetric_circulant( 720 spectrum_shape=self._shape_to_spectrum_shape(shape), 721 d=2, 722 ensure_self_adjoint_and_pd=ensure_self_adjoint_and_pd, 723 dtype=dtype) 724 725 lin_op_spectrum = spectrum 726 727 if use_placeholder: 728 lin_op_spectrum = array_ops.placeholder_with_default(spectrum, shape=None) 729 730 operator = linalg.LinearOperatorCirculant2D( 731 lin_op_spectrum, 732 is_positive_definite=True if ensure_self_adjoint_and_pd else None, 733 is_self_adjoint=True if ensure_self_adjoint_and_pd else None, 734 input_output_dtype=dtype) 735 736 self.assertEqual( 737 operator.parameters, 738 { 739 "input_output_dtype": dtype, 740 "is_non_singular": None, 741 "is_positive_definite": ( 742 True if ensure_self_adjoint_and_pd else None), 743 "is_self_adjoint": ( 744 True if ensure_self_adjoint_and_pd else None), 745 "is_square": True, 746 "name": "LinearOperatorCirculant2D", 747 "spectrum": lin_op_spectrum, 748 }) 749 750 mat = self._spectrum_to_circulant_2d(spectrum, shape, dtype=dtype) 751 752 return operator, mat 753 754 755class LinearOperatorCirculant2DTestNonHermitianSpectrum( 756 LinearOperatorCirculant2DBaseTest, 757 linear_operator_test_util.SquareLinearOperatorDerivedClassTest): 758 """Test of LinearOperatorCirculant when the spectrum is not Hermitian. 759 760 Non-Hermitian spectrum <==> Complex valued operator. 761 We test only complex dtypes here. 762 """ 763 764 @staticmethod 765 def dtypes_to_test(): 766 return [dtypes.complex64, dtypes.complex128] 767 768 @staticmethod 769 def skip_these_tests(): 770 return ["cholesky", "eigvalsh"] 771 772 def operator_and_matrix(self, 773 shape_info, 774 dtype, 775 use_placeholder, 776 ensure_self_adjoint_and_pd=False): 777 del ensure_self_adjoint_and_pd 778 shape = shape_info.shape 779 # Will be well conditioned enough to get accurate solves. 780 spectrum = linear_operator_test_util.random_sign_uniform( 781 shape=self._shape_to_spectrum_shape(shape), 782 dtype=dtype, 783 minval=1., 784 maxval=2.) 785 786 lin_op_spectrum = spectrum 787 788 if use_placeholder: 789 lin_op_spectrum = array_ops.placeholder_with_default(spectrum, shape=None) 790 791 operator = linalg.LinearOperatorCirculant2D( 792 lin_op_spectrum, input_output_dtype=dtype) 793 794 self.assertEqual( 795 operator.parameters, 796 { 797 "input_output_dtype": dtype, 798 "is_non_singular": None, 799 "is_positive_definite": None, 800 "is_self_adjoint": None, 801 "is_square": True, 802 "name": "LinearOperatorCirculant2D", 803 "spectrum": lin_op_spectrum, 804 } 805 ) 806 807 mat = self._spectrum_to_circulant_2d(spectrum, shape, dtype=dtype) 808 809 return operator, mat 810 811 def test_real_hermitian_spectrum_gives_real_symmetric_operator(self): 812 with self.cached_session(): # Necessary for fft_kernel_label_map 813 # This is a real and hermitian spectrum. 814 spectrum = [[1., 2., 2.], [3., 4., 4.], [3., 4., 4.]] 815 operator = linalg.LinearOperatorCirculant(spectrum) 816 817 matrix_tensor = operator.to_dense() 818 self.assertEqual(matrix_tensor.dtype, dtypes.complex64) 819 matrix_t = array_ops.matrix_transpose(matrix_tensor) 820 imag_matrix = math_ops.imag(matrix_tensor) 821 matrix, matrix_transpose, imag_matrix = self.evaluate( 822 [matrix_tensor, matrix_t, imag_matrix]) 823 824 np.testing.assert_allclose(0, imag_matrix, atol=1e-6) 825 self.assertAllClose(matrix, matrix_transpose, atol=1e-6) 826 827 def test_real_spectrum_gives_self_adjoint_operator(self): 828 with self.cached_session(): 829 # This is a real and hermitian spectrum. 830 spectrum = linear_operator_test_util.random_normal( 831 shape=(3, 3), dtype=dtypes.float32) 832 operator = linalg.LinearOperatorCirculant2D(spectrum) 833 834 matrix_tensor = operator.to_dense() 835 self.assertEqual(matrix_tensor.dtype, dtypes.complex64) 836 matrix_h = linalg.adjoint(matrix_tensor) 837 matrix, matrix_h = self.evaluate([matrix_tensor, matrix_h]) 838 self.assertAllClose(matrix, matrix_h, atol=1e-5) 839 840 def test_assert_non_singular_fails_for_singular_operator(self): 841 spectrum = math_ops.cast([[0 + 0j, 4 + 0j], [2j + 2, 3. + 0j]], 842 dtypes.complex64) 843 operator = linalg.LinearOperatorCirculant2D(spectrum) 844 with self.cached_session(): 845 with self.assertRaisesOpError("Singular operator"): 846 self.evaluate(operator.assert_non_singular()) 847 848 def test_assert_non_singular_does_not_fail_for_non_singular_operator(self): 849 spectrum = math_ops.cast([[-3j, 4 + 0j], [2j + 2, 3. + 0j]], 850 dtypes.complex64) 851 operator = linalg.LinearOperatorCirculant2D(spectrum) 852 with self.cached_session(): 853 self.evaluate(operator.assert_non_singular()) # Should not fail 854 855 def test_assert_positive_definite_fails_for_non_positive_definite(self): 856 spectrum = math_ops.cast([[6. + 0j, 4 + 0j], [2j, 3. + 0j]], 857 dtypes.complex64) 858 operator = linalg.LinearOperatorCirculant2D(spectrum) 859 with self.cached_session(): 860 with self.assertRaisesOpError("Not positive definite"): 861 self.evaluate(operator.assert_positive_definite()) 862 863 def test_assert_positive_definite_does_not_fail_when_pos_def(self): 864 spectrum = math_ops.cast([[6. + 0j, 4 + 0j], [2j + 2, 3. + 0j]], 865 dtypes.complex64) 866 operator = linalg.LinearOperatorCirculant2D(spectrum) 867 with self.cached_session(): 868 self.evaluate(operator.assert_positive_definite()) # Should not fail 869 870 def test_real_spectrum_and_not_self_adjoint_hint_raises(self): 871 spectrum = [[1., 2.], [3., 4]] 872 with self.assertRaisesRegex(ValueError, "real.*always.*self-adjoint"): 873 linalg.LinearOperatorCirculant2D(spectrum, is_self_adjoint=False) 874 875 def test_real_spectrum_auto_sets_is_self_adjoint_to_true(self): 876 spectrum = [[1., 2.], [3., 4]] 877 operator = linalg.LinearOperatorCirculant2D(spectrum) 878 self.assertTrue(operator.is_self_adjoint) 879 880 def test_invalid_rank_raises(self): 881 spectrum = array_ops.constant(np.float32(rng.rand(2))) 882 with self.assertRaisesRegex(ValueError, "must have at least 2 dimensions"): 883 linalg.LinearOperatorCirculant2D(spectrum) 884 885 def test_tape_safe(self): 886 spectrum = variables_module.Variable( 887 math_ops.cast([[1. + 0j, 1. + 0j], [1. + 1j, 2. + 2j]], 888 dtypes.complex64)) 889 operator = linalg.LinearOperatorCirculant2D(spectrum) 890 self.check_tape_safe(operator) 891 892 893@test_util.run_all_in_graph_and_eager_modes 894class LinearOperatorCirculant3DTest(test.TestCase): 895 """Simple test of the 3D case. See also the 1D and 2D tests.""" 896 897 _atol = { 898 dtypes.float16: 1e-3, 899 dtypes.float32: 1e-6, 900 dtypes.float64: 1e-7, 901 dtypes.complex64: 1e-6, 902 dtypes.complex128: 1e-7 903 } 904 _rtol = { 905 dtypes.float16: 1e-3, 906 dtypes.float32: 1e-6, 907 dtypes.float64: 1e-7, 908 dtypes.complex64: 1e-6, 909 dtypes.complex128: 1e-7 910 } 911 912 @contextlib.contextmanager 913 def _constrain_devices_and_set_default(self, sess, use_gpu, force_gpu): 914 """We overwrite the FFT operation mapping for testing.""" 915 with test.TestCase._constrain_devices_and_set_default( 916 self, sess, use_gpu, force_gpu) as sess: 917 yield sess 918 919 def test_real_spectrum_gives_self_adjoint_operator(self): 920 with self.cached_session(): 921 # This is a real and hermitian spectrum. 922 spectrum = linear_operator_test_util.random_normal( 923 shape=(2, 2, 3, 5), dtype=dtypes.float32) 924 operator = linalg.LinearOperatorCirculant3D(spectrum) 925 self.assertAllEqual((2, 2 * 3 * 5, 2 * 3 * 5), operator.shape) 926 927 self.assertEqual( 928 operator.parameters, 929 { 930 "input_output_dtype": dtypes.complex64, 931 "is_non_singular": None, 932 "is_positive_definite": None, 933 "is_self_adjoint": None, 934 "is_square": True, 935 "name": "LinearOperatorCirculant3D", 936 "spectrum": spectrum, 937 }) 938 939 matrix_tensor = operator.to_dense() 940 self.assertEqual(matrix_tensor.dtype, dtypes.complex64) 941 matrix_h = linalg.adjoint(matrix_tensor) 942 943 matrix, matrix_h = self.evaluate([matrix_tensor, matrix_h]) 944 self.assertAllEqual((2, 2 * 3 * 5, 2 * 3 * 5), matrix.shape) 945 self.assertAllClose(matrix, matrix_h) 946 947 def test_defining_operator_using_real_convolution_kernel(self): 948 with self.cached_session(): 949 convolution_kernel = linear_operator_test_util.random_normal( 950 shape=(2, 2, 3, 5), dtype=dtypes.float32) 951 # Convolution kernel is real ==> spectrum is Hermitian. 952 spectrum = fft_ops.fft3d( 953 math_ops.cast(convolution_kernel, dtypes.complex64)) 954 955 # spectrum is Hermitian ==> operator is real. 956 operator = linalg.LinearOperatorCirculant3D(spectrum) 957 self.assertAllEqual((2, 2 * 3 * 5, 2 * 3 * 5), operator.shape) 958 959 # Allow for complex output so we can make sure it has zero imag part. 960 self.assertEqual(operator.dtype, dtypes.complex64) 961 matrix = self.evaluate(operator.to_dense()) 962 self.assertAllEqual((2, 2 * 3 * 5, 2 * 3 * 5), matrix.shape) 963 np.testing.assert_allclose(0, np.imag(matrix), atol=1e-5) 964 965 def test_defining_spd_operator_by_taking_real_part(self): 966 with self.cached_session(): # Necessary for fft_kernel_label_map 967 # S is real and positive. 968 s = linear_operator_test_util.random_uniform( 969 shape=(10, 2, 3, 4), dtype=dtypes.float32, minval=1., maxval=2.) 970 971 # Let S = S1 + S2, the Hermitian and anti-hermitian parts. 972 # S1 = 0.5 * (S + S^H), S2 = 0.5 * (S - S^H), 973 # where ^H is the Hermitian transpose of the function: 974 # f(n0, n1, n2)^H := ComplexConjugate[f(N0-n0, N1-n1, N2-n2)]. 975 # We want to isolate S1, since 976 # S1 is Hermitian by construction 977 # S1 is real since S is 978 # S1 is positive since it is the sum of two positive kernels 979 980 # IDFT[S] = IDFT[S1] + IDFT[S2] 981 # = H1 + H2 982 # where H1 is real since it is Hermitian, 983 # and H2 is imaginary since it is anti-Hermitian. 984 ifft_s = fft_ops.ifft3d(math_ops.cast(s, dtypes.complex64)) 985 986 # Throw away H2, keep H1. 987 real_ifft_s = math_ops.real(ifft_s) 988 989 # This is the perfect spectrum! 990 # spectrum = DFT[H1] 991 # = S1, 992 fft_real_ifft_s = fft_ops.fft3d( 993 math_ops.cast(real_ifft_s, dtypes.complex64)) 994 995 # S1 is Hermitian ==> operator is real. 996 # S1 is real ==> operator is self-adjoint. 997 # S1 is positive ==> operator is positive-definite. 998 operator = linalg.LinearOperatorCirculant3D(fft_real_ifft_s) 999 1000 # Allow for complex output so we can check operator has zero imag part. 1001 self.assertEqual(operator.dtype, dtypes.complex64) 1002 matrix, matrix_t = self.evaluate([ 1003 operator.to_dense(), 1004 array_ops.matrix_transpose(operator.to_dense()) 1005 ]) 1006 self.evaluate(operator.assert_positive_definite()) # Should not fail. 1007 np.testing.assert_allclose(0, np.imag(matrix), atol=1e-6) 1008 self.assertAllClose(matrix, matrix_t) 1009 1010 # Just to test the theory, get S2 as well. 1011 # This should create an imaginary operator. 1012 # S2 is anti-Hermitian ==> operator is imaginary. 1013 # S2 is real ==> operator is self-adjoint. 1014 imag_ifft_s = math_ops.imag(ifft_s) 1015 fft_imag_ifft_s = fft_ops.fft3d( 1016 1j * math_ops.cast(imag_ifft_s, dtypes.complex64)) 1017 operator_imag = linalg.LinearOperatorCirculant3D(fft_imag_ifft_s) 1018 1019 matrix, matrix_h = self.evaluate([ 1020 operator_imag.to_dense(), 1021 array_ops.matrix_transpose(math_ops.conj(operator_imag.to_dense())) 1022 ]) 1023 self.assertAllClose(matrix, matrix_h) 1024 np.testing.assert_allclose(0, np.real(matrix), atol=1e-7) 1025 1026 1027if __name__ == "__main__": 1028 linear_operator_test_util.add_tests( 1029 LinearOperatorCirculantTestSelfAdjointOperator) 1030 linear_operator_test_util.add_tests( 1031 LinearOperatorCirculantTestHermitianSpectrum) 1032 linear_operator_test_util.add_tests( 1033 LinearOperatorCirculantTestNonHermitianSpectrum) 1034 linear_operator_test_util.add_tests( 1035 LinearOperatorCirculant2DTestHermitianSpectrum) 1036 linear_operator_test_util.add_tests( 1037 LinearOperatorCirculant2DTestNonHermitianSpectrum) 1038 test.main() 1039