• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""Extract files from a wheel's RECORD."""
2
3import csv
4import re
5import sys
6import zipfile
7from collections.abc import Iterable
8from pathlib import Path
9
10WhlRecord = Iterable[str]
11
12
13def get_record(whl_path: Path) -> WhlRecord:
14    try:
15        zipf = zipfile.ZipFile(whl_path)
16    except zipfile.BadZipFile as ex:
17        raise RuntimeError(f"{whl_path} is not a valid zip file") from ex
18    files = zipf.namelist()
19    try:
20        (record_file,) = [name for name in files if name.endswith(".dist-info/RECORD")]
21    except ValueError:
22        raise RuntimeError(f"{whl_path} doesn't contain exactly one .dist-info/RECORD")
23    record_lines = zipf.read(record_file).decode().splitlines()
24    return (row[0] for row in csv.reader(record_lines))
25
26
27def get_files(whl_record: WhlRecord, regex_pattern: str) -> list[str]:
28    """Get files in a wheel that match a regex pattern."""
29    p = re.compile(regex_pattern)
30    return [filepath for filepath in whl_record if re.match(p, filepath)]
31
32
33def extract_files(whl_path: Path, files: Iterable[str], outdir: Path) -> None:
34    """Extract files from whl_path to outdir."""
35    zipf = zipfile.ZipFile(whl_path)
36    for file in files:
37        zipf.extract(file, outdir)
38
39
40def main() -> None:
41    if len(sys.argv) not in {3, 4}:
42        print(
43            f"Usage: {sys.argv[0]} <wheel> <out_dir> [regex_pattern]",
44            file=sys.stderr,
45        )
46        sys.exit(1)
47
48    whl_path = Path(sys.argv[1]).resolve()
49    outdir = Path(sys.argv[2])
50    regex_pattern = sys.argv[3] if len(sys.argv) == 4 else ""
51
52    whl_record = get_record(whl_path)
53    files = get_files(whl_record, regex_pattern)
54    extract_files(whl_path, files, outdir)
55
56
57if __name__ == "__main__":
58    main()
59