1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Gradients for operators defined in linalg_ops.py. 16 17Useful reference for derivative formulas is 18An extended collection of matrix derivative results for forward and reverse 19mode algorithmic differentiation by Mike Giles: 20http://eprints.maths.ox.ac.uk/1079/1/NA-08-01.pdf 21 22A detailed derivation of formulas for backpropagating through spectral layers 23(SVD and Eig) by Ionescu, Vantzos & Sminchisescu: 24https://arxiv.org/pdf/1509.07838v4.pdf 25""" 26from __future__ import absolute_import 27from __future__ import division 28from __future__ import print_function 29 30from tensorflow.python.framework import ops 31from tensorflow.python.ops import array_ops 32from tensorflow.python.ops import control_flow_ops 33from tensorflow.python.ops import linalg_ops 34from tensorflow.python.ops import math_ops 35from tensorflow.python.ops.linalg import linalg_impl as _linalg 36 37 38@ops.RegisterGradient("MatrixInverse") 39def _MatrixInverseGrad(op, grad): 40 """Gradient for MatrixInverse.""" 41 ainv = op.outputs[0] 42 return -math_ops.matmul( 43 ainv, math_ops.matmul(grad, ainv, adjoint_b=True), adjoint_a=True) 44 45 46@ops.RegisterGradient("MatrixDeterminant") 47def _MatrixDeterminantGrad(op, grad): 48 """Gradient for MatrixDeterminant.""" 49 a = op.inputs[0] 50 c = op.outputs[0] 51 a_adj_inv = linalg_ops.matrix_inverse(a, adjoint=True) 52 multipliers = array_ops.reshape(grad * c, 53 array_ops.concat([array_ops.shape(c), [1, 1]], 54 0)) 55 return multipliers * a_adj_inv 56 57 58@ops.RegisterGradient("MatrixSquareRoot") 59def _MatrixSquareRootGrad(op, grad): 60 """Gradient for MatrixSquareRoot.""" 61 62 # Let A be an m x m square matrix (or batch of matrices) 63 # Let R = sqrtm(A) 64 # By definition, A = RR 65 # Take the differential: dA = d(RR) = RdR + dRR 66 # Solve the resulting Sylvester equation for dR 67 68 # Used to find Kronecker products within the Sylvester equation 69 def _KroneckerProduct(b1, b2): 70 """Computes the Kronecker product of two batches of square matrices""" 71 b1_shape = array_ops.shape(b1) 72 b2_shape = array_ops.shape(b2) 73 b1_order = b1_shape[-1] 74 b2_order = b2_shape[-1] 75 76 shape_slice_size = [math_ops.subtract(array_ops.size(b1_shape), 2)] 77 shape_slice = array_ops.slice(b1_shape, [0], 78 shape_slice_size) # Same for both batches 79 b1_reshape_shape = array_ops.concat( 80 [shape_slice, [b1_order], [1], [b1_order], [1]], 0) 81 b2_reshape_shape = array_ops.concat( 82 [shape_slice, [1], [b2_order], [1], [b2_order]], 0) 83 84 b1_reshape = array_ops.reshape(b1, b1_reshape_shape) 85 b2_reshape = array_ops.reshape(b2, b2_reshape_shape) 86 87 order_prod = b1_order * b2_order 88 kprod_shape = array_ops.concat([shape_slice, [order_prod], [order_prod]], 0) 89 return array_ops.reshape(b1_reshape * b2_reshape, kprod_shape) 90 91 sqrtm = op.outputs[0] # R 92 shape = array_ops.shape(sqrtm) 93 order = shape[-1] # m 94 matrix_count = math_ops.reduce_prod(shape[0:-2]) 95 96 # Get batch of m x m identity matrices 97 eye = linalg_ops.eye(order, dtype=sqrtm.dtype) # m x m identity matrix 98 eye_flat = array_ops.reshape(eye, [-1]) 99 eye_tiled = array_ops.tile(eye_flat, [matrix_count]) 100 eye_batch = array_ops.reshape(eye_tiled, shape) 101 102 # The transpose of R is taken in the k1 term instead of k2 in 103 # order to prevent redundant transposition of R (i.e. (R')' = R) 104 sqrtm_transpose = array_ops.matrix_transpose(sqrtm) 105 k1 = _KroneckerProduct(eye_batch, sqrtm_transpose) 106 k2 = _KroneckerProduct(sqrtm, eye_batch) 107 ksum = math_ops.add(k1, k2) 108 109 # Vectorize dA 110 shape_slice_size = [math_ops.subtract(array_ops.size(shape), 2)] 111 shape_slice = array_ops.slice(shape, [0], shape_slice_size) 112 shape_vec_da = array_ops.concat([shape_slice, [order * order], [1]], 0) 113 vec_da = array_ops.reshape(array_ops.matrix_transpose(grad), shape_vec_da) 114 115 # Solve for vec(dR) 116 vec_dsqrtm = linalg_ops.matrix_solve(ksum, vec_da) 117 118 # Solve for dR by inverse vectorizing vec(dR) 119 dsqrtm_transpose = array_ops.reshape(vec_dsqrtm, shape) 120 return array_ops.matrix_transpose(dsqrtm_transpose) 121 122 123@ops.RegisterGradient("LogMatrixDeterminant") 124def _LogMatrixDeterminantGrad(op, _, grad_b): 125 """Gradient for LogMatrixDeterminant.""" 126 a = op.inputs[0] 127 c = op.outputs[1] 128 a_adj_inv = linalg_ops.matrix_inverse(a, adjoint=True) 129 multipliers = array_ops.reshape( 130 grad_b, array_ops.concat([array_ops.shape(c), [1, 1]], 0)) 131 return multipliers * a_adj_inv 132 133 134@ops.RegisterGradient("Cholesky") 135def _CholeskyGrad(op, grad): 136 """Gradient for Cholesky.""" 137 138 # Gradient is l^{-H} @ ((l^{H} @ grad) * (tril(ones)-1/2*eye)) @ l^{-1} 139 l = op.outputs[0] 140 num_rows = array_ops.shape(l)[-1] 141 batch_shape = array_ops.shape(l)[:-2] 142 l_inverse = linalg_ops.matrix_triangular_solve(l, 143 linalg_ops.eye( 144 num_rows, 145 batch_shape=batch_shape, 146 dtype=l.dtype)) 147 148 middle = math_ops.matmul(l, grad, adjoint_a=True) 149 middle = array_ops.matrix_set_diag(middle, 150 0.5 * array_ops.matrix_diag_part(middle)) 151 middle = array_ops.matrix_band_part(middle, -1, 0) 152 153 grad_a = math_ops.matmul( 154 math_ops.matmul(l_inverse, middle, adjoint_a=True), l_inverse) 155 156 grad_a += _linalg.adjoint(grad_a) 157 return grad_a * 0.5 158 159 160@ops.RegisterGradient("Qr") 161def _QrGrad(op, dq, dr): 162 """Gradient for Qr.""" 163 q, r = op.outputs 164 if q.dtype.is_complex: 165 raise NotImplementedError("QrGrad not implemented for dtype: %s" % q.dtype) 166 if (r.shape.ndims is None or r.shape.as_list()[-2] is None or 167 r.shape.as_list()[-1] is None): 168 raise NotImplementedError("QrGrad not implemented with dynamic shapes.") 169 if r.shape.dims[-2].value != r.shape.dims[-1].value: 170 raise NotImplementedError("QrGrad not implemented when ncols > nrows " 171 "or full_matrices is true and ncols != nrows.") 172 173 qdq = math_ops.matmul(q, dq, adjoint_a=True) 174 qdq_ = qdq - _linalg.adjoint(qdq) 175 rdr = math_ops.matmul(r, dr, adjoint_b=True) 176 rdr_ = rdr - _linalg.adjoint(rdr) 177 tril = array_ops.matrix_band_part(qdq_ + rdr_, -1, 0) 178 179 def _TriangularSolve(x, r): 180 """Equiv to matmul(x, adjoint(matrix_inverse(r))) if r is upper-tri.""" 181 return _linalg.adjoint( 182 linalg_ops.matrix_triangular_solve( 183 r, _linalg.adjoint(x), lower=False, adjoint=False)) 184 185 grad_a = math_ops.matmul(q, dr + _TriangularSolve(tril, r)) 186 grad_b = _TriangularSolve(dq - math_ops.matmul(q, qdq), r) 187 return grad_a + grad_b 188 189 190@ops.RegisterGradient("MatrixSolve") 191def _MatrixSolveGrad(op, grad): 192 """Gradient for MatrixSolve.""" 193 a = op.inputs[0] 194 adjoint_a = op.get_attr("adjoint") 195 c = op.outputs[0] 196 grad_b = linalg_ops.matrix_solve(a, grad, adjoint=not adjoint_a) 197 if adjoint_a: 198 grad_a = -math_ops.matmul(c, grad_b, adjoint_b=True) 199 else: 200 grad_a = -math_ops.matmul(grad_b, c, adjoint_b=True) 201 return (grad_a, grad_b) 202 203 204@ops.RegisterGradient("MatrixSolveLs") 205def _MatrixSolveLsGrad(op, grad): 206 """Gradients for MatrixSolveLs.""" 207 208 # TODO(rmlarsen): The implementation could be more efficient: 209 # a) Output the Cholesky factorization from forward op instead of 210 # recomputing it here. 211 # b) Implement a symmetric rank-k update op instead of computing 212 # x*z + transpose(x*z). This pattern occurs other places in TensorFlow. 213 214 def _Overdetermined(op, grad): 215 """Gradients for the overdetermined case of MatrixSolveLs. 216 217 This is the backprop for the solution to the normal equations of the first 218 kind: 219 X = F(A, B) = (A^T * A + lambda * I)^{-1} * A^T * B 220 which solve the least squares problem 221 min ||A * X - B||_F^2 + lambda ||X||_F^2. 222 """ 223 a = op.inputs[0] 224 b = op.inputs[1] 225 x = op.outputs[0] 226 l2_regularizer = math_ops.cast(op.inputs[2], a.dtype.base_dtype) 227 # pylint: disable=protected-access 228 chol = linalg_ops._RegularizedGramianCholesky( 229 a, l2_regularizer=l2_regularizer, first_kind=True) 230 # pylint: enable=protected-access 231 # Temporary z = (A^T * A + lambda * I)^{-1} * grad. 232 z = linalg_ops.cholesky_solve(chol, grad) 233 xzt = math_ops.matmul(x, z, adjoint_b=True) 234 zx_sym = xzt + array_ops.matrix_transpose(xzt) 235 grad_a = -math_ops.matmul(a, zx_sym) + math_ops.matmul(b, z, adjoint_b=True) 236 grad_b = math_ops.matmul(a, z) 237 return (grad_a, grad_b, None) 238 239 def _Underdetermined(op, grad): 240 """Gradients for the underdetermined case of MatrixSolveLs. 241 242 This is the backprop for the solution to the normal equations of the second 243 kind: 244 X = F(A, B) = A * (A*A^T + lambda*I)^{-1} * B 245 that (for lambda=0) solve the least squares problem 246 min ||X||_F subject to A*X = B. 247 """ 248 a = op.inputs[0] 249 b = op.inputs[1] 250 l2_regularizer = math_ops.cast(op.inputs[2], a.dtype.base_dtype) 251 # pylint: disable=protected-access 252 chol = linalg_ops._RegularizedGramianCholesky( 253 a, l2_regularizer=l2_regularizer, first_kind=False) 254 # pylint: enable=protected-access 255 grad_b = linalg_ops.cholesky_solve(chol, math_ops.matmul(a, grad)) 256 # Temporary tmp = (A * A^T + lambda * I)^{-1} * B. 257 tmp = linalg_ops.cholesky_solve(chol, b) 258 a1 = math_ops.matmul(tmp, a, adjoint_a=True) 259 a1 = -math_ops.matmul(grad_b, a1) 260 a2 = grad - math_ops.matmul(a, grad_b, adjoint_a=True) 261 a2 = math_ops.matmul(tmp, a2, adjoint_b=True) 262 grad_a = a1 + a2 263 return (grad_a, grad_b, None) 264 265 fast = op.get_attr("fast") 266 if fast is False: 267 raise ValueError("Gradient not defined for fast=False") 268 matrix_shape = op.inputs[0].get_shape()[-2:] 269 if matrix_shape.is_fully_defined(): 270 if matrix_shape[-2] >= matrix_shape[-1]: 271 return _Overdetermined(op, grad) 272 else: 273 return _Underdetermined(op, grad) 274 else: 275 # We have to defer determining the shape to runtime and use 276 # conditional execution of the appropriate graph. 277 matrix_shape = array_ops.shape(op.inputs[0])[-2:] 278 return control_flow_ops.cond(matrix_shape[-2] >= matrix_shape[-1], 279 lambda: _Overdetermined(op, grad), 280 lambda: _Underdetermined(op, grad)) 281 282 283@ops.RegisterGradient("MatrixTriangularSolve") 284def _MatrixTriangularSolveGrad(op, grad): 285 """Gradient for MatrixTriangularSolve.""" 286 a = op.inputs[0] 287 adjoint_a = op.get_attr("adjoint") 288 lower_a = op.get_attr("lower") 289 c = op.outputs[0] 290 grad_b = linalg_ops.matrix_triangular_solve( 291 a, grad, lower=lower_a, adjoint=not adjoint_a) 292 if adjoint_a: 293 grad_a = -math_ops.matmul(c, grad_b, adjoint_b=True) 294 else: 295 grad_a = -math_ops.matmul(grad_b, c, adjoint_b=True) 296 if lower_a: 297 grad_a = array_ops.matrix_band_part(grad_a, -1, 0) 298 else: 299 grad_a = array_ops.matrix_band_part(grad_a, 0, -1) 300 return (grad_a, grad_b) 301 302 303@ops.RegisterGradient("SelfAdjointEigV2") 304def _SelfAdjointEigV2Grad(op, grad_e, grad_v): 305 """Gradient for SelfAdjointEigV2.""" 306 e = op.outputs[0] 307 compute_v = op.get_attr("compute_v") 308 # a = op.inputs[0], which satisfies 309 # a[...,:,:] * v[...,:,i] = e[...,i] * v[...,i] 310 with ops.control_dependencies([grad_e, grad_v]): 311 if compute_v: 312 v = op.outputs[1] 313 # Construct the matrix f(i,j) = (i != j ? 1 / (e_i - e_j) : 0). 314 # Notice that because of the term involving f, the gradient becomes 315 # infinite (or NaN in practice) when eigenvalues are not unique. 316 # Mathematically this should not be surprising, since for (k-fold) 317 # degenerate eigenvalues, the corresponding eigenvectors are only defined 318 # up to arbitrary rotation in a (k-dimensional) subspace. 319 f = array_ops.matrix_set_diag( 320 math_ops.reciprocal( 321 array_ops.expand_dims(e, -2) - array_ops.expand_dims(e, -1)), 322 array_ops.zeros_like(e)) 323 grad_a = math_ops.matmul( 324 v, 325 math_ops.matmul( 326 array_ops.matrix_diag(grad_e) + 327 f * math_ops.matmul(v, grad_v, adjoint_a=True), 328 v, 329 adjoint_b=True)) 330 else: 331 _, v = linalg_ops.self_adjoint_eig(op.inputs[0]) 332 grad_a = math_ops.matmul(v, 333 math_ops.matmul( 334 array_ops.matrix_diag(grad_e), 335 v, 336 adjoint_b=True)) 337 # The forward op only depends on the lower triangular part of a, so here we 338 # symmetrize and take the lower triangle 339 grad_a = array_ops.matrix_band_part(grad_a + _linalg.adjoint(grad_a), -1, 0) 340 grad_a = array_ops.matrix_set_diag(grad_a, 341 0.5 * array_ops.matrix_diag_part(grad_a)) 342 return grad_a 343 344 345@ops.RegisterGradient("Svd") 346def _SvdGrad(op, grad_s, grad_u, grad_v): 347 """Gradient for the singular value decomposition.""" 348 349 # The derivation for the compute_uv=False case, and most of 350 # the derivation for the full_matrices=True case, are in 351 # Giles' paper (see reference at top of file). A derivation for 352 # the full_matrices=False case is available at 353 # https://j-towns.github.io/papers/svd-derivative.pdf 354 a = op.inputs[0] 355 a_shape = a.get_shape().with_rank_at_least(2) 356 grad_s_mat = array_ops.matrix_diag(grad_s) 357 358 if not op.get_attr("compute_uv"): 359 s, u, v = linalg_ops.svd(a, compute_uv=True) 360 grad_a = math_ops.matmul(u, math_ops.matmul(grad_s_mat, v, adjoint_b=True)) 361 grad_a.set_shape(a_shape) 362 return grad_a 363 364 full_matrices = op.get_attr("full_matrices") 365 366 # TODO(rmlarsen): Make this work with complex types. 367 if a.dtype.is_complex: 368 raise NotImplementedError( 369 "SVD gradient is not implemented for complex types and " 370 "compute_uv=True.") 371 grad_u_shape = grad_u.get_shape().with_rank_at_least(2) 372 grad_v_shape = grad_v.get_shape().with_rank_at_least(2) 373 m = a_shape.dims[-2].merge_with(grad_u_shape[-2]) 374 n = a_shape.dims[-1].merge_with(grad_v_shape[-2]) 375 batch_shape = a_shape[:-2].merge_with(grad_u_shape[:-2]).merge_with( 376 grad_v_shape[:-2]) 377 a_shape = batch_shape.concatenate([m, n]) 378 379 m = a_shape.dims[-2].value 380 n = a_shape.dims[-1].value 381 # TODO(rmlarsen): Make this work with placeholders. 382 if m is None or n is None: 383 raise NotImplementedError( 384 "SVD gradient has not been implemented for input with unknown " 385 "inner matrix shape.") 386 387 s = op.outputs[0] 388 u = op.outputs[1] 389 v = op.outputs[2] 390 391 use_adjoint = False 392 if m > n: 393 # Compute the gradient for A^H = V * S^T * U^H, and (implicitly) take the 394 # Hermitian transpose of the gradient at the end. 395 use_adjoint = True 396 m, n = n, m 397 u, v = v, u 398 grad_u, grad_v = grad_v, grad_u 399 400 with ops.control_dependencies([grad_s, grad_u, grad_v]): 401 if full_matrices and abs(m - n) > 1: 402 raise NotImplementedError( 403 "svd gradient is not implemented for abs(m - n) > 1 " 404 "when full_matrices is True") 405 s_mat = array_ops.matrix_diag(s) 406 s2 = math_ops.square(s) 407 408 # NOTICE: Because of the term involving f, the gradient becomes 409 # infinite (or NaN in practice) when singular values are not unique. 410 # Mathematically this should not be surprising, since for (k-fold) 411 # degenerate singular values, the corresponding singular vectors are 412 # only defined up a (k-dimensional) subspace. In practice, this can 413 # lead to numerical instability when singular values are close but not 414 # exactly equal. 415 # Also, even with distinct singular values, the diagonal of f can have Inf 416 # values before setting to zero, which hurt when differentiating through 417 # this op. To avoid that, we add eye to the matrix before taking 418 # the reciprocal. 419 s_shape = array_ops.shape(s) 420 eye = _linalg.eye(s_shape[-1], batch_shape=s_shape[:-1], dtype=s.dtype) 421 f = array_ops.matrix_set_diag( 422 math_ops.reciprocal( 423 array_ops.expand_dims(s2, -2) - array_ops.expand_dims(s2, -1) + 424 eye), array_ops.zeros_like(s)) 425 s_inv_mat = array_ops.matrix_diag(math_ops.reciprocal(s)) 426 427 v1 = v[..., :, :m] 428 grad_v1 = grad_v[..., :, :m] 429 430 u_gu = math_ops.matmul(u, grad_u, adjoint_a=True) 431 v_gv = math_ops.matmul(v1, grad_v1, adjoint_a=True) 432 433 f_u = f * u_gu 434 f_v = f * v_gv 435 436 term1_nouv = ( 437 grad_s_mat + math_ops.matmul(f_u + _linalg.adjoint(f_u), s_mat) + 438 math_ops.matmul(s_mat, f_v + _linalg.adjoint(f_v))) 439 440 term1 = math_ops.matmul(u, math_ops.matmul(term1_nouv, v1, adjoint_b=True)) 441 442 if m == n: 443 grad_a_before_transpose = term1 444 else: 445 gv1t = array_ops.matrix_transpose(grad_v1) 446 gv1t_v1 = math_ops.matmul(gv1t, v1) 447 term2_nous = gv1t - math_ops.matmul(gv1t_v1, v1, adjoint_b=True) 448 449 if full_matrices: 450 v2 = v[..., :, m:n] 451 grad_v2 = grad_v[..., :, m:n] 452 453 v1t_gv2 = math_ops.matmul(v1, grad_v2, adjoint_a=True) 454 term2_nous -= math_ops.matmul(v1t_gv2, v2, adjoint_b=True) 455 456 u_s_inv = math_ops.matmul(u, s_inv_mat) 457 term2 = math_ops.matmul(u_s_inv, term2_nous) 458 459 grad_a_before_transpose = term1 + term2 460 461 if use_adjoint: 462 grad_a = array_ops.matrix_transpose(grad_a_before_transpose) 463 else: 464 grad_a = grad_a_before_transpose 465 466 grad_a.set_shape(a_shape) 467 return grad_a 468