• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16import numpy as np
17
18from tensorflow.python.framework import config
19from tensorflow.python.framework import dtypes
20from tensorflow.python.framework import test_util
21from tensorflow.python.ops import array_ops
22from tensorflow.python.ops import linalg_ops
23from tensorflow.python.ops import math_ops
24from tensorflow.python.ops import random_ops
25from tensorflow.python.ops import variables as variables_module
26from tensorflow.python.ops.linalg import linalg as linalg_lib
27from tensorflow.python.ops.linalg import linear_operator_test_util
28from tensorflow.python.platform import test
29
30
31rng = np.random.RandomState(2016)
32
33
34@test_util.run_all_in_graph_and_eager_modes
35class LinearOperatorIdentityTest(
36    linear_operator_test_util.SquareLinearOperatorDerivedClassTest):
37  """Most tests done in the base class LinearOperatorDerivedClassTest."""
38
39  def tearDown(self):
40    config.enable_tensor_float_32_execution(self.tf32_keep_)
41
42  def setUp(self):
43    self.tf32_keep_ = config.tensor_float_32_execution_enabled()
44    config.enable_tensor_float_32_execution(False)
45
46  @staticmethod
47  def dtypes_to_test():
48    # TODO(langmore) Test tf.float16 once tf.linalg.solve works in
49    # 16bit.
50    return [dtypes.float32, dtypes.float64, dtypes.complex64, dtypes.complex128]
51
52  @staticmethod
53  def optional_tests():
54    """List of optional test names to run."""
55    return [
56        "operator_matmul_with_same_type",
57        "operator_solve_with_same_type",
58    ]
59
60  def operator_and_matrix(
61      self, build_info, dtype, use_placeholder,
62      ensure_self_adjoint_and_pd=False):
63    # Identity matrix is already Hermitian Positive Definite.
64    del ensure_self_adjoint_and_pd
65
66    shape = list(build_info.shape)
67    assert shape[-1] == shape[-2]
68
69    batch_shape = shape[:-2]
70    num_rows = shape[-1]
71
72    operator = linalg_lib.LinearOperatorIdentity(
73        num_rows, batch_shape=batch_shape, dtype=dtype)
74    mat = linalg_ops.eye(num_rows, batch_shape=batch_shape, dtype=dtype)
75
76    return operator, mat
77
78  def test_assert_positive_definite(self):
79    with self.cached_session():
80      operator = linalg_lib.LinearOperatorIdentity(num_rows=2)
81      self.evaluate(operator.assert_positive_definite())  # Should not fail
82
83  def test_assert_non_singular(self):
84    with self.cached_session():
85      operator = linalg_lib.LinearOperatorIdentity(num_rows=2)
86      self.evaluate(operator.assert_non_singular())  # Should not fail
87
88  def test_assert_self_adjoint(self):
89    with self.cached_session():
90      operator = linalg_lib.LinearOperatorIdentity(num_rows=2)
91      self.evaluate(operator.assert_self_adjoint())  # Should not fail
92
93  def test_float16_matmul(self):
94    # float16 cannot be tested by base test class because tf.linalg.solve does
95    # not work with float16.
96    with self.cached_session():
97      operator = linalg_lib.LinearOperatorIdentity(
98          num_rows=2, dtype=dtypes.float16)
99      x = rng.randn(2, 3).astype(np.float16)
100      y = operator.matmul(x)
101      self.assertAllClose(x, self.evaluate(y))
102
103  def test_non_scalar_num_rows_raises_static(self):
104    with self.assertRaisesRegex(ValueError, "must be a 0-D Tensor"):
105      linalg_lib.LinearOperatorIdentity(num_rows=[2])
106
107  def test_non_integer_num_rows_raises_static(self):
108    with self.assertRaisesRegex(TypeError, "must be integer"):
109      linalg_lib.LinearOperatorIdentity(num_rows=2.)
110
111  def test_negative_num_rows_raises_static(self):
112    with self.assertRaisesRegex(ValueError, "must be non-negative"):
113      linalg_lib.LinearOperatorIdentity(num_rows=-2)
114
115  def test_non_1d_batch_shape_raises_static(self):
116    with self.assertRaisesRegex(ValueError, "must be a 1-D"):
117      linalg_lib.LinearOperatorIdentity(num_rows=2, batch_shape=2)
118
119  def test_non_integer_batch_shape_raises_static(self):
120    with self.assertRaisesRegex(TypeError, "must be integer"):
121      linalg_lib.LinearOperatorIdentity(num_rows=2, batch_shape=[2.])
122
123  def test_negative_batch_shape_raises_static(self):
124    with self.assertRaisesRegex(ValueError, "must be non-negative"):
125      linalg_lib.LinearOperatorIdentity(num_rows=2, batch_shape=[-2])
126
127  def test_non_scalar_num_rows_raises_dynamic(self):
128    with self.cached_session():
129      num_rows = array_ops.placeholder_with_default([2], shape=None)
130
131      with self.assertRaisesError("must be a 0-D Tensor"):
132        operator = linalg_lib.LinearOperatorIdentity(
133            num_rows, assert_proper_shapes=True)
134        self.evaluate(operator.to_dense())
135
136  def test_negative_num_rows_raises_dynamic(self):
137    with self.cached_session():
138      num_rows = array_ops.placeholder_with_default(-2, shape=None)
139      with self.assertRaisesError("must be non-negative"):
140        operator = linalg_lib.LinearOperatorIdentity(
141            num_rows, assert_proper_shapes=True)
142        self.evaluate(operator.to_dense())
143
144  def test_non_1d_batch_shape_raises_dynamic(self):
145    with self.cached_session():
146      batch_shape = array_ops.placeholder_with_default(2, shape=None)
147      with self.assertRaisesError("must be a 1-D"):
148        operator = linalg_lib.LinearOperatorIdentity(
149            num_rows=2, batch_shape=batch_shape, assert_proper_shapes=True)
150        self.evaluate(operator.to_dense())
151
152  def test_negative_batch_shape_raises_dynamic(self):
153    with self.cached_session():
154      batch_shape = array_ops.placeholder_with_default([-2], shape=None)
155      with self.assertRaisesError("must be non-negative"):
156        operator = linalg_lib.LinearOperatorIdentity(
157            num_rows=2, batch_shape=batch_shape, assert_proper_shapes=True)
158        self.evaluate(operator.to_dense())
159
160  def test_wrong_matrix_dimensions_raises_static(self):
161    operator = linalg_lib.LinearOperatorIdentity(num_rows=2)
162    x = rng.randn(3, 3).astype(np.float32)
163    with self.assertRaisesRegex(ValueError, "Dimensions.*not compatible"):
164      operator.matmul(x)
165
166  def test_wrong_matrix_dimensions_raises_dynamic(self):
167    num_rows = array_ops.placeholder_with_default(2, shape=None)
168    x = array_ops.placeholder_with_default(
169        rng.rand(3, 3).astype(np.float32), shape=None)
170
171    with self.cached_session():
172      with self.assertRaisesError("Dimensions.*not.compatible"):
173        operator = linalg_lib.LinearOperatorIdentity(
174            num_rows, assert_proper_shapes=True)
175        self.evaluate(operator.matmul(x))
176
177  def test_default_batch_shape_broadcasts_with_everything_static(self):
178    # These cannot be done in the automated (base test class) tests since they
179    # test shapes that tf.batch_matmul cannot handle.
180    # In particular, tf.batch_matmul does not broadcast.
181    with self.cached_session() as sess:
182      x = random_ops.random_normal(shape=(1, 2, 3, 4))
183      operator = linalg_lib.LinearOperatorIdentity(num_rows=3, dtype=x.dtype)
184
185      operator_matmul = operator.matmul(x)
186      expected = x
187
188      self.assertAllEqual(operator_matmul.shape, expected.shape)
189      self.assertAllClose(*self.evaluate([operator_matmul, expected]))
190
191  def test_default_batch_shape_broadcasts_with_everything_dynamic(self):
192    # These cannot be done in the automated (base test class) tests since they
193    # test shapes that tf.batch_matmul cannot handle.
194    # In particular, tf.batch_matmul does not broadcast.
195    with self.cached_session():
196      x = array_ops.placeholder_with_default(rng.randn(1, 2, 3, 4), shape=None)
197      operator = linalg_lib.LinearOperatorIdentity(num_rows=3, dtype=x.dtype)
198
199      operator_matmul = operator.matmul(x)
200      expected = x
201
202      self.assertAllClose(*self.evaluate([operator_matmul, expected]))
203
204  def test_broadcast_matmul_static_shapes(self):
205    # These cannot be done in the automated (base test class) tests since they
206    # test shapes that tf.batch_matmul cannot handle.
207    # In particular, tf.batch_matmul does not broadcast.
208    with self.cached_session() as sess:
209      # Given this x and LinearOperatorIdentity shape of (2, 1, 3, 3), the
210      # broadcast shape of operator and 'x' is (2, 2, 3, 4)
211      x = random_ops.random_normal(shape=(1, 2, 3, 4))
212      operator = linalg_lib.LinearOperatorIdentity(
213          num_rows=3, batch_shape=(2, 1), dtype=x.dtype)
214
215      # Batch matrix of zeros with the broadcast shape of x and operator.
216      zeros = array_ops.zeros(shape=(2, 2, 3, 4), dtype=x.dtype)
217
218      # Expected result of matmul and solve.
219      expected = x + zeros
220
221      operator_matmul = operator.matmul(x)
222      self.assertAllEqual(operator_matmul.shape, expected.shape)
223      self.assertAllClose(*self.evaluate([operator_matmul, expected]))
224
225  def test_broadcast_matmul_dynamic_shapes(self):
226    # These cannot be done in the automated (base test class) tests since they
227    # test shapes that tf.batch_matmul cannot handle.
228    # In particular, tf.batch_matmul does not broadcast.
229    with self.cached_session():
230      # Given this x and LinearOperatorIdentity shape of (2, 1, 3, 3), the
231      # broadcast shape of operator and 'x' is (2, 2, 3, 4)
232      x = array_ops.placeholder_with_default(rng.rand(1, 2, 3, 4), shape=None)
233      num_rows = array_ops.placeholder_with_default(3, shape=None)
234      batch_shape = array_ops.placeholder_with_default((2, 1), shape=None)
235
236      operator = linalg_lib.LinearOperatorIdentity(
237          num_rows, batch_shape=batch_shape, dtype=dtypes.float64)
238
239      # Batch matrix of zeros with the broadcast shape of x and operator.
240      zeros = array_ops.zeros(shape=(2, 2, 3, 4), dtype=x.dtype)
241
242      # Expected result of matmul and solve.
243      expected = x + zeros
244
245      operator_matmul = operator.matmul(x)
246      self.assertAllClose(*self.evaluate([operator_matmul, expected]))
247
248  def test_is_x_flags(self):
249    # The is_x flags are by default all True.
250    operator = linalg_lib.LinearOperatorIdentity(num_rows=2)
251    self.assertTrue(operator.is_positive_definite)
252    self.assertTrue(operator.is_non_singular)
253    self.assertTrue(operator.is_self_adjoint)
254
255    # Any of them False raises because the identity is always self-adjoint etc..
256    with self.assertRaisesRegex(ValueError, "is always non-singular"):
257      operator = linalg_lib.LinearOperatorIdentity(
258          num_rows=2,
259          is_non_singular=None,
260      )
261
262  def test_identity_adjoint_type(self):
263    operator = linalg_lib.LinearOperatorIdentity(
264        num_rows=2, is_non_singular=True)
265    self.assertIsInstance(
266        operator.adjoint(), linalg_lib.LinearOperatorIdentity)
267
268  def test_identity_cholesky_type(self):
269    operator = linalg_lib.LinearOperatorIdentity(
270        num_rows=2,
271        is_positive_definite=True,
272        is_self_adjoint=True,
273    )
274    self.assertIsInstance(
275        operator.cholesky(), linalg_lib.LinearOperatorIdentity)
276
277  def test_identity_inverse_type(self):
278    operator = linalg_lib.LinearOperatorIdentity(
279        num_rows=2, is_non_singular=True)
280    self.assertIsInstance(
281        operator.inverse(), linalg_lib.LinearOperatorIdentity)
282
283  def test_ref_type_shape_args_raises(self):
284    with self.assertRaisesRegex(TypeError, "num_rows.*reference"):
285      linalg_lib.LinearOperatorIdentity(num_rows=variables_module.Variable(2))
286
287    with self.assertRaisesRegex(TypeError, "batch_shape.*reference"):
288      linalg_lib.LinearOperatorIdentity(
289          num_rows=2, batch_shape=variables_module.Variable([3]))
290
291
292@test_util.run_all_in_graph_and_eager_modes
293class LinearOperatorScaledIdentityTest(
294    linear_operator_test_util.SquareLinearOperatorDerivedClassTest):
295  """Most tests done in the base class LinearOperatorDerivedClassTest."""
296
297  def tearDown(self):
298    config.enable_tensor_float_32_execution(self.tf32_keep_)
299
300  def setUp(self):
301    self.tf32_keep_ = config.tensor_float_32_execution_enabled()
302    config.enable_tensor_float_32_execution(False)
303
304  @staticmethod
305  def dtypes_to_test():
306    # TODO(langmore) Test tf.float16 once tf.linalg.solve works in
307    # 16bit.
308    return [dtypes.float32, dtypes.float64, dtypes.complex64, dtypes.complex128]
309
310  @staticmethod
311  def optional_tests():
312    """List of optional test names to run."""
313    return [
314        "operator_matmul_with_same_type",
315        "operator_solve_with_same_type",
316    ]
317
318  def operator_and_matrix(
319      self, build_info, dtype, use_placeholder,
320      ensure_self_adjoint_and_pd=False):
321
322    shape = list(build_info.shape)
323    assert shape[-1] == shape[-2]
324
325    batch_shape = shape[:-2]
326    num_rows = shape[-1]
327
328    # Uniform values that are at least length 1 from the origin.  Allows the
329    # operator to be well conditioned.
330    # Shape batch_shape
331    multiplier = linear_operator_test_util.random_sign_uniform(
332        shape=batch_shape, minval=1., maxval=2., dtype=dtype)
333
334    if ensure_self_adjoint_and_pd:
335      # Abs on complex64 will result in a float32, so we cast back up.
336      multiplier = math_ops.cast(math_ops.abs(multiplier), dtype=dtype)
337
338    # Nothing to feed since LinearOperatorScaledIdentity takes no Tensor args.
339    lin_op_multiplier = multiplier
340
341    if use_placeholder:
342      lin_op_multiplier = array_ops.placeholder_with_default(
343          multiplier, shape=None)
344
345    operator = linalg_lib.LinearOperatorScaledIdentity(
346        num_rows,
347        lin_op_multiplier,
348        is_self_adjoint=True if ensure_self_adjoint_and_pd else None,
349        is_positive_definite=True if ensure_self_adjoint_and_pd else None)
350
351    multiplier_matrix = array_ops.expand_dims(
352        array_ops.expand_dims(multiplier, -1), -1)
353    matrix = multiplier_matrix * linalg_ops.eye(
354        num_rows, batch_shape=batch_shape, dtype=dtype)
355
356    return operator, matrix
357
358  def test_assert_positive_definite_does_not_raise_when_positive(self):
359    with self.cached_session():
360      operator = linalg_lib.LinearOperatorScaledIdentity(
361          num_rows=2, multiplier=1.)
362      self.evaluate(operator.assert_positive_definite())  # Should not fail
363
364  def test_assert_positive_definite_raises_when_negative(self):
365    with self.cached_session():
366      operator = linalg_lib.LinearOperatorScaledIdentity(
367          num_rows=2, multiplier=-1.)
368      with self.assertRaisesOpError("not positive definite"):
369        self.evaluate(operator.assert_positive_definite())
370
371  def test_assert_non_singular_does_not_raise_when_non_singular(self):
372    with self.cached_session():
373      operator = linalg_lib.LinearOperatorScaledIdentity(
374          num_rows=2, multiplier=[1., 2., 3.])
375      self.evaluate(operator.assert_non_singular())  # Should not fail
376
377  def test_assert_non_singular_raises_when_singular(self):
378    with self.cached_session():
379      operator = linalg_lib.LinearOperatorScaledIdentity(
380          num_rows=2, multiplier=[1., 2., 0.])
381      with self.assertRaisesOpError("was singular"):
382        self.evaluate(operator.assert_non_singular())
383
384  def test_assert_self_adjoint_does_not_raise_when_self_adjoint(self):
385    with self.cached_session():
386      operator = linalg_lib.LinearOperatorScaledIdentity(
387          num_rows=2, multiplier=[1. + 0J])
388      self.evaluate(operator.assert_self_adjoint())  # Should not fail
389
390  def test_assert_self_adjoint_raises_when_not_self_adjoint(self):
391    with self.cached_session():
392      operator = linalg_lib.LinearOperatorScaledIdentity(
393          num_rows=2, multiplier=[1. + 1J])
394      with self.assertRaisesOpError("not self-adjoint"):
395        self.evaluate(operator.assert_self_adjoint())
396
397  def test_float16_matmul(self):
398    # float16 cannot be tested by base test class because tf.linalg.solve does
399    # not work with float16.
400    with self.cached_session():
401      multiplier = rng.rand(3).astype(np.float16)
402      operator = linalg_lib.LinearOperatorScaledIdentity(
403          num_rows=2, multiplier=multiplier)
404      x = rng.randn(2, 3).astype(np.float16)
405      y = operator.matmul(x)
406      self.assertAllClose(multiplier[..., None, None] * x, self.evaluate(y))
407
408  def test_non_scalar_num_rows_raises_static(self):
409    # Many "test_...num_rows" tests are performed in LinearOperatorIdentity.
410    with self.assertRaisesRegex(ValueError, "must be a 0-D Tensor"):
411      linalg_lib.LinearOperatorScaledIdentity(
412          num_rows=[2], multiplier=123.)
413
414  def test_wrong_matrix_dimensions_raises_static(self):
415    operator = linalg_lib.LinearOperatorScaledIdentity(
416        num_rows=2, multiplier=2.2)
417    x = rng.randn(3, 3).astype(np.float32)
418    with self.assertRaisesRegex(ValueError, "Dimensions.*not compatible"):
419      operator.matmul(x)
420
421  def test_wrong_matrix_dimensions_raises_dynamic(self):
422    num_rows = array_ops.placeholder_with_default(2, shape=None)
423    x = array_ops.placeholder_with_default(
424        rng.rand(3, 3).astype(np.float32), shape=None)
425
426    with self.cached_session():
427      with self.assertRaisesError("Dimensions.*not.compatible"):
428        operator = linalg_lib.LinearOperatorScaledIdentity(
429            num_rows,
430            multiplier=[1., 2],
431            assert_proper_shapes=True)
432        self.evaluate(operator.matmul(x))
433
434  def test_broadcast_matmul_and_solve(self):
435    # These cannot be done in the automated (base test class) tests since they
436    # test shapes that tf.batch_matmul cannot handle.
437    # In particular, tf.batch_matmul does not broadcast.
438    with self.cached_session() as sess:
439      # Given this x and LinearOperatorScaledIdentity shape of (2, 1, 3, 3), the
440      # broadcast shape of operator and 'x' is (2, 2, 3, 4)
441      x = random_ops.random_normal(shape=(1, 2, 3, 4))
442
443      # operator is 2.2 * identity (with a batch shape).
444      operator = linalg_lib.LinearOperatorScaledIdentity(
445          num_rows=3, multiplier=2.2 * array_ops.ones((2, 1)))
446
447      # Batch matrix of zeros with the broadcast shape of x and operator.
448      zeros = array_ops.zeros(shape=(2, 2, 3, 4), dtype=x.dtype)
449
450      # Test matmul
451      expected = x * 2.2 + zeros
452      operator_matmul = operator.matmul(x)
453      self.assertAllEqual(operator_matmul.shape, expected.shape)
454      self.assertAllClose(*self.evaluate([operator_matmul, expected]))
455
456      # Test solve
457      expected = x / 2.2 + zeros
458      operator_solve = operator.solve(x)
459      self.assertAllEqual(operator_solve.shape, expected.shape)
460      self.assertAllClose(*self.evaluate([operator_solve, expected]))
461
462  def test_broadcast_matmul_and_solve_scalar_scale_multiplier(self):
463    # These cannot be done in the automated (base test class) tests since they
464    # test shapes that tf.batch_matmul cannot handle.
465    # In particular, tf.batch_matmul does not broadcast.
466    with self.cached_session() as sess:
467      # Given this x and LinearOperatorScaledIdentity shape of (3, 3), the
468      # broadcast shape of operator and 'x' is (1, 2, 3, 4), which is the same
469      # shape as x.
470      x = random_ops.random_normal(shape=(1, 2, 3, 4))
471
472      # operator is 2.2 * identity (with a batch shape).
473      operator = linalg_lib.LinearOperatorScaledIdentity(
474          num_rows=3, multiplier=2.2)
475
476      # Test matmul
477      expected = x * 2.2
478      operator_matmul = operator.matmul(x)
479      self.assertAllEqual(operator_matmul.shape, expected.shape)
480      self.assertAllClose(*self.evaluate([operator_matmul, expected]))
481
482      # Test solve
483      expected = x / 2.2
484      operator_solve = operator.solve(x)
485      self.assertAllEqual(operator_solve.shape, expected.shape)
486      self.assertAllClose(*self.evaluate([operator_solve, expected]))
487
488  def test_is_x_flags(self):
489    operator = linalg_lib.LinearOperatorScaledIdentity(
490        num_rows=2, multiplier=1.,
491        is_positive_definite=False, is_non_singular=True)
492    self.assertFalse(operator.is_positive_definite)
493    self.assertTrue(operator.is_non_singular)
494    self.assertTrue(operator.is_self_adjoint)  # Auto-set due to real multiplier
495
496  def test_identity_matmul(self):
497    operator1 = linalg_lib.LinearOperatorIdentity(num_rows=2)
498    operator2 = linalg_lib.LinearOperatorScaledIdentity(
499        num_rows=2, multiplier=3.)
500    self.assertIsInstance(
501        operator1.matmul(operator1),
502        linalg_lib.LinearOperatorIdentity)
503
504    self.assertIsInstance(
505        operator1.matmul(operator1),
506        linalg_lib.LinearOperatorIdentity)
507
508    self.assertIsInstance(
509        operator2.matmul(operator2),
510        linalg_lib.LinearOperatorScaledIdentity)
511
512    operator_matmul = operator1.matmul(operator2)
513    self.assertIsInstance(
514        operator_matmul,
515        linalg_lib.LinearOperatorScaledIdentity)
516    self.assertAllClose(3., self.evaluate(operator_matmul.multiplier))
517
518    operator_matmul = operator2.matmul(operator1)
519    self.assertIsInstance(
520        operator_matmul,
521        linalg_lib.LinearOperatorScaledIdentity)
522    self.assertAllClose(3., self.evaluate(operator_matmul.multiplier))
523
524  def test_identity_solve(self):
525    operator1 = linalg_lib.LinearOperatorIdentity(num_rows=2)
526    operator2 = linalg_lib.LinearOperatorScaledIdentity(
527        num_rows=2, multiplier=3.)
528    self.assertIsInstance(
529        operator1.solve(operator1),
530        linalg_lib.LinearOperatorIdentity)
531
532    self.assertIsInstance(
533        operator2.solve(operator2),
534        linalg_lib.LinearOperatorScaledIdentity)
535
536    operator_solve = operator1.solve(operator2)
537    self.assertIsInstance(
538        operator_solve,
539        linalg_lib.LinearOperatorScaledIdentity)
540    self.assertAllClose(3., self.evaluate(operator_solve.multiplier))
541
542    operator_solve = operator2.solve(operator1)
543    self.assertIsInstance(
544        operator_solve,
545        linalg_lib.LinearOperatorScaledIdentity)
546    self.assertAllClose(1. / 3., self.evaluate(operator_solve.multiplier))
547
548  def test_scaled_identity_cholesky_type(self):
549    operator = linalg_lib.LinearOperatorScaledIdentity(
550        num_rows=2,
551        multiplier=3.,
552        is_positive_definite=True,
553        is_self_adjoint=True,
554    )
555    self.assertIsInstance(
556        operator.cholesky(),
557        linalg_lib.LinearOperatorScaledIdentity)
558
559  def test_scaled_identity_inverse_type(self):
560    operator = linalg_lib.LinearOperatorScaledIdentity(
561        num_rows=2,
562        multiplier=3.,
563        is_non_singular=True,
564    )
565    self.assertIsInstance(
566        operator.inverse(),
567        linalg_lib.LinearOperatorScaledIdentity)
568
569  def test_ref_type_shape_args_raises(self):
570    with self.assertRaisesRegex(TypeError, "num_rows.*reference"):
571      linalg_lib.LinearOperatorScaledIdentity(
572          num_rows=variables_module.Variable(2), multiplier=1.23)
573
574  def test_tape_safe(self):
575    multiplier = variables_module.Variable(1.23)
576    operator = linalg_lib.LinearOperatorScaledIdentity(
577        num_rows=2, multiplier=multiplier)
578    self.check_tape_safe(operator)
579
580  def test_convert_variables_to_tensors(self):
581    multiplier = variables_module.Variable(1.23)
582    operator = linalg_lib.LinearOperatorScaledIdentity(
583        num_rows=2, multiplier=multiplier)
584    with self.cached_session() as sess:
585      sess.run([multiplier.initializer])
586      self.check_convert_variables_to_tensors(operator)
587
588
589if __name__ == "__main__":
590  linear_operator_test_util.add_tests(LinearOperatorIdentityTest)
591  linear_operator_test_util.add_tests(LinearOperatorScaledIdentityTest)
592  test.main()
593