1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_DOUBLE_ARRAY_TRIE_H_
17 #define TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_DOUBLE_ARRAY_TRIE_H_
18
19 #include <functional>
20 #include <vector>
21
22 #include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/config_generated.h"
23 #include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/utils.h"
24
25 namespace tflite {
26 namespace ops {
27 namespace custom {
28 namespace sentencepiece {
29
30 // A trie node specifies a node in the tree, either an intermediate node or
31 // a leaf node.
32 // A leaf node contains the id as an int of the string match. This id is encoded
33 // in the lower 31 bits, thus the number of distinct ids is 2^31.
34 // An intermediate node has an associated label and an offset to its children.
35 // The label is encoded in the least significant byte and must match the input
36 // character during matching.
37
38 // A memory mappable trie, compatible with Darts::DoubleArray.
39 class DoubleArrayTrie {
40 public:
41 struct Match {
MatchMatch42 Match() {}
MatchMatch43 Match(int id, int match_length) : id(id), match_length(match_length) {}
44 int id = -1;
45 int match_length = -1;
emptyMatch46 bool empty() const { return match_length == -1; }
47 bool operator==(const Match& m) const {
48 return m.id == id && m.match_length == match_length;
49 }
50 };
51
52 // nodes and nodes_length specify the array of the nodes of the trie.
DoubleArrayTrie(const flatbuffers::Vector<uint32_t> * nodes)53 explicit DoubleArrayTrie(const flatbuffers::Vector<uint32_t>* nodes)
54 : nodes_(nodes) {}
55
56 // Finds matches that are prefixes of a string.
57 template <typename callback>
58 void IteratePrefixMatches(const utils::string_view& input,
59 callback update_fn) const;
60
61 // Finds the longest prefix match of a string.
LongestPrefixMatch(const utils::string_view & input)62 Match LongestPrefixMatch(const utils::string_view& input) const {
63 Match match;
64 IteratePrefixMatches(input, [&match](const Match& m) { match = m; });
65 return match;
66 }
67
68 private:
69 // Returns whether a node as a leaf as a child.
has_leaf(uint32_t i)70 bool has_leaf(uint32_t i) const { return ((*nodes_)[i]) & 0x100; }
71
72 // Returns a value associated with a node. Available when a node is a leaf.
value(uint32_t i)73 int value(uint32_t i) const {
74 return static_cast<int>(((*nodes_)[i]) & 0x7fffffff);
75 }
76
77 // Returns a label associated with a node.
78 // A leaf node will have the MSB set and thus return an invalid label.
label(uint32_t i)79 int32_t label(uint32_t i) const { return ((*nodes_)[i]) & 0x800000ff; }
80
81 // Returns offset to children.
offset(uint32_t i)82 int32_t offset(uint32_t i) const {
83 const uint32_t node = (*nodes_)[i];
84 return (node >> 10) << ((node & 0x200) >> 6);
85 }
86
87 const flatbuffers::Vector<uint32_t>* nodes_;
88 };
89
90 template <typename callback>
IteratePrefixMatches(const utils::string_view & input,callback update_fn)91 void DoubleArrayTrie::IteratePrefixMatches(const utils::string_view& input,
92 callback update_fn) const {
93 if (nodes_->size() == 0) {
94 return;
95 }
96 uint32_t pos = offset(0);
97 for (int i = 0; i < input.length(); ++i) {
98 pos ^= static_cast<unsigned char>(input.at(i));
99 if (pos < 0 || pos >= nodes_->size() || label(pos) != input.at(i)) {
100 // No match, exit.
101 return;
102 }
103 const bool node_has_leaf = has_leaf(pos);
104 pos ^= offset(pos);
105 if (pos < 0 || pos >= nodes_->size()) {
106 // We can get here only if the structure is corrupted.
107 return;
108 }
109 if (node_has_leaf) {
110 update_fn(Match(value(pos), i + 1));
111 }
112 }
113 }
114
115 } // namespace sentencepiece
116 } // namespace custom
117 } // namespace ops
118 } // namespace tflite
119
120 #endif // TENSORFLOW_LITE_SUPPORT_CUSTOM_OPS_KERNEL_SENTENCEPIECE_DOUBLE_ARRAY_TRIE_H_
121