• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16from tensorflow.python.eager import context
17from tensorflow.python.eager import wrap_function
18from tensorflow.python.framework import ops
19from tensorflow.python.framework import test_util
20from tensorflow.python.platform import test
21from tensorflow.python.trackable import resource
22
23
24def run_inside_wrap_function_in_eager_mode(graph_function):
25  """Decorator to execute the same graph code in eager and graph modes.
26
27  In graph mode, we just execute the graph_function passed as argument. In eager
28  mode, we wrap the function using wrap_function and then execute the wrapped
29  result.
30
31  Args:
32    graph_function: python function containing graph code to be wrapped
33
34  Returns:
35    decorated function
36  """
37  def wrap_and_execute(self):
38    if context.executing_eagerly():
39      wrapped = wrap_function.wrap_function(graph_function, [self])
40      # use the wrapped graph function
41      wrapped()
42    else:
43      # use the original function
44      graph_function(self)
45  return wrap_and_execute
46
47
48class _DummyResource(resource.TrackableResource):
49
50  def __init__(self, handle_name):
51    self._handle_name = handle_name
52    super(_DummyResource, self).__init__()
53
54  def _create_resource(self):
55    return self._handle_name
56
57
58class _DummyResource1(resource.TrackableResource):
59
60  def __init__(self, handle_name):
61    self._handle_name = handle_name
62    self._value = 0
63    super(_DummyResource1, self).__init__()
64
65  def _create_resource(self):
66    return self._handle_name
67
68
69class ResourceTrackerTest(test.TestCase):
70
71  def testBasic(self):
72    resource_tracker = resource.ResourceTracker()
73    with resource.resource_tracker_scope(resource_tracker):
74      dummy_resource1 = _DummyResource("test1")
75      dummy_resource2 = _DummyResource("test2")
76
77    self.assertEqual(2, len(resource_tracker.resources))
78    self.assertEqual("test1", resource_tracker.resources[0].resource_handle)
79    self.assertEqual("test2", resource_tracker.resources[1].resource_handle)
80
81  def testTwoScopes(self):
82    resource_tracker1 = resource.ResourceTracker()
83    with resource.resource_tracker_scope(resource_tracker1):
84      dummy_resource1 = _DummyResource("test1")
85
86    resource_tracker2 = resource.ResourceTracker()
87    with resource.resource_tracker_scope(resource_tracker2):
88      dummy_resource2 = _DummyResource("test2")
89
90    self.assertEqual(1, len(resource_tracker1.resources))
91    self.assertEqual("test1", resource_tracker1.resources[0].resource_handle)
92    self.assertEqual(1, len(resource_tracker2.resources))
93    self.assertEqual("test2", resource_tracker2.resources[0].resource_handle)
94
95  def testNestedScopesScopes(self):
96    resource_tracker = resource.ResourceTracker()
97    with resource.resource_tracker_scope(resource_tracker):
98      resource_tracker1 = resource.ResourceTracker()
99      with resource.resource_tracker_scope(resource_tracker1):
100        dummy_resource1 = _DummyResource("test1")
101
102      resource_tracker2 = resource.ResourceTracker()
103      with resource.resource_tracker_scope(resource_tracker2):
104        dummy_resource2 = _DummyResource("test2")
105
106    self.assertEqual(1, len(resource_tracker1.resources))
107    self.assertEqual("test1", resource_tracker1.resources[0].resource_handle)
108    self.assertEqual(1, len(resource_tracker2.resources))
109    self.assertEqual("test2", resource_tracker2.resources[0].resource_handle)
110    self.assertEqual(2, len(resource_tracker.resources))
111    self.assertEqual("test1", resource_tracker.resources[0].resource_handle)
112    self.assertEqual("test2", resource_tracker.resources[1].resource_handle)
113
114
115class ResourceCreatorScopeTest(test.TestCase):
116
117  @test_util.run_in_graph_and_eager_modes
118  @run_inside_wrap_function_in_eager_mode
119  def testResourceCreator(self):
120    def resource_creator_fn(next_creator, *a, **kwargs):
121      kwargs["handle_name"] = "forced_name"
122      return next_creator(*a, **kwargs)
123
124    # test that two resource classes use the same creator function
125    with ops.resource_creator_scope(["_DummyResource", "_DummyResource1"],
126                                    resource_creator_fn):
127      dummy_0 = _DummyResource(handle_name="fake_name_0")
128      dummy_1 = _DummyResource1(handle_name="fake_name_1")
129
130    self.assertEqual(dummy_0._handle_name, "forced_name")
131    self.assertEqual(dummy_1._handle_name, "forced_name")
132
133  @test_util.run_in_graph_and_eager_modes
134  @run_inside_wrap_function_in_eager_mode
135  def testResourceCreatorNestingError(self):
136
137    def creator(next_creator, *a, **kwargs):
138      return next_creator(*a, **kwargs)
139
140    # Save the state so we can clean up at the end.
141    graph = ops.get_default_graph()
142    old_creator_stack = graph._resource_creator_stack["_DummyResource"]
143
144    try:
145      scope = ops.resource_creator_scope(creator, "_DummyResource")
146      scope.__enter__()
147      with ops.resource_creator_scope(creator, "_DummyResource"):
148        with self.assertRaises(RuntimeError):
149          scope.__exit__(None, None, None)
150    finally:
151      graph._resource_creator_stack["_DummyResource"] = old_creator_stack
152
153  @test_util.run_in_graph_and_eager_modes
154  @run_inside_wrap_function_in_eager_mode
155  def testResourceCreatorNesting(self):
156
157    def resource_creator_fn_0(next_creator, *a, **kwargs):
158      instance = next_creator(*a, **kwargs)
159      instance._value = 1
160      return instance
161
162    def resource_creator_fn_1(next_creator, *a, **kwargs):
163      kwargs["handle_name"] = "forced_name1"
164      return next_creator(*a, **kwargs)
165
166    with ops.resource_creator_scope(["_DummyResource1"], resource_creator_fn_0):
167      with ops.resource_creator_scope(["_DummyResource1"],
168                                      resource_creator_fn_1):
169        dummy_0 = _DummyResource1(handle_name="fake_name")
170
171    self.assertEqual(dummy_0._handle_name, "forced_name1")
172    self.assertEqual(dummy_0._value, 1)
173
174
175if __name__ == "__main__":
176  test.main()
177