• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for tensorflow.ops.tf.Lu."""
16
17import numpy as np
18
19from tensorflow.python.client import session
20from tensorflow.python.framework import dtypes
21from tensorflow.python.framework import errors
22from tensorflow.python.framework import ops
23from tensorflow.python.framework import test_util
24from tensorflow.python.ops import array_ops
25from tensorflow.python.ops import control_flow_ops
26from tensorflow.python.ops import linalg_ops
27from tensorflow.python.ops import map_fn
28from tensorflow.python.ops import math_ops
29from tensorflow.python.ops import stateless_random_ops
30from tensorflow.python.ops import variables
31from tensorflow.python.platform import benchmark
32from tensorflow.python.platform import test
33
34
35@test_util.with_eager_op_as_function
36class LuOpTest(test.TestCase):
37
38  @property
39  def float_types(self):
40    return set((np.float64, np.float32, np.complex64, np.complex128))
41
42  def _verifyLuBase(self, x, lower, upper, perm, verification,
43                    output_idx_type):
44    lower_np, upper_np, perm_np, verification_np = self.evaluate(
45        [lower, upper, perm, verification])
46
47    self.assertAllClose(x, verification_np)
48    self.assertShapeEqual(x, lower)
49    self.assertShapeEqual(x, upper)
50
51    self.assertAllEqual(x.shape[:-1], perm.shape.as_list())
52
53    # Check dtypes are as expected.
54    self.assertEqual(x.dtype, lower_np.dtype)
55    self.assertEqual(x.dtype, upper_np.dtype)
56    self.assertEqual(output_idx_type.as_numpy_dtype, perm_np.dtype)
57
58    # Check that the permutation is valid.
59    if perm_np.shape[-1] > 0:
60      perm_reshaped = np.reshape(perm_np, (-1, perm_np.shape[-1]))
61      for perm_vector in perm_reshaped:
62        self.assertAllClose(np.arange(len(perm_vector)), np.sort(perm_vector))
63
64  def _verifyLu(self, x, output_idx_type=dtypes.int64):
65    # Verify that Px = LU.
66    lu, perm = linalg_ops.lu(x, output_idx_type=output_idx_type)
67
68    # Prepare the lower factor of shape num_rows x num_rows
69    lu_shape = np.array(lu.shape.as_list())
70    batch_shape = lu_shape[:-2]
71    num_rows = lu_shape[-2]
72    num_cols = lu_shape[-1]
73
74    lower = array_ops.matrix_band_part(lu, -1, 0)
75
76    if num_rows > num_cols:
77      eye = linalg_ops.eye(
78          num_rows, batch_shape=batch_shape, dtype=lower.dtype)
79      lower = array_ops.concat([lower, eye[..., num_cols:]], axis=-1)
80    elif num_rows < num_cols:
81      lower = lower[..., :num_rows]
82
83    # Fill the diagonal with ones.
84    ones_diag = array_ops.ones(
85        np.append(batch_shape, num_rows), dtype=lower.dtype)
86    lower = array_ops.matrix_set_diag(lower, ones_diag)
87
88    # Prepare the upper factor.
89    upper = array_ops.matrix_band_part(lu, 0, -1)
90
91    verification = test_util.matmul_without_tf32(lower, upper)
92
93    # Permute the rows of product of the Cholesky factors.
94    if num_rows > 0:
95      # Reshape the product of the triangular factors and permutation indices
96      # to a single batch dimension. This makes it easy to apply
97      # invert_permutation and gather_nd ops.
98      perm_reshaped = array_ops.reshape(perm, [-1, num_rows])
99      verification_reshaped = array_ops.reshape(verification,
100                                                [-1, num_rows, num_cols])
101      # Invert the permutation in each batch.
102      inv_perm_reshaped = map_fn.map_fn(array_ops.invert_permutation,
103                                        perm_reshaped)
104      batch_size = perm_reshaped.shape.as_list()[0]
105      # Prepare the batch indices with the same shape as the permutation.
106      # The corresponding batch index is paired with each of the `num_rows`
107      # permutation indices.
108      batch_indices = math_ops.cast(
109          array_ops.broadcast_to(
110              math_ops.range(batch_size)[:, None], perm_reshaped.shape),
111          dtype=output_idx_type)
112      if inv_perm_reshaped.shape == [0]:
113        inv_perm_reshaped = array_ops.zeros_like(batch_indices)
114      permuted_verification_reshaped = array_ops.gather_nd(
115          verification_reshaped,
116          array_ops.stack([batch_indices, inv_perm_reshaped], axis=-1))
117
118      # Reshape the verification matrix back to the original shape.
119      verification = array_ops.reshape(permuted_verification_reshaped,
120                                       lu_shape)
121
122    self._verifyLuBase(x, lower, upper, perm, verification,
123                       output_idx_type)
124
125  def testBasic(self):
126    data = np.array([[4., -1., 2.], [-1., 6., 0], [10., 0., 5.]])
127
128    for dtype in (np.float32, np.float64):
129      for output_idx_type in (dtypes.int32, dtypes.int64):
130        with self.subTest(dtype=dtype, output_idx_type=output_idx_type):
131          self._verifyLu(data.astype(dtype), output_idx_type=output_idx_type)
132
133    for dtype in (np.complex64, np.complex128):
134      for output_idx_type in (dtypes.int32, dtypes.int64):
135        with self.subTest(dtype=dtype, output_idx_type=output_idx_type):
136          complex_data = np.tril(1j * data, -1).astype(dtype)
137          complex_data += np.triu(-1j * data, 1).astype(dtype)
138          complex_data += data
139          self._verifyLu(complex_data, output_idx_type=output_idx_type)
140
141  def testPivoting(self):
142    # This matrix triggers partial pivoting because the first diagonal entry
143    # is small.
144    data = np.array([[1e-9, 1., 0.], [1., 0., 0], [0., 1., 5]])
145    self._verifyLu(data.astype(np.float32))
146
147    for dtype in (np.float32, np.float64):
148      with self.subTest(dtype=dtype):
149        self._verifyLu(data.astype(dtype))
150        _, p = linalg_ops.lu(data)
151        p_val = self.evaluate([p])
152        # Make sure p_val is not the identity permutation.
153        self.assertNotAllClose(np.arange(3), p_val)
154
155    for dtype in (np.complex64, np.complex128):
156      with self.subTest(dtype=dtype):
157        complex_data = np.tril(1j * data, -1).astype(dtype)
158        complex_data += np.triu(-1j * data, 1).astype(dtype)
159        complex_data += data
160        self._verifyLu(complex_data)
161        _, p = linalg_ops.lu(data)
162        p_val = self.evaluate([p])
163        # Make sure p_val is not the identity permutation.
164        self.assertNotAllClose(np.arange(3), p_val)
165
166  def testInvalidMatrix(self):
167    # LU factorization gives an error when the input is singular.
168    # Note: A singular matrix may return without error but it won't be a valid
169    # factorization.
170    for dtype in self.float_types:
171      with self.subTest(dtype=dtype):
172        with self.assertRaises(errors.InvalidArgumentError):
173          self.evaluate(
174              linalg_ops.lu(
175                  np.array([[1., 2., 3.], [2., 4., 6.], [2., 3., 4.]],
176                           dtype=dtype)))
177        with self.assertRaises(errors.InvalidArgumentError):
178          self.evaluate(
179              linalg_ops.lu(
180                  np.array([[[1., 2., 3.], [2., 4., 6.], [1., 2., 3.]],
181                            [[1., 2., 3.], [3., 4., 5.], [5., 6., 7.]]],
182                           dtype=dtype)))
183
184  def testBatch(self):
185    simple_array = np.array([[[1., -1.], [2., 5.]]])  # shape (1, 2, 2)
186    self._verifyLu(simple_array)
187    self._verifyLu(np.vstack((simple_array, simple_array)))
188    odd_sized_array = np.array([[[4., -1., 2.], [-1., 6., 0], [2., 0., 5.]]])
189    self._verifyLu(np.vstack((odd_sized_array, odd_sized_array)))
190
191    batch_size = 200
192
193    # Generate random matrices.
194    np.random.seed(42)
195    matrices = np.random.rand(batch_size, 5, 5)
196    self._verifyLu(matrices)
197
198    # Generate random complex valued matrices.
199    np.random.seed(52)
200    matrices = np.random.rand(batch_size, 5,
201                              5) + 1j * np.random.rand(batch_size, 5, 5)
202    self._verifyLu(matrices)
203
204  def testLargeMatrix(self):
205    # Generate random matrices.
206    n = 500
207    np.random.seed(64)
208    data = np.random.rand(n, n)
209    self._verifyLu(data)
210
211    # Generate random complex valued matrices.
212    np.random.seed(129)
213    data = np.random.rand(n, n) + 1j * np.random.rand(n, n)
214    self._verifyLu(data)
215
216  @test_util.disable_xla("b/206106619")
217  @test_util.run_in_graph_and_eager_modes(use_gpu=True)
218  def testEmpty(self):
219    self._verifyLu(np.empty([0, 2, 2]))
220    self._verifyLu(np.empty([2, 0, 0]))
221
222  @test_util.run_in_graph_and_eager_modes(use_gpu=True)
223  def testConcurrentExecutesWithoutError(self):
224    matrix_shape = [5, 5]
225    seed = [42, 24]
226    matrix1 = stateless_random_ops.stateless_random_normal(
227        shape=matrix_shape, seed=seed)
228    matrix2 = stateless_random_ops.stateless_random_normal(
229        shape=matrix_shape, seed=seed)
230    self.assertAllEqual(matrix1, matrix2)
231    lu1, p1 = linalg_ops.lu(matrix1)
232    lu2, p2 = linalg_ops.lu(matrix2)
233    lu1_val, p1_val, lu2_val, p2_val = self.evaluate([lu1, p1, lu2, p2])
234    self.assertAllEqual(lu1_val, lu2_val)
235    self.assertAllEqual(p1_val, p2_val)
236
237
238class LuBenchmark(test.Benchmark):
239  shapes = [
240      (4, 4),
241      (10, 10),
242      (16, 16),
243      (101, 101),
244      (256, 256),
245      (1000, 1000),
246      (1024, 1024),
247      (2048, 2048),
248      (4096, 4096),
249      (513, 2, 2),
250      (513, 8, 8),
251      (513, 256, 256),
252      (4, 513, 2, 2),
253  ]
254
255  def _GenerateMatrix(self, shape):
256    batch_shape = shape[:-2]
257    shape = shape[-2:]
258    assert shape[0] == shape[1]
259    n = shape[0]
260    matrix = np.ones(shape).astype(np.float32) / (2.0 * n) + np.diag(
261        np.ones(n).astype(np.float32))
262    return np.tile(matrix, batch_shape + (1, 1))
263
264  def benchmarkLuOp(self):
265    for shape in self.shapes:
266      with ops.Graph().as_default(), \
267          session.Session(config=benchmark.benchmark_config()) as sess, \
268          ops.device("/cpu:0"):
269        matrix = variables.Variable(self._GenerateMatrix(shape))
270        lu, p = linalg_ops.lu(matrix)
271        self.evaluate(variables.global_variables_initializer())
272        self.run_op_benchmark(
273            sess,
274            control_flow_ops.group(lu, p),
275            min_iters=25,
276            name="lu_cpu_{shape}".format(shape=shape))
277
278      if test.is_gpu_available(True):
279        with ops.Graph().as_default(), \
280            session.Session(config=benchmark.benchmark_config()) as sess, \
281            ops.device("/device:GPU:0"):
282          matrix = variables.Variable(self._GenerateMatrix(shape))
283          lu, p = linalg_ops.lu(matrix)
284          self.evaluate(variables.global_variables_initializer())
285          self.run_op_benchmark(
286              sess,
287              control_flow_ops.group(lu, p),
288              min_iters=25,
289              name="lu_gpu_{shape}".format(shape=shape))
290
291
292if __name__ == "__main__":
293  test.main()
294