1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Registrations for LinearOperator.matmul.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.ops.linalg import linear_operator 22from tensorflow.python.ops.linalg import linear_operator_algebra 23from tensorflow.python.ops.linalg import linear_operator_circulant 24from tensorflow.python.ops.linalg import linear_operator_composition 25from tensorflow.python.ops.linalg import linear_operator_diag 26from tensorflow.python.ops.linalg import linear_operator_identity 27from tensorflow.python.ops.linalg import linear_operator_lower_triangular 28from tensorflow.python.ops.linalg import linear_operator_zeros 29from tensorflow.python.ops.linalg import registrations_util 30 31 32# By default, use a LinearOperatorComposition to delay the computation. 33@linear_operator_algebra.RegisterMatmul( 34 linear_operator.LinearOperator, linear_operator.LinearOperator) 35def _matmul_linear_operator(linop_a, linop_b): 36 """Generic matmul of two `LinearOperator`s.""" 37 is_square = registrations_util.is_square(linop_a, linop_b) 38 is_non_singular = None 39 is_self_adjoint = None 40 is_positive_definite = None 41 42 if is_square: 43 is_non_singular = registrations_util.combined_non_singular_hint( 44 linop_a, linop_b) 45 elif is_square is False: # pylint:disable=g-bool-id-comparison 46 is_non_singular = False 47 is_self_adjoint = False 48 is_positive_definite = False 49 50 return linear_operator_composition.LinearOperatorComposition( 51 operators=[linop_a, linop_b], 52 is_non_singular=is_non_singular, 53 is_self_adjoint=is_self_adjoint, 54 is_positive_definite=is_positive_definite, 55 is_square=is_square, 56 ) 57 58# Identity 59 60 61@linear_operator_algebra.RegisterMatmul( 62 linear_operator_identity.LinearOperatorIdentity, 63 linear_operator.LinearOperator) 64def _matmul_linear_operator_identity_left(identity, linop): 65 del identity 66 return linop 67 68 69@linear_operator_algebra.RegisterMatmul( 70 linear_operator.LinearOperator, 71 linear_operator_identity.LinearOperatorIdentity) 72def _matmul_linear_operator_identity_right(linop, identity): 73 del identity 74 return linop 75 76 77@linear_operator_algebra.RegisterMatmul( 78 linear_operator_identity.LinearOperatorScaledIdentity, 79 linear_operator_identity.LinearOperatorScaledIdentity) 80def _matmul_linear_operator_scaled_identity(linop_a, linop_b): 81 """Matmul of two ScaledIdentity `LinearOperators`.""" 82 return linear_operator_identity.LinearOperatorScaledIdentity( 83 num_rows=linop_a.domain_dimension_tensor(), 84 multiplier=linop_a.multiplier * linop_b.multiplier, 85 is_non_singular=registrations_util.combined_non_singular_hint( 86 linop_a, linop_b), 87 is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint( 88 linop_a, linop_b), 89 is_positive_definite=( 90 registrations_util.combined_commuting_positive_definite_hint( 91 linop_a, linop_b)), 92 is_square=True) 93 94 95# Zeros 96 97 98@linear_operator_algebra.RegisterMatmul( 99 linear_operator.LinearOperator, 100 linear_operator_zeros.LinearOperatorZeros) 101def _matmul_linear_operator_zeros_right(linop, zeros): 102 if not zeros.is_square or not linop.is_square: 103 raise ValueError("Matmul with non-square `LinearOperator`s or non-square " 104 "`LinearOperatorZeros` not supported at this time.") 105 return zeros 106 107 108@linear_operator_algebra.RegisterMatmul( 109 linear_operator_zeros.LinearOperatorZeros, 110 linear_operator.LinearOperator) 111def _matmul_linear_operator_zeros_left(zeros, linop): 112 if not zeros.is_square or not linop.is_square: 113 raise ValueError("Matmul with non-square `LinearOperator`s or non-square " 114 "`LinearOperatorZeros` not supported at this time.") 115 return zeros 116 117 118# Diag. 119 120 121@linear_operator_algebra.RegisterMatmul( 122 linear_operator_diag.LinearOperatorDiag, 123 linear_operator_diag.LinearOperatorDiag) 124def _matmul_linear_operator_diag(linop_a, linop_b): 125 return linear_operator_diag.LinearOperatorDiag( 126 diag=linop_a.diag * linop_b.diag, 127 is_non_singular=registrations_util.combined_non_singular_hint( 128 linop_a, linop_b), 129 is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint( 130 linop_a, linop_b), 131 is_positive_definite=( 132 registrations_util.combined_commuting_positive_definite_hint( 133 linop_a, linop_b)), 134 is_square=True) 135 136 137@linear_operator_algebra.RegisterMatmul( 138 linear_operator_diag.LinearOperatorDiag, 139 linear_operator_identity.LinearOperatorScaledIdentity) 140def _matmul_linear_operator_diag_scaled_identity_right( 141 linop_diag, linop_scaled_identity): 142 return linear_operator_diag.LinearOperatorDiag( 143 diag=linop_diag.diag * linop_scaled_identity.multiplier, 144 is_non_singular=registrations_util.combined_non_singular_hint( 145 linop_diag, linop_scaled_identity), 146 is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint( 147 linop_diag, linop_scaled_identity), 148 is_positive_definite=( 149 registrations_util.combined_commuting_positive_definite_hint( 150 linop_diag, linop_scaled_identity)), 151 is_square=True) 152 153 154@linear_operator_algebra.RegisterMatmul( 155 linear_operator_identity.LinearOperatorScaledIdentity, 156 linear_operator_diag.LinearOperatorDiag) 157def _matmul_linear_operator_diag_scaled_identity_left( 158 linop_scaled_identity, linop_diag): 159 return linear_operator_diag.LinearOperatorDiag( 160 diag=linop_diag.diag * linop_scaled_identity.multiplier, 161 is_non_singular=registrations_util.combined_non_singular_hint( 162 linop_diag, linop_scaled_identity), 163 is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint( 164 linop_diag, linop_scaled_identity), 165 is_positive_definite=( 166 registrations_util.combined_commuting_positive_definite_hint( 167 linop_diag, linop_scaled_identity)), 168 is_square=True) 169 170 171@linear_operator_algebra.RegisterMatmul( 172 linear_operator_diag.LinearOperatorDiag, 173 linear_operator_lower_triangular.LinearOperatorLowerTriangular) 174def _matmul_linear_operator_diag_tril(linop_diag, linop_triangular): 175 return linear_operator_lower_triangular.LinearOperatorLowerTriangular( 176 tril=linop_diag.diag[..., None] * linop_triangular.to_dense(), 177 is_non_singular=registrations_util.combined_non_singular_hint( 178 linop_diag, linop_triangular), 179 # This is safe to do since the Triangular matrix is only self-adjoint 180 # when it is a diagonal matrix, and hence commutes. 181 is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint( 182 linop_diag, linop_triangular), 183 is_positive_definite=None, 184 is_square=True) 185 186 187@linear_operator_algebra.RegisterMatmul( 188 linear_operator_lower_triangular.LinearOperatorLowerTriangular, 189 linear_operator_diag.LinearOperatorDiag) 190def _matmul_linear_operator_tril_diag(linop_triangular, linop_diag): 191 return linear_operator_lower_triangular.LinearOperatorLowerTriangular( 192 tril=linop_triangular.to_dense() * linop_diag.diag, 193 is_non_singular=registrations_util.combined_non_singular_hint( 194 linop_diag, linop_triangular), 195 # This is safe to do since the Triangular matrix is only self-adjoint 196 # when it is a diagonal matrix, and hence commutes. 197 is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint( 198 linop_diag, linop_triangular), 199 is_positive_definite=None, 200 is_square=True) 201 202# Circulant. 203 204 205@linear_operator_algebra.RegisterMatmul( 206 linear_operator_circulant.LinearOperatorCirculant, 207 linear_operator_circulant.LinearOperatorCirculant) 208def _matmul_linear_operator_circulant_circulant(linop_a, linop_b): 209 return linear_operator_circulant.LinearOperatorCirculant( 210 spectrum=linop_a.spectrum * linop_b.spectrum, 211 is_non_singular=registrations_util.combined_non_singular_hint( 212 linop_a, linop_b), 213 is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint( 214 linop_a, linop_b), 215 is_positive_definite=( 216 registrations_util.combined_commuting_positive_definite_hint( 217 linop_a, linop_b)), 218 is_square=True) 219