• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Licensed under the Apache License, Version 2.0 (the "License");
2# you may not use this file except in compliance with the License.
3# You may obtain a copy of the License at
4#
5#      http://www.apache.org/licenses/LICENSE-2.0
6#
7# Unless required by applicable law or agreed to in writing, software
8# distributed under the License is distributed on an "AS IS" BASIS,
9# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10# See the License for the specific language governing permissions and
11# limitations under the License.
12
13"""
14Provides patches for some commonly used modules that enable them to work
15with pyfakefs.
16"""
17
18import sys
19from importlib import reload
20
21try:
22    import pandas as pd
23
24    try:
25        import pandas.io.parsers as parsers
26    except ImportError:
27        parsers = None
28except ImportError:
29    pd = None
30    parsers = None
31
32
33try:
34    import xlrd
35except ImportError:
36    xlrd = None
37
38
39try:
40    import django
41
42    try:
43        from django.core.files import locks
44    except ImportError:
45        locks = None
46except ImportError:
47    django = None
48    locks = None
49
50# From pandas v 1.2 onwards the python fs functions are used even when the engine
51# selected is "c". This means that we don't explicitly have to change the engine.
52patch_pandas = parsers is not None and [int(v) for v in pd.__version__.split(".")] < [
53    1,
54    2,
55    0,
56]
57
58
59def get_modules_to_patch():
60    modules_to_patch = {}
61    if xlrd is not None:
62        modules_to_patch["xlrd"] = XLRDModule
63    if locks is not None:
64        modules_to_patch["django.core.files.locks"] = FakeLocks
65    return modules_to_patch
66
67
68def get_classes_to_patch():
69    classes_to_patch = {}
70    if patch_pandas:
71        classes_to_patch["TextFileReader"] = ["pandas.io.parsers"]
72    return classes_to_patch
73
74
75def reload_handler(name):
76    if name in sys.modules:
77        reload(sys.modules[name])
78    return True
79
80
81def get_cleanup_handlers():
82    handlers = {}
83    if pd is not None:
84        handlers["pandas.core.arrays.arrow.extension_types"] = (
85            handle_extension_type_cleanup
86        )
87    if django is not None:
88        for module_name in django_view_modules():
89            handlers[module_name] = lambda name=module_name: reload_handler(name)
90    return handlers
91
92
93def get_fake_module_classes():
94    fake_module_classes = {}
95    if patch_pandas:
96        fake_module_classes["TextFileReader"] = FakeTextFileReader
97    return fake_module_classes
98
99
100if xlrd is not None:
101
102    class XLRDModule:
103        """Patches the xlrd module, which is used as the default Excel file
104        reader by pandas. Disables using memory mapped files, which are
105        implemented platform-specific on OS level."""
106
107        def __init__(self, _):
108            self._xlrd_module = xlrd
109
110        def open_workbook(
111            self,
112            filename=None,
113            logfile=sys.stdout,
114            verbosity=0,
115            use_mmap=False,
116            file_contents=None,
117            encoding_override=None,
118            formatting_info=False,
119            on_demand=False,
120            ragged_rows=False,
121        ):
122            return self._xlrd_module.open_workbook(
123                filename,
124                logfile,
125                verbosity,
126                False,
127                file_contents,
128                encoding_override,
129                formatting_info,
130                on_demand,
131                ragged_rows,
132            )
133
134        def __getattr__(self, name):
135            """Forwards any unfaked calls to the standard xlrd module."""
136            return getattr(self._xlrd_module, name)
137
138
139if patch_pandas:
140    # we currently need to add fake modules for both the parser module and
141    # the contained text reader - maybe this can be simplified
142
143    class FakeTextFileReader:
144        fake_parsers = None
145
146        def __init__(self, filesystem):
147            if self.fake_parsers is None:
148                self.__class__.fake_parsers = ParsersModule(filesystem)
149
150        def __call__(self, *args, **kwargs):
151            return self.fake_parsers.TextFileReader(*args, **kwargs)
152
153        def __getattr__(self, name):
154            return getattr(self.fake_parsers.TextFileReader, name)
155
156    class ParsersModule:
157        def __init__(self, _):
158            self._parsers_module = parsers
159
160        class TextFileReader(parsers.TextFileReader):
161            def __init__(self, *args, **kwargs):
162                kwargs["engine"] = "python"
163                super().__init__(*args, **kwargs)
164
165        def __getattr__(self, name):
166            """Forwards any unfaked calls to the standard xlrd module."""
167            return getattr(self._parsers_module, name)
168
169
170if pd is not None:
171
172    def handle_extension_type_cleanup(_name):
173        # the module registers two extension types on load
174        # on reload it raises if the extensions have not been unregistered before
175        try:
176            import pyarrow
177
178            # the code to register these types has been in the module
179            # since it was created (in pandas 1.5)
180            pyarrow.unregister_extension_type("pandas.interval")
181            pyarrow.unregister_extension_type("pandas.period")
182        except ImportError:
183            pass
184        return False
185
186
187if locks is not None:
188
189    class FakeLocks:
190        """django.core.files.locks uses low level OS functions, fake it."""
191
192        _locks_module = locks
193
194        def __init__(self, _):
195            pass
196
197        @staticmethod
198        def lock(f, flags):
199            return True
200
201        @staticmethod
202        def unlock(f):
203            return True
204
205        def __getattr__(self, name):
206            return getattr(self._locks_module, name)
207
208
209if django is not None:
210
211    def get_all_view_modules(urlpatterns, modules=None):
212        if modules is None:
213            modules = set()
214        for pattern in urlpatterns:
215            if hasattr(pattern, "url_patterns"):
216                get_all_view_modules(pattern.url_patterns, modules=modules)
217            else:
218                if hasattr(pattern.callback, "cls"):
219                    view = pattern.callback.cls
220                elif hasattr(pattern.callback, "view_class"):
221                    view = pattern.callback.view_class
222                else:
223                    view = pattern.callback
224                modules.add(view.__module__)
225        return modules
226
227    def django_view_modules():
228        try:
229            all_urlpatterns = __import__(
230                django.conf.settings.ROOT_URLCONF
231            ).urls.urlpatterns
232            return get_all_view_modules(all_urlpatterns)
233        except Exception:
234            return set()
235