1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16import numpy as np 17 18from tensorflow.python.eager import backprop 19from tensorflow.python.eager import context 20from tensorflow.python.framework import config 21from tensorflow.python.framework import dtypes 22from tensorflow.python.framework import test_util 23from tensorflow.python.ops import array_ops 24from tensorflow.python.ops import variables as variables_module 25from tensorflow.python.ops.linalg import linalg as linalg_lib 26from tensorflow.python.ops.linalg import linear_operator_block_lower_triangular as block_lower_triangular 27from tensorflow.python.ops.linalg import linear_operator_test_util 28from tensorflow.python.ops.linalg import linear_operator_util 29from tensorflow.python.platform import test 30 31linalg = linalg_lib 32rng = np.random.RandomState(0) 33 34 35def _block_lower_triangular_dense(expected_shape, blocks): 36 """Convert a list of blocks into a dense blockwise lower-triangular matrix.""" 37 rows = [] 38 num_cols = 0 39 for row_blocks in blocks: 40 41 # Get the batch shape for the block. 42 batch_row_shape = array_ops.shape(row_blocks[0])[:-1] 43 44 num_cols += array_ops.shape(row_blocks[-1])[-1] 45 zeros_to_pad_after_shape = array_ops.concat( 46 [batch_row_shape, [expected_shape[-2] - num_cols]], axis=-1) 47 zeros_to_pad_after = array_ops.zeros( 48 zeros_to_pad_after_shape, dtype=row_blocks[-1].dtype) 49 50 row_blocks.append(zeros_to_pad_after) 51 rows.append(array_ops.concat(row_blocks, axis=-1)) 52 53 return array_ops.concat(rows, axis=-2) 54 55 56@test_util.run_all_in_graph_and_eager_modes 57class SquareLinearOperatorBlockLowerTriangularTest( 58 linear_operator_test_util.SquareLinearOperatorDerivedClassTest): 59 """Most tests done in the base class LinearOperatorDerivedClassTest.""" 60 61 def tearDown(self): 62 config.enable_tensor_float_32_execution(self.tf32_keep_) 63 64 def setUp(self): 65 self.tf32_keep_ = config.tensor_float_32_execution_enabled() 66 config.enable_tensor_float_32_execution(False) 67 # Increase from 1e-6 to 1e-5 68 self._atol[dtypes.float32] = 1e-5 69 self._atol[dtypes.complex64] = 1e-5 70 self._rtol[dtypes.float32] = 1e-5 71 self._rtol[dtypes.complex64] = 1e-5 72 super(SquareLinearOperatorBlockLowerTriangularTest, self).setUp() 73 74 @staticmethod 75 def use_blockwise_arg(): 76 return True 77 78 @staticmethod 79 def skip_these_tests(): 80 # Skipping since `LinearOperatorBlockLowerTriangular` is in general not 81 # self-adjoint. 82 return ["cholesky", "eigvalsh"] 83 84 @staticmethod 85 def operator_shapes_infos(): 86 shape_info = linear_operator_test_util.OperatorShapesInfo 87 return [ 88 shape_info((0, 0)), 89 shape_info((1, 1)), 90 shape_info((1, 3, 3)), 91 shape_info((5, 5), blocks=[[(2, 2)], [(3, 2), (3, 3)]]), 92 shape_info((3, 7, 7), 93 blocks=[[(1, 2, 2)], [(1, 3, 2), (3, 3, 3)], 94 [(1, 2, 2), (1, 2, 3), (1, 2, 2)]]), 95 shape_info((2, 4, 6, 6), 96 blocks=[[(2, 1, 2, 2)], [(1, 4, 2), (4, 4, 4)]]), 97 ] 98 99 def operator_and_matrix( 100 self, shape_info, dtype, use_placeholder, 101 ensure_self_adjoint_and_pd=False): 102 103 expected_blocks = ( 104 shape_info.__dict__["blocks"] if "blocks" in shape_info.__dict__ 105 else [[list(shape_info.shape)]]) 106 107 matrices = [] 108 for i, row_shapes in enumerate(expected_blocks): 109 row = [] 110 for j, block_shape in enumerate(row_shapes): 111 if i == j: # operator is on the diagonal 112 row.append( 113 linear_operator_test_util.random_positive_definite_matrix( 114 block_shape, dtype, force_well_conditioned=True)) 115 else: 116 row.append( 117 linear_operator_test_util.random_normal(block_shape, dtype=dtype)) 118 matrices.append(row) 119 120 lin_op_matrices = matrices 121 122 if use_placeholder: 123 lin_op_matrices = [[ 124 array_ops.placeholder_with_default( 125 matrix, shape=None) for matrix in row] for row in matrices] 126 127 operator = block_lower_triangular.LinearOperatorBlockLowerTriangular( 128 [[linalg.LinearOperatorFullMatrix( # pylint:disable=g-complex-comprehension 129 l, 130 is_square=True, 131 is_self_adjoint=True if ensure_self_adjoint_and_pd else None, 132 is_positive_definite=True if ensure_self_adjoint_and_pd else None) 133 for l in row] for row in lin_op_matrices]) 134 135 # Should be auto-set. 136 self.assertTrue(operator.is_square) 137 138 # Broadcast the shapes. 139 expected_shape = list(shape_info.shape) 140 broadcasted_matrices = linear_operator_util.broadcast_matrix_batch_dims( 141 [op for row in matrices for op in row]) # pylint: disable=g-complex-comprehension 142 matrices = [broadcasted_matrices[i * (i + 1) // 2:(i + 1) * (i + 2) // 2] 143 for i in range(len(matrices))] 144 145 block_lower_triangular_dense = _block_lower_triangular_dense( 146 expected_shape, matrices) 147 148 if not use_placeholder: 149 block_lower_triangular_dense.set_shape(expected_shape) 150 151 return operator, block_lower_triangular_dense 152 153 def test_is_x_flags(self): 154 # Matrix with two positive eigenvalues, 1, and 1. 155 # The matrix values do not effect auto-setting of the flags. 156 matrix = [[1., 0.], [1., 1.]] 157 operator = block_lower_triangular.LinearOperatorBlockLowerTriangular( 158 [[linalg.LinearOperatorFullMatrix(matrix)]], 159 is_positive_definite=True, 160 is_non_singular=True, 161 is_self_adjoint=False) 162 self.assertTrue(operator.is_positive_definite) 163 self.assertTrue(operator.is_non_singular) 164 self.assertFalse(operator.is_self_adjoint) 165 166 def test_block_lower_triangular_inverse_type(self): 167 matrix = [[1., 0.], [0., 1.]] 168 operator = block_lower_triangular.LinearOperatorBlockLowerTriangular( 169 [[linalg.LinearOperatorFullMatrix(matrix, is_non_singular=True)], 170 [linalg.LinearOperatorFullMatrix(matrix, is_non_singular=True), 171 linalg.LinearOperatorFullMatrix(matrix, is_non_singular=True)]], 172 is_non_singular=True, 173 ) 174 inverse = operator.inverse() 175 self.assertIsInstance( 176 inverse, 177 block_lower_triangular.LinearOperatorBlockLowerTriangular) 178 self.assertEqual(2, len(inverse.operators)) 179 self.assertEqual(1, len(inverse.operators[0])) 180 self.assertEqual(2, len(inverse.operators[1])) 181 182 def test_tape_safe(self): 183 operator_1 = linalg.LinearOperatorFullMatrix( 184 variables_module.Variable([[1., 0.], [0., 1.]]), 185 is_self_adjoint=True, 186 is_positive_definite=True) 187 operator_2 = linalg.LinearOperatorFullMatrix( 188 variables_module.Variable([[2., 0.], [1., 0.]])) 189 operator_3 = linalg.LinearOperatorFullMatrix( 190 variables_module.Variable([[3., 1.], [1., 3.]]), 191 is_self_adjoint=True, 192 is_positive_definite=True) 193 operator = block_lower_triangular.LinearOperatorBlockLowerTriangular( 194 [[operator_1], [operator_2, operator_3]], 195 is_self_adjoint=False, 196 is_positive_definite=True) 197 198 diagonal_grads_only = ["diag_part", "trace", "determinant", 199 "log_abs_determinant"] 200 self.check_tape_safe(operator, skip_options=diagonal_grads_only) 201 202 for y in diagonal_grads_only: 203 for diag_block in [operator_1, operator_3]: 204 with backprop.GradientTape() as tape: 205 grads = tape.gradient(getattr(operator, y)(), diag_block.variables) 206 for item in grads: 207 self.assertIsNotNone(item) 208 209 def test_convert_variables_to_tensors(self): 210 operator_1 = linalg.LinearOperatorFullMatrix( 211 variables_module.Variable([[1., 0.], [0., 1.]]), 212 is_self_adjoint=True, 213 is_positive_definite=True) 214 operator_2 = linalg.LinearOperatorFullMatrix( 215 variables_module.Variable([[2., 0.], [1., 0.]])) 216 operator_3 = linalg.LinearOperatorFullMatrix( 217 variables_module.Variable([[3., 1.], [1., 3.]]), 218 is_self_adjoint=True, 219 is_positive_definite=True) 220 operator = block_lower_triangular.LinearOperatorBlockLowerTriangular( 221 [[operator_1], [operator_2, operator_3]], 222 is_self_adjoint=False, 223 is_positive_definite=True) 224 with self.cached_session() as sess: 225 sess.run([x.initializer for x in operator.variables]) 226 self.check_convert_variables_to_tensors(operator) 227 228 def test_is_non_singular_auto_set(self): 229 # Matrix with two positive eigenvalues, 11 and 8. 230 # The matrix values do not effect auto-setting of the flags. 231 matrix = [[11., 0.], [1., 8.]] 232 operator_1 = linalg.LinearOperatorFullMatrix(matrix, is_non_singular=True) 233 operator_2 = linalg.LinearOperatorFullMatrix(matrix, is_non_singular=True) 234 operator_3 = linalg.LinearOperatorFullMatrix(matrix, is_non_singular=True) 235 236 operator = block_lower_triangular.LinearOperatorBlockLowerTriangular( 237 [[operator_1], [operator_2, operator_3]], 238 is_positive_definite=False, # No reason it HAS to be False... 239 is_non_singular=None) 240 self.assertFalse(operator.is_positive_definite) 241 self.assertTrue(operator.is_non_singular) 242 243 with self.assertRaisesRegex(ValueError, "always non-singular"): 244 block_lower_triangular.LinearOperatorBlockLowerTriangular( 245 [[operator_1], [operator_2, operator_3]], is_non_singular=False) 246 247 operator_4 = linalg.LinearOperatorFullMatrix( 248 [[1., 0.], [2., 0.]], is_non_singular=False) 249 250 # A singular operator off of the main diagonal shouldn't raise 251 block_lower_triangular.LinearOperatorBlockLowerTriangular( 252 [[operator_1], [operator_4, operator_2]], is_non_singular=True) 253 254 with self.assertRaisesRegex(ValueError, "always singular"): 255 block_lower_triangular.LinearOperatorBlockLowerTriangular( 256 [[operator_1], [operator_2, operator_4]], is_non_singular=True) 257 258 def test_different_dtypes_raises(self): 259 operators = [ 260 [linalg.LinearOperatorFullMatrix(rng.rand(2, 3, 3))], 261 [linalg.LinearOperatorFullMatrix(rng.rand(2, 3, 3)), 262 linalg.LinearOperatorFullMatrix(rng.rand(2, 3, 3).astype(np.float32))] 263 ] 264 with self.assertRaisesRegex(TypeError, "same dtype"): 265 block_lower_triangular.LinearOperatorBlockLowerTriangular(operators) 266 267 def test_non_square_operator_raises(self): 268 operators = [ 269 [linalg.LinearOperatorFullMatrix(rng.rand(3, 4), is_square=False)], 270 [linalg.LinearOperatorFullMatrix(rng.rand(4, 4)), 271 linalg.LinearOperatorFullMatrix(rng.rand(4, 4))] 272 ] 273 with self.assertRaisesRegex(ValueError, "must be square"): 274 block_lower_triangular.LinearOperatorBlockLowerTriangular(operators) 275 276 def test_empty_operators_raises(self): 277 with self.assertRaisesRegex(ValueError, "must be a list of >=1"): 278 block_lower_triangular.LinearOperatorBlockLowerTriangular([]) 279 280 def test_operators_wrong_length_raises(self): 281 with self.assertRaisesRegex(ValueError, "must contain `2` blocks"): 282 block_lower_triangular.LinearOperatorBlockLowerTriangular([ 283 [linalg.LinearOperatorFullMatrix(rng.rand(2, 2))], 284 [linalg.LinearOperatorFullMatrix(rng.rand(2, 2)) 285 for _ in range(3)]]) 286 287 def test_operators_mismatched_dimension_raises(self): 288 operators = [ 289 [linalg.LinearOperatorFullMatrix(rng.rand(3, 3))], 290 [linalg.LinearOperatorFullMatrix(rng.rand(3, 4)), 291 linalg.LinearOperatorFullMatrix(rng.rand(3, 3))] 292 ] 293 with self.assertRaisesRegex(ValueError, "must be the same as"): 294 block_lower_triangular.LinearOperatorBlockLowerTriangular(operators) 295 296 def test_incompatible_input_blocks_raises(self): 297 matrix_1 = array_ops.placeholder_with_default(rng.rand(4, 4), shape=None) 298 matrix_2 = array_ops.placeholder_with_default(rng.rand(3, 4), shape=None) 299 matrix_3 = array_ops.placeholder_with_default(rng.rand(3, 3), shape=None) 300 operators = [ 301 [linalg.LinearOperatorFullMatrix(matrix_1, is_square=True)], 302 [linalg.LinearOperatorFullMatrix(matrix_2), 303 linalg.LinearOperatorFullMatrix(matrix_3, is_square=True)] 304 ] 305 operator = block_lower_triangular.LinearOperatorBlockLowerTriangular( 306 operators) 307 x = np.random.rand(2, 4, 5).tolist() 308 msg = ("dimension does not match" if context.executing_eagerly() 309 else "input structure is ambiguous") 310 with self.assertRaisesRegex(ValueError, msg): 311 operator.matmul(x) 312 313 314if __name__ == "__main__": 315 linear_operator_test_util.add_tests( 316 SquareLinearOperatorBlockLowerTriangularTest) 317 test.main() 318