• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2019 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #include <xnnpack/scalar-utils.h>
7 #include <xnnpack/dwconv.h>
8 
9 
xnn_qu8_dwconv_minmax_ukernel_up1x9__scalar(size_t channels,size_t output_width,const uint8_t ** input,const void * weights,uint8_t * output,size_t input_stride,size_t output_increment,size_t input_offset,const uint8_t * zero,const union xnn_qu8_gemm_params params[restrict XNN_MIN_ELEMENTS (1)])10 void xnn_qu8_dwconv_minmax_ukernel_up1x9__scalar(
11     size_t channels,
12     size_t output_width,
13     const uint8_t** input,
14     const void* weights,
15     uint8_t* output,
16     size_t input_stride,
17     size_t output_increment,
18     size_t input_offset,
19     const uint8_t* zero,
20     const union xnn_qu8_gemm_params params[restrict XNN_MIN_ELEMENTS(1)])
21 {
22   const int32_t vkernel_zero_point = params->scalar.kernel_zero_point;
23   const int32_t vmultiplier = params->scalar.multiplier;
24   const int32_t vq31rounding = INT32_C(0x40000000);
25   const int32_t vremainder_mask = params->scalar.remainder_mask;
26   const uint32_t vshift = params->scalar.shift;
27   const int32_t vremainder_threshold = params->scalar.remainder_threshold;
28   const int32_t vout_min = params->scalar.output_min_less_zero_point;
29   const int32_t vout_max = params->scalar.output_max_less_zero_point;
30   const int32_t voutput_zero_point = params->scalar.output_zero_point;
31   do {
32     const uint8_t* i0 = input[0];
33     if XNN_UNPREDICTABLE(i0 != zero) {
34       i0 = (const uint8_t*) ((uintptr_t) i0 + input_offset);
35     }
36     const uint8_t* i1 = input[1];
37     if XNN_UNPREDICTABLE(i1 != zero) {
38       i1 = (const uint8_t*) ((uintptr_t) i1 + input_offset);
39     }
40     const uint8_t* i2 = input[2];
41     if XNN_UNPREDICTABLE(i2 != zero) {
42       i2 = (const uint8_t*) ((uintptr_t) i2 + input_offset);
43     }
44     const uint8_t* i3 = input[3];
45     if XNN_UNPREDICTABLE(i3 != zero) {
46       i3 = (const uint8_t*) ((uintptr_t) i3 + input_offset);
47     }
48     const uint8_t* i4 = input[4];
49     if XNN_UNPREDICTABLE(i4 != zero) {
50       i4 = (const uint8_t*) ((uintptr_t) i4 + input_offset);
51     }
52     const uint8_t* i5 = input[5];
53     if XNN_UNPREDICTABLE(i5 != zero) {
54       i5 = (const uint8_t*) ((uintptr_t) i5 + input_offset);
55     }
56     const uint8_t* i6 = input[6];
57     if XNN_UNPREDICTABLE(i6 != zero) {
58       i6 = (const uint8_t*) ((uintptr_t) i6 + input_offset);
59     }
60     const uint8_t* i7 = input[7];
61     if XNN_UNPREDICTABLE(i7 != zero) {
62       i7 = (const uint8_t*) ((uintptr_t) i7 + input_offset);
63     }
64     const uint8_t* i8 = input[8];
65     if XNN_UNPREDICTABLE(i8 != zero) {
66       i8 = (const uint8_t*) ((uintptr_t) i8 + input_offset);
67     }
68 
69     input = (const uint8_t**) ((uintptr_t) input + input_stride);
70 
71     size_t c = channels;
72     const void* w = weights;
73     do {
74       int32_t vacc = *((const int32_t*) w);
75 
76       const int32_t vi0 = (int32_t) (uint32_t) *i0++;
77       const uint32_t vk0 = (uint32_t) ((const uint8_t*) w)[4];
78       const int32_t vxk0 = (int32_t) vk0 - vkernel_zero_point;
79       vacc += vi0 * vxk0;
80 
81       const int32_t vi1 = (int32_t) (uint32_t) *i1++;
82       const uint32_t vk1 = (uint32_t) ((const uint8_t*) w)[5];
83       const int32_t vxk1 = (int32_t) vk1 - vkernel_zero_point;
84       vacc += vi1 * vxk1;
85 
86       const int32_t vi2 = (int32_t) (uint32_t) *i2++;
87       const uint32_t vk2 = (uint32_t) ((const uint8_t*) w)[6];
88       const int32_t vxk2 = (int32_t) vk2 - vkernel_zero_point;
89       vacc += vi2 * vxk2;
90 
91       const int32_t vi3 = (int32_t) (uint32_t) *i3++;
92       const uint32_t vk3 = (uint32_t) ((const uint8_t*) w)[7];
93       const int32_t vxk3 = (int32_t) vk3 - vkernel_zero_point;
94       vacc += vi3 * vxk3;
95 
96       const int32_t vi4 = (int32_t) (uint32_t) *i4++;
97       const uint32_t vk4 = (uint32_t) ((const uint8_t*) w)[8];
98       const int32_t vxk4 = (int32_t) vk4 - vkernel_zero_point;
99       vacc += vi4 * vxk4;
100 
101       const int32_t vi5 = (int32_t) (uint32_t) *i5++;
102       const uint32_t vk5 = (uint32_t) ((const uint8_t*) w)[9];
103       const int32_t vxk5 = (int32_t) vk5 - vkernel_zero_point;
104       vacc += vi5 * vxk5;
105 
106       const int32_t vi6 = (int32_t) (uint32_t) *i6++;
107       const uint32_t vk6 = (uint32_t) ((const uint8_t*) w)[10];
108       const int32_t vxk6 = (int32_t) vk6 - vkernel_zero_point;
109       vacc += vi6 * vxk6;
110 
111       const int32_t vi7 = (int32_t) (uint32_t) *i7++;
112       const uint32_t vk7 = (uint32_t) ((const uint8_t*) w)[11];
113       const int32_t vxk7 = (int32_t) vk7 - vkernel_zero_point;
114       vacc += vi7 * vxk7;
115 
116       const int32_t vi8 = (int32_t) (uint32_t) *i8++;
117       const uint32_t vk8 = (uint32_t) ((const uint8_t*) w)[12];
118       const int32_t vxk8 = (int32_t) vk8 - vkernel_zero_point;
119       vacc += vi8 * vxk8;
120 
121       w = (const void*) ((uintptr_t) w + sizeof(int32_t) + 9 * sizeof(uint8_t));
122 
123       const int64_t vproduct = (int64_t) vacc * (int64_t) vmultiplier;
124       const int32_t vq31product = (int32_t) (uint32_t) ((uint64_t) (vproduct + (int64_t) vq31rounding) >> 31);
125       const int32_t vremainder = (vq31product & vremainder_mask) - (int32_t) (vq31product < 0);
126       int32_t vout = asr_s32(vq31product, vshift) + (int32_t) (vremainder > vremainder_threshold);
127       vout = vout < vout_min ? vout_min : vout;
128       vout = vout > vout_max ? vout_max : vout;
129       vout += voutput_zero_point;
130 
131       *output++ = vout;
132     } while (--c != 0);
133 
134     output = (uint8_t*) ((uintptr_t) output + output_increment);
135   } while (--output_width != 0);
136 }
137