1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for layer_utils.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import contextlib 23import multiprocessing.dummy 24import pickle 25import time 26import timeit 27 28import numpy as np 29 30from tensorflow.python.keras.utils import layer_utils 31from tensorflow.python.platform import test 32from tensorflow.python.training.tracking import tracking 33 34 35_PICKLEABLE_CALL_COUNT = collections.Counter() 36 37 38class MyPickleableObject(tracking.AutoTrackable): 39 """Needed for InterfaceTests.test_property_cache_serialization. 40 41 This class must be at the top level. This is a constraint of pickle, 42 unrelated to `cached_per_instance`. 43 """ 44 45 @property 46 @layer_utils.cached_per_instance 47 def my_id(self): 48 _PICKLEABLE_CALL_COUNT[self] += 1 49 return id(self) 50 51 52class LayerUtilsTest(test.TestCase): 53 54 def test_property_cache(self): 55 test_counter = collections.Counter() 56 57 class MyObject(tracking.AutoTrackable): 58 59 def __init__(self): 60 super(MyObject, self).__init__() 61 self._frozen = True 62 63 def __setattr__(self, key, value): 64 """Enforce that cache does not set attribute on MyObject.""" 65 if getattr(self, "_frozen", False): 66 raise ValueError("Cannot mutate when frozen.") 67 return super(MyObject, self).__setattr__(key, value) 68 69 @property 70 @layer_utils.cached_per_instance 71 def test_property(self): 72 test_counter[id(self)] += 1 73 return id(self) 74 75 first_object = MyObject() 76 second_object = MyObject() 77 78 # Make sure the objects return the correct values 79 self.assertEqual(first_object.test_property, id(first_object)) 80 self.assertEqual(second_object.test_property, id(second_object)) 81 82 # Make sure the cache does not share across objects 83 self.assertNotEqual(first_object.test_property, second_object.test_property) 84 85 # Check again (Now the values should be cached.) 86 self.assertEqual(first_object.test_property, id(first_object)) 87 self.assertEqual(second_object.test_property, id(second_object)) 88 89 # Count the function calls to make sure the cache is actually being used. 90 self.assertAllEqual(tuple(test_counter.values()), (1, 1)) 91 92 def test_property_cache_threaded(self): 93 call_count = collections.Counter() 94 95 class MyObject(tracking.AutoTrackable): 96 97 @property 98 @layer_utils.cached_per_instance 99 def test_property(self): 100 # Random sleeps to ensure that the execution thread changes 101 # mid-computation. 102 call_count["test_property"] += 1 103 time.sleep(np.random.random() + 1.) 104 105 # Use a RandomState which is seeded off the instance's id (the mod is 106 # because numpy limits the range of seeds) to ensure that an instance 107 # returns the same value in different threads, but different instances 108 # return different values. 109 return int(np.random.RandomState(id(self) % (2 ** 31)).randint(2 ** 16)) 110 111 def get_test_property(self, _): 112 """Function provided to .map for threading test.""" 113 return self.test_property 114 115 # Test that multiple threads return the same value. This requires that 116 # the underlying function is repeatable, as cached_property makes no attempt 117 # to prioritize the first call. 118 test_obj = MyObject() 119 with contextlib.closing(multiprocessing.dummy.Pool(32)) as pool: 120 # Intentionally make a large pool (even when there are only a small number 121 # of cpus) to ensure that the runtime switches threads. 122 results = pool.map(test_obj.get_test_property, range(64)) 123 self.assertEqual(len(set(results)), 1) 124 125 # Make sure we actually are testing threaded behavior. 126 self.assertGreater(call_count["test_property"], 1) 127 128 # Make sure new threads still cache hit. 129 with contextlib.closing(multiprocessing.dummy.Pool(2)) as pool: 130 start_time = timeit.default_timer() # Don't time pool instantiation. 131 results = pool.map(test_obj.get_test_property, range(4)) 132 total_time = timeit.default_timer() - start_time 133 134 # Note(taylorrobie): The reason that it is safe to time a unit test is that 135 # a cache hit will be << 1 second, and a cache miss is 136 # guaranteed to be >= 1 second. Empirically confirmed by 137 # 100,000 runs with no flakes. 138 self.assertLess(total_time, 0.95) 139 140 def test_property_cache_serialization(self): 141 # Reset call count. .keys() must be wrapped in a list, because otherwise we 142 # would mutate the iterator while iterating. 143 for k in list(_PICKLEABLE_CALL_COUNT.keys()): 144 _PICKLEABLE_CALL_COUNT.pop(k) 145 146 first_instance = MyPickleableObject() 147 self.assertEqual(id(first_instance), first_instance.my_id) 148 149 # Test that we can pickle and un-pickle 150 second_instance = pickle.loads(pickle.dumps(first_instance)) 151 152 self.assertEqual(id(second_instance), second_instance.my_id) 153 self.assertNotEqual(first_instance.my_id, second_instance.my_id) 154 155 # Make sure de-serialized object uses the cache. 156 self.assertEqual(_PICKLEABLE_CALL_COUNT[second_instance], 1) 157 158 # Make sure the decorator cache is not being serialized with the object. 159 expected_size = len(pickle.dumps(second_instance)) 160 for _ in range(5): 161 # Add some more entries to the cache. 162 _ = MyPickleableObject().my_id 163 self.assertEqual(len(_PICKLEABLE_CALL_COUNT), 7) 164 size_check_instance = MyPickleableObject() 165 _ = size_check_instance.my_id 166 self.assertEqual(expected_size, len(pickle.dumps(size_check_instance))) 167 168 169if __name__ == "__main__": 170 test.main() 171