• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "smartselect/text-classification-model.h"
18 
19 #include <cmath>
20 #include <iterator>
21 #include <numeric>
22 
23 #include "common/embedding-network.h"
24 #include "common/feature-extractor.h"
25 #include "common/memory_image/embedding-network-params-from-image.h"
26 #include "common/memory_image/memory-image-reader.h"
27 #include "common/mmap.h"
28 #include "common/softmax.h"
29 #include "smartselect/text-classification-model.pb.h"
30 #include "util/base/logging.h"
31 #include "util/utf8/unicodetext.h"
32 #include "unicode/uchar.h"
33 
34 namespace libtextclassifier {
35 
36 using nlp_core::EmbeddingNetwork;
37 using nlp_core::EmbeddingNetworkProto;
38 using nlp_core::FeatureVector;
39 using nlp_core::MemoryImageReader;
40 using nlp_core::MmapFile;
41 using nlp_core::MmapHandle;
42 using nlp_core::ScopedMmap;
43 
44 namespace {
45 
CountDigits(const std::string & str,CodepointSpan selection_indices)46 int CountDigits(const std::string& str, CodepointSpan selection_indices) {
47   int count = 0;
48   int i = 0;
49   const UnicodeText unicode_str = UTF8ToUnicodeText(str, /*do_copy=*/false);
50   for (auto it = unicode_str.begin(); it != unicode_str.end(); ++it, ++i) {
51     if (i >= selection_indices.first && i < selection_indices.second &&
52         u_isdigit(*it)) {
53       ++count;
54     }
55   }
56   return count;
57 }
58 
59 }  // namespace
60 
StripPunctuation(CodepointSpan selection,const std::string & context) const61 CodepointSpan TextClassificationModel::StripPunctuation(
62     CodepointSpan selection, const std::string& context) const {
63   UnicodeText context_unicode = UTF8ToUnicodeText(context, /*do_copy=*/false);
64   int context_length =
65       std::distance(context_unicode.begin(), context_unicode.end());
66 
67   // Check that the indices are valid.
68   if (selection.first < 0 || selection.first > context_length ||
69       selection.second < 0 || selection.second > context_length) {
70     return selection;
71   }
72 
73   // Move the left border until we encounter a non-punctuation character.
74   UnicodeText::const_iterator it_from_begin = context_unicode.begin();
75   std::advance(it_from_begin, selection.first);
76   for (; punctuation_to_strip_.find(*it_from_begin) !=
77          punctuation_to_strip_.end();
78        ++it_from_begin, ++selection.first) {
79   }
80 
81   // Unless we are already at the end, move the right border until we encounter
82   // a non-punctuation character.
83   UnicodeText::const_iterator it_from_end = context_unicode.begin();
84   std::advance(it_from_end, selection.second);
85   if (it_from_begin != it_from_end) {
86     --it_from_end;
87     for (; punctuation_to_strip_.find(*it_from_end) !=
88            punctuation_to_strip_.end();
89          --it_from_end, --selection.second) {
90     }
91     return selection;
92   } else {
93     // When the token is all punctuation.
94     return {0, 0};
95   }
96 }
97 
TextClassificationModel(int fd)98 TextClassificationModel::TextClassificationModel(int fd) : mmap_(fd) {
99   initialized_ = LoadModels(mmap_.handle());
100   if (!initialized_) {
101     TC_LOG(ERROR) << "Failed to load models";
102     return;
103   }
104 
105   selection_options_ = selection_params_->GetSelectionModelOptions();
106   for (const int codepoint : selection_options_.punctuation_to_strip()) {
107     punctuation_to_strip_.insert(codepoint);
108   }
109 
110   sharing_options_ = selection_params_->GetSharingModelOptions();
111 }
112 
113 namespace {
114 
115 // Converts sparse features vector to nlp_core::FeatureVector.
SparseFeaturesToFeatureVector(const std::vector<int> sparse_features,const nlp_core::NumericFeatureType & feature_type,nlp_core::FeatureVector * result)116 void SparseFeaturesToFeatureVector(
117     const std::vector<int> sparse_features,
118     const nlp_core::NumericFeatureType& feature_type,
119     nlp_core::FeatureVector* result) {
120   for (int feature_id : sparse_features) {
121     const int64 feature_value =
122         nlp_core::FloatFeatureValue(feature_id, 1.0 / sparse_features.size())
123             .discrete_value;
124     result->add(const_cast<nlp_core::NumericFeatureType*>(&feature_type),
125                 feature_value);
126   }
127 }
128 
129 // Returns a function that can be used for mapping sparse and dense features
130 // to a float feature vector.
131 // NOTE: The network object needs to be available at the time when the returned
132 // function object is used.
CreateFeatureVectorFn(const EmbeddingNetwork & network,int sparse_embedding_size)133 FeatureVectorFn CreateFeatureVectorFn(const EmbeddingNetwork& network,
134                                       int sparse_embedding_size) {
135   const nlp_core::NumericFeatureType feature_type("chargram_continuous", 0);
136   return [&network, sparse_embedding_size, feature_type](
137              const std::vector<int>& sparse_features,
138              const std::vector<float>& dense_features, float* embedding) {
139     nlp_core::FeatureVector feature_vector;
140     SparseFeaturesToFeatureVector(sparse_features, feature_type,
141                                   &feature_vector);
142 
143     if (network.GetEmbedding(feature_vector, 0, embedding)) {
144       for (int i = 0; i < dense_features.size(); i++) {
145         embedding[sparse_embedding_size + i] = dense_features[i];
146       }
147       return true;
148     } else {
149       return false;
150     }
151   };
152 }
153 
ParseMergedModel(const MmapHandle & mmap_handle,const char ** selection_model,int * selection_model_length,const char ** sharing_model,int * sharing_model_length)154 void ParseMergedModel(const MmapHandle& mmap_handle,
155                       const char** selection_model, int* selection_model_length,
156                       const char** sharing_model, int* sharing_model_length) {
157   // Read the length of the selection model.
158   const char* model_data = reinterpret_cast<const char*>(mmap_handle.start());
159   *selection_model_length =
160       LittleEndian::ToHost32(*reinterpret_cast<const uint32*>(model_data));
161   model_data += sizeof(*selection_model_length);
162   *selection_model = model_data;
163   model_data += *selection_model_length;
164 
165   *sharing_model_length =
166       LittleEndian::ToHost32(*reinterpret_cast<const uint32*>(model_data));
167   model_data += sizeof(*sharing_model_length);
168   *sharing_model = model_data;
169 }
170 
171 }  // namespace
172 
LoadModels(const MmapHandle & mmap_handle)173 bool TextClassificationModel::LoadModels(const MmapHandle& mmap_handle) {
174   if (!mmap_handle.ok()) {
175     return false;
176   }
177 
178   const char *selection_model, *sharing_model;
179   int selection_model_length, sharing_model_length;
180   ParseMergedModel(mmap_handle, &selection_model, &selection_model_length,
181                    &sharing_model, &sharing_model_length);
182 
183   selection_params_.reset(
184       ModelParamsBuilder(selection_model, selection_model_length, nullptr));
185   if (!selection_params_.get()) {
186     return false;
187   }
188   selection_network_.reset(new EmbeddingNetwork(selection_params_.get()));
189   selection_feature_processor_.reset(
190       new FeatureProcessor(selection_params_->GetFeatureProcessorOptions()));
191   selection_feature_fn_ = CreateFeatureVectorFn(
192       *selection_network_, selection_network_->EmbeddingSize(0));
193 
194   sharing_params_.reset(
195       ModelParamsBuilder(sharing_model, sharing_model_length,
196                          selection_params_->GetEmbeddingParams()));
197   if (!sharing_params_.get()) {
198     return false;
199   }
200   sharing_network_.reset(new EmbeddingNetwork(sharing_params_.get()));
201   sharing_feature_processor_.reset(
202       new FeatureProcessor(sharing_params_->GetFeatureProcessorOptions()));
203   sharing_feature_fn_ = CreateFeatureVectorFn(
204       *sharing_network_, sharing_network_->EmbeddingSize(0));
205 
206   return true;
207 }
208 
ReadSelectionModelOptions(int fd,ModelOptions * model_options)209 bool ReadSelectionModelOptions(int fd, ModelOptions* model_options) {
210   ScopedMmap mmap = ScopedMmap(fd);
211   if (!mmap.handle().ok()) {
212     TC_LOG(ERROR) << "Can't mmap.";
213     return false;
214   }
215 
216   const char *selection_model, *sharing_model;
217   int selection_model_length, sharing_model_length;
218   ParseMergedModel(mmap.handle(), &selection_model, &selection_model_length,
219                    &sharing_model, &sharing_model_length);
220 
221   MemoryImageReader<EmbeddingNetworkProto> reader(selection_model,
222                                                   selection_model_length);
223 
224   auto model_options_extension_id = model_options_in_embedding_network_proto;
225   if (reader.trimmed_proto().HasExtension(model_options_extension_id)) {
226     *model_options =
227         reader.trimmed_proto().GetExtension(model_options_extension_id);
228     return true;
229   } else {
230     return false;
231   }
232 }
233 
InferInternal(const std::string & context,CodepointSpan span,const FeatureProcessor & feature_processor,const EmbeddingNetwork & network,const FeatureVectorFn & feature_vector_fn,std::vector<CodepointSpan> * selection_label_spans) const234 EmbeddingNetwork::Vector TextClassificationModel::InferInternal(
235     const std::string& context, CodepointSpan span,
236     const FeatureProcessor& feature_processor, const EmbeddingNetwork& network,
237     const FeatureVectorFn& feature_vector_fn,
238     std::vector<CodepointSpan>* selection_label_spans) const {
239   std::vector<Token> tokens;
240   int click_pos;
241   std::unique_ptr<CachedFeatures> cached_features;
242   const int embedding_size = network.EmbeddingSize(0);
243   if (!feature_processor.ExtractFeatures(
244           context, span, /*relative_click_span=*/{0, 0},
245           CreateFeatureVectorFn(network, embedding_size),
246           embedding_size + feature_processor.DenseFeaturesCount(), &tokens,
247           &click_pos, &cached_features)) {
248     TC_LOG(ERROR) << "Could not extract features.";
249     return {};
250   }
251 
252   VectorSpan<float> features;
253   VectorSpan<Token> output_tokens;
254   if (!cached_features->Get(click_pos, &features, &output_tokens)) {
255     TC_LOG(ERROR) << "Could not extract features.";
256     return {};
257   }
258 
259   if (selection_label_spans != nullptr) {
260     if (!feature_processor.SelectionLabelSpans(output_tokens,
261                                                selection_label_spans)) {
262       TC_LOG(ERROR) << "Could not get spans for selection labels.";
263       return {};
264     }
265   }
266 
267   std::vector<float> scores;
268   network.ComputeLogits(features, &scores);
269   return scores;
270 }
271 
SuggestSelection(const std::string & context,CodepointSpan click_indices) const272 CodepointSpan TextClassificationModel::SuggestSelection(
273     const std::string& context, CodepointSpan click_indices) const {
274   if (!initialized_) {
275     TC_LOG(ERROR) << "Not initialized";
276     return click_indices;
277   }
278 
279   if (std::get<0>(click_indices) >= std::get<1>(click_indices)) {
280     TC_LOG(ERROR) << "Trying to run SuggestSelection with invalid indices:"
281                   << std::get<0>(click_indices) << " "
282                   << std::get<1>(click_indices);
283     return click_indices;
284   }
285 
286   const UnicodeText context_unicode =
287       UTF8ToUnicodeText(context, /*do_copy=*/false);
288   const int context_length =
289       std::distance(context_unicode.begin(), context_unicode.end());
290   if (std::get<0>(click_indices) >= context_length ||
291       std::get<1>(click_indices) > context_length) {
292     return click_indices;
293   }
294 
295   CodepointSpan result;
296   if (selection_options_.enforce_symmetry()) {
297     result = SuggestSelectionSymmetrical(context, click_indices);
298   } else {
299     float score;
300     std::tie(result, score) = SuggestSelectionInternal(context, click_indices);
301   }
302 
303   if (selection_options_.strip_punctuation()) {
304     result = StripPunctuation(result, context);
305   }
306 
307   return result;
308 }
309 
310 namespace {
311 
BestSelectionSpan(CodepointSpan original_click_indices,const std::vector<float> & scores,const std::vector<CodepointSpan> & selection_label_spans)312 std::pair<CodepointSpan, float> BestSelectionSpan(
313     CodepointSpan original_click_indices, const std::vector<float>& scores,
314     const std::vector<CodepointSpan>& selection_label_spans) {
315   if (!scores.empty()) {
316     const int prediction =
317         std::max_element(scores.begin(), scores.end()) - scores.begin();
318     std::pair<CodepointIndex, CodepointIndex> selection =
319         selection_label_spans[prediction];
320 
321     if (selection.first == kInvalidIndex || selection.second == kInvalidIndex) {
322       TC_LOG(ERROR) << "Invalid indices predicted, returning input: "
323                     << prediction << " " << selection.first << " "
324                     << selection.second;
325       return {original_click_indices, -1.0};
326     }
327 
328     return {{selection.first, selection.second}, scores[prediction]};
329   } else {
330     TC_LOG(ERROR) << "Returning default selection: scores.size() = "
331                   << scores.size();
332     return {original_click_indices, -1.0};
333   }
334 }
335 
336 }  // namespace
337 
338 std::pair<CodepointSpan, float>
SuggestSelectionInternal(const std::string & context,CodepointSpan click_indices) const339 TextClassificationModel::SuggestSelectionInternal(
340     const std::string& context, CodepointSpan click_indices) const {
341   if (!initialized_) {
342     TC_LOG(ERROR) << "Not initialized";
343     return {click_indices, -1.0};
344   }
345 
346   std::vector<CodepointSpan> selection_label_spans;
347   EmbeddingNetwork::Vector scores = InferInternal(
348       context, click_indices, *selection_feature_processor_,
349       *selection_network_, selection_feature_fn_, &selection_label_spans);
350   scores = nlp_core::ComputeSoftmax(scores);
351 
352   return BestSelectionSpan(click_indices, scores, selection_label_spans);
353 }
354 
355 // Implements a greedy-search-like algorithm for making selections symmetric.
356 //
357 // Steps:
358 // 1. Get a set of selection proposals from places around the clicked word.
359 // 2. For each proposal (going from highest-scoring), check if the tokens that
360 //    the proposal selects are still free, in which case it claims them, if a
361 //    proposal that contains the clicked token is found, it is returned as the
362 //    suggestion.
363 //
364 // This algorithm should ensure that if a selection is proposed, it does not
365 // matter which word of it was tapped - all of them will lead to the same
366 // selection.
SuggestSelectionSymmetrical(const std::string & context,CodepointSpan click_indices) const367 CodepointSpan TextClassificationModel::SuggestSelectionSymmetrical(
368     const std::string& context, CodepointSpan click_indices) const {
369   const int symmetry_context_size = selection_options_.symmetry_context_size();
370   std::vector<Token> tokens;
371   std::unique_ptr<CachedFeatures> cached_features;
372   int click_index;
373   int embedding_size = selection_network_->EmbeddingSize(0);
374   if (!selection_feature_processor_->ExtractFeatures(
375           context, click_indices, /*relative_click_span=*/
376           {symmetry_context_size, symmetry_context_size + 1},
377           selection_feature_fn_,
378           embedding_size + selection_feature_processor_->DenseFeaturesCount(),
379           &tokens, &click_index, &cached_features)) {
380     TC_LOG(ERROR) << "Couldn't ExtractFeatures.";
381     return click_indices;
382   }
383 
384   // Scan in the symmetry context for selection span proposals.
385   std::vector<std::pair<CodepointSpan, float>> proposals;
386 
387   for (int i = -symmetry_context_size; i < symmetry_context_size + 1; ++i) {
388     const int token_index = click_index + i;
389     if (token_index >= 0 && token_index < tokens.size() &&
390         !tokens[token_index].is_padding) {
391       float score;
392       VectorSpan<float> features;
393       VectorSpan<Token> output_tokens;
394 
395       CodepointSpan span;
396       if (cached_features->Get(token_index, &features, &output_tokens)) {
397         std::vector<float> scores;
398         selection_network_->ComputeLogits(features, &scores);
399 
400         std::vector<CodepointSpan> selection_label_spans;
401         if (selection_feature_processor_->SelectionLabelSpans(
402                 output_tokens, &selection_label_spans)) {
403           scores = nlp_core::ComputeSoftmax(scores);
404           std::tie(span, score) =
405               BestSelectionSpan(click_indices, scores, selection_label_spans);
406           if (span.first != kInvalidIndex && span.second != kInvalidIndex &&
407               score >= 0) {
408             proposals.push_back({span, score});
409           }
410         }
411       }
412     }
413   }
414 
415   // Sort selection span proposals by their respective probabilities.
416   std::sort(
417       proposals.begin(), proposals.end(),
418       [](std::pair<CodepointSpan, float> a, std::pair<CodepointSpan, float> b) {
419         return a.second > b.second;
420       });
421 
422   // Go from the highest-scoring proposal and claim tokens. Tokens are marked as
423   // claimed by the higher-scoring selection proposals, so that the
424   // lower-scoring ones cannot use them. Returns the selection proposal if it
425   // contains the clicked token.
426   std::vector<int> used_tokens(tokens.size(), 0);
427   for (auto span_result : proposals) {
428     TokenSpan span = CodepointSpanToTokenSpan(tokens, span_result.first);
429     if (span.first != kInvalidIndex && span.second != kInvalidIndex) {
430       bool feasible = true;
431       for (int i = span.first; i < span.second; i++) {
432         if (used_tokens[i] != 0) {
433           feasible = false;
434           break;
435         }
436       }
437 
438       if (feasible) {
439         if (span.first <= click_index && span.second > click_index) {
440           return {span_result.first.first, span_result.first.second};
441         }
442         for (int i = span.first; i < span.second; i++) {
443           used_tokens[i] = 1;
444         }
445       }
446     }
447   }
448 
449   return {click_indices.first, click_indices.second};
450 }
451 
452 std::vector<std::pair<std::string, float>>
ClassifyText(const std::string & context,CodepointSpan selection_indices,int hint_flags) const453 TextClassificationModel::ClassifyText(const std::string& context,
454                                       CodepointSpan selection_indices,
455                                       int hint_flags) const {
456   if (!initialized_) {
457     TC_LOG(ERROR) << "Not initialized";
458     return {};
459   }
460 
461   if (std::get<0>(selection_indices) >= std::get<1>(selection_indices)) {
462     TC_LOG(ERROR) << "Trying to run ClassifyText with invalid indices: "
463                   << std::get<0>(selection_indices) << " "
464                   << std::get<1>(selection_indices);
465     return {};
466   }
467 
468   if (hint_flags & SELECTION_IS_URL &&
469       sharing_options_.always_accept_url_hint()) {
470     return {{kUrlHintCollection, 1.0}};
471   }
472 
473   if (hint_flags & SELECTION_IS_EMAIL &&
474       sharing_options_.always_accept_email_hint()) {
475     return {{kEmailHintCollection, 1.0}};
476   }
477 
478   EmbeddingNetwork::Vector scores =
479       InferInternal(context, selection_indices, *sharing_feature_processor_,
480                     *sharing_network_, sharing_feature_fn_, nullptr);
481   if (scores.empty() ||
482       scores.size() != sharing_feature_processor_->NumCollections()) {
483     TC_LOG(ERROR) << "Using default class: scores.size() = " << scores.size();
484     return {};
485   }
486 
487   scores = nlp_core::ComputeSoftmax(scores);
488 
489   std::vector<std::pair<std::string, float>> result;
490   for (int i = 0; i < scores.size(); i++) {
491     result.push_back(
492         {sharing_feature_processor_->LabelToCollection(i), scores[i]});
493   }
494   std::sort(result.begin(), result.end(),
495             [](const std::pair<std::string, float>& a,
496                const std::pair<std::string, float>& b) {
497               return a.second > b.second;
498             });
499 
500   // Phone class sanity check.
501   if (result.begin()->first == kPhoneCollection) {
502     const int digit_count = CountDigits(context, selection_indices);
503     if (digit_count < sharing_options_.phone_min_num_digits() ||
504         digit_count > sharing_options_.phone_max_num_digits()) {
505       return {{kOtherCollection, 1.0}};
506     }
507   }
508 
509   return result;
510 }
511 
512 }  // namespace libtextclassifier
513