• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // Copyright 2019 Google LLC
5 //
6 // This source code is licensed under the BSD-style license found in the
7 // LICENSE file in the root directory of this source tree.
8 
9 #include <assert.h>
10 
11 #include <xnnpack/scalar-utils.h>
12 #include <xnnpack/avgpool.h>
13 
14 
xnn_q8_avgpool_ukernel_mp9p8q__scalar(size_t n,size_t ks,size_t kc,const uint8_t ** input,const uint8_t * zero,int32_t * buffer,uint8_t * output,size_t input_increment,size_t output_increment,const union xnn_q8_avgpool_params params[restrict static1])15 void xnn_q8_avgpool_ukernel_mp9p8q__scalar(
16     size_t n,
17     size_t ks,
18     size_t kc,
19     const uint8_t** input,
20     const uint8_t* zero,
21     int32_t* buffer,
22     uint8_t* output,
23     size_t input_increment,
24     size_t output_increment,
25     const union xnn_q8_avgpool_params params[restrict static 1])
26 {
27   assert(n != 0);
28   assert(ks > 9);
29   assert(kc != 0);
30 
31   const int32_t vbias = params->scalar.bias;
32   const int32_t vmultiplier = params->scalar.multiplier;
33   const int64_t vrounding = params->scalar.rounding;
34   const uint32_t vshift = params->scalar.right_shift;
35   const int32_t voutput_min = params->scalar.output_min_less_zero_point;
36   const int32_t voutput_max = params->scalar.output_max_less_zero_point;
37   const int32_t voutput_zero_point = params->scalar.output_zero_point;
38   do {
39     // First pass.
40     {
41       const uint8_t* i0 = *input++;
42       const uint8_t* i1 = *input++;
43       const uint8_t* i2 = *input++;
44       const uint8_t* i3 = *input++;
45       const uint8_t* i4 = *input++;
46       const uint8_t* i5 = *input++;
47       const uint8_t* i6 = *input++;
48       const uint8_t* i7 = *input++;
49       const uint8_t* i8 = *input++;
50 
51       int32_t* b = buffer;
52       size_t k = kc;
53       do {
54         const uint32_t vi0 = (uint32_t) *i0++;
55         const uint32_t vi1 = (uint32_t) *i1++;
56         const uint32_t vi2 = (uint32_t) *i2++;
57         const uint32_t vi3 = (uint32_t) *i3++;
58         const uint32_t vi4 = (uint32_t) *i4++;
59         const uint32_t vi5 = (uint32_t) *i5++;
60         const uint32_t vi6 = (uint32_t) *i6++;
61         const uint32_t vi7 = (uint32_t) *i7++;
62         const uint32_t vi8 = (uint32_t) *i8++;
63 
64         const uint32_t vsum01 = vi0 + vi1;
65         const uint32_t vsum23 = vi2 + vi3;
66         const uint32_t vsum45 = vi4 + vi5;
67         const uint32_t vsum67 = vi6 + vi7;
68         const uint32_t vsum018 = vsum01 + vi8;
69         const uint32_t vsum2345 = vsum23 + vsum45;
70         const uint32_t vsum01678 = vsum018 + vsum67;
71         int32_t vacc = vbias + (int32_t) vsum2345;
72         vacc += (int32_t) vsum01678;
73         *b++ = vacc;
74       } while (--k != 0);
75     }
76 
77     size_t m = ks;
78     // Intermediate passes.
79     for (m -= 9; m > 8; m -= 8) {
80       const uint8_t* i0 = *input++;
81       const uint8_t* i1 = *input++;
82       const uint8_t* i2 = *input++;
83       const uint8_t* i3 = *input++;
84       const uint8_t* i4 = *input++;
85       const uint8_t* i5 = *input++;
86       const uint8_t* i6 = *input++;
87       const uint8_t* i7 = *input++;
88 
89       int32_t* b = buffer;
90       size_t k = kc;
91       do {
92         int32_t vacc = *b;
93 
94         const uint32_t vi0 = (uint32_t) *i0++;
95         const uint32_t vi1 = (uint32_t) *i1++;
96         const uint32_t vi2 = (uint32_t) *i2++;
97         const uint32_t vi3 = (uint32_t) *i3++;
98         const uint32_t vi4 = (uint32_t) *i4++;
99         const uint32_t vi5 = (uint32_t) *i5++;
100         const uint32_t vi6 = (uint32_t) *i6++;
101         const uint32_t vi7 = (uint32_t) *i7++;
102 
103         const uint32_t vsum01 = vi0 + vi1;
104         const uint32_t vsum23 = vi2 + vi3;
105         const uint32_t vsum45 = vi4 + vi5;
106         const uint32_t vsum67 = vi6 + vi7;
107         const uint32_t vsum0123 = vsum01 + vsum23;
108         const uint32_t vsum4567 = vsum45 + vsum67;
109         vacc += (int32_t) vsum0123;
110         vacc += (int32_t) vsum4567;
111 
112         *b++ = vacc;
113       } while (--k != 0);
114     }
115 
116     // Last pass.
117     {
118       const uint8_t* i0 = input[0];
119       const uint8_t* i1 = input[1];
120       const uint8_t* i2 = input[2];
121       const uint8_t* i3 = input[3];
122       const uint8_t* i4 = input[4];
123       const uint8_t* i5 = input[5];
124       const uint8_t* i6 = input[6];
125       const uint8_t* i7 = input[7];
126       input = (const uint8_t**) ((uintptr_t) input + input_increment);
127       if (m < 2) {
128         i1 = zero;
129       }
130       if (m <= 2) {
131         i2 = zero;
132       }
133       if (m < 4) {
134         i3 = zero;
135       }
136       if (m <= 4) {
137         i4 = zero;
138       }
139       if (m < 6) {
140         i5 = zero;
141       }
142       if (m <= 6) {
143         i6 = zero;
144       }
145       if (m != 8) {
146         i7 = zero;
147       }
148 
149       size_t k = kc;
150       int32_t* b = buffer;
151       do {
152         int32_t vacc = *b++;
153 
154         const uint32_t vi0 = (uint32_t) *i0++;
155         const uint32_t vi1 = (uint32_t) *i1++;
156         const uint32_t vi2 = (uint32_t) *i2++;
157         const uint32_t vi3 = (uint32_t) *i3++;
158         const uint32_t vi4 = (uint32_t) *i4++;
159         const uint32_t vi5 = (uint32_t) *i5++;
160         const uint32_t vi6 = (uint32_t) *i6++;
161         const uint32_t vi7 = (uint32_t) *i7++;
162 
163         const uint32_t vsum01 = vi0 + vi1;
164         const uint32_t vsum23 = vi2 + vi3;
165         const uint32_t vsum45 = vi4 + vi5;
166         const uint32_t vsum67 = vi6 + vi7;
167         const uint32_t vsum0123 = vsum01 + vsum23;
168         const uint32_t vsum4567 = vsum45 + vsum67;
169         vacc += (int32_t) vsum0123;
170         vacc += (int32_t) vsum4567;
171 
172         const int64_t vproduct = (int64_t) vacc * (int64_t) vmultiplier;
173         const int64_t vadjusted_product = vproduct - (int64_t) (vacc < 0);
174         int32_t vout = (int32_t) asr_s64(vadjusted_product + vrounding, vshift);
175         vout = vout < voutput_min ? voutput_min : vout;
176         vout = vout > voutput_max ? voutput_max : vout;
177         vout += voutput_zero_point;
178 
179         *output++ = (uint8_t) vout;
180       } while (--k != 0);
181     }
182     output = (uint8_t*) ((uintptr_t) output + output_increment);
183   } while (--n != 0);
184 }
185