• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""`LinearOperator` acting like a tridiagonal matrix."""
16
17from tensorflow.python.framework import ops
18from tensorflow.python.ops import array_ops
19from tensorflow.python.ops import check_ops
20from tensorflow.python.ops import control_flow_ops
21from tensorflow.python.ops import gen_array_ops
22from tensorflow.python.ops import manip_ops
23from tensorflow.python.ops import math_ops
24from tensorflow.python.ops.linalg import linalg_impl as linalg
25from tensorflow.python.ops.linalg import linear_operator
26from tensorflow.python.ops.linalg import linear_operator_util
27from tensorflow.python.util.tf_export import tf_export
28
29__all__ = ['LinearOperatorTridiag',]
30
31_COMPACT = 'compact'
32_MATRIX = 'matrix'
33_SEQUENCE = 'sequence'
34_DIAGONAL_FORMATS = frozenset({_COMPACT, _MATRIX, _SEQUENCE})
35
36
37@tf_export('linalg.LinearOperatorTridiag')
38@linear_operator.make_composite_tensor
39class LinearOperatorTridiag(linear_operator.LinearOperator):
40  """`LinearOperator` acting like a [batch] square tridiagonal matrix.
41
42  This operator acts like a [batch] square tridiagonal matrix `A` with shape
43  `[B1,...,Bb, N, N]` for some `b >= 0`.  The first `b` indices index a
44  batch member.  For every batch index `(i1,...,ib)`, `A[i1,...,ib, : :]` is
45  an `N x M` matrix.  This matrix `A` is not materialized, but for
46  purposes of broadcasting this shape will be relevant.
47
48  Example usage:
49
50  Create a 3 x 3 tridiagonal linear operator.
51
52  >>> superdiag = [3., 4., 5.]
53  >>> diag = [1., -1., 2.]
54  >>> subdiag = [6., 7., 8]
55  >>> operator = tf.linalg.LinearOperatorTridiag(
56  ...    [superdiag, diag, subdiag],
57  ...    diagonals_format='sequence')
58  >>> operator.to_dense()
59  <tf.Tensor: shape=(3, 3), dtype=float32, numpy=
60  array([[ 1.,  3.,  0.],
61         [ 7., -1.,  4.],
62         [ 0.,  8.,  2.]], dtype=float32)>
63  >>> operator.shape
64  TensorShape([3, 3])
65
66  Scalar Tensor output.
67
68  >>> operator.log_abs_determinant()
69  <tf.Tensor: shape=(), dtype=float32, numpy=4.3307333>
70
71  Create a [2, 3] batch of 4 x 4 linear operators.
72
73  >>> diagonals = tf.random.normal(shape=[2, 3, 3, 4])
74  >>> operator = tf.linalg.LinearOperatorTridiag(
75  ...   diagonals,
76  ...   diagonals_format='compact')
77
78  Create a shape [2, 1, 4, 2] vector.  Note that this shape is compatible
79  since the batch dimensions, [2, 1], are broadcast to
80  operator.batch_shape = [2, 3].
81
82  >>> y = tf.random.normal(shape=[2, 1, 4, 2])
83  >>> x = operator.solve(y)
84  >>> x
85  <tf.Tensor: shape=(2, 3, 4, 2), dtype=float32, numpy=...,
86  dtype=float32)>
87
88  #### Shape compatibility
89
90  This operator acts on [batch] matrix with compatible shape.
91  `x` is a batch matrix with compatible shape for `matmul` and `solve` if
92
93  ```
94  operator.shape = [B1,...,Bb] + [N, N],  with b >= 0
95  x.shape =   [C1,...,Cc] + [N, R],
96  and [C1,...,Cc] broadcasts with [B1,...,Bb].
97  ```
98
99  #### Performance
100
101  Suppose `operator` is a `LinearOperatorTridiag` of shape `[N, N]`,
102  and `x.shape = [N, R]`.  Then
103
104  * `operator.matmul(x)` will take O(N * R) time.
105  * `operator.solve(x)` will take O(N * R) time.
106
107  If instead `operator` and `x` have shape `[B1,...,Bb, N, N]` and
108  `[B1,...,Bb, N, R]`, every operation increases in complexity by `B1*...*Bb`.
109
110  #### Matrix property hints
111
112  This `LinearOperator` is initialized with boolean flags of the form `is_X`,
113  for `X = non_singular, self_adjoint, positive_definite, square`.
114  These have the following meaning:
115
116  * If `is_X == True`, callers should expect the operator to have the
117    property `X`.  This is a promise that should be fulfilled, but is *not* a
118    runtime assert.  For example, finite floating point precision may result
119    in these promises being violated.
120  * If `is_X == False`, callers should expect the operator to not have `X`.
121  * If `is_X == None` (the default), callers should have no expectation either
122    way.
123  """
124
125  def __init__(self,
126               diagonals,
127               diagonals_format=_COMPACT,
128               is_non_singular=None,
129               is_self_adjoint=None,
130               is_positive_definite=None,
131               is_square=None,
132               name='LinearOperatorTridiag'):
133    r"""Initialize a `LinearOperatorTridiag`.
134
135    Args:
136      diagonals: `Tensor` or list of `Tensor`s depending on `diagonals_format`.
137
138        If `diagonals_format=sequence`, this is a list of three `Tensor`'s each
139        with shape `[B1, ..., Bb, N]`, `b >= 0, N >= 0`, representing the
140        superdiagonal, diagonal and subdiagonal in that order. Note the
141        superdiagonal is padded with an element in the last position, and the
142        subdiagonal is padded with an element in the front.
143
144        If `diagonals_format=matrix` this is a `[B1, ... Bb, N, N]` shaped
145        `Tensor` representing the full tridiagonal matrix.
146
147        If `diagonals_format=compact` this is a `[B1, ... Bb, 3, N]` shaped
148        `Tensor` with the second to last dimension indexing the
149        superdiagonal, diagonal and subdiagonal in that order. Note the
150        superdiagonal is padded with an element in the last position, and the
151        subdiagonal is padded with an element in the front.
152
153        In every case, these `Tensor`s are all floating dtype.
154      diagonals_format: one of `matrix`, `sequence`, or `compact`. Default is
155        `compact`.
156      is_non_singular:  Expect that this operator is non-singular.
157      is_self_adjoint:  Expect that this operator is equal to its hermitian
158        transpose.  If `diag.dtype` is real, this is auto-set to `True`.
159      is_positive_definite:  Expect that this operator is positive definite,
160        meaning the quadratic form `x^H A x` has positive real part for all
161        nonzero `x`.  Note that we do not require the operator to be
162        self-adjoint to be positive-definite.  See:
163        https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices
164      is_square:  Expect that this operator acts like square [batch] matrices.
165      name: A name for this `LinearOperator`.
166
167    Raises:
168      TypeError:  If `diag.dtype` is not an allowed type.
169      ValueError:  If `diag.dtype` is real, and `is_self_adjoint` is not `True`.
170    """
171    parameters = dict(
172        diagonals=diagonals,
173        diagonals_format=diagonals_format,
174        is_non_singular=is_non_singular,
175        is_self_adjoint=is_self_adjoint,
176        is_positive_definite=is_positive_definite,
177        is_square=is_square,
178        name=name
179    )
180
181    with ops.name_scope(name, values=[diagonals]):
182      if diagonals_format not in _DIAGONAL_FORMATS:
183        raise ValueError(
184            f'Argument `diagonals_format` must be one of compact, matrix, or '
185            f'sequence. Received : {diagonals_format}.')
186      if diagonals_format == _SEQUENCE:
187        self._diagonals = [linear_operator_util.convert_nonref_to_tensor(
188            d, name='diag_{}'.format(i)) for i, d in enumerate(diagonals)]
189        dtype = self._diagonals[0].dtype
190      else:
191        self._diagonals = linear_operator_util.convert_nonref_to_tensor(
192            diagonals, name='diagonals')
193        dtype = self._diagonals.dtype
194      self._diagonals_format = diagonals_format
195
196      super(LinearOperatorTridiag, self).__init__(
197          dtype=dtype,
198          is_non_singular=is_non_singular,
199          is_self_adjoint=is_self_adjoint,
200          is_positive_definite=is_positive_definite,
201          is_square=is_square,
202          parameters=parameters,
203          name=name)
204
205  def _shape(self):
206    if self.diagonals_format == _MATRIX:
207      return self.diagonals.shape
208    if self.diagonals_format == _COMPACT:
209      # Remove the second to last dimension that contains the value 3.
210      d_shape = self.diagonals.shape[:-2].concatenate(
211          self.diagonals.shape[-1])
212    else:
213      broadcast_shape = array_ops.broadcast_static_shape(
214          self.diagonals[0].shape[:-1],
215          self.diagonals[1].shape[:-1])
216      broadcast_shape = array_ops.broadcast_static_shape(
217          broadcast_shape,
218          self.diagonals[2].shape[:-1])
219      d_shape = broadcast_shape.concatenate(self.diagonals[1].shape[-1])
220    return d_shape.concatenate(d_shape[-1])
221
222  def _shape_tensor(self, diagonals=None):
223    diagonals = diagonals if diagonals is not None else self.diagonals
224    if self.diagonals_format == _MATRIX:
225      return array_ops.shape(diagonals)
226    if self.diagonals_format == _COMPACT:
227      d_shape = array_ops.shape(diagonals[..., 0, :])
228    else:
229      broadcast_shape = array_ops.broadcast_dynamic_shape(
230          array_ops.shape(self.diagonals[0])[:-1],
231          array_ops.shape(self.diagonals[1])[:-1])
232      broadcast_shape = array_ops.broadcast_dynamic_shape(
233          broadcast_shape,
234          array_ops.shape(self.diagonals[2])[:-1])
235      d_shape = array_ops.concat(
236          [broadcast_shape, [array_ops.shape(self.diagonals[1])[-1]]], axis=0)
237    return array_ops.concat([d_shape, [d_shape[-1]]], axis=-1)
238
239  def _assert_self_adjoint(self):
240    # Check the diagonal has non-zero imaginary, and the super and subdiagonals
241    # are conjugate.
242
243    asserts = []
244    diag_message = (
245        'This tridiagonal operator contained non-zero '
246        'imaginary values on the diagonal.')
247    off_diag_message = (
248        'This tridiagonal operator has non-conjugate '
249        'subdiagonal and superdiagonal.')
250
251    if self.diagonals_format == _MATRIX:
252      asserts += [check_ops.assert_equal(
253          self.diagonals, linalg.adjoint(self.diagonals),
254          message='Matrix was not equal to its adjoint.')]
255    elif self.diagonals_format == _COMPACT:
256      diagonals = ops.convert_to_tensor_v2_with_dispatch(self.diagonals)
257      asserts += [linear_operator_util.assert_zero_imag_part(
258          diagonals[..., 1, :], message=diag_message)]
259      # Roll the subdiagonal so the shifted argument is at the end.
260      subdiag = manip_ops.roll(diagonals[..., 2, :], shift=-1, axis=-1)
261      asserts += [check_ops.assert_equal(
262          math_ops.conj(subdiag[..., :-1]),
263          diagonals[..., 0, :-1],
264          message=off_diag_message)]
265    else:
266      asserts += [linear_operator_util.assert_zero_imag_part(
267          self.diagonals[1], message=diag_message)]
268      subdiag = manip_ops.roll(self.diagonals[2], shift=-1, axis=-1)
269      asserts += [check_ops.assert_equal(
270          math_ops.conj(subdiag[..., :-1]),
271          self.diagonals[0][..., :-1],
272          message=off_diag_message)]
273    return control_flow_ops.group(asserts)
274
275  def _construct_adjoint_diagonals(self, diagonals):
276    # Constructs adjoint tridiagonal matrix from diagonals.
277    if self.diagonals_format == _SEQUENCE:
278      diagonals = [math_ops.conj(d) for d in reversed(diagonals)]
279      # The subdiag and the superdiag swap places, so we need to shift the
280      # padding argument.
281      diagonals[0] = manip_ops.roll(diagonals[0], shift=-1, axis=-1)
282      diagonals[2] = manip_ops.roll(diagonals[2], shift=1, axis=-1)
283      return diagonals
284    elif self.diagonals_format == _MATRIX:
285      return linalg.adjoint(diagonals)
286    else:
287      diagonals = math_ops.conj(diagonals)
288      superdiag, diag, subdiag = array_ops.unstack(
289          diagonals, num=3, axis=-2)
290      # The subdiag and the superdiag swap places, so we need
291      # to shift all arguments.
292      new_superdiag = manip_ops.roll(subdiag, shift=-1, axis=-1)
293      new_subdiag = manip_ops.roll(superdiag, shift=1, axis=-1)
294      return array_ops.stack([new_superdiag, diag, new_subdiag], axis=-2)
295
296  def _matmul(self, x, adjoint=False, adjoint_arg=False):
297    diagonals = self.diagonals
298    if adjoint:
299      diagonals = self._construct_adjoint_diagonals(diagonals)
300    x = linalg.adjoint(x) if adjoint_arg else x
301    return linalg.tridiagonal_matmul(
302        diagonals, x,
303        diagonals_format=self.diagonals_format)
304
305  def _solve(self, rhs, adjoint=False, adjoint_arg=False):
306    diagonals = self.diagonals
307    if adjoint:
308      diagonals = self._construct_adjoint_diagonals(diagonals)
309
310    # TODO(b/144860784): Remove the broadcasting code below once
311    # tridiagonal_solve broadcasts.
312
313    rhs_shape = array_ops.shape(rhs)
314    k = self._shape_tensor(diagonals)[-1]
315    broadcast_shape = array_ops.broadcast_dynamic_shape(
316        self._shape_tensor(diagonals)[:-2], rhs_shape[:-2])
317    rhs = array_ops.broadcast_to(
318        rhs, array_ops.concat(
319            [broadcast_shape, rhs_shape[-2:]], axis=-1))
320    if self.diagonals_format == _MATRIX:
321      diagonals = array_ops.broadcast_to(
322          diagonals, array_ops.concat(
323              [broadcast_shape, [k, k]], axis=-1))
324    elif self.diagonals_format == _COMPACT:
325      diagonals = array_ops.broadcast_to(
326          diagonals, array_ops.concat(
327              [broadcast_shape, [3, k]], axis=-1))
328    else:
329      diagonals = [
330          array_ops.broadcast_to(d, array_ops.concat(
331              [broadcast_shape, [k]], axis=-1)) for d in diagonals]
332
333    y = linalg.tridiagonal_solve(
334        diagonals, rhs,
335        diagonals_format=self.diagonals_format,
336        transpose_rhs=adjoint_arg,
337        conjugate_rhs=adjoint_arg)
338    return y
339
340  def _diag_part(self):
341    if self.diagonals_format == _MATRIX:
342      return array_ops.matrix_diag_part(self.diagonals)
343    elif self.diagonals_format == _SEQUENCE:
344      diagonal = self.diagonals[1]
345      return array_ops.broadcast_to(
346          diagonal, self.shape_tensor()[:-1])
347    else:
348      return self.diagonals[..., 1, :]
349
350  def _to_dense(self):
351    if self.diagonals_format == _MATRIX:
352      return self.diagonals
353
354    if self.diagonals_format == _COMPACT:
355      return gen_array_ops.matrix_diag_v3(
356          self.diagonals,
357          k=(-1, 1),
358          num_rows=-1,
359          num_cols=-1,
360          align='LEFT_RIGHT',
361          padding_value=0.)
362
363    diagonals = [
364        ops.convert_to_tensor_v2_with_dispatch(d) for d in self.diagonals
365    ]
366    diagonals = array_ops.stack(diagonals, axis=-2)
367
368    return gen_array_ops.matrix_diag_v3(
369        diagonals,
370        k=(-1, 1),
371        num_rows=-1,
372        num_cols=-1,
373        align='LEFT_RIGHT',
374        padding_value=0.)
375
376  @property
377  def diagonals(self):
378    return self._diagonals
379
380  @property
381  def diagonals_format(self):
382    return self._diagonals_format
383
384  @property
385  def _composite_tensor_fields(self):
386    return ('diagonals', 'diagonals_format')
387
388  @property
389  def _experimental_parameter_ndims_to_matrix_ndims(self):
390    diagonal_event_ndims = 2
391    if self.diagonals_format == _SEQUENCE:
392      # For the diagonal and the super/sub diagonals.
393      diagonal_event_ndims = [1, 1, 1]
394    return {
395        'diagonals': diagonal_event_ndims,
396    }
397