1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15import os 16 17import numpy as np 18from tensorflow.python.checkpoint import checkpoint as util 19from tensorflow.python.framework import test_util 20from tensorflow.python.ops import array_ops 21from tensorflow.python.platform import test 22from tensorflow.python.trackable import autotrackable 23from tensorflow.python.trackable import data_structures 24from tensorflow.python.util import nest 25 26 27class InterfaceTests(test.TestCase): 28 29 def testMultipleAssignment(self): 30 root = autotrackable.AutoTrackable() 31 root.leaf = autotrackable.AutoTrackable() 32 root.leaf = root.leaf 33 duplicate_name_dep = autotrackable.AutoTrackable() 34 with self.assertRaisesRegex(ValueError, "already declared"): 35 root._track_trackable(duplicate_name_dep, name="leaf") 36 # No error; we're overriding __setattr__, so we can't really stop people 37 # from doing this while maintaining backward compatibility. 38 root.leaf = duplicate_name_dep 39 root._track_trackable(duplicate_name_dep, name="leaf", overwrite=True) 40 self.assertIs(duplicate_name_dep, root._lookup_dependency("leaf")) 41 self.assertIs(duplicate_name_dep, root._trackable_children()["leaf"]) 42 43 def testRemoveDependency(self): 44 root = autotrackable.AutoTrackable() 45 root.a = autotrackable.AutoTrackable() 46 self.assertEqual(1, len(root._trackable_children())) 47 self.assertEqual(1, len(root._unconditional_checkpoint_dependencies)) 48 self.assertIs(root.a, root._trackable_children()["a"]) 49 del root.a 50 self.assertFalse(hasattr(root, "a")) 51 self.assertEqual(0, len(root._trackable_children())) 52 self.assertEqual(0, len(root._unconditional_checkpoint_dependencies)) 53 root.a = autotrackable.AutoTrackable() 54 self.assertEqual(1, len(root._trackable_children())) 55 self.assertEqual(1, len(root._unconditional_checkpoint_dependencies)) 56 self.assertIs(root.a, root._trackable_children()["a"]) 57 58 def testListBasic(self): 59 a = autotrackable.AutoTrackable() 60 b = autotrackable.AutoTrackable() 61 a.l = [b] 62 c = autotrackable.AutoTrackable() 63 a.l.append(c) 64 a_deps = util.list_objects(a) 65 self.assertIn(b, a_deps) 66 self.assertIn(c, a_deps) 67 self.assertIn("l", a._trackable_children()) 68 direct_a_dep = a._trackable_children()["l"] 69 self.assertIn(b, direct_a_dep) 70 self.assertIn(c, direct_a_dep) 71 72 @test_util.run_in_graph_and_eager_modes 73 def testMutationDirtiesList(self): 74 a = autotrackable.AutoTrackable() 75 b = autotrackable.AutoTrackable() 76 a.l = [b] 77 c = autotrackable.AutoTrackable() 78 a.l.insert(0, c) 79 checkpoint = util.Checkpoint(a=a) 80 with self.assertRaisesRegex(ValueError, "A list element was replaced"): 81 checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt")) 82 83 @test_util.run_in_graph_and_eager_modes 84 def testOutOfBandEditDirtiesList(self): 85 a = autotrackable.AutoTrackable() 86 b = autotrackable.AutoTrackable() 87 held_reference = [b] 88 a.l = held_reference 89 c = autotrackable.AutoTrackable() 90 held_reference.append(c) 91 checkpoint = util.Checkpoint(a=a) 92 with self.assertRaisesRegex(ValueError, "The wrapped list was modified"): 93 checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt")) 94 95 @test_util.run_in_graph_and_eager_modes 96 def testNestedLists(self): 97 a = autotrackable.AutoTrackable() 98 a.l = [] 99 b = autotrackable.AutoTrackable() 100 a.l.append([b]) 101 c = autotrackable.AutoTrackable() 102 a.l[0].append(c) 103 a_deps = util.list_objects(a) 104 self.assertIn(b, a_deps) 105 self.assertIn(c, a_deps) 106 a.l[0].append(1) 107 d = autotrackable.AutoTrackable() 108 a.l[0].append(d) 109 a_deps = util.list_objects(a) 110 self.assertIn(d, a_deps) 111 self.assertIn(b, a_deps) 112 self.assertIn(c, a_deps) 113 self.assertNotIn(1, a_deps) 114 e = autotrackable.AutoTrackable() 115 f = autotrackable.AutoTrackable() 116 a.l1 = [[], [e]] 117 a.l1[0].append(f) 118 a_deps = util.list_objects(a) 119 self.assertIn(e, a_deps) 120 self.assertIn(f, a_deps) 121 checkpoint = util.Checkpoint(a=a) 122 checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt")) 123 a.l[0].append(data_structures.NoDependency([])) 124 a.l[0][-1].append(5) 125 checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt")) 126 # Dirtying the inner list means the root object is unsaveable. 127 a.l[0][1] = 2 128 with self.assertRaisesRegex(ValueError, "A list element was replaced"): 129 checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt")) 130 131 @test_util.run_in_graph_and_eager_modes 132 def testAssertions(self): 133 a = autotrackable.AutoTrackable() 134 a.l = {"k": [np.zeros([2, 2])]} 135 self.assertAllEqual(nest.flatten({"k": [np.zeros([2, 2])]}), 136 nest.flatten(a.l)) 137 self.assertAllClose({"k": [np.zeros([2, 2])]}, a.l) 138 nest.map_structure(self.assertAllClose, a.l, {"k": [np.zeros([2, 2])]}) 139 a.tensors = {"k": [array_ops.ones([2, 2]), array_ops.zeros([3, 3])]} 140 self.assertAllClose({"k": [np.ones([2, 2]), np.zeros([3, 3])]}, 141 self.evaluate(a.tensors)) 142 143 144if __name__ == "__main__": 145 test.main() 146