1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16import numpy as np 17 18from tensorflow.python.framework import dtypes 19from tensorflow.python.framework import test_util 20from tensorflow.python.ops import array_ops 21from tensorflow.python.ops import variables as variables_module 22from tensorflow.python.ops.linalg import linalg as linalg_lib 23from tensorflow.python.ops.linalg import linear_operator_test_util 24from tensorflow.python.platform import test 25 26 27rng = np.random.RandomState(2016) 28 29 30@test_util.run_all_in_graph_and_eager_modes 31class LinearOperatorZerosTest( 32 linear_operator_test_util.SquareLinearOperatorDerivedClassTest): 33 """Most tests done in the base class LinearOperatorDerivedClassTest.""" 34 35 @staticmethod 36 def skip_these_tests(): 37 return [ 38 "cholesky", 39 "cond", 40 "inverse", 41 "log_abs_det", 42 "solve", 43 "solve_with_broadcast" 44 ] 45 46 @staticmethod 47 def operator_shapes_infos(): 48 shapes_info = linear_operator_test_util.OperatorShapesInfo 49 return [ 50 shapes_info((1, 1)), 51 shapes_info((1, 3, 3)), 52 shapes_info((3, 4, 4)), 53 shapes_info((2, 1, 4, 4))] 54 55 @staticmethod 56 def optional_tests(): 57 """List of optional test names to run.""" 58 return [ 59 "operator_matmul_with_same_type", 60 ] 61 62 def operator_and_matrix( 63 self, build_info, dtype, use_placeholder, 64 ensure_self_adjoint_and_pd=False): 65 del ensure_self_adjoint_and_pd 66 del use_placeholder 67 shape = list(build_info.shape) 68 assert shape[-1] == shape[-2] 69 70 batch_shape = shape[:-2] 71 num_rows = shape[-1] 72 73 operator = linalg_lib.LinearOperatorZeros( 74 num_rows, batch_shape=batch_shape, dtype=dtype) 75 matrix = array_ops.zeros(shape=shape, dtype=dtype) 76 77 return operator, matrix 78 79 def test_assert_positive_definite(self): 80 operator = linalg_lib.LinearOperatorZeros(num_rows=2) 81 with self.assertRaisesOpError("non-positive definite"): 82 operator.assert_positive_definite() 83 84 def test_assert_non_singular(self): 85 with self.assertRaisesOpError("non-invertible"): 86 operator = linalg_lib.LinearOperatorZeros(num_rows=2) 87 operator.assert_non_singular() 88 89 def test_assert_self_adjoint(self): 90 with self.cached_session(): 91 operator = linalg_lib.LinearOperatorZeros(num_rows=2) 92 self.evaluate(operator.assert_self_adjoint()) # Should not fail 93 94 def test_non_scalar_num_rows_raises_static(self): 95 with self.assertRaisesRegex(ValueError, "must be a 0-D Tensor"): 96 linalg_lib.LinearOperatorZeros(num_rows=[2]) 97 with self.assertRaisesRegex(ValueError, "must be a 0-D Tensor"): 98 linalg_lib.LinearOperatorZeros(num_rows=2, num_columns=[2]) 99 100 def test_non_integer_num_rows_raises_static(self): 101 with self.assertRaisesRegex(TypeError, "must be integer"): 102 linalg_lib.LinearOperatorZeros(num_rows=2.) 103 with self.assertRaisesRegex(TypeError, "must be integer"): 104 linalg_lib.LinearOperatorZeros(num_rows=2, num_columns=2.) 105 106 def test_negative_num_rows_raises_static(self): 107 with self.assertRaisesRegex(ValueError, "must be non-negative"): 108 linalg_lib.LinearOperatorZeros(num_rows=-2) 109 with self.assertRaisesRegex(ValueError, "must be non-negative"): 110 linalg_lib.LinearOperatorZeros(num_rows=2, num_columns=-2) 111 112 def test_non_1d_batch_shape_raises_static(self): 113 with self.assertRaisesRegex(ValueError, "must be a 1-D"): 114 linalg_lib.LinearOperatorZeros(num_rows=2, batch_shape=2) 115 116 def test_non_integer_batch_shape_raises_static(self): 117 with self.assertRaisesRegex(TypeError, "must be integer"): 118 linalg_lib.LinearOperatorZeros(num_rows=2, batch_shape=[2.]) 119 120 def test_negative_batch_shape_raises_static(self): 121 with self.assertRaisesRegex(ValueError, "must be non-negative"): 122 linalg_lib.LinearOperatorZeros(num_rows=2, batch_shape=[-2]) 123 124 def test_non_scalar_num_rows_raises_dynamic(self): 125 with self.cached_session(): 126 num_rows = array_ops.placeholder_with_default([2], shape=None) 127 with self.assertRaisesError("must be a 0-D Tensor"): 128 operator = linalg_lib.LinearOperatorZeros( 129 num_rows, assert_proper_shapes=True) 130 self.evaluate(operator.to_dense()) 131 132 def test_negative_num_rows_raises_dynamic(self): 133 with self.cached_session(): 134 n = array_ops.placeholder_with_default(-2, shape=None) 135 with self.assertRaisesError("must be non-negative"): 136 operator = linalg_lib.LinearOperatorZeros( 137 num_rows=n, assert_proper_shapes=True) 138 self.evaluate(operator.to_dense()) 139 140 def test_non_1d_batch_shape_raises_dynamic(self): 141 with self.cached_session(): 142 batch_shape = array_ops.placeholder_with_default(2, shape=None) 143 with self.assertRaisesError("must be a 1-D"): 144 operator = linalg_lib.LinearOperatorZeros( 145 num_rows=2, batch_shape=batch_shape, assert_proper_shapes=True) 146 self.evaluate(operator.to_dense()) 147 148 def test_negative_batch_shape_raises_dynamic(self): 149 with self.cached_session(): 150 batch_shape = array_ops.placeholder_with_default([-2], shape=None) 151 with self.assertRaisesError("must be non-negative"): 152 operator = linalg_lib.LinearOperatorZeros( 153 num_rows=2, batch_shape=batch_shape, assert_proper_shapes=True) 154 self.evaluate(operator.to_dense()) 155 156 def test_wrong_matrix_dimensions_raises_static(self): 157 operator = linalg_lib.LinearOperatorZeros(num_rows=2) 158 x = rng.randn(3, 3).astype(np.float32) 159 with self.assertRaisesRegex(ValueError, "Dimensions.*not compatible"): 160 operator.matmul(x) 161 162 def test_wrong_matrix_dimensions_raises_dynamic(self): 163 num_rows = array_ops.placeholder_with_default(2, shape=None) 164 x = array_ops.placeholder_with_default(rng.rand(3, 3), shape=None) 165 166 with self.cached_session(): 167 with self.assertRaisesError("Dimensions.*not.compatible"): 168 operator = linalg_lib.LinearOperatorZeros( 169 num_rows, assert_proper_shapes=True, dtype=dtypes.float64) 170 self.evaluate(operator.matmul(x)) 171 172 def test_is_x_flags(self): 173 # The is_x flags are by default all True. 174 operator = linalg_lib.LinearOperatorZeros(num_rows=2) 175 self.assertFalse(operator.is_positive_definite) 176 self.assertFalse(operator.is_non_singular) 177 self.assertTrue(operator.is_self_adjoint) 178 179 def test_zeros_matmul(self): 180 operator1 = linalg_lib.LinearOperatorIdentity(num_rows=2) 181 operator2 = linalg_lib.LinearOperatorZeros(num_rows=2) 182 self.assertTrue(isinstance( 183 operator1.matmul(operator2), 184 linalg_lib.LinearOperatorZeros)) 185 186 self.assertTrue(isinstance( 187 operator2.matmul(operator1), 188 linalg_lib.LinearOperatorZeros)) 189 190 def test_ref_type_shape_args_raises(self): 191 with self.assertRaisesRegex(TypeError, "num_rows.cannot.be.reference"): 192 linalg_lib.LinearOperatorZeros(num_rows=variables_module.Variable(2)) 193 194 with self.assertRaisesRegex(TypeError, "num_columns.cannot.be.reference"): 195 linalg_lib.LinearOperatorZeros( 196 num_rows=2, num_columns=variables_module.Variable(3)) 197 198 with self.assertRaisesRegex(TypeError, "batch_shape.cannot.be.reference"): 199 linalg_lib.LinearOperatorZeros( 200 num_rows=2, batch_shape=variables_module.Variable([2])) 201 202 203@test_util.run_all_in_graph_and_eager_modes 204class LinearOperatorZerosNotSquareTest( 205 linear_operator_test_util.NonSquareLinearOperatorDerivedClassTest): 206 207 def operator_and_matrix( 208 self, build_info, dtype, use_placeholder, 209 ensure_self_adjoint_and_pd=False): 210 del use_placeholder 211 del ensure_self_adjoint_and_pd 212 shape = list(build_info.shape) 213 214 batch_shape = shape[:-2] 215 num_rows = shape[-2] 216 num_columns = shape[-1] 217 218 operator = linalg_lib.LinearOperatorZeros( 219 num_rows, num_columns, is_square=False, is_self_adjoint=False, 220 batch_shape=batch_shape, dtype=dtype) 221 matrix = array_ops.zeros(shape=shape, dtype=dtype) 222 223 return operator, matrix 224 225 226if __name__ == "__main__": 227 linear_operator_test_util.add_tests(LinearOperatorZerosTest) 228 linear_operator_test_util.add_tests(LinearOperatorZerosNotSquareTest) 229 test.main() 230