1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16from tensorflow.python.eager import context 17from tensorflow.python.eager import wrap_function 18from tensorflow.python.framework import ops 19from tensorflow.python.framework import test_util 20from tensorflow.python.platform import test 21from tensorflow.python.trackable import resource 22 23 24def run_inside_wrap_function_in_eager_mode(graph_function): 25 """Decorator to execute the same graph code in eager and graph modes. 26 27 In graph mode, we just execute the graph_function passed as argument. In eager 28 mode, we wrap the function using wrap_function and then execute the wrapped 29 result. 30 31 Args: 32 graph_function: python function containing graph code to be wrapped 33 34 Returns: 35 decorated function 36 """ 37 def wrap_and_execute(self): 38 if context.executing_eagerly(): 39 wrapped = wrap_function.wrap_function(graph_function, [self]) 40 # use the wrapped graph function 41 wrapped() 42 else: 43 # use the original function 44 graph_function(self) 45 return wrap_and_execute 46 47 48class _DummyResource(resource.TrackableResource): 49 50 def __init__(self, handle_name): 51 self._handle_name = handle_name 52 super(_DummyResource, self).__init__() 53 54 def _create_resource(self): 55 return self._handle_name 56 57 58class _DummyResource1(resource.TrackableResource): 59 60 def __init__(self, handle_name): 61 self._handle_name = handle_name 62 self._value = 0 63 super(_DummyResource1, self).__init__() 64 65 def _create_resource(self): 66 return self._handle_name 67 68 69class ResourceTrackerTest(test.TestCase): 70 71 def testBasic(self): 72 resource_tracker = resource.ResourceTracker() 73 with resource.resource_tracker_scope(resource_tracker): 74 dummy_resource1 = _DummyResource("test1") 75 dummy_resource2 = _DummyResource("test2") 76 77 self.assertEqual(2, len(resource_tracker.resources)) 78 self.assertEqual("test1", resource_tracker.resources[0].resource_handle) 79 self.assertEqual("test2", resource_tracker.resources[1].resource_handle) 80 81 def testTwoScopes(self): 82 resource_tracker1 = resource.ResourceTracker() 83 with resource.resource_tracker_scope(resource_tracker1): 84 dummy_resource1 = _DummyResource("test1") 85 86 resource_tracker2 = resource.ResourceTracker() 87 with resource.resource_tracker_scope(resource_tracker2): 88 dummy_resource2 = _DummyResource("test2") 89 90 self.assertEqual(1, len(resource_tracker1.resources)) 91 self.assertEqual("test1", resource_tracker1.resources[0].resource_handle) 92 self.assertEqual(1, len(resource_tracker2.resources)) 93 self.assertEqual("test2", resource_tracker2.resources[0].resource_handle) 94 95 def testNestedScopesScopes(self): 96 resource_tracker = resource.ResourceTracker() 97 with resource.resource_tracker_scope(resource_tracker): 98 resource_tracker1 = resource.ResourceTracker() 99 with resource.resource_tracker_scope(resource_tracker1): 100 dummy_resource1 = _DummyResource("test1") 101 102 resource_tracker2 = resource.ResourceTracker() 103 with resource.resource_tracker_scope(resource_tracker2): 104 dummy_resource2 = _DummyResource("test2") 105 106 self.assertEqual(1, len(resource_tracker1.resources)) 107 self.assertEqual("test1", resource_tracker1.resources[0].resource_handle) 108 self.assertEqual(1, len(resource_tracker2.resources)) 109 self.assertEqual("test2", resource_tracker2.resources[0].resource_handle) 110 self.assertEqual(2, len(resource_tracker.resources)) 111 self.assertEqual("test1", resource_tracker.resources[0].resource_handle) 112 self.assertEqual("test2", resource_tracker.resources[1].resource_handle) 113 114 115class ResourceCreatorScopeTest(test.TestCase): 116 117 @test_util.run_in_graph_and_eager_modes 118 @run_inside_wrap_function_in_eager_mode 119 def testResourceCreator(self): 120 def resource_creator_fn(next_creator, *a, **kwargs): 121 kwargs["handle_name"] = "forced_name" 122 return next_creator(*a, **kwargs) 123 124 # test that two resource classes use the same creator function 125 with ops.resource_creator_scope(["_DummyResource", "_DummyResource1"], 126 resource_creator_fn): 127 dummy_0 = _DummyResource(handle_name="fake_name_0") 128 dummy_1 = _DummyResource1(handle_name="fake_name_1") 129 130 self.assertEqual(dummy_0._handle_name, "forced_name") 131 self.assertEqual(dummy_1._handle_name, "forced_name") 132 133 @test_util.run_in_graph_and_eager_modes 134 @run_inside_wrap_function_in_eager_mode 135 def testResourceCreatorNestingError(self): 136 137 def creator(next_creator, *a, **kwargs): 138 return next_creator(*a, **kwargs) 139 140 # Save the state so we can clean up at the end. 141 graph = ops.get_default_graph() 142 old_creator_stack = graph._resource_creator_stack["_DummyResource"] 143 144 try: 145 scope = ops.resource_creator_scope(creator, "_DummyResource") 146 scope.__enter__() 147 with ops.resource_creator_scope(creator, "_DummyResource"): 148 with self.assertRaises(RuntimeError): 149 scope.__exit__(None, None, None) 150 finally: 151 graph._resource_creator_stack["_DummyResource"] = old_creator_stack 152 153 @test_util.run_in_graph_and_eager_modes 154 @run_inside_wrap_function_in_eager_mode 155 def testResourceCreatorNesting(self): 156 157 def resource_creator_fn_0(next_creator, *a, **kwargs): 158 instance = next_creator(*a, **kwargs) 159 instance._value = 1 160 return instance 161 162 def resource_creator_fn_1(next_creator, *a, **kwargs): 163 kwargs["handle_name"] = "forced_name1" 164 return next_creator(*a, **kwargs) 165 166 with ops.resource_creator_scope(["_DummyResource1"], resource_creator_fn_0): 167 with ops.resource_creator_scope(["_DummyResource1"], 168 resource_creator_fn_1): 169 dummy_0 = _DummyResource1(handle_name="fake_name") 170 171 self.assertEqual(dummy_0._handle_name, "forced_name1") 172 self.assertEqual(dummy_0._value, 1) 173 174 175if __name__ == "__main__": 176 test.main() 177