1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for tensorflow.ops.tf.Cholesky.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import numpy as np 22from six.moves import xrange # pylint: disable=redefined-builtin 23 24from tensorflow.compiler.tests import xla_test 25from tensorflow.python.framework import constant_op 26from tensorflow.python.framework import dtypes 27from tensorflow.python.framework import errors 28from tensorflow.python.framework import test_util 29from tensorflow.python.ops import array_ops 30from tensorflow.python.ops import linalg_ops 31from tensorflow.python.platform import test 32 33 34class CholeskyOpTest(xla_test.XLATestCase): 35 36 # Cholesky defined for float64, float32, complex64, complex128 37 # (https://www.tensorflow.org/api_docs/python/tf/cholesky) 38 @property 39 def float_types(self): 40 return set(super(CholeskyOpTest, self).float_types).intersection( 41 (np.float64, np.float32, np.complex64, np.complex128)) 42 43 def _verifyCholeskyBase(self, sess, placeholder, x, chol, verification, atol): 44 chol_np, verification_np = sess.run([chol, verification], {placeholder: x}) 45 self.assertAllClose(x, verification_np, atol=atol) 46 self.assertShapeEqual(x, chol) 47 # Check that the cholesky is lower triangular, and has positive diagonal 48 # elements. 49 if chol_np.shape[-1] > 0: 50 chol_reshaped = np.reshape(chol_np, (-1, chol_np.shape[-2], 51 chol_np.shape[-1])) 52 for chol_matrix in chol_reshaped: 53 self.assertAllClose(chol_matrix, np.tril(chol_matrix), atol=atol) 54 self.assertTrue((np.diag(chol_matrix) > 0.0).all()) 55 56 def _verifyCholesky(self, x, atol=1e-6): 57 # Verify that LL^T == x. 58 with self.session() as sess: 59 placeholder = array_ops.placeholder( 60 dtypes.as_dtype(x.dtype), shape=x.shape) 61 with self.test_scope(): 62 chol = linalg_ops.cholesky(placeholder) 63 verification = test_util.matmul_without_tf32(chol, chol, adjoint_b=True) 64 self._verifyCholeskyBase(sess, placeholder, x, chol, verification, atol) 65 66 def testBasic(self): 67 data = np.array([[4., -1., 2.], [-1., 6., 0], [2., 0., 5.]]) 68 for dtype in self.float_types: 69 self._verifyCholesky(data.astype(dtype)) 70 71 def testBatch(self): 72 for dtype in self.float_types: 73 simple_array = np.array( 74 [[[1., 0.], [0., 5.]]], dtype=dtype) # shape (1, 2, 2) 75 self._verifyCholesky(simple_array) 76 self._verifyCholesky(np.vstack((simple_array, simple_array))) 77 odd_sized_array = np.array( 78 [[[4., -1., 2.], [-1., 6., 0], [2., 0., 5.]]], dtype=dtype) 79 self._verifyCholesky(np.vstack((odd_sized_array, odd_sized_array))) 80 81 # Generate random positive-definite matrices. 82 matrices = np.random.rand(10, 5, 5).astype(dtype) 83 for i in xrange(10): 84 matrices[i] = np.dot(matrices[i].T, matrices[i]) 85 self._verifyCholesky(matrices, atol=1e-4) 86 87 @test_util.run_v2_only 88 def testNonSquareMatrixV2(self): 89 for dtype in self.float_types: 90 with self.assertRaises(errors.InvalidArgumentError): 91 linalg_ops.cholesky(np.array([[1., 2., 3.], [3., 4., 5.]], dtype=dtype)) 92 with self.assertRaises(errors.InvalidArgumentError): 93 linalg_ops.cholesky( 94 np.array( 95 [[[1., 2., 3.], [3., 4., 5.]], [[1., 2., 3.], [3., 4., 5.]]], 96 dtype=dtype)) 97 98 @test_util.run_v1_only("Different error types") 99 def testNonSquareMatrixV1(self): 100 for dtype in self.float_types: 101 with self.assertRaises(ValueError): 102 linalg_ops.cholesky(np.array([[1., 2., 3.], [3., 4., 5.]], dtype=dtype)) 103 with self.assertRaises(ValueError): 104 linalg_ops.cholesky( 105 np.array( 106 [[[1., 2., 3.], [3., 4., 5.]], [[1., 2., 3.], [3., 4., 5.]]], 107 dtype=dtype)) 108 109 @test_util.run_v2_only 110 def testWrongDimensionsV2(self): 111 for dtype in self.float_types: 112 tensor3 = constant_op.constant([1., 2.], dtype=dtype) 113 with self.assertRaises(errors.InvalidArgumentError): 114 linalg_ops.cholesky(tensor3) 115 with self.assertRaises(errors.InvalidArgumentError): 116 linalg_ops.cholesky(tensor3) 117 118 @test_util.run_v1_only("Different error types") 119 def testWrongDimensionsV1(self): 120 for dtype in self.float_types: 121 tensor3 = constant_op.constant([1., 2.], dtype=dtype) 122 with self.assertRaises(ValueError): 123 linalg_ops.cholesky(tensor3) 124 with self.assertRaises(ValueError): 125 linalg_ops.cholesky(tensor3) 126 127 def testLarge2000x2000(self): 128 n = 2000 129 shape = (n, n) 130 data = np.ones(shape).astype(np.float32) / (2.0 * n) + np.diag( 131 np.ones(n).astype(np.float32)) 132 self._verifyCholesky(data, atol=1e-4) 133 134 def testMatrixConditionNumbers(self): 135 for dtype in self.float_types: 136 condition_number = 1000 137 size = 20 138 139 # Generate random positive-definite symmetric matrices, and take their 140 # Eigendecomposition. 141 matrix = np.random.rand(size, size) 142 matrix = np.dot(matrix.T, matrix) 143 _, w = np.linalg.eigh(matrix) 144 145 # Build new Eigenvalues exponentially distributed between 1 and 146 # 1/condition_number 147 v = np.exp(-np.log(condition_number) * np.linspace(0, size, size) / size) 148 matrix = np.dot(np.dot(w, np.diag(v)), w.T).astype(dtype) 149 self._verifyCholesky(matrix, atol=1e-4) 150 151if __name__ == "__main__": 152 test.main() 153