• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "annotator/translate/translate.h"
18 
19 #include <algorithm>
20 #include <memory>
21 
22 #include "annotator/collections.h"
23 #include "annotator/entity-data_generated.h"
24 #include "annotator/model_generated.h"
25 #include "annotator/types.h"
26 #include "lang_id/lang-id-wrapper.h"
27 #include "utils/base/logging.h"
28 #include "utils/i18n/locale.h"
29 #include "utils/utf8/unicodetext.h"
30 #include "lang_id/lang-id.h"
31 
32 namespace libtextclassifier3 {
33 
ClassifyText(const UnicodeText & context,CodepointSpan selection_indices,const std::string & user_familiar_language_tags,ClassificationResult * classification_result) const34 bool TranslateAnnotator::ClassifyText(
35     const UnicodeText& context, CodepointSpan selection_indices,
36     const std::string& user_familiar_language_tags,
37     ClassificationResult* classification_result) const {
38   if (!(options_->enabled_modes() & ModeFlag_CLASSIFICATION)) {
39     return false;
40   }
41 
42   std::vector<TranslateAnnotator::LanguageConfidence> confidences;
43   if (options_->algorithm() ==
44       TranslateAnnotatorOptions_::Algorithm::Algorithm_BACKOFF) {
45     if (options_->backoff_options() == nullptr) {
46       TC3_LOG(WARNING) << "No backoff options specified. Returning.";
47       return false;
48     }
49     confidences = BackoffDetectLanguages(context, selection_indices);
50   }
51 
52   if (confidences.empty()) {
53     return false;
54   }
55 
56   std::vector<Locale> user_familiar_languages;
57   if (!ParseLocales(user_familiar_language_tags, &user_familiar_languages)) {
58     TC3_LOG(WARNING) << "Couldn't parse the user-understood languages.";
59     return false;
60   }
61   if (user_familiar_languages.empty()) {
62     TC3_VLOG(INFO) << "user_familiar_languages is not set, not suggesting "
63                       "translate action.";
64     return false;
65   }
66   bool user_can_understand_language_of_text = false;
67   for (const Locale& locale : user_familiar_languages) {
68     if (locale.Language() == confidences[0].language) {
69       user_can_understand_language_of_text = true;
70       break;
71     }
72   }
73 
74   if (!user_can_understand_language_of_text) {
75     classification_result->collection = Collections::Translate();
76     classification_result->score = options_->score();
77     classification_result->priority_score = options_->priority_score();
78     classification_result->serialized_entity_data =
79         CreateSerializedEntityData(confidences);
80     return true;
81   }
82 
83   return false;
84 }
85 
CreateSerializedEntityData(const std::vector<TranslateAnnotator::LanguageConfidence> & confidences) const86 std::string TranslateAnnotator::CreateSerializedEntityData(
87     const std::vector<TranslateAnnotator::LanguageConfidence>& confidences)
88     const {
89   EntityDataT entity_data;
90   entity_data.translate.reset(new EntityData_::TranslateT());
91 
92   for (const LanguageConfidence& confidence : confidences) {
93     EntityData_::Translate_::LanguagePredictionResultT*
94         language_prediction_result =
95             new EntityData_::Translate_::LanguagePredictionResultT();
96     language_prediction_result->language_tag = confidence.language;
97     language_prediction_result->confidence_score = confidence.confidence;
98     entity_data.translate->language_prediction_results.emplace_back(
99         language_prediction_result);
100   }
101   flatbuffers::FlatBufferBuilder builder;
102   FinishEntityDataBuffer(builder, EntityData::Pack(builder, &entity_data));
103   return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
104                      builder.GetSize());
105 }
106 
107 std::vector<TranslateAnnotator::LanguageConfidence>
BackoffDetectLanguages(const UnicodeText & context,CodepointSpan selection_indices) const108 TranslateAnnotator::BackoffDetectLanguages(
109     const UnicodeText& context, CodepointSpan selection_indices) const {
110   const float penalize_ratio = options_->backoff_options()->penalize_ratio();
111   const int min_text_size = options_->backoff_options()->min_text_size();
112   if (selection_indices.second - selection_indices.first < min_text_size &&
113       penalize_ratio <= 0) {
114     return {};
115   }
116 
117   const UnicodeText entity =
118       UnicodeText::Substring(context, selection_indices.first,
119                              selection_indices.second, /*do_copy=*/false);
120   const std::vector<std::pair<std::string, float>> lang_id_result =
121       langid::GetPredictions(langid_model_, entity.data(), entity.size_bytes());
122 
123   const float more_text_score_ratio =
124       1.0f - options_->backoff_options()->subject_text_score_ratio();
125   std::vector<std::pair<std::string, float>> more_lang_id_results;
126   if (more_text_score_ratio >= 0) {
127     const UnicodeText entity_with_context = TokenAlignedSubstringAroundSpan(
128         context, selection_indices, min_text_size);
129     more_lang_id_results =
130         langid::GetPredictions(langid_model_, entity_with_context.data(),
131                                entity_with_context.size_bytes());
132   }
133 
134   const float subject_text_score_ratio =
135       options_->backoff_options()->subject_text_score_ratio();
136 
137   std::map<std::string, float> result_map;
138   for (const auto& [language, score] : lang_id_result) {
139     result_map[language] = subject_text_score_ratio * score;
140   }
141   for (const auto& [language, score] : more_lang_id_results) {
142     result_map[language] += more_text_score_ratio * score * penalize_ratio;
143   }
144 
145   std::vector<TranslateAnnotator::LanguageConfidence> result;
146   result.reserve(result_map.size());
147   for (const auto& [key, value] : result_map) {
148     result.push_back({key, value});
149   }
150 
151   std::stable_sort(result.begin(), result.end(),
152                    [](const TranslateAnnotator::LanguageConfidence& a,
153                       const TranslateAnnotator::LanguageConfidence& b) {
154                      return a.confidence > b.confidence;
155                    });
156   return result;
157 }
158 
159 UnicodeText::const_iterator
FindIndexOfNextWhitespaceOrPunctuation(const UnicodeText & text,int start_index,int direction) const160 TranslateAnnotator::FindIndexOfNextWhitespaceOrPunctuation(
161     const UnicodeText& text, int start_index, int direction) const {
162   TC3_CHECK(direction == 1 || direction == -1);
163   auto it = text.begin();
164   std::advance(it, start_index);
165   while (it > text.begin() && it < text.end()) {
166     if (unilib_->IsWhitespace(*it) || unilib_->IsPunctuation(*it)) {
167       break;
168     }
169     std::advance(it, direction);
170   }
171   return it;
172 }
173 
TokenAlignedSubstringAroundSpan(const UnicodeText & text,CodepointSpan indices,int minimum_length) const174 UnicodeText TranslateAnnotator::TokenAlignedSubstringAroundSpan(
175     const UnicodeText& text, CodepointSpan indices, int minimum_length) const {
176   const int text_size_codepoints = text.size_codepoints();
177   if (text_size_codepoints < minimum_length) {
178     return UnicodeText(text, /*do_copy=*/false);
179   }
180 
181   const int start = indices.first;
182   const int end = indices.second;
183   const int length = end - start;
184   if (length >= minimum_length) {
185     return UnicodeText::Substring(text, start, end, /*do_copy=*/false);
186   }
187 
188   const int offset = (minimum_length - length) / 2;
189   const int iter_start = std::max(
190       0, std::min(start - offset, text_size_codepoints - minimum_length));
191   const int iter_end =
192       std::min(text_size_codepoints, iter_start + minimum_length);
193 
194   auto it_start = FindIndexOfNextWhitespaceOrPunctuation(text, iter_start, -1);
195   const auto it_end = FindIndexOfNextWhitespaceOrPunctuation(text, iter_end, 1);
196 
197   // The it_start now points to whitespace/punctuation (unless it reached the
198   // beginning of the string). So we'll move it one position forward to point to
199   // the actual text.
200   if (it_start != it_end && unilib_->IsWhitespace(*it_start)) {
201     std::advance(it_start, 1);
202   }
203 
204   return UnicodeText::Substring(it_start, it_end, /*do_copy=*/false);
205 }
206 
207 }  // namespace libtextclassifier3
208