1 //===--- Iterator.cpp - Query Symbol Retrieval ------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "Iterator.h"
10 #include "llvm/Support/Casting.h"
11 #include <algorithm>
12 #include <cassert>
13 #include <numeric>
14
15 namespace clang {
16 namespace clangd {
17 namespace dex {
18 namespace {
19
20 /// Implements Iterator over the intersection of other iterators.
21 ///
22 /// AndIterator iterates through common items among all children. It becomes
23 /// exhausted as soon as any child becomes exhausted. After each mutation, the
24 /// iterator restores the invariant: all children must point to the same item.
25 class AndIterator : public Iterator {
26 public:
AndIterator(std::vector<std::unique_ptr<Iterator>> AllChildren)27 explicit AndIterator(std::vector<std::unique_ptr<Iterator>> AllChildren)
28 : Iterator(Kind::And), Children(std::move(AllChildren)) {
29 assert(!Children.empty() && "AND iterator should have at least one child.");
30 // Establish invariants.
31 for (const auto &Child : Children)
32 ReachedEnd |= Child->reachedEnd();
33 sync();
34 // When children are sorted by the estimateSize(), sync() calls are more
35 // effective. Each sync() starts with the first child and makes sure all
36 // children point to the same element. If any child is "above" the previous
37 // ones, the algorithm resets and and advances the children to the next
38 // highest element starting from the front. When child iterators in the
39 // beginning have smaller estimated size, the sync() will have less restarts
40 // and become more effective.
41 llvm::sort(Children, [](const std::unique_ptr<Iterator> &LHS,
42 const std::unique_ptr<Iterator> &RHS) {
43 return LHS->estimateSize() < RHS->estimateSize();
44 });
45 }
46
reachedEnd() const47 bool reachedEnd() const override { return ReachedEnd; }
48
49 /// Advances all children to the next common item.
advance()50 void advance() override {
51 assert(!reachedEnd() && "AND iterator can't advance() at the end.");
52 Children.front()->advance();
53 sync();
54 }
55
56 /// Advances all children to the next common item with DocumentID >= ID.
advanceTo(DocID ID)57 void advanceTo(DocID ID) override {
58 assert(!reachedEnd() && "AND iterator can't advanceTo() at the end.");
59 Children.front()->advanceTo(ID);
60 sync();
61 }
62
peek() const63 DocID peek() const override { return Children.front()->peek(); }
64
consume()65 float consume() override {
66 assert(!reachedEnd() && "AND iterator can't consume() at the end.");
67 float Boost = 1;
68 for (const auto &Child : Children)
69 Boost *= Child->consume();
70 return Boost;
71 }
72
estimateSize() const73 size_t estimateSize() const override {
74 return Children.front()->estimateSize();
75 }
76
77 private:
dump(llvm::raw_ostream & OS) const78 llvm::raw_ostream &dump(llvm::raw_ostream &OS) const override {
79 OS << "(& ";
80 auto Separator = "";
81 for (const auto &Child : Children) {
82 OS << Separator << *Child;
83 Separator = " ";
84 }
85 OS << ')';
86 return OS;
87 }
88
89 /// Restores class invariants: each child will point to the same element after
90 /// sync.
sync()91 void sync() {
92 ReachedEnd |= Children.front()->reachedEnd();
93 if (ReachedEnd)
94 return;
95 auto SyncID = Children.front()->peek();
96 // Indicates whether any child needs to be advanced to new SyncID.
97 bool NeedsAdvance = false;
98 do {
99 NeedsAdvance = false;
100 for (auto &Child : Children) {
101 Child->advanceTo(SyncID);
102 ReachedEnd |= Child->reachedEnd();
103 // If any child reaches end And iterator can not match any other items.
104 // In this case, just terminate the process.
105 if (ReachedEnd)
106 return;
107 // If any child goes beyond given ID (i.e. ID is not the common item),
108 // all children should be advanced to the next common item.
109 if (Child->peek() > SyncID) {
110 SyncID = Child->peek();
111 NeedsAdvance = true;
112 }
113 }
114 } while (NeedsAdvance);
115 }
116
117 /// AndIterator owns its children and ensures that all of them point to the
118 /// same element. As soon as one child gets exhausted, AndIterator can no
119 /// longer advance and has reached its end.
120 std::vector<std::unique_ptr<Iterator>> Children;
121 /// Indicates whether any child is exhausted. It is cheaper to maintain and
122 /// update the field, rather than traversing the whole subtree in each
123 /// reachedEnd() call.
124 bool ReachedEnd = false;
125 friend Corpus; // For optimizations.
126 };
127
128 /// Implements Iterator over the union of other iterators.
129 ///
130 /// OrIterator iterates through all items which can be pointed to by at least
131 /// one child. To preserve the sorted order, this iterator always advances the
132 /// child with smallest Child->peek() value. OrIterator becomes exhausted as
133 /// soon as all of its children are exhausted.
134 class OrIterator : public Iterator {
135 public:
OrIterator(std::vector<std::unique_ptr<Iterator>> AllChildren)136 explicit OrIterator(std::vector<std::unique_ptr<Iterator>> AllChildren)
137 : Iterator(Kind::Or), Children(std::move(AllChildren)) {
138 assert(!Children.empty() && "OR iterator should have at least one child.");
139 }
140
141 /// Returns true if all children are exhausted.
reachedEnd() const142 bool reachedEnd() const override {
143 for (const auto &Child : Children)
144 if (!Child->reachedEnd())
145 return false;
146 return true;
147 }
148
149 /// Moves each child pointing to the smallest DocID to the next item.
advance()150 void advance() override {
151 assert(!reachedEnd() && "OR iterator can't advance() at the end.");
152 const auto SmallestID = peek();
153 for (const auto &Child : Children)
154 if (!Child->reachedEnd() && Child->peek() == SmallestID)
155 Child->advance();
156 }
157
158 /// Advances each child to the next existing element with DocumentID >= ID.
advanceTo(DocID ID)159 void advanceTo(DocID ID) override {
160 assert(!reachedEnd() && "OR iterator can't advanceTo() at the end.");
161 for (const auto &Child : Children)
162 if (!Child->reachedEnd())
163 Child->advanceTo(ID);
164 }
165
166 /// Returns the element under cursor of the child with smallest Child->peek()
167 /// value.
peek() const168 DocID peek() const override {
169 assert(!reachedEnd() && "OR iterator can't peek() at the end.");
170 DocID Result = std::numeric_limits<DocID>::max();
171
172 for (const auto &Child : Children)
173 if (!Child->reachedEnd())
174 Result = std::min(Result, Child->peek());
175
176 return Result;
177 }
178
179 // Returns the maximum boosting score among all Children when iterator
180 // points to the current ID.
consume()181 float consume() override {
182 assert(!reachedEnd() && "OR iterator can't consume() at the end.");
183 const DocID ID = peek();
184 float Boost = 1;
185 for (const auto &Child : Children)
186 if (!Child->reachedEnd() && Child->peek() == ID)
187 Boost = std::max(Boost, Child->consume());
188 return Boost;
189 }
190
estimateSize() const191 size_t estimateSize() const override {
192 size_t Size = 0;
193 for (const auto &Child : Children)
194 Size = std::max(Size, Child->estimateSize());
195 return Size;
196 }
197
198 private:
dump(llvm::raw_ostream & OS) const199 llvm::raw_ostream &dump(llvm::raw_ostream &OS) const override {
200 OS << "(| ";
201 auto Separator = "";
202 for (const auto &Child : Children) {
203 OS << Separator << *Child;
204 Separator = " ";
205 }
206 OS << ')';
207 return OS;
208 }
209
210 // FIXME(kbobyrev): Would storing Children in min-heap be faster?
211 std::vector<std::unique_ptr<Iterator>> Children;
212 friend Corpus; // For optimizations.
213 };
214
215 /// TrueIterator handles PostingLists which contain all items of the index. It
216 /// stores size of the virtual posting list, and all operations are performed
217 /// in O(1).
218 class TrueIterator : public Iterator {
219 public:
TrueIterator(DocID Size)220 explicit TrueIterator(DocID Size) : Iterator(Kind::True), Size(Size) {}
221
reachedEnd() const222 bool reachedEnd() const override { return Index >= Size; }
223
advance()224 void advance() override {
225 assert(!reachedEnd() && "TRUE iterator can't advance() at the end.");
226 ++Index;
227 }
228
advanceTo(DocID ID)229 void advanceTo(DocID ID) override {
230 assert(!reachedEnd() && "TRUE iterator can't advanceTo() at the end.");
231 Index = std::min(ID, Size);
232 }
233
peek() const234 DocID peek() const override {
235 assert(!reachedEnd() && "TRUE iterator can't peek() at the end.");
236 return Index;
237 }
238
consume()239 float consume() override {
240 assert(!reachedEnd() && "TRUE iterator can't consume() at the end.");
241 return 1;
242 }
243
estimateSize() const244 size_t estimateSize() const override { return Size; }
245
246 private:
dump(llvm::raw_ostream & OS) const247 llvm::raw_ostream &dump(llvm::raw_ostream &OS) const override {
248 return OS << "true";
249 }
250
251 DocID Index = 0;
252 /// Size of the underlying virtual PostingList.
253 DocID Size;
254 };
255
256 /// FalseIterator yields no results.
257 class FalseIterator : public Iterator {
258 public:
FalseIterator()259 FalseIterator() : Iterator(Kind::False) {}
reachedEnd() const260 bool reachedEnd() const override { return true; }
advance()261 void advance() override { assert(false); }
advanceTo(DocID ID)262 void advanceTo(DocID ID) override { assert(false); }
peek() const263 DocID peek() const override {
264 assert(false);
265 return 0;
266 }
consume()267 float consume() override {
268 assert(false);
269 return 1;
270 }
estimateSize() const271 size_t estimateSize() const override { return 0; }
272
273 private:
dump(llvm::raw_ostream & OS) const274 llvm::raw_ostream &dump(llvm::raw_ostream &OS) const override {
275 return OS << "false";
276 }
277 };
278
279 /// Boost iterator is a wrapper around its child which multiplies scores of
280 /// each retrieved item by a given factor.
281 class BoostIterator : public Iterator {
282 public:
BoostIterator(std::unique_ptr<Iterator> Child,float Factor)283 BoostIterator(std::unique_ptr<Iterator> Child, float Factor)
284 : Child(std::move(Child)), Factor(Factor) {}
285
reachedEnd() const286 bool reachedEnd() const override { return Child->reachedEnd(); }
287
advance()288 void advance() override { Child->advance(); }
289
advanceTo(DocID ID)290 void advanceTo(DocID ID) override { Child->advanceTo(ID); }
291
peek() const292 DocID peek() const override { return Child->peek(); }
293
consume()294 float consume() override { return Child->consume() * Factor; }
295
estimateSize() const296 size_t estimateSize() const override { return Child->estimateSize(); }
297
298 private:
dump(llvm::raw_ostream & OS) const299 llvm::raw_ostream &dump(llvm::raw_ostream &OS) const override {
300 return OS << "(* " << Factor << ' ' << *Child << ')';
301 }
302
303 std::unique_ptr<Iterator> Child;
304 float Factor;
305 };
306
307 /// This iterator limits the number of items retrieved from the child iterator
308 /// on top of the query tree. To ensure that query tree with LIMIT iterators
309 /// inside works correctly, users have to call Root->consume(Root->peek()) each
310 /// time item is retrieved at the root of query tree.
311 class LimitIterator : public Iterator {
312 public:
LimitIterator(std::unique_ptr<Iterator> Child,size_t Limit)313 LimitIterator(std::unique_ptr<Iterator> Child, size_t Limit)
314 : Child(std::move(Child)), Limit(Limit), ItemsLeft(Limit) {}
315
reachedEnd() const316 bool reachedEnd() const override {
317 return ItemsLeft == 0 || Child->reachedEnd();
318 }
319
advance()320 void advance() override { Child->advance(); }
321
advanceTo(DocID ID)322 void advanceTo(DocID ID) override { Child->advanceTo(ID); }
323
peek() const324 DocID peek() const override { return Child->peek(); }
325
326 /// Decreases the limit in case the element consumed at top of the query tree
327 /// comes from the underlying iterator.
consume()328 float consume() override {
329 assert(!reachedEnd() && "LimitIterator can't consume() at the end.");
330 --ItemsLeft;
331 return Child->consume();
332 }
333
estimateSize() const334 size_t estimateSize() const override {
335 return std::min(Child->estimateSize(), Limit);
336 }
337
338 private:
dump(llvm::raw_ostream & OS) const339 llvm::raw_ostream &dump(llvm::raw_ostream &OS) const override {
340 return OS << "(LIMIT " << Limit << " " << *Child << ')';
341 }
342
343 std::unique_ptr<Iterator> Child;
344 size_t Limit;
345 size_t ItemsLeft;
346 };
347
348 } // end namespace
349
consume(Iterator & It)350 std::vector<std::pair<DocID, float>> consume(Iterator &It) {
351 std::vector<std::pair<DocID, float>> Result;
352 for (; !It.reachedEnd(); It.advance())
353 Result.emplace_back(It.peek(), It.consume());
354 return Result;
355 }
356
357 std::unique_ptr<Iterator>
intersect(std::vector<std::unique_ptr<Iterator>> Children) const358 Corpus::intersect(std::vector<std::unique_ptr<Iterator>> Children) const {
359 std::vector<std::unique_ptr<Iterator>> RealChildren;
360 for (auto &Child : Children) {
361 switch (Child->kind()) {
362 case Iterator::Kind::True:
363 break; // No effect, drop the iterator.
364 case Iterator::Kind::False:
365 return std::move(Child); // Intersection is empty.
366 case Iterator::Kind::And: {
367 // Inline nested AND into parent AND.
368 auto &NewChildren = static_cast<AndIterator *>(Child.get())->Children;
369 std::move(NewChildren.begin(), NewChildren.end(),
370 std::back_inserter(RealChildren));
371 break;
372 }
373 default:
374 RealChildren.push_back(std::move(Child));
375 }
376 }
377 switch (RealChildren.size()) {
378 case 0:
379 return all();
380 case 1:
381 return std::move(RealChildren.front());
382 default:
383 return std::make_unique<AndIterator>(std::move(RealChildren));
384 }
385 }
386
387 std::unique_ptr<Iterator>
unionOf(std::vector<std::unique_ptr<Iterator>> Children) const388 Corpus::unionOf(std::vector<std::unique_ptr<Iterator>> Children) const {
389 std::vector<std::unique_ptr<Iterator>> RealChildren;
390 for (auto &Child : Children) {
391 switch (Child->kind()) {
392 case Iterator::Kind::False:
393 break; // No effect, drop the iterator.
394 case Iterator::Kind::Or: {
395 // Inline nested OR into parent OR.
396 auto &NewChildren = static_cast<OrIterator *>(Child.get())->Children;
397 std::move(NewChildren.begin(), NewChildren.end(),
398 std::back_inserter(RealChildren));
399 break;
400 }
401 case Iterator::Kind::True:
402 // Don't return all(), which would discard sibling boosts.
403 default:
404 RealChildren.push_back(std::move(Child));
405 }
406 }
407 switch (RealChildren.size()) {
408 case 0:
409 return none();
410 case 1:
411 return std::move(RealChildren.front());
412 default:
413 return std::make_unique<OrIterator>(std::move(RealChildren));
414 }
415 }
416
all() const417 std::unique_ptr<Iterator> Corpus::all() const {
418 return std::make_unique<TrueIterator>(Size);
419 }
420
none() const421 std::unique_ptr<Iterator> Corpus::none() const {
422 return std::make_unique<FalseIterator>();
423 }
424
boost(std::unique_ptr<Iterator> Child,float Factor) const425 std::unique_ptr<Iterator> Corpus::boost(std::unique_ptr<Iterator> Child,
426 float Factor) const {
427 if (Factor == 1)
428 return Child;
429 if (Child->kind() == Iterator::Kind::False)
430 return Child;
431 return std::make_unique<BoostIterator>(std::move(Child), Factor);
432 }
433
limit(std::unique_ptr<Iterator> Child,size_t Limit) const434 std::unique_ptr<Iterator> Corpus::limit(std::unique_ptr<Iterator> Child,
435 size_t Limit) const {
436 if (Child->kind() == Iterator::Kind::False)
437 return Child;
438 return std::make_unique<LimitIterator>(std::move(Child), Limit);
439 }
440
441 } // namespace dex
442 } // namespace clangd
443 } // namespace clang
444