• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""Builds and manages databases of tokenized strings."""
15
16from abc import abstractmethod
17import collections
18import csv
19from dataclasses import dataclass
20from datetime import datetime
21import io
22import logging
23from pathlib import Path
24import re
25import struct
26import subprocess
27from typing import (
28    BinaryIO,
29    Callable,
30    Dict,
31    Iterable,
32    Iterator,
33    List,
34    NamedTuple,
35    Optional,
36    Pattern,
37    TextIO,
38    Tuple,
39    Union,
40    ValuesView,
41)
42from uuid import uuid4
43
44DEFAULT_DOMAIN = ''
45
46# The default hash length to use for C-style hashes. This value only applies
47# when manually hashing strings to recreate token calculations in C. The C++
48# hash function does not have a maximum length.
49#
50# This MUST match the default value of PW_TOKENIZER_CFG_C_HASH_LENGTH in
51# pw_tokenizer/public/pw_tokenizer/config.h.
52DEFAULT_C_HASH_LENGTH = 128
53
54TOKENIZER_HASH_CONSTANT = 65599
55
56_LOG = logging.getLogger('pw_tokenizer')
57
58
59def _value(char: Union[int, str]) -> int:
60    return char if isinstance(char, int) else ord(char)
61
62
63def pw_tokenizer_65599_hash(
64    string: Union[str, bytes], *, hash_length: Optional[int] = None
65) -> int:
66    """Hashes the string with the hash function used to generate tokens in C++.
67
68    This hash function is used calculate tokens from strings in Python. It is
69    not used when extracting tokens from an ELF, since the token is stored in
70    the ELF as part of tokenization.
71    """
72    hash_value = len(string)
73    coefficient = TOKENIZER_HASH_CONSTANT
74
75    for char in string[:hash_length]:
76        hash_value = (hash_value + coefficient * _value(char)) % 2**32
77        coefficient = (coefficient * TOKENIZER_HASH_CONSTANT) % 2**32
78
79    return hash_value
80
81
82def c_hash(
83    string: Union[str, bytes], hash_length: int = DEFAULT_C_HASH_LENGTH
84) -> int:
85    """Hashes the string with the hash function used in C."""
86    return pw_tokenizer_65599_hash(string, hash_length=hash_length)
87
88
89class _EntryKey(NamedTuple):
90    """Uniquely refers to an entry."""
91
92    token: int
93    string: str
94
95
96@dataclass(eq=True, order=False)
97class TokenizedStringEntry:
98    """A tokenized string with its metadata."""
99
100    token: int
101    string: str
102    domain: str = DEFAULT_DOMAIN
103    date_removed: Optional[datetime] = None
104
105    def key(self) -> _EntryKey:
106        """The key determines uniqueness for a tokenized string."""
107        return _EntryKey(self.token, self.string)
108
109    def update_date_removed(self, new_date_removed: Optional[datetime]) -> None:
110        """Sets self.date_removed if the other date is newer."""
111        # No removal date (None) is treated as the newest date.
112        if self.date_removed is None:
113            return
114
115        if new_date_removed is None or new_date_removed > self.date_removed:
116            self.date_removed = new_date_removed
117
118    def __lt__(self, other) -> bool:
119        """Sorts the entry by token, date removed, then string."""
120        if self.token != other.token:
121            return self.token < other.token
122
123        # Sort removal dates in reverse, so the most recently removed (or still
124        # present) entry appears first.
125        if self.date_removed != other.date_removed:
126            return (other.date_removed or datetime.max) < (
127                self.date_removed or datetime.max
128            )
129
130        return self.string < other.string
131
132    def __str__(self) -> str:
133        return self.string
134
135
136class Database:
137    """Database of tokenized strings stored as TokenizedStringEntry objects."""
138
139    def __init__(self, entries: Iterable[TokenizedStringEntry] = ()):
140        """Creates a token database."""
141        # The database dict stores each unique (token, string) entry.
142        self._database: Dict[_EntryKey, TokenizedStringEntry] = {}
143
144        # This is a cache for fast token lookup that is built as needed.
145        self._cache: Optional[Dict[int, List[TokenizedStringEntry]]] = None
146
147        self.add(entries)
148
149    @classmethod
150    def from_strings(
151        cls,
152        strings: Iterable[str],
153        domain: str = DEFAULT_DOMAIN,
154        tokenize: Callable[[str], int] = pw_tokenizer_65599_hash,
155    ) -> 'Database':
156        """Creates a Database from an iterable of strings."""
157        return cls(
158            (
159                TokenizedStringEntry(tokenize(string), string, domain)
160                for string in strings
161            )
162        )
163
164    @classmethod
165    def merged(cls, *databases: 'Database') -> 'Database':
166        """Creates a TokenDatabase from one or more other databases."""
167        db = cls()
168        db.merge(*databases)
169        return db
170
171    @property
172    def token_to_entries(self) -> Dict[int, List[TokenizedStringEntry]]:
173        """Returns a dict that maps tokens to a list of TokenizedStringEntry."""
174        if self._cache is None:  # build cache token -> entry cache
175            self._cache = collections.defaultdict(list)
176            for entry in self._database.values():
177                self._cache[entry.token].append(entry)
178
179        return self._cache
180
181    def entries(self) -> ValuesView[TokenizedStringEntry]:
182        """Returns iterable over all TokenizedStringEntries in the database."""
183        return self._database.values()
184
185    def collisions(self) -> Iterator[Tuple[int, List[TokenizedStringEntry]]]:
186        """Returns tuple of (token, entries_list)) for all colliding tokens."""
187        for token, entries in self.token_to_entries.items():
188            if len(entries) > 1:
189                yield token, entries
190
191    def mark_removed(
192        self,
193        all_entries: Iterable[TokenizedStringEntry],
194        removal_date: Optional[datetime] = None,
195    ) -> List[TokenizedStringEntry]:
196        """Marks entries missing from all_entries as having been removed.
197
198        The entries are assumed to represent the complete set of entries for the
199        database. Entries currently in the database not present in the provided
200        entries are marked with a removal date but remain in the database.
201        Entries in all_entries missing from the database are NOT added; call the
202        add function to add these.
203
204        Args:
205          all_entries: the complete set of strings present in the database
206          removal_date: the datetime for removed entries; today by default
207
208        Returns:
209          A list of entries marked as removed.
210        """
211        self._cache = None
212
213        if removal_date is None:
214            removal_date = datetime.now()
215
216        all_keys = frozenset(entry.key() for entry in all_entries)
217
218        removed = []
219
220        for entry in self._database.values():
221            if entry.key() not in all_keys and (
222                entry.date_removed is None or removal_date < entry.date_removed
223            ):
224                # Add a removal date, or update it to the oldest date.
225                entry.date_removed = removal_date
226                removed.append(entry)
227
228        return removed
229
230    def add(self, entries: Iterable[TokenizedStringEntry]) -> None:
231        """Adds new entries and updates date_removed for existing entries.
232
233        If the added tokens have removal dates, the newest date is used.
234        """
235        self._cache = None
236
237        for new_entry in entries:
238            # Update an existing entry or create a new one.
239            try:
240                entry = self._database[new_entry.key()]
241                entry.domain = new_entry.domain
242
243                # Keep the latest removal date between the two entries.
244                if new_entry.date_removed is None:
245                    entry.date_removed = None
246                elif (
247                    entry.date_removed
248                    and entry.date_removed < new_entry.date_removed
249                ):
250                    entry.date_removed = new_entry.date_removed
251            except KeyError:
252                # Make a copy to avoid unintentially updating the database.
253                self._database[new_entry.key()] = TokenizedStringEntry(
254                    **vars(new_entry)
255                )
256
257    def purge(
258        self, date_removed_cutoff: Optional[datetime] = None
259    ) -> List[TokenizedStringEntry]:
260        """Removes and returns entries removed on/before date_removed_cutoff."""
261        self._cache = None
262
263        if date_removed_cutoff is None:
264            date_removed_cutoff = datetime.max
265
266        to_delete = [
267            entry
268            for _, entry in self._database.items()
269            if entry.date_removed and entry.date_removed <= date_removed_cutoff
270        ]
271
272        for entry in to_delete:
273            del self._database[entry.key()]
274
275        return to_delete
276
277    def merge(self, *databases: 'Database') -> None:
278        """Merges two or more databases together, keeping the newest dates."""
279        self._cache = None
280
281        for other_db in databases:
282            for entry in other_db.entries():
283                key = entry.key()
284
285                if key in self._database:
286                    self._database[key].update_date_removed(entry.date_removed)
287                else:
288                    self._database[key] = entry
289
290    def filter(
291        self,
292        include: Iterable[Union[str, Pattern[str]]] = (),
293        exclude: Iterable[Union[str, Pattern[str]]] = (),
294        replace: Iterable[Tuple[Union[str, Pattern[str]], str]] = (),
295    ) -> None:
296        """Filters the database using regular expressions (strings or compiled).
297
298        Args:
299          include: regexes; only entries matching at least one are kept
300          exclude: regexes; entries matching any of these are removed
301          replace: (regex, str) tuples; replaces matching terms in all entries
302        """
303        self._cache = None
304
305        to_delete: List[_EntryKey] = []
306
307        if include:
308            include_re = [re.compile(pattern) for pattern in include]
309            to_delete.extend(
310                key
311                for key, val in self._database.items()
312                if not any(rgx.search(val.string) for rgx in include_re)
313            )
314
315        if exclude:
316            exclude_re = [re.compile(pattern) for pattern in exclude]
317            to_delete.extend(
318                key
319                for key, val in self._database.items()
320                if any(rgx.search(val.string) for rgx in exclude_re)
321            )
322
323        for key in to_delete:
324            del self._database[key]
325
326        for search, replacement in replace:
327            search = re.compile(search)
328
329            for value in self._database.values():
330                value.string = search.sub(replacement, value.string)
331
332    def difference(self, other: 'Database') -> 'Database':
333        """Returns a new Database with entries in this DB not in the other."""
334        # pylint: disable=protected-access
335        return Database(
336            e for k, e in self._database.items() if k not in other._database
337        )
338        # pylint: enable=protected-access
339
340    def __len__(self) -> int:
341        """Returns the number of entries in the database."""
342        return len(self.entries())
343
344    def __bool__(self) -> bool:
345        """True if the database is non-empty."""
346        return bool(self._database)
347
348    def __str__(self) -> str:
349        """Outputs the database as CSV."""
350        csv_output = io.BytesIO()
351        write_csv(self, csv_output)
352        return csv_output.getvalue().decode()
353
354
355def parse_csv(fd: TextIO) -> Iterable[TokenizedStringEntry]:
356    """Parses TokenizedStringEntries from a CSV token database file."""
357    entries = []
358    for line in csv.reader(fd):
359        try:
360            token_str, date_str, string_literal = line
361
362            token = int(token_str, 16)
363            date = (
364                datetime.fromisoformat(date_str) if date_str.strip() else None
365            )
366
367            entries.append(
368                TokenizedStringEntry(
369                    token, string_literal, DEFAULT_DOMAIN, date
370                )
371            )
372        except (ValueError, UnicodeDecodeError) as err:
373            _LOG.error(
374                'Failed to parse tokenized string entry %s: %s', line, err
375            )
376    return entries
377
378
379def write_csv(database: Database, fd: BinaryIO) -> None:
380    """Writes the database as CSV to the provided binary file."""
381    for entry in sorted(database.entries()):
382        _write_csv_line(fd, entry)
383
384
385def _write_csv_line(fd: BinaryIO, entry: TokenizedStringEntry):
386    """Write a line in CSV format to the provided binary file."""
387    # Align the CSV output to 10-character columns for improved readability.
388    # Use \n instead of RFC 4180's \r\n.
389    fd.write(
390        '{:08x},{:10},"{}"\n'.format(
391            entry.token,
392            entry.date_removed.date().isoformat() if entry.date_removed else '',
393            entry.string.replace('"', '""'),
394        ).encode()
395    )  # escape " as ""
396
397
398class _BinaryFileFormat(NamedTuple):
399    """Attributes of the binary token database file format."""
400
401    magic: bytes = b'TOKENS\0\0'
402    header: struct.Struct = struct.Struct('<8sI4x')
403    entry: struct.Struct = struct.Struct('<IBBH')
404
405
406BINARY_FORMAT = _BinaryFileFormat()
407
408
409class DatabaseFormatError(Exception):
410    """Failed to parse a token database file."""
411
412
413def file_is_binary_database(fd: BinaryIO) -> bool:
414    """True if the file starts with the binary token database magic string."""
415    try:
416        fd.seek(0)
417        magic = fd.read(len(BINARY_FORMAT.magic))
418        fd.seek(0)
419        return BINARY_FORMAT.magic == magic
420    except IOError:
421        return False
422
423
424def _check_that_file_is_csv_database(path: Path) -> None:
425    """Raises an error unless the path appears to be a CSV token database."""
426    try:
427        with path.open('rb') as fd:
428            data = fd.read(8)  # Read 8 bytes, which should be the first token.
429
430        if not data:
431            return  # File is empty, which is valid CSV.
432
433        if len(data) != 8:
434            raise DatabaseFormatError(
435                f'Attempted to read {path} as a CSV token database, but the '
436                f'file is too short ({len(data)} B)'
437            )
438
439        # Make sure the first 8 chars are a valid hexadecimal number.
440        _ = int(data.decode(), 16)
441    except (IOError, UnicodeDecodeError, ValueError) as err:
442        raise DatabaseFormatError(
443            f'Encountered error while reading {path} as a CSV token database'
444        ) from err
445
446
447def parse_binary(fd: BinaryIO) -> Iterable[TokenizedStringEntry]:
448    """Parses TokenizedStringEntries from a binary token database file."""
449    magic, entry_count = BINARY_FORMAT.header.unpack(
450        fd.read(BINARY_FORMAT.header.size)
451    )
452
453    if magic != BINARY_FORMAT.magic:
454        raise DatabaseFormatError(
455            f'Binary token database magic number mismatch (found {magic!r}, '
456            f'expected {BINARY_FORMAT.magic!r}) while reading from {fd}'
457        )
458
459    entries = []
460
461    for _ in range(entry_count):
462        token, day, month, year = BINARY_FORMAT.entry.unpack(
463            fd.read(BINARY_FORMAT.entry.size)
464        )
465
466        try:
467            date_removed: Optional[datetime] = datetime(year, month, day)
468        except ValueError:
469            date_removed = None
470
471        entries.append((token, date_removed))
472
473    # Read the entire string table and define a function for looking up strings.
474    string_table = fd.read()
475
476    def read_string(start):
477        end = string_table.find(b'\0', start)
478        return (
479            string_table[start : string_table.find(b'\0', start)].decode(),
480            end + 1,
481        )
482
483    offset = 0
484    for token, removed in entries:
485        string, offset = read_string(offset)
486        yield TokenizedStringEntry(token, string, DEFAULT_DOMAIN, removed)
487
488
489def write_binary(database: Database, fd: BinaryIO) -> None:
490    """Writes the database as packed binary to the provided binary file."""
491    entries = sorted(database.entries())
492
493    fd.write(BINARY_FORMAT.header.pack(BINARY_FORMAT.magic, len(entries)))
494
495    string_table = bytearray()
496
497    for entry in entries:
498        if entry.date_removed:
499            removed_day = entry.date_removed.day
500            removed_month = entry.date_removed.month
501            removed_year = entry.date_removed.year
502        else:
503            # If there is no removal date, use the special value 0xffffffff for
504            # the day/month/year. That ensures that still-present tokens appear
505            # as the newest tokens when sorted by removal date.
506            removed_day = 0xFF
507            removed_month = 0xFF
508            removed_year = 0xFFFF
509
510        string_table += entry.string.encode()
511        string_table.append(0)
512
513        fd.write(
514            BINARY_FORMAT.entry.pack(
515                entry.token, removed_day, removed_month, removed_year
516            )
517        )
518
519    fd.write(string_table)
520
521
522class DatabaseFile(Database):
523    """A token database that is associated with a particular file.
524
525    This class adds the write_to_file() method that writes to file from which it
526    was created in the correct format (CSV or binary).
527    """
528
529    def __init__(
530        self, path: Path, entries: Iterable[TokenizedStringEntry]
531    ) -> None:
532        super().__init__(entries)
533        self.path = path
534
535    @staticmethod
536    def load(path: Path) -> 'DatabaseFile':
537        """Creates a DatabaseFile that coincides to the file type."""
538        if path.is_dir():
539            return _DirectoryDatabase(path)
540
541        # Read the path as a packed binary file.
542        with path.open('rb') as fd:
543            if file_is_binary_database(fd):
544                return _BinaryDatabase(path, fd)
545
546        # Read the path as a CSV file.
547        _check_that_file_is_csv_database(path)
548        return _CSVDatabase(path)
549
550    @abstractmethod
551    def write_to_file(self, *, rewrite: bool = False) -> None:
552        """Exports in the original format to the original path."""
553
554    @abstractmethod
555    def add_and_discard_temporary(
556        self, entries: Iterable[TokenizedStringEntry], commit: str
557    ) -> None:
558        """Discards and adds entries to export in the original format.
559
560        Adds entries after removing temporary entries from the Database
561        to exclusively write re-occurring entries into memory and disk.
562        """
563
564
565class _BinaryDatabase(DatabaseFile):
566    def __init__(self, path: Path, fd: BinaryIO) -> None:
567        super().__init__(path, parse_binary(fd))
568
569    def write_to_file(self, *, rewrite: bool = False) -> None:
570        """Exports in the binary format to the original path."""
571        del rewrite  # Binary databases are always rewritten
572        with self.path.open('wb') as fd:
573            write_binary(self, fd)
574
575    def add_and_discard_temporary(
576        self, entries: Iterable[TokenizedStringEntry], commit: str
577    ) -> None:
578        # TODO(b/241471465): Implement adding new tokens and removing
579        # temporary entries for binary databases.
580        raise NotImplementedError(
581            '--discard-temporary is currently only '
582            'supported for directory databases'
583        )
584
585
586class _CSVDatabase(DatabaseFile):
587    def __init__(self, path: Path) -> None:
588        with path.open('r', newline='', encoding='utf-8') as csv_fd:
589            super().__init__(path, parse_csv(csv_fd))
590
591    def write_to_file(self, *, rewrite: bool = False) -> None:
592        """Exports in the CSV format to the original path."""
593        del rewrite  # CSV databases are always rewritten
594        with self.path.open('wb') as fd:
595            write_csv(self, fd)
596
597    def add_and_discard_temporary(
598        self, entries: Iterable[TokenizedStringEntry], commit: str
599    ) -> None:
600        # TODO(b/241471465): Implement adding new tokens and removing
601        # temporary entries for CSV databases.
602        raise NotImplementedError(
603            '--discard-temporary is currently only '
604            'supported for directory databases'
605        )
606
607
608# The suffix used for CSV files in a directory database.
609DIR_DB_SUFFIX = '.pw_tokenizer.csv'
610_DIR_DB_GLOB = '*' + DIR_DB_SUFFIX
611
612
613def _parse_directory(directory: Path) -> Iterable[TokenizedStringEntry]:
614    """Parses TokenizedStringEntries tokenizer CSV files in the directory."""
615    for path in directory.glob(_DIR_DB_GLOB):
616        yield from _CSVDatabase(path).entries()
617
618
619def _most_recently_modified_file(paths: Iterable[Path]) -> Path:
620    return max(paths, key=lambda path: path.stat().st_mtime)
621
622
623class _DirectoryDatabase(DatabaseFile):
624    def __init__(self, directory: Path) -> None:
625        super().__init__(directory, _parse_directory(directory))
626
627    def write_to_file(self, *, rewrite: bool = False) -> None:
628        """Creates a new CSV file in the directory with any new tokens."""
629        if rewrite:
630            # Write the entire database to a new CSV file
631            new_file = self._create_filename()
632            with new_file.open('wb') as fd:
633                write_csv(self, fd)
634
635            # Delete all CSV files except for the new CSV with everything.
636            for csv_file in self.path.glob(_DIR_DB_GLOB):
637                if csv_file != new_file:
638                    csv_file.unlink()
639        else:
640            # Reread the tokens from disk and write only the new entries to CSV.
641            current_tokens = Database(_parse_directory(self.path))
642            new_entries = self.difference(current_tokens)
643            if new_entries:
644                with self._create_filename().open('wb') as fd:
645                    write_csv(new_entries, fd)
646
647    def _git_paths(self, commands: List) -> List[Path]:
648        """Returns a list of files from a Git command, filtered to matc."""
649        try:
650            output = subprocess.run(
651                ['git', *commands, _DIR_DB_GLOB],
652                capture_output=True,
653                check=True,
654                cwd=self.path,
655                text=True,
656            ).stdout.strip()
657            return [self.path / repo_path for repo_path in output.splitlines()]
658        except subprocess.CalledProcessError:
659            return []
660
661    def _find_latest_csv(self, commit: str) -> Path:
662        """Finds or creates a CSV to which to write new entries.
663
664        - Check for untracked CSVs. Use the most recently modified file, if any.
665        - Check for CSVs added in HEAD, if HEAD is not an ancestor of commit.
666          Use the most recently modified file, if any.
667        - If no untracked or committed files were found, create a new file.
668        """
669
670        # Prioritize untracked files in the directory database.
671        untracked_changes = self._git_paths(
672            ['ls-files', '--others', '--exclude-standard']
673        )
674        if untracked_changes:
675            return _most_recently_modified_file(untracked_changes)
676
677        # Check if HEAD is an ancestor of the base commit. This checks whether
678        # the top commit has been merged or not. If it has been merged, create a
679        # new CSV to use. Otherwise, check if a CSV was added in the commit.
680        head_is_not_merged = (
681            subprocess.run(
682                ['git', 'merge-base', '--is-ancestor', 'HEAD', commit],
683                cwd=self.path,
684                stdout=subprocess.DEVNULL,
685                stderr=subprocess.DEVNULL,
686            ).returncode
687            != 0
688        )
689
690        if head_is_not_merged:
691            # Find CSVs added in the top commit.
692            csvs_from_top_commit = self._git_paths(
693                [
694                    'diff',
695                    '--name-only',
696                    '--diff-filter=A',
697                    '--relative',
698                    'HEAD~',
699                ]
700            )
701
702            if csvs_from_top_commit:
703                return _most_recently_modified_file(csvs_from_top_commit)
704
705        return self._create_filename()
706
707    def _create_filename(self) -> Path:
708        """Generates a unique filename not in the directory."""
709        # Tracked and untracked files do not exist in the repo.
710        while (file := self.path / f'{uuid4().hex}{DIR_DB_SUFFIX}').exists():
711            pass
712        return file
713
714    def add_and_discard_temporary(
715        self, entries: Iterable[TokenizedStringEntry], commit: str
716    ) -> None:
717        """Adds new entries and discards temporary entries on disk.
718
719        - Find the latest CSV in the directory database or create a new one.
720        - Delete entries in the latest CSV that are not in the entries passed to
721          this function.
722        - Add the new entries to this database.
723        - Overwrite the latest CSV with only the newly added entries.
724        """
725        # Find entries not currently in the database.
726        added = Database(entries)
727        new_entries = added.difference(self)
728
729        csv_path = self._find_latest_csv(commit)
730        if csv_path.exists():
731            # Loading the CSV as a DatabaseFile.
732            csv_db = DatabaseFile.load(csv_path)
733
734            # Delete entries added in the CSV, but not added in this function.
735            for key in (e.key() for e in csv_db.difference(added).entries()):
736                del self._database[key]
737                del csv_db._database[key]  # pylint: disable=protected-access
738
739            csv_db.add(new_entries.entries())
740            csv_db.write_to_file()
741        elif new_entries:  # If the CSV does not exist, write all new tokens.
742            with csv_path.open('wb') as fd:
743                write_csv(new_entries, fd)
744
745        self.add(new_entries.entries())
746