1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Base TFDecorator class and utility functions for working with decorators. 16 17There are two ways to create decorators that TensorFlow can introspect into. 18This is important for documentation generation purposes, so that function 19signatures aren't obscured by the (*args, **kwds) signature that decorators 20often provide. 21 221. Call `tf_decorator.make_decorator` on your wrapper function. If your 23decorator is stateless, or can capture all of the variables it needs to work 24with through lexical closure, this is the simplest option. Create your wrapper 25function as usual, but instead of returning it, return 26`tf_decorator.make_decorator(target, your_wrapper)`. This will attach some 27decorator introspection metadata onto your wrapper and return it. 28 29Example: 30 31 def print_hello_before_calling(target): 32 def wrapper(*args, **kwargs): 33 print('hello') 34 return target(*args, **kwargs) 35 return tf_decorator.make_decorator(target, wrapper) 36 372. Derive from TFDecorator. If your decorator needs to be stateful, you can 38implement it in terms of a TFDecorator. Store whatever state you need in your 39derived class, and implement the `__call__` method to do your work before 40calling into your target. You can retrieve the target via 41`super(MyDecoratorClass, self).decorated_target`, and call it with whatever 42parameters it needs. 43 44Example: 45 46 class CallCounter(tf_decorator.TFDecorator): 47 def __init__(self, target): 48 super(CallCounter, self).__init__('count_calls', target) 49 self.call_count = 0 50 51 def __call__(self, *args, **kwargs): 52 self.call_count += 1 53 return super(CallCounter, self).decorated_target(*args, **kwargs) 54 55 def count_calls(target): 56 return CallCounter(target) 57""" 58from __future__ import absolute_import 59from __future__ import division 60from __future__ import print_function 61 62import inspect 63 64 65def make_decorator(target, 66 decorator_func, 67 decorator_name=None, 68 decorator_doc='', 69 decorator_argspec=None): 70 """Make a decorator from a wrapper and a target. 71 72 Args: 73 target: The final callable to be wrapped. 74 decorator_func: The wrapper function. 75 decorator_name: The name of the decorator. If `None`, the name of the 76 function calling make_decorator. 77 decorator_doc: Documentation specific to this application of 78 `decorator_func` to `target`. 79 decorator_argspec: The new callable signature of this decorator. 80 81 Returns: 82 The `decorator_func` argument with new metadata attached. 83 """ 84 if decorator_name is None: 85 decorator_name = inspect.currentframe().f_back.f_code.co_name 86 decorator = TFDecorator(decorator_name, target, decorator_doc, 87 decorator_argspec) 88 setattr(decorator_func, '_tf_decorator', decorator) 89 # Objects that are callables (e.g., a functools.partial object) may not have 90 # the following attributes. 91 if hasattr(target, '__name__'): 92 decorator_func.__name__ = target.__name__ 93 if hasattr(target, '__qualname__'): 94 decorator_func.__qualname__ = target.__qualname__ 95 if hasattr(target, '__module__'): 96 decorator_func.__module__ = target.__module__ 97 if hasattr(target, '__dict__'): 98 # Copy dict entries from target which are not overridden by decorator_func. 99 for name in target.__dict__: 100 if name not in decorator_func.__dict__: 101 decorator_func.__dict__[name] = target.__dict__[name] 102 if hasattr(target, '__doc__'): 103 decorator_func.__doc__ = decorator.__doc__ 104 decorator_func.__wrapped__ = target 105 # Keeping a second handle to `target` allows callers to detect whether the 106 # decorator was modified using `rewrap`. 107 decorator_func.__original_wrapped__ = target 108 return decorator_func 109 110 111def _has_tf_decorator_attr(obj): 112 """Checks if object has _tf_decorator attribute. 113 114 This check would work for mocked object as well since it would 115 check if returned attribute has the right type. 116 117 Args: 118 obj: Python object. 119 """ 120 return ( 121 hasattr(obj, '_tf_decorator') and 122 isinstance(getattr(obj, '_tf_decorator'), TFDecorator)) 123 124 125def rewrap(decorator_func, previous_target, new_target): 126 """Injects a new target into a function built by make_decorator. 127 128 This function allows replacing a function wrapped by `decorator_func`, 129 assuming the decorator that wraps the function is written as described below. 130 131 The decorator function must use `<decorator name>.__wrapped__` instead of the 132 wrapped function that is normally used: 133 134 Example: 135 136 # Instead of this: 137 def simple_parametrized_wrapper(*args, **kwds): 138 return wrapped_fn(*args, **kwds) 139 140 tf_decorator.make_decorator(simple_parametrized_wrapper, wrapped_fn) 141 142 # Write this: 143 def simple_parametrized_wrapper(*args, **kwds): 144 return simple_parametrized_wrapper.__wrapped__(*args, **kwds) 145 146 tf_decorator.make_decorator(simple_parametrized_wrapper, wrapped_fn) 147 148 Note that this process modifies decorator_func. 149 150 Args: 151 decorator_func: Callable returned by `wrap`. 152 previous_target: Callable that needs to be replaced. 153 new_target: Callable to replace previous_target with. 154 155 Returns: 156 The updated decorator. If decorator_func is not a tf_decorator, new_target 157 is returned. 158 """ 159 # Because the process mutates the decorator, we only need to alter the 160 # innermost function that wraps previous_target. 161 cur = decorator_func 162 innermost_decorator = None 163 target = None 164 while _has_tf_decorator_attr(cur): 165 innermost_decorator = cur 166 target = getattr(cur, '_tf_decorator') 167 if target.decorated_target is previous_target: 168 break 169 cur = target.decorated_target 170 assert cur is not None 171 172 # If decorator_func is not a decorator, new_target replaces it directly. 173 if innermost_decorator is None: 174 # Consistency check. The caller should always pass the result of 175 # tf_decorator.unwrap as previous_target. If decorator_func is not a 176 # decorator, that will have returned decorator_func itself. 177 assert decorator_func is previous_target 178 return new_target 179 180 target.decorated_target = new_target 181 182 if inspect.ismethod(innermost_decorator): 183 # Bound methods can't be assigned attributes. Thankfully, they seem to 184 # be just proxies for their unbound counterpart, and we can modify that. 185 if hasattr(innermost_decorator, '__func__'): 186 innermost_decorator.__func__.__wrapped__ = new_target 187 elif hasattr(innermost_decorator, 'im_func'): 188 innermost_decorator.im_func.__wrapped__ = new_target 189 else: 190 innermost_decorator.__wrapped__ = new_target 191 else: 192 innermost_decorator.__wrapped__ = new_target 193 194 return decorator_func 195 196 197def unwrap(maybe_tf_decorator): 198 """Unwraps an object into a list of TFDecorators and a final target. 199 200 Args: 201 maybe_tf_decorator: Any callable object. 202 203 Returns: 204 A tuple whose first element is an list of TFDecorator-derived objects that 205 were applied to the final callable target, and whose second element is the 206 final undecorated callable target. If the `maybe_tf_decorator` parameter is 207 not decorated by any TFDecorators, the first tuple element will be an empty 208 list. The `TFDecorator` list is ordered from outermost to innermost 209 decorators. 210 """ 211 decorators = [] 212 cur = maybe_tf_decorator 213 while True: 214 if isinstance(cur, TFDecorator): 215 decorators.append(cur) 216 elif _has_tf_decorator_attr(cur): 217 decorators.append(getattr(cur, '_tf_decorator')) 218 else: 219 break 220 if not hasattr(decorators[-1], 'decorated_target'): 221 break 222 cur = decorators[-1].decorated_target 223 return decorators, cur 224 225 226class TFDecorator(object): 227 """Base class for all TensorFlow decorators. 228 229 TFDecorator captures and exposes the wrapped target, and provides details 230 about the current decorator. 231 """ 232 233 def __init__(self, 234 decorator_name, 235 target, 236 decorator_doc='', 237 decorator_argspec=None): 238 self._decorated_target = target 239 self._decorator_name = decorator_name 240 self._decorator_doc = decorator_doc 241 self._decorator_argspec = decorator_argspec 242 if hasattr(target, '__name__'): 243 self.__name__ = target.__name__ 244 if hasattr(target, '__qualname__'): 245 self.__qualname__ = target.__qualname__ 246 if self._decorator_doc: 247 self.__doc__ = self._decorator_doc 248 elif hasattr(target, '__doc__') and target.__doc__: 249 self.__doc__ = target.__doc__ 250 else: 251 self.__doc__ = '' 252 253 def __get__(self, instance, owner): 254 return self._decorated_target.__get__(instance, owner) 255 256 def __call__(self, *args, **kwargs): 257 return self._decorated_target(*args, **kwargs) 258 259 @property 260 def decorated_target(self): 261 return self._decorated_target 262 263 @decorated_target.setter 264 def decorated_target(self, decorated_target): 265 self._decorated_target = decorated_target 266 267 @property 268 def decorator_name(self): 269 return self._decorator_name 270 271 @property 272 def decorator_doc(self): 273 return self._decorator_doc 274 275 @property 276 def decorator_argspec(self): 277 return self._decorator_argspec 278