1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16from tensorflow.python.framework import config 17from tensorflow.python.framework import dtypes 18from tensorflow.python.framework import test_util 19from tensorflow.python.ops import array_ops 20from tensorflow.python.ops import variables as variables_module 21from tensorflow.python.ops.linalg import linalg as linalg_lib 22from tensorflow.python.ops.linalg import linear_operator_inversion 23from tensorflow.python.ops.linalg import linear_operator_test_util 24from tensorflow.python.platform import test 25 26linalg = linalg_lib 27 28LinearOperatorInversion = linear_operator_inversion.LinearOperatorInversion # pylint: disable=invalid-name 29 30 31@test_util.run_all_in_graph_and_eager_modes 32class LinearOperatorInversionTest( 33 linear_operator_test_util.SquareLinearOperatorDerivedClassTest): 34 """Most tests done in the base class LinearOperatorDerivedClassTest.""" 35 36 def tearDown(self): 37 config.enable_tensor_float_32_execution(self.tf32_keep_) 38 39 def setUp(self): 40 self.tf32_keep_ = config.tensor_float_32_execution_enabled() 41 config.enable_tensor_float_32_execution(False) 42 self._atol[dtypes.complex64] = 1e-5 43 self._rtol[dtypes.complex64] = 1e-5 44 45 def operator_and_matrix(self, 46 build_info, 47 dtype, 48 use_placeholder, 49 ensure_self_adjoint_and_pd=False): 50 shape = list(build_info.shape) 51 52 if ensure_self_adjoint_and_pd: 53 matrix = linear_operator_test_util.random_positive_definite_matrix( 54 shape, dtype, force_well_conditioned=True) 55 else: 56 matrix = linear_operator_test_util.random_tril_matrix( 57 shape, dtype, force_well_conditioned=True, remove_upper=True) 58 59 lin_op_matrix = matrix 60 61 if use_placeholder: 62 lin_op_matrix = array_ops.placeholder_with_default(matrix, shape=None) 63 64 if ensure_self_adjoint_and_pd: 65 operator = LinearOperatorInversion( 66 linalg.LinearOperatorFullMatrix( 67 lin_op_matrix, is_positive_definite=True, is_self_adjoint=True)) 68 else: 69 operator = LinearOperatorInversion( 70 linalg.LinearOperatorLowerTriangular(lin_op_matrix)) 71 72 return operator, linalg.inv(matrix) 73 74 def test_base_operator_hint_used(self): 75 # The matrix values do not effect auto-setting of the flags. 76 matrix = [[1., 0.], [1., 1.]] 77 operator = linalg.LinearOperatorFullMatrix( 78 matrix, 79 is_positive_definite=True, 80 is_non_singular=True, 81 is_self_adjoint=False) 82 operator_inv = LinearOperatorInversion(operator) 83 self.assertTrue(operator_inv.is_positive_definite) 84 self.assertTrue(operator_inv.is_non_singular) 85 self.assertFalse(operator_inv.is_self_adjoint) 86 87 def test_supplied_hint_used(self): 88 # The matrix values do not effect auto-setting of the flags. 89 matrix = [[1., 0.], [1., 1.]] 90 operator = linalg.LinearOperatorFullMatrix(matrix) 91 operator_inv = LinearOperatorInversion( 92 operator, 93 is_positive_definite=True, 94 is_non_singular=True, 95 is_self_adjoint=False) 96 self.assertTrue(operator_inv.is_positive_definite) 97 self.assertTrue(operator_inv.is_non_singular) 98 self.assertFalse(operator_inv.is_self_adjoint) 99 100 def test_contradicting_hints_raise(self): 101 # The matrix values do not effect auto-setting of the flags. 102 matrix = [[1., 0.], [1., 1.]] 103 operator = linalg.LinearOperatorFullMatrix( 104 matrix, is_positive_definite=False) 105 with self.assertRaisesRegex(ValueError, "positive-definite"): 106 LinearOperatorInversion(operator, is_positive_definite=True) 107 108 operator = linalg.LinearOperatorFullMatrix(matrix, is_self_adjoint=False) 109 with self.assertRaisesRegex(ValueError, "self-adjoint"): 110 LinearOperatorInversion(operator, is_self_adjoint=True) 111 112 def test_singular_raises(self): 113 # The matrix values do not effect auto-setting of the flags. 114 matrix = [[1., 1.], [1., 1.]] 115 116 operator = linalg.LinearOperatorFullMatrix(matrix, is_non_singular=False) 117 with self.assertRaisesRegex(ValueError, "is_non_singular"): 118 LinearOperatorInversion(operator) 119 120 operator = linalg.LinearOperatorFullMatrix(matrix) 121 with self.assertRaisesRegex(ValueError, "is_non_singular"): 122 LinearOperatorInversion(operator, is_non_singular=False) 123 124 def test_name(self): 125 matrix = [[11., 0.], [1., 8.]] 126 operator = linalg.LinearOperatorFullMatrix( 127 matrix, name="my_operator", is_non_singular=True) 128 129 operator = LinearOperatorInversion(operator) 130 131 self.assertEqual("my_operator_inv", operator.name) 132 133 def test_tape_safe(self): 134 matrix = variables_module.Variable([[1., 2.], [3., 4.]]) 135 operator = LinearOperatorInversion(linalg.LinearOperatorFullMatrix(matrix)) 136 self.check_tape_safe(operator) 137 138 139if __name__ == "__main__": 140 linear_operator_test_util.add_tests(LinearOperatorInversionTest) 141 test.main() 142