• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""`LinearOperator` acting like a lower triangular matrix."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.framework import dtypes
22from tensorflow.python.framework import ops
23from tensorflow.python.ops import array_ops
24from tensorflow.python.ops import math_ops
25from tensorflow.python.ops.linalg import linalg_impl as linalg
26from tensorflow.python.ops.linalg import linear_operator
27from tensorflow.python.ops.linalg import linear_operator_util
28from tensorflow.python.util.tf_export import tf_export
29
30__all__ = [
31    "LinearOperatorLowerTriangular",
32]
33
34
35@tf_export("linalg.LinearOperatorLowerTriangular")
36class LinearOperatorLowerTriangular(linear_operator.LinearOperator):
37  """`LinearOperator` acting like a [batch] square lower triangular matrix.
38
39  This operator acts like a [batch] lower triangular matrix `A` with shape
40  `[B1,...,Bb, N, N]` for some `b >= 0`.  The first `b` indices index a
41  batch member.  For every batch index `(i1,...,ib)`, `A[i1,...,ib, : :]` is
42  an `N x N` matrix.
43
44  `LinearOperatorLowerTriangular` is initialized with a `Tensor` having
45  dimensions `[B1,...,Bb, N, N]`. The upper triangle of the last two
46  dimensions is ignored.
47
48  ```python
49  # Create a 2 x 2 lower-triangular linear operator.
50  tril = [[1., 2.], [3., 4.]]
51  operator = LinearOperatorLowerTriangular(tril)
52
53  # The upper triangle is ignored.
54  operator.to_dense()
55  ==> [[1., 0.]
56       [3., 4.]]
57
58  operator.shape
59  ==> [2, 2]
60
61  operator.log_abs_determinant()
62  ==> scalar Tensor
63
64  x = ... Shape [2, 4] Tensor
65  operator.matmul(x)
66  ==> Shape [2, 4] Tensor
67
68  # Create a [2, 3] batch of 4 x 4 linear operators.
69  tril = tf.random_normal(shape=[2, 3, 4, 4])
70  operator = LinearOperatorLowerTriangular(tril)
71  ```
72
73  #### Shape compatibility
74
75  This operator acts on [batch] matrix with compatible shape.
76  `x` is a batch matrix with compatible shape for `matmul` and `solve` if
77
78  ```
79  operator.shape = [B1,...,Bb] + [N, N],  with b >= 0
80  x.shape =        [B1,...,Bb] + [N, R],  with R >= 0.
81  ```
82
83  #### Performance
84
85  Suppose `operator` is a `LinearOperatorLowerTriangular` of shape `[N, N]`,
86  and `x.shape = [N, R]`.  Then
87
88  * `operator.matmul(x)` involves `N^2 * R` multiplications.
89  * `operator.solve(x)` involves `N * R` size `N` back-substitutions.
90  * `operator.determinant()` involves a size `N` `reduce_prod`.
91
92  If instead `operator` and `x` have shape `[B1,...,Bb, N, N]` and
93  `[B1,...,Bb, N, R]`, every operation increases in complexity by `B1*...*Bb`.
94
95  #### Matrix property hints
96
97  This `LinearOperator` is initialized with boolean flags of the form `is_X`,
98  for `X = non_singular, self_adjoint, positive_definite, square`.
99  These have the following meaning:
100
101  * If `is_X == True`, callers should expect the operator to have the
102    property `X`.  This is a promise that should be fulfilled, but is *not* a
103    runtime assert.  For example, finite floating point precision may result
104    in these promises being violated.
105  * If `is_X == False`, callers should expect the operator to not have `X`.
106  * If `is_X == None` (the default), callers should have no expectation either
107    way.
108  """
109
110  def __init__(self,
111               tril,
112               is_non_singular=None,
113               is_self_adjoint=None,
114               is_positive_definite=None,
115               is_square=None,
116               name="LinearOperatorLowerTriangular"):
117    r"""Initialize a `LinearOperatorLowerTriangular`.
118
119    Args:
120      tril:  Shape `[B1,...,Bb, N, N]` with `b >= 0`, `N >= 0`.
121        The lower triangular part of `tril` defines this operator.  The strictly
122        upper triangle is ignored.
123      is_non_singular:  Expect that this operator is non-singular.
124        This operator is non-singular if and only if its diagonal elements are
125        all non-zero.
126      is_self_adjoint:  Expect that this operator is equal to its hermitian
127        transpose.  This operator is self-adjoint only if it is diagonal with
128        real-valued diagonal entries.  In this case it is advised to use
129        `LinearOperatorDiag`.
130      is_positive_definite:  Expect that this operator is positive definite,
131        meaning the quadratic form `x^H A x` has positive real part for all
132        nonzero `x`.  Note that we do not require the operator to be
133        self-adjoint to be positive-definite.  See:
134        https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices
135      is_square:  Expect that this operator acts like square [batch] matrices.
136      name: A name for this `LinearOperator`.
137
138    Raises:
139      ValueError:  If `is_square` is `False`.
140    """
141
142    if is_square is False:
143      raise ValueError(
144          "Only square lower triangular operators supported at this time.")
145    is_square = True
146
147    with ops.name_scope(name, values=[tril]):
148      self._tril = ops.convert_to_tensor(tril, name="tril")
149      self._check_tril(self._tril)
150      self._tril = array_ops.matrix_band_part(tril, -1, 0)
151      self._diag = array_ops.matrix_diag_part(self._tril)
152
153      super(LinearOperatorLowerTriangular, self).__init__(
154          dtype=self._tril.dtype,
155          graph_parents=[self._tril],
156          is_non_singular=is_non_singular,
157          is_self_adjoint=is_self_adjoint,
158          is_positive_definite=is_positive_definite,
159          is_square=is_square,
160          name=name)
161
162  def _check_tril(self, tril):
163    """Static check of the `tril` argument."""
164    allowed_dtypes = [
165        dtypes.float16,
166        dtypes.float32,
167        dtypes.float64,
168        dtypes.complex64,
169        dtypes.complex128,
170    ]
171    dtype = tril.dtype
172    if dtype not in allowed_dtypes:
173      raise TypeError(
174          "Argument tril must have dtype in %s.  Found: %s"
175          % (allowed_dtypes, dtype))
176
177    if tril.get_shape().ndims is not None and tril.get_shape().ndims < 2:
178      raise ValueError(
179          "Argument tril must have at least 2 dimensions.  Found: %s"
180          % tril)
181
182  def _shape(self):
183    return self._tril.get_shape()
184
185  def _shape_tensor(self):
186    return array_ops.shape(self._tril)
187
188  def _assert_non_singular(self):
189    return linear_operator_util.assert_no_entries_with_modulus_zero(
190        self._diag,
191        message="Singular operator:  Diagonal contained zero values.")
192
193  def _matmul(self, x, adjoint=False, adjoint_arg=False):
194    return linear_operator_util.matmul_with_broadcast(
195        self._tril, x, adjoint_a=adjoint, adjoint_b=adjoint_arg)
196
197  def _determinant(self):
198    return math_ops.reduce_prod(self._diag, axis=[-1])
199
200  def _log_abs_determinant(self):
201    return math_ops.reduce_sum(
202        math_ops.log(math_ops.abs(self._diag)), axis=[-1])
203
204  def _solve(self, rhs, adjoint=False, adjoint_arg=False):
205    rhs = linalg.adjoint(rhs) if adjoint_arg else rhs
206    return linear_operator_util.matrix_triangular_solve_with_broadcast(
207        self._tril, rhs, lower=True, adjoint=adjoint)
208
209  def _to_dense(self):
210    return self._tril
211
212  def _add_to_tensor(self, x):
213    return self._tril + x
214