1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16import contextlib 17 18import numpy as np 19import scipy.linalg 20 21from tensorflow.python.eager import backprop 22from tensorflow.python.framework import dtypes 23from tensorflow.python.framework import test_util 24from tensorflow.python.ops import array_ops 25from tensorflow.python.ops import math_ops 26from tensorflow.python.ops import variables as variables_module 27from tensorflow.python.ops.linalg import linalg as linalg_lib 28from tensorflow.python.ops.linalg import linear_operator_test_util 29from tensorflow.python.ops.linalg import linear_operator_toeplitz 30from tensorflow.python.platform import test 31 32linalg = linalg_lib 33 34_to_complex = linear_operator_toeplitz._to_complex 35 36 37@test_util.run_all_in_graph_and_eager_modes 38class LinearOperatorToeplitzTest( 39 linear_operator_test_util.SquareLinearOperatorDerivedClassTest): 40 """Most tests done in the base class LinearOperatorDerivedClassTest.""" 41 42 @contextlib.contextmanager 43 def _constrain_devices_and_set_default(self, sess, use_gpu, force_gpu): 44 """We overwrite the FFT operation mapping for testing.""" 45 with test.TestCase._constrain_devices_and_set_default( 46 self, sess, use_gpu, force_gpu) as sess: 47 yield sess 48 49 def setUp(self): 50 # TODO(srvasude): Lower these tolerances once specialized solve and 51 # determinants are implemented. 52 self._atol[dtypes.float32] = 1e-4 53 self._rtol[dtypes.float32] = 1e-4 54 self._atol[dtypes.float64] = 1e-9 55 self._rtol[dtypes.float64] = 1e-9 56 self._atol[dtypes.complex64] = 1e-4 57 self._rtol[dtypes.complex64] = 1e-4 58 self._atol[dtypes.complex128] = 1e-9 59 self._rtol[dtypes.complex128] = 1e-9 60 61 @staticmethod 62 def skip_these_tests(): 63 # Skip solve tests, as these could have better stability 64 # (currently exercises the base class). 65 # TODO(srvasude): Enable these when solve is implemented. 66 return ["cholesky", "cond", "inverse", "solve", "solve_with_broadcast"] 67 68 @staticmethod 69 def operator_shapes_infos(): 70 shape_info = linear_operator_test_util.OperatorShapesInfo 71 # non-batch operators (n, n) and batch operators. 72 return [ 73 shape_info((1, 1)), 74 shape_info((1, 6, 6)), 75 shape_info((3, 4, 4)), 76 shape_info((2, 1, 3, 3)) 77 ] 78 79 def operator_and_matrix( 80 self, build_info, dtype, use_placeholder, 81 ensure_self_adjoint_and_pd=False): 82 shape = list(build_info.shape) 83 row = np.random.uniform(low=1., high=5., size=shape[:-1]) 84 col = np.random.uniform(low=1., high=5., size=shape[:-1]) 85 86 # Make sure first entry is the same 87 row[..., 0] = col[..., 0] 88 89 if ensure_self_adjoint_and_pd: 90 # Note that a Toeplitz matrix generated from a linearly decreasing 91 # non-negative sequence is positive definite. See 92 # https://www.math.cinvestav.mx/~grudsky/Papers/118_29062012_Albrecht.pdf 93 # for details. 94 row = np.linspace(start=10., stop=1., num=shape[-1]) 95 96 # The entries for the first row and column should be the same to guarantee 97 # symmetric. 98 row = col 99 100 lin_op_row = math_ops.cast(row, dtype=dtype) 101 lin_op_col = math_ops.cast(col, dtype=dtype) 102 103 if use_placeholder: 104 lin_op_row = array_ops.placeholder_with_default( 105 lin_op_row, shape=None) 106 lin_op_col = array_ops.placeholder_with_default( 107 lin_op_col, shape=None) 108 109 operator = linear_operator_toeplitz.LinearOperatorToeplitz( 110 row=lin_op_row, 111 col=lin_op_col, 112 is_self_adjoint=True if ensure_self_adjoint_and_pd else None, 113 is_positive_definite=True if ensure_self_adjoint_and_pd else None) 114 115 flattened_row = np.reshape(row, (-1, shape[-1])) 116 flattened_col = np.reshape(col, (-1, shape[-1])) 117 flattened_toeplitz = np.zeros( 118 [flattened_row.shape[0], shape[-1], shape[-1]]) 119 for i in range(flattened_row.shape[0]): 120 flattened_toeplitz[i] = scipy.linalg.toeplitz( 121 flattened_col[i], 122 flattened_row[i]) 123 matrix = np.reshape(flattened_toeplitz, shape) 124 matrix = math_ops.cast(matrix, dtype=dtype) 125 126 return operator, matrix 127 128 def test_scalar_row_col_raises(self): 129 with self.assertRaisesRegex(ValueError, "must have at least 1 dimension"): 130 linear_operator_toeplitz.LinearOperatorToeplitz(1., 1.) 131 132 with self.assertRaisesRegex(ValueError, "must have at least 1 dimension"): 133 linear_operator_toeplitz.LinearOperatorToeplitz([1.], 1.) 134 135 with self.assertRaisesRegex(ValueError, "must have at least 1 dimension"): 136 linear_operator_toeplitz.LinearOperatorToeplitz(1., [1.]) 137 138 def test_tape_safe(self): 139 col = variables_module.Variable([1.]) 140 row = variables_module.Variable([1.]) 141 operator = linear_operator_toeplitz.LinearOperatorToeplitz( 142 col, row, is_self_adjoint=True, is_positive_definite=True) 143 self.check_tape_safe( 144 operator, 145 skip_options=[ 146 # .diag_part, .trace depend only on `col`, so test explicitly below. 147 linear_operator_test_util.CheckTapeSafeSkipOptions.DIAG_PART, 148 linear_operator_test_util.CheckTapeSafeSkipOptions.TRACE, 149 ]) 150 151 with backprop.GradientTape() as tape: 152 self.assertIsNotNone(tape.gradient(operator.diag_part(), col)) 153 154 with backprop.GradientTape() as tape: 155 self.assertIsNotNone(tape.gradient(operator.trace(), col)) 156 157 def test_convert_variables_to_tensors(self): 158 col = variables_module.Variable([1.]) 159 row = variables_module.Variable([1.]) 160 operator = linear_operator_toeplitz.LinearOperatorToeplitz( 161 col, row, is_self_adjoint=True, is_positive_definite=True) 162 with self.cached_session() as sess: 163 sess.run([x.initializer for x in operator.variables]) 164 self.check_convert_variables_to_tensors(operator) 165 166 167if __name__ == "__main__": 168 linear_operator_test_util.add_tests(LinearOperatorToeplitzTest) 169 test.main() 170