1 // Copyright 2019 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5
6 #include <xnnpack/scalar-utils.h>
7 #include <xnnpack/dwconv.h>
8
9
xnn_qu8_dwconv_minmax_ukernel_up1x9__scalar(size_t channels,size_t output_width,const uint8_t ** input,const void * weights,uint8_t * output,size_t input_stride,size_t output_increment,size_t input_offset,const uint8_t * zero,const union xnn_qu8_gemm_params params[restrict XNN_MIN_ELEMENTS (1)])10 void xnn_qu8_dwconv_minmax_ukernel_up1x9__scalar(
11 size_t channels,
12 size_t output_width,
13 const uint8_t** input,
14 const void* weights,
15 uint8_t* output,
16 size_t input_stride,
17 size_t output_increment,
18 size_t input_offset,
19 const uint8_t* zero,
20 const union xnn_qu8_gemm_params params[restrict XNN_MIN_ELEMENTS(1)])
21 {
22 const int32_t vkernel_zero_point = params->scalar.kernel_zero_point;
23 const int32_t vmultiplier = params->scalar.multiplier;
24 const int32_t vq31rounding = INT32_C(0x40000000);
25 const int32_t vremainder_mask = params->scalar.remainder_mask;
26 const uint32_t vshift = params->scalar.shift;
27 const int32_t vremainder_threshold = params->scalar.remainder_threshold;
28 const int32_t vout_min = params->scalar.output_min_less_zero_point;
29 const int32_t vout_max = params->scalar.output_max_less_zero_point;
30 const int32_t voutput_zero_point = params->scalar.output_zero_point;
31 do {
32 const uint8_t* i0 = input[0];
33 if XNN_UNPREDICTABLE(i0 != zero) {
34 i0 = (const uint8_t*) ((uintptr_t) i0 + input_offset);
35 }
36 const uint8_t* i1 = input[1];
37 if XNN_UNPREDICTABLE(i1 != zero) {
38 i1 = (const uint8_t*) ((uintptr_t) i1 + input_offset);
39 }
40 const uint8_t* i2 = input[2];
41 if XNN_UNPREDICTABLE(i2 != zero) {
42 i2 = (const uint8_t*) ((uintptr_t) i2 + input_offset);
43 }
44 const uint8_t* i3 = input[3];
45 if XNN_UNPREDICTABLE(i3 != zero) {
46 i3 = (const uint8_t*) ((uintptr_t) i3 + input_offset);
47 }
48 const uint8_t* i4 = input[4];
49 if XNN_UNPREDICTABLE(i4 != zero) {
50 i4 = (const uint8_t*) ((uintptr_t) i4 + input_offset);
51 }
52 const uint8_t* i5 = input[5];
53 if XNN_UNPREDICTABLE(i5 != zero) {
54 i5 = (const uint8_t*) ((uintptr_t) i5 + input_offset);
55 }
56 const uint8_t* i6 = input[6];
57 if XNN_UNPREDICTABLE(i6 != zero) {
58 i6 = (const uint8_t*) ((uintptr_t) i6 + input_offset);
59 }
60 const uint8_t* i7 = input[7];
61 if XNN_UNPREDICTABLE(i7 != zero) {
62 i7 = (const uint8_t*) ((uintptr_t) i7 + input_offset);
63 }
64 const uint8_t* i8 = input[8];
65 if XNN_UNPREDICTABLE(i8 != zero) {
66 i8 = (const uint8_t*) ((uintptr_t) i8 + input_offset);
67 }
68
69 input = (const uint8_t**) ((uintptr_t) input + input_stride);
70
71 size_t c = channels;
72 const void* w = weights;
73 do {
74 int32_t vacc = *((const int32_t*) w);
75
76 const int32_t vi0 = (int32_t) (uint32_t) *i0++;
77 const uint32_t vk0 = (uint32_t) ((const uint8_t*) w)[4];
78 const int32_t vxk0 = (int32_t) vk0 - vkernel_zero_point;
79 vacc += vi0 * vxk0;
80
81 const int32_t vi1 = (int32_t) (uint32_t) *i1++;
82 const uint32_t vk1 = (uint32_t) ((const uint8_t*) w)[5];
83 const int32_t vxk1 = (int32_t) vk1 - vkernel_zero_point;
84 vacc += vi1 * vxk1;
85
86 const int32_t vi2 = (int32_t) (uint32_t) *i2++;
87 const uint32_t vk2 = (uint32_t) ((const uint8_t*) w)[6];
88 const int32_t vxk2 = (int32_t) vk2 - vkernel_zero_point;
89 vacc += vi2 * vxk2;
90
91 const int32_t vi3 = (int32_t) (uint32_t) *i3++;
92 const uint32_t vk3 = (uint32_t) ((const uint8_t*) w)[7];
93 const int32_t vxk3 = (int32_t) vk3 - vkernel_zero_point;
94 vacc += vi3 * vxk3;
95
96 const int32_t vi4 = (int32_t) (uint32_t) *i4++;
97 const uint32_t vk4 = (uint32_t) ((const uint8_t*) w)[8];
98 const int32_t vxk4 = (int32_t) vk4 - vkernel_zero_point;
99 vacc += vi4 * vxk4;
100
101 const int32_t vi5 = (int32_t) (uint32_t) *i5++;
102 const uint32_t vk5 = (uint32_t) ((const uint8_t*) w)[9];
103 const int32_t vxk5 = (int32_t) vk5 - vkernel_zero_point;
104 vacc += vi5 * vxk5;
105
106 const int32_t vi6 = (int32_t) (uint32_t) *i6++;
107 const uint32_t vk6 = (uint32_t) ((const uint8_t*) w)[10];
108 const int32_t vxk6 = (int32_t) vk6 - vkernel_zero_point;
109 vacc += vi6 * vxk6;
110
111 const int32_t vi7 = (int32_t) (uint32_t) *i7++;
112 const uint32_t vk7 = (uint32_t) ((const uint8_t*) w)[11];
113 const int32_t vxk7 = (int32_t) vk7 - vkernel_zero_point;
114 vacc += vi7 * vxk7;
115
116 const int32_t vi8 = (int32_t) (uint32_t) *i8++;
117 const uint32_t vk8 = (uint32_t) ((const uint8_t*) w)[12];
118 const int32_t vxk8 = (int32_t) vk8 - vkernel_zero_point;
119 vacc += vi8 * vxk8;
120
121 w = (const void*) ((uintptr_t) w + sizeof(int32_t) + 9 * sizeof(uint8_t));
122
123 const int64_t vproduct = (int64_t) vacc * (int64_t) vmultiplier;
124 const int32_t vq31product = (int32_t) (uint32_t) ((uint64_t) (vproduct + (int64_t) vq31rounding) >> 31);
125 const int32_t vremainder = (vq31product & vremainder_mask) - (int32_t) (vq31product < 0);
126 int32_t vout = asr_s32(vq31product, vshift) + (int32_t) (vremainder > vremainder_threshold);
127 vout = vout < vout_min ? vout_min : vout;
128 vout = vout > vout_max ? vout_max : vout;
129 vout += voutput_zero_point;
130
131 *output++ = vout;
132 } while (--c != 0);
133
134 output = (uint8_t*) ((uintptr_t) output + output_increment);
135 } while (--output_width != 0);
136 }
137