• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for tensorflow.ops.linalg.linalg_impl.tridiagonal_matmul."""
16
17import itertools
18
19import numpy as np
20
21from tensorflow.python.client import session
22from tensorflow.python.eager import context
23from tensorflow.python.framework import constant_op
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import errors_impl
26from tensorflow.python.framework import ops
27from tensorflow.python.ops import array_ops
28from tensorflow.python.ops import control_flow_ops
29from tensorflow.python.ops import gradient_checker_v2
30from tensorflow.python.ops import linalg_ops
31from tensorflow.python.ops import math_ops
32from tensorflow.python.ops import variables
33from tensorflow.python.ops.linalg import linalg_impl
34from tensorflow.python.platform import benchmark
35from tensorflow.python.platform import test
36
37
38class TridiagonalMulOpTest(test.TestCase):
39
40  def _testAllFormats(self,
41                      superdiag,
42                      maindiag,
43                      subdiag,
44                      rhs,
45                      expected,
46                      dtype=dtypes.float64):
47    superdiag_extended = np.pad(superdiag, [0, 1], 'constant')
48    subdiag_extended = np.pad(subdiag, [1, 0], 'constant')
49    diags_compact = np.stack([superdiag_extended, maindiag, subdiag_extended])
50    diags_matrix = np.diag(superdiag, 1) + np.diag(maindiag, 0) + np.diag(
51        subdiag, -1)
52
53    diags_sequence = (constant_op.constant(superdiag_extended, dtype),
54                      constant_op.constant(maindiag, dtype),
55                      constant_op.constant(subdiag_extended, dtype))
56    diags_compact = constant_op.constant(diags_compact, dtype)
57    diags_matrix = constant_op.constant(diags_matrix, dtype)
58    rhs = constant_op.constant(rhs, dtype)
59
60    rhs_batch = array_ops.stack([rhs, 2 * rhs])
61    diags_compact_batch = array_ops.stack([diags_compact, 2 * diags_compact])
62    diags_matrix_batch = array_ops.stack([diags_matrix, 2 * diags_matrix])
63    diags_sequence_batch = [array_ops.stack([x, 2 * x]) for x in diags_sequence]
64
65    results = [
66        linalg_impl.tridiagonal_matmul(
67            diags_sequence, rhs, diagonals_format='sequence'),
68        linalg_impl.tridiagonal_matmul(
69            diags_compact, rhs, diagonals_format='compact'),
70        linalg_impl.tridiagonal_matmul(
71            diags_matrix, rhs, diagonals_format='matrix')
72    ]
73    results_batch = [
74        linalg_impl.tridiagonal_matmul(
75            diags_sequence_batch, rhs_batch, diagonals_format='sequence'),
76        linalg_impl.tridiagonal_matmul(
77            diags_compact_batch, rhs_batch, diagonals_format='compact'),
78        linalg_impl.tridiagonal_matmul(
79            diags_matrix_batch, rhs_batch, diagonals_format='matrix')
80    ]
81
82    with self.cached_session():
83      results = self.evaluate(results)
84      results_batch = self.evaluate(results_batch)
85
86    expected = np.array(expected)
87    expected_batch = np.stack([expected, 4 * expected])
88    for result in results:
89      self.assertAllClose(result, expected)
90    for result in results_batch:
91      self.assertAllClose(result, expected_batch)
92
93  def _makeTridiagonalMatrix(self, superdiag, maindiag, subdiag):
94    super_pad = [[0, 0], [0, 1], [1, 0]]
95    sub_pad = [[0, 0], [1, 0], [0, 1]]
96
97    super_part = array_ops.pad(array_ops.matrix_diag(superdiag), super_pad)
98    main_part = array_ops.matrix_diag(maindiag)
99    sub_part = array_ops.pad(array_ops.matrix_diag(subdiag), sub_pad)
100    return super_part + main_part + sub_part
101
102  def _randomComplexArray(self, shape):
103    np.random.seed(43)
104    return (np.random.uniform(-10, 10, shape) +
105            np.random.uniform(-10, 10, shape) * 1j)
106
107  def _gradientTest(self, diags, rhs, dtype=dtypes.float64):
108
109    def reference_matmul(diags, rhs):
110      matrix = self._makeTridiagonalMatrix(diags[..., 0, :-1], diags[..., 1, :],
111                                           diags[..., 2, 1:])
112      return math_ops.matmul(matrix, rhs)
113
114    diags = constant_op.constant(diags, dtype=dtype)
115    rhs = constant_op.constant(rhs, dtype=dtype)
116    with self.cached_session():
117      grad_reference, _ = gradient_checker_v2.compute_gradient(
118          reference_matmul, [diags, rhs])
119      grad_theoretical, grad_numerical = gradient_checker_v2.compute_gradient(
120          linalg_impl.tridiagonal_matmul, [diags, rhs])
121    self.assertAllClose(grad_theoretical, grad_numerical)
122    self.assertAllClose(grad_theoretical, grad_reference)
123
124  def test2x2(self):
125    self._testAllFormats([1], [2, 3], [4], [[2, 1], [4, 3]], [[8, 5], [20, 13]])
126
127  def test3x3(self):
128    for dtype in [dtypes.float32, dtypes.float64]:
129      self._testAllFormats([1, 2], [1, 2, 1], [2, 1], [[1, 1], [2, 2], [3, 3]],
130                           [[3, 3], [12, 12], [5, 5]],
131                           dtype=dtype)
132
133  def testComplex(self):
134    for dtype in [dtypes.complex64, dtypes.complex128]:
135      self._testAllFormats([1j, 1j], [1, -1, 0], [1j, 1j],
136                           np.array([[1, 1j], [1, 1j], [1, 1j]]),
137                           [[1 + 1j, -1 + 1j], [-1 + 2j, -2 - 1j], [1j, -1]],
138                           dtype=dtype)
139
140  def testBatch(self):
141    b = 20
142    m = 10
143    n = 15
144    superdiag = self._randomComplexArray((b, m - 1))
145    maindiag = self._randomComplexArray((b, m))
146    subdiag = self._randomComplexArray((b, m - 1))
147    rhs = self._randomComplexArray((b, m, n))
148    matrix = np.stack([np.diag(superdiag[i], 1) + \
149                       np.diag(maindiag[i], 0) + \
150                       np.diag(subdiag[i], -1) for i in range(b)])
151    expected_result = np.matmul(matrix, rhs)
152    result = linalg_impl.tridiagonal_matmul(
153        constant_op.constant(matrix, dtype=dtypes.complex128),
154        constant_op.constant(rhs, dtype=dtypes.complex128),
155        diagonals_format='matrix')
156
157    with self.cached_session():
158      result = self.evaluate(result)
159
160    self.assertAllClose(result, expected_result)
161
162  def testGradientSmall(self):
163    self._gradientTest([[[1, 2, 0], [1, 2, 3], [0, 1, 2]]],
164                       [[[1, 2], [3, 4], [5, 6]]],
165                       dtype=dtypes.float64)
166
167  def testGradientComplexSmall(self):
168    self._gradientTest(
169        np.array([[[1 + 1j, 2j, 0], [1 + 2j, 2j, 3 + 0j], [0, 1j, 2 + 0j]]]),
170        np.array([[[1j, 2 + 0j], [3 + 1j, 4j], [5j, 6 + 3j]]]),
171        dtype=dtypes.complex128)
172
173  def testGradientComplexWithBatches(self):
174    b = 5
175    m = 10
176    n = 15
177    diags = self._randomComplexArray((b, 3, m))
178    rhs = self._randomComplexArray((b, m, n))
179    self._gradientTest(diags, rhs, dtype=dtypes.complex128)
180
181  def _testErrorWithShapesEager(self, exception_regex, superdiag_shape,
182                                maindiag_shape, subdiag_shape, rhs_shape):
183    with context.eager_mode():
184      superdiag = array_ops.ones(superdiag_shape)
185      maindiag = array_ops.ones(maindiag_shape)
186      subdiag = array_ops.ones(subdiag_shape)
187      rhs = array_ops.ones(rhs_shape)
188      with self.assertRaisesRegex(errors_impl.InvalidArgumentError,
189                                  exception_regex):
190        linalg_ops.tridiagonal_mat_mul(superdiag, maindiag, subdiag, rhs)
191
192  def testInvalidShapesEagerGpu(self):
193    if test.is_built_with_rocm():
194      self.skipTest('Incorrect Regex on rocm')
195    if not test.is_gpu_available():
196      self.skipTest('Test requires GPU')
197    self._testErrorWithShapesEager('Input must have rank >= 2, but got ',
198                                   [2], [2], [2], [2])
199    self._testErrorWithShapesEager(
200        'superdiag must have same rank as rhs, but got 3 and 2',
201        [2, 1, 2], [2, 1], [2, 1], [2, 2])
202    self._testErrorWithShapesEager(
203        'maindiag must have same outer dimensions as rhs, but for index 0, got '
204        '3 and 2',
205        [2, 1, 2], [3, 1, 2], [2, 1, 2], [2, 2, 2])
206    self._testErrorWithShapesEager(
207        "subdiag's second-to-last dimension must be 1, but got 3",
208        [2, 1, 2], [2, 1, 2], [2, 3, 2], [2, 2, 2])
209    self._testErrorWithShapesEager(
210        "subdiag's last dimension size must be rhs's second-to-last dimension "
211        "size, but got 3 and 2",
212        [2, 1, 2], [2, 1, 2], [2, 1, 3], [2, 2, 2])
213
214  # Benchmark
215  class TridiagonalMatMulBenchmark(test.Benchmark):
216    sizes = [(100000, 1, 1), (1000000, 1, 1), (10000000, 1, 1), (100000, 10, 1),
217             (100000, 100, 1), (10000, 1, 100), (10000, 1, 1000),
218             (10000, 1, 10000)]
219
220    def baseline(self, upper, diag, lower, vec):
221      diag_part = array_ops.expand_dims(diag, -1) * vec
222      lower_part = array_ops.pad(
223          array_ops.expand_dims(lower[:, 1:], -1) * vec[:, :-1, :],
224          [[0, 0], [1, 0], [0, 0]])
225      upper_part = array_ops.pad(
226          array_ops.expand_dims(upper[:, :-1], -1) * vec[:, 1:, :],
227          [[0, 0], [0, 1], [0, 0]])
228      return lower_part + diag_part + upper_part
229
230    def _generateData(self, batch_size, m, n, seed=42):
231      np.random.seed(seed)
232      data = np.random.normal(size=(batch_size, m, 3 + n))
233      return (variables.Variable(data[:, :, 0], dtype=dtypes.float64),
234              variables.Variable(data[:, :, 1], dtype=dtypes.float64),
235              variables.Variable(data[:, :, 2], dtype=dtypes.float64),
236              variables.Variable(data[:, :, 3:], dtype=dtypes.float64))
237
238    def benchmarkTridiagonalMulOp(self):
239      devices = [('/cpu:0', 'cpu')]
240      if test.is_gpu_available(cuda_only=True):
241        devices += [('/gpu:0', 'gpu')]
242
243      for device_option, size_option in itertools.product(devices, self.sizes):
244        device_id, device_name = device_option
245        m, batch_size, n = size_option
246
247        with ops.Graph().as_default(), \
248            session.Session(config=benchmark.benchmark_config()) as sess, \
249            ops.device(device_id):
250          upper, diag, lower, vec = self._generateData(batch_size, m, n)
251          x1 = self.baseline(upper, diag, lower, vec)
252          x2 = linalg_impl.tridiagonal_matmul((upper, diag, lower),
253                                              vec,
254                                              diagonals_format='sequence')
255
256          self.evaluate(variables.global_variables_initializer())
257          self.run_op_benchmark(
258              sess,
259              control_flow_ops.group(x1),
260              min_iters=10,
261              store_memory_usage=False,
262              name=('tridiagonal_matmul_baseline_%s'
263                    '_batch_size_%d_m_%d_n_%d' %
264                    (device_name, batch_size, m, n)))
265
266          self.run_op_benchmark(
267              sess,
268              control_flow_ops.group(x2),
269              min_iters=10,
270              store_memory_usage=False,
271              name=('tridiagonal_matmul_%s_batch_size_%d_m_%d_n_%d' %
272                    (device_name, batch_size, m, n)))
273
274
275if __name__ == '__main__':
276  test.main()
277