• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Type inference.
16
17This analysis annotates all symbols nodes of an AST with type information
18extracted from static sources:
19 * type annotations
20 * global and local symbols visible to the function at analysis time
21 * literals
22
23Important: This analysis is static, and does not detect dynamic type changes.
24The analysis attempts to use the values of external symbols, if available. These
25values are also considered static for the purpose of analysis.
26
27Requires reaching function definitions analysis.
28"""
29
30import itertools
31
32from typing import Any, Callable, Dict, Set
33
34import gast
35
36from tensorflow.python.autograph.pyct import anno
37from tensorflow.python.autograph.pyct import cfg
38from tensorflow.python.autograph.pyct import qual_names
39from tensorflow.python.autograph.pyct import transformer
40from tensorflow.python.autograph.pyct.static_analysis import activity
41from tensorflow.python.autograph.pyct.static_analysis import annos
42
43
44class Resolver(object):
45  """Resolver objects handle the process of looking up actual names and types.
46
47  Unless noted otherwise, all resolve_* methods:
48    * have a first namespace argument, mapping string to actual values
49    * have a second types_namespace argument, mapping string to actual inferred
50      types
51    * specify names as QN objects
52    * specify types as a Set of inferred types
53
54  Unless noted otherwise, all resolve_* methods must return either:
55    * a set of `type` objects
56    * None
57  """
58
59  def res_name(self, ns, types_ns, name):
60    """Resolves the type/value an external (e.g. closure, global) variable.
61
62    Args:
63      ns: namespace
64      types_ns: types namespace
65      name: symbol name
66    Returns:
67      Tuple (type, static_value). The first element is the type to use for
68      inferrence. The second is the static value to use. Return None to treat it
69      as unknown.
70    """
71    raise NotImplementedError('subclasses must implement')
72
73  def res_value(self, ns, value):
74    """Resolves the type a literal or static value."""
75    raise NotImplementedError('subclasses must implement')
76
77  def res_arg(self, ns, types_ns, f_name, name, type_anno, f_is_local):
78    """Resolves the type of a (possibly annotated) function argument.
79
80    Args:
81      ns: namespace
82      types_ns: types namespace
83      f_name: str, the function name
84      name: str, the argument name
85      type_anno: the type annotating the argument, if any
86      f_is_local: bool, whether the function is a local function
87    Returns:
88      Set of the argument types.
89    """
90    raise NotImplementedError('subclasses must implement')
91
92  def res_call(self, ns, types_ns, node, f_type, args, keywords):
93    """Resolves the return type an external function or method call.
94
95    Args:
96      ns: namespace
97      types_ns: types namespace
98      node: str, the function name
99      f_type: types of the actual function being called, if known
100      args: types of each respective argument in node.args
101      keywords: types of each respective argument in node.keywords
102
103    Returns:
104      Tuple (return_type, side_effect_types). The first element is just the
105      return types of the function. The second element is a map from
106      argument names to sets of types, and allow modelling side effects of
107      functions (for example via global or nonlocal).
108    """
109    raise NotImplementedError('subclasses must implement')
110
111  # TODO(mdan): Clean this up.
112  def res_slice(self, ns, types_ns, node_or_slice, value, slice_):
113    """Resolves the return type of slice operation."""
114    raise NotImplementedError('subclasses must implement')
115
116  def res_compare(self, ns, types_ns, node, left, right):
117    """Resolves the return type of a unary operation."""
118    raise NotImplementedError('subclasses must implement')
119
120  def res_unop(self, ns, types_ns, node, opnd):
121    """Resolves the return type of a unary operation."""
122    raise NotImplementedError('subclasses must implement')
123
124  def res_binop(self, ns, types_ns, node, left, right):
125    """Resolves the return type of a binary operation."""
126    raise NotImplementedError('subclasses must implement')
127
128  def res_list_literal(self, ns, elt_types):
129    """Resolves the type of a list literal from its elements."""
130    raise NotImplementedError('subclasses must implement')
131
132
133class _TypeMap(object):
134  """Abstraction for the state of the CFG walk for type inference.
135
136  This is a value type. Only implements the strictly necessary operators.
137
138  Attributes:
139    types: Dict[qual_names.QN, Set[Type]], mapping symbols to the set of
140      possible types.
141  """
142
143  def __init__(self, init_from=None):
144    if init_from:
145      assert isinstance(init_from, _TypeMap)
146      self.types = {
147          s: set(other_types) for s, other_types in init_from.types.items()
148      }
149    else:
150      self.types = {}
151
152  def __eq__(self, other):
153    if frozenset(self.types.keys()) != frozenset(other.types.keys()):
154      return False
155    ret = all(self.types[s] == other.types[s] for s in self.types)
156    return ret
157
158  def __ne__(self, other):
159    return not self.__eq__(other)
160
161  def __or__(self, other):
162    assert isinstance(other, _TypeMap)
163    result = _TypeMap(self)
164    for s, other_types in other.types.items():
165      if s not in result.types:
166        self_types = set()
167        result.types[s] = self_types
168      else:
169        self_types = result.types[s]
170      self_types.update(other_types)
171    return result
172
173  def __repr__(self):
174    return 'SymbolTable {}'.format(self.types)
175
176
177NO_VALUE = object()
178
179
180class StmtInferrer(gast.NodeVisitor):
181  """Runs type inference on a single AST statement.
182
183  This visitor annotates most nodes with type information. It also sets types
184  for the symbols modified by this statement in its types_out property.
185
186  Note: this inferrer is able to capture side effects of functions, however,
187  these side effects will not be applied to the current expression. Doing so
188  would create too much of a dependence on the runtime's internal rules about
189  execution order.
190  Example:
191
192    def f():
193      nonlocal a
194      a = 1
195      return a
196
197    a = 0.0
198    b = f() + a  # a = float; side effect of f() ignored
199    print(a)  # a = int; side effect of f() accounted for
200  """
201
202  def __init__(self,
203               resolver: Resolver,
204               scope: activity.Scope,
205               namespace: Dict[qual_names.QN, Any],
206               closure_types: Dict[qual_names.QN, Set[Any]],
207               types_in: _TypeMap):
208    self.resolver = resolver
209    self.scope = scope
210    self.namespace = namespace
211    self.closure_types = closure_types
212    self.types_in = types_in
213    self.new_symbols = {}
214
215    # rvalue type. This property is set when encountering an assign operation,
216    # so that visiting nodes with Store ctx (typically found on left side of
217    # assignments) can infer the type they should receive.
218    self.rtype = None
219
220  def visit(self, node):
221    types = super().visit(node)
222    if __debug__:
223      self._check_set(types)
224    if types is not None:
225      # TODO(mdan): Normalize by removing subtypes.
226      anno.setanno(node, anno.Static.TYPES, tuple(types))
227    return types
228
229  def _check_set(self, value):
230    if value is not None and not isinstance(value, set):
231      raise ValueError('{} method expected to return set, got {}'.format(
232          self.resolver, value))
233
234  def visit_Constant(self, node):
235    types = self.resolver.res_value(self.namespace, node.value)
236    if __debug__:
237      self._check_set(types)
238    return types
239
240  def _apply_unpacking(self, node):
241    assert isinstance(node.ctx, gast.Store)
242    if self.rtype is not None:
243      original_stype = self.rtype
244      # TODO(mdan): Find a better way to express unpacking.
245      i_type = self.resolver.res_value(self.namespace, 0)
246      for i, elt in enumerate(node.elts):
247        self.rtype = self.resolver.res_slice(
248            self.namespace, self.types_in.types, i, original_stype, i_type)
249        self.visit(elt)
250      self.rtype = original_stype
251      return original_stype
252    return None
253
254  def visit_Tuple(self, node):
255    if isinstance(node.ctx, gast.Load):
256      elt_types = ()
257      for elt in node.elts:
258        types_ = self.visit(elt)
259        if types_ is None:
260          return None
261        elt_types += (types_,)
262      return set(itertools.product(*elt_types))
263    return self._apply_unpacking(node)
264
265  def visit_List(self, node):
266    if isinstance(node.ctx, gast.Load):
267      elt_types = tuple(self.visit(elt) for elt in node.elts)
268      return self.resolver.res_list_literal(self.namespace, elt_types)
269    return self._apply_unpacking(node)
270
271  def visit_Set(self, node):
272    raise NotImplementedError()
273
274  def visit_Name(self, node):
275    name = anno.getanno(node, anno.Basic.QN)
276
277    if isinstance(node.ctx, gast.Load):
278      types = self.types_in.types.get(name, None)
279      if types is None:
280        if (name not in self.scope.bound) or (name in self.scope.nonlocals):
281          # TODO(mdan): Test with global variables.
282          if name in self.closure_types:
283            types = self.closure_types[name]
284          else:
285            types, value = self.resolver.res_name(
286                self.namespace, self.types_in.types, name)
287            if value is not None:
288              anno.setanno(node, anno.Static.VALUE, value)
289
290    elif isinstance(node.ctx, gast.Param):
291      # The direct parent it the whole function scope. See activity.py.
292      f_is_local = self.scope.parent.parent is not None
293
294      type_name = anno.getanno(node.annotation, anno.Basic.QN, None)
295      types = self.resolver.res_arg(self.namespace, self.types_in.types,
296                                    self.scope.function_name, name, type_name,
297                                    f_is_local)
298      if types is not None:
299        self.new_symbols[name] = types
300
301    elif isinstance(node.ctx, gast.Store):
302      if self.rtype is not None:
303        self.new_symbols[name] = self.rtype
304      types = self.rtype
305
306    else:
307      assert False, 'unknown ctx'
308
309    if __debug__:
310      self._check_set(types)
311
312    return types
313
314  def visit_Attribute(self, node):
315    parent_types = self.visit(node.value)
316
317    # Attempt to use the static value if known.
318    parent_value = anno.Static.VALUE.of(node.value, None)
319    if parent_value is not None:
320      static_value = getattr(parent_value, node.attr, NO_VALUE)
321
322      if static_value is NO_VALUE:
323        # Unexpected failure to resolve attribute. Ask the resolver about the
324        # full name instead.
325        types, static_value = self.resolver.res_name(
326            self.namespace, self.types_in, anno.Basic.QN.of(node))
327        anno.setanno(node, anno.Static.VALUE, static_value)
328        if __debug__:
329          self._check_set(types)
330        return types
331
332    else:
333      # Fall back to the type if that is known.
334      if parent_types is None:
335        return None
336
337      inferred_values = [getattr(t, node.attr, None) for t in parent_types]
338      if not inferred_values:
339        return None
340
341      static_value = inferred_values[0]
342      if static_value is None:
343        return None
344
345      if any(v is not static_value for v in inferred_values[1:]):
346        # Static value not stable, assume it's dynamic.
347        return None
348
349    types = self.resolver.res_value(self.namespace, static_value)
350    anno.setanno(node, anno.Static.VALUE, static_value)
351
352    if __debug__:
353      self._check_set(types)
354
355    return types
356
357  def visit_FunctionDef(self, node):
358    f_name = qual_names.QN(node.name)
359
360    if node.decorator_list:
361      raise NotImplementedError('decorators: {}'.format(node.decorator_list))
362
363    ret_types = None
364    if node.returns:
365      ret_types, _ = self.resolver.res_name(
366          self.namespace, self.types_in.types, anno.Basic.QN.of(node.returns))
367      if __debug__:
368        self._check_set(ret_types)
369
370    if ret_types is None:
371      ret_types = {Any}
372
373    f_types = set()
374    for rt in ret_types:
375      f_types.add(Callable[[Any], rt])
376
377    self.new_symbols[f_name] = f_types
378    # The definition of a function is an expression, hence has no return value.
379    return None
380
381  def _resolve_typed_callable(self, f_types, arg_types, keyword_types):
382    ret_types = set()
383    for t in f_types:
384
385      if isinstance(t, Callable):
386        # Note: these are undocummented - may be version-specific!
387        # Callable[[x], y]: __args__ are (x, y)
388        args = t.__args__
389        if args:
390          ret_types.add(args[-1])
391        else:
392          ret_types.add(Any)
393      else:
394        raise NotImplementedError('callable type {}'.format(type(t)))
395
396    # Side effects can not be inferred based on type alone.
397    side_effects = None
398    return ret_types, side_effects
399
400  def visit_Call(self, node):
401    self.visit(node.func)
402
403    f_name = anno.Basic.QN.of(node.func)
404    arg_types = [self.visit(a) for a in node.args]
405    keyword_types = [self.visit(kw.value) for kw in node.keywords]
406
407    if f_name in self.scope.bound:
408      # Local function, use local type definitions, if available.
409      f_type = self.types_in.types.get(f_name, None)
410      if f_type is None:
411        # No static type info available, nothing more to do.
412        ret_type, side_effects = None, None
413      else:
414        ret_type, side_effects = self._resolve_typed_callable(
415            f_type, arg_types, keyword_types)
416
417    else:
418      # Nonlocal function, resolve externally.
419      f_type = anno.Static.TYPES.of(node.func, None)
420      ret_type, side_effects = self.resolver.res_call(self.namespace,
421                                                      self.types_in.types, node,
422                                                      f_type, arg_types,
423                                                      keyword_types)
424
425    if __debug__:
426      self._check_set(ret_type)
427      if side_effects:
428        if not isinstance(side_effects, dict):
429          raise ValueError(
430              'side effects must be dict, got {}'.format(side_effects))
431        for k, v in side_effects.items():
432          if not isinstance(k, qual_names.QN):
433            raise ValueError('side effect keys must be QNs, got {}'.format(k))
434          self._check_set(v)
435
436    if side_effects:
437      self.new_symbols.update(side_effects)
438    return ret_type
439
440  def visit_Expr(self, node):
441    return self.visit(node.value)
442
443  def visit_Assign(self, node):
444    self.rtype = self.visit(node.value)
445
446    for t in node.targets:
447      self.visit(t)
448
449    self.rtype = None
450
451  def visit_Subscript(self, node):
452    val_types = self.visit(node.value)
453    slice_types = self.visit(node.slice)
454
455    if val_types is None or slice_types is None:
456      return None
457
458    types = self.resolver.res_slice(
459        self.namespace, self.types_in.types, node, val_types, slice_types)
460
461    if __debug__:
462      self._check_set(types)
463
464    return types
465
466  def visit_Compare(self, node):
467    left_types = self.visit(node.left)
468    right_types = [self.visit(c) for c in node.comparators]
469
470    if left_types is None or any(t is None for t in right_types):
471      return None
472
473    types = self.resolver.res_compare(
474        self.namespace, self.types_in.types, node, left_types, right_types)
475
476    if __debug__:
477      self._check_set(types)
478
479    return types
480
481  def visit_BinOp(self, node):
482    left_types = self.visit(node.left)
483    right_types = self.visit(node.right)
484
485    if left_types is None or right_types is None:
486      return None
487
488    types = self.resolver.res_binop(
489        self.namespace, self.types_in.types, node, left_types, right_types)
490
491    if __debug__:
492      self._check_set(types)
493
494    return types
495
496  def visit_UnaryOp(self, node):
497    opnd_types = self.visit(node.operand)
498
499    if opnd_types is None:
500      return None
501
502    types = self.resolver.res_unop(
503        self.namespace, self.types_in.types, node, opnd_types)
504
505    if __debug__:
506      self._check_set(types)
507
508    return types
509
510
511class Analyzer(cfg.GraphVisitor):
512  """CFG visitor that propagates type information across statements."""
513
514  def __init__(self, graph, resolver, namespace, scope, closure_types):
515    """Creates a new analyzer.
516
517    Args:
518      graph: cfg.Graph
519      resolver: Resolver
520      namespace: Dict[str, Any]
521      scope: activity.Scope
522      closure_types: Dict[QN, Set]
523    """
524    super(Analyzer, self).__init__(graph)
525    self.resolver = resolver
526    self.namespace = namespace
527    self.scope = scope
528    self.closure_types = closure_types
529
530    context_types = {
531        n: t for n, t in closure_types.items() if n not in scope.bound
532    }
533    if context_types:
534      self.context_types = _TypeMap()
535      self.context_types.types = context_types
536    else:
537      self.context_types = None
538
539  def init_state(self, _):
540    return _TypeMap()
541
542  def _update_closure_types(self, ast_node, types):
543    existing_types = anno.Static.CLOSURE_TYPES.of(ast_node, None)
544
545    if existing_types is None:
546      existing_types = {}
547      anno.Static.CLOSURE_TYPES.add_to(ast_node, existing_types)
548
549    for k, v in types.types.items():
550      if k in existing_types:
551        existing_types[k].update(v)
552      else:
553        existing_types[k] = set(v)
554
555  def visit_node(self, node):
556    prev_types_out = self.out[node]
557
558    types_in = _TypeMap()
559    for n in node.prev:
560      types_in |= self.out[n]
561    if (self.context_types is not None) and (node is self.graph.entry):
562      types_in |= self.context_types
563
564    types_out = _TypeMap(types_in)
565    ast_node = node.ast_node
566
567    inferrer = StmtInferrer(self.resolver, self.scope, self.namespace,
568                            self.closure_types, types_in)
569    inferrer.visit(ast_node)
570    types_out.types.update(inferrer.new_symbols)
571
572    reaching_fndefs = anno.Static.DEFINED_FNS_IN.of(ast_node)
573    node_scope = anno.Static.SCOPE.of(ast_node, None)
574    if node_scope is not None:
575      # TODO(mdan): Check that it's actually safe to skip nodes without scope.
576      reads = {str(qn) for qn in node_scope.read}
577      for def_node in reaching_fndefs:
578        if def_node.name in reads:
579          self._update_closure_types(def_node, types_out)
580
581    self.in_[node] = types_in
582    self.out[node] = types_out
583
584    return prev_types_out != types_out
585
586
587class FunctionVisitor(transformer.Base):
588  """AST visitor that applies type inference to each function separately."""
589
590  def __init__(self, source_info, graphs, resolver):
591    super(FunctionVisitor, self).__init__(source_info)
592    self.graphs = graphs
593    self.resolver = resolver
594
595  def visit_FunctionDef(self, node):
596    subgraph = self.graphs[node]
597    scope = anno.getanno(node, annos.NodeAnno.ARGS_AND_BODY_SCOPE)
598    closure_types = anno.getanno(node, anno.Static.CLOSURE_TYPES, {})
599
600    analyzer = Analyzer(subgraph, self.resolver, self.ctx.info.namespace, scope,
601                        closure_types)
602    analyzer.visit_forward()
603
604    # Recursively process any remaining subfunctions.
605    node.body = self.visit_block(node.body)
606
607    return node
608
609
610def resolve(node, source_info, graphs, resolver):
611  """Performs type inference.
612
613  Args:
614    node: ast.AST
615    source_info: transformer.SourceInfo
616    graphs: Dict[ast.FunctionDef, cfg.Graph]
617    resolver: Resolver
618
619  Returns:
620    ast.AST
621  """
622  visitor = FunctionVisitor(source_info, graphs, resolver)
623  node = visitor.visit(node)
624  return node
625