1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""TensorFlow-related utilities.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import six 21 22from tensorflow.python.eager import context 23from tensorflow.python.framework import composite_tensor 24from tensorflow.python.framework import ops 25from tensorflow.python.framework import smart_cond as smart_module 26from tensorflow.python.framework import tensor_shape 27from tensorflow.python.framework import tensor_util 28from tensorflow.python.ops import control_flow_ops 29from tensorflow.python.ops import variables 30from tensorflow.python.util import nest 31 32 33def smart_cond(pred, true_fn=None, false_fn=None, name=None): 34 """Return either `true_fn()` if predicate `pred` is true else `false_fn()`. 35 36 If `pred` is a bool or has a constant value, we return either `true_fn()` 37 or `false_fn()`, otherwise we use `tf.cond` to dynamically route to both. 38 39 Arguments: 40 pred: A scalar determining whether to return the result of `true_fn` or 41 `false_fn`. 42 true_fn: The callable to be performed if pred is true. 43 false_fn: The callable to be performed if pred is false. 44 name: Optional name prefix when using `tf.cond`. 45 46 Returns: 47 Tensors returned by the call to either `true_fn` or `false_fn`. 48 49 Raises: 50 TypeError: If `true_fn` or `false_fn` is not callable. 51 """ 52 if isinstance(pred, variables.Variable): 53 return control_flow_ops.cond( 54 pred, true_fn=true_fn, false_fn=false_fn, name=name) 55 return smart_module.smart_cond( 56 pred, true_fn=true_fn, false_fn=false_fn, name=name) 57 58 59def constant_value(pred): 60 """Return the bool value for `pred`, or None if `pred` had a dynamic value. 61 62 Arguments: 63 pred: A scalar, either a Python bool or a TensorFlow boolean variable 64 or tensor, or the Python integer 1 or 0. 65 66 Returns: 67 True or False if `pred` has a constant boolean value, None otherwise. 68 69 Raises: 70 TypeError: If `pred` is not a Variable, Tensor or bool, or Python 71 integer 1 or 0. 72 """ 73 # Allow integer booleans. 74 if isinstance(pred, int): 75 if pred == 1: 76 pred = True 77 elif pred == 0: 78 pred = False 79 80 if isinstance(pred, variables.Variable): 81 return None 82 return smart_module.smart_constant_value(pred) 83 84 85def is_tensor_or_tensor_list(v): 86 v = nest.flatten(v) 87 if v and isinstance(v[0], ops.Tensor): 88 return True 89 else: 90 return False 91 92 93def get_reachable_from_inputs(inputs, targets=None): 94 """Returns the set of tensors/ops reachable from `inputs`. 95 96 Stops if all targets have been found (target is optional). 97 98 Only valid in Symbolic mode, not Eager mode. 99 100 Args: 101 inputs: List of tensors. 102 targets: List of tensors. 103 104 Returns: 105 A set of tensors reachable from the inputs (includes the inputs themselves). 106 """ 107 inputs = nest.flatten(inputs) 108 reachable = set(inputs) 109 if targets: 110 targets = set(targets) 111 queue = inputs[:] 112 113 while queue: 114 x = queue.pop() 115 if isinstance(x, tuple(_user_convertible_tensor_types)): 116 # Can't find consumers of user-specific types. 117 continue 118 119 if isinstance(x, ops.Operation): 120 outputs = x.outputs[:] or [] 121 outputs += x._control_outputs # pylint: disable=protected-access 122 elif isinstance(x, variables.Variable): 123 try: 124 outputs = [x.op] 125 except AttributeError: 126 # Variables can be created in an Eager context. 127 outputs = [] 128 elif tensor_util.is_tensor(x): 129 outputs = x.consumers() 130 else: 131 raise TypeError('Expected Operation, Variable, or Tensor, got ' + str(x)) 132 133 for y in outputs: 134 if y not in reachable: 135 reachable.add(y) 136 queue.insert(0, y) 137 138 if targets and targets.issubset(reachable): 139 return reachable 140 return reachable 141 142 143# This function needs access to private functions of `nest`. 144# pylint: disable=protected-access 145def map_structure_with_atomic(is_atomic_fn, map_fn, nested): 146 """Maps the atomic elements of a nested structure. 147 148 Arguments: 149 is_atomic_fn: A function that determines if an element of `nested` is 150 atomic. 151 map_fn: The function to apply to atomic elements of `nested`. 152 nested: A nested structure. 153 154 Returns: 155 The nested structure, with atomic elements mapped according to `map_fn`. 156 157 Raises: 158 ValueError: If an element that is neither atomic nor a sequence is 159 encountered. 160 """ 161 if is_atomic_fn(nested): 162 return map_fn(nested) 163 164 # Recursively convert. 165 if not nest.is_sequence(nested): 166 raise ValueError( 167 'Received non-atomic and non-sequence element: {}'.format(nested)) 168 if nest._is_mapping(nested): 169 values = [nested[k] for k in nest._sorted(nested)] 170 else: 171 values = nested 172 mapped_values = [ 173 map_structure_with_atomic(is_atomic_fn, map_fn, ele) for ele in values 174 ] 175 return nest._sequence_like(nested, mapped_values) 176 177 178# pylint: enable=protected-access 179 180 181def convert_shapes(input_shape, to_tuples=True): 182 """Converts nested shape representations to desired format. 183 184 Performs: 185 186 TensorShapes -> tuples if `to_tuples=True`. 187 tuples of int or None -> TensorShapes if `to_tuples=False`. 188 189 Valid objects to be converted are: 190 - TensorShapes 191 - tuples with elements of type int or None. 192 - ints 193 - None 194 195 Arguments: 196 input_shape: A nested structure of objects to be converted to TensorShapes. 197 to_tuples: If `True`, converts all TensorShape to tuples. Otherwise converts 198 all tuples representing shapes to TensorShapes. 199 200 Returns: 201 Nested structure of shapes in desired format. 202 """ 203 204 def _is_shape_component(value): 205 return value is None or isinstance(value, (int, tensor_shape.Dimension)) 206 207 def _is_atomic_shape(input_shape): 208 # Ex: TensorShape or (None, 10, 32) or 5 or `None` 209 if _is_shape_component(input_shape): 210 return True 211 if isinstance(input_shape, tensor_shape.TensorShape): 212 return True 213 if (isinstance(input_shape, (tuple, list)) and 214 all(_is_shape_component(ele) for ele in input_shape)): 215 return True 216 return False 217 218 def _convert_shape(input_shape): 219 input_shape = tensor_shape.TensorShape(input_shape) 220 if to_tuples: 221 input_shape = tuple(input_shape.as_list()) 222 return input_shape 223 224 return map_structure_with_atomic(_is_atomic_shape, _convert_shape, 225 input_shape) 226 227 228class ListWrapper(object): 229 """A wrapper for lists to be treated as elements for `nest`.""" 230 231 def __init__(self, list_to_wrap): 232 self._list = list_to_wrap 233 234 def as_list(self): 235 return self._list 236 237 238def convert_inner_node_data(nested, wrap=False): 239 """Either wraps or unwraps innermost node data lists in `ListWrapper` objects. 240 241 Arguments: 242 nested: A nested data structure. 243 wrap: If `True`, wrap innermost lists in `ListWrapper` objects. If `False`, 244 unwraps `ListWrapper` objects into lists. 245 246 Returns: 247 Strucutre of same type as nested, with lists wrapped/unwrapped. 248 """ 249 250 def _is_atomic_nested(nested): 251 """Returns `True` if `nested` is a list representing node data.""" 252 if isinstance(nested, ListWrapper): 253 return True 254 # Node data can be of form `[layer_name, node_id, tensor_id]` or 255 # `[layer_name, node_id, tensor_id, kwargs]`. 256 if (isinstance(nested, list) and (len(nested) in [3, 4]) and 257 isinstance(nested[0], six.string_types)): 258 return True 259 return False 260 261 def _convert_object_or_list(nested): 262 """Convert b/t `ListWrapper` object and list representations.""" 263 if wrap: 264 if isinstance(nested, ListWrapper): 265 return nested 266 return ListWrapper(nested) 267 else: 268 if isinstance(nested, ListWrapper): 269 return nested.as_list() 270 return nested 271 272 return map_structure_with_atomic(_is_atomic_nested, _convert_object_or_list, 273 nested) 274 275 276def shape_type_conversion(fn): 277 """Decorator that handles tuple/TensorShape conversion. 278 279 Used in `compute_output_shape` and `build`. 280 281 Arguments: 282 fn: function to wrap. 283 284 Returns: 285 Wrapped function. 286 """ 287 288 def wrapper(instance, input_shape): 289 # Pass shapes as tuples to `fn` 290 # This preserves compatibility with external Keras. 291 if input_shape is not None: 292 input_shape = convert_shapes(input_shape, to_tuples=True) 293 output_shape = fn(instance, input_shape) 294 # Return shapes from `fn` as TensorShapes. 295 if output_shape is not None: 296 output_shape = convert_shapes(output_shape, to_tuples=False) 297 return output_shape 298 299 return wrapper 300 301 302def are_all_symbolic_tensors(tensors): 303 return all(is_symbolic_tensor(tensor) for tensor in tensors) 304 305 306_user_convertible_tensor_types = set() 307 308 309def is_symbolic_tensor(tensor): 310 """Returns whether a tensor is symbolic (from a TF graph) or an eager tensor. 311 312 A Variable can be seen as either: it is considered symbolic 313 when we are in a graph scope, and eager when we are in an eager scope. 314 315 Arguments: 316 tensor: A tensor instance to test. 317 318 Returns: 319 True for symbolic tensors, False for eager tensors. 320 """ 321 if isinstance(tensor, variables.Variable): 322 # Variables that are output of a Keras Layer in Functional API mode 323 # should be considered symbolic. 324 # TODO(omalleyt): We need a better way to check this in order to 325 # enable `run_eagerly=True` for Models containing Layers that 326 # return Variables as outputs. 327 return (getattr(tensor, '_keras_history', False) or 328 not context.executing_eagerly()) 329 if isinstance(tensor, composite_tensor.CompositeTensor): 330 return tensor._is_graph_tensor # pylint: disable=protected-access 331 if isinstance(tensor, ops.Tensor): 332 return hasattr(tensor, 'graph') 333 if isinstance(tensor, tuple(_user_convertible_tensor_types)): 334 return hasattr(ops.convert_to_tensor(tensor), 'graph') 335 return False 336 337 338def register_symbolic_tensor_type(cls): 339 """Allows users to specify types regarded as symbolic `Tensor`s. 340 341 Used in conjunction with `tf.register_tensor_conversion_function`, calling 342 `tf.keras.utils.register_symbolic_tensor_type(cls)` allows non-`Tensor` 343 objects to be plumbed through Keras layers. 344 345 Example: 346 347 ```python 348 # One-time setup. 349 class Foo(object): 350 def __init__(self, input_): 351 self._input = input_ 352 def value(self): 353 return tf.constant(42.) 354 355 tf.register_tensor_conversion_function( 356 Foo, lambda x, *args, **kwargs: x.value()) 357 358 tf.keras.utils.register_symbolic_tensor_type(Foo) 359 360 # User-land. 361 layer = tf.keras.layers.Lambda(lambda input_: Foo(input_)) 362 ``` 363 364 Arguments: 365 cls: A `class` type which shall be regarded as a symbolic `Tensor`. 366 """ 367 global _user_convertible_tensor_types 368 _user_convertible_tensor_types.add(cls) 369 370 371def is_tensor_or_variable(x): 372 return tensor_util.is_tensor(x) or isinstance(x, variables.Variable) 373