• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16import numpy as np
17
18from tensorflow.python.framework import test_util
19from tensorflow.python.ops import linalg_ops
20from tensorflow.python.ops.linalg import linalg as linalg_lib
21from tensorflow.python.ops.linalg import linear_operator_addition
22from tensorflow.python.platform import test
23
24linalg = linalg_lib
25rng = np.random.RandomState(0)
26
27add_operators = linear_operator_addition.add_operators
28
29
30# pylint: disable=unused-argument
31class _BadAdder(linear_operator_addition._Adder):
32  """Adder that will fail if used."""
33
34  def can_add(self, op1, op2):
35    raise AssertionError("BadAdder.can_add called!")
36
37  def _add(self, op1, op2, operator_name, hints):
38    raise AssertionError("This line should not be reached")
39
40
41# pylint: enable=unused-argument
42
43
44class LinearOperatorAdditionCorrectnessTest(test.TestCase):
45  """Tests correctness of addition with combinations of a few Adders.
46
47  Tests here are done with the _DEFAULT_ADDITION_TIERS, which means
48  add_operators should reduce all operators resulting in one single operator.
49
50  This shows that we are able to correctly combine adders using the tiered
51  system.  All Adders should be tested separately, and there is no need to test
52  every Adder within this class.
53  """
54
55  def test_one_operator_is_returned_unchanged(self):
56    op_a = linalg.LinearOperatorDiag([1., 1.])
57    op_sum = add_operators([op_a])
58    self.assertEqual(1, len(op_sum))
59    self.assertIs(op_sum[0], op_a)
60
61  def test_at_least_one_operators_required(self):
62    with self.assertRaisesRegex(ValueError, "must contain at least one"):
63      add_operators([])
64
65  def test_attempting_to_add_numbers_raises(self):
66    with self.assertRaisesRegex(TypeError, "contain only LinearOperator"):
67      add_operators([1, 2])
68
69  @test_util.run_deprecated_v1
70  def test_two_diag_operators(self):
71    op_a = linalg.LinearOperatorDiag(
72        [1., 1.], is_positive_definite=True, name="A")
73    op_b = linalg.LinearOperatorDiag(
74        [2., 2.], is_positive_definite=True, name="B")
75    with self.cached_session():
76      op_sum = add_operators([op_a, op_b])
77      self.assertEqual(1, len(op_sum))
78      op = op_sum[0]
79      self.assertIsInstance(op, linalg_lib.LinearOperatorDiag)
80      self.assertAllClose([[3., 0.], [0., 3.]], op.to_dense())
81      # Adding positive definite operators produces positive def.
82      self.assertTrue(op.is_positive_definite)
83      # Real diagonal ==> self-adjoint.
84      self.assertTrue(op.is_self_adjoint)
85      # Positive definite ==> non-singular
86      self.assertTrue(op.is_non_singular)
87      # Enforce particular name for this simple case
88      self.assertEqual("Add/B__A/", op.name)
89
90  @test_util.run_deprecated_v1
91  def test_three_diag_operators(self):
92    op1 = linalg.LinearOperatorDiag(
93        [1., 1.], is_positive_definite=True, name="op1")
94    op2 = linalg.LinearOperatorDiag(
95        [2., 2.], is_positive_definite=True, name="op2")
96    op3 = linalg.LinearOperatorDiag(
97        [3., 3.], is_positive_definite=True, name="op3")
98    with self.cached_session():
99      op_sum = add_operators([op1, op2, op3])
100      self.assertEqual(1, len(op_sum))
101      op = op_sum[0]
102      self.assertTrue(isinstance(op, linalg_lib.LinearOperatorDiag))
103      self.assertAllClose([[6., 0.], [0., 6.]], op.to_dense())
104      # Adding positive definite operators produces positive def.
105      self.assertTrue(op.is_positive_definite)
106      # Real diagonal ==> self-adjoint.
107      self.assertTrue(op.is_self_adjoint)
108      # Positive definite ==> non-singular
109      self.assertTrue(op.is_non_singular)
110
111  @test_util.run_deprecated_v1
112  def test_diag_tril_diag(self):
113    op1 = linalg.LinearOperatorDiag(
114        [1., 1.], is_non_singular=True, name="diag_a")
115    op2 = linalg.LinearOperatorLowerTriangular(
116        [[2., 0.], [0., 2.]],
117        is_self_adjoint=True,
118        is_non_singular=True,
119        name="tril")
120    op3 = linalg.LinearOperatorDiag(
121        [3., 3.], is_non_singular=True, name="diag_b")
122    with self.cached_session():
123      op_sum = add_operators([op1, op2, op3])
124      self.assertEqual(1, len(op_sum))
125      op = op_sum[0]
126      self.assertIsInstance(op, linalg_lib.LinearOperatorLowerTriangular)
127      self.assertAllClose([[6., 0.], [0., 6.]], op.to_dense())
128
129      # The diag operators will be self-adjoint (because real and diagonal).
130      # The TriL operator has the self-adjoint hint set.
131      self.assertTrue(op.is_self_adjoint)
132
133      # Even though op1/2/3 are non-singular, this does not imply op is.
134      # Since no custom hint was provided, we default to None (unknown).
135      self.assertEqual(None, op.is_non_singular)
136
137  @test_util.run_deprecated_v1
138  def test_matrix_diag_tril_diag_uses_custom_name(self):
139    op0 = linalg.LinearOperatorFullMatrix(
140        [[-1., -1.], [-1., -1.]], name="matrix")
141    op1 = linalg.LinearOperatorDiag([1., 1.], name="diag_a")
142    op2 = linalg.LinearOperatorLowerTriangular(
143        [[2., 0.], [1.5, 2.]], name="tril")
144    op3 = linalg.LinearOperatorDiag([3., 3.], name="diag_b")
145    with self.cached_session():
146      op_sum = add_operators([op0, op1, op2, op3], operator_name="my_operator")
147      self.assertEqual(1, len(op_sum))
148      op = op_sum[0]
149      self.assertIsInstance(op, linalg_lib.LinearOperatorFullMatrix)
150      self.assertAllClose([[5., -1.], [0.5, 5.]], op.to_dense())
151      self.assertEqual("my_operator", op.name)
152
153  def test_incompatible_domain_dimensions_raises(self):
154    op1 = linalg.LinearOperatorFullMatrix(rng.rand(2, 3))
155    op2 = linalg.LinearOperatorDiag(rng.rand(2, 4))
156    with self.assertRaisesRegex(ValueError, "must.*same `domain_dimension`"):
157      add_operators([op1, op2])
158
159  def test_incompatible_range_dimensions_raises(self):
160    op1 = linalg.LinearOperatorFullMatrix(rng.rand(2, 3))
161    op2 = linalg.LinearOperatorDiag(rng.rand(3, 3))
162    with self.assertRaisesRegex(ValueError, "must.*same `range_dimension`"):
163      add_operators([op1, op2])
164
165  def test_non_broadcastable_batch_shape_raises(self):
166    op1 = linalg.LinearOperatorFullMatrix(rng.rand(2, 3, 3))
167    op2 = linalg.LinearOperatorDiag(rng.rand(4, 3, 3))
168    with self.assertRaisesRegex(ValueError, "Incompatible shapes"):
169      add_operators([op1, op2])
170
171
172class LinearOperatorOrderOfAdditionTest(test.TestCase):
173  """Test that the order of addition is done as specified by tiers."""
174
175  def test_tier_0_additions_done_in_tier_0(self):
176    diag1 = linalg.LinearOperatorDiag([1.])
177    diag2 = linalg.LinearOperatorDiag([1.])
178    diag3 = linalg.LinearOperatorDiag([1.])
179    addition_tiers = [
180        [linear_operator_addition._AddAndReturnDiag()],
181        [_BadAdder()],
182    ]
183    # Should not raise since all were added in tier 0, and tier 1 (with the
184    # _BadAdder) was never reached.
185    op_sum = add_operators([diag1, diag2, diag3], addition_tiers=addition_tiers)
186    self.assertEqual(1, len(op_sum))
187    self.assertIsInstance(op_sum[0], linalg.LinearOperatorDiag)
188
189  def test_tier_1_additions_done_by_tier_1(self):
190    diag1 = linalg.LinearOperatorDiag([1.])
191    diag2 = linalg.LinearOperatorDiag([1.])
192    tril = linalg.LinearOperatorLowerTriangular([[1.]])
193    addition_tiers = [
194        [linear_operator_addition._AddAndReturnDiag()],
195        [linear_operator_addition._AddAndReturnTriL()],
196        [_BadAdder()],
197    ]
198    # Should not raise since all were added by tier 1, and the
199    # _BadAdder) was never reached.
200    op_sum = add_operators([diag1, diag2, tril], addition_tiers=addition_tiers)
201    self.assertEqual(1, len(op_sum))
202    self.assertIsInstance(op_sum[0], linalg.LinearOperatorLowerTriangular)
203
204  def test_tier_1_additions_done_by_tier_1_with_order_flipped(self):
205    diag1 = linalg.LinearOperatorDiag([1.])
206    diag2 = linalg.LinearOperatorDiag([1.])
207    tril = linalg.LinearOperatorLowerTriangular([[1.]])
208    addition_tiers = [
209        [linear_operator_addition._AddAndReturnTriL()],
210        [linear_operator_addition._AddAndReturnDiag()],
211        [_BadAdder()],
212    ]
213    # Tier 0 could convert to TriL, and this converted everything to TriL,
214    # including the Diags.
215    # Tier 1 was never used.
216    # Tier 2 was never used (therefore, _BadAdder didn't raise).
217    op_sum = add_operators([diag1, diag2, tril], addition_tiers=addition_tiers)
218    self.assertEqual(1, len(op_sum))
219    self.assertIsInstance(op_sum[0], linalg.LinearOperatorLowerTriangular)
220
221  @test_util.run_deprecated_v1
222  def test_cannot_add_everything_so_return_more_than_one_operator(self):
223    diag1 = linalg.LinearOperatorDiag([1.])
224    diag2 = linalg.LinearOperatorDiag([2.])
225    tril5 = linalg.LinearOperatorLowerTriangular([[5.]])
226    addition_tiers = [
227        [linear_operator_addition._AddAndReturnDiag()],
228    ]
229    # Tier 0 (the only tier) can only convert to Diag, so it combines the two
230    # diags, but the TriL is unchanged.
231    # Result should contain two operators, one Diag, one TriL.
232    op_sum = add_operators([diag1, diag2, tril5], addition_tiers=addition_tiers)
233    self.assertEqual(2, len(op_sum))
234    found_diag = False
235    found_tril = False
236    with self.cached_session():
237      for op in op_sum:
238        if isinstance(op, linalg.LinearOperatorDiag):
239          found_diag = True
240          self.assertAllClose([[3.]], op.to_dense())
241        if isinstance(op, linalg.LinearOperatorLowerTriangular):
242          found_tril = True
243          self.assertAllClose([[5.]], op.to_dense())
244      self.assertTrue(found_diag and found_tril)
245
246  def test_intermediate_tier_is_not_skipped(self):
247    diag1 = linalg.LinearOperatorDiag([1.])
248    diag2 = linalg.LinearOperatorDiag([1.])
249    tril = linalg.LinearOperatorLowerTriangular([[1.]])
250    addition_tiers = [
251        [linear_operator_addition._AddAndReturnDiag()],
252        [_BadAdder()],
253        [linear_operator_addition._AddAndReturnTriL()],
254    ]
255    # tril cannot be added in tier 0, and the intermediate tier 1 with the
256    # BadAdder will catch it and raise.
257    with self.assertRaisesRegex(AssertionError, "BadAdder.can_add called"):
258      add_operators([diag1, diag2, tril], addition_tiers=addition_tiers)
259
260
261class AddAndReturnScaledIdentityTest(test.TestCase):
262
263  def setUp(self):
264    self._adder = linear_operator_addition._AddAndReturnScaledIdentity()
265
266  @test_util.run_deprecated_v1
267  def test_identity_plus_identity(self):
268    id1 = linalg.LinearOperatorIdentity(num_rows=2)
269    id2 = linalg.LinearOperatorIdentity(num_rows=2, batch_shape=[3])
270    hints = linear_operator_addition._Hints(
271        is_positive_definite=True, is_non_singular=True)
272
273    self.assertTrue(self._adder.can_add(id1, id2))
274    operator = self._adder.add(id1, id2, "my_operator", hints)
275    self.assertIsInstance(operator, linalg.LinearOperatorScaledIdentity)
276
277    with self.cached_session():
278      self.assertAllClose(2 * linalg_ops.eye(num_rows=2, batch_shape=[3]),
279                          operator.to_dense())
280    self.assertTrue(operator.is_positive_definite)
281    self.assertTrue(operator.is_non_singular)
282    self.assertEqual("my_operator", operator.name)
283
284  @test_util.run_deprecated_v1
285  def test_identity_plus_scaled_identity(self):
286    id1 = linalg.LinearOperatorIdentity(num_rows=2, batch_shape=[3])
287    id2 = linalg.LinearOperatorScaledIdentity(num_rows=2, multiplier=2.2)
288    hints = linear_operator_addition._Hints(
289        is_positive_definite=True, is_non_singular=True)
290
291    self.assertTrue(self._adder.can_add(id1, id2))
292    operator = self._adder.add(id1, id2, "my_operator", hints)
293    self.assertIsInstance(operator, linalg.LinearOperatorScaledIdentity)
294
295    with self.cached_session():
296      self.assertAllClose(3.2 * linalg_ops.eye(num_rows=2, batch_shape=[3]),
297                          operator.to_dense())
298    self.assertTrue(operator.is_positive_definite)
299    self.assertTrue(operator.is_non_singular)
300    self.assertEqual("my_operator", operator.name)
301
302  @test_util.run_deprecated_v1
303  def test_scaled_identity_plus_scaled_identity(self):
304    id1 = linalg.LinearOperatorScaledIdentity(
305        num_rows=2, multiplier=[2.2, 2.2, 2.2])
306    id2 = linalg.LinearOperatorScaledIdentity(num_rows=2, multiplier=-1.0)
307    hints = linear_operator_addition._Hints(
308        is_positive_definite=True, is_non_singular=True)
309
310    self.assertTrue(self._adder.can_add(id1, id2))
311    operator = self._adder.add(id1, id2, "my_operator", hints)
312    self.assertIsInstance(operator, linalg.LinearOperatorScaledIdentity)
313
314    with self.cached_session():
315      self.assertAllClose(1.2 * linalg_ops.eye(num_rows=2, batch_shape=[3]),
316                          operator.to_dense())
317    self.assertTrue(operator.is_positive_definite)
318    self.assertTrue(operator.is_non_singular)
319    self.assertEqual("my_operator", operator.name)
320
321
322class AddAndReturnDiagTest(test.TestCase):
323
324  def setUp(self):
325    self._adder = linear_operator_addition._AddAndReturnDiag()
326
327  @test_util.run_deprecated_v1
328  def test_identity_plus_identity_returns_diag(self):
329    id1 = linalg.LinearOperatorIdentity(num_rows=2)
330    id2 = linalg.LinearOperatorIdentity(num_rows=2, batch_shape=[3])
331    hints = linear_operator_addition._Hints(
332        is_positive_definite=True, is_non_singular=True)
333
334    self.assertTrue(self._adder.can_add(id1, id2))
335    operator = self._adder.add(id1, id2, "my_operator", hints)
336    self.assertIsInstance(operator, linalg.LinearOperatorDiag)
337
338    with self.cached_session():
339      self.assertAllClose(2 * linalg_ops.eye(num_rows=2, batch_shape=[3]),
340                          operator.to_dense())
341    self.assertTrue(operator.is_positive_definite)
342    self.assertTrue(operator.is_non_singular)
343    self.assertEqual("my_operator", operator.name)
344
345  @test_util.run_deprecated_v1
346  def test_diag_plus_diag(self):
347    diag1 = rng.rand(2, 3, 4)
348    diag2 = rng.rand(4)
349    op1 = linalg.LinearOperatorDiag(diag1)
350    op2 = linalg.LinearOperatorDiag(diag2)
351    hints = linear_operator_addition._Hints(
352        is_positive_definite=True, is_non_singular=True)
353
354    self.assertTrue(self._adder.can_add(op1, op2))
355    operator = self._adder.add(op1, op2, "my_operator", hints)
356    self.assertIsInstance(operator, linalg.LinearOperatorDiag)
357
358    with self.cached_session():
359      self.assertAllClose(
360          linalg.LinearOperatorDiag(diag1 + diag2).to_dense(),
361          operator.to_dense())
362    self.assertTrue(operator.is_positive_definite)
363    self.assertTrue(operator.is_non_singular)
364    self.assertEqual("my_operator", operator.name)
365
366
367class AddAndReturnTriLTest(test.TestCase):
368
369  def setUp(self):
370    self._adder = linear_operator_addition._AddAndReturnTriL()
371
372  @test_util.run_deprecated_v1
373  def test_diag_plus_tril(self):
374    diag = linalg.LinearOperatorDiag([1., 2.])
375    tril = linalg.LinearOperatorLowerTriangular([[10., 0.], [30., 0.]])
376    hints = linear_operator_addition._Hints(
377        is_positive_definite=True, is_non_singular=True)
378
379    self.assertTrue(self._adder.can_add(diag, diag))
380    self.assertTrue(self._adder.can_add(diag, tril))
381    operator = self._adder.add(diag, tril, "my_operator", hints)
382    self.assertIsInstance(operator, linalg.LinearOperatorLowerTriangular)
383
384    with self.cached_session():
385      self.assertAllClose([[11., 0.], [30., 2.]], operator.to_dense())
386    self.assertTrue(operator.is_positive_definite)
387    self.assertTrue(operator.is_non_singular)
388    self.assertEqual("my_operator", operator.name)
389
390
391class AddAndReturnMatrixTest(test.TestCase):
392
393  def setUp(self):
394    self._adder = linear_operator_addition._AddAndReturnMatrix()
395
396  @test_util.run_deprecated_v1
397  def test_diag_plus_diag(self):
398    diag1 = linalg.LinearOperatorDiag([1., 2.])
399    diag2 = linalg.LinearOperatorDiag([-1., 3.])
400    hints = linear_operator_addition._Hints(
401        is_positive_definite=False, is_non_singular=False)
402
403    self.assertTrue(self._adder.can_add(diag1, diag2))
404    operator = self._adder.add(diag1, diag2, "my_operator", hints)
405    self.assertIsInstance(operator, linalg.LinearOperatorFullMatrix)
406
407    with self.cached_session():
408      self.assertAllClose([[0., 0.], [0., 5.]], operator.to_dense())
409    self.assertFalse(operator.is_positive_definite)
410    self.assertFalse(operator.is_non_singular)
411    self.assertEqual("my_operator", operator.name)
412
413
414if __name__ == "__main__":
415  test.main()
416