• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""Builds and manages databases of tokenized strings."""
15
16from __future__ import annotations
17
18from abc import abstractmethod
19import collections
20import csv
21from dataclasses import dataclass
22from datetime import datetime
23import io
24import logging
25from pathlib import Path
26import re
27import struct
28import subprocess
29from typing import (
30    BinaryIO,
31    Callable,
32    Iterable,
33    Iterator,
34    NamedTuple,
35    Pattern,
36    TextIO,
37    ValuesView,
38)
39from uuid import uuid4
40
41DEFAULT_DOMAIN = ''
42
43# The default hash length to use for C-style hashes. This value only applies
44# when manually hashing strings to recreate token calculations in C. The C++
45# hash function does not have a maximum length.
46#
47# This MUST match the default value of PW_TOKENIZER_CFG_C_HASH_LENGTH in
48# pw_tokenizer/public/pw_tokenizer/config.h.
49DEFAULT_C_HASH_LENGTH = 128
50
51TOKENIZER_HASH_CONSTANT = 65599
52
53_LOG = logging.getLogger('pw_tokenizer')
54
55
56def _value(char: int | str) -> int:
57    return char if isinstance(char, int) else ord(char)
58
59
60def pw_tokenizer_65599_hash(
61    string: str | bytes, *, hash_length: int | None = None
62) -> int:
63    """Hashes the string with the hash function used to generate tokens in C++.
64
65    This hash function is used calculate tokens from strings in Python. It is
66    not used when extracting tokens from an ELF, since the token is stored in
67    the ELF as part of tokenization.
68    """
69    hash_value = len(string)
70    coefficient = TOKENIZER_HASH_CONSTANT
71
72    for char in string[:hash_length]:
73        hash_value = (hash_value + coefficient * _value(char)) % 2**32
74        coefficient = (coefficient * TOKENIZER_HASH_CONSTANT) % 2**32
75
76    return hash_value
77
78
79def c_hash(
80    string: str | bytes, hash_length: int = DEFAULT_C_HASH_LENGTH
81) -> int:
82    """Hashes the string with the hash function used in C."""
83    return pw_tokenizer_65599_hash(string, hash_length=hash_length)
84
85
86class _EntryKey(NamedTuple):
87    """Uniquely refers to an entry."""
88
89    token: int
90    string: str
91
92
93@dataclass(eq=True, order=False)
94class TokenizedStringEntry:
95    """A tokenized string with its metadata."""
96
97    token: int
98    string: str
99    domain: str = DEFAULT_DOMAIN
100    date_removed: datetime | None = None
101
102    def key(self) -> _EntryKey:
103        """The key determines uniqueness for a tokenized string."""
104        return _EntryKey(self.token, self.string)
105
106    def update_date_removed(self, new_date_removed: datetime | None) -> None:
107        """Sets self.date_removed if the other date is newer."""
108        # No removal date (None) is treated as the newest date.
109        if self.date_removed is None:
110            return
111
112        if new_date_removed is None or new_date_removed > self.date_removed:
113            self.date_removed = new_date_removed
114
115    def __lt__(self, other) -> bool:
116        """Sorts the entry by token, date removed, then string."""
117        if self.token != other.token:
118            return self.token < other.token
119
120        # Sort removal dates in reverse, so the most recently removed (or still
121        # present) entry appears first.
122        if self.date_removed != other.date_removed:
123            return (other.date_removed or datetime.max) < (
124                self.date_removed or datetime.max
125            )
126
127        return self.string < other.string
128
129    def __str__(self) -> str:
130        return self.string
131
132
133class Database:
134    """Database of tokenized strings stored as TokenizedStringEntry objects."""
135
136    def __init__(self, entries: Iterable[TokenizedStringEntry] = ()):
137        """Creates a token database."""
138        # The database dict stores each unique (token, string) entry.
139        self._database: dict[_EntryKey, TokenizedStringEntry] = {}
140
141        # This is a cache for fast token lookup that is built as needed.
142        self._cache: dict[int, list[TokenizedStringEntry]] | None = None
143
144        self.add(entries)
145
146    @classmethod
147    def from_strings(
148        cls,
149        strings: Iterable[str],
150        domain: str = DEFAULT_DOMAIN,
151        tokenize: Callable[[str], int] = pw_tokenizer_65599_hash,
152    ) -> Database:
153        """Creates a Database from an iterable of strings."""
154        return cls(
155            (
156                TokenizedStringEntry(tokenize(string), string, domain)
157                for string in strings
158            )
159        )
160
161    @classmethod
162    def merged(cls, *databases: Database) -> Database:
163        """Creates a TokenDatabase from one or more other databases."""
164        db = cls()
165        db.merge(*databases)
166        return db
167
168    @property
169    def token_to_entries(self) -> dict[int, list[TokenizedStringEntry]]:
170        """Returns a dict that maps tokens to a list of TokenizedStringEntry."""
171        if self._cache is None:  # build cache token -> entry cache
172            self._cache = collections.defaultdict(list)
173            for entry in self._database.values():
174                self._cache[entry.token].append(entry)
175
176        return self._cache
177
178    def entries(self) -> ValuesView[TokenizedStringEntry]:
179        """Returns iterable over all TokenizedStringEntries in the database."""
180        return self._database.values()
181
182    def collisions(self) -> Iterator[tuple[int, list[TokenizedStringEntry]]]:
183        """Returns tuple of (token, entries_list)) for all colliding tokens."""
184        for token, entries in self.token_to_entries.items():
185            if len(entries) > 1:
186                yield token, entries
187
188    def mark_removed(
189        self,
190        all_entries: Iterable[TokenizedStringEntry],
191        removal_date: datetime | None = None,
192    ) -> list[TokenizedStringEntry]:
193        """Marks entries missing from all_entries as having been removed.
194
195        The entries are assumed to represent the complete set of entries for the
196        database. Entries currently in the database not present in the provided
197        entries are marked with a removal date but remain in the database.
198        Entries in all_entries missing from the database are NOT added; call the
199        add function to add these.
200
201        Args:
202          all_entries: the complete set of strings present in the database
203          removal_date: the datetime for removed entries; today by default
204
205        Returns:
206          A list of entries marked as removed.
207        """
208        self._cache = None
209
210        if removal_date is None:
211            removal_date = datetime.now()
212
213        all_keys = frozenset(entry.key() for entry in all_entries)
214
215        removed = []
216
217        for entry in self._database.values():
218            if entry.key() not in all_keys and (
219                entry.date_removed is None or removal_date < entry.date_removed
220            ):
221                # Add a removal date, or update it to the oldest date.
222                entry.date_removed = removal_date
223                removed.append(entry)
224
225        return removed
226
227    def add(self, entries: Iterable[TokenizedStringEntry]) -> None:
228        """Adds new entries and updates date_removed for existing entries.
229
230        If the added tokens have removal dates, the newest date is used.
231        """
232        self._cache = None
233
234        for new_entry in entries:
235            # Update an existing entry or create a new one.
236            try:
237                entry = self._database[new_entry.key()]
238                entry.domain = new_entry.domain
239
240                # Keep the latest removal date between the two entries.
241                if new_entry.date_removed is None:
242                    entry.date_removed = None
243                elif (
244                    entry.date_removed
245                    and entry.date_removed < new_entry.date_removed
246                ):
247                    entry.date_removed = new_entry.date_removed
248            except KeyError:
249                # Make a copy to avoid unintentially updating the database.
250                self._database[new_entry.key()] = TokenizedStringEntry(
251                    **vars(new_entry)
252                )
253
254    def purge(
255        self, date_removed_cutoff: datetime | None = None
256    ) -> list[TokenizedStringEntry]:
257        """Removes and returns entries removed on/before date_removed_cutoff."""
258        self._cache = None
259
260        if date_removed_cutoff is None:
261            date_removed_cutoff = datetime.max
262
263        to_delete = [
264            entry
265            for _, entry in self._database.items()
266            if entry.date_removed and entry.date_removed <= date_removed_cutoff
267        ]
268
269        for entry in to_delete:
270            del self._database[entry.key()]
271
272        return to_delete
273
274    def merge(self, *databases: Database) -> None:
275        """Merges two or more databases together, keeping the newest dates."""
276        self._cache = None
277
278        for other_db in databases:
279            for entry in other_db.entries():
280                key = entry.key()
281
282                if key in self._database:
283                    self._database[key].update_date_removed(entry.date_removed)
284                else:
285                    self._database[key] = entry
286
287    def filter(
288        self,
289        include: Iterable[str | Pattern[str]] = (),
290        exclude: Iterable[str | Pattern[str]] = (),
291        replace: Iterable[tuple[str | Pattern[str], str]] = (),
292    ) -> None:
293        """Filters the database using regular expressions (strings or compiled).
294
295        Args:
296          include: regexes; only entries matching at least one are kept
297          exclude: regexes; entries matching any of these are removed
298          replace: (regex, str) tuples; replaces matching terms in all entries
299        """
300        self._cache = None
301
302        to_delete: list[_EntryKey] = []
303
304        if include:
305            include_re = [re.compile(pattern) for pattern in include]
306            to_delete.extend(
307                key
308                for key, val in self._database.items()
309                if not any(rgx.search(val.string) for rgx in include_re)
310            )
311
312        if exclude:
313            exclude_re = [re.compile(pattern) for pattern in exclude]
314            to_delete.extend(
315                key
316                for key, val in self._database.items()
317                if any(rgx.search(val.string) for rgx in exclude_re)
318            )
319
320        for key in to_delete:
321            del self._database[key]
322
323        for search, replacement in replace:
324            search = re.compile(search)
325
326            for value in self._database.values():
327                value.string = search.sub(replacement, value.string)
328
329    def difference(self, other: Database) -> Database:
330        """Returns a new Database with entries in this DB not in the other."""
331        # pylint: disable=protected-access
332        return Database(
333            e for k, e in self._database.items() if k not in other._database
334        )
335        # pylint: enable=protected-access
336
337    def __len__(self) -> int:
338        """Returns the number of entries in the database."""
339        return len(self.entries())
340
341    def __bool__(self) -> bool:
342        """True if the database is non-empty."""
343        return bool(self._database)
344
345    def __str__(self) -> str:
346        """Outputs the database as CSV."""
347        csv_output = io.BytesIO()
348        write_csv(self, csv_output)
349        return csv_output.getvalue().decode()
350
351
352def parse_csv(fd: TextIO) -> Iterable[TokenizedStringEntry]:
353    """Parses TokenizedStringEntries from a CSV token database file."""
354    entries = []
355    for line in csv.reader(fd):
356        try:
357            token_str, date_str, string_literal = line
358
359            token = int(token_str, 16)
360            date = (
361                datetime.fromisoformat(date_str) if date_str.strip() else None
362            )
363
364            entries.append(
365                TokenizedStringEntry(
366                    token, string_literal, DEFAULT_DOMAIN, date
367                )
368            )
369        except (ValueError, UnicodeDecodeError) as err:
370            _LOG.error(
371                'Failed to parse tokenized string entry %s: %s', line, err
372            )
373    return entries
374
375
376def write_csv(database: Database, fd: BinaryIO) -> None:
377    """Writes the database as CSV to the provided binary file."""
378    for entry in sorted(database.entries()):
379        _write_csv_line(fd, entry)
380
381
382def _write_csv_line(fd: BinaryIO, entry: TokenizedStringEntry):
383    """Write a line in CSV format to the provided binary file."""
384    # Align the CSV output to 10-character columns for improved readability.
385    # Use \n instead of RFC 4180's \r\n.
386    fd.write(
387        '{:08x},{:10},"{}"\n'.format(
388            entry.token,
389            entry.date_removed.date().isoformat() if entry.date_removed else '',
390            entry.string.replace('"', '""'),
391        ).encode()
392    )  # escape " as ""
393
394
395class _BinaryFileFormat(NamedTuple):
396    """Attributes of the binary token database file format."""
397
398    magic: bytes = b'TOKENS\0\0'
399    header: struct.Struct = struct.Struct('<8sI4x')
400    entry: struct.Struct = struct.Struct('<IBBH')
401
402
403BINARY_FORMAT = _BinaryFileFormat()
404
405
406class DatabaseFormatError(Exception):
407    """Failed to parse a token database file."""
408
409
410def file_is_binary_database(fd: BinaryIO) -> bool:
411    """True if the file starts with the binary token database magic string."""
412    try:
413        fd.seek(0)
414        magic = fd.read(len(BINARY_FORMAT.magic))
415        fd.seek(0)
416        return BINARY_FORMAT.magic == magic
417    except IOError:
418        return False
419
420
421def _check_that_file_is_csv_database(path: Path) -> None:
422    """Raises an error unless the path appears to be a CSV token database."""
423    try:
424        with path.open('rb') as fd:
425            data = fd.read(8)  # Read 8 bytes, which should be the first token.
426
427        if not data:
428            return  # File is empty, which is valid CSV.
429
430        if len(data) != 8:
431            raise DatabaseFormatError(
432                f'Attempted to read {path} as a CSV token database, but the '
433                f'file is too short ({len(data)} B)'
434            )
435
436        # Make sure the first 8 chars are a valid hexadecimal number.
437        _ = int(data.decode(), 16)
438    except (IOError, UnicodeDecodeError, ValueError) as err:
439        raise DatabaseFormatError(
440            f'Encountered error while reading {path} as a CSV token database'
441        ) from err
442
443
444def parse_binary(fd: BinaryIO) -> Iterable[TokenizedStringEntry]:
445    """Parses TokenizedStringEntries from a binary token database file."""
446    magic, entry_count = BINARY_FORMAT.header.unpack(
447        fd.read(BINARY_FORMAT.header.size)
448    )
449
450    if magic != BINARY_FORMAT.magic:
451        raise DatabaseFormatError(
452            f'Binary token database magic number mismatch (found {magic!r}, '
453            f'expected {BINARY_FORMAT.magic!r}) while reading from {fd}'
454        )
455
456    entries = []
457
458    for _ in range(entry_count):
459        token, day, month, year = BINARY_FORMAT.entry.unpack(
460            fd.read(BINARY_FORMAT.entry.size)
461        )
462
463        try:
464            date_removed: datetime | None = datetime(year, month, day)
465        except ValueError:
466            date_removed = None
467
468        entries.append((token, date_removed))
469
470    # Read the entire string table and define a function for looking up strings.
471    string_table = fd.read()
472
473    def read_string(start):
474        end = string_table.find(b'\0', start)
475        return (
476            string_table[start : string_table.find(b'\0', start)].decode(),
477            end + 1,
478        )
479
480    offset = 0
481    for token, removed in entries:
482        string, offset = read_string(offset)
483        yield TokenizedStringEntry(token, string, DEFAULT_DOMAIN, removed)
484
485
486def write_binary(database: Database, fd: BinaryIO) -> None:
487    """Writes the database as packed binary to the provided binary file."""
488    entries = sorted(database.entries())
489
490    fd.write(BINARY_FORMAT.header.pack(BINARY_FORMAT.magic, len(entries)))
491
492    string_table = bytearray()
493
494    for entry in entries:
495        if entry.date_removed:
496            removed_day = entry.date_removed.day
497            removed_month = entry.date_removed.month
498            removed_year = entry.date_removed.year
499        else:
500            # If there is no removal date, use the special value 0xffffffff for
501            # the day/month/year. That ensures that still-present tokens appear
502            # as the newest tokens when sorted by removal date.
503            removed_day = 0xFF
504            removed_month = 0xFF
505            removed_year = 0xFFFF
506
507        string_table += entry.string.encode()
508        string_table.append(0)
509
510        fd.write(
511            BINARY_FORMAT.entry.pack(
512                entry.token, removed_day, removed_month, removed_year
513            )
514        )
515
516    fd.write(string_table)
517
518
519class DatabaseFile(Database):
520    """A token database that is associated with a particular file.
521
522    This class adds the write_to_file() method that writes to file from which it
523    was created in the correct format (CSV or binary).
524    """
525
526    def __init__(
527        self, path: Path, entries: Iterable[TokenizedStringEntry]
528    ) -> None:
529        super().__init__(entries)
530        self.path = path
531
532    @staticmethod
533    def load(path: Path) -> DatabaseFile:
534        """Creates a DatabaseFile that coincides to the file type."""
535        if path.is_dir():
536            return _DirectoryDatabase(path)
537
538        # Read the path as a packed binary file.
539        with path.open('rb') as fd:
540            if file_is_binary_database(fd):
541                return _BinaryDatabase(path, fd)
542
543        # Read the path as a CSV file.
544        _check_that_file_is_csv_database(path)
545        return _CSVDatabase(path)
546
547    @abstractmethod
548    def write_to_file(self, *, rewrite: bool = False) -> None:
549        """Exports in the original format to the original path."""
550
551    @abstractmethod
552    def add_and_discard_temporary(
553        self, entries: Iterable[TokenizedStringEntry], commit: str
554    ) -> None:
555        """Discards and adds entries to export in the original format.
556
557        Adds entries after removing temporary entries from the Database
558        to exclusively write re-occurring entries into memory and disk.
559        """
560
561
562class _BinaryDatabase(DatabaseFile):
563    def __init__(self, path: Path, fd: BinaryIO) -> None:
564        super().__init__(path, parse_binary(fd))
565
566    def write_to_file(self, *, rewrite: bool = False) -> None:
567        """Exports in the binary format to the original path."""
568        del rewrite  # Binary databases are always rewritten
569        with self.path.open('wb') as fd:
570            write_binary(self, fd)
571
572    def add_and_discard_temporary(
573        self, entries: Iterable[TokenizedStringEntry], commit: str
574    ) -> None:
575        # TODO: b/241471465 - Implement adding new tokens and removing
576        # temporary entries for binary databases.
577        raise NotImplementedError(
578            '--discard-temporary is currently only '
579            'supported for directory databases'
580        )
581
582
583class _CSVDatabase(DatabaseFile):
584    def __init__(self, path: Path) -> None:
585        with path.open('r', newline='', encoding='utf-8') as csv_fd:
586            super().__init__(path, parse_csv(csv_fd))
587
588    def write_to_file(self, *, rewrite: bool = False) -> None:
589        """Exports in the CSV format to the original path."""
590        del rewrite  # CSV databases are always rewritten
591        with self.path.open('wb') as fd:
592            write_csv(self, fd)
593
594    def add_and_discard_temporary(
595        self, entries: Iterable[TokenizedStringEntry], commit: str
596    ) -> None:
597        # TODO: b/241471465 - Implement adding new tokens and removing
598        # temporary entries for CSV databases.
599        raise NotImplementedError(
600            '--discard-temporary is currently only '
601            'supported for directory databases'
602        )
603
604
605# The suffix used for CSV files in a directory database.
606DIR_DB_SUFFIX = '.pw_tokenizer.csv'
607DIR_DB_GLOB = '*' + DIR_DB_SUFFIX
608
609
610def _parse_directory(directory: Path) -> Iterable[TokenizedStringEntry]:
611    """Parses TokenizedStringEntries tokenizer CSV files in the directory."""
612    for path in directory.glob(DIR_DB_GLOB):
613        yield from _CSVDatabase(path).entries()
614
615
616def _most_recently_modified_file(paths: Iterable[Path]) -> Path:
617    return max(paths, key=lambda path: path.stat().st_mtime)
618
619
620class _DirectoryDatabase(DatabaseFile):
621    def __init__(self, directory: Path) -> None:
622        super().__init__(directory, _parse_directory(directory))
623
624    def write_to_file(self, *, rewrite: bool = False) -> None:
625        """Creates a new CSV file in the directory with any new tokens."""
626        if rewrite:
627            # Write the entire database to a new CSV file
628            new_file = self._create_filename()
629            with new_file.open('wb') as fd:
630                write_csv(self, fd)
631
632            # Delete all CSV files except for the new CSV with everything.
633            for csv_file in self.path.glob(DIR_DB_GLOB):
634                if csv_file != new_file:
635                    csv_file.unlink()
636        else:
637            # Reread the tokens from disk and write only the new entries to CSV.
638            current_tokens = Database(_parse_directory(self.path))
639            new_entries = self.difference(current_tokens)
640            if new_entries:
641                with self._create_filename().open('wb') as fd:
642                    write_csv(new_entries, fd)
643
644    def _git_paths(self, commands: list) -> list[Path]:
645        """Returns a list of database CSVs from a Git command."""
646        try:
647            output = subprocess.run(
648                ['git', *commands, DIR_DB_GLOB],
649                capture_output=True,
650                check=True,
651                cwd=self.path,
652                text=True,
653            ).stdout.strip()
654            return [self.path / repo_path for repo_path in output.splitlines()]
655        except subprocess.CalledProcessError:
656            return []
657
658    def _find_latest_csv(self, commit: str) -> Path:
659        """Finds or creates a CSV to which to write new entries.
660
661        - Check for untracked CSVs. Use the most recently modified file, if any.
662        - Check for CSVs added in HEAD, if HEAD is not an ancestor of commit.
663          Use the most recently modified file, if any.
664        - If no untracked or committed files were found, create a new file.
665        """
666
667        # Prioritize untracked files in the directory database.
668        untracked_changes = self._git_paths(
669            ['ls-files', '--others', '--exclude-standard']
670        )
671        if untracked_changes:
672            return _most_recently_modified_file(untracked_changes)
673
674        # Check if HEAD is an ancestor of the base commit. This checks whether
675        # the top commit has been merged or not. If it has been merged, create a
676        # new CSV to use. Otherwise, check if a CSV was added in the commit.
677        head_is_not_merged = (
678            subprocess.run(
679                ['git', 'merge-base', '--is-ancestor', 'HEAD', commit],
680                cwd=self.path,
681                stdout=subprocess.DEVNULL,
682                stderr=subprocess.DEVNULL,
683            ).returncode
684            != 0
685        )
686
687        if head_is_not_merged:
688            # Find CSVs added in the top commit.
689            csvs_from_top_commit = self._git_paths(
690                [
691                    'diff',
692                    '--name-only',
693                    '--diff-filter=A',
694                    '--relative',
695                    'HEAD~',
696                ]
697            )
698
699            if csvs_from_top_commit:
700                return _most_recently_modified_file(csvs_from_top_commit)
701
702        return self._create_filename()
703
704    def _create_filename(self) -> Path:
705        """Generates a unique filename not in the directory."""
706        # Tracked and untracked files do not exist in the repo.
707        while (file := self.path / f'{uuid4().hex}{DIR_DB_SUFFIX}').exists():
708            pass
709        return file
710
711    def add_and_discard_temporary(
712        self, entries: Iterable[TokenizedStringEntry], commit: str
713    ) -> None:
714        """Adds new entries and discards temporary entries on disk.
715
716        - Find the latest CSV in the directory database or create a new one.
717        - Delete entries in the latest CSV that are not in the entries passed to
718          this function.
719        - Add the new entries to this database.
720        - Overwrite the latest CSV with only the newly added entries.
721        """
722        # Find entries not currently in the database.
723        added = Database(entries)
724        new_entries = added.difference(self)
725
726        csv_path = self._find_latest_csv(commit)
727        if csv_path.exists():
728            # Loading the CSV as a DatabaseFile.
729            csv_db = DatabaseFile.load(csv_path)
730
731            # Delete entries added in the CSV, but not added in this function.
732            for key in (e.key() for e in csv_db.difference(added).entries()):
733                del self._database[key]
734                del csv_db._database[key]  # pylint: disable=protected-access
735
736            csv_db.add(new_entries.entries())
737            csv_db.write_to_file()
738        elif new_entries:  # If the CSV does not exist, write all new tokens.
739            with csv_path.open('wb') as fd:
740                write_csv(new_entries, fd)
741
742        self.add(new_entries.entries())
743