1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifndef LIBTEXTCLASSIFIER_ACTIONS_ACTIONS_SUGGESTIONS_H_
18 #define LIBTEXTCLASSIFIER_ACTIONS_ACTIONS_SUGGESTIONS_H_
19
20 #include <map>
21 #include <memory>
22 #include <string>
23 #include <unordered_map>
24 #include <unordered_set>
25 #include <vector>
26
27 #include "actions/actions_model_generated.h"
28 #include "actions/conversation_intent_detection/conversation-intent-detection.h"
29 #include "actions/feature-processor.h"
30 #include "actions/grammar-actions.h"
31 #include "actions/ranker.h"
32 #include "actions/regex-actions.h"
33 #include "actions/sensitive-classifier-base.h"
34 #include "actions/types.h"
35 #include "annotator/annotator.h"
36 #include "annotator/model-executor.h"
37 #include "annotator/types.h"
38 #include "utils/flatbuffers/flatbuffers.h"
39 #include "utils/flatbuffers/mutable.h"
40 #include "utils/i18n/locale.h"
41 #include "utils/memory/mmap.h"
42 #include "utils/tflite-model-executor.h"
43 #include "utils/utf8/unilib.h"
44 #include "utils/variant.h"
45 #include "utils/zlib/zlib.h"
46 #include "absl/container/flat_hash_set.h"
47
48 namespace libtextclassifier3 {
49
50 // Class for predicting actions following a conversation.
51 class ActionsSuggestions {
52 public:
53 // Creates ActionsSuggestions from given data buffer with model.
54 static std::unique_ptr<ActionsSuggestions> FromUnownedBuffer(
55 const uint8_t* buffer, const int size, const UniLib* unilib = nullptr,
56 const std::string& triggering_preconditions_overlay = "");
57
58 // Creates ActionsSuggestions from model in the ScopedMmap object and takes
59 // ownership of it.
60 static std::unique_ptr<ActionsSuggestions> FromScopedMmap(
61 std::unique_ptr<libtextclassifier3::ScopedMmap> mmap,
62 const UniLib* unilib = nullptr,
63 const std::string& triggering_preconditions_overlay = "");
64 // Same as above, but also takes ownership of the unilib.
65 static std::unique_ptr<ActionsSuggestions> FromScopedMmap(
66 std::unique_ptr<libtextclassifier3::ScopedMmap> mmap,
67 std::unique_ptr<UniLib> unilib,
68 const std::string& triggering_preconditions_overlay);
69
70 // Creates ActionsSuggestions from model given as a file descriptor, offset
71 // and size in it. If offset and size are less than 0, will ignore them and
72 // will just use the fd.
73 static std::unique_ptr<ActionsSuggestions> FromFileDescriptor(
74 const int fd, const int offset, const int size,
75 const UniLib* unilib = nullptr,
76 const std::string& triggering_preconditions_overlay = "");
77 // Same as above, but also takes ownership of the unilib.
78 static std::unique_ptr<ActionsSuggestions> FromFileDescriptor(
79 const int fd, const int offset, const int size,
80 std::unique_ptr<UniLib> unilib,
81 const std::string& triggering_preconditions_overlay = "");
82
83 // Creates ActionsSuggestions from model given as a file descriptor.
84 static std::unique_ptr<ActionsSuggestions> FromFileDescriptor(
85 const int fd, const UniLib* unilib = nullptr,
86 const std::string& triggering_preconditions_overlay = "");
87 // Same as above, but also takes ownership of the unilib.
88 static std::unique_ptr<ActionsSuggestions> FromFileDescriptor(
89 const int fd, std::unique_ptr<UniLib> unilib,
90 const std::string& triggering_preconditions_overlay);
91
92 // Creates ActionsSuggestions from model given as a POSIX path.
93 static std::unique_ptr<ActionsSuggestions> FromPath(
94 const std::string& path, const UniLib* unilib = nullptr,
95 const std::string& triggering_preconditions_overlay = "");
96 // Same as above, but also takes ownership of unilib.
97 static std::unique_ptr<ActionsSuggestions> FromPath(
98 const std::string& path, std::unique_ptr<UniLib> unilib,
99 const std::string& triggering_preconditions_overlay);
100
101 ActionsSuggestionsResponse SuggestActions(
102 const Conversation& conversation,
103 const ActionSuggestionOptions& options = ActionSuggestionOptions()) const;
104
105 ActionsSuggestionsResponse SuggestActions(
106 const Conversation& conversation, const Annotator* annotator,
107 const ActionSuggestionOptions& options = ActionSuggestionOptions()) const;
108
109 bool InitializeConversationIntentDetection(
110 const std::string& serialized_config);
111
112 const ActionsModel* model() const;
113 const reflection::Schema* entity_data_schema() const;
114
115 static constexpr int kLocalUserId = 0;
116
117 protected:
118 // Exposed for testing.
119 bool EmbedTokenId(const int32 token_id, std::vector<float>* embedding) const;
120
121 // Embeds the tokens per message separately. Each message is padded to the
122 // maximum length with the padding token.
123 bool EmbedTokensPerMessage(const std::vector<std::vector<Token>>& tokens,
124 std::vector<float>* embeddings,
125 int* max_num_tokens_per_message) const;
126
127 // Concatenates the embedded message tokens - separated by start and end
128 // token between messages.
129 // If the total token count is greater than the maximum length, tokens at the
130 // start are dropped to fit into the limit.
131 // If the total token count is smaller than the minimum length, padding tokens
132 // are added to the end.
133 // Messages are assumed to be ordered by recency - most recent is last.
134 bool EmbedAndFlattenTokens(const std::vector<std::vector<Token>>& tokens,
135 std::vector<float>* embeddings,
136 int* total_token_count) const;
137
138 const ActionsModel* model_;
139
140 // Feature extractor and options.
141 std::unique_ptr<const ActionsFeatureProcessor> feature_processor_;
142 std::unique_ptr<const EmbeddingExecutor> embedding_executor_;
143 std::vector<float> embedded_padding_token_;
144 std::vector<float> embedded_start_token_;
145 std::vector<float> embedded_end_token_;
146 int token_embedding_size_;
147
148 private:
149 // Checks that model contains all required fields, and initializes internal
150 // datastructures.
151 bool ValidateAndInitialize();
152
153 void SetOrCreateUnilib(const UniLib* unilib);
154
155 // Prepare preconditions.
156 // Takes values from flag provided data, but falls back to model provided
157 // values for parameters that are not explicitly provided.
158 bool InitializeTriggeringPreconditions();
159
160 // Tokenizes a conversation and produces the tokens per message.
161 std::vector<std::vector<Token>> Tokenize(
162 const std::vector<std::string>& context) const;
163
164 bool AllocateInput(const int conversation_length, const int max_tokens,
165 const int total_token_count,
166 tflite::Interpreter* interpreter) const;
167
168 bool SetupModelInput(const std::vector<std::string>& context,
169 const std::vector<int>& user_ids,
170 const std::vector<float>& time_diffs,
171 const int num_suggestions,
172 const ActionSuggestionOptions& options,
173 tflite::Interpreter* interpreter) const;
174
175 void FillSuggestionFromSpecWithEntityData(const ActionSuggestionSpec* spec,
176 ActionSuggestion* suggestion) const;
177
178 void PopulateTextReplies(const tflite::Interpreter* interpreter,
179 int suggestion_index, int score_index,
180 const std::string& type, float priority_score,
181 const absl::flat_hash_set<std::string>& blocklist,
182 ActionsSuggestionsResponse* response) const;
183
184 void PopulateIntentTriggering(const tflite::Interpreter* interpreter,
185 int suggestion_index, int score_index,
186 const ActionSuggestionSpec* task_spec,
187 ActionsSuggestionsResponse* response) const;
188
189 bool ReadModelOutput(tflite::Interpreter* interpreter,
190 const ActionSuggestionOptions& options,
191 ActionsSuggestionsResponse* response) const;
192
193 bool SuggestActionsFromModel(
194 const Conversation& conversation, const int num_messages,
195 const ActionSuggestionOptions& options,
196 ActionsSuggestionsResponse* response,
197 std::unique_ptr<tflite::Interpreter>* interpreter) const;
198
199 Status SuggestActionsFromConversationIntentDetection(
200 const Conversation& conversation, const ActionSuggestionOptions& options,
201 std::vector<ActionSuggestion>* actions) const;
202
203 // Creates options for annotation of a message.
204 AnnotationOptions AnnotationOptionsForMessage(
205 const ConversationMessage& message) const;
206
207 void SuggestActionsFromAnnotations(
208 const Conversation& conversation,
209 std::vector<ActionSuggestion>* actions) const;
210
211 void SuggestActionsFromAnnotation(
212 const int message_index, const ActionSuggestionAnnotation& annotation,
213 std::vector<ActionSuggestion>* actions) const;
214
215 // Run annotator on the messages of a conversation.
216 Conversation AnnotateConversation(const Conversation& conversation,
217 const Annotator* annotator) const;
218
219 // Deduplicates equivalent annotations - annotations that have the same type
220 // and same span text.
221 // Returns the indices of the deduplicated annotations.
222 std::vector<int> DeduplicateAnnotations(
223 const std::vector<ActionSuggestionAnnotation>& annotations) const;
224
225 bool SuggestActionsFromLua(
226 const Conversation& conversation,
227 const TfLiteModelExecutor* model_executor,
228 const tflite::Interpreter* interpreter,
229 const reflection::Schema* annotation_entity_data_schema,
230 std::vector<ActionSuggestion>* actions) const;
231
232 bool GatherActionsSuggestions(const Conversation& conversation,
233 const Annotator* annotator,
234 const ActionSuggestionOptions& options,
235 ActionsSuggestionsResponse* response) const;
236
237 std::unique_ptr<libtextclassifier3::ScopedMmap> mmap_;
238
239 // Tensorflow Lite models.
240 std::unique_ptr<const TfLiteModelExecutor> model_executor_;
241
242 // Regex rules model.
243 std::unique_ptr<RegexActions> regex_actions_;
244
245 // The grammar rules model.
246 std::unique_ptr<GrammarActions> grammar_actions_;
247
248 std::unique_ptr<UniLib> owned_unilib_;
249 const UniLib* unilib_;
250
251 // Locales supported by the model.
252 std::vector<Locale> locales_;
253
254 // Annotation entities used by the model.
255 std::unordered_set<std::string> annotation_entity_types_;
256
257 // Builder for creating extra data.
258 const reflection::Schema* entity_data_schema_;
259 std::unique_ptr<MutableFlatbufferBuilder> entity_data_builder_;
260 std::unique_ptr<ActionsSuggestionsRanker> ranker_;
261
262 std::string lua_bytecode_;
263
264 // Triggering preconditions. These parameters can be backed by the model and
265 // (partially) be provided by flags.
266 TriggeringPreconditionsT preconditions_;
267 std::string triggering_preconditions_overlay_buffer_;
268 const TriggeringPreconditions* triggering_preconditions_overlay_;
269
270 // Low confidence input ngram classifier.
271 std::unique_ptr<const SensitiveTopicModelBase> sensitive_model_;
272
273 // Conversation intent detection model for additional actions.
274 std::unique_ptr<const ConversationIntentDetection>
275 conversation_intent_detection_;
276 };
277
278 // Interprets the buffer as a Model flatbuffer and returns it for reading.
279 const ActionsModel* ViewActionsModel(const void* buffer, int size);
280
281 // Opens model from given path and runs a function, passing the loaded Model
282 // flatbuffer as an argument.
283 //
284 // This is mainly useful if we don't want to pay the cost for the model
285 // initialization because we'll be only reading some flatbuffer values from the
286 // file.
287 template <typename ReturnType, typename Func>
VisitActionsModel(const std::string & path,Func function)288 ReturnType VisitActionsModel(const std::string& path, Func function) {
289 ScopedMmap mmap(path);
290 if (!mmap.handle().ok()) {
291 function(/*model=*/nullptr);
292 }
293 const ActionsModel* model =
294 ViewActionsModel(mmap.handle().start(), mmap.handle().num_bytes());
295 return function(model);
296 }
297
298 class ActionsSuggestionsTypes {
299 public:
300 // Should be in sync with those defined in Android.
301 // android/frameworks/base/core/java/android/view/textclassifier/ConversationActions.java
ViewCalendar()302 static const std::string& ViewCalendar() {
303 static const std::string& value =
304 *[]() { return new std::string("view_calendar"); }();
305 return value;
306 }
ViewMap()307 static const std::string& ViewMap() {
308 static const std::string& value =
309 *[]() { return new std::string("view_map"); }();
310 return value;
311 }
TrackFlight()312 static const std::string& TrackFlight() {
313 static const std::string& value =
314 *[]() { return new std::string("track_flight"); }();
315 return value;
316 }
OpenUrl()317 static const std::string& OpenUrl() {
318 static const std::string& value =
319 *[]() { return new std::string("open_url"); }();
320 return value;
321 }
SendSms()322 static const std::string& SendSms() {
323 static const std::string& value =
324 *[]() { return new std::string("send_sms"); }();
325 return value;
326 }
CallPhone()327 static const std::string& CallPhone() {
328 static const std::string& value =
329 *[]() { return new std::string("call_phone"); }();
330 return value;
331 }
SendEmail()332 static const std::string& SendEmail() {
333 static const std::string& value =
334 *[]() { return new std::string("send_email"); }();
335 return value;
336 }
ShareLocation()337 static const std::string& ShareLocation() {
338 static const std::string& value =
339 *[]() { return new std::string("share_location"); }();
340 return value;
341 }
CreateReminder()342 static const std::string& CreateReminder() {
343 static const std::string& value =
344 *[]() { return new std::string("create_reminder"); }();
345 return value;
346 }
TextReply()347 static const std::string& TextReply() {
348 static const std::string& value =
349 *[]() { return new std::string("text_reply"); }();
350 return value;
351 }
AddContact()352 static const std::string& AddContact() {
353 static const std::string& value =
354 *[]() { return new std::string("add_contact"); }();
355 return value;
356 }
Copy()357 static const std::string& Copy() {
358 static const std::string& value =
359 *[]() { return new std::string("copy"); }();
360 return value;
361 }
362 };
363
364 } // namespace libtextclassifier3
365
366 #endif // LIBTEXTCLASSIFIER_ACTIONS_ACTIONS_SUGGESTIONS_H_
367