1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #if defined(ENABLE_SSE) && !defined(ENABLE_AVX)
17 #include "nnacl/intrinsics/ms_simd_instructions.h"
18 #include "nnacl/fp32/common_func_fp32.h"
19
TiledC4MatmulFp32_Transfer(__m128 * dst1,__m128 * dst2,__m128 * dst3,__m128 * dst4,const __m128 weight,const float v1,const float v2,const float v3,const float v4)20 static inline void TiledC4MatmulFp32_Transfer(__m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4,
21 const __m128 weight, const float v1, const float v2, const float v3,
22 const float v4) {
23 *dst1 = _mm_add_ps(*dst1, _mm_mul_ps(weight, _mm_set_ps1(v1)));
24 *dst2 = _mm_add_ps(*dst2, _mm_mul_ps(weight, _mm_set_ps1(v2)));
25 *dst3 = _mm_add_ps(*dst3, _mm_mul_ps(weight, _mm_set_ps1(v3)));
26 *dst4 = _mm_add_ps(*dst4, _mm_mul_ps(weight, _mm_set_ps1(v4)));
27 }
28
TiledC4MatmulFp32_LoadData(__m128 * src1,__m128 * src2,__m128 * src3,__m128 * src4,const float * src)29 static inline void TiledC4MatmulFp32_LoadData(__m128 *src1, __m128 *src2, __m128 *src3, __m128 *src4,
30 const float *src) {
31 *src1 = _mm_loadu_ps(src);
32 *src2 = _mm_loadu_ps(src + 4);
33 *src3 = _mm_loadu_ps(src + 8);
34 *src4 = _mm_loadu_ps(src + 12);
35 }
36
TiledC4MatmulFp32(float * dst,const float * src,const float * weight,size_t cal_num,size_t ic4,size_t oc4)37 void TiledC4MatmulFp32(float *dst, const float *src, const float *weight, size_t cal_num, size_t ic4, size_t oc4) {
38 const float *src_tmp = src;
39 for (int i = 0; i < oc4; ++i) {
40 float *dst_tmp = dst;
41 src = src_tmp;
42 size_t ic4_tmp = ic4 - 1;
43 __m128 src1 = _mm_loadu_ps(src);
44 __m128 src2 = _mm_loadu_ps(src + 4);
45 __m128 src3 = _mm_loadu_ps(src + 8);
46 __m128 src4 = _mm_loadu_ps(src + 12);
47 src += 16;
48 __m128 weight_data[4];
49 weight_data[0] = _mm_loadu_ps(weight);
50 weight_data[1] = _mm_loadu_ps(weight + 4);
51 weight_data[2] = _mm_loadu_ps(weight + 8);
52 weight_data[3] = _mm_loadu_ps(weight + 12);
53 weight += 16;
54 __m128 dst1 = _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src1, 0)));
55 __m128 dst2 = _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src2, 0)));
56 __m128 dst3 = _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src3, 0)));
57 __m128 dst4 = _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src4, 0)));
58 for (int j = 1; j < 4; ++j) {
59 TiledC4MatmulFp32_Transfer(&dst1, &dst2, &dst3, &dst4, weight_data[j], MS_F32X4_GETI(src1, j),
60 MS_F32X4_GETI(src2, j), MS_F32X4_GETI(src3, j), MS_F32X4_GETI(src4, j));
61 }
62 TiledC4MatmulFp32_LoadData(&src1, &src2, &src3, &src4, src);
63 src += 16;
64 __m128 dst5 = _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src1, 0)));
65 __m128 dst6 = _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src2, 0)));
66 __m128 dst7 = _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src3, 0)));
67 __m128 dst8 = _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src4, 0)));
68 for (int j = 1; j < 4; ++j) {
69 TiledC4MatmulFp32_Transfer(&dst5, &dst6, &dst7, &dst8, weight_data[j], MS_F32X4_GETI(src1, j),
70 MS_F32X4_GETI(src2, j), MS_F32X4_GETI(src3, j), MS_F32X4_GETI(src4, j));
71 }
72 if (ic4_tmp != 0) {
73 ic4_tmp -= 1;
74 TiledC4MatmulFp32_LoadData(&src1, &src2, &src3, &src4, src);
75 src += 16;
76 weight_data[0] = _mm_loadu_ps(weight);
77 weight_data[1] = _mm_loadu_ps(weight + 4);
78 weight += 8;
79
80 dst1 = _mm_add_ps(dst1, _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src1, 0))));
81 dst2 = _mm_add_ps(dst2, _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src2, 0))));
82 for (; ic4_tmp != 0; ic4_tmp -= 1) {
83 dst3 = _mm_add_ps(dst3, _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src3, 0))));
84 dst4 = _mm_add_ps(dst4, _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src4, 0))));
85
86 TiledC4MatmulFp32_Transfer(&dst1, &dst2, &dst3, &dst4, weight_data[1], MS_F32X4_GETI(src1, 1),
87 MS_F32X4_GETI(src2, 1), MS_F32X4_GETI(src3, 1), MS_F32X4_GETI(src4, 1));
88
89 weight_data[2] = _mm_loadu_ps(weight);
90 weight_data[3] = _mm_loadu_ps(weight + 4);
91 weight += 8;
92
93 TiledC4MatmulFp32_Transfer(&dst1, &dst2, &dst3, &dst4, weight_data[2], MS_F32X4_GETI(src1, 2),
94 MS_F32X4_GETI(src2, 2), MS_F32X4_GETI(src3, 2), MS_F32X4_GETI(src4, 2));
95
96 dst1 = _mm_add_ps(dst1, _mm_mul_ps(weight_data[3], _mm_set_ps1(MS_F32X4_GETI(src1, 3))));
97 dst2 = _mm_add_ps(dst2, _mm_mul_ps(weight_data[3], _mm_set_ps1(MS_F32X4_GETI(src2, 3))));
98 src1 = _mm_loadu_ps(src);
99 src2 = _mm_loadu_ps(src + 4);
100 dst3 = _mm_add_ps(dst3, _mm_mul_ps(weight_data[3], _mm_set_ps1(MS_F32X4_GETI(src3, 3))));
101 dst4 = _mm_add_ps(dst4, _mm_mul_ps(weight_data[3], _mm_set_ps1(MS_F32X4_GETI(src4, 3))));
102 src3 = _mm_loadu_ps(src + 8);
103 src4 = _mm_loadu_ps(src + 12);
104 src += 16;
105
106 TiledC4MatmulFp32_Transfer(&dst5, &dst6, &dst7, &dst8, weight_data[0], MS_F32X4_GETI(src1, 0),
107 MS_F32X4_GETI(src2, 0), MS_F32X4_GETI(src3, 0), MS_F32X4_GETI(src4, 0));
108
109 TiledC4MatmulFp32_Transfer(&dst5, &dst6, &dst7, &dst8, weight_data[1], MS_F32X4_GETI(src1, 1),
110 MS_F32X4_GETI(src2, 1), MS_F32X4_GETI(src3, 1), MS_F32X4_GETI(src4, 1));
111
112 TiledC4MatmulFp32_Transfer(&dst5, &dst6, &dst7, &dst8, weight_data[2], MS_F32X4_GETI(src1, 2),
113 MS_F32X4_GETI(src2, 2), MS_F32X4_GETI(src3, 2), MS_F32X4_GETI(src4, 2));
114
115 weight_data[0] = _mm_loadu_ps(weight);
116 weight_data[1] = _mm_loadu_ps(weight + 4);
117 weight += 8;
118
119 TiledC4MatmulFp32_Transfer(&dst5, &dst6, &dst7, &dst8, weight_data[3], MS_F32X4_GETI(src1, 3),
120 MS_F32X4_GETI(src2, 3), MS_F32X4_GETI(src3, 3), MS_F32X4_GETI(src4, 3));
121 TiledC4MatmulFp32_LoadData(&src1, &src2, &src3, &src4, src);
122 src += 16;
123
124 dst1 = _mm_add_ps(dst1, _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src1, 0))));
125 dst2 = _mm_add_ps(dst2, _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src2, 0))));
126 }
127 dst3 = _mm_add_ps(dst3, _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src3, 0))));
128 dst4 = _mm_add_ps(dst4, _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src4, 0))));
129
130 TiledC4MatmulFp32_Transfer(&dst1, &dst2, &dst3, &dst4, weight_data[1], MS_F32X4_GETI(src1, 1),
131 MS_F32X4_GETI(src2, 1), MS_F32X4_GETI(src3, 1), MS_F32X4_GETI(src4, 1));
132
133 weight_data[2] = _mm_loadu_ps(weight);
134 weight_data[3] = _mm_loadu_ps(weight + 4);
135 weight += 8;
136
137 TiledC4MatmulFp32_Transfer(&dst1, &dst2, &dst3, &dst4, weight_data[2], MS_F32X4_GETI(src1, 2),
138 MS_F32X4_GETI(src2, 2), MS_F32X4_GETI(src3, 2), MS_F32X4_GETI(src4, 2));
139
140 TiledC4MatmulFp32_Transfer(&dst1, &dst2, &dst3, &dst4, weight_data[3], MS_F32X4_GETI(src1, 3),
141 MS_F32X4_GETI(src2, 3), MS_F32X4_GETI(src3, 3), MS_F32X4_GETI(src4, 3));
142
143 TiledC4MatmulFp32_LoadData(&src1, &src2, &src3, &src4, src);
144 src += 16;
145 for (int j = 0; j < 4; ++j) {
146 TiledC4MatmulFp32_Transfer(&dst5, &dst6, &dst7, &dst8, weight_data[j], MS_F32X4_GETI(src1, j),
147 MS_F32X4_GETI(src2, j), MS_F32X4_GETI(src3, j), MS_F32X4_GETI(src4, j));
148 }
149 }
150 _mm_storeu_ps(dst, dst1);
151 _mm_storeu_ps(dst + 4, dst2);
152 _mm_storeu_ps(dst + 8, dst3);
153 _mm_storeu_ps(dst + 12, dst4);
154 _mm_storeu_ps(dst + 16, dst5);
155 _mm_storeu_ps(dst + 20, dst6);
156 _mm_storeu_ps(dst + 24, dst7);
157 _mm_storeu_ps(dst + 28, dst8);
158 dst = dst_tmp + cal_num;
159 }
160 }
161 #endif
162