1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "actions/actions-suggestions.h"
18
19 #include <memory>
20
21 #include "actions/lua-actions.h"
22 #include "actions/types.h"
23 #include "actions/zlib-utils.h"
24 #include "utils/base/logging.h"
25 #include "utils/flatbuffers.h"
26 #include "utils/lua-utils.h"
27 #include "utils/regex-match.h"
28 #include "utils/strings/split.h"
29 #include "utils/strings/stringpiece.h"
30 #include "utils/utf8/unicodetext.h"
31 #include "utils/zlib/zlib_regex.h"
32 #include "tensorflow/lite/string_util.h"
33
34 namespace libtextclassifier3 {
35
36 const std::string& ActionsSuggestions::kViewCalendarType =
__anon2dbc54b70102() 37 *[]() { return new std::string("view_calendar"); }();
38 const std::string& ActionsSuggestions::kViewMapType =
__anon2dbc54b70202() 39 *[]() { return new std::string("view_map"); }();
40 const std::string& ActionsSuggestions::kTrackFlightType =
__anon2dbc54b70302() 41 *[]() { return new std::string("track_flight"); }();
42 const std::string& ActionsSuggestions::kOpenUrlType =
__anon2dbc54b70402() 43 *[]() { return new std::string("open_url"); }();
44 const std::string& ActionsSuggestions::kSendSmsType =
__anon2dbc54b70502() 45 *[]() { return new std::string("send_sms"); }();
46 const std::string& ActionsSuggestions::kCallPhoneType =
__anon2dbc54b70602() 47 *[]() { return new std::string("call_phone"); }();
48 const std::string& ActionsSuggestions::kSendEmailType =
__anon2dbc54b70702() 49 *[]() { return new std::string("send_email"); }();
50 const std::string& ActionsSuggestions::kShareLocation =
__anon2dbc54b70802() 51 *[]() { return new std::string("share_location"); }();
52
53 namespace {
54
LoadAndVerifyModel(const uint8_t * addr,int size)55 const ActionsModel* LoadAndVerifyModel(const uint8_t* addr, int size) {
56 flatbuffers::Verifier verifier(addr, size);
57 if (VerifyActionsModelBuffer(verifier)) {
58 return GetActionsModel(addr);
59 } else {
60 return nullptr;
61 }
62 }
63
64 template <typename T>
ValueOrDefault(const flatbuffers::Table * values,const int32 field_offset,const T default_value)65 T ValueOrDefault(const flatbuffers::Table* values, const int32 field_offset,
66 const T default_value) {
67 if (values == nullptr) {
68 return default_value;
69 }
70 return values->GetField<T>(field_offset, default_value);
71 }
72
73 // Returns number of (tail) messages of a conversation to consider.
NumMessagesToConsider(const Conversation & conversation,const int max_conversation_history_length)74 int NumMessagesToConsider(const Conversation& conversation,
75 const int max_conversation_history_length) {
76 return ((max_conversation_history_length < 0 ||
77 conversation.messages.size() < max_conversation_history_length)
78 ? conversation.messages.size()
79 : max_conversation_history_length);
80 }
81
82 } // namespace
83
FromUnownedBuffer(const uint8_t * buffer,const int size,const UniLib * unilib,const std::string & triggering_preconditions_overlay)84 std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromUnownedBuffer(
85 const uint8_t* buffer, const int size, const UniLib* unilib,
86 const std::string& triggering_preconditions_overlay) {
87 auto actions = std::unique_ptr<ActionsSuggestions>(new ActionsSuggestions());
88 const ActionsModel* model = LoadAndVerifyModel(buffer, size);
89 if (model == nullptr) {
90 return nullptr;
91 }
92 actions->model_ = model;
93 actions->SetOrCreateUnilib(unilib);
94 actions->triggering_preconditions_overlay_buffer_ =
95 triggering_preconditions_overlay;
96 if (!actions->ValidateAndInitialize()) {
97 return nullptr;
98 }
99 return actions;
100 }
101
FromScopedMmap(std::unique_ptr<libtextclassifier3::ScopedMmap> mmap,const UniLib * unilib,const std::string & triggering_preconditions_overlay)102 std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromScopedMmap(
103 std::unique_ptr<libtextclassifier3::ScopedMmap> mmap, const UniLib* unilib,
104 const std::string& triggering_preconditions_overlay) {
105 if (!mmap->handle().ok()) {
106 TC3_VLOG(1) << "Mmap failed.";
107 return nullptr;
108 }
109 const ActionsModel* model = LoadAndVerifyModel(
110 reinterpret_cast<const uint8_t*>(mmap->handle().start()),
111 mmap->handle().num_bytes());
112 if (!model) {
113 TC3_LOG(ERROR) << "Model verification failed.";
114 return nullptr;
115 }
116 auto actions = std::unique_ptr<ActionsSuggestions>(new ActionsSuggestions());
117 actions->model_ = model;
118 actions->mmap_ = std::move(mmap);
119 actions->SetOrCreateUnilib(unilib);
120 actions->triggering_preconditions_overlay_buffer_ =
121 triggering_preconditions_overlay;
122 if (!actions->ValidateAndInitialize()) {
123 return nullptr;
124 }
125 return actions;
126 }
127
FromScopedMmap(std::unique_ptr<libtextclassifier3::ScopedMmap> mmap,std::unique_ptr<UniLib> unilib,const std::string & triggering_preconditions_overlay)128 std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromScopedMmap(
129 std::unique_ptr<libtextclassifier3::ScopedMmap> mmap,
130 std::unique_ptr<UniLib> unilib,
131 const std::string& triggering_preconditions_overlay) {
132 if (!mmap->handle().ok()) {
133 TC3_VLOG(1) << "Mmap failed.";
134 return nullptr;
135 }
136 const ActionsModel* model = LoadAndVerifyModel(
137 reinterpret_cast<const uint8_t*>(mmap->handle().start()),
138 mmap->handle().num_bytes());
139 if (!model) {
140 TC3_LOG(ERROR) << "Model verification failed.";
141 return nullptr;
142 }
143 auto actions = std::unique_ptr<ActionsSuggestions>(new ActionsSuggestions());
144 actions->model_ = model;
145 actions->mmap_ = std::move(mmap);
146 actions->owned_unilib_ = std::move(unilib);
147 actions->unilib_ = actions->owned_unilib_.get();
148 actions->triggering_preconditions_overlay_buffer_ =
149 triggering_preconditions_overlay;
150 if (!actions->ValidateAndInitialize()) {
151 return nullptr;
152 }
153 return actions;
154 }
155
FromFileDescriptor(const int fd,const int offset,const int size,const UniLib * unilib,const std::string & triggering_preconditions_overlay)156 std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromFileDescriptor(
157 const int fd, const int offset, const int size, const UniLib* unilib,
158 const std::string& triggering_preconditions_overlay) {
159 std::unique_ptr<libtextclassifier3::ScopedMmap> mmap;
160 if (offset >= 0 && size >= 0) {
161 mmap.reset(new libtextclassifier3::ScopedMmap(fd, offset, size));
162 } else {
163 mmap.reset(new libtextclassifier3::ScopedMmap(fd));
164 }
165 return FromScopedMmap(std::move(mmap), unilib,
166 triggering_preconditions_overlay);
167 }
168
FromFileDescriptor(const int fd,const int offset,const int size,std::unique_ptr<UniLib> unilib,const std::string & triggering_preconditions_overlay)169 std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromFileDescriptor(
170 const int fd, const int offset, const int size,
171 std::unique_ptr<UniLib> unilib,
172 const std::string& triggering_preconditions_overlay) {
173 std::unique_ptr<libtextclassifier3::ScopedMmap> mmap;
174 if (offset >= 0 && size >= 0) {
175 mmap.reset(new libtextclassifier3::ScopedMmap(fd, offset, size));
176 } else {
177 mmap.reset(new libtextclassifier3::ScopedMmap(fd));
178 }
179 return FromScopedMmap(std::move(mmap), std::move(unilib),
180 triggering_preconditions_overlay);
181 }
182
FromFileDescriptor(const int fd,const UniLib * unilib,const std::string & triggering_preconditions_overlay)183 std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromFileDescriptor(
184 const int fd, const UniLib* unilib,
185 const std::string& triggering_preconditions_overlay) {
186 std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
187 new libtextclassifier3::ScopedMmap(fd));
188 return FromScopedMmap(std::move(mmap), unilib,
189 triggering_preconditions_overlay);
190 }
191
FromFileDescriptor(const int fd,std::unique_ptr<UniLib> unilib,const std::string & triggering_preconditions_overlay)192 std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromFileDescriptor(
193 const int fd, std::unique_ptr<UniLib> unilib,
194 const std::string& triggering_preconditions_overlay) {
195 std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
196 new libtextclassifier3::ScopedMmap(fd));
197 return FromScopedMmap(std::move(mmap), std::move(unilib),
198 triggering_preconditions_overlay);
199 }
200
FromPath(const std::string & path,const UniLib * unilib,const std::string & triggering_preconditions_overlay)201 std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromPath(
202 const std::string& path, const UniLib* unilib,
203 const std::string& triggering_preconditions_overlay) {
204 std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
205 new libtextclassifier3::ScopedMmap(path));
206 return FromScopedMmap(std::move(mmap), unilib,
207 triggering_preconditions_overlay);
208 }
209
FromPath(const std::string & path,std::unique_ptr<UniLib> unilib,const std::string & triggering_preconditions_overlay)210 std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromPath(
211 const std::string& path, std::unique_ptr<UniLib> unilib,
212 const std::string& triggering_preconditions_overlay) {
213 std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
214 new libtextclassifier3::ScopedMmap(path));
215 return FromScopedMmap(std::move(mmap), std::move(unilib),
216 triggering_preconditions_overlay);
217 }
218
SetOrCreateUnilib(const UniLib * unilib)219 void ActionsSuggestions::SetOrCreateUnilib(const UniLib* unilib) {
220 if (unilib != nullptr) {
221 unilib_ = unilib;
222 } else {
223 owned_unilib_.reset(new UniLib);
224 unilib_ = owned_unilib_.get();
225 }
226 }
227
ValidateAndInitialize()228 bool ActionsSuggestions::ValidateAndInitialize() {
229 if (model_ == nullptr) {
230 TC3_LOG(ERROR) << "No model specified.";
231 return false;
232 }
233
234 if (model_->smart_reply_action_type() == nullptr) {
235 TC3_LOG(ERROR) << "No smart reply action type specified.";
236 return false;
237 }
238
239 if (!InitializeTriggeringPreconditions()) {
240 TC3_LOG(ERROR) << "Could not initialize preconditions.";
241 return false;
242 }
243
244 if (model_->locales() &&
245 !ParseLocales(model_->locales()->c_str(), &locales_)) {
246 TC3_LOG(ERROR) << "Could not parse model supported locales.";
247 return false;
248 }
249
250 if (model_->tflite_model_spec() != nullptr) {
251 model_executor_ = TfLiteModelExecutor::FromBuffer(
252 model_->tflite_model_spec()->tflite_model());
253 if (!model_executor_) {
254 TC3_LOG(ERROR) << "Could not initialize model executor.";
255 return false;
256 }
257 }
258
259 if (model_->annotation_actions_spec() != nullptr &&
260 model_->annotation_actions_spec()->annotation_mapping() != nullptr) {
261 for (const AnnotationActionsSpec_::AnnotationMapping* mapping :
262 *model_->annotation_actions_spec()->annotation_mapping()) {
263 annotation_entity_types_.insert(mapping->annotation_collection()->str());
264 }
265 }
266
267 std::unique_ptr<ZlibDecompressor> decompressor = ZlibDecompressor::Instance();
268 if (!InitializeRules(decompressor.get())) {
269 TC3_LOG(ERROR) << "Could not initialize rules.";
270 return false;
271 }
272
273 if (model_->actions_entity_data_schema() != nullptr) {
274 entity_data_schema_ = LoadAndVerifyFlatbuffer<reflection::Schema>(
275 model_->actions_entity_data_schema()->Data(),
276 model_->actions_entity_data_schema()->size());
277 if (entity_data_schema_ == nullptr) {
278 TC3_LOG(ERROR) << "Could not load entity data schema data.";
279 return false;
280 }
281
282 entity_data_builder_.reset(
283 new ReflectiveFlatbufferBuilder(entity_data_schema_));
284 } else {
285 entity_data_schema_ = nullptr;
286 }
287
288 std::string actions_script;
289 if (GetUncompressedString(model_->lua_actions_script(),
290 model_->compressed_lua_actions_script(),
291 decompressor.get(), &actions_script) &&
292 !actions_script.empty()) {
293 if (!Compile(actions_script, &lua_bytecode_)) {
294 TC3_LOG(ERROR) << "Could not precompile lua actions snippet.";
295 return false;
296 }
297 }
298
299 if (!(ranker_ = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
300 model_->ranking_options(), decompressor.get(),
301 model_->smart_reply_action_type()->str()))) {
302 TC3_LOG(ERROR) << "Could not create an action suggestions ranker.";
303 return false;
304 }
305
306 // Create feature processor if specified.
307 const ActionsTokenFeatureProcessorOptions* options =
308 model_->feature_processor_options();
309 if (options != nullptr) {
310 if (options->tokenizer_options() == nullptr) {
311 TC3_LOG(ERROR) << "No tokenizer options specified.";
312 return false;
313 }
314
315 feature_processor_.reset(new ActionsFeatureProcessor(options, unilib_));
316 embedding_executor_ = TFLiteEmbeddingExecutor::FromBuffer(
317 options->embedding_model(), options->embedding_size(),
318 options->embedding_quantization_bits());
319
320 if (embedding_executor_ == nullptr) {
321 TC3_LOG(ERROR) << "Could not initialize embedding executor.";
322 return false;
323 }
324
325 // Cache embedding of padding, start and end token.
326 if (!EmbedTokenId(options->padding_token_id(), &embedded_padding_token_) ||
327 !EmbedTokenId(options->start_token_id(), &embedded_start_token_) ||
328 !EmbedTokenId(options->end_token_id(), &embedded_end_token_)) {
329 TC3_LOG(ERROR) << "Could not precompute token embeddings.";
330 return false;
331 }
332 token_embedding_size_ = feature_processor_->GetTokenEmbeddingSize();
333 }
334
335 // Create low confidence model if specified.
336 if (model_->low_confidence_ngram_model() != nullptr) {
337 ngram_model_ = NGramModel::Create(model_->low_confidence_ngram_model(),
338 feature_processor_ == nullptr
339 ? nullptr
340 : feature_processor_->tokenizer(),
341 unilib_);
342 if (ngram_model_ == nullptr) {
343 TC3_LOG(ERROR) << "Could not create ngram linear regression model.";
344 return false;
345 }
346 }
347
348 return true;
349 }
350
InitializeTriggeringPreconditions()351 bool ActionsSuggestions::InitializeTriggeringPreconditions() {
352 triggering_preconditions_overlay_ =
353 LoadAndVerifyFlatbuffer<TriggeringPreconditions>(
354 triggering_preconditions_overlay_buffer_);
355
356 if (triggering_preconditions_overlay_ == nullptr &&
357 !triggering_preconditions_overlay_buffer_.empty()) {
358 TC3_LOG(ERROR) << "Could not load triggering preconditions overwrites.";
359 return false;
360 }
361 const flatbuffers::Table* overlay =
362 reinterpret_cast<const flatbuffers::Table*>(
363 triggering_preconditions_overlay_);
364 const TriggeringPreconditions* defaults = model_->preconditions();
365 if (defaults == nullptr) {
366 TC3_LOG(ERROR) << "No triggering conditions specified.";
367 return false;
368 }
369
370 preconditions_.min_smart_reply_triggering_score = ValueOrDefault(
371 overlay, TriggeringPreconditions::VT_MIN_SMART_REPLY_TRIGGERING_SCORE,
372 defaults->min_smart_reply_triggering_score());
373 preconditions_.max_sensitive_topic_score = ValueOrDefault(
374 overlay, TriggeringPreconditions::VT_MAX_SENSITIVE_TOPIC_SCORE,
375 defaults->max_sensitive_topic_score());
376 preconditions_.suppress_on_sensitive_topic = ValueOrDefault(
377 overlay, TriggeringPreconditions::VT_SUPPRESS_ON_SENSITIVE_TOPIC,
378 defaults->suppress_on_sensitive_topic());
379 preconditions_.min_input_length =
380 ValueOrDefault(overlay, TriggeringPreconditions::VT_MIN_INPUT_LENGTH,
381 defaults->min_input_length());
382 preconditions_.max_input_length =
383 ValueOrDefault(overlay, TriggeringPreconditions::VT_MAX_INPUT_LENGTH,
384 defaults->max_input_length());
385 preconditions_.min_locale_match_fraction = ValueOrDefault(
386 overlay, TriggeringPreconditions::VT_MIN_LOCALE_MATCH_FRACTION,
387 defaults->min_locale_match_fraction());
388 preconditions_.handle_missing_locale_as_supported = ValueOrDefault(
389 overlay, TriggeringPreconditions::VT_HANDLE_MISSING_LOCALE_AS_SUPPORTED,
390 defaults->handle_missing_locale_as_supported());
391 preconditions_.handle_unknown_locale_as_supported = ValueOrDefault(
392 overlay, TriggeringPreconditions::VT_HANDLE_UNKNOWN_LOCALE_AS_SUPPORTED,
393 defaults->handle_unknown_locale_as_supported());
394 preconditions_.suppress_on_low_confidence_input = ValueOrDefault(
395 overlay, TriggeringPreconditions::VT_SUPPRESS_ON_LOW_CONFIDENCE_INPUT,
396 defaults->suppress_on_low_confidence_input());
397 preconditions_.diversification_distance_threshold = ValueOrDefault(
398 overlay, TriggeringPreconditions::VT_DIVERSIFICATION_DISTANCE_THRESHOLD,
399 defaults->diversification_distance_threshold());
400 preconditions_.confidence_threshold =
401 ValueOrDefault(overlay, TriggeringPreconditions::VT_CONFIDENCE_THRESHOLD,
402 defaults->confidence_threshold());
403 preconditions_.empirical_probability_factor = ValueOrDefault(
404 overlay, TriggeringPreconditions::VT_EMPIRICAL_PROBABILITY_FACTOR,
405 defaults->empirical_probability_factor());
406 preconditions_.min_reply_score_threshold = ValueOrDefault(
407 overlay, TriggeringPreconditions::VT_MIN_REPLY_SCORE_THRESHOLD,
408 defaults->min_reply_score_threshold());
409
410 return true;
411 }
412
EmbedTokenId(const int32 token_id,std::vector<float> * embedding) const413 bool ActionsSuggestions::EmbedTokenId(const int32 token_id,
414 std::vector<float>* embedding) const {
415 return feature_processor_->AppendFeatures(
416 {token_id},
417 /*dense_features=*/{}, embedding_executor_.get(), embedding);
418 }
419
InitializeRules(ZlibDecompressor * decompressor)420 bool ActionsSuggestions::InitializeRules(ZlibDecompressor* decompressor) {
421 if (model_->rules() != nullptr) {
422 if (!InitializeRules(decompressor, model_->rules(), &rules_)) {
423 TC3_LOG(ERROR) << "Could not initialize action rules.";
424 return false;
425 }
426 }
427
428 if (model_->low_confidence_rules() != nullptr) {
429 if (!InitializeRules(decompressor, model_->low_confidence_rules(),
430 &low_confidence_rules_)) {
431 TC3_LOG(ERROR) << "Could not initialize low confidence rules.";
432 return false;
433 }
434 }
435
436 // Extend by rules provided by the overwrite.
437 // NOTE: The rules from the original models are *not* cleared.
438 if (triggering_preconditions_overlay_ != nullptr &&
439 triggering_preconditions_overlay_->low_confidence_rules() != nullptr) {
440 // These rules are optionally compressed, but separately.
441 std::unique_ptr<ZlibDecompressor> overwrite_decompressor =
442 ZlibDecompressor::Instance();
443 if (overwrite_decompressor == nullptr) {
444 TC3_LOG(ERROR) << "Could not initialze decompressor for overwrite rules.";
445 return false;
446 }
447 if (!InitializeRules(
448 overwrite_decompressor.get(),
449 triggering_preconditions_overlay_->low_confidence_rules(),
450 &low_confidence_rules_)) {
451 TC3_LOG(ERROR)
452 << "Could not initialize low confidence rules from overwrite.";
453 return false;
454 }
455 }
456
457 return true;
458 }
459
InitializeRules(ZlibDecompressor * decompressor,const RulesModel * rules,std::vector<CompiledRule> * compiled_rules) const460 bool ActionsSuggestions::InitializeRules(
461 ZlibDecompressor* decompressor, const RulesModel* rules,
462 std::vector<CompiledRule>* compiled_rules) const {
463 for (const RulesModel_::Rule* rule : *rules->rule()) {
464 std::unique_ptr<UniLib::RegexPattern> compiled_pattern =
465 UncompressMakeRegexPattern(
466 *unilib_, rule->pattern(), rule->compressed_pattern(),
467 rules->lazy_regex_compilation(), decompressor);
468 if (compiled_pattern == nullptr) {
469 TC3_LOG(ERROR) << "Failed to load rule pattern.";
470 return false;
471 }
472
473 // Check whether there is a check on the output.
474 std::unique_ptr<UniLib::RegexPattern> compiled_output_pattern;
475 if (rule->output_pattern() != nullptr ||
476 rule->compressed_output_pattern() != nullptr) {
477 compiled_output_pattern = UncompressMakeRegexPattern(
478 *unilib_, rule->output_pattern(), rule->compressed_output_pattern(),
479 rules->lazy_regex_compilation(), decompressor);
480 if (compiled_output_pattern == nullptr) {
481 TC3_LOG(ERROR) << "Failed to load rule output pattern.";
482 return false;
483 }
484 }
485
486 compiled_rules->emplace_back(rule, std::move(compiled_pattern),
487 std::move(compiled_output_pattern));
488 }
489
490 return true;
491 }
492
IsLowConfidenceInput(const Conversation & conversation,const int num_messages,std::vector<int> * post_check_rules) const493 bool ActionsSuggestions::IsLowConfidenceInput(
494 const Conversation& conversation, const int num_messages,
495 std::vector<int>* post_check_rules) const {
496 for (int i = 1; i <= num_messages; i++) {
497 const std::string& message =
498 conversation.messages[conversation.messages.size() - i].text;
499 const UnicodeText message_unicode(
500 UTF8ToUnicodeText(message, /*do_copy=*/false));
501
502 // Run ngram linear regression model.
503 if (ngram_model_ != nullptr) {
504 if (ngram_model_->Eval(message_unicode)) {
505 return true;
506 }
507 }
508
509 // Run the regex based rules.
510 for (int low_confidence_rule = 0;
511 low_confidence_rule < low_confidence_rules_.size();
512 low_confidence_rule++) {
513 const CompiledRule& rule = low_confidence_rules_[low_confidence_rule];
514 const std::unique_ptr<UniLib::RegexMatcher> matcher =
515 rule.pattern->Matcher(message_unicode);
516 int status = UniLib::RegexMatcher::kNoError;
517 if (matcher->Find(&status) && status == UniLib::RegexMatcher::kNoError) {
518 // Rule only applies to input-output pairs, so defer the check.
519 if (rule.output_pattern != nullptr) {
520 post_check_rules->push_back(low_confidence_rule);
521 continue;
522 }
523 return true;
524 }
525 }
526 }
527 return false;
528 }
529
FilterConfidenceOutput(const std::vector<int> & post_check_rules,std::vector<ActionSuggestion> * actions) const530 bool ActionsSuggestions::FilterConfidenceOutput(
531 const std::vector<int>& post_check_rules,
532 std::vector<ActionSuggestion>* actions) const {
533 if (post_check_rules.empty() || actions->empty()) {
534 return true;
535 }
536 std::vector<ActionSuggestion> filtered_text_replies;
537 for (const ActionSuggestion& action : *actions) {
538 if (action.response_text.empty()) {
539 filtered_text_replies.push_back(action);
540 continue;
541 }
542 bool passes_post_check = true;
543 const UnicodeText text_reply_unicode(
544 UTF8ToUnicodeText(action.response_text, /*do_copy=*/false));
545 for (const int rule_id : post_check_rules) {
546 const std::unique_ptr<UniLib::RegexMatcher> matcher =
547 low_confidence_rules_[rule_id].output_pattern->Matcher(
548 text_reply_unicode);
549 if (matcher == nullptr) {
550 TC3_LOG(ERROR) << "Could not create matcher for post check rule.";
551 return false;
552 }
553 int status = UniLib::RegexMatcher::kNoError;
554 if (matcher->Find(&status) || status != UniLib::RegexMatcher::kNoError) {
555 passes_post_check = false;
556 break;
557 }
558 }
559 if (passes_post_check) {
560 filtered_text_replies.push_back(action);
561 }
562 }
563 *actions = std::move(filtered_text_replies);
564 return true;
565 }
566
SuggestionFromSpec(const ActionSuggestionSpec * action,const std::string & default_type,const std::string & default_response_text,const std::string & default_serialized_entity_data,const float default_score,const float default_priority_score) const567 ActionSuggestion ActionsSuggestions::SuggestionFromSpec(
568 const ActionSuggestionSpec* action, const std::string& default_type,
569 const std::string& default_response_text,
570 const std::string& default_serialized_entity_data,
571 const float default_score, const float default_priority_score) const {
572 ActionSuggestion suggestion;
573 suggestion.score = action != nullptr ? action->score() : default_score;
574 suggestion.priority_score =
575 action != nullptr ? action->priority_score() : default_priority_score;
576 suggestion.type = action != nullptr && action->type() != nullptr
577 ? action->type()->str()
578 : default_type;
579 suggestion.response_text =
580 action != nullptr && action->response_text() != nullptr
581 ? action->response_text()->str()
582 : default_response_text;
583 suggestion.serialized_entity_data =
584 action != nullptr && action->serialized_entity_data() != nullptr
585 ? action->serialized_entity_data()->str()
586 : default_serialized_entity_data;
587 return suggestion;
588 }
589
Tokenize(const std::vector<std::string> & context) const590 std::vector<std::vector<Token>> ActionsSuggestions::Tokenize(
591 const std::vector<std::string>& context) const {
592 std::vector<std::vector<Token>> tokens;
593 tokens.reserve(context.size());
594 for (const std::string& message : context) {
595 tokens.push_back(feature_processor_->tokenizer()->Tokenize(message));
596 }
597 return tokens;
598 }
599
EmbedTokensPerMessage(const std::vector<std::vector<Token>> & tokens,std::vector<float> * embeddings,int * max_num_tokens_per_message) const600 bool ActionsSuggestions::EmbedTokensPerMessage(
601 const std::vector<std::vector<Token>>& tokens,
602 std::vector<float>* embeddings, int* max_num_tokens_per_message) const {
603 const int num_messages = tokens.size();
604 *max_num_tokens_per_message = 0;
605 for (int i = 0; i < num_messages; i++) {
606 const int num_message_tokens = tokens[i].size();
607 if (num_message_tokens > *max_num_tokens_per_message) {
608 *max_num_tokens_per_message = num_message_tokens;
609 }
610 }
611
612 if (model_->feature_processor_options()->min_num_tokens_per_message() >
613 *max_num_tokens_per_message) {
614 *max_num_tokens_per_message =
615 model_->feature_processor_options()->min_num_tokens_per_message();
616 }
617 if (model_->feature_processor_options()->max_num_tokens_per_message() > 0 &&
618 *max_num_tokens_per_message >
619 model_->feature_processor_options()->max_num_tokens_per_message()) {
620 *max_num_tokens_per_message =
621 model_->feature_processor_options()->max_num_tokens_per_message();
622 }
623
624 // Embed all tokens and add paddings to pad tokens of each message to the
625 // maximum number of tokens in a message of the conversation.
626 // If a number of tokens is specified in the model config, tokens at the
627 // beginning of a message are dropped if they don't fit in the limit.
628 for (int i = 0; i < num_messages; i++) {
629 const int start =
630 std::max<int>(tokens[i].size() - *max_num_tokens_per_message, 0);
631 for (int pos = start; pos < tokens[i].size(); pos++) {
632 if (!feature_processor_->AppendTokenFeatures(
633 tokens[i][pos], embedding_executor_.get(), embeddings)) {
634 TC3_LOG(ERROR) << "Could not run token feature extractor.";
635 return false;
636 }
637 }
638 // Add padding.
639 for (int k = tokens[i].size(); k < *max_num_tokens_per_message; k++) {
640 embeddings->insert(embeddings->end(), embedded_padding_token_.begin(),
641 embedded_padding_token_.end());
642 }
643 }
644
645 return true;
646 }
647
EmbedAndFlattenTokens(const std::vector<std::vector<Token>> tokens,std::vector<float> * embeddings,int * total_token_count) const648 bool ActionsSuggestions::EmbedAndFlattenTokens(
649 const std::vector<std::vector<Token>> tokens,
650 std::vector<float>* embeddings, int* total_token_count) const {
651 const int num_messages = tokens.size();
652 int start_message = 0;
653 int message_token_offset = 0;
654
655 // If a maximum model input length is specified, we need to check how
656 // much we need to trim at the start.
657 const int max_num_total_tokens =
658 model_->feature_processor_options()->max_num_total_tokens();
659 if (max_num_total_tokens > 0) {
660 int total_tokens = 0;
661 start_message = num_messages - 1;
662 for (; start_message >= 0; start_message--) {
663 // Tokens of the message + start and end token.
664 const int num_message_tokens = tokens[start_message].size() + 2;
665 total_tokens += num_message_tokens;
666
667 // Check whether we exhausted the budget.
668 if (total_tokens >= max_num_total_tokens) {
669 message_token_offset = total_tokens - max_num_total_tokens;
670 break;
671 }
672 }
673 }
674
675 // Add embeddings.
676 *total_token_count = 0;
677 for (int i = start_message; i < num_messages; i++) {
678 if (message_token_offset == 0) {
679 ++(*total_token_count);
680 // Add `start message` token.
681 embeddings->insert(embeddings->end(), embedded_start_token_.begin(),
682 embedded_start_token_.end());
683 }
684
685 for (int pos = std::max(0, message_token_offset - 1);
686 pos < tokens[i].size(); pos++) {
687 ++(*total_token_count);
688 if (!feature_processor_->AppendTokenFeatures(
689 tokens[i][pos], embedding_executor_.get(), embeddings)) {
690 TC3_LOG(ERROR) << "Could not run token feature extractor.";
691 return false;
692 }
693 }
694
695 // Add `end message` token.
696 ++(*total_token_count);
697 embeddings->insert(embeddings->end(), embedded_end_token_.begin(),
698 embedded_end_token_.end());
699
700 // Reset for the subsequent messages.
701 message_token_offset = 0;
702 }
703
704 // Add optional padding.
705 const int min_num_total_tokens =
706 model_->feature_processor_options()->min_num_total_tokens();
707 for (; *total_token_count < min_num_total_tokens; ++(*total_token_count)) {
708 embeddings->insert(embeddings->end(), embedded_padding_token_.begin(),
709 embedded_padding_token_.end());
710 }
711
712 return true;
713 }
714
AllocateInput(const int conversation_length,const int max_tokens,const int total_token_count,tflite::Interpreter * interpreter) const715 bool ActionsSuggestions::AllocateInput(const int conversation_length,
716 const int max_tokens,
717 const int total_token_count,
718 tflite::Interpreter* interpreter) const {
719 if (model_->tflite_model_spec()->resize_inputs()) {
720 if (model_->tflite_model_spec()->input_context() >= 0) {
721 interpreter->ResizeInputTensor(
722 interpreter->inputs()[model_->tflite_model_spec()->input_context()],
723 {1, conversation_length});
724 }
725 if (model_->tflite_model_spec()->input_user_id() >= 0) {
726 interpreter->ResizeInputTensor(
727 interpreter->inputs()[model_->tflite_model_spec()->input_user_id()],
728 {1, conversation_length});
729 }
730 if (model_->tflite_model_spec()->input_time_diffs() >= 0) {
731 interpreter->ResizeInputTensor(
732 interpreter
733 ->inputs()[model_->tflite_model_spec()->input_time_diffs()],
734 {1, conversation_length});
735 }
736 if (model_->tflite_model_spec()->input_num_tokens() >= 0) {
737 interpreter->ResizeInputTensor(
738 interpreter
739 ->inputs()[model_->tflite_model_spec()->input_num_tokens()],
740 {conversation_length, 1});
741 }
742 if (model_->tflite_model_spec()->input_token_embeddings() >= 0) {
743 interpreter->ResizeInputTensor(
744 interpreter
745 ->inputs()[model_->tflite_model_spec()->input_token_embeddings()],
746 {conversation_length, max_tokens, token_embedding_size_});
747 }
748 if (model_->tflite_model_spec()->input_flattened_token_embeddings() >= 0) {
749 interpreter->ResizeInputTensor(
750 interpreter->inputs()[model_->tflite_model_spec()
751 ->input_flattened_token_embeddings()],
752 {1, total_token_count});
753 }
754 }
755
756 return interpreter->AllocateTensors() == kTfLiteOk;
757 }
758
SetupModelInput(const std::vector<std::string> & context,const std::vector<int> & user_ids,const std::vector<float> & time_diffs,const int num_suggestions,const float confidence_threshold,const float diversification_distance,const float empirical_probability_factor,tflite::Interpreter * interpreter) const759 bool ActionsSuggestions::SetupModelInput(
760 const std::vector<std::string>& context, const std::vector<int>& user_ids,
761 const std::vector<float>& time_diffs, const int num_suggestions,
762 const float confidence_threshold, const float diversification_distance,
763 const float empirical_probability_factor,
764 tflite::Interpreter* interpreter) const {
765 // Compute token embeddings.
766 std::vector<std::vector<Token>> tokens;
767 std::vector<float> token_embeddings;
768 std::vector<float> flattened_token_embeddings;
769 int max_tokens = 0;
770 int total_token_count = 0;
771 if (model_->tflite_model_spec()->input_num_tokens() >= 0 ||
772 model_->tflite_model_spec()->input_token_embeddings() >= 0 ||
773 model_->tflite_model_spec()->input_flattened_token_embeddings() >= 0) {
774 if (feature_processor_ == nullptr) {
775 TC3_LOG(ERROR) << "No feature processor specified.";
776 return false;
777 }
778
779 // Tokenize the messages in the conversation.
780 tokens = Tokenize(context);
781 if (model_->tflite_model_spec()->input_token_embeddings() >= 0) {
782 if (!EmbedTokensPerMessage(tokens, &token_embeddings, &max_tokens)) {
783 TC3_LOG(ERROR) << "Could not extract token features.";
784 return false;
785 }
786 }
787 if (model_->tflite_model_spec()->input_flattened_token_embeddings() >= 0) {
788 if (!EmbedAndFlattenTokens(tokens, &flattened_token_embeddings,
789 &total_token_count)) {
790 TC3_LOG(ERROR) << "Could not extract token features.";
791 return false;
792 }
793 }
794 }
795
796 if (!AllocateInput(context.size(), max_tokens, total_token_count,
797 interpreter)) {
798 TC3_LOG(ERROR) << "TensorFlow Lite model allocation failed.";
799 return false;
800 }
801 if (model_->tflite_model_spec()->input_context() >= 0) {
802 model_executor_->SetInput<std::string>(
803 model_->tflite_model_spec()->input_context(), context, interpreter);
804 }
805 if (model_->tflite_model_spec()->input_context_length() >= 0) {
806 model_executor_->SetInput<int>(
807 model_->tflite_model_spec()->input_context_length(), context.size(),
808 interpreter);
809 }
810 if (model_->tflite_model_spec()->input_user_id() >= 0) {
811 model_executor_->SetInput<int>(model_->tflite_model_spec()->input_user_id(),
812 user_ids, interpreter);
813 }
814 if (model_->tflite_model_spec()->input_num_suggestions() >= 0) {
815 model_executor_->SetInput<int>(
816 model_->tflite_model_spec()->input_num_suggestions(), num_suggestions,
817 interpreter);
818 }
819 if (model_->tflite_model_spec()->input_time_diffs() >= 0) {
820 model_executor_->SetInput<float>(
821 model_->tflite_model_spec()->input_time_diffs(), time_diffs,
822 interpreter);
823 }
824 if (model_->tflite_model_spec()->input_diversification_distance() >= 0) {
825 model_executor_->SetInput<float>(
826 model_->tflite_model_spec()->input_diversification_distance(),
827 diversification_distance, interpreter);
828 }
829 if (model_->tflite_model_spec()->input_confidence_threshold() >= 0) {
830 model_executor_->SetInput<float>(
831 model_->tflite_model_spec()->input_confidence_threshold(),
832 confidence_threshold, interpreter);
833 }
834 if (model_->tflite_model_spec()->input_empirical_probability_factor() >= 0) {
835 model_executor_->SetInput<float>(
836 model_->tflite_model_spec()->input_empirical_probability_factor(),
837 confidence_threshold, interpreter);
838 }
839 if (model_->tflite_model_spec()->input_num_tokens() >= 0) {
840 std::vector<int> num_tokens_per_message(tokens.size());
841 for (int i = 0; i < tokens.size(); i++) {
842 num_tokens_per_message[i] = tokens[i].size();
843 }
844 model_executor_->SetInput<int>(
845 model_->tflite_model_spec()->input_num_tokens(), num_tokens_per_message,
846 interpreter);
847 }
848 if (model_->tflite_model_spec()->input_token_embeddings() >= 0) {
849 model_executor_->SetInput<float>(
850 model_->tflite_model_spec()->input_token_embeddings(), token_embeddings,
851 interpreter);
852 }
853 if (model_->tflite_model_spec()->input_flattened_token_embeddings() >= 0) {
854 model_executor_->SetInput<float>(
855 model_->tflite_model_spec()->input_flattened_token_embeddings(),
856 flattened_token_embeddings, interpreter);
857 }
858 return true;
859 }
860
ReadModelOutput(tflite::Interpreter * interpreter,const ActionSuggestionOptions & options,ActionsSuggestionsResponse * response) const861 bool ActionsSuggestions::ReadModelOutput(
862 tflite::Interpreter* interpreter, const ActionSuggestionOptions& options,
863 ActionsSuggestionsResponse* response) const {
864 // Read sensitivity and triggering score predictions.
865 if (model_->tflite_model_spec()->output_triggering_score() >= 0) {
866 const TensorView<float>& triggering_score =
867 model_executor_->OutputView<float>(
868 model_->tflite_model_spec()->output_triggering_score(),
869 interpreter);
870 if (!triggering_score.is_valid() || triggering_score.size() == 0) {
871 TC3_LOG(ERROR) << "Could not compute triggering score.";
872 return false;
873 }
874 response->triggering_score = triggering_score.data()[0];
875 response->output_filtered_min_triggering_score =
876 (response->triggering_score <
877 preconditions_.min_smart_reply_triggering_score);
878 }
879 if (model_->tflite_model_spec()->output_sensitive_topic_score() >= 0) {
880 const TensorView<float>& sensitive_topic_score =
881 model_executor_->OutputView<float>(
882 model_->tflite_model_spec()->output_sensitive_topic_score(),
883 interpreter);
884 if (!sensitive_topic_score.is_valid() ||
885 sensitive_topic_score.dim(0) != 1) {
886 TC3_LOG(ERROR) << "Could not compute sensitive topic score.";
887 return false;
888 }
889 response->sensitivity_score = sensitive_topic_score.data()[0];
890 response->output_filtered_sensitivity =
891 (response->sensitivity_score >
892 preconditions_.max_sensitive_topic_score);
893 }
894
895 // Suppress model outputs.
896 if (response->output_filtered_sensitivity) {
897 return true;
898 }
899
900 // Read smart reply predictions.
901 std::vector<ActionSuggestion> text_replies;
902 if (!response->output_filtered_min_triggering_score &&
903 model_->tflite_model_spec()->output_replies() >= 0) {
904 const std::vector<tflite::StringRef> replies =
905 model_executor_->Output<tflite::StringRef>(
906 model_->tflite_model_spec()->output_replies(), interpreter);
907 TensorView<float> scores = model_executor_->OutputView<float>(
908 model_->tflite_model_spec()->output_replies_scores(), interpreter);
909 for (int i = 0; i < replies.size(); i++) {
910 if (replies[i].len == 0) continue;
911 const float score = scores.data()[i];
912 if (score < preconditions_.min_reply_score_threshold) {
913 continue;
914 }
915 response->actions.push_back({std::string(replies[i].str, replies[i].len),
916 model_->smart_reply_action_type()->str(),
917 score});
918 }
919 }
920
921 // Read actions suggestions.
922 if (model_->tflite_model_spec()->output_actions_scores() >= 0) {
923 const TensorView<float> actions_scores = model_executor_->OutputView<float>(
924 model_->tflite_model_spec()->output_actions_scores(), interpreter);
925 for (int i = 0; i < model_->action_type()->Length(); i++) {
926 const ActionTypeOptions* action_type = model_->action_type()->Get(i);
927 // Skip disabled action classes, such as the default other category.
928 if (!action_type->enabled()) {
929 continue;
930 }
931 const float score = actions_scores.data()[i];
932 if (score < action_type->min_triggering_score()) {
933 continue;
934 }
935 ActionSuggestion suggestion =
936 SuggestionFromSpec(action_type->action(),
937 /*default_type=*/action_type->name()->str());
938 suggestion.score = score;
939 response->actions.push_back(suggestion);
940 }
941 }
942
943 return true;
944 }
945
SuggestActionsFromModel(const Conversation & conversation,const int num_messages,const ActionSuggestionOptions & options,ActionsSuggestionsResponse * response,std::unique_ptr<tflite::Interpreter> * interpreter) const946 bool ActionsSuggestions::SuggestActionsFromModel(
947 const Conversation& conversation, const int num_messages,
948 const ActionSuggestionOptions& options,
949 ActionsSuggestionsResponse* response,
950 std::unique_ptr<tflite::Interpreter>* interpreter) const {
951 TC3_CHECK_LE(num_messages, conversation.messages.size());
952
953 if (!model_executor_) {
954 return true;
955 }
956 *interpreter = model_executor_->CreateInterpreter();
957
958 if (!*interpreter) {
959 TC3_LOG(ERROR) << "Could not build TensorFlow Lite interpreter for the "
960 "actions suggestions model.";
961 return false;
962 }
963
964 std::vector<std::string> context;
965 std::vector<int> user_ids;
966 std::vector<float> time_diffs;
967 context.reserve(num_messages);
968 user_ids.reserve(num_messages);
969 time_diffs.reserve(num_messages);
970
971 // Gather last `num_messages` messages from the conversation.
972 int64 last_message_reference_time_ms_utc = 0;
973 const float second_in_ms = 1000;
974 for (int i = conversation.messages.size() - num_messages;
975 i < conversation.messages.size(); i++) {
976 const ConversationMessage& message = conversation.messages[i];
977 context.push_back(message.text);
978 user_ids.push_back(message.user_id);
979
980 float time_diff_secs = 0;
981 if (message.reference_time_ms_utc != 0 &&
982 last_message_reference_time_ms_utc != 0) {
983 time_diff_secs = std::max(0.0f, (message.reference_time_ms_utc -
984 last_message_reference_time_ms_utc) /
985 second_in_ms);
986 }
987 if (message.reference_time_ms_utc != 0) {
988 last_message_reference_time_ms_utc = message.reference_time_ms_utc;
989 }
990 time_diffs.push_back(time_diff_secs);
991 }
992
993 if (!SetupModelInput(context, user_ids, time_diffs,
994 /*num_suggestions=*/model_->num_smart_replies(),
995 preconditions_.confidence_threshold,
996 preconditions_.diversification_distance_threshold,
997 preconditions_.empirical_probability_factor,
998 interpreter->get())) {
999 TC3_LOG(ERROR) << "Failed to setup input for TensorFlow Lite model.";
1000 return false;
1001 }
1002
1003 if ((*interpreter)->Invoke() != kTfLiteOk) {
1004 TC3_LOG(ERROR) << "Failed to invoke TensorFlow Lite interpreter.";
1005 return false;
1006 }
1007
1008 return ReadModelOutput(interpreter->get(), options, response);
1009 }
1010
AnnotationOptionsForMessage(const ConversationMessage & message) const1011 AnnotationOptions ActionsSuggestions::AnnotationOptionsForMessage(
1012 const ConversationMessage& message) const {
1013 AnnotationOptions options;
1014 options.detected_text_language_tags = message.detected_text_language_tags;
1015 options.reference_time_ms_utc = message.reference_time_ms_utc;
1016 options.reference_timezone = message.reference_timezone;
1017 options.annotation_usecase =
1018 model_->annotation_actions_spec()->annotation_usecase();
1019 options.is_serialized_entity_data_enabled =
1020 model_->annotation_actions_spec()->is_serialized_entity_data_enabled();
1021 options.entity_types = annotation_entity_types_;
1022 return options;
1023 }
1024
SuggestActionsFromAnnotations(const Conversation & conversation,const ActionSuggestionOptions & options,const Annotator * annotator,std::vector<ActionSuggestion> * actions) const1025 void ActionsSuggestions::SuggestActionsFromAnnotations(
1026 const Conversation& conversation, const ActionSuggestionOptions& options,
1027 const Annotator* annotator, std::vector<ActionSuggestion>* actions) const {
1028 if (model_->annotation_actions_spec() == nullptr ||
1029 model_->annotation_actions_spec()->annotation_mapping() == nullptr ||
1030 model_->annotation_actions_spec()->annotation_mapping()->size() == 0) {
1031 return;
1032 }
1033
1034 // Create actions based on the annotations.
1035 const int max_from_any_person =
1036 model_->annotation_actions_spec()->max_history_from_any_person();
1037 const int max_from_last_person =
1038 model_->annotation_actions_spec()->max_history_from_last_person();
1039 const int last_person = conversation.messages.back().user_id;
1040
1041 int num_messages_last_person = 0;
1042 int num_messages_any_person = 0;
1043 bool all_from_last_person = true;
1044 for (int message_index = conversation.messages.size() - 1; message_index >= 0;
1045 message_index--) {
1046 const ConversationMessage& message = conversation.messages[message_index];
1047 std::vector<AnnotatedSpan> annotations = message.annotations;
1048
1049 // Update how many messages we have processed from the last person in the
1050 // conversation and from any person in the conversation.
1051 num_messages_any_person++;
1052 if (all_from_last_person && message.user_id == last_person) {
1053 num_messages_last_person++;
1054 } else {
1055 all_from_last_person = false;
1056 }
1057
1058 if (num_messages_any_person > max_from_any_person &&
1059 (!all_from_last_person ||
1060 num_messages_last_person > max_from_last_person)) {
1061 break;
1062 }
1063
1064 if (message.user_id == kLocalUserId) {
1065 if (model_->annotation_actions_spec()->only_until_last_sent()) {
1066 break;
1067 }
1068 if (!model_->annotation_actions_spec()->include_local_user_messages()) {
1069 continue;
1070 }
1071 }
1072
1073 if (annotations.empty() && annotator != nullptr) {
1074 annotations = annotator->Annotate(message.text,
1075 AnnotationOptionsForMessage(message));
1076 }
1077 std::vector<ActionSuggestionAnnotation> action_annotations;
1078 action_annotations.reserve(annotations.size());
1079 for (const AnnotatedSpan& annotation : annotations) {
1080 if (annotation.classification.empty()) {
1081 continue;
1082 }
1083
1084 const ClassificationResult& classification_result =
1085 annotation.classification[0];
1086
1087 ActionSuggestionAnnotation action_annotation;
1088 action_annotation.span = {
1089 message_index, annotation.span,
1090 UTF8ToUnicodeText(message.text, /*do_copy=*/false)
1091 .UTF8Substring(annotation.span.first, annotation.span.second)};
1092 action_annotation.entity = classification_result;
1093 action_annotation.name = classification_result.collection;
1094 action_annotations.push_back(action_annotation);
1095 }
1096
1097 if (model_->annotation_actions_spec()->deduplicate_annotations()) {
1098 // Create actions only for deduplicated annotations.
1099 for (const int annotation_id :
1100 DeduplicateAnnotations(action_annotations)) {
1101 SuggestActionsFromAnnotation(
1102 message_index, action_annotations[annotation_id], actions);
1103 }
1104 } else {
1105 // Create actions for all annotations.
1106 for (const ActionSuggestionAnnotation& annotation : action_annotations) {
1107 SuggestActionsFromAnnotation(message_index, annotation, actions);
1108 }
1109 }
1110 }
1111 }
1112
SuggestActionsFromAnnotation(const int message_index,const ActionSuggestionAnnotation & annotation,std::vector<ActionSuggestion> * actions) const1113 void ActionsSuggestions::SuggestActionsFromAnnotation(
1114 const int message_index, const ActionSuggestionAnnotation& annotation,
1115 std::vector<ActionSuggestion>* actions) const {
1116 for (const AnnotationActionsSpec_::AnnotationMapping* mapping :
1117 *model_->annotation_actions_spec()->annotation_mapping()) {
1118 if (annotation.entity.collection ==
1119 mapping->annotation_collection()->str()) {
1120 if (annotation.entity.score < mapping->min_annotation_score()) {
1121 continue;
1122 }
1123 ActionSuggestion suggestion = SuggestionFromSpec(mapping->action());
1124 if (mapping->use_annotation_score()) {
1125 suggestion.score = annotation.entity.score;
1126 }
1127
1128 // Set annotation text as (additional) entity data field.
1129 if (mapping->entity_field() != nullptr) {
1130 std::unique_ptr<ReflectiveFlatbuffer> entity_data =
1131 entity_data_builder_->NewRoot();
1132 TC3_CHECK(entity_data != nullptr);
1133
1134 // Merge existing static entity data.
1135 if (!suggestion.serialized_entity_data.empty()) {
1136 entity_data->MergeFromSerializedFlatbuffer(
1137 StringPiece(suggestion.serialized_entity_data.c_str(),
1138 suggestion.serialized_entity_data.size()));
1139 }
1140
1141 entity_data->ParseAndSet(mapping->entity_field(), annotation.span.text);
1142 suggestion.serialized_entity_data = entity_data->Serialize();
1143 }
1144
1145 suggestion.annotations = {annotation};
1146 actions->push_back(suggestion);
1147 }
1148 }
1149 }
1150
DeduplicateAnnotations(const std::vector<ActionSuggestionAnnotation> & annotations) const1151 std::vector<int> ActionsSuggestions::DeduplicateAnnotations(
1152 const std::vector<ActionSuggestionAnnotation>& annotations) const {
1153 std::map<std::pair<std::string, std::string>, int> deduplicated_annotations;
1154
1155 for (int i = 0; i < annotations.size(); i++) {
1156 const std::pair<std::string, std::string> key = {annotations[i].name,
1157 annotations[i].span.text};
1158 auto entry = deduplicated_annotations.find(key);
1159 if (entry != deduplicated_annotations.end()) {
1160 // Kepp the annotation with the higher score.
1161 if (annotations[entry->second].entity.score <
1162 annotations[i].entity.score) {
1163 entry->second = i;
1164 }
1165 continue;
1166 }
1167 deduplicated_annotations.insert(entry, {key, i});
1168 }
1169
1170 std::vector<int> result;
1171 result.reserve(deduplicated_annotations.size());
1172 for (const auto& key_and_annotation : deduplicated_annotations) {
1173 result.push_back(key_and_annotation.second);
1174 }
1175 return result;
1176 }
1177
FillAnnotationFromMatchGroup(const UniLib::RegexMatcher * matcher,const RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroup * group,const int message_index,ActionSuggestionAnnotation * annotation) const1178 bool ActionsSuggestions::FillAnnotationFromMatchGroup(
1179 const UniLib::RegexMatcher* matcher,
1180 const RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroup* group,
1181 const int message_index, ActionSuggestionAnnotation* annotation) const {
1182 if (group->annotation_name() != nullptr ||
1183 group->annotation_type() != nullptr) {
1184 int status = UniLib::RegexMatcher::kNoError;
1185 const CodepointSpan span = {matcher->Start(group->group_id(), &status),
1186 matcher->End(group->group_id(), &status)};
1187 std::string text =
1188 matcher->Group(group->group_id(), &status).ToUTF8String();
1189 if (status != UniLib::RegexMatcher::kNoError) {
1190 TC3_LOG(ERROR) << "Could not extract span from rule capturing group.";
1191 return false;
1192 }
1193
1194 // The capturing group was not part of the match.
1195 if (span.first == kInvalidIndex || span.second == kInvalidIndex) {
1196 return false;
1197 }
1198 annotation->span.span = span;
1199 annotation->span.message_index = message_index;
1200 annotation->span.text = text;
1201 if (group->annotation_name() != nullptr) {
1202 annotation->name = group->annotation_name()->str();
1203 }
1204 if (group->annotation_type() != nullptr) {
1205 annotation->entity.collection = group->annotation_type()->str();
1206 }
1207 }
1208 return true;
1209 }
1210
SuggestActionsFromRules(const Conversation & conversation,std::vector<ActionSuggestion> * actions) const1211 bool ActionsSuggestions::SuggestActionsFromRules(
1212 const Conversation& conversation,
1213 std::vector<ActionSuggestion>* actions) const {
1214 // Create actions based on rules checking the last message.
1215 const int message_index = conversation.messages.size() - 1;
1216 const std::string& message = conversation.messages.back().text;
1217 const UnicodeText message_unicode(
1218 UTF8ToUnicodeText(message, /*do_copy=*/false));
1219 for (const CompiledRule& rule : rules_) {
1220 const std::unique_ptr<UniLib::RegexMatcher> matcher =
1221 rule.pattern->Matcher(message_unicode);
1222 int status = UniLib::RegexMatcher::kNoError;
1223 while (matcher->Find(&status) && status == UniLib::RegexMatcher::kNoError) {
1224 for (const RulesModel_::Rule_::RuleActionSpec* rule_action :
1225 *rule.rule->actions()) {
1226 const ActionSuggestionSpec* action = rule_action->action();
1227 std::vector<ActionSuggestionAnnotation> annotations;
1228
1229 bool sets_entity_data = false;
1230 std::unique_ptr<ReflectiveFlatbuffer> entity_data =
1231 entity_data_builder_ != nullptr ? entity_data_builder_->NewRoot()
1232 : nullptr;
1233
1234 // Set static entity data.
1235 if (action != nullptr && action->serialized_entity_data() != nullptr) {
1236 TC3_CHECK(entity_data != nullptr);
1237 sets_entity_data = true;
1238 entity_data->MergeFromSerializedFlatbuffer(
1239 StringPiece(action->serialized_entity_data()->c_str(),
1240 action->serialized_entity_data()->size()));
1241 }
1242
1243 // Add entity data from rule capturing groups.
1244 if (rule_action->capturing_group() != nullptr) {
1245 for (const RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroup*
1246 group : *rule_action->capturing_group()) {
1247 if (group->entity_field() != nullptr) {
1248 TC3_CHECK(entity_data != nullptr);
1249 sets_entity_data = true;
1250 if (!SetFieldFromCapturingGroup(
1251 group->group_id(), group->entity_field(), matcher.get(),
1252 entity_data.get())) {
1253 TC3_LOG(ERROR)
1254 << "Could not set entity data from rule capturing group.";
1255 return false;
1256 }
1257 }
1258
1259 // Create a text annotation for the group span.
1260 ActionSuggestionAnnotation annotation;
1261 if (FillAnnotationFromMatchGroup(matcher.get(), group,
1262 message_index, &annotation)) {
1263 annotations.push_back(annotation);
1264 }
1265
1266 // Create text reply.
1267 if (group->text_reply() != nullptr) {
1268 int status = UniLib::RegexMatcher::kNoError;
1269 const std::string group_text =
1270 matcher->Group(group->group_id(), &status).ToUTF8String();
1271 if (status != UniLib::RegexMatcher::kNoError) {
1272 TC3_LOG(ERROR) << "Could get text from capturing group.";
1273 return false;
1274 }
1275 if (group_text.empty()) {
1276 // The group was not part of the match, ignore and continue.
1277 continue;
1278 }
1279 actions->push_back(SuggestionFromSpec(
1280 group->text_reply(),
1281 /*default_type=*/model_->smart_reply_action_type()->str(),
1282 /*default_response_text=*/group_text));
1283 }
1284 }
1285 }
1286
1287 if (action != nullptr) {
1288 ActionSuggestion suggestion = SuggestionFromSpec(action);
1289 suggestion.annotations = annotations;
1290 if (sets_entity_data) {
1291 suggestion.serialized_entity_data = entity_data->Serialize();
1292 }
1293 actions->push_back(suggestion);
1294 }
1295 }
1296 }
1297 }
1298 return true;
1299 }
1300
SuggestActionsFromLua(const Conversation & conversation,const TfLiteModelExecutor * model_executor,const tflite::Interpreter * interpreter,const reflection::Schema * annotation_entity_data_schema,std::vector<ActionSuggestion> * actions) const1301 bool ActionsSuggestions::SuggestActionsFromLua(
1302 const Conversation& conversation, const TfLiteModelExecutor* model_executor,
1303 const tflite::Interpreter* interpreter,
1304 const reflection::Schema* annotation_entity_data_schema,
1305 std::vector<ActionSuggestion>* actions) const {
1306 if (lua_bytecode_.empty()) {
1307 return true;
1308 }
1309
1310 auto lua_actions = LuaActionsSuggestions::CreateLuaActionsSuggestions(
1311 lua_bytecode_, conversation, model_executor, model_->tflite_model_spec(),
1312 interpreter, entity_data_schema_, annotation_entity_data_schema);
1313 if (lua_actions == nullptr) {
1314 TC3_LOG(ERROR) << "Could not create lua actions.";
1315 return false;
1316 }
1317 return lua_actions->SuggestActions(actions);
1318 }
1319
GatherActionsSuggestions(const Conversation & conversation,const Annotator * annotator,const ActionSuggestionOptions & options,ActionsSuggestionsResponse * response) const1320 bool ActionsSuggestions::GatherActionsSuggestions(
1321 const Conversation& conversation, const Annotator* annotator,
1322 const ActionSuggestionOptions& options,
1323 ActionsSuggestionsResponse* response) const {
1324 if (conversation.messages.empty()) {
1325 return true;
1326 }
1327
1328 const int num_messages = NumMessagesToConsider(
1329 conversation, model_->max_conversation_history_length());
1330
1331 if (num_messages <= 0) {
1332 TC3_LOG(INFO) << "No messages provided for actions suggestions.";
1333 return false;
1334 }
1335
1336 SuggestActionsFromAnnotations(conversation, options, annotator,
1337 &response->actions);
1338
1339 int input_text_length = 0;
1340 int num_matching_locales = 0;
1341 for (int i = conversation.messages.size() - num_messages;
1342 i < conversation.messages.size(); i++) {
1343 input_text_length += conversation.messages[i].text.length();
1344 std::vector<Locale> message_languages;
1345 if (!ParseLocales(conversation.messages[i].detected_text_language_tags,
1346 &message_languages)) {
1347 continue;
1348 }
1349 if (Locale::IsAnyLocaleSupported(
1350 message_languages, locales_,
1351 preconditions_.handle_unknown_locale_as_supported)) {
1352 ++num_matching_locales;
1353 }
1354 }
1355
1356 // Bail out if we are provided with too few or too much input.
1357 if (input_text_length < preconditions_.min_input_length ||
1358 (preconditions_.max_input_length >= 0 &&
1359 input_text_length > preconditions_.max_input_length)) {
1360 TC3_LOG(INFO) << "Too much or not enough input for inference.";
1361 return response;
1362 }
1363
1364 // Bail out if the text does not look like it can be handled by the model.
1365 const float matching_fraction =
1366 static_cast<float>(num_matching_locales) / num_messages;
1367 if (matching_fraction < preconditions_.min_locale_match_fraction) {
1368 TC3_LOG(INFO) << "Not enough locale matches.";
1369 response->output_filtered_locale_mismatch = true;
1370 return true;
1371 }
1372
1373 std::vector<int> post_check_rules;
1374 if (preconditions_.suppress_on_low_confidence_input &&
1375 IsLowConfidenceInput(conversation, num_messages, &post_check_rules)) {
1376 response->output_filtered_low_confidence = true;
1377 return true;
1378 }
1379
1380 std::unique_ptr<tflite::Interpreter> interpreter;
1381 if (!SuggestActionsFromModel(conversation, num_messages, options, response,
1382 &interpreter)) {
1383 TC3_LOG(ERROR) << "Could not run model.";
1384 return false;
1385 }
1386
1387 // Suppress all predictions if the conversation was deemed sensitive.
1388 if (preconditions_.suppress_on_sensitive_topic &&
1389 response->output_filtered_sensitivity) {
1390 return true;
1391 }
1392
1393 if (!SuggestActionsFromLua(
1394 conversation, model_executor_.get(), interpreter.get(),
1395 annotator != nullptr ? annotator->entity_data_schema() : nullptr,
1396 &response->actions)) {
1397 TC3_LOG(ERROR) << "Could not suggest actions from script.";
1398 return false;
1399 }
1400
1401 if (!SuggestActionsFromRules(conversation, &response->actions)) {
1402 TC3_LOG(ERROR) << "Could not suggest actions from rules.";
1403 return false;
1404 }
1405
1406 if (preconditions_.suppress_on_low_confidence_input &&
1407 !FilterConfidenceOutput(post_check_rules, &response->actions)) {
1408 TC3_LOG(ERROR) << "Could not post-check actions.";
1409 return false;
1410 }
1411
1412 return true;
1413 }
1414
SuggestActions(const Conversation & conversation,const Annotator * annotator,const ActionSuggestionOptions & options) const1415 ActionsSuggestionsResponse ActionsSuggestions::SuggestActions(
1416 const Conversation& conversation, const Annotator* annotator,
1417 const ActionSuggestionOptions& options) const {
1418 ActionsSuggestionsResponse response;
1419 if (!GatherActionsSuggestions(conversation, annotator, options, &response)) {
1420 TC3_LOG(ERROR) << "Could not gather actions suggestions.";
1421 response.actions.clear();
1422 } else if (!ranker_->RankActions(conversation, &response, entity_data_schema_,
1423 annotator != nullptr
1424 ? annotator->entity_data_schema()
1425 : nullptr)) {
1426 TC3_LOG(ERROR) << "Could not rank actions.";
1427 response.actions.clear();
1428 }
1429 return response;
1430 }
1431
SuggestActions(const Conversation & conversation,const ActionSuggestionOptions & options) const1432 ActionsSuggestionsResponse ActionsSuggestions::SuggestActions(
1433 const Conversation& conversation,
1434 const ActionSuggestionOptions& options) const {
1435 return SuggestActions(conversation, /*annotator=*/nullptr, options);
1436 }
1437
model() const1438 const ActionsModel* ActionsSuggestions::model() const { return model_; }
entity_data_schema() const1439 const reflection::Schema* ActionsSuggestions::entity_data_schema() const {
1440 return entity_data_schema_;
1441 }
1442
ViewActionsModel(const void * buffer,int size)1443 const ActionsModel* ViewActionsModel(const void* buffer, int size) {
1444 if (buffer == nullptr) {
1445 return nullptr;
1446 }
1447 return LoadAndVerifyModel(reinterpret_cast<const uint8_t*>(buffer), size);
1448 }
1449
1450 } // namespace libtextclassifier3
1451