1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for tensorflow.ops.linalg.linalg_impl.tridiagonal_matmul.""" 16 17import itertools 18 19import numpy as np 20 21from tensorflow.python.client import session 22from tensorflow.python.eager import context 23from tensorflow.python.framework import constant_op 24from tensorflow.python.framework import dtypes 25from tensorflow.python.framework import errors_impl 26from tensorflow.python.framework import ops 27from tensorflow.python.ops import array_ops 28from tensorflow.python.ops import control_flow_ops 29from tensorflow.python.ops import gradient_checker_v2 30from tensorflow.python.ops import linalg_ops 31from tensorflow.python.ops import math_ops 32from tensorflow.python.ops import variables 33from tensorflow.python.ops.linalg import linalg_impl 34from tensorflow.python.platform import benchmark 35from tensorflow.python.platform import test 36 37 38class TridiagonalMulOpTest(test.TestCase): 39 40 def _testAllFormats(self, 41 superdiag, 42 maindiag, 43 subdiag, 44 rhs, 45 expected, 46 dtype=dtypes.float64): 47 superdiag_extended = np.pad(superdiag, [0, 1], 'constant') 48 subdiag_extended = np.pad(subdiag, [1, 0], 'constant') 49 diags_compact = np.stack([superdiag_extended, maindiag, subdiag_extended]) 50 diags_matrix = np.diag(superdiag, 1) + np.diag(maindiag, 0) + np.diag( 51 subdiag, -1) 52 53 diags_sequence = (constant_op.constant(superdiag_extended, dtype), 54 constant_op.constant(maindiag, dtype), 55 constant_op.constant(subdiag_extended, dtype)) 56 diags_compact = constant_op.constant(diags_compact, dtype) 57 diags_matrix = constant_op.constant(diags_matrix, dtype) 58 rhs = constant_op.constant(rhs, dtype) 59 60 rhs_batch = array_ops.stack([rhs, 2 * rhs]) 61 diags_compact_batch = array_ops.stack([diags_compact, 2 * diags_compact]) 62 diags_matrix_batch = array_ops.stack([diags_matrix, 2 * diags_matrix]) 63 diags_sequence_batch = [array_ops.stack([x, 2 * x]) for x in diags_sequence] 64 65 results = [ 66 linalg_impl.tridiagonal_matmul( 67 diags_sequence, rhs, diagonals_format='sequence'), 68 linalg_impl.tridiagonal_matmul( 69 diags_compact, rhs, diagonals_format='compact'), 70 linalg_impl.tridiagonal_matmul( 71 diags_matrix, rhs, diagonals_format='matrix') 72 ] 73 results_batch = [ 74 linalg_impl.tridiagonal_matmul( 75 diags_sequence_batch, rhs_batch, diagonals_format='sequence'), 76 linalg_impl.tridiagonal_matmul( 77 diags_compact_batch, rhs_batch, diagonals_format='compact'), 78 linalg_impl.tridiagonal_matmul( 79 diags_matrix_batch, rhs_batch, diagonals_format='matrix') 80 ] 81 82 with self.cached_session(): 83 results = self.evaluate(results) 84 results_batch = self.evaluate(results_batch) 85 86 expected = np.array(expected) 87 expected_batch = np.stack([expected, 4 * expected]) 88 for result in results: 89 self.assertAllClose(result, expected) 90 for result in results_batch: 91 self.assertAllClose(result, expected_batch) 92 93 def _makeTridiagonalMatrix(self, superdiag, maindiag, subdiag): 94 super_pad = [[0, 0], [0, 1], [1, 0]] 95 sub_pad = [[0, 0], [1, 0], [0, 1]] 96 97 super_part = array_ops.pad(array_ops.matrix_diag(superdiag), super_pad) 98 main_part = array_ops.matrix_diag(maindiag) 99 sub_part = array_ops.pad(array_ops.matrix_diag(subdiag), sub_pad) 100 return super_part + main_part + sub_part 101 102 def _randomComplexArray(self, shape): 103 np.random.seed(43) 104 return (np.random.uniform(-10, 10, shape) + 105 np.random.uniform(-10, 10, shape) * 1j) 106 107 def _gradientTest(self, diags, rhs, dtype=dtypes.float64): 108 109 def reference_matmul(diags, rhs): 110 matrix = self._makeTridiagonalMatrix(diags[..., 0, :-1], diags[..., 1, :], 111 diags[..., 2, 1:]) 112 return math_ops.matmul(matrix, rhs) 113 114 diags = constant_op.constant(diags, dtype=dtype) 115 rhs = constant_op.constant(rhs, dtype=dtype) 116 with self.cached_session(): 117 grad_reference, _ = gradient_checker_v2.compute_gradient( 118 reference_matmul, [diags, rhs]) 119 grad_theoretical, grad_numerical = gradient_checker_v2.compute_gradient( 120 linalg_impl.tridiagonal_matmul, [diags, rhs]) 121 self.assertAllClose(grad_theoretical, grad_numerical) 122 self.assertAllClose(grad_theoretical, grad_reference) 123 124 def test2x2(self): 125 self._testAllFormats([1], [2, 3], [4], [[2, 1], [4, 3]], [[8, 5], [20, 13]]) 126 127 def test3x3(self): 128 for dtype in [dtypes.float32, dtypes.float64]: 129 self._testAllFormats([1, 2], [1, 2, 1], [2, 1], [[1, 1], [2, 2], [3, 3]], 130 [[3, 3], [12, 12], [5, 5]], 131 dtype=dtype) 132 133 def testComplex(self): 134 for dtype in [dtypes.complex64, dtypes.complex128]: 135 self._testAllFormats([1j, 1j], [1, -1, 0], [1j, 1j], 136 np.array([[1, 1j], [1, 1j], [1, 1j]]), 137 [[1 + 1j, -1 + 1j], [-1 + 2j, -2 - 1j], [1j, -1]], 138 dtype=dtype) 139 140 def testBatch(self): 141 b = 20 142 m = 10 143 n = 15 144 superdiag = self._randomComplexArray((b, m - 1)) 145 maindiag = self._randomComplexArray((b, m)) 146 subdiag = self._randomComplexArray((b, m - 1)) 147 rhs = self._randomComplexArray((b, m, n)) 148 matrix = np.stack([np.diag(superdiag[i], 1) + \ 149 np.diag(maindiag[i], 0) + \ 150 np.diag(subdiag[i], -1) for i in range(b)]) 151 expected_result = np.matmul(matrix, rhs) 152 result = linalg_impl.tridiagonal_matmul( 153 constant_op.constant(matrix, dtype=dtypes.complex128), 154 constant_op.constant(rhs, dtype=dtypes.complex128), 155 diagonals_format='matrix') 156 157 with self.cached_session(): 158 result = self.evaluate(result) 159 160 self.assertAllClose(result, expected_result) 161 162 def testGradientSmall(self): 163 self._gradientTest([[[1, 2, 0], [1, 2, 3], [0, 1, 2]]], 164 [[[1, 2], [3, 4], [5, 6]]], 165 dtype=dtypes.float64) 166 167 def testGradientComplexSmall(self): 168 self._gradientTest( 169 np.array([[[1 + 1j, 2j, 0], [1 + 2j, 2j, 3 + 0j], [0, 1j, 2 + 0j]]]), 170 np.array([[[1j, 2 + 0j], [3 + 1j, 4j], [5j, 6 + 3j]]]), 171 dtype=dtypes.complex128) 172 173 def testGradientComplexWithBatches(self): 174 b = 5 175 m = 10 176 n = 15 177 diags = self._randomComplexArray((b, 3, m)) 178 rhs = self._randomComplexArray((b, m, n)) 179 self._gradientTest(diags, rhs, dtype=dtypes.complex128) 180 181 def _testErrorWithShapesEager(self, exception_regex, superdiag_shape, 182 maindiag_shape, subdiag_shape, rhs_shape): 183 with context.eager_mode(): 184 superdiag = array_ops.ones(superdiag_shape) 185 maindiag = array_ops.ones(maindiag_shape) 186 subdiag = array_ops.ones(subdiag_shape) 187 rhs = array_ops.ones(rhs_shape) 188 with self.assertRaisesRegex(errors_impl.InvalidArgumentError, 189 exception_regex): 190 linalg_ops.tridiagonal_mat_mul(superdiag, maindiag, subdiag, rhs) 191 192 def testInvalidShapesEagerGpu(self): 193 if test.is_built_with_rocm(): 194 self.skipTest('Incorrect Regex on rocm') 195 if not test.is_gpu_available(): 196 self.skipTest('Test requires GPU') 197 self._testErrorWithShapesEager('Input must have rank >= 2, but got ', 198 [2], [2], [2], [2]) 199 self._testErrorWithShapesEager( 200 'superdiag must have same rank as rhs, but got 3 and 2', 201 [2, 1, 2], [2, 1], [2, 1], [2, 2]) 202 self._testErrorWithShapesEager( 203 'maindiag must have same outer dimensions as rhs, but for index 0, got ' 204 '3 and 2', 205 [2, 1, 2], [3, 1, 2], [2, 1, 2], [2, 2, 2]) 206 self._testErrorWithShapesEager( 207 "subdiag's second-to-last dimension must be 1, but got 3", 208 [2, 1, 2], [2, 1, 2], [2, 3, 2], [2, 2, 2]) 209 self._testErrorWithShapesEager( 210 "subdiag's last dimension size must be rhs's second-to-last dimension " 211 "size, but got 3 and 2", 212 [2, 1, 2], [2, 1, 2], [2, 1, 3], [2, 2, 2]) 213 214 # Benchmark 215 class TridiagonalMatMulBenchmark(test.Benchmark): 216 sizes = [(100000, 1, 1), (1000000, 1, 1), (10000000, 1, 1), (100000, 10, 1), 217 (100000, 100, 1), (10000, 1, 100), (10000, 1, 1000), 218 (10000, 1, 10000)] 219 220 def baseline(self, upper, diag, lower, vec): 221 diag_part = array_ops.expand_dims(diag, -1) * vec 222 lower_part = array_ops.pad( 223 array_ops.expand_dims(lower[:, 1:], -1) * vec[:, :-1, :], 224 [[0, 0], [1, 0], [0, 0]]) 225 upper_part = array_ops.pad( 226 array_ops.expand_dims(upper[:, :-1], -1) * vec[:, 1:, :], 227 [[0, 0], [0, 1], [0, 0]]) 228 return lower_part + diag_part + upper_part 229 230 def _generateData(self, batch_size, m, n, seed=42): 231 np.random.seed(seed) 232 data = np.random.normal(size=(batch_size, m, 3 + n)) 233 return (variables.Variable(data[:, :, 0], dtype=dtypes.float64), 234 variables.Variable(data[:, :, 1], dtype=dtypes.float64), 235 variables.Variable(data[:, :, 2], dtype=dtypes.float64), 236 variables.Variable(data[:, :, 3:], dtype=dtypes.float64)) 237 238 def benchmarkTridiagonalMulOp(self): 239 devices = [('/cpu:0', 'cpu')] 240 if test.is_gpu_available(cuda_only=True): 241 devices += [('/gpu:0', 'gpu')] 242 243 for device_option, size_option in itertools.product(devices, self.sizes): 244 device_id, device_name = device_option 245 m, batch_size, n = size_option 246 247 with ops.Graph().as_default(), \ 248 session.Session(config=benchmark.benchmark_config()) as sess, \ 249 ops.device(device_id): 250 upper, diag, lower, vec = self._generateData(batch_size, m, n) 251 x1 = self.baseline(upper, diag, lower, vec) 252 x2 = linalg_impl.tridiagonal_matmul((upper, diag, lower), 253 vec, 254 diagonals_format='sequence') 255 256 self.evaluate(variables.global_variables_initializer()) 257 self.run_op_benchmark( 258 sess, 259 control_flow_ops.group(x1), 260 min_iters=10, 261 store_memory_usage=False, 262 name=('tridiagonal_matmul_baseline_%s' 263 '_batch_size_%d_m_%d_n_%d' % 264 (device_name, batch_size, m, n))) 265 266 self.run_op_benchmark( 267 sess, 268 control_flow_ops.group(x2), 269 min_iters=10, 270 store_memory_usage=False, 271 name=('tridiagonal_matmul_%s_batch_size_%d_m_%d_n_%d' % 272 (device_name, batch_size, m, n))) 273 274 275if __name__ == '__main__': 276 test.main() 277