• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (C) 2024 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "icing/index/embed/embedding-index.h"
16 
17 #include <algorithm>
18 #include <cstdint>
19 #include <cstring>
20 #include <memory>
21 #include <string>
22 #include <string_view>
23 #include <utility>
24 #include <vector>
25 
26 #include "icing/text_classifier/lib3/utils/base/status.h"
27 #include "icing/text_classifier/lib3/utils/base/statusor.h"
28 #include "icing/absl_ports/canonical_errors.h"
29 #include "icing/absl_ports/str_cat.h"
30 #include "icing/file/destructible-directory.h"
31 #include "icing/file/file-backed-vector.h"
32 #include "icing/file/filesystem.h"
33 #include "icing/file/memory-mapped-file.h"
34 #include "icing/file/posting_list/flash-index-storage.h"
35 #include "icing/file/posting_list/posting-list-identifier.h"
36 #include "icing/index/embed/embedding-hit.h"
37 #include "icing/index/embed/posting-list-embedding-hit-accessor.h"
38 #include "icing/index/hit/hit.h"
39 #include "icing/store/document-id.h"
40 #include "icing/store/dynamic-trie-key-mapper.h"
41 #include "icing/store/key-mapper.h"
42 #include "icing/util/crc32.h"
43 #include "icing/util/encode-util.h"
44 #include "icing/util/logging.h"
45 #include "icing/util/status-macros.h"
46 
47 namespace icing {
48 namespace lib {
49 
50 namespace {
51 
52 constexpr uint32_t kEmbeddingHitListMapperMaxSize =
53     128 * 1024 * 1024;  // 128 MiB;
54 
55 // The maximum length returned by encode_util::EncodeIntToCString is 5 for
56 // uint32_t.
57 constexpr uint32_t kEncodedDimensionLength = 5;
58 
GetMetadataFilePath(std::string_view working_path)59 std::string GetMetadataFilePath(std::string_view working_path) {
60   return absl_ports::StrCat(working_path, "/metadata");
61 }
62 
GetFlashIndexStorageFilePath(std::string_view working_path)63 std::string GetFlashIndexStorageFilePath(std::string_view working_path) {
64   return absl_ports::StrCat(working_path, "/flash_index_storage");
65 }
66 
GetEmbeddingHitListMapperPath(std::string_view working_path)67 std::string GetEmbeddingHitListMapperPath(std::string_view working_path) {
68   return absl_ports::StrCat(working_path, "/embedding_hit_list_mapper");
69 }
70 
GetEmbeddingVectorsFilePath(std::string_view working_path)71 std::string GetEmbeddingVectorsFilePath(std::string_view working_path) {
72   return absl_ports::StrCat(working_path, "/embedding_vectors");
73 }
74 
75 // An injective function that maps the ordered pair (dimension, model_signature)
76 // to a string, which is used to form a key for embedding_posting_list_mapper_.
GetPostingListKey(uint32_t dimension,std::string_view model_signature)77 std::string GetPostingListKey(uint32_t dimension,
78                               std::string_view model_signature) {
79   std::string encoded_dimension_str =
80       encode_util::EncodeIntToCString(dimension);
81   // Make encoded_dimension_str to fixed kEncodedDimensionLength bytes.
82   while (encoded_dimension_str.size() < kEncodedDimensionLength) {
83     // C string cannot contain 0 bytes, so we append it using 1, just like what
84     // we do in encode_util::EncodeIntToCString.
85     //
86     // The reason that this works is because DecodeIntToString decodes a byte
87     // value of 0x01 as 0x00. When EncodeIntToCString returns an encoded
88     // dimension that is less than 5 bytes, it means that the dimension contains
89     // unencoded leading 0x00. So here we're explicitly encoding those bytes as
90     // 0x01.
91     encoded_dimension_str.push_back(1);
92   }
93   return absl_ports::StrCat(encoded_dimension_str, model_signature);
94 }
95 
GetPostingListKey(const PropertyProto::VectorProto & vector)96 std::string GetPostingListKey(const PropertyProto::VectorProto& vector) {
97   return GetPostingListKey(vector.values_size(), vector.model_signature());
98 }
99 
100 }  // namespace
101 
102 libtextclassifier3::StatusOr<std::unique_ptr<EmbeddingIndex>>
Create(const Filesystem * filesystem,std::string working_path)103 EmbeddingIndex::Create(const Filesystem* filesystem, std::string working_path) {
104   ICING_RETURN_ERROR_IF_NULL(filesystem);
105 
106   std::unique_ptr<EmbeddingIndex> index = std::unique_ptr<EmbeddingIndex>(
107       new EmbeddingIndex(*filesystem, std::move(working_path)));
108   ICING_RETURN_IF_ERROR(index->Initialize());
109   return index;
110 }
111 
CreateStorageDataIfNonEmpty()112 libtextclassifier3::Status EmbeddingIndex::CreateStorageDataIfNonEmpty() {
113   if (is_empty()) {
114     return libtextclassifier3::Status::OK;
115   }
116 
117   ICING_ASSIGN_OR_RETURN(FlashIndexStorage flash_index_storage,
118                          FlashIndexStorage::Create(
119                              GetFlashIndexStorageFilePath(working_path_),
120                              &filesystem_, posting_list_hit_serializer_.get()));
121   flash_index_storage_ =
122       std::make_unique<FlashIndexStorage>(std::move(flash_index_storage));
123 
124   ICING_ASSIGN_OR_RETURN(
125       embedding_posting_list_mapper_,
126       DynamicTrieKeyMapper<PostingListIdentifier>::Create(
127           filesystem_, GetEmbeddingHitListMapperPath(working_path_),
128           kEmbeddingHitListMapperMaxSize));
129 
130   ICING_ASSIGN_OR_RETURN(
131       embedding_vectors_,
132       FileBackedVector<float>::Create(
133           filesystem_, GetEmbeddingVectorsFilePath(working_path_),
134           MemoryMappedFile::READ_WRITE_AUTO_SYNC));
135 
136   return libtextclassifier3::Status::OK;
137 }
138 
MarkIndexNonEmpty()139 libtextclassifier3::Status EmbeddingIndex::MarkIndexNonEmpty() {
140   if (!is_empty()) {
141     return libtextclassifier3::Status::OK;
142   }
143   info().is_empty = false;
144   return CreateStorageDataIfNonEmpty();
145 }
146 
Initialize()147 libtextclassifier3::Status EmbeddingIndex::Initialize() {
148   bool is_new = false;
149   if (!filesystem_.FileExists(GetMetadataFilePath(working_path_).c_str())) {
150     // Create working directory.
151     if (!filesystem_.CreateDirectoryRecursively(working_path_.c_str())) {
152       return absl_ports::InternalError(
153           absl_ports::StrCat("Failed to create directory: ", working_path_));
154     }
155     is_new = true;
156   }
157 
158   ICING_ASSIGN_OR_RETURN(
159       MemoryMappedFile metadata_mmapped_file,
160       MemoryMappedFile::Create(filesystem_, GetMetadataFilePath(working_path_),
161                                MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC,
162                                /*max_file_size=*/kMetadataFileSize,
163                                /*pre_mapping_file_offset=*/0,
164                                /*pre_mapping_mmap_size=*/kMetadataFileSize));
165   metadata_mmapped_file_ =
166       std::make_unique<MemoryMappedFile>(std::move(metadata_mmapped_file));
167 
168   if (is_new) {
169     ICING_RETURN_IF_ERROR(metadata_mmapped_file_->GrowAndRemapIfNecessary(
170         /*file_offset=*/0, /*mmap_size=*/kMetadataFileSize));
171     info().magic = Info::kMagic;
172     info().last_added_document_id = kInvalidDocumentId;
173     info().is_empty = true;
174     memset(Info().padding_, 0, Info::kPaddingSize);
175     ICING_RETURN_IF_ERROR(InitializeNewStorage());
176   } else {
177     if (metadata_mmapped_file_->available_size() != kMetadataFileSize) {
178       return absl_ports::FailedPreconditionError(
179           "Incorrect metadata file size");
180     }
181     if (info().magic != Info::kMagic) {
182       return absl_ports::FailedPreconditionError("Incorrect magic value");
183     }
184     ICING_RETURN_IF_ERROR(CreateStorageDataIfNonEmpty());
185     ICING_RETURN_IF_ERROR(InitializeExistingStorage());
186   }
187   return libtextclassifier3::Status::OK;
188 }
189 
Clear()190 libtextclassifier3::Status EmbeddingIndex::Clear() {
191   pending_embedding_hits_.clear();
192   metadata_mmapped_file_.reset();
193   flash_index_storage_.reset();
194   embedding_posting_list_mapper_.reset();
195   embedding_vectors_.reset();
196   if (filesystem_.DirectoryExists(working_path_.c_str())) {
197     ICING_RETURN_IF_ERROR(Discard(filesystem_, working_path_));
198   }
199   is_initialized_ = false;
200   return Initialize();
201 }
202 
203 libtextclassifier3::StatusOr<std::unique_ptr<PostingListEmbeddingHitAccessor>>
GetAccessor(uint32_t dimension,std::string_view model_signature) const204 EmbeddingIndex::GetAccessor(uint32_t dimension,
205                             std::string_view model_signature) const {
206   if (dimension == 0) {
207     return absl_ports::InvalidArgumentError("Dimension is 0");
208   }
209   if (is_empty()) {
210     return absl_ports::NotFoundError("EmbeddingIndex is empty");
211   }
212 
213   std::string key = GetPostingListKey(dimension, model_signature);
214   ICING_ASSIGN_OR_RETURN(PostingListIdentifier posting_list_id,
215                          embedding_posting_list_mapper_->Get(key));
216   return PostingListEmbeddingHitAccessor::CreateFromExisting(
217       flash_index_storage_.get(), posting_list_hit_serializer_.get(),
218       posting_list_id);
219 }
220 
BufferEmbedding(const BasicHit & basic_hit,const PropertyProto::VectorProto & vector)221 libtextclassifier3::Status EmbeddingIndex::BufferEmbedding(
222     const BasicHit& basic_hit, const PropertyProto::VectorProto& vector) {
223   if (vector.values_size() == 0) {
224     return absl_ports::InvalidArgumentError("Vector dimension is 0");
225   }
226   ICING_RETURN_IF_ERROR(MarkIndexNonEmpty());
227 
228   uint32_t location = embedding_vectors_->num_elements();
229   uint32_t dimension = vector.values_size();
230   std::string key = GetPostingListKey(vector);
231 
232   // Buffer the embedding hit.
233   pending_embedding_hits_.push_back(
234       {std::move(key), EmbeddingHit(basic_hit, location)});
235 
236   // Put vector
237   ICING_ASSIGN_OR_RETURN(FileBackedVector<float>::MutableArrayView mutable_arr,
238                          embedding_vectors_->Allocate(dimension));
239   mutable_arr.SetArray(/*idx=*/0, vector.values().data(), dimension);
240 
241   return libtextclassifier3::Status::OK;
242 }
243 
CommitBufferToIndex()244 libtextclassifier3::Status EmbeddingIndex::CommitBufferToIndex() {
245   if (pending_embedding_hits_.empty()) {
246     return libtextclassifier3::Status::OK;
247   }
248   ICING_RETURN_IF_ERROR(MarkIndexNonEmpty());
249 
250   std::sort(pending_embedding_hits_.begin(), pending_embedding_hits_.end());
251   auto iter_curr_key = pending_embedding_hits_.rbegin();
252   while (iter_curr_key != pending_embedding_hits_.rend()) {
253     // In order to batch putting embedding hits with the same key (dimension,
254     // model_signature) to the same posting list, we find the range
255     // [iter_curr_key, iter_next_key) of embedding hits with the same key and
256     // put them into their corresponding posting list together.
257     auto iter_next_key = iter_curr_key;
258     while (iter_next_key != pending_embedding_hits_.rend() &&
259            iter_next_key->first == iter_curr_key->first) {
260       iter_next_key++;
261     }
262 
263     const std::string& key = iter_curr_key->first;
264     libtextclassifier3::StatusOr<PostingListIdentifier> posting_list_id_or =
265         embedding_posting_list_mapper_->Get(key);
266     std::unique_ptr<PostingListEmbeddingHitAccessor> pl_accessor;
267     if (posting_list_id_or.ok()) {
268       // Existing posting list.
269       ICING_ASSIGN_OR_RETURN(
270           pl_accessor,
271           PostingListEmbeddingHitAccessor::CreateFromExisting(
272               flash_index_storage_.get(), posting_list_hit_serializer_.get(),
273               posting_list_id_or.ValueOrDie()));
274     } else if (absl_ports::IsNotFound(posting_list_id_or.status())) {
275       // New posting list.
276       ICING_ASSIGN_OR_RETURN(
277           pl_accessor,
278           PostingListEmbeddingHitAccessor::Create(
279               flash_index_storage_.get(), posting_list_hit_serializer_.get()));
280     } else {
281       // Errors
282       return std::move(posting_list_id_or).status();
283     }
284 
285     // Adding the embedding hits.
286     for (auto iter = iter_curr_key; iter != iter_next_key; ++iter) {
287       ICING_RETURN_IF_ERROR(pl_accessor->PrependHit(iter->second));
288     }
289 
290     // Finalize this posting list and add the posting list id in
291     // embedding_posting_list_mapper_.
292     PostingListEmbeddingHitAccessor::FinalizeResult result =
293         std::move(*pl_accessor).Finalize();
294     if (!result.id.is_valid()) {
295       return absl_ports::InternalError("Failed to finalize posting list");
296     }
297     ICING_RETURN_IF_ERROR(embedding_posting_list_mapper_->Put(key, result.id));
298 
299     // Advance to the next key.
300     iter_curr_key = iter_next_key;
301   }
302   pending_embedding_hits_.clear();
303   return libtextclassifier3::Status::OK;
304 }
305 
TransferIndex(const std::vector<DocumentId> & document_id_old_to_new,EmbeddingIndex * new_index) const306 libtextclassifier3::Status EmbeddingIndex::TransferIndex(
307     const std::vector<DocumentId>& document_id_old_to_new,
308     EmbeddingIndex* new_index) const {
309   if (is_empty()) {
310     return absl_ports::FailedPreconditionError("EmbeddingIndex is empty");
311   }
312 
313   std::unique_ptr<KeyMapper<PostingListIdentifier>::Iterator> itr =
314       embedding_posting_list_mapper_->GetIterator();
315   while (itr->Advance()) {
316     std::string_view key = itr->GetKey();
317     // This should never happen unless there is an inconsistency, or the index
318     // is corrupted.
319     if (key.size() < kEncodedDimensionLength) {
320       return absl_ports::InternalError(
321           "Got invalid key from embedding posting list mapper.");
322     }
323     uint32_t dimension = encode_util::DecodeIntFromCString(
324         std::string_view(key.begin(), kEncodedDimensionLength));
325 
326     // Transfer hits
327     std::vector<EmbeddingHit> new_hits;
328     ICING_ASSIGN_OR_RETURN(
329         std::unique_ptr<PostingListEmbeddingHitAccessor> old_pl_accessor,
330         PostingListEmbeddingHitAccessor::CreateFromExisting(
331             flash_index_storage_.get(), posting_list_hit_serializer_.get(),
332             /*existing_posting_list_id=*/itr->GetValue()));
333     while (true) {
334       ICING_ASSIGN_OR_RETURN(std::vector<EmbeddingHit> batch,
335                              old_pl_accessor->GetNextHitsBatch());
336       if (batch.empty()) {
337         break;
338       }
339       for (EmbeddingHit& old_hit : batch) {
340         // Safety checks to add robustness to the codebase, so to make sure
341         // that we never access invalid memory, in case that hit from the
342         // posting list is corrupted.
343         ICING_ASSIGN_OR_RETURN(const float* old_vector,
344                                GetEmbeddingVector(old_hit, dimension));
345         if (old_hit.basic_hit().document_id() < 0 ||
346             old_hit.basic_hit().document_id() >=
347                 document_id_old_to_new.size()) {
348           return absl_ports::InternalError(
349               "Embedding hit document id is out of bound. The provided map is "
350               "too small, or the index may have been corrupted.");
351         }
352 
353         // Construct transferred hit
354         DocumentId new_document_id =
355             document_id_old_to_new[old_hit.basic_hit().document_id()];
356         if (new_document_id == kInvalidDocumentId) {
357           continue;
358         }
359         ICING_RETURN_IF_ERROR(new_index->MarkIndexNonEmpty());
360         uint32_t new_location = new_index->embedding_vectors_->num_elements();
361         new_hits.push_back(EmbeddingHit(
362             BasicHit(old_hit.basic_hit().section_id(), new_document_id),
363             new_location));
364 
365         // Copy the embedding vector of the hit to the new index.
366         ICING_ASSIGN_OR_RETURN(
367             FileBackedVector<float>::MutableArrayView mutable_arr,
368             new_index->embedding_vectors_->Allocate(dimension));
369         mutable_arr.SetArray(/*idx=*/0, old_vector, dimension);
370       }
371     }
372     // No hit needs to be added to the new index.
373     if (new_hits.empty()) {
374       continue;
375     }
376     // Add transferred hits to the new index.
377     ICING_ASSIGN_OR_RETURN(
378         std::unique_ptr<PostingListEmbeddingHitAccessor> hit_accum,
379         PostingListEmbeddingHitAccessor::Create(
380             new_index->flash_index_storage_.get(),
381             new_index->posting_list_hit_serializer_.get()));
382     for (auto new_hit_itr = new_hits.rbegin(); new_hit_itr != new_hits.rend();
383          ++new_hit_itr) {
384       ICING_RETURN_IF_ERROR(hit_accum->PrependHit(*new_hit_itr));
385     }
386     PostingListEmbeddingHitAccessor::FinalizeResult result =
387         std::move(*hit_accum).Finalize();
388     if (!result.id.is_valid()) {
389       return absl_ports::InternalError("Failed to finalize posting list");
390     }
391     ICING_RETURN_IF_ERROR(
392         new_index->embedding_posting_list_mapper_->Put(key, result.id));
393   }
394   return libtextclassifier3::Status::OK;
395 }
396 
Optimize(const std::vector<DocumentId> & document_id_old_to_new,DocumentId new_last_added_document_id)397 libtextclassifier3::Status EmbeddingIndex::Optimize(
398     const std::vector<DocumentId>& document_id_old_to_new,
399     DocumentId new_last_added_document_id) {
400   if (is_empty()) {
401     info().last_added_document_id = new_last_added_document_id;
402     return libtextclassifier3::Status::OK;
403   }
404 
405   // This is just for completeness, but this should never be necessary, since we
406   // should never have pending hits at the time when Optimize is run.
407   ICING_RETURN_IF_ERROR(CommitBufferToIndex());
408 
409   std::string temporary_index_working_path = working_path_ + "_temp";
410   if (!filesystem_.DeleteDirectoryRecursively(
411           temporary_index_working_path.c_str())) {
412     ICING_LOG(ERROR) << "Recursively deleting " << temporary_index_working_path;
413     return absl_ports::InternalError(
414         "Unable to delete temp directory to prepare to build new index.");
415   }
416 
417   DestructibleDirectory temporary_index_dir(
418       &filesystem_, std::move(temporary_index_working_path));
419   if (!temporary_index_dir.is_valid()) {
420     return absl_ports::InternalError(
421         "Unable to create temp directory to build new index.");
422   }
423 
424   {
425     ICING_ASSIGN_OR_RETURN(
426         std::unique_ptr<EmbeddingIndex> new_index,
427         EmbeddingIndex::Create(&filesystem_, temporary_index_dir.dir()));
428     ICING_RETURN_IF_ERROR(
429         TransferIndex(document_id_old_to_new, new_index.get()));
430     new_index->set_last_added_document_id(new_last_added_document_id);
431     ICING_RETURN_IF_ERROR(new_index->PersistToDisk());
432   }
433 
434   // Destruct current storage instances to safely swap directories.
435   metadata_mmapped_file_.reset();
436   flash_index_storage_.reset();
437   embedding_posting_list_mapper_.reset();
438   embedding_vectors_.reset();
439 
440   if (!filesystem_.SwapFiles(temporary_index_dir.dir().c_str(),
441                              working_path_.c_str())) {
442     return absl_ports::InternalError(
443         "Unable to apply new index due to failed swap!");
444   }
445 
446   // Reinitialize the index.
447   is_initialized_ = false;
448   return Initialize();
449 }
450 
PersistMetadataToDisk(bool force)451 libtextclassifier3::Status EmbeddingIndex::PersistMetadataToDisk(bool force) {
452   return metadata_mmapped_file_->PersistToDisk();
453 }
454 
PersistStoragesToDisk(bool force)455 libtextclassifier3::Status EmbeddingIndex::PersistStoragesToDisk(bool force) {
456   if (is_empty()) {
457     return libtextclassifier3::Status::OK;
458   }
459 
460   if (!flash_index_storage_->PersistToDisk()) {
461     return absl_ports::InternalError("Fail to persist flash index to disk");
462   }
463   ICING_RETURN_IF_ERROR(embedding_posting_list_mapper_->PersistToDisk());
464   ICING_RETURN_IF_ERROR(embedding_vectors_->PersistToDisk());
465   return libtextclassifier3::Status::OK;
466 }
467 
ComputeInfoChecksum(bool force)468 libtextclassifier3::StatusOr<Crc32> EmbeddingIndex::ComputeInfoChecksum(
469     bool force) {
470   return info().ComputeChecksum();
471 }
472 
ComputeStoragesChecksum(bool force)473 libtextclassifier3::StatusOr<Crc32> EmbeddingIndex::ComputeStoragesChecksum(
474     bool force) {
475   if (is_empty()) {
476     return Crc32(0);
477   }
478   ICING_ASSIGN_OR_RETURN(Crc32 embedding_posting_list_mapper_crc,
479                          embedding_posting_list_mapper_->ComputeChecksum());
480   ICING_ASSIGN_OR_RETURN(Crc32 embedding_vectors_crc,
481                          embedding_vectors_->ComputeChecksum());
482   return Crc32(embedding_posting_list_mapper_crc.Get() ^
483                embedding_vectors_crc.Get());
484 }
485 
486 }  // namespace lib
487 }  // namespace icing
488