1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "actions/ranker.h"
18
19 #include <functional>
20 #include <set>
21 #include <vector>
22
23 #if !defined(TC3_DISABLE_LUA)
24 #include "actions/lua-ranker.h"
25 #endif
26 #include "actions/zlib-utils.h"
27 #include "annotator/types.h"
28 #include "utils/base/logging.h"
29 #if !defined(TC3_DISABLE_LUA)
30 #include "utils/lua-utils.h"
31 #endif
32
33 namespace libtextclassifier3 {
34 namespace {
35
SortByScoreAndType(std::vector<ActionSuggestion> * actions)36 void SortByScoreAndType(std::vector<ActionSuggestion>* actions) {
37 std::sort(actions->begin(), actions->end(),
38 [](const ActionSuggestion& a, const ActionSuggestion& b) {
39 return a.score > b.score ||
40 (a.score >= b.score && a.type < b.type);
41 });
42 }
43
44 template <typename T>
Compare(const T & left,const T & right)45 int Compare(const T& left, const T& right) {
46 if (left < right) {
47 return -1;
48 }
49 if (left > right) {
50 return 1;
51 }
52 return 0;
53 }
54
55 template <>
Compare(const std::string & left,const std::string & right)56 int Compare(const std::string& left, const std::string& right) {
57 return left.compare(right);
58 }
59
60 template <>
Compare(const MessageTextSpan & span,const MessageTextSpan & other)61 int Compare(const MessageTextSpan& span, const MessageTextSpan& other) {
62 if (const int value = Compare(span.message_index, other.message_index)) {
63 return value;
64 }
65 if (const int value = Compare(span.span.first, other.span.first)) {
66 return value;
67 }
68 if (const int value = Compare(span.span.second, other.span.second)) {
69 return value;
70 }
71 return 0;
72 }
73
IsSameSpan(const MessageTextSpan & span,const MessageTextSpan & other)74 bool IsSameSpan(const MessageTextSpan& span, const MessageTextSpan& other) {
75 return Compare(span, other) == 0;
76 }
77
TextSpansIntersect(const MessageTextSpan & span,const MessageTextSpan & other)78 bool TextSpansIntersect(const MessageTextSpan& span,
79 const MessageTextSpan& other) {
80 return span.message_index == other.message_index &&
81 SpansOverlap(span.span, other.span);
82 }
83
84 template <>
Compare(const ActionSuggestionAnnotation & annotation,const ActionSuggestionAnnotation & other)85 int Compare(const ActionSuggestionAnnotation& annotation,
86 const ActionSuggestionAnnotation& other) {
87 if (const int value = Compare(annotation.span, other.span)) {
88 return value;
89 }
90 if (const int value = Compare(annotation.name, other.name)) {
91 return value;
92 }
93 if (const int value =
94 Compare(annotation.entity.collection, other.entity.collection)) {
95 return value;
96 }
97 return 0;
98 }
99
100 // Checks whether two annotations can be considered equivalent.
IsEquivalentActionAnnotation(const ActionSuggestionAnnotation & annotation,const ActionSuggestionAnnotation & other)101 bool IsEquivalentActionAnnotation(const ActionSuggestionAnnotation& annotation,
102 const ActionSuggestionAnnotation& other) {
103 return Compare(annotation, other) == 0;
104 }
105
106 // Compares actions based on annotations.
CompareAnnotationsOnly(const ActionSuggestion & action,const ActionSuggestion & other)107 int CompareAnnotationsOnly(const ActionSuggestion& action,
108 const ActionSuggestion& other) {
109 if (const int value =
110 Compare(action.annotations.size(), other.annotations.size())) {
111 return value;
112 }
113 for (int i = 0; i < action.annotations.size(); i++) {
114 if (const int value =
115 Compare(action.annotations[i], other.annotations[i])) {
116 return value;
117 }
118 }
119 return 0;
120 }
121
122 // Checks whether two actions have the same annotations.
HaveEquivalentAnnotations(const ActionSuggestion & action,const ActionSuggestion & other)123 bool HaveEquivalentAnnotations(const ActionSuggestion& action,
124 const ActionSuggestion& other) {
125 return CompareAnnotationsOnly(action, other) == 0;
126 }
127
128 template <>
Compare(const ActionSuggestion & action,const ActionSuggestion & other)129 int Compare(const ActionSuggestion& action, const ActionSuggestion& other) {
130 if (const int value = Compare(action.type, other.type)) {
131 return value;
132 }
133 if (const int value = Compare(action.response_text, other.response_text)) {
134 return value;
135 }
136 if (const int value = Compare(action.serialized_entity_data,
137 other.serialized_entity_data)) {
138 return value;
139 }
140 return CompareAnnotationsOnly(action, other);
141 }
142
143 // Checks whether two action suggestions can be considered equivalent.
IsEquivalentActionSuggestion(const ActionSuggestion & action,const ActionSuggestion & other)144 bool IsEquivalentActionSuggestion(const ActionSuggestion& action,
145 const ActionSuggestion& other) {
146 return Compare(action, other) == 0;
147 }
148
149 // Checks whether any action is equivalent to the given one.
IsAnyActionEquivalent(const ActionSuggestion & action,const std::vector<ActionSuggestion> & actions)150 bool IsAnyActionEquivalent(const ActionSuggestion& action,
151 const std::vector<ActionSuggestion>& actions) {
152 for (const ActionSuggestion& other : actions) {
153 if (IsEquivalentActionSuggestion(action, other)) {
154 return true;
155 }
156 }
157 return false;
158 }
159
IsConflicting(const ActionSuggestionAnnotation & annotation,const ActionSuggestionAnnotation & other)160 bool IsConflicting(const ActionSuggestionAnnotation& annotation,
161 const ActionSuggestionAnnotation& other) {
162 // Two annotations are conflicting if they are different but refer to
163 // overlapping spans in the conversation.
164 return (!IsEquivalentActionAnnotation(annotation, other) &&
165 TextSpansIntersect(annotation.span, other.span));
166 }
167
168 // Checks whether two action suggestions can be considered conflicting.
IsConflictingActionSuggestion(const ActionSuggestion & action,const ActionSuggestion & other)169 bool IsConflictingActionSuggestion(const ActionSuggestion& action,
170 const ActionSuggestion& other) {
171 // Actions are considered conflicting, iff they refer to the same text span,
172 // but were not generated from the same annotation.
173 if (action.annotations.empty() || other.annotations.empty()) {
174 return false;
175 }
176 for (const ActionSuggestionAnnotation& annotation : action.annotations) {
177 for (const ActionSuggestionAnnotation& other_annotation :
178 other.annotations) {
179 if (IsConflicting(annotation, other_annotation)) {
180 return true;
181 }
182 }
183 }
184 return false;
185 }
186
187 // Checks whether any action is considered conflicting with the given one.
IsAnyActionConflicting(const ActionSuggestion & action,const std::vector<ActionSuggestion> & actions)188 bool IsAnyActionConflicting(const ActionSuggestion& action,
189 const std::vector<ActionSuggestion>& actions) {
190 for (const ActionSuggestion& other : actions) {
191 if (IsConflictingActionSuggestion(action, other)) {
192 return true;
193 }
194 }
195 return false;
196 }
197
198 } // namespace
199
200 std::unique_ptr<ActionsSuggestionsRanker>
CreateActionsSuggestionsRanker(const RankingOptions * options,ZlibDecompressor * decompressor,const std::string & smart_reply_action_type)201 ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
202 const RankingOptions* options, ZlibDecompressor* decompressor,
203 const std::string& smart_reply_action_type) {
204 auto ranker = std::unique_ptr<ActionsSuggestionsRanker>(
205 new ActionsSuggestionsRanker(options, smart_reply_action_type));
206
207 if (!ranker->InitializeAndValidate(decompressor)) {
208 TC3_LOG(ERROR) << "Could not initialize action ranker.";
209 return nullptr;
210 }
211
212 return ranker;
213 }
214
InitializeAndValidate(ZlibDecompressor * decompressor)215 bool ActionsSuggestionsRanker::InitializeAndValidate(
216 ZlibDecompressor* decompressor) {
217 if (options_ == nullptr) {
218 TC3_LOG(ERROR) << "No ranking options specified.";
219 return false;
220 }
221
222 #if !defined(TC3_DISABLE_LUA)
223 std::string lua_ranking_script;
224 if (GetUncompressedString(options_->lua_ranking_script(),
225 options_->compressed_lua_ranking_script(),
226 decompressor, &lua_ranking_script) &&
227 !lua_ranking_script.empty()) {
228 if (!Compile(lua_ranking_script, &lua_bytecode_)) {
229 TC3_LOG(ERROR) << "Could not precompile lua ranking snippet.";
230 return false;
231 }
232 }
233 #endif
234
235 return true;
236 }
237
RankActions(const Conversation & conversation,ActionsSuggestionsResponse * response,const reflection::Schema * entity_data_schema,const reflection::Schema * annotations_entity_data_schema) const238 bool ActionsSuggestionsRanker::RankActions(
239 const Conversation& conversation, ActionsSuggestionsResponse* response,
240 const reflection::Schema* entity_data_schema,
241 const reflection::Schema* annotations_entity_data_schema) const {
242 if (options_->deduplicate_suggestions() ||
243 options_->deduplicate_suggestions_by_span()) {
244 // First order suggestions by priority score for deduplication.
245 std::sort(
246 response->actions.begin(), response->actions.end(),
247 [](const ActionSuggestion& a, const ActionSuggestion& b) {
248 return a.priority_score > b.priority_score ||
249 (a.priority_score >= b.priority_score && a.score > b.score);
250 });
251
252 // Deduplicate, keeping the higher score actions.
253 if (options_->deduplicate_suggestions()) {
254 std::vector<ActionSuggestion> deduplicated_actions;
255 for (const ActionSuggestion& candidate : response->actions) {
256 // Check whether we already have an equivalent action.
257 if (!IsAnyActionEquivalent(candidate, deduplicated_actions)) {
258 deduplicated_actions.push_back(std::move(candidate));
259 }
260 }
261 response->actions = std::move(deduplicated_actions);
262 }
263
264 // Resolve conflicts between conflicting actions referring to the same
265 // text span.
266 if (options_->deduplicate_suggestions_by_span()) {
267 std::vector<ActionSuggestion> deduplicated_actions;
268 for (const ActionSuggestion& candidate : response->actions) {
269 // Check whether we already have a conflicting action.
270 if (!IsAnyActionConflicting(candidate, deduplicated_actions)) {
271 deduplicated_actions.push_back(std::move(candidate));
272 }
273 }
274 response->actions = std::move(deduplicated_actions);
275 }
276 }
277
278 // Suppress smart replies if actions are present.
279 if (options_->suppress_smart_replies_with_actions()) {
280 std::vector<ActionSuggestion> non_smart_reply_actions;
281 for (const ActionSuggestion& action : response->actions) {
282 if (action.type != smart_reply_action_type_) {
283 non_smart_reply_actions.push_back(std::move(action));
284 }
285 }
286 response->actions = std::move(non_smart_reply_actions);
287 }
288
289 // Group by annotation if specified.
290 if (options_->group_by_annotations()) {
291 auto group_id = std::map<
292 ActionSuggestion, int,
293 std::function<bool(const ActionSuggestion&, const ActionSuggestion&)>>{
294 [](const ActionSuggestion& action, const ActionSuggestion& other) {
295 return (CompareAnnotationsOnly(action, other) < 0);
296 }};
297 typedef std::vector<ActionSuggestion> ActionSuggestionGroup;
298 std::vector<ActionSuggestionGroup> groups;
299
300 // Group actions by the annotation set they are based of.
301 for (const ActionSuggestion& action : response->actions) {
302 // Treat actions with no annotations idependently.
303 if (action.annotations.empty()) {
304 groups.emplace_back(1, action);
305 continue;
306 }
307
308 auto it = group_id.find(action);
309 if (it != group_id.end()) {
310 groups[it->second].push_back(action);
311 } else {
312 group_id[action] = groups.size();
313 groups.emplace_back(1, action);
314 }
315 }
316
317 // Sort within each group by score.
318 for (std::vector<ActionSuggestion>& group : groups) {
319 SortByScoreAndType(&group);
320 }
321
322 // Sort groups by maximum score.
323 std::sort(groups.begin(), groups.end(),
324 [](const std::vector<ActionSuggestion>& a,
325 const std::vector<ActionSuggestion>& b) {
326 return a.begin()->score > b.begin()->score ||
327 (a.begin()->score >= b.begin()->score &&
328 a.begin()->type < b.begin()->type);
329 });
330
331 // Flatten result.
332 const size_t num_actions = response->actions.size();
333 response->actions.clear();
334 response->actions.reserve(num_actions);
335 for (const std::vector<ActionSuggestion>& actions : groups) {
336 response->actions.insert(response->actions.end(), actions.begin(),
337 actions.end());
338 }
339
340 } else {
341 // Order suggestions independently by score.
342 SortByScoreAndType(&response->actions);
343 }
344
345 #if !defined(TC3_DISABLE_LUA)
346 // Run lua ranking snippet, if provided.
347 if (!lua_bytecode_.empty()) {
348 auto lua_ranker = ActionsSuggestionsLuaRanker::Create(
349 conversation, lua_bytecode_, entity_data_schema,
350 annotations_entity_data_schema, response);
351 if (lua_ranker == nullptr || !lua_ranker->RankActions()) {
352 TC3_LOG(ERROR) << "Could not run lua ranking snippet.";
353 return false;
354 }
355 }
356 #endif
357
358 return true;
359 }
360
361 } // namespace libtextclassifier3
362