• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2019 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #include <assert.h>
7 
8 #include <xnnpack/common.h>
9 #include <xnnpack/scalar-utils.h>
10 #include <xnnpack/vadd.h>
11 
12 
xnn_q8_vadd_ukernel__scalar(size_t n,const uint8_t * a,const uint8_t * b,uint8_t * y,const union xnn_q8_add_params params[restrict static1])13 void xnn_q8_vadd_ukernel__scalar(
14     size_t n,
15     const uint8_t* a,
16     const uint8_t* b,
17     uint8_t* y,
18     const union xnn_q8_add_params params[restrict static 1])
19 {
20   assert(n != 0);
21 
22   const int32_t vzero_point_product = params->scalar.zero_point_product;
23   const uint32_t va_multiplier = params->scalar.a_multiplier;
24   const uint32_t vb_multiplier = params->scalar.b_multiplier;
25   const uint32_t vshift = params->scalar.shift;
26   const int32_t vremainder_mask = params->scalar.remainder_mask;
27   const int32_t vremainder_threshold = params->scalar.remainder_threshold;
28   const int32_t vy_zero_point = params->scalar.y_zero_point;
29   const int32_t vy_max = params->scalar.y_max;
30   const int32_t vy_min = params->scalar.y_min;
31 
32   do {
33     const int32_t va = (int32_t) (uint32_t) *a++;
34     const int32_t vb = (int32_t) (uint32_t) *b++;
35 
36     // Multiply by factors.
37     const int32_t va_product = va * va_multiplier;
38     const int32_t vb_product = vb * vb_multiplier;
39 
40     // Accumulate products.
41     const int32_t vacc = vzero_point_product + va_product + vb_product;
42 
43     // Shift right and round.
44     const int32_t vremainder = (vacc & vremainder_mask) - (int32_t) (vacc < 0);
45     int32_t vy = asr_s32(vacc, vshift) + (int32_t) (vremainder > vremainder_threshold);
46 
47     // Pack, saturate, and add output zero point.
48     vy += vy_zero_point;
49     vy = vy < vy_min ? vy_min : vy;
50     vy = vy > vy_max ? vy_max : vy;
51 
52     *y++ = vy;
53 
54     n -= sizeof(uint8_t);
55   } while (n != 0);
56 }
57