• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (C) 2024 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #ifndef ICING_INDEX_EMBED_DOC_HIT_INFO_ITERATOR_EMBEDDING_H_
16 #define ICING_INDEX_EMBED_DOC_HIT_INFO_ITERATOR_EMBEDDING_H_
17 
18 #include <cstdint>
19 #include <memory>
20 #include <string>
21 #include <string_view>
22 #include <utility>
23 #include <vector>
24 
25 #include "icing/text_classifier/lib3/utils/base/status.h"
26 #include "icing/text_classifier/lib3/utils/base/statusor.h"
27 #include "icing/absl_ports/canonical_errors.h"
28 #include "icing/index/embed/embedding-hit.h"
29 #include "icing/index/embed/embedding-index.h"
30 #include "icing/index/embed/embedding-query-results.h"
31 #include "icing/index/embed/embedding-scorer.h"
32 #include "icing/index/embed/posting-list-embedding-hit-accessor.h"
33 #include "icing/index/iterator/doc-hit-info-iterator.h"
34 #include "icing/index/iterator/section-restrict-data.h"
35 #include "icing/proto/search.pb.h"
36 #include "icing/schema/schema-store.h"
37 #include "icing/schema/section.h"
38 #include "icing/store/document-filter-data.h"
39 #include "icing/store/document-store.h"
40 
41 namespace icing {
42 namespace lib {
43 
44 class DocHitInfoIteratorEmbedding
45     : public DocHitInfoIteratorHandlingSectionRestrict {
46  public:
47   // Create a DocHitInfoIterator for iterating through all docs which have an
48   // embedding matched with the provided query with a score in the range of
49   // [score_low, score_high], using the provided metric_type.
50   //
51   // The iterator will store the matched embedding scores in info_map to
52   // prepare for scoring and snippeting.
53   //
54   // The iterator will handle the section restriction logic internally with the
55   // help of DocHitInfoIteratorHandlingSectionRestrict.
56   //
57   // Returns:
58   //   - a DocHitInfoIteratorEmbedding instance on success.
59   //   - Any error from posting lists.
60   static libtextclassifier3::StatusOr<
61       std::unique_ptr<DocHitInfoIteratorEmbedding>>
62   Create(const PropertyProto::VectorProto* query,
63          SearchSpecProto::EmbeddingQueryMetricType::Code metric_type,
64          double score_low, double score_high, bool get_embedding_match_info,
65          EmbeddingQueryResults::EmbeddingQueryMatchInfoMap* info_map,
66          const EmbeddingIndex* embedding_index,
67          const DocumentStore* document_store, const SchemaStore* schema_store,
68          int64_t current_time_ms);
69 
70   libtextclassifier3::Status Advance() override;
71 
TrimRightMostNode()72   libtextclassifier3::StatusOr<TrimmedNode> TrimRightMostNode() && override {
73     return absl_ports::InvalidArgumentError(
74         "Query suggestions for the semanticSearch function are not supported");
75   }
76 
GetCallStats()77   CallStats GetCallStats() const override {
78     return CallStats(
79         /*num_leaf_advance_calls_lite_index_in=*/num_advance_calls_,
80         /*num_leaf_advance_calls_main_index_in=*/0,
81         /*num_leaf_advance_calls_integer_index_in=*/0,
82         /*num_leaf_advance_calls_no_index_in=*/0,
83         /*num_blocks_inspected_in=*/0);
84   }
85 
ToString()86   std::string ToString() const override { return "embedding_iterator"; }
87 
88   // PopulateMatchedTermsStats is not applicable to embedding search.
PopulateMatchedTermsStats(std::vector<TermMatchInfo> * matched_terms_stats,SectionIdMask filtering_section_mask)89   void PopulateMatchedTermsStats(
90       std::vector<TermMatchInfo>* matched_terms_stats,
91       SectionIdMask filtering_section_mask) const override {}
92 
93  private:
DocHitInfoIteratorEmbedding(const PropertyProto::VectorProto * query,SearchSpecProto::EmbeddingQueryMetricType::Code metric_type,std::unique_ptr<EmbeddingScorer> embedding_scorer,double score_low,double score_high,bool get_embedding_match_info,EmbeddingQueryResults::EmbeddingQueryMatchInfoMap * info_map,const EmbeddingIndex * embedding_index,std::unique_ptr<PostingListEmbeddingHitAccessor> posting_list_accessor,const DocumentStore * document_store,const SchemaStore * schema_store,int64_t current_time_ms)94   explicit DocHitInfoIteratorEmbedding(
95       const PropertyProto::VectorProto* query,
96       SearchSpecProto::EmbeddingQueryMetricType::Code metric_type,
97       std::unique_ptr<EmbeddingScorer> embedding_scorer, double score_low,
98       double score_high, bool get_embedding_match_info,
99       EmbeddingQueryResults::EmbeddingQueryMatchInfoMap* info_map,
100       const EmbeddingIndex* embedding_index,
101       std::unique_ptr<PostingListEmbeddingHitAccessor> posting_list_accessor,
102       const DocumentStore* document_store, const SchemaStore* schema_store,
103       int64_t current_time_ms)
104       : query_(*query),
105         metric_type_(metric_type),
106         embedding_scorer_(std::move(embedding_scorer)),
107         score_low_(score_low),
108         score_high_(score_high),
109         get_embedding_match_info_(get_embedding_match_info),
110         info_map_(*info_map),
111         embedding_index_(*embedding_index),
112         posting_list_accessor_(std::move(posting_list_accessor)),
113         cached_embedding_hits_idx_(0),
114         current_allowed_sections_mask_(kSectionIdMaskAll),
115         no_more_hit_(false),
116         schema_type_id_(kInvalidSchemaTypeId),
117         document_store_(*document_store),
118         schema_store_(*schema_store),
119         current_time_ms_(current_time_ms),
120         num_advance_calls_(0) {}
121 
122   // Advance to the next embedding hit of the current document. If the current
123   // document id is kInvalidDocumentId, the method will advance to the first
124   // embedding hit of the next document and update doc_hit_info_.
125   //
126   // This method also properly updates cached_embedding_hits_,
127   // cached_embedding_hits_idx_, current_allowed_sections_mask_, and
128   // no_more_hit_ to reflect the current state.
129   //
130   // Returns:
131   //   - a const pointer to the next embedding hit on success.
132   //   - nullptr, if there is no more hit for the current document, or no more
133   //     hit in general if the current document id is kInvalidDocumentId.
134   //   - Any error from posting lists.
135   libtextclassifier3::StatusOr<const EmbeddingHit*> AdvanceToNextEmbeddingHit();
136 
137   // Similar to Advance(), this method advances the iterator to the next
138   // document, but it does not guarantee that the next document will have
139   // a matched embedding hit within the score range.
140   //
141   // Returns:
142   //   - OK, if it is able to advance to a new document_id.
143   //   - RESOUCE_EXHAUSTED, if we have run out of document_ids to iterate over.
144   //   - Any error from posting lists.
145   libtextclassifier3::Status AdvanceToNextUnfilteredDocument();
146 
147   // Query information
148   const PropertyProto::VectorProto& query_;  // Does not own
149 
150   // Scoring arguments
151   SearchSpecProto::EmbeddingQueryMetricType::Code metric_type_;
152   std::unique_ptr<EmbeddingScorer> embedding_scorer_;
153   double score_low_;
154   double score_high_;
155 
156   // Snippet arguments
157   bool get_embedding_match_info_;
158 
159   // MatchInfo map
160   EmbeddingQueryResults::EmbeddingQueryMatchInfoMap& info_map_;  // Does not own
161 
162   // Access to embeddings index data
163   const EmbeddingIndex& embedding_index_;
164   std::unique_ptr<PostingListEmbeddingHitAccessor> posting_list_accessor_;
165 
166   // Cached data from the embeddings index
167   std::vector<EmbeddingHit> cached_embedding_hits_;
168   int cached_embedding_hits_idx_;
169   SectionIdMask current_allowed_sections_mask_;
170   bool no_more_hit_;
171   SchemaTypeId schema_type_id_;  // The schema type id for the current document.
172 
173   const DocumentStore& document_store_;
174   const SchemaStore& schema_store_;
175   int64_t current_time_ms_;
176   int num_advance_calls_;
177 };
178 
179 }  // namespace lib
180 }  // namespace icing
181 
182 #endif  // ICING_INDEX_EMBED_DOC_HIT_INFO_ITERATOR_EMBEDDING_H_
183