• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1// Copyright 2019 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$assert CHANNEL_TILE >= 1
7$assert ROW_TILE >= 1
8$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
9#include <assert.h>
10
11#include <xnnpack/math.h>
12#include <xnnpack/vmulcaddc.h>
13
14
15$MIN_F32 = "__builtin_wasm_min_f32" if WASM else "math_min_f32"
16$MAX_F32 = "__builtin_wasm_max_f32" if WASM else "math_max_f32"
17void xnn_f32_vmulcaddc_minmax_ukernel_c${CHANNEL_TILE}__${"wasm" if WASM else "scalar"}_${ROW_TILE}x(
18    size_t rows,
19    size_t channels,
20    const float*restrict input,
21    size_t input_stride,
22    const float*restrict weights,
23    float*restrict output,
24    size_t output_stride,
25    const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
26{
27  assert(rows != 0);
28  assert(channels != 0);
29  assert(channels % sizeof(float) == 0);
30
31  const size_t input_increment = input_stride * ${ROW_TILE} - channels;
32  const size_t output_increment = output_stride * ${ROW_TILE} - channels;
33
34  const float* i0 = input;
35  float* o0 = output;
36  $for M in range(1, ROW_TILE):
37    const float* i${M} = (const float*) ((uintptr_t) i${M-1} + input_stride);
38    float* o${M} = (float*) ((uintptr_t) o${M-1} + output_stride);
39    $if M % 2 == 0:
40      if XNN_UNPREDICTABLE(rows <= ${M}) {
41        i${M} = i${M-1};
42        o${M} = o${M-1};
43      }
44    $else:
45      if XNN_UNPREDICTABLE(rows < ${M+1}) {
46        i${M} = i${M-1};
47        o${M} = o${M-1};
48      }
49
50  const float vmin = params->scalar.min;
51  const float vmax = params->scalar.max;
52  do {
53    const float* w = weights;
54    size_t c = channels;
55    $if CHANNEL_TILE > 1:
56      for (; c >= ${CHANNEL_TILE} * sizeof(float); c -= ${CHANNEL_TILE} * sizeof(float)) {
57        $for C in range(CHANNEL_TILE):
58          const float vscale${ABC[C]} = w[${C}];
59
60        $for M in range(ROW_TILE):
61          $for C in range(CHANNEL_TILE):
62            float vacc${M}x${ABC[C]} = i${M}[${C}];
63          i${M} += ${CHANNEL_TILE};
64
65        $for C in range(CHANNEL_TILE):
66          const float vbias${ABC[C]} = w[${C + CHANNEL_TILE}];
67
68        $for M in range(ROW_TILE):
69          $for C in range(CHANNEL_TILE):
70            vacc${M}x${ABC[C]} = vacc${M}x${ABC[C]} * vscale${ABC[C]} + vbias${ABC[C]};
71
72        $for M in range(ROW_TILE):
73          $for C in range(CHANNEL_TILE):
74            vacc${M}x${ABC[C]} = ${MAX_F32}(vacc${M}x${ABC[C]}, vmin);
75
76        $for M in range(ROW_TILE):
77          $for C in range(CHANNEL_TILE):
78            vacc${M}x${ABC[C]} = ${MIN_F32}(vacc${M}x${ABC[C]}, vmax);
79
80        $for M in range(ROW_TILE):
81          $for C in range(CHANNEL_TILE):
82            o${M}[${C}] = vacc${M}x${ABC[C]};
83          o${M} += ${CHANNEL_TILE};
84
85        w += ${CHANNEL_TILE * 2};
86      }
87      if XNN_UNLIKELY(c != 0) {
88        do {
89          const float vscale = *w++;
90
91          $for M in range(ROW_TILE):
92            float vacc${M} = *i${M}++;
93
94          const float vbias = w[${CHANNEL_TILE - 1}];
95
96          $for M in range(ROW_TILE):
97            vacc${M} = vacc${M} * vscale + vbias;
98
99          $for M in range(ROW_TILE):
100            vacc${M} = ${MAX_F32}(vacc${M}, vmin);
101
102          $for M in range(ROW_TILE):
103            vacc${M} = ${MIN_F32}(vacc${M}, vmax);
104
105          $for M in range(ROW_TILE):
106            *o${M}++ = vacc${M};
107
108          c -= sizeof(float);
109        } while (c != 0);
110      }
111    $else:
112      do {
113        const float vscale = w[0];
114
115        $for M in range(ROW_TILE):
116          float vacc${M} = *i${M}++;
117
118        const float vbias = w[1];
119
120        $for M in range(ROW_TILE):
121          vacc${M} = vacc${M} * vscale + vbias;
122
123        $for M in range(ROW_TILE):
124          vacc${M} = ${MAX_F32}(vacc${M}, vmin);
125
126        $for M in range(ROW_TILE):
127          vacc${M} = ${MIN_F32}(vacc${M}, vmax);
128
129        $for M in range(ROW_TILE):
130          *o${M}++ = vacc${M};
131
132        w += 2;
133        c -= sizeof(float);
134      } while (c != 0);
135    $for M in range(ROW_TILE):
136      i${M} = (const float*) ((uintptr_t) i${M} + input_increment);
137      o${M} = (float*) ((uintptr_t) o${M} + output_increment);
138      $if M % 2 == 1:
139        if XNN_UNPREDICTABLE(rows < ${ROW_TILE + M + 1}) {
140          i${M} = i${M-1};
141          o${M} = o${M-1};
142        }
143      $elif M != 0:
144        if XNN_UNPREDICTABLE(rows <= ${ROW_TILE + M}) {
145          i${M} = i${M-1};
146          o${M} = o${M-1};
147        }
148    rows = doz(rows, ${ROW_TILE});
149  } while (rows != 0);
150}
151