• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1// Copyright 2020 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$assert CHANNEL_TILE % 4 == 0
7$assert CHANNEL_TILE >= 4
8$assert ROW_TILE >= 1
9$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
10#include <assert.h>
11
12#include <wasm_simd128.h>
13
14#include <xnnpack/math.h>
15#include <xnnpack/vmulcaddc.h>
16
17
18void xnn_f32_vmulcaddc_minmax_ukernel_c${CHANNEL_TILE}__wasmsimd_${"x86" if X86 else "arm"}_${ROW_TILE}x(
19    size_t rows,
20    size_t channels,
21    const float*restrict input,
22    size_t input_stride,
23    const float*restrict weights,
24    float*restrict output,
25    size_t output_stride,
26    const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
27{
28  assert(rows != 0);
29  assert(channels != 0);
30  assert(channels % sizeof(float) == 0);
31
32  const float* i0 = input;
33  float* o0 = output;
34  $for M in range(1, ROW_TILE):
35    const float* i${M} = (const float*) ((uintptr_t) i${M-1} + input_stride);
36    float* o${M} = (float*) ((uintptr_t) o${M-1} + output_stride);
37
38  const size_t input_increment = input_stride * ${ROW_TILE} - channels;
39  const size_t output_increment = output_stride * ${ROW_TILE} - channels;
40
41  const v128_t vmin = wasm_v128_load64_splat(params->wasmsimd.min);
42  const v128_t vmax = wasm_v128_load64_splat(params->wasmsimd.max);
43  do {
44    $for M in range(1, ROW_TILE):
45      $if M % 2 == 0:
46        if XNN_UNPREDICTABLE(rows <= ${M}) {
47          i${M} = i${M-1};
48          o${M} = o${M-1};
49        }
50      $else:
51        if XNN_UNPREDICTABLE(rows < ${M+1}) {
52          i${M} = i${M-1};
53          o${M} = o${M-1};
54        }
55
56    const float* w = weights;
57    size_t c = channels;
58    for (; c >= ${CHANNEL_TILE} * sizeof(float); c -= ${CHANNEL_TILE} * sizeof(float)) {
59      const v128_t vscale${ABC[0:4]} = wasm_v128_load(w);
60      $for C in range(4, CHANNEL_TILE, 4):
61        const v128_t vscale${ABC[C:C+4]} = wasm_v128_load(w + ${C});
62
63      $for M in range(ROW_TILE):
64        v128_t vacc${M}x${ABC[0:4]} = wasm_v128_load(i${M});
65        $for C in range(4, CHANNEL_TILE, 4):
66          v128_t vacc${M}x${ABC[C:C+4]} = wasm_v128_load(i${M} + ${C});
67        i${M} += ${CHANNEL_TILE};
68
69      $for C in range(0, CHANNEL_TILE, 4):
70        const v128_t vbias${ABC[C:C+4]} = wasm_v128_load(w + ${C + CHANNEL_TILE});
71
72      $for M in range(ROW_TILE):
73        $for C in range(0, CHANNEL_TILE, 4):
74          vacc${M}x${ABC[C:C+4]} = wasm_f32x4_add(vbias${ABC[C:C+4]}, wasm_f32x4_mul(vscale${ABC[C:C+4]}, vacc${M}x${ABC[C:C+4]}));
75
76      $if X86:
77        $for M in range(ROW_TILE):
78          $for C in range(0, CHANNEL_TILE, 4):
79            vacc${M}x${ABC[C:C+4]} = wasm_f32x4_pmax(vmin, vacc${M}x${ABC[C:C+4]});
80
81        $for M in range(ROW_TILE):
82          $for C in range(0, CHANNEL_TILE, 4):
83            vacc${M}x${ABC[C:C+4]} = wasm_f32x4_pmin(vmax, vacc${M}x${ABC[C:C+4]});
84      $else:
85        $for M in range(ROW_TILE):
86          $for C in range(0, CHANNEL_TILE, 4):
87            vacc${M}x${ABC[C:C+4]} = wasm_f32x4_max(vacc${M}x${ABC[C:C+4]}, vmin);
88
89        $for M in range(ROW_TILE):
90          $for C in range(0, CHANNEL_TILE, 4):
91            vacc${M}x${ABC[C:C+4]} = wasm_f32x4_min(vacc${M}x${ABC[C:C+4]}, vmax);
92
93      $for M in range(ROW_TILE):
94        wasm_v128_store(o${M}, vacc${M}x${ABC[0:4]});
95        $for C in range(4, CHANNEL_TILE, 4):
96          wasm_v128_store(o${M} + ${C}, vacc${M}x${ABC[C:C+4]});
97        o${M} += ${CHANNEL_TILE};
98
99      w += ${CHANNEL_TILE * 2};
100    }
101    $if CHANNEL_TILE > 4:
102      for (; c >= 4 * sizeof(float); c -= 4 * sizeof(float)) {
103        const v128_t vscale = wasm_v128_load(w);
104
105        $for M in range(ROW_TILE):
106          v128_t vacc${M} = wasm_v128_load(i${M});
107          i${M} += 4;
108
109        const v128_t vbias = wasm_v128_load(w + ${CHANNEL_TILE});
110
111        $for M in range(ROW_TILE):
112          vacc${M} = wasm_f32x4_add(vbias, wasm_f32x4_mul(vscale, vacc${M}));
113
114        $if X86:
115          $for M in range(ROW_TILE):
116            vacc${M} = wasm_f32x4_pmax(vmin, vacc${M});
117
118          $for M in range(ROW_TILE):
119            vacc${M} = wasm_f32x4_pmin(vmax, vacc${M});
120        $else:
121          $for M in range(ROW_TILE):
122            vacc${M} = wasm_f32x4_max(vacc${M}, vmin);
123
124          $for M in range(ROW_TILE):
125            vacc${M} = wasm_f32x4_min(vacc${M}, vmax);
126
127        $for M in range(ROW_TILE):
128          wasm_v128_store(o${M}, vacc${M});
129          o${M} += 4;
130
131        w += 4;
132      }
133    if XNN_UNLIKELY(c != 0) {
134      const v128_t vscale = wasm_v128_load(w);
135
136      $for M in range(ROW_TILE):
137        v128_t vacc${M} = wasm_v128_load(i${M});
138        i${M} = (const float*) ((uintptr_t) i${M} + c);
139
140      const v128_t vbias = wasm_v128_load(w + ${CHANNEL_TILE});
141
142      $for M in range(ROW_TILE):
143        vacc${M} = wasm_f32x4_add(vbias, wasm_f32x4_mul(vscale, vacc${M}));
144
145      $if X86:
146        $for M in range(ROW_TILE):
147          vacc${M} = wasm_f32x4_pmax(vmin, vacc${M});
148
149        $for M in range(ROW_TILE):
150          vacc${M} = wasm_f32x4_pmin(vmax, vacc${M});
151      $else:
152        $for M in range(ROW_TILE):
153          vacc${M} = wasm_f32x4_max(vacc${M}, vmin);
154
155        $for M in range(ROW_TILE):
156          vacc${M} = wasm_f32x4_min(vacc${M}, vmax);
157
158      if (c & (2 * sizeof(float))) {
159        $for M in range(ROW_TILE):
160          *((double*) o${M}) = wasm_f64x2_extract_lane(vacc${M}, 0);
161
162        $for M in range(ROW_TILE):
163          vacc${M} = wasm_v32x4_shuffle(vacc${M}, vacc${M}, 2, 3, 2, 3);
164
165        $for M in range(ROW_TILE):
166          o${M} += 2;
167      }
168      if (c & (1 * sizeof(float))) {
169        $for M in range(ROW_TILE):
170          *o${M}++ = wasm_f32x4_extract_lane(vacc${M}, 0);
171      }
172    }
173    $for M in range(ROW_TILE):
174      i${M} = (const float*) ((uintptr_t) i${M} + input_increment);
175      o${M} = (float*) ((uintptr_t) o${M} + output_increment);
176    rows = doz(rows, ${ROW_TILE});
177  } while (rows != 0);
178}
179