1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for tensorflow.ops.svd.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import itertools 22 23from absl.testing import parameterized 24import numpy as np 25 26from tensorflow.compiler.tests import xla_test 27from tensorflow.python.framework import tensor_shape 28from tensorflow.python.ops import array_ops 29from tensorflow.python.ops import gen_linalg_ops 30from tensorflow.python.ops import linalg_ops 31from tensorflow.python.platform import test 32 33 34class SvdOpTest(xla_test.XLATestCase, parameterized.TestCase): 35 36 def _compute_usvt(self, s, u, v): 37 m = u.shape[-1] 38 n = v.shape[-1] 39 if m <= n: 40 v = v[..., :m] 41 else: 42 u = u[..., :n] 43 44 return np.matmul(u * s[..., None, :], np.swapaxes(v, -1, -2)) 45 46 def _testSvdCorrectness(self, dtype, shape): 47 np.random.seed(1) 48 x_np = np.random.uniform(low=-1.0, high=1.0, size=shape).astype(dtype) 49 m, n = shape[-2], shape[-1] 50 _, s_np, _ = np.linalg.svd(x_np) 51 with self.session() as sess: 52 x_tf = array_ops.placeholder(dtype) 53 with self.test_scope(): 54 s, u, v = linalg_ops.svd(x_tf, full_matrices=True) 55 s_val, u_val, v_val = sess.run([s, u, v], feed_dict={x_tf: x_np}) 56 u_diff = np.matmul(u_val, np.swapaxes(u_val, -1, -2)) - np.eye(m) 57 v_diff = np.matmul(v_val, np.swapaxes(v_val, -1, -2)) - np.eye(n) 58 # Check u_val and v_val are orthogonal matrices. 59 self.assertLess(np.linalg.norm(u_diff), 1e-2) 60 self.assertLess(np.linalg.norm(v_diff), 1e-2) 61 # Check that the singular values are correct, i.e., close to the ones from 62 # numpy.lingal.svd. 63 self.assertLess(np.linalg.norm(s_val - s_np), 1e-2) 64 # The tolerance is set based on our tests on numpy's svd. As our tests 65 # have batch dimensions and all our operations are on float32, we set the 66 # tolerance a bit larger. Numpy's svd calls LAPACK's svd, which operates 67 # on double precision. 68 self.assertLess( 69 np.linalg.norm(self._compute_usvt(s_val, u_val, v_val) - x_np), 2e-2) 70 71 # Check behavior with compute_uv=False. We expect to still see 3 outputs, 72 # with a sentinel scalar 0 in the last two outputs. 73 with self.test_scope(): 74 no_uv_s, no_uv_u, no_uv_v = gen_linalg_ops.svd( 75 x_tf, full_matrices=True, compute_uv=False) 76 no_uv_s_val, no_uv_u_val, no_uv_v_val = sess.run( 77 [no_uv_s, no_uv_u, no_uv_v], feed_dict={x_tf: x_np}) 78 self.assertAllClose(no_uv_s_val, s_val, atol=1e-4, rtol=1e-4) 79 self.assertEqual(no_uv_u_val.shape, tensor_shape.TensorShape([0])) 80 self.assertEqual(no_uv_v_val.shape, tensor_shape.TensorShape([0])) 81 82 SIZES = [1, 2, 5, 10, 32, 64] 83 DTYPES = [np.float32] 84 PARAMS = itertools.product(SIZES, DTYPES) 85 86 @parameterized.parameters(*PARAMS) 87 def testSvd(self, n, dtype): 88 for batch_dims in [(), (3,)] + [(3, 2)] * (n < 10): 89 self._testSvdCorrectness(dtype, batch_dims + (n, n)) 90 self._testSvdCorrectness(dtype, batch_dims + (2 * n, n)) 91 self._testSvdCorrectness(dtype, batch_dims + (n, 2 * n)) 92 93 94if __name__ == "__main__": 95 test.main() 96