1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Add one or more `LinearOperators` efficiently.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import abc 22 23import six 24 25from tensorflow.python.framework import ops 26from tensorflow.python.framework import tensor_shape 27from tensorflow.python.ops import array_ops 28from tensorflow.python.ops import check_ops 29from tensorflow.python.ops.linalg import linear_operator 30from tensorflow.python.ops.linalg import linear_operator_diag 31from tensorflow.python.ops.linalg import linear_operator_full_matrix 32from tensorflow.python.ops.linalg import linear_operator_identity 33from tensorflow.python.ops.linalg import linear_operator_lower_triangular 34 35__all__ = [] 36 37 38def add_operators(operators, 39 operator_name=None, 40 addition_tiers=None, 41 name=None): 42 """Efficiently add one or more linear operators. 43 44 Given operators `[A1, A2,...]`, this `Op` returns a possibly shorter list of 45 operators `[B1, B2,...]` such that 46 47 ```sum_k Ak.matmul(x) = sum_k Bk.matmul(x).``` 48 49 The operators `Bk` result by adding some of the `Ak`, as allowed by 50 `addition_tiers`. 51 52 Example of efficient adding of diagonal operators. 53 54 ```python 55 A1 = LinearOperatorDiag(diag=[1., 1.], name="A1") 56 A2 = LinearOperatorDiag(diag=[2., 2.], name="A2") 57 58 # Use two tiers, the first contains an Adder that returns Diag. Since both 59 # A1 and A2 are Diag, they can use this Adder. The second tier will not be 60 # used. 61 addition_tiers = [ 62 [_AddAndReturnDiag()], 63 [_AddAndReturnMatrix()]] 64 B_list = add_operators([A1, A2], addition_tiers=addition_tiers) 65 66 len(B_list) 67 ==> 1 68 69 B_list[0].__class__.__name__ 70 ==> 'LinearOperatorDiag' 71 72 B_list[0].to_dense() 73 ==> [[3., 0.], 74 [0., 3.]] 75 76 B_list[0].name 77 ==> 'Add/A1__A2/' 78 ``` 79 80 Args: 81 operators: Iterable of `LinearOperator` objects with same `dtype`, domain 82 and range dimensions, and broadcastable batch shapes. 83 operator_name: String name for returned `LinearOperator`. Defaults to 84 concatenation of "Add/A__B/" that indicates the order of addition steps. 85 addition_tiers: List tiers, like `[tier_0, tier_1, ...]`, where `tier_i` 86 is a list of `Adder` objects. This function attempts to do all additions 87 in tier `i` before trying tier `i + 1`. 88 name: A name for this `Op`. Defaults to `add_operators`. 89 90 Returns: 91 Subclass of `LinearOperator`. Class and order of addition may change as new 92 (and better) addition strategies emerge. 93 94 Raises: 95 ValueError: If `operators` argument is empty. 96 ValueError: If shapes are incompatible. 97 """ 98 # Default setting 99 if addition_tiers is None: 100 addition_tiers = _DEFAULT_ADDITION_TIERS 101 102 # Argument checking. 103 check_ops.assert_proper_iterable(operators) 104 operators = list(reversed(operators)) 105 if len(operators) < 1: 106 raise ValueError( 107 "Argument 'operators' must contain at least one operator. " 108 "Found: %s" % operators) 109 if not all( 110 isinstance(op, linear_operator.LinearOperator) for op in operators): 111 raise TypeError( 112 "Argument 'operators' must contain only LinearOperator instances. " 113 "Found: %s" % operators) 114 _static_check_for_same_dimensions(operators) 115 _static_check_for_broadcastable_batch_shape(operators) 116 117 graph_parents = [] 118 for operator in operators: 119 graph_parents.extend(operator.graph_parents) 120 121 with ops.name_scope(name or "add_operators", values=graph_parents): 122 123 # Additions done in one of the tiers. Try tier 0, 1,... 124 ops_to_try_at_next_tier = list(operators) 125 for tier in addition_tiers: 126 ops_to_try_at_this_tier = ops_to_try_at_next_tier 127 ops_to_try_at_next_tier = [] 128 while ops_to_try_at_this_tier: 129 op1 = ops_to_try_at_this_tier.pop() 130 op2, adder = _pop_a_match_at_tier(op1, ops_to_try_at_this_tier, tier) 131 if op2 is not None: 132 # Will try to add the result of this again at this same tier. 133 new_operator = adder.add(op1, op2, operator_name) 134 ops_to_try_at_this_tier.append(new_operator) 135 else: 136 ops_to_try_at_next_tier.append(op1) 137 138 return ops_to_try_at_next_tier 139 140 141def _pop_a_match_at_tier(op1, operator_list, tier): 142 # Search from the back of list to the front in order to create nice default 143 # order of operations. 144 for i in range(1, len(operator_list) + 1): 145 op2 = operator_list[-i] 146 for adder in tier: 147 if adder.can_add(op1, op2): 148 return operator_list.pop(-i), adder 149 return None, None 150 151 152def _infer_hints_allowing_override(op1, op2, hints): 153 """Infer hints from op1 and op2. hints argument is an override. 154 155 Args: 156 op1: LinearOperator 157 op2: LinearOperator 158 hints: _Hints object holding "is_X" boolean hints to use for returned 159 operator. 160 If some hint is None, try to set using op1 and op2. If the 161 hint is provided, ignore op1 and op2 hints. This allows an override 162 of previous hints, but does not allow forbidden hints (e.g. you still 163 cannot say a real diagonal operator is not self-adjoint. 164 165 Returns: 166 _Hints object. 167 """ 168 hints = hints or _Hints() 169 # If A, B are self-adjoint, then so is A + B. 170 if hints.is_self_adjoint is None: 171 is_self_adjoint = op1.is_self_adjoint and op2.is_self_adjoint 172 else: 173 is_self_adjoint = hints.is_self_adjoint 174 175 # If A, B are positive definite, then so is A + B. 176 if hints.is_positive_definite is None: 177 is_positive_definite = op1.is_positive_definite and op2.is_positive_definite 178 else: 179 is_positive_definite = hints.is_positive_definite 180 181 # A positive definite operator is always non-singular. 182 if is_positive_definite and hints.is_positive_definite is None: 183 is_non_singular = True 184 else: 185 is_non_singular = hints.is_non_singular 186 187 return _Hints( 188 is_non_singular=is_non_singular, 189 is_self_adjoint=is_self_adjoint, 190 is_positive_definite=is_positive_definite) 191 192 193def _static_check_for_same_dimensions(operators): 194 """ValueError if operators determined to have different dimensions.""" 195 if len(operators) < 2: 196 return 197 198 domain_dimensions = [ 199 (op.name, tensor_shape.dimension_value(op.domain_dimension)) 200 for op in operators 201 if tensor_shape.dimension_value(op.domain_dimension) is not None] 202 if len(set(value for name, value in domain_dimensions)) > 1: 203 raise ValueError("Operators must have the same domain dimension. Found: %s" 204 % domain_dimensions) 205 206 range_dimensions = [ 207 (op.name, tensor_shape.dimension_value(op.range_dimension)) 208 for op in operators 209 if tensor_shape.dimension_value(op.range_dimension) is not None] 210 if len(set(value for name, value in range_dimensions)) > 1: 211 raise ValueError("Operators must have the same range dimension. Found: %s" % 212 range_dimensions) 213 214 215def _static_check_for_broadcastable_batch_shape(operators): 216 """ValueError if operators determined to have non-broadcastable shapes.""" 217 if len(operators) < 2: 218 return 219 220 # This will fail if they cannot be broadcast together. 221 batch_shape = operators[0].batch_shape 222 for op in operators[1:]: 223 batch_shape = array_ops.broadcast_static_shape(batch_shape, op.batch_shape) 224 225 226class _Hints(object): 227 """Holds 'is_X' flags that every LinearOperator is initialized with.""" 228 229 def __init__(self, 230 is_non_singular=None, 231 is_positive_definite=None, 232 is_self_adjoint=None): 233 self.is_non_singular = is_non_singular 234 self.is_positive_definite = is_positive_definite 235 self.is_self_adjoint = is_self_adjoint 236 237 238################################################################################ 239# Classes to add two linear operators. 240################################################################################ 241 242 243@six.add_metaclass(abc.ABCMeta) 244class _Adder(object): 245 """Abstract base class to add two operators. 246 247 Each `Adder` acts independently, adding everything it can, paying no attention 248 as to whether another `Adder` could have done the addition more efficiently. 249 """ 250 251 @property 252 def name(self): 253 return self.__class__.__name__ 254 255 @abc.abstractmethod 256 def can_add(self, op1, op2): 257 """Returns `True` if this `Adder` can add `op1` and `op2`. Else `False`.""" 258 pass 259 260 @abc.abstractmethod 261 def _add(self, op1, op2, operator_name, hints): 262 # Derived classes can assume op1 and op2 have been validated, e.g. they have 263 # the same dtype, and their domain/range dimensions match. 264 pass 265 266 def add(self, op1, op2, operator_name, hints=None): 267 """Return new `LinearOperator` acting like `op1 + op2`. 268 269 Args: 270 op1: `LinearOperator` 271 op2: `LinearOperator`, with `shape` and `dtype` such that adding to 272 `op1` is allowed. 273 operator_name: `String` name to give to returned `LinearOperator` 274 hints: `_Hints` object. Returned `LinearOperator` will be created with 275 these hints. 276 277 Returns: 278 `LinearOperator` 279 """ 280 updated_hints = _infer_hints_allowing_override(op1, op2, hints) 281 282 if operator_name is None: 283 operator_name = "Add/" + op1.name + "__" + op2.name + "/" 284 285 values = op1.graph_parents + op2.graph_parents 286 scope_name = self.name 287 if scope_name.startswith("_"): 288 scope_name = scope_name[1:] 289 with ops.name_scope(scope_name, values=values): 290 return self._add(op1, op2, operator_name, updated_hints) 291 292 293class _AddAndReturnScaledIdentity(_Adder): 294 """Handles additions resulting in an Identity family member. 295 296 The Identity (`LinearOperatorScaledIdentity`, `LinearOperatorIdentity`) family 297 is closed under addition. This `Adder` respects that, and returns an Identity 298 """ 299 300 def can_add(self, op1, op2): 301 types = {_type(op1), _type(op2)} 302 return not types.difference(_IDENTITY_FAMILY) 303 304 def _add(self, op1, op2, operator_name, hints): 305 # Will build a LinearOperatorScaledIdentity. 306 307 if _type(op1) == _SCALED_IDENTITY: 308 multiplier_1 = op1.multiplier 309 else: 310 multiplier_1 = array_ops.ones(op1.batch_shape_tensor(), dtype=op1.dtype) 311 312 if _type(op2) == _SCALED_IDENTITY: 313 multiplier_2 = op2.multiplier 314 else: 315 multiplier_2 = array_ops.ones(op2.batch_shape_tensor(), dtype=op2.dtype) 316 317 return linear_operator_identity.LinearOperatorScaledIdentity( 318 num_rows=op1.range_dimension_tensor(), 319 multiplier=multiplier_1 + multiplier_2, 320 is_non_singular=hints.is_non_singular, 321 is_self_adjoint=hints.is_self_adjoint, 322 is_positive_definite=hints.is_positive_definite, 323 name=operator_name) 324 325 326class _AddAndReturnDiag(_Adder): 327 """Handles additions resulting in a Diag operator.""" 328 329 def can_add(self, op1, op2): 330 types = {_type(op1), _type(op2)} 331 return not types.difference(_DIAG_LIKE) 332 333 def _add(self, op1, op2, operator_name, hints): 334 return linear_operator_diag.LinearOperatorDiag( 335 diag=op1.diag_part() + op2.diag_part(), 336 is_non_singular=hints.is_non_singular, 337 is_self_adjoint=hints.is_self_adjoint, 338 is_positive_definite=hints.is_positive_definite, 339 name=operator_name) 340 341 342class _AddAndReturnTriL(_Adder): 343 """Handles additions resulting in a TriL operator.""" 344 345 def can_add(self, op1, op2): 346 types = {_type(op1), _type(op2)} 347 return not types.difference(_DIAG_LIKE.union({_TRIL})) 348 349 def _add(self, op1, op2, operator_name, hints): 350 if _type(op1) in _EFFICIENT_ADD_TO_TENSOR: 351 op_add_to_tensor, op_other = op1, op2 352 else: 353 op_add_to_tensor, op_other = op2, op1 354 355 return linear_operator_lower_triangular.LinearOperatorLowerTriangular( 356 tril=op_add_to_tensor.add_to_tensor(op_other.to_dense()), 357 is_non_singular=hints.is_non_singular, 358 is_self_adjoint=hints.is_self_adjoint, 359 is_positive_definite=hints.is_positive_definite, 360 name=operator_name) 361 362 363class _AddAndReturnMatrix(_Adder): 364 """"Handles additions resulting in a `LinearOperatorFullMatrix`.""" 365 366 def can_add(self, op1, op2): # pylint: disable=unused-argument 367 return isinstance(op1, linear_operator.LinearOperator) and isinstance( 368 op2, linear_operator.LinearOperator) 369 370 def _add(self, op1, op2, operator_name, hints): 371 if _type(op1) in _EFFICIENT_ADD_TO_TENSOR: 372 op_add_to_tensor, op_other = op1, op2 373 else: 374 op_add_to_tensor, op_other = op2, op1 375 return linear_operator_full_matrix.LinearOperatorFullMatrix( 376 matrix=op_add_to_tensor.add_to_tensor(op_other.to_dense()), 377 is_non_singular=hints.is_non_singular, 378 is_self_adjoint=hints.is_self_adjoint, 379 is_positive_definite=hints.is_positive_definite, 380 name=operator_name) 381 382 383################################################################################ 384# Constants designating types of LinearOperators 385################################################################################ 386 387# Type name constants for LinearOperator classes. 388_IDENTITY = "identity" 389_SCALED_IDENTITY = "scaled_identity" 390_DIAG = "diag" 391_TRIL = "tril" 392_MATRIX = "matrix" 393 394# Groups of operators. 395_DIAG_LIKE = {_DIAG, _IDENTITY, _SCALED_IDENTITY} 396_IDENTITY_FAMILY = {_IDENTITY, _SCALED_IDENTITY} 397# operators with an efficient .add_to_tensor() method. 398_EFFICIENT_ADD_TO_TENSOR = _DIAG_LIKE 399 400# Supported LinearOperator classes. 401SUPPORTED_OPERATORS = [ 402 linear_operator_diag.LinearOperatorDiag, 403 linear_operator_lower_triangular.LinearOperatorLowerTriangular, 404 linear_operator_full_matrix.LinearOperatorFullMatrix, 405 linear_operator_identity.LinearOperatorIdentity, 406 linear_operator_identity.LinearOperatorScaledIdentity 407] 408 409 410def _type(operator): 411 """Returns the type name constant (e.g. _TRIL) for operator.""" 412 if isinstance(operator, linear_operator_diag.LinearOperatorDiag): 413 return _DIAG 414 if isinstance(operator, 415 linear_operator_lower_triangular.LinearOperatorLowerTriangular): 416 return _TRIL 417 if isinstance(operator, linear_operator_full_matrix.LinearOperatorFullMatrix): 418 return _MATRIX 419 if isinstance(operator, linear_operator_identity.LinearOperatorIdentity): 420 return _IDENTITY 421 if isinstance(operator, 422 linear_operator_identity.LinearOperatorScaledIdentity): 423 return _SCALED_IDENTITY 424 raise TypeError("Operator type unknown: %s" % operator) 425 426 427################################################################################ 428# Addition tiers: 429# We attempt to use Adders in tier K before K+1. 430# 431# Organize tiers to 432# (i) reduce O(..) complexity of forming final operator, and 433# (ii) produce the "most efficient" final operator. 434# Dev notes: 435# * Results of addition at tier K will be added at tier K or higher. 436# * Tiers may change, and we warn the user that it may change. 437################################################################################ 438 439# Note that the final tier, _AddAndReturnMatrix, will convert everything to a 440# dense matrix. So it is sometimes very inefficient. 441_DEFAULT_ADDITION_TIERS = [ 442 [_AddAndReturnScaledIdentity()], 443 [_AddAndReturnDiag()], 444 [_AddAndReturnTriL()], 445 [_AddAndReturnMatrix()], 446] 447