1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "actions/actions-suggestions.h"
18
19 #include <memory>
20 #include <string>
21 #include <vector>
22
23 #include "utils/base/statusor.h"
24
25 #if !defined(TC3_DISABLE_LUA)
26 #include "actions/lua-actions.h"
27 #endif
28 #include "actions/ngram-model.h"
29 #include "actions/tflite-sensitive-model.h"
30 #include "actions/types.h"
31 #include "actions/utils.h"
32 #include "actions/zlib-utils.h"
33 #include "annotator/collections.h"
34 #include "utils/base/logging.h"
35 #if !defined(TC3_DISABLE_LUA)
36 #include "utils/lua-utils.h"
37 #endif
38 #include "utils/normalization.h"
39 #include "utils/optional.h"
40 #include "utils/strings/split.h"
41 #include "utils/strings/stringpiece.h"
42 #include "utils/strings/utf8.h"
43 #include "utils/utf8/unicodetext.h"
44 #include "absl/container/flat_hash_set.h"
45 #include "tensorflow/lite/string_util.h"
46
47 namespace libtextclassifier3 {
48
49 constexpr float kDefaultFloat = 0.0;
50 constexpr bool kDefaultBool = false;
51 constexpr int kDefaultInt = 1;
52
53 namespace {
54
LoadAndVerifyModel(const uint8_t * addr,int size)55 const ActionsModel* LoadAndVerifyModel(const uint8_t* addr, int size) {
56 flatbuffers::Verifier verifier(addr, size);
57 if (VerifyActionsModelBuffer(verifier)) {
58 return GetActionsModel(addr);
59 } else {
60 return nullptr;
61 }
62 }
63
64 template <typename T>
ValueOrDefault(const flatbuffers::Table * values,const int32 field_offset,const T default_value)65 T ValueOrDefault(const flatbuffers::Table* values, const int32 field_offset,
66 const T default_value) {
67 if (values == nullptr) {
68 return default_value;
69 }
70 return values->GetField<T>(field_offset, default_value);
71 }
72
73 // Returns number of (tail) messages of a conversation to consider.
NumMessagesToConsider(const Conversation & conversation,const int max_conversation_history_length)74 int NumMessagesToConsider(const Conversation& conversation,
75 const int max_conversation_history_length) {
76 return ((max_conversation_history_length < 0 ||
77 conversation.messages.size() < max_conversation_history_length)
78 ? conversation.messages.size()
79 : max_conversation_history_length);
80 }
81
82 template <typename T>
PadOrTruncateToTargetLength(const std::vector<T> & inputs,const int max_length,const T pad_value)83 std::vector<T> PadOrTruncateToTargetLength(const std::vector<T>& inputs,
84 const int max_length,
85 const T pad_value) {
86 if (inputs.size() >= max_length) {
87 return std::vector<T>(inputs.begin(), inputs.begin() + max_length);
88 } else {
89 std::vector<T> result;
90 result.reserve(max_length);
91 result.insert(result.begin(), inputs.begin(), inputs.end());
92 result.insert(result.end(), max_length - inputs.size(), pad_value);
93 return result;
94 }
95 }
96
97 template <typename T>
SetVectorOrScalarAsModelInput(const int param_index,const Variant & param_value,tflite::Interpreter * interpreter,const std::unique_ptr<const TfLiteModelExecutor> & model_executor)98 void SetVectorOrScalarAsModelInput(
99 const int param_index, const Variant& param_value,
100 tflite::Interpreter* interpreter,
101 const std::unique_ptr<const TfLiteModelExecutor>& model_executor) {
102 if (param_value.Has<std::vector<T>>()) {
103 model_executor->SetInput<T>(
104 param_index, param_value.ConstRefValue<std::vector<T>>(), interpreter);
105 } else if (param_value.Has<T>()) {
106 model_executor->SetInput<float>(param_index, param_value.Value<T>(),
107 interpreter);
108 } else {
109 TC3_LOG(ERROR) << "Variant type error!";
110 }
111 }
112 } // namespace
113
FromUnownedBuffer(const uint8_t * buffer,const int size,const UniLib * unilib,const std::string & triggering_preconditions_overlay)114 std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromUnownedBuffer(
115 const uint8_t* buffer, const int size, const UniLib* unilib,
116 const std::string& triggering_preconditions_overlay) {
117 auto actions = std::unique_ptr<ActionsSuggestions>(new ActionsSuggestions());
118 const ActionsModel* model = LoadAndVerifyModel(buffer, size);
119 if (model == nullptr) {
120 return nullptr;
121 }
122 actions->model_ = model;
123 actions->SetOrCreateUnilib(unilib);
124 actions->triggering_preconditions_overlay_buffer_ =
125 triggering_preconditions_overlay;
126 if (!actions->ValidateAndInitialize()) {
127 return nullptr;
128 }
129 return actions;
130 }
131
FromScopedMmap(std::unique_ptr<libtextclassifier3::ScopedMmap> mmap,const UniLib * unilib,const std::string & triggering_preconditions_overlay)132 std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromScopedMmap(
133 std::unique_ptr<libtextclassifier3::ScopedMmap> mmap, const UniLib* unilib,
134 const std::string& triggering_preconditions_overlay) {
135 if (!mmap->handle().ok()) {
136 TC3_VLOG(1) << "Mmap failed.";
137 return nullptr;
138 }
139 const ActionsModel* model = LoadAndVerifyModel(
140 reinterpret_cast<const uint8_t*>(mmap->handle().start()),
141 mmap->handle().num_bytes());
142 if (!model) {
143 TC3_LOG(ERROR) << "Model verification failed.";
144 return nullptr;
145 }
146 auto actions = std::unique_ptr<ActionsSuggestions>(new ActionsSuggestions());
147 actions->model_ = model;
148 actions->mmap_ = std::move(mmap);
149 actions->SetOrCreateUnilib(unilib);
150 actions->triggering_preconditions_overlay_buffer_ =
151 triggering_preconditions_overlay;
152 if (!actions->ValidateAndInitialize()) {
153 return nullptr;
154 }
155 return actions;
156 }
157
FromScopedMmap(std::unique_ptr<libtextclassifier3::ScopedMmap> mmap,std::unique_ptr<UniLib> unilib,const std::string & triggering_preconditions_overlay)158 std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromScopedMmap(
159 std::unique_ptr<libtextclassifier3::ScopedMmap> mmap,
160 std::unique_ptr<UniLib> unilib,
161 const std::string& triggering_preconditions_overlay) {
162 if (!mmap->handle().ok()) {
163 TC3_VLOG(1) << "Mmap failed.";
164 return nullptr;
165 }
166 const ActionsModel* model = LoadAndVerifyModel(
167 reinterpret_cast<const uint8_t*>(mmap->handle().start()),
168 mmap->handle().num_bytes());
169 if (!model) {
170 TC3_LOG(ERROR) << "Model verification failed.";
171 return nullptr;
172 }
173 auto actions = std::unique_ptr<ActionsSuggestions>(new ActionsSuggestions());
174 actions->model_ = model;
175 actions->mmap_ = std::move(mmap);
176 actions->owned_unilib_ = std::move(unilib);
177 actions->unilib_ = actions->owned_unilib_.get();
178 actions->triggering_preconditions_overlay_buffer_ =
179 triggering_preconditions_overlay;
180 if (!actions->ValidateAndInitialize()) {
181 return nullptr;
182 }
183 return actions;
184 }
185
FromFileDescriptor(const int fd,const int offset,const int size,const UniLib * unilib,const std::string & triggering_preconditions_overlay)186 std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromFileDescriptor(
187 const int fd, const int offset, const int size, const UniLib* unilib,
188 const std::string& triggering_preconditions_overlay) {
189 std::unique_ptr<libtextclassifier3::ScopedMmap> mmap;
190 if (offset >= 0 && size >= 0) {
191 mmap.reset(new libtextclassifier3::ScopedMmap(fd, offset, size));
192 } else {
193 mmap.reset(new libtextclassifier3::ScopedMmap(fd));
194 }
195 return FromScopedMmap(std::move(mmap), unilib,
196 triggering_preconditions_overlay);
197 }
198
FromFileDescriptor(const int fd,const int offset,const int size,std::unique_ptr<UniLib> unilib,const std::string & triggering_preconditions_overlay)199 std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromFileDescriptor(
200 const int fd, const int offset, const int size,
201 std::unique_ptr<UniLib> unilib,
202 const std::string& triggering_preconditions_overlay) {
203 std::unique_ptr<libtextclassifier3::ScopedMmap> mmap;
204 if (offset >= 0 && size >= 0) {
205 mmap.reset(new libtextclassifier3::ScopedMmap(fd, offset, size));
206 } else {
207 mmap.reset(new libtextclassifier3::ScopedMmap(fd));
208 }
209 return FromScopedMmap(std::move(mmap), std::move(unilib),
210 triggering_preconditions_overlay);
211 }
212
FromFileDescriptor(const int fd,const UniLib * unilib,const std::string & triggering_preconditions_overlay)213 std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromFileDescriptor(
214 const int fd, const UniLib* unilib,
215 const std::string& triggering_preconditions_overlay) {
216 std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
217 new libtextclassifier3::ScopedMmap(fd));
218 return FromScopedMmap(std::move(mmap), unilib,
219 triggering_preconditions_overlay);
220 }
221
FromFileDescriptor(const int fd,std::unique_ptr<UniLib> unilib,const std::string & triggering_preconditions_overlay)222 std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromFileDescriptor(
223 const int fd, std::unique_ptr<UniLib> unilib,
224 const std::string& triggering_preconditions_overlay) {
225 std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
226 new libtextclassifier3::ScopedMmap(fd));
227 return FromScopedMmap(std::move(mmap), std::move(unilib),
228 triggering_preconditions_overlay);
229 }
230
FromPath(const std::string & path,const UniLib * unilib,const std::string & triggering_preconditions_overlay)231 std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromPath(
232 const std::string& path, const UniLib* unilib,
233 const std::string& triggering_preconditions_overlay) {
234 std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
235 new libtextclassifier3::ScopedMmap(path));
236 return FromScopedMmap(std::move(mmap), unilib,
237 triggering_preconditions_overlay);
238 }
239
FromPath(const std::string & path,std::unique_ptr<UniLib> unilib,const std::string & triggering_preconditions_overlay)240 std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromPath(
241 const std::string& path, std::unique_ptr<UniLib> unilib,
242 const std::string& triggering_preconditions_overlay) {
243 std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
244 new libtextclassifier3::ScopedMmap(path));
245 return FromScopedMmap(std::move(mmap), std::move(unilib),
246 triggering_preconditions_overlay);
247 }
248
SetOrCreateUnilib(const UniLib * unilib)249 void ActionsSuggestions::SetOrCreateUnilib(const UniLib* unilib) {
250 if (unilib != nullptr) {
251 unilib_ = unilib;
252 } else {
253 owned_unilib_.reset(new UniLib);
254 unilib_ = owned_unilib_.get();
255 }
256 }
257
ValidateAndInitialize()258 bool ActionsSuggestions::ValidateAndInitialize() {
259 if (model_ == nullptr) {
260 TC3_LOG(ERROR) << "No model specified.";
261 return false;
262 }
263
264 if (model_->smart_reply_action_type() == nullptr) {
265 TC3_LOG(ERROR) << "No smart reply action type specified.";
266 return false;
267 }
268
269 if (!InitializeTriggeringPreconditions()) {
270 TC3_LOG(ERROR) << "Could not initialize preconditions.";
271 return false;
272 }
273
274 if (model_->locales() &&
275 !ParseLocales(model_->locales()->c_str(), &locales_)) {
276 TC3_LOG(ERROR) << "Could not parse model supported locales.";
277 return false;
278 }
279
280 if (model_->tflite_model_spec() != nullptr) {
281 model_executor_ = TfLiteModelExecutor::FromBuffer(
282 model_->tflite_model_spec()->tflite_model());
283 if (!model_executor_) {
284 TC3_LOG(ERROR) << "Could not initialize model executor.";
285 return false;
286 }
287 }
288
289 // Gather annotation entities for the rules.
290 if (model_->annotation_actions_spec() != nullptr &&
291 model_->annotation_actions_spec()->annotation_mapping() != nullptr) {
292 for (const AnnotationActionsSpec_::AnnotationMapping* mapping :
293 *model_->annotation_actions_spec()->annotation_mapping()) {
294 annotation_entity_types_.insert(mapping->annotation_collection()->str());
295 }
296 }
297
298 if (model_->actions_entity_data_schema() != nullptr) {
299 entity_data_schema_ = LoadAndVerifyFlatbuffer<reflection::Schema>(
300 model_->actions_entity_data_schema()->Data(),
301 model_->actions_entity_data_schema()->size());
302 if (entity_data_schema_ == nullptr) {
303 TC3_LOG(ERROR) << "Could not load entity data schema data.";
304 return false;
305 }
306
307 entity_data_builder_.reset(
308 new MutableFlatbufferBuilder(entity_data_schema_));
309 } else {
310 entity_data_schema_ = nullptr;
311 }
312
313 // Initialize regular expressions model.
314 std::unique_ptr<ZlibDecompressor> decompressor = ZlibDecompressor::Instance();
315 regex_actions_.reset(
316 new RegexActions(unilib_, model_->smart_reply_action_type()->str()));
317 if (!regex_actions_->InitializeRules(
318 model_->rules(), model_->low_confidence_rules(),
319 triggering_preconditions_overlay_, decompressor.get())) {
320 TC3_LOG(ERROR) << "Could not initialize regex rules.";
321 return false;
322 }
323
324 // Setup grammar model.
325 if (model_->rules() != nullptr &&
326 model_->rules()->grammar_rules() != nullptr) {
327 grammar_actions_.reset(new GrammarActions(
328 unilib_, model_->rules()->grammar_rules(), entity_data_builder_.get(),
329 model_->smart_reply_action_type()->str()));
330
331 // Gather annotation entities for the grammars.
332 if (auto annotation_nt = model_->rules()
333 ->grammar_rules()
334 ->rules()
335 ->nonterminals()
336 ->annotation_nt()) {
337 for (const grammar::RulesSet_::Nonterminals_::AnnotationNtEntry* entry :
338 *annotation_nt) {
339 annotation_entity_types_.insert(entry->key()->str());
340 }
341 }
342 }
343
344 #if !defined(TC3_DISABLE_LUA)
345 std::string actions_script;
346 if (GetUncompressedString(model_->lua_actions_script(),
347 model_->compressed_lua_actions_script(),
348 decompressor.get(), &actions_script) &&
349 !actions_script.empty()) {
350 if (!Compile(actions_script, &lua_bytecode_)) {
351 TC3_LOG(ERROR) << "Could not precompile lua actions snippet.";
352 return false;
353 }
354 }
355 #endif // TC3_DISABLE_LUA
356
357 if (!(ranker_ = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
358 model_->ranking_options(), decompressor.get(),
359 model_->smart_reply_action_type()->str()))) {
360 TC3_LOG(ERROR) << "Could not create an action suggestions ranker.";
361 return false;
362 }
363
364 // Create feature processor if specified.
365 const ActionsTokenFeatureProcessorOptions* options =
366 model_->feature_processor_options();
367 if (options != nullptr) {
368 if (options->tokenizer_options() == nullptr) {
369 TC3_LOG(ERROR) << "No tokenizer options specified.";
370 return false;
371 }
372
373 feature_processor_.reset(new ActionsFeatureProcessor(options, unilib_));
374 embedding_executor_ = TFLiteEmbeddingExecutor::FromBuffer(
375 options->embedding_model(), options->embedding_size(),
376 options->embedding_quantization_bits());
377
378 if (embedding_executor_ == nullptr) {
379 TC3_LOG(ERROR) << "Could not initialize embedding executor.";
380 return false;
381 }
382
383 // Cache embedding of padding, start and end token.
384 if (!EmbedTokenId(options->padding_token_id(), &embedded_padding_token_) ||
385 !EmbedTokenId(options->start_token_id(), &embedded_start_token_) ||
386 !EmbedTokenId(options->end_token_id(), &embedded_end_token_)) {
387 TC3_LOG(ERROR) << "Could not precompute token embeddings.";
388 return false;
389 }
390 token_embedding_size_ = feature_processor_->GetTokenEmbeddingSize();
391 }
392
393 // Create low confidence model if specified.
394 if (model_->low_confidence_ngram_model() != nullptr) {
395 sensitive_model_ = NGramSensitiveModel::Create(
396 unilib_, model_->low_confidence_ngram_model(),
397 feature_processor_ == nullptr ? nullptr
398 : feature_processor_->tokenizer());
399 if (sensitive_model_ == nullptr) {
400 TC3_LOG(ERROR) << "Could not create ngram linear regression model.";
401 return false;
402 }
403 }
404 if (model_->low_confidence_tflite_model() != nullptr) {
405 sensitive_model_ =
406 TFLiteSensitiveModel::Create(model_->low_confidence_tflite_model());
407 if (sensitive_model_ == nullptr) {
408 TC3_LOG(ERROR) << "Could not create TFLite sensitive model.";
409 return false;
410 }
411 }
412
413 return true;
414 }
415
InitializeTriggeringPreconditions()416 bool ActionsSuggestions::InitializeTriggeringPreconditions() {
417 triggering_preconditions_overlay_ =
418 LoadAndVerifyFlatbuffer<TriggeringPreconditions>(
419 triggering_preconditions_overlay_buffer_);
420
421 if (triggering_preconditions_overlay_ == nullptr &&
422 !triggering_preconditions_overlay_buffer_.empty()) {
423 TC3_LOG(ERROR) << "Could not load triggering preconditions overwrites.";
424 return false;
425 }
426 const flatbuffers::Table* overlay =
427 reinterpret_cast<const flatbuffers::Table*>(
428 triggering_preconditions_overlay_);
429 const TriggeringPreconditions* defaults = model_->preconditions();
430 if (defaults == nullptr) {
431 TC3_LOG(ERROR) << "No triggering conditions specified.";
432 return false;
433 }
434
435 preconditions_.min_smart_reply_triggering_score = ValueOrDefault(
436 overlay, TriggeringPreconditions::VT_MIN_SMART_REPLY_TRIGGERING_SCORE,
437 defaults->min_smart_reply_triggering_score());
438 preconditions_.max_sensitive_topic_score = ValueOrDefault(
439 overlay, TriggeringPreconditions::VT_MAX_SENSITIVE_TOPIC_SCORE,
440 defaults->max_sensitive_topic_score());
441 preconditions_.suppress_on_sensitive_topic = ValueOrDefault(
442 overlay, TriggeringPreconditions::VT_SUPPRESS_ON_SENSITIVE_TOPIC,
443 defaults->suppress_on_sensitive_topic());
444 preconditions_.min_input_length =
445 ValueOrDefault(overlay, TriggeringPreconditions::VT_MIN_INPUT_LENGTH,
446 defaults->min_input_length());
447 preconditions_.max_input_length =
448 ValueOrDefault(overlay, TriggeringPreconditions::VT_MAX_INPUT_LENGTH,
449 defaults->max_input_length());
450 preconditions_.min_locale_match_fraction = ValueOrDefault(
451 overlay, TriggeringPreconditions::VT_MIN_LOCALE_MATCH_FRACTION,
452 defaults->min_locale_match_fraction());
453 preconditions_.handle_missing_locale_as_supported = ValueOrDefault(
454 overlay, TriggeringPreconditions::VT_HANDLE_MISSING_LOCALE_AS_SUPPORTED,
455 defaults->handle_missing_locale_as_supported());
456 preconditions_.handle_unknown_locale_as_supported = ValueOrDefault(
457 overlay, TriggeringPreconditions::VT_HANDLE_UNKNOWN_LOCALE_AS_SUPPORTED,
458 defaults->handle_unknown_locale_as_supported());
459 preconditions_.suppress_on_low_confidence_input = ValueOrDefault(
460 overlay, TriggeringPreconditions::VT_SUPPRESS_ON_LOW_CONFIDENCE_INPUT,
461 defaults->suppress_on_low_confidence_input());
462 preconditions_.min_reply_score_threshold = ValueOrDefault(
463 overlay, TriggeringPreconditions::VT_MIN_REPLY_SCORE_THRESHOLD,
464 defaults->min_reply_score_threshold());
465
466 return true;
467 }
468
EmbedTokenId(const int32 token_id,std::vector<float> * embedding) const469 bool ActionsSuggestions::EmbedTokenId(const int32 token_id,
470 std::vector<float>* embedding) const {
471 return feature_processor_->AppendFeatures(
472 {token_id},
473 /*dense_features=*/{}, embedding_executor_.get(), embedding);
474 }
475
Tokenize(const std::vector<std::string> & context) const476 std::vector<std::vector<Token>> ActionsSuggestions::Tokenize(
477 const std::vector<std::string>& context) const {
478 std::vector<std::vector<Token>> tokens;
479 tokens.reserve(context.size());
480 for (const std::string& message : context) {
481 tokens.push_back(feature_processor_->tokenizer()->Tokenize(message));
482 }
483 return tokens;
484 }
485
EmbedTokensPerMessage(const std::vector<std::vector<Token>> & tokens,std::vector<float> * embeddings,int * max_num_tokens_per_message) const486 bool ActionsSuggestions::EmbedTokensPerMessage(
487 const std::vector<std::vector<Token>>& tokens,
488 std::vector<float>* embeddings, int* max_num_tokens_per_message) const {
489 const int num_messages = tokens.size();
490 *max_num_tokens_per_message = 0;
491 for (int i = 0; i < num_messages; i++) {
492 const int num_message_tokens = tokens[i].size();
493 if (num_message_tokens > *max_num_tokens_per_message) {
494 *max_num_tokens_per_message = num_message_tokens;
495 }
496 }
497
498 if (model_->feature_processor_options()->min_num_tokens_per_message() >
499 *max_num_tokens_per_message) {
500 *max_num_tokens_per_message =
501 model_->feature_processor_options()->min_num_tokens_per_message();
502 }
503 if (model_->feature_processor_options()->max_num_tokens_per_message() > 0 &&
504 *max_num_tokens_per_message >
505 model_->feature_processor_options()->max_num_tokens_per_message()) {
506 *max_num_tokens_per_message =
507 model_->feature_processor_options()->max_num_tokens_per_message();
508 }
509
510 // Embed all tokens and add paddings to pad tokens of each message to the
511 // maximum number of tokens in a message of the conversation.
512 // If a number of tokens is specified in the model config, tokens at the
513 // beginning of a message are dropped if they don't fit in the limit.
514 for (int i = 0; i < num_messages; i++) {
515 const int start =
516 std::max<int>(tokens[i].size() - *max_num_tokens_per_message, 0);
517 for (int pos = start; pos < tokens[i].size(); pos++) {
518 if (!feature_processor_->AppendTokenFeatures(
519 tokens[i][pos], embedding_executor_.get(), embeddings)) {
520 TC3_LOG(ERROR) << "Could not run token feature extractor.";
521 return false;
522 }
523 }
524 // Add padding.
525 for (int k = tokens[i].size(); k < *max_num_tokens_per_message; k++) {
526 embeddings->insert(embeddings->end(), embedded_padding_token_.begin(),
527 embedded_padding_token_.end());
528 }
529 }
530
531 return true;
532 }
533
EmbedAndFlattenTokens(const std::vector<std::vector<Token>> & tokens,std::vector<float> * embeddings,int * total_token_count) const534 bool ActionsSuggestions::EmbedAndFlattenTokens(
535 const std::vector<std::vector<Token>>& tokens,
536 std::vector<float>* embeddings, int* total_token_count) const {
537 const int num_messages = tokens.size();
538 int start_message = 0;
539 int message_token_offset = 0;
540
541 // If a maximum model input length is specified, we need to check how
542 // much we need to trim at the start.
543 const int max_num_total_tokens =
544 model_->feature_processor_options()->max_num_total_tokens();
545 if (max_num_total_tokens > 0) {
546 int total_tokens = 0;
547 start_message = num_messages - 1;
548 for (; start_message >= 0; start_message--) {
549 // Tokens of the message + start and end token.
550 const int num_message_tokens = tokens[start_message].size() + 2;
551 total_tokens += num_message_tokens;
552
553 // Check whether we exhausted the budget.
554 if (total_tokens >= max_num_total_tokens) {
555 message_token_offset = total_tokens - max_num_total_tokens;
556 break;
557 }
558 }
559 }
560
561 // Add embeddings.
562 *total_token_count = 0;
563 for (int i = start_message; i < num_messages; i++) {
564 if (message_token_offset == 0) {
565 ++(*total_token_count);
566 // Add `start message` token.
567 embeddings->insert(embeddings->end(), embedded_start_token_.begin(),
568 embedded_start_token_.end());
569 }
570
571 for (int pos = std::max(0, message_token_offset - 1);
572 pos < tokens[i].size(); pos++) {
573 ++(*total_token_count);
574 if (!feature_processor_->AppendTokenFeatures(
575 tokens[i][pos], embedding_executor_.get(), embeddings)) {
576 TC3_LOG(ERROR) << "Could not run token feature extractor.";
577 return false;
578 }
579 }
580
581 // Add `end message` token.
582 ++(*total_token_count);
583 embeddings->insert(embeddings->end(), embedded_end_token_.begin(),
584 embedded_end_token_.end());
585
586 // Reset for the subsequent messages.
587 message_token_offset = 0;
588 }
589
590 // Add optional padding.
591 const int min_num_total_tokens =
592 model_->feature_processor_options()->min_num_total_tokens();
593 for (; *total_token_count < min_num_total_tokens; ++(*total_token_count)) {
594 embeddings->insert(embeddings->end(), embedded_padding_token_.begin(),
595 embedded_padding_token_.end());
596 }
597
598 return true;
599 }
600
AllocateInput(const int conversation_length,const int max_tokens,const int total_token_count,tflite::Interpreter * interpreter) const601 bool ActionsSuggestions::AllocateInput(const int conversation_length,
602 const int max_tokens,
603 const int total_token_count,
604 tflite::Interpreter* interpreter) const {
605 if (model_->tflite_model_spec()->resize_inputs()) {
606 if (model_->tflite_model_spec()->input_context() >= 0) {
607 interpreter->ResizeInputTensor(
608 interpreter->inputs()[model_->tflite_model_spec()->input_context()],
609 {1, conversation_length});
610 }
611 if (model_->tflite_model_spec()->input_user_id() >= 0) {
612 interpreter->ResizeInputTensor(
613 interpreter->inputs()[model_->tflite_model_spec()->input_user_id()],
614 {1, conversation_length});
615 }
616 if (model_->tflite_model_spec()->input_time_diffs() >= 0) {
617 interpreter->ResizeInputTensor(
618 interpreter
619 ->inputs()[model_->tflite_model_spec()->input_time_diffs()],
620 {1, conversation_length});
621 }
622 if (model_->tflite_model_spec()->input_num_tokens() >= 0) {
623 interpreter->ResizeInputTensor(
624 interpreter
625 ->inputs()[model_->tflite_model_spec()->input_num_tokens()],
626 {conversation_length, 1});
627 }
628 if (model_->tflite_model_spec()->input_token_embeddings() >= 0) {
629 interpreter->ResizeInputTensor(
630 interpreter
631 ->inputs()[model_->tflite_model_spec()->input_token_embeddings()],
632 {conversation_length, max_tokens, token_embedding_size_});
633 }
634 if (model_->tflite_model_spec()->input_flattened_token_embeddings() >= 0) {
635 interpreter->ResizeInputTensor(
636 interpreter->inputs()[model_->tflite_model_spec()
637 ->input_flattened_token_embeddings()],
638 {1, total_token_count});
639 }
640 }
641
642 return interpreter->AllocateTensors() == kTfLiteOk;
643 }
644
SetupModelInput(const std::vector<std::string> & context,const std::vector<int> & user_ids,const std::vector<float> & time_diffs,const int num_suggestions,const ActionSuggestionOptions & options,tflite::Interpreter * interpreter) const645 bool ActionsSuggestions::SetupModelInput(
646 const std::vector<std::string>& context, const std::vector<int>& user_ids,
647 const std::vector<float>& time_diffs, const int num_suggestions,
648 const ActionSuggestionOptions& options,
649 tflite::Interpreter* interpreter) const {
650 // Compute token embeddings.
651 std::vector<std::vector<Token>> tokens;
652 std::vector<float> token_embeddings;
653 std::vector<float> flattened_token_embeddings;
654 int max_tokens = 0;
655 int total_token_count = 0;
656 if (model_->tflite_model_spec()->input_num_tokens() >= 0 ||
657 model_->tflite_model_spec()->input_token_embeddings() >= 0 ||
658 model_->tflite_model_spec()->input_flattened_token_embeddings() >= 0) {
659 if (feature_processor_ == nullptr) {
660 TC3_LOG(ERROR) << "No feature processor specified.";
661 return false;
662 }
663
664 // Tokenize the messages in the conversation.
665 tokens = Tokenize(context);
666 if (model_->tflite_model_spec()->input_token_embeddings() >= 0) {
667 if (!EmbedTokensPerMessage(tokens, &token_embeddings, &max_tokens)) {
668 TC3_LOG(ERROR) << "Could not extract token features.";
669 return false;
670 }
671 }
672 if (model_->tflite_model_spec()->input_flattened_token_embeddings() >= 0) {
673 if (!EmbedAndFlattenTokens(tokens, &flattened_token_embeddings,
674 &total_token_count)) {
675 TC3_LOG(ERROR) << "Could not extract token features.";
676 return false;
677 }
678 }
679 }
680
681 if (!AllocateInput(context.size(), max_tokens, total_token_count,
682 interpreter)) {
683 TC3_LOG(ERROR) << "TensorFlow Lite model allocation failed.";
684 return false;
685 }
686 if (model_->tflite_model_spec()->input_context() >= 0) {
687 if (model_->tflite_model_spec()->input_length_to_pad() > 0) {
688 model_executor_->SetInput<std::string>(
689 model_->tflite_model_spec()->input_context(),
690 PadOrTruncateToTargetLength(
691 context, model_->tflite_model_spec()->input_length_to_pad(),
692 std::string("")),
693 interpreter);
694 } else {
695 model_executor_->SetInput<std::string>(
696 model_->tflite_model_spec()->input_context(), context, interpreter);
697 }
698 }
699 if (model_->tflite_model_spec()->input_context_length() >= 0) {
700 model_executor_->SetInput<int>(
701 model_->tflite_model_spec()->input_context_length(), context.size(),
702 interpreter);
703 }
704 if (model_->tflite_model_spec()->input_user_id() >= 0) {
705 if (model_->tflite_model_spec()->input_length_to_pad() > 0) {
706 model_executor_->SetInput<int>(
707 model_->tflite_model_spec()->input_user_id(),
708 PadOrTruncateToTargetLength(
709 user_ids, model_->tflite_model_spec()->input_length_to_pad(), 0),
710 interpreter);
711 } else {
712 model_executor_->SetInput<int>(
713 model_->tflite_model_spec()->input_user_id(), user_ids, interpreter);
714 }
715 }
716 if (model_->tflite_model_spec()->input_num_suggestions() >= 0) {
717 model_executor_->SetInput<int>(
718 model_->tflite_model_spec()->input_num_suggestions(), num_suggestions,
719 interpreter);
720 }
721 if (model_->tflite_model_spec()->input_time_diffs() >= 0) {
722 model_executor_->SetInput<float>(
723 model_->tflite_model_spec()->input_time_diffs(), time_diffs,
724 interpreter);
725 }
726 if (model_->tflite_model_spec()->input_num_tokens() >= 0) {
727 std::vector<int> num_tokens_per_message(tokens.size());
728 for (int i = 0; i < tokens.size(); i++) {
729 num_tokens_per_message[i] = tokens[i].size();
730 }
731 model_executor_->SetInput<int>(
732 model_->tflite_model_spec()->input_num_tokens(), num_tokens_per_message,
733 interpreter);
734 }
735 if (model_->tflite_model_spec()->input_token_embeddings() >= 0) {
736 model_executor_->SetInput<float>(
737 model_->tflite_model_spec()->input_token_embeddings(), token_embeddings,
738 interpreter);
739 }
740 if (model_->tflite_model_spec()->input_flattened_token_embeddings() >= 0) {
741 model_executor_->SetInput<float>(
742 model_->tflite_model_spec()->input_flattened_token_embeddings(),
743 flattened_token_embeddings, interpreter);
744 }
745 // Set up additional input parameters.
746 if (const auto* input_name_index =
747 model_->tflite_model_spec()->input_name_index()) {
748 const std::unordered_map<std::string, Variant>& model_parameters =
749 options.model_parameters;
750 for (const TensorflowLiteModelSpec_::InputNameIndexEntry* entry :
751 *input_name_index) {
752 const std::string param_name = entry->key()->str();
753 const int param_index = entry->value();
754 const TfLiteType param_type =
755 interpreter->tensor(interpreter->inputs()[param_index])->type;
756 const auto param_value_it = model_parameters.find(param_name);
757 const bool has_value = param_value_it != model_parameters.end();
758 switch (param_type) {
759 case kTfLiteFloat32:
760 if (has_value) {
761 SetVectorOrScalarAsModelInput<float>(param_index,
762 param_value_it->second,
763 interpreter, model_executor_);
764 } else {
765 model_executor_->SetInput<float>(param_index, kDefaultFloat,
766 interpreter);
767 }
768 break;
769 case kTfLiteInt32:
770 if (has_value) {
771 SetVectorOrScalarAsModelInput<int32_t>(
772 param_index, param_value_it->second, interpreter,
773 model_executor_);
774 } else {
775 model_executor_->SetInput<int32_t>(param_index, kDefaultInt,
776 interpreter);
777 }
778 break;
779 case kTfLiteInt64:
780 model_executor_->SetInput<int64_t>(
781 param_index,
782 has_value ? param_value_it->second.Value<int64>() : kDefaultInt,
783 interpreter);
784 break;
785 case kTfLiteUInt8:
786 model_executor_->SetInput<uint8_t>(
787 param_index,
788 has_value ? param_value_it->second.Value<uint8>() : kDefaultInt,
789 interpreter);
790 break;
791 case kTfLiteInt8:
792 model_executor_->SetInput<int8_t>(
793 param_index,
794 has_value ? param_value_it->second.Value<int8>() : kDefaultInt,
795 interpreter);
796 break;
797 case kTfLiteBool:
798 model_executor_->SetInput<bool>(
799 param_index,
800 has_value ? param_value_it->second.Value<bool>() : kDefaultBool,
801 interpreter);
802 break;
803 default:
804 TC3_LOG(ERROR) << "Unsupported type of additional input parameter: "
805 << param_name;
806 }
807 }
808 }
809 return true;
810 }
811
PopulateTextReplies(const tflite::Interpreter * interpreter,int suggestion_index,int score_index,const std::string & type,float priority_score,const absl::flat_hash_set<std::string> & blocklist,ActionsSuggestionsResponse * response) const812 void ActionsSuggestions::PopulateTextReplies(
813 const tflite::Interpreter* interpreter, int suggestion_index,
814 int score_index, const std::string& type, float priority_score,
815 const absl::flat_hash_set<std::string>& blocklist,
816 ActionsSuggestionsResponse* response) const {
817 const std::vector<tflite::StringRef> replies =
818 model_executor_->Output<tflite::StringRef>(suggestion_index, interpreter);
819 const TensorView<float> scores =
820 model_executor_->OutputView<float>(score_index, interpreter);
821
822 for (int i = 0; i < replies.size(); i++) {
823 if (replies[i].len == 0) {
824 continue;
825 }
826 const float score = scores.data()[i];
827 if (score < preconditions_.min_reply_score_threshold) {
828 continue;
829 }
830 std::string response_text(replies[i].str, replies[i].len);
831 if (blocklist.contains(response_text)) {
832 continue;
833 }
834
835 response->actions.push_back({response_text, type, score, priority_score});
836 }
837 }
838
FillSuggestionFromSpecWithEntityData(const ActionSuggestionSpec * spec,ActionSuggestion * suggestion) const839 void ActionsSuggestions::FillSuggestionFromSpecWithEntityData(
840 const ActionSuggestionSpec* spec, ActionSuggestion* suggestion) const {
841 std::unique_ptr<MutableFlatbuffer> entity_data =
842 entity_data_builder_ != nullptr ? entity_data_builder_->NewRoot()
843 : nullptr;
844 FillSuggestionFromSpec(spec, entity_data.get(), suggestion);
845 }
846
PopulateIntentTriggering(const tflite::Interpreter * interpreter,int suggestion_index,int score_index,const ActionSuggestionSpec * task_spec,ActionsSuggestionsResponse * response) const847 void ActionsSuggestions::PopulateIntentTriggering(
848 const tflite::Interpreter* interpreter, int suggestion_index,
849 int score_index, const ActionSuggestionSpec* task_spec,
850 ActionsSuggestionsResponse* response) const {
851 if (!task_spec || task_spec->type()->size() == 0) {
852 TC3_LOG(ERROR)
853 << "Task type for intent (action) triggering cannot be empty!";
854 return;
855 }
856 const TensorView<bool> intent_prediction =
857 model_executor_->OutputView<bool>(suggestion_index, interpreter);
858 const TensorView<float> intent_scores =
859 model_executor_->OutputView<float>(score_index, interpreter);
860 // Two result corresponding to binary triggering case.
861 TC3_CHECK_EQ(intent_prediction.size(), 2);
862 TC3_CHECK_EQ(intent_scores.size(), 2);
863 // We rely on in-graph thresholding logic so at this point the results
864 // have been ranked properly according to threshold.
865 const bool triggering = intent_prediction.data()[0];
866 const float trigger_score = intent_scores.data()[0];
867
868 if (triggering) {
869 ActionSuggestion suggestion;
870 std::unique_ptr<MutableFlatbuffer> entity_data =
871 entity_data_builder_ != nullptr ? entity_data_builder_->NewRoot()
872 : nullptr;
873 FillSuggestionFromSpecWithEntityData(task_spec, &suggestion);
874 suggestion.score = trigger_score;
875 response->actions.push_back(std::move(suggestion));
876 }
877 }
878
ReadModelOutput(tflite::Interpreter * interpreter,const ActionSuggestionOptions & options,ActionsSuggestionsResponse * response) const879 bool ActionsSuggestions::ReadModelOutput(
880 tflite::Interpreter* interpreter, const ActionSuggestionOptions& options,
881 ActionsSuggestionsResponse* response) const {
882 // Read sensitivity and triggering score predictions.
883 if (model_->tflite_model_spec()->output_triggering_score() >= 0) {
884 const TensorView<float> triggering_score =
885 model_executor_->OutputView<float>(
886 model_->tflite_model_spec()->output_triggering_score(),
887 interpreter);
888 if (!triggering_score.is_valid() || triggering_score.size() == 0) {
889 TC3_LOG(ERROR) << "Could not compute triggering score.";
890 return false;
891 }
892 response->triggering_score = triggering_score.data()[0];
893 response->output_filtered_min_triggering_score =
894 (response->triggering_score <
895 preconditions_.min_smart_reply_triggering_score);
896 }
897 if (model_->tflite_model_spec()->output_sensitive_topic_score() >= 0) {
898 const TensorView<float> sensitive_topic_score =
899 model_executor_->OutputView<float>(
900 model_->tflite_model_spec()->output_sensitive_topic_score(),
901 interpreter);
902 if (!sensitive_topic_score.is_valid() ||
903 sensitive_topic_score.dim(0) != 1) {
904 TC3_LOG(ERROR) << "Could not compute sensitive topic score.";
905 return false;
906 }
907 response->sensitivity_score = sensitive_topic_score.data()[0];
908 response->is_sensitive = (response->sensitivity_score >
909 preconditions_.max_sensitive_topic_score);
910 }
911
912 // Suppress model outputs.
913 if (response->is_sensitive) {
914 return true;
915 }
916
917 // Read smart reply predictions.
918 if (!response->output_filtered_min_triggering_score &&
919 model_->tflite_model_spec()->output_replies() >= 0) {
920 absl::flat_hash_set<std::string> empty_blocklist;
921 PopulateTextReplies(interpreter,
922 model_->tflite_model_spec()->output_replies(),
923 model_->tflite_model_spec()->output_replies_scores(),
924 model_->smart_reply_action_type()->str(),
925 /* priority_score */ 0.0, empty_blocklist, response);
926 }
927
928 // Read actions suggestions.
929 if (model_->tflite_model_spec()->output_actions_scores() >= 0) {
930 const TensorView<float> actions_scores = model_executor_->OutputView<float>(
931 model_->tflite_model_spec()->output_actions_scores(), interpreter);
932 for (int i = 0; i < model_->action_type()->size(); i++) {
933 const ActionTypeOptions* action_type = model_->action_type()->Get(i);
934 // Skip disabled action classes, such as the default other category.
935 if (!action_type->enabled()) {
936 continue;
937 }
938 const float score = actions_scores.data()[i];
939 if (score < action_type->min_triggering_score()) {
940 continue;
941 }
942
943 // Create action from model output.
944 ActionSuggestion suggestion;
945 suggestion.type = action_type->name()->str();
946 std::unique_ptr<MutableFlatbuffer> entity_data =
947 entity_data_builder_ != nullptr ? entity_data_builder_->NewRoot()
948 : nullptr;
949 FillSuggestionFromSpecWithEntityData(action_type->action(), &suggestion);
950 suggestion.score = score;
951 response->actions.push_back(std::move(suggestion));
952 }
953 }
954
955 // Read multi-task predictions and construct the result properly.
956 if (const auto* prediction_metadata =
957 model_->tflite_model_spec()->prediction_metadata()) {
958 for (const PredictionMetadata* metadata : *prediction_metadata) {
959 const ActionSuggestionSpec* task_spec = metadata->task_spec();
960 const int suggestions_index = metadata->output_suggestions();
961 const int suggestions_scores_index =
962 metadata->output_suggestions_scores();
963 absl::flat_hash_set<std::string> response_text_blocklist;
964 switch (metadata->prediction_type()) {
965 case PredictionType_NEXT_MESSAGE_PREDICTION:
966 if (!task_spec || task_spec->type()->size() == 0) {
967 TC3_LOG(WARNING) << "Task type not provided, use default "
968 "smart_reply_action_type!";
969 }
970 if (task_spec) {
971 if (task_spec->response_text_blocklist()) {
972 for (const auto& val : *task_spec->response_text_blocklist()) {
973 response_text_blocklist.insert(val->str());
974 }
975 }
976 }
977 PopulateTextReplies(
978 interpreter, suggestions_index, suggestions_scores_index,
979 task_spec ? task_spec->type()->str()
980 : model_->smart_reply_action_type()->str(),
981 task_spec ? task_spec->priority_score() : 0.0,
982 response_text_blocklist, response);
983 break;
984 case PredictionType_INTENT_TRIGGERING:
985 PopulateIntentTriggering(interpreter, suggestions_index,
986 suggestions_scores_index, task_spec,
987 response);
988 break;
989 default:
990 TC3_LOG(ERROR) << "Unsupported prediction type!";
991 return false;
992 }
993 }
994 }
995
996 return true;
997 }
998
SuggestActionsFromModel(const Conversation & conversation,const int num_messages,const ActionSuggestionOptions & options,ActionsSuggestionsResponse * response,std::unique_ptr<tflite::Interpreter> * interpreter) const999 bool ActionsSuggestions::SuggestActionsFromModel(
1000 const Conversation& conversation, const int num_messages,
1001 const ActionSuggestionOptions& options,
1002 ActionsSuggestionsResponse* response,
1003 std::unique_ptr<tflite::Interpreter>* interpreter) const {
1004 TC3_CHECK_LE(num_messages, conversation.messages.size());
1005
1006 if (sensitive_model_ != nullptr &&
1007 sensitive_model_->EvalConversation(conversation, num_messages).first) {
1008 response->is_sensitive = true;
1009 return true;
1010 }
1011
1012 if (!model_executor_) {
1013 return true;
1014 }
1015 *interpreter = model_executor_->CreateInterpreter();
1016
1017 if (!*interpreter) {
1018 TC3_LOG(ERROR) << "Could not build TensorFlow Lite interpreter for the "
1019 "actions suggestions model.";
1020 return false;
1021 }
1022
1023 std::vector<std::string> context;
1024 std::vector<int> user_ids;
1025 std::vector<float> time_diffs;
1026 context.reserve(num_messages);
1027 user_ids.reserve(num_messages);
1028 time_diffs.reserve(num_messages);
1029
1030 // Gather last `num_messages` messages from the conversation.
1031 int64 last_message_reference_time_ms_utc = 0;
1032 const float second_in_ms = 1000;
1033 for (int i = conversation.messages.size() - num_messages;
1034 i < conversation.messages.size(); i++) {
1035 const ConversationMessage& message = conversation.messages[i];
1036 context.push_back(message.text);
1037 user_ids.push_back(message.user_id);
1038
1039 float time_diff_secs = 0;
1040 if (message.reference_time_ms_utc != 0 &&
1041 last_message_reference_time_ms_utc != 0) {
1042 time_diff_secs = std::max(0.0f, (message.reference_time_ms_utc -
1043 last_message_reference_time_ms_utc) /
1044 second_in_ms);
1045 }
1046 if (message.reference_time_ms_utc != 0) {
1047 last_message_reference_time_ms_utc = message.reference_time_ms_utc;
1048 }
1049 time_diffs.push_back(time_diff_secs);
1050 }
1051
1052 if (!SetupModelInput(context, user_ids, time_diffs,
1053 /*num_suggestions=*/model_->num_smart_replies(), options,
1054 interpreter->get())) {
1055 TC3_LOG(ERROR) << "Failed to setup input for TensorFlow Lite model.";
1056 return false;
1057 }
1058
1059 if ((*interpreter)->Invoke() != kTfLiteOk) {
1060 TC3_LOG(ERROR) << "Failed to invoke TensorFlow Lite interpreter.";
1061 return false;
1062 }
1063
1064 return ReadModelOutput(interpreter->get(), options, response);
1065 }
1066
SuggestActionsFromConversationIntentDetection(const Conversation & conversation,const ActionSuggestionOptions & options,std::vector<ActionSuggestion> * actions) const1067 Status ActionsSuggestions::SuggestActionsFromConversationIntentDetection(
1068 const Conversation& conversation, const ActionSuggestionOptions& options,
1069 std::vector<ActionSuggestion>* actions) const {
1070 TC3_ASSIGN_OR_RETURN(
1071 std::vector<ActionSuggestion> new_actions,
1072 conversation_intent_detection_->SuggestActions(conversation, options));
1073 for (auto& action : new_actions) {
1074 actions->push_back(std::move(action));
1075 }
1076 return Status::OK;
1077 }
1078
AnnotationOptionsForMessage(const ConversationMessage & message) const1079 AnnotationOptions ActionsSuggestions::AnnotationOptionsForMessage(
1080 const ConversationMessage& message) const {
1081 AnnotationOptions options;
1082 options.detected_text_language_tags = message.detected_text_language_tags;
1083 options.reference_time_ms_utc = message.reference_time_ms_utc;
1084 options.reference_timezone = message.reference_timezone;
1085 options.annotation_usecase =
1086 model_->annotation_actions_spec()->annotation_usecase();
1087 options.is_serialized_entity_data_enabled =
1088 model_->annotation_actions_spec()->is_serialized_entity_data_enabled();
1089 options.entity_types = annotation_entity_types_;
1090 return options;
1091 }
1092
1093 // Run annotator on the messages of a conversation.
AnnotateConversation(const Conversation & conversation,const Annotator * annotator) const1094 Conversation ActionsSuggestions::AnnotateConversation(
1095 const Conversation& conversation, const Annotator* annotator) const {
1096 if (annotator == nullptr) {
1097 return conversation;
1098 }
1099 const int num_messages_grammar =
1100 ((model_->rules() && model_->rules()->grammar_rules() &&
1101 model_->rules()
1102 ->grammar_rules()
1103 ->rules()
1104 ->nonterminals()
1105 ->annotation_nt())
1106 ? 1
1107 : 0);
1108 const int num_messages_mapping =
1109 (model_->annotation_actions_spec()
1110 ? std::max(model_->annotation_actions_spec()
1111 ->max_history_from_any_person(),
1112 model_->annotation_actions_spec()
1113 ->max_history_from_last_person())
1114 : 0);
1115 const int num_messages = std::max(num_messages_grammar, num_messages_mapping);
1116 if (num_messages == 0) {
1117 // No annotations are used.
1118 return conversation;
1119 }
1120 Conversation annotated_conversation = conversation;
1121 for (int i = 0, message_index = annotated_conversation.messages.size() - 1;
1122 i < num_messages && message_index >= 0; i++, message_index--) {
1123 ConversationMessage* message =
1124 &annotated_conversation.messages[message_index];
1125 if (message->annotations.empty()) {
1126 message->annotations = annotator->Annotate(
1127 message->text, AnnotationOptionsForMessage(*message));
1128 ConvertDatetimeToTime(&message->annotations);
1129 }
1130 }
1131 return annotated_conversation;
1132 }
1133
SuggestActionsFromAnnotations(const Conversation & conversation,std::vector<ActionSuggestion> * actions) const1134 void ActionsSuggestions::SuggestActionsFromAnnotations(
1135 const Conversation& conversation,
1136 std::vector<ActionSuggestion>* actions) const {
1137 if (model_->annotation_actions_spec() == nullptr ||
1138 model_->annotation_actions_spec()->annotation_mapping() == nullptr ||
1139 model_->annotation_actions_spec()->annotation_mapping()->size() == 0) {
1140 return;
1141 }
1142
1143 // Create actions based on the annotations.
1144 const int max_from_any_person =
1145 model_->annotation_actions_spec()->max_history_from_any_person();
1146 const int max_from_last_person =
1147 model_->annotation_actions_spec()->max_history_from_last_person();
1148 const int last_person = conversation.messages.back().user_id;
1149
1150 int num_messages_last_person = 0;
1151 int num_messages_any_person = 0;
1152 bool all_from_last_person = true;
1153 for (int message_index = conversation.messages.size() - 1; message_index >= 0;
1154 message_index--) {
1155 const ConversationMessage& message = conversation.messages[message_index];
1156 std::vector<AnnotatedSpan> annotations = message.annotations;
1157
1158 // Update how many messages we have processed from the last person in the
1159 // conversation and from any person in the conversation.
1160 num_messages_any_person++;
1161 if (all_from_last_person && message.user_id == last_person) {
1162 num_messages_last_person++;
1163 } else {
1164 all_from_last_person = false;
1165 }
1166
1167 if (num_messages_any_person > max_from_any_person &&
1168 (!all_from_last_person ||
1169 num_messages_last_person > max_from_last_person)) {
1170 break;
1171 }
1172
1173 if (message.user_id == kLocalUserId) {
1174 if (model_->annotation_actions_spec()->only_until_last_sent()) {
1175 break;
1176 }
1177 if (!model_->annotation_actions_spec()->include_local_user_messages()) {
1178 continue;
1179 }
1180 }
1181
1182 std::vector<ActionSuggestionAnnotation> action_annotations;
1183 action_annotations.reserve(annotations.size());
1184 for (const AnnotatedSpan& annotation : annotations) {
1185 if (annotation.classification.empty()) {
1186 continue;
1187 }
1188
1189 const ClassificationResult& classification_result =
1190 annotation.classification[0];
1191
1192 ActionSuggestionAnnotation action_annotation;
1193 action_annotation.span = {
1194 message_index, annotation.span,
1195 UTF8ToUnicodeText(message.text, /*do_copy=*/false)
1196 .UTF8Substring(annotation.span.first, annotation.span.second)};
1197 action_annotation.entity = classification_result;
1198 action_annotation.name = classification_result.collection;
1199 action_annotations.push_back(std::move(action_annotation));
1200 }
1201
1202 if (model_->annotation_actions_spec()->deduplicate_annotations()) {
1203 // Create actions only for deduplicated annotations.
1204 for (const int annotation_id :
1205 DeduplicateAnnotations(action_annotations)) {
1206 SuggestActionsFromAnnotation(
1207 message_index, action_annotations[annotation_id], actions);
1208 }
1209 } else {
1210 // Create actions for all annotations.
1211 for (const ActionSuggestionAnnotation& annotation : action_annotations) {
1212 SuggestActionsFromAnnotation(message_index, annotation, actions);
1213 }
1214 }
1215 }
1216 }
1217
SuggestActionsFromAnnotation(const int message_index,const ActionSuggestionAnnotation & annotation,std::vector<ActionSuggestion> * actions) const1218 void ActionsSuggestions::SuggestActionsFromAnnotation(
1219 const int message_index, const ActionSuggestionAnnotation& annotation,
1220 std::vector<ActionSuggestion>* actions) const {
1221 for (const AnnotationActionsSpec_::AnnotationMapping* mapping :
1222 *model_->annotation_actions_spec()->annotation_mapping()) {
1223 if (annotation.entity.collection ==
1224 mapping->annotation_collection()->str()) {
1225 if (annotation.entity.score < mapping->min_annotation_score()) {
1226 continue;
1227 }
1228
1229 std::unique_ptr<MutableFlatbuffer> entity_data =
1230 entity_data_builder_ != nullptr ? entity_data_builder_->NewRoot()
1231 : nullptr;
1232
1233 // Set annotation text as (additional) entity data field.
1234 if (mapping->entity_field() != nullptr) {
1235 TC3_CHECK_NE(entity_data, nullptr);
1236
1237 UnicodeText normalized_annotation_text =
1238 UTF8ToUnicodeText(annotation.span.text, /*do_copy=*/false);
1239
1240 // Apply normalization if specified.
1241 if (mapping->normalization_options() != nullptr) {
1242 normalized_annotation_text =
1243 NormalizeText(*unilib_, mapping->normalization_options(),
1244 normalized_annotation_text);
1245 }
1246
1247 entity_data->ParseAndSet(mapping->entity_field(),
1248 normalized_annotation_text.ToUTF8String());
1249 }
1250
1251 ActionSuggestion suggestion;
1252 FillSuggestionFromSpec(mapping->action(), entity_data.get(), &suggestion);
1253 if (mapping->use_annotation_score()) {
1254 suggestion.score = annotation.entity.score;
1255 }
1256 suggestion.annotations = {annotation};
1257 actions->push_back(std::move(suggestion));
1258 }
1259 }
1260 }
1261
DeduplicateAnnotations(const std::vector<ActionSuggestionAnnotation> & annotations) const1262 std::vector<int> ActionsSuggestions::DeduplicateAnnotations(
1263 const std::vector<ActionSuggestionAnnotation>& annotations) const {
1264 std::map<std::pair<std::string, std::string>, int> deduplicated_annotations;
1265
1266 for (int i = 0; i < annotations.size(); i++) {
1267 const std::pair<std::string, std::string> key = {annotations[i].name,
1268 annotations[i].span.text};
1269 auto entry = deduplicated_annotations.find(key);
1270 if (entry != deduplicated_annotations.end()) {
1271 // Kepp the annotation with the higher score.
1272 if (annotations[entry->second].entity.score <
1273 annotations[i].entity.score) {
1274 entry->second = i;
1275 }
1276 continue;
1277 }
1278 deduplicated_annotations.insert(entry, {key, i});
1279 }
1280
1281 std::vector<int> result;
1282 result.reserve(deduplicated_annotations.size());
1283 for (const auto& key_and_annotation : deduplicated_annotations) {
1284 result.push_back(key_and_annotation.second);
1285 }
1286 return result;
1287 }
1288
1289 #if !defined(TC3_DISABLE_LUA)
SuggestActionsFromLua(const Conversation & conversation,const TfLiteModelExecutor * model_executor,const tflite::Interpreter * interpreter,const reflection::Schema * annotation_entity_data_schema,std::vector<ActionSuggestion> * actions) const1290 bool ActionsSuggestions::SuggestActionsFromLua(
1291 const Conversation& conversation, const TfLiteModelExecutor* model_executor,
1292 const tflite::Interpreter* interpreter,
1293 const reflection::Schema* annotation_entity_data_schema,
1294 std::vector<ActionSuggestion>* actions) const {
1295 if (lua_bytecode_.empty()) {
1296 return true;
1297 }
1298
1299 auto lua_actions = LuaActionsSuggestions::CreateLuaActionsSuggestions(
1300 lua_bytecode_, conversation, model_executor, model_->tflite_model_spec(),
1301 interpreter, entity_data_schema_, annotation_entity_data_schema);
1302 if (lua_actions == nullptr) {
1303 TC3_LOG(ERROR) << "Could not create lua actions.";
1304 return false;
1305 }
1306 return lua_actions->SuggestActions(actions);
1307 }
1308 #else
SuggestActionsFromLua(const Conversation & conversation,const TfLiteModelExecutor * model_executor,const tflite::Interpreter * interpreter,const reflection::Schema * annotation_entity_data_schema,std::vector<ActionSuggestion> * actions) const1309 bool ActionsSuggestions::SuggestActionsFromLua(
1310 const Conversation& conversation, const TfLiteModelExecutor* model_executor,
1311 const tflite::Interpreter* interpreter,
1312 const reflection::Schema* annotation_entity_data_schema,
1313 std::vector<ActionSuggestion>* actions) const {
1314 return true;
1315 }
1316 #endif
1317
GatherActionsSuggestions(const Conversation & conversation,const Annotator * annotator,const ActionSuggestionOptions & options,ActionsSuggestionsResponse * response) const1318 bool ActionsSuggestions::GatherActionsSuggestions(
1319 const Conversation& conversation, const Annotator* annotator,
1320 const ActionSuggestionOptions& options,
1321 ActionsSuggestionsResponse* response) const {
1322 if (conversation.messages.empty()) {
1323 return true;
1324 }
1325
1326 // Run annotator against messages.
1327 const Conversation annotated_conversation =
1328 AnnotateConversation(conversation, annotator);
1329
1330 const int num_messages = NumMessagesToConsider(
1331 annotated_conversation, model_->max_conversation_history_length());
1332
1333 if (num_messages <= 0) {
1334 TC3_LOG(INFO) << "No messages provided for actions suggestions.";
1335 return false;
1336 }
1337
1338 SuggestActionsFromAnnotations(annotated_conversation, &response->actions);
1339
1340 if (grammar_actions_ != nullptr &&
1341 !grammar_actions_->SuggestActions(annotated_conversation,
1342 &response->actions)) {
1343 TC3_LOG(ERROR) << "Could not suggest actions from grammar rules.";
1344 return false;
1345 }
1346
1347 int input_text_length = 0;
1348 int num_matching_locales = 0;
1349 for (int i = annotated_conversation.messages.size() - num_messages;
1350 i < annotated_conversation.messages.size(); i++) {
1351 input_text_length += annotated_conversation.messages[i].text.length();
1352 std::vector<Locale> message_languages;
1353 if (!ParseLocales(
1354 annotated_conversation.messages[i].detected_text_language_tags,
1355 &message_languages)) {
1356 continue;
1357 }
1358 if (Locale::IsAnyLocaleSupported(
1359 message_languages, locales_,
1360 preconditions_.handle_unknown_locale_as_supported)) {
1361 ++num_matching_locales;
1362 }
1363 }
1364
1365 // Bail out if we are provided with too few or too much input.
1366 if (input_text_length < preconditions_.min_input_length ||
1367 (preconditions_.max_input_length >= 0 &&
1368 input_text_length > preconditions_.max_input_length)) {
1369 TC3_LOG(INFO) << "Too much or not enough input for inference.";
1370 return response;
1371 }
1372
1373 // Bail out if the text does not look like it can be handled by the model.
1374 const float matching_fraction =
1375 static_cast<float>(num_matching_locales) / num_messages;
1376 if (matching_fraction < preconditions_.min_locale_match_fraction) {
1377 TC3_LOG(INFO) << "Not enough locale matches.";
1378 response->output_filtered_locale_mismatch = true;
1379 return true;
1380 }
1381
1382 std::vector<const UniLib::RegexPattern*> post_check_rules;
1383 if (preconditions_.suppress_on_low_confidence_input) {
1384 if (regex_actions_->IsLowConfidenceInput(annotated_conversation,
1385 num_messages, &post_check_rules)) {
1386 response->output_filtered_low_confidence = true;
1387 return true;
1388 }
1389 }
1390
1391 std::unique_ptr<tflite::Interpreter> interpreter;
1392 if (!SuggestActionsFromModel(annotated_conversation, num_messages, options,
1393 response, &interpreter)) {
1394 TC3_LOG(ERROR) << "Could not run model.";
1395 return false;
1396 }
1397
1398 // SuggestActionsFromModel also detects if the conversation is sensitive,
1399 // either by using the old ngram model or the new model.
1400 // Suppress all predictions if the conversation was deemed sensitive.
1401 if (preconditions_.suppress_on_sensitive_topic && response->is_sensitive) {
1402 return true;
1403 }
1404
1405 if (conversation_intent_detection_) {
1406 // TODO(zbin): Ensure the deduplication/ranking logic in ranker.cc works.
1407 auto actions = SuggestActionsFromConversationIntentDetection(
1408 annotated_conversation, options, &response->actions);
1409 if (!actions.ok()) {
1410 TC3_LOG(ERROR) << "Could not run conversation intent detection: "
1411 << actions.error_message();
1412 return false;
1413 }
1414 }
1415
1416 if (!SuggestActionsFromLua(
1417 annotated_conversation, model_executor_.get(), interpreter.get(),
1418 annotator != nullptr ? annotator->entity_data_schema() : nullptr,
1419 &response->actions)) {
1420 TC3_LOG(ERROR) << "Could not suggest actions from script.";
1421 return false;
1422 }
1423
1424 if (!regex_actions_->SuggestActions(annotated_conversation,
1425 entity_data_builder_.get(),
1426 &response->actions)) {
1427 TC3_LOG(ERROR) << "Could not suggest actions from regex rules.";
1428 return false;
1429 }
1430
1431 if (preconditions_.suppress_on_low_confidence_input &&
1432 !regex_actions_->FilterConfidenceOutput(post_check_rules,
1433 &response->actions)) {
1434 TC3_LOG(ERROR) << "Could not post-check actions.";
1435 return false;
1436 }
1437
1438 return true;
1439 }
1440
SuggestActions(const Conversation & conversation,const Annotator * annotator,const ActionSuggestionOptions & options) const1441 ActionsSuggestionsResponse ActionsSuggestions::SuggestActions(
1442 const Conversation& conversation, const Annotator* annotator,
1443 const ActionSuggestionOptions& options) const {
1444 ActionsSuggestionsResponse response;
1445
1446 // Assert that messages are sorted correctly.
1447 for (int i = 1; i < conversation.messages.size(); i++) {
1448 if (conversation.messages[i].reference_time_ms_utc <
1449 conversation.messages[i - 1].reference_time_ms_utc) {
1450 TC3_LOG(ERROR) << "Messages are not sorted most recent last.";
1451 return response;
1452 }
1453 }
1454
1455 // Check that messages are valid utf8.
1456 for (const ConversationMessage& message : conversation.messages) {
1457 if (message.text.size() > std::numeric_limits<int>::max()) {
1458 TC3_LOG(ERROR) << "Rejecting too long input: " << message.text.size();
1459 return {};
1460 }
1461
1462 if (!unilib_->IsValidUtf8(UTF8ToUnicodeText(
1463 message.text.data(), message.text.size(), /*do_copy=*/false))) {
1464 TC3_LOG(ERROR) << "Not valid utf8 provided.";
1465 return response;
1466 }
1467 }
1468
1469 if (!GatherActionsSuggestions(conversation, annotator, options, &response)) {
1470 TC3_LOG(ERROR) << "Could not gather actions suggestions.";
1471 response.actions.clear();
1472 } else if (!ranker_->RankActions(conversation, &response, entity_data_schema_,
1473 annotator != nullptr
1474 ? annotator->entity_data_schema()
1475 : nullptr)) {
1476 TC3_LOG(ERROR) << "Could not rank actions.";
1477 response.actions.clear();
1478 }
1479 return response;
1480 }
1481
SuggestActions(const Conversation & conversation,const ActionSuggestionOptions & options) const1482 ActionsSuggestionsResponse ActionsSuggestions::SuggestActions(
1483 const Conversation& conversation,
1484 const ActionSuggestionOptions& options) const {
1485 return SuggestActions(conversation, /*annotator=*/nullptr, options);
1486 }
1487
model() const1488 const ActionsModel* ActionsSuggestions::model() const { return model_; }
entity_data_schema() const1489 const reflection::Schema* ActionsSuggestions::entity_data_schema() const {
1490 return entity_data_schema_;
1491 }
1492
ViewActionsModel(const void * buffer,int size)1493 const ActionsModel* ViewActionsModel(const void* buffer, int size) {
1494 if (buffer == nullptr) {
1495 return nullptr;
1496 }
1497 return LoadAndVerifyModel(reinterpret_cast<const uint8_t*>(buffer), size);
1498 }
1499
InitializeConversationIntentDetection(const std::string & serialized_config)1500 bool ActionsSuggestions::InitializeConversationIntentDetection(
1501 const std::string& serialized_config) {
1502 auto conversation_intent_detection =
1503 std::make_unique<ConversationIntentDetection>();
1504 if (!conversation_intent_detection->Initialize(serialized_config).ok()) {
1505 TC3_LOG(ERROR) << "Failed to initialize conversation intent detection.";
1506 return false;
1507 }
1508 conversation_intent_detection_ = std::move(conversation_intent_detection);
1509 return true;
1510 }
1511
1512 } // namespace libtextclassifier3
1513