1 /**
2 * Copyright 2024 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CUSTOM_ACLNN_KERNEL_MOD_H_
17 #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CUSTOM_ACLNN_KERNEL_MOD_H_
18 #include <vector>
19 #include <string>
20 #include <utility>
21 #include <memory>
22 #include "ops/base_operator.h"
23 #include "plugin/device/ascend/kernel/opapi/aclnn_kernel_mod.h"
24 #include "transform/acl_ir/acl_convert.h"
25
26 namespace mindspore {
27 namespace kernel {
28 constexpr size_t kTensorNum1 = 1;
29 constexpr size_t kTensorNum2 = 2;
30 constexpr size_t kTensorNum3 = 3;
31 constexpr size_t kTensorNum4 = 4;
32 constexpr size_t kTensorNum5 = 5;
33 constexpr size_t kTensorNum6 = 6;
34 constexpr size_t kTensorNum7 = 7;
35 constexpr size_t kTensorNum8 = 8;
36 constexpr size_t kTensorNum9 = 9;
37 constexpr size_t kTensorNum10 = 10;
38 constexpr size_t kTensorNum11 = 11;
39 constexpr size_t kTensorNum12 = 12;
40
41 template <size_t N>
42 class CustomAclnnKernelMod : public AclnnKernelMod {
43 public:
CustomAclnnKernelMod(std::string op_type)44 explicit CustomAclnnKernelMod(std::string op_type) : AclnnKernelMod(std::move(op_type)) {}
45 ~CustomAclnnKernelMod() = default;
GetWorkSpaceInfo(const std::vector<KernelTensor * > & inputs,const std::vector<KernelTensor * > & outputs)46 void GetWorkSpaceInfo(const std::vector<KernelTensor *> &inputs,
47 const std::vector<KernelTensor *> &outputs) override {
48 const auto &res_tuple = this->GetKernelTuple<N>(inputs, outputs);
49 std::apply(
50 [this](const auto &... args) {
51 hash_id_ = transform::CalcOpApiHash(op_type_, args...);
52 if (cache_hash_.count(hash_id_) == 0) {
53 const bool use_huge_pages = false;
54 auto return_value = GEN_EXECUTOR_CUST(op_type_, use_huge_pages, args...);
55 UpdateWorkspace(return_value);
56 } else {
57 auto return_value = GEN_EXECUTOR_BOOST(op_type_, hash_id_, args...);
58 UpdateWorkspace(return_value);
59 }
60 },
61 res_tuple);
62 }
Launch(const std::vector<KernelTensor * > & inputs,const std::vector<KernelTensor * > & workspace,const std::vector<KernelTensor * > & outputs,void * stream_ptr)63 bool Launch(const std::vector<KernelTensor *> &inputs, const std::vector<KernelTensor *> &workspace,
64 const std::vector<KernelTensor *> &outputs, void *stream_ptr) override {
65 ParseGenExecutor(GenExecutor(inputs, outputs));
66 RunOp(stream_ptr, workspace);
67 return true;
68 }
69
70 private:
71 template <typename... Ts>
GenExecutor(const std::vector<Ts> &...vecs)72 auto GenExecutor(const std::vector<Ts> &... vecs) {
73 const auto &op_type = this->op_type_;
74 const auto &hash_id = this->hash_id_;
75 const auto &res_tuple = this->GetKernelTuple<N>(vecs...);
76 auto executor_info = std::apply(
77 [&op_type, &hash_id](const auto &... args) { return GEN_EXECUTOR_BOOST(op_type, hash_id, args...); }, res_tuple);
78 return executor_info;
79 }
80
RunOp(void * stream_ptr,const std::vector<KernelTensor * > & workspace)81 void RunOp(void *stream_ptr, const std::vector<KernelTensor *> &workspace) {
82 if (workspace_size_list_.empty()) {
83 RUN_OP_API_ASYNC(op_type_, nullptr, 0, executor_, stream_ptr, release_func_);
84 } else {
85 if (workspace.empty()) {
86 MS_LOG(EXCEPTION) << "Failed to allocate workspace tensor!";
87 }
88 auto workspace_tensor = workspace[0];
89 if (workspace_tensor->size() != workspace_size_list_[0]) {
90 MS_LOG(EXCEPTION) << "Please check 'GetWorkSpaceInfo' and 'Launch' func. Expected workspace size is"
91 << workspace_size_list_[0] << ", but get " << workspace_tensor->size();
92 }
93 RUN_OP_API_ASYNC(op_type_, workspace_tensor->device_ptr(), workspace_size_list_[0], executor_, stream_ptr,
94 release_func_);
95 }
96 }
97 };
98
GetCustomAclNNKernelMod(const AnfNodePtr & anf_node)99 inline std::shared_ptr<AclnnKernelMod> GetCustomAclNNKernelMod(const AnfNodePtr &anf_node) {
100 auto primitive = GetCNodePrimitive(anf_node);
101 auto op_type = GetValue<std::string>(primitive->GetAttr("reg_op_name"));
102 auto arg_num = AnfUtils::GetInputTensorNum(anf_node) + AnfUtils::GetOutputTensorNum(anf_node);
103 MS_LOG(INFO) << "Kernel " << anf_node->fullname_with_scope() << " is a custom op, op type : " << op_type
104 << ", arg num : " << arg_num;
105 switch (arg_num) {
106 case kTensorNum1:
107 return std::make_shared<CustomAclnnKernelMod<kTensorNum1>>(op_type);
108 case kTensorNum2:
109 return std::make_shared<CustomAclnnKernelMod<kTensorNum2>>(op_type);
110 case kTensorNum3:
111 return std::make_shared<CustomAclnnKernelMod<kTensorNum3>>(op_type);
112 case kTensorNum4:
113 return std::make_shared<CustomAclnnKernelMod<kTensorNum4>>(op_type);
114 case kTensorNum5:
115 return std::make_shared<CustomAclnnKernelMod<kTensorNum5>>(op_type);
116 case kTensorNum6:
117 return std::make_shared<CustomAclnnKernelMod<kTensorNum6>>(op_type);
118 case kTensorNum7:
119 return std::make_shared<CustomAclnnKernelMod<kTensorNum7>>(op_type);
120 case kTensorNum8:
121 return std::make_shared<CustomAclnnKernelMod<kTensorNum8>>(op_type);
122 case kTensorNum9:
123 return std::make_shared<CustomAclnnKernelMod<kTensorNum9>>(op_type);
124 case kTensorNum10:
125 return std::make_shared<CustomAclnnKernelMod<kTensorNum10>>(op_type);
126 case kTensorNum11:
127 return std::make_shared<CustomAclnnKernelMod<kTensorNum11>>(op_type);
128 case kTensorNum12:
129 return std::make_shared<CustomAclnnKernelMod<kTensorNum12>>(op_type);
130 default:
131 MS_LOG(ERROR) << "Aclnn custom only support arg nums between 0 and 12, but get: " << arg_num;
132 }
133 return nullptr;
134 }
135
136 } // namespace kernel
137 } // namespace mindspore
138
139 #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CUSTOM_ACLNN_KERNEL_MOD_H_
140