• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16from tensorflow.python.framework import constant_op
17from tensorflow.python.framework import config
18from tensorflow.python.framework import dtypes
19from tensorflow.python.framework import test_util
20from tensorflow.python.ops import array_ops
21from tensorflow.python.ops import linalg_ops
22from tensorflow.python.ops import math_ops
23from tensorflow.python.ops import random_ops
24from tensorflow.python.ops.linalg import linalg as linalg_lib
25from tensorflow.python.ops.linalg import linear_operator_permutation as permutation
26from tensorflow.python.ops.linalg import linear_operator_test_util
27from tensorflow.python.platform import test
28
29linalg = linalg_lib
30CheckTapeSafeSkipOptions = linear_operator_test_util.CheckTapeSafeSkipOptions
31
32
33@test_util.run_all_in_graph_and_eager_modes
34class LinearOperatorPermutationTest(
35    linear_operator_test_util.SquareLinearOperatorDerivedClassTest):
36  """Most tests done in the base class LinearOperatorDerivedClassTest."""
37
38  def tearDown(self):
39    config.enable_tensor_float_32_execution(self.tf32_keep_)
40
41  def setUp(self):
42    self.tf32_keep_ = config.tensor_float_32_execution_enabled()
43    config.enable_tensor_float_32_execution(False)
44
45  @staticmethod
46  def operator_shapes_infos():
47    shape_info = linear_operator_test_util.OperatorShapesInfo
48    return [
49        shape_info((1, 1)),
50        shape_info((1, 3, 3)),
51        shape_info((3, 4, 4)),
52        shape_info((2, 1, 4, 4))]
53
54  @staticmethod
55  def skip_these_tests():
56    # This linear operator is almost never positive definite.
57    return ["cholesky", "eigvalsh"]
58
59  def operator_and_matrix(
60      self, build_info, dtype, use_placeholder,
61      ensure_self_adjoint_and_pd=False):
62    shape = list(build_info.shape)
63    perm = math_ops.range(0, shape[-1])
64    perm = array_ops.broadcast_to(perm, shape[:-1])
65    perm = random_ops.random_shuffle(perm)
66
67    if use_placeholder:
68      perm = array_ops.placeholder_with_default(
69          perm, shape=None)
70
71    operator = permutation.LinearOperatorPermutation(
72        perm, dtype=dtype)
73    matrix = math_ops.cast(
74        math_ops.equal(
75            math_ops.range(0, shape[-1]),
76            perm[..., array_ops.newaxis]),
77        dtype)
78    return operator, matrix
79
80  def test_permutation_raises(self):
81    perm = constant_op.constant(0, dtype=dtypes.int32)
82    with self.assertRaisesRegex(ValueError, "must have at least 1 dimension"):
83      permutation.LinearOperatorPermutation(perm)
84    perm = [0., 1., 2.]
85    with self.assertRaisesRegex(TypeError, "must be integer dtype"):
86      permutation.LinearOperatorPermutation(perm)
87    perm = [-1, 2, 3]
88    with self.assertRaisesRegex(ValueError,
89                                "must be a vector of unique integers"):
90      permutation.LinearOperatorPermutation(perm)
91
92  def test_to_dense_4x4(self):
93    perm = [0, 1, 2, 3]
94    self.assertAllClose(
95        permutation.LinearOperatorPermutation(perm).to_dense(),
96        linalg_ops.eye(4))
97    perm = [1, 0, 3, 2]
98    self.assertAllClose(
99        permutation.LinearOperatorPermutation(perm).to_dense(),
100        [[0., 1, 0, 0], [1., 0, 0, 0], [0., 0, 0, 1], [0., 0, 1, 0]])
101    perm = [3, 2, 0, 1]
102    self.assertAllClose(
103        permutation.LinearOperatorPermutation(perm).to_dense(),
104        [[0., 0, 0, 1], [0., 0, 1, 0], [1., 0, 0, 0], [0., 1, 0, 0]])
105
106
107if __name__ == "__main__":
108  linear_operator_test_util.add_tests(LinearOperatorPermutationTest)
109  test.main()
110