1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utilities for manipulating qualified names. 16 17A qualified name is a uniform way to refer to simple (e.g. 'foo') and composite 18(e.g. 'foo.bar') syntactic symbols. 19 20This is *not* related to the __qualname__ attribute used by inspect, which 21refers to scopes. 22""" 23 24from __future__ import absolute_import 25from __future__ import division 26from __future__ import print_function 27 28import collections 29 30import gast 31 32from tensorflow.python.autograph.pyct import anno 33from tensorflow.python.autograph.pyct import parser 34 35 36class Symbol(collections.namedtuple('Symbol', ['name'])): 37 """Represents a Python symbol.""" 38 39 40class StringLiteral(collections.namedtuple('StringLiteral', ['value'])): 41 """Represents a Python string literal.""" 42 43 def __str__(self): 44 return '\'%s\'' % self.value 45 46 def __repr__(self): 47 return str(self) 48 49 50class NumberLiteral(collections.namedtuple('NumberLiteral', ['value'])): 51 """Represents a Python numeric literal.""" 52 53 def __str__(self): 54 return '%s' % self.value 55 56 def __repr__(self): 57 return str(self) 58 59 60# TODO(mdan): Use subclasses to remove the has_attr has_subscript booleans. 61class QN(object): 62 """Represents a qualified name.""" 63 64 def __init__(self, base, attr=None, subscript=None): 65 if attr is not None and subscript is not None: 66 raise ValueError('A QN can only be either an attr or a subscript, not ' 67 'both: attr={}, subscript={}.'.format(attr, subscript)) 68 self._has_attr = False 69 self._has_subscript = False 70 71 if attr is not None: 72 if not isinstance(base, QN): 73 raise ValueError( 74 'for attribute QNs, base must be a QN; got instead "%s"' % base) 75 if not isinstance(attr, str): 76 raise ValueError('attr may only be a string; got instead "%s"' % attr) 77 self._parent = base 78 # TODO(mdan): Get rid of the tuple - it can only have 1 or 2 elements now. 79 self.qn = (base, attr) 80 self._has_attr = True 81 82 elif subscript is not None: 83 if not isinstance(base, QN): 84 raise ValueError('For subscript QNs, base must be a QN.') 85 self._parent = base 86 self.qn = (base, subscript) 87 self._has_subscript = True 88 89 else: 90 if not isinstance(base, (str, StringLiteral, NumberLiteral)): 91 # TODO(mdan): Require Symbol instead of string. 92 raise ValueError( 93 'for simple QNs, base must be a string or a Literal object;' 94 ' got instead "%s"' % type(base)) 95 assert '.' not in base and '[' not in base and ']' not in base 96 self._parent = None 97 self.qn = (base,) 98 99 def is_symbol(self): 100 return isinstance(self.qn[0], str) 101 102 def is_simple(self): 103 return len(self.qn) <= 1 104 105 def is_composite(self): 106 return len(self.qn) > 1 107 108 def has_subscript(self): 109 return self._has_subscript 110 111 def has_attr(self): 112 return self._has_attr 113 114 @property 115 def parent(self): 116 if self._parent is None: 117 raise ValueError('Cannot get parent of simple name "%s".' % self.qn[0]) 118 return self._parent 119 120 @property 121 def owner_set(self): 122 """Returns all the symbols (simple or composite) that own this QN. 123 124 In other words, if this symbol was modified, the symbols in the owner set 125 may also be affected. 126 127 Examples: 128 'a.b[c.d]' has two owners, 'a' and 'a.b' 129 """ 130 owners = set() 131 if self.has_attr() or self.has_subscript(): 132 owners.add(self.parent) 133 owners.update(self.parent.owner_set) 134 return owners 135 136 @property 137 def support_set(self): 138 """Returns the set of simple symbols that this QN relies on. 139 140 This would be the smallest set of symbols necessary for the QN to 141 statically resolve (assuming properties and index ranges are verified 142 at runtime). 143 144 Examples: 145 'a.b' has only one support symbol, 'a' 146 'a[i]' has two support symbols, 'a' and 'i' 147 """ 148 # TODO(mdan): This might be the set of Name nodes in the AST. Track those? 149 roots = set() 150 if self.has_attr(): 151 roots.update(self.parent.support_set) 152 elif self.has_subscript(): 153 roots.update(self.parent.support_set) 154 roots.update(self.qn[1].support_set) 155 else: 156 roots.add(self) 157 return roots 158 159 def __hash__(self): 160 return hash(self.qn + (self._has_attr, self._has_subscript)) 161 162 def __eq__(self, other): 163 return (isinstance(other, QN) and self.qn == other.qn and 164 self.has_subscript() == other.has_subscript() and 165 self.has_attr() == other.has_attr()) 166 167 def __str__(self): 168 if self.has_subscript(): 169 return str(self.qn[0]) + '[' + str(self.qn[1]) + ']' 170 if self.has_attr(): 171 return '.'.join(map(str, self.qn)) 172 else: 173 return str(self.qn[0]) 174 175 def __repr__(self): 176 return str(self) 177 178 def ssf(self): 179 """Simple symbol form.""" 180 ssfs = [n.ssf() if isinstance(n, QN) else n for n in self.qn] 181 ssf_string = '' 182 for i in range(0, len(self.qn) - 1): 183 if self.has_subscript(): 184 delimiter = '_sub_' 185 else: 186 delimiter = '_' 187 ssf_string += ssfs[i] + delimiter 188 return ssf_string + ssfs[-1] 189 190 def ast(self): 191 # The caller must adjust the context appropriately. 192 if self.has_subscript(): 193 return gast.Subscript(self.parent.ast(), gast.Index(self.qn[-1].ast()), 194 None) 195 if self.has_attr(): 196 return gast.Attribute(self.parent.ast(), self.qn[-1], None) 197 198 base = self.qn[0] 199 if isinstance(base, str): 200 return gast.Name(base, None, None) 201 elif isinstance(base, StringLiteral): 202 return gast.Str(base.value) 203 elif isinstance(base, NumberLiteral): 204 return gast.Num(base.value) 205 else: 206 assert False, ('the constructor should prevent types other than ' 207 'str, StringLiteral and NumberLiteral') 208 209 210class QnResolver(gast.NodeTransformer): 211 """Annotates nodes with QN information. 212 213 Note: Not using NodeAnnos to avoid circular dependencies. 214 """ 215 216 def visit_Name(self, node): 217 node = self.generic_visit(node) 218 anno.setanno(node, anno.Basic.QN, QN(node.id)) 219 return node 220 221 def visit_Attribute(self, node): 222 node = self.generic_visit(node) 223 if anno.hasanno(node.value, anno.Basic.QN): 224 anno.setanno(node, anno.Basic.QN, 225 QN(anno.getanno(node.value, anno.Basic.QN), attr=node.attr)) 226 return node 227 228 def visit_Subscript(self, node): 229 # TODO(mdan): This may no longer apply if we overload getitem. 230 node = self.generic_visit(node) 231 s = node.slice 232 if not isinstance(s, gast.Index): 233 # TODO(mdan): Support range and multi-dimensional indices. 234 # Continuing silently because some demos use these. 235 return node 236 if isinstance(s.value, gast.Num): 237 subscript = QN(NumberLiteral(s.value.n)) 238 elif isinstance(s.value, gast.Str): 239 subscript = QN(StringLiteral(s.value.s)) 240 else: 241 # The index may be an expression, case in which a name doesn't make sense. 242 if anno.hasanno(node.slice.value, anno.Basic.QN): 243 subscript = anno.getanno(node.slice.value, anno.Basic.QN) 244 else: 245 return node 246 if anno.hasanno(node.value, anno.Basic.QN): 247 anno.setanno(node, anno.Basic.QN, 248 QN(anno.getanno(node.value, anno.Basic.QN), 249 subscript=subscript)) 250 return node 251 252 253def resolve(node): 254 return QnResolver().visit(node) 255 256 257def from_str(qn_str): 258 node = parser.parse_expression(qn_str) 259 node = resolve(node) 260 return anno.getanno(node, anno.Basic.QN) 261