1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""## Functions for working with arbitrarily nested sequences of elements. 17 18This module can perform operations on nested structures. A nested structure is a 19Python collection that can contain further collections as well as other objects 20called atoms. Note that numpy arrays are considered atoms. 21 22nest recognizes the following types of collections: 23 1.tuple 24 2.namedtuple 25 3.dict 26 4.orderedDict 27 5.MutableMapping 28 6.attr.s 29 30attr.s decorated classes (http://www.attrs.org) are also supported, in the 31same way as `namedtuple`. 32 33The utilities here assume (and do not check) that the nested structures form a 34'tree', i.e., no references in the structure of the input of these functions 35should be recursive. 36 37Example structures: `((3, 4), 5, (6, 7, (9, 10), 8))`, `(np.array(0), 38 (np.array([3, 4]), tf.constant([3, 4])))` 39""" 40 41from __future__ import absolute_import 42from __future__ import division 43from __future__ import print_function 44 45import collections as _collections 46 47import six as _six 48import wrapt as _wrapt 49 50from tensorflow.python.platform import tf_logging 51from tensorflow.python.util import _pywrap_nest 52from tensorflow.python.util import _pywrap_utils 53from tensorflow.python.util.compat import collections_abc as _collections_abc 54from tensorflow.python.util.tf_export import tf_export 55 56 57_SHALLOW_TREE_HAS_INVALID_KEYS = ( 58 "The shallow_tree's keys are not a subset of the input_tree's keys. The " 59 "shallow_tree has the following keys that are not in the input_tree: {}.") 60 61_STRUCTURES_HAVE_MISMATCHING_TYPES = ( 62 "The two structures don't have the same sequence type. Input structure has " 63 "type {input_type}, while shallow structure has type {shallow_type}.") 64 65_STRUCTURES_HAVE_MISMATCHING_LENGTHS = ( 66 "The two structures don't have the same sequence length. Input " 67 "structure has length {input_length}, while shallow structure has length " 68 "{shallow_length}." 69) 70 71_INPUT_TREE_SMALLER_THAN_SHALLOW_TREE = ( 72 "The input_tree has fewer elements than the shallow_tree. Input structure " 73 "has length {input_size}, while shallow structure has length " 74 "{shallow_size}.") 75 76_IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ = ( 77 "If shallow structure is a sequence, input must also be a sequence. " 78 "Input has type: {}.") 79 80 81def _get_attrs_items(obj): 82 """Returns a list of (name, value) pairs from an attrs instance. 83 84 The list will be sorted by name. 85 86 Args: 87 obj: an object. 88 89 Returns: 90 A list of (attr_name, attr_value) pairs, sorted by attr_name. 91 """ 92 attrs = getattr(obj.__class__, "__attrs_attrs__") 93 attr_names = (a.name for a in attrs) 94 return [(attr_name, getattr(obj, attr_name)) for attr_name in attr_names] 95 96 97def _sorted(dict_): 98 """Returns a sorted list of the dict keys, with error if keys not sortable.""" 99 try: 100 return sorted(dict_.keys()) 101 except TypeError: 102 raise TypeError("nest only supports dicts with sortable keys.") 103 104 105def is_namedtuple(instance, strict=False): 106 """Returns True iff `instance` is a `namedtuple`. 107 108 Args: 109 instance: An instance of a Python object. 110 strict: If True, `instance` is considered to be a `namedtuple` only if 111 it is a "plain" namedtuple. For instance, a class inheriting 112 from a `namedtuple` will be considered to be a `namedtuple` 113 iff `strict=False`. 114 115 Returns: 116 True if `instance` is a `namedtuple`. 117 """ 118 return _pywrap_utils.IsNamedtuple(instance, strict) 119 120_is_namedtuple = is_namedtuple # This function was private up to TF2.5. 121 122# See the swig file (util.i) for documentation. 123_is_mapping_view = _pywrap_utils.IsMappingView 124_is_attrs = _pywrap_utils.IsAttrs 125_is_composite_tensor = _pywrap_utils.IsCompositeTensor 126_is_type_spec = _pywrap_utils.IsTypeSpec 127_is_mutable_mapping = _pywrap_utils.IsMutableMapping 128_is_mapping = _pywrap_utils.IsMapping 129 130 131@tf_export("__internal__.nest.is_attrs", v1=[]) 132def is_attrs(obj): 133 """Returns a true if its input is an instance of an attr.s decorated class.""" 134 return _is_attrs(obj) 135 136 137@tf_export("__internal__.nest.is_mapping", v1=[]) 138def is_mapping(obj): 139 """Returns a true if its input is a collections.Mapping.""" 140 return _is_mapping(obj) 141 142 143@tf_export("__internal__.nest.sequence_like", v1=[]) 144def _sequence_like(instance, args): 145 """Converts the sequence `args` to the same type as `instance`. 146 147 Args: 148 instance: an instance of `tuple`, `list`, `namedtuple`, `dict`, 149 `collections.OrderedDict`, or `composite_tensor.Composite_Tensor` 150 or `type_spec.TypeSpec`. 151 args: elements to be converted to the `instance` type. 152 153 Returns: 154 `args` with the type of `instance`. 155 """ 156 if _is_mutable_mapping(instance): 157 # Pack dictionaries in a deterministic order by sorting the keys. 158 # Notice this means that we ignore the original order of `OrderedDict` 159 # instances. This is intentional, to avoid potential bugs caused by mixing 160 # ordered and plain dicts (e.g., flattening a dict but using a 161 # corresponding `OrderedDict` to pack it back). 162 result = dict(zip(_sorted(instance), args)) 163 instance_type = type(instance) 164 if instance_type == _collections.defaultdict: 165 d = _collections.defaultdict(instance.default_factory) 166 else: 167 d = instance_type() 168 for key in instance: 169 d[key] = result[key] 170 return d 171 elif _is_mapping(instance): 172 result = dict(zip(_sorted(instance), args)) 173 instance_type = type(instance) 174 tf_logging.log_first_n( 175 tf_logging.WARN, "Mapping types may not work well with tf.nest. Prefer" 176 " using MutableMapping for {}".format(instance_type), 1) 177 try: 178 return instance_type((key, result[key]) for key in instance) 179 except TypeError as err: 180 raise TypeError("Error creating an object of type {} like {}. Note that " 181 "it must accept a single positional argument " 182 "representing an iterable of key-value pairs, in " 183 "addition to self. Cause: {}".format( 184 type(instance), instance, err)) 185 elif _is_mapping_view(instance): 186 # We can't directly construct mapping views, so we create a list instead 187 return list(args) 188 elif is_namedtuple(instance) or _is_attrs(instance): 189 if isinstance(instance, _wrapt.ObjectProxy): 190 instance_type = type(instance.__wrapped__) 191 else: 192 instance_type = type(instance) 193 return instance_type(*args) 194 elif _is_composite_tensor(instance): 195 assert len(args) == 1 196 spec = instance._type_spec # pylint: disable=protected-access 197 return spec._from_components(args[0]) # pylint: disable=protected-access 198 elif _is_type_spec(instance): 199 # Pack a CompositeTensor's components according to a TypeSpec. 200 assert len(args) == 1 201 return instance._from_components(args[0]) # pylint: disable=protected-access 202 elif isinstance(instance, _six.moves.range): 203 return _sequence_like(list(instance), args) 204 elif isinstance(instance, _wrapt.ObjectProxy): 205 # For object proxies, first create the underlying type and then re-wrap it 206 # in the proxy type. 207 return type(instance)(_sequence_like(instance.__wrapped__, args)) 208 else: 209 # Not a namedtuple 210 return type(instance)(args) 211 212 213def _yield_value(iterable): 214 for _, v in _yield_sorted_items(iterable): 215 yield v 216 217 218def _yield_sorted_items(iterable): 219 """Yield (key, value) pairs for `iterable` in a deterministic order. 220 221 For Sequences, the key will be an int, the array index of a value. 222 For Mappings, the key will be the dictionary key. 223 For objects (e.g. namedtuples), the key will be the attribute name. 224 225 In all cases, the keys will be iterated in sorted order. 226 227 Args: 228 iterable: an iterable. 229 230 Yields: 231 The iterable's (key, value) pairs, in order of sorted keys. 232 """ 233 # Ordered to check common structure types (list, tuple, dict) first. 234 if isinstance(iterable, list): 235 for item in enumerate(iterable): 236 yield item 237 # namedtuples handled separately to avoid expensive namedtuple check. 238 elif type(iterable) == tuple: # pylint: disable=unidiomatic-typecheck 239 for item in enumerate(iterable): 240 yield item 241 elif isinstance(iterable, (dict, _collections_abc.Mapping)): 242 # Iterate through dictionaries in a deterministic order by sorting the 243 # keys. Notice this means that we ignore the original order of `OrderedDict` 244 # instances. This is intentional, to avoid potential bugs caused by mixing 245 # ordered and plain dicts (e.g., flattening a dict but using a 246 # corresponding `OrderedDict` to pack it back). 247 for key in _sorted(iterable): 248 yield key, iterable[key] 249 elif _is_attrs(iterable): 250 for item in _get_attrs_items(iterable): 251 yield item 252 elif is_namedtuple(iterable): 253 for field in iterable._fields: 254 yield field, getattr(iterable, field) 255 elif _is_composite_tensor(iterable): 256 type_spec = iterable._type_spec # pylint: disable=protected-access 257 yield type_spec.value_type.__name__, type_spec._to_components(iterable) # pylint: disable=protected-access 258 elif _is_type_spec(iterable): 259 # Note: to allow CompositeTensors and their TypeSpecs to have matching 260 # structures, we need to use the same key string here. 261 yield iterable.value_type.__name__, iterable._component_specs # pylint: disable=protected-access 262 else: 263 for item in enumerate(iterable): 264 yield item 265 266 267# See the swig file (util.i) for documentation. 268is_sequence = _pywrap_utils.IsSequence 269 270 271# See the swig file (util.i) for documentation. 272is_sequence_or_composite = _pywrap_utils.IsSequenceOrComposite 273 274 275@tf_export("nest.is_nested") 276def is_nested(seq): 277 """Returns true if its input is a collections.abc.Sequence (except strings). 278 279 >>> tf.nest.is_nested("1234") 280 False 281 282 >>> tf.nest.is_nested([1, 3, [4, 5]]) 283 True 284 285 >>> tf.nest.is_nested(((7, 8), (5, 6))) 286 True 287 288 >>> tf.nest.is_nested([]) 289 True 290 291 >>> tf.nest.is_nested({"a": 1, "b": 2}) 292 True 293 294 >>> tf.nest.is_nested({"a": 1, "b": 2}.keys()) 295 True 296 297 >>> tf.nest.is_nested({"a": 1, "b": 2}.values()) 298 True 299 300 >>> tf.nest.is_nested({"a": 1, "b": 2}.items()) 301 True 302 303 >>> tf.nest.is_nested(set([1, 2])) 304 False 305 306 >>> ones = tf.ones([2, 3]) 307 >>> tf.nest.is_nested(ones) 308 False 309 310 Args: 311 seq: an input sequence. 312 313 Returns: 314 True if the sequence is a not a string and is a collections.abc.Sequence 315 or a dict. 316 """ 317 return is_sequence(seq) 318 319 320@tf_export("nest.flatten") 321def flatten(structure, expand_composites=False): 322 """Returns a flat list from a given nested structure. 323 324 If nest is not a structure , tuple (or a namedtuple), dict, or an attrs class, 325 then returns a single-element list: 326 [nest]. 327 328 This is the inverse of the `nest.pack_sequence_as` method that takes in a 329 flattened list and re-packs it into the nested structure. 330 331 In the case of dict instances, the sequence consists of the values, sorted by 332 key to ensure deterministic behavior. This is true also for OrderedDict 333 instances: their sequence order is ignored, the sorting order of keys is used 334 instead. The same convention is followed in `nest.pack_sequence_as`. This 335 correctly repacks dicts and OrderedDicts after they have been flattened, and 336 also allows flattening an OrderedDict and then repacking it back using a 337 corresponding plain dict, or vice-versa. Dictionaries with non-sortable keys 338 cannot be flattened. 339 340 Users must not modify any collections used in nest while this function is 341 running. 342 343 Examples: 344 345 1. Python dict (ordered by key): 346 347 >>> dict = { "key3": "value3", "key1": "value1", "key2": "value2" } 348 >>> tf.nest.flatten(dict) 349 ['value1', 'value2', 'value3'] 350 351 2. For a nested python tuple: 352 353 >>> tuple = ((1.0, 2.0), (3.0, 4.0, 5.0), 6.0) 354 >>> tf.nest.flatten(tuple) 355 [1.0, 2.0, 3.0, 4.0, 5.0, 6.0] 356 357 3. For a nested dictionary of dictionaries: 358 359 >>> dict = { "key3": {"c": (1.0, 2.0), "a": (3.0)}, 360 ... "key1": {"m": "val1", "g": "val2"} } 361 >>> tf.nest.flatten(dict) 362 ['val2', 'val1', 3.0, 1.0, 2.0] 363 364 4. Numpy array (will not flatten): 365 366 >>> array = np.array([[1, 2], [3, 4]]) 367 >>> tf.nest.flatten(array) 368 [array([[1, 2], 369 [3, 4]])] 370 371 5. `tf.Tensor` (will not flatten): 372 373 >>> tensor = tf.constant([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]) 374 >>> tf.nest.flatten(tensor) 375 [<tf.Tensor: shape=(3, 3), dtype=float32, numpy= 376 array([[1., 2., 3.], 377 [4., 5., 6.], 378 [7., 8., 9.]], dtype=float32)>] 379 380 6. `tf.RaggedTensor`: This is a composite tensor thats representation consists 381 of a flattened list of 'values' and a list of 'row_splits' which indicate how 382 to chop up the flattened list into different rows. For more details on 383 `tf.RaggedTensor`, please visit 384 https://www.tensorflow.org/api_docs/python/tf/RaggedTensor. 385 386 with `expand_composites=False`, we just return the RaggedTensor as is. 387 388 >>> tensor = tf.ragged.constant([[3, 1, 4, 1], [], [5, 9, 2]]) 389 >>> tf.nest.flatten(tensor, expand_composites=False) 390 [<tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2]]>] 391 392 with `expand_composites=True`, we return the component Tensors that make up 393 the RaggedTensor representation (the values and row_splits tensors) 394 395 >>> tensor = tf.ragged.constant([[3, 1, 4, 1], [], [5, 9, 2]]) 396 >>> tf.nest.flatten(tensor, expand_composites=True) 397 [<tf.Tensor: shape=(7,), dtype=int32, numpy=array([3, 1, 4, 1, 5, 9, 2], 398 dtype=int32)>, 399 <tf.Tensor: shape=(4,), dtype=int64, numpy=array([0, 4, 4, 7])>] 400 401 Args: 402 structure: an arbitrarily nested structure. Note, numpy arrays are 403 considered atoms and are not flattened. 404 expand_composites: If true, then composite tensors such as 405 `tf.sparse.SparseTensor` and `tf.RaggedTensor` are expanded into their 406 component tensors. 407 408 Returns: 409 A Python list, the flattened version of the input. 410 411 Raises: 412 TypeError: The nest is or contains a dict with non-sortable keys. 413 """ 414 if structure is None: 415 return [None] 416 expand_composites = bool(expand_composites) 417 return _pywrap_utils.Flatten(structure, expand_composites) 418 419 420# See the swig file (util.i) for documentation. 421same_namedtuples = _pywrap_utils.SameNamedtuples 422_same_namedtuples = same_namedtuples # This function was private up to TF2.5. 423 424 425class _DotString(object): 426 427 __slots__ = [] 428 429 def __str__(self): 430 return "." 431 432 def __repr__(self): 433 return "." 434 435 436_DOT = _DotString() 437 438 439@tf_export("nest.assert_same_structure") 440def assert_same_structure(nest1, nest2, check_types=True, 441 expand_composites=False): 442 """Asserts that two structures are nested in the same way. 443 444 Note the method does not check the types of data inside the structures. 445 446 Examples: 447 448 * These scalar vs. scalar comparisons will pass: 449 450 >>> tf.nest.assert_same_structure(1.5, tf.Variable(1, tf.uint32)) 451 >>> tf.nest.assert_same_structure("abc", np.array([1, 2])) 452 453 * These sequence vs. sequence comparisons will pass: 454 455 >>> structure1 = (((1, 2), 3), 4, (5, 6)) 456 >>> structure2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6")) 457 >>> structure3 = [(("a", "b"), "c"), "d", ["e", "f"]] 458 >>> tf.nest.assert_same_structure(structure1, structure2) 459 >>> tf.nest.assert_same_structure(structure1, structure3, check_types=False) 460 461 >>> import collections 462 >>> tf.nest.assert_same_structure( 463 ... collections.namedtuple("bar", "a b")(1, 2), 464 ... collections.namedtuple("foo", "a b")(2, 3), 465 ... check_types=False) 466 467 >>> tf.nest.assert_same_structure( 468 ... collections.namedtuple("bar", "a b")(1, 2), 469 ... { "a": 1, "b": 2 }, 470 ... check_types=False) 471 472 >>> tf.nest.assert_same_structure( 473 ... { "a": 1, "b": 2, "c": 3 }, 474 ... { "c": 6, "b": 5, "a": 4 }) 475 476 >>> ragged_tensor1 = tf.RaggedTensor.from_row_splits( 477 ... values=[3, 1, 4, 1, 5, 9, 2, 6], 478 ... row_splits=[0, 4, 4, 7, 8, 8]) 479 >>> ragged_tensor2 = tf.RaggedTensor.from_row_splits( 480 ... values=[3, 1, 4], 481 ... row_splits=[0, 3]) 482 >>> tf.nest.assert_same_structure( 483 ... ragged_tensor1, 484 ... ragged_tensor2, 485 ... expand_composites=True) 486 487 * These examples will raise exceptions: 488 489 >>> tf.nest.assert_same_structure([0, 1], np.array([0, 1])) 490 Traceback (most recent call last): 491 ... 492 ValueError: The two structures don't have the same nested structure 493 494 >>> tf.nest.assert_same_structure( 495 ... collections.namedtuple('bar', 'a b')(1, 2), 496 ... collections.namedtuple('foo', 'a b')(2, 3)) 497 Traceback (most recent call last): 498 ... 499 TypeError: The two structures don't have the same nested structure 500 501 Args: 502 nest1: an arbitrarily nested structure. 503 nest2: an arbitrarily nested structure. 504 check_types: if `True` (default) types of sequences are checked as well, 505 including the keys of dictionaries. If set to `False`, for example a 506 list and a tuple of objects will look the same if they have the same 507 size. Note that namedtuples with identical name and fields are always 508 considered to have the same shallow structure. Two types will also be 509 considered the same if they are both list subtypes (which allows "list" 510 and "_ListWrapper" from trackable dependency tracking to compare 511 equal). 512 expand_composites: If true, then composite tensors such as 513 `tf.sparse.SparseTensor` and `tf.RaggedTensor` are expanded into their 514 component tensors. 515 516 Raises: 517 ValueError: If the two structures do not have the same number of elements or 518 if the two structures are not nested in the same way. 519 TypeError: If the two structures differ in the type of sequence in any of 520 their substructures. Only possible if `check_types` is `True`. 521 """ 522 # Convert to bool explicitly as otherwise pybind will not be able# to handle 523 # type mismatch message correctly. See GitHub issue 42329 for details. 524 check_types = bool(check_types) 525 expand_composites = bool(expand_composites) 526 try: 527 _pywrap_utils.AssertSameStructure(nest1, nest2, check_types, 528 expand_composites) 529 except (ValueError, TypeError) as e: 530 str1 = str(map_structure(lambda _: _DOT, nest1)) 531 str2 = str(map_structure(lambda _: _DOT, nest2)) 532 raise type(e)("%s\n" 533 "Entire first structure:\n%s\n" 534 "Entire second structure:\n%s" 535 % (str(e), str1, str2)) 536 537 538def flatten_dict_items(dictionary): 539 """Returns a dictionary with flattened keys and values. 540 541 This function flattens the keys and values of a dictionary, which can be 542 arbitrarily nested structures, and returns the flattened version of such 543 structures: 544 545 ```python 546 example_dictionary = {(4, 5, (6, 8)): ("a", "b", ("c", "d"))} 547 result = {4: "a", 5: "b", 6: "c", 8: "d"} 548 flatten_dict_items(example_dictionary) == result 549 ``` 550 551 The input dictionary must satisfy two properties: 552 553 1. Its keys and values should have the same exact nested structure. 554 2. The set of all flattened keys of the dictionary must not contain repeated 555 keys. 556 557 Args: 558 dictionary: the dictionary to zip 559 560 Returns: 561 The zipped dictionary. 562 563 Raises: 564 TypeError: If the input is not a dictionary. 565 ValueError: If any key and value do not have the same structure layout, or 566 if keys are not unique. 567 """ 568 return _pywrap_nest.FlattenDictItems(dictionary) 569 570 571def _packed_nest_with_indices(structure, flat, index, is_seq, sequence_fn=None): 572 """Helper function for pack_sequence_as. 573 574 Args: 575 structure: Substructure (list / tuple / dict) to mimic. 576 flat: Flattened values to output substructure for. 577 index: Index at which to start reading from flat. 578 is_seq: Function used to test if a value should be treated as a sequence. 579 sequence_fn: Function used to generate a new sequence instance. 580 581 Returns: 582 The tuple (new_index, child), where: 583 * new_index - the updated index into `flat` having processed `structure`. 584 * packed - the subset of `flat` corresponding to `structure`, 585 having started at `index`, and packed into the same nested 586 format. 587 588 Raises: 589 ValueError: if `structure` contains more elements than `flat` 590 (assuming indexing starts from `index`). 591 """ 592 packed = [] 593 sequence_fn = sequence_fn or _sequence_like 594 for s in _yield_value(structure): 595 if is_seq(s): 596 new_index, child = _packed_nest_with_indices(s, flat, index, is_seq, 597 sequence_fn) 598 packed.append(sequence_fn(s, child)) 599 index = new_index 600 else: 601 packed.append(flat[index]) 602 index += 1 603 return index, packed 604 605 606def _pack_sequence_as(structure, flat_sequence, expand_composites, 607 sequence_fn=None): 608 """Implements sequence packing, with the option to alter the structure.""" 609 is_seq = is_sequence_or_composite if expand_composites else is_sequence 610 sequence_fn = sequence_fn or _sequence_like 611 def truncate(value, length): 612 value_str = str(value) 613 return value_str[:length] + (value_str[length:] and "...") 614 615 if not is_seq(flat_sequence): 616 raise TypeError( 617 "Attempted to pack value:\n {}\ninto a sequence, but found " 618 "incompatible type `{}` instead." 619 .format(truncate(flat_sequence, 100), type(flat_sequence))) 620 621 if not is_seq(structure): 622 if len(flat_sequence) != 1: 623 raise ValueError( 624 "The target structure is of type `{}`\n {}\nHowever the input " 625 "structure is a sequence ({}) of length {}.\n {}\nnest cannot " 626 "guarantee that it is safe to map one to the other.".format( 627 type(structure), truncate(structure, 100), type(flat_sequence), 628 len(flat_sequence), truncate(flat_sequence, 100))) 629 return flat_sequence[0] 630 631 try: 632 final_index, packed = _packed_nest_with_indices(structure, flat_sequence, 633 0, is_seq, sequence_fn) 634 if final_index < len(flat_sequence): 635 raise IndexError 636 except IndexError: 637 flat_structure = flatten(structure, expand_composites=expand_composites) 638 if len(flat_structure) != len(flat_sequence): 639 raise ValueError( 640 "Could not pack sequence. Structure had %d elements, but " 641 "flat_sequence had %d elements. Structure: %s, flat_sequence: %s." % 642 (len(flat_structure), len(flat_sequence), structure, flat_sequence)) 643 return sequence_fn(structure, packed) 644 645 646@tf_export("nest.pack_sequence_as") 647def pack_sequence_as(structure, flat_sequence, expand_composites=False): 648 """Returns a given flattened sequence packed into a given structure. 649 650 If `structure` is a scalar, `flat_sequence` must be a single-element list; 651 in this case the return value is `flat_sequence[0]`. 652 653 If `structure` is or contains a dict instance, the keys will be sorted to 654 pack the flat sequence in deterministic order. This is true also for 655 `OrderedDict` instances: their sequence order is ignored, the sorting order of 656 keys is used instead. The same convention is followed in `flatten`. 657 This correctly repacks dicts and `OrderedDict`s after they have been 658 flattened, and also allows flattening an `OrderedDict` and then repacking it 659 back using a corresponding plain dict, or vice-versa. 660 Dictionaries with non-sortable keys cannot be flattened. 661 662 Examples: 663 664 1. Python dict: 665 666 >>> structure = { "key3": "", "key1": "", "key2": "" } 667 >>> flat_sequence = ["value1", "value2", "value3"] 668 >>> tf.nest.pack_sequence_as(structure, flat_sequence) 669 {'key3': 'value3', 'key1': 'value1', 'key2': 'value2'} 670 671 2. For a nested python tuple: 672 673 >>> structure = (('a','b'), ('c','d','e'), 'f') 674 >>> flat_sequence = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0] 675 >>> tf.nest.pack_sequence_as(structure, flat_sequence) 676 ((1.0, 2.0), (3.0, 4.0, 5.0), 6.0) 677 678 3. For a nested dictionary of dictionaries: 679 680 >>> structure = { "key3": {"c": ('alpha', 'beta'), "a": ('gamma')}, 681 ... "key1": {"e": "val1", "d": "val2"} } 682 >>> flat_sequence = ['val2', 'val1', 3.0, 1.0, 2.0] 683 >>> tf.nest.pack_sequence_as(structure, flat_sequence) 684 {'key3': {'c': (1.0, 2.0), 'a': 3.0}, 'key1': {'e': 'val1', 'd': 'val2'}} 685 686 4. Numpy array (considered a scalar): 687 688 >>> structure = ['a'] 689 >>> flat_sequence = [np.array([[1, 2], [3, 4]])] 690 >>> tf.nest.pack_sequence_as(structure, flat_sequence) 691 [array([[1, 2], 692 [3, 4]])] 693 694 5. tf.Tensor (considered a scalar): 695 696 >>> structure = ['a'] 697 >>> flat_sequence = [tf.constant([[1., 2., 3.], [4., 5., 6.]])] 698 >>> tf.nest.pack_sequence_as(structure, flat_sequence) 699 [<tf.Tensor: shape=(2, 3), dtype=float32, 700 numpy= array([[1., 2., 3.], [4., 5., 6.]], dtype=float32)>] 701 702 6. `tf.RaggedTensor`: This is a composite tensor thats representation consists 703 of a flattened list of 'values' and a list of 'row_splits' which indicate how 704 to chop up the flattened list into different rows. For more details on 705 `tf.RaggedTensor`, please visit 706 https://www.tensorflow.org/api_docs/python/tf/RaggedTensor. 707 708 With `expand_composites=False`, we treat RaggedTensor as a scalar. 709 710 >>> structure = { "foo": tf.ragged.constant([[1, 2], [3]]), 711 ... "bar": tf.constant([[5]]) } 712 >>> flat_sequence = [ "one", "two" ] 713 >>> tf.nest.pack_sequence_as(structure, flat_sequence, 714 ... expand_composites=False) 715 {'foo': 'two', 'bar': 'one'} 716 717 With `expand_composites=True`, we expect that the flattened input contains 718 the tensors making up the ragged tensor i.e. the values and row_splits 719 tensors. 720 721 >>> structure = { "foo": tf.ragged.constant([[1., 2.], [3.]]), 722 ... "bar": tf.constant([[5.]]) } 723 >>> tensors = tf.nest.flatten(structure, expand_composites=True) 724 >>> print(tensors) 725 [<tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[5.]], 726 dtype=float32)>, 727 <tf.Tensor: shape=(3,), dtype=float32, numpy=array([1., 2., 3.], 728 dtype=float32)>, 729 <tf.Tensor: shape=(3,), dtype=int64, numpy=array([0, 2, 3])>] 730 >>> verified_tensors = [tf.debugging.check_numerics(t, 'invalid tensor: ') 731 ... if t.dtype==tf.float32 else t 732 ... for t in tensors] 733 >>> tf.nest.pack_sequence_as(structure, verified_tensors, 734 ... expand_composites=True) 735 {'foo': <tf.RaggedTensor [[1.0, 2.0], [3.0]]>, 736 'bar': <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[5.]], 737 dtype=float32)>} 738 739 Args: 740 structure: Nested structure, whose structure is given by nested lists, 741 tuples, and dicts. Note: numpy arrays and strings are considered 742 scalars. 743 flat_sequence: flat sequence to pack. 744 expand_composites: If true, then composite tensors such as 745 `tf.sparse.SparseTensor` and `tf.RaggedTensor` are expanded into their 746 component tensors. 747 748 Returns: 749 packed: `flat_sequence` converted to have the same recursive structure as 750 `structure`. 751 752 Raises: 753 ValueError: If `flat_sequence` and `structure` have different 754 element counts. 755 TypeError: `structure` is or contains a dict with non-sortable keys. 756 """ 757 return _pack_sequence_as(structure, flat_sequence, expand_composites) 758 759 760@tf_export("nest.map_structure") 761def map_structure(func, *structure, **kwargs): 762 """Applies `func` to each entry in `structure` and returns a new structure. 763 764 Applies `func(x[0], x[1], ...)` where x[i] is an entry in 765 `structure[i]`. All structures in `structure` must have the same arity, 766 and the return value will contain results with the same structure layout. 767 768 Examples: 769 770 * A single Python dict: 771 772 >>> a = {"hello": 24, "world": 76} 773 >>> tf.nest.map_structure(lambda p: p * 2, a) 774 {'hello': 48, 'world': 152} 775 776 * Multiple Python dictionaries: 777 778 >>> d1 = {"hello": 24, "world": 76} 779 >>> d2 = {"hello": 36, "world": 14} 780 >>> tf.nest.map_structure(lambda p1, p2: p1 + p2, d1, d2) 781 {'hello': 60, 'world': 90} 782 783 * A single Python list: 784 785 >>> a = [24, 76, "ab"] 786 >>> tf.nest.map_structure(lambda p: p * 2, a) 787 [48, 152, 'abab'] 788 789 * Scalars: 790 791 >>> tf.nest.map_structure(lambda x, y: x + y, 3, 4) 792 7 793 794 * Empty structures: 795 796 >>> tf.nest.map_structure(lambda x: x + 1, ()) 797 () 798 799 *. Check the types of iterables: 800 801 >>> s1 = (((1, 2), 3), 4, (5, 6)) 802 >>> s1_list = [[[1, 2], 3], 4, [5, 6]] 803 >>> tf.nest.map_structure(lambda x, y: None, s1, s1_list) 804 Traceback (most recent call last): 805 ... 806 TypeError: The two structures don't have the same nested structure 807 808 * Type check is set to False: 809 810 >>> s1 = (((1, 2), 3), 4, (5, 6)) 811 >>> s1_list = [[[1, 2], 3], 4, [5, 6]] 812 >>> tf.nest.map_structure(lambda x, y: None, s1, s1_list, check_types=False) 813 (((None, None), None), None, (None, None)) 814 815 Args: 816 func: A callable that accepts as many arguments as there are structures. 817 *structure: scalar, or tuple or dict or list of constructed scalars and/or 818 other tuples/lists, or scalars. Note: numpy arrays are considered as 819 scalars. 820 **kwargs: Valid keyword args are: 821 822 * `check_types`: If set to `True` (default) the types of 823 iterables within the structures have to be same (e.g. 824 `map_structure(func, [1], (1,))` raises a `TypeError` 825 exception). To allow this set this argument to `False`. 826 Note that namedtuples with identical name and fields are always 827 considered to have the same shallow structure. 828 * `expand_composites`: If set to `True`, then composite tensors such 829 as `tf.sparse.SparseTensor` and `tf.RaggedTensor` are expanded into 830 their component tensors. If `False` (the default), then composite 831 tensors are not expanded. 832 833 Returns: 834 A new structure with the same arity as `structure`, whose values correspond 835 to `func(x[0], x[1], ...)` where `x[i]` is a value in the corresponding 836 location in `structure[i]`. If there are different sequence types and 837 `check_types` is `False` the sequence types of the first structure will be 838 used. 839 840 Raises: 841 TypeError: If `func` is not callable or if the structures do not match 842 each other by depth tree. 843 ValueError: If no structure is provided or if the structures do not match 844 each other by type. 845 ValueError: If wrong keyword arguments are provided. 846 """ 847 if not callable(func): 848 raise TypeError("func must be callable, got: %s" % func) 849 850 if not structure: 851 raise ValueError("Must provide at least one structure") 852 853 check_types = kwargs.pop("check_types", True) 854 expand_composites = kwargs.pop("expand_composites", False) 855 856 if kwargs: 857 raise ValueError( 858 "Only valid keyword arguments are `check_types` and " 859 "`expand_composites`, not: `%s`" % ("`, `".join(kwargs.keys()))) 860 861 for other in structure[1:]: 862 assert_same_structure(structure[0], other, check_types=check_types, 863 expand_composites=expand_composites) 864 865 flat_structure = (flatten(s, expand_composites) for s in structure) 866 entries = zip(*flat_structure) 867 868 return pack_sequence_as( 869 structure[0], [func(*x) for x in entries], 870 expand_composites=expand_composites) 871 872 873def map_structure_with_paths(func, *structure, **kwargs): 874 """Applies `func` to each entry in `structure` and returns a new structure. 875 876 Applies `func(path, x[0], x[1], ..., **kwargs)` where x[i] is an entry in 877 `structure[i]` and `path` is the common path to x[i] in the structures. All 878 structures in `structure` must have the same arity, and the return value will 879 contain the results with the same structure layout. Special kwarg 880 `check_types` determines whether the types of iterables within the structure 881 must be the same-- see **kwargs definition below. 882 883 Args: 884 func: A callable with the signature func(path, *values, **kwargs) that is 885 evaluated on the leaves of the structure. 886 *structure: A variable number of compatible structures to process. 887 **kwargs: Optional kwargs to be passed through to func. Special kwarg 888 `check_types` is not passed to func, but instead determines whether the 889 types of iterables within the structures have to be same (e.g., 890 `map_structure(func, [1], (1,))` raises a `TypeError` exception). By 891 default, the types must match. To allow iteration over structures of 892 different types (but common arity), set this kwarg to `False`. 893 894 Returns: 895 A structure of the same form as the input structures whose leaves are the 896 result of evaluating func on corresponding leaves of the input structures. 897 898 Raises: 899 TypeError: If `func` is not callable or if the structures do not match 900 each other by depth tree. 901 TypeError: If `check_types` is not `False` and the two structures differ in 902 the type of sequence in any of their substructures. 903 ValueError: If no structures are provided. 904 """ 905 def wrapper_func(tuple_path, *inputs, **kwargs): 906 string_path = "/".join(str(s) for s in tuple_path) 907 return func(string_path, *inputs, **kwargs) 908 909 return map_structure_with_tuple_paths_up_to(structure[0], 910 wrapper_func, 911 *structure, 912 **kwargs) 913 914 915def map_structure_with_tuple_paths(func, *structure, **kwargs): 916 """Applies `func` to each entry in `structure` and returns a new structure. 917 918 Applies `func(tuple_path, x[0], x[1], ..., **kwargs)` where `x[i]` is an entry 919 in `structure[i]` and `tuple_path` is a tuple of indices and/or dictionary 920 keys (as returned by `nest.yield_flat_paths`), which uniquely specifies the 921 common path to x[i] in the structures. All structures in `structure` must have 922 the same arity, and the return value will contain the results in the same 923 structure. Special kwarg `check_types` determines whether the types of 924 iterables within the structure must be the same-- see **kwargs definition 925 below. 926 927 Args: 928 func: A callable with the signature `func(tuple_path, *values, **kwargs)` 929 that is evaluated on the leaves of the structure. 930 *structure: A variable number of compatible structures to process. 931 **kwargs: Optional kwargs to be passed through to func. Special kwarg 932 `check_types` is not passed to func, but instead determines whether the 933 types of iterables within the structures have to be same (e.g. 934 `map_structure(func, [1], (1,))` raises a `TypeError` exception). To allow 935 this set this argument to `False`. 936 937 Returns: 938 A structure of the same form as the input structures whose leaves are the 939 result of evaluating func on corresponding leaves of the input structures. 940 941 Raises: 942 TypeError: If `func` is not callable or if the structures do not match 943 each other by depth tree. 944 TypeError: If `check_types` is not `False` and the two structures differ in 945 the type of sequence in any of their substructures. 946 ValueError: If no structures are provided. 947 """ 948 return map_structure_with_tuple_paths_up_to(structure[0], 949 func, 950 *structure, 951 **kwargs) 952 953 954def _yield_flat_up_to(shallow_tree, input_tree, is_seq, path=()): 955 """Yields (path, value) pairs of input_tree flattened up to shallow_tree. 956 957 Args: 958 shallow_tree: Nested structure. Traverse no further than its leaf nodes. 959 input_tree: Nested structure. Return the paths and values from this tree. 960 Must have the same upper structure as shallow_tree. 961 is_seq: Function used to test if a value should be treated as a sequence. 962 path: Tuple. Optional argument, only used when recursing. The path from the 963 root of the original shallow_tree, down to the root of the shallow_tree 964 arg of this recursive call. 965 966 Yields: 967 Pairs of (path, value), where path the tuple path of a leaf node in 968 shallow_tree, and value is the value of the corresponding node in 969 input_tree. 970 """ 971 if not is_seq(shallow_tree): 972 yield (path, input_tree) 973 else: 974 input_tree = dict(_yield_sorted_items(input_tree)) 975 for shallow_key, shallow_subtree in _yield_sorted_items(shallow_tree): 976 subpath = path + (shallow_key,) 977 input_subtree = input_tree[shallow_key] 978 for leaf_path, leaf_value in _yield_flat_up_to(shallow_subtree, 979 input_subtree, is_seq, 980 path=subpath): 981 yield (leaf_path, leaf_value) 982 983 984def assert_shallow_structure(shallow_tree, 985 input_tree, 986 check_types=True, 987 expand_composites=False): 988 """Asserts that `shallow_tree` is a shallow structure of `input_tree`. 989 990 That is, this function tests if the `input_tree` structure can be created from 991 the `shallow_tree` structure by replacing its leaf nodes with deeper 992 tree structures. 993 994 Examples: 995 996 The following code will raise an exception: 997 ```python 998 shallow_tree = {"a": "A", "b": "B"} 999 input_tree = {"a": 1, "c": 2} 1000 assert_shallow_structure(shallow_tree, input_tree) 1001 ``` 1002 1003 The following code will raise an exception: 1004 ```python 1005 shallow_tree = ["a", "b"] 1006 input_tree = ["c", ["d", "e"], "f"] 1007 assert_shallow_structure(shallow_tree, input_tree) 1008 ``` 1009 1010 Args: 1011 shallow_tree: an arbitrarily nested structure. 1012 input_tree: an arbitrarily nested structure. 1013 check_types: if `True` (default) the sequence types of `shallow_tree` and 1014 `input_tree` have to be the same. Note that even with check_types==True, 1015 this function will consider two different namedtuple classes with the same 1016 name and _fields attribute to be the same class. 1017 expand_composites: If true, then composite tensors such as 1018 `tf.sparse.SparseTensor` and `tf.RaggedTensor` are expanded into their 1019 component tensors. 1020 Raises: 1021 TypeError: If `shallow_tree` is a sequence but `input_tree` is not. 1022 TypeError: If the sequence types of `shallow_tree` are different from 1023 `input_tree`. Only raised if `check_types` is `True`. 1024 ValueError: If the sequence lengths of `shallow_tree` are different from 1025 `input_tree`. 1026 """ 1027 is_seq = is_sequence_or_composite if expand_composites else is_sequence 1028 if is_seq(shallow_tree): 1029 if not is_seq(input_tree): 1030 raise TypeError( 1031 "If shallow structure is a sequence, input must also be a sequence. " 1032 "Input has type: %s." % type(input_tree)) 1033 1034 if isinstance(shallow_tree, _wrapt.ObjectProxy): 1035 shallow_type = type(shallow_tree.__wrapped__) 1036 else: 1037 shallow_type = type(shallow_tree) 1038 1039 if check_types and not isinstance(input_tree, shallow_type): 1040 # Duck-typing means that nest should be fine with two different 1041 # namedtuples with identical name and fields. 1042 shallow_is_namedtuple = is_namedtuple(shallow_tree, False) 1043 input_is_namedtuple = is_namedtuple(input_tree, False) 1044 if shallow_is_namedtuple and input_is_namedtuple: 1045 if not same_namedtuples(shallow_tree, input_tree): 1046 raise TypeError(_STRUCTURES_HAVE_MISMATCHING_TYPES.format( 1047 input_type=type(input_tree), 1048 shallow_type=type(shallow_tree))) 1049 1050 elif isinstance(shallow_tree, list) and isinstance(input_tree, list): 1051 # List subclasses are considered the same, 1052 # e.g. python list vs. _ListWrapper. 1053 pass 1054 1055 elif ((_is_composite_tensor(shallow_tree) or 1056 _is_composite_tensor(input_tree)) and 1057 (_is_type_spec(shallow_tree) or _is_type_spec(input_tree))): 1058 pass # Compatibility will be checked below. 1059 1060 elif not (isinstance(shallow_tree, _collections_abc.Mapping) and 1061 isinstance(input_tree, _collections_abc.Mapping)): 1062 raise TypeError(_STRUCTURES_HAVE_MISMATCHING_TYPES.format( 1063 input_type=type(input_tree), 1064 shallow_type=type(shallow_tree))) 1065 1066 if _is_composite_tensor(shallow_tree) or _is_composite_tensor(input_tree): 1067 if not ( 1068 (_is_composite_tensor(input_tree) or _is_type_spec(input_tree)) and 1069 (_is_composite_tensor(shallow_tree) or _is_type_spec(shallow_tree))): 1070 raise TypeError(_STRUCTURES_HAVE_MISMATCHING_TYPES.format( 1071 input_type=type(input_tree), 1072 shallow_type=type(shallow_tree))) 1073 type_spec_1 = (shallow_tree if _is_type_spec(shallow_tree) else 1074 shallow_tree._type_spec) # pylint: disable=protected-access 1075 type_spec_2 = (input_tree if _is_type_spec(input_tree) else 1076 input_tree._type_spec) # pylint: disable=protected-access 1077 try: 1078 _ = type_spec_1.most_specific_compatible_type(type_spec_2) 1079 except (TypeError, ValueError) as e: 1080 raise ValueError( 1081 "Incompatible CompositeTensor TypeSpecs: %s vs. %s -- %s" % 1082 (type_spec_1, type_spec_2, e)) 1083 1084 elif _is_type_spec(shallow_tree): 1085 if not _is_type_spec(input_tree): 1086 raise TypeError("If shallow structure is a TypeSpec, input must also " 1087 "be a TypeSpec. Input has type: %s." 1088 % type(input_tree)) 1089 else: 1090 if len(input_tree) != len(shallow_tree): 1091 raise ValueError( 1092 _STRUCTURES_HAVE_MISMATCHING_LENGTHS.format( 1093 input_length=len(input_tree), shallow_length=len(shallow_tree))) 1094 elif len(input_tree) < len(shallow_tree): 1095 raise ValueError( 1096 _INPUT_TREE_SMALLER_THAN_SHALLOW_TREE.format( 1097 input_size=len(input_tree), shallow_size=len(shallow_tree))) 1098 1099 if isinstance(shallow_tree, _collections_abc.Mapping): 1100 absent_keys = set(shallow_tree) - set(input_tree) 1101 if absent_keys: 1102 raise ValueError(_SHALLOW_TREE_HAS_INVALID_KEYS 1103 .format(sorted(absent_keys))) 1104 1105 for shallow_branch, input_branch in zip(_yield_value(shallow_tree), 1106 _yield_value(input_tree)): 1107 assert_shallow_structure(shallow_branch, input_branch, 1108 check_types=check_types, 1109 expand_composites=expand_composites) 1110 1111 1112@tf_export("__internal__.nest.flatten_up_to", v1=[]) 1113def flatten_up_to(shallow_tree, input_tree, check_types=True, 1114 expand_composites=False): 1115 """Flattens `input_tree` up to `shallow_tree`. 1116 1117 Any further depth in structure in `input_tree` is retained as elements in the 1118 partially flatten output. 1119 1120 If `shallow_tree` and `input_tree` are not sequences, this returns a 1121 single-element list: `[input_tree]`. 1122 1123 Use Case: 1124 1125 Sometimes we may wish to partially flatten a nested sequence, retaining some 1126 of the nested structure. We achieve this by specifying a shallow structure, 1127 `shallow_tree`, we wish to flatten up to. 1128 1129 The input, `input_tree`, can be thought of as having the same structure layout 1130 as `shallow_tree`, but with leaf nodes that are themselves tree structures. 1131 1132 Examples: 1133 1134 ```python 1135 input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]] 1136 shallow_tree = [[True, True], [False, True]] 1137 1138 flattened_input_tree = flatten_up_to(shallow_tree, input_tree) 1139 flattened_shallow_tree = flatten_up_to(shallow_tree, shallow_tree) 1140 1141 # Output is: 1142 # [[2, 2], [3, 3], [4, 9], [5, 5]] 1143 # [True, True, False, True] 1144 ``` 1145 1146 ```python 1147 input_tree = [[('a', 1), [('b', 2), [('c', 3), [('d', 4)]]]]] 1148 shallow_tree = [['level_1', ['level_2', ['level_3', ['level_4']]]]] 1149 1150 input_tree_flattened_as_shallow_tree = flatten_up_to(shallow_tree, input_tree) 1151 input_tree_flattened = flatten(input_tree) 1152 1153 # Output is: 1154 # [('a', 1), ('b', 2), ('c', 3), ('d', 4)] 1155 # ['a', 1, 'b', 2, 'c', 3, 'd', 4] 1156 ``` 1157 1158 Non-Sequence Edge Cases: 1159 1160 ```python 1161 flatten_up_to(0, 0) # Output: [0] 1162 flatten_up_to(0, [0, 1, 2]) # Output: [[0, 1, 2]] 1163 flatten_up_to([0, 1, 2], 0) # Output: TypeError 1164 flatten_up_to([0, 1, 2], [0, 1, 2]) # Output: [0, 1, 2] 1165 ``` 1166 1167 Args: 1168 shallow_tree: a possibly pruned structure of input_tree. 1169 input_tree: an arbitrarily nested structure or a scalar object. 1170 Note, numpy arrays are considered scalars. 1171 check_types: bool. If True, check that each node in shallow_tree has the 1172 same type as the corresponding node in input_tree. 1173 expand_composites: If true, then composite tensors such as 1174 `tf.sparse.SparseTensor` and `tf.RaggedTensor` are expanded into their 1175 component tensors. 1176 1177 Returns: 1178 A Python list, the partially flattened version of `input_tree` according to 1179 the structure of `shallow_tree`. 1180 1181 Raises: 1182 TypeError: If `shallow_tree` is a sequence but `input_tree` is not. 1183 TypeError: If the sequence types of `shallow_tree` are different from 1184 `input_tree`. 1185 ValueError: If the sequence lengths of `shallow_tree` are different from 1186 `input_tree`. 1187 """ 1188 is_seq = is_sequence_or_composite if expand_composites else is_sequence 1189 assert_shallow_structure(shallow_tree, 1190 input_tree, 1191 check_types=check_types, 1192 expand_composites=expand_composites) 1193 # Discard paths returned by _yield_flat_up_to. 1194 return [v for _, v in _yield_flat_up_to(shallow_tree, input_tree, is_seq)] 1195 1196 1197def flatten_with_tuple_paths_up_to(shallow_tree, 1198 input_tree, 1199 check_types=True, 1200 expand_composites=False): 1201 """Flattens `input_tree` up to `shallow_tree`. 1202 1203 Any further depth in structure in `input_tree` is retained as elements in the 1204 partially flattened output. 1205 1206 Returns a list of (path, value) pairs, where value a leaf node in the 1207 flattened tree, and path is the tuple path of that leaf in input_tree. 1208 1209 If `shallow_tree` and `input_tree` are not sequences, this returns a 1210 single-element list: `[((), input_tree)]`. 1211 1212 Use Case: 1213 1214 Sometimes we may wish to partially flatten a nested sequence, retaining some 1215 of the nested structure. We achieve this by specifying a shallow structure, 1216 `shallow_tree`, we wish to flatten up to. 1217 1218 The input, `input_tree`, can be thought of as having the same structure layout 1219 as `shallow_tree`, but with leaf nodes that are themselves tree structures. 1220 1221 Examples: 1222 1223 ```python 1224 input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]] 1225 shallow_tree = [[True, True], [False, True]] 1226 1227 flattened_input_tree = flatten_with_tuple_paths_up_to(shallow_tree, 1228 input_tree) 1229 flattened_shallow_tree = flatten_with_tuple_paths_up_to(shallow_tree, 1230 shallow_tree) 1231 1232 # Output is: 1233 # [((0, 0), [2, 2]), 1234 # ((0, 1), [3, 3]), 1235 # ((1, 0), [4, 9]), 1236 # ((1, 1), [5, 5])] 1237 # 1238 # [((0, 0), True), 1239 # ((0, 1), True), 1240 # ((1, 0), False), 1241 # ((1, 1), True)] 1242 ``` 1243 1244 ```python 1245 input_tree = [[('a', 1), [('b', 2), [('c', 3), [('d', 4)]]]]] 1246 shallow_tree = [['level_1', ['level_2', ['level_3', ['level_4']]]]] 1247 1248 input_tree_flattened_as_shallow_tree = flatten_up_to(shallow_tree, input_tree) 1249 input_tree_flattened = flatten(input_tree) 1250 1251 # Output is: 1252 # [((0, 0), ('a', 1)), 1253 # ((0, 1, 0), ('b', 2)), 1254 # ((0, 1, 1, 0), ('c', 3)), 1255 # ((0, 1, 1, 1), ('d', 4))] 1256 # ['a', 1, 'b', 2, 'c', 3, 'd', 4] 1257 ``` 1258 1259 Non-Sequence Edge Cases: 1260 1261 ```python 1262 flatten_with_tuple_paths_up_to(0, 0) # Output: [(), 0] 1263 1264 flatten_with_tuple_paths_up_to(0, [0, 1, 2]) # Output: [(), [0, 1, 2]] 1265 1266 flatten_with_tuple_paths_up_to([0, 1, 2], 0) # Output: TypeError 1267 1268 flatten_with_tuple_paths_up_to([0, 1, 2], [0, 1, 2]) 1269 # Output: [((0,) 0), ((1,), 1), ((2,), 2)] 1270 ``` 1271 1272 Args: 1273 shallow_tree: a possibly pruned structure of input_tree. 1274 input_tree: an arbitrarily nested structure or a scalar object. 1275 Note, numpy arrays are considered scalars. 1276 check_types: bool. If True, check that each node in shallow_tree has the 1277 same type as the corresponding node in input_tree. 1278 expand_composites: If true, then composite tensors such as 1279 `tf.sparse.SparseTensor` and `tf.RaggedTensor` are expanded into their 1280 component tensors. 1281 1282 Returns: 1283 A Python list, the partially flattened version of `input_tree` according to 1284 the structure of `shallow_tree`. 1285 1286 Raises: 1287 TypeError: If `shallow_tree` is a sequence but `input_tree` is not. 1288 TypeError: If the sequence types of `shallow_tree` are different from 1289 `input_tree`. 1290 ValueError: If the sequence lengths of `shallow_tree` are different from 1291 `input_tree`. 1292 """ 1293 is_seq = is_sequence_or_composite if expand_composites else is_sequence 1294 assert_shallow_structure(shallow_tree, 1295 input_tree, 1296 check_types=check_types, 1297 expand_composites=expand_composites) 1298 return list(_yield_flat_up_to(shallow_tree, input_tree, is_seq)) 1299 1300 1301@tf_export("__internal__.nest.map_structure_up_to", v1=[]) 1302def map_structure_up_to(shallow_tree, func, *inputs, **kwargs): 1303 """Applies a function or op to a number of partially flattened inputs. 1304 1305 The `inputs` are flattened up to `shallow_tree` before being mapped. 1306 1307 Use Case: 1308 1309 Sometimes we wish to apply a function to a partially flattened 1310 sequence (for example when the function itself takes sequence inputs). We 1311 achieve this by specifying a shallow structure, `shallow_tree` we wish to 1312 flatten up to. 1313 1314 The `inputs`, can be thought of as having the same structure layout as 1315 `shallow_tree`, but with leaf nodes that are themselves tree structures. 1316 1317 This function therefore will return something with the same base structure as 1318 `shallow_tree`. 1319 1320 Examples: 1321 1322 ```python 1323 shallow_tree = [None, None] 1324 inp_val = [1, 2, 3] 1325 out = map_structure_up_to(shallow_tree, lambda x: 2 * x, inp_val) 1326 1327 # Output is: [2, 4] 1328 ``` 1329 1330 ```python 1331 ab_tuple = collections.namedtuple("ab_tuple", "a, b") 1332 op_tuple = collections.namedtuple("op_tuple", "add, mul") 1333 inp_val = ab_tuple(a=2, b=3) 1334 inp_ops = ab_tuple(a=op_tuple(add=1, mul=2), b=op_tuple(add=2, mul=3)) 1335 out = map_structure_up_to(inp_val, lambda val, ops: (val + ops.add) * ops.mul, 1336 inp_val, inp_ops) 1337 1338 # Output is: ab_tuple(a=6, b=15) 1339 ``` 1340 1341 ```python 1342 data_list = [[2, 4, 6, 8], [[1, 3, 5, 7, 9], [3, 5, 7]]] 1343 name_list = ['evens', ['odds', 'primes']] 1344 out = map_structure_up_to( 1345 name_list, 1346 lambda name, sec: "first_{}_{}".format(len(sec), name), 1347 name_list, data_list) 1348 1349 # Output is: ['first_4_evens', ['first_5_odds', 'first_3_primes']] 1350 ``` 1351 1352 Args: 1353 shallow_tree: a shallow tree, common to all the inputs. 1354 func: callable which will be applied to each input individually. 1355 *inputs: arbitrarily nested combination of objects that are compatible with 1356 shallow_tree. The function `func` is applied to corresponding 1357 partially flattened elements of each input, so the function must support 1358 arity of `len(inputs)`. 1359 **kwargs: kwargs to feed to func(). Special kwarg 1360 `check_types` is not passed to func, but instead determines whether the 1361 types of iterables within the structures have to be same (e.g. 1362 `map_structure(func, [1], (1,))` raises a `TypeError` exception). To allow 1363 this set this argument to `False`. 1364 1365 Raises: 1366 TypeError: If `shallow_tree` is a sequence but `input_tree` is not. 1367 TypeError: If the sequence types of `shallow_tree` are different from 1368 `input_tree`. 1369 ValueError: If the sequence lengths of `shallow_tree` are different from 1370 `input_tree`. 1371 1372 Returns: 1373 result of repeatedly applying `func`, with the same structure layout as 1374 `shallow_tree`. 1375 """ 1376 return map_structure_with_tuple_paths_up_to( 1377 shallow_tree, 1378 lambda _, *values: func(*values), # Discards the path arg. 1379 *inputs, 1380 **kwargs) 1381 1382 1383def map_structure_with_tuple_paths_up_to(shallow_tree, func, *inputs, **kwargs): 1384 """Applies a function or op to a number of partially flattened inputs. 1385 1386 Like map_structure_up_to(), except that the 'func' argument takes a path 1387 tuple as its first argument, followed by the corresponding values from 1388 *inputs. 1389 1390 Example: 1391 1392 ```python 1393 lowercase = {'a': 'a', 'b': ('b0', 'b1')} 1394 uppercase = {'a': 'A', 'b': ('B0', 'B1')} 1395 1396 def print_path_and_values(path, *values): 1397 print("path: {}, values: {}".format(path, values)) 1398 1399 shallow_tree = {'a': None} 1400 map_structure_with_tuple_paths_up_to(shallow_tree, 1401 print_path_and_values, 1402 lowercase, 1403 uppercase) 1404 path: ('a',), values: ('a', 'A') 1405 path: ('b', 0), values: ('b0', 'B0') 1406 path: ('b', 1), values: ('b1', 'B1') 1407 1408 shallow_tree = {'b': None} 1409 map_structure_with_tuple_paths_up_to(shallow_tree, 1410 print_path_and_values, 1411 lowercase, 1412 uppercase, 1413 check_types=False) 1414 path: ('b', 1), values: (('bo', 'b1'), ('B0', 'B1')) 1415 1416 shallow_tree = {'a': None, 'b': {1: None}} 1417 map_structure_with_tuple_paths_up_to(shallow_tree, 1418 print_path_and_values, 1419 lowercase, 1420 uppercase, 1421 check_types=False) 1422 path: ('a',), values: ('a', 'A') 1423 path: ('b', 1), values: ('b1', B1') 1424 ``` 1425 1426 Args: 1427 shallow_tree: a shallow tree, common to all the inputs. 1428 func: callable that takes args (path, inputs_0_value, ... , inputs_N_value), 1429 where path is a tuple path to a leaf node in shallow_tree, and 1430 inputs_i_value is the corresponding value from inputs[i]. 1431 *inputs: nested structures that are all structurally compatible with 1432 shallow_tree. 1433 **kwargs: kwargs to feed to func(). Special kwarg 1434 `check_types` is not passed to func, but instead determines whether the 1435 types of iterables within the structures have to be same (e.g. 1436 `map_structure(func, [1], (1,))` raises a `TypeError` exception). To allow 1437 this set this argument to `False`. 1438 1439 Raises: 1440 TypeError: If `shallow_tree` is a sequence but one of `*inputs` is not. 1441 TypeError: If the sequence types of `shallow_tree` are different from 1442 `input_tree`. 1443 ValueError: If the sequence lengths of `shallow_tree` are different from 1444 `input_tree`. 1445 1446 Returns: 1447 Result of repeatedly applying `func`. Has the same structure layout as 1448 `shallow_tree`. 1449 """ 1450 if not inputs: 1451 raise ValueError("Cannot map over no sequences") 1452 1453 check_types = kwargs.pop("check_types", True) 1454 expand_composites = kwargs.pop("expand_composites", False) 1455 is_seq = is_sequence_or_composite if expand_composites else is_sequence 1456 1457 for input_tree in inputs: 1458 assert_shallow_structure( 1459 shallow_tree, 1460 input_tree, 1461 check_types=check_types, 1462 expand_composites=expand_composites) 1463 1464 # Flatten each input separately, apply the function to corresponding elements, 1465 # then repack based on the structure of the first input. 1466 flat_value_gen = ( 1467 flatten_up_to( # pylint: disable=g-complex-comprehension 1468 shallow_tree, 1469 input_tree, 1470 check_types, 1471 expand_composites=expand_composites) for input_tree in inputs) 1472 flat_path_gen = ( 1473 path for path, _ in _yield_flat_up_to(shallow_tree, inputs[0], is_seq)) 1474 results = [ 1475 func(*args, **kwargs) for args in zip(flat_path_gen, *flat_value_gen) 1476 ] 1477 return pack_sequence_as(structure=shallow_tree, flat_sequence=results, 1478 expand_composites=expand_composites) 1479 1480 1481@tf_export("__internal__.nest.get_traverse_shallow_structure", v1=[]) 1482def get_traverse_shallow_structure(traverse_fn, structure, 1483 expand_composites=False): 1484 """Generates a shallow structure from a `traverse_fn` and `structure`. 1485 1486 `traverse_fn` must accept any possible subtree of `structure` and return 1487 a depth=1 structure containing `True` or `False` values, describing which 1488 of the top-level subtrees may be traversed. It may also 1489 return scalar `True` or `False` "traversal is OK / not OK for all subtrees." 1490 1491 Examples are available in the unit tests (nest_test.py). 1492 1493 Args: 1494 traverse_fn: Function taking a substructure and returning either a scalar 1495 `bool` (whether to traverse that substructure or not) or a depth=1 1496 shallow structure of the same type, describing which parts of the 1497 substructure to traverse. 1498 structure: The structure to traverse. 1499 expand_composites: If true, then composite tensors such as 1500 `tf.sparse.SparseTensor` and `tf.RaggedTensor` are expanded into their 1501 component tensors. 1502 1503 Returns: 1504 A shallow structure containing python bools, which can be passed to 1505 `map_structure_up_to` and `flatten_up_to`. 1506 1507 Raises: 1508 TypeError: if `traverse_fn` returns a sequence for a non-sequence input, 1509 or a structure with depth higher than 1 for a sequence input, 1510 or if any leaf values in the returned structure or scalar are not type 1511 `bool`. 1512 """ 1513 is_seq = is_sequence_or_composite if expand_composites else is_sequence 1514 to_traverse = traverse_fn(structure) 1515 if not is_seq(structure): 1516 if not isinstance(to_traverse, bool): 1517 raise TypeError("traverse_fn returned structure: %s for non-structure: %s" 1518 % (to_traverse, structure)) 1519 return to_traverse 1520 level_traverse = [] 1521 if isinstance(to_traverse, bool): 1522 if not to_traverse: 1523 # Do not traverse this substructure at all. Exit early. 1524 return False 1525 else: 1526 # Traverse the entire substructure. 1527 for branch in _yield_value(structure): 1528 level_traverse.append( 1529 get_traverse_shallow_structure(traverse_fn, branch, 1530 expand_composites=expand_composites)) 1531 elif not is_seq(to_traverse): 1532 raise TypeError("traverse_fn returned a non-bool scalar: %s for input: %s" 1533 % (to_traverse, structure)) 1534 else: 1535 # Traverse some subset of this substructure. 1536 assert_shallow_structure(to_traverse, structure, 1537 expand_composites=expand_composites) 1538 for t, branch in zip(_yield_value(to_traverse), 1539 _yield_value(structure)): 1540 if not isinstance(t, bool): 1541 raise TypeError( 1542 "traverse_fn didn't return a depth=1 structure of bools. saw: %s " 1543 " for structure: %s" % (to_traverse, structure)) 1544 if t: 1545 level_traverse.append( 1546 get_traverse_shallow_structure(traverse_fn, branch)) 1547 else: 1548 level_traverse.append(False) 1549 return _sequence_like(structure, level_traverse) 1550 1551 1552@tf_export("__internal__.nest.yield_flat_paths", v1=[]) 1553def yield_flat_paths(nest, expand_composites=False): 1554 """Yields paths for some nested structure. 1555 1556 Paths are lists of objects which can be str-converted, which may include 1557 integers or other types which are used as indices in a dict. 1558 1559 The flat list will be in the corresponding order as if you called 1560 `nest.flatten` on the structure. This is handy for naming Tensors such 1561 the TF scope structure matches the tuple structure. 1562 1563 E.g. if we have a tuple `value = Foo(a=3, b=Bar(c=23, d=42))` 1564 1565 ```shell 1566 nest.flatten(value) 1567 [3, 23, 42] 1568 list(nest.yield_flat_paths(value)) 1569 [('a',), ('b', 'c'), ('b', 'd')] 1570 ``` 1571 1572 ```shell 1573 list(nest.yield_flat_paths({'a': [3]})) 1574 [('a', 0)] 1575 list(nest.yield_flat_paths({'a': 3})) 1576 [('a',)] 1577 ``` 1578 1579 Args: 1580 nest: the value to produce a flattened paths list for. 1581 expand_composites: If true, then composite tensors such as 1582 `tf.sparse.SparseTensor` and `tf.RaggedTensor` are expanded into their 1583 component tensors. 1584 1585 Yields: 1586 Tuples containing index or key values which form the path to a specific 1587 leaf value in the nested structure. 1588 """ 1589 is_seq = is_sequence_or_composite if expand_composites else is_sequence 1590 for k, _ in _yield_flat_up_to(nest, nest, is_seq): 1591 yield k 1592 1593 1594def flatten_with_joined_string_paths(structure, separator="/", 1595 expand_composites=False): 1596 """Returns a list of (string path, data element) tuples. 1597 1598 The order of tuples produced matches that of `nest.flatten`. This allows you 1599 to flatten a nested structure while keeping information about where in the 1600 structure each data element was located. See `nest.yield_flat_paths` 1601 for more information. 1602 1603 Args: 1604 structure: the nested structure to flatten. 1605 separator: string to separate levels of hierarchy in the results, defaults 1606 to '/'. 1607 expand_composites: If true, then composite tensors such as 1608 `tf.sparse.SparseTensor` and `tf.RaggedTensor` are expanded into their 1609 component tensors. 1610 1611 Returns: 1612 A list of (string, data element) tuples. 1613 """ 1614 flat_paths = yield_flat_paths(structure, expand_composites=expand_composites) 1615 def stringify_and_join(path_elements): 1616 return separator.join(str(path_element) for path_element in path_elements) 1617 1618 flat_string_paths = (stringify_and_join(path) for path in flat_paths) 1619 return list(zip(flat_string_paths, 1620 flatten(structure, expand_composites=expand_composites))) 1621 1622 1623def flatten_with_tuple_paths(structure, expand_composites=False): 1624 """Returns a list of `(tuple_path, leaf_element)` tuples. 1625 1626 The order of pairs produced matches that of `nest.flatten`. This allows you 1627 to flatten a nested structure while keeping information about where in the 1628 structure each data element was located. See `nest.yield_flat_paths` 1629 for more information about tuple paths. 1630 1631 Args: 1632 structure: the nested structure to flatten. 1633 expand_composites: If true, then composite tensors such as 1634 `tf.sparse.SparseTensor` and `tf.RaggedTensor` are expanded into their 1635 component tensors. 1636 1637 Returns: 1638 A list of `(tuple_path, leaf_element)` tuples. Each `tuple_path` is a tuple 1639 of indices and/or dictionary keys that uniquely specify the path to 1640 `leaf_element` within `structure`. 1641 """ 1642 return list(zip(yield_flat_paths(structure, 1643 expand_composites=expand_composites), 1644 flatten(structure, expand_composites=expand_composites))) 1645 1646 1647@tf_export("__internal__.nest.list_to_tuple", v1=[]) 1648def list_to_tuple(structure): 1649 """Replace all lists with tuples. 1650 1651 The fork of nest that tf.data uses treats lists as single elements, while 1652 tf.nest treats them as structures to recurse into. Keras has chosen to adopt 1653 the latter convention, and must therefore deeply replace all lists with tuples 1654 before passing structures to Dataset.from_generator. 1655 1656 Args: 1657 structure: A nested structure to be remapped. 1658 1659 Returns: 1660 structure mapped to replace all lists with tuples. 1661 """ 1662 def sequence_fn(instance, args): 1663 if isinstance(instance, list): 1664 return tuple(args) 1665 return _sequence_like(instance, args) 1666 1667 return _pack_sequence_as(structure, flatten(structure), False, 1668 sequence_fn=sequence_fn) 1669 1670 1671_pywrap_utils.RegisterType("Mapping", _collections_abc.Mapping) 1672_pywrap_utils.RegisterType("MutableMapping", _collections_abc.MutableMapping) 1673_pywrap_utils.RegisterType("Sequence", _collections_abc.Sequence) 1674_pywrap_utils.RegisterType("MappingView", _collections_abc.MappingView) 1675_pywrap_utils.RegisterType("ObjectProxy", _wrapt.ObjectProxy) 1676