1 /* ----------------------------------------------------------------------
2 * Project: CMSIS DSP Library
3 * Title: arm_mat_mult_q31.c
4 * Description: Q31 matrix multiplication
5 *
6 * $Date: 23 April 2021
7 * $Revision: V1.9.0
8 *
9 * Target Processor: Cortex-M and Cortex-A cores
10 * -------------------------------------------------------------------- */
11 /*
12 * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
13 *
14 * SPDX-License-Identifier: Apache-2.0
15 *
16 * Licensed under the Apache License, Version 2.0 (the License); you may
17 * not use this file except in compliance with the License.
18 * You may obtain a copy of the License at
19 *
20 * www.apache.org/licenses/LICENSE-2.0
21 *
22 * Unless required by applicable law or agreed to in writing, software
23 * distributed under the License is distributed on an AS IS BASIS, WITHOUT
24 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25 * See the License for the specific language governing permissions and
26 * limitations under the License.
27 */
28
29 #include "dsp/matrix_functions.h"
30
31 /**
32 @ingroup groupMatrix
33 */
34
35 /**
36 @addtogroup MatrixMult
37 @{
38 */
39
40 /**
41 @brief Q31 matrix multiplication.
42 @param[in] pSrcA points to the first input matrix structure
43 @param[in] pSrcB points to the second input matrix structure
44 @param[out] pDst points to output matrix structure
45 @return execution status
46 - \ref ARM_MATH_SUCCESS : Operation successful
47 - \ref ARM_MATH_SIZE_MISMATCH : Matrix size check failed
48
49 @par Scaling and Overflow Behavior
50 The function is implemented using an internal 64-bit accumulator.
51 The accumulator has a 2.62 format and maintains full precision of the intermediate
52 multiplication results but provides only a single guard bit. There is no saturation
53 on intermediate additions. Thus, if the accumulator overflows it wraps around and
54 distorts the result. The input signals should be scaled down to avoid intermediate
55 overflows. The input is thus scaled down by log2(numColsA) bits
56 to avoid overflows, as a total of numColsA additions are performed internally.
57 The 2.62 accumulator is right shifted by 31 bits and saturated to 1.31 format to yield the final result.
58 @remark
59 Refer to \ref arm_mat_mult_fast_q31() for a faster but less precise implementation of this function.
60 */
61 #if defined(ARM_MATH_MVEI) && !defined(ARM_MATH_AUTOVECTORIZE)
62
63 #define MATRIX_DIM2 2
64 #define MATRIX_DIM3 3
65 #define MATRIX_DIM4 4
66
arm_mat_mult_q31_2x2_mve(const arm_matrix_instance_q31 * pSrcA,const arm_matrix_instance_q31 * pSrcB,arm_matrix_instance_q31 * pDst)67 __STATIC_INLINE arm_status arm_mat_mult_q31_2x2_mve(
68 const arm_matrix_instance_q31 * pSrcA,
69 const arm_matrix_instance_q31 * pSrcB,
70 arm_matrix_instance_q31 * pDst)
71 {
72 q31_t *pInB = pSrcB->pData; /* input data matrix pointer B */
73 q31_t *pInA = pSrcA->pData; /* input data matrix pointer A */
74 q31_t *pOut = pDst->pData; /* output data matrix pointer */
75 uint32x4_t vecColBOffs;
76 q31_t *pInA0 = pInA;
77 q31_t *pInA1 = pInA0 + MATRIX_DIM2;
78 q63_t acc0, acc1;
79 q31x4_t vecB, vecA0, vecA1;
80 /* enable predication to disable half of vector elements */
81 mve_pred16_t p0 = vctp32q(MATRIX_DIM2);
82
83 vecColBOffs = vidupq_u32((uint32_t)0, 1);
84 vecColBOffs = vecColBOffs * MATRIX_DIM2;
85
86 pInB = pSrcB->pData;
87
88 /* load 1st B column (partial load) */
89 vecB = vldrwq_gather_shifted_offset_z_s32(pInB, vecColBOffs, p0);
90
91 /* load A rows */
92 vecA0 = vldrwq_s32(pInA0);
93 vecA1 = vldrwq_s32(pInA1);
94
95 acc0 = vrmlaldavhq(vecA0, vecB);
96 acc1 = vrmlaldavhq(vecA1, vecB);
97
98 acc0 = asrl(acc0, 23);
99 acc1 = asrl(acc1, 23);
100
101 pOut[0 * MATRIX_DIM2] = (q31_t) acc0;
102 pOut[1 * MATRIX_DIM2] = (q31_t) acc1;
103 pOut++;
104
105 /* move to next B column */
106 pInB = pInB + 1;
107
108 vecB = vldrwq_gather_shifted_offset_z_s32(pInB, vecColBOffs, p0);
109
110 acc0 = vrmlaldavhq(vecA0, vecB);
111 acc1 = vrmlaldavhq(vecA1, vecB);
112
113 acc0 = asrl(acc0, 23);
114 acc1 = asrl(acc1, 23);
115
116 pOut[0 * MATRIX_DIM2] = (q31_t) acc0;
117 pOut[1 * MATRIX_DIM2] = (q31_t) acc1;
118 /*
119 * Return to application
120 */
121 return (ARM_MATH_SUCCESS);
122 }
123
124
125
arm_mat_mult_q31_3x3_mve(const arm_matrix_instance_q31 * pSrcA,const arm_matrix_instance_q31 * pSrcB,arm_matrix_instance_q31 * pDst)126 __STATIC_INLINE arm_status arm_mat_mult_q31_3x3_mve(
127 const arm_matrix_instance_q31 * pSrcA,
128 const arm_matrix_instance_q31 * pSrcB,
129 arm_matrix_instance_q31 * pDst)
130 {
131 q31_t *pInB = pSrcB->pData; /* input data matrix pointer B */
132 q31_t *pInA = pSrcA->pData; /* input data matrix pointer A */
133 q31_t *pOut = pDst->pData; /* output data matrix pointer */
134 uint32x4_t vecColBOffs;
135 q31_t *pInA0 = pInA;
136 q31_t *pInA1 = pInA0 + MATRIX_DIM3;
137 q31_t *pInA2 = pInA1 + MATRIX_DIM3;
138 q63_t acc0, acc1, acc2;
139 q31x4_t vecB, vecA;
140 /* enable predication to disable last (4th) vector element */
141 mve_pred16_t p0 = vctp32q(MATRIX_DIM3);
142
143 vecColBOffs = vidupq_u32((uint32_t)0, 1);
144 vecColBOffs = vecColBOffs * MATRIX_DIM3;
145
146 pInB = pSrcB->pData;
147
148 vecB = vldrwq_gather_shifted_offset_z_s32(pInB, vecColBOffs, p0);
149
150 vecA = vldrwq_s32(pInA0);
151 acc0 = vrmlaldavhq(vecA, vecB);
152 vecA = vldrwq_s32(pInA1);
153 acc1 = vrmlaldavhq(vecA, vecB);
154 vecA = vldrwq_s32(pInA2);
155 acc2 = vrmlaldavhq(vecA, vecB);
156
157 acc0 = asrl(acc0, 23);
158 acc1 = asrl(acc1, 23);
159 acc2 = asrl(acc2, 23);
160
161 pOut[0 * MATRIX_DIM3] = (q31_t) acc0;
162 pOut[1 * MATRIX_DIM3] = (q31_t) acc1;
163 pOut[2 * MATRIX_DIM3] = (q31_t) acc2;
164 pOut++;
165
166 /* move to next B column */
167 pInB = pInB + 1;
168
169 vecB = vldrwq_gather_shifted_offset_z_s32(pInB, vecColBOffs, p0);
170
171 vecA = vldrwq_s32(pInA0);
172 acc0 = vrmlaldavhq(vecA, vecB);
173 vecA = vldrwq_s32(pInA1);
174 acc1 = vrmlaldavhq(vecA, vecB);
175 vecA = vldrwq_s32(pInA2);
176 acc2 = vrmlaldavhq(vecA, vecB);
177
178 acc0 = asrl(acc0, 23);
179 acc1 = asrl(acc1, 23);
180 acc2 = asrl(acc2, 23);
181
182 pOut[0 * MATRIX_DIM3] = (q31_t) acc0;
183 pOut[1 * MATRIX_DIM3] = (q31_t) acc1;
184 pOut[2 * MATRIX_DIM3] = (q31_t) acc2;
185 pOut++;
186
187 /* move to next B column */
188 pInB = pInB + 1;
189
190 vecB = vldrwq_gather_shifted_offset_z_s32(pInB, vecColBOffs, p0);
191
192 vecA = vldrwq_s32(pInA0);
193 acc0 = vrmlaldavhq(vecA, vecB);
194 vecA = vldrwq_s32(pInA1);
195 acc1 = vrmlaldavhq(vecA, vecB);
196 vecA = vldrwq_s32(pInA2);
197 acc2 = vrmlaldavhq(vecA, vecB);
198
199 acc0 = asrl(acc0, 23);
200 acc1 = asrl(acc1, 23);
201 acc2 = asrl(acc2, 23);
202
203 pOut[0 * MATRIX_DIM3] = (q31_t) acc0;
204 pOut[1 * MATRIX_DIM3] = (q31_t) acc1;
205 pOut[2 * MATRIX_DIM3] = (q31_t) acc2;
206 /*
207 * Return to application
208 */
209 return (ARM_MATH_SUCCESS);
210 }
211
arm_mat_mult_q31_4x4_mve(const arm_matrix_instance_q31 * pSrcA,const arm_matrix_instance_q31 * pSrcB,arm_matrix_instance_q31 * pDst)212 __STATIC_INLINE arm_status arm_mat_mult_q31_4x4_mve(
213 const arm_matrix_instance_q31 * pSrcA,
214 const arm_matrix_instance_q31 * pSrcB,
215 arm_matrix_instance_q31 * pDst)
216 {
217 q31_t *pInB = pSrcB->pData; /* input data matrix pointer B */
218 q31_t *pInA = pSrcA->pData; /* input data matrix pointer A */
219 q31_t *pOut = pDst->pData; /* output data matrix pointer */
220 uint32x4_t vecColBOffs;
221 q31_t *pInA0 = pInA;
222 q31_t *pInA1 = pInA0 + MATRIX_DIM4;
223 q31_t *pInA2 = pInA1 + MATRIX_DIM4;
224 q31_t *pInA3 = pInA2 + MATRIX_DIM4;
225 q63_t acc0, acc1, acc2, acc3;
226 q31x4_t vecB, vecA;
227
228 vecColBOffs = vidupq_u32((uint32_t)0, 4);
229
230 pInB = pSrcB->pData;
231
232 vecB = vldrwq_gather_shifted_offset_s32(pInB, vecColBOffs);
233
234 vecA = vldrwq_s32(pInA0);
235 acc0 = vrmlaldavhq(vecA, vecB);
236 vecA = vldrwq_s32(pInA1);
237 acc1 = vrmlaldavhq(vecA, vecB);
238 vecA = vldrwq_s32(pInA2);
239 acc2 = vrmlaldavhq(vecA, vecB);
240 vecA = vldrwq_s32(pInA3);
241 acc3 = vrmlaldavhq(vecA, vecB);
242
243 acc0 = asrl(acc0, 23);
244 acc1 = asrl(acc1, 23);
245 acc2 = asrl(acc2, 23);
246 acc3 = asrl(acc3, 23);
247
248 pOut[0 * MATRIX_DIM4] = (q31_t) acc0;
249 pOut[1 * MATRIX_DIM4] = (q31_t) acc1;
250 pOut[2 * MATRIX_DIM4] = (q31_t) acc2;
251 pOut[3 * MATRIX_DIM4] = (q31_t) acc3;
252 pOut++;
253
254 /* move to next B column */
255 pInB = pInB + 1;
256
257 vecB = vldrwq_gather_shifted_offset_s32(pInB, vecColBOffs);
258
259 vecA = vldrwq_s32(pInA0);
260 acc0 = vrmlaldavhq(vecA, vecB);
261 vecA = vldrwq_s32(pInA1);
262 acc1 = vrmlaldavhq(vecA, vecB);
263 vecA = vldrwq_s32(pInA2);
264 acc2 = vrmlaldavhq(vecA, vecB);
265 vecA = vldrwq_s32(pInA3);
266 acc3 = vrmlaldavhq(vecA, vecB);
267
268 acc0 = asrl(acc0, 23);
269 acc1 = asrl(acc1, 23);
270 acc2 = asrl(acc2, 23);
271 acc3 = asrl(acc3, 23);
272
273 pOut[0 * MATRIX_DIM4] = (q31_t) acc0;
274 pOut[1 * MATRIX_DIM4] = (q31_t) acc1;
275 pOut[2 * MATRIX_DIM4] = (q31_t) acc2;
276 pOut[3 * MATRIX_DIM4] = (q31_t) acc3;
277
278 pOut++;
279
280 /* move to next B column */
281 pInB = pInB + 1;
282
283 vecB = vldrwq_gather_shifted_offset_s32(pInB, vecColBOffs);
284
285 vecA = vldrwq_s32(pInA0);
286 acc0 = vrmlaldavhq(vecA, vecB);
287 vecA = vldrwq_s32(pInA1);
288 acc1 = vrmlaldavhq(vecA, vecB);
289 vecA = vldrwq_s32(pInA2);
290 acc2 = vrmlaldavhq(vecA, vecB);
291 vecA = vldrwq_s32(pInA3);
292 acc3 = vrmlaldavhq(vecA, vecB);
293
294 acc0 = asrl(acc0, 23);
295 acc1 = asrl(acc1, 23);
296 acc2 = asrl(acc2, 23);
297 acc3 = asrl(acc3, 23);
298
299 pOut[0 * MATRIX_DIM4] = (q31_t) acc0;
300 pOut[1 * MATRIX_DIM4] = (q31_t) acc1;
301 pOut[2 * MATRIX_DIM4] = (q31_t) acc2;
302 pOut[3 * MATRIX_DIM4] = (q31_t) acc3;
303
304 pOut++;
305
306 /* move to next B column */
307 pInB = pInB + 1;
308
309 vecB = vldrwq_gather_shifted_offset_s32(pInB, vecColBOffs);
310
311 vecA = vldrwq_s32(pInA0);
312 acc0 = vrmlaldavhq(vecA, vecB);
313 vecA = vldrwq_s32(pInA1);
314 acc1 = vrmlaldavhq(vecA, vecB);
315 vecA = vldrwq_s32(pInA2);
316 acc2 = vrmlaldavhq(vecA, vecB);
317 vecA = vldrwq_s32(pInA3);
318 acc3 = vrmlaldavhq(vecA, vecB);
319
320 acc0 = asrl(acc0, 23);
321 acc1 = asrl(acc1, 23);
322 acc2 = asrl(acc2, 23);
323 acc3 = asrl(acc3, 23);
324
325 pOut[0 * MATRIX_DIM4] = (q31_t) acc0;
326 pOut[1 * MATRIX_DIM4] = (q31_t) acc1;
327 pOut[2 * MATRIX_DIM4] = (q31_t) acc2;
328 pOut[3 * MATRIX_DIM4] = (q31_t) acc3;
329 /*
330 * Return to application
331 */
332 return (ARM_MATH_SUCCESS);
333 }
334
arm_mat_mult_q31(const arm_matrix_instance_q31 * pSrcA,const arm_matrix_instance_q31 * pSrcB,arm_matrix_instance_q31 * pDst)335 arm_status arm_mat_mult_q31(
336 const arm_matrix_instance_q31 * pSrcA,
337 const arm_matrix_instance_q31 * pSrcB,
338 arm_matrix_instance_q31 * pDst)
339 {
340 q31_t const *pInB = (q31_t const *)pSrcB->pData; /* input data matrix pointer B */
341 q31_t const *pInA = (q31_t const *)pSrcA->pData; /* input data matrix pointer A */
342 q31_t *pOut = pDst->pData; /* output data matrix pointer */
343 q31_t *px; /* Temporary output data matrix pointer */
344 uint16_t numRowsA = pSrcA->numRows; /* number of rows of input matrix A */
345 uint16_t numColsB = pSrcB->numCols; /* number of columns of input matrix B */
346 uint16_t numColsA = pSrcA->numCols; /* number of columns of input matrix A */
347 uint16_t col, i = 0U, row = numRowsA; /* loop counters */
348 arm_status status; /* status of matrix multiplication */
349 uint32x4_t vecOffs, vecColBOffs;
350 uint32_t blkCnt, rowCnt; /* loop counters */
351
352 #ifdef ARM_MATH_MATRIX_CHECK
353
354 /* Check for matrix mismatch condition */
355 if ((pSrcA->numCols != pSrcB->numRows) ||
356 (pSrcA->numRows != pDst->numRows) ||
357 (pSrcB->numCols != pDst->numCols) )
358 {
359 /* Set status as ARM_MATH_SIZE_MISMATCH */
360 status = ARM_MATH_SIZE_MISMATCH;
361 }
362 else
363
364 #endif /* #ifdef ARM_MATH_MATRIX_CHECK */
365
366 {
367 /* small squared matrix specialized routines */
368 if(numRowsA == numColsB && numColsB == numColsA) {
369 if (numRowsA == 1)
370 {
371 q63_t sum = (q63_t) *pInA * *pInB;
372 pOut[0] = (q31_t)(sum >> 31);
373 return (ARM_MATH_SUCCESS);
374 }
375 else if(numRowsA == 2)
376 return arm_mat_mult_q31_2x2_mve(pSrcA, pSrcB, pDst);
377 else if(numRowsA == 3)
378 return arm_mat_mult_q31_3x3_mve(pSrcA, pSrcB, pDst);
379 else if (numRowsA == 4)
380 return arm_mat_mult_q31_4x4_mve(pSrcA, pSrcB, pDst);
381 }
382
383 vecColBOffs = vidupq_u32((uint32_t)0, 1);
384 vecColBOffs = vecColBOffs * (uint32_t) (numColsB);
385
386 /*
387 * The following loop performs the dot-product of each row in pSrcA with each column in pSrcB
388 */
389
390 /*
391 * row loop
392 */
393 rowCnt = row >> 2;
394 while (rowCnt > 0U)
395 {
396 /*
397 * Output pointer is set to starting address of the row being processed
398 */
399 px = pOut + i;
400 i = i + 4 * numColsB;
401 /*
402 * For every row wise process, the column loop counter is to be initiated
403 */
404 col = numColsB;
405 /*
406 * For every row wise process, the pInB pointer is set
407 * to the starting address of the pSrcB data
408 */
409 pInB = (q31_t const *)pSrcB->pData;
410 /*
411 * column loop
412 */
413 while (col > 0U)
414 {
415 /*
416 * generate 4 columns elements
417 */
418 /*
419 * Matrix A columns number of MAC operations are to be performed
420 */
421
422 q31_t const *pSrcA0Vec, *pSrcA1Vec, *pSrcA2Vec, *pSrcA3Vec;
423 q31_t const *pInA0 = pInA;
424 q31_t const *pInA1 = pInA0 + numColsA;
425 q31_t const *pInA2 = pInA1 + numColsA;
426 q31_t const *pInA3 = pInA2 + numColsA;
427 q63_t acc0, acc1, acc2, acc3;
428
429 acc0 = 0LL;
430 acc1 = 0LL;
431 acc2 = 0LL;
432 acc3 = 0LL;
433
434 pSrcA0Vec = (q31_t const *) pInA0;
435 pSrcA1Vec = (q31_t const *) pInA1;
436 pSrcA2Vec = (q31_t const *) pInA2;
437 pSrcA3Vec = (q31_t const *) pInA3;
438
439 vecOffs = vecColBOffs;
440
441 /* process 1 x 4 block output */
442 blkCnt = numColsA >> 2;
443 while (blkCnt > 0U)
444 {
445 q31x4_t vecB, vecA;
446
447 vecB = vldrwq_gather_shifted_offset(pInB, vecOffs);
448 /* move Matrix B read offsets, 4 rows down */
449 vecOffs = vecOffs + (uint32_t) (numColsB * 4);
450
451 vecA = vld1q(pSrcA0Vec); pSrcA0Vec += 4;
452 acc0 = vrmlaldavhaq(acc0, vecA, vecB);
453 vecA = vld1q(pSrcA1Vec); pSrcA1Vec += 4;
454 acc1 = vrmlaldavhaq(acc1, vecA, vecB);
455 vecA = vld1q(pSrcA2Vec); pSrcA2Vec += 4;
456 acc2 = vrmlaldavhaq(acc2, vecA, vecB);
457 vecA = vld1q(pSrcA3Vec); pSrcA3Vec += 4;
458 acc3 = vrmlaldavhaq(acc3, vecA, vecB);
459 blkCnt--;
460 }
461
462 /*
463 * tail
464 * (will be merged thru tail predication)
465 */
466 blkCnt = numColsA & 3;
467 if (blkCnt > 0U)
468 {
469 mve_pred16_t p0 = vctp32q(blkCnt);
470 q31x4_t vecB, vecA;
471
472 vecB = vldrwq_gather_shifted_offset_z(pInB, vecOffs, p0);
473 //vecOffs = vecOffs + (uint32_t) (numColsB * 4);
474
475 vecA = vld1q(pSrcA0Vec); pSrcA0Vec += 4;
476 acc0 = vrmlaldavhaq(acc0, vecA, vecB);
477 vecA = vld1q(pSrcA1Vec); pSrcA1Vec += 4;
478 acc1 = vrmlaldavhaq(acc1, vecA, vecB);
479 vecA = vld1q(pSrcA2Vec); pSrcA2Vec += 4;
480 acc2 = vrmlaldavhaq(acc2, vecA, vecB);
481 vecA = vld1q(pSrcA3Vec); pSrcA3Vec += 4;
482 acc3 = vrmlaldavhaq(acc3, vecA, vecB);
483 }
484
485 acc0 = asrl(acc0, 23);
486 acc1 = asrl(acc1, 23);
487 acc2 = asrl(acc2, 23);
488 acc3 = asrl(acc3, 23);
489
490 px[0] = (q31_t) acc0;
491 px[1 * numColsB] = (q31_t) acc1;
492 px[2 * numColsB] = (q31_t) acc2;
493 px[3 * numColsB] = (q31_t) acc3;
494 px++;
495 /*
496 * Decrement the column loop counter
497 */
498 col--;
499 /*
500 * Update the pointer pInB to point to the starting address of the next column
501 */
502 pInB = (q31_t const *)pSrcB->pData + (numColsB - col);
503 }
504
505 /*
506 * Update the pointer pInA to point to the starting address of the next row
507 */
508 pInA += (numColsA * 4);
509 /*
510 * Decrement the row loop counter
511 */
512 rowCnt --;
513
514 }
515 rowCnt = row & 3;
516 while (rowCnt > 0U)
517 {
518 /*
519 * Output pointer is set to starting address of the row being processed
520 */
521 px = pOut + i;
522 i = i + numColsB;
523 /*
524 * For every row wise process, the column loop counter is to be initiated
525 */
526 col = numColsB;
527 /*
528 * For every row wise process, the pInB pointer is set
529 * to the starting address of the pSrcB data
530 */
531 pInB = (q31_t const *)pSrcB->pData;
532 /*
533 * column loop
534 */
535 while (col > 0U)
536 {
537 /*
538 * generate 4 columns elements
539 */
540 /*
541 * Matrix A columns number of MAC operations are to be performed
542 */
543
544 q31_t const *pSrcA0Vec;
545 q31_t const *pInA0 = pInA;
546 q63_t acc0;
547
548 acc0 = 0LL;
549
550
551 pSrcA0Vec = (q31_t const *) pInA0;
552
553 vecOffs = vecColBOffs;
554
555 /* process 1 x 4 block output */
556 blkCnt = numColsA >> 2;
557 while (blkCnt > 0U)
558 {
559 q31x4_t vecB, vecA;
560
561 vecB = vldrwq_gather_shifted_offset(pInB, vecOffs);
562 /* move Matrix B read offsets, 4 rows down */
563 vecOffs = vecOffs + (uint32_t) (numColsB * 4);
564
565 vecA = vld1q(pSrcA0Vec); pSrcA0Vec += 4;
566 acc0 = vrmlaldavhaq(acc0, vecA, vecB);
567
568 blkCnt--;
569 }
570
571 /*
572 * tail
573 * (will be merged thru tail predication)
574 */
575 blkCnt = numColsA & 3;
576 if (blkCnt > 0U)
577 {
578 mve_pred16_t p0 = vctp32q(blkCnt);
579 q31x4_t vecB, vecA;
580
581 vecB = vldrwq_gather_shifted_offset_z(pInB, vecOffs, p0);
582 //vecOffs = vecOffs + (uint32_t) (numColsB * 4);
583
584 vecA = vld1q(pSrcA0Vec);
585 pSrcA0Vec += 4;
586 acc0 = vrmlaldavhaq(acc0, vecA, vecB);
587
588 }
589
590 acc0 = asrl(acc0, 23);
591
592
593 px[0] = (q31_t) acc0;
594 px++;
595 /*
596 * Decrement the column loop counter
597 */
598 col--;
599 /*
600 * Update the pointer pInB to point to the starting address of the next column
601 */
602 pInB = (q31_t const *)pSrcB->pData + (numColsB - col);
603 }
604
605 /*
606 * Update the pointer pInA to point to the starting address of the next row
607 */
608 pInA += numColsA;
609 /*
610 * Decrement the row loop counter
611 */
612 rowCnt--;
613 }
614
615 /*
616 * set status as ARM_MATH_SUCCESS
617 */
618 status = ARM_MATH_SUCCESS;
619 }
620
621 /* Return to application */
622 return (status);
623 }
624
625 #else
arm_mat_mult_q31(const arm_matrix_instance_q31 * pSrcA,const arm_matrix_instance_q31 * pSrcB,arm_matrix_instance_q31 * pDst)626 arm_status arm_mat_mult_q31(
627 const arm_matrix_instance_q31 * pSrcA,
628 const arm_matrix_instance_q31 * pSrcB,
629 arm_matrix_instance_q31 * pDst)
630 {
631 q31_t *pIn1 = pSrcA->pData; /* Input data matrix pointer A */
632 q31_t *pIn2 = pSrcB->pData; /* Input data matrix pointer B */
633 q31_t *pInA = pSrcA->pData; /* Input data matrix pointer A */
634 q31_t *pInB = pSrcB->pData; /* Input data matrix pointer B */
635 q31_t *pOut = pDst->pData; /* Output data matrix pointer */
636 q31_t *px; /* Temporary output data matrix pointer */
637 q63_t sum; /* Accumulator */
638 uint16_t numRowsA = pSrcA->numRows; /* Number of rows of input matrix A */
639 uint16_t numColsB = pSrcB->numCols; /* Number of columns of input matrix B */
640 uint16_t numColsA = pSrcA->numCols; /* Number of columns of input matrix A */
641 uint32_t col, i = 0U, row = numRowsA, colCnt; /* Loop counters */
642 arm_status status; /* Status of matrix multiplication */
643
644 #ifdef ARM_MATH_MATRIX_CHECK
645
646 /* Check for matrix mismatch condition */
647 if ((pSrcA->numCols != pSrcB->numRows) ||
648 (pSrcA->numRows != pDst->numRows) ||
649 (pSrcB->numCols != pDst->numCols) )
650 {
651 /* Set status as ARM_MATH_SIZE_MISMATCH */
652 status = ARM_MATH_SIZE_MISMATCH;
653 }
654 else
655
656 #endif /* #ifdef ARM_MATH_MATRIX_CHECK */
657
658 {
659 /* The following loop performs the dot-product of each row in pSrcA with each column in pSrcB */
660 /* row loop */
661 do
662 {
663 /* Output pointer is set to starting address of row being processed */
664 px = pOut + i;
665
666 /* For every row wise process, column loop counter is to be initiated */
667 col = numColsB;
668
669 /* For every row wise process, pIn2 pointer is set to starting address of pSrcB data */
670 pIn2 = pSrcB->pData;
671
672 /* column loop */
673 do
674 {
675 /* Set the variable sum, that acts as accumulator, to zero */
676 sum = 0;
677
678 /* Initialize pointer pIn1 to point to starting address of column being processed */
679 pIn1 = pInA;
680
681 #if defined (ARM_MATH_LOOPUNROLL)
682
683 /* Loop unrolling: Compute 4 MACs at a time. */
684 colCnt = numColsA >> 2U;
685
686 /* matrix multiplication */
687 while (colCnt > 0U)
688 {
689 /* c(m,n) = a(1,1) * b(1,1) + a(1,2) * b(2,1) + .... + a(m,p) * b(p,n) */
690
691 /* Perform the multiply-accumulates */
692 sum += (q63_t) *pIn1++ * *pIn2;
693 pIn2 += numColsB;
694
695 sum += (q63_t) *pIn1++ * *pIn2;
696 pIn2 += numColsB;
697
698 sum += (q63_t) *pIn1++ * *pIn2;
699 pIn2 += numColsB;
700
701 sum += (q63_t) *pIn1++ * *pIn2;
702 pIn2 += numColsB;
703
704 /* Decrement loop counter */
705 colCnt--;
706 }
707
708 /* Loop unrolling: Compute remaining MACs */
709 colCnt = numColsA % 0x4U;
710
711 #else
712
713 /* Initialize cntCnt with number of columns */
714 colCnt = numColsA;
715
716 #endif /* #if defined (ARM_MATH_LOOPUNROLL) */
717
718 while (colCnt > 0U)
719 {
720 /* c(m,n) = a(1,1) * b(1,1) + a(1,2) * b(2,1) + .... + a(m,p) * b(p,n) */
721
722 /* Perform the multiply-accumulates */
723 sum += (q63_t) *pIn1++ * *pIn2;
724 pIn2 += numColsB;
725
726 /* Decrement loop counter */
727 colCnt--;
728 }
729
730 /* Convert result from 2.62 to 1.31 format and store in destination buffer */
731 *px++ = (q31_t) (sum >> 31);
732
733 /* Decrement column loop counter */
734 col--;
735
736 /* Update pointer pIn2 to point to starting address of next column */
737 pIn2 = pInB + (numColsB - col);
738
739 } while (col > 0U);
740
741 /* Update pointer pInA to point to starting address of next row */
742 i = i + numColsB;
743 pInA = pInA + numColsA;
744
745 /* Decrement row loop counter */
746 row--;
747
748 } while (row > 0U);
749
750 /* Set status as ARM_MATH_SUCCESS */
751 status = ARM_MATH_SUCCESS;
752 }
753
754 /* Return to application */
755 return (status);
756 }
757 #endif /* defined(ARM_MATH_MVEI) */
758
759 /**
760 @} end of MatrixMult group
761 */
762