1 // Copyright (C) 2024 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "icing/index/embed/embedding-index.h"
16
17 #include <algorithm>
18 #include <cstdint>
19 #include <cstring>
20 #include <memory>
21 #include <string>
22 #include <string_view>
23 #include <utility>
24 #include <vector>
25
26 #include "icing/text_classifier/lib3/utils/base/status.h"
27 #include "icing/text_classifier/lib3/utils/base/statusor.h"
28 #include "icing/absl_ports/canonical_errors.h"
29 #include "icing/absl_ports/str_cat.h"
30 #include "icing/file/destructible-directory.h"
31 #include "icing/file/file-backed-vector.h"
32 #include "icing/file/filesystem.h"
33 #include "icing/file/memory-mapped-file.h"
34 #include "icing/file/posting_list/flash-index-storage.h"
35 #include "icing/file/posting_list/posting-list-identifier.h"
36 #include "icing/index/embed/embedding-hit.h"
37 #include "icing/index/embed/posting-list-embedding-hit-accessor.h"
38 #include "icing/index/hit/hit.h"
39 #include "icing/store/document-id.h"
40 #include "icing/store/dynamic-trie-key-mapper.h"
41 #include "icing/store/key-mapper.h"
42 #include "icing/util/crc32.h"
43 #include "icing/util/encode-util.h"
44 #include "icing/util/logging.h"
45 #include "icing/util/status-macros.h"
46
47 namespace icing {
48 namespace lib {
49
50 namespace {
51
52 constexpr uint32_t kEmbeddingHitListMapperMaxSize =
53 128 * 1024 * 1024; // 128 MiB;
54
55 // The maximum length returned by encode_util::EncodeIntToCString is 5 for
56 // uint32_t.
57 constexpr uint32_t kEncodedDimensionLength = 5;
58
GetMetadataFilePath(std::string_view working_path)59 std::string GetMetadataFilePath(std::string_view working_path) {
60 return absl_ports::StrCat(working_path, "/metadata");
61 }
62
GetFlashIndexStorageFilePath(std::string_view working_path)63 std::string GetFlashIndexStorageFilePath(std::string_view working_path) {
64 return absl_ports::StrCat(working_path, "/flash_index_storage");
65 }
66
GetEmbeddingHitListMapperPath(std::string_view working_path)67 std::string GetEmbeddingHitListMapperPath(std::string_view working_path) {
68 return absl_ports::StrCat(working_path, "/embedding_hit_list_mapper");
69 }
70
GetEmbeddingVectorsFilePath(std::string_view working_path)71 std::string GetEmbeddingVectorsFilePath(std::string_view working_path) {
72 return absl_ports::StrCat(working_path, "/embedding_vectors");
73 }
74
75 // An injective function that maps the ordered pair (dimension, model_signature)
76 // to a string, which is used to form a key for embedding_posting_list_mapper_.
GetPostingListKey(uint32_t dimension,std::string_view model_signature)77 std::string GetPostingListKey(uint32_t dimension,
78 std::string_view model_signature) {
79 std::string encoded_dimension_str =
80 encode_util::EncodeIntToCString(dimension);
81 // Make encoded_dimension_str to fixed kEncodedDimensionLength bytes.
82 while (encoded_dimension_str.size() < kEncodedDimensionLength) {
83 // C string cannot contain 0 bytes, so we append it using 1, just like what
84 // we do in encode_util::EncodeIntToCString.
85 //
86 // The reason that this works is because DecodeIntToString decodes a byte
87 // value of 0x01 as 0x00. When EncodeIntToCString returns an encoded
88 // dimension that is less than 5 bytes, it means that the dimension contains
89 // unencoded leading 0x00. So here we're explicitly encoding those bytes as
90 // 0x01.
91 encoded_dimension_str.push_back(1);
92 }
93 return absl_ports::StrCat(encoded_dimension_str, model_signature);
94 }
95
GetPostingListKey(const PropertyProto::VectorProto & vector)96 std::string GetPostingListKey(const PropertyProto::VectorProto& vector) {
97 return GetPostingListKey(vector.values_size(), vector.model_signature());
98 }
99
100 } // namespace
101
102 libtextclassifier3::StatusOr<std::unique_ptr<EmbeddingIndex>>
Create(const Filesystem * filesystem,std::string working_path)103 EmbeddingIndex::Create(const Filesystem* filesystem, std::string working_path) {
104 ICING_RETURN_ERROR_IF_NULL(filesystem);
105
106 std::unique_ptr<EmbeddingIndex> index = std::unique_ptr<EmbeddingIndex>(
107 new EmbeddingIndex(*filesystem, std::move(working_path)));
108 ICING_RETURN_IF_ERROR(index->Initialize());
109 return index;
110 }
111
CreateStorageDataIfNonEmpty()112 libtextclassifier3::Status EmbeddingIndex::CreateStorageDataIfNonEmpty() {
113 if (is_empty()) {
114 return libtextclassifier3::Status::OK;
115 }
116
117 ICING_ASSIGN_OR_RETURN(FlashIndexStorage flash_index_storage,
118 FlashIndexStorage::Create(
119 GetFlashIndexStorageFilePath(working_path_),
120 &filesystem_, posting_list_hit_serializer_.get()));
121 flash_index_storage_ =
122 std::make_unique<FlashIndexStorage>(std::move(flash_index_storage));
123
124 ICING_ASSIGN_OR_RETURN(
125 embedding_posting_list_mapper_,
126 DynamicTrieKeyMapper<PostingListIdentifier>::Create(
127 filesystem_, GetEmbeddingHitListMapperPath(working_path_),
128 kEmbeddingHitListMapperMaxSize));
129
130 ICING_ASSIGN_OR_RETURN(
131 embedding_vectors_,
132 FileBackedVector<float>::Create(
133 filesystem_, GetEmbeddingVectorsFilePath(working_path_),
134 MemoryMappedFile::READ_WRITE_AUTO_SYNC));
135
136 return libtextclassifier3::Status::OK;
137 }
138
MarkIndexNonEmpty()139 libtextclassifier3::Status EmbeddingIndex::MarkIndexNonEmpty() {
140 if (!is_empty()) {
141 return libtextclassifier3::Status::OK;
142 }
143 info().is_empty = false;
144 return CreateStorageDataIfNonEmpty();
145 }
146
Initialize()147 libtextclassifier3::Status EmbeddingIndex::Initialize() {
148 bool is_new = false;
149 if (!filesystem_.FileExists(GetMetadataFilePath(working_path_).c_str())) {
150 // Create working directory.
151 if (!filesystem_.CreateDirectoryRecursively(working_path_.c_str())) {
152 return absl_ports::InternalError(
153 absl_ports::StrCat("Failed to create directory: ", working_path_));
154 }
155 is_new = true;
156 }
157
158 ICING_ASSIGN_OR_RETURN(
159 MemoryMappedFile metadata_mmapped_file,
160 MemoryMappedFile::Create(filesystem_, GetMetadataFilePath(working_path_),
161 MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC,
162 /*max_file_size=*/kMetadataFileSize,
163 /*pre_mapping_file_offset=*/0,
164 /*pre_mapping_mmap_size=*/kMetadataFileSize));
165 metadata_mmapped_file_ =
166 std::make_unique<MemoryMappedFile>(std::move(metadata_mmapped_file));
167
168 if (is_new) {
169 ICING_RETURN_IF_ERROR(metadata_mmapped_file_->GrowAndRemapIfNecessary(
170 /*file_offset=*/0, /*mmap_size=*/kMetadataFileSize));
171 info().magic = Info::kMagic;
172 info().last_added_document_id = kInvalidDocumentId;
173 info().is_empty = true;
174 memset(Info().padding_, 0, Info::kPaddingSize);
175 ICING_RETURN_IF_ERROR(InitializeNewStorage());
176 } else {
177 if (metadata_mmapped_file_->available_size() != kMetadataFileSize) {
178 return absl_ports::FailedPreconditionError(
179 "Incorrect metadata file size");
180 }
181 if (info().magic != Info::kMagic) {
182 return absl_ports::FailedPreconditionError("Incorrect magic value");
183 }
184 ICING_RETURN_IF_ERROR(CreateStorageDataIfNonEmpty());
185 ICING_RETURN_IF_ERROR(InitializeExistingStorage());
186 }
187 return libtextclassifier3::Status::OK;
188 }
189
Clear()190 libtextclassifier3::Status EmbeddingIndex::Clear() {
191 pending_embedding_hits_.clear();
192 metadata_mmapped_file_.reset();
193 flash_index_storage_.reset();
194 embedding_posting_list_mapper_.reset();
195 embedding_vectors_.reset();
196 if (filesystem_.DirectoryExists(working_path_.c_str())) {
197 ICING_RETURN_IF_ERROR(Discard(filesystem_, working_path_));
198 }
199 is_initialized_ = false;
200 return Initialize();
201 }
202
203 libtextclassifier3::StatusOr<std::unique_ptr<PostingListEmbeddingHitAccessor>>
GetAccessor(uint32_t dimension,std::string_view model_signature) const204 EmbeddingIndex::GetAccessor(uint32_t dimension,
205 std::string_view model_signature) const {
206 if (dimension == 0) {
207 return absl_ports::InvalidArgumentError("Dimension is 0");
208 }
209 if (is_empty()) {
210 return absl_ports::NotFoundError("EmbeddingIndex is empty");
211 }
212
213 std::string key = GetPostingListKey(dimension, model_signature);
214 ICING_ASSIGN_OR_RETURN(PostingListIdentifier posting_list_id,
215 embedding_posting_list_mapper_->Get(key));
216 return PostingListEmbeddingHitAccessor::CreateFromExisting(
217 flash_index_storage_.get(), posting_list_hit_serializer_.get(),
218 posting_list_id);
219 }
220
BufferEmbedding(const BasicHit & basic_hit,const PropertyProto::VectorProto & vector)221 libtextclassifier3::Status EmbeddingIndex::BufferEmbedding(
222 const BasicHit& basic_hit, const PropertyProto::VectorProto& vector) {
223 if (vector.values_size() == 0) {
224 return absl_ports::InvalidArgumentError("Vector dimension is 0");
225 }
226 ICING_RETURN_IF_ERROR(MarkIndexNonEmpty());
227
228 uint32_t location = embedding_vectors_->num_elements();
229 uint32_t dimension = vector.values_size();
230 std::string key = GetPostingListKey(vector);
231
232 // Buffer the embedding hit.
233 pending_embedding_hits_.push_back(
234 {std::move(key), EmbeddingHit(basic_hit, location)});
235
236 // Put vector
237 ICING_ASSIGN_OR_RETURN(FileBackedVector<float>::MutableArrayView mutable_arr,
238 embedding_vectors_->Allocate(dimension));
239 mutable_arr.SetArray(/*idx=*/0, vector.values().data(), dimension);
240
241 return libtextclassifier3::Status::OK;
242 }
243
CommitBufferToIndex()244 libtextclassifier3::Status EmbeddingIndex::CommitBufferToIndex() {
245 if (pending_embedding_hits_.empty()) {
246 return libtextclassifier3::Status::OK;
247 }
248 ICING_RETURN_IF_ERROR(MarkIndexNonEmpty());
249
250 std::sort(pending_embedding_hits_.begin(), pending_embedding_hits_.end());
251 auto iter_curr_key = pending_embedding_hits_.rbegin();
252 while (iter_curr_key != pending_embedding_hits_.rend()) {
253 // In order to batch putting embedding hits with the same key (dimension,
254 // model_signature) to the same posting list, we find the range
255 // [iter_curr_key, iter_next_key) of embedding hits with the same key and
256 // put them into their corresponding posting list together.
257 auto iter_next_key = iter_curr_key;
258 while (iter_next_key != pending_embedding_hits_.rend() &&
259 iter_next_key->first == iter_curr_key->first) {
260 iter_next_key++;
261 }
262
263 const std::string& key = iter_curr_key->first;
264 libtextclassifier3::StatusOr<PostingListIdentifier> posting_list_id_or =
265 embedding_posting_list_mapper_->Get(key);
266 std::unique_ptr<PostingListEmbeddingHitAccessor> pl_accessor;
267 if (posting_list_id_or.ok()) {
268 // Existing posting list.
269 ICING_ASSIGN_OR_RETURN(
270 pl_accessor,
271 PostingListEmbeddingHitAccessor::CreateFromExisting(
272 flash_index_storage_.get(), posting_list_hit_serializer_.get(),
273 posting_list_id_or.ValueOrDie()));
274 } else if (absl_ports::IsNotFound(posting_list_id_or.status())) {
275 // New posting list.
276 ICING_ASSIGN_OR_RETURN(
277 pl_accessor,
278 PostingListEmbeddingHitAccessor::Create(
279 flash_index_storage_.get(), posting_list_hit_serializer_.get()));
280 } else {
281 // Errors
282 return std::move(posting_list_id_or).status();
283 }
284
285 // Adding the embedding hits.
286 for (auto iter = iter_curr_key; iter != iter_next_key; ++iter) {
287 ICING_RETURN_IF_ERROR(pl_accessor->PrependHit(iter->second));
288 }
289
290 // Finalize this posting list and add the posting list id in
291 // embedding_posting_list_mapper_.
292 PostingListEmbeddingHitAccessor::FinalizeResult result =
293 std::move(*pl_accessor).Finalize();
294 if (!result.id.is_valid()) {
295 return absl_ports::InternalError("Failed to finalize posting list");
296 }
297 ICING_RETURN_IF_ERROR(embedding_posting_list_mapper_->Put(key, result.id));
298
299 // Advance to the next key.
300 iter_curr_key = iter_next_key;
301 }
302 pending_embedding_hits_.clear();
303 return libtextclassifier3::Status::OK;
304 }
305
TransferIndex(const std::vector<DocumentId> & document_id_old_to_new,EmbeddingIndex * new_index) const306 libtextclassifier3::Status EmbeddingIndex::TransferIndex(
307 const std::vector<DocumentId>& document_id_old_to_new,
308 EmbeddingIndex* new_index) const {
309 if (is_empty()) {
310 return absl_ports::FailedPreconditionError("EmbeddingIndex is empty");
311 }
312
313 std::unique_ptr<KeyMapper<PostingListIdentifier>::Iterator> itr =
314 embedding_posting_list_mapper_->GetIterator();
315 while (itr->Advance()) {
316 std::string_view key = itr->GetKey();
317 // This should never happen unless there is an inconsistency, or the index
318 // is corrupted.
319 if (key.size() < kEncodedDimensionLength) {
320 return absl_ports::InternalError(
321 "Got invalid key from embedding posting list mapper.");
322 }
323 uint32_t dimension = encode_util::DecodeIntFromCString(
324 std::string_view(key.begin(), kEncodedDimensionLength));
325
326 // Transfer hits
327 std::vector<EmbeddingHit> new_hits;
328 ICING_ASSIGN_OR_RETURN(
329 std::unique_ptr<PostingListEmbeddingHitAccessor> old_pl_accessor,
330 PostingListEmbeddingHitAccessor::CreateFromExisting(
331 flash_index_storage_.get(), posting_list_hit_serializer_.get(),
332 /*existing_posting_list_id=*/itr->GetValue()));
333 while (true) {
334 ICING_ASSIGN_OR_RETURN(std::vector<EmbeddingHit> batch,
335 old_pl_accessor->GetNextHitsBatch());
336 if (batch.empty()) {
337 break;
338 }
339 for (EmbeddingHit& old_hit : batch) {
340 // Safety checks to add robustness to the codebase, so to make sure
341 // that we never access invalid memory, in case that hit from the
342 // posting list is corrupted.
343 ICING_ASSIGN_OR_RETURN(const float* old_vector,
344 GetEmbeddingVector(old_hit, dimension));
345 if (old_hit.basic_hit().document_id() < 0 ||
346 old_hit.basic_hit().document_id() >=
347 document_id_old_to_new.size()) {
348 return absl_ports::InternalError(
349 "Embedding hit document id is out of bound. The provided map is "
350 "too small, or the index may have been corrupted.");
351 }
352
353 // Construct transferred hit
354 DocumentId new_document_id =
355 document_id_old_to_new[old_hit.basic_hit().document_id()];
356 if (new_document_id == kInvalidDocumentId) {
357 continue;
358 }
359 ICING_RETURN_IF_ERROR(new_index->MarkIndexNonEmpty());
360 uint32_t new_location = new_index->embedding_vectors_->num_elements();
361 new_hits.push_back(EmbeddingHit(
362 BasicHit(old_hit.basic_hit().section_id(), new_document_id),
363 new_location));
364
365 // Copy the embedding vector of the hit to the new index.
366 ICING_ASSIGN_OR_RETURN(
367 FileBackedVector<float>::MutableArrayView mutable_arr,
368 new_index->embedding_vectors_->Allocate(dimension));
369 mutable_arr.SetArray(/*idx=*/0, old_vector, dimension);
370 }
371 }
372 // No hit needs to be added to the new index.
373 if (new_hits.empty()) {
374 continue;
375 }
376 // Add transferred hits to the new index.
377 ICING_ASSIGN_OR_RETURN(
378 std::unique_ptr<PostingListEmbeddingHitAccessor> hit_accum,
379 PostingListEmbeddingHitAccessor::Create(
380 new_index->flash_index_storage_.get(),
381 new_index->posting_list_hit_serializer_.get()));
382 for (auto new_hit_itr = new_hits.rbegin(); new_hit_itr != new_hits.rend();
383 ++new_hit_itr) {
384 ICING_RETURN_IF_ERROR(hit_accum->PrependHit(*new_hit_itr));
385 }
386 PostingListEmbeddingHitAccessor::FinalizeResult result =
387 std::move(*hit_accum).Finalize();
388 if (!result.id.is_valid()) {
389 return absl_ports::InternalError("Failed to finalize posting list");
390 }
391 ICING_RETURN_IF_ERROR(
392 new_index->embedding_posting_list_mapper_->Put(key, result.id));
393 }
394 return libtextclassifier3::Status::OK;
395 }
396
Optimize(const std::vector<DocumentId> & document_id_old_to_new,DocumentId new_last_added_document_id)397 libtextclassifier3::Status EmbeddingIndex::Optimize(
398 const std::vector<DocumentId>& document_id_old_to_new,
399 DocumentId new_last_added_document_id) {
400 if (is_empty()) {
401 info().last_added_document_id = new_last_added_document_id;
402 return libtextclassifier3::Status::OK;
403 }
404
405 // This is just for completeness, but this should never be necessary, since we
406 // should never have pending hits at the time when Optimize is run.
407 ICING_RETURN_IF_ERROR(CommitBufferToIndex());
408
409 std::string temporary_index_working_path = working_path_ + "_temp";
410 if (!filesystem_.DeleteDirectoryRecursively(
411 temporary_index_working_path.c_str())) {
412 ICING_LOG(ERROR) << "Recursively deleting " << temporary_index_working_path;
413 return absl_ports::InternalError(
414 "Unable to delete temp directory to prepare to build new index.");
415 }
416
417 DestructibleDirectory temporary_index_dir(
418 &filesystem_, std::move(temporary_index_working_path));
419 if (!temporary_index_dir.is_valid()) {
420 return absl_ports::InternalError(
421 "Unable to create temp directory to build new index.");
422 }
423
424 {
425 ICING_ASSIGN_OR_RETURN(
426 std::unique_ptr<EmbeddingIndex> new_index,
427 EmbeddingIndex::Create(&filesystem_, temporary_index_dir.dir()));
428 ICING_RETURN_IF_ERROR(
429 TransferIndex(document_id_old_to_new, new_index.get()));
430 new_index->set_last_added_document_id(new_last_added_document_id);
431 ICING_RETURN_IF_ERROR(new_index->PersistToDisk());
432 }
433
434 // Destruct current storage instances to safely swap directories.
435 metadata_mmapped_file_.reset();
436 flash_index_storage_.reset();
437 embedding_posting_list_mapper_.reset();
438 embedding_vectors_.reset();
439
440 if (!filesystem_.SwapFiles(temporary_index_dir.dir().c_str(),
441 working_path_.c_str())) {
442 return absl_ports::InternalError(
443 "Unable to apply new index due to failed swap!");
444 }
445
446 // Reinitialize the index.
447 is_initialized_ = false;
448 return Initialize();
449 }
450
PersistMetadataToDisk(bool force)451 libtextclassifier3::Status EmbeddingIndex::PersistMetadataToDisk(bool force) {
452 return metadata_mmapped_file_->PersistToDisk();
453 }
454
PersistStoragesToDisk(bool force)455 libtextclassifier3::Status EmbeddingIndex::PersistStoragesToDisk(bool force) {
456 if (is_empty()) {
457 return libtextclassifier3::Status::OK;
458 }
459
460 if (!flash_index_storage_->PersistToDisk()) {
461 return absl_ports::InternalError("Fail to persist flash index to disk");
462 }
463 ICING_RETURN_IF_ERROR(embedding_posting_list_mapper_->PersistToDisk());
464 ICING_RETURN_IF_ERROR(embedding_vectors_->PersistToDisk());
465 return libtextclassifier3::Status::OK;
466 }
467
ComputeInfoChecksum(bool force)468 libtextclassifier3::StatusOr<Crc32> EmbeddingIndex::ComputeInfoChecksum(
469 bool force) {
470 return info().ComputeChecksum();
471 }
472
ComputeStoragesChecksum(bool force)473 libtextclassifier3::StatusOr<Crc32> EmbeddingIndex::ComputeStoragesChecksum(
474 bool force) {
475 if (is_empty()) {
476 return Crc32(0);
477 }
478 ICING_ASSIGN_OR_RETURN(Crc32 embedding_posting_list_mapper_crc,
479 embedding_posting_list_mapper_->ComputeChecksum());
480 ICING_ASSIGN_OR_RETURN(Crc32 embedding_vectors_crc,
481 embedding_vectors_->ComputeChecksum());
482 return Crc32(embedding_posting_list_mapper_crc.Get() ^
483 embedding_vectors_crc.Get());
484 }
485
486 } // namespace lib
487 } // namespace icing
488