• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_TRANSFORM_ACL_IR_OP_API_EXEC_H_
18 #define MINDSPORE_CCSRC_TRANSFORM_ACL_IR_OP_API_EXEC_H_
19 
20 #include <dlfcn.h>
21 #include <vector>
22 #include <functional>
23 #include <string>
24 #include <utility>
25 #include <unordered_map>
26 #include "acl/acl_base.h"
27 #include "acl/acl.h"
28 #include "transform/acl_ir/op_api_convert.h"
29 #include "transform/acl_ir/op_api_cache.h"
30 #include "transform/acl_ir/op_api_util.h"
31 #include "transform/acl_ir/acl_allocator.h"
32 #include "transform/symbol/acl_rt_symbol.h"
33 #include "transform/symbol/acl_symbol.h"
34 #include "transform/symbol/symbol_utils.h"
35 
36 namespace mindspore {
37 namespace transform {
38 using InitHugeMemThreadLocal = std::function<int(void *, bool)>;
39 using UnInitHugeMemThreadLocal = std::function<void(void *, bool)>;
40 using ReleaseHugeMem = std::function<void(void *, bool)>;
41 using ReleaseCallBack = std::function<void()>;
42 using RunApiFunc = int (*)(void *, uint64_t, transform::aclOpExecutor *, const aclrtStream);
43 
44 class OpApiDefaultResource {
45  public:
46   static OpApiDefaultResource &GetInstance();
47 
48   InitHugeMemThreadLocal init_mem_func();
49   UnInitHugeMemThreadLocal uninit_mem_func();
50   ReleaseHugeMem release_mem_func();
51 
52  private:
53   OpApiDefaultResource() = default;
54   ~OpApiDefaultResource() = default;
55 
56   InitHugeMemThreadLocal init_mem_func_{nullptr};
57   UnInitHugeMemThreadLocal uninit_mem_func_{nullptr};
58   ReleaseHugeMem release_mem_func_{nullptr};
59 };
60 
61 template <typename Tuple>
62 class OpApiParams {
63  public:
OpApiParams(Tuple && converted_params)64   explicit OpApiParams(Tuple &&converted_params) : converted_params_(std::move(converted_params)) {}
OpApiParams(Tuple && converted_params,bool mem_clear)65   explicit OpApiParams(Tuple &&converted_params, bool mem_clear)
66       : converted_params_(std::move(converted_params)), mem_clear_(mem_clear) {}
OpApiParams(OpApiParams && other)67   explicit OpApiParams(OpApiParams &&other) : converted_params_(std::move(other.converted_params_)) {
68     need_free_ = other.need_free_;
69     mem_clear_ = other.mem_clear_;
70     other.need_free_ = false;
71     other.mem_clear_ = false;
72   }
73   OpApiParams &operator=(OpApiParams &&other) {
74     if (this == &other) {
75       return *this;
76     }
77 
78     if (need_free_) {
79       ReleaseConvertTypes(converted_params_);
80     }
81 
82     converted_params_ = std::move(other.converted_params_);
83     need_free_ = other.need_free_;
84     mem_clear_ = other.mem_clear_;
85     other.need_free_ = false;
86     other.mem_clear_ = false;
87     return *this;
88   }
89 
90   OpApiParams() = delete;
91   OpApiParams(const OpApiParams &other) = delete;
92   OpApiParams &operator=(const OpApiParams &other) = delete;
93 
~OpApiParams()94   ~OpApiParams() {
95     if (need_free_) {
96       ReleaseConvertTypes(converted_params_);
97     }
98     if (mem_clear_) {
99       auto release_mem_func = transform::OpApiDefaultResource::GetInstance().release_mem_func();
100       if (release_mem_func) {
101         release_mem_func(nullptr, false);
102       }
103       auto uninit_mem_func = transform::OpApiDefaultResource::GetInstance().uninit_mem_func();
104       if (uninit_mem_func) {
105         uninit_mem_func(nullptr, false);
106       }
107     }
108   }
109 
converted_params()110   const Tuple &converted_params() const { return converted_params_; }
111 
112   template <size_t i>
get()113   auto get() {
114     return std::get<i>(converted_params_);
115   }
116 
117  private:
118   Tuple converted_params_;
119   bool need_free_{true};
120   bool mem_clear_{false};
121 };
122 
123 template <typename Function, typename Tuple, size_t... I>
call(Function f,Tuple t,std::index_sequence<I...>)124 auto call(Function f, Tuple t, std::index_sequence<I...>) {
125   return f(std::get<I>(t)...);
126 }
127 
128 template <typename Function, typename Tuple>
call(Function f,Tuple t)129 auto call(Function f, Tuple t) {
130   static constexpr auto size = std::tuple_size<Tuple>::value;
131   return call(f, t, std::make_index_sequence<size>{});
132 }
133 
134 // Get output shape from acl tensor.
135 ShapeVector UpdateOutputShape(const aclTensor *tensor);
136 void LoadOpApiLib();
137 void AclnnInit();
138 void AclnnFinalize();
139 
140 template <typename T>
141 class ReleaseCall {
142  public:
ReleaseCall(T && param)143   explicit ReleaseCall(T &&param) : converted_params_(param) {}
operator()144   void operator()() {
145     ReleaseConvertTypes(converted_params_);
146     auto release_mem_func = transform::OpApiDefaultResource::GetInstance().release_mem_func();
147     if (release_mem_func) {
148       release_mem_func(nullptr, false);
149     }
150   }
151 
152  private:
153   T converted_params_;
154 };
155 
156 class ApiCachePool {
157  public:
158   ApiCachePool() = default;
159   ~ApiCachePool() = default;
160 
get(const std::string & str)161   const char *get(const std::string &str) {
162     auto it = pool_.find(str);
163     if (it != pool_.end()) {
164       return it->second.c_str();
165     }
166     auto [map_iter, inserted] = pool_.emplace(str, str);
167     if (!inserted) {
168       MS_LOG(EXCEPTION) << "Failed to cache api.";
169     }
170     return map_iter->second.c_str();
171   }
172 
173  private:
174   std::unordered_map<std::string, std::string> pool_;
175 };
176 
177 // For normal generate executor.
178 #define GEN_EXECUTOR(aclnn_api, ...)                                                                              \
179   [](const std::string &api_str, const std::string &workspace_api_name, const auto &... args) -> auto {           \
180     static transform::ApiCachePool api_cache_pool;                                                                \
181     const char *api_name = api_cache_pool.get(api_str);                                                           \
182     static const auto get_workspace_size_func_ptr = transform::GetOpApiFunc(workspace_api_name.c_str());          \
183     if (get_workspace_size_func_ptr == nullptr) {                                                                 \
184       MS_LOG(EXCEPTION) << workspace_api_name << " not in " << transform::GetOpApiLibName() << ", please check!"; \
185     }                                                                                                             \
186     uint64_t workspace_size = 0;                                                                                  \
187     transform::aclOpExecutor *executor = nullptr;                                                                 \
188     std::function<void()> release_func = nullptr;                                                                 \
189     uint64_t *workspace_size_addr = &workspace_size;                                                              \
190     transform::aclOpExecutor **executor_addr = &executor;                                                         \
191     if (HitCache(api_name, executor_addr, workspace_size_addr, args...)) {                                        \
192       return std::make_tuple(workspace_size, executor, release_func);                                             \
193     }                                                                                                             \
194     auto init_mem_func = transform::OpApiDefaultResource::GetInstance().init_mem_func();                          \
195     if (init_mem_func) {                                                                                          \
196       init_mem_func(nullptr, false);                                                                              \
197     }                                                                                                             \
198     auto converted_params = transform::ConvertTypes(args..., workspace_size_addr, executor_addr);                 \
199     static auto get_workspace_size_func =                                                                         \
200       transform::ConvertToOpApiFunc(converted_params, get_workspace_size_func_ptr);                               \
201     auto workspace_status = transform::call(get_workspace_size_func, converted_params);                           \
202     if (workspace_status != 0) {                                                                                  \
203       MS_LOG(EXCEPTION) << workspace_api_name << " call failed, please check!";                                   \
204     }                                                                                                             \
205     auto releas_call = transform::ReleaseCall(std::move(converted_params));                                       \
206     release_func = std::function<void()>(releas_call);                                                            \
207     auto uninit_mem_func = transform::OpApiDefaultResource::GetInstance().uninit_mem_func();                      \
208     if (uninit_mem_func) {                                                                                        \
209       uninit_mem_func(nullptr, false);                                                                            \
210     }                                                                                                             \
211     return std::make_tuple(workspace_size, executor, release_func);                                               \
212   }                                                                                                               \
213   (aclnn_api, aclnn_api + "GetWorkspaceSize", __VA_ARGS__)
214 
215 // For custom generate executor.
216 #define GEN_EXECUTOR_CUST(aclnn_api, use_huge_pages, ...)                                                         \
217   [](const std::string &workspace_api_name, bool use_huge_pages, auto &... args) -> auto {                        \
218     static const auto get_workspace_size_func_ptr = transform::GetOpApiFunc(workspace_api_name.c_str());          \
219     if (get_workspace_size_func_ptr == nullptr) {                                                                 \
220       MS_LOG(EXCEPTION) << workspace_api_name << " not in " << transform::GetOpApiLibName() << ", please check!"; \
221     }                                                                                                             \
222     static const auto init_cache_thread_local = transform::GetOpApiFunc("InitPTACacheThreadLocal");               \
223     static const auto set_hash_key = transform::GetOpApiFunc("SetPTAHashKey");                                    \
224     transform::InitCacheThreadLocal init_cache_thread_local_func =                                                \
225       reinterpret_cast<transform::InitCacheThreadLocal>(init_cache_thread_local);                                 \
226     transform::SetHashKey set_hash_key_func = reinterpret_cast<transform::SetHashKey>(set_hash_key);              \
227     if (init_cache_thread_local_func && set_hash_key_func) {                                                      \
228       init_cache_thread_local_func();                                                                             \
229       set_hash_key_func(0);                                                                                       \
230     }                                                                                                             \
231     uint64_t workspace_size = 0;                                                                                  \
232     uint64_t *workspace_size_addr = &workspace_size;                                                              \
233     transform::aclOpExecutor *executor = nullptr;                                                                 \
234     transform::aclOpExecutor **executor_addr = &executor;                                                         \
235     auto init_mem_func = transform::OpApiDefaultResource::GetInstance().init_mem_func();                          \
236     if (use_huge_pages && init_mem_func) {                                                                        \
237       init_mem_func(nullptr, false);                                                                              \
238     }                                                                                                             \
239     auto converted_params = transform::ConvertTypes(args..., workspace_size_addr, executor_addr);                 \
240     static auto get_workspace_size_func =                                                                         \
241       transform::ConvertToOpApiFunc(converted_params, get_workspace_size_func_ptr);                               \
242     auto workspace_status = transform::call(get_workspace_size_func, converted_params);                           \
243     if (workspace_status != 0) {                                                                                  \
244       MS_LOG(EXCEPTION) << workspace_api_name << " call failed, please check!";                                   \
245     }                                                                                                             \
246     return std::make_tuple(workspace_size, executor,                                                              \
247                            transform::OpApiParams<decltype(converted_params)>(std::move(converted_params)),       \
248                            use_huge_pages);                                                                       \
249   }                                                                                                               \
250   (aclnn_api + "GetWorkspaceSize", use_huge_pages, __VA_ARGS__)
251 
252 // For speed up generate executor.
253 #define GEN_EXECUTOR_BOOST(aclnn_api, hash_id, ...)                                                               \
254   [](const std::string &api_str, const std::string &workspace_api_name, uint64_t hash_id,                         \
255      const auto &... args) -> auto {                                                                              \
256     static transform::ApiCachePool api_cache_pool;                                                                \
257     const char *api_name = api_cache_pool.get(api_str);                                                           \
258     static const auto get_workspace_size_func_ptr = transform::GetOpApiFunc(workspace_api_name.c_str());          \
259     if (get_workspace_size_func_ptr == nullptr) {                                                                 \
260       MS_LOG(EXCEPTION) << workspace_api_name << " not in " << transform::GetOpApiLibName() << ", please check!"; \
261     }                                                                                                             \
262     uint64_t workspace_size = 0;                                                                                  \
263     transform::aclOpExecutor *executor = nullptr;                                                                 \
264     std::function<void()> release_func = nullptr;                                                                 \
265     uint64_t *workspace_size_addr = &workspace_size;                                                              \
266     transform::aclOpExecutor **executor_addr = &executor;                                                         \
267     uint64_t new_hash_id = hash_id;                                                                               \
268     if (HitCacheSingle(api_name, executor_addr, workspace_size_addr, &new_hash_id, args...)) {                    \
269       return std::make_tuple(workspace_size, executor, release_func, new_hash_id, true);                          \
270     }                                                                                                             \
271     auto init_mem_func = transform::OpApiDefaultResource::GetInstance().init_mem_func();                          \
272     if (init_mem_func) {                                                                                          \
273       init_mem_func(nullptr, false);                                                                              \
274     }                                                                                                             \
275     auto converted_params = transform::ConvertTypes(args..., workspace_size_addr, executor_addr);                 \
276     static auto get_workspace_size_func =                                                                         \
277       transform::ConvertToOpApiFunc(converted_params, get_workspace_size_func_ptr);                               \
278     auto workspace_status = transform::call(get_workspace_size_func, converted_params);                           \
279     if (workspace_status != 0) {                                                                                  \
280       MS_LOG(EXCEPTION) << workspace_api_name << " call failed, please check!";                                   \
281     }                                                                                                             \
282     auto releas_call = transform::ReleaseCall(std::move(converted_params));                                       \
283     release_func = std::function<void()>(releas_call);                                                            \
284     auto uninit_mem_func = transform::OpApiDefaultResource::GetInstance().uninit_mem_func();                      \
285     if (uninit_mem_func) {                                                                                        \
286       uninit_mem_func(nullptr, false);                                                                            \
287     }                                                                                                             \
288     return std::make_tuple(workspace_size, executor, release_func, new_hash_id, false);                           \
289   }                                                                                                               \
290   (aclnn_api, aclnn_api + "GetWorkspaceSize", hash_id, __VA_ARGS__)
291 
292 // Async run op.
293 #define RUN_OP_API_ASYNC(aclnn_api, workspace_addr, workspace_size, executor, acl_stream, release_func)       \
294   do {                                                                                                        \
295     static const auto op_api_func = transform::GetOpApiFunc(aclnn_api.c_str());                               \
296     if (op_api_func == nullptr) {                                                                             \
297       MS_LOG(EXCEPTION) << aclnn_api << " not in " << transform::GetOpApiLibName() << ", please check!";      \
298     }                                                                                                         \
299     auto run_api_func = reinterpret_cast<transform::RunApiFunc>(op_api_func);                                 \
300     auto api_ret = run_api_func(workspace_addr, workspace_size, executor, acl_stream);                        \
301     if (api_ret != 0) {                                                                                       \
302       MS_LOG(EXCEPTION) << "Call " << aclnn_api << " failed, detail:" << CALL_ASCEND_API(aclGetRecentErrMsg); \
303     }                                                                                                         \
304     if (release_func != nullptr) {                                                                            \
305       release_func();                                                                                         \
306     }                                                                                                         \
307   } while (false)
308 
309 // Sync run op.
310 #define RUN_OP_API_SYNC(aclnn_api, workspace_addr, workspace_size, executor, acl_stream)                             \
311   do {                                                                                                               \
312     static const auto op_api_func = transform::GetOpApiFunc(aclnn_api.c_str());                                      \
313     if (op_api_func == nullptr) {                                                                                    \
314       MS_LOG(EXCEPTION) << aclnn_api << " not in " << transform::GetOpApiLibName() << ", please check!";             \
315     }                                                                                                                \
316     auto run_api_func = reinterpret_cast<transform::RunApiFunc>(op_api_func);                                        \
317     auto api_ret = run_api_func(workspace_addr, workspace_size, executor, acl_stream);                               \
318     if (api_ret != 0) {                                                                                              \
319       MS_LOG(EXCEPTION) << "Call " << aclnn_api << " failed, detail:" << CALL_ASCEND_API(aclGetRecentErrMsg);        \
320     }                                                                                                                \
321     auto ret = CALL_ASCEND_API(aclrtSynchronizeStream, acl_stream);                                                  \
322     if (ret != 0) {                                                                                                  \
323       MS_LOG(EXCEPTION) << "Sync stream " << aclnn_api << " failed, detail:" << CALL_ASCEND_API(aclGetRecentErrMsg); \
324     }                                                                                                                \
325   } while (false)
326 }  // namespace transform
327 }  // namespace mindspore
328 
329 #endif  // MINDSPORE_CCSRC_TRANSFORM_ACL_IR_OP_API_EXEC_H_
330