1 /** 2 * Copyright 2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef aMINDSPORE_CCSRC_DISTRIBUTED_EMBEDDING_CACHE_EMBEDDING_STORAGE_ABSTRACT_EMBEDDING_STORAGE_H_ 18 #define aMINDSPORE_CCSRC_DISTRIBUTED_EMBEDDING_CACHE_EMBEDDING_STORAGE_ABSTRACT_EMBEDDING_STORAGE_H_ 19 20 #include <memory> 21 #include <vector> 22 23 #include "include/backend/distributed/persistent/storage/storage.h" 24 #include "include/backend/device_address.h" 25 26 namespace mindspore { 27 namespace distributed { 28 namespace storage { 29 using mindspore::device::DeviceAddress; 30 constexpr size_t kDefaultSliceSizeInMB = 1024; 31 32 /** 33 * @brief AbstractEmbeddingStorage is encapsulated within the Huge Embedding Table's lookup and update interface. It 34 * supports embeddingstorage query and modification of Embeddings, interaction between the host cache(for hot spot data) 35 * and persistent storage (for non-hot spot data), and preferential access to Embeddings in the host cache. If the 36 * corresponding element cannot be found in the host cache, then read the element from the persistent storage. 37 * Otherwise, if the host cache has insufficient space, the expired elements will automatically be evicted the to the 38 * persistent storage. 39 */ 40 class AbstractEmbeddingStorage { 41 public: 42 AbstractEmbeddingStorage() = default; 43 virtual ~AbstractEmbeddingStorage() = default; 44 45 /** 46 * @brief Initialize the embedding storage. 47 * @param[in] `device_address`: The device address of the Embedding Table corresponding to the 48 * AbstractEmbeddingStorage. 49 */ 50 virtual void Initialize(const DeviceAddress *device_address) = 0; 51 52 /** 53 * @brief Finalize the AbstractEmbeddingStorage, release allocated resource. 54 */ 55 virtual void Finalize() = 0; 56 57 /** 58 * @brief Batch embeddings lookup operation. 59 * Query Embeddings in the host cache first, if the corresponding element cannot be found in the host cache, then read 60 * the element from the persistent storage and insert host cache. 61 * Access an element of the cache generally affects the location or order of the elements in the cache, depending 62 * on different cache strategies. 63 * @param[in] `keys`: All keys which need to query, containing data pointer and data buffer length. 64 * @param[out] `values`: The output embeddings, containing data pointer and data buffer length. 65 * @return Whether the function was successfully executed. 66 */ 67 virtual bool Get(const ConstDataWithLen &keys, const DataWithLen &values) = 0; 68 69 /** 70 * @brief Batch embeddings update/insert operation. 71 * Update/Insert Embeddings in the host cache first, if the host cache has insufficient space, the expired elements 72 * will automatically be evicted the to the persistent storage. 73 * Update or Insert an element of the cache generally affects the location or order of the elements in the cache, 74 * depending on different cache strategies. 75 * @param[in] `keys`: All keys whose emebedding need to update, containing data pointer and data buffer length. 76 * @param[in] `values`: Embeddings corresponding to all keys need to be updated, containing data pointer and data 77 * buffer length. 78 * @return Whether the function was successfully executed. 79 */ 80 virtual bool Put(const ConstDataWithLen &keys, const ConstDataWithLen &values) = 0; 81 82 /** 83 * @brief To export a slice from the storage, the size is specified by the parameter 'slice_size_in_mega_bytes' in MB. 84 * The default value of 1024 means that the value of the exported slice occupies 1024MB of memory. 85 * @param[in] `incremental`: Determine whether export in incremental or full manner, true 86 * for incremental export, false for full export 87 * @param[out] `last_slice`: A bool is returned to indicate whether the slice by export is the last slice, that is, 88 * the export is complete. 89 * @param[in] `slice_size_in_mega_bytes`: Assign host memory in MB that the value of the exported slice occupies, 90 * default 1024MB. 91 * @return The byte sequence of export data. 92 */ 93 virtual std::vector<std::shared_ptr<std::vector<char>>> ExportSlice( 94 bool incremental, bool *last_slice, size_t slice_size_in_mega_bytes = kDefaultSliceSizeInMB) = 0; 95 }; 96 } // namespace storage 97 } // namespace distributed 98 } // namespace mindspore 99 #endif // aMINDSPORE_CCSRC_DISTRIBUTED_EMBEDDING_CACHE_EMBEDDING_STORAGE_ABSTRACT_EMBEDDING_STORAGE_H_ 100