• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""Provides shared memory for direct access across processes.
2
3The API of this package is currently provisional. Refer to the
4documentation for details.
5"""
6
7
8__all__ = [ 'SharedMemory', 'ShareableList' ]
9
10
11from functools import partial
12import mmap
13import os
14import errno
15import struct
16import secrets
17
18if os.name == "nt":
19    import _winapi
20    _USE_POSIX = False
21else:
22    import _posixshmem
23    _USE_POSIX = True
24
25
26_O_CREX = os.O_CREAT | os.O_EXCL
27
28# FreeBSD (and perhaps other BSDs) limit names to 14 characters.
29_SHM_SAFE_NAME_LENGTH = 14
30
31# Shared memory block name prefix
32if _USE_POSIX:
33    _SHM_NAME_PREFIX = '/psm_'
34else:
35    _SHM_NAME_PREFIX = 'wnsm_'
36
37
38def _make_filename():
39    "Create a random filename for the shared memory object."
40    # number of random bytes to use for name
41    nbytes = (_SHM_SAFE_NAME_LENGTH - len(_SHM_NAME_PREFIX)) // 2
42    assert nbytes >= 2, '_SHM_NAME_PREFIX too long'
43    name = _SHM_NAME_PREFIX + secrets.token_hex(nbytes)
44    assert len(name) <= _SHM_SAFE_NAME_LENGTH
45    return name
46
47
48class SharedMemory:
49    """Creates a new shared memory block or attaches to an existing
50    shared memory block.
51
52    Every shared memory block is assigned a unique name.  This enables
53    one process to create a shared memory block with a particular name
54    so that a different process can attach to that same shared memory
55    block using that same name.
56
57    As a resource for sharing data across processes, shared memory blocks
58    may outlive the original process that created them.  When one process
59    no longer needs access to a shared memory block that might still be
60    needed by other processes, the close() method should be called.
61    When a shared memory block is no longer needed by any process, the
62    unlink() method should be called to ensure proper cleanup."""
63
64    # Defaults; enables close() and unlink() to run without errors.
65    _name = None
66    _fd = -1
67    _mmap = None
68    _buf = None
69    _flags = os.O_RDWR
70    _mode = 0o600
71    _prepend_leading_slash = True if _USE_POSIX else False
72
73    def __init__(self, name=None, create=False, size=0):
74        if not size >= 0:
75            raise ValueError("'size' must be a positive integer")
76        if create:
77            self._flags = _O_CREX | os.O_RDWR
78        if name is None and not self._flags & os.O_EXCL:
79            raise ValueError("'name' can only be None if create=True")
80
81        if _USE_POSIX:
82
83            # POSIX Shared Memory
84
85            if name is None:
86                while True:
87                    name = _make_filename()
88                    try:
89                        self._fd = _posixshmem.shm_open(
90                            name,
91                            self._flags,
92                            mode=self._mode
93                        )
94                    except FileExistsError:
95                        continue
96                    self._name = name
97                    break
98            else:
99                name = "/" + name if self._prepend_leading_slash else name
100                self._fd = _posixshmem.shm_open(
101                    name,
102                    self._flags,
103                    mode=self._mode
104                )
105                self._name = name
106            try:
107                if create and size:
108                    os.ftruncate(self._fd, size)
109                stats = os.fstat(self._fd)
110                size = stats.st_size
111                self._mmap = mmap.mmap(self._fd, size)
112            except OSError:
113                self.unlink()
114                raise
115
116            from .resource_tracker import register
117            register(self._name, "shared_memory")
118
119        else:
120
121            # Windows Named Shared Memory
122
123            if create:
124                while True:
125                    temp_name = _make_filename() if name is None else name
126                    # Create and reserve shared memory block with this name
127                    # until it can be attached to by mmap.
128                    h_map = _winapi.CreateFileMapping(
129                        _winapi.INVALID_HANDLE_VALUE,
130                        _winapi.NULL,
131                        _winapi.PAGE_READWRITE,
132                        (size >> 32) & 0xFFFFFFFF,
133                        size & 0xFFFFFFFF,
134                        temp_name
135                    )
136                    try:
137                        last_error_code = _winapi.GetLastError()
138                        if last_error_code == _winapi.ERROR_ALREADY_EXISTS:
139                            if name is not None:
140                                raise FileExistsError(
141                                    errno.EEXIST,
142                                    os.strerror(errno.EEXIST),
143                                    name,
144                                    _winapi.ERROR_ALREADY_EXISTS
145                                )
146                            else:
147                                continue
148                        self._mmap = mmap.mmap(-1, size, tagname=temp_name)
149                    finally:
150                        _winapi.CloseHandle(h_map)
151                    self._name = temp_name
152                    break
153
154            else:
155                self._name = name
156                # Dynamically determine the existing named shared memory
157                # block's size which is likely a multiple of mmap.PAGESIZE.
158                h_map = _winapi.OpenFileMapping(
159                    _winapi.FILE_MAP_READ,
160                    False,
161                    name
162                )
163                try:
164                    p_buf = _winapi.MapViewOfFile(
165                        h_map,
166                        _winapi.FILE_MAP_READ,
167                        0,
168                        0,
169                        0
170                    )
171                finally:
172                    _winapi.CloseHandle(h_map)
173                size = _winapi.VirtualQuerySize(p_buf)
174                self._mmap = mmap.mmap(-1, size, tagname=name)
175
176        self._size = size
177        self._buf = memoryview(self._mmap)
178
179    def __del__(self):
180        try:
181            self.close()
182        except OSError:
183            pass
184
185    def __reduce__(self):
186        return (
187            self.__class__,
188            (
189                self.name,
190                False,
191                self.size,
192            ),
193        )
194
195    def __repr__(self):
196        return f'{self.__class__.__name__}({self.name!r}, size={self.size})'
197
198    @property
199    def buf(self):
200        "A memoryview of contents of the shared memory block."
201        return self._buf
202
203    @property
204    def name(self):
205        "Unique name that identifies the shared memory block."
206        reported_name = self._name
207        if _USE_POSIX and self._prepend_leading_slash:
208            if self._name.startswith("/"):
209                reported_name = self._name[1:]
210        return reported_name
211
212    @property
213    def size(self):
214        "Size in bytes."
215        return self._size
216
217    def close(self):
218        """Closes access to the shared memory from this instance but does
219        not destroy the shared memory block."""
220        if self._buf is not None:
221            self._buf.release()
222            self._buf = None
223        if self._mmap is not None:
224            self._mmap.close()
225            self._mmap = None
226        if _USE_POSIX and self._fd >= 0:
227            os.close(self._fd)
228            self._fd = -1
229
230    def unlink(self):
231        """Requests that the underlying shared memory block be destroyed.
232
233        In order to ensure proper cleanup of resources, unlink should be
234        called once (and only once) across all processes which have access
235        to the shared memory block."""
236        if _USE_POSIX and self._name:
237            from .resource_tracker import unregister
238            _posixshmem.shm_unlink(self._name)
239            unregister(self._name, "shared_memory")
240
241
242_encoding = "utf8"
243
244class ShareableList:
245    """Pattern for a mutable list-like object shareable via a shared
246    memory block.  It differs from the built-in list type in that these
247    lists can not change their overall length (i.e. no append, insert,
248    etc.)
249
250    Because values are packed into a memoryview as bytes, the struct
251    packing format for any storable value must require no more than 8
252    characters to describe its format."""
253
254    _types_mapping = {
255        int: "q",
256        float: "d",
257        bool: "xxxxxxx?",
258        str: "%ds",
259        bytes: "%ds",
260        None.__class__: "xxxxxx?x",
261    }
262    _alignment = 8
263    _back_transforms_mapping = {
264        0: lambda value: value,                   # int, float, bool
265        1: lambda value: value.rstrip(b'\x00').decode(_encoding),  # str
266        2: lambda value: value.rstrip(b'\x00'),   # bytes
267        3: lambda _value: None,                   # None
268    }
269
270    @staticmethod
271    def _extract_recreation_code(value):
272        """Used in concert with _back_transforms_mapping to convert values
273        into the appropriate Python objects when retrieving them from
274        the list as well as when storing them."""
275        if not isinstance(value, (str, bytes, None.__class__)):
276            return 0
277        elif isinstance(value, str):
278            return 1
279        elif isinstance(value, bytes):
280            return 2
281        else:
282            return 3  # NoneType
283
284    def __init__(self, sequence=None, *, name=None):
285        if sequence is not None:
286            _formats = [
287                self._types_mapping[type(item)]
288                    if not isinstance(item, (str, bytes))
289                    else self._types_mapping[type(item)] % (
290                        self._alignment * (len(item) // self._alignment + 1),
291                    )
292                for item in sequence
293            ]
294            self._list_len = len(_formats)
295            assert sum(len(fmt) <= 8 for fmt in _formats) == self._list_len
296            self._allocated_bytes = tuple(
297                    self._alignment if fmt[-1] != "s" else int(fmt[:-1])
298                    for fmt in _formats
299            )
300            _recreation_codes = [
301                self._extract_recreation_code(item) for item in sequence
302            ]
303            requested_size = struct.calcsize(
304                "q" + self._format_size_metainfo +
305                "".join(_formats) +
306                self._format_packing_metainfo +
307                self._format_back_transform_codes
308            )
309
310        else:
311            requested_size = 8  # Some platforms require > 0.
312
313        if name is not None and sequence is None:
314            self.shm = SharedMemory(name)
315        else:
316            self.shm = SharedMemory(name, create=True, size=requested_size)
317
318        if sequence is not None:
319            _enc = _encoding
320            struct.pack_into(
321                "q" + self._format_size_metainfo,
322                self.shm.buf,
323                0,
324                self._list_len,
325                *(self._allocated_bytes)
326            )
327            struct.pack_into(
328                "".join(_formats),
329                self.shm.buf,
330                self._offset_data_start,
331                *(v.encode(_enc) if isinstance(v, str) else v for v in sequence)
332            )
333            struct.pack_into(
334                self._format_packing_metainfo,
335                self.shm.buf,
336                self._offset_packing_formats,
337                *(v.encode(_enc) for v in _formats)
338            )
339            struct.pack_into(
340                self._format_back_transform_codes,
341                self.shm.buf,
342                self._offset_back_transform_codes,
343                *(_recreation_codes)
344            )
345
346        else:
347            self._list_len = len(self)  # Obtains size from offset 0 in buffer.
348            self._allocated_bytes = struct.unpack_from(
349                self._format_size_metainfo,
350                self.shm.buf,
351                1 * 8
352            )
353
354    def _get_packing_format(self, position):
355        "Gets the packing format for a single value stored in the list."
356        position = position if position >= 0 else position + self._list_len
357        if (position >= self._list_len) or (self._list_len < 0):
358            raise IndexError("Requested position out of range.")
359
360        v = struct.unpack_from(
361            "8s",
362            self.shm.buf,
363            self._offset_packing_formats + position * 8
364        )[0]
365        fmt = v.rstrip(b'\x00')
366        fmt_as_str = fmt.decode(_encoding)
367
368        return fmt_as_str
369
370    def _get_back_transform(self, position):
371        "Gets the back transformation function for a single value."
372
373        position = position if position >= 0 else position + self._list_len
374        if (position >= self._list_len) or (self._list_len < 0):
375            raise IndexError("Requested position out of range.")
376
377        transform_code = struct.unpack_from(
378            "b",
379            self.shm.buf,
380            self._offset_back_transform_codes + position
381        )[0]
382        transform_function = self._back_transforms_mapping[transform_code]
383
384        return transform_function
385
386    def _set_packing_format_and_transform(self, position, fmt_as_str, value):
387        """Sets the packing format and back transformation code for a
388        single value in the list at the specified position."""
389
390        position = position if position >= 0 else position + self._list_len
391        if (position >= self._list_len) or (self._list_len < 0):
392            raise IndexError("Requested position out of range.")
393
394        struct.pack_into(
395            "8s",
396            self.shm.buf,
397            self._offset_packing_formats + position * 8,
398            fmt_as_str.encode(_encoding)
399        )
400
401        transform_code = self._extract_recreation_code(value)
402        struct.pack_into(
403            "b",
404            self.shm.buf,
405            self._offset_back_transform_codes + position,
406            transform_code
407        )
408
409    def __getitem__(self, position):
410        try:
411            offset = self._offset_data_start \
412                     + sum(self._allocated_bytes[:position])
413            (v,) = struct.unpack_from(
414                self._get_packing_format(position),
415                self.shm.buf,
416                offset
417            )
418        except IndexError:
419            raise IndexError("index out of range")
420
421        back_transform = self._get_back_transform(position)
422        v = back_transform(v)
423
424        return v
425
426    def __setitem__(self, position, value):
427        try:
428            offset = self._offset_data_start \
429                     + sum(self._allocated_bytes[:position])
430            current_format = self._get_packing_format(position)
431        except IndexError:
432            raise IndexError("assignment index out of range")
433
434        if not isinstance(value, (str, bytes)):
435            new_format = self._types_mapping[type(value)]
436        else:
437            if len(value) > self._allocated_bytes[position]:
438                raise ValueError("exceeds available storage for existing str")
439            if current_format[-1] == "s":
440                new_format = current_format
441            else:
442                new_format = self._types_mapping[str] % (
443                    self._allocated_bytes[position],
444                )
445
446        self._set_packing_format_and_transform(
447            position,
448            new_format,
449            value
450        )
451        value = value.encode(_encoding) if isinstance(value, str) else value
452        struct.pack_into(new_format, self.shm.buf, offset, value)
453
454    def __reduce__(self):
455        return partial(self.__class__, name=self.shm.name), ()
456
457    def __len__(self):
458        return struct.unpack_from("q", self.shm.buf, 0)[0]
459
460    def __repr__(self):
461        return f'{self.__class__.__name__}({list(self)}, name={self.shm.name!r})'
462
463    @property
464    def format(self):
465        "The struct packing format used by all currently stored values."
466        return "".join(
467            self._get_packing_format(i) for i in range(self._list_len)
468        )
469
470    @property
471    def _format_size_metainfo(self):
472        "The struct packing format used for metainfo on storage sizes."
473        return f"{self._list_len}q"
474
475    @property
476    def _format_packing_metainfo(self):
477        "The struct packing format used for the values' packing formats."
478        return "8s" * self._list_len
479
480    @property
481    def _format_back_transform_codes(self):
482        "The struct packing format used for the values' back transforms."
483        return "b" * self._list_len
484
485    @property
486    def _offset_data_start(self):
487        return (self._list_len + 1) * 8  # 8 bytes per "q"
488
489    @property
490    def _offset_packing_formats(self):
491        return self._offset_data_start + sum(self._allocated_bytes)
492
493    @property
494    def _offset_back_transform_codes(self):
495        return self._offset_packing_formats + self._list_len * 8
496
497    def count(self, value):
498        "L.count(value) -> integer -- return number of occurrences of value."
499
500        return sum(value == entry for entry in self)
501
502    def index(self, value):
503        """L.index(value) -> integer -- return first index of value.
504        Raises ValueError if the value is not present."""
505
506        for position, entry in enumerate(self):
507            if value == entry:
508                return position
509        else:
510            raise ValueError(f"{value!r} not in this container")
511