• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Construct the Kronecker product of one or more `LinearOperators`."""
16
17from tensorflow.python.framework import common_shapes
18from tensorflow.python.framework import dtypes
19from tensorflow.python.framework import errors
20from tensorflow.python.framework import ops
21from tensorflow.python.framework import tensor_shape
22from tensorflow.python.framework import tensor_util
23from tensorflow.python.ops import array_ops
24from tensorflow.python.ops import check_ops
25from tensorflow.python.ops import control_flow_ops
26from tensorflow.python.ops import math_ops
27from tensorflow.python.ops.linalg import linalg_impl as linalg
28from tensorflow.python.ops.linalg import linear_operator
29from tensorflow.python.util.tf_export import tf_export
30
31__all__ = ["LinearOperatorKronecker"]
32
33
34def _prefer_static_shape(x):
35  if x.shape.is_fully_defined():
36    return x.shape
37  return array_ops.shape(x)
38
39
40def _prefer_static_concat_shape(first_shape, second_shape_int_list):
41  """Concatenate a shape with a list of integers as statically as possible.
42
43  Args:
44    first_shape: `TensorShape` or `Tensor` instance. If a `TensorShape`,
45      `first_shape.is_fully_defined()` must return `True`.
46    second_shape_int_list: `list` of scalar integer `Tensor`s.
47
48  Returns:
49    `Tensor` representing concatenating `first_shape` and
50      `second_shape_int_list` as statically as possible.
51  """
52  second_shape_int_list_static = [
53      tensor_util.constant_value(s) for s in second_shape_int_list]
54  if (isinstance(first_shape, tensor_shape.TensorShape) and
55      all(s is not None for s in second_shape_int_list_static)):
56    return first_shape.concatenate(second_shape_int_list_static)
57  return array_ops.concat([first_shape, second_shape_int_list], axis=0)
58
59
60@tf_export("linalg.LinearOperatorKronecker")
61@linear_operator.make_composite_tensor
62class LinearOperatorKronecker(linear_operator.LinearOperator):
63  """Kronecker product between two `LinearOperators`.
64
65  This operator composes one or more linear operators `[op1,...,opJ]`,
66  building a new `LinearOperator` representing the Kronecker product:
67  `op1 x op2 x .. opJ` (we omit parentheses as the Kronecker product is
68  associative).
69
70  If `opj` has shape `batch_shape_j + [M_j, N_j]`, then the composed operator
71  will have shape equal to `broadcast_batch_shape + [prod M_j, prod N_j]`,
72  where the product is over all operators.
73
74  ```python
75  # Create a 4 x 4 linear operator composed of two 2 x 2 operators.
76  operator_1 = LinearOperatorFullMatrix([[1., 2.], [3., 4.]])
77  operator_2 = LinearOperatorFullMatrix([[1., 0.], [2., 1.]])
78  operator = LinearOperatorKronecker([operator_1, operator_2])
79
80  operator.to_dense()
81  ==> [[1., 0., 2., 0.],
82       [2., 1., 4., 2.],
83       [3., 0., 4., 0.],
84       [6., 3., 8., 4.]]
85
86  operator.shape
87  ==> [4, 4]
88
89  operator.log_abs_determinant()
90  ==> scalar Tensor
91
92  x = ... Shape [4, 2] Tensor
93  operator.matmul(x)
94  ==> Shape [4, 2] Tensor
95
96  # Create a [2, 3] batch of 4 x 5 linear operators.
97  matrix_45 = tf.random.normal(shape=[2, 3, 4, 5])
98  operator_45 = LinearOperatorFullMatrix(matrix)
99
100  # Create a [2, 3] batch of 5 x 6 linear operators.
101  matrix_56 = tf.random.normal(shape=[2, 3, 5, 6])
102  operator_56 = LinearOperatorFullMatrix(matrix_56)
103
104  # Compose to create a [2, 3] batch of 20 x 30 operators.
105  operator_large = LinearOperatorKronecker([operator_45, operator_56])
106
107  # Create a shape [2, 3, 20, 2] vector.
108  x = tf.random.normal(shape=[2, 3, 6, 2])
109  operator_large.matmul(x)
110  ==> Shape [2, 3, 30, 2] Tensor
111  ```
112
113  #### Performance
114
115  The performance of `LinearOperatorKronecker` on any operation is equal to
116  the sum of the individual operators' operations.
117
118  #### Matrix property hints
119
120  This `LinearOperator` is initialized with boolean flags of the form `is_X`,
121  for `X = non_singular, self_adjoint, positive_definite, square`.
122  These have the following meaning:
123
124  * If `is_X == True`, callers should expect the operator to have the
125    property `X`.  This is a promise that should be fulfilled, but is *not* a
126    runtime assert.  For example, finite floating point precision may result
127    in these promises being violated.
128  * If `is_X == False`, callers should expect the operator to not have `X`.
129  * If `is_X == None` (the default), callers should have no expectation either
130    way.
131  """
132
133  def __init__(self,
134               operators,
135               is_non_singular=None,
136               is_self_adjoint=None,
137               is_positive_definite=None,
138               is_square=None,
139               name=None):
140    r"""Initialize a `LinearOperatorKronecker`.
141
142    `LinearOperatorKronecker` is initialized with a list of operators
143    `[op_1,...,op_J]`.
144
145    Args:
146      operators:  Iterable of `LinearOperator` objects, each with
147        the same `dtype` and composable shape, representing the Kronecker
148        factors.
149      is_non_singular:  Expect that this operator is non-singular.
150      is_self_adjoint:  Expect that this operator is equal to its hermitian
151        transpose.
152      is_positive_definite:  Expect that this operator is positive definite,
153        meaning the quadratic form `x^H A x` has positive real part for all
154        nonzero `x`.  Note that we do not require the operator to be
155        self-adjoint to be positive-definite.  See:
156        https://en.wikipedia.org/wiki/Positive-definite_matrix\
157            #Extension_for_non_symmetric_matrices
158      is_square:  Expect that this operator acts like square [batch] matrices.
159      name: A name for this `LinearOperator`.  Default is the individual
160        operators names joined with `_x_`.
161
162    Raises:
163      TypeError:  If all operators do not have the same `dtype`.
164      ValueError:  If `operators` is empty.
165    """
166    parameters = dict(
167        operators=operators,
168        is_non_singular=is_non_singular,
169        is_self_adjoint=is_self_adjoint,
170        is_positive_definite=is_positive_definite,
171        is_square=is_square,
172        name=name
173    )
174
175    # Validate operators.
176    check_ops.assert_proper_iterable(operators)
177    operators = list(operators)
178    if not operators:
179      raise ValueError(f"Argument `operators` must be a list of >=1 operators. "
180                       f"Received: {operators}.")
181    self._operators = operators
182
183    # Validate dtype.
184    dtype = operators[0].dtype
185    for operator in operators:
186      if operator.dtype != dtype:
187        name_type = (str((o.name, o.dtype)) for o in operators)
188        raise TypeError(
189            f"Expected every operation in argument `operators` to have the "
190            f"same dtype. Received {list(name_type)}.")
191
192    # Auto-set and check hints.
193    # A Kronecker product is invertible, if and only if all factors are
194    # invertible.
195    if all(operator.is_non_singular for operator in operators):
196      if is_non_singular is False:
197        raise ValueError(
198            f"The Kronecker product of non-singular operators is always "
199            f"non-singular. Expected argument `is_non_singular` to be True. "
200            f"Received: {is_non_singular}.")
201      is_non_singular = True
202
203    if all(operator.is_self_adjoint for operator in operators):
204      if is_self_adjoint is False:
205        raise ValueError(
206            f"The Kronecker product of self-adjoint operators is always "
207            f"self-adjoint. Expected argument `is_self_adjoint` to be True. "
208            f"Received: {is_self_adjoint}.")
209      is_self_adjoint = True
210
211    # The eigenvalues of a Kronecker product are equal to the products of eigen
212    # values of the corresponding factors.
213    if all(operator.is_positive_definite for operator in operators):
214      if is_positive_definite is False:
215        raise ValueError(
216            f"The Kronecker product of positive-definite operators is always "
217            f"positive-definite. Expected argument `is_positive_definite` to "
218            f"be True. Received: {is_positive_definite}.")
219      is_positive_definite = True
220
221    if name is None:
222      name = operators[0].name
223      for operator in operators[1:]:
224        name += "_x_" + operator.name
225    with ops.name_scope(name):
226      super(LinearOperatorKronecker, self).__init__(
227          dtype=dtype,
228          is_non_singular=is_non_singular,
229          is_self_adjoint=is_self_adjoint,
230          is_positive_definite=is_positive_definite,
231          is_square=is_square,
232          parameters=parameters,
233          name=name)
234
235  @property
236  def operators(self):
237    return self._operators
238
239  def _shape(self):
240    # Get final matrix shape.
241    domain_dimension = self.operators[0].domain_dimension
242    for operator in self.operators[1:]:
243      domain_dimension = domain_dimension * operator.domain_dimension
244
245    range_dimension = self.operators[0].range_dimension
246    for operator in self.operators[1:]:
247      range_dimension = range_dimension * operator.range_dimension
248
249    matrix_shape = tensor_shape.TensorShape([
250        range_dimension, domain_dimension])
251
252    # Get broadcast batch shape.
253    # broadcast_shape checks for compatibility.
254    batch_shape = self.operators[0].batch_shape
255    for operator in self.operators[1:]:
256      batch_shape = common_shapes.broadcast_shape(
257          batch_shape, operator.batch_shape)
258
259    return batch_shape.concatenate(matrix_shape)
260
261  def _shape_tensor(self):
262    domain_dimension = self.operators[0].domain_dimension_tensor()
263    for operator in self.operators[1:]:
264      domain_dimension = domain_dimension * operator.domain_dimension_tensor()
265
266    range_dimension = self.operators[0].range_dimension_tensor()
267    for operator in self.operators[1:]:
268      range_dimension = range_dimension * operator.range_dimension_tensor()
269
270    matrix_shape = [range_dimension, domain_dimension]
271
272    # Get broadcast batch shape.
273    # broadcast_shape checks for compatibility.
274    batch_shape = self.operators[0].batch_shape_tensor()
275    for operator in self.operators[1:]:
276      batch_shape = array_ops.broadcast_dynamic_shape(
277          batch_shape, operator.batch_shape_tensor())
278
279    return array_ops.concat((batch_shape, matrix_shape), 0)
280
281  def _solve_matmul_internal(
282      self,
283      x,
284      solve_matmul_fn,
285      adjoint=False,
286      adjoint_arg=False):
287    # We heavily rely on Roth's column Lemma [1]:
288    # (A x B) * vec X = vec BXA^T
289    # where vec stacks all the columns of the matrix under each other.
290    # In our case, we use a variant of the lemma that is row-major
291    # friendly: (A x B) * vec' X = vec' AXB^T
292    # Where vec' reshapes a matrix into a vector. We can repeatedly apply this
293    # for a collection of kronecker products.
294    # Given that (A x B)^-1 = A^-1 x B^-1 and (A x B)^T = A^T x B^T, we can
295    # use the above to compute multiplications, solves with any composition of
296    # transposes.
297    output = x
298
299    if adjoint_arg:
300      if self.dtype.is_complex:
301        output = math_ops.conj(output)
302    else:
303      output = linalg.transpose(output)
304
305    for o in reversed(self.operators):
306      # Statically compute the reshape.
307      if adjoint:
308        operator_dimension = o.range_dimension_tensor()
309      else:
310        operator_dimension = o.domain_dimension_tensor()
311      output_shape = _prefer_static_shape(output)
312
313      if tensor_util.constant_value(operator_dimension) is not None:
314        operator_dimension = tensor_util.constant_value(operator_dimension)
315        if output.shape[-2] is not None and output.shape[-1] is not None:
316          dim = int(output.shape[-2] * output_shape[-1] // operator_dimension)
317      else:
318        dim = math_ops.cast(
319            output_shape[-2] * output_shape[-1] // operator_dimension,
320            dtype=dtypes.int32)
321
322      output_shape = _prefer_static_concat_shape(
323          output_shape[:-2], [dim, operator_dimension])
324      output = array_ops.reshape(output, shape=output_shape)
325
326      # Conjugate because we are trying to compute A @ B^T, but
327      # `LinearOperator` only supports `adjoint_arg`.
328      if self.dtype.is_complex:
329        output = math_ops.conj(output)
330
331      output = solve_matmul_fn(
332          o, output, adjoint=adjoint, adjoint_arg=True)
333
334    if adjoint_arg:
335      col_dim = _prefer_static_shape(x)[-2]
336    else:
337      col_dim = _prefer_static_shape(x)[-1]
338
339    if adjoint:
340      row_dim = self.domain_dimension_tensor()
341    else:
342      row_dim = self.range_dimension_tensor()
343
344    matrix_shape = [row_dim, col_dim]
345
346    output = array_ops.reshape(
347        output,
348        _prefer_static_concat_shape(
349            _prefer_static_shape(output)[:-2], matrix_shape))
350
351    if x.shape.is_fully_defined():
352      if adjoint_arg:
353        column_dim = x.shape[-2]
354      else:
355        column_dim = x.shape[-1]
356      broadcast_batch_shape = common_shapes.broadcast_shape(
357          x.shape[:-2], self.batch_shape)
358      if adjoint:
359        matrix_dimensions = [self.domain_dimension, column_dim]
360      else:
361        matrix_dimensions = [self.range_dimension, column_dim]
362
363      output.set_shape(broadcast_batch_shape.concatenate(
364          matrix_dimensions))
365
366    return output
367
368  def _matmul(self, x, adjoint=False, adjoint_arg=False):
369    def matmul_fn(o, x, adjoint, adjoint_arg):
370      return o.matmul(x, adjoint=adjoint, adjoint_arg=adjoint_arg)
371    return self._solve_matmul_internal(
372        x=x,
373        solve_matmul_fn=matmul_fn,
374        adjoint=adjoint,
375        adjoint_arg=adjoint_arg)
376
377  def _solve(self, rhs, adjoint=False, adjoint_arg=False):
378    def solve_fn(o, rhs, adjoint, adjoint_arg):
379      return o.solve(rhs, adjoint=adjoint, adjoint_arg=adjoint_arg)
380    return self._solve_matmul_internal(
381        x=rhs,
382        solve_matmul_fn=solve_fn,
383        adjoint=adjoint,
384        adjoint_arg=adjoint_arg)
385
386  def _determinant(self):
387    # Note that we have |X1 x X2| = |X1| ** n * |X2| ** m, where X1 is an m x m
388    # matrix, and X2 is an n x n matrix. We can iteratively apply this property
389    # to get the determinant of |X1 x X2 x X3 ...|. If T is the product of the
390    # domain dimension of all operators, then we have:
391    # |X1 x X2 x X3 ...| =
392    #    |X1| ** (T / m) * |X2 x X3 ... | ** m =
393    #    |X1| ** (T / m) * |X2| ** (m * (T / m) / n) *  ... =
394    #    |X1| ** (T / m) * |X2| ** (T / n) * | X3 x X4... | ** (m * n)
395    #    And by doing induction we have product(|X_i| ** (T / dim(X_i))).
396    total = self.domain_dimension_tensor()
397    determinant = 1.
398    for operator in self.operators:
399      determinant = determinant * operator.determinant() ** math_ops.cast(
400          total / operator.domain_dimension_tensor(),
401          dtype=operator.dtype)
402    return determinant
403
404  def _log_abs_determinant(self):
405    # This will be sum((total / dim(x_i)) * log |X_i|)
406    total = self.domain_dimension_tensor()
407    log_abs_det = 0.
408    for operator in self.operators:
409      log_abs_det += operator.log_abs_determinant() * math_ops.cast(
410          total / operator.domain_dimension_tensor(),
411          dtype=operator.dtype)
412    return log_abs_det
413
414  def _trace(self):
415    # tr(A x B) = tr(A) * tr(B)
416    trace = 1.
417    for operator in self.operators:
418      trace = trace * operator.trace()
419    return trace
420
421  def _diag_part(self):
422    diag_part = self.operators[0].diag_part()
423    for operator in self.operators[1:]:
424      diag_part = diag_part[..., :, array_ops.newaxis]
425      op_diag_part = operator.diag_part()[..., array_ops.newaxis, :]
426      diag_part = diag_part * op_diag_part
427      diag_part = array_ops.reshape(
428          diag_part,
429          shape=array_ops.concat(
430              [array_ops.shape(diag_part)[:-2], [-1]], axis=0))
431    if self.range_dimension > self.domain_dimension:
432      diag_dimension = self.domain_dimension
433    else:
434      diag_dimension = self.range_dimension
435    diag_part.set_shape(
436        self.batch_shape.concatenate(diag_dimension))
437    return diag_part
438
439  def _to_dense(self):
440    product = self.operators[0].to_dense()
441    for operator in self.operators[1:]:
442      # Product has shape [B, R1, 1, C1, 1].
443      product = product[
444          ..., :, array_ops.newaxis, :, array_ops.newaxis]
445      # Operator has shape [B, 1, R2, 1, C2].
446      op_to_mul = operator.to_dense()[
447          ..., array_ops.newaxis, :, array_ops.newaxis, :]
448      # This is now [B, R1, R2, C1, C2].
449      product = product * op_to_mul
450      # Now merge together dimensions to get [B, R1 * R2, C1 * C2].
451      product_shape = _prefer_static_shape(product)
452      shape = _prefer_static_concat_shape(
453          product_shape[:-4],
454          [product_shape[-4] * product_shape[-3],
455           product_shape[-2] * product_shape[-1]])
456
457      product = array_ops.reshape(product, shape=shape)
458    product.set_shape(self.shape)
459    return product
460
461  def _eigvals(self):
462    # This will be the kronecker product of all the eigenvalues.
463    # Note: It doesn't matter which kronecker product it is, since every
464    # kronecker product of the same matrices are similar.
465    eigvals = [operator.eigvals() for operator in self.operators]
466    # Now compute the kronecker product
467    product = eigvals[0]
468    for eigval in eigvals[1:]:
469      # Product has shape [B, R1, 1].
470      product = product[..., array_ops.newaxis]
471      # Eigval has shape [B, 1, R2]. Produces shape [B, R1, R2].
472      product = product * eigval[..., array_ops.newaxis, :]
473      # Reshape to [B, R1 * R2]
474      product = array_ops.reshape(
475          product,
476          shape=array_ops.concat([array_ops.shape(product)[:-2], [-1]], axis=0))
477    product.set_shape(self.shape[:-1])
478    return product
479
480  def _assert_non_singular(self):
481    if all(operator.is_square for operator in self.operators):
482      asserts = [operator.assert_non_singular() for operator in self.operators]
483      return control_flow_ops.group(asserts)
484    else:
485      raise errors.InvalidArgumentError(
486          node_def=None,
487          op=None,
488          message="All Kronecker factors must be square for the product to be "
489          "invertible. Expected hint `is_square` to be True for every operator "
490          "in argument `operators`.")
491
492  def _assert_self_adjoint(self):
493    if all(operator.is_square for operator in self.operators):
494      asserts = [operator.assert_self_adjoint() for operator in self.operators]
495      return control_flow_ops.group(asserts)
496    else:
497      raise errors.InvalidArgumentError(
498          node_def=None,
499          op=None,
500          message="All Kronecker factors must be square for the product to be "
501          "invertible. Expected hint `is_square` to be True for every operator "
502          "in argument `operators`.")
503
504  @property
505  def _composite_tensor_fields(self):
506    return ("operators",)
507
508  @property
509  def _experimental_parameter_ndims_to_matrix_ndims(self):
510    return {"operators": [0] * len(self.operators)}
511