• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16import contextlib
17
18import numpy as np
19import scipy.linalg
20
21from tensorflow.python.eager import backprop
22from tensorflow.python.framework import dtypes
23from tensorflow.python.framework import test_util
24from tensorflow.python.ops import array_ops
25from tensorflow.python.ops import math_ops
26from tensorflow.python.ops import variables as variables_module
27from tensorflow.python.ops.linalg import linalg as linalg_lib
28from tensorflow.python.ops.linalg import linear_operator_test_util
29from tensorflow.python.ops.linalg import linear_operator_toeplitz
30from tensorflow.python.platform import test
31
32linalg = linalg_lib
33
34_to_complex = linear_operator_toeplitz._to_complex
35
36
37@test_util.run_all_in_graph_and_eager_modes
38class LinearOperatorToeplitzTest(
39    linear_operator_test_util.SquareLinearOperatorDerivedClassTest):
40  """Most tests done in the base class LinearOperatorDerivedClassTest."""
41
42  @contextlib.contextmanager
43  def _constrain_devices_and_set_default(self, sess, use_gpu, force_gpu):
44    """We overwrite the FFT operation mapping for testing."""
45    with test.TestCase._constrain_devices_and_set_default(
46        self, sess, use_gpu, force_gpu) as sess:
47      yield sess
48
49  def setUp(self):
50    # TODO(srvasude): Lower these tolerances once specialized solve and
51    # determinants are implemented.
52    self._atol[dtypes.float32] = 1e-4
53    self._rtol[dtypes.float32] = 1e-4
54    self._atol[dtypes.float64] = 1e-9
55    self._rtol[dtypes.float64] = 1e-9
56    self._atol[dtypes.complex64] = 1e-4
57    self._rtol[dtypes.complex64] = 1e-4
58    self._atol[dtypes.complex128] = 1e-9
59    self._rtol[dtypes.complex128] = 1e-9
60
61  @staticmethod
62  def skip_these_tests():
63    # Skip solve tests, as these could have better stability
64    # (currently exercises the base class).
65    # TODO(srvasude): Enable these when solve is implemented.
66    return ["cholesky", "cond", "inverse", "solve", "solve_with_broadcast"]
67
68  @staticmethod
69  def operator_shapes_infos():
70    shape_info = linear_operator_test_util.OperatorShapesInfo
71    # non-batch operators (n, n) and batch operators.
72    return [
73        shape_info((1, 1)),
74        shape_info((1, 6, 6)),
75        shape_info((3, 4, 4)),
76        shape_info((2, 1, 3, 3))
77    ]
78
79  def operator_and_matrix(
80      self, build_info, dtype, use_placeholder,
81      ensure_self_adjoint_and_pd=False):
82    shape = list(build_info.shape)
83    row = np.random.uniform(low=1., high=5., size=shape[:-1])
84    col = np.random.uniform(low=1., high=5., size=shape[:-1])
85
86    # Make sure first entry is the same
87    row[..., 0] = col[..., 0]
88
89    if ensure_self_adjoint_and_pd:
90      # Note that a Toeplitz matrix generated from a linearly decreasing
91      # non-negative sequence is positive definite. See
92      # https://www.math.cinvestav.mx/~grudsky/Papers/118_29062012_Albrecht.pdf
93      # for details.
94      row = np.linspace(start=10., stop=1., num=shape[-1])
95
96      # The entries for the first row and column should be the same to guarantee
97      # symmetric.
98      row = col
99
100    lin_op_row = math_ops.cast(row, dtype=dtype)
101    lin_op_col = math_ops.cast(col, dtype=dtype)
102
103    if use_placeholder:
104      lin_op_row = array_ops.placeholder_with_default(
105          lin_op_row, shape=None)
106      lin_op_col = array_ops.placeholder_with_default(
107          lin_op_col, shape=None)
108
109    operator = linear_operator_toeplitz.LinearOperatorToeplitz(
110        row=lin_op_row,
111        col=lin_op_col,
112        is_self_adjoint=True if ensure_self_adjoint_and_pd else None,
113        is_positive_definite=True if ensure_self_adjoint_and_pd else None)
114
115    flattened_row = np.reshape(row, (-1, shape[-1]))
116    flattened_col = np.reshape(col, (-1, shape[-1]))
117    flattened_toeplitz = np.zeros(
118        [flattened_row.shape[0], shape[-1], shape[-1]])
119    for i in range(flattened_row.shape[0]):
120      flattened_toeplitz[i] = scipy.linalg.toeplitz(
121          flattened_col[i],
122          flattened_row[i])
123    matrix = np.reshape(flattened_toeplitz, shape)
124    matrix = math_ops.cast(matrix, dtype=dtype)
125
126    return operator, matrix
127
128  def test_scalar_row_col_raises(self):
129    with self.assertRaisesRegex(ValueError, "must have at least 1 dimension"):
130      linear_operator_toeplitz.LinearOperatorToeplitz(1., 1.)
131
132    with self.assertRaisesRegex(ValueError, "must have at least 1 dimension"):
133      linear_operator_toeplitz.LinearOperatorToeplitz([1.], 1.)
134
135    with self.assertRaisesRegex(ValueError, "must have at least 1 dimension"):
136      linear_operator_toeplitz.LinearOperatorToeplitz(1., [1.])
137
138  def test_tape_safe(self):
139    col = variables_module.Variable([1.])
140    row = variables_module.Variable([1.])
141    operator = linear_operator_toeplitz.LinearOperatorToeplitz(
142        col, row, is_self_adjoint=True, is_positive_definite=True)
143    self.check_tape_safe(
144        operator,
145        skip_options=[
146            # .diag_part, .trace depend only on `col`, so test explicitly below.
147            linear_operator_test_util.CheckTapeSafeSkipOptions.DIAG_PART,
148            linear_operator_test_util.CheckTapeSafeSkipOptions.TRACE,
149        ])
150
151    with backprop.GradientTape() as tape:
152      self.assertIsNotNone(tape.gradient(operator.diag_part(), col))
153
154    with backprop.GradientTape() as tape:
155      self.assertIsNotNone(tape.gradient(operator.trace(), col))
156
157  def test_convert_variables_to_tensors(self):
158    col = variables_module.Variable([1.])
159    row = variables_module.Variable([1.])
160    operator = linear_operator_toeplitz.LinearOperatorToeplitz(
161        col, row, is_self_adjoint=True, is_positive_definite=True)
162    with self.cached_session() as sess:
163      sess.run([x.initializer for x in operator.variables])
164      self.check_convert_variables_to_tensors(operator)
165
166
167if __name__ == "__main__":
168  linear_operator_test_util.add_tests(LinearOperatorToeplitzTest)
169  test.main()
170