• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* ----------------------------------------------------------------------
2  * Project:      CMSIS DSP Library
3  * Title:        arm_mat_mult_q31.c
4  * Description:  Q31 matrix multiplication
5  *
6  * $Date:        23 April 2021
7  * $Revision:    V1.9.0
8  *
9  * Target Processor: Cortex-M and Cortex-A cores
10  * -------------------------------------------------------------------- */
11 /*
12  * Copyright (C) 2010-2021 ARM Limited or its affiliates. All rights reserved.
13  *
14  * SPDX-License-Identifier: Apache-2.0
15  *
16  * Licensed under the Apache License, Version 2.0 (the License); you may
17  * not use this file except in compliance with the License.
18  * You may obtain a copy of the License at
19  *
20  * www.apache.org/licenses/LICENSE-2.0
21  *
22  * Unless required by applicable law or agreed to in writing, software
23  * distributed under the License is distributed on an AS IS BASIS, WITHOUT
24  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25  * See the License for the specific language governing permissions and
26  * limitations under the License.
27  */
28 
29 #include "dsp/matrix_functions.h"
30 
31 /**
32   @ingroup groupMatrix
33  */
34 
35 /**
36   @addtogroup MatrixMult
37   @{
38  */
39 
40 /**
41   @brief         Q31 matrix multiplication.
42   @param[in]     pSrcA      points to the first input matrix structure
43   @param[in]     pSrcB      points to the second input matrix structure
44   @param[out]    pDst       points to output matrix structure
45   @return        execution status
46                    - \ref ARM_MATH_SUCCESS       : Operation successful
47                    - \ref ARM_MATH_SIZE_MISMATCH : Matrix size check failed
48 
49   @par           Scaling and Overflow Behavior
50                    The function is implemented using an internal 64-bit accumulator.
51                    The accumulator has a 2.62 format and maintains full precision of the intermediate
52                    multiplication results but provides only a single guard bit. There is no saturation
53                    on intermediate additions. Thus, if the accumulator overflows it wraps around and
54                    distorts the result. The input signals should be scaled down to avoid intermediate
55                    overflows. The input is thus scaled down by log2(numColsA) bits
56                    to avoid overflows, as a total of numColsA additions are performed internally.
57                    The 2.62 accumulator is right shifted by 31 bits and saturated to 1.31 format to yield the final result.
58   @remark
59                    Refer to \ref arm_mat_mult_fast_q31() for a faster but less precise implementation of this function.
60  */
61 #if defined(ARM_MATH_MVEI) && !defined(ARM_MATH_AUTOVECTORIZE)
62 
63 #define MATRIX_DIM2 2
64 #define MATRIX_DIM3 3
65 #define MATRIX_DIM4 4
66 
arm_mat_mult_q31_2x2_mve(const arm_matrix_instance_q31 * pSrcA,const arm_matrix_instance_q31 * pSrcB,arm_matrix_instance_q31 * pDst)67 __STATIC_INLINE arm_status arm_mat_mult_q31_2x2_mve(
68     const arm_matrix_instance_q31 * pSrcA,
69     const arm_matrix_instance_q31 * pSrcB,
70     arm_matrix_instance_q31 * pDst)
71 {
72     q31_t       *pInB = pSrcB->pData;  /* input data matrix pointer B */
73     q31_t       *pInA = pSrcA->pData;  /* input data matrix pointer A */
74     q31_t       *pOut = pDst->pData;   /* output data matrix pointer */
75     uint32x4_t   vecColBOffs;
76     q31_t       *pInA0 = pInA;
77     q31_t       *pInA1 = pInA0 + MATRIX_DIM2;
78     q63_t        acc0, acc1;
79     q31x4_t      vecB, vecA0, vecA1;
80     /* enable predication to disable half of vector elements */
81     mve_pred16_t p0 = vctp32q(MATRIX_DIM2);
82 
83     vecColBOffs = vidupq_u32((uint32_t)0, 1);
84     vecColBOffs = vecColBOffs * MATRIX_DIM2;
85 
86     pInB = pSrcB->pData;
87 
88     /* load 1st B column (partial load) */
89     vecB = vldrwq_gather_shifted_offset_z_s32(pInB, vecColBOffs, p0);
90 
91     /* load A rows */
92     vecA0 = vldrwq_s32(pInA0);
93     vecA1 = vldrwq_s32(pInA1);
94 
95     acc0 = vrmlaldavhq(vecA0, vecB);
96     acc1 = vrmlaldavhq(vecA1, vecB);
97 
98     acc0 = asrl(acc0, 23);
99     acc1 = asrl(acc1, 23);
100 
101     pOut[0 * MATRIX_DIM2] = (q31_t) acc0;
102     pOut[1 * MATRIX_DIM2] = (q31_t) acc1;
103     pOut++;
104 
105     /* move to next B column */
106     pInB = pInB + 1;
107 
108     vecB = vldrwq_gather_shifted_offset_z_s32(pInB, vecColBOffs, p0);
109 
110     acc0 = vrmlaldavhq(vecA0, vecB);
111     acc1 = vrmlaldavhq(vecA1, vecB);
112 
113     acc0 = asrl(acc0, 23);
114     acc1 = asrl(acc1, 23);
115 
116     pOut[0 * MATRIX_DIM2] = (q31_t) acc0;
117     pOut[1 * MATRIX_DIM2] = (q31_t) acc1;
118     /*
119      * Return to application
120      */
121     return (ARM_MATH_SUCCESS);
122 }
123 
124 
125 
arm_mat_mult_q31_3x3_mve(const arm_matrix_instance_q31 * pSrcA,const arm_matrix_instance_q31 * pSrcB,arm_matrix_instance_q31 * pDst)126 __STATIC_INLINE arm_status arm_mat_mult_q31_3x3_mve(
127     const arm_matrix_instance_q31 * pSrcA,
128     const arm_matrix_instance_q31 * pSrcB,
129     arm_matrix_instance_q31 * pDst)
130 {
131     q31_t       *pInB = pSrcB->pData;  /* input data matrix pointer B */
132     q31_t       *pInA = pSrcA->pData;  /* input data matrix pointer A */
133     q31_t       *pOut = pDst->pData;   /* output data matrix pointer */
134     uint32x4_t   vecColBOffs;
135     q31_t       *pInA0 = pInA;
136     q31_t       *pInA1 = pInA0 + MATRIX_DIM3;
137     q31_t       *pInA2 = pInA1 + MATRIX_DIM3;
138     q63_t        acc0, acc1, acc2;
139     q31x4_t      vecB, vecA;
140     /* enable predication to disable last (4th) vector element */
141     mve_pred16_t p0 = vctp32q(MATRIX_DIM3);
142 
143     vecColBOffs = vidupq_u32((uint32_t)0, 1);
144     vecColBOffs = vecColBOffs * MATRIX_DIM3;
145 
146     pInB = pSrcB->pData;
147 
148     vecB = vldrwq_gather_shifted_offset_z_s32(pInB, vecColBOffs, p0);
149 
150     vecA = vldrwq_s32(pInA0);
151     acc0 = vrmlaldavhq(vecA, vecB);
152     vecA = vldrwq_s32(pInA1);
153     acc1 = vrmlaldavhq(vecA, vecB);
154     vecA = vldrwq_s32(pInA2);
155     acc2 = vrmlaldavhq(vecA, vecB);
156 
157     acc0 = asrl(acc0, 23);
158     acc1 = asrl(acc1, 23);
159     acc2 = asrl(acc2, 23);
160 
161     pOut[0 * MATRIX_DIM3] = (q31_t) acc0;
162     pOut[1 * MATRIX_DIM3] = (q31_t) acc1;
163     pOut[2 * MATRIX_DIM3] = (q31_t) acc2;
164     pOut++;
165 
166     /* move to next B column */
167     pInB = pInB + 1;
168 
169     vecB = vldrwq_gather_shifted_offset_z_s32(pInB, vecColBOffs, p0);
170 
171     vecA = vldrwq_s32(pInA0);
172     acc0 = vrmlaldavhq(vecA, vecB);
173     vecA = vldrwq_s32(pInA1);
174     acc1 = vrmlaldavhq(vecA, vecB);
175     vecA = vldrwq_s32(pInA2);
176     acc2 = vrmlaldavhq(vecA, vecB);
177 
178     acc0 = asrl(acc0, 23);
179     acc1 = asrl(acc1, 23);
180     acc2 = asrl(acc2, 23);
181 
182     pOut[0 * MATRIX_DIM3] = (q31_t) acc0;
183     pOut[1 * MATRIX_DIM3] = (q31_t) acc1;
184     pOut[2 * MATRIX_DIM3] = (q31_t) acc2;
185     pOut++;
186 
187     /* move to next B column */
188     pInB = pInB + 1;
189 
190     vecB = vldrwq_gather_shifted_offset_z_s32(pInB, vecColBOffs, p0);
191 
192     vecA = vldrwq_s32(pInA0);
193     acc0 = vrmlaldavhq(vecA, vecB);
194     vecA = vldrwq_s32(pInA1);
195     acc1 = vrmlaldavhq(vecA, vecB);
196     vecA = vldrwq_s32(pInA2);
197     acc2 = vrmlaldavhq(vecA, vecB);
198 
199     acc0 = asrl(acc0, 23);
200     acc1 = asrl(acc1, 23);
201     acc2 = asrl(acc2, 23);
202 
203     pOut[0 * MATRIX_DIM3] = (q31_t) acc0;
204     pOut[1 * MATRIX_DIM3] = (q31_t) acc1;
205     pOut[2 * MATRIX_DIM3] = (q31_t) acc2;
206     /*
207      * Return to application
208      */
209     return (ARM_MATH_SUCCESS);
210 }
211 
arm_mat_mult_q31_4x4_mve(const arm_matrix_instance_q31 * pSrcA,const arm_matrix_instance_q31 * pSrcB,arm_matrix_instance_q31 * pDst)212 __STATIC_INLINE arm_status arm_mat_mult_q31_4x4_mve(
213     const arm_matrix_instance_q31 * pSrcA,
214     const arm_matrix_instance_q31 * pSrcB,
215     arm_matrix_instance_q31 * pDst)
216 {
217     q31_t       *pInB = pSrcB->pData;  /* input data matrix pointer B */
218     q31_t       *pInA = pSrcA->pData;  /* input data matrix pointer A */
219     q31_t       *pOut = pDst->pData;   /* output data matrix pointer */
220     uint32x4_t   vecColBOffs;
221     q31_t       *pInA0 = pInA;
222     q31_t       *pInA1 = pInA0 + MATRIX_DIM4;
223     q31_t       *pInA2 = pInA1 + MATRIX_DIM4;
224     q31_t       *pInA3 = pInA2 + MATRIX_DIM4;
225     q63_t        acc0, acc1, acc2, acc3;
226     q31x4_t      vecB, vecA;
227 
228     vecColBOffs = vidupq_u32((uint32_t)0, 4);
229 
230     pInB = pSrcB->pData;
231 
232     vecB = vldrwq_gather_shifted_offset_s32(pInB, vecColBOffs);
233 
234     vecA = vldrwq_s32(pInA0);
235     acc0 = vrmlaldavhq(vecA, vecB);
236     vecA = vldrwq_s32(pInA1);
237     acc1 = vrmlaldavhq(vecA, vecB);
238     vecA = vldrwq_s32(pInA2);
239     acc2 = vrmlaldavhq(vecA, vecB);
240     vecA = vldrwq_s32(pInA3);
241     acc3 = vrmlaldavhq(vecA, vecB);
242 
243     acc0 = asrl(acc0, 23);
244     acc1 = asrl(acc1, 23);
245     acc2 = asrl(acc2, 23);
246     acc3 = asrl(acc3, 23);
247 
248     pOut[0 * MATRIX_DIM4] = (q31_t) acc0;
249     pOut[1 * MATRIX_DIM4] = (q31_t) acc1;
250     pOut[2 * MATRIX_DIM4] = (q31_t) acc2;
251     pOut[3 * MATRIX_DIM4] = (q31_t) acc3;
252     pOut++;
253 
254     /* move to next B column */
255     pInB = pInB + 1;
256 
257     vecB = vldrwq_gather_shifted_offset_s32(pInB, vecColBOffs);
258 
259     vecA = vldrwq_s32(pInA0);
260     acc0 = vrmlaldavhq(vecA, vecB);
261     vecA = vldrwq_s32(pInA1);
262     acc1 = vrmlaldavhq(vecA, vecB);
263     vecA = vldrwq_s32(pInA2);
264     acc2 = vrmlaldavhq(vecA, vecB);
265     vecA = vldrwq_s32(pInA3);
266     acc3 = vrmlaldavhq(vecA, vecB);
267 
268     acc0 = asrl(acc0, 23);
269     acc1 = asrl(acc1, 23);
270     acc2 = asrl(acc2, 23);
271     acc3 = asrl(acc3, 23);
272 
273     pOut[0 * MATRIX_DIM4] = (q31_t) acc0;
274     pOut[1 * MATRIX_DIM4] = (q31_t) acc1;
275     pOut[2 * MATRIX_DIM4] = (q31_t) acc2;
276     pOut[3 * MATRIX_DIM4] = (q31_t) acc3;
277 
278     pOut++;
279 
280     /* move to next B column */
281     pInB = pInB + 1;
282 
283     vecB = vldrwq_gather_shifted_offset_s32(pInB, vecColBOffs);
284 
285     vecA = vldrwq_s32(pInA0);
286     acc0 = vrmlaldavhq(vecA, vecB);
287     vecA = vldrwq_s32(pInA1);
288     acc1 = vrmlaldavhq(vecA, vecB);
289     vecA = vldrwq_s32(pInA2);
290     acc2 = vrmlaldavhq(vecA, vecB);
291     vecA = vldrwq_s32(pInA3);
292     acc3 = vrmlaldavhq(vecA, vecB);
293 
294     acc0 = asrl(acc0, 23);
295     acc1 = asrl(acc1, 23);
296     acc2 = asrl(acc2, 23);
297     acc3 = asrl(acc3, 23);
298 
299     pOut[0 * MATRIX_DIM4] = (q31_t) acc0;
300     pOut[1 * MATRIX_DIM4] = (q31_t) acc1;
301     pOut[2 * MATRIX_DIM4] = (q31_t) acc2;
302     pOut[3 * MATRIX_DIM4] = (q31_t) acc3;
303 
304     pOut++;
305 
306     /* move to next B column */
307     pInB = pInB + 1;
308 
309     vecB = vldrwq_gather_shifted_offset_s32(pInB, vecColBOffs);
310 
311     vecA = vldrwq_s32(pInA0);
312     acc0 = vrmlaldavhq(vecA, vecB);
313     vecA = vldrwq_s32(pInA1);
314     acc1 = vrmlaldavhq(vecA, vecB);
315     vecA = vldrwq_s32(pInA2);
316     acc2 = vrmlaldavhq(vecA, vecB);
317     vecA = vldrwq_s32(pInA3);
318     acc3 = vrmlaldavhq(vecA, vecB);
319 
320     acc0 = asrl(acc0, 23);
321     acc1 = asrl(acc1, 23);
322     acc2 = asrl(acc2, 23);
323     acc3 = asrl(acc3, 23);
324 
325     pOut[0 * MATRIX_DIM4] = (q31_t) acc0;
326     pOut[1 * MATRIX_DIM4] = (q31_t) acc1;
327     pOut[2 * MATRIX_DIM4] = (q31_t) acc2;
328     pOut[3 * MATRIX_DIM4] = (q31_t) acc3;
329     /*
330      * Return to application
331      */
332     return (ARM_MATH_SUCCESS);
333 }
334 
arm_mat_mult_q31(const arm_matrix_instance_q31 * pSrcA,const arm_matrix_instance_q31 * pSrcB,arm_matrix_instance_q31 * pDst)335 arm_status arm_mat_mult_q31(
336   const arm_matrix_instance_q31 * pSrcA,
337   const arm_matrix_instance_q31 * pSrcB,
338         arm_matrix_instance_q31 * pDst)
339 {
340     q31_t const *pInB = (q31_t const *)pSrcB->pData;  /* input data matrix pointer B */
341     q31_t const *pInA = (q31_t const *)pSrcA->pData;  /* input data matrix pointer A */
342     q31_t      *pOut = pDst->pData;   /* output data matrix pointer */
343     q31_t      *px;               /* Temporary output data matrix pointer */
344     uint16_t    numRowsA = pSrcA->numRows;    /* number of rows of input matrix A    */
345     uint16_t    numColsB = pSrcB->numCols;    /* number of columns of input matrix B */
346     uint16_t    numColsA = pSrcA->numCols;    /* number of columns of input matrix A */
347     uint16_t    col, i = 0U, row = numRowsA;  /* loop counters */
348     arm_status  status;          /* status of matrix multiplication */
349     uint32x4_t  vecOffs, vecColBOffs;
350     uint32_t    blkCnt, rowCnt;           /* loop counters */
351 
352   #ifdef ARM_MATH_MATRIX_CHECK
353 
354   /* Check for matrix mismatch condition */
355   if ((pSrcA->numCols != pSrcB->numRows) ||
356       (pSrcA->numRows != pDst->numRows)  ||
357       (pSrcB->numCols != pDst->numCols)    )
358   {
359     /* Set status as ARM_MATH_SIZE_MISMATCH */
360     status = ARM_MATH_SIZE_MISMATCH;
361   }
362   else
363 
364 #endif /* #ifdef ARM_MATH_MATRIX_CHECK */
365 
366   {
367      /* small squared matrix specialized routines */
368     if(numRowsA == numColsB && numColsB == numColsA) {
369         if (numRowsA == 1)
370         {
371           q63_t sum =  (q63_t) *pInA * *pInB;
372           pOut[0] = (q31_t)(sum >> 31);
373           return (ARM_MATH_SUCCESS);
374         }
375         else if(numRowsA == 2)
376             return arm_mat_mult_q31_2x2_mve(pSrcA, pSrcB, pDst);
377         else if(numRowsA == 3)
378             return arm_mat_mult_q31_3x3_mve(pSrcA, pSrcB, pDst);
379         else if (numRowsA == 4)
380             return arm_mat_mult_q31_4x4_mve(pSrcA, pSrcB, pDst);
381     }
382 
383     vecColBOffs = vidupq_u32((uint32_t)0, 1);
384     vecColBOffs = vecColBOffs * (uint32_t) (numColsB);
385 
386     /*
387      * The following loop performs the dot-product of each row in pSrcA with each column in pSrcB
388      */
389 
390     /*
391      * row loop
392      */
393     rowCnt = row >> 2;
394     while (rowCnt > 0U)
395     {
396         /*
397          * Output pointer is set to starting address of the row being processed
398          */
399         px = pOut + i;
400         i = i + 4 * numColsB;
401         /*
402          * For every row wise process, the column loop counter is to be initiated
403          */
404         col = numColsB;
405         /*
406          * For every row wise process, the pInB pointer is set
407          * to the starting address of the pSrcB data
408          */
409         pInB = (q31_t const *)pSrcB->pData;
410         /*
411          * column loop
412          */
413         while (col > 0U)
414         {
415                     /*
416              * generate 4 columns elements
417              */
418             /*
419              * Matrix A columns number of MAC operations are to be performed
420              */
421 
422             q31_t const *pSrcA0Vec, *pSrcA1Vec, *pSrcA2Vec, *pSrcA3Vec;
423             q31_t const   *pInA0 = pInA;
424             q31_t const   *pInA1 = pInA0 + numColsA;
425             q31_t const   *pInA2 = pInA1 + numColsA;
426             q31_t const   *pInA3 = pInA2 + numColsA;
427             q63_t          acc0, acc1, acc2, acc3;
428 
429             acc0 = 0LL;
430             acc1 = 0LL;
431             acc2 = 0LL;
432             acc3 = 0LL;
433 
434             pSrcA0Vec = (q31_t const *) pInA0;
435             pSrcA1Vec = (q31_t const *) pInA1;
436             pSrcA2Vec = (q31_t const *) pInA2;
437             pSrcA3Vec = (q31_t const *) pInA3;
438 
439             vecOffs = vecColBOffs;
440 
441             /* process 1 x 4 block output */
442             blkCnt = numColsA >> 2;
443             while (blkCnt > 0U)
444             {
445                 q31x4_t vecB, vecA;
446 
447                 vecB = vldrwq_gather_shifted_offset(pInB, vecOffs);
448                 /* move Matrix B read offsets, 4 rows down */
449                 vecOffs = vecOffs + (uint32_t) (numColsB * 4);
450 
451                 vecA = vld1q(pSrcA0Vec);  pSrcA0Vec += 4;
452                 acc0 = vrmlaldavhaq(acc0, vecA, vecB);
453                 vecA = vld1q(pSrcA1Vec);  pSrcA1Vec += 4;
454                 acc1 = vrmlaldavhaq(acc1, vecA, vecB);
455                 vecA = vld1q(pSrcA2Vec);  pSrcA2Vec += 4;
456                 acc2 = vrmlaldavhaq(acc2, vecA, vecB);
457                 vecA = vld1q(pSrcA3Vec);  pSrcA3Vec += 4;
458                 acc3 = vrmlaldavhaq(acc3, vecA, vecB);
459                 blkCnt--;
460             }
461 
462             /*
463              * tail
464              * (will be merged thru tail predication)
465              */
466             blkCnt = numColsA & 3;
467             if (blkCnt > 0U)
468             {
469                 mve_pred16_t p0 = vctp32q(blkCnt);
470                 q31x4_t   vecB, vecA;
471 
472                 vecB = vldrwq_gather_shifted_offset_z(pInB, vecOffs, p0);
473                 //vecOffs = vecOffs + (uint32_t) (numColsB * 4);
474 
475                 vecA = vld1q(pSrcA0Vec);  pSrcA0Vec += 4;
476                 acc0 = vrmlaldavhaq(acc0, vecA, vecB);
477                 vecA = vld1q(pSrcA1Vec);  pSrcA1Vec += 4;
478                 acc1 = vrmlaldavhaq(acc1, vecA, vecB);
479                 vecA = vld1q(pSrcA2Vec);  pSrcA2Vec += 4;
480                 acc2 = vrmlaldavhaq(acc2, vecA, vecB);
481                 vecA = vld1q(pSrcA3Vec);  pSrcA3Vec += 4;
482                 acc3 = vrmlaldavhaq(acc3, vecA, vecB);
483             }
484 
485             acc0 = asrl(acc0, 23);
486             acc1 = asrl(acc1, 23);
487             acc2 = asrl(acc2, 23);
488             acc3 = asrl(acc3, 23);
489 
490             px[0] = (q31_t) acc0;
491             px[1 * numColsB] = (q31_t) acc1;
492             px[2 * numColsB] = (q31_t) acc2;
493             px[3 * numColsB] = (q31_t) acc3;
494             px++;
495             /*
496              * Decrement the column loop counter
497              */
498             col--;
499             /*
500              * Update the pointer pInB to point to the  starting address of the next column
501              */
502             pInB = (q31_t const *)pSrcB->pData + (numColsB - col);
503         }
504 
505         /*
506          * Update the pointer pInA to point to the  starting address of the next row
507          */
508         pInA += (numColsA * 4);
509         /*
510          * Decrement the row loop counter
511          */
512         rowCnt --;
513 
514     }
515     rowCnt = row & 3;
516     while (rowCnt > 0U)
517     {
518              /*
519          * Output pointer is set to starting address of the row being processed
520          */
521         px = pOut + i;
522         i = i + numColsB;
523         /*
524          * For every row wise process, the column loop counter is to be initiated
525          */
526         col = numColsB;
527         /*
528          * For every row wise process, the pInB pointer is set
529          * to the starting address of the pSrcB data
530          */
531         pInB = (q31_t const *)pSrcB->pData;
532         /*
533          * column loop
534          */
535         while (col > 0U)
536         {
537             /*
538              * generate 4 columns elements
539              */
540             /*
541              * Matrix A columns number of MAC operations are to be performed
542              */
543 
544             q31_t const *pSrcA0Vec;
545             q31_t const   *pInA0 = pInA;
546             q63_t          acc0;
547 
548             acc0 = 0LL;
549 
550 
551             pSrcA0Vec = (q31_t const *) pInA0;
552 
553             vecOffs = vecColBOffs;
554 
555             /* process 1 x 4 block output */
556             blkCnt = numColsA >> 2;
557             while (blkCnt > 0U)
558             {
559                 q31x4_t vecB, vecA;
560 
561                 vecB = vldrwq_gather_shifted_offset(pInB, vecOffs);
562                 /* move Matrix B read offsets, 4 rows down */
563                 vecOffs = vecOffs + (uint32_t) (numColsB * 4);
564 
565                 vecA = vld1q(pSrcA0Vec);  pSrcA0Vec += 4;
566                 acc0 = vrmlaldavhaq(acc0, vecA, vecB);
567 
568                 blkCnt--;
569             }
570 
571             /*
572              * tail
573              * (will be merged thru tail predication)
574              */
575             blkCnt = numColsA & 3;
576             if (blkCnt > 0U)
577             {
578                 mve_pred16_t p0 = vctp32q(blkCnt);
579                 q31x4_t   vecB, vecA;
580 
581                 vecB = vldrwq_gather_shifted_offset_z(pInB, vecOffs, p0);
582                 //vecOffs = vecOffs + (uint32_t) (numColsB * 4);
583 
584                 vecA = vld1q(pSrcA0Vec);
585                 pSrcA0Vec += 4;
586                 acc0 = vrmlaldavhaq(acc0, vecA, vecB);
587 
588             }
589 
590             acc0 = asrl(acc0, 23);
591 
592 
593             px[0] = (q31_t) acc0;
594             px++;
595             /*
596              * Decrement the column loop counter
597              */
598             col--;
599             /*
600              * Update the pointer pInB to point to the  starting address of the next column
601              */
602             pInB = (q31_t const *)pSrcB->pData + (numColsB - col);
603         }
604 
605         /*
606          * Update the pointer pInA to point to the  starting address of the next row
607          */
608         pInA += numColsA;
609         /*
610          * Decrement the row loop counter
611          */
612         rowCnt--;
613     }
614 
615     /*
616      * set status as ARM_MATH_SUCCESS
617      */
618     status = ARM_MATH_SUCCESS;
619   }
620 
621   /* Return to application */
622   return (status);
623 }
624 
625 #else
arm_mat_mult_q31(const arm_matrix_instance_q31 * pSrcA,const arm_matrix_instance_q31 * pSrcB,arm_matrix_instance_q31 * pDst)626 arm_status arm_mat_mult_q31(
627   const arm_matrix_instance_q31 * pSrcA,
628   const arm_matrix_instance_q31 * pSrcB,
629         arm_matrix_instance_q31 * pDst)
630 {
631   q31_t *pIn1 = pSrcA->pData;                    /* Input data matrix pointer A */
632   q31_t *pIn2 = pSrcB->pData;                    /* Input data matrix pointer B */
633   q31_t *pInA = pSrcA->pData;                    /* Input data matrix pointer A */
634   q31_t *pInB = pSrcB->pData;                    /* Input data matrix pointer B */
635   q31_t *pOut = pDst->pData;                     /* Output data matrix pointer */
636   q31_t *px;                                     /* Temporary output data matrix pointer */
637   q63_t sum;                                     /* Accumulator */
638   uint16_t numRowsA = pSrcA->numRows;            /* Number of rows of input matrix A */
639   uint16_t numColsB = pSrcB->numCols;            /* Number of columns of input matrix B */
640   uint16_t numColsA = pSrcA->numCols;            /* Number of columns of input matrix A */
641   uint32_t col, i = 0U, row = numRowsA, colCnt;  /* Loop counters */
642   arm_status status;                             /* Status of matrix multiplication */
643 
644 #ifdef ARM_MATH_MATRIX_CHECK
645 
646   /* Check for matrix mismatch condition */
647   if ((pSrcA->numCols != pSrcB->numRows) ||
648       (pSrcA->numRows != pDst->numRows)  ||
649       (pSrcB->numCols != pDst->numCols)    )
650   {
651     /* Set status as ARM_MATH_SIZE_MISMATCH */
652     status = ARM_MATH_SIZE_MISMATCH;
653   }
654   else
655 
656 #endif /* #ifdef ARM_MATH_MATRIX_CHECK */
657 
658   {
659     /* The following loop performs the dot-product of each row in pSrcA with each column in pSrcB */
660     /* row loop */
661     do
662     {
663       /* Output pointer is set to starting address of row being processed */
664       px = pOut + i;
665 
666       /* For every row wise process, column loop counter is to be initiated */
667       col = numColsB;
668 
669       /* For every row wise process, pIn2 pointer is set to starting address of pSrcB data */
670       pIn2 = pSrcB->pData;
671 
672       /* column loop */
673       do
674       {
675         /* Set the variable sum, that acts as accumulator, to zero */
676         sum = 0;
677 
678         /* Initialize pointer pIn1 to point to starting address of column being processed */
679         pIn1 = pInA;
680 
681 #if defined (ARM_MATH_LOOPUNROLL)
682 
683         /* Loop unrolling: Compute 4 MACs at a time. */
684         colCnt = numColsA >> 2U;
685 
686         /* matrix multiplication */
687         while (colCnt > 0U)
688         {
689           /* c(m,n) = a(1,1) * b(1,1) + a(1,2) * b(2,1) + .... + a(m,p) * b(p,n) */
690 
691           /* Perform the multiply-accumulates */
692           sum += (q63_t) *pIn1++ * *pIn2;
693           pIn2 += numColsB;
694 
695           sum += (q63_t) *pIn1++ * *pIn2;
696           pIn2 += numColsB;
697 
698           sum += (q63_t) *pIn1++ * *pIn2;
699           pIn2 += numColsB;
700 
701           sum += (q63_t) *pIn1++ * *pIn2;
702           pIn2 += numColsB;
703 
704           /* Decrement loop counter */
705           colCnt--;
706         }
707 
708         /* Loop unrolling: Compute remaining MACs */
709         colCnt = numColsA % 0x4U;
710 
711 #else
712 
713         /* Initialize cntCnt with number of columns */
714         colCnt = numColsA;
715 
716 #endif /* #if defined (ARM_MATH_LOOPUNROLL) */
717 
718         while (colCnt > 0U)
719         {
720           /* c(m,n) = a(1,1) * b(1,1) + a(1,2) * b(2,1) + .... + a(m,p) * b(p,n) */
721 
722           /* Perform the multiply-accumulates */
723           sum += (q63_t) *pIn1++ * *pIn2;
724           pIn2 += numColsB;
725 
726           /* Decrement loop counter */
727           colCnt--;
728         }
729 
730         /* Convert result from 2.62 to 1.31 format and store in destination buffer */
731         *px++ = (q31_t) (sum >> 31);
732 
733         /* Decrement column loop counter */
734         col--;
735 
736         /* Update pointer pIn2 to point to starting address of next column */
737         pIn2 = pInB + (numColsB - col);
738 
739       } while (col > 0U);
740 
741       /* Update pointer pInA to point to starting address of next row */
742       i = i + numColsB;
743       pInA = pInA + numColsA;
744 
745       /* Decrement row loop counter */
746       row--;
747 
748     } while (row > 0U);
749 
750     /* Set status as ARM_MATH_SUCCESS */
751     status = ARM_MATH_SUCCESS;
752   }
753 
754   /* Return to application */
755   return (status);
756 }
757 #endif /* defined(ARM_MATH_MVEI) */
758 
759 /**
760   @} end of MatrixMult group
761  */
762