• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for cache module."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.autograph.pyct import cache
22from tensorflow.python.platform import test
23
24
25class CacheTest(test.TestCase):
26
27  def test_code_object_cache(self):
28
29    def factory(x):
30      def test_fn():
31        return x + 1
32      return test_fn
33
34    c = cache.CodeObjectCache()
35
36    f1 = factory(1)
37    dummy = object()
38
39    c[f1][1] = dummy
40
41    self.assertTrue(c.has(f1, 1))
42    self.assertFalse(c.has(f1, 2))
43    self.assertIs(c[f1][1], dummy)
44    self.assertEqual(len(c), 1)
45
46    f2 = factory(2)
47
48    self.assertTrue(c.has(f2, 1))
49    self.assertIs(c[f2][1], dummy)
50    self.assertEqual(len(c), 1)
51
52  def test_unbound_instance_cache(self):
53
54    class TestClass(object):
55
56      def method(self):
57        pass
58
59    c = cache.UnboundInstanceCache()
60
61    o1 = TestClass()
62    dummy = object()
63
64    c[o1.method][1] = dummy
65
66    self.assertTrue(c.has(o1.method, 1))
67    self.assertFalse(c.has(o1.method, 2))
68    self.assertIs(c[o1.method][1], dummy)
69    self.assertEqual(len(c), 1)
70
71    o2 = TestClass()
72
73    self.assertTrue(c.has(o2.method, 1))
74    self.assertIs(c[o2.method][1], dummy)
75    self.assertEqual(len(c), 1)
76
77
78if __name__ == '__main__':
79  test.main()
80