• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1// Copyright 2019 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$assert CHANNEL_TILE > 0
7$assert ROW_TILE >= 1
8$ABC = "0123456789ABCDEFGHIJKLMN"
9#include <assert.h>
10
11#include <math.h>
12
13#include <xnnpack/math.h>
14#include <xnnpack/prelu.h>
15
16
17$MIN_F32 = "__builtin_wasm_min_f32" if WASM else "math_min_f32"
18$MAX_F32 = "__builtin_wasm_max_f32" if WASM else "math_max_f32"
19void xnn_f32_prelu_ukernel__${"wasm" if WASM else "scalar"}_${ROW_TILE}x${CHANNEL_TILE}(
20    size_t rows,
21    size_t channels,
22    const float*restrict input,
23    size_t input_stride,
24    const float*restrict weights,
25    float*restrict output,
26    size_t output_stride,
27    const union xnn_f32_output_params params[restrict static 1])
28{
29  assert(rows != 0);
30  assert(channels != 0);
31  assert(channels % sizeof(float) == 0);
32
33  const float* i0 = input;
34  float* o0 = output;
35  $for M in range(1, ROW_TILE):
36    const float* i${M} = (const float*) ((uintptr_t) i${M-1} + input_stride);
37    float* o${M} = (float*) ((uintptr_t) o${M-1} + output_stride);
38    $if M % 2 == 0:
39      if XNN_UNPREDICTABLE(rows <= ${M}) {
40        i${M} = i${M-1};
41        o${M} = o${M-1};
42      }
43    $else:
44      if XNN_UNPREDICTABLE(rows < ${M+1}) {
45        i${M} = i${M-1};
46        o${M} = o${M-1};
47      }
48
49  const size_t input_increment = input_stride * ${ROW_TILE} - channels;
50  const size_t output_increment = output_stride * ${ROW_TILE} - channels;
51
52  const float vmin = params->scalar.min;
53  const float vmax = params->scalar.max;
54  do {
55    const float* w = weights;
56    size_t c = channels;
57    $if CHANNEL_TILE > 1:
58      for (; c >= ${CHANNEL_TILE} * sizeof(float); c -= ${CHANNEL_TILE} * sizeof(float)) {
59        $for C in range(CHANNEL_TILE):
60          const float vw${ABC[C]} = w[${C}];
61
62        $for M in range(ROW_TILE):
63          $for C in range(CHANNEL_TILE):
64            const float vi${M}x${ABC[C]} = i${M}[${C}];
65          i${M} += ${CHANNEL_TILE};
66
67        $for M in range(ROW_TILE):
68          $for C in range(CHANNEL_TILE):
69            float vacc${M}x${ABC[C]} = signbit(vi${M}x${ABC[C]}) ? vi${M}x${ABC[C]} * vw${ABC[C]} : vi${M}x${ABC[C]};
70
71        $for M in range(ROW_TILE):
72          $for C in range(CHANNEL_TILE):
73            vacc${M}x${ABC[C]} = ${MAX_F32}(vacc${M}x${ABC[C]}, vmin);
74
75        $for M in range(ROW_TILE):
76          $for C in range(CHANNEL_TILE):
77            vacc${M}x${ABC[C]} = ${MIN_F32}(vacc${M}x${ABC[C]}, vmax);
78
79        $for M in range(ROW_TILE):
80          $for C in range(CHANNEL_TILE):
81            o${M}[${C}] = vacc${M}x${ABC[C]};
82          o${M} += ${CHANNEL_TILE};
83
84        w += ${CHANNEL_TILE};
85      }
86      for (; c != 0; c -= sizeof(float)) {
87        const float vw = *w++;
88
89        $for M in range(ROW_TILE):
90          const float vi${M} = *i${M}++;
91
92        $for M in range(ROW_TILE):
93          float vacc${M} = signbit(vi${M}) ? vi${M} * vw : vi${M};
94
95        $for M in range(ROW_TILE):
96          vacc${M} = ${MAX_F32}(vacc${M}, vmin);
97
98        $for M in range(ROW_TILE):
99          vacc${M} = ${MIN_F32}(vacc${M}, vmax);
100
101        $for M in range(ROW_TILE):
102          *o${M}++ = vacc${M};
103      }
104    $else:
105      do {
106        const float vw = *w++;
107
108        $for M in range(ROW_TILE):
109          const float vi${M} = *i${M}++;
110
111        $for M in range(ROW_TILE):
112          $for C in range(CHANNEL_TILE):
113            float vacc${M} = signbit(vi${M}) ? vi${M} * vw : vi${M};
114
115        $for M in range(ROW_TILE):
116          $for C in range(CHANNEL_TILE):
117            vacc${M} = ${MAX_F32}(vacc${M}, vmin);
118
119        $for M in range(ROW_TILE):
120          $for C in range(CHANNEL_TILE):
121            vacc${M} = ${MIN_F32}(vacc${M}, vmax);
122
123        $for M in range(ROW_TILE):
124          *o${M}++ = vacc${M};
125
126        c -= sizeof(float);
127      } while (c != 0);
128    $for M in range(ROW_TILE):
129      i${M} = (const float*) ((uintptr_t) i${M} + input_increment);
130      o${M} = (float*) ((uintptr_t) o${M} + output_increment);
131      $if M % 2 == 1:
132        if XNN_UNPREDICTABLE(rows < ${ROW_TILE + M+1}) {
133          i${M} = i${M-1};
134          o${M} = o${M-1};
135        }
136      $elif M != 0:
137        if XNN_UNPREDICTABLE(rows <= ${ROW_TILE + M}) {
138          i${M} = i${M-1};
139          o${M} = o${M-1};
140        }
141    rows = doz(rows, ${ROW_TILE});
142  } while (rows != 0);
143}
144