1// Copyright 2020 Google LLC 2// 3// This source code is licensed under the BSD-style license found in the 4// LICENSE file in the root directory of this source tree. 5 6$assert BATCH_TILE % 4 == 0 7$assert BATCH_TILE >= 4 8$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" 9$assert OP in ["ADD", "DIV", "MAX", "MIN", "MUL", "SUB", "SQRDIFF"] 10$assert ACTIVATION in ["LINEAR", "MINMAX", "RELU"] 11#include <assert.h> 12 13#include <wasm_simd128.h> 14 15#include <xnnpack/common.h> 16#include <xnnpack/vbinary.h> 17 18 19$WASM_F32X4_OP = { 20$ "ADD": "wasm_f32x4_add", 21$ "DIV": "wasm_f32x4_div", 22$ "MAX": "wasm_f32x4_pmax" if X86 else "wasm_f32x4_max", 23$ "MIN": "wasm_f32x4_pmin" if X86 else "wasm_f32x4_min", 24$ "MUL": "wasm_f32x4_mul", 25$ "SUB": "wasm_f32x4_sub", 26$ "SQRDIFF": "wasm_f32x4_sub", 27$}[OP] 28$ARCH_SUFFIX = "" if ACTIVATION in ["LINEAR", "RELU"] and OP not in ["MIN", "MAX"] else "_x86" if X86 else "_arm" 29$ACTIVATION_SUFFIX = {"LINEAR": ""}.get(ACTIVATION, "_" + ACTIVATION.lower()) 30$PARAMS = {"LINEAR": "xnn_f32_default_params", "RELU": "xnn_f32_relu_params", "MINMAX": "xnn_f32_minmax_params"}[ACTIVATION] 31void xnn_f32_v${OP.lower()}${ACTIVATION_SUFFIX}_ukernel__wasmsimd${ARCH_SUFFIX}_x${BATCH_TILE}( 32 size_t n, 33 const float* a, 34 const float* b, 35 float* y, 36 const union ${PARAMS} params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS 37{ 38 assert(n != 0); 39 assert(n % sizeof(float) == 0); 40 assert(a != NULL); 41 assert(b != NULL); 42 assert(y != NULL); 43 44 $if ACTIVATION == "MINMAX": 45 const v128_t vy_min = wasm_v128_load64_splat(params->wasmsimd.min); 46 const v128_t vy_max = wasm_v128_load64_splat(params->wasmsimd.max); 47 $elif ACTIVATION == "RELU": 48 const v128_t vzero = wasm_i32x4_const_splat(0); 49 50 for (; n >= ${BATCH_TILE} * sizeof(float); n -= ${BATCH_TILE} * sizeof(float)) { 51 const v128_t va${ABC[0:4]} = wasm_v128_load(a); 52 $for N in range(4, BATCH_TILE, 4): 53 const v128_t va${ABC[N:N+4]} = wasm_v128_load(a + ${N}); 54 a += ${BATCH_TILE}; 55 56 const v128_t vb${ABC[0:4]} = wasm_v128_load(b); 57 $for N in range(4, BATCH_TILE, 4): 58 const v128_t vb${ABC[N:N+4]} = wasm_v128_load(b + ${N}); 59 b += ${BATCH_TILE}; 60 61 $if OP == "MIN" and X86: 62 $for N in range(0, BATCH_TILE, 4): 63 const v128_t vm${ABC[N:N+4]} = wasm_f32x4_lt(va${ABC[N:N+4]}, vb${ABC[N:N+4]}); 64 65 $for N in range(0, BATCH_TILE, 4): 66 v128_t vy${ABC[N:N+4]} = wasm_v128_bitselect(va${ABC[N:N+4]}, vb${ABC[N:N+4]}, vm${ABC[N:N+4]}); 67 $elif OP == "MAX" and X86: 68 $for N in range(0, BATCH_TILE, 4): 69 const v128_t vm${ABC[N:N+4]} = wasm_f32x4_le(va${ABC[N:N+4]}, vb${ABC[N:N+4]}); 70 71 $for N in range(0, BATCH_TILE, 4): 72 v128_t vy${ABC[N:N+4]} = wasm_v128_bitselect(vb${ABC[N:N+4]}, va${ABC[N:N+4]}, vm${ABC[N:N+4]}); 73 $else: 74 $for N in range(0, BATCH_TILE, 4): 75 v128_t vy${ABC[N:N+4]} = ${WASM_F32X4_OP}(va${ABC[N:N+4]}, vb${ABC[N:N+4]}); 76 77 $if OP == "SQRDIFF": 78 $for N in range(0, BATCH_TILE, 4): 79 vy${ABC[N:N+4]} = wasm_f32x4_mul(vy${ABC[N:N+4]}, vy${ABC[N:N+4]}); 80 81 $if ACTIVATION == "MINMAX": 82 $if X86: 83 $for N in range(0, BATCH_TILE, 4): 84 vy${ABC[N:N+4]} = wasm_f32x4_pmax(vy_min, vy${ABC[N:N+4]}); 85 86 $for N in range(0, BATCH_TILE, 4): 87 vy${ABC[N:N+4]} = wasm_f32x4_pmin(vy_max, vy${ABC[N:N+4]}); 88 $else: 89 $for N in range(0, BATCH_TILE, 4): 90 vy${ABC[N:N+4]} = wasm_f32x4_max(vy${ABC[N:N+4]}, vy_min); 91 92 $for N in range(0, BATCH_TILE, 4): 93 vy${ABC[N:N+4]} = wasm_f32x4_min(vy${ABC[N:N+4]}, vy_max); 94 $elif ACTIVATION == "RELU": 95 $for N in range(0, BATCH_TILE, 4): 96 vy${ABC[N:N+4]} = wasm_i32x4_max(vy${ABC[N:N+4]}, vzero); 97 98 wasm_v128_store(y, vy${ABC[0:4]}); 99 $for N in range(4, BATCH_TILE, 4): 100 wasm_v128_store(y + ${N}, vy${ABC[N:N+4]}); 101 y += ${BATCH_TILE}; 102 } 103 $if BATCH_TILE > 4: 104 for (; n >= 4 * sizeof(float); n -= 4 * sizeof(float)) { 105 const v128_t va = wasm_v128_load(a); 106 a += 4; 107 108 const v128_t vb = wasm_v128_load(b); 109 b += 4; 110 111 $if OP == "MIN" and X86: 112 const v128_t vm = wasm_f32x4_lt(va, vb); 113 v128_t vy = wasm_v128_bitselect(va, vb, vm); 114 $elif OP == "MAX" and X86: 115 const v128_t vm = wasm_f32x4_le(va, vb); 116 v128_t vy = wasm_v128_bitselect(vb, va, vm); 117 $else: 118 v128_t vy = ${WASM_F32X4_OP}(va, vb); 119 $if OP == "SQRDIFF": 120 vy = wasm_f32x4_mul(vy, vy); 121 122 $if ACTIVATION == "MINMAX": 123 $if X86: 124 vy = wasm_f32x4_pmax(vy_min, vy); 125 vy = wasm_f32x4_pmin(vy_max, vy); 126 $else: 127 vy = wasm_f32x4_max(vy, vy_min); 128 vy = wasm_f32x4_min(vy, vy_max); 129 $elif ACTIVATION == "RELU": 130 vy = wasm_i32x4_max(vy, vzero); 131 132 wasm_v128_store(y, vy); 133 y += 4; 134 } 135 if XNN_UNLIKELY(n != 0) { 136 const v128_t va = wasm_v128_load(a); 137 const v128_t vb = wasm_v128_load(b); 138 139 $if OP == "MIN" and X86: 140 const v128_t vm = wasm_f32x4_lt(va, vb); 141 v128_t vy = wasm_v128_bitselect(va, vb, vm); 142 $elif OP == "MAX" and X86: 143 const v128_t vm = wasm_f32x4_le(va, vb); 144 v128_t vy = wasm_v128_bitselect(vb, va, vm); 145 $else: 146 v128_t vy = ${WASM_F32X4_OP}(va, vb); 147 $if OP == "SQRDIFF": 148 vy = wasm_f32x4_mul(vy, vy); 149 150 $if ACTIVATION == "MINMAX": 151 $if X86: 152 vy = wasm_f32x4_pmax(vy_min, vy); 153 vy = wasm_f32x4_pmin(vy_max, vy); 154 $else: 155 vy = wasm_f32x4_max(vy, vy_min); 156 vy = wasm_f32x4_min(vy, vy_max); 157 $elif ACTIVATION == "RELU": 158 vy = wasm_i32x4_max(vy, vzero); 159 160 if (n & (2 * sizeof(float))) { 161 *((double*) y) = wasm_f64x2_extract_lane(vy, 0); 162 vy = wasm_v32x4_shuffle(vy, vy, 2, 3, 2, 3); 163 y += 2; 164 } 165 if (n & (1 * sizeof(float))) { 166 *y = wasm_f32x4_extract_lane(vy, 0); 167 } 168 } 169} 170