1// Copyright 2019 Google LLC 2// 3// This source code is licensed under the BSD-style license found in the 4// LICENSE file in the root directory of this source tree. 5 6$assert CHANNEL_TILE >= 1 7$assert ROW_TILE >= 1 8$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" 9#include <assert.h> 10 11#include <xnnpack/math.h> 12#include <xnnpack/vmulcaddc.h> 13 14 15$MIN_F32 = "__builtin_wasm_min_f32" if WASM else "math_min_f32" 16$MAX_F32 = "__builtin_wasm_max_f32" if WASM else "math_max_f32" 17void xnn_f32_vmulcaddc_minmax_ukernel_c${CHANNEL_TILE}__${"wasm" if WASM else "scalar"}_${ROW_TILE}x( 18 size_t rows, 19 size_t channels, 20 const float*restrict input, 21 size_t input_stride, 22 const float*restrict weights, 23 float*restrict output, 24 size_t output_stride, 25 const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) 26{ 27 assert(rows != 0); 28 assert(channels != 0); 29 assert(channels % sizeof(float) == 0); 30 31 const size_t input_increment = input_stride * ${ROW_TILE} - channels; 32 const size_t output_increment = output_stride * ${ROW_TILE} - channels; 33 34 const float* i0 = input; 35 float* o0 = output; 36 $for M in range(1, ROW_TILE): 37 const float* i${M} = (const float*) ((uintptr_t) i${M-1} + input_stride); 38 float* o${M} = (float*) ((uintptr_t) o${M-1} + output_stride); 39 $if M % 2 == 0: 40 if XNN_UNPREDICTABLE(rows <= ${M}) { 41 i${M} = i${M-1}; 42 o${M} = o${M-1}; 43 } 44 $else: 45 if XNN_UNPREDICTABLE(rows < ${M+1}) { 46 i${M} = i${M-1}; 47 o${M} = o${M-1}; 48 } 49 50 const float vmin = params->scalar.min; 51 const float vmax = params->scalar.max; 52 do { 53 const float* w = weights; 54 size_t c = channels; 55 $if CHANNEL_TILE > 1: 56 for (; c >= ${CHANNEL_TILE} * sizeof(float); c -= ${CHANNEL_TILE} * sizeof(float)) { 57 $for C in range(CHANNEL_TILE): 58 const float vscale${ABC[C]} = w[${C}]; 59 60 $for M in range(ROW_TILE): 61 $for C in range(CHANNEL_TILE): 62 float vacc${M}x${ABC[C]} = i${M}[${C}]; 63 i${M} += ${CHANNEL_TILE}; 64 65 $for C in range(CHANNEL_TILE): 66 const float vbias${ABC[C]} = w[${C + CHANNEL_TILE}]; 67 68 $for M in range(ROW_TILE): 69 $for C in range(CHANNEL_TILE): 70 vacc${M}x${ABC[C]} = vacc${M}x${ABC[C]} * vscale${ABC[C]} + vbias${ABC[C]}; 71 72 $for M in range(ROW_TILE): 73 $for C in range(CHANNEL_TILE): 74 vacc${M}x${ABC[C]} = ${MAX_F32}(vacc${M}x${ABC[C]}, vmin); 75 76 $for M in range(ROW_TILE): 77 $for C in range(CHANNEL_TILE): 78 vacc${M}x${ABC[C]} = ${MIN_F32}(vacc${M}x${ABC[C]}, vmax); 79 80 $for M in range(ROW_TILE): 81 $for C in range(CHANNEL_TILE): 82 o${M}[${C}] = vacc${M}x${ABC[C]}; 83 o${M} += ${CHANNEL_TILE}; 84 85 w += ${CHANNEL_TILE * 2}; 86 } 87 if XNN_UNLIKELY(c != 0) { 88 do { 89 const float vscale = *w++; 90 91 $for M in range(ROW_TILE): 92 float vacc${M} = *i${M}++; 93 94 const float vbias = w[${CHANNEL_TILE - 1}]; 95 96 $for M in range(ROW_TILE): 97 vacc${M} = vacc${M} * vscale + vbias; 98 99 $for M in range(ROW_TILE): 100 vacc${M} = ${MAX_F32}(vacc${M}, vmin); 101 102 $for M in range(ROW_TILE): 103 vacc${M} = ${MIN_F32}(vacc${M}, vmax); 104 105 $for M in range(ROW_TILE): 106 *o${M}++ = vacc${M}; 107 108 c -= sizeof(float); 109 } while (c != 0); 110 } 111 $else: 112 do { 113 const float vscale = w[0]; 114 115 $for M in range(ROW_TILE): 116 float vacc${M} = *i${M}++; 117 118 const float vbias = w[1]; 119 120 $for M in range(ROW_TILE): 121 vacc${M} = vacc${M} * vscale + vbias; 122 123 $for M in range(ROW_TILE): 124 vacc${M} = ${MAX_F32}(vacc${M}, vmin); 125 126 $for M in range(ROW_TILE): 127 vacc${M} = ${MIN_F32}(vacc${M}, vmax); 128 129 $for M in range(ROW_TILE): 130 *o${M}++ = vacc${M}; 131 132 w += 2; 133 c -= sizeof(float); 134 } while (c != 0); 135 $for M in range(ROW_TILE): 136 i${M} = (const float*) ((uintptr_t) i${M} + input_increment); 137 o${M} = (float*) ((uintptr_t) o${M} + output_increment); 138 $if M % 2 == 1: 139 if XNN_UNPREDICTABLE(rows < ${ROW_TILE + M + 1}) { 140 i${M} = i${M-1}; 141 o${M} = o${M-1}; 142 } 143 $elif M != 0: 144 if XNN_UNPREDICTABLE(rows <= ${ROW_TILE + M}) { 145 i${M} = i${M-1}; 146 o${M} = o${M-1}; 147 } 148 rows = doz(rows, ${ROW_TILE}); 149 } while (rows != 0); 150} 151