• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""Provides shared memory for direct access across processes.
2
3The API of this package is currently provisional. Refer to the
4documentation for details.
5"""
6
7
8__all__ = [ 'SharedMemory', 'ShareableList' ]
9
10
11from functools import partial
12import mmap
13import os
14import errno
15import struct
16import secrets
17import types
18
19if os.name == "nt":
20    import _winapi
21    _USE_POSIX = False
22else:
23    import _posixshmem
24    _USE_POSIX = True
25
26from . import resource_tracker
27
28_O_CREX = os.O_CREAT | os.O_EXCL
29
30# FreeBSD (and perhaps other BSDs) limit names to 14 characters.
31_SHM_SAFE_NAME_LENGTH = 14
32
33# Shared memory block name prefix
34if _USE_POSIX:
35    _SHM_NAME_PREFIX = '/psm_'
36else:
37    _SHM_NAME_PREFIX = 'wnsm_'
38
39
40def _make_filename():
41    "Create a random filename for the shared memory object."
42    # number of random bytes to use for name
43    nbytes = (_SHM_SAFE_NAME_LENGTH - len(_SHM_NAME_PREFIX)) // 2
44    assert nbytes >= 2, '_SHM_NAME_PREFIX too long'
45    name = _SHM_NAME_PREFIX + secrets.token_hex(nbytes)
46    assert len(name) <= _SHM_SAFE_NAME_LENGTH
47    return name
48
49
50class SharedMemory:
51    """Creates a new shared memory block or attaches to an existing
52    shared memory block.
53
54    Every shared memory block is assigned a unique name.  This enables
55    one process to create a shared memory block with a particular name
56    so that a different process can attach to that same shared memory
57    block using that same name.
58
59    As a resource for sharing data across processes, shared memory blocks
60    may outlive the original process that created them.  When one process
61    no longer needs access to a shared memory block that might still be
62    needed by other processes, the close() method should be called.
63    When a shared memory block is no longer needed by any process, the
64    unlink() method should be called to ensure proper cleanup."""
65
66    # Defaults; enables close() and unlink() to run without errors.
67    _name = None
68    _fd = -1
69    _mmap = None
70    _buf = None
71    _flags = os.O_RDWR
72    _mode = 0o600
73    _prepend_leading_slash = True if _USE_POSIX else False
74    _track = True
75
76    def __init__(self, name=None, create=False, size=0, *, track=True):
77        if not size >= 0:
78            raise ValueError("'size' must be a positive integer")
79        if create:
80            self._flags = _O_CREX | os.O_RDWR
81            if size == 0:
82                raise ValueError("'size' must be a positive number different from zero")
83        if name is None and not self._flags & os.O_EXCL:
84            raise ValueError("'name' can only be None if create=True")
85
86        self._track = track
87        if _USE_POSIX:
88
89            # POSIX Shared Memory
90
91            if name is None:
92                while True:
93                    name = _make_filename()
94                    try:
95                        self._fd = _posixshmem.shm_open(
96                            name,
97                            self._flags,
98                            mode=self._mode
99                        )
100                    except FileExistsError:
101                        continue
102                    self._name = name
103                    break
104            else:
105                name = "/" + name if self._prepend_leading_slash else name
106                self._fd = _posixshmem.shm_open(
107                    name,
108                    self._flags,
109                    mode=self._mode
110                )
111                self._name = name
112            try:
113                if create and size:
114                    os.ftruncate(self._fd, size)
115                stats = os.fstat(self._fd)
116                size = stats.st_size
117                self._mmap = mmap.mmap(self._fd, size)
118            except OSError:
119                self.unlink()
120                raise
121            if self._track:
122                resource_tracker.register(self._name, "shared_memory")
123
124        else:
125
126            # Windows Named Shared Memory
127
128            if create:
129                while True:
130                    temp_name = _make_filename() if name is None else name
131                    # Create and reserve shared memory block with this name
132                    # until it can be attached to by mmap.
133                    h_map = _winapi.CreateFileMapping(
134                        _winapi.INVALID_HANDLE_VALUE,
135                        _winapi.NULL,
136                        _winapi.PAGE_READWRITE,
137                        (size >> 32) & 0xFFFFFFFF,
138                        size & 0xFFFFFFFF,
139                        temp_name
140                    )
141                    try:
142                        last_error_code = _winapi.GetLastError()
143                        if last_error_code == _winapi.ERROR_ALREADY_EXISTS:
144                            if name is not None:
145                                raise FileExistsError(
146                                    errno.EEXIST,
147                                    os.strerror(errno.EEXIST),
148                                    name,
149                                    _winapi.ERROR_ALREADY_EXISTS
150                                )
151                            else:
152                                continue
153                        self._mmap = mmap.mmap(-1, size, tagname=temp_name)
154                    finally:
155                        _winapi.CloseHandle(h_map)
156                    self._name = temp_name
157                    break
158
159            else:
160                self._name = name
161                # Dynamically determine the existing named shared memory
162                # block's size which is likely a multiple of mmap.PAGESIZE.
163                h_map = _winapi.OpenFileMapping(
164                    _winapi.FILE_MAP_READ,
165                    False,
166                    name
167                )
168                try:
169                    p_buf = _winapi.MapViewOfFile(
170                        h_map,
171                        _winapi.FILE_MAP_READ,
172                        0,
173                        0,
174                        0
175                    )
176                finally:
177                    _winapi.CloseHandle(h_map)
178                try:
179                    size = _winapi.VirtualQuerySize(p_buf)
180                finally:
181                    _winapi.UnmapViewOfFile(p_buf)
182                self._mmap = mmap.mmap(-1, size, tagname=name)
183
184        self._size = size
185        self._buf = memoryview(self._mmap)
186
187    def __del__(self):
188        try:
189            self.close()
190        except OSError:
191            pass
192
193    def __reduce__(self):
194        return (
195            self.__class__,
196            (
197                self.name,
198                False,
199                self.size,
200            ),
201        )
202
203    def __repr__(self):
204        return f'{self.__class__.__name__}({self.name!r}, size={self.size})'
205
206    @property
207    def buf(self):
208        "A memoryview of contents of the shared memory block."
209        return self._buf
210
211    @property
212    def name(self):
213        "Unique name that identifies the shared memory block."
214        reported_name = self._name
215        if _USE_POSIX and self._prepend_leading_slash:
216            if self._name.startswith("/"):
217                reported_name = self._name[1:]
218        return reported_name
219
220    @property
221    def size(self):
222        "Size in bytes."
223        return self._size
224
225    def close(self):
226        """Closes access to the shared memory from this instance but does
227        not destroy the shared memory block."""
228        if self._buf is not None:
229            self._buf.release()
230            self._buf = None
231        if self._mmap is not None:
232            self._mmap.close()
233            self._mmap = None
234        if _USE_POSIX and self._fd >= 0:
235            os.close(self._fd)
236            self._fd = -1
237
238    def unlink(self):
239        """Requests that the underlying shared memory block be destroyed.
240
241        Unlink should be called once (and only once) across all handles
242        which have access to the shared memory block, even if these
243        handles belong to different processes. Closing and unlinking may
244        happen in any order, but trying to access data inside a shared
245        memory block after unlinking may result in memory errors,
246        depending on platform.
247
248        This method has no effect on Windows, where the only way to
249        delete a shared memory block is to close all handles."""
250
251        if _USE_POSIX and self._name:
252            _posixshmem.shm_unlink(self._name)
253            if self._track:
254                resource_tracker.unregister(self._name, "shared_memory")
255
256
257_encoding = "utf8"
258
259class ShareableList:
260    """Pattern for a mutable list-like object shareable via a shared
261    memory block.  It differs from the built-in list type in that these
262    lists can not change their overall length (i.e. no append, insert,
263    etc.)
264
265    Because values are packed into a memoryview as bytes, the struct
266    packing format for any storable value must require no more than 8
267    characters to describe its format."""
268
269    # The shared memory area is organized as follows:
270    # - 8 bytes: number of items (N) as a 64-bit integer
271    # - (N + 1) * 8 bytes: offsets of each element from the start of the
272    #                      data area
273    # - K bytes: the data area storing item values (with encoding and size
274    #            depending on their respective types)
275    # - N * 8 bytes: `struct` format string for each element
276    # - N bytes: index into _back_transforms_mapping for each element
277    #            (for reconstructing the corresponding Python value)
278    _types_mapping = {
279        int: "q",
280        float: "d",
281        bool: "xxxxxxx?",
282        str: "%ds",
283        bytes: "%ds",
284        None.__class__: "xxxxxx?x",
285    }
286    _alignment = 8
287    _back_transforms_mapping = {
288        0: lambda value: value,                   # int, float, bool
289        1: lambda value: value.rstrip(b'\x00').decode(_encoding),  # str
290        2: lambda value: value.rstrip(b'\x00'),   # bytes
291        3: lambda _value: None,                   # None
292    }
293
294    @staticmethod
295    def _extract_recreation_code(value):
296        """Used in concert with _back_transforms_mapping to convert values
297        into the appropriate Python objects when retrieving them from
298        the list as well as when storing them."""
299        if not isinstance(value, (str, bytes, None.__class__)):
300            return 0
301        elif isinstance(value, str):
302            return 1
303        elif isinstance(value, bytes):
304            return 2
305        else:
306            return 3  # NoneType
307
308    def __init__(self, sequence=None, *, name=None):
309        if name is None or sequence is not None:
310            sequence = sequence or ()
311            _formats = [
312                self._types_mapping[type(item)]
313                    if not isinstance(item, (str, bytes))
314                    else self._types_mapping[type(item)] % (
315                        self._alignment * (len(item) // self._alignment + 1),
316                    )
317                for item in sequence
318            ]
319            self._list_len = len(_formats)
320            assert sum(len(fmt) <= 8 for fmt in _formats) == self._list_len
321            offset = 0
322            # The offsets of each list element into the shared memory's
323            # data area (0 meaning the start of the data area, not the start
324            # of the shared memory area).
325            self._allocated_offsets = [0]
326            for fmt in _formats:
327                offset += self._alignment if fmt[-1] != "s" else int(fmt[:-1])
328                self._allocated_offsets.append(offset)
329            _recreation_codes = [
330                self._extract_recreation_code(item) for item in sequence
331            ]
332            requested_size = struct.calcsize(
333                "q" + self._format_size_metainfo +
334                "".join(_formats) +
335                self._format_packing_metainfo +
336                self._format_back_transform_codes
337            )
338
339            self.shm = SharedMemory(name, create=True, size=requested_size)
340        else:
341            self.shm = SharedMemory(name)
342
343        if sequence is not None:
344            _enc = _encoding
345            struct.pack_into(
346                "q" + self._format_size_metainfo,
347                self.shm.buf,
348                0,
349                self._list_len,
350                *(self._allocated_offsets)
351            )
352            struct.pack_into(
353                "".join(_formats),
354                self.shm.buf,
355                self._offset_data_start,
356                *(v.encode(_enc) if isinstance(v, str) else v for v in sequence)
357            )
358            struct.pack_into(
359                self._format_packing_metainfo,
360                self.shm.buf,
361                self._offset_packing_formats,
362                *(v.encode(_enc) for v in _formats)
363            )
364            struct.pack_into(
365                self._format_back_transform_codes,
366                self.shm.buf,
367                self._offset_back_transform_codes,
368                *(_recreation_codes)
369            )
370
371        else:
372            self._list_len = len(self)  # Obtains size from offset 0 in buffer.
373            self._allocated_offsets = list(
374                struct.unpack_from(
375                    self._format_size_metainfo,
376                    self.shm.buf,
377                    1 * 8
378                )
379            )
380
381    def _get_packing_format(self, position):
382        "Gets the packing format for a single value stored in the list."
383        position = position if position >= 0 else position + self._list_len
384        if (position >= self._list_len) or (self._list_len < 0):
385            raise IndexError("Requested position out of range.")
386
387        v = struct.unpack_from(
388            "8s",
389            self.shm.buf,
390            self._offset_packing_formats + position * 8
391        )[0]
392        fmt = v.rstrip(b'\x00')
393        fmt_as_str = fmt.decode(_encoding)
394
395        return fmt_as_str
396
397    def _get_back_transform(self, position):
398        "Gets the back transformation function for a single value."
399
400        if (position >= self._list_len) or (self._list_len < 0):
401            raise IndexError("Requested position out of range.")
402
403        transform_code = struct.unpack_from(
404            "b",
405            self.shm.buf,
406            self._offset_back_transform_codes + position
407        )[0]
408        transform_function = self._back_transforms_mapping[transform_code]
409
410        return transform_function
411
412    def _set_packing_format_and_transform(self, position, fmt_as_str, value):
413        """Sets the packing format and back transformation code for a
414        single value in the list at the specified position."""
415
416        if (position >= self._list_len) or (self._list_len < 0):
417            raise IndexError("Requested position out of range.")
418
419        struct.pack_into(
420            "8s",
421            self.shm.buf,
422            self._offset_packing_formats + position * 8,
423            fmt_as_str.encode(_encoding)
424        )
425
426        transform_code = self._extract_recreation_code(value)
427        struct.pack_into(
428            "b",
429            self.shm.buf,
430            self._offset_back_transform_codes + position,
431            transform_code
432        )
433
434    def __getitem__(self, position):
435        position = position if position >= 0 else position + self._list_len
436        try:
437            offset = self._offset_data_start + self._allocated_offsets[position]
438            (v,) = struct.unpack_from(
439                self._get_packing_format(position),
440                self.shm.buf,
441                offset
442            )
443        except IndexError:
444            raise IndexError("index out of range")
445
446        back_transform = self._get_back_transform(position)
447        v = back_transform(v)
448
449        return v
450
451    def __setitem__(self, position, value):
452        position = position if position >= 0 else position + self._list_len
453        try:
454            item_offset = self._allocated_offsets[position]
455            offset = self._offset_data_start + item_offset
456            current_format = self._get_packing_format(position)
457        except IndexError:
458            raise IndexError("assignment index out of range")
459
460        if not isinstance(value, (str, bytes)):
461            new_format = self._types_mapping[type(value)]
462            encoded_value = value
463        else:
464            allocated_length = self._allocated_offsets[position + 1] - item_offset
465
466            encoded_value = (value.encode(_encoding)
467                             if isinstance(value, str) else value)
468            if len(encoded_value) > allocated_length:
469                raise ValueError("bytes/str item exceeds available storage")
470            if current_format[-1] == "s":
471                new_format = current_format
472            else:
473                new_format = self._types_mapping[str] % (
474                    allocated_length,
475                )
476
477        self._set_packing_format_and_transform(
478            position,
479            new_format,
480            value
481        )
482        struct.pack_into(new_format, self.shm.buf, offset, encoded_value)
483
484    def __reduce__(self):
485        return partial(self.__class__, name=self.shm.name), ()
486
487    def __len__(self):
488        return struct.unpack_from("q", self.shm.buf, 0)[0]
489
490    def __repr__(self):
491        return f'{self.__class__.__name__}({list(self)}, name={self.shm.name!r})'
492
493    @property
494    def format(self):
495        "The struct packing format used by all currently stored items."
496        return "".join(
497            self._get_packing_format(i) for i in range(self._list_len)
498        )
499
500    @property
501    def _format_size_metainfo(self):
502        "The struct packing format used for the items' storage offsets."
503        return "q" * (self._list_len + 1)
504
505    @property
506    def _format_packing_metainfo(self):
507        "The struct packing format used for the items' packing formats."
508        return "8s" * self._list_len
509
510    @property
511    def _format_back_transform_codes(self):
512        "The struct packing format used for the items' back transforms."
513        return "b" * self._list_len
514
515    @property
516    def _offset_data_start(self):
517        # - 8 bytes for the list length
518        # - (N + 1) * 8 bytes for the element offsets
519        return (self._list_len + 2) * 8
520
521    @property
522    def _offset_packing_formats(self):
523        return self._offset_data_start + self._allocated_offsets[-1]
524
525    @property
526    def _offset_back_transform_codes(self):
527        return self._offset_packing_formats + self._list_len * 8
528
529    def count(self, value):
530        "L.count(value) -> integer -- return number of occurrences of value."
531
532        return sum(value == entry for entry in self)
533
534    def index(self, value):
535        """L.index(value) -> integer -- return first index of value.
536        Raises ValueError if the value is not present."""
537
538        for position, entry in enumerate(self):
539            if value == entry:
540                return position
541        else:
542            raise ValueError(f"{value!r} not in this container")
543
544    __class_getitem__ = classmethod(types.GenericAlias)
545