• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (C) 2024 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #ifndef ICING_INDEX_EMBED_EMBEDDING_QUERY_RESULTS_H_
16 #define ICING_INDEX_EMBED_EMBEDDING_QUERY_RESULTS_H_
17 
18 #include <memory>
19 #include <unordered_map>
20 #include <vector>
21 
22 #include "icing/legacy/core/icing-packed-pod.h"
23 #include "icing/proto/search.pb.h"
24 #include "icing/schema/section.h"
25 #include "icing/store/document-id.h"
26 
27 namespace icing {
28 namespace lib {
29 
30 struct EmbeddingMatchInfos {
31   // A vector of semantic scores of matched embeddings.
32   std::vector<double> scores;
33 
34   struct EmbeddingMatchSectionInfo {
35     // The position of the matched embedding vector in a section relative to
36     // other vectors with the same (dimension, signature) combination. Note that
37     // this is not the universal position of the vector in the section.
38     //
39     // E.g. If a repeated vector property contains the following vectors:
40     // - vector1: [1, 2, 3] (signature = "signature1", dimension = 3)
41     // - vector2: [7, 8, 9] (signature = "signature1", dimension = 3)
42     // - vector3: [4, 5, 6, 8] (signature = "signature2", dimension = 4)
43     // - vector4: [10, 11, 12] (signature = "signature1", dimension = 3)
44     //
45     // Then the position values for each vector would be:
46     // - vector1: 0
47     // - vector2: 1
48     // - vector3: 0
49     // - vector4: 2
50     int position;
51 
52     // The section id of an embedding vector.
53     SectionId section_id;
54   } __attribute__((packed));
55   static_assert(sizeof(EmbeddingMatchSectionInfo) == 5, "");
56   static_assert(icing_is_packed_pod<EmbeddingMatchSectionInfo>::value,
57                 "go/icing-ubsan");
58 
59   // A vector of section infos on the matched embeddings. This will be nullptr
60   // if embedding match info is not enabled for this query.
61   //
62   // When non-null, section_infos must have a 1:1 mapping with the scores
63   // vector.
64   std::unique_ptr<std::vector<EmbeddingMatchSectionInfo>> section_infos;
65 
66   EmbeddingMatchInfos() = default;
67 
68   EmbeddingMatchInfos(const EmbeddingMatchInfos& other) = delete;
69   EmbeddingMatchInfos& operator=(const EmbeddingMatchInfos& other) = delete;
70 
71   // Appends a score to the scores vector.
AppendScoreEmbeddingMatchInfos72   void AppendScore(double score) { scores.push_back(score); }
73 
74   // Appends a section info to the section_infos vector, allocating if needed.
AppendSectionInfoEmbeddingMatchInfos75   void AppendSectionInfo(SectionId section_id, int position) {
76     if (!section_infos) {
77       section_infos =
78           std::make_unique<std::vector<EmbeddingMatchSectionInfo>>();
79     }
80     section_infos->push_back({.position = position, .section_id = section_id});
81   }
82 };
83 
84 // A class to store results generated from embedding queries.
85 struct EmbeddingQueryResults {
86   // Maps from DocumentId to matched embedding infos for that document.
87   // For each document, its embedding match info consists of two vectors:
88   // - The scores vector, which will be used in the advanced scoring language
89   //   to determine the results for the "this.matchedSemanticScores(...)"
90   //   function.
91   // - The section infos vector, which will be used to retrieve snippeting
92   //   MatchInfo for the embedding query.
93   using EmbeddingQueryMatchInfoMap =
94       std::unordered_map<DocumentId, EmbeddingMatchInfos>;
95 
96   // Maps from (query_vector_index, metric_type) to EmbeddingQueryMatchInfoMap.
97   std::unordered_map<
98       int, std::unordered_map<SearchSpecProto::EmbeddingQueryMetricType::Code,
99                               EmbeddingQueryMatchInfoMap>>
100       result_infos;
101 
102   // Get the MatchedInfo map for the given query_vector_index and metric_type.
103   // Returns nullptr if (query_vector_index, metric_type) does not exist in the
104   // result_scores map.
GetMatchInfoMapEmbeddingQueryResults105   const EmbeddingQueryMatchInfoMap* GetMatchInfoMap(
106       int query_vector_index,
107       SearchSpecProto::EmbeddingQueryMetricType::Code metric_type) const {
108     // Check if a mapping exists for the query_vector_index
109     auto outer_it = result_infos.find(query_vector_index);
110     if (outer_it == result_infos.end()) {
111       return nullptr;
112     }
113     // Check if a mapping exists for the metric_type
114     auto inner_it = outer_it->second.find(metric_type);
115     if (inner_it == outer_it->second.end()) {
116       return nullptr;
117     }
118     return &inner_it->second;
119   }
120 
121   // Returns the matched infos for the given query_vector_index, metric_type,
122   // and doc_id. Returns nullptr if (query_vector_index, metric_type, doc_id)
123   // does not exist in the result_scores map.
GetMatchedInfosForDocumentEmbeddingQueryResults124   const EmbeddingMatchInfos* GetMatchedInfosForDocument(
125       int query_vector_index,
126       SearchSpecProto::EmbeddingQueryMetricType::Code metric_type,
127       DocumentId doc_id) const {
128     const EmbeddingQueryMatchInfoMap* info_map =
129         GetMatchInfoMap(query_vector_index, metric_type);
130     if (info_map == nullptr) {
131       return nullptr;
132     }
133     // Check if the doc_id exists in the info_map
134     auto info_it = info_map->find(doc_id);
135     if (info_it == info_map->end()) {
136       return nullptr;
137     }
138     return &info_it->second;
139   }
140 
141   // Returns the matched scores for the given query_vector_index, metric_type,
142   // and doc_id. Returns nullptr if (query_vector_index, metric_type, doc_id)
143   // does not exist in the result_scores map.
GetMatchedScoresForDocumentEmbeddingQueryResults144   const std::vector<double>* GetMatchedScoresForDocument(
145       int query_vector_index,
146       SearchSpecProto::EmbeddingQueryMetricType::Code metric_type,
147       DocumentId doc_id) const {
148     const EmbeddingMatchInfos* match_infos =
149         GetMatchedInfosForDocument(query_vector_index, metric_type, doc_id);
150     if (match_infos == nullptr) {
151       return nullptr;
152     }
153     return &match_infos->scores;
154   };
155 };
156 
157 }  // namespace lib
158 }  // namespace icing
159 
160 #endif  // ICING_INDEX_EMBED_EMBEDDING_QUERY_RESULTS_H_
161