1import unittest 2import operator 3import sys 4import pickle 5import gc 6 7from test import support 8 9class G: 10 'Sequence using __getitem__' 11 def __init__(self, seqn): 12 self.seqn = seqn 13 def __getitem__(self, i): 14 return self.seqn[i] 15 16class I: 17 'Sequence using iterator protocol' 18 def __init__(self, seqn): 19 self.seqn = seqn 20 self.i = 0 21 def __iter__(self): 22 return self 23 def __next__(self): 24 if self.i >= len(self.seqn): raise StopIteration 25 v = self.seqn[self.i] 26 self.i += 1 27 return v 28 29class Ig: 30 'Sequence using iterator protocol defined with a generator' 31 def __init__(self, seqn): 32 self.seqn = seqn 33 self.i = 0 34 def __iter__(self): 35 for val in self.seqn: 36 yield val 37 38class X: 39 'Missing __getitem__ and __iter__' 40 def __init__(self, seqn): 41 self.seqn = seqn 42 self.i = 0 43 def __next__(self): 44 if self.i >= len(self.seqn): raise StopIteration 45 v = self.seqn[self.i] 46 self.i += 1 47 return v 48 49class E: 50 'Test propagation of exceptions' 51 def __init__(self, seqn): 52 self.seqn = seqn 53 self.i = 0 54 def __iter__(self): 55 return self 56 def __next__(self): 57 3 // 0 58 59class N: 60 'Iterator missing __next__()' 61 def __init__(self, seqn): 62 self.seqn = seqn 63 self.i = 0 64 def __iter__(self): 65 return self 66 67class PickleTest: 68 # Helper to check picklability 69 def check_pickle(self, itorg, seq): 70 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 71 d = pickle.dumps(itorg, proto) 72 it = pickle.loads(d) 73 self.assertEqual(type(itorg), type(it)) 74 self.assertEqual(list(it), seq) 75 76 it = pickle.loads(d) 77 try: 78 next(it) 79 except StopIteration: 80 self.assertFalse(seq[1:]) 81 continue 82 d = pickle.dumps(it, proto) 83 it = pickle.loads(d) 84 self.assertEqual(list(it), seq[1:]) 85 86class EnumerateTestCase(unittest.TestCase, PickleTest): 87 88 enum = enumerate 89 seq, res = 'abc', [(0,'a'), (1,'b'), (2,'c')] 90 91 def test_basicfunction(self): 92 self.assertEqual(type(self.enum(self.seq)), self.enum) 93 e = self.enum(self.seq) 94 self.assertEqual(iter(e), e) 95 self.assertEqual(list(self.enum(self.seq)), self.res) 96 self.enum.__doc__ 97 98 def test_pickle(self): 99 self.check_pickle(self.enum(self.seq), self.res) 100 101 def test_getitemseqn(self): 102 self.assertEqual(list(self.enum(G(self.seq))), self.res) 103 e = self.enum(G('')) 104 self.assertRaises(StopIteration, next, e) 105 106 def test_iteratorseqn(self): 107 self.assertEqual(list(self.enum(I(self.seq))), self.res) 108 e = self.enum(I('')) 109 self.assertRaises(StopIteration, next, e) 110 111 def test_iteratorgenerator(self): 112 self.assertEqual(list(self.enum(Ig(self.seq))), self.res) 113 e = self.enum(Ig('')) 114 self.assertRaises(StopIteration, next, e) 115 116 def test_noniterable(self): 117 self.assertRaises(TypeError, self.enum, X(self.seq)) 118 119 def test_illformediterable(self): 120 self.assertRaises(TypeError, self.enum, N(self.seq)) 121 122 def test_exception_propagation(self): 123 self.assertRaises(ZeroDivisionError, list, self.enum(E(self.seq))) 124 125 def test_argumentcheck(self): 126 self.assertRaises(TypeError, self.enum) # no arguments 127 self.assertRaises(TypeError, self.enum, 1) # wrong type (not iterable) 128 self.assertRaises(TypeError, self.enum, 'abc', 'a') # wrong type 129 self.assertRaises(TypeError, self.enum, 'abc', 2, 3) # too many arguments 130 131 @support.cpython_only 132 def test_tuple_reuse(self): 133 # Tests an implementation detail where tuple is reused 134 # whenever nothing else holds a reference to it 135 self.assertEqual(len(set(map(id, list(enumerate(self.seq))))), len(self.seq)) 136 self.assertEqual(len(set(map(id, enumerate(self.seq)))), min(1,len(self.seq))) 137 138 @support.cpython_only 139 def test_enumerate_result_gc(self): 140 # bpo-42536: enumerate's tuple-reuse speed trick breaks the GC's 141 # assumptions about what can be untracked. Make sure we re-track result 142 # tuples whenever we reuse them. 143 it = self.enum([[]]) 144 gc.collect() 145 # That GC collection probably untracked the recycled internal result 146 # tuple, which is initialized to (None, None). Make sure it's re-tracked 147 # when it's mutated and returned from __next__: 148 self.assertTrue(gc.is_tracked(next(it))) 149 150class MyEnum(enumerate): 151 pass 152 153class SubclassTestCase(EnumerateTestCase): 154 155 enum = MyEnum 156 157class TestEmpty(EnumerateTestCase): 158 159 seq, res = '', [] 160 161class TestBig(EnumerateTestCase): 162 163 seq = range(10,20000,2) 164 res = list(zip(range(20000), seq)) 165 166class TestReversed(unittest.TestCase, PickleTest): 167 168 def test_simple(self): 169 class A: 170 def __getitem__(self, i): 171 if i < 5: 172 return str(i) 173 raise StopIteration 174 def __len__(self): 175 return 5 176 for data in ('abc', range(5), tuple(enumerate('abc')), A(), 177 range(1,17,5), dict.fromkeys('abcde')): 178 self.assertEqual(list(data)[::-1], list(reversed(data))) 179 # don't allow keyword arguments 180 self.assertRaises(TypeError, reversed, [], a=1) 181 182 def test_range_optimization(self): 183 x = range(1) 184 self.assertEqual(type(reversed(x)), type(iter(x))) 185 186 def test_len(self): 187 for s in ('hello', tuple('hello'), list('hello'), range(5)): 188 self.assertEqual(operator.length_hint(reversed(s)), len(s)) 189 r = reversed(s) 190 list(r) 191 self.assertEqual(operator.length_hint(r), 0) 192 class SeqWithWeirdLen: 193 called = False 194 def __len__(self): 195 if not self.called: 196 self.called = True 197 return 10 198 raise ZeroDivisionError 199 def __getitem__(self, index): 200 return index 201 r = reversed(SeqWithWeirdLen()) 202 self.assertRaises(ZeroDivisionError, operator.length_hint, r) 203 204 205 def test_gc(self): 206 class Seq: 207 def __len__(self): 208 return 10 209 def __getitem__(self, index): 210 return index 211 s = Seq() 212 r = reversed(s) 213 s.r = r 214 215 def test_args(self): 216 self.assertRaises(TypeError, reversed) 217 self.assertRaises(TypeError, reversed, [], 'extra') 218 219 @unittest.skipUnless(hasattr(sys, 'getrefcount'), 'test needs sys.getrefcount()') 220 def test_bug1229429(self): 221 # this bug was never in reversed, it was in 222 # PyObject_CallMethod, and reversed_new calls that sometimes. 223 def f(): 224 pass 225 r = f.__reversed__ = object() 226 rc = sys.getrefcount(r) 227 for i in range(10): 228 try: 229 reversed(f) 230 except TypeError: 231 pass 232 else: 233 self.fail("non-callable __reversed__ didn't raise!") 234 self.assertEqual(rc, sys.getrefcount(r)) 235 236 def test_objmethods(self): 237 # Objects must have __len__() and __getitem__() implemented. 238 class NoLen(object): 239 def __getitem__(self, i): return 1 240 nl = NoLen() 241 self.assertRaises(TypeError, reversed, nl) 242 243 class NoGetItem(object): 244 def __len__(self): return 2 245 ngi = NoGetItem() 246 self.assertRaises(TypeError, reversed, ngi) 247 248 class Blocked(object): 249 def __getitem__(self, i): return 1 250 def __len__(self): return 2 251 __reversed__ = None 252 b = Blocked() 253 self.assertRaises(TypeError, reversed, b) 254 255 def test_pickle(self): 256 for data in 'abc', range(5), tuple(enumerate('abc')), range(1,17,5): 257 self.check_pickle(reversed(data), list(data)[::-1]) 258 259 260class EnumerateStartTestCase(EnumerateTestCase): 261 262 def test_basicfunction(self): 263 e = self.enum(self.seq) 264 self.assertEqual(iter(e), e) 265 self.assertEqual(list(self.enum(self.seq)), self.res) 266 267 268class TestStart(EnumerateStartTestCase): 269 270 enum = lambda self, i: enumerate(i, start=11) 271 seq, res = 'abc', [(11, 'a'), (12, 'b'), (13, 'c')] 272 273 274class TestLongStart(EnumerateStartTestCase): 275 276 enum = lambda self, i: enumerate(i, start=sys.maxsize+1) 277 seq, res = 'abc', [(sys.maxsize+1,'a'), (sys.maxsize+2,'b'), 278 (sys.maxsize+3,'c')] 279 280 281if __name__ == "__main__": 282 unittest.main() 283