1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for cache module.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.autograph.pyct import cache 22from tensorflow.python.platform import test 23 24 25class CacheTest(test.TestCase): 26 27 def test_code_object_cache(self): 28 29 def factory(x): 30 def test_fn(): 31 return x + 1 32 return test_fn 33 34 c = cache.CodeObjectCache() 35 36 f1 = factory(1) 37 dummy = object() 38 39 c[f1][1] = dummy 40 41 self.assertTrue(c.has(f1, 1)) 42 self.assertFalse(c.has(f1, 2)) 43 self.assertIs(c[f1][1], dummy) 44 self.assertEqual(len(c), 1) 45 46 f2 = factory(2) 47 48 self.assertTrue(c.has(f2, 1)) 49 self.assertIs(c[f2][1], dummy) 50 self.assertEqual(len(c), 1) 51 52 def test_unbound_instance_cache(self): 53 54 class TestClass(object): 55 56 def method(self): 57 pass 58 59 c = cache.UnboundInstanceCache() 60 61 o1 = TestClass() 62 dummy = object() 63 64 c[o1.method][1] = dummy 65 66 self.assertTrue(c.has(o1.method, 1)) 67 self.assertFalse(c.has(o1.method, 2)) 68 self.assertIs(c[o1.method][1], dummy) 69 self.assertEqual(len(c), 1) 70 71 o2 = TestClass() 72 73 self.assertTrue(c.has(o2.method, 1)) 74 self.assertIs(c[o2.method][1], dummy) 75 self.assertEqual(len(c), 1) 76 77 78if __name__ == '__main__': 79 test.main() 80