1# Licensed under the Apache License, Version 2.0 (the "License"); 2# you may not use this file except in compliance with the License. 3# You may obtain a copy of the License at 4# 5# http://www.apache.org/licenses/LICENSE-2.0 6# 7# Unless required by applicable law or agreed to in writing, software 8# distributed under the License is distributed on an "AS IS" BASIS, 9# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10# See the License for the specific language governing permissions and 11# limitations under the License. 12 13"""Helper classes use for fake file system implementation.""" 14 15import ctypes 16import importlib 17import io 18import locale 19import os 20import platform 21import stat 22import sys 23import sysconfig 24import time 25import traceback 26from collections import namedtuple 27from copy import copy 28from dataclasses import dataclass 29from enum import Enum 30from stat import S_IFLNK 31from typing import Union, Optional, Any, AnyStr, overload, cast 32 33AnyString = Union[str, bytes] 34AnyPath = Union[AnyStr, os.PathLike] 35 36IS_PYPY = platform.python_implementation() == "PyPy" 37IS_WIN = sys.platform == "win32" 38IN_DOCKER = os.path.exists("/.dockerenv") 39 40PERM_READ = 0o400 # Read permission bit. 41PERM_WRITE = 0o200 # Write permission bit. 42PERM_EXE = 0o100 # Execute permission bit. 43PERM_DEF = 0o777 # Default permission bits. 44PERM_DEF_FILE = 0o666 # Default permission bits (regular file) 45PERM_ALL = 0o7777 # All permission bits. 46 47STDLIB_PATH = os.path.realpath(sysconfig.get_path("stdlib")) 48PYFAKEFS_PATH = os.path.dirname(__file__) 49PYFAKEFS_TEST_PATHS = [ 50 os.path.join(PYFAKEFS_PATH, "tests"), 51 os.path.join(PYFAKEFS_PATH, "pytest_tests"), 52] 53 54_OpenModes = namedtuple( 55 "_OpenModes", 56 "must_exist can_read can_write truncate append must_not_exist", 57) 58 59if sys.platform == "win32": 60 fake_id = 0 if ctypes.windll.shell32.IsUserAnAdmin() else 1 61 USER_ID = fake_id 62 GROUP_ID = fake_id 63else: 64 USER_ID = os.getuid() 65 GROUP_ID = os.getgid() 66 67 68def get_uid() -> int: 69 """Get the global user id. Same as ``os.getuid()``""" 70 return USER_ID 71 72 73def set_uid(uid: int) -> None: 74 """Set the global user id. This is used as st_uid for new files 75 and to differentiate between a normal user and the root user (uid 0). 76 For the root user, some permission restrictions are ignored. 77 78 Args: 79 uid: (int) the user ID of the user calling the file system functions. 80 """ 81 global USER_ID 82 USER_ID = uid 83 84 85def get_gid() -> int: 86 """Get the global group id. Same as ``os.getgid()``""" 87 return GROUP_ID 88 89 90def set_gid(gid: int) -> None: 91 """Set the global group id. This is only used to set st_gid for new files, 92 no permission checks are performed. 93 94 Args: 95 gid: (int) the group ID of the user calling the file system functions. 96 """ 97 global GROUP_ID 98 GROUP_ID = gid 99 100 101def reset_ids() -> None: 102 """Set the global user ID and group ID back to default values.""" 103 if sys.platform == "win32": 104 reset_id = 0 if ctypes.windll.shell32.IsUserAnAdmin() else 1 105 set_uid(reset_id) 106 set_gid(reset_id) 107 else: 108 set_uid(os.getuid()) 109 set_gid(os.getgid()) 110 111 112def is_root() -> bool: 113 """Return True if the current user is the root user.""" 114 return USER_ID == 0 115 116 117def is_int_type(val: Any) -> bool: 118 """Return True if `val` is of integer type.""" 119 return isinstance(val, int) 120 121 122def is_byte_string(val: Any) -> bool: 123 """Return True if `val` is a bytes-like object, False for a unicode 124 string.""" 125 return not hasattr(val, "encode") 126 127 128def is_unicode_string(val: Any) -> bool: 129 """Return True if `val` is a unicode string, False for a bytes-like 130 object.""" 131 return hasattr(val, "encode") 132 133 134def get_locale_encoding(): 135 if sys.version_info >= (3, 11): 136 return locale.getencoding() 137 return locale.getpreferredencoding(False) 138 139 140@overload 141def make_string_path(dir_name: AnyStr) -> AnyStr: ... 142 143 144@overload 145def make_string_path(dir_name: os.PathLike) -> str: ... 146 147 148def make_string_path(dir_name: AnyPath) -> AnyStr: # type: ignore[type-var] 149 return cast(AnyStr, os.fspath(dir_name)) # pytype: disable=invalid-annotation 150 151 152def to_string(path: Union[AnyStr, Union[str, bytes]]) -> str: 153 """Return the string representation of a byte string using the preferred 154 encoding, or the string itself if path is a str.""" 155 if isinstance(path, bytes): 156 return path.decode(get_locale_encoding()) 157 return path 158 159 160def to_bytes(path: Union[AnyStr, Union[str, bytes]]) -> bytes: 161 """Return the bytes representation of a string using the preferred 162 encoding, or the byte string itself if path is a byte string.""" 163 if isinstance(path, str): 164 return bytes(path, get_locale_encoding()) 165 return path 166 167 168def join_strings(s1: AnyStr, s2: AnyStr) -> AnyStr: 169 """This is a bit of a hack to satisfy mypy - may be refactored.""" 170 return s1 + s2 171 172 173def real_encoding(encoding: Optional[str]) -> Optional[str]: 174 """Since Python 3.10, the new function ``io.text_encoding`` returns 175 "locale" as the encoding if None is defined. This will be handled 176 as no encoding in pyfakefs.""" 177 if sys.version_info >= (3, 10): 178 return encoding if encoding != "locale" else None 179 return encoding 180 181 182def now(): 183 return time.time() 184 185 186@overload 187def matching_string(matched: bytes, string: AnyStr) -> bytes: ... 188 189 190@overload 191def matching_string(matched: str, string: AnyStr) -> str: ... 192 193 194@overload 195def matching_string(matched: AnyStr, string: None) -> None: ... 196 197 198def matching_string( # type: ignore[misc] 199 matched: AnyStr, string: Optional[AnyStr] 200) -> Optional[AnyString]: 201 """Return the string as byte or unicode depending 202 on the type of matched, assuming string is an ASCII string. 203 """ 204 if string is None: 205 return string 206 if isinstance(matched, bytes) and isinstance(string, str): 207 return string.encode(get_locale_encoding()) 208 return string # pytype: disable=bad-return-type 209 210 211@dataclass 212class FSProperties: 213 sep: str 214 altsep: Optional[str] 215 pathsep: str 216 linesep: str 217 devnull: str 218 219 220# pure POSIX file system properties, for use with PosixPath 221POSIX_PROPERTIES = FSProperties( 222 sep="/", 223 altsep=None, 224 pathsep=":", 225 linesep="\n", 226 devnull="/dev/null", 227) 228 229# pure Windows file system properties, for use with WindowsPath 230WINDOWS_PROPERTIES = FSProperties( 231 sep="\\", 232 altsep="/", 233 pathsep=";", 234 linesep="\r\n", 235 devnull="NUL", 236) 237 238 239class FSType(Enum): 240 """Defines which file system properties to use.""" 241 242 DEFAULT = 0 # use current OS file system + modifications in fake file system 243 POSIX = 1 # pure POSIX properties, for use in PosixPath 244 WINDOWS = 2 # pure Windows properties, for use in WindowsPath 245 246 247class FakeStatResult: 248 """Mimics os.stat_result for use as return type of `stat()` and similar. 249 This is needed as `os.stat_result` has no possibility to set 250 nanosecond times directly. 251 """ 252 253 def __init__( 254 self, 255 is_windows: bool, 256 user_id: int, 257 group_id: int, 258 initial_time: Optional[float] = None, 259 ): 260 self.st_mode: int = 0 261 self.st_ino: Optional[int] = None 262 self.st_dev: int = 0 263 self.st_nlink: int = 0 264 self.st_uid: int = user_id 265 self.st_gid: int = group_id 266 self._st_size: int = 0 267 self.is_windows: bool = is_windows 268 self._st_atime_ns: int = int((initial_time or 0) * 1e9) 269 self._st_mtime_ns: int = self._st_atime_ns 270 self._st_ctime_ns: int = self._st_atime_ns 271 272 def __eq__(self, other: Any) -> bool: 273 return ( 274 isinstance(other, FakeStatResult) 275 and self._st_atime_ns == other._st_atime_ns 276 and self._st_ctime_ns == other._st_ctime_ns 277 and self._st_mtime_ns == other._st_mtime_ns 278 and self.st_size == other.st_size 279 and self.st_gid == other.st_gid 280 and self.st_uid == other.st_uid 281 and self.st_nlink == other.st_nlink 282 and self.st_dev == other.st_dev 283 and self.st_ino == other.st_ino 284 and self.st_mode == other.st_mode 285 ) 286 287 def __ne__(self, other: Any) -> bool: 288 return not self == other 289 290 def copy(self) -> "FakeStatResult": 291 """Return a copy where the float usage is hard-coded to mimic the 292 behavior of the real os.stat_result. 293 """ 294 stat_result = copy(self) 295 return stat_result 296 297 def set_from_stat_result(self, stat_result: os.stat_result) -> None: 298 """Set values from a real os.stat_result. 299 Note: values that are controlled by the fake filesystem are not set. 300 This includes st_ino, st_dev and st_nlink. 301 """ 302 self.st_mode = stat_result.st_mode 303 self.st_uid = stat_result.st_uid 304 self.st_gid = stat_result.st_gid 305 self._st_size = stat_result.st_size 306 self._st_atime_ns = stat_result.st_atime_ns 307 self._st_mtime_ns = stat_result.st_mtime_ns 308 self._st_ctime_ns = stat_result.st_ctime_ns 309 310 @property 311 def st_ctime(self) -> Union[int, float]: 312 """Return the creation time in seconds.""" 313 return self._st_ctime_ns / 1e9 314 315 @st_ctime.setter 316 def st_ctime(self, val: Union[int, float]) -> None: 317 """Set the creation time in seconds.""" 318 self._st_ctime_ns = int(val * 1e9) 319 320 @property 321 def st_atime(self) -> Union[int, float]: 322 """Return the access time in seconds.""" 323 return self._st_atime_ns / 1e9 324 325 @st_atime.setter 326 def st_atime(self, val: Union[int, float]) -> None: 327 """Set the access time in seconds.""" 328 self._st_atime_ns = int(val * 1e9) 329 330 @property 331 def st_mtime(self) -> Union[int, float]: 332 """Return the modification time in seconds.""" 333 return self._st_mtime_ns / 1e9 334 335 @st_mtime.setter 336 def st_mtime(self, val: Union[int, float]) -> None: 337 """Set the modification time in seconds.""" 338 self._st_mtime_ns = int(val * 1e9) 339 340 @property 341 def st_size(self) -> int: 342 if self.st_mode & S_IFLNK == S_IFLNK and self.is_windows: 343 return 0 344 return self._st_size 345 346 @st_size.setter 347 def st_size(self, val: int) -> None: 348 self._st_size = val 349 350 @property 351 def st_blocks(self) -> int: 352 """Return the number of 512-byte blocks allocated for the file. 353 Assumes a page size of 4096 (matches most systems). 354 Ignores that this may not be available under some systems, 355 and that the result may differ if the file has holes. 356 """ 357 if self.is_windows: 358 raise AttributeError("'os.stat_result' object has no attribute 'st_blocks'") 359 page_size = 4096 360 blocks_in_page = page_size // 512 361 pages = self._st_size // page_size 362 if self._st_size % page_size: 363 pages += 1 364 return pages * blocks_in_page 365 366 @property 367 def st_file_attributes(self) -> int: 368 if not self.is_windows: 369 raise AttributeError( 370 "module 'os.stat_result' has no attribute 'st_file_attributes'" 371 ) 372 mode = 0 373 st_mode = self.st_mode 374 if st_mode & stat.S_IFDIR: 375 mode |= stat.FILE_ATTRIBUTE_DIRECTORY # type:ignore[attr-defined] 376 if st_mode & stat.S_IFREG: 377 mode |= stat.FILE_ATTRIBUTE_NORMAL # type:ignore[attr-defined] 378 if st_mode & (stat.S_IFCHR | stat.S_IFBLK): 379 mode |= stat.FILE_ATTRIBUTE_DEVICE # type:ignore[attr-defined] 380 if st_mode & stat.S_IFLNK: 381 mode |= stat.FILE_ATTRIBUTE_REPARSE_POINT # type:ignore 382 return mode 383 384 @property 385 def st_reparse_tag(self) -> int: 386 if not self.is_windows or sys.version_info < (3, 8): 387 raise AttributeError( 388 "module 'os.stat_result' has no attribute 'st_reparse_tag'" 389 ) 390 if self.st_mode & stat.S_IFLNK: 391 return stat.IO_REPARSE_TAG_SYMLINK # type: ignore[attr-defined] 392 return 0 393 394 def __getitem__(self, item: int) -> Optional[int]: 395 """Implement item access to mimic `os.stat_result` behavior.""" 396 import stat 397 398 if item == stat.ST_MODE: 399 return self.st_mode 400 if item == stat.ST_INO: 401 return self.st_ino 402 if item == stat.ST_DEV: 403 return self.st_dev 404 if item == stat.ST_NLINK: 405 return self.st_nlink 406 if item == stat.ST_UID: 407 return self.st_uid 408 if item == stat.ST_GID: 409 return self.st_gid 410 if item == stat.ST_SIZE: 411 return self.st_size 412 if item == stat.ST_ATIME: 413 # item access always returns int for backward compatibility 414 return int(self.st_atime) 415 if item == stat.ST_MTIME: 416 return int(self.st_mtime) 417 if item == stat.ST_CTIME: 418 return int(self.st_ctime) 419 raise ValueError("Invalid item") 420 421 @property 422 def st_atime_ns(self) -> int: 423 """Return the access time in nanoseconds.""" 424 return self._st_atime_ns 425 426 @st_atime_ns.setter 427 def st_atime_ns(self, val: int) -> None: 428 """Set the access time in nanoseconds.""" 429 self._st_atime_ns = val 430 431 @property 432 def st_mtime_ns(self) -> int: 433 """Return the modification time in nanoseconds.""" 434 return self._st_mtime_ns 435 436 @st_mtime_ns.setter 437 def st_mtime_ns(self, val: int) -> None: 438 """Set the modification time of the fake file in nanoseconds.""" 439 self._st_mtime_ns = val 440 441 @property 442 def st_ctime_ns(self) -> int: 443 """Return the creation time in nanoseconds.""" 444 return self._st_ctime_ns 445 446 @st_ctime_ns.setter 447 def st_ctime_ns(self, val: int) -> None: 448 """Set the creation time of the fake file in nanoseconds.""" 449 self._st_ctime_ns = val 450 451 452class BinaryBufferIO(io.BytesIO): 453 """Stream class that handles byte contents for files.""" 454 455 def __init__(self, contents: Optional[bytes]): 456 super().__init__(contents or b"") 457 458 def putvalue(self, value: bytes) -> None: 459 self.write(value) 460 461 462class TextBufferIO(io.TextIOWrapper): 463 """Stream class that handles Python string contents for files.""" 464 465 def __init__( 466 self, 467 contents: Optional[bytes] = None, 468 newline: Optional[str] = None, 469 encoding: Optional[str] = None, 470 errors: str = "strict", 471 ): 472 self._bytestream = io.BytesIO(contents or b"") 473 super().__init__(self._bytestream, encoding, errors, newline) 474 475 def getvalue(self) -> bytes: 476 return self._bytestream.getvalue() 477 478 def putvalue(self, value: bytes) -> None: 479 self._bytestream.write(value) 480 481 482def is_called_from_skipped_module( 483 skip_names: list, case_sensitive: bool, check_open_code: bool = False 484) -> bool: 485 def starts_with(path, string): 486 if case_sensitive: 487 return path.startswith(string) 488 return path.lower().startswith(string.lower()) 489 490 # in most cases we don't have skip names and won't need the overhead 491 # of analyzing the traceback, except when checking for open_code 492 if not skip_names and not check_open_code: 493 return False 494 495 stack = traceback.extract_stack() 496 497 # handle the case that we try to call the original `open_code` 498 # (since Python 3.12) 499 # The stack in this case is: 500 # -1: helpers.is_called_from_skipped_module: 'stack = traceback.extract_stack()' 501 # -2: fake_open.fake_open: 'if is_called_from_skipped_module(' 502 # -3: fake_io.open: 'return fake_open(' 503 # -4: fake_io.open_code : 'return self._io_module.open_code(path)' 504 if ( 505 check_open_code 506 and stack[-4].name == "open_code" 507 and stack[-4].line == "return self._io_module.open_code(path)" 508 ): 509 return True 510 511 if not skip_names: 512 return False 513 514 caller_filename = next( 515 ( 516 frame.filename 517 for frame in stack[::-1] 518 if not frame.filename.startswith("<frozen ") 519 and not starts_with(frame.filename, STDLIB_PATH) 520 and ( 521 not starts_with(frame.filename, PYFAKEFS_PATH) 522 or any( 523 starts_with(frame.filename, test_path) 524 for test_path in PYFAKEFS_TEST_PATHS 525 ) 526 ) 527 ), 528 None, 529 ) 530 531 if caller_filename: 532 caller_module_name = os.path.splitext(caller_filename)[0] 533 caller_module_name = caller_module_name.replace(os.sep, ".") 534 535 if any( 536 [ 537 caller_module_name == sn or caller_module_name.endswith("." + sn) 538 for sn in skip_names 539 ] 540 ): 541 return True 542 return False 543 544 545def reload_cleanup_handler(name): 546 """Cleanup handler that reloads the module with the given name. 547 Maybe needed in cases where a module is imported locally. 548 """ 549 if name in sys.modules: 550 importlib.reload(sys.modules[name]) 551 return True 552