1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "nnacl/kernel.h"
18 #include "nnacl/tensor_c.h"
19 #include "nnacl/op_base.h"
20 #include "nnacl/kernel/init_exec_env.h"
21
22 static KernelCreator g_kernelCreatorRegistry[PrimType_MAX][16];
23
RegKernelCreator(int opType,int dataType,KernelCreator creator)24 void RegKernelCreator(int opType, int dataType, KernelCreator creator) {
25 g_kernelCreatorRegistry[opType][REGIST_DT(dataType)] = creator;
26 }
27
Init_MSC_VER_kernels(void)28 void Init_MSC_VER_kernels(void) {
29 #ifdef _MSC_VER
30 /* VS env do not support automatic register
31 * register here first time */
32 static bool inited = false;
33 if (inited == false) {
34 init_vs_kernels(g_kernelCreatorRegistry);
35 inited = true;
36 }
37 #endif
38 return;
39 }
40
checkOpValid(int opType)41 bool checkOpValid(int opType) {
42 if (opType < PrimType_MIN || opType >= PrimType_MAX) {
43 return false;
44 }
45 return true;
46 }
47
SupportKernelC(int opType,int dataType)48 bool SupportKernelC(int opType, int dataType) {
49 Init_MSC_VER_kernels();
50 const int length = 16;
51 if (REGIST_DT(dataType) < 0 || REGIST_DT(dataType) >= length) {
52 return false;
53 }
54 if (!checkOpValid(opType)) {
55 return false;
56 }
57 KernelCreator creator = g_kernelCreatorRegistry[opType][REGIST_DT(dataType)];
58 return creator != NULL;
59 }
60
DefaultThreadUpdate(int32_t type,int64_t load,int64_t store,int64_t unit,int thread)61 int DefaultThreadUpdate(int32_t type, int64_t load, int64_t store, int64_t unit, int thread) {
62 return thread > 0 ? thread : 1;
63 }
64
NNACLKernelInferShape(struct KernelBase * self)65 int NNACLKernelInferShape(struct KernelBase *self) { return NNACL_ERR; }
66
NNACLCheckKernelBase(KernelBase * kernel_base)67 int NNACLCheckKernelBase(KernelBase *kernel_base) {
68 CheckExecEnv(kernel_base);
69
70 if (kernel_base->param_ == NULL) {
71 return NNACL_ERR;
72 }
73
74 if (kernel_base->thread_nr_ <= 0 || kernel_base->thread_nr_ > MAX_THREAD_NUM) {
75 return NNACL_ERR;
76 }
77
78 if (kernel_base->in_size_ == 0 || kernel_base->in_ == NULL) {
79 return NNACL_ERR;
80 }
81 if (kernel_base->out_size_ == 0 || kernel_base->out_ == NULL) {
82 return NNACL_ERR;
83 }
84 return NNACL_OK;
85 }
86
CreateKernel(OpParameter * param,TensorC ** ins,size_t in_size,TensorC ** outs,size_t out_size,int data_type,ExecEnv * env)87 KernelBase *CreateKernel(OpParameter *param, TensorC **ins, size_t in_size, TensorC **outs, size_t out_size,
88 int data_type, ExecEnv *env) {
89 Init_MSC_VER_kernels();
90 if (param == NULL) {
91 return NULL;
92 }
93 if (!checkOpValid(param->type_)) {
94 return NULL;
95 }
96
97 KernelCreator creator = g_kernelCreatorRegistry[param->type_][REGIST_DT(data_type)];
98 if (creator == NULL) {
99 return NULL;
100 }
101
102 KernelBase *kernel_base = creator(param, data_type);
103 if (kernel_base == NULL) {
104 return NULL;
105 }
106
107 kernel_base->InferShape = NNACLKernelInferShape;
108 kernel_base->UpdateThread = DefaultThreadUpdate;
109 kernel_base->env_ = env;
110 kernel_base->param_ = param;
111 kernel_base->thread_nr_ = param->thread_num_;
112 kernel_base->train_session_ = param->is_train_session_;
113 kernel_base->in_ = ins;
114 kernel_base->in_size_ = in_size;
115 kernel_base->out_ = outs;
116 kernel_base->out_size_ = out_size;
117
118 int ret = NNACLCheckKernelBase(kernel_base);
119 if (ret != NNACL_OK) {
120 return NULL;
121 }
122
123 return kernel_base;
124 }
125