• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef LIBTEXTCLASSIFIER_ACTIONS_ACTIONS_SUGGESTIONS_H_
18 #define LIBTEXTCLASSIFIER_ACTIONS_ACTIONS_SUGGESTIONS_H_
19 
20 #include <map>
21 #include <memory>
22 #include <string>
23 #include <unordered_set>
24 #include <vector>
25 
26 #include "actions/actions_model_generated.h"
27 #include "actions/feature-processor.h"
28 #include "actions/ngram-model.h"
29 #include "actions/ranker.h"
30 #include "actions/types.h"
31 #include "annotator/annotator.h"
32 #include "annotator/model-executor.h"
33 #include "annotator/types.h"
34 #include "utils/flatbuffers.h"
35 #include "utils/i18n/locale.h"
36 #include "utils/memory/mmap.h"
37 #include "utils/tflite-model-executor.h"
38 #include "utils/utf8/unilib.h"
39 #include "utils/variant.h"
40 #include "utils/zlib/zlib.h"
41 
42 namespace libtextclassifier3 {
43 
44 // Options for suggesting actions.
45 struct ActionSuggestionOptions {
DefaultActionSuggestionOptions46   static ActionSuggestionOptions Default() { return ActionSuggestionOptions(); }
47 };
48 
49 // Class for predicting actions following a conversation.
50 class ActionsSuggestions {
51  public:
52   // Creates ActionsSuggestions from given data buffer with model.
53   static std::unique_ptr<ActionsSuggestions> FromUnownedBuffer(
54       const uint8_t* buffer, const int size, const UniLib* unilib = nullptr,
55       const std::string& triggering_preconditions_overlay = "");
56 
57   // Creates ActionsSuggestions from model in the ScopedMmap object and takes
58   // ownership of it.
59   static std::unique_ptr<ActionsSuggestions> FromScopedMmap(
60       std::unique_ptr<libtextclassifier3::ScopedMmap> mmap,
61       const UniLib* unilib = nullptr,
62       const std::string& triggering_preconditions_overlay = "");
63   // Same as above, but also takes ownership of the unilib.
64   static std::unique_ptr<ActionsSuggestions> FromScopedMmap(
65       std::unique_ptr<libtextclassifier3::ScopedMmap> mmap,
66       std::unique_ptr<UniLib> unilib,
67       const std::string& triggering_preconditions_overlay);
68 
69   // Creates ActionsSuggestions from model given as a file descriptor, offset
70   // and size in it. If offset and size are less than 0, will ignore them and
71   // will just use the fd.
72   static std::unique_ptr<ActionsSuggestions> FromFileDescriptor(
73       const int fd, const int offset, const int size,
74       const UniLib* unilib = nullptr,
75       const std::string& triggering_preconditions_overlay = "");
76   // Same as above, but also takes ownership of the unilib.
77   static std::unique_ptr<ActionsSuggestions> FromFileDescriptor(
78       const int fd, const int offset, const int size,
79       std::unique_ptr<UniLib> unilib,
80       const std::string& triggering_preconditions_overlay = "");
81 
82   // Creates ActionsSuggestions from model given as a file descriptor.
83   static std::unique_ptr<ActionsSuggestions> FromFileDescriptor(
84       const int fd, const UniLib* unilib = nullptr,
85       const std::string& triggering_preconditions_overlay = "");
86   // Same as above, but also takes ownership of the unilib.
87   static std::unique_ptr<ActionsSuggestions> FromFileDescriptor(
88       const int fd, std::unique_ptr<UniLib> unilib,
89       const std::string& triggering_preconditions_overlay);
90 
91   // Creates ActionsSuggestions from model given as a POSIX path.
92   static std::unique_ptr<ActionsSuggestions> FromPath(
93       const std::string& path, const UniLib* unilib = nullptr,
94       const std::string& triggering_preconditions_overlay = "");
95   // Same as above, but also takes ownership of unilib.
96   static std::unique_ptr<ActionsSuggestions> FromPath(
97       const std::string& path, std::unique_ptr<UniLib> unilib,
98       const std::string& triggering_preconditions_overlay);
99 
100   ActionsSuggestionsResponse SuggestActions(
101       const Conversation& conversation,
102       const ActionSuggestionOptions& options = ActionSuggestionOptions()) const;
103 
104   ActionsSuggestionsResponse SuggestActions(
105       const Conversation& conversation, const Annotator* annotator,
106       const ActionSuggestionOptions& options = ActionSuggestionOptions()) const;
107 
108   const ActionsModel* model() const;
109   const reflection::Schema* entity_data_schema() const;
110 
111   static const int kLocalUserId = 0;
112 
113   // Should be in sync with those defined in Android.
114   // android/frameworks/base/core/java/android/view/textclassifier/ConversationActions.java
115   static const std::string& kViewCalendarType;
116   static const std::string& kViewMapType;
117   static const std::string& kTrackFlightType;
118   static const std::string& kOpenUrlType;
119   static const std::string& kSendSmsType;
120   static const std::string& kCallPhoneType;
121   static const std::string& kSendEmailType;
122   static const std::string& kShareLocation;
123 
124  protected:
125   // Exposed for testing.
126   bool EmbedTokenId(const int32 token_id, std::vector<float>* embedding) const;
127 
128   // Embeds the tokens per message separately. Each message is padded to the
129   // maximum length with the padding token.
130   bool EmbedTokensPerMessage(const std::vector<std::vector<Token>>& tokens,
131                              std::vector<float>* embeddings,
132                              int* max_num_tokens_per_message) const;
133 
134   // Concatenates the embedded message tokens - separated by start and end
135   // token between messages.
136   // If the total token count is greater than the maximum length, tokens at the
137   // start are dropped to fit into the limit.
138   // If the total token count is smaller than the minimum length, padding tokens
139   // are added to the end.
140   // Messages are assumed to be ordered by recency - most recent is last.
141   bool EmbedAndFlattenTokens(const std::vector<std::vector<Token>> tokens,
142                              std::vector<float>* embeddings,
143                              int* total_token_count) const;
144 
145   const ActionsModel* model_;
146 
147   // Feature extractor and options.
148   std::unique_ptr<const ActionsFeatureProcessor> feature_processor_;
149   std::unique_ptr<const EmbeddingExecutor> embedding_executor_;
150   std::vector<float> embedded_padding_token_;
151   std::vector<float> embedded_start_token_;
152   std::vector<float> embedded_end_token_;
153   int token_embedding_size_;
154 
155  private:
156   struct CompiledRule {
157     const RulesModel_::Rule* rule;
158     std::unique_ptr<UniLib::RegexPattern> pattern;
159     std::unique_ptr<UniLib::RegexPattern> output_pattern;
CompiledRuleCompiledRule160     CompiledRule(const RulesModel_::Rule* rule,
161                  std::unique_ptr<UniLib::RegexPattern> pattern,
162                  std::unique_ptr<UniLib::RegexPattern> output_pattern)
163         : rule(rule),
164           pattern(std::move(pattern)),
165           output_pattern(std::move(output_pattern)) {}
166   };
167 
168   // Checks that model contains all required fields, and initializes internal
169   // datastructures.
170   bool ValidateAndInitialize();
171 
172   void SetOrCreateUnilib(const UniLib* unilib);
173 
174   // Initializes regular expression rules.
175   bool InitializeRules(ZlibDecompressor* decompressor);
176   bool InitializeRules(ZlibDecompressor* decompressor, const RulesModel* rules,
177                        std::vector<CompiledRule>* compiled_rules) const;
178 
179   // Prepare preconditions.
180   // Takes values from flag provided data, but falls back to model provided
181   // values for parameters that are not explicitly provided.
182   bool InitializeTriggeringPreconditions();
183 
184   // Tokenizes a conversation and produces the tokens per message.
185   std::vector<std::vector<Token>> Tokenize(
186       const std::vector<std::string>& context) const;
187 
188   bool AllocateInput(const int conversation_length, const int max_tokens,
189                      const int total_token_count,
190                      tflite::Interpreter* interpreter) const;
191 
192   bool SetupModelInput(const std::vector<std::string>& context,
193                        const std::vector<int>& user_ids,
194                        const std::vector<float>& time_diffs,
195                        const int num_suggestions,
196                        const float confidence_threshold,
197                        const float diversification_distance,
198                        const float empirical_probability_factor,
199                        tflite::Interpreter* interpreter) const;
200   bool ReadModelOutput(tflite::Interpreter* interpreter,
201                        const ActionSuggestionOptions& options,
202                        ActionsSuggestionsResponse* response) const;
203 
204   bool SuggestActionsFromModel(
205       const Conversation& conversation, const int num_messages,
206       const ActionSuggestionOptions& options,
207       ActionsSuggestionsResponse* response,
208       std::unique_ptr<tflite::Interpreter>* interpreter) const;
209 
210   // Creates options for annotation of a message.
211   AnnotationOptions AnnotationOptionsForMessage(
212       const ConversationMessage& message) const;
213 
214   void SuggestActionsFromAnnotations(
215       const Conversation& conversation, const ActionSuggestionOptions& options,
216       const Annotator* annotator, std::vector<ActionSuggestion>* actions) const;
217 
218   void SuggestActionsFromAnnotation(
219       const int message_index, const ActionSuggestionAnnotation& annotation,
220       std::vector<ActionSuggestion>* actions) const;
221 
222   // Deduplicates equivalent annotations - annotations that have the same type
223   // and same span text.
224   // Returns the indices of the deduplicated annotations.
225   std::vector<int> DeduplicateAnnotations(
226       const std::vector<ActionSuggestionAnnotation>& annotations) const;
227 
228   bool SuggestActionsFromRules(const Conversation& conversation,
229                                std::vector<ActionSuggestion>* actions) const;
230 
231   bool SuggestActionsFromLua(
232       const Conversation& conversation,
233       const TfLiteModelExecutor* model_executor,
234       const tflite::Interpreter* interpreter,
235       const reflection::Schema* annotation_entity_data_schema,
236       std::vector<ActionSuggestion>* actions) const;
237 
238   bool GatherActionsSuggestions(const Conversation& conversation,
239                                 const Annotator* annotator,
240                                 const ActionSuggestionOptions& options,
241                                 ActionsSuggestionsResponse* response) const;
242 
243   // Checks whether the input triggers the low confidence checks.
244   bool IsLowConfidenceInput(const Conversation& conversation,
245                             const int num_messages,
246                             std::vector<int>* post_check_rules) const;
247   // Checks and filters suggestions triggering the low confidence post checks.
248   bool FilterConfidenceOutput(const std::vector<int>& post_check_rules,
249                               std::vector<ActionSuggestion>* actions) const;
250 
251   ActionSuggestion SuggestionFromSpec(
252       const ActionSuggestionSpec* action, const std::string& default_type = "",
253       const std::string& default_response_text = "",
254       const std::string& default_serialized_entity_data = "",
255       const float default_score = 0.0f,
256       const float default_priority_score = 0.0f) const;
257 
258   bool FillAnnotationFromMatchGroup(
259       const UniLib::RegexMatcher* matcher,
260       const RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroup* group,
261       const int message_index, ActionSuggestionAnnotation* annotation) const;
262 
263   std::unique_ptr<libtextclassifier3::ScopedMmap> mmap_;
264 
265   // Tensorflow Lite models.
266   std::unique_ptr<const TfLiteModelExecutor> model_executor_;
267 
268   // Rules.
269   std::vector<CompiledRule> rules_, low_confidence_rules_;
270 
271   std::unique_ptr<UniLib> owned_unilib_;
272   const UniLib* unilib_;
273 
274   // Locales supported by the model.
275   std::vector<Locale> locales_;
276 
277   // Annotation entities used by the model.
278   std::unordered_set<std::string> annotation_entity_types_;
279 
280   // Builder for creating extra data.
281   const reflection::Schema* entity_data_schema_;
282   std::unique_ptr<ReflectiveFlatbufferBuilder> entity_data_builder_;
283   std::unique_ptr<ActionsSuggestionsRanker> ranker_;
284 
285   std::string lua_bytecode_;
286 
287   // Triggering preconditions. These parameters can be backed by the model and
288   // (partially) be provided by flags.
289   TriggeringPreconditionsT preconditions_;
290   std::string triggering_preconditions_overlay_buffer_;
291   const TriggeringPreconditions* triggering_preconditions_overlay_;
292 
293   // Low confidence input ngram classifier.
294   std::unique_ptr<const NGramModel> ngram_model_;
295 };
296 
297 // Interprets the buffer as a Model flatbuffer and returns it for reading.
298 const ActionsModel* ViewActionsModel(const void* buffer, int size);
299 
300 // Opens model from given path and runs a function, passing the loaded Model
301 // flatbuffer as an argument.
302 //
303 // This is mainly useful if we don't want to pay the cost for the model
304 // initialization because we'll be only reading some flatbuffer values from the
305 // file.
306 template <typename ReturnType, typename Func>
VisitActionsModel(const std::string & path,Func function)307 ReturnType VisitActionsModel(const std::string& path, Func function) {
308   ScopedMmap mmap(path);
309   if (!mmap.handle().ok()) {
310     function(/*model=*/nullptr);
311   }
312   const ActionsModel* model =
313       ViewActionsModel(mmap.handle().start(), mmap.handle().num_bytes());
314   return function(model);
315 }
316 
317 }  // namespace libtextclassifier3
318 
319 #endif  // LIBTEXTCLASSIFIER_ACTIONS_ACTIONS_SUGGESTIONS_H_
320