• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16import numpy as np
17
18from tensorflow.python.framework import dtypes
19from tensorflow.python.framework import test_util
20from tensorflow.python.ops import array_ops
21from tensorflow.python.ops import variables as variables_module
22from tensorflow.python.ops.linalg import linalg as linalg_lib
23from tensorflow.python.ops.linalg import linear_operator_test_util
24from tensorflow.python.platform import test
25
26
27rng = np.random.RandomState(2016)
28
29
30@test_util.run_all_in_graph_and_eager_modes
31class LinearOperatorZerosTest(
32    linear_operator_test_util.SquareLinearOperatorDerivedClassTest):
33  """Most tests done in the base class LinearOperatorDerivedClassTest."""
34
35  @staticmethod
36  def skip_these_tests():
37    return [
38        "cholesky",
39        "cond",
40        "inverse",
41        "log_abs_det",
42        "solve",
43        "solve_with_broadcast"
44    ]
45
46  @staticmethod
47  def operator_shapes_infos():
48    shapes_info = linear_operator_test_util.OperatorShapesInfo
49    return [
50        shapes_info((1, 1)),
51        shapes_info((1, 3, 3)),
52        shapes_info((3, 4, 4)),
53        shapes_info((2, 1, 4, 4))]
54
55  @staticmethod
56  def optional_tests():
57    """List of optional test names to run."""
58    return [
59        "operator_matmul_with_same_type",
60    ]
61
62  def operator_and_matrix(
63      self, build_info, dtype, use_placeholder,
64      ensure_self_adjoint_and_pd=False):
65    del ensure_self_adjoint_and_pd
66    del use_placeholder
67    shape = list(build_info.shape)
68    assert shape[-1] == shape[-2]
69
70    batch_shape = shape[:-2]
71    num_rows = shape[-1]
72
73    operator = linalg_lib.LinearOperatorZeros(
74        num_rows, batch_shape=batch_shape, dtype=dtype)
75    matrix = array_ops.zeros(shape=shape, dtype=dtype)
76
77    return operator, matrix
78
79  def test_assert_positive_definite(self):
80    operator = linalg_lib.LinearOperatorZeros(num_rows=2)
81    with self.assertRaisesOpError("non-positive definite"):
82      operator.assert_positive_definite()
83
84  def test_assert_non_singular(self):
85    with self.assertRaisesOpError("non-invertible"):
86      operator = linalg_lib.LinearOperatorZeros(num_rows=2)
87      operator.assert_non_singular()
88
89  def test_assert_self_adjoint(self):
90    with self.cached_session():
91      operator = linalg_lib.LinearOperatorZeros(num_rows=2)
92      self.evaluate(operator.assert_self_adjoint())  # Should not fail
93
94  def test_non_scalar_num_rows_raises_static(self):
95    with self.assertRaisesRegex(ValueError, "must be a 0-D Tensor"):
96      linalg_lib.LinearOperatorZeros(num_rows=[2])
97    with self.assertRaisesRegex(ValueError, "must be a 0-D Tensor"):
98      linalg_lib.LinearOperatorZeros(num_rows=2, num_columns=[2])
99
100  def test_non_integer_num_rows_raises_static(self):
101    with self.assertRaisesRegex(TypeError, "must be integer"):
102      linalg_lib.LinearOperatorZeros(num_rows=2.)
103    with self.assertRaisesRegex(TypeError, "must be integer"):
104      linalg_lib.LinearOperatorZeros(num_rows=2, num_columns=2.)
105
106  def test_negative_num_rows_raises_static(self):
107    with self.assertRaisesRegex(ValueError, "must be non-negative"):
108      linalg_lib.LinearOperatorZeros(num_rows=-2)
109    with self.assertRaisesRegex(ValueError, "must be non-negative"):
110      linalg_lib.LinearOperatorZeros(num_rows=2, num_columns=-2)
111
112  def test_non_1d_batch_shape_raises_static(self):
113    with self.assertRaisesRegex(ValueError, "must be a 1-D"):
114      linalg_lib.LinearOperatorZeros(num_rows=2, batch_shape=2)
115
116  def test_non_integer_batch_shape_raises_static(self):
117    with self.assertRaisesRegex(TypeError, "must be integer"):
118      linalg_lib.LinearOperatorZeros(num_rows=2, batch_shape=[2.])
119
120  def test_negative_batch_shape_raises_static(self):
121    with self.assertRaisesRegex(ValueError, "must be non-negative"):
122      linalg_lib.LinearOperatorZeros(num_rows=2, batch_shape=[-2])
123
124  def test_non_scalar_num_rows_raises_dynamic(self):
125    with self.cached_session():
126      num_rows = array_ops.placeholder_with_default([2], shape=None)
127      with self.assertRaisesError("must be a 0-D Tensor"):
128        operator = linalg_lib.LinearOperatorZeros(
129            num_rows, assert_proper_shapes=True)
130        self.evaluate(operator.to_dense())
131
132  def test_negative_num_rows_raises_dynamic(self):
133    with self.cached_session():
134      n = array_ops.placeholder_with_default(-2, shape=None)
135      with self.assertRaisesError("must be non-negative"):
136        operator = linalg_lib.LinearOperatorZeros(
137            num_rows=n, assert_proper_shapes=True)
138        self.evaluate(operator.to_dense())
139
140  def test_non_1d_batch_shape_raises_dynamic(self):
141    with self.cached_session():
142      batch_shape = array_ops.placeholder_with_default(2, shape=None)
143      with self.assertRaisesError("must be a 1-D"):
144        operator = linalg_lib.LinearOperatorZeros(
145            num_rows=2, batch_shape=batch_shape, assert_proper_shapes=True)
146        self.evaluate(operator.to_dense())
147
148  def test_negative_batch_shape_raises_dynamic(self):
149    with self.cached_session():
150      batch_shape = array_ops.placeholder_with_default([-2], shape=None)
151      with self.assertRaisesError("must be non-negative"):
152        operator = linalg_lib.LinearOperatorZeros(
153            num_rows=2, batch_shape=batch_shape, assert_proper_shapes=True)
154        self.evaluate(operator.to_dense())
155
156  def test_wrong_matrix_dimensions_raises_static(self):
157    operator = linalg_lib.LinearOperatorZeros(num_rows=2)
158    x = rng.randn(3, 3).astype(np.float32)
159    with self.assertRaisesRegex(ValueError, "Dimensions.*not compatible"):
160      operator.matmul(x)
161
162  def test_wrong_matrix_dimensions_raises_dynamic(self):
163    num_rows = array_ops.placeholder_with_default(2, shape=None)
164    x = array_ops.placeholder_with_default(rng.rand(3, 3), shape=None)
165
166    with self.cached_session():
167      with self.assertRaisesError("Dimensions.*not.compatible"):
168        operator = linalg_lib.LinearOperatorZeros(
169            num_rows, assert_proper_shapes=True, dtype=dtypes.float64)
170        self.evaluate(operator.matmul(x))
171
172  def test_is_x_flags(self):
173    # The is_x flags are by default all True.
174    operator = linalg_lib.LinearOperatorZeros(num_rows=2)
175    self.assertFalse(operator.is_positive_definite)
176    self.assertFalse(operator.is_non_singular)
177    self.assertTrue(operator.is_self_adjoint)
178
179  def test_zeros_matmul(self):
180    operator1 = linalg_lib.LinearOperatorIdentity(num_rows=2)
181    operator2 = linalg_lib.LinearOperatorZeros(num_rows=2)
182    self.assertTrue(isinstance(
183        operator1.matmul(operator2),
184        linalg_lib.LinearOperatorZeros))
185
186    self.assertTrue(isinstance(
187        operator2.matmul(operator1),
188        linalg_lib.LinearOperatorZeros))
189
190  def test_ref_type_shape_args_raises(self):
191    with self.assertRaisesRegex(TypeError, "num_rows.cannot.be.reference"):
192      linalg_lib.LinearOperatorZeros(num_rows=variables_module.Variable(2))
193
194    with self.assertRaisesRegex(TypeError, "num_columns.cannot.be.reference"):
195      linalg_lib.LinearOperatorZeros(
196          num_rows=2, num_columns=variables_module.Variable(3))
197
198    with self.assertRaisesRegex(TypeError, "batch_shape.cannot.be.reference"):
199      linalg_lib.LinearOperatorZeros(
200          num_rows=2, batch_shape=variables_module.Variable([2]))
201
202
203@test_util.run_all_in_graph_and_eager_modes
204class LinearOperatorZerosNotSquareTest(
205    linear_operator_test_util.NonSquareLinearOperatorDerivedClassTest):
206
207  def operator_and_matrix(
208      self, build_info, dtype, use_placeholder,
209      ensure_self_adjoint_and_pd=False):
210    del use_placeholder
211    del ensure_self_adjoint_and_pd
212    shape = list(build_info.shape)
213
214    batch_shape = shape[:-2]
215    num_rows = shape[-2]
216    num_columns = shape[-1]
217
218    operator = linalg_lib.LinearOperatorZeros(
219        num_rows, num_columns, is_square=False, is_self_adjoint=False,
220        batch_shape=batch_shape, dtype=dtype)
221    matrix = array_ops.zeros(shape=shape, dtype=dtype)
222
223    return operator, matrix
224
225
226if __name__ == "__main__":
227  linear_operator_test_util.add_tests(LinearOperatorZerosTest)
228  linear_operator_test_util.add_tests(LinearOperatorZerosNotSquareTest)
229  test.main()
230