1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16import numpy as np 17 18from tensorflow.python.framework import config 19from tensorflow.python.framework import dtypes 20from tensorflow.python.framework import test_util 21from tensorflow.python.ops import array_ops 22from tensorflow.python.ops import math_ops 23from tensorflow.python.ops import variables as variables_module 24from tensorflow.python.ops.linalg import linalg as linalg_lib 25from tensorflow.python.ops.linalg import linear_operator_test_util 26from tensorflow.python.platform import test 27 28linalg = linalg_lib 29rng = np.random.RandomState(0) 30 31 32class BaseLinearOperatorLowRankUpdatetest(object): 33 """Base test for this type of operator.""" 34 35 # Subclasses should set these attributes to either True or False. 36 37 # If True, A = L + UDV^H 38 # If False, A = L + UV^H or A = L + UU^H, depending on _use_v. 39 _use_diag_update = None 40 41 # If True, diag is > 0, which means D is symmetric positive definite. 42 _is_diag_update_positive = None 43 44 # If True, A = L + UDV^H 45 # If False, A = L + UDU^H or A = L + UU^H, depending on _use_diag_update 46 _use_v = None 47 48 @staticmethod 49 def operator_shapes_infos(): 50 shape_info = linear_operator_test_util.OperatorShapesInfo 51 # Previously we had a (2, 10, 10) shape at the end. We did this to test the 52 # inversion and determinant lemmas on not-tiny matrices, since these are 53 # known to have stability issues. This resulted in test timeouts, so this 54 # shape has been removed, but rest assured, the tests did pass. 55 return [ 56 shape_info((0, 0)), 57 shape_info((1, 1)), 58 shape_info((1, 3, 3)), 59 shape_info((3, 4, 4)), 60 shape_info((2, 1, 4, 4))] 61 62 def _gen_positive_diag(self, dtype, diag_shape): 63 if dtype.is_complex: 64 diag = linear_operator_test_util.random_uniform( 65 diag_shape, minval=1e-4, maxval=1., dtype=dtypes.float32) 66 return math_ops.cast(diag, dtype=dtype) 67 68 return linear_operator_test_util.random_uniform( 69 diag_shape, minval=1e-4, maxval=1., dtype=dtype) 70 71 def operator_and_matrix(self, shape_info, dtype, use_placeholder, 72 ensure_self_adjoint_and_pd=False): 73 # Recall A = L + UDV^H 74 shape = list(shape_info.shape) 75 diag_shape = shape[:-1] 76 k = shape[-2] // 2 + 1 77 u_perturbation_shape = shape[:-1] + [k] 78 diag_update_shape = shape[:-2] + [k] 79 80 # base_operator L will be a symmetric positive definite diagonal linear 81 # operator, with condition number as high as 1e4. 82 base_diag = self._gen_positive_diag(dtype, diag_shape) 83 lin_op_base_diag = base_diag 84 85 # U 86 u = linear_operator_test_util.random_normal_correlated_columns( 87 u_perturbation_shape, dtype=dtype) 88 lin_op_u = u 89 90 # V 91 v = linear_operator_test_util.random_normal_correlated_columns( 92 u_perturbation_shape, dtype=dtype) 93 lin_op_v = v 94 95 # D 96 if self._is_diag_update_positive or ensure_self_adjoint_and_pd: 97 diag_update = self._gen_positive_diag(dtype, diag_update_shape) 98 else: 99 diag_update = linear_operator_test_util.random_normal( 100 diag_update_shape, stddev=1e-4, dtype=dtype) 101 lin_op_diag_update = diag_update 102 103 if use_placeholder: 104 lin_op_base_diag = array_ops.placeholder_with_default( 105 base_diag, shape=None) 106 lin_op_u = array_ops.placeholder_with_default(u, shape=None) 107 lin_op_v = array_ops.placeholder_with_default(v, shape=None) 108 lin_op_diag_update = array_ops.placeholder_with_default( 109 diag_update, shape=None) 110 111 base_operator = linalg.LinearOperatorDiag( 112 lin_op_base_diag, 113 is_positive_definite=True, 114 is_self_adjoint=True) 115 116 operator = linalg.LinearOperatorLowRankUpdate( 117 base_operator, 118 lin_op_u, 119 v=lin_op_v if self._use_v else None, 120 diag_update=lin_op_diag_update if self._use_diag_update else None, 121 is_diag_update_positive=self._is_diag_update_positive) 122 123 # The matrix representing L 124 base_diag_mat = array_ops.matrix_diag(base_diag) 125 126 # The matrix representing D 127 diag_update_mat = array_ops.matrix_diag(diag_update) 128 129 # Set up mat as some variant of A = L + UDV^H 130 if self._use_v and self._use_diag_update: 131 # In this case, we have L + UDV^H and it isn't symmetric. 132 expect_use_cholesky = False 133 matrix = base_diag_mat + math_ops.matmul( 134 u, math_ops.matmul(diag_update_mat, v, adjoint_b=True)) 135 elif self._use_v: 136 # In this case, we have L + UDV^H and it isn't symmetric. 137 expect_use_cholesky = False 138 matrix = base_diag_mat + math_ops.matmul(u, v, adjoint_b=True) 139 elif self._use_diag_update: 140 # In this case, we have L + UDU^H, which is PD if D > 0, since L > 0. 141 expect_use_cholesky = self._is_diag_update_positive 142 matrix = base_diag_mat + math_ops.matmul( 143 u, math_ops.matmul(diag_update_mat, u, adjoint_b=True)) 144 else: 145 # In this case, we have L + UU^H, which is PD since L > 0. 146 expect_use_cholesky = True 147 matrix = base_diag_mat + math_ops.matmul(u, u, adjoint_b=True) 148 149 if expect_use_cholesky: 150 self.assertTrue(operator._use_cholesky) 151 else: 152 self.assertFalse(operator._use_cholesky) 153 154 return operator, matrix 155 156 def test_tape_safe(self): 157 base_operator = linalg.LinearOperatorDiag( 158 variables_module.Variable([1.], name="diag"), 159 is_positive_definite=True, 160 is_self_adjoint=True) 161 162 operator = linalg.LinearOperatorLowRankUpdate( 163 base_operator, 164 u=variables_module.Variable([[2.]], name="u"), 165 v=variables_module.Variable([[1.25]], name="v") 166 if self._use_v else None, 167 diag_update=variables_module.Variable([1.25], name="diag_update") 168 if self._use_diag_update else None, 169 is_diag_update_positive=self._is_diag_update_positive) 170 self.check_tape_safe(operator) 171 172 def test_convert_variables_to_tensors(self): 173 base_operator = linalg.LinearOperatorDiag( 174 variables_module.Variable([1.], name="diag"), 175 is_positive_definite=True, 176 is_self_adjoint=True) 177 178 operator = linalg.LinearOperatorLowRankUpdate( 179 base_operator, 180 u=variables_module.Variable([[2.]], name="u"), 181 v=variables_module.Variable([[1.25]], name="v") 182 if self._use_v else None, 183 diag_update=variables_module.Variable([1.25], name="diag_update") 184 if self._use_diag_update else None, 185 is_diag_update_positive=self._is_diag_update_positive) 186 with self.cached_session() as sess: 187 sess.run([x.initializer for x in operator.variables]) 188 self.check_convert_variables_to_tensors(operator) 189 190 191class LinearOperatorLowRankUpdatetestWithDiagUseCholesky( 192 BaseLinearOperatorLowRankUpdatetest, 193 linear_operator_test_util.SquareLinearOperatorDerivedClassTest): 194 """A = L + UDU^H, D > 0, L > 0 ==> A > 0 and we can use a Cholesky.""" 195 196 _use_diag_update = True 197 _is_diag_update_positive = True 198 _use_v = False 199 200 def tearDown(self): 201 config.enable_tensor_float_32_execution(self.tf32_keep_) 202 203 def setUp(self): 204 self.tf32_keep_ = config.tensor_float_32_execution_enabled() 205 config.enable_tensor_float_32_execution(False) 206 # Decrease tolerance since we are testing with condition numbers as high as 207 # 1e4. 208 self._atol[dtypes.float32] = 1e-5 209 self._rtol[dtypes.float32] = 1e-5 210 self._atol[dtypes.float64] = 1e-10 211 self._rtol[dtypes.float64] = 1e-10 212 self._rtol[dtypes.complex64] = 1e-4 213 214 215class LinearOperatorLowRankUpdatetestWithDiagCannotUseCholesky( 216 BaseLinearOperatorLowRankUpdatetest, 217 linear_operator_test_util.SquareLinearOperatorDerivedClassTest): 218 """A = L + UDU^H, D !> 0, L > 0 ==> A !> 0 and we cannot use a Cholesky.""" 219 220 @staticmethod 221 def skip_these_tests(): 222 return ["cholesky", "eigvalsh"] 223 224 _use_diag_update = True 225 _is_diag_update_positive = False 226 _use_v = False 227 228 def tearDown(self): 229 config.enable_tensor_float_32_execution(self.tf32_keep_) 230 231 def setUp(self): 232 self.tf32_keep_ = config.tensor_float_32_execution_enabled() 233 config.enable_tensor_float_32_execution(False) 234 # Decrease tolerance since we are testing with condition numbers as high as 235 # 1e4. This class does not use Cholesky, and thus needs even looser 236 # tolerance. 237 self._atol[dtypes.float32] = 1e-4 238 self._rtol[dtypes.float32] = 1e-4 239 self._atol[dtypes.float64] = 1e-9 240 self._rtol[dtypes.float64] = 1e-9 241 self._rtol[dtypes.complex64] = 2e-4 242 243 244class LinearOperatorLowRankUpdatetestNoDiagUseCholesky( 245 BaseLinearOperatorLowRankUpdatetest, 246 linear_operator_test_util.SquareLinearOperatorDerivedClassTest): 247 """A = L + UU^H, L > 0 ==> A > 0 and we can use a Cholesky.""" 248 249 _use_diag_update = False 250 _is_diag_update_positive = None 251 _use_v = False 252 253 def tearDown(self): 254 config.enable_tensor_float_32_execution(self.tf32_keep_) 255 256 def setUp(self): 257 self.tf32_keep_ = config.tensor_float_32_execution_enabled() 258 config.enable_tensor_float_32_execution(False) 259 # Decrease tolerance since we are testing with condition numbers as high as 260 # 1e4. 261 self._atol[dtypes.float32] = 1e-5 262 self._rtol[dtypes.float32] = 1e-5 263 self._atol[dtypes.float64] = 1e-10 264 self._rtol[dtypes.float64] = 1e-10 265 self._rtol[dtypes.complex64] = 1e-4 266 267 268class LinearOperatorLowRankUpdatetestNoDiagCannotUseCholesky( 269 BaseLinearOperatorLowRankUpdatetest, 270 linear_operator_test_util.SquareLinearOperatorDerivedClassTest): 271 """A = L + UV^H, L > 0 ==> A is not symmetric and we cannot use a Cholesky.""" 272 273 @staticmethod 274 def skip_these_tests(): 275 return ["cholesky", "eigvalsh"] 276 277 _use_diag_update = False 278 _is_diag_update_positive = None 279 _use_v = True 280 281 def tearDown(self): 282 config.enable_tensor_float_32_execution(self.tf32_keep_) 283 284 def setUp(self): 285 self.tf32_keep_ = config.tensor_float_32_execution_enabled() 286 config.enable_tensor_float_32_execution(False) 287 # Decrease tolerance since we are testing with condition numbers as high as 288 # 1e4. This class does not use Cholesky, and thus needs even looser 289 # tolerance. 290 self._atol[dtypes.float32] = 1e-4 291 self._rtol[dtypes.float32] = 1e-4 292 self._atol[dtypes.float64] = 1e-9 293 self._rtol[dtypes.float64] = 1e-9 294 self._atol[dtypes.complex64] = 1e-5 295 self._rtol[dtypes.complex64] = 2e-4 296 297 298class LinearOperatorLowRankUpdatetestWithDiagNotSquare( 299 BaseLinearOperatorLowRankUpdatetest, 300 linear_operator_test_util.NonSquareLinearOperatorDerivedClassTest): 301 """A = L + UDU^H, D > 0, L > 0 ==> A > 0 and we can use a Cholesky.""" 302 303 _use_diag_update = True 304 _is_diag_update_positive = True 305 _use_v = True 306 307 def tearDown(self): 308 config.enable_tensor_float_32_execution(self.tf32_keep_) 309 310 def setUp(self): 311 self.tf32_keep_ = config.tensor_float_32_execution_enabled() 312 config.enable_tensor_float_32_execution(False) 313 314 315@test_util.run_all_without_tensor_float_32( 316 "Linear op calls matmul which uses TensorFloat-32.") 317class LinearOperatorLowRankUpdateBroadcastsShape(test.TestCase): 318 """Test that the operator's shape is the broadcast of arguments.""" 319 320 def test_static_shape_broadcasts_up_from_operator_to_other_args(self): 321 base_operator = linalg.LinearOperatorIdentity(num_rows=3) 322 u = array_ops.ones(shape=[2, 3, 2]) 323 diag = array_ops.ones(shape=[2, 2]) 324 325 operator = linalg.LinearOperatorLowRankUpdate(base_operator, u, diag) 326 327 # domain_dimension is 3 328 self.assertAllEqual([2, 3, 3], operator.shape) 329 self.assertAllEqual([2, 3, 3], self.evaluate(operator.to_dense()).shape) 330 331 @test_util.run_deprecated_v1 332 def test_dynamic_shape_broadcasts_up_from_operator_to_other_args(self): 333 num_rows_ph = array_ops.placeholder(dtypes.int32) 334 base_operator = linalg.LinearOperatorIdentity(num_rows=num_rows_ph) 335 336 u_shape_ph = array_ops.placeholder(dtypes.int32) 337 u = array_ops.ones(shape=u_shape_ph) 338 339 v_shape_ph = array_ops.placeholder(dtypes.int32) 340 v = array_ops.ones(shape=v_shape_ph) 341 342 diag_shape_ph = array_ops.placeholder(dtypes.int32) 343 diag_update = array_ops.ones(shape=diag_shape_ph) 344 345 operator = linalg.LinearOperatorLowRankUpdate(base_operator, 346 u=u, 347 diag_update=diag_update, 348 v=v) 349 350 feed_dict = { 351 num_rows_ph: 3, 352 u_shape_ph: [1, 1, 2, 3, 2], # batch_shape = [1, 1, 2] 353 v_shape_ph: [1, 2, 1, 3, 2], # batch_shape = [1, 2, 1] 354 diag_shape_ph: [2, 1, 1, 2] # batch_shape = [2, 1, 1] 355 } 356 357 with self.cached_session(): 358 shape_tensor = operator.shape_tensor().eval(feed_dict=feed_dict) 359 self.assertAllEqual([2, 2, 2, 3, 3], shape_tensor) 360 dense = operator.to_dense().eval(feed_dict=feed_dict) 361 self.assertAllEqual([2, 2, 2, 3, 3], dense.shape) 362 363 def test_u_and_v_incompatible_batch_shape_raises(self): 364 base_operator = linalg.LinearOperatorIdentity(num_rows=3, dtype=np.float64) 365 u = rng.rand(5, 3, 2) 366 v = rng.rand(4, 3, 2) 367 with self.assertRaisesRegex(ValueError, "Incompatible shapes"): 368 linalg.LinearOperatorLowRankUpdate(base_operator, u=u, v=v) 369 370 def test_u_and_base_operator_incompatible_batch_shape_raises(self): 371 base_operator = linalg.LinearOperatorIdentity( 372 num_rows=3, batch_shape=[4], dtype=np.float64) 373 u = rng.rand(5, 3, 2) 374 with self.assertRaisesRegex(ValueError, "Incompatible shapes"): 375 linalg.LinearOperatorLowRankUpdate(base_operator, u=u) 376 377 def test_u_and_base_operator_incompatible_domain_dimension(self): 378 base_operator = linalg.LinearOperatorIdentity(num_rows=3, dtype=np.float64) 379 u = rng.rand(5, 4, 2) 380 with self.assertRaisesRegex(ValueError, "not compatible"): 381 linalg.LinearOperatorLowRankUpdate(base_operator, u=u) 382 383 def test_u_and_diag_incompatible_low_rank_raises(self): 384 base_operator = linalg.LinearOperatorIdentity(num_rows=3, dtype=np.float64) 385 u = rng.rand(5, 3, 2) 386 diag = rng.rand(5, 4) # Last dimension should be 2 387 with self.assertRaisesRegex(ValueError, "not compatible"): 388 linalg.LinearOperatorLowRankUpdate(base_operator, u=u, diag_update=diag) 389 390 def test_diag_incompatible_batch_shape_raises(self): 391 base_operator = linalg.LinearOperatorIdentity(num_rows=3, dtype=np.float64) 392 u = rng.rand(5, 3, 2) 393 diag = rng.rand(4, 2) # First dimension should be 5 394 with self.assertRaisesRegex(ValueError, "Incompatible shapes"): 395 linalg.LinearOperatorLowRankUpdate(base_operator, u=u, diag_update=diag) 396 397 398if __name__ == "__main__": 399 linear_operator_test_util.add_tests( 400 LinearOperatorLowRankUpdatetestWithDiagUseCholesky) 401 linear_operator_test_util.add_tests( 402 LinearOperatorLowRankUpdatetestWithDiagCannotUseCholesky) 403 linear_operator_test_util.add_tests( 404 LinearOperatorLowRankUpdatetestNoDiagUseCholesky) 405 linear_operator_test_util.add_tests( 406 LinearOperatorLowRankUpdatetestNoDiagCannotUseCholesky) 407 linear_operator_test_util.add_tests( 408 LinearOperatorLowRankUpdatetestWithDiagNotSquare) 409 test.main() 410