• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Add one or more `LinearOperators` efficiently."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import abc
22
23import six
24
25from tensorflow.python.framework import ops
26from tensorflow.python.framework import tensor_shape
27from tensorflow.python.ops import array_ops
28from tensorflow.python.ops import check_ops
29from tensorflow.python.ops.linalg import linear_operator
30from tensorflow.python.ops.linalg import linear_operator_diag
31from tensorflow.python.ops.linalg import linear_operator_full_matrix
32from tensorflow.python.ops.linalg import linear_operator_identity
33from tensorflow.python.ops.linalg import linear_operator_lower_triangular
34
35__all__ = []
36
37
38def add_operators(operators,
39                  operator_name=None,
40                  addition_tiers=None,
41                  name=None):
42  """Efficiently add one or more linear operators.
43
44  Given operators `[A1, A2,...]`, this `Op` returns a possibly shorter list of
45  operators `[B1, B2,...]` such that
46
47  ```sum_k Ak.matmul(x) = sum_k Bk.matmul(x).```
48
49  The operators `Bk` result by adding some of the `Ak`, as allowed by
50  `addition_tiers`.
51
52  Example of efficient adding of diagonal operators.
53
54  ```python
55  A1 = LinearOperatorDiag(diag=[1., 1.], name="A1")
56  A2 = LinearOperatorDiag(diag=[2., 2.], name="A2")
57
58  # Use two tiers, the first contains an Adder that returns Diag.  Since both
59  # A1 and A2 are Diag, they can use this Adder.  The second tier will not be
60  # used.
61  addition_tiers = [
62      [_AddAndReturnDiag()],
63      [_AddAndReturnMatrix()]]
64  B_list = add_operators([A1, A2], addition_tiers=addition_tiers)
65
66  len(B_list)
67  ==> 1
68
69  B_list[0].__class__.__name__
70  ==> 'LinearOperatorDiag'
71
72  B_list[0].to_dense()
73  ==> [[3., 0.],
74       [0., 3.]]
75
76  B_list[0].name
77  ==> 'Add/A1__A2/'
78  ```
79
80  Args:
81    operators:  Iterable of `LinearOperator` objects with same `dtype`, domain
82      and range dimensions, and broadcastable batch shapes.
83    operator_name:  String name for returned `LinearOperator`.  Defaults to
84      concatenation of "Add/A__B/" that indicates the order of addition steps.
85    addition_tiers:  List tiers, like `[tier_0, tier_1, ...]`, where `tier_i`
86      is a list of `Adder` objects.  This function attempts to do all additions
87      in tier `i` before trying tier `i + 1`.
88    name:  A name for this `Op`.  Defaults to `add_operators`.
89
90  Returns:
91    Subclass of `LinearOperator`.  Class and order of addition may change as new
92      (and better) addition strategies emerge.
93
94  Raises:
95    ValueError:  If `operators` argument is empty.
96    ValueError:  If shapes are incompatible.
97  """
98  # Default setting
99  if addition_tiers is None:
100    addition_tiers = _DEFAULT_ADDITION_TIERS
101
102  # Argument checking.
103  check_ops.assert_proper_iterable(operators)
104  operators = list(reversed(operators))
105  if len(operators) < 1:
106    raise ValueError(
107        "Argument 'operators' must contain at least one operator.  "
108        "Found: %s" % operators)
109  if not all(
110      isinstance(op, linear_operator.LinearOperator) for op in operators):
111    raise TypeError(
112        "Argument 'operators' must contain only LinearOperator instances.  "
113        "Found: %s" % operators)
114  _static_check_for_same_dimensions(operators)
115  _static_check_for_broadcastable_batch_shape(operators)
116
117  graph_parents = []
118  for operator in operators:
119    graph_parents.extend(operator.graph_parents)
120
121  with ops.name_scope(name or "add_operators", values=graph_parents):
122
123    # Additions done in one of the tiers.  Try tier 0, 1,...
124    ops_to_try_at_next_tier = list(operators)
125    for tier in addition_tiers:
126      ops_to_try_at_this_tier = ops_to_try_at_next_tier
127      ops_to_try_at_next_tier = []
128      while ops_to_try_at_this_tier:
129        op1 = ops_to_try_at_this_tier.pop()
130        op2, adder = _pop_a_match_at_tier(op1, ops_to_try_at_this_tier, tier)
131        if op2 is not None:
132          # Will try to add the result of this again at this same tier.
133          new_operator = adder.add(op1, op2, operator_name)
134          ops_to_try_at_this_tier.append(new_operator)
135        else:
136          ops_to_try_at_next_tier.append(op1)
137
138    return ops_to_try_at_next_tier
139
140
141def _pop_a_match_at_tier(op1, operator_list, tier):
142  # Search from the back of list to the front in order to create nice default
143  # order of operations.
144  for i in range(1, len(operator_list) + 1):
145    op2 = operator_list[-i]
146    for adder in tier:
147      if adder.can_add(op1, op2):
148        return operator_list.pop(-i), adder
149  return None, None
150
151
152def _infer_hints_allowing_override(op1, op2, hints):
153  """Infer hints from op1 and op2.  hints argument is an override.
154
155  Args:
156    op1:  LinearOperator
157    op2:  LinearOperator
158    hints:  _Hints object holding "is_X" boolean hints to use for returned
159      operator.
160      If some hint is None, try to set using op1 and op2.  If the
161      hint is provided, ignore op1 and op2 hints.  This allows an override
162      of previous hints, but does not allow forbidden hints (e.g. you still
163      cannot say a real diagonal operator is not self-adjoint.
164
165  Returns:
166    _Hints object.
167  """
168  hints = hints or _Hints()
169  # If A, B are self-adjoint, then so is A + B.
170  if hints.is_self_adjoint is None:
171    is_self_adjoint = op1.is_self_adjoint and op2.is_self_adjoint
172  else:
173    is_self_adjoint = hints.is_self_adjoint
174
175  # If A, B are positive definite, then so is A + B.
176  if hints.is_positive_definite is None:
177    is_positive_definite = op1.is_positive_definite and op2.is_positive_definite
178  else:
179    is_positive_definite = hints.is_positive_definite
180
181  # A positive definite operator is always non-singular.
182  if is_positive_definite and hints.is_positive_definite is None:
183    is_non_singular = True
184  else:
185    is_non_singular = hints.is_non_singular
186
187  return _Hints(
188      is_non_singular=is_non_singular,
189      is_self_adjoint=is_self_adjoint,
190      is_positive_definite=is_positive_definite)
191
192
193def _static_check_for_same_dimensions(operators):
194  """ValueError if operators determined to have different dimensions."""
195  if len(operators) < 2:
196    return
197
198  domain_dimensions = [
199      (op.name, tensor_shape.dimension_value(op.domain_dimension))
200      for op in operators
201      if tensor_shape.dimension_value(op.domain_dimension) is not None]
202  if len(set(value for name, value in domain_dimensions)) > 1:
203    raise ValueError("Operators must have the same domain dimension. Found: %s"
204                     % domain_dimensions)
205
206  range_dimensions = [
207      (op.name, tensor_shape.dimension_value(op.range_dimension))
208      for op in operators
209      if tensor_shape.dimension_value(op.range_dimension) is not None]
210  if len(set(value for name, value in range_dimensions)) > 1:
211    raise ValueError("Operators must have the same range dimension. Found: %s" %
212                     range_dimensions)
213
214
215def _static_check_for_broadcastable_batch_shape(operators):
216  """ValueError if operators determined to have non-broadcastable shapes."""
217  if len(operators) < 2:
218    return
219
220  # This will fail if they cannot be broadcast together.
221  batch_shape = operators[0].batch_shape
222  for op in operators[1:]:
223    batch_shape = array_ops.broadcast_static_shape(batch_shape, op.batch_shape)
224
225
226class _Hints(object):
227  """Holds 'is_X' flags that every LinearOperator is initialized with."""
228
229  def __init__(self,
230               is_non_singular=None,
231               is_positive_definite=None,
232               is_self_adjoint=None):
233    self.is_non_singular = is_non_singular
234    self.is_positive_definite = is_positive_definite
235    self.is_self_adjoint = is_self_adjoint
236
237
238################################################################################
239# Classes to add two linear operators.
240################################################################################
241
242
243@six.add_metaclass(abc.ABCMeta)
244class _Adder(object):
245  """Abstract base class to add two operators.
246
247  Each `Adder` acts independently, adding everything it can, paying no attention
248  as to whether another `Adder` could have done the addition more efficiently.
249  """
250
251  @property
252  def name(self):
253    return self.__class__.__name__
254
255  @abc.abstractmethod
256  def can_add(self, op1, op2):
257    """Returns `True` if this `Adder` can add `op1` and `op2`.  Else `False`."""
258    pass
259
260  @abc.abstractmethod
261  def _add(self, op1, op2, operator_name, hints):
262    # Derived classes can assume op1 and op2 have been validated, e.g. they have
263    # the same dtype, and their domain/range dimensions match.
264    pass
265
266  def add(self, op1, op2, operator_name, hints=None):
267    """Return new `LinearOperator` acting like `op1 + op2`.
268
269    Args:
270      op1:  `LinearOperator`
271      op2:  `LinearOperator`, with `shape` and `dtype` such that adding to
272        `op1` is allowed.
273      operator_name:  `String` name to give to returned `LinearOperator`
274      hints:  `_Hints` object.  Returned `LinearOperator` will be created with
275        these hints.
276
277    Returns:
278      `LinearOperator`
279    """
280    updated_hints = _infer_hints_allowing_override(op1, op2, hints)
281
282    if operator_name is None:
283      operator_name = "Add/" + op1.name + "__" + op2.name + "/"
284
285    values = op1.graph_parents + op2.graph_parents
286    scope_name = self.name
287    if scope_name.startswith("_"):
288      scope_name = scope_name[1:]
289    with ops.name_scope(scope_name, values=values):
290      return self._add(op1, op2, operator_name, updated_hints)
291
292
293class _AddAndReturnScaledIdentity(_Adder):
294  """Handles additions resulting in an Identity family member.
295
296  The Identity (`LinearOperatorScaledIdentity`, `LinearOperatorIdentity`) family
297  is closed under addition.  This `Adder` respects that, and returns an Identity
298  """
299
300  def can_add(self, op1, op2):
301    types = {_type(op1), _type(op2)}
302    return not types.difference(_IDENTITY_FAMILY)
303
304  def _add(self, op1, op2, operator_name, hints):
305    # Will build a LinearOperatorScaledIdentity.
306
307    if _type(op1) == _SCALED_IDENTITY:
308      multiplier_1 = op1.multiplier
309    else:
310      multiplier_1 = array_ops.ones(op1.batch_shape_tensor(), dtype=op1.dtype)
311
312    if _type(op2) == _SCALED_IDENTITY:
313      multiplier_2 = op2.multiplier
314    else:
315      multiplier_2 = array_ops.ones(op2.batch_shape_tensor(), dtype=op2.dtype)
316
317    return linear_operator_identity.LinearOperatorScaledIdentity(
318        num_rows=op1.range_dimension_tensor(),
319        multiplier=multiplier_1 + multiplier_2,
320        is_non_singular=hints.is_non_singular,
321        is_self_adjoint=hints.is_self_adjoint,
322        is_positive_definite=hints.is_positive_definite,
323        name=operator_name)
324
325
326class _AddAndReturnDiag(_Adder):
327  """Handles additions resulting in a Diag operator."""
328
329  def can_add(self, op1, op2):
330    types = {_type(op1), _type(op2)}
331    return not types.difference(_DIAG_LIKE)
332
333  def _add(self, op1, op2, operator_name, hints):
334    return linear_operator_diag.LinearOperatorDiag(
335        diag=op1.diag_part() + op2.diag_part(),
336        is_non_singular=hints.is_non_singular,
337        is_self_adjoint=hints.is_self_adjoint,
338        is_positive_definite=hints.is_positive_definite,
339        name=operator_name)
340
341
342class _AddAndReturnTriL(_Adder):
343  """Handles additions resulting in a TriL operator."""
344
345  def can_add(self, op1, op2):
346    types = {_type(op1), _type(op2)}
347    return not types.difference(_DIAG_LIKE.union({_TRIL}))
348
349  def _add(self, op1, op2, operator_name, hints):
350    if _type(op1) in _EFFICIENT_ADD_TO_TENSOR:
351      op_add_to_tensor, op_other = op1, op2
352    else:
353      op_add_to_tensor, op_other = op2, op1
354
355    return linear_operator_lower_triangular.LinearOperatorLowerTriangular(
356        tril=op_add_to_tensor.add_to_tensor(op_other.to_dense()),
357        is_non_singular=hints.is_non_singular,
358        is_self_adjoint=hints.is_self_adjoint,
359        is_positive_definite=hints.is_positive_definite,
360        name=operator_name)
361
362
363class _AddAndReturnMatrix(_Adder):
364  """"Handles additions resulting in a `LinearOperatorFullMatrix`."""
365
366  def can_add(self, op1, op2):  # pylint: disable=unused-argument
367    return isinstance(op1, linear_operator.LinearOperator) and isinstance(
368        op2, linear_operator.LinearOperator)
369
370  def _add(self, op1, op2, operator_name, hints):
371    if _type(op1) in _EFFICIENT_ADD_TO_TENSOR:
372      op_add_to_tensor, op_other = op1, op2
373    else:
374      op_add_to_tensor, op_other = op2, op1
375    return linear_operator_full_matrix.LinearOperatorFullMatrix(
376        matrix=op_add_to_tensor.add_to_tensor(op_other.to_dense()),
377        is_non_singular=hints.is_non_singular,
378        is_self_adjoint=hints.is_self_adjoint,
379        is_positive_definite=hints.is_positive_definite,
380        name=operator_name)
381
382
383################################################################################
384# Constants designating types of LinearOperators
385################################################################################
386
387# Type name constants for LinearOperator classes.
388_IDENTITY = "identity"
389_SCALED_IDENTITY = "scaled_identity"
390_DIAG = "diag"
391_TRIL = "tril"
392_MATRIX = "matrix"
393
394# Groups of operators.
395_DIAG_LIKE = {_DIAG, _IDENTITY, _SCALED_IDENTITY}
396_IDENTITY_FAMILY = {_IDENTITY, _SCALED_IDENTITY}
397# operators with an efficient .add_to_tensor() method.
398_EFFICIENT_ADD_TO_TENSOR = _DIAG_LIKE
399
400# Supported LinearOperator classes.
401SUPPORTED_OPERATORS = [
402    linear_operator_diag.LinearOperatorDiag,
403    linear_operator_lower_triangular.LinearOperatorLowerTriangular,
404    linear_operator_full_matrix.LinearOperatorFullMatrix,
405    linear_operator_identity.LinearOperatorIdentity,
406    linear_operator_identity.LinearOperatorScaledIdentity
407]
408
409
410def _type(operator):
411  """Returns the type name constant (e.g. _TRIL) for operator."""
412  if isinstance(operator, linear_operator_diag.LinearOperatorDiag):
413    return _DIAG
414  if isinstance(operator,
415                linear_operator_lower_triangular.LinearOperatorLowerTriangular):
416    return _TRIL
417  if isinstance(operator, linear_operator_full_matrix.LinearOperatorFullMatrix):
418    return _MATRIX
419  if isinstance(operator, linear_operator_identity.LinearOperatorIdentity):
420    return _IDENTITY
421  if isinstance(operator,
422                linear_operator_identity.LinearOperatorScaledIdentity):
423    return _SCALED_IDENTITY
424  raise TypeError("Operator type unknown: %s" % operator)
425
426
427################################################################################
428# Addition tiers:
429# We attempt to use Adders in tier K before K+1.
430#
431# Organize tiers to
432#   (i) reduce O(..) complexity of forming final operator, and
433#   (ii) produce the "most efficient" final operator.
434# Dev notes:
435#  * Results of addition at tier K will be added at tier K or higher.
436#  * Tiers may change, and we warn the user that it may change.
437################################################################################
438
439# Note that the final tier, _AddAndReturnMatrix, will convert everything to a
440# dense matrix.  So it is sometimes very inefficient.
441_DEFAULT_ADDITION_TIERS = [
442    [_AddAndReturnScaledIdentity()],
443    [_AddAndReturnDiag()],
444    [_AddAndReturnTriL()],
445    [_AddAndReturnMatrix()],
446]
447