1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifndef LIBTEXTCLASSIFIER_ACTIONS_ACTIONS_SUGGESTIONS_H_
18 #define LIBTEXTCLASSIFIER_ACTIONS_ACTIONS_SUGGESTIONS_H_
19
20 #include <map>
21 #include <memory>
22 #include <string>
23 #include <unordered_map>
24 #include <unordered_set>
25 #include <vector>
26
27 #include "actions/actions_model_generated.h"
28 #include "actions/conversation_intent_detection/conversation-intent-detection.h"
29 #include "actions/feature-processor.h"
30 #include "actions/grammar-actions.h"
31 #include "actions/ranker.h"
32 #include "actions/regex-actions.h"
33 #include "actions/sensitive-classifier-base.h"
34 #include "actions/types.h"
35 #include "annotator/annotator.h"
36 #include "annotator/model-executor.h"
37 #include "annotator/types.h"
38 #include "utils/flatbuffers/flatbuffers.h"
39 #include "utils/flatbuffers/mutable.h"
40 #include "utils/i18n/locale.h"
41 #include "utils/memory/mmap.h"
42 #include "utils/tflite-model-executor.h"
43 #include "utils/utf8/unilib.h"
44 #include "utils/variant.h"
45 #include "utils/zlib/zlib.h"
46
47 namespace libtextclassifier3 {
48
49 // Class for predicting actions following a conversation.
50 class ActionsSuggestions {
51 public:
52 // Creates ActionsSuggestions from given data buffer with model.
53 static std::unique_ptr<ActionsSuggestions> FromUnownedBuffer(
54 const uint8_t* buffer, const int size, const UniLib* unilib = nullptr,
55 const std::string& triggering_preconditions_overlay = "");
56
57 // Creates ActionsSuggestions from model in the ScopedMmap object and takes
58 // ownership of it.
59 static std::unique_ptr<ActionsSuggestions> FromScopedMmap(
60 std::unique_ptr<libtextclassifier3::ScopedMmap> mmap,
61 const UniLib* unilib = nullptr,
62 const std::string& triggering_preconditions_overlay = "");
63 // Same as above, but also takes ownership of the unilib.
64 static std::unique_ptr<ActionsSuggestions> FromScopedMmap(
65 std::unique_ptr<libtextclassifier3::ScopedMmap> mmap,
66 std::unique_ptr<UniLib> unilib,
67 const std::string& triggering_preconditions_overlay);
68
69 // Creates ActionsSuggestions from model given as a file descriptor, offset
70 // and size in it. If offset and size are less than 0, will ignore them and
71 // will just use the fd.
72 static std::unique_ptr<ActionsSuggestions> FromFileDescriptor(
73 const int fd, const int offset, const int size,
74 const UniLib* unilib = nullptr,
75 const std::string& triggering_preconditions_overlay = "");
76 // Same as above, but also takes ownership of the unilib.
77 static std::unique_ptr<ActionsSuggestions> FromFileDescriptor(
78 const int fd, const int offset, const int size,
79 std::unique_ptr<UniLib> unilib,
80 const std::string& triggering_preconditions_overlay = "");
81
82 // Creates ActionsSuggestions from model given as a file descriptor.
83 static std::unique_ptr<ActionsSuggestions> FromFileDescriptor(
84 const int fd, const UniLib* unilib = nullptr,
85 const std::string& triggering_preconditions_overlay = "");
86 // Same as above, but also takes ownership of the unilib.
87 static std::unique_ptr<ActionsSuggestions> FromFileDescriptor(
88 const int fd, std::unique_ptr<UniLib> unilib,
89 const std::string& triggering_preconditions_overlay);
90
91 // Creates ActionsSuggestions from model given as a POSIX path.
92 static std::unique_ptr<ActionsSuggestions> FromPath(
93 const std::string& path, const UniLib* unilib = nullptr,
94 const std::string& triggering_preconditions_overlay = "");
95 // Same as above, but also takes ownership of unilib.
96 static std::unique_ptr<ActionsSuggestions> FromPath(
97 const std::string& path, std::unique_ptr<UniLib> unilib,
98 const std::string& triggering_preconditions_overlay);
99
100 ActionsSuggestionsResponse SuggestActions(
101 const Conversation& conversation,
102 const ActionSuggestionOptions& options = ActionSuggestionOptions()) const;
103
104 ActionsSuggestionsResponse SuggestActions(
105 const Conversation& conversation, const Annotator* annotator,
106 const ActionSuggestionOptions& options = ActionSuggestionOptions()) const;
107
108 bool InitializeConversationIntentDetection(
109 const std::string& serialized_config);
110
111 const ActionsModel* model() const;
112 const reflection::Schema* entity_data_schema() const;
113
114 static constexpr int kLocalUserId = 0;
115
116 protected:
117 // Exposed for testing.
118 bool EmbedTokenId(const int32 token_id, std::vector<float>* embedding) const;
119
120 // Embeds the tokens per message separately. Each message is padded to the
121 // maximum length with the padding token.
122 bool EmbedTokensPerMessage(const std::vector<std::vector<Token>>& tokens,
123 std::vector<float>* embeddings,
124 int* max_num_tokens_per_message) const;
125
126 // Concatenates the embedded message tokens - separated by start and end
127 // token between messages.
128 // If the total token count is greater than the maximum length, tokens at the
129 // start are dropped to fit into the limit.
130 // If the total token count is smaller than the minimum length, padding tokens
131 // are added to the end.
132 // Messages are assumed to be ordered by recency - most recent is last.
133 bool EmbedAndFlattenTokens(const std::vector<std::vector<Token>>& tokens,
134 std::vector<float>* embeddings,
135 int* total_token_count) const;
136
137 const ActionsModel* model_;
138
139 // Feature extractor and options.
140 std::unique_ptr<const ActionsFeatureProcessor> feature_processor_;
141 std::unique_ptr<const EmbeddingExecutor> embedding_executor_;
142 std::vector<float> embedded_padding_token_;
143 std::vector<float> embedded_start_token_;
144 std::vector<float> embedded_end_token_;
145 int token_embedding_size_;
146
147 private:
148 // Checks that model contains all required fields, and initializes internal
149 // datastructures.
150 bool ValidateAndInitialize();
151
152 void SetOrCreateUnilib(const UniLib* unilib);
153
154 // Prepare preconditions.
155 // Takes values from flag provided data, but falls back to model provided
156 // values for parameters that are not explicitly provided.
157 bool InitializeTriggeringPreconditions();
158
159 // Tokenizes a conversation and produces the tokens per message.
160 std::vector<std::vector<Token>> Tokenize(
161 const std::vector<std::string>& context) const;
162
163 bool AllocateInput(const int conversation_length, const int max_tokens,
164 const int total_token_count,
165 tflite::Interpreter* interpreter) const;
166
167 bool SetupModelInput(const std::vector<std::string>& context,
168 const std::vector<int>& user_ids,
169 const std::vector<float>& time_diffs,
170 const int num_suggestions,
171 const ActionSuggestionOptions& options,
172 tflite::Interpreter* interpreter) const;
173
174 void FillSuggestionFromSpecWithEntityData(const ActionSuggestionSpec* spec,
175 ActionSuggestion* suggestion) const;
176
177 void PopulateTextReplies(const tflite::Interpreter* interpreter,
178 int suggestion_index, int score_index,
179 const std::string& type,
180 ActionsSuggestionsResponse* response) const;
181
182 void PopulateIntentTriggering(const tflite::Interpreter* interpreter,
183 int suggestion_index, int score_index,
184 const ActionSuggestionSpec* task_spec,
185 ActionsSuggestionsResponse* response) const;
186
187 bool ReadModelOutput(tflite::Interpreter* interpreter,
188 const ActionSuggestionOptions& options,
189 ActionsSuggestionsResponse* response) const;
190
191 bool SuggestActionsFromModel(
192 const Conversation& conversation, const int num_messages,
193 const ActionSuggestionOptions& options,
194 ActionsSuggestionsResponse* response,
195 std::unique_ptr<tflite::Interpreter>* interpreter) const;
196
197 Status SuggestActionsFromConversationIntentDetection(
198 const Conversation& conversation, const ActionSuggestionOptions& options,
199 std::vector<ActionSuggestion>* actions) const;
200
201 // Creates options for annotation of a message.
202 AnnotationOptions AnnotationOptionsForMessage(
203 const ConversationMessage& message) const;
204
205 void SuggestActionsFromAnnotations(
206 const Conversation& conversation,
207 std::vector<ActionSuggestion>* actions) const;
208
209 void SuggestActionsFromAnnotation(
210 const int message_index, const ActionSuggestionAnnotation& annotation,
211 std::vector<ActionSuggestion>* actions) const;
212
213 // Run annotator on the messages of a conversation.
214 Conversation AnnotateConversation(const Conversation& conversation,
215 const Annotator* annotator) const;
216
217 // Deduplicates equivalent annotations - annotations that have the same type
218 // and same span text.
219 // Returns the indices of the deduplicated annotations.
220 std::vector<int> DeduplicateAnnotations(
221 const std::vector<ActionSuggestionAnnotation>& annotations) const;
222
223 bool SuggestActionsFromLua(
224 const Conversation& conversation,
225 const TfLiteModelExecutor* model_executor,
226 const tflite::Interpreter* interpreter,
227 const reflection::Schema* annotation_entity_data_schema,
228 std::vector<ActionSuggestion>* actions) const;
229
230 bool GatherActionsSuggestions(const Conversation& conversation,
231 const Annotator* annotator,
232 const ActionSuggestionOptions& options,
233 ActionsSuggestionsResponse* response) const;
234
235 std::unique_ptr<libtextclassifier3::ScopedMmap> mmap_;
236
237 // Tensorflow Lite models.
238 std::unique_ptr<const TfLiteModelExecutor> model_executor_;
239
240 // Regex rules model.
241 std::unique_ptr<RegexActions> regex_actions_;
242
243 // The grammar rules model.
244 std::unique_ptr<GrammarActions> grammar_actions_;
245
246 std::unique_ptr<UniLib> owned_unilib_;
247 const UniLib* unilib_;
248
249 // Locales supported by the model.
250 std::vector<Locale> locales_;
251
252 // Annotation entities used by the model.
253 std::unordered_set<std::string> annotation_entity_types_;
254
255 // Builder for creating extra data.
256 const reflection::Schema* entity_data_schema_;
257 std::unique_ptr<MutableFlatbufferBuilder> entity_data_builder_;
258 std::unique_ptr<ActionsSuggestionsRanker> ranker_;
259
260 std::string lua_bytecode_;
261
262 // Triggering preconditions. These parameters can be backed by the model and
263 // (partially) be provided by flags.
264 TriggeringPreconditionsT preconditions_;
265 std::string triggering_preconditions_overlay_buffer_;
266 const TriggeringPreconditions* triggering_preconditions_overlay_;
267
268 // Low confidence input ngram classifier.
269 std::unique_ptr<const SensitiveTopicModelBase> sensitive_model_;
270
271 // Conversation intent detection model for additional actions.
272 std::unique_ptr<const ConversationIntentDetection>
273 conversation_intent_detection_;
274 };
275
276 // Interprets the buffer as a Model flatbuffer and returns it for reading.
277 const ActionsModel* ViewActionsModel(const void* buffer, int size);
278
279 // Opens model from given path and runs a function, passing the loaded Model
280 // flatbuffer as an argument.
281 //
282 // This is mainly useful if we don't want to pay the cost for the model
283 // initialization because we'll be only reading some flatbuffer values from the
284 // file.
285 template <typename ReturnType, typename Func>
VisitActionsModel(const std::string & path,Func function)286 ReturnType VisitActionsModel(const std::string& path, Func function) {
287 ScopedMmap mmap(path);
288 if (!mmap.handle().ok()) {
289 function(/*model=*/nullptr);
290 }
291 const ActionsModel* model =
292 ViewActionsModel(mmap.handle().start(), mmap.handle().num_bytes());
293 return function(model);
294 }
295
296 class ActionsSuggestionsTypes {
297 public:
298 // Should be in sync with those defined in Android.
299 // android/frameworks/base/core/java/android/view/textclassifier/ConversationActions.java
ViewCalendar()300 static const std::string& ViewCalendar() {
301 static const std::string& value =
302 *[]() { return new std::string("view_calendar"); }();
303 return value;
304 }
ViewMap()305 static const std::string& ViewMap() {
306 static const std::string& value =
307 *[]() { return new std::string("view_map"); }();
308 return value;
309 }
TrackFlight()310 static const std::string& TrackFlight() {
311 static const std::string& value =
312 *[]() { return new std::string("track_flight"); }();
313 return value;
314 }
OpenUrl()315 static const std::string& OpenUrl() {
316 static const std::string& value =
317 *[]() { return new std::string("open_url"); }();
318 return value;
319 }
SendSms()320 static const std::string& SendSms() {
321 static const std::string& value =
322 *[]() { return new std::string("send_sms"); }();
323 return value;
324 }
CallPhone()325 static const std::string& CallPhone() {
326 static const std::string& value =
327 *[]() { return new std::string("call_phone"); }();
328 return value;
329 }
SendEmail()330 static const std::string& SendEmail() {
331 static const std::string& value =
332 *[]() { return new std::string("send_email"); }();
333 return value;
334 }
ShareLocation()335 static const std::string& ShareLocation() {
336 static const std::string& value =
337 *[]() { return new std::string("share_location"); }();
338 return value;
339 }
CreateReminder()340 static const std::string& CreateReminder() {
341 static const std::string& value =
342 *[]() { return new std::string("create_reminder"); }();
343 return value;
344 }
TextReply()345 static const std::string& TextReply() {
346 static const std::string& value =
347 *[]() { return new std::string("text_reply"); }();
348 return value;
349 }
AddContact()350 static const std::string& AddContact() {
351 static const std::string& value =
352 *[]() { return new std::string("add_contact"); }();
353 return value;
354 }
Copy()355 static const std::string& Copy() {
356 static const std::string& value =
357 *[]() { return new std::string("copy"); }();
358 return value;
359 }
360 };
361
362 } // namespace libtextclassifier3
363
364 #endif // LIBTEXTCLASSIFIER_ACTIONS_ACTIONS_SUGGESTIONS_H_
365