1 // Copyright 2020 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5
6 #include <assert.h>
7
8 #include <wasm_simd128.h>
9
10 #include <xnnpack/maxpool.h>
11
12
xnn_f32_maxpool_minmax_ukernel_9p8x__wasmsimd_x86_c4(size_t output_pixels,size_t kernel_elements,size_t channels,const float ** input,size_t input_offset,float * output,size_t input_increment,size_t output_increment,const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])13 void xnn_f32_maxpool_minmax_ukernel_9p8x__wasmsimd_x86_c4(
14 size_t output_pixels,
15 size_t kernel_elements,
16 size_t channels,
17 const float** input,
18 size_t input_offset,
19 float* output,
20 size_t input_increment,
21 size_t output_increment,
22 const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
23 {
24 assert(output_pixels != 0);
25 assert(kernel_elements != 0);
26 assert(channels != 0);
27
28 const v128_t voutput_max = wasm_v128_load64_splat(params->wasmsimd.max);
29 const v128_t voutput_min = wasm_v128_load64_splat(params->wasmsimd.min);
30 do {
31 float* o = output;
32 {
33 const float* i0 = *input++;
34 const float* i1 = *input++;
35 const float* i2 = *input++;
36 const float* i3 = *input++;
37 const float* i4 = *input++;
38 const float* i5 = *input++;
39 const float* i6 = *input++;
40 const float* i7 = *input++;
41 const float* i8 = *input++;
42 i0 = (const float*) ((uintptr_t) i0 + input_offset);
43 i1 = (const float*) ((uintptr_t) i1 + input_offset);
44 i2 = (const float*) ((uintptr_t) i2 + input_offset);
45 i3 = (const float*) ((uintptr_t) i3 + input_offset);
46 i4 = (const float*) ((uintptr_t) i4 + input_offset);
47 i5 = (const float*) ((uintptr_t) i5 + input_offset);
48 i6 = (const float*) ((uintptr_t) i6 + input_offset);
49 i7 = (const float*) ((uintptr_t) i7 + input_offset);
50 i8 = (const float*) ((uintptr_t) i8 + input_offset);
51 if (kernel_elements < 2) {
52 i1 = i0;
53 }
54 if (kernel_elements <= 2) {
55 i2 = i0;
56 }
57 if (kernel_elements < 4) {
58 i3 = i0;
59 }
60 if (kernel_elements <= 4) {
61 i4 = i0;
62 }
63 if (kernel_elements < 6) {
64 i5 = i0;
65 }
66 if (kernel_elements <= 6) {
67 i6 = i0;
68 }
69 if (kernel_elements < 8) {
70 i7 = i0;
71 }
72 if (kernel_elements <= 8) {
73 i8 = i0;
74 }
75
76 size_t c = channels;
77 for (; c >= 4; c -= 4) {
78 const v128_t vi0 = wasm_v128_load(i0);
79 i0 += 4;
80 const v128_t vi1 = wasm_v128_load(i1);
81 i1 += 4;
82 const v128_t vi2 = wasm_v128_load(i2);
83 i2 += 4;
84 const v128_t vi3 = wasm_v128_load(i3);
85 i3 += 4;
86 const v128_t vi4 = wasm_v128_load(i4);
87 i4 += 4;
88 const v128_t vi5 = wasm_v128_load(i5);
89 i5 += 4;
90 const v128_t vi6 = wasm_v128_load(i6);
91 i6 += 4;
92 const v128_t vi7 = wasm_v128_load(i7);
93 i7 += 4;
94 const v128_t vi8 = wasm_v128_load(i8);
95 i8 += 4;
96
97 const v128_t vmax01 = wasm_f32x4_pmax(vi1, vi0);
98 const v128_t vmax23 = wasm_f32x4_pmax(vi3, vi2);
99 const v128_t vmax45 = wasm_f32x4_pmax(vi5, vi4);
100 const v128_t vmax018 = wasm_f32x4_pmax(vi8, vmax01);
101 const v128_t vmax67 = wasm_f32x4_pmax(vi7, vi6);
102
103 const v128_t vmax2345 = wasm_f32x4_pmax(vmax45, vmax23);
104 const v128_t vmax01678 = wasm_f32x4_pmax(vmax67, vmax018);
105 const v128_t vmax = wasm_f32x4_pmax(vmax2345, vmax01678);
106
107 v128_t vout = wasm_f32x4_pmax(voutput_min, vmax);
108 vout = wasm_f32x4_pmin(voutput_max, vout);
109
110 wasm_v128_store(o, vout);
111 o += 4;
112 }
113 if (c != 0) {
114 const v128_t vi0 = wasm_v128_load(i0);
115 i0 += 4;
116 const v128_t vi1 = wasm_v128_load(i1);
117 i1 += 4;
118 const v128_t vi2 = wasm_v128_load(i2);
119 i2 += 4;
120 const v128_t vi3 = wasm_v128_load(i3);
121 i3 += 4;
122 const v128_t vi4 = wasm_v128_load(i4);
123 i4 += 4;
124 const v128_t vi5 = wasm_v128_load(i5);
125 i5 += 4;
126 const v128_t vi6 = wasm_v128_load(i6);
127 i6 += 4;
128 const v128_t vi7 = wasm_v128_load(i7);
129 i7 += 4;
130 const v128_t vi8 = wasm_v128_load(i8);
131 i8 += 4;
132
133 const v128_t vmax01 = wasm_f32x4_pmax(vi1, vi0);
134 const v128_t vmax23 = wasm_f32x4_pmax(vi3, vi2);
135 const v128_t vmax45 = wasm_f32x4_pmax(vi5, vi4);
136 const v128_t vmax018 = wasm_f32x4_pmax(vi8, vmax01);
137 const v128_t vmax67 = wasm_f32x4_pmax(vi7, vi6);
138
139 const v128_t vmax2345 = wasm_f32x4_pmax(vmax45, vmax23);
140 const v128_t vmax01678 = wasm_f32x4_pmax(vmax67, vmax018);
141 const v128_t vmax = wasm_f32x4_pmax(vmax2345, vmax01678);
142
143 v128_t vout = wasm_f32x4_pmax(voutput_min, vmax);
144 vout = wasm_f32x4_pmin(voutput_max, vout);
145
146 if (c & 2) {
147 *((double*) o) = wasm_f64x2_extract_lane(vout, 0);
148 vout = wasm_v32x4_shuffle(vout, vout, 2, 3, 2, 3);
149 o += 2;
150 }
151 if (c & 1) {
152 *o++ = wasm_f32x4_extract_lane(vout, 0);
153 }
154 }
155 }
156
157 for (ptrdiff_t k = (ptrdiff_t) kernel_elements - 9; k > 0; k -= 8) {
158 const float* i0 = *input++;
159 const float* i1 = *input++;
160 const float* i2 = *input++;
161 const float* i3 = *input++;
162 const float* i4 = *input++;
163 const float* i5 = *input++;
164 const float* i6 = *input++;
165 const float* i7 = *input++;
166 i0 = (const float*) ((uintptr_t) i0 + input_offset);
167 i1 = (const float*) ((uintptr_t) i1 + input_offset);
168 i2 = (const float*) ((uintptr_t) i2 + input_offset);
169 i3 = (const float*) ((uintptr_t) i3 + input_offset);
170 i4 = (const float*) ((uintptr_t) i4 + input_offset);
171 i5 = (const float*) ((uintptr_t) i5 + input_offset);
172 i6 = (const float*) ((uintptr_t) i6 + input_offset);
173 i7 = (const float*) ((uintptr_t) i7 + input_offset);
174 if (k < 2) {
175 i1 = i0;
176 }
177 if (k <= 2) {
178 i2 = i0;
179 }
180 if (k < 4) {
181 i3 = i0;
182 }
183 if (k <= 4) {
184 i4 = i0;
185 }
186 if (k < 6) {
187 i5 = i0;
188 }
189 if (k <= 6) {
190 i6 = i0;
191 }
192 if (k < 8) {
193 i7 = i0;
194 }
195
196 o = output;
197 size_t c = channels;
198 for (; c >= 4; c -= 4) {
199 const v128_t vi0 = wasm_v128_load(i0);
200 i0 += 4;
201 const v128_t vi1 = wasm_v128_load(i1);
202 i1 += 4;
203 const v128_t vi2 = wasm_v128_load(i2);
204 i2 += 4;
205 const v128_t vi3 = wasm_v128_load(i3);
206 i3 += 4;
207 const v128_t vi4 = wasm_v128_load(i4);
208 i4 += 4;
209 const v128_t vi5 = wasm_v128_load(i5);
210 i5 += 4;
211 const v128_t vi6 = wasm_v128_load(i6);
212 i6 += 4;
213 const v128_t vi7 = wasm_v128_load(i7);
214 i7 += 4;
215 const v128_t vo = wasm_v128_load(o);
216
217 const v128_t vmax01 = wasm_f32x4_pmax(vi1, vi0);
218 const v128_t vmax23 = wasm_f32x4_pmax(vi3, vi2);
219 const v128_t vmax45 = wasm_f32x4_pmax(vi5, vi4);
220 const v128_t vmax01o = wasm_f32x4_pmax(vo, vmax01);
221 const v128_t vmax67 = wasm_f32x4_pmax(vi7, vi6);
222
223 const v128_t vmax2345 = wasm_f32x4_pmax(vmax45, vmax23);
224 const v128_t vmax0167 = wasm_f32x4_pmax(vmax67, vmax01o);
225 const v128_t vmax = wasm_f32x4_pmax(vmax2345, vmax0167);
226
227 v128_t vout = wasm_f32x4_pmax(voutput_min, vmax);
228 vout = wasm_f32x4_pmin(voutput_max, vout);
229
230 wasm_v128_store(o, vout);
231 o += 4;
232 }
233 if (c != 0) {
234 const v128_t vi0 = wasm_v128_load(i0);
235 const v128_t vi1 = wasm_v128_load(i1);
236 const v128_t vi2 = wasm_v128_load(i2);
237 const v128_t vi3 = wasm_v128_load(i3);
238 const v128_t vi4 = wasm_v128_load(i4);
239 const v128_t vi5 = wasm_v128_load(i5);
240 const v128_t vi6 = wasm_v128_load(i6);
241 const v128_t vi7 = wasm_v128_load(i7);
242 const v128_t vo = wasm_v128_load(o);
243
244 const v128_t vmax01 = wasm_f32x4_pmax(vi1, vi0);
245 const v128_t vmax23 = wasm_f32x4_pmax(vi3, vi2);
246 const v128_t vmax45 = wasm_f32x4_pmax(vi5, vi4);
247 const v128_t vmax01o = wasm_f32x4_pmax(vo, vmax01);
248 const v128_t vmax67 = wasm_f32x4_pmax(vi7, vi6);
249
250 const v128_t vmax2345 = wasm_f32x4_pmax(vmax45, vmax23);
251 const v128_t vmax0167 = wasm_f32x4_pmax(vmax67, vmax01o);
252 const v128_t vmax = wasm_f32x4_pmax(vmax2345, vmax0167);
253
254 v128_t vout = wasm_f32x4_pmax(voutput_min, vmax);
255 vout = wasm_f32x4_pmin(voutput_max, vout);
256
257 if (c & 2) {
258 *((double*) o) = wasm_f64x2_extract_lane(vout, 0);
259 vout = wasm_v32x4_shuffle(vout, vout, 2, 3, 2, 3);
260 o += 2;
261 }
262 if (c & 1) {
263 *o++ = wasm_f32x4_extract_lane(vout, 0);
264 }
265 }
266 }
267 input = (const float**) ((uintptr_t) input + input_increment);
268 output = (float*) ((uintptr_t) o + output_increment);
269 } while (--output_pixels != 0);
270 }
271