1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Unit tests for object_identity.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.platform import test 22from tensorflow.python.util import nest 23from tensorflow.python.util import object_identity 24 25 26class ObjectIdentityWrapperTest(test.TestCase): 27 28 def testWrapperNotEqualToWrapped(self): 29 class SettableHash(object): 30 31 def __init__(self): 32 self.hash_value = 8675309 33 34 def __hash__(self): 35 return self.hash_value 36 37 o = SettableHash() 38 wrap1 = object_identity._ObjectIdentityWrapper(o) 39 wrap2 = object_identity._ObjectIdentityWrapper(o) 40 41 self.assertEqual(wrap1, wrap1) 42 self.assertEqual(wrap1, wrap2) 43 self.assertEqual(o, wrap1.unwrapped) 44 self.assertEqual(o, wrap2.unwrapped) 45 with self.assertRaises(TypeError): 46 bool(o == wrap1) 47 with self.assertRaises(TypeError): 48 bool(wrap1 != o) 49 50 self.assertNotIn(o, set([wrap1])) 51 o.hash_value = id(o) 52 # Since there is now a hash collision we raise an exception 53 with self.assertRaises(TypeError): 54 bool(o in set([wrap1])) 55 56 def testNestFlatten(self): 57 a = object_identity._ObjectIdentityWrapper('a') 58 b = object_identity._ObjectIdentityWrapper('b') 59 c = object_identity._ObjectIdentityWrapper('c') 60 flat = nest.flatten([[[(a, b)]], c]) 61 self.assertEqual(flat, [a, b, c]) 62 63 def testNestMapStructure(self): 64 k = object_identity._ObjectIdentityWrapper('k') 65 v1 = object_identity._ObjectIdentityWrapper('v1') 66 v2 = object_identity._ObjectIdentityWrapper('v2') 67 struct = nest.map_structure(lambda a, b: (a, b), {k: v1}, {k: v2}) 68 self.assertEqual(struct, {k: (v1, v2)}) 69 70 71class ObjectIdentitySetTest(test.TestCase): 72 73 def testDifference(self): 74 75 class Element(object): 76 pass 77 78 a = Element() 79 b = Element() 80 c = Element() 81 set1 = object_identity.ObjectIdentitySet([a, b]) 82 set2 = object_identity.ObjectIdentitySet([b, c]) 83 diff_set = set1.difference(set2) 84 self.assertIn(a, diff_set) 85 self.assertNotIn(b, diff_set) 86 self.assertNotIn(c, diff_set) 87 88 def testDiscard(self): 89 a = object() 90 b = object() 91 set1 = object_identity.ObjectIdentitySet([a, b]) 92 set1.discard(a) 93 self.assertIn(b, set1) 94 self.assertNotIn(a, set1) 95 96 def testClear(self): 97 a = object() 98 b = object() 99 set1 = object_identity.ObjectIdentitySet([a, b]) 100 set1.clear() 101 self.assertLen(set1, 0) 102 103 104if __name__ == '__main__': 105 test.main() 106