• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Converter construction support.
16
17This module contains a base class for all converters, as well as supporting
18structures. These structures are referred to as contexts.
19
20The class hierarchy is as follows:
21
22    <your converter>
23      [extends] converter.Base
24        [extends] transformer.Base
25            [extends] gast.nodeTransformer
26          [uses] transformer.SourceInfo
27        [uses] converter.EntityContext
28          [uses] converter.ProgramContext
29          [uses] transformer.SourceInfo
30
31converter.Base is a specialization of transformer.Base for AutoGraph. It's a
32very lightweight subclass that adds a `ctx` attribute holding the corresponding
33EntityContext object (see below). Note that converters are not reusable, and
34`visit` will raise an error if called more than once.
35
36converter.EntityContext contains mutable state associated with an entity that
37the converter processes.
38
39converter.ProgramContext contains mutable state across related entities. For
40example, when converting several functions that call one another, the
41ProgramContext should be shared across these entities.
42
43Below is the overall flow at conversion:
44
45    program_ctx = ProgramContext(<entities to convert>, <global settings>, ...)
46    while <program_ctx has more entities to convert>:
47      entity, source_info = <get next entity from program_ctx>
48      entity_ctx = EntityContext(program_ctx, source_info)
49      for <each ConverterClass>:
50        converter = ConverterClass(entity_ctx)
51
52        # May update entity_ctx and program_ctx
53        entity = converter.visit(entity)
54
55      <add entity's dependencies to program_ctx>
56
57Note that pyct contains a small number of transformers used for static analysis.
58These implement transformer.Base, rather than converter.Base, to avoid a
59dependency on AutoGraph.
60"""
61
62from __future__ import absolute_import
63from __future__ import division
64from __future__ import print_function
65
66import enum
67
68from tensorflow.python.autograph.pyct import anno
69from tensorflow.python.autograph.pyct import ast_util
70from tensorflow.python.autograph.pyct import parser
71from tensorflow.python.autograph.pyct import templates
72from tensorflow.python.autograph.pyct import transformer
73from tensorflow.python.util.tf_export import tf_export
74
75# TODO(mdan): These contexts can be refactored into first class objects.
76# For example, we could define Program and Entity abstractions that hold on
77# to the actual entity and have conversion methods.
78
79# TODO(mdan): Add a test specific to this converter.
80
81
82@tf_export('autograph.experimental.Feature')
83class Feature(enum.Enum):
84  """This enumeration represents optional conversion options.
85
86  These conversion options are experimental. They are subject to change without
87  notice and offer no guarantees.
88
89  _Example Usage_
90
91  ```python
92  optionals= tf.autograph.experimental.Feature.EQUALITY_OPERATORS
93  @tf.function(experimental_autograph_options=optionals)
94  def f(i):
95    if i == 0:  # EQUALITY_OPERATORS allows the use of == here.
96      tf.print('i is zero')
97  ```
98
99  Attributes:
100    ALL: Enable all features.
101    AUTO_CONTROL_DEPS: Insert of control dependencies in the generated code.
102    ASSERT_STATEMENTS: Convert Tensor-dependent assert statements to tf.Assert.
103    BUILTIN_FUNCTIONS: Convert builtin functions applied to Tensors to
104      their TF counterparts.
105    EQUALITY_OPERATORS: Whether to convert the comparison operators, like
106      equality. This is soon to be deprecated as support is being added to the
107      Tensor class.
108    LISTS: Convert list idioms, like initializers, slices, append, etc.
109    NAME_SCOPES: Insert name scopes that name ops according to context, like the
110      function they were defined in.
111  """
112
113  ALL = 'ALL'
114
115  AUTO_CONTROL_DEPS = 'AUTO_CONTROL_DEPS'
116  ASSERT_STATEMENTS = 'ASSERT_STATEMENTS'
117  BUILTIN_FUNCTIONS = 'BUILTIN_FUNCTIONS'
118  EQUALITY_OPERATORS = 'EQUALITY_OPERATORS'
119  LISTS = 'LISTS'
120  NAME_SCOPES = 'NAME_SCOPES'
121
122  @classmethod
123  def all(cls):
124    """Returns a tuple that enables all options."""
125    return tuple(cls.__members__.values())
126
127  @classmethod
128  def all_but(cls, exclude):
129    """Returns a tuple that enables all but the excluded options."""
130    if not isinstance(exclude, (list, tuple, set)):
131      exclude = (exclude,)
132    return tuple(set(cls.all()) - set(exclude) - {cls.ALL})
133
134
135STANDARD_OPTIONS = None  # Forward definition.
136
137
138class ConversionOptions(object):
139  """Immutable container for global conversion flags.
140
141  Attributes:
142    recursive: bool, whether to recursively convert any user functions or
143      classes that the converted function may use.
144    user_requested: bool, whether the conversion was explicitly requested by
145      the user, as opposed to being performed as a result of other logic. This
146      value always auto-resets to False in child conversions.
147    optional_features: Union[Feature, Set[Feature]], controls the use of
148      optional features in the conversion process. See Feature for available
149      options.
150  """
151
152  def __init__(self,
153               recursive=False,
154               user_requested=False,
155               internal_convert_user_code=True,
156               optional_features=Feature.ALL):
157    self.recursive = recursive
158    self.user_requested = user_requested
159    # TODO(mdan): Rename to conversion_recursion_depth?
160    self.internal_convert_user_code = internal_convert_user_code
161
162    if optional_features is None:
163      optional_features = ()
164    elif isinstance(optional_features, Feature):
165      optional_features = (optional_features,)
166    optional_features = frozenset(optional_features)
167    self.optional_features = optional_features
168
169  def as_tuple(self):
170    return (self.recursive, self.user_requested,
171            self.internal_convert_user_code, self.optional_features)
172
173  def __hash__(self):
174    return hash(self.as_tuple())
175
176  def __eq__(self, other):
177    assert isinstance(other, ConversionOptions)
178    return self.as_tuple() == other.as_tuple()
179
180  def __str__(self):
181    return 'ConversionOptions[{}]'
182
183  def uses(self, feature):
184    return (Feature.ALL in self.optional_features or
185            feature in self.optional_features)
186
187  def call_options(self):
188    """Returns the corresponding options to be used for recursive conversion."""
189    return ConversionOptions(
190        recursive=self.recursive,
191        user_requested=False,
192        internal_convert_user_code=self.recursive,
193        optional_features=self.optional_features)
194
195  def to_ast(self):
196    """Returns a representation of this object as an AST node.
197
198    The AST node encodes a constructor that would create an object with the
199    same contents.
200
201    Returns:
202      ast.Node
203    """
204    if self == STANDARD_OPTIONS:
205      return parser.parse_expression('ag__.STD')
206
207    template = """
208      ag__.ConversionOptions(
209          recursive=recursive_val,
210          user_requested=user_requested_val,
211          optional_features=optional_features_val,
212          internal_convert_user_code=internal_convert_user_code_val)
213    """
214
215    def list_of_features(values):
216      return parser.parse_expression('({})'.format(', '.join(
217          'ag__.{}'.format(str(v)) for v in values)))
218
219    expr_ast = templates.replace(
220        template,
221        recursive_val=parser.parse_expression(str(self.recursive)),
222        user_requested_val=parser.parse_expression(str(self.user_requested)),
223        internal_convert_user_code_val=parser.parse_expression(
224            str(self.internal_convert_user_code)),
225        optional_features_val=list_of_features(self.optional_features))
226    return expr_ast[0].value
227
228
229STANDARD_OPTIONS = ConversionOptions(
230    recursive=True,
231    user_requested=False,
232    internal_convert_user_code=True,
233    optional_features=None)
234
235
236class ProgramContext(object):
237  """ProgramContext keeps track of converting function hierarchies.
238
239  Attributes:
240    options: ConversionOptions
241    autograph_module: Deprecated. Do not use.
242  """
243
244  def __init__(self, options, autograph_module=None):
245    self.options = options
246    self.autograph_module = autograph_module
247
248
249class Base(transformer.Base):
250  """All converters should inherit from this class.
251
252  Attributes:
253    ctx: EntityContext
254  """
255
256  def __init__(self, ctx):
257    super(Base, self).__init__(ctx)
258
259    self._used = False
260    self._ast_depth = 0
261
262  def get_definition_directive(self, node, directive, arg, default):
263    """Returns the unique directive argument for a symbol.
264
265    See lang/directives.py for details on directives.
266
267    Example:
268       # Given a directive in the code:
269       ag.foo_directive(bar, baz=1)
270
271       # One can write for an AST node Name(id='bar'):
272       get_definition_directive(node, ag.foo_directive, 'baz')
273
274    Args:
275      node: ast.AST, the node representing the symbol for which the directive
276        argument is needed.
277      directive: Callable[..., Any], the directive to search.
278      arg: str, the directive argument to return.
279      default: Any
280
281    Raises:
282      ValueError: if conflicting annotations have been found
283    """
284    defs = anno.getanno(node, anno.Static.ORIG_DEFINITIONS, ())
285    if not defs:
286      return default
287
288    arg_values_found = []
289    for def_ in defs:
290      if (directive in def_.directives and arg in def_.directives[directive]):
291        arg_values_found.append(def_.directives[directive][arg])
292
293    if not arg_values_found:
294      return default
295
296    if len(arg_values_found) == 1:
297      return arg_values_found[0]
298
299    # If multiple annotations reach the symbol, they must all match. If they do,
300    # return any of them.
301    first_value = arg_values_found[0]
302    for other_value in arg_values_found[1:]:
303      if not ast_util.matches(first_value, other_value):
304        qn = anno.getanno(node, anno.Basic.QN)
305        raise ValueError(
306            '%s has ambiguous annotations for %s(%s): %s, %s' %
307            (qn, directive.__name__, arg, parser.unparse(other_value).strip(),
308             parser.unparse(first_value).strip()))
309    return first_value
310
311  def visit(self, node):
312    if not self._ast_depth:
313      if self._used:
314        raise ValueError('converter objects cannot be reused')
315      self._used = True
316
317    self._ast_depth += 1
318    try:
319      return super(Base, self).visit(node)
320    finally:
321      self._ast_depth -= 1
322