• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16from tensorflow.python.framework import config
17from tensorflow.python.framework import test_util
18from tensorflow.python.ops import array_ops
19from tensorflow.python.ops import linalg_ops
20from tensorflow.python.ops import math_ops
21from tensorflow.python.ops import random_ops
22from tensorflow.python.ops import variables as variables_module
23from tensorflow.python.ops.linalg import linalg as linalg_lib
24from tensorflow.python.ops.linalg import linear_operator_test_util
25from tensorflow.python.platform import test
26
27linalg = linalg_lib
28
29
30@test_util.run_all_in_graph_and_eager_modes
31class LinearOperatorDiagTest(
32    linear_operator_test_util.SquareLinearOperatorDerivedClassTest):
33  """Most tests done in the base class LinearOperatorDerivedClassTest."""
34
35  def tearDown(self):
36    config.enable_tensor_float_32_execution(self.tf32_keep_)
37
38  def setUp(self):
39    self.tf32_keep_ = config.tensor_float_32_execution_enabled()
40    config.enable_tensor_float_32_execution(False)
41
42  @staticmethod
43  def optional_tests():
44    """List of optional test names to run."""
45    return [
46        "operator_matmul_with_same_type",
47        "operator_solve_with_same_type",
48    ]
49
50  def operator_and_matrix(
51      self, build_info, dtype, use_placeholder,
52      ensure_self_adjoint_and_pd=False):
53    shape = list(build_info.shape)
54    diag = linear_operator_test_util.random_sign_uniform(
55        shape[:-1], minval=1., maxval=2., dtype=dtype)
56
57    if ensure_self_adjoint_and_pd:
58      # Abs on complex64 will result in a float32, so we cast back up.
59      diag = math_ops.cast(math_ops.abs(diag), dtype=dtype)
60
61    lin_op_diag = diag
62
63    if use_placeholder:
64      lin_op_diag = array_ops.placeholder_with_default(diag, shape=None)
65
66    operator = linalg.LinearOperatorDiag(
67        lin_op_diag,
68        is_self_adjoint=True if ensure_self_adjoint_and_pd else None,
69        is_positive_definite=True if ensure_self_adjoint_and_pd else None)
70
71    matrix = array_ops.matrix_diag(diag)
72
73    return operator, matrix
74
75  def test_assert_positive_definite_raises_for_zero_eigenvalue(self):
76    # Matrix with one positive eigenvalue and one zero eigenvalue.
77    with self.cached_session():
78      diag = [1.0, 0.0]
79      operator = linalg.LinearOperatorDiag(diag)
80
81      # is_self_adjoint should be auto-set for real diag.
82      self.assertTrue(operator.is_self_adjoint)
83      with self.assertRaisesOpError("non-positive.*not positive definite"):
84        operator.assert_positive_definite().run()
85
86  def test_assert_positive_definite_raises_for_negative_real_eigvalues(self):
87    with self.cached_session():
88      diag_x = [1.0, -2.0]
89      diag_y = [0., 0.]  # Imaginary eigenvalues should not matter.
90      diag = math_ops.complex(diag_x, diag_y)
91      operator = linalg.LinearOperatorDiag(diag)
92
93      # is_self_adjoint should not be auto-set for complex diag.
94      self.assertTrue(operator.is_self_adjoint is None)
95      with self.assertRaisesOpError("non-positive real.*not positive definite"):
96        operator.assert_positive_definite().run()
97
98  def test_assert_positive_definite_does_not_raise_if_pd_and_complex(self):
99    with self.cached_session():
100      x = [1., 2.]
101      y = [1., 0.]
102      diag = math_ops.complex(x, y)  # Re[diag] > 0.
103      # Should not fail
104      self.evaluate(linalg.LinearOperatorDiag(diag).assert_positive_definite())
105
106  def test_assert_non_singular_raises_if_zero_eigenvalue(self):
107    # Singular matrix with one positive eigenvalue and one zero eigenvalue.
108    with self.cached_session():
109      diag = [1.0, 0.0]
110      operator = linalg.LinearOperatorDiag(diag, is_self_adjoint=True)
111      with self.assertRaisesOpError("Singular operator"):
112        operator.assert_non_singular().run()
113
114  def test_assert_non_singular_does_not_raise_for_complex_nonsingular(self):
115    with self.cached_session():
116      x = [1., 0.]
117      y = [0., 1.]
118      diag = math_ops.complex(x, y)
119      # Should not raise.
120      self.evaluate(linalg.LinearOperatorDiag(diag).assert_non_singular())
121
122  def test_assert_self_adjoint_raises_if_diag_has_complex_part(self):
123    with self.cached_session():
124      x = [1., 0.]
125      y = [0., 1.]
126      diag = math_ops.complex(x, y)
127      operator = linalg.LinearOperatorDiag(diag)
128      with self.assertRaisesOpError("imaginary.*not self-adjoint"):
129        operator.assert_self_adjoint().run()
130
131  def test_assert_self_adjoint_does_not_raise_for_diag_with_zero_imag(self):
132    with self.cached_session():
133      x = [1., 0.]
134      y = [0., 0.]
135      diag = math_ops.complex(x, y)
136      operator = linalg.LinearOperatorDiag(diag)
137      # Should not raise
138      self.evaluate(operator.assert_self_adjoint())
139
140  def test_scalar_diag_raises(self):
141    with self.assertRaisesRegex(ValueError, "must have at least 1 dimension"):
142      linalg.LinearOperatorDiag(1.)
143
144  def test_broadcast_matmul_and_solve(self):
145    # These cannot be done in the automated (base test class) tests since they
146    # test shapes that tf.matmul cannot handle.
147    # In particular, tf.matmul does not broadcast.
148    with self.cached_session() as sess:
149      x = random_ops.random_normal(shape=(2, 2, 3, 4))
150
151      # This LinearOperatorDiag will be broadcast to (2, 2, 3, 3) during solve
152      # and matmul with 'x' as the argument.
153      diag = random_ops.random_uniform(shape=(2, 1, 3))
154      operator = linalg.LinearOperatorDiag(diag, is_self_adjoint=True)
155      self.assertAllEqual((2, 1, 3, 3), operator.shape)
156
157      # Create a batch matrix with the broadcast shape of operator.
158      diag_broadcast = array_ops.concat((diag, diag), 1)
159      mat = array_ops.matrix_diag(diag_broadcast)
160      self.assertAllEqual((2, 2, 3, 3), mat.shape)  # being pedantic.
161
162      operator_matmul = operator.matmul(x)
163      mat_matmul = math_ops.matmul(mat, x)
164      self.assertAllEqual(operator_matmul.shape, mat_matmul.shape)
165      self.assertAllClose(*self.evaluate([operator_matmul, mat_matmul]))
166
167      operator_solve = operator.solve(x)
168      mat_solve = linalg_ops.matrix_solve(mat, x)
169      self.assertAllEqual(operator_solve.shape, mat_solve.shape)
170      self.assertAllClose(*self.evaluate([operator_solve, mat_solve]))
171
172  def test_diag_matmul(self):
173    operator1 = linalg_lib.LinearOperatorDiag([2., 3.])
174    operator2 = linalg_lib.LinearOperatorDiag([1., 2.])
175    operator3 = linalg_lib.LinearOperatorScaledIdentity(
176        num_rows=2, multiplier=3.)
177    operator_matmul = operator1.matmul(operator2)
178    self.assertTrue(isinstance(
179        operator_matmul,
180        linalg_lib.LinearOperatorDiag))
181    self.assertAllClose([2., 6.], self.evaluate(operator_matmul.diag))
182
183    operator_matmul = operator2.matmul(operator1)
184    self.assertTrue(isinstance(
185        operator_matmul,
186        linalg_lib.LinearOperatorDiag))
187    self.assertAllClose([2., 6.], self.evaluate(operator_matmul.diag))
188
189    operator_matmul = operator1.matmul(operator3)
190    self.assertTrue(isinstance(
191        operator_matmul,
192        linalg_lib.LinearOperatorDiag))
193    self.assertAllClose([6., 9.], self.evaluate(operator_matmul.diag))
194
195    operator_matmul = operator3.matmul(operator1)
196    self.assertTrue(isinstance(
197        operator_matmul,
198        linalg_lib.LinearOperatorDiag))
199    self.assertAllClose([6., 9.], self.evaluate(operator_matmul.diag))
200
201  def test_diag_solve(self):
202    operator1 = linalg_lib.LinearOperatorDiag([2., 3.], is_non_singular=True)
203    operator2 = linalg_lib.LinearOperatorDiag([1., 2.], is_non_singular=True)
204    operator3 = linalg_lib.LinearOperatorScaledIdentity(
205        num_rows=2, multiplier=3., is_non_singular=True)
206    operator_solve = operator1.solve(operator2)
207    self.assertTrue(isinstance(
208        operator_solve,
209        linalg_lib.LinearOperatorDiag))
210    self.assertAllClose([0.5, 2 / 3.], self.evaluate(operator_solve.diag))
211
212    operator_solve = operator2.solve(operator1)
213    self.assertTrue(isinstance(
214        operator_solve,
215        linalg_lib.LinearOperatorDiag))
216    self.assertAllClose([2., 3 / 2.], self.evaluate(operator_solve.diag))
217
218    operator_solve = operator1.solve(operator3)
219    self.assertTrue(isinstance(
220        operator_solve,
221        linalg_lib.LinearOperatorDiag))
222    self.assertAllClose([3 / 2., 1.], self.evaluate(operator_solve.diag))
223
224    operator_solve = operator3.solve(operator1)
225    self.assertTrue(isinstance(
226        operator_solve,
227        linalg_lib.LinearOperatorDiag))
228    self.assertAllClose([2 / 3., 1.], self.evaluate(operator_solve.diag))
229
230  def test_diag_adjoint_type(self):
231    diag = [1., 3., 5., 8.]
232    operator = linalg.LinearOperatorDiag(diag, is_non_singular=True)
233    self.assertIsInstance(operator.adjoint(), linalg.LinearOperatorDiag)
234
235  def test_diag_cholesky_type(self):
236    diag = [1., 3., 5., 8.]
237    operator = linalg.LinearOperatorDiag(
238        diag,
239        is_positive_definite=True,
240        is_self_adjoint=True,
241    )
242    self.assertIsInstance(operator.cholesky(), linalg.LinearOperatorDiag)
243
244  def test_diag_inverse_type(self):
245    diag = [1., 3., 5., 8.]
246    operator = linalg.LinearOperatorDiag(diag, is_non_singular=True)
247    self.assertIsInstance(operator.inverse(), linalg.LinearOperatorDiag)
248
249  def test_tape_safe(self):
250    diag = variables_module.Variable([[2.]])
251    operator = linalg.LinearOperatorDiag(diag)
252    self.check_tape_safe(operator)
253
254  def test_convert_variables_to_tensors(self):
255    diag = variables_module.Variable([[2.]])
256    operator = linalg.LinearOperatorDiag(diag)
257    with self.cached_session() as sess:
258      sess.run([diag.initializer])
259      self.check_convert_variables_to_tensors(operator)
260
261
262if __name__ == "__main__":
263  linear_operator_test_util.add_tests(LinearOperatorDiagTest)
264  test.main()
265