1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Type inference. 16 17This analysis annotates all symbols nodes of an AST with type information 18extracted from static sources: 19 * type annotations 20 * global and local symbols visible to the function at analysis time 21 * literals 22 23Important: This analysis is static, and does not detect dynamic type changes. 24The analysis attempts to use the values of external symbols, if available. These 25values are also considered static for the purpose of analysis. 26 27Requires reaching function definitions analysis. 28""" 29 30import itertools 31 32from typing import Any, Callable, Dict, Set 33 34import gast 35 36from tensorflow.python.autograph.pyct import anno 37from tensorflow.python.autograph.pyct import cfg 38from tensorflow.python.autograph.pyct import qual_names 39from tensorflow.python.autograph.pyct import transformer 40from tensorflow.python.autograph.pyct.static_analysis import activity 41from tensorflow.python.autograph.pyct.static_analysis import annos 42 43 44class Resolver(object): 45 """Resolver objects handle the process of looking up actual names and types. 46 47 Unless noted otherwise, all resolve_* methods: 48 * have a first namespace argument, mapping string to actual values 49 * have a second types_namespace argument, mapping string to actual inferred 50 types 51 * specify names as QN objects 52 * specify types as a Set of inferred types 53 54 Unless noted otherwise, all resolve_* methods must return either: 55 * a set of `type` objects 56 * None 57 """ 58 59 def res_name(self, ns, types_ns, name): 60 """Resolves the type/value an external (e.g. closure, global) variable. 61 62 Args: 63 ns: namespace 64 types_ns: types namespace 65 name: symbol name 66 Returns: 67 Tuple (type, static_value). The first element is the type to use for 68 inferrence. The second is the static value to use. Return None to treat it 69 as unknown. 70 """ 71 raise NotImplementedError('subclasses must implement') 72 73 def res_value(self, ns, value): 74 """Resolves the type a literal or static value.""" 75 raise NotImplementedError('subclasses must implement') 76 77 def res_arg(self, ns, types_ns, f_name, name, type_anno, f_is_local): 78 """Resolves the type of a (possibly annotated) function argument. 79 80 Args: 81 ns: namespace 82 types_ns: types namespace 83 f_name: str, the function name 84 name: str, the argument name 85 type_anno: the type annotating the argument, if any 86 f_is_local: bool, whether the function is a local function 87 Returns: 88 Set of the argument types. 89 """ 90 raise NotImplementedError('subclasses must implement') 91 92 def res_call(self, ns, types_ns, node, f_type, args, keywords): 93 """Resolves the return type an external function or method call. 94 95 Args: 96 ns: namespace 97 types_ns: types namespace 98 node: str, the function name 99 f_type: types of the actual function being called, if known 100 args: types of each respective argument in node.args 101 keywords: types of each respective argument in node.keywords 102 103 Returns: 104 Tuple (return_type, side_effect_types). The first element is just the 105 return types of the function. The second element is a map from 106 argument names to sets of types, and allow modelling side effects of 107 functions (for example via global or nonlocal). 108 """ 109 raise NotImplementedError('subclasses must implement') 110 111 # TODO(mdan): Clean this up. 112 def res_slice(self, ns, types_ns, node_or_slice, value, slice_): 113 """Resolves the return type of slice operation.""" 114 raise NotImplementedError('subclasses must implement') 115 116 def res_compare(self, ns, types_ns, node, left, right): 117 """Resolves the return type of a unary operation.""" 118 raise NotImplementedError('subclasses must implement') 119 120 def res_unop(self, ns, types_ns, node, opnd): 121 """Resolves the return type of a unary operation.""" 122 raise NotImplementedError('subclasses must implement') 123 124 def res_binop(self, ns, types_ns, node, left, right): 125 """Resolves the return type of a binary operation.""" 126 raise NotImplementedError('subclasses must implement') 127 128 def res_list_literal(self, ns, elt_types): 129 """Resolves the type of a list literal from its elements.""" 130 raise NotImplementedError('subclasses must implement') 131 132 133class _TypeMap(object): 134 """Abstraction for the state of the CFG walk for type inference. 135 136 This is a value type. Only implements the strictly necessary operators. 137 138 Attributes: 139 types: Dict[qual_names.QN, Set[Type]], mapping symbols to the set of 140 possible types. 141 """ 142 143 def __init__(self, init_from=None): 144 if init_from: 145 assert isinstance(init_from, _TypeMap) 146 self.types = { 147 s: set(other_types) for s, other_types in init_from.types.items() 148 } 149 else: 150 self.types = {} 151 152 def __eq__(self, other): 153 if frozenset(self.types.keys()) != frozenset(other.types.keys()): 154 return False 155 ret = all(self.types[s] == other.types[s] for s in self.types) 156 return ret 157 158 def __ne__(self, other): 159 return not self.__eq__(other) 160 161 def __or__(self, other): 162 assert isinstance(other, _TypeMap) 163 result = _TypeMap(self) 164 for s, other_types in other.types.items(): 165 if s not in result.types: 166 self_types = set() 167 result.types[s] = self_types 168 else: 169 self_types = result.types[s] 170 self_types.update(other_types) 171 return result 172 173 def __repr__(self): 174 return 'SymbolTable {}'.format(self.types) 175 176 177NO_VALUE = object() 178 179 180class StmtInferrer(gast.NodeVisitor): 181 """Runs type inference on a single AST statement. 182 183 This visitor annotates most nodes with type information. It also sets types 184 for the symbols modified by this statement in its types_out property. 185 186 Note: this inferrer is able to capture side effects of functions, however, 187 these side effects will not be applied to the current expression. Doing so 188 would create too much of a dependence on the runtime's internal rules about 189 execution order. 190 Example: 191 192 def f(): 193 nonlocal a 194 a = 1 195 return a 196 197 a = 0.0 198 b = f() + a # a = float; side effect of f() ignored 199 print(a) # a = int; side effect of f() accounted for 200 """ 201 202 def __init__(self, 203 resolver: Resolver, 204 scope: activity.Scope, 205 namespace: Dict[qual_names.QN, Any], 206 closure_types: Dict[qual_names.QN, Set[Any]], 207 types_in: _TypeMap): 208 self.resolver = resolver 209 self.scope = scope 210 self.namespace = namespace 211 self.closure_types = closure_types 212 self.types_in = types_in 213 self.new_symbols = {} 214 215 # rvalue type. This property is set when encountering an assign operation, 216 # so that visiting nodes with Store ctx (typically found on left side of 217 # assignments) can infer the type they should receive. 218 self.rtype = None 219 220 def visit(self, node): 221 types = super().visit(node) 222 if __debug__: 223 self._check_set(types) 224 if types is not None: 225 # TODO(mdan): Normalize by removing subtypes. 226 anno.setanno(node, anno.Static.TYPES, tuple(types)) 227 return types 228 229 def _check_set(self, value): 230 if value is not None and not isinstance(value, set): 231 raise ValueError('{} method expected to return set, got {}'.format( 232 self.resolver, value)) 233 234 def visit_Constant(self, node): 235 types = self.resolver.res_value(self.namespace, node.value) 236 if __debug__: 237 self._check_set(types) 238 return types 239 240 def _apply_unpacking(self, node): 241 assert isinstance(node.ctx, gast.Store) 242 if self.rtype is not None: 243 original_stype = self.rtype 244 # TODO(mdan): Find a better way to express unpacking. 245 i_type = self.resolver.res_value(self.namespace, 0) 246 for i, elt in enumerate(node.elts): 247 self.rtype = self.resolver.res_slice( 248 self.namespace, self.types_in.types, i, original_stype, i_type) 249 self.visit(elt) 250 self.rtype = original_stype 251 return original_stype 252 return None 253 254 def visit_Tuple(self, node): 255 if isinstance(node.ctx, gast.Load): 256 elt_types = () 257 for elt in node.elts: 258 types_ = self.visit(elt) 259 if types_ is None: 260 return None 261 elt_types += (types_,) 262 return set(itertools.product(*elt_types)) 263 return self._apply_unpacking(node) 264 265 def visit_List(self, node): 266 if isinstance(node.ctx, gast.Load): 267 elt_types = tuple(self.visit(elt) for elt in node.elts) 268 return self.resolver.res_list_literal(self.namespace, elt_types) 269 return self._apply_unpacking(node) 270 271 def visit_Set(self, node): 272 raise NotImplementedError() 273 274 def visit_Name(self, node): 275 name = anno.getanno(node, anno.Basic.QN) 276 277 if isinstance(node.ctx, gast.Load): 278 types = self.types_in.types.get(name, None) 279 if types is None: 280 if (name not in self.scope.bound) or (name in self.scope.nonlocals): 281 # TODO(mdan): Test with global variables. 282 if name in self.closure_types: 283 types = self.closure_types[name] 284 else: 285 types, value = self.resolver.res_name( 286 self.namespace, self.types_in.types, name) 287 if value is not None: 288 anno.setanno(node, anno.Static.VALUE, value) 289 290 elif isinstance(node.ctx, gast.Param): 291 # The direct parent it the whole function scope. See activity.py. 292 f_is_local = self.scope.parent.parent is not None 293 294 type_name = anno.getanno(node.annotation, anno.Basic.QN, None) 295 types = self.resolver.res_arg(self.namespace, self.types_in.types, 296 self.scope.function_name, name, type_name, 297 f_is_local) 298 if types is not None: 299 self.new_symbols[name] = types 300 301 elif isinstance(node.ctx, gast.Store): 302 if self.rtype is not None: 303 self.new_symbols[name] = self.rtype 304 types = self.rtype 305 306 else: 307 assert False, 'unknown ctx' 308 309 if __debug__: 310 self._check_set(types) 311 312 return types 313 314 def visit_Attribute(self, node): 315 parent_types = self.visit(node.value) 316 317 # Attempt to use the static value if known. 318 parent_value = anno.Static.VALUE.of(node.value, None) 319 if parent_value is not None: 320 static_value = getattr(parent_value, node.attr, NO_VALUE) 321 322 if static_value is NO_VALUE: 323 # Unexpected failure to resolve attribute. Ask the resolver about the 324 # full name instead. 325 types, static_value = self.resolver.res_name( 326 self.namespace, self.types_in, anno.Basic.QN.of(node)) 327 anno.setanno(node, anno.Static.VALUE, static_value) 328 if __debug__: 329 self._check_set(types) 330 return types 331 332 else: 333 # Fall back to the type if that is known. 334 if parent_types is None: 335 return None 336 337 inferred_values = [getattr(t, node.attr, None) for t in parent_types] 338 if not inferred_values: 339 return None 340 341 static_value = inferred_values[0] 342 if static_value is None: 343 return None 344 345 if any(v is not static_value for v in inferred_values[1:]): 346 # Static value not stable, assume it's dynamic. 347 return None 348 349 types = self.resolver.res_value(self.namespace, static_value) 350 anno.setanno(node, anno.Static.VALUE, static_value) 351 352 if __debug__: 353 self._check_set(types) 354 355 return types 356 357 def visit_FunctionDef(self, node): 358 f_name = qual_names.QN(node.name) 359 360 if node.decorator_list: 361 raise NotImplementedError('decorators: {}'.format(node.decorator_list)) 362 363 ret_types = None 364 if node.returns: 365 ret_types, _ = self.resolver.res_name( 366 self.namespace, self.types_in.types, anno.Basic.QN.of(node.returns)) 367 if __debug__: 368 self._check_set(ret_types) 369 370 if ret_types is None: 371 ret_types = {Any} 372 373 f_types = set() 374 for rt in ret_types: 375 f_types.add(Callable[[Any], rt]) 376 377 self.new_symbols[f_name] = f_types 378 # The definition of a function is an expression, hence has no return value. 379 return None 380 381 def _resolve_typed_callable(self, f_types, arg_types, keyword_types): 382 ret_types = set() 383 for t in f_types: 384 385 if isinstance(t, Callable): 386 # Note: these are undocummented - may be version-specific! 387 # Callable[[x], y]: __args__ are (x, y) 388 args = t.__args__ 389 if args: 390 ret_types.add(args[-1]) 391 else: 392 ret_types.add(Any) 393 else: 394 raise NotImplementedError('callable type {}'.format(type(t))) 395 396 # Side effects can not be inferred based on type alone. 397 side_effects = None 398 return ret_types, side_effects 399 400 def visit_Call(self, node): 401 self.visit(node.func) 402 403 f_name = anno.Basic.QN.of(node.func) 404 arg_types = [self.visit(a) for a in node.args] 405 keyword_types = [self.visit(kw.value) for kw in node.keywords] 406 407 if f_name in self.scope.bound: 408 # Local function, use local type definitions, if available. 409 f_type = self.types_in.types.get(f_name, None) 410 if f_type is None: 411 # No static type info available, nothing more to do. 412 ret_type, side_effects = None, None 413 else: 414 ret_type, side_effects = self._resolve_typed_callable( 415 f_type, arg_types, keyword_types) 416 417 else: 418 # Nonlocal function, resolve externally. 419 f_type = anno.Static.TYPES.of(node.func, None) 420 ret_type, side_effects = self.resolver.res_call(self.namespace, 421 self.types_in.types, node, 422 f_type, arg_types, 423 keyword_types) 424 425 if __debug__: 426 self._check_set(ret_type) 427 if side_effects: 428 if not isinstance(side_effects, dict): 429 raise ValueError( 430 'side effects must be dict, got {}'.format(side_effects)) 431 for k, v in side_effects.items(): 432 if not isinstance(k, qual_names.QN): 433 raise ValueError('side effect keys must be QNs, got {}'.format(k)) 434 self._check_set(v) 435 436 if side_effects: 437 self.new_symbols.update(side_effects) 438 return ret_type 439 440 def visit_Expr(self, node): 441 return self.visit(node.value) 442 443 def visit_Assign(self, node): 444 self.rtype = self.visit(node.value) 445 446 for t in node.targets: 447 self.visit(t) 448 449 self.rtype = None 450 451 def visit_Subscript(self, node): 452 val_types = self.visit(node.value) 453 slice_types = self.visit(node.slice) 454 455 if val_types is None or slice_types is None: 456 return None 457 458 types = self.resolver.res_slice( 459 self.namespace, self.types_in.types, node, val_types, slice_types) 460 461 if __debug__: 462 self._check_set(types) 463 464 return types 465 466 def visit_Compare(self, node): 467 left_types = self.visit(node.left) 468 right_types = [self.visit(c) for c in node.comparators] 469 470 if left_types is None or any(t is None for t in right_types): 471 return None 472 473 types = self.resolver.res_compare( 474 self.namespace, self.types_in.types, node, left_types, right_types) 475 476 if __debug__: 477 self._check_set(types) 478 479 return types 480 481 def visit_BinOp(self, node): 482 left_types = self.visit(node.left) 483 right_types = self.visit(node.right) 484 485 if left_types is None or right_types is None: 486 return None 487 488 types = self.resolver.res_binop( 489 self.namespace, self.types_in.types, node, left_types, right_types) 490 491 if __debug__: 492 self._check_set(types) 493 494 return types 495 496 def visit_UnaryOp(self, node): 497 opnd_types = self.visit(node.operand) 498 499 if opnd_types is None: 500 return None 501 502 types = self.resolver.res_unop( 503 self.namespace, self.types_in.types, node, opnd_types) 504 505 if __debug__: 506 self._check_set(types) 507 508 return types 509 510 511class Analyzer(cfg.GraphVisitor): 512 """CFG visitor that propagates type information across statements.""" 513 514 def __init__(self, graph, resolver, namespace, scope, closure_types): 515 """Creates a new analyzer. 516 517 Args: 518 graph: cfg.Graph 519 resolver: Resolver 520 namespace: Dict[str, Any] 521 scope: activity.Scope 522 closure_types: Dict[QN, Set] 523 """ 524 super(Analyzer, self).__init__(graph) 525 self.resolver = resolver 526 self.namespace = namespace 527 self.scope = scope 528 self.closure_types = closure_types 529 530 context_types = { 531 n: t for n, t in closure_types.items() if n not in scope.bound 532 } 533 if context_types: 534 self.context_types = _TypeMap() 535 self.context_types.types = context_types 536 else: 537 self.context_types = None 538 539 def init_state(self, _): 540 return _TypeMap() 541 542 def _update_closure_types(self, ast_node, types): 543 existing_types = anno.Static.CLOSURE_TYPES.of(ast_node, None) 544 545 if existing_types is None: 546 existing_types = {} 547 anno.Static.CLOSURE_TYPES.add_to(ast_node, existing_types) 548 549 for k, v in types.types.items(): 550 if k in existing_types: 551 existing_types[k].update(v) 552 else: 553 existing_types[k] = set(v) 554 555 def visit_node(self, node): 556 prev_types_out = self.out[node] 557 558 types_in = _TypeMap() 559 for n in node.prev: 560 types_in |= self.out[n] 561 if (self.context_types is not None) and (node is self.graph.entry): 562 types_in |= self.context_types 563 564 types_out = _TypeMap(types_in) 565 ast_node = node.ast_node 566 567 inferrer = StmtInferrer(self.resolver, self.scope, self.namespace, 568 self.closure_types, types_in) 569 inferrer.visit(ast_node) 570 types_out.types.update(inferrer.new_symbols) 571 572 reaching_fndefs = anno.Static.DEFINED_FNS_IN.of(ast_node) 573 node_scope = anno.Static.SCOPE.of(ast_node, None) 574 if node_scope is not None: 575 # TODO(mdan): Check that it's actually safe to skip nodes without scope. 576 reads = {str(qn) for qn in node_scope.read} 577 for def_node in reaching_fndefs: 578 if def_node.name in reads: 579 self._update_closure_types(def_node, types_out) 580 581 self.in_[node] = types_in 582 self.out[node] = types_out 583 584 return prev_types_out != types_out 585 586 587class FunctionVisitor(transformer.Base): 588 """AST visitor that applies type inference to each function separately.""" 589 590 def __init__(self, source_info, graphs, resolver): 591 super(FunctionVisitor, self).__init__(source_info) 592 self.graphs = graphs 593 self.resolver = resolver 594 595 def visit_FunctionDef(self, node): 596 subgraph = self.graphs[node] 597 scope = anno.getanno(node, annos.NodeAnno.ARGS_AND_BODY_SCOPE) 598 closure_types = anno.getanno(node, anno.Static.CLOSURE_TYPES, {}) 599 600 analyzer = Analyzer(subgraph, self.resolver, self.ctx.info.namespace, scope, 601 closure_types) 602 analyzer.visit_forward() 603 604 # Recursively process any remaining subfunctions. 605 node.body = self.visit_block(node.body) 606 607 return node 608 609 610def resolve(node, source_info, graphs, resolver): 611 """Performs type inference. 612 613 Args: 614 node: ast.AST 615 source_info: transformer.SourceInfo 616 graphs: Dict[ast.FunctionDef, cfg.Graph] 617 resolver: Resolver 618 619 Returns: 620 ast.AST 621 """ 622 visitor = FunctionVisitor(source_info, graphs, resolver) 623 node = visitor.visit(node) 624 return node 625