1 // Copyright 2019 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5
6 #include <assert.h>
7
8 #include <xmmintrin.h>
9
10 #include <xnnpack/conv.h>
11 #include <xnnpack/math.h>
12
13
xnn_f32_conv_hwc2chw_ukernel_3x3s2p1c3x4__sse_1x1(size_t input_height,size_t input_width,size_t output_y_start,size_t output_y_end,const float * input,const float * zero,const float * weights,float * output,size_t input_padding_top,size_t output_channels,size_t output_height_stride,size_t output_channel_stride,const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])14 void xnn_f32_conv_hwc2chw_ukernel_3x3s2p1c3x4__sse_1x1(
15 size_t input_height,
16 size_t input_width,
17 size_t output_y_start,
18 size_t output_y_end,
19 const float* input,
20 const float* zero,
21 const float* weights,
22 float* output,
23 size_t input_padding_top,
24 size_t output_channels,
25 size_t output_height_stride,
26 size_t output_channel_stride,
27 const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
28 {
29 assert(input_width != 0);
30 assert(output_y_end > output_y_start);
31 assert(input_padding_top <= 1);
32 assert(output_channels != 0);
33
34 const size_t input_height_stride = input_width * 3 /* channels */ * sizeof(float);
35 const size_t input_width_decrement = round_down_po2(input_width, 2) * 3 /* channels */ * sizeof(float);
36 const size_t output_width = (input_width + 1) / 2;
37 const size_t output_channel_increment = output_channel_stride * 4 - output_width * sizeof(float);
38
39 // Adjustment for padding processed below
40 const float* i0 = (const float*) ((uintptr_t) input + input_height_stride * (output_y_start * 2 - input_padding_top));
41 const float* i1 = (const float*) ((uintptr_t) i0 + input_height_stride);
42 const float* i2 = (const float*) ((uintptr_t) i1 + input_height_stride);
43 float* output0 = (float*) ((uintptr_t) output + output_height_stride * output_y_start);
44
45 if XNN_UNPREDICTABLE(output_y_start < input_padding_top) {
46 i0 = zero;
47 }
48
49 const __m128 vmin = _mm_load_ps(params->sse.min);
50 const __m128 vmax = _mm_load_ps(params->sse.max);
51
52 for (size_t output_y = output_y_start; output_y < output_y_end; output_y += 1) {
53 const size_t input_y2 = output_y * 2 + 2 - input_padding_top;
54 if XNN_UNPREDICTABLE(input_y2 >= input_height) {
55 i2 = zero;
56 }
57
58 const float* w = weights;
59 size_t c = output_channels;
60 float* o0c0 = output0;
61 float* o0c1 = (float*) ((uintptr_t) o0c0 + output_channel_stride);
62 float* o0c2 = (float*) ((uintptr_t) o0c1 + output_channel_stride);
63 float* o0c3 = (float*) ((uintptr_t) o0c2 + output_channel_stride);
64 do {
65 if XNN_UNPREDICTABLE(c < 2) {
66 o0c1 = o0c0;
67 }
68 if XNN_UNPREDICTABLE(c <= 2) {
69 o0c2 = o0c1;
70 }
71 if XNN_UNPREDICTABLE(c < 4) {
72 o0c3 = o0c2;
73 }
74
75 // Left edge padding
76 __m128 vi00c0 = _mm_setzero_ps();
77 __m128 vi00c1 = _mm_setzero_ps();
78 __m128 vi00c2 = _mm_setzero_ps();
79 __m128 vi10c0 = _mm_setzero_ps();
80 __m128 vi10c1 = _mm_setzero_ps();
81 __m128 vi10c2 = _mm_setzero_ps();
82 __m128 vi20c0 = _mm_setzero_ps();
83 __m128 vi20c1 = _mm_setzero_ps();
84 __m128 vi20c2 = _mm_setzero_ps();
85
86 size_t iw = input_width;
87 for (; iw >= 2; iw -= 2) {
88 __m128 voc0123 = _mm_loadu_ps(w);
89
90 const __m128 vk00c0x0123 = _mm_load_ps(w + 4);
91 voc0123 = _mm_add_ps(voc0123, _mm_mul_ps(vk00c0x0123, vi00c0));
92
93 const __m128 vk10c0x0123 = _mm_load_ps(w + 8);
94 voc0123 = _mm_add_ps(voc0123, _mm_mul_ps(vk10c0x0123, vi10c0));
95
96 const __m128 vk20c0x0123 = _mm_load_ps(w + 12);
97 voc0123 = _mm_add_ps(voc0123, _mm_mul_ps(vk20c0x0123, vi20c0));
98
99 const __m128 vk00c1x0123 = _mm_load_ps(w + 16);
100 voc0123 = _mm_add_ps(voc0123, _mm_mul_ps(vk00c1x0123, vi00c1));
101
102 const __m128 vk10c1x0123 = _mm_load_ps(w + 20);
103 voc0123 = _mm_add_ps(voc0123, _mm_mul_ps(vk10c1x0123, vi10c1));
104
105 const __m128 vk20c1x0123 = _mm_load_ps(w + 24);
106 voc0123 = _mm_add_ps(voc0123, _mm_mul_ps(vk20c1x0123, vi20c1));
107
108 const __m128 vk00c2x0123 = _mm_load_ps(w + 28);
109 voc0123 = _mm_add_ps(voc0123, _mm_mul_ps(vk00c2x0123, vi00c2));
110
111 const __m128 vk10c2x0123 = _mm_load_ps(w + 32);
112 voc0123 = _mm_add_ps(voc0123, _mm_mul_ps(vk10c2x0123, vi10c2));
113
114 const __m128 vk20c2x0123 = _mm_load_ps(w + 36);
115 voc0123 = _mm_add_ps(voc0123, _mm_mul_ps(vk20c2x0123, vi20c2));
116
117 const __m128 vk01c0x0123 = _mm_load_ps(w + 40);
118 const __m128 vi01c0 = _mm_load1_ps(i0);
119 voc0123 = _mm_add_ps(voc0123, _mm_mul_ps(vk01c0x0123, vi01c0));
120
121 const __m128 vk11c0x0123 = _mm_load_ps(w + 44);
122 const __m128 vi11c0 = _mm_load1_ps(i1);
123 voc0123 = _mm_add_ps(voc0123, _mm_mul_ps(vk11c0x0123, vi11c0));
124
125 const __m128 vk21c0x0123 = _mm_load_ps(w + 48);
126 const __m128 vi21c0 = _mm_load1_ps(i2);
127 voc0123 = _mm_add_ps(voc0123, _mm_mul_ps(vk21c0x0123, vi21c0));
128
129 const __m128 vk01c1x0123 = _mm_load_ps(w + 52);
130 const __m128 vi01c1 = _mm_load1_ps(i0 + 1);
131 voc0123 = _mm_add_ps(voc0123, _mm_mul_ps(vk01c1x0123, vi01c1));
132
133 const __m128 vk11c1x0123 = _mm_load_ps(w + 56);
134 const __m128 vi11c1 = _mm_load1_ps(i1 + 1);
135 voc0123 = _mm_add_ps(voc0123, _mm_mul_ps(vk11c1x0123, vi11c1));
136
137 const __m128 vk21c1x0123 = _mm_load_ps(w + 60);
138 const __m128 vi21c1 = _mm_load1_ps(i2 + 1);
139 voc0123 = _mm_add_ps(voc0123, _mm_mul_ps(vk21c1x0123, vi21c1));
140
141 const __m128 vk01c2x0123 = _mm_load_ps(w + 64);
142 const __m128 vi01c2 = _mm_load1_ps(i0 + 2);
143 voc0123 = _mm_add_ps(voc0123, _mm_mul_ps(vk01c2x0123, vi01c2));
144
145 const __m128 vk11c2x0123 = _mm_load_ps(w + 68);
146 const __m128 vi11c2 = _mm_load1_ps(i1 + 2);
147 voc0123 = _mm_add_ps(voc0123, _mm_mul_ps(vk11c2x0123, vi11c2));
148
149 const __m128 vk21c2x0123 = _mm_load_ps(w + 72);
150 const __m128 vi21c2 = _mm_load1_ps(i2 + 2);
151 voc0123 = _mm_add_ps(voc0123, _mm_mul_ps(vk21c2x0123, vi21c2));
152
153 const __m128 vk02c0x0123 = _mm_load_ps(w + 76);
154 const __m128 vi02c0 = _mm_load1_ps(i0 + 3);
155 voc0123 = _mm_add_ps(voc0123, _mm_mul_ps(vk02c0x0123, vi02c0));
156
157 const __m128 vk12c0x0123 = _mm_load_ps(w + 80);
158 const __m128 vi12c0 = _mm_load1_ps(i1 + 3);
159 voc0123 = _mm_add_ps(voc0123, _mm_mul_ps(vk12c0x0123, vi12c0));
160
161 const __m128 vk22c0x0123 = _mm_load_ps(w + 84);
162 const __m128 vi22c0 = _mm_load1_ps(i2 + 3);
163 voc0123 = _mm_add_ps(voc0123, _mm_mul_ps(vk22c0x0123, vi22c0));
164
165 vi00c0 = vi02c0;
166 vi10c0 = vi12c0;
167 vi20c0 = vi22c0;
168
169 const __m128 vk02c1x0123 = _mm_load_ps(w + 88);
170 const __m128 vi02c1 = _mm_load1_ps(i0 + 4);
171 voc0123 = _mm_add_ps(voc0123, _mm_mul_ps(vk02c1x0123, vi02c1));
172
173 const __m128 vk12c1x0123 = _mm_load_ps(w + 92);
174 const __m128 vi12c1 = _mm_load1_ps(i1 + 4);
175 voc0123 = _mm_add_ps(voc0123, _mm_mul_ps(vk12c1x0123, vi12c1));
176
177 const __m128 vk22c1x0123 = _mm_load_ps(w + 96);
178 const __m128 vi22c1 = _mm_load1_ps(i2 + 4);
179 voc0123 = _mm_add_ps(voc0123, _mm_mul_ps(vk22c1x0123, vi22c1));
180
181 vi00c1 = vi02c1;
182 vi10c1 = vi12c1;
183 vi20c1 = vi22c1;
184
185 const __m128 vk02c2x0123 = _mm_load_ps(w + 100);
186 const __m128 vi02c2 = _mm_load1_ps(i0 + 5);
187 voc0123 = _mm_add_ps(voc0123, _mm_mul_ps(vk02c2x0123, vi02c2));
188
189 const __m128 vk12c2x0123 = _mm_load_ps(w + 104);
190 const __m128 vi12c2 = _mm_load1_ps(i1 + 5);
191 voc0123 = _mm_add_ps(voc0123, _mm_mul_ps(vk12c2x0123, vi12c2));
192
193 const __m128 vk22c2x0123 = _mm_load_ps(w + 108);
194 const __m128 vi22c2 = _mm_load1_ps(i2 + 5);
195 voc0123 = _mm_add_ps(voc0123, _mm_mul_ps(vk22c2x0123, vi22c2));
196
197 vi00c2 = vi02c2;
198 vi10c2 = vi12c2;
199 vi20c2 = vi22c2;
200
201 voc0123 = _mm_min_ps(voc0123, vmax);
202 voc0123 = _mm_max_ps(voc0123, vmin);
203
204 _mm_store_ss(o0c0, voc0123); o0c0++;
205 _mm_store_ss(o0c1, _mm_shuffle_ps(voc0123, voc0123, 1)); o0c1++;
206 _mm_store_ss(o0c2, _mm_shuffle_ps(voc0123, voc0123, 2)); o0c2++;
207 _mm_store_ss(o0c3, _mm_shuffle_ps(voc0123, voc0123, 3)); o0c3++;
208
209 i0 += 6;
210 i1 += 6;
211 i2 += 6;
212 }
213 assert(iw < 2);
214 if XNN_UNLIKELY(iw != 0) {
215 __m128 voc0123 = _mm_load_ps(w);
216
217 const __m128 vk00c0x0123 = _mm_load_ps(w + 4);
218 voc0123 = _mm_add_ps(voc0123, _mm_mul_ps(vk00c0x0123, vi00c0));
219
220 const __m128 vk10c0x0123 = _mm_load_ps(w + 8);
221 voc0123 = _mm_add_ps(voc0123, _mm_mul_ps(vk10c0x0123, vi10c0));
222
223 const __m128 vk20c0x0123 = _mm_load_ps(w + 12);
224 voc0123 = _mm_add_ps(voc0123, _mm_mul_ps(vk20c0x0123, vi20c0));
225
226 const __m128 vk00c1x0123 = _mm_load_ps(w + 16);
227 voc0123 = _mm_add_ps(voc0123, _mm_mul_ps(vk00c1x0123, vi00c1));
228
229 const __m128 vk10c1x0123 = _mm_load_ps(w + 20);
230 voc0123 = _mm_add_ps(voc0123, _mm_mul_ps(vk10c1x0123, vi10c1));
231
232 const __m128 vk20c1x0123 = _mm_load_ps(w + 24);
233 voc0123 = _mm_add_ps(voc0123, _mm_mul_ps(vk20c1x0123, vi20c1));
234
235 const __m128 vk00c2x0123 = _mm_load_ps(w + 28);
236 voc0123 = _mm_add_ps(voc0123, _mm_mul_ps(vk00c2x0123, vi00c2));
237
238 const __m128 vk10c2x0123 = _mm_load_ps(w + 32);
239 voc0123 = _mm_add_ps(voc0123, _mm_mul_ps(vk10c2x0123, vi10c2));
240
241 const __m128 vk20c2x0123 = _mm_load_ps(w + 36);
242 voc0123 = _mm_add_ps(voc0123, _mm_mul_ps(vk20c2x0123, vi20c2));
243
244 const __m128 vk01c0x0123 = _mm_load_ps(w + 40);
245 const __m128 vi01c0 = _mm_load1_ps(i0);
246 voc0123 = _mm_add_ps(voc0123, _mm_mul_ps(vk01c0x0123, vi01c0));
247
248 const __m128 vk11c0x0123 = _mm_load_ps(w + 44);
249 const __m128 vi11c0 = _mm_load1_ps(i1);
250 voc0123 = _mm_add_ps(voc0123, _mm_mul_ps(vk11c0x0123, vi11c0));
251
252 const __m128 vk21c0x0123 = _mm_load_ps(w + 48);
253 const __m128 vi21c0 = _mm_load1_ps(i2);
254 voc0123 = _mm_add_ps(voc0123, _mm_mul_ps(vk21c0x0123, vi21c0));
255
256 const __m128 vk01c1x0123 = _mm_load_ps(w + 52);
257 const __m128 vi01c1 = _mm_load1_ps(i0 + 1);
258 voc0123 = _mm_add_ps(voc0123, _mm_mul_ps(vk01c1x0123, vi01c1));
259
260 const __m128 vk11c1x0123 = _mm_load_ps(w + 56);
261 const __m128 vi11c1 = _mm_load1_ps(i1 + 1);
262 voc0123 = _mm_add_ps(voc0123, _mm_mul_ps(vk11c1x0123, vi11c1));
263
264 const __m128 vk21c1x0123 = _mm_load_ps(w + 60);
265 const __m128 vi21c1 = _mm_load1_ps(i2 + 1);
266 voc0123 = _mm_add_ps(voc0123, _mm_mul_ps(vk21c1x0123, vi21c1));
267
268 const __m128 vk01c2x0123 = _mm_load_ps(w + 64);
269 const __m128 vi01c2 = _mm_load1_ps(i0 + 2);
270 voc0123 = _mm_add_ps(voc0123, _mm_mul_ps(vk01c2x0123, vi01c2));
271
272 const __m128 vk11c2x0123 = _mm_load_ps(w + 68);
273 const __m128 vi11c2 = _mm_load1_ps(i1 + 2);
274 voc0123 = _mm_add_ps(voc0123, _mm_mul_ps(vk11c2x0123, vi11c2));
275
276 const __m128 vk21c2x0123 = _mm_load_ps(w + 72);
277 const __m128 vi21c2 = _mm_load1_ps(i2 + 2);
278 voc0123 = _mm_add_ps(voc0123, _mm_mul_ps(vk21c2x0123, vi21c2));
279
280 voc0123 = _mm_min_ps(voc0123, vmax);
281 voc0123 = _mm_max_ps(voc0123, vmin);
282
283 _mm_store_ss(o0c0, voc0123); o0c0++;
284 _mm_store_ss(o0c1, _mm_shuffle_ps(voc0123, voc0123, 1)); o0c1++;
285 _mm_store_ss(o0c2, _mm_shuffle_ps(voc0123, voc0123, 2)); o0c2++;
286 _mm_store_ss(o0c3, _mm_shuffle_ps(voc0123, voc0123, 3)); o0c3++;
287 }
288 // Move output pointers back to the position of the first pixel in a row,
289 // and forward to the next block of output channels.
290 o0c0 = (float*) ((uintptr_t) o0c0 + output_channel_increment);
291 o0c1 = (float*) ((uintptr_t) o0c1 + output_channel_increment);
292 o0c2 = (float*) ((uintptr_t) o0c2 + output_channel_increment);
293 o0c3 = (float*) ((uintptr_t) o0c3 + output_channel_increment);
294 // Revert input pointers to the position of the first pixel in a row
295 i0 = (const float*) ((uintptr_t) i0 - input_width_decrement);
296 i1 = (const float*) ((uintptr_t) i1 - input_width_decrement);
297 i2 = (const float*) ((uintptr_t) i2 - input_width_decrement);
298 // Move to the block of weights for the next 4 output channels
299 w += 112;
300 c = doz(c, 4);
301 } while (c != 0);
302 // Move output pointers forward to the next row
303 output0 = (float*) ((uintptr_t) output0 + output_height_stride);
304 // Move input pointers forward to the next row
305 i0 = i2;
306 i1 = (const float*) ((uintptr_t) i0 + input_height_stride);
307 i2 = (const float*) ((uintptr_t) i1 + input_height_stride);
308 }
309 }
310