1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Converter for list operations. 16 17This includes converting Python lists to TensorArray/TensorList. 18""" 19 20# TODO(mdan): Elaborate the logic here. 21# TODO(mdan): Does it even make sense to attempt to try to use TAs? 22# The current rule (always convert to TensorArray) is naive and insufficient. 23# In general, a better mechanism could look like: 24# * convert to TensorList by default 25# * leave as Python list if the user explicitly forbids it 26# * convert to TensorArray only when complete write once behavior can be 27# guaranteed (e.g. list comprehensions) 28 29import gast 30 31from tensorflow.python.autograph.core import converter 32from tensorflow.python.autograph.lang import directives 33from tensorflow.python.autograph.pyct import anno 34from tensorflow.python.autograph.pyct import parser 35from tensorflow.python.autograph.pyct import qual_names 36from tensorflow.python.autograph.pyct import templates 37from tensorflow.python.autograph.pyct.static_analysis import activity 38from tensorflow.python.autograph.pyct.static_analysis.annos import NodeAnno 39 40 41class _Statement(object): 42 43 def __init__(self): 44 self.pop_uses = None 45 46 47class ListTransformer(converter.Base): 48 """Converts lists and related operations to their TF counterpart.""" 49 50 def visit_List(self, node): 51 node = self.generic_visit(node) 52 template = """ 53 ag__.new_list(elements) 54 """ 55 return templates.replace_as_expression(template, elements=node) 56 57 def _replace_append_call(self, node): 58 assert len(node.args) == 1 59 assert isinstance(node.func, gast.Attribute) 60 template = """ 61 target = ag__.list_append(target, element) 62 """ 63 return templates.replace( 64 template, 65 target=node.func.value, 66 element=node.args[0]) 67 68 def _replace_pop_call(self, node): 69 # Expressions that use pop() are converted to a statement + expression. 70 # 71 # For example: 72 # 73 # print(target.pop()) 74 # 75 # ... is converted to: 76 # 77 # target, target_pop = ag__.list_pop(target) 78 # print(target_pop) 79 # 80 # Here, we just generate the variable name and swap it in, 81 # and _generate_pop_operation will handle the rest. 82 # 83 # Multiple uses of pop() are allowed: 84 # 85 # print(tartget.pop(), target.pop()) 86 # print(tartget.pop().pop()) 87 # 88 assert isinstance(node.func, gast.Attribute) 89 scope = anno.getanno(node, NodeAnno.ARGS_SCOPE) 90 target_node = node.func.value 91 92 # Attempt to use a related name if one exists. Otherwise use something 93 # generic. 94 if anno.hasanno(target_node, anno.Basic.QN): 95 target_name = anno.getanno(target_node, anno.Basic.QN).ssf() 96 else: 97 target_name = 'list_' 98 pop_var_name = self.ctx.namer.new_symbol(target_name, scope.referenced) 99 100 stmt = self.state[_Statement] 101 if stmt.pop_uses is None: 102 stmt.pop_uses = [] 103 stmt.pop_uses.append((node, pop_var_name)) 104 105 return templates.replace_as_expression('var_name', var_name=pop_var_name) 106 107 def _replace_stack_call(self, node): 108 assert len(node.args) == 1 109 dtype = self.get_definition_directive( 110 node.args[0], 111 directives.set_element_type, 112 'dtype', 113 default=templates.replace_as_expression('None')) 114 template = """ 115 ag__.list_stack( 116 target, 117 opts=ag__.ListStackOpts( 118 element_dtype=dtype, 119 original_call=orig_call)) 120 """ 121 return templates.replace_as_expression( 122 template, 123 dtype=dtype, 124 target=node.args[0], 125 orig_call=node.func) 126 127 def visit_Call(self, node): 128 node = self.generic_visit(node) 129 130 # TODO(mdan): This is insufficient if target is a function argument. 131 # In the case of function arguments, we need to add the list to the 132 # function's return value, because it is being modified. 133 # TODO(mdan): Checking just the name is brittle, can it be improved? 134 if isinstance(node.func, gast.Attribute): 135 func_name = node.func.attr 136 if func_name == 'append' and (len(node.args) == 1): 137 node = self._replace_append_call(node) 138 elif func_name == 'pop' and (len(node.args) <= 1): 139 node = self._replace_pop_call(node) 140 elif (func_name == 'stack' and (len(node.args) == 1) and 141 (not node.keywords or node.keywords[0].arg == 'strict')): 142 # This avoids false positives with keyword args. 143 # TODO(mdan): handle kwargs properly. 144 node = self._replace_stack_call(node) 145 146 return node 147 148 def _generate_pop_operation(self, original_call_node, pop_var_name): 149 assert isinstance(original_call_node.func, gast.Attribute) 150 151 if original_call_node.args: 152 pop_element = original_call_node.args[0] 153 else: 154 pop_element = parser.parse_expression('None') 155 156 # The call will be something like "target.pop()", and the dtype is hooked to 157 # target, hence the func.value. 158 # TODO(mdan): For lists of lists, this won't work. 159 # The reason why it won't work is because it's unclear how to annotate 160 # the list as a "list of lists with a certain element type" when using 161 # operations like `l.pop().pop()`. 162 dtype = self.get_definition_directive( 163 original_call_node.func.value, 164 directives.set_element_type, 165 'dtype', 166 default=templates.replace_as_expression('None')) 167 shape = self.get_definition_directive( 168 original_call_node.func.value, 169 directives.set_element_type, 170 'shape', 171 default=templates.replace_as_expression('None')) 172 173 template = """ 174 target, pop_var_name = ag__.list_pop( 175 target, element, 176 opts=ag__.ListPopOpts(element_dtype=dtype, element_shape=shape)) 177 """ 178 return templates.replace( 179 template, 180 target=original_call_node.func.value, 181 pop_var_name=pop_var_name, 182 element=pop_element, 183 dtype=dtype, 184 shape=shape) 185 186 def _postprocess_statement(self, node): 187 """Inserts any separate pop() calls that node may use.""" 188 pop_uses = self.state[_Statement].pop_uses 189 if pop_uses: 190 replacements = [] 191 for original_call_node, pop_var_name in pop_uses: 192 replacements.extend( 193 self._generate_pop_operation(original_call_node, pop_var_name)) 194 replacements.append(node) 195 node = replacements 196 self.state[_Statement].exit() 197 return node, None 198 199 def _visit_and_process_block(self, block): 200 return self.visit_block( 201 block, 202 before_visit=self.state[_Statement].enter, 203 after_visit=self._postprocess_statement) 204 205 def visit_FunctionDef(self, node): 206 node.args = self.generic_visit(node.args) 207 node.decorator_list = self.visit_block(node.decorator_list) 208 node.body = self._visit_and_process_block(node.body) 209 return node 210 211 def visit_For(self, node): 212 node.target = self.visit(node.target) 213 node.body = self._visit_and_process_block(node.body) 214 node.orelse = self._visit_and_process_block(node.orelse) 215 return node 216 217 def visit_While(self, node): 218 node.test = self.visit(node.test) 219 node.body = self._visit_and_process_block(node.body) 220 node.orelse = self._visit_and_process_block(node.orelse) 221 return node 222 223 def visit_If(self, node): 224 node.test = self.visit(node.test) 225 node.body = self._visit_and_process_block(node.body) 226 node.orelse = self._visit_and_process_block(node.orelse) 227 return node 228 229 def visit_With(self, node): 230 node.items = self.visit_block(node.items) 231 node.body = self._visit_and_process_block(node.body) 232 return node 233 234 235def transform(node, ctx): 236 node = qual_names.resolve(node) 237 node = activity.resolve(node, ctx, None) 238 239 return ListTransformer(ctx).visit(node) 240