• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* ----------------------------------------------------------------------
2  * Project:      CMSIS DSP Library
3  * Title:        arm_mat_vec_mult_q7.c
4  * Description:  Q7 matrix and vector multiplication
5  *
6  * $Date:        23 April 2021
7  *
8  * $Revision:    V1.9.0
9  *
10  * Target Processor: Cortex-M and Cortex-A cores
11  * -------------------------------------------------------------------- */
12 /*
13  * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
14  *
15  * SPDX-License-Identifier: Apache-2.0
16  *
17  * Licensed under the Apache License, Version 2.0 (the License); you may
18  * not use this file except in compliance with the License.
19  * You may obtain a copy of the License at
20  *
21  * www.apache.org/licenses/LICENSE-2.0
22  *
23  * Unless required by applicable law or agreed to in writing, software
24  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
25  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
26  * See the License for the specific language governing permissions and
27  * limitations under the License.
28  */
29 
30 #include "dsp/matrix_functions.h"
31 
32 /**
33  * @ingroup groupMatrix
34  */
35 
36 
37 
38 /**
39  * @addtogroup MatrixVectMult
40  * @{
41  */
42 
43 /**
44  * @brief Q7 matrix and vector multiplication.
45  * @param[in]       *pSrcMat points to the input matrix structure
46  * @param[in]       *pVec points to the input vector
47  * @param[out]      *pDst points to the output vector
48  */
49 #if defined(ARM_MATH_MVEI) && !defined(ARM_MATH_AUTOVECTORIZE)
50 
51 #include "arm_helium_utils.h"
52 
arm_mat_vec_mult_q7(const arm_matrix_instance_q7 * pSrcMat,const q7_t * pSrcVec,q7_t * pDstVec)53 void arm_mat_vec_mult_q7(
54     const arm_matrix_instance_q7 * pSrcMat,
55     const q7_t     *pSrcVec,
56     q7_t           *pDstVec)
57 {
58     const q7_t *pMatSrc = pSrcMat->pData;
59     const q7_t *pMat0, *pMat1;
60     uint32_t     numRows = pSrcMat->numRows;
61     uint32_t     numCols = pSrcMat->numCols;
62     q7_t       *px;
63     int32_t      row;
64     uint16_t     blkCnt;           /* loop counters */
65 
66     row = numRows;
67     px = pDstVec;
68 
69     /*
70      * compute 4x64-bit accumulators per loop
71      */
72     while (row >= 4)
73     {
74         q7_t const *pMat0Vec, *pMat1Vec, *pMat2Vec, *pMat3Vec, *pVec;
75         const q7_t  *pMat2, *pMat3;
76         q7_t const  *pSrcVecPtr = pSrcVec;
77         q31_t        acc0, acc1, acc2, acc3;
78         q7x16_t      vecMatA0, vecMatA1, vecMatA2, vecMatA3, vecIn;
79 
80         pVec = pSrcVec;
81         /*
82          * Initialize the pointer pIn1 to point to the starting address of the column being processed
83          */
84         pMat0 = pMatSrc;
85         pMat1 = pMat0 + numCols;
86         pMat2 = pMat1 + numCols;
87         pMat3 = pMat2 + numCols;
88 
89         acc0 = 0L;
90         acc1 = 0L;
91         acc2 = 0L;
92         acc3 = 0L;
93 
94         pMat0Vec = pMat0;
95         pMat1Vec = pMat1;
96         pMat2Vec = pMat2;
97         pMat3Vec = pMat3;
98         pVec = pSrcVecPtr;
99 
100         blkCnt = numCols >> 4;
101         while (blkCnt > 0U)
102         {
103 
104             vecMatA0 = vld1q(pMat0Vec);
105             pMat0Vec += 16;
106             vecMatA1 = vld1q(pMat1Vec);
107             pMat1Vec += 16;
108             vecMatA2 = vld1q(pMat2Vec);
109             pMat2Vec += 16;
110             vecMatA3 = vld1q(pMat3Vec);
111             pMat3Vec += 16;
112             vecIn = vld1q(pVec);
113             pVec += 16;
114 
115             acc0 = vmladavaq(acc0, vecIn, vecMatA0);
116             acc1 = vmladavaq(acc1, vecIn, vecMatA1);
117             acc2 = vmladavaq(acc2, vecIn, vecMatA2);
118             acc3 = vmladavaq(acc3, vecIn, vecMatA3);
119 
120             blkCnt--;
121         }
122         /*
123          * tail
124          * (will be merged thru tail predication)
125          */
126         blkCnt = numCols & 0xF;
127         if (blkCnt > 0U)
128         {
129             mve_pred16_t p0 = vctp8q(blkCnt);
130 
131             vecMatA0 = vld1q(pMat0Vec);
132             vecMatA1 = vld1q(pMat1Vec);
133             vecMatA2 = vld1q(pMat2Vec);
134             vecMatA3 = vld1q(pMat3Vec);
135             vecIn = vldrbq_z_s8(pVec, p0);
136 
137             acc0 = vmladavaq(acc0, vecIn, vecMatA0);
138             acc1 = vmladavaq(acc1, vecIn, vecMatA1);
139             acc2 = vmladavaq(acc2, vecIn, vecMatA2);
140             acc3 = vmladavaq(acc3, vecIn, vecMatA3);
141         }
142 
143         *px++ = __SSAT(acc0 >> 7, 8);
144         *px++ = __SSAT(acc1 >> 7, 8);
145         *px++ = __SSAT(acc2 >> 7, 8);
146         *px++ = __SSAT(acc3 >> 7, 8);
147 
148         pMatSrc += numCols * 4;
149         /*
150          * Decrement the row loop counter
151          */
152         row -= 4;
153     }
154 
155     /*
156      * process any remaining rows pair
157      */
158     if (row >= 2)
159     {
160         q7_t const  *pMat0Vec, *pMat1Vec, *pVec;
161         q7_t const  *pSrcVecPtr = pSrcVec;
162         q31_t         acc0, acc1;
163         q7x16_t     vecMatA0, vecMatA1, vecIn;
164 
165         /*
166          * For every row wise process, the pInVec pointer is set
167          * to the starting address of the vector
168          */
169         pVec = pSrcVec;
170 
171         /*
172          * Initialize the pointer pIn1 to point to the starting address of the column being processed
173          */
174         pMat0 = pMatSrc;
175         pMat1 = pMat0 + numCols;
176 
177         acc0 = 0;
178         acc1 = 0;
179 
180         pMat0Vec = pMat0;
181         pMat1Vec = pMat1;
182         pVec = pSrcVecPtr;
183 
184         blkCnt = numCols >> 4;
185         while (blkCnt > 0U)
186         {
187             vecMatA0 = vld1q(pMat0Vec);
188             pMat0Vec += 16;
189             vecMatA1 = vld1q(pMat1Vec);
190             pMat1Vec += 16;
191             vecIn = vld1q(pVec);
192             pVec += 16;
193 
194             acc0 = vmladavaq(acc0, vecIn, vecMatA0);
195             acc1 = vmladavaq(acc1, vecIn, vecMatA1);
196 
197             blkCnt--;
198         }
199 
200         /*
201          * tail
202          * (will be merged thru tail predication)
203          */
204         blkCnt = numCols & 0xF;
205         if (blkCnt > 0U)
206         {
207             mve_pred16_t p0 = vctp8q(blkCnt);
208 
209             vecMatA0 = vld1q(pMat0Vec);
210             vecMatA1 = vld1q(pMat1Vec);
211             vecIn = vldrbq_z_s8(pVec, p0);
212 
213             acc0 = vmladavaq(acc0, vecIn, vecMatA0);
214             acc1 = vmladavaq(acc1, vecIn, vecMatA1);
215         }
216 
217         *px++ = __SSAT(acc0 >> 7, 8);
218         *px++ = __SSAT(acc1 >> 7, 8);
219 
220         pMatSrc += numCols * 2;
221         /*
222          * Decrement the row loop counter
223          */
224         row -= 2;
225     }
226 
227     if (row >= 1)
228     {
229         q7_t const  *pMat0Vec, *pVec;
230         q7_t const  *pSrcVecPtr = pSrcVec;
231         q31_t         acc0;
232         q7x16_t     vecMatA0, vecIn;
233 
234         /*
235          * For every row wise process, the pInVec pointer is set
236          * to the starting address of the vector
237          */
238         pVec = pSrcVec;
239 
240         /*
241          * Initialize the pointer pIn1 to point to the starting address of the column being processed
242          */
243         pMat0 = pMatSrc;
244 
245         acc0 = 0LL;
246 
247         pMat0Vec = pMat0;
248         pVec = pSrcVecPtr;
249 
250         blkCnt = numCols >> 4;
251         while (blkCnt > 0U)
252         {
253             vecMatA0 = vld1q(pMat0Vec);
254             pMat0Vec += 16;
255             vecIn = vld1q(pVec);
256             pVec += 16;
257 
258             acc0 = vmladavaq(acc0, vecIn, vecMatA0);
259             blkCnt--;
260         }
261         /*
262          * tail
263          * (will be merged thru tail predication)
264          */
265         blkCnt = numCols & 0xF;
266         if (blkCnt > 0U)
267         {
268             mve_pred16_t p0 = vctp8q(blkCnt);
269 
270             vecMatA0 = vld1q(pMat0Vec);
271             vecIn = vldrbq_z_s8(pVec, p0);
272             acc0 = vmladavaq(acc0, vecIn, vecMatA0);
273         }
274         *px++ = __SSAT(acc0 >> 7, 8);
275     }
276 }
277 
278 #else
arm_mat_vec_mult_q7(const arm_matrix_instance_q7 * pSrcMat,const q7_t * pVec,q7_t * pDst)279 void arm_mat_vec_mult_q7(const arm_matrix_instance_q7 *pSrcMat, const q7_t *pVec, q7_t *pDst)
280 {
281     uint32_t numRows = pSrcMat->numRows;
282     uint32_t numCols = pSrcMat->numCols;
283     const q7_t *pSrcA = pSrcMat->pData;
284     const q7_t *pInA1;       /* input data matrix pointer of Q7 type */
285     const q7_t *pInA2;       /* input data matrix pointer of Q7 type */
286     const q7_t *pInA3;       /* input data matrix pointer of Q7 type */
287     const q7_t *pInA4;       /* input data matrix pointer of Q7 type */
288     const q7_t *pInVec;      /* input data vector pointer of Q7 type */
289     q7_t *px;                /* output data pointer */
290     uint32_t i, row, colCnt; /* loop counters */
291 
292     q31_t matData, matData2, vecData, vecData2;
293 
294 
295     /* Process 4 rows at a time */
296     row = numRows >> 2;
297     i = 0u;
298     px = pDst;
299 
300 
301 
302     /* The following loop performs the dot-product of each row in pSrcA with the vector */
303     while (row > 0) {
304         /* For every row wise process, the pInVec pointer is set
305          ** to the starting address of the vector */
306         pInVec = pVec;
307 
308         /* Initialize accumulators */
309         q31_t sum1 = 0;
310         q31_t sum2 = 0;
311         q31_t sum3 = 0;
312         q31_t sum4 = 0;
313 
314         /* Loop unrolling: process 4 columns per iteration */
315         colCnt = numCols >> 2;
316 
317         /* Initialize row pointers so we can track 4 rows at once */
318         pInA1 = pSrcA + i;
319         pInA2 = pInA1 + numCols;
320         pInA3 = pInA2 + numCols;
321         pInA4 = pInA3 + numCols;
322 
323 
324         // Inner loop: matrix-vector multiplication
325 
326         while (colCnt > 0u) {
327             // Read 4 values from vector
328             vecData = read_q7x4_ia ((q7_t **) &pInVec);
329             vecData2 = __SXTB16(__ROR(vecData, 8));
330             vecData = __SXTB16(vecData);
331             // Read 16 values from the matrix - 4 values from each of 4 rows, and do multiply accumulate
332             matData = read_q7x4_ia ((q7_t **) &pInA1);
333             matData2 = __SXTB16(__ROR(matData, 8));
334             matData = __SXTB16(matData);
335             sum1 = __SMLAD(matData, vecData, sum1);
336             sum1 = __SMLAD(matData2, vecData2, sum1);
337             matData = read_q7x4_ia ((q7_t **) &pInA2);
338             matData2 = __SXTB16(__ROR(matData, 8));
339             matData = __SXTB16(matData);
340             sum2 = __SMLAD(matData, vecData, sum2);
341             sum2 = __SMLAD(matData2, vecData2, sum2);
342             matData = read_q7x4_ia ((q7_t **) &pInA3);
343             matData2 = __SXTB16(__ROR(matData, 8));
344             matData = __SXTB16(matData);
345             sum3 = __SMLAD(matData, vecData, sum3);
346             sum3 = __SMLAD(matData2, vecData2, sum3);
347             matData = read_q7x4_ia ((q7_t **) &pInA4);
348             matData2 = __SXTB16(__ROR(matData, 8));
349             matData = __SXTB16(matData);
350             sum4 = __SMLAD(matData, vecData, sum4);
351             sum4 = __SMLAD(matData2, vecData2, sum4);
352 
353             // Decrement the loop counter
354             colCnt--;
355         }
356 
357         /* process any remaining columns */
358 
359         colCnt = numCols & 3u;
360 
361         while (colCnt > 0) {
362             vecData = *pInVec++;
363             sum1 += *pInA1++ * vecData;
364             sum2 += *pInA2++ * vecData;
365             sum3 += *pInA3++ * vecData;
366             sum4 += *pInA4++ * vecData;
367             colCnt--;
368         }
369 
370         /* Saturate and store the result in the destination buffer */
371         *px++ = (q7_t)(__SSAT((sum1 >> 7), 8));
372         *px++ = (q7_t)(__SSAT((sum2 >> 7), 8));
373         *px++ = (q7_t)(__SSAT((sum3 >> 7), 8));
374         *px++ = (q7_t)(__SSAT((sum4 >> 7), 8));
375 
376         i = i + numCols * 4;
377 
378         /* Decrement the row loop counter */
379         row--;
380     }
381 
382     /* process any remaining rows */
383     row = numRows & 3u;
384     while (row > 0) {
385 
386         q31_t sum = 0;
387         pInVec = pVec;
388         pInA1 = pSrcA + i;
389 
390         // loop unrolling - process 4 elements at a time
391         colCnt = numCols >> 2;
392 
393         while (colCnt > 0) {
394             vecData = read_q7x4_ia ((q7_t **) &pInVec);
395             vecData2 = __SXTB16(__ROR(vecData, 8));
396             vecData = __SXTB16(vecData);
397             matData = read_q7x4_ia ((q7_t **) &pInA1);
398             matData2 = __SXTB16(__ROR(matData, 8));
399             matData = __SXTB16(matData);
400             sum = __SMLAD(matData, vecData, sum);
401             sum = __SMLAD(matData2, vecData2, sum);
402             colCnt--;
403         }
404 
405         // process remainder of row
406         colCnt = numCols & 3u;
407         while (colCnt > 0) {
408             sum += *pInA1++ * *pInVec++;
409             colCnt--;
410         }
411         *px++ = (q7_t)(__SSAT((sum >> 7), 8));
412         i = i + numCols;
413         row--;
414     }
415 }
416 #endif /* defined(ARM_MATH_MVEI) */
417 
418 /**
419  * @} end of MatrixMult group
420  */
421