1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15import collections 16import copy 17import json 18import os 19import pickle 20 21from absl.testing import parameterized 22from tensorflow.python.checkpoint import checkpoint as util 23from tensorflow.python.data.ops import dataset_ops 24from tensorflow.python.eager import context 25from tensorflow.python.eager import def_function 26from tensorflow.python.eager import test 27from tensorflow.python.framework import constant_op 28from tensorflow.python.framework import tensor_shape 29from tensorflow.python.layers import core as non_keras_core 30from tensorflow.python.module import module 31from tensorflow.python.ops import array_ops 32from tensorflow.python.ops import resource_variable_ops 33from tensorflow.python.ops import variables 34from tensorflow.python.trackable import autotrackable 35from tensorflow.python.trackable import data_structures 36from tensorflow.python.util import nest 37from tensorflow.python.util import serialization 38 39 40class ListTests(test.TestCase): 41 42 def testJSONSerialization(self): 43 obj = autotrackable.AutoTrackable() 44 obj.l = [1] 45 json.dumps(obj.l, default=serialization.get_json_type) 46 47 def testNotTrackable(self): 48 class NotTrackable(object): 49 pass 50 51 with self.assertRaises(ValueError): 52 data_structures.List([NotTrackable()]) 53 54 def testCallNotImplemented(self): 55 with self.assertRaisesRegex(TypeError, "not callable"): 56 data_structures.List()(1.) # pylint: disable=not-callable 57 58 def testNoPop(self): 59 with self.assertRaises(AttributeError): 60 data_structures.List().pop() 61 62 def testNesting(self): 63 with context.graph_mode(): 64 inner = data_structures.List() 65 outer = data_structures.List([inner]) 66 inner.append(non_keras_core.Dense(1)) 67 inner[0](array_ops.ones([2, 3])) 68 self.assertEqual(2, len(outer.variables)) 69 self.assertIsInstance( 70 outer.variables[0], 71 resource_variable_ops.ResourceVariable) 72 73 def testNonLayerVariables(self): 74 v = resource_variable_ops.ResourceVariable([1.]) 75 l = data_structures.List([v]) 76 self.assertTrue(l.trainable) 77 self.assertEqual([], l.layers) 78 self.assertEqual([v], l.variables) 79 self.assertEqual([v], l.trainable_weights) 80 self.assertEqual([], l.non_trainable_variables) 81 l.trainable = False 82 self.assertEqual([v], l.variables) 83 self.assertEqual([], l.trainable_variables) 84 self.assertEqual([v], l.non_trainable_variables) 85 l.trainable = True 86 v2 = resource_variable_ops.ResourceVariable(1., trainable=False) 87 l.append(v2) 88 self.assertEqual([v, v2], l.weights) 89 self.assertEqual([v], l.trainable_weights) 90 self.assertEqual([v2], l.non_trainable_weights) 91 92 def testCopy(self): 93 v1 = resource_variable_ops.ResourceVariable(1.) 94 v2 = resource_variable_ops.ResourceVariable(1.) 95 v3 = resource_variable_ops.ResourceVariable(1.) 96 97 l1 = data_structures.List([v1, v2]) 98 l2 = l1.copy() 99 l2.append(v3) 100 self.assertEqual(list(l1), [v1, v2]) 101 self.assertEqual(list(l2), [v1, v2, v3]) 102 103 def testSlicing(self): 104 v1 = resource_variable_ops.ResourceVariable(1.) 105 v2 = resource_variable_ops.ResourceVariable(1.) 106 v3 = resource_variable_ops.ResourceVariable(1.) 107 v4 = resource_variable_ops.ResourceVariable(1.) 108 109 l = data_structures.List([v1, v2, v3, v4]) 110 self.assertEqual(l[1:], [v2, v3, v4]) 111 self.assertEqual(l[1:-1], [v2, v3]) 112 self.assertEqual(l[:-1], [v1, v2, v3]) 113 114 def testHash(self): 115 has_sequences = {data_structures.List(), data_structures.List()} 116 self.assertEqual(2, len(has_sequences)) 117 self.assertNotIn(data_structures.List(), has_sequences) 118 119 def testIMul_zero(self): 120 l = data_structures.List([]) 121 with self.assertRaisesRegex(ValueError, "List only supports append"): 122 l *= 0 123 124 def testIMul(self): 125 v = resource_variable_ops.ResourceVariable(1.) 126 l = data_structures.List([v]) 127 l *= 2 128 self.assertEqual(list(l), [v] * 2) 129 130 def testMul(self): 131 v = resource_variable_ops.ResourceVariable(1.) 132 l = data_structures.List([v, v, v]) 133 self.assertEqual(list(l * 2), [v, v, v] * 2) 134 135 def testRMul(self): 136 v = resource_variable_ops.ResourceVariable(1.) 137 l = data_structures.List([v, v, v]) 138 self.assertEqual(list(2 * l), [v, v, v] * 2) 139 140 141class ListWrapperTest(test.TestCase): 142 143 IGNORED = ("__new__", "__init__", "__subclasshook__", "__getattribute__") 144 145 def test_overrides_all_list_methods(self): 146 not_overridden = [] 147 148 for name in dir(list): 149 if name in ListWrapperTest.IGNORED: 150 continue 151 152 list_method = getattr(list, name) 153 154 if not callable(list_method): 155 continue 156 157 object_method = getattr(object, name, None) 158 if object_method is not None and object_method == list_method: 159 # Skip methods that aren't overridden from object. 160 continue 161 162 if list_method == getattr(data_structures.ListWrapper, name): 163 not_overridden.append(name) 164 165 if not_overridden: 166 self.fail("ListWrapper does not override %s" % (not_overridden)) 167 168 def testPickle(self): 169 original = data_structures.ListWrapper([1, 2]) 170 serialized = pickle.dumps(original) 171 del original 172 deserialized = pickle.loads(serialized) 173 self.assertEqual([1, 2], deserialized) 174 175 def testSameStructure(self): 176 l = [1] 177 nest.assert_same_structure(l, data_structures.ListWrapper(copy.copy(l))) 178 179 def testMutateWithoutTrackableComponents(self): 180 m = module.Module() 181 m.l = [1, 2] 182 m.l.insert(0, 0) 183 self.assertEqual(m.l, [0, 1, 2]) 184 self.assertEqual(m.l._trackable_children(), {}) 185 186 def testFunctionCaching(self): 187 @def_function.function 188 def f(list_input): 189 return list_input[0] + constant_op.constant(1.) 190 191 first_trace = f.get_concrete_function([constant_op.constant(2.)]) 192 second_trace = f.get_concrete_function( 193 data_structures.ListWrapper([constant_op.constant(3.)])) 194 self.assertIs(first_trace, second_trace) 195 196 def testListWrapperBasic(self): 197 # ListWrapper, unlike List, compares like the built-in list type (since it 198 # is used to automatically replace lists). 199 a = autotrackable.AutoTrackable() 200 b = autotrackable.AutoTrackable() 201 self.assertEqual([a, a], 202 [a, a]) 203 self.assertEqual(data_structures.ListWrapper([a, a]), 204 data_structures.ListWrapper([a, a])) 205 self.assertEqual([a, a], 206 data_structures.ListWrapper([a, a])) 207 self.assertEqual(data_structures.ListWrapper([a, a]), 208 [a, a]) 209 self.assertNotEqual([a, a], 210 [b, a]) 211 self.assertNotEqual(data_structures.ListWrapper([a, a]), 212 data_structures.ListWrapper([b, a])) 213 self.assertNotEqual([a, a], 214 data_structures.ListWrapper([b, a])) 215 self.assertLess([a], [a, b]) 216 self.assertLess(data_structures.ListWrapper([a]), 217 data_structures.ListWrapper([a, b])) 218 self.assertLessEqual([a], [a, b]) 219 self.assertLessEqual(data_structures.ListWrapper([a]), 220 data_structures.ListWrapper([a, b])) 221 self.assertGreater([a, b], [a]) 222 self.assertGreater(data_structures.ListWrapper([a, b]), 223 data_structures.ListWrapper([a])) 224 self.assertGreaterEqual([a, b], [a]) 225 self.assertGreaterEqual(data_structures.ListWrapper([a, b]), 226 data_structures.ListWrapper([a])) 227 self.assertEqual([a], data_structures.ListWrapper([a])) 228 self.assertEqual([a], list(data_structures.List([a]))) 229 self.assertEqual([a, a], data_structures.ListWrapper([a]) + [a]) 230 self.assertEqual([a, a], [a] + data_structures.ListWrapper([a])) 231 self.assertIsInstance(data_structures.ListWrapper([a]), list) 232 self.assertEqual( 233 tensor_shape.TensorShape([None, 2]).as_list(), 234 (data_structures.ListWrapper([None]) 235 + tensor_shape.TensorShape([2])).as_list()) 236 237 def testAcceptsNonTrackableContent(self): 238 l = data_structures.ListWrapper([1, 2, 3]) 239 self.assertEqual(l, [1, 2, 3]) 240 241 def testWrapperChangesList(self): 242 l = [] 243 l_wrapper = data_structures.ListWrapper(l) 244 l_wrapper.append(1) 245 self.assertEqual([1], l) 246 247 def testListChangesWrapper(self): 248 l = [] 249 l_wrapper = data_structures.ListWrapper(l) 250 l.append(1) 251 self.assertEqual([1], l_wrapper) 252 253 def testNotHashable(self): 254 with self.assertRaises(TypeError): 255 hash(data_structures.ListWrapper()) # pylint: disable=no-value-for-parameter 256 257 def testDelItem(self): 258 l = data_structures.ListWrapper([1, 2, 3, [4]]) 259 del l[0] 260 self.assertEqual(l, [2, 3, [4]]) 261 self.assertUnableToSave(l, "Unable to save .*__delitem__") 262 263 def testDelSlice(self): 264 l = data_structures.ListWrapper([1, 2, 3, [4]]) 265 del l[2:3] 266 self.assertEqual(l, [1, 2, [4]]) 267 self.assertUnableToSave(l, "Unable to save .*__delslice__") 268 269 def testSetSlice_canSaveForNonTrackableItems(self): 270 l = data_structures.ListWrapper([1, 2, 3, 4]) 271 l[:] = 2, 8, 9, 0 272 self.assertEqual(l, [2, 8, 9, 0]) 273 l._maybe_initialize_trackable() # pylint: disable=protected-access 274 self.assertEqual(len(l._trackable_children()), 0) # pylint: disable=protected-access 275 276 def testSetSlice_cannotSaveIfTrackableModified(self): 277 v1 = resource_variable_ops.ResourceVariable(1.) 278 v2 = resource_variable_ops.ResourceVariable(1.) 279 l = data_structures.ListWrapper([1, 2, v1, v2]) 280 l[:] = 2, 8, 9, v2 281 self.assertEqual(l, [2, 8, 9, v2]) 282 self.assertUnableToSave(l, "Unable to save .*__setslice__") 283 284 def testSetSlice_truncate(self): 285 l = data_structures.ListWrapper([1, 2, 3, 4]) 286 l[:] = [] 287 self.assertEqual(l, []) 288 289 def testSetSlice_extend(self): 290 l = data_structures.ListWrapper([1, 2, 3, 4]) 291 l[2:] = 1, 2, 3, 4 292 self.assertEqual(l, [1, 2, 1, 2, 3, 4]) 293 294 def testIMulNegative(self): 295 l = data_structures.ListWrapper([1, 2, 3, [4]]) 296 l *= -1 297 self.assertEqual(l, [1, 2, 3, [4]] * -1) 298 self.assertUnableToSave(l, "Unable to save") 299 300 def testIMulPositive(self): 301 v = variables.Variable(1.) 302 l = data_structures.ListWrapper([1, 2, 3, 4, v]) 303 self.assertDictEqual({"4": v}, l._trackable_children()) 304 root = util.Checkpoint(l=l) 305 prefix = os.path.join(self.get_temp_dir(), "ckpt") 306 path = root.save(prefix) 307 v.assign(5.) 308 l *= 2 309 self.assertEqual(l, [1, 2, 3, 4, v, 1, 2, 3, 4, v]) 310 self.assertDictEqual({"4": v, "9": v}, l._trackable_children()) 311 root.restore(path) 312 self.assertAllClose(1., v.numpy()) 313 314 def testSort(self): 315 l = data_structures.ListWrapper([[1], [2], [3], [4]]) 316 l.sort() 317 self.assertAllEqual(l, [[1], [2], [3], [4]]) 318 # Regardless of being a no-op for the input list, we still refuse to save. 319 # This is intentional since otherwise we would end up with a hard to debug 320 # case for users (e.g. sometimes sort on a ListWrapper is trackable and 321 # other times it is not). 322 self.assertUnableToSave(l, "Unable to save .*sort") 323 324 def assertUnableToSave(self, l, msg): 325 l._maybe_initialize_trackable() # pylint: disable=protected-access 326 with self.assertRaisesRegex(ValueError, msg): 327 return l._trackable_children() # pylint: disable=protected-access 328 329 330class MappingTests(test.TestCase): 331 332 def testJSONSerialization(self): 333 obj = autotrackable.AutoTrackable() 334 obj.d = {"a": 2} 335 json.dumps(obj.d, default=serialization.get_json_type) 336 337 def testNoOverwrite(self): 338 mapping = data_structures.Mapping() 339 original = data_structures.List() 340 mapping["a"] = original 341 with self.assertRaises(ValueError): 342 mapping["a"] = data_structures.List() 343 self.assertIs(original, mapping["a"]) 344 with self.assertRaises(AttributeError): 345 del mapping["a"] # pylint: disable=unsupported-delete-operation 346 mapping.update(b=data_structures.Mapping()) 347 with self.assertRaises(ValueError): 348 mapping.update({"b": data_structures.Mapping()}) 349 350 def testNonStringKeys(self): 351 mapping = data_structures.Mapping() 352 with self.assertRaises(TypeError): 353 mapping[1] = data_structures.List() 354 355 def testHashing(self): 356 has_mappings = set([data_structures.Mapping(), 357 data_structures.Mapping()]) 358 self.assertEqual(2, len(has_mappings)) 359 self.assertNotIn(data_structures.Mapping(), has_mappings) 360 # In contrast to Mapping, dict wrappers are not hashable 361 a = autotrackable.AutoTrackable() 362 a.d = {} 363 self.assertEqual({}, a.d) 364 self.assertFalse({} != a.d) # pylint: disable=g-explicit-bool-comparison 365 self.assertNotEqual({1: 2}, a.d) 366 with self.assertRaisesRegex(TypeError, "unhashable"): 367 set([a.d]) 368 369 def testListShallowCopy(self): 370 root = autotrackable.AutoTrackable() 371 orig_list = [[1.]] 372 root.a = orig_list 373 copied = copy.copy(root.a) 374 self.assertAllEqual([[1.]], copied) 375 self.assertIsNot(root.a, copied) 376 self.assertIs(root.a[0], copied[0]) 377 378 # Dirtiness should be inherited 379 util.list_objects(root.a) 380 orig_list.append(1.) 381 with self.assertRaises(ValueError): 382 util.list_objects(root.a) 383 with self.assertRaises(ValueError): 384 util.list_objects(copy.copy(root.a)) 385 386 def testListDeepCopy(self): 387 root = autotrackable.AutoTrackable() 388 orig_list = [[1.]] 389 root.a = orig_list 390 copied = copy.deepcopy(root.a) 391 self.assertAllEqual([[1.]], copied) 392 self.assertIsNot(root.a, copied) 393 self.assertIsNot(root.a[0], copied[0]) 394 395 # Dirtiness should be inherited 396 util.list_objects(root.a) 397 orig_list.append(1.) 398 with self.assertRaises(ValueError): 399 util.list_objects(root.a) 400 with self.assertRaises(ValueError): 401 util.list_objects(copy.deepcopy(root.a)) 402 403 def testDictShallowCopy(self): 404 root = autotrackable.AutoTrackable() 405 orig_dict = {"a": [1.]} 406 root.a = orig_dict 407 copied = copy.copy(root.a) 408 self.assertAllEqual([1.], copied["a"]) 409 self.assertIsNot(root.a, copied) 410 self.assertIs(root.a["a"], copied["a"]) 411 412 copied = root.a.copy() 413 self.assertAllEqual([1.], copied["a"]) 414 self.assertIsNot(root.a, copied) 415 self.assertIs(root.a["a"], copied["a"]) 416 417 # Dirtiness should be inherited 418 util.list_objects(root.a) 419 orig_dict["b"] = [] 420 with self.assertRaises(ValueError): 421 util.list_objects(root.a) 422 with self.assertRaises(ValueError): 423 util.list_objects(copy.copy(root.a)) 424 425 def testDictDeepCopy(self): 426 root = autotrackable.AutoTrackable() 427 orig_dict = {"a": [1.]} 428 root.a = orig_dict 429 copied = copy.deepcopy(root.a) 430 self.assertAllEqual([1.], copied["a"]) 431 self.assertIsNot(root.a, copied) 432 self.assertIsNot(root.a["a"], copied["a"]) 433 434 # Dirtiness should be inherited 435 util.list_objects(root.a) 436 orig_dict["b"] = [] 437 with self.assertRaises(ValueError): 438 util.list_objects(root.a) 439 with self.assertRaises(ValueError): 440 util.list_objects(copy.deepcopy(root.a)) 441 442 def testShallowCopyTrackable(self): 443 original = autotrackable.AutoTrackable() 444 original_sub = autotrackable.AutoTrackable() 445 original.a = [[1.]] 446 original.b = {"a": original_sub} 447 shallow_copied = copy.copy(original) 448 self.assertIs(original_sub, shallow_copied.b["a"]) 449 self.assertIsNot(original, shallow_copied) 450 self.assertEqual([[1.]], shallow_copied.a) 451 shallow_deps = util.list_objects(shallow_copied) 452 self.assertIn(shallow_copied.a, shallow_deps) 453 self.assertIn(shallow_copied.b, shallow_deps) 454 self.assertIn(shallow_copied.b["a"], shallow_deps) 455 456 def testDeepCopyTrackable(self): 457 original = autotrackable.AutoTrackable() 458 original_sub = autotrackable.AutoTrackable() 459 original.a = [[1.]] 460 original.b = {"a": original_sub} 461 self.assertIsInstance(original.b, dict) 462 deep_copied = copy.deepcopy(original) 463 self.assertIsInstance(deep_copied.b, dict) 464 self.assertIsNot(original, deep_copied) 465 self.assertIsNot(original_sub, deep_copied.b["a"]) 466 self.assertEqual([[1.]], deep_copied.a) 467 self.assertIsInstance(deep_copied.b["a"], autotrackable.AutoTrackable) 468 deps = util.list_objects(deep_copied) 469 self.assertIn(deep_copied.a, deps) 470 self.assertIn(deep_copied.b, deps) 471 self.assertIn(deep_copied.b["a"], deps) 472 self.assertNotIn(original_sub, deps) 473 474 def testConstructableFromSequence(self): 475 result = data_structures._DictWrapper([(1, 2), (3, 4)]) 476 self.assertIsInstance(result, dict) 477 self.assertEqual({1: 2, 3: 4}, result) 478 479 def testPickle(self): 480 original = data_structures._DictWrapper(dict(a=1, b=2)) 481 serialized = pickle.dumps(original) 482 del original 483 deserialized = pickle.loads(serialized) 484 self.assertEqual(dict(a=1, b=2), deserialized) 485 486 def testListAddOrder(self): 487 self.assertEqual([1., 2.], 488 data_structures.ListWrapper([1.]) 489 + data_structures.ListWrapper([2.])) 490 self.assertEqual([1., 2.], 491 data_structures.ListWrapper([1.]) 492 + [2.]) 493 self.assertEqual([1., 2.], 494 [1.] 495 + data_structures.ListWrapper([2.])) 496 497 def testSameStructure(self): 498 d = {1: "a"} 499 nest.assert_same_structure(d, data_structures._DictWrapper(d.copy())) 500 501 def testFunctionCaching(self): 502 @def_function.function 503 def f(dict_input): 504 return dict_input["x"] + constant_op.constant(1.) 505 506 first_trace = f.get_concrete_function({"x": constant_op.constant(2.)}) 507 second_trace = f.get_concrete_function( 508 data_structures._DictWrapper({"x": constant_op.constant(3.)})) 509 self.assertIs(first_trace, second_trace) 510 511 512class TupleTests(test.TestCase, parameterized.TestCase): 513 514 def testJSONSerialization(self): 515 obj = autotrackable.AutoTrackable() 516 obj.l = (1,) 517 json.dumps(obj.l, default=serialization.get_json_type) 518 519 def testNonLayerVariables(self): 520 v = resource_variable_ops.ResourceVariable([1.]) 521 l = data_structures._TupleWrapper((v,)) 522 self.assertEqual([], l.layers) 523 self.assertEqual([v], l.variables) 524 self.assertEqual([v], l.trainable_weights) 525 self.assertEqual([], l.non_trainable_variables) 526 527 def testCopy(self): 528 v1 = resource_variable_ops.ResourceVariable(1.) 529 v2 = resource_variable_ops.ResourceVariable(1.) 530 531 l1 = data_structures._TupleWrapper((v1, v2)) 532 l2 = copy.copy(l1) 533 self.assertEqual(l1, (v1, v2)) 534 self.assertEqual(l2, (v1, v2)) 535 self.assertIs(l1[0], l2[0]) 536 l2_deep = copy.deepcopy(l1) 537 self.assertIsNot(l1[0], l2_deep[0]) 538 with self.assertRaises(AttributeError): 539 l2.append(v1) 540 541 def testSlicing(self): 542 v1 = resource_variable_ops.ResourceVariable(1.) 543 v2 = resource_variable_ops.ResourceVariable(1.) 544 v3 = resource_variable_ops.ResourceVariable(1.) 545 v4 = resource_variable_ops.ResourceVariable(1.) 546 547 l = data_structures._TupleWrapper((v1, v2, v3, v4)) 548 self.assertEqual(l[1:], (v2, v3, v4)) 549 self.assertEqual(l[1:-1], (v2, v3)) 550 self.assertEqual(l[:-1], (v1, v2, v3)) 551 552 def testHash(self): 553 has_sequences = set([data_structures._TupleWrapper(), 554 data_structures._TupleWrapper()]) 555 self.assertLen(has_sequences, 1) 556 self.assertIn(data_structures._TupleWrapper(), has_sequences) 557 558 def testIMul_zero(self): 559 l = data_structures._TupleWrapper((1,)) 560 l *= 0 561 self.assertEqual((), l) 562 563 def testIMul(self): 564 # Note: tuple behavior differs from list behavior. Lists are mutated by 565 # imul/iadd, tuples assign a new object to the left hand side of the 566 # expression. 567 v = resource_variable_ops.ResourceVariable(1.) 568 l = data_structures._TupleWrapper((v,)) 569 original = l 570 l *= 2 571 self.assertEqual(l, (v,) * 2) 572 self.assertNotEqual(original, (v,) * 2) 573 574 def testIAdd(self): 575 v = resource_variable_ops.ResourceVariable(1.) 576 l = data_structures._TupleWrapper((v,)) 577 original = l 578 l += (1,) 579 self.assertEqual(l, (v, 1)) 580 self.assertNotEqual(original, (v, 1)) 581 self.assertEqual(original, (v,)) 582 583 def testMul(self): 584 v = resource_variable_ops.ResourceVariable(1.) 585 l = data_structures._TupleWrapper((v, v, v)) 586 self.assertEqual(l * 2, (v, v, v) * 2) 587 588 def testRMul(self): 589 v = resource_variable_ops.ResourceVariable(1.) 590 l = data_structures._TupleWrapper((v, v, v)) 591 self.assertEqual(2 * l, (v, v, v) * 2) 592 593 def testPickle(self): 594 original = data_structures._TupleWrapper((1, 2)) 595 serialized = pickle.dumps(original) 596 del original 597 deserialized = pickle.loads(serialized) 598 self.assertEqual((1, 2), deserialized) 599 600 def testNamedTuple(self): 601 named = collections.namedtuple("Named", ("x", "y")) 602 v = variables.Variable(2) 603 nt = named(x=v, y=2) 604 m = module.Module() 605 m.nt = nt 606 self.assertIs(v, m.nt.x) 607 self.assertIs(v, m.nt[0]) 608 self.assertIs( 609 v, m._trackable_children()["nt"]._trackable_children()["x"]) 610 self.assertEqual(2, m.nt.y) 611 612 def testNamedTupleConflictingAttributes(self): 613 named = collections.namedtuple("Named", ("x", "weights")) 614 v = variables.Variable(2) 615 nt = named(x=v, weights=3) 616 m = module.Module() 617 m.nt = nt 618 self.assertEqual(3, m.nt.weights) 619 620 def testNamedSubclassing(self): 621 named = collections.namedtuple("Named", ("x", "y")) 622 v = variables.Variable(2) 623 624 class NamedSubclass(named): 625 626 def __new__(cls, x, y): 627 del y # unused 628 return super(NamedSubclass, cls).__new__(cls, x, 3) 629 630 @property 631 def summed(self): 632 return self.x + self.y 633 634 nt = NamedSubclass(x=v, y=2) 635 m = module.Module() 636 m.nt = nt 637 self.assertEqual(3, m.nt.y) 638 self.assertIs(v, m.nt.x) 639 self.assertIn(v, 640 m._trackable_children()["nt"]._trackable_children().values()) 641 self.assertIn("x", m.nt._trackable_children()) 642 self.assertIn("0", m.nt._trackable_children()) 643 self.assertEqual(5, self.evaluate(m.nt.summed)) 644 645 def testUnnamedSubclassing(self): 646 v = variables.Variable(2) 647 648 class UnnamedSubclass(tuple): 649 650 @property 651 def summed(self): 652 return self[0] + self[1] 653 654 unt = UnnamedSubclass([v, 2]) 655 m = module.Module() 656 m.unt = unt 657 self.assertIn("0", m.unt._trackable_children()) 658 self.assertLen(m.unt._trackable_children(), 1) 659 self.assertEqual(4, self.evaluate(m.unt.summed)) 660 nest.assert_same_structure( 661 [m.unt], nest.map_structure(lambda x: x, [m.unt])) 662 663 def testNamedtupleSubclassWithCustomNew(self): 664 class SubclassWithDifferentArgs(collections.namedtuple("A", ["x"])): 665 666 def __new__(cls): 667 return super(SubclassWithDifferentArgs, cls).__new__(cls, []) 668 669 nt = SubclassWithDifferentArgs() 670 m = module.Module() 671 m.nt = nt 672 m.nt.x.append(variables.Variable(1.)) 673 prefix = os.path.join(self.get_temp_dir(), "ckpt") 674 ckpt = util.Checkpoint(m=m) 675 with self.assertRaises(ValueError): 676 ckpt.save(prefix) 677 678 def testSameStructure(self): 679 t = (variables.Variable(1.),) 680 m = module.Module() 681 m.t = t 682 nest.assert_same_structure(t, m.t) 683 nest.assert_same_structure(m.t, t) 684 685 nt_type = collections.namedtuple("nt", ["x", "y"]) 686 nt = nt_type(x=1, y=2) 687 m.nt = nt 688 nest.assert_same_structure(m.nt, nt) 689 with self.assertRaises(TypeError): # pylint: disable=g-error-prone-assert-raises 690 nest.assert_same_structure(m.nt, m.t) 691 692 def testFlatten(self): 693 t = data_structures._TupleWrapper((1, data_structures._TupleWrapper((2,)))) 694 self.assertEqual([1, 2], nest.flatten(t)) 695 self.assertEqual( 696 nest.flatten_with_tuple_paths((1, (2,))), 697 nest.flatten_with_tuple_paths(t)) 698 self.assertEqual((3, (4,)), 699 nest.pack_sequence_as(t, [3, 4])) 700 nt_type = collections.namedtuple("nt", ["x", "y"]) 701 nt = nt_type(1., 2.) 702 wrapped_nt = data_structures._TupleWrapper(nt) 703 self.assertEqual( 704 nest.flatten_with_tuple_paths(nt), 705 nest.flatten_with_tuple_paths(wrapped_nt)) 706 self.assertEqual((3, 4,), 707 nest.pack_sequence_as(wrapped_nt, [3, 4])) 708 self.assertEqual(3, nest.pack_sequence_as(wrapped_nt, [3, 4]).x) 709 710 def testFunctionCaching(self): 711 @def_function.function 712 def f(tuple_input): 713 return tuple_input[0] + constant_op.constant(1.) 714 715 first_trace = f.get_concrete_function((constant_op.constant(2.),)) 716 second_trace = f.get_concrete_function( 717 data_structures._TupleWrapper((constant_op.constant(3.),))) 718 self.assertIs(first_trace, second_trace) 719 720 def testPythonMapImpl(self): 721 t = data_structures._TupleWrapper((1, data_structures._TupleWrapper((2,)))) 722 self.assertEqual( 723 (4, (5,)), 724 nest.map_structure_up_to((None, (None,)), lambda x: x + 3, t, 725 check_types=True)) 726 nest.assert_shallow_structure((None, None), t) 727 728 def testDatasetMap(self): 729 dataset = dataset_ops.Dataset.from_tensor_slices( 730 constant_op.constant([1, 2, 3])) 731 dataset = dataset.map(lambda x: data_structures._TupleWrapper((x,))) 732 for index, element in enumerate(dataset): 733 self.assertEqual((index + 1,), self.evaluate(element)) 734 735 def testDatasetMapNamed(self): 736 nt_type = collections.namedtuple("A", ["x"]) 737 dataset = dataset_ops.Dataset.from_tensor_slices( 738 constant_op.constant([1, 2, 3])) 739 dataset = dataset.map(lambda x: data_structures._TupleWrapper(nt_type(x))) 740 for index, element in enumerate(dataset): 741 self.assertEqual((index + 1,), self.evaluate(element)) 742 743 def testLoopAssignedModule(self): 744 m = module.Module() 745 m.s = (m,) 746 self.assertLen(m._trackable_children(), 1) 747 self.assertIn("s", m._trackable_children()) 748 self.assertIs(m.s, m._trackable_children()["s"]) 749 self.assertEqual((), m.trainable_variables) 750 751 752if __name__ == "__main__": 753 test.main() 754