• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16import numpy as np
17
18from tensorflow.python.framework import config
19from tensorflow.python.framework import dtypes
20from tensorflow.python.framework import ops
21from tensorflow.python.framework import test_util
22from tensorflow.python.ops import array_ops
23from tensorflow.python.ops import variables as variables_module
24from tensorflow.python.ops.linalg import linalg as linalg_lib
25from tensorflow.python.ops.linalg import linear_operator_kronecker as kronecker
26from tensorflow.python.ops.linalg import linear_operator_lower_triangular as lower_triangular
27from tensorflow.python.ops.linalg import linear_operator_test_util
28from tensorflow.python.ops.linalg import linear_operator_util
29from tensorflow.python.platform import test
30
31linalg = linalg_lib
32rng = np.random.RandomState(0)
33
34
35def _kronecker_dense(factors):
36  """Convert a list of factors, into a dense Kronecker product."""
37  product = factors[0]
38  for factor in factors[1:]:
39    product = product[..., array_ops.newaxis, :, array_ops.newaxis]
40    factor_to_mul = factor[..., array_ops.newaxis, :, array_ops.newaxis, :]
41    product *= factor_to_mul
42    product = array_ops.reshape(
43        product,
44        shape=array_ops.concat(
45            [array_ops.shape(product)[:-4],
46             [array_ops.shape(product)[-4] * array_ops.shape(product)[-3],
47              array_ops.shape(product)[-2] * array_ops.shape(product)[-1]]
48            ], axis=0))
49
50  return product
51
52
53class KroneckerDenseTest(test.TestCase):
54  """Test of `_kronecker_dense` function."""
55
56  def test_kronecker_dense_matrix(self):
57    x = ops.convert_to_tensor([[2., 3.], [1., 2.]], dtype=dtypes.float32)
58    y = ops.convert_to_tensor([[1., 2.], [5., -1.]], dtype=dtypes.float32)
59    # From explicitly writing out the kronecker product of x and y.
60    z = ops.convert_to_tensor([
61        [2., 4., 3., 6.],
62        [10., -2., 15., -3.],
63        [1., 2., 2., 4.],
64        [5., -1., 10., -2.]], dtype=dtypes.float32)
65    # From explicitly writing out the kronecker product of y and x.
66    w = ops.convert_to_tensor([
67        [2., 3., 4., 6.],
68        [1., 2., 2., 4.],
69        [10., 15., -2., -3.],
70        [5., 10., -1., -2.]], dtype=dtypes.float32)
71
72    self.assertAllClose(
73        self.evaluate(_kronecker_dense([x, y])), self.evaluate(z))
74    self.assertAllClose(
75        self.evaluate(_kronecker_dense([y, x])), self.evaluate(w))
76
77
78@test_util.run_all_in_graph_and_eager_modes
79class SquareLinearOperatorKroneckerTest(
80    linear_operator_test_util.SquareLinearOperatorDerivedClassTest):
81  """Most tests done in the base class LinearOperatorDerivedClassTest."""
82
83  def tearDown(self):
84    config.enable_tensor_float_32_execution(self.tf32_keep_)
85
86  def setUp(self):
87    self.tf32_keep_ = config.tensor_float_32_execution_enabled()
88    config.enable_tensor_float_32_execution(False)
89    # Increase from 1e-6 to 1e-4
90    self._atol[dtypes.float32] = 1e-4
91    self._atol[dtypes.complex64] = 1e-4
92    self._rtol[dtypes.float32] = 1e-4
93    self._rtol[dtypes.complex64] = 1e-4
94
95  @staticmethod
96  def operator_shapes_infos():
97    shape_info = linear_operator_test_util.OperatorShapesInfo
98    return [
99        shape_info((1, 1), factors=[(1, 1), (1, 1)]),
100        shape_info((8, 8), factors=[(2, 2), (2, 2), (2, 2)]),
101        shape_info((12, 12), factors=[(2, 2), (3, 3), (2, 2)]),
102        shape_info((1, 3, 3), factors=[(1, 1), (1, 3, 3)]),
103        shape_info((3, 6, 6), factors=[(3, 1, 1), (1, 2, 2), (1, 3, 3)]),
104    ]
105
106  def operator_and_matrix(
107      self, build_info, dtype, use_placeholder,
108      ensure_self_adjoint_and_pd=False):
109    # Kronecker products constructed below will be from symmetric
110    # positive-definite matrices.
111    del ensure_self_adjoint_and_pd
112    shape = list(build_info.shape)
113    expected_factors = build_info.__dict__["factors"]
114    matrices = [
115        linear_operator_test_util.random_positive_definite_matrix(
116            block_shape, dtype, force_well_conditioned=True)
117        for block_shape in expected_factors
118    ]
119
120    lin_op_matrices = matrices
121
122    if use_placeholder:
123      lin_op_matrices = [
124          array_ops.placeholder_with_default(m, shape=None) for m in matrices]
125
126    operator = kronecker.LinearOperatorKronecker(
127        [linalg.LinearOperatorFullMatrix(
128            l,
129            is_square=True,
130            is_self_adjoint=True,
131            is_positive_definite=True)
132         for l in lin_op_matrices])
133
134    matrices = linear_operator_util.broadcast_matrix_batch_dims(matrices)
135
136    kronecker_dense = _kronecker_dense(matrices)
137
138    if not use_placeholder:
139      kronecker_dense.set_shape(shape)
140
141    return operator, kronecker_dense
142
143  def test_is_x_flags(self):
144    # Matrix with two positive eigenvalues, 1, and 1.
145    # The matrix values do not effect auto-setting of the flags.
146    matrix = [[1., 0.], [1., 1.]]
147    operator = kronecker.LinearOperatorKronecker(
148        [linalg.LinearOperatorFullMatrix(matrix),
149         linalg.LinearOperatorFullMatrix(matrix)],
150        is_positive_definite=True,
151        is_non_singular=True,
152        is_self_adjoint=False)
153    self.assertTrue(operator.is_positive_definite)
154    self.assertTrue(operator.is_non_singular)
155    self.assertFalse(operator.is_self_adjoint)
156
157  def test_is_non_singular_auto_set(self):
158    # Matrix with two positive eigenvalues, 11 and 8.
159    # The matrix values do not effect auto-setting of the flags.
160    matrix = [[11., 0.], [1., 8.]]
161    operator_1 = linalg.LinearOperatorFullMatrix(matrix, is_non_singular=True)
162    operator_2 = linalg.LinearOperatorFullMatrix(matrix, is_non_singular=True)
163
164    operator = kronecker.LinearOperatorKronecker(
165        [operator_1, operator_2],
166        is_positive_definite=False,  # No reason it HAS to be False...
167        is_non_singular=None)
168    self.assertFalse(operator.is_positive_definite)
169    self.assertTrue(operator.is_non_singular)
170
171    with self.assertRaisesRegex(ValueError, "always non-singular"):
172      kronecker.LinearOperatorKronecker(
173          [operator_1, operator_2], is_non_singular=False)
174
175  def test_name(self):
176    matrix = [[11., 0.], [1., 8.]]
177    operator_1 = linalg.LinearOperatorFullMatrix(matrix, name="left")
178    operator_2 = linalg.LinearOperatorFullMatrix(matrix, name="right")
179
180    operator = kronecker.LinearOperatorKronecker([operator_1, operator_2])
181
182    self.assertEqual("left_x_right", operator.name)
183
184  def test_different_dtypes_raises(self):
185    operators = [
186        linalg.LinearOperatorFullMatrix(rng.rand(2, 3, 3)),
187        linalg.LinearOperatorFullMatrix(rng.rand(2, 3, 3).astype(np.float32))
188    ]
189    with self.assertRaisesRegex(TypeError, "same dtype"):
190      kronecker.LinearOperatorKronecker(operators)
191
192  def test_empty_or_one_operators_raises(self):
193    with self.assertRaisesRegex(ValueError, ">=1 operators"):
194      kronecker.LinearOperatorKronecker([])
195
196  def test_kronecker_adjoint_type(self):
197    matrix = [[1., 0.], [0., 1.]]
198    operator = kronecker.LinearOperatorKronecker(
199        [
200            linalg.LinearOperatorFullMatrix(
201                matrix, is_non_singular=True),
202            linalg.LinearOperatorFullMatrix(
203                matrix, is_non_singular=True),
204        ],
205        is_non_singular=True,
206    )
207    adjoint = operator.adjoint()
208    self.assertIsInstance(
209        adjoint,
210        kronecker.LinearOperatorKronecker)
211    self.assertEqual(2, len(adjoint.operators))
212
213  def test_kronecker_cholesky_type(self):
214    matrix = [[1., 0.], [0., 1.]]
215    operator = kronecker.LinearOperatorKronecker(
216        [
217            linalg.LinearOperatorFullMatrix(
218                matrix,
219                is_positive_definite=True,
220                is_self_adjoint=True,
221            ),
222            linalg.LinearOperatorFullMatrix(
223                matrix,
224                is_positive_definite=True,
225                is_self_adjoint=True,
226            ),
227        ],
228        is_positive_definite=True,
229        is_self_adjoint=True,
230    )
231    cholesky_factor = operator.cholesky()
232    self.assertIsInstance(
233        cholesky_factor,
234        kronecker.LinearOperatorKronecker)
235    self.assertEqual(2, len(cholesky_factor.operators))
236    self.assertIsInstance(
237        cholesky_factor.operators[0],
238        lower_triangular.LinearOperatorLowerTriangular)
239    self.assertIsInstance(
240        cholesky_factor.operators[1],
241        lower_triangular.LinearOperatorLowerTriangular)
242
243  def test_kronecker_inverse_type(self):
244    matrix = [[1., 0.], [0., 1.]]
245    operator = kronecker.LinearOperatorKronecker(
246        [
247            linalg.LinearOperatorFullMatrix(
248                matrix, is_non_singular=True),
249            linalg.LinearOperatorFullMatrix(
250                matrix, is_non_singular=True),
251        ],
252        is_non_singular=True,
253    )
254    inverse = operator.inverse()
255    self.assertIsInstance(
256        inverse,
257        kronecker.LinearOperatorKronecker)
258    self.assertEqual(2, len(inverse.operators))
259
260  def test_tape_safe(self):
261    matrix_1 = variables_module.Variable([[1., 0.], [0., 1.]])
262    matrix_2 = variables_module.Variable([[2., 0.], [0., 2.]])
263    operator = kronecker.LinearOperatorKronecker(
264        [
265            linalg.LinearOperatorFullMatrix(
266                matrix_1, is_non_singular=True),
267            linalg.LinearOperatorFullMatrix(
268                matrix_2, is_non_singular=True),
269        ],
270        is_non_singular=True,
271    )
272    self.check_tape_safe(operator)
273
274  def test_convert_variables_to_tensors(self):
275    matrix_1 = variables_module.Variable([[1., 0.], [0., 1.]])
276    matrix_2 = variables_module.Variable([[2., 0.], [0., 2.]])
277    operator = kronecker.LinearOperatorKronecker(
278        [
279            linalg.LinearOperatorFullMatrix(
280                matrix_1, is_non_singular=True),
281            linalg.LinearOperatorFullMatrix(
282                matrix_2, is_non_singular=True),
283        ],
284        is_non_singular=True,
285    )
286    with self.cached_session() as sess:
287      sess.run([x.initializer for x in operator.variables])
288      self.check_convert_variables_to_tensors(operator)
289
290
291if __name__ == "__main__":
292  linear_operator_test_util.add_tests(SquareLinearOperatorKroneckerTest)
293  test.main()
294