1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15import numpy as np 16 17from tensorflow.python.eager import context 18from tensorflow.python.framework import constant_op 19from tensorflow.python.framework import dtypes 20from tensorflow.python.framework import ops 21from tensorflow.python.framework import tensor_shape 22from tensorflow.python.framework import test_util 23from tensorflow.python.ops import array_ops 24from tensorflow.python.ops import linalg_ops 25from tensorflow.python.ops import math_ops 26from tensorflow.python.ops.linalg import linalg as linalg_lib 27from tensorflow.python.ops.parallel_for import control_flow_ops 28from tensorflow.python.platform import test 29 30linalg = linalg_lib 31rng = np.random.RandomState(123) 32 33 34class LinearOperatorShape(linalg.LinearOperator): 35 """LinearOperator that implements the methods ._shape and _shape_tensor.""" 36 37 def __init__(self, 38 shape, 39 is_non_singular=None, 40 is_self_adjoint=None, 41 is_positive_definite=None, 42 is_square=None): 43 parameters = dict( 44 shape=shape, 45 is_non_singular=is_non_singular, 46 is_self_adjoint=is_self_adjoint, 47 is_positive_definite=is_positive_definite, 48 is_square=is_square 49 ) 50 51 self._stored_shape = shape 52 super(LinearOperatorShape, self).__init__( 53 dtype=dtypes.float32, 54 is_non_singular=is_non_singular, 55 is_self_adjoint=is_self_adjoint, 56 is_positive_definite=is_positive_definite, 57 is_square=is_square, 58 parameters=parameters) 59 60 def _shape(self): 61 return tensor_shape.TensorShape(self._stored_shape) 62 63 def _shape_tensor(self): 64 return constant_op.constant(self._stored_shape, dtype=dtypes.int32) 65 66 def _matmul(self): 67 raise NotImplementedError("Not needed for this test.") 68 69 70class LinearOperatorMatmulSolve(linalg.LinearOperator): 71 """LinearOperator that wraps a [batch] matrix and implements matmul/solve.""" 72 73 def __init__(self, 74 matrix, 75 is_non_singular=None, 76 is_self_adjoint=None, 77 is_positive_definite=None, 78 is_square=None): 79 parameters = dict( 80 matrix=matrix, 81 is_non_singular=is_non_singular, 82 is_self_adjoint=is_self_adjoint, 83 is_positive_definite=is_positive_definite, 84 is_square=is_square 85 ) 86 87 self._matrix = ops.convert_to_tensor(matrix, name="matrix") 88 super(LinearOperatorMatmulSolve, self).__init__( 89 dtype=self._matrix.dtype, 90 is_non_singular=is_non_singular, 91 is_self_adjoint=is_self_adjoint, 92 is_positive_definite=is_positive_definite, 93 is_square=is_square, 94 parameters=parameters) 95 96 def _shape(self): 97 return self._matrix.shape 98 99 def _shape_tensor(self): 100 return array_ops.shape(self._matrix) 101 102 def _matmul(self, x, adjoint=False, adjoint_arg=False): 103 x = ops.convert_to_tensor(x, name="x") 104 return math_ops.matmul( 105 self._matrix, x, adjoint_a=adjoint, adjoint_b=adjoint_arg) 106 107 def _solve(self, rhs, adjoint=False, adjoint_arg=False): 108 rhs = ops.convert_to_tensor(rhs, name="rhs") 109 assert not adjoint_arg, "Not implemented for this test class." 110 return linalg_ops.matrix_solve(self._matrix, rhs, adjoint=adjoint) 111 112 113@test_util.run_all_in_graph_and_eager_modes 114class LinearOperatorTest(test.TestCase): 115 116 def test_all_shape_properties_defined_by_the_one_property_shape(self): 117 118 shape = (1, 2, 3, 4) 119 operator = LinearOperatorShape(shape) 120 121 self.assertAllEqual(shape, operator.shape) 122 self.assertAllEqual(4, operator.tensor_rank) 123 self.assertAllEqual((1, 2), operator.batch_shape) 124 self.assertAllEqual(4, operator.domain_dimension) 125 self.assertAllEqual(3, operator.range_dimension) 126 expected_parameters = { 127 "is_non_singular": None, 128 "is_positive_definite": None, 129 "is_self_adjoint": None, 130 "is_square": None, 131 "shape": (1, 2, 3, 4), 132 } 133 self.assertEqual(expected_parameters, operator.parameters) 134 135 def test_all_shape_methods_defined_by_the_one_method_shape(self): 136 with self.cached_session(): 137 shape = (1, 2, 3, 4) 138 operator = LinearOperatorShape(shape) 139 140 self.assertAllEqual(shape, self.evaluate(operator.shape_tensor())) 141 self.assertAllEqual(4, self.evaluate(operator.tensor_rank_tensor())) 142 self.assertAllEqual((1, 2), self.evaluate(operator.batch_shape_tensor())) 143 self.assertAllEqual(4, self.evaluate(operator.domain_dimension_tensor())) 144 self.assertAllEqual(3, self.evaluate(operator.range_dimension_tensor())) 145 146 def test_is_x_properties(self): 147 operator = LinearOperatorShape( 148 shape=(2, 2), 149 is_non_singular=False, 150 is_self_adjoint=True, 151 is_positive_definite=False) 152 self.assertFalse(operator.is_non_singular) 153 self.assertTrue(operator.is_self_adjoint) 154 self.assertFalse(operator.is_positive_definite) 155 156 def test_nontrivial_parameters(self): 157 matrix = rng.randn(2, 3, 4) 158 matrix_ph = array_ops.placeholder_with_default(input=matrix, shape=None) 159 operator = LinearOperatorMatmulSolve(matrix_ph) 160 expected_parameters = { 161 "is_non_singular": None, 162 "is_positive_definite": None, 163 "is_self_adjoint": None, 164 "is_square": None, 165 "matrix": matrix_ph, 166 } 167 self.assertEqual(expected_parameters, operator.parameters) 168 169 def test_generic_to_dense_method_non_square_matrix_static(self): 170 matrix = rng.randn(2, 3, 4) 171 operator = LinearOperatorMatmulSolve(matrix) 172 with self.cached_session(): 173 operator_dense = operator.to_dense() 174 self.assertAllEqual((2, 3, 4), operator_dense.shape) 175 self.assertAllClose(matrix, self.evaluate(operator_dense)) 176 177 def test_generic_to_dense_method_non_square_matrix_tensor(self): 178 matrix = rng.randn(2, 3, 4) 179 matrix_ph = array_ops.placeholder_with_default(input=matrix, shape=None) 180 operator = LinearOperatorMatmulSolve(matrix_ph) 181 operator_dense = operator.to_dense() 182 self.assertAllClose(matrix, self.evaluate(operator_dense)) 183 184 def test_matvec(self): 185 matrix = [[1., 0], [0., 2.]] 186 operator = LinearOperatorMatmulSolve(matrix) 187 x = [1., 1.] 188 with self.cached_session(): 189 y = operator.matvec(x) 190 self.assertAllEqual((2,), y.shape) 191 self.assertAllClose([1., 2.], self.evaluate(y)) 192 193 def test_solvevec(self): 194 matrix = [[1., 0], [0., 2.]] 195 operator = LinearOperatorMatmulSolve(matrix) 196 y = [1., 1.] 197 with self.cached_session(): 198 x = operator.solvevec(y) 199 self.assertAllEqual((2,), x.shape) 200 self.assertAllClose([1., 1 / 2.], self.evaluate(x)) 201 202 def test_is_square_set_to_true_for_square_static_shapes(self): 203 operator = LinearOperatorShape(shape=(2, 4, 4)) 204 self.assertTrue(operator.is_square) 205 206 def test_is_square_set_to_false_for_square_static_shapes(self): 207 operator = LinearOperatorShape(shape=(2, 3, 4)) 208 self.assertFalse(operator.is_square) 209 210 def test_is_square_set_incorrectly_to_false_raises(self): 211 with self.assertRaisesRegex(ValueError, "but.*was square"): 212 _ = LinearOperatorShape(shape=(2, 4, 4), is_square=False).is_square 213 214 def test_is_square_set_inconsistent_with_other_hints_raises(self): 215 with self.assertRaisesRegex(ValueError, "is always square"): 216 matrix = array_ops.placeholder_with_default(input=(), shape=None) 217 LinearOperatorMatmulSolve(matrix, is_non_singular=True, is_square=False) 218 219 with self.assertRaisesRegex(ValueError, "is always square"): 220 matrix = array_ops.placeholder_with_default(input=(), shape=None) 221 LinearOperatorMatmulSolve( 222 matrix, is_positive_definite=True, is_square=False) 223 224 def test_non_square_operators_raise_on_determinant_and_solve(self): 225 operator = LinearOperatorShape((2, 3)) 226 with self.assertRaisesRegex(NotImplementedError, "not be square"): 227 operator.determinant() 228 with self.assertRaisesRegex(NotImplementedError, "not be square"): 229 operator.log_abs_determinant() 230 with self.assertRaisesRegex(NotImplementedError, "not be square"): 231 operator.solve(rng.rand(2, 2)) 232 233 with self.assertRaisesRegex(ValueError, "is always square"): 234 matrix = array_ops.placeholder_with_default(input=(), shape=None) 235 LinearOperatorMatmulSolve( 236 matrix, is_positive_definite=True, is_square=False) 237 238 def test_is_square_manual_set_works(self): 239 matrix = array_ops.placeholder_with_default( 240 input=np.ones((2, 2)), shape=None) 241 operator = LinearOperatorMatmulSolve(matrix) 242 if not context.executing_eagerly(): 243 # Eager mode will read in the default value, and discover the answer is 244 # True. Graph mode must rely on the hint, since the placeholder has 245 # shape=None...the hint is, by default, None. 246 self.assertEqual(None, operator.is_square) 247 248 # Set to True 249 operator = LinearOperatorMatmulSolve(matrix, is_square=True) 250 self.assertTrue(operator.is_square) 251 252 def test_linear_operator_matmul_hints_closed(self): 253 matrix = array_ops.placeholder_with_default(input=np.ones((2, 2)), 254 shape=None) 255 operator1 = LinearOperatorMatmulSolve(matrix) 256 257 operator_matmul = operator1.matmul(operator1) 258 259 if not context.executing_eagerly(): 260 # Eager mode will read in the input and discover matrix is square. 261 self.assertEqual(None, operator_matmul.is_square) 262 self.assertEqual(None, operator_matmul.is_non_singular) 263 self.assertEqual(None, operator_matmul.is_self_adjoint) 264 self.assertEqual(None, operator_matmul.is_positive_definite) 265 266 operator2 = LinearOperatorMatmulSolve( 267 matrix, 268 is_non_singular=True, 269 is_self_adjoint=True, 270 is_positive_definite=True, 271 is_square=True, 272 ) 273 274 operator_matmul = operator2.matmul(operator2) 275 276 self.assertTrue(operator_matmul.is_square) 277 self.assertTrue(operator_matmul.is_non_singular) 278 self.assertEqual(None, operator_matmul.is_self_adjoint) 279 self.assertEqual(None, operator_matmul.is_positive_definite) 280 281 def test_linear_operator_matmul_hints_false(self): 282 matrix1 = array_ops.placeholder_with_default( 283 input=rng.rand(2, 2), shape=None) 284 operator1 = LinearOperatorMatmulSolve( 285 matrix1, 286 is_non_singular=False, 287 is_self_adjoint=False, 288 is_positive_definite=False, 289 is_square=True, 290 ) 291 292 operator_matmul = operator1.matmul(operator1) 293 294 self.assertTrue(operator_matmul.is_square) 295 self.assertFalse(operator_matmul.is_non_singular) 296 self.assertEqual(None, operator_matmul.is_self_adjoint) 297 self.assertEqual(None, operator_matmul.is_positive_definite) 298 299 matrix2 = array_ops.placeholder_with_default( 300 input=rng.rand(2, 3), shape=None) 301 operator2 = LinearOperatorMatmulSolve( 302 matrix2, 303 is_non_singular=False, 304 is_self_adjoint=False, 305 is_positive_definite=False, 306 is_square=False, 307 ) 308 309 operator_matmul = operator2.matmul(operator2, adjoint_arg=True) 310 311 if context.executing_eagerly(): 312 self.assertTrue(operator_matmul.is_square) 313 # False since we specified is_non_singular=False. 314 self.assertFalse(operator_matmul.is_non_singular) 315 else: 316 self.assertIsNone(operator_matmul.is_square) 317 # May be non-singular, since it's the composition of two non-square. 318 # TODO(b/136162840) This is a bit inconsistent, and should probably be 319 # False since we specified operator2.is_non_singular == False. 320 self.assertIsNone(operator_matmul.is_non_singular) 321 322 # No way to deduce these, even in Eager mode. 323 self.assertIsNone(operator_matmul.is_self_adjoint) 324 self.assertIsNone(operator_matmul.is_positive_definite) 325 326 def test_linear_operator_matmul_hint_infer_square(self): 327 matrix1 = array_ops.placeholder_with_default( 328 input=rng.rand(2, 3), shape=(2, 3)) 329 matrix2 = array_ops.placeholder_with_default( 330 input=rng.rand(3, 2), shape=(3, 2)) 331 matrix3 = array_ops.placeholder_with_default( 332 input=rng.rand(3, 4), shape=(3, 4)) 333 334 operator1 = LinearOperatorMatmulSolve(matrix1, is_square=False) 335 operator2 = LinearOperatorMatmulSolve(matrix2, is_square=False) 336 operator3 = LinearOperatorMatmulSolve(matrix3, is_square=False) 337 338 self.assertTrue(operator1.matmul(operator2).is_square) 339 self.assertTrue(operator2.matmul(operator1).is_square) 340 self.assertFalse(operator1.matmul(operator3).is_square) 341 342 def testDispatchedMethods(self): 343 operator = linalg.LinearOperatorFullMatrix( 344 [[1., 0.5], [0.5, 1.]], 345 is_square=True, 346 is_self_adjoint=True, 347 is_non_singular=True, 348 is_positive_definite=True) 349 methods = { 350 "trace": linalg.trace, 351 "diag_part": linalg.diag_part, 352 "log_abs_determinant": linalg.logdet, 353 "determinant": linalg.det 354 } 355 for method in methods: 356 op_val = getattr(operator, method)() 357 linalg_val = methods[method](operator) 358 self.assertAllClose( 359 self.evaluate(op_val), 360 self.evaluate(linalg_val)) 361 # Solve and Matmul go here. 362 363 adjoint = linalg.adjoint(operator) 364 self.assertIsInstance(adjoint, linalg.LinearOperator) 365 cholesky = linalg.cholesky(operator) 366 self.assertIsInstance(cholesky, linalg.LinearOperator) 367 inverse = linalg.inv(operator) 368 self.assertIsInstance(inverse, linalg.LinearOperator) 369 370 def testDispatchMatmulSolve(self): 371 operator = linalg.LinearOperatorFullMatrix( 372 np.float64([[1., 0.5], [0.5, 1.]]), 373 is_square=True, 374 is_self_adjoint=True, 375 is_non_singular=True, 376 is_positive_definite=True) 377 rhs = np.random.uniform(-1., 1., size=[3, 2, 2]) 378 for adjoint in [False, True]: 379 for adjoint_arg in [False, True]: 380 op_val = operator.matmul( 381 rhs, adjoint=adjoint, adjoint_arg=adjoint_arg) 382 matmul_val = math_ops.matmul( 383 operator, rhs, adjoint_a=adjoint, adjoint_b=adjoint_arg) 384 self.assertAllClose( 385 self.evaluate(op_val), self.evaluate(matmul_val)) 386 387 op_val = operator.solve(rhs, adjoint=adjoint) 388 solve_val = linalg.solve(operator, rhs, adjoint=adjoint) 389 self.assertAllClose( 390 self.evaluate(op_val), self.evaluate(solve_val)) 391 392 def testDispatchMatmulLeftOperatorIsTensor(self): 393 mat = np.float64([[1., 0.5], [0.5, 1.]]) 394 right_operator = linalg.LinearOperatorFullMatrix( 395 mat, 396 is_square=True, 397 is_self_adjoint=True, 398 is_non_singular=True, 399 is_positive_definite=True) 400 lhs = np.random.uniform(-1., 1., size=[3, 2, 2]) 401 402 for adjoint in [False, True]: 403 for adjoint_arg in [False, True]: 404 op_val = math_ops.matmul( 405 lhs, mat, adjoint_a=adjoint, adjoint_b=adjoint_arg) 406 matmul_val = math_ops.matmul( 407 lhs, right_operator, adjoint_a=adjoint, adjoint_b=adjoint_arg) 408 self.assertAllClose( 409 self.evaluate(op_val), self.evaluate(matmul_val)) 410 411 def testVectorizedMap(self): 412 413 def fn(x): 414 y = constant_op.constant([3., 4.]) 415 # Make a [2, N, N] shaped operator. 416 x = x * y[..., array_ops.newaxis, array_ops.newaxis] 417 operator = linalg.LinearOperatorFullMatrix( 418 x, is_square=True) 419 return operator 420 421 x = np.random.uniform(-1., 1., size=[3, 5, 5]).astype(np.float32) 422 batched_operator = control_flow_ops.vectorized_map( 423 fn, ops.convert_to_tensor(x)) 424 self.assertIsInstance(batched_operator, linalg.LinearOperator) 425 self.assertAllEqual(batched_operator.batch_shape, [3, 2]) 426 427 428if __name__ == "__main__": 429 test.main() 430