1# Copyright 2020 The Pigweed Authors 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); you may not 4# use this file except in compliance with the License. You may obtain a copy of 5# the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12# License for the specific language governing permissions and limitations under 13# the License. 14"""Builds and manages databases of tokenized strings.""" 15 16from __future__ import annotations 17 18from abc import abstractmethod 19import bisect 20from collections.abc import ( 21 Callable, 22 Iterable, 23 Iterator, 24 Mapping, 25 Sequence, 26 ValuesView, 27) 28import csv 29from dataclasses import dataclass 30from datetime import datetime 31import io 32import logging 33from pathlib import Path 34import re 35import struct 36import subprocess 37from typing import ( 38 Any, 39 BinaryIO, 40 IO, 41 NamedTuple, 42 overload, 43 Pattern, 44 TextIO, 45 TypeVar, 46) 47from uuid import uuid4 48 49DEFAULT_DOMAIN = '' 50 51# The default hash length to use for C-style hashes. This value only applies 52# when manually hashing strings to recreate token calculations in C. The C++ 53# hash function does not have a maximum length. 54# 55# This MUST match the default value of PW_TOKENIZER_CFG_C_HASH_LENGTH in 56# pw_tokenizer/public/pw_tokenizer/config.h. 57DEFAULT_C_HASH_LENGTH = 128 58 59TOKENIZER_HASH_CONSTANT = 65599 60 61_LOG = logging.getLogger('pw_tokenizer') 62 63 64def _value(char: int | str) -> int: 65 return char if isinstance(char, int) else ord(char) 66 67 68def pw_tokenizer_65599_hash( 69 string: str | bytes, *, hash_length: int | None = None 70) -> int: 71 """Hashes the string with the hash function used to generate tokens in C++. 72 73 This hash function is used calculate tokens from strings in Python. It is 74 not used when extracting tokens from an ELF, since the token is stored in 75 the ELF as part of tokenization. 76 """ 77 hash_value = len(string) 78 coefficient = TOKENIZER_HASH_CONSTANT 79 80 for char in string[:hash_length]: 81 hash_value = (hash_value + coefficient * _value(char)) % 2**32 82 coefficient = (coefficient * TOKENIZER_HASH_CONSTANT) % 2**32 83 84 return hash_value 85 86 87def c_hash( 88 string: str | bytes, hash_length: int = DEFAULT_C_HASH_LENGTH 89) -> int: 90 """Hashes the string with the hash function used in C.""" 91 return pw_tokenizer_65599_hash(string, hash_length=hash_length) 92 93 94@dataclass(frozen=True, eq=True, order=True) 95class _EntryKey: 96 """Uniquely refers to an entry.""" 97 98 domain: str 99 token: int 100 string: str 101 102 103class TokenizedStringEntry: 104 """A tokenized string with its metadata.""" 105 106 def __init__( 107 self, 108 token: int, 109 string: str, 110 domain: str = DEFAULT_DOMAIN, 111 date_removed: datetime | None = None, 112 ) -> None: 113 self._key = _EntryKey( 114 ''.join(domain.split()), 115 token, 116 string, 117 ) 118 self.date_removed = date_removed 119 120 @property 121 def token(self) -> int: 122 return self._key.token 123 124 @property 125 def string(self) -> str: 126 return self._key.string 127 128 @property 129 def domain(self) -> str: 130 return self._key.domain 131 132 def key(self) -> _EntryKey: 133 """The key determines uniqueness for a tokenized string.""" 134 return self._key 135 136 def update_date_removed(self, new_date_removed: datetime | None) -> None: 137 """Sets self.date_removed if the other date is newer.""" 138 # No removal date (None) is treated as the newest date. 139 if self.date_removed is None: 140 return 141 142 if new_date_removed is None or new_date_removed > self.date_removed: 143 self.date_removed = new_date_removed 144 145 def __eq__(self, other: Any) -> bool: 146 return ( 147 self.key() == other.key() 148 and self.date_removed == other.date_removed 149 ) 150 151 def __lt__(self, other: Any) -> bool: 152 """Sorts the entry by domain, token, date removed, then string.""" 153 if self.domain != other.domain: 154 return self.domain < other.domain 155 156 if self.token != other.token: 157 return self.token < other.token 158 159 # Sort removal dates in reverse, so the most recently removed (or still 160 # present) entry appears first. 161 if self.date_removed != other.date_removed: 162 return (other.date_removed or datetime.max) < ( 163 self.date_removed or datetime.max 164 ) 165 166 return self.string < other.string 167 168 def __str__(self) -> str: 169 return self.string 170 171 def __repr__(self) -> str: 172 return ( 173 f'{self.__class__.__name__}(token=0x{self.token:08x}, ' 174 f'string={self.string!r}, domain={self.domain!r})' 175 ) 176 177 178_TokenToEntries = dict[int, list[TokenizedStringEntry]] 179_K = TypeVar('_K') 180_V = TypeVar('_V') 181_T = TypeVar('_T') 182 183 184class _TokenDatabaseView(Mapping[_K, _V]): # pylint: disable=abstract-method 185 """Read-only mapping view of a token database. 186 187 Behaves like a read-only version of defaultdict(list). 188 """ 189 190 def __init__(self, mapping: Mapping[_K, Any]) -> None: 191 self._mapping = mapping 192 193 def __contains__(self, key: object) -> bool: 194 return key in self._mapping 195 196 @overload 197 def get(self, key: _K) -> _V | None: # pylint: disable=arguments-differ 198 ... 199 200 @overload 201 def get(self, key: _K, default: _T) -> _V | _T: # pylint: disable=W0222 202 ... 203 204 def get(self, key: _K, default: _T | None = None) -> _V | _T | None: 205 return self._mapping.get(key, default) 206 207 def __iter__(self) -> Iterator[_K]: 208 return iter(self._mapping) 209 210 def __len__(self) -> int: 211 return len(self._mapping) 212 213 def __str__(self) -> str: 214 return str(self._mapping) 215 216 def __repr__(self) -> str: 217 return repr(self._mapping) 218 219 220class _TokenMapping(_TokenDatabaseView[int, Sequence[TokenizedStringEntry]]): 221 def __getitem__(self, token: int) -> Sequence[TokenizedStringEntry]: 222 """Returns strings that match the specified token; may be empty.""" 223 return self._mapping.get(token, ()) # Empty sequence if no match 224 225 226class _DomainTokenMapping(_TokenDatabaseView[str, _TokenMapping]): 227 def __getitem__(self, domain: str) -> _TokenMapping: 228 """Returns the token-to-strings mapping for the specified domain.""" 229 return _TokenMapping(self._mapping.get(domain, {})) # Empty if no match 230 231 232def _add_entry(entries: _TokenToEntries, entry: TokenizedStringEntry) -> None: 233 bisect.insort( 234 entries.setdefault(entry.token, []), 235 entry, 236 key=TokenizedStringEntry.key, # Keep lists of entries sorted by key. 237 ) 238 239 240class Database: 241 """Database of tokenized strings stored as TokenizedStringEntry objects.""" 242 243 def __init__(self, entries: Iterable[TokenizedStringEntry] = ()): 244 """Creates a token database.""" 245 # The database dict stores each unique (token, string) entry. 246 self._database: dict[_EntryKey, TokenizedStringEntry] = {} 247 248 # Index by token and domain 249 self._token_entries: _TokenToEntries = {} 250 self._domain_token_entries: dict[str, _TokenToEntries] = {} 251 252 self.add(entries) 253 254 @classmethod 255 def from_strings( 256 cls, 257 strings: Iterable[str], 258 domain: str = DEFAULT_DOMAIN, 259 tokenize: Callable[[str], int] = pw_tokenizer_65599_hash, 260 ) -> Database: 261 """Creates a Database from an iterable of strings.""" 262 return cls( 263 TokenizedStringEntry(tokenize(string), string, domain) 264 for string in strings 265 ) 266 267 @classmethod 268 def merged(cls, *databases: Database) -> Database: 269 """Creates a TokenDatabase from one or more other databases.""" 270 db = cls() 271 db.merge(*databases) 272 return db 273 274 @property 275 def token_to_entries(self) -> Mapping[int, Sequence[TokenizedStringEntry]]: 276 """Returns a mapping of tokens to a sequence of TokenizedStringEntry. 277 278 Returns token database entries from all domains. 279 """ 280 return _TokenMapping(self._token_entries) 281 282 @property 283 def domains( 284 self, 285 ) -> Mapping[str, Mapping[int, Sequence[TokenizedStringEntry]]]: 286 """Returns a mapping of domains to tokens to a sequence of entries. 287 288 `database.domains[domain][token]` returns a sequence of strings matching 289 the token in the domain, or an empty sequence if there are no matches. 290 """ 291 return _DomainTokenMapping(self._domain_token_entries) 292 293 def entries(self) -> ValuesView[TokenizedStringEntry]: 294 """Returns iterable over all TokenizedStringEntries in the database.""" 295 return self._database.values() 296 297 def collisions( 298 self, 299 ) -> Iterator[tuple[int, Sequence[TokenizedStringEntry]]]: 300 """Returns tuple of (token, entries_list)) for all colliding tokens.""" 301 for token, entries in self.token_to_entries.items(): 302 if len(entries) > 1: 303 yield token, entries 304 305 def mark_removed( 306 self, 307 all_entries: Iterable[TokenizedStringEntry], 308 removal_date: datetime | None = None, 309 ) -> list[TokenizedStringEntry]: 310 """Marks entries missing from all_entries as having been removed. 311 312 The entries are assumed to represent the complete set of entries for the 313 database. Entries currently in the database not present in the provided 314 entries are marked with a removal date but remain in the database. 315 Entries in all_entries missing from the database are NOT added; call the 316 add function to add these. 317 318 Args: 319 all_entries: the complete set of strings present in the database 320 removal_date: the datetime for removed entries; today by default 321 322 Returns: 323 A list of entries marked as removed. 324 """ 325 if removal_date is None: 326 removal_date = datetime.now() 327 328 all_keys = frozenset(entry.key() for entry in all_entries) 329 330 removed = [] 331 332 for entry in self._database.values(): 333 if entry.key() not in all_keys and ( 334 entry.date_removed is None or removal_date < entry.date_removed 335 ): 336 # Add a removal date, or update it to the oldest date. 337 entry.date_removed = removal_date 338 removed.append(entry) 339 340 return removed 341 342 def add(self, entries: Iterable[TokenizedStringEntry]) -> None: 343 """Adds new entries and updates date_removed for existing entries. 344 345 If the added tokens have removal dates, the newest date is used. 346 """ 347 for new_entry in entries: 348 # Update an existing entry or create a new one. 349 try: 350 entry = self._database[new_entry.key()] 351 352 # Keep the latest removal date between the two entries. 353 if new_entry.date_removed is None: 354 entry.date_removed = None 355 elif ( 356 entry.date_removed 357 and entry.date_removed < new_entry.date_removed 358 ): 359 entry.date_removed = new_entry.date_removed 360 except KeyError: 361 self._add_new_entry(new_entry) 362 363 def purge( 364 self, date_removed_cutoff: datetime | None = None 365 ) -> list[TokenizedStringEntry]: 366 """Removes and returns entries removed on/before date_removed_cutoff.""" 367 if date_removed_cutoff is None: 368 date_removed_cutoff = datetime.max 369 370 to_delete = [ 371 entry 372 for entry in self._database.values() 373 if entry.date_removed and entry.date_removed <= date_removed_cutoff 374 ] 375 376 for entry in to_delete: 377 self._delete_entry(entry) 378 379 return to_delete 380 381 def merge(self, *databases: Database) -> None: 382 """Merges two or more databases together. 383 384 All entries are kept if there are token collisions. 385 386 If there are two identical tokens (same domain, token, and string), the 387 newest removal date (or no date if either token has no removal date) is 388 used for the merged token. 389 """ 390 for other_db in databases: 391 for entry in other_db.entries(): 392 key = entry.key() 393 394 if key in self._database: 395 self._database[key].update_date_removed(entry.date_removed) 396 else: 397 self._add_new_entry(entry) 398 399 def filter( 400 self, 401 include: Iterable[str | Pattern[str]] = (), 402 exclude: Iterable[str | Pattern[str]] = (), 403 replace: Iterable[tuple[str | Pattern[str], str]] = (), 404 ) -> None: 405 """Filters the database using regular expressions (strings or compiled). 406 407 Args: 408 include: regexes; only entries matching at least one are kept 409 exclude: regexes; entries matching any of these are removed 410 replace: (regex, str) tuples; replaces matching terms in all entries 411 """ 412 to_delete: list[TokenizedStringEntry] = [] 413 414 if include: 415 include_re = [re.compile(pattern) for pattern in include] 416 to_delete.extend( 417 val 418 for val in self._database.values() 419 if not any(rgx.search(val.string) for rgx in include_re) 420 ) 421 422 if exclude: 423 exclude_re = [re.compile(pattern) for pattern in exclude] 424 to_delete.extend( 425 val 426 for val in self._database.values() 427 if any(rgx.search(val.string) for rgx in exclude_re) 428 ) 429 430 for entry in to_delete: 431 self._delete_entry(entry) 432 433 # Do the replacement after removing entries. 434 for search, replacement in replace: 435 search = re.compile(search) 436 437 to_replace: list[TokenizedStringEntry] = [] 438 add: list[TokenizedStringEntry] = [] 439 440 for entry in self._database.values(): 441 new_string = search.sub(replacement, entry.string) 442 if new_string != entry.string: 443 to_replace.append(entry) 444 add.append( 445 TokenizedStringEntry( 446 entry.token, 447 new_string, 448 entry.domain, 449 entry.date_removed, 450 ) 451 ) 452 453 for entry in to_replace: 454 self._delete_entry(entry) 455 self.add(add) 456 457 def difference(self, other: Database) -> Database: 458 """Returns a new Database with entries in this DB not in the other.""" 459 # pylint: disable=protected-access 460 return Database( 461 e for k, e in self._database.items() if k not in other._database 462 ) 463 # pylint: enable=protected-access 464 465 def _add_new_entry(self, new_entry: TokenizedStringEntry) -> None: 466 entry = TokenizedStringEntry( # These are mutable, so make a copy. 467 new_entry.token, 468 new_entry.string, 469 new_entry.domain, 470 new_entry.date_removed, 471 ) 472 self._database[entry.key()] = entry 473 _add_entry(self._token_entries, entry) 474 _add_entry( 475 self._domain_token_entries.setdefault(entry.domain, {}), entry 476 ) 477 478 def _delete_entry(self, entry: TokenizedStringEntry) -> None: 479 del self._database[entry.key()] 480 481 # Remove from the token / domain mappings and clean up empty lists. 482 self._token_entries[entry.token].remove(entry) 483 if not self._token_entries[entry.token]: 484 del self._token_entries[entry.token] 485 486 self._domain_token_entries[entry.domain][entry.token].remove(entry) 487 if not self._domain_token_entries[entry.domain][entry.token]: 488 del self._domain_token_entries[entry.domain][entry.token] 489 if not self._domain_token_entries[entry.domain]: 490 del self._domain_token_entries[entry.domain] 491 492 def __len__(self) -> int: 493 """Returns the number of entries in the database.""" 494 return len(self.entries()) 495 496 def __bool__(self) -> bool: 497 """True if the database is non-empty.""" 498 return bool(self._database) 499 500 def __str__(self) -> str: 501 """Outputs the database as CSV.""" 502 csv_output = io.BytesIO() 503 write_csv(self, csv_output) 504 return csv_output.getvalue().decode() 505 506 507def parse_csv(fd: TextIO) -> Iterable[TokenizedStringEntry]: 508 """Parses TokenizedStringEntries from a CSV token database file.""" 509 for line in csv.reader(fd): 510 try: 511 try: 512 token_str, date_str, domain, string_literal = line 513 except ValueError: 514 # If there are only three columns, use the default domain. 515 token_str, date_str, string_literal = line 516 domain = DEFAULT_DOMAIN 517 518 token = int(token_str, 16) 519 date = ( 520 datetime.fromisoformat(date_str) if date_str.strip() else None 521 ) 522 523 yield TokenizedStringEntry(token, string_literal, domain, date) 524 except (ValueError, UnicodeDecodeError) as err: 525 _LOG.error( 526 'Failed to parse tokenized string entry %s: %s', line, err 527 ) 528 529 530def write_csv(database: Database, fd: IO[bytes]) -> None: 531 """Writes the database as CSV to the provided binary file.""" 532 for entry in sorted(database.entries()): 533 _write_csv_line(fd, entry) 534 535 536def _write_csv_line(fd: IO[bytes], entry: TokenizedStringEntry): 537 """Write a line in CSV format to the provided binary file.""" 538 # Align the CSV output to 10-character columns for improved readability. 539 # Use \n instead of RFC 4180's \r\n. 540 fd.write( 541 '{:08x},{:10},"{}","{}"\n'.format( 542 entry.token, 543 entry.date_removed.date().isoformat() if entry.date_removed else '', 544 entry.domain.replace('"', '""'), # escape " as "" 545 entry.string.replace('"', '""'), 546 ).encode() 547 ) 548 549 550class _BinaryFileFormat(NamedTuple): 551 """Attributes of the binary token database file format.""" 552 553 magic: bytes = b'TOKENS\0\0' 554 header: struct.Struct = struct.Struct('<8sI4x') 555 entry: struct.Struct = struct.Struct('<IBBH') 556 557 558BINARY_FORMAT = _BinaryFileFormat() 559 560 561class DatabaseFormatError(Exception): 562 """Failed to parse a token database file.""" 563 564 565def file_is_binary_database(fd: BinaryIO) -> bool: 566 """True if the file starts with the binary token database magic string.""" 567 try: 568 fd.seek(0) 569 magic = fd.read(len(BINARY_FORMAT.magic)) 570 fd.seek(0) 571 return BINARY_FORMAT.magic == magic 572 except IOError: 573 return False 574 575 576def _check_that_file_is_csv_database(path: Path) -> None: 577 """Raises an error unless the path appears to be a CSV token database.""" 578 try: 579 with path.open('rb') as fd: 580 data = fd.read(8) # Read 8 bytes, which should be the first token. 581 582 if not data: 583 return # File is empty, which is valid CSV. 584 585 if len(data) != 8: 586 raise DatabaseFormatError( 587 f'Attempted to read {path} as a CSV token database, but the ' 588 f'file is too short ({len(data)} B)' 589 ) 590 591 # Make sure the first 8 chars are a valid hexadecimal number. 592 _ = int(data.decode(), 16) 593 except (IOError, UnicodeDecodeError, ValueError) as err: 594 raise DatabaseFormatError( 595 f'Encountered error while reading {path} as a CSV token database' 596 ) from err 597 598 599def parse_binary(fd: BinaryIO) -> Iterable[TokenizedStringEntry]: 600 """Parses TokenizedStringEntries from a binary token database file.""" 601 magic, entry_count = BINARY_FORMAT.header.unpack( 602 fd.read(BINARY_FORMAT.header.size) 603 ) 604 605 if magic != BINARY_FORMAT.magic: 606 raise DatabaseFormatError( 607 f'Binary token database magic number mismatch (found {magic!r}, ' 608 f'expected {BINARY_FORMAT.magic!r}) while reading from {fd}' 609 ) 610 611 entries = [] 612 613 for _ in range(entry_count): 614 token, day, month, year = BINARY_FORMAT.entry.unpack( 615 fd.read(BINARY_FORMAT.entry.size) 616 ) 617 618 try: 619 date_removed: datetime | None = datetime(year, month, day) 620 except ValueError: 621 date_removed = None 622 623 entries.append((token, date_removed)) 624 625 # Read the entire string table and define a function for looking up strings. 626 string_table = fd.read() 627 628 def read_string(start): 629 end = string_table.find(b'\0', start) 630 return ( 631 string_table[start : string_table.find(b'\0', start)].decode(), 632 end + 1, 633 ) 634 635 offset = 0 636 for token, removed in entries: 637 string, offset = read_string(offset) 638 yield TokenizedStringEntry(token, string, DEFAULT_DOMAIN, removed) 639 640 641def write_binary(database: Database, fd: BinaryIO) -> None: 642 """Writes the database as packed binary to the provided binary file.""" 643 entries = sorted(database.entries()) 644 645 fd.write(BINARY_FORMAT.header.pack(BINARY_FORMAT.magic, len(entries))) 646 647 string_table = bytearray() 648 649 for entry in entries: 650 if entry.date_removed: 651 removed_day = entry.date_removed.day 652 removed_month = entry.date_removed.month 653 removed_year = entry.date_removed.year 654 else: 655 # If there is no removal date, use the special value 0xffffffff for 656 # the day/month/year. That ensures that still-present tokens appear 657 # as the newest tokens when sorted by removal date. 658 removed_day = 0xFF 659 removed_month = 0xFF 660 removed_year = 0xFFFF 661 662 string_table += entry.string.encode() 663 string_table.append(0) 664 665 fd.write( 666 BINARY_FORMAT.entry.pack( 667 entry.token, removed_day, removed_month, removed_year 668 ) 669 ) 670 671 fd.write(string_table) 672 673 674class DatabaseFile(Database): 675 """A token database that is associated with a particular file. 676 677 This class adds the write_to_file() method that writes to file from which it 678 was created in the correct format (CSV or binary). 679 """ 680 681 def __init__( 682 self, path: Path, entries: Iterable[TokenizedStringEntry] 683 ) -> None: 684 super().__init__(entries) 685 self.path = path 686 687 @staticmethod 688 def load(path: Path) -> DatabaseFile: 689 """Creates a DatabaseFile that coincides to the file type.""" 690 if path.is_dir(): 691 return _DirectoryDatabase(path) 692 693 # Read the path as a packed binary file. 694 with path.open('rb') as fd: 695 if file_is_binary_database(fd): 696 return _BinaryDatabase(path, fd) 697 698 # Read the path as a CSV file. 699 _check_that_file_is_csv_database(path) 700 return _CSVDatabase(path) 701 702 @abstractmethod 703 def write_to_file(self, *, rewrite: bool = False) -> None: 704 """Exports in the original format to the original path.""" 705 706 @abstractmethod 707 def add_and_discard_temporary( 708 self, entries: Iterable[TokenizedStringEntry], commit: str 709 ) -> None: 710 """Discards and adds entries to export in the original format. 711 712 Adds entries after removing temporary entries from the Database 713 to exclusively write re-occurring entries into memory and disk. 714 """ 715 716 717class _BinaryDatabase(DatabaseFile): 718 def __init__(self, path: Path, fd: BinaryIO) -> None: 719 super().__init__(path, parse_binary(fd)) 720 721 def write_to_file(self, *, rewrite: bool = False) -> None: 722 """Exports in the binary format to the original path.""" 723 del rewrite # Binary databases are always rewritten 724 with self.path.open('wb') as fd: 725 write_binary(self, fd) 726 727 def add_and_discard_temporary( 728 self, entries: Iterable[TokenizedStringEntry], commit: str 729 ) -> None: 730 # TODO: b/241471465 - Implement adding new tokens and removing 731 # temporary entries for binary databases. 732 raise NotImplementedError( 733 '--discard-temporary is currently only ' 734 'supported for directory databases' 735 ) 736 737 738class _CSVDatabase(DatabaseFile): 739 def __init__(self, path: Path) -> None: 740 with path.open('r', newline='', encoding='utf-8') as csv_fd: 741 super().__init__(path, parse_csv(csv_fd)) 742 743 def write_to_file(self, *, rewrite: bool = False) -> None: 744 """Exports in the CSV format to the original path.""" 745 del rewrite # CSV databases are always rewritten 746 with self.path.open('wb') as fd: 747 write_csv(self, fd) 748 749 def add_and_discard_temporary( 750 self, entries: Iterable[TokenizedStringEntry], commit: str 751 ) -> None: 752 # TODO: b/241471465 - Implement adding new tokens and removing 753 # temporary entries for CSV databases. 754 raise NotImplementedError( 755 '--discard-temporary is currently only ' 756 'supported for directory databases' 757 ) 758 759 760# The suffix used for CSV files in a directory database. 761DIR_DB_SUFFIX = '.pw_tokenizer.csv' 762DIR_DB_GLOB = '*' + DIR_DB_SUFFIX 763 764 765def _parse_directory(directory: Path) -> Iterable[TokenizedStringEntry]: 766 """Parses TokenizedStringEntries tokenizer CSV files in the directory.""" 767 for path in directory.glob(DIR_DB_GLOB): 768 yield from _CSVDatabase(path).entries() 769 770 771def _most_recently_modified_file(paths: Iterable[Path]) -> Path: 772 return max(paths, key=lambda path: path.stat().st_mtime) 773 774 775class _DirectoryDatabase(DatabaseFile): 776 def __init__(self, directory: Path) -> None: 777 super().__init__(directory, _parse_directory(directory)) 778 779 def write_to_file(self, *, rewrite: bool = False) -> None: 780 """Creates a new CSV file in the directory with any new tokens.""" 781 if rewrite: 782 # Write the entire database to a new CSV file 783 new_file = self._create_filename() 784 with new_file.open('wb') as fd: 785 write_csv(self, fd) 786 787 # Delete all CSV files except for the new CSV with everything. 788 for csv_file in self.path.glob(DIR_DB_GLOB): 789 if csv_file != new_file: 790 csv_file.unlink() 791 else: 792 # Reread the tokens from disk and write only the new entries to CSV. 793 current_tokens = Database(_parse_directory(self.path)) 794 new_entries = self.difference(current_tokens) 795 if new_entries: 796 with self._create_filename().open('wb') as fd: 797 write_csv(new_entries, fd) 798 799 def _git_paths(self, commands: list) -> list[Path]: 800 """Returns a list of database CSVs from a Git command.""" 801 try: 802 output = subprocess.run( 803 ['git', *commands, DIR_DB_GLOB], 804 capture_output=True, 805 check=True, 806 cwd=self.path, 807 text=True, 808 ).stdout.strip() 809 return [self.path / repo_path for repo_path in output.splitlines()] 810 except subprocess.CalledProcessError: 811 return [] 812 813 def _find_latest_csv(self, commit: str) -> Path: 814 """Finds or creates a CSV to which to write new entries. 815 816 - Check for untracked CSVs. Use the most recently modified file, if any. 817 - Check for CSVs added in HEAD, if HEAD is not an ancestor of commit. 818 Use the most recently modified file, if any. 819 - If no untracked or committed files were found, create a new file. 820 """ 821 822 # Prioritize untracked files in the directory database. 823 untracked_changes = self._git_paths( 824 ['ls-files', '--others', '--exclude-standard'] 825 ) 826 if untracked_changes: 827 return _most_recently_modified_file(untracked_changes) 828 829 # Check if HEAD is an ancestor of the base commit. This checks whether 830 # the top commit has been merged or not. If it has been merged, create a 831 # new CSV to use. Otherwise, check if a CSV was added in the commit. 832 head_is_not_merged = ( 833 subprocess.run( 834 ['git', 'merge-base', '--is-ancestor', 'HEAD', commit], 835 cwd=self.path, 836 stdout=subprocess.DEVNULL, 837 stderr=subprocess.DEVNULL, 838 ).returncode 839 != 0 840 ) 841 842 if head_is_not_merged: 843 # Find CSVs added in the top commit. 844 csvs_from_top_commit = self._git_paths( 845 [ 846 'diff', 847 '--name-only', 848 '--diff-filter=A', 849 '--relative', 850 'HEAD~', 851 ] 852 ) 853 854 if csvs_from_top_commit: 855 return _most_recently_modified_file(csvs_from_top_commit) 856 857 return self._create_filename() 858 859 def _create_filename(self) -> Path: 860 """Generates a unique filename not in the directory.""" 861 # Tracked and untracked files do not exist in the repo. 862 while (file := self.path / f'{uuid4().hex}{DIR_DB_SUFFIX}').exists(): 863 pass 864 return file 865 866 def add_and_discard_temporary( 867 self, entries: Iterable[TokenizedStringEntry], commit: str 868 ) -> None: 869 """Adds new entries and discards temporary entries on disk. 870 871 - Find the latest CSV in the directory database or create a new one. 872 - Delete entries in the latest CSV that are not in the entries passed to 873 this function. 874 - Add the new entries to this database. 875 - Overwrite the latest CSV with only the newly added entries. 876 """ 877 # Find entries not currently in the database. 878 added = Database(entries) 879 new_entries = added.difference(self) 880 881 csv_path = self._find_latest_csv(commit) 882 if csv_path.exists(): 883 # Loading the CSV as a DatabaseFile. 884 csv_db = DatabaseFile.load(csv_path) 885 886 # Delete entries added in the CSV, but not added in this function. 887 for key in (e.key() for e in csv_db.difference(added).entries()): 888 del self._database[key] 889 del csv_db._database[key] # pylint: disable=protected-access 890 891 csv_db.add(new_entries.entries()) 892 csv_db.write_to_file() 893 elif new_entries: # If the CSV does not exist, write all new tokens. 894 with csv_path.open('wb') as fd: 895 write_csv(new_entries, fd) 896 897 self.add(new_entries.entries()) 898