1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16from absl.testing import parameterized 17import numpy as np 18 19from tensorflow.python.framework import ops 20from tensorflow.python.framework import tensor_shape 21from tensorflow.python.ops import array_ops 22from tensorflow.python.ops import linalg_ops 23from tensorflow.python.ops import math_ops 24from tensorflow.python.ops.linalg import linear_operator_util 25from tensorflow.python.platform import test 26 27rng = np.random.RandomState(0) 28 29 30class AssertZeroImagPartTest(test.TestCase): 31 32 def test_real_tensor_doesnt_raise(self): 33 x = ops.convert_to_tensor([0., 2, 3]) 34 # Should not raise. 35 self.evaluate( 36 linear_operator_util.assert_zero_imag_part(x, message="ABC123")) 37 38 def test_complex_tensor_with_imag_zero_doesnt_raise(self): 39 x = ops.convert_to_tensor([1., 0, 3]) 40 y = ops.convert_to_tensor([0., 0, 0]) 41 z = math_ops.complex(x, y) 42 # Should not raise. 43 self.evaluate( 44 linear_operator_util.assert_zero_imag_part(z, message="ABC123")) 45 46 def test_complex_tensor_with_nonzero_imag_raises(self): 47 x = ops.convert_to_tensor([1., 2, 0]) 48 y = ops.convert_to_tensor([1., 2, 0]) 49 z = math_ops.complex(x, y) 50 with self.assertRaisesOpError("ABC123"): 51 self.evaluate( 52 linear_operator_util.assert_zero_imag_part(z, message="ABC123")) 53 54 55class AssertNoEntriesWithModulusZeroTest(test.TestCase): 56 57 def test_nonzero_real_tensor_doesnt_raise(self): 58 x = ops.convert_to_tensor([1., 2, 3]) 59 # Should not raise. 60 self.evaluate( 61 linear_operator_util.assert_no_entries_with_modulus_zero( 62 x, message="ABC123")) 63 64 def test_nonzero_complex_tensor_doesnt_raise(self): 65 x = ops.convert_to_tensor([1., 0, 3]) 66 y = ops.convert_to_tensor([1., 2, 0]) 67 z = math_ops.complex(x, y) 68 # Should not raise. 69 self.evaluate( 70 linear_operator_util.assert_no_entries_with_modulus_zero( 71 z, message="ABC123")) 72 73 def test_zero_real_tensor_raises(self): 74 x = ops.convert_to_tensor([1., 0, 3]) 75 with self.assertRaisesOpError("ABC123"): 76 self.evaluate( 77 linear_operator_util.assert_no_entries_with_modulus_zero( 78 x, message="ABC123")) 79 80 def test_zero_complex_tensor_raises(self): 81 x = ops.convert_to_tensor([1., 2, 0]) 82 y = ops.convert_to_tensor([1., 2, 0]) 83 z = math_ops.complex(x, y) 84 with self.assertRaisesOpError("ABC123"): 85 self.evaluate( 86 linear_operator_util.assert_no_entries_with_modulus_zero( 87 z, message="ABC123")) 88 89 90class BroadcastMatrixBatchDimsTest(test.TestCase): 91 92 def test_zero_batch_matrices_returned_as_empty_list(self): 93 self.assertAllEqual([], 94 linear_operator_util.broadcast_matrix_batch_dims([])) 95 96 def test_one_batch_matrix_returned_after_tensor_conversion(self): 97 arr = rng.rand(2, 3, 4) 98 tensor, = linear_operator_util.broadcast_matrix_batch_dims([arr]) 99 self.assertTrue(isinstance(tensor, ops.Tensor)) 100 101 self.assertAllClose(arr, self.evaluate(tensor)) 102 103 def test_static_dims_broadcast(self): 104 # x.batch_shape = [3, 1, 2] 105 # y.batch_shape = [4, 1] 106 # broadcast batch shape = [3, 4, 2] 107 x = rng.rand(3, 1, 2, 1, 5) 108 y = rng.rand(4, 1, 3, 7) 109 batch_of_zeros = np.zeros((3, 4, 2, 1, 1)) 110 x_bc_expected = x + batch_of_zeros 111 y_bc_expected = y + batch_of_zeros 112 113 x_bc, y_bc = linear_operator_util.broadcast_matrix_batch_dims([x, y]) 114 115 self.assertAllEqual(x_bc_expected.shape, x_bc.shape) 116 self.assertAllEqual(y_bc_expected.shape, y_bc.shape) 117 x_bc_, y_bc_ = self.evaluate([x_bc, y_bc]) 118 self.assertAllClose(x_bc_expected, x_bc_) 119 self.assertAllClose(y_bc_expected, y_bc_) 120 121 def test_static_dims_broadcast_second_arg_higher_rank(self): 122 # x.batch_shape = [1, 2] 123 # y.batch_shape = [1, 3, 1] 124 # broadcast batch shape = [1, 3, 2] 125 x = rng.rand(1, 2, 1, 5) 126 y = rng.rand(1, 3, 2, 3, 7) 127 batch_of_zeros = np.zeros((1, 3, 2, 1, 1)) 128 x_bc_expected = x + batch_of_zeros 129 y_bc_expected = y + batch_of_zeros 130 131 x_bc, y_bc = linear_operator_util.broadcast_matrix_batch_dims([x, y]) 132 133 self.assertAllEqual(x_bc_expected.shape, x_bc.shape) 134 self.assertAllEqual(y_bc_expected.shape, y_bc.shape) 135 x_bc_, y_bc_ = self.evaluate([x_bc, y_bc]) 136 self.assertAllClose(x_bc_expected, x_bc_) 137 self.assertAllClose(y_bc_expected, y_bc_) 138 139 def test_dynamic_dims_broadcast_32bit(self): 140 # x.batch_shape = [3, 1, 2] 141 # y.batch_shape = [4, 1] 142 # broadcast batch shape = [3, 4, 2] 143 x = rng.rand(3, 1, 2, 1, 5).astype(np.float32) 144 y = rng.rand(4, 1, 3, 7).astype(np.float32) 145 batch_of_zeros = np.zeros((3, 4, 2, 1, 1)).astype(np.float32) 146 x_bc_expected = x + batch_of_zeros 147 y_bc_expected = y + batch_of_zeros 148 149 x_ph = array_ops.placeholder_with_default(x, shape=None) 150 y_ph = array_ops.placeholder_with_default(y, shape=None) 151 152 x_bc, y_bc = linear_operator_util.broadcast_matrix_batch_dims([x_ph, y_ph]) 153 154 x_bc_, y_bc_ = self.evaluate([x_bc, y_bc]) 155 self.assertAllClose(x_bc_expected, x_bc_) 156 self.assertAllClose(y_bc_expected, y_bc_) 157 158 def test_dynamic_dims_broadcast_32bit_second_arg_higher_rank(self): 159 # x.batch_shape = [1, 2] 160 # y.batch_shape = [3, 4, 1] 161 # broadcast batch shape = [3, 4, 2] 162 x = rng.rand(1, 2, 1, 5).astype(np.float32) 163 y = rng.rand(3, 4, 1, 3, 7).astype(np.float32) 164 batch_of_zeros = np.zeros((3, 4, 2, 1, 1)).astype(np.float32) 165 x_bc_expected = x + batch_of_zeros 166 y_bc_expected = y + batch_of_zeros 167 168 x_ph = array_ops.placeholder_with_default(x, shape=None) 169 y_ph = array_ops.placeholder_with_default(y, shape=None) 170 171 x_bc, y_bc = linear_operator_util.broadcast_matrix_batch_dims([x_ph, y_ph]) 172 173 x_bc_, y_bc_ = self.evaluate([x_bc, y_bc]) 174 self.assertAllClose(x_bc_expected, x_bc_) 175 self.assertAllClose(y_bc_expected, y_bc_) 176 177 def test_less_than_two_dims_raises_static(self): 178 x = rng.rand(3) 179 y = rng.rand(1, 1) 180 181 with self.assertRaisesRegex(ValueError, "at least two dimensions"): 182 linear_operator_util.broadcast_matrix_batch_dims([x, y]) 183 184 with self.assertRaisesRegex(ValueError, "at least two dimensions"): 185 linear_operator_util.broadcast_matrix_batch_dims([y, x]) 186 187 188class MatrixSolveWithBroadcastTest(test.TestCase): 189 190 def test_static_dims_broadcast_matrix_has_extra_dims(self): 191 # batch_shape = [2] 192 matrix = rng.rand(2, 3, 3) 193 rhs = rng.rand(3, 7) 194 rhs_broadcast = rhs + np.zeros((2, 1, 1)) 195 196 result = linear_operator_util.matrix_solve_with_broadcast(matrix, rhs) 197 self.assertAllEqual((2, 3, 7), result.shape) 198 expected = linalg_ops.matrix_solve(matrix, rhs_broadcast) 199 self.assertAllClose(*self.evaluate([expected, result])) 200 201 def test_static_dims_broadcast_rhs_has_extra_dims(self): 202 # Since the second arg has extra dims, and the domain dim of the first arg 203 # is larger than the number of linear equations, code will "flip" the extra 204 # dims of the first arg to the far right, making extra linear equations 205 # (then call the matrix function, then flip back). 206 # We have verified that this optimization indeed happens. How? We stepped 207 # through with a debugger. 208 # batch_shape = [2] 209 matrix = rng.rand(3, 3) 210 rhs = rng.rand(2, 3, 2) 211 matrix_broadcast = matrix + np.zeros((2, 1, 1)) 212 213 result = linear_operator_util.matrix_solve_with_broadcast(matrix, rhs) 214 self.assertAllEqual((2, 3, 2), result.shape) 215 expected = linalg_ops.matrix_solve(matrix_broadcast, rhs) 216 self.assertAllClose(*self.evaluate([expected, result])) 217 218 def test_static_dims_broadcast_rhs_has_extra_dims_dynamic(self): 219 # Since the second arg has extra dims, and the domain dim of the first arg 220 # is larger than the number of linear equations, code will "flip" the extra 221 # dims of the first arg to the far right, making extra linear equations 222 # (then call the matrix function, then flip back). 223 # We have verified that this optimization indeed happens. How? We stepped 224 # through with a debugger. 225 # batch_shape = [2] 226 matrix = rng.rand(3, 3) 227 rhs = rng.rand(2, 3, 2) 228 matrix_broadcast = matrix + np.zeros((2, 1, 1)) 229 230 matrix_ph = array_ops.placeholder_with_default(matrix, shape=[None, None]) 231 rhs_ph = array_ops.placeholder_with_default(rhs, shape=[None, None, None]) 232 233 result = linear_operator_util.matrix_solve_with_broadcast(matrix_ph, rhs_ph) 234 self.assertAllEqual(3, result.shape.ndims) 235 expected = linalg_ops.matrix_solve(matrix_broadcast, rhs) 236 self.assertAllClose(*self.evaluate([expected, result])) 237 238 def test_static_dims_broadcast_rhs_has_extra_dims_and_adjoint(self): 239 # Since the second arg has extra dims, and the domain dim of the first arg 240 # is larger than the number of linear equations, code will "flip" the extra 241 # dims of the first arg to the far right, making extra linear equations 242 # (then call the matrix function, then flip back). 243 # We have verified that this optimization indeed happens. How? We stepped 244 # through with a debugger. 245 # batch_shape = [2] 246 matrix = rng.rand(3, 3) 247 rhs = rng.rand(2, 3, 2) 248 matrix_broadcast = matrix + np.zeros((2, 1, 1)) 249 250 result = linear_operator_util.matrix_solve_with_broadcast( 251 matrix, rhs, adjoint=True) 252 self.assertAllEqual((2, 3, 2), result.shape) 253 expected = linalg_ops.matrix_solve(matrix_broadcast, rhs, adjoint=True) 254 self.assertAllClose(*self.evaluate([expected, result])) 255 256 def test_dynamic_dims_broadcast_64bit(self): 257 # batch_shape = [2, 2] 258 matrix = rng.rand(2, 3, 3) 259 rhs = rng.rand(2, 1, 3, 7) 260 matrix_broadcast = matrix + np.zeros((2, 2, 1, 1)) 261 rhs_broadcast = rhs + np.zeros((2, 2, 1, 1)) 262 263 matrix_ph = array_ops.placeholder_with_default(matrix, shape=None) 264 rhs_ph = array_ops.placeholder_with_default(rhs, shape=None) 265 266 result, expected = self.evaluate([ 267 linear_operator_util.matrix_solve_with_broadcast(matrix_ph, rhs_ph), 268 linalg_ops.matrix_solve(matrix_broadcast, rhs_broadcast) 269 ]) 270 self.assertAllClose(expected, result) 271 272 273class DomainDimensionStubOperator(object): 274 275 def __init__(self, domain_dimension): 276 self._domain_dimension = ops.convert_to_tensor(domain_dimension) 277 278 def domain_dimension_tensor(self): 279 return self._domain_dimension 280 281 282class AssertCompatibleMatrixDimensionsTest(test.TestCase): 283 284 def test_compatible_dimensions_do_not_raise(self): 285 x = ops.convert_to_tensor(rng.rand(2, 3, 4)) 286 operator = DomainDimensionStubOperator(3) 287 # Should not raise 288 self.evaluate( 289 linear_operator_util.assert_compatible_matrix_dimensions(operator, x)) 290 291 def test_incompatible_dimensions_raise(self): 292 x = ops.convert_to_tensor(rng.rand(2, 4, 4)) 293 operator = DomainDimensionStubOperator(3) 294 # pylint: disable=g-error-prone-assert-raises 295 with self.assertRaisesOpError("Dimensions are not compatible"): 296 self.evaluate( 297 linear_operator_util.assert_compatible_matrix_dimensions(operator, x)) 298 # pylint: enable=g-error-prone-assert-raises 299 300 301class DummyOperatorWithHint(object): 302 303 def __init__(self, **kwargs): 304 self.__dict__.update(kwargs) 305 306 307class UseOperatorOrProvidedHintUnlessContradictingTest(test.TestCase, 308 parameterized.TestCase): 309 310 @parameterized.named_parameters( 311 ("none_none", None, None, None), 312 ("none_true", None, True, True), 313 ("true_none", True, None, True), 314 ("true_true", True, True, True), 315 ("none_false", None, False, False), 316 ("false_none", False, None, False), 317 ("false_false", False, False, False), 318 ) 319 def test_computes_an_or_if_non_contradicting(self, operator_hint_value, 320 provided_hint_value, 321 expected_result): 322 self.assertEqual( 323 expected_result, 324 linear_operator_util.use_operator_or_provided_hint_unless_contradicting( 325 operator=DummyOperatorWithHint(my_hint=operator_hint_value), 326 hint_attr_name="my_hint", 327 provided_hint_value=provided_hint_value, 328 message="should not be needed here")) 329 330 @parameterized.named_parameters( 331 ("true_false", True, False), 332 ("false_true", False, True), 333 ) 334 def test_raises_if_contradicting(self, operator_hint_value, 335 provided_hint_value): 336 with self.assertRaisesRegex(ValueError, "my error message"): 337 linear_operator_util.use_operator_or_provided_hint_unless_contradicting( 338 operator=DummyOperatorWithHint(my_hint=operator_hint_value), 339 hint_attr_name="my_hint", 340 provided_hint_value=provided_hint_value, 341 message="my error message") 342 343 344class BlockwiseTest(test.TestCase, parameterized.TestCase): 345 346 @parameterized.named_parameters( 347 ("split_dim_1", [3, 3, 4], -1), 348 ("split_dim_2", [2, 5], -2), 349 ) 350 def test_blockwise_input(self, op_dimension_values, split_dim): 351 352 op_dimensions = [ 353 tensor_shape.Dimension(v) for v in op_dimension_values] 354 unknown_op_dimensions = [ 355 tensor_shape.Dimension(None) for _ in op_dimension_values] 356 357 batch_shape = [2, 1] 358 arg_dim = 5 359 if split_dim == -1: 360 blockwise_arrays = [np.zeros(batch_shape + [arg_dim, d]) 361 for d in op_dimension_values] 362 else: 363 blockwise_arrays = [np.zeros(batch_shape + [d, arg_dim]) 364 for d in op_dimension_values] 365 366 blockwise_list = [block.tolist() for block in blockwise_arrays] 367 blockwise_tensors = [ops.convert_to_tensor(block) 368 for block in blockwise_arrays] 369 blockwise_placeholders = [ 370 array_ops.placeholder_with_default(block, shape=None) 371 for block in blockwise_arrays] 372 373 # Iterables of non-nested structures are always interpreted as blockwise. 374 # The list of lists is interpreted as blockwise as well, regardless of 375 # whether the operator dimensions are known, since the sizes of its elements 376 # along `split_dim` are non-identical. 377 for op_dims in [op_dimensions, unknown_op_dimensions]: 378 for blockwise_inputs in [ 379 blockwise_arrays, blockwise_list, 380 blockwise_tensors, blockwise_placeholders]: 381 self.assertTrue(linear_operator_util.arg_is_blockwise( 382 op_dims, blockwise_inputs, split_dim)) 383 384 def test_non_blockwise_input(self): 385 x = np.zeros((2, 3, 4, 6)) 386 x_tensor = ops.convert_to_tensor(x) 387 x_placeholder = array_ops.placeholder_with_default(x, shape=None) 388 x_list = x.tolist() 389 390 # For known and matching operator dimensions, interpret all as non-blockwise 391 op_dimension_values = [2, 1, 3] 392 op_dimensions = [tensor_shape.Dimension(d) for d in op_dimension_values] 393 for inputs in [x, x_tensor, x_placeholder, x_list]: 394 self.assertFalse(linear_operator_util.arg_is_blockwise( 395 op_dimensions, inputs, -1)) 396 397 # The input is still interpreted as non-blockwise for unknown operator 398 # dimensions (`x_list` has an outermost dimension that does not matcn the 399 # number of blocks, and the other inputs are not iterables). 400 unknown_op_dimensions = [ 401 tensor_shape.Dimension(None) for _ in op_dimension_values] 402 for inputs in [x, x_tensor, x_placeholder, x_list]: 403 self.assertFalse(linear_operator_util.arg_is_blockwise( 404 unknown_op_dimensions, inputs, -1)) 405 406 def test_ambiguous_input_raises(self): 407 x = np.zeros((3, 4, 2)).tolist() 408 op_dimensions = [tensor_shape.Dimension(None) for _ in range(3)] 409 410 # Since the leftmost dimension of `x` is equal to the number of blocks, and 411 # the operators have unknown dimension, the input is ambiguous. 412 with self.assertRaisesRegex(ValueError, "structure is ambiguous"): 413 linear_operator_util.arg_is_blockwise(op_dimensions, x, -2) 414 415 def test_mismatched_input_raises(self): 416 x = np.zeros((2, 3, 4, 6)).tolist() 417 op_dimension_values = [4, 3] 418 op_dimensions = [tensor_shape.Dimension(v) for v in op_dimension_values] 419 420 # The dimensions of the two operator-blocks sum to 7. `x` is a 421 # two-element list; if interpreted blockwise, its corresponding dimensions 422 # sum to 12 (=6*2). If not interpreted blockwise, its corresponding 423 # dimension is 6. This is a mismatch. 424 with self.assertRaisesRegex(ValueError, "dimension does not match"): 425 linear_operator_util.arg_is_blockwise(op_dimensions, x, -1) 426 427if __name__ == "__main__": 428 test.main() 429