1# Copyright 2020 The Pigweed Authors 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); you may not 4# use this file except in compliance with the License. You may obtain a copy of 5# the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12# License for the specific language governing permissions and limitations under 13# the License. 14"""Builds and manages databases of tokenized strings.""" 15 16from __future__ import annotations 17 18from abc import abstractmethod 19import collections 20import csv 21from dataclasses import dataclass 22from datetime import datetime 23import io 24import logging 25from pathlib import Path 26import re 27import struct 28import subprocess 29from typing import ( 30 BinaryIO, 31 Callable, 32 Iterable, 33 Iterator, 34 NamedTuple, 35 Pattern, 36 TextIO, 37 ValuesView, 38) 39from uuid import uuid4 40 41DEFAULT_DOMAIN = '' 42 43# The default hash length to use for C-style hashes. This value only applies 44# when manually hashing strings to recreate token calculations in C. The C++ 45# hash function does not have a maximum length. 46# 47# This MUST match the default value of PW_TOKENIZER_CFG_C_HASH_LENGTH in 48# pw_tokenizer/public/pw_tokenizer/config.h. 49DEFAULT_C_HASH_LENGTH = 128 50 51TOKENIZER_HASH_CONSTANT = 65599 52 53_LOG = logging.getLogger('pw_tokenizer') 54 55 56def _value(char: int | str) -> int: 57 return char if isinstance(char, int) else ord(char) 58 59 60def pw_tokenizer_65599_hash( 61 string: str | bytes, *, hash_length: int | None = None 62) -> int: 63 """Hashes the string with the hash function used to generate tokens in C++. 64 65 This hash function is used calculate tokens from strings in Python. It is 66 not used when extracting tokens from an ELF, since the token is stored in 67 the ELF as part of tokenization. 68 """ 69 hash_value = len(string) 70 coefficient = TOKENIZER_HASH_CONSTANT 71 72 for char in string[:hash_length]: 73 hash_value = (hash_value + coefficient * _value(char)) % 2**32 74 coefficient = (coefficient * TOKENIZER_HASH_CONSTANT) % 2**32 75 76 return hash_value 77 78 79def c_hash( 80 string: str | bytes, hash_length: int = DEFAULT_C_HASH_LENGTH 81) -> int: 82 """Hashes the string with the hash function used in C.""" 83 return pw_tokenizer_65599_hash(string, hash_length=hash_length) 84 85 86class _EntryKey(NamedTuple): 87 """Uniquely refers to an entry.""" 88 89 token: int 90 string: str 91 92 93@dataclass(eq=True, order=False) 94class TokenizedStringEntry: 95 """A tokenized string with its metadata.""" 96 97 token: int 98 string: str 99 domain: str = DEFAULT_DOMAIN 100 date_removed: datetime | None = None 101 102 def key(self) -> _EntryKey: 103 """The key determines uniqueness for a tokenized string.""" 104 return _EntryKey(self.token, self.string) 105 106 def update_date_removed(self, new_date_removed: datetime | None) -> None: 107 """Sets self.date_removed if the other date is newer.""" 108 # No removal date (None) is treated as the newest date. 109 if self.date_removed is None: 110 return 111 112 if new_date_removed is None or new_date_removed > self.date_removed: 113 self.date_removed = new_date_removed 114 115 def __lt__(self, other) -> bool: 116 """Sorts the entry by token, date removed, then string.""" 117 if self.token != other.token: 118 return self.token < other.token 119 120 # Sort removal dates in reverse, so the most recently removed (or still 121 # present) entry appears first. 122 if self.date_removed != other.date_removed: 123 return (other.date_removed or datetime.max) < ( 124 self.date_removed or datetime.max 125 ) 126 127 return self.string < other.string 128 129 def __str__(self) -> str: 130 return self.string 131 132 133class Database: 134 """Database of tokenized strings stored as TokenizedStringEntry objects.""" 135 136 def __init__(self, entries: Iterable[TokenizedStringEntry] = ()): 137 """Creates a token database.""" 138 # The database dict stores each unique (token, string) entry. 139 self._database: dict[_EntryKey, TokenizedStringEntry] = {} 140 141 # This is a cache for fast token lookup that is built as needed. 142 self._cache: dict[int, list[TokenizedStringEntry]] | None = None 143 144 self.add(entries) 145 146 @classmethod 147 def from_strings( 148 cls, 149 strings: Iterable[str], 150 domain: str = DEFAULT_DOMAIN, 151 tokenize: Callable[[str], int] = pw_tokenizer_65599_hash, 152 ) -> Database: 153 """Creates a Database from an iterable of strings.""" 154 return cls( 155 ( 156 TokenizedStringEntry(tokenize(string), string, domain) 157 for string in strings 158 ) 159 ) 160 161 @classmethod 162 def merged(cls, *databases: Database) -> Database: 163 """Creates a TokenDatabase from one or more other databases.""" 164 db = cls() 165 db.merge(*databases) 166 return db 167 168 @property 169 def token_to_entries(self) -> dict[int, list[TokenizedStringEntry]]: 170 """Returns a dict that maps tokens to a list of TokenizedStringEntry.""" 171 if self._cache is None: # build cache token -> entry cache 172 self._cache = collections.defaultdict(list) 173 for entry in self._database.values(): 174 self._cache[entry.token].append(entry) 175 176 return self._cache 177 178 def entries(self) -> ValuesView[TokenizedStringEntry]: 179 """Returns iterable over all TokenizedStringEntries in the database.""" 180 return self._database.values() 181 182 def collisions(self) -> Iterator[tuple[int, list[TokenizedStringEntry]]]: 183 """Returns tuple of (token, entries_list)) for all colliding tokens.""" 184 for token, entries in self.token_to_entries.items(): 185 if len(entries) > 1: 186 yield token, entries 187 188 def mark_removed( 189 self, 190 all_entries: Iterable[TokenizedStringEntry], 191 removal_date: datetime | None = None, 192 ) -> list[TokenizedStringEntry]: 193 """Marks entries missing from all_entries as having been removed. 194 195 The entries are assumed to represent the complete set of entries for the 196 database. Entries currently in the database not present in the provided 197 entries are marked with a removal date but remain in the database. 198 Entries in all_entries missing from the database are NOT added; call the 199 add function to add these. 200 201 Args: 202 all_entries: the complete set of strings present in the database 203 removal_date: the datetime for removed entries; today by default 204 205 Returns: 206 A list of entries marked as removed. 207 """ 208 self._cache = None 209 210 if removal_date is None: 211 removal_date = datetime.now() 212 213 all_keys = frozenset(entry.key() for entry in all_entries) 214 215 removed = [] 216 217 for entry in self._database.values(): 218 if entry.key() not in all_keys and ( 219 entry.date_removed is None or removal_date < entry.date_removed 220 ): 221 # Add a removal date, or update it to the oldest date. 222 entry.date_removed = removal_date 223 removed.append(entry) 224 225 return removed 226 227 def add(self, entries: Iterable[TokenizedStringEntry]) -> None: 228 """Adds new entries and updates date_removed for existing entries. 229 230 If the added tokens have removal dates, the newest date is used. 231 """ 232 self._cache = None 233 234 for new_entry in entries: 235 # Update an existing entry or create a new one. 236 try: 237 entry = self._database[new_entry.key()] 238 entry.domain = new_entry.domain 239 240 # Keep the latest removal date between the two entries. 241 if new_entry.date_removed is None: 242 entry.date_removed = None 243 elif ( 244 entry.date_removed 245 and entry.date_removed < new_entry.date_removed 246 ): 247 entry.date_removed = new_entry.date_removed 248 except KeyError: 249 # Make a copy to avoid unintentially updating the database. 250 self._database[new_entry.key()] = TokenizedStringEntry( 251 **vars(new_entry) 252 ) 253 254 def purge( 255 self, date_removed_cutoff: datetime | None = None 256 ) -> list[TokenizedStringEntry]: 257 """Removes and returns entries removed on/before date_removed_cutoff.""" 258 self._cache = None 259 260 if date_removed_cutoff is None: 261 date_removed_cutoff = datetime.max 262 263 to_delete = [ 264 entry 265 for _, entry in self._database.items() 266 if entry.date_removed and entry.date_removed <= date_removed_cutoff 267 ] 268 269 for entry in to_delete: 270 del self._database[entry.key()] 271 272 return to_delete 273 274 def merge(self, *databases: Database) -> None: 275 """Merges two or more databases together, keeping the newest dates.""" 276 self._cache = None 277 278 for other_db in databases: 279 for entry in other_db.entries(): 280 key = entry.key() 281 282 if key in self._database: 283 self._database[key].update_date_removed(entry.date_removed) 284 else: 285 self._database[key] = entry 286 287 def filter( 288 self, 289 include: Iterable[str | Pattern[str]] = (), 290 exclude: Iterable[str | Pattern[str]] = (), 291 replace: Iterable[tuple[str | Pattern[str], str]] = (), 292 ) -> None: 293 """Filters the database using regular expressions (strings or compiled). 294 295 Args: 296 include: regexes; only entries matching at least one are kept 297 exclude: regexes; entries matching any of these are removed 298 replace: (regex, str) tuples; replaces matching terms in all entries 299 """ 300 self._cache = None 301 302 to_delete: list[_EntryKey] = [] 303 304 if include: 305 include_re = [re.compile(pattern) for pattern in include] 306 to_delete.extend( 307 key 308 for key, val in self._database.items() 309 if not any(rgx.search(val.string) for rgx in include_re) 310 ) 311 312 if exclude: 313 exclude_re = [re.compile(pattern) for pattern in exclude] 314 to_delete.extend( 315 key 316 for key, val in self._database.items() 317 if any(rgx.search(val.string) for rgx in exclude_re) 318 ) 319 320 for key in to_delete: 321 del self._database[key] 322 323 for search, replacement in replace: 324 search = re.compile(search) 325 326 for value in self._database.values(): 327 value.string = search.sub(replacement, value.string) 328 329 def difference(self, other: Database) -> Database: 330 """Returns a new Database with entries in this DB not in the other.""" 331 # pylint: disable=protected-access 332 return Database( 333 e for k, e in self._database.items() if k not in other._database 334 ) 335 # pylint: enable=protected-access 336 337 def __len__(self) -> int: 338 """Returns the number of entries in the database.""" 339 return len(self.entries()) 340 341 def __bool__(self) -> bool: 342 """True if the database is non-empty.""" 343 return bool(self._database) 344 345 def __str__(self) -> str: 346 """Outputs the database as CSV.""" 347 csv_output = io.BytesIO() 348 write_csv(self, csv_output) 349 return csv_output.getvalue().decode() 350 351 352def parse_csv(fd: TextIO) -> Iterable[TokenizedStringEntry]: 353 """Parses TokenizedStringEntries from a CSV token database file.""" 354 entries = [] 355 for line in csv.reader(fd): 356 try: 357 token_str, date_str, string_literal = line 358 359 token = int(token_str, 16) 360 date = ( 361 datetime.fromisoformat(date_str) if date_str.strip() else None 362 ) 363 364 entries.append( 365 TokenizedStringEntry( 366 token, string_literal, DEFAULT_DOMAIN, date 367 ) 368 ) 369 except (ValueError, UnicodeDecodeError) as err: 370 _LOG.error( 371 'Failed to parse tokenized string entry %s: %s', line, err 372 ) 373 return entries 374 375 376def write_csv(database: Database, fd: BinaryIO) -> None: 377 """Writes the database as CSV to the provided binary file.""" 378 for entry in sorted(database.entries()): 379 _write_csv_line(fd, entry) 380 381 382def _write_csv_line(fd: BinaryIO, entry: TokenizedStringEntry): 383 """Write a line in CSV format to the provided binary file.""" 384 # Align the CSV output to 10-character columns for improved readability. 385 # Use \n instead of RFC 4180's \r\n. 386 fd.write( 387 '{:08x},{:10},"{}"\n'.format( 388 entry.token, 389 entry.date_removed.date().isoformat() if entry.date_removed else '', 390 entry.string.replace('"', '""'), 391 ).encode() 392 ) # escape " as "" 393 394 395class _BinaryFileFormat(NamedTuple): 396 """Attributes of the binary token database file format.""" 397 398 magic: bytes = b'TOKENS\0\0' 399 header: struct.Struct = struct.Struct('<8sI4x') 400 entry: struct.Struct = struct.Struct('<IBBH') 401 402 403BINARY_FORMAT = _BinaryFileFormat() 404 405 406class DatabaseFormatError(Exception): 407 """Failed to parse a token database file.""" 408 409 410def file_is_binary_database(fd: BinaryIO) -> bool: 411 """True if the file starts with the binary token database magic string.""" 412 try: 413 fd.seek(0) 414 magic = fd.read(len(BINARY_FORMAT.magic)) 415 fd.seek(0) 416 return BINARY_FORMAT.magic == magic 417 except IOError: 418 return False 419 420 421def _check_that_file_is_csv_database(path: Path) -> None: 422 """Raises an error unless the path appears to be a CSV token database.""" 423 try: 424 with path.open('rb') as fd: 425 data = fd.read(8) # Read 8 bytes, which should be the first token. 426 427 if not data: 428 return # File is empty, which is valid CSV. 429 430 if len(data) != 8: 431 raise DatabaseFormatError( 432 f'Attempted to read {path} as a CSV token database, but the ' 433 f'file is too short ({len(data)} B)' 434 ) 435 436 # Make sure the first 8 chars are a valid hexadecimal number. 437 _ = int(data.decode(), 16) 438 except (IOError, UnicodeDecodeError, ValueError) as err: 439 raise DatabaseFormatError( 440 f'Encountered error while reading {path} as a CSV token database' 441 ) from err 442 443 444def parse_binary(fd: BinaryIO) -> Iterable[TokenizedStringEntry]: 445 """Parses TokenizedStringEntries from a binary token database file.""" 446 magic, entry_count = BINARY_FORMAT.header.unpack( 447 fd.read(BINARY_FORMAT.header.size) 448 ) 449 450 if magic != BINARY_FORMAT.magic: 451 raise DatabaseFormatError( 452 f'Binary token database magic number mismatch (found {magic!r}, ' 453 f'expected {BINARY_FORMAT.magic!r}) while reading from {fd}' 454 ) 455 456 entries = [] 457 458 for _ in range(entry_count): 459 token, day, month, year = BINARY_FORMAT.entry.unpack( 460 fd.read(BINARY_FORMAT.entry.size) 461 ) 462 463 try: 464 date_removed: datetime | None = datetime(year, month, day) 465 except ValueError: 466 date_removed = None 467 468 entries.append((token, date_removed)) 469 470 # Read the entire string table and define a function for looking up strings. 471 string_table = fd.read() 472 473 def read_string(start): 474 end = string_table.find(b'\0', start) 475 return ( 476 string_table[start : string_table.find(b'\0', start)].decode(), 477 end + 1, 478 ) 479 480 offset = 0 481 for token, removed in entries: 482 string, offset = read_string(offset) 483 yield TokenizedStringEntry(token, string, DEFAULT_DOMAIN, removed) 484 485 486def write_binary(database: Database, fd: BinaryIO) -> None: 487 """Writes the database as packed binary to the provided binary file.""" 488 entries = sorted(database.entries()) 489 490 fd.write(BINARY_FORMAT.header.pack(BINARY_FORMAT.magic, len(entries))) 491 492 string_table = bytearray() 493 494 for entry in entries: 495 if entry.date_removed: 496 removed_day = entry.date_removed.day 497 removed_month = entry.date_removed.month 498 removed_year = entry.date_removed.year 499 else: 500 # If there is no removal date, use the special value 0xffffffff for 501 # the day/month/year. That ensures that still-present tokens appear 502 # as the newest tokens when sorted by removal date. 503 removed_day = 0xFF 504 removed_month = 0xFF 505 removed_year = 0xFFFF 506 507 string_table += entry.string.encode() 508 string_table.append(0) 509 510 fd.write( 511 BINARY_FORMAT.entry.pack( 512 entry.token, removed_day, removed_month, removed_year 513 ) 514 ) 515 516 fd.write(string_table) 517 518 519class DatabaseFile(Database): 520 """A token database that is associated with a particular file. 521 522 This class adds the write_to_file() method that writes to file from which it 523 was created in the correct format (CSV or binary). 524 """ 525 526 def __init__( 527 self, path: Path, entries: Iterable[TokenizedStringEntry] 528 ) -> None: 529 super().__init__(entries) 530 self.path = path 531 532 @staticmethod 533 def load(path: Path) -> DatabaseFile: 534 """Creates a DatabaseFile that coincides to the file type.""" 535 if path.is_dir(): 536 return _DirectoryDatabase(path) 537 538 # Read the path as a packed binary file. 539 with path.open('rb') as fd: 540 if file_is_binary_database(fd): 541 return _BinaryDatabase(path, fd) 542 543 # Read the path as a CSV file. 544 _check_that_file_is_csv_database(path) 545 return _CSVDatabase(path) 546 547 @abstractmethod 548 def write_to_file(self, *, rewrite: bool = False) -> None: 549 """Exports in the original format to the original path.""" 550 551 @abstractmethod 552 def add_and_discard_temporary( 553 self, entries: Iterable[TokenizedStringEntry], commit: str 554 ) -> None: 555 """Discards and adds entries to export in the original format. 556 557 Adds entries after removing temporary entries from the Database 558 to exclusively write re-occurring entries into memory and disk. 559 """ 560 561 562class _BinaryDatabase(DatabaseFile): 563 def __init__(self, path: Path, fd: BinaryIO) -> None: 564 super().__init__(path, parse_binary(fd)) 565 566 def write_to_file(self, *, rewrite: bool = False) -> None: 567 """Exports in the binary format to the original path.""" 568 del rewrite # Binary databases are always rewritten 569 with self.path.open('wb') as fd: 570 write_binary(self, fd) 571 572 def add_and_discard_temporary( 573 self, entries: Iterable[TokenizedStringEntry], commit: str 574 ) -> None: 575 # TODO: b/241471465 - Implement adding new tokens and removing 576 # temporary entries for binary databases. 577 raise NotImplementedError( 578 '--discard-temporary is currently only ' 579 'supported for directory databases' 580 ) 581 582 583class _CSVDatabase(DatabaseFile): 584 def __init__(self, path: Path) -> None: 585 with path.open('r', newline='', encoding='utf-8') as csv_fd: 586 super().__init__(path, parse_csv(csv_fd)) 587 588 def write_to_file(self, *, rewrite: bool = False) -> None: 589 """Exports in the CSV format to the original path.""" 590 del rewrite # CSV databases are always rewritten 591 with self.path.open('wb') as fd: 592 write_csv(self, fd) 593 594 def add_and_discard_temporary( 595 self, entries: Iterable[TokenizedStringEntry], commit: str 596 ) -> None: 597 # TODO: b/241471465 - Implement adding new tokens and removing 598 # temporary entries for CSV databases. 599 raise NotImplementedError( 600 '--discard-temporary is currently only ' 601 'supported for directory databases' 602 ) 603 604 605# The suffix used for CSV files in a directory database. 606DIR_DB_SUFFIX = '.pw_tokenizer.csv' 607DIR_DB_GLOB = '*' + DIR_DB_SUFFIX 608 609 610def _parse_directory(directory: Path) -> Iterable[TokenizedStringEntry]: 611 """Parses TokenizedStringEntries tokenizer CSV files in the directory.""" 612 for path in directory.glob(DIR_DB_GLOB): 613 yield from _CSVDatabase(path).entries() 614 615 616def _most_recently_modified_file(paths: Iterable[Path]) -> Path: 617 return max(paths, key=lambda path: path.stat().st_mtime) 618 619 620class _DirectoryDatabase(DatabaseFile): 621 def __init__(self, directory: Path) -> None: 622 super().__init__(directory, _parse_directory(directory)) 623 624 def write_to_file(self, *, rewrite: bool = False) -> None: 625 """Creates a new CSV file in the directory with any new tokens.""" 626 if rewrite: 627 # Write the entire database to a new CSV file 628 new_file = self._create_filename() 629 with new_file.open('wb') as fd: 630 write_csv(self, fd) 631 632 # Delete all CSV files except for the new CSV with everything. 633 for csv_file in self.path.glob(DIR_DB_GLOB): 634 if csv_file != new_file: 635 csv_file.unlink() 636 else: 637 # Reread the tokens from disk and write only the new entries to CSV. 638 current_tokens = Database(_parse_directory(self.path)) 639 new_entries = self.difference(current_tokens) 640 if new_entries: 641 with self._create_filename().open('wb') as fd: 642 write_csv(new_entries, fd) 643 644 def _git_paths(self, commands: list) -> list[Path]: 645 """Returns a list of database CSVs from a Git command.""" 646 try: 647 output = subprocess.run( 648 ['git', *commands, DIR_DB_GLOB], 649 capture_output=True, 650 check=True, 651 cwd=self.path, 652 text=True, 653 ).stdout.strip() 654 return [self.path / repo_path for repo_path in output.splitlines()] 655 except subprocess.CalledProcessError: 656 return [] 657 658 def _find_latest_csv(self, commit: str) -> Path: 659 """Finds or creates a CSV to which to write new entries. 660 661 - Check for untracked CSVs. Use the most recently modified file, if any. 662 - Check for CSVs added in HEAD, if HEAD is not an ancestor of commit. 663 Use the most recently modified file, if any. 664 - If no untracked or committed files were found, create a new file. 665 """ 666 667 # Prioritize untracked files in the directory database. 668 untracked_changes = self._git_paths( 669 ['ls-files', '--others', '--exclude-standard'] 670 ) 671 if untracked_changes: 672 return _most_recently_modified_file(untracked_changes) 673 674 # Check if HEAD is an ancestor of the base commit. This checks whether 675 # the top commit has been merged or not. If it has been merged, create a 676 # new CSV to use. Otherwise, check if a CSV was added in the commit. 677 head_is_not_merged = ( 678 subprocess.run( 679 ['git', 'merge-base', '--is-ancestor', 'HEAD', commit], 680 cwd=self.path, 681 stdout=subprocess.DEVNULL, 682 stderr=subprocess.DEVNULL, 683 ).returncode 684 != 0 685 ) 686 687 if head_is_not_merged: 688 # Find CSVs added in the top commit. 689 csvs_from_top_commit = self._git_paths( 690 [ 691 'diff', 692 '--name-only', 693 '--diff-filter=A', 694 '--relative', 695 'HEAD~', 696 ] 697 ) 698 699 if csvs_from_top_commit: 700 return _most_recently_modified_file(csvs_from_top_commit) 701 702 return self._create_filename() 703 704 def _create_filename(self) -> Path: 705 """Generates a unique filename not in the directory.""" 706 # Tracked and untracked files do not exist in the repo. 707 while (file := self.path / f'{uuid4().hex}{DIR_DB_SUFFIX}').exists(): 708 pass 709 return file 710 711 def add_and_discard_temporary( 712 self, entries: Iterable[TokenizedStringEntry], commit: str 713 ) -> None: 714 """Adds new entries and discards temporary entries on disk. 715 716 - Find the latest CSV in the directory database or create a new one. 717 - Delete entries in the latest CSV that are not in the entries passed to 718 this function. 719 - Add the new entries to this database. 720 - Overwrite the latest CSV with only the newly added entries. 721 """ 722 # Find entries not currently in the database. 723 added = Database(entries) 724 new_entries = added.difference(self) 725 726 csv_path = self._find_latest_csv(commit) 727 if csv_path.exists(): 728 # Loading the CSV as a DatabaseFile. 729 csv_db = DatabaseFile.load(csv_path) 730 731 # Delete entries added in the CSV, but not added in this function. 732 for key in (e.key() for e in csv_db.difference(added).entries()): 733 del self._database[key] 734 del csv_db._database[key] # pylint: disable=protected-access 735 736 csv_db.add(new_entries.entries()) 737 csv_db.write_to_file() 738 elif new_entries: # If the CSV does not exist, write all new tokens. 739 with csv_path.open('wb') as fd: 740 write_csv(new_entries, fd) 741 742 self.add(new_entries.entries()) 743