1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16import numpy as np 17 18from tensorflow.python.framework import config 19from tensorflow.python.framework import dtypes 20from tensorflow.python.framework import errors 21from tensorflow.python.framework import test_util 22from tensorflow.python.ops import array_ops 23from tensorflow.python.ops import math_ops 24from tensorflow.python.ops import variables as variables_module 25from tensorflow.python.ops.linalg import linalg as linalg_lib 26from tensorflow.python.ops.linalg import linear_operator_test_util 27from tensorflow.python.platform import test 28 29linalg = linalg_lib 30 31 32@test_util.run_all_in_graph_and_eager_modes 33class SquareLinearOperatorFullMatrixTest( 34 linear_operator_test_util.SquareLinearOperatorDerivedClassTest): 35 """Most tests done in the base class LinearOperatorDerivedClassTest.""" 36 37 def operator_and_matrix( 38 self, build_info, dtype, use_placeholder, 39 ensure_self_adjoint_and_pd=False): 40 shape = list(build_info.shape) 41 42 matrix = linear_operator_test_util.random_positive_definite_matrix( 43 shape, dtype) 44 45 lin_op_matrix = matrix 46 47 if use_placeholder: 48 lin_op_matrix = array_ops.placeholder_with_default(matrix, shape=None) 49 50 # Set the hints to none to test non-symmetric PD code paths. 51 operator = linalg.LinearOperatorFullMatrix( 52 lin_op_matrix, 53 is_square=True, 54 is_self_adjoint=True if ensure_self_adjoint_and_pd else None, 55 is_positive_definite=True if ensure_self_adjoint_and_pd else None) 56 57 return operator, matrix 58 59 def test_is_x_flags(self): 60 # Matrix with two positive eigenvalues. 61 matrix = [[1., 0.], [1., 11.]] 62 operator = linalg.LinearOperatorFullMatrix( 63 matrix, 64 is_positive_definite=True, 65 is_non_singular=True, 66 is_self_adjoint=False) 67 self.assertTrue(operator.is_positive_definite) 68 self.assertTrue(operator.is_non_singular) 69 self.assertFalse(operator.is_self_adjoint) 70 # Auto-detected. 71 self.assertTrue(operator.is_square) 72 73 def test_assert_non_singular_raises_if_cond_too_big_but_finite(self): 74 with self.cached_session(): 75 tril = linear_operator_test_util.random_tril_matrix( 76 shape=(50, 50), dtype=np.float32) 77 diag = np.logspace(-2, 2, 50).astype(np.float32) 78 tril = array_ops.matrix_set_diag(tril, diag) 79 matrix = self.evaluate(math_ops.matmul(tril, tril, transpose_b=True)) 80 operator = linalg.LinearOperatorFullMatrix(matrix) 81 with self.assertRaisesOpError("Singular matrix"): 82 # Ensure that we have finite condition number...just HUGE. 83 cond = np.linalg.cond(matrix) 84 self.assertTrue(np.isfinite(cond)) 85 self.assertGreater(cond, 1e12) 86 operator.assert_non_singular().run() 87 88 def test_assert_non_singular_raises_if_cond_infinite(self): 89 with self.cached_session(): 90 matrix = [[1., 1.], [1., 1.]] 91 # We don't pass the is_self_adjoint hint here, which means we take the 92 # generic code path. 93 operator = linalg.LinearOperatorFullMatrix(matrix) 94 with self.assertRaisesOpError("Singular matrix"): 95 operator.assert_non_singular().run() 96 97 def test_assert_self_adjoint(self): 98 matrix = [[0., 1.], [0., 1.]] 99 operator = linalg.LinearOperatorFullMatrix(matrix) 100 with self.cached_session(): 101 with self.assertRaisesOpError("not equal to its adjoint"): 102 operator.assert_self_adjoint().run() 103 104 @test_util.disable_xla("Assert statements in kernels not supported in XLA") 105 def test_assert_positive_definite(self): 106 matrix = [[1., 1.], [1., 1.]] 107 operator = linalg.LinearOperatorFullMatrix(matrix, is_self_adjoint=True) 108 with self.cached_session(): 109 with self.assertRaises(errors.InvalidArgumentError): 110 operator.assert_positive_definite().run() 111 112 def test_tape_safe(self): 113 matrix = variables_module.Variable([[2.]]) 114 operator = linalg.LinearOperatorFullMatrix(matrix) 115 self.check_tape_safe(operator) 116 117 def test_convert_variables_to_tensors(self): 118 matrix = variables_module.Variable([[3.]]) 119 operator = linalg.LinearOperatorFullMatrix(matrix) 120 with self.cached_session() as sess: 121 sess.run([matrix.initializer]) 122 self.check_convert_variables_to_tensors(operator) 123 124 125@test_util.run_all_in_graph_and_eager_modes 126class SquareLinearOperatorFullMatrixSymmetricPositiveDefiniteTest( 127 linear_operator_test_util.SquareLinearOperatorDerivedClassTest): 128 """Most tests done in the base class LinearOperatorDerivedClassTest. 129 130 In this test, the operator is constructed with hints that invoke the use of 131 a Cholesky decomposition for solves/determinant. 132 """ 133 134 def setUp(self): 135 # Increase from 1e-6 to 1e-5. This reduction in tolerance happens, 136 # presumably, because we are taking a different code path in the operator 137 # and the matrix. The operator uses a Cholesky, the matrix uses standard 138 # solve. 139 self._atol[dtypes.float32] = 1e-5 140 self._rtol[dtypes.float32] = 1e-5 141 self._atol[dtypes.float64] = 1e-10 142 self._rtol[dtypes.float64] = 1e-10 143 144 @staticmethod 145 def dtypes_to_test(): 146 return [dtypes.float32, dtypes.float64] 147 148 def operator_and_matrix( 149 self, build_info, dtype, use_placeholder, 150 ensure_self_adjoint_and_pd=False): 151 152 # Matrix is always symmetric and positive definite in this class. 153 del ensure_self_adjoint_and_pd 154 155 shape = list(build_info.shape) 156 157 matrix = linear_operator_test_util.random_positive_definite_matrix( 158 shape, dtype, force_well_conditioned=True) 159 160 lin_op_matrix = matrix 161 162 if use_placeholder: 163 lin_op_matrix = array_ops.placeholder_with_default(matrix, shape=None) 164 165 operator = linalg.LinearOperatorFullMatrix( 166 lin_op_matrix, 167 is_square=True, 168 is_self_adjoint=True, 169 is_positive_definite=True) 170 171 return operator, matrix 172 173 def test_is_x_flags(self): 174 # Matrix with two positive eigenvalues. 175 matrix = [[1., 0.], [0., 7.]] 176 operator = linalg.LinearOperatorFullMatrix( 177 matrix, is_positive_definite=True, is_self_adjoint=True) 178 179 self.assertTrue(operator.is_positive_definite) 180 self.assertTrue(operator.is_self_adjoint) 181 182 # Should be auto-set 183 self.assertTrue(operator.is_non_singular) 184 self.assertTrue(operator._can_use_cholesky) 185 self.assertTrue(operator.is_square) 186 187 @test_util.disable_xla("Assert statements in kernels not supported in XLA") 188 def test_assert_non_singular(self): 189 matrix = [[1., 1.], [1., 1.]] 190 operator = linalg.LinearOperatorFullMatrix( 191 matrix, is_self_adjoint=True, is_positive_definite=True) 192 with self.cached_session(): 193 # Cholesky decomposition may fail, so the error is not specific to 194 # non-singular. 195 with self.assertRaisesOpError(""): 196 operator.assert_non_singular().run() 197 198 def test_assert_self_adjoint(self): 199 matrix = [[0., 1.], [0., 1.]] 200 operator = linalg.LinearOperatorFullMatrix( 201 matrix, is_self_adjoint=True, is_positive_definite=True) 202 with self.cached_session(): 203 with self.assertRaisesOpError("not equal to its adjoint"): 204 operator.assert_self_adjoint().run() 205 206 @test_util.disable_xla("Assert statements in kernels not supported in XLA") 207 def test_assert_positive_definite(self): 208 matrix = [[1., 1.], [1., 1.]] 209 operator = linalg.LinearOperatorFullMatrix( 210 matrix, is_self_adjoint=True, is_positive_definite=True) 211 with self.cached_session(): 212 # Cholesky decomposition may fail, so the error is not specific to 213 # non-singular. 214 with self.assertRaisesOpError(""): 215 operator.assert_positive_definite().run() 216 217 def test_tape_safe(self): 218 matrix = variables_module.Variable([[2.]]) 219 operator = linalg.LinearOperatorFullMatrix( 220 matrix, is_self_adjoint=True, is_positive_definite=True) 221 self.check_tape_safe(operator) 222 223 224@test_util.run_all_in_graph_and_eager_modes 225class NonSquareLinearOperatorFullMatrixTest( 226 linear_operator_test_util.NonSquareLinearOperatorDerivedClassTest): 227 """Most tests done in the base class LinearOperatorDerivedClassTest.""" 228 229 def operator_and_matrix( 230 self, build_info, dtype, use_placeholder, 231 ensure_self_adjoint_and_pd=False): 232 del ensure_self_adjoint_and_pd 233 shape = list(build_info.shape) 234 matrix = linear_operator_test_util.random_normal(shape, dtype=dtype) 235 236 lin_op_matrix = matrix 237 238 if use_placeholder: 239 lin_op_matrix = array_ops.placeholder_with_default(matrix, shape=None) 240 241 operator = linalg.LinearOperatorFullMatrix(lin_op_matrix, is_square=True) 242 243 return operator, matrix 244 245 def test_is_x_flags(self): 246 matrix = [[3., 2., 1.], [1., 1., 1.]] 247 operator = linalg.LinearOperatorFullMatrix( 248 matrix, 249 is_self_adjoint=False) 250 self.assertEqual(operator.is_positive_definite, None) 251 self.assertEqual(operator.is_non_singular, None) 252 self.assertFalse(operator.is_self_adjoint) 253 self.assertFalse(operator.is_square) 254 255 def test_matrix_must_have_at_least_two_dims_or_raises(self): 256 with self.assertRaisesRegex(ValueError, "at least 2 dimensions"): 257 linalg.LinearOperatorFullMatrix([1.]) 258 259 def test_tape_safe(self): 260 matrix = variables_module.Variable([[2., 1.]]) 261 operator = linalg.LinearOperatorFullMatrix(matrix) 262 self.check_tape_safe(operator) 263 264 265if __name__ == "__main__": 266 config.enable_tensor_float_32_execution(False) 267 linear_operator_test_util.add_tests(SquareLinearOperatorFullMatrixTest) 268 linear_operator_test_util.add_tests(NonSquareLinearOperatorFullMatrixTest) 269 linear_operator_test_util.add_tests( 270 SquareLinearOperatorFullMatrixSymmetricPositiveDefiniteTest) 271 test.main() 272