• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16import numpy as np
17
18from tensorflow.python.eager import backprop
19from tensorflow.python.eager import context
20from tensorflow.python.framework import config
21from tensorflow.python.framework import dtypes
22from tensorflow.python.framework import test_util
23from tensorflow.python.ops import array_ops
24from tensorflow.python.ops import variables as variables_module
25from tensorflow.python.ops.linalg import linalg as linalg_lib
26from tensorflow.python.ops.linalg import linear_operator_block_lower_triangular as block_lower_triangular
27from tensorflow.python.ops.linalg import linear_operator_test_util
28from tensorflow.python.ops.linalg import linear_operator_util
29from tensorflow.python.platform import test
30
31linalg = linalg_lib
32rng = np.random.RandomState(0)
33
34
35def _block_lower_triangular_dense(expected_shape, blocks):
36  """Convert a list of blocks into a dense blockwise lower-triangular matrix."""
37  rows = []
38  num_cols = 0
39  for row_blocks in blocks:
40
41    # Get the batch shape for the block.
42    batch_row_shape = array_ops.shape(row_blocks[0])[:-1]
43
44    num_cols += array_ops.shape(row_blocks[-1])[-1]
45    zeros_to_pad_after_shape = array_ops.concat(
46        [batch_row_shape, [expected_shape[-2] - num_cols]], axis=-1)
47    zeros_to_pad_after = array_ops.zeros(
48        zeros_to_pad_after_shape, dtype=row_blocks[-1].dtype)
49
50    row_blocks.append(zeros_to_pad_after)
51    rows.append(array_ops.concat(row_blocks, axis=-1))
52
53  return array_ops.concat(rows, axis=-2)
54
55
56@test_util.run_all_in_graph_and_eager_modes
57class SquareLinearOperatorBlockLowerTriangularTest(
58    linear_operator_test_util.SquareLinearOperatorDerivedClassTest):
59  """Most tests done in the base class LinearOperatorDerivedClassTest."""
60
61  def tearDown(self):
62    config.enable_tensor_float_32_execution(self.tf32_keep_)
63
64  def setUp(self):
65    self.tf32_keep_ = config.tensor_float_32_execution_enabled()
66    config.enable_tensor_float_32_execution(False)
67    # Increase from 1e-6 to 1e-5
68    self._atol[dtypes.float32] = 1e-5
69    self._atol[dtypes.complex64] = 1e-5
70    self._rtol[dtypes.float32] = 1e-5
71    self._rtol[dtypes.complex64] = 1e-5
72    super(SquareLinearOperatorBlockLowerTriangularTest, self).setUp()
73
74  @staticmethod
75  def use_blockwise_arg():
76    return True
77
78  @staticmethod
79  def skip_these_tests():
80    # Skipping since `LinearOperatorBlockLowerTriangular` is in general not
81    # self-adjoint.
82    return ["cholesky", "eigvalsh"]
83
84  @staticmethod
85  def operator_shapes_infos():
86    shape_info = linear_operator_test_util.OperatorShapesInfo
87    return [
88        shape_info((0, 0)),
89        shape_info((1, 1)),
90        shape_info((1, 3, 3)),
91        shape_info((5, 5), blocks=[[(2, 2)], [(3, 2), (3, 3)]]),
92        shape_info((3, 7, 7),
93                   blocks=[[(1, 2, 2)], [(1, 3, 2), (3, 3, 3)],
94                           [(1, 2, 2), (1, 2, 3), (1, 2, 2)]]),
95        shape_info((2, 4, 6, 6),
96                   blocks=[[(2, 1, 2, 2)], [(1, 4, 2), (4, 4, 4)]]),
97    ]
98
99  def operator_and_matrix(
100      self, shape_info, dtype, use_placeholder,
101      ensure_self_adjoint_and_pd=False):
102
103    expected_blocks = (
104        shape_info.__dict__["blocks"] if "blocks" in shape_info.__dict__
105        else [[list(shape_info.shape)]])
106
107    matrices = []
108    for i, row_shapes in enumerate(expected_blocks):
109      row = []
110      for j, block_shape in enumerate(row_shapes):
111        if i == j:  # operator is on the diagonal
112          row.append(
113              linear_operator_test_util.random_positive_definite_matrix(
114                  block_shape, dtype, force_well_conditioned=True))
115        else:
116          row.append(
117              linear_operator_test_util.random_normal(block_shape, dtype=dtype))
118      matrices.append(row)
119
120    lin_op_matrices = matrices
121
122    if use_placeholder:
123      lin_op_matrices = [[
124          array_ops.placeholder_with_default(
125              matrix, shape=None) for matrix in row] for row in matrices]
126
127    operator = block_lower_triangular.LinearOperatorBlockLowerTriangular(
128        [[linalg.LinearOperatorFullMatrix(  # pylint:disable=g-complex-comprehension
129            l,
130            is_square=True,
131            is_self_adjoint=True if ensure_self_adjoint_and_pd else None,
132            is_positive_definite=True if ensure_self_adjoint_and_pd else None)
133          for l in row] for row in lin_op_matrices])
134
135    # Should be auto-set.
136    self.assertTrue(operator.is_square)
137
138    # Broadcast the shapes.
139    expected_shape = list(shape_info.shape)
140    broadcasted_matrices = linear_operator_util.broadcast_matrix_batch_dims(
141        [op for row in matrices for op in row])  # pylint: disable=g-complex-comprehension
142    matrices = [broadcasted_matrices[i * (i + 1) // 2:(i + 1) * (i + 2) // 2]
143                for i in range(len(matrices))]
144
145    block_lower_triangular_dense = _block_lower_triangular_dense(
146        expected_shape, matrices)
147
148    if not use_placeholder:
149      block_lower_triangular_dense.set_shape(expected_shape)
150
151    return operator, block_lower_triangular_dense
152
153  def test_is_x_flags(self):
154    # Matrix with two positive eigenvalues, 1, and 1.
155    # The matrix values do not effect auto-setting of the flags.
156    matrix = [[1., 0.], [1., 1.]]
157    operator = block_lower_triangular.LinearOperatorBlockLowerTriangular(
158        [[linalg.LinearOperatorFullMatrix(matrix)]],
159        is_positive_definite=True,
160        is_non_singular=True,
161        is_self_adjoint=False)
162    self.assertTrue(operator.is_positive_definite)
163    self.assertTrue(operator.is_non_singular)
164    self.assertFalse(operator.is_self_adjoint)
165
166  def test_block_lower_triangular_inverse_type(self):
167    matrix = [[1., 0.], [0., 1.]]
168    operator = block_lower_triangular.LinearOperatorBlockLowerTriangular(
169        [[linalg.LinearOperatorFullMatrix(matrix, is_non_singular=True)],
170         [linalg.LinearOperatorFullMatrix(matrix, is_non_singular=True),
171          linalg.LinearOperatorFullMatrix(matrix, is_non_singular=True)]],
172        is_non_singular=True,
173    )
174    inverse = operator.inverse()
175    self.assertIsInstance(
176        inverse,
177        block_lower_triangular.LinearOperatorBlockLowerTriangular)
178    self.assertEqual(2, len(inverse.operators))
179    self.assertEqual(1, len(inverse.operators[0]))
180    self.assertEqual(2, len(inverse.operators[1]))
181
182  def test_tape_safe(self):
183    operator_1 = linalg.LinearOperatorFullMatrix(
184        variables_module.Variable([[1., 0.], [0., 1.]]),
185        is_self_adjoint=True,
186        is_positive_definite=True)
187    operator_2 = linalg.LinearOperatorFullMatrix(
188        variables_module.Variable([[2., 0.], [1., 0.]]))
189    operator_3 = linalg.LinearOperatorFullMatrix(
190        variables_module.Variable([[3., 1.], [1., 3.]]),
191        is_self_adjoint=True,
192        is_positive_definite=True)
193    operator = block_lower_triangular.LinearOperatorBlockLowerTriangular(
194        [[operator_1], [operator_2, operator_3]],
195        is_self_adjoint=False,
196        is_positive_definite=True)
197
198    diagonal_grads_only = ["diag_part", "trace", "determinant",
199                           "log_abs_determinant"]
200    self.check_tape_safe(operator, skip_options=diagonal_grads_only)
201
202    for y in diagonal_grads_only:
203      for diag_block in [operator_1, operator_3]:
204        with backprop.GradientTape() as tape:
205          grads = tape.gradient(getattr(operator, y)(), diag_block.variables)
206          for item in grads:
207            self.assertIsNotNone(item)
208
209  def test_convert_variables_to_tensors(self):
210    operator_1 = linalg.LinearOperatorFullMatrix(
211        variables_module.Variable([[1., 0.], [0., 1.]]),
212        is_self_adjoint=True,
213        is_positive_definite=True)
214    operator_2 = linalg.LinearOperatorFullMatrix(
215        variables_module.Variable([[2., 0.], [1., 0.]]))
216    operator_3 = linalg.LinearOperatorFullMatrix(
217        variables_module.Variable([[3., 1.], [1., 3.]]),
218        is_self_adjoint=True,
219        is_positive_definite=True)
220    operator = block_lower_triangular.LinearOperatorBlockLowerTriangular(
221        [[operator_1], [operator_2, operator_3]],
222        is_self_adjoint=False,
223        is_positive_definite=True)
224    with self.cached_session() as sess:
225      sess.run([x.initializer for x in operator.variables])
226      self.check_convert_variables_to_tensors(operator)
227
228  def test_is_non_singular_auto_set(self):
229    # Matrix with two positive eigenvalues, 11 and 8.
230    # The matrix values do not effect auto-setting of the flags.
231    matrix = [[11., 0.], [1., 8.]]
232    operator_1 = linalg.LinearOperatorFullMatrix(matrix, is_non_singular=True)
233    operator_2 = linalg.LinearOperatorFullMatrix(matrix, is_non_singular=True)
234    operator_3 = linalg.LinearOperatorFullMatrix(matrix, is_non_singular=True)
235
236    operator = block_lower_triangular.LinearOperatorBlockLowerTriangular(
237        [[operator_1], [operator_2, operator_3]],
238        is_positive_definite=False,  # No reason it HAS to be False...
239        is_non_singular=None)
240    self.assertFalse(operator.is_positive_definite)
241    self.assertTrue(operator.is_non_singular)
242
243    with self.assertRaisesRegex(ValueError, "always non-singular"):
244      block_lower_triangular.LinearOperatorBlockLowerTriangular(
245          [[operator_1], [operator_2, operator_3]], is_non_singular=False)
246
247    operator_4 = linalg.LinearOperatorFullMatrix(
248        [[1., 0.], [2., 0.]], is_non_singular=False)
249
250    # A singular operator off of the main diagonal shouldn't raise
251    block_lower_triangular.LinearOperatorBlockLowerTriangular(
252        [[operator_1], [operator_4, operator_2]], is_non_singular=True)
253
254    with self.assertRaisesRegex(ValueError, "always singular"):
255      block_lower_triangular.LinearOperatorBlockLowerTriangular(
256          [[operator_1], [operator_2, operator_4]], is_non_singular=True)
257
258  def test_different_dtypes_raises(self):
259    operators = [
260        [linalg.LinearOperatorFullMatrix(rng.rand(2, 3, 3))],
261        [linalg.LinearOperatorFullMatrix(rng.rand(2, 3, 3)),
262         linalg.LinearOperatorFullMatrix(rng.rand(2, 3, 3).astype(np.float32))]
263    ]
264    with self.assertRaisesRegex(TypeError, "same dtype"):
265      block_lower_triangular.LinearOperatorBlockLowerTriangular(operators)
266
267  def test_non_square_operator_raises(self):
268    operators = [
269        [linalg.LinearOperatorFullMatrix(rng.rand(3, 4), is_square=False)],
270        [linalg.LinearOperatorFullMatrix(rng.rand(4, 4)),
271         linalg.LinearOperatorFullMatrix(rng.rand(4, 4))]
272    ]
273    with self.assertRaisesRegex(ValueError, "must be square"):
274      block_lower_triangular.LinearOperatorBlockLowerTriangular(operators)
275
276  def test_empty_operators_raises(self):
277    with self.assertRaisesRegex(ValueError, "must be a list of >=1"):
278      block_lower_triangular.LinearOperatorBlockLowerTriangular([])
279
280  def test_operators_wrong_length_raises(self):
281    with self.assertRaisesRegex(ValueError, "must contain `2` blocks"):
282      block_lower_triangular.LinearOperatorBlockLowerTriangular([
283          [linalg.LinearOperatorFullMatrix(rng.rand(2, 2))],
284          [linalg.LinearOperatorFullMatrix(rng.rand(2, 2))
285           for _ in range(3)]])
286
287  def test_operators_mismatched_dimension_raises(self):
288    operators = [
289        [linalg.LinearOperatorFullMatrix(rng.rand(3, 3))],
290        [linalg.LinearOperatorFullMatrix(rng.rand(3, 4)),
291         linalg.LinearOperatorFullMatrix(rng.rand(3, 3))]
292    ]
293    with self.assertRaisesRegex(ValueError, "must be the same as"):
294      block_lower_triangular.LinearOperatorBlockLowerTriangular(operators)
295
296  def test_incompatible_input_blocks_raises(self):
297    matrix_1 = array_ops.placeholder_with_default(rng.rand(4, 4), shape=None)
298    matrix_2 = array_ops.placeholder_with_default(rng.rand(3, 4), shape=None)
299    matrix_3 = array_ops.placeholder_with_default(rng.rand(3, 3), shape=None)
300    operators = [
301        [linalg.LinearOperatorFullMatrix(matrix_1, is_square=True)],
302        [linalg.LinearOperatorFullMatrix(matrix_2),
303         linalg.LinearOperatorFullMatrix(matrix_3, is_square=True)]
304    ]
305    operator = block_lower_triangular.LinearOperatorBlockLowerTriangular(
306        operators)
307    x = np.random.rand(2, 4, 5).tolist()
308    msg = ("dimension does not match" if context.executing_eagerly()
309           else "input structure is ambiguous")
310    with self.assertRaisesRegex(ValueError, msg):
311      operator.matmul(x)
312
313
314if __name__ == "__main__":
315  linear_operator_test_util.add_tests(
316      SquareLinearOperatorBlockLowerTriangularTest)
317  test.main()
318