1 /*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "smartselect/text-classification-model.h"
18
19 #include <cmath>
20 #include <iterator>
21 #include <numeric>
22
23 #include "common/embedding-network.h"
24 #include "common/feature-extractor.h"
25 #include "common/memory_image/embedding-network-params-from-image.h"
26 #include "common/memory_image/memory-image-reader.h"
27 #include "common/mmap.h"
28 #include "common/softmax.h"
29 #include "smartselect/text-classification-model.pb.h"
30 #include "util/base/logging.h"
31 #include "util/utf8/unicodetext.h"
32 #include "unicode/uchar.h"
33
34 namespace libtextclassifier {
35
36 using nlp_core::EmbeddingNetwork;
37 using nlp_core::EmbeddingNetworkProto;
38 using nlp_core::FeatureVector;
39 using nlp_core::MemoryImageReader;
40 using nlp_core::MmapFile;
41 using nlp_core::MmapHandle;
42 using nlp_core::ScopedMmap;
43
44 namespace {
45
CountDigits(const std::string & str,CodepointSpan selection_indices)46 int CountDigits(const std::string& str, CodepointSpan selection_indices) {
47 int count = 0;
48 int i = 0;
49 const UnicodeText unicode_str = UTF8ToUnicodeText(str, /*do_copy=*/false);
50 for (auto it = unicode_str.begin(); it != unicode_str.end(); ++it, ++i) {
51 if (i >= selection_indices.first && i < selection_indices.second &&
52 u_isdigit(*it)) {
53 ++count;
54 }
55 }
56 return count;
57 }
58
59 } // namespace
60
StripPunctuation(CodepointSpan selection,const std::string & context) const61 CodepointSpan TextClassificationModel::StripPunctuation(
62 CodepointSpan selection, const std::string& context) const {
63 UnicodeText context_unicode = UTF8ToUnicodeText(context, /*do_copy=*/false);
64 int context_length =
65 std::distance(context_unicode.begin(), context_unicode.end());
66
67 // Check that the indices are valid.
68 if (selection.first < 0 || selection.first > context_length ||
69 selection.second < 0 || selection.second > context_length) {
70 return selection;
71 }
72
73 // Move the left border until we encounter a non-punctuation character.
74 UnicodeText::const_iterator it_from_begin = context_unicode.begin();
75 std::advance(it_from_begin, selection.first);
76 for (; punctuation_to_strip_.find(*it_from_begin) !=
77 punctuation_to_strip_.end();
78 ++it_from_begin, ++selection.first) {
79 }
80
81 // Unless we are already at the end, move the right border until we encounter
82 // a non-punctuation character.
83 UnicodeText::const_iterator it_from_end = context_unicode.begin();
84 std::advance(it_from_end, selection.second);
85 if (it_from_begin != it_from_end) {
86 --it_from_end;
87 for (; punctuation_to_strip_.find(*it_from_end) !=
88 punctuation_to_strip_.end();
89 --it_from_end, --selection.second) {
90 }
91 return selection;
92 } else {
93 // When the token is all punctuation.
94 return {0, 0};
95 }
96 }
97
TextClassificationModel(int fd)98 TextClassificationModel::TextClassificationModel(int fd) : mmap_(fd) {
99 initialized_ = LoadModels(mmap_.handle());
100 if (!initialized_) {
101 TC_LOG(ERROR) << "Failed to load models";
102 return;
103 }
104
105 selection_options_ = selection_params_->GetSelectionModelOptions();
106 for (const int codepoint : selection_options_.punctuation_to_strip()) {
107 punctuation_to_strip_.insert(codepoint);
108 }
109
110 sharing_options_ = selection_params_->GetSharingModelOptions();
111 }
112
113 namespace {
114
115 // Converts sparse features vector to nlp_core::FeatureVector.
SparseFeaturesToFeatureVector(const std::vector<int> sparse_features,const nlp_core::NumericFeatureType & feature_type,nlp_core::FeatureVector * result)116 void SparseFeaturesToFeatureVector(
117 const std::vector<int> sparse_features,
118 const nlp_core::NumericFeatureType& feature_type,
119 nlp_core::FeatureVector* result) {
120 for (int feature_id : sparse_features) {
121 const int64 feature_value =
122 nlp_core::FloatFeatureValue(feature_id, 1.0 / sparse_features.size())
123 .discrete_value;
124 result->add(const_cast<nlp_core::NumericFeatureType*>(&feature_type),
125 feature_value);
126 }
127 }
128
129 // Returns a function that can be used for mapping sparse and dense features
130 // to a float feature vector.
131 // NOTE: The network object needs to be available at the time when the returned
132 // function object is used.
CreateFeatureVectorFn(const EmbeddingNetwork & network,int sparse_embedding_size)133 FeatureVectorFn CreateFeatureVectorFn(const EmbeddingNetwork& network,
134 int sparse_embedding_size) {
135 const nlp_core::NumericFeatureType feature_type("chargram_continuous", 0);
136 return [&network, sparse_embedding_size, feature_type](
137 const std::vector<int>& sparse_features,
138 const std::vector<float>& dense_features, float* embedding) {
139 nlp_core::FeatureVector feature_vector;
140 SparseFeaturesToFeatureVector(sparse_features, feature_type,
141 &feature_vector);
142
143 if (network.GetEmbedding(feature_vector, 0, embedding)) {
144 for (int i = 0; i < dense_features.size(); i++) {
145 embedding[sparse_embedding_size + i] = dense_features[i];
146 }
147 return true;
148 } else {
149 return false;
150 }
151 };
152 }
153
ParseMergedModel(const MmapHandle & mmap_handle,const char ** selection_model,int * selection_model_length,const char ** sharing_model,int * sharing_model_length)154 void ParseMergedModel(const MmapHandle& mmap_handle,
155 const char** selection_model, int* selection_model_length,
156 const char** sharing_model, int* sharing_model_length) {
157 // Read the length of the selection model.
158 const char* model_data = reinterpret_cast<const char*>(mmap_handle.start());
159 *selection_model_length =
160 LittleEndian::ToHost32(*reinterpret_cast<const uint32*>(model_data));
161 model_data += sizeof(*selection_model_length);
162 *selection_model = model_data;
163 model_data += *selection_model_length;
164
165 *sharing_model_length =
166 LittleEndian::ToHost32(*reinterpret_cast<const uint32*>(model_data));
167 model_data += sizeof(*sharing_model_length);
168 *sharing_model = model_data;
169 }
170
171 } // namespace
172
LoadModels(const MmapHandle & mmap_handle)173 bool TextClassificationModel::LoadModels(const MmapHandle& mmap_handle) {
174 if (!mmap_handle.ok()) {
175 return false;
176 }
177
178 const char *selection_model, *sharing_model;
179 int selection_model_length, sharing_model_length;
180 ParseMergedModel(mmap_handle, &selection_model, &selection_model_length,
181 &sharing_model, &sharing_model_length);
182
183 selection_params_.reset(
184 ModelParamsBuilder(selection_model, selection_model_length, nullptr));
185 if (!selection_params_.get()) {
186 return false;
187 }
188 selection_network_.reset(new EmbeddingNetwork(selection_params_.get()));
189 selection_feature_processor_.reset(
190 new FeatureProcessor(selection_params_->GetFeatureProcessorOptions()));
191 selection_feature_fn_ = CreateFeatureVectorFn(
192 *selection_network_, selection_network_->EmbeddingSize(0));
193
194 sharing_params_.reset(
195 ModelParamsBuilder(sharing_model, sharing_model_length,
196 selection_params_->GetEmbeddingParams()));
197 if (!sharing_params_.get()) {
198 return false;
199 }
200 sharing_network_.reset(new EmbeddingNetwork(sharing_params_.get()));
201 sharing_feature_processor_.reset(
202 new FeatureProcessor(sharing_params_->GetFeatureProcessorOptions()));
203 sharing_feature_fn_ = CreateFeatureVectorFn(
204 *sharing_network_, sharing_network_->EmbeddingSize(0));
205
206 return true;
207 }
208
ReadSelectionModelOptions(int fd,ModelOptions * model_options)209 bool ReadSelectionModelOptions(int fd, ModelOptions* model_options) {
210 ScopedMmap mmap = ScopedMmap(fd);
211 if (!mmap.handle().ok()) {
212 TC_LOG(ERROR) << "Can't mmap.";
213 return false;
214 }
215
216 const char *selection_model, *sharing_model;
217 int selection_model_length, sharing_model_length;
218 ParseMergedModel(mmap.handle(), &selection_model, &selection_model_length,
219 &sharing_model, &sharing_model_length);
220
221 MemoryImageReader<EmbeddingNetworkProto> reader(selection_model,
222 selection_model_length);
223
224 auto model_options_extension_id = model_options_in_embedding_network_proto;
225 if (reader.trimmed_proto().HasExtension(model_options_extension_id)) {
226 *model_options =
227 reader.trimmed_proto().GetExtension(model_options_extension_id);
228 return true;
229 } else {
230 return false;
231 }
232 }
233
InferInternal(const std::string & context,CodepointSpan span,const FeatureProcessor & feature_processor,const EmbeddingNetwork & network,const FeatureVectorFn & feature_vector_fn,std::vector<CodepointSpan> * selection_label_spans) const234 EmbeddingNetwork::Vector TextClassificationModel::InferInternal(
235 const std::string& context, CodepointSpan span,
236 const FeatureProcessor& feature_processor, const EmbeddingNetwork& network,
237 const FeatureVectorFn& feature_vector_fn,
238 std::vector<CodepointSpan>* selection_label_spans) const {
239 std::vector<Token> tokens;
240 int click_pos;
241 std::unique_ptr<CachedFeatures> cached_features;
242 const int embedding_size = network.EmbeddingSize(0);
243 if (!feature_processor.ExtractFeatures(
244 context, span, /*relative_click_span=*/{0, 0},
245 CreateFeatureVectorFn(network, embedding_size),
246 embedding_size + feature_processor.DenseFeaturesCount(), &tokens,
247 &click_pos, &cached_features)) {
248 TC_LOG(ERROR) << "Could not extract features.";
249 return {};
250 }
251
252 VectorSpan<float> features;
253 VectorSpan<Token> output_tokens;
254 if (!cached_features->Get(click_pos, &features, &output_tokens)) {
255 TC_LOG(ERROR) << "Could not extract features.";
256 return {};
257 }
258
259 if (selection_label_spans != nullptr) {
260 if (!feature_processor.SelectionLabelSpans(output_tokens,
261 selection_label_spans)) {
262 TC_LOG(ERROR) << "Could not get spans for selection labels.";
263 return {};
264 }
265 }
266
267 std::vector<float> scores;
268 network.ComputeLogits(features, &scores);
269 return scores;
270 }
271
SuggestSelection(const std::string & context,CodepointSpan click_indices) const272 CodepointSpan TextClassificationModel::SuggestSelection(
273 const std::string& context, CodepointSpan click_indices) const {
274 if (!initialized_) {
275 TC_LOG(ERROR) << "Not initialized";
276 return click_indices;
277 }
278
279 if (std::get<0>(click_indices) >= std::get<1>(click_indices)) {
280 TC_LOG(ERROR) << "Trying to run SuggestSelection with invalid indices:"
281 << std::get<0>(click_indices) << " "
282 << std::get<1>(click_indices);
283 return click_indices;
284 }
285
286 const UnicodeText context_unicode =
287 UTF8ToUnicodeText(context, /*do_copy=*/false);
288 const int context_length =
289 std::distance(context_unicode.begin(), context_unicode.end());
290 if (std::get<0>(click_indices) >= context_length ||
291 std::get<1>(click_indices) > context_length) {
292 return click_indices;
293 }
294
295 CodepointSpan result;
296 if (selection_options_.enforce_symmetry()) {
297 result = SuggestSelectionSymmetrical(context, click_indices);
298 } else {
299 float score;
300 std::tie(result, score) = SuggestSelectionInternal(context, click_indices);
301 }
302
303 if (selection_options_.strip_punctuation()) {
304 result = StripPunctuation(result, context);
305 }
306
307 return result;
308 }
309
310 namespace {
311
BestSelectionSpan(CodepointSpan original_click_indices,const std::vector<float> & scores,const std::vector<CodepointSpan> & selection_label_spans)312 std::pair<CodepointSpan, float> BestSelectionSpan(
313 CodepointSpan original_click_indices, const std::vector<float>& scores,
314 const std::vector<CodepointSpan>& selection_label_spans) {
315 if (!scores.empty()) {
316 const int prediction =
317 std::max_element(scores.begin(), scores.end()) - scores.begin();
318 std::pair<CodepointIndex, CodepointIndex> selection =
319 selection_label_spans[prediction];
320
321 if (selection.first == kInvalidIndex || selection.second == kInvalidIndex) {
322 TC_LOG(ERROR) << "Invalid indices predicted, returning input: "
323 << prediction << " " << selection.first << " "
324 << selection.second;
325 return {original_click_indices, -1.0};
326 }
327
328 return {{selection.first, selection.second}, scores[prediction]};
329 } else {
330 TC_LOG(ERROR) << "Returning default selection: scores.size() = "
331 << scores.size();
332 return {original_click_indices, -1.0};
333 }
334 }
335
336 } // namespace
337
338 std::pair<CodepointSpan, float>
SuggestSelectionInternal(const std::string & context,CodepointSpan click_indices) const339 TextClassificationModel::SuggestSelectionInternal(
340 const std::string& context, CodepointSpan click_indices) const {
341 if (!initialized_) {
342 TC_LOG(ERROR) << "Not initialized";
343 return {click_indices, -1.0};
344 }
345
346 std::vector<CodepointSpan> selection_label_spans;
347 EmbeddingNetwork::Vector scores = InferInternal(
348 context, click_indices, *selection_feature_processor_,
349 *selection_network_, selection_feature_fn_, &selection_label_spans);
350 scores = nlp_core::ComputeSoftmax(scores);
351
352 return BestSelectionSpan(click_indices, scores, selection_label_spans);
353 }
354
355 // Implements a greedy-search-like algorithm for making selections symmetric.
356 //
357 // Steps:
358 // 1. Get a set of selection proposals from places around the clicked word.
359 // 2. For each proposal (going from highest-scoring), check if the tokens that
360 // the proposal selects are still free, in which case it claims them, if a
361 // proposal that contains the clicked token is found, it is returned as the
362 // suggestion.
363 //
364 // This algorithm should ensure that if a selection is proposed, it does not
365 // matter which word of it was tapped - all of them will lead to the same
366 // selection.
SuggestSelectionSymmetrical(const std::string & context,CodepointSpan click_indices) const367 CodepointSpan TextClassificationModel::SuggestSelectionSymmetrical(
368 const std::string& context, CodepointSpan click_indices) const {
369 const int symmetry_context_size = selection_options_.symmetry_context_size();
370 std::vector<Token> tokens;
371 std::unique_ptr<CachedFeatures> cached_features;
372 int click_index;
373 int embedding_size = selection_network_->EmbeddingSize(0);
374 if (!selection_feature_processor_->ExtractFeatures(
375 context, click_indices, /*relative_click_span=*/
376 {symmetry_context_size, symmetry_context_size + 1},
377 selection_feature_fn_,
378 embedding_size + selection_feature_processor_->DenseFeaturesCount(),
379 &tokens, &click_index, &cached_features)) {
380 TC_LOG(ERROR) << "Couldn't ExtractFeatures.";
381 return click_indices;
382 }
383
384 // Scan in the symmetry context for selection span proposals.
385 std::vector<std::pair<CodepointSpan, float>> proposals;
386
387 for (int i = -symmetry_context_size; i < symmetry_context_size + 1; ++i) {
388 const int token_index = click_index + i;
389 if (token_index >= 0 && token_index < tokens.size() &&
390 !tokens[token_index].is_padding) {
391 float score;
392 VectorSpan<float> features;
393 VectorSpan<Token> output_tokens;
394
395 CodepointSpan span;
396 if (cached_features->Get(token_index, &features, &output_tokens)) {
397 std::vector<float> scores;
398 selection_network_->ComputeLogits(features, &scores);
399
400 std::vector<CodepointSpan> selection_label_spans;
401 if (selection_feature_processor_->SelectionLabelSpans(
402 output_tokens, &selection_label_spans)) {
403 scores = nlp_core::ComputeSoftmax(scores);
404 std::tie(span, score) =
405 BestSelectionSpan(click_indices, scores, selection_label_spans);
406 if (span.first != kInvalidIndex && span.second != kInvalidIndex &&
407 score >= 0) {
408 proposals.push_back({span, score});
409 }
410 }
411 }
412 }
413 }
414
415 // Sort selection span proposals by their respective probabilities.
416 std::sort(
417 proposals.begin(), proposals.end(),
418 [](std::pair<CodepointSpan, float> a, std::pair<CodepointSpan, float> b) {
419 return a.second > b.second;
420 });
421
422 // Go from the highest-scoring proposal and claim tokens. Tokens are marked as
423 // claimed by the higher-scoring selection proposals, so that the
424 // lower-scoring ones cannot use them. Returns the selection proposal if it
425 // contains the clicked token.
426 std::vector<int> used_tokens(tokens.size(), 0);
427 for (auto span_result : proposals) {
428 TokenSpan span = CodepointSpanToTokenSpan(tokens, span_result.first);
429 if (span.first != kInvalidIndex && span.second != kInvalidIndex) {
430 bool feasible = true;
431 for (int i = span.first; i < span.second; i++) {
432 if (used_tokens[i] != 0) {
433 feasible = false;
434 break;
435 }
436 }
437
438 if (feasible) {
439 if (span.first <= click_index && span.second > click_index) {
440 return {span_result.first.first, span_result.first.second};
441 }
442 for (int i = span.first; i < span.second; i++) {
443 used_tokens[i] = 1;
444 }
445 }
446 }
447 }
448
449 return {click_indices.first, click_indices.second};
450 }
451
452 std::vector<std::pair<std::string, float>>
ClassifyText(const std::string & context,CodepointSpan selection_indices,int hint_flags) const453 TextClassificationModel::ClassifyText(const std::string& context,
454 CodepointSpan selection_indices,
455 int hint_flags) const {
456 if (!initialized_) {
457 TC_LOG(ERROR) << "Not initialized";
458 return {};
459 }
460
461 if (std::get<0>(selection_indices) >= std::get<1>(selection_indices)) {
462 TC_LOG(ERROR) << "Trying to run ClassifyText with invalid indices: "
463 << std::get<0>(selection_indices) << " "
464 << std::get<1>(selection_indices);
465 return {};
466 }
467
468 if (hint_flags & SELECTION_IS_URL &&
469 sharing_options_.always_accept_url_hint()) {
470 return {{kUrlHintCollection, 1.0}};
471 }
472
473 if (hint_flags & SELECTION_IS_EMAIL &&
474 sharing_options_.always_accept_email_hint()) {
475 return {{kEmailHintCollection, 1.0}};
476 }
477
478 EmbeddingNetwork::Vector scores =
479 InferInternal(context, selection_indices, *sharing_feature_processor_,
480 *sharing_network_, sharing_feature_fn_, nullptr);
481 if (scores.empty() ||
482 scores.size() != sharing_feature_processor_->NumCollections()) {
483 TC_LOG(ERROR) << "Using default class: scores.size() = " << scores.size();
484 return {};
485 }
486
487 scores = nlp_core::ComputeSoftmax(scores);
488
489 std::vector<std::pair<std::string, float>> result;
490 for (int i = 0; i < scores.size(); i++) {
491 result.push_back(
492 {sharing_feature_processor_->LabelToCollection(i), scores[i]});
493 }
494 std::sort(result.begin(), result.end(),
495 [](const std::pair<std::string, float>& a,
496 const std::pair<std::string, float>& b) {
497 return a.second > b.second;
498 });
499
500 // Phone class sanity check.
501 if (result.begin()->first == kPhoneCollection) {
502 const int digit_count = CountDigits(context, selection_indices);
503 if (digit_count < sharing_options_.phone_min_num_digits() ||
504 digit_count > sharing_options_.phone_max_num_digits()) {
505 return {{kOtherCollection, 1.0}};
506 }
507 }
508
509 return result;
510 }
511
512 } // namespace libtextclassifier
513