1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for qual_names module.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import textwrap 22 23from tensorflow.python.autograph.pyct import anno 24from tensorflow.python.autograph.pyct import parser 25from tensorflow.python.autograph.pyct import qual_names 26from tensorflow.python.autograph.pyct.qual_names import QN 27from tensorflow.python.autograph.pyct.qual_names import resolve 28from tensorflow.python.platform import test 29 30 31class QNTest(test.TestCase): 32 33 def test_from_str(self): 34 a = QN('a') 35 b = QN('b') 36 a_dot_b = QN(a, attr='b') 37 a_sub_b = QN(a, subscript=b) 38 self.assertEqual(qual_names.from_str('a.b'), a_dot_b) 39 self.assertEqual(qual_names.from_str('a'), a) 40 self.assertEqual(qual_names.from_str('a[b]'), a_sub_b) 41 42 def test_basic(self): 43 a = QN('a') 44 self.assertEqual(a.qn, ('a',)) 45 self.assertEqual(str(a), 'a') 46 self.assertEqual(a.ssf(), 'a') 47 self.assertEqual(a.ast().id, 'a') 48 self.assertFalse(a.is_composite()) 49 with self.assertRaises(ValueError): 50 _ = a.parent 51 52 a_b = QN(a, attr='b') 53 self.assertEqual(a_b.qn, (a, 'b')) 54 self.assertEqual(str(a_b), 'a.b') 55 self.assertEqual(a_b.ssf(), 'a_b') 56 self.assertEqual(a_b.ast().value.id, 'a') 57 self.assertEqual(a_b.ast().attr, 'b') 58 self.assertTrue(a_b.is_composite()) 59 self.assertEqual(a_b.parent.qn, ('a',)) 60 61 def test_subscripts(self): 62 a = QN('a') 63 b = QN('b') 64 a_sub_b = QN(a, subscript=b) 65 self.assertEqual(a_sub_b.qn, (a, b)) 66 self.assertEqual(str(a_sub_b), 'a[b]') 67 self.assertEqual(a_sub_b.ssf(), 'a_sub_b') 68 self.assertEqual(a_sub_b.ast().value.id, 'a') 69 self.assertEqual(a_sub_b.ast().slice.value.id, 'b') 70 self.assertTrue(a_sub_b.is_composite()) 71 self.assertTrue(a_sub_b.has_subscript()) 72 self.assertEqual(a_sub_b.parent.qn, ('a',)) 73 74 c = QN('c') 75 b_sub_c = QN(b, subscript=c) 76 a_sub_b_sub_c = QN(a, subscript=b_sub_c) 77 self.assertEqual(a_sub_b_sub_c.qn, (a, b_sub_c)) 78 self.assertTrue(a_sub_b.is_composite()) 79 self.assertTrue(a_sub_b_sub_c.is_composite()) 80 self.assertTrue(a_sub_b.has_subscript()) 81 self.assertTrue(a_sub_b_sub_c.has_subscript()) 82 self.assertEqual(b_sub_c.qn, (b, c)) 83 self.assertEqual(str(a_sub_b_sub_c), 'a[b[c]]') 84 self.assertEqual(a_sub_b_sub_c.ssf(), 'a_sub_b_sub_c') 85 self.assertEqual(a_sub_b_sub_c.ast().value.id, 'a') 86 self.assertEqual(a_sub_b_sub_c.ast().slice.value.value.id, 'b') 87 self.assertEqual(a_sub_b_sub_c.ast().slice.value.slice.value.id, 'c') 88 self.assertEqual(b_sub_c.ast().slice.value.id, 'c') 89 self.assertEqual(a_sub_b_sub_c.parent.qn, ('a',)) 90 with self.assertRaises(ValueError): 91 QN('a', 'b') 92 93 def test_equality(self): 94 a = QN('a') 95 a2 = QN('a') 96 a_b = QN(a, attr='b') 97 self.assertEqual(a2.qn, ('a',)) 98 with self.assertRaises(ValueError): 99 _ = a.parent 100 101 a_b2 = QN(a, attr='b') 102 self.assertEqual(a_b2.qn, (a, 'b')) 103 self.assertEqual(a_b2.parent.qn, ('a',)) 104 105 self.assertTrue(a2 == a) 106 self.assertFalse(a2 is a) 107 108 self.assertTrue(a_b.parent == a) 109 self.assertTrue(a_b2.parent == a) 110 111 self.assertTrue(a_b2 == a_b) 112 self.assertFalse(a_b2 is a_b) 113 self.assertFalse(a_b2 == a) 114 a_sub_b = QN(a, subscript='b') 115 a_sub_b2 = QN(a, subscript='b') 116 self.assertTrue(a_sub_b == a_sub_b2) 117 self.assertFalse(a_sub_b == a_b) 118 119 def test_nested_attrs_subscripts(self): 120 a = QN('a') 121 b = QN('b') 122 c = QN('c') 123 b_sub_c = QN(b, subscript=c) 124 a_sub_b_sub_c = QN(a, subscript=b_sub_c) 125 126 b_dot_c = QN(b, attr='c') 127 a_sub__b_dot_c = QN(a, subscript=b_dot_c) 128 129 a_sub_b = QN(a, subscript=b) 130 a_sub_b__dot_c = QN(a_sub_b, attr='c') 131 132 a_dot_b = QN(a, attr='b') 133 a_dot_b_sub_c = QN(a_dot_b, subscript=c) 134 135 self.assertEqual(str(a_sub_b_sub_c), 'a[b[c]]') 136 self.assertEqual(str(a_sub__b_dot_c), 'a[b.c]') 137 self.assertEqual(str(a_sub_b__dot_c), 'a[b].c') 138 self.assertEqual(str(a_dot_b_sub_c), 'a.b[c]') 139 140 self.assertNotEqual(a_sub_b_sub_c, a_sub__b_dot_c) 141 self.assertNotEqual(a_sub_b_sub_c, a_sub_b__dot_c) 142 self.assertNotEqual(a_sub_b_sub_c, a_dot_b_sub_c) 143 144 self.assertNotEqual(a_sub__b_dot_c, a_sub_b__dot_c) 145 self.assertNotEqual(a_sub__b_dot_c, a_dot_b_sub_c) 146 147 self.assertNotEqual(a_sub_b__dot_c, a_dot_b_sub_c) 148 149 def test_hashable(self): 150 d = {QN('a'): 'a', QN('b'): 'b'} 151 self.assertEqual(d[QN('a')], 'a') 152 self.assertEqual(d[QN('b')], 'b') 153 self.assertTrue(QN('c') not in d) 154 155 def test_literals(self): 156 a = QN('a') 157 a_sub_str_b = QN(a, subscript=QN(qual_names.StringLiteral('b'))) 158 a_sub_b = QN(a, subscript=QN('b')) 159 160 self.assertNotEqual(a_sub_str_b, a_sub_b) 161 self.assertNotEqual(hash(a_sub_str_b), hash(a_sub_b)) 162 163 a_sub_three = QN(a, subscript=QN(qual_names.NumberLiteral(3))) 164 self.assertEqual(a_sub_three.ast().slice.value.n, 3) 165 166 def test_support_set(self): 167 a = QN('a') 168 b = QN('b') 169 c = QN('c') 170 a_sub_b = QN(a, subscript=b) 171 a_dot_b = QN(a, attr='b') 172 a_dot_b_dot_c = QN(a_dot_b, attr='c') 173 a_dot_b_sub_c = QN(a_dot_b, subscript=c) 174 175 self.assertSetEqual(a.support_set, set((a,))) 176 self.assertSetEqual(a_sub_b.support_set, set((a, b))) 177 self.assertSetEqual(a_dot_b.support_set, set((a,))) 178 self.assertSetEqual(a_dot_b_dot_c.support_set, set((a,))) 179 self.assertSetEqual(a_dot_b_sub_c.support_set, set((a, c))) 180 181 182class QNResolverTest(test.TestCase): 183 184 def assertQNStringIs(self, node, qn_str): 185 self.assertEqual(str(anno.getanno(node, anno.Basic.QN)), qn_str) 186 187 def test_resolve(self): 188 samples = """ 189 a 190 a.b 191 (c, d.e) 192 [f, (g.h.i)] 193 j(k, l) 194 """ 195 nodes = resolve(parser.parse_str(textwrap.dedent(samples))) 196 nodes = tuple(n.value for n in nodes.body) 197 198 self.assertQNStringIs(nodes[0], 'a') 199 self.assertQNStringIs(nodes[1], 'a.b') 200 self.assertQNStringIs(nodes[2].elts[0], 'c') 201 self.assertQNStringIs(nodes[2].elts[1], 'd.e') 202 self.assertQNStringIs(nodes[3].elts[0], 'f') 203 self.assertQNStringIs(nodes[3].elts[1], 'g.h.i') 204 self.assertQNStringIs(nodes[4].func, 'j') 205 self.assertQNStringIs(nodes[4].args[0], 'k') 206 self.assertQNStringIs(nodes[4].args[1], 'l') 207 208 def test_subscript_resolve(self): 209 samples = """ 210 x[i] 211 x[i.b] 212 a.b[c] 213 a.b[x.y] 214 a[z[c]] 215 a[b[c[d]]] 216 a[b].c 217 a.b.c[d].e.f 218 a.b[c[d]].e.f 219 a.b[c[d.e.f].g].h 220 """ 221 nodes = resolve(parser.parse_str(textwrap.dedent(samples))) 222 nodes = tuple(n.value for n in nodes.body) 223 224 self.assertQNStringIs(nodes[0], 'x[i]') 225 self.assertQNStringIs(nodes[1], 'x[i.b]') 226 self.assertQNStringIs(nodes[2], 'a.b[c]') 227 self.assertQNStringIs(nodes[3], 'a.b[x.y]') 228 self.assertQNStringIs(nodes[4], 'a[z[c]]') 229 self.assertQNStringIs(nodes[5], 'a[b[c[d]]]') 230 self.assertQNStringIs(nodes[6], 'a[b].c') 231 self.assertQNStringIs(nodes[7], 'a.b.c[d].e.f') 232 self.assertQNStringIs(nodes[8], 'a.b[c[d]].e.f') 233 self.assertQNStringIs(nodes[9], 'a.b[c[d.e.f].g].h') 234 235 def test_function_calls(self): 236 samples = """ 237 a.b 238 a.b() 239 a().b 240 z[i] 241 z[i]() 242 z()[i] 243 """ 244 nodes = resolve(parser.parse_str(textwrap.dedent(samples))) 245 nodes = tuple(n.value for n in nodes.body) 246 self.assertQNStringIs(nodes[0], 'a.b') 247 self.assertQNStringIs(nodes[1].func, 'a.b') 248 self.assertQNStringIs(nodes[2].value.func, 'a') 249 self.assertQNStringIs(nodes[3], 'z[i]') 250 self.assertQNStringIs(nodes[4].func, 'z[i]') 251 self.assertQNStringIs(nodes[5].value.func, 'z') 252 253 254if __name__ == '__main__': 255 test.main() 256