1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "annotator/vocab/vocab-annotator-impl.h"
18
19 #include "annotator/feature-processor.h"
20 #include "annotator/model_generated.h"
21 #include "utils/base/logging.h"
22 #include "utils/optional.h"
23 #include "utils/strings/numbers.h"
24
25 namespace libtextclassifier3 {
26
VocabAnnotator(std::unique_ptr<VocabLevelTable> vocab_level_table,const std::vector<Locale> & triggering_locales,const FeatureProcessor & feature_processor,const UniLib & unilib,const VocabModel * model)27 VocabAnnotator::VocabAnnotator(
28 std::unique_ptr<VocabLevelTable> vocab_level_table,
29 const std::vector<Locale>& triggering_locales,
30 const FeatureProcessor& feature_processor, const UniLib& unilib,
31 const VocabModel* model)
32 : vocab_level_table_(std::move(vocab_level_table)),
33 triggering_locales_(triggering_locales),
34 feature_processor_(feature_processor),
35 unilib_(unilib),
36 model_(model) {}
37
Create(const VocabModel * model,const FeatureProcessor & feature_processor,const UniLib & unilib)38 std::unique_ptr<VocabAnnotator> VocabAnnotator::Create(
39 const VocabModel* model, const FeatureProcessor& feature_processor,
40 const UniLib& unilib) {
41 std::unique_ptr<VocabLevelTable> vocab_lebel_table =
42 VocabLevelTable::Create(model);
43 if (vocab_lebel_table == nullptr) {
44 TC3_LOG(ERROR) << "Failed to create vocab level table.";
45 return nullptr;
46 }
47 std::vector<Locale> triggering_locales;
48 if (model->triggering_locales() &&
49 !ParseLocales(model->triggering_locales()->c_str(),
50 &triggering_locales)) {
51 TC3_LOG(ERROR) << "Could not parse model supported locales.";
52 return nullptr;
53 }
54
55 return std::unique_ptr<VocabAnnotator>(
56 new VocabAnnotator(std::move(vocab_lebel_table), triggering_locales,
57 feature_processor, unilib, model));
58 }
59
Annotate(const UnicodeText & context,const std::vector<Locale> detected_text_language_tags,bool trigger_on_beginner_words,std::vector<AnnotatedSpan> * results) const60 bool VocabAnnotator::Annotate(
61 const UnicodeText& context,
62 const std::vector<Locale> detected_text_language_tags,
63 bool trigger_on_beginner_words, std::vector<AnnotatedSpan>* results) const {
64 std::vector<Token> tokens = feature_processor_.Tokenize(context);
65 for (const Token& token : tokens) {
66 ClassificationResult classification_result;
67 CodepointSpan stripped_span;
68 bool found = ClassifyTextInternal(
69 context, {token.start, token.end}, detected_text_language_tags,
70 trigger_on_beginner_words, &classification_result, &stripped_span);
71 if (found) {
72 results->push_back(AnnotatedSpan{stripped_span, {classification_result}});
73 }
74 }
75 return true;
76 }
77
ClassifyText(const UnicodeText & context,CodepointSpan click,const std::vector<Locale> detected_text_language_tags,bool trigger_on_beginner_words,ClassificationResult * result) const78 bool VocabAnnotator::ClassifyText(
79 const UnicodeText& context, CodepointSpan click,
80 const std::vector<Locale> detected_text_language_tags,
81 bool trigger_on_beginner_words, ClassificationResult* result) const {
82 CodepointSpan stripped_span;
83 return ClassifyTextInternal(context, click, detected_text_language_tags,
84 trigger_on_beginner_words, result,
85 &stripped_span);
86 }
87
ClassifyTextInternal(const UnicodeText & context,const CodepointSpan click,const std::vector<Locale> detected_text_language_tags,bool trigger_on_beginner_words,ClassificationResult * classification_result,CodepointSpan * classified_span) const88 bool VocabAnnotator::ClassifyTextInternal(
89 const UnicodeText& context, const CodepointSpan click,
90 const std::vector<Locale> detected_text_language_tags,
91 bool trigger_on_beginner_words, ClassificationResult* classification_result,
92 CodepointSpan* classified_span) const {
93 if (vocab_level_table_ == nullptr) {
94 return false;
95 }
96
97 if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
98 triggering_locales_,
99 /*default_value=*/false)) {
100 return false;
101 }
102 const CodepointSpan stripped_span =
103 feature_processor_.StripBoundaryCodepoints(context,
104 {click.first, click.second});
105 const UnicodeText stripped_token = UnicodeText::Substring(
106 context, stripped_span.first, stripped_span.second, /*do_copy=*/false);
107 const std::string lower_token =
108 unilib_.ToLowerText(stripped_token).ToUTF8String();
109
110 const Optional<LookupResult> result = vocab_level_table_->Lookup(lower_token);
111 if (!result.has_value()) {
112 return false;
113 }
114 if (result.value().do_not_trigger_in_upper_case &&
115 unilib_.IsUpper(*stripped_token.begin())) {
116 TC3_VLOG(INFO) << "Not trigger define: proper noun in upper case.";
117 return false;
118 }
119 if (result.value().beginner_level && !trigger_on_beginner_words) {
120 TC3_VLOG(INFO) << "Not trigger define: for beginner only.";
121 return false;
122 }
123 *classification_result =
124 ClassificationResult("dictionary", model_->target_classification_score(),
125 model_->priority_score());
126 *classified_span = stripped_span;
127
128 return true;
129 }
130 } // namespace libtextclassifier3
131