1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifndef MINDSPORE_CCSRC_TRANSFORM_ACL_IR_OP_API_EXEC_H_
18 #define MINDSPORE_CCSRC_TRANSFORM_ACL_IR_OP_API_EXEC_H_
19
20 #include <dlfcn.h>
21 #include <vector>
22 #include <functional>
23 #include <string>
24 #include <utility>
25 #include <unordered_map>
26 #include "acl/acl_base.h"
27 #include "acl/acl.h"
28 #include "transform/acl_ir/op_api_convert.h"
29 #include "transform/acl_ir/op_api_cache.h"
30 #include "transform/acl_ir/op_api_util.h"
31 #include "transform/acl_ir/acl_allocator.h"
32 #include "transform/symbol/acl_rt_symbol.h"
33 #include "transform/symbol/acl_symbol.h"
34 #include "transform/symbol/symbol_utils.h"
35
36 namespace mindspore {
37 namespace transform {
38 using InitHugeMemThreadLocal = std::function<int(void *, bool)>;
39 using UnInitHugeMemThreadLocal = std::function<void(void *, bool)>;
40 using ReleaseHugeMem = std::function<void(void *, bool)>;
41 using ReleaseCallBack = std::function<void()>;
42 using RunApiFunc = int (*)(void *, uint64_t, transform::aclOpExecutor *, const aclrtStream);
43
44 class OpApiDefaultResource {
45 public:
46 static OpApiDefaultResource &GetInstance();
47
48 InitHugeMemThreadLocal init_mem_func();
49 UnInitHugeMemThreadLocal uninit_mem_func();
50 ReleaseHugeMem release_mem_func();
51
52 private:
53 OpApiDefaultResource() = default;
54 ~OpApiDefaultResource() = default;
55
56 InitHugeMemThreadLocal init_mem_func_{nullptr};
57 UnInitHugeMemThreadLocal uninit_mem_func_{nullptr};
58 ReleaseHugeMem release_mem_func_{nullptr};
59 };
60
61 template <typename Tuple>
62 class OpApiParams {
63 public:
OpApiParams(Tuple && converted_params)64 explicit OpApiParams(Tuple &&converted_params) : converted_params_(std::move(converted_params)) {}
OpApiParams(Tuple && converted_params,bool mem_clear)65 explicit OpApiParams(Tuple &&converted_params, bool mem_clear)
66 : converted_params_(std::move(converted_params)), mem_clear_(mem_clear) {}
OpApiParams(OpApiParams && other)67 explicit OpApiParams(OpApiParams &&other) : converted_params_(std::move(other.converted_params_)) {
68 need_free_ = other.need_free_;
69 mem_clear_ = other.mem_clear_;
70 other.need_free_ = false;
71 other.mem_clear_ = false;
72 }
73 OpApiParams &operator=(OpApiParams &&other) {
74 if (this == &other) {
75 return *this;
76 }
77
78 if (need_free_) {
79 ReleaseConvertTypes(converted_params_);
80 }
81
82 converted_params_ = std::move(other.converted_params_);
83 need_free_ = other.need_free_;
84 mem_clear_ = other.mem_clear_;
85 other.need_free_ = false;
86 other.mem_clear_ = false;
87 return *this;
88 }
89
90 OpApiParams() = delete;
91 OpApiParams(const OpApiParams &other) = delete;
92 OpApiParams &operator=(const OpApiParams &other) = delete;
93
~OpApiParams()94 ~OpApiParams() {
95 if (need_free_) {
96 ReleaseConvertTypes(converted_params_);
97 }
98 if (mem_clear_) {
99 auto release_mem_func = transform::OpApiDefaultResource::GetInstance().release_mem_func();
100 if (release_mem_func) {
101 release_mem_func(nullptr, false);
102 }
103 auto uninit_mem_func = transform::OpApiDefaultResource::GetInstance().uninit_mem_func();
104 if (uninit_mem_func) {
105 uninit_mem_func(nullptr, false);
106 }
107 }
108 }
109
converted_params()110 const Tuple &converted_params() const { return converted_params_; }
111
112 template <size_t i>
get()113 auto get() {
114 return std::get<i>(converted_params_);
115 }
116
117 private:
118 Tuple converted_params_;
119 bool need_free_{true};
120 bool mem_clear_{false};
121 };
122
123 template <typename Function, typename Tuple, size_t... I>
call(Function f,Tuple t,std::index_sequence<I...>)124 auto call(Function f, Tuple t, std::index_sequence<I...>) {
125 return f(std::get<I>(t)...);
126 }
127
128 template <typename Function, typename Tuple>
call(Function f,Tuple t)129 auto call(Function f, Tuple t) {
130 static constexpr auto size = std::tuple_size<Tuple>::value;
131 return call(f, t, std::make_index_sequence<size>{});
132 }
133
134 // Get output shape from acl tensor.
135 ShapeVector UpdateOutputShape(const aclTensor *tensor);
136 void LoadOpApiLib();
137 void AclnnInit();
138 void AclnnFinalize();
139
140 template <typename T>
141 class ReleaseCall {
142 public:
ReleaseCall(T && param)143 explicit ReleaseCall(T &¶m) : converted_params_(param) {}
operator()144 void operator()() {
145 ReleaseConvertTypes(converted_params_);
146 auto release_mem_func = transform::OpApiDefaultResource::GetInstance().release_mem_func();
147 if (release_mem_func) {
148 release_mem_func(nullptr, false);
149 }
150 }
151
152 private:
153 T converted_params_;
154 };
155
156 class ApiCachePool {
157 public:
158 ApiCachePool() = default;
159 ~ApiCachePool() = default;
160
get(const std::string & str)161 const char *get(const std::string &str) {
162 auto it = pool_.find(str);
163 if (it != pool_.end()) {
164 return it->second.c_str();
165 }
166 auto [map_iter, inserted] = pool_.emplace(str, str);
167 if (!inserted) {
168 MS_LOG(EXCEPTION) << "Failed to cache api.";
169 }
170 return map_iter->second.c_str();
171 }
172
173 private:
174 std::unordered_map<std::string, std::string> pool_;
175 };
176
177 // For normal generate executor.
178 #define GEN_EXECUTOR(aclnn_api, ...) \
179 [](const std::string &api_str, const std::string &workspace_api_name, const auto &... args) -> auto { \
180 static transform::ApiCachePool api_cache_pool; \
181 const char *api_name = api_cache_pool.get(api_str); \
182 static const auto get_workspace_size_func_ptr = transform::GetOpApiFunc(workspace_api_name.c_str()); \
183 if (get_workspace_size_func_ptr == nullptr) { \
184 MS_LOG(EXCEPTION) << workspace_api_name << " not in " << transform::GetOpApiLibName() << ", please check!"; \
185 } \
186 uint64_t workspace_size = 0; \
187 transform::aclOpExecutor *executor = nullptr; \
188 std::function<void()> release_func = nullptr; \
189 uint64_t *workspace_size_addr = &workspace_size; \
190 transform::aclOpExecutor **executor_addr = &executor; \
191 if (HitCache(api_name, executor_addr, workspace_size_addr, args...)) { \
192 return std::make_tuple(workspace_size, executor, release_func); \
193 } \
194 auto init_mem_func = transform::OpApiDefaultResource::GetInstance().init_mem_func(); \
195 if (init_mem_func) { \
196 init_mem_func(nullptr, false); \
197 } \
198 auto converted_params = transform::ConvertTypes(args..., workspace_size_addr, executor_addr); \
199 static auto get_workspace_size_func = \
200 transform::ConvertToOpApiFunc(converted_params, get_workspace_size_func_ptr); \
201 auto workspace_status = transform::call(get_workspace_size_func, converted_params); \
202 if (workspace_status != 0) { \
203 MS_LOG(EXCEPTION) << workspace_api_name << " call failed, please check!"; \
204 } \
205 auto releas_call = transform::ReleaseCall(std::move(converted_params)); \
206 release_func = std::function<void()>(releas_call); \
207 auto uninit_mem_func = transform::OpApiDefaultResource::GetInstance().uninit_mem_func(); \
208 if (uninit_mem_func) { \
209 uninit_mem_func(nullptr, false); \
210 } \
211 return std::make_tuple(workspace_size, executor, release_func); \
212 } \
213 (aclnn_api, aclnn_api + "GetWorkspaceSize", __VA_ARGS__)
214
215 // For custom generate executor.
216 #define GEN_EXECUTOR_CUST(aclnn_api, use_huge_pages, ...) \
217 [](const std::string &workspace_api_name, bool use_huge_pages, auto &... args) -> auto { \
218 static const auto get_workspace_size_func_ptr = transform::GetOpApiFunc(workspace_api_name.c_str()); \
219 if (get_workspace_size_func_ptr == nullptr) { \
220 MS_LOG(EXCEPTION) << workspace_api_name << " not in " << transform::GetOpApiLibName() << ", please check!"; \
221 } \
222 static const auto init_cache_thread_local = transform::GetOpApiFunc("InitPTACacheThreadLocal"); \
223 static const auto set_hash_key = transform::GetOpApiFunc("SetPTAHashKey"); \
224 transform::InitCacheThreadLocal init_cache_thread_local_func = \
225 reinterpret_cast<transform::InitCacheThreadLocal>(init_cache_thread_local); \
226 transform::SetHashKey set_hash_key_func = reinterpret_cast<transform::SetHashKey>(set_hash_key); \
227 if (init_cache_thread_local_func && set_hash_key_func) { \
228 init_cache_thread_local_func(); \
229 set_hash_key_func(0); \
230 } \
231 uint64_t workspace_size = 0; \
232 uint64_t *workspace_size_addr = &workspace_size; \
233 transform::aclOpExecutor *executor = nullptr; \
234 transform::aclOpExecutor **executor_addr = &executor; \
235 auto init_mem_func = transform::OpApiDefaultResource::GetInstance().init_mem_func(); \
236 if (use_huge_pages && init_mem_func) { \
237 init_mem_func(nullptr, false); \
238 } \
239 auto converted_params = transform::ConvertTypes(args..., workspace_size_addr, executor_addr); \
240 static auto get_workspace_size_func = \
241 transform::ConvertToOpApiFunc(converted_params, get_workspace_size_func_ptr); \
242 auto workspace_status = transform::call(get_workspace_size_func, converted_params); \
243 if (workspace_status != 0) { \
244 MS_LOG(EXCEPTION) << workspace_api_name << " call failed, please check!"; \
245 } \
246 return std::make_tuple(workspace_size, executor, \
247 transform::OpApiParams<decltype(converted_params)>(std::move(converted_params)), \
248 use_huge_pages); \
249 } \
250 (aclnn_api + "GetWorkspaceSize", use_huge_pages, __VA_ARGS__)
251
252 // For speed up generate executor.
253 #define GEN_EXECUTOR_BOOST(aclnn_api, hash_id, ...) \
254 [](const std::string &api_str, const std::string &workspace_api_name, uint64_t hash_id, \
255 const auto &... args) -> auto { \
256 static transform::ApiCachePool api_cache_pool; \
257 const char *api_name = api_cache_pool.get(api_str); \
258 static const auto get_workspace_size_func_ptr = transform::GetOpApiFunc(workspace_api_name.c_str()); \
259 if (get_workspace_size_func_ptr == nullptr) { \
260 MS_LOG(EXCEPTION) << workspace_api_name << " not in " << transform::GetOpApiLibName() << ", please check!"; \
261 } \
262 uint64_t workspace_size = 0; \
263 transform::aclOpExecutor *executor = nullptr; \
264 std::function<void()> release_func = nullptr; \
265 uint64_t *workspace_size_addr = &workspace_size; \
266 transform::aclOpExecutor **executor_addr = &executor; \
267 uint64_t new_hash_id = hash_id; \
268 if (HitCacheSingle(api_name, executor_addr, workspace_size_addr, &new_hash_id, args...)) { \
269 return std::make_tuple(workspace_size, executor, release_func, new_hash_id, true); \
270 } \
271 auto init_mem_func = transform::OpApiDefaultResource::GetInstance().init_mem_func(); \
272 if (init_mem_func) { \
273 init_mem_func(nullptr, false); \
274 } \
275 auto converted_params = transform::ConvertTypes(args..., workspace_size_addr, executor_addr); \
276 static auto get_workspace_size_func = \
277 transform::ConvertToOpApiFunc(converted_params, get_workspace_size_func_ptr); \
278 auto workspace_status = transform::call(get_workspace_size_func, converted_params); \
279 if (workspace_status != 0) { \
280 MS_LOG(EXCEPTION) << workspace_api_name << " call failed, please check!"; \
281 } \
282 auto releas_call = transform::ReleaseCall(std::move(converted_params)); \
283 release_func = std::function<void()>(releas_call); \
284 auto uninit_mem_func = transform::OpApiDefaultResource::GetInstance().uninit_mem_func(); \
285 if (uninit_mem_func) { \
286 uninit_mem_func(nullptr, false); \
287 } \
288 return std::make_tuple(workspace_size, executor, release_func, new_hash_id, false); \
289 } \
290 (aclnn_api, aclnn_api + "GetWorkspaceSize", hash_id, __VA_ARGS__)
291
292 // Async run op.
293 #define RUN_OP_API_ASYNC(aclnn_api, workspace_addr, workspace_size, executor, acl_stream, release_func) \
294 do { \
295 static const auto op_api_func = transform::GetOpApiFunc(aclnn_api.c_str()); \
296 if (op_api_func == nullptr) { \
297 MS_LOG(EXCEPTION) << aclnn_api << " not in " << transform::GetOpApiLibName() << ", please check!"; \
298 } \
299 auto run_api_func = reinterpret_cast<transform::RunApiFunc>(op_api_func); \
300 auto api_ret = run_api_func(workspace_addr, workspace_size, executor, acl_stream); \
301 if (api_ret != 0) { \
302 MS_LOG(EXCEPTION) << "Call " << aclnn_api << " failed, detail:" << CALL_ASCEND_API(aclGetRecentErrMsg); \
303 } \
304 if (release_func != nullptr) { \
305 release_func(); \
306 } \
307 } while (false)
308
309 // Sync run op.
310 #define RUN_OP_API_SYNC(aclnn_api, workspace_addr, workspace_size, executor, acl_stream) \
311 do { \
312 static const auto op_api_func = transform::GetOpApiFunc(aclnn_api.c_str()); \
313 if (op_api_func == nullptr) { \
314 MS_LOG(EXCEPTION) << aclnn_api << " not in " << transform::GetOpApiLibName() << ", please check!"; \
315 } \
316 auto run_api_func = reinterpret_cast<transform::RunApiFunc>(op_api_func); \
317 auto api_ret = run_api_func(workspace_addr, workspace_size, executor, acl_stream); \
318 if (api_ret != 0) { \
319 MS_LOG(EXCEPTION) << "Call " << aclnn_api << " failed, detail:" << CALL_ASCEND_API(aclGetRecentErrMsg); \
320 } \
321 auto ret = CALL_ASCEND_API(aclrtSynchronizeStream, acl_stream); \
322 if (ret != 0) { \
323 MS_LOG(EXCEPTION) << "Sync stream " << aclnn_api << " failed, detail:" << CALL_ASCEND_API(aclGetRecentErrMsg); \
324 } \
325 } while (false)
326 } // namespace transform
327 } // namespace mindspore
328
329 #endif // MINDSPORE_CCSRC_TRANSFORM_ACL_IR_OP_API_EXEC_H_
330