• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "utils/tflite/skipgram_finder.h"
18 
19 #include <cctype>
20 #include <deque>
21 #include <string>
22 #include <vector>
23 
24 #include "utils/strings/utf8.h"
25 #include "utils/utf8/unilib-common.h"
26 #include "absl/container/flat_hash_map.h"
27 #include "absl/container/flat_hash_set.h"
28 #include "absl/strings/match.h"
29 #include "absl/strings/str_split.h"
30 #include "absl/strings/string_view.h"
31 #include "tensorflow/lite/string_util.h"
32 
33 namespace libtextclassifier3 {
34 namespace {
35 
36 using ::tflite::StringRef;
37 
PreprocessToken(std::string & token)38 void PreprocessToken(std::string& token) {
39   size_t in = 0;
40   size_t out = 0;
41   while (in < token.size()) {
42     const char* in_data = token.data() + in;
43     const int n = GetNumBytesForUTF8Char(in_data);
44     if (n < 0 || n > token.size() - in) {
45       // Invalid Utf8 sequence.
46       break;
47     }
48     in += n;
49     const char32 r = ValidCharToRune(in_data);
50     if (IsPunctuation(r)) {
51       continue;
52     }
53     const char32 rl = ToLower(r);
54     char output_buffer[4];
55     int encoded_length = ValidRuneToChar(rl, output_buffer);
56     if (encoded_length > n) {
57       // This is a hack, but there are exactly two unicode characters whose
58       // lowercase versions have longer UTF-8 encodings (0x23a to 0x2c65,
59       // 0x23e to 0x2c66).  So, to avoid sizing issues, they're not lowercased.
60       encoded_length = ValidRuneToChar(r, output_buffer);
61     }
62     memcpy(token.data() + out, output_buffer, encoded_length);
63     out += encoded_length;
64   }
65 
66   size_t remaining = token.size() - in;
67   if (remaining > 0) {
68     memmove(token.data() + out, token.data() + in, remaining);
69     out += remaining;
70   }
71   token.resize(out);
72 }
73 
74 }  // namespace
75 
AddSkipgram(const std::string & skipgram,int category)76 void SkipgramFinder::AddSkipgram(const std::string& skipgram, int category) {
77   std::vector<std::string> tokens = absl::StrSplit(skipgram, ' ');
78 
79   // Store the skipgram in a trie-like structure that uses tokens as the
80   // edge labels, instead of characters.  Each node represents a skipgram made
81   // from the tokens used to reach the node, and stores the categories the
82   // skipgram is associated with.
83   TrieNode* cur = &skipgram_trie_;
84   for (auto& token : tokens) {
85     if (absl::EndsWith(token, ".*")) {
86       token.resize(token.size() - 2);
87       PreprocessToken(token);
88       auto iter = cur->prefix_to_node.find(token);
89       if (iter != cur->prefix_to_node.end()) {
90         cur = &iter->second;
91       } else {
92         cur = &cur->prefix_to_node
93                    .emplace(std::piecewise_construct,
94                             std::forward_as_tuple(token), std::make_tuple<>())
95                    .first->second;
96       }
97       continue;
98     }
99 
100     PreprocessToken(token);
101     auto iter = cur->token_to_node.find(token);
102     if (iter != cur->token_to_node.end()) {
103       cur = &iter->second;
104     } else {
105       cur = &cur->token_to_node
106                  .emplace(std::piecewise_construct,
107                           std::forward_as_tuple(token), std::make_tuple<>())
108                  .first->second;
109     }
110   }
111   cur->categories.insert(category);
112 }
113 
FindSkipgrams(const std::string & input) const114 absl::flat_hash_set<int> SkipgramFinder::FindSkipgrams(
115     const std::string& input) const {
116   std::vector<std::string> tokens = absl::StrSplit(input, ' ');
117   std::vector<absl::string_view> sv_tokens;
118   sv_tokens.reserve(tokens.size());
119   for (auto& token : tokens) {
120     PreprocessToken(token);
121     sv_tokens.emplace_back(token.data(), token.size());
122   }
123   return FindSkipgrams(sv_tokens);
124 }
125 
FindSkipgrams(const std::vector<StringRef> & tokens) const126 absl::flat_hash_set<int> SkipgramFinder::FindSkipgrams(
127     const std::vector<StringRef>& tokens) const {
128   std::vector<absl::string_view> sv_tokens;
129   sv_tokens.reserve(tokens.size());
130   for (auto& token : tokens) {
131     sv_tokens.emplace_back(token.str, token.len);
132   }
133   return FindSkipgrams(sv_tokens);
134 }
135 
FindSkipgrams(const std::vector<absl::string_view> & tokens) const136 absl::flat_hash_set<int> SkipgramFinder::FindSkipgrams(
137     const std::vector<absl::string_view>& tokens) const {
138   absl::flat_hash_set<int> categories;
139 
140   // Tracks skipgram prefixes and the index of their last token.
141   std::deque<std::pair<int, const TrieNode*>> indices_and_skipgrams;
142 
143   for (int token_i = 0; token_i < tokens.size(); token_i++) {
144     const absl::string_view& token = tokens[token_i];
145 
146     std::vector<absl::string_view> token_prefixes;
147     {
148       const char* s = token.data();
149       int n = token.size();
150       while (n > 0) {
151         const int rlen = GetNumBytesForUTF8Char(s);
152         if (rlen < 0 || rlen > n) {
153           // Invalid UTF8.
154           break;
155         }
156         n -= rlen;
157         s += rlen;
158         token_prefixes.emplace_back(token.data(), token.size() - n);
159       }
160     }
161 
162     // Drop any skipgrams prefixes which would skip more than `max_skip_size_`
163     // tokens between the end of the prefix and the current token.
164     while (!indices_and_skipgrams.empty()) {
165       if (indices_and_skipgrams.front().first + max_skip_size_ + 1 < token_i) {
166         indices_and_skipgrams.pop_front();
167       } else {
168         break;
169       }
170     }
171 
172     // Check if we can form a valid skipgram prefix (or skipgram) by adding
173     // the current token to any of the existing skipgram prefixes, or
174     // if the current token is a valid skipgram prefix (or skipgram).
175     size_t size = indices_and_skipgrams.size();
176     for (size_t skipgram_i = 0; skipgram_i <= size; skipgram_i++) {
177       const auto& node = skipgram_i < size
178                              ? *indices_and_skipgrams[skipgram_i].second
179                              : skipgram_trie_;
180 
181       auto iter = node.token_to_node.find(token);
182       if (iter != node.token_to_node.end()) {
183         categories.insert(iter->second.categories.begin(),
184                           iter->second.categories.end());
185         indices_and_skipgrams.push_back(std::make_pair(token_i, &iter->second));
186       }
187 
188       for (const auto& token_prefix : token_prefixes) {
189         auto iter = node.prefix_to_node.find(token_prefix);
190         if (iter != node.prefix_to_node.end()) {
191           categories.insert(iter->second.categories.begin(),
192                             iter->second.categories.end());
193           indices_and_skipgrams.push_back(
194               std::make_pair(token_i, &iter->second));
195         }
196       }
197     }
198   }
199 
200   return categories;
201 }
202 
203 }  // namespace libtextclassifier3
204