• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Add one or more `LinearOperators` efficiently."""
16
17import abc
18
19from tensorflow.python.framework import ops
20from tensorflow.python.framework import tensor_shape
21from tensorflow.python.ops import array_ops
22from tensorflow.python.ops import check_ops
23from tensorflow.python.ops.linalg import linear_operator
24from tensorflow.python.ops.linalg import linear_operator_diag
25from tensorflow.python.ops.linalg import linear_operator_full_matrix
26from tensorflow.python.ops.linalg import linear_operator_identity
27from tensorflow.python.ops.linalg import linear_operator_lower_triangular
28
29__all__ = []
30
31
32def add_operators(operators,
33                  operator_name=None,
34                  addition_tiers=None,
35                  name=None):
36  """Efficiently add one or more linear operators.
37
38  Given operators `[A1, A2,...]`, this `Op` returns a possibly shorter list of
39  operators `[B1, B2,...]` such that
40
41  ```sum_k Ak.matmul(x) = sum_k Bk.matmul(x).```
42
43  The operators `Bk` result by adding some of the `Ak`, as allowed by
44  `addition_tiers`.
45
46  Example of efficient adding of diagonal operators.
47
48  ```python
49  A1 = LinearOperatorDiag(diag=[1., 1.], name="A1")
50  A2 = LinearOperatorDiag(diag=[2., 2.], name="A2")
51
52  # Use two tiers, the first contains an Adder that returns Diag.  Since both
53  # A1 and A2 are Diag, they can use this Adder.  The second tier will not be
54  # used.
55  addition_tiers = [
56      [_AddAndReturnDiag()],
57      [_AddAndReturnMatrix()]]
58  B_list = add_operators([A1, A2], addition_tiers=addition_tiers)
59
60  len(B_list)
61  ==> 1
62
63  B_list[0].__class__.__name__
64  ==> 'LinearOperatorDiag'
65
66  B_list[0].to_dense()
67  ==> [[3., 0.],
68       [0., 3.]]
69
70  B_list[0].name
71  ==> 'Add/A1__A2/'
72  ```
73
74  Args:
75    operators:  Iterable of `LinearOperator` objects with same `dtype`, domain
76      and range dimensions, and broadcastable batch shapes.
77    operator_name:  String name for returned `LinearOperator`.  Defaults to
78      concatenation of "Add/A__B/" that indicates the order of addition steps.
79    addition_tiers:  List tiers, like `[tier_0, tier_1, ...]`, where `tier_i`
80      is a list of `Adder` objects.  This function attempts to do all additions
81      in tier `i` before trying tier `i + 1`.
82    name:  A name for this `Op`.  Defaults to `add_operators`.
83
84  Returns:
85    Subclass of `LinearOperator`.  Class and order of addition may change as new
86      (and better) addition strategies emerge.
87
88  Raises:
89    ValueError:  If `operators` argument is empty.
90    ValueError:  If shapes are incompatible.
91  """
92  # Default setting
93  if addition_tiers is None:
94    addition_tiers = _DEFAULT_ADDITION_TIERS
95
96  # Argument checking.
97  check_ops.assert_proper_iterable(operators)
98  operators = list(reversed(operators))
99  if len(operators) < 1:
100    raise ValueError(
101        f"Argument `operators` must contain at least one operator. "
102        f"Received: {operators}.")
103  if not all(
104      isinstance(op, linear_operator.LinearOperator) for op in operators):
105    raise TypeError(
106        f"Argument `operators` must contain only LinearOperator instances. "
107        f"Received: {operators}.")
108  _static_check_for_same_dimensions(operators)
109  _static_check_for_broadcastable_batch_shape(operators)
110
111  with ops.name_scope(name or "add_operators"):
112
113    # Additions done in one of the tiers.  Try tier 0, 1,...
114    ops_to_try_at_next_tier = list(operators)
115    for tier in addition_tiers:
116      ops_to_try_at_this_tier = ops_to_try_at_next_tier
117      ops_to_try_at_next_tier = []
118      while ops_to_try_at_this_tier:
119        op1 = ops_to_try_at_this_tier.pop()
120        op2, adder = _pop_a_match_at_tier(op1, ops_to_try_at_this_tier, tier)
121        if op2 is not None:
122          # Will try to add the result of this again at this same tier.
123          new_operator = adder.add(op1, op2, operator_name)
124          ops_to_try_at_this_tier.append(new_operator)
125        else:
126          ops_to_try_at_next_tier.append(op1)
127
128    return ops_to_try_at_next_tier
129
130
131def _pop_a_match_at_tier(op1, operator_list, tier):
132  # Search from the back of list to the front in order to create nice default
133  # order of operations.
134  for i in range(1, len(operator_list) + 1):
135    op2 = operator_list[-i]
136    for adder in tier:
137      if adder.can_add(op1, op2):
138        return operator_list.pop(-i), adder
139  return None, None
140
141
142def _infer_hints_allowing_override(op1, op2, hints):
143  """Infer hints from op1 and op2.  hints argument is an override.
144
145  Args:
146    op1:  LinearOperator
147    op2:  LinearOperator
148    hints:  _Hints object holding "is_X" boolean hints to use for returned
149      operator.
150      If some hint is None, try to set using op1 and op2.  If the
151      hint is provided, ignore op1 and op2 hints.  This allows an override
152      of previous hints, but does not allow forbidden hints (e.g. you still
153      cannot say a real diagonal operator is not self-adjoint.
154
155  Returns:
156    _Hints object.
157  """
158  hints = hints or _Hints()
159  # If A, B are self-adjoint, then so is A + B.
160  if hints.is_self_adjoint is None:
161    is_self_adjoint = op1.is_self_adjoint and op2.is_self_adjoint
162  else:
163    is_self_adjoint = hints.is_self_adjoint
164
165  # If A, B are positive definite, then so is A + B.
166  if hints.is_positive_definite is None:
167    is_positive_definite = op1.is_positive_definite and op2.is_positive_definite
168  else:
169    is_positive_definite = hints.is_positive_definite
170
171  # A positive definite operator is always non-singular.
172  if is_positive_definite and hints.is_positive_definite is None:
173    is_non_singular = True
174  else:
175    is_non_singular = hints.is_non_singular
176
177  return _Hints(
178      is_non_singular=is_non_singular,
179      is_self_adjoint=is_self_adjoint,
180      is_positive_definite=is_positive_definite)
181
182
183def _static_check_for_same_dimensions(operators):
184  """ValueError if operators determined to have different dimensions."""
185  if len(operators) < 2:
186    return
187
188  domain_dimensions = [
189      (op.name, tensor_shape.dimension_value(op.domain_dimension))
190      for op in operators
191      if tensor_shape.dimension_value(op.domain_dimension) is not None]
192  if len(set(value for name, value in domain_dimensions)) > 1:
193    raise ValueError(f"All `operators` must have the same `domain_dimension`. "
194                     f"Received: {domain_dimensions}.")
195
196  range_dimensions = [
197      (op.name, tensor_shape.dimension_value(op.range_dimension))
198      for op in operators
199      if tensor_shape.dimension_value(op.range_dimension) is not None]
200  if len(set(value for name, value in range_dimensions)) > 1:
201    raise ValueError(f"All operators must have the same `range_dimension`. "
202                     f"Received: {range_dimensions}.")
203
204
205def _static_check_for_broadcastable_batch_shape(operators):
206  """ValueError if operators determined to have non-broadcastable shapes."""
207  if len(operators) < 2:
208    return
209
210  # This will fail if they cannot be broadcast together.
211  batch_shape = operators[0].batch_shape
212  for op in operators[1:]:
213    batch_shape = array_ops.broadcast_static_shape(batch_shape, op.batch_shape)
214
215
216class _Hints:
217  """Holds 'is_X' flags that every LinearOperator is initialized with."""
218
219  def __init__(self,
220               is_non_singular=None,
221               is_positive_definite=None,
222               is_self_adjoint=None):
223    self.is_non_singular = is_non_singular
224    self.is_positive_definite = is_positive_definite
225    self.is_self_adjoint = is_self_adjoint
226
227
228################################################################################
229# Classes to add two linear operators.
230################################################################################
231
232
233class _Adder(metaclass=abc.ABCMeta):
234  """Abstract base class to add two operators.
235
236  Each `Adder` acts independently, adding everything it can, paying no attention
237  as to whether another `Adder` could have done the addition more efficiently.
238  """
239
240  @property
241  def name(self):
242    return self.__class__.__name__
243
244  @abc.abstractmethod
245  def can_add(self, op1, op2):
246    """Returns `True` if this `Adder` can add `op1` and `op2`.  Else `False`."""
247    pass
248
249  @abc.abstractmethod
250  def _add(self, op1, op2, operator_name, hints):
251    # Derived classes can assume op1 and op2 have been validated, e.g. they have
252    # the same dtype, and their domain/range dimensions match.
253    pass
254
255  def add(self, op1, op2, operator_name, hints=None):
256    """Return new `LinearOperator` acting like `op1 + op2`.
257
258    Args:
259      op1:  `LinearOperator`
260      op2:  `LinearOperator`, with `shape` and `dtype` such that adding to
261        `op1` is allowed.
262      operator_name:  `String` name to give to returned `LinearOperator`
263      hints:  `_Hints` object.  Returned `LinearOperator` will be created with
264        these hints.
265
266    Returns:
267      `LinearOperator`
268    """
269    updated_hints = _infer_hints_allowing_override(op1, op2, hints)
270
271    if operator_name is None:
272      operator_name = "Add/" + op1.name + "__" + op2.name + "/"
273
274    scope_name = self.name
275    if scope_name.startswith("_"):
276      scope_name = scope_name[1:]
277    with ops.name_scope(scope_name):
278      return self._add(op1, op2, operator_name, updated_hints)
279
280
281class _AddAndReturnScaledIdentity(_Adder):
282  """Handles additions resulting in an Identity family member.
283
284  The Identity (`LinearOperatorScaledIdentity`, `LinearOperatorIdentity`) family
285  is closed under addition.  This `Adder` respects that, and returns an Identity
286  """
287
288  def can_add(self, op1, op2):
289    types = {_type(op1), _type(op2)}
290    return not types.difference(_IDENTITY_FAMILY)
291
292  def _add(self, op1, op2, operator_name, hints):
293    # Will build a LinearOperatorScaledIdentity.
294
295    if _type(op1) == _SCALED_IDENTITY:
296      multiplier_1 = op1.multiplier
297    else:
298      multiplier_1 = array_ops.ones(op1.batch_shape_tensor(), dtype=op1.dtype)
299
300    if _type(op2) == _SCALED_IDENTITY:
301      multiplier_2 = op2.multiplier
302    else:
303      multiplier_2 = array_ops.ones(op2.batch_shape_tensor(), dtype=op2.dtype)
304
305    return linear_operator_identity.LinearOperatorScaledIdentity(
306        num_rows=op1.range_dimension_tensor(),
307        multiplier=multiplier_1 + multiplier_2,
308        is_non_singular=hints.is_non_singular,
309        is_self_adjoint=hints.is_self_adjoint,
310        is_positive_definite=hints.is_positive_definite,
311        name=operator_name)
312
313
314class _AddAndReturnDiag(_Adder):
315  """Handles additions resulting in a Diag operator."""
316
317  def can_add(self, op1, op2):
318    types = {_type(op1), _type(op2)}
319    return not types.difference(_DIAG_LIKE)
320
321  def _add(self, op1, op2, operator_name, hints):
322    return linear_operator_diag.LinearOperatorDiag(
323        diag=op1.diag_part() + op2.diag_part(),
324        is_non_singular=hints.is_non_singular,
325        is_self_adjoint=hints.is_self_adjoint,
326        is_positive_definite=hints.is_positive_definite,
327        name=operator_name)
328
329
330class _AddAndReturnTriL(_Adder):
331  """Handles additions resulting in a TriL operator."""
332
333  def can_add(self, op1, op2):
334    types = {_type(op1), _type(op2)}
335    return not types.difference(_DIAG_LIKE.union({_TRIL}))
336
337  def _add(self, op1, op2, operator_name, hints):
338    if _type(op1) in _EFFICIENT_ADD_TO_TENSOR:
339      op_add_to_tensor, op_other = op1, op2
340    else:
341      op_add_to_tensor, op_other = op2, op1
342
343    return linear_operator_lower_triangular.LinearOperatorLowerTriangular(
344        tril=op_add_to_tensor.add_to_tensor(op_other.to_dense()),
345        is_non_singular=hints.is_non_singular,
346        is_self_adjoint=hints.is_self_adjoint,
347        is_positive_definite=hints.is_positive_definite,
348        name=operator_name)
349
350
351class _AddAndReturnMatrix(_Adder):
352  """"Handles additions resulting in a `LinearOperatorFullMatrix`."""
353
354  def can_add(self, op1, op2):  # pylint: disable=unused-argument
355    return isinstance(op1, linear_operator.LinearOperator) and isinstance(
356        op2, linear_operator.LinearOperator)
357
358  def _add(self, op1, op2, operator_name, hints):
359    if _type(op1) in _EFFICIENT_ADD_TO_TENSOR:
360      op_add_to_tensor, op_other = op1, op2
361    else:
362      op_add_to_tensor, op_other = op2, op1
363    return linear_operator_full_matrix.LinearOperatorFullMatrix(
364        matrix=op_add_to_tensor.add_to_tensor(op_other.to_dense()),
365        is_non_singular=hints.is_non_singular,
366        is_self_adjoint=hints.is_self_adjoint,
367        is_positive_definite=hints.is_positive_definite,
368        name=operator_name)
369
370
371################################################################################
372# Constants designating types of LinearOperators
373################################################################################
374
375# Type name constants for LinearOperator classes.
376_IDENTITY = "identity"
377_SCALED_IDENTITY = "scaled_identity"
378_DIAG = "diag"
379_TRIL = "tril"
380_MATRIX = "matrix"
381
382# Groups of operators.
383_DIAG_LIKE = {_DIAG, _IDENTITY, _SCALED_IDENTITY}
384_IDENTITY_FAMILY = {_IDENTITY, _SCALED_IDENTITY}
385# operators with an efficient .add_to_tensor() method.
386_EFFICIENT_ADD_TO_TENSOR = _DIAG_LIKE
387
388# Supported LinearOperator classes.
389SUPPORTED_OPERATORS = [
390    linear_operator_diag.LinearOperatorDiag,
391    linear_operator_lower_triangular.LinearOperatorLowerTriangular,
392    linear_operator_full_matrix.LinearOperatorFullMatrix,
393    linear_operator_identity.LinearOperatorIdentity,
394    linear_operator_identity.LinearOperatorScaledIdentity
395]
396
397
398def _type(operator):
399  """Returns the type name constant (e.g. _TRIL) for operator."""
400  if isinstance(operator, linear_operator_diag.LinearOperatorDiag):
401    return _DIAG
402  if isinstance(operator,
403                linear_operator_lower_triangular.LinearOperatorLowerTriangular):
404    return _TRIL
405  if isinstance(operator, linear_operator_full_matrix.LinearOperatorFullMatrix):
406    return _MATRIX
407  if isinstance(operator, linear_operator_identity.LinearOperatorIdentity):
408    return _IDENTITY
409  if isinstance(operator,
410                linear_operator_identity.LinearOperatorScaledIdentity):
411    return _SCALED_IDENTITY
412  raise TypeError(f"Expected operator to be one of [LinearOperatorDiag, "
413                  f"LinearOperatorLowerTriangular, LinearOperatorFullMatrix, "
414                  f"LinearOperatorIdentity, LinearOperatorScaledIdentity]. "
415                  f"Received: {operator}")
416
417
418################################################################################
419# Addition tiers:
420# We attempt to use Adders in tier K before K+1.
421#
422# Organize tiers to
423#   (i) reduce O(..) complexity of forming final operator, and
424#   (ii) produce the "most efficient" final operator.
425# Dev notes:
426#  * Results of addition at tier K will be added at tier K or higher.
427#  * Tiers may change, and we warn the user that it may change.
428################################################################################
429
430# Note that the final tier, _AddAndReturnMatrix, will convert everything to a
431# dense matrix.  So it is sometimes very inefficient.
432_DEFAULT_ADDITION_TIERS = [
433    [_AddAndReturnScaledIdentity()],
434    [_AddAndReturnDiag()],
435    [_AddAndReturnTriL()],
436    [_AddAndReturnMatrix()],
437]
438