• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Generic source code transformation infrastructure."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import inspect
22import threading
23import types
24
25import gast
26
27from tensorflow.python.autograph.pyct import cache
28from tensorflow.python.autograph.pyct import inspect_utils
29from tensorflow.python.autograph.pyct import loader
30from tensorflow.python.autograph.pyct import naming
31from tensorflow.python.autograph.pyct import origin_info
32from tensorflow.python.autograph.pyct import parser
33from tensorflow.python.autograph.pyct import templates
34from tensorflow.python.autograph.pyct import transformer
35from tensorflow.python.autograph.utils import ag_logging as logging
36
37
38def _wrap_into_factory(nodes, entity_name, inner_factory_name,
39                       outer_factory_name, closure_vars, factory_args,
40                       future_features):
41  """Wraps an AST into the body of a factory with consistent lexical context.
42
43  The AST is expected to define some symbol with a name given by `entity_name`.
44
45  This mechanism ensures that the resulting transformed entity has lexical
46  scoping identical to that of the source entity, while allowing extra
47  parametrization.
48
49  Two nested factories achieve the following:
50
51   1. The inner factory dynamically creates the entity represented by `nodes`.
52   2. The inner factory is parametrized by a custom set of arguments.
53   3. The inner factory has a closure identical to that of the transformed
54       entity.
55   4. The inner factory has local variables named like `args`, which `nodes` may
56       use as additional parameters.
57   5. The inner factory returns the variables given by `entity_name`.
58   6. The outer factory is niladic.
59   7. The outer factory has no closure.
60   8. The outer factory creates the necessary lexical scope for the inner
61       factory, so that the loaded code has the given configuration for
62       closure/globals.
63   9. The outer factory returns the inner factory.
64
65  Roughly speaking, the following code is generated:
66
67      from __future__ import future_feature_1
68      from __future__ import future_feature_2
69      ...
70
71      def outer_factory():
72        closure_var_1 = None
73        closure_var_2 = None
74        ...
75
76        def inner_factory(arg_1, arg_2, ...):
77          <<nodes>>
78          return entity
79
80        return inner_factory
81
82  The lexical scoping is created using dummy symbol declarations which create
83  local variables in the body of the outer factory, so that the Python parser
84  correctly marks them as free non-global variables upon load (that is, it
85  creates cell slots for each symbol. These symbols are initialized with None,
86  but their values are not expected to be used; instead, the caller is expected
87  to replace them with the cells of the source entity. For more details, see:
88  https://docs.python.org/3/reference/executionmodel.html#binding-of-names
89
90  Args:
91    nodes: Tuple[ast.AST], the source code to wrap.
92    entity_name: Union[Text, ast.AST], the name of the principal entity that
93      `nodes` define.
94    inner_factory_name: Text, the name of the inner factory.
95    outer_factory_name: Text, the name of the outer factory.
96    closure_vars: Iterable[Text], names of the closure variables for the inner
97      factory.
98    factory_args: Iterable[Text], names of additional arguments for the
99      inner factory. Useful to configure variables that the converted code can
100      use. Typically, these are modules.
101    future_features: Iterable[Text], names of future statements to associate the
102      code with.
103
104  Returns:
105    ast.AST
106  """
107  dummy_closure_defs = []
108  for var_name in closure_vars:
109    template = """
110      var_name = None
111    """
112    dummy_closure_defs.extend(templates.replace(template, var_name=var_name))
113
114  if future_features:
115    future_imports = gast.ImportFrom(
116        module='__future__',
117        names=[gast.alias(name=name, asname=None) for name in future_features],
118        level=0)
119  else:
120    future_imports = []
121
122  factory_args = [
123      gast.Name(name, ctx=gast.Param(), annotation=None, type_comment=None)
124      for name in factory_args
125  ]
126
127  template = """
128    future_imports
129    def outer_factory_name():
130      dummy_closure_defs
131      def inner_factory_name(factory_args):
132        entity_defs
133        return entity_name
134      return inner_factory_name
135  """
136  return templates.replace(
137      template,
138      dummy_closure_defs=dummy_closure_defs,
139      entity_defs=nodes,
140      entity_name=entity_name,
141      factory_args=factory_args,
142      future_imports=future_imports,
143      inner_factory_name=inner_factory_name,
144      outer_factory_name=outer_factory_name)
145
146
147class _PythonFnFactory(object):
148  """Helper object that wraps a Python function factory."""
149
150  def __init__(self, name, freevars, extra_locals):
151    """Creates a new factory for a Python function.
152
153    Args:
154      name: The function name.
155      freevars: The list of non-global free variables for the function.
156      extra_locals: Dict[Text, Any], names and values for custom variables that
157        are accessible to the generated code as local variables.
158    """
159    self._name = name
160    self._freevars = freevars
161    self._extra_locals = extra_locals
162
163    self._unbound_factory = None
164    self.module = None
165    self.source_map = None
166
167  def create(self,
168             nodes,
169             namer,
170             inner_factory_name='inner_factory',
171             outer_factory_name='outer_factory',
172             future_features=()):
173    """Initializes a function."""
174    if self._unbound_factory is not None:
175      raise ValueError('double initialization; create a new object instead')
176
177    inner_factory_name = namer.new_symbol(inner_factory_name, ())
178    outer_factory_name = namer.new_symbol(outer_factory_name, ())
179    nodes = _wrap_into_factory(nodes, self._name, inner_factory_name,
180                               outer_factory_name, self._freevars,
181                               self._extra_locals.keys(), future_features)
182
183    module, _, source_map = loader.load_ast(
184        nodes, include_source_map=True)
185    outer_factory = getattr(module, outer_factory_name)
186    self._unbound_factory = outer_factory()
187    self.module = module
188    self.source_map = source_map
189
190  def instantiate(self,
191                  globals_,
192                  closure,
193                  defaults=None,
194                  kwdefaults=None):
195    """Creates a new function instance."""
196    if self._unbound_factory is None:
197      raise ValueError('call create first')
198
199    factory_code = self._unbound_factory.__code__
200    factory_freevars = factory_code.co_freevars
201    closure_map = dict(zip(self._freevars, closure))
202    factory_closure = tuple(
203        closure_map[name] for name in factory_code.co_freevars)
204    if len(factory_closure) != len(closure):
205      raise ValueError(
206          'closure mismatch, requested {}, but source function had {}'.format(
207              self._freevars, factory_freevars))
208
209    bound_factory = types.FunctionType(
210        code=factory_code,
211        globals=globals_,
212        name=self._name,
213        argdefs=(),
214        closure=factory_closure)
215
216    # The lint override is a false positive.
217    new_fn = bound_factory(**self._extra_locals)  # pylint:disable=not-callable
218
219    if defaults:
220      new_fn.__defaults__ = defaults
221    if kwdefaults:
222      new_fn.__kwdefaults__ = kwdefaults
223
224    return new_fn
225
226
227class GenericTranspiler(object):
228  """A generic transpiler for Python functions.
229
230  Its interface is the `transform` API, which can process Python function
231  objects. Internally, it handles parsing.
232
233  Users typically subclass this, customizing the `transform_ast` method. The
234  output of transformed_ast is returned directly by `transform`. Existing
235  methods like `transform_function` may also be overloaded.
236
237  Example:
238
239      class MyTransformer(GenericTranspiler):
240
241        def transform_ast(self, node, ctx):
242          result = <<transform node>>
243          return result
244
245      transformer = MyTransfomer()
246
247      result = transformer.transform(f, ...)
248      # result is the output
249  """
250
251  def get_transformed_name(self, node):
252    """Returns a name for the output function. Subclasses may override this."""
253    if isinstance(node, gast.Lambda):
254      return 'lam'
255    elif isinstance(node, gast.FunctionDef):
256      return node.name
257    raise ValueError('Unknown node type {}'.format(node))
258
259  def transform_ast(self, node, ctx):
260    """Performs an actual transformation of a function's AST.
261
262    Subclasses must implement this method, and do not usually call it.
263
264    Args:
265      node: One or more ast.AST nodes representing the AST to be transformed.
266      ctx: transformer.Context.
267    """
268    raise NotImplementedError('subclasses must override this')
269
270  def transform(self, obj, user_context):
271    """Transforms a Python object.
272
273    Users typically call this method.
274
275    Args:
276      obj: A Python object, function, type, etc.
277      user_context: An opaque object (may be None) that is forwarded to
278        transform_ast, through the ctx.user_context argument.
279    Returns:
280      The result of calling transform_function.
281
282    Raises:
283      NotImplementedError: if the type of obj is not handled.
284    """
285    if inspect.isfunction(obj) or inspect.ismethod(obj):
286      return self.transform_function(obj, user_context)
287
288    raise NotImplementedError('Non-function: {}'.format(type(obj)))
289
290  def _erase_arg_defaults(self, node):
291    """Erase arg default expressions, which would otherwise be unbound."""
292    args = node.args
293    for i in range(len(args.defaults)):
294      args.defaults[i] = parser.parse_expression('None')
295    for i, d in enumerate(args.kw_defaults):
296      if d is not None:
297        args.kw_defaults[i] = parser.parse_expression('None')
298    return node
299
300  def transform_module(self, mod, user_context):
301    """Transforms a module.
302
303    Subclasses may override this method. The return value is opaque.
304
305    The method receives the original AST. The result is passed as-is to the
306    output of `transform`.
307
308    Args:
309      mod: A Python module.
310      user_context: An opaque object (may be None) that is forwarded to
311        transform_ast, through the ctx.user_context argument.
312    Returns:
313      List[Tuple[Any, Any]]. By default it returns the output of transform_ast,
314      evaluated on each supported member, other than modules, together with a
315      `transformer.Context` containing information about the transformation
316      process.
317    """
318    result = []
319    for member in mod.__dict__.values():
320      if inspect.ismodule(member):
321        continue  # Not transforming modules recursively.
322      try:
323        result.append(self.transform(member, user_context))
324      except NotImplementedError:
325        pass  # Skip unsupported elements.
326    return result
327
328  def transform_function(self, fn, user_context):
329    """Transforms a function.
330
331    Subclasses may override this method. The return value is opaque.
332
333    The method receives the original AST. The result is passed as-is to the
334    output of `transform`.
335
336    Args:
337      fn: A function or lambda.
338      user_context: An opaque object (may be None) that is forwarded to
339        transform_ast, through the ctx.user_context argument.
340    Returns:
341      Tuple[Any, Any]. By default it returns the output of transform_ast,
342      together with a `transformer.Context` containing information about the
343      transformation process.
344    """
345    future_features = inspect_utils.getfutureimports(fn)
346    node, source = parser.parse_entity(fn, future_features=future_features)
347    logging.log(3, 'Source code of %s:\n\n%s\n', fn, source)
348
349    origin_info.resolve_entity(node, source, fn)
350
351    namespace = inspect_utils.getnamespace(fn)
352    namer = naming.Namer(namespace)
353    new_name = namer.new_symbol(self.get_transformed_name(node), ())
354    entity_info = transformer.EntityInfo(
355        name=new_name,
356        source_code=source,
357        source_file='<fragment>',
358        future_features=future_features,
359        namespace=namespace)
360    context = transformer.Context(entity_info, namer, user_context)
361
362    node = self._erase_arg_defaults(node)
363    result = self.transform_ast(node, context)
364
365    return result, context
366
367
368class PyToPy(GenericTranspiler):
369  """A generic Python-to-Python transpiler.
370
371  Its `transform` method offers a function-in, function-out interface.
372  Internally, it takes care of parsing, caching and loading of the translated
373  code.
374
375  Users typically subclass this, overriding `transform_ast`.
376
377  Usually, instances of this class are singletons, since each instance manages
378  its own cache. The caching can be controlled by overriding `get_caching_key`.
379
380  Example:
381
382      class MyTransformer(PyToPy):
383
384        def transform_ast(self, node, ctx):
385          node = <<transform node, usually using ast.NodeTransformer classes>>
386          return node
387
388      transformer = MyTransfomer()
389
390      new_f, module, source_map = transformer.transform_function(f, ...)
391      # new_f is a function with signature identical to f
392
393  The transformed function has access to the same namespace as the original
394  function. To allow access to internal APIs, users may inject additional
395  symbols by overriding `get_extra_locals`.
396  """
397
398  def __init__(self):
399    self._cache_lock = threading.RLock()
400    self._cache = cache.CodeObjectCache()
401
402  def get_extra_locals(self):
403    """Returns extra static local variables to be made to transformed code.
404
405    Subclasses must override this.
406
407    Returns:
408      extra_locals: A Dict[Text, Any] containing additional variables to make
409        available to the transformed code.
410    """
411    raise NotImplementedError('subclasses must override this')
412
413  def get_caching_key(self, user_context):
414    """Returns a unique key to use for caching.
415
416    Subclasses must override this.
417
418    Calls made to `transform_function` with functions that have the same code
419    object and caching key will return a cached instance on subsequent
420    invocations.
421
422    Args:
423      user_context: The context object which was passed to `transform`.
424
425    Returns:
426      extra_locals: A hashable.
427    """
428    raise NotImplementedError('subclasses must override this')
429
430  def _cached_factory(self, fn, cache_subkey):
431    cached_factory = self._cache[fn][cache_subkey]
432    logging.log(3, 'Cache hit for %s subkey %s: %s', fn, cache_subkey,
433                cached_factory)
434    return cached_factory
435
436  def transform_function(self, fn, user_context):
437    """Transforms a function. See GenericTranspiler.trasnform_function.
438
439    This overload wraps the parent's `transform_function`, adding caching and
440    facilities to instantiate the output as a Python object. It also
441    adds facilities to make new symbols available to the generated Python code,
442    visible as local variables - see `get_extra_locals`.
443
444    Args:
445      fn: A function or lambda.
446      user_context: An opaque object (may be None) that is forwarded to
447        transform_ast, through the ctx.user_context argument.
448    Returns:
449      A tuple:
450        * A function or lambda with the same signature and closure as `fn`
451        * The temporary module into which the transformed function was loaded
452        * The source map as a
453            Dict[origin_info.LineLocation, origin_info.OriginInfo]
454    """
455    cache_subkey = self.get_caching_key(user_context)
456
457    if self._cache.has(fn, cache_subkey):
458      # Fast path: use a lock-free check.
459      factory = self._cached_factory(fn, cache_subkey)
460
461    else:
462      with self._cache_lock:
463        # Check again under lock.
464        if self._cache.has(fn, cache_subkey):
465          factory = self._cached_factory(fn, cache_subkey)
466
467        else:
468          logging.log(1, '%s is not cached for subkey %s', fn, cache_subkey)
469          # TODO(mdan): Confusing overloading pattern. Fix.
470          nodes, ctx = super(PyToPy, self).transform_function(fn, user_context)
471
472          if isinstance(nodes, gast.Lambda):
473            nodes = gast.Assign(
474                targets=[
475                    gast.Name(
476                        ctx.info.name,
477                        ctx=gast.Store(),
478                        annotation=None,
479                        type_comment=None)
480                ],
481                value=nodes)
482          else:
483            nodes.name = ctx.info.name
484
485          if logging.has_verbosity(2):
486            logging.log(2, 'Transformed %s:\n\n%s\n', fn, parser.unparse(nodes))
487
488          factory = _PythonFnFactory(
489              ctx.info.name, fn.__code__.co_freevars, self.get_extra_locals())
490          factory.create(
491              nodes, ctx.namer, future_features=ctx.info.future_features)
492          self._cache[fn][cache_subkey] = factory
493
494    transformed_fn = factory.instantiate(
495        globals_=fn.__globals__,
496        closure=fn.__closure__ or (),
497        defaults=fn.__defaults__,
498        kwdefaults=getattr(fn, '__kwdefaults__', None))
499    return transformed_fn, factory.module, factory.source_map
500