• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""Unit tests for the memoryview
2
3   Some tests are in test_bytes. Many tests that require _testbuffer.ndarray
4   are in test_buffer.
5"""
6
7import unittest
8import test.support
9import sys
10import gc
11import weakref
12import array
13import io
14import copy
15import pickle
16
17
18class AbstractMemoryTests:
19    source_bytes = b"abcdef"
20
21    @property
22    def _source(self):
23        return self.source_bytes
24
25    @property
26    def _types(self):
27        return filter(None, [self.ro_type, self.rw_type])
28
29    def check_getitem_with_type(self, tp):
30        b = tp(self._source)
31        oldrefcount = sys.getrefcount(b)
32        m = self._view(b)
33        self.assertEqual(m[0], ord(b"a"))
34        self.assertIsInstance(m[0], int)
35        self.assertEqual(m[5], ord(b"f"))
36        self.assertEqual(m[-1], ord(b"f"))
37        self.assertEqual(m[-6], ord(b"a"))
38        # Bounds checking
39        self.assertRaises(IndexError, lambda: m[6])
40        self.assertRaises(IndexError, lambda: m[-7])
41        self.assertRaises(IndexError, lambda: m[sys.maxsize])
42        self.assertRaises(IndexError, lambda: m[-sys.maxsize])
43        # Type checking
44        self.assertRaises(TypeError, lambda: m[None])
45        self.assertRaises(TypeError, lambda: m[0.0])
46        self.assertRaises(TypeError, lambda: m["a"])
47        m = None
48        self.assertEqual(sys.getrefcount(b), oldrefcount)
49
50    def test_getitem(self):
51        for tp in self._types:
52            self.check_getitem_with_type(tp)
53
54    def test_iter(self):
55        for tp in self._types:
56            b = tp(self._source)
57            m = self._view(b)
58            self.assertEqual(list(m), [m[i] for i in range(len(m))])
59
60    def test_setitem_readonly(self):
61        if not self.ro_type:
62            self.skipTest("no read-only type to test")
63        b = self.ro_type(self._source)
64        oldrefcount = sys.getrefcount(b)
65        m = self._view(b)
66        def setitem(value):
67            m[0] = value
68        self.assertRaises(TypeError, setitem, b"a")
69        self.assertRaises(TypeError, setitem, 65)
70        self.assertRaises(TypeError, setitem, memoryview(b"a"))
71        m = None
72        self.assertEqual(sys.getrefcount(b), oldrefcount)
73
74    def test_setitem_writable(self):
75        if not self.rw_type:
76            self.skipTest("no writable type to test")
77        tp = self.rw_type
78        b = self.rw_type(self._source)
79        oldrefcount = sys.getrefcount(b)
80        m = self._view(b)
81        m[0] = ord(b'1')
82        self._check_contents(tp, b, b"1bcdef")
83        m[0:1] = tp(b"0")
84        self._check_contents(tp, b, b"0bcdef")
85        m[1:3] = tp(b"12")
86        self._check_contents(tp, b, b"012def")
87        m[1:1] = tp(b"")
88        self._check_contents(tp, b, b"012def")
89        m[:] = tp(b"abcdef")
90        self._check_contents(tp, b, b"abcdef")
91
92        # Overlapping copies of a view into itself
93        m[0:3] = m[2:5]
94        self._check_contents(tp, b, b"cdedef")
95        m[:] = tp(b"abcdef")
96        m[2:5] = m[0:3]
97        self._check_contents(tp, b, b"ababcf")
98
99        def setitem(key, value):
100            m[key] = tp(value)
101        # Bounds checking
102        self.assertRaises(IndexError, setitem, 6, b"a")
103        self.assertRaises(IndexError, setitem, -7, b"a")
104        self.assertRaises(IndexError, setitem, sys.maxsize, b"a")
105        self.assertRaises(IndexError, setitem, -sys.maxsize, b"a")
106        # Wrong index/slice types
107        self.assertRaises(TypeError, setitem, 0.0, b"a")
108        self.assertRaises(TypeError, setitem, (0,), b"a")
109        self.assertRaises(TypeError, setitem, (slice(0,1,1), 0), b"a")
110        self.assertRaises(TypeError, setitem, (0, slice(0,1,1)), b"a")
111        self.assertRaises(TypeError, setitem, (0,), b"a")
112        self.assertRaises(TypeError, setitem, "a", b"a")
113        # Not implemented: multidimensional slices
114        slices = (slice(0,1,1), slice(0,1,2))
115        self.assertRaises(NotImplementedError, setitem, slices, b"a")
116        # Trying to resize the memory object
117        exc = ValueError if m.format == 'c' else TypeError
118        self.assertRaises(exc, setitem, 0, b"")
119        self.assertRaises(exc, setitem, 0, b"ab")
120        self.assertRaises(ValueError, setitem, slice(1,1), b"a")
121        self.assertRaises(ValueError, setitem, slice(0,2), b"a")
122
123        m = None
124        self.assertEqual(sys.getrefcount(b), oldrefcount)
125
126    def test_delitem(self):
127        for tp in self._types:
128            b = tp(self._source)
129            m = self._view(b)
130            with self.assertRaises(TypeError):
131                del m[1]
132            with self.assertRaises(TypeError):
133                del m[1:4]
134
135    def test_tobytes(self):
136        for tp in self._types:
137            m = self._view(tp(self._source))
138            b = m.tobytes()
139            # This calls self.getitem_type() on each separate byte of b"abcdef"
140            expected = b"".join(
141                self.getitem_type(bytes([c])) for c in b"abcdef")
142            self.assertEqual(b, expected)
143            self.assertIsInstance(b, bytes)
144
145    def test_tolist(self):
146        for tp in self._types:
147            m = self._view(tp(self._source))
148            l = m.tolist()
149            self.assertEqual(l, list(b"abcdef"))
150
151    def test_compare(self):
152        # memoryviews can compare for equality with other objects
153        # having the buffer interface.
154        for tp in self._types:
155            m = self._view(tp(self._source))
156            for tp_comp in self._types:
157                self.assertTrue(m == tp_comp(b"abcdef"))
158                self.assertFalse(m != tp_comp(b"abcdef"))
159                self.assertFalse(m == tp_comp(b"abcde"))
160                self.assertTrue(m != tp_comp(b"abcde"))
161                self.assertFalse(m == tp_comp(b"abcde1"))
162                self.assertTrue(m != tp_comp(b"abcde1"))
163            self.assertTrue(m == m)
164            self.assertTrue(m == m[:])
165            self.assertTrue(m[0:6] == m[:])
166            self.assertFalse(m[0:5] == m)
167
168            # Comparison with objects which don't support the buffer API
169            self.assertFalse(m == "abcdef")
170            self.assertTrue(m != "abcdef")
171            self.assertFalse("abcdef" == m)
172            self.assertTrue("abcdef" != m)
173
174            # Unordered comparisons
175            for c in (m, b"abcdef"):
176                self.assertRaises(TypeError, lambda: m < c)
177                self.assertRaises(TypeError, lambda: c <= m)
178                self.assertRaises(TypeError, lambda: m >= c)
179                self.assertRaises(TypeError, lambda: c > m)
180
181    def check_attributes_with_type(self, tp):
182        m = self._view(tp(self._source))
183        self.assertEqual(m.format, self.format)
184        self.assertEqual(m.itemsize, self.itemsize)
185        self.assertEqual(m.ndim, 1)
186        self.assertEqual(m.shape, (6,))
187        self.assertEqual(len(m), 6)
188        self.assertEqual(m.strides, (self.itemsize,))
189        self.assertEqual(m.suboffsets, ())
190        return m
191
192    def test_attributes_readonly(self):
193        if not self.ro_type:
194            self.skipTest("no read-only type to test")
195        m = self.check_attributes_with_type(self.ro_type)
196        self.assertEqual(m.readonly, True)
197
198    def test_attributes_writable(self):
199        if not self.rw_type:
200            self.skipTest("no writable type to test")
201        m = self.check_attributes_with_type(self.rw_type)
202        self.assertEqual(m.readonly, False)
203
204    def test_getbuffer(self):
205        # Test PyObject_GetBuffer() on a memoryview object.
206        for tp in self._types:
207            b = tp(self._source)
208            oldrefcount = sys.getrefcount(b)
209            m = self._view(b)
210            oldviewrefcount = sys.getrefcount(m)
211            s = str(m, "utf-8")
212            self._check_contents(tp, b, s.encode("utf-8"))
213            self.assertEqual(sys.getrefcount(m), oldviewrefcount)
214            m = None
215            self.assertEqual(sys.getrefcount(b), oldrefcount)
216
217    def test_gc(self):
218        for tp in self._types:
219            if not isinstance(tp, type):
220                # If tp is a factory rather than a plain type, skip
221                continue
222
223            class MyView():
224                def __init__(self, base):
225                    self.m = memoryview(base)
226            class MySource(tp):
227                pass
228            class MyObject:
229                pass
230
231            # Create a reference cycle through a memoryview object.
232            # This exercises mbuf_clear().
233            b = MySource(tp(b'abc'))
234            m = self._view(b)
235            o = MyObject()
236            b.m = m
237            b.o = o
238            wr = weakref.ref(o)
239            b = m = o = None
240            # The cycle must be broken
241            gc.collect()
242            self.assertTrue(wr() is None, wr())
243
244            # This exercises memory_clear().
245            m = MyView(tp(b'abc'))
246            o = MyObject()
247            m.x = m
248            m.o = o
249            wr = weakref.ref(o)
250            m = o = None
251            # The cycle must be broken
252            gc.collect()
253            self.assertTrue(wr() is None, wr())
254
255    def _check_released(self, m, tp):
256        check = self.assertRaisesRegex(ValueError, "released")
257        with check: bytes(m)
258        with check: m.tobytes()
259        with check: m.tolist()
260        with check: m[0]
261        with check: m[0] = b'x'
262        with check: len(m)
263        with check: m.format
264        with check: m.itemsize
265        with check: m.ndim
266        with check: m.readonly
267        with check: m.shape
268        with check: m.strides
269        with check:
270            with m:
271                pass
272        # str() and repr() still function
273        self.assertIn("released memory", str(m))
274        self.assertIn("released memory", repr(m))
275        self.assertEqual(m, m)
276        self.assertNotEqual(m, memoryview(tp(self._source)))
277        self.assertNotEqual(m, tp(self._source))
278
279    def test_contextmanager(self):
280        for tp in self._types:
281            b = tp(self._source)
282            m = self._view(b)
283            with m as cm:
284                self.assertIs(cm, m)
285            self._check_released(m, tp)
286            m = self._view(b)
287            # Can release explicitly inside the context manager
288            with m:
289                m.release()
290
291    def test_release(self):
292        for tp in self._types:
293            b = tp(self._source)
294            m = self._view(b)
295            m.release()
296            self._check_released(m, tp)
297            # Can be called a second time (it's a no-op)
298            m.release()
299            self._check_released(m, tp)
300
301    def test_writable_readonly(self):
302        # Issue #10451: memoryview incorrectly exposes a readonly
303        # buffer as writable causing a segfault if using mmap
304        tp = self.ro_type
305        if tp is None:
306            self.skipTest("no read-only type to test")
307        b = tp(self._source)
308        m = self._view(b)
309        i = io.BytesIO(b'ZZZZ')
310        self.assertRaises(TypeError, i.readinto, m)
311
312    def test_getbuf_fail(self):
313        self.assertRaises(TypeError, self._view, {})
314
315    def test_hash(self):
316        # Memoryviews of readonly (hashable) types are hashable, and they
317        # hash as hash(obj.tobytes()).
318        tp = self.ro_type
319        if tp is None:
320            self.skipTest("no read-only type to test")
321        b = tp(self._source)
322        m = self._view(b)
323        self.assertEqual(hash(m), hash(b"abcdef"))
324        # Releasing the memoryview keeps the stored hash value (as with weakrefs)
325        m.release()
326        self.assertEqual(hash(m), hash(b"abcdef"))
327        # Hashing a memoryview for the first time after it is released
328        # results in an error (as with weakrefs).
329        m = self._view(b)
330        m.release()
331        self.assertRaises(ValueError, hash, m)
332
333    def test_hash_writable(self):
334        # Memoryviews of writable types are unhashable
335        tp = self.rw_type
336        if tp is None:
337            self.skipTest("no writable type to test")
338        b = tp(self._source)
339        m = self._view(b)
340        self.assertRaises(ValueError, hash, m)
341
342    def test_weakref(self):
343        # Check memoryviews are weakrefable
344        for tp in self._types:
345            b = tp(self._source)
346            m = self._view(b)
347            L = []
348            def callback(wr, b=b):
349                L.append(b)
350            wr = weakref.ref(m, callback)
351            self.assertIs(wr(), m)
352            del m
353            test.support.gc_collect()
354            self.assertIs(wr(), None)
355            self.assertIs(L[0], b)
356
357    def test_reversed(self):
358        for tp in self._types:
359            b = tp(self._source)
360            m = self._view(b)
361            aslist = list(reversed(m.tolist()))
362            self.assertEqual(list(reversed(m)), aslist)
363            self.assertEqual(list(reversed(m)), list(m[::-1]))
364
365    def test_toreadonly(self):
366        for tp in self._types:
367            b = tp(self._source)
368            m = self._view(b)
369            mm = m.toreadonly()
370            self.assertTrue(mm.readonly)
371            self.assertTrue(memoryview(mm).readonly)
372            self.assertEqual(mm.tolist(), m.tolist())
373            mm.release()
374            m.tolist()
375
376    def test_issue22668(self):
377        a = array.array('H', [256, 256, 256, 256])
378        x = memoryview(a)
379        m = x.cast('B')
380        b = m.cast('H')
381        c = b[0:2]
382        d = memoryview(b)
383
384        del b
385
386        self.assertEqual(c[0], 256)
387        self.assertEqual(d[0], 256)
388        self.assertEqual(c.format, "H")
389        self.assertEqual(d.format, "H")
390
391        _ = m.cast('I')
392        self.assertEqual(c[0], 256)
393        self.assertEqual(d[0], 256)
394        self.assertEqual(c.format, "H")
395        self.assertEqual(d.format, "H")
396
397
398# Variations on source objects for the buffer: bytes-like objects, then arrays
399# with itemsize > 1.
400# NOTE: support for multi-dimensional objects is unimplemented.
401
402class BaseBytesMemoryTests(AbstractMemoryTests):
403    ro_type = bytes
404    rw_type = bytearray
405    getitem_type = bytes
406    itemsize = 1
407    format = 'B'
408
409class BaseArrayMemoryTests(AbstractMemoryTests):
410    ro_type = None
411    rw_type = lambda self, b: array.array('i', list(b))
412    getitem_type = lambda self, b: array.array('i', list(b)).tobytes()
413    itemsize = array.array('i').itemsize
414    format = 'i'
415
416    @unittest.skip('XXX test should be adapted for non-byte buffers')
417    def test_getbuffer(self):
418        pass
419
420    @unittest.skip('XXX NotImplementedError: tolist() only supports byte views')
421    def test_tolist(self):
422        pass
423
424
425# Variations on indirection levels: memoryview, slice of memoryview,
426# slice of slice of memoryview.
427# This is important to test allocation subtleties.
428
429class BaseMemoryviewTests:
430    def _view(self, obj):
431        return memoryview(obj)
432
433    def _check_contents(self, tp, obj, contents):
434        self.assertEqual(obj, tp(contents))
435
436class BaseMemorySliceTests:
437    source_bytes = b"XabcdefY"
438
439    def _view(self, obj):
440        m = memoryview(obj)
441        return m[1:7]
442
443    def _check_contents(self, tp, obj, contents):
444        self.assertEqual(obj[1:7], tp(contents))
445
446    def test_refs(self):
447        for tp in self._types:
448            m = memoryview(tp(self._source))
449            oldrefcount = sys.getrefcount(m)
450            m[1:2]
451            self.assertEqual(sys.getrefcount(m), oldrefcount)
452
453class BaseMemorySliceSliceTests:
454    source_bytes = b"XabcdefY"
455
456    def _view(self, obj):
457        m = memoryview(obj)
458        return m[:7][1:]
459
460    def _check_contents(self, tp, obj, contents):
461        self.assertEqual(obj[1:7], tp(contents))
462
463
464# Concrete test classes
465
466class BytesMemoryviewTest(unittest.TestCase,
467    BaseMemoryviewTests, BaseBytesMemoryTests):
468
469    def test_constructor(self):
470        for tp in self._types:
471            ob = tp(self._source)
472            self.assertTrue(memoryview(ob))
473            self.assertTrue(memoryview(object=ob))
474            self.assertRaises(TypeError, memoryview)
475            self.assertRaises(TypeError, memoryview, ob, ob)
476            self.assertRaises(TypeError, memoryview, argument=ob)
477            self.assertRaises(TypeError, memoryview, ob, argument=True)
478
479class ArrayMemoryviewTest(unittest.TestCase,
480    BaseMemoryviewTests, BaseArrayMemoryTests):
481
482    def test_array_assign(self):
483        # Issue #4569: segfault when mutating a memoryview with itemsize != 1
484        a = array.array('i', range(10))
485        m = memoryview(a)
486        new_a = array.array('i', range(9, -1, -1))
487        m[:] = new_a
488        self.assertEqual(a, new_a)
489
490
491class BytesMemorySliceTest(unittest.TestCase,
492    BaseMemorySliceTests, BaseBytesMemoryTests):
493    pass
494
495class ArrayMemorySliceTest(unittest.TestCase,
496    BaseMemorySliceTests, BaseArrayMemoryTests):
497    pass
498
499class BytesMemorySliceSliceTest(unittest.TestCase,
500    BaseMemorySliceSliceTests, BaseBytesMemoryTests):
501    pass
502
503class ArrayMemorySliceSliceTest(unittest.TestCase,
504    BaseMemorySliceSliceTests, BaseArrayMemoryTests):
505    pass
506
507
508class OtherTest(unittest.TestCase):
509    def test_ctypes_cast(self):
510        # Issue 15944: Allow all source formats when casting to bytes.
511        ctypes = test.support.import_module("ctypes")
512        p6 = bytes(ctypes.c_double(0.6))
513
514        d = ctypes.c_double()
515        m = memoryview(d).cast("B")
516        m[:2] = p6[:2]
517        m[2:] = p6[2:]
518        self.assertEqual(d.value, 0.6)
519
520        for format in "Bbc":
521            with self.subTest(format):
522                d = ctypes.c_double()
523                m = memoryview(d).cast(format)
524                m[:2] = memoryview(p6).cast(format)[:2]
525                m[2:] = memoryview(p6).cast(format)[2:]
526                self.assertEqual(d.value, 0.6)
527
528    def test_memoryview_hex(self):
529        # Issue #9951: memoryview.hex() segfaults with non-contiguous buffers.
530        x = b'0' * 200000
531        m1 = memoryview(x)
532        m2 = m1[::-1]
533        self.assertEqual(m2.hex(), '30' * 200000)
534
535    def test_copy(self):
536        m = memoryview(b'abc')
537        with self.assertRaises(TypeError):
538            copy.copy(m)
539
540    def test_pickle(self):
541        m = memoryview(b'abc')
542        for proto in range(pickle.HIGHEST_PROTOCOL + 1):
543            with self.assertRaises(TypeError):
544                pickle.dumps(m, proto)
545
546
547if __name__ == "__main__":
548    unittest.main()
549