• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1// Copyright 2019 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$assert CHANNEL_TILE >= 1
7$assert ROW_TILE >= 1
8$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
9#include <assert.h>
10
11#include <xnnpack/math.h>
12#include <xnnpack/vmulcaddc.h>
13
14
15$MIN_F32 = "__builtin_wasm_min_f32" if WASM else "math_min_f32"
16$MAX_F32 = "__builtin_wasm_max_f32" if WASM else "math_max_f32"
17void xnn_f32_vmulcaddc_minmax_ukernel_c${CHANNEL_TILE}__${"wasm" if WASM else "scalar"}_${ROW_TILE}x(
18    size_t rows,
19    size_t channels,
20    const float*restrict input,
21    size_t input_stride,
22    const float*restrict weights,
23    float*restrict output,
24    size_t output_stride,
25    const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
26{
27  assert(rows != 0);
28  assert(channels != 0);
29  assert(channels % sizeof(float) == 0);
30
31  const size_t input_increment = input_stride * ${ROW_TILE} - channels;
32  const size_t output_increment = output_stride * ${ROW_TILE} - channels;
33
34  const float* i0 = input;
35  float* o0 = output;
36  $for M in range(1, ROW_TILE):
37    const float* i${M} = (const float*) ((uintptr_t) i${M-1} + input_stride);
38    float* o${M} = (float*) ((uintptr_t) o${M-1} + output_stride);
39
40  const float vmin = params->scalar.min;
41  const float vmax = params->scalar.max;
42  do {
43    $for M in range(1, ROW_TILE):
44      $if M % 2 == 0:
45        if XNN_UNPREDICTABLE(rows <= ${M}) {
46          i${M} = i${M-1};
47          o${M} = o${M-1};
48        }
49      $else:
50        if XNN_UNPREDICTABLE(rows < ${M+1}) {
51          i${M} = i${M-1};
52          o${M} = o${M-1};
53        }
54
55    const float* w = weights;
56    size_t c = channels;
57    $if CHANNEL_TILE > 1:
58      for (; c >= ${CHANNEL_TILE} * sizeof(float); c -= ${CHANNEL_TILE} * sizeof(float)) {
59        $for C in range(CHANNEL_TILE):
60          const float vscale${ABC[C]} = w[${C}];
61
62        $for M in range(ROW_TILE):
63          $for C in range(CHANNEL_TILE):
64            float vacc${M}x${ABC[C]} = i${M}[${C}];
65          i${M} += ${CHANNEL_TILE};
66
67        $for C in range(CHANNEL_TILE):
68          const float vbias${ABC[C]} = w[${C + CHANNEL_TILE}];
69
70        $for M in range(ROW_TILE):
71          $for C in range(CHANNEL_TILE):
72            vacc${M}x${ABC[C]} = vacc${M}x${ABC[C]} * vscale${ABC[C]} + vbias${ABC[C]};
73
74        $for M in range(ROW_TILE):
75          $for C in range(CHANNEL_TILE):
76            vacc${M}x${ABC[C]} = ${MAX_F32}(vacc${M}x${ABC[C]}, vmin);
77
78        $for M in range(ROW_TILE):
79          $for C in range(CHANNEL_TILE):
80            vacc${M}x${ABC[C]} = ${MIN_F32}(vacc${M}x${ABC[C]}, vmax);
81
82        $for M in range(ROW_TILE):
83          $for C in range(CHANNEL_TILE):
84            o${M}[${C}] = vacc${M}x${ABC[C]};
85          o${M} += ${CHANNEL_TILE};
86
87        w += ${CHANNEL_TILE * 2};
88      }
89      if XNN_UNLIKELY(c != 0) {
90        do {
91          const float vscale = *w++;
92
93          $for M in range(ROW_TILE):
94            float vacc${M} = *i${M}++;
95
96          const float vbias = w[${CHANNEL_TILE - 1}];
97
98          $for M in range(ROW_TILE):
99            vacc${M} = vacc${M} * vscale + vbias;
100
101          $for M in range(ROW_TILE):
102            vacc${M} = ${MAX_F32}(vacc${M}, vmin);
103
104          $for M in range(ROW_TILE):
105            vacc${M} = ${MIN_F32}(vacc${M}, vmax);
106
107          $for M in range(ROW_TILE):
108            *o${M}++ = vacc${M};
109
110          c -= sizeof(float);
111        } while (c != 0);
112      }
113    $else:
114      do {
115        const float vscale = w[0];
116
117        $for M in range(ROW_TILE):
118          float vacc${M} = *i${M}++;
119
120        const float vbias = w[1];
121
122        $for M in range(ROW_TILE):
123          vacc${M} = vacc${M} * vscale + vbias;
124
125        $for M in range(ROW_TILE):
126          vacc${M} = ${MAX_F32}(vacc${M}, vmin);
127
128        $for M in range(ROW_TILE):
129          vacc${M} = ${MIN_F32}(vacc${M}, vmax);
130
131        $for M in range(ROW_TILE):
132          *o${M}++ = vacc${M};
133
134        w += 2;
135        c -= sizeof(float);
136      } while (c != 0);
137    $for M in range(ROW_TILE):
138      i${M} = (const float*) ((uintptr_t) i${M} + input_increment);
139      o${M} = (float*) ((uintptr_t) o${M} + output_increment);
140    rows = doz(rows, ${ROW_TILE});
141  } while (rows != 0);
142}
143