• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for tensorflow.ops.math_ops.matrix_inverse."""
16
17import itertools
18import unittest
19
20from absl.testing import parameterized
21import numpy as np
22
23from tensorflow.compiler.tests import xla_test
24from tensorflow.python.framework import test_util
25from tensorflow.python.ops import array_ops
26from tensorflow.python.ops import linalg_ops
27from tensorflow.python.ops import math_ops
28from tensorflow.python.platform import test
29
30
31@test_util.run_all_without_tensor_float_32(
32    "XLA QR op calls matmul. Also, matmul used for verification. Also with "
33    'TensorFloat-32, mysterious "Unable to launch cuBLAS gemm" error '
34    "occasionally occurs")
35# TODO(b/165435566): Fix "Unable to launch cuBLAS gemm" error
36class QrOpTest(xla_test.XLATestCase, parameterized.TestCase):
37
38  def AdjustedNorm(self, x):
39    """Computes the norm of matrices in 'x', adjusted for dimension and type."""
40    norm = np.linalg.norm(x, axis=(-2, -1))
41    return norm / (max(x.shape[-2:]) * np.finfo(x.dtype).eps)
42
43  def CompareOrthogonal(self, x, y, rank):
44    # We only compare the first 'rank' orthogonal vectors since the
45    # remainder form an arbitrary orthonormal basis for the
46    # (row- or column-) null space, whose exact value depends on
47    # implementation details. Notice that since we check that the
48    # matrices of singular vectors are unitary elsewhere, we do
49    # implicitly test that the trailing vectors of x and y span the
50    # same space.
51    x = x[..., 0:rank]
52    y = y[..., 0:rank]
53    # Q is only unique up to sign (complex phase factor for complex matrices),
54    # so we normalize the sign first.
55    sum_of_ratios = np.sum(np.divide(y, x), -2, keepdims=True)
56    phases = np.divide(sum_of_ratios, np.abs(sum_of_ratios))
57    x *= phases
58    self.assertTrue(np.all(self.AdjustedNorm(x - y) < 30.0))
59
60  def CheckApproximation(self, a, q, r):
61    # Tests that a ~= q*r.
62    precision = self.AdjustedNorm(a - np.matmul(q, r))
63    self.assertTrue(np.all(precision < 11.0))
64
65  def CheckUnitary(self, x):
66    # Tests that x[...,:,:]^H * x[...,:,:] is close to the identity.
67    xx = math_ops.matmul(x, x, adjoint_a=True)
68    identity = array_ops.matrix_band_part(array_ops.ones_like(xx), 0, 0)
69    tol = 100 * np.finfo(x.dtype).eps
70    self.assertAllClose(xx, identity, atol=tol)
71
72  def _random_matrix(self, dtype, shape):
73    np.random.seed(1)
74
75    def rng():
76      return np.random.uniform(
77          low=-1.0, high=1.0, size=np.prod(shape)).reshape(shape).astype(dtype)
78
79    x_np = rng()
80    if np.issubdtype(dtype, np.complexfloating):
81      x_np += rng() * dtype(1j)
82    return x_np
83
84  def _test(self, x_np, full_matrices, full_rank=True):
85    dtype = x_np.dtype
86    shape = x_np.shape
87    with self.session() as sess:
88      x_tf = array_ops.placeholder(dtype)
89      with self.device_scope():
90        q_tf, r_tf = linalg_ops.qr(x_tf, full_matrices=full_matrices)
91      q_tf_val, r_tf_val = sess.run([q_tf, r_tf], feed_dict={x_tf: x_np})
92
93      q_dims = q_tf_val.shape
94      np_q = np.ndarray(q_dims, dtype)
95      np_q_reshape = np.reshape(np_q, (-1, q_dims[-2], q_dims[-1]))
96      new_first_dim = np_q_reshape.shape[0]
97
98      x_reshape = np.reshape(x_np, (-1, x_np.shape[-2], x_np.shape[-1]))
99      for i in range(new_first_dim):
100        if full_matrices:
101          np_q_reshape[i, :, :], _ = np.linalg.qr(
102              x_reshape[i, :, :], mode="complete")
103        else:
104          np_q_reshape[i, :, :], _ = np.linalg.qr(
105              x_reshape[i, :, :], mode="reduced")
106      np_q = np.reshape(np_q_reshape, q_dims)
107      if full_rank:
108        # Q is unique up to sign/phase if the matrix is full-rank.
109        self.CompareOrthogonal(np_q, q_tf_val, min(shape[-2:]))
110      self.CheckApproximation(x_np, q_tf_val, r_tf_val)
111      self.CheckUnitary(q_tf_val)
112
113  SIZES = [1, 2, 5, 10, 32, 100, 300, 603]
114  DTYPES = [np.float32, np.complex64]
115  PARAMS = itertools.product(SIZES, SIZES, DTYPES)
116
117  @parameterized.parameters(*PARAMS)
118  def testQR(self, rows, cols, dtype):
119    for full_matrices in [True, False]:
120      # Only tests the (3, 2) case for small numbers of rows/columns.
121      for batch_dims in [(), (3,)] + [(3, 2)] * (max(rows, cols) < 10):
122        x_np = self._random_matrix(dtype, batch_dims + (rows, cols))
123        self._test(x_np, full_matrices)
124
125  def testLarge2000x2000(self):
126    x_np = self._random_matrix(np.float32, (2000, 2000))
127    self._test(x_np, full_matrices=True)
128
129  @unittest.skip("Test times out on CI")
130  def testLarge17500x128(self):
131    x_np = self._random_matrix(np.float32, (17500, 128))
132    self._test(x_np, full_matrices=True)
133
134  @parameterized.parameters((23, 25), (513, 23))
135  def testZeroColumn(self, rows, cols):
136    x_np = self._random_matrix(np.complex64, (rows, cols))
137    x_np[:, 7] = 0.
138    self._test(x_np, full_matrices=True)
139
140  @parameterized.parameters((4, 4), (514, 20))
141  def testRepeatedColumn(self, rows, cols):
142    x_np = self._random_matrix(np.complex64, (rows, cols))
143    x_np[:, 1] = x_np[:, 2]
144    self._test(x_np, full_matrices=True, full_rank=False)
145
146
147if __name__ == "__main__":
148  test.main()
149