• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef NNACL_KERNEL_CONVOLLUTION_BASE_H_
18 #define NNACL_KERNEL_CONVOLLUTION_BASE_H_
19 
20 #include "nnacl/op_base.h"
21 #include "nnacl/tensor_c.h"
22 #include "nnacl/kernel.h"
23 #include "nnacl/conv_parameter.h"
24 #include "nnacl/tensor_c_utils.h"
25 
26 #define ConvMinBlock 1
27 
28 typedef struct ConvolutionBaseStruct {
29   KernelBase base_;
30   ConvComputeParam compute_;
31   bool weight_is_packed_;
32   bool is_repack_;
33   bool infershape_done_;
34   bool use_batch_cut_flag_;
35   FormatC out_format_;
36 
37   void *packed_weight_;
38   void *bias_data_;
39   void *origin_weight_;  // do not Free
40   void *origin_bias_;    // do not Free
41 
42   void (*init_global_variable_)(struct ConvolutionBaseStruct *conv_im2col);
43   int (*malloc_weight_bias_)(struct ConvolutionBaseStruct *conv_base);
44   void (*pack_weight_)(struct ConvolutionBaseStruct *conv_base);
45   int (*run_impl_)(struct ConvolutionBaseStruct *conv, int task_id);
46 
47   bool is_sharing_pack_;
48   void *shaing_manager_;
49   void (*free_sharing_weight_)(void *manager, void *tensor_data);
50   void *(*get_sharing_weight_)(void *manager, const void *tensor_data, const size_t size, bool *is_packed);
51 } ConvolutionBaseStruct;
52 
53 int ConvBaseUpdateParamInfo(ConvComputeParam *compute, ConvParameter *conv_param);
54 int ConvBaseUpdateComputeInfo(ConvolutionBaseStruct *conv);
55 void ConvBaseRelease(ConvolutionBaseStruct *conv);
56 int ConvBaseCheckResizeValid(ConvolutionBaseStruct *conv);
57 int ConvBasePrepare(ConvolutionBaseStruct *conv);
58 int ConvBaseInitConvWeightBias(ConvolutionBaseStruct *conv);
59 int ConvBaseRepackWeight(ConvolutionBaseStruct *conv);
60 void ConvBaseUpdateOriginWeightAndBias(ConvolutionBaseStruct *conv);
61 void *ConvBaseGetConvPackWeightData(ConvolutionBaseStruct *conv, int data_size);
62 
63 #endif  // NNACL_KERNEL_CONVOLLUTION_BASE_H_
64