1// Copyright 2020 Google LLC 2// 3// This source code is licensed under the BSD-style license found in the 4// LICENSE file in the root directory of this source tree. 5 6$assert ROW_TILE >= 1 7$assert ACCUMULATORS >= 1 8#include <assert.h> 9 10#include <wasm_simd128.h> 11 12#include <xnnpack/dwconv.h> 13#include <xnnpack/math.h> 14 15 16$ARCH_SUFFIX = "_x86" if X86 else "_arm" 17 18void xnn_f32_dwconv2d_chw_ukernel_3x3s2p1__wasmsimd${ARCH_SUFFIX}_loadsplat_${ROW_TILE}x4${"_acc%d" % ACCUMULATORS if ACCUMULATORS > 1 else ""}( 19 size_t input_height, 20 size_t input_width, 21 const float* input, 22 const float* weights, 23 const float* zero, 24 float* output, 25 uint32_t padding_top, 26 const union xnn_f32_chw_params params[restrict XNN_MIN_ELEMENTS(1)]) 27{ 28 assert(input_height != 0); 29 assert(input_width != 0); 30 assert(input_width % sizeof(float) == 0); 31 assert(padding_top >= 0); 32 assert(padding_top <= 1); 33 34 const v128_t vmask_even = wasm_v128_load(params->scalar.mask_even); 35 const v128_t vmask_odd = wasm_v128_load(params->scalar.mask_odd); 36 const v128_t vmax = wasm_v32x4_load_splat(¶ms->scalar.max); 37 const v128_t vmin = wasm_v32x4_load_splat(¶ms->scalar.min); 38 39 const v128_t vw0123 = wasm_v128_load(weights); 40 const v128_t vw4567 = wasm_v128_load(weights + 4); 41 const v128_t vw89 = wasm_v64x2_load_splat(weights + 8); 42 const v128_t vbias = wasm_v32x4_shuffle(vw0123, vw0123, 0, 0, 0, 0); 43 const v128_t vk00 = wasm_v32x4_shuffle(vw0123, vw0123, 1, 1, 1, 1); 44 const v128_t vk01 = wasm_v32x4_shuffle(vw0123, vw0123, 2, 2, 2, 2); 45 const v128_t vk02 = wasm_v32x4_shuffle(vw0123, vw0123, 3, 3, 3, 3); 46 const v128_t vk10 = wasm_v32x4_shuffle(vw4567, vw4567, 0, 0, 0, 0); 47 const v128_t vk11 = wasm_v32x4_shuffle(vw4567, vw4567, 1, 1, 1, 1); 48 const v128_t vk12 = wasm_v32x4_shuffle(vw4567, vw4567, 2, 2, 2, 2); 49 const v128_t vk20 = wasm_v32x4_shuffle(vw4567, vw4567, 3, 3, 3, 3); 50 const v128_t vk21 = wasm_v32x4_shuffle(vw89, vw89, 0, 0, 0, 0); 51 const v128_t vk22 = wasm_v32x4_shuffle(vw89, vw89, 1, 1, 1, 1); 52 53 const v128_t vzero = wasm_f32x4_splat(0.0f); 54 55 const size_t input_decrement = round_down_po2(input_width, 4 /* SIMD output width */ * 2 /* subsampling */ * sizeof(float)); 56 $if ROW_TILE > 1: 57 const size_t output_width = round_down_po2((input_width + (2 /* padding */ - 3 /* kernel size */ + 2 /* subsampling */) * sizeof(float)) / 2, sizeof(float)); 58 59 const float* i0 = (const float*) ((uintptr_t) input - ((-padding_top) & input_width)); 60 const float* i1 = (const float*) ((uintptr_t) i0 + input_width); 61 if XNN_UNPREDICTABLE(padding_top != 0) { 62 i0 = zero; 63 } 64 $for M in range(2, 1 + 2 * ROW_TILE): 65 const float* i${M} = (const float*) ((uintptr_t) i${M-1} + input_width); 66 67 float* o0 = output; 68 $for M in range(1, ROW_TILE): 69 float* o${M} = (float*) ((uintptr_t) o${M-1} + output_width); 70 71 size_t padded_input_height = input_height + padding_top + 1 /* padding bottom */; 72 size_t output_height = (padded_input_height - 3 /* kernel size */ + 2 /* subsampling */) / 2; 73 do { 74 $for M in range(2, 1 + 2 * ROW_TILE): 75 if XNN_UNPREDICTABLE(padded_input_height < ${2 + M}) { 76 i${M} = zero; 77 $if M % 2 == 1: 78 o${(M - 1) / 2} = o${(M - 1) / 2 - 1}; 79 } 80 81 $for M in range(1 + 2 * ROW_TILE): 82 v128_t vi${M}x1357 = vzero; 83 84 size_t w = input_width; 85 for (; w >= 8 * sizeof(float); w -= 8 * sizeof(float)) { 86 $for M in range(ROW_TILE): 87 v128_t vo${M}p0 = vbias; 88 89 $for M in range(1 + 2 * ROW_TILE): 90 const v128_t vi${M}x89AB = wasm_v128_load(i${M}); 91 const v128_t vi${M}xCDEF = wasm_v128_load(i${M} + 4); 92 i${M} += 8; 93 94 $for M in range(1 + 2 * ROW_TILE): 95 const v128_t vi${M}x8ACE = wasm_v32x4_shuffle(vi${M}x89AB, vi${M}xCDEF, 0, 2, 4, 6); 96 const v128_t vi${M}x9BDF = wasm_v32x4_shuffle(vi${M}x89AB, vi${M}xCDEF, 1, 3, 5, 7); 97 98 $for M in range(ROW_TILE): 99 $if ACCUMULATORS > 1: 100 v128_t vo${M}p1 = wasm_f32x4_mul(vi${2*M}x8ACE, vk01); 101 $else: 102 vo${M}p0 = wasm_f32x4_add(vo${M}p0, wasm_f32x4_mul(vi${2*M}x8ACE, vk01)); 103 104 $for M in range(ROW_TILE): 105 $if ACCUMULATORS > 2: 106 v128_t vo${M}p2 = wasm_f32x4_mul(vi${2*M+1}x8ACE, vk11); 107 $else: 108 vo${M}p0 = wasm_f32x4_add(vo${M}p0, wasm_f32x4_mul(vi${2*M+1}x8ACE, vk11)); 109 110 $for M in range(ROW_TILE): 111 $if ACCUMULATORS > 3: 112 v128_t vo${M}p3 = wasm_f32x4_mul(vi${2*M+2}x8ACE, vk21); 113 $else: 114 vo${M}p${4 % ACCUMULATORS} = wasm_f32x4_add(vo${M}p${4 % ACCUMULATORS}, wasm_f32x4_mul(vi${2*M+2}x8ACE, vk21)); 115 116 $for M in range(1 + 2 * ROW_TILE): 117 const v128_t vi${M}x7BDF = wasm_v32x4_shuffle(vi${M}x1357, vi${M}x9BDF, 3, 4, 5, 6); 118 vi${M}x1357 = vi${M}x9BDF; 119 120 $for M in range(ROW_TILE): 121 $if ACCUMULATORS > 4: 122 v128_t vo${M}p4 = wasm_f32x4_mul(vi${2*M}x7BDF, vk00); 123 $else: 124 vo${M}p${5 % ACCUMULATORS} = wasm_f32x4_add(vo${M}p${5 % ACCUMULATORS}, wasm_f32x4_mul(vi${2*M}x7BDF, vk00)); 125 126 $for M in range(ROW_TILE): 127 $if ACCUMULATORS > 5: 128 v128_t vo${M}p5 = wasm_f32x4_mul(vi${2*M+1}x7BDF, vk10); 129 $else: 130 vo${M}p${6 % ACCUMULATORS} = wasm_f32x4_add(vo${M}p${6 % ACCUMULATORS}, wasm_f32x4_mul(vi${2*M+1}x7BDF, vk10)); 131 132 $for M in range(ROW_TILE): 133 $if ACCUMULATORS > 6: 134 v128_t vo${M}p6 = wasm_f32x4_mul(vi${2*M+2}x7BDF, vk11); 135 $else: 136 vo${M}p${7 % ACCUMULATORS} = wasm_f32x4_add(vo${M}p${7 % ACCUMULATORS}, wasm_f32x4_mul(vi${2*M+2}x7BDF, vk20)); 137 138 $for M in range(ROW_TILE): 139 vo${M}p${8 % ACCUMULATORS} = wasm_f32x4_add(vo${M}p${8 % ACCUMULATORS}, wasm_f32x4_mul(vi${2*M}x9BDF, vk02)); 140 141 $for M in range(ROW_TILE): 142 vo${M}p${9 % ACCUMULATORS} = wasm_f32x4_add(vo${M}p${9 % ACCUMULATORS}, wasm_f32x4_mul(vi${2*M+1}x9BDF, vk12)); 143 144 $for M in range(ROW_TILE): 145 vo${M}p${10 % ACCUMULATORS} = wasm_f32x4_add(vo${M}p${10 % ACCUMULATORS}, wasm_f32x4_mul(vi${2*M+2}x9BDF, vk22)); 146 147 $if ACCUMULATORS > 1: 148 $ACC_SLICE = 1 149 $while ACC_SLICE < ACCUMULATORS: 150 $for A in range(0, ACCUMULATORS, ACC_SLICE * 2): 151 $if A + ACC_SLICE < ACCUMULATORS: 152 $for M in range(ROW_TILE): 153 vo${M}p${A} = wasm_f32x4_add(vo${M}p${A}, vo${M}p${A + ACC_SLICE}); 154 $ACC_SLICE *= 2 155 156 $if X86: 157 $for M in range(ROW_TILE): 158 v128_t vo${M} = wasm_v128_bitselect(vmin, vo${M}p0, wasm_f32x4_lt(vo${M}p0, vmin)); 159 $for M in range(ROW_TILE): 160 vo${M} = wasm_v128_bitselect(vo${M}, vmax, wasm_f32x4_le(vo${M}, vmax)); 161 $else: 162 $for M in range(ROW_TILE): 163 v128_t vo${M} = wasm_f32x4_max(vo${M}p0, vmin); 164 $for M in range(ROW_TILE): 165 vo${M} = wasm_f32x4_min(vo${M}, vmax); 166 167 $for M in reversed(range(ROW_TILE)): 168 wasm_v128_store(o${M}, vo${M}); o${M} += 4; 169 } 170 // Last block has 0-7 pixels to process. 171 assert(w < 8 * sizeof(float)); 172 if XNN_LIKELY(w != 0) { 173 $for M in range(ROW_TILE): 174 v128_t vo${M}p0 = vbias; 175 176 $for M in range(1 + 2 * ROW_TILE): 177 const v128_t vi${M}x89AB = wasm_v128_load(i${M}); 178 const v128_t vi${M}xCDEF = wasm_v128_load(i${M} + 4); 179 180 $for M in range(1 + 2 * ROW_TILE): 181 const v128_t vi${M}x8ACE = wasm_v128_and(vmask_even, wasm_v32x4_shuffle(vi${M}x89AB, vi${M}xCDEF, 0, 2, 4, 6)); 182 const v128_t vi${M}x9BDF = wasm_v128_and(vmask_odd, wasm_v32x4_shuffle(vi${M}x89AB, vi${M}xCDEF, 1, 3, 5, 7)); 183 184 $for M in range(ROW_TILE): 185 $if ACCUMULATORS > 1: 186 v128_t vo${M}p1 = wasm_f32x4_mul(vi${2*M}x8ACE, vk01); 187 $else: 188 vo${M}p0 = wasm_f32x4_add(vo${M}p0, wasm_f32x4_mul(vi${2*M}x8ACE, vk01)); 189 190 $for M in range(ROW_TILE): 191 $if ACCUMULATORS > 2: 192 v128_t vo${M}p2 = wasm_f32x4_mul(vi${2*M+1}x8ACE, vk11); 193 $else: 194 vo${M}p0 = wasm_f32x4_add(vo${M}p0, wasm_f32x4_mul(vi${2*M+1}x8ACE, vk11)); 195 196 $for M in range(ROW_TILE): 197 $if ACCUMULATORS > 3: 198 v128_t vo${M}p3 = wasm_f32x4_mul(vi${2*M+2}x8ACE, vk21); 199 $else: 200 vo${M}p${4 % ACCUMULATORS} = wasm_f32x4_add(vo${M}p${4 % ACCUMULATORS}, wasm_f32x4_mul(vi${2*M+2}x8ACE, vk21)); 201 202 $for M in range(1 + 2 * ROW_TILE): 203 const v128_t vi${M}x7BDF = wasm_v32x4_shuffle(vi${M}x1357, vi${M}x9BDF, 3, 4, 5, 6); 204 205 $for M in range(ROW_TILE): 206 $if ACCUMULATORS > 4: 207 v128_t vo${M}p4 = wasm_f32x4_mul(vi${2*M}x7BDF, vk00); 208 $else: 209 vo${M}p${5 % ACCUMULATORS} = wasm_f32x4_add(vo${M}p${5 % ACCUMULATORS}, wasm_f32x4_mul(vi${2*M}x7BDF, vk00)); 210 211 $for M in range(ROW_TILE): 212 $if ACCUMULATORS > 5: 213 v128_t vo${M}p5 = wasm_f32x4_mul(vi${2*M+1}x7BDF, vk10); 214 $else: 215 vo${M}p${6 % ACCUMULATORS} = wasm_f32x4_add(vo${M}p${6 % ACCUMULATORS}, wasm_f32x4_mul(vi${2*M+1}x7BDF, vk10)); 216 217 $for M in range(ROW_TILE): 218 $if ACCUMULATORS > 6: 219 v128_t vo${M}p6 = wasm_f32x4_mul(vi${2*M+2}x7BDF, vk11); 220 $else: 221 vo${M}p${7 % ACCUMULATORS} = wasm_f32x4_add(vo${M}p${7 % ACCUMULATORS}, wasm_f32x4_mul(vi${2*M+2}x7BDF, vk20)); 222 223 $for M in range(ROW_TILE): 224 vo${M}p${8 % ACCUMULATORS} = wasm_f32x4_add(vo${M}p${8 % ACCUMULATORS}, wasm_f32x4_mul(vi${2*M}x9BDF, vk02)); 225 226 $for M in range(ROW_TILE): 227 vo${M}p${9 % ACCUMULATORS} = wasm_f32x4_add(vo${M}p${9 % ACCUMULATORS}, wasm_f32x4_mul(vi${2*M+1}x9BDF, vk12)); 228 229 $for M in range(ROW_TILE): 230 vo${M}p${10 % ACCUMULATORS} = wasm_f32x4_add(vo${M}p${10 % ACCUMULATORS}, wasm_f32x4_mul(vi${2*M+2}x9BDF, vk22)); 231 232 $if ACCUMULATORS > 1: 233 $ACC_SLICE = 1 234 $while ACC_SLICE < ACCUMULATORS: 235 $for A in range(0, ACCUMULATORS, ACC_SLICE * 2): 236 $if A + ACC_SLICE < ACCUMULATORS: 237 $for M in range(ROW_TILE): 238 vo${M}p${A} = wasm_f32x4_add(vo${M}p${A}, vo${M}p${A + ACC_SLICE}); 239 $ACC_SLICE *= 2 240 241 $if X86: 242 $for M in range(ROW_TILE): 243 v128_t vo${M} = wasm_v128_bitselect(vmin, vo${M}p0, wasm_f32x4_lt(vo${M}p0, vmin)); 244 $for M in range(ROW_TILE): 245 vo${M} = wasm_v128_bitselect(vo${M}, vmax, wasm_f32x4_le(vo${M}, vmax)); 246 $else: 247 $for M in range(ROW_TILE): 248 v128_t vo${M} = wasm_f32x4_max(vo${M}p0, vmin); 249 $for M in range(ROW_TILE): 250 vo${M} = wasm_f32x4_min(vo${M}, vmax); 251 252 w += 1 * sizeof(float); 253 if (w & (8 * sizeof(float))) { 254 $for M in reversed(range(ROW_TILE)): 255 wasm_v128_store(o${M}, vo${M}); o${M} += 4; 256 } else { 257 if (w & (4 * sizeof(float))) { 258 $for M in reversed(range(ROW_TILE)): 259 *((double*) o${M}) = wasm_f64x2_extract_lane(vo${M}, 0); o${M} += 2; 260 261 $for M in range(ROW_TILE): 262 vo${M} = wasm_v32x4_shuffle(vo${M}, vo${M}, 2, 3, 0, 1); 263 } 264 if (w & (2 * sizeof(float))) { 265 $for M in reversed(range(ROW_TILE)): 266 *o${M} = wasm_f32x4_extract_lane(vo${M}, 0); o${M} += 1; 267 } 268 } 269 } 270 271 i0 = (const float*) ((uintptr_t) i${2 * ROW_TILE} - input_decrement); 272 $for M in range(1, 1 + 2 * ROW_TILE): 273 i${M} = (const float*) ((uintptr_t) i${M-1} + input_width); 274 275 $if ROW_TILE > 1: 276 o0 = o${ROW_TILE - 1}; 277 $for M in range(1, ROW_TILE): 278 o${M} = (float*) ((uintptr_t) o${M-1} + output_width); 279 280 $if ROW_TILE > 1: 281 output_height = doz(output_height, ${ROW_TILE}); 282 padded_input_height = doz(padded_input_height, ${ROW_TILE * 2}); 283 $else: 284 output_height -= 1; 285 padded_input_height -= 2; 286 } while (output_height != 0); 287} 288