1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16import numpy as np 17 18from tensorflow.python.framework import test_util 19from tensorflow.python.ops import linalg_ops 20from tensorflow.python.ops.linalg import linalg as linalg_lib 21from tensorflow.python.ops.linalg import linear_operator_addition 22from tensorflow.python.platform import test 23 24linalg = linalg_lib 25rng = np.random.RandomState(0) 26 27add_operators = linear_operator_addition.add_operators 28 29 30# pylint: disable=unused-argument 31class _BadAdder(linear_operator_addition._Adder): 32 """Adder that will fail if used.""" 33 34 def can_add(self, op1, op2): 35 raise AssertionError("BadAdder.can_add called!") 36 37 def _add(self, op1, op2, operator_name, hints): 38 raise AssertionError("This line should not be reached") 39 40 41# pylint: enable=unused-argument 42 43 44class LinearOperatorAdditionCorrectnessTest(test.TestCase): 45 """Tests correctness of addition with combinations of a few Adders. 46 47 Tests here are done with the _DEFAULT_ADDITION_TIERS, which means 48 add_operators should reduce all operators resulting in one single operator. 49 50 This shows that we are able to correctly combine adders using the tiered 51 system. All Adders should be tested separately, and there is no need to test 52 every Adder within this class. 53 """ 54 55 def test_one_operator_is_returned_unchanged(self): 56 op_a = linalg.LinearOperatorDiag([1., 1.]) 57 op_sum = add_operators([op_a]) 58 self.assertEqual(1, len(op_sum)) 59 self.assertIs(op_sum[0], op_a) 60 61 def test_at_least_one_operators_required(self): 62 with self.assertRaisesRegex(ValueError, "must contain at least one"): 63 add_operators([]) 64 65 def test_attempting_to_add_numbers_raises(self): 66 with self.assertRaisesRegex(TypeError, "contain only LinearOperator"): 67 add_operators([1, 2]) 68 69 @test_util.run_deprecated_v1 70 def test_two_diag_operators(self): 71 op_a = linalg.LinearOperatorDiag( 72 [1., 1.], is_positive_definite=True, name="A") 73 op_b = linalg.LinearOperatorDiag( 74 [2., 2.], is_positive_definite=True, name="B") 75 with self.cached_session(): 76 op_sum = add_operators([op_a, op_b]) 77 self.assertEqual(1, len(op_sum)) 78 op = op_sum[0] 79 self.assertIsInstance(op, linalg_lib.LinearOperatorDiag) 80 self.assertAllClose([[3., 0.], [0., 3.]], op.to_dense()) 81 # Adding positive definite operators produces positive def. 82 self.assertTrue(op.is_positive_definite) 83 # Real diagonal ==> self-adjoint. 84 self.assertTrue(op.is_self_adjoint) 85 # Positive definite ==> non-singular 86 self.assertTrue(op.is_non_singular) 87 # Enforce particular name for this simple case 88 self.assertEqual("Add/B__A/", op.name) 89 90 @test_util.run_deprecated_v1 91 def test_three_diag_operators(self): 92 op1 = linalg.LinearOperatorDiag( 93 [1., 1.], is_positive_definite=True, name="op1") 94 op2 = linalg.LinearOperatorDiag( 95 [2., 2.], is_positive_definite=True, name="op2") 96 op3 = linalg.LinearOperatorDiag( 97 [3., 3.], is_positive_definite=True, name="op3") 98 with self.cached_session(): 99 op_sum = add_operators([op1, op2, op3]) 100 self.assertEqual(1, len(op_sum)) 101 op = op_sum[0] 102 self.assertTrue(isinstance(op, linalg_lib.LinearOperatorDiag)) 103 self.assertAllClose([[6., 0.], [0., 6.]], op.to_dense()) 104 # Adding positive definite operators produces positive def. 105 self.assertTrue(op.is_positive_definite) 106 # Real diagonal ==> self-adjoint. 107 self.assertTrue(op.is_self_adjoint) 108 # Positive definite ==> non-singular 109 self.assertTrue(op.is_non_singular) 110 111 @test_util.run_deprecated_v1 112 def test_diag_tril_diag(self): 113 op1 = linalg.LinearOperatorDiag( 114 [1., 1.], is_non_singular=True, name="diag_a") 115 op2 = linalg.LinearOperatorLowerTriangular( 116 [[2., 0.], [0., 2.]], 117 is_self_adjoint=True, 118 is_non_singular=True, 119 name="tril") 120 op3 = linalg.LinearOperatorDiag( 121 [3., 3.], is_non_singular=True, name="diag_b") 122 with self.cached_session(): 123 op_sum = add_operators([op1, op2, op3]) 124 self.assertEqual(1, len(op_sum)) 125 op = op_sum[0] 126 self.assertIsInstance(op, linalg_lib.LinearOperatorLowerTriangular) 127 self.assertAllClose([[6., 0.], [0., 6.]], op.to_dense()) 128 129 # The diag operators will be self-adjoint (because real and diagonal). 130 # The TriL operator has the self-adjoint hint set. 131 self.assertTrue(op.is_self_adjoint) 132 133 # Even though op1/2/3 are non-singular, this does not imply op is. 134 # Since no custom hint was provided, we default to None (unknown). 135 self.assertEqual(None, op.is_non_singular) 136 137 @test_util.run_deprecated_v1 138 def test_matrix_diag_tril_diag_uses_custom_name(self): 139 op0 = linalg.LinearOperatorFullMatrix( 140 [[-1., -1.], [-1., -1.]], name="matrix") 141 op1 = linalg.LinearOperatorDiag([1., 1.], name="diag_a") 142 op2 = linalg.LinearOperatorLowerTriangular( 143 [[2., 0.], [1.5, 2.]], name="tril") 144 op3 = linalg.LinearOperatorDiag([3., 3.], name="diag_b") 145 with self.cached_session(): 146 op_sum = add_operators([op0, op1, op2, op3], operator_name="my_operator") 147 self.assertEqual(1, len(op_sum)) 148 op = op_sum[0] 149 self.assertIsInstance(op, linalg_lib.LinearOperatorFullMatrix) 150 self.assertAllClose([[5., -1.], [0.5, 5.]], op.to_dense()) 151 self.assertEqual("my_operator", op.name) 152 153 def test_incompatible_domain_dimensions_raises(self): 154 op1 = linalg.LinearOperatorFullMatrix(rng.rand(2, 3)) 155 op2 = linalg.LinearOperatorDiag(rng.rand(2, 4)) 156 with self.assertRaisesRegex(ValueError, "must.*same `domain_dimension`"): 157 add_operators([op1, op2]) 158 159 def test_incompatible_range_dimensions_raises(self): 160 op1 = linalg.LinearOperatorFullMatrix(rng.rand(2, 3)) 161 op2 = linalg.LinearOperatorDiag(rng.rand(3, 3)) 162 with self.assertRaisesRegex(ValueError, "must.*same `range_dimension`"): 163 add_operators([op1, op2]) 164 165 def test_non_broadcastable_batch_shape_raises(self): 166 op1 = linalg.LinearOperatorFullMatrix(rng.rand(2, 3, 3)) 167 op2 = linalg.LinearOperatorDiag(rng.rand(4, 3, 3)) 168 with self.assertRaisesRegex(ValueError, "Incompatible shapes"): 169 add_operators([op1, op2]) 170 171 172class LinearOperatorOrderOfAdditionTest(test.TestCase): 173 """Test that the order of addition is done as specified by tiers.""" 174 175 def test_tier_0_additions_done_in_tier_0(self): 176 diag1 = linalg.LinearOperatorDiag([1.]) 177 diag2 = linalg.LinearOperatorDiag([1.]) 178 diag3 = linalg.LinearOperatorDiag([1.]) 179 addition_tiers = [ 180 [linear_operator_addition._AddAndReturnDiag()], 181 [_BadAdder()], 182 ] 183 # Should not raise since all were added in tier 0, and tier 1 (with the 184 # _BadAdder) was never reached. 185 op_sum = add_operators([diag1, diag2, diag3], addition_tiers=addition_tiers) 186 self.assertEqual(1, len(op_sum)) 187 self.assertIsInstance(op_sum[0], linalg.LinearOperatorDiag) 188 189 def test_tier_1_additions_done_by_tier_1(self): 190 diag1 = linalg.LinearOperatorDiag([1.]) 191 diag2 = linalg.LinearOperatorDiag([1.]) 192 tril = linalg.LinearOperatorLowerTriangular([[1.]]) 193 addition_tiers = [ 194 [linear_operator_addition._AddAndReturnDiag()], 195 [linear_operator_addition._AddAndReturnTriL()], 196 [_BadAdder()], 197 ] 198 # Should not raise since all were added by tier 1, and the 199 # _BadAdder) was never reached. 200 op_sum = add_operators([diag1, diag2, tril], addition_tiers=addition_tiers) 201 self.assertEqual(1, len(op_sum)) 202 self.assertIsInstance(op_sum[0], linalg.LinearOperatorLowerTriangular) 203 204 def test_tier_1_additions_done_by_tier_1_with_order_flipped(self): 205 diag1 = linalg.LinearOperatorDiag([1.]) 206 diag2 = linalg.LinearOperatorDiag([1.]) 207 tril = linalg.LinearOperatorLowerTriangular([[1.]]) 208 addition_tiers = [ 209 [linear_operator_addition._AddAndReturnTriL()], 210 [linear_operator_addition._AddAndReturnDiag()], 211 [_BadAdder()], 212 ] 213 # Tier 0 could convert to TriL, and this converted everything to TriL, 214 # including the Diags. 215 # Tier 1 was never used. 216 # Tier 2 was never used (therefore, _BadAdder didn't raise). 217 op_sum = add_operators([diag1, diag2, tril], addition_tiers=addition_tiers) 218 self.assertEqual(1, len(op_sum)) 219 self.assertIsInstance(op_sum[0], linalg.LinearOperatorLowerTriangular) 220 221 @test_util.run_deprecated_v1 222 def test_cannot_add_everything_so_return_more_than_one_operator(self): 223 diag1 = linalg.LinearOperatorDiag([1.]) 224 diag2 = linalg.LinearOperatorDiag([2.]) 225 tril5 = linalg.LinearOperatorLowerTriangular([[5.]]) 226 addition_tiers = [ 227 [linear_operator_addition._AddAndReturnDiag()], 228 ] 229 # Tier 0 (the only tier) can only convert to Diag, so it combines the two 230 # diags, but the TriL is unchanged. 231 # Result should contain two operators, one Diag, one TriL. 232 op_sum = add_operators([diag1, diag2, tril5], addition_tiers=addition_tiers) 233 self.assertEqual(2, len(op_sum)) 234 found_diag = False 235 found_tril = False 236 with self.cached_session(): 237 for op in op_sum: 238 if isinstance(op, linalg.LinearOperatorDiag): 239 found_diag = True 240 self.assertAllClose([[3.]], op.to_dense()) 241 if isinstance(op, linalg.LinearOperatorLowerTriangular): 242 found_tril = True 243 self.assertAllClose([[5.]], op.to_dense()) 244 self.assertTrue(found_diag and found_tril) 245 246 def test_intermediate_tier_is_not_skipped(self): 247 diag1 = linalg.LinearOperatorDiag([1.]) 248 diag2 = linalg.LinearOperatorDiag([1.]) 249 tril = linalg.LinearOperatorLowerTriangular([[1.]]) 250 addition_tiers = [ 251 [linear_operator_addition._AddAndReturnDiag()], 252 [_BadAdder()], 253 [linear_operator_addition._AddAndReturnTriL()], 254 ] 255 # tril cannot be added in tier 0, and the intermediate tier 1 with the 256 # BadAdder will catch it and raise. 257 with self.assertRaisesRegex(AssertionError, "BadAdder.can_add called"): 258 add_operators([diag1, diag2, tril], addition_tiers=addition_tiers) 259 260 261class AddAndReturnScaledIdentityTest(test.TestCase): 262 263 def setUp(self): 264 self._adder = linear_operator_addition._AddAndReturnScaledIdentity() 265 266 @test_util.run_deprecated_v1 267 def test_identity_plus_identity(self): 268 id1 = linalg.LinearOperatorIdentity(num_rows=2) 269 id2 = linalg.LinearOperatorIdentity(num_rows=2, batch_shape=[3]) 270 hints = linear_operator_addition._Hints( 271 is_positive_definite=True, is_non_singular=True) 272 273 self.assertTrue(self._adder.can_add(id1, id2)) 274 operator = self._adder.add(id1, id2, "my_operator", hints) 275 self.assertIsInstance(operator, linalg.LinearOperatorScaledIdentity) 276 277 with self.cached_session(): 278 self.assertAllClose(2 * linalg_ops.eye(num_rows=2, batch_shape=[3]), 279 operator.to_dense()) 280 self.assertTrue(operator.is_positive_definite) 281 self.assertTrue(operator.is_non_singular) 282 self.assertEqual("my_operator", operator.name) 283 284 @test_util.run_deprecated_v1 285 def test_identity_plus_scaled_identity(self): 286 id1 = linalg.LinearOperatorIdentity(num_rows=2, batch_shape=[3]) 287 id2 = linalg.LinearOperatorScaledIdentity(num_rows=2, multiplier=2.2) 288 hints = linear_operator_addition._Hints( 289 is_positive_definite=True, is_non_singular=True) 290 291 self.assertTrue(self._adder.can_add(id1, id2)) 292 operator = self._adder.add(id1, id2, "my_operator", hints) 293 self.assertIsInstance(operator, linalg.LinearOperatorScaledIdentity) 294 295 with self.cached_session(): 296 self.assertAllClose(3.2 * linalg_ops.eye(num_rows=2, batch_shape=[3]), 297 operator.to_dense()) 298 self.assertTrue(operator.is_positive_definite) 299 self.assertTrue(operator.is_non_singular) 300 self.assertEqual("my_operator", operator.name) 301 302 @test_util.run_deprecated_v1 303 def test_scaled_identity_plus_scaled_identity(self): 304 id1 = linalg.LinearOperatorScaledIdentity( 305 num_rows=2, multiplier=[2.2, 2.2, 2.2]) 306 id2 = linalg.LinearOperatorScaledIdentity(num_rows=2, multiplier=-1.0) 307 hints = linear_operator_addition._Hints( 308 is_positive_definite=True, is_non_singular=True) 309 310 self.assertTrue(self._adder.can_add(id1, id2)) 311 operator = self._adder.add(id1, id2, "my_operator", hints) 312 self.assertIsInstance(operator, linalg.LinearOperatorScaledIdentity) 313 314 with self.cached_session(): 315 self.assertAllClose(1.2 * linalg_ops.eye(num_rows=2, batch_shape=[3]), 316 operator.to_dense()) 317 self.assertTrue(operator.is_positive_definite) 318 self.assertTrue(operator.is_non_singular) 319 self.assertEqual("my_operator", operator.name) 320 321 322class AddAndReturnDiagTest(test.TestCase): 323 324 def setUp(self): 325 self._adder = linear_operator_addition._AddAndReturnDiag() 326 327 @test_util.run_deprecated_v1 328 def test_identity_plus_identity_returns_diag(self): 329 id1 = linalg.LinearOperatorIdentity(num_rows=2) 330 id2 = linalg.LinearOperatorIdentity(num_rows=2, batch_shape=[3]) 331 hints = linear_operator_addition._Hints( 332 is_positive_definite=True, is_non_singular=True) 333 334 self.assertTrue(self._adder.can_add(id1, id2)) 335 operator = self._adder.add(id1, id2, "my_operator", hints) 336 self.assertIsInstance(operator, linalg.LinearOperatorDiag) 337 338 with self.cached_session(): 339 self.assertAllClose(2 * linalg_ops.eye(num_rows=2, batch_shape=[3]), 340 operator.to_dense()) 341 self.assertTrue(operator.is_positive_definite) 342 self.assertTrue(operator.is_non_singular) 343 self.assertEqual("my_operator", operator.name) 344 345 @test_util.run_deprecated_v1 346 def test_diag_plus_diag(self): 347 diag1 = rng.rand(2, 3, 4) 348 diag2 = rng.rand(4) 349 op1 = linalg.LinearOperatorDiag(diag1) 350 op2 = linalg.LinearOperatorDiag(diag2) 351 hints = linear_operator_addition._Hints( 352 is_positive_definite=True, is_non_singular=True) 353 354 self.assertTrue(self._adder.can_add(op1, op2)) 355 operator = self._adder.add(op1, op2, "my_operator", hints) 356 self.assertIsInstance(operator, linalg.LinearOperatorDiag) 357 358 with self.cached_session(): 359 self.assertAllClose( 360 linalg.LinearOperatorDiag(diag1 + diag2).to_dense(), 361 operator.to_dense()) 362 self.assertTrue(operator.is_positive_definite) 363 self.assertTrue(operator.is_non_singular) 364 self.assertEqual("my_operator", operator.name) 365 366 367class AddAndReturnTriLTest(test.TestCase): 368 369 def setUp(self): 370 self._adder = linear_operator_addition._AddAndReturnTriL() 371 372 @test_util.run_deprecated_v1 373 def test_diag_plus_tril(self): 374 diag = linalg.LinearOperatorDiag([1., 2.]) 375 tril = linalg.LinearOperatorLowerTriangular([[10., 0.], [30., 0.]]) 376 hints = linear_operator_addition._Hints( 377 is_positive_definite=True, is_non_singular=True) 378 379 self.assertTrue(self._adder.can_add(diag, diag)) 380 self.assertTrue(self._adder.can_add(diag, tril)) 381 operator = self._adder.add(diag, tril, "my_operator", hints) 382 self.assertIsInstance(operator, linalg.LinearOperatorLowerTriangular) 383 384 with self.cached_session(): 385 self.assertAllClose([[11., 0.], [30., 2.]], operator.to_dense()) 386 self.assertTrue(operator.is_positive_definite) 387 self.assertTrue(operator.is_non_singular) 388 self.assertEqual("my_operator", operator.name) 389 390 391class AddAndReturnMatrixTest(test.TestCase): 392 393 def setUp(self): 394 self._adder = linear_operator_addition._AddAndReturnMatrix() 395 396 @test_util.run_deprecated_v1 397 def test_diag_plus_diag(self): 398 diag1 = linalg.LinearOperatorDiag([1., 2.]) 399 diag2 = linalg.LinearOperatorDiag([-1., 3.]) 400 hints = linear_operator_addition._Hints( 401 is_positive_definite=False, is_non_singular=False) 402 403 self.assertTrue(self._adder.can_add(diag1, diag2)) 404 operator = self._adder.add(diag1, diag2, "my_operator", hints) 405 self.assertIsInstance(operator, linalg.LinearOperatorFullMatrix) 406 407 with self.cached_session(): 408 self.assertAllClose([[0., 0.], [0., 5.]], operator.to_dense()) 409 self.assertFalse(operator.is_positive_definite) 410 self.assertFalse(operator.is_non_singular) 411 self.assertEqual("my_operator", operator.name) 412 413 414if __name__ == "__main__": 415 test.main() 416