• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1// Copyright 2020 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$assert CHANNEL_TILE % 4 == 0
7$assert CHANNEL_TILE >= 4
8$assert ROW_TILE >= 1
9$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
10#include <assert.h>
11
12#include <wasm_simd128.h>
13
14#include <xnnpack/math.h>
15#include <xnnpack/vmulcaddc.h>
16
17
18void xnn_f32_vmulcaddc_minmax_ukernel_c${CHANNEL_TILE}__wasmsimd_${"x86" if X86 else "arm"}_${ROW_TILE}x(
19    size_t rows,
20    size_t channels,
21    const float*restrict input,
22    size_t input_stride,
23    const float*restrict weights,
24    float*restrict output,
25    size_t output_stride,
26    const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_DISABLE_TSAN
27{
28  assert(rows != 0);
29  assert(channels != 0);
30  assert(channels % sizeof(float) == 0);
31
32  const float* i0 = input;
33  float* o0 = output;
34  $for M in range(1, ROW_TILE):
35    const float* i${M} = (const float*) ((uintptr_t) i${M-1} + input_stride);
36    float* o${M} = (float*) ((uintptr_t) o${M-1} + output_stride);
37    $if M % 2 == 0:
38      if XNN_UNPREDICTABLE(rows <= ${M}) {
39        i${M} = i${M-1};
40        o${M} = o${M-1};
41      }
42    $else:
43      if XNN_UNPREDICTABLE(rows < ${M+1}) {
44        i${M} = i${M-1};
45        o${M} = o${M-1};
46      }
47
48  const size_t input_increment = input_stride * ${ROW_TILE} - channels;
49  const size_t output_increment = output_stride * ${ROW_TILE} - channels;
50
51  const v128_t vmin = wasm_v32x4_load_splat(&params->scalar.min);
52  const v128_t vmax = wasm_v32x4_load_splat(&params->scalar.max);
53  do {
54    const float* w = weights;
55    size_t c = channels;
56    for (; c >= ${CHANNEL_TILE} * sizeof(float); c -= ${CHANNEL_TILE} * sizeof(float)) {
57      const v128_t vscale${ABC[0:4]} = wasm_v128_load(w);
58      $for C in range(4, CHANNEL_TILE, 4):
59        const v128_t vscale${ABC[C:C+4]} = wasm_v128_load(w + ${C});
60
61      $for M in range(ROW_TILE):
62        v128_t vacc${M}x${ABC[0:4]} = wasm_v128_load(i${M});
63        $for C in range(4, CHANNEL_TILE, 4):
64          v128_t vacc${M}x${ABC[C:C+4]} = wasm_v128_load(i${M} + ${C});
65        i${M} += ${CHANNEL_TILE};
66
67      $for C in range(0, CHANNEL_TILE, 4):
68        const v128_t vbias${ABC[C:C+4]} = wasm_v128_load(w + ${C + CHANNEL_TILE});
69
70      $for M in range(ROW_TILE):
71        $for C in range(0, CHANNEL_TILE, 4):
72          vacc${M}x${ABC[C:C+4]} = wasm_f32x4_add(vbias${ABC[C:C+4]}, wasm_f32x4_mul(vscale${ABC[C:C+4]}, vacc${M}x${ABC[C:C+4]}));
73
74      $if X86:
75        $for M in range(ROW_TILE):
76          $for C in range(0, CHANNEL_TILE, 4):
77            vacc${M}x${ABC[C:C+4]} = wasm_v128_bitselect(vmin, vacc${M}x${ABC[C:C+4]}, wasm_f32x4_lt(vacc${M}x${ABC[C:C+4]}, vmin));
78
79        $for M in range(ROW_TILE):
80          $for C in range(0, CHANNEL_TILE, 4):
81            vacc${M}x${ABC[C:C+4]} = wasm_v128_bitselect(vacc${M}x${ABC[C:C+4]}, vmax, wasm_f32x4_le(vacc${M}x${ABC[C:C+4]}, vmax));
82      $else:
83        $for M in range(ROW_TILE):
84          $for C in range(0, CHANNEL_TILE, 4):
85            vacc${M}x${ABC[C:C+4]} = wasm_f32x4_max(vacc${M}x${ABC[C:C+4]}, vmin);
86
87        $for M in range(ROW_TILE):
88          $for C in range(0, CHANNEL_TILE, 4):
89            vacc${M}x${ABC[C:C+4]} = wasm_f32x4_min(vacc${M}x${ABC[C:C+4]}, vmax);
90
91      $for M in range(ROW_TILE):
92        wasm_v128_store(o${M}, vacc${M}x${ABC[0:4]});
93        $for C in range(4, CHANNEL_TILE, 4):
94          wasm_v128_store(o${M} + ${C}, vacc${M}x${ABC[C:C+4]});
95        o${M} += ${CHANNEL_TILE};
96
97      w += ${CHANNEL_TILE * 2};
98    }
99    $if CHANNEL_TILE > 4:
100      for (; c >= 4 * sizeof(float); c -= 4 * sizeof(float)) {
101        const v128_t vscale = wasm_v128_load(w);
102
103        $for M in range(ROW_TILE):
104          v128_t vacc${M} = wasm_v128_load(i${M});
105          i${M} += 4;
106
107        const v128_t vbias = wasm_v128_load(w + ${CHANNEL_TILE});
108
109        $for M in range(ROW_TILE):
110          vacc${M} = wasm_f32x4_add(vbias, wasm_f32x4_mul(vscale, vacc${M}));
111
112        $if X86:
113          $for M in range(ROW_TILE):
114            vacc${M} = wasm_v128_bitselect(vmin, vacc${M}, wasm_f32x4_lt(vacc${M}, vmin));
115
116          $for M in range(ROW_TILE):
117            vacc${M} = wasm_v128_bitselect(vacc${M}, vmax, wasm_f32x4_le(vacc${M}, vmax));
118        $else:
119          $for M in range(ROW_TILE):
120            vacc${M} = wasm_f32x4_max(vacc${M}, vmin);
121
122          $for M in range(ROW_TILE):
123            vacc${M} = wasm_f32x4_min(vacc${M}, vmax);
124
125        $for M in range(ROW_TILE):
126          wasm_v128_store(o${M}, vacc${M});
127          o${M} += 4;
128
129        w += 4;
130      }
131    if XNN_UNLIKELY(c != 0) {
132      const v128_t vscale = wasm_v128_load(w);
133
134      $for M in range(ROW_TILE):
135        v128_t vacc${M} = wasm_v128_load(i${M});
136        i${M} = (const float*) ((uintptr_t) i${M} + c);
137
138      const v128_t vbias = wasm_v128_load(w + ${CHANNEL_TILE});
139
140      $for M in range(ROW_TILE):
141        vacc${M} = wasm_f32x4_add(vbias, wasm_f32x4_mul(vscale, vacc${M}));
142
143      $if X86:
144        $for M in range(ROW_TILE):
145          vacc${M} = wasm_v128_bitselect(vmin, vacc${M}, wasm_f32x4_lt(vacc${M}, vmin));
146
147        $for M in range(ROW_TILE):
148          vacc${M} = wasm_v128_bitselect(vacc${M}, vmax, wasm_f32x4_le(vacc${M}, vmax));
149      $else:
150        $for M in range(ROW_TILE):
151          vacc${M} = wasm_f32x4_max(vacc${M}, vmin);
152
153        $for M in range(ROW_TILE):
154          vacc${M} = wasm_f32x4_min(vacc${M}, vmax);
155
156      if (c & (2 * sizeof(float))) {
157        $for M in range(ROW_TILE):
158          *((double*) o${M}) = wasm_f64x2_extract_lane(vacc${M}, 0);
159
160        $for M in range(ROW_TILE):
161          vacc${M} = wasm_v32x4_shuffle(vacc${M}, vacc${M}, 2, 3, 2, 3);
162
163        $for M in range(ROW_TILE):
164          o${M} += 2;
165      }
166      if (c & (1 * sizeof(float))) {
167        $for M in range(ROW_TILE):
168          *o${M}++ = wasm_f32x4_extract_lane(vacc${M}, 0);
169      }
170    }
171    $for M in range(ROW_TILE):
172      i${M} = (const float*) ((uintptr_t) i${M} + input_increment);
173      o${M} = (float*) ((uintptr_t) o${M} + output_increment);
174      $if M % 2 == 1:
175        if XNN_UNPREDICTABLE(rows < ${ROW_TILE + M + 1}) {
176          i${M} = i${M-1};
177          o${M} = o${M-1};
178        }
179      $elif M != 0:
180        if XNN_UNPREDICTABLE(rows <= ${ROW_TILE + M}) {
181          i${M} = i${M-1};
182          o${M} = o${M-1};
183        }
184    rows = doz(rows, ${ROW_TILE});
185  } while (rows != 0);
186}
187