• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version C2NUM.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-C2NUM.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifdef ENABLE_AVX
17 #include "nnacl/intrinsics/ms_simd_instructions.h"
18 #include "nnacl/fp32/matmul_fp32.h"
19 #include "nnacl/op_base.h"
20 
Deconv4X8AvxKernel(const float * src,const float * weight,float * dst,int col,int row,int depth,int stride)21 void Deconv4X8AvxKernel(const float *src, const float *weight, float *dst, int col, int row, int depth, int stride) {
22   __m256 res1 = _mm256_setzero_ps();
23   __m256 res4 = _mm256_setzero_ps();
24   __m256 res7 = _mm256_setzero_ps();
25   __m256 res10 = _mm256_setzero_ps();
26 
27   for (int d = 0; d < depth; ++d) {
28     __m256 w0 = _mm256_loadu_ps(weight);
29     __m256 tmp = _mm256_set1_ps(*src);
30     __m256 tmp1 = _mm256_set1_ps(*(src + C1NUM));
31     weight += C8NUM;
32     __m256 tmp2 = _mm256_set1_ps(*(src + C2NUM));
33     __m256 tmp3 = _mm256_set1_ps(*(src + C3NUM));
34     res1 = _mm256_fmadd_ps(tmp, w0, res1);
35     res4 = _mm256_fmadd_ps(tmp1, w0, res4);
36     src += C4NUM;
37     res7 = _mm256_fmadd_ps(tmp2, w0, res7);
38     res10 = _mm256_fmadd_ps(tmp3, w0, res10);
39   }
40   // write
41   _mm256_storeu_ps(dst, res1);
42   _mm256_storeu_ps(dst + C8NUM, res4);
43   _mm256_storeu_ps(dst + C16NUM, res7);
44   _mm256_storeu_ps(dst + C24NUM, res10);
45 }
46 
Deconv4X16AvxKernel(const float * src,const float * weight,float * dst,int col,int row,int depth,int stride)47 void Deconv4X16AvxKernel(const float *src, const float *weight, float *dst, int col, int row, int depth, int stride) {
48   __m256 res1 = _mm256_setzero_ps();
49   __m256 res2 = _mm256_setzero_ps();
50   __m256 res4 = _mm256_setzero_ps();
51   __m256 res5 = _mm256_setzero_ps();
52   __m256 res7 = _mm256_setzero_ps();
53   __m256 res8 = _mm256_setzero_ps();
54   __m256 res10 = _mm256_setzero_ps();
55   __m256 res11 = _mm256_setzero_ps();
56 
57   for (int d = 0; d < depth; ++d) {
58     __m256 w0 = _mm256_loadu_ps(weight);
59     __m256 w1 = _mm256_loadu_ps(weight + C8NUM);
60     weight += C16NUM;
61     __m256 tmp = _mm256_set1_ps(*src);
62     __m256 tmp1 = _mm256_set1_ps(*(src + C1NUM));
63     __m256 tmp2 = _mm256_set1_ps(*(src + C2NUM));
64     __m256 tmp3 = _mm256_set1_ps(*(src + C3NUM));
65     res1 = _mm256_fmadd_ps(tmp, w0, res1);
66     res2 = _mm256_fmadd_ps(tmp, w1, res2);
67     src += C4NUM;
68     res4 = _mm256_fmadd_ps(tmp1, w0, res4);
69     res5 = _mm256_fmadd_ps(tmp1, w1, res5);
70     res7 = _mm256_fmadd_ps(tmp2, w0, res7);
71     res8 = _mm256_fmadd_ps(tmp2, w1, res8);
72     res10 = _mm256_fmadd_ps(tmp3, w0, res10);
73     res11 = _mm256_fmadd_ps(tmp3, w1, res11);
74   }
75   // write
76   _mm256_storeu_ps(dst, res1);
77   _mm256_storeu_ps(dst + C8NUM, res4);
78   _mm256_storeu_ps(dst + C16NUM, res7);
79   _mm256_storeu_ps(dst + C24NUM, res10);
80 
81   _mm256_storeu_ps(dst + stride, res2);
82   _mm256_storeu_ps(dst + stride + C8NUM, res5);
83   _mm256_storeu_ps(dst + stride + C16NUM, res8);
84   _mm256_storeu_ps(dst + stride + C24NUM, res11);
85 }
86 
Deconv4X24AvxKernel(const float * src,const float * weight,float * dst,int col,int row,int depth,int stride)87 void Deconv4X24AvxKernel(const float *src, const float *weight, float *dst, int col, int row, int depth, int stride) {
88   __m256 res1 = _mm256_setzero_ps();
89   __m256 res2 = _mm256_setzero_ps();
90   __m256 res3 = _mm256_setzero_ps();
91   __m256 res4 = _mm256_setzero_ps();
92   __m256 res5 = _mm256_setzero_ps();
93   __m256 res6 = _mm256_setzero_ps();
94   __m256 res7 = _mm256_setzero_ps();
95   __m256 res8 = _mm256_setzero_ps();
96   __m256 res9 = _mm256_setzero_ps();
97   __m256 res10 = _mm256_setzero_ps();
98   __m256 res11 = _mm256_setzero_ps();
99   __m256 res12 = _mm256_setzero_ps();
100 
101   for (int d = 0; d < depth; ++d) {
102     __m256 w0 = _mm256_loadu_ps(weight);
103     __m256 w1 = _mm256_loadu_ps(weight + C8NUM);
104     __m256 w2 = _mm256_loadu_ps(weight + C16NUM);
105     __m256 tmp = _mm256_set1_ps(*src);
106     res1 = _mm256_fmadd_ps(tmp, w0, res1);
107     res2 = _mm256_fmadd_ps(tmp, w1, res2);
108     res3 = _mm256_fmadd_ps(tmp, w2, res3);
109     tmp = _mm256_set1_ps(*(src + C1NUM));
110     res4 = _mm256_fmadd_ps(tmp, w0, res4);
111     res5 = _mm256_fmadd_ps(tmp, w1, res5);
112     res6 = _mm256_fmadd_ps(tmp, w2, res6);
113     tmp = _mm256_set1_ps(*(src + C2NUM));
114     res7 = _mm256_fmadd_ps(tmp, w0, res7);
115     res8 = _mm256_fmadd_ps(tmp, w1, res8);
116     res9 = _mm256_fmadd_ps(tmp, w2, res9);
117     tmp = _mm256_set1_ps(*(src + C3NUM));
118     res10 = _mm256_fmadd_ps(tmp, w0, res10);
119     res11 = _mm256_fmadd_ps(tmp, w1, res11);
120     res12 = _mm256_fmadd_ps(tmp, w2, res12);
121     weight += C24NUM;
122     src += C4NUM;
123   }
124   // write
125   _mm256_storeu_ps(dst, res1);
126   _mm256_storeu_ps(dst + C8NUM, res4);
127   _mm256_storeu_ps(dst + C16NUM, res7);
128   _mm256_storeu_ps(dst + C24NUM, res10);
129 
130   _mm256_storeu_ps(dst + stride, res2);
131   _mm256_storeu_ps(dst + stride + C8NUM, res5);
132   _mm256_storeu_ps(dst + stride + C16NUM, res8);
133   _mm256_storeu_ps(dst + stride + C24NUM, res11);
134 
135   _mm256_storeu_ps(dst + C2NUM * stride, res3);
136   _mm256_storeu_ps(dst + C2NUM * stride + C8NUM, res6);
137   _mm256_storeu_ps(dst + C2NUM * stride + C16NUM, res9);
138   _mm256_storeu_ps(dst + C2NUM * stride + C24NUM, res12);
139 }
140 
DeconvMatmulAvx(const float * a,const float * b,float * c,int depth,int row,int col,const int plane)141 void DeconvMatmulAvx(const float *a, const float *b, float *c, int depth, int row, int col, const int plane) {
142   NNACL_CHECK_ZERO_RETURN(plane);
143   int col_num = 0;
144   int col_block = UP_DIV(col / plane, C8NUM);
145   DeconvAvxKernel kernel[3] = {Deconv4X8AvxKernel, Deconv4X16AvxKernel, Deconv4X24AvxKernel};
146   for (int col_tmp = 0; col_tmp < col_block; col_tmp += col_num) {
147     col_num = MSMIN(C3NUM, col_block - col_tmp);
148     for (int p = 0; p < plane; ++p) {
149       for (int r = 0; r < row; r += C4NUM) {
150         kernel[col_num - 1](a + r * depth, b + (col_tmp * plane + p * col_num) * C8NUM * depth,
151                             c + (col_tmp * plane + p) * C8NUM * row + r * C8NUM, col_num, C4NUM, depth,
152                             row * C8NUM * plane);
153       }
154     }
155   }
156 }
157 
158 #ifdef ENABLE_DEBUG
DeconvColXRowAvxKernel(const float * src,const float * weight,float * dst,int col,int row,int depth,int stride)159 void DeconvColXRowAvxKernel(const float *src, const float *weight, float *dst, int col, int row, int depth,
160                             int stride) {
161   __m256 res[C12NUM];
162   __m256 w[C3NUM];
163   for (int i = 0; i < C12NUM; ++i) {
164     res[i] = _mm256_setzero_ps();
165   }
166   for (int d = 0; d < depth; ++d) {
167     for (int c = 0; c < col; ++c) {
168       w[c] = _mm256_loadu_ps(weight + c * C8NUM);
169     }
170     weight += col * C8NUM;
171     for (int r = 0; r < row; ++r) {  // C4NUm
172       __m256 tmp = _mm256_set1_ps(*src);
173       for (int c = 0; c < col; ++c) {  // 3 * C8NUM
174         res[r * col + c] = _mm256_fmadd_ps(tmp, w[c], res[r * col + c]);
175       }
176       src += 1;
177     }
178   }
179   // write
180   for (int i = 0; i < col; ++i) {
181     for (int j = 0; j < row; ++j) {
182       _mm256_storeu_ps(dst + j * C8NUM, res[j * col + i]);
183     }
184     dst += stride;
185   }
186 }
187 #endif
188 #endif
189