1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifndef LIBTEXTCLASSIFIER_ACTIONS_ACTIONS_SUGGESTIONS_H_
18 #define LIBTEXTCLASSIFIER_ACTIONS_ACTIONS_SUGGESTIONS_H_
19
20 #include <map>
21 #include <memory>
22 #include <string>
23 #include <unordered_map>
24 #include <unordered_set>
25 #include <vector>
26
27 #include "actions/actions_model_generated.h"
28 #include "actions/feature-processor.h"
29 #include "actions/grammar-actions.h"
30 #include "actions/ngram-model.h"
31 #include "actions/ranker.h"
32 #include "actions/regex-actions.h"
33 #include "actions/types.h"
34 #include "annotator/annotator.h"
35 #include "annotator/model-executor.h"
36 #include "annotator/types.h"
37 #include "utils/flatbuffers.h"
38 #include "utils/i18n/locale.h"
39 #include "utils/memory/mmap.h"
40 #include "utils/tflite-model-executor.h"
41 #include "utils/utf8/unilib.h"
42 #include "utils/variant.h"
43 #include "utils/zlib/zlib.h"
44
45 namespace libtextclassifier3 {
46
47 // Options for suggesting actions.
48 struct ActionSuggestionOptions {
DefaultActionSuggestionOptions49 static ActionSuggestionOptions Default() { return ActionSuggestionOptions(); }
50 std::unordered_map<std::string, Variant> model_parameters;
51 };
52
53 // Class for predicting actions following a conversation.
54 class ActionsSuggestions {
55 public:
56 // Creates ActionsSuggestions from given data buffer with model.
57 static std::unique_ptr<ActionsSuggestions> FromUnownedBuffer(
58 const uint8_t* buffer, const int size, const UniLib* unilib = nullptr,
59 const std::string& triggering_preconditions_overlay = "");
60
61 // Creates ActionsSuggestions from model in the ScopedMmap object and takes
62 // ownership of it.
63 static std::unique_ptr<ActionsSuggestions> FromScopedMmap(
64 std::unique_ptr<libtextclassifier3::ScopedMmap> mmap,
65 const UniLib* unilib = nullptr,
66 const std::string& triggering_preconditions_overlay = "");
67 // Same as above, but also takes ownership of the unilib.
68 static std::unique_ptr<ActionsSuggestions> FromScopedMmap(
69 std::unique_ptr<libtextclassifier3::ScopedMmap> mmap,
70 std::unique_ptr<UniLib> unilib,
71 const std::string& triggering_preconditions_overlay);
72
73 // Creates ActionsSuggestions from model given as a file descriptor, offset
74 // and size in it. If offset and size are less than 0, will ignore them and
75 // will just use the fd.
76 static std::unique_ptr<ActionsSuggestions> FromFileDescriptor(
77 const int fd, const int offset, const int size,
78 const UniLib* unilib = nullptr,
79 const std::string& triggering_preconditions_overlay = "");
80 // Same as above, but also takes ownership of the unilib.
81 static std::unique_ptr<ActionsSuggestions> FromFileDescriptor(
82 const int fd, const int offset, const int size,
83 std::unique_ptr<UniLib> unilib,
84 const std::string& triggering_preconditions_overlay = "");
85
86 // Creates ActionsSuggestions from model given as a file descriptor.
87 static std::unique_ptr<ActionsSuggestions> FromFileDescriptor(
88 const int fd, const UniLib* unilib = nullptr,
89 const std::string& triggering_preconditions_overlay = "");
90 // Same as above, but also takes ownership of the unilib.
91 static std::unique_ptr<ActionsSuggestions> FromFileDescriptor(
92 const int fd, std::unique_ptr<UniLib> unilib,
93 const std::string& triggering_preconditions_overlay);
94
95 // Creates ActionsSuggestions from model given as a POSIX path.
96 static std::unique_ptr<ActionsSuggestions> FromPath(
97 const std::string& path, const UniLib* unilib = nullptr,
98 const std::string& triggering_preconditions_overlay = "");
99 // Same as above, but also takes ownership of unilib.
100 static std::unique_ptr<ActionsSuggestions> FromPath(
101 const std::string& path, std::unique_ptr<UniLib> unilib,
102 const std::string& triggering_preconditions_overlay);
103
104 ActionsSuggestionsResponse SuggestActions(
105 const Conversation& conversation,
106 const ActionSuggestionOptions& options = ActionSuggestionOptions()) const;
107
108 ActionsSuggestionsResponse SuggestActions(
109 const Conversation& conversation, const Annotator* annotator,
110 const ActionSuggestionOptions& options = ActionSuggestionOptions()) const;
111
112 const ActionsModel* model() const;
113 const reflection::Schema* entity_data_schema() const;
114
115 static constexpr int kLocalUserId = 0;
116
117 // Should be in sync with those defined in Android.
118 // android/frameworks/base/core/java/android/view/textclassifier/ConversationActions.java
119 static const std::string& kViewCalendarType;
120 static const std::string& kViewMapType;
121 static const std::string& kTrackFlightType;
122 static const std::string& kOpenUrlType;
123 static const std::string& kSendSmsType;
124 static const std::string& kCallPhoneType;
125 static const std::string& kSendEmailType;
126 static const std::string& kShareLocation;
127
128 protected:
129 // Exposed for testing.
130 bool EmbedTokenId(const int32 token_id, std::vector<float>* embedding) const;
131
132 // Embeds the tokens per message separately. Each message is padded to the
133 // maximum length with the padding token.
134 bool EmbedTokensPerMessage(const std::vector<std::vector<Token>>& tokens,
135 std::vector<float>* embeddings,
136 int* max_num_tokens_per_message) const;
137
138 // Concatenates the embedded message tokens - separated by start and end
139 // token between messages.
140 // If the total token count is greater than the maximum length, tokens at the
141 // start are dropped to fit into the limit.
142 // If the total token count is smaller than the minimum length, padding tokens
143 // are added to the end.
144 // Messages are assumed to be ordered by recency - most recent is last.
145 bool EmbedAndFlattenTokens(const std::vector<std::vector<Token>>& tokens,
146 std::vector<float>* embeddings,
147 int* total_token_count) const;
148
149 const ActionsModel* model_;
150
151 // Feature extractor and options.
152 std::unique_ptr<const ActionsFeatureProcessor> feature_processor_;
153 std::unique_ptr<const EmbeddingExecutor> embedding_executor_;
154 std::vector<float> embedded_padding_token_;
155 std::vector<float> embedded_start_token_;
156 std::vector<float> embedded_end_token_;
157 int token_embedding_size_;
158
159 private:
160 // Checks that model contains all required fields, and initializes internal
161 // datastructures.
162 bool ValidateAndInitialize();
163
164 void SetOrCreateUnilib(const UniLib* unilib);
165
166 // Prepare preconditions.
167 // Takes values from flag provided data, but falls back to model provided
168 // values for parameters that are not explicitly provided.
169 bool InitializeTriggeringPreconditions();
170
171 // Tokenizes a conversation and produces the tokens per message.
172 std::vector<std::vector<Token>> Tokenize(
173 const std::vector<std::string>& context) const;
174
175 bool AllocateInput(const int conversation_length, const int max_tokens,
176 const int total_token_count,
177 tflite::Interpreter* interpreter) const;
178
179 bool SetupModelInput(const std::vector<std::string>& context,
180 const std::vector<int>& user_ids,
181 const std::vector<float>& time_diffs,
182 const int num_suggestions,
183 const ActionSuggestionOptions& options,
184 tflite::Interpreter* interpreter) const;
185
186 void FillSuggestionFromSpecWithEntityData(const ActionSuggestionSpec* spec,
187 ActionSuggestion* suggestion) const;
188
189 void PopulateTextReplies(const tflite::Interpreter* interpreter,
190 int suggestion_index, int score_index,
191 const std::string& type,
192 ActionsSuggestionsResponse* response) const;
193
194 void PopulateIntentTriggering(const tflite::Interpreter* interpreter,
195 int suggestion_index, int score_index,
196 const ActionSuggestionSpec* task_spec,
197 ActionsSuggestionsResponse* response) const;
198
199 bool ReadModelOutput(tflite::Interpreter* interpreter,
200 const ActionSuggestionOptions& options,
201 ActionsSuggestionsResponse* response) const;
202
203 bool SuggestActionsFromModel(
204 const Conversation& conversation, const int num_messages,
205 const ActionSuggestionOptions& options,
206 ActionsSuggestionsResponse* response,
207 std::unique_ptr<tflite::Interpreter>* interpreter) const;
208
209 // Creates options for annotation of a message.
210 AnnotationOptions AnnotationOptionsForMessage(
211 const ConversationMessage& message) const;
212
213 void SuggestActionsFromAnnotations(
214 const Conversation& conversation,
215 std::vector<ActionSuggestion>* actions) const;
216
217 void SuggestActionsFromAnnotation(
218 const int message_index, const ActionSuggestionAnnotation& annotation,
219 std::vector<ActionSuggestion>* actions) const;
220
221 // Run annotator on the messages of a conversation.
222 Conversation AnnotateConversation(const Conversation& conversation,
223 const Annotator* annotator) const;
224
225 // Deduplicates equivalent annotations - annotations that have the same type
226 // and same span text.
227 // Returns the indices of the deduplicated annotations.
228 std::vector<int> DeduplicateAnnotations(
229 const std::vector<ActionSuggestionAnnotation>& annotations) const;
230
231 bool SuggestActionsFromLua(
232 const Conversation& conversation,
233 const TfLiteModelExecutor* model_executor,
234 const tflite::Interpreter* interpreter,
235 const reflection::Schema* annotation_entity_data_schema,
236 std::vector<ActionSuggestion>* actions) const;
237
238 bool GatherActionsSuggestions(const Conversation& conversation,
239 const Annotator* annotator,
240 const ActionSuggestionOptions& options,
241 ActionsSuggestionsResponse* response) const;
242
243 std::unique_ptr<libtextclassifier3::ScopedMmap> mmap_;
244
245 // Tensorflow Lite models.
246 std::unique_ptr<const TfLiteModelExecutor> model_executor_;
247
248 // Regex rules model.
249 std::unique_ptr<RegexActions> regex_actions_;
250
251 // The grammar rules model.
252 std::unique_ptr<GrammarActions> grammar_actions_;
253
254 std::unique_ptr<UniLib> owned_unilib_;
255 const UniLib* unilib_;
256
257 // Locales supported by the model.
258 std::vector<Locale> locales_;
259
260 // Annotation entities used by the model.
261 std::unordered_set<std::string> annotation_entity_types_;
262
263 // Builder for creating extra data.
264 const reflection::Schema* entity_data_schema_;
265 std::unique_ptr<ReflectiveFlatbufferBuilder> entity_data_builder_;
266 std::unique_ptr<ActionsSuggestionsRanker> ranker_;
267
268 std::string lua_bytecode_;
269
270 // Triggering preconditions. These parameters can be backed by the model and
271 // (partially) be provided by flags.
272 TriggeringPreconditionsT preconditions_;
273 std::string triggering_preconditions_overlay_buffer_;
274 const TriggeringPreconditions* triggering_preconditions_overlay_;
275
276 // Low confidence input ngram classifier.
277 std::unique_ptr<const NGramModel> ngram_model_;
278 };
279
280 // Interprets the buffer as a Model flatbuffer and returns it for reading.
281 const ActionsModel* ViewActionsModel(const void* buffer, int size);
282
283 // Opens model from given path and runs a function, passing the loaded Model
284 // flatbuffer as an argument.
285 //
286 // This is mainly useful if we don't want to pay the cost for the model
287 // initialization because we'll be only reading some flatbuffer values from the
288 // file.
289 template <typename ReturnType, typename Func>
VisitActionsModel(const std::string & path,Func function)290 ReturnType VisitActionsModel(const std::string& path, Func function) {
291 ScopedMmap mmap(path);
292 if (!mmap.handle().ok()) {
293 function(/*model=*/nullptr);
294 }
295 const ActionsModel* model =
296 ViewActionsModel(mmap.handle().start(), mmap.handle().num_bytes());
297 return function(model);
298 }
299
300 } // namespace libtextclassifier3
301
302 #endif // LIBTEXTCLASSIFIER_ACTIONS_ACTIONS_SUGGESTIONS_H_
303