• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Licensed under the Apache License, Version 2.0 (the "License");
2# you may not use this file except in compliance with the License.
3# You may obtain a copy of the License at
4#
5#      http://www.apache.org/licenses/LICENSE-2.0
6#
7# Unless required by applicable law or agreed to in writing, software
8# distributed under the License is distributed on an "AS IS" BASIS,
9# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10# See the License for the specific language governing permissions and
11# limitations under the License.
12
13"""Helper classes use for fake file system implementation."""
14
15import ctypes
16import importlib
17import io
18import locale
19import os
20import platform
21import stat
22import sys
23import sysconfig
24import time
25import traceback
26from collections import namedtuple
27from copy import copy
28from dataclasses import dataclass
29from enum import Enum
30from stat import S_IFLNK
31from typing import Union, Optional, Any, AnyStr, overload, cast
32
33AnyString = Union[str, bytes]
34AnyPath = Union[AnyStr, os.PathLike]
35
36IS_PYPY = platform.python_implementation() == "PyPy"
37IS_WIN = sys.platform == "win32"
38IN_DOCKER = os.path.exists("/.dockerenv")
39
40PERM_READ = 0o400  # Read permission bit.
41PERM_WRITE = 0o200  # Write permission bit.
42PERM_EXE = 0o100  # Execute permission bit.
43PERM_DEF = 0o777  # Default permission bits.
44PERM_DEF_FILE = 0o666  # Default permission bits (regular file)
45PERM_ALL = 0o7777  # All permission bits.
46
47STDLIB_PATH = os.path.realpath(sysconfig.get_path("stdlib"))
48PYFAKEFS_PATH = os.path.dirname(__file__)
49PYFAKEFS_TEST_PATHS = [
50    os.path.join(PYFAKEFS_PATH, "tests"),
51    os.path.join(PYFAKEFS_PATH, "pytest_tests"),
52]
53
54_OpenModes = namedtuple(
55    "_OpenModes",
56    "must_exist can_read can_write truncate append must_not_exist",
57)
58
59if sys.platform == "win32":
60    fake_id = 0 if ctypes.windll.shell32.IsUserAnAdmin() else 1
61    USER_ID = fake_id
62    GROUP_ID = fake_id
63else:
64    USER_ID = os.getuid()
65    GROUP_ID = os.getgid()
66
67
68def get_uid() -> int:
69    """Get the global user id. Same as ``os.getuid()``"""
70    return USER_ID
71
72
73def set_uid(uid: int) -> None:
74    """Set the global user id. This is used as st_uid for new files
75    and to differentiate between a normal user and the root user (uid 0).
76    For the root user, some permission restrictions are ignored.
77
78    Args:
79        uid: (int) the user ID of the user calling the file system functions.
80    """
81    global USER_ID
82    USER_ID = uid
83
84
85def get_gid() -> int:
86    """Get the global group id. Same as ``os.getgid()``"""
87    return GROUP_ID
88
89
90def set_gid(gid: int) -> None:
91    """Set the global group id. This is only used to set st_gid for new files,
92    no permission checks are performed.
93
94    Args:
95        gid: (int) the group ID of the user calling the file system functions.
96    """
97    global GROUP_ID
98    GROUP_ID = gid
99
100
101def reset_ids() -> None:
102    """Set the global user ID and group ID back to default values."""
103    if sys.platform == "win32":
104        reset_id = 0 if ctypes.windll.shell32.IsUserAnAdmin() else 1
105        set_uid(reset_id)
106        set_gid(reset_id)
107    else:
108        set_uid(os.getuid())
109        set_gid(os.getgid())
110
111
112def is_root() -> bool:
113    """Return True if the current user is the root user."""
114    return USER_ID == 0
115
116
117def is_int_type(val: Any) -> bool:
118    """Return True if `val` is of integer type."""
119    return isinstance(val, int)
120
121
122def is_byte_string(val: Any) -> bool:
123    """Return True if `val` is a bytes-like object, False for a unicode
124    string."""
125    return not hasattr(val, "encode")
126
127
128def is_unicode_string(val: Any) -> bool:
129    """Return True if `val` is a unicode string, False for a bytes-like
130    object."""
131    return hasattr(val, "encode")
132
133
134def get_locale_encoding():
135    if sys.version_info >= (3, 11):
136        return locale.getencoding()
137    return locale.getpreferredencoding(False)
138
139
140@overload
141def make_string_path(dir_name: AnyStr) -> AnyStr: ...
142
143
144@overload
145def make_string_path(dir_name: os.PathLike) -> str: ...
146
147
148def make_string_path(dir_name: AnyPath) -> AnyStr:  # type: ignore[type-var]
149    return cast(AnyStr, os.fspath(dir_name))  # pytype: disable=invalid-annotation
150
151
152def to_string(path: Union[AnyStr, Union[str, bytes]]) -> str:
153    """Return the string representation of a byte string using the preferred
154    encoding, or the string itself if path is a str."""
155    if isinstance(path, bytes):
156        return path.decode(get_locale_encoding())
157    return path
158
159
160def to_bytes(path: Union[AnyStr, Union[str, bytes]]) -> bytes:
161    """Return the bytes representation of a string using the preferred
162    encoding, or the byte string itself if path is a byte string."""
163    if isinstance(path, str):
164        return bytes(path, get_locale_encoding())
165    return path
166
167
168def join_strings(s1: AnyStr, s2: AnyStr) -> AnyStr:
169    """This is a bit of a hack to satisfy mypy - may be refactored."""
170    return s1 + s2
171
172
173def real_encoding(encoding: Optional[str]) -> Optional[str]:
174    """Since Python 3.10, the new function ``io.text_encoding`` returns
175    "locale" as the encoding if None is defined. This will be handled
176    as no encoding in pyfakefs."""
177    if sys.version_info >= (3, 10):
178        return encoding if encoding != "locale" else None
179    return encoding
180
181
182def now():
183    return time.time()
184
185
186@overload
187def matching_string(matched: bytes, string: AnyStr) -> bytes: ...
188
189
190@overload
191def matching_string(matched: str, string: AnyStr) -> str: ...
192
193
194@overload
195def matching_string(matched: AnyStr, string: None) -> None: ...
196
197
198def matching_string(  # type: ignore[misc]
199    matched: AnyStr, string: Optional[AnyStr]
200) -> Optional[AnyString]:
201    """Return the string as byte or unicode depending
202    on the type of matched, assuming string is an ASCII string.
203    """
204    if string is None:
205        return string
206    if isinstance(matched, bytes) and isinstance(string, str):
207        return string.encode(get_locale_encoding())
208    return string  # pytype: disable=bad-return-type
209
210
211@dataclass
212class FSProperties:
213    sep: str
214    altsep: Optional[str]
215    pathsep: str
216    linesep: str
217    devnull: str
218
219
220# pure POSIX file system properties, for use with PosixPath
221POSIX_PROPERTIES = FSProperties(
222    sep="/",
223    altsep=None,
224    pathsep=":",
225    linesep="\n",
226    devnull="/dev/null",
227)
228
229# pure Windows file system properties, for use with WindowsPath
230WINDOWS_PROPERTIES = FSProperties(
231    sep="\\",
232    altsep="/",
233    pathsep=";",
234    linesep="\r\n",
235    devnull="NUL",
236)
237
238
239class FSType(Enum):
240    """Defines which file system properties to use."""
241
242    DEFAULT = 0  # use current OS file system + modifications in fake file system
243    POSIX = 1  # pure POSIX properties, for use in PosixPath
244    WINDOWS = 2  # pure Windows properties, for use in WindowsPath
245
246
247class FakeStatResult:
248    """Mimics os.stat_result for use as return type of `stat()` and similar.
249    This is needed as `os.stat_result` has no possibility to set
250    nanosecond times directly.
251    """
252
253    def __init__(
254        self,
255        is_windows: bool,
256        user_id: int,
257        group_id: int,
258        initial_time: Optional[float] = None,
259    ):
260        self.st_mode: int = 0
261        self.st_ino: Optional[int] = None
262        self.st_dev: int = 0
263        self.st_nlink: int = 0
264        self.st_uid: int = user_id
265        self.st_gid: int = group_id
266        self._st_size: int = 0
267        self.is_windows: bool = is_windows
268        self._st_atime_ns: int = int((initial_time or 0) * 1e9)
269        self._st_mtime_ns: int = self._st_atime_ns
270        self._st_ctime_ns: int = self._st_atime_ns
271
272    def __eq__(self, other: Any) -> bool:
273        return (
274            isinstance(other, FakeStatResult)
275            and self._st_atime_ns == other._st_atime_ns
276            and self._st_ctime_ns == other._st_ctime_ns
277            and self._st_mtime_ns == other._st_mtime_ns
278            and self.st_size == other.st_size
279            and self.st_gid == other.st_gid
280            and self.st_uid == other.st_uid
281            and self.st_nlink == other.st_nlink
282            and self.st_dev == other.st_dev
283            and self.st_ino == other.st_ino
284            and self.st_mode == other.st_mode
285        )
286
287    def __ne__(self, other: Any) -> bool:
288        return not self == other
289
290    def copy(self) -> "FakeStatResult":
291        """Return a copy where the float usage is hard-coded to mimic the
292        behavior of the real os.stat_result.
293        """
294        stat_result = copy(self)
295        return stat_result
296
297    def set_from_stat_result(self, stat_result: os.stat_result) -> None:
298        """Set values from a real os.stat_result.
299        Note: values that are controlled by the fake filesystem are not set.
300        This includes st_ino, st_dev and st_nlink.
301        """
302        self.st_mode = stat_result.st_mode
303        self.st_uid = stat_result.st_uid
304        self.st_gid = stat_result.st_gid
305        self._st_size = stat_result.st_size
306        self._st_atime_ns = stat_result.st_atime_ns
307        self._st_mtime_ns = stat_result.st_mtime_ns
308        self._st_ctime_ns = stat_result.st_ctime_ns
309
310    @property
311    def st_ctime(self) -> Union[int, float]:
312        """Return the creation time in seconds."""
313        return self._st_ctime_ns / 1e9
314
315    @st_ctime.setter
316    def st_ctime(self, val: Union[int, float]) -> None:
317        """Set the creation time in seconds."""
318        self._st_ctime_ns = int(val * 1e9)
319
320    @property
321    def st_atime(self) -> Union[int, float]:
322        """Return the access time in seconds."""
323        return self._st_atime_ns / 1e9
324
325    @st_atime.setter
326    def st_atime(self, val: Union[int, float]) -> None:
327        """Set the access time in seconds."""
328        self._st_atime_ns = int(val * 1e9)
329
330    @property
331    def st_mtime(self) -> Union[int, float]:
332        """Return the modification time in seconds."""
333        return self._st_mtime_ns / 1e9
334
335    @st_mtime.setter
336    def st_mtime(self, val: Union[int, float]) -> None:
337        """Set the modification time in seconds."""
338        self._st_mtime_ns = int(val * 1e9)
339
340    @property
341    def st_size(self) -> int:
342        if self.st_mode & S_IFLNK == S_IFLNK and self.is_windows:
343            return 0
344        return self._st_size
345
346    @st_size.setter
347    def st_size(self, val: int) -> None:
348        self._st_size = val
349
350    @property
351    def st_blocks(self) -> int:
352        """Return the number of 512-byte blocks allocated for the file.
353        Assumes a page size of 4096 (matches most systems).
354        Ignores that this may not be available under some systems,
355        and that the result may differ if the file has holes.
356        """
357        if self.is_windows:
358            raise AttributeError("'os.stat_result' object has no attribute 'st_blocks'")
359        page_size = 4096
360        blocks_in_page = page_size // 512
361        pages = self._st_size // page_size
362        if self._st_size % page_size:
363            pages += 1
364        return pages * blocks_in_page
365
366    @property
367    def st_file_attributes(self) -> int:
368        if not self.is_windows:
369            raise AttributeError(
370                "module 'os.stat_result' has no attribute 'st_file_attributes'"
371            )
372        mode = 0
373        st_mode = self.st_mode
374        if st_mode & stat.S_IFDIR:
375            mode |= stat.FILE_ATTRIBUTE_DIRECTORY  # type:ignore[attr-defined]
376        if st_mode & stat.S_IFREG:
377            mode |= stat.FILE_ATTRIBUTE_NORMAL  # type:ignore[attr-defined]
378        if st_mode & (stat.S_IFCHR | stat.S_IFBLK):
379            mode |= stat.FILE_ATTRIBUTE_DEVICE  # type:ignore[attr-defined]
380        if st_mode & stat.S_IFLNK:
381            mode |= stat.FILE_ATTRIBUTE_REPARSE_POINT  # type:ignore
382        return mode
383
384    @property
385    def st_reparse_tag(self) -> int:
386        if not self.is_windows or sys.version_info < (3, 8):
387            raise AttributeError(
388                "module 'os.stat_result' has no attribute 'st_reparse_tag'"
389            )
390        if self.st_mode & stat.S_IFLNK:
391            return stat.IO_REPARSE_TAG_SYMLINK  # type: ignore[attr-defined]
392        return 0
393
394    def __getitem__(self, item: int) -> Optional[int]:
395        """Implement item access to mimic `os.stat_result` behavior."""
396        import stat
397
398        if item == stat.ST_MODE:
399            return self.st_mode
400        if item == stat.ST_INO:
401            return self.st_ino
402        if item == stat.ST_DEV:
403            return self.st_dev
404        if item == stat.ST_NLINK:
405            return self.st_nlink
406        if item == stat.ST_UID:
407            return self.st_uid
408        if item == stat.ST_GID:
409            return self.st_gid
410        if item == stat.ST_SIZE:
411            return self.st_size
412        if item == stat.ST_ATIME:
413            # item access always returns int for backward compatibility
414            return int(self.st_atime)
415        if item == stat.ST_MTIME:
416            return int(self.st_mtime)
417        if item == stat.ST_CTIME:
418            return int(self.st_ctime)
419        raise ValueError("Invalid item")
420
421    @property
422    def st_atime_ns(self) -> int:
423        """Return the access time in nanoseconds."""
424        return self._st_atime_ns
425
426    @st_atime_ns.setter
427    def st_atime_ns(self, val: int) -> None:
428        """Set the access time in nanoseconds."""
429        self._st_atime_ns = val
430
431    @property
432    def st_mtime_ns(self) -> int:
433        """Return the modification time in nanoseconds."""
434        return self._st_mtime_ns
435
436    @st_mtime_ns.setter
437    def st_mtime_ns(self, val: int) -> None:
438        """Set the modification time of the fake file in nanoseconds."""
439        self._st_mtime_ns = val
440
441    @property
442    def st_ctime_ns(self) -> int:
443        """Return the creation time in nanoseconds."""
444        return self._st_ctime_ns
445
446    @st_ctime_ns.setter
447    def st_ctime_ns(self, val: int) -> None:
448        """Set the creation time of the fake file in nanoseconds."""
449        self._st_ctime_ns = val
450
451
452class BinaryBufferIO(io.BytesIO):
453    """Stream class that handles byte contents for files."""
454
455    def __init__(self, contents: Optional[bytes]):
456        super().__init__(contents or b"")
457
458    def putvalue(self, value: bytes) -> None:
459        self.write(value)
460
461
462class TextBufferIO(io.TextIOWrapper):
463    """Stream class that handles Python string contents for files."""
464
465    def __init__(
466        self,
467        contents: Optional[bytes] = None,
468        newline: Optional[str] = None,
469        encoding: Optional[str] = None,
470        errors: str = "strict",
471    ):
472        self._bytestream = io.BytesIO(contents or b"")
473        super().__init__(self._bytestream, encoding, errors, newline)
474
475    def getvalue(self) -> bytes:
476        return self._bytestream.getvalue()
477
478    def putvalue(self, value: bytes) -> None:
479        self._bytestream.write(value)
480
481
482def is_called_from_skipped_module(
483    skip_names: list, case_sensitive: bool, check_open_code: bool = False
484) -> bool:
485    def starts_with(path, string):
486        if case_sensitive:
487            return path.startswith(string)
488        return path.lower().startswith(string.lower())
489
490    # in most cases we don't have skip names and won't need the overhead
491    # of analyzing the traceback, except when checking for open_code
492    if not skip_names and not check_open_code:
493        return False
494
495    stack = traceback.extract_stack()
496
497    # handle the case that we try to call the original `open_code`
498    # (since Python 3.12)
499    # The stack in this case is:
500    # -1: helpers.is_called_from_skipped_module: 'stack = traceback.extract_stack()'
501    # -2: fake_open.fake_open: 'if is_called_from_skipped_module('
502    # -3: fake_io.open: 'return fake_open('
503    # -4: fake_io.open_code : 'return self._io_module.open_code(path)'
504    if (
505        check_open_code
506        and stack[-4].name == "open_code"
507        and stack[-4].line == "return self._io_module.open_code(path)"
508    ):
509        return True
510
511    if not skip_names:
512        return False
513
514    caller_filename = next(
515        (
516            frame.filename
517            for frame in stack[::-1]
518            if not frame.filename.startswith("<frozen ")
519            and not starts_with(frame.filename, STDLIB_PATH)
520            and (
521                not starts_with(frame.filename, PYFAKEFS_PATH)
522                or any(
523                    starts_with(frame.filename, test_path)
524                    for test_path in PYFAKEFS_TEST_PATHS
525                )
526            )
527        ),
528        None,
529    )
530
531    if caller_filename:
532        caller_module_name = os.path.splitext(caller_filename)[0]
533        caller_module_name = caller_module_name.replace(os.sep, ".")
534
535        if any(
536            [
537                caller_module_name == sn or caller_module_name.endswith("." + sn)
538                for sn in skip_names
539            ]
540        ):
541            return True
542    return False
543
544
545def reload_cleanup_handler(name):
546    """Cleanup handler that reloads the module with the given name.
547    Maybe needed in cases where a module is imported locally.
548    """
549    if name in sys.modules:
550        importlib.reload(sys.modules[name])
551    return True
552