• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1// Copyright 2019 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$assert CHANNEL_TILE % 4 == 0
7$assert CHANNEL_TILE >= 4
8$assert ROW_TILE >= 1
9$ABC = "0123456789ABCDEFGHIJKLMN"
10#include <assert.h>
11
12$if BLEND:
13  #include <smmintrin.h>
14$else:
15  #include <emmintrin.h>
16
17#include <xnnpack/math.h>
18#include <xnnpack/prelu.h>
19
20
21void xnn_f32_prelu_ukernel__${"sse41" if BLEND else "sse2"}_${ROW_TILE}x${CHANNEL_TILE}(
22    size_t rows,
23    size_t channels,
24    const float*restrict input,
25    size_t input_stride,
26    const float*restrict weights,
27    float*restrict output,
28    size_t output_stride,
29    const union xnn_f32_output_params params[restrict static 1])
30{
31  assert(rows != 0);
32  assert(channels != 0);
33  assert(channels % sizeof(float) == 0);
34
35  const float* i0 = input;
36  float* o0 = output;
37  $for M in range(1, ROW_TILE):
38    const float* i${M} = (const float*) ((uintptr_t) i${M-1} + input_stride);
39    float* o${M} = (float*) ((uintptr_t) o${M-1} + output_stride);
40    $if M % 2 == 0:
41      if XNN_UNPREDICTABLE(rows <= ${M}) {
42        i${M} = i${M-1};
43        o${M} = o${M-1};
44      }
45    $else:
46      if XNN_UNPREDICTABLE(rows < ${M+1}) {
47        i${M} = i${M-1};
48        o${M} = o${M-1};
49      }
50
51  const size_t input_increment = input_stride * ${ROW_TILE} - channels;
52  const size_t output_increment = output_stride * ${ROW_TILE} - channels;
53
54  const __m128 vmin = _mm_load_ps(params->sse.min);
55  const __m128 vmax = _mm_load_ps(params->sse.max);
56  do {
57    const float* w = weights;
58    size_t c = channels;
59    for (; c >= ${CHANNEL_TILE} * sizeof(float); c -= ${CHANNEL_TILE} * sizeof(float)) {
60      const __m128 vw${ABC[0:4]} = _mm_load_ps(w);
61      $for C in range(4, CHANNEL_TILE, 4):
62        const __m128 vw${ABC[C:C+4]} = _mm_load_ps(w + ${C});
63      w += ${CHANNEL_TILE};
64
65      $for M in range(ROW_TILE):
66        const __m128 vi${M}x${ABC[0:4]} = _mm_loadu_ps(i${M});
67        $for C in range(4, CHANNEL_TILE, 4):
68          const __m128 vi${M}x${ABC[C:C+4]} = _mm_loadu_ps(i${M} + ${C});
69        i${M} += ${CHANNEL_TILE};
70
71      $for M in range(ROW_TILE):
72        $for C in range(0, CHANNEL_TILE, 4):
73          const __m128 vprod${M}x${ABC[C:C+4]} = _mm_mul_ps(vi${M}x${ABC[C:C+4]}, vw${ABC[C:C+4]});
74          $if not BLEND:
75            const __m128 vmask${M}x${ABC[C:C+4]} = _mm_castsi128_ps(_mm_cmpgt_epi32(_mm_setzero_si128(), _mm_castps_si128(vi${M}x${ABC[C:C+4]})));
76
77      $for M in range(ROW_TILE):
78        $for C in range(0, CHANNEL_TILE, 4):
79          $if BLEND:
80            __m128 vacc${M}x${ABC[C:C+4]} = _mm_blendv_ps(vi${M}x${ABC[C:C+4]}, vprod${M}x${ABC[C:C+4]}, vi${M}x${ABC[C:C+4]});
81          $else:
82            __m128 vacc${M}x${ABC[C:C+4]} = _mm_or_ps(_mm_and_ps(vprod${M}x${ABC[C:C+4]}, vmask${M}x${ABC[C:C+4]}), _mm_andnot_ps(vmask${M}x${ABC[C:C+4]}, vi${M}x${ABC[C:C+4]}));
83
84      $for M in range(ROW_TILE):
85        $for C in range(0, CHANNEL_TILE, 4):
86          vacc${M}x${ABC[C:C+4]} = _mm_max_ps(vacc${M}x${ABC[C:C+4]}, vmin);
87
88      $for M in range(ROW_TILE):
89        $for C in range(0, CHANNEL_TILE, 4):
90          vacc${M}x${ABC[C:C+4]} = _mm_min_ps(vacc${M}x${ABC[C:C+4]}, vmax);
91
92      $for M in range(ROW_TILE):
93        _mm_storeu_ps(o${M}, vacc${M}x${ABC[0:4]});
94        $for C in range(4, CHANNEL_TILE, 4):
95          _mm_storeu_ps(o${M} + ${C}, vacc${M}x${ABC[C:C+4]});
96        o${M} += ${CHANNEL_TILE};
97    }
98    $if CHANNEL_TILE > 4:
99      for (; c >= 4 * sizeof(float); c -= 4 * sizeof(float)) {
100        const __m128 vw0123 = _mm_load_ps(w);
101        w += 4;
102
103        $for M in range(ROW_TILE):
104          const __m128 vi${M}x0123 = _mm_loadu_ps(i${M});
105          i${M} += 4;
106
107        $for M in range(ROW_TILE):
108          const __m128 vprod${M}x0123 = _mm_mul_ps(vi${M}x0123, vw0123);
109          $if not BLEND:
110            const __m128 vmask${M}x0123 = _mm_castsi128_ps(_mm_cmpgt_epi32(_mm_setzero_si128(), _mm_castps_si128(vi${M}x0123)));
111
112        $for M in range(ROW_TILE):
113          $if BLEND:
114            __m128 vacc${M}x0123 = _mm_blendv_ps(vi${M}x0123, vprod${M}x0123, vi${M}x0123);
115          $else:
116            __m128 vacc${M}x0123 = _mm_or_ps(_mm_and_ps(vprod${M}x0123, vmask${M}x0123), _mm_andnot_ps(vmask${M}x0123, vi${M}x0123));
117
118        $for M in range(ROW_TILE):
119          vacc${M}x0123 = _mm_max_ps(vacc${M}x0123, vmin);
120
121        $for M in range(ROW_TILE):
122          vacc${M}x0123 = _mm_min_ps(vacc${M}x0123, vmax);
123
124        $for M in range(ROW_TILE):
125          _mm_storeu_ps(o${M}, vacc${M}x0123);
126          o${M} += 4;
127      }
128    if XNN_UNLIKELY(c != 0) {
129      const __m128 vw0123 = _mm_load_ps(w);
130      w = (const float*) ((uintptr_t) w + c);
131
132      $for M in range(ROW_TILE):
133        const __m128 vi${M}x0123 = _mm_loadu_ps(i${M});
134        i${M} = (const float*) ((uintptr_t) i${M} + c);
135
136      $for M in range(ROW_TILE):
137        const __m128 vprod${M}x0123 = _mm_mul_ps(vi${M}x0123, vw0123);
138        $if not BLEND:
139          const __m128 vmask${M}x0123 = _mm_castsi128_ps(_mm_cmpgt_epi32(_mm_setzero_si128(), _mm_castps_si128(vi${M}x0123)));
140
141      $for M in range(ROW_TILE):
142        $if BLEND:
143          __m128 vacc${M}x0123 = _mm_blendv_ps(vi${M}x0123, vprod${M}x0123, vi${M}x0123);
144        $else:
145          __m128 vacc${M}x0123 = _mm_or_ps(_mm_and_ps(vprod${M}x0123, vmask${M}x0123), _mm_andnot_ps(vmask${M}x0123, vi${M}x0123));
146
147      $for M in range(ROW_TILE):
148        vacc${M}x0123 = _mm_max_ps(vacc${M}x0123, vmin);
149
150      $for M in range(ROW_TILE):
151        vacc${M}x0123 = _mm_min_ps(vacc${M}x0123, vmax);
152
153      if (c & (2 * sizeof(float))) {
154        $for M in range(ROW_TILE):
155          _mm_storel_pi((__m64*) o${M}, vacc${M}x0123);
156
157        $for M in range(ROW_TILE):
158          vacc${M}x0123 = _mm_movehl_ps(vacc${M}x0123, vacc${M}x0123);
159
160        $for M in range(ROW_TILE):
161          o${M} += 2;
162      }
163      if (c & (1 * sizeof(float))) {
164        $for M in range(ROW_TILE):
165          _mm_store_ss(o${M}, vacc${M}x0123);
166
167        $for M in range(ROW_TILE):
168          o${M} += 1;
169      }
170    }
171    $for M in range(ROW_TILE):
172      i${M} = (const float*) ((uintptr_t) i${M} + input_increment);
173      o${M} = (float*) ((uintptr_t) o${M} + output_increment);
174      $if M % 2 == 1:
175        if XNN_UNPREDICTABLE(rows < ${ROW_TILE + M+1}) {
176          i${M} = i${M-1};
177          o${M} = o${M-1};
178        }
179      $elif M != 0:
180        if XNN_UNPREDICTABLE(rows <= ${ROW_TILE + M}) {
181          i${M} = i${M-1};
182          o${M} = o${M-1};
183        }
184    rows = doz(rows, ${ROW_TILE});
185  } while (rows != 0);
186}
187