1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16from tensorflow.python.framework import constant_op 17from tensorflow.python.framework import config 18from tensorflow.python.framework import dtypes 19from tensorflow.python.framework import test_util 20from tensorflow.python.ops import array_ops 21from tensorflow.python.ops import linalg_ops 22from tensorflow.python.ops import math_ops 23from tensorflow.python.ops import random_ops 24from tensorflow.python.ops.linalg import linalg as linalg_lib 25from tensorflow.python.ops.linalg import linear_operator_permutation as permutation 26from tensorflow.python.ops.linalg import linear_operator_test_util 27from tensorflow.python.platform import test 28 29linalg = linalg_lib 30CheckTapeSafeSkipOptions = linear_operator_test_util.CheckTapeSafeSkipOptions 31 32 33@test_util.run_all_in_graph_and_eager_modes 34class LinearOperatorPermutationTest( 35 linear_operator_test_util.SquareLinearOperatorDerivedClassTest): 36 """Most tests done in the base class LinearOperatorDerivedClassTest.""" 37 38 def tearDown(self): 39 config.enable_tensor_float_32_execution(self.tf32_keep_) 40 41 def setUp(self): 42 self.tf32_keep_ = config.tensor_float_32_execution_enabled() 43 config.enable_tensor_float_32_execution(False) 44 45 @staticmethod 46 def operator_shapes_infos(): 47 shape_info = linear_operator_test_util.OperatorShapesInfo 48 return [ 49 shape_info((1, 1)), 50 shape_info((1, 3, 3)), 51 shape_info((3, 4, 4)), 52 shape_info((2, 1, 4, 4))] 53 54 @staticmethod 55 def skip_these_tests(): 56 # This linear operator is almost never positive definite. 57 return ["cholesky", "eigvalsh"] 58 59 def operator_and_matrix( 60 self, build_info, dtype, use_placeholder, 61 ensure_self_adjoint_and_pd=False): 62 shape = list(build_info.shape) 63 perm = math_ops.range(0, shape[-1]) 64 perm = array_ops.broadcast_to(perm, shape[:-1]) 65 perm = random_ops.random_shuffle(perm) 66 67 if use_placeholder: 68 perm = array_ops.placeholder_with_default( 69 perm, shape=None) 70 71 operator = permutation.LinearOperatorPermutation( 72 perm, dtype=dtype) 73 matrix = math_ops.cast( 74 math_ops.equal( 75 math_ops.range(0, shape[-1]), 76 perm[..., array_ops.newaxis]), 77 dtype) 78 return operator, matrix 79 80 def test_permutation_raises(self): 81 perm = constant_op.constant(0, dtype=dtypes.int32) 82 with self.assertRaisesRegex(ValueError, "must have at least 1 dimension"): 83 permutation.LinearOperatorPermutation(perm) 84 perm = [0., 1., 2.] 85 with self.assertRaisesRegex(TypeError, "must be integer dtype"): 86 permutation.LinearOperatorPermutation(perm) 87 perm = [-1, 2, 3] 88 with self.assertRaisesRegex(ValueError, 89 "must be a vector of unique integers"): 90 permutation.LinearOperatorPermutation(perm) 91 92 def test_to_dense_4x4(self): 93 perm = [0, 1, 2, 3] 94 self.assertAllClose( 95 permutation.LinearOperatorPermutation(perm).to_dense(), 96 linalg_ops.eye(4)) 97 perm = [1, 0, 3, 2] 98 self.assertAllClose( 99 permutation.LinearOperatorPermutation(perm).to_dense(), 100 [[0., 1, 0, 0], [1., 0, 0, 0], [0., 0, 0, 1], [0., 0, 1, 0]]) 101 perm = [3, 2, 0, 1] 102 self.assertAllClose( 103 permutation.LinearOperatorPermutation(perm).to_dense(), 104 [[0., 0, 0, 1], [0., 0, 1, 0], [1., 0, 0, 0], [0., 1, 0, 0]]) 105 106 107if __name__ == "__main__": 108 linear_operator_test_util.add_tests(LinearOperatorPermutationTest) 109 test.main() 110