• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Base TFDecorator class and utility functions for working with decorators.
16
17There are two ways to create decorators that TensorFlow can introspect into.
18This is important for documentation generation purposes, so that function
19signatures aren't obscured by the (*args, **kwds) signature that decorators
20often provide.
21
221. Call `tf_decorator.make_decorator` on your wrapper function. If your
23decorator is stateless, or can capture all of the variables it needs to work
24with through lexical closure, this is the simplest option. Create your wrapper
25function as usual, but instead of returning it, return
26`tf_decorator.make_decorator(target, your_wrapper)`. This will attach some
27decorator introspection metadata onto your wrapper and return it.
28
29Example:
30
31  def print_hello_before_calling(target):
32    def wrapper(*args, **kwargs):
33      print('hello')
34      return target(*args, **kwargs)
35    return tf_decorator.make_decorator(target, wrapper)
36
372. Derive from TFDecorator. If your decorator needs to be stateful, you can
38implement it in terms of a TFDecorator. Store whatever state you need in your
39derived class, and implement the `__call__` method to do your work before
40calling into your target. You can retrieve the target via
41`super(MyDecoratorClass, self).decorated_target`, and call it with whatever
42parameters it needs.
43
44Example:
45
46  class CallCounter(tf_decorator.TFDecorator):
47    def __init__(self, target):
48      super(CallCounter, self).__init__('count_calls', target)
49      self.call_count = 0
50
51    def __call__(self, *args, **kwargs):
52      self.call_count += 1
53      return super(CallCounter, self).decorated_target(*args, **kwargs)
54
55  def count_calls(target):
56    return CallCounter(target)
57"""
58from __future__ import absolute_import
59from __future__ import division
60from __future__ import print_function
61
62import inspect
63
64
65def make_decorator(target,
66                   decorator_func,
67                   decorator_name=None,
68                   decorator_doc='',
69                   decorator_argspec=None):
70  """Make a decorator from a wrapper and a target.
71
72  Args:
73    target: The final callable to be wrapped.
74    decorator_func: The wrapper function.
75    decorator_name: The name of the decorator. If `None`, the name of the
76      function calling make_decorator.
77    decorator_doc: Documentation specific to this application of
78      `decorator_func` to `target`.
79    decorator_argspec: The new callable signature of this decorator.
80
81  Returns:
82    The `decorator_func` argument with new metadata attached.
83  """
84  if decorator_name is None:
85    decorator_name = inspect.currentframe().f_back.f_code.co_name
86  decorator = TFDecorator(decorator_name, target, decorator_doc,
87                          decorator_argspec)
88  setattr(decorator_func, '_tf_decorator', decorator)
89  # Objects that are callables (e.g., a functools.partial object) may not have
90  # the following attributes.
91  if hasattr(target, '__name__'):
92    decorator_func.__name__ = target.__name__
93  if hasattr(target, '__qualname__'):
94    decorator_func.__qualname__ = target.__qualname__
95  if hasattr(target, '__module__'):
96    decorator_func.__module__ = target.__module__
97  if hasattr(target, '__dict__'):
98    # Copy dict entries from target which are not overridden by decorator_func.
99    for name in target.__dict__:
100      if name not in decorator_func.__dict__:
101        decorator_func.__dict__[name] = target.__dict__[name]
102  if hasattr(target, '__doc__'):
103    decorator_func.__doc__ = decorator.__doc__
104  decorator_func.__wrapped__ = target
105  # Keeping a second handle to `target` allows callers to detect whether the
106  # decorator was modified using `rewrap`.
107  decorator_func.__original_wrapped__ = target
108  return decorator_func
109
110
111def _has_tf_decorator_attr(obj):
112  """Checks if object has _tf_decorator attribute.
113
114  This check would work for mocked object as well since it would
115  check if returned attribute has the right type.
116
117  Args:
118    obj: Python object.
119  """
120  return (
121      hasattr(obj, '_tf_decorator') and
122      isinstance(getattr(obj, '_tf_decorator'), TFDecorator))
123
124
125def rewrap(decorator_func, previous_target, new_target):
126  """Injects a new target into a function built by make_decorator.
127
128  This function allows replacing a function wrapped by `decorator_func`,
129  assuming the decorator that wraps the function is written as described below.
130
131  The decorator function must use `<decorator name>.__wrapped__` instead of the
132  wrapped function that is normally used:
133
134  Example:
135
136      # Instead of this:
137      def simple_parametrized_wrapper(*args, **kwds):
138        return wrapped_fn(*args, **kwds)
139
140      tf_decorator.make_decorator(simple_parametrized_wrapper, wrapped_fn)
141
142      # Write this:
143      def simple_parametrized_wrapper(*args, **kwds):
144        return simple_parametrized_wrapper.__wrapped__(*args, **kwds)
145
146      tf_decorator.make_decorator(simple_parametrized_wrapper, wrapped_fn)
147
148  Note that this process modifies decorator_func.
149
150  Args:
151    decorator_func: Callable returned by `wrap`.
152    previous_target: Callable that needs to be replaced.
153    new_target: Callable to replace previous_target with.
154
155  Returns:
156    The updated decorator. If decorator_func is not a tf_decorator, new_target
157    is returned.
158  """
159  # Because the process mutates the decorator, we only need to alter the
160  # innermost function that wraps previous_target.
161  cur = decorator_func
162  innermost_decorator = None
163  target = None
164  while _has_tf_decorator_attr(cur):
165    innermost_decorator = cur
166    target = getattr(cur, '_tf_decorator')
167    if target.decorated_target is previous_target:
168      break
169    cur = target.decorated_target
170    assert cur is not None
171
172  # If decorator_func is not a decorator, new_target replaces it directly.
173  if innermost_decorator is None:
174    # Consistency check. The caller should always pass the result of
175    # tf_decorator.unwrap as previous_target. If decorator_func is not a
176    # decorator, that will have returned decorator_func itself.
177    assert decorator_func is previous_target
178    return new_target
179
180  target.decorated_target = new_target
181
182  if inspect.ismethod(innermost_decorator):
183    # Bound methods can't be assigned attributes. Thankfully, they seem to
184    # be just proxies for their unbound counterpart, and we can modify that.
185    if hasattr(innermost_decorator, '__func__'):
186      innermost_decorator.__func__.__wrapped__ = new_target
187    elif hasattr(innermost_decorator, 'im_func'):
188      innermost_decorator.im_func.__wrapped__ = new_target
189    else:
190      innermost_decorator.__wrapped__ = new_target
191  else:
192    innermost_decorator.__wrapped__ = new_target
193
194  return decorator_func
195
196
197def unwrap(maybe_tf_decorator):
198  """Unwraps an object into a list of TFDecorators and a final target.
199
200  Args:
201    maybe_tf_decorator: Any callable object.
202
203  Returns:
204    A tuple whose first element is an list of TFDecorator-derived objects that
205    were applied to the final callable target, and whose second element is the
206    final undecorated callable target. If the `maybe_tf_decorator` parameter is
207    not decorated by any TFDecorators, the first tuple element will be an empty
208    list. The `TFDecorator` list is ordered from outermost to innermost
209    decorators.
210  """
211  decorators = []
212  cur = maybe_tf_decorator
213  while True:
214    if isinstance(cur, TFDecorator):
215      decorators.append(cur)
216    elif _has_tf_decorator_attr(cur):
217      decorators.append(getattr(cur, '_tf_decorator'))
218    else:
219      break
220    if not hasattr(decorators[-1], 'decorated_target'):
221      break
222    cur = decorators[-1].decorated_target
223  return decorators, cur
224
225
226class TFDecorator(object):
227  """Base class for all TensorFlow decorators.
228
229  TFDecorator captures and exposes the wrapped target, and provides details
230  about the current decorator.
231  """
232
233  def __init__(self,
234               decorator_name,
235               target,
236               decorator_doc='',
237               decorator_argspec=None):
238    self._decorated_target = target
239    self._decorator_name = decorator_name
240    self._decorator_doc = decorator_doc
241    self._decorator_argspec = decorator_argspec
242    if hasattr(target, '__name__'):
243      self.__name__ = target.__name__
244    if hasattr(target, '__qualname__'):
245      self.__qualname__ = target.__qualname__
246    if self._decorator_doc:
247      self.__doc__ = self._decorator_doc
248    elif hasattr(target, '__doc__') and target.__doc__:
249      self.__doc__ = target.__doc__
250    else:
251      self.__doc__ = ''
252
253  def __get__(self, instance, owner):
254    return self._decorated_target.__get__(instance, owner)
255
256  def __call__(self, *args, **kwargs):
257    return self._decorated_target(*args, **kwargs)
258
259  @property
260  def decorated_target(self):
261    return self._decorated_target
262
263  @decorated_target.setter
264  def decorated_target(self, decorated_target):
265    self._decorated_target = decorated_target
266
267  @property
268  def decorator_name(self):
269    return self._decorator_name
270
271  @property
272  def decorator_doc(self):
273    return self._decorator_doc
274
275  @property
276  def decorator_argspec(self):
277    return self._decorator_argspec
278