1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16from tensorflow.python.framework import config 17from tensorflow.python.framework import test_util 18from tensorflow.python.ops import array_ops 19from tensorflow.python.ops import linalg_ops 20from tensorflow.python.ops import math_ops 21from tensorflow.python.ops import random_ops 22from tensorflow.python.ops import variables as variables_module 23from tensorflow.python.ops.linalg import linalg as linalg_lib 24from tensorflow.python.ops.linalg import linear_operator_test_util 25from tensorflow.python.platform import test 26 27linalg = linalg_lib 28 29 30@test_util.run_all_in_graph_and_eager_modes 31class LinearOperatorDiagTest( 32 linear_operator_test_util.SquareLinearOperatorDerivedClassTest): 33 """Most tests done in the base class LinearOperatorDerivedClassTest.""" 34 35 def tearDown(self): 36 config.enable_tensor_float_32_execution(self.tf32_keep_) 37 38 def setUp(self): 39 self.tf32_keep_ = config.tensor_float_32_execution_enabled() 40 config.enable_tensor_float_32_execution(False) 41 42 @staticmethod 43 def optional_tests(): 44 """List of optional test names to run.""" 45 return [ 46 "operator_matmul_with_same_type", 47 "operator_solve_with_same_type", 48 ] 49 50 def operator_and_matrix( 51 self, build_info, dtype, use_placeholder, 52 ensure_self_adjoint_and_pd=False): 53 shape = list(build_info.shape) 54 diag = linear_operator_test_util.random_sign_uniform( 55 shape[:-1], minval=1., maxval=2., dtype=dtype) 56 57 if ensure_self_adjoint_and_pd: 58 # Abs on complex64 will result in a float32, so we cast back up. 59 diag = math_ops.cast(math_ops.abs(diag), dtype=dtype) 60 61 lin_op_diag = diag 62 63 if use_placeholder: 64 lin_op_diag = array_ops.placeholder_with_default(diag, shape=None) 65 66 operator = linalg.LinearOperatorDiag( 67 lin_op_diag, 68 is_self_adjoint=True if ensure_self_adjoint_and_pd else None, 69 is_positive_definite=True if ensure_self_adjoint_and_pd else None) 70 71 matrix = array_ops.matrix_diag(diag) 72 73 return operator, matrix 74 75 def test_assert_positive_definite_raises_for_zero_eigenvalue(self): 76 # Matrix with one positive eigenvalue and one zero eigenvalue. 77 with self.cached_session(): 78 diag = [1.0, 0.0] 79 operator = linalg.LinearOperatorDiag(diag) 80 81 # is_self_adjoint should be auto-set for real diag. 82 self.assertTrue(operator.is_self_adjoint) 83 with self.assertRaisesOpError("non-positive.*not positive definite"): 84 operator.assert_positive_definite().run() 85 86 def test_assert_positive_definite_raises_for_negative_real_eigvalues(self): 87 with self.cached_session(): 88 diag_x = [1.0, -2.0] 89 diag_y = [0., 0.] # Imaginary eigenvalues should not matter. 90 diag = math_ops.complex(diag_x, diag_y) 91 operator = linalg.LinearOperatorDiag(diag) 92 93 # is_self_adjoint should not be auto-set for complex diag. 94 self.assertTrue(operator.is_self_adjoint is None) 95 with self.assertRaisesOpError("non-positive real.*not positive definite"): 96 operator.assert_positive_definite().run() 97 98 def test_assert_positive_definite_does_not_raise_if_pd_and_complex(self): 99 with self.cached_session(): 100 x = [1., 2.] 101 y = [1., 0.] 102 diag = math_ops.complex(x, y) # Re[diag] > 0. 103 # Should not fail 104 self.evaluate(linalg.LinearOperatorDiag(diag).assert_positive_definite()) 105 106 def test_assert_non_singular_raises_if_zero_eigenvalue(self): 107 # Singular matrix with one positive eigenvalue and one zero eigenvalue. 108 with self.cached_session(): 109 diag = [1.0, 0.0] 110 operator = linalg.LinearOperatorDiag(diag, is_self_adjoint=True) 111 with self.assertRaisesOpError("Singular operator"): 112 operator.assert_non_singular().run() 113 114 def test_assert_non_singular_does_not_raise_for_complex_nonsingular(self): 115 with self.cached_session(): 116 x = [1., 0.] 117 y = [0., 1.] 118 diag = math_ops.complex(x, y) 119 # Should not raise. 120 self.evaluate(linalg.LinearOperatorDiag(diag).assert_non_singular()) 121 122 def test_assert_self_adjoint_raises_if_diag_has_complex_part(self): 123 with self.cached_session(): 124 x = [1., 0.] 125 y = [0., 1.] 126 diag = math_ops.complex(x, y) 127 operator = linalg.LinearOperatorDiag(diag) 128 with self.assertRaisesOpError("imaginary.*not self-adjoint"): 129 operator.assert_self_adjoint().run() 130 131 def test_assert_self_adjoint_does_not_raise_for_diag_with_zero_imag(self): 132 with self.cached_session(): 133 x = [1., 0.] 134 y = [0., 0.] 135 diag = math_ops.complex(x, y) 136 operator = linalg.LinearOperatorDiag(diag) 137 # Should not raise 138 self.evaluate(operator.assert_self_adjoint()) 139 140 def test_scalar_diag_raises(self): 141 with self.assertRaisesRegex(ValueError, "must have at least 1 dimension"): 142 linalg.LinearOperatorDiag(1.) 143 144 def test_broadcast_matmul_and_solve(self): 145 # These cannot be done in the automated (base test class) tests since they 146 # test shapes that tf.matmul cannot handle. 147 # In particular, tf.matmul does not broadcast. 148 with self.cached_session() as sess: 149 x = random_ops.random_normal(shape=(2, 2, 3, 4)) 150 151 # This LinearOperatorDiag will be broadcast to (2, 2, 3, 3) during solve 152 # and matmul with 'x' as the argument. 153 diag = random_ops.random_uniform(shape=(2, 1, 3)) 154 operator = linalg.LinearOperatorDiag(diag, is_self_adjoint=True) 155 self.assertAllEqual((2, 1, 3, 3), operator.shape) 156 157 # Create a batch matrix with the broadcast shape of operator. 158 diag_broadcast = array_ops.concat((diag, diag), 1) 159 mat = array_ops.matrix_diag(diag_broadcast) 160 self.assertAllEqual((2, 2, 3, 3), mat.shape) # being pedantic. 161 162 operator_matmul = operator.matmul(x) 163 mat_matmul = math_ops.matmul(mat, x) 164 self.assertAllEqual(operator_matmul.shape, mat_matmul.shape) 165 self.assertAllClose(*self.evaluate([operator_matmul, mat_matmul])) 166 167 operator_solve = operator.solve(x) 168 mat_solve = linalg_ops.matrix_solve(mat, x) 169 self.assertAllEqual(operator_solve.shape, mat_solve.shape) 170 self.assertAllClose(*self.evaluate([operator_solve, mat_solve])) 171 172 def test_diag_matmul(self): 173 operator1 = linalg_lib.LinearOperatorDiag([2., 3.]) 174 operator2 = linalg_lib.LinearOperatorDiag([1., 2.]) 175 operator3 = linalg_lib.LinearOperatorScaledIdentity( 176 num_rows=2, multiplier=3.) 177 operator_matmul = operator1.matmul(operator2) 178 self.assertTrue(isinstance( 179 operator_matmul, 180 linalg_lib.LinearOperatorDiag)) 181 self.assertAllClose([2., 6.], self.evaluate(operator_matmul.diag)) 182 183 operator_matmul = operator2.matmul(operator1) 184 self.assertTrue(isinstance( 185 operator_matmul, 186 linalg_lib.LinearOperatorDiag)) 187 self.assertAllClose([2., 6.], self.evaluate(operator_matmul.diag)) 188 189 operator_matmul = operator1.matmul(operator3) 190 self.assertTrue(isinstance( 191 operator_matmul, 192 linalg_lib.LinearOperatorDiag)) 193 self.assertAllClose([6., 9.], self.evaluate(operator_matmul.diag)) 194 195 operator_matmul = operator3.matmul(operator1) 196 self.assertTrue(isinstance( 197 operator_matmul, 198 linalg_lib.LinearOperatorDiag)) 199 self.assertAllClose([6., 9.], self.evaluate(operator_matmul.diag)) 200 201 def test_diag_solve(self): 202 operator1 = linalg_lib.LinearOperatorDiag([2., 3.], is_non_singular=True) 203 operator2 = linalg_lib.LinearOperatorDiag([1., 2.], is_non_singular=True) 204 operator3 = linalg_lib.LinearOperatorScaledIdentity( 205 num_rows=2, multiplier=3., is_non_singular=True) 206 operator_solve = operator1.solve(operator2) 207 self.assertTrue(isinstance( 208 operator_solve, 209 linalg_lib.LinearOperatorDiag)) 210 self.assertAllClose([0.5, 2 / 3.], self.evaluate(operator_solve.diag)) 211 212 operator_solve = operator2.solve(operator1) 213 self.assertTrue(isinstance( 214 operator_solve, 215 linalg_lib.LinearOperatorDiag)) 216 self.assertAllClose([2., 3 / 2.], self.evaluate(operator_solve.diag)) 217 218 operator_solve = operator1.solve(operator3) 219 self.assertTrue(isinstance( 220 operator_solve, 221 linalg_lib.LinearOperatorDiag)) 222 self.assertAllClose([3 / 2., 1.], self.evaluate(operator_solve.diag)) 223 224 operator_solve = operator3.solve(operator1) 225 self.assertTrue(isinstance( 226 operator_solve, 227 linalg_lib.LinearOperatorDiag)) 228 self.assertAllClose([2 / 3., 1.], self.evaluate(operator_solve.diag)) 229 230 def test_diag_adjoint_type(self): 231 diag = [1., 3., 5., 8.] 232 operator = linalg.LinearOperatorDiag(diag, is_non_singular=True) 233 self.assertIsInstance(operator.adjoint(), linalg.LinearOperatorDiag) 234 235 def test_diag_cholesky_type(self): 236 diag = [1., 3., 5., 8.] 237 operator = linalg.LinearOperatorDiag( 238 diag, 239 is_positive_definite=True, 240 is_self_adjoint=True, 241 ) 242 self.assertIsInstance(operator.cholesky(), linalg.LinearOperatorDiag) 243 244 def test_diag_inverse_type(self): 245 diag = [1., 3., 5., 8.] 246 operator = linalg.LinearOperatorDiag(diag, is_non_singular=True) 247 self.assertIsInstance(operator.inverse(), linalg.LinearOperatorDiag) 248 249 def test_tape_safe(self): 250 diag = variables_module.Variable([[2.]]) 251 operator = linalg.LinearOperatorDiag(diag) 252 self.check_tape_safe(operator) 253 254 def test_convert_variables_to_tensors(self): 255 diag = variables_module.Variable([[2.]]) 256 operator = linalg.LinearOperatorDiag(diag) 257 with self.cached_session() as sess: 258 sess.run([diag.initializer]) 259 self.check_convert_variables_to_tensors(operator) 260 261 262if __name__ == "__main__": 263 linear_operator_test_util.add_tests(LinearOperatorDiagTest) 264 test.main() 265