• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for tensorflow.ops.math_ops.matrix_solve."""
16
17import numpy as np
18
19from tensorflow.python.client import session
20from tensorflow.python.eager import context
21from tensorflow.python.framework import constant_op
22from tensorflow.python.framework import dtypes
23from tensorflow.python.framework import errors_impl
24from tensorflow.python.framework import ops
25from tensorflow.python.framework import test_util
26from tensorflow.python.ops import array_ops
27from tensorflow.python.ops import control_flow_ops
28from tensorflow.python.ops import linalg_ops
29from tensorflow.python.ops import math_ops
30from tensorflow.python.ops import variables
31from tensorflow.python.platform import benchmark
32from tensorflow.python.platform import test as test_lib
33
34
35def _AddTest(test, op_name, testcase_name, fn):
36  test_name = "_".join(["test", op_name, testcase_name])
37  if hasattr(test, test_name):
38    raise RuntimeError("Test %s defined more than once" % test_name)
39  setattr(test, test_name, fn)
40
41
42def _GenerateTestData(matrix_shape, num_rhs):
43  batch_shape = matrix_shape[:-2]
44  matrix_shape = matrix_shape[-2:]
45  m = matrix_shape[-2]
46  np.random.seed(1)
47  matrix = np.random.uniform(
48      low=-1.0, high=1.0,
49      size=np.prod(matrix_shape)).reshape(matrix_shape).astype(np.float32)
50  rhs = np.ones([m, num_rhs]).astype(np.float32)
51  matrix = variables.Variable(
52      np.tile(matrix, batch_shape + (1, 1)), trainable=False)
53  rhs = variables.Variable(np.tile(rhs, batch_shape + (1, 1)), trainable=False)
54  return matrix, rhs
55
56
57def _SolveWithNumpy(matrix, rhs, l2_regularizer=0):
58  if l2_regularizer == 0:
59    np_ans, _, _, _ = np.linalg.lstsq(matrix, rhs)
60    return np_ans
61  else:
62    rows = matrix.shape[-2]
63    cols = matrix.shape[-1]
64    if rows >= cols:
65      preconditioner = l2_regularizer * np.identity(cols)
66      gramian = np.dot(np.conj(matrix.T), matrix) + preconditioner
67      rhs = np.dot(np.conj(matrix.T), rhs)
68      return np.linalg.solve(gramian, rhs)
69    else:
70      preconditioner = l2_regularizer * np.identity(rows)
71      gramian = np.dot(matrix, np.conj(matrix.T)) + preconditioner
72      z = np.linalg.solve(gramian, rhs)
73      return np.dot(np.conj(matrix.T), z)
74
75
76@test_util.with_eager_op_as_function
77class MatrixSolveLsOpTest(test_lib.TestCase):
78
79  def _verifySolve(self,
80                   x,
81                   y,
82                   dtype,
83                   use_placeholder,
84                   fast,
85                   l2_regularizer,
86                   batch_shape=()):
87    if not fast and l2_regularizer != 0:
88      # The slow path does not support regularization.
89      return
90    if use_placeholder and context.executing_eagerly():
91      return
92    maxdim = np.max(x.shape)
93    if dtype == np.float32 or dtype == np.complex64:
94      tol = maxdim * 5e-4
95    else:
96      tol = maxdim * 5e-7
97      a = x.astype(dtype)
98      b = y.astype(dtype)
99      if dtype in [np.complex64, np.complex128]:
100        a.imag = a.real
101        b.imag = b.real
102      # numpy.linalg.lstqr does not batching, so we just solve a single system
103      # and replicate the solution. and residual norm.
104      np_ans = _SolveWithNumpy(x, y, l2_regularizer=l2_regularizer)
105      np_r = np.dot(np.conj(a.T), b - np.dot(a, np_ans))
106      np_r_norm = np.sqrt(np.sum(np.conj(np_r) * np_r))
107      if batch_shape != ():
108        a = np.tile(a, batch_shape + (1, 1))
109        b = np.tile(b, batch_shape + (1, 1))
110        np_ans = np.tile(np_ans, batch_shape + (1, 1))
111        np_r_norm = np.tile(np_r_norm, batch_shape)
112      if use_placeholder:
113        a_ph = array_ops.placeholder(dtypes.as_dtype(dtype))
114        b_ph = array_ops.placeholder(dtypes.as_dtype(dtype))
115        feed_dict = {a_ph: a, b_ph: b}
116        tf_ans = linalg_ops.matrix_solve_ls(
117            a_ph, b_ph, fast=fast, l2_regularizer=l2_regularizer)
118      else:
119        tf_ans = linalg_ops.matrix_solve_ls(
120            a, b, fast=fast, l2_regularizer=l2_regularizer)
121        feed_dict = None
122        self.assertEqual(np_ans.shape, tf_ans.get_shape())
123      if feed_dict:
124        with self.session() as sess:
125          tf_ans_val = sess.run(tf_ans, feed_dict=feed_dict)
126      else:
127        tf_ans_val = self.evaluate(tf_ans)
128      self.assertEqual(np_ans.shape, tf_ans_val.shape)
129      self.assertAllClose(np_ans, tf_ans_val, atol=2 * tol, rtol=2 * tol)
130
131      if l2_regularizer == 0:
132        # The least squares solution should satisfy A^H * (b - A*x) = 0.
133        tf_r = b - math_ops.matmul(a, tf_ans)
134        tf_r = math_ops.matmul(a, tf_r, adjoint_a=True)
135        tf_r_norm = linalg_ops.norm(tf_r, ord="fro", axis=[-2, -1])
136        if feed_dict:
137          with self.session() as sess:
138            tf_ans_val, tf_r_norm_val = sess.run([tf_ans, tf_r_norm],
139                                                 feed_dict=feed_dict)
140        else:
141          tf_ans_val, tf_r_norm_val = self.evaluate([tf_ans, tf_r_norm])
142        self.assertAllClose(np_r_norm, tf_r_norm_val, atol=tol, rtol=tol)
143
144  @test_util.run_in_graph_and_eager_modes(use_gpu=True)
145  def testWrongDimensions(self):
146    # The matrix and right-hand sides should have the same number of rows.
147    with self.session():
148      matrix = constant_op.constant([[1., 0.], [0., 1.]])
149      rhs = constant_op.constant([[1., 0.]])
150      with self.assertRaises((ValueError, errors_impl.InvalidArgumentError)):
151        linalg_ops.matrix_solve_ls(matrix, rhs)
152
153  @test_util.run_in_graph_and_eager_modes(use_gpu=True)
154  def testEmpty(self):
155    full = np.array([[1., 2.], [3., 4.], [5., 6.]])
156    empty0 = np.empty([3, 0])
157    empty1 = np.empty([0, 2])
158    for fast in [True, False]:
159      tf_ans = self.evaluate(
160          linalg_ops.matrix_solve_ls(empty0, empty0, fast=fast))
161      self.assertEqual(tf_ans.shape, (0, 0))
162      tf_ans = self.evaluate(
163          linalg_ops.matrix_solve_ls(empty0, full, fast=fast))
164      self.assertEqual(tf_ans.shape, (0, 2))
165      tf_ans = self.evaluate(
166          linalg_ops.matrix_solve_ls(full, empty0, fast=fast))
167      self.assertEqual(tf_ans.shape, (2, 0))
168      tf_ans = self.evaluate(
169          linalg_ops.matrix_solve_ls(empty1, empty1, fast=fast))
170      self.assertEqual(tf_ans.shape, (2, 2))
171
172  @test_util.run_in_graph_and_eager_modes(use_gpu=True)
173  def testBatchResultSize(self):
174    # 3x3x3 matrices, 3x3x1 right-hand sides.
175    matrix = np.array([1., 0., 0., 0., 1., 0., 0., 0., 1.] * 3).reshape(3, 3, 3)  # pylint: disable=too-many-function-args
176    rhs = np.array([1., 2., 3.] * 3).reshape(3, 3, 1)  # pylint: disable=too-many-function-args
177    answer = linalg_ops.matrix_solve(matrix, rhs)
178    ls_answer = linalg_ops.matrix_solve_ls(matrix, rhs)
179    self.assertEqual(ls_answer.get_shape(), [3, 3, 1])
180    self.assertEqual(answer.get_shape(), [3, 3, 1])
181
182
183def _GetSmallMatrixSolveLsOpTests(dtype, use_placeholder, fast, l2_regularizer):
184
185  def Square(self):
186    # 2x2 matrices, 2x3 right-hand sides.
187    matrix = np.array([[1., 2.], [3., 4.]])
188    rhs = np.array([[1., 0., 1.], [0., 1., 1.]])
189    for batch_shape in (), (2, 3):
190      self._verifySolve(
191          matrix,
192          rhs,
193          dtype,
194          use_placeholder,
195          fast,
196          l2_regularizer,
197          batch_shape=batch_shape)
198
199  def Overdetermined(self):
200    # 2x2 matrices, 2x3 right-hand sides.
201    matrix = np.array([[1., 2.], [3., 4.], [5., 6.]])
202    rhs = np.array([[1., 0., 1.], [0., 1., 1.], [1., 1., 0.]])
203    for batch_shape in (), (2, 3):
204      self._verifySolve(
205          matrix,
206          rhs,
207          dtype,
208          use_placeholder,
209          fast,
210          l2_regularizer,
211          batch_shape=batch_shape)
212
213  def Underdetermined(self):
214    # 2x2 matrices, 2x3 right-hand sides.
215    matrix = np.array([[1., 2., 3], [4., 5., 6.]])
216    rhs = np.array([[1., 0., 1.], [0., 1., 1.]])
217    for batch_shape in (), (2, 3):
218      self._verifySolve(
219          matrix,
220          rhs,
221          dtype,
222          use_placeholder,
223          fast,
224          l2_regularizer,
225          batch_shape=batch_shape)
226
227  return (Square, Overdetermined, Underdetermined)
228
229
230def _GetLargeMatrixSolveLsOpTests(dtype, use_placeholder, fast, l2_regularizer):
231
232  def LargeBatchSquare(self):
233    np.random.seed(1)
234    num_rhs = 1
235    matrix_shape = (127, 127)
236    matrix = np.random.uniform(
237        low=-1.0, high=1.0,
238        size=np.prod(matrix_shape)).reshape(matrix_shape).astype(np.float32)
239    rhs = np.ones([matrix_shape[0], num_rhs]).astype(np.float32)
240    self._verifySolve(
241        matrix,
242        rhs,
243        dtype,
244        use_placeholder,
245        fast,
246        l2_regularizer,
247        batch_shape=(16, 8))
248
249  def LargeBatchOverdetermined(self):
250    np.random.seed(1)
251    num_rhs = 1
252    matrix_shape = (127, 64)
253    matrix = np.random.uniform(
254        low=-1.0, high=1.0,
255        size=np.prod(matrix_shape)).reshape(matrix_shape).astype(np.float32)
256    rhs = np.ones([matrix_shape[0], num_rhs]).astype(np.float32)
257    self._verifySolve(
258        matrix,
259        rhs,
260        dtype,
261        use_placeholder,
262        fast,
263        l2_regularizer,
264        batch_shape=(16, 8))
265
266  def LargeBatchUnderdetermined(self):
267    np.random.seed(1)
268    num_rhs = 1
269    matrix_shape = (64, 127)
270    matrix = np.random.uniform(
271        low=-1.0, high=1.0,
272        size=np.prod(matrix_shape)).reshape(matrix_shape).astype(np.float32)
273    rhs = np.ones([matrix_shape[0], num_rhs]).astype(np.float32)
274    self._verifySolve(
275        matrix,
276        rhs,
277        dtype,
278        use_placeholder,
279        fast,
280        l2_regularizer,
281        batch_shape=(16, 8))
282
283  return (LargeBatchSquare, LargeBatchOverdetermined, LargeBatchUnderdetermined)
284
285
286class MatrixSolveLsBenchmark(test_lib.Benchmark):
287
288  matrix_shapes = [
289      (4, 4),
290      (8, 4),
291      (4, 8),
292      (10, 10),
293      (10, 8),
294      (8, 10),
295      (16, 16),
296      (16, 10),
297      (10, 16),
298      (101, 101),
299      (101, 31),
300      (31, 101),
301      (256, 256),
302      (256, 200),
303      (200, 256),
304      (1001, 1001),
305      (1001, 501),
306      (501, 1001),
307      (1024, 1024),
308      (1024, 128),
309      (128, 1024),
310      (2048, 2048),
311      (2048, 64),
312      (64, 2048),
313      (513, 4, 4),
314      (513, 4, 2),
315      (513, 2, 4),
316      (513, 16, 16),
317      (513, 16, 10),
318      (513, 10, 16),
319      (513, 256, 256),
320      (513, 256, 128),
321      (513, 128, 256),
322  ]
323
324  def benchmarkMatrixSolveLsOp(self):
325    run_gpu_test = test_lib.is_gpu_available(True)
326    regularizer = 1.0
327    for matrix_shape in self.matrix_shapes:
328      for num_rhs in 1, 2, matrix_shape[-1]:
329
330        with ops.Graph().as_default(), \
331            session.Session(config=benchmark.benchmark_config()) as sess, \
332            ops.device("/cpu:0"):
333          matrix, rhs = _GenerateTestData(matrix_shape, num_rhs)
334          x = linalg_ops.matrix_solve_ls(matrix, rhs, regularizer)
335          self.evaluate(variables.global_variables_initializer())
336          self.run_op_benchmark(
337              sess,
338              control_flow_ops.group(x),
339              min_iters=25,
340              store_memory_usage=False,
341              name=("matrix_solve_ls_cpu_shape_{matrix_shape}_num_rhs_{num_rhs}"
342                   ).format(matrix_shape=matrix_shape, num_rhs=num_rhs))
343
344        if run_gpu_test and (len(matrix_shape) < 3 or matrix_shape[0] < 513):
345          with ops.Graph().as_default(), \
346                session.Session(config=benchmark.benchmark_config()) as sess, \
347                ops.device("/gpu:0"):
348            matrix, rhs = _GenerateTestData(matrix_shape, num_rhs)
349            x = linalg_ops.matrix_solve_ls(matrix, rhs, regularizer)
350            self.evaluate(variables.global_variables_initializer())
351            self.run_op_benchmark(
352                sess,
353                control_flow_ops.group(x),
354                min_iters=25,
355                store_memory_usage=False,
356                name=("matrix_solve_ls_gpu_shape_{matrix_shape}_num_rhs_"
357                      "{num_rhs}").format(
358                          matrix_shape=matrix_shape, num_rhs=num_rhs))
359
360
361if __name__ == "__main__":
362  dtypes_to_test = [np.float32, np.float64, np.complex64, np.complex128]
363  for dtype_ in dtypes_to_test:
364    for use_placeholder_ in set([False, True]):
365      for fast_ in [True, False]:
366        l2_regularizers = [0] if dtype_ == np.complex128 else [0, 0.1]
367        for l2_regularizer_ in l2_regularizers:
368          for test_case in _GetSmallMatrixSolveLsOpTests(
369              dtype_, use_placeholder_, fast_, l2_regularizer_):
370            name = "%s_%s_placeholder_%s_fast_%s_regu_%s" % (test_case.__name__,
371                                                             dtype_.__name__,
372                                                             use_placeholder_,
373                                                             fast_,
374                                                             l2_regularizer_)
375            _AddTest(MatrixSolveLsOpTest, "MatrixSolveLsOpTest", name,
376                     test_case)
377  for dtype_ in dtypes_to_test:
378    for test_case in _GetLargeMatrixSolveLsOpTests(dtype_, False, True, 0.0):
379      name = "%s_%s" % (test_case.__name__, dtype_.__name__)
380      _AddTest(MatrixSolveLsOpTest, "MatrixSolveLsOpTest", name, test_case)
381
382  test_lib.main()
383