• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15import collections
16import copy
17import json
18import os
19import pickle
20
21from absl.testing import parameterized
22from tensorflow.python.checkpoint import checkpoint as util
23from tensorflow.python.data.ops import dataset_ops
24from tensorflow.python.eager import context
25from tensorflow.python.eager import def_function
26from tensorflow.python.eager import test
27from tensorflow.python.framework import constant_op
28from tensorflow.python.framework import tensor_shape
29from tensorflow.python.layers import core as non_keras_core
30from tensorflow.python.module import module
31from tensorflow.python.ops import array_ops
32from tensorflow.python.ops import resource_variable_ops
33from tensorflow.python.ops import variables
34from tensorflow.python.trackable import autotrackable
35from tensorflow.python.trackable import data_structures
36from tensorflow.python.util import nest
37from tensorflow.python.util import serialization
38
39
40class ListTests(test.TestCase):
41
42  def testJSONSerialization(self):
43    obj = autotrackable.AutoTrackable()
44    obj.l = [1]
45    json.dumps(obj.l, default=serialization.get_json_type)
46
47  def testNotTrackable(self):
48    class NotTrackable(object):
49      pass
50
51    with self.assertRaises(ValueError):
52      data_structures.List([NotTrackable()])
53
54  def testCallNotImplemented(self):
55    with self.assertRaisesRegex(TypeError, "not callable"):
56      data_structures.List()(1.)  # pylint: disable=not-callable
57
58  def testNoPop(self):
59    with self.assertRaises(AttributeError):
60      data_structures.List().pop()
61
62  def testNesting(self):
63    with context.graph_mode():
64      inner = data_structures.List()
65      outer = data_structures.List([inner])
66      inner.append(non_keras_core.Dense(1))
67      inner[0](array_ops.ones([2, 3]))
68      self.assertEqual(2, len(outer.variables))
69      self.assertIsInstance(
70          outer.variables[0],
71          resource_variable_ops.ResourceVariable)
72
73  def testNonLayerVariables(self):
74    v = resource_variable_ops.ResourceVariable([1.])
75    l = data_structures.List([v])
76    self.assertTrue(l.trainable)
77    self.assertEqual([], l.layers)
78    self.assertEqual([v], l.variables)
79    self.assertEqual([v], l.trainable_weights)
80    self.assertEqual([], l.non_trainable_variables)
81    l.trainable = False
82    self.assertEqual([v], l.variables)
83    self.assertEqual([], l.trainable_variables)
84    self.assertEqual([v], l.non_trainable_variables)
85    l.trainable = True
86    v2 = resource_variable_ops.ResourceVariable(1., trainable=False)
87    l.append(v2)
88    self.assertEqual([v, v2], l.weights)
89    self.assertEqual([v], l.trainable_weights)
90    self.assertEqual([v2], l.non_trainable_weights)
91
92  def testCopy(self):
93    v1 = resource_variable_ops.ResourceVariable(1.)
94    v2 = resource_variable_ops.ResourceVariable(1.)
95    v3 = resource_variable_ops.ResourceVariable(1.)
96
97    l1 = data_structures.List([v1, v2])
98    l2 = l1.copy()
99    l2.append(v3)
100    self.assertEqual(list(l1), [v1, v2])
101    self.assertEqual(list(l2), [v1, v2, v3])
102
103  def testSlicing(self):
104    v1 = resource_variable_ops.ResourceVariable(1.)
105    v2 = resource_variable_ops.ResourceVariable(1.)
106    v3 = resource_variable_ops.ResourceVariable(1.)
107    v4 = resource_variable_ops.ResourceVariable(1.)
108
109    l = data_structures.List([v1, v2, v3, v4])
110    self.assertEqual(l[1:], [v2, v3, v4])
111    self.assertEqual(l[1:-1], [v2, v3])
112    self.assertEqual(l[:-1], [v1, v2, v3])
113
114  def testHash(self):
115    has_sequences = {data_structures.List(), data_structures.List()}
116    self.assertEqual(2, len(has_sequences))
117    self.assertNotIn(data_structures.List(), has_sequences)
118
119  def testIMul_zero(self):
120    l = data_structures.List([])
121    with self.assertRaisesRegex(ValueError, "List only supports append"):
122      l *= 0
123
124  def testIMul(self):
125    v = resource_variable_ops.ResourceVariable(1.)
126    l = data_structures.List([v])
127    l *= 2
128    self.assertEqual(list(l), [v] * 2)
129
130  def testMul(self):
131    v = resource_variable_ops.ResourceVariable(1.)
132    l = data_structures.List([v, v, v])
133    self.assertEqual(list(l * 2), [v, v, v] * 2)
134
135  def testRMul(self):
136    v = resource_variable_ops.ResourceVariable(1.)
137    l = data_structures.List([v, v, v])
138    self.assertEqual(list(2 * l), [v, v, v] * 2)
139
140
141class ListWrapperTest(test.TestCase):
142
143  IGNORED = ("__new__", "__init__", "__subclasshook__", "__getattribute__")
144
145  def test_overrides_all_list_methods(self):
146    not_overridden = []
147
148    for name in dir(list):
149      if name in ListWrapperTest.IGNORED:
150        continue
151
152      list_method = getattr(list, name)
153
154      if not callable(list_method):
155        continue
156
157      object_method = getattr(object, name, None)
158      if object_method is not None and object_method == list_method:
159        # Skip methods that aren't overridden from object.
160        continue
161
162      if list_method == getattr(data_structures.ListWrapper, name):
163        not_overridden.append(name)
164
165    if not_overridden:
166      self.fail("ListWrapper does not override %s" % (not_overridden))
167
168  def testPickle(self):
169    original = data_structures.ListWrapper([1, 2])
170    serialized = pickle.dumps(original)
171    del original
172    deserialized = pickle.loads(serialized)
173    self.assertEqual([1, 2], deserialized)
174
175  def testSameStructure(self):
176    l = [1]
177    nest.assert_same_structure(l, data_structures.ListWrapper(copy.copy(l)))
178
179  def testMutateWithoutTrackableComponents(self):
180    m = module.Module()
181    m.l = [1, 2]
182    m.l.insert(0, 0)
183    self.assertEqual(m.l, [0, 1, 2])
184    self.assertEqual(m.l._trackable_children(), {})
185
186  def testFunctionCaching(self):
187    @def_function.function
188    def f(list_input):
189      return list_input[0] + constant_op.constant(1.)
190
191    first_trace = f.get_concrete_function([constant_op.constant(2.)])
192    second_trace = f.get_concrete_function(
193        data_structures.ListWrapper([constant_op.constant(3.)]))
194    self.assertIs(first_trace, second_trace)
195
196  def testListWrapperBasic(self):
197    # ListWrapper, unlike List, compares like the built-in list type (since it
198    # is used to automatically replace lists).
199    a = autotrackable.AutoTrackable()
200    b = autotrackable.AutoTrackable()
201    self.assertEqual([a, a],
202                     [a, a])
203    self.assertEqual(data_structures.ListWrapper([a, a]),
204                     data_structures.ListWrapper([a, a]))
205    self.assertEqual([a, a],
206                     data_structures.ListWrapper([a, a]))
207    self.assertEqual(data_structures.ListWrapper([a, a]),
208                     [a, a])
209    self.assertNotEqual([a, a],
210                        [b, a])
211    self.assertNotEqual(data_structures.ListWrapper([a, a]),
212                        data_structures.ListWrapper([b, a]))
213    self.assertNotEqual([a, a],
214                        data_structures.ListWrapper([b, a]))
215    self.assertLess([a], [a, b])
216    self.assertLess(data_structures.ListWrapper([a]),
217                    data_structures.ListWrapper([a, b]))
218    self.assertLessEqual([a], [a, b])
219    self.assertLessEqual(data_structures.ListWrapper([a]),
220                         data_structures.ListWrapper([a, b]))
221    self.assertGreater([a, b], [a])
222    self.assertGreater(data_structures.ListWrapper([a, b]),
223                       data_structures.ListWrapper([a]))
224    self.assertGreaterEqual([a, b], [a])
225    self.assertGreaterEqual(data_structures.ListWrapper([a, b]),
226                            data_structures.ListWrapper([a]))
227    self.assertEqual([a], data_structures.ListWrapper([a]))
228    self.assertEqual([a], list(data_structures.List([a])))
229    self.assertEqual([a, a], data_structures.ListWrapper([a]) + [a])
230    self.assertEqual([a, a], [a] + data_structures.ListWrapper([a]))
231    self.assertIsInstance(data_structures.ListWrapper([a]), list)
232    self.assertEqual(
233        tensor_shape.TensorShape([None, 2]).as_list(),
234        (data_structures.ListWrapper([None])
235         + tensor_shape.TensorShape([2])).as_list())
236
237  def testAcceptsNonTrackableContent(self):
238    l = data_structures.ListWrapper([1, 2, 3])
239    self.assertEqual(l, [1, 2, 3])
240
241  def testWrapperChangesList(self):
242    l = []
243    l_wrapper = data_structures.ListWrapper(l)
244    l_wrapper.append(1)
245    self.assertEqual([1], l)
246
247  def testListChangesWrapper(self):
248    l = []
249    l_wrapper = data_structures.ListWrapper(l)
250    l.append(1)
251    self.assertEqual([1], l_wrapper)
252
253  def testNotHashable(self):
254    with self.assertRaises(TypeError):
255      hash(data_structures.ListWrapper())  # pylint: disable=no-value-for-parameter
256
257  def testDelItem(self):
258    l = data_structures.ListWrapper([1, 2, 3, [4]])
259    del l[0]
260    self.assertEqual(l, [2, 3, [4]])
261    self.assertUnableToSave(l, "Unable to save .*__delitem__")
262
263  def testDelSlice(self):
264    l = data_structures.ListWrapper([1, 2, 3, [4]])
265    del l[2:3]
266    self.assertEqual(l, [1, 2, [4]])
267    self.assertUnableToSave(l, "Unable to save .*__delslice__")
268
269  def testSetSlice_canSaveForNonTrackableItems(self):
270    l = data_structures.ListWrapper([1, 2, 3, 4])
271    l[:] = 2, 8, 9, 0
272    self.assertEqual(l, [2, 8, 9, 0])
273    l._maybe_initialize_trackable()  # pylint: disable=protected-access
274    self.assertEqual(len(l._trackable_children()), 0)  # pylint: disable=protected-access
275
276  def testSetSlice_cannotSaveIfTrackableModified(self):
277    v1 = resource_variable_ops.ResourceVariable(1.)
278    v2 = resource_variable_ops.ResourceVariable(1.)
279    l = data_structures.ListWrapper([1, 2, v1, v2])
280    l[:] = 2, 8, 9, v2
281    self.assertEqual(l, [2, 8, 9, v2])
282    self.assertUnableToSave(l, "Unable to save .*__setslice__")
283
284  def testSetSlice_truncate(self):
285    l = data_structures.ListWrapper([1, 2, 3, 4])
286    l[:] = []
287    self.assertEqual(l, [])
288
289  def testSetSlice_extend(self):
290    l = data_structures.ListWrapper([1, 2, 3, 4])
291    l[2:] = 1, 2, 3, 4
292    self.assertEqual(l, [1, 2, 1, 2, 3, 4])
293
294  def testIMulNegative(self):
295    l = data_structures.ListWrapper([1, 2, 3, [4]])
296    l *= -1
297    self.assertEqual(l, [1, 2, 3, [4]] * -1)
298    self.assertUnableToSave(l, "Unable to save")
299
300  def testIMulPositive(self):
301    v = variables.Variable(1.)
302    l = data_structures.ListWrapper([1, 2, 3, 4, v])
303    self.assertDictEqual({"4": v}, l._trackable_children())
304    root = util.Checkpoint(l=l)
305    prefix = os.path.join(self.get_temp_dir(), "ckpt")
306    path = root.save(prefix)
307    v.assign(5.)
308    l *= 2
309    self.assertEqual(l, [1, 2, 3, 4, v, 1, 2, 3, 4, v])
310    self.assertDictEqual({"4": v, "9": v}, l._trackable_children())
311    root.restore(path)
312    self.assertAllClose(1., v.numpy())
313
314  def testSort(self):
315    l = data_structures.ListWrapper([[1], [2], [3], [4]])
316    l.sort()
317    self.assertAllEqual(l, [[1], [2], [3], [4]])
318    # Regardless of being a no-op for the input list, we still refuse to save.
319    # This is intentional since otherwise we would end up with a hard to debug
320    # case for users (e.g. sometimes sort on a ListWrapper is trackable and
321    # other times it is not).
322    self.assertUnableToSave(l, "Unable to save .*sort")
323
324  def assertUnableToSave(self, l, msg):
325    l._maybe_initialize_trackable()  # pylint: disable=protected-access
326    with self.assertRaisesRegex(ValueError, msg):
327      return l._trackable_children()  # pylint: disable=protected-access
328
329
330class MappingTests(test.TestCase):
331
332  def testJSONSerialization(self):
333    obj = autotrackable.AutoTrackable()
334    obj.d = {"a": 2}
335    json.dumps(obj.d, default=serialization.get_json_type)
336
337  def testNoOverwrite(self):
338    mapping = data_structures.Mapping()
339    original = data_structures.List()
340    mapping["a"] = original
341    with self.assertRaises(ValueError):
342      mapping["a"] = data_structures.List()
343    self.assertIs(original, mapping["a"])
344    with self.assertRaises(AttributeError):
345      del mapping["a"]  # pylint: disable=unsupported-delete-operation
346    mapping.update(b=data_structures.Mapping())
347    with self.assertRaises(ValueError):
348      mapping.update({"b": data_structures.Mapping()})
349
350  def testNonStringKeys(self):
351    mapping = data_structures.Mapping()
352    with self.assertRaises(TypeError):
353      mapping[1] = data_structures.List()
354
355  def testHashing(self):
356    has_mappings = set([data_structures.Mapping(),
357                        data_structures.Mapping()])
358    self.assertEqual(2, len(has_mappings))
359    self.assertNotIn(data_structures.Mapping(), has_mappings)
360    # In contrast to Mapping, dict wrappers are not hashable
361    a = autotrackable.AutoTrackable()
362    a.d = {}
363    self.assertEqual({}, a.d)
364    self.assertFalse({} != a.d)  # pylint: disable=g-explicit-bool-comparison
365    self.assertNotEqual({1: 2}, a.d)
366    with self.assertRaisesRegex(TypeError, "unhashable"):
367      set([a.d])
368
369  def testListShallowCopy(self):
370    root = autotrackable.AutoTrackable()
371    orig_list = [[1.]]
372    root.a = orig_list
373    copied = copy.copy(root.a)
374    self.assertAllEqual([[1.]], copied)
375    self.assertIsNot(root.a, copied)
376    self.assertIs(root.a[0], copied[0])
377
378    # Dirtiness should be inherited
379    util.list_objects(root.a)
380    orig_list.append(1.)
381    with self.assertRaises(ValueError):
382      util.list_objects(root.a)
383    with self.assertRaises(ValueError):
384      util.list_objects(copy.copy(root.a))
385
386  def testListDeepCopy(self):
387    root = autotrackable.AutoTrackable()
388    orig_list = [[1.]]
389    root.a = orig_list
390    copied = copy.deepcopy(root.a)
391    self.assertAllEqual([[1.]], copied)
392    self.assertIsNot(root.a, copied)
393    self.assertIsNot(root.a[0], copied[0])
394
395    # Dirtiness should be inherited
396    util.list_objects(root.a)
397    orig_list.append(1.)
398    with self.assertRaises(ValueError):
399      util.list_objects(root.a)
400    with self.assertRaises(ValueError):
401      util.list_objects(copy.deepcopy(root.a))
402
403  def testDictShallowCopy(self):
404    root = autotrackable.AutoTrackable()
405    orig_dict = {"a": [1.]}
406    root.a = orig_dict
407    copied = copy.copy(root.a)
408    self.assertAllEqual([1.], copied["a"])
409    self.assertIsNot(root.a, copied)
410    self.assertIs(root.a["a"], copied["a"])
411
412    copied = root.a.copy()
413    self.assertAllEqual([1.], copied["a"])
414    self.assertIsNot(root.a, copied)
415    self.assertIs(root.a["a"], copied["a"])
416
417    # Dirtiness should be inherited
418    util.list_objects(root.a)
419    orig_dict["b"] = []
420    with self.assertRaises(ValueError):
421      util.list_objects(root.a)
422    with self.assertRaises(ValueError):
423      util.list_objects(copy.copy(root.a))
424
425  def testDictDeepCopy(self):
426    root = autotrackable.AutoTrackable()
427    orig_dict = {"a": [1.]}
428    root.a = orig_dict
429    copied = copy.deepcopy(root.a)
430    self.assertAllEqual([1.], copied["a"])
431    self.assertIsNot(root.a, copied)
432    self.assertIsNot(root.a["a"], copied["a"])
433
434    # Dirtiness should be inherited
435    util.list_objects(root.a)
436    orig_dict["b"] = []
437    with self.assertRaises(ValueError):
438      util.list_objects(root.a)
439    with self.assertRaises(ValueError):
440      util.list_objects(copy.deepcopy(root.a))
441
442  def testShallowCopyTrackable(self):
443    original = autotrackable.AutoTrackable()
444    original_sub = autotrackable.AutoTrackable()
445    original.a = [[1.]]
446    original.b = {"a": original_sub}
447    shallow_copied = copy.copy(original)
448    self.assertIs(original_sub, shallow_copied.b["a"])
449    self.assertIsNot(original, shallow_copied)
450    self.assertEqual([[1.]], shallow_copied.a)
451    shallow_deps = util.list_objects(shallow_copied)
452    self.assertIn(shallow_copied.a, shallow_deps)
453    self.assertIn(shallow_copied.b, shallow_deps)
454    self.assertIn(shallow_copied.b["a"], shallow_deps)
455
456  def testDeepCopyTrackable(self):
457    original = autotrackable.AutoTrackable()
458    original_sub = autotrackable.AutoTrackable()
459    original.a = [[1.]]
460    original.b = {"a": original_sub}
461    self.assertIsInstance(original.b, dict)
462    deep_copied = copy.deepcopy(original)
463    self.assertIsInstance(deep_copied.b, dict)
464    self.assertIsNot(original, deep_copied)
465    self.assertIsNot(original_sub, deep_copied.b["a"])
466    self.assertEqual([[1.]], deep_copied.a)
467    self.assertIsInstance(deep_copied.b["a"], autotrackable.AutoTrackable)
468    deps = util.list_objects(deep_copied)
469    self.assertIn(deep_copied.a, deps)
470    self.assertIn(deep_copied.b, deps)
471    self.assertIn(deep_copied.b["a"], deps)
472    self.assertNotIn(original_sub, deps)
473
474  def testConstructableFromSequence(self):
475    result = data_structures._DictWrapper([(1, 2), (3, 4)])
476    self.assertIsInstance(result, dict)
477    self.assertEqual({1: 2, 3: 4}, result)
478
479  def testPickle(self):
480    original = data_structures._DictWrapper(dict(a=1, b=2))
481    serialized = pickle.dumps(original)
482    del original
483    deserialized = pickle.loads(serialized)
484    self.assertEqual(dict(a=1, b=2), deserialized)
485
486  def testListAddOrder(self):
487    self.assertEqual([1., 2.],
488                     data_structures.ListWrapper([1.])
489                     + data_structures.ListWrapper([2.]))
490    self.assertEqual([1., 2.],
491                     data_structures.ListWrapper([1.])
492                     + [2.])
493    self.assertEqual([1., 2.],
494                     [1.]
495                     + data_structures.ListWrapper([2.]))
496
497  def testSameStructure(self):
498    d = {1: "a"}
499    nest.assert_same_structure(d, data_structures._DictWrapper(d.copy()))
500
501  def testFunctionCaching(self):
502    @def_function.function
503    def f(dict_input):
504      return dict_input["x"] + constant_op.constant(1.)
505
506    first_trace = f.get_concrete_function({"x": constant_op.constant(2.)})
507    second_trace = f.get_concrete_function(
508        data_structures._DictWrapper({"x": constant_op.constant(3.)}))
509    self.assertIs(first_trace, second_trace)
510
511
512class TupleTests(test.TestCase, parameterized.TestCase):
513
514  def testJSONSerialization(self):
515    obj = autotrackable.AutoTrackable()
516    obj.l = (1,)
517    json.dumps(obj.l, default=serialization.get_json_type)
518
519  def testNonLayerVariables(self):
520    v = resource_variable_ops.ResourceVariable([1.])
521    l = data_structures._TupleWrapper((v,))
522    self.assertEqual([], l.layers)
523    self.assertEqual([v], l.variables)
524    self.assertEqual([v], l.trainable_weights)
525    self.assertEqual([], l.non_trainable_variables)
526
527  def testCopy(self):
528    v1 = resource_variable_ops.ResourceVariable(1.)
529    v2 = resource_variable_ops.ResourceVariable(1.)
530
531    l1 = data_structures._TupleWrapper((v1, v2))
532    l2 = copy.copy(l1)
533    self.assertEqual(l1, (v1, v2))
534    self.assertEqual(l2, (v1, v2))
535    self.assertIs(l1[0], l2[0])
536    l2_deep = copy.deepcopy(l1)
537    self.assertIsNot(l1[0], l2_deep[0])
538    with self.assertRaises(AttributeError):
539      l2.append(v1)
540
541  def testSlicing(self):
542    v1 = resource_variable_ops.ResourceVariable(1.)
543    v2 = resource_variable_ops.ResourceVariable(1.)
544    v3 = resource_variable_ops.ResourceVariable(1.)
545    v4 = resource_variable_ops.ResourceVariable(1.)
546
547    l = data_structures._TupleWrapper((v1, v2, v3, v4))
548    self.assertEqual(l[1:], (v2, v3, v4))
549    self.assertEqual(l[1:-1], (v2, v3))
550    self.assertEqual(l[:-1], (v1, v2, v3))
551
552  def testHash(self):
553    has_sequences = set([data_structures._TupleWrapper(),
554                         data_structures._TupleWrapper()])
555    self.assertLen(has_sequences, 1)
556    self.assertIn(data_structures._TupleWrapper(), has_sequences)
557
558  def testIMul_zero(self):
559    l = data_structures._TupleWrapper((1,))
560    l *= 0
561    self.assertEqual((), l)
562
563  def testIMul(self):
564    # Note: tuple behavior differs from list behavior. Lists are mutated by
565    # imul/iadd, tuples assign a new object to the left hand side of the
566    # expression.
567    v = resource_variable_ops.ResourceVariable(1.)
568    l = data_structures._TupleWrapper((v,))
569    original = l
570    l *= 2
571    self.assertEqual(l, (v,) * 2)
572    self.assertNotEqual(original, (v,) * 2)
573
574  def testIAdd(self):
575    v = resource_variable_ops.ResourceVariable(1.)
576    l = data_structures._TupleWrapper((v,))
577    original = l
578    l += (1,)
579    self.assertEqual(l, (v, 1))
580    self.assertNotEqual(original, (v, 1))
581    self.assertEqual(original, (v,))
582
583  def testMul(self):
584    v = resource_variable_ops.ResourceVariable(1.)
585    l = data_structures._TupleWrapper((v, v, v))
586    self.assertEqual(l * 2, (v, v, v) * 2)
587
588  def testRMul(self):
589    v = resource_variable_ops.ResourceVariable(1.)
590    l = data_structures._TupleWrapper((v, v, v))
591    self.assertEqual(2 * l, (v, v, v) * 2)
592
593  def testPickle(self):
594    original = data_structures._TupleWrapper((1, 2))
595    serialized = pickle.dumps(original)
596    del original
597    deserialized = pickle.loads(serialized)
598    self.assertEqual((1, 2), deserialized)
599
600  def testNamedTuple(self):
601    named = collections.namedtuple("Named", ("x", "y"))
602    v = variables.Variable(2)
603    nt = named(x=v, y=2)
604    m = module.Module()
605    m.nt = nt
606    self.assertIs(v, m.nt.x)
607    self.assertIs(v, m.nt[0])
608    self.assertIs(
609        v, m._trackable_children()["nt"]._trackable_children()["x"])
610    self.assertEqual(2, m.nt.y)
611
612  def testNamedTupleConflictingAttributes(self):
613    named = collections.namedtuple("Named", ("x", "weights"))
614    v = variables.Variable(2)
615    nt = named(x=v, weights=3)
616    m = module.Module()
617    m.nt = nt
618    self.assertEqual(3, m.nt.weights)
619
620  def testNamedSubclassing(self):
621    named = collections.namedtuple("Named", ("x", "y"))
622    v = variables.Variable(2)
623
624    class NamedSubclass(named):
625
626      def __new__(cls, x, y):
627        del y  # unused
628        return super(NamedSubclass, cls).__new__(cls, x, 3)
629
630      @property
631      def summed(self):
632        return self.x + self.y
633
634    nt = NamedSubclass(x=v, y=2)
635    m = module.Module()
636    m.nt = nt
637    self.assertEqual(3, m.nt.y)
638    self.assertIs(v, m.nt.x)
639    self.assertIn(v,
640                  m._trackable_children()["nt"]._trackable_children().values())
641    self.assertIn("x", m.nt._trackable_children())
642    self.assertIn("0", m.nt._trackable_children())
643    self.assertEqual(5, self.evaluate(m.nt.summed))
644
645  def testUnnamedSubclassing(self):
646    v = variables.Variable(2)
647
648    class UnnamedSubclass(tuple):
649
650      @property
651      def summed(self):
652        return self[0] + self[1]
653
654    unt = UnnamedSubclass([v, 2])
655    m = module.Module()
656    m.unt = unt
657    self.assertIn("0", m.unt._trackable_children())
658    self.assertLen(m.unt._trackable_children(), 1)
659    self.assertEqual(4, self.evaluate(m.unt.summed))
660    nest.assert_same_structure(
661        [m.unt], nest.map_structure(lambda x: x, [m.unt]))
662
663  def testNamedtupleSubclassWithCustomNew(self):
664    class SubclassWithDifferentArgs(collections.namedtuple("A", ["x"])):
665
666      def __new__(cls):
667        return super(SubclassWithDifferentArgs, cls).__new__(cls, [])
668
669    nt = SubclassWithDifferentArgs()
670    m = module.Module()
671    m.nt = nt
672    m.nt.x.append(variables.Variable(1.))
673    prefix = os.path.join(self.get_temp_dir(), "ckpt")
674    ckpt = util.Checkpoint(m=m)
675    with self.assertRaises(ValueError):
676      ckpt.save(prefix)
677
678  def testSameStructure(self):
679    t = (variables.Variable(1.),)
680    m = module.Module()
681    m.t = t
682    nest.assert_same_structure(t, m.t)
683    nest.assert_same_structure(m.t, t)
684
685    nt_type = collections.namedtuple("nt", ["x", "y"])
686    nt = nt_type(x=1, y=2)
687    m.nt = nt
688    nest.assert_same_structure(m.nt, nt)
689    with self.assertRaises(TypeError):  # pylint: disable=g-error-prone-assert-raises
690      nest.assert_same_structure(m.nt, m.t)
691
692  def testFlatten(self):
693    t = data_structures._TupleWrapper((1, data_structures._TupleWrapper((2,))))
694    self.assertEqual([1, 2], nest.flatten(t))
695    self.assertEqual(
696        nest.flatten_with_tuple_paths((1, (2,))),
697        nest.flatten_with_tuple_paths(t))
698    self.assertEqual((3, (4,)),
699                     nest.pack_sequence_as(t, [3, 4]))
700    nt_type = collections.namedtuple("nt", ["x", "y"])
701    nt = nt_type(1., 2.)
702    wrapped_nt = data_structures._TupleWrapper(nt)
703    self.assertEqual(
704        nest.flatten_with_tuple_paths(nt),
705        nest.flatten_with_tuple_paths(wrapped_nt))
706    self.assertEqual((3, 4,),
707                     nest.pack_sequence_as(wrapped_nt, [3, 4]))
708    self.assertEqual(3, nest.pack_sequence_as(wrapped_nt, [3, 4]).x)
709
710  def testFunctionCaching(self):
711    @def_function.function
712    def f(tuple_input):
713      return tuple_input[0] + constant_op.constant(1.)
714
715    first_trace = f.get_concrete_function((constant_op.constant(2.),))
716    second_trace = f.get_concrete_function(
717        data_structures._TupleWrapper((constant_op.constant(3.),)))
718    self.assertIs(first_trace, second_trace)
719
720  def testPythonMapImpl(self):
721    t = data_structures._TupleWrapper((1, data_structures._TupleWrapper((2,))))
722    self.assertEqual(
723        (4, (5,)),
724        nest.map_structure_up_to((None, (None,)), lambda x: x + 3, t,
725                                 check_types=True))
726    nest.assert_shallow_structure((None, None), t)
727
728  def testDatasetMap(self):
729    dataset = dataset_ops.Dataset.from_tensor_slices(
730        constant_op.constant([1, 2, 3]))
731    dataset = dataset.map(lambda x: data_structures._TupleWrapper((x,)))
732    for index, element in enumerate(dataset):
733      self.assertEqual((index + 1,), self.evaluate(element))
734
735  def testDatasetMapNamed(self):
736    nt_type = collections.namedtuple("A", ["x"])
737    dataset = dataset_ops.Dataset.from_tensor_slices(
738        constant_op.constant([1, 2, 3]))
739    dataset = dataset.map(lambda x: data_structures._TupleWrapper(nt_type(x)))
740    for index, element in enumerate(dataset):
741      self.assertEqual((index + 1,), self.evaluate(element))
742
743  def testLoopAssignedModule(self):
744    m = module.Module()
745    m.s = (m,)
746    self.assertLen(m._trackable_children(), 1)
747    self.assertIn("s", m._trackable_children())
748    self.assertIs(m.s, m._trackable_children()["s"])
749    self.assertEqual((), m.trainable_variables)
750
751
752if __name__ == "__main__":
753  test.main()
754