1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "annotator/translate/translate.h"
18
19 #include <algorithm>
20 #include <memory>
21
22 #include "annotator/collections.h"
23 #include "annotator/entity-data_generated.h"
24 #include "annotator/model_generated.h"
25 #include "annotator/types.h"
26 #include "lang_id/lang-id-wrapper.h"
27 #include "utils/base/logging.h"
28 #include "utils/i18n/locale.h"
29 #include "utils/utf8/unicodetext.h"
30 #include "lang_id/lang-id.h"
31
32 namespace libtextclassifier3 {
33
ClassifyText(const UnicodeText & context,CodepointSpan selection_indices,const std::string & user_familiar_language_tags,ClassificationResult * classification_result) const34 bool TranslateAnnotator::ClassifyText(
35 const UnicodeText& context, CodepointSpan selection_indices,
36 const std::string& user_familiar_language_tags,
37 ClassificationResult* classification_result) const {
38 if (!(options_->enabled_modes() & ModeFlag_CLASSIFICATION)) {
39 return false;
40 }
41
42 std::vector<TranslateAnnotator::LanguageConfidence> confidences;
43 if (options_->algorithm() ==
44 TranslateAnnotatorOptions_::Algorithm::Algorithm_BACKOFF) {
45 if (options_->backoff_options() == nullptr) {
46 TC3_LOG(WARNING) << "No backoff options specified. Returning.";
47 return false;
48 }
49 confidences = BackoffDetectLanguages(context, selection_indices);
50 }
51
52 if (confidences.empty()) {
53 return false;
54 }
55
56 std::vector<Locale> user_familiar_languages;
57 if (!ParseLocales(user_familiar_language_tags, &user_familiar_languages)) {
58 TC3_LOG(WARNING) << "Couldn't parse the user-understood languages.";
59 return false;
60 }
61 if (user_familiar_languages.empty()) {
62 TC3_VLOG(INFO) << "user_familiar_languages is not set, not suggesting "
63 "translate action.";
64 return false;
65 }
66 bool user_can_understand_language_of_text = false;
67 for (const Locale& locale : user_familiar_languages) {
68 if (locale.Language() == confidences[0].language) {
69 user_can_understand_language_of_text = true;
70 break;
71 }
72 }
73
74 if (!user_can_understand_language_of_text) {
75 classification_result->collection = Collections::Translate();
76 classification_result->score = options_->score();
77 classification_result->priority_score = options_->priority_score();
78 classification_result->serialized_entity_data =
79 CreateSerializedEntityData(confidences);
80 return true;
81 }
82
83 return false;
84 }
85
CreateSerializedEntityData(const std::vector<TranslateAnnotator::LanguageConfidence> & confidences) const86 std::string TranslateAnnotator::CreateSerializedEntityData(
87 const std::vector<TranslateAnnotator::LanguageConfidence>& confidences)
88 const {
89 EntityDataT entity_data;
90 entity_data.translate.reset(new EntityData_::TranslateT());
91
92 for (const LanguageConfidence& confidence : confidences) {
93 EntityData_::Translate_::LanguagePredictionResultT*
94 language_prediction_result =
95 new EntityData_::Translate_::LanguagePredictionResultT();
96 language_prediction_result->language_tag = confidence.language;
97 language_prediction_result->confidence_score = confidence.confidence;
98 entity_data.translate->language_prediction_results.emplace_back(
99 language_prediction_result);
100 }
101 flatbuffers::FlatBufferBuilder builder;
102 FinishEntityDataBuffer(builder, EntityData::Pack(builder, &entity_data));
103 return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
104 builder.GetSize());
105 }
106
107 std::vector<TranslateAnnotator::LanguageConfidence>
BackoffDetectLanguages(const UnicodeText & context,CodepointSpan selection_indices) const108 TranslateAnnotator::BackoffDetectLanguages(
109 const UnicodeText& context, CodepointSpan selection_indices) const {
110 const float penalize_ratio = options_->backoff_options()->penalize_ratio();
111 const int min_text_size = options_->backoff_options()->min_text_size();
112 if (selection_indices.second - selection_indices.first < min_text_size &&
113 penalize_ratio <= 0) {
114 return {};
115 }
116
117 const UnicodeText entity =
118 UnicodeText::Substring(context, selection_indices.first,
119 selection_indices.second, /*do_copy=*/false);
120 const std::vector<std::pair<std::string, float>> lang_id_result =
121 langid::GetPredictions(langid_model_, entity.data(), entity.size_bytes());
122
123 const float more_text_score_ratio =
124 1.0f - options_->backoff_options()->subject_text_score_ratio();
125 std::vector<std::pair<std::string, float>> more_lang_id_results;
126 if (more_text_score_ratio >= 0) {
127 const UnicodeText entity_with_context = TokenAlignedSubstringAroundSpan(
128 context, selection_indices, min_text_size);
129 more_lang_id_results =
130 langid::GetPredictions(langid_model_, entity_with_context.data(),
131 entity_with_context.size_bytes());
132 }
133
134 const float subject_text_score_ratio =
135 options_->backoff_options()->subject_text_score_ratio();
136
137 std::map<std::string, float> result_map;
138 for (const auto& [language, score] : lang_id_result) {
139 result_map[language] = subject_text_score_ratio * score;
140 }
141 for (const auto& [language, score] : more_lang_id_results) {
142 result_map[language] += more_text_score_ratio * score * penalize_ratio;
143 }
144
145 std::vector<TranslateAnnotator::LanguageConfidence> result;
146 result.reserve(result_map.size());
147 for (const auto& [key, value] : result_map) {
148 result.push_back({key, value});
149 }
150
151 std::stable_sort(result.begin(), result.end(),
152 [](const TranslateAnnotator::LanguageConfidence& a,
153 const TranslateAnnotator::LanguageConfidence& b) {
154 return a.confidence > b.confidence;
155 });
156 return result;
157 }
158
159 UnicodeText::const_iterator
FindIndexOfNextWhitespaceOrPunctuation(const UnicodeText & text,int start_index,int direction) const160 TranslateAnnotator::FindIndexOfNextWhitespaceOrPunctuation(
161 const UnicodeText& text, int start_index, int direction) const {
162 TC3_CHECK(direction == 1 || direction == -1);
163 auto it = text.begin();
164 std::advance(it, start_index);
165 while (it > text.begin() && it < text.end()) {
166 if (unilib_->IsWhitespace(*it) || unilib_->IsPunctuation(*it)) {
167 break;
168 }
169 std::advance(it, direction);
170 }
171 return it;
172 }
173
TokenAlignedSubstringAroundSpan(const UnicodeText & text,CodepointSpan indices,int minimum_length) const174 UnicodeText TranslateAnnotator::TokenAlignedSubstringAroundSpan(
175 const UnicodeText& text, CodepointSpan indices, int minimum_length) const {
176 const int text_size_codepoints = text.size_codepoints();
177 if (text_size_codepoints < minimum_length) {
178 return UnicodeText(text, /*do_copy=*/false);
179 }
180
181 const int start = indices.first;
182 const int end = indices.second;
183 const int length = end - start;
184 if (length >= minimum_length) {
185 return UnicodeText::Substring(text, start, end, /*do_copy=*/false);
186 }
187
188 const int offset = (minimum_length - length) / 2;
189 const int iter_start = std::max(
190 0, std::min(start - offset, text_size_codepoints - minimum_length));
191 const int iter_end =
192 std::min(text_size_codepoints, iter_start + minimum_length);
193
194 auto it_start = FindIndexOfNextWhitespaceOrPunctuation(text, iter_start, -1);
195 const auto it_end = FindIndexOfNextWhitespaceOrPunctuation(text, iter_end, 1);
196
197 // The it_start now points to whitespace/punctuation (unless it reached the
198 // beginning of the string). So we'll move it one position forward to point to
199 // the actual text.
200 if (it_start != it_end && unilib_->IsWhitespace(*it_start)) {
201 std::advance(it_start, 1);
202 }
203
204 return UnicodeText::Substring(it_start, it_end, /*do_copy=*/false);
205 }
206
207 } // namespace libtextclassifier3
208