/* ---------------------------------------------------------------------- * Project: CMSIS DSP Library * Title: arm_mat_vec_mult_q7.c * Description: Q7 matrix and vector multiplication * * $Date: 23 April 2021 * * $Revision: V1.9.0 * * Target Processor: Cortex-M and Cortex-A cores * -------------------------------------------------------------------- */ /* * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved. * * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the License); you may * not use this file except in compliance with the License. * You may obtain a copy of the License at * * www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an AS IS BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "dsp/matrix_functions.h" /** * @ingroup groupMatrix */ /** * @addtogroup MatrixVectMult * @{ */ /** * @brief Q7 matrix and vector multiplication. * @param[in] *pSrcMat points to the input matrix structure * @param[in] *pVec points to the input vector * @param[out] *pDst points to the output vector */ #if defined(ARM_MATH_MVEI) && !defined(ARM_MATH_AUTOVECTORIZE) #include "arm_helium_utils.h" void arm_mat_vec_mult_q7( const arm_matrix_instance_q7 * pSrcMat, const q7_t *pSrcVec, q7_t *pDstVec) { const q7_t *pMatSrc = pSrcMat->pData; const q7_t *pMat0, *pMat1; uint32_t numRows = pSrcMat->numRows; uint32_t numCols = pSrcMat->numCols; q7_t *px; int32_t row; uint16_t blkCnt; /* loop counters */ row = numRows; px = pDstVec; /* * compute 4x64-bit accumulators per loop */ while (row >= 4) { q7_t const *pMat0Vec, *pMat1Vec, *pMat2Vec, *pMat3Vec, *pVec; const q7_t *pMat2, *pMat3; q7_t const *pSrcVecPtr = pSrcVec; q31_t acc0, acc1, acc2, acc3; q7x16_t vecMatA0, vecMatA1, vecMatA2, vecMatA3, vecIn; pVec = pSrcVec; /* * Initialize the pointer pIn1 to point to the starting address of the column being processed */ pMat0 = pMatSrc; pMat1 = pMat0 + numCols; pMat2 = pMat1 + numCols; pMat3 = pMat2 + numCols; acc0 = 0L; acc1 = 0L; acc2 = 0L; acc3 = 0L; pMat0Vec = pMat0; pMat1Vec = pMat1; pMat2Vec = pMat2; pMat3Vec = pMat3; pVec = pSrcVecPtr; blkCnt = numCols >> 4; while (blkCnt > 0U) { vecMatA0 = vld1q(pMat0Vec); pMat0Vec += 16; vecMatA1 = vld1q(pMat1Vec); pMat1Vec += 16; vecMatA2 = vld1q(pMat2Vec); pMat2Vec += 16; vecMatA3 = vld1q(pMat3Vec); pMat3Vec += 16; vecIn = vld1q(pVec); pVec += 16; acc0 = vmladavaq(acc0, vecIn, vecMatA0); acc1 = vmladavaq(acc1, vecIn, vecMatA1); acc2 = vmladavaq(acc2, vecIn, vecMatA2); acc3 = vmladavaq(acc3, vecIn, vecMatA3); blkCnt--; } /* * tail * (will be merged thru tail predication) */ blkCnt = numCols & 0xF; if (blkCnt > 0U) { mve_pred16_t p0 = vctp8q(blkCnt); vecMatA0 = vld1q(pMat0Vec); vecMatA1 = vld1q(pMat1Vec); vecMatA2 = vld1q(pMat2Vec); vecMatA3 = vld1q(pMat3Vec); vecIn = vldrbq_z_s8(pVec, p0); acc0 = vmladavaq(acc0, vecIn, vecMatA0); acc1 = vmladavaq(acc1, vecIn, vecMatA1); acc2 = vmladavaq(acc2, vecIn, vecMatA2); acc3 = vmladavaq(acc3, vecIn, vecMatA3); } *px++ = __SSAT(acc0 >> 7, 8); *px++ = __SSAT(acc1 >> 7, 8); *px++ = __SSAT(acc2 >> 7, 8); *px++ = __SSAT(acc3 >> 7, 8); pMatSrc += numCols * 4; /* * Decrement the row loop counter */ row -= 4; } /* * process any remaining rows pair */ if (row >= 2) { q7_t const *pMat0Vec, *pMat1Vec, *pVec; q7_t const *pSrcVecPtr = pSrcVec; q31_t acc0, acc1; q7x16_t vecMatA0, vecMatA1, vecIn; /* * For every row wise process, the pInVec pointer is set * to the starting address of the vector */ pVec = pSrcVec; /* * Initialize the pointer pIn1 to point to the starting address of the column being processed */ pMat0 = pMatSrc; pMat1 = pMat0 + numCols; acc0 = 0; acc1 = 0; pMat0Vec = pMat0; pMat1Vec = pMat1; pVec = pSrcVecPtr; blkCnt = numCols >> 4; while (blkCnt > 0U) { vecMatA0 = vld1q(pMat0Vec); pMat0Vec += 16; vecMatA1 = vld1q(pMat1Vec); pMat1Vec += 16; vecIn = vld1q(pVec); pVec += 16; acc0 = vmladavaq(acc0, vecIn, vecMatA0); acc1 = vmladavaq(acc1, vecIn, vecMatA1); blkCnt--; } /* * tail * (will be merged thru tail predication) */ blkCnt = numCols & 0xF; if (blkCnt > 0U) { mve_pred16_t p0 = vctp8q(blkCnt); vecMatA0 = vld1q(pMat0Vec); vecMatA1 = vld1q(pMat1Vec); vecIn = vldrbq_z_s8(pVec, p0); acc0 = vmladavaq(acc0, vecIn, vecMatA0); acc1 = vmladavaq(acc1, vecIn, vecMatA1); } *px++ = __SSAT(acc0 >> 7, 8); *px++ = __SSAT(acc1 >> 7, 8); pMatSrc += numCols * 2; /* * Decrement the row loop counter */ row -= 2; } if (row >= 1) { q7_t const *pMat0Vec, *pVec; q7_t const *pSrcVecPtr = pSrcVec; q31_t acc0; q7x16_t vecMatA0, vecIn; /* * For every row wise process, the pInVec pointer is set * to the starting address of the vector */ pVec = pSrcVec; /* * Initialize the pointer pIn1 to point to the starting address of the column being processed */ pMat0 = pMatSrc; acc0 = 0LL; pMat0Vec = pMat0; pVec = pSrcVecPtr; blkCnt = numCols >> 4; while (blkCnt > 0U) { vecMatA0 = vld1q(pMat0Vec); pMat0Vec += 16; vecIn = vld1q(pVec); pVec += 16; acc0 = vmladavaq(acc0, vecIn, vecMatA0); blkCnt--; } /* * tail * (will be merged thru tail predication) */ blkCnt = numCols & 0xF; if (blkCnt > 0U) { mve_pred16_t p0 = vctp8q(blkCnt); vecMatA0 = vld1q(pMat0Vec); vecIn = vldrbq_z_s8(pVec, p0); acc0 = vmladavaq(acc0, vecIn, vecMatA0); } *px++ = __SSAT(acc0 >> 7, 8); } } #else void arm_mat_vec_mult_q7(const arm_matrix_instance_q7 *pSrcMat, const q7_t *pVec, q7_t *pDst) { uint32_t numRows = pSrcMat->numRows; uint32_t numCols = pSrcMat->numCols; const q7_t *pSrcA = pSrcMat->pData; const q7_t *pInA1; /* input data matrix pointer of Q7 type */ const q7_t *pInA2; /* input data matrix pointer of Q7 type */ const q7_t *pInA3; /* input data matrix pointer of Q7 type */ const q7_t *pInA4; /* input data matrix pointer of Q7 type */ const q7_t *pInVec; /* input data vector pointer of Q7 type */ q7_t *px; /* output data pointer */ uint32_t i, row, colCnt; /* loop counters */ q31_t matData, matData2, vecData, vecData2; /* Process 4 rows at a time */ row = numRows >> 2; i = 0u; px = pDst; /* The following loop performs the dot-product of each row in pSrcA with the vector */ while (row > 0) { /* For every row wise process, the pInVec pointer is set ** to the starting address of the vector */ pInVec = pVec; /* Initialize accumulators */ q31_t sum1 = 0; q31_t sum2 = 0; q31_t sum3 = 0; q31_t sum4 = 0; /* Loop unrolling: process 4 columns per iteration */ colCnt = numCols >> 2; /* Initialize row pointers so we can track 4 rows at once */ pInA1 = pSrcA + i; pInA2 = pInA1 + numCols; pInA3 = pInA2 + numCols; pInA4 = pInA3 + numCols; // Inner loop: matrix-vector multiplication while (colCnt > 0u) { // Read 4 values from vector vecData = read_q7x4_ia ((q7_t **) &pInVec); vecData2 = __SXTB16(__ROR(vecData, 8)); vecData = __SXTB16(vecData); // Read 16 values from the matrix - 4 values from each of 4 rows, and do multiply accumulate matData = read_q7x4_ia ((q7_t **) &pInA1); matData2 = __SXTB16(__ROR(matData, 8)); matData = __SXTB16(matData); sum1 = __SMLAD(matData, vecData, sum1); sum1 = __SMLAD(matData2, vecData2, sum1); matData = read_q7x4_ia ((q7_t **) &pInA2); matData2 = __SXTB16(__ROR(matData, 8)); matData = __SXTB16(matData); sum2 = __SMLAD(matData, vecData, sum2); sum2 = __SMLAD(matData2, vecData2, sum2); matData = read_q7x4_ia ((q7_t **) &pInA3); matData2 = __SXTB16(__ROR(matData, 8)); matData = __SXTB16(matData); sum3 = __SMLAD(matData, vecData, sum3); sum3 = __SMLAD(matData2, vecData2, sum3); matData = read_q7x4_ia ((q7_t **) &pInA4); matData2 = __SXTB16(__ROR(matData, 8)); matData = __SXTB16(matData); sum4 = __SMLAD(matData, vecData, sum4); sum4 = __SMLAD(matData2, vecData2, sum4); // Decrement the loop counter colCnt--; } /* process any remaining columns */ colCnt = numCols & 3u; while (colCnt > 0) { vecData = *pInVec++; sum1 += *pInA1++ * vecData; sum2 += *pInA2++ * vecData; sum3 += *pInA3++ * vecData; sum4 += *pInA4++ * vecData; colCnt--; } /* Saturate and store the result in the destination buffer */ *px++ = (q7_t)(__SSAT((sum1 >> 7), 8)); *px++ = (q7_t)(__SSAT((sum2 >> 7), 8)); *px++ = (q7_t)(__SSAT((sum3 >> 7), 8)); *px++ = (q7_t)(__SSAT((sum4 >> 7), 8)); i = i + numCols * 4; /* Decrement the row loop counter */ row--; } /* process any remaining rows */ row = numRows & 3u; while (row > 0) { q31_t sum = 0; pInVec = pVec; pInA1 = pSrcA + i; // loop unrolling - process 4 elements at a time colCnt = numCols >> 2; while (colCnt > 0) { vecData = read_q7x4_ia ((q7_t **) &pInVec); vecData2 = __SXTB16(__ROR(vecData, 8)); vecData = __SXTB16(vecData); matData = read_q7x4_ia ((q7_t **) &pInA1); matData2 = __SXTB16(__ROR(matData, 8)); matData = __SXTB16(matData); sum = __SMLAD(matData, vecData, sum); sum = __SMLAD(matData2, vecData2, sum); colCnt--; } // process remainder of row colCnt = numCols & 3u; while (colCnt > 0) { sum += *pInA1++ * *pInVec++; colCnt--; } *px++ = (q7_t)(__SSAT((sum >> 7), 8)); i = i + numCols; row--; } } #endif /* defined(ARM_MATH_MVEI) */ /** * @} end of MatrixMult group */