1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for utilities working with arbitrarily nested structures.""" 16 17import collections 18import numpy as np 19from absl.testing import parameterized 20 21from tensorflow.python.data.kernel_tests import test_base 22from tensorflow.python.data.util import nest 23from tensorflow.python.framework import combinations 24from tensorflow.python.framework import constant_op 25from tensorflow.python.framework import sparse_tensor 26from tensorflow.python.ops import array_ops 27from tensorflow.python.ops import math_ops 28from tensorflow.python.ops.ragged import ragged_factory_ops 29from tensorflow.python.platform import test 30 31 32class NestTest(test_base.DatasetTestBase, parameterized.TestCase): 33 34 @combinations.generate(test_base.default_test_combinations()) 35 def testFlattenAndPack(self): 36 structure = ((3, 4), 5, (6, 7, (9, 10), 8)) 37 flat = ["a", "b", "c", "d", "e", "f", "g", "h"] 38 self.assertEqual(nest.flatten(structure), [3, 4, 5, 6, 7, 9, 10, 8]) 39 self.assertEqual( 40 nest.pack_sequence_as(structure, flat), (("a", "b"), "c", 41 ("d", "e", ("f", "g"), "h"))) 42 point = collections.namedtuple("Point", ["x", "y"]) 43 structure = (point(x=4, y=2), ((point(x=1, y=0),),)) 44 flat = [4, 2, 1, 0] 45 self.assertEqual(nest.flatten(structure), flat) 46 restructured_from_flat = nest.pack_sequence_as(structure, flat) 47 self.assertEqual(restructured_from_flat, structure) 48 self.assertEqual(restructured_from_flat[0].x, 4) 49 self.assertEqual(restructured_from_flat[0].y, 2) 50 self.assertEqual(restructured_from_flat[1][0][0].x, 1) 51 self.assertEqual(restructured_from_flat[1][0][0].y, 0) 52 53 self.assertEqual([5], nest.flatten(5)) 54 self.assertEqual([np.array([5])], nest.flatten(np.array([5]))) 55 56 self.assertEqual("a", nest.pack_sequence_as(5, ["a"])) 57 self.assertEqual( 58 np.array([5]), nest.pack_sequence_as("scalar", [np.array([5])])) 59 60 with self.assertRaisesRegex(ValueError, "Argument `structure` is a scalar"): 61 nest.pack_sequence_as("scalar", [4, 5]) 62 63 with self.assertRaisesRegex(TypeError, "flat_sequence"): 64 nest.pack_sequence_as([4, 5], "bad_sequence") 65 66 with self.assertRaises(ValueError): 67 nest.pack_sequence_as([5, 6, [7, 8]], ["a", "b", "c"]) 68 69 @combinations.generate(test_base.default_test_combinations()) 70 def testFlattenDictOrder(self): 71 """`flatten` orders dicts by key, including OrderedDicts.""" 72 ordered = collections.OrderedDict([("d", 3), ("b", 1), ("a", 0), ("c", 2)]) 73 plain = {"d": 3, "b": 1, "a": 0, "c": 2} 74 ordered_flat = nest.flatten(ordered) 75 plain_flat = nest.flatten(plain) 76 self.assertEqual([0, 1, 2, 3], ordered_flat) 77 self.assertEqual([0, 1, 2, 3], plain_flat) 78 79 @combinations.generate(test_base.default_test_combinations()) 80 def testPackDictOrder(self): 81 """Packing orders dicts by key, including OrderedDicts.""" 82 ordered = collections.OrderedDict([("d", 0), ("b", 0), ("a", 0), ("c", 0)]) 83 plain = {"d": 0, "b": 0, "a": 0, "c": 0} 84 seq = [0, 1, 2, 3] 85 ordered_reconstruction = nest.pack_sequence_as(ordered, seq) 86 plain_reconstruction = nest.pack_sequence_as(plain, seq) 87 self.assertEqual( 88 collections.OrderedDict([("d", 3), ("b", 1), ("a", 0), ("c", 2)]), 89 ordered_reconstruction) 90 self.assertEqual({"d": 3, "b": 1, "a": 0, "c": 2}, plain_reconstruction) 91 92 @combinations.generate(test_base.default_test_combinations()) 93 def testFlattenAndPackWithDicts(self): 94 # A nice messy mix of tuples, lists, dicts, and `OrderedDict`s. 95 named_tuple = collections.namedtuple("A", ("b", "c")) 96 mess = ( 97 "z", 98 named_tuple(3, 4), 99 { 100 "c": ( 101 1, 102 collections.OrderedDict([ 103 ("b", 3), 104 ("a", 2), 105 ]), 106 ), 107 "b": 5 108 }, 109 17 110 ) 111 112 flattened = nest.flatten(mess) 113 self.assertEqual(flattened, ["z", 3, 4, 5, 1, 2, 3, 17]) 114 115 structure_of_mess = ( 116 14, 117 named_tuple("a", True), 118 { 119 "c": ( 120 0, 121 collections.OrderedDict([ 122 ("b", 9), 123 ("a", 8), 124 ]), 125 ), 126 "b": 3 127 }, 128 "hi everybody", 129 ) 130 131 unflattened = nest.pack_sequence_as(structure_of_mess, flattened) 132 self.assertEqual(unflattened, mess) 133 134 # Check also that the OrderedDict was created, with the correct key order. 135 unflattened_ordered_dict = unflattened[2]["c"][1] 136 self.assertIsInstance(unflattened_ordered_dict, collections.OrderedDict) 137 self.assertEqual(list(unflattened_ordered_dict.keys()), ["b", "a"]) 138 139 @combinations.generate(test_base.default_test_combinations()) 140 def testFlattenSparseValue(self): 141 st = sparse_tensor.SparseTensorValue([[0]], [0], [1]) 142 single_value = st 143 list_of_values = [st, st, st] 144 nest_of_values = ((st), ((st), (st))) 145 dict_of_values = {"foo": st, "bar": st, "baz": st} 146 self.assertEqual([st], nest.flatten(single_value)) 147 self.assertEqual([[st, st, st]], nest.flatten(list_of_values)) 148 self.assertEqual([st, st, st], nest.flatten(nest_of_values)) 149 self.assertEqual([st, st, st], nest.flatten(dict_of_values)) 150 151 @combinations.generate(test_base.default_test_combinations()) 152 def testFlattenRaggedValue(self): 153 rt = ragged_factory_ops.constant_value([[[0]], [[1]]]) 154 single_value = rt 155 list_of_values = [rt, rt, rt] 156 nest_of_values = ((rt), ((rt), (rt))) 157 dict_of_values = {"foo": rt, "bar": rt, "baz": rt} 158 self.assertEqual([rt], nest.flatten(single_value)) 159 self.assertEqual([[rt, rt, rt]], nest.flatten(list_of_values)) 160 self.assertEqual([rt, rt, rt], nest.flatten(nest_of_values)) 161 self.assertEqual([rt, rt, rt], nest.flatten(dict_of_values)) 162 163 @combinations.generate(test_base.default_test_combinations()) 164 def testIsNested(self): 165 self.assertFalse(nest.is_nested("1234")) 166 self.assertFalse(nest.is_nested([1, 3, [4, 5]])) 167 self.assertTrue(nest.is_nested(((7, 8), (5, 6)))) 168 self.assertFalse(nest.is_nested([])) 169 self.assertFalse(nest.is_nested(set([1, 2]))) 170 ones = array_ops.ones([2, 3]) 171 self.assertFalse(nest.is_nested(ones)) 172 self.assertFalse(nest.is_nested(math_ops.tanh(ones))) 173 self.assertFalse(nest.is_nested(np.ones((4, 5)))) 174 self.assertTrue(nest.is_nested({"foo": 1, "bar": 2})) 175 self.assertFalse( 176 nest.is_nested(sparse_tensor.SparseTensorValue([[0]], [0], [1]))) 177 self.assertFalse( 178 nest.is_nested(ragged_factory_ops.constant_value([[[0]], [[1]]]))) 179 180 @combinations.generate(test_base.default_test_combinations()) 181 def testAssertSameStructure(self): 182 structure1 = (((1, 2), 3), 4, (5, 6)) 183 structure2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6")) 184 structure_different_num_elements = ("spam", "eggs") 185 structure_different_nesting = (((1, 2), 3), 4, 5, (6,)) 186 structure_dictionary = {"foo": 2, "bar": 4, "baz": {"foo": 5, "bar": 6}} 187 structure_dictionary_diff_nested = { 188 "foo": 2, 189 "bar": 4, 190 "baz": { 191 "foo": 5, 192 "baz": 6 193 } 194 } 195 nest.assert_same_structure(structure1, structure2) 196 nest.assert_same_structure("abc", 1.0) 197 nest.assert_same_structure("abc", np.array([0, 1])) 198 nest.assert_same_structure("abc", constant_op.constant([0, 1])) 199 200 with self.assertRaisesRegex(ValueError, 201 "don't have the same nested structure"): 202 nest.assert_same_structure(structure1, structure_different_num_elements) 203 204 with self.assertRaisesRegex(ValueError, 205 "don't have the same nested structure"): 206 nest.assert_same_structure((0, 1), np.array([0, 1])) 207 208 with self.assertRaisesRegex(ValueError, 209 "don't have the same nested structure"): 210 nest.assert_same_structure(0, (0, 1)) 211 212 with self.assertRaisesRegex(ValueError, 213 "don't have the same nested structure"): 214 nest.assert_same_structure(structure1, structure_different_nesting) 215 216 named_type_0 = collections.namedtuple("named_0", ("a", "b")) 217 named_type_1 = collections.namedtuple("named_1", ("a", "b")) 218 self.assertRaises(TypeError, nest.assert_same_structure, (0, 1), 219 named_type_0("a", "b")) 220 221 nest.assert_same_structure(named_type_0(3, 4), named_type_0("a", "b")) 222 223 self.assertRaises(TypeError, nest.assert_same_structure, 224 named_type_0(3, 4), named_type_1(3, 4)) 225 226 with self.assertRaisesRegex(ValueError, 227 "don't have the same nested structure"): 228 nest.assert_same_structure(named_type_0(3, 4), named_type_0((3,), 4)) 229 230 with self.assertRaisesRegex(ValueError, 231 "don't have the same nested structure"): 232 nest.assert_same_structure(((3,), 4), (3, (4,))) 233 234 structure1_list = {"a": ((1, 2), 3), "b": 4, "c": (5, 6)} 235 structure2_list = {"a": ((1, 2), 3), "b": 4, "d": (5, 6)} 236 with self.assertRaisesRegex(TypeError, "don't have the same sequence type"): 237 nest.assert_same_structure(structure1, structure1_list) 238 nest.assert_same_structure(structure1, structure2, check_types=False) 239 nest.assert_same_structure(structure1, structure1_list, check_types=False) 240 with self.assertRaisesRegex(ValueError, "don't have the same set of keys"): 241 nest.assert_same_structure(structure1_list, structure2_list) 242 with self.assertRaisesRegex(ValueError, "don't have the same set of keys"): 243 nest.assert_same_structure(structure_dictionary, 244 structure_dictionary_diff_nested) 245 nest.assert_same_structure( 246 structure_dictionary, 247 structure_dictionary_diff_nested, 248 check_types=False) 249 nest.assert_same_structure( 250 structure1_list, structure2_list, check_types=False) 251 252 @combinations.generate(test_base.default_test_combinations()) 253 def testMapStructure(self): 254 structure1 = (((1, 2), 3), 4, (5, 6)) 255 structure2 = (((7, 8), 9), 10, (11, 12)) 256 structure1_plus1 = nest.map_structure(lambda x: x + 1, structure1) 257 nest.assert_same_structure(structure1, structure1_plus1) 258 self.assertAllEqual( 259 [2, 3, 4, 5, 6, 7], 260 nest.flatten(structure1_plus1)) 261 structure1_plus_structure2 = nest.map_structure( 262 lambda x, y: x + y, structure1, structure2) 263 self.assertEqual( 264 (((1 + 7, 2 + 8), 3 + 9), 4 + 10, (5 + 11, 6 + 12)), 265 structure1_plus_structure2) 266 267 self.assertEqual(3, nest.map_structure(lambda x: x - 1, 4)) 268 269 self.assertEqual(7, nest.map_structure(lambda x, y: x + y, 3, 4)) 270 271 with self.assertRaisesRegex(TypeError, "callable"): 272 nest.map_structure("bad", structure1_plus1) 273 274 with self.assertRaisesRegex(ValueError, "same nested structure"): 275 nest.map_structure(lambda x, y: None, 3, (3,)) 276 277 with self.assertRaisesRegex(TypeError, "same sequence type"): 278 nest.map_structure(lambda x, y: None, ((3, 4), 5), {"a": (3, 4), "b": 5}) 279 280 with self.assertRaisesRegex(ValueError, "same nested structure"): 281 nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5))) 282 283 with self.assertRaisesRegex(ValueError, "same nested structure"): 284 nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5)), 285 check_types=False) 286 287 with self.assertRaisesRegex(ValueError, "Only valid keyword argument"): 288 nest.map_structure(lambda x: None, structure1, foo="a") 289 290 with self.assertRaisesRegex(ValueError, "Only valid keyword argument"): 291 nest.map_structure(lambda x: None, structure1, check_types=False, foo="a") 292 293 @combinations.generate(test_base.default_test_combinations()) 294 def testAssertShallowStructure(self): 295 inp_ab = ("a", "b") 296 inp_abc = ("a", "b", "c") 297 expected_message = ( 298 "The two structures don't have the same sequence length. Input " 299 "structure has length 2, while shallow structure has length 3.") 300 with self.assertRaisesRegex(ValueError, expected_message): 301 nest.assert_shallow_structure(inp_abc, inp_ab) 302 303 inp_ab1 = ((1, 1), (2, 2)) 304 inp_ab2 = {"a": (1, 1), "b": (2, 2)} 305 expected_message = ( 306 "The two structures don't have the same sequence type. Input structure " 307 "has type 'tuple', while shallow structure has type " 308 "'dict'.") 309 with self.assertRaisesRegex(TypeError, expected_message): 310 nest.assert_shallow_structure(inp_ab2, inp_ab1) 311 nest.assert_shallow_structure(inp_ab2, inp_ab1, check_types=False) 312 313 inp_ab1 = {"a": (1, 1), "b": {"c": (2, 2)}} 314 inp_ab2 = {"a": (1, 1), "b": {"d": (2, 2)}} 315 expected_message = ( 316 r"The two structures don't have the same keys. Input " 317 r"structure has keys \['c'\], while shallow structure has " 318 r"keys \['d'\].") 319 with self.assertRaisesRegex(ValueError, expected_message): 320 nest.assert_shallow_structure(inp_ab2, inp_ab1) 321 322 inp_ab = collections.OrderedDict([("a", 1), ("b", (2, 3))]) 323 inp_ba = collections.OrderedDict([("b", (2, 3)), ("a", 1)]) 324 nest.assert_shallow_structure(inp_ab, inp_ba) 325 326 @combinations.generate(test_base.default_test_combinations()) 327 def testFlattenUpTo(self): 328 input_tree = (((2, 2), (3, 3)), ((4, 9), (5, 5))) 329 shallow_tree = ((True, True), (False, True)) 330 flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) 331 flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) 332 self.assertEqual(flattened_input_tree, [(2, 2), (3, 3), (4, 9), (5, 5)]) 333 self.assertEqual(flattened_shallow_tree, [True, True, False, True]) 334 335 input_tree = ((("a", 1), (("b", 2), (("c", 3), (("d", 4)))))) 336 shallow_tree = (("level_1", ("level_2", ("level_3", ("level_4"))))) 337 input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, 338 input_tree) 339 input_tree_flattened = nest.flatten(input_tree) 340 self.assertEqual(input_tree_flattened_as_shallow_tree, 341 [("a", 1), ("b", 2), ("c", 3), ("d", 4)]) 342 self.assertEqual(input_tree_flattened, ["a", 1, "b", 2, "c", 3, "d", 4]) 343 344 ## Shallow non-list edge-case. 345 # Using iterable elements. 346 input_tree = ["input_tree"] 347 shallow_tree = "shallow_tree" 348 flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) 349 flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) 350 self.assertEqual(flattened_input_tree, [input_tree]) 351 self.assertEqual(flattened_shallow_tree, [shallow_tree]) 352 353 input_tree = ("input_tree_0", "input_tree_1") 354 shallow_tree = "shallow_tree" 355 flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) 356 flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) 357 self.assertEqual(flattened_input_tree, [input_tree]) 358 self.assertEqual(flattened_shallow_tree, [shallow_tree]) 359 360 # Using non-iterable elements. 361 input_tree = (0,) 362 shallow_tree = 9 363 flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) 364 flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) 365 self.assertEqual(flattened_input_tree, [input_tree]) 366 self.assertEqual(flattened_shallow_tree, [shallow_tree]) 367 368 input_tree = (0, 1) 369 shallow_tree = 9 370 flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) 371 flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) 372 self.assertEqual(flattened_input_tree, [input_tree]) 373 self.assertEqual(flattened_shallow_tree, [shallow_tree]) 374 375 ## Both non-list edge-case. 376 # Using iterable elements. 377 input_tree = "input_tree" 378 shallow_tree = "shallow_tree" 379 flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) 380 flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) 381 self.assertEqual(flattened_input_tree, [input_tree]) 382 self.assertEqual(flattened_shallow_tree, [shallow_tree]) 383 384 # Using non-iterable elements. 385 input_tree = 0 386 shallow_tree = 0 387 flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) 388 flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) 389 self.assertEqual(flattened_input_tree, [input_tree]) 390 self.assertEqual(flattened_shallow_tree, [shallow_tree]) 391 392 ## Input non-list edge-case. 393 # Using iterable elements. 394 input_tree = "input_tree" 395 shallow_tree = ("shallow_tree",) 396 expected_message = ("If shallow structure is a sequence, input must also " 397 "be a sequence. Input has type: 'str'.") 398 with self.assertRaisesRegex(TypeError, expected_message): 399 flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) 400 flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) 401 self.assertEqual(flattened_shallow_tree, list(shallow_tree)) 402 403 input_tree = "input_tree" 404 shallow_tree = ("shallow_tree_9", "shallow_tree_8") 405 with self.assertRaisesRegex(TypeError, expected_message): 406 flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) 407 flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) 408 self.assertEqual(flattened_shallow_tree, list(shallow_tree)) 409 410 # Using non-iterable elements. 411 input_tree = 0 412 shallow_tree = (9,) 413 expected_message = ("If shallow structure is a sequence, input must also " 414 "be a sequence. Input has type: 'int'.") 415 with self.assertRaisesRegex(TypeError, expected_message): 416 flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) 417 flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) 418 self.assertEqual(flattened_shallow_tree, list(shallow_tree)) 419 420 input_tree = 0 421 shallow_tree = (9, 8) 422 with self.assertRaisesRegex(TypeError, expected_message): 423 flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) 424 flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) 425 self.assertEqual(flattened_shallow_tree, list(shallow_tree)) 426 427 # Using dict. 428 input_tree = {"a": ((2, 2), (3, 3)), "b": ((4, 9), (5, 5))} 429 shallow_tree = {"a": (True, True), "b": (False, True)} 430 flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) 431 flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) 432 self.assertEqual(flattened_input_tree, [(2, 2), (3, 3), (4, 9), (5, 5)]) 433 self.assertEqual(flattened_shallow_tree, [True, True, False, True]) 434 435 @combinations.generate(test_base.default_test_combinations()) 436 def testMapStructureUpTo(self): 437 ab_tuple = collections.namedtuple("ab_tuple", "a, b") 438 op_tuple = collections.namedtuple("op_tuple", "add, mul") 439 inp_val = ab_tuple(a=2, b=3) 440 inp_ops = ab_tuple(a=op_tuple(add=1, mul=2), b=op_tuple(add=2, mul=3)) 441 out = nest.map_structure_up_to( 442 inp_val, lambda val, ops: (val + ops.add) * ops.mul, inp_val, inp_ops) 443 self.assertEqual(out.a, 6) 444 self.assertEqual(out.b, 15) 445 446 data_list = ((2, 4, 6, 8), ((1, 3, 5, 7, 9), (3, 5, 7))) 447 name_list = ("evens", ("odds", "primes")) 448 out = nest.map_structure_up_to( 449 name_list, lambda name, sec: "first_{}_{}".format(len(sec), name), 450 name_list, data_list) 451 self.assertEqual(out, ("first_4_evens", ("first_5_odds", "first_3_primes"))) 452 453 454if __name__ == "__main__": 455 test.main() 456