// Copyright 2021 Google LLC // // This source code is licensed under the BSD-style license found in the // LICENSE file in the root directory of this source tree. $assert REQUANTIZATION in ["FP32", "RNDNU"] $assert not CHANNELWISE or REQUANTIZATION == "FP32" $assert DATATYPE in ["QC8", "QS8", "QU8"] $assert DATATYPE != "QC8" or REQUANTIZATION == "FP32" #include .syntax unified $PARAMS_UNION = "xnn_qs8_minmax_params" if CHANNELWISE else "xnn_qs8_conv_minmax_params" $ISA = "neonv8" if ARMV8 else "neon" $CPU = "a35" if ARMV8 else "a7" $XMIN = "VMIN.U8" if DATATYPE == "QU8" else "VMIN.S8" $XMAX = "VMAX.U8" if DATATYPE == "QU8" else "VMAX.S8" $XXTL = "VMOVL.U8" if DATATYPE == "QU8" else "VMOVL.S8" $SQXTXN = "VQMOVUN.S16" if DATATYPE == "QU8" else "VQMOVN.S16" $XINT8_T = "uint8_t" if DATATYPE == "QU8" else "int8_t" // void xnn_${DATATYPE.lower()}_igemm_minmax_${REQUANTIZATION.lower()}_ukernel_1x8__aarch32_${ISA}_mlal_lane${"_prfm" if PREFETCH else ""}_cortex_${CPU} // size_t mr, (r0) // size_t nc, r1 // size_t kc, (r2) -> sp + 56 -> r5 // size_t ks, (r3) -> sp + 60 -> r14 // const ${XINT8_T}**restrict a, sp + 88 -> r2 // const void*restrict w, sp + 92 -> r9 // ${XINT8_T}*restrict c, sp + 96 -> r11 // size_t cm_stride, sp + 100 -> r6 // size_t cn_stride, sp + 104 -> r12 // size_t a_offset, sp + 108 -> (r5) // const ${XINT8_T}* zero, sp + 112 -> r7 // ${PARAMS_UNION}*params); sp + 116 -> (r5) // d8-d15, r4-r11,r14(lr) need to be preserved if used. r13(sp),r15(pc) are reserved. // Based on cortex_a53 microkernel but with Neon loads // Register usage // A0 r3 d0-d1 q0 // B r9 d8-d9 q4 q5 // C0 r11 d16-d17 q8 d18-d19 q9 // q2, q3 acc2 // Unused r4, r8, r10, d15, q10-q15, q1-q3 $if REQUANTIZATION == "RNDNU" and DATATYPE != "QU8": // params structure is 16 bytes // struct { // int32_t right_pre_shift; d12[0] // int32_t multiplier; d12[1] // int32_t right_post_shift; d13[0] // int16_t output_zero_point; d13[2] // int8_t output_min; d13[6] // int8_t output_max; d13[7] // } rndnu_neon; $elif REQUANTIZATION == "RNDNU" and DATATYPE == "QU8": // params structure is 20 bytes // struct { // uint8_t kernel_zero_point[4]; d14 // int32_t right_pre_shift; d12[0] // int32_t multiplier; d12[1] // int32_t right_post_shift; d13[0] // int16_t output_zero_point; d13[2] // uint8_t output_min; d13[6] // uint8_t output_max; d13[7] // } rndnu_neon; $elif DATATYPE == "QC8" and not ARMV8: // params structure is 10 bytes // struct { // float magic_bias; d12[0] // int32_t magic_bias_less_output_zero_point; d12[1] // int8_t output_min; d13[6] // int8_t output_max; d13[7] // } xnn_qs8_minmax_params.neon; $else: // params structure is 4 bytes // struct { // int16_t output_zero_point; d13[2] // int8_t output_min; d13[6] // int8_t output_max; d13[7] // } xnn_qs8_minmax_params.neonv8; BEGIN_FUNCTION xnn_${DATATYPE.lower()}_igemm_minmax_${REQUANTIZATION.lower()}_ukernel_1x8__aarch32_${ISA}_mlal_lane${"_prfm" if PREFETCH else ""}_cortex_${CPU} # Push 88 bytes # r2, r3 will be reloaded in outer loop. PUSH {r2, r3, r5, r6, r7, r9, r11, lr} // +32 $if DATATYPE == "QU8": VPUSH {d8-d14} // +56 = 88 $else: SUB sp, sp, 8 // +8 VPUSH {d8-d13} // +48 = 88 LDR r2, [sp, 88] // a LDR r9, [sp, 92] // w LDR r11, [sp, 96] // c LDR r6, [sp, 100] // cm_stride LDR r12, [sp, 104] // cn_stride LDR r7, [sp, 112] // zero LDR r5, [sp, 116] // params MOV r14, r3 // p = ks # Load params values $if DATATYPE == "QU8": VLD1.32 {d14[]}, [r5]! // QU8 kernel_zero_point $if REQUANTIZATION == "RNDNU": VLDM r5, {d12-d13} // RNDNU params $elif DATATYPE == "QC8" and ARMV8: VLD1.32 {d13[]}, [r5] // QC8 neonv8 params $elif DATATYPE == "QC8" and not ARMV8: VLDM r5!, {d12} // QC8 neon params VLD1.16 {d13[]}, [r5] $if PREFETCH: PLD [r9, 64] // Prefetch B PLD [r9, 112] PLD [r9, 192] PLD [r9, 256] PLD [r9, 320] PLD [r9, 384] .p2align 3 0: # Load initial bias from w into accumulators VLDM r9!, {d16-d19} // Bias VMOV.I32 q2, 0 // second set of C for pipelining FMLA VMOV.I32 q3, 0 .p2align 3 1: # Load next A pointer LDR r3, [r2, 0] # Add a_offset LDR r5, [sp, 108] // a_offset ADD r2, r2, 4 CMP r3, r7 // if a0 == zero ADD r3, r3, r5 // a0 += a_offset MOVEQ r3, r7 // a0 = zero, else += a0 + a_offset LDR r5, [sp, 56] // kc SUBS r5, r5, 8 // kc - 8 BLO 5f // less than 8 channels? // Prologue - load A0 and B0 VLD1.8 {d0}, [r3]! // A0 SUBS r5, r5, 8 // k = k - 8 VLD1.8 {d8}, [r9]! // B0 BLO 3f // less than 8 channels? // Main loop - 8 bytes // 64 bytes for weights. .p2align 3 2: // Extend ${XXTL} q0, d0 $if DATATYPE == "QU8": VSUBL.U8 q4, d8, d14 $else: VMOVL.S8 q4, d8 $if PREFETCH: PLD [r9, 448] // BLOCK 0 VLD1.8 {d10}, [r9]! // B1 VMLAL.S16 q8, d8, d0[0] VMLAL.S16 q9, d9, d0[0] $if DATATYPE == "QU8": VSUBL.U8 q5, d10, d14 $else: VMOVL.S8 q5, d10 // BLOCK 1 VLD1.8 {d8}, [r9]! // B2 VMLAL.S16 q2, d10, d0[1] VMLAL.S16 q3, d11, d0[1] $if DATATYPE == "QU8": VSUBL.U8 q4, d8, d14 $else: VMOVL.S8 q4, d8 // BLOCK 2 VLD1.8 {d10}, [r9]! // B3 VMLAL.S16 q8, d8, d0[2] VMLAL.S16 q9, d9, d0[2] $if DATATYPE == "QU8": VSUBL.U8 q5, d10, d14 $else: VMOVL.S8 q5, d10 // BLOCK 3 VLD1.8 {d8}, [r9]! // B4 VMLAL.S16 q2, d10, d0[3] VMLAL.S16 q3, d11, d0[3] VLD1.8 {d0}, [r3]! // A0 $if DATATYPE == "QU8": VSUBL.U8 q4, d8, d14 $else: VMOVL.S8 q4, d8 // BLOCK 4 VLD1.8 {d10}, [r9]! // B5 VMLAL.S16 q8, d8, d1[0] VMLAL.S16 q9, d9, d1[0] $if DATATYPE == "QU8": VSUBL.U8 q5, d10, d14 $else: VMOVL.S8 q5, d10 // BLOCK 5 VLD1.8 {d8}, [r9]! // B6 VMLAL.S16 q2, d10, d1[1] VMLAL.S16 q3, d11, d1[1] $if DATATYPE == "QU8": VSUBL.U8 q4, d8, d14 $else: VMOVL.S8 q4, d8 // BLOCK 6 VLD1.8 {d10}, [r9]! // B7 VMLAL.S16 q8, d8, d1[2] VMLAL.S16 q9, d9, d1[2] $if DATATYPE == "QU8": VSUBL.U8 q5, d10, d14 $else: VMOVL.S8 q5, d10 SUBS r5, r5, 8 // BLOCK 7 VLD1.8 {d8}, [r9]! // B0 VMLAL.S16 q2, d10, d1[3] VMLAL.S16 q3, d11, d1[3] BHS 2b // Epilogue .p2align 3 3: // Extend ${XXTL} q0, d0 $if DATATYPE == "QU8": VSUBL.U8 q4, d8, d14 $else: VMOVL.S8 q4, d8 $if PREFETCH: PLD [r9, 448] // BLOCK 0 VLD1.8 {d10}, [r9]! // B1 VMLAL.S16 q8, d8, d0[0] VMLAL.S16 q9, d9, d0[0] $if DATATYPE == "QU8": VSUBL.U8 q5, d10, d14 $else: VMOVL.S8 q5, d10 // BLOCK 1 VLD1.8 {d8}, [r9]! // B2 VMLAL.S16 q2, d10, d0[1] VMLAL.S16 q3, d11, d0[1] $if DATATYPE == "QU8": VSUBL.U8 q4, d8, d14 $else: VMOVL.S8 q4, d8 // BLOCK 2 VLD1.8 {d10}, [r9]! // B3 VMLAL.S16 q8, d8, d0[2] VMLAL.S16 q9, d9, d0[2] $if DATATYPE == "QU8": VSUBL.U8 q5, d10, d14 $else: VMOVL.S8 q5, d10 // BLOCK 3 VLD1.8 {d8}, [r9]! // B4 VMLAL.S16 q2, d10, d0[3] VMLAL.S16 q3, d11, d0[3] $if DATATYPE == "QU8": VSUBL.U8 q4, d8, d14 $else: VMOVL.S8 q4, d8 // BLOCK 4 VLD1.8 {d10}, [r9]! // B5 VMLAL.S16 q8, d8, d1[0] VMLAL.S16 q9, d9, d1[0] $if DATATYPE == "QU8": VSUBL.U8 q5, d10, d14 $else: VMOVL.S8 q5, d10 // BLOCK 5 VLD1.8 {d8}, [r9]! // B6 VMLAL.S16 q2, d10, d1[1] VMLAL.S16 q3, d11, d1[1] $if DATATYPE == "QU8": VSUBL.U8 q4, d8, d14 $else: VMOVL.S8 q4, d8 // BLOCK 6 VLD1.8 {d10}, [r9]! // B7 VMLAL.S16 q8, d8, d1[2] VMLAL.S16 q9, d9, d1[2] $if DATATYPE == "QU8": VSUBL.U8 q5, d10, d14 $else: VMOVL.S8 q5, d10 ADDS r5, r5, 8 VMLAL.S16 q2, d10, d1[3] VMLAL.S16 q3, d11, d1[3] # Is there a remainder?- 1-7 bytes of A BNE 6f 4: # ks loop SUBS r14, r14, 4 // ks -= MR * sizeof(void*) BHI 1b LDR r14, [sp, 60] // p = ks VADD.S32 q8, q8, q2 VADD.S32 q9, q9, q3 $if REQUANTIZATION == "RNDNU": # RNDNU quantization VDUP.32 q0, d12[0] // right_pre_shift VQSHL.S32 q8, q8, q0 VQSHL.S32 q9, q9, q0 VDUP.32 q2, d13[0] // right_post_shift VQDMULH.S32 q8, q8, d12[1] // multiplier VQDMULH.S32 q9, q9, d12[1] VRSHL.S32 q8, q8, q2 VRSHL.S32 q9, q9, q2 $elif DATATYPE == "QC8" and ARMV8: # QC8 FP32 quantization VLD1.8 {q0-q1}, [r9]! VCVT.F32.S32 q8, q8 VCVT.F32.S32 q9, q9 VMUL.F32 q8, q8, q0 // multiplier VMUL.F32 q9, q9, q1 VCVTN.S32.F32 q8, q8 VCVTN.S32.F32 q9, q9 $elif DATATYPE == "QC8" and not ARMV8: # QC8 FP32 quantization VLD1.8 {q0-q1}, [r9]! VDUP.32 q2, d12[0] // magic_bias VDUP.32 q3, d12[1] // magic_bias_less_output_zero_point VCVT.F32.S32 q8, q8 VCVT.F32.S32 q9, q9 VMUL.F32 q8, q8, q0 // multiplier VMUL.F32 q9, q9, q1 VADD.F32 q8, q8, q2 // magic_bias VADD.F32 q9, q9, q2 VQSUB.S32 q8, q8, q3 // magic_bias_less_output_zero_point VQSUB.S32 q9, q9, q3 $if DATATYPE != "QC8" or ARMV8: VDUP.16 q0, d13[2] // output_zero_point VQMOVN.S32 d16, q8 VQMOVN.S32 d17, q9 $if DATATYPE != "QC8" or ARMV8: VQADD.S16 q8, q8, q0 VDUP.8 d24, d13[6] // output_min ${SQXTXN} d0, q8 VDUP.8 d25, d13[7] // output_max ${XMAX} d0, d0, d24 SUBS r1, r1, 8 ${XMIN} d0, d0, d25 # Store full 1 x 8 BLO 7f VST1.8 {d0}, [r11], r12 SUB r2, r2, r14 // a -= ks BHI 0b $if DATATYPE == "QU8": VPOP {d8-d14} ADD sp, sp, 8 // skip r2, r3 $else: VPOP {d8-d13} ADD sp, sp, 16 // skip pad of 8, r2, r3 POP {r5, r6, r7, r9, r11, pc} # Remainder- 1 to 7 bytes of A .p2align 3 5: AND r5, r5, 7 // kc remainder 1 to 7 6: VLD1.8 {d0}, [r3] VLD1.8 {d8}, [r9]! ${XXTL} q0, d0 $if DATATYPE == "QU8": VSUBL.U8 q4, d8, d14 $else: VMOVL.S8 q4, d8 VMLAL.S16 q8, d8, d0[0] VMLAL.S16 q9, d9, d0[0] CMP r5, 2 BLO 4b VLD1.8 {d8}, [r9]! $if DATATYPE == "QU8": VSUBL.U8 q4, d8, d14 $else: VMOVL.S8 q4, d8 VMLAL.S16 q8, d8, d0[1] VMLAL.S16 q9, d9, d0[1] BEQ 4b VLD1.8 {d8}, [r9]! $if DATATYPE == "QU8": VSUBL.U8 q4, d8, d14 $else: VMOVL.S8 q4, d8 VMLAL.S16 q8, d8, d0[2] VMLAL.S16 q9, d9, d0[2] CMP r5, 4 BLO 4b VLD1.8 {d8}, [r9]! $if DATATYPE == "QU8": VSUBL.U8 q4, d8, d14 $else: VMOVL.S8 q4, d8 VMLAL.S16 q8, d8, d0[3] VMLAL.S16 q9, d9, d0[3] BEQ 4b VLD1.8 {d8}, [r9]! $if DATATYPE == "QU8": VSUBL.U8 q4, d8, d14 $else: VMOVL.S8 q4, d8 VMLAL.S16 q8, d8, d1[0] VMLAL.S16 q9, d9, d1[0] CMP r5, 6 BLO 4b VLD1.8 {d8}, [r9]! $if DATATYPE == "QU8": VSUBL.U8 q4, d8, d14 $else: VMOVL.S8 q4, d8 VMLAL.S16 q8, d8, d1[1] VMLAL.S16 q9, d9, d1[1] BEQ 4b VLD1.8 {d8}, [r9]! $if DATATYPE == "QU8": VSUBL.U8 q4, d8, d14 $else: VMOVL.S8 q4, d8 VMLAL.S16 q8, d8, d1[2] VMLAL.S16 q9, d9, d1[2] B 4b # Store odd width .p2align 3 7: TST r1, 4 BEQ 8f VST1.32 {d0[0]}, [r11]! VEXT.8 q0, q0, q0, 4 8: TST r1, 2 BEQ 9f VST1.16 {d0[0]}, [r11]! VEXT.8 q0, q0, q0, 2 9: TST r1, 1 BEQ 10f VST1.8 {d0[0]}, [r11] 10: $if DATATYPE == "QU8": VPOP {d8-d14} ADD sp, sp, 8 // skip r2, r3 $else: VPOP {d8-d13} ADD sp, sp, 16 // skip pad of 8, r2, r3 POP {r5, r6, r7, r9, r11, pc} END_FUNCTION xnn_${DATATYPE.lower()}_igemm_minmax_${REQUANTIZATION.lower()}_ukernel_1x8__aarch32_${ISA}_mlal_lane${"_prfm" if PREFETCH else ""}_cortex_${CPU} #ifdef __ELF__ .section ".note.GNU-stack","",%progbits #endif