• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16from tensorflow.python.framework import config
17from tensorflow.python.framework import dtypes
18from tensorflow.python.framework import test_util
19from tensorflow.python.ops import array_ops
20from tensorflow.python.ops import variables as variables_module
21from tensorflow.python.ops.linalg import linalg as linalg_lib
22from tensorflow.python.ops.linalg import linear_operator_inversion
23from tensorflow.python.ops.linalg import linear_operator_test_util
24from tensorflow.python.platform import test
25
26linalg = linalg_lib
27
28LinearOperatorInversion = linear_operator_inversion.LinearOperatorInversion  # pylint: disable=invalid-name
29
30
31@test_util.run_all_in_graph_and_eager_modes
32class LinearOperatorInversionTest(
33    linear_operator_test_util.SquareLinearOperatorDerivedClassTest):
34  """Most tests done in the base class LinearOperatorDerivedClassTest."""
35
36  def tearDown(self):
37    config.enable_tensor_float_32_execution(self.tf32_keep_)
38
39  def setUp(self):
40    self.tf32_keep_ = config.tensor_float_32_execution_enabled()
41    config.enable_tensor_float_32_execution(False)
42    self._atol[dtypes.complex64] = 1e-5
43    self._rtol[dtypes.complex64] = 1e-5
44
45  def operator_and_matrix(self,
46                          build_info,
47                          dtype,
48                          use_placeholder,
49                          ensure_self_adjoint_and_pd=False):
50    shape = list(build_info.shape)
51
52    if ensure_self_adjoint_and_pd:
53      matrix = linear_operator_test_util.random_positive_definite_matrix(
54          shape, dtype, force_well_conditioned=True)
55    else:
56      matrix = linear_operator_test_util.random_tril_matrix(
57          shape, dtype, force_well_conditioned=True, remove_upper=True)
58
59    lin_op_matrix = matrix
60
61    if use_placeholder:
62      lin_op_matrix = array_ops.placeholder_with_default(matrix, shape=None)
63
64    if ensure_self_adjoint_and_pd:
65      operator = LinearOperatorInversion(
66          linalg.LinearOperatorFullMatrix(
67              lin_op_matrix, is_positive_definite=True, is_self_adjoint=True))
68    else:
69      operator = LinearOperatorInversion(
70          linalg.LinearOperatorLowerTriangular(lin_op_matrix))
71
72    return operator, linalg.inv(matrix)
73
74  def test_base_operator_hint_used(self):
75    # The matrix values do not effect auto-setting of the flags.
76    matrix = [[1., 0.], [1., 1.]]
77    operator = linalg.LinearOperatorFullMatrix(
78        matrix,
79        is_positive_definite=True,
80        is_non_singular=True,
81        is_self_adjoint=False)
82    operator_inv = LinearOperatorInversion(operator)
83    self.assertTrue(operator_inv.is_positive_definite)
84    self.assertTrue(operator_inv.is_non_singular)
85    self.assertFalse(operator_inv.is_self_adjoint)
86
87  def test_supplied_hint_used(self):
88    # The matrix values do not effect auto-setting of the flags.
89    matrix = [[1., 0.], [1., 1.]]
90    operator = linalg.LinearOperatorFullMatrix(matrix)
91    operator_inv = LinearOperatorInversion(
92        operator,
93        is_positive_definite=True,
94        is_non_singular=True,
95        is_self_adjoint=False)
96    self.assertTrue(operator_inv.is_positive_definite)
97    self.assertTrue(operator_inv.is_non_singular)
98    self.assertFalse(operator_inv.is_self_adjoint)
99
100  def test_contradicting_hints_raise(self):
101    # The matrix values do not effect auto-setting of the flags.
102    matrix = [[1., 0.], [1., 1.]]
103    operator = linalg.LinearOperatorFullMatrix(
104        matrix, is_positive_definite=False)
105    with self.assertRaisesRegex(ValueError, "positive-definite"):
106      LinearOperatorInversion(operator, is_positive_definite=True)
107
108    operator = linalg.LinearOperatorFullMatrix(matrix, is_self_adjoint=False)
109    with self.assertRaisesRegex(ValueError, "self-adjoint"):
110      LinearOperatorInversion(operator, is_self_adjoint=True)
111
112  def test_singular_raises(self):
113    # The matrix values do not effect auto-setting of the flags.
114    matrix = [[1., 1.], [1., 1.]]
115
116    operator = linalg.LinearOperatorFullMatrix(matrix, is_non_singular=False)
117    with self.assertRaisesRegex(ValueError, "is_non_singular"):
118      LinearOperatorInversion(operator)
119
120    operator = linalg.LinearOperatorFullMatrix(matrix)
121    with self.assertRaisesRegex(ValueError, "is_non_singular"):
122      LinearOperatorInversion(operator, is_non_singular=False)
123
124  def test_name(self):
125    matrix = [[11., 0.], [1., 8.]]
126    operator = linalg.LinearOperatorFullMatrix(
127        matrix, name="my_operator", is_non_singular=True)
128
129    operator = LinearOperatorInversion(operator)
130
131    self.assertEqual("my_operator_inv", operator.name)
132
133  def test_tape_safe(self):
134    matrix = variables_module.Variable([[1., 2.], [3., 4.]])
135    operator = LinearOperatorInversion(linalg.LinearOperatorFullMatrix(matrix))
136    self.check_tape_safe(operator)
137
138
139if __name__ == "__main__":
140  linear_operator_test_util.add_tests(LinearOperatorInversionTest)
141  test.main()
142