• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Unit tests for object_identity."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.platform import test
22from tensorflow.python.util import nest
23from tensorflow.python.util import object_identity
24
25
26class ObjectIdentityWrapperTest(test.TestCase):
27
28  def testWrapperNotEqualToWrapped(self):
29    class SettableHash(object):
30
31      def __init__(self):
32        self.hash_value = 8675309
33
34      def __hash__(self):
35        return self.hash_value
36
37    o = SettableHash()
38    wrap1 = object_identity._ObjectIdentityWrapper(o)
39    wrap2 = object_identity._ObjectIdentityWrapper(o)
40
41    self.assertEqual(wrap1, wrap1)
42    self.assertEqual(wrap1, wrap2)
43    self.assertEqual(o, wrap1.unwrapped)
44    self.assertEqual(o, wrap2.unwrapped)
45    with self.assertRaises(TypeError):
46      bool(o == wrap1)
47    with self.assertRaises(TypeError):
48      bool(wrap1 != o)
49
50    self.assertNotIn(o, set([wrap1]))
51    o.hash_value = id(o)
52    # Since there is now a hash collision we raise an exception
53    with self.assertRaises(TypeError):
54      bool(o in set([wrap1]))
55
56  def testNestFlatten(self):
57    a = object_identity._ObjectIdentityWrapper('a')
58    b = object_identity._ObjectIdentityWrapper('b')
59    c = object_identity._ObjectIdentityWrapper('c')
60    flat = nest.flatten([[[(a, b)]], c])
61    self.assertEqual(flat, [a, b, c])
62
63  def testNestMapStructure(self):
64    k = object_identity._ObjectIdentityWrapper('k')
65    v1 = object_identity._ObjectIdentityWrapper('v1')
66    v2 = object_identity._ObjectIdentityWrapper('v2')
67    struct = nest.map_structure(lambda a, b: (a, b), {k: v1}, {k: v2})
68    self.assertEqual(struct, {k: (v1, v2)})
69
70
71class ObjectIdentitySetTest(test.TestCase):
72
73  def testDifference(self):
74
75    class Element(object):
76      pass
77
78    a = Element()
79    b = Element()
80    c = Element()
81    set1 = object_identity.ObjectIdentitySet([a, b])
82    set2 = object_identity.ObjectIdentitySet([b, c])
83    diff_set = set1.difference(set2)
84    self.assertIn(a, diff_set)
85    self.assertNotIn(b, diff_set)
86    self.assertNotIn(c, diff_set)
87
88  def testDiscard(self):
89    a = object()
90    b = object()
91    set1 = object_identity.ObjectIdentitySet([a, b])
92    set1.discard(a)
93    self.assertIn(b, set1)
94    self.assertNotIn(a, set1)
95
96  def testClear(self):
97    a = object()
98    b = object()
99    set1 = object_identity.ObjectIdentitySet([a, b])
100    set1.clear()
101    self.assertLen(set1, 0)
102
103
104if __name__ == '__main__':
105  test.main()
106