1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utilities for manipulating qualified names. 16 17A qualified name is a uniform way to refer to simple (e.g. 'foo') and composite 18(e.g. 'foo.bar') syntactic symbols. 19 20This is *not* related to the __qualname__ attribute used by inspect, which 21refers to scopes. 22""" 23 24from __future__ import absolute_import 25from __future__ import division 26from __future__ import print_function 27 28import collections 29 30import gast 31 32from tensorflow.python.autograph.pyct import anno 33from tensorflow.python.autograph.pyct import parser 34 35 36class CallerMustSetThis(object): 37 pass 38 39 40class Symbol(collections.namedtuple('Symbol', ['name'])): 41 """Represents a Python symbol.""" 42 43 44class Literal(collections.namedtuple('Literal', ['value'])): 45 """Represents a Python numeric literal.""" 46 47 def __str__(self): 48 if isinstance(self.value, str): 49 return "'{}'".format(self.value) 50 return str(self.value) 51 52 def __repr__(self): 53 return str(self) 54 55 56# TODO(mdan): Use subclasses to remove the has_attr has_subscript booleans. 57class QN(object): 58 """Represents a qualified name.""" 59 60 def __init__(self, base, attr=None, subscript=None): 61 if attr is not None and subscript is not None: 62 raise ValueError('A QN can only be either an attr or a subscript, not ' 63 'both: attr={}, subscript={}.'.format(attr, subscript)) 64 self._has_attr = False 65 self._has_subscript = False 66 67 if attr is not None: 68 if not isinstance(base, QN): 69 raise ValueError( 70 'for attribute QNs, base must be a QN; got instead "%s"' % base) 71 if not isinstance(attr, str): 72 raise ValueError('attr may only be a string; got instead "%s"' % attr) 73 self._parent = base 74 # TODO(mdan): Get rid of the tuple - it can only have 1 or 2 elements now. 75 self.qn = (base, attr) 76 self._has_attr = True 77 78 elif subscript is not None: 79 if not isinstance(base, QN): 80 raise ValueError('For subscript QNs, base must be a QN.') 81 self._parent = base 82 self.qn = (base, subscript) 83 self._has_subscript = True 84 85 else: 86 if not isinstance(base, (str, Literal)): 87 # TODO(mdan): Require Symbol instead of string. 88 raise ValueError( 89 'for simple QNs, base must be a string or a Literal object;' 90 ' got instead "%s"' % type(base)) 91 assert '.' not in base and '[' not in base and ']' not in base 92 self._parent = None 93 self.qn = (base,) 94 95 def is_symbol(self): 96 return isinstance(self.qn[0], str) 97 98 def is_simple(self): 99 return len(self.qn) <= 1 100 101 def is_composite(self): 102 return len(self.qn) > 1 103 104 def has_subscript(self): 105 return self._has_subscript 106 107 def has_attr(self): 108 return self._has_attr 109 110 @property 111 def attr(self): 112 if not self._has_attr: 113 raise ValueError('Cannot get attr of non-attribute "%s".' % self) 114 return self.qn[1] 115 116 @property 117 def parent(self): 118 if self._parent is None: 119 raise ValueError('Cannot get parent of simple name "%s".' % self.qn[0]) 120 return self._parent 121 122 @property 123 def owner_set(self): 124 """Returns all the symbols (simple or composite) that own this QN. 125 126 In other words, if this symbol was modified, the symbols in the owner set 127 may also be affected. 128 129 Examples: 130 'a.b[c.d]' has two owners, 'a' and 'a.b' 131 """ 132 owners = set() 133 if self.has_attr() or self.has_subscript(): 134 owners.add(self.parent) 135 owners.update(self.parent.owner_set) 136 return owners 137 138 @property 139 def support_set(self): 140 """Returns the set of simple symbols that this QN relies on. 141 142 This would be the smallest set of symbols necessary for the QN to 143 statically resolve (assuming properties and index ranges are verified 144 at runtime). 145 146 Examples: 147 'a.b' has only one support symbol, 'a' 148 'a[i]' has two support symbols, 'a' and 'i' 149 """ 150 # TODO(mdan): This might be the set of Name nodes in the AST. Track those? 151 roots = set() 152 if self.has_attr(): 153 roots.update(self.parent.support_set) 154 elif self.has_subscript(): 155 roots.update(self.parent.support_set) 156 roots.update(self.qn[1].support_set) 157 else: 158 roots.add(self) 159 return roots 160 161 def __hash__(self): 162 return hash(self.qn + (self._has_attr, self._has_subscript)) 163 164 def __eq__(self, other): 165 return (isinstance(other, QN) and self.qn == other.qn and 166 self.has_subscript() == other.has_subscript() and 167 self.has_attr() == other.has_attr()) 168 169 def __lt__(self, other): 170 if isinstance(other, QN): 171 return self.qn < other.qn 172 else: 173 return str(self) < str(other) 174 175 def __gt__(self, other): 176 if isinstance(other, QN): 177 return self.qn > other.qn 178 else: 179 return str(self) > str(other) 180 181 def __str__(self): 182 root = self.qn[0] 183 if self.has_subscript(): 184 return '{}[{}]'.format(root, self.qn[1]) 185 if self.has_attr(): 186 return '.'.join(map(str, self.qn)) 187 else: 188 return str(root) 189 190 def __repr__(self): 191 return str(self) 192 193 def ssf(self): 194 """Simple symbol form.""" 195 ssfs = [n.ssf() if isinstance(n, QN) else n for n in self.qn] 196 ssf_string = '' 197 for i in range(0, len(self.qn) - 1): 198 if self.has_subscript(): 199 delimiter = '_sub_' 200 else: 201 delimiter = '_' 202 ssf_string += ssfs[i] + delimiter 203 return ssf_string + ssfs[-1] 204 205 def ast(self): 206 """AST representation.""" 207 # The caller must adjust the context appropriately. 208 if self.has_subscript(): 209 return gast.Subscript( 210 value=self.parent.ast(), 211 slice=self.qn[-1].ast(), 212 ctx=CallerMustSetThis) 213 if self.has_attr(): 214 return gast.Attribute( 215 value=self.parent.ast(), attr=self.qn[-1], ctx=CallerMustSetThis) 216 217 base = self.qn[0] 218 if isinstance(base, str): 219 return gast.Name( 220 base, ctx=CallerMustSetThis, annotation=None, type_comment=None) 221 elif isinstance(base, Literal): 222 return gast.Constant(base.value, kind=None) 223 else: 224 assert False, ('the constructor should prevent types other than ' 225 'str and Literal') 226 227 228class QnResolver(gast.NodeTransformer): 229 """Annotates nodes with QN information. 230 231 Note: Not using NodeAnnos to avoid circular dependencies. 232 """ 233 234 def visit_Name(self, node): 235 node = self.generic_visit(node) 236 anno.setanno(node, anno.Basic.QN, QN(node.id)) 237 return node 238 239 def visit_Attribute(self, node): 240 node = self.generic_visit(node) 241 if anno.hasanno(node.value, anno.Basic.QN): 242 anno.setanno(node, anno.Basic.QN, 243 QN(anno.getanno(node.value, anno.Basic.QN), attr=node.attr)) 244 return node 245 246 def visit_Subscript(self, node): 247 # TODO(mdan): This may no longer apply if we overload getitem. 248 node = self.generic_visit(node) 249 s = node.slice 250 if isinstance(s, (gast.Tuple, gast.Slice)): 251 # TODO(mdan): Support range and multi-dimensional indices. 252 # Continuing silently because some demos use these. 253 return node 254 if isinstance(s, gast.Constant) and s.value != Ellipsis: 255 subscript = QN(Literal(s.value)) 256 else: 257 # The index may be an expression, case in which a name doesn't make sense. 258 if anno.hasanno(s, anno.Basic.QN): 259 subscript = anno.getanno(s, anno.Basic.QN) 260 else: 261 return node 262 if anno.hasanno(node.value, anno.Basic.QN): 263 anno.setanno(node, anno.Basic.QN, 264 QN(anno.getanno(node.value, anno.Basic.QN), 265 subscript=subscript)) 266 return node 267 268 269def resolve(node): 270 return QnResolver().visit(node) 271 272 273def from_str(qn_str): 274 node = parser.parse_expression(qn_str) 275 node = resolve(node) 276 return anno.getanno(node, anno.Basic.QN) 277