• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""`LinearOperator` acting like a lower triangular matrix."""
16
17from tensorflow.python.framework import ops
18from tensorflow.python.ops import array_ops
19from tensorflow.python.ops import math_ops
20from tensorflow.python.ops.linalg import linalg_impl as linalg
21from tensorflow.python.ops.linalg import linear_operator
22from tensorflow.python.ops.linalg import linear_operator_util
23from tensorflow.python.util.tf_export import tf_export
24
25__all__ = [
26    "LinearOperatorLowerTriangular",
27]
28
29
30@tf_export("linalg.LinearOperatorLowerTriangular")
31@linear_operator.make_composite_tensor
32class LinearOperatorLowerTriangular(linear_operator.LinearOperator):
33  """`LinearOperator` acting like a [batch] square lower triangular matrix.
34
35  This operator acts like a [batch] lower triangular matrix `A` with shape
36  `[B1,...,Bb, N, N]` for some `b >= 0`.  The first `b` indices index a
37  batch member.  For every batch index `(i1,...,ib)`, `A[i1,...,ib, : :]` is
38  an `N x N` matrix.
39
40  `LinearOperatorLowerTriangular` is initialized with a `Tensor` having
41  dimensions `[B1,...,Bb, N, N]`. The upper triangle of the last two
42  dimensions is ignored.
43
44  ```python
45  # Create a 2 x 2 lower-triangular linear operator.
46  tril = [[1., 2.], [3., 4.]]
47  operator = LinearOperatorLowerTriangular(tril)
48
49  # The upper triangle is ignored.
50  operator.to_dense()
51  ==> [[1., 0.]
52       [3., 4.]]
53
54  operator.shape
55  ==> [2, 2]
56
57  operator.log_abs_determinant()
58  ==> scalar Tensor
59
60  x = ... Shape [2, 4] Tensor
61  operator.matmul(x)
62  ==> Shape [2, 4] Tensor
63
64  # Create a [2, 3] batch of 4 x 4 linear operators.
65  tril = tf.random.normal(shape=[2, 3, 4, 4])
66  operator = LinearOperatorLowerTriangular(tril)
67  ```
68
69  #### Shape compatibility
70
71  This operator acts on [batch] matrix with compatible shape.
72  `x` is a batch matrix with compatible shape for `matmul` and `solve` if
73
74  ```
75  operator.shape = [B1,...,Bb] + [N, N],  with b >= 0
76  x.shape =        [B1,...,Bb] + [N, R],  with R >= 0.
77  ```
78
79  #### Performance
80
81  Suppose `operator` is a `LinearOperatorLowerTriangular` of shape `[N, N]`,
82  and `x.shape = [N, R]`.  Then
83
84  * `operator.matmul(x)` involves `N^2 * R` multiplications.
85  * `operator.solve(x)` involves `N * R` size `N` back-substitutions.
86  * `operator.determinant()` involves a size `N` `reduce_prod`.
87
88  If instead `operator` and `x` have shape `[B1,...,Bb, N, N]` and
89  `[B1,...,Bb, N, R]`, every operation increases in complexity by `B1*...*Bb`.
90
91  #### Matrix property hints
92
93  This `LinearOperator` is initialized with boolean flags of the form `is_X`,
94  for `X = non_singular, self_adjoint, positive_definite, square`.
95  These have the following meaning:
96
97  * If `is_X == True`, callers should expect the operator to have the
98    property `X`.  This is a promise that should be fulfilled, but is *not* a
99    runtime assert.  For example, finite floating point precision may result
100    in these promises being violated.
101  * If `is_X == False`, callers should expect the operator to not have `X`.
102  * If `is_X == None` (the default), callers should have no expectation either
103    way.
104  """
105
106  def __init__(self,
107               tril,
108               is_non_singular=None,
109               is_self_adjoint=None,
110               is_positive_definite=None,
111               is_square=None,
112               name="LinearOperatorLowerTriangular"):
113    r"""Initialize a `LinearOperatorLowerTriangular`.
114
115    Args:
116      tril:  Shape `[B1,...,Bb, N, N]` with `b >= 0`, `N >= 0`.
117        The lower triangular part of `tril` defines this operator.  The strictly
118        upper triangle is ignored.
119      is_non_singular:  Expect that this operator is non-singular.
120        This operator is non-singular if and only if its diagonal elements are
121        all non-zero.
122      is_self_adjoint:  Expect that this operator is equal to its hermitian
123        transpose.  This operator is self-adjoint only if it is diagonal with
124        real-valued diagonal entries.  In this case it is advised to use
125        `LinearOperatorDiag`.
126      is_positive_definite:  Expect that this operator is positive definite,
127        meaning the quadratic form `x^H A x` has positive real part for all
128        nonzero `x`.  Note that we do not require the operator to be
129        self-adjoint to be positive-definite.  See:
130        https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices
131      is_square:  Expect that this operator acts like square [batch] matrices.
132      name: A name for this `LinearOperator`.
133
134    Raises:
135      ValueError:  If `is_square` is `False`.
136    """
137    parameters = dict(
138        tril=tril,
139        is_non_singular=is_non_singular,
140        is_self_adjoint=is_self_adjoint,
141        is_positive_definite=is_positive_definite,
142        is_square=is_square,
143        name=name
144    )
145
146    if is_square is False:
147      raise ValueError(
148          "Only square lower triangular operators supported at this time.")
149    is_square = True
150
151    with ops.name_scope(name, values=[tril]):
152      self._tril = linear_operator_util.convert_nonref_to_tensor(tril,
153                                                                 name="tril")
154      self._check_tril(self._tril)
155
156      super(LinearOperatorLowerTriangular, self).__init__(
157          dtype=self._tril.dtype,
158          is_non_singular=is_non_singular,
159          is_self_adjoint=is_self_adjoint,
160          is_positive_definite=is_positive_definite,
161          is_square=is_square,
162          parameters=parameters,
163          name=name)
164
165  @property
166  def tril(self):
167    """The lower triangular matrix defining this operator."""
168    return self._tril
169
170  def _check_tril(self, tril):
171    """Static check of the `tril` argument."""
172
173    if tril.shape.ndims is not None and tril.shape.ndims < 2:
174      raise ValueError(
175          "Argument tril must have at least 2 dimensions.  Found: %s"
176          % tril)
177
178  def _get_tril(self):
179    """Gets the `tril` kwarg, with upper part zero-d out."""
180    return array_ops.matrix_band_part(self._tril, -1, 0)
181
182  def _get_diag(self):
183    """Gets the diagonal part of `tril` kwarg."""
184    return array_ops.matrix_diag_part(self._tril)
185
186  def _shape(self):
187    return self._tril.shape
188
189  def _shape_tensor(self):
190    return array_ops.shape(self._tril)
191
192  def _assert_non_singular(self):
193    return linear_operator_util.assert_no_entries_with_modulus_zero(
194        self._get_diag(),
195        message="Singular operator:  Diagonal contained zero values.")
196
197  def _matmul(self, x, adjoint=False, adjoint_arg=False):
198    return math_ops.matmul(
199        self._get_tril(), x, adjoint_a=adjoint, adjoint_b=adjoint_arg)
200
201  def _determinant(self):
202    return math_ops.reduce_prod(self._get_diag(), axis=[-1])
203
204  def _log_abs_determinant(self):
205    return math_ops.reduce_sum(
206        math_ops.log(math_ops.abs(self._get_diag())), axis=[-1])
207
208  def _solve(self, rhs, adjoint=False, adjoint_arg=False):
209    rhs = linalg.adjoint(rhs) if adjoint_arg else rhs
210    return linalg.triangular_solve(
211        self._get_tril(), rhs, lower=True, adjoint=adjoint)
212
213  def _to_dense(self):
214    return self._get_tril()
215
216  def _eigvals(self):
217    return self._get_diag()
218
219  @property
220  def _composite_tensor_fields(self):
221    return ("tril",)
222
223  @property
224  def _experimental_parameter_ndims_to_matrix_ndims(self):
225    return {"tril": 2}
226