• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifdef ENABLE_SSE
18 #include "nnacl/intrinsics/ms_simd_instructions.h"
19 #include "nnacl/fp32/matmul_fp32.h"
20 #include "nnacl/op_base.h"
21 #include "nnacl/matmul_parameter.h"
22 #include "nnacl/intrinsics/sse/sse_common.h"
23 #include "nnacl/base/minimal_filtering_generator.h"
24 
MatrixMultiplyWinograd(const float * matix_a,const float * matrix_b,float * matrix_c,int m,int k,int n,int in_channel,int c4_channel)25 void MatrixMultiplyWinograd(const float *matix_a, const float *matrix_b, float *matrix_c, int m, int k, int n,
26                             int in_channel, int c4_channel) {
27   const float *src1 = matix_a;
28   int c16 = DOWN_DIV(in_channel, C16NUM) * C16NUM;
29   int c8 = DOWN_DIV(in_channel, C8NUM) * C8NUM;
30   for (int i = 0; i < m; ++i) {
31     const float *src1_n = src1;
32     const float *src2_n = matrix_b;
33     for (int j = 0; j < n; ++j) {
34       const float *src1_j = src1_n;
35       int y = 0;
36       // 16 channel
37       for (; y < c16; y += C16NUM) {
38         __m128 dst1 = _mm_setzero_ps();
39         __m128 dst2 = _mm_setzero_ps();
40         __m128 dst3 = _mm_setzero_ps();
41         __m128 dst4 = _mm_setzero_ps();
42         const float *src2_y = src2_n;
43         for (int z = 0; z < k; ++z) {
44           __m128 ma1 = _mm_loadu_ps(src1_j);
45           __m128 ma2 = _mm_loadu_ps(src1_j + 4);
46           __m128 ma3 = _mm_loadu_ps(src1_j + 8);
47           __m128 ma4 = _mm_loadu_ps(src1_j + 12);
48 
49           __m128 mb = _mm_load_ps1(src2_y);
50           __m128 tmp1 = _mm_mul_ps(ma1, mb);
51           __m128 tmp2 = _mm_mul_ps(ma2, mb);
52           __m128 tmp3 = _mm_mul_ps(ma3, mb);
53           __m128 tmp4 = _mm_mul_ps(ma4, mb);
54           dst1 = _mm_add_ps(dst1, tmp1);
55           dst2 = _mm_add_ps(dst2, tmp2);
56           dst3 = _mm_add_ps(dst3, tmp3);
57           dst4 = _mm_add_ps(dst4, tmp4);
58           src1_j += in_channel;
59           src2_y += n;
60         }
61         _mm_storeu_ps(matrix_c, dst1);
62         _mm_storeu_ps(matrix_c + 4, dst2);
63         _mm_storeu_ps(matrix_c + 8, dst3);
64         _mm_storeu_ps(matrix_c + 12, dst4);
65         src1_j -= in_channel * k;
66         src1_j += C16NUM;
67         matrix_c += C16NUM;
68       }
69       // 8 channel
70       for (; y < c8; y += C8NUM) {
71         __m128 dst1 = _mm_setzero_ps();
72         __m128 dst2 = _mm_setzero_ps();
73         const float *src2_y = src2_n;
74         for (int z = 0; z < k; ++z) {
75           __m128 ma1 = _mm_loadu_ps(src1_j);
76           __m128 ma2 = _mm_loadu_ps(src1_j + 4);
77 
78           __m128 mb = _mm_load_ps1(src2_y);
79           __m128 tmp1 = _mm_mul_ps(ma1, mb);
80           __m128 tmp2 = _mm_mul_ps(ma2, mb);
81           dst1 = _mm_add_ps(dst1, tmp1);
82           dst2 = _mm_add_ps(dst2, tmp2);
83           src1_j += in_channel;
84           src2_y += n;
85         }
86         _mm_storeu_ps(matrix_c, dst1);
87         _mm_storeu_ps(matrix_c + 4, dst2);
88         src1_j -= in_channel * k;
89         src1_j += C8NUM;
90         matrix_c += C8NUM;
91       }
92       // remain chann
93       for (; y < in_channel; ++y) {
94         float tmp = 0;
95         for (int z = 0; z < k; ++z) {
96           tmp += matix_a[z * in_channel + y + i * in_channel * k] * matrix_b[j + z * n];
97         }
98         *matrix_c++ = tmp;
99       }
100       src2_n += 1;
101     }
102     src1 += k * in_channel;
103   }
104 }
105 
MatmulFloatSse64Opt(const float * a,const float * b,float * c,const float * bias,int act_type,int depth,int row,int col,int stride,int write_mode)106 void MatmulFloatSse64Opt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row,
107                          int col, int stride, int write_mode) {
108   int C8Steps = row * C8NUM, WinoSteps1 = stride * col, WinoSteps2 = stride * C8NUM;
109   for (int r = row; r > 0; r -= C4NUM) {
110     const float *srcb_d = b, *bias_d = bias;
111     float *dst = NULL;
112     for (int cc = col; cc > 0; cc -= C8NUM) {
113       if (write_mode != 0) {  // writec8
114         dst = c;
115       }
116       const float *srca_d = a;
117       __m128 dst1 = _mm_setzero_ps(), dst2 = _mm_setzero_ps(), dst3 = _mm_setzero_ps(), dst4 = _mm_setzero_ps();
118       __m128 dst5 = _mm_setzero_ps(), dst6 = _mm_setzero_ps(), dst7 = _mm_setzero_ps(), dst8 = _mm_setzero_ps();
119       for (int d = depth; d > 0; --d) {
120         __m128 b1 = _mm_loadu_ps(srcb_d), b2 = _mm_loadu_ps(srcb_d + 4);
121         __m128 a1 = _mm_load_ps1(srca_d), a2 = _mm_load_ps1(srca_d + 1);
122         __m128 tmp1 = _mm_mul_ps(b1, a1), tmp2 = _mm_mul_ps(b2, a1);
123         __m128 tmp3 = _mm_mul_ps(b1, a2), tmp4 = _mm_mul_ps(b2, a2);
124         a1 = _mm_load_ps1(srca_d + 2);
125         dst1 = _mm_add_ps(dst1, tmp1), dst2 = _mm_add_ps(dst2, tmp2);
126         a2 = _mm_load_ps1(srca_d + 3);
127         dst3 = _mm_add_ps(dst3, tmp3), dst4 = _mm_add_ps(dst4, tmp4);
128         tmp1 = _mm_mul_ps(b1, a1), tmp2 = _mm_mul_ps(b2, a1);
129         tmp3 = _mm_mul_ps(b1, a2), tmp4 = _mm_mul_ps(b2, a2);
130         dst5 = _mm_add_ps(dst5, tmp1), dst6 = _mm_add_ps(dst6, tmp2);
131         dst7 = _mm_add_ps(dst7, tmp3), dst8 = _mm_add_ps(dst8, tmp4);
132         srcb_d += C8NUM, srca_d += C4NUM;
133       }
134 
135       if (bias != NULL) {
136         DoBiasBlock8(bias_d, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8);
137         bias_d += C8NUM;
138       }
139 
140       ActBlock8(&dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, act_type);
141 
142       if (write_mode == OutType_TileC8) {  // WriteWino
143         c = dst + WinoSteps2;
144         _mm_storeu_ps(dst, dst1), _mm_storeu_ps(dst + 4, dst2);
145         dst += WinoSteps1;
146         _mm_storeu_ps(dst, dst3), _mm_storeu_ps(dst + 4, dst4);
147         dst += WinoSteps1;
148         _mm_storeu_ps(dst, dst5), _mm_storeu_ps(dst + 4, dst6);
149         dst += WinoSteps1;
150         _mm_storeu_ps(dst, dst7), _mm_storeu_ps(dst + 4, dst8);
151       } else if (write_mode == OutType_C8) {  // WriteC8
152         _mm_storeu_ps(c, dst1), _mm_storeu_ps(c + 4, dst2);
153         _mm_storeu_ps(c + 8, dst3), _mm_storeu_ps(c + 12, dst4);
154         _mm_storeu_ps(c + 16, dst5), _mm_storeu_ps(c + 20, dst6);
155         _mm_storeu_ps(c + 24, dst7), _mm_storeu_ps(c + 28, dst8);
156         c += C8Steps;
157       } else {
158         switch (cc) {
159           case 1:  // write1
160             c = dst + 1;
161             WriteCol1(&dst, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, stride, 1, r);
162             break;
163           case 2:  // write2
164             c = dst + 2;
165             WriteCol2Opt(&dst, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, stride, r);
166             break;
167           case 3:  // write3
168             c = dst + 3;
169             _mm_store_ss(dst, dst1);
170             dst1 = _mm_shuffle_ps(dst1, dst1, _MM_SHUFFLE(0, 3, 2, 1));
171             _mm_store_ss(dst + 1, dst1);
172             dst1 = _mm_shuffle_ps(dst1, dst1, _MM_SHUFFLE(0, 3, 2, 1));
173             _mm_store_ss(dst + 2, dst1);
174             WriteCol3(&dst, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, stride, 3, r);
175             break;
176           case 4:  // write4
177             c = dst + 4;
178             WriteCol4(&dst, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, stride, 4, r);
179             break;
180           case 5:  // write5
181             c = dst + 5;
182             WriteCol5(&dst, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, stride, 5, r);
183             break;
184           case 6:  // write6
185             c = dst + 6;
186             WriteCol6(&dst, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, stride, 6, r);
187             break;
188           case 7:  // write7
189             c = dst + 7;
190             WriteCol7(&dst, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, stride, 7, r);
191             break;
192           default:  // write8
193             c = dst + C8NUM;
194             WriteCol8(&dst, &dst1, &dst2, &dst3, &dst4, &dst5, &dst6, &dst7, &dst8, stride, 8, r);
195             break;
196         }
197       }
198       if (cc <= C8NUM) break;  // write end
199     }
200     a += C4NUM * depth;
201     if (write_mode == OutType_C8) c += 32;
202     if (write_mode == OutType_TileC8) c = dst + WinoSteps2;
203     if (write_mode == OutType_Nhwc) c = dst - col;
204     if (r <= C4NUM) break;
205   }
206 }
207 
DeconvMatmulFloatSse(const float * a,const float * b,float * c,int depth,int row,int col)208 void DeconvMatmulFloatSse(const float *a, const float *b, float *c, int depth, int row, int col) {
209   for (int col_tmp = col; col_tmp > 0; col_tmp -= C8NUM) {
210     const float *srca_d = a;
211     float *dst = c;
212     for (int r = row; r > 0; r -= C4NUM) {
213       const float *srcb_d = b;
214       __m128 dst1 = _mm_setzero_ps(), dst2 = _mm_setzero_ps();
215       __m128 dst3 = _mm_setzero_ps(), dst4 = _mm_setzero_ps();
216       __m128 dst5 = _mm_setzero_ps(), dst6 = _mm_setzero_ps();
217       __m128 dst7 = _mm_setzero_ps(), dst8 = _mm_setzero_ps();
218       for (int d = 0; d < depth; d++) {
219         __m128 b1 = _mm_loadu_ps(srcb_d), b2 = _mm_loadu_ps(srcb_d + 4);
220         __m128 a1 = _mm_load_ps1(srca_d), a2 = _mm_load_ps1(srca_d + 1);
221         __m128 tmp1 = _mm_mul_ps(b1, a1), tmp2 = _mm_mul_ps(b2, a1);
222         __m128 tmp3 = _mm_mul_ps(b1, a2), tmp4 = _mm_mul_ps(b2, a2);
223         a1 = _mm_load_ps1(srca_d + 2);
224         dst1 = _mm_add_ps(dst1, tmp1), dst2 = _mm_add_ps(dst2, tmp2);
225         a2 = _mm_load_ps1(srca_d + 3);
226         dst3 = _mm_add_ps(dst3, tmp3), dst4 = _mm_add_ps(dst4, tmp4);
227         tmp1 = _mm_mul_ps(b1, a1), tmp2 = _mm_mul_ps(b2, a1);
228         tmp3 = _mm_mul_ps(b1, a2), tmp4 = _mm_mul_ps(b2, a2);
229         dst5 = _mm_add_ps(dst5, tmp1), dst6 = _mm_add_ps(dst6, tmp2);
230         dst7 = _mm_add_ps(dst7, tmp3), dst8 = _mm_add_ps(dst8, tmp4);
231         srcb_d += C8NUM, srca_d += C4NUM;
232       }
233       _mm_storeu_ps(dst, dst1), _mm_storeu_ps(dst + 4, dst2);
234       _mm_storeu_ps(dst + 8, dst3), _mm_storeu_ps(dst + 12, dst4);
235       _mm_storeu_ps(dst + 16, dst5), _mm_storeu_ps(dst + 20, dst6);
236       _mm_storeu_ps(dst + 24, dst7), _mm_storeu_ps(dst + 28, dst8);
237       dst += 32;
238       c = dst;
239     }
240     b += depth * C8NUM;
241   }
242 }
243 #endif
244