1# Copyright 2020 The Pigweed Authors 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); you may not 4# use this file except in compliance with the License. You may obtain a copy of 5# the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12# License for the specific language governing permissions and limitations under 13# the License. 14"""Builds and manages databases of tokenized strings.""" 15 16from abc import abstractmethod 17import collections 18import csv 19from dataclasses import dataclass 20from datetime import datetime 21import io 22import logging 23from pathlib import Path 24import re 25import struct 26import subprocess 27from typing import ( 28 BinaryIO, 29 Callable, 30 Dict, 31 Iterable, 32 Iterator, 33 List, 34 NamedTuple, 35 Optional, 36 Pattern, 37 TextIO, 38 Tuple, 39 Union, 40 ValuesView, 41) 42from uuid import uuid4 43 44DEFAULT_DOMAIN = '' 45 46# The default hash length to use for C-style hashes. This value only applies 47# when manually hashing strings to recreate token calculations in C. The C++ 48# hash function does not have a maximum length. 49# 50# This MUST match the default value of PW_TOKENIZER_CFG_C_HASH_LENGTH in 51# pw_tokenizer/public/pw_tokenizer/config.h. 52DEFAULT_C_HASH_LENGTH = 128 53 54TOKENIZER_HASH_CONSTANT = 65599 55 56_LOG = logging.getLogger('pw_tokenizer') 57 58 59def _value(char: Union[int, str]) -> int: 60 return char if isinstance(char, int) else ord(char) 61 62 63def pw_tokenizer_65599_hash( 64 string: Union[str, bytes], *, hash_length: Optional[int] = None 65) -> int: 66 """Hashes the string with the hash function used to generate tokens in C++. 67 68 This hash function is used calculate tokens from strings in Python. It is 69 not used when extracting tokens from an ELF, since the token is stored in 70 the ELF as part of tokenization. 71 """ 72 hash_value = len(string) 73 coefficient = TOKENIZER_HASH_CONSTANT 74 75 for char in string[:hash_length]: 76 hash_value = (hash_value + coefficient * _value(char)) % 2**32 77 coefficient = (coefficient * TOKENIZER_HASH_CONSTANT) % 2**32 78 79 return hash_value 80 81 82def c_hash( 83 string: Union[str, bytes], hash_length: int = DEFAULT_C_HASH_LENGTH 84) -> int: 85 """Hashes the string with the hash function used in C.""" 86 return pw_tokenizer_65599_hash(string, hash_length=hash_length) 87 88 89class _EntryKey(NamedTuple): 90 """Uniquely refers to an entry.""" 91 92 token: int 93 string: str 94 95 96@dataclass(eq=True, order=False) 97class TokenizedStringEntry: 98 """A tokenized string with its metadata.""" 99 100 token: int 101 string: str 102 domain: str = DEFAULT_DOMAIN 103 date_removed: Optional[datetime] = None 104 105 def key(self) -> _EntryKey: 106 """The key determines uniqueness for a tokenized string.""" 107 return _EntryKey(self.token, self.string) 108 109 def update_date_removed(self, new_date_removed: Optional[datetime]) -> None: 110 """Sets self.date_removed if the other date is newer.""" 111 # No removal date (None) is treated as the newest date. 112 if self.date_removed is None: 113 return 114 115 if new_date_removed is None or new_date_removed > self.date_removed: 116 self.date_removed = new_date_removed 117 118 def __lt__(self, other) -> bool: 119 """Sorts the entry by token, date removed, then string.""" 120 if self.token != other.token: 121 return self.token < other.token 122 123 # Sort removal dates in reverse, so the most recently removed (or still 124 # present) entry appears first. 125 if self.date_removed != other.date_removed: 126 return (other.date_removed or datetime.max) < ( 127 self.date_removed or datetime.max 128 ) 129 130 return self.string < other.string 131 132 def __str__(self) -> str: 133 return self.string 134 135 136class Database: 137 """Database of tokenized strings stored as TokenizedStringEntry objects.""" 138 139 def __init__(self, entries: Iterable[TokenizedStringEntry] = ()): 140 """Creates a token database.""" 141 # The database dict stores each unique (token, string) entry. 142 self._database: Dict[_EntryKey, TokenizedStringEntry] = {} 143 144 # This is a cache for fast token lookup that is built as needed. 145 self._cache: Optional[Dict[int, List[TokenizedStringEntry]]] = None 146 147 self.add(entries) 148 149 @classmethod 150 def from_strings( 151 cls, 152 strings: Iterable[str], 153 domain: str = DEFAULT_DOMAIN, 154 tokenize: Callable[[str], int] = pw_tokenizer_65599_hash, 155 ) -> 'Database': 156 """Creates a Database from an iterable of strings.""" 157 return cls( 158 ( 159 TokenizedStringEntry(tokenize(string), string, domain) 160 for string in strings 161 ) 162 ) 163 164 @classmethod 165 def merged(cls, *databases: 'Database') -> 'Database': 166 """Creates a TokenDatabase from one or more other databases.""" 167 db = cls() 168 db.merge(*databases) 169 return db 170 171 @property 172 def token_to_entries(self) -> Dict[int, List[TokenizedStringEntry]]: 173 """Returns a dict that maps tokens to a list of TokenizedStringEntry.""" 174 if self._cache is None: # build cache token -> entry cache 175 self._cache = collections.defaultdict(list) 176 for entry in self._database.values(): 177 self._cache[entry.token].append(entry) 178 179 return self._cache 180 181 def entries(self) -> ValuesView[TokenizedStringEntry]: 182 """Returns iterable over all TokenizedStringEntries in the database.""" 183 return self._database.values() 184 185 def collisions(self) -> Iterator[Tuple[int, List[TokenizedStringEntry]]]: 186 """Returns tuple of (token, entries_list)) for all colliding tokens.""" 187 for token, entries in self.token_to_entries.items(): 188 if len(entries) > 1: 189 yield token, entries 190 191 def mark_removed( 192 self, 193 all_entries: Iterable[TokenizedStringEntry], 194 removal_date: Optional[datetime] = None, 195 ) -> List[TokenizedStringEntry]: 196 """Marks entries missing from all_entries as having been removed. 197 198 The entries are assumed to represent the complete set of entries for the 199 database. Entries currently in the database not present in the provided 200 entries are marked with a removal date but remain in the database. 201 Entries in all_entries missing from the database are NOT added; call the 202 add function to add these. 203 204 Args: 205 all_entries: the complete set of strings present in the database 206 removal_date: the datetime for removed entries; today by default 207 208 Returns: 209 A list of entries marked as removed. 210 """ 211 self._cache = None 212 213 if removal_date is None: 214 removal_date = datetime.now() 215 216 all_keys = frozenset(entry.key() for entry in all_entries) 217 218 removed = [] 219 220 for entry in self._database.values(): 221 if entry.key() not in all_keys and ( 222 entry.date_removed is None or removal_date < entry.date_removed 223 ): 224 # Add a removal date, or update it to the oldest date. 225 entry.date_removed = removal_date 226 removed.append(entry) 227 228 return removed 229 230 def add(self, entries: Iterable[TokenizedStringEntry]) -> None: 231 """Adds new entries and updates date_removed for existing entries. 232 233 If the added tokens have removal dates, the newest date is used. 234 """ 235 self._cache = None 236 237 for new_entry in entries: 238 # Update an existing entry or create a new one. 239 try: 240 entry = self._database[new_entry.key()] 241 entry.domain = new_entry.domain 242 243 # Keep the latest removal date between the two entries. 244 if new_entry.date_removed is None: 245 entry.date_removed = None 246 elif ( 247 entry.date_removed 248 and entry.date_removed < new_entry.date_removed 249 ): 250 entry.date_removed = new_entry.date_removed 251 except KeyError: 252 # Make a copy to avoid unintentially updating the database. 253 self._database[new_entry.key()] = TokenizedStringEntry( 254 **vars(new_entry) 255 ) 256 257 def purge( 258 self, date_removed_cutoff: Optional[datetime] = None 259 ) -> List[TokenizedStringEntry]: 260 """Removes and returns entries removed on/before date_removed_cutoff.""" 261 self._cache = None 262 263 if date_removed_cutoff is None: 264 date_removed_cutoff = datetime.max 265 266 to_delete = [ 267 entry 268 for _, entry in self._database.items() 269 if entry.date_removed and entry.date_removed <= date_removed_cutoff 270 ] 271 272 for entry in to_delete: 273 del self._database[entry.key()] 274 275 return to_delete 276 277 def merge(self, *databases: 'Database') -> None: 278 """Merges two or more databases together, keeping the newest dates.""" 279 self._cache = None 280 281 for other_db in databases: 282 for entry in other_db.entries(): 283 key = entry.key() 284 285 if key in self._database: 286 self._database[key].update_date_removed(entry.date_removed) 287 else: 288 self._database[key] = entry 289 290 def filter( 291 self, 292 include: Iterable[Union[str, Pattern[str]]] = (), 293 exclude: Iterable[Union[str, Pattern[str]]] = (), 294 replace: Iterable[Tuple[Union[str, Pattern[str]], str]] = (), 295 ) -> None: 296 """Filters the database using regular expressions (strings or compiled). 297 298 Args: 299 include: regexes; only entries matching at least one are kept 300 exclude: regexes; entries matching any of these are removed 301 replace: (regex, str) tuples; replaces matching terms in all entries 302 """ 303 self._cache = None 304 305 to_delete: List[_EntryKey] = [] 306 307 if include: 308 include_re = [re.compile(pattern) for pattern in include] 309 to_delete.extend( 310 key 311 for key, val in self._database.items() 312 if not any(rgx.search(val.string) for rgx in include_re) 313 ) 314 315 if exclude: 316 exclude_re = [re.compile(pattern) for pattern in exclude] 317 to_delete.extend( 318 key 319 for key, val in self._database.items() 320 if any(rgx.search(val.string) for rgx in exclude_re) 321 ) 322 323 for key in to_delete: 324 del self._database[key] 325 326 for search, replacement in replace: 327 search = re.compile(search) 328 329 for value in self._database.values(): 330 value.string = search.sub(replacement, value.string) 331 332 def difference(self, other: 'Database') -> 'Database': 333 """Returns a new Database with entries in this DB not in the other.""" 334 # pylint: disable=protected-access 335 return Database( 336 e for k, e in self._database.items() if k not in other._database 337 ) 338 # pylint: enable=protected-access 339 340 def __len__(self) -> int: 341 """Returns the number of entries in the database.""" 342 return len(self.entries()) 343 344 def __bool__(self) -> bool: 345 """True if the database is non-empty.""" 346 return bool(self._database) 347 348 def __str__(self) -> str: 349 """Outputs the database as CSV.""" 350 csv_output = io.BytesIO() 351 write_csv(self, csv_output) 352 return csv_output.getvalue().decode() 353 354 355def parse_csv(fd: TextIO) -> Iterable[TokenizedStringEntry]: 356 """Parses TokenizedStringEntries from a CSV token database file.""" 357 entries = [] 358 for line in csv.reader(fd): 359 try: 360 token_str, date_str, string_literal = line 361 362 token = int(token_str, 16) 363 date = ( 364 datetime.fromisoformat(date_str) if date_str.strip() else None 365 ) 366 367 entries.append( 368 TokenizedStringEntry( 369 token, string_literal, DEFAULT_DOMAIN, date 370 ) 371 ) 372 except (ValueError, UnicodeDecodeError) as err: 373 _LOG.error( 374 'Failed to parse tokenized string entry %s: %s', line, err 375 ) 376 return entries 377 378 379def write_csv(database: Database, fd: BinaryIO) -> None: 380 """Writes the database as CSV to the provided binary file.""" 381 for entry in sorted(database.entries()): 382 _write_csv_line(fd, entry) 383 384 385def _write_csv_line(fd: BinaryIO, entry: TokenizedStringEntry): 386 """Write a line in CSV format to the provided binary file.""" 387 # Align the CSV output to 10-character columns for improved readability. 388 # Use \n instead of RFC 4180's \r\n. 389 fd.write( 390 '{:08x},{:10},"{}"\n'.format( 391 entry.token, 392 entry.date_removed.date().isoformat() if entry.date_removed else '', 393 entry.string.replace('"', '""'), 394 ).encode() 395 ) # escape " as "" 396 397 398class _BinaryFileFormat(NamedTuple): 399 """Attributes of the binary token database file format.""" 400 401 magic: bytes = b'TOKENS\0\0' 402 header: struct.Struct = struct.Struct('<8sI4x') 403 entry: struct.Struct = struct.Struct('<IBBH') 404 405 406BINARY_FORMAT = _BinaryFileFormat() 407 408 409class DatabaseFormatError(Exception): 410 """Failed to parse a token database file.""" 411 412 413def file_is_binary_database(fd: BinaryIO) -> bool: 414 """True if the file starts with the binary token database magic string.""" 415 try: 416 fd.seek(0) 417 magic = fd.read(len(BINARY_FORMAT.magic)) 418 fd.seek(0) 419 return BINARY_FORMAT.magic == magic 420 except IOError: 421 return False 422 423 424def _check_that_file_is_csv_database(path: Path) -> None: 425 """Raises an error unless the path appears to be a CSV token database.""" 426 try: 427 with path.open('rb') as fd: 428 data = fd.read(8) # Read 8 bytes, which should be the first token. 429 430 if not data: 431 return # File is empty, which is valid CSV. 432 433 if len(data) != 8: 434 raise DatabaseFormatError( 435 f'Attempted to read {path} as a CSV token database, but the ' 436 f'file is too short ({len(data)} B)' 437 ) 438 439 # Make sure the first 8 chars are a valid hexadecimal number. 440 _ = int(data.decode(), 16) 441 except (IOError, UnicodeDecodeError, ValueError) as err: 442 raise DatabaseFormatError( 443 f'Encountered error while reading {path} as a CSV token database' 444 ) from err 445 446 447def parse_binary(fd: BinaryIO) -> Iterable[TokenizedStringEntry]: 448 """Parses TokenizedStringEntries from a binary token database file.""" 449 magic, entry_count = BINARY_FORMAT.header.unpack( 450 fd.read(BINARY_FORMAT.header.size) 451 ) 452 453 if magic != BINARY_FORMAT.magic: 454 raise DatabaseFormatError( 455 f'Binary token database magic number mismatch (found {magic!r}, ' 456 f'expected {BINARY_FORMAT.magic!r}) while reading from {fd}' 457 ) 458 459 entries = [] 460 461 for _ in range(entry_count): 462 token, day, month, year = BINARY_FORMAT.entry.unpack( 463 fd.read(BINARY_FORMAT.entry.size) 464 ) 465 466 try: 467 date_removed: Optional[datetime] = datetime(year, month, day) 468 except ValueError: 469 date_removed = None 470 471 entries.append((token, date_removed)) 472 473 # Read the entire string table and define a function for looking up strings. 474 string_table = fd.read() 475 476 def read_string(start): 477 end = string_table.find(b'\0', start) 478 return ( 479 string_table[start : string_table.find(b'\0', start)].decode(), 480 end + 1, 481 ) 482 483 offset = 0 484 for token, removed in entries: 485 string, offset = read_string(offset) 486 yield TokenizedStringEntry(token, string, DEFAULT_DOMAIN, removed) 487 488 489def write_binary(database: Database, fd: BinaryIO) -> None: 490 """Writes the database as packed binary to the provided binary file.""" 491 entries = sorted(database.entries()) 492 493 fd.write(BINARY_FORMAT.header.pack(BINARY_FORMAT.magic, len(entries))) 494 495 string_table = bytearray() 496 497 for entry in entries: 498 if entry.date_removed: 499 removed_day = entry.date_removed.day 500 removed_month = entry.date_removed.month 501 removed_year = entry.date_removed.year 502 else: 503 # If there is no removal date, use the special value 0xffffffff for 504 # the day/month/year. That ensures that still-present tokens appear 505 # as the newest tokens when sorted by removal date. 506 removed_day = 0xFF 507 removed_month = 0xFF 508 removed_year = 0xFFFF 509 510 string_table += entry.string.encode() 511 string_table.append(0) 512 513 fd.write( 514 BINARY_FORMAT.entry.pack( 515 entry.token, removed_day, removed_month, removed_year 516 ) 517 ) 518 519 fd.write(string_table) 520 521 522class DatabaseFile(Database): 523 """A token database that is associated with a particular file. 524 525 This class adds the write_to_file() method that writes to file from which it 526 was created in the correct format (CSV or binary). 527 """ 528 529 def __init__( 530 self, path: Path, entries: Iterable[TokenizedStringEntry] 531 ) -> None: 532 super().__init__(entries) 533 self.path = path 534 535 @staticmethod 536 def load(path: Path) -> 'DatabaseFile': 537 """Creates a DatabaseFile that coincides to the file type.""" 538 if path.is_dir(): 539 return _DirectoryDatabase(path) 540 541 # Read the path as a packed binary file. 542 with path.open('rb') as fd: 543 if file_is_binary_database(fd): 544 return _BinaryDatabase(path, fd) 545 546 # Read the path as a CSV file. 547 _check_that_file_is_csv_database(path) 548 return _CSVDatabase(path) 549 550 @abstractmethod 551 def write_to_file(self, *, rewrite: bool = False) -> None: 552 """Exports in the original format to the original path.""" 553 554 @abstractmethod 555 def add_and_discard_temporary( 556 self, entries: Iterable[TokenizedStringEntry], commit: str 557 ) -> None: 558 """Discards and adds entries to export in the original format. 559 560 Adds entries after removing temporary entries from the Database 561 to exclusively write re-occurring entries into memory and disk. 562 """ 563 564 565class _BinaryDatabase(DatabaseFile): 566 def __init__(self, path: Path, fd: BinaryIO) -> None: 567 super().__init__(path, parse_binary(fd)) 568 569 def write_to_file(self, *, rewrite: bool = False) -> None: 570 """Exports in the binary format to the original path.""" 571 del rewrite # Binary databases are always rewritten 572 with self.path.open('wb') as fd: 573 write_binary(self, fd) 574 575 def add_and_discard_temporary( 576 self, entries: Iterable[TokenizedStringEntry], commit: str 577 ) -> None: 578 # TODO(b/241471465): Implement adding new tokens and removing 579 # temporary entries for binary databases. 580 raise NotImplementedError( 581 '--discard-temporary is currently only ' 582 'supported for directory databases' 583 ) 584 585 586class _CSVDatabase(DatabaseFile): 587 def __init__(self, path: Path) -> None: 588 with path.open('r', newline='', encoding='utf-8') as csv_fd: 589 super().__init__(path, parse_csv(csv_fd)) 590 591 def write_to_file(self, *, rewrite: bool = False) -> None: 592 """Exports in the CSV format to the original path.""" 593 del rewrite # CSV databases are always rewritten 594 with self.path.open('wb') as fd: 595 write_csv(self, fd) 596 597 def add_and_discard_temporary( 598 self, entries: Iterable[TokenizedStringEntry], commit: str 599 ) -> None: 600 # TODO(b/241471465): Implement adding new tokens and removing 601 # temporary entries for CSV databases. 602 raise NotImplementedError( 603 '--discard-temporary is currently only ' 604 'supported for directory databases' 605 ) 606 607 608# The suffix used for CSV files in a directory database. 609DIR_DB_SUFFIX = '.pw_tokenizer.csv' 610_DIR_DB_GLOB = '*' + DIR_DB_SUFFIX 611 612 613def _parse_directory(directory: Path) -> Iterable[TokenizedStringEntry]: 614 """Parses TokenizedStringEntries tokenizer CSV files in the directory.""" 615 for path in directory.glob(_DIR_DB_GLOB): 616 yield from _CSVDatabase(path).entries() 617 618 619def _most_recently_modified_file(paths: Iterable[Path]) -> Path: 620 return max(paths, key=lambda path: path.stat().st_mtime) 621 622 623class _DirectoryDatabase(DatabaseFile): 624 def __init__(self, directory: Path) -> None: 625 super().__init__(directory, _parse_directory(directory)) 626 627 def write_to_file(self, *, rewrite: bool = False) -> None: 628 """Creates a new CSV file in the directory with any new tokens.""" 629 if rewrite: 630 # Write the entire database to a new CSV file 631 new_file = self._create_filename() 632 with new_file.open('wb') as fd: 633 write_csv(self, fd) 634 635 # Delete all CSV files except for the new CSV with everything. 636 for csv_file in self.path.glob(_DIR_DB_GLOB): 637 if csv_file != new_file: 638 csv_file.unlink() 639 else: 640 # Reread the tokens from disk and write only the new entries to CSV. 641 current_tokens = Database(_parse_directory(self.path)) 642 new_entries = self.difference(current_tokens) 643 if new_entries: 644 with self._create_filename().open('wb') as fd: 645 write_csv(new_entries, fd) 646 647 def _git_paths(self, commands: List) -> List[Path]: 648 """Returns a list of files from a Git command, filtered to matc.""" 649 try: 650 output = subprocess.run( 651 ['git', *commands, _DIR_DB_GLOB], 652 capture_output=True, 653 check=True, 654 cwd=self.path, 655 text=True, 656 ).stdout.strip() 657 return [self.path / repo_path for repo_path in output.splitlines()] 658 except subprocess.CalledProcessError: 659 return [] 660 661 def _find_latest_csv(self, commit: str) -> Path: 662 """Finds or creates a CSV to which to write new entries. 663 664 - Check for untracked CSVs. Use the most recently modified file, if any. 665 - Check for CSVs added in HEAD, if HEAD is not an ancestor of commit. 666 Use the most recently modified file, if any. 667 - If no untracked or committed files were found, create a new file. 668 """ 669 670 # Prioritize untracked files in the directory database. 671 untracked_changes = self._git_paths( 672 ['ls-files', '--others', '--exclude-standard'] 673 ) 674 if untracked_changes: 675 return _most_recently_modified_file(untracked_changes) 676 677 # Check if HEAD is an ancestor of the base commit. This checks whether 678 # the top commit has been merged or not. If it has been merged, create a 679 # new CSV to use. Otherwise, check if a CSV was added in the commit. 680 head_is_not_merged = ( 681 subprocess.run( 682 ['git', 'merge-base', '--is-ancestor', 'HEAD', commit], 683 cwd=self.path, 684 stdout=subprocess.DEVNULL, 685 stderr=subprocess.DEVNULL, 686 ).returncode 687 != 0 688 ) 689 690 if head_is_not_merged: 691 # Find CSVs added in the top commit. 692 csvs_from_top_commit = self._git_paths( 693 [ 694 'diff', 695 '--name-only', 696 '--diff-filter=A', 697 '--relative', 698 'HEAD~', 699 ] 700 ) 701 702 if csvs_from_top_commit: 703 return _most_recently_modified_file(csvs_from_top_commit) 704 705 return self._create_filename() 706 707 def _create_filename(self) -> Path: 708 """Generates a unique filename not in the directory.""" 709 # Tracked and untracked files do not exist in the repo. 710 while (file := self.path / f'{uuid4().hex}{DIR_DB_SUFFIX}').exists(): 711 pass 712 return file 713 714 def add_and_discard_temporary( 715 self, entries: Iterable[TokenizedStringEntry], commit: str 716 ) -> None: 717 """Adds new entries and discards temporary entries on disk. 718 719 - Find the latest CSV in the directory database or create a new one. 720 - Delete entries in the latest CSV that are not in the entries passed to 721 this function. 722 - Add the new entries to this database. 723 - Overwrite the latest CSV with only the newly added entries. 724 """ 725 # Find entries not currently in the database. 726 added = Database(entries) 727 new_entries = added.difference(self) 728 729 csv_path = self._find_latest_csv(commit) 730 if csv_path.exists(): 731 # Loading the CSV as a DatabaseFile. 732 csv_db = DatabaseFile.load(csv_path) 733 734 # Delete entries added in the CSV, but not added in this function. 735 for key in (e.key() for e in csv_db.difference(added).entries()): 736 del self._database[key] 737 del csv_db._database[key] # pylint: disable=protected-access 738 739 csv_db.add(new_entries.entries()) 740 csv_db.write_to_file() 741 elif new_entries: # If the CSV does not exist, write all new tokens. 742 with csv_path.open('wb') as fd: 743 write_csv(new_entries, fd) 744 745 self.add(new_entries.entries()) 746