• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* ----------------------------------------------------------------------
2  * Project:      CMSIS DSP Library
3  * Title:        arm_mat_vec_mult_q31.c
4  * Description:  Q31 matrix and vector multiplication
5  *
6  * $Date:        23 April 2021
7  *
8  * $Revision:    V1.9.0
9  *
10  * Target Processor: Cortex-M and Cortex-A cores
11  * -------------------------------------------------------------------- */
12 /*
13  * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
14  *
15  * SPDX-License-Identifier: Apache-2.0
16  *
17  * Licensed under the Apache License, Version 2.0 (the License); you may
18  * not use this file except in compliance with the License.
19  * You may obtain a copy of the License at
20  *
21  * www.apache.org/licenses/LICENSE-2.0
22  *
23  * Unless required by applicable law or agreed to in writing, software
24  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
25  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
26  * See the License for the specific language governing permissions and
27  * limitations under the License.
28  */
29 
30 #include "dsp/matrix_functions.h"
31 
32 /**
33  * @ingroup groupMatrix
34  */
35 
36 
37 
38 /**
39  * @addtogroup MatrixVectMult
40  * @{
41  */
42 
43 /**
44  * @brief Q31 matrix and vector multiplication.
45  * @param[in]       *pSrcMat points to the input matrix structure
46  * @param[in]       *pVec points to the input vector
47  * @param[out]      *pDst points to the output vector
48  */
49 #if defined(ARM_MATH_MVEI) && !defined(ARM_MATH_AUTOVECTORIZE)
arm_mat_vec_mult_q31(const arm_matrix_instance_q31 * pSrcMat,const q31_t * pSrcVec,q31_t * pDstVec)50 void arm_mat_vec_mult_q31(
51     const arm_matrix_instance_q31 * pSrcMat,
52     const q31_t     *pSrcVec,
53     q31_t           *pDstVec)
54 {
55     const q31_t *pMatSrc = pSrcMat->pData;
56     const q31_t *pMat0, *pMat1;
57     uint32_t     numRows = pSrcMat->numRows;
58     uint32_t     numCols = pSrcMat->numCols;
59     q31_t       *px;
60     int32_t      row;
61     uint16_t     blkCnt;           /* loop counters */
62 
63     row = numRows;
64     px = pDstVec;
65 
66     /*
67      * compute 3x64-bit accumulators per loop
68      */
69     while (row >= 3)
70     {
71         q31_t const *pMat0Vec, *pMat1Vec, *pMat2Vec, *pVec;
72         const q31_t  *pMat2;
73         q31_t const  *pSrcVecPtr = pSrcVec;
74         q63_t         acc0, acc1, acc2;
75         q31x4_t     vecMatA0, vecMatA1, vecMatA2, vecIn;
76 
77 
78         pVec = pSrcVec;
79         /*
80          * Initialize the pointer pIn1 to point to the starting address of the column being processed
81          */
82         pMat0 = pMatSrc;
83         pMat1 = pMat0 + numCols;
84         pMat2 = pMat1 + numCols;
85 
86         acc0 = 0LL;
87         acc1 = 0LL;
88         acc2 = 0LL;
89 
90         pMat0Vec = pMat0;
91         pMat1Vec = pMat1;
92         pMat2Vec = pMat2;
93         pVec = pSrcVecPtr;
94 
95         blkCnt = numCols >> 2;
96         while (blkCnt > 0U)
97         {
98             vecMatA0 = vld1q(pMat0Vec);
99             pMat0Vec += 4;
100             vecMatA1 = vld1q(pMat1Vec);
101             pMat1Vec += 4;
102             vecMatA2 = vld1q(pMat2Vec);
103             pMat2Vec += 4;
104             vecIn = vld1q(pVec);
105             pVec += 4;
106 
107             acc0 = vmlaldavaq(acc0, vecIn, vecMatA0);
108             acc1 = vmlaldavaq(acc1, vecIn, vecMatA1);
109             acc2 = vmlaldavaq(acc2, vecIn, vecMatA2);
110 
111             blkCnt--;
112         }
113         /*
114          * tail
115          * (will be merged thru tail predication)
116          */
117         blkCnt = numCols & 3;
118         if (blkCnt > 0U)
119         {
120             mve_pred16_t p0 = vctp32q(blkCnt);
121 
122             vecMatA0 = vld1q(pMat0Vec);
123             vecMatA1 = vld1q(pMat1Vec);
124             vecMatA2 = vld1q(pMat2Vec);
125             vecIn = vldrwq_z_s32(pVec, p0);
126 
127             acc0 = vmlaldavaq(acc0, vecIn, vecMatA0);
128             acc1 = vmlaldavaq(acc1, vecIn, vecMatA1);
129             acc2 = vmlaldavaq(acc2, vecIn, vecMatA2);
130         }
131 
132         *px++ = asrl(acc0, 31);
133         *px++ = asrl(acc1, 31);
134         *px++ = asrl(acc2, 31);
135 
136         pMatSrc += numCols * 3;
137         /*
138          * Decrement the row loop counter
139          */
140         row -= 3;
141     }
142 
143     /*
144      * process any remaining rows pair
145      */
146     if (row >= 2)
147     {
148         q31_t const *pMat0Vec, *pMat1Vec, *pVec;
149         q31_t const  *pSrcVecPtr = pSrcVec;
150         q63_t         acc0, acc1;
151         q31x4_t     vecMatA0, vecMatA1, vecIn;
152 
153         /*
154          * For every row wise process, the pInVec pointer is set
155          * to the starting address of the vector
156          */
157         pVec = pSrcVec;
158 
159         /*
160          * Initialize the pointer pIn1 to point to the starting address of the column being processed
161          */
162         pMat0 = pMatSrc;
163         pMat1 = pMat0 + numCols;
164 
165         acc0 = 0LL;
166         acc1 = 0LL;
167 
168         pMat0Vec = pMat0;
169         pMat1Vec = pMat1;
170         pVec = pSrcVecPtr;
171 
172         blkCnt = numCols >> 2;
173         while (blkCnt > 0U)
174         {
175             vecMatA0 = vld1q(pMat0Vec);
176             pMat0Vec += 4;
177             vecMatA1 = vld1q(pMat1Vec);
178             pMat1Vec += 4;
179             vecIn = vld1q(pVec);
180             pVec += 4;
181 
182             acc0 = vmlaldavaq(acc0, vecIn, vecMatA0);
183             acc1 = vmlaldavaq(acc1, vecIn, vecMatA1);
184 
185             blkCnt--;
186         }
187 
188         /*
189          * tail
190          * (will be merged thru tail predication)
191          */
192         blkCnt = numCols & 3;
193         if (blkCnt > 0U)
194         {
195             mve_pred16_t p0 = vctp32q(blkCnt);
196 
197             vecMatA0 = vld1q(pMat0Vec);
198             vecMatA1 = vld1q(pMat1Vec);
199             vecIn = vldrwq_z_s32(pVec, p0);
200 
201             acc0 = vmlaldavaq(acc0, vecIn, vecMatA0);
202             acc1 = vmlaldavaq(acc1, vecIn, vecMatA1);
203         }
204 
205         *px++ = asrl(acc0, 31);
206         *px++ = asrl(acc1, 31);
207 
208         pMatSrc += numCols * 2;
209         /*
210          * Decrement the row loop counter
211          */
212         row -= 2;
213     }
214 
215     if (row >= 1)
216     {
217         q31_t const *pMat0Vec, *pVec;
218         q31_t const  *pSrcVecPtr = pSrcVec;
219         q63_t         acc0;
220         q31x4_t     vecMatA0, vecIn;
221 
222         /*
223          * For every row wise process, the pInVec pointer is set
224          * to the starting address of the vector
225          */
226         pVec = pSrcVec;
227 
228         /*
229          * Initialize the pointer pIn1 to point to the starting address of the column being processed
230          */
231         pMat0 = pMatSrc;
232 
233         acc0 = 0LL;
234 
235         pMat0Vec = pMat0;
236         pVec = pSrcVecPtr;
237 
238         blkCnt = numCols >> 2;
239         while (blkCnt > 0U)
240         {
241             vecMatA0 = vld1q(pMat0Vec);
242             pMat0Vec += 4;
243             vecIn = vld1q(pVec);
244             pVec += 4;
245             acc0 = vmlaldavaq(acc0, vecIn, vecMatA0);
246             blkCnt--;
247         }
248         /*
249          * tail
250          * (will be merged thru tail predication)
251          */
252         blkCnt = numCols & 3;
253         if (blkCnt > 0U)
254         {
255             mve_pred16_t p0 = vctp32q(blkCnt);
256 
257             vecMatA0 = vld1q(pMat0Vec);
258             vecIn = vldrwq_z_s32(pVec, p0);
259             acc0 = vmlaldavaq(acc0, vecIn, vecMatA0);
260         }
261 
262         *px++ = asrl(acc0, 31);
263     }
264 }
265 #else
arm_mat_vec_mult_q31(const arm_matrix_instance_q31 * pSrcMat,const q31_t * pVec,q31_t * pDst)266 void arm_mat_vec_mult_q31(const arm_matrix_instance_q31 *pSrcMat, const q31_t *pVec, q31_t *pDst)
267 {
268     uint32_t numRows = pSrcMat->numRows;
269     uint32_t numCols = pSrcMat->numCols;
270     const q31_t *pSrcA = pSrcMat->pData;
271     const q31_t *pInA1;      /* input data matrix pointer A of Q31 type */
272     const q31_t *pInA2;      /* input data matrix pointer A of Q31 type */
273     const q31_t *pInA3;      /* input data matrix pointer A of Q31 type */
274     const q31_t *pInA4;      /* input data matrix pointer A of Q31 type */
275     const q31_t *pInVec;     /* input data matrix pointer B of Q31 type */
276     q31_t *px;               /* Temporary output data matrix pointer */
277     uint16_t i, row, colCnt; /* loop counters */
278     q31_t matData, matData2, vecData, vecData2;
279 
280 
281     /* Process 4 rows at a time */
282     row = numRows >> 2;
283     i = 0u;
284     px = pDst;
285 
286     /* The following loop performs the dot-product of each row in pSrcA with the vector */
287     /* row loop */
288     while (row > 0) {
289         /* For every row wise process, the pInVec pointer is set
290          ** to the starting address of the vector */
291         pInVec = pVec;
292 
293         /* Initialize accumulators */
294         q63_t sum1 = 0;
295         q63_t sum2 = 0;
296         q63_t sum3 = 0;
297         q63_t sum4 = 0;
298 
299         /* Loop unrolling: process 2 columns per iteration */
300         colCnt = numCols;
301 
302         /* Initialize pointers to the starting address of the column being processed */
303         pInA1 = pSrcA + i;
304         pInA2 = pInA1 + numCols;
305         pInA3 = pInA2 + numCols;
306         pInA4 = pInA3 + numCols;
307 
308 
309         // Main loop: matrix-vector multiplication
310         while (colCnt > 0u) {
311             // Read 2 values from vector
312             vecData = *(pInVec)++;
313 
314             // Read 8 values from the matrix - 2 values from each of 4 rows, and do multiply accumulate
315             matData = *(pInA1)++;
316             sum1 += (q63_t)matData * vecData;
317             matData = *(pInA2)++;
318             sum2 += (q63_t)matData * vecData;
319             matData = *(pInA3)++;
320             sum3 += (q63_t)matData * vecData;
321             matData = *(pInA4)++;
322             sum4 += (q63_t)matData * vecData;
323 
324             // Decrement the loop counter
325             colCnt--;
326         }
327 
328         /* Saturate and store the result in the destination buffer */
329         *px++ = (q31_t)(sum1 >> 31);
330         *px++ = (q31_t)(sum2 >> 31);
331         *px++ = (q31_t)(sum3 >> 31);
332         *px++ = (q31_t)(sum4 >> 31);
333 
334         i = i + numCols * 4;
335 
336         /* Decrement the row loop counter */
337         row--;
338     }
339 
340     /* process any remaining rows */
341     row = numRows & 3u;
342     while (row > 0) {
343 
344         q63_t sum = 0;
345         pInVec = pVec;
346         pInA1 = pSrcA + i;
347 
348         colCnt = numCols >> 1;
349 
350         while (colCnt > 0) {
351             vecData = *(pInVec)++;
352             vecData2 = *(pInVec)++;
353             matData = *(pInA1)++;
354             matData2 = *(pInA1)++;
355             sum += (q63_t)matData * vecData;
356             sum += (q63_t)matData2 * vecData2;
357             colCnt--;
358         }
359 
360         // process remainder of row
361         colCnt = numCols & 1u;
362         while (colCnt > 0) {
363             sum += (q63_t)*pInA1++ * *pInVec++;
364             colCnt--;
365         }
366 
367         *px++ = (q31_t)(sum >> 31);
368         i = i + numCols;
369         row--;
370     }
371 }
372 #endif /* defined(ARM_MATH_MVEI) */
373 
374 /**
375  * @} end of MatrixMult group
376  */
377