• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for utilities working with arbitrarily nested structures."""
16
17import collections
18import numpy as np
19from absl.testing import parameterized
20
21from tensorflow.python.data.kernel_tests import test_base
22from tensorflow.python.data.util import nest
23from tensorflow.python.framework import combinations
24from tensorflow.python.framework import constant_op
25from tensorflow.python.framework import sparse_tensor
26from tensorflow.python.ops import array_ops
27from tensorflow.python.ops import math_ops
28from tensorflow.python.ops.ragged import ragged_factory_ops
29from tensorflow.python.platform import test
30
31
32class NestTest(test_base.DatasetTestBase, parameterized.TestCase):
33
34  @combinations.generate(test_base.default_test_combinations())
35  def testFlattenAndPack(self):
36    structure = ((3, 4), 5, (6, 7, (9, 10), 8))
37    flat = ["a", "b", "c", "d", "e", "f", "g", "h"]
38    self.assertEqual(nest.flatten(structure), [3, 4, 5, 6, 7, 9, 10, 8])
39    self.assertEqual(
40        nest.pack_sequence_as(structure, flat), (("a", "b"), "c",
41                                                 ("d", "e", ("f", "g"), "h")))
42    point = collections.namedtuple("Point", ["x", "y"])
43    structure = (point(x=4, y=2), ((point(x=1, y=0),),))
44    flat = [4, 2, 1, 0]
45    self.assertEqual(nest.flatten(structure), flat)
46    restructured_from_flat = nest.pack_sequence_as(structure, flat)
47    self.assertEqual(restructured_from_flat, structure)
48    self.assertEqual(restructured_from_flat[0].x, 4)
49    self.assertEqual(restructured_from_flat[0].y, 2)
50    self.assertEqual(restructured_from_flat[1][0][0].x, 1)
51    self.assertEqual(restructured_from_flat[1][0][0].y, 0)
52
53    self.assertEqual([5], nest.flatten(5))
54    self.assertEqual([np.array([5])], nest.flatten(np.array([5])))
55
56    self.assertEqual("a", nest.pack_sequence_as(5, ["a"]))
57    self.assertEqual(
58        np.array([5]), nest.pack_sequence_as("scalar", [np.array([5])]))
59
60    with self.assertRaisesRegex(ValueError, "Argument `structure` is a scalar"):
61      nest.pack_sequence_as("scalar", [4, 5])
62
63    with self.assertRaisesRegex(TypeError, "flat_sequence"):
64      nest.pack_sequence_as([4, 5], "bad_sequence")
65
66    with self.assertRaises(ValueError):
67      nest.pack_sequence_as([5, 6, [7, 8]], ["a", "b", "c"])
68
69  @combinations.generate(test_base.default_test_combinations())
70  def testFlattenDictOrder(self):
71    """`flatten` orders dicts by key, including OrderedDicts."""
72    ordered = collections.OrderedDict([("d", 3), ("b", 1), ("a", 0), ("c", 2)])
73    plain = {"d": 3, "b": 1, "a": 0, "c": 2}
74    ordered_flat = nest.flatten(ordered)
75    plain_flat = nest.flatten(plain)
76    self.assertEqual([0, 1, 2, 3], ordered_flat)
77    self.assertEqual([0, 1, 2, 3], plain_flat)
78
79  @combinations.generate(test_base.default_test_combinations())
80  def testPackDictOrder(self):
81    """Packing orders dicts by key, including OrderedDicts."""
82    ordered = collections.OrderedDict([("d", 0), ("b", 0), ("a", 0), ("c", 0)])
83    plain = {"d": 0, "b": 0, "a": 0, "c": 0}
84    seq = [0, 1, 2, 3]
85    ordered_reconstruction = nest.pack_sequence_as(ordered, seq)
86    plain_reconstruction = nest.pack_sequence_as(plain, seq)
87    self.assertEqual(
88        collections.OrderedDict([("d", 3), ("b", 1), ("a", 0), ("c", 2)]),
89        ordered_reconstruction)
90    self.assertEqual({"d": 3, "b": 1, "a": 0, "c": 2}, plain_reconstruction)
91
92  @combinations.generate(test_base.default_test_combinations())
93  def testFlattenAndPackWithDicts(self):
94    # A nice messy mix of tuples, lists, dicts, and `OrderedDict`s.
95    named_tuple = collections.namedtuple("A", ("b", "c"))
96    mess = (
97        "z",
98        named_tuple(3, 4),
99        {
100            "c": (
101                1,
102                collections.OrderedDict([
103                    ("b", 3),
104                    ("a", 2),
105                ]),
106            ),
107            "b": 5
108        },
109        17
110    )
111
112    flattened = nest.flatten(mess)
113    self.assertEqual(flattened, ["z", 3, 4, 5, 1, 2, 3, 17])
114
115    structure_of_mess = (
116        14,
117        named_tuple("a", True),
118        {
119            "c": (
120                0,
121                collections.OrderedDict([
122                    ("b", 9),
123                    ("a", 8),
124                ]),
125            ),
126            "b": 3
127        },
128        "hi everybody",
129    )
130
131    unflattened = nest.pack_sequence_as(structure_of_mess, flattened)
132    self.assertEqual(unflattened, mess)
133
134    # Check also that the OrderedDict was created, with the correct key order.
135    unflattened_ordered_dict = unflattened[2]["c"][1]
136    self.assertIsInstance(unflattened_ordered_dict, collections.OrderedDict)
137    self.assertEqual(list(unflattened_ordered_dict.keys()), ["b", "a"])
138
139  @combinations.generate(test_base.default_test_combinations())
140  def testFlattenSparseValue(self):
141    st = sparse_tensor.SparseTensorValue([[0]], [0], [1])
142    single_value = st
143    list_of_values = [st, st, st]
144    nest_of_values = ((st), ((st), (st)))
145    dict_of_values = {"foo": st, "bar": st, "baz": st}
146    self.assertEqual([st], nest.flatten(single_value))
147    self.assertEqual([[st, st, st]], nest.flatten(list_of_values))
148    self.assertEqual([st, st, st], nest.flatten(nest_of_values))
149    self.assertEqual([st, st, st], nest.flatten(dict_of_values))
150
151  @combinations.generate(test_base.default_test_combinations())
152  def testFlattenRaggedValue(self):
153    rt = ragged_factory_ops.constant_value([[[0]], [[1]]])
154    single_value = rt
155    list_of_values = [rt, rt, rt]
156    nest_of_values = ((rt), ((rt), (rt)))
157    dict_of_values = {"foo": rt, "bar": rt, "baz": rt}
158    self.assertEqual([rt], nest.flatten(single_value))
159    self.assertEqual([[rt, rt, rt]], nest.flatten(list_of_values))
160    self.assertEqual([rt, rt, rt], nest.flatten(nest_of_values))
161    self.assertEqual([rt, rt, rt], nest.flatten(dict_of_values))
162
163  @combinations.generate(test_base.default_test_combinations())
164  def testIsNested(self):
165    self.assertFalse(nest.is_nested("1234"))
166    self.assertFalse(nest.is_nested([1, 3, [4, 5]]))
167    self.assertTrue(nest.is_nested(((7, 8), (5, 6))))
168    self.assertFalse(nest.is_nested([]))
169    self.assertFalse(nest.is_nested(set([1, 2])))
170    ones = array_ops.ones([2, 3])
171    self.assertFalse(nest.is_nested(ones))
172    self.assertFalse(nest.is_nested(math_ops.tanh(ones)))
173    self.assertFalse(nest.is_nested(np.ones((4, 5))))
174    self.assertTrue(nest.is_nested({"foo": 1, "bar": 2}))
175    self.assertFalse(
176        nest.is_nested(sparse_tensor.SparseTensorValue([[0]], [0], [1])))
177    self.assertFalse(
178        nest.is_nested(ragged_factory_ops.constant_value([[[0]], [[1]]])))
179
180  @combinations.generate(test_base.default_test_combinations())
181  def testAssertSameStructure(self):
182    structure1 = (((1, 2), 3), 4, (5, 6))
183    structure2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6"))
184    structure_different_num_elements = ("spam", "eggs")
185    structure_different_nesting = (((1, 2), 3), 4, 5, (6,))
186    structure_dictionary = {"foo": 2, "bar": 4, "baz": {"foo": 5, "bar": 6}}
187    structure_dictionary_diff_nested = {
188        "foo": 2,
189        "bar": 4,
190        "baz": {
191            "foo": 5,
192            "baz": 6
193        }
194    }
195    nest.assert_same_structure(structure1, structure2)
196    nest.assert_same_structure("abc", 1.0)
197    nest.assert_same_structure("abc", np.array([0, 1]))
198    nest.assert_same_structure("abc", constant_op.constant([0, 1]))
199
200    with self.assertRaisesRegex(ValueError,
201                                "don't have the same nested structure"):
202      nest.assert_same_structure(structure1, structure_different_num_elements)
203
204    with self.assertRaisesRegex(ValueError,
205                                "don't have the same nested structure"):
206      nest.assert_same_structure((0, 1), np.array([0, 1]))
207
208    with self.assertRaisesRegex(ValueError,
209                                "don't have the same nested structure"):
210      nest.assert_same_structure(0, (0, 1))
211
212    with self.assertRaisesRegex(ValueError,
213                                "don't have the same nested structure"):
214      nest.assert_same_structure(structure1, structure_different_nesting)
215
216    named_type_0 = collections.namedtuple("named_0", ("a", "b"))
217    named_type_1 = collections.namedtuple("named_1", ("a", "b"))
218    self.assertRaises(TypeError, nest.assert_same_structure, (0, 1),
219                      named_type_0("a", "b"))
220
221    nest.assert_same_structure(named_type_0(3, 4), named_type_0("a", "b"))
222
223    self.assertRaises(TypeError, nest.assert_same_structure,
224                      named_type_0(3, 4), named_type_1(3, 4))
225
226    with self.assertRaisesRegex(ValueError,
227                                "don't have the same nested structure"):
228      nest.assert_same_structure(named_type_0(3, 4), named_type_0((3,), 4))
229
230    with self.assertRaisesRegex(ValueError,
231                                "don't have the same nested structure"):
232      nest.assert_same_structure(((3,), 4), (3, (4,)))
233
234    structure1_list = {"a": ((1, 2), 3), "b": 4, "c": (5, 6)}
235    structure2_list = {"a": ((1, 2), 3), "b": 4, "d": (5, 6)}
236    with self.assertRaisesRegex(TypeError, "don't have the same sequence type"):
237      nest.assert_same_structure(structure1, structure1_list)
238    nest.assert_same_structure(structure1, structure2, check_types=False)
239    nest.assert_same_structure(structure1, structure1_list, check_types=False)
240    with self.assertRaisesRegex(ValueError, "don't have the same set of keys"):
241      nest.assert_same_structure(structure1_list, structure2_list)
242    with self.assertRaisesRegex(ValueError, "don't have the same set of keys"):
243      nest.assert_same_structure(structure_dictionary,
244                                 structure_dictionary_diff_nested)
245    nest.assert_same_structure(
246        structure_dictionary,
247        structure_dictionary_diff_nested,
248        check_types=False)
249    nest.assert_same_structure(
250        structure1_list, structure2_list, check_types=False)
251
252  @combinations.generate(test_base.default_test_combinations())
253  def testMapStructure(self):
254    structure1 = (((1, 2), 3), 4, (5, 6))
255    structure2 = (((7, 8), 9), 10, (11, 12))
256    structure1_plus1 = nest.map_structure(lambda x: x + 1, structure1)
257    nest.assert_same_structure(structure1, structure1_plus1)
258    self.assertAllEqual(
259        [2, 3, 4, 5, 6, 7],
260        nest.flatten(structure1_plus1))
261    structure1_plus_structure2 = nest.map_structure(
262        lambda x, y: x + y, structure1, structure2)
263    self.assertEqual(
264        (((1 + 7, 2 + 8), 3 + 9), 4 + 10, (5 + 11, 6 + 12)),
265        structure1_plus_structure2)
266
267    self.assertEqual(3, nest.map_structure(lambda x: x - 1, 4))
268
269    self.assertEqual(7, nest.map_structure(lambda x, y: x + y, 3, 4))
270
271    with self.assertRaisesRegex(TypeError, "callable"):
272      nest.map_structure("bad", structure1_plus1)
273
274    with self.assertRaisesRegex(ValueError, "same nested structure"):
275      nest.map_structure(lambda x, y: None, 3, (3,))
276
277    with self.assertRaisesRegex(TypeError, "same sequence type"):
278      nest.map_structure(lambda x, y: None, ((3, 4), 5), {"a": (3, 4), "b": 5})
279
280    with self.assertRaisesRegex(ValueError, "same nested structure"):
281      nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5)))
282
283    with self.assertRaisesRegex(ValueError, "same nested structure"):
284      nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5)),
285                         check_types=False)
286
287    with self.assertRaisesRegex(ValueError, "Only valid keyword argument"):
288      nest.map_structure(lambda x: None, structure1, foo="a")
289
290    with self.assertRaisesRegex(ValueError, "Only valid keyword argument"):
291      nest.map_structure(lambda x: None, structure1, check_types=False, foo="a")
292
293  @combinations.generate(test_base.default_test_combinations())
294  def testAssertShallowStructure(self):
295    inp_ab = ("a", "b")
296    inp_abc = ("a", "b", "c")
297    expected_message = (
298        "The two structures don't have the same sequence length. Input "
299        "structure has length 2, while shallow structure has length 3.")
300    with self.assertRaisesRegex(ValueError, expected_message):
301      nest.assert_shallow_structure(inp_abc, inp_ab)
302
303    inp_ab1 = ((1, 1), (2, 2))
304    inp_ab2 = {"a": (1, 1), "b": (2, 2)}
305    expected_message = (
306        "The two structures don't have the same sequence type. Input structure "
307        "has type 'tuple', while shallow structure has type "
308        "'dict'.")
309    with self.assertRaisesRegex(TypeError, expected_message):
310      nest.assert_shallow_structure(inp_ab2, inp_ab1)
311    nest.assert_shallow_structure(inp_ab2, inp_ab1, check_types=False)
312
313    inp_ab1 = {"a": (1, 1), "b": {"c": (2, 2)}}
314    inp_ab2 = {"a": (1, 1), "b": {"d": (2, 2)}}
315    expected_message = (
316        r"The two structures don't have the same keys. Input "
317        r"structure has keys \['c'\], while shallow structure has "
318        r"keys \['d'\].")
319    with self.assertRaisesRegex(ValueError, expected_message):
320      nest.assert_shallow_structure(inp_ab2, inp_ab1)
321
322    inp_ab = collections.OrderedDict([("a", 1), ("b", (2, 3))])
323    inp_ba = collections.OrderedDict([("b", (2, 3)), ("a", 1)])
324    nest.assert_shallow_structure(inp_ab, inp_ba)
325
326  @combinations.generate(test_base.default_test_combinations())
327  def testFlattenUpTo(self):
328    input_tree = (((2, 2), (3, 3)), ((4, 9), (5, 5)))
329    shallow_tree = ((True, True), (False, True))
330    flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
331    flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
332    self.assertEqual(flattened_input_tree, [(2, 2), (3, 3), (4, 9), (5, 5)])
333    self.assertEqual(flattened_shallow_tree, [True, True, False, True])
334
335    input_tree = ((("a", 1), (("b", 2), (("c", 3), (("d", 4))))))
336    shallow_tree = (("level_1", ("level_2", ("level_3", ("level_4")))))
337    input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
338                                                              input_tree)
339    input_tree_flattened = nest.flatten(input_tree)
340    self.assertEqual(input_tree_flattened_as_shallow_tree,
341                     [("a", 1), ("b", 2), ("c", 3), ("d", 4)])
342    self.assertEqual(input_tree_flattened, ["a", 1, "b", 2, "c", 3, "d", 4])
343
344    ## Shallow non-list edge-case.
345    # Using iterable elements.
346    input_tree = ["input_tree"]
347    shallow_tree = "shallow_tree"
348    flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
349    flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
350    self.assertEqual(flattened_input_tree, [input_tree])
351    self.assertEqual(flattened_shallow_tree, [shallow_tree])
352
353    input_tree = ("input_tree_0", "input_tree_1")
354    shallow_tree = "shallow_tree"
355    flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
356    flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
357    self.assertEqual(flattened_input_tree, [input_tree])
358    self.assertEqual(flattened_shallow_tree, [shallow_tree])
359
360    # Using non-iterable elements.
361    input_tree = (0,)
362    shallow_tree = 9
363    flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
364    flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
365    self.assertEqual(flattened_input_tree, [input_tree])
366    self.assertEqual(flattened_shallow_tree, [shallow_tree])
367
368    input_tree = (0, 1)
369    shallow_tree = 9
370    flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
371    flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
372    self.assertEqual(flattened_input_tree, [input_tree])
373    self.assertEqual(flattened_shallow_tree, [shallow_tree])
374
375    ## Both non-list edge-case.
376    # Using iterable elements.
377    input_tree = "input_tree"
378    shallow_tree = "shallow_tree"
379    flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
380    flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
381    self.assertEqual(flattened_input_tree, [input_tree])
382    self.assertEqual(flattened_shallow_tree, [shallow_tree])
383
384    # Using non-iterable elements.
385    input_tree = 0
386    shallow_tree = 0
387    flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
388    flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
389    self.assertEqual(flattened_input_tree, [input_tree])
390    self.assertEqual(flattened_shallow_tree, [shallow_tree])
391
392    ## Input non-list edge-case.
393    # Using iterable elements.
394    input_tree = "input_tree"
395    shallow_tree = ("shallow_tree",)
396    expected_message = ("If shallow structure is a sequence, input must also "
397                        "be a sequence. Input has type: 'str'.")
398    with self.assertRaisesRegex(TypeError, expected_message):
399      flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
400    flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
401    self.assertEqual(flattened_shallow_tree, list(shallow_tree))
402
403    input_tree = "input_tree"
404    shallow_tree = ("shallow_tree_9", "shallow_tree_8")
405    with self.assertRaisesRegex(TypeError, expected_message):
406      flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
407    flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
408    self.assertEqual(flattened_shallow_tree, list(shallow_tree))
409
410    # Using non-iterable elements.
411    input_tree = 0
412    shallow_tree = (9,)
413    expected_message = ("If shallow structure is a sequence, input must also "
414                        "be a sequence. Input has type: 'int'.")
415    with self.assertRaisesRegex(TypeError, expected_message):
416      flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
417    flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
418    self.assertEqual(flattened_shallow_tree, list(shallow_tree))
419
420    input_tree = 0
421    shallow_tree = (9, 8)
422    with self.assertRaisesRegex(TypeError, expected_message):
423      flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
424    flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
425    self.assertEqual(flattened_shallow_tree, list(shallow_tree))
426
427    # Using dict.
428    input_tree = {"a": ((2, 2), (3, 3)), "b": ((4, 9), (5, 5))}
429    shallow_tree = {"a": (True, True), "b": (False, True)}
430    flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
431    flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
432    self.assertEqual(flattened_input_tree, [(2, 2), (3, 3), (4, 9), (5, 5)])
433    self.assertEqual(flattened_shallow_tree, [True, True, False, True])
434
435  @combinations.generate(test_base.default_test_combinations())
436  def testMapStructureUpTo(self):
437    ab_tuple = collections.namedtuple("ab_tuple", "a, b")
438    op_tuple = collections.namedtuple("op_tuple", "add, mul")
439    inp_val = ab_tuple(a=2, b=3)
440    inp_ops = ab_tuple(a=op_tuple(add=1, mul=2), b=op_tuple(add=2, mul=3))
441    out = nest.map_structure_up_to(
442        inp_val, lambda val, ops: (val + ops.add) * ops.mul, inp_val, inp_ops)
443    self.assertEqual(out.a, 6)
444    self.assertEqual(out.b, 15)
445
446    data_list = ((2, 4, 6, 8), ((1, 3, 5, 7, 9), (3, 5, 7)))
447    name_list = ("evens", ("odds", "primes"))
448    out = nest.map_structure_up_to(
449        name_list, lambda name, sec: "first_{}_{}".format(len(sec), name),
450        name_list, data_list)
451    self.assertEqual(out, ("first_4_evens", ("first_5_odds", "first_3_primes")))
452
453
454if __name__ == "__main__":
455  test.main()
456