1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16from tensorflow.python.framework import config 17from tensorflow.python.framework import test_util 18from tensorflow.python.ops import array_ops 19from tensorflow.python.ops import linalg_ops 20from tensorflow.python.ops import math_ops 21from tensorflow.python.ops import variables as variables_module 22from tensorflow.python.ops.linalg import linalg as linalg_lib 23from tensorflow.python.ops.linalg import linear_operator_householder as householder 24from tensorflow.python.ops.linalg import linear_operator_test_util 25from tensorflow.python.platform import test 26 27linalg = linalg_lib 28CheckTapeSafeSkipOptions = linear_operator_test_util.CheckTapeSafeSkipOptions 29 30 31@test_util.run_all_in_graph_and_eager_modes 32class LinearOperatorHouseholderTest( 33 linear_operator_test_util.SquareLinearOperatorDerivedClassTest): 34 """Most tests done in the base class LinearOperatorDerivedClassTest.""" 35 36 def tearDown(self): 37 config.enable_tensor_float_32_execution(self.tf32_keep_) 38 39 def setUp(self): 40 self.tf32_keep_ = config.tensor_float_32_execution_enabled() 41 config.enable_tensor_float_32_execution(False) 42 43 @staticmethod 44 def operator_shapes_infos(): 45 shape_info = linear_operator_test_util.OperatorShapesInfo 46 return [ 47 shape_info((1, 1)), 48 shape_info((1, 3, 3)), 49 shape_info((3, 4, 4)), 50 shape_info((2, 1, 4, 4))] 51 52 @staticmethod 53 def skip_these_tests(): 54 # This linear operator is never positive definite. 55 return ["cholesky"] 56 57 def operator_and_matrix( 58 self, build_info, dtype, use_placeholder, 59 ensure_self_adjoint_and_pd=False): 60 shape = list(build_info.shape) 61 reflection_axis = linear_operator_test_util.random_sign_uniform( 62 shape[:-1], minval=1., maxval=2., dtype=dtype) 63 # Make sure unit norm. 64 reflection_axis = reflection_axis / linalg_ops.norm( 65 reflection_axis, axis=-1, keepdims=True) 66 67 lin_op_reflection_axis = reflection_axis 68 69 if use_placeholder: 70 lin_op_reflection_axis = array_ops.placeholder_with_default( 71 reflection_axis, shape=None) 72 73 operator = householder.LinearOperatorHouseholder(lin_op_reflection_axis) 74 75 mat = reflection_axis[..., array_ops.newaxis] 76 matrix = -2 * math_ops.matmul(mat, mat, adjoint_b=True) 77 matrix = array_ops.matrix_set_diag( 78 matrix, 1. + array_ops.matrix_diag_part(matrix)) 79 80 return operator, matrix 81 82 def test_scalar_reflection_axis_raises(self): 83 with self.assertRaisesRegex(ValueError, "must have at least 1 dimension"): 84 householder.LinearOperatorHouseholder(1.) 85 86 def test_householder_adjoint_type(self): 87 reflection_axis = [1., 3., 5., 8.] 88 operator = householder.LinearOperatorHouseholder(reflection_axis) 89 self.assertIsInstance( 90 operator.adjoint(), householder.LinearOperatorHouseholder) 91 92 def test_householder_inverse_type(self): 93 reflection_axis = [1., 3., 5., 8.] 94 operator = householder.LinearOperatorHouseholder(reflection_axis) 95 self.assertIsInstance( 96 operator.inverse(), householder.LinearOperatorHouseholder) 97 98 def test_tape_safe(self): 99 reflection_axis = variables_module.Variable([1., 3., 5., 8.]) 100 operator = householder.LinearOperatorHouseholder(reflection_axis) 101 self.check_tape_safe( 102 operator, 103 skip_options=[ 104 # Determinant hard-coded as 1. 105 CheckTapeSafeSkipOptions.DETERMINANT, 106 CheckTapeSafeSkipOptions.LOG_ABS_DETERMINANT, 107 # Trace hard-coded. 108 CheckTapeSafeSkipOptions.TRACE, 109 ]) 110 111 def test_convert_variables_to_tensors(self): 112 reflection_axis = variables_module.Variable([1., 3., 5., 8.]) 113 operator = householder.LinearOperatorHouseholder(reflection_axis) 114 with self.cached_session() as sess: 115 sess.run([reflection_axis.initializer]) 116 self.check_convert_variables_to_tensors(operator) 117 118 119if __name__ == "__main__": 120 linear_operator_test_util.add_tests(LinearOperatorHouseholderTest) 121 test.main() 122