1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Generic source code transformation infrastructure.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import inspect 22import threading 23import types 24 25import gast 26 27from tensorflow.python.autograph.pyct import cache 28from tensorflow.python.autograph.pyct import inspect_utils 29from tensorflow.python.autograph.pyct import loader 30from tensorflow.python.autograph.pyct import naming 31from tensorflow.python.autograph.pyct import origin_info 32from tensorflow.python.autograph.pyct import parser 33from tensorflow.python.autograph.pyct import templates 34from tensorflow.python.autograph.pyct import transformer 35from tensorflow.python.autograph.utils import ag_logging as logging 36 37 38def _wrap_into_factory(nodes, entity_name, inner_factory_name, 39 outer_factory_name, closure_vars, factory_args, 40 future_features): 41 """Wraps an AST into the body of a factory with consistent lexical context. 42 43 The AST is expected to define some symbol with a name given by `entity_name`. 44 45 This mechanism ensures that the resulting transformed entity has lexical 46 scoping identical to that of the source entity, while allowing extra 47 parametrization. 48 49 Two nested factories achieve the following: 50 51 1. The inner factory dynamically creates the entity represented by `nodes`. 52 2. The inner factory is parametrized by a custom set of arguments. 53 3. The inner factory has a closure identical to that of the transformed 54 entity. 55 4. The inner factory has local variables named like `args`, which `nodes` may 56 use as additional parameters. 57 5. The inner factory returns the variables given by `entity_name`. 58 6. The outer factory is niladic. 59 7. The outer factory has no closure. 60 8. The outer factory creates the necessary lexical scope for the inner 61 factory, so that the loaded code has the given configuration for 62 closure/globals. 63 9. The outer factory returns the inner factory. 64 65 Roughly speaking, the following code is generated: 66 67 from __future__ import future_feature_1 68 from __future__ import future_feature_2 69 ... 70 71 def outer_factory(): 72 closure_var_1 = None 73 closure_var_2 = None 74 ... 75 76 def inner_factory(arg_1, arg_2, ...): 77 <<nodes>> 78 return entity 79 80 return inner_factory 81 82 The lexical scoping is created using dummy symbol declarations which create 83 local variables in the body of the outer factory, so that the Python parser 84 correctly marks them as free non-global variables upon load (that is, it 85 creates cell slots for each symbol. These symbols are initialized with None, 86 but their values are not expected to be used; instead, the caller is expected 87 to replace them with the cells of the source entity. For more details, see: 88 https://docs.python.org/3/reference/executionmodel.html#binding-of-names 89 90 Args: 91 nodes: Tuple[ast.AST], the source code to wrap. 92 entity_name: Union[Text, ast.AST], the name of the principal entity that 93 `nodes` define. 94 inner_factory_name: Text, the name of the inner factory. 95 outer_factory_name: Text, the name of the outer factory. 96 closure_vars: Iterable[Text], names of the closure variables for the inner 97 factory. 98 factory_args: Iterable[Text], names of additional arguments for the 99 inner factory. Useful to configure variables that the converted code can 100 use. Typically, these are modules. 101 future_features: Iterable[Text], names of future statements to associate the 102 code with. 103 104 Returns: 105 ast.AST 106 """ 107 dummy_closure_defs = [] 108 for var_name in closure_vars: 109 template = """ 110 var_name = None 111 """ 112 dummy_closure_defs.extend(templates.replace(template, var_name=var_name)) 113 114 if future_features: 115 future_imports = gast.ImportFrom( 116 module='__future__', 117 names=[gast.alias(name=name, asname=None) for name in future_features], 118 level=0) 119 else: 120 future_imports = [] 121 122 factory_args = [ 123 gast.Name(name, ctx=gast.Param(), annotation=None, type_comment=None) 124 for name in factory_args 125 ] 126 127 template = """ 128 future_imports 129 def outer_factory_name(): 130 dummy_closure_defs 131 def inner_factory_name(factory_args): 132 entity_defs 133 return entity_name 134 return inner_factory_name 135 """ 136 return templates.replace( 137 template, 138 dummy_closure_defs=dummy_closure_defs, 139 entity_defs=nodes, 140 entity_name=entity_name, 141 factory_args=factory_args, 142 future_imports=future_imports, 143 inner_factory_name=inner_factory_name, 144 outer_factory_name=outer_factory_name) 145 146 147class _PythonFnFactory(object): 148 """Helper object that wraps a Python function factory.""" 149 150 def __init__(self, name, freevars, extra_locals): 151 """Creates a new factory for a Python function. 152 153 Args: 154 name: The function name. 155 freevars: The list of non-global free variables for the function. 156 extra_locals: Dict[Text, Any], names and values for custom variables that 157 are accessible to the generated code as local variables. 158 """ 159 self._name = name 160 self._freevars = freevars 161 self._extra_locals = extra_locals 162 163 self._unbound_factory = None 164 self.module = None 165 self.source_map = None 166 167 def create(self, 168 nodes, 169 namer, 170 inner_factory_name='inner_factory', 171 outer_factory_name='outer_factory', 172 future_features=()): 173 """Initializes a function.""" 174 if self._unbound_factory is not None: 175 raise ValueError('double initialization; create a new object instead') 176 177 inner_factory_name = namer.new_symbol(inner_factory_name, ()) 178 outer_factory_name = namer.new_symbol(outer_factory_name, ()) 179 nodes = _wrap_into_factory(nodes, self._name, inner_factory_name, 180 outer_factory_name, self._freevars, 181 self._extra_locals.keys(), future_features) 182 183 module, _, source_map = loader.load_ast( 184 nodes, include_source_map=True) 185 outer_factory = getattr(module, outer_factory_name) 186 self._unbound_factory = outer_factory() 187 self.module = module 188 self.source_map = source_map 189 190 def instantiate(self, 191 globals_, 192 closure, 193 defaults=None, 194 kwdefaults=None): 195 """Creates a new function instance.""" 196 if self._unbound_factory is None: 197 raise ValueError('call create first') 198 199 factory_code = self._unbound_factory.__code__ 200 factory_freevars = factory_code.co_freevars 201 closure_map = dict(zip(self._freevars, closure)) 202 factory_closure = tuple( 203 closure_map[name] for name in factory_code.co_freevars) 204 if len(factory_closure) != len(closure): 205 raise ValueError( 206 'closure mismatch, requested {}, but source function had {}'.format( 207 self._freevars, factory_freevars)) 208 209 bound_factory = types.FunctionType( 210 code=factory_code, 211 globals=globals_, 212 name=self._name, 213 argdefs=(), 214 closure=factory_closure) 215 216 # The lint override is a false positive. 217 new_fn = bound_factory(**self._extra_locals) # pylint:disable=not-callable 218 219 if defaults: 220 new_fn.__defaults__ = defaults 221 if kwdefaults: 222 new_fn.__kwdefaults__ = kwdefaults 223 224 return new_fn 225 226 227class GenericTranspiler(object): 228 """A generic transpiler for Python functions. 229 230 Its interface is the `transform` API, which can process Python function 231 objects. Internally, it handles parsing. 232 233 Users typically subclass this, customizing the `transform_ast` method. The 234 output of transformed_ast is returned directly by `transform`. Existing 235 methods like `transform_function` may also be overloaded. 236 237 Example: 238 239 class MyTransformer(GenericTranspiler): 240 241 def transform_ast(self, node, ctx): 242 result = <<transform node>> 243 return result 244 245 transformer = MyTransfomer() 246 247 result = transformer.transform(f, ...) 248 # result is the output 249 """ 250 251 def get_transformed_name(self, node): 252 """Returns a name for the output function. Subclasses may override this.""" 253 if isinstance(node, gast.Lambda): 254 return 'lam' 255 elif isinstance(node, gast.FunctionDef): 256 return node.name 257 raise ValueError('Unknown node type {}'.format(node)) 258 259 def transform_ast(self, node, ctx): 260 """Performs an actual transformation of a function's AST. 261 262 Subclasses must implement this method, and do not usually call it. 263 264 Args: 265 node: One or more ast.AST nodes representing the AST to be transformed. 266 ctx: transformer.Context. 267 """ 268 raise NotImplementedError('subclasses must override this') 269 270 def transform(self, obj, user_context): 271 """Transforms a Python object. 272 273 Users typically call this method. 274 275 Args: 276 obj: A Python object, function, type, etc. 277 user_context: An opaque object (may be None) that is forwarded to 278 transform_ast, through the ctx.user_context argument. 279 Returns: 280 The result of calling transform_function. 281 282 Raises: 283 NotImplementedError: if the type of obj is not handled. 284 """ 285 if inspect.isfunction(obj) or inspect.ismethod(obj): 286 return self.transform_function(obj, user_context) 287 288 raise NotImplementedError('Non-function: {}'.format(type(obj))) 289 290 def _erase_arg_defaults(self, node): 291 """Erase arg default expressions, which would otherwise be unbound.""" 292 args = node.args 293 for i in range(len(args.defaults)): 294 args.defaults[i] = parser.parse_expression('None') 295 for i, d in enumerate(args.kw_defaults): 296 if d is not None: 297 args.kw_defaults[i] = parser.parse_expression('None') 298 return node 299 300 def transform_module(self, mod, user_context): 301 """Transforms a module. 302 303 Subclasses may override this method. The return value is opaque. 304 305 The method receives the original AST. The result is passed as-is to the 306 output of `transform`. 307 308 Args: 309 mod: A Python module. 310 user_context: An opaque object (may be None) that is forwarded to 311 transform_ast, through the ctx.user_context argument. 312 Returns: 313 List[Tuple[Any, Any]]. By default it returns the output of transform_ast, 314 evaluated on each supported member, other than modules, together with a 315 `transformer.Context` containing information about the transformation 316 process. 317 """ 318 result = [] 319 for member in mod.__dict__.values(): 320 if inspect.ismodule(member): 321 continue # Not transforming modules recursively. 322 try: 323 result.append(self.transform(member, user_context)) 324 except NotImplementedError: 325 pass # Skip unsupported elements. 326 return result 327 328 def transform_function(self, fn, user_context): 329 """Transforms a function. 330 331 Subclasses may override this method. The return value is opaque. 332 333 The method receives the original AST. The result is passed as-is to the 334 output of `transform`. 335 336 Args: 337 fn: A function or lambda. 338 user_context: An opaque object (may be None) that is forwarded to 339 transform_ast, through the ctx.user_context argument. 340 Returns: 341 Tuple[Any, Any]. By default it returns the output of transform_ast, 342 together with a `transformer.Context` containing information about the 343 transformation process. 344 """ 345 future_features = inspect_utils.getfutureimports(fn) 346 node, source = parser.parse_entity(fn, future_features=future_features) 347 logging.log(3, 'Source code of %s:\n\n%s\n', fn, source) 348 349 origin_info.resolve_entity(node, source, fn) 350 351 namespace = inspect_utils.getnamespace(fn) 352 namer = naming.Namer(namespace) 353 new_name = namer.new_symbol(self.get_transformed_name(node), ()) 354 entity_info = transformer.EntityInfo( 355 name=new_name, 356 source_code=source, 357 source_file='<fragment>', 358 future_features=future_features, 359 namespace=namespace) 360 context = transformer.Context(entity_info, namer, user_context) 361 362 node = self._erase_arg_defaults(node) 363 result = self.transform_ast(node, context) 364 365 return result, context 366 367 368class PyToPy(GenericTranspiler): 369 """A generic Python-to-Python transpiler. 370 371 Its `transform` method offers a function-in, function-out interface. 372 Internally, it takes care of parsing, caching and loading of the translated 373 code. 374 375 Users typically subclass this, overriding `transform_ast`. 376 377 Usually, instances of this class are singletons, since each instance manages 378 its own cache. The caching can be controlled by overriding `get_caching_key`. 379 380 Example: 381 382 class MyTransformer(PyToPy): 383 384 def transform_ast(self, node, ctx): 385 node = <<transform node, usually using ast.NodeTransformer classes>> 386 return node 387 388 transformer = MyTransfomer() 389 390 new_f, module, source_map = transformer.transform_function(f, ...) 391 # new_f is a function with signature identical to f 392 393 The transformed function has access to the same namespace as the original 394 function. To allow access to internal APIs, users may inject additional 395 symbols by overriding `get_extra_locals`. 396 """ 397 398 def __init__(self): 399 self._cache_lock = threading.RLock() 400 self._cache = cache.CodeObjectCache() 401 402 def get_extra_locals(self): 403 """Returns extra static local variables to be made to transformed code. 404 405 Subclasses must override this. 406 407 Returns: 408 extra_locals: A Dict[Text, Any] containing additional variables to make 409 available to the transformed code. 410 """ 411 raise NotImplementedError('subclasses must override this') 412 413 def get_caching_key(self, user_context): 414 """Returns a unique key to use for caching. 415 416 Subclasses must override this. 417 418 Calls made to `transform_function` with functions that have the same code 419 object and caching key will return a cached instance on subsequent 420 invocations. 421 422 Args: 423 user_context: The context object which was passed to `transform`. 424 425 Returns: 426 extra_locals: A hashable. 427 """ 428 raise NotImplementedError('subclasses must override this') 429 430 def _cached_factory(self, fn, cache_subkey): 431 cached_factory = self._cache[fn][cache_subkey] 432 logging.log(3, 'Cache hit for %s subkey %s: %s', fn, cache_subkey, 433 cached_factory) 434 return cached_factory 435 436 def transform_function(self, fn, user_context): 437 """Transforms a function. See GenericTranspiler.trasnform_function. 438 439 This overload wraps the parent's `transform_function`, adding caching and 440 facilities to instantiate the output as a Python object. It also 441 adds facilities to make new symbols available to the generated Python code, 442 visible as local variables - see `get_extra_locals`. 443 444 Args: 445 fn: A function or lambda. 446 user_context: An opaque object (may be None) that is forwarded to 447 transform_ast, through the ctx.user_context argument. 448 Returns: 449 A tuple: 450 * A function or lambda with the same signature and closure as `fn` 451 * The temporary module into which the transformed function was loaded 452 * The source map as a 453 Dict[origin_info.LineLocation, origin_info.OriginInfo] 454 """ 455 cache_subkey = self.get_caching_key(user_context) 456 457 if self._cache.has(fn, cache_subkey): 458 # Fast path: use a lock-free check. 459 factory = self._cached_factory(fn, cache_subkey) 460 461 else: 462 with self._cache_lock: 463 # Check again under lock. 464 if self._cache.has(fn, cache_subkey): 465 factory = self._cached_factory(fn, cache_subkey) 466 467 else: 468 logging.log(1, '%s is not cached for subkey %s', fn, cache_subkey) 469 # TODO(mdan): Confusing overloading pattern. Fix. 470 nodes, ctx = super(PyToPy, self).transform_function(fn, user_context) 471 472 if isinstance(nodes, gast.Lambda): 473 nodes = gast.Assign( 474 targets=[ 475 gast.Name( 476 ctx.info.name, 477 ctx=gast.Store(), 478 annotation=None, 479 type_comment=None) 480 ], 481 value=nodes) 482 else: 483 nodes.name = ctx.info.name 484 485 if logging.has_verbosity(2): 486 logging.log(2, 'Transformed %s:\n\n%s\n', fn, parser.unparse(nodes)) 487 488 factory = _PythonFnFactory( 489 ctx.info.name, fn.__code__.co_freevars, self.get_extra_locals()) 490 factory.create( 491 nodes, ctx.namer, future_features=ctx.info.future_features) 492 self._cache[fn][cache_subkey] = factory 493 494 transformed_fn = factory.instantiate( 495 globals_=fn.__globals__, 496 closure=fn.__closure__ or (), 497 defaults=fn.__defaults__, 498 kwdefaults=getattr(fn, '__kwdefaults__', None)) 499 return transformed_fn, factory.module, factory.source_map 500