1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for tensorflow.ops.tf.Lu.""" 16 17import numpy as np 18 19from tensorflow.python.client import session 20from tensorflow.python.framework import dtypes 21from tensorflow.python.framework import errors 22from tensorflow.python.framework import ops 23from tensorflow.python.framework import test_util 24from tensorflow.python.ops import array_ops 25from tensorflow.python.ops import control_flow_ops 26from tensorflow.python.ops import linalg_ops 27from tensorflow.python.ops import map_fn 28from tensorflow.python.ops import math_ops 29from tensorflow.python.ops import stateless_random_ops 30from tensorflow.python.ops import variables 31from tensorflow.python.platform import benchmark 32from tensorflow.python.platform import test 33 34 35@test_util.with_eager_op_as_function 36class LuOpTest(test.TestCase): 37 38 @property 39 def float_types(self): 40 return set((np.float64, np.float32, np.complex64, np.complex128)) 41 42 def _verifyLuBase(self, x, lower, upper, perm, verification, 43 output_idx_type): 44 lower_np, upper_np, perm_np, verification_np = self.evaluate( 45 [lower, upper, perm, verification]) 46 47 self.assertAllClose(x, verification_np) 48 self.assertShapeEqual(x, lower) 49 self.assertShapeEqual(x, upper) 50 51 self.assertAllEqual(x.shape[:-1], perm.shape.as_list()) 52 53 # Check dtypes are as expected. 54 self.assertEqual(x.dtype, lower_np.dtype) 55 self.assertEqual(x.dtype, upper_np.dtype) 56 self.assertEqual(output_idx_type.as_numpy_dtype, perm_np.dtype) 57 58 # Check that the permutation is valid. 59 if perm_np.shape[-1] > 0: 60 perm_reshaped = np.reshape(perm_np, (-1, perm_np.shape[-1])) 61 for perm_vector in perm_reshaped: 62 self.assertAllClose(np.arange(len(perm_vector)), np.sort(perm_vector)) 63 64 def _verifyLu(self, x, output_idx_type=dtypes.int64): 65 # Verify that Px = LU. 66 lu, perm = linalg_ops.lu(x, output_idx_type=output_idx_type) 67 68 # Prepare the lower factor of shape num_rows x num_rows 69 lu_shape = np.array(lu.shape.as_list()) 70 batch_shape = lu_shape[:-2] 71 num_rows = lu_shape[-2] 72 num_cols = lu_shape[-1] 73 74 lower = array_ops.matrix_band_part(lu, -1, 0) 75 76 if num_rows > num_cols: 77 eye = linalg_ops.eye( 78 num_rows, batch_shape=batch_shape, dtype=lower.dtype) 79 lower = array_ops.concat([lower, eye[..., num_cols:]], axis=-1) 80 elif num_rows < num_cols: 81 lower = lower[..., :num_rows] 82 83 # Fill the diagonal with ones. 84 ones_diag = array_ops.ones( 85 np.append(batch_shape, num_rows), dtype=lower.dtype) 86 lower = array_ops.matrix_set_diag(lower, ones_diag) 87 88 # Prepare the upper factor. 89 upper = array_ops.matrix_band_part(lu, 0, -1) 90 91 verification = test_util.matmul_without_tf32(lower, upper) 92 93 # Permute the rows of product of the Cholesky factors. 94 if num_rows > 0: 95 # Reshape the product of the triangular factors and permutation indices 96 # to a single batch dimension. This makes it easy to apply 97 # invert_permutation and gather_nd ops. 98 perm_reshaped = array_ops.reshape(perm, [-1, num_rows]) 99 verification_reshaped = array_ops.reshape(verification, 100 [-1, num_rows, num_cols]) 101 # Invert the permutation in each batch. 102 inv_perm_reshaped = map_fn.map_fn(array_ops.invert_permutation, 103 perm_reshaped) 104 batch_size = perm_reshaped.shape.as_list()[0] 105 # Prepare the batch indices with the same shape as the permutation. 106 # The corresponding batch index is paired with each of the `num_rows` 107 # permutation indices. 108 batch_indices = math_ops.cast( 109 array_ops.broadcast_to( 110 math_ops.range(batch_size)[:, None], perm_reshaped.shape), 111 dtype=output_idx_type) 112 if inv_perm_reshaped.shape == [0]: 113 inv_perm_reshaped = array_ops.zeros_like(batch_indices) 114 permuted_verification_reshaped = array_ops.gather_nd( 115 verification_reshaped, 116 array_ops.stack([batch_indices, inv_perm_reshaped], axis=-1)) 117 118 # Reshape the verification matrix back to the original shape. 119 verification = array_ops.reshape(permuted_verification_reshaped, 120 lu_shape) 121 122 self._verifyLuBase(x, lower, upper, perm, verification, 123 output_idx_type) 124 125 def testBasic(self): 126 data = np.array([[4., -1., 2.], [-1., 6., 0], [10., 0., 5.]]) 127 128 for dtype in (np.float32, np.float64): 129 for output_idx_type in (dtypes.int32, dtypes.int64): 130 with self.subTest(dtype=dtype, output_idx_type=output_idx_type): 131 self._verifyLu(data.astype(dtype), output_idx_type=output_idx_type) 132 133 for dtype in (np.complex64, np.complex128): 134 for output_idx_type in (dtypes.int32, dtypes.int64): 135 with self.subTest(dtype=dtype, output_idx_type=output_idx_type): 136 complex_data = np.tril(1j * data, -1).astype(dtype) 137 complex_data += np.triu(-1j * data, 1).astype(dtype) 138 complex_data += data 139 self._verifyLu(complex_data, output_idx_type=output_idx_type) 140 141 def testPivoting(self): 142 # This matrix triggers partial pivoting because the first diagonal entry 143 # is small. 144 data = np.array([[1e-9, 1., 0.], [1., 0., 0], [0., 1., 5]]) 145 self._verifyLu(data.astype(np.float32)) 146 147 for dtype in (np.float32, np.float64): 148 with self.subTest(dtype=dtype): 149 self._verifyLu(data.astype(dtype)) 150 _, p = linalg_ops.lu(data) 151 p_val = self.evaluate([p]) 152 # Make sure p_val is not the identity permutation. 153 self.assertNotAllClose(np.arange(3), p_val) 154 155 for dtype in (np.complex64, np.complex128): 156 with self.subTest(dtype=dtype): 157 complex_data = np.tril(1j * data, -1).astype(dtype) 158 complex_data += np.triu(-1j * data, 1).astype(dtype) 159 complex_data += data 160 self._verifyLu(complex_data) 161 _, p = linalg_ops.lu(data) 162 p_val = self.evaluate([p]) 163 # Make sure p_val is not the identity permutation. 164 self.assertNotAllClose(np.arange(3), p_val) 165 166 def testInvalidMatrix(self): 167 # LU factorization gives an error when the input is singular. 168 # Note: A singular matrix may return without error but it won't be a valid 169 # factorization. 170 for dtype in self.float_types: 171 with self.subTest(dtype=dtype): 172 with self.assertRaises(errors.InvalidArgumentError): 173 self.evaluate( 174 linalg_ops.lu( 175 np.array([[1., 2., 3.], [2., 4., 6.], [2., 3., 4.]], 176 dtype=dtype))) 177 with self.assertRaises(errors.InvalidArgumentError): 178 self.evaluate( 179 linalg_ops.lu( 180 np.array([[[1., 2., 3.], [2., 4., 6.], [1., 2., 3.]], 181 [[1., 2., 3.], [3., 4., 5.], [5., 6., 7.]]], 182 dtype=dtype))) 183 184 def testBatch(self): 185 simple_array = np.array([[[1., -1.], [2., 5.]]]) # shape (1, 2, 2) 186 self._verifyLu(simple_array) 187 self._verifyLu(np.vstack((simple_array, simple_array))) 188 odd_sized_array = np.array([[[4., -1., 2.], [-1., 6., 0], [2., 0., 5.]]]) 189 self._verifyLu(np.vstack((odd_sized_array, odd_sized_array))) 190 191 batch_size = 200 192 193 # Generate random matrices. 194 np.random.seed(42) 195 matrices = np.random.rand(batch_size, 5, 5) 196 self._verifyLu(matrices) 197 198 # Generate random complex valued matrices. 199 np.random.seed(52) 200 matrices = np.random.rand(batch_size, 5, 201 5) + 1j * np.random.rand(batch_size, 5, 5) 202 self._verifyLu(matrices) 203 204 def testLargeMatrix(self): 205 # Generate random matrices. 206 n = 500 207 np.random.seed(64) 208 data = np.random.rand(n, n) 209 self._verifyLu(data) 210 211 # Generate random complex valued matrices. 212 np.random.seed(129) 213 data = np.random.rand(n, n) + 1j * np.random.rand(n, n) 214 self._verifyLu(data) 215 216 @test_util.disable_xla("b/206106619") 217 @test_util.run_in_graph_and_eager_modes(use_gpu=True) 218 def testEmpty(self): 219 self._verifyLu(np.empty([0, 2, 2])) 220 self._verifyLu(np.empty([2, 0, 0])) 221 222 @test_util.run_in_graph_and_eager_modes(use_gpu=True) 223 def testConcurrentExecutesWithoutError(self): 224 matrix_shape = [5, 5] 225 seed = [42, 24] 226 matrix1 = stateless_random_ops.stateless_random_normal( 227 shape=matrix_shape, seed=seed) 228 matrix2 = stateless_random_ops.stateless_random_normal( 229 shape=matrix_shape, seed=seed) 230 self.assertAllEqual(matrix1, matrix2) 231 lu1, p1 = linalg_ops.lu(matrix1) 232 lu2, p2 = linalg_ops.lu(matrix2) 233 lu1_val, p1_val, lu2_val, p2_val = self.evaluate([lu1, p1, lu2, p2]) 234 self.assertAllEqual(lu1_val, lu2_val) 235 self.assertAllEqual(p1_val, p2_val) 236 237 238class LuBenchmark(test.Benchmark): 239 shapes = [ 240 (4, 4), 241 (10, 10), 242 (16, 16), 243 (101, 101), 244 (256, 256), 245 (1000, 1000), 246 (1024, 1024), 247 (2048, 2048), 248 (4096, 4096), 249 (513, 2, 2), 250 (513, 8, 8), 251 (513, 256, 256), 252 (4, 513, 2, 2), 253 ] 254 255 def _GenerateMatrix(self, shape): 256 batch_shape = shape[:-2] 257 shape = shape[-2:] 258 assert shape[0] == shape[1] 259 n = shape[0] 260 matrix = np.ones(shape).astype(np.float32) / (2.0 * n) + np.diag( 261 np.ones(n).astype(np.float32)) 262 return np.tile(matrix, batch_shape + (1, 1)) 263 264 def benchmarkLuOp(self): 265 for shape in self.shapes: 266 with ops.Graph().as_default(), \ 267 session.Session(config=benchmark.benchmark_config()) as sess, \ 268 ops.device("/cpu:0"): 269 matrix = variables.Variable(self._GenerateMatrix(shape)) 270 lu, p = linalg_ops.lu(matrix) 271 self.evaluate(variables.global_variables_initializer()) 272 self.run_op_benchmark( 273 sess, 274 control_flow_ops.group(lu, p), 275 min_iters=25, 276 name="lu_cpu_{shape}".format(shape=shape)) 277 278 if test.is_gpu_available(True): 279 with ops.Graph().as_default(), \ 280 session.Session(config=benchmark.benchmark_config()) as sess, \ 281 ops.device("/device:GPU:0"): 282 matrix = variables.Variable(self._GenerateMatrix(shape)) 283 lu, p = linalg_ops.lu(matrix) 284 self.evaluate(variables.global_variables_initializer()) 285 self.run_op_benchmark( 286 sess, 287 control_flow_ops.group(lu, p), 288 min_iters=25, 289 name="lu_gpu_{shape}".format(shape=shape)) 290 291 292if __name__ == "__main__": 293 test.main() 294