1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_DOUBLE_ARRAY_TRIE_H_ 18 #define LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_DOUBLE_ARRAY_TRIE_H_ 19 20 #include <functional> 21 #include <vector> 22 23 #include "utils/base/endian.h" 24 #include "utils/base/integral_types.h" 25 #include "utils/sentencepiece/matcher.h" 26 #include "utils/strings/stringpiece.h" 27 28 namespace libtextclassifier3 { 29 30 // A trie node specifies a node in the tree, either an intermediate node or 31 // a leaf node. 32 // A leaf node contains the id as an int of the string match. This id is encoded 33 // in the lower 30 bits, thus the number of distinct ids is 2^30. 34 // An intermediate node has an associated label and an offset to it's children. 35 // The label is encoded in the least significant byte and must match the input 36 // character during matching. 37 // We account for endianness when using the node values, as they are serialized 38 // (in little endian) as bytes in the flatbuffer model. 39 typedef uint32 TrieNode; 40 41 // A memory mappable trie, compatible with Darts::DoubleArray. 42 class DoubleArrayTrie : public SentencePieceMatcher { 43 public: 44 // nodes and nodes_length specify the array of the nodes of the trie. DoubleArrayTrie(const TrieNode * nodes,const int nodes_length)45 DoubleArrayTrie(const TrieNode* nodes, const int nodes_length) 46 : nodes_(nodes), nodes_length_(nodes_length) {} 47 48 // Find matches that are prefixes of a string. 49 bool FindAllPrefixMatches(StringPiece input, 50 std::vector<TrieMatch>* matches) const override; 51 // Find the longest prefix match of a string. 52 bool LongestPrefixMatch(StringPiece input, 53 TrieMatch* longest_match) const override; 54 55 private: 56 // Returns whether a node as a leaf as a child. has_leaf(uint32 i)57 bool has_leaf(uint32 i) const { return nodes_[i] & 0x100; } 58 59 // Available when a node is a leaf. value(uint32 i)60 int value(uint32 i) const { 61 return static_cast<int>(LittleEndian::ToHost32(nodes_[i]) & 0x7fffffff); 62 } 63 64 // Label associated with a node. 65 // A leaf node will have the MSB set and thus return an invalid label. label(uint32 i)66 uint32 label(uint32 i) const { 67 return LittleEndian::ToHost32(nodes_[i]) & 0x800000ff; 68 } 69 70 // Returns offset to children. offset(uint32 i)71 uint32 offset(uint32 i) const { 72 const uint32 node = LittleEndian::ToHost32(nodes_[i]); 73 return (node >> 10) << ((node & 0x200) >> 6); 74 } 75 76 bool GatherPrefixMatches( 77 StringPiece input, const std::function<void(TrieMatch)>& update_fn) const; 78 79 const TrieNode* nodes_; 80 const int nodes_length_; 81 }; 82 83 } // namespace libtextclassifier3 84 85 #endif // LIBTEXTCLASSIFIER_UTILS_SENTENCEPIECE_DOUBLE_ARRAY_TRIE_H_ 86