• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16from tensorflow.python.framework import config
17from tensorflow.python.framework import test_util
18from tensorflow.python.ops import array_ops
19from tensorflow.python.ops import math_ops
20from tensorflow.python.ops import variables as variables_module
21from tensorflow.python.ops.linalg import linalg as linalg_lib
22from tensorflow.python.ops.linalg import linear_operator_test_util
23from tensorflow.python.platform import test
24
25linalg = linalg_lib
26
27
28@test_util.run_all_in_graph_and_eager_modes
29class LinearOperatorLowerTriangularTest(
30    linear_operator_test_util.SquareLinearOperatorDerivedClassTest):
31  """Most tests done in the base class LinearOperatorDerivedClassTest."""
32
33  @staticmethod
34  def skip_these_tests():
35    # Cholesky does not make sense for triangular matrices.
36    return ["cholesky"]
37
38  def operator_and_matrix(self, build_info, dtype, use_placeholder,
39                          ensure_self_adjoint_and_pd=False):
40    shape = list(build_info.shape)
41    # Upper triangle will be nonzero, but ignored.
42    # Use a diagonal that ensures this matrix is well conditioned.
43    tril = linear_operator_test_util.random_tril_matrix(
44        shape, dtype=dtype, force_well_conditioned=True, remove_upper=False)
45    if ensure_self_adjoint_and_pd:
46      # Get the diagonal and make the matrix out of it.
47      tril = array_ops.matrix_diag_part(tril)
48      tril = math_ops.abs(tril) + 1e-1
49      tril = array_ops.matrix_diag(tril)
50
51    lin_op_tril = tril
52
53    if use_placeholder:
54      lin_op_tril = array_ops.placeholder_with_default(lin_op_tril, shape=None)
55
56    operator = linalg.LinearOperatorLowerTriangular(
57        lin_op_tril,
58        is_self_adjoint=True if ensure_self_adjoint_and_pd else None,
59        is_positive_definite=True if ensure_self_adjoint_and_pd else None)
60
61    matrix = array_ops.matrix_band_part(tril, -1, 0)
62
63    return operator, matrix
64
65  def test_assert_non_singular(self):
66    # Singular matrix with one positive eigenvalue and one zero eigenvalue.
67    with self.cached_session():
68      tril = [[1., 0.], [1., 0.]]
69      operator = linalg.LinearOperatorLowerTriangular(tril)
70      with self.assertRaisesOpError("Singular operator"):
71        operator.assert_non_singular().run()
72
73  def test_is_x_flags(self):
74    # Matrix with two positive eigenvalues.
75    tril = [[1., 0.], [1., 1.]]
76    operator = linalg.LinearOperatorLowerTriangular(
77        tril,
78        is_positive_definite=True,
79        is_non_singular=True,
80        is_self_adjoint=False)
81    self.assertTrue(operator.is_positive_definite)
82    self.assertTrue(operator.is_non_singular)
83    self.assertFalse(operator.is_self_adjoint)
84
85  def test_tril_must_have_at_least_two_dims_or_raises(self):
86    with self.assertRaisesRegex(ValueError, "at least 2 dimensions"):
87      linalg.LinearOperatorLowerTriangular([1.])
88
89  def test_triangular_diag_matmul(self):
90    operator1 = linalg_lib.LinearOperatorLowerTriangular(
91        [[1., 0., 0.], [2., 1., 0.], [2., 3., 3.]])
92    operator2 = linalg_lib.LinearOperatorDiag([2., 2., 3.])
93    operator_matmul = operator1.matmul(operator2)
94    self.assertTrue(isinstance(
95        operator_matmul,
96        linalg_lib.LinearOperatorLowerTriangular))
97    self.assertAllClose(
98        math_ops.matmul(
99            operator1.to_dense(),
100            operator2.to_dense()),
101        self.evaluate(operator_matmul.to_dense()))
102
103    operator_matmul = operator2.matmul(operator1)
104    self.assertTrue(isinstance(
105        operator_matmul,
106        linalg_lib.LinearOperatorLowerTriangular))
107    self.assertAllClose(
108        math_ops.matmul(
109            operator2.to_dense(),
110            operator1.to_dense()),
111        self.evaluate(operator_matmul.to_dense()))
112
113  def test_tape_safe(self):
114    tril = variables_module.Variable([[1., 0.], [0., 1.]])
115    operator = linalg_lib.LinearOperatorLowerTriangular(
116        tril, is_non_singular=True)
117    self.check_tape_safe(operator)
118
119  def test_convert_variables_to_tensors(self):
120    tril = variables_module.Variable([[1., 0.], [0., 1.]])
121    operator = linalg_lib.LinearOperatorLowerTriangular(
122        tril, is_non_singular=True)
123    with self.cached_session() as sess:
124      sess.run([tril.initializer])
125      self.check_convert_variables_to_tensors(operator)
126
127
128if __name__ == "__main__":
129  config.enable_tensor_float_32_execution(False)
130  linear_operator_test_util.add_tests(LinearOperatorLowerTriangularTest)
131  test.main()
132