• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16import numpy as np
17
18from tensorflow.python.framework import config
19from tensorflow.python.framework import dtypes
20from tensorflow.python.framework import test_util
21from tensorflow.python.ops import array_ops
22from tensorflow.python.ops import math_ops
23from tensorflow.python.ops.linalg import linalg as linalg_lib
24from tensorflow.python.ops.linalg import linear_operator_test_util
25from tensorflow.python.platform import test
26
27linalg = linalg_lib
28rng = np.random.RandomState(0)
29
30
31class SquareLinearOperatorCompositionTest(
32    linear_operator_test_util.SquareLinearOperatorDerivedClassTest):
33  """Most tests done in the base class LinearOperatorDerivedClassTest."""
34
35  def tearDown(self):
36    config.enable_tensor_float_32_execution(self.tf32_keep_)
37
38  def setUp(self):
39    self.tf32_keep_ = config.tensor_float_32_execution_enabled()
40    config.enable_tensor_float_32_execution(False)
41    # Increase from 1e-6 to 1e-4 and 2e-4.
42    self._atol[dtypes.float32] = 2e-4
43    self._atol[dtypes.complex64] = 1e-4
44    self._rtol[dtypes.float32] = 2e-4
45    self._rtol[dtypes.complex64] = 1e-4
46
47  @staticmethod
48  def skip_these_tests():
49    # Cholesky not implemented.
50    return ["cholesky"]
51
52  def operator_and_matrix(self, build_info, dtype, use_placeholder,
53                          ensure_self_adjoint_and_pd=False):
54    shape = list(build_info.shape)
55
56    # Either 1 or 2 matrices, depending.
57    num_operators = rng.randint(low=1, high=3)
58    if ensure_self_adjoint_and_pd:
59      # The random PD matrices are also symmetric. Here we are computing
60      # A @ A ... @ A. Since A is symmetric and PD, so are any powers of it.
61      matrices = [
62          linear_operator_test_util.random_positive_definite_matrix(
63              shape, dtype, force_well_conditioned=True)] * num_operators
64    else:
65      matrices = [
66          linear_operator_test_util.random_positive_definite_matrix(
67              shape, dtype, force_well_conditioned=True)
68          for _ in range(num_operators)
69      ]
70
71    lin_op_matrices = matrices
72
73    if use_placeholder:
74      lin_op_matrices = [
75          array_ops.placeholder_with_default(
76              matrix, shape=None) for matrix in matrices]
77
78    operator = linalg.LinearOperatorComposition(
79        [linalg.LinearOperatorFullMatrix(l) for l in lin_op_matrices],
80        is_positive_definite=True if ensure_self_adjoint_and_pd else None,
81        is_self_adjoint=True if ensure_self_adjoint_and_pd else None,
82        is_square=True)
83
84    matmul_order_list = list(reversed(matrices))
85    mat = matmul_order_list[0]
86    for other_mat in matmul_order_list[1:]:
87      mat = math_ops.matmul(other_mat, mat)
88
89    return operator, mat
90
91  def test_is_x_flags(self):
92    # Matrix with two positive eigenvalues, 1, and 1.
93    # The matrix values do not effect auto-setting of the flags.
94    matrix = [[1., 0.], [1., 1.]]
95    operator = linalg.LinearOperatorComposition(
96        [linalg.LinearOperatorFullMatrix(matrix)],
97        is_positive_definite=True,
98        is_non_singular=True,
99        is_self_adjoint=False)
100    self.assertTrue(operator.is_positive_definite)
101    self.assertTrue(operator.is_non_singular)
102    self.assertFalse(operator.is_self_adjoint)
103
104  def test_is_non_singular_auto_set(self):
105    # Matrix with two positive eigenvalues, 11 and 8.
106    # The matrix values do not effect auto-setting of the flags.
107    matrix = [[11., 0.], [1., 8.]]
108    operator_1 = linalg.LinearOperatorFullMatrix(matrix, is_non_singular=True)
109    operator_2 = linalg.LinearOperatorFullMatrix(matrix, is_non_singular=True)
110
111    operator = linalg.LinearOperatorComposition(
112        [operator_1, operator_2],
113        is_positive_definite=False,  # No reason it HAS to be False...
114        is_non_singular=None)
115    self.assertFalse(operator.is_positive_definite)
116    self.assertTrue(operator.is_non_singular)
117
118    with self.assertRaisesRegex(ValueError, "always non-singular"):
119      linalg.LinearOperatorComposition(
120          [operator_1, operator_2], is_non_singular=False)
121
122  def test_name(self):
123    matrix = [[11., 0.], [1., 8.]]
124    operator_1 = linalg.LinearOperatorFullMatrix(matrix, name="left")
125    operator_2 = linalg.LinearOperatorFullMatrix(matrix, name="right")
126
127    operator = linalg.LinearOperatorComposition([operator_1, operator_2])
128
129    self.assertEqual("left_o_right", operator.name)
130
131  def test_different_dtypes_raises(self):
132    operators = [
133        linalg.LinearOperatorFullMatrix(rng.rand(2, 3, 3)),
134        linalg.LinearOperatorFullMatrix(rng.rand(2, 3, 3).astype(np.float32))
135    ]
136    with self.assertRaisesRegex(TypeError, "same dtype"):
137      linalg.LinearOperatorComposition(operators)
138
139  def test_empty_operators_raises(self):
140    with self.assertRaisesRegex(ValueError, "non-empty"):
141      linalg.LinearOperatorComposition([])
142
143
144class NonSquareLinearOperatorCompositionTest(
145    linear_operator_test_util.NonSquareLinearOperatorDerivedClassTest):
146  """Most tests done in the base class LinearOperatorDerivedClassTest."""
147
148  def tearDown(self):
149    config.enable_tensor_float_32_execution(self.tf32_keep_)
150
151  def setUp(self):
152    self.tf32_keep_ = config.tensor_float_32_execution_enabled()
153    config.enable_tensor_float_32_execution(False)
154    # Increase from 1e-6 to 1e-4
155    self._atol[dtypes.float32] = 1e-4
156    self._atol[dtypes.complex64] = 1e-4
157    self._rtol[dtypes.float32] = 1e-4
158    self._rtol[dtypes.complex64] = 1e-4
159
160  @staticmethod
161  def skip_these_tests():
162    # Testing the condition number fails when using XLA with cuBLASLt
163    # A slight numerical difference between different matmul algorithms
164    # leads to large precision issues
165    return linear_operator_test_util.NonSquareLinearOperatorDerivedClassTest.skip_these_tests(
166    ) + ["cond"]
167
168  def operator_and_matrix(
169      self, build_info, dtype, use_placeholder,
170      ensure_self_adjoint_and_pd=False):
171    del ensure_self_adjoint_and_pd
172    shape = list(build_info.shape)
173
174    # Create 2 matrices/operators, A1, A2, which becomes A = A1 A2.
175    # Use inner dimension of 2.
176    k = 2
177    batch_shape = shape[:-2]
178    shape_1 = batch_shape + [shape[-2], k]
179    shape_2 = batch_shape + [k, shape[-1]]
180
181    # Ensure that the matrices are well-conditioned by generating
182    # random matrices whose singular values are close to 1.
183    # The reason to do this is because cond(AB) <= cond(A) * cond(B).
184    # By ensuring that each factor has condition number close to 1, we ensure
185    # that the condition number of the product isn't too far away from 1.
186    def generate_well_conditioned(shape, dtype):
187      m, n = shape[-2], shape[-1]
188      min_dim = min(m, n)
189      # Generate singular values that are close to 1.
190      d = linear_operator_test_util.random_normal(
191          shape[:-2] + [min_dim],
192          mean=1.,
193          stddev=0.1,
194          dtype=dtype)
195      zeros = array_ops.zeros(shape=shape[:-2] + [m, n], dtype=dtype)
196      d = linalg_lib.set_diag(zeros, d)
197      u, _ = linalg_lib.qr(linear_operator_test_util.random_normal(
198          shape[:-2] + [m, m], dtype=dtype))
199
200      v, _ = linalg_lib.qr(linear_operator_test_util.random_normal(
201          shape[:-2] + [n, n], dtype=dtype))
202      return math_ops.matmul(u, math_ops.matmul(d, v))
203
204    matrices = [
205        generate_well_conditioned(shape_1, dtype=dtype),
206        generate_well_conditioned(shape_2, dtype=dtype),
207    ]
208
209    lin_op_matrices = matrices
210
211    if use_placeholder:
212      lin_op_matrices = [
213          array_ops.placeholder_with_default(
214              matrix, shape=None) for matrix in matrices]
215
216    operator = linalg.LinearOperatorComposition(
217        [linalg.LinearOperatorFullMatrix(l) for l in lin_op_matrices])
218
219    matmul_order_list = list(reversed(matrices))
220    mat = matmul_order_list[0]
221    for other_mat in matmul_order_list[1:]:
222      mat = math_ops.matmul(other_mat, mat)
223
224    return operator, mat
225
226  @test_util.run_deprecated_v1
227  def test_static_shapes(self):
228    operators = [
229        linalg.LinearOperatorFullMatrix(rng.rand(2, 3, 4)),
230        linalg.LinearOperatorFullMatrix(rng.rand(2, 4, 5))
231    ]
232    operator = linalg.LinearOperatorComposition(operators)
233    self.assertAllEqual((2, 3, 5), operator.shape)
234
235  @test_util.run_deprecated_v1
236  def test_shape_tensors_when_statically_available(self):
237    operators = [
238        linalg.LinearOperatorFullMatrix(rng.rand(2, 3, 4)),
239        linalg.LinearOperatorFullMatrix(rng.rand(2, 4, 5))
240    ]
241    operator = linalg.LinearOperatorComposition(operators)
242    with self.cached_session():
243      self.assertAllEqual((2, 3, 5), operator.shape_tensor())
244
245  @test_util.run_deprecated_v1
246  def test_shape_tensors_when_only_dynamically_available(self):
247    mat_1 = rng.rand(1, 2, 3, 4)
248    mat_2 = rng.rand(1, 2, 4, 5)
249    mat_ph_1 = array_ops.placeholder(dtypes.float64)
250    mat_ph_2 = array_ops.placeholder(dtypes.float64)
251    feed_dict = {mat_ph_1: mat_1, mat_ph_2: mat_2}
252
253    operators = [
254        linalg.LinearOperatorFullMatrix(mat_ph_1),
255        linalg.LinearOperatorFullMatrix(mat_ph_2)
256    ]
257    operator = linalg.LinearOperatorComposition(operators)
258    with self.cached_session():
259      self.assertAllEqual(
260          (1, 2, 3, 5), operator.shape_tensor().eval(feed_dict=feed_dict))
261
262
263if __name__ == "__main__":
264  linear_operator_test_util.add_tests(SquareLinearOperatorCompositionTest)
265  linear_operator_test_util.add_tests(NonSquareLinearOperatorCompositionTest)
266  test.main()
267