• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15import numpy as np
16
17from tensorflow.python.eager import context
18from tensorflow.python.framework import constant_op
19from tensorflow.python.framework import dtypes
20from tensorflow.python.framework import ops
21from tensorflow.python.framework import tensor_shape
22from tensorflow.python.framework import test_util
23from tensorflow.python.ops import array_ops
24from tensorflow.python.ops import linalg_ops
25from tensorflow.python.ops import math_ops
26from tensorflow.python.ops.linalg import linalg as linalg_lib
27from tensorflow.python.ops.parallel_for import control_flow_ops
28from tensorflow.python.platform import test
29
30linalg = linalg_lib
31rng = np.random.RandomState(123)
32
33
34class LinearOperatorShape(linalg.LinearOperator):
35  """LinearOperator that implements the methods ._shape and _shape_tensor."""
36
37  def __init__(self,
38               shape,
39               is_non_singular=None,
40               is_self_adjoint=None,
41               is_positive_definite=None,
42               is_square=None):
43    parameters = dict(
44        shape=shape,
45        is_non_singular=is_non_singular,
46        is_self_adjoint=is_self_adjoint,
47        is_positive_definite=is_positive_definite,
48        is_square=is_square
49    )
50
51    self._stored_shape = shape
52    super(LinearOperatorShape, self).__init__(
53        dtype=dtypes.float32,
54        is_non_singular=is_non_singular,
55        is_self_adjoint=is_self_adjoint,
56        is_positive_definite=is_positive_definite,
57        is_square=is_square,
58        parameters=parameters)
59
60  def _shape(self):
61    return tensor_shape.TensorShape(self._stored_shape)
62
63  def _shape_tensor(self):
64    return constant_op.constant(self._stored_shape, dtype=dtypes.int32)
65
66  def _matmul(self):
67    raise NotImplementedError("Not needed for this test.")
68
69
70class LinearOperatorMatmulSolve(linalg.LinearOperator):
71  """LinearOperator that wraps a [batch] matrix and implements matmul/solve."""
72
73  def __init__(self,
74               matrix,
75               is_non_singular=None,
76               is_self_adjoint=None,
77               is_positive_definite=None,
78               is_square=None):
79    parameters = dict(
80        matrix=matrix,
81        is_non_singular=is_non_singular,
82        is_self_adjoint=is_self_adjoint,
83        is_positive_definite=is_positive_definite,
84        is_square=is_square
85    )
86
87    self._matrix = ops.convert_to_tensor(matrix, name="matrix")
88    super(LinearOperatorMatmulSolve, self).__init__(
89        dtype=self._matrix.dtype,
90        is_non_singular=is_non_singular,
91        is_self_adjoint=is_self_adjoint,
92        is_positive_definite=is_positive_definite,
93        is_square=is_square,
94        parameters=parameters)
95
96  def _shape(self):
97    return self._matrix.shape
98
99  def _shape_tensor(self):
100    return array_ops.shape(self._matrix)
101
102  def _matmul(self, x, adjoint=False, adjoint_arg=False):
103    x = ops.convert_to_tensor(x, name="x")
104    return math_ops.matmul(
105        self._matrix, x, adjoint_a=adjoint, adjoint_b=adjoint_arg)
106
107  def _solve(self, rhs, adjoint=False, adjoint_arg=False):
108    rhs = ops.convert_to_tensor(rhs, name="rhs")
109    assert not adjoint_arg, "Not implemented for this test class."
110    return linalg_ops.matrix_solve(self._matrix, rhs, adjoint=adjoint)
111
112
113@test_util.run_all_in_graph_and_eager_modes
114class LinearOperatorTest(test.TestCase):
115
116  def test_all_shape_properties_defined_by_the_one_property_shape(self):
117
118    shape = (1, 2, 3, 4)
119    operator = LinearOperatorShape(shape)
120
121    self.assertAllEqual(shape, operator.shape)
122    self.assertAllEqual(4, operator.tensor_rank)
123    self.assertAllEqual((1, 2), operator.batch_shape)
124    self.assertAllEqual(4, operator.domain_dimension)
125    self.assertAllEqual(3, operator.range_dimension)
126    expected_parameters = {
127        "is_non_singular": None,
128        "is_positive_definite": None,
129        "is_self_adjoint": None,
130        "is_square": None,
131        "shape": (1, 2, 3, 4),
132    }
133    self.assertEqual(expected_parameters, operator.parameters)
134
135  def test_all_shape_methods_defined_by_the_one_method_shape(self):
136    with self.cached_session():
137      shape = (1, 2, 3, 4)
138      operator = LinearOperatorShape(shape)
139
140      self.assertAllEqual(shape, self.evaluate(operator.shape_tensor()))
141      self.assertAllEqual(4, self.evaluate(operator.tensor_rank_tensor()))
142      self.assertAllEqual((1, 2), self.evaluate(operator.batch_shape_tensor()))
143      self.assertAllEqual(4, self.evaluate(operator.domain_dimension_tensor()))
144      self.assertAllEqual(3, self.evaluate(operator.range_dimension_tensor()))
145
146  def test_is_x_properties(self):
147    operator = LinearOperatorShape(
148        shape=(2, 2),
149        is_non_singular=False,
150        is_self_adjoint=True,
151        is_positive_definite=False)
152    self.assertFalse(operator.is_non_singular)
153    self.assertTrue(operator.is_self_adjoint)
154    self.assertFalse(operator.is_positive_definite)
155
156  def test_nontrivial_parameters(self):
157    matrix = rng.randn(2, 3, 4)
158    matrix_ph = array_ops.placeholder_with_default(input=matrix, shape=None)
159    operator = LinearOperatorMatmulSolve(matrix_ph)
160    expected_parameters = {
161        "is_non_singular": None,
162        "is_positive_definite": None,
163        "is_self_adjoint": None,
164        "is_square": None,
165        "matrix": matrix_ph,
166    }
167    self.assertEqual(expected_parameters, operator.parameters)
168
169  def test_generic_to_dense_method_non_square_matrix_static(self):
170    matrix = rng.randn(2, 3, 4)
171    operator = LinearOperatorMatmulSolve(matrix)
172    with self.cached_session():
173      operator_dense = operator.to_dense()
174      self.assertAllEqual((2, 3, 4), operator_dense.shape)
175      self.assertAllClose(matrix, self.evaluate(operator_dense))
176
177  def test_generic_to_dense_method_non_square_matrix_tensor(self):
178    matrix = rng.randn(2, 3, 4)
179    matrix_ph = array_ops.placeholder_with_default(input=matrix, shape=None)
180    operator = LinearOperatorMatmulSolve(matrix_ph)
181    operator_dense = operator.to_dense()
182    self.assertAllClose(matrix, self.evaluate(operator_dense))
183
184  def test_matvec(self):
185    matrix = [[1., 0], [0., 2.]]
186    operator = LinearOperatorMatmulSolve(matrix)
187    x = [1., 1.]
188    with self.cached_session():
189      y = operator.matvec(x)
190      self.assertAllEqual((2,), y.shape)
191      self.assertAllClose([1., 2.], self.evaluate(y))
192
193  def test_solvevec(self):
194    matrix = [[1., 0], [0., 2.]]
195    operator = LinearOperatorMatmulSolve(matrix)
196    y = [1., 1.]
197    with self.cached_session():
198      x = operator.solvevec(y)
199      self.assertAllEqual((2,), x.shape)
200      self.assertAllClose([1., 1 / 2.], self.evaluate(x))
201
202  def test_is_square_set_to_true_for_square_static_shapes(self):
203    operator = LinearOperatorShape(shape=(2, 4, 4))
204    self.assertTrue(operator.is_square)
205
206  def test_is_square_set_to_false_for_square_static_shapes(self):
207    operator = LinearOperatorShape(shape=(2, 3, 4))
208    self.assertFalse(operator.is_square)
209
210  def test_is_square_set_incorrectly_to_false_raises(self):
211    with self.assertRaisesRegex(ValueError, "but.*was square"):
212      _ = LinearOperatorShape(shape=(2, 4, 4), is_square=False).is_square
213
214  def test_is_square_set_inconsistent_with_other_hints_raises(self):
215    with self.assertRaisesRegex(ValueError, "is always square"):
216      matrix = array_ops.placeholder_with_default(input=(), shape=None)
217      LinearOperatorMatmulSolve(matrix, is_non_singular=True, is_square=False)
218
219    with self.assertRaisesRegex(ValueError, "is always square"):
220      matrix = array_ops.placeholder_with_default(input=(), shape=None)
221      LinearOperatorMatmulSolve(
222          matrix, is_positive_definite=True, is_square=False)
223
224  def test_non_square_operators_raise_on_determinant_and_solve(self):
225    operator = LinearOperatorShape((2, 3))
226    with self.assertRaisesRegex(NotImplementedError, "not be square"):
227      operator.determinant()
228    with self.assertRaisesRegex(NotImplementedError, "not be square"):
229      operator.log_abs_determinant()
230    with self.assertRaisesRegex(NotImplementedError, "not be square"):
231      operator.solve(rng.rand(2, 2))
232
233    with self.assertRaisesRegex(ValueError, "is always square"):
234      matrix = array_ops.placeholder_with_default(input=(), shape=None)
235      LinearOperatorMatmulSolve(
236          matrix, is_positive_definite=True, is_square=False)
237
238  def test_is_square_manual_set_works(self):
239    matrix = array_ops.placeholder_with_default(
240        input=np.ones((2, 2)), shape=None)
241    operator = LinearOperatorMatmulSolve(matrix)
242    if not context.executing_eagerly():
243      # Eager mode will read in the default value, and discover the answer is
244      # True.  Graph mode must rely on the hint, since the placeholder has
245      # shape=None...the hint is, by default, None.
246      self.assertEqual(None, operator.is_square)
247
248    # Set to True
249    operator = LinearOperatorMatmulSolve(matrix, is_square=True)
250    self.assertTrue(operator.is_square)
251
252  def test_linear_operator_matmul_hints_closed(self):
253    matrix = array_ops.placeholder_with_default(input=np.ones((2, 2)),
254                                                shape=None)
255    operator1 = LinearOperatorMatmulSolve(matrix)
256
257    operator_matmul = operator1.matmul(operator1)
258
259    if not context.executing_eagerly():
260      # Eager mode will read in the input and discover matrix is square.
261      self.assertEqual(None, operator_matmul.is_square)
262    self.assertEqual(None, operator_matmul.is_non_singular)
263    self.assertEqual(None, operator_matmul.is_self_adjoint)
264    self.assertEqual(None, operator_matmul.is_positive_definite)
265
266    operator2 = LinearOperatorMatmulSolve(
267        matrix,
268        is_non_singular=True,
269        is_self_adjoint=True,
270        is_positive_definite=True,
271        is_square=True,
272    )
273
274    operator_matmul = operator2.matmul(operator2)
275
276    self.assertTrue(operator_matmul.is_square)
277    self.assertTrue(operator_matmul.is_non_singular)
278    self.assertEqual(None, operator_matmul.is_self_adjoint)
279    self.assertEqual(None, operator_matmul.is_positive_definite)
280
281  def test_linear_operator_matmul_hints_false(self):
282    matrix1 = array_ops.placeholder_with_default(
283        input=rng.rand(2, 2), shape=None)
284    operator1 = LinearOperatorMatmulSolve(
285        matrix1,
286        is_non_singular=False,
287        is_self_adjoint=False,
288        is_positive_definite=False,
289        is_square=True,
290    )
291
292    operator_matmul = operator1.matmul(operator1)
293
294    self.assertTrue(operator_matmul.is_square)
295    self.assertFalse(operator_matmul.is_non_singular)
296    self.assertEqual(None, operator_matmul.is_self_adjoint)
297    self.assertEqual(None, operator_matmul.is_positive_definite)
298
299    matrix2 = array_ops.placeholder_with_default(
300        input=rng.rand(2, 3), shape=None)
301    operator2 = LinearOperatorMatmulSolve(
302        matrix2,
303        is_non_singular=False,
304        is_self_adjoint=False,
305        is_positive_definite=False,
306        is_square=False,
307    )
308
309    operator_matmul = operator2.matmul(operator2, adjoint_arg=True)
310
311    if context.executing_eagerly():
312      self.assertTrue(operator_matmul.is_square)
313      # False since we specified is_non_singular=False.
314      self.assertFalse(operator_matmul.is_non_singular)
315    else:
316      self.assertIsNone(operator_matmul.is_square)
317      # May be non-singular, since it's the composition of two non-square.
318      # TODO(b/136162840) This is a bit inconsistent, and should probably be
319      # False since we specified operator2.is_non_singular == False.
320      self.assertIsNone(operator_matmul.is_non_singular)
321
322    # No way to deduce these, even in Eager mode.
323    self.assertIsNone(operator_matmul.is_self_adjoint)
324    self.assertIsNone(operator_matmul.is_positive_definite)
325
326  def test_linear_operator_matmul_hint_infer_square(self):
327    matrix1 = array_ops.placeholder_with_default(
328        input=rng.rand(2, 3), shape=(2, 3))
329    matrix2 = array_ops.placeholder_with_default(
330        input=rng.rand(3, 2), shape=(3, 2))
331    matrix3 = array_ops.placeholder_with_default(
332        input=rng.rand(3, 4), shape=(3, 4))
333
334    operator1 = LinearOperatorMatmulSolve(matrix1, is_square=False)
335    operator2 = LinearOperatorMatmulSolve(matrix2, is_square=False)
336    operator3 = LinearOperatorMatmulSolve(matrix3, is_square=False)
337
338    self.assertTrue(operator1.matmul(operator2).is_square)
339    self.assertTrue(operator2.matmul(operator1).is_square)
340    self.assertFalse(operator1.matmul(operator3).is_square)
341
342  def testDispatchedMethods(self):
343    operator = linalg.LinearOperatorFullMatrix(
344        [[1., 0.5], [0.5, 1.]],
345        is_square=True,
346        is_self_adjoint=True,
347        is_non_singular=True,
348        is_positive_definite=True)
349    methods = {
350        "trace": linalg.trace,
351        "diag_part": linalg.diag_part,
352        "log_abs_determinant": linalg.logdet,
353        "determinant": linalg.det
354    }
355    for method in methods:
356      op_val = getattr(operator, method)()
357      linalg_val = methods[method](operator)
358      self.assertAllClose(
359          self.evaluate(op_val),
360          self.evaluate(linalg_val))
361    # Solve and Matmul go here.
362
363    adjoint = linalg.adjoint(operator)
364    self.assertIsInstance(adjoint, linalg.LinearOperator)
365    cholesky = linalg.cholesky(operator)
366    self.assertIsInstance(cholesky, linalg.LinearOperator)
367    inverse = linalg.inv(operator)
368    self.assertIsInstance(inverse, linalg.LinearOperator)
369
370  def testDispatchMatmulSolve(self):
371    operator = linalg.LinearOperatorFullMatrix(
372        np.float64([[1., 0.5], [0.5, 1.]]),
373        is_square=True,
374        is_self_adjoint=True,
375        is_non_singular=True,
376        is_positive_definite=True)
377    rhs = np.random.uniform(-1., 1., size=[3, 2, 2])
378    for adjoint in [False, True]:
379      for adjoint_arg in [False, True]:
380        op_val = operator.matmul(
381            rhs, adjoint=adjoint, adjoint_arg=adjoint_arg)
382        matmul_val = math_ops.matmul(
383            operator, rhs, adjoint_a=adjoint, adjoint_b=adjoint_arg)
384        self.assertAllClose(
385            self.evaluate(op_val), self.evaluate(matmul_val))
386
387      op_val = operator.solve(rhs, adjoint=adjoint)
388      solve_val = linalg.solve(operator, rhs, adjoint=adjoint)
389      self.assertAllClose(
390          self.evaluate(op_val), self.evaluate(solve_val))
391
392  def testDispatchMatmulLeftOperatorIsTensor(self):
393    mat = np.float64([[1., 0.5], [0.5, 1.]])
394    right_operator = linalg.LinearOperatorFullMatrix(
395        mat,
396        is_square=True,
397        is_self_adjoint=True,
398        is_non_singular=True,
399        is_positive_definite=True)
400    lhs = np.random.uniform(-1., 1., size=[3, 2, 2])
401
402    for adjoint in [False, True]:
403      for adjoint_arg in [False, True]:
404        op_val = math_ops.matmul(
405            lhs, mat, adjoint_a=adjoint, adjoint_b=adjoint_arg)
406        matmul_val = math_ops.matmul(
407            lhs, right_operator, adjoint_a=adjoint, adjoint_b=adjoint_arg)
408        self.assertAllClose(
409            self.evaluate(op_val), self.evaluate(matmul_val))
410
411  def testVectorizedMap(self):
412
413    def fn(x):
414      y = constant_op.constant([3., 4.])
415      # Make a [2, N, N] shaped operator.
416      x = x * y[..., array_ops.newaxis, array_ops.newaxis]
417      operator = linalg.LinearOperatorFullMatrix(
418          x, is_square=True)
419      return operator
420
421    x = np.random.uniform(-1., 1., size=[3, 5, 5]).astype(np.float32)
422    batched_operator = control_flow_ops.vectorized_map(
423        fn, ops.convert_to_tensor(x))
424    self.assertIsInstance(batched_operator, linalg.LinearOperator)
425    self.assertAllEqual(batched_operator.batch_shape, [3, 2])
426
427
428if __name__ == "__main__":
429  test.main()
430