1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Type inference. 16 17This analysis annotates all symbols nodes of an AST with type information 18extracted from static sources: 19 * type annotations 20 * global and local symbols visible to the function at analysis time 21 * literals 22 23Important: This analysis is static, and does not detect dynamic type changes. 24The analysis attempts to use the values of external symbols, if available. These 25values are also considered static for the purpose of analysis. 26 27Requires reaching function definitions analysis. 28""" 29 30from __future__ import absolute_import 31from __future__ import division 32from __future__ import print_function 33 34import itertools 35 36from typing import Any, Callable, Dict, Set 37 38import gast 39 40from tensorflow.python.autograph.pyct import anno 41from tensorflow.python.autograph.pyct import cfg 42from tensorflow.python.autograph.pyct import qual_names 43from tensorflow.python.autograph.pyct import transformer 44from tensorflow.python.autograph.pyct.static_analysis import activity 45from tensorflow.python.autograph.pyct.static_analysis import annos 46 47 48class Resolver(object): 49 """Resolver objects handle the process of looking up actual names and types. 50 51 Unless noted otherwise, all resolve_* methods: 52 * have a first namespace argument, mapping string to actual values 53 * have a second types_namespace argument, mapping string to actual inferred 54 types 55 * specify names as QN objects 56 * specify types as a Set of inferred types 57 58 Unless noted otherwise, all resolve_* methods must return either: 59 * a set of `type` objects 60 * None 61 """ 62 63 def res_name(self, ns, types_ns, name): 64 """Resolves the type/value an external (e.g. closure, global) variable. 65 66 Args: 67 ns: namespace 68 types_ns: types namespace 69 name: symbol name 70 Returns: 71 Tuple (type, static_value). The first element is the type to use for 72 inferrence. The second is the static value to use. Return None to treat it 73 as unknown. 74 """ 75 raise NotImplementedError('subclasses must implement') 76 77 def res_value(self, ns, value): 78 """Resolves the type a literal or static value.""" 79 raise NotImplementedError('subclasses must implement') 80 81 def res_arg(self, ns, types_ns, f_name, name, type_anno, f_is_local): 82 """Resolves the type of a (possibly annotated) function argument. 83 84 Args: 85 ns: namespace 86 types_ns: types namespace 87 f_name: str, the function name 88 name: str, the argument name 89 type_anno: the type annotating the argument, if any 90 f_is_local: bool, whether the function is a local function 91 Returns: 92 Set of the argument types. 93 """ 94 raise NotImplementedError('subclasses must implement') 95 96 def res_call(self, ns, types_ns, node, f_type, args, keywords): 97 """Resolves the return type an external function or method call. 98 99 Args: 100 ns: namespace 101 types_ns: types namespace 102 node: str, the function name 103 f_type: types of the actual function being called, if known 104 args: types of each respective argument in node.args 105 keywords: types of each respective argument in node.keywords 106 107 Returns: 108 Tuple (return_type, side_effect_types). The first element is just the 109 return types of the function. The second element is a map from 110 argument names to sets of types, and allow modelling side effects of 111 functions (for example via global or nonlocal). 112 """ 113 raise NotImplementedError('subclasses must implement') 114 115 # TODO(mdan): Clean this up. 116 def res_slice(self, ns, types_ns, node_or_slice, value, slice_): 117 """Resolves the return type of slice operation.""" 118 raise NotImplementedError('subclasses must implement') 119 120 def res_compare(self, ns, types_ns, node, left, right): 121 """Resolves the return type of a unary operation.""" 122 raise NotImplementedError('subclasses must implement') 123 124 def res_unop(self, ns, types_ns, node, opnd): 125 """Resolves the return type of a unary operation.""" 126 raise NotImplementedError('subclasses must implement') 127 128 def res_binop(self, ns, types_ns, node, left, right): 129 """Resolves the return type of a binary operation.""" 130 raise NotImplementedError('subclasses must implement') 131 132 def res_list_literal(self, ns, elt_types): 133 """Resolves the type of a list literal from its elements.""" 134 raise NotImplementedError('subclasses must implement') 135 136 137class _TypeMap(object): 138 """Abstraction for the state of the CFG walk for type inference. 139 140 This is a value type. Only implements the strictly necessary operators. 141 142 Attributes: 143 types: Dict[qual_names.QN, Set[Type]], mapping symbols to the set of 144 possible types. 145 """ 146 147 def __init__(self, init_from=None): 148 if init_from: 149 assert isinstance(init_from, _TypeMap) 150 self.types = { 151 s: set(other_types) for s, other_types in init_from.types.items() 152 } 153 else: 154 self.types = {} 155 156 def __eq__(self, other): 157 if frozenset(self.types.keys()) != frozenset(other.types.keys()): 158 return False 159 ret = all(self.types[s] == other.types[s] for s in self.types) 160 return ret 161 162 def __ne__(self, other): 163 return not self.__eq__(other) 164 165 def __or__(self, other): 166 assert isinstance(other, _TypeMap) 167 result = _TypeMap(self) 168 for s, other_types in other.types.items(): 169 if s not in result.types: 170 self_types = set() 171 result.types[s] = self_types 172 else: 173 self_types = result.types[s] 174 self_types.update(other_types) 175 return result 176 177 def __repr__(self): 178 return 'SymbolTable {}'.format(self.types) 179 180 181NO_VALUE = object() 182 183 184class StmtInferrer(gast.NodeVisitor): 185 """Runs type inference on a single AST statement. 186 187 This visitor annotates most nodes with type information. It also sets types 188 for the symbols modified by this statement in its types_out property. 189 190 Note: this inferrer is able to capture side effects of functions, however, 191 these side effects will not be applied to the current expression. Doing so 192 would create too much of a dependence on the runtime's internal rules about 193 execution order. 194 Example: 195 196 def f(): 197 nonlocal a 198 a = 1 199 return a 200 201 a = 0.0 202 b = f() + a # a = float; side effect of f() ignored 203 print(a) # a = int; side effect of f() accounted for 204 """ 205 206 def __init__(self, 207 resolver: Resolver, 208 scope: activity.Scope, 209 namespace: Dict[qual_names.QN, Any], 210 closure_types: Dict[qual_names.QN, Set[Any]], 211 types_in: _TypeMap): 212 self.resolver = resolver 213 self.scope = scope 214 self.namespace = namespace 215 self.closure_types = closure_types 216 self.types_in = types_in 217 self.new_symbols = {} 218 219 # rvalue type. This property is set when encountering an assign operation, 220 # so that visiting nodes with Store ctx (typically found on left side of 221 # assignments) can infer the type they should receive. 222 self.rtype = None 223 224 def visit(self, node): 225 types = super().visit(node) 226 if __debug__: 227 self._check_set(types) 228 if types is not None: 229 # TODO(mdan): Normalize by removing subtypes. 230 anno.setanno(node, anno.Static.TYPES, tuple(types)) 231 return types 232 233 def _check_set(self, value): 234 if value is not None and not isinstance(value, set): 235 raise ValueError('{} method expected to return set, got {}'.format( 236 self.resolver, value)) 237 238 def visit_Constant(self, node): 239 types = self.resolver.res_value(self.namespace, node.value) 240 if __debug__: 241 self._check_set(types) 242 return types 243 244 def _apply_unpacking(self, node): 245 assert isinstance(node.ctx, gast.Store) 246 if self.rtype is not None: 247 original_stype = self.rtype 248 # TODO(mdan): Find a better way to express unpacking. 249 i_type = self.resolver.res_value(self.namespace, 0) 250 for i, elt in enumerate(node.elts): 251 self.rtype = self.resolver.res_slice( 252 self.namespace, self.types_in.types, i, original_stype, i_type) 253 self.visit(elt) 254 self.rtype = original_stype 255 return original_stype 256 return None 257 258 def visit_Tuple(self, node): 259 if isinstance(node.ctx, gast.Load): 260 elt_types = () 261 for elt in node.elts: 262 types_ = self.visit(elt) 263 if types_ is None: 264 return None 265 elt_types += (types_,) 266 return set(itertools.product(*elt_types)) 267 return self._apply_unpacking(node) 268 269 def visit_List(self, node): 270 if isinstance(node.ctx, gast.Load): 271 elt_types = tuple(self.visit(elt) for elt in node.elts) 272 return self.resolver.res_list_literal(self.namespace, elt_types) 273 return self._apply_unpacking(node) 274 275 def visit_Set(self, node): 276 raise NotImplementedError() 277 278 def visit_Name(self, node): 279 name = anno.getanno(node, anno.Basic.QN) 280 281 if isinstance(node.ctx, gast.Load): 282 types = self.types_in.types.get(name, None) 283 if types is None: 284 if (name not in self.scope.bound) or (name in self.scope.nonlocals): 285 # TODO(mdan): Test with global variables. 286 if name in self.closure_types: 287 types = self.closure_types[name] 288 else: 289 types, value = self.resolver.res_name( 290 self.namespace, self.types_in.types, name) 291 if value is not None: 292 anno.setanno(node, anno.Static.VALUE, value) 293 294 elif isinstance(node.ctx, gast.Param): 295 # The direct parent it the whole function scope. See activity.py. 296 f_is_local = self.scope.parent.parent is not None 297 298 type_name = anno.getanno(node.annotation, anno.Basic.QN, None) 299 types = self.resolver.res_arg(self.namespace, self.types_in.types, 300 self.scope.function_name, name, type_name, 301 f_is_local) 302 if types is not None: 303 self.new_symbols[name] = types 304 305 elif isinstance(node.ctx, gast.Store): 306 if self.rtype is not None: 307 self.new_symbols[name] = self.rtype 308 types = self.rtype 309 310 else: 311 assert False, 'unknown ctx' 312 313 if __debug__: 314 self._check_set(types) 315 316 return types 317 318 def visit_Attribute(self, node): 319 parent_types = self.visit(node.value) 320 321 # Attempt to use the static value if known. 322 parent_value = anno.Static.VALUE.of(node.value, None) 323 if parent_value is not None: 324 static_value = getattr(parent_value, node.attr, NO_VALUE) 325 326 if static_value is NO_VALUE: 327 # Unexpected failure to resolve attribute. Ask the resolver about the 328 # full name instead. 329 types, static_value = self.resolver.res_name( 330 self.namespace, self.types_in, anno.Basic.QN.of(node)) 331 anno.setanno(node, anno.Static.VALUE, static_value) 332 if __debug__: 333 self._check_set(types) 334 return types 335 336 else: 337 # Fall back to the type if that is known. 338 if parent_types is None: 339 return None 340 341 inferred_values = [getattr(t, node.attr, None) for t in parent_types] 342 if not inferred_values: 343 return None 344 345 static_value = inferred_values[0] 346 if static_value is None: 347 return None 348 349 if any(v is not static_value for v in inferred_values[1:]): 350 # Static value not stable, assume it's dynamic. 351 return None 352 353 types = self.resolver.res_value(self.namespace, static_value) 354 anno.setanno(node, anno.Static.VALUE, static_value) 355 356 if __debug__: 357 self._check_set(types) 358 359 return types 360 361 def visit_FunctionDef(self, node): 362 f_name = qual_names.QN(node.name) 363 364 if node.decorator_list: 365 raise NotImplementedError('decorators: {}'.format(node.decorator_list)) 366 367 ret_types = None 368 if node.returns: 369 ret_types, _ = self.resolver.res_name( 370 self.namespace, self.types_in.types, anno.Basic.QN.of(node.returns)) 371 if __debug__: 372 self._check_set(ret_types) 373 374 if ret_types is None: 375 ret_types = {Any} 376 377 f_types = set() 378 for rt in ret_types: 379 f_types.add(Callable[[Any], rt]) 380 381 self.new_symbols[f_name] = f_types 382 # The definition of a function is an expression, hence has no return value. 383 return None 384 385 def _resolve_typed_callable(self, f_types, arg_types, keyword_types): 386 ret_types = set() 387 for t in f_types: 388 389 if isinstance(t, Callable): 390 # Note: these are undocummented - may be version-specific! 391 # Callable[[x], y]: __args__ are (x, y) 392 args = t.__args__ 393 if args: 394 ret_types.add(args[-1]) 395 else: 396 ret_types.add(Any) 397 else: 398 raise NotImplementedError('callable type {}'.format(type(t))) 399 400 # Side effects can not be inferred based on type alone. 401 side_effects = None 402 return ret_types, side_effects 403 404 def visit_Call(self, node): 405 self.visit(node.func) 406 407 f_name = anno.Basic.QN.of(node.func) 408 arg_types = [self.visit(a) for a in node.args] 409 keyword_types = [self.visit(kw.value) for kw in node.keywords] 410 411 if f_name in self.scope.bound: 412 # Local function, use local type definitions, if available. 413 f_type = self.types_in.types.get(f_name, None) 414 if f_type is None: 415 # No static type info available, nothing more to do. 416 ret_type, side_effects = None, None 417 else: 418 ret_type, side_effects = self._resolve_typed_callable( 419 f_type, arg_types, keyword_types) 420 421 else: 422 # Nonlocal function, resolve externally. 423 f_type = anno.Static.TYPES.of(node.func, None) 424 ret_type, side_effects = self.resolver.res_call(self.namespace, 425 self.types_in.types, node, 426 f_type, arg_types, 427 keyword_types) 428 429 if __debug__: 430 self._check_set(ret_type) 431 if side_effects: 432 if not isinstance(side_effects, dict): 433 raise ValueError( 434 'side effects must be dict, got {}'.format(side_effects)) 435 for k, v in side_effects.items(): 436 if not isinstance(k, qual_names.QN): 437 raise ValueError('side effect keys must be QNs, got {}'.format(k)) 438 self._check_set(v) 439 440 if side_effects: 441 self.new_symbols.update(side_effects) 442 return ret_type 443 444 def visit_Expr(self, node): 445 return self.visit(node.value) 446 447 def visit_Assign(self, node): 448 self.rtype = self.visit(node.value) 449 450 for t in node.targets: 451 self.visit(t) 452 453 self.rtype = None 454 455 def visit_Subscript(self, node): 456 val_types = self.visit(node.value) 457 slice_types = self.visit(node.slice) 458 459 if val_types is None or slice_types is None: 460 return None 461 462 types = self.resolver.res_slice( 463 self.namespace, self.types_in.types, node, val_types, slice_types) 464 465 if __debug__: 466 self._check_set(types) 467 468 return types 469 470 def visit_Compare(self, node): 471 left_types = self.visit(node.left) 472 right_types = [self.visit(c) for c in node.comparators] 473 474 if left_types is None or any(t is None for t in right_types): 475 return None 476 477 types = self.resolver.res_compare( 478 self.namespace, self.types_in.types, node, left_types, right_types) 479 480 if __debug__: 481 self._check_set(types) 482 483 return types 484 485 def visit_BinOp(self, node): 486 left_types = self.visit(node.left) 487 right_types = self.visit(node.right) 488 489 if left_types is None or right_types is None: 490 return None 491 492 types = self.resolver.res_binop( 493 self.namespace, self.types_in.types, node, left_types, right_types) 494 495 if __debug__: 496 self._check_set(types) 497 498 return types 499 500 def visit_UnaryOp(self, node): 501 opnd_types = self.visit(node.operand) 502 503 if opnd_types is None: 504 return None 505 506 types = self.resolver.res_unop( 507 self.namespace, self.types_in.types, node, opnd_types) 508 509 if __debug__: 510 self._check_set(types) 511 512 return types 513 514 515class Analyzer(cfg.GraphVisitor): 516 """CFG visitor that propagates type information across statements.""" 517 518 def __init__(self, graph, resolver, namespace, scope, closure_types): 519 """Creates a new analyzer. 520 521 Args: 522 graph: cfg.Graph 523 resolver: Resolver 524 namespace: Dict[str, Any] 525 scope: activity.Scope 526 closure_types: Dict[QN, Set] 527 """ 528 super(Analyzer, self).__init__(graph) 529 self.resolver = resolver 530 self.namespace = namespace 531 self.scope = scope 532 self.closure_types = closure_types 533 534 context_types = { 535 n: t for n, t in closure_types.items() if n not in scope.bound 536 } 537 if context_types: 538 self.context_types = _TypeMap() 539 self.context_types.types = context_types 540 else: 541 self.context_types = None 542 543 def init_state(self, _): 544 return _TypeMap() 545 546 def _update_closure_types(self, ast_node, types): 547 existing_types = anno.Static.CLOSURE_TYPES.of(ast_node, None) 548 549 if existing_types is None: 550 existing_types = {} 551 anno.Static.CLOSURE_TYPES.add_to(ast_node, existing_types) 552 553 for k, v in types.types.items(): 554 if k in existing_types: 555 existing_types[k].update(v) 556 else: 557 existing_types[k] = set(v) 558 559 def visit_node(self, node): 560 prev_types_out = self.out[node] 561 562 types_in = _TypeMap() 563 for n in node.prev: 564 types_in |= self.out[n] 565 if (self.context_types is not None) and (node is self.graph.entry): 566 types_in |= self.context_types 567 568 types_out = _TypeMap(types_in) 569 ast_node = node.ast_node 570 571 inferrer = StmtInferrer(self.resolver, self.scope, self.namespace, 572 self.closure_types, types_in) 573 inferrer.visit(ast_node) 574 types_out.types.update(inferrer.new_symbols) 575 576 reaching_fndefs = anno.Static.DEFINED_FNS_IN.of(ast_node) 577 node_scope = anno.Static.SCOPE.of(ast_node, None) 578 if node_scope is not None: 579 # TODO(mdan): Check that it's actually safe to skip nodes without scope. 580 reads = {str(qn) for qn in node_scope.read} 581 for def_node in reaching_fndefs: 582 if def_node.name in reads: 583 self._update_closure_types(def_node, types_out) 584 585 self.in_[node] = types_in 586 self.out[node] = types_out 587 588 return prev_types_out != types_out 589 590 591class FunctionVisitor(transformer.Base): 592 """AST visitor that applies type inference to each function separately.""" 593 594 def __init__(self, source_info, graphs, resolver): 595 super(FunctionVisitor, self).__init__(source_info) 596 self.graphs = graphs 597 self.resolver = resolver 598 599 def visit_FunctionDef(self, node): 600 subgraph = self.graphs[node] 601 scope = anno.getanno(node, annos.NodeAnno.ARGS_AND_BODY_SCOPE) 602 closure_types = anno.getanno(node, anno.Static.CLOSURE_TYPES, {}) 603 604 analyzer = Analyzer(subgraph, self.resolver, self.ctx.info.namespace, scope, 605 closure_types) 606 analyzer.visit_forward() 607 608 # Recursively process any remaining subfunctions. 609 node.body = self.visit_block(node.body) 610 611 return node 612 613 614def resolve(node, source_info, graphs, resolver): 615 """Performs type inference. 616 617 Args: 618 node: ast.AST 619 source_info: transformer.SourceInfo 620 graphs: Dict[ast.FunctionDef, cfg.Graph] 621 resolver: Resolver 622 623 Returns: 624 ast.AST 625 """ 626 visitor = FunctionVisitor(source_info, graphs, resolver) 627 node = visitor.visit(node) 628 return node 629