• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Handles builtins and other special functions."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import gast
22
23from tensorflow.python.autograph.core import converter
24from tensorflow.python.autograph.operators import py_builtins
25from tensorflow.python.autograph.pyct import anno
26from tensorflow.python.autograph.pyct import templates
27
28
29class BuiltinFunctionTransformer(converter.Base):
30  """Handles builtin functions.
31
32  This transformer only covers functions that are translated into a
33  TF equivalent, like `len`.
34  """
35
36  def _convert_builtin(self, f, args, as_expression):
37    template = """
38      ag__.func(args)
39    """
40    if as_expression:
41      return templates.replace_as_expression(
42          template, func=py_builtins.overload_of(f).__name__, args=args)
43    else:
44      return templates.replace(
45          template, func=py_builtins.overload_of(f).__name__, args=args)
46
47  def visit_Call(self, node):
48    node = self.generic_visit(node)
49    if anno.hasanno(node.func, 'live_val'):
50      live_val = anno.getanno(node.func, 'live_val')
51      try:
52        if live_val in py_builtins.SUPPORTED_BUILTINS:
53          node = self._convert_builtin(live_val, node.args, as_expression=True)
54      except TypeError:
55        # Not everything in Python is hashable. If it isn't then it's definitely
56        # not a supported built-in.
57        return node
58    return node
59
60  def visit_Print(self, node):
61    node = self.generic_visit(node)
62    args = node.values
63    # Following is the case when calling print(a, b)
64    if len(args) == 1 and isinstance(args[0], gast.Tuple):
65      args = args[0].elts
66    return self._convert_builtin(print, args, as_expression=False)
67
68
69def transform(node, ctx):
70  return BuiltinFunctionTransformer(ctx).visit(node)
71