• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16import numpy as np
17
18from tensorflow.python.eager import context
19from tensorflow.python.framework import config
20from tensorflow.python.framework import dtypes
21from tensorflow.python.framework import test_util
22from tensorflow.python.ops import array_ops
23from tensorflow.python.ops import math_ops
24from tensorflow.python.ops import variables as variables_module
25from tensorflow.python.ops.linalg import linalg as linalg_lib
26from tensorflow.python.ops.linalg import linear_operator_block_diag as block_diag
27from tensorflow.python.ops.linalg import linear_operator_lower_triangular as lower_triangular
28from tensorflow.python.ops.linalg import linear_operator_test_util
29from tensorflow.python.ops.linalg import linear_operator_util
30from tensorflow.python.platform import test
31
32linalg = linalg_lib
33rng = np.random.RandomState(0)
34
35
36def _block_diag_dense(expected_shape, blocks):
37  """Convert a list of blocks, into a dense block diagonal matrix."""
38  rows = []
39  num_cols = 0
40  for block in blocks:
41    # Get the batch shape for the block.
42    batch_row_shape = array_ops.shape(block)[:-1]
43
44    zeros_to_pad_before_shape = array_ops.concat(
45        [batch_row_shape, [num_cols]], axis=-1)
46    zeros_to_pad_before = array_ops.zeros(
47        shape=zeros_to_pad_before_shape, dtype=block.dtype)
48    num_cols += array_ops.shape(block)[-1]
49    zeros_to_pad_after_shape = array_ops.concat(
50        [batch_row_shape, [expected_shape[-1] - num_cols]], axis=-1)
51    zeros_to_pad_after = array_ops.zeros(
52        zeros_to_pad_after_shape, dtype=block.dtype)
53
54    rows.append(array_ops.concat(
55        [zeros_to_pad_before, block, zeros_to_pad_after], axis=-1))
56
57  return array_ops.concat(rows, axis=-2)
58
59
60@test_util.run_all_in_graph_and_eager_modes
61class SquareLinearOperatorBlockDiagTest(
62    linear_operator_test_util.SquareLinearOperatorDerivedClassTest):
63  """Most tests done in the base class LinearOperatorDerivedClassTest."""
64
65  def tearDown(self):
66    config.enable_tensor_float_32_execution(self.tf32_keep_)
67
68  def setUp(self):
69    self.tf32_keep_ = config.tensor_float_32_execution_enabled()
70    config.enable_tensor_float_32_execution(False)
71    # Increase from 1e-6 to 1e-4
72    self._atol[dtypes.float32] = 1e-4
73    self._atol[dtypes.complex64] = 1e-4
74    self._rtol[dtypes.float32] = 1e-4
75    self._rtol[dtypes.complex64] = 1e-4
76
77  @staticmethod
78  def optional_tests():
79    """List of optional test names to run."""
80    return [
81        "operator_matmul_with_same_type",
82        "operator_solve_with_same_type",
83    ]
84
85  @staticmethod
86  def operator_shapes_infos():
87    shape_info = linear_operator_test_util.OperatorShapesInfo
88    return [
89        shape_info((0, 0)),
90        shape_info((1, 1)),
91        shape_info((1, 3, 3)),
92        shape_info((5, 5), blocks=[(2, 2), (3, 3)]),
93        shape_info((3, 7, 7), blocks=[(1, 2, 2), (3, 2, 2), (1, 3, 3)]),
94        shape_info((2, 1, 5, 5), blocks=[(2, 1, 2, 2), (1, 3, 3)]),
95    ]
96
97  @staticmethod
98  def use_blockwise_arg():
99    return True
100
101  def operator_and_matrix(
102      self, shape_info, dtype, use_placeholder,
103      ensure_self_adjoint_and_pd=False):
104    shape = list(shape_info.shape)
105    expected_blocks = (
106        shape_info.__dict__["blocks"] if "blocks" in shape_info.__dict__
107        else [shape])
108    matrices = [
109        linear_operator_test_util.random_positive_definite_matrix(
110            block_shape, dtype, force_well_conditioned=True)
111        for block_shape in expected_blocks
112    ]
113
114    lin_op_matrices = matrices
115
116    if use_placeholder:
117      lin_op_matrices = [
118          array_ops.placeholder_with_default(
119              matrix, shape=None) for matrix in matrices]
120
121    operator = block_diag.LinearOperatorBlockDiag(
122        [linalg.LinearOperatorFullMatrix(
123            l,
124            is_square=True,
125            is_self_adjoint=True if ensure_self_adjoint_and_pd else None,
126            is_positive_definite=True if ensure_self_adjoint_and_pd else None)
127         for l in lin_op_matrices])
128
129    # Should be auto-set.
130    self.assertTrue(operator.is_square)
131
132    # Broadcast the shapes.
133    expected_shape = list(shape_info.shape)
134
135    matrices = linear_operator_util.broadcast_matrix_batch_dims(matrices)
136
137    block_diag_dense = _block_diag_dense(expected_shape, matrices)
138
139    if not use_placeholder:
140      block_diag_dense.set_shape(
141          expected_shape[:-2] + [expected_shape[-1], expected_shape[-1]])
142
143    return operator, block_diag_dense
144
145  def test_is_x_flags(self):
146    # Matrix with two positive eigenvalues, 1, and 1.
147    # The matrix values do not effect auto-setting of the flags.
148    matrix = [[1., 0.], [1., 1.]]
149    operator = block_diag.LinearOperatorBlockDiag(
150        [linalg.LinearOperatorFullMatrix(matrix)],
151        is_positive_definite=True,
152        is_non_singular=True,
153        is_self_adjoint=False)
154    self.assertTrue(operator.is_positive_definite)
155    self.assertTrue(operator.is_non_singular)
156    self.assertFalse(operator.is_self_adjoint)
157
158  def test_is_x_parameters(self):
159    matrix = [[1., 0.], [1., 1.]]
160    sub_operator = linalg.LinearOperatorFullMatrix(matrix)
161    operator = block_diag.LinearOperatorBlockDiag(
162        [sub_operator],
163        is_positive_definite=True,
164        is_non_singular=True,
165        is_self_adjoint=False)
166    self.assertEqual(
167        operator.parameters,
168        {
169            "name": None,
170            "is_square": True,
171            "is_positive_definite": True,
172            "is_self_adjoint": False,
173            "is_non_singular": True,
174            "operators": [sub_operator],
175        })
176    self.assertEqual(
177        sub_operator.parameters,
178        {
179            "is_non_singular": None,
180            "is_positive_definite": None,
181            "is_self_adjoint": None,
182            "is_square": None,
183            "matrix": matrix,
184            "name": "LinearOperatorFullMatrix",
185        })
186
187  def test_block_diag_adjoint_type(self):
188    matrix = [[1., 0.], [0., 1.]]
189    operator = block_diag.LinearOperatorBlockDiag(
190        [
191            linalg.LinearOperatorFullMatrix(
192                matrix,
193                is_non_singular=True,
194            ),
195            linalg.LinearOperatorFullMatrix(
196                matrix,
197                is_non_singular=True,
198            ),
199        ],
200        is_non_singular=True,
201    )
202    adjoint = operator.adjoint()
203    self.assertIsInstance(
204        adjoint,
205        block_diag.LinearOperatorBlockDiag)
206    self.assertEqual(2, len(adjoint.operators))
207
208  def test_block_diag_cholesky_type(self):
209    matrix = [[1., 0.], [0., 1.]]
210    operator = block_diag.LinearOperatorBlockDiag(
211        [
212            linalg.LinearOperatorFullMatrix(
213                matrix,
214                is_positive_definite=True,
215                is_self_adjoint=True,
216            ),
217            linalg.LinearOperatorFullMatrix(
218                matrix,
219                is_positive_definite=True,
220                is_self_adjoint=True,
221            ),
222        ],
223        is_positive_definite=True,
224        is_self_adjoint=True,
225    )
226    cholesky_factor = operator.cholesky()
227    self.assertIsInstance(
228        cholesky_factor,
229        block_diag.LinearOperatorBlockDiag)
230    self.assertEqual(2, len(cholesky_factor.operators))
231    self.assertIsInstance(
232        cholesky_factor.operators[0],
233        lower_triangular.LinearOperatorLowerTriangular)
234    self.assertIsInstance(
235        cholesky_factor.operators[1],
236        lower_triangular.LinearOperatorLowerTriangular
237    )
238
239  def test_block_diag_inverse_type(self):
240    matrix = [[1., 0.], [0., 1.]]
241    operator = block_diag.LinearOperatorBlockDiag(
242        [
243            linalg.LinearOperatorFullMatrix(
244                matrix,
245                is_non_singular=True,
246            ),
247            linalg.LinearOperatorFullMatrix(
248                matrix,
249                is_non_singular=True,
250            ),
251        ],
252        is_non_singular=True,
253    )
254    inverse = operator.inverse()
255    self.assertIsInstance(
256        inverse,
257        block_diag.LinearOperatorBlockDiag)
258    self.assertEqual(2, len(inverse.operators))
259
260  def test_block_diag_matmul_type(self):
261    matrices1 = []
262    matrices2 = []
263    for i in range(1, 5):
264      matrices1.append(linalg.LinearOperatorFullMatrix(
265          linear_operator_test_util.random_normal(
266              [2, i], dtype=dtypes.float32)))
267
268      matrices2.append(linalg.LinearOperatorFullMatrix(
269          linear_operator_test_util.random_normal(
270              [i, 3], dtype=dtypes.float32)))
271
272    operator1 = block_diag.LinearOperatorBlockDiag(matrices1, is_square=False)
273    operator2 = block_diag.LinearOperatorBlockDiag(matrices2, is_square=False)
274
275    expected_matrix = math_ops.matmul(
276        operator1.to_dense(), operator2.to_dense())
277    actual_operator = operator1.matmul(operator2)
278
279    self.assertIsInstance(
280        actual_operator, block_diag.LinearOperatorBlockDiag)
281    actual_, expected_ = self.evaluate([
282        actual_operator.to_dense(), expected_matrix])
283    self.assertAllClose(actual_, expected_)
284
285  def test_block_diag_matmul_raises(self):
286    matrices1 = []
287    for i in range(1, 5):
288      matrices1.append(linalg.LinearOperatorFullMatrix(
289          linear_operator_test_util.random_normal(
290              [2, i], dtype=dtypes.float32)))
291    operator1 = block_diag.LinearOperatorBlockDiag(matrices1, is_square=False)
292    operator2 = linalg.LinearOperatorFullMatrix(
293        linear_operator_test_util.random_normal(
294            [15, 3], dtype=dtypes.float32))
295
296    with self.assertRaisesRegex(ValueError, "Operators are incompatible"):
297      operator1.matmul(operator2)
298
299  def test_block_diag_solve_type(self):
300    matrices1 = []
301    matrices2 = []
302    for i in range(1, 5):
303      matrices1.append(linalg.LinearOperatorFullMatrix(
304          linear_operator_test_util.random_tril_matrix(
305              [i, i],
306              dtype=dtypes.float32,
307              force_well_conditioned=True)))
308
309      matrices2.append(linalg.LinearOperatorFullMatrix(
310          linear_operator_test_util.random_normal(
311              [i, 3], dtype=dtypes.float32)))
312
313    operator1 = block_diag.LinearOperatorBlockDiag(matrices1)
314    operator2 = block_diag.LinearOperatorBlockDiag(matrices2, is_square=False)
315
316    expected_matrix = linalg.solve(
317        operator1.to_dense(), operator2.to_dense())
318    actual_operator = operator1.solve(operator2)
319
320    self.assertIsInstance(
321        actual_operator, block_diag.LinearOperatorBlockDiag)
322    actual_, expected_ = self.evaluate([
323        actual_operator.to_dense(), expected_matrix])
324    self.assertAllClose(actual_, expected_)
325
326  def test_block_diag_solve_raises(self):
327    matrices1 = []
328    for i in range(1, 5):
329      matrices1.append(linalg.LinearOperatorFullMatrix(
330          linear_operator_test_util.random_normal(
331              [i, i], dtype=dtypes.float32)))
332    operator1 = block_diag.LinearOperatorBlockDiag(matrices1)
333    operator2 = linalg.LinearOperatorFullMatrix(
334        linear_operator_test_util.random_normal(
335            [15, 3], dtype=dtypes.float32))
336
337    with self.assertRaisesRegex(ValueError, "Operators are incompatible"):
338      operator1.solve(operator2)
339
340  def test_tape_safe(self):
341    matrices = []
342    for _ in range(4):
343      matrices.append(variables_module.Variable(
344          linear_operator_test_util.random_positive_definite_matrix(
345              [2, 2], dtype=dtypes.float32, force_well_conditioned=True)))
346
347    operator = block_diag.LinearOperatorBlockDiag(
348        [linalg.LinearOperatorFullMatrix(
349            matrix, is_self_adjoint=True,
350            is_positive_definite=True) for matrix in matrices],
351        is_self_adjoint=True,
352        is_positive_definite=True,
353    )
354    self.check_tape_safe(operator)
355
356  def test_convert_variables_to_tensors(self):
357    matrices = []
358    for _ in range(3):
359      matrices.append(variables_module.Variable(
360          linear_operator_test_util.random_positive_definite_matrix(
361              [3, 3], dtype=dtypes.float32, force_well_conditioned=True)))
362
363    operator = block_diag.LinearOperatorBlockDiag(
364        [linalg.LinearOperatorFullMatrix(
365            matrix, is_self_adjoint=True,
366            is_positive_definite=True) for matrix in matrices],
367        is_self_adjoint=True,
368        is_positive_definite=True,
369    )
370    with self.cached_session() as sess:
371      sess.run([x.initializer for x in operator.variables])
372      self.check_convert_variables_to_tensors(operator)
373
374  def test_is_non_singular_auto_set(self):
375    # Matrix with two positive eigenvalues, 11 and 8.
376    # The matrix values do not effect auto-setting of the flags.
377    matrix = [[11., 0.], [1., 8.]]
378    operator_1 = linalg.LinearOperatorFullMatrix(matrix, is_non_singular=True)
379    operator_2 = linalg.LinearOperatorFullMatrix(matrix, is_non_singular=True)
380
381    operator = block_diag.LinearOperatorBlockDiag(
382        [operator_1, operator_2],
383        is_positive_definite=False,  # No reason it HAS to be False...
384        is_non_singular=None)
385    self.assertFalse(operator.is_positive_definite)
386    self.assertTrue(operator.is_non_singular)
387
388    with self.assertRaisesRegex(ValueError, "always non-singular"):
389      block_diag.LinearOperatorBlockDiag(
390          [operator_1, operator_2], is_non_singular=False)
391
392  def test_name(self):
393    matrix = [[11., 0.], [1., 8.]]
394    operator_1 = linalg.LinearOperatorFullMatrix(matrix, name="left")
395    operator_2 = linalg.LinearOperatorFullMatrix(matrix, name="right")
396
397    operator = block_diag.LinearOperatorBlockDiag([operator_1, operator_2])
398
399    self.assertEqual("left_ds_right", operator.name)
400
401  def test_different_dtypes_raises(self):
402    operators = [
403        linalg.LinearOperatorFullMatrix(rng.rand(2, 3, 3)),
404        linalg.LinearOperatorFullMatrix(rng.rand(2, 3, 3).astype(np.float32))
405    ]
406    with self.assertRaisesRegex(TypeError, "same dtype"):
407      block_diag.LinearOperatorBlockDiag(operators)
408
409  def test_empty_operators_raises(self):
410    with self.assertRaisesRegex(ValueError, "non-empty"):
411      block_diag.LinearOperatorBlockDiag([])
412
413  def test_incompatible_input_blocks_raises(self):
414    matrix_1 = array_ops.placeholder_with_default(rng.rand(4, 4), shape=None)
415    matrix_2 = array_ops.placeholder_with_default(rng.rand(3, 3), shape=None)
416    operators = [
417        linalg.LinearOperatorFullMatrix(matrix_1, is_square=True),
418        linalg.LinearOperatorFullMatrix(matrix_2, is_square=True)
419    ]
420    operator = block_diag.LinearOperatorBlockDiag(operators)
421    x = np.random.rand(2, 4, 5).tolist()
422    msg = ("dimension does not match" if context.executing_eagerly()
423           else "input structure is ambiguous")
424    with self.assertRaisesRegex(ValueError, msg):
425      operator.matmul(x)
426
427
428@test_util.run_all_in_graph_and_eager_modes
429class NonSquareLinearOperatorBlockDiagTest(
430    linear_operator_test_util.NonSquareLinearOperatorDerivedClassTest):
431  """Most tests done in the base class LinearOperatorDerivedClassTest."""
432
433  def tearDown(self):
434    config.enable_tensor_float_32_execution(self.tf32_keep_)
435
436  def setUp(self):
437    self.tf32_keep_ = config.tensor_float_32_execution_enabled()
438    config.enable_tensor_float_32_execution(False)
439    # Increase from 1e-6 to 1e-4
440    self._atol[dtypes.float32] = 1e-4
441    self._atol[dtypes.complex64] = 1e-4
442    self._rtol[dtypes.float32] = 1e-4
443    self._rtol[dtypes.complex64] = 1e-4
444    super(NonSquareLinearOperatorBlockDiagTest, self).setUp()
445
446  @staticmethod
447  def operator_shapes_infos():
448    shape_info = linear_operator_test_util.OperatorShapesInfo
449    return [
450        shape_info((1, 0)),
451        shape_info((1, 2, 3)),
452        shape_info((5, 3), blocks=[(2, 1), (3, 2)]),
453        shape_info((3, 6, 5), blocks=[(1, 2, 1), (3, 1, 2), (1, 3, 2)]),
454        shape_info((2, 1, 5, 2), blocks=[(2, 1, 2, 1), (1, 3, 1)]),
455    ]
456
457  @staticmethod
458  def skip_these_tests():
459    return [
460        "cholesky",
461        "cond",
462        "det",
463        "diag_part",
464        "eigvalsh",
465        "inverse",
466        "log_abs_det",
467        "solve",
468        "solve_with_broadcast",
469        "trace"]
470
471  @staticmethod
472  def use_blockwise_arg():
473    return True
474
475  def operator_and_matrix(
476      self, shape_info, dtype, use_placeholder,
477      ensure_self_adjoint_and_pd=False):
478    del ensure_self_adjoint_and_pd
479    shape = list(shape_info.shape)
480    expected_blocks = (
481        shape_info.__dict__["blocks"] if "blocks" in shape_info.__dict__
482        else [shape])
483    matrices = [
484        linear_operator_test_util.random_normal(block_shape, dtype=dtype)
485        for block_shape in expected_blocks
486    ]
487
488    lin_op_matrices = matrices
489
490    if use_placeholder:
491      lin_op_matrices = [
492          array_ops.placeholder_with_default(
493              matrix, shape=None) for matrix in matrices]
494
495    blocks = []
496    for l in lin_op_matrices:
497      blocks.append(
498          linalg.LinearOperatorFullMatrix(
499              l,
500              is_square=False,
501              is_self_adjoint=False,
502              is_positive_definite=False))
503    operator = block_diag.LinearOperatorBlockDiag(blocks)
504
505    # Broadcast the shapes.
506    expected_shape = list(shape_info.shape)
507
508    matrices = linear_operator_util.broadcast_matrix_batch_dims(matrices)
509
510    block_diag_dense = _block_diag_dense(expected_shape, matrices)
511
512    if not use_placeholder:
513      block_diag_dense.set_shape(expected_shape)
514
515    return operator, block_diag_dense
516
517
518if __name__ == "__main__":
519  linear_operator_test_util.add_tests(SquareLinearOperatorBlockDiagTest)
520  linear_operator_test_util.add_tests(NonSquareLinearOperatorBlockDiagTest)
521  test.main()
522