• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Handles control flow statements: while, for, if."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import gast
22
23from tensorflow.python.autograph.core import converter
24from tensorflow.python.autograph.lang import directives
25from tensorflow.python.autograph.pyct import anno
26from tensorflow.python.autograph.pyct import cfg
27from tensorflow.python.autograph.pyct import origin_info
28from tensorflow.python.autograph.pyct import parser
29from tensorflow.python.autograph.pyct import qual_names
30from tensorflow.python.autograph.pyct import templates
31from tensorflow.python.autograph.pyct.static_analysis import activity
32from tensorflow.python.autograph.pyct.static_analysis import annos
33from tensorflow.python.autograph.pyct.static_analysis import liveness
34from tensorflow.python.autograph.pyct.static_analysis import reaching_definitions
35from tensorflow.python.autograph.pyct.static_analysis import reaching_fndefs
36
37
38class _Function(object):
39
40  scope = None
41
42
43class ControlFlowTransformer(converter.Base):
44  """Transforms control flow structures like loops an conditionals."""
45
46  def visit_Lambda(self, node):
47    with self.state[_Function] as fn:
48      fn.scope = anno.getanno(node, anno.Static.SCOPE)
49      return self.generic_visit(node)
50
51  def visit_FunctionDef(self, node):
52    with self.state[_Function] as fn:
53      fn.scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
54      return self.generic_visit(node)
55
56  def _create_nonlocal_declarations(self, vars_):
57    vars_ = set(vars_)
58    results = []
59    global_vars = self.state[_Function].scope.globals & vars_
60
61    if global_vars:
62      results.append(gast.Global([str(v) for v in global_vars]))
63
64    nonlocal_vars = [
65        v for v in vars_ if not v.is_composite() and v not in global_vars]
66    if nonlocal_vars:
67      results.append(gast.Nonlocal([str(v) for v in nonlocal_vars]))
68
69    return results
70
71  def _create_state_functions(
72      self, block_vars, nonlocal_declarations, getter_name, setter_name):
73    if not block_vars:
74      template = """
75        def getter_name():
76          return ()
77        def setter_name(block_vars):
78          pass
79      """
80      return templates.replace(
81          template, getter_name=getter_name, setter_name=setter_name)
82
83    guarded_block_vars = []
84    for v in block_vars:
85      if v.is_simple():
86        guarded_block_vars.append(v)
87      else:
88        guarded_block_vars.append(
89            templates.replace_as_expression(
90                'ag__.ldu(lambda: var_, name)',
91                var_=v,
92                name=gast.Constant(str(v), kind=None)))
93
94    template = """
95      def getter_name():
96        return guarded_state_vars,
97      def setter_name(vars_):
98        nonlocal_declarations
99        state_vars, = vars_
100    """
101    return templates.replace(
102        template,
103        nonlocal_declarations=nonlocal_declarations,
104        getter_name=getter_name,
105        guarded_state_vars=guarded_block_vars,
106        setter_name=setter_name,
107        state_vars=tuple(block_vars))
108
109  def _create_loop_options(self, node):
110    if not anno.hasanno(node, anno.Basic.DIRECTIVES):
111      return gast.Dict([], [])
112
113    loop_directives = anno.getanno(node, anno.Basic.DIRECTIVES)
114    if directives.set_loop_options not in loop_directives:
115      return gast.Dict([], [])
116
117    opts_dict = loop_directives[directives.set_loop_options]
118    str_keys, values = zip(*opts_dict.items())
119    keys = [gast.Constant(s, kind=None) for s in str_keys]
120    values = list(values)  # ast and gast don't play well with tuples.
121    return gast.Dict(keys, values)
122
123  def _create_undefined_assigns(self, undefined_symbols):
124    assignments = []
125    for s in undefined_symbols:
126      template = '''
127        var = ag__.Undefined(symbol_name)
128      '''
129      assignments += templates.replace(
130          template,
131          var=s,
132          symbol_name=gast.Constant(s.ssf(), kind=None))
133    return assignments
134
135  def _get_block_basic_vars(self, modified, live_in, live_out):
136    nonlocals = self.state[_Function].scope.nonlocals
137    basic_scope_vars = []
138    for s in modified:
139      if s.is_composite():
140        # TODO(mdan): Raise an error when this happens for a TF scope.
141        continue
142      # Variables not live into or out of the scope are considered local to the
143      # scope.
144      if s in live_in or s in live_out or s in nonlocals:
145        basic_scope_vars.append(s)
146      continue
147    return frozenset(basic_scope_vars)
148
149  def _get_block_composite_vars(self, modified, live_in):
150    # The scope variables corresponding to composite symbols (e.g. `self.x`).
151    composite_scope_vars = []
152    for s in modified:
153      if not s.is_composite():
154        continue
155      # Mutations made to objects created inside the scope will appear as writes
156      # to composite symbols. Because these mutations appear as modifications
157      # made to composite symbols, we check whether the composite's parent is
158      # actually live into the scope.
159      # Example:
160      #   while cond:
161      #     x = Foo()
162      #     x.foo = 2 * x.foo  # x.foo is live into the scope, but x is not.
163      #
164      # Note that some parents might not be symbols - for example, in x['foo'],
165      # 'foo' is a parent, but it's a literal, not a symbol. We don't check the
166      # liveness of literals.
167      support_set_symbols = tuple(
168          sss for sss in s.support_set if sss.is_symbol())
169      if not all(sss in live_in for sss in support_set_symbols):
170        continue
171      composite_scope_vars.append(s)
172    return frozenset(composite_scope_vars)
173
174  def _get_block_vars(self, node, modified):
175    """Determines the variables affected inside a control flow statement."""
176    defined_in = anno.getanno(node, anno.Static.DEFINED_VARS_IN)
177    live_in = anno.getanno(node, anno.Static.LIVE_VARS_IN)
178    live_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT)
179    fn_scope = self.state[_Function].scope
180
181    basic_scope_vars = self._get_block_basic_vars(
182        modified,
183        live_in,
184        live_out)
185    composite_scope_vars = self._get_block_composite_vars(modified, live_in)
186    scope_vars = tuple(basic_scope_vars | composite_scope_vars)
187
188    # Variables that are modified inside the scope, but not defined
189    # before entering it. Only simple variables must be defined. The
190    # composite ones will be implicitly checked at runtime.
191    possibly_undefined = (
192        modified - defined_in - fn_scope.globals - fn_scope.nonlocals)
193    undefined = tuple(v for v in possibly_undefined if not v.is_composite())
194
195    # Variables that are modified inside the scope, and depend on values outside
196    # it.
197    input_only = basic_scope_vars & live_in - live_out
198
199    # Place the outputs first, then sort lexicographically.
200    scope_vars = sorted(scope_vars, key=lambda v: (v in input_only, v))
201    nouts = len(scope_vars) - len(input_only)
202
203    return scope_vars, undefined, nouts
204
205  def visit_If(self, node):
206    node = self.generic_visit(node)
207    body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
208    orelse_scope = anno.getanno(node, annos.NodeAnno.ORELSE_SCOPE)
209
210    cond_vars, undefined, nouts = self._get_block_vars(
211        node, body_scope.bound | orelse_scope.bound)
212
213    undefined_assigns = self._create_undefined_assigns(undefined)
214
215    nonlocal_declarations = self._create_nonlocal_declarations(cond_vars)
216
217    reserved = body_scope.referenced | orelse_scope.referenced
218    state_getter_name = self.ctx.namer.new_symbol('get_state', reserved)
219    state_setter_name = self.ctx.namer.new_symbol('set_state', reserved)
220    state_functions = self._create_state_functions(
221        cond_vars, nonlocal_declarations, state_getter_name, state_setter_name)
222
223    orelse_body = node.orelse
224    if not orelse_body:
225      orelse_body = [gast.Pass()]
226
227    template = """
228      state_functions
229      def body_name():
230        nonlocal_declarations
231        body
232      def orelse_name():
233        nonlocal_declarations
234        orelse
235      undefined_assigns
236      ag__.if_stmt(
237        test,
238        body_name,
239        orelse_name,
240        state_getter_name,
241        state_setter_name,
242        (symbol_names,),
243        nouts)
244    """
245    new_nodes = templates.replace(
246        template,
247        body=node.body,
248        body_name=self.ctx.namer.new_symbol('if_body', reserved),
249        orelse=orelse_body,
250        orelse_name=self.ctx.namer.new_symbol('else_body', reserved),
251        nonlocal_declarations=nonlocal_declarations,
252        nouts=gast.Constant(nouts, kind=None),
253        state_functions=state_functions,
254        state_getter_name=state_getter_name,
255        state_setter_name=state_setter_name,
256        symbol_names=tuple(gast.Constant(str(s), kind=None) for s in cond_vars),
257        test=node.test,
258        undefined_assigns=undefined_assigns)
259    origin_info.copy_origin(node, new_nodes[-1])
260    return new_nodes
261
262  def visit_While(self, node):
263    node = self.generic_visit(node)
264    body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
265
266    loop_vars, undefined, _ = self._get_block_vars(node, body_scope.bound)
267
268    undefined_assigns = self._create_undefined_assigns(undefined)
269
270    nonlocal_declarations = self._create_nonlocal_declarations(loop_vars)
271
272    reserved = body_scope.referenced
273    state_getter_name = self.ctx.namer.new_symbol('get_state', reserved)
274    state_setter_name = self.ctx.namer.new_symbol('set_state', reserved)
275    state_functions = self._create_state_functions(
276        loop_vars, nonlocal_declarations, state_getter_name, state_setter_name)
277
278    opts = self._create_loop_options(node)
279
280    template = """
281      state_functions
282      def body_name():
283        nonlocal_declarations
284        body
285      def test_name():
286        return test
287      undefined_assigns
288      ag__.while_stmt(
289          test_name,
290          body_name,
291          state_getter_name,
292          state_setter_name,
293          (symbol_names,),
294          opts)
295    """
296    new_nodes = templates.replace(
297        template,
298        body=node.body,
299        body_name=self.ctx.namer.new_symbol('loop_body', reserved),
300        nonlocal_declarations=nonlocal_declarations,
301        opts=opts,
302        state_functions=state_functions,
303        state_getter_name=state_getter_name,
304        state_setter_name=state_setter_name,
305        symbol_names=tuple(gast.Constant(str(s), kind=None) for s in loop_vars),
306        test=node.test,
307        test_name=self.ctx.namer.new_symbol('loop_test', reserved),
308        undefined_assigns=undefined_assigns)
309    origin_info.copy_origin(node, new_nodes[-1])
310    return new_nodes
311
312  def visit_For(self, node):
313    node = self.generic_visit(node)
314    body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
315    iter_scope = anno.getanno(node, annos.NodeAnno.ITERATE_SCOPE)
316
317    loop_vars, undefined, _ = self._get_block_vars(
318        node, body_scope.bound | iter_scope.bound)
319
320    undefined_assigns = self._create_undefined_assigns(undefined)
321
322    nonlocal_declarations = self._create_nonlocal_declarations(loop_vars)
323
324    reserved = body_scope.referenced | iter_scope.referenced
325    state_getter_name = self.ctx.namer.new_symbol('get_state', reserved)
326    state_setter_name = self.ctx.namer.new_symbol('set_state', reserved)
327    state_functions = self._create_state_functions(
328        loop_vars, nonlocal_declarations, state_getter_name, state_setter_name)
329
330    opts = self._create_loop_options(node)
331    opts.keys.append(gast.Constant('iterate_names', kind=None))
332    opts.values.append(gast.Constant(
333        parser.unparse(node.target, include_encoding_marker=False), kind=None))
334
335    if anno.hasanno(node, anno.Basic.EXTRA_LOOP_TEST):
336      extra_test = anno.getanno(node, anno.Basic.EXTRA_LOOP_TEST)
337      extra_test_name = self.ctx.namer.new_symbol(
338          'extra_test', reserved)
339      template = """
340        def extra_test_name():
341          nonlocal_declarations
342          return extra_test_expr
343      """
344      extra_test_function = templates.replace(
345          template,
346          extra_test_expr=extra_test,
347          extra_test_name=extra_test_name,
348          loop_vars=loop_vars,
349          nonlocal_declarations=nonlocal_declarations)
350    else:
351      extra_test_name = parser.parse_expression('None')
352      extra_test_function = []
353
354    # iterate_arg_name holds a single arg with the iterates, which may be a
355    # tuple.
356    iterate_arg_name = self.ctx.namer.new_symbol('itr', reserved)
357    template = """
358      iterates = iterate_arg_name
359    """
360    iterate_expansion = templates.replace(
361        template, iterate_arg_name=iterate_arg_name, iterates=node.target)
362    origin_info.copy_origin(node, iterate_expansion)
363
364    template = """
365      state_functions
366      def body_name(iterate_arg_name):
367        nonlocal_declarations
368        iterate_expansion
369        body
370      extra_test_function
371      undefined_assigns
372      ag__.for_stmt(
373          iterated,
374          extra_test_name,
375          body_name,
376          state_getter_name,
377          state_setter_name,
378          (symbol_names,),
379          opts)
380    """
381    new_nodes = templates.replace(
382        template,
383        body=node.body,
384        body_name=self.ctx.namer.new_symbol('loop_body', reserved),
385        extra_test_function=extra_test_function,
386        extra_test_name=extra_test_name,
387        iterate_arg_name=iterate_arg_name,
388        iterate_expansion=iterate_expansion,
389        iterated=node.iter,
390        nonlocal_declarations=nonlocal_declarations,
391        opts=opts,
392        symbol_names=tuple(gast.Constant(str(s), kind=None) for s in loop_vars),
393        state_functions=state_functions,
394        state_getter_name=state_getter_name,
395        state_setter_name=state_setter_name,
396        undefined_assigns=undefined_assigns)
397    origin_info.copy_origin(node, new_nodes[-1])
398    return new_nodes
399
400
401class AnnotatedDef(reaching_definitions.Definition):
402
403  def __init__(self):
404    super(AnnotatedDef, self).__init__()
405    self.directives = {}
406
407
408def transform(node, ctx):
409  graphs = cfg.build(node)
410  node = qual_names.resolve(node)
411  node = activity.resolve(node, ctx, None)
412  node = reaching_definitions.resolve(node, ctx, graphs)
413  node = reaching_fndefs.resolve(node, ctx, graphs)
414  node = liveness.resolve(node, ctx, graphs)
415
416  node = ControlFlowTransformer(ctx).visit(node)
417  return node
418