• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "actions/actions-suggestions.h"
18 
19 #include <memory>
20 #include <string>
21 #include <vector>
22 
23 #include "utils/base/statusor.h"
24 
25 #if !defined(TC3_DISABLE_LUA)
26 #include "actions/lua-actions.h"
27 #endif
28 #include "actions/ngram-model.h"
29 #include "actions/tflite-sensitive-model.h"
30 #include "actions/types.h"
31 #include "actions/utils.h"
32 #include "actions/zlib-utils.h"
33 #include "annotator/collections.h"
34 #include "utils/base/logging.h"
35 #if !defined(TC3_DISABLE_LUA)
36 #include "utils/lua-utils.h"
37 #endif
38 #include "utils/normalization.h"
39 #include "utils/optional.h"
40 #include "utils/strings/split.h"
41 #include "utils/strings/stringpiece.h"
42 #include "utils/strings/utf8.h"
43 #include "utils/utf8/unicodetext.h"
44 #include "absl/container/flat_hash_set.h"
45 #include "tensorflow/lite/string_util.h"
46 
47 namespace libtextclassifier3 {
48 
49 constexpr float kDefaultFloat = 0.0;
50 constexpr bool kDefaultBool = false;
51 constexpr int kDefaultInt = 1;
52 
53 namespace {
54 
LoadAndVerifyModel(const uint8_t * addr,int size)55 const ActionsModel* LoadAndVerifyModel(const uint8_t* addr, int size) {
56   flatbuffers::Verifier verifier(addr, size);
57   if (VerifyActionsModelBuffer(verifier)) {
58     return GetActionsModel(addr);
59   } else {
60     return nullptr;
61   }
62 }
63 
64 template <typename T>
ValueOrDefault(const flatbuffers::Table * values,const int32 field_offset,const T default_value)65 T ValueOrDefault(const flatbuffers::Table* values, const int32 field_offset,
66                  const T default_value) {
67   if (values == nullptr) {
68     return default_value;
69   }
70   return values->GetField<T>(field_offset, default_value);
71 }
72 
73 // Returns number of (tail) messages of a conversation to consider.
NumMessagesToConsider(const Conversation & conversation,const int max_conversation_history_length)74 int NumMessagesToConsider(const Conversation& conversation,
75                           const int max_conversation_history_length) {
76   return ((max_conversation_history_length < 0 ||
77            conversation.messages.size() < max_conversation_history_length)
78               ? conversation.messages.size()
79               : max_conversation_history_length);
80 }
81 
82 template <typename T>
PadOrTruncateToTargetLength(const std::vector<T> & inputs,const int max_length,const T pad_value)83 std::vector<T> PadOrTruncateToTargetLength(const std::vector<T>& inputs,
84                                            const int max_length,
85                                            const T pad_value) {
86   if (inputs.size() >= max_length) {
87     return std::vector<T>(inputs.begin(), inputs.begin() + max_length);
88   } else {
89     std::vector<T> result;
90     result.reserve(max_length);
91     result.insert(result.begin(), inputs.begin(), inputs.end());
92     result.insert(result.end(), max_length - inputs.size(), pad_value);
93     return result;
94   }
95 }
96 
97 template <typename T>
SetVectorOrScalarAsModelInput(const int param_index,const Variant & param_value,tflite::Interpreter * interpreter,const std::unique_ptr<const TfLiteModelExecutor> & model_executor)98 void SetVectorOrScalarAsModelInput(
99     const int param_index, const Variant& param_value,
100     tflite::Interpreter* interpreter,
101     const std::unique_ptr<const TfLiteModelExecutor>& model_executor) {
102   if (param_value.Has<std::vector<T>>()) {
103     model_executor->SetInput<T>(
104         param_index, param_value.ConstRefValue<std::vector<T>>(), interpreter);
105   } else if (param_value.Has<T>()) {
106     model_executor->SetInput<float>(param_index, param_value.Value<T>(),
107                                     interpreter);
108   } else {
109     TC3_LOG(ERROR) << "Variant type error!";
110   }
111 }
112 }  // namespace
113 
FromUnownedBuffer(const uint8_t * buffer,const int size,const UniLib * unilib,const std::string & triggering_preconditions_overlay)114 std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromUnownedBuffer(
115     const uint8_t* buffer, const int size, const UniLib* unilib,
116     const std::string& triggering_preconditions_overlay) {
117   auto actions = std::unique_ptr<ActionsSuggestions>(new ActionsSuggestions());
118   const ActionsModel* model = LoadAndVerifyModel(buffer, size);
119   if (model == nullptr) {
120     return nullptr;
121   }
122   actions->model_ = model;
123   actions->SetOrCreateUnilib(unilib);
124   actions->triggering_preconditions_overlay_buffer_ =
125       triggering_preconditions_overlay;
126   if (!actions->ValidateAndInitialize()) {
127     return nullptr;
128   }
129   return actions;
130 }
131 
FromScopedMmap(std::unique_ptr<libtextclassifier3::ScopedMmap> mmap,const UniLib * unilib,const std::string & triggering_preconditions_overlay)132 std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromScopedMmap(
133     std::unique_ptr<libtextclassifier3::ScopedMmap> mmap, const UniLib* unilib,
134     const std::string& triggering_preconditions_overlay) {
135   if (!mmap->handle().ok()) {
136     TC3_VLOG(1) << "Mmap failed.";
137     return nullptr;
138   }
139   const ActionsModel* model = LoadAndVerifyModel(
140       reinterpret_cast<const uint8_t*>(mmap->handle().start()),
141       mmap->handle().num_bytes());
142   if (!model) {
143     TC3_LOG(ERROR) << "Model verification failed.";
144     return nullptr;
145   }
146   auto actions = std::unique_ptr<ActionsSuggestions>(new ActionsSuggestions());
147   actions->model_ = model;
148   actions->mmap_ = std::move(mmap);
149   actions->SetOrCreateUnilib(unilib);
150   actions->triggering_preconditions_overlay_buffer_ =
151       triggering_preconditions_overlay;
152   if (!actions->ValidateAndInitialize()) {
153     return nullptr;
154   }
155   return actions;
156 }
157 
FromScopedMmap(std::unique_ptr<libtextclassifier3::ScopedMmap> mmap,std::unique_ptr<UniLib> unilib,const std::string & triggering_preconditions_overlay)158 std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromScopedMmap(
159     std::unique_ptr<libtextclassifier3::ScopedMmap> mmap,
160     std::unique_ptr<UniLib> unilib,
161     const std::string& triggering_preconditions_overlay) {
162   if (!mmap->handle().ok()) {
163     TC3_VLOG(1) << "Mmap failed.";
164     return nullptr;
165   }
166   const ActionsModel* model = LoadAndVerifyModel(
167       reinterpret_cast<const uint8_t*>(mmap->handle().start()),
168       mmap->handle().num_bytes());
169   if (!model) {
170     TC3_LOG(ERROR) << "Model verification failed.";
171     return nullptr;
172   }
173   auto actions = std::unique_ptr<ActionsSuggestions>(new ActionsSuggestions());
174   actions->model_ = model;
175   actions->mmap_ = std::move(mmap);
176   actions->owned_unilib_ = std::move(unilib);
177   actions->unilib_ = actions->owned_unilib_.get();
178   actions->triggering_preconditions_overlay_buffer_ =
179       triggering_preconditions_overlay;
180   if (!actions->ValidateAndInitialize()) {
181     return nullptr;
182   }
183   return actions;
184 }
185 
FromFileDescriptor(const int fd,const int offset,const int size,const UniLib * unilib,const std::string & triggering_preconditions_overlay)186 std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromFileDescriptor(
187     const int fd, const int offset, const int size, const UniLib* unilib,
188     const std::string& triggering_preconditions_overlay) {
189   std::unique_ptr<libtextclassifier3::ScopedMmap> mmap;
190   if (offset >= 0 && size >= 0) {
191     mmap.reset(new libtextclassifier3::ScopedMmap(fd, offset, size));
192   } else {
193     mmap.reset(new libtextclassifier3::ScopedMmap(fd));
194   }
195   return FromScopedMmap(std::move(mmap), unilib,
196                         triggering_preconditions_overlay);
197 }
198 
FromFileDescriptor(const int fd,const int offset,const int size,std::unique_ptr<UniLib> unilib,const std::string & triggering_preconditions_overlay)199 std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromFileDescriptor(
200     const int fd, const int offset, const int size,
201     std::unique_ptr<UniLib> unilib,
202     const std::string& triggering_preconditions_overlay) {
203   std::unique_ptr<libtextclassifier3::ScopedMmap> mmap;
204   if (offset >= 0 && size >= 0) {
205     mmap.reset(new libtextclassifier3::ScopedMmap(fd, offset, size));
206   } else {
207     mmap.reset(new libtextclassifier3::ScopedMmap(fd));
208   }
209   return FromScopedMmap(std::move(mmap), std::move(unilib),
210                         triggering_preconditions_overlay);
211 }
212 
FromFileDescriptor(const int fd,const UniLib * unilib,const std::string & triggering_preconditions_overlay)213 std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromFileDescriptor(
214     const int fd, const UniLib* unilib,
215     const std::string& triggering_preconditions_overlay) {
216   std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
217       new libtextclassifier3::ScopedMmap(fd));
218   return FromScopedMmap(std::move(mmap), unilib,
219                         triggering_preconditions_overlay);
220 }
221 
FromFileDescriptor(const int fd,std::unique_ptr<UniLib> unilib,const std::string & triggering_preconditions_overlay)222 std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromFileDescriptor(
223     const int fd, std::unique_ptr<UniLib> unilib,
224     const std::string& triggering_preconditions_overlay) {
225   std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
226       new libtextclassifier3::ScopedMmap(fd));
227   return FromScopedMmap(std::move(mmap), std::move(unilib),
228                         triggering_preconditions_overlay);
229 }
230 
FromPath(const std::string & path,const UniLib * unilib,const std::string & triggering_preconditions_overlay)231 std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromPath(
232     const std::string& path, const UniLib* unilib,
233     const std::string& triggering_preconditions_overlay) {
234   std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
235       new libtextclassifier3::ScopedMmap(path));
236   return FromScopedMmap(std::move(mmap), unilib,
237                         triggering_preconditions_overlay);
238 }
239 
FromPath(const std::string & path,std::unique_ptr<UniLib> unilib,const std::string & triggering_preconditions_overlay)240 std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromPath(
241     const std::string& path, std::unique_ptr<UniLib> unilib,
242     const std::string& triggering_preconditions_overlay) {
243   std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
244       new libtextclassifier3::ScopedMmap(path));
245   return FromScopedMmap(std::move(mmap), std::move(unilib),
246                         triggering_preconditions_overlay);
247 }
248 
SetOrCreateUnilib(const UniLib * unilib)249 void ActionsSuggestions::SetOrCreateUnilib(const UniLib* unilib) {
250   if (unilib != nullptr) {
251     unilib_ = unilib;
252   } else {
253     owned_unilib_.reset(new UniLib);
254     unilib_ = owned_unilib_.get();
255   }
256 }
257 
ValidateAndInitialize()258 bool ActionsSuggestions::ValidateAndInitialize() {
259   if (model_ == nullptr) {
260     TC3_LOG(ERROR) << "No model specified.";
261     return false;
262   }
263 
264   if (model_->smart_reply_action_type() == nullptr) {
265     TC3_LOG(ERROR) << "No smart reply action type specified.";
266     return false;
267   }
268 
269   if (!InitializeTriggeringPreconditions()) {
270     TC3_LOG(ERROR) << "Could not initialize preconditions.";
271     return false;
272   }
273 
274   if (model_->locales() &&
275       !ParseLocales(model_->locales()->c_str(), &locales_)) {
276     TC3_LOG(ERROR) << "Could not parse model supported locales.";
277     return false;
278   }
279 
280   if (model_->tflite_model_spec() != nullptr) {
281     model_executor_ = TfLiteModelExecutor::FromBuffer(
282         model_->tflite_model_spec()->tflite_model());
283     if (!model_executor_) {
284       TC3_LOG(ERROR) << "Could not initialize model executor.";
285       return false;
286     }
287   }
288 
289   // Gather annotation entities for the rules.
290   if (model_->annotation_actions_spec() != nullptr &&
291       model_->annotation_actions_spec()->annotation_mapping() != nullptr) {
292     for (const AnnotationActionsSpec_::AnnotationMapping* mapping :
293          *model_->annotation_actions_spec()->annotation_mapping()) {
294       annotation_entity_types_.insert(mapping->annotation_collection()->str());
295     }
296   }
297 
298   if (model_->actions_entity_data_schema() != nullptr) {
299     entity_data_schema_ = LoadAndVerifyFlatbuffer<reflection::Schema>(
300         model_->actions_entity_data_schema()->Data(),
301         model_->actions_entity_data_schema()->size());
302     if (entity_data_schema_ == nullptr) {
303       TC3_LOG(ERROR) << "Could not load entity data schema data.";
304       return false;
305     }
306 
307     entity_data_builder_.reset(
308         new MutableFlatbufferBuilder(entity_data_schema_));
309   } else {
310     entity_data_schema_ = nullptr;
311   }
312 
313   // Initialize regular expressions model.
314   std::unique_ptr<ZlibDecompressor> decompressor = ZlibDecompressor::Instance();
315   regex_actions_.reset(
316       new RegexActions(unilib_, model_->smart_reply_action_type()->str()));
317   if (!regex_actions_->InitializeRules(
318           model_->rules(), model_->low_confidence_rules(),
319           triggering_preconditions_overlay_, decompressor.get())) {
320     TC3_LOG(ERROR) << "Could not initialize regex rules.";
321     return false;
322   }
323 
324   // Setup grammar model.
325   if (model_->rules() != nullptr &&
326       model_->rules()->grammar_rules() != nullptr) {
327     grammar_actions_.reset(new GrammarActions(
328         unilib_, model_->rules()->grammar_rules(), entity_data_builder_.get(),
329         model_->smart_reply_action_type()->str()));
330 
331     // Gather annotation entities for the grammars.
332     if (auto annotation_nt = model_->rules()
333                                  ->grammar_rules()
334                                  ->rules()
335                                  ->nonterminals()
336                                  ->annotation_nt()) {
337       for (const grammar::RulesSet_::Nonterminals_::AnnotationNtEntry* entry :
338            *annotation_nt) {
339         annotation_entity_types_.insert(entry->key()->str());
340       }
341     }
342   }
343 
344 #if !defined(TC3_DISABLE_LUA)
345   std::string actions_script;
346   if (GetUncompressedString(model_->lua_actions_script(),
347                             model_->compressed_lua_actions_script(),
348                             decompressor.get(), &actions_script) &&
349       !actions_script.empty()) {
350     if (!Compile(actions_script, &lua_bytecode_)) {
351       TC3_LOG(ERROR) << "Could not precompile lua actions snippet.";
352       return false;
353     }
354   }
355 #endif  // TC3_DISABLE_LUA
356 
357   if (!(ranker_ = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
358             model_->ranking_options(), decompressor.get(),
359             model_->smart_reply_action_type()->str()))) {
360     TC3_LOG(ERROR) << "Could not create an action suggestions ranker.";
361     return false;
362   }
363 
364   // Create feature processor if specified.
365   const ActionsTokenFeatureProcessorOptions* options =
366       model_->feature_processor_options();
367   if (options != nullptr) {
368     if (options->tokenizer_options() == nullptr) {
369       TC3_LOG(ERROR) << "No tokenizer options specified.";
370       return false;
371     }
372 
373     feature_processor_.reset(new ActionsFeatureProcessor(options, unilib_));
374     embedding_executor_ = TFLiteEmbeddingExecutor::FromBuffer(
375         options->embedding_model(), options->embedding_size(),
376         options->embedding_quantization_bits());
377 
378     if (embedding_executor_ == nullptr) {
379       TC3_LOG(ERROR) << "Could not initialize embedding executor.";
380       return false;
381     }
382 
383     // Cache embedding of padding, start and end token.
384     if (!EmbedTokenId(options->padding_token_id(), &embedded_padding_token_) ||
385         !EmbedTokenId(options->start_token_id(), &embedded_start_token_) ||
386         !EmbedTokenId(options->end_token_id(), &embedded_end_token_)) {
387       TC3_LOG(ERROR) << "Could not precompute token embeddings.";
388       return false;
389     }
390     token_embedding_size_ = feature_processor_->GetTokenEmbeddingSize();
391   }
392 
393   // Create low confidence model if specified.
394   if (model_->low_confidence_ngram_model() != nullptr) {
395     sensitive_model_ = NGramSensitiveModel::Create(
396         unilib_, model_->low_confidence_ngram_model(),
397         feature_processor_ == nullptr ? nullptr
398                                       : feature_processor_->tokenizer());
399     if (sensitive_model_ == nullptr) {
400       TC3_LOG(ERROR) << "Could not create ngram linear regression model.";
401       return false;
402     }
403   }
404   if (model_->low_confidence_tflite_model() != nullptr) {
405     sensitive_model_ =
406         TFLiteSensitiveModel::Create(model_->low_confidence_tflite_model());
407     if (sensitive_model_ == nullptr) {
408       TC3_LOG(ERROR) << "Could not create TFLite sensitive model.";
409       return false;
410     }
411   }
412 
413   return true;
414 }
415 
InitializeTriggeringPreconditions()416 bool ActionsSuggestions::InitializeTriggeringPreconditions() {
417   triggering_preconditions_overlay_ =
418       LoadAndVerifyFlatbuffer<TriggeringPreconditions>(
419           triggering_preconditions_overlay_buffer_);
420 
421   if (triggering_preconditions_overlay_ == nullptr &&
422       !triggering_preconditions_overlay_buffer_.empty()) {
423     TC3_LOG(ERROR) << "Could not load triggering preconditions overwrites.";
424     return false;
425   }
426   const flatbuffers::Table* overlay =
427       reinterpret_cast<const flatbuffers::Table*>(
428           triggering_preconditions_overlay_);
429   const TriggeringPreconditions* defaults = model_->preconditions();
430   if (defaults == nullptr) {
431     TC3_LOG(ERROR) << "No triggering conditions specified.";
432     return false;
433   }
434 
435   preconditions_.min_smart_reply_triggering_score = ValueOrDefault(
436       overlay, TriggeringPreconditions::VT_MIN_SMART_REPLY_TRIGGERING_SCORE,
437       defaults->min_smart_reply_triggering_score());
438   preconditions_.max_sensitive_topic_score = ValueOrDefault(
439       overlay, TriggeringPreconditions::VT_MAX_SENSITIVE_TOPIC_SCORE,
440       defaults->max_sensitive_topic_score());
441   preconditions_.suppress_on_sensitive_topic = ValueOrDefault(
442       overlay, TriggeringPreconditions::VT_SUPPRESS_ON_SENSITIVE_TOPIC,
443       defaults->suppress_on_sensitive_topic());
444   preconditions_.min_input_length =
445       ValueOrDefault(overlay, TriggeringPreconditions::VT_MIN_INPUT_LENGTH,
446                      defaults->min_input_length());
447   preconditions_.max_input_length =
448       ValueOrDefault(overlay, TriggeringPreconditions::VT_MAX_INPUT_LENGTH,
449                      defaults->max_input_length());
450   preconditions_.min_locale_match_fraction = ValueOrDefault(
451       overlay, TriggeringPreconditions::VT_MIN_LOCALE_MATCH_FRACTION,
452       defaults->min_locale_match_fraction());
453   preconditions_.handle_missing_locale_as_supported = ValueOrDefault(
454       overlay, TriggeringPreconditions::VT_HANDLE_MISSING_LOCALE_AS_SUPPORTED,
455       defaults->handle_missing_locale_as_supported());
456   preconditions_.handle_unknown_locale_as_supported = ValueOrDefault(
457       overlay, TriggeringPreconditions::VT_HANDLE_UNKNOWN_LOCALE_AS_SUPPORTED,
458       defaults->handle_unknown_locale_as_supported());
459   preconditions_.suppress_on_low_confidence_input = ValueOrDefault(
460       overlay, TriggeringPreconditions::VT_SUPPRESS_ON_LOW_CONFIDENCE_INPUT,
461       defaults->suppress_on_low_confidence_input());
462   preconditions_.min_reply_score_threshold = ValueOrDefault(
463       overlay, TriggeringPreconditions::VT_MIN_REPLY_SCORE_THRESHOLD,
464       defaults->min_reply_score_threshold());
465 
466   return true;
467 }
468 
EmbedTokenId(const int32 token_id,std::vector<float> * embedding) const469 bool ActionsSuggestions::EmbedTokenId(const int32 token_id,
470                                       std::vector<float>* embedding) const {
471   return feature_processor_->AppendFeatures(
472       {token_id},
473       /*dense_features=*/{}, embedding_executor_.get(), embedding);
474 }
475 
Tokenize(const std::vector<std::string> & context) const476 std::vector<std::vector<Token>> ActionsSuggestions::Tokenize(
477     const std::vector<std::string>& context) const {
478   std::vector<std::vector<Token>> tokens;
479   tokens.reserve(context.size());
480   for (const std::string& message : context) {
481     tokens.push_back(feature_processor_->tokenizer()->Tokenize(message));
482   }
483   return tokens;
484 }
485 
EmbedTokensPerMessage(const std::vector<std::vector<Token>> & tokens,std::vector<float> * embeddings,int * max_num_tokens_per_message) const486 bool ActionsSuggestions::EmbedTokensPerMessage(
487     const std::vector<std::vector<Token>>& tokens,
488     std::vector<float>* embeddings, int* max_num_tokens_per_message) const {
489   const int num_messages = tokens.size();
490   *max_num_tokens_per_message = 0;
491   for (int i = 0; i < num_messages; i++) {
492     const int num_message_tokens = tokens[i].size();
493     if (num_message_tokens > *max_num_tokens_per_message) {
494       *max_num_tokens_per_message = num_message_tokens;
495     }
496   }
497 
498   if (model_->feature_processor_options()->min_num_tokens_per_message() >
499       *max_num_tokens_per_message) {
500     *max_num_tokens_per_message =
501         model_->feature_processor_options()->min_num_tokens_per_message();
502   }
503   if (model_->feature_processor_options()->max_num_tokens_per_message() > 0 &&
504       *max_num_tokens_per_message >
505           model_->feature_processor_options()->max_num_tokens_per_message()) {
506     *max_num_tokens_per_message =
507         model_->feature_processor_options()->max_num_tokens_per_message();
508   }
509 
510   // Embed all tokens and add paddings to pad tokens of each message to the
511   // maximum number of tokens in a message of the conversation.
512   // If a number of tokens is specified in the model config, tokens at the
513   // beginning of a message are dropped if they don't fit in the limit.
514   for (int i = 0; i < num_messages; i++) {
515     const int start =
516         std::max<int>(tokens[i].size() - *max_num_tokens_per_message, 0);
517     for (int pos = start; pos < tokens[i].size(); pos++) {
518       if (!feature_processor_->AppendTokenFeatures(
519               tokens[i][pos], embedding_executor_.get(), embeddings)) {
520         TC3_LOG(ERROR) << "Could not run token feature extractor.";
521         return false;
522       }
523     }
524     // Add padding.
525     for (int k = tokens[i].size(); k < *max_num_tokens_per_message; k++) {
526       embeddings->insert(embeddings->end(), embedded_padding_token_.begin(),
527                          embedded_padding_token_.end());
528     }
529   }
530 
531   return true;
532 }
533 
EmbedAndFlattenTokens(const std::vector<std::vector<Token>> & tokens,std::vector<float> * embeddings,int * total_token_count) const534 bool ActionsSuggestions::EmbedAndFlattenTokens(
535     const std::vector<std::vector<Token>>& tokens,
536     std::vector<float>* embeddings, int* total_token_count) const {
537   const int num_messages = tokens.size();
538   int start_message = 0;
539   int message_token_offset = 0;
540 
541   // If a maximum model input length is specified, we need to check how
542   // much we need to trim at the start.
543   const int max_num_total_tokens =
544       model_->feature_processor_options()->max_num_total_tokens();
545   if (max_num_total_tokens > 0) {
546     int total_tokens = 0;
547     start_message = num_messages - 1;
548     for (; start_message >= 0; start_message--) {
549       // Tokens of the message + start and end token.
550       const int num_message_tokens = tokens[start_message].size() + 2;
551       total_tokens += num_message_tokens;
552 
553       // Check whether we exhausted the budget.
554       if (total_tokens >= max_num_total_tokens) {
555         message_token_offset = total_tokens - max_num_total_tokens;
556         break;
557       }
558     }
559   }
560 
561   // Add embeddings.
562   *total_token_count = 0;
563   for (int i = start_message; i < num_messages; i++) {
564     if (message_token_offset == 0) {
565       ++(*total_token_count);
566       // Add `start message` token.
567       embeddings->insert(embeddings->end(), embedded_start_token_.begin(),
568                          embedded_start_token_.end());
569     }
570 
571     for (int pos = std::max(0, message_token_offset - 1);
572          pos < tokens[i].size(); pos++) {
573       ++(*total_token_count);
574       if (!feature_processor_->AppendTokenFeatures(
575               tokens[i][pos], embedding_executor_.get(), embeddings)) {
576         TC3_LOG(ERROR) << "Could not run token feature extractor.";
577         return false;
578       }
579     }
580 
581     // Add `end message` token.
582     ++(*total_token_count);
583     embeddings->insert(embeddings->end(), embedded_end_token_.begin(),
584                        embedded_end_token_.end());
585 
586     // Reset for the subsequent messages.
587     message_token_offset = 0;
588   }
589 
590   // Add optional padding.
591   const int min_num_total_tokens =
592       model_->feature_processor_options()->min_num_total_tokens();
593   for (; *total_token_count < min_num_total_tokens; ++(*total_token_count)) {
594     embeddings->insert(embeddings->end(), embedded_padding_token_.begin(),
595                        embedded_padding_token_.end());
596   }
597 
598   return true;
599 }
600 
AllocateInput(const int conversation_length,const int max_tokens,const int total_token_count,tflite::Interpreter * interpreter) const601 bool ActionsSuggestions::AllocateInput(const int conversation_length,
602                                        const int max_tokens,
603                                        const int total_token_count,
604                                        tflite::Interpreter* interpreter) const {
605   if (model_->tflite_model_spec()->resize_inputs()) {
606     if (model_->tflite_model_spec()->input_context() >= 0) {
607       interpreter->ResizeInputTensor(
608           interpreter->inputs()[model_->tflite_model_spec()->input_context()],
609           {1, conversation_length});
610     }
611     if (model_->tflite_model_spec()->input_user_id() >= 0) {
612       interpreter->ResizeInputTensor(
613           interpreter->inputs()[model_->tflite_model_spec()->input_user_id()],
614           {1, conversation_length});
615     }
616     if (model_->tflite_model_spec()->input_time_diffs() >= 0) {
617       interpreter->ResizeInputTensor(
618           interpreter
619               ->inputs()[model_->tflite_model_spec()->input_time_diffs()],
620           {1, conversation_length});
621     }
622     if (model_->tflite_model_spec()->input_num_tokens() >= 0) {
623       interpreter->ResizeInputTensor(
624           interpreter
625               ->inputs()[model_->tflite_model_spec()->input_num_tokens()],
626           {conversation_length, 1});
627     }
628     if (model_->tflite_model_spec()->input_token_embeddings() >= 0) {
629       interpreter->ResizeInputTensor(
630           interpreter
631               ->inputs()[model_->tflite_model_spec()->input_token_embeddings()],
632           {conversation_length, max_tokens, token_embedding_size_});
633     }
634     if (model_->tflite_model_spec()->input_flattened_token_embeddings() >= 0) {
635       interpreter->ResizeInputTensor(
636           interpreter->inputs()[model_->tflite_model_spec()
637                                     ->input_flattened_token_embeddings()],
638           {1, total_token_count});
639     }
640   }
641 
642   return interpreter->AllocateTensors() == kTfLiteOk;
643 }
644 
SetupModelInput(const std::vector<std::string> & context,const std::vector<int> & user_ids,const std::vector<float> & time_diffs,const int num_suggestions,const ActionSuggestionOptions & options,tflite::Interpreter * interpreter) const645 bool ActionsSuggestions::SetupModelInput(
646     const std::vector<std::string>& context, const std::vector<int>& user_ids,
647     const std::vector<float>& time_diffs, const int num_suggestions,
648     const ActionSuggestionOptions& options,
649     tflite::Interpreter* interpreter) const {
650   // Compute token embeddings.
651   std::vector<std::vector<Token>> tokens;
652   std::vector<float> token_embeddings;
653   std::vector<float> flattened_token_embeddings;
654   int max_tokens = 0;
655   int total_token_count = 0;
656   if (model_->tflite_model_spec()->input_num_tokens() >= 0 ||
657       model_->tflite_model_spec()->input_token_embeddings() >= 0 ||
658       model_->tflite_model_spec()->input_flattened_token_embeddings() >= 0) {
659     if (feature_processor_ == nullptr) {
660       TC3_LOG(ERROR) << "No feature processor specified.";
661       return false;
662     }
663 
664     // Tokenize the messages in the conversation.
665     tokens = Tokenize(context);
666     if (model_->tflite_model_spec()->input_token_embeddings() >= 0) {
667       if (!EmbedTokensPerMessage(tokens, &token_embeddings, &max_tokens)) {
668         TC3_LOG(ERROR) << "Could not extract token features.";
669         return false;
670       }
671     }
672     if (model_->tflite_model_spec()->input_flattened_token_embeddings() >= 0) {
673       if (!EmbedAndFlattenTokens(tokens, &flattened_token_embeddings,
674                                  &total_token_count)) {
675         TC3_LOG(ERROR) << "Could not extract token features.";
676         return false;
677       }
678     }
679   }
680 
681   if (!AllocateInput(context.size(), max_tokens, total_token_count,
682                      interpreter)) {
683     TC3_LOG(ERROR) << "TensorFlow Lite model allocation failed.";
684     return false;
685   }
686   if (model_->tflite_model_spec()->input_context() >= 0) {
687     if (model_->tflite_model_spec()->input_length_to_pad() > 0) {
688       model_executor_->SetInput<std::string>(
689           model_->tflite_model_spec()->input_context(),
690           PadOrTruncateToTargetLength(
691               context, model_->tflite_model_spec()->input_length_to_pad(),
692               std::string("")),
693           interpreter);
694     } else {
695       model_executor_->SetInput<std::string>(
696           model_->tflite_model_spec()->input_context(), context, interpreter);
697     }
698   }
699   if (model_->tflite_model_spec()->input_context_length() >= 0) {
700     model_executor_->SetInput<int>(
701         model_->tflite_model_spec()->input_context_length(), context.size(),
702         interpreter);
703   }
704   if (model_->tflite_model_spec()->input_user_id() >= 0) {
705     if (model_->tflite_model_spec()->input_length_to_pad() > 0) {
706       model_executor_->SetInput<int>(
707           model_->tflite_model_spec()->input_user_id(),
708           PadOrTruncateToTargetLength(
709               user_ids, model_->tflite_model_spec()->input_length_to_pad(), 0),
710           interpreter);
711     } else {
712       model_executor_->SetInput<int>(
713           model_->tflite_model_spec()->input_user_id(), user_ids, interpreter);
714     }
715   }
716   if (model_->tflite_model_spec()->input_num_suggestions() >= 0) {
717     model_executor_->SetInput<int>(
718         model_->tflite_model_spec()->input_num_suggestions(), num_suggestions,
719         interpreter);
720   }
721   if (model_->tflite_model_spec()->input_time_diffs() >= 0) {
722     model_executor_->SetInput<float>(
723         model_->tflite_model_spec()->input_time_diffs(), time_diffs,
724         interpreter);
725   }
726   if (model_->tflite_model_spec()->input_num_tokens() >= 0) {
727     std::vector<int> num_tokens_per_message(tokens.size());
728     for (int i = 0; i < tokens.size(); i++) {
729       num_tokens_per_message[i] = tokens[i].size();
730     }
731     model_executor_->SetInput<int>(
732         model_->tflite_model_spec()->input_num_tokens(), num_tokens_per_message,
733         interpreter);
734   }
735   if (model_->tflite_model_spec()->input_token_embeddings() >= 0) {
736     model_executor_->SetInput<float>(
737         model_->tflite_model_spec()->input_token_embeddings(), token_embeddings,
738         interpreter);
739   }
740   if (model_->tflite_model_spec()->input_flattened_token_embeddings() >= 0) {
741     model_executor_->SetInput<float>(
742         model_->tflite_model_spec()->input_flattened_token_embeddings(),
743         flattened_token_embeddings, interpreter);
744   }
745   // Set up additional input parameters.
746   if (const auto* input_name_index =
747           model_->tflite_model_spec()->input_name_index()) {
748     const std::unordered_map<std::string, Variant>& model_parameters =
749         options.model_parameters;
750     for (const TensorflowLiteModelSpec_::InputNameIndexEntry* entry :
751          *input_name_index) {
752       const std::string param_name = entry->key()->str();
753       const int param_index = entry->value();
754       const TfLiteType param_type =
755           interpreter->tensor(interpreter->inputs()[param_index])->type;
756       const auto param_value_it = model_parameters.find(param_name);
757       const bool has_value = param_value_it != model_parameters.end();
758       switch (param_type) {
759         case kTfLiteFloat32:
760           if (has_value) {
761             SetVectorOrScalarAsModelInput<float>(param_index,
762                                                  param_value_it->second,
763                                                  interpreter, model_executor_);
764           } else {
765             model_executor_->SetInput<float>(param_index, kDefaultFloat,
766                                              interpreter);
767           }
768           break;
769         case kTfLiteInt32:
770           if (has_value) {
771             SetVectorOrScalarAsModelInput<int32_t>(
772                 param_index, param_value_it->second, interpreter,
773                 model_executor_);
774           } else {
775             model_executor_->SetInput<int32_t>(param_index, kDefaultInt,
776                                                interpreter);
777           }
778           break;
779         case kTfLiteInt64:
780           model_executor_->SetInput<int64_t>(
781               param_index,
782               has_value ? param_value_it->second.Value<int64>() : kDefaultInt,
783               interpreter);
784           break;
785         case kTfLiteUInt8:
786           model_executor_->SetInput<uint8_t>(
787               param_index,
788               has_value ? param_value_it->second.Value<uint8>() : kDefaultInt,
789               interpreter);
790           break;
791         case kTfLiteInt8:
792           model_executor_->SetInput<int8_t>(
793               param_index,
794               has_value ? param_value_it->second.Value<int8>() : kDefaultInt,
795               interpreter);
796           break;
797         case kTfLiteBool:
798           model_executor_->SetInput<bool>(
799               param_index,
800               has_value ? param_value_it->second.Value<bool>() : kDefaultBool,
801               interpreter);
802           break;
803         default:
804           TC3_LOG(ERROR) << "Unsupported type of additional input parameter: "
805                          << param_name;
806       }
807     }
808   }
809   return true;
810 }
811 
PopulateTextReplies(const tflite::Interpreter * interpreter,int suggestion_index,int score_index,const std::string & type,float priority_score,const absl::flat_hash_set<std::string> & blocklist,ActionsSuggestionsResponse * response) const812 void ActionsSuggestions::PopulateTextReplies(
813     const tflite::Interpreter* interpreter, int suggestion_index,
814     int score_index, const std::string& type, float priority_score,
815     const absl::flat_hash_set<std::string>& blocklist,
816     ActionsSuggestionsResponse* response) const {
817   const std::vector<tflite::StringRef> replies =
818       model_executor_->Output<tflite::StringRef>(suggestion_index, interpreter);
819   const TensorView<float> scores =
820       model_executor_->OutputView<float>(score_index, interpreter);
821 
822   for (int i = 0; i < replies.size(); i++) {
823     if (replies[i].len == 0) {
824       continue;
825     }
826     const float score = scores.data()[i];
827     if (score < preconditions_.min_reply_score_threshold) {
828       continue;
829     }
830     std::string response_text(replies[i].str, replies[i].len);
831     if (blocklist.contains(response_text)) {
832       continue;
833     }
834 
835     response->actions.push_back({response_text, type, score, priority_score});
836   }
837 }
838 
FillSuggestionFromSpecWithEntityData(const ActionSuggestionSpec * spec,ActionSuggestion * suggestion) const839 void ActionsSuggestions::FillSuggestionFromSpecWithEntityData(
840     const ActionSuggestionSpec* spec, ActionSuggestion* suggestion) const {
841   std::unique_ptr<MutableFlatbuffer> entity_data =
842       entity_data_builder_ != nullptr ? entity_data_builder_->NewRoot()
843                                       : nullptr;
844   FillSuggestionFromSpec(spec, entity_data.get(), suggestion);
845 }
846 
PopulateIntentTriggering(const tflite::Interpreter * interpreter,int suggestion_index,int score_index,const ActionSuggestionSpec * task_spec,ActionsSuggestionsResponse * response) const847 void ActionsSuggestions::PopulateIntentTriggering(
848     const tflite::Interpreter* interpreter, int suggestion_index,
849     int score_index, const ActionSuggestionSpec* task_spec,
850     ActionsSuggestionsResponse* response) const {
851   if (!task_spec || task_spec->type()->size() == 0) {
852     TC3_LOG(ERROR)
853         << "Task type for intent (action) triggering cannot be empty!";
854     return;
855   }
856   const TensorView<bool> intent_prediction =
857       model_executor_->OutputView<bool>(suggestion_index, interpreter);
858   const TensorView<float> intent_scores =
859       model_executor_->OutputView<float>(score_index, interpreter);
860   // Two result corresponding to binary triggering case.
861   TC3_CHECK_EQ(intent_prediction.size(), 2);
862   TC3_CHECK_EQ(intent_scores.size(), 2);
863   // We rely on in-graph thresholding logic so at this point the results
864   // have been ranked properly according to threshold.
865   const bool triggering = intent_prediction.data()[0];
866   const float trigger_score = intent_scores.data()[0];
867 
868   if (triggering) {
869     ActionSuggestion suggestion;
870     std::unique_ptr<MutableFlatbuffer> entity_data =
871         entity_data_builder_ != nullptr ? entity_data_builder_->NewRoot()
872                                         : nullptr;
873     FillSuggestionFromSpecWithEntityData(task_spec, &suggestion);
874     suggestion.score = trigger_score;
875     response->actions.push_back(std::move(suggestion));
876   }
877 }
878 
ReadModelOutput(tflite::Interpreter * interpreter,const ActionSuggestionOptions & options,ActionsSuggestionsResponse * response) const879 bool ActionsSuggestions::ReadModelOutput(
880     tflite::Interpreter* interpreter, const ActionSuggestionOptions& options,
881     ActionsSuggestionsResponse* response) const {
882   // Read sensitivity and triggering score predictions.
883   if (model_->tflite_model_spec()->output_triggering_score() >= 0) {
884     const TensorView<float> triggering_score =
885         model_executor_->OutputView<float>(
886             model_->tflite_model_spec()->output_triggering_score(),
887             interpreter);
888     if (!triggering_score.is_valid() || triggering_score.size() == 0) {
889       TC3_LOG(ERROR) << "Could not compute triggering score.";
890       return false;
891     }
892     response->triggering_score = triggering_score.data()[0];
893     response->output_filtered_min_triggering_score =
894         (response->triggering_score <
895          preconditions_.min_smart_reply_triggering_score);
896   }
897   if (model_->tflite_model_spec()->output_sensitive_topic_score() >= 0) {
898     const TensorView<float> sensitive_topic_score =
899         model_executor_->OutputView<float>(
900             model_->tflite_model_spec()->output_sensitive_topic_score(),
901             interpreter);
902     if (!sensitive_topic_score.is_valid() ||
903         sensitive_topic_score.dim(0) != 1) {
904       TC3_LOG(ERROR) << "Could not compute sensitive topic score.";
905       return false;
906     }
907     response->sensitivity_score = sensitive_topic_score.data()[0];
908     response->is_sensitive = (response->sensitivity_score >
909                               preconditions_.max_sensitive_topic_score);
910   }
911 
912   // Suppress model outputs.
913   if (response->is_sensitive) {
914     return true;
915   }
916 
917   // Read smart reply predictions.
918   if (!response->output_filtered_min_triggering_score &&
919       model_->tflite_model_spec()->output_replies() >= 0) {
920     absl::flat_hash_set<std::string> empty_blocklist;
921     PopulateTextReplies(interpreter,
922                         model_->tflite_model_spec()->output_replies(),
923                         model_->tflite_model_spec()->output_replies_scores(),
924                         model_->smart_reply_action_type()->str(),
925                         /* priority_score */ 0.0, empty_blocklist, response);
926   }
927 
928   // Read actions suggestions.
929   if (model_->tflite_model_spec()->output_actions_scores() >= 0) {
930     const TensorView<float> actions_scores = model_executor_->OutputView<float>(
931         model_->tflite_model_spec()->output_actions_scores(), interpreter);
932     for (int i = 0; i < model_->action_type()->size(); i++) {
933       const ActionTypeOptions* action_type = model_->action_type()->Get(i);
934       // Skip disabled action classes, such as the default other category.
935       if (!action_type->enabled()) {
936         continue;
937       }
938       const float score = actions_scores.data()[i];
939       if (score < action_type->min_triggering_score()) {
940         continue;
941       }
942 
943       // Create action from model output.
944       ActionSuggestion suggestion;
945       suggestion.type = action_type->name()->str();
946       std::unique_ptr<MutableFlatbuffer> entity_data =
947           entity_data_builder_ != nullptr ? entity_data_builder_->NewRoot()
948                                           : nullptr;
949       FillSuggestionFromSpecWithEntityData(action_type->action(), &suggestion);
950       suggestion.score = score;
951       response->actions.push_back(std::move(suggestion));
952     }
953   }
954 
955   // Read multi-task predictions and construct the result properly.
956   if (const auto* prediction_metadata =
957           model_->tflite_model_spec()->prediction_metadata()) {
958     for (const PredictionMetadata* metadata : *prediction_metadata) {
959       const ActionSuggestionSpec* task_spec = metadata->task_spec();
960       const int suggestions_index = metadata->output_suggestions();
961       const int suggestions_scores_index =
962           metadata->output_suggestions_scores();
963       absl::flat_hash_set<std::string> response_text_blocklist;
964       switch (metadata->prediction_type()) {
965         case PredictionType_NEXT_MESSAGE_PREDICTION:
966           if (!task_spec || task_spec->type()->size() == 0) {
967             TC3_LOG(WARNING) << "Task type not provided, use default "
968                                 "smart_reply_action_type!";
969           }
970           if (task_spec) {
971             if (task_spec->response_text_blocklist()) {
972               for (const auto& val : *task_spec->response_text_blocklist()) {
973                 response_text_blocklist.insert(val->str());
974               }
975             }
976           }
977           PopulateTextReplies(
978               interpreter, suggestions_index, suggestions_scores_index,
979               task_spec ? task_spec->type()->str()
980                         : model_->smart_reply_action_type()->str(),
981               task_spec ? task_spec->priority_score() : 0.0,
982               response_text_blocklist, response);
983           break;
984         case PredictionType_INTENT_TRIGGERING:
985           PopulateIntentTriggering(interpreter, suggestions_index,
986                                    suggestions_scores_index, task_spec,
987                                    response);
988           break;
989         default:
990           TC3_LOG(ERROR) << "Unsupported prediction type!";
991           return false;
992       }
993     }
994   }
995 
996   return true;
997 }
998 
SuggestActionsFromModel(const Conversation & conversation,const int num_messages,const ActionSuggestionOptions & options,ActionsSuggestionsResponse * response,std::unique_ptr<tflite::Interpreter> * interpreter) const999 bool ActionsSuggestions::SuggestActionsFromModel(
1000     const Conversation& conversation, const int num_messages,
1001     const ActionSuggestionOptions& options,
1002     ActionsSuggestionsResponse* response,
1003     std::unique_ptr<tflite::Interpreter>* interpreter) const {
1004   TC3_CHECK_LE(num_messages, conversation.messages.size());
1005 
1006   if (sensitive_model_ != nullptr &&
1007       sensitive_model_->EvalConversation(conversation, num_messages).first) {
1008     response->is_sensitive = true;
1009     return true;
1010   }
1011 
1012   if (!model_executor_) {
1013     return true;
1014   }
1015   *interpreter = model_executor_->CreateInterpreter();
1016 
1017   if (!*interpreter) {
1018     TC3_LOG(ERROR) << "Could not build TensorFlow Lite interpreter for the "
1019                       "actions suggestions model.";
1020     return false;
1021   }
1022 
1023   std::vector<std::string> context;
1024   std::vector<int> user_ids;
1025   std::vector<float> time_diffs;
1026   context.reserve(num_messages);
1027   user_ids.reserve(num_messages);
1028   time_diffs.reserve(num_messages);
1029 
1030   // Gather last `num_messages` messages from the conversation.
1031   int64 last_message_reference_time_ms_utc = 0;
1032   const float second_in_ms = 1000;
1033   for (int i = conversation.messages.size() - num_messages;
1034        i < conversation.messages.size(); i++) {
1035     const ConversationMessage& message = conversation.messages[i];
1036     context.push_back(message.text);
1037     user_ids.push_back(message.user_id);
1038 
1039     float time_diff_secs = 0;
1040     if (message.reference_time_ms_utc != 0 &&
1041         last_message_reference_time_ms_utc != 0) {
1042       time_diff_secs = std::max(0.0f, (message.reference_time_ms_utc -
1043                                        last_message_reference_time_ms_utc) /
1044                                           second_in_ms);
1045     }
1046     if (message.reference_time_ms_utc != 0) {
1047       last_message_reference_time_ms_utc = message.reference_time_ms_utc;
1048     }
1049     time_diffs.push_back(time_diff_secs);
1050   }
1051 
1052   if (!SetupModelInput(context, user_ids, time_diffs,
1053                        /*num_suggestions=*/model_->num_smart_replies(), options,
1054                        interpreter->get())) {
1055     TC3_LOG(ERROR) << "Failed to setup input for TensorFlow Lite model.";
1056     return false;
1057   }
1058 
1059   if ((*interpreter)->Invoke() != kTfLiteOk) {
1060     TC3_LOG(ERROR) << "Failed to invoke TensorFlow Lite interpreter.";
1061     return false;
1062   }
1063 
1064   return ReadModelOutput(interpreter->get(), options, response);
1065 }
1066 
SuggestActionsFromConversationIntentDetection(const Conversation & conversation,const ActionSuggestionOptions & options,std::vector<ActionSuggestion> * actions) const1067 Status ActionsSuggestions::SuggestActionsFromConversationIntentDetection(
1068     const Conversation& conversation, const ActionSuggestionOptions& options,
1069     std::vector<ActionSuggestion>* actions) const {
1070   TC3_ASSIGN_OR_RETURN(
1071       std::vector<ActionSuggestion> new_actions,
1072       conversation_intent_detection_->SuggestActions(conversation, options));
1073   for (auto& action : new_actions) {
1074     actions->push_back(std::move(action));
1075   }
1076   return Status::OK;
1077 }
1078 
AnnotationOptionsForMessage(const ConversationMessage & message) const1079 AnnotationOptions ActionsSuggestions::AnnotationOptionsForMessage(
1080     const ConversationMessage& message) const {
1081   AnnotationOptions options;
1082   options.detected_text_language_tags = message.detected_text_language_tags;
1083   options.reference_time_ms_utc = message.reference_time_ms_utc;
1084   options.reference_timezone = message.reference_timezone;
1085   options.annotation_usecase =
1086       model_->annotation_actions_spec()->annotation_usecase();
1087   options.is_serialized_entity_data_enabled =
1088       model_->annotation_actions_spec()->is_serialized_entity_data_enabled();
1089   options.entity_types = annotation_entity_types_;
1090   return options;
1091 }
1092 
1093 // Run annotator on the messages of a conversation.
AnnotateConversation(const Conversation & conversation,const Annotator * annotator) const1094 Conversation ActionsSuggestions::AnnotateConversation(
1095     const Conversation& conversation, const Annotator* annotator) const {
1096   if (annotator == nullptr) {
1097     return conversation;
1098   }
1099   const int num_messages_grammar =
1100       ((model_->rules() && model_->rules()->grammar_rules() &&
1101         model_->rules()
1102             ->grammar_rules()
1103             ->rules()
1104             ->nonterminals()
1105             ->annotation_nt())
1106            ? 1
1107            : 0);
1108   const int num_messages_mapping =
1109       (model_->annotation_actions_spec()
1110            ? std::max(model_->annotation_actions_spec()
1111                           ->max_history_from_any_person(),
1112                       model_->annotation_actions_spec()
1113                           ->max_history_from_last_person())
1114            : 0);
1115   const int num_messages = std::max(num_messages_grammar, num_messages_mapping);
1116   if (num_messages == 0) {
1117     // No annotations are used.
1118     return conversation;
1119   }
1120   Conversation annotated_conversation = conversation;
1121   for (int i = 0, message_index = annotated_conversation.messages.size() - 1;
1122        i < num_messages && message_index >= 0; i++, message_index--) {
1123     ConversationMessage* message =
1124         &annotated_conversation.messages[message_index];
1125     if (message->annotations.empty()) {
1126       message->annotations = annotator->Annotate(
1127           message->text, AnnotationOptionsForMessage(*message));
1128       ConvertDatetimeToTime(&message->annotations);
1129     }
1130   }
1131   return annotated_conversation;
1132 }
1133 
SuggestActionsFromAnnotations(const Conversation & conversation,std::vector<ActionSuggestion> * actions) const1134 void ActionsSuggestions::SuggestActionsFromAnnotations(
1135     const Conversation& conversation,
1136     std::vector<ActionSuggestion>* actions) const {
1137   if (model_->annotation_actions_spec() == nullptr ||
1138       model_->annotation_actions_spec()->annotation_mapping() == nullptr ||
1139       model_->annotation_actions_spec()->annotation_mapping()->size() == 0) {
1140     return;
1141   }
1142 
1143   // Create actions based on the annotations.
1144   const int max_from_any_person =
1145       model_->annotation_actions_spec()->max_history_from_any_person();
1146   const int max_from_last_person =
1147       model_->annotation_actions_spec()->max_history_from_last_person();
1148   const int last_person = conversation.messages.back().user_id;
1149 
1150   int num_messages_last_person = 0;
1151   int num_messages_any_person = 0;
1152   bool all_from_last_person = true;
1153   for (int message_index = conversation.messages.size() - 1; message_index >= 0;
1154        message_index--) {
1155     const ConversationMessage& message = conversation.messages[message_index];
1156     std::vector<AnnotatedSpan> annotations = message.annotations;
1157 
1158     // Update how many messages we have processed from the last person in the
1159     // conversation and from any person in the conversation.
1160     num_messages_any_person++;
1161     if (all_from_last_person && message.user_id == last_person) {
1162       num_messages_last_person++;
1163     } else {
1164       all_from_last_person = false;
1165     }
1166 
1167     if (num_messages_any_person > max_from_any_person &&
1168         (!all_from_last_person ||
1169          num_messages_last_person > max_from_last_person)) {
1170       break;
1171     }
1172 
1173     if (message.user_id == kLocalUserId) {
1174       if (model_->annotation_actions_spec()->only_until_last_sent()) {
1175         break;
1176       }
1177       if (!model_->annotation_actions_spec()->include_local_user_messages()) {
1178         continue;
1179       }
1180     }
1181 
1182     std::vector<ActionSuggestionAnnotation> action_annotations;
1183     action_annotations.reserve(annotations.size());
1184     for (const AnnotatedSpan& annotation : annotations) {
1185       if (annotation.classification.empty()) {
1186         continue;
1187       }
1188 
1189       const ClassificationResult& classification_result =
1190           annotation.classification[0];
1191 
1192       ActionSuggestionAnnotation action_annotation;
1193       action_annotation.span = {
1194           message_index, annotation.span,
1195           UTF8ToUnicodeText(message.text, /*do_copy=*/false)
1196               .UTF8Substring(annotation.span.first, annotation.span.second)};
1197       action_annotation.entity = classification_result;
1198       action_annotation.name = classification_result.collection;
1199       action_annotations.push_back(std::move(action_annotation));
1200     }
1201 
1202     if (model_->annotation_actions_spec()->deduplicate_annotations()) {
1203       // Create actions only for deduplicated annotations.
1204       for (const int annotation_id :
1205            DeduplicateAnnotations(action_annotations)) {
1206         SuggestActionsFromAnnotation(
1207             message_index, action_annotations[annotation_id], actions);
1208       }
1209     } else {
1210       // Create actions for all annotations.
1211       for (const ActionSuggestionAnnotation& annotation : action_annotations) {
1212         SuggestActionsFromAnnotation(message_index, annotation, actions);
1213       }
1214     }
1215   }
1216 }
1217 
SuggestActionsFromAnnotation(const int message_index,const ActionSuggestionAnnotation & annotation,std::vector<ActionSuggestion> * actions) const1218 void ActionsSuggestions::SuggestActionsFromAnnotation(
1219     const int message_index, const ActionSuggestionAnnotation& annotation,
1220     std::vector<ActionSuggestion>* actions) const {
1221   for (const AnnotationActionsSpec_::AnnotationMapping* mapping :
1222        *model_->annotation_actions_spec()->annotation_mapping()) {
1223     if (annotation.entity.collection ==
1224         mapping->annotation_collection()->str()) {
1225       if (annotation.entity.score < mapping->min_annotation_score()) {
1226         continue;
1227       }
1228 
1229       std::unique_ptr<MutableFlatbuffer> entity_data =
1230           entity_data_builder_ != nullptr ? entity_data_builder_->NewRoot()
1231                                           : nullptr;
1232 
1233       // Set annotation text as (additional) entity data field.
1234       if (mapping->entity_field() != nullptr) {
1235         TC3_CHECK_NE(entity_data, nullptr);
1236 
1237         UnicodeText normalized_annotation_text =
1238             UTF8ToUnicodeText(annotation.span.text, /*do_copy=*/false);
1239 
1240         // Apply normalization if specified.
1241         if (mapping->normalization_options() != nullptr) {
1242           normalized_annotation_text =
1243               NormalizeText(*unilib_, mapping->normalization_options(),
1244                             normalized_annotation_text);
1245         }
1246 
1247         entity_data->ParseAndSet(mapping->entity_field(),
1248                                  normalized_annotation_text.ToUTF8String());
1249       }
1250 
1251       ActionSuggestion suggestion;
1252       FillSuggestionFromSpec(mapping->action(), entity_data.get(), &suggestion);
1253       if (mapping->use_annotation_score()) {
1254         suggestion.score = annotation.entity.score;
1255       }
1256       suggestion.annotations = {annotation};
1257       actions->push_back(std::move(suggestion));
1258     }
1259   }
1260 }
1261 
DeduplicateAnnotations(const std::vector<ActionSuggestionAnnotation> & annotations) const1262 std::vector<int> ActionsSuggestions::DeduplicateAnnotations(
1263     const std::vector<ActionSuggestionAnnotation>& annotations) const {
1264   std::map<std::pair<std::string, std::string>, int> deduplicated_annotations;
1265 
1266   for (int i = 0; i < annotations.size(); i++) {
1267     const std::pair<std::string, std::string> key = {annotations[i].name,
1268                                                      annotations[i].span.text};
1269     auto entry = deduplicated_annotations.find(key);
1270     if (entry != deduplicated_annotations.end()) {
1271       // Kepp the annotation with the higher score.
1272       if (annotations[entry->second].entity.score <
1273           annotations[i].entity.score) {
1274         entry->second = i;
1275       }
1276       continue;
1277     }
1278     deduplicated_annotations.insert(entry, {key, i});
1279   }
1280 
1281   std::vector<int> result;
1282   result.reserve(deduplicated_annotations.size());
1283   for (const auto& key_and_annotation : deduplicated_annotations) {
1284     result.push_back(key_and_annotation.second);
1285   }
1286   return result;
1287 }
1288 
1289 #if !defined(TC3_DISABLE_LUA)
SuggestActionsFromLua(const Conversation & conversation,const TfLiteModelExecutor * model_executor,const tflite::Interpreter * interpreter,const reflection::Schema * annotation_entity_data_schema,std::vector<ActionSuggestion> * actions) const1290 bool ActionsSuggestions::SuggestActionsFromLua(
1291     const Conversation& conversation, const TfLiteModelExecutor* model_executor,
1292     const tflite::Interpreter* interpreter,
1293     const reflection::Schema* annotation_entity_data_schema,
1294     std::vector<ActionSuggestion>* actions) const {
1295   if (lua_bytecode_.empty()) {
1296     return true;
1297   }
1298 
1299   auto lua_actions = LuaActionsSuggestions::CreateLuaActionsSuggestions(
1300       lua_bytecode_, conversation, model_executor, model_->tflite_model_spec(),
1301       interpreter, entity_data_schema_, annotation_entity_data_schema);
1302   if (lua_actions == nullptr) {
1303     TC3_LOG(ERROR) << "Could not create lua actions.";
1304     return false;
1305   }
1306   return lua_actions->SuggestActions(actions);
1307 }
1308 #else
SuggestActionsFromLua(const Conversation & conversation,const TfLiteModelExecutor * model_executor,const tflite::Interpreter * interpreter,const reflection::Schema * annotation_entity_data_schema,std::vector<ActionSuggestion> * actions) const1309 bool ActionsSuggestions::SuggestActionsFromLua(
1310     const Conversation& conversation, const TfLiteModelExecutor* model_executor,
1311     const tflite::Interpreter* interpreter,
1312     const reflection::Schema* annotation_entity_data_schema,
1313     std::vector<ActionSuggestion>* actions) const {
1314   return true;
1315 }
1316 #endif
1317 
GatherActionsSuggestions(const Conversation & conversation,const Annotator * annotator,const ActionSuggestionOptions & options,ActionsSuggestionsResponse * response) const1318 bool ActionsSuggestions::GatherActionsSuggestions(
1319     const Conversation& conversation, const Annotator* annotator,
1320     const ActionSuggestionOptions& options,
1321     ActionsSuggestionsResponse* response) const {
1322   if (conversation.messages.empty()) {
1323     return true;
1324   }
1325 
1326   // Run annotator against messages.
1327   const Conversation annotated_conversation =
1328       AnnotateConversation(conversation, annotator);
1329 
1330   const int num_messages = NumMessagesToConsider(
1331       annotated_conversation, model_->max_conversation_history_length());
1332 
1333   if (num_messages <= 0) {
1334     TC3_LOG(INFO) << "No messages provided for actions suggestions.";
1335     return false;
1336   }
1337 
1338   SuggestActionsFromAnnotations(annotated_conversation, &response->actions);
1339 
1340   if (grammar_actions_ != nullptr &&
1341       !grammar_actions_->SuggestActions(annotated_conversation,
1342                                         &response->actions)) {
1343     TC3_LOG(ERROR) << "Could not suggest actions from grammar rules.";
1344     return false;
1345   }
1346 
1347   int input_text_length = 0;
1348   int num_matching_locales = 0;
1349   for (int i = annotated_conversation.messages.size() - num_messages;
1350        i < annotated_conversation.messages.size(); i++) {
1351     input_text_length += annotated_conversation.messages[i].text.length();
1352     std::vector<Locale> message_languages;
1353     if (!ParseLocales(
1354             annotated_conversation.messages[i].detected_text_language_tags,
1355             &message_languages)) {
1356       continue;
1357     }
1358     if (Locale::IsAnyLocaleSupported(
1359             message_languages, locales_,
1360             preconditions_.handle_unknown_locale_as_supported)) {
1361       ++num_matching_locales;
1362     }
1363   }
1364 
1365   // Bail out if we are provided with too few or too much input.
1366   if (input_text_length < preconditions_.min_input_length ||
1367       (preconditions_.max_input_length >= 0 &&
1368        input_text_length > preconditions_.max_input_length)) {
1369     TC3_LOG(INFO) << "Too much or not enough input for inference.";
1370     return response;
1371   }
1372 
1373   // Bail out if the text does not look like it can be handled by the model.
1374   const float matching_fraction =
1375       static_cast<float>(num_matching_locales) / num_messages;
1376   if (matching_fraction < preconditions_.min_locale_match_fraction) {
1377     TC3_LOG(INFO) << "Not enough locale matches.";
1378     response->output_filtered_locale_mismatch = true;
1379     return true;
1380   }
1381 
1382   std::vector<const UniLib::RegexPattern*> post_check_rules;
1383   if (preconditions_.suppress_on_low_confidence_input) {
1384     if (regex_actions_->IsLowConfidenceInput(annotated_conversation,
1385                                              num_messages, &post_check_rules)) {
1386       response->output_filtered_low_confidence = true;
1387       return true;
1388     }
1389   }
1390 
1391   std::unique_ptr<tflite::Interpreter> interpreter;
1392   if (!SuggestActionsFromModel(annotated_conversation, num_messages, options,
1393                                response, &interpreter)) {
1394     TC3_LOG(ERROR) << "Could not run model.";
1395     return false;
1396   }
1397 
1398   // SuggestActionsFromModel also detects if the conversation is sensitive,
1399   // either by using the old ngram model or the new model.
1400   // Suppress all predictions if the conversation was deemed sensitive.
1401   if (preconditions_.suppress_on_sensitive_topic && response->is_sensitive) {
1402     return true;
1403   }
1404 
1405   if (conversation_intent_detection_) {
1406     // TODO(zbin): Ensure the deduplication/ranking logic in ranker.cc works.
1407     auto actions = SuggestActionsFromConversationIntentDetection(
1408         annotated_conversation, options, &response->actions);
1409     if (!actions.ok()) {
1410       TC3_LOG(ERROR) << "Could not run conversation intent detection: "
1411                      << actions.error_message();
1412       return false;
1413     }
1414   }
1415 
1416   if (!SuggestActionsFromLua(
1417           annotated_conversation, model_executor_.get(), interpreter.get(),
1418           annotator != nullptr ? annotator->entity_data_schema() : nullptr,
1419           &response->actions)) {
1420     TC3_LOG(ERROR) << "Could not suggest actions from script.";
1421     return false;
1422   }
1423 
1424   if (!regex_actions_->SuggestActions(annotated_conversation,
1425                                       entity_data_builder_.get(),
1426                                       &response->actions)) {
1427     TC3_LOG(ERROR) << "Could not suggest actions from regex rules.";
1428     return false;
1429   }
1430 
1431   if (preconditions_.suppress_on_low_confidence_input &&
1432       !regex_actions_->FilterConfidenceOutput(post_check_rules,
1433                                               &response->actions)) {
1434     TC3_LOG(ERROR) << "Could not post-check actions.";
1435     return false;
1436   }
1437 
1438   return true;
1439 }
1440 
SuggestActions(const Conversation & conversation,const Annotator * annotator,const ActionSuggestionOptions & options) const1441 ActionsSuggestionsResponse ActionsSuggestions::SuggestActions(
1442     const Conversation& conversation, const Annotator* annotator,
1443     const ActionSuggestionOptions& options) const {
1444   ActionsSuggestionsResponse response;
1445 
1446   // Assert that messages are sorted correctly.
1447   for (int i = 1; i < conversation.messages.size(); i++) {
1448     if (conversation.messages[i].reference_time_ms_utc <
1449         conversation.messages[i - 1].reference_time_ms_utc) {
1450       TC3_LOG(ERROR) << "Messages are not sorted most recent last.";
1451       return response;
1452     }
1453   }
1454 
1455   // Check that messages are valid utf8.
1456   for (const ConversationMessage& message : conversation.messages) {
1457     if (message.text.size() > std::numeric_limits<int>::max()) {
1458       TC3_LOG(ERROR) << "Rejecting too long input: " << message.text.size();
1459       return {};
1460     }
1461 
1462     if (!unilib_->IsValidUtf8(UTF8ToUnicodeText(
1463             message.text.data(), message.text.size(), /*do_copy=*/false))) {
1464       TC3_LOG(ERROR) << "Not valid utf8 provided.";
1465       return response;
1466     }
1467   }
1468 
1469   if (!GatherActionsSuggestions(conversation, annotator, options, &response)) {
1470     TC3_LOG(ERROR) << "Could not gather actions suggestions.";
1471     response.actions.clear();
1472   } else if (!ranker_->RankActions(conversation, &response, entity_data_schema_,
1473                                    annotator != nullptr
1474                                        ? annotator->entity_data_schema()
1475                                        : nullptr)) {
1476     TC3_LOG(ERROR) << "Could not rank actions.";
1477     response.actions.clear();
1478   }
1479   return response;
1480 }
1481 
SuggestActions(const Conversation & conversation,const ActionSuggestionOptions & options) const1482 ActionsSuggestionsResponse ActionsSuggestions::SuggestActions(
1483     const Conversation& conversation,
1484     const ActionSuggestionOptions& options) const {
1485   return SuggestActions(conversation, /*annotator=*/nullptr, options);
1486 }
1487 
model() const1488 const ActionsModel* ActionsSuggestions::model() const { return model_; }
entity_data_schema() const1489 const reflection::Schema* ActionsSuggestions::entity_data_schema() const {
1490   return entity_data_schema_;
1491 }
1492 
ViewActionsModel(const void * buffer,int size)1493 const ActionsModel* ViewActionsModel(const void* buffer, int size) {
1494   if (buffer == nullptr) {
1495     return nullptr;
1496   }
1497   return LoadAndVerifyModel(reinterpret_cast<const uint8_t*>(buffer), size);
1498 }
1499 
InitializeConversationIntentDetection(const std::string & serialized_config)1500 bool ActionsSuggestions::InitializeConversationIntentDetection(
1501     const std::string& serialized_config) {
1502   auto conversation_intent_detection =
1503       std::make_unique<ConversationIntentDetection>();
1504   if (!conversation_intent_detection->Initialize(serialized_config).ok()) {
1505     TC3_LOG(ERROR) << "Failed to initialize conversation intent detection.";
1506     return false;
1507   }
1508   conversation_intent_detection_ = std::move(conversation_intent_detection);
1509   return true;
1510 }
1511 
1512 }  // namespace libtextclassifier3
1513