1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""AST node annotation support. 16 17Adapted from Tangent. 18""" 19 20from __future__ import absolute_import 21from __future__ import division 22from __future__ import print_function 23 24import enum 25 26# pylint:disable=g-bad-import-order 27 28import gast 29# pylint:enable=g-bad-import-order 30 31 32# TODO(mdan): Shorten the names. 33# These names are heavily used, and anno.blaa 34# TODO(mdan): Replace the attr-dict mechanism with a more typed solution. 35 36 37class NoValue(enum.Enum): 38 """Base class for different types of AST annotations.""" 39 40 def of(self, node, default=None): 41 return getanno(node, self, default=default) 42 43 def add_to(self, node, value): 44 setanno(node, self, value) 45 46 def exists(self, node): 47 return hasanno(node, self) 48 49 def __repr__(self): 50 return str(self.name) 51 52 53class Basic(NoValue): 54 """Container for basic annotation keys. 55 56 The enum values are used strictly for documentation purposes. 57 """ 58 59 QN = 'Qualified name, as it appeared in the code. See qual_names.py.' 60 SKIP_PROCESSING = ( 61 'This node should be preserved as is and not processed any further.') 62 INDENT_BLOCK_REMAINDER = ( 63 'When a node is annotated with this, the remainder of the block should' 64 ' be indented below it. The annotation contains a tuple' 65 ' (new_body, name_map), where `new_body` is the new indented block and' 66 ' `name_map` allows renaming symbols.') 67 ORIGIN = ('Information about the source code that converted code originated' 68 ' from. See origin_information.py.') 69 DIRECTIVES = ('User directives associated with a statement or a variable.' 70 ' Typically, they affect the immediately-enclosing statement.') 71 72 EXTRA_LOOP_TEST = ( 73 'A special annotation containing additional test code to be executed in' 74 ' for loops.') 75 76 77class Static(NoValue): 78 """Container for static analysis annotation keys. 79 80 The enum values are used strictly for documentation purposes. 81 """ 82 83 # Symbols 84 # These flags are boolean. 85 IS_PARAM = 'Symbol is a parameter to the function being analyzed.' 86 87 # Scopes 88 # Scopes are represented by objects of type activity.Scope. 89 SCOPE = 'The scope for the annotated node. See activity.py.' 90 # TODO(mdan): Drop these in favor of accessing the child's SCOPE. 91 ARGS_SCOPE = 'The scope for the argument list of a function call.' 92 COND_SCOPE = 'The scope for the test node of a conditional statement.' 93 BODY_SCOPE = ( 94 'The scope for the main body of a statement (True branch for if ' 95 'statements, main body for loops).') 96 ORELSE_SCOPE = ( 97 'The scope for the orelse body of a statement (False branch for if ' 98 'statements, orelse body for loops).') 99 100 # Static analysis annotations. 101 DEFINITIONS = ( 102 'Reaching definition information. See reaching_definitions.py.') 103 ORIG_DEFINITIONS = ( 104 'The value of DEFINITIONS that applied to the original code before any' 105 ' conversion.') 106 DEFINED_FNS_IN = ( 107 'Local function definitions that may exist when exiting the node. See' 108 ' reaching_fndefs.py') 109 DEFINED_VARS_IN = ( 110 'Symbols defined when entering the node. See reaching_definitions.py.') 111 LIVE_VARS_OUT = ('Symbols live when exiting the node. See liveness.py.') 112 LIVE_VARS_IN = ('Symbols live when entering the node. See liveness.py.') 113 TYPES = 'Static type information. See type_inference.py.' 114 CLOSURE_TYPES = 'Types of closure symbols at each detected call site.' 115 VALUE = 'Static value information. See type_inference.py.' 116 117 118FAIL = object() 119 120 121def keys(node, field_name='___pyct_anno'): 122 if not hasattr(node, field_name): 123 return frozenset() 124 return frozenset(getattr(node, field_name).keys()) 125 126 127def getanno(node, key, default=FAIL, field_name='___pyct_anno'): 128 if (default is FAIL or (hasattr(node, field_name) and 129 (key in getattr(node, field_name)))): 130 return getattr(node, field_name)[key] 131 return default 132 133 134def hasanno(node, key, field_name='___pyct_anno'): 135 return hasattr(node, field_name) and key in getattr(node, field_name) 136 137 138def setanno(node, key, value, field_name='___pyct_anno'): 139 annotations = getattr(node, field_name, {}) 140 setattr(node, field_name, annotations) 141 annotations[key] = value 142 143 # So that the annotations survive gast_to_ast() and ast_to_gast() 144 if field_name not in node._fields: 145 node._fields += (field_name,) 146 147 148def delanno(node, key, field_name='___pyct_anno'): 149 annotations = getattr(node, field_name) 150 del annotations[key] 151 if not annotations: 152 delattr(node, field_name) 153 node._fields = tuple(f for f in node._fields if f != field_name) 154 155 156def copyanno(from_node, to_node, key, field_name='___pyct_anno'): 157 if hasanno(from_node, key, field_name=field_name): 158 setanno( 159 to_node, 160 key, 161 getanno(from_node, key, field_name=field_name), 162 field_name=field_name) 163 164 165def dup(node, copy_map, field_name='___pyct_anno'): 166 """Recursively copies annotations in an AST tree. 167 168 Args: 169 node: ast.AST 170 copy_map: Dict[Hashable, Hashable], maps a source anno key to a destination 171 key. All annotations with the source key will be copied to identical 172 annotations with the destination key. 173 field_name: str 174 """ 175 for n in gast.walk(node): 176 for k in copy_map: 177 if hasanno(n, k, field_name): 178 setanno(n, copy_map[k], getanno(n, k, field_name), field_name) 179