1// Copyright 2020 Google LLC 2// 3// This source code is licensed under the BSD-style license found in the 4// LICENSE file in the root directory of this source tree. 5 6$assert BATCH_TILE % 4 == 0 7$assert BATCH_TILE >= 4 8$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" 9$assert OP in ["ADD", "DIV", "RDIV", "MAX", "MIN", "MUL", "SUB", "RSUB", "SQRDIFF"] 10$assert ACTIVATION in ["LINEAR", "MINMAX", "RELU"] 11#include <assert.h> 12 13#include <wasm_simd128.h> 14 15#include <xnnpack/common.h> 16#include <xnnpack/vbinary.h> 17 18 19$WASM_F32X4_OP = { 20$ "ADD": lambda x: "wasm_f32x4_add(%s, vb)" % x, 21$ "DIV": lambda x: "wasm_f32x4_div(%s, vb)" % x, 22$ "RDIV": lambda x: "wasm_f32x4_div(vb, %s)" % x, 23$ "MAX": lambda x: ("wasm_f32x4_pmax(vb, %s)" if X86 else "wasm_f32x4_max(%s, vb)") % x, 24$ "MIN": lambda x: ("wasm_f32x4_pmin(vb, %s)" if X86 else "wasm_f32x4_min(%s, vb)") % x, 25$ "MUL": lambda x: "wasm_f32x4_mul(%s, vb)" % x, 26$ "SUB": lambda x: "wasm_f32x4_sub(%s, vb)" % x, 27$ "RSUB": lambda x: "wasm_f32x4_sub(vb, %s)" % x, 28$ "SQRDIFF": lambda x: "wasm_f32x4_sub(%s, vb)" % x, 29$}[OP] 30$assert ACTIVATION in ["LINEAR", "RELU", "MINMAX"] 31$ARCH_SUFFIX = "" if ACTIVATION in ["LINEAR", "RELU"] and OP not in ["MIN", "MAX"] else "_x86" if X86 else "_arm" 32$ACTIVATION_SUFFIX = {"LINEAR": ""}.get(ACTIVATION, "_" + ACTIVATION.lower()) 33$PARAMS = {"LINEAR": "xnn_f32_default_params", "RELU": "xnn_f32_relu_params", "MINMAX": "xnn_f32_minmax_params"}[ACTIVATION] 34void xnn_f32_v${OP.lower()}c${ACTIVATION_SUFFIX}_ukernel__wasmsimd${ARCH_SUFFIX}_x${BATCH_TILE}( 35 size_t n, 36 const float* a, 37 const float* b, 38 float* y, 39 const union ${PARAMS} params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS 40{ 41 assert(n != 0); 42 assert(n % sizeof(float) == 0); 43 assert(a != NULL); 44 assert(b != NULL); 45 assert(y != NULL); 46 47 $if ACTIVATION == "MINMAX": 48 const v128_t vy_min = wasm_v128_load64_splat(params->wasmsimd.min); 49 const v128_t vy_max = wasm_v128_load64_splat(params->wasmsimd.max); 50 $elif ACTIVATION == "RELU": 51 const v128_t vzero = wasm_i32x4_const_splat(0); 52 const v128_t vb = wasm_v128_load32_splat(b); 53 for (; n >= ${BATCH_TILE} * sizeof(float); n -= ${BATCH_TILE} * sizeof(float)) { 54 const v128_t va${ABC[0:4]} = wasm_v128_load(a); 55 $for N in range(4, BATCH_TILE, 4): 56 const v128_t va${ABC[N:N+4]} = wasm_v128_load(a + ${N}); 57 a += ${BATCH_TILE}; 58 59 $for N in range(0, BATCH_TILE, 4): 60 v128_t vy${ABC[N:N+4]} = ${WASM_F32X4_OP("va" + ABC[N:N+4])}; 61 62 $if OP == "SQRDIFF": 63 $for N in range(0, BATCH_TILE, 4): 64 vy${ABC[N:N+4]} = wasm_f32x4_mul(vy${ABC[N:N+4]}, vy${ABC[N:N+4]}); 65 66 $if ACTIVATION == "MINMAX": 67 $if X86: 68 $for N in range(0, BATCH_TILE, 4): 69 vy${ABC[N:N+4]} = wasm_f32x4_pmax(vy_min, vy${ABC[N:N+4]}); 70 71 $for N in range(0, BATCH_TILE, 4): 72 vy${ABC[N:N+4]} = wasm_f32x4_pmin(vy_max, vy${ABC[N:N+4]}); 73 $else: 74 $for N in range(0, BATCH_TILE, 4): 75 vy${ABC[N:N+4]} = wasm_f32x4_max(vy${ABC[N:N+4]}, vy_min); 76 77 $for N in range(0, BATCH_TILE, 4): 78 vy${ABC[N:N+4]} = wasm_f32x4_min(vy${ABC[N:N+4]}, vy_max); 79 $elif ACTIVATION == "RELU": 80 $for N in range(0, BATCH_TILE, 4): 81 vy${ABC[N:N+4]} = wasm_i32x4_max(vy${ABC[N:N+4]}, vzero); 82 83 wasm_v128_store(y, vy${ABC[0:4]}); 84 $for N in range(4, BATCH_TILE, 4): 85 wasm_v128_store(y + ${N}, vy${ABC[N:N+4]}); 86 y += ${BATCH_TILE}; 87 } 88 $if BATCH_TILE > 4: 89 for (; n >= 4 * sizeof(float); n -= 4 * sizeof(float)) { 90 const v128_t va = wasm_v128_load(a); 91 a += 4; 92 93 v128_t vy = ${WASM_F32X4_OP("va")}; 94 $if OP == "SQRDIFF": 95 vy = wasm_f32x4_mul(vy, vy); 96 97 $if ACTIVATION == "MINMAX": 98 $if X86: 99 vy = wasm_f32x4_pmax(vy_min, vy); 100 vy = wasm_f32x4_pmin(vy_max, vy); 101 $else: 102 vy = wasm_f32x4_max(vy, vy_min); 103 vy = wasm_f32x4_min(vy, vy_max); 104 $elif ACTIVATION == "RELU": 105 vy = wasm_i32x4_max(vy, vzero); 106 107 wasm_v128_store(y, vy); 108 y += 4; 109 } 110 if XNN_UNLIKELY(n != 0) { 111 const v128_t va = wasm_v128_load(a); 112 113 v128_t vy = ${WASM_F32X4_OP("va")}; 114 $if OP == "SQRDIFF": 115 vy = wasm_f32x4_mul(vy, vy); 116 117 $if ACTIVATION == "MINMAX": 118 $if X86: 119 vy = wasm_f32x4_pmax(vy_min, vy); 120 vy = wasm_f32x4_pmin(vy_max, vy); 121 $else: 122 vy = wasm_f32x4_max(vy, vy_min); 123 vy = wasm_f32x4_min(vy, vy_max); 124 $elif ACTIVATION == "RELU": 125 vy = wasm_i32x4_max(vy, vzero); 126 127 if (n & (2 * sizeof(float))) { 128 *((double*) y) = wasm_f64x2_extract_lane(vy, 0); 129 vy = wasm_v32x4_shuffle(vy, vy, 2, 3, 2, 3); 130 y += 2; 131 } 132 if (n & (1 * sizeof(float))) { 133 *y = wasm_f32x4_extract_lane(vy, 0); 134 } 135 } 136} 137