• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "actions/actions-suggestions.h"
18 
19 #include <memory>
20 
21 #include "actions/lua-actions.h"
22 #include "actions/types.h"
23 #include "actions/zlib-utils.h"
24 #include "utils/base/logging.h"
25 #include "utils/flatbuffers.h"
26 #include "utils/lua-utils.h"
27 #include "utils/regex-match.h"
28 #include "utils/strings/split.h"
29 #include "utils/strings/stringpiece.h"
30 #include "utils/utf8/unicodetext.h"
31 #include "utils/zlib/zlib_regex.h"
32 #include "tensorflow/lite/string_util.h"
33 
34 namespace libtextclassifier3 {
35 
36 const std::string& ActionsSuggestions::kViewCalendarType =
__anon2dbc54b70102() 37     *[]() { return new std::string("view_calendar"); }();
38 const std::string& ActionsSuggestions::kViewMapType =
__anon2dbc54b70202() 39     *[]() { return new std::string("view_map"); }();
40 const std::string& ActionsSuggestions::kTrackFlightType =
__anon2dbc54b70302() 41     *[]() { return new std::string("track_flight"); }();
42 const std::string& ActionsSuggestions::kOpenUrlType =
__anon2dbc54b70402() 43     *[]() { return new std::string("open_url"); }();
44 const std::string& ActionsSuggestions::kSendSmsType =
__anon2dbc54b70502() 45     *[]() { return new std::string("send_sms"); }();
46 const std::string& ActionsSuggestions::kCallPhoneType =
__anon2dbc54b70602() 47     *[]() { return new std::string("call_phone"); }();
48 const std::string& ActionsSuggestions::kSendEmailType =
__anon2dbc54b70702() 49     *[]() { return new std::string("send_email"); }();
50 const std::string& ActionsSuggestions::kShareLocation =
__anon2dbc54b70802() 51     *[]() { return new std::string("share_location"); }();
52 
53 namespace {
54 
LoadAndVerifyModel(const uint8_t * addr,int size)55 const ActionsModel* LoadAndVerifyModel(const uint8_t* addr, int size) {
56   flatbuffers::Verifier verifier(addr, size);
57   if (VerifyActionsModelBuffer(verifier)) {
58     return GetActionsModel(addr);
59   } else {
60     return nullptr;
61   }
62 }
63 
64 template <typename T>
ValueOrDefault(const flatbuffers::Table * values,const int32 field_offset,const T default_value)65 T ValueOrDefault(const flatbuffers::Table* values, const int32 field_offset,
66                  const T default_value) {
67   if (values == nullptr) {
68     return default_value;
69   }
70   return values->GetField<T>(field_offset, default_value);
71 }
72 
73 // Returns number of (tail) messages of a conversation to consider.
NumMessagesToConsider(const Conversation & conversation,const int max_conversation_history_length)74 int NumMessagesToConsider(const Conversation& conversation,
75                           const int max_conversation_history_length) {
76   return ((max_conversation_history_length < 0 ||
77            conversation.messages.size() < max_conversation_history_length)
78               ? conversation.messages.size()
79               : max_conversation_history_length);
80 }
81 
82 }  // namespace
83 
FromUnownedBuffer(const uint8_t * buffer,const int size,const UniLib * unilib,const std::string & triggering_preconditions_overlay)84 std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromUnownedBuffer(
85     const uint8_t* buffer, const int size, const UniLib* unilib,
86     const std::string& triggering_preconditions_overlay) {
87   auto actions = std::unique_ptr<ActionsSuggestions>(new ActionsSuggestions());
88   const ActionsModel* model = LoadAndVerifyModel(buffer, size);
89   if (model == nullptr) {
90     return nullptr;
91   }
92   actions->model_ = model;
93   actions->SetOrCreateUnilib(unilib);
94   actions->triggering_preconditions_overlay_buffer_ =
95       triggering_preconditions_overlay;
96   if (!actions->ValidateAndInitialize()) {
97     return nullptr;
98   }
99   return actions;
100 }
101 
FromScopedMmap(std::unique_ptr<libtextclassifier3::ScopedMmap> mmap,const UniLib * unilib,const std::string & triggering_preconditions_overlay)102 std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromScopedMmap(
103     std::unique_ptr<libtextclassifier3::ScopedMmap> mmap, const UniLib* unilib,
104     const std::string& triggering_preconditions_overlay) {
105   if (!mmap->handle().ok()) {
106     TC3_VLOG(1) << "Mmap failed.";
107     return nullptr;
108   }
109   const ActionsModel* model = LoadAndVerifyModel(
110       reinterpret_cast<const uint8_t*>(mmap->handle().start()),
111       mmap->handle().num_bytes());
112   if (!model) {
113     TC3_LOG(ERROR) << "Model verification failed.";
114     return nullptr;
115   }
116   auto actions = std::unique_ptr<ActionsSuggestions>(new ActionsSuggestions());
117   actions->model_ = model;
118   actions->mmap_ = std::move(mmap);
119   actions->SetOrCreateUnilib(unilib);
120   actions->triggering_preconditions_overlay_buffer_ =
121       triggering_preconditions_overlay;
122   if (!actions->ValidateAndInitialize()) {
123     return nullptr;
124   }
125   return actions;
126 }
127 
FromScopedMmap(std::unique_ptr<libtextclassifier3::ScopedMmap> mmap,std::unique_ptr<UniLib> unilib,const std::string & triggering_preconditions_overlay)128 std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromScopedMmap(
129     std::unique_ptr<libtextclassifier3::ScopedMmap> mmap,
130     std::unique_ptr<UniLib> unilib,
131     const std::string& triggering_preconditions_overlay) {
132   if (!mmap->handle().ok()) {
133     TC3_VLOG(1) << "Mmap failed.";
134     return nullptr;
135   }
136   const ActionsModel* model = LoadAndVerifyModel(
137       reinterpret_cast<const uint8_t*>(mmap->handle().start()),
138       mmap->handle().num_bytes());
139   if (!model) {
140     TC3_LOG(ERROR) << "Model verification failed.";
141     return nullptr;
142   }
143   auto actions = std::unique_ptr<ActionsSuggestions>(new ActionsSuggestions());
144   actions->model_ = model;
145   actions->mmap_ = std::move(mmap);
146   actions->owned_unilib_ = std::move(unilib);
147   actions->unilib_ = actions->owned_unilib_.get();
148   actions->triggering_preconditions_overlay_buffer_ =
149       triggering_preconditions_overlay;
150   if (!actions->ValidateAndInitialize()) {
151     return nullptr;
152   }
153   return actions;
154 }
155 
FromFileDescriptor(const int fd,const int offset,const int size,const UniLib * unilib,const std::string & triggering_preconditions_overlay)156 std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromFileDescriptor(
157     const int fd, const int offset, const int size, const UniLib* unilib,
158     const std::string& triggering_preconditions_overlay) {
159   std::unique_ptr<libtextclassifier3::ScopedMmap> mmap;
160   if (offset >= 0 && size >= 0) {
161     mmap.reset(new libtextclassifier3::ScopedMmap(fd, offset, size));
162   } else {
163     mmap.reset(new libtextclassifier3::ScopedMmap(fd));
164   }
165   return FromScopedMmap(std::move(mmap), unilib,
166                         triggering_preconditions_overlay);
167 }
168 
FromFileDescriptor(const int fd,const int offset,const int size,std::unique_ptr<UniLib> unilib,const std::string & triggering_preconditions_overlay)169 std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromFileDescriptor(
170     const int fd, const int offset, const int size,
171     std::unique_ptr<UniLib> unilib,
172     const std::string& triggering_preconditions_overlay) {
173   std::unique_ptr<libtextclassifier3::ScopedMmap> mmap;
174   if (offset >= 0 && size >= 0) {
175     mmap.reset(new libtextclassifier3::ScopedMmap(fd, offset, size));
176   } else {
177     mmap.reset(new libtextclassifier3::ScopedMmap(fd));
178   }
179   return FromScopedMmap(std::move(mmap), std::move(unilib),
180                         triggering_preconditions_overlay);
181 }
182 
FromFileDescriptor(const int fd,const UniLib * unilib,const std::string & triggering_preconditions_overlay)183 std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromFileDescriptor(
184     const int fd, const UniLib* unilib,
185     const std::string& triggering_preconditions_overlay) {
186   std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
187       new libtextclassifier3::ScopedMmap(fd));
188   return FromScopedMmap(std::move(mmap), unilib,
189                         triggering_preconditions_overlay);
190 }
191 
FromFileDescriptor(const int fd,std::unique_ptr<UniLib> unilib,const std::string & triggering_preconditions_overlay)192 std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromFileDescriptor(
193     const int fd, std::unique_ptr<UniLib> unilib,
194     const std::string& triggering_preconditions_overlay) {
195   std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
196       new libtextclassifier3::ScopedMmap(fd));
197   return FromScopedMmap(std::move(mmap), std::move(unilib),
198                         triggering_preconditions_overlay);
199 }
200 
FromPath(const std::string & path,const UniLib * unilib,const std::string & triggering_preconditions_overlay)201 std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromPath(
202     const std::string& path, const UniLib* unilib,
203     const std::string& triggering_preconditions_overlay) {
204   std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
205       new libtextclassifier3::ScopedMmap(path));
206   return FromScopedMmap(std::move(mmap), unilib,
207                         triggering_preconditions_overlay);
208 }
209 
FromPath(const std::string & path,std::unique_ptr<UniLib> unilib,const std::string & triggering_preconditions_overlay)210 std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromPath(
211     const std::string& path, std::unique_ptr<UniLib> unilib,
212     const std::string& triggering_preconditions_overlay) {
213   std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
214       new libtextclassifier3::ScopedMmap(path));
215   return FromScopedMmap(std::move(mmap), std::move(unilib),
216                         triggering_preconditions_overlay);
217 }
218 
SetOrCreateUnilib(const UniLib * unilib)219 void ActionsSuggestions::SetOrCreateUnilib(const UniLib* unilib) {
220   if (unilib != nullptr) {
221     unilib_ = unilib;
222   } else {
223     owned_unilib_.reset(new UniLib);
224     unilib_ = owned_unilib_.get();
225   }
226 }
227 
ValidateAndInitialize()228 bool ActionsSuggestions::ValidateAndInitialize() {
229   if (model_ == nullptr) {
230     TC3_LOG(ERROR) << "No model specified.";
231     return false;
232   }
233 
234   if (model_->smart_reply_action_type() == nullptr) {
235     TC3_LOG(ERROR) << "No smart reply action type specified.";
236     return false;
237   }
238 
239   if (!InitializeTriggeringPreconditions()) {
240     TC3_LOG(ERROR) << "Could not initialize preconditions.";
241     return false;
242   }
243 
244   if (model_->locales() &&
245       !ParseLocales(model_->locales()->c_str(), &locales_)) {
246     TC3_LOG(ERROR) << "Could not parse model supported locales.";
247     return false;
248   }
249 
250   if (model_->tflite_model_spec() != nullptr) {
251     model_executor_ = TfLiteModelExecutor::FromBuffer(
252         model_->tflite_model_spec()->tflite_model());
253     if (!model_executor_) {
254       TC3_LOG(ERROR) << "Could not initialize model executor.";
255       return false;
256     }
257   }
258 
259   if (model_->annotation_actions_spec() != nullptr &&
260       model_->annotation_actions_spec()->annotation_mapping() != nullptr) {
261     for (const AnnotationActionsSpec_::AnnotationMapping* mapping :
262          *model_->annotation_actions_spec()->annotation_mapping()) {
263       annotation_entity_types_.insert(mapping->annotation_collection()->str());
264     }
265   }
266 
267   std::unique_ptr<ZlibDecompressor> decompressor = ZlibDecompressor::Instance();
268   if (!InitializeRules(decompressor.get())) {
269     TC3_LOG(ERROR) << "Could not initialize rules.";
270     return false;
271   }
272 
273   if (model_->actions_entity_data_schema() != nullptr) {
274     entity_data_schema_ = LoadAndVerifyFlatbuffer<reflection::Schema>(
275         model_->actions_entity_data_schema()->Data(),
276         model_->actions_entity_data_schema()->size());
277     if (entity_data_schema_ == nullptr) {
278       TC3_LOG(ERROR) << "Could not load entity data schema data.";
279       return false;
280     }
281 
282     entity_data_builder_.reset(
283         new ReflectiveFlatbufferBuilder(entity_data_schema_));
284   } else {
285     entity_data_schema_ = nullptr;
286   }
287 
288   std::string actions_script;
289   if (GetUncompressedString(model_->lua_actions_script(),
290                             model_->compressed_lua_actions_script(),
291                             decompressor.get(), &actions_script) &&
292       !actions_script.empty()) {
293     if (!Compile(actions_script, &lua_bytecode_)) {
294       TC3_LOG(ERROR) << "Could not precompile lua actions snippet.";
295       return false;
296     }
297   }
298 
299   if (!(ranker_ = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
300             model_->ranking_options(), decompressor.get(),
301             model_->smart_reply_action_type()->str()))) {
302     TC3_LOG(ERROR) << "Could not create an action suggestions ranker.";
303     return false;
304   }
305 
306   // Create feature processor if specified.
307   const ActionsTokenFeatureProcessorOptions* options =
308       model_->feature_processor_options();
309   if (options != nullptr) {
310     if (options->tokenizer_options() == nullptr) {
311       TC3_LOG(ERROR) << "No tokenizer options specified.";
312       return false;
313     }
314 
315     feature_processor_.reset(new ActionsFeatureProcessor(options, unilib_));
316     embedding_executor_ = TFLiteEmbeddingExecutor::FromBuffer(
317         options->embedding_model(), options->embedding_size(),
318         options->embedding_quantization_bits());
319 
320     if (embedding_executor_ == nullptr) {
321       TC3_LOG(ERROR) << "Could not initialize embedding executor.";
322       return false;
323     }
324 
325     // Cache embedding of padding, start and end token.
326     if (!EmbedTokenId(options->padding_token_id(), &embedded_padding_token_) ||
327         !EmbedTokenId(options->start_token_id(), &embedded_start_token_) ||
328         !EmbedTokenId(options->end_token_id(), &embedded_end_token_)) {
329       TC3_LOG(ERROR) << "Could not precompute token embeddings.";
330       return false;
331     }
332     token_embedding_size_ = feature_processor_->GetTokenEmbeddingSize();
333   }
334 
335   // Create low confidence model if specified.
336   if (model_->low_confidence_ngram_model() != nullptr) {
337     ngram_model_ = NGramModel::Create(model_->low_confidence_ngram_model(),
338                                       feature_processor_ == nullptr
339                                           ? nullptr
340                                           : feature_processor_->tokenizer(),
341                                       unilib_);
342     if (ngram_model_ == nullptr) {
343       TC3_LOG(ERROR) << "Could not create ngram linear regression model.";
344       return false;
345     }
346   }
347 
348   return true;
349 }
350 
InitializeTriggeringPreconditions()351 bool ActionsSuggestions::InitializeTriggeringPreconditions() {
352   triggering_preconditions_overlay_ =
353       LoadAndVerifyFlatbuffer<TriggeringPreconditions>(
354           triggering_preconditions_overlay_buffer_);
355 
356   if (triggering_preconditions_overlay_ == nullptr &&
357       !triggering_preconditions_overlay_buffer_.empty()) {
358     TC3_LOG(ERROR) << "Could not load triggering preconditions overwrites.";
359     return false;
360   }
361   const flatbuffers::Table* overlay =
362       reinterpret_cast<const flatbuffers::Table*>(
363           triggering_preconditions_overlay_);
364   const TriggeringPreconditions* defaults = model_->preconditions();
365   if (defaults == nullptr) {
366     TC3_LOG(ERROR) << "No triggering conditions specified.";
367     return false;
368   }
369 
370   preconditions_.min_smart_reply_triggering_score = ValueOrDefault(
371       overlay, TriggeringPreconditions::VT_MIN_SMART_REPLY_TRIGGERING_SCORE,
372       defaults->min_smart_reply_triggering_score());
373   preconditions_.max_sensitive_topic_score = ValueOrDefault(
374       overlay, TriggeringPreconditions::VT_MAX_SENSITIVE_TOPIC_SCORE,
375       defaults->max_sensitive_topic_score());
376   preconditions_.suppress_on_sensitive_topic = ValueOrDefault(
377       overlay, TriggeringPreconditions::VT_SUPPRESS_ON_SENSITIVE_TOPIC,
378       defaults->suppress_on_sensitive_topic());
379   preconditions_.min_input_length =
380       ValueOrDefault(overlay, TriggeringPreconditions::VT_MIN_INPUT_LENGTH,
381                      defaults->min_input_length());
382   preconditions_.max_input_length =
383       ValueOrDefault(overlay, TriggeringPreconditions::VT_MAX_INPUT_LENGTH,
384                      defaults->max_input_length());
385   preconditions_.min_locale_match_fraction = ValueOrDefault(
386       overlay, TriggeringPreconditions::VT_MIN_LOCALE_MATCH_FRACTION,
387       defaults->min_locale_match_fraction());
388   preconditions_.handle_missing_locale_as_supported = ValueOrDefault(
389       overlay, TriggeringPreconditions::VT_HANDLE_MISSING_LOCALE_AS_SUPPORTED,
390       defaults->handle_missing_locale_as_supported());
391   preconditions_.handle_unknown_locale_as_supported = ValueOrDefault(
392       overlay, TriggeringPreconditions::VT_HANDLE_UNKNOWN_LOCALE_AS_SUPPORTED,
393       defaults->handle_unknown_locale_as_supported());
394   preconditions_.suppress_on_low_confidence_input = ValueOrDefault(
395       overlay, TriggeringPreconditions::VT_SUPPRESS_ON_LOW_CONFIDENCE_INPUT,
396       defaults->suppress_on_low_confidence_input());
397   preconditions_.diversification_distance_threshold = ValueOrDefault(
398       overlay, TriggeringPreconditions::VT_DIVERSIFICATION_DISTANCE_THRESHOLD,
399       defaults->diversification_distance_threshold());
400   preconditions_.confidence_threshold =
401       ValueOrDefault(overlay, TriggeringPreconditions::VT_CONFIDENCE_THRESHOLD,
402                      defaults->confidence_threshold());
403   preconditions_.empirical_probability_factor = ValueOrDefault(
404       overlay, TriggeringPreconditions::VT_EMPIRICAL_PROBABILITY_FACTOR,
405       defaults->empirical_probability_factor());
406   preconditions_.min_reply_score_threshold = ValueOrDefault(
407       overlay, TriggeringPreconditions::VT_MIN_REPLY_SCORE_THRESHOLD,
408       defaults->min_reply_score_threshold());
409 
410   return true;
411 }
412 
EmbedTokenId(const int32 token_id,std::vector<float> * embedding) const413 bool ActionsSuggestions::EmbedTokenId(const int32 token_id,
414                                       std::vector<float>* embedding) const {
415   return feature_processor_->AppendFeatures(
416       {token_id},
417       /*dense_features=*/{}, embedding_executor_.get(), embedding);
418 }
419 
InitializeRules(ZlibDecompressor * decompressor)420 bool ActionsSuggestions::InitializeRules(ZlibDecompressor* decompressor) {
421   if (model_->rules() != nullptr) {
422     if (!InitializeRules(decompressor, model_->rules(), &rules_)) {
423       TC3_LOG(ERROR) << "Could not initialize action rules.";
424       return false;
425     }
426   }
427 
428   if (model_->low_confidence_rules() != nullptr) {
429     if (!InitializeRules(decompressor, model_->low_confidence_rules(),
430                          &low_confidence_rules_)) {
431       TC3_LOG(ERROR) << "Could not initialize low confidence rules.";
432       return false;
433     }
434   }
435 
436   // Extend by rules provided by the overwrite.
437   // NOTE: The rules from the original models are *not* cleared.
438   if (triggering_preconditions_overlay_ != nullptr &&
439       triggering_preconditions_overlay_->low_confidence_rules() != nullptr) {
440     // These rules are optionally compressed, but separately.
441     std::unique_ptr<ZlibDecompressor> overwrite_decompressor =
442         ZlibDecompressor::Instance();
443     if (overwrite_decompressor == nullptr) {
444       TC3_LOG(ERROR) << "Could not initialze decompressor for overwrite rules.";
445       return false;
446     }
447     if (!InitializeRules(
448             overwrite_decompressor.get(),
449             triggering_preconditions_overlay_->low_confidence_rules(),
450             &low_confidence_rules_)) {
451       TC3_LOG(ERROR)
452           << "Could not initialize low confidence rules from overwrite.";
453       return false;
454     }
455   }
456 
457   return true;
458 }
459 
InitializeRules(ZlibDecompressor * decompressor,const RulesModel * rules,std::vector<CompiledRule> * compiled_rules) const460 bool ActionsSuggestions::InitializeRules(
461     ZlibDecompressor* decompressor, const RulesModel* rules,
462     std::vector<CompiledRule>* compiled_rules) const {
463   for (const RulesModel_::Rule* rule : *rules->rule()) {
464     std::unique_ptr<UniLib::RegexPattern> compiled_pattern =
465         UncompressMakeRegexPattern(
466             *unilib_, rule->pattern(), rule->compressed_pattern(),
467             rules->lazy_regex_compilation(), decompressor);
468     if (compiled_pattern == nullptr) {
469       TC3_LOG(ERROR) << "Failed to load rule pattern.";
470       return false;
471     }
472 
473     // Check whether there is a check on the output.
474     std::unique_ptr<UniLib::RegexPattern> compiled_output_pattern;
475     if (rule->output_pattern() != nullptr ||
476         rule->compressed_output_pattern() != nullptr) {
477       compiled_output_pattern = UncompressMakeRegexPattern(
478           *unilib_, rule->output_pattern(), rule->compressed_output_pattern(),
479           rules->lazy_regex_compilation(), decompressor);
480       if (compiled_output_pattern == nullptr) {
481         TC3_LOG(ERROR) << "Failed to load rule output pattern.";
482         return false;
483       }
484     }
485 
486     compiled_rules->emplace_back(rule, std::move(compiled_pattern),
487                                  std::move(compiled_output_pattern));
488   }
489 
490   return true;
491 }
492 
IsLowConfidenceInput(const Conversation & conversation,const int num_messages,std::vector<int> * post_check_rules) const493 bool ActionsSuggestions::IsLowConfidenceInput(
494     const Conversation& conversation, const int num_messages,
495     std::vector<int>* post_check_rules) const {
496   for (int i = 1; i <= num_messages; i++) {
497     const std::string& message =
498         conversation.messages[conversation.messages.size() - i].text;
499     const UnicodeText message_unicode(
500         UTF8ToUnicodeText(message, /*do_copy=*/false));
501 
502     // Run ngram linear regression model.
503     if (ngram_model_ != nullptr) {
504       if (ngram_model_->Eval(message_unicode)) {
505         return true;
506       }
507     }
508 
509     // Run the regex based rules.
510     for (int low_confidence_rule = 0;
511          low_confidence_rule < low_confidence_rules_.size();
512          low_confidence_rule++) {
513       const CompiledRule& rule = low_confidence_rules_[low_confidence_rule];
514       const std::unique_ptr<UniLib::RegexMatcher> matcher =
515           rule.pattern->Matcher(message_unicode);
516       int status = UniLib::RegexMatcher::kNoError;
517       if (matcher->Find(&status) && status == UniLib::RegexMatcher::kNoError) {
518         // Rule only applies to input-output pairs, so defer the check.
519         if (rule.output_pattern != nullptr) {
520           post_check_rules->push_back(low_confidence_rule);
521           continue;
522         }
523         return true;
524       }
525     }
526   }
527   return false;
528 }
529 
FilterConfidenceOutput(const std::vector<int> & post_check_rules,std::vector<ActionSuggestion> * actions) const530 bool ActionsSuggestions::FilterConfidenceOutput(
531     const std::vector<int>& post_check_rules,
532     std::vector<ActionSuggestion>* actions) const {
533   if (post_check_rules.empty() || actions->empty()) {
534     return true;
535   }
536   std::vector<ActionSuggestion> filtered_text_replies;
537   for (const ActionSuggestion& action : *actions) {
538     if (action.response_text.empty()) {
539       filtered_text_replies.push_back(action);
540       continue;
541     }
542     bool passes_post_check = true;
543     const UnicodeText text_reply_unicode(
544         UTF8ToUnicodeText(action.response_text, /*do_copy=*/false));
545     for (const int rule_id : post_check_rules) {
546       const std::unique_ptr<UniLib::RegexMatcher> matcher =
547           low_confidence_rules_[rule_id].output_pattern->Matcher(
548               text_reply_unicode);
549       if (matcher == nullptr) {
550         TC3_LOG(ERROR) << "Could not create matcher for post check rule.";
551         return false;
552       }
553       int status = UniLib::RegexMatcher::kNoError;
554       if (matcher->Find(&status) || status != UniLib::RegexMatcher::kNoError) {
555         passes_post_check = false;
556         break;
557       }
558     }
559     if (passes_post_check) {
560       filtered_text_replies.push_back(action);
561     }
562   }
563   *actions = std::move(filtered_text_replies);
564   return true;
565 }
566 
SuggestionFromSpec(const ActionSuggestionSpec * action,const std::string & default_type,const std::string & default_response_text,const std::string & default_serialized_entity_data,const float default_score,const float default_priority_score) const567 ActionSuggestion ActionsSuggestions::SuggestionFromSpec(
568     const ActionSuggestionSpec* action, const std::string& default_type,
569     const std::string& default_response_text,
570     const std::string& default_serialized_entity_data,
571     const float default_score, const float default_priority_score) const {
572   ActionSuggestion suggestion;
573   suggestion.score = action != nullptr ? action->score() : default_score;
574   suggestion.priority_score =
575       action != nullptr ? action->priority_score() : default_priority_score;
576   suggestion.type = action != nullptr && action->type() != nullptr
577                         ? action->type()->str()
578                         : default_type;
579   suggestion.response_text =
580       action != nullptr && action->response_text() != nullptr
581           ? action->response_text()->str()
582           : default_response_text;
583   suggestion.serialized_entity_data =
584       action != nullptr && action->serialized_entity_data() != nullptr
585           ? action->serialized_entity_data()->str()
586           : default_serialized_entity_data;
587   return suggestion;
588 }
589 
Tokenize(const std::vector<std::string> & context) const590 std::vector<std::vector<Token>> ActionsSuggestions::Tokenize(
591     const std::vector<std::string>& context) const {
592   std::vector<std::vector<Token>> tokens;
593   tokens.reserve(context.size());
594   for (const std::string& message : context) {
595     tokens.push_back(feature_processor_->tokenizer()->Tokenize(message));
596   }
597   return tokens;
598 }
599 
EmbedTokensPerMessage(const std::vector<std::vector<Token>> & tokens,std::vector<float> * embeddings,int * max_num_tokens_per_message) const600 bool ActionsSuggestions::EmbedTokensPerMessage(
601     const std::vector<std::vector<Token>>& tokens,
602     std::vector<float>* embeddings, int* max_num_tokens_per_message) const {
603   const int num_messages = tokens.size();
604   *max_num_tokens_per_message = 0;
605   for (int i = 0; i < num_messages; i++) {
606     const int num_message_tokens = tokens[i].size();
607     if (num_message_tokens > *max_num_tokens_per_message) {
608       *max_num_tokens_per_message = num_message_tokens;
609     }
610   }
611 
612   if (model_->feature_processor_options()->min_num_tokens_per_message() >
613       *max_num_tokens_per_message) {
614     *max_num_tokens_per_message =
615         model_->feature_processor_options()->min_num_tokens_per_message();
616   }
617   if (model_->feature_processor_options()->max_num_tokens_per_message() > 0 &&
618       *max_num_tokens_per_message >
619           model_->feature_processor_options()->max_num_tokens_per_message()) {
620     *max_num_tokens_per_message =
621         model_->feature_processor_options()->max_num_tokens_per_message();
622   }
623 
624   // Embed all tokens and add paddings to pad tokens of each message to the
625   // maximum number of tokens in a message of the conversation.
626   // If a number of tokens is specified in the model config, tokens at the
627   // beginning of a message are dropped if they don't fit in the limit.
628   for (int i = 0; i < num_messages; i++) {
629     const int start =
630         std::max<int>(tokens[i].size() - *max_num_tokens_per_message, 0);
631     for (int pos = start; pos < tokens[i].size(); pos++) {
632       if (!feature_processor_->AppendTokenFeatures(
633               tokens[i][pos], embedding_executor_.get(), embeddings)) {
634         TC3_LOG(ERROR) << "Could not run token feature extractor.";
635         return false;
636       }
637     }
638     // Add padding.
639     for (int k = tokens[i].size(); k < *max_num_tokens_per_message; k++) {
640       embeddings->insert(embeddings->end(), embedded_padding_token_.begin(),
641                          embedded_padding_token_.end());
642     }
643   }
644 
645   return true;
646 }
647 
EmbedAndFlattenTokens(const std::vector<std::vector<Token>> tokens,std::vector<float> * embeddings,int * total_token_count) const648 bool ActionsSuggestions::EmbedAndFlattenTokens(
649     const std::vector<std::vector<Token>> tokens,
650     std::vector<float>* embeddings, int* total_token_count) const {
651   const int num_messages = tokens.size();
652   int start_message = 0;
653   int message_token_offset = 0;
654 
655   // If a maximum model input length is specified, we need to check how
656   // much we need to trim at the start.
657   const int max_num_total_tokens =
658       model_->feature_processor_options()->max_num_total_tokens();
659   if (max_num_total_tokens > 0) {
660     int total_tokens = 0;
661     start_message = num_messages - 1;
662     for (; start_message >= 0; start_message--) {
663       // Tokens of the message + start and end token.
664       const int num_message_tokens = tokens[start_message].size() + 2;
665       total_tokens += num_message_tokens;
666 
667       // Check whether we exhausted the budget.
668       if (total_tokens >= max_num_total_tokens) {
669         message_token_offset = total_tokens - max_num_total_tokens;
670         break;
671       }
672     }
673   }
674 
675   // Add embeddings.
676   *total_token_count = 0;
677   for (int i = start_message; i < num_messages; i++) {
678     if (message_token_offset == 0) {
679       ++(*total_token_count);
680       // Add `start message` token.
681       embeddings->insert(embeddings->end(), embedded_start_token_.begin(),
682                          embedded_start_token_.end());
683     }
684 
685     for (int pos = std::max(0, message_token_offset - 1);
686          pos < tokens[i].size(); pos++) {
687       ++(*total_token_count);
688       if (!feature_processor_->AppendTokenFeatures(
689               tokens[i][pos], embedding_executor_.get(), embeddings)) {
690         TC3_LOG(ERROR) << "Could not run token feature extractor.";
691         return false;
692       }
693     }
694 
695     // Add `end message` token.
696     ++(*total_token_count);
697     embeddings->insert(embeddings->end(), embedded_end_token_.begin(),
698                        embedded_end_token_.end());
699 
700     // Reset for the subsequent messages.
701     message_token_offset = 0;
702   }
703 
704   // Add optional padding.
705   const int min_num_total_tokens =
706       model_->feature_processor_options()->min_num_total_tokens();
707   for (; *total_token_count < min_num_total_tokens; ++(*total_token_count)) {
708     embeddings->insert(embeddings->end(), embedded_padding_token_.begin(),
709                        embedded_padding_token_.end());
710   }
711 
712   return true;
713 }
714 
AllocateInput(const int conversation_length,const int max_tokens,const int total_token_count,tflite::Interpreter * interpreter) const715 bool ActionsSuggestions::AllocateInput(const int conversation_length,
716                                        const int max_tokens,
717                                        const int total_token_count,
718                                        tflite::Interpreter* interpreter) const {
719   if (model_->tflite_model_spec()->resize_inputs()) {
720     if (model_->tflite_model_spec()->input_context() >= 0) {
721       interpreter->ResizeInputTensor(
722           interpreter->inputs()[model_->tflite_model_spec()->input_context()],
723           {1, conversation_length});
724     }
725     if (model_->tflite_model_spec()->input_user_id() >= 0) {
726       interpreter->ResizeInputTensor(
727           interpreter->inputs()[model_->tflite_model_spec()->input_user_id()],
728           {1, conversation_length});
729     }
730     if (model_->tflite_model_spec()->input_time_diffs() >= 0) {
731       interpreter->ResizeInputTensor(
732           interpreter
733               ->inputs()[model_->tflite_model_spec()->input_time_diffs()],
734           {1, conversation_length});
735     }
736     if (model_->tflite_model_spec()->input_num_tokens() >= 0) {
737       interpreter->ResizeInputTensor(
738           interpreter
739               ->inputs()[model_->tflite_model_spec()->input_num_tokens()],
740           {conversation_length, 1});
741     }
742     if (model_->tflite_model_spec()->input_token_embeddings() >= 0) {
743       interpreter->ResizeInputTensor(
744           interpreter
745               ->inputs()[model_->tflite_model_spec()->input_token_embeddings()],
746           {conversation_length, max_tokens, token_embedding_size_});
747     }
748     if (model_->tflite_model_spec()->input_flattened_token_embeddings() >= 0) {
749       interpreter->ResizeInputTensor(
750           interpreter->inputs()[model_->tflite_model_spec()
751                                     ->input_flattened_token_embeddings()],
752           {1, total_token_count});
753     }
754   }
755 
756   return interpreter->AllocateTensors() == kTfLiteOk;
757 }
758 
SetupModelInput(const std::vector<std::string> & context,const std::vector<int> & user_ids,const std::vector<float> & time_diffs,const int num_suggestions,const float confidence_threshold,const float diversification_distance,const float empirical_probability_factor,tflite::Interpreter * interpreter) const759 bool ActionsSuggestions::SetupModelInput(
760     const std::vector<std::string>& context, const std::vector<int>& user_ids,
761     const std::vector<float>& time_diffs, const int num_suggestions,
762     const float confidence_threshold, const float diversification_distance,
763     const float empirical_probability_factor,
764     tflite::Interpreter* interpreter) const {
765   // Compute token embeddings.
766   std::vector<std::vector<Token>> tokens;
767   std::vector<float> token_embeddings;
768   std::vector<float> flattened_token_embeddings;
769   int max_tokens = 0;
770   int total_token_count = 0;
771   if (model_->tflite_model_spec()->input_num_tokens() >= 0 ||
772       model_->tflite_model_spec()->input_token_embeddings() >= 0 ||
773       model_->tflite_model_spec()->input_flattened_token_embeddings() >= 0) {
774     if (feature_processor_ == nullptr) {
775       TC3_LOG(ERROR) << "No feature processor specified.";
776       return false;
777     }
778 
779     // Tokenize the messages in the conversation.
780     tokens = Tokenize(context);
781     if (model_->tflite_model_spec()->input_token_embeddings() >= 0) {
782       if (!EmbedTokensPerMessage(tokens, &token_embeddings, &max_tokens)) {
783         TC3_LOG(ERROR) << "Could not extract token features.";
784         return false;
785       }
786     }
787     if (model_->tflite_model_spec()->input_flattened_token_embeddings() >= 0) {
788       if (!EmbedAndFlattenTokens(tokens, &flattened_token_embeddings,
789                                  &total_token_count)) {
790         TC3_LOG(ERROR) << "Could not extract token features.";
791         return false;
792       }
793     }
794   }
795 
796   if (!AllocateInput(context.size(), max_tokens, total_token_count,
797                      interpreter)) {
798     TC3_LOG(ERROR) << "TensorFlow Lite model allocation failed.";
799     return false;
800   }
801   if (model_->tflite_model_spec()->input_context() >= 0) {
802     model_executor_->SetInput<std::string>(
803         model_->tflite_model_spec()->input_context(), context, interpreter);
804   }
805   if (model_->tflite_model_spec()->input_context_length() >= 0) {
806     model_executor_->SetInput<int>(
807         model_->tflite_model_spec()->input_context_length(), context.size(),
808         interpreter);
809   }
810   if (model_->tflite_model_spec()->input_user_id() >= 0) {
811     model_executor_->SetInput<int>(model_->tflite_model_spec()->input_user_id(),
812                                    user_ids, interpreter);
813   }
814   if (model_->tflite_model_spec()->input_num_suggestions() >= 0) {
815     model_executor_->SetInput<int>(
816         model_->tflite_model_spec()->input_num_suggestions(), num_suggestions,
817         interpreter);
818   }
819   if (model_->tflite_model_spec()->input_time_diffs() >= 0) {
820     model_executor_->SetInput<float>(
821         model_->tflite_model_spec()->input_time_diffs(), time_diffs,
822         interpreter);
823   }
824   if (model_->tflite_model_spec()->input_diversification_distance() >= 0) {
825     model_executor_->SetInput<float>(
826         model_->tflite_model_spec()->input_diversification_distance(),
827         diversification_distance, interpreter);
828   }
829   if (model_->tflite_model_spec()->input_confidence_threshold() >= 0) {
830     model_executor_->SetInput<float>(
831         model_->tflite_model_spec()->input_confidence_threshold(),
832         confidence_threshold, interpreter);
833   }
834   if (model_->tflite_model_spec()->input_empirical_probability_factor() >= 0) {
835     model_executor_->SetInput<float>(
836         model_->tflite_model_spec()->input_empirical_probability_factor(),
837         confidence_threshold, interpreter);
838   }
839   if (model_->tflite_model_spec()->input_num_tokens() >= 0) {
840     std::vector<int> num_tokens_per_message(tokens.size());
841     for (int i = 0; i < tokens.size(); i++) {
842       num_tokens_per_message[i] = tokens[i].size();
843     }
844     model_executor_->SetInput<int>(
845         model_->tflite_model_spec()->input_num_tokens(), num_tokens_per_message,
846         interpreter);
847   }
848   if (model_->tflite_model_spec()->input_token_embeddings() >= 0) {
849     model_executor_->SetInput<float>(
850         model_->tflite_model_spec()->input_token_embeddings(), token_embeddings,
851         interpreter);
852   }
853   if (model_->tflite_model_spec()->input_flattened_token_embeddings() >= 0) {
854     model_executor_->SetInput<float>(
855         model_->tflite_model_spec()->input_flattened_token_embeddings(),
856         flattened_token_embeddings, interpreter);
857   }
858   return true;
859 }
860 
ReadModelOutput(tflite::Interpreter * interpreter,const ActionSuggestionOptions & options,ActionsSuggestionsResponse * response) const861 bool ActionsSuggestions::ReadModelOutput(
862     tflite::Interpreter* interpreter, const ActionSuggestionOptions& options,
863     ActionsSuggestionsResponse* response) const {
864   // Read sensitivity and triggering score predictions.
865   if (model_->tflite_model_spec()->output_triggering_score() >= 0) {
866     const TensorView<float>& triggering_score =
867         model_executor_->OutputView<float>(
868             model_->tflite_model_spec()->output_triggering_score(),
869             interpreter);
870     if (!triggering_score.is_valid() || triggering_score.size() == 0) {
871       TC3_LOG(ERROR) << "Could not compute triggering score.";
872       return false;
873     }
874     response->triggering_score = triggering_score.data()[0];
875     response->output_filtered_min_triggering_score =
876         (response->triggering_score <
877          preconditions_.min_smart_reply_triggering_score);
878   }
879   if (model_->tflite_model_spec()->output_sensitive_topic_score() >= 0) {
880     const TensorView<float>& sensitive_topic_score =
881         model_executor_->OutputView<float>(
882             model_->tflite_model_spec()->output_sensitive_topic_score(),
883             interpreter);
884     if (!sensitive_topic_score.is_valid() ||
885         sensitive_topic_score.dim(0) != 1) {
886       TC3_LOG(ERROR) << "Could not compute sensitive topic score.";
887       return false;
888     }
889     response->sensitivity_score = sensitive_topic_score.data()[0];
890     response->output_filtered_sensitivity =
891         (response->sensitivity_score >
892          preconditions_.max_sensitive_topic_score);
893   }
894 
895   // Suppress model outputs.
896   if (response->output_filtered_sensitivity) {
897     return true;
898   }
899 
900   // Read smart reply predictions.
901   std::vector<ActionSuggestion> text_replies;
902   if (!response->output_filtered_min_triggering_score &&
903       model_->tflite_model_spec()->output_replies() >= 0) {
904     const std::vector<tflite::StringRef> replies =
905         model_executor_->Output<tflite::StringRef>(
906             model_->tflite_model_spec()->output_replies(), interpreter);
907     TensorView<float> scores = model_executor_->OutputView<float>(
908         model_->tflite_model_spec()->output_replies_scores(), interpreter);
909     for (int i = 0; i < replies.size(); i++) {
910       if (replies[i].len == 0) continue;
911       const float score = scores.data()[i];
912       if (score < preconditions_.min_reply_score_threshold) {
913         continue;
914       }
915       response->actions.push_back({std::string(replies[i].str, replies[i].len),
916                                    model_->smart_reply_action_type()->str(),
917                                    score});
918     }
919   }
920 
921   // Read actions suggestions.
922   if (model_->tflite_model_spec()->output_actions_scores() >= 0) {
923     const TensorView<float> actions_scores = model_executor_->OutputView<float>(
924         model_->tflite_model_spec()->output_actions_scores(), interpreter);
925     for (int i = 0; i < model_->action_type()->Length(); i++) {
926       const ActionTypeOptions* action_type = model_->action_type()->Get(i);
927       // Skip disabled action classes, such as the default other category.
928       if (!action_type->enabled()) {
929         continue;
930       }
931       const float score = actions_scores.data()[i];
932       if (score < action_type->min_triggering_score()) {
933         continue;
934       }
935       ActionSuggestion suggestion =
936           SuggestionFromSpec(action_type->action(),
937                              /*default_type=*/action_type->name()->str());
938       suggestion.score = score;
939       response->actions.push_back(suggestion);
940     }
941   }
942 
943   return true;
944 }
945 
SuggestActionsFromModel(const Conversation & conversation,const int num_messages,const ActionSuggestionOptions & options,ActionsSuggestionsResponse * response,std::unique_ptr<tflite::Interpreter> * interpreter) const946 bool ActionsSuggestions::SuggestActionsFromModel(
947     const Conversation& conversation, const int num_messages,
948     const ActionSuggestionOptions& options,
949     ActionsSuggestionsResponse* response,
950     std::unique_ptr<tflite::Interpreter>* interpreter) const {
951   TC3_CHECK_LE(num_messages, conversation.messages.size());
952 
953   if (!model_executor_) {
954     return true;
955   }
956   *interpreter = model_executor_->CreateInterpreter();
957 
958   if (!*interpreter) {
959     TC3_LOG(ERROR) << "Could not build TensorFlow Lite interpreter for the "
960                       "actions suggestions model.";
961     return false;
962   }
963 
964   std::vector<std::string> context;
965   std::vector<int> user_ids;
966   std::vector<float> time_diffs;
967   context.reserve(num_messages);
968   user_ids.reserve(num_messages);
969   time_diffs.reserve(num_messages);
970 
971   // Gather last `num_messages` messages from the conversation.
972   int64 last_message_reference_time_ms_utc = 0;
973   const float second_in_ms = 1000;
974   for (int i = conversation.messages.size() - num_messages;
975        i < conversation.messages.size(); i++) {
976     const ConversationMessage& message = conversation.messages[i];
977     context.push_back(message.text);
978     user_ids.push_back(message.user_id);
979 
980     float time_diff_secs = 0;
981     if (message.reference_time_ms_utc != 0 &&
982         last_message_reference_time_ms_utc != 0) {
983       time_diff_secs = std::max(0.0f, (message.reference_time_ms_utc -
984                                        last_message_reference_time_ms_utc) /
985                                           second_in_ms);
986     }
987     if (message.reference_time_ms_utc != 0) {
988       last_message_reference_time_ms_utc = message.reference_time_ms_utc;
989     }
990     time_diffs.push_back(time_diff_secs);
991   }
992 
993   if (!SetupModelInput(context, user_ids, time_diffs,
994                        /*num_suggestions=*/model_->num_smart_replies(),
995                        preconditions_.confidence_threshold,
996                        preconditions_.diversification_distance_threshold,
997                        preconditions_.empirical_probability_factor,
998                        interpreter->get())) {
999     TC3_LOG(ERROR) << "Failed to setup input for TensorFlow Lite model.";
1000     return false;
1001   }
1002 
1003   if ((*interpreter)->Invoke() != kTfLiteOk) {
1004     TC3_LOG(ERROR) << "Failed to invoke TensorFlow Lite interpreter.";
1005     return false;
1006   }
1007 
1008   return ReadModelOutput(interpreter->get(), options, response);
1009 }
1010 
AnnotationOptionsForMessage(const ConversationMessage & message) const1011 AnnotationOptions ActionsSuggestions::AnnotationOptionsForMessage(
1012     const ConversationMessage& message) const {
1013   AnnotationOptions options;
1014   options.detected_text_language_tags = message.detected_text_language_tags;
1015   options.reference_time_ms_utc = message.reference_time_ms_utc;
1016   options.reference_timezone = message.reference_timezone;
1017   options.annotation_usecase =
1018       model_->annotation_actions_spec()->annotation_usecase();
1019   options.is_serialized_entity_data_enabled =
1020       model_->annotation_actions_spec()->is_serialized_entity_data_enabled();
1021   options.entity_types = annotation_entity_types_;
1022   return options;
1023 }
1024 
SuggestActionsFromAnnotations(const Conversation & conversation,const ActionSuggestionOptions & options,const Annotator * annotator,std::vector<ActionSuggestion> * actions) const1025 void ActionsSuggestions::SuggestActionsFromAnnotations(
1026     const Conversation& conversation, const ActionSuggestionOptions& options,
1027     const Annotator* annotator, std::vector<ActionSuggestion>* actions) const {
1028   if (model_->annotation_actions_spec() == nullptr ||
1029       model_->annotation_actions_spec()->annotation_mapping() == nullptr ||
1030       model_->annotation_actions_spec()->annotation_mapping()->size() == 0) {
1031     return;
1032   }
1033 
1034   // Create actions based on the annotations.
1035   const int max_from_any_person =
1036       model_->annotation_actions_spec()->max_history_from_any_person();
1037   const int max_from_last_person =
1038       model_->annotation_actions_spec()->max_history_from_last_person();
1039   const int last_person = conversation.messages.back().user_id;
1040 
1041   int num_messages_last_person = 0;
1042   int num_messages_any_person = 0;
1043   bool all_from_last_person = true;
1044   for (int message_index = conversation.messages.size() - 1; message_index >= 0;
1045        message_index--) {
1046     const ConversationMessage& message = conversation.messages[message_index];
1047     std::vector<AnnotatedSpan> annotations = message.annotations;
1048 
1049     // Update how many messages we have processed from the last person in the
1050     // conversation and from any person in the conversation.
1051     num_messages_any_person++;
1052     if (all_from_last_person && message.user_id == last_person) {
1053       num_messages_last_person++;
1054     } else {
1055       all_from_last_person = false;
1056     }
1057 
1058     if (num_messages_any_person > max_from_any_person &&
1059         (!all_from_last_person ||
1060          num_messages_last_person > max_from_last_person)) {
1061       break;
1062     }
1063 
1064     if (message.user_id == kLocalUserId) {
1065       if (model_->annotation_actions_spec()->only_until_last_sent()) {
1066         break;
1067       }
1068       if (!model_->annotation_actions_spec()->include_local_user_messages()) {
1069         continue;
1070       }
1071     }
1072 
1073     if (annotations.empty() && annotator != nullptr) {
1074       annotations = annotator->Annotate(message.text,
1075                                         AnnotationOptionsForMessage(message));
1076     }
1077     std::vector<ActionSuggestionAnnotation> action_annotations;
1078     action_annotations.reserve(annotations.size());
1079     for (const AnnotatedSpan& annotation : annotations) {
1080       if (annotation.classification.empty()) {
1081         continue;
1082       }
1083 
1084       const ClassificationResult& classification_result =
1085           annotation.classification[0];
1086 
1087       ActionSuggestionAnnotation action_annotation;
1088       action_annotation.span = {
1089           message_index, annotation.span,
1090           UTF8ToUnicodeText(message.text, /*do_copy=*/false)
1091               .UTF8Substring(annotation.span.first, annotation.span.second)};
1092       action_annotation.entity = classification_result;
1093       action_annotation.name = classification_result.collection;
1094       action_annotations.push_back(action_annotation);
1095     }
1096 
1097     if (model_->annotation_actions_spec()->deduplicate_annotations()) {
1098       // Create actions only for deduplicated annotations.
1099       for (const int annotation_id :
1100            DeduplicateAnnotations(action_annotations)) {
1101         SuggestActionsFromAnnotation(
1102             message_index, action_annotations[annotation_id], actions);
1103       }
1104     } else {
1105       // Create actions for all annotations.
1106       for (const ActionSuggestionAnnotation& annotation : action_annotations) {
1107         SuggestActionsFromAnnotation(message_index, annotation, actions);
1108       }
1109     }
1110   }
1111 }
1112 
SuggestActionsFromAnnotation(const int message_index,const ActionSuggestionAnnotation & annotation,std::vector<ActionSuggestion> * actions) const1113 void ActionsSuggestions::SuggestActionsFromAnnotation(
1114     const int message_index, const ActionSuggestionAnnotation& annotation,
1115     std::vector<ActionSuggestion>* actions) const {
1116   for (const AnnotationActionsSpec_::AnnotationMapping* mapping :
1117        *model_->annotation_actions_spec()->annotation_mapping()) {
1118     if (annotation.entity.collection ==
1119         mapping->annotation_collection()->str()) {
1120       if (annotation.entity.score < mapping->min_annotation_score()) {
1121         continue;
1122       }
1123       ActionSuggestion suggestion = SuggestionFromSpec(mapping->action());
1124       if (mapping->use_annotation_score()) {
1125         suggestion.score = annotation.entity.score;
1126       }
1127 
1128       // Set annotation text as (additional) entity data field.
1129       if (mapping->entity_field() != nullptr) {
1130         std::unique_ptr<ReflectiveFlatbuffer> entity_data =
1131             entity_data_builder_->NewRoot();
1132         TC3_CHECK(entity_data != nullptr);
1133 
1134         // Merge existing static entity data.
1135         if (!suggestion.serialized_entity_data.empty()) {
1136           entity_data->MergeFromSerializedFlatbuffer(
1137               StringPiece(suggestion.serialized_entity_data.c_str(),
1138                           suggestion.serialized_entity_data.size()));
1139         }
1140 
1141         entity_data->ParseAndSet(mapping->entity_field(), annotation.span.text);
1142         suggestion.serialized_entity_data = entity_data->Serialize();
1143       }
1144 
1145       suggestion.annotations = {annotation};
1146       actions->push_back(suggestion);
1147     }
1148   }
1149 }
1150 
DeduplicateAnnotations(const std::vector<ActionSuggestionAnnotation> & annotations) const1151 std::vector<int> ActionsSuggestions::DeduplicateAnnotations(
1152     const std::vector<ActionSuggestionAnnotation>& annotations) const {
1153   std::map<std::pair<std::string, std::string>, int> deduplicated_annotations;
1154 
1155   for (int i = 0; i < annotations.size(); i++) {
1156     const std::pair<std::string, std::string> key = {annotations[i].name,
1157                                                      annotations[i].span.text};
1158     auto entry = deduplicated_annotations.find(key);
1159     if (entry != deduplicated_annotations.end()) {
1160       // Kepp the annotation with the higher score.
1161       if (annotations[entry->second].entity.score <
1162           annotations[i].entity.score) {
1163         entry->second = i;
1164       }
1165       continue;
1166     }
1167     deduplicated_annotations.insert(entry, {key, i});
1168   }
1169 
1170   std::vector<int> result;
1171   result.reserve(deduplicated_annotations.size());
1172   for (const auto& key_and_annotation : deduplicated_annotations) {
1173     result.push_back(key_and_annotation.second);
1174   }
1175   return result;
1176 }
1177 
FillAnnotationFromMatchGroup(const UniLib::RegexMatcher * matcher,const RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroup * group,const int message_index,ActionSuggestionAnnotation * annotation) const1178 bool ActionsSuggestions::FillAnnotationFromMatchGroup(
1179     const UniLib::RegexMatcher* matcher,
1180     const RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroup* group,
1181     const int message_index, ActionSuggestionAnnotation* annotation) const {
1182   if (group->annotation_name() != nullptr ||
1183       group->annotation_type() != nullptr) {
1184     int status = UniLib::RegexMatcher::kNoError;
1185     const CodepointSpan span = {matcher->Start(group->group_id(), &status),
1186                                 matcher->End(group->group_id(), &status)};
1187     std::string text =
1188         matcher->Group(group->group_id(), &status).ToUTF8String();
1189     if (status != UniLib::RegexMatcher::kNoError) {
1190       TC3_LOG(ERROR) << "Could not extract span from rule capturing group.";
1191       return false;
1192     }
1193 
1194     // The capturing group was not part of the match.
1195     if (span.first == kInvalidIndex || span.second == kInvalidIndex) {
1196       return false;
1197     }
1198     annotation->span.span = span;
1199     annotation->span.message_index = message_index;
1200     annotation->span.text = text;
1201     if (group->annotation_name() != nullptr) {
1202       annotation->name = group->annotation_name()->str();
1203     }
1204     if (group->annotation_type() != nullptr) {
1205       annotation->entity.collection = group->annotation_type()->str();
1206     }
1207   }
1208   return true;
1209 }
1210 
SuggestActionsFromRules(const Conversation & conversation,std::vector<ActionSuggestion> * actions) const1211 bool ActionsSuggestions::SuggestActionsFromRules(
1212     const Conversation& conversation,
1213     std::vector<ActionSuggestion>* actions) const {
1214   // Create actions based on rules checking the last message.
1215   const int message_index = conversation.messages.size() - 1;
1216   const std::string& message = conversation.messages.back().text;
1217   const UnicodeText message_unicode(
1218       UTF8ToUnicodeText(message, /*do_copy=*/false));
1219   for (const CompiledRule& rule : rules_) {
1220     const std::unique_ptr<UniLib::RegexMatcher> matcher =
1221         rule.pattern->Matcher(message_unicode);
1222     int status = UniLib::RegexMatcher::kNoError;
1223     while (matcher->Find(&status) && status == UniLib::RegexMatcher::kNoError) {
1224       for (const RulesModel_::Rule_::RuleActionSpec* rule_action :
1225            *rule.rule->actions()) {
1226         const ActionSuggestionSpec* action = rule_action->action();
1227         std::vector<ActionSuggestionAnnotation> annotations;
1228 
1229         bool sets_entity_data = false;
1230         std::unique_ptr<ReflectiveFlatbuffer> entity_data =
1231             entity_data_builder_ != nullptr ? entity_data_builder_->NewRoot()
1232                                             : nullptr;
1233 
1234         // Set static entity data.
1235         if (action != nullptr && action->serialized_entity_data() != nullptr) {
1236           TC3_CHECK(entity_data != nullptr);
1237           sets_entity_data = true;
1238           entity_data->MergeFromSerializedFlatbuffer(
1239               StringPiece(action->serialized_entity_data()->c_str(),
1240                           action->serialized_entity_data()->size()));
1241         }
1242 
1243         // Add entity data from rule capturing groups.
1244         if (rule_action->capturing_group() != nullptr) {
1245           for (const RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroup*
1246                    group : *rule_action->capturing_group()) {
1247             if (group->entity_field() != nullptr) {
1248               TC3_CHECK(entity_data != nullptr);
1249               sets_entity_data = true;
1250               if (!SetFieldFromCapturingGroup(
1251                       group->group_id(), group->entity_field(), matcher.get(),
1252                       entity_data.get())) {
1253                 TC3_LOG(ERROR)
1254                     << "Could not set entity data from rule capturing group.";
1255                 return false;
1256               }
1257             }
1258 
1259             // Create a text annotation for the group span.
1260             ActionSuggestionAnnotation annotation;
1261             if (FillAnnotationFromMatchGroup(matcher.get(), group,
1262                                              message_index, &annotation)) {
1263               annotations.push_back(annotation);
1264             }
1265 
1266             // Create text reply.
1267             if (group->text_reply() != nullptr) {
1268               int status = UniLib::RegexMatcher::kNoError;
1269               const std::string group_text =
1270                   matcher->Group(group->group_id(), &status).ToUTF8String();
1271               if (status != UniLib::RegexMatcher::kNoError) {
1272                 TC3_LOG(ERROR) << "Could get text from capturing group.";
1273                 return false;
1274               }
1275               if (group_text.empty()) {
1276                 // The group was not part of the match, ignore and continue.
1277                 continue;
1278               }
1279               actions->push_back(SuggestionFromSpec(
1280                   group->text_reply(),
1281                   /*default_type=*/model_->smart_reply_action_type()->str(),
1282                   /*default_response_text=*/group_text));
1283             }
1284           }
1285         }
1286 
1287         if (action != nullptr) {
1288           ActionSuggestion suggestion = SuggestionFromSpec(action);
1289           suggestion.annotations = annotations;
1290           if (sets_entity_data) {
1291             suggestion.serialized_entity_data = entity_data->Serialize();
1292           }
1293           actions->push_back(suggestion);
1294         }
1295       }
1296     }
1297   }
1298   return true;
1299 }
1300 
SuggestActionsFromLua(const Conversation & conversation,const TfLiteModelExecutor * model_executor,const tflite::Interpreter * interpreter,const reflection::Schema * annotation_entity_data_schema,std::vector<ActionSuggestion> * actions) const1301 bool ActionsSuggestions::SuggestActionsFromLua(
1302     const Conversation& conversation, const TfLiteModelExecutor* model_executor,
1303     const tflite::Interpreter* interpreter,
1304     const reflection::Schema* annotation_entity_data_schema,
1305     std::vector<ActionSuggestion>* actions) const {
1306   if (lua_bytecode_.empty()) {
1307     return true;
1308   }
1309 
1310   auto lua_actions = LuaActionsSuggestions::CreateLuaActionsSuggestions(
1311       lua_bytecode_, conversation, model_executor, model_->tflite_model_spec(),
1312       interpreter, entity_data_schema_, annotation_entity_data_schema);
1313   if (lua_actions == nullptr) {
1314     TC3_LOG(ERROR) << "Could not create lua actions.";
1315     return false;
1316   }
1317   return lua_actions->SuggestActions(actions);
1318 }
1319 
GatherActionsSuggestions(const Conversation & conversation,const Annotator * annotator,const ActionSuggestionOptions & options,ActionsSuggestionsResponse * response) const1320 bool ActionsSuggestions::GatherActionsSuggestions(
1321     const Conversation& conversation, const Annotator* annotator,
1322     const ActionSuggestionOptions& options,
1323     ActionsSuggestionsResponse* response) const {
1324   if (conversation.messages.empty()) {
1325     return true;
1326   }
1327 
1328   const int num_messages = NumMessagesToConsider(
1329       conversation, model_->max_conversation_history_length());
1330 
1331   if (num_messages <= 0) {
1332     TC3_LOG(INFO) << "No messages provided for actions suggestions.";
1333     return false;
1334   }
1335 
1336   SuggestActionsFromAnnotations(conversation, options, annotator,
1337                                 &response->actions);
1338 
1339   int input_text_length = 0;
1340   int num_matching_locales = 0;
1341   for (int i = conversation.messages.size() - num_messages;
1342        i < conversation.messages.size(); i++) {
1343     input_text_length += conversation.messages[i].text.length();
1344     std::vector<Locale> message_languages;
1345     if (!ParseLocales(conversation.messages[i].detected_text_language_tags,
1346                       &message_languages)) {
1347       continue;
1348     }
1349     if (Locale::IsAnyLocaleSupported(
1350             message_languages, locales_,
1351             preconditions_.handle_unknown_locale_as_supported)) {
1352       ++num_matching_locales;
1353     }
1354   }
1355 
1356   // Bail out if we are provided with too few or too much input.
1357   if (input_text_length < preconditions_.min_input_length ||
1358       (preconditions_.max_input_length >= 0 &&
1359        input_text_length > preconditions_.max_input_length)) {
1360     TC3_LOG(INFO) << "Too much or not enough input for inference.";
1361     return response;
1362   }
1363 
1364   // Bail out if the text does not look like it can be handled by the model.
1365   const float matching_fraction =
1366       static_cast<float>(num_matching_locales) / num_messages;
1367   if (matching_fraction < preconditions_.min_locale_match_fraction) {
1368     TC3_LOG(INFO) << "Not enough locale matches.";
1369     response->output_filtered_locale_mismatch = true;
1370     return true;
1371   }
1372 
1373   std::vector<int> post_check_rules;
1374   if (preconditions_.suppress_on_low_confidence_input &&
1375       IsLowConfidenceInput(conversation, num_messages, &post_check_rules)) {
1376     response->output_filtered_low_confidence = true;
1377     return true;
1378   }
1379 
1380   std::unique_ptr<tflite::Interpreter> interpreter;
1381   if (!SuggestActionsFromModel(conversation, num_messages, options, response,
1382                                &interpreter)) {
1383     TC3_LOG(ERROR) << "Could not run model.";
1384     return false;
1385   }
1386 
1387   // Suppress all predictions if the conversation was deemed sensitive.
1388   if (preconditions_.suppress_on_sensitive_topic &&
1389       response->output_filtered_sensitivity) {
1390     return true;
1391   }
1392 
1393   if (!SuggestActionsFromLua(
1394           conversation, model_executor_.get(), interpreter.get(),
1395           annotator != nullptr ? annotator->entity_data_schema() : nullptr,
1396           &response->actions)) {
1397     TC3_LOG(ERROR) << "Could not suggest actions from script.";
1398     return false;
1399   }
1400 
1401   if (!SuggestActionsFromRules(conversation, &response->actions)) {
1402     TC3_LOG(ERROR) << "Could not suggest actions from rules.";
1403     return false;
1404   }
1405 
1406   if (preconditions_.suppress_on_low_confidence_input &&
1407       !FilterConfidenceOutput(post_check_rules, &response->actions)) {
1408     TC3_LOG(ERROR) << "Could not post-check actions.";
1409     return false;
1410   }
1411 
1412   return true;
1413 }
1414 
SuggestActions(const Conversation & conversation,const Annotator * annotator,const ActionSuggestionOptions & options) const1415 ActionsSuggestionsResponse ActionsSuggestions::SuggestActions(
1416     const Conversation& conversation, const Annotator* annotator,
1417     const ActionSuggestionOptions& options) const {
1418   ActionsSuggestionsResponse response;
1419   if (!GatherActionsSuggestions(conversation, annotator, options, &response)) {
1420     TC3_LOG(ERROR) << "Could not gather actions suggestions.";
1421     response.actions.clear();
1422   } else if (!ranker_->RankActions(conversation, &response, entity_data_schema_,
1423                                    annotator != nullptr
1424                                        ? annotator->entity_data_schema()
1425                                        : nullptr)) {
1426     TC3_LOG(ERROR) << "Could not rank actions.";
1427     response.actions.clear();
1428   }
1429   return response;
1430 }
1431 
SuggestActions(const Conversation & conversation,const ActionSuggestionOptions & options) const1432 ActionsSuggestionsResponse ActionsSuggestions::SuggestActions(
1433     const Conversation& conversation,
1434     const ActionSuggestionOptions& options) const {
1435   return SuggestActions(conversation, /*annotator=*/nullptr, options);
1436 }
1437 
model() const1438 const ActionsModel* ActionsSuggestions::model() const { return model_; }
entity_data_schema() const1439 const reflection::Schema* ActionsSuggestions::entity_data_schema() const {
1440   return entity_data_schema_;
1441 }
1442 
ViewActionsModel(const void * buffer,int size)1443 const ActionsModel* ViewActionsModel(const void* buffer, int size) {
1444   if (buffer == nullptr) {
1445     return nullptr;
1446   }
1447   return LoadAndVerifyModel(reinterpret_cast<const uint8_t*>(buffer), size);
1448 }
1449 
1450 }  // namespace libtextclassifier3
1451