1 // Copyright (C) 2022 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "icing/scoring/priority-queue-scored-document-hits-ranker.h"
16 
17 #include <vector>
18 
19 #include "gmock/gmock.h"
20 #include "gtest/gtest.h"
21 #include "icing/scoring/scored-document-hit.h"
22 #include "icing/testing/common-matchers.h"
23 
24 namespace icing {
25 namespace lib {
26 
27 namespace {
28 
29 using ::testing::ElementsAre;
30 using ::testing::Eq;
31 using ::testing::IsEmpty;
32 using ::testing::SizeIs;
33 
34 class Converter {
35  public:
operator ()(ScoredDocumentHit hit) const36   JoinedScoredDocumentHit operator()(ScoredDocumentHit hit) const {
37     return converter_(std::move(hit));
38   }
39 
40  private:
41   ScoredDocumentHit::Converter converter_;
42 } converter;
43 
PopAll(PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit> & ranker)44 std::vector<JoinedScoredDocumentHit> PopAll(
45     PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>& ranker) {
46   std::vector<JoinedScoredDocumentHit> hits;
47   while (!ranker.empty()) {
48     hits.push_back(ranker.PopNext());
49   }
50   return hits;
51 }
52 
TEST(PriorityQueueScoredDocumentHitsRankerTest,ShouldGetCorrectSizeAndEmpty)53 TEST(PriorityQueueScoredDocumentHitsRankerTest, ShouldGetCorrectSizeAndEmpty) {
54   ScoredDocumentHit scored_hit_0(/*document_id=*/0, kSectionIdMaskNone,
55                                  /*score=*/1);
56   ScoredDocumentHit scored_hit_1(/*document_id=*/1, kSectionIdMaskNone,
57                                  /*score=*/1);
58   ScoredDocumentHit scored_hit_2(/*document_id=*/2, kSectionIdMaskNone,
59                                  /*score=*/1);
60 
61   PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit> ranker(
62       {scored_hit_1, scored_hit_0, scored_hit_2},
63       /*is_descending=*/true);
64   EXPECT_THAT(ranker.size(), Eq(3));
65   EXPECT_FALSE(ranker.empty());
66 
67   ranker.PopNext();
68   EXPECT_THAT(ranker.size(), Eq(2));
69   EXPECT_FALSE(ranker.empty());
70 
71   ranker.PopNext();
72   EXPECT_THAT(ranker.size(), Eq(1));
73   EXPECT_FALSE(ranker.empty());
74 
75   ranker.PopNext();
76   EXPECT_THAT(ranker.size(), Eq(0));
77   EXPECT_TRUE(ranker.empty());
78 }
79 
TEST(PriorityQueueScoredDocumentHitsRankerTest,ShouldRankInDescendingOrder)80 TEST(PriorityQueueScoredDocumentHitsRankerTest, ShouldRankInDescendingOrder) {
81   ScoredDocumentHit scored_hit_0(/*document_id=*/0, kSectionIdMaskNone,
82                                  /*score=*/1);
83   ScoredDocumentHit scored_hit_1(/*document_id=*/1, kSectionIdMaskNone,
84                                  /*score=*/1);
85   ScoredDocumentHit scored_hit_2(/*document_id=*/2, kSectionIdMaskNone,
86                                  /*score=*/1);
87   ScoredDocumentHit scored_hit_3(/*document_id=*/3, kSectionIdMaskNone,
88                                  /*score=*/1);
89   ScoredDocumentHit scored_hit_4(/*document_id=*/4, kSectionIdMaskNone,
90                                  /*score=*/1);
91 
92   PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit> ranker(
93       {scored_hit_1, scored_hit_0, scored_hit_2, scored_hit_4, scored_hit_3},
94       /*is_descending=*/true);
95 
96   EXPECT_THAT(ranker, SizeIs(5));
97   std::vector<JoinedScoredDocumentHit> scored_document_hits = PopAll(ranker);
98   EXPECT_THAT(
99       scored_document_hits,
100       ElementsAre(EqualsJoinedScoredDocumentHit(converter(scored_hit_4)),
101                   EqualsJoinedScoredDocumentHit(converter(scored_hit_3)),
102                   EqualsJoinedScoredDocumentHit(converter(scored_hit_2)),
103                   EqualsJoinedScoredDocumentHit(converter(scored_hit_1)),
104                   EqualsJoinedScoredDocumentHit(converter(scored_hit_0))));
105 }
106 
TEST(PriorityQueueScoredDocumentHitsRankerTest,ShouldRankInAscendingOrder)107 TEST(PriorityQueueScoredDocumentHitsRankerTest, ShouldRankInAscendingOrder) {
108   ScoredDocumentHit scored_hit_0(/*document_id=*/0, kSectionIdMaskNone,
109                                  /*score=*/1);
110   ScoredDocumentHit scored_hit_1(/*document_id=*/1, kSectionIdMaskNone,
111                                  /*score=*/1);
112   ScoredDocumentHit scored_hit_2(/*document_id=*/2, kSectionIdMaskNone,
113                                  /*score=*/1);
114   ScoredDocumentHit scored_hit_3(/*document_id=*/3, kSectionIdMaskNone,
115                                  /*score=*/1);
116   ScoredDocumentHit scored_hit_4(/*document_id=*/4, kSectionIdMaskNone,
117                                  /*score=*/1);
118 
119   PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit> ranker(
120       {scored_hit_1, scored_hit_0, scored_hit_2, scored_hit_4, scored_hit_3},
121       /*is_descending=*/false);
122 
123   EXPECT_THAT(ranker, SizeIs(5));
124   std::vector<JoinedScoredDocumentHit> scored_document_hits = PopAll(ranker);
125   EXPECT_THAT(
126       scored_document_hits,
127       ElementsAre(EqualsJoinedScoredDocumentHit(converter(scored_hit_0)),
128                   EqualsJoinedScoredDocumentHit(converter(scored_hit_1)),
129                   EqualsJoinedScoredDocumentHit(converter(scored_hit_2)),
130                   EqualsJoinedScoredDocumentHit(converter(scored_hit_3)),
131                   EqualsJoinedScoredDocumentHit(converter(scored_hit_4))));
132 }
133 
TEST(PriorityQueueScoredDocumentHitsRankerTest,ShouldRankDuplicateScoredDocumentHits)134 TEST(PriorityQueueScoredDocumentHitsRankerTest,
135      ShouldRankDuplicateScoredDocumentHits) {
136   ScoredDocumentHit scored_hit_0(/*document_id=*/0, kSectionIdMaskNone,
137                                  /*score=*/1);
138   ScoredDocumentHit scored_hit_1(/*document_id=*/1, kSectionIdMaskNone,
139                                  /*score=*/1);
140   ScoredDocumentHit scored_hit_2(/*document_id=*/2, kSectionIdMaskNone,
141                                  /*score=*/1);
142   ScoredDocumentHit scored_hit_3(/*document_id=*/3, kSectionIdMaskNone,
143                                  /*score=*/1);
144   ScoredDocumentHit scored_hit_4(/*document_id=*/4, kSectionIdMaskNone,
145                                  /*score=*/1);
146 
147   PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit> ranker(
148       {scored_hit_2, scored_hit_4, scored_hit_1, scored_hit_0, scored_hit_2,
149        scored_hit_2, scored_hit_4, scored_hit_3},
150       /*is_descending=*/true);
151 
152   EXPECT_THAT(ranker, SizeIs(8));
153   std::vector<JoinedScoredDocumentHit> scored_document_hits = PopAll(ranker);
154   EXPECT_THAT(
155       scored_document_hits,
156       ElementsAre(EqualsJoinedScoredDocumentHit(converter(scored_hit_4)),
157                   EqualsJoinedScoredDocumentHit(converter(scored_hit_4)),
158                   EqualsJoinedScoredDocumentHit(converter(scored_hit_3)),
159                   EqualsJoinedScoredDocumentHit(converter(scored_hit_2)),
160                   EqualsJoinedScoredDocumentHit(converter(scored_hit_2)),
161                   EqualsJoinedScoredDocumentHit(converter(scored_hit_2)),
162                   EqualsJoinedScoredDocumentHit(converter(scored_hit_1)),
163                   EqualsJoinedScoredDocumentHit(converter(scored_hit_0))));
164 }
165 
TEST(PriorityQueueScoredDocumentHitsRankerTest,ShouldRankEmptyScoredDocumentHits)166 TEST(PriorityQueueScoredDocumentHitsRankerTest,
167      ShouldRankEmptyScoredDocumentHits) {
168   PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit> ranker(
169       /*scored_document_hits=*/{},
170       /*is_descending=*/true);
171   EXPECT_THAT(ranker, IsEmpty());
172 }
173 
TEST(PriorityQueueScoredDocumentHitsRankerTest,ShouldTruncateToNewSize)174 TEST(PriorityQueueScoredDocumentHitsRankerTest, ShouldTruncateToNewSize) {
175   ScoredDocumentHit scored_hit_0(/*document_id=*/0, kSectionIdMaskNone,
176                                  /*score=*/1);
177   ScoredDocumentHit scored_hit_1(/*document_id=*/1, kSectionIdMaskNone,
178                                  /*score=*/1);
179   ScoredDocumentHit scored_hit_2(/*document_id=*/2, kSectionIdMaskNone,
180                                  /*score=*/1);
181   ScoredDocumentHit scored_hit_3(/*document_id=*/3, kSectionIdMaskNone,
182                                  /*score=*/1);
183   ScoredDocumentHit scored_hit_4(/*document_id=*/4, kSectionIdMaskNone,
184                                  /*score=*/1);
185 
186   PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit> ranker(
187       {scored_hit_1, scored_hit_0, scored_hit_2, scored_hit_4, scored_hit_3},
188       /*is_descending=*/true);
189   ASSERT_THAT(ranker, SizeIs(5));
190 
191   ranker.TruncateHitsTo(/*new_size=*/3);
192   EXPECT_THAT(ranker, SizeIs(3));
193   std::vector<JoinedScoredDocumentHit> scored_document_hits = PopAll(ranker);
194   EXPECT_THAT(
195       scored_document_hits,
196       ElementsAre(EqualsJoinedScoredDocumentHit(converter(scored_hit_4)),
197                   EqualsJoinedScoredDocumentHit(converter(scored_hit_3)),
198                   EqualsJoinedScoredDocumentHit(converter(scored_hit_2))));
199 }
200 
TEST(PriorityQueueScoredDocumentHitsRankerTest,ShouldTruncateToZero)201 TEST(PriorityQueueScoredDocumentHitsRankerTest, ShouldTruncateToZero) {
202   ScoredDocumentHit scored_hit_0(/*document_id=*/0, kSectionIdMaskNone,
203                                  /*score=*/1);
204   ScoredDocumentHit scored_hit_1(/*document_id=*/1, kSectionIdMaskNone,
205                                  /*score=*/1);
206   ScoredDocumentHit scored_hit_2(/*document_id=*/2, kSectionIdMaskNone,
207                                  /*score=*/1);
208   ScoredDocumentHit scored_hit_3(/*document_id=*/3, kSectionIdMaskNone,
209                                  /*score=*/1);
210   ScoredDocumentHit scored_hit_4(/*document_id=*/4, kSectionIdMaskNone,
211                                  /*score=*/1);
212 
213   PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit> ranker(
214       {scored_hit_1, scored_hit_0, scored_hit_2, scored_hit_4, scored_hit_3},
215       /*is_descending=*/true);
216   ASSERT_THAT(ranker, SizeIs(5));
217 
218   ranker.TruncateHitsTo(/*new_size=*/0);
219   EXPECT_THAT(ranker, IsEmpty());
220 }
221 
TEST(PriorityQueueScoredDocumentHitsRankerTest,ShouldNotTruncateToNegative)222 TEST(PriorityQueueScoredDocumentHitsRankerTest, ShouldNotTruncateToNegative) {
223   ScoredDocumentHit scored_hit_0(/*document_id=*/0, kSectionIdMaskNone,
224                                  /*score=*/1);
225   ScoredDocumentHit scored_hit_1(/*document_id=*/1, kSectionIdMaskNone,
226                                  /*score=*/1);
227   ScoredDocumentHit scored_hit_2(/*document_id=*/2, kSectionIdMaskNone,
228                                  /*score=*/1);
229   ScoredDocumentHit scored_hit_3(/*document_id=*/3, kSectionIdMaskNone,
230                                  /*score=*/1);
231   ScoredDocumentHit scored_hit_4(/*document_id=*/4, kSectionIdMaskNone,
232                                  /*score=*/1);
233 
234   PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit> ranker(
235       {scored_hit_1, scored_hit_0, scored_hit_2, scored_hit_4, scored_hit_3},
236       /*is_descending=*/true);
237   ASSERT_THAT(ranker, SizeIs(Eq(5)));
238 
239   ranker.TruncateHitsTo(/*new_size=*/-1);
240   EXPECT_THAT(ranker, SizeIs(Eq(5)));
241   // Contents are not affected.
242   std::vector<JoinedScoredDocumentHit> scored_document_hits = PopAll(ranker);
243   EXPECT_THAT(
244       scored_document_hits,
245       ElementsAre(EqualsJoinedScoredDocumentHit(converter(scored_hit_4)),
246                   EqualsJoinedScoredDocumentHit(converter(scored_hit_3)),
247                   EqualsJoinedScoredDocumentHit(converter(scored_hit_2)),
248                   EqualsJoinedScoredDocumentHit(converter(scored_hit_1)),
249                   EqualsJoinedScoredDocumentHit(converter(scored_hit_0))));
250 }
251 
252 }  // namespace
253 
254 }  // namespace lib
255 }  // namespace icing
256