• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Registrations for LinearOperator.solve."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.ops.linalg import linear_operator
22from tensorflow.python.ops.linalg import linear_operator_algebra
23from tensorflow.python.ops.linalg import linear_operator_block_diag
24from tensorflow.python.ops.linalg import linear_operator_circulant
25from tensorflow.python.ops.linalg import linear_operator_composition
26from tensorflow.python.ops.linalg import linear_operator_diag
27from tensorflow.python.ops.linalg import linear_operator_identity
28from tensorflow.python.ops.linalg import linear_operator_inversion
29from tensorflow.python.ops.linalg import linear_operator_lower_triangular
30from tensorflow.python.ops.linalg import registrations_util
31
32
33# By default, use a LinearOperatorComposition to delay the computation.
34@linear_operator_algebra.RegisterSolve(
35    linear_operator.LinearOperator, linear_operator.LinearOperator)
36def _solve_linear_operator(linop_a, linop_b):
37  """Generic solve of two `LinearOperator`s."""
38  is_square = registrations_util.is_square(linop_a, linop_b)
39  is_non_singular = None
40  is_self_adjoint = None
41  is_positive_definite = None
42
43  if is_square:
44    is_non_singular = registrations_util.combined_non_singular_hint(
45        linop_a, linop_b)
46  elif is_square is False:  # pylint:disable=g-bool-id-comparison
47    is_non_singular = False
48    is_self_adjoint = False
49    is_positive_definite = False
50
51  return linear_operator_composition.LinearOperatorComposition(
52      operators=[
53          linear_operator_inversion.LinearOperatorInversion(linop_a),
54          linop_b
55      ],
56      is_non_singular=is_non_singular,
57      is_self_adjoint=is_self_adjoint,
58      is_positive_definite=is_positive_definite,
59      is_square=is_square,
60  )
61
62
63@linear_operator_algebra.RegisterSolve(
64    linear_operator_inversion.LinearOperatorInversion,
65    linear_operator.LinearOperator)
66def _solve_inverse_linear_operator(linop_a, linop_b):
67  """Solve inverse of generic `LinearOperator`s."""
68  return linop_a.operator.matmul(linop_b)
69
70
71# Identity
72@linear_operator_algebra.RegisterSolve(
73    linear_operator_identity.LinearOperatorIdentity,
74    linear_operator.LinearOperator)
75def _solve_linear_operator_identity_left(identity, linop):
76  del identity
77  return linop
78
79
80@linear_operator_algebra.RegisterSolve(
81    linear_operator.LinearOperator,
82    linear_operator_identity.LinearOperatorIdentity)
83def _solve_linear_operator_identity_right(linop, identity):
84  del identity
85  return linop.inverse()
86
87
88@linear_operator_algebra.RegisterSolve(
89    linear_operator_identity.LinearOperatorScaledIdentity,
90    linear_operator_identity.LinearOperatorScaledIdentity)
91def _solve_linear_operator_scaled_identity(linop_a, linop_b):
92  """Solve of two ScaledIdentity `LinearOperators`."""
93  return linear_operator_identity.LinearOperatorScaledIdentity(
94      num_rows=linop_a.domain_dimension_tensor(),
95      multiplier=linop_b.multiplier / linop_a.multiplier,
96      is_non_singular=registrations_util.combined_non_singular_hint(
97          linop_a, linop_b),
98      is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint(
99          linop_a, linop_b),
100      is_positive_definite=(
101          registrations_util.combined_commuting_positive_definite_hint(
102              linop_a, linop_b)),
103      is_square=True)
104
105
106# Diag.
107
108
109@linear_operator_algebra.RegisterSolve(
110    linear_operator_diag.LinearOperatorDiag,
111    linear_operator_diag.LinearOperatorDiag)
112def _solve_linear_operator_diag(linop_a, linop_b):
113  return linear_operator_diag.LinearOperatorDiag(
114      diag=linop_b.diag / linop_a.diag,
115      is_non_singular=registrations_util.combined_non_singular_hint(
116          linop_a, linop_b),
117      is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint(
118          linop_a, linop_b),
119      is_positive_definite=(
120          registrations_util.combined_commuting_positive_definite_hint(
121              linop_a, linop_b)),
122      is_square=True)
123
124
125@linear_operator_algebra.RegisterSolve(
126    linear_operator_diag.LinearOperatorDiag,
127    linear_operator_identity.LinearOperatorScaledIdentity)
128def _solve_linear_operator_diag_scaled_identity_right(
129    linop_diag, linop_scaled_identity):
130  return linear_operator_diag.LinearOperatorDiag(
131      diag=linop_scaled_identity.multiplier / linop_diag.diag,
132      is_non_singular=registrations_util.combined_non_singular_hint(
133          linop_diag, linop_scaled_identity),
134      is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint(
135          linop_diag, linop_scaled_identity),
136      is_positive_definite=(
137          registrations_util.combined_commuting_positive_definite_hint(
138              linop_diag, linop_scaled_identity)),
139      is_square=True)
140
141
142@linear_operator_algebra.RegisterSolve(
143    linear_operator_identity.LinearOperatorScaledIdentity,
144    linear_operator_diag.LinearOperatorDiag)
145def _solve_linear_operator_diag_scaled_identity_left(
146    linop_scaled_identity, linop_diag):
147  return linear_operator_diag.LinearOperatorDiag(
148      diag=linop_diag.diag / linop_scaled_identity.multiplier,
149      is_non_singular=registrations_util.combined_non_singular_hint(
150          linop_diag, linop_scaled_identity),
151      is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint(
152          linop_diag, linop_scaled_identity),
153      is_positive_definite=(
154          registrations_util.combined_commuting_positive_definite_hint(
155              linop_diag, linop_scaled_identity)),
156      is_square=True)
157
158
159@linear_operator_algebra.RegisterSolve(
160    linear_operator_diag.LinearOperatorDiag,
161    linear_operator_lower_triangular.LinearOperatorLowerTriangular)
162def _solve_linear_operator_diag_tril(linop_diag, linop_triangular):
163  return linear_operator_lower_triangular.LinearOperatorLowerTriangular(
164      tril=linop_triangular.to_dense() / linop_diag.diag[..., None],
165      is_non_singular=registrations_util.combined_non_singular_hint(
166          linop_diag, linop_triangular),
167      # This is safe to do since the Triangular matrix is only self-adjoint
168      # when it is a diagonal matrix, and hence commutes.
169      is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint(
170          linop_diag, linop_triangular),
171      is_positive_definite=None,
172      is_square=True)
173
174
175# Circulant.
176
177
178@linear_operator_algebra.RegisterSolve(
179    linear_operator_circulant.LinearOperatorCirculant,
180    linear_operator_circulant.LinearOperatorCirculant)
181def _solve_linear_operator_circulant_circulant(linop_a, linop_b):
182  return linear_operator_circulant.LinearOperatorCirculant(
183      spectrum=linop_b.spectrum / linop_a.spectrum,
184      is_non_singular=registrations_util.combined_non_singular_hint(
185          linop_a, linop_b),
186      is_self_adjoint=registrations_util.combined_commuting_self_adjoint_hint(
187          linop_a, linop_b),
188      is_positive_definite=(
189          registrations_util.combined_commuting_positive_definite_hint(
190              linop_a, linop_b)),
191      is_square=True)
192
193
194# Block Diag
195
196
197@linear_operator_algebra.RegisterSolve(
198    linear_operator_block_diag.LinearOperatorBlockDiag,
199    linear_operator_block_diag.LinearOperatorBlockDiag)
200def _solve_linear_operator_block_diag_block_diag(linop_a, linop_b):
201  return linear_operator_block_diag.LinearOperatorBlockDiag(
202      operators=[
203          o1.solve(o2) for o1, o2 in zip(
204              linop_a.operators, linop_b.operators)],
205      is_non_singular=registrations_util.combined_non_singular_hint(
206          linop_a, linop_b),
207      # In general, a solve of self-adjoint positive-definite block diagonal
208      # matrices is not self-=adjoint.
209      is_self_adjoint=None,
210      # In general, a solve of positive-definite block diagonal matrices is
211      # not positive-definite.
212      is_positive_definite=None,
213      is_square=True)
214