1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "actions/ranker.h"
18
19 #include <functional>
20 #include <set>
21 #include <vector>
22
23 #include "actions/lua-ranker.h"
24 #include "actions/zlib-utils.h"
25 #include "annotator/types.h"
26 #include "utils/base/logging.h"
27 #include "utils/lua-utils.h"
28
29 namespace libtextclassifier3 {
30 namespace {
31
SortByScoreAndType(std::vector<ActionSuggestion> * actions)32 void SortByScoreAndType(std::vector<ActionSuggestion>* actions) {
33 std::sort(actions->begin(), actions->end(),
34 [](const ActionSuggestion& a, const ActionSuggestion& b) {
35 return a.score > b.score ||
36 (a.score >= b.score && a.type < b.type);
37 });
38 }
39
40 template <typename T>
Compare(const T & left,const T & right)41 int Compare(const T& left, const T& right) {
42 if (left < right) {
43 return -1;
44 }
45 if (left > right) {
46 return 1;
47 }
48 return 0;
49 }
50
51 template <>
Compare(const std::string & left,const std::string & right)52 int Compare(const std::string& left, const std::string& right) {
53 return left.compare(right);
54 }
55
56 template <>
Compare(const MessageTextSpan & span,const MessageTextSpan & other)57 int Compare(const MessageTextSpan& span, const MessageTextSpan& other) {
58 if (const int value = Compare(span.message_index, other.message_index)) {
59 return value;
60 }
61 if (const int value = Compare(span.span.first, other.span.first)) {
62 return value;
63 }
64 if (const int value = Compare(span.span.second, other.span.second)) {
65 return value;
66 }
67 return 0;
68 }
69
IsSameSpan(const MessageTextSpan & span,const MessageTextSpan & other)70 bool IsSameSpan(const MessageTextSpan& span, const MessageTextSpan& other) {
71 return Compare(span, other) == 0;
72 }
73
TextSpansIntersect(const MessageTextSpan & span,const MessageTextSpan & other)74 bool TextSpansIntersect(const MessageTextSpan& span,
75 const MessageTextSpan& other) {
76 return span.message_index == other.message_index &&
77 SpansOverlap(span.span, other.span);
78 }
79
80 template <>
Compare(const ActionSuggestionAnnotation & annotation,const ActionSuggestionAnnotation & other)81 int Compare(const ActionSuggestionAnnotation& annotation,
82 const ActionSuggestionAnnotation& other) {
83 if (const int value = Compare(annotation.span, other.span)) {
84 return value;
85 }
86 if (const int value = Compare(annotation.name, other.name)) {
87 return value;
88 }
89 if (const int value =
90 Compare(annotation.entity.collection, other.entity.collection)) {
91 return value;
92 }
93 return 0;
94 }
95
96 // Checks whether two annotations can be considered equivalent.
IsEquivalentActionAnnotation(const ActionSuggestionAnnotation & annotation,const ActionSuggestionAnnotation & other)97 bool IsEquivalentActionAnnotation(const ActionSuggestionAnnotation& annotation,
98 const ActionSuggestionAnnotation& other) {
99 return Compare(annotation, other) == 0;
100 }
101
102 // Compares actions based on annotations.
CompareAnnotationsOnly(const ActionSuggestion & action,const ActionSuggestion & other)103 int CompareAnnotationsOnly(const ActionSuggestion& action,
104 const ActionSuggestion& other) {
105 if (const int value =
106 Compare(action.annotations.size(), other.annotations.size())) {
107 return value;
108 }
109 for (int i = 0; i < action.annotations.size(); i++) {
110 if (const int value =
111 Compare(action.annotations[i], other.annotations[i])) {
112 return value;
113 }
114 }
115 return 0;
116 }
117
118 // Checks whether two actions have the same annotations.
HaveEquivalentAnnotations(const ActionSuggestion & action,const ActionSuggestion & other)119 bool HaveEquivalentAnnotations(const ActionSuggestion& action,
120 const ActionSuggestion& other) {
121 return CompareAnnotationsOnly(action, other) == 0;
122 }
123
124 template <>
Compare(const ActionSuggestion & action,const ActionSuggestion & other)125 int Compare(const ActionSuggestion& action, const ActionSuggestion& other) {
126 if (const int value = Compare(action.type, other.type)) {
127 return value;
128 }
129 if (const int value = Compare(action.response_text, other.response_text)) {
130 return value;
131 }
132 if (const int value = Compare(action.serialized_entity_data,
133 other.serialized_entity_data)) {
134 return value;
135 }
136 return CompareAnnotationsOnly(action, other);
137 }
138
139 // Checks whether two action suggestions can be considered equivalent.
IsEquivalentActionSuggestion(const ActionSuggestion & action,const ActionSuggestion & other)140 bool IsEquivalentActionSuggestion(const ActionSuggestion& action,
141 const ActionSuggestion& other) {
142 return Compare(action, other) == 0;
143 }
144
145 // Checks whether any action is equivalent to the given one.
IsAnyActionEquivalent(const ActionSuggestion & action,const std::vector<ActionSuggestion> & actions)146 bool IsAnyActionEquivalent(const ActionSuggestion& action,
147 const std::vector<ActionSuggestion>& actions) {
148 for (const ActionSuggestion& other : actions) {
149 if (IsEquivalentActionSuggestion(action, other)) {
150 return true;
151 }
152 }
153 return false;
154 }
155
IsConflicting(const ActionSuggestionAnnotation & annotation,const ActionSuggestionAnnotation & other)156 bool IsConflicting(const ActionSuggestionAnnotation& annotation,
157 const ActionSuggestionAnnotation& other) {
158 // Two annotations are conflicting if they are different but refer to
159 // overlapping spans in the conversation.
160 return (!IsEquivalentActionAnnotation(annotation, other) &&
161 TextSpansIntersect(annotation.span, other.span));
162 }
163
164 // Checks whether two action suggestions can be considered conflicting.
IsConflictingActionSuggestion(const ActionSuggestion & action,const ActionSuggestion & other)165 bool IsConflictingActionSuggestion(const ActionSuggestion& action,
166 const ActionSuggestion& other) {
167 // Actions are considered conflicting, iff they refer to the same text span,
168 // but were not generated from the same annotation.
169 if (action.annotations.empty() || other.annotations.empty()) {
170 return false;
171 }
172 for (const ActionSuggestionAnnotation& annotation : action.annotations) {
173 for (const ActionSuggestionAnnotation& other_annotation :
174 other.annotations) {
175 if (IsConflicting(annotation, other_annotation)) {
176 return true;
177 }
178 }
179 }
180 return false;
181 }
182
183 // Checks whether any action is considered conflicting with the given one.
IsAnyActionConflicting(const ActionSuggestion & action,const std::vector<ActionSuggestion> & actions)184 bool IsAnyActionConflicting(const ActionSuggestion& action,
185 const std::vector<ActionSuggestion>& actions) {
186 for (const ActionSuggestion& other : actions) {
187 if (IsConflictingActionSuggestion(action, other)) {
188 return true;
189 }
190 }
191 return false;
192 }
193
194 } // namespace
195
196 std::unique_ptr<ActionsSuggestionsRanker>
CreateActionsSuggestionsRanker(const RankingOptions * options,ZlibDecompressor * decompressor,const std::string & smart_reply_action_type)197 ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
198 const RankingOptions* options, ZlibDecompressor* decompressor,
199 const std::string& smart_reply_action_type) {
200 auto ranker = std::unique_ptr<ActionsSuggestionsRanker>(
201 new ActionsSuggestionsRanker(options, smart_reply_action_type));
202
203 if (!ranker->InitializeAndValidate(decompressor)) {
204 TC3_LOG(ERROR) << "Could not initialize action ranker.";
205 return nullptr;
206 }
207
208 return ranker;
209 }
210
InitializeAndValidate(ZlibDecompressor * decompressor)211 bool ActionsSuggestionsRanker::InitializeAndValidate(
212 ZlibDecompressor* decompressor) {
213 if (options_ == nullptr) {
214 TC3_LOG(ERROR) << "No ranking options specified.";
215 return false;
216 }
217
218 std::string lua_ranking_script;
219 if (GetUncompressedString(options_->lua_ranking_script(),
220 options_->compressed_lua_ranking_script(),
221 decompressor, &lua_ranking_script) &&
222 !lua_ranking_script.empty()) {
223 if (!Compile(lua_ranking_script, &lua_bytecode_)) {
224 TC3_LOG(ERROR) << "Could not precompile lua ranking snippet.";
225 return false;
226 }
227 }
228
229 return true;
230 }
231
RankActions(const Conversation & conversation,ActionsSuggestionsResponse * response,const reflection::Schema * entity_data_schema,const reflection::Schema * annotations_entity_data_schema) const232 bool ActionsSuggestionsRanker::RankActions(
233 const Conversation& conversation, ActionsSuggestionsResponse* response,
234 const reflection::Schema* entity_data_schema,
235 const reflection::Schema* annotations_entity_data_schema) const {
236 if (options_->deduplicate_suggestions() ||
237 options_->deduplicate_suggestions_by_span()) {
238 // First order suggestions by priority score for deduplication.
239 std::sort(
240 response->actions.begin(), response->actions.end(),
241 [](const ActionSuggestion& a, const ActionSuggestion& b) {
242 return a.priority_score > b.priority_score ||
243 (a.priority_score >= b.priority_score && a.score > b.score);
244 });
245
246 // Deduplicate, keeping the higher score actions.
247 if (options_->deduplicate_suggestions()) {
248 std::vector<ActionSuggestion> deduplicated_actions;
249 for (const ActionSuggestion& candidate : response->actions) {
250 // Check whether we already have an equivalent action.
251 if (!IsAnyActionEquivalent(candidate, deduplicated_actions)) {
252 deduplicated_actions.push_back(std::move(candidate));
253 }
254 }
255 response->actions = std::move(deduplicated_actions);
256 }
257
258 // Resolve conflicts between conflicting actions referring to the same
259 // text span.
260 if (options_->deduplicate_suggestions_by_span()) {
261 std::vector<ActionSuggestion> deduplicated_actions;
262 for (const ActionSuggestion& candidate : response->actions) {
263 // Check whether we already have a conflicting action.
264 if (!IsAnyActionConflicting(candidate, deduplicated_actions)) {
265 deduplicated_actions.push_back(std::move(candidate));
266 }
267 }
268 response->actions = std::move(deduplicated_actions);
269 }
270 }
271
272 // Suppress smart replies if actions are present.
273 if (options_->suppress_smart_replies_with_actions()) {
274 std::vector<ActionSuggestion> non_smart_reply_actions;
275 for (const ActionSuggestion& action : response->actions) {
276 if (action.type != smart_reply_action_type_) {
277 non_smart_reply_actions.push_back(std::move(action));
278 }
279 }
280 response->actions = std::move(non_smart_reply_actions);
281 }
282
283 // Group by annotation if specified.
284 if (options_->group_by_annotations()) {
285 auto group_id = std::map<
286 ActionSuggestion, int,
287 std::function<bool(const ActionSuggestion&, const ActionSuggestion&)>>{
288 [](const ActionSuggestion& action, const ActionSuggestion& other) {
289 return (CompareAnnotationsOnly(action, other) < 0);
290 }};
291 typedef std::vector<ActionSuggestion> ActionSuggestionGroup;
292 std::vector<ActionSuggestionGroup> groups;
293
294 // Group actions by the annotation set they are based of.
295 for (const ActionSuggestion& action : response->actions) {
296 // Treat actions with no annotations idependently.
297 if (action.annotations.empty()) {
298 groups.emplace_back(1, action);
299 continue;
300 }
301
302 auto it = group_id.find(action);
303 if (it != group_id.end()) {
304 groups[it->second].push_back(action);
305 } else {
306 group_id[action] = groups.size();
307 groups.emplace_back(1, action);
308 }
309 }
310
311 // Sort within each group by score.
312 for (std::vector<ActionSuggestion>& group : groups) {
313 SortByScoreAndType(&group);
314 }
315
316 // Sort groups by maximum score.
317 std::sort(groups.begin(), groups.end(),
318 [](const std::vector<ActionSuggestion>& a,
319 const std::vector<ActionSuggestion>& b) {
320 return a.begin()->score > b.begin()->score ||
321 (a.begin()->score >= b.begin()->score &&
322 a.begin()->type < b.begin()->type);
323 });
324
325 // Flatten result.
326 const size_t num_actions = response->actions.size();
327 response->actions.clear();
328 response->actions.reserve(num_actions);
329 for (const std::vector<ActionSuggestion>& actions : groups) {
330 response->actions.insert(response->actions.end(), actions.begin(),
331 actions.end());
332 }
333
334 } else {
335 // Order suggestions independently by score.
336 SortByScoreAndType(&response->actions);
337 }
338
339 // Run lua ranking snippet, if provided.
340 if (!lua_bytecode_.empty()) {
341 auto lua_ranker = ActionsSuggestionsLuaRanker::Create(
342 conversation, lua_bytecode_, entity_data_schema,
343 annotations_entity_data_schema, response);
344 if (lua_ranker == nullptr || !lua_ranker->RankActions()) {
345 TC3_LOG(ERROR) << "Could not run lua ranking snippet.";
346 return false;
347 }
348 }
349
350 return true;
351 }
352
353 } // namespace libtextclassifier3
354