1"""Unit tests for the PickleBuffer object. 2 3Pickling tests themselves are in pickletester.py. 4""" 5 6import gc 7from pickle import PickleBuffer 8import weakref 9import unittest 10 11from test.support import import_helper 12 13 14class B(bytes): 15 pass 16 17 18class PickleBufferTest(unittest.TestCase): 19 20 def check_memoryview(self, pb, equiv): 21 with memoryview(pb) as m: 22 with memoryview(equiv) as expected: 23 self.assertEqual(m.nbytes, expected.nbytes) 24 self.assertEqual(m.readonly, expected.readonly) 25 self.assertEqual(m.itemsize, expected.itemsize) 26 self.assertEqual(m.shape, expected.shape) 27 self.assertEqual(m.strides, expected.strides) 28 self.assertEqual(m.c_contiguous, expected.c_contiguous) 29 self.assertEqual(m.f_contiguous, expected.f_contiguous) 30 self.assertEqual(m.format, expected.format) 31 self.assertEqual(m.tobytes(), expected.tobytes()) 32 33 def test_constructor_failure(self): 34 with self.assertRaises(TypeError): 35 PickleBuffer() 36 with self.assertRaises(TypeError): 37 PickleBuffer("foo") 38 # Released memoryview fails taking a buffer 39 m = memoryview(b"foo") 40 m.release() 41 with self.assertRaises(ValueError): 42 PickleBuffer(m) 43 44 def test_basics(self): 45 pb = PickleBuffer(b"foo") 46 self.assertEqual(b"foo", bytes(pb)) 47 with memoryview(pb) as m: 48 self.assertTrue(m.readonly) 49 50 pb = PickleBuffer(bytearray(b"foo")) 51 self.assertEqual(b"foo", bytes(pb)) 52 with memoryview(pb) as m: 53 self.assertFalse(m.readonly) 54 m[0] = 48 55 self.assertEqual(b"0oo", bytes(pb)) 56 57 def test_release(self): 58 pb = PickleBuffer(b"foo") 59 pb.release() 60 with self.assertRaises(ValueError) as raises: 61 memoryview(pb) 62 self.assertIn("operation forbidden on released PickleBuffer object", 63 str(raises.exception)) 64 # Idempotency 65 pb.release() 66 67 def test_cycle(self): 68 b = B(b"foo") 69 pb = PickleBuffer(b) 70 b.cycle = pb 71 wpb = weakref.ref(pb) 72 del b, pb 73 gc.collect() 74 self.assertIsNone(wpb()) 75 76 def test_ndarray_2d(self): 77 # C-contiguous 78 ndarray = import_helper.import_module("_testbuffer").ndarray 79 arr = ndarray(list(range(12)), shape=(4, 3), format='<i') 80 self.assertTrue(arr.c_contiguous) 81 self.assertFalse(arr.f_contiguous) 82 pb = PickleBuffer(arr) 83 self.check_memoryview(pb, arr) 84 # Non-contiguous 85 arr = arr[::2] 86 self.assertFalse(arr.c_contiguous) 87 self.assertFalse(arr.f_contiguous) 88 pb = PickleBuffer(arr) 89 self.check_memoryview(pb, arr) 90 # F-contiguous 91 arr = ndarray(list(range(12)), shape=(3, 4), strides=(4, 12), format='<i') 92 self.assertTrue(arr.f_contiguous) 93 self.assertFalse(arr.c_contiguous) 94 pb = PickleBuffer(arr) 95 self.check_memoryview(pb, arr) 96 97 # Tests for PickleBuffer.raw() 98 99 def check_raw(self, obj, equiv): 100 pb = PickleBuffer(obj) 101 with pb.raw() as m: 102 self.assertIsInstance(m, memoryview) 103 self.check_memoryview(m, equiv) 104 105 def test_raw(self): 106 for obj in (b"foo", bytearray(b"foo")): 107 with self.subTest(obj=obj): 108 self.check_raw(obj, obj) 109 110 def test_raw_ndarray(self): 111 # 1-D, contiguous 112 ndarray = import_helper.import_module("_testbuffer").ndarray 113 arr = ndarray(list(range(3)), shape=(3,), format='<h') 114 equiv = b"\x00\x00\x01\x00\x02\x00" 115 self.check_raw(arr, equiv) 116 # 2-D, C-contiguous 117 arr = ndarray(list(range(6)), shape=(2, 3), format='<h') 118 equiv = b"\x00\x00\x01\x00\x02\x00\x03\x00\x04\x00\x05\x00" 119 self.check_raw(arr, equiv) 120 # 2-D, F-contiguous 121 arr = ndarray(list(range(6)), shape=(2, 3), strides=(2, 4), 122 format='<h') 123 # Note this is different from arr.tobytes() 124 equiv = b"\x00\x00\x01\x00\x02\x00\x03\x00\x04\x00\x05\x00" 125 self.check_raw(arr, equiv) 126 # 0-D 127 arr = ndarray(456, shape=(), format='<i') 128 equiv = b'\xc8\x01\x00\x00' 129 self.check_raw(arr, equiv) 130 131 def check_raw_non_contiguous(self, obj): 132 pb = PickleBuffer(obj) 133 with self.assertRaisesRegex(BufferError, "non-contiguous"): 134 pb.raw() 135 136 def test_raw_non_contiguous(self): 137 # 1-D 138 ndarray = import_helper.import_module("_testbuffer").ndarray 139 arr = ndarray(list(range(6)), shape=(6,), format='<i')[::2] 140 self.check_raw_non_contiguous(arr) 141 # 2-D 142 arr = ndarray(list(range(12)), shape=(4, 3), format='<i')[::2] 143 self.check_raw_non_contiguous(arr) 144 145 def test_raw_released(self): 146 pb = PickleBuffer(b"foo") 147 pb.release() 148 with self.assertRaises(ValueError) as raises: 149 pb.raw() 150 151 152if __name__ == "__main__": 153 unittest.main() 154