• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2# Copyright 2020 The Pigweed Authors
3#
4# Licensed under the Apache License, Version 2.0 (the "License"); you may not
5# use this file except in compliance with the License. You may obtain a copy of
6# the License at
7#
8#     https://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13# License for the specific language governing permissions and limitations under
14# the License.
15"""Tests for the tokens module."""
16
17from datetime import datetime
18import io
19import logging
20from pathlib import Path
21import shutil
22import tempfile
23from typing import Iterator
24import unittest
25
26from pw_tokenizer import tokens
27from pw_tokenizer.tokens import c_hash, DIR_DB_SUFFIX, _LOG
28
29CSV_DATABASE = '''\
3000000000,2019-06-10,""
31141c35d5,          ,"The answer: ""%s"""
322db1515f,          ,"%u%d%02x%X%hu%hhu%d%ld%lu%lld%llu%c%c%c"
332e668cd6,2019-06-11,"Jello, world!"
3431631781,          ,"%d"
3561fd1e26,          ,"%ld"
3668ab92da,          ,"%s there are %x (%.2f) of them%c"
377b940e2a,          ,"Hello %s! %hd %e"
38851beeb6,          ,"%u %d"
39881436a0,          ,"The answer is: %s"
40ad002c97,          ,"%llx"
41b3653e13,2019-06-12,"Jello!"
42b912567b,          ,"%x%lld%1.2f%s"
43cc6d3131,2020-01-01,"Jello?"
44e13b0f94,          ,"%llu"
45e65aefef,2019-06-10,"Won't fit : %s%d"
46'''
47
48# The date 2019-06-10 is 07E3-06-0A in hex. In database order, it's 0A 06 E3 07.
49BINARY_DATABASE = (
50    b'TOKENS\x00\x00\x10\x00\x00\x00\0\0\0\0'  # header (0x10 entries)
51    b'\x00\x00\x00\x00\x0a\x06\xe3\x07'  # 0x01
52    b'\xd5\x35\x1c\x14\xff\xff\xff\xff'  # 0x02
53    b'\x5f\x51\xb1\x2d\xff\xff\xff\xff'  # 0x03
54    b'\xd6\x8c\x66\x2e\x0b\x06\xe3\x07'  # 0x04
55    b'\x81\x17\x63\x31\xff\xff\xff\xff'  # 0x05
56    b'\x26\x1e\xfd\x61\xff\xff\xff\xff'  # 0x06
57    b'\xda\x92\xab\x68\xff\xff\xff\xff'  # 0x07
58    b'\x2a\x0e\x94\x7b\xff\xff\xff\xff'  # 0x08
59    b'\xb6\xee\x1b\x85\xff\xff\xff\xff'  # 0x09
60    b'\xa0\x36\x14\x88\xff\xff\xff\xff'  # 0x0a
61    b'\x97\x2c\x00\xad\xff\xff\xff\xff'  # 0x0b
62    b'\x13\x3e\x65\xb3\x0c\x06\xe3\x07'  # 0x0c
63    b'\x7b\x56\x12\xb9\xff\xff\xff\xff'  # 0x0d
64    b'\x31\x31\x6d\xcc\x01\x01\xe4\x07'  # 0x0e
65    b'\x94\x0f\x3b\xe1\xff\xff\xff\xff'  # 0x0f
66    b'\xef\xef\x5a\xe6\x0a\x06\xe3\x07'  # 0x10
67    b'\x00'
68    b'The answer: "%s"\x00'
69    b'%u%d%02x%X%hu%hhu%d%ld%lu%lld%llu%c%c%c\x00'
70    b'Jello, world!\x00'
71    b'%d\x00'
72    b'%ld\x00'
73    b'%s there are %x (%.2f) of them%c\x00'
74    b'Hello %s! %hd %e\x00'
75    b'%u %d\x00'
76    b'The answer is: %s\x00'
77    b'%llx\x00'
78    b'Jello!\x00'
79    b'%x%lld%1.2f%s\x00'
80    b'Jello?\x00'
81    b'%llu\x00'
82    b'Won\'t fit : %s%d\x00'
83)
84
85INVALID_CSV = """\
861,,"Whoa there!"
872,this is totally invalid,"Whoa there!"
883,,"This one's OK"
89,,"Also broken"
905,1845-02-02,"I'm %s fine"
916,"Missing fields"
92"""
93
94CSV_DATABASE_2 = '''\
9500000000,          ,""
96141c35d5,          ,"The answer: ""%s"""
9729aef586,          ,"1234"
982b78825f,          ,"[:-)"
992e668cd6,          ,"Jello, world!"
10031631781,          ,"%d"
10161fd1e26,          ,"%ld"
10268ab92da,          ,"%s there are %x (%.2f) of them%c"
1037b940e2a,          ,"Hello %s! %hd %e"
1047da55d52,          ,">:-[]"
1057f35a9a5,          ,"TestName"
106851beeb6,          ,"%u %d"
107881436a0,          ,"The answer is: %s"
10888808930,          ,"%u%d%02x%X%hu%hhd%d%ld%lu%lld%llu%c%c%c"
10992723f44,          ,"???"
110a09d6698,          ,"won-won-won-wonderful"
111aa9ffa66,          ,"void pw::tokenizer::{anonymous}::TestName()"
112ad002c97,          ,"%llx"
113b3653e13,          ,"Jello!"
114cc6d3131,          ,"Jello?"
115e13b0f94,          ,"%llu"
116e65aefef,          ,"Won't fit : %s%d"
117'''
118
119CSV_DATABASE_3 = """\
12017fa86d3,          ,"hello"
12118c5017c,          ,"yes"
12259b2701c,          ,"The answer was: %s"
123881436a0,          ,"The answer is: %s"
124d18ada0f,          ,"something"
125"""
126
127CSV_DATABASE_4 = '''\
12800000000,          ,""
129141c35d5,          ,"The answer: ""%s"""
13017fa86d3,          ,"hello"
13118c5017c,          ,"yes"
13229aef586,          ,"1234"
1332b78825f,          ,"[:-)"
1342e668cd6,          ,"Jello, world!"
13531631781,          ,"%d"
13659b2701c,          ,"The answer was: %s"
13761fd1e26,          ,"%ld"
13868ab92da,          ,"%s there are %x (%.2f) of them%c"
1397b940e2a,          ,"Hello %s! %hd %e"
1407da55d52,          ,">:-[]"
1417f35a9a5,          ,"TestName"
142851beeb6,          ,"%u %d"
143881436a0,          ,"The answer is: %s"
14488808930,          ,"%u%d%02x%X%hu%hhd%d%ld%lu%lld%llu%c%c%c"
14592723f44,          ,"???"
146a09d6698,          ,"won-won-won-wonderful"
147aa9ffa66,          ,"void pw::tokenizer::{anonymous}::TestName()"
148ad002c97,          ,"%llx"
149b3653e13,          ,"Jello!"
150cc6d3131,          ,"Jello?"
151d18ada0f,          ,"something"
152e13b0f94,          ,"%llu"
153e65aefef,          ,"Won't fit : %s%d"
154'''
155
156
157def read_db_from_csv(csv_str: str) -> tokens.Database:
158    with io.StringIO(csv_str) as csv_db:
159        return tokens.Database(tokens.parse_csv(csv_db))
160
161
162def _entries(*strings: str) -> Iterator[tokens.TokenizedStringEntry]:
163    for string in strings:
164        yield tokens.TokenizedStringEntry(c_hash(string), string)
165
166
167class TokenDatabaseTest(unittest.TestCase):
168    """Tests the token database class."""
169
170    def test_csv(self) -> None:
171        db = read_db_from_csv(CSV_DATABASE)
172        self.assertEqual(str(db), CSV_DATABASE)
173
174        db = read_db_from_csv('')
175        self.assertEqual(str(db), '')
176
177    def test_csv_formatting(self) -> None:
178        db = read_db_from_csv('')
179        self.assertEqual(str(db), '')
180
181        db = read_db_from_csv('abc123,2048-04-01,Fake string\n')
182        self.assertEqual(str(db), '00abc123,2048-04-01,"Fake string"\n')
183
184        db = read_db_from_csv(
185            '1,1990-01-01,"Quotes"""\n' '0,1990-02-01,"Commas,"",,"\n'
186        )
187        self.assertEqual(
188            str(db),
189            (
190                '00000000,1990-02-01,"Commas,"",,"\n'
191                '00000001,1990-01-01,"Quotes"""\n'
192            ),
193        )
194
195    def test_bad_csv(self) -> None:
196        with self.assertLogs(_LOG, logging.ERROR) as logs:
197            db = read_db_from_csv(INVALID_CSV)
198
199        self.assertGreaterEqual(len(logs.output), 3)
200        self.assertEqual(len(db.token_to_entries), 3)
201
202        self.assertEqual(db.token_to_entries[1][0].string, 'Whoa there!')
203        self.assertFalse(db.token_to_entries[2])
204        self.assertEqual(db.token_to_entries[3][0].string, "This one's OK")
205        self.assertFalse(db.token_to_entries[4])
206        self.assertEqual(db.token_to_entries[5][0].string, "I'm %s fine")
207        self.assertFalse(db.token_to_entries[6])
208
209    def test_lookup(self) -> None:
210        db = read_db_from_csv(CSV_DATABASE)
211        self.assertEqual(db.token_to_entries[0x9999], [])
212
213        matches = db.token_to_entries[0x2E668CD6]
214        self.assertEqual(len(matches), 1)
215        jello = matches[0]
216
217        self.assertEqual(jello.token, 0x2E668CD6)
218        self.assertEqual(jello.string, 'Jello, world!')
219        self.assertEqual(jello.date_removed, datetime(2019, 6, 11))
220
221        matches = db.token_to_entries[0xE13B0F94]
222        self.assertEqual(len(matches), 1)
223        llu = matches[0]
224        self.assertEqual(llu.token, 0xE13B0F94)
225        self.assertEqual(llu.string, '%llu')
226        self.assertIsNone(llu.date_removed)
227
228        (answer,) = db.token_to_entries[0x141C35D5]
229        self.assertEqual(answer.string, 'The answer: "%s"')
230
231    def test_collisions(self) -> None:
232        hash_1 = tokens.c_hash('o000', 96)
233        hash_2 = tokens.c_hash('0Q1Q', 96)
234        self.assertEqual(hash_1, hash_2)
235
236        db = tokens.Database.from_strings(['o000', '0Q1Q'])
237
238        self.assertEqual(len(db.token_to_entries[hash_1]), 2)
239        self.assertCountEqual(
240            [entry.string for entry in db.token_to_entries[hash_1]],
241            ['o000', '0Q1Q'],
242        )
243
244    def test_purge(self) -> None:
245        db = read_db_from_csv(CSV_DATABASE)
246        original_length = len(db.token_to_entries)
247
248        self.assertEqual(db.token_to_entries[0][0].string, '')
249        self.assertEqual(db.token_to_entries[0x31631781][0].string, '%d')
250        self.assertEqual(
251            db.token_to_entries[0x2E668CD6][0].string, 'Jello, world!'
252        )
253        self.assertEqual(db.token_to_entries[0xB3653E13][0].string, 'Jello!')
254        self.assertEqual(db.token_to_entries[0xCC6D3131][0].string, 'Jello?')
255        self.assertEqual(
256            db.token_to_entries[0xE65AEFEF][0].string, "Won't fit : %s%d"
257        )
258
259        db.purge(datetime(2019, 6, 11))
260        self.assertLess(len(db.token_to_entries), original_length)
261
262        self.assertFalse(db.token_to_entries[0])
263        self.assertEqual(db.token_to_entries[0x31631781][0].string, '%d')
264        self.assertFalse(db.token_to_entries[0x2E668CD6])
265        self.assertEqual(db.token_to_entries[0xB3653E13][0].string, 'Jello!')
266        self.assertEqual(db.token_to_entries[0xCC6D3131][0].string, 'Jello?')
267        self.assertFalse(db.token_to_entries[0xE65AEFEF])
268
269    def test_merge(self) -> None:
270        """Tests the tokens.Database merge method."""
271
272        db = tokens.Database()
273
274        # Test basic merging into an empty database.
275        db.merge(
276            tokens.Database(
277                [
278                    tokens.TokenizedStringEntry(
279                        1, 'one', date_removed=datetime.min
280                    ),
281                    tokens.TokenizedStringEntry(
282                        2, 'two', date_removed=datetime.min
283                    ),
284                ]
285            )
286        )
287        self.assertEqual({str(e) for e in db.entries()}, {'one', 'two'})
288        self.assertEqual(db.token_to_entries[1][0].date_removed, datetime.min)
289        self.assertEqual(db.token_to_entries[2][0].date_removed, datetime.min)
290
291        # Test merging in an entry with a removal date.
292        db.merge(
293            tokens.Database(
294                [
295                    tokens.TokenizedStringEntry(3, 'three'),
296                    tokens.TokenizedStringEntry(
297                        4, 'four', date_removed=datetime.min
298                    ),
299                ]
300            )
301        )
302        self.assertEqual(
303            {str(e) for e in db.entries()}, {'one', 'two', 'three', 'four'}
304        )
305        self.assertIsNone(db.token_to_entries[3][0].date_removed)
306        self.assertEqual(db.token_to_entries[4][0].date_removed, datetime.min)
307
308        # Test merging in one entry.
309        db.merge(
310            tokens.Database(
311                [
312                    tokens.TokenizedStringEntry(5, 'five'),
313                ]
314            )
315        )
316        self.assertEqual(
317            {str(e) for e in db.entries()},
318            {'one', 'two', 'three', 'four', 'five'},
319        )
320        self.assertEqual(db.token_to_entries[4][0].date_removed, datetime.min)
321        self.assertIsNone(db.token_to_entries[5][0].date_removed)
322
323        # Merge in repeated entries different removal dates.
324        db.merge(
325            tokens.Database(
326                [
327                    tokens.TokenizedStringEntry(
328                        4, 'four', date_removed=datetime.max
329                    ),
330                    tokens.TokenizedStringEntry(
331                        5, 'five', date_removed=datetime.max
332                    ),
333                ]
334            )
335        )
336        self.assertEqual(len(db.entries()), 5)
337        self.assertEqual(
338            {str(e) for e in db.entries()},
339            {'one', 'two', 'three', 'four', 'five'},
340        )
341        self.assertEqual(db.token_to_entries[4][0].date_removed, datetime.max)
342        self.assertIsNone(db.token_to_entries[5][0].date_removed)
343
344        # Merge in the same repeated entries now without removal dates.
345        db.merge(
346            tokens.Database(
347                [
348                    tokens.TokenizedStringEntry(4, 'four'),
349                    tokens.TokenizedStringEntry(5, 'five'),
350                ]
351            )
352        )
353        self.assertEqual(len(db.entries()), 5)
354        self.assertEqual(
355            {str(e) for e in db.entries()},
356            {'one', 'two', 'three', 'four', 'five'},
357        )
358        self.assertIsNone(db.token_to_entries[4][0].date_removed)
359        self.assertIsNone(db.token_to_entries[5][0].date_removed)
360
361        # Merge in an empty databsse.
362        db.merge(tokens.Database([]))
363        self.assertEqual(
364            {str(e) for e in db.entries()},
365            {'one', 'two', 'three', 'four', 'five'},
366        )
367
368    def test_merge_multiple_datbases_in_one_call(self) -> None:
369        """Tests the merge and merged methods with multiple databases."""
370        db = tokens.Database.merged(
371            tokens.Database(
372                [
373                    tokens.TokenizedStringEntry(
374                        1, 'one', date_removed=datetime.max
375                    )
376                ]
377            ),
378            tokens.Database(
379                [
380                    tokens.TokenizedStringEntry(
381                        2, 'two', date_removed=datetime.min
382                    )
383                ]
384            ),
385            tokens.Database(
386                [
387                    tokens.TokenizedStringEntry(
388                        1, 'one', date_removed=datetime.min
389                    )
390                ]
391            ),
392        )
393        self.assertEqual({str(e) for e in db.entries()}, {'one', 'two'})
394
395        db.merge(
396            tokens.Database(
397                [
398                    tokens.TokenizedStringEntry(
399                        4, 'four', date_removed=datetime.max
400                    )
401                ]
402            ),
403            tokens.Database(
404                [
405                    tokens.TokenizedStringEntry(
406                        2, 'two', date_removed=datetime.max
407                    )
408                ]
409            ),
410            tokens.Database(
411                [
412                    tokens.TokenizedStringEntry(
413                        3, 'three', date_removed=datetime.min
414                    )
415                ]
416            ),
417        )
418        self.assertEqual(
419            {str(e) for e in db.entries()}, {'one', 'two', 'three', 'four'}
420        )
421
422    def test_entry_counts(self) -> None:
423        self.assertEqual(len(CSV_DATABASE.splitlines()), 16)
424
425        db = read_db_from_csv(CSV_DATABASE)
426        self.assertEqual(len(db.entries()), 16)
427        self.assertEqual(len(db.token_to_entries), 16)
428
429        # Add two strings with the same hash.
430        db.add(_entries('o000', '0Q1Q'))
431
432        self.assertEqual(len(db.entries()), 18)
433        self.assertEqual(len(db.token_to_entries), 17)
434
435    def test_mark_removed(self) -> None:
436        """Tests that date_removed field is set by mark_removed."""
437        db = tokens.Database.from_strings(
438            ['MILK', 'apples', 'oranges', 'CHEESE', 'pears']
439        )
440
441        self.assertTrue(
442            all(entry.date_removed is None for entry in db.entries())
443        )
444        date_1 = datetime(1, 2, 3)
445
446        db.mark_removed(_entries('apples', 'oranges', 'pears'), date_1)
447
448        self.assertEqual(
449            db.token_to_entries[c_hash('MILK')][0].date_removed, date_1
450        )
451        self.assertEqual(
452            db.token_to_entries[c_hash('CHEESE')][0].date_removed, date_1
453        )
454
455        now = datetime.now()
456        db.mark_removed(_entries('MILK', 'CHEESE', 'pears'))
457
458        # New strings are not added or re-added in mark_removed().
459        milk_date = db.token_to_entries[c_hash('MILK')][0].date_removed
460        assert milk_date is not None
461        self.assertGreaterEqual(milk_date, date_1)
462
463        cheese_date = db.token_to_entries[c_hash('CHEESE')][0].date_removed
464        assert cheese_date is not None
465        self.assertGreaterEqual(cheese_date, date_1)
466
467        # These strings were removed.
468        apples_date = db.token_to_entries[c_hash('apples')][0].date_removed
469        assert apples_date is not None
470        self.assertGreaterEqual(apples_date, now)
471
472        oranges_date = db.token_to_entries[c_hash('oranges')][0].date_removed
473        assert oranges_date is not None
474        self.assertGreaterEqual(oranges_date, now)
475        self.assertIsNone(db.token_to_entries[c_hash('pears')][0].date_removed)
476
477    def test_add(self) -> None:
478        db = tokens.Database()
479        db.add(_entries('MILK', 'apples'))
480        self.assertEqual({e.string for e in db.entries()}, {'MILK', 'apples'})
481
482        db.add(_entries('oranges', 'CHEESE', 'pears'))
483        self.assertEqual(len(db.entries()), 5)
484
485        db.add(_entries('MILK', 'apples', 'only this one is new'))
486        self.assertEqual(len(db.entries()), 6)
487
488        db.add(_entries('MILK'))
489        self.assertEqual(
490            {e.string for e in db.entries()},
491            {
492                'MILK',
493                'apples',
494                'oranges',
495                'CHEESE',
496                'pears',
497                'only this one is new',
498            },
499        )
500
501    def test_add_duplicate_entries_keeps_none_as_removal_date(self) -> None:
502        db = tokens.Database()
503        db.add(
504            [
505                tokens.TokenizedStringEntry(1, 'Spam', '', datetime.now()),
506                tokens.TokenizedStringEntry(1, 'Spam', ''),
507                tokens.TokenizedStringEntry(1, 'Spam', '', datetime.min),
508            ]
509        )
510        self.assertEqual(len(db), 1)
511        self.assertIsNone(db.token_to_entries[1][0].date_removed)
512
513    def test_add_duplicate_entries_keeps_newest_removal_date(self) -> None:
514        db = tokens.Database()
515        db.add(
516            [
517                tokens.TokenizedStringEntry(1, 'Spam', '', datetime.now()),
518                tokens.TokenizedStringEntry(1, 'Spam', '', datetime.max),
519                tokens.TokenizedStringEntry(1, 'Spam', '', datetime.now()),
520                tokens.TokenizedStringEntry(1, 'Spam', '', datetime.min),
521            ]
522        )
523        self.assertEqual(len(db), 1)
524        self.assertEqual(db.token_to_entries[1][0].date_removed, datetime.max)
525
526    def test_difference(self) -> None:
527        first = tokens.Database(
528            [
529                tokens.TokenizedStringEntry(1, 'one'),
530                tokens.TokenizedStringEntry(2, 'two'),
531                tokens.TokenizedStringEntry(3, 'three'),
532            ]
533        )
534        second = tokens.Database(
535            [
536                tokens.TokenizedStringEntry(1, 'one'),
537                tokens.TokenizedStringEntry(3, 'three'),
538                tokens.TokenizedStringEntry(4, 'four'),
539            ]
540        )
541        difference = first.difference(second)
542        self.assertEqual({e.string for e in difference.entries()}, {'two'})
543
544    def test_binary_format_write(self) -> None:
545        db = read_db_from_csv(CSV_DATABASE)
546
547        with io.BytesIO() as fd:
548            tokens.write_binary(db, fd)
549            binary_db = fd.getvalue()
550
551        self.assertEqual(BINARY_DATABASE, binary_db)
552
553    def test_binary_format_parse(self) -> None:
554        with io.BytesIO(BINARY_DATABASE) as binary_db:
555            db = tokens.Database(tokens.parse_binary(binary_db))
556
557        self.assertEqual(str(db), CSV_DATABASE)
558
559
560class TestDatabaseFile(unittest.TestCase):
561    """Tests the DatabaseFile class."""
562
563    def setUp(self) -> None:
564        file = tempfile.NamedTemporaryFile(delete=False)
565        file.close()
566        self._path = Path(file.name)
567
568    def tearDown(self) -> None:
569        self._path.unlink()
570
571    def test_update_csv_file(self) -> None:
572        self._path.write_text(CSV_DATABASE)
573        db = tokens.DatabaseFile.load(self._path)
574        self.assertEqual(str(db), CSV_DATABASE)
575
576        db.add([tokens.TokenizedStringEntry(0xFFFFFFFF, 'New entry!')])
577
578        db.write_to_file()
579
580        self.assertEqual(
581            self._path.read_text(),
582            CSV_DATABASE + 'ffffffff,          ,"New entry!"\n',
583        )
584
585    def test_csv_file_too_short_raises_exception(self) -> None:
586        self._path.write_text('1234')
587
588        with self.assertRaises(tokens.DatabaseFormatError):
589            tokens.DatabaseFile.load(self._path)
590
591    def test_csv_invalid_format_raises_exception(self) -> None:
592        self._path.write_text('MK34567890')
593
594        with self.assertRaises(tokens.DatabaseFormatError):
595            tokens.DatabaseFile.load(self._path)
596
597    def test_csv_not_utf8(self) -> None:
598        self._path.write_bytes(b'\x80' * 20)
599
600        with self.assertRaises(tokens.DatabaseFormatError):
601            tokens.DatabaseFile.load(self._path)
602
603
604class TestFilter(unittest.TestCase):
605    """Tests the filtering functionality."""
606
607    def setUp(self) -> None:
608        self.db = tokens.Database(
609            [
610                tokens.TokenizedStringEntry(1, 'Luke'),
611                tokens.TokenizedStringEntry(2, 'Leia'),
612                tokens.TokenizedStringEntry(2, 'Darth Vader'),
613                tokens.TokenizedStringEntry(2, 'Emperor Palpatine'),
614                tokens.TokenizedStringEntry(3, 'Han'),
615                tokens.TokenizedStringEntry(4, 'Chewbacca'),
616                tokens.TokenizedStringEntry(5, 'Darth Maul'),
617                tokens.TokenizedStringEntry(6, 'Han Solo'),
618            ]
619        )
620
621    def test_filter_include_single_regex(self) -> None:
622        self.db.filter(include=[' '])  # anything with a space
623        self.assertEqual(
624            set(e.string for e in self.db.entries()),
625            {'Darth Vader', 'Emperor Palpatine', 'Darth Maul', 'Han Solo'},
626        )
627
628    def test_filter_include_multiple_regexes(self) -> None:
629        self.db.filter(include=['Darth', 'cc', '^Han$'])
630        self.assertEqual(
631            set(e.string for e in self.db.entries()),
632            {'Darth Vader', 'Darth Maul', 'Han', 'Chewbacca'},
633        )
634
635    def test_filter_include_no_matches(self) -> None:
636        self.db.filter(include=['Gandalf'])
637        self.assertFalse(self.db.entries())
638
639    def test_filter_exclude_single_regex(self) -> None:
640        self.db.filter(exclude=['^[^L]'])
641        self.assertEqual(
642            set(e.string for e in self.db.entries()), {'Luke', 'Leia'}
643        )
644
645    def test_filter_exclude_multiple_regexes(self) -> None:
646        self.db.filter(exclude=[' ', 'Han', 'Chewbacca'])
647        self.assertEqual(
648            set(e.string for e in self.db.entries()), {'Luke', 'Leia'}
649        )
650
651    def test_filter_exclude_no_matches(self) -> None:
652        self.db.filter(exclude=['.*'])
653        self.assertFalse(self.db.entries())
654
655    def test_filter_include_and_exclude(self) -> None:
656        self.db.filter(include=[' '], exclude=['Darth', 'Emperor'])
657        self.assertEqual(set(e.string for e in self.db.entries()), {'Han Solo'})
658
659    def test_filter_neither_include_nor_exclude(self) -> None:
660        self.db.filter()
661        self.assertEqual(
662            set(e.string for e in self.db.entries()),
663            {
664                'Luke',
665                'Leia',
666                'Darth Vader',
667                'Emperor Palpatine',
668                'Han',
669                'Chewbacca',
670                'Darth Maul',
671                'Han Solo',
672            },
673        )
674
675
676class TestDirectoryDatabase(unittest.TestCase):
677    """Test DirectoryDatabase class is properly loaded."""
678
679    def setUp(self) -> None:
680        self._dir = Path(tempfile.mkdtemp('_pw_tokenizer_test'))
681        self._db_dir = self._dir / '_dir_database_test'
682        self._db_dir.mkdir(exist_ok=True)
683        self._db_csv = self._db_dir / f'first{DIR_DB_SUFFIX}'
684
685    def tearDown(self) -> None:
686        shutil.rmtree(self._dir)
687
688    def test_loading_empty_directory(self) -> None:
689        self.assertFalse(tokens.DatabaseFile.load(self._db_dir).entries())
690
691    def test_loading_a_single_file(self) -> None:
692        self._db_csv.write_text(CSV_DATABASE)
693        csv = tokens.DatabaseFile.load(self._db_csv)
694        directory_db = tokens.DatabaseFile.load(self._db_dir)
695        self.assertEqual(1, len(list(self._db_dir.iterdir())))
696        self.assertEqual(str(csv), str(directory_db))
697
698    def test_loading_multiples_files(self) -> None:
699        self._db_csv.write_text(CSV_DATABASE_3)
700        first_csv = tokens.DatabaseFile.load(self._db_csv)
701
702        path_to_second_csv = self._db_dir / f'second{DIR_DB_SUFFIX}'
703        path_to_second_csv.write_text(CSV_DATABASE_2)
704        second_csv = tokens.DatabaseFile.load(path_to_second_csv)
705
706        path_to_third_csv = self._db_dir / f'third{DIR_DB_SUFFIX}'
707        path_to_third_csv.write_text(CSV_DATABASE_4)
708        third_csv = tokens.DatabaseFile.load(path_to_third_csv)
709
710        all_databases_merged = tokens.Database.merged(
711            first_csv, second_csv, third_csv
712        )
713        directory_db = tokens.DatabaseFile.load(self._db_dir)
714        self.assertEqual(3, len(list(self._db_dir.iterdir())))
715        self.assertEqual(str(all_databases_merged), str(directory_db))
716
717    def test_loading_multiples_files_with_removal_dates(self) -> None:
718        self._db_csv.write_text(CSV_DATABASE)
719        first_csv = tokens.DatabaseFile.load(self._db_csv)
720
721        path_to_second_csv = self._db_dir / f'second{DIR_DB_SUFFIX}'
722        path_to_second_csv.write_text(CSV_DATABASE_2)
723        second_csv = tokens.DatabaseFile.load(path_to_second_csv)
724
725        path_to_third_csv = self._db_dir / f'third{DIR_DB_SUFFIX}'
726        path_to_third_csv.write_text(CSV_DATABASE_3)
727        third_csv = tokens.DatabaseFile.load(path_to_third_csv)
728
729        all_databases_merged = tokens.Database.merged(
730            first_csv, second_csv, third_csv
731        )
732        directory_db = tokens.DatabaseFile.load(self._db_dir)
733        self.assertEqual(3, len(list(self._db_dir.iterdir())))
734        self.assertEqual(str(all_databases_merged), str(directory_db))
735
736    def test_rewrite(self) -> None:
737        self._db_dir.joinpath('junk_file').write_text('should be ignored')
738
739        self._db_csv.write_text(CSV_DATABASE_3)
740        first_csv = tokens.DatabaseFile.load(self._db_csv)
741
742        path_to_second_csv = self._db_dir / f'second{DIR_DB_SUFFIX}'
743        path_to_second_csv.write_text(CSV_DATABASE_2)
744        second_csv = tokens.DatabaseFile.load(path_to_second_csv)
745
746        path_to_third_csv = self._db_dir / f'third{DIR_DB_SUFFIX}'
747        path_to_third_csv.write_text(CSV_DATABASE_4)
748        third_csv = tokens.DatabaseFile.load(path_to_third_csv)
749
750        all_databases_merged = tokens.Database.merged(
751            first_csv, second_csv, third_csv
752        )
753
754        directory_db = tokens.DatabaseFile.load(self._db_dir)
755        directory_db.write_to_file(rewrite=True)
756
757        self.assertEqual(1, len(list(self._db_dir.glob(f'*{DIR_DB_SUFFIX}'))))
758        self.assertEqual(
759            self._db_dir.joinpath('junk_file').read_text(), 'should be ignored'
760        )
761
762        directory_db = tokens.DatabaseFile.load(self._db_dir)
763        self.assertEqual(str(all_databases_merged), str(directory_db))
764
765
766if __name__ == '__main__':
767    unittest.main()
768