• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef LIBTEXTCLASSIFIER_ACTIONS_ACTIONS_SUGGESTIONS_H_
18 #define LIBTEXTCLASSIFIER_ACTIONS_ACTIONS_SUGGESTIONS_H_
19 
20 #include <map>
21 #include <memory>
22 #include <string>
23 #include <unordered_map>
24 #include <unordered_set>
25 #include <vector>
26 
27 #include "actions/actions_model_generated.h"
28 #include "actions/conversation_intent_detection/conversation-intent-detection.h"
29 #include "actions/feature-processor.h"
30 #include "actions/grammar-actions.h"
31 #include "actions/ranker.h"
32 #include "actions/regex-actions.h"
33 #include "actions/sensitive-classifier-base.h"
34 #include "actions/types.h"
35 #include "annotator/annotator.h"
36 #include "annotator/model-executor.h"
37 #include "annotator/types.h"
38 #include "utils/flatbuffers/flatbuffers.h"
39 #include "utils/flatbuffers/mutable.h"
40 #include "utils/i18n/locale.h"
41 #include "utils/memory/mmap.h"
42 #include "utils/tflite-model-executor.h"
43 #include "utils/utf8/unilib.h"
44 #include "utils/variant.h"
45 #include "utils/zlib/zlib.h"
46 
47 namespace libtextclassifier3 {
48 
49 // Class for predicting actions following a conversation.
50 class ActionsSuggestions {
51  public:
52   // Creates ActionsSuggestions from given data buffer with model.
53   static std::unique_ptr<ActionsSuggestions> FromUnownedBuffer(
54       const uint8_t* buffer, const int size, const UniLib* unilib = nullptr,
55       const std::string& triggering_preconditions_overlay = "");
56 
57   // Creates ActionsSuggestions from model in the ScopedMmap object and takes
58   // ownership of it.
59   static std::unique_ptr<ActionsSuggestions> FromScopedMmap(
60       std::unique_ptr<libtextclassifier3::ScopedMmap> mmap,
61       const UniLib* unilib = nullptr,
62       const std::string& triggering_preconditions_overlay = "");
63   // Same as above, but also takes ownership of the unilib.
64   static std::unique_ptr<ActionsSuggestions> FromScopedMmap(
65       std::unique_ptr<libtextclassifier3::ScopedMmap> mmap,
66       std::unique_ptr<UniLib> unilib,
67       const std::string& triggering_preconditions_overlay);
68 
69   // Creates ActionsSuggestions from model given as a file descriptor, offset
70   // and size in it. If offset and size are less than 0, will ignore them and
71   // will just use the fd.
72   static std::unique_ptr<ActionsSuggestions> FromFileDescriptor(
73       const int fd, const int offset, const int size,
74       const UniLib* unilib = nullptr,
75       const std::string& triggering_preconditions_overlay = "");
76   // Same as above, but also takes ownership of the unilib.
77   static std::unique_ptr<ActionsSuggestions> FromFileDescriptor(
78       const int fd, const int offset, const int size,
79       std::unique_ptr<UniLib> unilib,
80       const std::string& triggering_preconditions_overlay = "");
81 
82   // Creates ActionsSuggestions from model given as a file descriptor.
83   static std::unique_ptr<ActionsSuggestions> FromFileDescriptor(
84       const int fd, const UniLib* unilib = nullptr,
85       const std::string& triggering_preconditions_overlay = "");
86   // Same as above, but also takes ownership of the unilib.
87   static std::unique_ptr<ActionsSuggestions> FromFileDescriptor(
88       const int fd, std::unique_ptr<UniLib> unilib,
89       const std::string& triggering_preconditions_overlay);
90 
91   // Creates ActionsSuggestions from model given as a POSIX path.
92   static std::unique_ptr<ActionsSuggestions> FromPath(
93       const std::string& path, const UniLib* unilib = nullptr,
94       const std::string& triggering_preconditions_overlay = "");
95   // Same as above, but also takes ownership of unilib.
96   static std::unique_ptr<ActionsSuggestions> FromPath(
97       const std::string& path, std::unique_ptr<UniLib> unilib,
98       const std::string& triggering_preconditions_overlay);
99 
100   ActionsSuggestionsResponse SuggestActions(
101       const Conversation& conversation,
102       const ActionSuggestionOptions& options = ActionSuggestionOptions()) const;
103 
104   ActionsSuggestionsResponse SuggestActions(
105       const Conversation& conversation, const Annotator* annotator,
106       const ActionSuggestionOptions& options = ActionSuggestionOptions()) const;
107 
108   bool InitializeConversationIntentDetection(
109       const std::string& serialized_config);
110 
111   const ActionsModel* model() const;
112   const reflection::Schema* entity_data_schema() const;
113 
114   static constexpr int kLocalUserId = 0;
115 
116  protected:
117   // Exposed for testing.
118   bool EmbedTokenId(const int32 token_id, std::vector<float>* embedding) const;
119 
120   // Embeds the tokens per message separately. Each message is padded to the
121   // maximum length with the padding token.
122   bool EmbedTokensPerMessage(const std::vector<std::vector<Token>>& tokens,
123                              std::vector<float>* embeddings,
124                              int* max_num_tokens_per_message) const;
125 
126   // Concatenates the embedded message tokens - separated by start and end
127   // token between messages.
128   // If the total token count is greater than the maximum length, tokens at the
129   // start are dropped to fit into the limit.
130   // If the total token count is smaller than the minimum length, padding tokens
131   // are added to the end.
132   // Messages are assumed to be ordered by recency - most recent is last.
133   bool EmbedAndFlattenTokens(const std::vector<std::vector<Token>>& tokens,
134                              std::vector<float>* embeddings,
135                              int* total_token_count) const;
136 
137   const ActionsModel* model_;
138 
139   // Feature extractor and options.
140   std::unique_ptr<const ActionsFeatureProcessor> feature_processor_;
141   std::unique_ptr<const EmbeddingExecutor> embedding_executor_;
142   std::vector<float> embedded_padding_token_;
143   std::vector<float> embedded_start_token_;
144   std::vector<float> embedded_end_token_;
145   int token_embedding_size_;
146 
147  private:
148   // Checks that model contains all required fields, and initializes internal
149   // datastructures.
150   bool ValidateAndInitialize();
151 
152   void SetOrCreateUnilib(const UniLib* unilib);
153 
154   // Prepare preconditions.
155   // Takes values from flag provided data, but falls back to model provided
156   // values for parameters that are not explicitly provided.
157   bool InitializeTriggeringPreconditions();
158 
159   // Tokenizes a conversation and produces the tokens per message.
160   std::vector<std::vector<Token>> Tokenize(
161       const std::vector<std::string>& context) const;
162 
163   bool AllocateInput(const int conversation_length, const int max_tokens,
164                      const int total_token_count,
165                      tflite::Interpreter* interpreter) const;
166 
167   bool SetupModelInput(const std::vector<std::string>& context,
168                        const std::vector<int>& user_ids,
169                        const std::vector<float>& time_diffs,
170                        const int num_suggestions,
171                        const ActionSuggestionOptions& options,
172                        tflite::Interpreter* interpreter) const;
173 
174   void FillSuggestionFromSpecWithEntityData(const ActionSuggestionSpec* spec,
175                                             ActionSuggestion* suggestion) const;
176 
177   void PopulateTextReplies(const tflite::Interpreter* interpreter,
178                            int suggestion_index, int score_index,
179                            const std::string& type,
180                            ActionsSuggestionsResponse* response) const;
181 
182   void PopulateIntentTriggering(const tflite::Interpreter* interpreter,
183                                 int suggestion_index, int score_index,
184                                 const ActionSuggestionSpec* task_spec,
185                                 ActionsSuggestionsResponse* response) const;
186 
187   bool ReadModelOutput(tflite::Interpreter* interpreter,
188                        const ActionSuggestionOptions& options,
189                        ActionsSuggestionsResponse* response) const;
190 
191   bool SuggestActionsFromModel(
192       const Conversation& conversation, const int num_messages,
193       const ActionSuggestionOptions& options,
194       ActionsSuggestionsResponse* response,
195       std::unique_ptr<tflite::Interpreter>* interpreter) const;
196 
197   Status SuggestActionsFromConversationIntentDetection(
198       const Conversation& conversation, const ActionSuggestionOptions& options,
199       std::vector<ActionSuggestion>* actions) const;
200 
201   // Creates options for annotation of a message.
202   AnnotationOptions AnnotationOptionsForMessage(
203       const ConversationMessage& message) const;
204 
205   void SuggestActionsFromAnnotations(
206       const Conversation& conversation,
207       std::vector<ActionSuggestion>* actions) const;
208 
209   void SuggestActionsFromAnnotation(
210       const int message_index, const ActionSuggestionAnnotation& annotation,
211       std::vector<ActionSuggestion>* actions) const;
212 
213   // Run annotator on the messages of a conversation.
214   Conversation AnnotateConversation(const Conversation& conversation,
215                                     const Annotator* annotator) const;
216 
217   // Deduplicates equivalent annotations - annotations that have the same type
218   // and same span text.
219   // Returns the indices of the deduplicated annotations.
220   std::vector<int> DeduplicateAnnotations(
221       const std::vector<ActionSuggestionAnnotation>& annotations) const;
222 
223   bool SuggestActionsFromLua(
224       const Conversation& conversation,
225       const TfLiteModelExecutor* model_executor,
226       const tflite::Interpreter* interpreter,
227       const reflection::Schema* annotation_entity_data_schema,
228       std::vector<ActionSuggestion>* actions) const;
229 
230   bool GatherActionsSuggestions(const Conversation& conversation,
231                                 const Annotator* annotator,
232                                 const ActionSuggestionOptions& options,
233                                 ActionsSuggestionsResponse* response) const;
234 
235   std::unique_ptr<libtextclassifier3::ScopedMmap> mmap_;
236 
237   // Tensorflow Lite models.
238   std::unique_ptr<const TfLiteModelExecutor> model_executor_;
239 
240   // Regex rules model.
241   std::unique_ptr<RegexActions> regex_actions_;
242 
243   // The grammar rules model.
244   std::unique_ptr<GrammarActions> grammar_actions_;
245 
246   std::unique_ptr<UniLib> owned_unilib_;
247   const UniLib* unilib_;
248 
249   // Locales supported by the model.
250   std::vector<Locale> locales_;
251 
252   // Annotation entities used by the model.
253   std::unordered_set<std::string> annotation_entity_types_;
254 
255   // Builder for creating extra data.
256   const reflection::Schema* entity_data_schema_;
257   std::unique_ptr<MutableFlatbufferBuilder> entity_data_builder_;
258   std::unique_ptr<ActionsSuggestionsRanker> ranker_;
259 
260   std::string lua_bytecode_;
261 
262   // Triggering preconditions. These parameters can be backed by the model and
263   // (partially) be provided by flags.
264   TriggeringPreconditionsT preconditions_;
265   std::string triggering_preconditions_overlay_buffer_;
266   const TriggeringPreconditions* triggering_preconditions_overlay_;
267 
268   // Low confidence input ngram classifier.
269   std::unique_ptr<const SensitiveTopicModelBase> sensitive_model_;
270 
271   // Conversation intent detection model for additional actions.
272   std::unique_ptr<const ConversationIntentDetection>
273       conversation_intent_detection_;
274 };
275 
276 // Interprets the buffer as a Model flatbuffer and returns it for reading.
277 const ActionsModel* ViewActionsModel(const void* buffer, int size);
278 
279 // Opens model from given path and runs a function, passing the loaded Model
280 // flatbuffer as an argument.
281 //
282 // This is mainly useful if we don't want to pay the cost for the model
283 // initialization because we'll be only reading some flatbuffer values from the
284 // file.
285 template <typename ReturnType, typename Func>
VisitActionsModel(const std::string & path,Func function)286 ReturnType VisitActionsModel(const std::string& path, Func function) {
287   ScopedMmap mmap(path);
288   if (!mmap.handle().ok()) {
289     function(/*model=*/nullptr);
290   }
291   const ActionsModel* model =
292       ViewActionsModel(mmap.handle().start(), mmap.handle().num_bytes());
293   return function(model);
294 }
295 
296 class ActionsSuggestionsTypes {
297  public:
298   // Should be in sync with those defined in Android.
299   // android/frameworks/base/core/java/android/view/textclassifier/ConversationActions.java
ViewCalendar()300   static const std::string& ViewCalendar() {
301     static const std::string& value =
302         *[]() { return new std::string("view_calendar"); }();
303     return value;
304   }
ViewMap()305   static const std::string& ViewMap() {
306     static const std::string& value =
307         *[]() { return new std::string("view_map"); }();
308     return value;
309   }
TrackFlight()310   static const std::string& TrackFlight() {
311     static const std::string& value =
312         *[]() { return new std::string("track_flight"); }();
313     return value;
314   }
OpenUrl()315   static const std::string& OpenUrl() {
316     static const std::string& value =
317         *[]() { return new std::string("open_url"); }();
318     return value;
319   }
SendSms()320   static const std::string& SendSms() {
321     static const std::string& value =
322         *[]() { return new std::string("send_sms"); }();
323     return value;
324   }
CallPhone()325   static const std::string& CallPhone() {
326     static const std::string& value =
327         *[]() { return new std::string("call_phone"); }();
328     return value;
329   }
SendEmail()330   static const std::string& SendEmail() {
331     static const std::string& value =
332         *[]() { return new std::string("send_email"); }();
333     return value;
334   }
ShareLocation()335   static const std::string& ShareLocation() {
336     static const std::string& value =
337         *[]() { return new std::string("share_location"); }();
338     return value;
339   }
CreateReminder()340   static const std::string& CreateReminder() {
341     static const std::string& value =
342         *[]() { return new std::string("create_reminder"); }();
343     return value;
344   }
TextReply()345   static const std::string& TextReply() {
346     static const std::string& value =
347         *[]() { return new std::string("text_reply"); }();
348     return value;
349   }
AddContact()350   static const std::string& AddContact() {
351     static const std::string& value =
352         *[]() { return new std::string("add_contact"); }();
353     return value;
354   }
Copy()355   static const std::string& Copy() {
356     static const std::string& value =
357         *[]() { return new std::string("copy"); }();
358     return value;
359   }
360 };
361 
362 }  // namespace libtextclassifier3
363 
364 #endif  // LIBTEXTCLASSIFIER_ACTIONS_ACTIONS_SUGGESTIONS_H_
365