• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "actions/actions-suggestions.h"
18 
19 #include <memory>
20 
21 #include "actions/lua-actions.h"
22 #include "actions/types.h"
23 #include "actions/utils.h"
24 #include "actions/zlib-utils.h"
25 #include "annotator/collections.h"
26 #include "utils/base/logging.h"
27 #include "utils/flatbuffers.h"
28 #include "utils/lua-utils.h"
29 #include "utils/normalization.h"
30 #include "utils/optional.h"
31 #include "utils/strings/split.h"
32 #include "utils/strings/stringpiece.h"
33 #include "utils/utf8/unicodetext.h"
34 #include "tensorflow/lite/string_util.h"
35 
36 namespace libtextclassifier3 {
37 
38 const std::string& ActionsSuggestions::kViewCalendarType =
__anon4bdb828f0102() 39     *[]() { return new std::string("view_calendar"); }();
40 const std::string& ActionsSuggestions::kViewMapType =
__anon4bdb828f0202() 41     *[]() { return new std::string("view_map"); }();
42 const std::string& ActionsSuggestions::kTrackFlightType =
__anon4bdb828f0302() 43     *[]() { return new std::string("track_flight"); }();
44 const std::string& ActionsSuggestions::kOpenUrlType =
__anon4bdb828f0402() 45     *[]() { return new std::string("open_url"); }();
46 const std::string& ActionsSuggestions::kSendSmsType =
__anon4bdb828f0502() 47     *[]() { return new std::string("send_sms"); }();
48 const std::string& ActionsSuggestions::kCallPhoneType =
__anon4bdb828f0602() 49     *[]() { return new std::string("call_phone"); }();
50 const std::string& ActionsSuggestions::kSendEmailType =
__anon4bdb828f0702() 51     *[]() { return new std::string("send_email"); }();
52 const std::string& ActionsSuggestions::kShareLocation =
__anon4bdb828f0802() 53     *[]() { return new std::string("share_location"); }();
54 
55 // Name for a datetime annotation that only includes time but no date.
56 const std::string& kTimeAnnotation =
__anon4bdb828f0902() 57     *[]() { return new std::string("time"); }();
58 
59 constexpr float kDefaultFloat = 0.0;
60 constexpr bool kDefaultBool = false;
61 constexpr int kDefaultInt = 1;
62 
63 namespace {
64 
LoadAndVerifyModel(const uint8_t * addr,int size)65 const ActionsModel* LoadAndVerifyModel(const uint8_t* addr, int size) {
66   flatbuffers::Verifier verifier(addr, size);
67   if (VerifyActionsModelBuffer(verifier)) {
68     return GetActionsModel(addr);
69   } else {
70     return nullptr;
71   }
72 }
73 
74 template <typename T>
ValueOrDefault(const flatbuffers::Table * values,const int32 field_offset,const T default_value)75 T ValueOrDefault(const flatbuffers::Table* values, const int32 field_offset,
76                  const T default_value) {
77   if (values == nullptr) {
78     return default_value;
79   }
80   return values->GetField<T>(field_offset, default_value);
81 }
82 
83 // Returns number of (tail) messages of a conversation to consider.
NumMessagesToConsider(const Conversation & conversation,const int max_conversation_history_length)84 int NumMessagesToConsider(const Conversation& conversation,
85                           const int max_conversation_history_length) {
86   return ((max_conversation_history_length < 0 ||
87            conversation.messages.size() < max_conversation_history_length)
88               ? conversation.messages.size()
89               : max_conversation_history_length);
90 }
91 
92 }  // namespace
93 
FromUnownedBuffer(const uint8_t * buffer,const int size,const UniLib * unilib,const std::string & triggering_preconditions_overlay)94 std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromUnownedBuffer(
95     const uint8_t* buffer, const int size, const UniLib* unilib,
96     const std::string& triggering_preconditions_overlay) {
97   auto actions = std::unique_ptr<ActionsSuggestions>(new ActionsSuggestions());
98   const ActionsModel* model = LoadAndVerifyModel(buffer, size);
99   if (model == nullptr) {
100     return nullptr;
101   }
102   actions->model_ = model;
103   actions->SetOrCreateUnilib(unilib);
104   actions->triggering_preconditions_overlay_buffer_ =
105       triggering_preconditions_overlay;
106   if (!actions->ValidateAndInitialize()) {
107     return nullptr;
108   }
109   return actions;
110 }
111 
FromScopedMmap(std::unique_ptr<libtextclassifier3::ScopedMmap> mmap,const UniLib * unilib,const std::string & triggering_preconditions_overlay)112 std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromScopedMmap(
113     std::unique_ptr<libtextclassifier3::ScopedMmap> mmap, const UniLib* unilib,
114     const std::string& triggering_preconditions_overlay) {
115   if (!mmap->handle().ok()) {
116     TC3_VLOG(1) << "Mmap failed.";
117     return nullptr;
118   }
119   const ActionsModel* model = LoadAndVerifyModel(
120       reinterpret_cast<const uint8_t*>(mmap->handle().start()),
121       mmap->handle().num_bytes());
122   if (!model) {
123     TC3_LOG(ERROR) << "Model verification failed.";
124     return nullptr;
125   }
126   auto actions = std::unique_ptr<ActionsSuggestions>(new ActionsSuggestions());
127   actions->model_ = model;
128   actions->mmap_ = std::move(mmap);
129   actions->SetOrCreateUnilib(unilib);
130   actions->triggering_preconditions_overlay_buffer_ =
131       triggering_preconditions_overlay;
132   if (!actions->ValidateAndInitialize()) {
133     return nullptr;
134   }
135   return actions;
136 }
137 
FromScopedMmap(std::unique_ptr<libtextclassifier3::ScopedMmap> mmap,std::unique_ptr<UniLib> unilib,const std::string & triggering_preconditions_overlay)138 std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromScopedMmap(
139     std::unique_ptr<libtextclassifier3::ScopedMmap> mmap,
140     std::unique_ptr<UniLib> unilib,
141     const std::string& triggering_preconditions_overlay) {
142   if (!mmap->handle().ok()) {
143     TC3_VLOG(1) << "Mmap failed.";
144     return nullptr;
145   }
146   const ActionsModel* model = LoadAndVerifyModel(
147       reinterpret_cast<const uint8_t*>(mmap->handle().start()),
148       mmap->handle().num_bytes());
149   if (!model) {
150     TC3_LOG(ERROR) << "Model verification failed.";
151     return nullptr;
152   }
153   auto actions = std::unique_ptr<ActionsSuggestions>(new ActionsSuggestions());
154   actions->model_ = model;
155   actions->mmap_ = std::move(mmap);
156   actions->owned_unilib_ = std::move(unilib);
157   actions->unilib_ = actions->owned_unilib_.get();
158   actions->triggering_preconditions_overlay_buffer_ =
159       triggering_preconditions_overlay;
160   if (!actions->ValidateAndInitialize()) {
161     return nullptr;
162   }
163   return actions;
164 }
165 
FromFileDescriptor(const int fd,const int offset,const int size,const UniLib * unilib,const std::string & triggering_preconditions_overlay)166 std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromFileDescriptor(
167     const int fd, const int offset, const int size, const UniLib* unilib,
168     const std::string& triggering_preconditions_overlay) {
169   std::unique_ptr<libtextclassifier3::ScopedMmap> mmap;
170   if (offset >= 0 && size >= 0) {
171     mmap.reset(new libtextclassifier3::ScopedMmap(fd, offset, size));
172   } else {
173     mmap.reset(new libtextclassifier3::ScopedMmap(fd));
174   }
175   return FromScopedMmap(std::move(mmap), unilib,
176                         triggering_preconditions_overlay);
177 }
178 
FromFileDescriptor(const int fd,const int offset,const int size,std::unique_ptr<UniLib> unilib,const std::string & triggering_preconditions_overlay)179 std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromFileDescriptor(
180     const int fd, const int offset, const int size,
181     std::unique_ptr<UniLib> unilib,
182     const std::string& triggering_preconditions_overlay) {
183   std::unique_ptr<libtextclassifier3::ScopedMmap> mmap;
184   if (offset >= 0 && size >= 0) {
185     mmap.reset(new libtextclassifier3::ScopedMmap(fd, offset, size));
186   } else {
187     mmap.reset(new libtextclassifier3::ScopedMmap(fd));
188   }
189   return FromScopedMmap(std::move(mmap), std::move(unilib),
190                         triggering_preconditions_overlay);
191 }
192 
FromFileDescriptor(const int fd,const UniLib * unilib,const std::string & triggering_preconditions_overlay)193 std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromFileDescriptor(
194     const int fd, const UniLib* unilib,
195     const std::string& triggering_preconditions_overlay) {
196   std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
197       new libtextclassifier3::ScopedMmap(fd));
198   return FromScopedMmap(std::move(mmap), unilib,
199                         triggering_preconditions_overlay);
200 }
201 
FromFileDescriptor(const int fd,std::unique_ptr<UniLib> unilib,const std::string & triggering_preconditions_overlay)202 std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromFileDescriptor(
203     const int fd, std::unique_ptr<UniLib> unilib,
204     const std::string& triggering_preconditions_overlay) {
205   std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
206       new libtextclassifier3::ScopedMmap(fd));
207   return FromScopedMmap(std::move(mmap), std::move(unilib),
208                         triggering_preconditions_overlay);
209 }
210 
FromPath(const std::string & path,const UniLib * unilib,const std::string & triggering_preconditions_overlay)211 std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromPath(
212     const std::string& path, const UniLib* unilib,
213     const std::string& triggering_preconditions_overlay) {
214   std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
215       new libtextclassifier3::ScopedMmap(path));
216   return FromScopedMmap(std::move(mmap), unilib,
217                         triggering_preconditions_overlay);
218 }
219 
FromPath(const std::string & path,std::unique_ptr<UniLib> unilib,const std::string & triggering_preconditions_overlay)220 std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromPath(
221     const std::string& path, std::unique_ptr<UniLib> unilib,
222     const std::string& triggering_preconditions_overlay) {
223   std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
224       new libtextclassifier3::ScopedMmap(path));
225   return FromScopedMmap(std::move(mmap), std::move(unilib),
226                         triggering_preconditions_overlay);
227 }
228 
SetOrCreateUnilib(const UniLib * unilib)229 void ActionsSuggestions::SetOrCreateUnilib(const UniLib* unilib) {
230   if (unilib != nullptr) {
231     unilib_ = unilib;
232   } else {
233     owned_unilib_.reset(new UniLib);
234     unilib_ = owned_unilib_.get();
235   }
236 }
237 
ValidateAndInitialize()238 bool ActionsSuggestions::ValidateAndInitialize() {
239   if (model_ == nullptr) {
240     TC3_LOG(ERROR) << "No model specified.";
241     return false;
242   }
243 
244   if (model_->smart_reply_action_type() == nullptr) {
245     TC3_LOG(ERROR) << "No smart reply action type specified.";
246     return false;
247   }
248 
249   if (!InitializeTriggeringPreconditions()) {
250     TC3_LOG(ERROR) << "Could not initialize preconditions.";
251     return false;
252   }
253 
254   if (model_->locales() &&
255       !ParseLocales(model_->locales()->c_str(), &locales_)) {
256     TC3_LOG(ERROR) << "Could not parse model supported locales.";
257     return false;
258   }
259 
260   if (model_->tflite_model_spec() != nullptr) {
261     model_executor_ = TfLiteModelExecutor::FromBuffer(
262         model_->tflite_model_spec()->tflite_model());
263     if (!model_executor_) {
264       TC3_LOG(ERROR) << "Could not initialize model executor.";
265       return false;
266     }
267   }
268 
269   // Gather annotation entities for the rules.
270   if (model_->annotation_actions_spec() != nullptr &&
271       model_->annotation_actions_spec()->annotation_mapping() != nullptr) {
272     for (const AnnotationActionsSpec_::AnnotationMapping* mapping :
273          *model_->annotation_actions_spec()->annotation_mapping()) {
274       annotation_entity_types_.insert(mapping->annotation_collection()->str());
275     }
276   }
277 
278   if (model_->actions_entity_data_schema() != nullptr) {
279     entity_data_schema_ = LoadAndVerifyFlatbuffer<reflection::Schema>(
280         model_->actions_entity_data_schema()->Data(),
281         model_->actions_entity_data_schema()->size());
282     if (entity_data_schema_ == nullptr) {
283       TC3_LOG(ERROR) << "Could not load entity data schema data.";
284       return false;
285     }
286 
287     entity_data_builder_.reset(
288         new ReflectiveFlatbufferBuilder(entity_data_schema_));
289   } else {
290     entity_data_schema_ = nullptr;
291   }
292 
293   // Initialize regular expressions model.
294   std::unique_ptr<ZlibDecompressor> decompressor = ZlibDecompressor::Instance();
295   regex_actions_.reset(
296       new RegexActions(unilib_, model_->smart_reply_action_type()->str()));
297   if (!regex_actions_->InitializeRules(
298           model_->rules(), model_->low_confidence_rules(),
299           triggering_preconditions_overlay_, decompressor.get())) {
300     TC3_LOG(ERROR) << "Could not initialize regex rules.";
301     return false;
302   }
303 
304   // Setup grammar model.
305   if (model_->rules() != nullptr &&
306       model_->rules()->grammar_rules() != nullptr) {
307     grammar_actions_.reset(new GrammarActions(
308         unilib_, model_->rules()->grammar_rules(), entity_data_builder_.get(),
309         model_->smart_reply_action_type()->str()));
310 
311     // Gather annotation entities for the grammars.
312     if (auto annotation_nt = model_->rules()
313                                  ->grammar_rules()
314                                  ->rules()
315                                  ->nonterminals()
316                                  ->annotation_nt()) {
317       for (const grammar::RulesSet_::Nonterminals_::AnnotationNtEntry* entry :
318            *annotation_nt) {
319         annotation_entity_types_.insert(entry->key()->str());
320       }
321     }
322   }
323 
324   std::string actions_script;
325   if (GetUncompressedString(model_->lua_actions_script(),
326                             model_->compressed_lua_actions_script(),
327                             decompressor.get(), &actions_script) &&
328       !actions_script.empty()) {
329     if (!Compile(actions_script, &lua_bytecode_)) {
330       TC3_LOG(ERROR) << "Could not precompile lua actions snippet.";
331       return false;
332     }
333   }
334 
335   if (!(ranker_ = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
336             model_->ranking_options(), decompressor.get(),
337             model_->smart_reply_action_type()->str()))) {
338     TC3_LOG(ERROR) << "Could not create an action suggestions ranker.";
339     return false;
340   }
341 
342   // Create feature processor if specified.
343   const ActionsTokenFeatureProcessorOptions* options =
344       model_->feature_processor_options();
345   if (options != nullptr) {
346     if (options->tokenizer_options() == nullptr) {
347       TC3_LOG(ERROR) << "No tokenizer options specified.";
348       return false;
349     }
350 
351     feature_processor_.reset(new ActionsFeatureProcessor(options, unilib_));
352     embedding_executor_ = TFLiteEmbeddingExecutor::FromBuffer(
353         options->embedding_model(), options->embedding_size(),
354         options->embedding_quantization_bits());
355 
356     if (embedding_executor_ == nullptr) {
357       TC3_LOG(ERROR) << "Could not initialize embedding executor.";
358       return false;
359     }
360 
361     // Cache embedding of padding, start and end token.
362     if (!EmbedTokenId(options->padding_token_id(), &embedded_padding_token_) ||
363         !EmbedTokenId(options->start_token_id(), &embedded_start_token_) ||
364         !EmbedTokenId(options->end_token_id(), &embedded_end_token_)) {
365       TC3_LOG(ERROR) << "Could not precompute token embeddings.";
366       return false;
367     }
368     token_embedding_size_ = feature_processor_->GetTokenEmbeddingSize();
369   }
370 
371   // Create low confidence model if specified.
372   if (model_->low_confidence_ngram_model() != nullptr) {
373     ngram_model_ = NGramModel::Create(
374         unilib_, model_->low_confidence_ngram_model(),
375         feature_processor_ == nullptr ? nullptr
376                                       : feature_processor_->tokenizer());
377     if (ngram_model_ == nullptr) {
378       TC3_LOG(ERROR) << "Could not create ngram linear regression model.";
379       return false;
380     }
381   }
382 
383   return true;
384 }
385 
InitializeTriggeringPreconditions()386 bool ActionsSuggestions::InitializeTriggeringPreconditions() {
387   triggering_preconditions_overlay_ =
388       LoadAndVerifyFlatbuffer<TriggeringPreconditions>(
389           triggering_preconditions_overlay_buffer_);
390 
391   if (triggering_preconditions_overlay_ == nullptr &&
392       !triggering_preconditions_overlay_buffer_.empty()) {
393     TC3_LOG(ERROR) << "Could not load triggering preconditions overwrites.";
394     return false;
395   }
396   const flatbuffers::Table* overlay =
397       reinterpret_cast<const flatbuffers::Table*>(
398           triggering_preconditions_overlay_);
399   const TriggeringPreconditions* defaults = model_->preconditions();
400   if (defaults == nullptr) {
401     TC3_LOG(ERROR) << "No triggering conditions specified.";
402     return false;
403   }
404 
405   preconditions_.min_smart_reply_triggering_score = ValueOrDefault(
406       overlay, TriggeringPreconditions::VT_MIN_SMART_REPLY_TRIGGERING_SCORE,
407       defaults->min_smart_reply_triggering_score());
408   preconditions_.max_sensitive_topic_score = ValueOrDefault(
409       overlay, TriggeringPreconditions::VT_MAX_SENSITIVE_TOPIC_SCORE,
410       defaults->max_sensitive_topic_score());
411   preconditions_.suppress_on_sensitive_topic = ValueOrDefault(
412       overlay, TriggeringPreconditions::VT_SUPPRESS_ON_SENSITIVE_TOPIC,
413       defaults->suppress_on_sensitive_topic());
414   preconditions_.min_input_length =
415       ValueOrDefault(overlay, TriggeringPreconditions::VT_MIN_INPUT_LENGTH,
416                      defaults->min_input_length());
417   preconditions_.max_input_length =
418       ValueOrDefault(overlay, TriggeringPreconditions::VT_MAX_INPUT_LENGTH,
419                      defaults->max_input_length());
420   preconditions_.min_locale_match_fraction = ValueOrDefault(
421       overlay, TriggeringPreconditions::VT_MIN_LOCALE_MATCH_FRACTION,
422       defaults->min_locale_match_fraction());
423   preconditions_.handle_missing_locale_as_supported = ValueOrDefault(
424       overlay, TriggeringPreconditions::VT_HANDLE_MISSING_LOCALE_AS_SUPPORTED,
425       defaults->handle_missing_locale_as_supported());
426   preconditions_.handle_unknown_locale_as_supported = ValueOrDefault(
427       overlay, TriggeringPreconditions::VT_HANDLE_UNKNOWN_LOCALE_AS_SUPPORTED,
428       defaults->handle_unknown_locale_as_supported());
429   preconditions_.suppress_on_low_confidence_input = ValueOrDefault(
430       overlay, TriggeringPreconditions::VT_SUPPRESS_ON_LOW_CONFIDENCE_INPUT,
431       defaults->suppress_on_low_confidence_input());
432   preconditions_.min_reply_score_threshold = ValueOrDefault(
433       overlay, TriggeringPreconditions::VT_MIN_REPLY_SCORE_THRESHOLD,
434       defaults->min_reply_score_threshold());
435 
436   return true;
437 }
438 
EmbedTokenId(const int32 token_id,std::vector<float> * embedding) const439 bool ActionsSuggestions::EmbedTokenId(const int32 token_id,
440                                       std::vector<float>* embedding) const {
441   return feature_processor_->AppendFeatures(
442       {token_id},
443       /*dense_features=*/{}, embedding_executor_.get(), embedding);
444 }
445 
Tokenize(const std::vector<std::string> & context) const446 std::vector<std::vector<Token>> ActionsSuggestions::Tokenize(
447     const std::vector<std::string>& context) const {
448   std::vector<std::vector<Token>> tokens;
449   tokens.reserve(context.size());
450   for (const std::string& message : context) {
451     tokens.push_back(feature_processor_->tokenizer()->Tokenize(message));
452   }
453   return tokens;
454 }
455 
EmbedTokensPerMessage(const std::vector<std::vector<Token>> & tokens,std::vector<float> * embeddings,int * max_num_tokens_per_message) const456 bool ActionsSuggestions::EmbedTokensPerMessage(
457     const std::vector<std::vector<Token>>& tokens,
458     std::vector<float>* embeddings, int* max_num_tokens_per_message) const {
459   const int num_messages = tokens.size();
460   *max_num_tokens_per_message = 0;
461   for (int i = 0; i < num_messages; i++) {
462     const int num_message_tokens = tokens[i].size();
463     if (num_message_tokens > *max_num_tokens_per_message) {
464       *max_num_tokens_per_message = num_message_tokens;
465     }
466   }
467 
468   if (model_->feature_processor_options()->min_num_tokens_per_message() >
469       *max_num_tokens_per_message) {
470     *max_num_tokens_per_message =
471         model_->feature_processor_options()->min_num_tokens_per_message();
472   }
473   if (model_->feature_processor_options()->max_num_tokens_per_message() > 0 &&
474       *max_num_tokens_per_message >
475           model_->feature_processor_options()->max_num_tokens_per_message()) {
476     *max_num_tokens_per_message =
477         model_->feature_processor_options()->max_num_tokens_per_message();
478   }
479 
480   // Embed all tokens and add paddings to pad tokens of each message to the
481   // maximum number of tokens in a message of the conversation.
482   // If a number of tokens is specified in the model config, tokens at the
483   // beginning of a message are dropped if they don't fit in the limit.
484   for (int i = 0; i < num_messages; i++) {
485     const int start =
486         std::max<int>(tokens[i].size() - *max_num_tokens_per_message, 0);
487     for (int pos = start; pos < tokens[i].size(); pos++) {
488       if (!feature_processor_->AppendTokenFeatures(
489               tokens[i][pos], embedding_executor_.get(), embeddings)) {
490         TC3_LOG(ERROR) << "Could not run token feature extractor.";
491         return false;
492       }
493     }
494     // Add padding.
495     for (int k = tokens[i].size(); k < *max_num_tokens_per_message; k++) {
496       embeddings->insert(embeddings->end(), embedded_padding_token_.begin(),
497                          embedded_padding_token_.end());
498     }
499   }
500 
501   return true;
502 }
503 
EmbedAndFlattenTokens(const std::vector<std::vector<Token>> & tokens,std::vector<float> * embeddings,int * total_token_count) const504 bool ActionsSuggestions::EmbedAndFlattenTokens(
505     const std::vector<std::vector<Token>>& tokens,
506     std::vector<float>* embeddings, int* total_token_count) const {
507   const int num_messages = tokens.size();
508   int start_message = 0;
509   int message_token_offset = 0;
510 
511   // If a maximum model input length is specified, we need to check how
512   // much we need to trim at the start.
513   const int max_num_total_tokens =
514       model_->feature_processor_options()->max_num_total_tokens();
515   if (max_num_total_tokens > 0) {
516     int total_tokens = 0;
517     start_message = num_messages - 1;
518     for (; start_message >= 0; start_message--) {
519       // Tokens of the message + start and end token.
520       const int num_message_tokens = tokens[start_message].size() + 2;
521       total_tokens += num_message_tokens;
522 
523       // Check whether we exhausted the budget.
524       if (total_tokens >= max_num_total_tokens) {
525         message_token_offset = total_tokens - max_num_total_tokens;
526         break;
527       }
528     }
529   }
530 
531   // Add embeddings.
532   *total_token_count = 0;
533   for (int i = start_message; i < num_messages; i++) {
534     if (message_token_offset == 0) {
535       ++(*total_token_count);
536       // Add `start message` token.
537       embeddings->insert(embeddings->end(), embedded_start_token_.begin(),
538                          embedded_start_token_.end());
539     }
540 
541     for (int pos = std::max(0, message_token_offset - 1);
542          pos < tokens[i].size(); pos++) {
543       ++(*total_token_count);
544       if (!feature_processor_->AppendTokenFeatures(
545               tokens[i][pos], embedding_executor_.get(), embeddings)) {
546         TC3_LOG(ERROR) << "Could not run token feature extractor.";
547         return false;
548       }
549     }
550 
551     // Add `end message` token.
552     ++(*total_token_count);
553     embeddings->insert(embeddings->end(), embedded_end_token_.begin(),
554                        embedded_end_token_.end());
555 
556     // Reset for the subsequent messages.
557     message_token_offset = 0;
558   }
559 
560   // Add optional padding.
561   const int min_num_total_tokens =
562       model_->feature_processor_options()->min_num_total_tokens();
563   for (; *total_token_count < min_num_total_tokens; ++(*total_token_count)) {
564     embeddings->insert(embeddings->end(), embedded_padding_token_.begin(),
565                        embedded_padding_token_.end());
566   }
567 
568   return true;
569 }
570 
AllocateInput(const int conversation_length,const int max_tokens,const int total_token_count,tflite::Interpreter * interpreter) const571 bool ActionsSuggestions::AllocateInput(const int conversation_length,
572                                        const int max_tokens,
573                                        const int total_token_count,
574                                        tflite::Interpreter* interpreter) const {
575   if (model_->tflite_model_spec()->resize_inputs()) {
576     if (model_->tflite_model_spec()->input_context() >= 0) {
577       interpreter->ResizeInputTensor(
578           interpreter->inputs()[model_->tflite_model_spec()->input_context()],
579           {1, conversation_length});
580     }
581     if (model_->tflite_model_spec()->input_user_id() >= 0) {
582       interpreter->ResizeInputTensor(
583           interpreter->inputs()[model_->tflite_model_spec()->input_user_id()],
584           {1, conversation_length});
585     }
586     if (model_->tflite_model_spec()->input_time_diffs() >= 0) {
587       interpreter->ResizeInputTensor(
588           interpreter
589               ->inputs()[model_->tflite_model_spec()->input_time_diffs()],
590           {1, conversation_length});
591     }
592     if (model_->tflite_model_spec()->input_num_tokens() >= 0) {
593       interpreter->ResizeInputTensor(
594           interpreter
595               ->inputs()[model_->tflite_model_spec()->input_num_tokens()],
596           {conversation_length, 1});
597     }
598     if (model_->tflite_model_spec()->input_token_embeddings() >= 0) {
599       interpreter->ResizeInputTensor(
600           interpreter
601               ->inputs()[model_->tflite_model_spec()->input_token_embeddings()],
602           {conversation_length, max_tokens, token_embedding_size_});
603     }
604     if (model_->tflite_model_spec()->input_flattened_token_embeddings() >= 0) {
605       interpreter->ResizeInputTensor(
606           interpreter->inputs()[model_->tflite_model_spec()
607                                     ->input_flattened_token_embeddings()],
608           {1, total_token_count});
609     }
610   }
611 
612   return interpreter->AllocateTensors() == kTfLiteOk;
613 }
614 
SetupModelInput(const std::vector<std::string> & context,const std::vector<int> & user_ids,const std::vector<float> & time_diffs,const int num_suggestions,const ActionSuggestionOptions & options,tflite::Interpreter * interpreter) const615 bool ActionsSuggestions::SetupModelInput(
616     const std::vector<std::string>& context, const std::vector<int>& user_ids,
617     const std::vector<float>& time_diffs, const int num_suggestions,
618     const ActionSuggestionOptions& options,
619     tflite::Interpreter* interpreter) const {
620   // Compute token embeddings.
621   std::vector<std::vector<Token>> tokens;
622   std::vector<float> token_embeddings;
623   std::vector<float> flattened_token_embeddings;
624   int max_tokens = 0;
625   int total_token_count = 0;
626   if (model_->tflite_model_spec()->input_num_tokens() >= 0 ||
627       model_->tflite_model_spec()->input_token_embeddings() >= 0 ||
628       model_->tflite_model_spec()->input_flattened_token_embeddings() >= 0) {
629     if (feature_processor_ == nullptr) {
630       TC3_LOG(ERROR) << "No feature processor specified.";
631       return false;
632     }
633 
634     // Tokenize the messages in the conversation.
635     tokens = Tokenize(context);
636     if (model_->tflite_model_spec()->input_token_embeddings() >= 0) {
637       if (!EmbedTokensPerMessage(tokens, &token_embeddings, &max_tokens)) {
638         TC3_LOG(ERROR) << "Could not extract token features.";
639         return false;
640       }
641     }
642     if (model_->tflite_model_spec()->input_flattened_token_embeddings() >= 0) {
643       if (!EmbedAndFlattenTokens(tokens, &flattened_token_embeddings,
644                                  &total_token_count)) {
645         TC3_LOG(ERROR) << "Could not extract token features.";
646         return false;
647       }
648     }
649   }
650 
651   if (!AllocateInput(context.size(), max_tokens, total_token_count,
652                      interpreter)) {
653     TC3_LOG(ERROR) << "TensorFlow Lite model allocation failed.";
654     return false;
655   }
656   if (model_->tflite_model_spec()->input_context() >= 0) {
657     model_executor_->SetInput<std::string>(
658         model_->tflite_model_spec()->input_context(), context, interpreter);
659   }
660   if (model_->tflite_model_spec()->input_context_length() >= 0) {
661     model_executor_->SetInput<int>(
662         model_->tflite_model_spec()->input_context_length(), context.size(),
663         interpreter);
664   }
665   if (model_->tflite_model_spec()->input_user_id() >= 0) {
666     model_executor_->SetInput<int>(model_->tflite_model_spec()->input_user_id(),
667                                    user_ids, interpreter);
668   }
669   if (model_->tflite_model_spec()->input_num_suggestions() >= 0) {
670     model_executor_->SetInput<int>(
671         model_->tflite_model_spec()->input_num_suggestions(), num_suggestions,
672         interpreter);
673   }
674   if (model_->tflite_model_spec()->input_time_diffs() >= 0) {
675     model_executor_->SetInput<float>(
676         model_->tflite_model_spec()->input_time_diffs(), time_diffs,
677         interpreter);
678   }
679   if (model_->tflite_model_spec()->input_num_tokens() >= 0) {
680     std::vector<int> num_tokens_per_message(tokens.size());
681     for (int i = 0; i < tokens.size(); i++) {
682       num_tokens_per_message[i] = tokens[i].size();
683     }
684     model_executor_->SetInput<int>(
685         model_->tflite_model_spec()->input_num_tokens(), num_tokens_per_message,
686         interpreter);
687   }
688   if (model_->tflite_model_spec()->input_token_embeddings() >= 0) {
689     model_executor_->SetInput<float>(
690         model_->tflite_model_spec()->input_token_embeddings(), token_embeddings,
691         interpreter);
692   }
693   if (model_->tflite_model_spec()->input_flattened_token_embeddings() >= 0) {
694     model_executor_->SetInput<float>(
695         model_->tflite_model_spec()->input_flattened_token_embeddings(),
696         flattened_token_embeddings, interpreter);
697   }
698   // Set up additional input parameters.
699   if (const auto* input_name_index =
700           model_->tflite_model_spec()->input_name_index()) {
701     const std::unordered_map<std::string, Variant>& model_parameters =
702         options.model_parameters;
703     for (const TensorflowLiteModelSpec_::InputNameIndexEntry* entry :
704          *input_name_index) {
705       const std::string param_name = entry->key()->str();
706       const int param_index = entry->value();
707       const TfLiteType param_type =
708           interpreter->tensor(interpreter->inputs()[param_index])->type;
709       const auto param_value_it = model_parameters.find(param_name);
710       const bool has_value = param_value_it != model_parameters.end();
711       switch (param_type) {
712         case kTfLiteFloat32:
713           model_executor_->SetInput<float>(
714               param_index,
715               has_value ? param_value_it->second.Value<float>() : kDefaultFloat,
716               interpreter);
717           break;
718         case kTfLiteInt32:
719           model_executor_->SetInput<int32_t>(
720               param_index,
721               has_value ? param_value_it->second.Value<int>() : kDefaultInt,
722               interpreter);
723           break;
724         case kTfLiteInt64:
725           model_executor_->SetInput<int64_t>(
726               param_index,
727               has_value ? param_value_it->second.Value<int64>() : kDefaultInt,
728               interpreter);
729           break;
730         case kTfLiteUInt8:
731           model_executor_->SetInput<uint8_t>(
732               param_index,
733               has_value ? param_value_it->second.Value<uint8>() : kDefaultInt,
734               interpreter);
735           break;
736         case kTfLiteInt8:
737           model_executor_->SetInput<int8_t>(
738               param_index,
739               has_value ? param_value_it->second.Value<int8>() : kDefaultInt,
740               interpreter);
741           break;
742         case kTfLiteBool:
743           model_executor_->SetInput<bool>(
744               param_index,
745               has_value ? param_value_it->second.Value<bool>() : kDefaultBool,
746               interpreter);
747           break;
748         default:
749           TC3_LOG(ERROR) << "Unsupported type of additional input parameter: "
750                          << param_name;
751       }
752     }
753   }
754   return true;
755 }
756 
PopulateTextReplies(const tflite::Interpreter * interpreter,int suggestion_index,int score_index,const std::string & type,ActionsSuggestionsResponse * response) const757 void ActionsSuggestions::PopulateTextReplies(
758     const tflite::Interpreter* interpreter, int suggestion_index,
759     int score_index, const std::string& type,
760     ActionsSuggestionsResponse* response) const {
761   const std::vector<tflite::StringRef> replies =
762       model_executor_->Output<tflite::StringRef>(suggestion_index, interpreter);
763   const TensorView<float> scores =
764       model_executor_->OutputView<float>(score_index, interpreter);
765   for (int i = 0; i < replies.size(); i++) {
766     if (replies[i].len == 0) {
767       continue;
768     }
769     const float score = scores.data()[i];
770     if (score < preconditions_.min_reply_score_threshold) {
771       continue;
772     }
773     response->actions.push_back(
774         {std::string(replies[i].str, replies[i].len), type, score});
775   }
776 }
777 
FillSuggestionFromSpecWithEntityData(const ActionSuggestionSpec * spec,ActionSuggestion * suggestion) const778 void ActionsSuggestions::FillSuggestionFromSpecWithEntityData(
779     const ActionSuggestionSpec* spec, ActionSuggestion* suggestion) const {
780   std::unique_ptr<ReflectiveFlatbuffer> entity_data =
781       entity_data_builder_ != nullptr ? entity_data_builder_->NewRoot()
782                                       : nullptr;
783   FillSuggestionFromSpec(spec, entity_data.get(), suggestion);
784 }
785 
PopulateIntentTriggering(const tflite::Interpreter * interpreter,int suggestion_index,int score_index,const ActionSuggestionSpec * task_spec,ActionsSuggestionsResponse * response) const786 void ActionsSuggestions::PopulateIntentTriggering(
787     const tflite::Interpreter* interpreter, int suggestion_index,
788     int score_index, const ActionSuggestionSpec* task_spec,
789     ActionsSuggestionsResponse* response) const {
790   if (!task_spec || task_spec->type()->size() == 0) {
791     TC3_LOG(ERROR)
792         << "Task type for intent (action) triggering cannot be empty!";
793     return;
794   }
795   const TensorView<bool> intent_prediction =
796       model_executor_->OutputView<bool>(suggestion_index, interpreter);
797   const TensorView<float> intent_scores =
798       model_executor_->OutputView<float>(score_index, interpreter);
799   // Two result corresponding to binary triggering case.
800   TC3_CHECK_EQ(intent_prediction.size(), 2);
801   TC3_CHECK_EQ(intent_scores.size(), 2);
802   // We rely on in-graph thresholding logic so at this point the results
803   // have been ranked properly according to threshold.
804   const bool triggering = intent_prediction.data()[0];
805   const float trigger_score = intent_scores.data()[0];
806 
807   if (triggering) {
808     ActionSuggestion suggestion;
809     std::unique_ptr<ReflectiveFlatbuffer> entity_data =
810         entity_data_builder_ != nullptr ? entity_data_builder_->NewRoot()
811                                         : nullptr;
812     FillSuggestionFromSpecWithEntityData(task_spec, &suggestion);
813     suggestion.score = trigger_score;
814     response->actions.push_back(std::move(suggestion));
815   }
816 }
817 
ReadModelOutput(tflite::Interpreter * interpreter,const ActionSuggestionOptions & options,ActionsSuggestionsResponse * response) const818 bool ActionsSuggestions::ReadModelOutput(
819     tflite::Interpreter* interpreter, const ActionSuggestionOptions& options,
820     ActionsSuggestionsResponse* response) const {
821   // Read sensitivity and triggering score predictions.
822   if (model_->tflite_model_spec()->output_triggering_score() >= 0) {
823     const TensorView<float> triggering_score =
824         model_executor_->OutputView<float>(
825             model_->tflite_model_spec()->output_triggering_score(),
826             interpreter);
827     if (!triggering_score.is_valid() || triggering_score.size() == 0) {
828       TC3_LOG(ERROR) << "Could not compute triggering score.";
829       return false;
830     }
831     response->triggering_score = triggering_score.data()[0];
832     response->output_filtered_min_triggering_score =
833         (response->triggering_score <
834          preconditions_.min_smart_reply_triggering_score);
835   }
836   if (model_->tflite_model_spec()->output_sensitive_topic_score() >= 0) {
837     const TensorView<float> sensitive_topic_score =
838         model_executor_->OutputView<float>(
839             model_->tflite_model_spec()->output_sensitive_topic_score(),
840             interpreter);
841     if (!sensitive_topic_score.is_valid() ||
842         sensitive_topic_score.dim(0) != 1) {
843       TC3_LOG(ERROR) << "Could not compute sensitive topic score.";
844       return false;
845     }
846     response->sensitivity_score = sensitive_topic_score.data()[0];
847     response->output_filtered_sensitivity =
848         (response->sensitivity_score >
849          preconditions_.max_sensitive_topic_score);
850   }
851 
852   // Suppress model outputs.
853   if (response->output_filtered_sensitivity) {
854     return true;
855   }
856 
857   // Read smart reply predictions.
858   if (!response->output_filtered_min_triggering_score &&
859       model_->tflite_model_spec()->output_replies() >= 0) {
860     PopulateTextReplies(interpreter,
861                         model_->tflite_model_spec()->output_replies(),
862                         model_->tflite_model_spec()->output_replies_scores(),
863                         model_->smart_reply_action_type()->str(), response);
864   }
865 
866   // Read actions suggestions.
867   if (model_->tflite_model_spec()->output_actions_scores() >= 0) {
868     const TensorView<float> actions_scores = model_executor_->OutputView<float>(
869         model_->tflite_model_spec()->output_actions_scores(), interpreter);
870     for (int i = 0; i < model_->action_type()->size(); i++) {
871       const ActionTypeOptions* action_type = model_->action_type()->Get(i);
872       // Skip disabled action classes, such as the default other category.
873       if (!action_type->enabled()) {
874         continue;
875       }
876       const float score = actions_scores.data()[i];
877       if (score < action_type->min_triggering_score()) {
878         continue;
879       }
880 
881       // Create action from model output.
882       ActionSuggestion suggestion;
883       suggestion.type = action_type->name()->str();
884       std::unique_ptr<ReflectiveFlatbuffer> entity_data =
885           entity_data_builder_ != nullptr ? entity_data_builder_->NewRoot()
886                                           : nullptr;
887       FillSuggestionFromSpecWithEntityData(action_type->action(), &suggestion);
888       suggestion.score = score;
889       response->actions.push_back(std::move(suggestion));
890     }
891   }
892 
893   // Read multi-task predictions and construct the result properly.
894   if (const auto* prediction_metadata =
895           model_->tflite_model_spec()->prediction_metadata()) {
896     for (const PredictionMetadata* metadata : *prediction_metadata) {
897       const ActionSuggestionSpec* task_spec = metadata->task_spec();
898       const int suggestions_index = metadata->output_suggestions();
899       const int suggestions_scores_index =
900           metadata->output_suggestions_scores();
901       switch (metadata->prediction_type()) {
902         case PredictionType_NEXT_MESSAGE_PREDICTION:
903           if (!task_spec || task_spec->type()->size() == 0) {
904             TC3_LOG(WARNING) << "Task type not provided, use default "
905                                 "smart_reply_action_type!";
906           }
907           PopulateTextReplies(
908               interpreter, suggestions_index, suggestions_scores_index,
909               task_spec ? task_spec->type()->str()
910                         : model_->smart_reply_action_type()->str(),
911               response);
912           break;
913         case PredictionType_INTENT_TRIGGERING:
914           PopulateIntentTriggering(interpreter, suggestions_index,
915                                    suggestions_scores_index, task_spec,
916                                    response);
917           break;
918         default:
919           TC3_LOG(ERROR) << "Unsupported prediction type!";
920           return false;
921       }
922     }
923   }
924 
925   return true;
926 }
927 
SuggestActionsFromModel(const Conversation & conversation,const int num_messages,const ActionSuggestionOptions & options,ActionsSuggestionsResponse * response,std::unique_ptr<tflite::Interpreter> * interpreter) const928 bool ActionsSuggestions::SuggestActionsFromModel(
929     const Conversation& conversation, const int num_messages,
930     const ActionSuggestionOptions& options,
931     ActionsSuggestionsResponse* response,
932     std::unique_ptr<tflite::Interpreter>* interpreter) const {
933   TC3_CHECK_LE(num_messages, conversation.messages.size());
934 
935   if (!model_executor_) {
936     return true;
937   }
938   *interpreter = model_executor_->CreateInterpreter();
939 
940   if (!*interpreter) {
941     TC3_LOG(ERROR) << "Could not build TensorFlow Lite interpreter for the "
942                       "actions suggestions model.";
943     return false;
944   }
945 
946   std::vector<std::string> context;
947   std::vector<int> user_ids;
948   std::vector<float> time_diffs;
949   context.reserve(num_messages);
950   user_ids.reserve(num_messages);
951   time_diffs.reserve(num_messages);
952 
953   // Gather last `num_messages` messages from the conversation.
954   int64 last_message_reference_time_ms_utc = 0;
955   const float second_in_ms = 1000;
956   for (int i = conversation.messages.size() - num_messages;
957        i < conversation.messages.size(); i++) {
958     const ConversationMessage& message = conversation.messages[i];
959     context.push_back(message.text);
960     user_ids.push_back(message.user_id);
961 
962     float time_diff_secs = 0;
963     if (message.reference_time_ms_utc != 0 &&
964         last_message_reference_time_ms_utc != 0) {
965       time_diff_secs = std::max(0.0f, (message.reference_time_ms_utc -
966                                        last_message_reference_time_ms_utc) /
967                                           second_in_ms);
968     }
969     if (message.reference_time_ms_utc != 0) {
970       last_message_reference_time_ms_utc = message.reference_time_ms_utc;
971     }
972     time_diffs.push_back(time_diff_secs);
973   }
974 
975   if (!SetupModelInput(context, user_ids, time_diffs,
976                        /*num_suggestions=*/model_->num_smart_replies(), options,
977                        interpreter->get())) {
978     TC3_LOG(ERROR) << "Failed to setup input for TensorFlow Lite model.";
979     return false;
980   }
981 
982   if ((*interpreter)->Invoke() != kTfLiteOk) {
983     TC3_LOG(ERROR) << "Failed to invoke TensorFlow Lite interpreter.";
984     return false;
985   }
986 
987   return ReadModelOutput(interpreter->get(), options, response);
988 }
989 
AnnotationOptionsForMessage(const ConversationMessage & message) const990 AnnotationOptions ActionsSuggestions::AnnotationOptionsForMessage(
991     const ConversationMessage& message) const {
992   AnnotationOptions options;
993   options.detected_text_language_tags = message.detected_text_language_tags;
994   options.reference_time_ms_utc = message.reference_time_ms_utc;
995   options.reference_timezone = message.reference_timezone;
996   options.annotation_usecase =
997       model_->annotation_actions_spec()->annotation_usecase();
998   options.is_serialized_entity_data_enabled =
999       model_->annotation_actions_spec()->is_serialized_entity_data_enabled();
1000   options.entity_types = annotation_entity_types_;
1001   return options;
1002 }
1003 
1004 // Run annotator on the messages of a conversation.
AnnotateConversation(const Conversation & conversation,const Annotator * annotator) const1005 Conversation ActionsSuggestions::AnnotateConversation(
1006     const Conversation& conversation, const Annotator* annotator) const {
1007   if (annotator == nullptr) {
1008     return conversation;
1009   }
1010   const int num_messages_grammar =
1011       ((model_->rules() && model_->rules()->grammar_rules() &&
1012         model_->rules()
1013             ->grammar_rules()
1014             ->rules()
1015             ->nonterminals()
1016             ->annotation_nt())
1017            ? 1
1018            : 0);
1019   const int num_messages_mapping =
1020       (model_->annotation_actions_spec()
1021            ? std::max(model_->annotation_actions_spec()
1022                           ->max_history_from_any_person(),
1023                       model_->annotation_actions_spec()
1024                           ->max_history_from_last_person())
1025            : 0);
1026   const int num_messages = std::max(num_messages_grammar, num_messages_mapping);
1027   if (num_messages == 0) {
1028     // No annotations are used.
1029     return conversation;
1030   }
1031   Conversation annotated_conversation = conversation;
1032   for (int i = 0, message_index = annotated_conversation.messages.size() - 1;
1033        i < num_messages && message_index >= 0; i++, message_index--) {
1034     ConversationMessage* message =
1035         &annotated_conversation.messages[message_index];
1036     if (message->annotations.empty()) {
1037       message->annotations = annotator->Annotate(
1038           message->text, AnnotationOptionsForMessage(*message));
1039       for (int i = 0; i < message->annotations.size(); i++) {
1040         ClassificationResult* classification =
1041             &message->annotations[i].classification.front();
1042 
1043         // Specialize datetime annotation to time annotation if no date
1044         // component is present.
1045         if (classification->collection == Collections::DateTime() &&
1046             classification->datetime_parse_result.IsSet()) {
1047           bool has_only_time = true;
1048           for (const DatetimeComponent& component :
1049                classification->datetime_parse_result.datetime_components) {
1050             if (component.component_type !=
1051                     DatetimeComponent::ComponentType::UNSPECIFIED &&
1052                 component.component_type <
1053                     DatetimeComponent::ComponentType::HOUR) {
1054               has_only_time = false;
1055               break;
1056             }
1057           }
1058           if (has_only_time) {
1059             classification->collection = kTimeAnnotation;
1060           }
1061         }
1062       }
1063     }
1064   }
1065   return annotated_conversation;
1066 }
1067 
SuggestActionsFromAnnotations(const Conversation & conversation,std::vector<ActionSuggestion> * actions) const1068 void ActionsSuggestions::SuggestActionsFromAnnotations(
1069     const Conversation& conversation,
1070     std::vector<ActionSuggestion>* actions) const {
1071   if (model_->annotation_actions_spec() == nullptr ||
1072       model_->annotation_actions_spec()->annotation_mapping() == nullptr ||
1073       model_->annotation_actions_spec()->annotation_mapping()->size() == 0) {
1074     return;
1075   }
1076 
1077   // Create actions based on the annotations.
1078   const int max_from_any_person =
1079       model_->annotation_actions_spec()->max_history_from_any_person();
1080   const int max_from_last_person =
1081       model_->annotation_actions_spec()->max_history_from_last_person();
1082   const int last_person = conversation.messages.back().user_id;
1083 
1084   int num_messages_last_person = 0;
1085   int num_messages_any_person = 0;
1086   bool all_from_last_person = true;
1087   for (int message_index = conversation.messages.size() - 1; message_index >= 0;
1088        message_index--) {
1089     const ConversationMessage& message = conversation.messages[message_index];
1090     std::vector<AnnotatedSpan> annotations = message.annotations;
1091 
1092     // Update how many messages we have processed from the last person in the
1093     // conversation and from any person in the conversation.
1094     num_messages_any_person++;
1095     if (all_from_last_person && message.user_id == last_person) {
1096       num_messages_last_person++;
1097     } else {
1098       all_from_last_person = false;
1099     }
1100 
1101     if (num_messages_any_person > max_from_any_person &&
1102         (!all_from_last_person ||
1103          num_messages_last_person > max_from_last_person)) {
1104       break;
1105     }
1106 
1107     if (message.user_id == kLocalUserId) {
1108       if (model_->annotation_actions_spec()->only_until_last_sent()) {
1109         break;
1110       }
1111       if (!model_->annotation_actions_spec()->include_local_user_messages()) {
1112         continue;
1113       }
1114     }
1115 
1116     std::vector<ActionSuggestionAnnotation> action_annotations;
1117     action_annotations.reserve(annotations.size());
1118     for (const AnnotatedSpan& annotation : annotations) {
1119       if (annotation.classification.empty()) {
1120         continue;
1121       }
1122 
1123       const ClassificationResult& classification_result =
1124           annotation.classification[0];
1125 
1126       ActionSuggestionAnnotation action_annotation;
1127       action_annotation.span = {
1128           message_index, annotation.span,
1129           UTF8ToUnicodeText(message.text, /*do_copy=*/false)
1130               .UTF8Substring(annotation.span.first, annotation.span.second)};
1131       action_annotation.entity = classification_result;
1132       action_annotation.name = classification_result.collection;
1133       action_annotations.push_back(std::move(action_annotation));
1134     }
1135 
1136     if (model_->annotation_actions_spec()->deduplicate_annotations()) {
1137       // Create actions only for deduplicated annotations.
1138       for (const int annotation_id :
1139            DeduplicateAnnotations(action_annotations)) {
1140         SuggestActionsFromAnnotation(
1141             message_index, action_annotations[annotation_id], actions);
1142       }
1143     } else {
1144       // Create actions for all annotations.
1145       for (const ActionSuggestionAnnotation& annotation : action_annotations) {
1146         SuggestActionsFromAnnotation(message_index, annotation, actions);
1147       }
1148     }
1149   }
1150 }
1151 
SuggestActionsFromAnnotation(const int message_index,const ActionSuggestionAnnotation & annotation,std::vector<ActionSuggestion> * actions) const1152 void ActionsSuggestions::SuggestActionsFromAnnotation(
1153     const int message_index, const ActionSuggestionAnnotation& annotation,
1154     std::vector<ActionSuggestion>* actions) const {
1155   for (const AnnotationActionsSpec_::AnnotationMapping* mapping :
1156        *model_->annotation_actions_spec()->annotation_mapping()) {
1157     if (annotation.entity.collection ==
1158         mapping->annotation_collection()->str()) {
1159       if (annotation.entity.score < mapping->min_annotation_score()) {
1160         continue;
1161       }
1162 
1163       std::unique_ptr<ReflectiveFlatbuffer> entity_data =
1164           entity_data_builder_ != nullptr ? entity_data_builder_->NewRoot()
1165                                           : nullptr;
1166 
1167       // Set annotation text as (additional) entity data field.
1168       if (mapping->entity_field() != nullptr) {
1169         TC3_CHECK_NE(entity_data, nullptr);
1170 
1171         UnicodeText normalized_annotation_text =
1172             UTF8ToUnicodeText(annotation.span.text, /*do_copy=*/false);
1173 
1174         // Apply normalization if specified.
1175         if (mapping->normalization_options() != nullptr) {
1176           normalized_annotation_text =
1177               NormalizeText(*unilib_, mapping->normalization_options(),
1178                             normalized_annotation_text);
1179         }
1180 
1181         entity_data->ParseAndSet(mapping->entity_field(),
1182                                  normalized_annotation_text.ToUTF8String());
1183       }
1184 
1185       ActionSuggestion suggestion;
1186       FillSuggestionFromSpec(mapping->action(), entity_data.get(), &suggestion);
1187       if (mapping->use_annotation_score()) {
1188         suggestion.score = annotation.entity.score;
1189       }
1190       suggestion.annotations = {annotation};
1191       actions->push_back(std::move(suggestion));
1192     }
1193   }
1194 }
1195 
DeduplicateAnnotations(const std::vector<ActionSuggestionAnnotation> & annotations) const1196 std::vector<int> ActionsSuggestions::DeduplicateAnnotations(
1197     const std::vector<ActionSuggestionAnnotation>& annotations) const {
1198   std::map<std::pair<std::string, std::string>, int> deduplicated_annotations;
1199 
1200   for (int i = 0; i < annotations.size(); i++) {
1201     const std::pair<std::string, std::string> key = {annotations[i].name,
1202                                                      annotations[i].span.text};
1203     auto entry = deduplicated_annotations.find(key);
1204     if (entry != deduplicated_annotations.end()) {
1205       // Kepp the annotation with the higher score.
1206       if (annotations[entry->second].entity.score <
1207           annotations[i].entity.score) {
1208         entry->second = i;
1209       }
1210       continue;
1211     }
1212     deduplicated_annotations.insert(entry, {key, i});
1213   }
1214 
1215   std::vector<int> result;
1216   result.reserve(deduplicated_annotations.size());
1217   for (const auto& key_and_annotation : deduplicated_annotations) {
1218     result.push_back(key_and_annotation.second);
1219   }
1220   return result;
1221 }
1222 
SuggestActionsFromLua(const Conversation & conversation,const TfLiteModelExecutor * model_executor,const tflite::Interpreter * interpreter,const reflection::Schema * annotation_entity_data_schema,std::vector<ActionSuggestion> * actions) const1223 bool ActionsSuggestions::SuggestActionsFromLua(
1224     const Conversation& conversation, const TfLiteModelExecutor* model_executor,
1225     const tflite::Interpreter* interpreter,
1226     const reflection::Schema* annotation_entity_data_schema,
1227     std::vector<ActionSuggestion>* actions) const {
1228   if (lua_bytecode_.empty()) {
1229     return true;
1230   }
1231 
1232   auto lua_actions = LuaActionsSuggestions::CreateLuaActionsSuggestions(
1233       lua_bytecode_, conversation, model_executor, model_->tflite_model_spec(),
1234       interpreter, entity_data_schema_, annotation_entity_data_schema);
1235   if (lua_actions == nullptr) {
1236     TC3_LOG(ERROR) << "Could not create lua actions.";
1237     return false;
1238   }
1239   return lua_actions->SuggestActions(actions);
1240 }
1241 
GatherActionsSuggestions(const Conversation & conversation,const Annotator * annotator,const ActionSuggestionOptions & options,ActionsSuggestionsResponse * response) const1242 bool ActionsSuggestions::GatherActionsSuggestions(
1243     const Conversation& conversation, const Annotator* annotator,
1244     const ActionSuggestionOptions& options,
1245     ActionsSuggestionsResponse* response) const {
1246   if (conversation.messages.empty()) {
1247     return true;
1248   }
1249 
1250   // Run annotator against messages.
1251   const Conversation annotated_conversation =
1252       AnnotateConversation(conversation, annotator);
1253 
1254   const int num_messages = NumMessagesToConsider(
1255       annotated_conversation, model_->max_conversation_history_length());
1256 
1257   if (num_messages <= 0) {
1258     TC3_LOG(INFO) << "No messages provided for actions suggestions.";
1259     return false;
1260   }
1261 
1262   SuggestActionsFromAnnotations(annotated_conversation, &response->actions);
1263 
1264   if (grammar_actions_ != nullptr &&
1265       !grammar_actions_->SuggestActions(annotated_conversation,
1266                                         &response->actions)) {
1267     TC3_LOG(ERROR) << "Could not suggest actions from grammar rules.";
1268     return false;
1269   }
1270 
1271   int input_text_length = 0;
1272   int num_matching_locales = 0;
1273   for (int i = annotated_conversation.messages.size() - num_messages;
1274        i < annotated_conversation.messages.size(); i++) {
1275     input_text_length += annotated_conversation.messages[i].text.length();
1276     std::vector<Locale> message_languages;
1277     if (!ParseLocales(
1278             annotated_conversation.messages[i].detected_text_language_tags,
1279             &message_languages)) {
1280       continue;
1281     }
1282     if (Locale::IsAnyLocaleSupported(
1283             message_languages, locales_,
1284             preconditions_.handle_unknown_locale_as_supported)) {
1285       ++num_matching_locales;
1286     }
1287   }
1288 
1289   // Bail out if we are provided with too few or too much input.
1290   if (input_text_length < preconditions_.min_input_length ||
1291       (preconditions_.max_input_length >= 0 &&
1292        input_text_length > preconditions_.max_input_length)) {
1293     TC3_LOG(INFO) << "Too much or not enough input for inference.";
1294     return response;
1295   }
1296 
1297   // Bail out if the text does not look like it can be handled by the model.
1298   const float matching_fraction =
1299       static_cast<float>(num_matching_locales) / num_messages;
1300   if (matching_fraction < preconditions_.min_locale_match_fraction) {
1301     TC3_LOG(INFO) << "Not enough locale matches.";
1302     response->output_filtered_locale_mismatch = true;
1303     return true;
1304   }
1305 
1306   std::vector<const UniLib::RegexPattern*> post_check_rules;
1307   if (preconditions_.suppress_on_low_confidence_input) {
1308     if ((ngram_model_ != nullptr &&
1309          ngram_model_->EvalConversation(annotated_conversation,
1310                                         num_messages)) ||
1311         regex_actions_->IsLowConfidenceInput(annotated_conversation,
1312                                              num_messages, &post_check_rules)) {
1313       response->output_filtered_low_confidence = true;
1314       return true;
1315     }
1316   }
1317 
1318   std::unique_ptr<tflite::Interpreter> interpreter;
1319   if (!SuggestActionsFromModel(annotated_conversation, num_messages, options,
1320                                response, &interpreter)) {
1321     TC3_LOG(ERROR) << "Could not run model.";
1322     return false;
1323   }
1324 
1325   // Suppress all predictions if the conversation was deemed sensitive.
1326   if (preconditions_.suppress_on_sensitive_topic &&
1327       response->output_filtered_sensitivity) {
1328     return true;
1329   }
1330 
1331   if (!SuggestActionsFromLua(
1332           annotated_conversation, model_executor_.get(), interpreter.get(),
1333           annotator != nullptr ? annotator->entity_data_schema() : nullptr,
1334           &response->actions)) {
1335     TC3_LOG(ERROR) << "Could not suggest actions from script.";
1336     return false;
1337   }
1338 
1339   if (!regex_actions_->SuggestActions(annotated_conversation,
1340                                       entity_data_builder_.get(),
1341                                       &response->actions)) {
1342     TC3_LOG(ERROR) << "Could not suggest actions from regex rules.";
1343     return false;
1344   }
1345 
1346   if (preconditions_.suppress_on_low_confidence_input &&
1347       !regex_actions_->FilterConfidenceOutput(post_check_rules,
1348                                               &response->actions)) {
1349     TC3_LOG(ERROR) << "Could not post-check actions.";
1350     return false;
1351   }
1352 
1353   return true;
1354 }
1355 
SuggestActions(const Conversation & conversation,const Annotator * annotator,const ActionSuggestionOptions & options) const1356 ActionsSuggestionsResponse ActionsSuggestions::SuggestActions(
1357     const Conversation& conversation, const Annotator* annotator,
1358     const ActionSuggestionOptions& options) const {
1359   ActionsSuggestionsResponse response;
1360 
1361   // Assert that messages are sorted correctly.
1362   for (int i = 1; i < conversation.messages.size(); i++) {
1363     if (conversation.messages[i].reference_time_ms_utc <
1364         conversation.messages[i - 1].reference_time_ms_utc) {
1365       TC3_LOG(ERROR) << "Messages are not sorted most recent last.";
1366     }
1367   }
1368 
1369   if (!GatherActionsSuggestions(conversation, annotator, options, &response)) {
1370     TC3_LOG(ERROR) << "Could not gather actions suggestions.";
1371     response.actions.clear();
1372   } else if (!ranker_->RankActions(conversation, &response, entity_data_schema_,
1373                                    annotator != nullptr
1374                                        ? annotator->entity_data_schema()
1375                                        : nullptr)) {
1376     TC3_LOG(ERROR) << "Could not rank actions.";
1377     response.actions.clear();
1378   }
1379   return response;
1380 }
1381 
SuggestActions(const Conversation & conversation,const ActionSuggestionOptions & options) const1382 ActionsSuggestionsResponse ActionsSuggestions::SuggestActions(
1383     const Conversation& conversation,
1384     const ActionSuggestionOptions& options) const {
1385   return SuggestActions(conversation, /*annotator=*/nullptr, options);
1386 }
1387 
model() const1388 const ActionsModel* ActionsSuggestions::model() const { return model_; }
entity_data_schema() const1389 const reflection::Schema* ActionsSuggestions::entity_data_schema() const {
1390   return entity_data_schema_;
1391 }
1392 
ViewActionsModel(const void * buffer,int size)1393 const ActionsModel* ViewActionsModel(const void* buffer, int size) {
1394   if (buffer == nullptr) {
1395     return nullptr;
1396   }
1397   return LoadAndVerifyModel(reinterpret_cast<const uint8_t*>(buffer), size);
1398 }
1399 
1400 }  // namespace libtextclassifier3
1401