1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for registration mechanisms.""" 16 17from tensorflow.python.framework import tensor_shape 18from tensorflow.python.ops.linalg import cholesky_registrations # pylint: disable=unused-import 19from tensorflow.python.ops.linalg import linear_operator 20from tensorflow.python.ops.linalg import linear_operator_algebra 21from tensorflow.python.ops.linalg import matmul_registrations # pylint: disable=unused-import 22from tensorflow.python.ops.linalg import solve_registrations # pylint: disable=unused-import 23from tensorflow.python.platform import test 24 25# pylint: disable=protected-access 26_ADJOINTS = linear_operator_algebra._ADJOINTS 27_registered_adjoint = linear_operator_algebra._registered_adjoint 28_CHOLESKY_DECOMPS = linear_operator_algebra._CHOLESKY_DECOMPS 29_registered_cholesky = linear_operator_algebra._registered_cholesky 30_INVERSES = linear_operator_algebra._INVERSES 31_registered_inverse = linear_operator_algebra._registered_inverse 32_MATMUL = linear_operator_algebra._MATMUL 33_registered_matmul = linear_operator_algebra._registered_matmul 34_SOLVE = linear_operator_algebra._SOLVE 35_registered_solve = linear_operator_algebra._registered_solve 36# pylint: enable=protected-access 37 38 39class AdjointTest(test.TestCase): 40 41 def testRegistration(self): 42 43 class CustomLinOp(linear_operator.LinearOperator): 44 45 def _matmul(self, a): 46 pass 47 48 def _shape(self): 49 return tensor_shape.TensorShape([1, 1]) 50 51 def _shape_tensor(self): 52 pass 53 54 # Register Adjoint to a lambda that spits out the name parameter 55 @linear_operator_algebra.RegisterAdjoint(CustomLinOp) 56 def _adjoint(a): # pylint: disable=unused-argument,unused-variable 57 return "OK" 58 59 self.assertEqual("OK", CustomLinOp(dtype=None).adjoint()) 60 61 def testRegistrationFailures(self): 62 63 class CustomLinOp(linear_operator.LinearOperator): 64 pass 65 66 with self.assertRaisesRegex(TypeError, "must be callable"): 67 linear_operator_algebra.RegisterAdjoint(CustomLinOp)("blah") 68 69 # First registration is OK 70 linear_operator_algebra.RegisterAdjoint(CustomLinOp)(lambda a: None) 71 72 # Second registration fails 73 with self.assertRaisesRegex(ValueError, "has already been registered"): 74 linear_operator_algebra.RegisterAdjoint(CustomLinOp)(lambda a: None) 75 76 def testExactAdjointRegistrationsAllMatch(self): 77 for (k, v) in _ADJOINTS.items(): 78 self.assertEqual(v, _registered_adjoint(k[0])) 79 80 81class CholeskyTest(test.TestCase): 82 83 def testRegistration(self): 84 85 class CustomLinOp(linear_operator.LinearOperator): 86 87 def _matmul(self, a): 88 pass 89 90 def _shape(self): 91 return tensor_shape.TensorShape([1, 1]) 92 93 def _shape_tensor(self): 94 pass 95 96 # Register Cholesky to a lambda that spits out the name parameter 97 @linear_operator_algebra.RegisterCholesky(CustomLinOp) 98 def _cholesky(a): # pylint: disable=unused-argument,unused-variable 99 return "OK" 100 101 with self.assertRaisesRegex(ValueError, "positive definite"): 102 CustomLinOp(dtype=None, is_self_adjoint=True).cholesky() 103 104 with self.assertRaisesRegex(ValueError, "self adjoint"): 105 CustomLinOp(dtype=None, is_positive_definite=True).cholesky() 106 107 custom_linop = CustomLinOp( 108 dtype=None, is_self_adjoint=True, is_positive_definite=True) 109 self.assertEqual("OK", custom_linop.cholesky()) 110 111 def testRegistrationFailures(self): 112 113 class CustomLinOp(linear_operator.LinearOperator): 114 pass 115 116 with self.assertRaisesRegex(TypeError, "must be callable"): 117 linear_operator_algebra.RegisterCholesky(CustomLinOp)("blah") 118 119 # First registration is OK 120 linear_operator_algebra.RegisterCholesky(CustomLinOp)(lambda a: None) 121 122 # Second registration fails 123 with self.assertRaisesRegex(ValueError, "has already been registered"): 124 linear_operator_algebra.RegisterCholesky(CustomLinOp)(lambda a: None) 125 126 def testExactCholeskyRegistrationsAllMatch(self): 127 for (k, v) in _CHOLESKY_DECOMPS.items(): 128 self.assertEqual(v, _registered_cholesky(k[0])) 129 130 131class MatmulTest(test.TestCase): 132 133 def testRegistration(self): 134 135 class CustomLinOp(linear_operator.LinearOperator): 136 137 def _matmul(self, a): 138 pass 139 140 def _shape(self): 141 return tensor_shape.TensorShape([1, 1]) 142 143 def _shape_tensor(self): 144 pass 145 146 # Register Matmul to a lambda that spits out the name parameter 147 @linear_operator_algebra.RegisterMatmul(CustomLinOp, CustomLinOp) 148 def _matmul(a, b): # pylint: disable=unused-argument,unused-variable 149 return "OK" 150 151 custom_linop = CustomLinOp( 152 dtype=None, is_self_adjoint=True, is_positive_definite=True) 153 self.assertEqual("OK", custom_linop.matmul(custom_linop)) 154 155 def testRegistrationFailures(self): 156 157 class CustomLinOp(linear_operator.LinearOperator): 158 pass 159 160 with self.assertRaisesRegex(TypeError, "must be callable"): 161 linear_operator_algebra.RegisterMatmul(CustomLinOp, CustomLinOp)("blah") 162 163 # First registration is OK 164 linear_operator_algebra.RegisterMatmul( 165 CustomLinOp, CustomLinOp)(lambda a: None) 166 167 # Second registration fails 168 with self.assertRaisesRegex(ValueError, "has already been registered"): 169 linear_operator_algebra.RegisterMatmul( 170 CustomLinOp, CustomLinOp)(lambda a: None) 171 172 def testExactMatmulRegistrationsAllMatch(self): 173 for (k, v) in _MATMUL.items(): 174 self.assertEqual(v, _registered_matmul(k[0], k[1])) 175 176 177class SolveTest(test.TestCase): 178 179 def testRegistration(self): 180 181 class CustomLinOp(linear_operator.LinearOperator): 182 183 def _matmul(self, a): 184 pass 185 186 def _solve(self, a): 187 pass 188 189 def _shape(self): 190 return tensor_shape.TensorShape([1, 1]) 191 192 def _shape_tensor(self): 193 pass 194 195 # Register Solve to a lambda that spits out the name parameter 196 @linear_operator_algebra.RegisterSolve(CustomLinOp, CustomLinOp) 197 def _solve(a, b): # pylint: disable=unused-argument,unused-variable 198 return "OK" 199 200 custom_linop = CustomLinOp( 201 dtype=None, is_self_adjoint=True, is_positive_definite=True) 202 self.assertEqual("OK", custom_linop.solve(custom_linop)) 203 204 def testRegistrationFailures(self): 205 206 class CustomLinOp(linear_operator.LinearOperator): 207 pass 208 209 with self.assertRaisesRegex(TypeError, "must be callable"): 210 linear_operator_algebra.RegisterSolve(CustomLinOp, CustomLinOp)("blah") 211 212 # First registration is OK 213 linear_operator_algebra.RegisterSolve( 214 CustomLinOp, CustomLinOp)(lambda a: None) 215 216 # Second registration fails 217 with self.assertRaisesRegex(ValueError, "has already been registered"): 218 linear_operator_algebra.RegisterSolve( 219 CustomLinOp, CustomLinOp)(lambda a: None) 220 221 def testExactSolveRegistrationsAllMatch(self): 222 for (k, v) in _SOLVE.items(): 223 self.assertEqual(v, _registered_solve(k[0], k[1])) 224 225 226class InverseTest(test.TestCase): 227 228 def testRegistration(self): 229 230 class CustomLinOp(linear_operator.LinearOperator): 231 232 def _matmul(self, a): 233 pass 234 235 def _shape(self): 236 return tensor_shape.TensorShape([1, 1]) 237 238 def _shape_tensor(self): 239 pass 240 241 # Register Inverse to a lambda that spits out the name parameter 242 @linear_operator_algebra.RegisterInverse(CustomLinOp) 243 def _inverse(a): # pylint: disable=unused-argument,unused-variable 244 return "OK" 245 246 with self.assertRaisesRegex(ValueError, "singular"): 247 CustomLinOp(dtype=None, is_non_singular=False).inverse() 248 249 self.assertEqual("OK", CustomLinOp( 250 dtype=None, is_non_singular=True).inverse()) 251 252 def testRegistrationFailures(self): 253 254 class CustomLinOp(linear_operator.LinearOperator): 255 pass 256 257 with self.assertRaisesRegex(TypeError, "must be callable"): 258 linear_operator_algebra.RegisterInverse(CustomLinOp)("blah") 259 260 # First registration is OK 261 linear_operator_algebra.RegisterInverse(CustomLinOp)(lambda a: None) 262 263 # Second registration fails 264 with self.assertRaisesRegex(ValueError, "has already been registered"): 265 linear_operator_algebra.RegisterInverse(CustomLinOp)(lambda a: None) 266 267 def testExactRegistrationsAllMatch(self): 268 for (k, v) in _INVERSES.items(): 269 self.assertEqual(v, _registered_inverse(k[0])) 270 271 272if __name__ == "__main__": 273 test.main() 274