1 // Copyright 2022 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5
6 #include <assert.h>
7
8 #include <immintrin.h>
9
10 #include <xnnpack/maxpool.h>
11
12
xnn_f16_maxpool_minmax_ukernel_9p8x__f16c_c8(size_t output_pixels,size_t kernel_elements,size_t channels,const void ** input,size_t input_offset,void * output,size_t input_increment,size_t output_increment,const union xnn_f16_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])13 void xnn_f16_maxpool_minmax_ukernel_9p8x__f16c_c8(
14 size_t output_pixels,
15 size_t kernel_elements,
16 size_t channels,
17 const void** input,
18 size_t input_offset,
19 void* output,
20 size_t input_increment,
21 size_t output_increment,
22 const union xnn_f16_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
23 {
24 assert(output_pixels != 0);
25 assert(kernel_elements != 0);
26 assert(channels != 0);
27
28 const __m256 voutput_min = _mm256_load_ps(params->avx.min);
29 const __m256 voutput_max = _mm256_load_ps(params->avx.max);
30 do {
31 uint16_t* o = output;
32 {
33 const uint16_t* i0 = *input++;
34 const uint16_t* i1 = *input++;
35 const uint16_t* i2 = *input++;
36 const uint16_t* i3 = *input++;
37 const uint16_t* i4 = *input++;
38 const uint16_t* i5 = *input++;
39 const uint16_t* i6 = *input++;
40 const uint16_t* i7 = *input++;
41 const uint16_t* i8 = *input++;
42 i0 = (const uint16_t*) ((uintptr_t) i0 + input_offset);
43 i1 = (const uint16_t*) ((uintptr_t) i1 + input_offset);
44 i2 = (const uint16_t*) ((uintptr_t) i2 + input_offset);
45 i3 = (const uint16_t*) ((uintptr_t) i3 + input_offset);
46 i4 = (const uint16_t*) ((uintptr_t) i4 + input_offset);
47 i5 = (const uint16_t*) ((uintptr_t) i5 + input_offset);
48 i6 = (const uint16_t*) ((uintptr_t) i6 + input_offset);
49 i7 = (const uint16_t*) ((uintptr_t) i7 + input_offset);
50 i8 = (const uint16_t*) ((uintptr_t) i8 + input_offset);
51 if (kernel_elements < 2) {
52 i1 = i0;
53 }
54 if (kernel_elements <= 2) {
55 i2 = i0;
56 }
57 if (kernel_elements < 4) {
58 i3 = i0;
59 }
60 if (kernel_elements <= 4) {
61 i4 = i0;
62 }
63 if (kernel_elements < 6) {
64 i5 = i0;
65 }
66 if (kernel_elements <= 6) {
67 i6 = i0;
68 }
69 if (kernel_elements < 8) {
70 i7 = i0;
71 }
72 if (kernel_elements <= 8) {
73 i8 = i0;
74 }
75
76 size_t c = channels;
77 for (; c >= 8; c -= 8) {
78 const __m256 vi0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i0));
79 i0 += 8;
80 const __m256 vi1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i1));
81 i1 += 8;
82 const __m256 vi2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i2));
83 i2 += 8;
84 const __m256 vi3 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i3));
85 i3 += 8;
86 const __m256 vi4 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i4));
87 i4 += 8;
88 const __m256 vi5 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i5));
89 i5 += 8;
90 const __m256 vi6 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i6));
91 i6 += 8;
92 const __m256 vi7 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i7));
93 i7 += 8;
94 const __m256 vi8 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i8));
95 i8 += 8;
96
97 const __m256 vmax018 = _mm256_max_ps(_mm256_max_ps(vi0, vi1), vi8);
98 const __m256 vmax23 = _mm256_max_ps(vi2, vi3);
99 const __m256 vmax45 = _mm256_max_ps(vi4, vi5);
100 const __m256 vmax67 = _mm256_max_ps(vi6, vi7);
101
102 const __m256 vmax2345 = _mm256_max_ps(vmax23, vmax45);
103 const __m256 vmax01678 = _mm256_max_ps(vmax018, vmax67);
104 const __m256 vmax = _mm256_max_ps(vmax2345, vmax01678);
105 const __m256 vout = _mm256_max_ps(_mm256_min_ps(vmax, voutput_max), voutput_min);
106
107 _mm_storeu_si128((__m128i*) o, _mm256_cvtps_ph(vout, _MM_FROUND_NO_EXC));
108 o += 8;
109 }
110 if (c != 0) {
111 const __m256 vi0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i0));
112 i0 += 8;
113 const __m256 vi1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i1));
114 i1 += 8;
115 const __m256 vi2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i2));
116 i2 += 8;
117 const __m256 vi3 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i3));
118 i3 += 8;
119 const __m256 vi4 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i4));
120 i4 += 8;
121 const __m256 vi5 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i5));
122 i5 += 8;
123 const __m256 vi6 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i6));
124 i6 += 8;
125 const __m256 vi7 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i7));
126 i7 += 8;
127 const __m256 vi8 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i8));
128 i8 += 8;
129
130 const __m256 vmax018 = _mm256_max_ps(_mm256_max_ps(vi0, vi1), vi8);
131 const __m256 vmax23 = _mm256_max_ps(vi2, vi3);
132 const __m256 vmax45 = _mm256_max_ps(vi4, vi5);
133 const __m256 vmax67 = _mm256_max_ps(vi6, vi7);
134
135 const __m256 vmax2345 = _mm256_max_ps(vmax23, vmax45);
136 const __m256 vmax01678 = _mm256_max_ps(vmax018, vmax67);
137 const __m256 vmax = _mm256_max_ps(vmax2345, vmax01678);
138 __m256 vout = _mm256_max_ps(_mm256_min_ps(vmax, voutput_max), voutput_min);
139
140 __m128i vh = _mm256_cvtps_ph(vout, _MM_FROUND_NO_EXC);
141 if (c & 4) {
142 _mm_storel_epi64((__m128i*) o, vh);
143 vh = _mm_unpackhi_epi64(vh, vh);
144 o += 4;
145 }
146 if (c & 2) {
147 *((uint32_t*) o) = (uint32_t) _mm_cvtsi128_si32(vh);
148 vh = _mm_srli_epi64(vh, 32);
149 o += 2;
150 }
151 if (c & 1) {
152 *o = _mm_extract_epi16(vh, 0);
153 o += 1;
154 }
155 }
156 }
157
158 for (ptrdiff_t k = (ptrdiff_t) kernel_elements - 9; k > 0; k -= 8) {
159 const uint16_t* i0 = *input++;
160 const uint16_t* i1 = *input++;
161 const uint16_t* i2 = *input++;
162 const uint16_t* i3 = *input++;
163 const uint16_t* i4 = *input++;
164 const uint16_t* i5 = *input++;
165 const uint16_t* i6 = *input++;
166 const uint16_t* i7 = *input++;
167 i0 = (const uint16_t*) ((uintptr_t) i0 + input_offset);
168 i1 = (const uint16_t*) ((uintptr_t) i1 + input_offset);
169 i2 = (const uint16_t*) ((uintptr_t) i2 + input_offset);
170 i3 = (const uint16_t*) ((uintptr_t) i3 + input_offset);
171 i4 = (const uint16_t*) ((uintptr_t) i4 + input_offset);
172 i5 = (const uint16_t*) ((uintptr_t) i5 + input_offset);
173 i6 = (const uint16_t*) ((uintptr_t) i6 + input_offset);
174 i7 = (const uint16_t*) ((uintptr_t) i7 + input_offset);
175 if (k < 2) {
176 i1 = i0;
177 }
178 if (k <= 2) {
179 i2 = i0;
180 }
181 if (k < 4) {
182 i3 = i0;
183 }
184 if (k <= 4) {
185 i4 = i0;
186 }
187 if (k < 6) {
188 i5 = i0;
189 }
190 if (k <= 6) {
191 i6 = i0;
192 }
193 if (k < 8) {
194 i7 = i0;
195 }
196
197 o = output;
198 size_t c = channels;
199 for (; c >= 8; c -= 8) {
200 const __m256 vi0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i0));
201 i0 += 8;
202 const __m256 vi1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i1));
203 i1 += 8;
204 const __m256 vi2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i2));
205 i2 += 8;
206 const __m256 vi3 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i3));
207 i3 += 8;
208 const __m256 vi4 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i4));
209 i4 += 8;
210 const __m256 vi5 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i5));
211 i5 += 8;
212 const __m256 vi6 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i6));
213 i6 += 8;
214 const __m256 vi7 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i7));
215 i7 += 8;
216 const __m256 vo = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) o));
217
218 const __m256 vmax01 = _mm256_max_ps(_mm256_max_ps(vi0, vi1), vo);
219 const __m256 vmax23 = _mm256_max_ps(vi2, vi3);
220 const __m256 vmax45 = _mm256_max_ps(vi4, vi5);
221 const __m256 vmax67 = _mm256_max_ps(vi6, vi7);
222
223 const __m256 vmax2345 = _mm256_max_ps(vmax23, vmax45);
224 const __m256 vmax0167 = _mm256_max_ps(vmax01, vmax67);
225 const __m256 vmax = _mm256_max_ps(vmax2345, vmax0167);
226 const __m256 vout = _mm256_max_ps(_mm256_min_ps(vmax, voutput_max), voutput_min);
227
228 _mm_storeu_si128((__m128i*) o, _mm256_cvtps_ph(vout, _MM_FROUND_NO_EXC));
229 o += 8;
230 }
231 if (c != 0) {
232 const __m256 vi0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i0));
233 const __m256 vi1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i1));
234 const __m256 vi2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i2));
235 const __m256 vi3 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i3));
236 const __m256 vi4 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i4));
237 const __m256 vi5 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i5));
238 const __m256 vi6 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i6));
239 const __m256 vi7 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i7));
240 const __m256 vo = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) o));
241
242 const __m256 vmax01 = _mm256_max_ps(_mm256_max_ps(vi0, vi1), vo);
243 const __m256 vmax23 = _mm256_max_ps(vi2, vi3);
244 const __m256 vmax45 = _mm256_max_ps(vi4, vi5);
245 const __m256 vmax67 = _mm256_max_ps(vi6, vi7);
246
247 const __m256 vmax2345 = _mm256_max_ps(vmax23, vmax45);
248 const __m256 vmax0167 = _mm256_max_ps(vmax01, vmax67);
249 const __m256 vmax = _mm256_max_ps(vmax2345, vmax0167);
250 __m256 vout = _mm256_max_ps(_mm256_min_ps(vmax, voutput_max), voutput_min);
251
252 __m128i vh = _mm256_cvtps_ph(vout, _MM_FROUND_NO_EXC);
253 if (c & 4) {
254 _mm_storel_epi64((__m128i*) o, vh);
255 vh = _mm_unpackhi_epi64(vh, vh);
256 o += 4;
257 }
258 if (c & 2) {
259 *((uint32_t*) o) = (uint32_t) _mm_cvtsi128_si32(vh);
260 vh = _mm_srli_epi64(vh, 32);
261 o += 2;
262 }
263 if (c & 1) {
264 *o = _mm_extract_epi16(vh, 0);
265 o += 1;
266 }
267 }
268 }
269 input = (const void**) ((uintptr_t) input + input_increment);
270 output = (uint16_t*) ((uintptr_t) o + output_increment);
271 } while (--output_pixels != 0);
272 }
273