1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "actions/actions-suggestions.h"
18
19 #include <memory>
20
21 #include "actions/lua-actions.h"
22 #include "actions/types.h"
23 #include "actions/utils.h"
24 #include "actions/zlib-utils.h"
25 #include "annotator/collections.h"
26 #include "utils/base/logging.h"
27 #include "utils/flatbuffers.h"
28 #include "utils/lua-utils.h"
29 #include "utils/normalization.h"
30 #include "utils/optional.h"
31 #include "utils/strings/split.h"
32 #include "utils/strings/stringpiece.h"
33 #include "utils/utf8/unicodetext.h"
34 #include "tensorflow/lite/string_util.h"
35
36 namespace libtextclassifier3 {
37
38 const std::string& ActionsSuggestions::kViewCalendarType =
__anon4bdb828f0102() 39 *[]() { return new std::string("view_calendar"); }();
40 const std::string& ActionsSuggestions::kViewMapType =
__anon4bdb828f0202() 41 *[]() { return new std::string("view_map"); }();
42 const std::string& ActionsSuggestions::kTrackFlightType =
__anon4bdb828f0302() 43 *[]() { return new std::string("track_flight"); }();
44 const std::string& ActionsSuggestions::kOpenUrlType =
__anon4bdb828f0402() 45 *[]() { return new std::string("open_url"); }();
46 const std::string& ActionsSuggestions::kSendSmsType =
__anon4bdb828f0502() 47 *[]() { return new std::string("send_sms"); }();
48 const std::string& ActionsSuggestions::kCallPhoneType =
__anon4bdb828f0602() 49 *[]() { return new std::string("call_phone"); }();
50 const std::string& ActionsSuggestions::kSendEmailType =
__anon4bdb828f0702() 51 *[]() { return new std::string("send_email"); }();
52 const std::string& ActionsSuggestions::kShareLocation =
__anon4bdb828f0802() 53 *[]() { return new std::string("share_location"); }();
54
55 // Name for a datetime annotation that only includes time but no date.
56 const std::string& kTimeAnnotation =
__anon4bdb828f0902() 57 *[]() { return new std::string("time"); }();
58
59 constexpr float kDefaultFloat = 0.0;
60 constexpr bool kDefaultBool = false;
61 constexpr int kDefaultInt = 1;
62
63 namespace {
64
LoadAndVerifyModel(const uint8_t * addr,int size)65 const ActionsModel* LoadAndVerifyModel(const uint8_t* addr, int size) {
66 flatbuffers::Verifier verifier(addr, size);
67 if (VerifyActionsModelBuffer(verifier)) {
68 return GetActionsModel(addr);
69 } else {
70 return nullptr;
71 }
72 }
73
74 template <typename T>
ValueOrDefault(const flatbuffers::Table * values,const int32 field_offset,const T default_value)75 T ValueOrDefault(const flatbuffers::Table* values, const int32 field_offset,
76 const T default_value) {
77 if (values == nullptr) {
78 return default_value;
79 }
80 return values->GetField<T>(field_offset, default_value);
81 }
82
83 // Returns number of (tail) messages of a conversation to consider.
NumMessagesToConsider(const Conversation & conversation,const int max_conversation_history_length)84 int NumMessagesToConsider(const Conversation& conversation,
85 const int max_conversation_history_length) {
86 return ((max_conversation_history_length < 0 ||
87 conversation.messages.size() < max_conversation_history_length)
88 ? conversation.messages.size()
89 : max_conversation_history_length);
90 }
91
92 } // namespace
93
FromUnownedBuffer(const uint8_t * buffer,const int size,const UniLib * unilib,const std::string & triggering_preconditions_overlay)94 std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromUnownedBuffer(
95 const uint8_t* buffer, const int size, const UniLib* unilib,
96 const std::string& triggering_preconditions_overlay) {
97 auto actions = std::unique_ptr<ActionsSuggestions>(new ActionsSuggestions());
98 const ActionsModel* model = LoadAndVerifyModel(buffer, size);
99 if (model == nullptr) {
100 return nullptr;
101 }
102 actions->model_ = model;
103 actions->SetOrCreateUnilib(unilib);
104 actions->triggering_preconditions_overlay_buffer_ =
105 triggering_preconditions_overlay;
106 if (!actions->ValidateAndInitialize()) {
107 return nullptr;
108 }
109 return actions;
110 }
111
FromScopedMmap(std::unique_ptr<libtextclassifier3::ScopedMmap> mmap,const UniLib * unilib,const std::string & triggering_preconditions_overlay)112 std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromScopedMmap(
113 std::unique_ptr<libtextclassifier3::ScopedMmap> mmap, const UniLib* unilib,
114 const std::string& triggering_preconditions_overlay) {
115 if (!mmap->handle().ok()) {
116 TC3_VLOG(1) << "Mmap failed.";
117 return nullptr;
118 }
119 const ActionsModel* model = LoadAndVerifyModel(
120 reinterpret_cast<const uint8_t*>(mmap->handle().start()),
121 mmap->handle().num_bytes());
122 if (!model) {
123 TC3_LOG(ERROR) << "Model verification failed.";
124 return nullptr;
125 }
126 auto actions = std::unique_ptr<ActionsSuggestions>(new ActionsSuggestions());
127 actions->model_ = model;
128 actions->mmap_ = std::move(mmap);
129 actions->SetOrCreateUnilib(unilib);
130 actions->triggering_preconditions_overlay_buffer_ =
131 triggering_preconditions_overlay;
132 if (!actions->ValidateAndInitialize()) {
133 return nullptr;
134 }
135 return actions;
136 }
137
FromScopedMmap(std::unique_ptr<libtextclassifier3::ScopedMmap> mmap,std::unique_ptr<UniLib> unilib,const std::string & triggering_preconditions_overlay)138 std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromScopedMmap(
139 std::unique_ptr<libtextclassifier3::ScopedMmap> mmap,
140 std::unique_ptr<UniLib> unilib,
141 const std::string& triggering_preconditions_overlay) {
142 if (!mmap->handle().ok()) {
143 TC3_VLOG(1) << "Mmap failed.";
144 return nullptr;
145 }
146 const ActionsModel* model = LoadAndVerifyModel(
147 reinterpret_cast<const uint8_t*>(mmap->handle().start()),
148 mmap->handle().num_bytes());
149 if (!model) {
150 TC3_LOG(ERROR) << "Model verification failed.";
151 return nullptr;
152 }
153 auto actions = std::unique_ptr<ActionsSuggestions>(new ActionsSuggestions());
154 actions->model_ = model;
155 actions->mmap_ = std::move(mmap);
156 actions->owned_unilib_ = std::move(unilib);
157 actions->unilib_ = actions->owned_unilib_.get();
158 actions->triggering_preconditions_overlay_buffer_ =
159 triggering_preconditions_overlay;
160 if (!actions->ValidateAndInitialize()) {
161 return nullptr;
162 }
163 return actions;
164 }
165
FromFileDescriptor(const int fd,const int offset,const int size,const UniLib * unilib,const std::string & triggering_preconditions_overlay)166 std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromFileDescriptor(
167 const int fd, const int offset, const int size, const UniLib* unilib,
168 const std::string& triggering_preconditions_overlay) {
169 std::unique_ptr<libtextclassifier3::ScopedMmap> mmap;
170 if (offset >= 0 && size >= 0) {
171 mmap.reset(new libtextclassifier3::ScopedMmap(fd, offset, size));
172 } else {
173 mmap.reset(new libtextclassifier3::ScopedMmap(fd));
174 }
175 return FromScopedMmap(std::move(mmap), unilib,
176 triggering_preconditions_overlay);
177 }
178
FromFileDescriptor(const int fd,const int offset,const int size,std::unique_ptr<UniLib> unilib,const std::string & triggering_preconditions_overlay)179 std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromFileDescriptor(
180 const int fd, const int offset, const int size,
181 std::unique_ptr<UniLib> unilib,
182 const std::string& triggering_preconditions_overlay) {
183 std::unique_ptr<libtextclassifier3::ScopedMmap> mmap;
184 if (offset >= 0 && size >= 0) {
185 mmap.reset(new libtextclassifier3::ScopedMmap(fd, offset, size));
186 } else {
187 mmap.reset(new libtextclassifier3::ScopedMmap(fd));
188 }
189 return FromScopedMmap(std::move(mmap), std::move(unilib),
190 triggering_preconditions_overlay);
191 }
192
FromFileDescriptor(const int fd,const UniLib * unilib,const std::string & triggering_preconditions_overlay)193 std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromFileDescriptor(
194 const int fd, const UniLib* unilib,
195 const std::string& triggering_preconditions_overlay) {
196 std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
197 new libtextclassifier3::ScopedMmap(fd));
198 return FromScopedMmap(std::move(mmap), unilib,
199 triggering_preconditions_overlay);
200 }
201
FromFileDescriptor(const int fd,std::unique_ptr<UniLib> unilib,const std::string & triggering_preconditions_overlay)202 std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromFileDescriptor(
203 const int fd, std::unique_ptr<UniLib> unilib,
204 const std::string& triggering_preconditions_overlay) {
205 std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
206 new libtextclassifier3::ScopedMmap(fd));
207 return FromScopedMmap(std::move(mmap), std::move(unilib),
208 triggering_preconditions_overlay);
209 }
210
FromPath(const std::string & path,const UniLib * unilib,const std::string & triggering_preconditions_overlay)211 std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromPath(
212 const std::string& path, const UniLib* unilib,
213 const std::string& triggering_preconditions_overlay) {
214 std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
215 new libtextclassifier3::ScopedMmap(path));
216 return FromScopedMmap(std::move(mmap), unilib,
217 triggering_preconditions_overlay);
218 }
219
FromPath(const std::string & path,std::unique_ptr<UniLib> unilib,const std::string & triggering_preconditions_overlay)220 std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromPath(
221 const std::string& path, std::unique_ptr<UniLib> unilib,
222 const std::string& triggering_preconditions_overlay) {
223 std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
224 new libtextclassifier3::ScopedMmap(path));
225 return FromScopedMmap(std::move(mmap), std::move(unilib),
226 triggering_preconditions_overlay);
227 }
228
SetOrCreateUnilib(const UniLib * unilib)229 void ActionsSuggestions::SetOrCreateUnilib(const UniLib* unilib) {
230 if (unilib != nullptr) {
231 unilib_ = unilib;
232 } else {
233 owned_unilib_.reset(new UniLib);
234 unilib_ = owned_unilib_.get();
235 }
236 }
237
ValidateAndInitialize()238 bool ActionsSuggestions::ValidateAndInitialize() {
239 if (model_ == nullptr) {
240 TC3_LOG(ERROR) << "No model specified.";
241 return false;
242 }
243
244 if (model_->smart_reply_action_type() == nullptr) {
245 TC3_LOG(ERROR) << "No smart reply action type specified.";
246 return false;
247 }
248
249 if (!InitializeTriggeringPreconditions()) {
250 TC3_LOG(ERROR) << "Could not initialize preconditions.";
251 return false;
252 }
253
254 if (model_->locales() &&
255 !ParseLocales(model_->locales()->c_str(), &locales_)) {
256 TC3_LOG(ERROR) << "Could not parse model supported locales.";
257 return false;
258 }
259
260 if (model_->tflite_model_spec() != nullptr) {
261 model_executor_ = TfLiteModelExecutor::FromBuffer(
262 model_->tflite_model_spec()->tflite_model());
263 if (!model_executor_) {
264 TC3_LOG(ERROR) << "Could not initialize model executor.";
265 return false;
266 }
267 }
268
269 // Gather annotation entities for the rules.
270 if (model_->annotation_actions_spec() != nullptr &&
271 model_->annotation_actions_spec()->annotation_mapping() != nullptr) {
272 for (const AnnotationActionsSpec_::AnnotationMapping* mapping :
273 *model_->annotation_actions_spec()->annotation_mapping()) {
274 annotation_entity_types_.insert(mapping->annotation_collection()->str());
275 }
276 }
277
278 if (model_->actions_entity_data_schema() != nullptr) {
279 entity_data_schema_ = LoadAndVerifyFlatbuffer<reflection::Schema>(
280 model_->actions_entity_data_schema()->Data(),
281 model_->actions_entity_data_schema()->size());
282 if (entity_data_schema_ == nullptr) {
283 TC3_LOG(ERROR) << "Could not load entity data schema data.";
284 return false;
285 }
286
287 entity_data_builder_.reset(
288 new ReflectiveFlatbufferBuilder(entity_data_schema_));
289 } else {
290 entity_data_schema_ = nullptr;
291 }
292
293 // Initialize regular expressions model.
294 std::unique_ptr<ZlibDecompressor> decompressor = ZlibDecompressor::Instance();
295 regex_actions_.reset(
296 new RegexActions(unilib_, model_->smart_reply_action_type()->str()));
297 if (!regex_actions_->InitializeRules(
298 model_->rules(), model_->low_confidence_rules(),
299 triggering_preconditions_overlay_, decompressor.get())) {
300 TC3_LOG(ERROR) << "Could not initialize regex rules.";
301 return false;
302 }
303
304 // Setup grammar model.
305 if (model_->rules() != nullptr &&
306 model_->rules()->grammar_rules() != nullptr) {
307 grammar_actions_.reset(new GrammarActions(
308 unilib_, model_->rules()->grammar_rules(), entity_data_builder_.get(),
309 model_->smart_reply_action_type()->str()));
310
311 // Gather annotation entities for the grammars.
312 if (auto annotation_nt = model_->rules()
313 ->grammar_rules()
314 ->rules()
315 ->nonterminals()
316 ->annotation_nt()) {
317 for (const grammar::RulesSet_::Nonterminals_::AnnotationNtEntry* entry :
318 *annotation_nt) {
319 annotation_entity_types_.insert(entry->key()->str());
320 }
321 }
322 }
323
324 std::string actions_script;
325 if (GetUncompressedString(model_->lua_actions_script(),
326 model_->compressed_lua_actions_script(),
327 decompressor.get(), &actions_script) &&
328 !actions_script.empty()) {
329 if (!Compile(actions_script, &lua_bytecode_)) {
330 TC3_LOG(ERROR) << "Could not precompile lua actions snippet.";
331 return false;
332 }
333 }
334
335 if (!(ranker_ = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
336 model_->ranking_options(), decompressor.get(),
337 model_->smart_reply_action_type()->str()))) {
338 TC3_LOG(ERROR) << "Could not create an action suggestions ranker.";
339 return false;
340 }
341
342 // Create feature processor if specified.
343 const ActionsTokenFeatureProcessorOptions* options =
344 model_->feature_processor_options();
345 if (options != nullptr) {
346 if (options->tokenizer_options() == nullptr) {
347 TC3_LOG(ERROR) << "No tokenizer options specified.";
348 return false;
349 }
350
351 feature_processor_.reset(new ActionsFeatureProcessor(options, unilib_));
352 embedding_executor_ = TFLiteEmbeddingExecutor::FromBuffer(
353 options->embedding_model(), options->embedding_size(),
354 options->embedding_quantization_bits());
355
356 if (embedding_executor_ == nullptr) {
357 TC3_LOG(ERROR) << "Could not initialize embedding executor.";
358 return false;
359 }
360
361 // Cache embedding of padding, start and end token.
362 if (!EmbedTokenId(options->padding_token_id(), &embedded_padding_token_) ||
363 !EmbedTokenId(options->start_token_id(), &embedded_start_token_) ||
364 !EmbedTokenId(options->end_token_id(), &embedded_end_token_)) {
365 TC3_LOG(ERROR) << "Could not precompute token embeddings.";
366 return false;
367 }
368 token_embedding_size_ = feature_processor_->GetTokenEmbeddingSize();
369 }
370
371 // Create low confidence model if specified.
372 if (model_->low_confidence_ngram_model() != nullptr) {
373 ngram_model_ = NGramModel::Create(
374 unilib_, model_->low_confidence_ngram_model(),
375 feature_processor_ == nullptr ? nullptr
376 : feature_processor_->tokenizer());
377 if (ngram_model_ == nullptr) {
378 TC3_LOG(ERROR) << "Could not create ngram linear regression model.";
379 return false;
380 }
381 }
382
383 return true;
384 }
385
InitializeTriggeringPreconditions()386 bool ActionsSuggestions::InitializeTriggeringPreconditions() {
387 triggering_preconditions_overlay_ =
388 LoadAndVerifyFlatbuffer<TriggeringPreconditions>(
389 triggering_preconditions_overlay_buffer_);
390
391 if (triggering_preconditions_overlay_ == nullptr &&
392 !triggering_preconditions_overlay_buffer_.empty()) {
393 TC3_LOG(ERROR) << "Could not load triggering preconditions overwrites.";
394 return false;
395 }
396 const flatbuffers::Table* overlay =
397 reinterpret_cast<const flatbuffers::Table*>(
398 triggering_preconditions_overlay_);
399 const TriggeringPreconditions* defaults = model_->preconditions();
400 if (defaults == nullptr) {
401 TC3_LOG(ERROR) << "No triggering conditions specified.";
402 return false;
403 }
404
405 preconditions_.min_smart_reply_triggering_score = ValueOrDefault(
406 overlay, TriggeringPreconditions::VT_MIN_SMART_REPLY_TRIGGERING_SCORE,
407 defaults->min_smart_reply_triggering_score());
408 preconditions_.max_sensitive_topic_score = ValueOrDefault(
409 overlay, TriggeringPreconditions::VT_MAX_SENSITIVE_TOPIC_SCORE,
410 defaults->max_sensitive_topic_score());
411 preconditions_.suppress_on_sensitive_topic = ValueOrDefault(
412 overlay, TriggeringPreconditions::VT_SUPPRESS_ON_SENSITIVE_TOPIC,
413 defaults->suppress_on_sensitive_topic());
414 preconditions_.min_input_length =
415 ValueOrDefault(overlay, TriggeringPreconditions::VT_MIN_INPUT_LENGTH,
416 defaults->min_input_length());
417 preconditions_.max_input_length =
418 ValueOrDefault(overlay, TriggeringPreconditions::VT_MAX_INPUT_LENGTH,
419 defaults->max_input_length());
420 preconditions_.min_locale_match_fraction = ValueOrDefault(
421 overlay, TriggeringPreconditions::VT_MIN_LOCALE_MATCH_FRACTION,
422 defaults->min_locale_match_fraction());
423 preconditions_.handle_missing_locale_as_supported = ValueOrDefault(
424 overlay, TriggeringPreconditions::VT_HANDLE_MISSING_LOCALE_AS_SUPPORTED,
425 defaults->handle_missing_locale_as_supported());
426 preconditions_.handle_unknown_locale_as_supported = ValueOrDefault(
427 overlay, TriggeringPreconditions::VT_HANDLE_UNKNOWN_LOCALE_AS_SUPPORTED,
428 defaults->handle_unknown_locale_as_supported());
429 preconditions_.suppress_on_low_confidence_input = ValueOrDefault(
430 overlay, TriggeringPreconditions::VT_SUPPRESS_ON_LOW_CONFIDENCE_INPUT,
431 defaults->suppress_on_low_confidence_input());
432 preconditions_.min_reply_score_threshold = ValueOrDefault(
433 overlay, TriggeringPreconditions::VT_MIN_REPLY_SCORE_THRESHOLD,
434 defaults->min_reply_score_threshold());
435
436 return true;
437 }
438
EmbedTokenId(const int32 token_id,std::vector<float> * embedding) const439 bool ActionsSuggestions::EmbedTokenId(const int32 token_id,
440 std::vector<float>* embedding) const {
441 return feature_processor_->AppendFeatures(
442 {token_id},
443 /*dense_features=*/{}, embedding_executor_.get(), embedding);
444 }
445
Tokenize(const std::vector<std::string> & context) const446 std::vector<std::vector<Token>> ActionsSuggestions::Tokenize(
447 const std::vector<std::string>& context) const {
448 std::vector<std::vector<Token>> tokens;
449 tokens.reserve(context.size());
450 for (const std::string& message : context) {
451 tokens.push_back(feature_processor_->tokenizer()->Tokenize(message));
452 }
453 return tokens;
454 }
455
EmbedTokensPerMessage(const std::vector<std::vector<Token>> & tokens,std::vector<float> * embeddings,int * max_num_tokens_per_message) const456 bool ActionsSuggestions::EmbedTokensPerMessage(
457 const std::vector<std::vector<Token>>& tokens,
458 std::vector<float>* embeddings, int* max_num_tokens_per_message) const {
459 const int num_messages = tokens.size();
460 *max_num_tokens_per_message = 0;
461 for (int i = 0; i < num_messages; i++) {
462 const int num_message_tokens = tokens[i].size();
463 if (num_message_tokens > *max_num_tokens_per_message) {
464 *max_num_tokens_per_message = num_message_tokens;
465 }
466 }
467
468 if (model_->feature_processor_options()->min_num_tokens_per_message() >
469 *max_num_tokens_per_message) {
470 *max_num_tokens_per_message =
471 model_->feature_processor_options()->min_num_tokens_per_message();
472 }
473 if (model_->feature_processor_options()->max_num_tokens_per_message() > 0 &&
474 *max_num_tokens_per_message >
475 model_->feature_processor_options()->max_num_tokens_per_message()) {
476 *max_num_tokens_per_message =
477 model_->feature_processor_options()->max_num_tokens_per_message();
478 }
479
480 // Embed all tokens and add paddings to pad tokens of each message to the
481 // maximum number of tokens in a message of the conversation.
482 // If a number of tokens is specified in the model config, tokens at the
483 // beginning of a message are dropped if they don't fit in the limit.
484 for (int i = 0; i < num_messages; i++) {
485 const int start =
486 std::max<int>(tokens[i].size() - *max_num_tokens_per_message, 0);
487 for (int pos = start; pos < tokens[i].size(); pos++) {
488 if (!feature_processor_->AppendTokenFeatures(
489 tokens[i][pos], embedding_executor_.get(), embeddings)) {
490 TC3_LOG(ERROR) << "Could not run token feature extractor.";
491 return false;
492 }
493 }
494 // Add padding.
495 for (int k = tokens[i].size(); k < *max_num_tokens_per_message; k++) {
496 embeddings->insert(embeddings->end(), embedded_padding_token_.begin(),
497 embedded_padding_token_.end());
498 }
499 }
500
501 return true;
502 }
503
EmbedAndFlattenTokens(const std::vector<std::vector<Token>> & tokens,std::vector<float> * embeddings,int * total_token_count) const504 bool ActionsSuggestions::EmbedAndFlattenTokens(
505 const std::vector<std::vector<Token>>& tokens,
506 std::vector<float>* embeddings, int* total_token_count) const {
507 const int num_messages = tokens.size();
508 int start_message = 0;
509 int message_token_offset = 0;
510
511 // If a maximum model input length is specified, we need to check how
512 // much we need to trim at the start.
513 const int max_num_total_tokens =
514 model_->feature_processor_options()->max_num_total_tokens();
515 if (max_num_total_tokens > 0) {
516 int total_tokens = 0;
517 start_message = num_messages - 1;
518 for (; start_message >= 0; start_message--) {
519 // Tokens of the message + start and end token.
520 const int num_message_tokens = tokens[start_message].size() + 2;
521 total_tokens += num_message_tokens;
522
523 // Check whether we exhausted the budget.
524 if (total_tokens >= max_num_total_tokens) {
525 message_token_offset = total_tokens - max_num_total_tokens;
526 break;
527 }
528 }
529 }
530
531 // Add embeddings.
532 *total_token_count = 0;
533 for (int i = start_message; i < num_messages; i++) {
534 if (message_token_offset == 0) {
535 ++(*total_token_count);
536 // Add `start message` token.
537 embeddings->insert(embeddings->end(), embedded_start_token_.begin(),
538 embedded_start_token_.end());
539 }
540
541 for (int pos = std::max(0, message_token_offset - 1);
542 pos < tokens[i].size(); pos++) {
543 ++(*total_token_count);
544 if (!feature_processor_->AppendTokenFeatures(
545 tokens[i][pos], embedding_executor_.get(), embeddings)) {
546 TC3_LOG(ERROR) << "Could not run token feature extractor.";
547 return false;
548 }
549 }
550
551 // Add `end message` token.
552 ++(*total_token_count);
553 embeddings->insert(embeddings->end(), embedded_end_token_.begin(),
554 embedded_end_token_.end());
555
556 // Reset for the subsequent messages.
557 message_token_offset = 0;
558 }
559
560 // Add optional padding.
561 const int min_num_total_tokens =
562 model_->feature_processor_options()->min_num_total_tokens();
563 for (; *total_token_count < min_num_total_tokens; ++(*total_token_count)) {
564 embeddings->insert(embeddings->end(), embedded_padding_token_.begin(),
565 embedded_padding_token_.end());
566 }
567
568 return true;
569 }
570
AllocateInput(const int conversation_length,const int max_tokens,const int total_token_count,tflite::Interpreter * interpreter) const571 bool ActionsSuggestions::AllocateInput(const int conversation_length,
572 const int max_tokens,
573 const int total_token_count,
574 tflite::Interpreter* interpreter) const {
575 if (model_->tflite_model_spec()->resize_inputs()) {
576 if (model_->tflite_model_spec()->input_context() >= 0) {
577 interpreter->ResizeInputTensor(
578 interpreter->inputs()[model_->tflite_model_spec()->input_context()],
579 {1, conversation_length});
580 }
581 if (model_->tflite_model_spec()->input_user_id() >= 0) {
582 interpreter->ResizeInputTensor(
583 interpreter->inputs()[model_->tflite_model_spec()->input_user_id()],
584 {1, conversation_length});
585 }
586 if (model_->tflite_model_spec()->input_time_diffs() >= 0) {
587 interpreter->ResizeInputTensor(
588 interpreter
589 ->inputs()[model_->tflite_model_spec()->input_time_diffs()],
590 {1, conversation_length});
591 }
592 if (model_->tflite_model_spec()->input_num_tokens() >= 0) {
593 interpreter->ResizeInputTensor(
594 interpreter
595 ->inputs()[model_->tflite_model_spec()->input_num_tokens()],
596 {conversation_length, 1});
597 }
598 if (model_->tflite_model_spec()->input_token_embeddings() >= 0) {
599 interpreter->ResizeInputTensor(
600 interpreter
601 ->inputs()[model_->tflite_model_spec()->input_token_embeddings()],
602 {conversation_length, max_tokens, token_embedding_size_});
603 }
604 if (model_->tflite_model_spec()->input_flattened_token_embeddings() >= 0) {
605 interpreter->ResizeInputTensor(
606 interpreter->inputs()[model_->tflite_model_spec()
607 ->input_flattened_token_embeddings()],
608 {1, total_token_count});
609 }
610 }
611
612 return interpreter->AllocateTensors() == kTfLiteOk;
613 }
614
SetupModelInput(const std::vector<std::string> & context,const std::vector<int> & user_ids,const std::vector<float> & time_diffs,const int num_suggestions,const ActionSuggestionOptions & options,tflite::Interpreter * interpreter) const615 bool ActionsSuggestions::SetupModelInput(
616 const std::vector<std::string>& context, const std::vector<int>& user_ids,
617 const std::vector<float>& time_diffs, const int num_suggestions,
618 const ActionSuggestionOptions& options,
619 tflite::Interpreter* interpreter) const {
620 // Compute token embeddings.
621 std::vector<std::vector<Token>> tokens;
622 std::vector<float> token_embeddings;
623 std::vector<float> flattened_token_embeddings;
624 int max_tokens = 0;
625 int total_token_count = 0;
626 if (model_->tflite_model_spec()->input_num_tokens() >= 0 ||
627 model_->tflite_model_spec()->input_token_embeddings() >= 0 ||
628 model_->tflite_model_spec()->input_flattened_token_embeddings() >= 0) {
629 if (feature_processor_ == nullptr) {
630 TC3_LOG(ERROR) << "No feature processor specified.";
631 return false;
632 }
633
634 // Tokenize the messages in the conversation.
635 tokens = Tokenize(context);
636 if (model_->tflite_model_spec()->input_token_embeddings() >= 0) {
637 if (!EmbedTokensPerMessage(tokens, &token_embeddings, &max_tokens)) {
638 TC3_LOG(ERROR) << "Could not extract token features.";
639 return false;
640 }
641 }
642 if (model_->tflite_model_spec()->input_flattened_token_embeddings() >= 0) {
643 if (!EmbedAndFlattenTokens(tokens, &flattened_token_embeddings,
644 &total_token_count)) {
645 TC3_LOG(ERROR) << "Could not extract token features.";
646 return false;
647 }
648 }
649 }
650
651 if (!AllocateInput(context.size(), max_tokens, total_token_count,
652 interpreter)) {
653 TC3_LOG(ERROR) << "TensorFlow Lite model allocation failed.";
654 return false;
655 }
656 if (model_->tflite_model_spec()->input_context() >= 0) {
657 model_executor_->SetInput<std::string>(
658 model_->tflite_model_spec()->input_context(), context, interpreter);
659 }
660 if (model_->tflite_model_spec()->input_context_length() >= 0) {
661 model_executor_->SetInput<int>(
662 model_->tflite_model_spec()->input_context_length(), context.size(),
663 interpreter);
664 }
665 if (model_->tflite_model_spec()->input_user_id() >= 0) {
666 model_executor_->SetInput<int>(model_->tflite_model_spec()->input_user_id(),
667 user_ids, interpreter);
668 }
669 if (model_->tflite_model_spec()->input_num_suggestions() >= 0) {
670 model_executor_->SetInput<int>(
671 model_->tflite_model_spec()->input_num_suggestions(), num_suggestions,
672 interpreter);
673 }
674 if (model_->tflite_model_spec()->input_time_diffs() >= 0) {
675 model_executor_->SetInput<float>(
676 model_->tflite_model_spec()->input_time_diffs(), time_diffs,
677 interpreter);
678 }
679 if (model_->tflite_model_spec()->input_num_tokens() >= 0) {
680 std::vector<int> num_tokens_per_message(tokens.size());
681 for (int i = 0; i < tokens.size(); i++) {
682 num_tokens_per_message[i] = tokens[i].size();
683 }
684 model_executor_->SetInput<int>(
685 model_->tflite_model_spec()->input_num_tokens(), num_tokens_per_message,
686 interpreter);
687 }
688 if (model_->tflite_model_spec()->input_token_embeddings() >= 0) {
689 model_executor_->SetInput<float>(
690 model_->tflite_model_spec()->input_token_embeddings(), token_embeddings,
691 interpreter);
692 }
693 if (model_->tflite_model_spec()->input_flattened_token_embeddings() >= 0) {
694 model_executor_->SetInput<float>(
695 model_->tflite_model_spec()->input_flattened_token_embeddings(),
696 flattened_token_embeddings, interpreter);
697 }
698 // Set up additional input parameters.
699 if (const auto* input_name_index =
700 model_->tflite_model_spec()->input_name_index()) {
701 const std::unordered_map<std::string, Variant>& model_parameters =
702 options.model_parameters;
703 for (const TensorflowLiteModelSpec_::InputNameIndexEntry* entry :
704 *input_name_index) {
705 const std::string param_name = entry->key()->str();
706 const int param_index = entry->value();
707 const TfLiteType param_type =
708 interpreter->tensor(interpreter->inputs()[param_index])->type;
709 const auto param_value_it = model_parameters.find(param_name);
710 const bool has_value = param_value_it != model_parameters.end();
711 switch (param_type) {
712 case kTfLiteFloat32:
713 model_executor_->SetInput<float>(
714 param_index,
715 has_value ? param_value_it->second.Value<float>() : kDefaultFloat,
716 interpreter);
717 break;
718 case kTfLiteInt32:
719 model_executor_->SetInput<int32_t>(
720 param_index,
721 has_value ? param_value_it->second.Value<int>() : kDefaultInt,
722 interpreter);
723 break;
724 case kTfLiteInt64:
725 model_executor_->SetInput<int64_t>(
726 param_index,
727 has_value ? param_value_it->second.Value<int64>() : kDefaultInt,
728 interpreter);
729 break;
730 case kTfLiteUInt8:
731 model_executor_->SetInput<uint8_t>(
732 param_index,
733 has_value ? param_value_it->second.Value<uint8>() : kDefaultInt,
734 interpreter);
735 break;
736 case kTfLiteInt8:
737 model_executor_->SetInput<int8_t>(
738 param_index,
739 has_value ? param_value_it->second.Value<int8>() : kDefaultInt,
740 interpreter);
741 break;
742 case kTfLiteBool:
743 model_executor_->SetInput<bool>(
744 param_index,
745 has_value ? param_value_it->second.Value<bool>() : kDefaultBool,
746 interpreter);
747 break;
748 default:
749 TC3_LOG(ERROR) << "Unsupported type of additional input parameter: "
750 << param_name;
751 }
752 }
753 }
754 return true;
755 }
756
PopulateTextReplies(const tflite::Interpreter * interpreter,int suggestion_index,int score_index,const std::string & type,ActionsSuggestionsResponse * response) const757 void ActionsSuggestions::PopulateTextReplies(
758 const tflite::Interpreter* interpreter, int suggestion_index,
759 int score_index, const std::string& type,
760 ActionsSuggestionsResponse* response) const {
761 const std::vector<tflite::StringRef> replies =
762 model_executor_->Output<tflite::StringRef>(suggestion_index, interpreter);
763 const TensorView<float> scores =
764 model_executor_->OutputView<float>(score_index, interpreter);
765 for (int i = 0; i < replies.size(); i++) {
766 if (replies[i].len == 0) {
767 continue;
768 }
769 const float score = scores.data()[i];
770 if (score < preconditions_.min_reply_score_threshold) {
771 continue;
772 }
773 response->actions.push_back(
774 {std::string(replies[i].str, replies[i].len), type, score});
775 }
776 }
777
FillSuggestionFromSpecWithEntityData(const ActionSuggestionSpec * spec,ActionSuggestion * suggestion) const778 void ActionsSuggestions::FillSuggestionFromSpecWithEntityData(
779 const ActionSuggestionSpec* spec, ActionSuggestion* suggestion) const {
780 std::unique_ptr<ReflectiveFlatbuffer> entity_data =
781 entity_data_builder_ != nullptr ? entity_data_builder_->NewRoot()
782 : nullptr;
783 FillSuggestionFromSpec(spec, entity_data.get(), suggestion);
784 }
785
PopulateIntentTriggering(const tflite::Interpreter * interpreter,int suggestion_index,int score_index,const ActionSuggestionSpec * task_spec,ActionsSuggestionsResponse * response) const786 void ActionsSuggestions::PopulateIntentTriggering(
787 const tflite::Interpreter* interpreter, int suggestion_index,
788 int score_index, const ActionSuggestionSpec* task_spec,
789 ActionsSuggestionsResponse* response) const {
790 if (!task_spec || task_spec->type()->size() == 0) {
791 TC3_LOG(ERROR)
792 << "Task type for intent (action) triggering cannot be empty!";
793 return;
794 }
795 const TensorView<bool> intent_prediction =
796 model_executor_->OutputView<bool>(suggestion_index, interpreter);
797 const TensorView<float> intent_scores =
798 model_executor_->OutputView<float>(score_index, interpreter);
799 // Two result corresponding to binary triggering case.
800 TC3_CHECK_EQ(intent_prediction.size(), 2);
801 TC3_CHECK_EQ(intent_scores.size(), 2);
802 // We rely on in-graph thresholding logic so at this point the results
803 // have been ranked properly according to threshold.
804 const bool triggering = intent_prediction.data()[0];
805 const float trigger_score = intent_scores.data()[0];
806
807 if (triggering) {
808 ActionSuggestion suggestion;
809 std::unique_ptr<ReflectiveFlatbuffer> entity_data =
810 entity_data_builder_ != nullptr ? entity_data_builder_->NewRoot()
811 : nullptr;
812 FillSuggestionFromSpecWithEntityData(task_spec, &suggestion);
813 suggestion.score = trigger_score;
814 response->actions.push_back(std::move(suggestion));
815 }
816 }
817
ReadModelOutput(tflite::Interpreter * interpreter,const ActionSuggestionOptions & options,ActionsSuggestionsResponse * response) const818 bool ActionsSuggestions::ReadModelOutput(
819 tflite::Interpreter* interpreter, const ActionSuggestionOptions& options,
820 ActionsSuggestionsResponse* response) const {
821 // Read sensitivity and triggering score predictions.
822 if (model_->tflite_model_spec()->output_triggering_score() >= 0) {
823 const TensorView<float> triggering_score =
824 model_executor_->OutputView<float>(
825 model_->tflite_model_spec()->output_triggering_score(),
826 interpreter);
827 if (!triggering_score.is_valid() || triggering_score.size() == 0) {
828 TC3_LOG(ERROR) << "Could not compute triggering score.";
829 return false;
830 }
831 response->triggering_score = triggering_score.data()[0];
832 response->output_filtered_min_triggering_score =
833 (response->triggering_score <
834 preconditions_.min_smart_reply_triggering_score);
835 }
836 if (model_->tflite_model_spec()->output_sensitive_topic_score() >= 0) {
837 const TensorView<float> sensitive_topic_score =
838 model_executor_->OutputView<float>(
839 model_->tflite_model_spec()->output_sensitive_topic_score(),
840 interpreter);
841 if (!sensitive_topic_score.is_valid() ||
842 sensitive_topic_score.dim(0) != 1) {
843 TC3_LOG(ERROR) << "Could not compute sensitive topic score.";
844 return false;
845 }
846 response->sensitivity_score = sensitive_topic_score.data()[0];
847 response->output_filtered_sensitivity =
848 (response->sensitivity_score >
849 preconditions_.max_sensitive_topic_score);
850 }
851
852 // Suppress model outputs.
853 if (response->output_filtered_sensitivity) {
854 return true;
855 }
856
857 // Read smart reply predictions.
858 if (!response->output_filtered_min_triggering_score &&
859 model_->tflite_model_spec()->output_replies() >= 0) {
860 PopulateTextReplies(interpreter,
861 model_->tflite_model_spec()->output_replies(),
862 model_->tflite_model_spec()->output_replies_scores(),
863 model_->smart_reply_action_type()->str(), response);
864 }
865
866 // Read actions suggestions.
867 if (model_->tflite_model_spec()->output_actions_scores() >= 0) {
868 const TensorView<float> actions_scores = model_executor_->OutputView<float>(
869 model_->tflite_model_spec()->output_actions_scores(), interpreter);
870 for (int i = 0; i < model_->action_type()->size(); i++) {
871 const ActionTypeOptions* action_type = model_->action_type()->Get(i);
872 // Skip disabled action classes, such as the default other category.
873 if (!action_type->enabled()) {
874 continue;
875 }
876 const float score = actions_scores.data()[i];
877 if (score < action_type->min_triggering_score()) {
878 continue;
879 }
880
881 // Create action from model output.
882 ActionSuggestion suggestion;
883 suggestion.type = action_type->name()->str();
884 std::unique_ptr<ReflectiveFlatbuffer> entity_data =
885 entity_data_builder_ != nullptr ? entity_data_builder_->NewRoot()
886 : nullptr;
887 FillSuggestionFromSpecWithEntityData(action_type->action(), &suggestion);
888 suggestion.score = score;
889 response->actions.push_back(std::move(suggestion));
890 }
891 }
892
893 // Read multi-task predictions and construct the result properly.
894 if (const auto* prediction_metadata =
895 model_->tflite_model_spec()->prediction_metadata()) {
896 for (const PredictionMetadata* metadata : *prediction_metadata) {
897 const ActionSuggestionSpec* task_spec = metadata->task_spec();
898 const int suggestions_index = metadata->output_suggestions();
899 const int suggestions_scores_index =
900 metadata->output_suggestions_scores();
901 switch (metadata->prediction_type()) {
902 case PredictionType_NEXT_MESSAGE_PREDICTION:
903 if (!task_spec || task_spec->type()->size() == 0) {
904 TC3_LOG(WARNING) << "Task type not provided, use default "
905 "smart_reply_action_type!";
906 }
907 PopulateTextReplies(
908 interpreter, suggestions_index, suggestions_scores_index,
909 task_spec ? task_spec->type()->str()
910 : model_->smart_reply_action_type()->str(),
911 response);
912 break;
913 case PredictionType_INTENT_TRIGGERING:
914 PopulateIntentTriggering(interpreter, suggestions_index,
915 suggestions_scores_index, task_spec,
916 response);
917 break;
918 default:
919 TC3_LOG(ERROR) << "Unsupported prediction type!";
920 return false;
921 }
922 }
923 }
924
925 return true;
926 }
927
SuggestActionsFromModel(const Conversation & conversation,const int num_messages,const ActionSuggestionOptions & options,ActionsSuggestionsResponse * response,std::unique_ptr<tflite::Interpreter> * interpreter) const928 bool ActionsSuggestions::SuggestActionsFromModel(
929 const Conversation& conversation, const int num_messages,
930 const ActionSuggestionOptions& options,
931 ActionsSuggestionsResponse* response,
932 std::unique_ptr<tflite::Interpreter>* interpreter) const {
933 TC3_CHECK_LE(num_messages, conversation.messages.size());
934
935 if (!model_executor_) {
936 return true;
937 }
938 *interpreter = model_executor_->CreateInterpreter();
939
940 if (!*interpreter) {
941 TC3_LOG(ERROR) << "Could not build TensorFlow Lite interpreter for the "
942 "actions suggestions model.";
943 return false;
944 }
945
946 std::vector<std::string> context;
947 std::vector<int> user_ids;
948 std::vector<float> time_diffs;
949 context.reserve(num_messages);
950 user_ids.reserve(num_messages);
951 time_diffs.reserve(num_messages);
952
953 // Gather last `num_messages` messages from the conversation.
954 int64 last_message_reference_time_ms_utc = 0;
955 const float second_in_ms = 1000;
956 for (int i = conversation.messages.size() - num_messages;
957 i < conversation.messages.size(); i++) {
958 const ConversationMessage& message = conversation.messages[i];
959 context.push_back(message.text);
960 user_ids.push_back(message.user_id);
961
962 float time_diff_secs = 0;
963 if (message.reference_time_ms_utc != 0 &&
964 last_message_reference_time_ms_utc != 0) {
965 time_diff_secs = std::max(0.0f, (message.reference_time_ms_utc -
966 last_message_reference_time_ms_utc) /
967 second_in_ms);
968 }
969 if (message.reference_time_ms_utc != 0) {
970 last_message_reference_time_ms_utc = message.reference_time_ms_utc;
971 }
972 time_diffs.push_back(time_diff_secs);
973 }
974
975 if (!SetupModelInput(context, user_ids, time_diffs,
976 /*num_suggestions=*/model_->num_smart_replies(), options,
977 interpreter->get())) {
978 TC3_LOG(ERROR) << "Failed to setup input for TensorFlow Lite model.";
979 return false;
980 }
981
982 if ((*interpreter)->Invoke() != kTfLiteOk) {
983 TC3_LOG(ERROR) << "Failed to invoke TensorFlow Lite interpreter.";
984 return false;
985 }
986
987 return ReadModelOutput(interpreter->get(), options, response);
988 }
989
AnnotationOptionsForMessage(const ConversationMessage & message) const990 AnnotationOptions ActionsSuggestions::AnnotationOptionsForMessage(
991 const ConversationMessage& message) const {
992 AnnotationOptions options;
993 options.detected_text_language_tags = message.detected_text_language_tags;
994 options.reference_time_ms_utc = message.reference_time_ms_utc;
995 options.reference_timezone = message.reference_timezone;
996 options.annotation_usecase =
997 model_->annotation_actions_spec()->annotation_usecase();
998 options.is_serialized_entity_data_enabled =
999 model_->annotation_actions_spec()->is_serialized_entity_data_enabled();
1000 options.entity_types = annotation_entity_types_;
1001 return options;
1002 }
1003
1004 // Run annotator on the messages of a conversation.
AnnotateConversation(const Conversation & conversation,const Annotator * annotator) const1005 Conversation ActionsSuggestions::AnnotateConversation(
1006 const Conversation& conversation, const Annotator* annotator) const {
1007 if (annotator == nullptr) {
1008 return conversation;
1009 }
1010 const int num_messages_grammar =
1011 ((model_->rules() && model_->rules()->grammar_rules() &&
1012 model_->rules()
1013 ->grammar_rules()
1014 ->rules()
1015 ->nonterminals()
1016 ->annotation_nt())
1017 ? 1
1018 : 0);
1019 const int num_messages_mapping =
1020 (model_->annotation_actions_spec()
1021 ? std::max(model_->annotation_actions_spec()
1022 ->max_history_from_any_person(),
1023 model_->annotation_actions_spec()
1024 ->max_history_from_last_person())
1025 : 0);
1026 const int num_messages = std::max(num_messages_grammar, num_messages_mapping);
1027 if (num_messages == 0) {
1028 // No annotations are used.
1029 return conversation;
1030 }
1031 Conversation annotated_conversation = conversation;
1032 for (int i = 0, message_index = annotated_conversation.messages.size() - 1;
1033 i < num_messages && message_index >= 0; i++, message_index--) {
1034 ConversationMessage* message =
1035 &annotated_conversation.messages[message_index];
1036 if (message->annotations.empty()) {
1037 message->annotations = annotator->Annotate(
1038 message->text, AnnotationOptionsForMessage(*message));
1039 for (int i = 0; i < message->annotations.size(); i++) {
1040 ClassificationResult* classification =
1041 &message->annotations[i].classification.front();
1042
1043 // Specialize datetime annotation to time annotation if no date
1044 // component is present.
1045 if (classification->collection == Collections::DateTime() &&
1046 classification->datetime_parse_result.IsSet()) {
1047 bool has_only_time = true;
1048 for (const DatetimeComponent& component :
1049 classification->datetime_parse_result.datetime_components) {
1050 if (component.component_type !=
1051 DatetimeComponent::ComponentType::UNSPECIFIED &&
1052 component.component_type <
1053 DatetimeComponent::ComponentType::HOUR) {
1054 has_only_time = false;
1055 break;
1056 }
1057 }
1058 if (has_only_time) {
1059 classification->collection = kTimeAnnotation;
1060 }
1061 }
1062 }
1063 }
1064 }
1065 return annotated_conversation;
1066 }
1067
SuggestActionsFromAnnotations(const Conversation & conversation,std::vector<ActionSuggestion> * actions) const1068 void ActionsSuggestions::SuggestActionsFromAnnotations(
1069 const Conversation& conversation,
1070 std::vector<ActionSuggestion>* actions) const {
1071 if (model_->annotation_actions_spec() == nullptr ||
1072 model_->annotation_actions_spec()->annotation_mapping() == nullptr ||
1073 model_->annotation_actions_spec()->annotation_mapping()->size() == 0) {
1074 return;
1075 }
1076
1077 // Create actions based on the annotations.
1078 const int max_from_any_person =
1079 model_->annotation_actions_spec()->max_history_from_any_person();
1080 const int max_from_last_person =
1081 model_->annotation_actions_spec()->max_history_from_last_person();
1082 const int last_person = conversation.messages.back().user_id;
1083
1084 int num_messages_last_person = 0;
1085 int num_messages_any_person = 0;
1086 bool all_from_last_person = true;
1087 for (int message_index = conversation.messages.size() - 1; message_index >= 0;
1088 message_index--) {
1089 const ConversationMessage& message = conversation.messages[message_index];
1090 std::vector<AnnotatedSpan> annotations = message.annotations;
1091
1092 // Update how many messages we have processed from the last person in the
1093 // conversation and from any person in the conversation.
1094 num_messages_any_person++;
1095 if (all_from_last_person && message.user_id == last_person) {
1096 num_messages_last_person++;
1097 } else {
1098 all_from_last_person = false;
1099 }
1100
1101 if (num_messages_any_person > max_from_any_person &&
1102 (!all_from_last_person ||
1103 num_messages_last_person > max_from_last_person)) {
1104 break;
1105 }
1106
1107 if (message.user_id == kLocalUserId) {
1108 if (model_->annotation_actions_spec()->only_until_last_sent()) {
1109 break;
1110 }
1111 if (!model_->annotation_actions_spec()->include_local_user_messages()) {
1112 continue;
1113 }
1114 }
1115
1116 std::vector<ActionSuggestionAnnotation> action_annotations;
1117 action_annotations.reserve(annotations.size());
1118 for (const AnnotatedSpan& annotation : annotations) {
1119 if (annotation.classification.empty()) {
1120 continue;
1121 }
1122
1123 const ClassificationResult& classification_result =
1124 annotation.classification[0];
1125
1126 ActionSuggestionAnnotation action_annotation;
1127 action_annotation.span = {
1128 message_index, annotation.span,
1129 UTF8ToUnicodeText(message.text, /*do_copy=*/false)
1130 .UTF8Substring(annotation.span.first, annotation.span.second)};
1131 action_annotation.entity = classification_result;
1132 action_annotation.name = classification_result.collection;
1133 action_annotations.push_back(std::move(action_annotation));
1134 }
1135
1136 if (model_->annotation_actions_spec()->deduplicate_annotations()) {
1137 // Create actions only for deduplicated annotations.
1138 for (const int annotation_id :
1139 DeduplicateAnnotations(action_annotations)) {
1140 SuggestActionsFromAnnotation(
1141 message_index, action_annotations[annotation_id], actions);
1142 }
1143 } else {
1144 // Create actions for all annotations.
1145 for (const ActionSuggestionAnnotation& annotation : action_annotations) {
1146 SuggestActionsFromAnnotation(message_index, annotation, actions);
1147 }
1148 }
1149 }
1150 }
1151
SuggestActionsFromAnnotation(const int message_index,const ActionSuggestionAnnotation & annotation,std::vector<ActionSuggestion> * actions) const1152 void ActionsSuggestions::SuggestActionsFromAnnotation(
1153 const int message_index, const ActionSuggestionAnnotation& annotation,
1154 std::vector<ActionSuggestion>* actions) const {
1155 for (const AnnotationActionsSpec_::AnnotationMapping* mapping :
1156 *model_->annotation_actions_spec()->annotation_mapping()) {
1157 if (annotation.entity.collection ==
1158 mapping->annotation_collection()->str()) {
1159 if (annotation.entity.score < mapping->min_annotation_score()) {
1160 continue;
1161 }
1162
1163 std::unique_ptr<ReflectiveFlatbuffer> entity_data =
1164 entity_data_builder_ != nullptr ? entity_data_builder_->NewRoot()
1165 : nullptr;
1166
1167 // Set annotation text as (additional) entity data field.
1168 if (mapping->entity_field() != nullptr) {
1169 TC3_CHECK_NE(entity_data, nullptr);
1170
1171 UnicodeText normalized_annotation_text =
1172 UTF8ToUnicodeText(annotation.span.text, /*do_copy=*/false);
1173
1174 // Apply normalization if specified.
1175 if (mapping->normalization_options() != nullptr) {
1176 normalized_annotation_text =
1177 NormalizeText(*unilib_, mapping->normalization_options(),
1178 normalized_annotation_text);
1179 }
1180
1181 entity_data->ParseAndSet(mapping->entity_field(),
1182 normalized_annotation_text.ToUTF8String());
1183 }
1184
1185 ActionSuggestion suggestion;
1186 FillSuggestionFromSpec(mapping->action(), entity_data.get(), &suggestion);
1187 if (mapping->use_annotation_score()) {
1188 suggestion.score = annotation.entity.score;
1189 }
1190 suggestion.annotations = {annotation};
1191 actions->push_back(std::move(suggestion));
1192 }
1193 }
1194 }
1195
DeduplicateAnnotations(const std::vector<ActionSuggestionAnnotation> & annotations) const1196 std::vector<int> ActionsSuggestions::DeduplicateAnnotations(
1197 const std::vector<ActionSuggestionAnnotation>& annotations) const {
1198 std::map<std::pair<std::string, std::string>, int> deduplicated_annotations;
1199
1200 for (int i = 0; i < annotations.size(); i++) {
1201 const std::pair<std::string, std::string> key = {annotations[i].name,
1202 annotations[i].span.text};
1203 auto entry = deduplicated_annotations.find(key);
1204 if (entry != deduplicated_annotations.end()) {
1205 // Kepp the annotation with the higher score.
1206 if (annotations[entry->second].entity.score <
1207 annotations[i].entity.score) {
1208 entry->second = i;
1209 }
1210 continue;
1211 }
1212 deduplicated_annotations.insert(entry, {key, i});
1213 }
1214
1215 std::vector<int> result;
1216 result.reserve(deduplicated_annotations.size());
1217 for (const auto& key_and_annotation : deduplicated_annotations) {
1218 result.push_back(key_and_annotation.second);
1219 }
1220 return result;
1221 }
1222
SuggestActionsFromLua(const Conversation & conversation,const TfLiteModelExecutor * model_executor,const tflite::Interpreter * interpreter,const reflection::Schema * annotation_entity_data_schema,std::vector<ActionSuggestion> * actions) const1223 bool ActionsSuggestions::SuggestActionsFromLua(
1224 const Conversation& conversation, const TfLiteModelExecutor* model_executor,
1225 const tflite::Interpreter* interpreter,
1226 const reflection::Schema* annotation_entity_data_schema,
1227 std::vector<ActionSuggestion>* actions) const {
1228 if (lua_bytecode_.empty()) {
1229 return true;
1230 }
1231
1232 auto lua_actions = LuaActionsSuggestions::CreateLuaActionsSuggestions(
1233 lua_bytecode_, conversation, model_executor, model_->tflite_model_spec(),
1234 interpreter, entity_data_schema_, annotation_entity_data_schema);
1235 if (lua_actions == nullptr) {
1236 TC3_LOG(ERROR) << "Could not create lua actions.";
1237 return false;
1238 }
1239 return lua_actions->SuggestActions(actions);
1240 }
1241
GatherActionsSuggestions(const Conversation & conversation,const Annotator * annotator,const ActionSuggestionOptions & options,ActionsSuggestionsResponse * response) const1242 bool ActionsSuggestions::GatherActionsSuggestions(
1243 const Conversation& conversation, const Annotator* annotator,
1244 const ActionSuggestionOptions& options,
1245 ActionsSuggestionsResponse* response) const {
1246 if (conversation.messages.empty()) {
1247 return true;
1248 }
1249
1250 // Run annotator against messages.
1251 const Conversation annotated_conversation =
1252 AnnotateConversation(conversation, annotator);
1253
1254 const int num_messages = NumMessagesToConsider(
1255 annotated_conversation, model_->max_conversation_history_length());
1256
1257 if (num_messages <= 0) {
1258 TC3_LOG(INFO) << "No messages provided for actions suggestions.";
1259 return false;
1260 }
1261
1262 SuggestActionsFromAnnotations(annotated_conversation, &response->actions);
1263
1264 if (grammar_actions_ != nullptr &&
1265 !grammar_actions_->SuggestActions(annotated_conversation,
1266 &response->actions)) {
1267 TC3_LOG(ERROR) << "Could not suggest actions from grammar rules.";
1268 return false;
1269 }
1270
1271 int input_text_length = 0;
1272 int num_matching_locales = 0;
1273 for (int i = annotated_conversation.messages.size() - num_messages;
1274 i < annotated_conversation.messages.size(); i++) {
1275 input_text_length += annotated_conversation.messages[i].text.length();
1276 std::vector<Locale> message_languages;
1277 if (!ParseLocales(
1278 annotated_conversation.messages[i].detected_text_language_tags,
1279 &message_languages)) {
1280 continue;
1281 }
1282 if (Locale::IsAnyLocaleSupported(
1283 message_languages, locales_,
1284 preconditions_.handle_unknown_locale_as_supported)) {
1285 ++num_matching_locales;
1286 }
1287 }
1288
1289 // Bail out if we are provided with too few or too much input.
1290 if (input_text_length < preconditions_.min_input_length ||
1291 (preconditions_.max_input_length >= 0 &&
1292 input_text_length > preconditions_.max_input_length)) {
1293 TC3_LOG(INFO) << "Too much or not enough input for inference.";
1294 return response;
1295 }
1296
1297 // Bail out if the text does not look like it can be handled by the model.
1298 const float matching_fraction =
1299 static_cast<float>(num_matching_locales) / num_messages;
1300 if (matching_fraction < preconditions_.min_locale_match_fraction) {
1301 TC3_LOG(INFO) << "Not enough locale matches.";
1302 response->output_filtered_locale_mismatch = true;
1303 return true;
1304 }
1305
1306 std::vector<const UniLib::RegexPattern*> post_check_rules;
1307 if (preconditions_.suppress_on_low_confidence_input) {
1308 if ((ngram_model_ != nullptr &&
1309 ngram_model_->EvalConversation(annotated_conversation,
1310 num_messages)) ||
1311 regex_actions_->IsLowConfidenceInput(annotated_conversation,
1312 num_messages, &post_check_rules)) {
1313 response->output_filtered_low_confidence = true;
1314 return true;
1315 }
1316 }
1317
1318 std::unique_ptr<tflite::Interpreter> interpreter;
1319 if (!SuggestActionsFromModel(annotated_conversation, num_messages, options,
1320 response, &interpreter)) {
1321 TC3_LOG(ERROR) << "Could not run model.";
1322 return false;
1323 }
1324
1325 // Suppress all predictions if the conversation was deemed sensitive.
1326 if (preconditions_.suppress_on_sensitive_topic &&
1327 response->output_filtered_sensitivity) {
1328 return true;
1329 }
1330
1331 if (!SuggestActionsFromLua(
1332 annotated_conversation, model_executor_.get(), interpreter.get(),
1333 annotator != nullptr ? annotator->entity_data_schema() : nullptr,
1334 &response->actions)) {
1335 TC3_LOG(ERROR) << "Could not suggest actions from script.";
1336 return false;
1337 }
1338
1339 if (!regex_actions_->SuggestActions(annotated_conversation,
1340 entity_data_builder_.get(),
1341 &response->actions)) {
1342 TC3_LOG(ERROR) << "Could not suggest actions from regex rules.";
1343 return false;
1344 }
1345
1346 if (preconditions_.suppress_on_low_confidence_input &&
1347 !regex_actions_->FilterConfidenceOutput(post_check_rules,
1348 &response->actions)) {
1349 TC3_LOG(ERROR) << "Could not post-check actions.";
1350 return false;
1351 }
1352
1353 return true;
1354 }
1355
SuggestActions(const Conversation & conversation,const Annotator * annotator,const ActionSuggestionOptions & options) const1356 ActionsSuggestionsResponse ActionsSuggestions::SuggestActions(
1357 const Conversation& conversation, const Annotator* annotator,
1358 const ActionSuggestionOptions& options) const {
1359 ActionsSuggestionsResponse response;
1360
1361 // Assert that messages are sorted correctly.
1362 for (int i = 1; i < conversation.messages.size(); i++) {
1363 if (conversation.messages[i].reference_time_ms_utc <
1364 conversation.messages[i - 1].reference_time_ms_utc) {
1365 TC3_LOG(ERROR) << "Messages are not sorted most recent last.";
1366 }
1367 }
1368
1369 if (!GatherActionsSuggestions(conversation, annotator, options, &response)) {
1370 TC3_LOG(ERROR) << "Could not gather actions suggestions.";
1371 response.actions.clear();
1372 } else if (!ranker_->RankActions(conversation, &response, entity_data_schema_,
1373 annotator != nullptr
1374 ? annotator->entity_data_schema()
1375 : nullptr)) {
1376 TC3_LOG(ERROR) << "Could not rank actions.";
1377 response.actions.clear();
1378 }
1379 return response;
1380 }
1381
SuggestActions(const Conversation & conversation,const ActionSuggestionOptions & options) const1382 ActionsSuggestionsResponse ActionsSuggestions::SuggestActions(
1383 const Conversation& conversation,
1384 const ActionSuggestionOptions& options) const {
1385 return SuggestActions(conversation, /*annotator=*/nullptr, options);
1386 }
1387
model() const1388 const ActionsModel* ActionsSuggestions::model() const { return model_; }
entity_data_schema() const1389 const reflection::Schema* ActionsSuggestions::entity_data_schema() const {
1390 return entity_data_schema_;
1391 }
1392
ViewActionsModel(const void * buffer,int size)1393 const ActionsModel* ViewActionsModel(const void* buffer, int size) {
1394 if (buffer == nullptr) {
1395 return nullptr;
1396 }
1397 return LoadAndVerifyModel(reinterpret_cast<const uint8_t*>(buffer), size);
1398 }
1399
1400 } // namespace libtextclassifier3
1401