• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15import os
16
17import numpy as np
18from tensorflow.python.checkpoint import checkpoint as util
19from tensorflow.python.framework import test_util
20from tensorflow.python.ops import array_ops
21from tensorflow.python.platform import test
22from tensorflow.python.trackable import autotrackable
23from tensorflow.python.trackable import data_structures
24from tensorflow.python.util import nest
25
26
27class InterfaceTests(test.TestCase):
28
29  def testMultipleAssignment(self):
30    root = autotrackable.AutoTrackable()
31    root.leaf = autotrackable.AutoTrackable()
32    root.leaf = root.leaf
33    duplicate_name_dep = autotrackable.AutoTrackable()
34    with self.assertRaisesRegex(ValueError, "already declared"):
35      root._track_trackable(duplicate_name_dep, name="leaf")
36    # No error; we're overriding __setattr__, so we can't really stop people
37    # from doing this while maintaining backward compatibility.
38    root.leaf = duplicate_name_dep
39    root._track_trackable(duplicate_name_dep, name="leaf", overwrite=True)
40    self.assertIs(duplicate_name_dep, root._lookup_dependency("leaf"))
41    self.assertIs(duplicate_name_dep, root._trackable_children()["leaf"])
42
43  def testRemoveDependency(self):
44    root = autotrackable.AutoTrackable()
45    root.a = autotrackable.AutoTrackable()
46    self.assertEqual(1, len(root._trackable_children()))
47    self.assertEqual(1, len(root._unconditional_checkpoint_dependencies))
48    self.assertIs(root.a, root._trackable_children()["a"])
49    del root.a
50    self.assertFalse(hasattr(root, "a"))
51    self.assertEqual(0, len(root._trackable_children()))
52    self.assertEqual(0, len(root._unconditional_checkpoint_dependencies))
53    root.a = autotrackable.AutoTrackable()
54    self.assertEqual(1, len(root._trackable_children()))
55    self.assertEqual(1, len(root._unconditional_checkpoint_dependencies))
56    self.assertIs(root.a, root._trackable_children()["a"])
57
58  def testListBasic(self):
59    a = autotrackable.AutoTrackable()
60    b = autotrackable.AutoTrackable()
61    a.l = [b]
62    c = autotrackable.AutoTrackable()
63    a.l.append(c)
64    a_deps = util.list_objects(a)
65    self.assertIn(b, a_deps)
66    self.assertIn(c, a_deps)
67    self.assertIn("l", a._trackable_children())
68    direct_a_dep = a._trackable_children()["l"]
69    self.assertIn(b, direct_a_dep)
70    self.assertIn(c, direct_a_dep)
71
72  @test_util.run_in_graph_and_eager_modes
73  def testMutationDirtiesList(self):
74    a = autotrackable.AutoTrackable()
75    b = autotrackable.AutoTrackable()
76    a.l = [b]
77    c = autotrackable.AutoTrackable()
78    a.l.insert(0, c)
79    checkpoint = util.Checkpoint(a=a)
80    with self.assertRaisesRegex(ValueError, "A list element was replaced"):
81      checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
82
83  @test_util.run_in_graph_and_eager_modes
84  def testOutOfBandEditDirtiesList(self):
85    a = autotrackable.AutoTrackable()
86    b = autotrackable.AutoTrackable()
87    held_reference = [b]
88    a.l = held_reference
89    c = autotrackable.AutoTrackable()
90    held_reference.append(c)
91    checkpoint = util.Checkpoint(a=a)
92    with self.assertRaisesRegex(ValueError, "The wrapped list was modified"):
93      checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
94
95  @test_util.run_in_graph_and_eager_modes
96  def testNestedLists(self):
97    a = autotrackable.AutoTrackable()
98    a.l = []
99    b = autotrackable.AutoTrackable()
100    a.l.append([b])
101    c = autotrackable.AutoTrackable()
102    a.l[0].append(c)
103    a_deps = util.list_objects(a)
104    self.assertIn(b, a_deps)
105    self.assertIn(c, a_deps)
106    a.l[0].append(1)
107    d = autotrackable.AutoTrackable()
108    a.l[0].append(d)
109    a_deps = util.list_objects(a)
110    self.assertIn(d, a_deps)
111    self.assertIn(b, a_deps)
112    self.assertIn(c, a_deps)
113    self.assertNotIn(1, a_deps)
114    e = autotrackable.AutoTrackable()
115    f = autotrackable.AutoTrackable()
116    a.l1 = [[], [e]]
117    a.l1[0].append(f)
118    a_deps = util.list_objects(a)
119    self.assertIn(e, a_deps)
120    self.assertIn(f, a_deps)
121    checkpoint = util.Checkpoint(a=a)
122    checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
123    a.l[0].append(data_structures.NoDependency([]))
124    a.l[0][-1].append(5)
125    checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
126    # Dirtying the inner list means the root object is unsaveable.
127    a.l[0][1] = 2
128    with self.assertRaisesRegex(ValueError, "A list element was replaced"):
129      checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
130
131  @test_util.run_in_graph_and_eager_modes
132  def testAssertions(self):
133    a = autotrackable.AutoTrackable()
134    a.l = {"k": [np.zeros([2, 2])]}
135    self.assertAllEqual(nest.flatten({"k": [np.zeros([2, 2])]}),
136                        nest.flatten(a.l))
137    self.assertAllClose({"k": [np.zeros([2, 2])]}, a.l)
138    nest.map_structure(self.assertAllClose, a.l, {"k": [np.zeros([2, 2])]})
139    a.tensors = {"k": [array_ops.ones([2, 2]), array_ops.zeros([3, 3])]}
140    self.assertAllClose({"k": [np.ones([2, 2]), np.zeros([3, 3])]},
141                        self.evaluate(a.tensors))
142
143
144if __name__ == "__main__":
145  test.main()
146