1"""Extract files from a wheel's RECORD.""" 2 3import csv 4import re 5import sys 6import zipfile 7from collections.abc import Iterable 8from pathlib import Path 9 10WhlRecord = Iterable[str] 11 12 13def get_record(whl_path: Path) -> WhlRecord: 14 try: 15 zipf = zipfile.ZipFile(whl_path) 16 except zipfile.BadZipFile as ex: 17 raise RuntimeError(f"{whl_path} is not a valid zip file") from ex 18 files = zipf.namelist() 19 try: 20 (record_file,) = [name for name in files if name.endswith(".dist-info/RECORD")] 21 except ValueError: 22 raise RuntimeError(f"{whl_path} doesn't contain exactly one .dist-info/RECORD") 23 record_lines = zipf.read(record_file).decode().splitlines() 24 return (row[0] for row in csv.reader(record_lines)) 25 26 27def get_files(whl_record: WhlRecord, regex_pattern: str) -> list[str]: 28 """Get files in a wheel that match a regex pattern.""" 29 p = re.compile(regex_pattern) 30 return [filepath for filepath in whl_record if re.match(p, filepath)] 31 32 33def extract_files(whl_path: Path, files: Iterable[str], outdir: Path) -> None: 34 """Extract files from whl_path to outdir.""" 35 zipf = zipfile.ZipFile(whl_path) 36 for file in files: 37 zipf.extract(file, outdir) 38 39 40def main() -> None: 41 if len(sys.argv) not in {3, 4}: 42 print( 43 f"Usage: {sys.argv[0]} <wheel> <out_dir> [regex_pattern]", 44 file=sys.stderr, 45 ) 46 sys.exit(1) 47 48 whl_path = Path(sys.argv[1]).resolve() 49 outdir = Path(sys.argv[2]) 50 regex_pattern = sys.argv[3] if len(sys.argv) == 4 else "" 51 52 whl_record = get_record(whl_path) 53 files = get_files(whl_record, regex_pattern) 54 extract_files(whl_path, files, outdir) 55 56 57if __name__ == "__main__": 58 main() 59