• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16from absl.testing import parameterized
17import numpy as np
18
19from tensorflow.python.framework import ops
20from tensorflow.python.framework import tensor_shape
21from tensorflow.python.ops import array_ops
22from tensorflow.python.ops import linalg_ops
23from tensorflow.python.ops import math_ops
24from tensorflow.python.ops.linalg import linear_operator_util
25from tensorflow.python.platform import test
26
27rng = np.random.RandomState(0)
28
29
30class AssertZeroImagPartTest(test.TestCase):
31
32  def test_real_tensor_doesnt_raise(self):
33    x = ops.convert_to_tensor([0., 2, 3])
34    # Should not raise.
35    self.evaluate(
36        linear_operator_util.assert_zero_imag_part(x, message="ABC123"))
37
38  def test_complex_tensor_with_imag_zero_doesnt_raise(self):
39    x = ops.convert_to_tensor([1., 0, 3])
40    y = ops.convert_to_tensor([0., 0, 0])
41    z = math_ops.complex(x, y)
42    # Should not raise.
43    self.evaluate(
44        linear_operator_util.assert_zero_imag_part(z, message="ABC123"))
45
46  def test_complex_tensor_with_nonzero_imag_raises(self):
47    x = ops.convert_to_tensor([1., 2, 0])
48    y = ops.convert_to_tensor([1., 2, 0])
49    z = math_ops.complex(x, y)
50    with self.assertRaisesOpError("ABC123"):
51      self.evaluate(
52          linear_operator_util.assert_zero_imag_part(z, message="ABC123"))
53
54
55class AssertNoEntriesWithModulusZeroTest(test.TestCase):
56
57  def test_nonzero_real_tensor_doesnt_raise(self):
58    x = ops.convert_to_tensor([1., 2, 3])
59    # Should not raise.
60    self.evaluate(
61        linear_operator_util.assert_no_entries_with_modulus_zero(
62            x, message="ABC123"))
63
64  def test_nonzero_complex_tensor_doesnt_raise(self):
65    x = ops.convert_to_tensor([1., 0, 3])
66    y = ops.convert_to_tensor([1., 2, 0])
67    z = math_ops.complex(x, y)
68    # Should not raise.
69    self.evaluate(
70        linear_operator_util.assert_no_entries_with_modulus_zero(
71            z, message="ABC123"))
72
73  def test_zero_real_tensor_raises(self):
74    x = ops.convert_to_tensor([1., 0, 3])
75    with self.assertRaisesOpError("ABC123"):
76      self.evaluate(
77          linear_operator_util.assert_no_entries_with_modulus_zero(
78              x, message="ABC123"))
79
80  def test_zero_complex_tensor_raises(self):
81    x = ops.convert_to_tensor([1., 2, 0])
82    y = ops.convert_to_tensor([1., 2, 0])
83    z = math_ops.complex(x, y)
84    with self.assertRaisesOpError("ABC123"):
85      self.evaluate(
86          linear_operator_util.assert_no_entries_with_modulus_zero(
87              z, message="ABC123"))
88
89
90class BroadcastMatrixBatchDimsTest(test.TestCase):
91
92  def test_zero_batch_matrices_returned_as_empty_list(self):
93    self.assertAllEqual([],
94                        linear_operator_util.broadcast_matrix_batch_dims([]))
95
96  def test_one_batch_matrix_returned_after_tensor_conversion(self):
97    arr = rng.rand(2, 3, 4)
98    tensor, = linear_operator_util.broadcast_matrix_batch_dims([arr])
99    self.assertTrue(isinstance(tensor, ops.Tensor))
100
101    self.assertAllClose(arr, self.evaluate(tensor))
102
103  def test_static_dims_broadcast(self):
104    # x.batch_shape = [3, 1, 2]
105    # y.batch_shape = [4, 1]
106    # broadcast batch shape = [3, 4, 2]
107    x = rng.rand(3, 1, 2, 1, 5)
108    y = rng.rand(4, 1, 3, 7)
109    batch_of_zeros = np.zeros((3, 4, 2, 1, 1))
110    x_bc_expected = x + batch_of_zeros
111    y_bc_expected = y + batch_of_zeros
112
113    x_bc, y_bc = linear_operator_util.broadcast_matrix_batch_dims([x, y])
114
115    self.assertAllEqual(x_bc_expected.shape, x_bc.shape)
116    self.assertAllEqual(y_bc_expected.shape, y_bc.shape)
117    x_bc_, y_bc_ = self.evaluate([x_bc, y_bc])
118    self.assertAllClose(x_bc_expected, x_bc_)
119    self.assertAllClose(y_bc_expected, y_bc_)
120
121  def test_static_dims_broadcast_second_arg_higher_rank(self):
122    # x.batch_shape =    [1, 2]
123    # y.batch_shape = [1, 3, 1]
124    # broadcast batch shape = [1, 3, 2]
125    x = rng.rand(1, 2, 1, 5)
126    y = rng.rand(1, 3, 2, 3, 7)
127    batch_of_zeros = np.zeros((1, 3, 2, 1, 1))
128    x_bc_expected = x + batch_of_zeros
129    y_bc_expected = y + batch_of_zeros
130
131    x_bc, y_bc = linear_operator_util.broadcast_matrix_batch_dims([x, y])
132
133    self.assertAllEqual(x_bc_expected.shape, x_bc.shape)
134    self.assertAllEqual(y_bc_expected.shape, y_bc.shape)
135    x_bc_, y_bc_ = self.evaluate([x_bc, y_bc])
136    self.assertAllClose(x_bc_expected, x_bc_)
137    self.assertAllClose(y_bc_expected, y_bc_)
138
139  def test_dynamic_dims_broadcast_32bit(self):
140    # x.batch_shape = [3, 1, 2]
141    # y.batch_shape = [4, 1]
142    # broadcast batch shape = [3, 4, 2]
143    x = rng.rand(3, 1, 2, 1, 5).astype(np.float32)
144    y = rng.rand(4, 1, 3, 7).astype(np.float32)
145    batch_of_zeros = np.zeros((3, 4, 2, 1, 1)).astype(np.float32)
146    x_bc_expected = x + batch_of_zeros
147    y_bc_expected = y + batch_of_zeros
148
149    x_ph = array_ops.placeholder_with_default(x, shape=None)
150    y_ph = array_ops.placeholder_with_default(y, shape=None)
151
152    x_bc, y_bc = linear_operator_util.broadcast_matrix_batch_dims([x_ph, y_ph])
153
154    x_bc_, y_bc_ = self.evaluate([x_bc, y_bc])
155    self.assertAllClose(x_bc_expected, x_bc_)
156    self.assertAllClose(y_bc_expected, y_bc_)
157
158  def test_dynamic_dims_broadcast_32bit_second_arg_higher_rank(self):
159    # x.batch_shape =    [1, 2]
160    # y.batch_shape = [3, 4, 1]
161    # broadcast batch shape = [3, 4, 2]
162    x = rng.rand(1, 2, 1, 5).astype(np.float32)
163    y = rng.rand(3, 4, 1, 3, 7).astype(np.float32)
164    batch_of_zeros = np.zeros((3, 4, 2, 1, 1)).astype(np.float32)
165    x_bc_expected = x + batch_of_zeros
166    y_bc_expected = y + batch_of_zeros
167
168    x_ph = array_ops.placeholder_with_default(x, shape=None)
169    y_ph = array_ops.placeholder_with_default(y, shape=None)
170
171    x_bc, y_bc = linear_operator_util.broadcast_matrix_batch_dims([x_ph, y_ph])
172
173    x_bc_, y_bc_ = self.evaluate([x_bc, y_bc])
174    self.assertAllClose(x_bc_expected, x_bc_)
175    self.assertAllClose(y_bc_expected, y_bc_)
176
177  def test_less_than_two_dims_raises_static(self):
178    x = rng.rand(3)
179    y = rng.rand(1, 1)
180
181    with self.assertRaisesRegex(ValueError, "at least two dimensions"):
182      linear_operator_util.broadcast_matrix_batch_dims([x, y])
183
184    with self.assertRaisesRegex(ValueError, "at least two dimensions"):
185      linear_operator_util.broadcast_matrix_batch_dims([y, x])
186
187
188class MatrixSolveWithBroadcastTest(test.TestCase):
189
190  def test_static_dims_broadcast_matrix_has_extra_dims(self):
191    # batch_shape = [2]
192    matrix = rng.rand(2, 3, 3)
193    rhs = rng.rand(3, 7)
194    rhs_broadcast = rhs + np.zeros((2, 1, 1))
195
196    result = linear_operator_util.matrix_solve_with_broadcast(matrix, rhs)
197    self.assertAllEqual((2, 3, 7), result.shape)
198    expected = linalg_ops.matrix_solve(matrix, rhs_broadcast)
199    self.assertAllClose(*self.evaluate([expected, result]))
200
201  def test_static_dims_broadcast_rhs_has_extra_dims(self):
202    # Since the second arg has extra dims, and the domain dim of the first arg
203    # is larger than the number of linear equations, code will "flip" the extra
204    # dims of the first arg to the far right, making extra linear equations
205    # (then call the matrix function, then flip back).
206    # We have verified that this optimization indeed happens.  How? We stepped
207    # through with a debugger.
208    # batch_shape = [2]
209    matrix = rng.rand(3, 3)
210    rhs = rng.rand(2, 3, 2)
211    matrix_broadcast = matrix + np.zeros((2, 1, 1))
212
213    result = linear_operator_util.matrix_solve_with_broadcast(matrix, rhs)
214    self.assertAllEqual((2, 3, 2), result.shape)
215    expected = linalg_ops.matrix_solve(matrix_broadcast, rhs)
216    self.assertAllClose(*self.evaluate([expected, result]))
217
218  def test_static_dims_broadcast_rhs_has_extra_dims_dynamic(self):
219    # Since the second arg has extra dims, and the domain dim of the first arg
220    # is larger than the number of linear equations, code will "flip" the extra
221    # dims of the first arg to the far right, making extra linear equations
222    # (then call the matrix function, then flip back).
223    # We have verified that this optimization indeed happens.  How? We stepped
224    # through with a debugger.
225    # batch_shape = [2]
226    matrix = rng.rand(3, 3)
227    rhs = rng.rand(2, 3, 2)
228    matrix_broadcast = matrix + np.zeros((2, 1, 1))
229
230    matrix_ph = array_ops.placeholder_with_default(matrix, shape=[None, None])
231    rhs_ph = array_ops.placeholder_with_default(rhs, shape=[None, None, None])
232
233    result = linear_operator_util.matrix_solve_with_broadcast(matrix_ph, rhs_ph)
234    self.assertAllEqual(3, result.shape.ndims)
235    expected = linalg_ops.matrix_solve(matrix_broadcast, rhs)
236    self.assertAllClose(*self.evaluate([expected, result]))
237
238  def test_static_dims_broadcast_rhs_has_extra_dims_and_adjoint(self):
239    # Since the second arg has extra dims, and the domain dim of the first arg
240    # is larger than the number of linear equations, code will "flip" the extra
241    # dims of the first arg to the far right, making extra linear equations
242    # (then call the matrix function, then flip back).
243    # We have verified that this optimization indeed happens.  How? We stepped
244    # through with a debugger.
245    # batch_shape = [2]
246    matrix = rng.rand(3, 3)
247    rhs = rng.rand(2, 3, 2)
248    matrix_broadcast = matrix + np.zeros((2, 1, 1))
249
250    result = linear_operator_util.matrix_solve_with_broadcast(
251        matrix, rhs, adjoint=True)
252    self.assertAllEqual((2, 3, 2), result.shape)
253    expected = linalg_ops.matrix_solve(matrix_broadcast, rhs, adjoint=True)
254    self.assertAllClose(*self.evaluate([expected, result]))
255
256  def test_dynamic_dims_broadcast_64bit(self):
257    # batch_shape = [2, 2]
258    matrix = rng.rand(2, 3, 3)
259    rhs = rng.rand(2, 1, 3, 7)
260    matrix_broadcast = matrix + np.zeros((2, 2, 1, 1))
261    rhs_broadcast = rhs + np.zeros((2, 2, 1, 1))
262
263    matrix_ph = array_ops.placeholder_with_default(matrix, shape=None)
264    rhs_ph = array_ops.placeholder_with_default(rhs, shape=None)
265
266    result, expected = self.evaluate([
267        linear_operator_util.matrix_solve_with_broadcast(matrix_ph, rhs_ph),
268        linalg_ops.matrix_solve(matrix_broadcast, rhs_broadcast)
269    ])
270    self.assertAllClose(expected, result)
271
272
273class DomainDimensionStubOperator(object):
274
275  def __init__(self, domain_dimension):
276    self._domain_dimension = ops.convert_to_tensor(domain_dimension)
277
278  def domain_dimension_tensor(self):
279    return self._domain_dimension
280
281
282class AssertCompatibleMatrixDimensionsTest(test.TestCase):
283
284  def test_compatible_dimensions_do_not_raise(self):
285    x = ops.convert_to_tensor(rng.rand(2, 3, 4))
286    operator = DomainDimensionStubOperator(3)
287    # Should not raise
288    self.evaluate(
289        linear_operator_util.assert_compatible_matrix_dimensions(operator, x))
290
291  def test_incompatible_dimensions_raise(self):
292    x = ops.convert_to_tensor(rng.rand(2, 4, 4))
293    operator = DomainDimensionStubOperator(3)
294    # pylint: disable=g-error-prone-assert-raises
295    with self.assertRaisesOpError("Dimensions are not compatible"):
296      self.evaluate(
297          linear_operator_util.assert_compatible_matrix_dimensions(operator, x))
298    # pylint: enable=g-error-prone-assert-raises
299
300
301class DummyOperatorWithHint(object):
302
303  def __init__(self, **kwargs):
304    self.__dict__.update(kwargs)
305
306
307class UseOperatorOrProvidedHintUnlessContradictingTest(test.TestCase,
308                                                       parameterized.TestCase):
309
310  @parameterized.named_parameters(
311      ("none_none", None, None, None),
312      ("none_true", None, True, True),
313      ("true_none", True, None, True),
314      ("true_true", True, True, True),
315      ("none_false", None, False, False),
316      ("false_none", False, None, False),
317      ("false_false", False, False, False),
318  )
319  def test_computes_an_or_if_non_contradicting(self, operator_hint_value,
320                                               provided_hint_value,
321                                               expected_result):
322    self.assertEqual(
323        expected_result,
324        linear_operator_util.use_operator_or_provided_hint_unless_contradicting(
325            operator=DummyOperatorWithHint(my_hint=operator_hint_value),
326            hint_attr_name="my_hint",
327            provided_hint_value=provided_hint_value,
328            message="should not be needed here"))
329
330  @parameterized.named_parameters(
331      ("true_false", True, False),
332      ("false_true", False, True),
333  )
334  def test_raises_if_contradicting(self, operator_hint_value,
335                                   provided_hint_value):
336    with self.assertRaisesRegex(ValueError, "my error message"):
337      linear_operator_util.use_operator_or_provided_hint_unless_contradicting(
338          operator=DummyOperatorWithHint(my_hint=operator_hint_value),
339          hint_attr_name="my_hint",
340          provided_hint_value=provided_hint_value,
341          message="my error message")
342
343
344class BlockwiseTest(test.TestCase, parameterized.TestCase):
345
346  @parameterized.named_parameters(
347      ("split_dim_1", [3, 3, 4], -1),
348      ("split_dim_2", [2, 5], -2),
349      )
350  def test_blockwise_input(self, op_dimension_values, split_dim):
351
352    op_dimensions = [
353        tensor_shape.Dimension(v) for v in op_dimension_values]
354    unknown_op_dimensions = [
355        tensor_shape.Dimension(None) for _ in op_dimension_values]
356
357    batch_shape = [2, 1]
358    arg_dim = 5
359    if split_dim == -1:
360      blockwise_arrays = [np.zeros(batch_shape + [arg_dim, d])
361                          for d in op_dimension_values]
362    else:
363      blockwise_arrays = [np.zeros(batch_shape + [d, arg_dim])
364                          for d in op_dimension_values]
365
366    blockwise_list = [block.tolist() for block in blockwise_arrays]
367    blockwise_tensors = [ops.convert_to_tensor(block)
368                         for block in blockwise_arrays]
369    blockwise_placeholders = [
370        array_ops.placeholder_with_default(block, shape=None)
371        for block in blockwise_arrays]
372
373    # Iterables of non-nested structures are always interpreted as blockwise.
374    # The list of lists is interpreted as blockwise as well, regardless of
375    # whether the operator dimensions are known, since the sizes of its elements
376    # along `split_dim` are non-identical.
377    for op_dims in [op_dimensions, unknown_op_dimensions]:
378      for blockwise_inputs in [
379          blockwise_arrays, blockwise_list,
380          blockwise_tensors, blockwise_placeholders]:
381        self.assertTrue(linear_operator_util.arg_is_blockwise(
382            op_dims, blockwise_inputs, split_dim))
383
384  def test_non_blockwise_input(self):
385    x = np.zeros((2, 3, 4, 6))
386    x_tensor = ops.convert_to_tensor(x)
387    x_placeholder = array_ops.placeholder_with_default(x, shape=None)
388    x_list = x.tolist()
389
390    # For known and matching operator dimensions, interpret all as non-blockwise
391    op_dimension_values = [2, 1, 3]
392    op_dimensions = [tensor_shape.Dimension(d) for d in op_dimension_values]
393    for inputs in [x, x_tensor, x_placeholder, x_list]:
394      self.assertFalse(linear_operator_util.arg_is_blockwise(
395          op_dimensions, inputs, -1))
396
397    # The input is still interpreted as non-blockwise for unknown operator
398    # dimensions (`x_list` has an outermost dimension that does not matcn the
399    # number of blocks, and the other inputs are not iterables).
400    unknown_op_dimensions = [
401        tensor_shape.Dimension(None) for _ in op_dimension_values]
402    for inputs in [x, x_tensor, x_placeholder, x_list]:
403      self.assertFalse(linear_operator_util.arg_is_blockwise(
404          unknown_op_dimensions, inputs, -1))
405
406  def test_ambiguous_input_raises(self):
407    x = np.zeros((3, 4, 2)).tolist()
408    op_dimensions = [tensor_shape.Dimension(None) for _ in range(3)]
409
410    # Since the leftmost dimension of `x` is equal to the number of blocks, and
411    # the operators have unknown dimension, the input is ambiguous.
412    with self.assertRaisesRegex(ValueError, "structure is ambiguous"):
413      linear_operator_util.arg_is_blockwise(op_dimensions, x, -2)
414
415  def test_mismatched_input_raises(self):
416    x = np.zeros((2, 3, 4, 6)).tolist()
417    op_dimension_values = [4, 3]
418    op_dimensions = [tensor_shape.Dimension(v) for v in op_dimension_values]
419
420    # The dimensions of the two operator-blocks sum to 7. `x` is a
421    # two-element list; if interpreted blockwise, its corresponding dimensions
422    # sum to 12 (=6*2). If not interpreted blockwise, its corresponding
423    # dimension is 6. This is a mismatch.
424    with self.assertRaisesRegex(ValueError, "dimension does not match"):
425      linear_operator_util.arg_is_blockwise(op_dimensions, x, -1)
426
427if __name__ == "__main__":
428  test.main()
429