1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifndef MINDSPORE_CCSRC_TRANSFORM_ACL_IR_OP_API_CACHE_H_
18 #define MINDSPORE_CCSRC_TRANSFORM_ACL_IR_OP_API_CACHE_H_
19
20 #include <string>
21 #include <vector>
22 #include <utility>
23 #include "transform/acl_ir/op_api_convert.h"
24
25 namespace mindspore::transform {
26 typedef aclOpExecutor *(*GetExecCache)(uint64_t, uint64_t *);
27 typedef void (*InitCacheThreadLocal)();
28 typedef void (*SetHashKey)(uint64_t);
29 typedef bool (*CanUseCache)(const char *);
30
31 constexpr int g_hash_buf_size = 8192;
32 constexpr int g_hash_buf_max_size = g_hash_buf_size + 1024;
33 extern thread_local char g_hash_buf[g_hash_buf_size];
34 extern thread_local int g_hash_offset;
35
MemcpyToBuf(const void * data_expression,size_t size_expression)36 inline void MemcpyToBuf(const void *data_expression, size_t size_expression) {
37 if (size_expression == 0) {
38 return;
39 }
40 if (g_hash_offset + size_expression >= g_hash_buf_size) {
41 g_hash_offset = g_hash_buf_max_size;
42 return;
43 }
44 auto ret = memcpy_sp(g_hash_buf + g_hash_offset, g_hash_buf_size - g_hash_offset, data_expression, size_expression);
45 if (ret != EOK) {
46 MS_LOG(EXCEPTION) << "Failed to memcpy!";
47 }
48 g_hash_offset += size_expression;
49 }
50
51 void GatherInfo(mindspore::kernel::KernelTensor *);
52 void GatherInfo(const std::pair<mindspore::kernel::KernelTensor *, bool> &);
53 void GatherInfo(const std::vector<mindspore::kernel::KernelTensor *> &);
54 void GatherInfo(const device::DeviceAddressPtr &);
55
56 void GatherInfo(const mindspore::tensor::BaseTensorPtr &);
57 void GatherInfo(const std::optional<tensor::BaseTensorPtr> &);
58 void GatherInfo(const std::vector<tensor::BaseTensorPtr> &);
59 void GatherInfo(const mindspore::tensor::TensorPtr &);
60 void GatherInfo(const std::optional<tensor::TensorPtr> &);
61 void GatherInfo(const std::vector<tensor::TensorPtr> &);
62
63 template <typename T>
GatherInfo(const T & value)64 void GatherInfo(const T &value) {
65 MemcpyToBuf(&value, sizeof(T));
66 }
67
68 template <typename T>
GatherInfo(std::optional<T> value)69 void GatherInfo(std::optional<T> value) {
70 if (value.has_value()) {
71 GatherInfo(value.value());
72 }
73 }
74
75 void GatherInfo(const string &);
76 void GatherInfo(const std::optional<string> &);
77
78 void GatherInfo(const ScalarPtr &);
79 void GatherInfo(const std::optional<ScalarPtr> &);
80
81 void GatherInfo(const TypePtr &);
82 void GatherInfo(const std::optional<TypePtr> &);
83
84 template <typename T>
GatherInfo(const std::vector<T> & values)85 void GatherInfo(const std::vector<T> &values) {
86 MemcpyToBuf(values.data(), values.size() * sizeof(T));
87 }
88
GatherInfo(TypeId type_id)89 inline void GatherInfo(TypeId type_id) { MemcpyToBuf(&type_id, sizeof(int)); }
90
91 void GatherInfo();
92
93 template <typename T, typename... Args>
GatherInfo(const T & arg,const Args &...args)94 void GatherInfo(const T &arg, const Args &... args) {
95 GatherInfo(arg);
96 GatherInfo(args...);
97 }
98
99 void RefreshAddr(mindspore::kernel::KernelTensor *);
100 void RefreshAddr(const std::pair<mindspore::kernel::KernelTensor *, bool> &);
RefreshAddr(const std::vector<mindspore::kernel::KernelTensor * > & tensor_list)101 inline void RefreshAddr(const std::vector<mindspore::kernel::KernelTensor *> &tensor_list) {
102 for (auto tensor : tensor_list) {
103 RefreshAddr(tensor);
104 }
105 }
106
107 template <typename Args>
RefreshAddr(const Args & values)108 void RefreshAddr(const Args &values) {}
109
RefreshAddr()110 inline void RefreshAddr() {}
111
112 template <typename T, typename... Args>
RefreshAddr(const T & arg,const Args &...args)113 void RefreshAddr(const T &arg, const Args &... args) {
114 RefreshAddr(arg);
115 RefreshAddr(args...);
116 }
117
118 uint64_t calc_hash_id();
119 uint64_t gen_hash(const void *key, const int len, const uint32_t seed = 0xdeadb0d7);
120
121 template <typename... Args>
HitCache(const char * aclnn_api,aclOpExecutor ** executor,uint64_t * workspace_size,const Args &...args)122 bool HitCache(const char *aclnn_api, aclOpExecutor **executor, uint64_t *workspace_size, const Args &... args) {
123 static const auto get_exec_cache = transform::GetOpApiFunc("PTAGetExecCache");
124 static const auto init_cache_thread_local = transform::GetOpApiFunc("InitPTACacheThreadLocal");
125 static const auto set_hash_key = transform::GetOpApiFunc("SetPTAHashKey");
126 static const auto can_use_cache = transform::GetOpApiFunc("CanUsePTACache");
127 GetExecCache get_exec_cache_func = reinterpret_cast<GetExecCache>(get_exec_cache);
128 InitCacheThreadLocal init_cache_thread_local_func = reinterpret_cast<InitCacheThreadLocal>(init_cache_thread_local);
129 SetHashKey set_hash_key_func = reinterpret_cast<SetHashKey>(set_hash_key);
130 CanUseCache can_use_cache_func = reinterpret_cast<CanUseCache>(can_use_cache);
131 bool has_func = get_exec_cache_func && init_cache_thread_local_func && set_hash_key_func;
132 bool can_use = can_use_cache_func && can_use_cache_func(aclnn_api);
133 if (!has_func || !can_use) {
134 return false;
135 }
136 init_cache_thread_local_func();
137 g_hash_offset = 0;
138 GatherInfo(std::string(aclnn_api), args...);
139 uint64_t hash_id = calc_hash_id();
140 set_hash_key_func(hash_id);
141 *executor = get_exec_cache_func(hash_id, workspace_size);
142 if (*executor == nullptr) {
143 return false;
144 }
145 return true;
146 }
147
148 template <typename... Args>
CalcOpApiHash(const std::string & arg,const Args &...args)149 uint64_t CalcOpApiHash(const std::string &arg, const Args &... args) {
150 g_hash_offset = 0;
151 GatherInfo(arg, args...);
152 return calc_hash_id();
153 }
154
155 template <typename... Args>
HitCacheSingle(const char * aclnn_api,aclOpExecutor ** executor,uint64_t * workspace_size,uint64_t * hash_id,const Args &...args)156 bool HitCacheSingle(const char *aclnn_api, aclOpExecutor **executor, uint64_t *workspace_size, uint64_t *hash_id,
157 const Args &... args) {
158 static const auto get_exec_cache = transform::GetOpApiFunc("PTAGetExecCache");
159 static const auto init_cache_thread_local = transform::GetOpApiFunc("InitPTACacheThreadLocal");
160 static const auto set_hash_key = transform::GetOpApiFunc("SetPTAHashKey");
161 static const auto can_use_cache = transform::GetOpApiFunc("CanUsePTACache");
162 GetExecCache get_exec_cache_func = reinterpret_cast<GetExecCache>(get_exec_cache);
163 InitCacheThreadLocal init_cache_thread_local_func = reinterpret_cast<InitCacheThreadLocal>(init_cache_thread_local);
164 SetHashKey set_hash_key_func = reinterpret_cast<SetHashKey>(set_hash_key);
165 CanUseCache can_use_cache_func = reinterpret_cast<CanUseCache>(can_use_cache);
166 bool has_func = get_exec_cache_func && init_cache_thread_local_func && set_hash_key_func;
167 bool can_use = can_use_cache_func && can_use_cache_func(aclnn_api);
168 if (!has_func || !can_use) {
169 return false;
170 }
171 init_cache_thread_local_func();
172 g_hash_offset = 0;
173
174 if (*hash_id == 0) {
175 GatherInfo(std::string(aclnn_api), args...);
176 *hash_id = calc_hash_id();
177 } else {
178 RefreshAddr(args...);
179 }
180
181 set_hash_key_func(*hash_id);
182 *executor = get_exec_cache_func(*hash_id, workspace_size);
183 if (*executor == nullptr) {
184 return false;
185 }
186 return true;
187 }
188 } // namespace mindspore::transform
189 #endif // MINDSPORE_CCSRC_TRANSFORM_ACL_IR_OP_API_CACHE_H_
190