1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef LIBTEXTCLASSIFIER_UTILS_TFLITE_SKIPGRAM_FINDER_H_ 18 #define LIBTEXTCLASSIFIER_UTILS_TFLITE_SKIPGRAM_FINDER_H_ 19 20 #include <string> 21 #include <vector> 22 23 #include "absl/container/flat_hash_map.h" 24 #include "absl/container/flat_hash_set.h" 25 #include "absl/strings/string_view.h" 26 #include "tensorflow/lite/string_util.h" 27 28 namespace libtextclassifier3 { 29 30 // SkipgramFinder finds skipgrams in strings. 31 // 32 // To use: First, add skipgrams using AddSkipgram() - each skipgram is 33 // associated with some category. Then, call FindSkipgrams() on a string, 34 // which will return the set of categories of the skipgrams in the string. 35 // 36 // Both the skipgrams and the input strings will be tokenzied by splitting 37 // on spaces. Additionally, the tokens will be lowercased and have any 38 // trailing punctuation removed. 39 class SkipgramFinder { 40 public: SkipgramFinder(int max_skip_size)41 explicit SkipgramFinder(int max_skip_size) : max_skip_size_(max_skip_size) {} 42 43 // Adds a skipgram that SkipgramFinder should look for in input strings. 44 // Tokens may use the regex '.*' as a suffix. 45 void AddSkipgram(const std::string& skipgram, int category); 46 47 // Find all of the skipgrams in `input`, and return their categories. 48 absl::flat_hash_set<int> FindSkipgrams(const std::string& input) const; 49 50 // Find all of the skipgrams in `tokens`, and return their categories. 51 absl::flat_hash_set<int> FindSkipgrams( 52 const std::vector<absl::string_view>& tokens) const; 53 absl::flat_hash_set<int> FindSkipgrams( 54 const std::vector<::tflite::StringRef>& tokens) const; 55 56 private: 57 struct TrieNode { 58 absl::flat_hash_set<int> categories; 59 // Maps tokens to the next node in the trie. 60 absl::flat_hash_map<std::string, TrieNode> token_to_node; 61 // Maps token prefixes (<prefix>.*) to the next node in the trie. 62 absl::flat_hash_map<std::string, TrieNode> prefix_to_node; 63 }; 64 65 TrieNode skipgram_trie_; 66 int max_skip_size_; 67 }; 68 69 } // namespace libtextclassifier3 70 #endif // LIBTEXTCLASSIFIER_UTILS_TFLITE_SKIPGRAM_FINDER_H_ 71