1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Type resolution. 16 17This analyzer uses known live values to further infer object types. This 18may include for instance constructed objects and object member functions. 19 20In addition, the analyzer also handles user annotations made in the code (for 21example, the autograph.set_element_type function). 22 23Requires annotations generated by LiveValuesResolver. 24""" 25 26# TODO(mdan): This would be more robust with a CFG. 27# Situations with multiple reaching modifications (e.g. modified inside and 28# outside a control flow statement) should be more robustly detected and 29# analyzed. 30 31# TODO(mdan): Look into using Python AST's type annotation fields instead. 32# It would be desirable to use that mechanism if we can. 33# Some caveats to consider: We may need to annotate other nodes like 34# Attribute. It may also not be feasible for us to faithfully to replicate 35# PY3's type annotations where it isn't available. It would also require us 36# to design rigorous type definitions that can accommodate Python types 37# as well as TensorFLow dtypes and shapes. 38 39 40from __future__ import absolute_import 41from __future__ import division 42from __future__ import print_function 43 44import gast 45 46from tensorflow.python.autograph.pyct import anno 47from tensorflow.python.autograph.pyct import ast_util 48from tensorflow.python.autograph.pyct import inspect_utils 49from tensorflow.python.autograph.pyct import transformer 50from tensorflow.python.util import tf_inspect 51 52 53# TODO(mdan): Remove the duplication between this and activity.py. 54# In particular, the symbol definitions we track here could as well be tracked 55# there because they follow the same rules for visibility. 56# TODO(mdan): Use a CFG based Defined analysis instead. 57class Scope(object): 58 """Tracks symbol value references. 59 60 Attributes: 61 values: A dict mapping string to gast.Node, containing the value that was 62 most recently assigned to the symbol. 63 """ 64 65 def __init__(self, parent): 66 """Create a new scope. 67 68 Args: 69 parent: A Scope or None. 70 """ 71 self.parent = parent 72 self.values = {} 73 74 def __repr__(self): 75 return 'Scope[%s]' % self.values.keys() 76 77 def copy(self): 78 s = Scope(self.parent) 79 s.values = self.values.copy() 80 return s 81 82 def setval(self, name, value): 83 self.values[name] = value 84 85 def hasval(self, name): 86 return (name in self.values or 87 (self.parent is not None and self.parent.hasval(name))) 88 89 def getval(self, name): 90 if name in self.values: 91 return self.values[name] 92 if self.parent is not None: 93 return self.parent.getval(name) 94 raise KeyError(name) 95 96 97class TypeInfoResolver(transformer.Base): 98 """Annotates symbols with type information where possible. 99 100 Nodes currently annotated: 101 * Call (helps detect class constructors) 102 * Attribute (helps resolve object methods) 103 """ 104 105 def __init__(self, context): 106 super(TypeInfoResolver, self).__init__(context) 107 self.scope = Scope(None) 108 109 def visit_FunctionDef(self, node): 110 self.scope = Scope(self.scope) 111 node = self.generic_visit(node) 112 self.scope = self.scope.parent 113 return node 114 115 def _visit_block(self, block): 116 self.scope = Scope(self.scope) 117 block = self.visit_block(block) 118 self.scope = self.scope.parent 119 return block 120 121 def visit_For(self, node): 122 self.generic_visit(node.target) 123 self.generic_visit(node.iter) 124 node.body = self._visit_block(node.body) 125 node.orelse = self._visit_block(node.orelse) 126 return node 127 128 def visit_While(self, node): 129 self.generic_visit(node.test) 130 node.body = self._visit_block(node.body) 131 node.orelse = self._visit_block(node.orelse) 132 return node 133 134 def visit_If(self, node): 135 self.generic_visit(node.test) 136 node.body = self._visit_block(node.body) 137 node.orelse = self._visit_block(node.orelse) 138 return node 139 140 def _process_function_arg(self, arg_node): 141 qn = anno.getanno(arg_node, anno.Basic.QN) 142 arg_name = str(qn) 143 self.scope.setval(qn, arg_node) 144 if (len(self.enclosing_entities) == 1 and 145 arg_name in self.ctx.info.arg_types): 146 # Forge a node to hold the type information, so that method calls on 147 # it can resolve the type. 148 type_string, type_obj = self.ctx.info.arg_types[ 149 arg_name] 150 anno.setanno(arg_node, 'type', type_obj) 151 anno.setanno(arg_node, 'type_fqn', tuple(type_string.split('.'))) 152 153 def visit_arg(self, node): 154 self._process_function_arg(node.arg) 155 return node 156 157 def visit_Name(self, node): 158 self.generic_visit(node) 159 if isinstance(node.ctx, gast.Param): 160 self._process_function_arg(node) 161 elif isinstance(node.ctx, gast.Load): 162 qn = anno.getanno(node, anno.Basic.QN) 163 if self.scope.hasval(qn): 164 # E.g. if we had 165 # a = b 166 # then for future references to `a` we should have definition = `b` 167 definition = self.scope.getval(qn) 168 anno.copyanno(definition, node, 'type') 169 anno.copyanno(definition, node, 'type_fqn') 170 171 # TODO(mdan): Remove this when the directives module is in. 172 anno.copyanno(definition, node, 'element_type') 173 anno.copyanno(definition, node, 'element_shape') 174 return node 175 176 def _process_variable_assignment(self, target, value): 177 # Constructors 178 if isinstance(value, gast.Call): 179 func = value.func 180 if anno.hasanno(func, 'live_val'): 181 func_obj = anno.getanno(func, 'live_val') 182 if (tf_inspect.isclass(func_obj) and 183 not inspect_utils.isbuiltin(func_obj)): 184 anno.setanno(value, 'is_constructor', True) 185 anno.setanno(value, 'type', func_obj) 186 anno.setanno(value, 'type_fqn', anno.getanno(func, 'fqn')) 187 # TODO(mdan): Raise an error if constructor has side effects. 188 # We can have a whitelist of no-side-effects constructors. 189 # We can also step inside the constructor and further analyze. 190 191 if isinstance(target, (gast.Name, gast.Attribute)): 192 target_symbol = anno.getanno(target, anno.Basic.QN) 193 self.scope.setval(target_symbol, value) 194 elif isinstance(target, gast.Subscript): 195 pass 196 else: 197 raise ValueError('assignment target has unknown type: %s' % target) 198 199 def visit_With(self, node): 200 for item in node.items: 201 if item.optional_vars is not None: 202 ast_util.apply_to_single_assignments((item.optional_vars,), 203 item.context_expr, 204 self._process_variable_assignment) 205 self.generic_visit(node) 206 return node 207 208 def visit_Assign(self, node): 209 self.generic_visit(node) 210 ast_util.apply_to_single_assignments(node.targets, node.value, 211 self._process_variable_assignment) 212 return node 213 214 215def resolve(node, context): 216 return TypeInfoResolver(context).visit(node) 217