• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for tensorflow.ops.linalg_grad."""
16
17import numpy as np
18
19from tensorflow.python.framework import constant_op
20from tensorflow.python.framework import test_util
21from tensorflow.python.ops import array_ops
22from tensorflow.python.ops import gradient_checker_v2
23from tensorflow.python.ops import gradients_impl
24from tensorflow.python.ops import linalg_ops
25from tensorflow.python.ops import math_ops
26from tensorflow.python.ops.linalg import linalg_impl
27from tensorflow.python.platform import test as test_lib
28
29
30def _AddTest(test, op_name, testcase_name, fn):
31  test_name = '_'.join(['test', op_name, testcase_name])
32  if hasattr(test, test_name):
33    raise RuntimeError('Test %s defined more than once' % test_name)
34  setattr(test, test_name, fn)
35
36
37class ShapeTest(test_lib.TestCase):
38
39  @test_util.run_deprecated_v1
40  def testBatchGradientUnknownSize(self):
41    with self.cached_session():
42      batch_size = constant_op.constant(3)
43      matrix_size = constant_op.constant(4)
44      batch_identity = array_ops.tile(
45          array_ops.expand_dims(
46              array_ops.diag(array_ops.ones([matrix_size])), 0),
47          [batch_size, 1, 1])
48      determinants = linalg_ops.matrix_determinant(batch_identity)
49      reduced = math_ops.reduce_sum(determinants)
50      sum_grad = gradients_impl.gradients(reduced, batch_identity)[0]
51      self.assertAllClose(batch_identity, self.evaluate(sum_grad))
52
53
54class MatrixUnaryFunctorGradientTest(test_lib.TestCase):
55  pass  # Filled in below
56
57
58def _GetMatrixUnaryFunctorGradientTest(functor_, dtype_, shape_, **kwargs_):
59
60  @test_util.enable_control_flow_v2
61  @test_util.run_in_graph_and_eager_modes(use_gpu=True)
62  @test_util.run_without_tensor_float_32(
63      'Tests `tf.linalg.expm`, which call matmul. Additionally, calls ops '
64      'which do matmul in their gradient, such as MatrixSolve.')
65  def Test(self):
66
67    def RandomInput():
68      np.random.seed(1)
69      return np.random.uniform(
70          low=-1.0, high=1.0,
71          size=np.prod(shape_)).reshape(shape_).astype(dtype_)
72
73    if functor_.__name__ == 'matrix_square_root':
74      # Square the input matrix to ensure that its matrix square root exists
75      f = lambda x: functor_(math_ops.matmul(x, x), **kwargs_)
76    else:
77      f = functor_
78
79    # Optimal stepsize for central difference is O(epsilon^{1/3}).
80    epsilon = np.finfo(dtype_).eps
81    delta = epsilon**(1.0 / 3.0)
82    # tolerance obtained by looking at actual differences using
83    # np.linalg.norm(theoretical-numerical, np.inf) on -mavx build
84    tol = 1e-6 if dtype_ == np.float64 else 0.05
85
86    theoretical, numerical = gradient_checker_v2.compute_gradient(
87        f, [RandomInput()], delta=delta)
88    self.assertAllClose(theoretical, numerical, atol=tol, rtol=tol)
89
90  return Test
91
92
93class MatrixBinaryFunctorGradientTest(test_lib.TestCase):
94  pass  # Filled in below
95
96
97def _GetMatrixBinaryFunctorGradientTest(functor_,
98                                        dtype_,
99                                        shape_,
100                                        float32_tol_fudge=1.0,
101                                        **kwargs_):
102
103  @test_util.run_in_graph_and_eager_modes(use_gpu=True)
104  @test_util.run_without_tensor_float_32(
105      'Tests `tf.linalg.lstsq`, which call matmul. Additionally, calls ops '
106      'which do matmul in their gradient, such as MatrixSolveLs.')
107  # TODO(b/164254522): With TensorFloat-32, some tests fails with extremely high
108  # absolute and relative differences when calling assertAllClose. For example,
109  # the test test_MatrixSolveLsGradient_float32_10_10_1e-06 of class
110  # MatrixBinaryFunctorGradientTest fails with a max absolute difference of
111  # 0.883 and a max relative difference of 736892. We should consider disabling
112  # TensorFloat-32 within `tf.linalg.lstsq and perhaps other linear algebra
113  # functions, even if TensorFloat-32 is allowed globally.
114  def Test(self):
115
116    def RandomInput():
117      np.random.seed(1)
118      return np.random.uniform(
119          low=-1.0, high=1.0,
120          size=np.prod(shape_)).reshape(shape_).astype(dtype_)
121
122    fixed = RandomInput()
123
124    # Optimal stepsize for central difference is O(epsilon^{1/3}).
125    epsilon = np.finfo(dtype_).eps
126    delta = epsilon**(1.0 / 3.0)
127    # tolerance obtained by looking at actual differences using
128    # np.linalg.norm(theoretical-numerical, np.inf) on -mavx build
129    tol = 1e-6 if dtype_ == np.float64 else float32_tol_fudge * 0.05
130
131    # check gradient w.r.t. left argument.
132    theoretical, numerical = gradient_checker_v2.compute_gradient(
133        lambda x: functor_(x, fixed, **kwargs_), [RandomInput()], delta=delta)
134    self.assertAllClose(theoretical, numerical, atol=tol, rtol=tol)
135
136    # check gradient w.r.t. right argument.
137    theoretical, numerical = gradient_checker_v2.compute_gradient(
138        lambda y: functor_(fixed, y, **kwargs_), [RandomInput()], delta=delta)
139    self.assertAllClose(theoretical, numerical, atol=tol, rtol=tol)
140
141  return Test
142
143
144def _GetBandedTriangularSolveGradientTest(
145    functor_,
146    dtype_,
147    shape_,
148    float32_tol_fudge=1.0,  # pylint: disable=redefined-outer-name
149    **kwargs_):
150
151  @test_util.run_in_graph_and_eager_modes(use_gpu=True)
152  def Test(self):
153    n = shape_[-1]
154
155    np.random.seed(1)
156    # Make sure invertible.
157    a_np = np.random.uniform(low=1.0, high=2.0, size=shape_).astype(dtype_)
158    a = constant_op.constant(a_np)
159
160    b_np = np.random.uniform(low=-1.0, high=1.0, size=[n, n]).astype(dtype_)
161    b = constant_op.constant(b_np)
162
163    epsilon = np.finfo(dtype_).eps
164    delta = epsilon**(1.0 / 3.0)
165    # tolerance obtained by looking at actual differences using
166    # np.linalg.norm(theoretical-numerical, np.inf) on -mavx build
167    tol = 1e-6 if dtype_ == np.float64 else float32_tol_fudge * 0.05
168
169    # check gradient w.r.t. left argument.
170    theoretical, numerical = gradient_checker_v2.compute_gradient(
171        lambda x: functor_(x, b, **kwargs_), [a], delta=delta)
172    self.assertAllClose(theoretical, numerical, atol=tol, rtol=tol)
173
174    # check gradient w.r.t. right argument.
175    theoretical, numerical = gradient_checker_v2.compute_gradient(
176        lambda y: functor_(a, y, **kwargs_), [b], delta=delta)
177    self.assertAllClose(theoretical, numerical, atol=tol, rtol=tol)
178
179  return Test
180
181
182if __name__ == '__main__':
183  # Tests for gradients of binary matrix operations.
184  for dtype in np.float32, np.float64:
185    for size in 2, 5, 10:
186      # We skip the rank 4, size 10 case: it is slow and conceptually covered
187      # by the other cases.
188      for extra in [(), (2,), (3,)] + [(3, 2)] * (size < 10):
189        for adjoint in False, True:
190          shape = extra + (size, size)
191          name = '%s_%s_adj_%s' % (dtype.__name__, '_'.join(map(
192              str, shape)), str(adjoint))
193          _AddTest(
194              MatrixBinaryFunctorGradientTest, 'MatrixSolveGradient', name,
195              _GetMatrixBinaryFunctorGradientTest(
196                  linalg_ops.matrix_solve, dtype, shape, adjoint=adjoint))
197
198          for lower in True, False:
199            name = '%s_low_%s' % (name, lower)
200            _AddTest(
201                MatrixBinaryFunctorGradientTest,
202                'MatrixTriangularSolveGradient', name,
203                _GetMatrixBinaryFunctorGradientTest(
204                    linalg_ops.matrix_triangular_solve,
205                    dtype,
206                    shape,
207                    float32_tol_fudge=4.0,
208                    adjoint=adjoint,
209                    lower=lower))
210
211            band_shape = extra + (size // 2 + 1, size)
212            name = '%s_%s_adj_%s_low_%s' % (dtype.__name__, '_'.join(
213                map(str, band_shape)), str(adjoint), lower)
214            _AddTest(
215                MatrixBinaryFunctorGradientTest,
216                'BandedTriangularSolveGradient', name,
217                _GetBandedTriangularSolveGradientTest(
218                    linalg_ops.banded_triangular_solve,
219                    dtype,
220                    band_shape,
221                    float32_tol_fudge=4.0,
222                    adjoint=adjoint,
223                    lower=lower))
224
225  # Tests for gradients of unary matrix operations.
226  for dtype in np.float32, np.float64:
227    for size in 2, 5, 10:
228      # We skip the rank 4, size 10 case: it is slow and conceptually covered
229      # by the other cases.
230      for extra in [(), (2,), (3,)] + [(3, 2)] * (size < 10):
231        shape = extra + (size, size)
232        name = '%s_%s' % (dtype.__name__, '_'.join(map(str, shape)))
233        _AddTest(
234            MatrixUnaryFunctorGradientTest, 'MatrixInverseGradient', name,
235            _GetMatrixUnaryFunctorGradientTest(linalg_ops.matrix_inverse, dtype,
236                                               shape))
237        _AddTest(
238            MatrixUnaryFunctorGradientTest, 'MatrixAdjointInverseGradient',
239            name, _GetMatrixUnaryFunctorGradientTest(
240                lambda x: linalg_ops.matrix_inverse(x, adjoint=True),
241                dtype, shape))
242
243        if not test_lib.is_built_with_rocm():
244          # TODO(rocm) :
245          # re-enable this test when upstream issues are resolved
246          # see commit msg for details
247          _AddTest(
248              MatrixUnaryFunctorGradientTest, 'MatrixExponentialGradient', name,
249              _GetMatrixUnaryFunctorGradientTest(linalg_impl.matrix_exponential,
250                                                 dtype, shape))
251        _AddTest(
252            MatrixUnaryFunctorGradientTest, 'MatrixDeterminantGradient', name,
253            _GetMatrixUnaryFunctorGradientTest(linalg_ops.matrix_determinant,
254                                               dtype, shape))
255        _AddTest(
256            MatrixUnaryFunctorGradientTest, 'LogMatrixDeterminantGradient',
257            name,
258            _GetMatrixUnaryFunctorGradientTest(
259                lambda x: linalg_ops.log_matrix_determinant(x)[1], dtype,
260                shape))
261
262        # The numerical Jacobian is consistently invalid for these four shapes
263        # because the matrix square root of the perturbed input doesn't exist
264        if shape in {(2, 5, 5), (3, 5, 5), (3, 10, 10), (3, 2, 5, 5)}:
265          # Alternative shape that consistently produces a valid numerical Jacobian
266          shape = extra + (size + 1, size + 1)
267          name = '%s_%s' % (dtype.__name__, '_'.join(map(str, shape)))
268        _AddTest(
269            MatrixUnaryFunctorGradientTest, 'MatrixSquareRootGradient', name,
270            _GetMatrixUnaryFunctorGradientTest(linalg_ops.matrix_square_root,
271                                               dtype, shape))
272
273  # Tests for gradients of matrix_solve_ls
274  for dtype in np.float32, np.float64:
275    for rows in 2, 5, 10:
276      for cols in 2, 5, 10:
277        for l2_regularization in 1e-6, 0.001, 1.0:
278          shape = (rows, cols)
279          name = '%s_%s_%s' % (dtype.__name__, '_'.join(map(
280              str, shape)), l2_regularization)
281          float32_tol_fudge = 5.1 if l2_regularization == 1e-6 else 4.0
282          _AddTest(
283              MatrixBinaryFunctorGradientTest,
284              'MatrixSolveLsGradient',
285              name,
286              # pylint: disable=long-lambda,g-long-lambda
287              _GetMatrixBinaryFunctorGradientTest(
288                  (lambda a, b, l=l2_regularization: linalg_ops.matrix_solve_ls(
289                      a, b, l)), dtype, shape, float32_tol_fudge))
290
291  test_lib.main()
292