1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "actions/ngram-model.h"
18
19 #include <algorithm>
20
21 #include "actions/feature-processor.h"
22 #include "utils/hash/farmhash.h"
23 #include "utils/strings/stringpiece.h"
24
25 namespace libtextclassifier3 {
26 namespace {
27
28 // An iterator to iterate over the initial tokens of the n-grams of a model.
29 class FirstTokenIterator
30 : public std::iterator<std::random_access_iterator_tag,
31 /*value_type=*/uint32, /*difference_type=*/ptrdiff_t,
32 /*pointer=*/const uint32*,
33 /*reference=*/uint32&> {
34 public:
FirstTokenIterator(const NGramLinearRegressionModel * model,int index)35 explicit FirstTokenIterator(const NGramLinearRegressionModel* model,
36 int index)
37 : model_(model), index_(index) {}
38
operator ++()39 FirstTokenIterator& operator++() {
40 index_++;
41 return *this;
42 }
operator +=(ptrdiff_t dist)43 FirstTokenIterator& operator+=(ptrdiff_t dist) {
44 index_ += dist;
45 return *this;
46 }
operator -(const FirstTokenIterator & other_it) const47 ptrdiff_t operator-(const FirstTokenIterator& other_it) const {
48 return index_ - other_it.index_;
49 }
operator *() const50 uint32 operator*() const {
51 const uint32 token_offset = (*model_->ngram_start_offsets())[index_];
52 return (*model_->hashed_ngram_tokens())[token_offset];
53 }
index() const54 int index() const { return index_; }
55
56 private:
57 const NGramLinearRegressionModel* model_;
58 int index_;
59 };
60
61 } // anonymous namespace
62
Create(const UniLib * unilib,const NGramLinearRegressionModel * model,const Tokenizer * tokenizer)63 std::unique_ptr<NGramSensitiveModel> NGramSensitiveModel::Create(
64 const UniLib* unilib, const NGramLinearRegressionModel* model,
65 const Tokenizer* tokenizer) {
66 if (model == nullptr) {
67 return nullptr;
68 }
69 if (tokenizer == nullptr && model->tokenizer_options() == nullptr) {
70 TC3_LOG(ERROR) << "No tokenizer options specified.";
71 return nullptr;
72 }
73 return std::unique_ptr<NGramSensitiveModel>(
74 new NGramSensitiveModel(unilib, model, tokenizer));
75 }
76
NGramSensitiveModel(const UniLib * unilib,const NGramLinearRegressionModel * model,const Tokenizer * tokenizer)77 NGramSensitiveModel::NGramSensitiveModel(
78 const UniLib* unilib, const NGramLinearRegressionModel* model,
79 const Tokenizer* tokenizer)
80 : model_(model) {
81 // Create new tokenizer if options are specified, reuse feature processor
82 // tokenizer otherwise.
83 if (model->tokenizer_options() != nullptr) {
84 owned_tokenizer_ = CreateTokenizer(model->tokenizer_options(), unilib);
85 tokenizer_ = owned_tokenizer_.get();
86 } else {
87 tokenizer_ = tokenizer;
88 }
89 }
90
91 // Returns whether a given n-gram matches the token stream.
IsNGramMatch(const uint32 * tokens,size_t num_tokens,const uint32 * ngram_tokens,size_t num_ngram_tokens,int max_skips) const92 bool NGramSensitiveModel::IsNGramMatch(const uint32* tokens, size_t num_tokens,
93 const uint32* ngram_tokens,
94 size_t num_ngram_tokens,
95 int max_skips) const {
96 int token_idx = 0, ngram_token_idx = 0, skip_remain = 0;
97 for (; token_idx < num_tokens && ngram_token_idx < num_ngram_tokens;) {
98 if (tokens[token_idx] == ngram_tokens[ngram_token_idx]) {
99 // Token matches. Advance both and reset the skip budget.
100 ++token_idx;
101 ++ngram_token_idx;
102 skip_remain = max_skips;
103 } else if (skip_remain > 0) {
104 // No match, but we have skips left, so just advance over the token.
105 ++token_idx;
106 skip_remain--;
107 } else {
108 // No match and we're out of skips. Reject.
109 return false;
110 }
111 }
112 return ngram_token_idx == num_ngram_tokens;
113 }
114
115 // Calculates the total number of skip-grams that can be created for a stream
116 // with the given number of tokens.
GetNumSkipGrams(int num_tokens,int max_ngram_length,int max_skips)117 uint64 NGramSensitiveModel::GetNumSkipGrams(int num_tokens,
118 int max_ngram_length,
119 int max_skips) {
120 // Start with unigrams.
121 uint64 total = num_tokens;
122 for (int ngram_len = 2;
123 ngram_len <= max_ngram_length && ngram_len <= num_tokens; ++ngram_len) {
124 // We can easily compute the expected length of the n-gram (with skips),
125 // but it doesn't account for the fact that they may be longer than the
126 // input and should be pruned.
127 // Instead, we iterate over the distribution of effective n-gram lengths
128 // and add each length individually.
129 const int num_gaps = ngram_len - 1;
130 const int len_min = ngram_len;
131 const int len_max = ngram_len + num_gaps * max_skips;
132 const int len_mid = (len_max + len_min) / 2;
133 for (int len_i = len_min; len_i <= len_max; ++len_i) {
134 if (len_i > num_tokens) continue;
135 const int num_configs_of_len_i =
136 len_i <= len_mid ? len_i - len_min + 1 : len_max - len_i + 1;
137 const int num_start_offsets = num_tokens - len_i + 1;
138 total += num_configs_of_len_i * num_start_offsets;
139 }
140 }
141 return total;
142 }
143
GetFirstTokenMatches(uint32 token_hash) const144 std::pair<int, int> NGramSensitiveModel::GetFirstTokenMatches(
145 uint32 token_hash) const {
146 const int num_ngrams = model_->ngram_weights()->size();
147 const auto start_it = FirstTokenIterator(model_, 0);
148 const auto end_it = FirstTokenIterator(model_, num_ngrams);
149 const int start = std::lower_bound(start_it, end_it, token_hash).index();
150 const int end = std::upper_bound(start_it, end_it, token_hash).index();
151 return std::make_pair(start, end);
152 }
153
Eval(const UnicodeText & text) const154 std::pair<bool, float> NGramSensitiveModel::Eval(
155 const UnicodeText& text) const {
156 const std::vector<Token> raw_tokens = tokenizer_->Tokenize(text);
157
158 // If we have no tokens, then just bail early.
159 if (raw_tokens.empty()) {
160 return std::make_pair(false, model_->default_token_weight());
161 }
162
163 // Hash the tokens.
164 std::vector<uint32> tokens;
165 tokens.reserve(raw_tokens.size());
166 for (const Token& raw_token : raw_tokens) {
167 tokens.push_back(tc3farmhash::Fingerprint32(raw_token.value.data(),
168 raw_token.value.length()));
169 }
170
171 // Calculate the total number of skip-grams that can be generated for the
172 // input text.
173 const uint64 num_candidates = GetNumSkipGrams(
174 tokens.size(), model_->max_denom_ngram_length(), model_->max_skips());
175
176 // For each token, see whether it denotes the start of an n-gram in the model.
177 int num_matches = 0;
178 float weight_matches = 0.f;
179 for (size_t start_i = 0; start_i < tokens.size(); ++start_i) {
180 const std::pair<int, int> ngram_range =
181 GetFirstTokenMatches(tokens[start_i]);
182 for (int ngram_idx = ngram_range.first; ngram_idx < ngram_range.second;
183 ++ngram_idx) {
184 const uint16 ngram_tokens_begin =
185 (*model_->ngram_start_offsets())[ngram_idx];
186 const uint16 ngram_tokens_end =
187 (*model_->ngram_start_offsets())[ngram_idx + 1];
188 if (IsNGramMatch(
189 /*tokens=*/tokens.data() + start_i,
190 /*num_tokens=*/tokens.size() - start_i,
191 /*ngram_tokens=*/model_->hashed_ngram_tokens()->data() +
192 ngram_tokens_begin,
193 /*num_ngram_tokens=*/ngram_tokens_end - ngram_tokens_begin,
194 /*max_skips=*/model_->max_skips())) {
195 ++num_matches;
196 weight_matches += (*model_->ngram_weights())[ngram_idx];
197 }
198 }
199 }
200
201 // Calculate the score.
202 const int num_misses = num_candidates - num_matches;
203 const float internal_score =
204 (weight_matches + (model_->default_token_weight() * num_misses)) /
205 num_candidates;
206 return std::make_pair(internal_score > model_->threshold(), internal_score);
207 }
208
EvalConversation(const Conversation & conversation,const int num_messages) const209 std::pair<bool, float> NGramSensitiveModel::EvalConversation(
210 const Conversation& conversation, const int num_messages) const {
211 float score = 0.0;
212 for (int i = 1; i <= num_messages; i++) {
213 const std::string& message =
214 conversation.messages[conversation.messages.size() - i].text;
215 const UnicodeText message_unicode(
216 UTF8ToUnicodeText(message, /*do_copy=*/false));
217 // Run ngram linear regression model.
218 const auto prediction = Eval(message_unicode);
219 if (prediction.first) {
220 return prediction;
221 }
222 score = std::max(score, prediction.second);
223 }
224 return std::make_pair(false, score);
225 }
226
227 } // namespace libtextclassifier3
228