1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""TensorFlow-related utilities.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import collections 21import copy 22import numpy as np 23import six 24 25from tensorflow.python.data.experimental.ops import cardinality 26from tensorflow.python.eager import context 27from tensorflow.python.framework import composite_tensor 28from tensorflow.python.framework import ops 29from tensorflow.python.framework import tensor_shape 30from tensorflow.python.framework import tensor_spec 31from tensorflow.python.framework import tensor_util 32from tensorflow.python.framework import type_spec 33from tensorflow.python.keras import backend as K 34from tensorflow.python.keras.engine import keras_tensor 35from tensorflow.python.keras.utils import object_identity 36from tensorflow.python.keras.utils import tf_contextlib 37from tensorflow.python.ops import math_ops 38from tensorflow.python.ops import variables 39from tensorflow.python.ops.ragged import ragged_tensor 40from tensorflow.python.ops.ragged import ragged_tensor_value 41from tensorflow.python.util import nest 42 43 44def is_tensor_or_tensor_list(v): 45 v = nest.flatten(v) 46 if v and isinstance(v[0], ops.Tensor): 47 return True 48 else: 49 return False 50 51 52def get_reachable_from_inputs(inputs, targets=None): 53 """Returns the set of tensors/ops reachable from `inputs`. 54 55 Stops if all targets have been found (target is optional). 56 57 Only valid in Symbolic mode, not Eager mode. 58 59 Args: 60 inputs: List of tensors. 61 targets: List of tensors. 62 63 Returns: 64 A set of tensors reachable from the inputs (includes the inputs themselves). 65 """ 66 inputs = nest.flatten(inputs, expand_composites=True) 67 reachable = object_identity.ObjectIdentitySet(inputs) 68 if targets: 69 remaining_targets = object_identity.ObjectIdentitySet(nest.flatten(targets)) 70 queue = collections.deque(inputs) 71 72 while queue: 73 x = queue.pop() 74 if isinstance(x, tuple(_user_convertible_tensor_types)): 75 # Can't find consumers of user-specific types. 76 continue 77 78 if isinstance(x, ops.Operation): 79 outputs = x.outputs[:] or [] 80 outputs += x._control_outputs # pylint: disable=protected-access 81 elif isinstance(x, variables.Variable): 82 try: 83 outputs = [x.op] 84 except AttributeError: 85 # Variables can be created in an Eager context. 86 outputs = [] 87 elif tensor_util.is_tf_type(x): 88 outputs = x.consumers() 89 else: 90 raise TypeError('Expected Operation, Variable, or Tensor, got ' + str(x)) 91 92 for y in outputs: 93 if y not in reachable: 94 reachable.add(y) 95 if targets: 96 remaining_targets.discard(y) 97 queue.appendleft(y) 98 99 if targets and not remaining_targets: 100 return reachable 101 102 return reachable 103 104 105# This function needs access to private functions of `nest`. 106# pylint: disable=protected-access 107def map_structure_with_atomic(is_atomic_fn, map_fn, nested): 108 """Maps the atomic elements of a nested structure. 109 110 Args: 111 is_atomic_fn: A function that determines if an element of `nested` is 112 atomic. 113 map_fn: The function to apply to atomic elements of `nested`. 114 nested: A nested structure. 115 116 Returns: 117 The nested structure, with atomic elements mapped according to `map_fn`. 118 119 Raises: 120 ValueError: If an element that is neither atomic nor a sequence is 121 encountered. 122 """ 123 if is_atomic_fn(nested): 124 return map_fn(nested) 125 126 # Recursively convert. 127 if not nest.is_nested(nested): 128 raise ValueError( 129 'Received non-atomic and non-sequence element: {}'.format(nested)) 130 if nest.is_mapping(nested): 131 values = [nested[k] for k in sorted(nested.keys())] 132 elif nest.is_attrs(nested): 133 values = _astuple(nested) 134 else: 135 values = nested 136 mapped_values = [ 137 map_structure_with_atomic(is_atomic_fn, map_fn, ele) for ele in values 138 ] 139 return nest._sequence_like(nested, mapped_values) 140 141 142def get_shapes(tensors): 143 """Gets shapes from tensors.""" 144 return nest.map_structure(lambda x: x.shape, tensors) 145 146 147# pylint: enable=protected-access 148 149 150def convert_shapes(input_shape, to_tuples=True): 151 """Converts nested shape representations to desired format. 152 153 Performs: 154 155 TensorShapes -> tuples if `to_tuples=True`. 156 tuples of int or None -> TensorShapes if `to_tuples=False`. 157 158 Valid objects to be converted are: 159 - TensorShapes 160 - tuples with elements of type int or None. 161 - ints 162 - None 163 164 Args: 165 input_shape: A nested structure of objects to be converted to TensorShapes. 166 to_tuples: If `True`, converts all TensorShape to tuples. Otherwise converts 167 all tuples representing shapes to TensorShapes. 168 169 Returns: 170 Nested structure of shapes in desired format. 171 172 Raises: 173 ValueError: when the input tensor shape can't be converted to tuples, eg 174 unknown tensor shape. 175 """ 176 177 def _is_shape_component(value): 178 return value is None or isinstance(value, (int, tensor_shape.Dimension)) 179 180 def _is_atomic_shape(input_shape): 181 # Ex: TensorShape or (None, 10, 32) or 5 or `None` 182 if _is_shape_component(input_shape): 183 return True 184 if isinstance(input_shape, tensor_shape.TensorShape): 185 return True 186 if (isinstance(input_shape, (tuple, list)) and 187 all(_is_shape_component(ele) for ele in input_shape)): 188 return True 189 return False 190 191 def _convert_shape(input_shape): 192 input_shape = tensor_shape.TensorShape(input_shape) 193 if to_tuples: 194 input_shape = tuple(input_shape.as_list()) 195 return input_shape 196 197 return map_structure_with_atomic(_is_atomic_shape, _convert_shape, 198 input_shape) 199 200 201class ListWrapper(object): 202 """A wrapper for lists to be treated as elements for `nest`.""" 203 204 def __init__(self, list_to_wrap): 205 self._list = list_to_wrap 206 207 def as_list(self): 208 return self._list 209 210 211def convert_inner_node_data(nested, wrap=False): 212 """Either wraps or unwraps innermost node data lists in `ListWrapper` objects. 213 214 Args: 215 nested: A nested data structure. 216 wrap: If `True`, wrap innermost lists in `ListWrapper` objects. If `False`, 217 unwraps `ListWrapper` objects into lists. 218 219 Returns: 220 Structure of same type as nested, with lists wrapped/unwrapped. 221 """ 222 223 def _is_serialized_node_data(nested): 224 # Node data can be of form `[layer_name, node_id, tensor_id]` or 225 # `[layer_name, node_id, tensor_id, kwargs]`. 226 if (isinstance(nested, list) and (len(nested) in [3, 4]) and 227 isinstance(nested[0], six.string_types)): 228 return True 229 return False 230 231 def _is_atomic_nested(nested): 232 """Returns `True` if `nested` is a list representing node data.""" 233 if isinstance(nested, ListWrapper): 234 return True 235 if _is_serialized_node_data(nested): 236 return True 237 return not nest.is_nested(nested) 238 239 def _convert_object_or_list(nested): 240 """Convert b/t `ListWrapper` object and list representations.""" 241 if wrap: 242 if isinstance(nested, ListWrapper): 243 return nested 244 if _is_serialized_node_data(nested): 245 return ListWrapper(nested) 246 return nested 247 else: 248 if isinstance(nested, ListWrapper): 249 return nested.as_list() 250 return nested 251 252 return map_structure_with_atomic(_is_atomic_nested, _convert_object_or_list, 253 nested) 254 255 256def shape_type_conversion(fn): 257 """Decorator that handles tuple/TensorShape conversion. 258 259 Used in `compute_output_shape` and `build`. 260 261 Args: 262 fn: function to wrap. 263 264 Returns: 265 Wrapped function. 266 """ 267 268 def wrapper(instance, input_shape): 269 # Pass shapes as tuples to `fn` 270 # This preserves compatibility with external Keras. 271 if input_shape is not None: 272 input_shape = convert_shapes(input_shape, to_tuples=True) 273 output_shape = fn(instance, input_shape) 274 # Return shapes from `fn` as TensorShapes. 275 if output_shape is not None: 276 output_shape = convert_shapes(output_shape, to_tuples=False) 277 return output_shape 278 279 return wrapper 280 281 282def are_all_symbolic_tensors(tensors): 283 return all(map(is_symbolic_tensor, tensors)) 284 285 286_user_convertible_tensor_types = set() 287 288 289def is_extension_type(tensor): 290 """Returns whether a tensor is of an ExtensionType. 291 292 github.com/tensorflow/community/pull/269 293 Currently it works by checking if `tensor` is a `CompositeTensor` instance, 294 but this will be changed to use an appropriate extensiontype protocol 295 check once ExtensionType is made public. 296 297 Args: 298 tensor: An object to test 299 300 Returns: 301 True if the tensor is an extension type object, false if not. 302 """ 303 return isinstance(tensor, composite_tensor.CompositeTensor) 304 305 306def is_symbolic_tensor(tensor): 307 """Returns whether a tensor is symbolic (from a TF graph) or an eager tensor. 308 309 A Variable can be seen as either: it is considered symbolic 310 when we are in a graph scope, and eager when we are in an eager scope. 311 312 Args: 313 tensor: A tensor instance to test. 314 315 Returns: 316 True for symbolic tensors, False for eager tensors. 317 """ 318 if isinstance(tensor, ops.Tensor): 319 return hasattr(tensor, 'graph') 320 elif is_extension_type(tensor): 321 component_tensors = nest.flatten(tensor, expand_composites=True) 322 return any(hasattr(t, 'graph') for t in component_tensors) 323 elif isinstance(tensor, variables.Variable): 324 # Variables that are output of a Keras Layer in Functional API mode 325 # should be considered symbolic. 326 # TODO(omalleyt): We need a better way to check this in order to 327 # enable `run_eagerly=True` for Models containing Layers that 328 # return Variables as outputs. 329 return (getattr(tensor, '_keras_history', False) or 330 not context.executing_eagerly()) 331 elif isinstance(tensor, tuple(_user_convertible_tensor_types)): 332 tensor = ops.convert_to_tensor_or_composite(tensor) 333 return is_symbolic_tensor(tensor) 334 else: 335 return False 336 337 338def register_symbolic_tensor_type(cls): 339 """Allows users to specify types regarded as symbolic `Tensor`s. 340 341 Used in conjunction with `tf.register_tensor_conversion_function`, calling 342 `tf.keras.utils.register_symbolic_tensor_type(cls)` allows non-`Tensor` 343 objects to be plumbed through Keras layers. 344 345 Example: 346 347 ```python 348 # One-time setup. 349 class Foo(object): 350 def __init__(self, input_): 351 self._input = input_ 352 def value(self): 353 return tf.constant(42.) 354 355 tf.register_tensor_conversion_function( 356 Foo, lambda x, *args, **kwargs: x.value()) 357 358 tf.keras.utils.register_symbolic_tensor_type(Foo) 359 360 # User-land. 361 layer = tf.keras.layers.Lambda(lambda input_: Foo(input_)) 362 ``` 363 364 Args: 365 cls: A `class` type which shall be regarded as a symbolic `Tensor`. 366 """ 367 global _user_convertible_tensor_types 368 if cls not in _user_convertible_tensor_types: 369 keras_tensor.register_keras_tensor_specialization( 370 cls, keras_tensor.UserRegisteredTypeKerasTensor) 371 _user_convertible_tensor_types.add(cls) 372 373 374def type_spec_from_value(value): 375 """Grab type_spec without converting array-likes to tensors.""" 376 if is_extension_type(value): 377 return value._type_spec # pylint: disable=protected-access 378 # Get a TensorSpec for array-like data without 379 # converting the data to a Tensor 380 if hasattr(value, 'shape') and hasattr(value, 'dtype'): 381 return tensor_spec.TensorSpec(value.shape, value.dtype) 382 else: 383 return type_spec.type_spec_from_value(value) 384 385 386def is_ragged(tensor): 387 """Returns true if `tensor` is a ragged tensor or ragged tensor value.""" 388 return isinstance( 389 tensor, 390 (ragged_tensor.RaggedTensor, ragged_tensor_value.RaggedTensorValue)) 391 392 393def is_tensor_or_variable(x): 394 return tensor_util.is_tf_type(x) or isinstance(x, variables.Variable) 395 396 397def assert_no_legacy_layers(layers): 398 """Prevent tf.layers.Layers from being used with Keras. 399 400 Certain legacy layers inherit from their keras analogs; however they are 401 not supported with keras and can lead to subtle and hard to diagnose bugs. 402 403 Args: 404 layers: A list of layers to check 405 406 Raises: 407 TypeError: If any elements of layers are tf.layers.Layers 408 """ 409 410 # isinstance check for tf.layers.Layer introduces a circular dependency. 411 legacy_layers = [l for l in layers if getattr(l, '_is_legacy_layer', None)] 412 if legacy_layers: 413 layer_str = '\n'.join(' ' + str(l) for l in legacy_layers) 414 raise TypeError( 415 'The following are legacy tf.layers.Layers:\n{}\nTo use keras as a ' 416 'framework (for instance using the Network, Model, or Sequential ' 417 'classes), please use the tf.keras.layers implementation instead. ' 418 '(Or, if writing custom layers, subclass from tf.keras.layers rather ' 419 'than tf.layers)'.format(layer_str)) 420 421 422@tf_contextlib.contextmanager 423def maybe_init_scope(layer): 424 """Open an `init_scope` if in V2 mode and using the keras graph. 425 426 Args: 427 layer: The Layer/Model that is currently active. 428 429 Yields: 430 None 431 """ 432 # Don't open an init_scope in V1 mode or when using legacy tf.layers. 433 if (ops.executing_eagerly_outside_functions() and 434 getattr(layer, '_keras_style', True)): 435 with ops.init_scope(): 436 yield 437 else: 438 yield 439 440 441@tf_contextlib.contextmanager 442def graph_context_for_symbolic_tensors(*args, **kwargs): 443 """Returns graph context manager if any of the inputs is a symbolic tensor.""" 444 if any(is_symbolic_tensor(v) for v in list(args) + list(kwargs.values())): 445 with K.get_graph().as_default(): 446 yield 447 else: 448 yield 449 450 451def dataset_is_infinite(dataset): 452 """True if the passed dataset is infinite.""" 453 if ops.executing_eagerly_outside_functions(): 454 return math_ops.equal( 455 cardinality.cardinality(dataset), cardinality.INFINITE) 456 else: 457 dataset_size = K.get_session().run(cardinality.cardinality(dataset)) 458 return dataset_size == cardinality.INFINITE 459 460 461def get_tensor_spec(t, dynamic_batch=False, name=None): 462 """Returns a `TensorSpec` given a single `Tensor` or `TensorSpec`.""" 463 # pylint: disable=protected-access 464 if isinstance(t, type_spec.TypeSpec): 465 spec = t 466 elif is_extension_type(t): 467 # TODO(b/148821952): Should these specs have a name attr? 468 spec = t._type_spec 469 elif (hasattr(t, '_keras_history') and 470 hasattr(t._keras_history[0], '_type_spec')): 471 return t._keras_history[0]._type_spec 472 elif hasattr(t, 'shape') and hasattr(t, 'dtype'): 473 spec = tensor_spec.TensorSpec(shape=t.shape, dtype=t.dtype, name=name) 474 else: 475 return None # Allow non-Tensors to pass through. 476 477 if not dynamic_batch: 478 return spec 479 480 dynamic_batch_spec = copy.deepcopy(spec) 481 # RaggedTensorSpec only has a private _shape. 482 shape = dynamic_batch_spec._shape 483 if shape.rank is not None and shape.rank > 0: 484 shape_list = shape.as_list() 485 shape_list[0] = None 486 dynamic_batch_spec._shape = tensor_shape.TensorShape(shape_list) 487 return dynamic_batch_spec 488 # pylint: enable=protected-access 489 490 491def to_numpy_or_python_type(tensors): 492 """Converts a structure of `Tensor`s to `NumPy` arrays or Python scalar types. 493 494 For each tensor, it calls `tensor.numpy()`. If the result is a scalar value, 495 it converts it to a Python type, such as a float or int, by calling 496 `result.item()`. 497 498 Numpy scalars are converted, as Python types are often more convenient to deal 499 with. This is especially useful for bfloat16 Numpy scalars, which don't 500 support as many operations as other Numpy values. 501 502 Args: 503 tensors: A structure of tensors. 504 505 Returns: 506 `tensors`, but scalar tensors are converted to Python types and non-scalar 507 tensors are converted to Numpy arrays. 508 """ 509 def _to_single_numpy_or_python_type(t): 510 if isinstance(t, ops.Tensor): 511 x = t.numpy() 512 return x.item() if np.ndim(x) == 0 else x 513 return t # Don't turn ragged or sparse tensors to NumPy. 514 515 return nest.map_structure(_to_single_numpy_or_python_type, tensors) 516 517 518def _astuple(attrs): 519 """Converts the given attrs to tuple non-recursively.""" 520 cls = type(attrs) 521 fields = getattr(cls, '__attrs_attrs__', None) 522 if fields is None: 523 raise ValueError('%r is not an attrs-decorated class.' % cls) 524 values = [] 525 for field in fields: 526 values.append(getattr(attrs, field.name)) 527 return tuple(values) 528