• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_PARAMETER_CACHE_EMBEDDING_CACHE_H_
18 #define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_PARAMETER_CACHE_EMBEDDING_CACHE_H_
19 #include <cmath>
20 #include <algorithm>
21 #include <memory>
22 #include "include/api/status.h"
23 #include "include/api/data_type.h"
24 #include "src/common/log_adapter.h"
25 #include "src/extendrt/delegate/parameter_cache/cache_algorithm.h"
26 #include "src/extendrt/delegate/parameter_cache/cache_mem_base.h"
27 
28 namespace mindspore {
29 namespace cache {
30 class EmbeddingCache {
31  public:
EmbeddingCache(size_t vocab_size,size_t device_cache_size,size_t batch_elements,int rank_id,int rank_group_size)32   EmbeddingCache(size_t vocab_size, size_t device_cache_size, size_t batch_elements, int rank_id, int rank_group_size)
33       : vocab_size_(vocab_size),
34         device_cache_size_(device_cache_size),
35         batch_elements_(batch_elements),
36         rank_id_(rank_id),
37         rank_group_size_(rank_group_size) {
38     MS_ASSERT(rank_group_size_ != 0);
39     auto local_shard_size = static_cast<int>(std::ceil(static_cast<float>(vocab_size_) / rank_group_size_));
40     min_host_index_ = local_shard_size * rank_id_;
41     max_host_index_ = std::min(min_host_index_ + local_shard_size, static_cast<int>(vocab_size_));
42     host_cache_size_ = max_host_index_ - min_host_index_;
43 
44     MS_LOG(INFO) << "rank_group_size_ num:" << rank_group_size_ << ", rank id:" << rank_id_
45                  << ", vocab_size_:" << vocab_size_ << ", host_cache_size_:" << host_cache_size_
46                  << ", index begin:" << min_host_index_ << ", index end:" << max_host_index_;
47   }
48 
49   ~EmbeddingCache();
50   Status Init(uint32_t device_id, const void *context, mindspore::MSTensor host_cache_tensor,
51               mindspore::MSTensor device_tensor);
52   Status SetHostCacheAddr(void *addr, size_t size);
53   Status SetDeviceCacheAddr(void *host_mem_addr, size_t size);
54   Status CheckCacheHit(const int *batch_ids, const size_t batch_ids_len, int *hash_index);
GetDeviceStartIndex()55   size_t GetDeviceStartIndex() { return device_start_index_; }
56 
57  private:
58   Status Init(mindspore::MSTensor host_cache_tensor, mindspore::MSTensor device_tensor);
59   Status MallocCacheMemory();
60 
61  private:
62   std::shared_ptr<cache::CacheMemBase> device_cache_{nullptr};
63   std::shared_ptr<CacheAlgorithm> cache_{nullptr};
64 
65   size_t vocab_size_{0};         // total size
66   size_t host_cache_size_{0};    // local host size
67   size_t device_cache_size_{0};  // local device cache size
68   size_t device_start_index_{0};
69   size_t embedding_size_{0};
70   size_t batch_elements_{0};
71 
72   DataType data_type_{DataType::kNumberTypeFloat32};
73   size_t sizeof_data_type_{0};
74 
75   void *device_addr_{nullptr};  // hash_info.device_address.addr
76   void *host_addr_{nullptr};
77 
78   int *hash_swap_index_addr_;  // embedding_device_cache_->hash_swap_index_addr_
79   void *hash_swap_value_addr_;
80   void *hash_swap_value_device_addr_;
81 
82   int rank_id_;
83   int rank_group_size_;
84   int min_host_index_{0};
85   int max_host_index_{0};
86 };
87 }  // namespace cache
88 }  // namespace mindspore
89 #endif  // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_PARAMETER_CACHE_EMBEDDING_CACHE_H_
90