• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Handles control flow statements: while, for, if."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import gast
22
23from tensorflow.python.autograph.core import converter
24from tensorflow.python.autograph.pyct import anno
25from tensorflow.python.autograph.pyct import ast_util
26from tensorflow.python.autograph.pyct import templates
27from tensorflow.python.autograph.pyct.static_analysis import annos
28
29
30class SymbolNamer(object):
31  """Describes the interface for ControlFlowTransformer's namer."""
32
33  def new_symbol(self, name_root, reserved_locals):
34    """Generate a new unique symbol.
35
36    Args:
37      name_root: String, used as stem in the new name.
38      reserved_locals: Set(string), additional local symbols that are reserved
39          and which should not be used.
40    Returns:
41      String.
42    """
43    raise NotImplementedError()
44
45
46class ControlFlowTransformer(converter.Base):
47  """Transforms control flow structures like loops an conditionals."""
48
49  def _create_cond_branch(self, body_name, aliased_orig_names,
50                          aliased_new_names, body, returns):
51    if not returns:
52      # TODO(b/110167197): Replace with a plain return.
53      template = """
54        return 1
55      """
56      return_stmt = templates.replace(template)
57    elif len(returns) == 1:
58      template = """
59        return retval
60      """
61      return_stmt = templates.replace(template, retval=returns[0])
62    else:
63      template = """
64        return (retvals,)
65      """
66      return_stmt = templates.replace(template, retvals=returns)
67
68    if aliased_orig_names:
69      template = """
70        def body_name():
71          aliased_new_names, = aliased_orig_names,
72          body
73          return_stmt
74      """
75      return templates.replace(
76          template,
77          body_name=body_name,
78          body=body,
79          aliased_orig_names=aliased_orig_names,
80          aliased_new_names=aliased_new_names,
81          return_stmt=return_stmt)
82    else:
83      template = """
84        def body_name():
85          body
86          return_stmt
87      """
88      return templates.replace(
89          template, body_name=body_name, body=body, return_stmt=return_stmt)
90
91  def _create_cond_expr(self, results, test, body_name, orelse_name,
92                        state_getter_name,
93                        state_setter_name):
94    if results is not None:
95      template = """
96        results = ag__.if_stmt(test, body_name, orelse_name,
97                               state_getter_name, state_setter_name)
98      """
99      return templates.replace(
100          template,
101          test=test,
102          results=results,
103          body_name=body_name,
104          orelse_name=orelse_name,
105          state_getter_name=state_getter_name,
106          state_setter_name=state_setter_name)
107    else:
108      template = """
109        ag__.if_stmt(test, body_name, orelse_name, getter_name, setter_name)
110      """
111      return templates.replace(
112          template,
113          test=test,
114          body_name=body_name,
115          orelse_name=orelse_name,
116          getter_name=state_getter_name,
117          setter_name=state_setter_name)
118
119  def _fmt_symbols(self, symbol_set):
120    if not symbol_set:
121      return 'no variables'
122    return ', '.join(map(str, symbol_set))
123
124  def _determine_aliased_symbols(self, scope, node_defined_in, block):
125    if block:
126      block_live_in = set(anno.getanno(block[0], anno.Static.LIVE_VARS_IN))
127    else:
128      block_live_in = set()
129
130    # For the purpose of aliasing, composite symbols with live owners are live
131    # as well. Otherwise this would leak tensors from the conditional's body.
132    #
133    # For example:
134    #
135    #   obj = some_obj
136    #   if cond:
137    #     obj.a = val
138    #
139    # Thanslating to the code below would be incorrect:
140    #
141    #   def true_fn():
142    #     obj.a = val()  # Wrong! leaks ops owned by true_fn
143    #     return obj.a
144    for s in scope.modified:
145      if s.is_composite():
146        live_parents = block_live_in & s.owner_set
147        if live_parents:
148          block_live_in.add(s)
149    return scope.modified & node_defined_in & block_live_in
150
151  def _create_state_functions(self, composites,
152                              state_getter_name, state_setter_name):
153    if composites:
154      composite_tuple = tuple(composites)
155      template = """
156        def state_getter_name():
157          return composite_tuple,
158        def state_setter_name(vals):
159          composite_tuple, = vals
160      """
161      node = templates.replace(
162          template,
163          state_getter_name=state_getter_name,
164          state_setter_name=state_setter_name,
165          composite_tuple=composite_tuple)
166    else:
167      template = """
168        def state_getter_name():
169          return ()
170        def state_setter_name(_):
171          pass
172        """
173      node = templates.replace(
174          template,
175          state_getter_name=state_getter_name,
176          state_setter_name=state_setter_name)
177
178    return node
179
180  def _create_undefined_assigns(self, undefined_symbols):
181    assignments = []
182    for s in undefined_symbols:
183      template = '''
184        var = ag__.Undefined(symbol_name)
185      '''
186      assignments += templates.replace(
187          template,
188          var=s,
189          symbol_name=gast.Str(s.ssf()))
190    return assignments
191
192  def visit_If(self, node):
193    body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
194    orelse_scope = anno.getanno(node, annos.NodeAnno.ORELSE_SCOPE)
195    defined_in = anno.getanno(node, anno.Static.DEFINED_VARS_IN)
196    live_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT)
197
198    # Note: this information needs to be extracted before the body conversion
199    # that happens in the call to generic_visit below, because the conversion
200    # generates nodes that lack static analysis annotations.
201    need_alias_in_body = self._determine_aliased_symbols(
202        body_scope, defined_in, node.body)
203    need_alias_in_orelse = self._determine_aliased_symbols(
204        orelse_scope, defined_in, node.orelse)
205
206    node = self.generic_visit(node)
207
208    modified_in_cond = body_scope.modified | orelse_scope.modified
209    returned_from_cond = set()
210    composites = set()
211    for s in modified_in_cond:
212      if s in live_out:
213        returned_from_cond.add(s)
214      if s.is_composite():
215        # Special treatment for compound objects, always return them.
216        # This allows special handling within the if_stmt itself.
217        # For example, in TensorFlow we need to restore the state of composite
218        # symbols to ensure that only effects from the executed branch are seen.
219        returned_from_cond.add(s)
220        composites.add(s)
221
222    created_in_body = body_scope.modified & returned_from_cond - defined_in
223    created_in_orelse = orelse_scope.modified & returned_from_cond - defined_in
224
225    basic_created_in_body = tuple(
226        s for s in created_in_body if not s.is_composite())
227    basic_created_in_orelse = tuple(
228        s for s in created_in_orelse if not s.is_composite())
229
230    # These variables are defined only in a single branch. This is fine in
231    # Python so we pass them through. Another backend, e.g. Tensorflow, may need
232    # to handle these cases specially or throw an Error.
233    possibly_undefined = (set(basic_created_in_body) ^
234                          set(basic_created_in_orelse))
235
236    # Alias the closure variables inside the conditional functions, to allow
237    # the functions access to the respective variables.
238    # We will alias variables independently for body and orelse scope,
239    # because different branches might write different variables.
240    aliased_body_orig_names = tuple(need_alias_in_body)
241    aliased_orelse_orig_names = tuple(need_alias_in_orelse)
242    aliased_body_new_names = tuple(
243        self.ctx.namer.new_symbol(s.ssf(), body_scope.referenced)
244        for s in aliased_body_orig_names)
245    aliased_orelse_new_names = tuple(
246        self.ctx.namer.new_symbol(s.ssf(), orelse_scope.referenced)
247        for s in aliased_orelse_orig_names)
248
249    alias_body_map = dict(zip(aliased_body_orig_names, aliased_body_new_names))
250    alias_orelse_map = dict(
251        zip(aliased_orelse_orig_names, aliased_orelse_new_names))
252
253    node_body = ast_util.rename_symbols(node.body, alias_body_map)
254    node_orelse = ast_util.rename_symbols(node.orelse, alias_orelse_map)
255
256    cond_var_name = self.ctx.namer.new_symbol('cond', body_scope.referenced)
257    body_name = self.ctx.namer.new_symbol('if_true', body_scope.referenced)
258    orelse_name = self.ctx.namer.new_symbol('if_false', orelse_scope.referenced)
259    all_referenced = body_scope.referenced | orelse_scope.referenced
260    state_getter_name = self.ctx.namer.new_symbol('get_state', all_referenced)
261    state_setter_name = self.ctx.namer.new_symbol('set_state', all_referenced)
262
263    returned_from_cond = tuple(returned_from_cond)
264    if returned_from_cond:
265      if len(returned_from_cond) == 1:
266        cond_results = returned_from_cond[0]
267      else:
268        cond_results = gast.Tuple([s.ast() for s in returned_from_cond], None)
269
270      returned_from_body = tuple(
271          alias_body_map[s] if s in need_alias_in_body else s
272          for s in returned_from_cond)
273      returned_from_orelse = tuple(
274          alias_orelse_map[s] if s in need_alias_in_orelse else s
275          for s in returned_from_cond)
276
277    else:
278      # When the cond would return no value, we leave the cond called without
279      # results. That in turn should trigger the side effect guards. The
280      # branch functions will return a dummy value that ensures cond
281      # actually has some return value as well.
282      cond_results = None
283      # TODO(mdan): Replace with None once side_effect_guards is retired.
284      returned_from_body = (templates.replace_as_expression(
285          'ag__.match_staging_level(1, cond_var_name)',
286          cond_var_name=cond_var_name),)
287      returned_from_orelse = (templates.replace_as_expression(
288          'ag__.match_staging_level(1, cond_var_name)',
289          cond_var_name=cond_var_name),)
290
291    cond_assign = self.create_assignment(cond_var_name, node.test)
292    body_def = self._create_cond_branch(
293        body_name,
294        aliased_orig_names=aliased_body_orig_names,
295        aliased_new_names=aliased_body_new_names,
296        body=node_body,
297        returns=returned_from_body)
298    orelse_def = self._create_cond_branch(
299        orelse_name,
300        aliased_orig_names=aliased_orelse_orig_names,
301        aliased_new_names=aliased_orelse_new_names,
302        body=node_orelse,
303        returns=returned_from_orelse)
304    undefined_assigns = self._create_undefined_assigns(possibly_undefined)
305    composite_defs = self._create_state_functions(
306        composites, state_getter_name, state_setter_name)
307
308    cond_expr = self._create_cond_expr(cond_results, cond_var_name, body_name,
309                                       orelse_name, state_getter_name,
310                                       state_setter_name)
311
312    return (undefined_assigns + cond_assign + composite_defs + body_def +
313            orelse_def + cond_expr)
314
315  def _get_loop_state(self, node):
316    body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
317    defined_in = anno.getanno(node, anno.Static.DEFINED_VARS_IN)
318    live_in = anno.getanno(node, anno.Static.LIVE_VARS_IN)
319    live_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT)
320    reserved_symbols = body_scope.referenced
321
322    loop_state = []
323    for s in body_scope.modified:
324
325      # Variables not live into or out of the loop are considered local to the
326      # loop.
327      if s not in live_in and s not in live_out:
328        continue
329
330      # Mutations made to objects created inside the loop will appear as writes
331      # to composite symbols. Because these mutations appear as modifications
332      # made to composite symbols, we check whether the composite's parent is
333      # actually live into the loop.
334      # Example:
335      #   while cond:
336      #     x = Foo()
337      #     x.foo = 2 * x.foo  # x.foo is live into the loop, but x is not.
338      if s.is_composite() and not all(p in live_in for p in s.support_set):
339        continue
340
341      loop_state.append(s)
342    loop_state = frozenset(loop_state)
343
344    # Variable that are used or defined inside the loop, but not defined
345    # before entering the loop
346    undefined_lives = loop_state - defined_in
347
348    # Only simple variables must be defined. The composite ones will be
349    # implicitly checked at runtime.
350    possibly_undefs = {v for v in undefined_lives if v.is_simple()}
351
352    return loop_state, reserved_symbols, possibly_undefs
353
354  def _state_constructs(self, loop_state, reserved_symbols):
355    loop_state = tuple(loop_state)
356    state_ssf = [
357        self.ctx.namer.new_symbol(s.ssf(), reserved_symbols) for s in loop_state
358    ]
359    ssf_map = {
360        name: ssf
361        for name, ssf in zip(loop_state, state_ssf)
362        if str(name) != ssf
363    }
364
365    state_ast_tuple = gast.Tuple([n.ast() for n in loop_state], None)
366
367    if len(loop_state) == 1:
368      loop_state = loop_state[0]
369      state_ssf = state_ssf[0]
370
371    return loop_state, state_ssf, state_ast_tuple, ssf_map
372
373  def visit_While(self, node):
374    self.generic_visit(node)
375
376    loop_state, reserved_symbols, possibly_undefs = self._get_loop_state(node)
377
378    # Note: one might expect we can dispatch based on the loop condition.
379    # But because that is dependent on the state, it cannot be evaluated ahead
380    # of time - doing that would risk duplicating any effects the condition has.
381    # Furthermore, we cannot evaluate slices and attributes, because they might
382    # trigger __getitem__ or __getattribute__.
383    #
384    # A case where this fails includes ops with side effects on a stateful
385    # resource captured in an object:
386    #
387    #   while self.v.read() > 0:
388    #     self.v.assign(1)
389    #
390    # TODO(mdan): Handle the case above.
391    cond_scope = anno.getanno(node, annos.NodeAnno.COND_SCOPE)
392    cond_closure = set()
393    for s in cond_scope.read:
394      cond_closure |= s.support_set
395
396    loop_state, state_ssf, state_ast_tuple, ssf_map = self._state_constructs(
397        loop_state, reserved_symbols)
398    node_body = ast_util.rename_symbols(node.body, ssf_map)
399    test = ast_util.rename_symbols(node.test, ssf_map)
400
401    if loop_state:
402      template = """
403        def test_name(state_ssf):
404          return test
405        def body_name(state_ssf):
406          body
407          return state_ssf,
408        state_ast_tuple = ag__.while_stmt(
409            test_name, body_name, (state,), (extra_deps,))
410      """
411      node = templates.replace(
412          template,
413          state=loop_state,
414          state_ssf=state_ssf,
415          state_ast_tuple=state_ast_tuple,
416          test_name=self.ctx.namer.new_symbol('loop_test', reserved_symbols),
417          test=test,
418          body_name=self.ctx.namer.new_symbol('loop_body', reserved_symbols),
419          body=node_body,
420          extra_deps=tuple(s.ast() for s in cond_closure),
421      )
422    else:
423      template = """
424        def test_name():
425          return test
426        def body_name():
427          body
428          return ()
429        ag__.while_stmt(test_name, body_name, (), (extra_deps,))
430      """
431      node = templates.replace(
432          template,
433          test_name=self.ctx.namer.new_symbol('loop_test', reserved_symbols),
434          test=test,
435          body_name=self.ctx.namer.new_symbol('loop_body', reserved_symbols),
436          body=node_body,
437          extra_deps=tuple(s.ast() for s in cond_closure),
438      )
439
440    undefined_assigns = self._create_undefined_assigns(possibly_undefs)
441    return undefined_assigns + node
442
443  def _create_for_loop_early_stopping(self, loop_state, state_ssf,
444                                      state_ast_tuple, original_node,
445                                      extra_test_name, extra_test,
446                                      body_name, loop_body):
447    """Create node for for-loop with early stopping (e.g. break or return)."""
448    template = """
449      def extra_test_name(state_ssf):
450        return extra_test_expr
451      def body_name(loop_vars, state_ssf):
452        # Workaround for PEP-3113
453        iterate = loop_vars
454        body
455        return state_ssf,
456      state_ast_tuple = ag__.for_stmt(
457          iter_, extra_test_name, body_name, (state,))
458    """
459    return templates.replace(
460        template,
461        state=loop_state,
462        state_ssf=state_ssf,
463        state_ast_tuple=state_ast_tuple,
464        iter_=original_node.iter,
465        iterate=original_node.target,
466        extra_test_name=extra_test_name,
467        extra_test_expr=extra_test,
468        body_name=body_name,
469        body=loop_body)
470
471  def _create_for_loop_with_state(self, loop_state, state_ssf, state_ast_tuple,
472                                  original_node, body_name, loop_body):
473    """Create node for for-loop with loop-carried state, no early stopping."""
474    template = """
475      def body_name(loop_vars, state_ssf):
476        # Workaround for PEP-3113
477        iterate = loop_vars
478        body
479        return state_ssf,
480      state_ast_tuple = ag__.for_stmt(
481          iter_, None, body_name, (state,))
482    """
483    return templates.replace(
484        template,
485        state=loop_state,
486        state_ssf=state_ssf,
487        state_ast_tuple=state_ast_tuple,
488        iter_=original_node.iter,
489        iterate=original_node.target,
490        body_name=body_name,
491        body=loop_body)
492
493  def _create_for_loop_without_state(self, original_node, body_name, loop_body):
494    """Create node for for-loop with loop-carried state, no early stopping."""
495    template = """
496      def body_name(loop_vars):
497        # Workaround for PEP-3113
498        iterate = loop_vars
499        body
500        return ()
501      ag__.for_stmt(iter_, None, body_name, ())
502    """
503    return templates.replace(
504        template,
505        iter_=original_node.iter,
506        iterate=original_node.target,
507        body_name=body_name,
508        body=loop_body)
509
510  def visit_For(self, node):
511    self.generic_visit(node)
512
513    loop_state, reserved_symbols, possibly_undefs = self._get_loop_state(node)
514    loop_state, state_ssf, state_ast_tuple, ssf_map = self._state_constructs(
515        loop_state, reserved_symbols)
516    node_body = ast_util.rename_symbols(node.body, ssf_map)
517    body_name = self.ctx.namer.new_symbol('loop_body', reserved_symbols)
518
519    has_extra_test = anno.hasanno(node, 'extra_test')
520    if loop_state:
521      if has_extra_test:
522        # Loop with early stopping (e.g. break or return)
523        extra_test = anno.getanno(node, 'extra_test')
524        extra_test = ast_util.rename_symbols(extra_test, ssf_map)
525        extra_test_name = self.ctx.namer.new_symbol('extra_test',
526                                                    reserved_symbols)
527        node = self._create_for_loop_early_stopping(
528            loop_state, state_ssf, state_ast_tuple, node, extra_test_name,
529            extra_test, body_name, node_body)
530      else:
531        # Loop with loop-carried state and no early stopping
532        node = self._create_for_loop_with_state(
533            loop_state, state_ssf, state_ast_tuple, node, body_name, node_body)
534    else:
535      # Loop with no loop-carried state and no early stopping
536      assert not has_extra_test, ('Early stoppiong (e.g. break and/or return) '
537                                  'should create state variables.')
538      node = self._create_for_loop_without_state(node, body_name, node_body)
539
540    undefined_assigns = self._create_undefined_assigns(possibly_undefs)
541    return undefined_assigns + node
542
543
544def transform(node, ctx):
545  node = ControlFlowTransformer(ctx).visit(node)
546  return node
547