1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "actions/actions-suggestions.h"
18
19 #include <memory>
20 #include <vector>
21
22 #include "utils/base/statusor.h"
23
24 #if !defined(TC3_DISABLE_LUA)
25 #include "actions/lua-actions.h"
26 #endif
27 #include "actions/ngram-model.h"
28 #include "actions/tflite-sensitive-model.h"
29 #include "actions/types.h"
30 #include "actions/utils.h"
31 #include "actions/zlib-utils.h"
32 #include "annotator/collections.h"
33 #include "utils/base/logging.h"
34 #if !defined(TC3_DISABLE_LUA)
35 #include "utils/lua-utils.h"
36 #endif
37 #include "utils/normalization.h"
38 #include "utils/optional.h"
39 #include "utils/strings/split.h"
40 #include "utils/strings/stringpiece.h"
41 #include "utils/strings/utf8.h"
42 #include "utils/utf8/unicodetext.h"
43 #include "tensorflow/lite/string_util.h"
44
45 namespace libtextclassifier3 {
46
47 constexpr float kDefaultFloat = 0.0;
48 constexpr bool kDefaultBool = false;
49 constexpr int kDefaultInt = 1;
50
51 namespace {
52
LoadAndVerifyModel(const uint8_t * addr,int size)53 const ActionsModel* LoadAndVerifyModel(const uint8_t* addr, int size) {
54 flatbuffers::Verifier verifier(addr, size);
55 if (VerifyActionsModelBuffer(verifier)) {
56 return GetActionsModel(addr);
57 } else {
58 return nullptr;
59 }
60 }
61
62 template <typename T>
ValueOrDefault(const flatbuffers::Table * values,const int32 field_offset,const T default_value)63 T ValueOrDefault(const flatbuffers::Table* values, const int32 field_offset,
64 const T default_value) {
65 if (values == nullptr) {
66 return default_value;
67 }
68 return values->GetField<T>(field_offset, default_value);
69 }
70
71 // Returns number of (tail) messages of a conversation to consider.
NumMessagesToConsider(const Conversation & conversation,const int max_conversation_history_length)72 int NumMessagesToConsider(const Conversation& conversation,
73 const int max_conversation_history_length) {
74 return ((max_conversation_history_length < 0 ||
75 conversation.messages.size() < max_conversation_history_length)
76 ? conversation.messages.size()
77 : max_conversation_history_length);
78 }
79
80 template <typename T>
PadOrTruncateToTargetLength(const std::vector<T> & inputs,const int max_length,const T pad_value)81 std::vector<T> PadOrTruncateToTargetLength(const std::vector<T>& inputs,
82 const int max_length,
83 const T pad_value) {
84 if (inputs.size() >= max_length) {
85 return std::vector<T>(inputs.begin(), inputs.begin() + max_length);
86 } else {
87 std::vector<T> result;
88 result.reserve(max_length);
89 result.insert(result.begin(), inputs.begin(), inputs.end());
90 result.insert(result.end(), max_length - inputs.size(), pad_value);
91 return result;
92 }
93 }
94
95 template <typename T>
SetVectorOrScalarAsModelInput(const int param_index,const Variant & param_value,tflite::Interpreter * interpreter,const std::unique_ptr<const TfLiteModelExecutor> & model_executor)96 void SetVectorOrScalarAsModelInput(
97 const int param_index, const Variant& param_value,
98 tflite::Interpreter* interpreter,
99 const std::unique_ptr<const TfLiteModelExecutor>& model_executor) {
100 if (param_value.Has<std::vector<T>>()) {
101 model_executor->SetInput<T>(
102 param_index, param_value.ConstRefValue<std::vector<T>>(), interpreter);
103 } else if (param_value.Has<T>()) {
104 model_executor->SetInput<float>(param_index, param_value.Value<T>(),
105 interpreter);
106 } else {
107 TC3_LOG(ERROR) << "Variant type error!";
108 }
109 }
110 } // namespace
111
FromUnownedBuffer(const uint8_t * buffer,const int size,const UniLib * unilib,const std::string & triggering_preconditions_overlay)112 std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromUnownedBuffer(
113 const uint8_t* buffer, const int size, const UniLib* unilib,
114 const std::string& triggering_preconditions_overlay) {
115 auto actions = std::unique_ptr<ActionsSuggestions>(new ActionsSuggestions());
116 const ActionsModel* model = LoadAndVerifyModel(buffer, size);
117 if (model == nullptr) {
118 return nullptr;
119 }
120 actions->model_ = model;
121 actions->SetOrCreateUnilib(unilib);
122 actions->triggering_preconditions_overlay_buffer_ =
123 triggering_preconditions_overlay;
124 if (!actions->ValidateAndInitialize()) {
125 return nullptr;
126 }
127 return actions;
128 }
129
FromScopedMmap(std::unique_ptr<libtextclassifier3::ScopedMmap> mmap,const UniLib * unilib,const std::string & triggering_preconditions_overlay)130 std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromScopedMmap(
131 std::unique_ptr<libtextclassifier3::ScopedMmap> mmap, const UniLib* unilib,
132 const std::string& triggering_preconditions_overlay) {
133 if (!mmap->handle().ok()) {
134 TC3_VLOG(1) << "Mmap failed.";
135 return nullptr;
136 }
137 const ActionsModel* model = LoadAndVerifyModel(
138 reinterpret_cast<const uint8_t*>(mmap->handle().start()),
139 mmap->handle().num_bytes());
140 if (!model) {
141 TC3_LOG(ERROR) << "Model verification failed.";
142 return nullptr;
143 }
144 auto actions = std::unique_ptr<ActionsSuggestions>(new ActionsSuggestions());
145 actions->model_ = model;
146 actions->mmap_ = std::move(mmap);
147 actions->SetOrCreateUnilib(unilib);
148 actions->triggering_preconditions_overlay_buffer_ =
149 triggering_preconditions_overlay;
150 if (!actions->ValidateAndInitialize()) {
151 return nullptr;
152 }
153 return actions;
154 }
155
FromScopedMmap(std::unique_ptr<libtextclassifier3::ScopedMmap> mmap,std::unique_ptr<UniLib> unilib,const std::string & triggering_preconditions_overlay)156 std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromScopedMmap(
157 std::unique_ptr<libtextclassifier3::ScopedMmap> mmap,
158 std::unique_ptr<UniLib> unilib,
159 const std::string& triggering_preconditions_overlay) {
160 if (!mmap->handle().ok()) {
161 TC3_VLOG(1) << "Mmap failed.";
162 return nullptr;
163 }
164 const ActionsModel* model = LoadAndVerifyModel(
165 reinterpret_cast<const uint8_t*>(mmap->handle().start()),
166 mmap->handle().num_bytes());
167 if (!model) {
168 TC3_LOG(ERROR) << "Model verification failed.";
169 return nullptr;
170 }
171 auto actions = std::unique_ptr<ActionsSuggestions>(new ActionsSuggestions());
172 actions->model_ = model;
173 actions->mmap_ = std::move(mmap);
174 actions->owned_unilib_ = std::move(unilib);
175 actions->unilib_ = actions->owned_unilib_.get();
176 actions->triggering_preconditions_overlay_buffer_ =
177 triggering_preconditions_overlay;
178 if (!actions->ValidateAndInitialize()) {
179 return nullptr;
180 }
181 return actions;
182 }
183
FromFileDescriptor(const int fd,const int offset,const int size,const UniLib * unilib,const std::string & triggering_preconditions_overlay)184 std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromFileDescriptor(
185 const int fd, const int offset, const int size, const UniLib* unilib,
186 const std::string& triggering_preconditions_overlay) {
187 std::unique_ptr<libtextclassifier3::ScopedMmap> mmap;
188 if (offset >= 0 && size >= 0) {
189 mmap.reset(new libtextclassifier3::ScopedMmap(fd, offset, size));
190 } else {
191 mmap.reset(new libtextclassifier3::ScopedMmap(fd));
192 }
193 return FromScopedMmap(std::move(mmap), unilib,
194 triggering_preconditions_overlay);
195 }
196
FromFileDescriptor(const int fd,const int offset,const int size,std::unique_ptr<UniLib> unilib,const std::string & triggering_preconditions_overlay)197 std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromFileDescriptor(
198 const int fd, const int offset, const int size,
199 std::unique_ptr<UniLib> unilib,
200 const std::string& triggering_preconditions_overlay) {
201 std::unique_ptr<libtextclassifier3::ScopedMmap> mmap;
202 if (offset >= 0 && size >= 0) {
203 mmap.reset(new libtextclassifier3::ScopedMmap(fd, offset, size));
204 } else {
205 mmap.reset(new libtextclassifier3::ScopedMmap(fd));
206 }
207 return FromScopedMmap(std::move(mmap), std::move(unilib),
208 triggering_preconditions_overlay);
209 }
210
FromFileDescriptor(const int fd,const UniLib * unilib,const std::string & triggering_preconditions_overlay)211 std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromFileDescriptor(
212 const int fd, const UniLib* unilib,
213 const std::string& triggering_preconditions_overlay) {
214 std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
215 new libtextclassifier3::ScopedMmap(fd));
216 return FromScopedMmap(std::move(mmap), unilib,
217 triggering_preconditions_overlay);
218 }
219
FromFileDescriptor(const int fd,std::unique_ptr<UniLib> unilib,const std::string & triggering_preconditions_overlay)220 std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromFileDescriptor(
221 const int fd, std::unique_ptr<UniLib> unilib,
222 const std::string& triggering_preconditions_overlay) {
223 std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
224 new libtextclassifier3::ScopedMmap(fd));
225 return FromScopedMmap(std::move(mmap), std::move(unilib),
226 triggering_preconditions_overlay);
227 }
228
FromPath(const std::string & path,const UniLib * unilib,const std::string & triggering_preconditions_overlay)229 std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromPath(
230 const std::string& path, const UniLib* unilib,
231 const std::string& triggering_preconditions_overlay) {
232 std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
233 new libtextclassifier3::ScopedMmap(path));
234 return FromScopedMmap(std::move(mmap), unilib,
235 triggering_preconditions_overlay);
236 }
237
FromPath(const std::string & path,std::unique_ptr<UniLib> unilib,const std::string & triggering_preconditions_overlay)238 std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromPath(
239 const std::string& path, std::unique_ptr<UniLib> unilib,
240 const std::string& triggering_preconditions_overlay) {
241 std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
242 new libtextclassifier3::ScopedMmap(path));
243 return FromScopedMmap(std::move(mmap), std::move(unilib),
244 triggering_preconditions_overlay);
245 }
246
SetOrCreateUnilib(const UniLib * unilib)247 void ActionsSuggestions::SetOrCreateUnilib(const UniLib* unilib) {
248 if (unilib != nullptr) {
249 unilib_ = unilib;
250 } else {
251 owned_unilib_.reset(new UniLib);
252 unilib_ = owned_unilib_.get();
253 }
254 }
255
ValidateAndInitialize()256 bool ActionsSuggestions::ValidateAndInitialize() {
257 if (model_ == nullptr) {
258 TC3_LOG(ERROR) << "No model specified.";
259 return false;
260 }
261
262 if (model_->smart_reply_action_type() == nullptr) {
263 TC3_LOG(ERROR) << "No smart reply action type specified.";
264 return false;
265 }
266
267 if (!InitializeTriggeringPreconditions()) {
268 TC3_LOG(ERROR) << "Could not initialize preconditions.";
269 return false;
270 }
271
272 if (model_->locales() &&
273 !ParseLocales(model_->locales()->c_str(), &locales_)) {
274 TC3_LOG(ERROR) << "Could not parse model supported locales.";
275 return false;
276 }
277
278 if (model_->tflite_model_spec() != nullptr) {
279 model_executor_ = TfLiteModelExecutor::FromBuffer(
280 model_->tflite_model_spec()->tflite_model());
281 if (!model_executor_) {
282 TC3_LOG(ERROR) << "Could not initialize model executor.";
283 return false;
284 }
285 }
286
287 // Gather annotation entities for the rules.
288 if (model_->annotation_actions_spec() != nullptr &&
289 model_->annotation_actions_spec()->annotation_mapping() != nullptr) {
290 for (const AnnotationActionsSpec_::AnnotationMapping* mapping :
291 *model_->annotation_actions_spec()->annotation_mapping()) {
292 annotation_entity_types_.insert(mapping->annotation_collection()->str());
293 }
294 }
295
296 if (model_->actions_entity_data_schema() != nullptr) {
297 entity_data_schema_ = LoadAndVerifyFlatbuffer<reflection::Schema>(
298 model_->actions_entity_data_schema()->Data(),
299 model_->actions_entity_data_schema()->size());
300 if (entity_data_schema_ == nullptr) {
301 TC3_LOG(ERROR) << "Could not load entity data schema data.";
302 return false;
303 }
304
305 entity_data_builder_.reset(
306 new MutableFlatbufferBuilder(entity_data_schema_));
307 } else {
308 entity_data_schema_ = nullptr;
309 }
310
311 // Initialize regular expressions model.
312 std::unique_ptr<ZlibDecompressor> decompressor = ZlibDecompressor::Instance();
313 regex_actions_.reset(
314 new RegexActions(unilib_, model_->smart_reply_action_type()->str()));
315 if (!regex_actions_->InitializeRules(
316 model_->rules(), model_->low_confidence_rules(),
317 triggering_preconditions_overlay_, decompressor.get())) {
318 TC3_LOG(ERROR) << "Could not initialize regex rules.";
319 return false;
320 }
321
322 // Setup grammar model.
323 if (model_->rules() != nullptr &&
324 model_->rules()->grammar_rules() != nullptr) {
325 grammar_actions_.reset(new GrammarActions(
326 unilib_, model_->rules()->grammar_rules(), entity_data_builder_.get(),
327 model_->smart_reply_action_type()->str()));
328
329 // Gather annotation entities for the grammars.
330 if (auto annotation_nt = model_->rules()
331 ->grammar_rules()
332 ->rules()
333 ->nonterminals()
334 ->annotation_nt()) {
335 for (const grammar::RulesSet_::Nonterminals_::AnnotationNtEntry* entry :
336 *annotation_nt) {
337 annotation_entity_types_.insert(entry->key()->str());
338 }
339 }
340 }
341
342 #if !defined(TC3_DISABLE_LUA)
343 std::string actions_script;
344 if (GetUncompressedString(model_->lua_actions_script(),
345 model_->compressed_lua_actions_script(),
346 decompressor.get(), &actions_script) &&
347 !actions_script.empty()) {
348 if (!Compile(actions_script, &lua_bytecode_)) {
349 TC3_LOG(ERROR) << "Could not precompile lua actions snippet.";
350 return false;
351 }
352 }
353 #endif // TC3_DISABLE_LUA
354
355 if (!(ranker_ = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
356 model_->ranking_options(), decompressor.get(),
357 model_->smart_reply_action_type()->str()))) {
358 TC3_LOG(ERROR) << "Could not create an action suggestions ranker.";
359 return false;
360 }
361
362 // Create feature processor if specified.
363 const ActionsTokenFeatureProcessorOptions* options =
364 model_->feature_processor_options();
365 if (options != nullptr) {
366 if (options->tokenizer_options() == nullptr) {
367 TC3_LOG(ERROR) << "No tokenizer options specified.";
368 return false;
369 }
370
371 feature_processor_.reset(new ActionsFeatureProcessor(options, unilib_));
372 embedding_executor_ = TFLiteEmbeddingExecutor::FromBuffer(
373 options->embedding_model(), options->embedding_size(),
374 options->embedding_quantization_bits());
375
376 if (embedding_executor_ == nullptr) {
377 TC3_LOG(ERROR) << "Could not initialize embedding executor.";
378 return false;
379 }
380
381 // Cache embedding of padding, start and end token.
382 if (!EmbedTokenId(options->padding_token_id(), &embedded_padding_token_) ||
383 !EmbedTokenId(options->start_token_id(), &embedded_start_token_) ||
384 !EmbedTokenId(options->end_token_id(), &embedded_end_token_)) {
385 TC3_LOG(ERROR) << "Could not precompute token embeddings.";
386 return false;
387 }
388 token_embedding_size_ = feature_processor_->GetTokenEmbeddingSize();
389 }
390
391 // Create low confidence model if specified.
392 if (model_->low_confidence_ngram_model() != nullptr) {
393 sensitive_model_ = NGramSensitiveModel::Create(
394 unilib_, model_->low_confidence_ngram_model(),
395 feature_processor_ == nullptr ? nullptr
396 : feature_processor_->tokenizer());
397 if (sensitive_model_ == nullptr) {
398 TC3_LOG(ERROR) << "Could not create ngram linear regression model.";
399 return false;
400 }
401 }
402 if (model_->low_confidence_tflite_model() != nullptr) {
403 sensitive_model_ =
404 TFLiteSensitiveModel::Create(model_->low_confidence_tflite_model());
405 if (sensitive_model_ == nullptr) {
406 TC3_LOG(ERROR) << "Could not create TFLite sensitive model.";
407 return false;
408 }
409 }
410
411 return true;
412 }
413
InitializeTriggeringPreconditions()414 bool ActionsSuggestions::InitializeTriggeringPreconditions() {
415 triggering_preconditions_overlay_ =
416 LoadAndVerifyFlatbuffer<TriggeringPreconditions>(
417 triggering_preconditions_overlay_buffer_);
418
419 if (triggering_preconditions_overlay_ == nullptr &&
420 !triggering_preconditions_overlay_buffer_.empty()) {
421 TC3_LOG(ERROR) << "Could not load triggering preconditions overwrites.";
422 return false;
423 }
424 const flatbuffers::Table* overlay =
425 reinterpret_cast<const flatbuffers::Table*>(
426 triggering_preconditions_overlay_);
427 const TriggeringPreconditions* defaults = model_->preconditions();
428 if (defaults == nullptr) {
429 TC3_LOG(ERROR) << "No triggering conditions specified.";
430 return false;
431 }
432
433 preconditions_.min_smart_reply_triggering_score = ValueOrDefault(
434 overlay, TriggeringPreconditions::VT_MIN_SMART_REPLY_TRIGGERING_SCORE,
435 defaults->min_smart_reply_triggering_score());
436 preconditions_.max_sensitive_topic_score = ValueOrDefault(
437 overlay, TriggeringPreconditions::VT_MAX_SENSITIVE_TOPIC_SCORE,
438 defaults->max_sensitive_topic_score());
439 preconditions_.suppress_on_sensitive_topic = ValueOrDefault(
440 overlay, TriggeringPreconditions::VT_SUPPRESS_ON_SENSITIVE_TOPIC,
441 defaults->suppress_on_sensitive_topic());
442 preconditions_.min_input_length =
443 ValueOrDefault(overlay, TriggeringPreconditions::VT_MIN_INPUT_LENGTH,
444 defaults->min_input_length());
445 preconditions_.max_input_length =
446 ValueOrDefault(overlay, TriggeringPreconditions::VT_MAX_INPUT_LENGTH,
447 defaults->max_input_length());
448 preconditions_.min_locale_match_fraction = ValueOrDefault(
449 overlay, TriggeringPreconditions::VT_MIN_LOCALE_MATCH_FRACTION,
450 defaults->min_locale_match_fraction());
451 preconditions_.handle_missing_locale_as_supported = ValueOrDefault(
452 overlay, TriggeringPreconditions::VT_HANDLE_MISSING_LOCALE_AS_SUPPORTED,
453 defaults->handle_missing_locale_as_supported());
454 preconditions_.handle_unknown_locale_as_supported = ValueOrDefault(
455 overlay, TriggeringPreconditions::VT_HANDLE_UNKNOWN_LOCALE_AS_SUPPORTED,
456 defaults->handle_unknown_locale_as_supported());
457 preconditions_.suppress_on_low_confidence_input = ValueOrDefault(
458 overlay, TriggeringPreconditions::VT_SUPPRESS_ON_LOW_CONFIDENCE_INPUT,
459 defaults->suppress_on_low_confidence_input());
460 preconditions_.min_reply_score_threshold = ValueOrDefault(
461 overlay, TriggeringPreconditions::VT_MIN_REPLY_SCORE_THRESHOLD,
462 defaults->min_reply_score_threshold());
463
464 return true;
465 }
466
EmbedTokenId(const int32 token_id,std::vector<float> * embedding) const467 bool ActionsSuggestions::EmbedTokenId(const int32 token_id,
468 std::vector<float>* embedding) const {
469 return feature_processor_->AppendFeatures(
470 {token_id},
471 /*dense_features=*/{}, embedding_executor_.get(), embedding);
472 }
473
Tokenize(const std::vector<std::string> & context) const474 std::vector<std::vector<Token>> ActionsSuggestions::Tokenize(
475 const std::vector<std::string>& context) const {
476 std::vector<std::vector<Token>> tokens;
477 tokens.reserve(context.size());
478 for (const std::string& message : context) {
479 tokens.push_back(feature_processor_->tokenizer()->Tokenize(message));
480 }
481 return tokens;
482 }
483
EmbedTokensPerMessage(const std::vector<std::vector<Token>> & tokens,std::vector<float> * embeddings,int * max_num_tokens_per_message) const484 bool ActionsSuggestions::EmbedTokensPerMessage(
485 const std::vector<std::vector<Token>>& tokens,
486 std::vector<float>* embeddings, int* max_num_tokens_per_message) const {
487 const int num_messages = tokens.size();
488 *max_num_tokens_per_message = 0;
489 for (int i = 0; i < num_messages; i++) {
490 const int num_message_tokens = tokens[i].size();
491 if (num_message_tokens > *max_num_tokens_per_message) {
492 *max_num_tokens_per_message = num_message_tokens;
493 }
494 }
495
496 if (model_->feature_processor_options()->min_num_tokens_per_message() >
497 *max_num_tokens_per_message) {
498 *max_num_tokens_per_message =
499 model_->feature_processor_options()->min_num_tokens_per_message();
500 }
501 if (model_->feature_processor_options()->max_num_tokens_per_message() > 0 &&
502 *max_num_tokens_per_message >
503 model_->feature_processor_options()->max_num_tokens_per_message()) {
504 *max_num_tokens_per_message =
505 model_->feature_processor_options()->max_num_tokens_per_message();
506 }
507
508 // Embed all tokens and add paddings to pad tokens of each message to the
509 // maximum number of tokens in a message of the conversation.
510 // If a number of tokens is specified in the model config, tokens at the
511 // beginning of a message are dropped if they don't fit in the limit.
512 for (int i = 0; i < num_messages; i++) {
513 const int start =
514 std::max<int>(tokens[i].size() - *max_num_tokens_per_message, 0);
515 for (int pos = start; pos < tokens[i].size(); pos++) {
516 if (!feature_processor_->AppendTokenFeatures(
517 tokens[i][pos], embedding_executor_.get(), embeddings)) {
518 TC3_LOG(ERROR) << "Could not run token feature extractor.";
519 return false;
520 }
521 }
522 // Add padding.
523 for (int k = tokens[i].size(); k < *max_num_tokens_per_message; k++) {
524 embeddings->insert(embeddings->end(), embedded_padding_token_.begin(),
525 embedded_padding_token_.end());
526 }
527 }
528
529 return true;
530 }
531
EmbedAndFlattenTokens(const std::vector<std::vector<Token>> & tokens,std::vector<float> * embeddings,int * total_token_count) const532 bool ActionsSuggestions::EmbedAndFlattenTokens(
533 const std::vector<std::vector<Token>>& tokens,
534 std::vector<float>* embeddings, int* total_token_count) const {
535 const int num_messages = tokens.size();
536 int start_message = 0;
537 int message_token_offset = 0;
538
539 // If a maximum model input length is specified, we need to check how
540 // much we need to trim at the start.
541 const int max_num_total_tokens =
542 model_->feature_processor_options()->max_num_total_tokens();
543 if (max_num_total_tokens > 0) {
544 int total_tokens = 0;
545 start_message = num_messages - 1;
546 for (; start_message >= 0; start_message--) {
547 // Tokens of the message + start and end token.
548 const int num_message_tokens = tokens[start_message].size() + 2;
549 total_tokens += num_message_tokens;
550
551 // Check whether we exhausted the budget.
552 if (total_tokens >= max_num_total_tokens) {
553 message_token_offset = total_tokens - max_num_total_tokens;
554 break;
555 }
556 }
557 }
558
559 // Add embeddings.
560 *total_token_count = 0;
561 for (int i = start_message; i < num_messages; i++) {
562 if (message_token_offset == 0) {
563 ++(*total_token_count);
564 // Add `start message` token.
565 embeddings->insert(embeddings->end(), embedded_start_token_.begin(),
566 embedded_start_token_.end());
567 }
568
569 for (int pos = std::max(0, message_token_offset - 1);
570 pos < tokens[i].size(); pos++) {
571 ++(*total_token_count);
572 if (!feature_processor_->AppendTokenFeatures(
573 tokens[i][pos], embedding_executor_.get(), embeddings)) {
574 TC3_LOG(ERROR) << "Could not run token feature extractor.";
575 return false;
576 }
577 }
578
579 // Add `end message` token.
580 ++(*total_token_count);
581 embeddings->insert(embeddings->end(), embedded_end_token_.begin(),
582 embedded_end_token_.end());
583
584 // Reset for the subsequent messages.
585 message_token_offset = 0;
586 }
587
588 // Add optional padding.
589 const int min_num_total_tokens =
590 model_->feature_processor_options()->min_num_total_tokens();
591 for (; *total_token_count < min_num_total_tokens; ++(*total_token_count)) {
592 embeddings->insert(embeddings->end(), embedded_padding_token_.begin(),
593 embedded_padding_token_.end());
594 }
595
596 return true;
597 }
598
AllocateInput(const int conversation_length,const int max_tokens,const int total_token_count,tflite::Interpreter * interpreter) const599 bool ActionsSuggestions::AllocateInput(const int conversation_length,
600 const int max_tokens,
601 const int total_token_count,
602 tflite::Interpreter* interpreter) const {
603 if (model_->tflite_model_spec()->resize_inputs()) {
604 if (model_->tflite_model_spec()->input_context() >= 0) {
605 interpreter->ResizeInputTensor(
606 interpreter->inputs()[model_->tflite_model_spec()->input_context()],
607 {1, conversation_length});
608 }
609 if (model_->tflite_model_spec()->input_user_id() >= 0) {
610 interpreter->ResizeInputTensor(
611 interpreter->inputs()[model_->tflite_model_spec()->input_user_id()],
612 {1, conversation_length});
613 }
614 if (model_->tflite_model_spec()->input_time_diffs() >= 0) {
615 interpreter->ResizeInputTensor(
616 interpreter
617 ->inputs()[model_->tflite_model_spec()->input_time_diffs()],
618 {1, conversation_length});
619 }
620 if (model_->tflite_model_spec()->input_num_tokens() >= 0) {
621 interpreter->ResizeInputTensor(
622 interpreter
623 ->inputs()[model_->tflite_model_spec()->input_num_tokens()],
624 {conversation_length, 1});
625 }
626 if (model_->tflite_model_spec()->input_token_embeddings() >= 0) {
627 interpreter->ResizeInputTensor(
628 interpreter
629 ->inputs()[model_->tflite_model_spec()->input_token_embeddings()],
630 {conversation_length, max_tokens, token_embedding_size_});
631 }
632 if (model_->tflite_model_spec()->input_flattened_token_embeddings() >= 0) {
633 interpreter->ResizeInputTensor(
634 interpreter->inputs()[model_->tflite_model_spec()
635 ->input_flattened_token_embeddings()],
636 {1, total_token_count});
637 }
638 }
639
640 return interpreter->AllocateTensors() == kTfLiteOk;
641 }
642
SetupModelInput(const std::vector<std::string> & context,const std::vector<int> & user_ids,const std::vector<float> & time_diffs,const int num_suggestions,const ActionSuggestionOptions & options,tflite::Interpreter * interpreter) const643 bool ActionsSuggestions::SetupModelInput(
644 const std::vector<std::string>& context, const std::vector<int>& user_ids,
645 const std::vector<float>& time_diffs, const int num_suggestions,
646 const ActionSuggestionOptions& options,
647 tflite::Interpreter* interpreter) const {
648 // Compute token embeddings.
649 std::vector<std::vector<Token>> tokens;
650 std::vector<float> token_embeddings;
651 std::vector<float> flattened_token_embeddings;
652 int max_tokens = 0;
653 int total_token_count = 0;
654 if (model_->tflite_model_spec()->input_num_tokens() >= 0 ||
655 model_->tflite_model_spec()->input_token_embeddings() >= 0 ||
656 model_->tflite_model_spec()->input_flattened_token_embeddings() >= 0) {
657 if (feature_processor_ == nullptr) {
658 TC3_LOG(ERROR) << "No feature processor specified.";
659 return false;
660 }
661
662 // Tokenize the messages in the conversation.
663 tokens = Tokenize(context);
664 if (model_->tflite_model_spec()->input_token_embeddings() >= 0) {
665 if (!EmbedTokensPerMessage(tokens, &token_embeddings, &max_tokens)) {
666 TC3_LOG(ERROR) << "Could not extract token features.";
667 return false;
668 }
669 }
670 if (model_->tflite_model_spec()->input_flattened_token_embeddings() >= 0) {
671 if (!EmbedAndFlattenTokens(tokens, &flattened_token_embeddings,
672 &total_token_count)) {
673 TC3_LOG(ERROR) << "Could not extract token features.";
674 return false;
675 }
676 }
677 }
678
679 if (!AllocateInput(context.size(), max_tokens, total_token_count,
680 interpreter)) {
681 TC3_LOG(ERROR) << "TensorFlow Lite model allocation failed.";
682 return false;
683 }
684 if (model_->tflite_model_spec()->input_context() >= 0) {
685 if (model_->tflite_model_spec()->input_length_to_pad() > 0) {
686 model_executor_->SetInput<std::string>(
687 model_->tflite_model_spec()->input_context(),
688 PadOrTruncateToTargetLength(
689 context, model_->tflite_model_spec()->input_length_to_pad(),
690 std::string("")),
691 interpreter);
692 } else {
693 model_executor_->SetInput<std::string>(
694 model_->tflite_model_spec()->input_context(), context, interpreter);
695 }
696 }
697 if (model_->tflite_model_spec()->input_context_length() >= 0) {
698 model_executor_->SetInput<int>(
699 model_->tflite_model_spec()->input_context_length(), context.size(),
700 interpreter);
701 }
702 if (model_->tflite_model_spec()->input_user_id() >= 0) {
703 if (model_->tflite_model_spec()->input_length_to_pad() > 0) {
704 model_executor_->SetInput<int>(
705 model_->tflite_model_spec()->input_user_id(),
706 PadOrTruncateToTargetLength(
707 user_ids, model_->tflite_model_spec()->input_length_to_pad(), 0),
708 interpreter);
709 } else {
710 model_executor_->SetInput<int>(
711 model_->tflite_model_spec()->input_user_id(), user_ids, interpreter);
712 }
713 }
714 if (model_->tflite_model_spec()->input_num_suggestions() >= 0) {
715 model_executor_->SetInput<int>(
716 model_->tflite_model_spec()->input_num_suggestions(), num_suggestions,
717 interpreter);
718 }
719 if (model_->tflite_model_spec()->input_time_diffs() >= 0) {
720 model_executor_->SetInput<float>(
721 model_->tflite_model_spec()->input_time_diffs(), time_diffs,
722 interpreter);
723 }
724 if (model_->tflite_model_spec()->input_num_tokens() >= 0) {
725 std::vector<int> num_tokens_per_message(tokens.size());
726 for (int i = 0; i < tokens.size(); i++) {
727 num_tokens_per_message[i] = tokens[i].size();
728 }
729 model_executor_->SetInput<int>(
730 model_->tflite_model_spec()->input_num_tokens(), num_tokens_per_message,
731 interpreter);
732 }
733 if (model_->tflite_model_spec()->input_token_embeddings() >= 0) {
734 model_executor_->SetInput<float>(
735 model_->tflite_model_spec()->input_token_embeddings(), token_embeddings,
736 interpreter);
737 }
738 if (model_->tflite_model_spec()->input_flattened_token_embeddings() >= 0) {
739 model_executor_->SetInput<float>(
740 model_->tflite_model_spec()->input_flattened_token_embeddings(),
741 flattened_token_embeddings, interpreter);
742 }
743 // Set up additional input parameters.
744 if (const auto* input_name_index =
745 model_->tflite_model_spec()->input_name_index()) {
746 const std::unordered_map<std::string, Variant>& model_parameters =
747 options.model_parameters;
748 for (const TensorflowLiteModelSpec_::InputNameIndexEntry* entry :
749 *input_name_index) {
750 const std::string param_name = entry->key()->str();
751 const int param_index = entry->value();
752 const TfLiteType param_type =
753 interpreter->tensor(interpreter->inputs()[param_index])->type;
754 const auto param_value_it = model_parameters.find(param_name);
755 const bool has_value = param_value_it != model_parameters.end();
756 switch (param_type) {
757 case kTfLiteFloat32:
758 if (has_value) {
759 SetVectorOrScalarAsModelInput<float>(param_index,
760 param_value_it->second,
761 interpreter, model_executor_);
762 } else {
763 model_executor_->SetInput<float>(param_index, kDefaultFloat,
764 interpreter);
765 }
766 break;
767 case kTfLiteInt32:
768 if (has_value) {
769 SetVectorOrScalarAsModelInput<int32_t>(
770 param_index, param_value_it->second, interpreter,
771 model_executor_);
772 } else {
773 model_executor_->SetInput<int32_t>(param_index, kDefaultInt,
774 interpreter);
775 }
776 break;
777 case kTfLiteInt64:
778 model_executor_->SetInput<int64_t>(
779 param_index,
780 has_value ? param_value_it->second.Value<int64>() : kDefaultInt,
781 interpreter);
782 break;
783 case kTfLiteUInt8:
784 model_executor_->SetInput<uint8_t>(
785 param_index,
786 has_value ? param_value_it->second.Value<uint8>() : kDefaultInt,
787 interpreter);
788 break;
789 case kTfLiteInt8:
790 model_executor_->SetInput<int8_t>(
791 param_index,
792 has_value ? param_value_it->second.Value<int8>() : kDefaultInt,
793 interpreter);
794 break;
795 case kTfLiteBool:
796 model_executor_->SetInput<bool>(
797 param_index,
798 has_value ? param_value_it->second.Value<bool>() : kDefaultBool,
799 interpreter);
800 break;
801 default:
802 TC3_LOG(ERROR) << "Unsupported type of additional input parameter: "
803 << param_name;
804 }
805 }
806 }
807 return true;
808 }
809
PopulateTextReplies(const tflite::Interpreter * interpreter,int suggestion_index,int score_index,const std::string & type,ActionsSuggestionsResponse * response) const810 void ActionsSuggestions::PopulateTextReplies(
811 const tflite::Interpreter* interpreter, int suggestion_index,
812 int score_index, const std::string& type,
813 ActionsSuggestionsResponse* response) const {
814 const std::vector<tflite::StringRef> replies =
815 model_executor_->Output<tflite::StringRef>(suggestion_index, interpreter);
816 const TensorView<float> scores =
817 model_executor_->OutputView<float>(score_index, interpreter);
818 for (int i = 0; i < replies.size(); i++) {
819 if (replies[i].len == 0) {
820 continue;
821 }
822 const float score = scores.data()[i];
823 if (score < preconditions_.min_reply_score_threshold) {
824 continue;
825 }
826 response->actions.push_back(
827 {std::string(replies[i].str, replies[i].len), type, score});
828 }
829 }
830
FillSuggestionFromSpecWithEntityData(const ActionSuggestionSpec * spec,ActionSuggestion * suggestion) const831 void ActionsSuggestions::FillSuggestionFromSpecWithEntityData(
832 const ActionSuggestionSpec* spec, ActionSuggestion* suggestion) const {
833 std::unique_ptr<MutableFlatbuffer> entity_data =
834 entity_data_builder_ != nullptr ? entity_data_builder_->NewRoot()
835 : nullptr;
836 FillSuggestionFromSpec(spec, entity_data.get(), suggestion);
837 }
838
PopulateIntentTriggering(const tflite::Interpreter * interpreter,int suggestion_index,int score_index,const ActionSuggestionSpec * task_spec,ActionsSuggestionsResponse * response) const839 void ActionsSuggestions::PopulateIntentTriggering(
840 const tflite::Interpreter* interpreter, int suggestion_index,
841 int score_index, const ActionSuggestionSpec* task_spec,
842 ActionsSuggestionsResponse* response) const {
843 if (!task_spec || task_spec->type()->size() == 0) {
844 TC3_LOG(ERROR)
845 << "Task type for intent (action) triggering cannot be empty!";
846 return;
847 }
848 const TensorView<bool> intent_prediction =
849 model_executor_->OutputView<bool>(suggestion_index, interpreter);
850 const TensorView<float> intent_scores =
851 model_executor_->OutputView<float>(score_index, interpreter);
852 // Two result corresponding to binary triggering case.
853 TC3_CHECK_EQ(intent_prediction.size(), 2);
854 TC3_CHECK_EQ(intent_scores.size(), 2);
855 // We rely on in-graph thresholding logic so at this point the results
856 // have been ranked properly according to threshold.
857 const bool triggering = intent_prediction.data()[0];
858 const float trigger_score = intent_scores.data()[0];
859
860 if (triggering) {
861 ActionSuggestion suggestion;
862 std::unique_ptr<MutableFlatbuffer> entity_data =
863 entity_data_builder_ != nullptr ? entity_data_builder_->NewRoot()
864 : nullptr;
865 FillSuggestionFromSpecWithEntityData(task_spec, &suggestion);
866 suggestion.score = trigger_score;
867 response->actions.push_back(std::move(suggestion));
868 }
869 }
870
ReadModelOutput(tflite::Interpreter * interpreter,const ActionSuggestionOptions & options,ActionsSuggestionsResponse * response) const871 bool ActionsSuggestions::ReadModelOutput(
872 tflite::Interpreter* interpreter, const ActionSuggestionOptions& options,
873 ActionsSuggestionsResponse* response) const {
874 // Read sensitivity and triggering score predictions.
875 if (model_->tflite_model_spec()->output_triggering_score() >= 0) {
876 const TensorView<float> triggering_score =
877 model_executor_->OutputView<float>(
878 model_->tflite_model_spec()->output_triggering_score(),
879 interpreter);
880 if (!triggering_score.is_valid() || triggering_score.size() == 0) {
881 TC3_LOG(ERROR) << "Could not compute triggering score.";
882 return false;
883 }
884 response->triggering_score = triggering_score.data()[0];
885 response->output_filtered_min_triggering_score =
886 (response->triggering_score <
887 preconditions_.min_smart_reply_triggering_score);
888 }
889 if (model_->tflite_model_spec()->output_sensitive_topic_score() >= 0) {
890 const TensorView<float> sensitive_topic_score =
891 model_executor_->OutputView<float>(
892 model_->tflite_model_spec()->output_sensitive_topic_score(),
893 interpreter);
894 if (!sensitive_topic_score.is_valid() ||
895 sensitive_topic_score.dim(0) != 1) {
896 TC3_LOG(ERROR) << "Could not compute sensitive topic score.";
897 return false;
898 }
899 response->sensitivity_score = sensitive_topic_score.data()[0];
900 response->is_sensitive = (response->sensitivity_score >
901 preconditions_.max_sensitive_topic_score);
902 }
903
904 // Suppress model outputs.
905 if (response->is_sensitive) {
906 return true;
907 }
908
909 // Read smart reply predictions.
910 if (!response->output_filtered_min_triggering_score &&
911 model_->tflite_model_spec()->output_replies() >= 0) {
912 PopulateTextReplies(interpreter,
913 model_->tflite_model_spec()->output_replies(),
914 model_->tflite_model_spec()->output_replies_scores(),
915 model_->smart_reply_action_type()->str(), response);
916 }
917
918 // Read actions suggestions.
919 if (model_->tflite_model_spec()->output_actions_scores() >= 0) {
920 const TensorView<float> actions_scores = model_executor_->OutputView<float>(
921 model_->tflite_model_spec()->output_actions_scores(), interpreter);
922 for (int i = 0; i < model_->action_type()->size(); i++) {
923 const ActionTypeOptions* action_type = model_->action_type()->Get(i);
924 // Skip disabled action classes, such as the default other category.
925 if (!action_type->enabled()) {
926 continue;
927 }
928 const float score = actions_scores.data()[i];
929 if (score < action_type->min_triggering_score()) {
930 continue;
931 }
932
933 // Create action from model output.
934 ActionSuggestion suggestion;
935 suggestion.type = action_type->name()->str();
936 std::unique_ptr<MutableFlatbuffer> entity_data =
937 entity_data_builder_ != nullptr ? entity_data_builder_->NewRoot()
938 : nullptr;
939 FillSuggestionFromSpecWithEntityData(action_type->action(), &suggestion);
940 suggestion.score = score;
941 response->actions.push_back(std::move(suggestion));
942 }
943 }
944
945 // Read multi-task predictions and construct the result properly.
946 if (const auto* prediction_metadata =
947 model_->tflite_model_spec()->prediction_metadata()) {
948 for (const PredictionMetadata* metadata : *prediction_metadata) {
949 const ActionSuggestionSpec* task_spec = metadata->task_spec();
950 const int suggestions_index = metadata->output_suggestions();
951 const int suggestions_scores_index =
952 metadata->output_suggestions_scores();
953 switch (metadata->prediction_type()) {
954 case PredictionType_NEXT_MESSAGE_PREDICTION:
955 if (!task_spec || task_spec->type()->size() == 0) {
956 TC3_LOG(WARNING) << "Task type not provided, use default "
957 "smart_reply_action_type!";
958 }
959 PopulateTextReplies(
960 interpreter, suggestions_index, suggestions_scores_index,
961 task_spec ? task_spec->type()->str()
962 : model_->smart_reply_action_type()->str(),
963 response);
964 break;
965 case PredictionType_INTENT_TRIGGERING:
966 PopulateIntentTriggering(interpreter, suggestions_index,
967 suggestions_scores_index, task_spec,
968 response);
969 break;
970 default:
971 TC3_LOG(ERROR) << "Unsupported prediction type!";
972 return false;
973 }
974 }
975 }
976
977 return true;
978 }
979
SuggestActionsFromModel(const Conversation & conversation,const int num_messages,const ActionSuggestionOptions & options,ActionsSuggestionsResponse * response,std::unique_ptr<tflite::Interpreter> * interpreter) const980 bool ActionsSuggestions::SuggestActionsFromModel(
981 const Conversation& conversation, const int num_messages,
982 const ActionSuggestionOptions& options,
983 ActionsSuggestionsResponse* response,
984 std::unique_ptr<tflite::Interpreter>* interpreter) const {
985 TC3_CHECK_LE(num_messages, conversation.messages.size());
986
987 if (sensitive_model_ != nullptr &&
988 sensitive_model_->EvalConversation(conversation, num_messages).first) {
989 response->is_sensitive = true;
990 return true;
991 }
992
993 if (!model_executor_) {
994 return true;
995 }
996 *interpreter = model_executor_->CreateInterpreter();
997
998 if (!*interpreter) {
999 TC3_LOG(ERROR) << "Could not build TensorFlow Lite interpreter for the "
1000 "actions suggestions model.";
1001 return false;
1002 }
1003
1004 std::vector<std::string> context;
1005 std::vector<int> user_ids;
1006 std::vector<float> time_diffs;
1007 context.reserve(num_messages);
1008 user_ids.reserve(num_messages);
1009 time_diffs.reserve(num_messages);
1010
1011 // Gather last `num_messages` messages from the conversation.
1012 int64 last_message_reference_time_ms_utc = 0;
1013 const float second_in_ms = 1000;
1014 for (int i = conversation.messages.size() - num_messages;
1015 i < conversation.messages.size(); i++) {
1016 const ConversationMessage& message = conversation.messages[i];
1017 context.push_back(message.text);
1018 user_ids.push_back(message.user_id);
1019
1020 float time_diff_secs = 0;
1021 if (message.reference_time_ms_utc != 0 &&
1022 last_message_reference_time_ms_utc != 0) {
1023 time_diff_secs = std::max(0.0f, (message.reference_time_ms_utc -
1024 last_message_reference_time_ms_utc) /
1025 second_in_ms);
1026 }
1027 if (message.reference_time_ms_utc != 0) {
1028 last_message_reference_time_ms_utc = message.reference_time_ms_utc;
1029 }
1030 time_diffs.push_back(time_diff_secs);
1031 }
1032
1033 if (!SetupModelInput(context, user_ids, time_diffs,
1034 /*num_suggestions=*/model_->num_smart_replies(), options,
1035 interpreter->get())) {
1036 TC3_LOG(ERROR) << "Failed to setup input for TensorFlow Lite model.";
1037 return false;
1038 }
1039
1040 if ((*interpreter)->Invoke() != kTfLiteOk) {
1041 TC3_LOG(ERROR) << "Failed to invoke TensorFlow Lite interpreter.";
1042 return false;
1043 }
1044
1045 return ReadModelOutput(interpreter->get(), options, response);
1046 }
1047
SuggestActionsFromConversationIntentDetection(const Conversation & conversation,const ActionSuggestionOptions & options,std::vector<ActionSuggestion> * actions) const1048 Status ActionsSuggestions::SuggestActionsFromConversationIntentDetection(
1049 const Conversation& conversation, const ActionSuggestionOptions& options,
1050 std::vector<ActionSuggestion>* actions) const {
1051 TC3_ASSIGN_OR_RETURN(
1052 std::vector<ActionSuggestion> new_actions,
1053 conversation_intent_detection_->SuggestActions(conversation, options));
1054 for (auto& action : new_actions) {
1055 actions->push_back(std::move(action));
1056 }
1057 return Status::OK;
1058 }
1059
AnnotationOptionsForMessage(const ConversationMessage & message) const1060 AnnotationOptions ActionsSuggestions::AnnotationOptionsForMessage(
1061 const ConversationMessage& message) const {
1062 AnnotationOptions options;
1063 options.detected_text_language_tags = message.detected_text_language_tags;
1064 options.reference_time_ms_utc = message.reference_time_ms_utc;
1065 options.reference_timezone = message.reference_timezone;
1066 options.annotation_usecase =
1067 model_->annotation_actions_spec()->annotation_usecase();
1068 options.is_serialized_entity_data_enabled =
1069 model_->annotation_actions_spec()->is_serialized_entity_data_enabled();
1070 options.entity_types = annotation_entity_types_;
1071 return options;
1072 }
1073
1074 // Run annotator on the messages of a conversation.
AnnotateConversation(const Conversation & conversation,const Annotator * annotator) const1075 Conversation ActionsSuggestions::AnnotateConversation(
1076 const Conversation& conversation, const Annotator* annotator) const {
1077 if (annotator == nullptr) {
1078 return conversation;
1079 }
1080 const int num_messages_grammar =
1081 ((model_->rules() && model_->rules()->grammar_rules() &&
1082 model_->rules()
1083 ->grammar_rules()
1084 ->rules()
1085 ->nonterminals()
1086 ->annotation_nt())
1087 ? 1
1088 : 0);
1089 const int num_messages_mapping =
1090 (model_->annotation_actions_spec()
1091 ? std::max(model_->annotation_actions_spec()
1092 ->max_history_from_any_person(),
1093 model_->annotation_actions_spec()
1094 ->max_history_from_last_person())
1095 : 0);
1096 const int num_messages = std::max(num_messages_grammar, num_messages_mapping);
1097 if (num_messages == 0) {
1098 // No annotations are used.
1099 return conversation;
1100 }
1101 Conversation annotated_conversation = conversation;
1102 for (int i = 0, message_index = annotated_conversation.messages.size() - 1;
1103 i < num_messages && message_index >= 0; i++, message_index--) {
1104 ConversationMessage* message =
1105 &annotated_conversation.messages[message_index];
1106 if (message->annotations.empty()) {
1107 message->annotations = annotator->Annotate(
1108 message->text, AnnotationOptionsForMessage(*message));
1109 ConvertDatetimeToTime(&message->annotations);
1110 }
1111 }
1112 return annotated_conversation;
1113 }
1114
SuggestActionsFromAnnotations(const Conversation & conversation,std::vector<ActionSuggestion> * actions) const1115 void ActionsSuggestions::SuggestActionsFromAnnotations(
1116 const Conversation& conversation,
1117 std::vector<ActionSuggestion>* actions) const {
1118 if (model_->annotation_actions_spec() == nullptr ||
1119 model_->annotation_actions_spec()->annotation_mapping() == nullptr ||
1120 model_->annotation_actions_spec()->annotation_mapping()->size() == 0) {
1121 return;
1122 }
1123
1124 // Create actions based on the annotations.
1125 const int max_from_any_person =
1126 model_->annotation_actions_spec()->max_history_from_any_person();
1127 const int max_from_last_person =
1128 model_->annotation_actions_spec()->max_history_from_last_person();
1129 const int last_person = conversation.messages.back().user_id;
1130
1131 int num_messages_last_person = 0;
1132 int num_messages_any_person = 0;
1133 bool all_from_last_person = true;
1134 for (int message_index = conversation.messages.size() - 1; message_index >= 0;
1135 message_index--) {
1136 const ConversationMessage& message = conversation.messages[message_index];
1137 std::vector<AnnotatedSpan> annotations = message.annotations;
1138
1139 // Update how many messages we have processed from the last person in the
1140 // conversation and from any person in the conversation.
1141 num_messages_any_person++;
1142 if (all_from_last_person && message.user_id == last_person) {
1143 num_messages_last_person++;
1144 } else {
1145 all_from_last_person = false;
1146 }
1147
1148 if (num_messages_any_person > max_from_any_person &&
1149 (!all_from_last_person ||
1150 num_messages_last_person > max_from_last_person)) {
1151 break;
1152 }
1153
1154 if (message.user_id == kLocalUserId) {
1155 if (model_->annotation_actions_spec()->only_until_last_sent()) {
1156 break;
1157 }
1158 if (!model_->annotation_actions_spec()->include_local_user_messages()) {
1159 continue;
1160 }
1161 }
1162
1163 std::vector<ActionSuggestionAnnotation> action_annotations;
1164 action_annotations.reserve(annotations.size());
1165 for (const AnnotatedSpan& annotation : annotations) {
1166 if (annotation.classification.empty()) {
1167 continue;
1168 }
1169
1170 const ClassificationResult& classification_result =
1171 annotation.classification[0];
1172
1173 ActionSuggestionAnnotation action_annotation;
1174 action_annotation.span = {
1175 message_index, annotation.span,
1176 UTF8ToUnicodeText(message.text, /*do_copy=*/false)
1177 .UTF8Substring(annotation.span.first, annotation.span.second)};
1178 action_annotation.entity = classification_result;
1179 action_annotation.name = classification_result.collection;
1180 action_annotations.push_back(std::move(action_annotation));
1181 }
1182
1183 if (model_->annotation_actions_spec()->deduplicate_annotations()) {
1184 // Create actions only for deduplicated annotations.
1185 for (const int annotation_id :
1186 DeduplicateAnnotations(action_annotations)) {
1187 SuggestActionsFromAnnotation(
1188 message_index, action_annotations[annotation_id], actions);
1189 }
1190 } else {
1191 // Create actions for all annotations.
1192 for (const ActionSuggestionAnnotation& annotation : action_annotations) {
1193 SuggestActionsFromAnnotation(message_index, annotation, actions);
1194 }
1195 }
1196 }
1197 }
1198
SuggestActionsFromAnnotation(const int message_index,const ActionSuggestionAnnotation & annotation,std::vector<ActionSuggestion> * actions) const1199 void ActionsSuggestions::SuggestActionsFromAnnotation(
1200 const int message_index, const ActionSuggestionAnnotation& annotation,
1201 std::vector<ActionSuggestion>* actions) const {
1202 for (const AnnotationActionsSpec_::AnnotationMapping* mapping :
1203 *model_->annotation_actions_spec()->annotation_mapping()) {
1204 if (annotation.entity.collection ==
1205 mapping->annotation_collection()->str()) {
1206 if (annotation.entity.score < mapping->min_annotation_score()) {
1207 continue;
1208 }
1209
1210 std::unique_ptr<MutableFlatbuffer> entity_data =
1211 entity_data_builder_ != nullptr ? entity_data_builder_->NewRoot()
1212 : nullptr;
1213
1214 // Set annotation text as (additional) entity data field.
1215 if (mapping->entity_field() != nullptr) {
1216 TC3_CHECK_NE(entity_data, nullptr);
1217
1218 UnicodeText normalized_annotation_text =
1219 UTF8ToUnicodeText(annotation.span.text, /*do_copy=*/false);
1220
1221 // Apply normalization if specified.
1222 if (mapping->normalization_options() != nullptr) {
1223 normalized_annotation_text =
1224 NormalizeText(*unilib_, mapping->normalization_options(),
1225 normalized_annotation_text);
1226 }
1227
1228 entity_data->ParseAndSet(mapping->entity_field(),
1229 normalized_annotation_text.ToUTF8String());
1230 }
1231
1232 ActionSuggestion suggestion;
1233 FillSuggestionFromSpec(mapping->action(), entity_data.get(), &suggestion);
1234 if (mapping->use_annotation_score()) {
1235 suggestion.score = annotation.entity.score;
1236 }
1237 suggestion.annotations = {annotation};
1238 actions->push_back(std::move(suggestion));
1239 }
1240 }
1241 }
1242
DeduplicateAnnotations(const std::vector<ActionSuggestionAnnotation> & annotations) const1243 std::vector<int> ActionsSuggestions::DeduplicateAnnotations(
1244 const std::vector<ActionSuggestionAnnotation>& annotations) const {
1245 std::map<std::pair<std::string, std::string>, int> deduplicated_annotations;
1246
1247 for (int i = 0; i < annotations.size(); i++) {
1248 const std::pair<std::string, std::string> key = {annotations[i].name,
1249 annotations[i].span.text};
1250 auto entry = deduplicated_annotations.find(key);
1251 if (entry != deduplicated_annotations.end()) {
1252 // Kepp the annotation with the higher score.
1253 if (annotations[entry->second].entity.score <
1254 annotations[i].entity.score) {
1255 entry->second = i;
1256 }
1257 continue;
1258 }
1259 deduplicated_annotations.insert(entry, {key, i});
1260 }
1261
1262 std::vector<int> result;
1263 result.reserve(deduplicated_annotations.size());
1264 for (const auto& key_and_annotation : deduplicated_annotations) {
1265 result.push_back(key_and_annotation.second);
1266 }
1267 return result;
1268 }
1269
1270 #if !defined(TC3_DISABLE_LUA)
SuggestActionsFromLua(const Conversation & conversation,const TfLiteModelExecutor * model_executor,const tflite::Interpreter * interpreter,const reflection::Schema * annotation_entity_data_schema,std::vector<ActionSuggestion> * actions) const1271 bool ActionsSuggestions::SuggestActionsFromLua(
1272 const Conversation& conversation, const TfLiteModelExecutor* model_executor,
1273 const tflite::Interpreter* interpreter,
1274 const reflection::Schema* annotation_entity_data_schema,
1275 std::vector<ActionSuggestion>* actions) const {
1276 if (lua_bytecode_.empty()) {
1277 return true;
1278 }
1279
1280 auto lua_actions = LuaActionsSuggestions::CreateLuaActionsSuggestions(
1281 lua_bytecode_, conversation, model_executor, model_->tflite_model_spec(),
1282 interpreter, entity_data_schema_, annotation_entity_data_schema);
1283 if (lua_actions == nullptr) {
1284 TC3_LOG(ERROR) << "Could not create lua actions.";
1285 return false;
1286 }
1287 return lua_actions->SuggestActions(actions);
1288 }
1289 #else
SuggestActionsFromLua(const Conversation & conversation,const TfLiteModelExecutor * model_executor,const tflite::Interpreter * interpreter,const reflection::Schema * annotation_entity_data_schema,std::vector<ActionSuggestion> * actions) const1290 bool ActionsSuggestions::SuggestActionsFromLua(
1291 const Conversation& conversation, const TfLiteModelExecutor* model_executor,
1292 const tflite::Interpreter* interpreter,
1293 const reflection::Schema* annotation_entity_data_schema,
1294 std::vector<ActionSuggestion>* actions) const {
1295 return true;
1296 }
1297 #endif
1298
GatherActionsSuggestions(const Conversation & conversation,const Annotator * annotator,const ActionSuggestionOptions & options,ActionsSuggestionsResponse * response) const1299 bool ActionsSuggestions::GatherActionsSuggestions(
1300 const Conversation& conversation, const Annotator* annotator,
1301 const ActionSuggestionOptions& options,
1302 ActionsSuggestionsResponse* response) const {
1303 if (conversation.messages.empty()) {
1304 return true;
1305 }
1306
1307 // Run annotator against messages.
1308 const Conversation annotated_conversation =
1309 AnnotateConversation(conversation, annotator);
1310
1311 const int num_messages = NumMessagesToConsider(
1312 annotated_conversation, model_->max_conversation_history_length());
1313
1314 if (num_messages <= 0) {
1315 TC3_LOG(INFO) << "No messages provided for actions suggestions.";
1316 return false;
1317 }
1318
1319 SuggestActionsFromAnnotations(annotated_conversation, &response->actions);
1320
1321 if (grammar_actions_ != nullptr &&
1322 !grammar_actions_->SuggestActions(annotated_conversation,
1323 &response->actions)) {
1324 TC3_LOG(ERROR) << "Could not suggest actions from grammar rules.";
1325 return false;
1326 }
1327
1328 int input_text_length = 0;
1329 int num_matching_locales = 0;
1330 for (int i = annotated_conversation.messages.size() - num_messages;
1331 i < annotated_conversation.messages.size(); i++) {
1332 input_text_length += annotated_conversation.messages[i].text.length();
1333 std::vector<Locale> message_languages;
1334 if (!ParseLocales(
1335 annotated_conversation.messages[i].detected_text_language_tags,
1336 &message_languages)) {
1337 continue;
1338 }
1339 if (Locale::IsAnyLocaleSupported(
1340 message_languages, locales_,
1341 preconditions_.handle_unknown_locale_as_supported)) {
1342 ++num_matching_locales;
1343 }
1344 }
1345
1346 // Bail out if we are provided with too few or too much input.
1347 if (input_text_length < preconditions_.min_input_length ||
1348 (preconditions_.max_input_length >= 0 &&
1349 input_text_length > preconditions_.max_input_length)) {
1350 TC3_LOG(INFO) << "Too much or not enough input for inference.";
1351 return response;
1352 }
1353
1354 // Bail out if the text does not look like it can be handled by the model.
1355 const float matching_fraction =
1356 static_cast<float>(num_matching_locales) / num_messages;
1357 if (matching_fraction < preconditions_.min_locale_match_fraction) {
1358 TC3_LOG(INFO) << "Not enough locale matches.";
1359 response->output_filtered_locale_mismatch = true;
1360 return true;
1361 }
1362
1363 std::vector<const UniLib::RegexPattern*> post_check_rules;
1364 if (preconditions_.suppress_on_low_confidence_input) {
1365 if (regex_actions_->IsLowConfidenceInput(annotated_conversation,
1366 num_messages, &post_check_rules)) {
1367 response->output_filtered_low_confidence = true;
1368 return true;
1369 }
1370 }
1371
1372 std::unique_ptr<tflite::Interpreter> interpreter;
1373 if (!SuggestActionsFromModel(annotated_conversation, num_messages, options,
1374 response, &interpreter)) {
1375 TC3_LOG(ERROR) << "Could not run model.";
1376 return false;
1377 }
1378
1379 // SuggestActionsFromModel also detects if the conversation is sensitive,
1380 // either by using the old ngram model or the new model.
1381 // Suppress all predictions if the conversation was deemed sensitive.
1382 if (preconditions_.suppress_on_sensitive_topic && response->is_sensitive) {
1383 return true;
1384 }
1385
1386 if (conversation_intent_detection_) {
1387 // TODO(zbin): Ensure the deduplication/ranking logic in ranker.cc works.
1388 auto actions = SuggestActionsFromConversationIntentDetection(
1389 annotated_conversation, options, &response->actions);
1390 if (!actions.ok()) {
1391 TC3_LOG(ERROR) << "Could not run conversation intent detection: "
1392 << actions.error_message();
1393 return false;
1394 }
1395 }
1396
1397 if (!SuggestActionsFromLua(
1398 annotated_conversation, model_executor_.get(), interpreter.get(),
1399 annotator != nullptr ? annotator->entity_data_schema() : nullptr,
1400 &response->actions)) {
1401 TC3_LOG(ERROR) << "Could not suggest actions from script.";
1402 return false;
1403 }
1404
1405 if (!regex_actions_->SuggestActions(annotated_conversation,
1406 entity_data_builder_.get(),
1407 &response->actions)) {
1408 TC3_LOG(ERROR) << "Could not suggest actions from regex rules.";
1409 return false;
1410 }
1411
1412 if (preconditions_.suppress_on_low_confidence_input &&
1413 !regex_actions_->FilterConfidenceOutput(post_check_rules,
1414 &response->actions)) {
1415 TC3_LOG(ERROR) << "Could not post-check actions.";
1416 return false;
1417 }
1418
1419 return true;
1420 }
1421
SuggestActions(const Conversation & conversation,const Annotator * annotator,const ActionSuggestionOptions & options) const1422 ActionsSuggestionsResponse ActionsSuggestions::SuggestActions(
1423 const Conversation& conversation, const Annotator* annotator,
1424 const ActionSuggestionOptions& options) const {
1425 ActionsSuggestionsResponse response;
1426
1427 // Assert that messages are sorted correctly.
1428 for (int i = 1; i < conversation.messages.size(); i++) {
1429 if (conversation.messages[i].reference_time_ms_utc <
1430 conversation.messages[i - 1].reference_time_ms_utc) {
1431 TC3_LOG(ERROR) << "Messages are not sorted most recent last.";
1432 return response;
1433 }
1434 }
1435
1436 // Check that messages are valid utf8.
1437 for (const ConversationMessage& message : conversation.messages) {
1438 if (message.text.size() > std::numeric_limits<int>::max()) {
1439 TC3_LOG(ERROR) << "Rejecting too long input: " << message.text.size();
1440 return {};
1441 }
1442
1443 if (!unilib_->IsValidUtf8(UTF8ToUnicodeText(
1444 message.text.data(), message.text.size(), /*do_copy=*/false))) {
1445 TC3_LOG(ERROR) << "Not valid utf8 provided.";
1446 return response;
1447 }
1448 }
1449
1450 if (!GatherActionsSuggestions(conversation, annotator, options, &response)) {
1451 TC3_LOG(ERROR) << "Could not gather actions suggestions.";
1452 response.actions.clear();
1453 } else if (!ranker_->RankActions(conversation, &response, entity_data_schema_,
1454 annotator != nullptr
1455 ? annotator->entity_data_schema()
1456 : nullptr)) {
1457 TC3_LOG(ERROR) << "Could not rank actions.";
1458 response.actions.clear();
1459 }
1460 return response;
1461 }
1462
SuggestActions(const Conversation & conversation,const ActionSuggestionOptions & options) const1463 ActionsSuggestionsResponse ActionsSuggestions::SuggestActions(
1464 const Conversation& conversation,
1465 const ActionSuggestionOptions& options) const {
1466 return SuggestActions(conversation, /*annotator=*/nullptr, options);
1467 }
1468
model() const1469 const ActionsModel* ActionsSuggestions::model() const { return model_; }
entity_data_schema() const1470 const reflection::Schema* ActionsSuggestions::entity_data_schema() const {
1471 return entity_data_schema_;
1472 }
1473
ViewActionsModel(const void * buffer,int size)1474 const ActionsModel* ViewActionsModel(const void* buffer, int size) {
1475 if (buffer == nullptr) {
1476 return nullptr;
1477 }
1478 return LoadAndVerifyModel(reinterpret_cast<const uint8_t*>(buffer), size);
1479 }
1480
InitializeConversationIntentDetection(const std::string & serialized_config)1481 bool ActionsSuggestions::InitializeConversationIntentDetection(
1482 const std::string& serialized_config) {
1483 auto conversation_intent_detection =
1484 std::make_unique<ConversationIntentDetection>();
1485 if (!conversation_intent_detection->Initialize(serialized_config).ok()) {
1486 TC3_LOG(ERROR) << "Failed to initialize conversation intent detection.";
1487 return false;
1488 }
1489 conversation_intent_detection_ = std::move(conversation_intent_detection);
1490 return true;
1491 }
1492
1493 } // namespace libtextclassifier3
1494