1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Registrations for LinearOperator.solve.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.ops.linalg import linear_operator 22from tensorflow.python.ops.linalg import linear_operator_algebra 23from tensorflow.python.ops.linalg import linear_operator_block_diag 24from tensorflow.python.ops.linalg import linear_operator_circulant 25from tensorflow.python.ops.linalg import linear_operator_composition 26from tensorflow.python.ops.linalg import linear_operator_diag 27from tensorflow.python.ops.linalg import linear_operator_identity 28from tensorflow.python.ops.linalg import linear_operator_inversion 29from tensorflow.python.ops.linalg import linear_operator_lower_triangular 30from tensorflow.python.ops.linalg import registrations_util 31 32 33# By default, use a LinearOperatorComposition to delay the computation. 34@linear_operator_algebra.RegisterSolve( 35 linear_operator.LinearOperator, linear_operator.LinearOperator) 36def _solve_linear_operator(linop_a, linop_b): 37 """Generic solve of two `LinearOperator`s.""" 38 is_square = registrations_util.is_square(linop_a, linop_b) 39 is_non_singular = None 40 is_self_adjoint = None 41 is_positive_definite = None 42 43 if is_square: 44 is_non_singular = registrations_util.combined_non_singular_hint( 45 linop_a, linop_b) 46 elif is_square is False: # pylint:disable=g-bool-id-comparison 47 is_non_singular = False 48 is_self_adjoint = False 49 is_positive_definite = False 50 51 return linear_operator_composition.LinearOperatorComposition( 52 operators=[ 53 linear_operator_inversion.LinearOperatorInversion(linop_a), 54 linop_b 55 ], 56 is_non_singular=is_non_singular, 57 is_self_adjoint=is_self_adjoint, 58 is_positive_definite=is_positive_definite, 59 is_square=is_square, 60 ) 61 62 63@linear_operator_algebra.RegisterSolve( 64 linear_operator_inversion.LinearOperatorInversion, 65 linear_operator.LinearOperator) 66def _solve_inverse_linear_operator(linop_a, linop_b): 67 """Solve inverse of generic `LinearOperator`s.""" 68 return linop_a.operator.matmul(linop_b) 69 70 71# Identity 72@linear_operator_algebra.RegisterSolve( 73 linear_operator_identity.LinearOperatorIdentity, 74 linear_operator.LinearOperator) 75def _solve_linear_operator_identity_left(identity, linop): 76 del identity 77 return linop 78 79 80@linear_operator_algebra.RegisterSolve( 81 linear_operator.LinearOperator, 82 linear_operator_identity.LinearOperatorIdentity) 83def _solve_linear_operator_identity_right(linop, identity): 84 del identity 85 return linop.inverse() 86 87 88@linear_operator_algebra.RegisterSolve( 89 linear_operator_identity.LinearOperatorScaledIdentity, 90 linear_operator_identity.LinearOperatorScaledIdentity) 91def _solve_linear_operator_scaled_identity(linop_a, linop_b): 92 """Solve of two ScaledIdentity `LinearOperators`.""" 93 return linear_operator_identity.LinearOperatorScaledIdentity( 94 num_rows=linop_a.domain_dimension_tensor(), 95 multiplier=linop_b.multiplier / linop_a.multiplier, 96 is_non_singular=registrations_util.combined_non_singular_hint( 97 linop_a, linop_b), 98 is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint( 99 linop_a, linop_b), 100 is_positive_definite=( 101 registrations_util.combined_commuting_positive_definite_hint( 102 linop_a, linop_b)), 103 is_square=True) 104 105 106# Diag. 107 108 109@linear_operator_algebra.RegisterSolve( 110 linear_operator_diag.LinearOperatorDiag, 111 linear_operator_diag.LinearOperatorDiag) 112def _solve_linear_operator_diag(linop_a, linop_b): 113 return linear_operator_diag.LinearOperatorDiag( 114 diag=linop_b.diag / linop_a.diag, 115 is_non_singular=registrations_util.combined_non_singular_hint( 116 linop_a, linop_b), 117 is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint( 118 linop_a, linop_b), 119 is_positive_definite=( 120 registrations_util.combined_commuting_positive_definite_hint( 121 linop_a, linop_b)), 122 is_square=True) 123 124 125@linear_operator_algebra.RegisterSolve( 126 linear_operator_diag.LinearOperatorDiag, 127 linear_operator_identity.LinearOperatorScaledIdentity) 128def _solve_linear_operator_diag_scaled_identity_right( 129 linop_diag, linop_scaled_identity): 130 return linear_operator_diag.LinearOperatorDiag( 131 diag=linop_scaled_identity.multiplier / linop_diag.diag, 132 is_non_singular=registrations_util.combined_non_singular_hint( 133 linop_diag, linop_scaled_identity), 134 is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint( 135 linop_diag, linop_scaled_identity), 136 is_positive_definite=( 137 registrations_util.combined_commuting_positive_definite_hint( 138 linop_diag, linop_scaled_identity)), 139 is_square=True) 140 141 142@linear_operator_algebra.RegisterSolve( 143 linear_operator_identity.LinearOperatorScaledIdentity, 144 linear_operator_diag.LinearOperatorDiag) 145def _solve_linear_operator_diag_scaled_identity_left( 146 linop_scaled_identity, linop_diag): 147 return linear_operator_diag.LinearOperatorDiag( 148 diag=linop_diag.diag / linop_scaled_identity.multiplier, 149 is_non_singular=registrations_util.combined_non_singular_hint( 150 linop_diag, linop_scaled_identity), 151 is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint( 152 linop_diag, linop_scaled_identity), 153 is_positive_definite=( 154 registrations_util.combined_commuting_positive_definite_hint( 155 linop_diag, linop_scaled_identity)), 156 is_square=True) 157 158 159@linear_operator_algebra.RegisterSolve( 160 linear_operator_diag.LinearOperatorDiag, 161 linear_operator_lower_triangular.LinearOperatorLowerTriangular) 162def _solve_linear_operator_diag_tril(linop_diag, linop_triangular): 163 return linear_operator_lower_triangular.LinearOperatorLowerTriangular( 164 tril=linop_triangular.to_dense() / linop_diag.diag[..., None], 165 is_non_singular=registrations_util.combined_non_singular_hint( 166 linop_diag, linop_triangular), 167 # This is safe to do since the Triangular matrix is only self-adjoint 168 # when it is a diagonal matrix, and hence commutes. 169 is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint( 170 linop_diag, linop_triangular), 171 is_positive_definite=None, 172 is_square=True) 173 174 175# Circulant. 176 177 178@linear_operator_algebra.RegisterSolve( 179 linear_operator_circulant.LinearOperatorCirculant, 180 linear_operator_circulant.LinearOperatorCirculant) 181def _solve_linear_operator_circulant_circulant(linop_a, linop_b): 182 return linear_operator_circulant.LinearOperatorCirculant( 183 spectrum=linop_b.spectrum / linop_a.spectrum, 184 is_non_singular=registrations_util.combined_non_singular_hint( 185 linop_a, linop_b), 186 is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint( 187 linop_a, linop_b), 188 is_positive_definite=( 189 registrations_util.combined_commuting_positive_definite_hint( 190 linop_a, linop_b)), 191 is_square=True) 192 193 194# Block Diag 195 196 197@linear_operator_algebra.RegisterSolve( 198 linear_operator_block_diag.LinearOperatorBlockDiag, 199 linear_operator_block_diag.LinearOperatorBlockDiag) 200def _solve_linear_operator_block_diag_block_diag(linop_a, linop_b): 201 return linear_operator_block_diag.LinearOperatorBlockDiag( 202 operators=[ 203 o1.solve(o2) for o1, o2 in zip( 204 linop_a.operators, linop_b.operators)], 205 is_non_singular=registrations_util.combined_non_singular_hint( 206 linop_a, linop_b), 207 # In general, a solve of self-adjoint positive-definite block diagonal 208 # matrices is not self-=adjoint. 209 is_self_adjoint=None, 210 # In general, a solve of positive-definite block diagonal matrices is 211 # not positive-definite. 212 is_positive_definite=None, 213 is_square=True) 214