• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import collections
2import contextlib
3import itertools
4import pathlib
5import operator
6import re
7import warnings
8import zipfile
9
10from . import abc
11
12from ._itertools import only
13
14
15def remove_duplicates(items):
16    return iter(collections.OrderedDict.fromkeys(items))
17
18
19class FileReader(abc.TraversableResources):
20    def __init__(self, loader):
21        self.path = pathlib.Path(loader.path).parent
22
23    def resource_path(self, resource):
24        """
25        Return the file system path to prevent
26        `resources.path()` from creating a temporary
27        copy.
28        """
29        return str(self.path.joinpath(resource))
30
31    def files(self):
32        return self.path
33
34
35class ZipReader(abc.TraversableResources):
36    def __init__(self, loader, module):
37        self.prefix = loader.prefix.replace('\\', '/')
38        if loader.is_package(module):
39            _, _, name = module.rpartition('.')
40            self.prefix += name + '/'
41        self.archive = loader.archive
42
43    def open_resource(self, resource):
44        try:
45            return super().open_resource(resource)
46        except KeyError as exc:
47            raise FileNotFoundError(exc.args[0])
48
49    def is_resource(self, path):
50        """
51        Workaround for `zipfile.Path.is_file` returning true
52        for non-existent paths.
53        """
54        target = self.files().joinpath(path)
55        return target.is_file() and target.exists()
56
57    def files(self):
58        return zipfile.Path(self.archive, self.prefix)
59
60
61class MultiplexedPath(abc.Traversable):
62    """
63    Given a series of Traversable objects, implement a merged
64    version of the interface across all objects. Useful for
65    namespace packages which may be multihomed at a single
66    name.
67    """
68
69    def __init__(self, *paths):
70        self._paths = list(map(_ensure_traversable, remove_duplicates(paths)))
71        if not self._paths:
72            message = 'MultiplexedPath must contain at least one path'
73            raise FileNotFoundError(message)
74        if not all(path.is_dir() for path in self._paths):
75            raise NotADirectoryError('MultiplexedPath only supports directories')
76
77    def iterdir(self):
78        children = (child for path in self._paths for child in path.iterdir())
79        by_name = operator.attrgetter('name')
80        groups = itertools.groupby(sorted(children, key=by_name), key=by_name)
81        return map(self._follow, (locs for name, locs in groups))
82
83    def read_bytes(self):
84        raise FileNotFoundError(f'{self} is not a file')
85
86    def read_text(self, *args, **kwargs):
87        raise FileNotFoundError(f'{self} is not a file')
88
89    def is_dir(self):
90        return True
91
92    def is_file(self):
93        return False
94
95    def joinpath(self, *descendants):
96        try:
97            return super().joinpath(*descendants)
98        except abc.TraversalError:
99            # One of the paths did not resolve (a directory does not exist).
100            # Just return something that will not exist.
101            return self._paths[0].joinpath(*descendants)
102
103    @classmethod
104    def _follow(cls, children):
105        """
106        Construct a MultiplexedPath if needed.
107
108        If children contains a sole element, return it.
109        Otherwise, return a MultiplexedPath of the items.
110        Unless one of the items is not a Directory, then return the first.
111        """
112        subdirs, one_dir, one_file = itertools.tee(children, 3)
113
114        try:
115            return only(one_dir)
116        except ValueError:
117            try:
118                return cls(*subdirs)
119            except NotADirectoryError:
120                return next(one_file)
121
122    def open(self, *args, **kwargs):
123        raise FileNotFoundError(f'{self} is not a file')
124
125    @property
126    def name(self):
127        return self._paths[0].name
128
129    def __repr__(self):
130        paths = ', '.join(f"'{path}'" for path in self._paths)
131        return f'MultiplexedPath({paths})'
132
133
134class NamespaceReader(abc.TraversableResources):
135    def __init__(self, namespace_path):
136        if 'NamespacePath' not in str(namespace_path):
137            raise ValueError('Invalid path')
138        self.path = MultiplexedPath(*map(self._resolve, namespace_path))
139
140    @classmethod
141    def _resolve(cls, path_str) -> abc.Traversable:
142        r"""
143        Given an item from a namespace path, resolve it to a Traversable.
144
145        path_str might be a directory on the filesystem or a path to a
146        zipfile plus the path within the zipfile, e.g. ``/foo/bar`` or
147        ``/foo/baz.zip/inner_dir`` or ``foo\baz.zip\inner_dir\sub``.
148        """
149        (dir,) = (cand for cand in cls._candidate_paths(path_str) if cand.is_dir())
150        return dir
151
152    @classmethod
153    def _candidate_paths(cls, path_str):
154        yield pathlib.Path(path_str)
155        yield from cls._resolve_zip_path(path_str)
156
157    @staticmethod
158    def _resolve_zip_path(path_str):
159        for match in reversed(list(re.finditer(r'[\\/]', path_str))):
160            with contextlib.suppress(
161                FileNotFoundError,
162                IsADirectoryError,
163                NotADirectoryError,
164                PermissionError,
165            ):
166                inner = path_str[match.end() :].replace('\\', '/') + '/'
167                yield zipfile.Path(path_str[: match.start()], inner.lstrip('/'))
168
169    def resource_path(self, resource):
170        """
171        Return the file system path to prevent
172        `resources.path()` from creating a temporary
173        copy.
174        """
175        return str(self.path.joinpath(resource))
176
177    def files(self):
178        return self.path
179
180
181def _ensure_traversable(path):
182    """
183    Convert deprecated string arguments to traversables (pathlib.Path).
184
185    Remove with Python 3.15.
186    """
187    if not isinstance(path, str):
188        return path
189
190    warnings.warn(
191        "String arguments are deprecated. Pass a Traversable instead.",
192        DeprecationWarning,
193        stacklevel=3,
194    )
195
196    return pathlib.Path(path)
197