1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Type resolution. 16 17Requires annotations generated by LiveValuesResolver. 18""" 19 20from __future__ import absolute_import 21from __future__ import division 22from __future__ import print_function 23 24import gast 25 26from tensorflow.contrib.py2tf.pyct import anno 27from tensorflow.contrib.py2tf.pyct import transformer 28from tensorflow.python.util import tf_inspect 29 30 31class Scope(object): 32 """Encloses symbol value references. 33 34 Attributes: 35 values: A dict mapping string to gast.Node, containing the value that was 36 most recently assigned to the symbol. 37 """ 38 39 def __init__(self, parent): 40 """Create a new scope. 41 42 Args: 43 parent: A Scope or None. 44 """ 45 self.parent = parent 46 self.values = {} 47 48 def __repr__(self): 49 return 'Scope[%s]' % self.values.keys() 50 51 def copy(self): 52 s = Scope(self.parent) 53 s.values = self.values.copy() 54 return s 55 56 def setval(self, name, value): 57 self.values[name] = value 58 59 def hasval(self, name): 60 return (name in self.values or 61 (self.parent is not None and self.parent.hasval(name))) 62 63 def getval(self, name): 64 if name in self.values: 65 return self.values[name] 66 if self.parent is not None: 67 return self.parent.getval(name) 68 raise KeyError(name) 69 70 71class TypeInfoResolver(transformer.Base): 72 """Annotates symbols with type information where possible. 73 74 Nodes currently annotated: 75 * Call (helps detect class constructors) 76 * Attribute (helps resolve object methods) 77 """ 78 79 def __init__(self, context): 80 super(TypeInfoResolver, self).__init__(context) 81 self.scope = Scope(None) 82 self.function_level = 0 83 84 def visit_FunctionDef(self, node): 85 self.scope = Scope(self.scope) 86 self.function_level += 1 87 self.generic_visit(node) 88 self.function_level -= 1 89 self.scope = self.scope.parent 90 return node 91 92 def _visit_block(self, block): 93 self.scope = Scope(self.scope) 94 for i, n in enumerate(block): 95 block[i] = self.generic_visit(n) 96 self.scope = self.scope.parent 97 return block 98 99 def visit_For(self, node): 100 self.generic_visit(node.target) 101 self.generic_visit(node.iter) 102 node.body = self._visit_block(node.body) 103 node.orelse = self._visit_block(node.orelse) 104 return node 105 106 def visit_While(self, node): 107 self.generic_visit(node.test) 108 node.body = self._visit_block(node.body) 109 node.orelse = self._visit_block(node.orelse) 110 return node 111 112 def visit_If(self, node): 113 self.generic_visit(node.test) 114 node.body = self._visit_block(node.body) 115 node.orelse = self._visit_block(node.orelse) 116 return node 117 118 def _process_function_arg(self, arg_name): 119 str_name = str(arg_name) 120 if self.function_level == 1 and str_name in self.context.arg_types: 121 # Forge a node to hold the type information, so that method calls on 122 # it can resolve the type. 123 type_holder = arg_name.ast() 124 type_string, type_obj = self.context.arg_types[str_name] 125 anno.setanno(type_holder, 'type', type_obj) 126 anno.setanno(type_holder, 'type_fqn', tuple(type_string.split('.'))) 127 self.scope.setval(arg_name, type_holder) 128 129 def visit_arg(self, node): 130 self._process_function_arg(anno.getanno(node.arg, anno.Basic.QN)) 131 return node 132 133 def visit_Name(self, node): 134 self.generic_visit(node) 135 qn = anno.getanno(node, anno.Basic.QN) 136 if isinstance(node.ctx, gast.Param): 137 self._process_function_arg(qn) 138 elif isinstance(node.ctx, gast.Load) and self.scope.hasval(qn): 139 # E.g. if we had 140 # a = b 141 # then for future references to `a` we should have traced_source = `b` 142 traced_source = self.scope.getval(qn) 143 if anno.hasanno(traced_source, 'type'): 144 anno.setanno(node, 'type', anno.getanno(traced_source, 'type')) 145 anno.setanno(node, 'type_fqn', anno.getanno(traced_source, 'type_fqn')) 146 return node 147 148 def _process_variable_assignment(self, source, targets): 149 if isinstance(source, gast.Call): 150 func = source.func 151 if anno.hasanno(func, 'live_val'): 152 func_obj = anno.getanno(func, 'live_val') 153 if tf_inspect.isclass(func_obj): 154 anno.setanno(source, 'is_constructor', True) 155 anno.setanno(source, 'type', func_obj) 156 anno.setanno(source, 'type_fqn', anno.getanno(func, 'fqn')) 157 # TODO(mdan): Raise an error if constructor has side effects. 158 # We can have a whitelist of no-side-effects constructors. 159 # We can also step inside the constructor and further analyze. 160 161 for t in targets: 162 if isinstance(t, gast.Tuple): 163 for i, e in enumerate(t.elts): 164 self.scope.setval( 165 anno.getanno(e, anno.Basic.QN), 166 gast.Subscript(source, gast.Index(i), ctx=gast.Store())) 167 elif isinstance(t, (gast.Name, gast.Attribute)): 168 self.scope.setval(anno.getanno(t, anno.Basic.QN), source) 169 else: 170 raise ValueError('Dont know how to handle assignment to %s' % t) 171 172 def visit_With(self, node): 173 for wi in node.items: 174 if wi.optional_vars is not None: 175 self._process_variable_assignment(wi.context_expr, (wi.optional_vars,)) 176 self.generic_visit(node) 177 return node 178 179 def visit_Assign(self, node): 180 self.generic_visit(node) 181 self._process_variable_assignment(node.value, node.targets) 182 return node 183 184 185def resolve(node, context): 186 return TypeInfoResolver(context).visit(node) 187