• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "actions/ranker.h"
18 
19 #include <functional>
20 #include <set>
21 #include <vector>
22 
23 #include "actions/lua-ranker.h"
24 #include "actions/zlib-utils.h"
25 #include "annotator/types.h"
26 #include "utils/base/logging.h"
27 #include "utils/lua-utils.h"
28 
29 namespace libtextclassifier3 {
30 namespace {
31 
SortByScoreAndType(std::vector<ActionSuggestion> * actions)32 void SortByScoreAndType(std::vector<ActionSuggestion>* actions) {
33   std::sort(actions->begin(), actions->end(),
34             [](const ActionSuggestion& a, const ActionSuggestion& b) {
35               return a.score > b.score ||
36                      (a.score >= b.score && a.type < b.type);
37             });
38 }
39 
40 template <typename T>
Compare(const T & left,const T & right)41 int Compare(const T& left, const T& right) {
42   if (left < right) {
43     return -1;
44   }
45   if (left > right) {
46     return 1;
47   }
48   return 0;
49 }
50 
51 template <>
Compare(const std::string & left,const std::string & right)52 int Compare(const std::string& left, const std::string& right) {
53   return left.compare(right);
54 }
55 
56 template <>
Compare(const MessageTextSpan & span,const MessageTextSpan & other)57 int Compare(const MessageTextSpan& span, const MessageTextSpan& other) {
58   if (const int value = Compare(span.message_index, other.message_index)) {
59     return value;
60   }
61   if (const int value = Compare(span.span.first, other.span.first)) {
62     return value;
63   }
64   if (const int value = Compare(span.span.second, other.span.second)) {
65     return value;
66   }
67   return 0;
68 }
69 
IsSameSpan(const MessageTextSpan & span,const MessageTextSpan & other)70 bool IsSameSpan(const MessageTextSpan& span, const MessageTextSpan& other) {
71   return Compare(span, other) == 0;
72 }
73 
TextSpansIntersect(const MessageTextSpan & span,const MessageTextSpan & other)74 bool TextSpansIntersect(const MessageTextSpan& span,
75                         const MessageTextSpan& other) {
76   return span.message_index == other.message_index &&
77          SpansOverlap(span.span, other.span);
78 }
79 
80 template <>
Compare(const ActionSuggestionAnnotation & annotation,const ActionSuggestionAnnotation & other)81 int Compare(const ActionSuggestionAnnotation& annotation,
82             const ActionSuggestionAnnotation& other) {
83   if (const int value = Compare(annotation.span, other.span)) {
84     return value;
85   }
86   if (const int value = Compare(annotation.name, other.name)) {
87     return value;
88   }
89   if (const int value =
90           Compare(annotation.entity.collection, other.entity.collection)) {
91     return value;
92   }
93   return 0;
94 }
95 
96 // Checks whether two annotations can be considered equivalent.
IsEquivalentActionAnnotation(const ActionSuggestionAnnotation & annotation,const ActionSuggestionAnnotation & other)97 bool IsEquivalentActionAnnotation(const ActionSuggestionAnnotation& annotation,
98                                   const ActionSuggestionAnnotation& other) {
99   return Compare(annotation, other) == 0;
100 }
101 
102 // Compares actions based on annotations.
CompareAnnotationsOnly(const ActionSuggestion & action,const ActionSuggestion & other)103 int CompareAnnotationsOnly(const ActionSuggestion& action,
104                            const ActionSuggestion& other) {
105   if (const int value =
106           Compare(action.annotations.size(), other.annotations.size())) {
107     return value;
108   }
109   for (int i = 0; i < action.annotations.size(); i++) {
110     if (const int value =
111             Compare(action.annotations[i], other.annotations[i])) {
112       return value;
113     }
114   }
115   return 0;
116 }
117 
118 // Checks whether two actions have the same annotations.
HaveEquivalentAnnotations(const ActionSuggestion & action,const ActionSuggestion & other)119 bool HaveEquivalentAnnotations(const ActionSuggestion& action,
120                                const ActionSuggestion& other) {
121   return CompareAnnotationsOnly(action, other) == 0;
122 }
123 
124 template <>
Compare(const ActionSuggestion & action,const ActionSuggestion & other)125 int Compare(const ActionSuggestion& action, const ActionSuggestion& other) {
126   if (const int value = Compare(action.type, other.type)) {
127     return value;
128   }
129   if (const int value = Compare(action.response_text, other.response_text)) {
130     return value;
131   }
132   if (const int value = Compare(action.serialized_entity_data,
133                                 other.serialized_entity_data)) {
134     return value;
135   }
136   return CompareAnnotationsOnly(action, other);
137 }
138 
139 // Checks whether two action suggestions can be considered equivalent.
IsEquivalentActionSuggestion(const ActionSuggestion & action,const ActionSuggestion & other)140 bool IsEquivalentActionSuggestion(const ActionSuggestion& action,
141                                   const ActionSuggestion& other) {
142   return Compare(action, other) == 0;
143 }
144 
145 // Checks whether any action is equivalent to the given one.
IsAnyActionEquivalent(const ActionSuggestion & action,const std::vector<ActionSuggestion> & actions)146 bool IsAnyActionEquivalent(const ActionSuggestion& action,
147                            const std::vector<ActionSuggestion>& actions) {
148   for (const ActionSuggestion& other : actions) {
149     if (IsEquivalentActionSuggestion(action, other)) {
150       return true;
151     }
152   }
153   return false;
154 }
155 
IsConflicting(const ActionSuggestionAnnotation & annotation,const ActionSuggestionAnnotation & other)156 bool IsConflicting(const ActionSuggestionAnnotation& annotation,
157                    const ActionSuggestionAnnotation& other) {
158   // Two annotations are conflicting if they are different but refer to
159   // overlapping spans in the conversation.
160   return (!IsEquivalentActionAnnotation(annotation, other) &&
161           TextSpansIntersect(annotation.span, other.span));
162 }
163 
164 // Checks whether two action suggestions can be considered conflicting.
IsConflictingActionSuggestion(const ActionSuggestion & action,const ActionSuggestion & other)165 bool IsConflictingActionSuggestion(const ActionSuggestion& action,
166                                    const ActionSuggestion& other) {
167   // Actions are considered conflicting, iff they refer to the same text span,
168   // but were not generated from the same annotation.
169   if (action.annotations.empty() || other.annotations.empty()) {
170     return false;
171   }
172   for (const ActionSuggestionAnnotation& annotation : action.annotations) {
173     for (const ActionSuggestionAnnotation& other_annotation :
174          other.annotations) {
175       if (IsConflicting(annotation, other_annotation)) {
176         return true;
177       }
178     }
179   }
180   return false;
181 }
182 
183 // Checks whether any action is considered conflicting with the given one.
IsAnyActionConflicting(const ActionSuggestion & action,const std::vector<ActionSuggestion> & actions)184 bool IsAnyActionConflicting(const ActionSuggestion& action,
185                             const std::vector<ActionSuggestion>& actions) {
186   for (const ActionSuggestion& other : actions) {
187     if (IsConflictingActionSuggestion(action, other)) {
188       return true;
189     }
190   }
191   return false;
192 }
193 
194 }  // namespace
195 
196 std::unique_ptr<ActionsSuggestionsRanker>
CreateActionsSuggestionsRanker(const RankingOptions * options,ZlibDecompressor * decompressor,const std::string & smart_reply_action_type)197 ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
198     const RankingOptions* options, ZlibDecompressor* decompressor,
199     const std::string& smart_reply_action_type) {
200   auto ranker = std::unique_ptr<ActionsSuggestionsRanker>(
201       new ActionsSuggestionsRanker(options, smart_reply_action_type));
202 
203   if (!ranker->InitializeAndValidate(decompressor)) {
204     TC3_LOG(ERROR) << "Could not initialize action ranker.";
205     return nullptr;
206   }
207 
208   return ranker;
209 }
210 
InitializeAndValidate(ZlibDecompressor * decompressor)211 bool ActionsSuggestionsRanker::InitializeAndValidate(
212     ZlibDecompressor* decompressor) {
213   if (options_ == nullptr) {
214     TC3_LOG(ERROR) << "No ranking options specified.";
215     return false;
216   }
217 
218   std::string lua_ranking_script;
219   if (GetUncompressedString(options_->lua_ranking_script(),
220                             options_->compressed_lua_ranking_script(),
221                             decompressor, &lua_ranking_script) &&
222       !lua_ranking_script.empty()) {
223     if (!Compile(lua_ranking_script, &lua_bytecode_)) {
224       TC3_LOG(ERROR) << "Could not precompile lua ranking snippet.";
225       return false;
226     }
227   }
228 
229   return true;
230 }
231 
RankActions(const Conversation & conversation,ActionsSuggestionsResponse * response,const reflection::Schema * entity_data_schema,const reflection::Schema * annotations_entity_data_schema) const232 bool ActionsSuggestionsRanker::RankActions(
233     const Conversation& conversation, ActionsSuggestionsResponse* response,
234     const reflection::Schema* entity_data_schema,
235     const reflection::Schema* annotations_entity_data_schema) const {
236   if (options_->deduplicate_suggestions() ||
237       options_->deduplicate_suggestions_by_span()) {
238     // First order suggestions by priority score for deduplication.
239     std::sort(
240         response->actions.begin(), response->actions.end(),
241         [](const ActionSuggestion& a, const ActionSuggestion& b) {
242           return a.priority_score > b.priority_score ||
243                  (a.priority_score >= b.priority_score && a.score > b.score);
244         });
245 
246     // Deduplicate, keeping the higher score actions.
247     if (options_->deduplicate_suggestions()) {
248       std::vector<ActionSuggestion> deduplicated_actions;
249       for (const ActionSuggestion& candidate : response->actions) {
250         // Check whether we already have an equivalent action.
251         if (!IsAnyActionEquivalent(candidate, deduplicated_actions)) {
252           deduplicated_actions.push_back(std::move(candidate));
253         }
254       }
255       response->actions = std::move(deduplicated_actions);
256     }
257 
258     // Resolve conflicts between conflicting actions referring to the same
259     // text span.
260     if (options_->deduplicate_suggestions_by_span()) {
261       std::vector<ActionSuggestion> deduplicated_actions;
262       for (const ActionSuggestion& candidate : response->actions) {
263         // Check whether we already have a conflicting action.
264         if (!IsAnyActionConflicting(candidate, deduplicated_actions)) {
265           deduplicated_actions.push_back(std::move(candidate));
266         }
267       }
268       response->actions = std::move(deduplicated_actions);
269     }
270   }
271 
272   // Suppress smart replies if actions are present.
273   if (options_->suppress_smart_replies_with_actions()) {
274     std::vector<ActionSuggestion> non_smart_reply_actions;
275     for (const ActionSuggestion& action : response->actions) {
276       if (action.type != smart_reply_action_type_) {
277         non_smart_reply_actions.push_back(std::move(action));
278       }
279     }
280     response->actions = std::move(non_smart_reply_actions);
281   }
282 
283   // Group by annotation if specified.
284   if (options_->group_by_annotations()) {
285     auto group_id = std::map<
286         ActionSuggestion, int,
287         std::function<bool(const ActionSuggestion&, const ActionSuggestion&)>>{
288         [](const ActionSuggestion& action, const ActionSuggestion& other) {
289           return (CompareAnnotationsOnly(action, other) < 0);
290         }};
291     typedef std::vector<ActionSuggestion> ActionSuggestionGroup;
292     std::vector<ActionSuggestionGroup> groups;
293 
294     // Group actions by the annotation set they are based of.
295     for (const ActionSuggestion& action : response->actions) {
296       // Treat actions with no annotations idependently.
297       if (action.annotations.empty()) {
298         groups.emplace_back(1, action);
299         continue;
300       }
301 
302       auto it = group_id.find(action);
303       if (it != group_id.end()) {
304         groups[it->second].push_back(action);
305       } else {
306         group_id[action] = groups.size();
307         groups.emplace_back(1, action);
308       }
309     }
310 
311     // Sort within each group by score.
312     for (std::vector<ActionSuggestion>& group : groups) {
313       SortByScoreAndType(&group);
314     }
315 
316     // Sort groups by maximum score.
317     std::sort(groups.begin(), groups.end(),
318               [](const std::vector<ActionSuggestion>& a,
319                  const std::vector<ActionSuggestion>& b) {
320                 return a.begin()->score > b.begin()->score ||
321                        (a.begin()->score >= b.begin()->score &&
322                         a.begin()->type < b.begin()->type);
323               });
324 
325     // Flatten result.
326     const size_t num_actions = response->actions.size();
327     response->actions.clear();
328     response->actions.reserve(num_actions);
329     for (const std::vector<ActionSuggestion>& actions : groups) {
330       response->actions.insert(response->actions.end(), actions.begin(),
331                                actions.end());
332     }
333 
334   } else {
335     // Order suggestions independently by score.
336     SortByScoreAndType(&response->actions);
337   }
338 
339   // Run lua ranking snippet, if provided.
340   if (!lua_bytecode_.empty()) {
341     auto lua_ranker = ActionsSuggestionsLuaRanker::Create(
342         conversation, lua_bytecode_, entity_data_schema,
343         annotations_entity_data_schema, response);
344     if (lua_ranker == nullptr || !lua_ranker->RankActions()) {
345       TC3_LOG(ERROR) << "Could not run lua ranking snippet.";
346       return false;
347     }
348   }
349 
350   return true;
351 }
352 
353 }  // namespace libtextclassifier3
354