• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for qual_names module."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import textwrap
22
23from tensorflow.python.autograph.pyct import anno
24from tensorflow.python.autograph.pyct import parser
25from tensorflow.python.autograph.pyct import qual_names
26from tensorflow.python.autograph.pyct.qual_names import QN
27from tensorflow.python.autograph.pyct.qual_names import resolve
28from tensorflow.python.platform import test
29
30
31class QNTest(test.TestCase):
32
33  def test_from_str(self):
34    a = QN('a')
35    b = QN('b')
36    a_dot_b = QN(a, attr='b')
37    a_sub_b = QN(a, subscript=b)
38    self.assertEqual(qual_names.from_str('a.b'), a_dot_b)
39    self.assertEqual(qual_names.from_str('a'), a)
40    self.assertEqual(qual_names.from_str('a[b]'), a_sub_b)
41
42  def test_basic(self):
43    a = QN('a')
44    self.assertEqual(a.qn, ('a',))
45    self.assertEqual(str(a), 'a')
46    self.assertEqual(a.ssf(), 'a')
47    self.assertEqual(a.ast().id, 'a')
48    self.assertFalse(a.is_composite())
49    with self.assertRaises(ValueError):
50      _ = a.parent
51
52    a_b = QN(a, attr='b')
53    self.assertEqual(a_b.qn, (a, 'b'))
54    self.assertEqual(str(a_b), 'a.b')
55    self.assertEqual(a_b.ssf(), 'a_b')
56    self.assertEqual(a_b.ast().value.id, 'a')
57    self.assertEqual(a_b.ast().attr, 'b')
58    self.assertTrue(a_b.is_composite())
59    self.assertEqual(a_b.parent.qn, ('a',))
60
61  def test_subscripts(self):
62    a = QN('a')
63    b = QN('b')
64    a_sub_b = QN(a, subscript=b)
65    self.assertEqual(a_sub_b.qn, (a, b))
66    self.assertEqual(str(a_sub_b), 'a[b]')
67    self.assertEqual(a_sub_b.ssf(), 'a_sub_b')
68    self.assertEqual(a_sub_b.ast().value.id, 'a')
69    self.assertEqual(a_sub_b.ast().slice.value.id, 'b')
70    self.assertTrue(a_sub_b.is_composite())
71    self.assertTrue(a_sub_b.has_subscript())
72    self.assertEqual(a_sub_b.parent.qn, ('a',))
73
74    c = QN('c')
75    b_sub_c = QN(b, subscript=c)
76    a_sub_b_sub_c = QN(a, subscript=b_sub_c)
77    self.assertEqual(a_sub_b_sub_c.qn, (a, b_sub_c))
78    self.assertTrue(a_sub_b.is_composite())
79    self.assertTrue(a_sub_b_sub_c.is_composite())
80    self.assertTrue(a_sub_b.has_subscript())
81    self.assertTrue(a_sub_b_sub_c.has_subscript())
82    self.assertEqual(b_sub_c.qn, (b, c))
83    self.assertEqual(str(a_sub_b_sub_c), 'a[b[c]]')
84    self.assertEqual(a_sub_b_sub_c.ssf(), 'a_sub_b_sub_c')
85    self.assertEqual(a_sub_b_sub_c.ast().value.id, 'a')
86    self.assertEqual(a_sub_b_sub_c.ast().slice.value.value.id, 'b')
87    self.assertEqual(a_sub_b_sub_c.ast().slice.value.slice.value.id, 'c')
88    self.assertEqual(b_sub_c.ast().slice.value.id, 'c')
89    self.assertEqual(a_sub_b_sub_c.parent.qn, ('a',))
90    with self.assertRaises(ValueError):
91      QN('a', 'b')
92
93  def test_equality(self):
94    a = QN('a')
95    a2 = QN('a')
96    a_b = QN(a, attr='b')
97    self.assertEqual(a2.qn, ('a',))
98    with self.assertRaises(ValueError):
99      _ = a.parent
100
101    a_b2 = QN(a, attr='b')
102    self.assertEqual(a_b2.qn, (a, 'b'))
103    self.assertEqual(a_b2.parent.qn, ('a',))
104
105    self.assertTrue(a2 == a)
106    self.assertFalse(a2 is a)
107
108    self.assertTrue(a_b.parent == a)
109    self.assertTrue(a_b2.parent == a)
110
111    self.assertTrue(a_b2 == a_b)
112    self.assertFalse(a_b2 is a_b)
113    self.assertFalse(a_b2 == a)
114    a_sub_b = QN(a, subscript='b')
115    a_sub_b2 = QN(a, subscript='b')
116    self.assertTrue(a_sub_b == a_sub_b2)
117    self.assertFalse(a_sub_b == a_b)
118
119  def test_nested_attrs_subscripts(self):
120    a = QN('a')
121    b = QN('b')
122    c = QN('c')
123    b_sub_c = QN(b, subscript=c)
124    a_sub_b_sub_c = QN(a, subscript=b_sub_c)
125
126    b_dot_c = QN(b, attr='c')
127    a_sub__b_dot_c = QN(a, subscript=b_dot_c)
128
129    a_sub_b = QN(a, subscript=b)
130    a_sub_b__dot_c = QN(a_sub_b, attr='c')
131
132    a_dot_b = QN(a, attr='b')
133    a_dot_b_sub_c = QN(a_dot_b, subscript=c)
134
135    self.assertEqual(str(a_sub_b_sub_c), 'a[b[c]]')
136    self.assertEqual(str(a_sub__b_dot_c), 'a[b.c]')
137    self.assertEqual(str(a_sub_b__dot_c), 'a[b].c')
138    self.assertEqual(str(a_dot_b_sub_c), 'a.b[c]')
139
140    self.assertNotEqual(a_sub_b_sub_c, a_sub__b_dot_c)
141    self.assertNotEqual(a_sub_b_sub_c, a_sub_b__dot_c)
142    self.assertNotEqual(a_sub_b_sub_c, a_dot_b_sub_c)
143
144    self.assertNotEqual(a_sub__b_dot_c, a_sub_b__dot_c)
145    self.assertNotEqual(a_sub__b_dot_c, a_dot_b_sub_c)
146
147    self.assertNotEqual(a_sub_b__dot_c, a_dot_b_sub_c)
148
149  def test_hashable(self):
150    d = {QN('a'): 'a', QN('b'): 'b'}
151    self.assertEqual(d[QN('a')], 'a')
152    self.assertEqual(d[QN('b')], 'b')
153    self.assertTrue(QN('c') not in d)
154
155  def test_literals(self):
156    a = QN('a')
157    a_sub_str_b = QN(a, subscript=QN(qual_names.StringLiteral('b')))
158    a_sub_b = QN(a, subscript=QN('b'))
159
160    self.assertNotEqual(a_sub_str_b, a_sub_b)
161    self.assertNotEqual(hash(a_sub_str_b), hash(a_sub_b))
162
163    a_sub_three = QN(a, subscript=QN(qual_names.NumberLiteral(3)))
164    self.assertEqual(a_sub_three.ast().slice.value.n, 3)
165
166  def test_support_set(self):
167    a = QN('a')
168    b = QN('b')
169    c = QN('c')
170    a_sub_b = QN(a, subscript=b)
171    a_dot_b = QN(a, attr='b')
172    a_dot_b_dot_c = QN(a_dot_b, attr='c')
173    a_dot_b_sub_c = QN(a_dot_b, subscript=c)
174
175    self.assertSetEqual(a.support_set, set((a,)))
176    self.assertSetEqual(a_sub_b.support_set, set((a, b)))
177    self.assertSetEqual(a_dot_b.support_set, set((a,)))
178    self.assertSetEqual(a_dot_b_dot_c.support_set, set((a,)))
179    self.assertSetEqual(a_dot_b_sub_c.support_set, set((a, c)))
180
181
182class QNResolverTest(test.TestCase):
183
184  def assertQNStringIs(self, node, qn_str):
185    self.assertEqual(str(anno.getanno(node, anno.Basic.QN)), qn_str)
186
187  def test_resolve(self):
188    samples = """
189      a
190      a.b
191      (c, d.e)
192      [f, (g.h.i)]
193      j(k, l)
194    """
195    nodes = resolve(parser.parse_str(textwrap.dedent(samples)))
196    nodes = tuple(n.value for n in nodes.body)
197
198    self.assertQNStringIs(nodes[0], 'a')
199    self.assertQNStringIs(nodes[1], 'a.b')
200    self.assertQNStringIs(nodes[2].elts[0], 'c')
201    self.assertQNStringIs(nodes[2].elts[1], 'd.e')
202    self.assertQNStringIs(nodes[3].elts[0], 'f')
203    self.assertQNStringIs(nodes[3].elts[1], 'g.h.i')
204    self.assertQNStringIs(nodes[4].func, 'j')
205    self.assertQNStringIs(nodes[4].args[0], 'k')
206    self.assertQNStringIs(nodes[4].args[1], 'l')
207
208  def test_subscript_resolve(self):
209    samples = """
210      x[i]
211      x[i.b]
212      a.b[c]
213      a.b[x.y]
214      a[z[c]]
215      a[b[c[d]]]
216      a[b].c
217      a.b.c[d].e.f
218      a.b[c[d]].e.f
219      a.b[c[d.e.f].g].h
220    """
221    nodes = resolve(parser.parse_str(textwrap.dedent(samples)))
222    nodes = tuple(n.value for n in nodes.body)
223
224    self.assertQNStringIs(nodes[0], 'x[i]')
225    self.assertQNStringIs(nodes[1], 'x[i.b]')
226    self.assertQNStringIs(nodes[2], 'a.b[c]')
227    self.assertQNStringIs(nodes[3], 'a.b[x.y]')
228    self.assertQNStringIs(nodes[4], 'a[z[c]]')
229    self.assertQNStringIs(nodes[5], 'a[b[c[d]]]')
230    self.assertQNStringIs(nodes[6], 'a[b].c')
231    self.assertQNStringIs(nodes[7], 'a.b.c[d].e.f')
232    self.assertQNStringIs(nodes[8], 'a.b[c[d]].e.f')
233    self.assertQNStringIs(nodes[9], 'a.b[c[d.e.f].g].h')
234
235  def test_function_calls(self):
236    samples = """
237      a.b
238      a.b()
239      a().b
240      z[i]
241      z[i]()
242      z()[i]
243    """
244    nodes = resolve(parser.parse_str(textwrap.dedent(samples)))
245    nodes = tuple(n.value for n in nodes.body)
246    self.assertQNStringIs(nodes[0], 'a.b')
247    self.assertQNStringIs(nodes[1].func, 'a.b')
248    self.assertQNStringIs(nodes[2].value.func, 'a')
249    self.assertQNStringIs(nodes[3], 'z[i]')
250    self.assertQNStringIs(nodes[4].func, 'z[i]')
251    self.assertQNStringIs(nodes[5].value.func, 'z')
252
253
254if __name__ == '__main__':
255  test.main()
256