• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Registrations for LinearOperator.matmul."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.ops.linalg import linear_operator
22from tensorflow.python.ops.linalg import linear_operator_algebra
23from tensorflow.python.ops.linalg import linear_operator_circulant
24from tensorflow.python.ops.linalg import linear_operator_composition
25from tensorflow.python.ops.linalg import linear_operator_diag
26from tensorflow.python.ops.linalg import linear_operator_identity
27from tensorflow.python.ops.linalg import linear_operator_lower_triangular
28from tensorflow.python.ops.linalg import linear_operator_zeros
29from tensorflow.python.ops.linalg import registrations_util
30
31
32# By default, use a LinearOperatorComposition to delay the computation.
33@linear_operator_algebra.RegisterMatmul(
34    linear_operator.LinearOperator, linear_operator.LinearOperator)
35def _matmul_linear_operator(linop_a, linop_b):
36  """Generic matmul of two `LinearOperator`s."""
37  is_square = registrations_util.is_square(linop_a, linop_b)
38  is_non_singular = None
39  is_self_adjoint = None
40  is_positive_definite = None
41
42  if is_square:
43    is_non_singular = registrations_util.combined_non_singular_hint(
44        linop_a, linop_b)
45  elif is_square is False:  # pylint:disable=g-bool-id-comparison
46    is_non_singular = False
47    is_self_adjoint = False
48    is_positive_definite = False
49
50  return linear_operator_composition.LinearOperatorComposition(
51      operators=[linop_a, linop_b],
52      is_non_singular=is_non_singular,
53      is_self_adjoint=is_self_adjoint,
54      is_positive_definite=is_positive_definite,
55      is_square=is_square,
56  )
57
58# Identity
59
60
61@linear_operator_algebra.RegisterMatmul(
62    linear_operator_identity.LinearOperatorIdentity,
63    linear_operator.LinearOperator)
64def _matmul_linear_operator_identity_left(identity, linop):
65  del identity
66  return linop
67
68
69@linear_operator_algebra.RegisterMatmul(
70    linear_operator.LinearOperator,
71    linear_operator_identity.LinearOperatorIdentity)
72def _matmul_linear_operator_identity_right(linop, identity):
73  del identity
74  return linop
75
76
77@linear_operator_algebra.RegisterMatmul(
78    linear_operator_identity.LinearOperatorScaledIdentity,
79    linear_operator_identity.LinearOperatorScaledIdentity)
80def _matmul_linear_operator_scaled_identity(linop_a, linop_b):
81  """Matmul of two ScaledIdentity `LinearOperators`."""
82  return linear_operator_identity.LinearOperatorScaledIdentity(
83      num_rows=linop_a.domain_dimension_tensor(),
84      multiplier=linop_a.multiplier * linop_b.multiplier,
85      is_non_singular=registrations_util.combined_non_singular_hint(
86          linop_a, linop_b),
87      is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint(
88          linop_a, linop_b),
89      is_positive_definite=(
90          registrations_util.combined_commuting_positive_definite_hint(
91              linop_a, linop_b)),
92      is_square=True)
93
94
95# Zeros
96
97
98@linear_operator_algebra.RegisterMatmul(
99    linear_operator.LinearOperator,
100    linear_operator_zeros.LinearOperatorZeros)
101def _matmul_linear_operator_zeros_right(linop, zeros):
102  if not zeros.is_square or not linop.is_square:
103    raise ValueError("Matmul with non-square `LinearOperator`s or non-square "
104                     "`LinearOperatorZeros` not supported at this time.")
105  return zeros
106
107
108@linear_operator_algebra.RegisterMatmul(
109    linear_operator_zeros.LinearOperatorZeros,
110    linear_operator.LinearOperator)
111def _matmul_linear_operator_zeros_left(zeros, linop):
112  if not zeros.is_square or not linop.is_square:
113    raise ValueError("Matmul with non-square `LinearOperator`s or non-square "
114                     "`LinearOperatorZeros` not supported at this time.")
115  return zeros
116
117
118# Diag.
119
120
121@linear_operator_algebra.RegisterMatmul(
122    linear_operator_diag.LinearOperatorDiag,
123    linear_operator_diag.LinearOperatorDiag)
124def _matmul_linear_operator_diag(linop_a, linop_b):
125  return linear_operator_diag.LinearOperatorDiag(
126      diag=linop_a.diag * linop_b.diag,
127      is_non_singular=registrations_util.combined_non_singular_hint(
128          linop_a, linop_b),
129      is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint(
130          linop_a, linop_b),
131      is_positive_definite=(
132          registrations_util.combined_commuting_positive_definite_hint(
133              linop_a, linop_b)),
134      is_square=True)
135
136
137@linear_operator_algebra.RegisterMatmul(
138    linear_operator_diag.LinearOperatorDiag,
139    linear_operator_identity.LinearOperatorScaledIdentity)
140def _matmul_linear_operator_diag_scaled_identity_right(
141    linop_diag, linop_scaled_identity):
142  return linear_operator_diag.LinearOperatorDiag(
143      diag=linop_diag.diag * linop_scaled_identity.multiplier,
144      is_non_singular=registrations_util.combined_non_singular_hint(
145          linop_diag, linop_scaled_identity),
146      is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint(
147          linop_diag, linop_scaled_identity),
148      is_positive_definite=(
149          registrations_util.combined_commuting_positive_definite_hint(
150              linop_diag, linop_scaled_identity)),
151      is_square=True)
152
153
154@linear_operator_algebra.RegisterMatmul(
155    linear_operator_identity.LinearOperatorScaledIdentity,
156    linear_operator_diag.LinearOperatorDiag)
157def _matmul_linear_operator_diag_scaled_identity_left(
158    linop_scaled_identity, linop_diag):
159  return linear_operator_diag.LinearOperatorDiag(
160      diag=linop_diag.diag * linop_scaled_identity.multiplier,
161      is_non_singular=registrations_util.combined_non_singular_hint(
162          linop_diag, linop_scaled_identity),
163      is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint(
164          linop_diag, linop_scaled_identity),
165      is_positive_definite=(
166          registrations_util.combined_commuting_positive_definite_hint(
167              linop_diag, linop_scaled_identity)),
168      is_square=True)
169
170
171@linear_operator_algebra.RegisterMatmul(
172    linear_operator_diag.LinearOperatorDiag,
173    linear_operator_lower_triangular.LinearOperatorLowerTriangular)
174def _matmul_linear_operator_diag_tril(linop_diag, linop_triangular):
175  return linear_operator_lower_triangular.LinearOperatorLowerTriangular(
176      tril=linop_diag.diag[..., None] * linop_triangular.to_dense(),
177      is_non_singular=registrations_util.combined_non_singular_hint(
178          linop_diag, linop_triangular),
179      # This is safe to do since the Triangular matrix is only self-adjoint
180      # when it is a diagonal matrix, and hence commutes.
181      is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint(
182          linop_diag, linop_triangular),
183      is_positive_definite=None,
184      is_square=True)
185
186
187@linear_operator_algebra.RegisterMatmul(
188    linear_operator_lower_triangular.LinearOperatorLowerTriangular,
189    linear_operator_diag.LinearOperatorDiag)
190def _matmul_linear_operator_tril_diag(linop_triangular, linop_diag):
191  return linear_operator_lower_triangular.LinearOperatorLowerTriangular(
192      tril=linop_triangular.to_dense() * linop_diag.diag,
193      is_non_singular=registrations_util.combined_non_singular_hint(
194          linop_diag, linop_triangular),
195      # This is safe to do since the Triangular matrix is only self-adjoint
196      # when it is a diagonal matrix, and hence commutes.
197      is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint(
198          linop_diag, linop_triangular),
199      is_positive_definite=None,
200      is_square=True)
201
202# Circulant.
203
204
205@linear_operator_algebra.RegisterMatmul(
206    linear_operator_circulant.LinearOperatorCirculant,
207    linear_operator_circulant.LinearOperatorCirculant)
208def _matmul_linear_operator_circulant_circulant(linop_a, linop_b):
209  return linear_operator_circulant.LinearOperatorCirculant(
210      spectrum=linop_a.spectrum * linop_b.spectrum,
211      is_non_singular=registrations_util.combined_non_singular_hint(
212          linop_a, linop_b),
213      is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint(
214          linop_a, linop_b),
215      is_positive_definite=(
216          registrations_util.combined_commuting_positive_definite_hint(
217              linop_a, linop_b)),
218      is_square=True)
219