1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for tensorflow.ops.math_ops.matrix_square_root.""" 16 17import numpy as np 18 19from tensorflow.python.framework import constant_op 20from tensorflow.python.framework import errors_impl 21from tensorflow.python.framework import test_util 22from tensorflow.python.ops import gen_linalg_ops 23from tensorflow.python.ops import math_ops 24from tensorflow.python.ops import stateless_random_ops 25from tensorflow.python.platform import test 26 27 28@test_util.run_all_without_tensor_float_32 29class SquareRootOpTest(test.TestCase): 30 31 def _verifySquareRoot(self, matrix, np_type): 32 matrix = matrix.astype(np_type) 33 34 # Verify that matmul(sqrtm(A), sqrtm(A)) = A 35 sqrt = gen_linalg_ops.matrix_square_root(matrix) 36 square = test_util.matmul_without_tf32(sqrt, sqrt) 37 self.assertShapeEqual(matrix, square) 38 self.assertAllClose(matrix, square, rtol=1e-4, atol=1e-3) 39 40 def _verifySquareRootReal(self, x): 41 for np_type in [np.float32, np.float64]: 42 self._verifySquareRoot(x, np_type) 43 44 def _verifySquareRootComplex(self, x): 45 for np_type in [np.complex64, np.complex128]: 46 self._verifySquareRoot(x, np_type) 47 48 def _makeBatch(self, matrix1, matrix2): 49 matrix_batch = np.concatenate( 50 [np.expand_dims(matrix1, 0), 51 np.expand_dims(matrix2, 0)]) 52 matrix_batch = np.tile(matrix_batch, [2, 3, 1, 1]) 53 return matrix_batch 54 55 def _testMatrices(self, matrix1, matrix2): 56 # Real 57 self._verifySquareRootReal(matrix1) 58 self._verifySquareRootReal(matrix2) 59 self._verifySquareRootReal(self._makeBatch(matrix1, matrix2)) 60 matrix1 = matrix1.astype(np.complex64) 61 matrix2 = matrix2.astype(np.complex64) 62 matrix1 += 1j * matrix1 63 matrix2 += 1j * matrix2 64 self._verifySquareRootComplex(matrix1) 65 self._verifySquareRootComplex(matrix2) 66 self._verifySquareRootComplex(self._makeBatch(matrix1, matrix2)) 67 68 def testSymmetricPositiveDefinite(self): 69 matrix1 = np.array([[2., 1.], [1., 2.]]) 70 matrix2 = np.array([[3., -1.], [-1., 3.]]) 71 self._testMatrices(matrix1, matrix2) 72 73 def testAsymmetric(self): 74 matrix1 = np.array([[0., 4.], [-1., 5.]]) 75 matrix2 = np.array([[33., 24.], [48., 57.]]) 76 self._testMatrices(matrix1, matrix2) 77 78 def testIdentityMatrix(self): 79 # 2x2 80 identity = np.array([[1., 0], [0, 1.]]) 81 self._verifySquareRootReal(identity) 82 # 3x3 83 identity = np.array([[1., 0, 0], [0, 1., 0], [0, 0, 1.]]) 84 self._verifySquareRootReal(identity) 85 86 def testEmpty(self): 87 self._verifySquareRootReal(np.empty([0, 2, 2])) 88 self._verifySquareRootReal(np.empty([2, 0, 0])) 89 90 @test_util.run_in_graph_and_eager_modes(use_gpu=True) 91 def testWrongDimensions(self): 92 # The input to the square root should be at least a 2-dimensional tensor. 93 tensor = constant_op.constant([1., 2.]) 94 with self.assertRaises((ValueError, errors_impl.InvalidArgumentError)): 95 gen_linalg_ops.matrix_square_root(tensor) 96 97 @test_util.run_in_graph_and_eager_modes(use_gpu=True) 98 def testNotSquare(self): 99 with self.assertRaises((ValueError, errors_impl.InvalidArgumentError)): 100 tensor = constant_op.constant([[1., 0., -1.], [-1., 1., 0.]]) 101 self.evaluate(gen_linalg_ops.matrix_square_root(tensor)) 102 103 @test_util.run_in_graph_and_eager_modes(use_gpu=True) 104 def testConcurrentExecutesWithoutError(self): 105 matrix_shape = [5, 5] 106 seed = [42, 24] 107 matrix1 = stateless_random_ops.stateless_random_normal( 108 shape=matrix_shape, seed=seed) 109 matrix2 = stateless_random_ops.stateless_random_normal( 110 shape=matrix_shape, seed=seed) 111 self.assertAllEqual(matrix1, matrix2) 112 square1 = math_ops.matmul(matrix1, matrix1) 113 square2 = math_ops.matmul(matrix2, matrix2) 114 sqrt1 = gen_linalg_ops.matrix_square_root(square1) 115 sqrt2 = gen_linalg_ops.matrix_square_root(square2) 116 all_ops = [sqrt1, sqrt2] 117 sqrt = self.evaluate(all_ops) 118 self.assertAllClose(sqrt[0], sqrt[1]) 119 120 121if __name__ == "__main__": 122 test.main() 123