• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Original Algorithm:
2# By Steve Hanov, 2011. Released to the public domain.
3# Please see http://stevehanov.ca/blog/index.php?id=115 for the accompanying article.
4#
5# Adapted for PyPy/CPython by Carl Friedrich Bolz-Tereick
6#
7# Based on Daciuk, Jan, et al. "Incremental construction of minimal acyclic finite-state automata."
8# Computational linguistics 26.1 (2000): 3-16.
9#
10# Updated 2014 to use DAWG as a mapping; see
11# Kowaltowski, T.; CL. Lucchesi (1993), "Applications of finite automata representing large vocabularies",
12# Software-Practice and Experience 1993
13
14from collections import defaultdict
15from functools import cached_property
16
17
18# This class represents a node in the directed acyclic word graph (DAWG). It
19# has a list of edges to other nodes. It has functions for testing whether it
20# is equivalent to another node. Nodes are equivalent if they have identical
21# edges, and each identical edge leads to identical states. The __hash__ and
22# __eq__ functions allow it to be used as a key in a python dictionary.
23
24
25class DawgNode:
26
27    def __init__(self, dawg):
28        self.id = dawg.next_id
29        dawg.next_id += 1
30        self.final = False
31        self.edges = {}
32
33        self.linear_edges = None # later: list of (string, next_state)
34
35    def __str__(self):
36        if self.final:
37            arr = ["1"]
38        else:
39            arr = ["0"]
40
41        for (label, node) in sorted(self.edges.items()):
42            arr.append(label)
43            arr.append(str(node.id))
44
45        return "_".join(arr)
46    __repr__ = __str__
47
48    def _as_tuple(self):
49        edges = sorted(self.edges.items())
50        edge_tuple = tuple((label, node.id) for label, node in edges)
51        return (self.final, edge_tuple)
52
53    def __hash__(self):
54        return hash(self._as_tuple())
55
56    def __eq__(self, other):
57        return self._as_tuple() == other._as_tuple()
58
59    @cached_property
60    def num_reachable_linear(self):
61        # returns the number of different paths to final nodes reachable from
62        # this one
63
64        count = 0
65        # staying at self counts as a path if self is final
66        if self.final:
67            count += 1
68        for label, node in self.linear_edges:
69            count += node.num_reachable_linear
70
71        return count
72
73
74class Dawg:
75    def __init__(self):
76        self.previous_word = ""
77        self.next_id = 0
78        self.root = DawgNode(self)
79
80        # Here is a list of nodes that have not been checked for duplication.
81        self.unchecked_nodes = []
82
83        # To deduplicate, maintain a dictionary with
84        # minimized_nodes[canonical_node] is canonical_node.
85        # Based on __hash__ and __eq__, minimized_nodes[n] is the
86        # canonical node equal to n.
87        # In other words, self.minimized_nodes[x] == x for all nodes found in
88        # the dict.
89        self.minimized_nodes = {}
90
91        # word: value mapping
92        self.data = {}
93        # value: word mapping
94        self.inverse = {}
95
96    def insert(self, word, value):
97        if not all(0 <= ord(c) < 128 for c in word):
98            raise ValueError("Use 7-bit ASCII characters only")
99        if word <= self.previous_word:
100            raise ValueError("Error: Words must be inserted in alphabetical order.")
101        if value in self.inverse:
102            raise ValueError(f"value {value} is duplicate, got it for word {self.inverse[value]} and now {word}")
103
104        # find common prefix between word and previous word
105        common_prefix = 0
106        for i in range(min(len(word), len(self.previous_word))):
107            if word[i] != self.previous_word[i]:
108                break
109            common_prefix += 1
110
111        # Check the unchecked_nodes for redundant nodes, proceeding from last
112        # one down to the common prefix size. Then truncate the list at that
113        # point.
114        self._minimize(common_prefix)
115
116        self.data[word] = value
117        self.inverse[value] = word
118
119        # add the suffix, starting from the correct node mid-way through the
120        # graph
121        if len(self.unchecked_nodes) == 0:
122            node = self.root
123        else:
124            node = self.unchecked_nodes[-1][2]
125
126        for letter in word[common_prefix:]:
127            next_node = DawgNode(self)
128            node.edges[letter] = next_node
129            self.unchecked_nodes.append((node, letter, next_node))
130            node = next_node
131
132        node.final = True
133        self.previous_word = word
134
135    def finish(self):
136        if not self.data:
137            raise ValueError("need at least one word in the dawg")
138        # minimize all unchecked_nodes
139        self._minimize(0)
140
141        self._linearize_edges()
142
143        topoorder, linear_data, inverse = self._topological_order()
144        return self.compute_packed(topoorder), linear_data, inverse
145
146    def _minimize(self, down_to):
147        # proceed from the leaf up to a certain point
148        for i in range(len(self.unchecked_nodes) - 1, down_to - 1, -1):
149            (parent, letter, child) = self.unchecked_nodes[i]
150            if child in self.minimized_nodes:
151                # replace the child with the previously encountered one
152                parent.edges[letter] = self.minimized_nodes[child]
153            else:
154                # add the state to the minimized nodes.
155                self.minimized_nodes[child] = child
156            self.unchecked_nodes.pop()
157
158    def _lookup(self, word):
159        """ Return an integer 0 <= k < number of strings in dawg
160        where word is the kth successful traversal of the dawg. """
161        node = self.root
162        skipped = 0  # keep track of number of final nodes that we skipped
163        index = 0
164        while index < len(word):
165            for label, child in node.linear_edges:
166                if word[index] == label[0]:
167                    if word[index:index + len(label)] == label:
168                        if node.final:
169                            skipped += 1
170                        index += len(label)
171                        node = child
172                        break
173                    else:
174                        return None
175                skipped += child.num_reachable_linear
176            else:
177                return None
178        return skipped
179
180    def enum_all_nodes(self):
181        stack = [self.root]
182        done = set()
183        while stack:
184            node = stack.pop()
185            if node.id in done:
186                continue
187            yield node
188            done.add(node.id)
189            for label, child in sorted(node.edges.items()):
190                stack.append(child)
191
192    def prettyprint(self):
193        for node in sorted(self.enum_all_nodes(), key=lambda e: e.id):
194            s_final = " final" if node.final else ""
195            print(f"{node.id}: ({node}) {s_final}")
196            for label, child in sorted(node.edges.items()):
197                print(f"    {label} goto {child.id}")
198
199    def _inverse_lookup(self, number):
200        assert 0, "not working in the current form, but keep it as the pure python version of compact lookup"
201        result = []
202        node = self.root
203        while 1:
204            if node.final:
205                if pos == 0:
206                    return "".join(result)
207                pos -= 1
208            for label, child in sorted(node.edges.items()):
209                nextpos = pos - child.num_reachable_linear
210                if nextpos < 0:
211                    result.append(label)
212                    node = child
213                    break
214                else:
215                    pos = nextpos
216            else:
217                assert 0
218
219    def _linearize_edges(self):
220        # compute "linear" edges. the idea is that long chains of edges without
221        # any of the intermediate states being final or any extra incoming or
222        # outgoing edges can be represented by having removing them, and
223        # instead using longer strings as edge labels (instead of single
224        # characters)
225        incoming = defaultdict(list)
226        nodes = sorted(self.enum_all_nodes(), key=lambda e: e.id)
227        for node in nodes:
228            for label, child in sorted(node.edges.items()):
229                incoming[child].append(node)
230        for node in nodes:
231            node.linear_edges = []
232            for label, child in sorted(node.edges.items()):
233                s = [label]
234                while len(child.edges) == 1 and len(incoming[child]) == 1 and not child.final:
235                    (c, child), = child.edges.items()
236                    s.append(c)
237                node.linear_edges.append((''.join(s), child))
238
239    def _topological_order(self):
240        # compute reachable linear nodes, and the set of incoming edges for each node
241        order = []
242        stack = [self.root]
243        seen = set()
244        while stack:
245            # depth first traversal
246            node = stack.pop()
247            if node.id in seen:
248                continue
249            seen.add(node.id)
250            order.append(node)
251            for label, child in node.linear_edges:
252                stack.append(child)
253
254        # do a (slightly bad) topological sort
255        incoming = defaultdict(set)
256        for node in order:
257            for label, child in node.linear_edges:
258                incoming[child].add((label, node))
259        no_incoming = [order[0]]
260        topoorder = []
261        positions = {}
262        while no_incoming:
263            node = no_incoming.pop()
264            topoorder.append(node)
265            positions[node] = len(topoorder)
266            # use "reversed" to make sure that the linear_edges get reorderd
267            # from their alphabetical order as little as necessary (no_incoming
268            # is LIFO)
269            for label, child in reversed(node.linear_edges):
270                incoming[child].discard((label, node))
271                if not incoming[child]:
272                    no_incoming.append(child)
273                    del incoming[child]
274        # check result
275        assert set(topoorder) == set(order)
276        assert len(set(topoorder)) == len(topoorder)
277
278        for node in order:
279            node.linear_edges.sort(key=lambda element: positions[element[1]])
280
281        for node in order:
282            for label, child in node.linear_edges:
283                assert positions[child] > positions[node]
284        # number the nodes. afterwards every input string in the set has a
285        # unique number in the 0 <= number < len(data). We then put the data in
286        # self.data into a linear list using these numbers as indexes.
287        topoorder[0].num_reachable_linear
288        linear_data = [None] * len(self.data)
289        inverse = {} # maps value back to index
290        for word, value in self.data.items():
291            index = self._lookup(word)
292            linear_data[index] = value
293            inverse[value] = index
294
295        return topoorder, linear_data, inverse
296
297    def compute_packed(self, order):
298        def compute_chunk(node, offsets):
299            """ compute the packed node/edge data for a node. result is a
300            list of bytes as long as order. the jump distance calculations use
301            the offsets dictionary to know where in the final big output
302            bytestring the individual nodes will end up. """
303            result = bytearray()
304            offset = offsets[node]
305            encode_varint_unsigned(number_add_bits(node.num_reachable_linear, node.final), result)
306            if len(node.linear_edges) == 0:
307                assert node.final
308                encode_varint_unsigned(0, result) # add a 0 saying "done"
309            prev_child_offset = offset + len(result)
310            for edgeindex, (label, targetnode) in enumerate(node.linear_edges):
311                label = label.encode('ascii')
312                child_offset = offsets[targetnode]
313                child_offset_difference = child_offset - prev_child_offset
314
315                info = number_add_bits(child_offset_difference, len(label) == 1, edgeindex == len(node.linear_edges) - 1)
316                if edgeindex == 0:
317                    assert info != 0
318                encode_varint_unsigned(info, result)
319                prev_child_offset = child_offset
320                if len(label) > 1:
321                    encode_varint_unsigned(len(label), result)
322                result.extend(label)
323            return result
324
325        def compute_new_offsets(chunks, offsets):
326            """ Given a list of chunks, compute the new offsets (by adding the
327            chunk lengths together). Also check if we cannot shrink the output
328            further because none of the node offsets are smaller now. if that's
329            the case return None. """
330            new_offsets = {}
331            curr_offset = 0
332            should_continue = False
333            for node, result in zip(order, chunks):
334                if curr_offset < offsets[node]:
335                    # the new offset is below the current assumption, this
336                    # means we can shrink the output more
337                    should_continue = True
338                new_offsets[node] = curr_offset
339                curr_offset += len(result)
340            if not should_continue:
341                return None
342            return new_offsets
343
344        # assign initial offsets to every node
345        offsets = {}
346        for i, node in enumerate(order):
347            # we don't know position of the edge yet, just use something big as
348            # the starting position. we'll have to do further iterations anyway,
349            # but the size is at least a lower limit then
350            offsets[node] = i * 2 ** 30
351
352
353        # due to the variable integer width encoding of edge targets we need to
354        # run this to fixpoint. in the process we shrink the output more and
355        # more until we can't any more. at any point we can stop and use the
356        # output, but we might need padding zero bytes when joining the chunks
357        # to have the correct jump distances
358        last_offsets = None
359        while 1:
360            chunks = [compute_chunk(node, offsets) for node in order]
361            last_offsets = offsets
362            offsets = compute_new_offsets(chunks, offsets)
363            if offsets is None: # couldn't shrink
364                break
365
366        # build the final packed string
367        total_result = bytearray()
368        for node, result in zip(order, chunks):
369            node_offset = last_offsets[node]
370            if node_offset > len(total_result):
371                # need to pad to get the offsets correct
372                padding = b"\x00" * (node_offset - len(total_result))
373                total_result.extend(padding)
374            assert node_offset == len(total_result)
375            total_result.extend(result)
376        return bytes(total_result)
377
378
379# ______________________________________________________________________
380# the following functions operate on the packed representation
381
382def number_add_bits(x, *bits):
383    for bit in bits:
384        assert bit == 0 or bit == 1
385        x = (x << 1) | bit
386    return x
387
388def encode_varint_unsigned(i, res):
389    # https://en.wikipedia.org/wiki/LEB128 unsigned variant
390    more = True
391    startlen = len(res)
392    if i < 0:
393        raise ValueError("only positive numbers supported", i)
394    while more:
395        lowest7bits = i & 0b1111111
396        i >>= 7
397        if i == 0:
398            more = False
399        else:
400            lowest7bits |= 0b10000000
401        res.append(lowest7bits)
402    return len(res) - startlen
403
404def number_split_bits(x, n, acc=()):
405    if n == 1:
406        return x >> 1, x & 1
407    if n == 2:
408        return x >> 2, (x >> 1) & 1, x & 1
409    assert 0, "implement me!"
410
411def decode_varint_unsigned(b, index=0):
412    res = 0
413    shift = 0
414    while True:
415        byte = b[index]
416        res = res | ((byte & 0b1111111) << shift)
417        index += 1
418        shift += 7
419        if not (byte & 0b10000000):
420            return res, index
421
422def decode_node(packed, node):
423    x, node = decode_varint_unsigned(packed, node)
424    node_count, final = number_split_bits(x, 1)
425    return node_count, final, node
426
427def decode_edge(packed, edgeindex, prev_child_offset, offset):
428    x, offset = decode_varint_unsigned(packed, offset)
429    if x == 0 and edgeindex == 0:
430        raise KeyError # trying to decode past a final node
431    child_offset_difference, len1, last_edge = number_split_bits(x, 2)
432    child_offset = prev_child_offset + child_offset_difference
433    if len1:
434        size = 1
435    else:
436        size, offset = decode_varint_unsigned(packed, offset)
437    return child_offset, last_edge, size, offset
438
439def _match_edge(packed, s, size, node_offset, stringpos):
440    if size > 1 and stringpos + size > len(s):
441        # past the end of the string, can't match
442        return False
443    for i in range(size):
444        if packed[node_offset + i] != s[stringpos + i]:
445            # if a subsequent char of an edge doesn't match, the word isn't in
446            # the dawg
447            if i > 0:
448                raise KeyError
449            return False
450    return True
451
452def lookup(packed, data, s):
453    return data[_lookup(packed, s)]
454
455def _lookup(packed, s):
456    stringpos = 0
457    node_offset = 0
458    skipped = 0  # keep track of number of final nodes that we skipped
459    false = False
460    while stringpos < len(s):
461        #print(f"{node_offset=} {stringpos=}")
462        _, final, edge_offset = decode_node(packed, node_offset)
463        prev_child_offset = edge_offset
464        edgeindex = 0
465        while 1:
466            child_offset, last_edge, size, edgelabel_chars_offset = decode_edge(packed, edgeindex, prev_child_offset, edge_offset)
467            #print(f"    {edge_offset=} {child_offset=} {last_edge=} {size=} {edgelabel_chars_offset=}")
468            edgeindex += 1
469            prev_child_offset = child_offset
470            if _match_edge(packed, s, size, edgelabel_chars_offset, stringpos):
471                # match
472                if final:
473                    skipped += 1
474                stringpos += size
475                node_offset = child_offset
476                break
477            if last_edge:
478                raise KeyError
479            descendant_count, _, _ = decode_node(packed, child_offset)
480            skipped += descendant_count
481            edge_offset = edgelabel_chars_offset + size
482    _, final, _ = decode_node(packed, node_offset)
483    if final:
484        return skipped
485    raise KeyError
486
487def inverse_lookup(packed, inverse, x):
488    pos = inverse[x]
489    return _inverse_lookup(packed, pos)
490
491def _inverse_lookup(packed, pos):
492    result = bytearray()
493    node_offset = 0
494    while 1:
495        node_count, final, edge_offset = decode_node(packed, node_offset)
496        if final:
497            if pos == 0:
498                return bytes(result)
499            pos -= 1
500        prev_child_offset = edge_offset
501        edgeindex = 0
502        while 1:
503            child_offset, last_edge, size, edgelabel_chars_offset = decode_edge(packed, edgeindex, prev_child_offset, edge_offset)
504            edgeindex += 1
505            prev_child_offset = child_offset
506            descendant_count, _, _ = decode_node(packed, child_offset)
507            nextpos = pos - descendant_count
508            if nextpos < 0:
509                assert edgelabel_chars_offset >= 0
510                result.extend(packed[edgelabel_chars_offset: edgelabel_chars_offset + size])
511                node_offset = child_offset
512                break
513            elif not last_edge:
514                pos = nextpos
515                edge_offset = edgelabel_chars_offset + size
516            else:
517                raise KeyError
518        else:
519            raise KeyError
520
521
522def build_compression_dawg(ucdata):
523    d = Dawg()
524    ucdata.sort()
525    for name, value in ucdata:
526        d.insert(name, value)
527    packed, pos_to_code, reversedict = d.finish()
528    print("size of dawg [KiB]", round(len(packed) / 1024, 2))
529    # check that lookup and inverse_lookup work correctly on the input data
530    for name, value in ucdata:
531        assert lookup(packed, pos_to_code, name.encode('ascii')) == value
532        assert inverse_lookup(packed, reversedict, value) == name.encode('ascii')
533    return packed, pos_to_code
534