• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The Pigweed Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License"); you may not
4# use this file except in compliance with the License. You may obtain a copy of
5# the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations under
13# the License.
14"""Builds and manages databases of tokenized strings."""
15
16from __future__ import annotations
17
18from abc import abstractmethod
19import bisect
20from collections.abc import (
21    Callable,
22    Iterable,
23    Iterator,
24    Mapping,
25    Sequence,
26    ValuesView,
27)
28import csv
29from dataclasses import dataclass
30from datetime import datetime
31import io
32import logging
33from pathlib import Path
34import re
35import struct
36import subprocess
37from typing import (
38    Any,
39    BinaryIO,
40    IO,
41    NamedTuple,
42    overload,
43    Pattern,
44    TextIO,
45    TypeVar,
46)
47from uuid import uuid4
48
49DEFAULT_DOMAIN = ''
50
51# The default hash length to use for C-style hashes. This value only applies
52# when manually hashing strings to recreate token calculations in C. The C++
53# hash function does not have a maximum length.
54#
55# This MUST match the default value of PW_TOKENIZER_CFG_C_HASH_LENGTH in
56# pw_tokenizer/public/pw_tokenizer/config.h.
57DEFAULT_C_HASH_LENGTH = 128
58
59TOKENIZER_HASH_CONSTANT = 65599
60
61_LOG = logging.getLogger('pw_tokenizer')
62
63
64def _value(char: int | str) -> int:
65    return char if isinstance(char, int) else ord(char)
66
67
68def pw_tokenizer_65599_hash(
69    string: str | bytes, *, hash_length: int | None = None
70) -> int:
71    """Hashes the string with the hash function used to generate tokens in C++.
72
73    This hash function is used calculate tokens from strings in Python. It is
74    not used when extracting tokens from an ELF, since the token is stored in
75    the ELF as part of tokenization.
76    """
77    hash_value = len(string)
78    coefficient = TOKENIZER_HASH_CONSTANT
79
80    for char in string[:hash_length]:
81        hash_value = (hash_value + coefficient * _value(char)) % 2**32
82        coefficient = (coefficient * TOKENIZER_HASH_CONSTANT) % 2**32
83
84    return hash_value
85
86
87def c_hash(
88    string: str | bytes, hash_length: int = DEFAULT_C_HASH_LENGTH
89) -> int:
90    """Hashes the string with the hash function used in C."""
91    return pw_tokenizer_65599_hash(string, hash_length=hash_length)
92
93
94@dataclass(frozen=True, eq=True, order=True)
95class _EntryKey:
96    """Uniquely refers to an entry."""
97
98    domain: str
99    token: int
100    string: str
101
102
103class TokenizedStringEntry:
104    """A tokenized string with its metadata."""
105
106    def __init__(
107        self,
108        token: int,
109        string: str,
110        domain: str = DEFAULT_DOMAIN,
111        date_removed: datetime | None = None,
112    ) -> None:
113        self._key = _EntryKey(
114            ''.join(domain.split()),
115            token,
116            string,
117        )
118        self.date_removed = date_removed
119
120    @property
121    def token(self) -> int:
122        return self._key.token
123
124    @property
125    def string(self) -> str:
126        return self._key.string
127
128    @property
129    def domain(self) -> str:
130        return self._key.domain
131
132    def key(self) -> _EntryKey:
133        """The key determines uniqueness for a tokenized string."""
134        return self._key
135
136    def update_date_removed(self, new_date_removed: datetime | None) -> None:
137        """Sets self.date_removed if the other date is newer."""
138        # No removal date (None) is treated as the newest date.
139        if self.date_removed is None:
140            return
141
142        if new_date_removed is None or new_date_removed > self.date_removed:
143            self.date_removed = new_date_removed
144
145    def __eq__(self, other: Any) -> bool:
146        return (
147            self.key() == other.key()
148            and self.date_removed == other.date_removed
149        )
150
151    def __lt__(self, other: Any) -> bool:
152        """Sorts the entry by domain, token, date removed, then string."""
153        if self.domain != other.domain:
154            return self.domain < other.domain
155
156        if self.token != other.token:
157            return self.token < other.token
158
159        # Sort removal dates in reverse, so the most recently removed (or still
160        # present) entry appears first.
161        if self.date_removed != other.date_removed:
162            return (other.date_removed or datetime.max) < (
163                self.date_removed or datetime.max
164            )
165
166        return self.string < other.string
167
168    def __str__(self) -> str:
169        return self.string
170
171    def __repr__(self) -> str:
172        return (
173            f'{self.__class__.__name__}(token=0x{self.token:08x}, '
174            f'string={self.string!r}, domain={self.domain!r})'
175        )
176
177
178_TokenToEntries = dict[int, list[TokenizedStringEntry]]
179_K = TypeVar('_K')
180_V = TypeVar('_V')
181_T = TypeVar('_T')
182
183
184class _TokenDatabaseView(Mapping[_K, _V]):  # pylint: disable=abstract-method
185    """Read-only mapping view of a token database.
186
187    Behaves like a read-only version of defaultdict(list).
188    """
189
190    def __init__(self, mapping: Mapping[_K, Any]) -> None:
191        self._mapping = mapping
192
193    def __contains__(self, key: object) -> bool:
194        return key in self._mapping
195
196    @overload
197    def get(self, key: _K) -> _V | None:  # pylint: disable=arguments-differ
198        ...
199
200    @overload
201    def get(self, key: _K, default: _T) -> _V | _T:  # pylint: disable=W0222
202        ...
203
204    def get(self, key: _K, default: _T | None = None) -> _V | _T | None:
205        return self._mapping.get(key, default)
206
207    def __iter__(self) -> Iterator[_K]:
208        return iter(self._mapping)
209
210    def __len__(self) -> int:
211        return len(self._mapping)
212
213    def __str__(self) -> str:
214        return str(self._mapping)
215
216    def __repr__(self) -> str:
217        return repr(self._mapping)
218
219
220class _TokenMapping(_TokenDatabaseView[int, Sequence[TokenizedStringEntry]]):
221    def __getitem__(self, token: int) -> Sequence[TokenizedStringEntry]:
222        """Returns strings that match the specified token; may be empty."""
223        return self._mapping.get(token, ())  # Empty sequence if no match
224
225
226class _DomainTokenMapping(_TokenDatabaseView[str, _TokenMapping]):
227    def __getitem__(self, domain: str) -> _TokenMapping:
228        """Returns the token-to-strings mapping for the specified domain."""
229        return _TokenMapping(self._mapping.get(domain, {}))  # Empty if no match
230
231
232def _add_entry(entries: _TokenToEntries, entry: TokenizedStringEntry) -> None:
233    bisect.insort(
234        entries.setdefault(entry.token, []),
235        entry,
236        key=TokenizedStringEntry.key,  # Keep lists of entries sorted by key.
237    )
238
239
240class Database:
241    """Database of tokenized strings stored as TokenizedStringEntry objects."""
242
243    def __init__(self, entries: Iterable[TokenizedStringEntry] = ()):
244        """Creates a token database."""
245        # The database dict stores each unique (token, string) entry.
246        self._database: dict[_EntryKey, TokenizedStringEntry] = {}
247
248        # Index by token and domain
249        self._token_entries: _TokenToEntries = {}
250        self._domain_token_entries: dict[str, _TokenToEntries] = {}
251
252        self.add(entries)
253
254    @classmethod
255    def from_strings(
256        cls,
257        strings: Iterable[str],
258        domain: str = DEFAULT_DOMAIN,
259        tokenize: Callable[[str], int] = pw_tokenizer_65599_hash,
260    ) -> Database:
261        """Creates a Database from an iterable of strings."""
262        return cls(
263            TokenizedStringEntry(tokenize(string), string, domain)
264            for string in strings
265        )
266
267    @classmethod
268    def merged(cls, *databases: Database) -> Database:
269        """Creates a TokenDatabase from one or more other databases."""
270        db = cls()
271        db.merge(*databases)
272        return db
273
274    @property
275    def token_to_entries(self) -> Mapping[int, Sequence[TokenizedStringEntry]]:
276        """Returns a mapping of tokens to a sequence of TokenizedStringEntry.
277
278        Returns token database entries from all domains.
279        """
280        return _TokenMapping(self._token_entries)
281
282    @property
283    def domains(
284        self,
285    ) -> Mapping[str, Mapping[int, Sequence[TokenizedStringEntry]]]:
286        """Returns a mapping of domains to tokens to a sequence of entries.
287
288        `database.domains[domain][token]` returns a sequence of strings matching
289        the token in the domain, or an empty sequence if there are no matches.
290        """
291        return _DomainTokenMapping(self._domain_token_entries)
292
293    def entries(self) -> ValuesView[TokenizedStringEntry]:
294        """Returns iterable over all TokenizedStringEntries in the database."""
295        return self._database.values()
296
297    def collisions(
298        self,
299    ) -> Iterator[tuple[int, Sequence[TokenizedStringEntry]]]:
300        """Returns tuple of (token, entries_list)) for all colliding tokens."""
301        for token, entries in self.token_to_entries.items():
302            if len(entries) > 1:
303                yield token, entries
304
305    def mark_removed(
306        self,
307        all_entries: Iterable[TokenizedStringEntry],
308        removal_date: datetime | None = None,
309    ) -> list[TokenizedStringEntry]:
310        """Marks entries missing from all_entries as having been removed.
311
312        The entries are assumed to represent the complete set of entries for the
313        database. Entries currently in the database not present in the provided
314        entries are marked with a removal date but remain in the database.
315        Entries in all_entries missing from the database are NOT added; call the
316        add function to add these.
317
318        Args:
319          all_entries: the complete set of strings present in the database
320          removal_date: the datetime for removed entries; today by default
321
322        Returns:
323          A list of entries marked as removed.
324        """
325        if removal_date is None:
326            removal_date = datetime.now()
327
328        all_keys = frozenset(entry.key() for entry in all_entries)
329
330        removed = []
331
332        for entry in self._database.values():
333            if entry.key() not in all_keys and (
334                entry.date_removed is None or removal_date < entry.date_removed
335            ):
336                # Add a removal date, or update it to the oldest date.
337                entry.date_removed = removal_date
338                removed.append(entry)
339
340        return removed
341
342    def add(self, entries: Iterable[TokenizedStringEntry]) -> None:
343        """Adds new entries and updates date_removed for existing entries.
344
345        If the added tokens have removal dates, the newest date is used.
346        """
347        for new_entry in entries:
348            # Update an existing entry or create a new one.
349            try:
350                entry = self._database[new_entry.key()]
351
352                # Keep the latest removal date between the two entries.
353                if new_entry.date_removed is None:
354                    entry.date_removed = None
355                elif (
356                    entry.date_removed
357                    and entry.date_removed < new_entry.date_removed
358                ):
359                    entry.date_removed = new_entry.date_removed
360            except KeyError:
361                self._add_new_entry(new_entry)
362
363    def purge(
364        self, date_removed_cutoff: datetime | None = None
365    ) -> list[TokenizedStringEntry]:
366        """Removes and returns entries removed on/before date_removed_cutoff."""
367        if date_removed_cutoff is None:
368            date_removed_cutoff = datetime.max
369
370        to_delete = [
371            entry
372            for entry in self._database.values()
373            if entry.date_removed and entry.date_removed <= date_removed_cutoff
374        ]
375
376        for entry in to_delete:
377            self._delete_entry(entry)
378
379        return to_delete
380
381    def merge(self, *databases: Database) -> None:
382        """Merges two or more databases together.
383
384        All entries are kept if there are token collisions.
385
386        If there are two identical tokens (same domain, token, and string), the
387        newest removal date (or no date if either token has no removal date) is
388        used for the merged token.
389        """
390        for other_db in databases:
391            for entry in other_db.entries():
392                key = entry.key()
393
394                if key in self._database:
395                    self._database[key].update_date_removed(entry.date_removed)
396                else:
397                    self._add_new_entry(entry)
398
399    def filter(
400        self,
401        include: Iterable[str | Pattern[str]] = (),
402        exclude: Iterable[str | Pattern[str]] = (),
403        replace: Iterable[tuple[str | Pattern[str], str]] = (),
404    ) -> None:
405        """Filters the database using regular expressions (strings or compiled).
406
407        Args:
408          include: regexes; only entries matching at least one are kept
409          exclude: regexes; entries matching any of these are removed
410          replace: (regex, str) tuples; replaces matching terms in all entries
411        """
412        to_delete: list[TokenizedStringEntry] = []
413
414        if include:
415            include_re = [re.compile(pattern) for pattern in include]
416            to_delete.extend(
417                val
418                for val in self._database.values()
419                if not any(rgx.search(val.string) for rgx in include_re)
420            )
421
422        if exclude:
423            exclude_re = [re.compile(pattern) for pattern in exclude]
424            to_delete.extend(
425                val
426                for val in self._database.values()
427                if any(rgx.search(val.string) for rgx in exclude_re)
428            )
429
430        for entry in to_delete:
431            self._delete_entry(entry)
432
433        # Do the replacement after removing entries.
434        for search, replacement in replace:
435            search = re.compile(search)
436
437            to_replace: list[TokenizedStringEntry] = []
438            add: list[TokenizedStringEntry] = []
439
440            for entry in self._database.values():
441                new_string = search.sub(replacement, entry.string)
442                if new_string != entry.string:
443                    to_replace.append(entry)
444                    add.append(
445                        TokenizedStringEntry(
446                            entry.token,
447                            new_string,
448                            entry.domain,
449                            entry.date_removed,
450                        )
451                    )
452
453            for entry in to_replace:
454                self._delete_entry(entry)
455            self.add(add)
456
457    def difference(self, other: Database) -> Database:
458        """Returns a new Database with entries in this DB not in the other."""
459        # pylint: disable=protected-access
460        return Database(
461            e for k, e in self._database.items() if k not in other._database
462        )
463        # pylint: enable=protected-access
464
465    def _add_new_entry(self, new_entry: TokenizedStringEntry) -> None:
466        entry = TokenizedStringEntry(  # These are mutable, so make a copy.
467            new_entry.token,
468            new_entry.string,
469            new_entry.domain,
470            new_entry.date_removed,
471        )
472        self._database[entry.key()] = entry
473        _add_entry(self._token_entries, entry)
474        _add_entry(
475            self._domain_token_entries.setdefault(entry.domain, {}), entry
476        )
477
478    def _delete_entry(self, entry: TokenizedStringEntry) -> None:
479        del self._database[entry.key()]
480
481        # Remove from the token / domain mappings and clean up empty lists.
482        self._token_entries[entry.token].remove(entry)
483        if not self._token_entries[entry.token]:
484            del self._token_entries[entry.token]
485
486        self._domain_token_entries[entry.domain][entry.token].remove(entry)
487        if not self._domain_token_entries[entry.domain][entry.token]:
488            del self._domain_token_entries[entry.domain][entry.token]
489            if not self._domain_token_entries[entry.domain]:
490                del self._domain_token_entries[entry.domain]
491
492    def __len__(self) -> int:
493        """Returns the number of entries in the database."""
494        return len(self.entries())
495
496    def __bool__(self) -> bool:
497        """True if the database is non-empty."""
498        return bool(self._database)
499
500    def __str__(self) -> str:
501        """Outputs the database as CSV."""
502        csv_output = io.BytesIO()
503        write_csv(self, csv_output)
504        return csv_output.getvalue().decode()
505
506
507def parse_csv(fd: TextIO) -> Iterable[TokenizedStringEntry]:
508    """Parses TokenizedStringEntries from a CSV token database file."""
509    for line in csv.reader(fd):
510        try:
511            try:
512                token_str, date_str, domain, string_literal = line
513            except ValueError:
514                # If there are only three columns, use the default domain.
515                token_str, date_str, string_literal = line
516                domain = DEFAULT_DOMAIN
517
518            token = int(token_str, 16)
519            date = (
520                datetime.fromisoformat(date_str) if date_str.strip() else None
521            )
522
523            yield TokenizedStringEntry(token, string_literal, domain, date)
524        except (ValueError, UnicodeDecodeError) as err:
525            _LOG.error(
526                'Failed to parse tokenized string entry %s: %s', line, err
527            )
528
529
530def write_csv(database: Database, fd: IO[bytes]) -> None:
531    """Writes the database as CSV to the provided binary file."""
532    for entry in sorted(database.entries()):
533        _write_csv_line(fd, entry)
534
535
536def _write_csv_line(fd: IO[bytes], entry: TokenizedStringEntry):
537    """Write a line in CSV format to the provided binary file."""
538    # Align the CSV output to 10-character columns for improved readability.
539    # Use \n instead of RFC 4180's \r\n.
540    fd.write(
541        '{:08x},{:10},"{}","{}"\n'.format(
542            entry.token,
543            entry.date_removed.date().isoformat() if entry.date_removed else '',
544            entry.domain.replace('"', '""'),  # escape " as ""
545            entry.string.replace('"', '""'),
546        ).encode()
547    )
548
549
550class _BinaryFileFormat(NamedTuple):
551    """Attributes of the binary token database file format."""
552
553    magic: bytes = b'TOKENS\0\0'
554    header: struct.Struct = struct.Struct('<8sI4x')
555    entry: struct.Struct = struct.Struct('<IBBH')
556
557
558BINARY_FORMAT = _BinaryFileFormat()
559
560
561class DatabaseFormatError(Exception):
562    """Failed to parse a token database file."""
563
564
565def file_is_binary_database(fd: BinaryIO) -> bool:
566    """True if the file starts with the binary token database magic string."""
567    try:
568        fd.seek(0)
569        magic = fd.read(len(BINARY_FORMAT.magic))
570        fd.seek(0)
571        return BINARY_FORMAT.magic == magic
572    except IOError:
573        return False
574
575
576def _check_that_file_is_csv_database(path: Path) -> None:
577    """Raises an error unless the path appears to be a CSV token database."""
578    try:
579        with path.open('rb') as fd:
580            data = fd.read(8)  # Read 8 bytes, which should be the first token.
581
582        if not data:
583            return  # File is empty, which is valid CSV.
584
585        if len(data) != 8:
586            raise DatabaseFormatError(
587                f'Attempted to read {path} as a CSV token database, but the '
588                f'file is too short ({len(data)} B)'
589            )
590
591        # Make sure the first 8 chars are a valid hexadecimal number.
592        _ = int(data.decode(), 16)
593    except (IOError, UnicodeDecodeError, ValueError) as err:
594        raise DatabaseFormatError(
595            f'Encountered error while reading {path} as a CSV token database'
596        ) from err
597
598
599def parse_binary(fd: BinaryIO) -> Iterable[TokenizedStringEntry]:
600    """Parses TokenizedStringEntries from a binary token database file."""
601    magic, entry_count = BINARY_FORMAT.header.unpack(
602        fd.read(BINARY_FORMAT.header.size)
603    )
604
605    if magic != BINARY_FORMAT.magic:
606        raise DatabaseFormatError(
607            f'Binary token database magic number mismatch (found {magic!r}, '
608            f'expected {BINARY_FORMAT.magic!r}) while reading from {fd}'
609        )
610
611    entries = []
612
613    for _ in range(entry_count):
614        token, day, month, year = BINARY_FORMAT.entry.unpack(
615            fd.read(BINARY_FORMAT.entry.size)
616        )
617
618        try:
619            date_removed: datetime | None = datetime(year, month, day)
620        except ValueError:
621            date_removed = None
622
623        entries.append((token, date_removed))
624
625    # Read the entire string table and define a function for looking up strings.
626    string_table = fd.read()
627
628    def read_string(start):
629        end = string_table.find(b'\0', start)
630        return (
631            string_table[start : string_table.find(b'\0', start)].decode(),
632            end + 1,
633        )
634
635    offset = 0
636    for token, removed in entries:
637        string, offset = read_string(offset)
638        yield TokenizedStringEntry(token, string, DEFAULT_DOMAIN, removed)
639
640
641def write_binary(database: Database, fd: BinaryIO) -> None:
642    """Writes the database as packed binary to the provided binary file."""
643    entries = sorted(database.entries())
644
645    fd.write(BINARY_FORMAT.header.pack(BINARY_FORMAT.magic, len(entries)))
646
647    string_table = bytearray()
648
649    for entry in entries:
650        if entry.date_removed:
651            removed_day = entry.date_removed.day
652            removed_month = entry.date_removed.month
653            removed_year = entry.date_removed.year
654        else:
655            # If there is no removal date, use the special value 0xffffffff for
656            # the day/month/year. That ensures that still-present tokens appear
657            # as the newest tokens when sorted by removal date.
658            removed_day = 0xFF
659            removed_month = 0xFF
660            removed_year = 0xFFFF
661
662        string_table += entry.string.encode()
663        string_table.append(0)
664
665        fd.write(
666            BINARY_FORMAT.entry.pack(
667                entry.token, removed_day, removed_month, removed_year
668            )
669        )
670
671    fd.write(string_table)
672
673
674class DatabaseFile(Database):
675    """A token database that is associated with a particular file.
676
677    This class adds the write_to_file() method that writes to file from which it
678    was created in the correct format (CSV or binary).
679    """
680
681    def __init__(
682        self, path: Path, entries: Iterable[TokenizedStringEntry]
683    ) -> None:
684        super().__init__(entries)
685        self.path = path
686
687    @staticmethod
688    def load(path: Path) -> DatabaseFile:
689        """Creates a DatabaseFile that coincides to the file type."""
690        if path.is_dir():
691            return _DirectoryDatabase(path)
692
693        # Read the path as a packed binary file.
694        with path.open('rb') as fd:
695            if file_is_binary_database(fd):
696                return _BinaryDatabase(path, fd)
697
698        # Read the path as a CSV file.
699        _check_that_file_is_csv_database(path)
700        return _CSVDatabase(path)
701
702    @abstractmethod
703    def write_to_file(self, *, rewrite: bool = False) -> None:
704        """Exports in the original format to the original path."""
705
706    @abstractmethod
707    def add_and_discard_temporary(
708        self, entries: Iterable[TokenizedStringEntry], commit: str
709    ) -> None:
710        """Discards and adds entries to export in the original format.
711
712        Adds entries after removing temporary entries from the Database
713        to exclusively write re-occurring entries into memory and disk.
714        """
715
716
717class _BinaryDatabase(DatabaseFile):
718    def __init__(self, path: Path, fd: BinaryIO) -> None:
719        super().__init__(path, parse_binary(fd))
720
721    def write_to_file(self, *, rewrite: bool = False) -> None:
722        """Exports in the binary format to the original path."""
723        del rewrite  # Binary databases are always rewritten
724        with self.path.open('wb') as fd:
725            write_binary(self, fd)
726
727    def add_and_discard_temporary(
728        self, entries: Iterable[TokenizedStringEntry], commit: str
729    ) -> None:
730        # TODO: b/241471465 - Implement adding new tokens and removing
731        # temporary entries for binary databases.
732        raise NotImplementedError(
733            '--discard-temporary is currently only '
734            'supported for directory databases'
735        )
736
737
738class _CSVDatabase(DatabaseFile):
739    def __init__(self, path: Path) -> None:
740        with path.open('r', newline='', encoding='utf-8') as csv_fd:
741            super().__init__(path, parse_csv(csv_fd))
742
743    def write_to_file(self, *, rewrite: bool = False) -> None:
744        """Exports in the CSV format to the original path."""
745        del rewrite  # CSV databases are always rewritten
746        with self.path.open('wb') as fd:
747            write_csv(self, fd)
748
749    def add_and_discard_temporary(
750        self, entries: Iterable[TokenizedStringEntry], commit: str
751    ) -> None:
752        # TODO: b/241471465 - Implement adding new tokens and removing
753        # temporary entries for CSV databases.
754        raise NotImplementedError(
755            '--discard-temporary is currently only '
756            'supported for directory databases'
757        )
758
759
760# The suffix used for CSV files in a directory database.
761DIR_DB_SUFFIX = '.pw_tokenizer.csv'
762DIR_DB_GLOB = '*' + DIR_DB_SUFFIX
763
764
765def _parse_directory(directory: Path) -> Iterable[TokenizedStringEntry]:
766    """Parses TokenizedStringEntries tokenizer CSV files in the directory."""
767    for path in directory.glob(DIR_DB_GLOB):
768        yield from _CSVDatabase(path).entries()
769
770
771def _most_recently_modified_file(paths: Iterable[Path]) -> Path:
772    return max(paths, key=lambda path: path.stat().st_mtime)
773
774
775class _DirectoryDatabase(DatabaseFile):
776    def __init__(self, directory: Path) -> None:
777        super().__init__(directory, _parse_directory(directory))
778
779    def write_to_file(self, *, rewrite: bool = False) -> None:
780        """Creates a new CSV file in the directory with any new tokens."""
781        if rewrite:
782            # Write the entire database to a new CSV file
783            new_file = self._create_filename()
784            with new_file.open('wb') as fd:
785                write_csv(self, fd)
786
787            # Delete all CSV files except for the new CSV with everything.
788            for csv_file in self.path.glob(DIR_DB_GLOB):
789                if csv_file != new_file:
790                    csv_file.unlink()
791        else:
792            # Reread the tokens from disk and write only the new entries to CSV.
793            current_tokens = Database(_parse_directory(self.path))
794            new_entries = self.difference(current_tokens)
795            if new_entries:
796                with self._create_filename().open('wb') as fd:
797                    write_csv(new_entries, fd)
798
799    def _git_paths(self, commands: list) -> list[Path]:
800        """Returns a list of database CSVs from a Git command."""
801        try:
802            output = subprocess.run(
803                ['git', *commands, DIR_DB_GLOB],
804                capture_output=True,
805                check=True,
806                cwd=self.path,
807                text=True,
808            ).stdout.strip()
809            return [self.path / repo_path for repo_path in output.splitlines()]
810        except subprocess.CalledProcessError:
811            return []
812
813    def _find_latest_csv(self, commit: str) -> Path:
814        """Finds or creates a CSV to which to write new entries.
815
816        - Check for untracked CSVs. Use the most recently modified file, if any.
817        - Check for CSVs added in HEAD, if HEAD is not an ancestor of commit.
818          Use the most recently modified file, if any.
819        - If no untracked or committed files were found, create a new file.
820        """
821
822        # Prioritize untracked files in the directory database.
823        untracked_changes = self._git_paths(
824            ['ls-files', '--others', '--exclude-standard']
825        )
826        if untracked_changes:
827            return _most_recently_modified_file(untracked_changes)
828
829        # Check if HEAD is an ancestor of the base commit. This checks whether
830        # the top commit has been merged or not. If it has been merged, create a
831        # new CSV to use. Otherwise, check if a CSV was added in the commit.
832        head_is_not_merged = (
833            subprocess.run(
834                ['git', 'merge-base', '--is-ancestor', 'HEAD', commit],
835                cwd=self.path,
836                stdout=subprocess.DEVNULL,
837                stderr=subprocess.DEVNULL,
838            ).returncode
839            != 0
840        )
841
842        if head_is_not_merged:
843            # Find CSVs added in the top commit.
844            csvs_from_top_commit = self._git_paths(
845                [
846                    'diff',
847                    '--name-only',
848                    '--diff-filter=A',
849                    '--relative',
850                    'HEAD~',
851                ]
852            )
853
854            if csvs_from_top_commit:
855                return _most_recently_modified_file(csvs_from_top_commit)
856
857        return self._create_filename()
858
859    def _create_filename(self) -> Path:
860        """Generates a unique filename not in the directory."""
861        # Tracked and untracked files do not exist in the repo.
862        while (file := self.path / f'{uuid4().hex}{DIR_DB_SUFFIX}').exists():
863            pass
864        return file
865
866    def add_and_discard_temporary(
867        self, entries: Iterable[TokenizedStringEntry], commit: str
868    ) -> None:
869        """Adds new entries and discards temporary entries on disk.
870
871        - Find the latest CSV in the directory database or create a new one.
872        - Delete entries in the latest CSV that are not in the entries passed to
873          this function.
874        - Add the new entries to this database.
875        - Overwrite the latest CSV with only the newly added entries.
876        """
877        # Find entries not currently in the database.
878        added = Database(entries)
879        new_entries = added.difference(self)
880
881        csv_path = self._find_latest_csv(commit)
882        if csv_path.exists():
883            # Loading the CSV as a DatabaseFile.
884            csv_db = DatabaseFile.load(csv_path)
885
886            # Delete entries added in the CSV, but not added in this function.
887            for key in (e.key() for e in csv_db.difference(added).entries()):
888                del self._database[key]
889                del csv_db._database[key]  # pylint: disable=protected-access
890
891            csv_db.add(new_entries.entries())
892            csv_db.write_to_file()
893        elif new_entries:  # If the CSV does not exist, write all new tokens.
894            with csv_path.open('wb') as fd:
895                write_csv(new_entries, fd)
896
897        self.add(new_entries.entries())
898