1 /** 2 * Copyright 2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef NNACL_KERNEL_CONVOLLUTION_WINOGRAD_BASE_H_ 18 #define NNACL_KERNEL_CONVOLLUTION_WINOGRAD_BASE_H_ 19 20 #include "nnacl/op_base.h" 21 #include "nnacl/tensor_c.h" 22 #include "nnacl/kernel.h" 23 #include "nnacl/conv_parameter.h" 24 #include "nnacl/kernel/convolution_base.h" 25 #include "nnacl/fp32/winograd_utils.h" 26 27 #define CONVOLUTION_WINOGRAD_MATRIX_SIZE 64 28 #define CONVOLUTION_WINOGRAD_TMP_BUFFER_SIZE 5 29 #define CONVOLUTION_WINOGRAD_INPUT_UNIT_SIZE 8 30 31 typedef float *TmpBufferAddress; 32 33 typedef struct ConvolutionWinogradBaseStruct { 34 ConvolutionBaseStruct conv_; 35 36 int kernel_unit_; 37 int input_unit_; 38 int output_unit_; 39 int oc_block_; 40 int tile_num_; 41 int tmp_data_tile_; 42 float *tmp_data_; 43 float *trans_input_; 44 float *gemm_out_; 45 float *col_buffer_; 46 float *opt_input_trans_; 47 float matrix_g_[CONVOLUTION_WINOGRAD_MATRIX_SIZE]; 48 float matrix_gt_[CONVOLUTION_WINOGRAD_MATRIX_SIZE]; 49 TmpBufferAddress tmp_buffer_address_list_[CONVOLUTION_WINOGRAD_TMP_BUFFER_SIZE]; 50 TransFuncList transfer_functions_; 51 52 int (*config_input_output_)(struct ConvolutionWinogradBaseStruct *winograd); 53 } ConvolutionWinogradBaseStruct; 54 55 void ConvWinoBasePackWeight(ConvolutionBaseStruct *conv); 56 int ConvWinoBaseConfigInputOutput(ConvolutionWinogradBaseStruct *winograd); 57 int ConvWinoBaseRunImpl(ConvolutionBaseStruct *conv, int task_id); 58 int ConvWinoBaseMallocWeightBiasData(ConvolutionBaseStruct *conv); 59 int ConvolutionWinogradBasePrepare(KernelBase *self); 60 int ConvolutionWinogradBaseResize(KernelBase *self); 61 int ConvolutionWinogradBaseRelease(KernelBase *self); 62 int ConvolutionWinogradBaseCompute(KernelBase *self); 63 ConvolutionWinogradBaseStruct *CreateConvWinogradBase(ConvParameter *conv_param); 64 65 #endif // NNACL_KERNEL_CONVOLLUTION_WINOGRAD_BASE_H_ 66