1 /** 2 * Copyright 2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_PARAMETER_CACHE_EMBEDDING_CACHE_H_ 18 #define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_PARAMETER_CACHE_EMBEDDING_CACHE_H_ 19 #include <cmath> 20 #include <algorithm> 21 #include <memory> 22 #include "include/api/status.h" 23 #include "include/api/data_type.h" 24 #include "src/common/log_adapter.h" 25 #include "src/extendrt/delegate/parameter_cache/cache_algorithm.h" 26 #include "src/extendrt/delegate/parameter_cache/cache_mem_base.h" 27 28 namespace mindspore { 29 namespace cache { 30 class EmbeddingCache { 31 public: EmbeddingCache(size_t vocab_size,size_t device_cache_size,size_t batch_elements,int rank_id,int rank_group_size)32 EmbeddingCache(size_t vocab_size, size_t device_cache_size, size_t batch_elements, int rank_id, int rank_group_size) 33 : vocab_size_(vocab_size), 34 device_cache_size_(device_cache_size), 35 batch_elements_(batch_elements), 36 rank_id_(rank_id), 37 rank_group_size_(rank_group_size) { 38 MS_ASSERT(rank_group_size_ != 0); 39 auto local_shard_size = static_cast<int>(std::ceil(static_cast<float>(vocab_size_) / rank_group_size_)); 40 min_host_index_ = local_shard_size * rank_id_; 41 max_host_index_ = std::min(min_host_index_ + local_shard_size, static_cast<int>(vocab_size_)); 42 host_cache_size_ = max_host_index_ - min_host_index_; 43 44 MS_LOG(INFO) << "rank_group_size_ num:" << rank_group_size_ << ", rank id:" << rank_id_ 45 << ", vocab_size_:" << vocab_size_ << ", host_cache_size_:" << host_cache_size_ 46 << ", index begin:" << min_host_index_ << ", index end:" << max_host_index_; 47 } 48 49 ~EmbeddingCache(); 50 Status Init(uint32_t device_id, const void *context, mindspore::MSTensor host_cache_tensor, 51 mindspore::MSTensor device_tensor); 52 Status SetHostCacheAddr(void *addr, size_t size); 53 Status SetDeviceCacheAddr(void *host_mem_addr, size_t size); 54 Status CheckCacheHit(const int *batch_ids, const size_t batch_ids_len, int *hash_index); GetDeviceStartIndex()55 size_t GetDeviceStartIndex() { return device_start_index_; } 56 57 private: 58 Status Init(mindspore::MSTensor host_cache_tensor, mindspore::MSTensor device_tensor); 59 Status MallocCacheMemory(); 60 61 private: 62 std::shared_ptr<cache::CacheMemBase> device_cache_{nullptr}; 63 std::shared_ptr<CacheAlgorithm> cache_{nullptr}; 64 65 size_t vocab_size_{0}; // total size 66 size_t host_cache_size_{0}; // local host size 67 size_t device_cache_size_{0}; // local device cache size 68 size_t device_start_index_{0}; 69 size_t embedding_size_{0}; 70 size_t batch_elements_{0}; 71 72 DataType data_type_{DataType::kNumberTypeFloat32}; 73 size_t sizeof_data_type_{0}; 74 75 void *device_addr_{nullptr}; // hash_info.device_address.addr 76 void *host_addr_{nullptr}; 77 78 int *hash_swap_index_addr_; // embedding_device_cache_->hash_swap_index_addr_ 79 void *hash_swap_value_addr_; 80 void *hash_swap_value_device_addr_; 81 82 int rank_id_; 83 int rank_group_size_; 84 int min_host_index_{0}; 85 int max_host_index_{0}; 86 }; 87 } // namespace cache 88 } // namespace mindspore 89 #endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_PARAMETER_CACHE_EMBEDDING_CACHE_H_ 90