1 #ifndef MARISA_GRIMOIRE_TRIE_KEY_H_ 2 #define MARISA_GRIMOIRE_TRIE_KEY_H_ 3 4 #include "marisa/base.h" 5 6 namespace marisa { 7 namespace grimoire { 8 namespace trie { 9 10 class Key { 11 public: Key()12 Key() : ptr_(NULL), length_(0), union_(), id_(0) { 13 union_.terminal = 0; 14 } Key(const Key & entry)15 Key(const Key &entry) 16 : ptr_(entry.ptr_), length_(entry.length_), 17 union_(entry.union_), id_(entry.id_) {} 18 19 Key &operator=(const Key &entry) { 20 ptr_ = entry.ptr_; 21 length_ = entry.length_; 22 union_ = entry.union_; 23 id_ = entry.id_; 24 return *this; 25 } 26 27 char operator[](std::size_t i) const { 28 MARISA_DEBUG_IF(i >= length_, MARISA_BOUND_ERROR); 29 return ptr_[i]; 30 } 31 substr(std::size_t pos,std::size_t length)32 void substr(std::size_t pos, std::size_t length) { 33 MARISA_DEBUG_IF(pos > length_, MARISA_BOUND_ERROR); 34 MARISA_DEBUG_IF(length > length_, MARISA_BOUND_ERROR); 35 MARISA_DEBUG_IF(pos > (length_ - length), MARISA_BOUND_ERROR); 36 ptr_ += pos; 37 length_ = (UInt32)length; 38 } 39 set_str(const char * ptr,std::size_t length)40 void set_str(const char *ptr, std::size_t length) { 41 MARISA_DEBUG_IF((ptr == NULL) && (length != 0), MARISA_NULL_ERROR); 42 MARISA_DEBUG_IF(length > MARISA_UINT32_MAX, MARISA_SIZE_ERROR); 43 ptr_ = ptr; 44 length_ = (UInt32)length; 45 } set_weight(float weight)46 void set_weight(float weight) { 47 union_.weight = weight; 48 } set_terminal(std::size_t terminal)49 void set_terminal(std::size_t terminal) { 50 MARISA_DEBUG_IF(terminal > MARISA_UINT32_MAX, MARISA_SIZE_ERROR); 51 union_.terminal = (UInt32)terminal; 52 } set_id(std::size_t id)53 void set_id(std::size_t id) { 54 MARISA_DEBUG_IF(id > MARISA_UINT32_MAX, MARISA_SIZE_ERROR); 55 id_ = (UInt32)id; 56 } 57 ptr()58 const char *ptr() const { 59 return ptr_; 60 } length()61 std::size_t length() const { 62 return length_; 63 } weight()64 float weight() const { 65 return union_.weight; 66 } terminal()67 std::size_t terminal() const { 68 return union_.terminal; 69 } id()70 std::size_t id() const { 71 return id_; 72 } 73 74 private: 75 const char *ptr_; 76 UInt32 length_; 77 union Union { 78 float weight; 79 UInt32 terminal; 80 } union_; 81 UInt32 id_; 82 }; 83 84 inline bool operator==(const Key &lhs, const Key &rhs) { 85 if (lhs.length() != rhs.length()) { 86 return false; 87 } 88 for (std::size_t i = 0; i < lhs.length(); ++i) { 89 if (lhs[i] != rhs[i]) { 90 return false; 91 } 92 } 93 return true; 94 } 95 96 inline bool operator!=(const Key &lhs, const Key &rhs) { 97 return !(lhs == rhs); 98 } 99 100 inline bool operator<(const Key &lhs, const Key &rhs) { 101 for (std::size_t i = 0; i < lhs.length(); ++i) { 102 if (i == rhs.length()) { 103 return false; 104 } 105 if (lhs[i] != rhs[i]) { 106 return (UInt8)lhs[i] < (UInt8)rhs[i]; 107 } 108 } 109 return lhs.length() < rhs.length(); 110 } 111 112 inline bool operator>(const Key &lhs, const Key &rhs) { 113 return rhs < lhs; 114 } 115 116 class ReverseKey { 117 public: ReverseKey()118 ReverseKey() : ptr_(NULL), length_(0), union_(), id_(0) { 119 union_.terminal = 0; 120 } ReverseKey(const ReverseKey & entry)121 ReverseKey(const ReverseKey &entry) 122 : ptr_(entry.ptr_), length_(entry.length_), 123 union_(entry.union_), id_(entry.id_) {} 124 125 ReverseKey &operator=(const ReverseKey &entry) { 126 ptr_ = entry.ptr_; 127 length_ = entry.length_; 128 union_ = entry.union_; 129 id_ = entry.id_; 130 return *this; 131 } 132 133 char operator[](std::size_t i) const { 134 MARISA_DEBUG_IF(i >= length_, MARISA_BOUND_ERROR); 135 return *(ptr_ - i - 1); 136 } 137 substr(std::size_t pos,std::size_t length)138 void substr(std::size_t pos, std::size_t length) { 139 MARISA_DEBUG_IF(pos > length_, MARISA_BOUND_ERROR); 140 MARISA_DEBUG_IF(length > length_, MARISA_BOUND_ERROR); 141 MARISA_DEBUG_IF(pos > (length_ - length), MARISA_BOUND_ERROR); 142 ptr_ -= pos; 143 length_ = (UInt32)length; 144 } 145 set_str(const char * ptr,std::size_t length)146 void set_str(const char *ptr, std::size_t length) { 147 MARISA_DEBUG_IF((ptr == NULL) && (length != 0), MARISA_NULL_ERROR); 148 MARISA_DEBUG_IF(length > MARISA_UINT32_MAX, MARISA_SIZE_ERROR); 149 ptr_ = ptr + length; 150 length_ = (UInt32)length; 151 } set_weight(float weight)152 void set_weight(float weight) { 153 union_.weight = weight; 154 } set_terminal(std::size_t terminal)155 void set_terminal(std::size_t terminal) { 156 MARISA_DEBUG_IF(terminal > MARISA_UINT32_MAX, MARISA_SIZE_ERROR); 157 union_.terminal = (UInt32)terminal; 158 } set_id(std::size_t id)159 void set_id(std::size_t id) { 160 MARISA_DEBUG_IF(id > MARISA_UINT32_MAX, MARISA_SIZE_ERROR); 161 id_ = (UInt32)id; 162 } 163 ptr()164 const char *ptr() const { 165 return ptr_ - length_; 166 } length()167 std::size_t length() const { 168 return length_; 169 } weight()170 float weight() const { 171 return union_.weight; 172 } terminal()173 std::size_t terminal() const { 174 return union_.terminal; 175 } id()176 std::size_t id() const { 177 return id_; 178 } 179 180 private: 181 const char *ptr_; 182 UInt32 length_; 183 union Union { 184 float weight; 185 UInt32 terminal; 186 } union_; 187 UInt32 id_; 188 }; 189 190 inline bool operator==(const ReverseKey &lhs, const ReverseKey &rhs) { 191 if (lhs.length() != rhs.length()) { 192 return false; 193 } 194 for (std::size_t i = 0; i < lhs.length(); ++i) { 195 if (lhs[i] != rhs[i]) { 196 return false; 197 } 198 } 199 return true; 200 } 201 202 inline bool operator!=(const ReverseKey &lhs, const ReverseKey &rhs) { 203 return !(lhs == rhs); 204 } 205 206 inline bool operator<(const ReverseKey &lhs, const ReverseKey &rhs) { 207 for (std::size_t i = 0; i < lhs.length(); ++i) { 208 if (i == rhs.length()) { 209 return false; 210 } 211 if (lhs[i] != rhs[i]) { 212 return (UInt8)lhs[i] < (UInt8)rhs[i]; 213 } 214 } 215 return lhs.length() < rhs.length(); 216 } 217 218 inline bool operator>(const ReverseKey &lhs, const ReverseKey &rhs) { 219 return rhs < lhs; 220 } 221 222 } // namespace trie 223 } // namespace grimoire 224 } // namespace marisa 225 226 #endif // MARISA_GRIMOIRE_TRIE_KEY_H_ 227