• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2019 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #include <assert.h>
7 
8 #include <xmmintrin.h>
9 
10 #include <xnnpack/conv.h>
11 #include <xnnpack/math.h>
12 
13 
xnn_f32_conv_hwc2chw_ukernel_3x3s2p1c3x4__sse_1x1(size_t input_height,size_t input_width,size_t output_y_start,size_t output_y_end,const float * input,const float * zero,const float * weights,float * output,size_t input_padding_top,size_t output_channels,size_t output_height_stride,size_t output_channel_stride,const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])14 void xnn_f32_conv_hwc2chw_ukernel_3x3s2p1c3x4__sse_1x1(
15     size_t input_height,
16     size_t input_width,
17     size_t output_y_start,
18     size_t output_y_end,
19     const float* input,
20     const float* zero,
21     const float* weights,
22     float* output,
23     size_t input_padding_top,
24     size_t output_channels,
25     size_t output_height_stride,
26     size_t output_channel_stride,
27     const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
28 {
29   assert(input_width != 0);
30   assert(output_y_end > output_y_start);
31   assert(input_padding_top <= 1);
32   assert(output_channels != 0);
33 
34   const size_t input_height_stride = input_width * 3 /* channels */ * sizeof(float);
35   const size_t input_width_decrement = round_down_po2(input_width, 2) * 3 /* channels */ * sizeof(float);
36   const size_t output_width = (input_width + 1) / 2;
37   const size_t output_channel_increment = output_channel_stride * 4 - output_width * sizeof(float);
38 
39   // Adjustment for padding processed below
40   const float* i0 = (const float*) ((uintptr_t) input + input_height_stride * (output_y_start * 2 - input_padding_top));
41   const float* i1 = (const float*) ((uintptr_t) i0 + input_height_stride);
42   const float* i2 = (const float*) ((uintptr_t) i1 + input_height_stride);
43   float* output0 = (float*) ((uintptr_t) output + output_height_stride * output_y_start);
44 
45   if XNN_UNPREDICTABLE(output_y_start < input_padding_top) {
46     i0 = zero;
47   }
48 
49   const __m128 vmin = _mm_load_ps(params->sse.min);
50   const __m128 vmax = _mm_load_ps(params->sse.max);
51 
52   for (size_t output_y = output_y_start; output_y < output_y_end; output_y += 1) {
53     const size_t input_y2 = output_y * 2 + 2 - input_padding_top;
54     if XNN_UNPREDICTABLE(input_y2 >= input_height) {
55       i2 = zero;
56     }
57 
58     const float* w = weights;
59     size_t c = output_channels;
60     float* o0c0 = output0;
61     float* o0c1 = (float*) ((uintptr_t) o0c0 + output_channel_stride);
62     float* o0c2 = (float*) ((uintptr_t) o0c1 + output_channel_stride);
63     float* o0c3 = (float*) ((uintptr_t) o0c2 + output_channel_stride);
64     do {
65       if XNN_UNPREDICTABLE(c < 2) {
66         o0c1 = o0c0;
67       }
68       if XNN_UNPREDICTABLE(c <= 2) {
69         o0c2 = o0c1;
70       }
71       if XNN_UNPREDICTABLE(c < 4) {
72         o0c3 = o0c2;
73       }
74 
75       // Left edge padding
76       __m128 vi00c0 = _mm_setzero_ps();
77       __m128 vi00c1 = _mm_setzero_ps();
78       __m128 vi00c2 = _mm_setzero_ps();
79       __m128 vi10c0 = _mm_setzero_ps();
80       __m128 vi10c1 = _mm_setzero_ps();
81       __m128 vi10c2 = _mm_setzero_ps();
82       __m128 vi20c0 = _mm_setzero_ps();
83       __m128 vi20c1 = _mm_setzero_ps();
84       __m128 vi20c2 = _mm_setzero_ps();
85 
86       size_t iw = input_width;
87       for (; iw >= 2; iw -= 2) {
88         __m128 voc0123 = _mm_loadu_ps(w);
89 
90         const __m128 vk00c0x0123 = _mm_load_ps(w + 4);
91         voc0123 = _mm_add_ps(voc0123, _mm_mul_ps(vk00c0x0123, vi00c0));
92 
93         const __m128 vk10c0x0123 = _mm_load_ps(w + 8);
94         voc0123 = _mm_add_ps(voc0123, _mm_mul_ps(vk10c0x0123, vi10c0));
95 
96         const __m128 vk20c0x0123 = _mm_load_ps(w + 12);
97         voc0123 = _mm_add_ps(voc0123, _mm_mul_ps(vk20c0x0123, vi20c0));
98 
99         const __m128 vk00c1x0123 = _mm_load_ps(w + 16);
100         voc0123 = _mm_add_ps(voc0123, _mm_mul_ps(vk00c1x0123, vi00c1));
101 
102         const __m128 vk10c1x0123 = _mm_load_ps(w + 20);
103         voc0123 = _mm_add_ps(voc0123, _mm_mul_ps(vk10c1x0123, vi10c1));
104 
105         const __m128 vk20c1x0123 = _mm_load_ps(w + 24);
106         voc0123 = _mm_add_ps(voc0123, _mm_mul_ps(vk20c1x0123, vi20c1));
107 
108         const __m128 vk00c2x0123 = _mm_load_ps(w + 28);
109         voc0123 = _mm_add_ps(voc0123, _mm_mul_ps(vk00c2x0123, vi00c2));
110 
111         const __m128 vk10c2x0123 = _mm_load_ps(w + 32);
112         voc0123 = _mm_add_ps(voc0123, _mm_mul_ps(vk10c2x0123, vi10c2));
113 
114         const __m128 vk20c2x0123 = _mm_load_ps(w + 36);
115         voc0123 = _mm_add_ps(voc0123, _mm_mul_ps(vk20c2x0123, vi20c2));
116 
117         const __m128 vk01c0x0123 = _mm_load_ps(w + 40);
118         const __m128 vi01c0 = _mm_load1_ps(i0);
119         voc0123 = _mm_add_ps(voc0123, _mm_mul_ps(vk01c0x0123, vi01c0));
120 
121         const __m128 vk11c0x0123 = _mm_load_ps(w + 44);
122         const __m128 vi11c0 = _mm_load1_ps(i1);
123         voc0123 = _mm_add_ps(voc0123, _mm_mul_ps(vk11c0x0123, vi11c0));
124 
125         const __m128 vk21c0x0123 = _mm_load_ps(w + 48);
126         const __m128 vi21c0 = _mm_load1_ps(i2);
127         voc0123 = _mm_add_ps(voc0123, _mm_mul_ps(vk21c0x0123, vi21c0));
128 
129         const __m128 vk01c1x0123 = _mm_load_ps(w + 52);
130         const __m128 vi01c1 = _mm_load1_ps(i0 + 1);
131         voc0123 = _mm_add_ps(voc0123, _mm_mul_ps(vk01c1x0123, vi01c1));
132 
133         const __m128 vk11c1x0123 = _mm_load_ps(w + 56);
134         const __m128 vi11c1 = _mm_load1_ps(i1 + 1);
135         voc0123 = _mm_add_ps(voc0123, _mm_mul_ps(vk11c1x0123, vi11c1));
136 
137         const __m128 vk21c1x0123 = _mm_load_ps(w + 60);
138         const __m128 vi21c1 = _mm_load1_ps(i2 + 1);
139         voc0123 = _mm_add_ps(voc0123, _mm_mul_ps(vk21c1x0123, vi21c1));
140 
141         const __m128 vk01c2x0123 = _mm_load_ps(w + 64);
142         const __m128 vi01c2 = _mm_load1_ps(i0 + 2);
143         voc0123 = _mm_add_ps(voc0123, _mm_mul_ps(vk01c2x0123, vi01c2));
144 
145         const __m128 vk11c2x0123 = _mm_load_ps(w + 68);
146         const __m128 vi11c2 = _mm_load1_ps(i1 + 2);
147         voc0123 = _mm_add_ps(voc0123, _mm_mul_ps(vk11c2x0123, vi11c2));
148 
149         const __m128 vk21c2x0123 = _mm_load_ps(w + 72);
150         const __m128 vi21c2 = _mm_load1_ps(i2 + 2);
151         voc0123 = _mm_add_ps(voc0123, _mm_mul_ps(vk21c2x0123, vi21c2));
152 
153         const __m128 vk02c0x0123 = _mm_load_ps(w + 76);
154         const __m128 vi02c0 = _mm_load1_ps(i0 + 3);
155         voc0123 = _mm_add_ps(voc0123, _mm_mul_ps(vk02c0x0123, vi02c0));
156 
157         const __m128 vk12c0x0123 = _mm_load_ps(w + 80);
158         const __m128 vi12c0 = _mm_load1_ps(i1 + 3);
159         voc0123 = _mm_add_ps(voc0123, _mm_mul_ps(vk12c0x0123, vi12c0));
160 
161         const __m128 vk22c0x0123 = _mm_load_ps(w + 84);
162         const __m128 vi22c0 = _mm_load1_ps(i2 + 3);
163         voc0123 = _mm_add_ps(voc0123, _mm_mul_ps(vk22c0x0123, vi22c0));
164 
165         vi00c0 = vi02c0;
166         vi10c0 = vi12c0;
167         vi20c0 = vi22c0;
168 
169         const __m128 vk02c1x0123 = _mm_load_ps(w + 88);
170         const __m128 vi02c1 = _mm_load1_ps(i0 + 4);
171         voc0123 = _mm_add_ps(voc0123, _mm_mul_ps(vk02c1x0123, vi02c1));
172 
173         const __m128 vk12c1x0123 = _mm_load_ps(w + 92);
174         const __m128 vi12c1 = _mm_load1_ps(i1 + 4);
175         voc0123 = _mm_add_ps(voc0123, _mm_mul_ps(vk12c1x0123, vi12c1));
176 
177         const __m128 vk22c1x0123 = _mm_load_ps(w + 96);
178         const __m128 vi22c1 = _mm_load1_ps(i2 + 4);
179         voc0123 = _mm_add_ps(voc0123, _mm_mul_ps(vk22c1x0123, vi22c1));
180 
181         vi00c1 = vi02c1;
182         vi10c1 = vi12c1;
183         vi20c1 = vi22c1;
184 
185         const __m128 vk02c2x0123 = _mm_load_ps(w + 100);
186         const __m128 vi02c2 = _mm_load1_ps(i0 + 5);
187         voc0123 = _mm_add_ps(voc0123, _mm_mul_ps(vk02c2x0123, vi02c2));
188 
189         const __m128 vk12c2x0123 = _mm_load_ps(w + 104);
190         const __m128 vi12c2 = _mm_load1_ps(i1 + 5);
191         voc0123 = _mm_add_ps(voc0123, _mm_mul_ps(vk12c2x0123, vi12c2));
192 
193         const __m128 vk22c2x0123 = _mm_load_ps(w + 108);
194         const __m128 vi22c2 = _mm_load1_ps(i2 + 5);
195         voc0123 = _mm_add_ps(voc0123, _mm_mul_ps(vk22c2x0123, vi22c2));
196 
197         vi00c2 = vi02c2;
198         vi10c2 = vi12c2;
199         vi20c2 = vi22c2;
200 
201         voc0123 = _mm_min_ps(voc0123, vmax);
202         voc0123 = _mm_max_ps(voc0123, vmin);
203 
204         _mm_store_ss(o0c0, voc0123); o0c0++;
205         _mm_store_ss(o0c1, _mm_shuffle_ps(voc0123, voc0123, 1)); o0c1++;
206         _mm_store_ss(o0c2, _mm_shuffle_ps(voc0123, voc0123, 2)); o0c2++;
207         _mm_store_ss(o0c3, _mm_shuffle_ps(voc0123, voc0123, 3)); o0c3++;
208 
209         i0 += 6;
210         i1 += 6;
211         i2 += 6;
212       }
213       assert(iw < 2);
214       if XNN_UNLIKELY(iw != 0) {
215         __m128 voc0123 = _mm_load_ps(w);
216 
217         const __m128 vk00c0x0123 = _mm_load_ps(w + 4);
218         voc0123 = _mm_add_ps(voc0123, _mm_mul_ps(vk00c0x0123, vi00c0));
219 
220         const __m128 vk10c0x0123 = _mm_load_ps(w + 8);
221         voc0123 = _mm_add_ps(voc0123, _mm_mul_ps(vk10c0x0123, vi10c0));
222 
223         const __m128 vk20c0x0123 = _mm_load_ps(w + 12);
224         voc0123 = _mm_add_ps(voc0123, _mm_mul_ps(vk20c0x0123, vi20c0));
225 
226         const __m128 vk00c1x0123 = _mm_load_ps(w + 16);
227         voc0123 = _mm_add_ps(voc0123, _mm_mul_ps(vk00c1x0123, vi00c1));
228 
229         const __m128 vk10c1x0123 = _mm_load_ps(w + 20);
230         voc0123 = _mm_add_ps(voc0123, _mm_mul_ps(vk10c1x0123, vi10c1));
231 
232         const __m128 vk20c1x0123 = _mm_load_ps(w + 24);
233         voc0123 = _mm_add_ps(voc0123, _mm_mul_ps(vk20c1x0123, vi20c1));
234 
235         const __m128 vk00c2x0123 = _mm_load_ps(w + 28);
236         voc0123 = _mm_add_ps(voc0123, _mm_mul_ps(vk00c2x0123, vi00c2));
237 
238         const __m128 vk10c2x0123 = _mm_load_ps(w + 32);
239         voc0123 = _mm_add_ps(voc0123, _mm_mul_ps(vk10c2x0123, vi10c2));
240 
241         const __m128 vk20c2x0123 = _mm_load_ps(w + 36);
242         voc0123 = _mm_add_ps(voc0123, _mm_mul_ps(vk20c2x0123, vi20c2));
243 
244         const __m128 vk01c0x0123 = _mm_load_ps(w + 40);
245         const __m128 vi01c0 = _mm_load1_ps(i0);
246         voc0123 = _mm_add_ps(voc0123, _mm_mul_ps(vk01c0x0123, vi01c0));
247 
248         const __m128 vk11c0x0123 = _mm_load_ps(w + 44);
249         const __m128 vi11c0 = _mm_load1_ps(i1);
250         voc0123 = _mm_add_ps(voc0123, _mm_mul_ps(vk11c0x0123, vi11c0));
251 
252         const __m128 vk21c0x0123 = _mm_load_ps(w + 48);
253         const __m128 vi21c0 = _mm_load1_ps(i2);
254         voc0123 = _mm_add_ps(voc0123, _mm_mul_ps(vk21c0x0123, vi21c0));
255 
256         const __m128 vk01c1x0123 = _mm_load_ps(w + 52);
257         const __m128 vi01c1 = _mm_load1_ps(i0 + 1);
258         voc0123 = _mm_add_ps(voc0123, _mm_mul_ps(vk01c1x0123, vi01c1));
259 
260         const __m128 vk11c1x0123 = _mm_load_ps(w + 56);
261         const __m128 vi11c1 = _mm_load1_ps(i1 + 1);
262         voc0123 = _mm_add_ps(voc0123, _mm_mul_ps(vk11c1x0123, vi11c1));
263 
264         const __m128 vk21c1x0123 = _mm_load_ps(w + 60);
265         const __m128 vi21c1 = _mm_load1_ps(i2 + 1);
266         voc0123 = _mm_add_ps(voc0123, _mm_mul_ps(vk21c1x0123, vi21c1));
267 
268         const __m128 vk01c2x0123 = _mm_load_ps(w + 64);
269         const __m128 vi01c2 = _mm_load1_ps(i0 + 2);
270         voc0123 = _mm_add_ps(voc0123, _mm_mul_ps(vk01c2x0123, vi01c2));
271 
272         const __m128 vk11c2x0123 = _mm_load_ps(w + 68);
273         const __m128 vi11c2 = _mm_load1_ps(i1 + 2);
274         voc0123 = _mm_add_ps(voc0123, _mm_mul_ps(vk11c2x0123, vi11c2));
275 
276         const __m128 vk21c2x0123 = _mm_load_ps(w + 72);
277         const __m128 vi21c2 = _mm_load1_ps(i2 + 2);
278         voc0123 = _mm_add_ps(voc0123, _mm_mul_ps(vk21c2x0123, vi21c2));
279 
280         voc0123 = _mm_min_ps(voc0123, vmax);
281         voc0123 = _mm_max_ps(voc0123, vmin);
282 
283         _mm_store_ss(o0c0, voc0123); o0c0++;
284         _mm_store_ss(o0c1, _mm_shuffle_ps(voc0123, voc0123, 1)); o0c1++;
285         _mm_store_ss(o0c2, _mm_shuffle_ps(voc0123, voc0123, 2)); o0c2++;
286         _mm_store_ss(o0c3, _mm_shuffle_ps(voc0123, voc0123, 3)); o0c3++;
287       }
288       // Move output pointers back to the position of the first pixel in a row,
289       // and forward to the next block of output channels.
290       o0c0 = (float*) ((uintptr_t) o0c0 + output_channel_increment);
291       o0c1 = (float*) ((uintptr_t) o0c1 + output_channel_increment);
292       o0c2 = (float*) ((uintptr_t) o0c2 + output_channel_increment);
293       o0c3 = (float*) ((uintptr_t) o0c3 + output_channel_increment);
294       // Revert input pointers to the position of the first pixel in a row
295       i0 = (const float*) ((uintptr_t) i0 - input_width_decrement);
296       i1 = (const float*) ((uintptr_t) i1 - input_width_decrement);
297       i2 = (const float*) ((uintptr_t) i2 - input_width_decrement);
298       // Move to the block of weights for the next 4 output channels
299       w += 112;
300       c = doz(c, 4);
301     } while (c != 0);
302     // Move output pointers forward to the next row
303     output0 = (float*) ((uintptr_t) output0 + output_height_stride);
304     // Move input pointers forward to the next row
305     i0 = i2;
306     i1 = (const float*) ((uintptr_t) i0 + input_height_stride);
307     i2 = (const float*) ((uintptr_t) i1 + input_height_stride);
308   }
309 }
310