• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef LIBTEXTCLASSIFIER_ACTIONS_ACTIONS_SUGGESTIONS_H_
18 #define LIBTEXTCLASSIFIER_ACTIONS_ACTIONS_SUGGESTIONS_H_
19 
20 #include <map>
21 #include <memory>
22 #include <string>
23 #include <unordered_map>
24 #include <unordered_set>
25 #include <vector>
26 
27 #include "actions/actions_model_generated.h"
28 #include "actions/feature-processor.h"
29 #include "actions/grammar-actions.h"
30 #include "actions/ngram-model.h"
31 #include "actions/ranker.h"
32 #include "actions/regex-actions.h"
33 #include "actions/types.h"
34 #include "annotator/annotator.h"
35 #include "annotator/model-executor.h"
36 #include "annotator/types.h"
37 #include "utils/flatbuffers.h"
38 #include "utils/i18n/locale.h"
39 #include "utils/memory/mmap.h"
40 #include "utils/tflite-model-executor.h"
41 #include "utils/utf8/unilib.h"
42 #include "utils/variant.h"
43 #include "utils/zlib/zlib.h"
44 
45 namespace libtextclassifier3 {
46 
47 // Options for suggesting actions.
48 struct ActionSuggestionOptions {
DefaultActionSuggestionOptions49   static ActionSuggestionOptions Default() { return ActionSuggestionOptions(); }
50   std::unordered_map<std::string, Variant> model_parameters;
51 };
52 
53 // Class for predicting actions following a conversation.
54 class ActionsSuggestions {
55  public:
56   // Creates ActionsSuggestions from given data buffer with model.
57   static std::unique_ptr<ActionsSuggestions> FromUnownedBuffer(
58       const uint8_t* buffer, const int size, const UniLib* unilib = nullptr,
59       const std::string& triggering_preconditions_overlay = "");
60 
61   // Creates ActionsSuggestions from model in the ScopedMmap object and takes
62   // ownership of it.
63   static std::unique_ptr<ActionsSuggestions> FromScopedMmap(
64       std::unique_ptr<libtextclassifier3::ScopedMmap> mmap,
65       const UniLib* unilib = nullptr,
66       const std::string& triggering_preconditions_overlay = "");
67   // Same as above, but also takes ownership of the unilib.
68   static std::unique_ptr<ActionsSuggestions> FromScopedMmap(
69       std::unique_ptr<libtextclassifier3::ScopedMmap> mmap,
70       std::unique_ptr<UniLib> unilib,
71       const std::string& triggering_preconditions_overlay);
72 
73   // Creates ActionsSuggestions from model given as a file descriptor, offset
74   // and size in it. If offset and size are less than 0, will ignore them and
75   // will just use the fd.
76   static std::unique_ptr<ActionsSuggestions> FromFileDescriptor(
77       const int fd, const int offset, const int size,
78       const UniLib* unilib = nullptr,
79       const std::string& triggering_preconditions_overlay = "");
80   // Same as above, but also takes ownership of the unilib.
81   static std::unique_ptr<ActionsSuggestions> FromFileDescriptor(
82       const int fd, const int offset, const int size,
83       std::unique_ptr<UniLib> unilib,
84       const std::string& triggering_preconditions_overlay = "");
85 
86   // Creates ActionsSuggestions from model given as a file descriptor.
87   static std::unique_ptr<ActionsSuggestions> FromFileDescriptor(
88       const int fd, const UniLib* unilib = nullptr,
89       const std::string& triggering_preconditions_overlay = "");
90   // Same as above, but also takes ownership of the unilib.
91   static std::unique_ptr<ActionsSuggestions> FromFileDescriptor(
92       const int fd, std::unique_ptr<UniLib> unilib,
93       const std::string& triggering_preconditions_overlay);
94 
95   // Creates ActionsSuggestions from model given as a POSIX path.
96   static std::unique_ptr<ActionsSuggestions> FromPath(
97       const std::string& path, const UniLib* unilib = nullptr,
98       const std::string& triggering_preconditions_overlay = "");
99   // Same as above, but also takes ownership of unilib.
100   static std::unique_ptr<ActionsSuggestions> FromPath(
101       const std::string& path, std::unique_ptr<UniLib> unilib,
102       const std::string& triggering_preconditions_overlay);
103 
104   ActionsSuggestionsResponse SuggestActions(
105       const Conversation& conversation,
106       const ActionSuggestionOptions& options = ActionSuggestionOptions()) const;
107 
108   ActionsSuggestionsResponse SuggestActions(
109       const Conversation& conversation, const Annotator* annotator,
110       const ActionSuggestionOptions& options = ActionSuggestionOptions()) const;
111 
112   const ActionsModel* model() const;
113   const reflection::Schema* entity_data_schema() const;
114 
115   static constexpr int kLocalUserId = 0;
116 
117   // Should be in sync with those defined in Android.
118   // android/frameworks/base/core/java/android/view/textclassifier/ConversationActions.java
119   static const std::string& kViewCalendarType;
120   static const std::string& kViewMapType;
121   static const std::string& kTrackFlightType;
122   static const std::string& kOpenUrlType;
123   static const std::string& kSendSmsType;
124   static const std::string& kCallPhoneType;
125   static const std::string& kSendEmailType;
126   static const std::string& kShareLocation;
127 
128  protected:
129   // Exposed for testing.
130   bool EmbedTokenId(const int32 token_id, std::vector<float>* embedding) const;
131 
132   // Embeds the tokens per message separately. Each message is padded to the
133   // maximum length with the padding token.
134   bool EmbedTokensPerMessage(const std::vector<std::vector<Token>>& tokens,
135                              std::vector<float>* embeddings,
136                              int* max_num_tokens_per_message) const;
137 
138   // Concatenates the embedded message tokens - separated by start and end
139   // token between messages.
140   // If the total token count is greater than the maximum length, tokens at the
141   // start are dropped to fit into the limit.
142   // If the total token count is smaller than the minimum length, padding tokens
143   // are added to the end.
144   // Messages are assumed to be ordered by recency - most recent is last.
145   bool EmbedAndFlattenTokens(const std::vector<std::vector<Token>>& tokens,
146                              std::vector<float>* embeddings,
147                              int* total_token_count) const;
148 
149   const ActionsModel* model_;
150 
151   // Feature extractor and options.
152   std::unique_ptr<const ActionsFeatureProcessor> feature_processor_;
153   std::unique_ptr<const EmbeddingExecutor> embedding_executor_;
154   std::vector<float> embedded_padding_token_;
155   std::vector<float> embedded_start_token_;
156   std::vector<float> embedded_end_token_;
157   int token_embedding_size_;
158 
159  private:
160   // Checks that model contains all required fields, and initializes internal
161   // datastructures.
162   bool ValidateAndInitialize();
163 
164   void SetOrCreateUnilib(const UniLib* unilib);
165 
166   // Prepare preconditions.
167   // Takes values from flag provided data, but falls back to model provided
168   // values for parameters that are not explicitly provided.
169   bool InitializeTriggeringPreconditions();
170 
171   // Tokenizes a conversation and produces the tokens per message.
172   std::vector<std::vector<Token>> Tokenize(
173       const std::vector<std::string>& context) const;
174 
175   bool AllocateInput(const int conversation_length, const int max_tokens,
176                      const int total_token_count,
177                      tflite::Interpreter* interpreter) const;
178 
179   bool SetupModelInput(const std::vector<std::string>& context,
180                        const std::vector<int>& user_ids,
181                        const std::vector<float>& time_diffs,
182                        const int num_suggestions,
183                        const ActionSuggestionOptions& options,
184                        tflite::Interpreter* interpreter) const;
185 
186   void FillSuggestionFromSpecWithEntityData(const ActionSuggestionSpec* spec,
187                                             ActionSuggestion* suggestion) const;
188 
189   void PopulateTextReplies(const tflite::Interpreter* interpreter,
190                            int suggestion_index, int score_index,
191                            const std::string& type,
192                            ActionsSuggestionsResponse* response) const;
193 
194   void PopulateIntentTriggering(const tflite::Interpreter* interpreter,
195                                 int suggestion_index, int score_index,
196                                 const ActionSuggestionSpec* task_spec,
197                                 ActionsSuggestionsResponse* response) const;
198 
199   bool ReadModelOutput(tflite::Interpreter* interpreter,
200                        const ActionSuggestionOptions& options,
201                        ActionsSuggestionsResponse* response) const;
202 
203   bool SuggestActionsFromModel(
204       const Conversation& conversation, const int num_messages,
205       const ActionSuggestionOptions& options,
206       ActionsSuggestionsResponse* response,
207       std::unique_ptr<tflite::Interpreter>* interpreter) const;
208 
209   // Creates options for annotation of a message.
210   AnnotationOptions AnnotationOptionsForMessage(
211       const ConversationMessage& message) const;
212 
213   void SuggestActionsFromAnnotations(
214       const Conversation& conversation,
215       std::vector<ActionSuggestion>* actions) const;
216 
217   void SuggestActionsFromAnnotation(
218       const int message_index, const ActionSuggestionAnnotation& annotation,
219       std::vector<ActionSuggestion>* actions) const;
220 
221   // Run annotator on the messages of a conversation.
222   Conversation AnnotateConversation(const Conversation& conversation,
223                                     const Annotator* annotator) const;
224 
225   // Deduplicates equivalent annotations - annotations that have the same type
226   // and same span text.
227   // Returns the indices of the deduplicated annotations.
228   std::vector<int> DeduplicateAnnotations(
229       const std::vector<ActionSuggestionAnnotation>& annotations) const;
230 
231   bool SuggestActionsFromLua(
232       const Conversation& conversation,
233       const TfLiteModelExecutor* model_executor,
234       const tflite::Interpreter* interpreter,
235       const reflection::Schema* annotation_entity_data_schema,
236       std::vector<ActionSuggestion>* actions) const;
237 
238   bool GatherActionsSuggestions(const Conversation& conversation,
239                                 const Annotator* annotator,
240                                 const ActionSuggestionOptions& options,
241                                 ActionsSuggestionsResponse* response) const;
242 
243   std::unique_ptr<libtextclassifier3::ScopedMmap> mmap_;
244 
245   // Tensorflow Lite models.
246   std::unique_ptr<const TfLiteModelExecutor> model_executor_;
247 
248   // Regex rules model.
249   std::unique_ptr<RegexActions> regex_actions_;
250 
251   // The grammar rules model.
252   std::unique_ptr<GrammarActions> grammar_actions_;
253 
254   std::unique_ptr<UniLib> owned_unilib_;
255   const UniLib* unilib_;
256 
257   // Locales supported by the model.
258   std::vector<Locale> locales_;
259 
260   // Annotation entities used by the model.
261   std::unordered_set<std::string> annotation_entity_types_;
262 
263   // Builder for creating extra data.
264   const reflection::Schema* entity_data_schema_;
265   std::unique_ptr<ReflectiveFlatbufferBuilder> entity_data_builder_;
266   std::unique_ptr<ActionsSuggestionsRanker> ranker_;
267 
268   std::string lua_bytecode_;
269 
270   // Triggering preconditions. These parameters can be backed by the model and
271   // (partially) be provided by flags.
272   TriggeringPreconditionsT preconditions_;
273   std::string triggering_preconditions_overlay_buffer_;
274   const TriggeringPreconditions* triggering_preconditions_overlay_;
275 
276   // Low confidence input ngram classifier.
277   std::unique_ptr<const NGramModel> ngram_model_;
278 };
279 
280 // Interprets the buffer as a Model flatbuffer and returns it for reading.
281 const ActionsModel* ViewActionsModel(const void* buffer, int size);
282 
283 // Opens model from given path and runs a function, passing the loaded Model
284 // flatbuffer as an argument.
285 //
286 // This is mainly useful if we don't want to pay the cost for the model
287 // initialization because we'll be only reading some flatbuffer values from the
288 // file.
289 template <typename ReturnType, typename Func>
VisitActionsModel(const std::string & path,Func function)290 ReturnType VisitActionsModel(const std::string& path, Func function) {
291   ScopedMmap mmap(path);
292   if (!mmap.handle().ok()) {
293     function(/*model=*/nullptr);
294   }
295   const ActionsModel* model =
296       ViewActionsModel(mmap.handle().start(), mmap.handle().num_bytes());
297   return function(model);
298 }
299 
300 }  // namespace libtextclassifier3
301 
302 #endif  // LIBTEXTCLASSIFIER_ACTIONS_ACTIONS_SUGGESTIONS_H_
303