1# Original Algorithm: 2# By Steve Hanov, 2011. Released to the public domain. 3# Please see http://stevehanov.ca/blog/index.php?id=115 for the accompanying article. 4# 5# Adapted for PyPy/CPython by Carl Friedrich Bolz-Tereick 6# 7# Based on Daciuk, Jan, et al. "Incremental construction of minimal acyclic finite-state automata." 8# Computational linguistics 26.1 (2000): 3-16. 9# 10# Updated 2014 to use DAWG as a mapping; see 11# Kowaltowski, T.; CL. Lucchesi (1993), "Applications of finite automata representing large vocabularies", 12# Software-Practice and Experience 1993 13 14from collections import defaultdict 15from functools import cached_property 16 17 18# This class represents a node in the directed acyclic word graph (DAWG). It 19# has a list of edges to other nodes. It has functions for testing whether it 20# is equivalent to another node. Nodes are equivalent if they have identical 21# edges, and each identical edge leads to identical states. The __hash__ and 22# __eq__ functions allow it to be used as a key in a python dictionary. 23 24 25class DawgNode: 26 27 def __init__(self, dawg): 28 self.id = dawg.next_id 29 dawg.next_id += 1 30 self.final = False 31 self.edges = {} 32 33 self.linear_edges = None # later: list of (string, next_state) 34 35 def __str__(self): 36 if self.final: 37 arr = ["1"] 38 else: 39 arr = ["0"] 40 41 for (label, node) in sorted(self.edges.items()): 42 arr.append(label) 43 arr.append(str(node.id)) 44 45 return "_".join(arr) 46 __repr__ = __str__ 47 48 def _as_tuple(self): 49 edges = sorted(self.edges.items()) 50 edge_tuple = tuple((label, node.id) for label, node in edges) 51 return (self.final, edge_tuple) 52 53 def __hash__(self): 54 return hash(self._as_tuple()) 55 56 def __eq__(self, other): 57 return self._as_tuple() == other._as_tuple() 58 59 @cached_property 60 def num_reachable_linear(self): 61 # returns the number of different paths to final nodes reachable from 62 # this one 63 64 count = 0 65 # staying at self counts as a path if self is final 66 if self.final: 67 count += 1 68 for label, node in self.linear_edges: 69 count += node.num_reachable_linear 70 71 return count 72 73 74class Dawg: 75 def __init__(self): 76 self.previous_word = "" 77 self.next_id = 0 78 self.root = DawgNode(self) 79 80 # Here is a list of nodes that have not been checked for duplication. 81 self.unchecked_nodes = [] 82 83 # To deduplicate, maintain a dictionary with 84 # minimized_nodes[canonical_node] is canonical_node. 85 # Based on __hash__ and __eq__, minimized_nodes[n] is the 86 # canonical node equal to n. 87 # In other words, self.minimized_nodes[x] == x for all nodes found in 88 # the dict. 89 self.minimized_nodes = {} 90 91 # word: value mapping 92 self.data = {} 93 # value: word mapping 94 self.inverse = {} 95 96 def insert(self, word, value): 97 if not all(0 <= ord(c) < 128 for c in word): 98 raise ValueError("Use 7-bit ASCII characters only") 99 if word <= self.previous_word: 100 raise ValueError("Error: Words must be inserted in alphabetical order.") 101 if value in self.inverse: 102 raise ValueError(f"value {value} is duplicate, got it for word {self.inverse[value]} and now {word}") 103 104 # find common prefix between word and previous word 105 common_prefix = 0 106 for i in range(min(len(word), len(self.previous_word))): 107 if word[i] != self.previous_word[i]: 108 break 109 common_prefix += 1 110 111 # Check the unchecked_nodes for redundant nodes, proceeding from last 112 # one down to the common prefix size. Then truncate the list at that 113 # point. 114 self._minimize(common_prefix) 115 116 self.data[word] = value 117 self.inverse[value] = word 118 119 # add the suffix, starting from the correct node mid-way through the 120 # graph 121 if len(self.unchecked_nodes) == 0: 122 node = self.root 123 else: 124 node = self.unchecked_nodes[-1][2] 125 126 for letter in word[common_prefix:]: 127 next_node = DawgNode(self) 128 node.edges[letter] = next_node 129 self.unchecked_nodes.append((node, letter, next_node)) 130 node = next_node 131 132 node.final = True 133 self.previous_word = word 134 135 def finish(self): 136 if not self.data: 137 raise ValueError("need at least one word in the dawg") 138 # minimize all unchecked_nodes 139 self._minimize(0) 140 141 self._linearize_edges() 142 143 topoorder, linear_data, inverse = self._topological_order() 144 return self.compute_packed(topoorder), linear_data, inverse 145 146 def _minimize(self, down_to): 147 # proceed from the leaf up to a certain point 148 for i in range(len(self.unchecked_nodes) - 1, down_to - 1, -1): 149 (parent, letter, child) = self.unchecked_nodes[i] 150 if child in self.minimized_nodes: 151 # replace the child with the previously encountered one 152 parent.edges[letter] = self.minimized_nodes[child] 153 else: 154 # add the state to the minimized nodes. 155 self.minimized_nodes[child] = child 156 self.unchecked_nodes.pop() 157 158 def _lookup(self, word): 159 """ Return an integer 0 <= k < number of strings in dawg 160 where word is the kth successful traversal of the dawg. """ 161 node = self.root 162 skipped = 0 # keep track of number of final nodes that we skipped 163 index = 0 164 while index < len(word): 165 for label, child in node.linear_edges: 166 if word[index] == label[0]: 167 if word[index:index + len(label)] == label: 168 if node.final: 169 skipped += 1 170 index += len(label) 171 node = child 172 break 173 else: 174 return None 175 skipped += child.num_reachable_linear 176 else: 177 return None 178 return skipped 179 180 def enum_all_nodes(self): 181 stack = [self.root] 182 done = set() 183 while stack: 184 node = stack.pop() 185 if node.id in done: 186 continue 187 yield node 188 done.add(node.id) 189 for label, child in sorted(node.edges.items()): 190 stack.append(child) 191 192 def prettyprint(self): 193 for node in sorted(self.enum_all_nodes(), key=lambda e: e.id): 194 s_final = " final" if node.final else "" 195 print(f"{node.id}: ({node}) {s_final}") 196 for label, child in sorted(node.edges.items()): 197 print(f" {label} goto {child.id}") 198 199 def _inverse_lookup(self, number): 200 assert 0, "not working in the current form, but keep it as the pure python version of compact lookup" 201 result = [] 202 node = self.root 203 while 1: 204 if node.final: 205 if pos == 0: 206 return "".join(result) 207 pos -= 1 208 for label, child in sorted(node.edges.items()): 209 nextpos = pos - child.num_reachable_linear 210 if nextpos < 0: 211 result.append(label) 212 node = child 213 break 214 else: 215 pos = nextpos 216 else: 217 assert 0 218 219 def _linearize_edges(self): 220 # compute "linear" edges. the idea is that long chains of edges without 221 # any of the intermediate states being final or any extra incoming or 222 # outgoing edges can be represented by having removing them, and 223 # instead using longer strings as edge labels (instead of single 224 # characters) 225 incoming = defaultdict(list) 226 nodes = sorted(self.enum_all_nodes(), key=lambda e: e.id) 227 for node in nodes: 228 for label, child in sorted(node.edges.items()): 229 incoming[child].append(node) 230 for node in nodes: 231 node.linear_edges = [] 232 for label, child in sorted(node.edges.items()): 233 s = [label] 234 while len(child.edges) == 1 and len(incoming[child]) == 1 and not child.final: 235 (c, child), = child.edges.items() 236 s.append(c) 237 node.linear_edges.append((''.join(s), child)) 238 239 def _topological_order(self): 240 # compute reachable linear nodes, and the set of incoming edges for each node 241 order = [] 242 stack = [self.root] 243 seen = set() 244 while stack: 245 # depth first traversal 246 node = stack.pop() 247 if node.id in seen: 248 continue 249 seen.add(node.id) 250 order.append(node) 251 for label, child in node.linear_edges: 252 stack.append(child) 253 254 # do a (slightly bad) topological sort 255 incoming = defaultdict(set) 256 for node in order: 257 for label, child in node.linear_edges: 258 incoming[child].add((label, node)) 259 no_incoming = [order[0]] 260 topoorder = [] 261 positions = {} 262 while no_incoming: 263 node = no_incoming.pop() 264 topoorder.append(node) 265 positions[node] = len(topoorder) 266 # use "reversed" to make sure that the linear_edges get reorderd 267 # from their alphabetical order as little as necessary (no_incoming 268 # is LIFO) 269 for label, child in reversed(node.linear_edges): 270 incoming[child].discard((label, node)) 271 if not incoming[child]: 272 no_incoming.append(child) 273 del incoming[child] 274 # check result 275 assert set(topoorder) == set(order) 276 assert len(set(topoorder)) == len(topoorder) 277 278 for node in order: 279 node.linear_edges.sort(key=lambda element: positions[element[1]]) 280 281 for node in order: 282 for label, child in node.linear_edges: 283 assert positions[child] > positions[node] 284 # number the nodes. afterwards every input string in the set has a 285 # unique number in the 0 <= number < len(data). We then put the data in 286 # self.data into a linear list using these numbers as indexes. 287 topoorder[0].num_reachable_linear 288 linear_data = [None] * len(self.data) 289 inverse = {} # maps value back to index 290 for word, value in self.data.items(): 291 index = self._lookup(word) 292 linear_data[index] = value 293 inverse[value] = index 294 295 return topoorder, linear_data, inverse 296 297 def compute_packed(self, order): 298 def compute_chunk(node, offsets): 299 """ compute the packed node/edge data for a node. result is a 300 list of bytes as long as order. the jump distance calculations use 301 the offsets dictionary to know where in the final big output 302 bytestring the individual nodes will end up. """ 303 result = bytearray() 304 offset = offsets[node] 305 encode_varint_unsigned(number_add_bits(node.num_reachable_linear, node.final), result) 306 if len(node.linear_edges) == 0: 307 assert node.final 308 encode_varint_unsigned(0, result) # add a 0 saying "done" 309 prev_child_offset = offset + len(result) 310 for edgeindex, (label, targetnode) in enumerate(node.linear_edges): 311 label = label.encode('ascii') 312 child_offset = offsets[targetnode] 313 child_offset_difference = child_offset - prev_child_offset 314 315 info = number_add_bits(child_offset_difference, len(label) == 1, edgeindex == len(node.linear_edges) - 1) 316 if edgeindex == 0: 317 assert info != 0 318 encode_varint_unsigned(info, result) 319 prev_child_offset = child_offset 320 if len(label) > 1: 321 encode_varint_unsigned(len(label), result) 322 result.extend(label) 323 return result 324 325 def compute_new_offsets(chunks, offsets): 326 """ Given a list of chunks, compute the new offsets (by adding the 327 chunk lengths together). Also check if we cannot shrink the output 328 further because none of the node offsets are smaller now. if that's 329 the case return None. """ 330 new_offsets = {} 331 curr_offset = 0 332 should_continue = False 333 for node, result in zip(order, chunks): 334 if curr_offset < offsets[node]: 335 # the new offset is below the current assumption, this 336 # means we can shrink the output more 337 should_continue = True 338 new_offsets[node] = curr_offset 339 curr_offset += len(result) 340 if not should_continue: 341 return None 342 return new_offsets 343 344 # assign initial offsets to every node 345 offsets = {} 346 for i, node in enumerate(order): 347 # we don't know position of the edge yet, just use something big as 348 # the starting position. we'll have to do further iterations anyway, 349 # but the size is at least a lower limit then 350 offsets[node] = i * 2 ** 30 351 352 353 # due to the variable integer width encoding of edge targets we need to 354 # run this to fixpoint. in the process we shrink the output more and 355 # more until we can't any more. at any point we can stop and use the 356 # output, but we might need padding zero bytes when joining the chunks 357 # to have the correct jump distances 358 last_offsets = None 359 while 1: 360 chunks = [compute_chunk(node, offsets) for node in order] 361 last_offsets = offsets 362 offsets = compute_new_offsets(chunks, offsets) 363 if offsets is None: # couldn't shrink 364 break 365 366 # build the final packed string 367 total_result = bytearray() 368 for node, result in zip(order, chunks): 369 node_offset = last_offsets[node] 370 if node_offset > len(total_result): 371 # need to pad to get the offsets correct 372 padding = b"\x00" * (node_offset - len(total_result)) 373 total_result.extend(padding) 374 assert node_offset == len(total_result) 375 total_result.extend(result) 376 return bytes(total_result) 377 378 379# ______________________________________________________________________ 380# the following functions operate on the packed representation 381 382def number_add_bits(x, *bits): 383 for bit in bits: 384 assert bit == 0 or bit == 1 385 x = (x << 1) | bit 386 return x 387 388def encode_varint_unsigned(i, res): 389 # https://en.wikipedia.org/wiki/LEB128 unsigned variant 390 more = True 391 startlen = len(res) 392 if i < 0: 393 raise ValueError("only positive numbers supported", i) 394 while more: 395 lowest7bits = i & 0b1111111 396 i >>= 7 397 if i == 0: 398 more = False 399 else: 400 lowest7bits |= 0b10000000 401 res.append(lowest7bits) 402 return len(res) - startlen 403 404def number_split_bits(x, n, acc=()): 405 if n == 1: 406 return x >> 1, x & 1 407 if n == 2: 408 return x >> 2, (x >> 1) & 1, x & 1 409 assert 0, "implement me!" 410 411def decode_varint_unsigned(b, index=0): 412 res = 0 413 shift = 0 414 while True: 415 byte = b[index] 416 res = res | ((byte & 0b1111111) << shift) 417 index += 1 418 shift += 7 419 if not (byte & 0b10000000): 420 return res, index 421 422def decode_node(packed, node): 423 x, node = decode_varint_unsigned(packed, node) 424 node_count, final = number_split_bits(x, 1) 425 return node_count, final, node 426 427def decode_edge(packed, edgeindex, prev_child_offset, offset): 428 x, offset = decode_varint_unsigned(packed, offset) 429 if x == 0 and edgeindex == 0: 430 raise KeyError # trying to decode past a final node 431 child_offset_difference, len1, last_edge = number_split_bits(x, 2) 432 child_offset = prev_child_offset + child_offset_difference 433 if len1: 434 size = 1 435 else: 436 size, offset = decode_varint_unsigned(packed, offset) 437 return child_offset, last_edge, size, offset 438 439def _match_edge(packed, s, size, node_offset, stringpos): 440 if size > 1 and stringpos + size > len(s): 441 # past the end of the string, can't match 442 return False 443 for i in range(size): 444 if packed[node_offset + i] != s[stringpos + i]: 445 # if a subsequent char of an edge doesn't match, the word isn't in 446 # the dawg 447 if i > 0: 448 raise KeyError 449 return False 450 return True 451 452def lookup(packed, data, s): 453 return data[_lookup(packed, s)] 454 455def _lookup(packed, s): 456 stringpos = 0 457 node_offset = 0 458 skipped = 0 # keep track of number of final nodes that we skipped 459 false = False 460 while stringpos < len(s): 461 #print(f"{node_offset=} {stringpos=}") 462 _, final, edge_offset = decode_node(packed, node_offset) 463 prev_child_offset = edge_offset 464 edgeindex = 0 465 while 1: 466 child_offset, last_edge, size, edgelabel_chars_offset = decode_edge(packed, edgeindex, prev_child_offset, edge_offset) 467 #print(f" {edge_offset=} {child_offset=} {last_edge=} {size=} {edgelabel_chars_offset=}") 468 edgeindex += 1 469 prev_child_offset = child_offset 470 if _match_edge(packed, s, size, edgelabel_chars_offset, stringpos): 471 # match 472 if final: 473 skipped += 1 474 stringpos += size 475 node_offset = child_offset 476 break 477 if last_edge: 478 raise KeyError 479 descendant_count, _, _ = decode_node(packed, child_offset) 480 skipped += descendant_count 481 edge_offset = edgelabel_chars_offset + size 482 _, final, _ = decode_node(packed, node_offset) 483 if final: 484 return skipped 485 raise KeyError 486 487def inverse_lookup(packed, inverse, x): 488 pos = inverse[x] 489 return _inverse_lookup(packed, pos) 490 491def _inverse_lookup(packed, pos): 492 result = bytearray() 493 node_offset = 0 494 while 1: 495 node_count, final, edge_offset = decode_node(packed, node_offset) 496 if final: 497 if pos == 0: 498 return bytes(result) 499 pos -= 1 500 prev_child_offset = edge_offset 501 edgeindex = 0 502 while 1: 503 child_offset, last_edge, size, edgelabel_chars_offset = decode_edge(packed, edgeindex, prev_child_offset, edge_offset) 504 edgeindex += 1 505 prev_child_offset = child_offset 506 descendant_count, _, _ = decode_node(packed, child_offset) 507 nextpos = pos - descendant_count 508 if nextpos < 0: 509 assert edgelabel_chars_offset >= 0 510 result.extend(packed[edgelabel_chars_offset: edgelabel_chars_offset + size]) 511 node_offset = child_offset 512 break 513 elif not last_edge: 514 pos = nextpos 515 edge_offset = edgelabel_chars_offset + size 516 else: 517 raise KeyError 518 else: 519 raise KeyError 520 521 522def build_compression_dawg(ucdata): 523 d = Dawg() 524 ucdata.sort() 525 for name, value in ucdata: 526 d.insert(name, value) 527 packed, pos_to_code, reversedict = d.finish() 528 print("size of dawg [KiB]", round(len(packed) / 1024, 2)) 529 # check that lookup and inverse_lookup work correctly on the input data 530 for name, value in ucdata: 531 assert lookup(packed, pos_to_code, name.encode('ascii')) == value 532 assert inverse_lookup(packed, reversedict, value) == name.encode('ascii') 533 return packed, pos_to_code 534