1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version C2NUM.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-C2NUM.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #ifdef ENABLE_AVX
17 #include "nnacl/intrinsics/ms_simd_instructions.h"
18 #include "nnacl/fp32/matmul_fp32.h"
19 #include "nnacl/op_base.h"
20
Deconv4X8AvxKernel(const float * src,const float * weight,float * dst,int col,int row,int depth,int stride)21 void Deconv4X8AvxKernel(const float *src, const float *weight, float *dst, int col, int row, int depth, int stride) {
22 __m256 res1 = _mm256_setzero_ps();
23 __m256 res4 = _mm256_setzero_ps();
24 __m256 res7 = _mm256_setzero_ps();
25 __m256 res10 = _mm256_setzero_ps();
26
27 for (int d = 0; d < depth; ++d) {
28 __m256 w0 = _mm256_loadu_ps(weight);
29 __m256 tmp = _mm256_set1_ps(*src);
30 __m256 tmp1 = _mm256_set1_ps(*(src + C1NUM));
31 weight += C8NUM;
32 __m256 tmp2 = _mm256_set1_ps(*(src + C2NUM));
33 __m256 tmp3 = _mm256_set1_ps(*(src + C3NUM));
34 res1 = _mm256_fmadd_ps(tmp, w0, res1);
35 res4 = _mm256_fmadd_ps(tmp1, w0, res4);
36 src += C4NUM;
37 res7 = _mm256_fmadd_ps(tmp2, w0, res7);
38 res10 = _mm256_fmadd_ps(tmp3, w0, res10);
39 }
40 // write
41 _mm256_storeu_ps(dst, res1);
42 _mm256_storeu_ps(dst + C8NUM, res4);
43 _mm256_storeu_ps(dst + C16NUM, res7);
44 _mm256_storeu_ps(dst + C24NUM, res10);
45 }
46
Deconv4X16AvxKernel(const float * src,const float * weight,float * dst,int col,int row,int depth,int stride)47 void Deconv4X16AvxKernel(const float *src, const float *weight, float *dst, int col, int row, int depth, int stride) {
48 __m256 res1 = _mm256_setzero_ps();
49 __m256 res2 = _mm256_setzero_ps();
50 __m256 res4 = _mm256_setzero_ps();
51 __m256 res5 = _mm256_setzero_ps();
52 __m256 res7 = _mm256_setzero_ps();
53 __m256 res8 = _mm256_setzero_ps();
54 __m256 res10 = _mm256_setzero_ps();
55 __m256 res11 = _mm256_setzero_ps();
56
57 for (int d = 0; d < depth; ++d) {
58 __m256 w0 = _mm256_loadu_ps(weight);
59 __m256 w1 = _mm256_loadu_ps(weight + C8NUM);
60 weight += C16NUM;
61 __m256 tmp = _mm256_set1_ps(*src);
62 __m256 tmp1 = _mm256_set1_ps(*(src + C1NUM));
63 __m256 tmp2 = _mm256_set1_ps(*(src + C2NUM));
64 __m256 tmp3 = _mm256_set1_ps(*(src + C3NUM));
65 res1 = _mm256_fmadd_ps(tmp, w0, res1);
66 res2 = _mm256_fmadd_ps(tmp, w1, res2);
67 src += C4NUM;
68 res4 = _mm256_fmadd_ps(tmp1, w0, res4);
69 res5 = _mm256_fmadd_ps(tmp1, w1, res5);
70 res7 = _mm256_fmadd_ps(tmp2, w0, res7);
71 res8 = _mm256_fmadd_ps(tmp2, w1, res8);
72 res10 = _mm256_fmadd_ps(tmp3, w0, res10);
73 res11 = _mm256_fmadd_ps(tmp3, w1, res11);
74 }
75 // write
76 _mm256_storeu_ps(dst, res1);
77 _mm256_storeu_ps(dst + C8NUM, res4);
78 _mm256_storeu_ps(dst + C16NUM, res7);
79 _mm256_storeu_ps(dst + C24NUM, res10);
80
81 _mm256_storeu_ps(dst + stride, res2);
82 _mm256_storeu_ps(dst + stride + C8NUM, res5);
83 _mm256_storeu_ps(dst + stride + C16NUM, res8);
84 _mm256_storeu_ps(dst + stride + C24NUM, res11);
85 }
86
Deconv4X24AvxKernel(const float * src,const float * weight,float * dst,int col,int row,int depth,int stride)87 void Deconv4X24AvxKernel(const float *src, const float *weight, float *dst, int col, int row, int depth, int stride) {
88 __m256 res1 = _mm256_setzero_ps();
89 __m256 res2 = _mm256_setzero_ps();
90 __m256 res3 = _mm256_setzero_ps();
91 __m256 res4 = _mm256_setzero_ps();
92 __m256 res5 = _mm256_setzero_ps();
93 __m256 res6 = _mm256_setzero_ps();
94 __m256 res7 = _mm256_setzero_ps();
95 __m256 res8 = _mm256_setzero_ps();
96 __m256 res9 = _mm256_setzero_ps();
97 __m256 res10 = _mm256_setzero_ps();
98 __m256 res11 = _mm256_setzero_ps();
99 __m256 res12 = _mm256_setzero_ps();
100
101 for (int d = 0; d < depth; ++d) {
102 __m256 w0 = _mm256_loadu_ps(weight);
103 __m256 w1 = _mm256_loadu_ps(weight + C8NUM);
104 __m256 w2 = _mm256_loadu_ps(weight + C16NUM);
105 __m256 tmp = _mm256_set1_ps(*src);
106 res1 = _mm256_fmadd_ps(tmp, w0, res1);
107 res2 = _mm256_fmadd_ps(tmp, w1, res2);
108 res3 = _mm256_fmadd_ps(tmp, w2, res3);
109 tmp = _mm256_set1_ps(*(src + C1NUM));
110 res4 = _mm256_fmadd_ps(tmp, w0, res4);
111 res5 = _mm256_fmadd_ps(tmp, w1, res5);
112 res6 = _mm256_fmadd_ps(tmp, w2, res6);
113 tmp = _mm256_set1_ps(*(src + C2NUM));
114 res7 = _mm256_fmadd_ps(tmp, w0, res7);
115 res8 = _mm256_fmadd_ps(tmp, w1, res8);
116 res9 = _mm256_fmadd_ps(tmp, w2, res9);
117 tmp = _mm256_set1_ps(*(src + C3NUM));
118 res10 = _mm256_fmadd_ps(tmp, w0, res10);
119 res11 = _mm256_fmadd_ps(tmp, w1, res11);
120 res12 = _mm256_fmadd_ps(tmp, w2, res12);
121 weight += C24NUM;
122 src += C4NUM;
123 }
124 // write
125 _mm256_storeu_ps(dst, res1);
126 _mm256_storeu_ps(dst + C8NUM, res4);
127 _mm256_storeu_ps(dst + C16NUM, res7);
128 _mm256_storeu_ps(dst + C24NUM, res10);
129
130 _mm256_storeu_ps(dst + stride, res2);
131 _mm256_storeu_ps(dst + stride + C8NUM, res5);
132 _mm256_storeu_ps(dst + stride + C16NUM, res8);
133 _mm256_storeu_ps(dst + stride + C24NUM, res11);
134
135 _mm256_storeu_ps(dst + C2NUM * stride, res3);
136 _mm256_storeu_ps(dst + C2NUM * stride + C8NUM, res6);
137 _mm256_storeu_ps(dst + C2NUM * stride + C16NUM, res9);
138 _mm256_storeu_ps(dst + C2NUM * stride + C24NUM, res12);
139 }
140
DeconvMatmulAvx(const float * a,const float * b,float * c,int depth,int row,int col,const int plane)141 void DeconvMatmulAvx(const float *a, const float *b, float *c, int depth, int row, int col, const int plane) {
142 NNACL_CHECK_ZERO_RETURN(plane);
143 int col_num = 0;
144 int col_block = UP_DIV(col / plane, C8NUM);
145 DeconvAvxKernel kernel[3] = {Deconv4X8AvxKernel, Deconv4X16AvxKernel, Deconv4X24AvxKernel};
146 for (int col_tmp = 0; col_tmp < col_block; col_tmp += col_num) {
147 col_num = MSMIN(C3NUM, col_block - col_tmp);
148 for (int p = 0; p < plane; ++p) {
149 for (int r = 0; r < row; r += C4NUM) {
150 kernel[col_num - 1](a + r * depth, b + (col_tmp * plane + p * col_num) * C8NUM * depth,
151 c + (col_tmp * plane + p) * C8NUM * row + r * C8NUM, col_num, C4NUM, depth,
152 row * C8NUM * plane);
153 }
154 }
155 }
156 }
157
158 #ifdef ENABLE_DEBUG
DeconvColXRowAvxKernel(const float * src,const float * weight,float * dst,int col,int row,int depth,int stride)159 void DeconvColXRowAvxKernel(const float *src, const float *weight, float *dst, int col, int row, int depth,
160 int stride) {
161 __m256 res[C12NUM];
162 __m256 w[C3NUM];
163 for (int i = 0; i < C12NUM; ++i) {
164 res[i] = _mm256_setzero_ps();
165 }
166 for (int d = 0; d < depth; ++d) {
167 for (int c = 0; c < col; ++c) {
168 w[c] = _mm256_loadu_ps(weight + c * C8NUM);
169 }
170 weight += col * C8NUM;
171 for (int r = 0; r < row; ++r) { // C4NUm
172 __m256 tmp = _mm256_set1_ps(*src);
173 for (int c = 0; c < col; ++c) { // 3 * C8NUM
174 res[r * col + c] = _mm256_fmadd_ps(tmp, w[c], res[r * col + c]);
175 }
176 src += 1;
177 }
178 }
179 // write
180 for (int i = 0; i < col; ++i) {
181 for (int j = 0; j < row; ++j) {
182 _mm256_storeu_ps(dst + j * C8NUM, res[j * col + i]);
183 }
184 dst += stride;
185 }
186 }
187 #endif
188 #endif
189