1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""## Functions for working with arbitrarily nested sequences of elements. 17 18This module can perform operations on nested structures. A nested structure is a 19Python collection that can contain further collections as well as other objects 20called atoms. Note that numpy arrays are considered atoms. 21 22nest recognizes the following types of collections: 23 1.tuple 24 2.namedtuple 25 3.dict 26 4.orderedDict 27 5.MutableMapping 28 6.attr.s 29 30attr.s decorated classes (http://www.attrs.org) are also supported, in the 31same way as `namedtuple`. 32 33The utilities here assume (and do not check) that the nested structures form a 34'tree', i.e., no references in the structure of the input of these functions 35should be recursive. 36 37Example structures: `((3, 4), 5, (6, 7, (9, 10), 8))`, `(np.array(0), 38 (np.array([3, 4]), tf.constant([3, 4])))` 39""" 40 41from __future__ import absolute_import 42from __future__ import division 43from __future__ import print_function 44 45import collections as _collections 46 47import six as _six 48import wrapt as _wrapt 49 50from tensorflow.python.platform import tf_logging 51from tensorflow.python.util import _pywrap_nest 52from tensorflow.python.util import _pywrap_utils 53from tensorflow.python.util.compat import collections_abc as _collections_abc 54from tensorflow.python.util.tf_export import tf_export 55 56 57_SHALLOW_TREE_HAS_INVALID_KEYS = ( 58 "The shallow_tree's keys are not a subset of the input_tree's keys. The " 59 "shallow_tree has the following keys that are not in the input_tree: {}.") 60 61_STRUCTURES_HAVE_MISMATCHING_TYPES = ( 62 "The two structures don't have the same sequence type. Input structure has " 63 "type {input_type}, while shallow structure has type {shallow_type}.") 64 65_STRUCTURES_HAVE_MISMATCHING_LENGTHS = ( 66 "The two structures don't have the same sequence length. Input " 67 "structure has length {input_length}, while shallow structure has length " 68 "{shallow_length}." 69) 70 71_INPUT_TREE_SMALLER_THAN_SHALLOW_TREE = ( 72 "The input_tree has fewer elements than the shallow_tree. Input structure " 73 "has length {input_size}, while shallow structure has length " 74 "{shallow_size}.") 75 76_IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ = ( 77 "If shallow structure is a sequence, input must also be a sequence. " 78 "Input has type: {}.") 79 80 81def _get_attrs_items(obj): 82 """Returns a list of (name, value) pairs from an attrs instance. 83 84 The list will be sorted by name. 85 86 Args: 87 obj: an object. 88 89 Returns: 90 A list of (attr_name, attr_value) pairs, sorted by attr_name. 91 """ 92 attrs = getattr(obj.__class__, "__attrs_attrs__") 93 attr_names = (a.name for a in attrs) 94 return [(attr_name, getattr(obj, attr_name)) for attr_name in attr_names] 95 96 97def _sorted(dict_): 98 """Returns a sorted list of the dict keys, with error if keys not sortable.""" 99 try: 100 return sorted(dict_.keys()) 101 except TypeError: 102 raise TypeError("nest only supports dicts with sortable keys.") 103 104 105def _is_namedtuple(instance, strict=False): 106 """Returns True iff `instance` is a `namedtuple`. 107 108 Args: 109 instance: An instance of a Python object. 110 strict: If True, `instance` is considered to be a `namedtuple` only if 111 it is a "plain" namedtuple. For instance, a class inheriting 112 from a `namedtuple` will be considered to be a `namedtuple` 113 iff `strict=False`. 114 115 Returns: 116 True if `instance` is a `namedtuple`. 117 """ 118 return _pywrap_utils.IsNamedtuple(instance, strict) 119 120 121# See the swig file (util.i) for documentation. 122_is_mapping_view = _pywrap_utils.IsMappingView 123_is_attrs = _pywrap_utils.IsAttrs 124_is_composite_tensor = _pywrap_utils.IsCompositeTensor 125_is_type_spec = _pywrap_utils.IsTypeSpec 126_is_mutable_mapping = _pywrap_utils.IsMutableMapping 127_is_mapping = _pywrap_utils.IsMapping 128 129 130@tf_export("__internal__.nest.is_attrs", v1=[]) 131def is_attrs(obj): 132 """Returns a true if its input is an instance of an attr.s decorated class.""" 133 return _is_attrs(obj) 134 135 136@tf_export("__internal__.nest.is_mapping", v1=[]) 137def is_mapping(obj): 138 """Returns a true if its input is a collections.Mapping.""" 139 return _is_mapping(obj) 140 141 142@tf_export("__internal__.nest.sequence_like", v1=[]) 143def _sequence_like(instance, args): 144 """Converts the sequence `args` to the same type as `instance`. 145 146 Args: 147 instance: an instance of `tuple`, `list`, `namedtuple`, `dict`, 148 `collections.OrderedDict`, or `composite_tensor.Composite_Tensor` 149 or `type_spec.TypeSpec`. 150 args: elements to be converted to the `instance` type. 151 152 Returns: 153 `args` with the type of `instance`. 154 """ 155 if _is_mutable_mapping(instance): 156 # Pack dictionaries in a deterministic order by sorting the keys. 157 # Notice this means that we ignore the original order of `OrderedDict` 158 # instances. This is intentional, to avoid potential bugs caused by mixing 159 # ordered and plain dicts (e.g., flattening a dict but using a 160 # corresponding `OrderedDict` to pack it back). 161 result = dict(zip(_sorted(instance), args)) 162 instance_type = type(instance) 163 if instance_type == _collections.defaultdict: 164 d = _collections.defaultdict(instance.default_factory) 165 else: 166 d = instance_type() 167 for key in instance: 168 d[key] = result[key] 169 return d 170 elif _is_mapping(instance): 171 result = dict(zip(_sorted(instance), args)) 172 instance_type = type(instance) 173 tf_logging.log_first_n( 174 tf_logging.WARN, "Mapping types may not work well with tf.nest. Prefer" 175 " using MutableMapping for {}".format(instance_type), 1) 176 try: 177 return instance_type((key, result[key]) for key in instance) 178 except TypeError as err: 179 raise TypeError("Error creating an object of type {} like {}. Note that " 180 "it must accept a single positional argument " 181 "representing an iterable of key-value pairs, in " 182 "addition to self. Cause: {}".format( 183 type(instance), instance, err)) 184 elif _is_mapping_view(instance): 185 # We can't directly construct mapping views, so we create a list instead 186 return list(args) 187 elif _is_namedtuple(instance) or _is_attrs(instance): 188 if isinstance(instance, _wrapt.ObjectProxy): 189 instance_type = type(instance.__wrapped__) 190 else: 191 instance_type = type(instance) 192 return instance_type(*args) 193 elif _is_composite_tensor(instance): 194 assert len(args) == 1 195 spec = instance._type_spec # pylint: disable=protected-access 196 return spec._from_components(args[0]) # pylint: disable=protected-access 197 elif _is_type_spec(instance): 198 # Pack a CompositeTensor's components according to a TypeSpec. 199 assert len(args) == 1 200 return instance._from_components(args[0]) # pylint: disable=protected-access 201 elif isinstance(instance, _six.moves.range): 202 return _sequence_like(list(instance), args) 203 elif isinstance(instance, _wrapt.ObjectProxy): 204 # For object proxies, first create the underlying type and then re-wrap it 205 # in the proxy type. 206 return type(instance)(_sequence_like(instance.__wrapped__, args)) 207 else: 208 # Not a namedtuple 209 return type(instance)(args) 210 211 212def _yield_value(iterable): 213 for _, v in _yield_sorted_items(iterable): 214 yield v 215 216 217def _yield_sorted_items(iterable): 218 """Yield (key, value) pairs for `iterable` in a deterministic order. 219 220 For Sequences, the key will be an int, the array index of a value. 221 For Mappings, the key will be the dictionary key. 222 For objects (e.g. namedtuples), the key will be the attribute name. 223 224 In all cases, the keys will be iterated in sorted order. 225 226 Args: 227 iterable: an iterable. 228 229 Yields: 230 The iterable's (key, value) pairs, in order of sorted keys. 231 """ 232 # Ordered to check common structure types (list, tuple, dict) first. 233 if isinstance(iterable, list): 234 for item in enumerate(iterable): 235 yield item 236 # namedtuples handled separately to avoid expensive namedtuple check. 237 elif type(iterable) == tuple: # pylint: disable=unidiomatic-typecheck 238 for item in enumerate(iterable): 239 yield item 240 elif isinstance(iterable, (dict, _collections_abc.Mapping)): 241 # Iterate through dictionaries in a deterministic order by sorting the 242 # keys. Notice this means that we ignore the original order of `OrderedDict` 243 # instances. This is intentional, to avoid potential bugs caused by mixing 244 # ordered and plain dicts (e.g., flattening a dict but using a 245 # corresponding `OrderedDict` to pack it back). 246 for key in _sorted(iterable): 247 yield key, iterable[key] 248 elif _is_attrs(iterable): 249 for item in _get_attrs_items(iterable): 250 yield item 251 elif _is_namedtuple(iterable): 252 for field in iterable._fields: 253 yield field, getattr(iterable, field) 254 elif _is_composite_tensor(iterable): 255 type_spec = iterable._type_spec # pylint: disable=protected-access 256 yield type_spec.value_type.__name__, type_spec._to_components(iterable) # pylint: disable=protected-access 257 elif _is_type_spec(iterable): 258 # Note: to allow CompositeTensors and their TypeSpecs to have matching 259 # structures, we need to use the same key string here. 260 yield iterable.value_type.__name__, iterable._component_specs # pylint: disable=protected-access 261 else: 262 for item in enumerate(iterable): 263 yield item 264 265 266# See the swig file (util.i) for documentation. 267is_sequence = _pywrap_utils.IsSequence 268 269 270# See the swig file (util.i) for documentation. 271is_sequence_or_composite = _pywrap_utils.IsSequenceOrComposite 272 273 274@tf_export("nest.is_nested") 275def is_nested(seq): 276 """Returns true if its input is a collections.abc.Sequence (except strings). 277 278 >>> tf.nest.is_nested("1234") 279 False 280 281 >>> tf.nest.is_nested([1, 3, [4, 5]]) 282 True 283 284 >>> tf.nest.is_nested(((7, 8), (5, 6))) 285 True 286 287 >>> tf.nest.is_nested([]) 288 True 289 290 >>> tf.nest.is_nested({"a": 1, "b": 2}) 291 True 292 293 >>> tf.nest.is_nested({"a": 1, "b": 2}.keys()) 294 True 295 296 >>> tf.nest.is_nested({"a": 1, "b": 2}.values()) 297 True 298 299 >>> tf.nest.is_nested({"a": 1, "b": 2}.items()) 300 True 301 302 >>> tf.nest.is_nested(set([1, 2])) 303 False 304 305 >>> ones = tf.ones([2, 3]) 306 >>> tf.nest.is_nested(ones) 307 False 308 309 Args: 310 seq: an input sequence. 311 312 Returns: 313 True if the sequence is a not a string and is a collections.abc.Sequence 314 or a dict. 315 """ 316 return is_sequence(seq) 317 318 319@tf_export("nest.flatten") 320def flatten(structure, expand_composites=False): 321 """Returns a flat list from a given nested structure. 322 323 If nest is not a structure , tuple (or a namedtuple), dict, or an attrs class, 324 then returns a single-element list: 325 [nest]. 326 327 This is the inverse of the `nest.pack_sequence_as` method that takes in a 328 flattened list and re-packs it into the nested structure. 329 330 In the case of dict instances, the sequence consists of the values, sorted by 331 key to ensure deterministic behavior. This is true also for OrderedDict 332 instances: their sequence order is ignored, the sorting order of keys is used 333 instead. The same convention is followed in `nest.pack_sequence_as`. This 334 correctly repacks dicts and OrderedDicts after they have been flattened, and 335 also allows flattening an OrderedDict and then repacking it back using a 336 corresponding plain dict, or vice-versa. Dictionaries with non-sortable keys 337 cannot be flattened. 338 339 Users must not modify any collections used in nest while this function is 340 running. 341 342 Examples: 343 344 1. Python dict (ordered by key): 345 346 >>> dict = { "key3": "value3", "key1": "value1", "key2": "value2" } 347 >>> tf.nest.flatten(dict) 348 ['value1', 'value2', 'value3'] 349 350 2. For a nested python tuple: 351 352 >>> tuple = ((1.0, 2.0), (3.0, 4.0, 5.0), 6.0) 353 >>> tf.nest.flatten(tuple) 354 [1.0, 2.0, 3.0, 4.0, 5.0, 6.0] 355 356 3. For a nested dictionary of dictionaries: 357 358 >>> dict = { "key3": {"c": (1.0, 2.0), "a": (3.0)}, 359 ... "key1": {"m": "val1", "g": "val2"} } 360 >>> tf.nest.flatten(dict) 361 ['val2', 'val1', 3.0, 1.0, 2.0] 362 363 4. Numpy array (will not flatten): 364 365 >>> array = np.array([[1, 2], [3, 4]]) 366 >>> tf.nest.flatten(array) 367 [array([[1, 2], 368 [3, 4]])] 369 370 5. `tf.Tensor` (will not flatten): 371 372 >>> tensor = tf.constant([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]) 373 >>> tf.nest.flatten(tensor) 374 [<tf.Tensor: shape=(3, 3), dtype=float32, numpy= 375 array([[1., 2., 3.], 376 [4., 5., 6.], 377 [7., 8., 9.]], dtype=float32)>] 378 379 6. `tf.RaggedTensor`: This is a composite tensor thats representation consists 380 of a flattened list of 'values' and a list of 'row_splits' which indicate how 381 to chop up the flattened list into different rows. For more details on 382 `tf.RaggedTensor`, please visit 383 https://www.tensorflow.org/api_docs/python/tf/RaggedTensor. 384 385 with `expand_composites=False`, we just return the RaggedTensor as is. 386 387 >>> tensor = tf.ragged.constant([[3, 1, 4, 1], [], [5, 9, 2]]) 388 >>> tf.nest.flatten(tensor, expand_composites=False) 389 [<tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2]]>] 390 391 with `expand_composites=True`, we return the component Tensors that make up 392 the RaggedTensor representation (the values and row_splits tensors) 393 394 >>> tensor = tf.ragged.constant([[3, 1, 4, 1], [], [5, 9, 2]]) 395 >>> tf.nest.flatten(tensor, expand_composites=True) 396 [<tf.Tensor: shape=(7,), dtype=int32, numpy=array([3, 1, 4, 1, 5, 9, 2], 397 dtype=int32)>, 398 <tf.Tensor: shape=(4,), dtype=int64, numpy=array([0, 4, 4, 7])>] 399 400 Args: 401 structure: an arbitrarily nested structure. Note, numpy arrays are 402 considered atoms and are not flattened. 403 expand_composites: If true, then composite tensors such as 404 `tf.sparse.SparseTensor` and `tf.RaggedTensor` are expanded into their 405 component tensors. 406 407 Returns: 408 A Python list, the flattened version of the input. 409 410 Raises: 411 TypeError: The nest is or contains a dict with non-sortable keys. 412 """ 413 if structure is None: 414 return [None] 415 expand_composites = bool(expand_composites) 416 return _pywrap_utils.Flatten(structure, expand_composites) 417 418 419# See the swig file (util.i) for documentation. 420_same_namedtuples = _pywrap_utils.SameNamedtuples 421 422 423class _DotString(object): 424 425 __slots__ = [] 426 427 def __str__(self): 428 return "." 429 430 def __repr__(self): 431 return "." 432 433 434_DOT = _DotString() 435 436 437@tf_export("nest.assert_same_structure") 438def assert_same_structure(nest1, nest2, check_types=True, 439 expand_composites=False): 440 """Asserts that two structures are nested in the same way. 441 442 Note the method does not check the types of data inside the structures. 443 444 Examples: 445 446 * These scalar vs. scalar comparisons will pass: 447 448 >>> tf.nest.assert_same_structure(1.5, tf.Variable(1, tf.uint32)) 449 >>> tf.nest.assert_same_structure("abc", np.array([1, 2])) 450 451 * These sequence vs. sequence comparisons will pass: 452 453 >>> structure1 = (((1, 2), 3), 4, (5, 6)) 454 >>> structure2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6")) 455 >>> structure3 = [(("a", "b"), "c"), "d", ["e", "f"]] 456 >>> tf.nest.assert_same_structure(structure1, structure2) 457 >>> tf.nest.assert_same_structure(structure1, structure3, check_types=False) 458 459 >>> import collections 460 >>> tf.nest.assert_same_structure( 461 ... collections.namedtuple("bar", "a b")(1, 2), 462 ... collections.namedtuple("foo", "a b")(2, 3), 463 ... check_types=False) 464 465 >>> tf.nest.assert_same_structure( 466 ... collections.namedtuple("bar", "a b")(1, 2), 467 ... { "a": 1, "b": 2 }, 468 ... check_types=False) 469 470 >>> tf.nest.assert_same_structure( 471 ... { "a": 1, "b": 2, "c": 3 }, 472 ... { "c": 6, "b": 5, "a": 4 }) 473 474 >>> ragged_tensor1 = tf.RaggedTensor.from_row_splits( 475 ... values=[3, 1, 4, 1, 5, 9, 2, 6], 476 ... row_splits=[0, 4, 4, 7, 8, 8]) 477 >>> ragged_tensor2 = tf.RaggedTensor.from_row_splits( 478 ... values=[3, 1, 4], 479 ... row_splits=[0, 3]) 480 >>> tf.nest.assert_same_structure( 481 ... ragged_tensor1, 482 ... ragged_tensor2, 483 ... expand_composites=True) 484 485 * These examples will raise exceptions: 486 487 >>> tf.nest.assert_same_structure([0, 1], np.array([0, 1])) 488 Traceback (most recent call last): 489 ... 490 ValueError: The two structures don't have the same nested structure 491 492 >>> tf.nest.assert_same_structure( 493 ... collections.namedtuple('bar', 'a b')(1, 2), 494 ... collections.namedtuple('foo', 'a b')(2, 3)) 495 Traceback (most recent call last): 496 ... 497 TypeError: The two structures don't have the same nested structure 498 499 Args: 500 nest1: an arbitrarily nested structure. 501 nest2: an arbitrarily nested structure. 502 check_types: if `True` (default) types of sequences are checked as well, 503 including the keys of dictionaries. If set to `False`, for example a 504 list and a tuple of objects will look the same if they have the same 505 size. Note that namedtuples with identical name and fields are always 506 considered to have the same shallow structure. Two types will also be 507 considered the same if they are both list subtypes (which allows "list" 508 and "_ListWrapper" from trackable dependency tracking to compare 509 equal). 510 expand_composites: If true, then composite tensors such as 511 `tf.sparse.SparseTensor` and `tf.RaggedTensor` are expanded into their 512 component tensors. 513 514 Raises: 515 ValueError: If the two structures do not have the same number of elements or 516 if the two structures are not nested in the same way. 517 TypeError: If the two structures differ in the type of sequence in any of 518 their substructures. Only possible if `check_types` is `True`. 519 """ 520 # Convert to bool explicitly as otherwise pybind will not be able# to handle 521 # type mismatch message correctly. See GitHub issue 42329 for details. 522 check_types = bool(check_types) 523 expand_composites = bool(expand_composites) 524 try: 525 _pywrap_utils.AssertSameStructure(nest1, nest2, check_types, 526 expand_composites) 527 except (ValueError, TypeError) as e: 528 str1 = str(map_structure(lambda _: _DOT, nest1)) 529 str2 = str(map_structure(lambda _: _DOT, nest2)) 530 raise type(e)("%s\n" 531 "Entire first structure:\n%s\n" 532 "Entire second structure:\n%s" 533 % (str(e), str1, str2)) 534 535 536def flatten_dict_items(dictionary): 537 """Returns a dictionary with flattened keys and values. 538 539 This function flattens the keys and values of a dictionary, which can be 540 arbitrarily nested structures, and returns the flattened version of such 541 structures: 542 543 ```python 544 example_dictionary = {(4, 5, (6, 8)): ("a", "b", ("c", "d"))} 545 result = {4: "a", 5: "b", 6: "c", 8: "d"} 546 flatten_dict_items(example_dictionary) == result 547 ``` 548 549 The input dictionary must satisfy two properties: 550 551 1. Its keys and values should have the same exact nested structure. 552 2. The set of all flattened keys of the dictionary must not contain repeated 553 keys. 554 555 Args: 556 dictionary: the dictionary to zip 557 558 Returns: 559 The zipped dictionary. 560 561 Raises: 562 TypeError: If the input is not a dictionary. 563 ValueError: If any key and value do not have the same structure layout, or 564 if keys are not unique. 565 """ 566 return _pywrap_nest.FlattenDictItems(dictionary) 567 568 569def _packed_nest_with_indices(structure, flat, index, is_seq, sequence_fn=None): 570 """Helper function for pack_sequence_as. 571 572 Args: 573 structure: Substructure (list / tuple / dict) to mimic. 574 flat: Flattened values to output substructure for. 575 index: Index at which to start reading from flat. 576 is_seq: Function used to test if a value should be treated as a sequence. 577 sequence_fn: Function used to generate a new sequence instance. 578 579 Returns: 580 The tuple (new_index, child), where: 581 * new_index - the updated index into `flat` having processed `structure`. 582 * packed - the subset of `flat` corresponding to `structure`, 583 having started at `index`, and packed into the same nested 584 format. 585 586 Raises: 587 ValueError: if `structure` contains more elements than `flat` 588 (assuming indexing starts from `index`). 589 """ 590 packed = [] 591 sequence_fn = sequence_fn or _sequence_like 592 for s in _yield_value(structure): 593 if is_seq(s): 594 new_index, child = _packed_nest_with_indices(s, flat, index, is_seq, 595 sequence_fn) 596 packed.append(sequence_fn(s, child)) 597 index = new_index 598 else: 599 packed.append(flat[index]) 600 index += 1 601 return index, packed 602 603 604def _pack_sequence_as(structure, flat_sequence, expand_composites, 605 sequence_fn=None): 606 """Implements sequence packing, with the option to alter the structure.""" 607 is_seq = is_sequence_or_composite if expand_composites else is_sequence 608 sequence_fn = sequence_fn or _sequence_like 609 def truncate(value, length): 610 value_str = str(value) 611 return value_str[:length] + (value_str[length:] and "...") 612 613 if not is_seq(flat_sequence): 614 raise TypeError( 615 "Attempted to pack value:\n {}\ninto a sequence, but found " 616 "incompatible type `{}` instead." 617 .format(truncate(flat_sequence, 100), type(flat_sequence))) 618 619 if not is_seq(structure): 620 if len(flat_sequence) != 1: 621 raise ValueError( 622 "The target structure is of type `{}`\n {}\nHowever the input " 623 "structure is a sequence ({}) of length {}.\n {}\nnest cannot " 624 "guarantee that it is safe to map one to the other.".format( 625 type(structure), truncate(structure, 100), type(flat_sequence), 626 len(flat_sequence), truncate(flat_sequence, 100))) 627 return flat_sequence[0] 628 629 try: 630 final_index, packed = _packed_nest_with_indices(structure, flat_sequence, 631 0, is_seq, sequence_fn) 632 if final_index < len(flat_sequence): 633 raise IndexError 634 except IndexError: 635 flat_structure = flatten(structure) 636 if len(flat_structure) != len(flat_sequence): 637 raise ValueError( 638 "Could not pack sequence. Structure had %d elements, but " 639 "flat_sequence had %d elements. Structure: %s, flat_sequence: %s." % 640 (len(flat_structure), len(flat_sequence), structure, flat_sequence)) 641 return sequence_fn(structure, packed) 642 643 644@tf_export("nest.pack_sequence_as") 645def pack_sequence_as(structure, flat_sequence, expand_composites=False): 646 """Returns a given flattened sequence packed into a given structure. 647 648 If `structure` is a scalar, `flat_sequence` must be a single-element list; 649 in this case the return value is `flat_sequence[0]`. 650 651 If `structure` is or contains a dict instance, the keys will be sorted to 652 pack the flat sequence in deterministic order. This is true also for 653 `OrderedDict` instances: their sequence order is ignored, the sorting order of 654 keys is used instead. The same convention is followed in `flatten`. 655 This correctly repacks dicts and `OrderedDict`s after they have been 656 flattened, and also allows flattening an `OrderedDict` and then repacking it 657 back using a corresponding plain dict, or vice-versa. 658 Dictionaries with non-sortable keys cannot be flattened. 659 660 Examples: 661 662 1. Python dict: 663 664 >>> structure = { "key3": "", "key1": "", "key2": "" } 665 >>> flat_sequence = ["value1", "value2", "value3"] 666 >>> tf.nest.pack_sequence_as(structure, flat_sequence) 667 {'key3': 'value3', 'key1': 'value1', 'key2': 'value2'} 668 669 2. For a nested python tuple: 670 671 >>> structure = (('a','b'), ('c','d','e'), 'f') 672 >>> flat_sequence = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0] 673 >>> tf.nest.pack_sequence_as(structure, flat_sequence) 674 ((1.0, 2.0), (3.0, 4.0, 5.0), 6.0) 675 676 3. For a nested dictionary of dictionaries: 677 678 >>> structure = { "key3": {"c": ('alpha', 'beta'), "a": ('gamma')}, 679 ... "key1": {"e": "val1", "d": "val2"} } 680 >>> flat_sequence = ['val2', 'val1', 3.0, 1.0, 2.0] 681 >>> tf.nest.pack_sequence_as(structure, flat_sequence) 682 {'key3': {'c': (1.0, 2.0), 'a': 3.0}, 'key1': {'e': 'val1', 'd': 'val2'}} 683 684 4. Numpy array (considered a scalar): 685 686 >>> structure = ['a'] 687 >>> flat_sequence = [np.array([[1, 2], [3, 4]])] 688 >>> tf.nest.pack_sequence_as(structure, flat_sequence) 689 [array([[1, 2], 690 [3, 4]])] 691 692 5. tf.Tensor (considered a scalar): 693 694 >>> structure = ['a'] 695 >>> flat_sequence = [tf.constant([[1., 2., 3.], [4., 5., 6.]])] 696 >>> tf.nest.pack_sequence_as(structure, flat_sequence) 697 [<tf.Tensor: shape=(2, 3), dtype=float32, 698 numpy= array([[1., 2., 3.], [4., 5., 6.]], dtype=float32)>] 699 700 6. `tf.RaggedTensor`: This is a composite tensor thats representation consists 701 of a flattened list of 'values' and a list of 'row_splits' which indicate how 702 to chop up the flattened list into different rows. For more details on 703 `tf.RaggedTensor`, please visit 704 https://www.tensorflow.org/api_docs/python/tf/RaggedTensor. 705 706 With `expand_composites=False`, we treat RaggedTensor as a scalar. 707 708 >>> structure = { "foo": tf.ragged.constant([[1, 2], [3]]), 709 ... "bar": tf.constant([[5]]) } 710 >>> flat_sequence = [ "one", "two" ] 711 >>> tf.nest.pack_sequence_as(structure, flat_sequence, 712 ... expand_composites=False) 713 {'foo': 'two', 'bar': 'one'} 714 715 With `expand_composites=True`, we expect that the flattened input contains 716 the tensors making up the ragged tensor i.e. the values and row_splits 717 tensors. 718 719 >>> structure = { "foo": tf.ragged.constant([[1., 2.], [3.]]), 720 ... "bar": tf.constant([[5.]]) } 721 >>> tensors = tf.nest.flatten(structure, expand_composites=True) 722 >>> print(tensors) 723 [<tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[5.]], 724 dtype=float32)>, 725 <tf.Tensor: shape=(3,), dtype=float32, numpy=array([1., 2., 3.], 726 dtype=float32)>, 727 <tf.Tensor: shape=(3,), dtype=int64, numpy=array([0, 2, 3])>] 728 >>> verified_tensors = [tf.debugging.check_numerics(t, 'invalid tensor: ') 729 ... if t.dtype==tf.float32 else t 730 ... for t in tensors] 731 >>> tf.nest.pack_sequence_as(structure, verified_tensors, 732 ... expand_composites=True) 733 {'foo': <tf.RaggedTensor [[1.0, 2.0], [3.0]]>, 734 'bar': <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[5.]], 735 dtype=float32)>} 736 737 Args: 738 structure: Nested structure, whose structure is given by nested lists, 739 tuples, and dicts. Note: numpy arrays and strings are considered 740 scalars. 741 flat_sequence: flat sequence to pack. 742 expand_composites: If true, then composite tensors such as 743 `tf.sparse.SparseTensor` and `tf.RaggedTensor` are expanded into their 744 component tensors. 745 746 Returns: 747 packed: `flat_sequence` converted to have the same recursive structure as 748 `structure`. 749 750 Raises: 751 ValueError: If `flat_sequence` and `structure` have different 752 element counts. 753 TypeError: `structure` is or contains a dict with non-sortable keys. 754 """ 755 return _pack_sequence_as(structure, flat_sequence, expand_composites) 756 757 758@tf_export("nest.map_structure") 759def map_structure(func, *structure, **kwargs): 760 """Applies `func` to each entry in `structure` and returns a new structure. 761 762 Applies `func(x[0], x[1], ...)` where x[i] is an entry in 763 `structure[i]`. All structures in `structure` must have the same arity, 764 and the return value will contain results with the same structure layout. 765 766 Examples: 767 768 * A single Python dict: 769 770 >>> a = {"hello": 24, "world": 76} 771 >>> tf.nest.map_structure(lambda p: p * 2, a) 772 {'hello': 48, 'world': 152} 773 774 * Multiple Python dictionaries: 775 776 >>> d1 = {"hello": 24, "world": 76} 777 >>> d2 = {"hello": 36, "world": 14} 778 >>> tf.nest.map_structure(lambda p1, p2: p1 + p2, d1, d2) 779 {'hello': 60, 'world': 90} 780 781 * A single Python list: 782 783 >>> a = [24, 76, "ab"] 784 >>> tf.nest.map_structure(lambda p: p * 2, a) 785 [48, 152, 'abab'] 786 787 * Scalars: 788 789 >>> tf.nest.map_structure(lambda x, y: x + y, 3, 4) 790 7 791 792 * Empty structures: 793 794 >>> tf.nest.map_structure(lambda x: x + 1, ()) 795 () 796 797 *. Check the types of iterables: 798 799 >>> s1 = (((1, 2), 3), 4, (5, 6)) 800 >>> s1_list = [[[1, 2], 3], 4, [5, 6]] 801 >>> tf.nest.map_structure(lambda x, y: None, s1, s1_list) 802 Traceback (most recent call last): 803 ... 804 TypeError: The two structures don't have the same nested structure 805 806 * Type check is set to False: 807 808 >>> s1 = (((1, 2), 3), 4, (5, 6)) 809 >>> s1_list = [[[1, 2], 3], 4, [5, 6]] 810 >>> tf.nest.map_structure(lambda x, y: None, s1, s1_list, check_types=False) 811 (((None, None), None), None, (None, None)) 812 813 Args: 814 func: A callable that accepts as many arguments as there are structures. 815 *structure: scalar, or tuple or dict or list of constructed scalars and/or 816 other tuples/lists, or scalars. Note: numpy arrays are considered as 817 scalars. 818 **kwargs: Valid keyword args are: 819 820 * `check_types`: If set to `True` (default) the types of 821 iterables within the structures have to be same (e.g. 822 `map_structure(func, [1], (1,))` raises a `TypeError` 823 exception). To allow this set this argument to `False`. 824 Note that namedtuples with identical name and fields are always 825 considered to have the same shallow structure. 826 * `expand_composites`: If set to `True`, then composite tensors such 827 as `tf.sparse.SparseTensor` and `tf.RaggedTensor` are expanded into 828 their component tensors. If `False` (the default), then composite 829 tensors are not expanded. 830 831 Returns: 832 A new structure with the same arity as `structure`, whose values correspond 833 to `func(x[0], x[1], ...)` where `x[i]` is a value in the corresponding 834 location in `structure[i]`. If there are different sequence types and 835 `check_types` is `False` the sequence types of the first structure will be 836 used. 837 838 Raises: 839 TypeError: If `func` is not callable or if the structures do not match 840 each other by depth tree. 841 ValueError: If no structure is provided or if the structures do not match 842 each other by type. 843 ValueError: If wrong keyword arguments are provided. 844 """ 845 if not callable(func): 846 raise TypeError("func must be callable, got: %s" % func) 847 848 if not structure: 849 raise ValueError("Must provide at least one structure") 850 851 check_types = kwargs.pop("check_types", True) 852 expand_composites = kwargs.pop("expand_composites", False) 853 854 if kwargs: 855 raise ValueError( 856 "Only valid keyword arguments are `check_types` and " 857 "`expand_composites`, not: `%s`" % ("`, `".join(kwargs.keys()))) 858 859 for other in structure[1:]: 860 assert_same_structure(structure[0], other, check_types=check_types, 861 expand_composites=expand_composites) 862 863 flat_structure = (flatten(s, expand_composites) for s in structure) 864 entries = zip(*flat_structure) 865 866 return pack_sequence_as( 867 structure[0], [func(*x) for x in entries], 868 expand_composites=expand_composites) 869 870 871def map_structure_with_paths(func, *structure, **kwargs): 872 """Applies `func` to each entry in `structure` and returns a new structure. 873 874 Applies `func(path, x[0], x[1], ..., **kwargs)` where x[i] is an entry in 875 `structure[i]` and `path` is the common path to x[i] in the structures. All 876 structures in `structure` must have the same arity, and the return value will 877 contain the results with the same structure layout. Special kwarg 878 `check_types` determines whether the types of iterables within the structure 879 must be the same-- see **kwargs definition below. 880 881 Args: 882 func: A callable with the signature func(path, *values, **kwargs) that is 883 evaluated on the leaves of the structure. 884 *structure: A variable number of compatible structures to process. 885 **kwargs: Optional kwargs to be passed through to func. Special kwarg 886 `check_types` is not passed to func, but instead determines whether the 887 types of iterables within the structures have to be same (e.g., 888 `map_structure(func, [1], (1,))` raises a `TypeError` exception). By 889 default, the types must match. To allow iteration over structures of 890 different types (but common arity), set this kwarg to `False`. 891 892 Returns: 893 A structure of the same form as the input structures whose leaves are the 894 result of evaluating func on corresponding leaves of the input structures. 895 896 Raises: 897 TypeError: If `func` is not callable or if the structures do not match 898 each other by depth tree. 899 TypeError: If `check_types` is not `False` and the two structures differ in 900 the type of sequence in any of their substructures. 901 ValueError: If no structures are provided. 902 """ 903 def wrapper_func(tuple_path, *inputs, **kwargs): 904 string_path = "/".join(str(s) for s in tuple_path) 905 return func(string_path, *inputs, **kwargs) 906 907 return map_structure_with_tuple_paths_up_to(structure[0], 908 wrapper_func, 909 *structure, 910 **kwargs) 911 912 913def map_structure_with_tuple_paths(func, *structure, **kwargs): 914 """Applies `func` to each entry in `structure` and returns a new structure. 915 916 Applies `func(tuple_path, x[0], x[1], ..., **kwargs)` where `x[i]` is an entry 917 in `structure[i]` and `tuple_path` is a tuple of indices and/or dictionary 918 keys (as returned by `nest.yield_flat_paths`), which uniquely specifies the 919 common path to x[i] in the structures. All structures in `structure` must have 920 the same arity, and the return value will contain the results in the same 921 structure. Special kwarg `check_types` determines whether the types of 922 iterables within the structure must be the same-- see **kwargs definition 923 below. 924 925 Args: 926 func: A callable with the signature `func(tuple_path, *values, **kwargs)` 927 that is evaluated on the leaves of the structure. 928 *structure: A variable number of compatible structures to process. 929 **kwargs: Optional kwargs to be passed through to func. Special kwarg 930 `check_types` is not passed to func, but instead determines whether the 931 types of iterables within the structures have to be same (e.g. 932 `map_structure(func, [1], (1,))` raises a `TypeError` exception). To allow 933 this set this argument to `False`. 934 935 Returns: 936 A structure of the same form as the input structures whose leaves are the 937 result of evaluating func on corresponding leaves of the input structures. 938 939 Raises: 940 TypeError: If `func` is not callable or if the structures do not match 941 each other by depth tree. 942 TypeError: If `check_types` is not `False` and the two structures differ in 943 the type of sequence in any of their substructures. 944 ValueError: If no structures are provided. 945 """ 946 return map_structure_with_tuple_paths_up_to(structure[0], 947 func, 948 *structure, 949 **kwargs) 950 951 952def _yield_flat_up_to(shallow_tree, input_tree, is_seq, path=()): 953 """Yields (path, value) pairs of input_tree flattened up to shallow_tree. 954 955 Args: 956 shallow_tree: Nested structure. Traverse no further than its leaf nodes. 957 input_tree: Nested structure. Return the paths and values from this tree. 958 Must have the same upper structure as shallow_tree. 959 is_seq: Function used to test if a value should be treated as a sequence. 960 path: Tuple. Optional argument, only used when recursing. The path from the 961 root of the original shallow_tree, down to the root of the shallow_tree 962 arg of this recursive call. 963 964 Yields: 965 Pairs of (path, value), where path the tuple path of a leaf node in 966 shallow_tree, and value is the value of the corresponding node in 967 input_tree. 968 """ 969 if not is_seq(shallow_tree): 970 yield (path, input_tree) 971 else: 972 input_tree = dict(_yield_sorted_items(input_tree)) 973 for shallow_key, shallow_subtree in _yield_sorted_items(shallow_tree): 974 subpath = path + (shallow_key,) 975 input_subtree = input_tree[shallow_key] 976 for leaf_path, leaf_value in _yield_flat_up_to(shallow_subtree, 977 input_subtree, is_seq, 978 path=subpath): 979 yield (leaf_path, leaf_value) 980 981 982def assert_shallow_structure(shallow_tree, 983 input_tree, 984 check_types=True, 985 expand_composites=False): 986 """Asserts that `shallow_tree` is a shallow structure of `input_tree`. 987 988 That is, this function tests if the `input_tree` structure can be created from 989 the `shallow_tree` structure by replacing its leaf nodes with deeper 990 tree structures. 991 992 Examples: 993 994 The following code will raise an exception: 995 ```python 996 shallow_tree = {"a": "A", "b": "B"} 997 input_tree = {"a": 1, "c": 2} 998 assert_shallow_structure(shallow_tree, input_tree) 999 ``` 1000 1001 The following code will raise an exception: 1002 ```python 1003 shallow_tree = ["a", "b"] 1004 input_tree = ["c", ["d", "e"], "f"] 1005 assert_shallow_structure(shallow_tree, input_tree) 1006 ``` 1007 1008 Args: 1009 shallow_tree: an arbitrarily nested structure. 1010 input_tree: an arbitrarily nested structure. 1011 check_types: if `True` (default) the sequence types of `shallow_tree` and 1012 `input_tree` have to be the same. Note that even with check_types==True, 1013 this function will consider two different namedtuple classes with the same 1014 name and _fields attribute to be the same class. 1015 expand_composites: If true, then composite tensors such as 1016 `tf.sparse.SparseTensor` and `tf.RaggedTensor` are expanded into their 1017 component tensors. 1018 Raises: 1019 TypeError: If `shallow_tree` is a sequence but `input_tree` is not. 1020 TypeError: If the sequence types of `shallow_tree` are different from 1021 `input_tree`. Only raised if `check_types` is `True`. 1022 ValueError: If the sequence lengths of `shallow_tree` are different from 1023 `input_tree`. 1024 """ 1025 is_seq = is_sequence_or_composite if expand_composites else is_sequence 1026 if is_seq(shallow_tree): 1027 if not is_seq(input_tree): 1028 raise TypeError( 1029 "If shallow structure is a sequence, input must also be a sequence. " 1030 "Input has type: %s." % type(input_tree)) 1031 1032 if isinstance(shallow_tree, _wrapt.ObjectProxy): 1033 shallow_type = type(shallow_tree.__wrapped__) 1034 else: 1035 shallow_type = type(shallow_tree) 1036 1037 if check_types and not isinstance(input_tree, shallow_type): 1038 # Duck-typing means that nest should be fine with two different 1039 # namedtuples with identical name and fields. 1040 shallow_is_namedtuple = _is_namedtuple(shallow_tree, False) 1041 input_is_namedtuple = _is_namedtuple(input_tree, False) 1042 if shallow_is_namedtuple and input_is_namedtuple: 1043 if not _same_namedtuples(shallow_tree, input_tree): 1044 raise TypeError(_STRUCTURES_HAVE_MISMATCHING_TYPES.format( 1045 input_type=type(input_tree), 1046 shallow_type=type(shallow_tree))) 1047 1048 elif isinstance(shallow_tree, list) and isinstance(input_tree, list): 1049 # List subclasses are considered the same, 1050 # e.g. python list vs. _ListWrapper. 1051 pass 1052 1053 elif ((_is_composite_tensor(shallow_tree) or 1054 _is_composite_tensor(input_tree)) and 1055 (_is_type_spec(shallow_tree) or _is_type_spec(input_tree))): 1056 pass # Compatibility will be checked below. 1057 1058 elif not (isinstance(shallow_tree, _collections_abc.Mapping) and 1059 isinstance(input_tree, _collections_abc.Mapping)): 1060 raise TypeError(_STRUCTURES_HAVE_MISMATCHING_TYPES.format( 1061 input_type=type(input_tree), 1062 shallow_type=type(shallow_tree))) 1063 1064 if _is_composite_tensor(shallow_tree) or _is_composite_tensor(input_tree): 1065 if not ( 1066 (_is_composite_tensor(input_tree) or _is_type_spec(input_tree)) and 1067 (_is_composite_tensor(shallow_tree) or _is_type_spec(shallow_tree))): 1068 raise TypeError(_STRUCTURES_HAVE_MISMATCHING_TYPES.format( 1069 input_type=type(input_tree), 1070 shallow_type=type(shallow_tree))) 1071 type_spec_1 = (shallow_tree if _is_type_spec(shallow_tree) else 1072 shallow_tree._type_spec) # pylint: disable=protected-access 1073 type_spec_2 = (input_tree if _is_type_spec(input_tree) else 1074 input_tree._type_spec) # pylint: disable=protected-access 1075 try: 1076 _ = type_spec_1.most_specific_compatible_type(type_spec_2) 1077 except (TypeError, ValueError) as e: 1078 raise ValueError( 1079 "Incompatible CompositeTensor TypeSpecs: %s vs. %s -- %s" % 1080 (type_spec_1, type_spec_2, e)) 1081 1082 elif _is_type_spec(shallow_tree): 1083 if not _is_type_spec(input_tree): 1084 raise TypeError("If shallow structure is a TypeSpec, input must also " 1085 "be a TypeSpec. Input has type: %s." 1086 % type(input_tree)) 1087 else: 1088 if len(input_tree) != len(shallow_tree): 1089 raise ValueError( 1090 _STRUCTURES_HAVE_MISMATCHING_LENGTHS.format( 1091 input_length=len(input_tree), shallow_length=len(shallow_tree))) 1092 elif len(input_tree) < len(shallow_tree): 1093 raise ValueError( 1094 _INPUT_TREE_SMALLER_THAN_SHALLOW_TREE.format( 1095 input_size=len(input_tree), shallow_size=len(shallow_tree))) 1096 1097 if isinstance(shallow_tree, _collections_abc.Mapping): 1098 absent_keys = set(shallow_tree) - set(input_tree) 1099 if absent_keys: 1100 raise ValueError(_SHALLOW_TREE_HAS_INVALID_KEYS 1101 .format(sorted(absent_keys))) 1102 1103 for shallow_branch, input_branch in zip(_yield_value(shallow_tree), 1104 _yield_value(input_tree)): 1105 assert_shallow_structure(shallow_branch, input_branch, 1106 check_types=check_types, 1107 expand_composites=expand_composites) 1108 1109 1110@tf_export("__internal__.nest.flatten_up_to", v1=[]) 1111def flatten_up_to(shallow_tree, input_tree, check_types=True, 1112 expand_composites=False): 1113 """Flattens `input_tree` up to `shallow_tree`. 1114 1115 Any further depth in structure in `input_tree` is retained as elements in the 1116 partially flatten output. 1117 1118 If `shallow_tree` and `input_tree` are not sequences, this returns a 1119 single-element list: `[input_tree]`. 1120 1121 Use Case: 1122 1123 Sometimes we may wish to partially flatten a nested sequence, retaining some 1124 of the nested structure. We achieve this by specifying a shallow structure, 1125 `shallow_tree`, we wish to flatten up to. 1126 1127 The input, `input_tree`, can be thought of as having the same structure layout 1128 as `shallow_tree`, but with leaf nodes that are themselves tree structures. 1129 1130 Examples: 1131 1132 ```python 1133 input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]] 1134 shallow_tree = [[True, True], [False, True]] 1135 1136 flattened_input_tree = flatten_up_to(shallow_tree, input_tree) 1137 flattened_shallow_tree = flatten_up_to(shallow_tree, shallow_tree) 1138 1139 # Output is: 1140 # [[2, 2], [3, 3], [4, 9], [5, 5]] 1141 # [True, True, False, True] 1142 ``` 1143 1144 ```python 1145 input_tree = [[('a', 1), [('b', 2), [('c', 3), [('d', 4)]]]]] 1146 shallow_tree = [['level_1', ['level_2', ['level_3', ['level_4']]]]] 1147 1148 input_tree_flattened_as_shallow_tree = flatten_up_to(shallow_tree, input_tree) 1149 input_tree_flattened = flatten(input_tree) 1150 1151 # Output is: 1152 # [('a', 1), ('b', 2), ('c', 3), ('d', 4)] 1153 # ['a', 1, 'b', 2, 'c', 3, 'd', 4] 1154 ``` 1155 1156 Non-Sequence Edge Cases: 1157 1158 ```python 1159 flatten_up_to(0, 0) # Output: [0] 1160 flatten_up_to(0, [0, 1, 2]) # Output: [[0, 1, 2]] 1161 flatten_up_to([0, 1, 2], 0) # Output: TypeError 1162 flatten_up_to([0, 1, 2], [0, 1, 2]) # Output: [0, 1, 2] 1163 ``` 1164 1165 Args: 1166 shallow_tree: a possibly pruned structure of input_tree. 1167 input_tree: an arbitrarily nested structure or a scalar object. 1168 Note, numpy arrays are considered scalars. 1169 check_types: bool. If True, check that each node in shallow_tree has the 1170 same type as the corresponding node in input_tree. 1171 expand_composites: If true, then composite tensors such as 1172 `tf.sparse.SparseTensor` and `tf.RaggedTensor` are expanded into their 1173 component tensors. 1174 1175 Returns: 1176 A Python list, the partially flattened version of `input_tree` according to 1177 the structure of `shallow_tree`. 1178 1179 Raises: 1180 TypeError: If `shallow_tree` is a sequence but `input_tree` is not. 1181 TypeError: If the sequence types of `shallow_tree` are different from 1182 `input_tree`. 1183 ValueError: If the sequence lengths of `shallow_tree` are different from 1184 `input_tree`. 1185 """ 1186 is_seq = is_sequence_or_composite if expand_composites else is_sequence 1187 assert_shallow_structure(shallow_tree, 1188 input_tree, 1189 check_types=check_types, 1190 expand_composites=expand_composites) 1191 # Discard paths returned by _yield_flat_up_to. 1192 return [v for _, v in _yield_flat_up_to(shallow_tree, input_tree, is_seq)] 1193 1194 1195def flatten_with_tuple_paths_up_to(shallow_tree, 1196 input_tree, 1197 check_types=True, 1198 expand_composites=False): 1199 """Flattens `input_tree` up to `shallow_tree`. 1200 1201 Any further depth in structure in `input_tree` is retained as elements in the 1202 partially flattened output. 1203 1204 Returns a list of (path, value) pairs, where value a leaf node in the 1205 flattened tree, and path is the tuple path of that leaf in input_tree. 1206 1207 If `shallow_tree` and `input_tree` are not sequences, this returns a 1208 single-element list: `[((), input_tree)]`. 1209 1210 Use Case: 1211 1212 Sometimes we may wish to partially flatten a nested sequence, retaining some 1213 of the nested structure. We achieve this by specifying a shallow structure, 1214 `shallow_tree`, we wish to flatten up to. 1215 1216 The input, `input_tree`, can be thought of as having the same structure layout 1217 as `shallow_tree`, but with leaf nodes that are themselves tree structures. 1218 1219 Examples: 1220 1221 ```python 1222 input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]] 1223 shallow_tree = [[True, True], [False, True]] 1224 1225 flattened_input_tree = flatten_with_tuple_paths_up_to(shallow_tree, 1226 input_tree) 1227 flattened_shallow_tree = flatten_with_tuple_paths_up_to(shallow_tree, 1228 shallow_tree) 1229 1230 # Output is: 1231 # [((0, 0), [2, 2]), 1232 # ((0, 1), [3, 3]), 1233 # ((1, 0), [4, 9]), 1234 # ((1, 1), [5, 5])] 1235 # 1236 # [((0, 0), True), 1237 # ((0, 1), True), 1238 # ((1, 0), False), 1239 # ((1, 1), True)] 1240 ``` 1241 1242 ```python 1243 input_tree = [[('a', 1), [('b', 2), [('c', 3), [('d', 4)]]]]] 1244 shallow_tree = [['level_1', ['level_2', ['level_3', ['level_4']]]]] 1245 1246 input_tree_flattened_as_shallow_tree = flatten_up_to(shallow_tree, input_tree) 1247 input_tree_flattened = flatten(input_tree) 1248 1249 # Output is: 1250 # [((0, 0), ('a', 1)), 1251 # ((0, 1, 0), ('b', 2)), 1252 # ((0, 1, 1, 0), ('c', 3)), 1253 # ((0, 1, 1, 1), ('d', 4))] 1254 # ['a', 1, 'b', 2, 'c', 3, 'd', 4] 1255 ``` 1256 1257 Non-Sequence Edge Cases: 1258 1259 ```python 1260 flatten_with_tuple_paths_up_to(0, 0) # Output: [(), 0] 1261 1262 flatten_with_tuple_paths_up_to(0, [0, 1, 2]) # Output: [(), [0, 1, 2]] 1263 1264 flatten_with_tuple_paths_up_to([0, 1, 2], 0) # Output: TypeError 1265 1266 flatten_with_tuple_paths_up_to([0, 1, 2], [0, 1, 2]) 1267 # Output: [((0,) 0), ((1,), 1), ((2,), 2)] 1268 ``` 1269 1270 Args: 1271 shallow_tree: a possibly pruned structure of input_tree. 1272 input_tree: an arbitrarily nested structure or a scalar object. 1273 Note, numpy arrays are considered scalars. 1274 check_types: bool. If True, check that each node in shallow_tree has the 1275 same type as the corresponding node in input_tree. 1276 expand_composites: If true, then composite tensors such as 1277 `tf.sparse.SparseTensor` and `tf.RaggedTensor` are expanded into their 1278 component tensors. 1279 1280 Returns: 1281 A Python list, the partially flattened version of `input_tree` according to 1282 the structure of `shallow_tree`. 1283 1284 Raises: 1285 TypeError: If `shallow_tree` is a sequence but `input_tree` is not. 1286 TypeError: If the sequence types of `shallow_tree` are different from 1287 `input_tree`. 1288 ValueError: If the sequence lengths of `shallow_tree` are different from 1289 `input_tree`. 1290 """ 1291 is_seq = is_sequence_or_composite if expand_composites else is_sequence 1292 assert_shallow_structure(shallow_tree, 1293 input_tree, 1294 check_types=check_types, 1295 expand_composites=expand_composites) 1296 return list(_yield_flat_up_to(shallow_tree, input_tree, is_seq)) 1297 1298 1299@tf_export("__internal__.nest.map_structure_up_to", v1=[]) 1300def map_structure_up_to(shallow_tree, func, *inputs, **kwargs): 1301 """Applies a function or op to a number of partially flattened inputs. 1302 1303 The `inputs` are flattened up to `shallow_tree` before being mapped. 1304 1305 Use Case: 1306 1307 Sometimes we wish to apply a function to a partially flattened 1308 sequence (for example when the function itself takes sequence inputs). We 1309 achieve this by specifying a shallow structure, `shallow_tree` we wish to 1310 flatten up to. 1311 1312 The `inputs`, can be thought of as having the same structure layout as 1313 `shallow_tree`, but with leaf nodes that are themselves tree structures. 1314 1315 This function therefore will return something with the same base structure as 1316 `shallow_tree`. 1317 1318 Examples: 1319 1320 ```python 1321 shallow_tree = [None, None] 1322 inp_val = [1, 2, 3] 1323 out = map_structure_up_to(shallow_tree, lambda x: 2 * x, inp_val) 1324 1325 # Output is: [2, 4] 1326 ``` 1327 1328 ```python 1329 ab_tuple = collections.namedtuple("ab_tuple", "a, b") 1330 op_tuple = collections.namedtuple("op_tuple", "add, mul") 1331 inp_val = ab_tuple(a=2, b=3) 1332 inp_ops = ab_tuple(a=op_tuple(add=1, mul=2), b=op_tuple(add=2, mul=3)) 1333 out = map_structure_up_to(inp_val, lambda val, ops: (val + ops.add) * ops.mul, 1334 inp_val, inp_ops) 1335 1336 # Output is: ab_tuple(a=6, b=15) 1337 ``` 1338 1339 ```python 1340 data_list = [[2, 4, 6, 8], [[1, 3, 5, 7, 9], [3, 5, 7]]] 1341 name_list = ['evens', ['odds', 'primes']] 1342 out = map_structure_up_to( 1343 name_list, 1344 lambda name, sec: "first_{}_{}".format(len(sec), name), 1345 name_list, data_list) 1346 1347 # Output is: ['first_4_evens', ['first_5_odds', 'first_3_primes']] 1348 ``` 1349 1350 Args: 1351 shallow_tree: a shallow tree, common to all the inputs. 1352 func: callable which will be applied to each input individually. 1353 *inputs: arbitrarily nested combination of objects that are compatible with 1354 shallow_tree. The function `func` is applied to corresponding 1355 partially flattened elements of each input, so the function must support 1356 arity of `len(inputs)`. 1357 **kwargs: kwargs to feed to func(). Special kwarg 1358 `check_types` is not passed to func, but instead determines whether the 1359 types of iterables within the structures have to be same (e.g. 1360 `map_structure(func, [1], (1,))` raises a `TypeError` exception). To allow 1361 this set this argument to `False`. 1362 1363 Raises: 1364 TypeError: If `shallow_tree` is a sequence but `input_tree` is not. 1365 TypeError: If the sequence types of `shallow_tree` are different from 1366 `input_tree`. 1367 ValueError: If the sequence lengths of `shallow_tree` are different from 1368 `input_tree`. 1369 1370 Returns: 1371 result of repeatedly applying `func`, with the same structure layout as 1372 `shallow_tree`. 1373 """ 1374 return map_structure_with_tuple_paths_up_to( 1375 shallow_tree, 1376 lambda _, *values: func(*values), # Discards the path arg. 1377 *inputs, 1378 **kwargs) 1379 1380 1381def map_structure_with_tuple_paths_up_to(shallow_tree, func, *inputs, **kwargs): 1382 """Applies a function or op to a number of partially flattened inputs. 1383 1384 Like map_structure_up_to(), except that the 'func' argument takes a path 1385 tuple as its first argument, followed by the corresponding values from 1386 *inputs. 1387 1388 Example: 1389 1390 ```python 1391 lowercase = {'a': 'a', 'b': ('b0', 'b1')} 1392 uppercase = {'a': 'A', 'b': ('B0', 'B1')} 1393 1394 def print_path_and_values(path, *values): 1395 print("path: {}, values: {}".format(path, values)) 1396 1397 shallow_tree = {'a': None} 1398 map_structure_with_tuple_paths_up_to(shallow_tree, 1399 print_path_and_values, 1400 lowercase, 1401 uppercase) 1402 path: ('a',), values: ('a', 'A') 1403 path: ('b', 0), values: ('b0', 'B0') 1404 path: ('b', 1), values: ('b1', 'B1') 1405 1406 shallow_tree = {'b': None} 1407 map_structure_with_tuple_paths_up_to(shallow_tree, 1408 print_path_and_values, 1409 lowercase, 1410 uppercase, 1411 check_types=False) 1412 path: ('b', 1), values: (('bo', 'b1'), ('B0', 'B1')) 1413 1414 shallow_tree = {'a': None, 'b': {1: None}} 1415 map_structure_with_tuple_paths_up_to(shallow_tree, 1416 print_path_and_values, 1417 lowercase, 1418 uppercase, 1419 check_types=False) 1420 path: ('a',), values: ('a', 'A') 1421 path: ('b', 1), values: ('b1', B1') 1422 ``` 1423 1424 Args: 1425 shallow_tree: a shallow tree, common to all the inputs. 1426 func: callable that takes args (path, inputs_0_value, ... , inputs_N_value), 1427 where path is a tuple path to a leaf node in shallow_tree, and 1428 inputs_i_value is the corresponding value from inputs[i]. 1429 *inputs: nested structures that are all structurally compatible with 1430 shallow_tree. 1431 **kwargs: kwargs to feed to func(). Special kwarg 1432 `check_types` is not passed to func, but instead determines whether the 1433 types of iterables within the structures have to be same (e.g. 1434 `map_structure(func, [1], (1,))` raises a `TypeError` exception). To allow 1435 this set this argument to `False`. 1436 1437 Raises: 1438 TypeError: If `shallow_tree` is a sequence but one of `*inputs` is not. 1439 TypeError: If the sequence types of `shallow_tree` are different from 1440 `input_tree`. 1441 ValueError: If the sequence lengths of `shallow_tree` are different from 1442 `input_tree`. 1443 1444 Returns: 1445 Result of repeatedly applying `func`. Has the same structure layout as 1446 `shallow_tree`. 1447 """ 1448 if not inputs: 1449 raise ValueError("Cannot map over no sequences") 1450 1451 check_types = kwargs.pop("check_types", True) 1452 expand_composites = kwargs.pop("expand_composites", False) 1453 is_seq = is_sequence_or_composite if expand_composites else is_sequence 1454 1455 for input_tree in inputs: 1456 assert_shallow_structure( 1457 shallow_tree, 1458 input_tree, 1459 check_types=check_types, 1460 expand_composites=expand_composites) 1461 1462 # Flatten each input separately, apply the function to corresponding elements, 1463 # then repack based on the structure of the first input. 1464 flat_value_gen = ( 1465 flatten_up_to( # pylint: disable=g-complex-comprehension 1466 shallow_tree, 1467 input_tree, 1468 check_types, 1469 expand_composites=expand_composites) for input_tree in inputs) 1470 flat_path_gen = ( 1471 path for path, _ in _yield_flat_up_to(shallow_tree, inputs[0], is_seq)) 1472 results = [ 1473 func(*args, **kwargs) for args in zip(flat_path_gen, *flat_value_gen) 1474 ] 1475 return pack_sequence_as(structure=shallow_tree, flat_sequence=results, 1476 expand_composites=expand_composites) 1477 1478 1479@tf_export("__internal__.nest.get_traverse_shallow_structure", v1=[]) 1480def get_traverse_shallow_structure(traverse_fn, structure, 1481 expand_composites=False): 1482 """Generates a shallow structure from a `traverse_fn` and `structure`. 1483 1484 `traverse_fn` must accept any possible subtree of `structure` and return 1485 a depth=1 structure containing `True` or `False` values, describing which 1486 of the top-level subtrees may be traversed. It may also 1487 return scalar `True` or `False` "traversal is OK / not OK for all subtrees." 1488 1489 Examples are available in the unit tests (nest_test.py). 1490 1491 Args: 1492 traverse_fn: Function taking a substructure and returning either a scalar 1493 `bool` (whether to traverse that substructure or not) or a depth=1 1494 shallow structure of the same type, describing which parts of the 1495 substructure to traverse. 1496 structure: The structure to traverse. 1497 expand_composites: If true, then composite tensors such as 1498 `tf.sparse.SparseTensor` and `tf.RaggedTensor` are expanded into their 1499 component tensors. 1500 1501 Returns: 1502 A shallow structure containing python bools, which can be passed to 1503 `map_structure_up_to` and `flatten_up_to`. 1504 1505 Raises: 1506 TypeError: if `traverse_fn` returns a sequence for a non-sequence input, 1507 or a structure with depth higher than 1 for a sequence input, 1508 or if any leaf values in the returned structure or scalar are not type 1509 `bool`. 1510 """ 1511 is_seq = is_sequence_or_composite if expand_composites else is_sequence 1512 to_traverse = traverse_fn(structure) 1513 if not is_seq(structure): 1514 if not isinstance(to_traverse, bool): 1515 raise TypeError("traverse_fn returned structure: %s for non-structure: %s" 1516 % (to_traverse, structure)) 1517 return to_traverse 1518 level_traverse = [] 1519 if isinstance(to_traverse, bool): 1520 if not to_traverse: 1521 # Do not traverse this substructure at all. Exit early. 1522 return False 1523 else: 1524 # Traverse the entire substructure. 1525 for branch in _yield_value(structure): 1526 level_traverse.append( 1527 get_traverse_shallow_structure(traverse_fn, branch, 1528 expand_composites=expand_composites)) 1529 elif not is_seq(to_traverse): 1530 raise TypeError("traverse_fn returned a non-bool scalar: %s for input: %s" 1531 % (to_traverse, structure)) 1532 else: 1533 # Traverse some subset of this substructure. 1534 assert_shallow_structure(to_traverse, structure, 1535 expand_composites=expand_composites) 1536 for t, branch in zip(_yield_value(to_traverse), 1537 _yield_value(structure)): 1538 if not isinstance(t, bool): 1539 raise TypeError( 1540 "traverse_fn didn't return a depth=1 structure of bools. saw: %s " 1541 " for structure: %s" % (to_traverse, structure)) 1542 if t: 1543 level_traverse.append( 1544 get_traverse_shallow_structure(traverse_fn, branch)) 1545 else: 1546 level_traverse.append(False) 1547 return _sequence_like(structure, level_traverse) 1548 1549 1550@tf_export("__internal__.nest.yield_flat_paths", v1=[]) 1551def yield_flat_paths(nest, expand_composites=False): 1552 """Yields paths for some nested structure. 1553 1554 Paths are lists of objects which can be str-converted, which may include 1555 integers or other types which are used as indices in a dict. 1556 1557 The flat list will be in the corresponding order as if you called 1558 `nest.flatten` on the structure. This is handy for naming Tensors such 1559 the TF scope structure matches the tuple structure. 1560 1561 E.g. if we have a tuple `value = Foo(a=3, b=Bar(c=23, d=42))` 1562 1563 ```shell 1564 nest.flatten(value) 1565 [3, 23, 42] 1566 list(nest.yield_flat_paths(value)) 1567 [('a',), ('b', 'c'), ('b', 'd')] 1568 ``` 1569 1570 ```shell 1571 list(nest.yield_flat_paths({'a': [3]})) 1572 [('a', 0)] 1573 list(nest.yield_flat_paths({'a': 3})) 1574 [('a',)] 1575 ``` 1576 1577 Args: 1578 nest: the value to produce a flattened paths list for. 1579 expand_composites: If true, then composite tensors such as 1580 `tf.sparse.SparseTensor` and `tf.RaggedTensor` are expanded into their 1581 component tensors. 1582 1583 Yields: 1584 Tuples containing index or key values which form the path to a specific 1585 leaf value in the nested structure. 1586 """ 1587 is_seq = is_sequence_or_composite if expand_composites else is_sequence 1588 for k, _ in _yield_flat_up_to(nest, nest, is_seq): 1589 yield k 1590 1591 1592def flatten_with_joined_string_paths(structure, separator="/", 1593 expand_composites=False): 1594 """Returns a list of (string path, data element) tuples. 1595 1596 The order of tuples produced matches that of `nest.flatten`. This allows you 1597 to flatten a nested structure while keeping information about where in the 1598 structure each data element was located. See `nest.yield_flat_paths` 1599 for more information. 1600 1601 Args: 1602 structure: the nested structure to flatten. 1603 separator: string to separate levels of hierarchy in the results, defaults 1604 to '/'. 1605 expand_composites: If true, then composite tensors such as 1606 `tf.sparse.SparseTensor` and `tf.RaggedTensor` are expanded into their 1607 component tensors. 1608 1609 Returns: 1610 A list of (string, data element) tuples. 1611 """ 1612 flat_paths = yield_flat_paths(structure, expand_composites=expand_composites) 1613 def stringify_and_join(path_elements): 1614 return separator.join(str(path_element) for path_element in path_elements) 1615 1616 flat_string_paths = (stringify_and_join(path) for path in flat_paths) 1617 return list(zip(flat_string_paths, 1618 flatten(structure, expand_composites=expand_composites))) 1619 1620 1621def flatten_with_tuple_paths(structure, expand_composites=False): 1622 """Returns a list of `(tuple_path, leaf_element)` tuples. 1623 1624 The order of pairs produced matches that of `nest.flatten`. This allows you 1625 to flatten a nested structure while keeping information about where in the 1626 structure each data element was located. See `nest.yield_flat_paths` 1627 for more information about tuple paths. 1628 1629 Args: 1630 structure: the nested structure to flatten. 1631 expand_composites: If true, then composite tensors such as 1632 `tf.sparse.SparseTensor` and `tf.RaggedTensor` are expanded into their 1633 component tensors. 1634 1635 Returns: 1636 A list of `(tuple_path, leaf_element)` tuples. Each `tuple_path` is a tuple 1637 of indices and/or dictionary keys that uniquely specify the path to 1638 `leaf_element` within `structure`. 1639 """ 1640 return list(zip(yield_flat_paths(structure, 1641 expand_composites=expand_composites), 1642 flatten(structure, expand_composites=expand_composites))) 1643 1644 1645@tf_export("__internal__.nest.list_to_tuple", v1=[]) 1646def list_to_tuple(structure): 1647 """Replace all lists with tuples. 1648 1649 The fork of nest that tf.data uses treats lists as single elements, while 1650 tf.nest treats them as structures to recurse into. Keras has chosen to adopt 1651 the latter convention, and must therefore deeply replace all lists with tuples 1652 before passing structures to Dataset.from_generator. 1653 1654 Args: 1655 structure: A nested structure to be remapped. 1656 1657 Returns: 1658 structure mapped to replace all lists with tuples. 1659 """ 1660 def sequence_fn(instance, args): 1661 if isinstance(instance, list): 1662 return tuple(args) 1663 return _sequence_like(instance, args) 1664 1665 return _pack_sequence_as(structure, flatten(structure), False, 1666 sequence_fn=sequence_fn) 1667 1668 1669_pywrap_utils.RegisterType("Mapping", _collections_abc.Mapping) 1670_pywrap_utils.RegisterType("MutableMapping", _collections_abc.MutableMapping) 1671_pywrap_utils.RegisterType("Sequence", _collections_abc.Sequence) 1672_pywrap_utils.RegisterType("MappingView", _collections_abc.MappingView) 1673_pywrap_utils.RegisterType("ObjectProxy", _wrapt.ObjectProxy) 1674