• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_TRANSFORM_ACL_IR_OP_API_CACHE_H_
18 #define MINDSPORE_CCSRC_TRANSFORM_ACL_IR_OP_API_CACHE_H_
19 
20 #include <string>
21 #include <vector>
22 #include <utility>
23 #include "transform/acl_ir/op_api_convert.h"
24 
25 namespace mindspore::transform {
26 typedef aclOpExecutor *(*GetExecCache)(uint64_t, uint64_t *);
27 typedef void (*InitCacheThreadLocal)();
28 typedef void (*SetHashKey)(uint64_t);
29 typedef bool (*CanUseCache)(const char *);
30 
31 constexpr int g_hash_buf_size = 8192;
32 constexpr int g_hash_buf_max_size = g_hash_buf_size + 1024;
33 extern thread_local char g_hash_buf[g_hash_buf_size];
34 extern thread_local int g_hash_offset;
35 
MemcpyToBuf(const void * data_expression,size_t size_expression)36 inline void MemcpyToBuf(const void *data_expression, size_t size_expression) {
37   if (size_expression == 0) {
38     return;
39   }
40   if (g_hash_offset + size_expression >= g_hash_buf_size) {
41     g_hash_offset = g_hash_buf_max_size;
42     return;
43   }
44   auto ret = memcpy_sp(g_hash_buf + g_hash_offset, g_hash_buf_size - g_hash_offset, data_expression, size_expression);
45   if (ret != EOK) {
46     MS_LOG(EXCEPTION) << "Failed to memcpy!";
47   }
48   g_hash_offset += size_expression;
49 }
50 
51 void GatherInfo(mindspore::kernel::KernelTensor *);
52 void GatherInfo(const std::pair<mindspore::kernel::KernelTensor *, bool> &);
53 void GatherInfo(const std::vector<mindspore::kernel::KernelTensor *> &);
54 void GatherInfo(const device::DeviceAddressPtr &);
55 
56 void GatherInfo(const mindspore::tensor::BaseTensorPtr &);
57 void GatherInfo(const std::optional<tensor::BaseTensorPtr> &);
58 void GatherInfo(const std::vector<tensor::BaseTensorPtr> &);
59 void GatherInfo(const mindspore::tensor::TensorPtr &);
60 void GatherInfo(const std::optional<tensor::TensorPtr> &);
61 void GatherInfo(const std::vector<tensor::TensorPtr> &);
62 
63 template <typename T>
GatherInfo(const T & value)64 void GatherInfo(const T &value) {
65   MemcpyToBuf(&value, sizeof(T));
66 }
67 
68 template <typename T>
GatherInfo(std::optional<T> value)69 void GatherInfo(std::optional<T> value) {
70   if (value.has_value()) {
71     GatherInfo(value.value());
72   }
73 }
74 
75 void GatherInfo(const string &);
76 void GatherInfo(const std::optional<string> &);
77 
78 void GatherInfo(const ScalarPtr &);
79 void GatherInfo(const std::optional<ScalarPtr> &);
80 
81 void GatherInfo(const TypePtr &);
82 void GatherInfo(const std::optional<TypePtr> &);
83 
84 template <typename T>
GatherInfo(const std::vector<T> & values)85 void GatherInfo(const std::vector<T> &values) {
86   MemcpyToBuf(values.data(), values.size() * sizeof(T));
87 }
88 
GatherInfo(TypeId type_id)89 inline void GatherInfo(TypeId type_id) { MemcpyToBuf(&type_id, sizeof(int)); }
90 
91 void GatherInfo();
92 
93 template <typename T, typename... Args>
GatherInfo(const T & arg,const Args &...args)94 void GatherInfo(const T &arg, const Args &... args) {
95   GatherInfo(arg);
96   GatherInfo(args...);
97 }
98 
99 void RefreshAddr(mindspore::kernel::KernelTensor *);
100 void RefreshAddr(const std::pair<mindspore::kernel::KernelTensor *, bool> &);
RefreshAddr(const std::vector<mindspore::kernel::KernelTensor * > & tensor_list)101 inline void RefreshAddr(const std::vector<mindspore::kernel::KernelTensor *> &tensor_list) {
102   for (auto tensor : tensor_list) {
103     RefreshAddr(tensor);
104   }
105 }
106 
107 template <typename Args>
RefreshAddr(const Args & values)108 void RefreshAddr(const Args &values) {}
109 
RefreshAddr()110 inline void RefreshAddr() {}
111 
112 template <typename T, typename... Args>
RefreshAddr(const T & arg,const Args &...args)113 void RefreshAddr(const T &arg, const Args &... args) {
114   RefreshAddr(arg);
115   RefreshAddr(args...);
116 }
117 
118 uint64_t calc_hash_id();
119 uint64_t gen_hash(const void *key, const int len, const uint32_t seed = 0xdeadb0d7);
120 
121 template <typename... Args>
HitCache(const char * aclnn_api,aclOpExecutor ** executor,uint64_t * workspace_size,const Args &...args)122 bool HitCache(const char *aclnn_api, aclOpExecutor **executor, uint64_t *workspace_size, const Args &... args) {
123   static const auto get_exec_cache = transform::GetOpApiFunc("PTAGetExecCache");
124   static const auto init_cache_thread_local = transform::GetOpApiFunc("InitPTACacheThreadLocal");
125   static const auto set_hash_key = transform::GetOpApiFunc("SetPTAHashKey");
126   static const auto can_use_cache = transform::GetOpApiFunc("CanUsePTACache");
127   GetExecCache get_exec_cache_func = reinterpret_cast<GetExecCache>(get_exec_cache);
128   InitCacheThreadLocal init_cache_thread_local_func = reinterpret_cast<InitCacheThreadLocal>(init_cache_thread_local);
129   SetHashKey set_hash_key_func = reinterpret_cast<SetHashKey>(set_hash_key);
130   CanUseCache can_use_cache_func = reinterpret_cast<CanUseCache>(can_use_cache);
131   bool has_func = get_exec_cache_func && init_cache_thread_local_func && set_hash_key_func;
132   bool can_use = can_use_cache_func && can_use_cache_func(aclnn_api);
133   if (!has_func || !can_use) {
134     return false;
135   }
136   init_cache_thread_local_func();
137   g_hash_offset = 0;
138   GatherInfo(std::string(aclnn_api), args...);
139   uint64_t hash_id = calc_hash_id();
140   set_hash_key_func(hash_id);
141   *executor = get_exec_cache_func(hash_id, workspace_size);
142   if (*executor == nullptr) {
143     return false;
144   }
145   return true;
146 }
147 
148 template <typename... Args>
CalcOpApiHash(const std::string & arg,const Args &...args)149 uint64_t CalcOpApiHash(const std::string &arg, const Args &... args) {
150   g_hash_offset = 0;
151   GatherInfo(arg, args...);
152   return calc_hash_id();
153 }
154 
155 template <typename... Args>
HitCacheSingle(const char * aclnn_api,aclOpExecutor ** executor,uint64_t * workspace_size,uint64_t * hash_id,const Args &...args)156 bool HitCacheSingle(const char *aclnn_api, aclOpExecutor **executor, uint64_t *workspace_size, uint64_t *hash_id,
157                     const Args &... args) {
158   static const auto get_exec_cache = transform::GetOpApiFunc("PTAGetExecCache");
159   static const auto init_cache_thread_local = transform::GetOpApiFunc("InitPTACacheThreadLocal");
160   static const auto set_hash_key = transform::GetOpApiFunc("SetPTAHashKey");
161   static const auto can_use_cache = transform::GetOpApiFunc("CanUsePTACache");
162   GetExecCache get_exec_cache_func = reinterpret_cast<GetExecCache>(get_exec_cache);
163   InitCacheThreadLocal init_cache_thread_local_func = reinterpret_cast<InitCacheThreadLocal>(init_cache_thread_local);
164   SetHashKey set_hash_key_func = reinterpret_cast<SetHashKey>(set_hash_key);
165   CanUseCache can_use_cache_func = reinterpret_cast<CanUseCache>(can_use_cache);
166   bool has_func = get_exec_cache_func && init_cache_thread_local_func && set_hash_key_func;
167   bool can_use = can_use_cache_func && can_use_cache_func(aclnn_api);
168   if (!has_func || !can_use) {
169     return false;
170   }
171   init_cache_thread_local_func();
172   g_hash_offset = 0;
173 
174   if (*hash_id == 0) {
175     GatherInfo(std::string(aclnn_api), args...);
176     *hash_id = calc_hash_id();
177   } else {
178     RefreshAddr(args...);
179   }
180 
181   set_hash_key_func(*hash_id);
182   *executor = get_exec_cache_func(*hash_id, workspace_size);
183   if (*executor == nullptr) {
184     return false;
185   }
186   return true;
187 }
188 }  // namespace mindspore::transform
189 #endif  // MINDSPORE_CCSRC_TRANSFORM_ACL_IR_OP_API_CACHE_H_
190