• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1// Copyright 2019 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$assert CHANNEL_TILE % 4 == 0
7$assert CHANNEL_TILE >= 4
8$assert ROW_TILE >= 1
9$ABC = "0123456789ABCDEFGHIJKLMN"
10#include <assert.h>
11
12#include <psimd.h>
13
14#include <xnnpack/math.h>
15#include <xnnpack/prelu.h>
16
17
18void xnn_f32_prelu_ukernel__psimd_${ROW_TILE}x${CHANNEL_TILE}(
19    size_t rows,
20    size_t channels,
21    const float*restrict input,
22    size_t input_stride,
23    const float*restrict weights,
24    float*restrict output,
25    size_t output_stride,
26    const union xnn_f32_output_params params[restrict static 1])
27{
28  assert(rows != 0);
29  assert(channels != 0);
30  assert(channels % sizeof(float) == 0);
31
32  const float* i0 = input;
33  float* o0 = output;
34  $for M in range(1, ROW_TILE):
35    const float* i${M} = (const float*) ((uintptr_t) i${M-1} + input_stride);
36    float* o${M} = (float*) ((uintptr_t) o${M-1} + output_stride);
37    $if M % 2 == 0:
38      if XNN_UNPREDICTABLE(rows <= ${M}) {
39        i${M} = i${M-1};
40        o${M} = o${M-1};
41      }
42    $else:
43      if XNN_UNPREDICTABLE(rows < ${M+1}) {
44        i${M} = i${M-1};
45        o${M} = o${M-1};
46      }
47
48  const size_t input_increment = input_stride * ${ROW_TILE} - channels;
49  const size_t output_increment = output_stride * ${ROW_TILE} - channels;
50
51  const psimd_f32 vmin = psimd_load_splat_f32(&params->scalar.min);
52  const psimd_f32 vmax = psimd_load_splat_f32(&params->scalar.max);
53  do {
54    const float* w = weights;
55    size_t c = channels;
56    for (; c >= ${CHANNEL_TILE} * sizeof(float); c -= ${CHANNEL_TILE} * sizeof(float)) {
57      const psimd_f32 vw${ABC[0:4]} = psimd_load_f32(w);
58      $for C in range(4, CHANNEL_TILE, 4):
59        const psimd_f32 vw${ABC[C:C+4]} = psimd_load_f32(w + ${C});
60      w += ${CHANNEL_TILE};
61
62      $for M in range(ROW_TILE):
63        const psimd_f32 vi${M}x${ABC[0:4]} = psimd_load_f32(i${M});
64        $for C in range(4, CHANNEL_TILE, 4):
65          const psimd_f32 vi${M}x${ABC[C:C+4]} = psimd_load_f32(i${M} + ${C});
66        i${M} += ${CHANNEL_TILE};
67
68      $for M in range(ROW_TILE):
69        $for C in range(0, CHANNEL_TILE, 4):
70          psimd_f32 vacc${M}x${ABC[C:C+4]} = psimd_mul_f32(vi${M}x${ABC[C:C+4]}, vw${ABC[C:C+4]});
71
72      $for M in range(ROW_TILE):
73        $for C in range(0, CHANNEL_TILE, 4):
74          vacc${M}x${ABC[C:C+4]} = psimd_signblend_f32(vi${M}x${ABC[C:C+4]}, vacc${M}x${ABC[C:C+4]}, vi${M}x${ABC[C:C+4]});
75
76      $for M in range(ROW_TILE):
77        $for C in range(0, CHANNEL_TILE, 4):
78          vacc${M}x${ABC[C:C+4]} = psimd_max_f32(vacc${M}x${ABC[C:C+4]}, vmin);
79
80      $for M in range(ROW_TILE):
81        $for C in range(0, CHANNEL_TILE, 4):
82          vacc${M}x${ABC[C:C+4]} = psimd_min_f32(vacc${M}x${ABC[C:C+4]}, vmax);
83
84      $for M in range(ROW_TILE):
85        psimd_store_f32(o${M}, vacc${M}x${ABC[0:4]});
86        $for C in range(4, CHANNEL_TILE, 4):
87          psimd_store_f32(o${M} + ${C}, vacc${M}x${ABC[C:C+4]});
88        o${M} += ${CHANNEL_TILE};
89    }
90    $if CHANNEL_TILE > 4:
91      for (; c >= 4 * sizeof(float); c -= 4 * sizeof(float)) {
92        const psimd_f32 vw0123 = psimd_load_f32(w);
93        w += 4;
94
95        $for M in range(ROW_TILE):
96          const psimd_f32 vi${M}x0123 = psimd_load_f32(i${M});
97          i${M} += 4;
98
99        $for M in range(ROW_TILE):
100          psimd_f32 vacc${M}x0123 = psimd_mul_f32(vi${M}x0123, vw0123);
101
102        $for M in range(ROW_TILE):
103          vacc${M}x0123 = psimd_signblend_f32(vi${M}x0123, vacc${M}x0123, vi${M}x0123);
104
105        $for M in range(ROW_TILE):
106          vacc${M}x0123 = psimd_max_f32(vacc${M}x0123, vmin);
107
108        $for M in range(ROW_TILE):
109          vacc${M}x0123 = psimd_min_f32(vacc${M}x0123, vmax);
110
111        $for M in range(ROW_TILE):
112          psimd_store_f32(o${M}, vacc${M}x0123);
113          o${M} += 4;
114      }
115    if XNN_UNLIKELY(c != 0) {
116      const psimd_f32 vw0123 = psimd_load_f32(w);
117      w = (const float*) ((uintptr_t) w + c);
118
119      $for M in range(ROW_TILE):
120        const psimd_f32 vi${M}x0123 = psimd_load_f32(i${M});
121        i${M} = (const float*) ((uintptr_t) i${M} + c);
122
123      $for M in range(ROW_TILE):
124        psimd_f32 vacc${M}x0123 = psimd_mul_f32(vi${M}x0123, vw0123);
125
126      $for M in range(ROW_TILE):
127        vacc${M}x0123 = psimd_signblend_f32(vi${M}x0123, vacc${M}x0123, vi${M}x0123);
128
129      $for M in range(ROW_TILE):
130        vacc${M}x0123 = psimd_max_f32(vacc${M}x0123, vmin);
131
132      $for M in range(ROW_TILE):
133        vacc${M}x0123 = psimd_min_f32(vacc${M}x0123, vmax);
134
135      if (c & (2 * sizeof(float))) {
136        $for M in range(ROW_TILE):
137          psimd_store2_f32(o${M}, vacc${M}x0123);
138
139        $for M in range(ROW_TILE):
140          vacc${M}x0123 = psimd_concat_hi_f32(vacc${M}x0123, vacc${M}x0123);
141
142        $for M in range(ROW_TILE):
143          o${M} += 2;
144      }
145      if (c & (1 * sizeof(float))) {
146        $for M in range(ROW_TILE):
147          psimd_store1_f32(o${M}, vacc${M}x0123);
148
149        $for M in range(ROW_TILE):
150          o${M} += 1;
151      }
152    }
153    $for M in range(ROW_TILE):
154      i${M} = (const float*) ((uintptr_t) i${M} + input_increment);
155      o${M} = (float*) ((uintptr_t) o${M} + output_increment);
156      $if M % 2 == 1:
157        if XNN_UNPREDICTABLE(rows < ${ROW_TILE + M+1}) {
158          i${M} = i${M-1};
159          o${M} = o${M-1};
160        }
161      $elif M != 0:
162        if XNN_UNPREDICTABLE(rows <= ${ROW_TILE + M}) {
163          i${M} = i${M-1};
164          o${M} = o${M-1};
165        }
166    rows = doz(rows, ${ROW_TILE});
167  } while (rows != 0);
168}
169