• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for tensorflow.ops.math_ops.matrix_square_root."""
16
17import numpy as np
18
19from tensorflow.python.framework import constant_op
20from tensorflow.python.framework import errors_impl
21from tensorflow.python.framework import test_util
22from tensorflow.python.ops import gen_linalg_ops
23from tensorflow.python.ops import math_ops
24from tensorflow.python.ops import stateless_random_ops
25from tensorflow.python.platform import test
26
27
28@test_util.run_all_without_tensor_float_32
29class SquareRootOpTest(test.TestCase):
30
31  def _verifySquareRoot(self, matrix, np_type):
32    matrix = matrix.astype(np_type)
33
34    # Verify that matmul(sqrtm(A), sqrtm(A)) = A
35    sqrt = gen_linalg_ops.matrix_square_root(matrix)
36    square = test_util.matmul_without_tf32(sqrt, sqrt)
37    self.assertShapeEqual(matrix, square)
38    self.assertAllClose(matrix, square, rtol=1e-4, atol=1e-3)
39
40  def _verifySquareRootReal(self, x):
41    for np_type in [np.float32, np.float64]:
42      self._verifySquareRoot(x, np_type)
43
44  def _verifySquareRootComplex(self, x):
45    for np_type in [np.complex64, np.complex128]:
46      self._verifySquareRoot(x, np_type)
47
48  def _makeBatch(self, matrix1, matrix2):
49    matrix_batch = np.concatenate(
50        [np.expand_dims(matrix1, 0),
51         np.expand_dims(matrix2, 0)])
52    matrix_batch = np.tile(matrix_batch, [2, 3, 1, 1])
53    return matrix_batch
54
55  def _testMatrices(self, matrix1, matrix2):
56    # Real
57    self._verifySquareRootReal(matrix1)
58    self._verifySquareRootReal(matrix2)
59    self._verifySquareRootReal(self._makeBatch(matrix1, matrix2))
60    matrix1 = matrix1.astype(np.complex64)
61    matrix2 = matrix2.astype(np.complex64)
62    matrix1 += 1j * matrix1
63    matrix2 += 1j * matrix2
64    self._verifySquareRootComplex(matrix1)
65    self._verifySquareRootComplex(matrix2)
66    self._verifySquareRootComplex(self._makeBatch(matrix1, matrix2))
67
68  def testSymmetricPositiveDefinite(self):
69    matrix1 = np.array([[2., 1.], [1., 2.]])
70    matrix2 = np.array([[3., -1.], [-1., 3.]])
71    self._testMatrices(matrix1, matrix2)
72
73  def testAsymmetric(self):
74    matrix1 = np.array([[0., 4.], [-1., 5.]])
75    matrix2 = np.array([[33., 24.], [48., 57.]])
76    self._testMatrices(matrix1, matrix2)
77
78  def testIdentityMatrix(self):
79    # 2x2
80    identity = np.array([[1., 0], [0, 1.]])
81    self._verifySquareRootReal(identity)
82    # 3x3
83    identity = np.array([[1., 0, 0], [0, 1., 0], [0, 0, 1.]])
84    self._verifySquareRootReal(identity)
85
86  def testEmpty(self):
87    self._verifySquareRootReal(np.empty([0, 2, 2]))
88    self._verifySquareRootReal(np.empty([2, 0, 0]))
89
90  @test_util.run_in_graph_and_eager_modes(use_gpu=True)
91  def testWrongDimensions(self):
92    # The input to the square root should be at least a 2-dimensional tensor.
93    tensor = constant_op.constant([1., 2.])
94    with self.assertRaises((ValueError, errors_impl.InvalidArgumentError)):
95      gen_linalg_ops.matrix_square_root(tensor)
96
97  @test_util.run_in_graph_and_eager_modes(use_gpu=True)
98  def testNotSquare(self):
99    with self.assertRaises((ValueError, errors_impl.InvalidArgumentError)):
100      tensor = constant_op.constant([[1., 0., -1.], [-1., 1., 0.]])
101      self.evaluate(gen_linalg_ops.matrix_square_root(tensor))
102
103  @test_util.run_in_graph_and_eager_modes(use_gpu=True)
104  def testConcurrentExecutesWithoutError(self):
105    matrix_shape = [5, 5]
106    seed = [42, 24]
107    matrix1 = stateless_random_ops.stateless_random_normal(
108        shape=matrix_shape, seed=seed)
109    matrix2 = stateless_random_ops.stateless_random_normal(
110        shape=matrix_shape, seed=seed)
111    self.assertAllEqual(matrix1, matrix2)
112    square1 = math_ops.matmul(matrix1, matrix1)
113    square2 = math_ops.matmul(matrix2, matrix2)
114    sqrt1 = gen_linalg_ops.matrix_square_root(square1)
115    sqrt2 = gen_linalg_ops.matrix_square_root(square2)
116    all_ops = [sqrt1, sqrt2]
117    sqrt = self.evaluate(all_ops)
118    self.assertAllClose(sqrt[0], sqrt[1])
119
120
121if __name__ == "__main__":
122  test.main()
123