1 // Copyright 2019 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5
6 #include <assert.h>
7
8 #include <arm_neon.h>
9
10 #include <xnnpack/gavgpool.h>
11 #include <xnnpack/math.h>
12
13
xnn_f32_gavgpool_minmax_ukernel_7p7x__neon_c4(size_t rows,size_t channels,const float * input,size_t input_stride,const float * zero,float * buffer,float * output,const union xnn_f32_scaleminmax_params params[restrict XNN_MIN_ELEMENTS (1)])14 void xnn_f32_gavgpool_minmax_ukernel_7p7x__neon_c4(
15 size_t rows,
16 size_t channels,
17 const float* input,
18 size_t input_stride,
19 const float* zero,
20 float* buffer,
21 float* output,
22 const union xnn_f32_scaleminmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_DISABLE_TSAN
23 {
24 assert(rows > 7);
25 assert(channels != 0);
26
27 const float* i0 = input;
28 const float* i1 = (const float*) ((uintptr_t) i0 + input_stride);
29 const float* i2 = (const float*) ((uintptr_t) i1 + input_stride);
30 const float* i3 = (const float*) ((uintptr_t) i2 + input_stride);
31 const float* i4 = (const float*) ((uintptr_t) i3 + input_stride);
32 const float* i5 = (const float*) ((uintptr_t) i4 + input_stride);
33 const float* i6 = (const float*) ((uintptr_t) i5 + input_stride);
34 const size_t packed_channels = round_up_po2(channels, 4);
35 const size_t input_increment = 7 * input_stride - packed_channels * sizeof(float);
36
37 float* b = buffer;
38 for (size_t c = 0; c < channels; c += 4) {
39 const float32x4_t vi0 = vld1q_f32(i0); i0 += 4;
40 const float32x4_t vi1 = vld1q_f32(i1); i1 += 4;
41 const float32x4_t vi2 = vld1q_f32(i2); i2 += 4;
42 const float32x4_t vi3 = vld1q_f32(i3); i3 += 4;
43 const float32x4_t vi4 = vld1q_f32(i4); i4 += 4;
44 const float32x4_t vi5 = vld1q_f32(i5); i5 += 4;
45 const float32x4_t vi6 = vld1q_f32(i6); i6 += 4;
46
47 const float32x4_t vsum01 = vaddq_f32(vi0, vi1);
48 const float32x4_t vsum23 = vaddq_f32(vi2, vi3);
49 const float32x4_t vsum45 = vaddq_f32(vi4, vi5);
50
51 const float32x4_t vsum016 = vaddq_f32(vsum01, vi6);
52 const float32x4_t vsum2345 = vaddq_f32(vsum23, vsum45);
53
54 const float32x4_t vsum = vaddq_f32(vsum016, vsum2345);
55
56 vst1q_f32(b, vsum); b += 4;
57 }
58 for (rows -= 7; rows > 7; rows -= 7) {
59 b = buffer;
60
61 i0 = (const float*) ((uintptr_t) i0 + input_increment);
62 i1 = (const float*) ((uintptr_t) i1 + input_increment);
63 i2 = (const float*) ((uintptr_t) i2 + input_increment);
64 i3 = (const float*) ((uintptr_t) i3 + input_increment);
65 i4 = (const float*) ((uintptr_t) i4 + input_increment);
66 i5 = (const float*) ((uintptr_t) i5 + input_increment);
67 i6 = (const float*) ((uintptr_t) i6 + input_increment);
68
69 for (size_t c = 0; c < channels; c += 4) {
70 const float32x4_t vi0 = vld1q_f32(i0); i0 += 4;
71 const float32x4_t vi1 = vld1q_f32(i1); i1 += 4;
72 const float32x4_t vi2 = vld1q_f32(i2); i2 += 4;
73 const float32x4_t vi3 = vld1q_f32(i3); i3 += 4;
74 const float32x4_t vi4 = vld1q_f32(i4); i4 += 4;
75 const float32x4_t vi5 = vld1q_f32(i5); i5 += 4;
76 const float32x4_t vi6 = vld1q_f32(i6); i6 += 4;
77 const float32x4_t vacc = vld1q_f32(b);
78
79 const float32x4_t vsum01 = vaddq_f32(vi0, vi1);
80 const float32x4_t vsum23 = vaddq_f32(vi2, vi3);
81 const float32x4_t vsum45 = vaddq_f32(vi4, vi5);
82 const float32x4_t vsum6a = vaddq_f32(vi6, vacc);
83
84 const float32x4_t vsum0123 = vaddq_f32(vsum01, vsum23);
85 const float32x4_t vsum456a = vaddq_f32(vsum45, vsum6a);
86
87 const float32x4_t vsum = vaddq_f32(vsum0123, vsum456a);
88
89 vst1q_f32(b, vsum); b += 4;
90 }
91 }
92
93 i0 = (const float*) ((uintptr_t) i0 + input_increment);
94 i1 = (const float*) ((uintptr_t) i1 + input_increment);
95 if (rows < 2) {
96 i1 = zero;
97 }
98 i2 = (const float*) ((uintptr_t) i2 + input_increment);
99 if (rows <= 2) {
100 i2 = zero;
101 }
102 i3 = (const float*) ((uintptr_t) i3 + input_increment);
103 if (rows < 4) {
104 i3 = zero;
105 }
106 i4 = (const float*) ((uintptr_t) i4 + input_increment);
107 if (rows <= 4) {
108 i4 = zero;
109 }
110 i5 = (const float*) ((uintptr_t) i5 + input_increment);
111 if (rows < 6) {
112 i5 = zero;
113 }
114 i6 = (const float*) ((uintptr_t) i6 + input_increment);
115 if (rows <= 6) {
116 i6 = zero;
117 }
118 const float32x4_t vscale = vld1q_dup_f32(¶ms->scalar.scale);
119 const float32x4_t vmin = vld1q_dup_f32(¶ms->scalar.min);
120 const float32x4_t vmax = vld1q_dup_f32(¶ms->scalar.max);
121
122 b = buffer;
123 while (channels >= 4) {
124 const float32x4_t vi0 = vld1q_f32(i0); i0 += 4;
125 const float32x4_t vi1 = vld1q_f32(i1); i1 += 4;
126 const float32x4_t vi2 = vld1q_f32(i2); i2 += 4;
127 const float32x4_t vi3 = vld1q_f32(i3); i3 += 4;
128 const float32x4_t vi4 = vld1q_f32(i4); i4 += 4;
129 const float32x4_t vi5 = vld1q_f32(i5); i5 += 4;
130 const float32x4_t vi6 = vld1q_f32(i6); i6 += 4;
131 const float32x4_t vacc = vld1q_f32(b); b += 4;
132
133 const float32x4_t vsum01 = vaddq_f32(vi0, vi1);
134 const float32x4_t vsum23 = vaddq_f32(vi2, vi3);
135 const float32x4_t vsum45 = vaddq_f32(vi4, vi5);
136 const float32x4_t vsum6a = vaddq_f32(vi6, vacc);
137
138 const float32x4_t vsum0123 = vaddq_f32(vsum01, vsum23);
139 const float32x4_t vsum456a = vaddq_f32(vsum45, vsum6a);
140
141 const float32x4_t vsum = vaddq_f32(vsum0123, vsum456a);
142
143 float32x4_t vout = vmulq_f32(vsum, vscale);
144 vout = vmaxq_f32(vout, vmin);
145 vout = vminq_f32(vout, vmax);
146
147 vst1q_f32(output, vout); output += 4;
148
149 channels -= 4;
150 }
151 if (channels != 0) {
152 const float32x4_t vi0 = vld1q_f32(i0);
153 const float32x4_t vi1 = vld1q_f32(i1);
154 const float32x4_t vi2 = vld1q_f32(i2);
155 const float32x4_t vi3 = vld1q_f32(i3);
156 const float32x4_t vi4 = vld1q_f32(i4);
157 const float32x4_t vi5 = vld1q_f32(i5);
158 const float32x4_t vi6 = vld1q_f32(i6);
159 const float32x4_t vacc = vld1q_f32(b);
160
161 const float32x4_t vsum01 = vaddq_f32(vi0, vi1);
162 const float32x4_t vsum23 = vaddq_f32(vi2, vi3);
163 const float32x4_t vsum45 = vaddq_f32(vi4, vi5);
164 const float32x4_t vsum6a = vaddq_f32(vi6, vacc);
165
166 const float32x4_t vsum0123 = vaddq_f32(vsum01, vsum23);
167 const float32x4_t vsum456a = vaddq_f32(vsum45, vsum6a);
168
169 const float32x4_t vsum = vaddq_f32(vsum0123, vsum456a);
170
171 float32x4_t vout = vmulq_f32(vsum, vscale);
172 vout = vmaxq_f32(vout, vmin);
173 vout = vminq_f32(vout, vmax);
174
175 float32x2_t vout_lo = vget_low_f32(vout);
176 if (channels & 2) {
177 vst1_f32(output, vout_lo); output += 2;
178 vout_lo = vget_high_f32(vout);
179 }
180 if (channels & 1) {
181 vst1_lane_f32(output, vout_lo, 0);
182 }
183 }
184 }
185