1# Copyright 2020 The Pigweed Authors 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); you may not 4# use this file except in compliance with the License. You may obtain a copy of 5# the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12# License for the specific language governing permissions and limitations under 13# the License. 14"""Builds and manages databases of tokenized strings.""" 15 16import collections 17import csv 18from dataclasses import dataclass 19from datetime import datetime 20import io 21import logging 22from pathlib import Path 23import re 24import struct 25from typing import (BinaryIO, Callable, Dict, Iterable, Iterator, List, 26 NamedTuple, Optional, Pattern, Tuple, Union, ValuesView) 27 28DATE_FORMAT = '%Y-%m-%d' 29DEFAULT_DOMAIN = '' 30 31# The default hash length to use. This value only applies when hashing strings 32# from a legacy-style ELF with plain strings. New tokenized string entries 33# include the token alongside the string. 34# 35# This MUST match the default value of PW_TOKENIZER_CFG_C_HASH_LENGTH in 36# pw_tokenizer/public/pw_tokenizer/config.h. 37DEFAULT_C_HASH_LENGTH = 128 38 39TOKENIZER_HASH_CONSTANT = 65599 40 41_LOG = logging.getLogger('pw_tokenizer') 42 43 44def _value(char: Union[int, str]) -> int: 45 return char if isinstance(char, int) else ord(char) 46 47 48def pw_tokenizer_65599_fixed_length_hash(string: Union[str, bytes], 49 hash_length: int) -> int: 50 """Hashes the provided string. 51 52 This hash function is only used when adding tokens from legacy-style 53 tokenized strings in an ELF, which do not include the token. 54 """ 55 hash_value = len(string) 56 coefficient = TOKENIZER_HASH_CONSTANT 57 58 for char in string[:hash_length]: 59 hash_value = (hash_value + coefficient * _value(char)) % 2**32 60 coefficient = (coefficient * TOKENIZER_HASH_CONSTANT) % 2**32 61 62 return hash_value 63 64 65def default_hash(string: Union[str, bytes]) -> int: 66 return pw_tokenizer_65599_fixed_length_hash(string, DEFAULT_C_HASH_LENGTH) 67 68 69class _EntryKey(NamedTuple): 70 """Uniquely refers to an entry.""" 71 token: int 72 string: str 73 74 75@dataclass(eq=True, order=False) 76class TokenizedStringEntry: 77 """A tokenized string with its metadata.""" 78 token: int 79 string: str 80 domain: str = DEFAULT_DOMAIN 81 date_removed: Optional[datetime] = None 82 83 def key(self) -> _EntryKey: 84 """The key determines uniqueness for a tokenized string.""" 85 return _EntryKey(self.token, self.string) 86 87 def update_date_removed(self, 88 new_date_removed: Optional[datetime]) -> None: 89 """Sets self.date_removed if the other date is newer.""" 90 # No removal date (None) is treated as the newest date. 91 if self.date_removed is None: 92 return 93 94 if new_date_removed is None or new_date_removed > self.date_removed: 95 self.date_removed = new_date_removed 96 97 def __lt__(self, other) -> bool: 98 """Sorts the entry by token, date removed, then string.""" 99 if self.token != other.token: 100 return self.token < other.token 101 102 # Sort removal dates in reverse, so the most recently removed (or still 103 # present) entry appears first. 104 if self.date_removed != other.date_removed: 105 return (other.date_removed or datetime.max) < (self.date_removed 106 or datetime.max) 107 108 return self.string < other.string 109 110 def __str__(self) -> str: 111 return self.string 112 113 114class Database: 115 """Database of tokenized strings stored as TokenizedStringEntry objects.""" 116 def __init__(self, entries: Iterable[TokenizedStringEntry] = ()): 117 """Creates a token database.""" 118 # The database dict stores each unique (token, string) entry. 119 self._database: Dict[_EntryKey, TokenizedStringEntry] = { 120 entry.key(): entry 121 for entry in entries 122 } 123 124 # This is a cache for fast token lookup that is built as needed. 125 self._cache: Optional[Dict[int, List[TokenizedStringEntry]]] = None 126 127 @classmethod 128 def from_strings( 129 cls, 130 strings: Iterable[str], 131 domain: str = DEFAULT_DOMAIN, 132 tokenize: Callable[[str], int] = default_hash) -> 'Database': 133 """Creates a Database from an iterable of strings.""" 134 return cls((TokenizedStringEntry(tokenize(string), string, domain) 135 for string in strings)) 136 137 @classmethod 138 def merged(cls, *databases: 'Database') -> 'Database': 139 """Creates a TokenDatabase from one or more other databases.""" 140 db = cls() 141 db.merge(*databases) 142 return db 143 144 @property 145 def token_to_entries(self) -> Dict[int, List[TokenizedStringEntry]]: 146 """Returns a dict that maps tokens to a list of TokenizedStringEntry.""" 147 if self._cache is None: # build cache token -> entry cache 148 self._cache = collections.defaultdict(list) 149 for entry in self._database.values(): 150 self._cache[entry.token].append(entry) 151 152 return self._cache 153 154 def entries(self) -> ValuesView[TokenizedStringEntry]: 155 """Returns iterable over all TokenizedStringEntries in the database.""" 156 return self._database.values() 157 158 def collisions(self) -> Iterator[Tuple[int, List[TokenizedStringEntry]]]: 159 """Returns tuple of (token, entries_list)) for all colliding tokens.""" 160 for token, entries in self.token_to_entries.items(): 161 if len(entries) > 1: 162 yield token, entries 163 164 def mark_removed( 165 self, 166 all_entries: Iterable[TokenizedStringEntry], 167 removal_date: Optional[datetime] = None 168 ) -> List[TokenizedStringEntry]: 169 """Marks entries missing from all_entries as having been removed. 170 171 The entries are assumed to represent the complete set of entries for the 172 database. Entries currently in the database not present in the provided 173 entries are marked with a removal date but remain in the database. 174 Entries in all_entries missing from the database are NOT added; call the 175 add function to add these. 176 177 Args: 178 all_entries: the complete set of strings present in the database 179 removal_date: the datetime for removed entries; today by default 180 181 Returns: 182 A list of entries marked as removed. 183 """ 184 self._cache = None 185 186 if removal_date is None: 187 removal_date = datetime.now() 188 189 all_keys = frozenset(entry.key() for entry in all_entries) 190 191 removed = [] 192 193 for entry in self._database.values(): 194 if (entry.key() not in all_keys 195 and (entry.date_removed is None 196 or removal_date < entry.date_removed)): 197 # Add a removal date, or update it to the oldest date. 198 entry.date_removed = removal_date 199 removed.append(entry) 200 201 return removed 202 203 def add(self, entries: Iterable[TokenizedStringEntry]) -> None: 204 """Adds new entries and updates date_removed for existing entries.""" 205 self._cache = None 206 207 for new_entry in entries: 208 # Update an existing entry or create a new one. 209 try: 210 entry = self._database[new_entry.key()] 211 entry.domain = new_entry.domain 212 entry.date_removed = None 213 except KeyError: 214 self._database[new_entry.key()] = TokenizedStringEntry( 215 new_entry.token, new_entry.string, new_entry.domain) 216 217 def purge( 218 self, 219 date_removed_cutoff: Optional[datetime] = None 220 ) -> List[TokenizedStringEntry]: 221 """Removes and returns entries removed on/before date_removed_cutoff.""" 222 self._cache = None 223 224 if date_removed_cutoff is None: 225 date_removed_cutoff = datetime.max 226 227 to_delete = [ 228 entry for _, entry in self._database.items() 229 if entry.date_removed and entry.date_removed <= date_removed_cutoff 230 ] 231 232 for entry in to_delete: 233 del self._database[entry.key()] 234 235 return to_delete 236 237 def merge(self, *databases: 'Database') -> None: 238 """Merges two or more databases together, keeping the newest dates.""" 239 self._cache = None 240 241 for other_db in databases: 242 for entry in other_db.entries(): 243 key = entry.key() 244 245 if key in self._database: 246 self._database[key].update_date_removed(entry.date_removed) 247 else: 248 self._database[key] = entry 249 250 def filter( 251 self, 252 include: Iterable[Union[str, Pattern[str]]] = (), 253 exclude: Iterable[Union[str, Pattern[str]]] = (), 254 replace: Iterable[Tuple[Union[str, Pattern[str]], str]] = () 255 ) -> None: 256 """Filters the database using regular expressions (strings or compiled). 257 258 Args: 259 include: regexes; only entries matching at least one are kept 260 exclude: regexes; entries matching any of these are removed 261 replace: (regex, str) tuples; replaces matching terms in all entries 262 """ 263 self._cache = None 264 265 to_delete: List[_EntryKey] = [] 266 267 if include: 268 include_re = [re.compile(pattern) for pattern in include] 269 to_delete.extend( 270 key for key, val in self._database.items() 271 if not any(rgx.search(val.string) for rgx in include_re)) 272 273 if exclude: 274 exclude_re = [re.compile(pattern) for pattern in exclude] 275 to_delete.extend(key for key, val in self._database.items() if any( 276 rgx.search(val.string) for rgx in exclude_re)) 277 278 for key in to_delete: 279 del self._database[key] 280 281 for search, replacement in replace: 282 search = re.compile(search) 283 284 for value in self._database.values(): 285 value.string = search.sub(replacement, value.string) 286 287 def __len__(self) -> int: 288 """Returns the number of entries in the database.""" 289 return len(self.entries()) 290 291 def __str__(self) -> str: 292 """Outputs the database as CSV.""" 293 csv_output = io.BytesIO() 294 write_csv(self, csv_output) 295 return csv_output.getvalue().decode() 296 297 298def parse_csv(fd) -> Iterable[TokenizedStringEntry]: 299 """Parses TokenizedStringEntries from a CSV token database file.""" 300 for line in csv.reader(fd): 301 try: 302 token_str, date_str, string_literal = line 303 304 token = int(token_str, 16) 305 date = (datetime.strptime(date_str, DATE_FORMAT) 306 if date_str.strip() else None) 307 308 yield TokenizedStringEntry(token, string_literal, DEFAULT_DOMAIN, 309 date) 310 except (ValueError, UnicodeDecodeError) as err: 311 _LOG.error('Failed to parse tokenized string entry %s: %s', line, 312 err) 313 314 315def write_csv(database: Database, fd: BinaryIO) -> None: 316 """Writes the database as CSV to the provided binary file.""" 317 for entry in sorted(database.entries()): 318 # Align the CSV output to 10-character columns for improved readability. 319 # Use \n instead of RFC 4180's \r\n. 320 fd.write('{:08x},{:10},"{}"\n'.format( 321 entry.token, 322 entry.date_removed.strftime(DATE_FORMAT) if entry.date_removed else 323 '', entry.string.replace('"', '""')).encode()) # escape " as "" 324 325 326class _BinaryFileFormat(NamedTuple): 327 """Attributes of the binary token database file format.""" 328 329 magic: bytes = b'TOKENS\0\0' 330 header: struct.Struct = struct.Struct('<8sI4x') 331 entry: struct.Struct = struct.Struct('<IBBH') 332 333 334BINARY_FORMAT = _BinaryFileFormat() 335 336 337class DatabaseFormatError(Exception): 338 """Failed to parse a token database file.""" 339 340 341def file_is_binary_database(fd: BinaryIO) -> bool: 342 """True if the file starts with the binary token database magic string.""" 343 try: 344 fd.seek(0) 345 magic = fd.read(len(BINARY_FORMAT.magic)) 346 fd.seek(0) 347 return BINARY_FORMAT.magic == magic 348 except IOError: 349 return False 350 351 352def _check_that_file_is_csv_database(path: Path) -> None: 353 """Raises an error unless the path appears to be a CSV token database.""" 354 try: 355 with path.open('rb') as fd: 356 data = fd.read(8) # Read 8 bytes, which should be the first token. 357 358 if not data: 359 return # File is empty, which is valid CSV. 360 361 if len(data) != 8: 362 raise DatabaseFormatError( 363 f'Attempted to read {path} as a CSV token database, but the ' 364 f'file is too short ({len(data)} B)') 365 366 # Make sure the first 8 chars are a valid hexadecimal number. 367 _ = int(data.decode(), 16) 368 except (IOError, UnicodeDecodeError, ValueError) as err: 369 raise DatabaseFormatError( 370 f'Encountered error while reading {path} as a CSV token database' 371 ) from err 372 373 374def parse_binary(fd: BinaryIO) -> Iterable[TokenizedStringEntry]: 375 """Parses TokenizedStringEntries from a binary token database file.""" 376 magic, entry_count = BINARY_FORMAT.header.unpack( 377 fd.read(BINARY_FORMAT.header.size)) 378 379 if magic != BINARY_FORMAT.magic: 380 raise DatabaseFormatError( 381 f'Binary token database magic number mismatch (found {magic!r}, ' 382 f'expected {BINARY_FORMAT.magic!r}) while reading from {fd}') 383 384 entries = [] 385 386 for _ in range(entry_count): 387 token, day, month, year = BINARY_FORMAT.entry.unpack( 388 fd.read(BINARY_FORMAT.entry.size)) 389 390 try: 391 date_removed: Optional[datetime] = datetime(year, month, day) 392 except ValueError: 393 date_removed = None 394 395 entries.append((token, date_removed)) 396 397 # Read the entire string table and define a function for looking up strings. 398 string_table = fd.read() 399 400 def read_string(start): 401 end = string_table.find(b'\0', start) 402 return string_table[start:string_table.find(b'\0', start)].decode( 403 ), end + 1 404 405 offset = 0 406 for token, removed in entries: 407 string, offset = read_string(offset) 408 yield TokenizedStringEntry(token, string, DEFAULT_DOMAIN, removed) 409 410 411def write_binary(database: Database, fd: BinaryIO) -> None: 412 """Writes the database as packed binary to the provided binary file.""" 413 entries = sorted(database.entries()) 414 415 fd.write(BINARY_FORMAT.header.pack(BINARY_FORMAT.magic, len(entries))) 416 417 string_table = bytearray() 418 419 for entry in entries: 420 if entry.date_removed: 421 removed_day = entry.date_removed.day 422 removed_month = entry.date_removed.month 423 removed_year = entry.date_removed.year 424 else: 425 # If there is no removal date, use the special value 0xffffffff for 426 # the day/month/year. That ensures that still-present tokens appear 427 # as the newest tokens when sorted by removal date. 428 removed_day = 0xff 429 removed_month = 0xff 430 removed_year = 0xffff 431 432 string_table += entry.string.encode() 433 string_table.append(0) 434 435 fd.write( 436 BINARY_FORMAT.entry.pack(entry.token, removed_day, removed_month, 437 removed_year)) 438 439 fd.write(string_table) 440 441 442class DatabaseFile(Database): 443 """A token database that is associated with a particular file. 444 445 This class adds the write_to_file() method that writes to file from which it 446 was created in the correct format (CSV or binary). 447 """ 448 def __init__(self, path: Union[Path, str]): 449 self.path = Path(path) 450 451 # Read the path as a packed binary file. 452 with self.path.open('rb') as fd: 453 if file_is_binary_database(fd): 454 super().__init__(parse_binary(fd)) 455 self._export = write_binary 456 return 457 458 # Read the path as a CSV file. 459 _check_that_file_is_csv_database(self.path) 460 with self.path.open('r', newline='', encoding='utf-8') as file: 461 super().__init__(parse_csv(file)) 462 self._export = write_csv 463 464 def write_to_file(self, path: Optional[Union[Path, str]] = None) -> None: 465 """Exports in the original format to the original or provided path.""" 466 with open(self.path if path is None else path, 'wb') as fd: 467 self._export(self, fd) 468