• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "nnacl/kernel.h"
18 #include "nnacl/tensor_c.h"
19 #include "nnacl/op_base.h"
20 #include "nnacl/kernel/init_exec_env.h"
21 
22 static KernelCreator g_kernelCreatorRegistry[PrimType_MAX][16];
23 
RegKernelCreator(int opType,int dataType,KernelCreator creator)24 void RegKernelCreator(int opType, int dataType, KernelCreator creator) {
25   g_kernelCreatorRegistry[opType][REGIST_DT(dataType)] = creator;
26 }
27 
Init_MSC_VER_kernels(void)28 void Init_MSC_VER_kernels(void) {
29 #ifdef _MSC_VER
30   /* VS env do not support automatic register
31    * register here first time */
32   static bool inited = false;
33   if (inited == false) {
34     init_vs_kernels(g_kernelCreatorRegistry);
35     inited = true;
36   }
37 #endif
38   return;
39 }
40 
checkOpValid(int opType)41 bool checkOpValid(int opType) {
42   if (opType < PrimType_MIN || opType >= PrimType_MAX) {
43     return false;
44   }
45   return true;
46 }
47 
SupportKernelC(int opType,int dataType)48 bool SupportKernelC(int opType, int dataType) {
49   Init_MSC_VER_kernels();
50   const int length = 16;
51   if (REGIST_DT(dataType) < 0 || REGIST_DT(dataType) >= length) {
52     return false;
53   }
54   if (!checkOpValid(opType)) {
55     return false;
56   }
57   KernelCreator creator = g_kernelCreatorRegistry[opType][REGIST_DT(dataType)];
58   return creator != NULL;
59 }
60 
DefaultThreadUpdate(int32_t type,int64_t load,int64_t store,int64_t unit,int thread)61 int DefaultThreadUpdate(int32_t type, int64_t load, int64_t store, int64_t unit, int thread) {
62   return thread > 0 ? thread : 1;
63 }
64 
NNACLKernelInferShape(struct KernelBase * self)65 int NNACLKernelInferShape(struct KernelBase *self) { return NNACL_ERR; }
66 
NNACLCheckKernelBase(KernelBase * kernel_base)67 int NNACLCheckKernelBase(KernelBase *kernel_base) {
68   CheckExecEnv(kernel_base);
69 
70   if (kernel_base->param_ == NULL) {
71     return NNACL_ERR;
72   }
73 
74   if (kernel_base->thread_nr_ <= 0 || kernel_base->thread_nr_ > MAX_THREAD_NUM) {
75     return NNACL_ERR;
76   }
77 
78   if (kernel_base->in_size_ == 0 || kernel_base->in_ == NULL) {
79     return NNACL_ERR;
80   }
81   if (kernel_base->out_size_ == 0 || kernel_base->out_ == NULL) {
82     return NNACL_ERR;
83   }
84   return NNACL_OK;
85 }
86 
CreateKernel(OpParameter * param,TensorC ** ins,size_t in_size,TensorC ** outs,size_t out_size,int data_type,ExecEnv * env)87 KernelBase *CreateKernel(OpParameter *param, TensorC **ins, size_t in_size, TensorC **outs, size_t out_size,
88                          int data_type, ExecEnv *env) {
89   Init_MSC_VER_kernels();
90   if (param == NULL) {
91     return NULL;
92   }
93   if (!checkOpValid(param->type_)) {
94     return NULL;
95   }
96 
97   KernelCreator creator = g_kernelCreatorRegistry[param->type_][REGIST_DT(data_type)];
98   if (creator == NULL) {
99     return NULL;
100   }
101 
102   KernelBase *kernel_base = creator(param, data_type);
103   if (kernel_base == NULL) {
104     return NULL;
105   }
106 
107   kernel_base->InferShape = NNACLKernelInferShape;
108   kernel_base->UpdateThread = DefaultThreadUpdate;
109   kernel_base->env_ = env;
110   kernel_base->param_ = param;
111   kernel_base->thread_nr_ = param->thread_num_;
112   kernel_base->train_session_ = param->is_train_session_;
113   kernel_base->in_ = ins;
114   kernel_base->in_size_ = in_size;
115   kernel_base->out_ = outs;
116   kernel_base->out_size_ = out_size;
117 
118   int ret = NNACLCheckKernelBase(kernel_base);
119   if (ret != NNACL_OK) {
120     return NULL;
121   }
122 
123   return kernel_base;
124 }
125