1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16import numpy as np 17 18from tensorflow.python.eager import context 19from tensorflow.python.framework import config 20from tensorflow.python.framework import dtypes 21from tensorflow.python.framework import test_util 22from tensorflow.python.ops import array_ops 23from tensorflow.python.ops import math_ops 24from tensorflow.python.ops import variables as variables_module 25from tensorflow.python.ops.linalg import linalg as linalg_lib 26from tensorflow.python.ops.linalg import linear_operator_block_diag as block_diag 27from tensorflow.python.ops.linalg import linear_operator_lower_triangular as lower_triangular 28from tensorflow.python.ops.linalg import linear_operator_test_util 29from tensorflow.python.ops.linalg import linear_operator_util 30from tensorflow.python.platform import test 31 32linalg = linalg_lib 33rng = np.random.RandomState(0) 34 35 36def _block_diag_dense(expected_shape, blocks): 37 """Convert a list of blocks, into a dense block diagonal matrix.""" 38 rows = [] 39 num_cols = 0 40 for block in blocks: 41 # Get the batch shape for the block. 42 batch_row_shape = array_ops.shape(block)[:-1] 43 44 zeros_to_pad_before_shape = array_ops.concat( 45 [batch_row_shape, [num_cols]], axis=-1) 46 zeros_to_pad_before = array_ops.zeros( 47 shape=zeros_to_pad_before_shape, dtype=block.dtype) 48 num_cols += array_ops.shape(block)[-1] 49 zeros_to_pad_after_shape = array_ops.concat( 50 [batch_row_shape, [expected_shape[-1] - num_cols]], axis=-1) 51 zeros_to_pad_after = array_ops.zeros( 52 zeros_to_pad_after_shape, dtype=block.dtype) 53 54 rows.append(array_ops.concat( 55 [zeros_to_pad_before, block, zeros_to_pad_after], axis=-1)) 56 57 return array_ops.concat(rows, axis=-2) 58 59 60@test_util.run_all_in_graph_and_eager_modes 61class SquareLinearOperatorBlockDiagTest( 62 linear_operator_test_util.SquareLinearOperatorDerivedClassTest): 63 """Most tests done in the base class LinearOperatorDerivedClassTest.""" 64 65 def tearDown(self): 66 config.enable_tensor_float_32_execution(self.tf32_keep_) 67 68 def setUp(self): 69 self.tf32_keep_ = config.tensor_float_32_execution_enabled() 70 config.enable_tensor_float_32_execution(False) 71 # Increase from 1e-6 to 1e-4 72 self._atol[dtypes.float32] = 1e-4 73 self._atol[dtypes.complex64] = 1e-4 74 self._rtol[dtypes.float32] = 1e-4 75 self._rtol[dtypes.complex64] = 1e-4 76 77 @staticmethod 78 def optional_tests(): 79 """List of optional test names to run.""" 80 return [ 81 "operator_matmul_with_same_type", 82 "operator_solve_with_same_type", 83 ] 84 85 @staticmethod 86 def operator_shapes_infos(): 87 shape_info = linear_operator_test_util.OperatorShapesInfo 88 return [ 89 shape_info((0, 0)), 90 shape_info((1, 1)), 91 shape_info((1, 3, 3)), 92 shape_info((5, 5), blocks=[(2, 2), (3, 3)]), 93 shape_info((3, 7, 7), blocks=[(1, 2, 2), (3, 2, 2), (1, 3, 3)]), 94 shape_info((2, 1, 5, 5), blocks=[(2, 1, 2, 2), (1, 3, 3)]), 95 ] 96 97 @staticmethod 98 def use_blockwise_arg(): 99 return True 100 101 def operator_and_matrix( 102 self, shape_info, dtype, use_placeholder, 103 ensure_self_adjoint_and_pd=False): 104 shape = list(shape_info.shape) 105 expected_blocks = ( 106 shape_info.__dict__["blocks"] if "blocks" in shape_info.__dict__ 107 else [shape]) 108 matrices = [ 109 linear_operator_test_util.random_positive_definite_matrix( 110 block_shape, dtype, force_well_conditioned=True) 111 for block_shape in expected_blocks 112 ] 113 114 lin_op_matrices = matrices 115 116 if use_placeholder: 117 lin_op_matrices = [ 118 array_ops.placeholder_with_default( 119 matrix, shape=None) for matrix in matrices] 120 121 operator = block_diag.LinearOperatorBlockDiag( 122 [linalg.LinearOperatorFullMatrix( 123 l, 124 is_square=True, 125 is_self_adjoint=True if ensure_self_adjoint_and_pd else None, 126 is_positive_definite=True if ensure_self_adjoint_and_pd else None) 127 for l in lin_op_matrices]) 128 129 # Should be auto-set. 130 self.assertTrue(operator.is_square) 131 132 # Broadcast the shapes. 133 expected_shape = list(shape_info.shape) 134 135 matrices = linear_operator_util.broadcast_matrix_batch_dims(matrices) 136 137 block_diag_dense = _block_diag_dense(expected_shape, matrices) 138 139 if not use_placeholder: 140 block_diag_dense.set_shape( 141 expected_shape[:-2] + [expected_shape[-1], expected_shape[-1]]) 142 143 return operator, block_diag_dense 144 145 def test_is_x_flags(self): 146 # Matrix with two positive eigenvalues, 1, and 1. 147 # The matrix values do not effect auto-setting of the flags. 148 matrix = [[1., 0.], [1., 1.]] 149 operator = block_diag.LinearOperatorBlockDiag( 150 [linalg.LinearOperatorFullMatrix(matrix)], 151 is_positive_definite=True, 152 is_non_singular=True, 153 is_self_adjoint=False) 154 self.assertTrue(operator.is_positive_definite) 155 self.assertTrue(operator.is_non_singular) 156 self.assertFalse(operator.is_self_adjoint) 157 158 def test_is_x_parameters(self): 159 matrix = [[1., 0.], [1., 1.]] 160 sub_operator = linalg.LinearOperatorFullMatrix(matrix) 161 operator = block_diag.LinearOperatorBlockDiag( 162 [sub_operator], 163 is_positive_definite=True, 164 is_non_singular=True, 165 is_self_adjoint=False) 166 self.assertEqual( 167 operator.parameters, 168 { 169 "name": None, 170 "is_square": True, 171 "is_positive_definite": True, 172 "is_self_adjoint": False, 173 "is_non_singular": True, 174 "operators": [sub_operator], 175 }) 176 self.assertEqual( 177 sub_operator.parameters, 178 { 179 "is_non_singular": None, 180 "is_positive_definite": None, 181 "is_self_adjoint": None, 182 "is_square": None, 183 "matrix": matrix, 184 "name": "LinearOperatorFullMatrix", 185 }) 186 187 def test_block_diag_adjoint_type(self): 188 matrix = [[1., 0.], [0., 1.]] 189 operator = block_diag.LinearOperatorBlockDiag( 190 [ 191 linalg.LinearOperatorFullMatrix( 192 matrix, 193 is_non_singular=True, 194 ), 195 linalg.LinearOperatorFullMatrix( 196 matrix, 197 is_non_singular=True, 198 ), 199 ], 200 is_non_singular=True, 201 ) 202 adjoint = operator.adjoint() 203 self.assertIsInstance( 204 adjoint, 205 block_diag.LinearOperatorBlockDiag) 206 self.assertEqual(2, len(adjoint.operators)) 207 208 def test_block_diag_cholesky_type(self): 209 matrix = [[1., 0.], [0., 1.]] 210 operator = block_diag.LinearOperatorBlockDiag( 211 [ 212 linalg.LinearOperatorFullMatrix( 213 matrix, 214 is_positive_definite=True, 215 is_self_adjoint=True, 216 ), 217 linalg.LinearOperatorFullMatrix( 218 matrix, 219 is_positive_definite=True, 220 is_self_adjoint=True, 221 ), 222 ], 223 is_positive_definite=True, 224 is_self_adjoint=True, 225 ) 226 cholesky_factor = operator.cholesky() 227 self.assertIsInstance( 228 cholesky_factor, 229 block_diag.LinearOperatorBlockDiag) 230 self.assertEqual(2, len(cholesky_factor.operators)) 231 self.assertIsInstance( 232 cholesky_factor.operators[0], 233 lower_triangular.LinearOperatorLowerTriangular) 234 self.assertIsInstance( 235 cholesky_factor.operators[1], 236 lower_triangular.LinearOperatorLowerTriangular 237 ) 238 239 def test_block_diag_inverse_type(self): 240 matrix = [[1., 0.], [0., 1.]] 241 operator = block_diag.LinearOperatorBlockDiag( 242 [ 243 linalg.LinearOperatorFullMatrix( 244 matrix, 245 is_non_singular=True, 246 ), 247 linalg.LinearOperatorFullMatrix( 248 matrix, 249 is_non_singular=True, 250 ), 251 ], 252 is_non_singular=True, 253 ) 254 inverse = operator.inverse() 255 self.assertIsInstance( 256 inverse, 257 block_diag.LinearOperatorBlockDiag) 258 self.assertEqual(2, len(inverse.operators)) 259 260 def test_block_diag_matmul_type(self): 261 matrices1 = [] 262 matrices2 = [] 263 for i in range(1, 5): 264 matrices1.append(linalg.LinearOperatorFullMatrix( 265 linear_operator_test_util.random_normal( 266 [2, i], dtype=dtypes.float32))) 267 268 matrices2.append(linalg.LinearOperatorFullMatrix( 269 linear_operator_test_util.random_normal( 270 [i, 3], dtype=dtypes.float32))) 271 272 operator1 = block_diag.LinearOperatorBlockDiag(matrices1, is_square=False) 273 operator2 = block_diag.LinearOperatorBlockDiag(matrices2, is_square=False) 274 275 expected_matrix = math_ops.matmul( 276 operator1.to_dense(), operator2.to_dense()) 277 actual_operator = operator1.matmul(operator2) 278 279 self.assertIsInstance( 280 actual_operator, block_diag.LinearOperatorBlockDiag) 281 actual_, expected_ = self.evaluate([ 282 actual_operator.to_dense(), expected_matrix]) 283 self.assertAllClose(actual_, expected_) 284 285 def test_block_diag_matmul_raises(self): 286 matrices1 = [] 287 for i in range(1, 5): 288 matrices1.append(linalg.LinearOperatorFullMatrix( 289 linear_operator_test_util.random_normal( 290 [2, i], dtype=dtypes.float32))) 291 operator1 = block_diag.LinearOperatorBlockDiag(matrices1, is_square=False) 292 operator2 = linalg.LinearOperatorFullMatrix( 293 linear_operator_test_util.random_normal( 294 [15, 3], dtype=dtypes.float32)) 295 296 with self.assertRaisesRegex(ValueError, "Operators are incompatible"): 297 operator1.matmul(operator2) 298 299 def test_block_diag_solve_type(self): 300 matrices1 = [] 301 matrices2 = [] 302 for i in range(1, 5): 303 matrices1.append(linalg.LinearOperatorFullMatrix( 304 linear_operator_test_util.random_tril_matrix( 305 [i, i], 306 dtype=dtypes.float32, 307 force_well_conditioned=True))) 308 309 matrices2.append(linalg.LinearOperatorFullMatrix( 310 linear_operator_test_util.random_normal( 311 [i, 3], dtype=dtypes.float32))) 312 313 operator1 = block_diag.LinearOperatorBlockDiag(matrices1) 314 operator2 = block_diag.LinearOperatorBlockDiag(matrices2, is_square=False) 315 316 expected_matrix = linalg.solve( 317 operator1.to_dense(), operator2.to_dense()) 318 actual_operator = operator1.solve(operator2) 319 320 self.assertIsInstance( 321 actual_operator, block_diag.LinearOperatorBlockDiag) 322 actual_, expected_ = self.evaluate([ 323 actual_operator.to_dense(), expected_matrix]) 324 self.assertAllClose(actual_, expected_) 325 326 def test_block_diag_solve_raises(self): 327 matrices1 = [] 328 for i in range(1, 5): 329 matrices1.append(linalg.LinearOperatorFullMatrix( 330 linear_operator_test_util.random_normal( 331 [i, i], dtype=dtypes.float32))) 332 operator1 = block_diag.LinearOperatorBlockDiag(matrices1) 333 operator2 = linalg.LinearOperatorFullMatrix( 334 linear_operator_test_util.random_normal( 335 [15, 3], dtype=dtypes.float32)) 336 337 with self.assertRaisesRegex(ValueError, "Operators are incompatible"): 338 operator1.solve(operator2) 339 340 def test_tape_safe(self): 341 matrices = [] 342 for _ in range(4): 343 matrices.append(variables_module.Variable( 344 linear_operator_test_util.random_positive_definite_matrix( 345 [2, 2], dtype=dtypes.float32, force_well_conditioned=True))) 346 347 operator = block_diag.LinearOperatorBlockDiag( 348 [linalg.LinearOperatorFullMatrix( 349 matrix, is_self_adjoint=True, 350 is_positive_definite=True) for matrix in matrices], 351 is_self_adjoint=True, 352 is_positive_definite=True, 353 ) 354 self.check_tape_safe(operator) 355 356 def test_convert_variables_to_tensors(self): 357 matrices = [] 358 for _ in range(3): 359 matrices.append(variables_module.Variable( 360 linear_operator_test_util.random_positive_definite_matrix( 361 [3, 3], dtype=dtypes.float32, force_well_conditioned=True))) 362 363 operator = block_diag.LinearOperatorBlockDiag( 364 [linalg.LinearOperatorFullMatrix( 365 matrix, is_self_adjoint=True, 366 is_positive_definite=True) for matrix in matrices], 367 is_self_adjoint=True, 368 is_positive_definite=True, 369 ) 370 with self.cached_session() as sess: 371 sess.run([x.initializer for x in operator.variables]) 372 self.check_convert_variables_to_tensors(operator) 373 374 def test_is_non_singular_auto_set(self): 375 # Matrix with two positive eigenvalues, 11 and 8. 376 # The matrix values do not effect auto-setting of the flags. 377 matrix = [[11., 0.], [1., 8.]] 378 operator_1 = linalg.LinearOperatorFullMatrix(matrix, is_non_singular=True) 379 operator_2 = linalg.LinearOperatorFullMatrix(matrix, is_non_singular=True) 380 381 operator = block_diag.LinearOperatorBlockDiag( 382 [operator_1, operator_2], 383 is_positive_definite=False, # No reason it HAS to be False... 384 is_non_singular=None) 385 self.assertFalse(operator.is_positive_definite) 386 self.assertTrue(operator.is_non_singular) 387 388 with self.assertRaisesRegex(ValueError, "always non-singular"): 389 block_diag.LinearOperatorBlockDiag( 390 [operator_1, operator_2], is_non_singular=False) 391 392 def test_name(self): 393 matrix = [[11., 0.], [1., 8.]] 394 operator_1 = linalg.LinearOperatorFullMatrix(matrix, name="left") 395 operator_2 = linalg.LinearOperatorFullMatrix(matrix, name="right") 396 397 operator = block_diag.LinearOperatorBlockDiag([operator_1, operator_2]) 398 399 self.assertEqual("left_ds_right", operator.name) 400 401 def test_different_dtypes_raises(self): 402 operators = [ 403 linalg.LinearOperatorFullMatrix(rng.rand(2, 3, 3)), 404 linalg.LinearOperatorFullMatrix(rng.rand(2, 3, 3).astype(np.float32)) 405 ] 406 with self.assertRaisesRegex(TypeError, "same dtype"): 407 block_diag.LinearOperatorBlockDiag(operators) 408 409 def test_empty_operators_raises(self): 410 with self.assertRaisesRegex(ValueError, "non-empty"): 411 block_diag.LinearOperatorBlockDiag([]) 412 413 def test_incompatible_input_blocks_raises(self): 414 matrix_1 = array_ops.placeholder_with_default(rng.rand(4, 4), shape=None) 415 matrix_2 = array_ops.placeholder_with_default(rng.rand(3, 3), shape=None) 416 operators = [ 417 linalg.LinearOperatorFullMatrix(matrix_1, is_square=True), 418 linalg.LinearOperatorFullMatrix(matrix_2, is_square=True) 419 ] 420 operator = block_diag.LinearOperatorBlockDiag(operators) 421 x = np.random.rand(2, 4, 5).tolist() 422 msg = ("dimension does not match" if context.executing_eagerly() 423 else "input structure is ambiguous") 424 with self.assertRaisesRegex(ValueError, msg): 425 operator.matmul(x) 426 427 428@test_util.run_all_in_graph_and_eager_modes 429class NonSquareLinearOperatorBlockDiagTest( 430 linear_operator_test_util.NonSquareLinearOperatorDerivedClassTest): 431 """Most tests done in the base class LinearOperatorDerivedClassTest.""" 432 433 def tearDown(self): 434 config.enable_tensor_float_32_execution(self.tf32_keep_) 435 436 def setUp(self): 437 self.tf32_keep_ = config.tensor_float_32_execution_enabled() 438 config.enable_tensor_float_32_execution(False) 439 # Increase from 1e-6 to 1e-4 440 self._atol[dtypes.float32] = 1e-4 441 self._atol[dtypes.complex64] = 1e-4 442 self._rtol[dtypes.float32] = 1e-4 443 self._rtol[dtypes.complex64] = 1e-4 444 super(NonSquareLinearOperatorBlockDiagTest, self).setUp() 445 446 @staticmethod 447 def operator_shapes_infos(): 448 shape_info = linear_operator_test_util.OperatorShapesInfo 449 return [ 450 shape_info((1, 0)), 451 shape_info((1, 2, 3)), 452 shape_info((5, 3), blocks=[(2, 1), (3, 2)]), 453 shape_info((3, 6, 5), blocks=[(1, 2, 1), (3, 1, 2), (1, 3, 2)]), 454 shape_info((2, 1, 5, 2), blocks=[(2, 1, 2, 1), (1, 3, 1)]), 455 ] 456 457 @staticmethod 458 def skip_these_tests(): 459 return [ 460 "cholesky", 461 "cond", 462 "det", 463 "diag_part", 464 "eigvalsh", 465 "inverse", 466 "log_abs_det", 467 "solve", 468 "solve_with_broadcast", 469 "trace"] 470 471 @staticmethod 472 def use_blockwise_arg(): 473 return True 474 475 def operator_and_matrix( 476 self, shape_info, dtype, use_placeholder, 477 ensure_self_adjoint_and_pd=False): 478 del ensure_self_adjoint_and_pd 479 shape = list(shape_info.shape) 480 expected_blocks = ( 481 shape_info.__dict__["blocks"] if "blocks" in shape_info.__dict__ 482 else [shape]) 483 matrices = [ 484 linear_operator_test_util.random_normal(block_shape, dtype=dtype) 485 for block_shape in expected_blocks 486 ] 487 488 lin_op_matrices = matrices 489 490 if use_placeholder: 491 lin_op_matrices = [ 492 array_ops.placeholder_with_default( 493 matrix, shape=None) for matrix in matrices] 494 495 blocks = [] 496 for l in lin_op_matrices: 497 blocks.append( 498 linalg.LinearOperatorFullMatrix( 499 l, 500 is_square=False, 501 is_self_adjoint=False, 502 is_positive_definite=False)) 503 operator = block_diag.LinearOperatorBlockDiag(blocks) 504 505 # Broadcast the shapes. 506 expected_shape = list(shape_info.shape) 507 508 matrices = linear_operator_util.broadcast_matrix_batch_dims(matrices) 509 510 block_diag_dense = _block_diag_dense(expected_shape, matrices) 511 512 if not use_placeholder: 513 block_diag_dense.set_shape(expected_shape) 514 515 return operator, block_diag_dense 516 517 518if __name__ == "__main__": 519 linear_operator_test_util.add_tests(SquareLinearOperatorBlockDiagTest) 520 linear_operator_test_util.add_tests(NonSquareLinearOperatorBlockDiagTest) 521 test.main() 522