• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Handles control flow statements: while, for, if."""
16
17import gast
18
19from tensorflow.python.autograph.core import converter
20from tensorflow.python.autograph.lang import directives
21from tensorflow.python.autograph.pyct import anno
22from tensorflow.python.autograph.pyct import cfg
23from tensorflow.python.autograph.pyct import origin_info
24from tensorflow.python.autograph.pyct import parser
25from tensorflow.python.autograph.pyct import qual_names
26from tensorflow.python.autograph.pyct import templates
27from tensorflow.python.autograph.pyct.static_analysis import activity
28from tensorflow.python.autograph.pyct.static_analysis import annos
29from tensorflow.python.autograph.pyct.static_analysis import liveness
30from tensorflow.python.autograph.pyct.static_analysis import reaching_definitions
31from tensorflow.python.autograph.pyct.static_analysis import reaching_fndefs
32
33
34class _Function(object):
35
36  scope = None
37
38
39class ControlFlowTransformer(converter.Base):
40  """Transforms control flow structures like loops an conditionals."""
41
42  def visit_Lambda(self, node):
43    with self.state[_Function] as fn:
44      fn.scope = anno.getanno(node, anno.Static.SCOPE)
45      return self.generic_visit(node)
46
47  def visit_FunctionDef(self, node):
48    with self.state[_Function] as fn:
49      fn.scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
50      return self.generic_visit(node)
51
52  def _create_nonlocal_declarations(self, vars_):
53    vars_ = set(vars_)
54    results = []
55    global_vars = self.state[_Function].scope.globals & vars_
56
57    if global_vars:
58      results.append(gast.Global([str(v) for v in global_vars]))
59
60    nonlocal_vars = [
61        v for v in vars_ if not v.is_composite() and v not in global_vars]
62    if nonlocal_vars:
63      results.append(gast.Nonlocal([str(v) for v in nonlocal_vars]))
64
65    return results
66
67  def _create_state_functions(
68      self, block_vars, nonlocal_declarations, getter_name, setter_name):
69    if not block_vars:
70      template = """
71        def getter_name():
72          return ()
73        def setter_name(block_vars):
74          pass
75      """
76      return templates.replace(
77          template, getter_name=getter_name, setter_name=setter_name)
78
79    guarded_block_vars = []
80    for v in block_vars:
81      if v.is_simple():
82        guarded_block_vars.append(v)
83      else:
84        guarded_block_vars.append(
85            templates.replace_as_expression(
86                'ag__.ldu(lambda: var_, name)',
87                var_=v,
88                name=gast.Constant(str(v), kind=None)))
89
90    template = """
91      def getter_name():
92        return guarded_state_vars,
93      def setter_name(vars_):
94        nonlocal_declarations
95        state_vars, = vars_
96    """
97    return templates.replace(
98        template,
99        nonlocal_declarations=nonlocal_declarations,
100        getter_name=getter_name,
101        guarded_state_vars=guarded_block_vars,
102        setter_name=setter_name,
103        state_vars=tuple(block_vars))
104
105  def _create_loop_options(self, node):
106    if not anno.hasanno(node, anno.Basic.DIRECTIVES):
107      return gast.Dict([], [])
108
109    loop_directives = anno.getanno(node, anno.Basic.DIRECTIVES)
110    if directives.set_loop_options not in loop_directives:
111      return gast.Dict([], [])
112
113    opts_dict = loop_directives[directives.set_loop_options]
114    str_keys, values = zip(*opts_dict.items())
115    keys = [gast.Constant(s, kind=None) for s in str_keys]
116    values = list(values)  # ast and gast don't play well with tuples.
117    return gast.Dict(keys, values)
118
119  def _create_undefined_assigns(self, undefined_symbols):
120    assignments = []
121    for s in undefined_symbols:
122      template = '''
123        var = ag__.Undefined(symbol_name)
124      '''
125      assignments += templates.replace(
126          template,
127          var=s,
128          symbol_name=gast.Constant(s.ssf(), kind=None))
129    return assignments
130
131  def _get_block_basic_vars(self, modified, live_in, live_out):
132    nonlocals = self.state[_Function].scope.nonlocals
133    basic_scope_vars = []
134    for s in modified:
135      if s.is_composite():
136        # TODO(mdan): Raise an error when this happens for a TF scope.
137        continue
138      # Variables not live into or out of the scope are considered local to the
139      # scope.
140      if s in live_in or s in live_out or s in nonlocals:
141        basic_scope_vars.append(s)
142      continue
143    return frozenset(basic_scope_vars)
144
145  def _get_block_composite_vars(self, modified, live_in):
146    # The scope variables corresponding to composite symbols (e.g. `self.x`).
147    composite_scope_vars = []
148    for s in modified:
149      if not s.is_composite():
150        continue
151      # Mutations made to objects created inside the scope will appear as writes
152      # to composite symbols. Because these mutations appear as modifications
153      # made to composite symbols, we check whether the composite's parent is
154      # actually live into the scope.
155      # Example:
156      #   while cond:
157      #     x = Foo()
158      #     x.foo = 2 * x.foo  # x.foo is live into the scope, but x is not.
159      #
160      # Note that some parents might not be symbols - for example, in x['foo'],
161      # 'foo' is a parent, but it's a literal, not a symbol. We don't check the
162      # liveness of literals.
163      support_set_symbols = tuple(
164          sss for sss in s.support_set if sss.is_symbol())
165      if not all(sss in live_in for sss in support_set_symbols):
166        continue
167      composite_scope_vars.append(s)
168    return frozenset(composite_scope_vars)
169
170  def _get_block_vars(self, node, modified):
171    """Determines the variables affected inside a control flow statement."""
172    defined_in = anno.getanno(node, anno.Static.DEFINED_VARS_IN)
173    live_in = anno.getanno(node, anno.Static.LIVE_VARS_IN)
174    live_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT)
175    fn_scope = self.state[_Function].scope
176
177    basic_scope_vars = self._get_block_basic_vars(
178        modified,
179        live_in,
180        live_out)
181    composite_scope_vars = self._get_block_composite_vars(modified, live_in)
182    scope_vars = tuple(basic_scope_vars | composite_scope_vars)
183
184    # Variables that are modified inside the scope, but not defined
185    # before entering it. Only simple variables must be defined. The
186    # composite ones will be implicitly checked at runtime.
187    possibly_undefined = (
188        modified - defined_in - fn_scope.globals - fn_scope.nonlocals)
189    undefined = tuple(v for v in possibly_undefined if not v.is_composite())
190
191    # Variables that are modified inside the scope, and depend on values outside
192    # it.
193    input_only = basic_scope_vars & live_in - live_out
194
195    # Place the outputs first, then sort lexicographically.
196    scope_vars = sorted(scope_vars, key=lambda v: (v in input_only, v))
197    nouts = len(scope_vars) - len(input_only)
198
199    return scope_vars, undefined, nouts
200
201  def visit_If(self, node):
202    node = self.generic_visit(node)
203    body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
204    orelse_scope = anno.getanno(node, annos.NodeAnno.ORELSE_SCOPE)
205
206    cond_vars, undefined, nouts = self._get_block_vars(
207        node, body_scope.bound | orelse_scope.bound)
208
209    undefined_assigns = self._create_undefined_assigns(undefined)
210
211    nonlocal_declarations = self._create_nonlocal_declarations(cond_vars)
212
213    reserved = body_scope.referenced | orelse_scope.referenced
214    state_getter_name = self.ctx.namer.new_symbol('get_state', reserved)
215    state_setter_name = self.ctx.namer.new_symbol('set_state', reserved)
216    state_functions = self._create_state_functions(
217        cond_vars, nonlocal_declarations, state_getter_name, state_setter_name)
218
219    orelse_body = node.orelse
220    if not orelse_body:
221      orelse_body = [gast.Pass()]
222
223    template = """
224      state_functions
225      def body_name():
226        nonlocal_declarations
227        body
228      def orelse_name():
229        nonlocal_declarations
230        orelse
231      undefined_assigns
232      ag__.if_stmt(
233        test,
234        body_name,
235        orelse_name,
236        state_getter_name,
237        state_setter_name,
238        (symbol_names,),
239        nouts)
240    """
241    new_nodes = templates.replace(
242        template,
243        body=node.body,
244        body_name=self.ctx.namer.new_symbol('if_body', reserved),
245        orelse=orelse_body,
246        orelse_name=self.ctx.namer.new_symbol('else_body', reserved),
247        nonlocal_declarations=nonlocal_declarations,
248        nouts=gast.Constant(nouts, kind=None),
249        state_functions=state_functions,
250        state_getter_name=state_getter_name,
251        state_setter_name=state_setter_name,
252        symbol_names=tuple(gast.Constant(str(s), kind=None) for s in cond_vars),
253        test=node.test,
254        undefined_assigns=undefined_assigns)
255    origin_info.copy_origin(node, new_nodes[-1])
256    return new_nodes
257
258  def visit_While(self, node):
259    node = self.generic_visit(node)
260    body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
261
262    loop_vars, undefined, _ = self._get_block_vars(node, body_scope.bound)
263
264    undefined_assigns = self._create_undefined_assigns(undefined)
265
266    nonlocal_declarations = self._create_nonlocal_declarations(loop_vars)
267
268    reserved = body_scope.referenced
269    state_getter_name = self.ctx.namer.new_symbol('get_state', reserved)
270    state_setter_name = self.ctx.namer.new_symbol('set_state', reserved)
271    state_functions = self._create_state_functions(
272        loop_vars, nonlocal_declarations, state_getter_name, state_setter_name)
273
274    opts = self._create_loop_options(node)
275
276    template = """
277      state_functions
278      def body_name():
279        nonlocal_declarations
280        body
281      def test_name():
282        return test
283      undefined_assigns
284      ag__.while_stmt(
285          test_name,
286          body_name,
287          state_getter_name,
288          state_setter_name,
289          (symbol_names,),
290          opts)
291    """
292    new_nodes = templates.replace(
293        template,
294        body=node.body,
295        body_name=self.ctx.namer.new_symbol('loop_body', reserved),
296        nonlocal_declarations=nonlocal_declarations,
297        opts=opts,
298        state_functions=state_functions,
299        state_getter_name=state_getter_name,
300        state_setter_name=state_setter_name,
301        symbol_names=tuple(gast.Constant(str(s), kind=None) for s in loop_vars),
302        test=node.test,
303        test_name=self.ctx.namer.new_symbol('loop_test', reserved),
304        undefined_assigns=undefined_assigns)
305    origin_info.copy_origin(node, new_nodes[-1])
306    return new_nodes
307
308  def visit_For(self, node):
309    node = self.generic_visit(node)
310    body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
311    iter_scope = anno.getanno(node, annos.NodeAnno.ITERATE_SCOPE)
312
313    loop_vars, undefined, _ = self._get_block_vars(
314        node, body_scope.bound | iter_scope.bound)
315
316    undefined_assigns = self._create_undefined_assigns(undefined)
317
318    nonlocal_declarations = self._create_nonlocal_declarations(loop_vars)
319
320    reserved = body_scope.referenced | iter_scope.referenced
321    state_getter_name = self.ctx.namer.new_symbol('get_state', reserved)
322    state_setter_name = self.ctx.namer.new_symbol('set_state', reserved)
323    state_functions = self._create_state_functions(
324        loop_vars, nonlocal_declarations, state_getter_name, state_setter_name)
325
326    opts = self._create_loop_options(node)
327    opts.keys.append(gast.Constant('iterate_names', kind=None))
328    opts.values.append(gast.Constant(
329        parser.unparse(node.target, include_encoding_marker=False), kind=None))
330
331    if anno.hasanno(node, anno.Basic.EXTRA_LOOP_TEST):
332      extra_test = anno.getanno(node, anno.Basic.EXTRA_LOOP_TEST)
333      extra_test_name = self.ctx.namer.new_symbol(
334          'extra_test', reserved)
335      template = """
336        def extra_test_name():
337          nonlocal_declarations
338          return extra_test_expr
339      """
340      extra_test_function = templates.replace(
341          template,
342          extra_test_expr=extra_test,
343          extra_test_name=extra_test_name,
344          loop_vars=loop_vars,
345          nonlocal_declarations=nonlocal_declarations)
346    else:
347      extra_test_name = parser.parse_expression('None')
348      extra_test_function = []
349
350    # iterate_arg_name holds a single arg with the iterates, which may be a
351    # tuple.
352    iterate_arg_name = self.ctx.namer.new_symbol('itr', reserved)
353    template = """
354      iterates = iterate_arg_name
355    """
356    iterate_expansion = templates.replace(
357        template, iterate_arg_name=iterate_arg_name, iterates=node.target)
358    origin_info.copy_origin(node, iterate_expansion)
359
360    template = """
361      state_functions
362      def body_name(iterate_arg_name):
363        nonlocal_declarations
364        iterate_expansion
365        body
366      extra_test_function
367      undefined_assigns
368      ag__.for_stmt(
369          iterated,
370          extra_test_name,
371          body_name,
372          state_getter_name,
373          state_setter_name,
374          (symbol_names,),
375          opts)
376    """
377    new_nodes = templates.replace(
378        template,
379        body=node.body,
380        body_name=self.ctx.namer.new_symbol('loop_body', reserved),
381        extra_test_function=extra_test_function,
382        extra_test_name=extra_test_name,
383        iterate_arg_name=iterate_arg_name,
384        iterate_expansion=iterate_expansion,
385        iterated=node.iter,
386        nonlocal_declarations=nonlocal_declarations,
387        opts=opts,
388        symbol_names=tuple(gast.Constant(str(s), kind=None) for s in loop_vars),
389        state_functions=state_functions,
390        state_getter_name=state_getter_name,
391        state_setter_name=state_setter_name,
392        undefined_assigns=undefined_assigns)
393    origin_info.copy_origin(node, new_nodes[-1])
394    return new_nodes
395
396
397class AnnotatedDef(reaching_definitions.Definition):
398
399  def __init__(self):
400    super(AnnotatedDef, self).__init__()
401    self.directives = {}
402
403
404def transform(node, ctx):
405  graphs = cfg.build(node)
406  node = qual_names.resolve(node)
407  node = activity.resolve(node, ctx, None)
408  node = reaching_definitions.resolve(node, ctx, graphs)
409  node = reaching_fndefs.resolve(node, ctx, graphs)
410  node = liveness.resolve(node, ctx, graphs)
411
412  node = ControlFlowTransformer(ctx).visit(node)
413  return node
414