1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16from tensorflow.python.framework import config 17from tensorflow.python.framework import test_util 18from tensorflow.python.ops import array_ops 19from tensorflow.python.ops import gen_array_ops 20from tensorflow.python.ops import manip_ops 21from tensorflow.python.ops import math_ops 22from tensorflow.python.ops import variables as variables_module 23from tensorflow.python.ops.linalg import linalg as linalg_lib 24from tensorflow.python.ops.linalg import linear_operator_test_util 25from tensorflow.python.platform import test 26 27 28class _LinearOperatorTriDiagBase(object): 29 30 def build_operator_and_matrix( 31 self, build_info, dtype, use_placeholder, 32 ensure_self_adjoint_and_pd=False, 33 diagonals_format='sequence'): 34 shape = list(build_info.shape) 35 36 # Ensure that diagonal has large enough values. If we generate a 37 # self adjoint PD matrix, then the diagonal will be dominant guaranteeing 38 # positive definitess. 39 diag = linear_operator_test_util.random_sign_uniform( 40 shape[:-1], minval=4., maxval=6., dtype=dtype) 41 # We'll truncate these depending on the format 42 subdiag = linear_operator_test_util.random_sign_uniform( 43 shape[:-1], minval=1., maxval=2., dtype=dtype) 44 if ensure_self_adjoint_and_pd: 45 # Abs on complex64 will result in a float32, so we cast back up. 46 diag = math_ops.cast(math_ops.abs(diag), dtype=dtype) 47 # The first element of subdiag is ignored. We'll add a dummy element 48 # to superdiag to pad it. 49 superdiag = math_ops.conj(subdiag) 50 superdiag = manip_ops.roll(superdiag, shift=-1, axis=-1) 51 else: 52 superdiag = linear_operator_test_util.random_sign_uniform( 53 shape[:-1], minval=1., maxval=2., dtype=dtype) 54 55 matrix_diagonals = array_ops.stack( 56 [superdiag, diag, subdiag], axis=-2) 57 matrix = gen_array_ops.matrix_diag_v3( 58 matrix_diagonals, 59 k=(-1, 1), 60 num_rows=-1, 61 num_cols=-1, 62 align='LEFT_RIGHT', 63 padding_value=0.) 64 65 if diagonals_format == 'sequence': 66 diagonals = [superdiag, diag, subdiag] 67 elif diagonals_format == 'compact': 68 diagonals = array_ops.stack([superdiag, diag, subdiag], axis=-2) 69 elif diagonals_format == 'matrix': 70 diagonals = matrix 71 72 lin_op_diagonals = diagonals 73 74 if use_placeholder: 75 if diagonals_format == 'sequence': 76 lin_op_diagonals = [array_ops.placeholder_with_default( 77 d, shape=None) for d in lin_op_diagonals] 78 else: 79 lin_op_diagonals = array_ops.placeholder_with_default( 80 lin_op_diagonals, shape=None) 81 82 operator = linalg_lib.LinearOperatorTridiag( 83 diagonals=lin_op_diagonals, 84 diagonals_format=diagonals_format, 85 is_self_adjoint=True if ensure_self_adjoint_and_pd else None, 86 is_positive_definite=True if ensure_self_adjoint_and_pd else None) 87 return operator, matrix 88 89 @staticmethod 90 def operator_shapes_infos(): 91 shape_info = linear_operator_test_util.OperatorShapesInfo 92 # non-batch operators (n, n) and batch operators. 93 return [ 94 shape_info((3, 3)), 95 shape_info((1, 6, 6)), 96 shape_info((3, 4, 4)), 97 shape_info((2, 1, 3, 3)) 98 ] 99 100 101@test_util.with_eager_op_as_function 102@test_util.run_all_in_graph_and_eager_modes 103class LinearOperatorTriDiagCompactTest( 104 _LinearOperatorTriDiagBase, 105 linear_operator_test_util.SquareLinearOperatorDerivedClassTest): 106 """Most tests done in the base class LinearOperatorDerivedClassTest.""" 107 108 def tearDown(self): 109 config.enable_tensor_float_32_execution(self.tf32_keep_) 110 111 def setUp(self): 112 self.tf32_keep_ = config.tensor_float_32_execution_enabled() 113 config.enable_tensor_float_32_execution(False) 114 115 def operator_and_matrix( 116 self, build_info, dtype, use_placeholder, 117 ensure_self_adjoint_and_pd=False): 118 return self.build_operator_and_matrix( 119 build_info, dtype, use_placeholder, 120 ensure_self_adjoint_and_pd=ensure_self_adjoint_and_pd, 121 diagonals_format='compact') 122 123 @test_util.disable_xla('Current implementation does not yet support pivoting') 124 def test_tape_safe(self): 125 diag = variables_module.Variable([[3., 6., 2.], [2., 4., 2.], [5., 1., 2.]]) 126 operator = linalg_lib.LinearOperatorTridiag( 127 diag, diagonals_format='compact') 128 self.check_tape_safe(operator) 129 130 def test_convert_variables_to_tensors(self): 131 diag = variables_module.Variable([[3., 6., 2.], [2., 4., 2.], [5., 1., 2.]]) 132 operator = linalg_lib.LinearOperatorTridiag( 133 diag, diagonals_format='compact') 134 with self.cached_session() as sess: 135 sess.run([diag.initializer]) 136 self.check_convert_variables_to_tensors(operator) 137 138 139@test_util.with_eager_op_as_function 140@test_util.run_all_in_graph_and_eager_modes 141class LinearOperatorTriDiagSequenceTest( 142 _LinearOperatorTriDiagBase, 143 linear_operator_test_util.SquareLinearOperatorDerivedClassTest): 144 """Most tests done in the base class LinearOperatorDerivedClassTest.""" 145 146 def tearDown(self): 147 config.enable_tensor_float_32_execution(self.tf32_keep_) 148 149 def setUp(self): 150 self.tf32_keep_ = config.tensor_float_32_execution_enabled() 151 config.enable_tensor_float_32_execution(False) 152 153 def operator_and_matrix( 154 self, build_info, dtype, use_placeholder, 155 ensure_self_adjoint_and_pd=False): 156 return self.build_operator_and_matrix( 157 build_info, dtype, use_placeholder, 158 ensure_self_adjoint_and_pd=ensure_self_adjoint_and_pd, 159 diagonals_format='sequence') 160 161 @test_util.disable_xla('Current implementation does not yet support pivoting') 162 def test_tape_safe(self): 163 diagonals = [ 164 variables_module.Variable([3., 6., 2.]), 165 variables_module.Variable([2., 4., 2.]), 166 variables_module.Variable([5., 1., 2.])] 167 operator = linalg_lib.LinearOperatorTridiag( 168 diagonals, diagonals_format='sequence') 169 # Skip the diagonal part and trace since this only dependent on the 170 # middle variable. We test this below. 171 self.check_tape_safe(operator, skip_options=['diag_part', 'trace']) 172 173 diagonals = [ 174 [3., 6., 2.], 175 variables_module.Variable([2., 4., 2.]), 176 [5., 1., 2.] 177 ] 178 operator = linalg_lib.LinearOperatorTridiag( 179 diagonals, diagonals_format='sequence') 180 181 182@test_util.with_eager_op_as_function 183@test_util.run_all_in_graph_and_eager_modes 184class LinearOperatorTriDiagMatrixTest( 185 _LinearOperatorTriDiagBase, 186 linear_operator_test_util.SquareLinearOperatorDerivedClassTest): 187 """Most tests done in the base class LinearOperatorDerivedClassTest.""" 188 189 def tearDown(self): 190 config.enable_tensor_float_32_execution(self.tf32_keep_) 191 192 def setUp(self): 193 self.tf32_keep_ = config.tensor_float_32_execution_enabled() 194 config.enable_tensor_float_32_execution(False) 195 196 def operator_and_matrix( 197 self, build_info, dtype, use_placeholder, 198 ensure_self_adjoint_and_pd=False): 199 return self.build_operator_and_matrix( 200 build_info, dtype, use_placeholder, 201 ensure_self_adjoint_and_pd=ensure_self_adjoint_and_pd, 202 diagonals_format='matrix') 203 204 @test_util.disable_xla('Current implementation does not yet support pivoting') 205 def test_tape_safe(self): 206 matrix = variables_module.Variable([[3., 2., 0.], [1., 6., 4.], [0., 2, 2]]) 207 operator = linalg_lib.LinearOperatorTridiag( 208 matrix, diagonals_format='matrix') 209 self.check_tape_safe(operator) 210 211 212if __name__ == '__main__': 213 if not test_util.is_xla_enabled(): 214 linear_operator_test_util.add_tests(LinearOperatorTriDiagCompactTest) 215 linear_operator_test_util.add_tests(LinearOperatorTriDiagSequenceTest) 216 linear_operator_test_util.add_tests(LinearOperatorTriDiagMatrixTest) 217 test.main() 218