/* * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "annotator/vocab/vocab-annotator-impl.h" #include "annotator/feature-processor.h" #include "annotator/model_generated.h" #include "utils/base/logging.h" #include "utils/optional.h" #include "utils/strings/numbers.h" namespace libtextclassifier3 { VocabAnnotator::VocabAnnotator( std::unique_ptr vocab_level_table, const std::vector& triggering_locales, const FeatureProcessor& feature_processor, const UniLib& unilib, const VocabModel* model) : vocab_level_table_(std::move(vocab_level_table)), triggering_locales_(triggering_locales), feature_processor_(feature_processor), unilib_(unilib), model_(model) {} std::unique_ptr VocabAnnotator::Create( const VocabModel* model, const FeatureProcessor& feature_processor, const UniLib& unilib) { std::unique_ptr vocab_lebel_table = VocabLevelTable::Create(model); if (vocab_lebel_table == nullptr) { TC3_LOG(ERROR) << "Failed to create vocab level table."; return nullptr; } std::vector triggering_locales; if (model->triggering_locales() && !ParseLocales(model->triggering_locales()->c_str(), &triggering_locales)) { TC3_LOG(ERROR) << "Could not parse model supported locales."; return nullptr; } return std::unique_ptr( new VocabAnnotator(std::move(vocab_lebel_table), triggering_locales, feature_processor, unilib, model)); } bool VocabAnnotator::Annotate( const UnicodeText& context, const std::vector detected_text_language_tags, bool trigger_on_beginner_words, std::vector* results) const { std::vector tokens = feature_processor_.Tokenize(context); for (const Token& token : tokens) { ClassificationResult classification_result; CodepointSpan stripped_span; bool found = ClassifyTextInternal( context, {token.start, token.end}, detected_text_language_tags, trigger_on_beginner_words, &classification_result, &stripped_span); if (found) { results->push_back(AnnotatedSpan{stripped_span, {classification_result}}); } } return true; } bool VocabAnnotator::ClassifyText( const UnicodeText& context, CodepointSpan click, const std::vector detected_text_language_tags, bool trigger_on_beginner_words, ClassificationResult* result) const { CodepointSpan stripped_span; return ClassifyTextInternal(context, click, detected_text_language_tags, trigger_on_beginner_words, result, &stripped_span); } bool VocabAnnotator::ClassifyTextInternal( const UnicodeText& context, const CodepointSpan click, const std::vector detected_text_language_tags, bool trigger_on_beginner_words, ClassificationResult* classification_result, CodepointSpan* classified_span) const { if (vocab_level_table_ == nullptr) { return false; } if (!Locale::IsAnyLocaleSupported(detected_text_language_tags, triggering_locales_, /*default_value=*/false)) { return false; } const CodepointSpan stripped_span = feature_processor_.StripBoundaryCodepoints(context, {click.first, click.second}); const UnicodeText stripped_token = UnicodeText::Substring( context, stripped_span.first, stripped_span.second, /*do_copy=*/false); const std::string lower_token = unilib_.ToLowerText(stripped_token).ToUTF8String(); const Optional result = vocab_level_table_->Lookup(lower_token); if (!result.has_value()) { return false; } if (result.value().do_not_trigger_in_upper_case && unilib_.IsUpper(*stripped_token.begin())) { TC3_VLOG(INFO) << "Not trigger define: proper noun in upper case."; return false; } if (result.value().beginner_level && !trigger_on_beginner_words) { TC3_VLOG(INFO) << "Not trigger define: for beginner only."; return false; } *classification_result = ClassificationResult("dictionary", model_->target_classification_score(), model_->priority_score()); *classified_span = stripped_span; return true; } } // namespace libtextclassifier3