1#!/usr/bin/env python3 2# Copyright 2020 The Pigweed Authors 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); you may not 5# use this file except in compliance with the License. You may obtain a copy of 6# the License at 7# 8# https://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 12# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 13# License for the specific language governing permissions and limitations under 14# the License. 15"""Tests for the tokens module.""" 16 17import datetime 18import io 19import logging 20from pathlib import Path 21import tempfile 22from typing import Iterator 23import unittest 24 25from pw_tokenizer import tokens 26from pw_tokenizer.tokens import default_hash, _LOG 27 28CSV_DATABASE = '''\ 2900000000,2019-06-10,"" 30141c35d5, ,"The answer: ""%s""" 312db1515f, ,"%u%d%02x%X%hu%hhu%d%ld%lu%lld%llu%c%c%c" 322e668cd6,2019-06-11,"Jello, world!" 3331631781, ,"%d" 3461fd1e26, ,"%ld" 3568ab92da, ,"%s there are %x (%.2f) of them%c" 367b940e2a, ,"Hello %s! %hd %e" 37851beeb6, ,"%u %d" 38881436a0, ,"The answer is: %s" 39ad002c97, ,"%llx" 40b3653e13,2019-06-12,"Jello!" 41b912567b, ,"%x%lld%1.2f%s" 42cc6d3131,2020-01-01,"Jello?" 43e13b0f94, ,"%llu" 44e65aefef,2019-06-10,"Won't fit : %s%d" 45''' 46 47# The date 2019-06-10 is 07E3-06-0A in hex. In database order, it's 0A 06 E3 07. 48BINARY_DATABASE = ( 49 b'TOKENS\x00\x00\x10\x00\x00\x00\0\0\0\0' # header (0x10 entries) 50 b'\x00\x00\x00\x00\x0a\x06\xe3\x07' # 0x01 51 b'\xd5\x35\x1c\x14\xff\xff\xff\xff' # 0x02 52 b'\x5f\x51\xb1\x2d\xff\xff\xff\xff' # 0x03 53 b'\xd6\x8c\x66\x2e\x0b\x06\xe3\x07' # 0x04 54 b'\x81\x17\x63\x31\xff\xff\xff\xff' # 0x05 55 b'\x26\x1e\xfd\x61\xff\xff\xff\xff' # 0x06 56 b'\xda\x92\xab\x68\xff\xff\xff\xff' # 0x07 57 b'\x2a\x0e\x94\x7b\xff\xff\xff\xff' # 0x08 58 b'\xb6\xee\x1b\x85\xff\xff\xff\xff' # 0x09 59 b'\xa0\x36\x14\x88\xff\xff\xff\xff' # 0x0a 60 b'\x97\x2c\x00\xad\xff\xff\xff\xff' # 0x0b 61 b'\x13\x3e\x65\xb3\x0c\x06\xe3\x07' # 0x0c 62 b'\x7b\x56\x12\xb9\xff\xff\xff\xff' # 0x0d 63 b'\x31\x31\x6d\xcc\x01\x01\xe4\x07' # 0x0e 64 b'\x94\x0f\x3b\xe1\xff\xff\xff\xff' # 0x0f 65 b'\xef\xef\x5a\xe6\x0a\x06\xe3\x07' # 0x10 66 b'\x00' 67 b'The answer: "%s"\x00' 68 b'%u%d%02x%X%hu%hhu%d%ld%lu%lld%llu%c%c%c\x00' 69 b'Jello, world!\x00' 70 b'%d\x00' 71 b'%ld\x00' 72 b'%s there are %x (%.2f) of them%c\x00' 73 b'Hello %s! %hd %e\x00' 74 b'%u %d\x00' 75 b'The answer is: %s\x00' 76 b'%llx\x00' 77 b'Jello!\x00' 78 b'%x%lld%1.2f%s\x00' 79 b'Jello?\x00' 80 b'%llu\x00' 81 b'Won\'t fit : %s%d\x00') 82 83INVALID_CSV = """\ 841,,"Whoa there!" 852,this is totally invalid,"Whoa there!" 863,,"This one's OK" 87,,"Also broken" 885,1845-2-2,"I'm %s fine" 896,"Missing fields" 90""" 91 92 93def read_db_from_csv(csv_str: str) -> tokens.Database: 94 with io.StringIO(csv_str) as csv_db: 95 return tokens.Database(tokens.parse_csv(csv_db)) 96 97 98def _entries(*strings: str) -> Iterator[tokens.TokenizedStringEntry]: 99 for string in strings: 100 yield tokens.TokenizedStringEntry(default_hash(string), string) 101 102 103class TokenDatabaseTest(unittest.TestCase): 104 """Tests the token database class.""" 105 def test_csv(self): 106 db = read_db_from_csv(CSV_DATABASE) 107 self.assertEqual(str(db), CSV_DATABASE) 108 109 db = read_db_from_csv('') 110 self.assertEqual(str(db), '') 111 112 def test_csv_formatting(self): 113 db = read_db_from_csv('') 114 self.assertEqual(str(db), '') 115 116 db = read_db_from_csv('abc123,2048-4-1,Fake string\n') 117 self.assertEqual(str(db), '00abc123,2048-04-01,"Fake string"\n') 118 119 db = read_db_from_csv('1,1990-01-01,"Quotes"""\n' 120 '0,1990-02-01,"Commas,"",,"\n') 121 self.assertEqual(str(db), ('00000000,1990-02-01,"Commas,"",,"\n' 122 '00000001,1990-01-01,"Quotes"""\n')) 123 124 def test_bad_csv(self): 125 with self.assertLogs(_LOG, logging.ERROR) as logs: 126 db = read_db_from_csv(INVALID_CSV) 127 128 self.assertGreaterEqual(len(logs.output), 3) 129 self.assertEqual(len(db.token_to_entries), 3) 130 131 self.assertEqual(db.token_to_entries[1][0].string, 'Whoa there!') 132 self.assertFalse(db.token_to_entries[2]) 133 self.assertEqual(db.token_to_entries[3][0].string, "This one's OK") 134 self.assertFalse(db.token_to_entries[4]) 135 self.assertEqual(db.token_to_entries[5][0].string, "I'm %s fine") 136 self.assertFalse(db.token_to_entries[6]) 137 138 def test_lookup(self): 139 db = read_db_from_csv(CSV_DATABASE) 140 self.assertEqual(db.token_to_entries[0x9999], []) 141 142 matches = db.token_to_entries[0x2e668cd6] 143 self.assertEqual(len(matches), 1) 144 jello = matches[0] 145 146 self.assertEqual(jello.token, 0x2e668cd6) 147 self.assertEqual(jello.string, 'Jello, world!') 148 self.assertEqual(jello.date_removed, datetime.datetime(2019, 6, 11)) 149 150 matches = db.token_to_entries[0xe13b0f94] 151 self.assertEqual(len(matches), 1) 152 llu = matches[0] 153 self.assertEqual(llu.token, 0xe13b0f94) 154 self.assertEqual(llu.string, '%llu') 155 self.assertIsNone(llu.date_removed) 156 157 answer, = db.token_to_entries[0x141c35d5] 158 self.assertEqual(answer.string, 'The answer: "%s"') 159 160 def test_collisions(self): 161 hash_1 = tokens.pw_tokenizer_65599_hash('o000', 96) 162 hash_2 = tokens.pw_tokenizer_65599_hash('0Q1Q', 96) 163 self.assertEqual(hash_1, hash_2) 164 165 db = tokens.Database.from_strings(['o000', '0Q1Q']) 166 167 self.assertEqual(len(db.token_to_entries[hash_1]), 2) 168 self.assertCountEqual( 169 [entry.string for entry in db.token_to_entries[hash_1]], 170 ['o000', '0Q1Q']) 171 172 def test_purge(self): 173 db = read_db_from_csv(CSV_DATABASE) 174 original_length = len(db.token_to_entries) 175 176 self.assertEqual(db.token_to_entries[0][0].string, '') 177 self.assertEqual(db.token_to_entries[0x31631781][0].string, '%d') 178 self.assertEqual(db.token_to_entries[0x2e668cd6][0].string, 179 'Jello, world!') 180 self.assertEqual(db.token_to_entries[0xb3653e13][0].string, 'Jello!') 181 self.assertEqual(db.token_to_entries[0xcc6d3131][0].string, 'Jello?') 182 self.assertEqual(db.token_to_entries[0xe65aefef][0].string, 183 "Won't fit : %s%d") 184 185 db.purge(datetime.datetime(2019, 6, 11)) 186 self.assertLess(len(db.token_to_entries), original_length) 187 188 self.assertFalse(db.token_to_entries[0]) 189 self.assertEqual(db.token_to_entries[0x31631781][0].string, '%d') 190 self.assertFalse(db.token_to_entries[0x2e668cd6]) 191 self.assertEqual(db.token_to_entries[0xb3653e13][0].string, 'Jello!') 192 self.assertEqual(db.token_to_entries[0xcc6d3131][0].string, 'Jello?') 193 self.assertFalse(db.token_to_entries[0xe65aefef]) 194 195 def test_merge(self): 196 """Tests the tokens.Database merge method.""" 197 198 db = tokens.Database() 199 200 # Test basic merging into an empty database. 201 db.merge( 202 tokens.Database([ 203 tokens.TokenizedStringEntry( 204 1, 'one', date_removed=datetime.datetime.min), 205 tokens.TokenizedStringEntry( 206 2, 'two', date_removed=datetime.datetime.min), 207 ])) 208 self.assertEqual({str(e) for e in db.entries()}, {'one', 'two'}) 209 self.assertEqual(db.token_to_entries[1][0].date_removed, 210 datetime.datetime.min) 211 self.assertEqual(db.token_to_entries[2][0].date_removed, 212 datetime.datetime.min) 213 214 # Test merging in an entry with a removal date. 215 db.merge( 216 tokens.Database([ 217 tokens.TokenizedStringEntry(3, 'three'), 218 tokens.TokenizedStringEntry( 219 4, 'four', date_removed=datetime.datetime.min), 220 ])) 221 self.assertEqual({str(e) 222 for e in db.entries()}, 223 {'one', 'two', 'three', 'four'}) 224 self.assertIsNone(db.token_to_entries[3][0].date_removed) 225 self.assertEqual(db.token_to_entries[4][0].date_removed, 226 datetime.datetime.min) 227 228 # Test merging in one entry. 229 db.merge(tokens.Database([ 230 tokens.TokenizedStringEntry(5, 'five'), 231 ])) 232 self.assertEqual({str(e) 233 for e in db.entries()}, 234 {'one', 'two', 'three', 'four', 'five'}) 235 self.assertEqual(db.token_to_entries[4][0].date_removed, 236 datetime.datetime.min) 237 self.assertIsNone(db.token_to_entries[5][0].date_removed) 238 239 # Merge in repeated entries different removal dates. 240 db.merge( 241 tokens.Database([ 242 tokens.TokenizedStringEntry( 243 4, 'four', date_removed=datetime.datetime.max), 244 tokens.TokenizedStringEntry( 245 5, 'five', date_removed=datetime.datetime.max), 246 ])) 247 self.assertEqual(len(db.entries()), 5) 248 self.assertEqual({str(e) 249 for e in db.entries()}, 250 {'one', 'two', 'three', 'four', 'five'}) 251 self.assertEqual(db.token_to_entries[4][0].date_removed, 252 datetime.datetime.max) 253 self.assertIsNone(db.token_to_entries[5][0].date_removed) 254 255 # Merge in the same repeated entries now without removal dates. 256 db.merge( 257 tokens.Database([ 258 tokens.TokenizedStringEntry(4, 'four'), 259 tokens.TokenizedStringEntry(5, 'five') 260 ])) 261 self.assertEqual(len(db.entries()), 5) 262 self.assertEqual({str(e) 263 for e in db.entries()}, 264 {'one', 'two', 'three', 'four', 'five'}) 265 self.assertIsNone(db.token_to_entries[4][0].date_removed) 266 self.assertIsNone(db.token_to_entries[5][0].date_removed) 267 268 # Merge in an empty databsse. 269 db.merge(tokens.Database([])) 270 self.assertEqual({str(e) 271 for e in db.entries()}, 272 {'one', 'two', 'three', 'four', 'five'}) 273 274 def test_merge_multiple_datbases_in_one_call(self): 275 """Tests the merge and merged methods with multiple databases.""" 276 db = tokens.Database.merged( 277 tokens.Database([ 278 tokens.TokenizedStringEntry(1, 279 'one', 280 date_removed=datetime.datetime.max) 281 ]), 282 tokens.Database([ 283 tokens.TokenizedStringEntry(2, 284 'two', 285 date_removed=datetime.datetime.min) 286 ]), 287 tokens.Database([ 288 tokens.TokenizedStringEntry(1, 289 'one', 290 date_removed=datetime.datetime.min) 291 ])) 292 self.assertEqual({str(e) for e in db.entries()}, {'one', 'two'}) 293 294 db.merge( 295 tokens.Database([ 296 tokens.TokenizedStringEntry(4, 297 'four', 298 date_removed=datetime.datetime.max) 299 ]), 300 tokens.Database([ 301 tokens.TokenizedStringEntry(2, 302 'two', 303 date_removed=datetime.datetime.max) 304 ]), 305 tokens.Database([ 306 tokens.TokenizedStringEntry(3, 307 'three', 308 date_removed=datetime.datetime.min) 309 ])) 310 self.assertEqual({str(e) 311 for e in db.entries()}, 312 {'one', 'two', 'three', 'four'}) 313 314 def test_entry_counts(self): 315 self.assertEqual(len(CSV_DATABASE.splitlines()), 16) 316 317 db = read_db_from_csv(CSV_DATABASE) 318 self.assertEqual(len(db.entries()), 16) 319 self.assertEqual(len(db.token_to_entries), 16) 320 321 # Add two strings with the same hash. 322 db.add(_entries('o000', '0Q1Q')) 323 324 self.assertEqual(len(db.entries()), 18) 325 self.assertEqual(len(db.token_to_entries), 17) 326 327 def test_mark_removed(self): 328 """Tests that date_removed field is set by mark_removed.""" 329 db = tokens.Database.from_strings( 330 ['MILK', 'apples', 'oranges', 'CHEESE', 'pears']) 331 332 self.assertTrue( 333 all(entry.date_removed is None for entry in db.entries())) 334 date_1 = datetime.datetime(1, 2, 3) 335 336 db.mark_removed(_entries('apples', 'oranges', 'pears'), date_1) 337 338 self.assertEqual( 339 db.token_to_entries[default_hash('MILK')][0].date_removed, date_1) 340 self.assertEqual( 341 db.token_to_entries[default_hash('CHEESE')][0].date_removed, 342 date_1) 343 344 now = datetime.datetime.now() 345 db.mark_removed(_entries('MILK', 'CHEESE', 'pears')) 346 347 # New strings are not added or re-added in mark_removed(). 348 self.assertGreaterEqual( 349 db.token_to_entries[default_hash('MILK')][0].date_removed, date_1) 350 self.assertGreaterEqual( 351 db.token_to_entries[default_hash('CHEESE')][0].date_removed, 352 date_1) 353 354 # These strings were removed. 355 self.assertGreaterEqual( 356 db.token_to_entries[default_hash('apples')][0].date_removed, now) 357 self.assertGreaterEqual( 358 db.token_to_entries[default_hash('oranges')][0].date_removed, now) 359 self.assertIsNone( 360 db.token_to_entries[default_hash('pears')][0].date_removed) 361 362 def test_add(self): 363 db = tokens.Database() 364 db.add(_entries('MILK', 'apples')) 365 self.assertEqual({e.string for e in db.entries()}, {'MILK', 'apples'}) 366 367 db.add(_entries('oranges', 'CHEESE', 'pears')) 368 self.assertEqual(len(db.entries()), 5) 369 370 db.add(_entries('MILK', 'apples', 'only this one is new')) 371 self.assertEqual(len(db.entries()), 6) 372 373 db.add(_entries('MILK')) 374 self.assertEqual({e.string 375 for e in db.entries()}, { 376 'MILK', 'apples', 'oranges', 'CHEESE', 'pears', 377 'only this one is new' 378 }) 379 380 def test_binary_format_write(self): 381 db = read_db_from_csv(CSV_DATABASE) 382 383 with io.BytesIO() as fd: 384 tokens.write_binary(db, fd) 385 binary_db = fd.getvalue() 386 387 self.assertEqual(BINARY_DATABASE, binary_db) 388 389 def test_binary_format_parse(self): 390 with io.BytesIO(BINARY_DATABASE) as binary_db: 391 db = tokens.Database(tokens.parse_binary(binary_db)) 392 393 self.assertEqual(str(db), CSV_DATABASE) 394 395 396class TestDatabaseFile(unittest.TestCase): 397 """Tests the DatabaseFile class.""" 398 def setUp(self): 399 file = tempfile.NamedTemporaryFile(delete=False) 400 file.close() 401 self._path = Path(file.name) 402 403 def tearDown(self): 404 self._path.unlink() 405 406 def test_update_csv_file(self): 407 self._path.write_text(CSV_DATABASE) 408 db = tokens.DatabaseFile(self._path) 409 self.assertEqual(str(db), CSV_DATABASE) 410 411 db.add([tokens.TokenizedStringEntry(0xffffffff, 'New entry!')]) 412 413 db.write_to_file() 414 415 self.assertEqual(self._path.read_text(), 416 CSV_DATABASE + 'ffffffff, ,"New entry!"\n') 417 418 def test_csv_file_too_short_raises_exception(self): 419 self._path.write_text('1234') 420 421 with self.assertRaises(tokens.DatabaseFormatError): 422 tokens.DatabaseFile(self._path) 423 424 def test_csv_invalid_format_raises_exception(self): 425 self._path.write_text('MK34567890') 426 427 with self.assertRaises(tokens.DatabaseFormatError): 428 tokens.DatabaseFile(self._path) 429 430 def test_csv_not_utf8(self): 431 self._path.write_bytes(b'\x80' * 20) 432 433 with self.assertRaises(tokens.DatabaseFormatError): 434 tokens.DatabaseFile(self._path) 435 436 437class TestFilter(unittest.TestCase): 438 """Tests the filtering functionality.""" 439 def setUp(self): 440 self.db = tokens.Database([ 441 tokens.TokenizedStringEntry(1, 'Luke'), 442 tokens.TokenizedStringEntry(2, 'Leia'), 443 tokens.TokenizedStringEntry(2, 'Darth Vader'), 444 tokens.TokenizedStringEntry(2, 'Emperor Palpatine'), 445 tokens.TokenizedStringEntry(3, 'Han'), 446 tokens.TokenizedStringEntry(4, 'Chewbacca'), 447 tokens.TokenizedStringEntry(5, 'Darth Maul'), 448 tokens.TokenizedStringEntry(6, 'Han Solo'), 449 ]) 450 451 def test_filter_include_single_regex(self): 452 self.db.filter(include=[' ']) # anything with a space 453 self.assertEqual( 454 set(e.string for e in self.db.entries()), 455 {'Darth Vader', 'Emperor Palpatine', 'Darth Maul', 'Han Solo'}) 456 457 def test_filter_include_multiple_regexes(self): 458 self.db.filter(include=['Darth', 'cc', '^Han$']) 459 self.assertEqual(set(e.string for e in self.db.entries()), 460 {'Darth Vader', 'Darth Maul', 'Han', 'Chewbacca'}) 461 462 def test_filter_include_no_matches(self): 463 self.db.filter(include=['Gandalf']) 464 self.assertFalse(self.db.entries()) 465 466 def test_filter_exclude_single_regex(self): 467 self.db.filter(exclude=['^[^L]']) 468 self.assertEqual(set(e.string for e in self.db.entries()), 469 {'Luke', 'Leia'}) 470 471 def test_filter_exclude_multiple_regexes(self): 472 self.db.filter(exclude=[' ', 'Han', 'Chewbacca']) 473 self.assertEqual(set(e.string for e in self.db.entries()), 474 {'Luke', 'Leia'}) 475 476 def test_filter_exclude_no_matches(self): 477 self.db.filter(exclude=['.*']) 478 self.assertFalse(self.db.entries()) 479 480 def test_filter_include_and_exclude(self): 481 self.db.filter(include=[' '], exclude=['Darth', 'Emperor']) 482 self.assertEqual(set(e.string for e in self.db.entries()), 483 {'Han Solo'}) 484 485 def test_filter_neither_include_nor_exclude(self): 486 self.db.filter() 487 self.assertEqual( 488 set(e.string for e in self.db.entries()), { 489 'Luke', 'Leia', 'Darth Vader', 'Emperor Palpatine', 'Han', 490 'Chewbacca', 'Darth Maul', 'Han Solo' 491 }) 492 493 494if __name__ == '__main__': 495 unittest.main() 496