• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16from tensorflow.python.framework import config
17from tensorflow.python.framework import test_util
18from tensorflow.python.ops import array_ops
19from tensorflow.python.ops import linalg_ops
20from tensorflow.python.ops import math_ops
21from tensorflow.python.ops import variables as variables_module
22from tensorflow.python.ops.linalg import linalg as linalg_lib
23from tensorflow.python.ops.linalg import linear_operator_householder as householder
24from tensorflow.python.ops.linalg import linear_operator_test_util
25from tensorflow.python.platform import test
26
27linalg = linalg_lib
28CheckTapeSafeSkipOptions = linear_operator_test_util.CheckTapeSafeSkipOptions
29
30
31@test_util.run_all_in_graph_and_eager_modes
32class LinearOperatorHouseholderTest(
33    linear_operator_test_util.SquareLinearOperatorDerivedClassTest):
34  """Most tests done in the base class LinearOperatorDerivedClassTest."""
35
36  def tearDown(self):
37    config.enable_tensor_float_32_execution(self.tf32_keep_)
38
39  def setUp(self):
40    self.tf32_keep_ = config.tensor_float_32_execution_enabled()
41    config.enable_tensor_float_32_execution(False)
42
43  @staticmethod
44  def operator_shapes_infos():
45    shape_info = linear_operator_test_util.OperatorShapesInfo
46    return [
47        shape_info((1, 1)),
48        shape_info((1, 3, 3)),
49        shape_info((3, 4, 4)),
50        shape_info((2, 1, 4, 4))]
51
52  @staticmethod
53  def skip_these_tests():
54    # This linear operator is never positive definite.
55    return ["cholesky"]
56
57  def operator_and_matrix(
58      self, build_info, dtype, use_placeholder,
59      ensure_self_adjoint_and_pd=False):
60    shape = list(build_info.shape)
61    reflection_axis = linear_operator_test_util.random_sign_uniform(
62        shape[:-1], minval=1., maxval=2., dtype=dtype)
63    # Make sure unit norm.
64    reflection_axis = reflection_axis / linalg_ops.norm(
65        reflection_axis, axis=-1, keepdims=True)
66
67    lin_op_reflection_axis = reflection_axis
68
69    if use_placeholder:
70      lin_op_reflection_axis = array_ops.placeholder_with_default(
71          reflection_axis, shape=None)
72
73    operator = householder.LinearOperatorHouseholder(lin_op_reflection_axis)
74
75    mat = reflection_axis[..., array_ops.newaxis]
76    matrix = -2 * math_ops.matmul(mat, mat, adjoint_b=True)
77    matrix = array_ops.matrix_set_diag(
78        matrix, 1. + array_ops.matrix_diag_part(matrix))
79
80    return operator, matrix
81
82  def test_scalar_reflection_axis_raises(self):
83    with self.assertRaisesRegex(ValueError, "must have at least 1 dimension"):
84      householder.LinearOperatorHouseholder(1.)
85
86  def test_householder_adjoint_type(self):
87    reflection_axis = [1., 3., 5., 8.]
88    operator = householder.LinearOperatorHouseholder(reflection_axis)
89    self.assertIsInstance(
90        operator.adjoint(), householder.LinearOperatorHouseholder)
91
92  def test_householder_inverse_type(self):
93    reflection_axis = [1., 3., 5., 8.]
94    operator = householder.LinearOperatorHouseholder(reflection_axis)
95    self.assertIsInstance(
96        operator.inverse(), householder.LinearOperatorHouseholder)
97
98  def test_tape_safe(self):
99    reflection_axis = variables_module.Variable([1., 3., 5., 8.])
100    operator = householder.LinearOperatorHouseholder(reflection_axis)
101    self.check_tape_safe(
102        operator,
103        skip_options=[
104            # Determinant hard-coded as 1.
105            CheckTapeSafeSkipOptions.DETERMINANT,
106            CheckTapeSafeSkipOptions.LOG_ABS_DETERMINANT,
107            # Trace hard-coded.
108            CheckTapeSafeSkipOptions.TRACE,
109        ])
110
111  def test_convert_variables_to_tensors(self):
112    reflection_axis = variables_module.Variable([1., 3., 5., 8.])
113    operator = householder.LinearOperatorHouseholder(reflection_axis)
114    with self.cached_session() as sess:
115      sess.run([reflection_axis.initializer])
116      self.check_convert_variables_to_tensors(operator)
117
118
119if __name__ == "__main__":
120  linear_operator_test_util.add_tests(LinearOperatorHouseholderTest)
121  test.main()
122