1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for tensorflow.ops.math_ops.matrix_inverse.""" 16 17import itertools 18import unittest 19 20from absl.testing import parameterized 21import numpy as np 22 23from tensorflow.compiler.tests import xla_test 24from tensorflow.python.framework import test_util 25from tensorflow.python.ops import array_ops 26from tensorflow.python.ops import linalg_ops 27from tensorflow.python.ops import math_ops 28from tensorflow.python.platform import test 29 30 31@test_util.run_all_without_tensor_float_32( 32 "XLA QR op calls matmul. Also, matmul used for verification. Also with " 33 'TensorFloat-32, mysterious "Unable to launch cuBLAS gemm" error ' 34 "occasionally occurs") 35# TODO(b/165435566): Fix "Unable to launch cuBLAS gemm" error 36class QrOpTest(xla_test.XLATestCase, parameterized.TestCase): 37 38 def AdjustedNorm(self, x): 39 """Computes the norm of matrices in 'x', adjusted for dimension and type.""" 40 norm = np.linalg.norm(x, axis=(-2, -1)) 41 return norm / (max(x.shape[-2:]) * np.finfo(x.dtype).eps) 42 43 def CompareOrthogonal(self, x, y, rank): 44 # We only compare the first 'rank' orthogonal vectors since the 45 # remainder form an arbitrary orthonormal basis for the 46 # (row- or column-) null space, whose exact value depends on 47 # implementation details. Notice that since we check that the 48 # matrices of singular vectors are unitary elsewhere, we do 49 # implicitly test that the trailing vectors of x and y span the 50 # same space. 51 x = x[..., 0:rank] 52 y = y[..., 0:rank] 53 # Q is only unique up to sign (complex phase factor for complex matrices), 54 # so we normalize the sign first. 55 sum_of_ratios = np.sum(np.divide(y, x), -2, keepdims=True) 56 phases = np.divide(sum_of_ratios, np.abs(sum_of_ratios)) 57 x *= phases 58 self.assertTrue(np.all(self.AdjustedNorm(x - y) < 30.0)) 59 60 def CheckApproximation(self, a, q, r): 61 # Tests that a ~= q*r. 62 precision = self.AdjustedNorm(a - np.matmul(q, r)) 63 self.assertTrue(np.all(precision < 11.0)) 64 65 def CheckUnitary(self, x): 66 # Tests that x[...,:,:]^H * x[...,:,:] is close to the identity. 67 xx = math_ops.matmul(x, x, adjoint_a=True) 68 identity = array_ops.matrix_band_part(array_ops.ones_like(xx), 0, 0) 69 tol = 100 * np.finfo(x.dtype).eps 70 self.assertAllClose(xx, identity, atol=tol) 71 72 def _random_matrix(self, dtype, shape): 73 np.random.seed(1) 74 75 def rng(): 76 return np.random.uniform( 77 low=-1.0, high=1.0, size=np.prod(shape)).reshape(shape).astype(dtype) 78 79 x_np = rng() 80 if np.issubdtype(dtype, np.complexfloating): 81 x_np += rng() * dtype(1j) 82 return x_np 83 84 def _test(self, x_np, full_matrices, full_rank=True): 85 dtype = x_np.dtype 86 shape = x_np.shape 87 with self.session() as sess: 88 x_tf = array_ops.placeholder(dtype) 89 with self.device_scope(): 90 q_tf, r_tf = linalg_ops.qr(x_tf, full_matrices=full_matrices) 91 q_tf_val, r_tf_val = sess.run([q_tf, r_tf], feed_dict={x_tf: x_np}) 92 93 q_dims = q_tf_val.shape 94 np_q = np.ndarray(q_dims, dtype) 95 np_q_reshape = np.reshape(np_q, (-1, q_dims[-2], q_dims[-1])) 96 new_first_dim = np_q_reshape.shape[0] 97 98 x_reshape = np.reshape(x_np, (-1, x_np.shape[-2], x_np.shape[-1])) 99 for i in range(new_first_dim): 100 if full_matrices: 101 np_q_reshape[i, :, :], _ = np.linalg.qr( 102 x_reshape[i, :, :], mode="complete") 103 else: 104 np_q_reshape[i, :, :], _ = np.linalg.qr( 105 x_reshape[i, :, :], mode="reduced") 106 np_q = np.reshape(np_q_reshape, q_dims) 107 if full_rank: 108 # Q is unique up to sign/phase if the matrix is full-rank. 109 self.CompareOrthogonal(np_q, q_tf_val, min(shape[-2:])) 110 self.CheckApproximation(x_np, q_tf_val, r_tf_val) 111 self.CheckUnitary(q_tf_val) 112 113 SIZES = [1, 2, 5, 10, 32, 100, 300, 603] 114 DTYPES = [np.float32, np.complex64] 115 PARAMS = itertools.product(SIZES, SIZES, DTYPES) 116 117 @parameterized.parameters(*PARAMS) 118 def testQR(self, rows, cols, dtype): 119 for full_matrices in [True, False]: 120 # Only tests the (3, 2) case for small numbers of rows/columns. 121 for batch_dims in [(), (3,)] + [(3, 2)] * (max(rows, cols) < 10): 122 x_np = self._random_matrix(dtype, batch_dims + (rows, cols)) 123 self._test(x_np, full_matrices) 124 125 def testLarge2000x2000(self): 126 x_np = self._random_matrix(np.float32, (2000, 2000)) 127 self._test(x_np, full_matrices=True) 128 129 @unittest.skip("Test times out on CI") 130 def testLarge17500x128(self): 131 x_np = self._random_matrix(np.float32, (17500, 128)) 132 self._test(x_np, full_matrices=True) 133 134 @parameterized.parameters((23, 25), (513, 23)) 135 def testZeroColumn(self, rows, cols): 136 x_np = self._random_matrix(np.complex64, (rows, cols)) 137 x_np[:, 7] = 0. 138 self._test(x_np, full_matrices=True) 139 140 @parameterized.parameters((4, 4), (514, 20)) 141 def testRepeatedColumn(self, rows, cols): 142 x_np = self._random_matrix(np.complex64, (rows, cols)) 143 x_np[:, 1] = x_np[:, 2] 144 self._test(x_np, full_matrices=True, full_rank=False) 145 146 147if __name__ == "__main__": 148 test.main() 149