• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2022 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #include <assert.h>
7 
8 #include <arm_neon.h>
9 
10 #include <xnnpack/maxpool.h>
11 
12 
xnn_f16_maxpool_minmax_ukernel_9p8x__neonfp16arith_c8(size_t output_pixels,size_t kernel_elements,size_t channels,const void ** input,size_t input_offset,void * output,size_t input_increment,size_t output_increment,const union xnn_f16_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])13 void xnn_f16_maxpool_minmax_ukernel_9p8x__neonfp16arith_c8(
14     size_t output_pixels,
15     size_t kernel_elements,
16     size_t channels,
17     const void** input,
18     size_t input_offset,
19     void* output,
20     size_t input_increment,
21     size_t output_increment,
22     const union xnn_f16_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
23 {
24   assert(output_pixels != 0);
25   assert(kernel_elements != 0);
26   assert(channels != 0);
27 
28   const float16x8_t voutput_min = vreinterpretq_f16_u16(vld1q_dup_u16(&params->neon.min));
29   const float16x8_t voutput_max = vreinterpretq_f16_u16(vld1q_dup_u16(&params->neon.max));
30   do {
31     __fp16* o = output;
32     {
33       const __fp16* i0 = *input++;
34       const __fp16* i1 = *input++;
35       const __fp16* i2 = *input++;
36       const __fp16* i3 = *input++;
37       const __fp16* i4 = *input++;
38       const __fp16* i5 = *input++;
39       const __fp16* i6 = *input++;
40       const __fp16* i7 = *input++;
41       const __fp16* i8 = *input++;
42       i0 = (const __fp16*) ((uintptr_t) i0 + input_offset);
43       i1 = (const __fp16*) ((uintptr_t) i1 + input_offset);
44       i2 = (const __fp16*) ((uintptr_t) i2 + input_offset);
45       i3 = (const __fp16*) ((uintptr_t) i3 + input_offset);
46       i4 = (const __fp16*) ((uintptr_t) i4 + input_offset);
47       i5 = (const __fp16*) ((uintptr_t) i5 + input_offset);
48       i6 = (const __fp16*) ((uintptr_t) i6 + input_offset);
49       i7 = (const __fp16*) ((uintptr_t) i7 + input_offset);
50       i8 = (const __fp16*) ((uintptr_t) i8 + input_offset);
51       if (kernel_elements < 2) {
52         i1 = i0;
53       }
54       if (kernel_elements <= 2) {
55         i2 = i0;
56       }
57       if (kernel_elements < 4) {
58         i3 = i0;
59       }
60       if (kernel_elements <= 4) {
61         i4 = i0;
62       }
63       if (kernel_elements < 6) {
64         i5 = i0;
65       }
66       if (kernel_elements <= 6) {
67         i6 = i0;
68       }
69       if (kernel_elements < 8) {
70         i7 = i0;
71       }
72       if (kernel_elements <= 8) {
73         i8 = i0;
74       }
75 
76       size_t c = channels;
77       for (; c >= 8; c -= 8) {
78         const float16x8_t vi0 = vld1q_f16(i0); i0 += 8;
79         const float16x8_t vi1 = vld1q_f16(i1); i1 += 8;
80         const float16x8_t vi2 = vld1q_f16(i2); i2 += 8;
81         const float16x8_t vi3 = vld1q_f16(i3); i3 += 8;
82         const float16x8_t vi4 = vld1q_f16(i4); i4 += 8;
83         const float16x8_t vi5 = vld1q_f16(i5); i5 += 8;
84         const float16x8_t vi6 = vld1q_f16(i6); i6 += 8;
85         const float16x8_t vi7 = vld1q_f16(i7); i7 += 8;
86         const float16x8_t vi8 = vld1q_f16(i8); i8 += 8;
87 
88         const float16x8_t vmax018 = vmaxq_f16(vmaxq_f16(vi0, vi1), vi8);
89         const float16x8_t vmax23 = vmaxq_f16(vi2, vi3);
90         const float16x8_t vmax45 = vmaxq_f16(vi4, vi5);
91         const float16x8_t vmax67 = vmaxq_f16(vi6, vi7);
92 
93         const float16x8_t vmax2345 = vmaxq_f16(vmax23, vmax45);
94         const float16x8_t vmax01678 = vmaxq_f16(vmax018, vmax67);
95         const float16x8_t vmax = vmaxq_f16(vmax2345, vmax01678);
96         const float16x8_t vout = vmaxq_f16(vminq_f16(vmax, voutput_max), voutput_min);
97 
98         vst1q_f16(o, vout); o += 8;
99       }
100       if (c != 0) {
101         const float16x8_t vi0 = vld1q_f16(i0); i0 += 8;
102         const float16x8_t vi1 = vld1q_f16(i1); i1 += 8;
103         const float16x8_t vi2 = vld1q_f16(i2); i2 += 8;
104         const float16x8_t vi3 = vld1q_f16(i3); i3 += 8;
105         const float16x8_t vi4 = vld1q_f16(i4); i4 += 8;
106         const float16x8_t vi5 = vld1q_f16(i5); i5 += 8;
107         const float16x8_t vi6 = vld1q_f16(i6); i6 += 8;
108         const float16x8_t vi7 = vld1q_f16(i7); i7 += 8;
109         const float16x8_t vi8 = vld1q_f16(i8); i8 += 8;
110 
111         const float16x8_t vmax018 = vmaxq_f16(vmaxq_f16(vi0, vi1), vi8);
112         const float16x8_t vmax23 = vmaxq_f16(vi2, vi3);
113         const float16x8_t vmax45 = vmaxq_f16(vi4, vi5);
114         const float16x8_t vmax67 = vmaxq_f16(vi6, vi7);
115 
116         const float16x8_t vmax2345 = vmaxq_f16(vmax23, vmax45);
117         const float16x8_t vmax01678 = vmaxq_f16(vmax018, vmax67);
118         const float16x8_t vmax = vmaxq_f16(vmax2345, vmax01678);
119         float16x8_t vout = vmaxq_f16(vminq_f16(vmax, voutput_max), voutput_min);
120 
121         float16x4_t vout_lo = vget_low_f16(vout);
122         if (c & 4) {
123           vst1_f16(o, vout_lo); o += 4;
124           vout_lo = vget_high_f16(vout);
125         }
126         if (c & 2) {
127           vst1_lane_u32((void*) o, vreinterpret_u32_f16(vout_lo), 0); o += 2;
128           vout_lo = vext_f16(vout_lo, vout_lo, 2);
129         }
130         if (c & 1) {
131           vst1_lane_f16(o, vout_lo, 0); o += 1;
132         }
133       }
134     }
135 
136     for (ptrdiff_t k = (ptrdiff_t) kernel_elements - 9; k > 0; k -= 8) {
137       const __fp16* i0 = *input++;
138       const __fp16* i1 = *input++;
139       const __fp16* i2 = *input++;
140       const __fp16* i3 = *input++;
141       const __fp16* i4 = *input++;
142       const __fp16* i5 = *input++;
143       const __fp16* i6 = *input++;
144       const __fp16* i7 = *input++;
145       i0 = (const __fp16*) ((uintptr_t) i0 + input_offset);
146       i1 = (const __fp16*) ((uintptr_t) i1 + input_offset);
147       i2 = (const __fp16*) ((uintptr_t) i2 + input_offset);
148       i3 = (const __fp16*) ((uintptr_t) i3 + input_offset);
149       i4 = (const __fp16*) ((uintptr_t) i4 + input_offset);
150       i5 = (const __fp16*) ((uintptr_t) i5 + input_offset);
151       i6 = (const __fp16*) ((uintptr_t) i6 + input_offset);
152       i7 = (const __fp16*) ((uintptr_t) i7 + input_offset);
153       if (k < 2) {
154         i1 = i0;
155       }
156       if (k <= 2) {
157         i2 = i0;
158       }
159       if (k < 4) {
160         i3 = i0;
161       }
162       if (k <= 4) {
163         i4 = i0;
164       }
165       if (k < 6) {
166         i5 = i0;
167       }
168       if (k <= 6) {
169         i6 = i0;
170       }
171       if (k < 8) {
172         i7 = i0;
173       }
174 
175       o = output;
176       size_t c = channels;
177       for (; c >= 8; c -= 8) {
178         const float16x8_t vi0 = vld1q_f16(i0); i0 += 8;
179         const float16x8_t vi1 = vld1q_f16(i1); i1 += 8;
180         const float16x8_t vi2 = vld1q_f16(i2); i2 += 8;
181         const float16x8_t vi3 = vld1q_f16(i3); i3 += 8;
182         const float16x8_t vi4 = vld1q_f16(i4); i4 += 8;
183         const float16x8_t vi5 = vld1q_f16(i5); i5 += 8;
184         const float16x8_t vi6 = vld1q_f16(i6); i6 += 8;
185         const float16x8_t vi7 = vld1q_f16(i7); i7 += 8;
186         const float16x8_t vo = vld1q_f16(o);
187 
188         const float16x8_t vmax01 = vmaxq_f16(vmaxq_f16(vi0, vi1), vo);
189         const float16x8_t vmax23 = vmaxq_f16(vi2, vi3);
190         const float16x8_t vmax45 = vmaxq_f16(vi4, vi5);
191         const float16x8_t vmax67 = vmaxq_f16(vi6, vi7);
192 
193         const float16x8_t vmax2345 = vmaxq_f16(vmax23, vmax45);
194         const float16x8_t vmax0167 = vmaxq_f16(vmax01, vmax67);
195         const float16x8_t vmax = vmaxq_f16(vmax2345, vmax0167);
196         const float16x8_t vout = vmaxq_f16(vminq_f16(vmax, voutput_max), voutput_min);
197 
198         vst1q_f16(o, vout); o += 8;
199       }
200       if (c != 0) {
201         const float16x8_t vi0 = vld1q_f16(i0);
202         const float16x8_t vi1 = vld1q_f16(i1);
203         const float16x8_t vi2 = vld1q_f16(i2);
204         const float16x8_t vi3 = vld1q_f16(i3);
205         const float16x8_t vi4 = vld1q_f16(i4);
206         const float16x8_t vi5 = vld1q_f16(i5);
207         const float16x8_t vi6 = vld1q_f16(i6);
208         const float16x8_t vi7 = vld1q_f16(i7);
209         const float16x8_t vo = vld1q_f16(o);
210 
211         const float16x8_t vmax01 = vmaxq_f16(vmaxq_f16(vi0, vi1), vo);
212         const float16x8_t vmax23 = vmaxq_f16(vi2, vi3);
213         const float16x8_t vmax45 = vmaxq_f16(vi4, vi5);
214         const float16x8_t vmax67 = vmaxq_f16(vi6, vi7);
215 
216         const float16x8_t vmax2345 = vmaxq_f16(vmax23, vmax45);
217         const float16x8_t vmax0167 = vmaxq_f16(vmax01, vmax67);
218         const float16x8_t vmax = vmaxq_f16(vmax2345, vmax0167);
219         float16x8_t vout = vmaxq_f16(vminq_f16(vmax, voutput_max), voutput_min);
220 
221         float16x4_t vout_lo = vget_low_f16(vout);
222         if (c & 4) {
223           vst1_f16(o, vout_lo); o += 4;
224           vout_lo = vget_high_f16(vout);
225         }
226         if (c & 2) {
227           vst1_lane_u32((void*) o, vreinterpret_u32_f16(vout_lo), 0); o += 2;
228           vout_lo = vext_f16(vout_lo, vout_lo, 2);
229         }
230         if (c & 1) {
231           vst1_lane_f16(o, vout_lo, 0); o += 1;
232         }
233       }
234     }
235     input = (const void**) ((uintptr_t) input + input_increment);
236     output = (__fp16*) ((uintptr_t) o + output_increment);
237   } while (--output_pixels != 0);
238 }
239