1// Copyright 2019 Google LLC 2// 3// This source code is licensed under the BSD-style license found in the 4// LICENSE file in the root directory of this source tree. 5 6$assert BATCH_TILE >= 1 7$ABC = "0123456789ABCDEFGHIJKLMN" 8$assert OP in ["ADD", "DIV", "MAX", "MIN", "MUL", "SUB"] 9#include <assert.h> 10 11#include <xnnpack/common.h> 12#include <xnnpack/math.h> 13#include <xnnpack/vbinary.h> 14 15 16$MIN_F32 = "__builtin_wasm_min_f32" if WASM else "math_min_f32" 17$MAX_F32 = "__builtin_wasm_max_f32" if WASM else "math_max_f32" 18$OP_FUNC = { 19$ "ADD": lambda x, y: "%s + %s" % (x, y), 20$ "DIV": lambda x, y: "%s / %s" % (x, y), 21$ "MAX": lambda x, y: "%s(%s, %s)" % (MAX_F32, x, y), 22$ "MIN": lambda x, y: "%s(%s, %s)" % (MIN_F32, x, y), 23$ "MUL": lambda x, y: "%s * %s" % (x, y), 24$ "SUB": lambda x, y: "%s - %s" % (x, y), 25$}[OP] 26void xnn_f32_v${OP.lower()}_ukernel__${"wasm" if WASM else "scalar"}_x${BATCH_TILE}( 27 size_t n, 28 const float* a, 29 const float* b, 30 float* y, 31 const union xnn_f32_output_params params[restrict static 1]) 32{ 33 assert(n != 0); 34 assert(n % sizeof(float) == 0); 35 36 const float vy_min = params->scalar.min; 37 const float vy_max = params->scalar.max; 38 39 $if BATCH_TILE > 1: 40 for (; n >= ${BATCH_TILE} * sizeof(float); n -= ${BATCH_TILE} * sizeof(float)) { 41 $for N in range(BATCH_TILE): 42 const float va${ABC[N]} = a[${N}]; 43 a += ${BATCH_TILE}; 44 45 $for N in range(BATCH_TILE): 46 const float vb${ABC[N]} = b[${N}]; 47 b += ${BATCH_TILE}; 48 49 $for N in range(BATCH_TILE): 50 float vy${ABC[N]} = ${OP_FUNC("va" + ABC[N], "vb" + ABC[N])}; 51 52 $for N in range(BATCH_TILE): 53 vy${ABC[N]} = ${MAX_F32}(vy${ABC[N]}, vy_min); 54 55 $for N in range(BATCH_TILE): 56 vy${ABC[N]} = ${MIN_F32}(vy${ABC[N]}, vy_max); 57 58 $for N in range(BATCH_TILE): 59 y[${N}] = vy${ABC[N]}; 60 y += ${BATCH_TILE}; 61 } 62 if XNN_UNLIKELY(n != 0) { 63 $if BATCH_TILE > 2: 64 do { 65 const float va = *a++; 66 const float vb = *b++; 67 float vy = ${OP_FUNC("va", "vb")}; 68 vy = ${MAX_F32}(vy, vy_min); 69 vy = ${MIN_F32}(vy, vy_max); 70 *y++ = vy; 71 n -= sizeof(float); 72 } while (n != 0); 73 $else: 74 const float va = *a; 75 const float vb = *b; 76 float vy = ${OP_FUNC("va", "vb")}; 77 vy = ${MAX_F32}(vy, vy_min); 78 vy = ${MIN_F32}(vy, vy_max); 79 *y = vy; 80 } 81 $else: 82 for (; n >= sizeof(float); n -= sizeof(float)) { 83 const float va = *a++; 84 const float vb = *b++; 85 float vy = ${OP_FUNC("va", "vb")}; 86 vy = ${MAX_F32}(vy, vy_min); 87 vy = ${MIN_F32}(vy, vy_max); 88 *y++ = vy; 89 } 90} 91