• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1// Copyright 2019 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$assert CHANNEL_TILE % 4 == 0
7$assert CHANNEL_TILE >= 4
8$assert ROW_TILE >= 1
9$ABC = "0123456789ABCDEFGHIJKLMN"
10#include <assert.h>
11
12#include <psimd.h>
13
14#include <xnnpack/math.h>
15#include <xnnpack/vmulcaddc.h>
16
17
18void xnn_f32_vmulcaddc_ukernel_c${CHANNEL_TILE}__psimd_${ROW_TILE}x(
19    size_t rows,
20    size_t channels,
21    const float*restrict input,
22    size_t input_stride,
23    const float*restrict weights,
24    float*restrict output,
25    size_t output_stride,
26    const union xnn_f32_output_params params[restrict static 1])
27{
28  assert(rows != 0);
29  assert(channels != 0);
30  assert(channels % sizeof(float) == 0);
31
32  const float* i0 = input;
33  float* o0 = output;
34  $for M in range(1, ROW_TILE):
35    const float* i${M} = (const float*) ((uintptr_t) i${M-1} + input_stride);
36    float* o${M} = (float*) ((uintptr_t) o${M-1} + output_stride);
37    $if M % 2 == 0:
38      if XNN_UNPREDICTABLE(rows <= ${M}) {
39        i${M} = i${M-1};
40        o${M} = o${M-1};
41      }
42    $else:
43      if XNN_UNPREDICTABLE(rows < ${M+1}) {
44        i${M} = i${M-1};
45        o${M} = o${M-1};
46      }
47
48  const size_t input_increment = input_stride * ${ROW_TILE} - channels;
49  const size_t output_increment = output_stride * ${ROW_TILE} - channels;
50
51  const psimd_f32 vmin = psimd_load_splat_f32(&params->scalar.min);
52  const psimd_f32 vmax = psimd_load_splat_f32(&params->scalar.max);
53  do {
54    const float* w = weights;
55    size_t c = channels;
56    for (; c >= ${CHANNEL_TILE} * sizeof(float); c -= ${CHANNEL_TILE} * sizeof(float)) {
57      const psimd_f32 vscale${ABC[0:4]} = psimd_load_f32(w);
58      $for C in range(4, CHANNEL_TILE, 4):
59        const psimd_f32 vscale${ABC[C:C+4]} = psimd_load_f32(w + ${C});
60
61      $for M in range(ROW_TILE):
62        psimd_f32 vacc${M}x${ABC[0:4]} = psimd_load_f32(i${M});
63        $for C in range(4, CHANNEL_TILE, 4):
64          psimd_f32 vacc${M}x${ABC[C:C+4]} = psimd_load_f32(i${M} + ${C});
65        i${M} += ${CHANNEL_TILE};
66
67      $for C in range(0, CHANNEL_TILE, 4):
68        const psimd_f32 vbias${ABC[C:C+4]} = psimd_load_f32(w + ${C + CHANNEL_TILE});
69
70      $for M in range(ROW_TILE):
71        $for C in range(0, CHANNEL_TILE, 4):
72          vacc${M}x${ABC[C:C+4]} = psimd_qfma_f32(vbias${ABC[C:C+4]}, vscale${ABC[C:C+4]}, vacc${M}x${ABC[C:C+4]});
73
74      $for M in range(ROW_TILE):
75        $for C in range(0, CHANNEL_TILE, 4):
76          vacc${M}x${ABC[C:C+4]} = psimd_max_f32(vacc${M}x${ABC[C:C+4]}, vmin);
77
78      $for M in range(ROW_TILE):
79        $for C in range(0, CHANNEL_TILE, 4):
80          vacc${M}x${ABC[C:C+4]} = psimd_min_f32(vacc${M}x${ABC[C:C+4]}, vmax);
81
82      $for M in range(ROW_TILE):
83        psimd_store_f32(o${M}, vacc${M}x${ABC[0:4]});
84        $for C in range(4, CHANNEL_TILE, 4):
85          psimd_store_f32(o${M} + ${C}, vacc${M}x${ABC[C:C+4]});
86        o${M} += ${CHANNEL_TILE};
87
88      w += ${CHANNEL_TILE * 2};
89    }
90    $if CHANNEL_TILE > 4:
91      for (; c >= 4 * sizeof(float); c -= 4 * sizeof(float)) {
92        const psimd_f32 vscale0123 = psimd_load_f32(w);
93
94        $for M in range(ROW_TILE):
95          psimd_f32 vacc${M}x0123 = psimd_load_f32(i${M});
96          i${M} += 4;
97
98        const psimd_f32 vbias0123 = psimd_load_f32(w + ${CHANNEL_TILE});
99
100        $for M in range(ROW_TILE):
101          vacc${M}x0123 = psimd_qfma_f32(vbias0123, vscale0123, vacc${M}x0123);
102
103        $for M in range(ROW_TILE):
104          vacc${M}x0123 = psimd_max_f32(vacc${M}x0123, vmin);
105
106        $for M in range(ROW_TILE):
107          vacc${M}x0123 = psimd_min_f32(vacc${M}x0123, vmax);
108
109        $for M in range(ROW_TILE):
110          psimd_store_f32(o${M}, vacc${M}x0123);
111          o${M} += 4;
112
113        w += 4;
114      }
115    if XNN_UNLIKELY(c != 0) {
116      const psimd_f32 vscale0123 = psimd_load_f32(w);
117
118      $for M in range(ROW_TILE):
119        psimd_f32 vacc${M}x0123 = psimd_load_f32(i${M});
120        i${M} = (const float*) ((uintptr_t) i${M} + c);
121
122      const psimd_f32 vbias0123 = psimd_load_f32(w + ${CHANNEL_TILE});
123
124      $for M in range(ROW_TILE):
125        vacc${M}x0123 = psimd_qfma_f32(vbias0123, vscale0123, vacc${M}x0123);
126
127      $for M in range(ROW_TILE):
128        vacc${M}x0123 = psimd_max_f32(vacc${M}x0123, vmin);
129
130      $for M in range(ROW_TILE):
131        vacc${M}x0123 = psimd_min_f32(vacc${M}x0123, vmax);
132
133      if (c & (2 * sizeof(float))) {
134        $for M in range(ROW_TILE):
135          psimd_store2_f32(o${M}, vacc${M}x0123);
136
137        $for M in range(ROW_TILE):
138          vacc${M}x0123 = psimd_concat_hi_f32(vacc${M}x0123, vacc${M}x0123);
139
140        $for M in range(ROW_TILE):
141          o${M} += 2;
142      }
143      if (c & (1 * sizeof(float))) {
144        $for M in range(ROW_TILE):
145          psimd_store1_f32(o${M}, vacc${M}x0123);
146
147        $for M in range(ROW_TILE):
148          o${M} += 1;
149      }
150    }
151    $for M in range(ROW_TILE):
152      i${M} = (const float*) ((uintptr_t) i${M} + input_increment);
153      o${M} = (float*) ((uintptr_t) o${M} + output_increment);
154      $if M % 2 == 1:
155        if XNN_UNPREDICTABLE(rows < ${ROW_TILE + M + 1}) {
156          i${M} = i${M-1};
157          o${M} = o${M-1};
158        }
159      $elif M != 0:
160        if XNN_UNPREDICTABLE(rows <= ${ROW_TILE + M}) {
161          i${M} = i${M-1};
162          o${M} = o${M-1};
163        }
164    rows = doz(rows, ${ROW_TILE});
165  } while (rows != 0);
166}
167