• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef NNACL_KERNEL_H_
17 #define NNACL_KERNEL_H_
18 
19 #include "nnacl/op_base.h"
20 #include "nnacl/infer/common_infer.h"
21 
22 typedef struct ExecEnv {
23   void *allocator_;
24   void *thread_pool_;
25   void *(*Alloc)(void *allocator, size_t sz);
26   void (*Free)(void *allocator, void *ptr);
27   int (*ParallelLaunch)(void *thread_pool, void *task, void *param, int task_num);
28 } ExecEnv;
29 
30 typedef struct KernelBase {
31   int (*Release)(struct KernelBase *self);
32   int (*Prepare)(struct KernelBase *self);
33   int (*Compute)(struct KernelBase *self);
34   int (*Resize)(struct KernelBase *self);
35   int (*InferShape)(struct KernelBase *self);
36   int (*UpdateThread)(int32_t type, int64_t load, int64_t store, int64_t unit, int thread);
37   OpParameter *param_;
38   int thread_nr_;
39   ExecEnv *env_;
40   TensorC **in_;
41   size_t in_size_;
42   TensorC **out_;
43   size_t out_size_;
44   bool train_session_;
45   void *workspace_; /* only used in train */
46   int work_size_;   /* only used in train */
47 } KernelBase;
48 
49 #ifdef _MSC_VER
50 #define REG_KERNEL_CREATOR(op, data_type, func)
51 #else
52 #define REG_KERNEL_CREATOR(op, data_type, func) \
53   __attribute__((constructor(102))) void Reg##op##data_type##Creator() { RegKernelCreator(op, data_type, func); }
54 #endif
55 
56 #define REGIST_DT(DataType) (DataType - kNumberTypeBegin - 1)
57 typedef KernelBase *(*KernelCreator)(OpParameter *param, int data_type);
58 void RegKernelCreator(int opType, int dataType, KernelCreator func);
59 
60 #ifdef __cplusplus
61 extern "C" {
62 #endif
63 KernelBase *CreateKernel(OpParameter *param, TensorC **ins, size_t in_size, TensorC **outs, size_t out_size,
64                          int data_type, ExecEnv *env);
65 bool SupportKernelC(int opType, int dataType);
66 #ifdef __cplusplus
67 }
68 #endif
69 #endif
70