1from __future__ import absolute_import, division, unicode_literals 2from six import text_type 3 4from bisect import bisect_left 5 6from ._base import Trie as ABCTrie 7 8 9class Trie(ABCTrie): 10 def __init__(self, data): 11 if not all(isinstance(x, text_type) for x in data.keys()): 12 raise TypeError("All keys must be strings") 13 14 self._data = data 15 self._keys = sorted(data.keys()) 16 self._cachestr = "" 17 self._cachepoints = (0, len(data)) 18 19 def __contains__(self, key): 20 return key in self._data 21 22 def __len__(self): 23 return len(self._data) 24 25 def __iter__(self): 26 return iter(self._data) 27 28 def __getitem__(self, key): 29 return self._data[key] 30 31 def keys(self, prefix=None): 32 if prefix is None or prefix == "" or not self._keys: 33 return set(self._keys) 34 35 if prefix.startswith(self._cachestr): 36 lo, hi = self._cachepoints 37 start = i = bisect_left(self._keys, prefix, lo, hi) 38 else: 39 start = i = bisect_left(self._keys, prefix) 40 41 keys = set() 42 if start == len(self._keys): 43 return keys 44 45 while self._keys[i].startswith(prefix): 46 keys.add(self._keys[i]) 47 i += 1 48 49 self._cachestr = prefix 50 self._cachepoints = (start, i) 51 52 return keys 53 54 def has_keys_with_prefix(self, prefix): 55 if prefix in self._data: 56 return True 57 58 if prefix.startswith(self._cachestr): 59 lo, hi = self._cachepoints 60 i = bisect_left(self._keys, prefix, lo, hi) 61 else: 62 i = bisect_left(self._keys, prefix) 63 64 if i == len(self._keys): 65 return False 66 67 return self._keys[i].startswith(prefix) 68