1 // Copyright 2019 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5
6 #include <assert.h>
7
8 #include <xnnpack/avgpool.h>
9 #include <xnnpack/math.h>
10
11
xnn_f32_avgpool_ukernel_mp9p8q__wasm(size_t n,size_t ks,size_t kc,const float ** input,const float * zero,float * buffer,float * output,size_t input_increment,size_t output_increment,const union xnn_f32_avgpool_params params[restrict static1])12 void xnn_f32_avgpool_ukernel_mp9p8q__wasm(
13 size_t n,
14 size_t ks,
15 size_t kc,
16 const float** input,
17 const float* zero,
18 float* buffer,
19 float* output,
20 size_t input_increment,
21 size_t output_increment,
22 const union xnn_f32_avgpool_params params[restrict static 1])
23 {
24 assert(n != 0);
25 assert(ks > 9);
26 assert(kc != 0);
27
28 const float vmultiplier = params->scalar.multiplier;
29 const float voutput_min = params->scalar.output_min;
30 const float voutput_max = params->scalar.output_max;
31
32 do {
33 {
34 const float* i0 = *input++;
35 const float* i1 = *input++;
36 const float* i2 = *input++;
37 const float* i3 = *input++;
38 const float* i4 = *input++;
39 const float* i5 = *input++;
40 const float* i6 = *input++;
41 const float* i7 = *input++;
42 const float* i8 = *input++;
43
44 float* b = buffer;
45 size_t k = kc;
46 do {
47 const float vi0 = *i0++;
48 const float vi1 = *i1++;
49 const float vi2 = *i2++;
50 const float vi3 = *i3++;
51 const float vi4 = *i4++;
52 const float vi5 = *i5++;
53 const float vi6 = *i6++;
54 const float vi7 = *i7++;
55 const float vi8 = *i8++;
56
57 const float vsum01 = vi0 + vi1;
58 const float vsum23 = vi2 + vi3;
59 const float vsum45 = vi4 + vi5;
60 const float vsum67 = vi6 + vi7;
61 const float vsum018 = vsum01 + vi8;
62 const float vsum2345 = vsum23 + vsum45;
63 const float vsum01678 = vsum018 + vsum67;
64 const float vsum = vsum2345 + vsum01678;
65
66 *b++ = vsum;
67 } while (--k != 0);
68 }
69
70 size_t m = ks;
71 for (m -= 9; m > 8; m -= 8) {
72 const float* i0 = *input++;
73 const float* i1 = *input++;
74 const float* i2 = *input++;
75 const float* i3 = *input++;
76 const float* i4 = *input++;
77 const float* i5 = *input++;
78 const float* i6 = *input++;
79 const float* i7 = *input++;
80
81 float* b = buffer;
82 size_t k = kc;
83 do {
84 const float vi0 = *i0++;
85 const float vi1 = *i1++;
86 const float vi2 = *i2++;
87 const float vi3 = *i3++;
88 const float vi4 = *i4++;
89 const float vi5 = *i5++;
90 const float vi6 = *i6++;
91 const float vi7 = *i7++;
92 const float vacc = *b;
93
94 const float vsum01 = vi0 + vi1;
95 const float vsum23 = vi2 + vi3;
96 const float vsum45 = vi4 + vi5;
97 const float vsum67 = vi6 + vi7;
98 const float vsum01a = vsum01 + vacc;
99 const float vsum2345 = vsum23 + vsum45;
100 const float vsum0167a = vsum01a + vsum67;
101 const float vsum = vsum2345 + vsum0167a;
102
103 *b++ = vsum;
104 } while (--k != 0);
105 }
106
107 {
108 const float* i0 = input[0];
109 const float* i1 = input[1];
110 const float* i2 = input[2];
111 const float* i3 = input[3];
112 const float* i4 = input[4];
113 const float* i5 = input[5];
114 const float* i6 = input[6];
115 const float* i7 = input[7];
116 input = (const float**) ((uintptr_t) input + input_increment);
117 if (m < 2) {
118 i1 = zero;
119 }
120 if (m <= 2) {
121 i2 = zero;
122 }
123 if (m < 4) {
124 i3 = zero;
125 }
126 if (m <= 4) {
127 i4 = zero;
128 }
129 if (m < 6) {
130 i5 = zero;
131 }
132 if (m <= 6) {
133 i6 = zero;
134 }
135 if (m != 8) {
136 i7 = zero;
137 }
138
139 size_t k = kc;
140 float* b = buffer;
141 do {
142 const float vi0 = *i0++;
143 const float vi1 = *i1++;
144 const float vi2 = *i2++;
145 const float vi3 = *i3++;
146 const float vi4 = *i4++;
147 const float vi5 = *i5++;
148 const float vi6 = *i6++;
149 const float vi7 = *i7++;
150 const float vacc = *b++;
151
152 const float vsum01 = vi0 + vi1;
153 const float vsum23 = vi2 + vi3;
154 const float vsum45 = vi4 + vi5;
155 const float vsum67 = vi6 + vi7;
156 const float vsum01a = vsum01 + vacc;
157 const float vsum2345 = vsum23 + vsum45;
158 const float vsum0167a = vsum01a + vsum67;
159 const float vsum = vsum2345 + vsum0167a;
160
161 float vout = vsum * vmultiplier;
162 vout = __builtin_wasm_max_f32(vout, voutput_min);
163 vout = __builtin_wasm_min_f32(vout, voutput_max);
164
165 *output++ = vout;
166 } while (--k != 0);
167 }
168 output = (float*) ((uintptr_t) output + output_increment);
169 } while (--n != 0);
170 }
171