• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities for manipulating qualified names.
16
17A qualified name is a uniform way to refer to simple (e.g. 'foo') and composite
18(e.g. 'foo.bar') syntactic symbols.
19
20This is *not* related to the __qualname__ attribute used by inspect, which
21refers to scopes.
22"""
23
24from __future__ import absolute_import
25from __future__ import division
26from __future__ import print_function
27
28import collections
29
30import gast
31
32from tensorflow.python.autograph.pyct import anno
33from tensorflow.python.autograph.pyct import parser
34
35
36class Symbol(collections.namedtuple('Symbol', ['name'])):
37  """Represents a Python symbol."""
38
39
40class StringLiteral(collections.namedtuple('StringLiteral', ['value'])):
41  """Represents a Python string literal."""
42
43  def __str__(self):
44    return '\'%s\'' % self.value
45
46  def __repr__(self):
47    return str(self)
48
49
50class NumberLiteral(collections.namedtuple('NumberLiteral', ['value'])):
51  """Represents a Python numeric literal."""
52
53  def __str__(self):
54    return '%s' % self.value
55
56  def __repr__(self):
57    return str(self)
58
59
60# TODO(mdan): Use subclasses to remove the has_attr has_subscript booleans.
61class QN(object):
62  """Represents a qualified name."""
63
64  def __init__(self, base, attr=None, subscript=None):
65    if attr is not None and subscript is not None:
66      raise ValueError('A QN can only be either an attr or a subscript, not '
67                       'both: attr={}, subscript={}.'.format(attr, subscript))
68    self._has_attr = False
69    self._has_subscript = False
70
71    if attr is not None:
72      if not isinstance(base, QN):
73        raise ValueError(
74            'for attribute QNs, base must be a QN; got instead "%s"' % base)
75      if not isinstance(attr, str):
76        raise ValueError('attr may only be a string; got instead "%s"' % attr)
77      self._parent = base
78      # TODO(mdan): Get rid of the tuple - it can only have 1 or 2 elements now.
79      self.qn = (base, attr)
80      self._has_attr = True
81
82    elif subscript is not None:
83      if not isinstance(base, QN):
84        raise ValueError('For subscript QNs, base must be a QN.')
85      self._parent = base
86      self.qn = (base, subscript)
87      self._has_subscript = True
88
89    else:
90      if not isinstance(base, (str, StringLiteral, NumberLiteral)):
91        # TODO(mdan): Require Symbol instead of string.
92        raise ValueError(
93            'for simple QNs, base must be a string or a Literal object;'
94            ' got instead "%s"' % type(base))
95      assert '.' not in base and '[' not in base and ']' not in base
96      self._parent = None
97      self.qn = (base,)
98
99  def is_symbol(self):
100    return isinstance(self.qn[0], str)
101
102  def is_simple(self):
103    return len(self.qn) <= 1
104
105  def is_composite(self):
106    return len(self.qn) > 1
107
108  def has_subscript(self):
109    return self._has_subscript
110
111  def has_attr(self):
112    return self._has_attr
113
114  @property
115  def parent(self):
116    if self._parent is None:
117      raise ValueError('Cannot get parent of simple name "%s".' % self.qn[0])
118    return self._parent
119
120  @property
121  def owner_set(self):
122    """Returns all the symbols (simple or composite) that own this QN.
123
124    In other words, if this symbol was modified, the symbols in the owner set
125    may also be affected.
126
127    Examples:
128      'a.b[c.d]' has two owners, 'a' and 'a.b'
129    """
130    owners = set()
131    if self.has_attr() or self.has_subscript():
132      owners.add(self.parent)
133      owners.update(self.parent.owner_set)
134    return owners
135
136  @property
137  def support_set(self):
138    """Returns the set of simple symbols that this QN relies on.
139
140    This would be the smallest set of symbols necessary for the QN to
141    statically resolve (assuming properties and index ranges are verified
142    at runtime).
143
144    Examples:
145      'a.b' has only one support symbol, 'a'
146      'a[i]' has two support symbols, 'a' and 'i'
147    """
148    # TODO(mdan): This might be the set of Name nodes in the AST. Track those?
149    roots = set()
150    if self.has_attr():
151      roots.update(self.parent.support_set)
152    elif self.has_subscript():
153      roots.update(self.parent.support_set)
154      roots.update(self.qn[1].support_set)
155    else:
156      roots.add(self)
157    return roots
158
159  def __hash__(self):
160    return hash(self.qn + (self._has_attr, self._has_subscript))
161
162  def __eq__(self, other):
163    return (isinstance(other, QN) and self.qn == other.qn and
164            self.has_subscript() == other.has_subscript() and
165            self.has_attr() == other.has_attr())
166
167  def __str__(self):
168    if self.has_subscript():
169      return str(self.qn[0]) + '[' + str(self.qn[1]) + ']'
170    if self.has_attr():
171      return '.'.join(map(str, self.qn))
172    else:
173      return str(self.qn[0])
174
175  def __repr__(self):
176    return str(self)
177
178  def ssf(self):
179    """Simple symbol form."""
180    ssfs = [n.ssf() if isinstance(n, QN) else n for n in self.qn]
181    ssf_string = ''
182    for i in range(0, len(self.qn) - 1):
183      if self.has_subscript():
184        delimiter = '_sub_'
185      else:
186        delimiter = '_'
187      ssf_string += ssfs[i] + delimiter
188    return ssf_string + ssfs[-1]
189
190  def ast(self):
191    # The caller must adjust the context appropriately.
192    if self.has_subscript():
193      return gast.Subscript(self.parent.ast(), gast.Index(self.qn[-1].ast()),
194                            None)
195    if self.has_attr():
196      return gast.Attribute(self.parent.ast(), self.qn[-1], None)
197
198    base = self.qn[0]
199    if isinstance(base, str):
200      return gast.Name(base, None, None)
201    elif isinstance(base, StringLiteral):
202      return gast.Str(base.value)
203    elif isinstance(base, NumberLiteral):
204      return gast.Num(base.value)
205    else:
206      assert False, ('the constructor should prevent types other than '
207                     'str, StringLiteral and NumberLiteral')
208
209
210class QnResolver(gast.NodeTransformer):
211  """Annotates nodes with QN information.
212
213  Note: Not using NodeAnnos to avoid circular dependencies.
214  """
215
216  def visit_Name(self, node):
217    node = self.generic_visit(node)
218    anno.setanno(node, anno.Basic.QN, QN(node.id))
219    return node
220
221  def visit_Attribute(self, node):
222    node = self.generic_visit(node)
223    if anno.hasanno(node.value, anno.Basic.QN):
224      anno.setanno(node, anno.Basic.QN,
225                   QN(anno.getanno(node.value, anno.Basic.QN), attr=node.attr))
226    return node
227
228  def visit_Subscript(self, node):
229    # TODO(mdan): This may no longer apply if we overload getitem.
230    node = self.generic_visit(node)
231    s = node.slice
232    if not isinstance(s, gast.Index):
233      # TODO(mdan): Support range and multi-dimensional indices.
234      # Continuing silently because some demos use these.
235      return node
236    if isinstance(s.value, gast.Num):
237      subscript = QN(NumberLiteral(s.value.n))
238    elif isinstance(s.value, gast.Str):
239      subscript = QN(StringLiteral(s.value.s))
240    else:
241      # The index may be an expression, case in which a name doesn't make sense.
242      if anno.hasanno(node.slice.value, anno.Basic.QN):
243        subscript = anno.getanno(node.slice.value, anno.Basic.QN)
244      else:
245        return node
246    if anno.hasanno(node.value, anno.Basic.QN):
247      anno.setanno(node, anno.Basic.QN,
248                   QN(anno.getanno(node.value, anno.Basic.QN),
249                      subscript=subscript))
250    return node
251
252
253def resolve(node):
254  return QnResolver().visit(node)
255
256
257def from_str(qn_str):
258  node = parser.parse_expression(qn_str)
259  node = resolve(node)
260  return anno.getanno(node, anno.Basic.QN)
261