1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Add one or more `LinearOperators` efficiently.""" 16 17import abc 18 19from tensorflow.python.framework import ops 20from tensorflow.python.framework import tensor_shape 21from tensorflow.python.ops import array_ops 22from tensorflow.python.ops import check_ops 23from tensorflow.python.ops.linalg import linear_operator 24from tensorflow.python.ops.linalg import linear_operator_diag 25from tensorflow.python.ops.linalg import linear_operator_full_matrix 26from tensorflow.python.ops.linalg import linear_operator_identity 27from tensorflow.python.ops.linalg import linear_operator_lower_triangular 28 29__all__ = [] 30 31 32def add_operators(operators, 33 operator_name=None, 34 addition_tiers=None, 35 name=None): 36 """Efficiently add one or more linear operators. 37 38 Given operators `[A1, A2,...]`, this `Op` returns a possibly shorter list of 39 operators `[B1, B2,...]` such that 40 41 ```sum_k Ak.matmul(x) = sum_k Bk.matmul(x).``` 42 43 The operators `Bk` result by adding some of the `Ak`, as allowed by 44 `addition_tiers`. 45 46 Example of efficient adding of diagonal operators. 47 48 ```python 49 A1 = LinearOperatorDiag(diag=[1., 1.], name="A1") 50 A2 = LinearOperatorDiag(diag=[2., 2.], name="A2") 51 52 # Use two tiers, the first contains an Adder that returns Diag. Since both 53 # A1 and A2 are Diag, they can use this Adder. The second tier will not be 54 # used. 55 addition_tiers = [ 56 [_AddAndReturnDiag()], 57 [_AddAndReturnMatrix()]] 58 B_list = add_operators([A1, A2], addition_tiers=addition_tiers) 59 60 len(B_list) 61 ==> 1 62 63 B_list[0].__class__.__name__ 64 ==> 'LinearOperatorDiag' 65 66 B_list[0].to_dense() 67 ==> [[3., 0.], 68 [0., 3.]] 69 70 B_list[0].name 71 ==> 'Add/A1__A2/' 72 ``` 73 74 Args: 75 operators: Iterable of `LinearOperator` objects with same `dtype`, domain 76 and range dimensions, and broadcastable batch shapes. 77 operator_name: String name for returned `LinearOperator`. Defaults to 78 concatenation of "Add/A__B/" that indicates the order of addition steps. 79 addition_tiers: List tiers, like `[tier_0, tier_1, ...]`, where `tier_i` 80 is a list of `Adder` objects. This function attempts to do all additions 81 in tier `i` before trying tier `i + 1`. 82 name: A name for this `Op`. Defaults to `add_operators`. 83 84 Returns: 85 Subclass of `LinearOperator`. Class and order of addition may change as new 86 (and better) addition strategies emerge. 87 88 Raises: 89 ValueError: If `operators` argument is empty. 90 ValueError: If shapes are incompatible. 91 """ 92 # Default setting 93 if addition_tiers is None: 94 addition_tiers = _DEFAULT_ADDITION_TIERS 95 96 # Argument checking. 97 check_ops.assert_proper_iterable(operators) 98 operators = list(reversed(operators)) 99 if len(operators) < 1: 100 raise ValueError( 101 f"Argument `operators` must contain at least one operator. " 102 f"Received: {operators}.") 103 if not all( 104 isinstance(op, linear_operator.LinearOperator) for op in operators): 105 raise TypeError( 106 f"Argument `operators` must contain only LinearOperator instances. " 107 f"Received: {operators}.") 108 _static_check_for_same_dimensions(operators) 109 _static_check_for_broadcastable_batch_shape(operators) 110 111 with ops.name_scope(name or "add_operators"): 112 113 # Additions done in one of the tiers. Try tier 0, 1,... 114 ops_to_try_at_next_tier = list(operators) 115 for tier in addition_tiers: 116 ops_to_try_at_this_tier = ops_to_try_at_next_tier 117 ops_to_try_at_next_tier = [] 118 while ops_to_try_at_this_tier: 119 op1 = ops_to_try_at_this_tier.pop() 120 op2, adder = _pop_a_match_at_tier(op1, ops_to_try_at_this_tier, tier) 121 if op2 is not None: 122 # Will try to add the result of this again at this same tier. 123 new_operator = adder.add(op1, op2, operator_name) 124 ops_to_try_at_this_tier.append(new_operator) 125 else: 126 ops_to_try_at_next_tier.append(op1) 127 128 return ops_to_try_at_next_tier 129 130 131def _pop_a_match_at_tier(op1, operator_list, tier): 132 # Search from the back of list to the front in order to create nice default 133 # order of operations. 134 for i in range(1, len(operator_list) + 1): 135 op2 = operator_list[-i] 136 for adder in tier: 137 if adder.can_add(op1, op2): 138 return operator_list.pop(-i), adder 139 return None, None 140 141 142def _infer_hints_allowing_override(op1, op2, hints): 143 """Infer hints from op1 and op2. hints argument is an override. 144 145 Args: 146 op1: LinearOperator 147 op2: LinearOperator 148 hints: _Hints object holding "is_X" boolean hints to use for returned 149 operator. 150 If some hint is None, try to set using op1 and op2. If the 151 hint is provided, ignore op1 and op2 hints. This allows an override 152 of previous hints, but does not allow forbidden hints (e.g. you still 153 cannot say a real diagonal operator is not self-adjoint. 154 155 Returns: 156 _Hints object. 157 """ 158 hints = hints or _Hints() 159 # If A, B are self-adjoint, then so is A + B. 160 if hints.is_self_adjoint is None: 161 is_self_adjoint = op1.is_self_adjoint and op2.is_self_adjoint 162 else: 163 is_self_adjoint = hints.is_self_adjoint 164 165 # If A, B are positive definite, then so is A + B. 166 if hints.is_positive_definite is None: 167 is_positive_definite = op1.is_positive_definite and op2.is_positive_definite 168 else: 169 is_positive_definite = hints.is_positive_definite 170 171 # A positive definite operator is always non-singular. 172 if is_positive_definite and hints.is_positive_definite is None: 173 is_non_singular = True 174 else: 175 is_non_singular = hints.is_non_singular 176 177 return _Hints( 178 is_non_singular=is_non_singular, 179 is_self_adjoint=is_self_adjoint, 180 is_positive_definite=is_positive_definite) 181 182 183def _static_check_for_same_dimensions(operators): 184 """ValueError if operators determined to have different dimensions.""" 185 if len(operators) < 2: 186 return 187 188 domain_dimensions = [ 189 (op.name, tensor_shape.dimension_value(op.domain_dimension)) 190 for op in operators 191 if tensor_shape.dimension_value(op.domain_dimension) is not None] 192 if len(set(value for name, value in domain_dimensions)) > 1: 193 raise ValueError(f"All `operators` must have the same `domain_dimension`. " 194 f"Received: {domain_dimensions}.") 195 196 range_dimensions = [ 197 (op.name, tensor_shape.dimension_value(op.range_dimension)) 198 for op in operators 199 if tensor_shape.dimension_value(op.range_dimension) is not None] 200 if len(set(value for name, value in range_dimensions)) > 1: 201 raise ValueError(f"All operators must have the same `range_dimension`. " 202 f"Received: {range_dimensions}.") 203 204 205def _static_check_for_broadcastable_batch_shape(operators): 206 """ValueError if operators determined to have non-broadcastable shapes.""" 207 if len(operators) < 2: 208 return 209 210 # This will fail if they cannot be broadcast together. 211 batch_shape = operators[0].batch_shape 212 for op in operators[1:]: 213 batch_shape = array_ops.broadcast_static_shape(batch_shape, op.batch_shape) 214 215 216class _Hints: 217 """Holds 'is_X' flags that every LinearOperator is initialized with.""" 218 219 def __init__(self, 220 is_non_singular=None, 221 is_positive_definite=None, 222 is_self_adjoint=None): 223 self.is_non_singular = is_non_singular 224 self.is_positive_definite = is_positive_definite 225 self.is_self_adjoint = is_self_adjoint 226 227 228################################################################################ 229# Classes to add two linear operators. 230################################################################################ 231 232 233class _Adder(metaclass=abc.ABCMeta): 234 """Abstract base class to add two operators. 235 236 Each `Adder` acts independently, adding everything it can, paying no attention 237 as to whether another `Adder` could have done the addition more efficiently. 238 """ 239 240 @property 241 def name(self): 242 return self.__class__.__name__ 243 244 @abc.abstractmethod 245 def can_add(self, op1, op2): 246 """Returns `True` if this `Adder` can add `op1` and `op2`. Else `False`.""" 247 pass 248 249 @abc.abstractmethod 250 def _add(self, op1, op2, operator_name, hints): 251 # Derived classes can assume op1 and op2 have been validated, e.g. they have 252 # the same dtype, and their domain/range dimensions match. 253 pass 254 255 def add(self, op1, op2, operator_name, hints=None): 256 """Return new `LinearOperator` acting like `op1 + op2`. 257 258 Args: 259 op1: `LinearOperator` 260 op2: `LinearOperator`, with `shape` and `dtype` such that adding to 261 `op1` is allowed. 262 operator_name: `String` name to give to returned `LinearOperator` 263 hints: `_Hints` object. Returned `LinearOperator` will be created with 264 these hints. 265 266 Returns: 267 `LinearOperator` 268 """ 269 updated_hints = _infer_hints_allowing_override(op1, op2, hints) 270 271 if operator_name is None: 272 operator_name = "Add/" + op1.name + "__" + op2.name + "/" 273 274 scope_name = self.name 275 if scope_name.startswith("_"): 276 scope_name = scope_name[1:] 277 with ops.name_scope(scope_name): 278 return self._add(op1, op2, operator_name, updated_hints) 279 280 281class _AddAndReturnScaledIdentity(_Adder): 282 """Handles additions resulting in an Identity family member. 283 284 The Identity (`LinearOperatorScaledIdentity`, `LinearOperatorIdentity`) family 285 is closed under addition. This `Adder` respects that, and returns an Identity 286 """ 287 288 def can_add(self, op1, op2): 289 types = {_type(op1), _type(op2)} 290 return not types.difference(_IDENTITY_FAMILY) 291 292 def _add(self, op1, op2, operator_name, hints): 293 # Will build a LinearOperatorScaledIdentity. 294 295 if _type(op1) == _SCALED_IDENTITY: 296 multiplier_1 = op1.multiplier 297 else: 298 multiplier_1 = array_ops.ones(op1.batch_shape_tensor(), dtype=op1.dtype) 299 300 if _type(op2) == _SCALED_IDENTITY: 301 multiplier_2 = op2.multiplier 302 else: 303 multiplier_2 = array_ops.ones(op2.batch_shape_tensor(), dtype=op2.dtype) 304 305 return linear_operator_identity.LinearOperatorScaledIdentity( 306 num_rows=op1.range_dimension_tensor(), 307 multiplier=multiplier_1 + multiplier_2, 308 is_non_singular=hints.is_non_singular, 309 is_self_adjoint=hints.is_self_adjoint, 310 is_positive_definite=hints.is_positive_definite, 311 name=operator_name) 312 313 314class _AddAndReturnDiag(_Adder): 315 """Handles additions resulting in a Diag operator.""" 316 317 def can_add(self, op1, op2): 318 types = {_type(op1), _type(op2)} 319 return not types.difference(_DIAG_LIKE) 320 321 def _add(self, op1, op2, operator_name, hints): 322 return linear_operator_diag.LinearOperatorDiag( 323 diag=op1.diag_part() + op2.diag_part(), 324 is_non_singular=hints.is_non_singular, 325 is_self_adjoint=hints.is_self_adjoint, 326 is_positive_definite=hints.is_positive_definite, 327 name=operator_name) 328 329 330class _AddAndReturnTriL(_Adder): 331 """Handles additions resulting in a TriL operator.""" 332 333 def can_add(self, op1, op2): 334 types = {_type(op1), _type(op2)} 335 return not types.difference(_DIAG_LIKE.union({_TRIL})) 336 337 def _add(self, op1, op2, operator_name, hints): 338 if _type(op1) in _EFFICIENT_ADD_TO_TENSOR: 339 op_add_to_tensor, op_other = op1, op2 340 else: 341 op_add_to_tensor, op_other = op2, op1 342 343 return linear_operator_lower_triangular.LinearOperatorLowerTriangular( 344 tril=op_add_to_tensor.add_to_tensor(op_other.to_dense()), 345 is_non_singular=hints.is_non_singular, 346 is_self_adjoint=hints.is_self_adjoint, 347 is_positive_definite=hints.is_positive_definite, 348 name=operator_name) 349 350 351class _AddAndReturnMatrix(_Adder): 352 """"Handles additions resulting in a `LinearOperatorFullMatrix`.""" 353 354 def can_add(self, op1, op2): # pylint: disable=unused-argument 355 return isinstance(op1, linear_operator.LinearOperator) and isinstance( 356 op2, linear_operator.LinearOperator) 357 358 def _add(self, op1, op2, operator_name, hints): 359 if _type(op1) in _EFFICIENT_ADD_TO_TENSOR: 360 op_add_to_tensor, op_other = op1, op2 361 else: 362 op_add_to_tensor, op_other = op2, op1 363 return linear_operator_full_matrix.LinearOperatorFullMatrix( 364 matrix=op_add_to_tensor.add_to_tensor(op_other.to_dense()), 365 is_non_singular=hints.is_non_singular, 366 is_self_adjoint=hints.is_self_adjoint, 367 is_positive_definite=hints.is_positive_definite, 368 name=operator_name) 369 370 371################################################################################ 372# Constants designating types of LinearOperators 373################################################################################ 374 375# Type name constants for LinearOperator classes. 376_IDENTITY = "identity" 377_SCALED_IDENTITY = "scaled_identity" 378_DIAG = "diag" 379_TRIL = "tril" 380_MATRIX = "matrix" 381 382# Groups of operators. 383_DIAG_LIKE = {_DIAG, _IDENTITY, _SCALED_IDENTITY} 384_IDENTITY_FAMILY = {_IDENTITY, _SCALED_IDENTITY} 385# operators with an efficient .add_to_tensor() method. 386_EFFICIENT_ADD_TO_TENSOR = _DIAG_LIKE 387 388# Supported LinearOperator classes. 389SUPPORTED_OPERATORS = [ 390 linear_operator_diag.LinearOperatorDiag, 391 linear_operator_lower_triangular.LinearOperatorLowerTriangular, 392 linear_operator_full_matrix.LinearOperatorFullMatrix, 393 linear_operator_identity.LinearOperatorIdentity, 394 linear_operator_identity.LinearOperatorScaledIdentity 395] 396 397 398def _type(operator): 399 """Returns the type name constant (e.g. _TRIL) for operator.""" 400 if isinstance(operator, linear_operator_diag.LinearOperatorDiag): 401 return _DIAG 402 if isinstance(operator, 403 linear_operator_lower_triangular.LinearOperatorLowerTriangular): 404 return _TRIL 405 if isinstance(operator, linear_operator_full_matrix.LinearOperatorFullMatrix): 406 return _MATRIX 407 if isinstance(operator, linear_operator_identity.LinearOperatorIdentity): 408 return _IDENTITY 409 if isinstance(operator, 410 linear_operator_identity.LinearOperatorScaledIdentity): 411 return _SCALED_IDENTITY 412 raise TypeError(f"Expected operator to be one of [LinearOperatorDiag, " 413 f"LinearOperatorLowerTriangular, LinearOperatorFullMatrix, " 414 f"LinearOperatorIdentity, LinearOperatorScaledIdentity]. " 415 f"Received: {operator}") 416 417 418################################################################################ 419# Addition tiers: 420# We attempt to use Adders in tier K before K+1. 421# 422# Organize tiers to 423# (i) reduce O(..) complexity of forming final operator, and 424# (ii) produce the "most efficient" final operator. 425# Dev notes: 426# * Results of addition at tier K will be added at tier K or higher. 427# * Tiers may change, and we warn the user that it may change. 428################################################################################ 429 430# Note that the final tier, _AddAndReturnMatrix, will convert everything to a 431# dense matrix. So it is sometimes very inefficient. 432_DEFAULT_ADDITION_TIERS = [ 433 [_AddAndReturnScaledIdentity()], 434 [_AddAndReturnDiag()], 435 [_AddAndReturnTriL()], 436 [_AddAndReturnMatrix()], 437] 438