• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "nnacl/fp32/conv_sw_arm64_fp32.h"
18 #include "nnacl/fp32/conv_sw.h"
19 
CheckArm64UseSWConv(const ConvParameter * conv_param)20 bool CheckArm64UseSWConv(const ConvParameter *conv_param) {
21   if (conv_param->kernel_h_ == 1 && conv_param->kernel_w_ == 1) {
22     return false;
23   }
24   if (conv_param->input_channel_ > C128NUM) {
25     return false;
26   }
27   if (conv_param->kernel_h_ > C5NUM || conv_param->kernel_w_ > C5NUM) {
28     return false;
29   }
30   if (conv_param->dilation_h_ != 1 || conv_param->dilation_w_ != 1) {
31     return false;
32   }
33   if (conv_param->stride_w_ > C3NUM) {
34     return false;
35   }
36   if (conv_param->input_h_ / conv_param->kernel_h_ < C48NUM || conv_param->input_w_ / conv_param->kernel_w_ < C48NUM) {
37     return false;
38   }
39   return true;
40 }
41 
42 typedef void (*SWConvKernel)(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h,
43                              size_t kernel_w, size_t act_flag, size_t oc_algin, size_t ic_algin, size_t in_kw_step,
44                              size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, size_t write_mode);
45 
46 void SWConv1x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h,
47                      size_t kernel_w, size_t act_flag, size_t oc_algin, size_t ic_algin, size_t in_kw_step,
48                      size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, size_t write_mode);
49 
50 void SWConv1x16Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h,
51                       size_t kernel_w, size_t act_flag, size_t oc_algin, size_t ic_algin, size_t in_kw_step,
52                       size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, size_t write_mode);
53 
54 void SWConv2x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h,
55                      size_t kernel_w, size_t act_flag, size_t oc_algin, size_t ic_algin, size_t in_kw_step,
56                      size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, size_t write_mode);
57 
58 void SWConv2x16Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h,
59                       size_t kernel_w, size_t act_flag, size_t oc_algin, size_t ic_algin, size_t in_kw_step,
60                       size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, size_t write_mode);
61 
62 void SWConv3x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h,
63                      size_t kernel_w, size_t act_flag, size_t oc_algin, size_t ic_algin, size_t in_kw_step,
64                      size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, size_t write_mode);
65 
66 void SWConv3x16Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h,
67                       size_t kernel_w, size_t act_flag, size_t oc_algin, size_t ic_algin, size_t in_kw_step,
68                       size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, size_t write_mode);
69 
70 void SWConv4x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h,
71                      size_t kernel_w, size_t act_flag, size_t oc_algin, size_t ic_algin, size_t in_kw_step,
72                      size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, size_t write_mode);
73 
74 void SWConv4x16Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h,
75                       size_t kernel_w, size_t act_flag, size_t oc_algin, size_t ic_algin, size_t in_kw_step,
76                       size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, size_t write_mode);
77 
78 void SWConv5x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h,
79                      size_t kernel_w, size_t act_flag, size_t oc_algin, size_t ic_algin, size_t in_kw_step,
80                      size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, size_t write_mode);
81 
82 void SWConv5x16Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t kernel_h,
83                       size_t kernel_w, size_t act_flag, size_t oc_algin, size_t ic_algin, size_t in_kw_step,
84                       size_t in_kh_step, size_t in_sw_step, size_t kw_remainder, size_t write_mode);
85 
86 #define ROW_NUM_LIST const int ow_block_num[2] = {5, 5};
87 #define KERNEL_LIST                                                                        \
88   const SWConvKernel kernel[2][5] = {                                                      \
89     {SWConv1x8Kernel, SWConv2x8Kernel, SWConv3x8Kernel, SWConv4x8Kernel, SWConv5x8Kernel}, \
90     {SWConv1x16Kernel, SWConv2x16Kernel, SWConv3x16Kernel, SWConv4x16Kernel, SWConv5x16Kernel}};
91 #define COMPUTE_CORE                                                                                              \
92   kernel[oc_block - 1][ow_block - 1](dst_oc + ow * out_w_step, src_w, weight, bias, kernel_h, kernel_w, act_type, \
93                                      out_block_step, ic_algin, in_kw_step, in_kh_step, in_sw_step, 0, write_mode);
94 #define OUTER_COMPUTE                                                                                                 \
95   kernel(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw, act_type,                 \
96          sw_param->out_block_step_, sw_param->ic_align_, sw_param->in_kw_step_, sw_param->in_kh_step_,                \
97          sw_param->in_sw_step_, (conv_param->kernel_w_ - end_kw + start_kw) * C8NUM * oc_block * sw_param->ic_align_, \
98          write_mode);
99 GenerateConvSWFunc(Arm64, C2NUM, ROW_NUM_LIST, KERNEL_LIST, COMPUTE_CORE, OUTER_COMPUTE);
100