• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Converter for list operations.
16
17This includes converting Python lists to TensorArray/TensorList.
18"""
19
20# TODO(mdan): Elaborate the logic here.
21# TODO(mdan): Does it even make sense to attempt to try to use TAs?
22# The current rule (always convert to TensorArray) is naive and insufficient.
23# In general, a better mechanism could look like:
24#   * convert to TensorList by default
25#   * leave as Python list if the user explicitly forbids it
26#   * convert to TensorArray only when complete write once behavior can be
27#     guaranteed (e.g. list comprehensions)
28
29import gast
30
31from tensorflow.python.autograph.core import converter
32from tensorflow.python.autograph.lang import directives
33from tensorflow.python.autograph.pyct import anno
34from tensorflow.python.autograph.pyct import parser
35from tensorflow.python.autograph.pyct import qual_names
36from tensorflow.python.autograph.pyct import templates
37from tensorflow.python.autograph.pyct.static_analysis import activity
38from tensorflow.python.autograph.pyct.static_analysis.annos import NodeAnno
39
40
41class _Statement(object):
42
43  def __init__(self):
44    self.pop_uses = None
45
46
47class ListTransformer(converter.Base):
48  """Converts lists and related operations to their TF counterpart."""
49
50  def visit_List(self, node):
51    node = self.generic_visit(node)
52    template = """
53      ag__.new_list(elements)
54    """
55    return templates.replace_as_expression(template, elements=node)
56
57  def _replace_append_call(self, node):
58    assert len(node.args) == 1
59    assert isinstance(node.func, gast.Attribute)
60    template = """
61      target = ag__.list_append(target, element)
62    """
63    return templates.replace(
64        template,
65        target=node.func.value,
66        element=node.args[0])
67
68  def _replace_pop_call(self, node):
69    # Expressions that use pop() are converted to a statement + expression.
70    #
71    # For example:
72    #
73    #   print(target.pop())
74    #
75    # ... is converted to:
76    #
77    #   target, target_pop = ag__.list_pop(target)
78    #   print(target_pop)
79    #
80    # Here, we just generate the variable name and swap it in,
81    # and _generate_pop_operation will handle the rest.
82    #
83    # Multiple uses of pop() are allowed:
84    #
85    #   print(tartget.pop(), target.pop())
86    #   print(tartget.pop().pop())
87    #
88    assert isinstance(node.func, gast.Attribute)
89    scope = anno.getanno(node, NodeAnno.ARGS_SCOPE)
90    target_node = node.func.value
91
92    # Attempt to use a related name if one exists. Otherwise use something
93    # generic.
94    if anno.hasanno(target_node, anno.Basic.QN):
95      target_name = anno.getanno(target_node, anno.Basic.QN).ssf()
96    else:
97      target_name = 'list_'
98    pop_var_name = self.ctx.namer.new_symbol(target_name, scope.referenced)
99
100    stmt = self.state[_Statement]
101    if stmt.pop_uses is None:
102      stmt.pop_uses = []
103    stmt.pop_uses.append((node, pop_var_name))
104
105    return templates.replace_as_expression('var_name', var_name=pop_var_name)
106
107  def _replace_stack_call(self, node):
108    assert len(node.args) == 1
109    dtype = self.get_definition_directive(
110        node.args[0],
111        directives.set_element_type,
112        'dtype',
113        default=templates.replace_as_expression('None'))
114    template = """
115      ag__.list_stack(
116          target,
117          opts=ag__.ListStackOpts(
118              element_dtype=dtype,
119              original_call=orig_call))
120    """
121    return templates.replace_as_expression(
122        template,
123        dtype=dtype,
124        target=node.args[0],
125        orig_call=node.func)
126
127  def visit_Call(self, node):
128    node = self.generic_visit(node)
129
130    # TODO(mdan): This is insufficient if target is a function argument.
131    # In the case of function arguments, we need to add the list to the
132    # function's return value, because it is being modified.
133    # TODO(mdan): Checking just the name is brittle, can it be improved?
134    if isinstance(node.func, gast.Attribute):
135      func_name = node.func.attr
136      if func_name == 'append' and (len(node.args) == 1):
137        node = self._replace_append_call(node)
138      elif func_name == 'pop' and (len(node.args) <= 1):
139        node = self._replace_pop_call(node)
140      elif (func_name == 'stack' and (len(node.args) == 1) and
141            (not node.keywords or node.keywords[0].arg == 'strict')):
142        # This avoids false positives with keyword args.
143        # TODO(mdan): handle kwargs properly.
144        node = self._replace_stack_call(node)
145
146    return node
147
148  def _generate_pop_operation(self, original_call_node, pop_var_name):
149    assert isinstance(original_call_node.func, gast.Attribute)
150
151    if original_call_node.args:
152      pop_element = original_call_node.args[0]
153    else:
154      pop_element = parser.parse_expression('None')
155
156    # The call will be something like "target.pop()", and the dtype is hooked to
157    # target, hence the func.value.
158    # TODO(mdan): For lists of lists, this won't work.
159    # The reason why it won't work is because it's unclear how to annotate
160    # the list as a "list of lists with a certain element type" when using
161    # operations like `l.pop().pop()`.
162    dtype = self.get_definition_directive(
163        original_call_node.func.value,
164        directives.set_element_type,
165        'dtype',
166        default=templates.replace_as_expression('None'))
167    shape = self.get_definition_directive(
168        original_call_node.func.value,
169        directives.set_element_type,
170        'shape',
171        default=templates.replace_as_expression('None'))
172
173    template = """
174      target, pop_var_name = ag__.list_pop(
175          target, element,
176          opts=ag__.ListPopOpts(element_dtype=dtype, element_shape=shape))
177    """
178    return templates.replace(
179        template,
180        target=original_call_node.func.value,
181        pop_var_name=pop_var_name,
182        element=pop_element,
183        dtype=dtype,
184        shape=shape)
185
186  def _postprocess_statement(self, node):
187    """Inserts any separate pop() calls that node may use."""
188    pop_uses = self.state[_Statement].pop_uses
189    if pop_uses:
190      replacements = []
191      for original_call_node, pop_var_name in pop_uses:
192        replacements.extend(
193            self._generate_pop_operation(original_call_node, pop_var_name))
194      replacements.append(node)
195      node = replacements
196    self.state[_Statement].exit()
197    return node, None
198
199  def _visit_and_process_block(self, block):
200    return self.visit_block(
201        block,
202        before_visit=self.state[_Statement].enter,
203        after_visit=self._postprocess_statement)
204
205  def visit_FunctionDef(self, node):
206    node.args = self.generic_visit(node.args)
207    node.decorator_list = self.visit_block(node.decorator_list)
208    node.body = self._visit_and_process_block(node.body)
209    return node
210
211  def visit_For(self, node):
212    node.target = self.visit(node.target)
213    node.body = self._visit_and_process_block(node.body)
214    node.orelse = self._visit_and_process_block(node.orelse)
215    return node
216
217  def visit_While(self, node):
218    node.test = self.visit(node.test)
219    node.body = self._visit_and_process_block(node.body)
220    node.orelse = self._visit_and_process_block(node.orelse)
221    return node
222
223  def visit_If(self, node):
224    node.test = self.visit(node.test)
225    node.body = self._visit_and_process_block(node.body)
226    node.orelse = self._visit_and_process_block(node.orelse)
227    return node
228
229  def visit_With(self, node):
230    node.items = self.visit_block(node.items)
231    node.body = self._visit_and_process_block(node.body)
232    return node
233
234
235def transform(node, ctx):
236  node = qual_names.resolve(node)
237  node = activity.resolve(node, ctx, None)
238
239  return ListTransformer(ctx).visit(node)
240