• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16from tensorflow.python.framework import config
17from tensorflow.python.framework import test_util
18from tensorflow.python.ops import array_ops
19from tensorflow.python.ops import gen_array_ops
20from tensorflow.python.ops import manip_ops
21from tensorflow.python.ops import math_ops
22from tensorflow.python.ops import variables as variables_module
23from tensorflow.python.ops.linalg import linalg as linalg_lib
24from tensorflow.python.ops.linalg import linear_operator_test_util
25from tensorflow.python.platform import test
26
27
28class _LinearOperatorTriDiagBase(object):
29
30  def build_operator_and_matrix(
31      self, build_info, dtype, use_placeholder,
32      ensure_self_adjoint_and_pd=False,
33      diagonals_format='sequence'):
34    shape = list(build_info.shape)
35
36    # Ensure that diagonal has large enough values. If we generate a
37    # self adjoint PD matrix, then the diagonal will be dominant guaranteeing
38    # positive definitess.
39    diag = linear_operator_test_util.random_sign_uniform(
40        shape[:-1], minval=4., maxval=6., dtype=dtype)
41    # We'll truncate these depending on the format
42    subdiag = linear_operator_test_util.random_sign_uniform(
43        shape[:-1], minval=1., maxval=2., dtype=dtype)
44    if ensure_self_adjoint_and_pd:
45      # Abs on complex64 will result in a float32, so we cast back up.
46      diag = math_ops.cast(math_ops.abs(diag), dtype=dtype)
47      # The first element of subdiag is ignored. We'll add a dummy element
48      # to superdiag to pad it.
49      superdiag = math_ops.conj(subdiag)
50      superdiag = manip_ops.roll(superdiag, shift=-1, axis=-1)
51    else:
52      superdiag = linear_operator_test_util.random_sign_uniform(
53          shape[:-1], minval=1., maxval=2., dtype=dtype)
54
55    matrix_diagonals = array_ops.stack(
56        [superdiag, diag, subdiag], axis=-2)
57    matrix = gen_array_ops.matrix_diag_v3(
58        matrix_diagonals,
59        k=(-1, 1),
60        num_rows=-1,
61        num_cols=-1,
62        align='LEFT_RIGHT',
63        padding_value=0.)
64
65    if diagonals_format == 'sequence':
66      diagonals = [superdiag, diag, subdiag]
67    elif diagonals_format == 'compact':
68      diagonals = array_ops.stack([superdiag, diag, subdiag], axis=-2)
69    elif diagonals_format == 'matrix':
70      diagonals = matrix
71
72    lin_op_diagonals = diagonals
73
74    if use_placeholder:
75      if diagonals_format == 'sequence':
76        lin_op_diagonals = [array_ops.placeholder_with_default(
77            d, shape=None) for d in lin_op_diagonals]
78      else:
79        lin_op_diagonals = array_ops.placeholder_with_default(
80            lin_op_diagonals, shape=None)
81
82    operator = linalg_lib.LinearOperatorTridiag(
83        diagonals=lin_op_diagonals,
84        diagonals_format=diagonals_format,
85        is_self_adjoint=True if ensure_self_adjoint_and_pd else None,
86        is_positive_definite=True if ensure_self_adjoint_and_pd else None)
87    return operator, matrix
88
89  @staticmethod
90  def operator_shapes_infos():
91    shape_info = linear_operator_test_util.OperatorShapesInfo
92    # non-batch operators (n, n) and batch operators.
93    return [
94        shape_info((3, 3)),
95        shape_info((1, 6, 6)),
96        shape_info((3, 4, 4)),
97        shape_info((2, 1, 3, 3))
98    ]
99
100
101@test_util.with_eager_op_as_function
102@test_util.run_all_in_graph_and_eager_modes
103class LinearOperatorTriDiagCompactTest(
104    _LinearOperatorTriDiagBase,
105    linear_operator_test_util.SquareLinearOperatorDerivedClassTest):
106  """Most tests done in the base class LinearOperatorDerivedClassTest."""
107
108  def tearDown(self):
109    config.enable_tensor_float_32_execution(self.tf32_keep_)
110
111  def setUp(self):
112    self.tf32_keep_ = config.tensor_float_32_execution_enabled()
113    config.enable_tensor_float_32_execution(False)
114
115  def operator_and_matrix(
116      self, build_info, dtype, use_placeholder,
117      ensure_self_adjoint_and_pd=False):
118    return self.build_operator_and_matrix(
119        build_info, dtype, use_placeholder,
120        ensure_self_adjoint_and_pd=ensure_self_adjoint_and_pd,
121        diagonals_format='compact')
122
123  @test_util.disable_xla('Current implementation does not yet support pivoting')
124  def test_tape_safe(self):
125    diag = variables_module.Variable([[3., 6., 2.], [2., 4., 2.], [5., 1., 2.]])
126    operator = linalg_lib.LinearOperatorTridiag(
127        diag, diagonals_format='compact')
128    self.check_tape_safe(operator)
129
130  def test_convert_variables_to_tensors(self):
131    diag = variables_module.Variable([[3., 6., 2.], [2., 4., 2.], [5., 1., 2.]])
132    operator = linalg_lib.LinearOperatorTridiag(
133        diag, diagonals_format='compact')
134    with self.cached_session() as sess:
135      sess.run([diag.initializer])
136      self.check_convert_variables_to_tensors(operator)
137
138
139@test_util.with_eager_op_as_function
140@test_util.run_all_in_graph_and_eager_modes
141class LinearOperatorTriDiagSequenceTest(
142    _LinearOperatorTriDiagBase,
143    linear_operator_test_util.SquareLinearOperatorDerivedClassTest):
144  """Most tests done in the base class LinearOperatorDerivedClassTest."""
145
146  def tearDown(self):
147    config.enable_tensor_float_32_execution(self.tf32_keep_)
148
149  def setUp(self):
150    self.tf32_keep_ = config.tensor_float_32_execution_enabled()
151    config.enable_tensor_float_32_execution(False)
152
153  def operator_and_matrix(
154      self, build_info, dtype, use_placeholder,
155      ensure_self_adjoint_and_pd=False):
156    return self.build_operator_and_matrix(
157        build_info, dtype, use_placeholder,
158        ensure_self_adjoint_and_pd=ensure_self_adjoint_and_pd,
159        diagonals_format='sequence')
160
161  @test_util.disable_xla('Current implementation does not yet support pivoting')
162  def test_tape_safe(self):
163    diagonals = [
164        variables_module.Variable([3., 6., 2.]),
165        variables_module.Variable([2., 4., 2.]),
166        variables_module.Variable([5., 1., 2.])]
167    operator = linalg_lib.LinearOperatorTridiag(
168        diagonals, diagonals_format='sequence')
169    # Skip the diagonal part and trace since this only dependent on the
170    # middle variable. We test this below.
171    self.check_tape_safe(operator, skip_options=['diag_part', 'trace'])
172
173    diagonals = [
174        [3., 6., 2.],
175        variables_module.Variable([2., 4., 2.]),
176        [5., 1., 2.]
177    ]
178    operator = linalg_lib.LinearOperatorTridiag(
179        diagonals, diagonals_format='sequence')
180
181
182@test_util.with_eager_op_as_function
183@test_util.run_all_in_graph_and_eager_modes
184class LinearOperatorTriDiagMatrixTest(
185    _LinearOperatorTriDiagBase,
186    linear_operator_test_util.SquareLinearOperatorDerivedClassTest):
187  """Most tests done in the base class LinearOperatorDerivedClassTest."""
188
189  def tearDown(self):
190    config.enable_tensor_float_32_execution(self.tf32_keep_)
191
192  def setUp(self):
193    self.tf32_keep_ = config.tensor_float_32_execution_enabled()
194    config.enable_tensor_float_32_execution(False)
195
196  def operator_and_matrix(
197      self, build_info, dtype, use_placeholder,
198      ensure_self_adjoint_and_pd=False):
199    return self.build_operator_and_matrix(
200        build_info, dtype, use_placeholder,
201        ensure_self_adjoint_and_pd=ensure_self_adjoint_and_pd,
202        diagonals_format='matrix')
203
204  @test_util.disable_xla('Current implementation does not yet support pivoting')
205  def test_tape_safe(self):
206    matrix = variables_module.Variable([[3., 2., 0.], [1., 6., 4.], [0., 2, 2]])
207    operator = linalg_lib.LinearOperatorTridiag(
208        matrix, diagonals_format='matrix')
209    self.check_tape_safe(operator)
210
211
212if __name__ == '__main__':
213  if not test_util.is_xla_enabled():
214    linear_operator_test_util.add_tests(LinearOperatorTriDiagCompactTest)
215    linear_operator_test_util.add_tests(LinearOperatorTriDiagSequenceTest)
216    linear_operator_test_util.add_tests(LinearOperatorTriDiagMatrixTest)
217  test.main()
218