/* * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef LIBTEXTCLASSIFIER_UTILS_TFLITE_SKIPGRAM_FINDER_H_ #define LIBTEXTCLASSIFIER_UTILS_TFLITE_SKIPGRAM_FINDER_H_ #include #include #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/strings/string_view.h" #include "tensorflow/lite/string_util.h" namespace libtextclassifier3 { // SkipgramFinder finds skipgrams in strings. // // To use: First, add skipgrams using AddSkipgram() - each skipgram is // associated with some category. Then, call FindSkipgrams() on a string, // which will return the set of categories of the skipgrams in the string. // // Both the skipgrams and the input strings will be tokenzied by splitting // on spaces. Additionally, the tokens will be lowercased and have any // trailing punctuation removed. class SkipgramFinder { public: explicit SkipgramFinder(int max_skip_size) : max_skip_size_(max_skip_size) {} // Adds a skipgram that SkipgramFinder should look for in input strings. // Tokens may use the regex '.*' as a suffix. void AddSkipgram(const std::string& skipgram, int category); // Find all of the skipgrams in `input`, and return their categories. absl::flat_hash_set FindSkipgrams(const std::string& input) const; // Find all of the skipgrams in `tokens`, and return their categories. absl::flat_hash_set FindSkipgrams( const std::vector& tokens) const; absl::flat_hash_set FindSkipgrams( const std::vector<::tflite::StringRef>& tokens) const; private: struct TrieNode { absl::flat_hash_set categories; // Maps tokens to the next node in the trie. absl::flat_hash_map token_to_node; // Maps token prefixes (.*) to the next node in the trie. absl::flat_hash_map prefix_to_node; }; TrieNode skipgram_trie_; int max_skip_size_; }; } // namespace libtextclassifier3 #endif // LIBTEXTCLASSIFIER_UTILS_TFLITE_SKIPGRAM_FINDER_H_