• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Type resolution.
16
17This analyzer uses known live values to further infer object types. This
18may include for instance constructed objects and object member functions.
19
20In addition, the analyzer also handles user annotations made in the code (for
21example, the autograph.set_element_type function).
22
23Requires annotations generated by LiveValuesResolver.
24"""
25
26# TODO(mdan): This would be more robust with a CFG.
27# Situations with multiple reaching modifications (e.g. modified inside and
28# outside a control flow statement) should be more robustly detected and
29# analyzed.
30
31# TODO(mdan): Look into using Python AST's type annotation fields instead.
32# It would be desirable to use that mechanism if we can.
33# Some caveats to consider: We may need to annotate other nodes like
34# Attribute. It may also not be feasible for us to faithfully to replicate
35# PY3's type annotations where it isn't available. It would also require us
36# to design rigorous type definitions that can accommodate Python types
37# as well as TensorFLow dtypes and shapes.
38
39
40from __future__ import absolute_import
41from __future__ import division
42from __future__ import print_function
43
44import gast
45
46from tensorflow.python.autograph.pyct import anno
47from tensorflow.python.autograph.pyct import ast_util
48from tensorflow.python.autograph.pyct import inspect_utils
49from tensorflow.python.autograph.pyct import transformer
50from tensorflow.python.util import tf_inspect
51
52
53# TODO(mdan): Remove the duplication between this and activity.py.
54# In particular, the symbol definitions we track here could as well be tracked
55# there because they follow the same rules for visibility.
56# TODO(mdan): Use a CFG based Defined analysis instead.
57class Scope(object):
58  """Tracks symbol value references.
59
60  Attributes:
61    values: A dict mapping string to gast.Node, containing the value that was
62        most recently assigned to the symbol.
63  """
64
65  def __init__(self, parent):
66    """Create a new scope.
67
68    Args:
69      parent: A Scope or None.
70    """
71    self.parent = parent
72    self.values = {}
73
74  def __repr__(self):
75    return 'Scope[%s]' % self.values.keys()
76
77  def copy(self):
78    s = Scope(self.parent)
79    s.values = self.values.copy()
80    return s
81
82  def setval(self, name, value):
83    self.values[name] = value
84
85  def hasval(self, name):
86    return (name in self.values or
87            (self.parent is not None and self.parent.hasval(name)))
88
89  def getval(self, name):
90    if name in self.values:
91      return self.values[name]
92    if self.parent is not None:
93      return self.parent.getval(name)
94    raise KeyError(name)
95
96
97class TypeInfoResolver(transformer.Base):
98  """Annotates symbols with type information where possible.
99
100  Nodes currently annotated:
101    * Call (helps detect class constructors)
102    * Attribute (helps resolve object methods)
103  """
104
105  def __init__(self, context):
106    super(TypeInfoResolver, self).__init__(context)
107    self.scope = Scope(None)
108
109  def visit_FunctionDef(self, node):
110    self.scope = Scope(self.scope)
111    node = self.generic_visit(node)
112    self.scope = self.scope.parent
113    return node
114
115  def _visit_block(self, block):
116    self.scope = Scope(self.scope)
117    block = self.visit_block(block)
118    self.scope = self.scope.parent
119    return block
120
121  def visit_For(self, node):
122    self.generic_visit(node.target)
123    self.generic_visit(node.iter)
124    node.body = self._visit_block(node.body)
125    node.orelse = self._visit_block(node.orelse)
126    return node
127
128  def visit_While(self, node):
129    self.generic_visit(node.test)
130    node.body = self._visit_block(node.body)
131    node.orelse = self._visit_block(node.orelse)
132    return node
133
134  def visit_If(self, node):
135    self.generic_visit(node.test)
136    node.body = self._visit_block(node.body)
137    node.orelse = self._visit_block(node.orelse)
138    return node
139
140  def _process_function_arg(self, arg_node):
141    qn = anno.getanno(arg_node, anno.Basic.QN)
142    arg_name = str(qn)
143    self.scope.setval(qn, arg_node)
144    if (len(self.enclosing_entities) == 1 and
145        arg_name in self.ctx.info.arg_types):
146      # Forge a node to hold the type information, so that method calls on
147      # it can resolve the type.
148      type_string, type_obj = self.ctx.info.arg_types[
149          arg_name]
150      anno.setanno(arg_node, 'type', type_obj)
151      anno.setanno(arg_node, 'type_fqn', tuple(type_string.split('.')))
152
153  def visit_arg(self, node):
154    self._process_function_arg(node.arg)
155    return node
156
157  def visit_Name(self, node):
158    self.generic_visit(node)
159    if isinstance(node.ctx, gast.Param):
160      self._process_function_arg(node)
161    elif isinstance(node.ctx, gast.Load):
162      qn = anno.getanno(node, anno.Basic.QN)
163      if self.scope.hasval(qn):
164        # E.g. if we had
165        # a = b
166        # then for future references to `a` we should have definition = `b`
167        definition = self.scope.getval(qn)
168        anno.copyanno(definition, node, 'type')
169        anno.copyanno(definition, node, 'type_fqn')
170
171        # TODO(mdan): Remove this when the directives module is in.
172        anno.copyanno(definition, node, 'element_type')
173        anno.copyanno(definition, node, 'element_shape')
174    return node
175
176  def _process_variable_assignment(self, target, value):
177    # Constructors
178    if isinstance(value, gast.Call):
179      func = value.func
180      if anno.hasanno(func, 'live_val'):
181        func_obj = anno.getanno(func, 'live_val')
182        if (tf_inspect.isclass(func_obj) and
183            not inspect_utils.isbuiltin(func_obj)):
184          anno.setanno(value, 'is_constructor', True)
185          anno.setanno(value, 'type', func_obj)
186          anno.setanno(value, 'type_fqn', anno.getanno(func, 'fqn'))
187          # TODO(mdan): Raise an error if constructor has side effects.
188          # We can have a whitelist of no-side-effects constructors.
189          # We can also step inside the constructor and further analyze.
190
191    if isinstance(target, (gast.Name, gast.Attribute)):
192      target_symbol = anno.getanno(target, anno.Basic.QN)
193      self.scope.setval(target_symbol, value)
194    elif isinstance(target, gast.Subscript):
195      pass
196    else:
197      raise ValueError('assignment target has unknown type: %s' % target)
198
199  def visit_With(self, node):
200    for item in node.items:
201      if item.optional_vars is not None:
202        ast_util.apply_to_single_assignments((item.optional_vars,),
203                                             item.context_expr,
204                                             self._process_variable_assignment)
205    self.generic_visit(node)
206    return node
207
208  def visit_Assign(self, node):
209    self.generic_visit(node)
210    ast_util.apply_to_single_assignments(node.targets, node.value,
211                                         self._process_variable_assignment)
212    return node
213
214
215def resolve(node, context):
216  return TypeInfoResolver(context).visit(node)
217