• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // Copyright 2019 Google LLC
5 //
6 // This source code is licensed under the BSD-style license found in the
7 // LICENSE file in the root directory of this source tree.
8 
9 #include <assert.h>
10 
11 #include <arm_neon.h>
12 
13 #include <xnnpack/maxpool.h>
14 
15 
xnn_u8_maxpool_ukernel_9p8x__neon_c16(size_t output_pixels,size_t kernel_elements,size_t channels,const uint8_t ** input,size_t input_offset,uint8_t * output,size_t input_increment,size_t output_increment,const union xnn_u8_output_params params[restrict static1])16 void xnn_u8_maxpool_ukernel_9p8x__neon_c16(
17     size_t output_pixels,
18     size_t kernel_elements,
19     size_t channels,
20     const uint8_t** input,
21     size_t input_offset,
22     uint8_t* output,
23     size_t input_increment,
24     size_t output_increment,
25     const union xnn_u8_output_params params[restrict static 1])
26 {
27   assert(output_pixels != 0);
28   assert(kernel_elements != 0);
29   assert(channels != 0);
30 
31   const uint8x16_t voutput_max = vld1q_dup_u8(&params->neon.max);
32   const uint8x16_t voutput_min = vld1q_dup_u8(&params->neon.min);
33   do {
34     uint8_t* o = output;
35     {
36       const uint8_t* i0 = *input++;
37       const uint8_t* i1 = *input++;
38       const uint8_t* i2 = *input++;
39       const uint8_t* i3 = *input++;
40       const uint8_t* i4 = *input++;
41       const uint8_t* i5 = *input++;
42       const uint8_t* i6 = *input++;
43       const uint8_t* i7 = *input++;
44       const uint8_t* i8 = *input++;
45       i0 = (const uint8_t*) ((uintptr_t) i0 + input_offset);
46       i1 = (const uint8_t*) ((uintptr_t) i1 + input_offset);
47       i2 = (const uint8_t*) ((uintptr_t) i2 + input_offset);
48       i3 = (const uint8_t*) ((uintptr_t) i3 + input_offset);
49       i4 = (const uint8_t*) ((uintptr_t) i4 + input_offset);
50       i5 = (const uint8_t*) ((uintptr_t) i5 + input_offset);
51       i6 = (const uint8_t*) ((uintptr_t) i6 + input_offset);
52       i7 = (const uint8_t*) ((uintptr_t) i7 + input_offset);
53       i8 = (const uint8_t*) ((uintptr_t) i8 + input_offset);
54       if (kernel_elements < 2) {
55         i1 = i0;
56       }
57       if (kernel_elements <= 2) {
58         i2 = i0;
59       }
60       if (kernel_elements < 4) {
61         i3 = i0;
62       }
63       if (kernel_elements <= 4) {
64         i4 = i0;
65       }
66       if (kernel_elements < 6) {
67         i5 = i0;
68       }
69       if (kernel_elements <= 6) {
70         i6 = i0;
71       }
72       if (kernel_elements < 8) {
73         i7 = i0;
74       }
75       if (kernel_elements <= 8) {
76         i8 = i0;
77       }
78 
79       size_t c = channels;
80       for (; c >= 16; c -= 16) {
81         const uint8x16_t vi0 = vld1q_u8(i0); i0 += 16;
82         const uint8x16_t vi1 = vld1q_u8(i1); i1 += 16;
83         const uint8x16_t vi2 = vld1q_u8(i2); i2 += 16;
84         const uint8x16_t vi3 = vld1q_u8(i3); i3 += 16;
85         const uint8x16_t vi4 = vld1q_u8(i4); i4 += 16;
86         const uint8x16_t vi5 = vld1q_u8(i5); i5 += 16;
87         const uint8x16_t vi6 = vld1q_u8(i6); i6 += 16;
88         const uint8x16_t vi7 = vld1q_u8(i7); i7 += 16;
89         const uint8x16_t vi8 = vld1q_u8(i8); i8 += 16;
90 
91         const uint8x16_t vmax018 = vmaxq_u8(vmaxq_u8(vi0, vi1), vi8);
92         const uint8x16_t vmax23 = vmaxq_u8(vi2, vi3);
93         const uint8x16_t vmax45 = vmaxq_u8(vi4, vi5);
94         const uint8x16_t vmax67 = vmaxq_u8(vi6, vi7);
95 
96         const uint8x16_t vmax2345 = vmaxq_u8(vmax23, vmax45);
97         const uint8x16_t vmax01678 = vmaxq_u8(vmax018, vmax67);
98         const uint8x16_t vmax = vmaxq_u8(vmax2345, vmax01678);
99         const uint8x16_t vout = vmaxq_u8(vminq_u8(vmax, voutput_max), voutput_min);
100 
101         vst1q_u8(o, vout); o += 16;
102       }
103       if (c != 0) {
104         const uint8x16_t vi0 = vld1q_u8(i0);
105         const uint8x16_t vi1 = vld1q_u8(i1);
106         const uint8x16_t vi2 = vld1q_u8(i2);
107         const uint8x16_t vi3 = vld1q_u8(i3);
108         const uint8x16_t vi4 = vld1q_u8(i4);
109         const uint8x16_t vi5 = vld1q_u8(i5);
110         const uint8x16_t vi6 = vld1q_u8(i6);
111         const uint8x16_t vi7 = vld1q_u8(i7);
112         const uint8x16_t vi8 = vld1q_u8(i8);
113 
114         const uint8x16_t vmax018 = vmaxq_u8(vmaxq_u8(vi0, vi1), vi8);
115         const uint8x16_t vmax23 = vmaxq_u8(vi2, vi3);
116         const uint8x16_t vmax45 = vmaxq_u8(vi4, vi5);
117         const uint8x16_t vmax67 = vmaxq_u8(vi6, vi7);
118 
119         const uint8x16_t vmax2345 = vmaxq_u8(vmax23, vmax45);
120         const uint8x16_t vmax01678 = vmaxq_u8(vmax018, vmax67);
121         const uint8x16_t vmax = vmaxq_u8(vmax2345, vmax01678);
122         const uint8x16_t vout = vmaxq_u8(vminq_u8(vmax, voutput_max), voutput_min);
123 
124         uint8x8_t vout_lo = vget_low_u8(vout);
125         if (c & 8) {
126           vst1_u8(o, vout_lo); o += 8;
127           vout_lo = vget_high_u8(vout);
128         }
129         if (c & 4) {
130           vst1_lane_u32(__builtin_assume_aligned(o, 1), vreinterpret_u32_u8(vout_lo), 0); o += 4;
131           vout_lo = vext_u8(vout_lo, vout_lo, 4);
132         }
133         if (c & 2) {
134           vst1_lane_u16(__builtin_assume_aligned(o, 1), vreinterpret_u16_u8(vout_lo), 0); o += 2;
135           vout_lo = vext_u8(vout_lo, vout_lo, 2);
136         }
137         if (c & 1) {
138           vst1_lane_u8(o, vout_lo, 0); o += 1;
139         }
140       }
141     }
142 
143     for (ptrdiff_t k = (ptrdiff_t) kernel_elements - 9; k > 0; k -= 8) {
144       const uint8_t* i0 = *input++;
145       const uint8_t* i1 = *input++;
146       const uint8_t* i2 = *input++;
147       const uint8_t* i3 = *input++;
148       const uint8_t* i4 = *input++;
149       const uint8_t* i5 = *input++;
150       const uint8_t* i6 = *input++;
151       const uint8_t* i7 = *input++;
152       i0 = (const uint8_t*) ((uintptr_t) i0 + input_offset);
153       i1 = (const uint8_t*) ((uintptr_t) i1 + input_offset);
154       i2 = (const uint8_t*) ((uintptr_t) i2 + input_offset);
155       i3 = (const uint8_t*) ((uintptr_t) i3 + input_offset);
156       i4 = (const uint8_t*) ((uintptr_t) i4 + input_offset);
157       i5 = (const uint8_t*) ((uintptr_t) i5 + input_offset);
158       i6 = (const uint8_t*) ((uintptr_t) i6 + input_offset);
159       i7 = (const uint8_t*) ((uintptr_t) i7 + input_offset);
160       if (k < 2) {
161         i1 = i0;
162       }
163       if (k <= 2) {
164         i2 = i0;
165       }
166       if (k < 4) {
167         i3 = i0;
168       }
169       if (k <= 4) {
170         i4 = i0;
171       }
172       if (k < 6) {
173         i5 = i0;
174       }
175       if (k <= 6) {
176         i6 = i0;
177       }
178       if (k < 8) {
179         i7 = i0;
180       }
181 
182       o = output;
183       size_t c = channels;
184       for (; c >= 16; c -= 16) {
185         const uint8x16_t vi0 = vld1q_u8(i0); i0 += 16;
186         const uint8x16_t vi1 = vld1q_u8(i1); i1 += 16;
187         const uint8x16_t vi2 = vld1q_u8(i2); i2 += 16;
188         const uint8x16_t vi3 = vld1q_u8(i3); i3 += 16;
189         const uint8x16_t vi4 = vld1q_u8(i4); i4 += 16;
190         const uint8x16_t vi5 = vld1q_u8(i5); i5 += 16;
191         const uint8x16_t vi6 = vld1q_u8(i6); i6 += 16;
192         const uint8x16_t vi7 = vld1q_u8(i7); i7 += 16;
193         const uint8x16_t vo = vld1q_u8(o);
194 
195         const uint8x16_t vmax01 = vmaxq_u8(vmaxq_u8(vi0, vi1), vo);
196         const uint8x16_t vmax23 = vmaxq_u8(vi2, vi3);
197         const uint8x16_t vmax45 = vmaxq_u8(vi4, vi5);
198         const uint8x16_t vmax67 = vmaxq_u8(vi6, vi7);
199 
200         const uint8x16_t vmax2345 = vmaxq_u8(vmax23, vmax45);
201         const uint8x16_t vmax0167 = vmaxq_u8(vmax01, vmax67);
202         const uint8x16_t vmax = vmaxq_u8(vmax2345, vmax0167);
203         const uint8x16_t vout = vmaxq_u8(vminq_u8(vmax, voutput_max), voutput_min);
204 
205         vst1q_u8(o, vout); o += 16;
206       }
207       if (c != 0) {
208         const uint8x16_t vi0 = vld1q_u8(i0);
209         const uint8x16_t vi1 = vld1q_u8(i1);
210         const uint8x16_t vi2 = vld1q_u8(i2);
211         const uint8x16_t vi3 = vld1q_u8(i3);
212         const uint8x16_t vi4 = vld1q_u8(i4);
213         const uint8x16_t vi5 = vld1q_u8(i5);
214         const uint8x16_t vi6 = vld1q_u8(i6);
215         const uint8x16_t vi7 = vld1q_u8(i7);
216         const uint8x16_t vo = vld1q_u8(o);
217 
218         const uint8x16_t vmax01 = vmaxq_u8(vmaxq_u8(vi0, vi1), vo);
219         const uint8x16_t vmax23 = vmaxq_u8(vi2, vi3);
220         const uint8x16_t vmax45 = vmaxq_u8(vi4, vi5);
221         const uint8x16_t vmax67 = vmaxq_u8(vi6, vi7);
222 
223         const uint8x16_t vmax2345 = vmaxq_u8(vmax23, vmax45);
224         const uint8x16_t vmax0167 = vmaxq_u8(vmax01, vmax67);
225         const uint8x16_t vmax = vmaxq_u8(vmax2345, vmax0167);
226         const uint8x16_t vout = vmaxq_u8(vminq_u8(vmax, voutput_max), voutput_min);
227 
228         uint8x8_t vout_lo = vget_low_u8(vout);
229         if (c & 8) {
230           vst1_u8(o, vout_lo); o += 8;
231           vout_lo = vget_high_u8(vout);
232         }
233         if (c & 4) {
234           vst1_lane_u32(__builtin_assume_aligned(o, 1), vreinterpret_u32_u8(vout_lo), 0); o += 4;
235           vout_lo = vext_u8(vout_lo, vout_lo, 4);
236         }
237         if (c & 2) {
238           vst1_lane_u16(__builtin_assume_aligned(o, 1), vreinterpret_u16_u8(vout_lo), 0); o += 2;
239           vout_lo = vext_u8(vout_lo, vout_lo, 2);
240         }
241         if (c & 1) {
242           vst1_lane_u8(o, vout_lo, 0); o += 1;
243         }
244       }
245     }
246     input = (const uint8_t**) ((uintptr_t) input + input_increment);
247     output = (uint8_t*) ((uintptr_t) o + output_increment);
248   } while (--output_pixels != 0);
249 }
250