• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for layer_utils."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import contextlib
23import multiprocessing.dummy
24import pickle
25import time
26import timeit
27
28import numpy as np
29
30from tensorflow.python.keras.utils import layer_utils
31from tensorflow.python.platform import test
32from tensorflow.python.training.tracking import tracking
33
34
35_PICKLEABLE_CALL_COUNT = collections.Counter()
36
37
38class MyPickleableObject(tracking.AutoTrackable):
39  """Needed for InterfaceTests.test_property_cache_serialization.
40
41  This class must be at the top level. This is a constraint of pickle,
42  unrelated to `cached_per_instance`.
43  """
44
45  @property
46  @layer_utils.cached_per_instance
47  def my_id(self):
48    _PICKLEABLE_CALL_COUNT[self] += 1
49    return id(self)
50
51
52class LayerUtilsTest(test.TestCase):
53
54  def test_property_cache(self):
55    test_counter = collections.Counter()
56
57    class MyObject(tracking.AutoTrackable):
58
59      def __init__(self):
60        super(MyObject, self).__init__()
61        self._frozen = True
62
63      def __setattr__(self, key, value):
64        """Enforce that cache does not set attribute on MyObject."""
65        if getattr(self, "_frozen", False):
66          raise ValueError("Cannot mutate when frozen.")
67        return super(MyObject, self).__setattr__(key, value)
68
69      @property
70      @layer_utils.cached_per_instance
71      def test_property(self):
72        test_counter[id(self)] += 1
73        return id(self)
74
75    first_object = MyObject()
76    second_object = MyObject()
77
78    # Make sure the objects return the correct values
79    self.assertEqual(first_object.test_property, id(first_object))
80    self.assertEqual(second_object.test_property, id(second_object))
81
82    # Make sure the cache does not share across objects
83    self.assertNotEqual(first_object.test_property, second_object.test_property)
84
85    # Check again (Now the values should be cached.)
86    self.assertEqual(first_object.test_property, id(first_object))
87    self.assertEqual(second_object.test_property, id(second_object))
88
89    # Count the function calls to make sure the cache is actually being used.
90    self.assertAllEqual(tuple(test_counter.values()), (1, 1))
91
92  def test_property_cache_threaded(self):
93    call_count = collections.Counter()
94
95    class MyObject(tracking.AutoTrackable):
96
97      @property
98      @layer_utils.cached_per_instance
99      def test_property(self):
100        # Random sleeps to ensure that the execution thread changes
101        # mid-computation.
102        call_count["test_property"] += 1
103        time.sleep(np.random.random() + 1.)
104
105        # Use a RandomState which is seeded off the instance's id (the mod is
106        # because numpy limits the range of seeds) to ensure that an instance
107        # returns the same value in different threads, but different instances
108        # return different values.
109        return int(np.random.RandomState(id(self) % (2 ** 31)).randint(2 ** 16))
110
111      def get_test_property(self, _):
112        """Function provided to .map for threading test."""
113        return self.test_property
114
115    # Test that multiple threads return the same value. This requires that
116    # the underlying function is repeatable, as cached_property makes no attempt
117    # to prioritize the first call.
118    test_obj = MyObject()
119    with contextlib.closing(multiprocessing.dummy.Pool(32)) as pool:
120      # Intentionally make a large pool (even when there are only a small number
121      # of cpus) to ensure that the runtime switches threads.
122      results = pool.map(test_obj.get_test_property, range(64))
123    self.assertEqual(len(set(results)), 1)
124
125    # Make sure we actually are testing threaded behavior.
126    self.assertGreater(call_count["test_property"], 1)
127
128    # Make sure new threads still cache hit.
129    with contextlib.closing(multiprocessing.dummy.Pool(2)) as pool:
130      start_time = timeit.default_timer()  # Don't time pool instantiation.
131      results = pool.map(test_obj.get_test_property, range(4))
132    total_time = timeit.default_timer() - start_time
133
134    # Note(taylorrobie): The reason that it is safe to time a unit test is that
135    #                    a cache hit will be << 1 second, and a cache miss is
136    #                    guaranteed to be >= 1 second. Empirically confirmed by
137    #                    100,000 runs with no flakes.
138    self.assertLess(total_time, 0.95)
139
140  def test_property_cache_serialization(self):
141    # Reset call count. .keys() must be wrapped in a list, because otherwise we
142    # would mutate the iterator while iterating.
143    for k in list(_PICKLEABLE_CALL_COUNT.keys()):
144      _PICKLEABLE_CALL_COUNT.pop(k)
145
146    first_instance = MyPickleableObject()
147    self.assertEqual(id(first_instance), first_instance.my_id)
148
149    # Test that we can pickle and un-pickle
150    second_instance = pickle.loads(pickle.dumps(first_instance))
151
152    self.assertEqual(id(second_instance), second_instance.my_id)
153    self.assertNotEqual(first_instance.my_id, second_instance.my_id)
154
155    # Make sure de-serialized object uses the cache.
156    self.assertEqual(_PICKLEABLE_CALL_COUNT[second_instance], 1)
157
158    # Make sure the decorator cache is not being serialized with the object.
159    expected_size = len(pickle.dumps(second_instance))
160    for _ in range(5):
161      # Add some more entries to the cache.
162      _ = MyPickleableObject().my_id
163    self.assertEqual(len(_PICKLEABLE_CALL_COUNT), 7)
164    size_check_instance = MyPickleableObject()
165    _ = size_check_instance.my_id
166    self.assertEqual(expected_size, len(pickle.dumps(size_check_instance)))
167
168
169if __name__ == "__main__":
170  test.main()
171