• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for registration mechanisms."""
16
17from tensorflow.python.framework import tensor_shape
18from tensorflow.python.ops.linalg import cholesky_registrations  # pylint: disable=unused-import
19from tensorflow.python.ops.linalg import linear_operator
20from tensorflow.python.ops.linalg import linear_operator_algebra
21from tensorflow.python.ops.linalg import matmul_registrations  # pylint: disable=unused-import
22from tensorflow.python.ops.linalg import solve_registrations  # pylint: disable=unused-import
23from tensorflow.python.platform import test
24
25# pylint: disable=protected-access
26_ADJOINTS = linear_operator_algebra._ADJOINTS
27_registered_adjoint = linear_operator_algebra._registered_adjoint
28_CHOLESKY_DECOMPS = linear_operator_algebra._CHOLESKY_DECOMPS
29_registered_cholesky = linear_operator_algebra._registered_cholesky
30_INVERSES = linear_operator_algebra._INVERSES
31_registered_inverse = linear_operator_algebra._registered_inverse
32_MATMUL = linear_operator_algebra._MATMUL
33_registered_matmul = linear_operator_algebra._registered_matmul
34_SOLVE = linear_operator_algebra._SOLVE
35_registered_solve = linear_operator_algebra._registered_solve
36# pylint: enable=protected-access
37
38
39class AdjointTest(test.TestCase):
40
41  def testRegistration(self):
42
43    class CustomLinOp(linear_operator.LinearOperator):
44
45      def _matmul(self, a):
46        pass
47
48      def _shape(self):
49        return tensor_shape.TensorShape([1, 1])
50
51      def _shape_tensor(self):
52        pass
53
54    # Register Adjoint to a lambda that spits out the name parameter
55    @linear_operator_algebra.RegisterAdjoint(CustomLinOp)
56    def _adjoint(a):  # pylint: disable=unused-argument,unused-variable
57      return "OK"
58
59    self.assertEqual("OK", CustomLinOp(dtype=None).adjoint())
60
61  def testRegistrationFailures(self):
62
63    class CustomLinOp(linear_operator.LinearOperator):
64      pass
65
66    with self.assertRaisesRegex(TypeError, "must be callable"):
67      linear_operator_algebra.RegisterAdjoint(CustomLinOp)("blah")
68
69    # First registration is OK
70    linear_operator_algebra.RegisterAdjoint(CustomLinOp)(lambda a: None)
71
72    # Second registration fails
73    with self.assertRaisesRegex(ValueError, "has already been registered"):
74      linear_operator_algebra.RegisterAdjoint(CustomLinOp)(lambda a: None)
75
76  def testExactAdjointRegistrationsAllMatch(self):
77    for (k, v) in _ADJOINTS.items():
78      self.assertEqual(v, _registered_adjoint(k[0]))
79
80
81class CholeskyTest(test.TestCase):
82
83  def testRegistration(self):
84
85    class CustomLinOp(linear_operator.LinearOperator):
86
87      def _matmul(self, a):
88        pass
89
90      def _shape(self):
91        return tensor_shape.TensorShape([1, 1])
92
93      def _shape_tensor(self):
94        pass
95
96    # Register Cholesky to a lambda that spits out the name parameter
97    @linear_operator_algebra.RegisterCholesky(CustomLinOp)
98    def _cholesky(a):  # pylint: disable=unused-argument,unused-variable
99      return "OK"
100
101    with self.assertRaisesRegex(ValueError, "positive definite"):
102      CustomLinOp(dtype=None, is_self_adjoint=True).cholesky()
103
104    with self.assertRaisesRegex(ValueError, "self adjoint"):
105      CustomLinOp(dtype=None, is_positive_definite=True).cholesky()
106
107    custom_linop = CustomLinOp(
108        dtype=None, is_self_adjoint=True, is_positive_definite=True)
109    self.assertEqual("OK", custom_linop.cholesky())
110
111  def testRegistrationFailures(self):
112
113    class CustomLinOp(linear_operator.LinearOperator):
114      pass
115
116    with self.assertRaisesRegex(TypeError, "must be callable"):
117      linear_operator_algebra.RegisterCholesky(CustomLinOp)("blah")
118
119    # First registration is OK
120    linear_operator_algebra.RegisterCholesky(CustomLinOp)(lambda a: None)
121
122    # Second registration fails
123    with self.assertRaisesRegex(ValueError, "has already been registered"):
124      linear_operator_algebra.RegisterCholesky(CustomLinOp)(lambda a: None)
125
126  def testExactCholeskyRegistrationsAllMatch(self):
127    for (k, v) in _CHOLESKY_DECOMPS.items():
128      self.assertEqual(v, _registered_cholesky(k[0]))
129
130
131class MatmulTest(test.TestCase):
132
133  def testRegistration(self):
134
135    class CustomLinOp(linear_operator.LinearOperator):
136
137      def _matmul(self, a):
138        pass
139
140      def _shape(self):
141        return tensor_shape.TensorShape([1, 1])
142
143      def _shape_tensor(self):
144        pass
145
146    # Register Matmul to a lambda that spits out the name parameter
147    @linear_operator_algebra.RegisterMatmul(CustomLinOp, CustomLinOp)
148    def _matmul(a, b):  # pylint: disable=unused-argument,unused-variable
149      return "OK"
150
151    custom_linop = CustomLinOp(
152        dtype=None, is_self_adjoint=True, is_positive_definite=True)
153    self.assertEqual("OK", custom_linop.matmul(custom_linop))
154
155  def testRegistrationFailures(self):
156
157    class CustomLinOp(linear_operator.LinearOperator):
158      pass
159
160    with self.assertRaisesRegex(TypeError, "must be callable"):
161      linear_operator_algebra.RegisterMatmul(CustomLinOp, CustomLinOp)("blah")
162
163    # First registration is OK
164    linear_operator_algebra.RegisterMatmul(
165        CustomLinOp, CustomLinOp)(lambda a: None)
166
167    # Second registration fails
168    with self.assertRaisesRegex(ValueError, "has already been registered"):
169      linear_operator_algebra.RegisterMatmul(
170          CustomLinOp, CustomLinOp)(lambda a: None)
171
172  def testExactMatmulRegistrationsAllMatch(self):
173    for (k, v) in _MATMUL.items():
174      self.assertEqual(v, _registered_matmul(k[0], k[1]))
175
176
177class SolveTest(test.TestCase):
178
179  def testRegistration(self):
180
181    class CustomLinOp(linear_operator.LinearOperator):
182
183      def _matmul(self, a):
184        pass
185
186      def _solve(self, a):
187        pass
188
189      def _shape(self):
190        return tensor_shape.TensorShape([1, 1])
191
192      def _shape_tensor(self):
193        pass
194
195    # Register Solve to a lambda that spits out the name parameter
196    @linear_operator_algebra.RegisterSolve(CustomLinOp, CustomLinOp)
197    def _solve(a, b):  # pylint: disable=unused-argument,unused-variable
198      return "OK"
199
200    custom_linop = CustomLinOp(
201        dtype=None, is_self_adjoint=True, is_positive_definite=True)
202    self.assertEqual("OK", custom_linop.solve(custom_linop))
203
204  def testRegistrationFailures(self):
205
206    class CustomLinOp(linear_operator.LinearOperator):
207      pass
208
209    with self.assertRaisesRegex(TypeError, "must be callable"):
210      linear_operator_algebra.RegisterSolve(CustomLinOp, CustomLinOp)("blah")
211
212    # First registration is OK
213    linear_operator_algebra.RegisterSolve(
214        CustomLinOp, CustomLinOp)(lambda a: None)
215
216    # Second registration fails
217    with self.assertRaisesRegex(ValueError, "has already been registered"):
218      linear_operator_algebra.RegisterSolve(
219          CustomLinOp, CustomLinOp)(lambda a: None)
220
221  def testExactSolveRegistrationsAllMatch(self):
222    for (k, v) in _SOLVE.items():
223      self.assertEqual(v, _registered_solve(k[0], k[1]))
224
225
226class InverseTest(test.TestCase):
227
228  def testRegistration(self):
229
230    class CustomLinOp(linear_operator.LinearOperator):
231
232      def _matmul(self, a):
233        pass
234
235      def _shape(self):
236        return tensor_shape.TensorShape([1, 1])
237
238      def _shape_tensor(self):
239        pass
240
241    # Register Inverse to a lambda that spits out the name parameter
242    @linear_operator_algebra.RegisterInverse(CustomLinOp)
243    def _inverse(a):  # pylint: disable=unused-argument,unused-variable
244      return "OK"
245
246    with self.assertRaisesRegex(ValueError, "singular"):
247      CustomLinOp(dtype=None, is_non_singular=False).inverse()
248
249    self.assertEqual("OK", CustomLinOp(
250        dtype=None, is_non_singular=True).inverse())
251
252  def testRegistrationFailures(self):
253
254    class CustomLinOp(linear_operator.LinearOperator):
255      pass
256
257    with self.assertRaisesRegex(TypeError, "must be callable"):
258      linear_operator_algebra.RegisterInverse(CustomLinOp)("blah")
259
260    # First registration is OK
261    linear_operator_algebra.RegisterInverse(CustomLinOp)(lambda a: None)
262
263    # Second registration fails
264    with self.assertRaisesRegex(ValueError, "has already been registered"):
265      linear_operator_algebra.RegisterInverse(CustomLinOp)(lambda a: None)
266
267  def testExactRegistrationsAllMatch(self):
268    for (k, v) in _INVERSES.items():
269      self.assertEqual(v, _registered_inverse(k[0]))
270
271
272if __name__ == "__main__":
273  test.main()
274