1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16import numpy as np 17 18from tensorflow.python.framework import config 19from tensorflow.python.framework import dtypes 20from tensorflow.python.framework import ops 21from tensorflow.python.framework import test_util 22from tensorflow.python.ops import array_ops 23from tensorflow.python.ops import variables as variables_module 24from tensorflow.python.ops.linalg import linalg as linalg_lib 25from tensorflow.python.ops.linalg import linear_operator_kronecker as kronecker 26from tensorflow.python.ops.linalg import linear_operator_lower_triangular as lower_triangular 27from tensorflow.python.ops.linalg import linear_operator_test_util 28from tensorflow.python.ops.linalg import linear_operator_util 29from tensorflow.python.platform import test 30 31linalg = linalg_lib 32rng = np.random.RandomState(0) 33 34 35def _kronecker_dense(factors): 36 """Convert a list of factors, into a dense Kronecker product.""" 37 product = factors[0] 38 for factor in factors[1:]: 39 product = product[..., array_ops.newaxis, :, array_ops.newaxis] 40 factor_to_mul = factor[..., array_ops.newaxis, :, array_ops.newaxis, :] 41 product *= factor_to_mul 42 product = array_ops.reshape( 43 product, 44 shape=array_ops.concat( 45 [array_ops.shape(product)[:-4], 46 [array_ops.shape(product)[-4] * array_ops.shape(product)[-3], 47 array_ops.shape(product)[-2] * array_ops.shape(product)[-1]] 48 ], axis=0)) 49 50 return product 51 52 53class KroneckerDenseTest(test.TestCase): 54 """Test of `_kronecker_dense` function.""" 55 56 def test_kronecker_dense_matrix(self): 57 x = ops.convert_to_tensor([[2., 3.], [1., 2.]], dtype=dtypes.float32) 58 y = ops.convert_to_tensor([[1., 2.], [5., -1.]], dtype=dtypes.float32) 59 # From explicitly writing out the kronecker product of x and y. 60 z = ops.convert_to_tensor([ 61 [2., 4., 3., 6.], 62 [10., -2., 15., -3.], 63 [1., 2., 2., 4.], 64 [5., -1., 10., -2.]], dtype=dtypes.float32) 65 # From explicitly writing out the kronecker product of y and x. 66 w = ops.convert_to_tensor([ 67 [2., 3., 4., 6.], 68 [1., 2., 2., 4.], 69 [10., 15., -2., -3.], 70 [5., 10., -1., -2.]], dtype=dtypes.float32) 71 72 self.assertAllClose( 73 self.evaluate(_kronecker_dense([x, y])), self.evaluate(z)) 74 self.assertAllClose( 75 self.evaluate(_kronecker_dense([y, x])), self.evaluate(w)) 76 77 78@test_util.run_all_in_graph_and_eager_modes 79class SquareLinearOperatorKroneckerTest( 80 linear_operator_test_util.SquareLinearOperatorDerivedClassTest): 81 """Most tests done in the base class LinearOperatorDerivedClassTest.""" 82 83 def tearDown(self): 84 config.enable_tensor_float_32_execution(self.tf32_keep_) 85 86 def setUp(self): 87 self.tf32_keep_ = config.tensor_float_32_execution_enabled() 88 config.enable_tensor_float_32_execution(False) 89 # Increase from 1e-6 to 1e-4 90 self._atol[dtypes.float32] = 1e-4 91 self._atol[dtypes.complex64] = 1e-4 92 self._rtol[dtypes.float32] = 1e-4 93 self._rtol[dtypes.complex64] = 1e-4 94 95 @staticmethod 96 def operator_shapes_infos(): 97 shape_info = linear_operator_test_util.OperatorShapesInfo 98 return [ 99 shape_info((1, 1), factors=[(1, 1), (1, 1)]), 100 shape_info((8, 8), factors=[(2, 2), (2, 2), (2, 2)]), 101 shape_info((12, 12), factors=[(2, 2), (3, 3), (2, 2)]), 102 shape_info((1, 3, 3), factors=[(1, 1), (1, 3, 3)]), 103 shape_info((3, 6, 6), factors=[(3, 1, 1), (1, 2, 2), (1, 3, 3)]), 104 ] 105 106 def operator_and_matrix( 107 self, build_info, dtype, use_placeholder, 108 ensure_self_adjoint_and_pd=False): 109 # Kronecker products constructed below will be from symmetric 110 # positive-definite matrices. 111 del ensure_self_adjoint_and_pd 112 shape = list(build_info.shape) 113 expected_factors = build_info.__dict__["factors"] 114 matrices = [ 115 linear_operator_test_util.random_positive_definite_matrix( 116 block_shape, dtype, force_well_conditioned=True) 117 for block_shape in expected_factors 118 ] 119 120 lin_op_matrices = matrices 121 122 if use_placeholder: 123 lin_op_matrices = [ 124 array_ops.placeholder_with_default(m, shape=None) for m in matrices] 125 126 operator = kronecker.LinearOperatorKronecker( 127 [linalg.LinearOperatorFullMatrix( 128 l, 129 is_square=True, 130 is_self_adjoint=True, 131 is_positive_definite=True) 132 for l in lin_op_matrices]) 133 134 matrices = linear_operator_util.broadcast_matrix_batch_dims(matrices) 135 136 kronecker_dense = _kronecker_dense(matrices) 137 138 if not use_placeholder: 139 kronecker_dense.set_shape(shape) 140 141 return operator, kronecker_dense 142 143 def test_is_x_flags(self): 144 # Matrix with two positive eigenvalues, 1, and 1. 145 # The matrix values do not effect auto-setting of the flags. 146 matrix = [[1., 0.], [1., 1.]] 147 operator = kronecker.LinearOperatorKronecker( 148 [linalg.LinearOperatorFullMatrix(matrix), 149 linalg.LinearOperatorFullMatrix(matrix)], 150 is_positive_definite=True, 151 is_non_singular=True, 152 is_self_adjoint=False) 153 self.assertTrue(operator.is_positive_definite) 154 self.assertTrue(operator.is_non_singular) 155 self.assertFalse(operator.is_self_adjoint) 156 157 def test_is_non_singular_auto_set(self): 158 # Matrix with two positive eigenvalues, 11 and 8. 159 # The matrix values do not effect auto-setting of the flags. 160 matrix = [[11., 0.], [1., 8.]] 161 operator_1 = linalg.LinearOperatorFullMatrix(matrix, is_non_singular=True) 162 operator_2 = linalg.LinearOperatorFullMatrix(matrix, is_non_singular=True) 163 164 operator = kronecker.LinearOperatorKronecker( 165 [operator_1, operator_2], 166 is_positive_definite=False, # No reason it HAS to be False... 167 is_non_singular=None) 168 self.assertFalse(operator.is_positive_definite) 169 self.assertTrue(operator.is_non_singular) 170 171 with self.assertRaisesRegex(ValueError, "always non-singular"): 172 kronecker.LinearOperatorKronecker( 173 [operator_1, operator_2], is_non_singular=False) 174 175 def test_name(self): 176 matrix = [[11., 0.], [1., 8.]] 177 operator_1 = linalg.LinearOperatorFullMatrix(matrix, name="left") 178 operator_2 = linalg.LinearOperatorFullMatrix(matrix, name="right") 179 180 operator = kronecker.LinearOperatorKronecker([operator_1, operator_2]) 181 182 self.assertEqual("left_x_right", operator.name) 183 184 def test_different_dtypes_raises(self): 185 operators = [ 186 linalg.LinearOperatorFullMatrix(rng.rand(2, 3, 3)), 187 linalg.LinearOperatorFullMatrix(rng.rand(2, 3, 3).astype(np.float32)) 188 ] 189 with self.assertRaisesRegex(TypeError, "same dtype"): 190 kronecker.LinearOperatorKronecker(operators) 191 192 def test_empty_or_one_operators_raises(self): 193 with self.assertRaisesRegex(ValueError, ">=1 operators"): 194 kronecker.LinearOperatorKronecker([]) 195 196 def test_kronecker_adjoint_type(self): 197 matrix = [[1., 0.], [0., 1.]] 198 operator = kronecker.LinearOperatorKronecker( 199 [ 200 linalg.LinearOperatorFullMatrix( 201 matrix, is_non_singular=True), 202 linalg.LinearOperatorFullMatrix( 203 matrix, is_non_singular=True), 204 ], 205 is_non_singular=True, 206 ) 207 adjoint = operator.adjoint() 208 self.assertIsInstance( 209 adjoint, 210 kronecker.LinearOperatorKronecker) 211 self.assertEqual(2, len(adjoint.operators)) 212 213 def test_kronecker_cholesky_type(self): 214 matrix = [[1., 0.], [0., 1.]] 215 operator = kronecker.LinearOperatorKronecker( 216 [ 217 linalg.LinearOperatorFullMatrix( 218 matrix, 219 is_positive_definite=True, 220 is_self_adjoint=True, 221 ), 222 linalg.LinearOperatorFullMatrix( 223 matrix, 224 is_positive_definite=True, 225 is_self_adjoint=True, 226 ), 227 ], 228 is_positive_definite=True, 229 is_self_adjoint=True, 230 ) 231 cholesky_factor = operator.cholesky() 232 self.assertIsInstance( 233 cholesky_factor, 234 kronecker.LinearOperatorKronecker) 235 self.assertEqual(2, len(cholesky_factor.operators)) 236 self.assertIsInstance( 237 cholesky_factor.operators[0], 238 lower_triangular.LinearOperatorLowerTriangular) 239 self.assertIsInstance( 240 cholesky_factor.operators[1], 241 lower_triangular.LinearOperatorLowerTriangular) 242 243 def test_kronecker_inverse_type(self): 244 matrix = [[1., 0.], [0., 1.]] 245 operator = kronecker.LinearOperatorKronecker( 246 [ 247 linalg.LinearOperatorFullMatrix( 248 matrix, is_non_singular=True), 249 linalg.LinearOperatorFullMatrix( 250 matrix, is_non_singular=True), 251 ], 252 is_non_singular=True, 253 ) 254 inverse = operator.inverse() 255 self.assertIsInstance( 256 inverse, 257 kronecker.LinearOperatorKronecker) 258 self.assertEqual(2, len(inverse.operators)) 259 260 def test_tape_safe(self): 261 matrix_1 = variables_module.Variable([[1., 0.], [0., 1.]]) 262 matrix_2 = variables_module.Variable([[2., 0.], [0., 2.]]) 263 operator = kronecker.LinearOperatorKronecker( 264 [ 265 linalg.LinearOperatorFullMatrix( 266 matrix_1, is_non_singular=True), 267 linalg.LinearOperatorFullMatrix( 268 matrix_2, is_non_singular=True), 269 ], 270 is_non_singular=True, 271 ) 272 self.check_tape_safe(operator) 273 274 def test_convert_variables_to_tensors(self): 275 matrix_1 = variables_module.Variable([[1., 0.], [0., 1.]]) 276 matrix_2 = variables_module.Variable([[2., 0.], [0., 2.]]) 277 operator = kronecker.LinearOperatorKronecker( 278 [ 279 linalg.LinearOperatorFullMatrix( 280 matrix_1, is_non_singular=True), 281 linalg.LinearOperatorFullMatrix( 282 matrix_2, is_non_singular=True), 283 ], 284 is_non_singular=True, 285 ) 286 with self.cached_session() as sess: 287 sess.run([x.initializer for x in operator.variables]) 288 self.check_convert_variables_to_tensors(operator) 289 290 291if __name__ == "__main__": 292 linear_operator_test_util.add_tests(SquareLinearOperatorKroneckerTest) 293 test.main() 294