• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2# Copyright 2020 The Pigweed Authors
3#
4# Licensed under the Apache License, Version 2.0 (the "License"); you may not
5# use this file except in compliance with the License. You may obtain a copy of
6# the License at
7#
8#     https://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13# License for the specific language governing permissions and limitations under
14# the License.
15"""Tests for the tokens module."""
16
17import datetime
18import io
19import logging
20from pathlib import Path
21import tempfile
22from typing import Iterator
23import unittest
24
25from pw_tokenizer import tokens
26from pw_tokenizer.tokens import default_hash, _LOG
27
28CSV_DATABASE = '''\
2900000000,2019-06-10,""
30141c35d5,          ,"The answer: ""%s"""
312db1515f,          ,"%u%d%02x%X%hu%hhu%d%ld%lu%lld%llu%c%c%c"
322e668cd6,2019-06-11,"Jello, world!"
3331631781,          ,"%d"
3461fd1e26,          ,"%ld"
3568ab92da,          ,"%s there are %x (%.2f) of them%c"
367b940e2a,          ,"Hello %s! %hd %e"
37851beeb6,          ,"%u %d"
38881436a0,          ,"The answer is: %s"
39ad002c97,          ,"%llx"
40b3653e13,2019-06-12,"Jello!"
41b912567b,          ,"%x%lld%1.2f%s"
42cc6d3131,2020-01-01,"Jello?"
43e13b0f94,          ,"%llu"
44e65aefef,2019-06-10,"Won't fit : %s%d"
45'''
46
47# The date 2019-06-10 is 07E3-06-0A in hex. In database order, it's 0A 06 E3 07.
48BINARY_DATABASE = (
49    b'TOKENS\x00\x00\x10\x00\x00\x00\0\0\0\0'  # header (0x10 entries)
50    b'\x00\x00\x00\x00\x0a\x06\xe3\x07'  # 0x01
51    b'\xd5\x35\x1c\x14\xff\xff\xff\xff'  # 0x02
52    b'\x5f\x51\xb1\x2d\xff\xff\xff\xff'  # 0x03
53    b'\xd6\x8c\x66\x2e\x0b\x06\xe3\x07'  # 0x04
54    b'\x81\x17\x63\x31\xff\xff\xff\xff'  # 0x05
55    b'\x26\x1e\xfd\x61\xff\xff\xff\xff'  # 0x06
56    b'\xda\x92\xab\x68\xff\xff\xff\xff'  # 0x07
57    b'\x2a\x0e\x94\x7b\xff\xff\xff\xff'  # 0x08
58    b'\xb6\xee\x1b\x85\xff\xff\xff\xff'  # 0x09
59    b'\xa0\x36\x14\x88\xff\xff\xff\xff'  # 0x0a
60    b'\x97\x2c\x00\xad\xff\xff\xff\xff'  # 0x0b
61    b'\x13\x3e\x65\xb3\x0c\x06\xe3\x07'  # 0x0c
62    b'\x7b\x56\x12\xb9\xff\xff\xff\xff'  # 0x0d
63    b'\x31\x31\x6d\xcc\x01\x01\xe4\x07'  # 0x0e
64    b'\x94\x0f\x3b\xe1\xff\xff\xff\xff'  # 0x0f
65    b'\xef\xef\x5a\xe6\x0a\x06\xe3\x07'  # 0x10
66    b'\x00'
67    b'The answer: "%s"\x00'
68    b'%u%d%02x%X%hu%hhu%d%ld%lu%lld%llu%c%c%c\x00'
69    b'Jello, world!\x00'
70    b'%d\x00'
71    b'%ld\x00'
72    b'%s there are %x (%.2f) of them%c\x00'
73    b'Hello %s! %hd %e\x00'
74    b'%u %d\x00'
75    b'The answer is: %s\x00'
76    b'%llx\x00'
77    b'Jello!\x00'
78    b'%x%lld%1.2f%s\x00'
79    b'Jello?\x00'
80    b'%llu\x00'
81    b'Won\'t fit : %s%d\x00')
82
83INVALID_CSV = """\
841,,"Whoa there!"
852,this is totally invalid,"Whoa there!"
863,,"This one's OK"
87,,"Also broken"
885,1845-2-2,"I'm %s fine"
896,"Missing fields"
90"""
91
92
93def read_db_from_csv(csv_str: str) -> tokens.Database:
94    with io.StringIO(csv_str) as csv_db:
95        return tokens.Database(tokens.parse_csv(csv_db))
96
97
98def _entries(*strings: str) -> Iterator[tokens.TokenizedStringEntry]:
99    for string in strings:
100        yield tokens.TokenizedStringEntry(default_hash(string), string)
101
102
103class TokenDatabaseTest(unittest.TestCase):
104    """Tests the token database class."""
105    def test_csv(self):
106        db = read_db_from_csv(CSV_DATABASE)
107        self.assertEqual(str(db), CSV_DATABASE)
108
109        db = read_db_from_csv('')
110        self.assertEqual(str(db), '')
111
112    def test_csv_formatting(self):
113        db = read_db_from_csv('')
114        self.assertEqual(str(db), '')
115
116        db = read_db_from_csv('abc123,2048-4-1,Fake string\n')
117        self.assertEqual(str(db), '00abc123,2048-04-01,"Fake string"\n')
118
119        db = read_db_from_csv('1,1990-01-01,"Quotes"""\n'
120                              '0,1990-02-01,"Commas,"",,"\n')
121        self.assertEqual(str(db), ('00000000,1990-02-01,"Commas,"",,"\n'
122                                   '00000001,1990-01-01,"Quotes"""\n'))
123
124    def test_bad_csv(self):
125        with self.assertLogs(_LOG, logging.ERROR) as logs:
126            db = read_db_from_csv(INVALID_CSV)
127
128        self.assertGreaterEqual(len(logs.output), 3)
129        self.assertEqual(len(db.token_to_entries), 3)
130
131        self.assertEqual(db.token_to_entries[1][0].string, 'Whoa there!')
132        self.assertFalse(db.token_to_entries[2])
133        self.assertEqual(db.token_to_entries[3][0].string, "This one's OK")
134        self.assertFalse(db.token_to_entries[4])
135        self.assertEqual(db.token_to_entries[5][0].string, "I'm %s fine")
136        self.assertFalse(db.token_to_entries[6])
137
138    def test_lookup(self):
139        db = read_db_from_csv(CSV_DATABASE)
140        self.assertEqual(db.token_to_entries[0x9999], [])
141
142        matches = db.token_to_entries[0x2e668cd6]
143        self.assertEqual(len(matches), 1)
144        jello = matches[0]
145
146        self.assertEqual(jello.token, 0x2e668cd6)
147        self.assertEqual(jello.string, 'Jello, world!')
148        self.assertEqual(jello.date_removed, datetime.datetime(2019, 6, 11))
149
150        matches = db.token_to_entries[0xe13b0f94]
151        self.assertEqual(len(matches), 1)
152        llu = matches[0]
153        self.assertEqual(llu.token, 0xe13b0f94)
154        self.assertEqual(llu.string, '%llu')
155        self.assertIsNone(llu.date_removed)
156
157        answer, = db.token_to_entries[0x141c35d5]
158        self.assertEqual(answer.string, 'The answer: "%s"')
159
160    def test_collisions(self):
161        hash_1 = tokens.pw_tokenizer_65599_hash('o000', 96)
162        hash_2 = tokens.pw_tokenizer_65599_hash('0Q1Q', 96)
163        self.assertEqual(hash_1, hash_2)
164
165        db = tokens.Database.from_strings(['o000', '0Q1Q'])
166
167        self.assertEqual(len(db.token_to_entries[hash_1]), 2)
168        self.assertCountEqual(
169            [entry.string for entry in db.token_to_entries[hash_1]],
170            ['o000', '0Q1Q'])
171
172    def test_purge(self):
173        db = read_db_from_csv(CSV_DATABASE)
174        original_length = len(db.token_to_entries)
175
176        self.assertEqual(db.token_to_entries[0][0].string, '')
177        self.assertEqual(db.token_to_entries[0x31631781][0].string, '%d')
178        self.assertEqual(db.token_to_entries[0x2e668cd6][0].string,
179                         'Jello, world!')
180        self.assertEqual(db.token_to_entries[0xb3653e13][0].string, 'Jello!')
181        self.assertEqual(db.token_to_entries[0xcc6d3131][0].string, 'Jello?')
182        self.assertEqual(db.token_to_entries[0xe65aefef][0].string,
183                         "Won't fit : %s%d")
184
185        db.purge(datetime.datetime(2019, 6, 11))
186        self.assertLess(len(db.token_to_entries), original_length)
187
188        self.assertFalse(db.token_to_entries[0])
189        self.assertEqual(db.token_to_entries[0x31631781][0].string, '%d')
190        self.assertFalse(db.token_to_entries[0x2e668cd6])
191        self.assertEqual(db.token_to_entries[0xb3653e13][0].string, 'Jello!')
192        self.assertEqual(db.token_to_entries[0xcc6d3131][0].string, 'Jello?')
193        self.assertFalse(db.token_to_entries[0xe65aefef])
194
195    def test_merge(self):
196        """Tests the tokens.Database merge method."""
197
198        db = tokens.Database()
199
200        # Test basic merging into an empty database.
201        db.merge(
202            tokens.Database([
203                tokens.TokenizedStringEntry(
204                    1, 'one', date_removed=datetime.datetime.min),
205                tokens.TokenizedStringEntry(
206                    2, 'two', date_removed=datetime.datetime.min),
207            ]))
208        self.assertEqual({str(e) for e in db.entries()}, {'one', 'two'})
209        self.assertEqual(db.token_to_entries[1][0].date_removed,
210                         datetime.datetime.min)
211        self.assertEqual(db.token_to_entries[2][0].date_removed,
212                         datetime.datetime.min)
213
214        # Test merging in an entry with a removal date.
215        db.merge(
216            tokens.Database([
217                tokens.TokenizedStringEntry(3, 'three'),
218                tokens.TokenizedStringEntry(
219                    4, 'four', date_removed=datetime.datetime.min),
220            ]))
221        self.assertEqual({str(e)
222                          for e in db.entries()},
223                         {'one', 'two', 'three', 'four'})
224        self.assertIsNone(db.token_to_entries[3][0].date_removed)
225        self.assertEqual(db.token_to_entries[4][0].date_removed,
226                         datetime.datetime.min)
227
228        # Test merging in one entry.
229        db.merge(tokens.Database([
230            tokens.TokenizedStringEntry(5, 'five'),
231        ]))
232        self.assertEqual({str(e)
233                          for e in db.entries()},
234                         {'one', 'two', 'three', 'four', 'five'})
235        self.assertEqual(db.token_to_entries[4][0].date_removed,
236                         datetime.datetime.min)
237        self.assertIsNone(db.token_to_entries[5][0].date_removed)
238
239        # Merge in repeated entries different removal dates.
240        db.merge(
241            tokens.Database([
242                tokens.TokenizedStringEntry(
243                    4, 'four', date_removed=datetime.datetime.max),
244                tokens.TokenizedStringEntry(
245                    5, 'five', date_removed=datetime.datetime.max),
246            ]))
247        self.assertEqual(len(db.entries()), 5)
248        self.assertEqual({str(e)
249                          for e in db.entries()},
250                         {'one', 'two', 'three', 'four', 'five'})
251        self.assertEqual(db.token_to_entries[4][0].date_removed,
252                         datetime.datetime.max)
253        self.assertIsNone(db.token_to_entries[5][0].date_removed)
254
255        # Merge in the same repeated entries now without removal dates.
256        db.merge(
257            tokens.Database([
258                tokens.TokenizedStringEntry(4, 'four'),
259                tokens.TokenizedStringEntry(5, 'five')
260            ]))
261        self.assertEqual(len(db.entries()), 5)
262        self.assertEqual({str(e)
263                          for e in db.entries()},
264                         {'one', 'two', 'three', 'four', 'five'})
265        self.assertIsNone(db.token_to_entries[4][0].date_removed)
266        self.assertIsNone(db.token_to_entries[5][0].date_removed)
267
268        # Merge in an empty databsse.
269        db.merge(tokens.Database([]))
270        self.assertEqual({str(e)
271                          for e in db.entries()},
272                         {'one', 'two', 'three', 'four', 'five'})
273
274    def test_merge_multiple_datbases_in_one_call(self):
275        """Tests the merge and merged methods with multiple databases."""
276        db = tokens.Database.merged(
277            tokens.Database([
278                tokens.TokenizedStringEntry(1,
279                                            'one',
280                                            date_removed=datetime.datetime.max)
281            ]),
282            tokens.Database([
283                tokens.TokenizedStringEntry(2,
284                                            'two',
285                                            date_removed=datetime.datetime.min)
286            ]),
287            tokens.Database([
288                tokens.TokenizedStringEntry(1,
289                                            'one',
290                                            date_removed=datetime.datetime.min)
291            ]))
292        self.assertEqual({str(e) for e in db.entries()}, {'one', 'two'})
293
294        db.merge(
295            tokens.Database([
296                tokens.TokenizedStringEntry(4,
297                                            'four',
298                                            date_removed=datetime.datetime.max)
299            ]),
300            tokens.Database([
301                tokens.TokenizedStringEntry(2,
302                                            'two',
303                                            date_removed=datetime.datetime.max)
304            ]),
305            tokens.Database([
306                tokens.TokenizedStringEntry(3,
307                                            'three',
308                                            date_removed=datetime.datetime.min)
309            ]))
310        self.assertEqual({str(e)
311                          for e in db.entries()},
312                         {'one', 'two', 'three', 'four'})
313
314    def test_entry_counts(self):
315        self.assertEqual(len(CSV_DATABASE.splitlines()), 16)
316
317        db = read_db_from_csv(CSV_DATABASE)
318        self.assertEqual(len(db.entries()), 16)
319        self.assertEqual(len(db.token_to_entries), 16)
320
321        # Add two strings with the same hash.
322        db.add(_entries('o000', '0Q1Q'))
323
324        self.assertEqual(len(db.entries()), 18)
325        self.assertEqual(len(db.token_to_entries), 17)
326
327    def test_mark_removed(self):
328        """Tests that date_removed field is set by mark_removed."""
329        db = tokens.Database.from_strings(
330            ['MILK', 'apples', 'oranges', 'CHEESE', 'pears'])
331
332        self.assertTrue(
333            all(entry.date_removed is None for entry in db.entries()))
334        date_1 = datetime.datetime(1, 2, 3)
335
336        db.mark_removed(_entries('apples', 'oranges', 'pears'), date_1)
337
338        self.assertEqual(
339            db.token_to_entries[default_hash('MILK')][0].date_removed, date_1)
340        self.assertEqual(
341            db.token_to_entries[default_hash('CHEESE')][0].date_removed,
342            date_1)
343
344        now = datetime.datetime.now()
345        db.mark_removed(_entries('MILK', 'CHEESE', 'pears'))
346
347        # New strings are not added or re-added in mark_removed().
348        self.assertGreaterEqual(
349            db.token_to_entries[default_hash('MILK')][0].date_removed, date_1)
350        self.assertGreaterEqual(
351            db.token_to_entries[default_hash('CHEESE')][0].date_removed,
352            date_1)
353
354        # These strings were removed.
355        self.assertGreaterEqual(
356            db.token_to_entries[default_hash('apples')][0].date_removed, now)
357        self.assertGreaterEqual(
358            db.token_to_entries[default_hash('oranges')][0].date_removed, now)
359        self.assertIsNone(
360            db.token_to_entries[default_hash('pears')][0].date_removed)
361
362    def test_add(self):
363        db = tokens.Database()
364        db.add(_entries('MILK', 'apples'))
365        self.assertEqual({e.string for e in db.entries()}, {'MILK', 'apples'})
366
367        db.add(_entries('oranges', 'CHEESE', 'pears'))
368        self.assertEqual(len(db.entries()), 5)
369
370        db.add(_entries('MILK', 'apples', 'only this one is new'))
371        self.assertEqual(len(db.entries()), 6)
372
373        db.add(_entries('MILK'))
374        self.assertEqual({e.string
375                          for e in db.entries()}, {
376                              'MILK', 'apples', 'oranges', 'CHEESE', 'pears',
377                              'only this one is new'
378                          })
379
380    def test_binary_format_write(self):
381        db = read_db_from_csv(CSV_DATABASE)
382
383        with io.BytesIO() as fd:
384            tokens.write_binary(db, fd)
385            binary_db = fd.getvalue()
386
387        self.assertEqual(BINARY_DATABASE, binary_db)
388
389    def test_binary_format_parse(self):
390        with io.BytesIO(BINARY_DATABASE) as binary_db:
391            db = tokens.Database(tokens.parse_binary(binary_db))
392
393        self.assertEqual(str(db), CSV_DATABASE)
394
395
396class TestDatabaseFile(unittest.TestCase):
397    """Tests the DatabaseFile class."""
398    def setUp(self):
399        file = tempfile.NamedTemporaryFile(delete=False)
400        file.close()
401        self._path = Path(file.name)
402
403    def tearDown(self):
404        self._path.unlink()
405
406    def test_update_csv_file(self):
407        self._path.write_text(CSV_DATABASE)
408        db = tokens.DatabaseFile(self._path)
409        self.assertEqual(str(db), CSV_DATABASE)
410
411        db.add([tokens.TokenizedStringEntry(0xffffffff, 'New entry!')])
412
413        db.write_to_file()
414
415        self.assertEqual(self._path.read_text(),
416                         CSV_DATABASE + 'ffffffff,          ,"New entry!"\n')
417
418    def test_csv_file_too_short_raises_exception(self):
419        self._path.write_text('1234')
420
421        with self.assertRaises(tokens.DatabaseFormatError):
422            tokens.DatabaseFile(self._path)
423
424    def test_csv_invalid_format_raises_exception(self):
425        self._path.write_text('MK34567890')
426
427        with self.assertRaises(tokens.DatabaseFormatError):
428            tokens.DatabaseFile(self._path)
429
430    def test_csv_not_utf8(self):
431        self._path.write_bytes(b'\x80' * 20)
432
433        with self.assertRaises(tokens.DatabaseFormatError):
434            tokens.DatabaseFile(self._path)
435
436
437class TestFilter(unittest.TestCase):
438    """Tests the filtering functionality."""
439    def setUp(self):
440        self.db = tokens.Database([
441            tokens.TokenizedStringEntry(1, 'Luke'),
442            tokens.TokenizedStringEntry(2, 'Leia'),
443            tokens.TokenizedStringEntry(2, 'Darth Vader'),
444            tokens.TokenizedStringEntry(2, 'Emperor Palpatine'),
445            tokens.TokenizedStringEntry(3, 'Han'),
446            tokens.TokenizedStringEntry(4, 'Chewbacca'),
447            tokens.TokenizedStringEntry(5, 'Darth Maul'),
448            tokens.TokenizedStringEntry(6, 'Han Solo'),
449        ])
450
451    def test_filter_include_single_regex(self):
452        self.db.filter(include=[' '])  # anything with a space
453        self.assertEqual(
454            set(e.string for e in self.db.entries()),
455            {'Darth Vader', 'Emperor Palpatine', 'Darth Maul', 'Han Solo'})
456
457    def test_filter_include_multiple_regexes(self):
458        self.db.filter(include=['Darth', 'cc', '^Han$'])
459        self.assertEqual(set(e.string for e in self.db.entries()),
460                         {'Darth Vader', 'Darth Maul', 'Han', 'Chewbacca'})
461
462    def test_filter_include_no_matches(self):
463        self.db.filter(include=['Gandalf'])
464        self.assertFalse(self.db.entries())
465
466    def test_filter_exclude_single_regex(self):
467        self.db.filter(exclude=['^[^L]'])
468        self.assertEqual(set(e.string for e in self.db.entries()),
469                         {'Luke', 'Leia'})
470
471    def test_filter_exclude_multiple_regexes(self):
472        self.db.filter(exclude=[' ', 'Han', 'Chewbacca'])
473        self.assertEqual(set(e.string for e in self.db.entries()),
474                         {'Luke', 'Leia'})
475
476    def test_filter_exclude_no_matches(self):
477        self.db.filter(exclude=['.*'])
478        self.assertFalse(self.db.entries())
479
480    def test_filter_include_and_exclude(self):
481        self.db.filter(include=[' '], exclude=['Darth', 'Emperor'])
482        self.assertEqual(set(e.string for e in self.db.entries()),
483                         {'Han Solo'})
484
485    def test_filter_neither_include_nor_exclude(self):
486        self.db.filter()
487        self.assertEqual(
488            set(e.string for e in self.db.entries()), {
489                'Luke', 'Leia', 'Darth Vader', 'Emperor Palpatine', 'Han',
490                'Chewbacca', 'Darth Maul', 'Han Solo'
491            })
492
493
494if __name__ == '__main__':
495    unittest.main()
496