1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16import numpy as np 17 18from tensorflow.python.framework import config 19from tensorflow.python.framework import dtypes 20from tensorflow.python.framework import test_util 21from tensorflow.python.ops import array_ops 22from tensorflow.python.ops import linalg_ops 23from tensorflow.python.ops import math_ops 24from tensorflow.python.ops import random_ops 25from tensorflow.python.ops import variables as variables_module 26from tensorflow.python.ops.linalg import linalg as linalg_lib 27from tensorflow.python.ops.linalg import linear_operator_test_util 28from tensorflow.python.platform import test 29 30 31rng = np.random.RandomState(2016) 32 33 34@test_util.run_all_in_graph_and_eager_modes 35class LinearOperatorIdentityTest( 36 linear_operator_test_util.SquareLinearOperatorDerivedClassTest): 37 """Most tests done in the base class LinearOperatorDerivedClassTest.""" 38 39 def tearDown(self): 40 config.enable_tensor_float_32_execution(self.tf32_keep_) 41 42 def setUp(self): 43 self.tf32_keep_ = config.tensor_float_32_execution_enabled() 44 config.enable_tensor_float_32_execution(False) 45 46 @staticmethod 47 def dtypes_to_test(): 48 # TODO(langmore) Test tf.float16 once tf.linalg.solve works in 49 # 16bit. 50 return [dtypes.float32, dtypes.float64, dtypes.complex64, dtypes.complex128] 51 52 @staticmethod 53 def optional_tests(): 54 """List of optional test names to run.""" 55 return [ 56 "operator_matmul_with_same_type", 57 "operator_solve_with_same_type", 58 ] 59 60 def operator_and_matrix( 61 self, build_info, dtype, use_placeholder, 62 ensure_self_adjoint_and_pd=False): 63 # Identity matrix is already Hermitian Positive Definite. 64 del ensure_self_adjoint_and_pd 65 66 shape = list(build_info.shape) 67 assert shape[-1] == shape[-2] 68 69 batch_shape = shape[:-2] 70 num_rows = shape[-1] 71 72 operator = linalg_lib.LinearOperatorIdentity( 73 num_rows, batch_shape=batch_shape, dtype=dtype) 74 mat = linalg_ops.eye(num_rows, batch_shape=batch_shape, dtype=dtype) 75 76 return operator, mat 77 78 def test_assert_positive_definite(self): 79 with self.cached_session(): 80 operator = linalg_lib.LinearOperatorIdentity(num_rows=2) 81 self.evaluate(operator.assert_positive_definite()) # Should not fail 82 83 def test_assert_non_singular(self): 84 with self.cached_session(): 85 operator = linalg_lib.LinearOperatorIdentity(num_rows=2) 86 self.evaluate(operator.assert_non_singular()) # Should not fail 87 88 def test_assert_self_adjoint(self): 89 with self.cached_session(): 90 operator = linalg_lib.LinearOperatorIdentity(num_rows=2) 91 self.evaluate(operator.assert_self_adjoint()) # Should not fail 92 93 def test_float16_matmul(self): 94 # float16 cannot be tested by base test class because tf.linalg.solve does 95 # not work with float16. 96 with self.cached_session(): 97 operator = linalg_lib.LinearOperatorIdentity( 98 num_rows=2, dtype=dtypes.float16) 99 x = rng.randn(2, 3).astype(np.float16) 100 y = operator.matmul(x) 101 self.assertAllClose(x, self.evaluate(y)) 102 103 def test_non_scalar_num_rows_raises_static(self): 104 with self.assertRaisesRegex(ValueError, "must be a 0-D Tensor"): 105 linalg_lib.LinearOperatorIdentity(num_rows=[2]) 106 107 def test_non_integer_num_rows_raises_static(self): 108 with self.assertRaisesRegex(TypeError, "must be integer"): 109 linalg_lib.LinearOperatorIdentity(num_rows=2.) 110 111 def test_negative_num_rows_raises_static(self): 112 with self.assertRaisesRegex(ValueError, "must be non-negative"): 113 linalg_lib.LinearOperatorIdentity(num_rows=-2) 114 115 def test_non_1d_batch_shape_raises_static(self): 116 with self.assertRaisesRegex(ValueError, "must be a 1-D"): 117 linalg_lib.LinearOperatorIdentity(num_rows=2, batch_shape=2) 118 119 def test_non_integer_batch_shape_raises_static(self): 120 with self.assertRaisesRegex(TypeError, "must be integer"): 121 linalg_lib.LinearOperatorIdentity(num_rows=2, batch_shape=[2.]) 122 123 def test_negative_batch_shape_raises_static(self): 124 with self.assertRaisesRegex(ValueError, "must be non-negative"): 125 linalg_lib.LinearOperatorIdentity(num_rows=2, batch_shape=[-2]) 126 127 def test_non_scalar_num_rows_raises_dynamic(self): 128 with self.cached_session(): 129 num_rows = array_ops.placeholder_with_default([2], shape=None) 130 131 with self.assertRaisesError("must be a 0-D Tensor"): 132 operator = linalg_lib.LinearOperatorIdentity( 133 num_rows, assert_proper_shapes=True) 134 self.evaluate(operator.to_dense()) 135 136 def test_negative_num_rows_raises_dynamic(self): 137 with self.cached_session(): 138 num_rows = array_ops.placeholder_with_default(-2, shape=None) 139 with self.assertRaisesError("must be non-negative"): 140 operator = linalg_lib.LinearOperatorIdentity( 141 num_rows, assert_proper_shapes=True) 142 self.evaluate(operator.to_dense()) 143 144 def test_non_1d_batch_shape_raises_dynamic(self): 145 with self.cached_session(): 146 batch_shape = array_ops.placeholder_with_default(2, shape=None) 147 with self.assertRaisesError("must be a 1-D"): 148 operator = linalg_lib.LinearOperatorIdentity( 149 num_rows=2, batch_shape=batch_shape, assert_proper_shapes=True) 150 self.evaluate(operator.to_dense()) 151 152 def test_negative_batch_shape_raises_dynamic(self): 153 with self.cached_session(): 154 batch_shape = array_ops.placeholder_with_default([-2], shape=None) 155 with self.assertRaisesError("must be non-negative"): 156 operator = linalg_lib.LinearOperatorIdentity( 157 num_rows=2, batch_shape=batch_shape, assert_proper_shapes=True) 158 self.evaluate(operator.to_dense()) 159 160 def test_wrong_matrix_dimensions_raises_static(self): 161 operator = linalg_lib.LinearOperatorIdentity(num_rows=2) 162 x = rng.randn(3, 3).astype(np.float32) 163 with self.assertRaisesRegex(ValueError, "Dimensions.*not compatible"): 164 operator.matmul(x) 165 166 def test_wrong_matrix_dimensions_raises_dynamic(self): 167 num_rows = array_ops.placeholder_with_default(2, shape=None) 168 x = array_ops.placeholder_with_default( 169 rng.rand(3, 3).astype(np.float32), shape=None) 170 171 with self.cached_session(): 172 with self.assertRaisesError("Dimensions.*not.compatible"): 173 operator = linalg_lib.LinearOperatorIdentity( 174 num_rows, assert_proper_shapes=True) 175 self.evaluate(operator.matmul(x)) 176 177 def test_default_batch_shape_broadcasts_with_everything_static(self): 178 # These cannot be done in the automated (base test class) tests since they 179 # test shapes that tf.batch_matmul cannot handle. 180 # In particular, tf.batch_matmul does not broadcast. 181 with self.cached_session() as sess: 182 x = random_ops.random_normal(shape=(1, 2, 3, 4)) 183 operator = linalg_lib.LinearOperatorIdentity(num_rows=3, dtype=x.dtype) 184 185 operator_matmul = operator.matmul(x) 186 expected = x 187 188 self.assertAllEqual(operator_matmul.shape, expected.shape) 189 self.assertAllClose(*self.evaluate([operator_matmul, expected])) 190 191 def test_default_batch_shape_broadcasts_with_everything_dynamic(self): 192 # These cannot be done in the automated (base test class) tests since they 193 # test shapes that tf.batch_matmul cannot handle. 194 # In particular, tf.batch_matmul does not broadcast. 195 with self.cached_session(): 196 x = array_ops.placeholder_with_default(rng.randn(1, 2, 3, 4), shape=None) 197 operator = linalg_lib.LinearOperatorIdentity(num_rows=3, dtype=x.dtype) 198 199 operator_matmul = operator.matmul(x) 200 expected = x 201 202 self.assertAllClose(*self.evaluate([operator_matmul, expected])) 203 204 def test_broadcast_matmul_static_shapes(self): 205 # These cannot be done in the automated (base test class) tests since they 206 # test shapes that tf.batch_matmul cannot handle. 207 # In particular, tf.batch_matmul does not broadcast. 208 with self.cached_session() as sess: 209 # Given this x and LinearOperatorIdentity shape of (2, 1, 3, 3), the 210 # broadcast shape of operator and 'x' is (2, 2, 3, 4) 211 x = random_ops.random_normal(shape=(1, 2, 3, 4)) 212 operator = linalg_lib.LinearOperatorIdentity( 213 num_rows=3, batch_shape=(2, 1), dtype=x.dtype) 214 215 # Batch matrix of zeros with the broadcast shape of x and operator. 216 zeros = array_ops.zeros(shape=(2, 2, 3, 4), dtype=x.dtype) 217 218 # Expected result of matmul and solve. 219 expected = x + zeros 220 221 operator_matmul = operator.matmul(x) 222 self.assertAllEqual(operator_matmul.shape, expected.shape) 223 self.assertAllClose(*self.evaluate([operator_matmul, expected])) 224 225 def test_broadcast_matmul_dynamic_shapes(self): 226 # These cannot be done in the automated (base test class) tests since they 227 # test shapes that tf.batch_matmul cannot handle. 228 # In particular, tf.batch_matmul does not broadcast. 229 with self.cached_session(): 230 # Given this x and LinearOperatorIdentity shape of (2, 1, 3, 3), the 231 # broadcast shape of operator and 'x' is (2, 2, 3, 4) 232 x = array_ops.placeholder_with_default(rng.rand(1, 2, 3, 4), shape=None) 233 num_rows = array_ops.placeholder_with_default(3, shape=None) 234 batch_shape = array_ops.placeholder_with_default((2, 1), shape=None) 235 236 operator = linalg_lib.LinearOperatorIdentity( 237 num_rows, batch_shape=batch_shape, dtype=dtypes.float64) 238 239 # Batch matrix of zeros with the broadcast shape of x and operator. 240 zeros = array_ops.zeros(shape=(2, 2, 3, 4), dtype=x.dtype) 241 242 # Expected result of matmul and solve. 243 expected = x + zeros 244 245 operator_matmul = operator.matmul(x) 246 self.assertAllClose(*self.evaluate([operator_matmul, expected])) 247 248 def test_is_x_flags(self): 249 # The is_x flags are by default all True. 250 operator = linalg_lib.LinearOperatorIdentity(num_rows=2) 251 self.assertTrue(operator.is_positive_definite) 252 self.assertTrue(operator.is_non_singular) 253 self.assertTrue(operator.is_self_adjoint) 254 255 # Any of them False raises because the identity is always self-adjoint etc.. 256 with self.assertRaisesRegex(ValueError, "is always non-singular"): 257 operator = linalg_lib.LinearOperatorIdentity( 258 num_rows=2, 259 is_non_singular=None, 260 ) 261 262 def test_identity_adjoint_type(self): 263 operator = linalg_lib.LinearOperatorIdentity( 264 num_rows=2, is_non_singular=True) 265 self.assertIsInstance( 266 operator.adjoint(), linalg_lib.LinearOperatorIdentity) 267 268 def test_identity_cholesky_type(self): 269 operator = linalg_lib.LinearOperatorIdentity( 270 num_rows=2, 271 is_positive_definite=True, 272 is_self_adjoint=True, 273 ) 274 self.assertIsInstance( 275 operator.cholesky(), linalg_lib.LinearOperatorIdentity) 276 277 def test_identity_inverse_type(self): 278 operator = linalg_lib.LinearOperatorIdentity( 279 num_rows=2, is_non_singular=True) 280 self.assertIsInstance( 281 operator.inverse(), linalg_lib.LinearOperatorIdentity) 282 283 def test_ref_type_shape_args_raises(self): 284 with self.assertRaisesRegex(TypeError, "num_rows.*reference"): 285 linalg_lib.LinearOperatorIdentity(num_rows=variables_module.Variable(2)) 286 287 with self.assertRaisesRegex(TypeError, "batch_shape.*reference"): 288 linalg_lib.LinearOperatorIdentity( 289 num_rows=2, batch_shape=variables_module.Variable([3])) 290 291 292@test_util.run_all_in_graph_and_eager_modes 293class LinearOperatorScaledIdentityTest( 294 linear_operator_test_util.SquareLinearOperatorDerivedClassTest): 295 """Most tests done in the base class LinearOperatorDerivedClassTest.""" 296 297 def tearDown(self): 298 config.enable_tensor_float_32_execution(self.tf32_keep_) 299 300 def setUp(self): 301 self.tf32_keep_ = config.tensor_float_32_execution_enabled() 302 config.enable_tensor_float_32_execution(False) 303 304 @staticmethod 305 def dtypes_to_test(): 306 # TODO(langmore) Test tf.float16 once tf.linalg.solve works in 307 # 16bit. 308 return [dtypes.float32, dtypes.float64, dtypes.complex64, dtypes.complex128] 309 310 @staticmethod 311 def optional_tests(): 312 """List of optional test names to run.""" 313 return [ 314 "operator_matmul_with_same_type", 315 "operator_solve_with_same_type", 316 ] 317 318 def operator_and_matrix( 319 self, build_info, dtype, use_placeholder, 320 ensure_self_adjoint_and_pd=False): 321 322 shape = list(build_info.shape) 323 assert shape[-1] == shape[-2] 324 325 batch_shape = shape[:-2] 326 num_rows = shape[-1] 327 328 # Uniform values that are at least length 1 from the origin. Allows the 329 # operator to be well conditioned. 330 # Shape batch_shape 331 multiplier = linear_operator_test_util.random_sign_uniform( 332 shape=batch_shape, minval=1., maxval=2., dtype=dtype) 333 334 if ensure_self_adjoint_and_pd: 335 # Abs on complex64 will result in a float32, so we cast back up. 336 multiplier = math_ops.cast(math_ops.abs(multiplier), dtype=dtype) 337 338 # Nothing to feed since LinearOperatorScaledIdentity takes no Tensor args. 339 lin_op_multiplier = multiplier 340 341 if use_placeholder: 342 lin_op_multiplier = array_ops.placeholder_with_default( 343 multiplier, shape=None) 344 345 operator = linalg_lib.LinearOperatorScaledIdentity( 346 num_rows, 347 lin_op_multiplier, 348 is_self_adjoint=True if ensure_self_adjoint_and_pd else None, 349 is_positive_definite=True if ensure_self_adjoint_and_pd else None) 350 351 multiplier_matrix = array_ops.expand_dims( 352 array_ops.expand_dims(multiplier, -1), -1) 353 matrix = multiplier_matrix * linalg_ops.eye( 354 num_rows, batch_shape=batch_shape, dtype=dtype) 355 356 return operator, matrix 357 358 def test_assert_positive_definite_does_not_raise_when_positive(self): 359 with self.cached_session(): 360 operator = linalg_lib.LinearOperatorScaledIdentity( 361 num_rows=2, multiplier=1.) 362 self.evaluate(operator.assert_positive_definite()) # Should not fail 363 364 def test_assert_positive_definite_raises_when_negative(self): 365 with self.cached_session(): 366 operator = linalg_lib.LinearOperatorScaledIdentity( 367 num_rows=2, multiplier=-1.) 368 with self.assertRaisesOpError("not positive definite"): 369 self.evaluate(operator.assert_positive_definite()) 370 371 def test_assert_non_singular_does_not_raise_when_non_singular(self): 372 with self.cached_session(): 373 operator = linalg_lib.LinearOperatorScaledIdentity( 374 num_rows=2, multiplier=[1., 2., 3.]) 375 self.evaluate(operator.assert_non_singular()) # Should not fail 376 377 def test_assert_non_singular_raises_when_singular(self): 378 with self.cached_session(): 379 operator = linalg_lib.LinearOperatorScaledIdentity( 380 num_rows=2, multiplier=[1., 2., 0.]) 381 with self.assertRaisesOpError("was singular"): 382 self.evaluate(operator.assert_non_singular()) 383 384 def test_assert_self_adjoint_does_not_raise_when_self_adjoint(self): 385 with self.cached_session(): 386 operator = linalg_lib.LinearOperatorScaledIdentity( 387 num_rows=2, multiplier=[1. + 0J]) 388 self.evaluate(operator.assert_self_adjoint()) # Should not fail 389 390 def test_assert_self_adjoint_raises_when_not_self_adjoint(self): 391 with self.cached_session(): 392 operator = linalg_lib.LinearOperatorScaledIdentity( 393 num_rows=2, multiplier=[1. + 1J]) 394 with self.assertRaisesOpError("not self-adjoint"): 395 self.evaluate(operator.assert_self_adjoint()) 396 397 def test_float16_matmul(self): 398 # float16 cannot be tested by base test class because tf.linalg.solve does 399 # not work with float16. 400 with self.cached_session(): 401 multiplier = rng.rand(3).astype(np.float16) 402 operator = linalg_lib.LinearOperatorScaledIdentity( 403 num_rows=2, multiplier=multiplier) 404 x = rng.randn(2, 3).astype(np.float16) 405 y = operator.matmul(x) 406 self.assertAllClose(multiplier[..., None, None] * x, self.evaluate(y)) 407 408 def test_non_scalar_num_rows_raises_static(self): 409 # Many "test_...num_rows" tests are performed in LinearOperatorIdentity. 410 with self.assertRaisesRegex(ValueError, "must be a 0-D Tensor"): 411 linalg_lib.LinearOperatorScaledIdentity( 412 num_rows=[2], multiplier=123.) 413 414 def test_wrong_matrix_dimensions_raises_static(self): 415 operator = linalg_lib.LinearOperatorScaledIdentity( 416 num_rows=2, multiplier=2.2) 417 x = rng.randn(3, 3).astype(np.float32) 418 with self.assertRaisesRegex(ValueError, "Dimensions.*not compatible"): 419 operator.matmul(x) 420 421 def test_wrong_matrix_dimensions_raises_dynamic(self): 422 num_rows = array_ops.placeholder_with_default(2, shape=None) 423 x = array_ops.placeholder_with_default( 424 rng.rand(3, 3).astype(np.float32), shape=None) 425 426 with self.cached_session(): 427 with self.assertRaisesError("Dimensions.*not.compatible"): 428 operator = linalg_lib.LinearOperatorScaledIdentity( 429 num_rows, 430 multiplier=[1., 2], 431 assert_proper_shapes=True) 432 self.evaluate(operator.matmul(x)) 433 434 def test_broadcast_matmul_and_solve(self): 435 # These cannot be done in the automated (base test class) tests since they 436 # test shapes that tf.batch_matmul cannot handle. 437 # In particular, tf.batch_matmul does not broadcast. 438 with self.cached_session() as sess: 439 # Given this x and LinearOperatorScaledIdentity shape of (2, 1, 3, 3), the 440 # broadcast shape of operator and 'x' is (2, 2, 3, 4) 441 x = random_ops.random_normal(shape=(1, 2, 3, 4)) 442 443 # operator is 2.2 * identity (with a batch shape). 444 operator = linalg_lib.LinearOperatorScaledIdentity( 445 num_rows=3, multiplier=2.2 * array_ops.ones((2, 1))) 446 447 # Batch matrix of zeros with the broadcast shape of x and operator. 448 zeros = array_ops.zeros(shape=(2, 2, 3, 4), dtype=x.dtype) 449 450 # Test matmul 451 expected = x * 2.2 + zeros 452 operator_matmul = operator.matmul(x) 453 self.assertAllEqual(operator_matmul.shape, expected.shape) 454 self.assertAllClose(*self.evaluate([operator_matmul, expected])) 455 456 # Test solve 457 expected = x / 2.2 + zeros 458 operator_solve = operator.solve(x) 459 self.assertAllEqual(operator_solve.shape, expected.shape) 460 self.assertAllClose(*self.evaluate([operator_solve, expected])) 461 462 def test_broadcast_matmul_and_solve_scalar_scale_multiplier(self): 463 # These cannot be done in the automated (base test class) tests since they 464 # test shapes that tf.batch_matmul cannot handle. 465 # In particular, tf.batch_matmul does not broadcast. 466 with self.cached_session() as sess: 467 # Given this x and LinearOperatorScaledIdentity shape of (3, 3), the 468 # broadcast shape of operator and 'x' is (1, 2, 3, 4), which is the same 469 # shape as x. 470 x = random_ops.random_normal(shape=(1, 2, 3, 4)) 471 472 # operator is 2.2 * identity (with a batch shape). 473 operator = linalg_lib.LinearOperatorScaledIdentity( 474 num_rows=3, multiplier=2.2) 475 476 # Test matmul 477 expected = x * 2.2 478 operator_matmul = operator.matmul(x) 479 self.assertAllEqual(operator_matmul.shape, expected.shape) 480 self.assertAllClose(*self.evaluate([operator_matmul, expected])) 481 482 # Test solve 483 expected = x / 2.2 484 operator_solve = operator.solve(x) 485 self.assertAllEqual(operator_solve.shape, expected.shape) 486 self.assertAllClose(*self.evaluate([operator_solve, expected])) 487 488 def test_is_x_flags(self): 489 operator = linalg_lib.LinearOperatorScaledIdentity( 490 num_rows=2, multiplier=1., 491 is_positive_definite=False, is_non_singular=True) 492 self.assertFalse(operator.is_positive_definite) 493 self.assertTrue(operator.is_non_singular) 494 self.assertTrue(operator.is_self_adjoint) # Auto-set due to real multiplier 495 496 def test_identity_matmul(self): 497 operator1 = linalg_lib.LinearOperatorIdentity(num_rows=2) 498 operator2 = linalg_lib.LinearOperatorScaledIdentity( 499 num_rows=2, multiplier=3.) 500 self.assertIsInstance( 501 operator1.matmul(operator1), 502 linalg_lib.LinearOperatorIdentity) 503 504 self.assertIsInstance( 505 operator1.matmul(operator1), 506 linalg_lib.LinearOperatorIdentity) 507 508 self.assertIsInstance( 509 operator2.matmul(operator2), 510 linalg_lib.LinearOperatorScaledIdentity) 511 512 operator_matmul = operator1.matmul(operator2) 513 self.assertIsInstance( 514 operator_matmul, 515 linalg_lib.LinearOperatorScaledIdentity) 516 self.assertAllClose(3., self.evaluate(operator_matmul.multiplier)) 517 518 operator_matmul = operator2.matmul(operator1) 519 self.assertIsInstance( 520 operator_matmul, 521 linalg_lib.LinearOperatorScaledIdentity) 522 self.assertAllClose(3., self.evaluate(operator_matmul.multiplier)) 523 524 def test_identity_solve(self): 525 operator1 = linalg_lib.LinearOperatorIdentity(num_rows=2) 526 operator2 = linalg_lib.LinearOperatorScaledIdentity( 527 num_rows=2, multiplier=3.) 528 self.assertIsInstance( 529 operator1.solve(operator1), 530 linalg_lib.LinearOperatorIdentity) 531 532 self.assertIsInstance( 533 operator2.solve(operator2), 534 linalg_lib.LinearOperatorScaledIdentity) 535 536 operator_solve = operator1.solve(operator2) 537 self.assertIsInstance( 538 operator_solve, 539 linalg_lib.LinearOperatorScaledIdentity) 540 self.assertAllClose(3., self.evaluate(operator_solve.multiplier)) 541 542 operator_solve = operator2.solve(operator1) 543 self.assertIsInstance( 544 operator_solve, 545 linalg_lib.LinearOperatorScaledIdentity) 546 self.assertAllClose(1. / 3., self.evaluate(operator_solve.multiplier)) 547 548 def test_scaled_identity_cholesky_type(self): 549 operator = linalg_lib.LinearOperatorScaledIdentity( 550 num_rows=2, 551 multiplier=3., 552 is_positive_definite=True, 553 is_self_adjoint=True, 554 ) 555 self.assertIsInstance( 556 operator.cholesky(), 557 linalg_lib.LinearOperatorScaledIdentity) 558 559 def test_scaled_identity_inverse_type(self): 560 operator = linalg_lib.LinearOperatorScaledIdentity( 561 num_rows=2, 562 multiplier=3., 563 is_non_singular=True, 564 ) 565 self.assertIsInstance( 566 operator.inverse(), 567 linalg_lib.LinearOperatorScaledIdentity) 568 569 def test_ref_type_shape_args_raises(self): 570 with self.assertRaisesRegex(TypeError, "num_rows.*reference"): 571 linalg_lib.LinearOperatorScaledIdentity( 572 num_rows=variables_module.Variable(2), multiplier=1.23) 573 574 def test_tape_safe(self): 575 multiplier = variables_module.Variable(1.23) 576 operator = linalg_lib.LinearOperatorScaledIdentity( 577 num_rows=2, multiplier=multiplier) 578 self.check_tape_safe(operator) 579 580 def test_convert_variables_to_tensors(self): 581 multiplier = variables_module.Variable(1.23) 582 operator = linalg_lib.LinearOperatorScaledIdentity( 583 num_rows=2, multiplier=multiplier) 584 with self.cached_session() as sess: 585 sess.run([multiplier.initializer]) 586 self.check_convert_variables_to_tensors(operator) 587 588 589if __name__ == "__main__": 590 linear_operator_test_util.add_tests(LinearOperatorIdentityTest) 591 linear_operator_test_util.add_tests(LinearOperatorScaledIdentityTest) 592 test.main() 593