• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef LIBTEXTCLASSIFIER_ACTIONS_ACTIONS_SUGGESTIONS_H_
18 #define LIBTEXTCLASSIFIER_ACTIONS_ACTIONS_SUGGESTIONS_H_
19 
20 #include <map>
21 #include <memory>
22 #include <string>
23 #include <unordered_map>
24 #include <unordered_set>
25 #include <vector>
26 
27 #include "actions/actions_model_generated.h"
28 #include "actions/conversation_intent_detection/conversation-intent-detection.h"
29 #include "actions/feature-processor.h"
30 #include "actions/grammar-actions.h"
31 #include "actions/ranker.h"
32 #include "actions/regex-actions.h"
33 #include "actions/sensitive-classifier-base.h"
34 #include "actions/types.h"
35 #include "annotator/annotator.h"
36 #include "annotator/model-executor.h"
37 #include "annotator/types.h"
38 #include "utils/flatbuffers/flatbuffers.h"
39 #include "utils/flatbuffers/mutable.h"
40 #include "utils/i18n/locale.h"
41 #include "utils/memory/mmap.h"
42 #include "utils/tflite-model-executor.h"
43 #include "utils/utf8/unilib.h"
44 #include "utils/variant.h"
45 #include "utils/zlib/zlib.h"
46 #include "absl/container/flat_hash_set.h"
47 
48 namespace libtextclassifier3 {
49 
50 // Class for predicting actions following a conversation.
51 class ActionsSuggestions {
52  public:
53   // Creates ActionsSuggestions from given data buffer with model.
54   static std::unique_ptr<ActionsSuggestions> FromUnownedBuffer(
55       const uint8_t* buffer, const int size, const UniLib* unilib = nullptr,
56       const std::string& triggering_preconditions_overlay = "");
57 
58   // Creates ActionsSuggestions from model in the ScopedMmap object and takes
59   // ownership of it.
60   static std::unique_ptr<ActionsSuggestions> FromScopedMmap(
61       std::unique_ptr<libtextclassifier3::ScopedMmap> mmap,
62       const UniLib* unilib = nullptr,
63       const std::string& triggering_preconditions_overlay = "");
64   // Same as above, but also takes ownership of the unilib.
65   static std::unique_ptr<ActionsSuggestions> FromScopedMmap(
66       std::unique_ptr<libtextclassifier3::ScopedMmap> mmap,
67       std::unique_ptr<UniLib> unilib,
68       const std::string& triggering_preconditions_overlay);
69 
70   // Creates ActionsSuggestions from model given as a file descriptor, offset
71   // and size in it. If offset and size are less than 0, will ignore them and
72   // will just use the fd.
73   static std::unique_ptr<ActionsSuggestions> FromFileDescriptor(
74       const int fd, const int offset, const int size,
75       const UniLib* unilib = nullptr,
76       const std::string& triggering_preconditions_overlay = "");
77   // Same as above, but also takes ownership of the unilib.
78   static std::unique_ptr<ActionsSuggestions> FromFileDescriptor(
79       const int fd, const int offset, const int size,
80       std::unique_ptr<UniLib> unilib,
81       const std::string& triggering_preconditions_overlay = "");
82 
83   // Creates ActionsSuggestions from model given as a file descriptor.
84   static std::unique_ptr<ActionsSuggestions> FromFileDescriptor(
85       const int fd, const UniLib* unilib = nullptr,
86       const std::string& triggering_preconditions_overlay = "");
87   // Same as above, but also takes ownership of the unilib.
88   static std::unique_ptr<ActionsSuggestions> FromFileDescriptor(
89       const int fd, std::unique_ptr<UniLib> unilib,
90       const std::string& triggering_preconditions_overlay);
91 
92   // Creates ActionsSuggestions from model given as a POSIX path.
93   static std::unique_ptr<ActionsSuggestions> FromPath(
94       const std::string& path, const UniLib* unilib = nullptr,
95       const std::string& triggering_preconditions_overlay = "");
96   // Same as above, but also takes ownership of unilib.
97   static std::unique_ptr<ActionsSuggestions> FromPath(
98       const std::string& path, std::unique_ptr<UniLib> unilib,
99       const std::string& triggering_preconditions_overlay);
100 
101   ActionsSuggestionsResponse SuggestActions(
102       const Conversation& conversation,
103       const ActionSuggestionOptions& options = ActionSuggestionOptions()) const;
104 
105   ActionsSuggestionsResponse SuggestActions(
106       const Conversation& conversation, const Annotator* annotator,
107       const ActionSuggestionOptions& options = ActionSuggestionOptions()) const;
108 
109   bool InitializeConversationIntentDetection(
110       const std::string& serialized_config);
111 
112   const ActionsModel* model() const;
113   const reflection::Schema* entity_data_schema() const;
114 
115   static constexpr int kLocalUserId = 0;
116 
117  protected:
118   // Exposed for testing.
119   bool EmbedTokenId(const int32 token_id, std::vector<float>* embedding) const;
120 
121   // Embeds the tokens per message separately. Each message is padded to the
122   // maximum length with the padding token.
123   bool EmbedTokensPerMessage(const std::vector<std::vector<Token>>& tokens,
124                              std::vector<float>* embeddings,
125                              int* max_num_tokens_per_message) const;
126 
127   // Concatenates the embedded message tokens - separated by start and end
128   // token between messages.
129   // If the total token count is greater than the maximum length, tokens at the
130   // start are dropped to fit into the limit.
131   // If the total token count is smaller than the minimum length, padding tokens
132   // are added to the end.
133   // Messages are assumed to be ordered by recency - most recent is last.
134   bool EmbedAndFlattenTokens(const std::vector<std::vector<Token>>& tokens,
135                              std::vector<float>* embeddings,
136                              int* total_token_count) const;
137 
138   const ActionsModel* model_;
139 
140   // Feature extractor and options.
141   std::unique_ptr<const ActionsFeatureProcessor> feature_processor_;
142   std::unique_ptr<const EmbeddingExecutor> embedding_executor_;
143   std::vector<float> embedded_padding_token_;
144   std::vector<float> embedded_start_token_;
145   std::vector<float> embedded_end_token_;
146   int token_embedding_size_;
147 
148  private:
149   // Checks that model contains all required fields, and initializes internal
150   // datastructures.
151   bool ValidateAndInitialize();
152 
153   void SetOrCreateUnilib(const UniLib* unilib);
154 
155   // Prepare preconditions.
156   // Takes values from flag provided data, but falls back to model provided
157   // values for parameters that are not explicitly provided.
158   bool InitializeTriggeringPreconditions();
159 
160   // Tokenizes a conversation and produces the tokens per message.
161   std::vector<std::vector<Token>> Tokenize(
162       const std::vector<std::string>& context) const;
163 
164   bool AllocateInput(const int conversation_length, const int max_tokens,
165                      const int total_token_count,
166                      tflite::Interpreter* interpreter) const;
167 
168   bool SetupModelInput(const std::vector<std::string>& context,
169                        const std::vector<int>& user_ids,
170                        const std::vector<float>& time_diffs,
171                        const int num_suggestions,
172                        const ActionSuggestionOptions& options,
173                        tflite::Interpreter* interpreter) const;
174 
175   void FillSuggestionFromSpecWithEntityData(const ActionSuggestionSpec* spec,
176                                             ActionSuggestion* suggestion) const;
177 
178   void PopulateTextReplies(const tflite::Interpreter* interpreter,
179                            int suggestion_index, int score_index,
180                            const std::string& type, float priority_score,
181                            const absl::flat_hash_set<std::string>& blocklist,
182                            ActionsSuggestionsResponse* response) const;
183 
184   void PopulateIntentTriggering(const tflite::Interpreter* interpreter,
185                                 int suggestion_index, int score_index,
186                                 const ActionSuggestionSpec* task_spec,
187                                 ActionsSuggestionsResponse* response) const;
188 
189   bool ReadModelOutput(tflite::Interpreter* interpreter,
190                        const ActionSuggestionOptions& options,
191                        ActionsSuggestionsResponse* response) const;
192 
193   bool SuggestActionsFromModel(
194       const Conversation& conversation, const int num_messages,
195       const ActionSuggestionOptions& options,
196       ActionsSuggestionsResponse* response,
197       std::unique_ptr<tflite::Interpreter>* interpreter) const;
198 
199   Status SuggestActionsFromConversationIntentDetection(
200       const Conversation& conversation, const ActionSuggestionOptions& options,
201       std::vector<ActionSuggestion>* actions) const;
202 
203   // Creates options for annotation of a message.
204   AnnotationOptions AnnotationOptionsForMessage(
205       const ConversationMessage& message) const;
206 
207   void SuggestActionsFromAnnotations(
208       const Conversation& conversation,
209       std::vector<ActionSuggestion>* actions) const;
210 
211   void SuggestActionsFromAnnotation(
212       const int message_index, const ActionSuggestionAnnotation& annotation,
213       std::vector<ActionSuggestion>* actions) const;
214 
215   // Run annotator on the messages of a conversation.
216   Conversation AnnotateConversation(const Conversation& conversation,
217                                     const Annotator* annotator) const;
218 
219   // Deduplicates equivalent annotations - annotations that have the same type
220   // and same span text.
221   // Returns the indices of the deduplicated annotations.
222   std::vector<int> DeduplicateAnnotations(
223       const std::vector<ActionSuggestionAnnotation>& annotations) const;
224 
225   bool SuggestActionsFromLua(
226       const Conversation& conversation,
227       const TfLiteModelExecutor* model_executor,
228       const tflite::Interpreter* interpreter,
229       const reflection::Schema* annotation_entity_data_schema,
230       std::vector<ActionSuggestion>* actions) const;
231 
232   bool GatherActionsSuggestions(const Conversation& conversation,
233                                 const Annotator* annotator,
234                                 const ActionSuggestionOptions& options,
235                                 ActionsSuggestionsResponse* response) const;
236 
237   std::unique_ptr<libtextclassifier3::ScopedMmap> mmap_;
238 
239   // Tensorflow Lite models.
240   std::unique_ptr<const TfLiteModelExecutor> model_executor_;
241 
242   // Regex rules model.
243   std::unique_ptr<RegexActions> regex_actions_;
244 
245   // The grammar rules model.
246   std::unique_ptr<GrammarActions> grammar_actions_;
247 
248   std::unique_ptr<UniLib> owned_unilib_;
249   const UniLib* unilib_;
250 
251   // Locales supported by the model.
252   std::vector<Locale> locales_;
253 
254   // Annotation entities used by the model.
255   std::unordered_set<std::string> annotation_entity_types_;
256 
257   // Builder for creating extra data.
258   const reflection::Schema* entity_data_schema_;
259   std::unique_ptr<MutableFlatbufferBuilder> entity_data_builder_;
260   std::unique_ptr<ActionsSuggestionsRanker> ranker_;
261 
262   std::string lua_bytecode_;
263 
264   // Triggering preconditions. These parameters can be backed by the model and
265   // (partially) be provided by flags.
266   TriggeringPreconditionsT preconditions_;
267   std::string triggering_preconditions_overlay_buffer_;
268   const TriggeringPreconditions* triggering_preconditions_overlay_;
269 
270   // Low confidence input ngram classifier.
271   std::unique_ptr<const SensitiveTopicModelBase> sensitive_model_;
272 
273   // Conversation intent detection model for additional actions.
274   std::unique_ptr<const ConversationIntentDetection>
275       conversation_intent_detection_;
276 };
277 
278 // Interprets the buffer as a Model flatbuffer and returns it for reading.
279 const ActionsModel* ViewActionsModel(const void* buffer, int size);
280 
281 // Opens model from given path and runs a function, passing the loaded Model
282 // flatbuffer as an argument.
283 //
284 // This is mainly useful if we don't want to pay the cost for the model
285 // initialization because we'll be only reading some flatbuffer values from the
286 // file.
287 template <typename ReturnType, typename Func>
VisitActionsModel(const std::string & path,Func function)288 ReturnType VisitActionsModel(const std::string& path, Func function) {
289   ScopedMmap mmap(path);
290   if (!mmap.handle().ok()) {
291     function(/*model=*/nullptr);
292   }
293   const ActionsModel* model =
294       ViewActionsModel(mmap.handle().start(), mmap.handle().num_bytes());
295   return function(model);
296 }
297 
298 class ActionsSuggestionsTypes {
299  public:
300   // Should be in sync with those defined in Android.
301   // android/frameworks/base/core/java/android/view/textclassifier/ConversationActions.java
ViewCalendar()302   static const std::string& ViewCalendar() {
303     static const std::string& value =
304         *[]() { return new std::string("view_calendar"); }();
305     return value;
306   }
ViewMap()307   static const std::string& ViewMap() {
308     static const std::string& value =
309         *[]() { return new std::string("view_map"); }();
310     return value;
311   }
TrackFlight()312   static const std::string& TrackFlight() {
313     static const std::string& value =
314         *[]() { return new std::string("track_flight"); }();
315     return value;
316   }
OpenUrl()317   static const std::string& OpenUrl() {
318     static const std::string& value =
319         *[]() { return new std::string("open_url"); }();
320     return value;
321   }
SendSms()322   static const std::string& SendSms() {
323     static const std::string& value =
324         *[]() { return new std::string("send_sms"); }();
325     return value;
326   }
CallPhone()327   static const std::string& CallPhone() {
328     static const std::string& value =
329         *[]() { return new std::string("call_phone"); }();
330     return value;
331   }
SendEmail()332   static const std::string& SendEmail() {
333     static const std::string& value =
334         *[]() { return new std::string("send_email"); }();
335     return value;
336   }
ShareLocation()337   static const std::string& ShareLocation() {
338     static const std::string& value =
339         *[]() { return new std::string("share_location"); }();
340     return value;
341   }
CreateReminder()342   static const std::string& CreateReminder() {
343     static const std::string& value =
344         *[]() { return new std::string("create_reminder"); }();
345     return value;
346   }
TextReply()347   static const std::string& TextReply() {
348     static const std::string& value =
349         *[]() { return new std::string("text_reply"); }();
350     return value;
351   }
AddContact()352   static const std::string& AddContact() {
353     static const std::string& value =
354         *[]() { return new std::string("add_contact"); }();
355     return value;
356   }
Copy()357   static const std::string& Copy() {
358     static const std::string& value =
359         *[]() { return new std::string("copy"); }();
360     return value;
361   }
362 };
363 
364 }  // namespace libtextclassifier3
365 
366 #endif  // LIBTEXTCLASSIFIER_ACTIONS_ACTIONS_SUGGESTIONS_H_
367