• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16import numpy as np
17
18from tensorflow.python.framework import config
19from tensorflow.python.framework import dtypes
20from tensorflow.python.framework import test_util
21from tensorflow.python.ops import array_ops
22from tensorflow.python.ops import math_ops
23from tensorflow.python.ops import variables as variables_module
24from tensorflow.python.ops.linalg import linalg as linalg_lib
25from tensorflow.python.ops.linalg import linear_operator_test_util
26from tensorflow.python.platform import test
27
28linalg = linalg_lib
29rng = np.random.RandomState(0)
30
31
32class BaseLinearOperatorLowRankUpdatetest(object):
33  """Base test for this type of operator."""
34
35  # Subclasses should set these attributes to either True or False.
36
37  # If True, A = L + UDV^H
38  # If False, A = L + UV^H or A = L + UU^H, depending on _use_v.
39  _use_diag_update = None
40
41  # If True, diag is > 0, which means D is symmetric positive definite.
42  _is_diag_update_positive = None
43
44  # If True, A = L + UDV^H
45  # If False, A = L + UDU^H or A = L + UU^H, depending on _use_diag_update
46  _use_v = None
47
48  @staticmethod
49  def operator_shapes_infos():
50    shape_info = linear_operator_test_util.OperatorShapesInfo
51    # Previously we had a (2, 10, 10) shape at the end.  We did this to test the
52    # inversion and determinant lemmas on not-tiny matrices, since these are
53    # known to have stability issues.  This resulted in test timeouts, so this
54    # shape has been removed, but rest assured, the tests did pass.
55    return [
56        shape_info((0, 0)),
57        shape_info((1, 1)),
58        shape_info((1, 3, 3)),
59        shape_info((3, 4, 4)),
60        shape_info((2, 1, 4, 4))]
61
62  def _gen_positive_diag(self, dtype, diag_shape):
63    if dtype.is_complex:
64      diag = linear_operator_test_util.random_uniform(
65          diag_shape, minval=1e-4, maxval=1., dtype=dtypes.float32)
66      return math_ops.cast(diag, dtype=dtype)
67
68    return linear_operator_test_util.random_uniform(
69        diag_shape, minval=1e-4, maxval=1., dtype=dtype)
70
71  def operator_and_matrix(self, shape_info, dtype, use_placeholder,
72                          ensure_self_adjoint_and_pd=False):
73    # Recall A = L + UDV^H
74    shape = list(shape_info.shape)
75    diag_shape = shape[:-1]
76    k = shape[-2] // 2 + 1
77    u_perturbation_shape = shape[:-1] + [k]
78    diag_update_shape = shape[:-2] + [k]
79
80    # base_operator L will be a symmetric positive definite diagonal linear
81    # operator, with condition number as high as 1e4.
82    base_diag = self._gen_positive_diag(dtype, diag_shape)
83    lin_op_base_diag = base_diag
84
85    # U
86    u = linear_operator_test_util.random_normal_correlated_columns(
87        u_perturbation_shape, dtype=dtype)
88    lin_op_u = u
89
90    # V
91    v = linear_operator_test_util.random_normal_correlated_columns(
92        u_perturbation_shape, dtype=dtype)
93    lin_op_v = v
94
95    # D
96    if self._is_diag_update_positive or ensure_self_adjoint_and_pd:
97      diag_update = self._gen_positive_diag(dtype, diag_update_shape)
98    else:
99      diag_update = linear_operator_test_util.random_normal(
100          diag_update_shape, stddev=1e-4, dtype=dtype)
101    lin_op_diag_update = diag_update
102
103    if use_placeholder:
104      lin_op_base_diag = array_ops.placeholder_with_default(
105          base_diag, shape=None)
106      lin_op_u = array_ops.placeholder_with_default(u, shape=None)
107      lin_op_v = array_ops.placeholder_with_default(v, shape=None)
108      lin_op_diag_update = array_ops.placeholder_with_default(
109          diag_update, shape=None)
110
111    base_operator = linalg.LinearOperatorDiag(
112        lin_op_base_diag,
113        is_positive_definite=True,
114        is_self_adjoint=True)
115
116    operator = linalg.LinearOperatorLowRankUpdate(
117        base_operator,
118        lin_op_u,
119        v=lin_op_v if self._use_v else None,
120        diag_update=lin_op_diag_update if self._use_diag_update else None,
121        is_diag_update_positive=self._is_diag_update_positive)
122
123    # The matrix representing L
124    base_diag_mat = array_ops.matrix_diag(base_diag)
125
126    # The matrix representing D
127    diag_update_mat = array_ops.matrix_diag(diag_update)
128
129    # Set up mat as some variant of A = L + UDV^H
130    if self._use_v and self._use_diag_update:
131      # In this case, we have L + UDV^H and it isn't symmetric.
132      expect_use_cholesky = False
133      matrix = base_diag_mat + math_ops.matmul(
134          u, math_ops.matmul(diag_update_mat, v, adjoint_b=True))
135    elif self._use_v:
136      # In this case, we have L + UDV^H and it isn't symmetric.
137      expect_use_cholesky = False
138      matrix = base_diag_mat + math_ops.matmul(u, v, adjoint_b=True)
139    elif self._use_diag_update:
140      # In this case, we have L + UDU^H, which is PD if D > 0, since L > 0.
141      expect_use_cholesky = self._is_diag_update_positive
142      matrix = base_diag_mat + math_ops.matmul(
143          u, math_ops.matmul(diag_update_mat, u, adjoint_b=True))
144    else:
145      # In this case, we have L + UU^H, which is PD since L > 0.
146      expect_use_cholesky = True
147      matrix = base_diag_mat + math_ops.matmul(u, u, adjoint_b=True)
148
149    if expect_use_cholesky:
150      self.assertTrue(operator._use_cholesky)
151    else:
152      self.assertFalse(operator._use_cholesky)
153
154    return operator, matrix
155
156  def test_tape_safe(self):
157    base_operator = linalg.LinearOperatorDiag(
158        variables_module.Variable([1.], name="diag"),
159        is_positive_definite=True,
160        is_self_adjoint=True)
161
162    operator = linalg.LinearOperatorLowRankUpdate(
163        base_operator,
164        u=variables_module.Variable([[2.]], name="u"),
165        v=variables_module.Variable([[1.25]], name="v")
166        if self._use_v else None,
167        diag_update=variables_module.Variable([1.25], name="diag_update")
168        if self._use_diag_update else None,
169        is_diag_update_positive=self._is_diag_update_positive)
170    self.check_tape_safe(operator)
171
172  def test_convert_variables_to_tensors(self):
173    base_operator = linalg.LinearOperatorDiag(
174        variables_module.Variable([1.], name="diag"),
175        is_positive_definite=True,
176        is_self_adjoint=True)
177
178    operator = linalg.LinearOperatorLowRankUpdate(
179        base_operator,
180        u=variables_module.Variable([[2.]], name="u"),
181        v=variables_module.Variable([[1.25]], name="v")
182        if self._use_v else None,
183        diag_update=variables_module.Variable([1.25], name="diag_update")
184        if self._use_diag_update else None,
185        is_diag_update_positive=self._is_diag_update_positive)
186    with self.cached_session() as sess:
187      sess.run([x.initializer for x in operator.variables])
188      self.check_convert_variables_to_tensors(operator)
189
190
191class LinearOperatorLowRankUpdatetestWithDiagUseCholesky(
192    BaseLinearOperatorLowRankUpdatetest,
193    linear_operator_test_util.SquareLinearOperatorDerivedClassTest):
194  """A = L + UDU^H, D > 0, L > 0 ==> A > 0 and we can use a Cholesky."""
195
196  _use_diag_update = True
197  _is_diag_update_positive = True
198  _use_v = False
199
200  def tearDown(self):
201    config.enable_tensor_float_32_execution(self.tf32_keep_)
202
203  def setUp(self):
204    self.tf32_keep_ = config.tensor_float_32_execution_enabled()
205    config.enable_tensor_float_32_execution(False)
206    # Decrease tolerance since we are testing with condition numbers as high as
207    # 1e4.
208    self._atol[dtypes.float32] = 1e-5
209    self._rtol[dtypes.float32] = 1e-5
210    self._atol[dtypes.float64] = 1e-10
211    self._rtol[dtypes.float64] = 1e-10
212    self._rtol[dtypes.complex64] = 1e-4
213
214
215class LinearOperatorLowRankUpdatetestWithDiagCannotUseCholesky(
216    BaseLinearOperatorLowRankUpdatetest,
217    linear_operator_test_util.SquareLinearOperatorDerivedClassTest):
218  """A = L + UDU^H, D !> 0, L > 0 ==> A !> 0 and we cannot use a Cholesky."""
219
220  @staticmethod
221  def skip_these_tests():
222    return ["cholesky", "eigvalsh"]
223
224  _use_diag_update = True
225  _is_diag_update_positive = False
226  _use_v = False
227
228  def tearDown(self):
229    config.enable_tensor_float_32_execution(self.tf32_keep_)
230
231  def setUp(self):
232    self.tf32_keep_ = config.tensor_float_32_execution_enabled()
233    config.enable_tensor_float_32_execution(False)
234    # Decrease tolerance since we are testing with condition numbers as high as
235    # 1e4.  This class does not use Cholesky, and thus needs even looser
236    # tolerance.
237    self._atol[dtypes.float32] = 1e-4
238    self._rtol[dtypes.float32] = 1e-4
239    self._atol[dtypes.float64] = 1e-9
240    self._rtol[dtypes.float64] = 1e-9
241    self._rtol[dtypes.complex64] = 2e-4
242
243
244class LinearOperatorLowRankUpdatetestNoDiagUseCholesky(
245    BaseLinearOperatorLowRankUpdatetest,
246    linear_operator_test_util.SquareLinearOperatorDerivedClassTest):
247  """A = L + UU^H, L > 0 ==> A > 0 and we can use a Cholesky."""
248
249  _use_diag_update = False
250  _is_diag_update_positive = None
251  _use_v = False
252
253  def tearDown(self):
254    config.enable_tensor_float_32_execution(self.tf32_keep_)
255
256  def setUp(self):
257    self.tf32_keep_ = config.tensor_float_32_execution_enabled()
258    config.enable_tensor_float_32_execution(False)
259    # Decrease tolerance since we are testing with condition numbers as high as
260    # 1e4.
261    self._atol[dtypes.float32] = 1e-5
262    self._rtol[dtypes.float32] = 1e-5
263    self._atol[dtypes.float64] = 1e-10
264    self._rtol[dtypes.float64] = 1e-10
265    self._rtol[dtypes.complex64] = 1e-4
266
267
268class LinearOperatorLowRankUpdatetestNoDiagCannotUseCholesky(
269    BaseLinearOperatorLowRankUpdatetest,
270    linear_operator_test_util.SquareLinearOperatorDerivedClassTest):
271  """A = L + UV^H, L > 0 ==> A is not symmetric and we cannot use a Cholesky."""
272
273  @staticmethod
274  def skip_these_tests():
275    return ["cholesky", "eigvalsh"]
276
277  _use_diag_update = False
278  _is_diag_update_positive = None
279  _use_v = True
280
281  def tearDown(self):
282    config.enable_tensor_float_32_execution(self.tf32_keep_)
283
284  def setUp(self):
285    self.tf32_keep_ = config.tensor_float_32_execution_enabled()
286    config.enable_tensor_float_32_execution(False)
287    # Decrease tolerance since we are testing with condition numbers as high as
288    # 1e4.  This class does not use Cholesky, and thus needs even looser
289    # tolerance.
290    self._atol[dtypes.float32] = 1e-4
291    self._rtol[dtypes.float32] = 1e-4
292    self._atol[dtypes.float64] = 1e-9
293    self._rtol[dtypes.float64] = 1e-9
294    self._atol[dtypes.complex64] = 1e-5
295    self._rtol[dtypes.complex64] = 2e-4
296
297
298class LinearOperatorLowRankUpdatetestWithDiagNotSquare(
299    BaseLinearOperatorLowRankUpdatetest,
300    linear_operator_test_util.NonSquareLinearOperatorDerivedClassTest):
301  """A = L + UDU^H, D > 0, L > 0 ==> A > 0 and we can use a Cholesky."""
302
303  _use_diag_update = True
304  _is_diag_update_positive = True
305  _use_v = True
306
307  def tearDown(self):
308    config.enable_tensor_float_32_execution(self.tf32_keep_)
309
310  def setUp(self):
311    self.tf32_keep_ = config.tensor_float_32_execution_enabled()
312    config.enable_tensor_float_32_execution(False)
313
314
315@test_util.run_all_without_tensor_float_32(
316    "Linear op calls matmul which uses TensorFloat-32.")
317class LinearOperatorLowRankUpdateBroadcastsShape(test.TestCase):
318  """Test that the operator's shape is the broadcast of arguments."""
319
320  def test_static_shape_broadcasts_up_from_operator_to_other_args(self):
321    base_operator = linalg.LinearOperatorIdentity(num_rows=3)
322    u = array_ops.ones(shape=[2, 3, 2])
323    diag = array_ops.ones(shape=[2, 2])
324
325    operator = linalg.LinearOperatorLowRankUpdate(base_operator, u, diag)
326
327    # domain_dimension is 3
328    self.assertAllEqual([2, 3, 3], operator.shape)
329    self.assertAllEqual([2, 3, 3], self.evaluate(operator.to_dense()).shape)
330
331  @test_util.run_deprecated_v1
332  def test_dynamic_shape_broadcasts_up_from_operator_to_other_args(self):
333    num_rows_ph = array_ops.placeholder(dtypes.int32)
334    base_operator = linalg.LinearOperatorIdentity(num_rows=num_rows_ph)
335
336    u_shape_ph = array_ops.placeholder(dtypes.int32)
337    u = array_ops.ones(shape=u_shape_ph)
338
339    v_shape_ph = array_ops.placeholder(dtypes.int32)
340    v = array_ops.ones(shape=v_shape_ph)
341
342    diag_shape_ph = array_ops.placeholder(dtypes.int32)
343    diag_update = array_ops.ones(shape=diag_shape_ph)
344
345    operator = linalg.LinearOperatorLowRankUpdate(base_operator,
346                                                  u=u,
347                                                  diag_update=diag_update,
348                                                  v=v)
349
350    feed_dict = {
351        num_rows_ph: 3,
352        u_shape_ph: [1, 1, 2, 3, 2],  # batch_shape = [1, 1, 2]
353        v_shape_ph: [1, 2, 1, 3, 2],  # batch_shape = [1, 2, 1]
354        diag_shape_ph: [2, 1, 1, 2]  # batch_shape = [2, 1, 1]
355    }
356
357    with self.cached_session():
358      shape_tensor = operator.shape_tensor().eval(feed_dict=feed_dict)
359      self.assertAllEqual([2, 2, 2, 3, 3], shape_tensor)
360      dense = operator.to_dense().eval(feed_dict=feed_dict)
361      self.assertAllEqual([2, 2, 2, 3, 3], dense.shape)
362
363  def test_u_and_v_incompatible_batch_shape_raises(self):
364    base_operator = linalg.LinearOperatorIdentity(num_rows=3, dtype=np.float64)
365    u = rng.rand(5, 3, 2)
366    v = rng.rand(4, 3, 2)
367    with self.assertRaisesRegex(ValueError, "Incompatible shapes"):
368      linalg.LinearOperatorLowRankUpdate(base_operator, u=u, v=v)
369
370  def test_u_and_base_operator_incompatible_batch_shape_raises(self):
371    base_operator = linalg.LinearOperatorIdentity(
372        num_rows=3, batch_shape=[4], dtype=np.float64)
373    u = rng.rand(5, 3, 2)
374    with self.assertRaisesRegex(ValueError, "Incompatible shapes"):
375      linalg.LinearOperatorLowRankUpdate(base_operator, u=u)
376
377  def test_u_and_base_operator_incompatible_domain_dimension(self):
378    base_operator = linalg.LinearOperatorIdentity(num_rows=3, dtype=np.float64)
379    u = rng.rand(5, 4, 2)
380    with self.assertRaisesRegex(ValueError, "not compatible"):
381      linalg.LinearOperatorLowRankUpdate(base_operator, u=u)
382
383  def test_u_and_diag_incompatible_low_rank_raises(self):
384    base_operator = linalg.LinearOperatorIdentity(num_rows=3, dtype=np.float64)
385    u = rng.rand(5, 3, 2)
386    diag = rng.rand(5, 4)  # Last dimension should be 2
387    with self.assertRaisesRegex(ValueError, "not compatible"):
388      linalg.LinearOperatorLowRankUpdate(base_operator, u=u, diag_update=diag)
389
390  def test_diag_incompatible_batch_shape_raises(self):
391    base_operator = linalg.LinearOperatorIdentity(num_rows=3, dtype=np.float64)
392    u = rng.rand(5, 3, 2)
393    diag = rng.rand(4, 2)  # First dimension should be 5
394    with self.assertRaisesRegex(ValueError, "Incompatible shapes"):
395      linalg.LinearOperatorLowRankUpdate(base_operator, u=u, diag_update=diag)
396
397
398if __name__ == "__main__":
399  linear_operator_test_util.add_tests(
400      LinearOperatorLowRankUpdatetestWithDiagUseCholesky)
401  linear_operator_test_util.add_tests(
402      LinearOperatorLowRankUpdatetestWithDiagCannotUseCholesky)
403  linear_operator_test_util.add_tests(
404      LinearOperatorLowRankUpdatetestNoDiagUseCholesky)
405  linear_operator_test_util.add_tests(
406      LinearOperatorLowRankUpdatetestNoDiagCannotUseCholesky)
407  linear_operator_test_util.add_tests(
408      LinearOperatorLowRankUpdatetestWithDiagNotSquare)
409  test.main()
410