1 /* ----------------------------------------------------------------------
2 * Project: CMSIS DSP Library
3 * Title: arm_mat_vec_mult_q31.c
4 * Description: Q31 matrix and vector multiplication
5 *
6 * $Date: 23 April 2021
7 *
8 * $Revision: V1.9.0
9 *
10 * Target Processor: Cortex-M and Cortex-A cores
11 * -------------------------------------------------------------------- */
12 /*
13 * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
14 *
15 * SPDX-License-Identifier: Apache-2.0
16 *
17 * Licensed under the Apache License, Version 2.0 (the License); you may
18 * not use this file except in compliance with the License.
19 * You may obtain a copy of the License at
20 *
21 * www.apache.org/licenses/LICENSE-2.0
22 *
23 * Unless required by applicable law or agreed to in writing, software
24 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
25 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
26 * See the License for the specific language governing permissions and
27 * limitations under the License.
28 */
29
30 #include "dsp/matrix_functions.h"
31
32 /**
33 * @ingroup groupMatrix
34 */
35
36
37
38 /**
39 * @addtogroup MatrixVectMult
40 * @{
41 */
42
43 /**
44 * @brief Q31 matrix and vector multiplication.
45 * @param[in] *pSrcMat points to the input matrix structure
46 * @param[in] *pVec points to the input vector
47 * @param[out] *pDst points to the output vector
48 */
49 #if defined(ARM_MATH_MVEI) && !defined(ARM_MATH_AUTOVECTORIZE)
arm_mat_vec_mult_q31(const arm_matrix_instance_q31 * pSrcMat,const q31_t * pSrcVec,q31_t * pDstVec)50 void arm_mat_vec_mult_q31(
51 const arm_matrix_instance_q31 * pSrcMat,
52 const q31_t *pSrcVec,
53 q31_t *pDstVec)
54 {
55 const q31_t *pMatSrc = pSrcMat->pData;
56 const q31_t *pMat0, *pMat1;
57 uint32_t numRows = pSrcMat->numRows;
58 uint32_t numCols = pSrcMat->numCols;
59 q31_t *px;
60 int32_t row;
61 uint16_t blkCnt; /* loop counters */
62
63 row = numRows;
64 px = pDstVec;
65
66 /*
67 * compute 3x64-bit accumulators per loop
68 */
69 while (row >= 3)
70 {
71 q31_t const *pMat0Vec, *pMat1Vec, *pMat2Vec, *pVec;
72 const q31_t *pMat2;
73 q31_t const *pSrcVecPtr = pSrcVec;
74 q63_t acc0, acc1, acc2;
75 q31x4_t vecMatA0, vecMatA1, vecMatA2, vecIn;
76
77
78 pVec = pSrcVec;
79 /*
80 * Initialize the pointer pIn1 to point to the starting address of the column being processed
81 */
82 pMat0 = pMatSrc;
83 pMat1 = pMat0 + numCols;
84 pMat2 = pMat1 + numCols;
85
86 acc0 = 0LL;
87 acc1 = 0LL;
88 acc2 = 0LL;
89
90 pMat0Vec = pMat0;
91 pMat1Vec = pMat1;
92 pMat2Vec = pMat2;
93 pVec = pSrcVecPtr;
94
95 blkCnt = numCols >> 2;
96 while (blkCnt > 0U)
97 {
98 vecMatA0 = vld1q(pMat0Vec);
99 pMat0Vec += 4;
100 vecMatA1 = vld1q(pMat1Vec);
101 pMat1Vec += 4;
102 vecMatA2 = vld1q(pMat2Vec);
103 pMat2Vec += 4;
104 vecIn = vld1q(pVec);
105 pVec += 4;
106
107 acc0 = vmlaldavaq(acc0, vecIn, vecMatA0);
108 acc1 = vmlaldavaq(acc1, vecIn, vecMatA1);
109 acc2 = vmlaldavaq(acc2, vecIn, vecMatA2);
110
111 blkCnt--;
112 }
113 /*
114 * tail
115 * (will be merged thru tail predication)
116 */
117 blkCnt = numCols & 3;
118 if (blkCnt > 0U)
119 {
120 mve_pred16_t p0 = vctp32q(blkCnt);
121
122 vecMatA0 = vld1q(pMat0Vec);
123 vecMatA1 = vld1q(pMat1Vec);
124 vecMatA2 = vld1q(pMat2Vec);
125 vecIn = vldrwq_z_s32(pVec, p0);
126
127 acc0 = vmlaldavaq(acc0, vecIn, vecMatA0);
128 acc1 = vmlaldavaq(acc1, vecIn, vecMatA1);
129 acc2 = vmlaldavaq(acc2, vecIn, vecMatA2);
130 }
131
132 *px++ = asrl(acc0, 31);
133 *px++ = asrl(acc1, 31);
134 *px++ = asrl(acc2, 31);
135
136 pMatSrc += numCols * 3;
137 /*
138 * Decrement the row loop counter
139 */
140 row -= 3;
141 }
142
143 /*
144 * process any remaining rows pair
145 */
146 if (row >= 2)
147 {
148 q31_t const *pMat0Vec, *pMat1Vec, *pVec;
149 q31_t const *pSrcVecPtr = pSrcVec;
150 q63_t acc0, acc1;
151 q31x4_t vecMatA0, vecMatA1, vecIn;
152
153 /*
154 * For every row wise process, the pInVec pointer is set
155 * to the starting address of the vector
156 */
157 pVec = pSrcVec;
158
159 /*
160 * Initialize the pointer pIn1 to point to the starting address of the column being processed
161 */
162 pMat0 = pMatSrc;
163 pMat1 = pMat0 + numCols;
164
165 acc0 = 0LL;
166 acc1 = 0LL;
167
168 pMat0Vec = pMat0;
169 pMat1Vec = pMat1;
170 pVec = pSrcVecPtr;
171
172 blkCnt = numCols >> 2;
173 while (blkCnt > 0U)
174 {
175 vecMatA0 = vld1q(pMat0Vec);
176 pMat0Vec += 4;
177 vecMatA1 = vld1q(pMat1Vec);
178 pMat1Vec += 4;
179 vecIn = vld1q(pVec);
180 pVec += 4;
181
182 acc0 = vmlaldavaq(acc0, vecIn, vecMatA0);
183 acc1 = vmlaldavaq(acc1, vecIn, vecMatA1);
184
185 blkCnt--;
186 }
187
188 /*
189 * tail
190 * (will be merged thru tail predication)
191 */
192 blkCnt = numCols & 3;
193 if (blkCnt > 0U)
194 {
195 mve_pred16_t p0 = vctp32q(blkCnt);
196
197 vecMatA0 = vld1q(pMat0Vec);
198 vecMatA1 = vld1q(pMat1Vec);
199 vecIn = vldrwq_z_s32(pVec, p0);
200
201 acc0 = vmlaldavaq(acc0, vecIn, vecMatA0);
202 acc1 = vmlaldavaq(acc1, vecIn, vecMatA1);
203 }
204
205 *px++ = asrl(acc0, 31);
206 *px++ = asrl(acc1, 31);
207
208 pMatSrc += numCols * 2;
209 /*
210 * Decrement the row loop counter
211 */
212 row -= 2;
213 }
214
215 if (row >= 1)
216 {
217 q31_t const *pMat0Vec, *pVec;
218 q31_t const *pSrcVecPtr = pSrcVec;
219 q63_t acc0;
220 q31x4_t vecMatA0, vecIn;
221
222 /*
223 * For every row wise process, the pInVec pointer is set
224 * to the starting address of the vector
225 */
226 pVec = pSrcVec;
227
228 /*
229 * Initialize the pointer pIn1 to point to the starting address of the column being processed
230 */
231 pMat0 = pMatSrc;
232
233 acc0 = 0LL;
234
235 pMat0Vec = pMat0;
236 pVec = pSrcVecPtr;
237
238 blkCnt = numCols >> 2;
239 while (blkCnt > 0U)
240 {
241 vecMatA0 = vld1q(pMat0Vec);
242 pMat0Vec += 4;
243 vecIn = vld1q(pVec);
244 pVec += 4;
245 acc0 = vmlaldavaq(acc0, vecIn, vecMatA0);
246 blkCnt--;
247 }
248 /*
249 * tail
250 * (will be merged thru tail predication)
251 */
252 blkCnt = numCols & 3;
253 if (blkCnt > 0U)
254 {
255 mve_pred16_t p0 = vctp32q(blkCnt);
256
257 vecMatA0 = vld1q(pMat0Vec);
258 vecIn = vldrwq_z_s32(pVec, p0);
259 acc0 = vmlaldavaq(acc0, vecIn, vecMatA0);
260 }
261
262 *px++ = asrl(acc0, 31);
263 }
264 }
265 #else
arm_mat_vec_mult_q31(const arm_matrix_instance_q31 * pSrcMat,const q31_t * pVec,q31_t * pDst)266 void arm_mat_vec_mult_q31(const arm_matrix_instance_q31 *pSrcMat, const q31_t *pVec, q31_t *pDst)
267 {
268 uint32_t numRows = pSrcMat->numRows;
269 uint32_t numCols = pSrcMat->numCols;
270 const q31_t *pSrcA = pSrcMat->pData;
271 const q31_t *pInA1; /* input data matrix pointer A of Q31 type */
272 const q31_t *pInA2; /* input data matrix pointer A of Q31 type */
273 const q31_t *pInA3; /* input data matrix pointer A of Q31 type */
274 const q31_t *pInA4; /* input data matrix pointer A of Q31 type */
275 const q31_t *pInVec; /* input data matrix pointer B of Q31 type */
276 q31_t *px; /* Temporary output data matrix pointer */
277 uint16_t i, row, colCnt; /* loop counters */
278 q31_t matData, matData2, vecData, vecData2;
279
280
281 /* Process 4 rows at a time */
282 row = numRows >> 2;
283 i = 0u;
284 px = pDst;
285
286 /* The following loop performs the dot-product of each row in pSrcA with the vector */
287 /* row loop */
288 while (row > 0) {
289 /* For every row wise process, the pInVec pointer is set
290 ** to the starting address of the vector */
291 pInVec = pVec;
292
293 /* Initialize accumulators */
294 q63_t sum1 = 0;
295 q63_t sum2 = 0;
296 q63_t sum3 = 0;
297 q63_t sum4 = 0;
298
299 /* Loop unrolling: process 2 columns per iteration */
300 colCnt = numCols;
301
302 /* Initialize pointers to the starting address of the column being processed */
303 pInA1 = pSrcA + i;
304 pInA2 = pInA1 + numCols;
305 pInA3 = pInA2 + numCols;
306 pInA4 = pInA3 + numCols;
307
308
309 // Main loop: matrix-vector multiplication
310 while (colCnt > 0u) {
311 // Read 2 values from vector
312 vecData = *(pInVec)++;
313
314 // Read 8 values from the matrix - 2 values from each of 4 rows, and do multiply accumulate
315 matData = *(pInA1)++;
316 sum1 += (q63_t)matData * vecData;
317 matData = *(pInA2)++;
318 sum2 += (q63_t)matData * vecData;
319 matData = *(pInA3)++;
320 sum3 += (q63_t)matData * vecData;
321 matData = *(pInA4)++;
322 sum4 += (q63_t)matData * vecData;
323
324 // Decrement the loop counter
325 colCnt--;
326 }
327
328 /* Saturate and store the result in the destination buffer */
329 *px++ = (q31_t)(sum1 >> 31);
330 *px++ = (q31_t)(sum2 >> 31);
331 *px++ = (q31_t)(sum3 >> 31);
332 *px++ = (q31_t)(sum4 >> 31);
333
334 i = i + numCols * 4;
335
336 /* Decrement the row loop counter */
337 row--;
338 }
339
340 /* process any remaining rows */
341 row = numRows & 3u;
342 while (row > 0) {
343
344 q63_t sum = 0;
345 pInVec = pVec;
346 pInA1 = pSrcA + i;
347
348 colCnt = numCols >> 1;
349
350 while (colCnt > 0) {
351 vecData = *(pInVec)++;
352 vecData2 = *(pInVec)++;
353 matData = *(pInA1)++;
354 matData2 = *(pInA1)++;
355 sum += (q63_t)matData * vecData;
356 sum += (q63_t)matData2 * vecData2;
357 colCnt--;
358 }
359
360 // process remainder of row
361 colCnt = numCols & 1u;
362 while (colCnt > 0) {
363 sum += (q63_t)*pInA1++ * *pInVec++;
364 colCnt--;
365 }
366
367 *px++ = (q31_t)(sum >> 31);
368 i = i + numCols;
369 row--;
370 }
371 }
372 #endif /* defined(ARM_MATH_MVEI) */
373
374 /**
375 * @} end of MatrixMult group
376 */
377