• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #if defined(ENABLE_SSE) && !defined(ENABLE_AVX)
17 #include "nnacl/intrinsics/ms_simd_instructions.h"
18 #include "nnacl/fp32/common_func_fp32.h"
19 
TiledC4MatmulFp32_Transfer(__m128 * dst1,__m128 * dst2,__m128 * dst3,__m128 * dst4,const __m128 weight,const float v1,const float v2,const float v3,const float v4)20 static inline void TiledC4MatmulFp32_Transfer(__m128 *dst1, __m128 *dst2, __m128 *dst3, __m128 *dst4,
21                                               const __m128 weight, const float v1, const float v2, const float v3,
22                                               const float v4) {
23   *dst1 = _mm_add_ps(*dst1, _mm_mul_ps(weight, _mm_set_ps1(v1)));
24   *dst2 = _mm_add_ps(*dst2, _mm_mul_ps(weight, _mm_set_ps1(v2)));
25   *dst3 = _mm_add_ps(*dst3, _mm_mul_ps(weight, _mm_set_ps1(v3)));
26   *dst4 = _mm_add_ps(*dst4, _mm_mul_ps(weight, _mm_set_ps1(v4)));
27 }
28 
TiledC4MatmulFp32_LoadData(__m128 * src1,__m128 * src2,__m128 * src3,__m128 * src4,const float * src)29 static inline void TiledC4MatmulFp32_LoadData(__m128 *src1, __m128 *src2, __m128 *src3, __m128 *src4,
30                                               const float *src) {
31   *src1 = _mm_loadu_ps(src);
32   *src2 = _mm_loadu_ps(src + 4);
33   *src3 = _mm_loadu_ps(src + 8);
34   *src4 = _mm_loadu_ps(src + 12);
35 }
36 
TiledC4MatmulFp32(float * dst,const float * src,const float * weight,size_t cal_num,size_t ic4,size_t oc4)37 void TiledC4MatmulFp32(float *dst, const float *src, const float *weight, size_t cal_num, size_t ic4, size_t oc4) {
38   const float *src_tmp = src;
39   for (int i = 0; i < oc4; ++i) {
40     float *dst_tmp = dst;
41     src = src_tmp;
42     size_t ic4_tmp = ic4 - 1;
43     __m128 src1 = _mm_loadu_ps(src);
44     __m128 src2 = _mm_loadu_ps(src + 4);
45     __m128 src3 = _mm_loadu_ps(src + 8);
46     __m128 src4 = _mm_loadu_ps(src + 12);
47     src += 16;
48     __m128 weight_data[4];
49     weight_data[0] = _mm_loadu_ps(weight);
50     weight_data[1] = _mm_loadu_ps(weight + 4);
51     weight_data[2] = _mm_loadu_ps(weight + 8);
52     weight_data[3] = _mm_loadu_ps(weight + 12);
53     weight += 16;
54     __m128 dst1 = _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src1, 0)));
55     __m128 dst2 = _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src2, 0)));
56     __m128 dst3 = _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src3, 0)));
57     __m128 dst4 = _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src4, 0)));
58     for (int j = 1; j < 4; ++j) {
59       TiledC4MatmulFp32_Transfer(&dst1, &dst2, &dst3, &dst4, weight_data[j], MS_F32X4_GETI(src1, j),
60                                  MS_F32X4_GETI(src2, j), MS_F32X4_GETI(src3, j), MS_F32X4_GETI(src4, j));
61     }
62     TiledC4MatmulFp32_LoadData(&src1, &src2, &src3, &src4, src);
63     src += 16;
64     __m128 dst5 = _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src1, 0)));
65     __m128 dst6 = _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src2, 0)));
66     __m128 dst7 = _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src3, 0)));
67     __m128 dst8 = _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src4, 0)));
68     for (int j = 1; j < 4; ++j) {
69       TiledC4MatmulFp32_Transfer(&dst5, &dst6, &dst7, &dst8, weight_data[j], MS_F32X4_GETI(src1, j),
70                                  MS_F32X4_GETI(src2, j), MS_F32X4_GETI(src3, j), MS_F32X4_GETI(src4, j));
71     }
72     if (ic4_tmp != 0) {
73       ic4_tmp -= 1;
74       TiledC4MatmulFp32_LoadData(&src1, &src2, &src3, &src4, src);
75       src += 16;
76       weight_data[0] = _mm_loadu_ps(weight);
77       weight_data[1] = _mm_loadu_ps(weight + 4);
78       weight += 8;
79 
80       dst1 = _mm_add_ps(dst1, _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src1, 0))));
81       dst2 = _mm_add_ps(dst2, _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src2, 0))));
82       for (; ic4_tmp != 0; ic4_tmp -= 1) {
83         dst3 = _mm_add_ps(dst3, _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src3, 0))));
84         dst4 = _mm_add_ps(dst4, _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src4, 0))));
85 
86         TiledC4MatmulFp32_Transfer(&dst1, &dst2, &dst3, &dst4, weight_data[1], MS_F32X4_GETI(src1, 1),
87                                    MS_F32X4_GETI(src2, 1), MS_F32X4_GETI(src3, 1), MS_F32X4_GETI(src4, 1));
88 
89         weight_data[2] = _mm_loadu_ps(weight);
90         weight_data[3] = _mm_loadu_ps(weight + 4);
91         weight += 8;
92 
93         TiledC4MatmulFp32_Transfer(&dst1, &dst2, &dst3, &dst4, weight_data[2], MS_F32X4_GETI(src1, 2),
94                                    MS_F32X4_GETI(src2, 2), MS_F32X4_GETI(src3, 2), MS_F32X4_GETI(src4, 2));
95 
96         dst1 = _mm_add_ps(dst1, _mm_mul_ps(weight_data[3], _mm_set_ps1(MS_F32X4_GETI(src1, 3))));
97         dst2 = _mm_add_ps(dst2, _mm_mul_ps(weight_data[3], _mm_set_ps1(MS_F32X4_GETI(src2, 3))));
98         src1 = _mm_loadu_ps(src);
99         src2 = _mm_loadu_ps(src + 4);
100         dst3 = _mm_add_ps(dst3, _mm_mul_ps(weight_data[3], _mm_set_ps1(MS_F32X4_GETI(src3, 3))));
101         dst4 = _mm_add_ps(dst4, _mm_mul_ps(weight_data[3], _mm_set_ps1(MS_F32X4_GETI(src4, 3))));
102         src3 = _mm_loadu_ps(src + 8);
103         src4 = _mm_loadu_ps(src + 12);
104         src += 16;
105 
106         TiledC4MatmulFp32_Transfer(&dst5, &dst6, &dst7, &dst8, weight_data[0], MS_F32X4_GETI(src1, 0),
107                                    MS_F32X4_GETI(src2, 0), MS_F32X4_GETI(src3, 0), MS_F32X4_GETI(src4, 0));
108 
109         TiledC4MatmulFp32_Transfer(&dst5, &dst6, &dst7, &dst8, weight_data[1], MS_F32X4_GETI(src1, 1),
110                                    MS_F32X4_GETI(src2, 1), MS_F32X4_GETI(src3, 1), MS_F32X4_GETI(src4, 1));
111 
112         TiledC4MatmulFp32_Transfer(&dst5, &dst6, &dst7, &dst8, weight_data[2], MS_F32X4_GETI(src1, 2),
113                                    MS_F32X4_GETI(src2, 2), MS_F32X4_GETI(src3, 2), MS_F32X4_GETI(src4, 2));
114 
115         weight_data[0] = _mm_loadu_ps(weight);
116         weight_data[1] = _mm_loadu_ps(weight + 4);
117         weight += 8;
118 
119         TiledC4MatmulFp32_Transfer(&dst5, &dst6, &dst7, &dst8, weight_data[3], MS_F32X4_GETI(src1, 3),
120                                    MS_F32X4_GETI(src2, 3), MS_F32X4_GETI(src3, 3), MS_F32X4_GETI(src4, 3));
121         TiledC4MatmulFp32_LoadData(&src1, &src2, &src3, &src4, src);
122         src += 16;
123 
124         dst1 = _mm_add_ps(dst1, _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src1, 0))));
125         dst2 = _mm_add_ps(dst2, _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src2, 0))));
126       }
127       dst3 = _mm_add_ps(dst3, _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src3, 0))));
128       dst4 = _mm_add_ps(dst4, _mm_mul_ps(weight_data[0], _mm_set_ps1(MS_F32X4_GETI(src4, 0))));
129 
130       TiledC4MatmulFp32_Transfer(&dst1, &dst2, &dst3, &dst4, weight_data[1], MS_F32X4_GETI(src1, 1),
131                                  MS_F32X4_GETI(src2, 1), MS_F32X4_GETI(src3, 1), MS_F32X4_GETI(src4, 1));
132 
133       weight_data[2] = _mm_loadu_ps(weight);
134       weight_data[3] = _mm_loadu_ps(weight + 4);
135       weight += 8;
136 
137       TiledC4MatmulFp32_Transfer(&dst1, &dst2, &dst3, &dst4, weight_data[2], MS_F32X4_GETI(src1, 2),
138                                  MS_F32X4_GETI(src2, 2), MS_F32X4_GETI(src3, 2), MS_F32X4_GETI(src4, 2));
139 
140       TiledC4MatmulFp32_Transfer(&dst1, &dst2, &dst3, &dst4, weight_data[3], MS_F32X4_GETI(src1, 3),
141                                  MS_F32X4_GETI(src2, 3), MS_F32X4_GETI(src3, 3), MS_F32X4_GETI(src4, 3));
142 
143       TiledC4MatmulFp32_LoadData(&src1, &src2, &src3, &src4, src);
144       src += 16;
145       for (int j = 0; j < 4; ++j) {
146         TiledC4MatmulFp32_Transfer(&dst5, &dst6, &dst7, &dst8, weight_data[j], MS_F32X4_GETI(src1, j),
147                                    MS_F32X4_GETI(src2, j), MS_F32X4_GETI(src3, j), MS_F32X4_GETI(src4, j));
148       }
149     }
150     _mm_storeu_ps(dst, dst1);
151     _mm_storeu_ps(dst + 4, dst2);
152     _mm_storeu_ps(dst + 8, dst3);
153     _mm_storeu_ps(dst + 12, dst4);
154     _mm_storeu_ps(dst + 16, dst5);
155     _mm_storeu_ps(dst + 20, dst6);
156     _mm_storeu_ps(dst + 24, dst7);
157     _mm_storeu_ps(dst + 28, dst8);
158     dst = dst_tmp + cal_num;
159   }
160 }
161 #endif
162