• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_PARAMETER_CACHE_EMBEDDING_CACHE_MANAGER_H_
18 #define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_PARAMETER_CACHE_EMBEDDING_CACHE_MANAGER_H_
19 #include <memory>
20 #include <map>
21 #include <string>
22 #include <vector>
23 #include "include/api/kernel.h"
24 #include "include/api/status.h"
25 #include "include/api/data_type.h"
26 #include "src/extendrt/delegate/parameter_cache/embedding_cache.h"
27 #include "src/extendrt/delegate/parameter_cache/load_host_cache_model.h"
28 #include "src/extendrt/delegate/tensorrt/distribution/distribution_base.h"
29 
30 namespace mindspore {
31 namespace cache {
32 class EmbeddingCacheManager {
33  public:
EmbeddingCacheManager()34   EmbeddingCacheManager() {
35     rank_id_ = lite::GetRankID();
36     rank_group_size_ = lite::GetGPUGroupSize();
37   }
38   Status Init(const std::string &cache_model_path, size_t vocab_size, size_t device_cache_size);
39   Status Init(DelegateModel<schema::Primitive> *model, size_t vocab_size, size_t device_cache_size);
40   bool CheckIsCacheKernel(kernel::Kernel *kernel);
41   Status InitCacheKernel(kernel::Kernel *kernel, uint32_t device_id, const void *context);
42   bool IsCacheTensor(mindspore::MSTensor tensor);
43   int CacheHandle(const std::string &tensor_name, mindspore::MSTensor model_input_tensor, void *device_addr);
44   Status SetDeviceCacheAddr(const std::string &tensor_name, void *device_mem_addr, size_t size);
45   std::vector<int64_t> GetCacheShape(mindspore::MSTensor tensor);
46   size_t GetCacheDataSize(mindspore::MSTensor tensor);
47 
48  private:
49   std::map<std::string, std::shared_ptr<EmbeddingCache>> caches_;
50   std::vector<int> hash_indices_;
51   int rank_id_{0};
52   int rank_group_size_{1};
53 
54   std::shared_ptr<HostCacheModel> host_cache_model_;
55   size_t vocab_size_;
56   size_t device_cache_size_;
57 };
58 }  // namespace cache
59 }  // namespace mindspore
60 #endif  // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_PARAMETER_CACHE_EMBEDDING_CACHE_MANAGER_H_
61