1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16import numpy as np 17 18from tensorflow.python.framework import config 19from tensorflow.python.framework import dtypes 20from tensorflow.python.framework import test_util 21from tensorflow.python.ops import array_ops 22from tensorflow.python.ops import variables as variables_module 23from tensorflow.python.ops.linalg import linalg as linalg_lib 24from tensorflow.python.ops.linalg import linear_operator_adjoint 25from tensorflow.python.ops.linalg import linear_operator_test_util 26from tensorflow.python.platform import test 27 28linalg = linalg_lib 29 30LinearOperatorAdjoint = linear_operator_adjoint.LinearOperatorAdjoint # pylint: disable=invalid-name 31 32 33@test_util.run_all_in_graph_and_eager_modes 34class LinearOperatorAdjointTest( 35 linear_operator_test_util.SquareLinearOperatorDerivedClassTest): 36 """Most tests done in the base class LinearOperatorDerivedClassTest.""" 37 38 def tearDown(self): 39 config.enable_tensor_float_32_execution(self.tf32_keep_) 40 41 def setUp(self): 42 self.tf32_keep_ = config.tensor_float_32_execution_enabled() 43 config.enable_tensor_float_32_execution(False) 44 self._atol[dtypes.complex64] = 1e-5 45 self._rtol[dtypes.complex64] = 1e-5 46 47 def operator_and_matrix(self, 48 build_info, 49 dtype, 50 use_placeholder, 51 ensure_self_adjoint_and_pd=False): 52 shape = list(build_info.shape) 53 54 if ensure_self_adjoint_and_pd: 55 matrix = linear_operator_test_util.random_positive_definite_matrix( 56 shape, dtype, force_well_conditioned=True) 57 else: 58 matrix = linear_operator_test_util.random_tril_matrix( 59 shape, dtype, force_well_conditioned=True, remove_upper=True) 60 61 lin_op_matrix = matrix 62 63 if use_placeholder: 64 lin_op_matrix = array_ops.placeholder_with_default(matrix, shape=None) 65 66 if ensure_self_adjoint_and_pd: 67 operator = LinearOperatorAdjoint( 68 linalg.LinearOperatorFullMatrix( 69 lin_op_matrix, is_positive_definite=True, is_self_adjoint=True)) 70 else: 71 operator = LinearOperatorAdjoint( 72 linalg.LinearOperatorLowerTriangular(lin_op_matrix)) 73 74 return operator, linalg.adjoint(matrix) 75 76 def test_base_operator_hint_used(self): 77 # The matrix values do not effect auto-setting of the flags. 78 matrix = [[1., 0.], [1., 1.]] 79 operator = linalg.LinearOperatorFullMatrix( 80 matrix, 81 is_positive_definite=True, 82 is_non_singular=True, 83 is_self_adjoint=False) 84 operator_adjoint = LinearOperatorAdjoint(operator) 85 self.assertTrue(operator_adjoint.is_positive_definite) 86 self.assertTrue(operator_adjoint.is_non_singular) 87 self.assertFalse(operator_adjoint.is_self_adjoint) 88 89 def test_supplied_hint_used(self): 90 # The matrix values do not effect auto-setting of the flags. 91 matrix = [[1., 0.], [1., 1.]] 92 operator = linalg.LinearOperatorFullMatrix(matrix) 93 operator_adjoint = LinearOperatorAdjoint( 94 operator, 95 is_positive_definite=True, 96 is_non_singular=True, 97 is_self_adjoint=False) 98 self.assertTrue(operator_adjoint.is_positive_definite) 99 self.assertTrue(operator_adjoint.is_non_singular) 100 self.assertFalse(operator_adjoint.is_self_adjoint) 101 102 def test_contradicting_hints_raise(self): 103 # The matrix values do not effect auto-setting of the flags. 104 matrix = [[1., 0.], [1., 1.]] 105 operator = linalg.LinearOperatorFullMatrix( 106 matrix, is_positive_definite=False) 107 with self.assertRaisesRegex(ValueError, "positive-definite"): 108 LinearOperatorAdjoint(operator, is_positive_definite=True) 109 110 operator = linalg.LinearOperatorFullMatrix(matrix, is_self_adjoint=False) 111 with self.assertRaisesRegex(ValueError, "self-adjoint"): 112 LinearOperatorAdjoint(operator, is_self_adjoint=True) 113 114 def test_name(self): 115 matrix = [[11., 0.], [1., 8.]] 116 operator = linalg.LinearOperatorFullMatrix( 117 matrix, name="my_operator", is_non_singular=True) 118 119 operator = LinearOperatorAdjoint(operator) 120 121 self.assertEqual("my_operator_adjoint", operator.name) 122 123 def test_matmul_adjoint_operator(self): 124 matrix1 = np.random.randn(4, 4) 125 matrix2 = np.random.randn(4, 4) 126 full_matrix1 = linalg.LinearOperatorFullMatrix(matrix1) 127 full_matrix2 = linalg.LinearOperatorFullMatrix(matrix2) 128 129 self.assertAllClose( 130 np.matmul(matrix1, matrix2.T), 131 self.evaluate( 132 full_matrix1.matmul(full_matrix2, adjoint_arg=True).to_dense())) 133 134 self.assertAllClose( 135 np.matmul(matrix1.T, matrix2), 136 self.evaluate( 137 full_matrix1.matmul(full_matrix2, adjoint=True).to_dense())) 138 139 self.assertAllClose( 140 np.matmul(matrix1.T, matrix2.T), 141 self.evaluate( 142 full_matrix1.matmul(full_matrix2, adjoint=True, 143 adjoint_arg=True).to_dense())) 144 145 def test_matmul_adjoint_complex_operator(self): 146 matrix1 = np.random.randn(4, 4) + 1j * np.random.randn(4, 4) 147 matrix2 = np.random.randn(4, 4) + 1j * np.random.randn(4, 4) 148 full_matrix1 = linalg.LinearOperatorFullMatrix(matrix1) 149 full_matrix2 = linalg.LinearOperatorFullMatrix(matrix2) 150 151 self.assertAllClose( 152 np.matmul(matrix1, 153 matrix2.conj().T), 154 self.evaluate( 155 full_matrix1.matmul(full_matrix2, adjoint_arg=True).to_dense())) 156 157 self.assertAllClose( 158 np.matmul(matrix1.conj().T, matrix2), 159 self.evaluate( 160 full_matrix1.matmul(full_matrix2, adjoint=True).to_dense())) 161 162 self.assertAllClose( 163 np.matmul(matrix1.conj().T, 164 matrix2.conj().T), 165 self.evaluate( 166 full_matrix1.matmul(full_matrix2, adjoint=True, 167 adjoint_arg=True).to_dense())) 168 169 def test_matvec(self): 170 matrix = np.array([[1., 2.], [3., 4.]]) 171 x = np.array([1., 2.]) 172 operator = linalg.LinearOperatorFullMatrix(matrix) 173 self.assertAllClose(matrix.dot(x), self.evaluate(operator.matvec(x))) 174 self.assertAllClose(matrix.T.dot(x), self.evaluate(operator.H.matvec(x))) 175 176 def test_solve_adjoint_operator(self): 177 matrix1 = self.evaluate( 178 linear_operator_test_util.random_tril_matrix( 179 [4, 4], dtype=dtypes.float64, force_well_conditioned=True)) 180 matrix2 = np.random.randn(4, 4) 181 full_matrix1 = linalg.LinearOperatorLowerTriangular( 182 matrix1, is_non_singular=True) 183 full_matrix2 = linalg.LinearOperatorFullMatrix(matrix2) 184 185 self.assertAllClose( 186 self.evaluate(linalg.triangular_solve(matrix1, matrix2.T)), 187 self.evaluate( 188 full_matrix1.solve(full_matrix2, adjoint_arg=True).to_dense())) 189 190 self.assertAllClose( 191 self.evaluate(linalg.triangular_solve(matrix1.T, matrix2, lower=False)), 192 self.evaluate( 193 full_matrix1.solve(full_matrix2, adjoint=True).to_dense())) 194 195 self.assertAllClose( 196 self.evaluate( 197 linalg.triangular_solve(matrix1.T, matrix2.T, lower=False)), 198 self.evaluate( 199 full_matrix1.solve(full_matrix2, adjoint=True, 200 adjoint_arg=True).to_dense())) 201 202 def test_solve_adjoint_complex_operator(self): 203 matrix1 = self.evaluate( 204 linear_operator_test_util.random_tril_matrix( 205 [4, 4], dtype=dtypes.complex128, force_well_conditioned=True) + 206 1j * linear_operator_test_util.random_tril_matrix( 207 [4, 4], dtype=dtypes.complex128, force_well_conditioned=True)) 208 matrix2 = np.random.randn(4, 4) + 1j * np.random.randn(4, 4) 209 210 full_matrix1 = linalg.LinearOperatorLowerTriangular( 211 matrix1, is_non_singular=True) 212 full_matrix2 = linalg.LinearOperatorFullMatrix(matrix2) 213 214 self.assertAllClose( 215 self.evaluate(linalg.triangular_solve(matrix1, 216 matrix2.conj().T)), 217 self.evaluate( 218 full_matrix1.solve(full_matrix2, adjoint_arg=True).to_dense())) 219 220 self.assertAllClose( 221 self.evaluate( 222 linalg.triangular_solve(matrix1.conj().T, matrix2, lower=False)), 223 self.evaluate( 224 full_matrix1.solve(full_matrix2, adjoint=True).to_dense())) 225 226 self.assertAllClose( 227 self.evaluate( 228 linalg.triangular_solve( 229 matrix1.conj().T, matrix2.conj().T, lower=False)), 230 self.evaluate( 231 full_matrix1.solve(full_matrix2, adjoint=True, 232 adjoint_arg=True).to_dense())) 233 234 def test_solvevec(self): 235 matrix = np.array([[1., 2.], [3., 4.]]) 236 inv_matrix = np.linalg.inv(matrix) 237 x = np.array([1., 2.]) 238 operator = linalg.LinearOperatorFullMatrix(matrix) 239 self.assertAllClose(inv_matrix.dot(x), self.evaluate(operator.solvevec(x))) 240 self.assertAllClose( 241 inv_matrix.T.dot(x), self.evaluate(operator.H.solvevec(x))) 242 243 def test_tape_safe(self): 244 matrix = variables_module.Variable([[1., 2.], [3., 4.]]) 245 operator = LinearOperatorAdjoint(linalg.LinearOperatorFullMatrix(matrix)) 246 self.check_tape_safe(operator) 247 248 249@test_util.run_all_in_graph_and_eager_modes 250class LinearOperatorAdjointNonSquareTest( 251 linear_operator_test_util.NonSquareLinearOperatorDerivedClassTest): 252 """Tests done in the base class NonSquareLinearOperatorDerivedClassTest.""" 253 254 def operator_and_matrix(self, build_info, dtype, use_placeholder): 255 shape_before_adjoint = list(build_info.shape) 256 # We need to swap the last two dimensions because we are taking the adjoint 257 # of this operator 258 shape_before_adjoint[-1], shape_before_adjoint[-2] = ( 259 shape_before_adjoint[-2], shape_before_adjoint[-1]) 260 matrix = linear_operator_test_util.random_normal( 261 shape_before_adjoint, dtype=dtype) 262 263 lin_op_matrix = matrix 264 265 if use_placeholder: 266 lin_op_matrix = array_ops.placeholder_with_default(matrix, shape=None) 267 268 operator = LinearOperatorAdjoint( 269 linalg.LinearOperatorFullMatrix(lin_op_matrix)) 270 271 return operator, linalg.adjoint(matrix) 272 273 274if __name__ == "__main__": 275 linear_operator_test_util.add_tests(LinearOperatorAdjointTest) 276 test.main() 277