• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (C) 2024 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #ifndef ICING_INDEX_EMBED_DOC_HIT_INFO_ITERATOR_EMBEDDING_H_
16 #define ICING_INDEX_EMBED_DOC_HIT_INFO_ITERATOR_EMBEDDING_H_
17 
18 #include <memory>
19 #include <string>
20 #include <string_view>
21 #include <utility>
22 #include <vector>
23 
24 #include "icing/text_classifier/lib3/utils/base/status.h"
25 #include "icing/text_classifier/lib3/utils/base/statusor.h"
26 #include "icing/absl_ports/canonical_errors.h"
27 #include "icing/index/embed/embedding-hit.h"
28 #include "icing/index/embed/embedding-index.h"
29 #include "icing/index/embed/embedding-query-results.h"
30 #include "icing/index/embed/embedding-scorer.h"
31 #include "icing/index/embed/posting-list-embedding-hit-accessor.h"
32 #include "icing/index/iterator/doc-hit-info-iterator.h"
33 #include "icing/index/iterator/section-restrict-data.h"
34 #include "icing/proto/search.pb.h"
35 #include "icing/schema/section.h"
36 
37 namespace icing {
38 namespace lib {
39 
40 class DocHitInfoIteratorEmbedding : public DocHitInfoLeafIterator {
41  public:
42   // Create a DocHitInfoIterator for iterating through all docs which have an
43   // embedding matched with the provided query with a score in the range of
44   // [score_low, score_high], using the provided metric_type.
45   //
46   // The iterator will store the matched embedding scores in score_map to
47   // prepare for scoring.
48   //
49   // The iterator will handle the section restriction logic internally by the
50   // provided section_restrict_data.
51   //
52   // Returns:
53   //   - a DocHitInfoIteratorEmbedding instance on success.
54   //   - Any error from posting lists.
55   static libtextclassifier3::StatusOr<
56       std::unique_ptr<DocHitInfoIteratorEmbedding>>
57   Create(const PropertyProto::VectorProto* query,
58          std::unique_ptr<SectionRestrictData> section_restrict_data,
59          SearchSpecProto::EmbeddingQueryMetricType::Code metric_type,
60          double score_low, double score_high,
61          EmbeddingQueryResults::EmbeddingQueryScoreMap* score_map,
62          const EmbeddingIndex* embedding_index);
63 
64   libtextclassifier3::Status Advance() override;
65 
66   // The iterator will internally handle the section restriction logic by itself
67   // to have better control, so that it is able to filter out embedding hits
68   // from unwanted sections to avoid retrieving unnecessary vectors and
69   // calculate scores for them.
full_section_restriction_applied()70   bool full_section_restriction_applied() const override { return true; }
71 
TrimRightMostNode()72   libtextclassifier3::StatusOr<TrimmedNode> TrimRightMostNode() && override {
73     return absl_ports::InvalidArgumentError(
74         "Query suggestions for the semanticSearch function are not supported");
75   }
76 
GetCallStats()77   CallStats GetCallStats() const override {
78     return CallStats(
79         /*num_leaf_advance_calls_lite_index_in=*/num_advance_calls_,
80         /*num_leaf_advance_calls_main_index_in=*/0,
81         /*num_leaf_advance_calls_integer_index_in=*/0,
82         /*num_leaf_advance_calls_no_index_in=*/0,
83         /*num_blocks_inspected_in=*/0);
84   }
85 
ToString()86   std::string ToString() const override { return "embedding_iterator"; }
87 
88   // PopulateMatchedTermsStats is not applicable to embedding search.
PopulateMatchedTermsStats(std::vector<TermMatchInfo> * matched_terms_stats,SectionIdMask filtering_section_mask)89   void PopulateMatchedTermsStats(
90       std::vector<TermMatchInfo>* matched_terms_stats,
91       SectionIdMask filtering_section_mask) const override {}
92 
93  private:
DocHitInfoIteratorEmbedding(const PropertyProto::VectorProto * query,std::unique_ptr<SectionRestrictData> section_restrict_data,SearchSpecProto::EmbeddingQueryMetricType::Code metric_type,std::unique_ptr<EmbeddingScorer> embedding_scorer,double score_low,double score_high,EmbeddingQueryResults::EmbeddingQueryScoreMap * score_map,const EmbeddingIndex * embedding_index,std::unique_ptr<PostingListEmbeddingHitAccessor> posting_list_accessor)94   explicit DocHitInfoIteratorEmbedding(
95       const PropertyProto::VectorProto* query,
96       std::unique_ptr<SectionRestrictData> section_restrict_data,
97       SearchSpecProto::EmbeddingQueryMetricType::Code metric_type,
98       std::unique_ptr<EmbeddingScorer> embedding_scorer, double score_low,
99       double score_high,
100       EmbeddingQueryResults::EmbeddingQueryScoreMap* score_map,
101       const EmbeddingIndex* embedding_index,
102       std::unique_ptr<PostingListEmbeddingHitAccessor> posting_list_accessor)
103       : query_(*query),
104         section_restrict_data_(std::move(section_restrict_data)),
105         metric_type_(metric_type),
106         embedding_scorer_(std::move(embedding_scorer)),
107         score_low_(score_low),
108         score_high_(score_high),
109         score_map_(*score_map),
110         embedding_index_(*embedding_index),
111         posting_list_accessor_(std::move(posting_list_accessor)),
112         cached_embedding_hits_idx_(0),
113         current_allowed_sections_mask_(kSectionIdMaskAll),
114         no_more_hit_(false),
115         num_advance_calls_(0) {}
116 
117   // Advance to the next embedding hit of the current document. If the current
118   // document id is kInvalidDocumentId, the method will advance to the first
119   // embedding hit of the next document and update doc_hit_info_.
120   //
121   // This method also properly updates cached_embedding_hits_,
122   // cached_embedding_hits_idx_, current_allowed_sections_mask_, and
123   // no_more_hit_ to reflect the current state.
124   //
125   // Returns:
126   //   - a const pointer to the next embedding hit on success.
127   //   - nullptr, if there is no more hit for the current document, or no more
128   //     hit in general if the current document id is kInvalidDocumentId.
129   //   - Any error from posting lists.
130   libtextclassifier3::StatusOr<const EmbeddingHit*> AdvanceToNextEmbeddingHit();
131 
132   // Similar to Advance(), this method advances the iterator to the next
133   // document, but it does not guarantee that the next document will have
134   // a matched embedding hit within the score range.
135   //
136   // Returns:
137   //   - OK, if it is able to advance to a new document_id.
138   //   - RESOUCE_EXHAUSTED, if we have run out of document_ids to iterate over.
139   //   - Any error from posting lists.
140   libtextclassifier3::Status AdvanceToNextUnfilteredDocument();
141 
142   // Query information
143   const PropertyProto::VectorProto& query_;                     // Does not own
144   std::unique_ptr<SectionRestrictData> section_restrict_data_;  // Nullable.
145 
146   // Scoring arguments
147   SearchSpecProto::EmbeddingQueryMetricType::Code metric_type_;
148   std::unique_ptr<EmbeddingScorer> embedding_scorer_;
149   double score_low_;
150   double score_high_;
151 
152   // Score map
153   EmbeddingQueryResults::EmbeddingQueryScoreMap& score_map_;  // Does not own
154 
155   // Access to embeddings index data
156   const EmbeddingIndex& embedding_index_;
157   std::unique_ptr<PostingListEmbeddingHitAccessor> posting_list_accessor_;
158 
159   // Cached data from the embeddings index
160   std::vector<EmbeddingHit> cached_embedding_hits_;
161   int cached_embedding_hits_idx_;
162   SectionIdMask current_allowed_sections_mask_;
163   bool no_more_hit_;
164 
165   int num_advance_calls_;
166 };
167 
168 }  // namespace lib
169 }  // namespace icing
170 
171 #endif  // ICING_INDEX_EMBED_DOC_HIT_INFO_ITERATOR_EMBEDDING_H_
172