• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""## Functions for working with arbitrarily nested sequences of elements.
17
18This module can perform operations on nested structures. A nested structure is a
19Python collection that can contain further collections as well as other objects
20called atoms. Note that numpy arrays are considered atoms.
21
22nest recognizes the following types of collections:
23  1.tuple
24  2.namedtuple
25  3.dict
26  4.orderedDict
27  5.MutableMapping
28  6.attr.s
29
30attr.s decorated classes (http://www.attrs.org) are also supported, in the
31same way as `namedtuple`.
32
33The utilities here assume (and do not check) that the nested structures form a
34'tree', i.e., no references in the structure of the input of these functions
35should be recursive.
36
37Example structures: `((3, 4), 5, (6, 7, (9, 10), 8))`, `(np.array(0),
38  (np.array([3, 4]), tf.constant([3, 4])))`
39"""
40
41from __future__ import absolute_import
42from __future__ import division
43from __future__ import print_function
44
45import collections as _collections
46
47import six as _six
48import wrapt as _wrapt
49
50from tensorflow.python.platform import tf_logging
51from tensorflow.python.util import _pywrap_nest
52from tensorflow.python.util import _pywrap_utils
53from tensorflow.python.util.compat import collections_abc as _collections_abc
54from tensorflow.python.util.tf_export import tf_export
55
56
57_SHALLOW_TREE_HAS_INVALID_KEYS = (
58    "The shallow_tree's keys are not a subset of the input_tree's keys. The "
59    "shallow_tree has the following keys that are not in the input_tree: {}.")
60
61_STRUCTURES_HAVE_MISMATCHING_TYPES = (
62    "The two structures don't have the same sequence type. Input structure has "
63    "type {input_type}, while shallow structure has type {shallow_type}.")
64
65_STRUCTURES_HAVE_MISMATCHING_LENGTHS = (
66    "The two structures don't have the same sequence length. Input "
67    "structure has length {input_length}, while shallow structure has length "
68    "{shallow_length}."
69)
70
71_INPUT_TREE_SMALLER_THAN_SHALLOW_TREE = (
72    "The input_tree has fewer elements than the shallow_tree. Input structure "
73    "has length {input_size}, while shallow structure has length "
74    "{shallow_size}.")
75
76_IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ = (
77    "If shallow structure is a sequence, input must also be a sequence. "
78    "Input has type: {}.")
79
80
81def _get_attrs_items(obj):
82  """Returns a list of (name, value) pairs from an attrs instance.
83
84  The list will be sorted by name.
85
86  Args:
87    obj: an object.
88
89  Returns:
90    A list of (attr_name, attr_value) pairs, sorted by attr_name.
91  """
92  attrs = getattr(obj.__class__, "__attrs_attrs__")
93  attr_names = (a.name for a in attrs)
94  return [(attr_name, getattr(obj, attr_name)) for attr_name in attr_names]
95
96
97def _sorted(dict_):
98  """Returns a sorted list of the dict keys, with error if keys not sortable."""
99  try:
100    return sorted(dict_.keys())
101  except TypeError:
102    raise TypeError("nest only supports dicts with sortable keys.")
103
104
105def is_namedtuple(instance, strict=False):
106  """Returns True iff `instance` is a `namedtuple`.
107
108  Args:
109    instance: An instance of a Python object.
110    strict: If True, `instance` is considered to be a `namedtuple` only if
111        it is a "plain" namedtuple. For instance, a class inheriting
112        from a `namedtuple` will be considered to be a `namedtuple`
113        iff `strict=False`.
114
115  Returns:
116    True if `instance` is a `namedtuple`.
117  """
118  return _pywrap_utils.IsNamedtuple(instance, strict)
119
120_is_namedtuple = is_namedtuple  # This function was private up to TF2.5.
121
122# See the swig file (util.i) for documentation.
123_is_mapping_view = _pywrap_utils.IsMappingView
124_is_attrs = _pywrap_utils.IsAttrs
125_is_composite_tensor = _pywrap_utils.IsCompositeTensor
126_is_type_spec = _pywrap_utils.IsTypeSpec
127_is_mutable_mapping = _pywrap_utils.IsMutableMapping
128_is_mapping = _pywrap_utils.IsMapping
129
130
131@tf_export("__internal__.nest.is_attrs", v1=[])
132def is_attrs(obj):
133  """Returns a true if its input is an instance of an attr.s decorated class."""
134  return _is_attrs(obj)
135
136
137@tf_export("__internal__.nest.is_mapping", v1=[])
138def is_mapping(obj):
139  """Returns a true if its input is a collections.Mapping."""
140  return _is_mapping(obj)
141
142
143@tf_export("__internal__.nest.sequence_like", v1=[])
144def _sequence_like(instance, args):
145  """Converts the sequence `args` to the same type as `instance`.
146
147  Args:
148    instance: an instance of `tuple`, `list`, `namedtuple`, `dict`,
149        `collections.OrderedDict`, or `composite_tensor.Composite_Tensor`
150        or `type_spec.TypeSpec`.
151    args: elements to be converted to the `instance` type.
152
153  Returns:
154    `args` with the type of `instance`.
155  """
156  if _is_mutable_mapping(instance):
157    # Pack dictionaries in a deterministic order by sorting the keys.
158    # Notice this means that we ignore the original order of `OrderedDict`
159    # instances. This is intentional, to avoid potential bugs caused by mixing
160    # ordered and plain dicts (e.g., flattening a dict but using a
161    # corresponding `OrderedDict` to pack it back).
162    result = dict(zip(_sorted(instance), args))
163    instance_type = type(instance)
164    if instance_type == _collections.defaultdict:
165      d = _collections.defaultdict(instance.default_factory)
166    else:
167      d = instance_type()
168    for key in instance:
169      d[key] = result[key]
170    return d
171  elif _is_mapping(instance):
172    result = dict(zip(_sorted(instance), args))
173    instance_type = type(instance)
174    tf_logging.log_first_n(
175        tf_logging.WARN, "Mapping types may not work well with tf.nest. Prefer"
176        " using MutableMapping for {}".format(instance_type), 1)
177    try:
178      return instance_type((key, result[key]) for key in instance)
179    except TypeError as err:
180      raise TypeError("Error creating an object of type {} like {}. Note that "
181                      "it must accept a single positional argument "
182                      "representing an iterable of key-value pairs, in "
183                      "addition to self. Cause: {}".format(
184                          type(instance), instance, err))
185  elif _is_mapping_view(instance):
186    # We can't directly construct mapping views, so we create a list instead
187    return list(args)
188  elif is_namedtuple(instance) or _is_attrs(instance):
189    if isinstance(instance, _wrapt.ObjectProxy):
190      instance_type = type(instance.__wrapped__)
191    else:
192      instance_type = type(instance)
193    return instance_type(*args)
194  elif _is_composite_tensor(instance):
195    assert len(args) == 1
196    spec = instance._type_spec  # pylint: disable=protected-access
197    return spec._from_components(args[0])  # pylint: disable=protected-access
198  elif _is_type_spec(instance):
199    # Pack a CompositeTensor's components according to a TypeSpec.
200    assert len(args) == 1
201    return instance._from_components(args[0])  # pylint: disable=protected-access
202  elif isinstance(instance, _six.moves.range):
203    return _sequence_like(list(instance), args)
204  elif isinstance(instance, _wrapt.ObjectProxy):
205    # For object proxies, first create the underlying type and then re-wrap it
206    # in the proxy type.
207    return type(instance)(_sequence_like(instance.__wrapped__, args))
208  else:
209    # Not a namedtuple
210    return type(instance)(args)
211
212
213def _yield_value(iterable):
214  for _, v in _yield_sorted_items(iterable):
215    yield v
216
217
218def _yield_sorted_items(iterable):
219  """Yield (key, value) pairs for `iterable` in a deterministic order.
220
221  For Sequences, the key will be an int, the array index of a value.
222  For Mappings, the key will be the dictionary key.
223  For objects (e.g. namedtuples), the key will be the attribute name.
224
225  In all cases, the keys will be iterated in sorted order.
226
227  Args:
228    iterable: an iterable.
229
230  Yields:
231    The iterable's (key, value) pairs, in order of sorted keys.
232  """
233  # Ordered to check common structure types (list, tuple, dict) first.
234  if isinstance(iterable, list):
235    for item in enumerate(iterable):
236      yield item
237  # namedtuples handled separately to avoid expensive namedtuple check.
238  elif type(iterable) == tuple:  # pylint: disable=unidiomatic-typecheck
239    for item in enumerate(iterable):
240      yield item
241  elif isinstance(iterable, (dict, _collections_abc.Mapping)):
242    # Iterate through dictionaries in a deterministic order by sorting the
243    # keys. Notice this means that we ignore the original order of `OrderedDict`
244    # instances. This is intentional, to avoid potential bugs caused by mixing
245    # ordered and plain dicts (e.g., flattening a dict but using a
246    # corresponding `OrderedDict` to pack it back).
247    for key in _sorted(iterable):
248      yield key, iterable[key]
249  elif _is_attrs(iterable):
250    for item in _get_attrs_items(iterable):
251      yield item
252  elif is_namedtuple(iterable):
253    for field in iterable._fields:
254      yield field, getattr(iterable, field)
255  elif _is_composite_tensor(iterable):
256    type_spec = iterable._type_spec  # pylint: disable=protected-access
257    yield type_spec.value_type.__name__, type_spec._to_components(iterable)  # pylint: disable=protected-access
258  elif _is_type_spec(iterable):
259    # Note: to allow CompositeTensors and their TypeSpecs to have matching
260    # structures, we need to use the same key string here.
261    yield iterable.value_type.__name__, iterable._component_specs  # pylint: disable=protected-access
262  else:
263    for item in enumerate(iterable):
264      yield item
265
266
267# See the swig file (util.i) for documentation.
268is_sequence = _pywrap_utils.IsSequence
269
270
271# See the swig file (util.i) for documentation.
272is_sequence_or_composite = _pywrap_utils.IsSequenceOrComposite
273
274
275@tf_export("nest.is_nested")
276def is_nested(seq):
277  """Returns true if its input is a collections.abc.Sequence (except strings).
278
279    >>> tf.nest.is_nested("1234")
280    False
281
282    >>> tf.nest.is_nested([1, 3, [4, 5]])
283    True
284
285    >>> tf.nest.is_nested(((7, 8), (5, 6)))
286    True
287
288    >>> tf.nest.is_nested([])
289    True
290
291    >>> tf.nest.is_nested({"a": 1, "b": 2})
292    True
293
294    >>> tf.nest.is_nested({"a": 1, "b": 2}.keys())
295    True
296
297    >>> tf.nest.is_nested({"a": 1, "b": 2}.values())
298    True
299
300    >>> tf.nest.is_nested({"a": 1, "b": 2}.items())
301    True
302
303    >>> tf.nest.is_nested(set([1, 2]))
304    False
305
306    >>> ones = tf.ones([2, 3])
307    >>> tf.nest.is_nested(ones)
308    False
309
310  Args:
311    seq: an input sequence.
312
313  Returns:
314    True if the sequence is a not a string and is a collections.abc.Sequence
315    or a dict.
316  """
317  return is_sequence(seq)
318
319
320@tf_export("nest.flatten")
321def flatten(structure, expand_composites=False):
322  """Returns a flat list from a given nested structure.
323
324  If nest is not a structure , tuple (or a namedtuple), dict, or an attrs class,
325  then returns a single-element list:
326    [nest].
327
328  This is the inverse of the `nest.pack_sequence_as` method that takes in a
329  flattened list and re-packs it into the nested structure.
330
331  In the case of dict instances, the sequence consists of the values, sorted by
332  key to ensure deterministic behavior. This is true also for OrderedDict
333  instances: their sequence order is ignored, the sorting order of keys is used
334  instead. The same convention is followed in `nest.pack_sequence_as`. This
335  correctly repacks dicts and OrderedDicts after they have been flattened, and
336  also allows flattening an OrderedDict and then repacking it back using a
337  corresponding plain dict, or vice-versa. Dictionaries with non-sortable keys
338  cannot be flattened.
339
340  Users must not modify any collections used in nest while this function is
341  running.
342
343  Examples:
344
345  1. Python dict (ordered by key):
346
347    >>> dict = { "key3": "value3", "key1": "value1", "key2": "value2" }
348    >>> tf.nest.flatten(dict)
349    ['value1', 'value2', 'value3']
350
351  2. For a nested python tuple:
352
353    >>> tuple = ((1.0, 2.0), (3.0, 4.0, 5.0), 6.0)
354    >>> tf.nest.flatten(tuple)
355        [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
356
357  3. For a nested dictionary of dictionaries:
358
359    >>> dict = { "key3": {"c": (1.0, 2.0), "a": (3.0)},
360    ... "key1": {"m": "val1", "g": "val2"} }
361    >>> tf.nest.flatten(dict)
362    ['val2', 'val1', 3.0, 1.0, 2.0]
363
364  4. Numpy array (will not flatten):
365
366    >>> array = np.array([[1, 2], [3, 4]])
367    >>> tf.nest.flatten(array)
368        [array([[1, 2],
369                [3, 4]])]
370
371  5. `tf.Tensor` (will not flatten):
372
373    >>> tensor = tf.constant([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]])
374    >>> tf.nest.flatten(tensor)
375        [<tf.Tensor: shape=(3, 3), dtype=float32, numpy=
376          array([[1., 2., 3.],
377                 [4., 5., 6.],
378                 [7., 8., 9.]], dtype=float32)>]
379
380  6. `tf.RaggedTensor`: This is a composite tensor thats representation consists
381  of a flattened list of 'values' and a list of 'row_splits' which indicate how
382  to chop up the flattened list into different rows. For more details on
383  `tf.RaggedTensor`, please visit
384  https://www.tensorflow.org/api_docs/python/tf/RaggedTensor.
385
386  with `expand_composites=False`, we just return the RaggedTensor as is.
387
388    >>> tensor = tf.ragged.constant([[3, 1, 4, 1], [], [5, 9, 2]])
389    >>> tf.nest.flatten(tensor, expand_composites=False)
390    [<tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2]]>]
391
392  with `expand_composites=True`, we return the component Tensors that make up
393  the RaggedTensor representation (the values and row_splits tensors)
394
395    >>> tensor = tf.ragged.constant([[3, 1, 4, 1], [], [5, 9, 2]])
396    >>> tf.nest.flatten(tensor, expand_composites=True)
397    [<tf.Tensor: shape=(7,), dtype=int32, numpy=array([3, 1, 4, 1, 5, 9, 2],
398                                                      dtype=int32)>,
399     <tf.Tensor: shape=(4,), dtype=int64, numpy=array([0, 4, 4, 7])>]
400
401  Args:
402    structure: an arbitrarily nested structure. Note, numpy arrays are
403      considered atoms and are not flattened.
404    expand_composites: If true, then composite tensors such as
405      `tf.sparse.SparseTensor` and `tf.RaggedTensor` are expanded into their
406      component tensors.
407
408  Returns:
409    A Python list, the flattened version of the input.
410
411  Raises:
412    TypeError: The nest is or contains a dict with non-sortable keys.
413  """
414  if structure is None:
415    return [None]
416  expand_composites = bool(expand_composites)
417  return _pywrap_utils.Flatten(structure, expand_composites)
418
419
420# See the swig file (util.i) for documentation.
421same_namedtuples = _pywrap_utils.SameNamedtuples
422_same_namedtuples = same_namedtuples  # This function was private up to TF2.5.
423
424
425class _DotString(object):
426
427  __slots__ = []
428
429  def __str__(self):
430    return "."
431
432  def __repr__(self):
433    return "."
434
435
436_DOT = _DotString()
437
438
439@tf_export("nest.assert_same_structure")
440def assert_same_structure(nest1, nest2, check_types=True,
441                          expand_composites=False):
442  """Asserts that two structures are nested in the same way.
443
444  Note the method does not check the types of data inside the structures.
445
446  Examples:
447
448  * These scalar vs. scalar comparisons will pass:
449
450    >>> tf.nest.assert_same_structure(1.5, tf.Variable(1, tf.uint32))
451    >>> tf.nest.assert_same_structure("abc", np.array([1, 2]))
452
453  * These sequence vs. sequence comparisons will pass:
454
455    >>> structure1 = (((1, 2), 3), 4, (5, 6))
456    >>> structure2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6"))
457    >>> structure3 = [(("a", "b"), "c"), "d", ["e", "f"]]
458    >>> tf.nest.assert_same_structure(structure1, structure2)
459    >>> tf.nest.assert_same_structure(structure1, structure3, check_types=False)
460
461    >>> import collections
462    >>> tf.nest.assert_same_structure(
463    ...     collections.namedtuple("bar", "a b")(1, 2),
464    ...     collections.namedtuple("foo", "a b")(2, 3),
465    ...     check_types=False)
466
467    >>> tf.nest.assert_same_structure(
468    ...     collections.namedtuple("bar", "a b")(1, 2),
469    ...     { "a": 1, "b": 2 },
470    ...     check_types=False)
471
472    >>> tf.nest.assert_same_structure(
473    ...     { "a": 1, "b": 2, "c": 3 },
474    ...     { "c": 6, "b": 5, "a": 4 })
475
476    >>> ragged_tensor1 = tf.RaggedTensor.from_row_splits(
477    ...       values=[3, 1, 4, 1, 5, 9, 2, 6],
478    ...       row_splits=[0, 4, 4, 7, 8, 8])
479    >>> ragged_tensor2 = tf.RaggedTensor.from_row_splits(
480    ...       values=[3, 1, 4],
481    ...       row_splits=[0, 3])
482    >>> tf.nest.assert_same_structure(
483    ...       ragged_tensor1,
484    ...       ragged_tensor2,
485    ...       expand_composites=True)
486
487  * These examples will raise exceptions:
488
489    >>> tf.nest.assert_same_structure([0, 1], np.array([0, 1]))
490    Traceback (most recent call last):
491    ...
492    ValueError: The two structures don't have the same nested structure
493
494    >>> tf.nest.assert_same_structure(
495    ...       collections.namedtuple('bar', 'a b')(1, 2),
496    ...       collections.namedtuple('foo', 'a b')(2, 3))
497    Traceback (most recent call last):
498    ...
499    TypeError: The two structures don't have the same nested structure
500
501  Args:
502    nest1: an arbitrarily nested structure.
503    nest2: an arbitrarily nested structure.
504    check_types: if `True` (default) types of sequences are checked as well,
505      including the keys of dictionaries. If set to `False`, for example a
506      list and a tuple of objects will look the same if they have the same
507      size. Note that namedtuples with identical name and fields are always
508      considered to have the same shallow structure. Two types will also be
509      considered the same if they are both list subtypes (which allows "list"
510      and "_ListWrapper" from trackable dependency tracking to compare
511      equal).
512    expand_composites: If true, then composite tensors such as
513      `tf.sparse.SparseTensor` and `tf.RaggedTensor` are expanded into their
514      component tensors.
515
516  Raises:
517    ValueError: If the two structures do not have the same number of elements or
518      if the two structures are not nested in the same way.
519    TypeError: If the two structures differ in the type of sequence in any of
520      their substructures. Only possible if `check_types` is `True`.
521  """
522  # Convert to bool explicitly as otherwise pybind will not be able# to handle
523  # type mismatch message correctly. See GitHub issue 42329 for details.
524  check_types = bool(check_types)
525  expand_composites = bool(expand_composites)
526  try:
527    _pywrap_utils.AssertSameStructure(nest1, nest2, check_types,
528                                      expand_composites)
529  except (ValueError, TypeError) as e:
530    str1 = str(map_structure(lambda _: _DOT, nest1))
531    str2 = str(map_structure(lambda _: _DOT, nest2))
532    raise type(e)("%s\n"
533                  "Entire first structure:\n%s\n"
534                  "Entire second structure:\n%s"
535                  % (str(e), str1, str2))
536
537
538def flatten_dict_items(dictionary):
539  """Returns a dictionary with flattened keys and values.
540
541  This function flattens the keys and values of a dictionary, which can be
542  arbitrarily nested structures, and returns the flattened version of such
543  structures:
544
545  ```python
546  example_dictionary = {(4, 5, (6, 8)): ("a", "b", ("c", "d"))}
547  result = {4: "a", 5: "b", 6: "c", 8: "d"}
548  flatten_dict_items(example_dictionary) == result
549  ```
550
551  The input dictionary must satisfy two properties:
552
553  1. Its keys and values should have the same exact nested structure.
554  2. The set of all flattened keys of the dictionary must not contain repeated
555     keys.
556
557  Args:
558    dictionary: the dictionary to zip
559
560  Returns:
561    The zipped dictionary.
562
563  Raises:
564    TypeError: If the input is not a dictionary.
565    ValueError: If any key and value do not have the same structure layout, or
566    if keys are not unique.
567  """
568  return _pywrap_nest.FlattenDictItems(dictionary)
569
570
571def _packed_nest_with_indices(structure, flat, index, is_seq, sequence_fn=None):
572  """Helper function for pack_sequence_as.
573
574  Args:
575    structure: Substructure (list / tuple / dict) to mimic.
576    flat: Flattened values to output substructure for.
577    index: Index at which to start reading from flat.
578    is_seq: Function used to test if a value should be treated as a sequence.
579    sequence_fn: Function used to generate a new sequence instance.
580
581  Returns:
582    The tuple (new_index, child), where:
583      * new_index - the updated index into `flat` having processed `structure`.
584      * packed - the subset of `flat` corresponding to `structure`,
585                 having started at `index`, and packed into the same nested
586                 format.
587
588  Raises:
589    ValueError: if `structure` contains more elements than `flat`
590      (assuming indexing starts from `index`).
591  """
592  packed = []
593  sequence_fn = sequence_fn or _sequence_like
594  for s in _yield_value(structure):
595    if is_seq(s):
596      new_index, child = _packed_nest_with_indices(s, flat, index, is_seq,
597                                                   sequence_fn)
598      packed.append(sequence_fn(s, child))
599      index = new_index
600    else:
601      packed.append(flat[index])
602      index += 1
603  return index, packed
604
605
606def _pack_sequence_as(structure, flat_sequence, expand_composites,
607                      sequence_fn=None):
608  """Implements sequence packing, with the option to alter the structure."""
609  is_seq = is_sequence_or_composite if expand_composites else is_sequence
610  sequence_fn = sequence_fn or _sequence_like
611  def truncate(value, length):
612    value_str = str(value)
613    return value_str[:length] + (value_str[length:] and "...")
614
615  if not is_seq(flat_sequence):
616    raise TypeError(
617        "Attempted to pack value:\n  {}\ninto a sequence, but found "
618        "incompatible type `{}` instead."
619        .format(truncate(flat_sequence, 100), type(flat_sequence)))
620
621  if not is_seq(structure):
622    if len(flat_sequence) != 1:
623      raise ValueError(
624          "The target structure is of type `{}`\n  {}\nHowever the input "
625          "structure is a sequence ({}) of length {}.\n  {}\nnest cannot "
626          "guarantee that it is safe to map one to the other.".format(
627              type(structure), truncate(structure, 100), type(flat_sequence),
628              len(flat_sequence), truncate(flat_sequence, 100)))
629    return flat_sequence[0]
630
631  try:
632    final_index, packed = _packed_nest_with_indices(structure, flat_sequence,
633                                                    0, is_seq, sequence_fn)
634    if final_index < len(flat_sequence):
635      raise IndexError
636  except IndexError:
637    flat_structure = flatten(structure, expand_composites=expand_composites)
638    if len(flat_structure) != len(flat_sequence):
639      raise ValueError(
640          "Could not pack sequence. Structure had %d elements, but "
641          "flat_sequence had %d elements.  Structure: %s, flat_sequence: %s." %
642          (len(flat_structure), len(flat_sequence), structure, flat_sequence))
643  return sequence_fn(structure, packed)
644
645
646@tf_export("nest.pack_sequence_as")
647def pack_sequence_as(structure, flat_sequence, expand_composites=False):
648  """Returns a given flattened sequence packed into a given structure.
649
650  If `structure` is a scalar, `flat_sequence` must be a single-element list;
651  in this case the return value is `flat_sequence[0]`.
652
653  If `structure` is or contains a dict instance, the keys will be sorted to
654  pack the flat sequence in deterministic order. This is true also for
655  `OrderedDict` instances: their sequence order is ignored, the sorting order of
656  keys is used instead. The same convention is followed in `flatten`.
657  This correctly repacks dicts and `OrderedDict`s after they have been
658  flattened, and also allows flattening an `OrderedDict` and then repacking it
659  back using a corresponding plain dict, or vice-versa.
660  Dictionaries with non-sortable keys cannot be flattened.
661
662  Examples:
663
664  1. Python dict:
665
666    >>> structure = { "key3": "", "key1": "", "key2": "" }
667    >>> flat_sequence = ["value1", "value2", "value3"]
668    >>> tf.nest.pack_sequence_as(structure, flat_sequence)
669    {'key3': 'value3', 'key1': 'value1', 'key2': 'value2'}
670
671  2. For a nested python tuple:
672
673    >>> structure = (('a','b'), ('c','d','e'), 'f')
674    >>> flat_sequence = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
675    >>> tf.nest.pack_sequence_as(structure, flat_sequence)
676    ((1.0, 2.0), (3.0, 4.0, 5.0), 6.0)
677
678  3. For a nested dictionary of dictionaries:
679
680    >>> structure = { "key3": {"c": ('alpha', 'beta'), "a": ('gamma')},
681    ...               "key1": {"e": "val1", "d": "val2"} }
682    >>> flat_sequence = ['val2', 'val1', 3.0, 1.0, 2.0]
683    >>> tf.nest.pack_sequence_as(structure, flat_sequence)
684    {'key3': {'c': (1.0, 2.0), 'a': 3.0}, 'key1': {'e': 'val1', 'd': 'val2'}}
685
686  4. Numpy array (considered a scalar):
687
688    >>> structure = ['a']
689    >>> flat_sequence = [np.array([[1, 2], [3, 4]])]
690    >>> tf.nest.pack_sequence_as(structure, flat_sequence)
691    [array([[1, 2],
692           [3, 4]])]
693
694  5. tf.Tensor (considered a scalar):
695
696    >>> structure = ['a']
697    >>> flat_sequence = [tf.constant([[1., 2., 3.], [4., 5., 6.]])]
698    >>> tf.nest.pack_sequence_as(structure, flat_sequence)
699    [<tf.Tensor: shape=(2, 3), dtype=float32,
700     numpy= array([[1., 2., 3.], [4., 5., 6.]], dtype=float32)>]
701
702  6. `tf.RaggedTensor`: This is a composite tensor thats representation consists
703  of a flattened list of 'values' and a list of 'row_splits' which indicate how
704  to chop up the flattened list into different rows. For more details on
705  `tf.RaggedTensor`, please visit
706  https://www.tensorflow.org/api_docs/python/tf/RaggedTensor.
707
708  With `expand_composites=False`, we treat RaggedTensor as a scalar.
709
710    >>> structure = { "foo": tf.ragged.constant([[1, 2], [3]]),
711    ...               "bar": tf.constant([[5]]) }
712    >>> flat_sequence = [ "one", "two" ]
713    >>> tf.nest.pack_sequence_as(structure, flat_sequence,
714    ... expand_composites=False)
715    {'foo': 'two', 'bar': 'one'}
716
717  With `expand_composites=True`, we expect that the flattened input contains
718  the tensors making up the ragged tensor i.e. the values and row_splits
719  tensors.
720
721    >>> structure = { "foo": tf.ragged.constant([[1., 2.], [3.]]),
722    ...               "bar": tf.constant([[5.]]) }
723    >>> tensors = tf.nest.flatten(structure, expand_composites=True)
724    >>> print(tensors)
725    [<tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[5.]],
726     dtype=float32)>,
727     <tf.Tensor: shape=(3,), dtype=float32, numpy=array([1., 2., 3.],
728     dtype=float32)>,
729     <tf.Tensor: shape=(3,), dtype=int64, numpy=array([0, 2, 3])>]
730    >>> verified_tensors = [tf.debugging.check_numerics(t, 'invalid tensor: ')
731    ...                     if t.dtype==tf.float32 else t
732    ...                     for t in tensors]
733    >>> tf.nest.pack_sequence_as(structure, verified_tensors,
734    ...                          expand_composites=True)
735    {'foo': <tf.RaggedTensor [[1.0, 2.0], [3.0]]>,
736     'bar': <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[5.]],
737     dtype=float32)>}
738
739  Args:
740    structure: Nested structure, whose structure is given by nested lists,
741      tuples, and dicts. Note: numpy arrays and strings are considered
742      scalars.
743    flat_sequence: flat sequence to pack.
744    expand_composites: If true, then composite tensors such as
745      `tf.sparse.SparseTensor` and `tf.RaggedTensor` are expanded into their
746      component tensors.
747
748  Returns:
749    packed: `flat_sequence` converted to have the same recursive structure as
750      `structure`.
751
752  Raises:
753    ValueError: If `flat_sequence` and `structure` have different
754      element counts.
755    TypeError: `structure` is or contains a dict with non-sortable keys.
756  """
757  return _pack_sequence_as(structure, flat_sequence, expand_composites)
758
759
760@tf_export("nest.map_structure")
761def map_structure(func, *structure, **kwargs):
762  """Applies `func` to each entry in `structure` and returns a new structure.
763
764  Applies `func(x[0], x[1], ...)` where x[i] is an entry in
765  `structure[i]`.  All structures in `structure` must have the same arity,
766  and the return value will contain results with the same structure layout.
767
768  Examples:
769
770  * A single Python dict:
771
772  >>> a = {"hello": 24, "world": 76}
773  >>> tf.nest.map_structure(lambda p: p * 2, a)
774  {'hello': 48, 'world': 152}
775
776  * Multiple Python dictionaries:
777
778  >>> d1 = {"hello": 24, "world": 76}
779  >>> d2 = {"hello": 36, "world": 14}
780  >>> tf.nest.map_structure(lambda p1, p2: p1 + p2, d1, d2)
781  {'hello': 60, 'world': 90}
782
783  * A single Python list:
784
785  >>> a = [24, 76, "ab"]
786  >>> tf.nest.map_structure(lambda p: p * 2, a)
787  [48, 152, 'abab']
788
789  * Scalars:
790
791  >>> tf.nest.map_structure(lambda x, y: x + y, 3, 4)
792  7
793
794  * Empty structures:
795
796  >>> tf.nest.map_structure(lambda x: x + 1, ())
797  ()
798
799  *. Check the types of iterables:
800
801  >>> s1 = (((1, 2), 3), 4, (5, 6))
802  >>> s1_list = [[[1, 2], 3], 4, [5, 6]]
803  >>> tf.nest.map_structure(lambda x, y: None, s1, s1_list)
804  Traceback (most recent call last):
805  ...
806  TypeError: The two structures don't have the same nested structure
807
808  * Type check is set to False:
809
810  >>> s1 = (((1, 2), 3), 4, (5, 6))
811  >>> s1_list = [[[1, 2], 3], 4, [5, 6]]
812  >>> tf.nest.map_structure(lambda x, y: None, s1, s1_list, check_types=False)
813  (((None, None), None), None, (None, None))
814
815  Args:
816    func: A callable that accepts as many arguments as there are structures.
817    *structure: scalar, or tuple or dict or list of constructed scalars and/or
818      other tuples/lists, or scalars.  Note: numpy arrays are considered as
819      scalars.
820    **kwargs: Valid keyword args are:
821
822      * `check_types`: If set to `True` (default) the types of
823        iterables within the structures have to be same (e.g.
824        `map_structure(func, [1], (1,))` raises a `TypeError`
825        exception). To allow this set this argument to `False`.
826        Note that namedtuples with identical name and fields are always
827        considered to have the same shallow structure.
828      * `expand_composites`: If set to `True`, then composite tensors such
829        as `tf.sparse.SparseTensor` and `tf.RaggedTensor` are expanded into
830        their component tensors.  If `False` (the default), then composite
831        tensors are not expanded.
832
833  Returns:
834    A new structure with the same arity as `structure`, whose values correspond
835    to `func(x[0], x[1], ...)` where `x[i]` is a value in the corresponding
836    location in `structure[i]`. If there are different sequence types and
837    `check_types` is `False` the sequence types of the first structure will be
838    used.
839
840  Raises:
841    TypeError: If `func` is not callable or if the structures do not match
842      each other by depth tree.
843    ValueError: If no structure is provided or if the structures do not match
844      each other by type.
845    ValueError: If wrong keyword arguments are provided.
846  """
847  if not callable(func):
848    raise TypeError("func must be callable, got: %s" % func)
849
850  if not structure:
851    raise ValueError("Must provide at least one structure")
852
853  check_types = kwargs.pop("check_types", True)
854  expand_composites = kwargs.pop("expand_composites", False)
855
856  if kwargs:
857    raise ValueError(
858        "Only valid keyword arguments are `check_types` and "
859        "`expand_composites`, not: `%s`" % ("`, `".join(kwargs.keys())))
860
861  for other in structure[1:]:
862    assert_same_structure(structure[0], other, check_types=check_types,
863                          expand_composites=expand_composites)
864
865  flat_structure = (flatten(s, expand_composites) for s in structure)
866  entries = zip(*flat_structure)
867
868  return pack_sequence_as(
869      structure[0], [func(*x) for x in entries],
870      expand_composites=expand_composites)
871
872
873def map_structure_with_paths(func, *structure, **kwargs):
874  """Applies `func` to each entry in `structure` and returns a new structure.
875
876  Applies `func(path, x[0], x[1], ..., **kwargs)` where x[i] is an entry in
877  `structure[i]` and `path` is the common path to x[i] in the structures.  All
878  structures in `structure` must have the same arity, and the return value will
879  contain the results with the same structure layout. Special kwarg
880  `check_types` determines whether the types of iterables within the structure
881  must be the same-- see **kwargs definition below.
882
883  Args:
884    func: A callable with the signature func(path, *values, **kwargs) that is
885      evaluated on the leaves of the structure.
886    *structure: A variable number of compatible structures to process.
887    **kwargs: Optional kwargs to be passed through to func. Special kwarg
888      `check_types` is not passed to func, but instead determines whether the
889      types of iterables within the structures have to be same (e.g.,
890      `map_structure(func, [1], (1,))` raises a `TypeError` exception). By
891      default, the types must match. To allow iteration over structures of
892      different types (but common arity), set this kwarg to `False`.
893
894  Returns:
895    A structure of the same form as the input structures whose leaves are the
896    result of evaluating func on corresponding leaves of the input structures.
897
898  Raises:
899    TypeError: If `func` is not callable or if the structures do not match
900      each other by depth tree.
901    TypeError: If `check_types` is not `False` and the two structures differ in
902      the type of sequence in any of their substructures.
903    ValueError: If no structures are provided.
904  """
905  def wrapper_func(tuple_path, *inputs, **kwargs):
906    string_path = "/".join(str(s) for s in tuple_path)
907    return func(string_path, *inputs, **kwargs)
908
909  return map_structure_with_tuple_paths_up_to(structure[0],
910                                              wrapper_func,
911                                              *structure,
912                                              **kwargs)
913
914
915def map_structure_with_tuple_paths(func, *structure, **kwargs):
916  """Applies `func` to each entry in `structure` and returns a new structure.
917
918  Applies `func(tuple_path, x[0], x[1], ..., **kwargs)` where `x[i]` is an entry
919  in `structure[i]` and `tuple_path` is a tuple of indices and/or dictionary
920  keys (as returned by `nest.yield_flat_paths`), which uniquely specifies the
921  common path to x[i] in the structures. All structures in `structure` must have
922  the same arity, and the return value will contain the results in the same
923  structure. Special kwarg `check_types` determines whether the types of
924  iterables within the structure must be the same-- see **kwargs definition
925  below.
926
927  Args:
928    func: A callable with the signature `func(tuple_path, *values, **kwargs)`
929      that is evaluated on the leaves of the structure.
930    *structure: A variable number of compatible structures to process.
931    **kwargs: Optional kwargs to be passed through to func. Special kwarg
932      `check_types` is not passed to func, but instead determines whether the
933      types of iterables within the structures have to be same (e.g.
934      `map_structure(func, [1], (1,))` raises a `TypeError` exception). To allow
935      this set this argument to `False`.
936
937  Returns:
938    A structure of the same form as the input structures whose leaves are the
939    result of evaluating func on corresponding leaves of the input structures.
940
941  Raises:
942    TypeError: If `func` is not callable or if the structures do not match
943      each other by depth tree.
944    TypeError: If `check_types` is not `False` and the two structures differ in
945      the type of sequence in any of their substructures.
946    ValueError: If no structures are provided.
947  """
948  return map_structure_with_tuple_paths_up_to(structure[0],
949                                              func,
950                                              *structure,
951                                              **kwargs)
952
953
954def _yield_flat_up_to(shallow_tree, input_tree, is_seq, path=()):
955  """Yields (path, value) pairs of input_tree flattened up to shallow_tree.
956
957  Args:
958    shallow_tree: Nested structure. Traverse no further than its leaf nodes.
959    input_tree: Nested structure. Return the paths and values from this tree.
960      Must have the same upper structure as shallow_tree.
961    is_seq: Function used to test if a value should be treated as a sequence.
962    path: Tuple. Optional argument, only used when recursing. The path from the
963      root of the original shallow_tree, down to the root of the shallow_tree
964      arg of this recursive call.
965
966  Yields:
967    Pairs of (path, value), where path the tuple path of a leaf node in
968    shallow_tree, and value is the value of the corresponding node in
969    input_tree.
970  """
971  if not is_seq(shallow_tree):
972    yield (path, input_tree)
973  else:
974    input_tree = dict(_yield_sorted_items(input_tree))
975    for shallow_key, shallow_subtree in _yield_sorted_items(shallow_tree):
976      subpath = path + (shallow_key,)
977      input_subtree = input_tree[shallow_key]
978      for leaf_path, leaf_value in _yield_flat_up_to(shallow_subtree,
979                                                     input_subtree, is_seq,
980                                                     path=subpath):
981        yield (leaf_path, leaf_value)
982
983
984def assert_shallow_structure(shallow_tree,
985                             input_tree,
986                             check_types=True,
987                             expand_composites=False):
988  """Asserts that `shallow_tree` is a shallow structure of `input_tree`.
989
990  That is, this function tests if the `input_tree` structure can be created from
991  the `shallow_tree` structure by replacing its leaf nodes with deeper
992  tree structures.
993
994  Examples:
995
996  The following code will raise an exception:
997  ```python
998    shallow_tree = {"a": "A", "b": "B"}
999    input_tree = {"a": 1, "c": 2}
1000    assert_shallow_structure(shallow_tree, input_tree)
1001  ```
1002
1003  The following code will raise an exception:
1004  ```python
1005    shallow_tree = ["a", "b"]
1006    input_tree = ["c", ["d", "e"], "f"]
1007    assert_shallow_structure(shallow_tree, input_tree)
1008  ```
1009
1010  Args:
1011    shallow_tree: an arbitrarily nested structure.
1012    input_tree: an arbitrarily nested structure.
1013    check_types: if `True` (default) the sequence types of `shallow_tree` and
1014      `input_tree` have to be the same. Note that even with check_types==True,
1015      this function will consider two different namedtuple classes with the same
1016      name and _fields attribute to be the same class.
1017    expand_composites: If true, then composite tensors such as
1018      `tf.sparse.SparseTensor` and `tf.RaggedTensor` are expanded into their
1019      component tensors.
1020  Raises:
1021    TypeError: If `shallow_tree` is a sequence but `input_tree` is not.
1022    TypeError: If the sequence types of `shallow_tree` are different from
1023      `input_tree`. Only raised if `check_types` is `True`.
1024    ValueError: If the sequence lengths of `shallow_tree` are different from
1025      `input_tree`.
1026  """
1027  is_seq = is_sequence_or_composite if expand_composites else is_sequence
1028  if is_seq(shallow_tree):
1029    if not is_seq(input_tree):
1030      raise TypeError(
1031          "If shallow structure is a sequence, input must also be a sequence. "
1032          "Input has type: %s." % type(input_tree))
1033
1034    if isinstance(shallow_tree, _wrapt.ObjectProxy):
1035      shallow_type = type(shallow_tree.__wrapped__)
1036    else:
1037      shallow_type = type(shallow_tree)
1038
1039    if check_types and not isinstance(input_tree, shallow_type):
1040      # Duck-typing means that nest should be fine with two different
1041      # namedtuples with identical name and fields.
1042      shallow_is_namedtuple = is_namedtuple(shallow_tree, False)
1043      input_is_namedtuple = is_namedtuple(input_tree, False)
1044      if shallow_is_namedtuple and input_is_namedtuple:
1045        if not same_namedtuples(shallow_tree, input_tree):
1046          raise TypeError(_STRUCTURES_HAVE_MISMATCHING_TYPES.format(
1047              input_type=type(input_tree),
1048              shallow_type=type(shallow_tree)))
1049
1050      elif isinstance(shallow_tree, list) and isinstance(input_tree, list):
1051        # List subclasses are considered the same,
1052        # e.g. python list vs. _ListWrapper.
1053        pass
1054
1055      elif ((_is_composite_tensor(shallow_tree) or
1056             _is_composite_tensor(input_tree)) and
1057            (_is_type_spec(shallow_tree) or _is_type_spec(input_tree))):
1058        pass  # Compatibility will be checked below.
1059
1060      elif not (isinstance(shallow_tree, _collections_abc.Mapping) and
1061                isinstance(input_tree, _collections_abc.Mapping)):
1062        raise TypeError(_STRUCTURES_HAVE_MISMATCHING_TYPES.format(
1063            input_type=type(input_tree),
1064            shallow_type=type(shallow_tree)))
1065
1066    if _is_composite_tensor(shallow_tree) or _is_composite_tensor(input_tree):
1067      if not (
1068          (_is_composite_tensor(input_tree) or _is_type_spec(input_tree)) and
1069          (_is_composite_tensor(shallow_tree) or _is_type_spec(shallow_tree))):
1070        raise TypeError(_STRUCTURES_HAVE_MISMATCHING_TYPES.format(
1071            input_type=type(input_tree),
1072            shallow_type=type(shallow_tree)))
1073      type_spec_1 = (shallow_tree if _is_type_spec(shallow_tree) else
1074                     shallow_tree._type_spec)  # pylint: disable=protected-access
1075      type_spec_2 = (input_tree if _is_type_spec(input_tree) else
1076                     input_tree._type_spec)  # pylint: disable=protected-access
1077      try:
1078        _ = type_spec_1.most_specific_compatible_type(type_spec_2)
1079      except (TypeError, ValueError) as e:
1080        raise ValueError(
1081            "Incompatible CompositeTensor TypeSpecs: %s vs. %s -- %s" %
1082            (type_spec_1, type_spec_2, e))
1083
1084    elif _is_type_spec(shallow_tree):
1085      if not _is_type_spec(input_tree):
1086        raise TypeError("If shallow structure is a TypeSpec, input must also "
1087                        "be a TypeSpec.  Input has type: %s."
1088                        % type(input_tree))
1089    else:
1090      if len(input_tree) != len(shallow_tree):
1091        raise ValueError(
1092            _STRUCTURES_HAVE_MISMATCHING_LENGTHS.format(
1093                input_length=len(input_tree), shallow_length=len(shallow_tree)))
1094      elif len(input_tree) < len(shallow_tree):
1095        raise ValueError(
1096            _INPUT_TREE_SMALLER_THAN_SHALLOW_TREE.format(
1097                input_size=len(input_tree), shallow_size=len(shallow_tree)))
1098
1099    if isinstance(shallow_tree, _collections_abc.Mapping):
1100      absent_keys = set(shallow_tree) - set(input_tree)
1101      if absent_keys:
1102        raise ValueError(_SHALLOW_TREE_HAS_INVALID_KEYS
1103                         .format(sorted(absent_keys)))
1104
1105    for shallow_branch, input_branch in zip(_yield_value(shallow_tree),
1106                                            _yield_value(input_tree)):
1107      assert_shallow_structure(shallow_branch, input_branch,
1108                               check_types=check_types,
1109                               expand_composites=expand_composites)
1110
1111
1112@tf_export("__internal__.nest.flatten_up_to", v1=[])
1113def flatten_up_to(shallow_tree, input_tree, check_types=True,
1114                  expand_composites=False):
1115  """Flattens `input_tree` up to `shallow_tree`.
1116
1117  Any further depth in structure in `input_tree` is retained as elements in the
1118  partially flatten output.
1119
1120  If `shallow_tree` and `input_tree` are not sequences, this returns a
1121  single-element list: `[input_tree]`.
1122
1123  Use Case:
1124
1125  Sometimes we may wish to partially flatten a nested sequence, retaining some
1126  of the nested structure. We achieve this by specifying a shallow structure,
1127  `shallow_tree`, we wish to flatten up to.
1128
1129  The input, `input_tree`, can be thought of as having the same structure layout
1130  as `shallow_tree`, but with leaf nodes that are themselves tree structures.
1131
1132  Examples:
1133
1134  ```python
1135  input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]]
1136  shallow_tree = [[True, True], [False, True]]
1137
1138  flattened_input_tree = flatten_up_to(shallow_tree, input_tree)
1139  flattened_shallow_tree = flatten_up_to(shallow_tree, shallow_tree)
1140
1141  # Output is:
1142  # [[2, 2], [3, 3], [4, 9], [5, 5]]
1143  # [True, True, False, True]
1144  ```
1145
1146  ```python
1147  input_tree = [[('a', 1), [('b', 2), [('c', 3), [('d', 4)]]]]]
1148  shallow_tree = [['level_1', ['level_2', ['level_3', ['level_4']]]]]
1149
1150  input_tree_flattened_as_shallow_tree = flatten_up_to(shallow_tree, input_tree)
1151  input_tree_flattened = flatten(input_tree)
1152
1153  # Output is:
1154  # [('a', 1), ('b', 2), ('c', 3), ('d', 4)]
1155  # ['a', 1, 'b', 2, 'c', 3, 'd', 4]
1156  ```
1157
1158  Non-Sequence Edge Cases:
1159
1160  ```python
1161  flatten_up_to(0, 0)  # Output: [0]
1162  flatten_up_to(0, [0, 1, 2])  # Output: [[0, 1, 2]]
1163  flatten_up_to([0, 1, 2], 0)  # Output: TypeError
1164  flatten_up_to([0, 1, 2], [0, 1, 2])  # Output: [0, 1, 2]
1165  ```
1166
1167  Args:
1168    shallow_tree: a possibly pruned structure of input_tree.
1169    input_tree: an arbitrarily nested structure or a scalar object.
1170      Note, numpy arrays are considered scalars.
1171    check_types: bool. If True, check that each node in shallow_tree has the
1172      same type as the corresponding node in input_tree.
1173    expand_composites: If true, then composite tensors such as
1174      `tf.sparse.SparseTensor` and `tf.RaggedTensor` are expanded into their
1175      component tensors.
1176
1177  Returns:
1178    A Python list, the partially flattened version of `input_tree` according to
1179    the structure of `shallow_tree`.
1180
1181  Raises:
1182    TypeError: If `shallow_tree` is a sequence but `input_tree` is not.
1183    TypeError: If the sequence types of `shallow_tree` are different from
1184      `input_tree`.
1185    ValueError: If the sequence lengths of `shallow_tree` are different from
1186      `input_tree`.
1187  """
1188  is_seq = is_sequence_or_composite if expand_composites else is_sequence
1189  assert_shallow_structure(shallow_tree,
1190                           input_tree,
1191                           check_types=check_types,
1192                           expand_composites=expand_composites)
1193  # Discard paths returned by _yield_flat_up_to.
1194  return [v for _, v in _yield_flat_up_to(shallow_tree, input_tree, is_seq)]
1195
1196
1197def flatten_with_tuple_paths_up_to(shallow_tree,
1198                                   input_tree,
1199                                   check_types=True,
1200                                   expand_composites=False):
1201  """Flattens `input_tree` up to `shallow_tree`.
1202
1203  Any further depth in structure in `input_tree` is retained as elements in the
1204  partially flattened output.
1205
1206  Returns a list of (path, value) pairs, where value a leaf node in the
1207  flattened tree, and path is the tuple path of that leaf in input_tree.
1208
1209  If `shallow_tree` and `input_tree` are not sequences, this returns a
1210  single-element list: `[((), input_tree)]`.
1211
1212  Use Case:
1213
1214  Sometimes we may wish to partially flatten a nested sequence, retaining some
1215  of the nested structure. We achieve this by specifying a shallow structure,
1216  `shallow_tree`, we wish to flatten up to.
1217
1218  The input, `input_tree`, can be thought of as having the same structure layout
1219  as `shallow_tree`, but with leaf nodes that are themselves tree structures.
1220
1221  Examples:
1222
1223  ```python
1224  input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]]
1225  shallow_tree = [[True, True], [False, True]]
1226
1227  flattened_input_tree = flatten_with_tuple_paths_up_to(shallow_tree,
1228                                                        input_tree)
1229  flattened_shallow_tree = flatten_with_tuple_paths_up_to(shallow_tree,
1230                                                          shallow_tree)
1231
1232  # Output is:
1233  # [((0, 0), [2, 2]),
1234  #  ((0, 1), [3, 3]),
1235  #  ((1, 0), [4, 9]),
1236  #  ((1, 1), [5, 5])]
1237  #
1238  # [((0, 0), True),
1239  #  ((0, 1), True),
1240  #  ((1, 0), False),
1241  #  ((1, 1), True)]
1242  ```
1243
1244  ```python
1245  input_tree = [[('a', 1), [('b', 2), [('c', 3), [('d', 4)]]]]]
1246  shallow_tree = [['level_1', ['level_2', ['level_3', ['level_4']]]]]
1247
1248  input_tree_flattened_as_shallow_tree = flatten_up_to(shallow_tree, input_tree)
1249  input_tree_flattened = flatten(input_tree)
1250
1251  # Output is:
1252  # [((0, 0), ('a', 1)),
1253  #  ((0, 1, 0), ('b', 2)),
1254  #  ((0, 1, 1, 0), ('c', 3)),
1255  #  ((0, 1, 1, 1), ('d', 4))]
1256  # ['a', 1, 'b', 2, 'c', 3, 'd', 4]
1257  ```
1258
1259  Non-Sequence Edge Cases:
1260
1261  ```python
1262  flatten_with_tuple_paths_up_to(0, 0)  # Output: [(), 0]
1263
1264  flatten_with_tuple_paths_up_to(0, [0, 1, 2])  # Output: [(), [0, 1, 2]]
1265
1266  flatten_with_tuple_paths_up_to([0, 1, 2], 0)  # Output: TypeError
1267
1268  flatten_with_tuple_paths_up_to([0, 1, 2], [0, 1, 2])
1269  # Output: [((0,) 0), ((1,), 1), ((2,), 2)]
1270  ```
1271
1272  Args:
1273    shallow_tree: a possibly pruned structure of input_tree.
1274    input_tree: an arbitrarily nested structure or a scalar object.
1275      Note, numpy arrays are considered scalars.
1276    check_types: bool. If True, check that each node in shallow_tree has the
1277      same type as the corresponding node in input_tree.
1278    expand_composites: If true, then composite tensors such as
1279      `tf.sparse.SparseTensor` and `tf.RaggedTensor` are expanded into their
1280      component tensors.
1281
1282  Returns:
1283    A Python list, the partially flattened version of `input_tree` according to
1284    the structure of `shallow_tree`.
1285
1286  Raises:
1287    TypeError: If `shallow_tree` is a sequence but `input_tree` is not.
1288    TypeError: If the sequence types of `shallow_tree` are different from
1289      `input_tree`.
1290    ValueError: If the sequence lengths of `shallow_tree` are different from
1291      `input_tree`.
1292  """
1293  is_seq = is_sequence_or_composite if expand_composites else is_sequence
1294  assert_shallow_structure(shallow_tree,
1295                           input_tree,
1296                           check_types=check_types,
1297                           expand_composites=expand_composites)
1298  return list(_yield_flat_up_to(shallow_tree, input_tree, is_seq))
1299
1300
1301@tf_export("__internal__.nest.map_structure_up_to", v1=[])
1302def map_structure_up_to(shallow_tree, func, *inputs, **kwargs):
1303  """Applies a function or op to a number of partially flattened inputs.
1304
1305  The `inputs` are flattened up to `shallow_tree` before being mapped.
1306
1307  Use Case:
1308
1309  Sometimes we wish to apply a function to a partially flattened
1310  sequence (for example when the function itself takes sequence inputs). We
1311  achieve this by specifying a shallow structure, `shallow_tree` we wish to
1312  flatten up to.
1313
1314  The `inputs`, can be thought of as having the same structure layout as
1315  `shallow_tree`, but with leaf nodes that are themselves tree structures.
1316
1317  This function therefore will return something with the same base structure as
1318  `shallow_tree`.
1319
1320  Examples:
1321
1322  ```python
1323  shallow_tree = [None, None]
1324  inp_val = [1, 2, 3]
1325  out = map_structure_up_to(shallow_tree, lambda x: 2 * x, inp_val)
1326
1327  # Output is: [2, 4]
1328  ```
1329
1330  ```python
1331  ab_tuple = collections.namedtuple("ab_tuple", "a, b")
1332  op_tuple = collections.namedtuple("op_tuple", "add, mul")
1333  inp_val = ab_tuple(a=2, b=3)
1334  inp_ops = ab_tuple(a=op_tuple(add=1, mul=2), b=op_tuple(add=2, mul=3))
1335  out = map_structure_up_to(inp_val, lambda val, ops: (val + ops.add) * ops.mul,
1336                            inp_val, inp_ops)
1337
1338  # Output is: ab_tuple(a=6, b=15)
1339  ```
1340
1341  ```python
1342  data_list = [[2, 4, 6, 8], [[1, 3, 5, 7, 9], [3, 5, 7]]]
1343  name_list = ['evens', ['odds', 'primes']]
1344  out = map_structure_up_to(
1345      name_list,
1346      lambda name, sec: "first_{}_{}".format(len(sec), name),
1347      name_list, data_list)
1348
1349  # Output is: ['first_4_evens', ['first_5_odds', 'first_3_primes']]
1350  ```
1351
1352  Args:
1353    shallow_tree: a shallow tree, common to all the inputs.
1354    func: callable which will be applied to each input individually.
1355    *inputs: arbitrarily nested combination of objects that are compatible with
1356        shallow_tree. The function `func` is applied to corresponding
1357        partially flattened elements of each input, so the function must support
1358        arity of `len(inputs)`.
1359    **kwargs: kwargs to feed to func(). Special kwarg
1360      `check_types` is not passed to func, but instead determines whether the
1361      types of iterables within the structures have to be same (e.g.
1362      `map_structure(func, [1], (1,))` raises a `TypeError` exception). To allow
1363      this set this argument to `False`.
1364
1365  Raises:
1366    TypeError: If `shallow_tree` is a sequence but `input_tree` is not.
1367    TypeError: If the sequence types of `shallow_tree` are different from
1368      `input_tree`.
1369    ValueError: If the sequence lengths of `shallow_tree` are different from
1370      `input_tree`.
1371
1372  Returns:
1373    result of repeatedly applying `func`, with the same structure layout as
1374    `shallow_tree`.
1375  """
1376  return map_structure_with_tuple_paths_up_to(
1377      shallow_tree,
1378      lambda _, *values: func(*values),  # Discards the path arg.
1379      *inputs,
1380      **kwargs)
1381
1382
1383def map_structure_with_tuple_paths_up_to(shallow_tree, func, *inputs, **kwargs):
1384  """Applies a function or op to a number of partially flattened inputs.
1385
1386  Like map_structure_up_to(), except that the 'func' argument takes a path
1387  tuple as its first argument, followed by the corresponding values from
1388  *inputs.
1389
1390  Example:
1391
1392  ```python
1393  lowercase = {'a': 'a', 'b': ('b0', 'b1')}
1394  uppercase = {'a': 'A', 'b': ('B0', 'B1')}
1395
1396  def print_path_and_values(path, *values):
1397    print("path: {}, values: {}".format(path, values))
1398
1399  shallow_tree = {'a': None}
1400  map_structure_with_tuple_paths_up_to(shallow_tree,
1401                                       print_path_and_values,
1402                                       lowercase,
1403                                       uppercase)
1404  path: ('a',), values: ('a', 'A')
1405  path: ('b', 0), values: ('b0', 'B0')
1406  path: ('b', 1), values: ('b1', 'B1')
1407
1408  shallow_tree = {'b': None}
1409  map_structure_with_tuple_paths_up_to(shallow_tree,
1410                                       print_path_and_values,
1411                                       lowercase,
1412                                       uppercase,
1413                                       check_types=False)
1414  path: ('b', 1), values: (('bo', 'b1'), ('B0', 'B1'))
1415
1416  shallow_tree = {'a': None, 'b': {1: None}}
1417  map_structure_with_tuple_paths_up_to(shallow_tree,
1418                                       print_path_and_values,
1419                                       lowercase,
1420                                       uppercase,
1421                                       check_types=False)
1422  path: ('a',), values: ('a', 'A')
1423  path: ('b', 1), values: ('b1', B1')
1424  ```
1425
1426  Args:
1427    shallow_tree: a shallow tree, common to all the inputs.
1428    func: callable that takes args (path, inputs_0_value, ... , inputs_N_value),
1429      where path is a tuple path to a leaf node in shallow_tree, and
1430      inputs_i_value is the corresponding value from inputs[i].
1431    *inputs: nested structures that are all structurally compatible with
1432        shallow_tree.
1433    **kwargs: kwargs to feed to func(). Special kwarg
1434      `check_types` is not passed to func, but instead determines whether the
1435      types of iterables within the structures have to be same (e.g.
1436      `map_structure(func, [1], (1,))` raises a `TypeError` exception). To allow
1437      this set this argument to `False`.
1438
1439  Raises:
1440    TypeError: If `shallow_tree` is a sequence but one of `*inputs` is not.
1441    TypeError: If the sequence types of `shallow_tree` are different from
1442      `input_tree`.
1443    ValueError: If the sequence lengths of `shallow_tree` are different from
1444      `input_tree`.
1445
1446  Returns:
1447    Result of repeatedly applying `func`. Has the same structure layout as
1448    `shallow_tree`.
1449  """
1450  if not inputs:
1451    raise ValueError("Cannot map over no sequences")
1452
1453  check_types = kwargs.pop("check_types", True)
1454  expand_composites = kwargs.pop("expand_composites", False)
1455  is_seq = is_sequence_or_composite if expand_composites else is_sequence
1456
1457  for input_tree in inputs:
1458    assert_shallow_structure(
1459        shallow_tree,
1460        input_tree,
1461        check_types=check_types,
1462        expand_composites=expand_composites)
1463
1464  # Flatten each input separately, apply the function to corresponding elements,
1465  # then repack based on the structure of the first input.
1466  flat_value_gen = (
1467      flatten_up_to(  # pylint: disable=g-complex-comprehension
1468          shallow_tree,
1469          input_tree,
1470          check_types,
1471          expand_composites=expand_composites) for input_tree in inputs)
1472  flat_path_gen = (
1473      path for path, _ in _yield_flat_up_to(shallow_tree, inputs[0], is_seq))
1474  results = [
1475      func(*args, **kwargs) for args in zip(flat_path_gen, *flat_value_gen)
1476  ]
1477  return pack_sequence_as(structure=shallow_tree, flat_sequence=results,
1478                          expand_composites=expand_composites)
1479
1480
1481@tf_export("__internal__.nest.get_traverse_shallow_structure", v1=[])
1482def get_traverse_shallow_structure(traverse_fn, structure,
1483                                   expand_composites=False):
1484  """Generates a shallow structure from a `traverse_fn` and `structure`.
1485
1486  `traverse_fn` must accept any possible subtree of `structure` and return
1487  a depth=1 structure containing `True` or `False` values, describing which
1488  of the top-level subtrees may be traversed.  It may also
1489  return scalar `True` or `False` "traversal is OK / not OK for all subtrees."
1490
1491  Examples are available in the unit tests (nest_test.py).
1492
1493  Args:
1494    traverse_fn: Function taking a substructure and returning either a scalar
1495      `bool` (whether to traverse that substructure or not) or a depth=1
1496      shallow structure of the same type, describing which parts of the
1497      substructure to traverse.
1498    structure: The structure to traverse.
1499    expand_composites: If true, then composite tensors such as
1500      `tf.sparse.SparseTensor` and `tf.RaggedTensor` are expanded into their
1501      component tensors.
1502
1503  Returns:
1504    A shallow structure containing python bools, which can be passed to
1505    `map_structure_up_to` and `flatten_up_to`.
1506
1507  Raises:
1508    TypeError: if `traverse_fn` returns a sequence for a non-sequence input,
1509      or a structure with depth higher than 1 for a sequence input,
1510      or if any leaf values in the returned structure or scalar are not type
1511      `bool`.
1512  """
1513  is_seq = is_sequence_or_composite if expand_composites else is_sequence
1514  to_traverse = traverse_fn(structure)
1515  if not is_seq(structure):
1516    if not isinstance(to_traverse, bool):
1517      raise TypeError("traverse_fn returned structure: %s for non-structure: %s"
1518                      % (to_traverse, structure))
1519    return to_traverse
1520  level_traverse = []
1521  if isinstance(to_traverse, bool):
1522    if not to_traverse:
1523      # Do not traverse this substructure at all.  Exit early.
1524      return False
1525    else:
1526      # Traverse the entire substructure.
1527      for branch in _yield_value(structure):
1528        level_traverse.append(
1529            get_traverse_shallow_structure(traverse_fn, branch,
1530                                           expand_composites=expand_composites))
1531  elif not is_seq(to_traverse):
1532    raise TypeError("traverse_fn returned a non-bool scalar: %s for input: %s"
1533                    % (to_traverse, structure))
1534  else:
1535    # Traverse some subset of this substructure.
1536    assert_shallow_structure(to_traverse, structure,
1537                             expand_composites=expand_composites)
1538    for t, branch in zip(_yield_value(to_traverse),
1539                         _yield_value(structure)):
1540      if not isinstance(t, bool):
1541        raise TypeError(
1542            "traverse_fn didn't return a depth=1 structure of bools.  saw: %s "
1543            " for structure: %s" % (to_traverse, structure))
1544      if t:
1545        level_traverse.append(
1546            get_traverse_shallow_structure(traverse_fn, branch))
1547      else:
1548        level_traverse.append(False)
1549  return _sequence_like(structure, level_traverse)
1550
1551
1552@tf_export("__internal__.nest.yield_flat_paths", v1=[])
1553def yield_flat_paths(nest, expand_composites=False):
1554  """Yields paths for some nested structure.
1555
1556  Paths are lists of objects which can be str-converted, which may include
1557  integers or other types which are used as indices in a dict.
1558
1559  The flat list will be in the corresponding order as if you called
1560  `nest.flatten` on the structure. This is handy for naming Tensors such
1561  the TF scope structure matches the tuple structure.
1562
1563  E.g. if we have a tuple `value = Foo(a=3, b=Bar(c=23, d=42))`
1564
1565  ```shell
1566  nest.flatten(value)
1567  [3, 23, 42]
1568  list(nest.yield_flat_paths(value))
1569  [('a',), ('b', 'c'), ('b', 'd')]
1570  ```
1571
1572  ```shell
1573  list(nest.yield_flat_paths({'a': [3]}))
1574  [('a', 0)]
1575  list(nest.yield_flat_paths({'a': 3}))
1576  [('a',)]
1577  ```
1578
1579  Args:
1580    nest: the value to produce a flattened paths list for.
1581    expand_composites: If true, then composite tensors such as
1582      `tf.sparse.SparseTensor` and `tf.RaggedTensor` are expanded into their
1583      component tensors.
1584
1585  Yields:
1586    Tuples containing index or key values which form the path to a specific
1587    leaf value in the nested structure.
1588  """
1589  is_seq = is_sequence_or_composite if expand_composites else is_sequence
1590  for k, _ in _yield_flat_up_to(nest, nest, is_seq):
1591    yield k
1592
1593
1594def flatten_with_joined_string_paths(structure, separator="/",
1595                                     expand_composites=False):
1596  """Returns a list of (string path, data element) tuples.
1597
1598  The order of tuples produced matches that of `nest.flatten`. This allows you
1599  to flatten a nested structure while keeping information about where in the
1600  structure each data element was located. See `nest.yield_flat_paths`
1601  for more information.
1602
1603  Args:
1604    structure: the nested structure to flatten.
1605    separator: string to separate levels of hierarchy in the results, defaults
1606      to '/'.
1607    expand_composites: If true, then composite tensors such as
1608      `tf.sparse.SparseTensor` and `tf.RaggedTensor` are expanded into their
1609      component tensors.
1610
1611  Returns:
1612    A list of (string, data element) tuples.
1613  """
1614  flat_paths = yield_flat_paths(structure, expand_composites=expand_composites)
1615  def stringify_and_join(path_elements):
1616    return separator.join(str(path_element) for path_element in path_elements)
1617
1618  flat_string_paths = (stringify_and_join(path) for path in flat_paths)
1619  return list(zip(flat_string_paths,
1620                  flatten(structure, expand_composites=expand_composites)))
1621
1622
1623def flatten_with_tuple_paths(structure, expand_composites=False):
1624  """Returns a list of `(tuple_path, leaf_element)` tuples.
1625
1626  The order of pairs produced matches that of `nest.flatten`. This allows you
1627  to flatten a nested structure while keeping information about where in the
1628  structure each data element was located. See `nest.yield_flat_paths`
1629  for more information about tuple paths.
1630
1631  Args:
1632    structure: the nested structure to flatten.
1633    expand_composites: If true, then composite tensors such as
1634      `tf.sparse.SparseTensor` and `tf.RaggedTensor` are expanded into their
1635      component tensors.
1636
1637  Returns:
1638    A list of `(tuple_path, leaf_element)` tuples. Each `tuple_path` is a tuple
1639    of indices and/or dictionary keys that uniquely specify the path to
1640    `leaf_element` within `structure`.
1641  """
1642  return list(zip(yield_flat_paths(structure,
1643                                   expand_composites=expand_composites),
1644                  flatten(structure, expand_composites=expand_composites)))
1645
1646
1647@tf_export("__internal__.nest.list_to_tuple", v1=[])
1648def list_to_tuple(structure):
1649  """Replace all lists with tuples.
1650
1651  The fork of nest that tf.data uses treats lists as single elements, while
1652  tf.nest treats them as structures to recurse into. Keras has chosen to adopt
1653  the latter convention, and must therefore deeply replace all lists with tuples
1654  before passing structures to Dataset.from_generator.
1655
1656  Args:
1657    structure: A nested structure to be remapped.
1658
1659  Returns:
1660    structure mapped to replace all lists with tuples.
1661  """
1662  def sequence_fn(instance, args):
1663    if isinstance(instance, list):
1664      return tuple(args)
1665    return _sequence_like(instance, args)
1666
1667  return _pack_sequence_as(structure, flatten(structure), False,
1668                           sequence_fn=sequence_fn)
1669
1670
1671_pywrap_utils.RegisterType("Mapping", _collections_abc.Mapping)
1672_pywrap_utils.RegisterType("MutableMapping", _collections_abc.MutableMapping)
1673_pywrap_utils.RegisterType("Sequence", _collections_abc.Sequence)
1674_pywrap_utils.RegisterType("MappingView", _collections_abc.MappingView)
1675_pywrap_utils.RegisterType("ObjectProxy", _wrapt.ObjectProxy)
1676